Unverified Commit 0ce3461a authored by Tzu-Ling Kan's avatar Tzu-Ling Kan Committed by GitHub
Browse files

feat: Add runtime.endpoint() method to eliminate namespace chaining (#6386)


Signed-off-by: default avatartzulingk@nvidia.com <tzulingk@nvidia.com>
parent 6f4b33f7
......@@ -267,19 +267,16 @@ async def init(runtime: DistributedRuntime, args: argparse.Namespace, config: Co
Instantiate and serve
"""
component = runtime.namespace(config.namespace).component(config.component)
generate_endpoint = component.endpoint(config.endpoint)
generate_endpoint = runtime.endpoint(
f"{config.namespace}.{config.component}.{config.endpoint}"
)
parsed_namespace, parsed_component_name, parsed_endpoint_name = parse_endpoint(
args.downstream_endpoint
)
pd_worker_client = (
await runtime.namespace(parsed_namespace)
.component(parsed_component_name)
.endpoint(parsed_endpoint_name)
.client()
)
pd_worker_client = await runtime.endpoint(
f"{parsed_namespace}.{parsed_component_name}.{parsed_endpoint_name}"
).client()
handler = VllmEncodeWorker(args, config.engine_args, pd_worker_client)
await handler.async_init(runtime)
......
......@@ -224,19 +224,16 @@ async def init(runtime: DistributedRuntime, args: argparse.Namespace, config: Co
Instantiate and serve
"""
component = runtime.namespace(config.namespace).component(config.component)
generate_endpoint = component.endpoint(config.endpoint)
generate_endpoint = runtime.endpoint(
f"{config.namespace}.{config.component}.{config.endpoint}"
)
parsed_namespace, parsed_component_name, parsed_endpoint_name = parse_endpoint(
args.downstream_endpoint
)
pd_worker_client = (
await runtime.namespace(parsed_namespace)
.component(parsed_component_name)
.endpoint(parsed_endpoint_name)
.client()
)
pd_worker_client = await runtime.endpoint(
f"{parsed_namespace}.{parsed_component_name}.{parsed_endpoint_name}"
).client()
handler = VllmEncodeWorker(args, config.engine_args, pd_worker_client)
await handler.async_init(runtime)
......
......@@ -299,19 +299,16 @@ async def init(runtime: DistributedRuntime, args: argparse.Namespace, config: Co
Instantiate and serve
"""
component = runtime.namespace(config.namespace).component(config.component)
generate_endpoint = component.endpoint(config.endpoint)
generate_endpoint = runtime.endpoint(
f"{config.namespace}.{config.component}.{config.endpoint}"
)
parsed_namespace, parsed_component_name, parsed_endpoint_name = parse_endpoint(
args.downstream_endpoint
)
encode_worker_client = (
await runtime.namespace(parsed_namespace)
.component(parsed_component_name)
.endpoint(parsed_endpoint_name)
.client()
)
encode_worker_client = await runtime.endpoint(
f"{parsed_namespace}.{parsed_component_name}.{parsed_endpoint_name}"
).client()
handler = Processor(args, config.engine_args, encode_worker_client)
......
......@@ -271,19 +271,16 @@ async def init(runtime: DistributedRuntime, args: argparse.Namespace, config: Co
Instantiate and serve
"""
component = runtime.namespace(config.namespace).component(config.component)
generate_endpoint = component.endpoint(config.endpoint)
generate_endpoint = runtime.endpoint(
f"{config.namespace}.{config.component}.{config.endpoint}"
)
parsed_namespace, parsed_component_name, parsed_endpoint_name = parse_endpoint(
args.downstream_endpoint
)
pd_worker_client = (
await runtime.namespace(parsed_namespace)
.component(parsed_component_name)
.endpoint(parsed_endpoint_name)
.client()
)
pd_worker_client = await runtime.endpoint(
f"{parsed_namespace}.{parsed_component_name}.{parsed_endpoint_name}"
).client()
handler = VllmEncodeWorker(args, config.engine_args, pd_worker_client)
await handler.async_init(runtime)
......
......@@ -233,12 +233,9 @@ class VllmPDWorker(VllmBaseWorker):
parsed_component_name,
parsed_endpoint_name,
) = parse_endpoint(self.downstream_endpoint)
self.decode_worker_client = (
await runtime.namespace(parsed_namespace)
.component(parsed_component_name)
.endpoint(parsed_endpoint_name)
.client()
)
self.decode_worker_client = await runtime.endpoint(
f"{parsed_namespace}.{parsed_component_name}.{parsed_endpoint_name}"
).client()
if "video" in self.engine_args.model.lower():
self.EMBEDDINGS_DTYPE = torch.uint8
......@@ -435,9 +432,10 @@ async def init(runtime: DistributedRuntime, args: argparse.Namespace, config: Co
Instantiate and serve
"""
component = runtime.namespace(config.namespace).component(config.component)
generate_endpoint = component.endpoint(config.endpoint)
generate_endpoint = runtime.endpoint(
f"{config.namespace}.{config.component}.{config.endpoint}"
)
component = generate_endpoint.component()
clear_endpoint = component.endpoint("clear_kv_blocks")
if args.worker_type in ["prefill", "encode_prefill"]:
......
......@@ -31,9 +31,7 @@ class RequestHandler:
@dynamo_worker()
async def worker(runtime: DistributedRuntime):
component = runtime.namespace("examples/bls").component("bar")
endpoint = component.endpoint("generate")
endpoint = runtime.endpoint("examples/bls.bar.generate")
await endpoint.serve_endpoint(RequestHandler().generate)
......
......@@ -24,18 +24,8 @@ uvloop.install()
@dynamo_worker()
async def worker(runtime: DistributedRuntime):
foo = (
await runtime.namespace("examples/bls")
.component("foo")
.endpoint("generate")
.client()
)
bar = (
await runtime.namespace("examples/bls")
.component("bar")
.endpoint("generate")
.client()
)
foo = await runtime.endpoint("examples/bls.foo.generate").client()
bar = await runtime.endpoint("examples/bls.bar.generate").client()
# hello world showed us the client has a .generate, which uses the default load balancer
# however, you can explicitly opt-in to client side load balancing by using the `round_robin`
......
......@@ -30,9 +30,7 @@ class RequestHandler:
@dynamo_worker()
async def worker(runtime: DistributedRuntime):
component = runtime.namespace("examples/bls").component("foo")
endpoint = component.endpoint("generate")
endpoint = runtime.endpoint("examples/bls.foo.generate")
await endpoint.serve_endpoint(RequestHandler().generate)
......
......@@ -30,7 +30,7 @@ async def init(runtime: DistributedRuntime, ns: str):
Instantiate a `backend` client and call the `generate` endpoint
"""
# get endpoint
endpoint = runtime.namespace(ns).component("backend").endpoint("generate")
endpoint = runtime.endpoint(f"{ns}.backend.generate")
# create client
client = await endpoint.client()
......
......@@ -43,9 +43,7 @@ async def init(runtime: DistributedRuntime, ns: str):
Instantiate a `backend` component and serve the `generate` endpoint
A `Component` can serve multiple endpoints
"""
component = runtime.namespace(ns).component("backend")
endpoint = component.endpoint("generate")
endpoint = runtime.endpoint(f"{ns}.backend.generate")
print("Started server instance")
await endpoint.serve_endpoint(RequestHandler().generate)
......
......@@ -30,7 +30,7 @@ async def init(runtime: DistributedRuntime, ns: str):
Instantiate a `backend` client and call the `generate` endpoint
"""
# get endpoint
endpoint = runtime.namespace(ns).component("backend").endpoint("generate")
endpoint = runtime.endpoint(f"{ns}.backend.generate")
# create client
client = await endpoint.client()
......
......@@ -48,9 +48,7 @@ async def init(runtime: DistributedRuntime, ns: str):
Instantiate a `backend` component and serve the `generate` endpoint
A `Component` can serve multiple endpoints
"""
component = runtime.namespace(ns).component("backend")
endpoint = component.endpoint("generate")
endpoint = runtime.endpoint(f"{ns}.backend.generate")
print("Started server instance")
# the server will gracefully shutdown (i.e., keep opened TCP streams finishes)
......
......@@ -88,9 +88,9 @@ async def init(runtime: DistributedRuntime, config: Config):
"""
Instantiate and serve
"""
component = runtime.namespace(config.namespace).component(config.component)
endpoint = component.endpoint(config.endpoint)
endpoint = runtime.endpoint(
f"{config.namespace}.{config.component}.{config.endpoint}"
)
await register_model(
ModelInput.Tokens,
ModelType.Chat | ModelType.Completions,
......
......@@ -101,9 +101,9 @@ async def init(runtime: DistributedRuntime, config: Config):
"""
Instantiate and serve
"""
component = runtime.namespace(config.namespace).component(config.component)
endpoint = component.endpoint(config.endpoint)
endpoint = runtime.endpoint(
f"{config.namespace}.{config.component}.{config.endpoint}"
)
await register_model(
ModelInput.Text, ModelType.Chat | ModelType.Completions, endpoint, config.model
)
......
......@@ -99,9 +99,9 @@ async def init(runtime: DistributedRuntime, config: Config):
"""
Instantiate and serve
"""
component = runtime.namespace(config.namespace).component(config.component)
endpoint = component.endpoint(config.endpoint)
endpoint = runtime.endpoint(
f"{config.namespace}.{config.component}.{config.endpoint}"
)
await register_model(
ModelInput.Tokens,
ModelType.Chat | ModelType.Completions,
......
......@@ -31,9 +31,7 @@ class RequestHandler:
@dynamo_worker()
async def worker(runtime: DistributedRuntime):
component = runtime.namespace("examples/pipeline").component("backend")
endpoint = component.endpoint("generate")
endpoint = runtime.endpoint("examples/pipeline.backend.generate")
await endpoint.serve_endpoint(RequestHandler().generate)
......
......@@ -35,17 +35,10 @@ class RequestHandler:
@dynamo_worker()
async def worker(runtime: DistributedRuntime):
# client to the next component - in this case the middle component
next = (
await runtime.namespace("examples/pipeline")
.component("middle")
.endpoint("generate")
.client()
)
next = await runtime.endpoint("examples/pipeline.middle.generate").client()
# create endpoint service for frontend component
component = runtime.namespace("examples/pipeline").component("frontend")
endpoint = component.endpoint("generate")
endpoint = runtime.endpoint("examples/pipeline.frontend.generate")
handler = RequestHandler(next)
await endpoint.serve_endpoint(handler.generate)
......
......@@ -35,17 +35,10 @@ class RequestHandler:
@dynamo_worker()
async def worker(runtime: DistributedRuntime):
# client to backend
backend = (
await runtime.namespace("examples/pipeline")
.component("backend")
.endpoint("generate")
.client()
)
backend = await runtime.endpoint("examples/pipeline.backend.generate").client()
# create endpoint service for middle component
component = runtime.namespace("examples/pipeline").component("middle")
endpoint = component.endpoint("generate")
endpoint = runtime.endpoint("examples/pipeline.middle.generate")
await endpoint.serve_endpoint(RequestHandler(backend).generate)
......
......@@ -31,12 +31,7 @@ async def worker(runtime: DistributedRuntime):
- `frontend` call `middle` which calls `backend`
- each component transforms the request before passing it to the backend
"""
pipeline = (
await runtime.namespace("examples/pipeline")
.component("frontend")
.endpoint("generate")
.client()
)
pipeline = await runtime.endpoint("examples/pipeline.frontend.generate").client()
async for char in await pipeline.round_robin("hello from"):
print(char)
......
......@@ -26,7 +26,7 @@ async def worker(runtime: DistributedRuntime):
Instantiate a `backend` client and call the `generate` endpoint
"""
# get endpoint
endpoint = runtime.namespace("dynamo").component("backend").endpoint("generate")
endpoint = runtime.endpoint("dynamo.backend.generate")
# create client
client = await endpoint.client()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment