Unverified Commit 0ce3461a authored by Tzu-Ling Kan's avatar Tzu-Ling Kan Committed by GitHub
Browse files

feat: Add runtime.endpoint() method to eliminate namespace chaining (#6386)


Signed-off-by: default avatartzulingk@nvidia.com <tzulingk@nvidia.com>
parent 6f4b33f7
...@@ -267,19 +267,16 @@ async def init(runtime: DistributedRuntime, args: argparse.Namespace, config: Co ...@@ -267,19 +267,16 @@ async def init(runtime: DistributedRuntime, args: argparse.Namespace, config: Co
Instantiate and serve Instantiate and serve
""" """
component = runtime.namespace(config.namespace).component(config.component) generate_endpoint = runtime.endpoint(
f"{config.namespace}.{config.component}.{config.endpoint}"
generate_endpoint = component.endpoint(config.endpoint) )
parsed_namespace, parsed_component_name, parsed_endpoint_name = parse_endpoint( parsed_namespace, parsed_component_name, parsed_endpoint_name = parse_endpoint(
args.downstream_endpoint args.downstream_endpoint
) )
pd_worker_client = ( pd_worker_client = await runtime.endpoint(
await runtime.namespace(parsed_namespace) f"{parsed_namespace}.{parsed_component_name}.{parsed_endpoint_name}"
.component(parsed_component_name) ).client()
.endpoint(parsed_endpoint_name)
.client()
)
handler = VllmEncodeWorker(args, config.engine_args, pd_worker_client) handler = VllmEncodeWorker(args, config.engine_args, pd_worker_client)
await handler.async_init(runtime) await handler.async_init(runtime)
......
...@@ -224,19 +224,16 @@ async def init(runtime: DistributedRuntime, args: argparse.Namespace, config: Co ...@@ -224,19 +224,16 @@ async def init(runtime: DistributedRuntime, args: argparse.Namespace, config: Co
Instantiate and serve Instantiate and serve
""" """
component = runtime.namespace(config.namespace).component(config.component) generate_endpoint = runtime.endpoint(
f"{config.namespace}.{config.component}.{config.endpoint}"
generate_endpoint = component.endpoint(config.endpoint) )
parsed_namespace, parsed_component_name, parsed_endpoint_name = parse_endpoint( parsed_namespace, parsed_component_name, parsed_endpoint_name = parse_endpoint(
args.downstream_endpoint args.downstream_endpoint
) )
pd_worker_client = ( pd_worker_client = await runtime.endpoint(
await runtime.namespace(parsed_namespace) f"{parsed_namespace}.{parsed_component_name}.{parsed_endpoint_name}"
.component(parsed_component_name) ).client()
.endpoint(parsed_endpoint_name)
.client()
)
handler = VllmEncodeWorker(args, config.engine_args, pd_worker_client) handler = VllmEncodeWorker(args, config.engine_args, pd_worker_client)
await handler.async_init(runtime) await handler.async_init(runtime)
......
...@@ -299,19 +299,16 @@ async def init(runtime: DistributedRuntime, args: argparse.Namespace, config: Co ...@@ -299,19 +299,16 @@ async def init(runtime: DistributedRuntime, args: argparse.Namespace, config: Co
Instantiate and serve Instantiate and serve
""" """
component = runtime.namespace(config.namespace).component(config.component) generate_endpoint = runtime.endpoint(
f"{config.namespace}.{config.component}.{config.endpoint}"
generate_endpoint = component.endpoint(config.endpoint) )
parsed_namespace, parsed_component_name, parsed_endpoint_name = parse_endpoint( parsed_namespace, parsed_component_name, parsed_endpoint_name = parse_endpoint(
args.downstream_endpoint args.downstream_endpoint
) )
encode_worker_client = ( encode_worker_client = await runtime.endpoint(
await runtime.namespace(parsed_namespace) f"{parsed_namespace}.{parsed_component_name}.{parsed_endpoint_name}"
.component(parsed_component_name) ).client()
.endpoint(parsed_endpoint_name)
.client()
)
handler = Processor(args, config.engine_args, encode_worker_client) handler = Processor(args, config.engine_args, encode_worker_client)
......
...@@ -271,19 +271,16 @@ async def init(runtime: DistributedRuntime, args: argparse.Namespace, config: Co ...@@ -271,19 +271,16 @@ async def init(runtime: DistributedRuntime, args: argparse.Namespace, config: Co
Instantiate and serve Instantiate and serve
""" """
component = runtime.namespace(config.namespace).component(config.component) generate_endpoint = runtime.endpoint(
f"{config.namespace}.{config.component}.{config.endpoint}"
generate_endpoint = component.endpoint(config.endpoint) )
parsed_namespace, parsed_component_name, parsed_endpoint_name = parse_endpoint( parsed_namespace, parsed_component_name, parsed_endpoint_name = parse_endpoint(
args.downstream_endpoint args.downstream_endpoint
) )
pd_worker_client = ( pd_worker_client = await runtime.endpoint(
await runtime.namespace(parsed_namespace) f"{parsed_namespace}.{parsed_component_name}.{parsed_endpoint_name}"
.component(parsed_component_name) ).client()
.endpoint(parsed_endpoint_name)
.client()
)
handler = VllmEncodeWorker(args, config.engine_args, pd_worker_client) handler = VllmEncodeWorker(args, config.engine_args, pd_worker_client)
await handler.async_init(runtime) await handler.async_init(runtime)
......
...@@ -233,12 +233,9 @@ class VllmPDWorker(VllmBaseWorker): ...@@ -233,12 +233,9 @@ class VllmPDWorker(VllmBaseWorker):
parsed_component_name, parsed_component_name,
parsed_endpoint_name, parsed_endpoint_name,
) = parse_endpoint(self.downstream_endpoint) ) = parse_endpoint(self.downstream_endpoint)
self.decode_worker_client = ( self.decode_worker_client = await runtime.endpoint(
await runtime.namespace(parsed_namespace) f"{parsed_namespace}.{parsed_component_name}.{parsed_endpoint_name}"
.component(parsed_component_name) ).client()
.endpoint(parsed_endpoint_name)
.client()
)
if "video" in self.engine_args.model.lower(): if "video" in self.engine_args.model.lower():
self.EMBEDDINGS_DTYPE = torch.uint8 self.EMBEDDINGS_DTYPE = torch.uint8
...@@ -435,9 +432,10 @@ async def init(runtime: DistributedRuntime, args: argparse.Namespace, config: Co ...@@ -435,9 +432,10 @@ async def init(runtime: DistributedRuntime, args: argparse.Namespace, config: Co
Instantiate and serve Instantiate and serve
""" """
component = runtime.namespace(config.namespace).component(config.component) generate_endpoint = runtime.endpoint(
f"{config.namespace}.{config.component}.{config.endpoint}"
generate_endpoint = component.endpoint(config.endpoint) )
component = generate_endpoint.component()
clear_endpoint = component.endpoint("clear_kv_blocks") clear_endpoint = component.endpoint("clear_kv_blocks")
if args.worker_type in ["prefill", "encode_prefill"]: if args.worker_type in ["prefill", "encode_prefill"]:
......
...@@ -31,9 +31,7 @@ class RequestHandler: ...@@ -31,9 +31,7 @@ class RequestHandler:
@dynamo_worker() @dynamo_worker()
async def worker(runtime: DistributedRuntime): async def worker(runtime: DistributedRuntime):
component = runtime.namespace("examples/bls").component("bar") endpoint = runtime.endpoint("examples/bls.bar.generate")
endpoint = component.endpoint("generate")
await endpoint.serve_endpoint(RequestHandler().generate) await endpoint.serve_endpoint(RequestHandler().generate)
......
...@@ -24,18 +24,8 @@ uvloop.install() ...@@ -24,18 +24,8 @@ uvloop.install()
@dynamo_worker() @dynamo_worker()
async def worker(runtime: DistributedRuntime): async def worker(runtime: DistributedRuntime):
foo = ( foo = await runtime.endpoint("examples/bls.foo.generate").client()
await runtime.namespace("examples/bls") bar = await runtime.endpoint("examples/bls.bar.generate").client()
.component("foo")
.endpoint("generate")
.client()
)
bar = (
await runtime.namespace("examples/bls")
.component("bar")
.endpoint("generate")
.client()
)
# hello world showed us the client has a .generate, which uses the default load balancer # hello world showed us the client has a .generate, which uses the default load balancer
# however, you can explicitly opt-in to client side load balancing by using the `round_robin` # however, you can explicitly opt-in to client side load balancing by using the `round_robin`
......
...@@ -30,9 +30,7 @@ class RequestHandler: ...@@ -30,9 +30,7 @@ class RequestHandler:
@dynamo_worker() @dynamo_worker()
async def worker(runtime: DistributedRuntime): async def worker(runtime: DistributedRuntime):
component = runtime.namespace("examples/bls").component("foo") endpoint = runtime.endpoint("examples/bls.foo.generate")
endpoint = component.endpoint("generate")
await endpoint.serve_endpoint(RequestHandler().generate) await endpoint.serve_endpoint(RequestHandler().generate)
......
...@@ -30,7 +30,7 @@ async def init(runtime: DistributedRuntime, ns: str): ...@@ -30,7 +30,7 @@ async def init(runtime: DistributedRuntime, ns: str):
Instantiate a `backend` client and call the `generate` endpoint Instantiate a `backend` client and call the `generate` endpoint
""" """
# get endpoint # get endpoint
endpoint = runtime.namespace(ns).component("backend").endpoint("generate") endpoint = runtime.endpoint(f"{ns}.backend.generate")
# create client # create client
client = await endpoint.client() client = await endpoint.client()
......
...@@ -43,9 +43,7 @@ async def init(runtime: DistributedRuntime, ns: str): ...@@ -43,9 +43,7 @@ async def init(runtime: DistributedRuntime, ns: str):
Instantiate a `backend` component and serve the `generate` endpoint Instantiate a `backend` component and serve the `generate` endpoint
A `Component` can serve multiple endpoints A `Component` can serve multiple endpoints
""" """
component = runtime.namespace(ns).component("backend") endpoint = runtime.endpoint(f"{ns}.backend.generate")
endpoint = component.endpoint("generate")
print("Started server instance") print("Started server instance")
await endpoint.serve_endpoint(RequestHandler().generate) await endpoint.serve_endpoint(RequestHandler().generate)
......
...@@ -30,7 +30,7 @@ async def init(runtime: DistributedRuntime, ns: str): ...@@ -30,7 +30,7 @@ async def init(runtime: DistributedRuntime, ns: str):
Instantiate a `backend` client and call the `generate` endpoint Instantiate a `backend` client and call the `generate` endpoint
""" """
# get endpoint # get endpoint
endpoint = runtime.namespace(ns).component("backend").endpoint("generate") endpoint = runtime.endpoint(f"{ns}.backend.generate")
# create client # create client
client = await endpoint.client() client = await endpoint.client()
......
...@@ -48,9 +48,7 @@ async def init(runtime: DistributedRuntime, ns: str): ...@@ -48,9 +48,7 @@ async def init(runtime: DistributedRuntime, ns: str):
Instantiate a `backend` component and serve the `generate` endpoint Instantiate a `backend` component and serve the `generate` endpoint
A `Component` can serve multiple endpoints A `Component` can serve multiple endpoints
""" """
component = runtime.namespace(ns).component("backend") endpoint = runtime.endpoint(f"{ns}.backend.generate")
endpoint = component.endpoint("generate")
print("Started server instance") print("Started server instance")
# the server will gracefully shutdown (i.e., keep opened TCP streams finishes) # the server will gracefully shutdown (i.e., keep opened TCP streams finishes)
......
...@@ -88,9 +88,9 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -88,9 +88,9 @@ async def init(runtime: DistributedRuntime, config: Config):
""" """
Instantiate and serve Instantiate and serve
""" """
component = runtime.namespace(config.namespace).component(config.component) endpoint = runtime.endpoint(
f"{config.namespace}.{config.component}.{config.endpoint}"
endpoint = component.endpoint(config.endpoint) )
await register_model( await register_model(
ModelInput.Tokens, ModelInput.Tokens,
ModelType.Chat | ModelType.Completions, ModelType.Chat | ModelType.Completions,
......
...@@ -101,9 +101,9 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -101,9 +101,9 @@ async def init(runtime: DistributedRuntime, config: Config):
""" """
Instantiate and serve Instantiate and serve
""" """
component = runtime.namespace(config.namespace).component(config.component) endpoint = runtime.endpoint(
f"{config.namespace}.{config.component}.{config.endpoint}"
endpoint = component.endpoint(config.endpoint) )
await register_model( await register_model(
ModelInput.Text, ModelType.Chat | ModelType.Completions, endpoint, config.model ModelInput.Text, ModelType.Chat | ModelType.Completions, endpoint, config.model
) )
......
...@@ -99,9 +99,9 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -99,9 +99,9 @@ async def init(runtime: DistributedRuntime, config: Config):
""" """
Instantiate and serve Instantiate and serve
""" """
component = runtime.namespace(config.namespace).component(config.component) endpoint = runtime.endpoint(
f"{config.namespace}.{config.component}.{config.endpoint}"
endpoint = component.endpoint(config.endpoint) )
await register_model( await register_model(
ModelInput.Tokens, ModelInput.Tokens,
ModelType.Chat | ModelType.Completions, ModelType.Chat | ModelType.Completions,
......
...@@ -31,9 +31,7 @@ class RequestHandler: ...@@ -31,9 +31,7 @@ class RequestHandler:
@dynamo_worker() @dynamo_worker()
async def worker(runtime: DistributedRuntime): async def worker(runtime: DistributedRuntime):
component = runtime.namespace("examples/pipeline").component("backend") endpoint = runtime.endpoint("examples/pipeline.backend.generate")
endpoint = component.endpoint("generate")
await endpoint.serve_endpoint(RequestHandler().generate) await endpoint.serve_endpoint(RequestHandler().generate)
......
...@@ -35,17 +35,10 @@ class RequestHandler: ...@@ -35,17 +35,10 @@ class RequestHandler:
@dynamo_worker() @dynamo_worker()
async def worker(runtime: DistributedRuntime): async def worker(runtime: DistributedRuntime):
# client to the next component - in this case the middle component # client to the next component - in this case the middle component
next = ( next = await runtime.endpoint("examples/pipeline.middle.generate").client()
await runtime.namespace("examples/pipeline")
.component("middle")
.endpoint("generate")
.client()
)
# create endpoint service for frontend component # create endpoint service for frontend component
component = runtime.namespace("examples/pipeline").component("frontend") endpoint = runtime.endpoint("examples/pipeline.frontend.generate")
endpoint = component.endpoint("generate")
handler = RequestHandler(next) handler = RequestHandler(next)
await endpoint.serve_endpoint(handler.generate) await endpoint.serve_endpoint(handler.generate)
......
...@@ -35,17 +35,10 @@ class RequestHandler: ...@@ -35,17 +35,10 @@ class RequestHandler:
@dynamo_worker() @dynamo_worker()
async def worker(runtime: DistributedRuntime): async def worker(runtime: DistributedRuntime):
# client to backend # client to backend
backend = ( backend = await runtime.endpoint("examples/pipeline.backend.generate").client()
await runtime.namespace("examples/pipeline")
.component("backend")
.endpoint("generate")
.client()
)
# create endpoint service for middle component # create endpoint service for middle component
component = runtime.namespace("examples/pipeline").component("middle") endpoint = runtime.endpoint("examples/pipeline.middle.generate")
endpoint = component.endpoint("generate")
await endpoint.serve_endpoint(RequestHandler(backend).generate) await endpoint.serve_endpoint(RequestHandler(backend).generate)
......
...@@ -31,12 +31,7 @@ async def worker(runtime: DistributedRuntime): ...@@ -31,12 +31,7 @@ async def worker(runtime: DistributedRuntime):
- `frontend` call `middle` which calls `backend` - `frontend` call `middle` which calls `backend`
- each component transforms the request before passing it to the backend - each component transforms the request before passing it to the backend
""" """
pipeline = ( pipeline = await runtime.endpoint("examples/pipeline.frontend.generate").client()
await runtime.namespace("examples/pipeline")
.component("frontend")
.endpoint("generate")
.client()
)
async for char in await pipeline.round_robin("hello from"): async for char in await pipeline.round_robin("hello from"):
print(char) print(char)
......
...@@ -26,7 +26,7 @@ async def worker(runtime: DistributedRuntime): ...@@ -26,7 +26,7 @@ async def worker(runtime: DistributedRuntime):
Instantiate a `backend` client and call the `generate` endpoint Instantiate a `backend` client and call the `generate` endpoint
""" """
# get endpoint # get endpoint
endpoint = runtime.namespace("dynamo").component("backend").endpoint("generate") endpoint = runtime.endpoint("dynamo.backend.generate")
# create client # create client
client = await endpoint.client() client = await endpoint.client()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment