"lib/llm/src/vscode:/vscode.git/clone" did not exist on "0862f87bd474f25ff81305e414f2aa4a535f2890"
Unverified Commit 0ce3461a authored by Tzu-Ling Kan's avatar Tzu-Ling Kan Committed by GitHub
Browse files

feat: Add runtime.endpoint() method to eliminate namespace chaining (#6386)


Signed-off-by: default avatartzulingk@nvidia.com <tzulingk@nvidia.com>
parent 6f4b33f7
......@@ -438,10 +438,8 @@ class EngineFactory:
reasoning_parser_class = None
(namespace_name, component_name, endpoint_name) = instance_id.triple()
generate_endpoint = (
self.runtime.namespace(namespace_name)
.component(component_name)
.endpoint(endpoint_name)
generate_endpoint = self.runtime.endpoint(
f"{namespace_name}.{component_name}.{endpoint_name}"
)
if self.router_config.router_mode == RouterMode.KV:
......
......@@ -71,8 +71,8 @@ async def main(runtime: DistributedRuntime, args):
logger.info("=" * 60)
# Create the GlobalPlanner component
component = runtime.namespace(namespace).component("GlobalPlanner")
# Create the GlobalPlanner component (get from first endpoint)
component = runtime.endpoint(f"{namespace}.GlobalPlanner.scale_request").component()
# Get K8s namespace (where GlobalPlanner pod is running)
k8s_namespace = os.environ.get("POD_NAMESPACE", "default")
......
......@@ -69,12 +69,12 @@ async def worker(runtime: DistributedRuntime):
# Initialize connections to local routers
await handler.initialize()
# Create component in the global router namespace
component = runtime.namespace(config.namespace).component(config.component_name)
# Create endpoints for prefill and decode
# Note: We use separate endpoints so we can register them with different ModelTypes
prefill_endpoint = component.endpoint("prefill_generate")
prefill_endpoint = runtime.endpoint(
f"{config.namespace}.{config.component_name}.prefill_generate"
)
component = prefill_endpoint.component()
decode_endpoint = component.endpoint("decode_generate")
logger.info("Registering as prefill worker...")
......
......@@ -79,11 +79,7 @@ class GlobalRouterHandler:
# Connect to prefill pool local routers
for idx, namespace in enumerate(self.config.prefill_pool_dynamo_namespaces):
try:
endpoint = (
self.runtime.namespace(namespace)
.component("router")
.endpoint("generate")
)
endpoint = self.runtime.endpoint(f"{namespace}.router.generate")
client = await endpoint.client()
self.prefill_clients[namespace] = client
logger.info(
......@@ -98,11 +94,7 @@ class GlobalRouterHandler:
# Connect to decode pool local routers
for idx, namespace in enumerate(self.config.decode_pool_dynamo_namespaces):
try:
endpoint = (
self.runtime.namespace(namespace)
.component("router")
.endpoint("generate")
)
endpoint = self.runtime.endpoint(f"{namespace}.router.generate")
client = await endpoint.client()
self.decode_clients[namespace] = client
logger.info(
......
......@@ -58,13 +58,11 @@ async def init_planner(runtime: DistributedRuntime, config: PlannerConfig):
await start_planner(runtime, config)
component = runtime.namespace(config.namespace).component("Planner")
async def generate(request: RequestType):
"""Dummy endpoint to satisfy that each component has an endpoint"""
yield "mock endpoint"
generate_endpoint = component.endpoint("generate")
generate_endpoint = runtime.endpoint(f"{config.namespace}.Planner.generate")
await generate_endpoint.serve_endpoint(generate) # type: ignore[arg-type]
......
......@@ -34,10 +34,8 @@ class RemotePlannerClient:
async def _ensure_client(self):
"""Lazy initialization of endpoint client with retry mechanism"""
if self._client is None:
endpoint = (
self.runtime.namespace(self.central_namespace)
.component(self.central_component)
.endpoint("scale_request")
endpoint = self.runtime.endpoint(
f"{self.central_namespace}.{self.central_component}.scale_request"
)
# Retry logic with exponential backoff
......
......@@ -486,12 +486,9 @@ class BasePlanner:
async def _get_or_create_client(self, component_name: str, endpoint_name: str):
"""Create a client for the given component and endpoint, with a brief sleep for state sync."""
client = (
await self.runtime.namespace(self.namespace)
.component(component_name)
.endpoint(endpoint_name)
.client()
)
client = await self.runtime.endpoint(
f"{self.namespace}.{component_name}.{endpoint_name}"
).client()
# TODO: remove this sleep after rust client() is blocking until watching state
await asyncio.sleep(0.1)
return client
......
......@@ -59,12 +59,9 @@ class StandaloneRouterHandler:
namespace, component, endpoint = parts
# Get worker endpoint
worker_endpoint = (
self.runtime.namespace(namespace)
.component(component)
.endpoint(endpoint)
worker_endpoint = self.runtime.endpoint(
f"{namespace}.{component}.{endpoint}"
)
self.worker_client = await worker_endpoint.client()
self.kv_router = KvRouter(
......@@ -182,17 +179,15 @@ async def worker(runtime: DistributedRuntime):
kv_router_config = build_kv_router_config(config)
# Create service component - use "router" as component name
component = runtime.namespace(config.namespace).component("router")
# Create handler
handler = StandaloneRouterHandler(
runtime, config.endpoint, config.router_block_size, kv_router_config
)
await handler.initialize()
# Expose endpoints
generate_endpoint = component.endpoint("generate")
# Create endpoints (get component from first endpoint to avoid duplicate metrics registries)
generate_endpoint = runtime.endpoint(f"{config.namespace}.router.generate")
component = generate_endpoint.component()
best_worker_endpoint = component.endpoint("best_worker_id")
logger.debug("Starting to serve endpoints...")
......
......@@ -251,11 +251,10 @@ async def init(
engine = sgl.Engine(server_args=server_args)
load_time = time.time() - start_time
component = runtime.namespace(dynamo_args.namespace).component(
dynamo_args.component
generate_endpoint = runtime.endpoint(
f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}"
)
generate_endpoint = component.endpoint(dynamo_args.endpoint)
component = generate_endpoint.component()
# Setup metrics and KV events for ALL nodes (including non-leader)
# Non-leader nodes need KV event publishing for their local DP ranks
......@@ -338,11 +337,10 @@ async def init_prefill(
engine = sgl.Engine(server_args=server_args)
component = runtime.namespace(dynamo_args.namespace).component(
dynamo_args.component
generate_endpoint = runtime.endpoint(
f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}"
)
generate_endpoint = component.endpoint(dynamo_args.endpoint)
component = generate_endpoint.component()
# Setup metrics and KV events for ALL nodes (including non-leader)
# Non-leader nodes need KV event publishing for their local DP ranks
......@@ -426,11 +424,10 @@ async def init_diffusion(
engine = sgl.Engine(server_args=server_args)
component = runtime.namespace(dynamo_args.namespace).component(
dynamo_args.component
generate_endpoint = runtime.endpoint(
f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}"
)
generate_endpoint = component.endpoint(dynamo_args.endpoint)
component = generate_endpoint.component()
# Setup metrics and KV events for ALL nodes (including non-leader)
# Non-leader nodes need KV event publishing for their local DP ranks
......@@ -502,11 +499,10 @@ async def init_embedding(
engine = sgl.Engine(server_args=server_args)
component = runtime.namespace(dynamo_args.namespace).component(
dynamo_args.component
generate_endpoint = runtime.endpoint(
f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}"
)
generate_endpoint = component.endpoint(dynamo_args.endpoint)
component = generate_endpoint.component()
# publisher instantiates the metrics and kv event publishers
publisher, metrics_task, metrics_labels = await setup_sgl_metrics(
......@@ -594,11 +590,10 @@ async def init_image_diffusion(runtime: DistributedRuntime, config: Config):
if not fs_url:
raise ValueError("--image-diffusion-fs-url is required for diffusion workers")
component = runtime.namespace(dynamo_args.namespace).component(
dynamo_args.component
generate_endpoint = runtime.endpoint(
f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}"
)
generate_endpoint = component.endpoint(dynamo_args.endpoint)
component = generate_endpoint.component()
# Image diffusion doesn't have metrics publisher like LLM
# Could add custom metrics for images/sec, steps/sec later
......@@ -681,11 +676,10 @@ async def init_video_generation(runtime: DistributedRuntime, config: Config):
"--video-generation-fs-url is required for video generation workers"
)
component = runtime.namespace(dynamo_args.namespace).component(
dynamo_args.component
generate_endpoint = runtime.endpoint(
f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}"
)
generate_endpoint = component.endpoint(dynamo_args.endpoint)
component = generate_endpoint.component()
handler = VideoGenerationWorkerHandler(
component,
......@@ -729,19 +723,15 @@ async def init_multimodal_processor(
):
"""Initialize multimodal processor component"""
server_args, dynamo_args = config.server_args, config.dynamo_args
component = runtime.namespace(dynamo_args.namespace).component(
dynamo_args.component
generate_endpoint = runtime.endpoint(
f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}"
)
generate_endpoint = component.endpoint(dynamo_args.endpoint)
component = generate_endpoint.component()
# For processor, we need to connect to the encode worker
encode_worker_client = (
await runtime.namespace(dynamo_args.namespace)
.component("encoder")
.endpoint("generate")
.client()
)
encode_worker_client = await runtime.endpoint(
f"{dynamo_args.namespace}.encoder.generate"
).client()
ready_event = asyncio.Event()
......@@ -787,19 +777,15 @@ async def init_multimodal_encode_worker(
"""Initialize multimodal encode worker component"""
server_args, dynamo_args = config.server_args, config.dynamo_args
component = runtime.namespace(dynamo_args.namespace).component(
dynamo_args.component
generate_endpoint = runtime.endpoint(
f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}"
)
generate_endpoint = component.endpoint(dynamo_args.endpoint)
component = generate_endpoint.component()
# For encode worker, we need to connect to the downstream LLM worker
pd_worker_client = (
await runtime.namespace(dynamo_args.namespace)
.component("backend")
.endpoint("generate")
.client()
)
pd_worker_client = await runtime.endpoint(
f"{dynamo_args.namespace}.backend.generate"
).client()
handler = MultimodalEncodeWorkerHandler(
component, config, pd_worker_client, shutdown_event
......@@ -840,22 +826,18 @@ async def init_multimodal_worker(
"""
server_args, dynamo_args = config.server_args, config.dynamo_args
component = runtime.namespace(dynamo_args.namespace).component(
dynamo_args.component
generate_endpoint = runtime.endpoint(
f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}"
)
generate_endpoint = component.endpoint(dynamo_args.endpoint)
component = generate_endpoint.component()
engine = sgl.Engine(server_args=server_args)
if config.serving_mode == DisaggregationMode.DECODE:
logging.info("Initializing prefill client for multimodal decode worker")
prefill_client = (
await runtime.namespace(dynamo_args.namespace)
.component("prefill")
.endpoint("generate")
.client()
)
prefill_client = await runtime.endpoint(
f"{dynamo_args.namespace}.prefill.generate"
).client()
handler = MultimodalWorkerHandler(
component, engine, config, prefill_client, shutdown_event
)
......@@ -895,11 +877,10 @@ async def init_multimodal_prefill_worker(
engine = sgl.Engine(server_args=server_args)
component = runtime.namespace(dynamo_args.namespace).component(
dynamo_args.component
generate_endpoint = runtime.endpoint(
f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}"
)
generate_endpoint = component.endpoint(dynamo_args.endpoint)
component = generate_endpoint.component()
handler = MultimodalPrefillWorkerHandler(component, engine, config, shutdown_event)
await handler.async_init()
......
......@@ -129,14 +129,9 @@ async def init_llm_worker(
parsed_namespace, parsed_component_name, parsed_endpoint_name = parse_endpoint(
config.encode_endpoint
)
encode_client = (
await runtime.namespace(parsed_namespace)
.component(parsed_component_name)
.endpoint(parsed_endpoint_name)
.client()
)
component = runtime.namespace(config.namespace).component(config.component)
encode_client = await runtime.endpoint(
f"{parsed_namespace}.{parsed_component_name}.{parsed_endpoint_name}"
).client()
# Convert model path to Path object if it's a local path, otherwise keep as string
model_path = str(config.model)
......@@ -334,7 +329,10 @@ async def init_llm_worker(
config.disaggregation_mode,
component_gauges=component_gauges,
) as engine:
endpoint = component.endpoint(config.endpoint)
endpoint = runtime.endpoint(
f"{config.namespace}.{config.component}.{config.endpoint}"
)
component = endpoint.component()
# should ideally call get_engine_runtime_config
# this is because we don't have a good way to
......@@ -451,9 +449,7 @@ async def init_llm_worker(
if config.publish_events_and_metrics:
# Initialize and pass in the publisher to the request handler to
# publish events and metrics.
kv_listener = runtime.namespace(config.namespace).component(
config.component
)
kv_listener = endpoint.component()
# Use model as fallback if served_model_name is not provided
model_name_for_metrics = config.served_model_name or config.model
metrics_labels = [
......
......@@ -81,9 +81,11 @@ async def init_video_diffusion_worker(
enable_async_cpu_offload=config.enable_async_cpu_offload,
)
# Get the component and endpoint from the runtime
component = runtime.namespace(config.namespace).component(config.component)
endpoint = component.endpoint(config.endpoint)
# Get the endpoint from the runtime
endpoint = runtime.endpoint(
f"{config.namespace}.{config.component}.{config.endpoint}"
)
component = endpoint.component()
# Initialize the diffusion engine (auto-detects pipeline from model_index.json)
engine = DiffusionEngine(diffusion_config)
......
......@@ -539,9 +539,10 @@ async def init_prefill(
"""
Instantiate and serve
"""
component = runtime.namespace(config.namespace).component(config.component)
generate_endpoint = component.endpoint(config.endpoint)
generate_endpoint = runtime.endpoint(
f"{config.namespace}.{config.component}.{config.endpoint}"
)
component = generate_endpoint.component()
clear_endpoint = component.endpoint("clear_kv_blocks")
# Use pre-created engine if provided (checkpoint mode), otherwise create new
......@@ -681,9 +682,10 @@ async def init(
Instantiate and serve
"""
component = runtime.namespace(config.namespace).component(config.component)
generate_endpoint = component.endpoint(config.endpoint)
generate_endpoint = runtime.endpoint(
f"{config.namespace}.{config.component}.{config.endpoint}"
)
component = generate_endpoint.component()
clear_endpoint = component.endpoint("clear_kv_blocks")
load_lora_endpoint = component.endpoint("load_lora")
unload_lora_endpoint = component.endpoint("unload_lora")
......@@ -931,8 +933,10 @@ async def init_omni(
# Lazy import to avoid loading vllm-omni unless explicitly needed
from dynamo.vllm.omni import OmniHandler
component = runtime.namespace(config.namespace).component(config.component)
generate_endpoint = component.endpoint(config.endpoint)
generate_endpoint = runtime.endpoint(
f"{config.namespace}.{config.component}.{config.endpoint}"
)
component = generate_endpoint.component()
# Initialize OmniHandler with Omni orchestrator
handler = OmniHandler(
......
......@@ -21,8 +21,8 @@ from .multimodal_handlers import (
logger = logging.getLogger(__name__)
# (engine_client, vllm_config, default_sampling_params, prometheus_temp_dir)
EngineSetupResult = tuple[Any, Any, Any, Any]
# (engine_client, vllm_config, default_sampling_params, prometheus_temp_dir, component_gauges)
EngineSetupResult = tuple[Any, Any, Any, Any, Any]
SetupVllmEngineFn = Callable[..., EngineSetupResult]
SetupKvEventPublisherFn = Callable[..., Optional[Any]]
......@@ -88,9 +88,10 @@ class WorkerFactory:
- Aggregated (P+D): Prefill and decode on same worker
- Disaggregated (P→D): Prefill forwards to separate decode worker
"""
component = runtime.namespace(config.namespace).component(config.component)
generate_endpoint = component.endpoint(config.endpoint)
generate_endpoint = runtime.endpoint(
f"{config.namespace}.{config.component}.{config.endpoint}"
)
component = generate_endpoint.component()
clear_endpoint = component.endpoint("clear_kv_blocks")
# Use pre-created engine if provided (checkpoint mode), otherwise create new
......@@ -114,12 +115,9 @@ class WorkerFactory:
# Set up encode worker client when routing to encoder is enabled
encode_worker_client = None
if config.route_to_encoder:
encode_worker_client = (
await runtime.namespace(config.namespace)
.component("encoder")
.endpoint("generate")
.client()
)
encode_worker_client = await runtime.endpoint(
f"{config.namespace}.encoder.generate"
).client()
logger.info("Waiting for Encoder Worker Instances ...")
await encode_worker_client.wait_for_instances()
logger.info("Connected to encoder workers")
......@@ -127,12 +125,9 @@ class WorkerFactory:
# Set up decode worker client for disaggregated mode
decode_worker_client = None
if config.is_prefill_worker:
decode_worker_client = (
await runtime.namespace(config.namespace)
.component("decoder")
.endpoint("generate")
.client()
)
decode_worker_client = await runtime.endpoint(
f"{config.namespace}.decoder.generate"
).client()
await decode_worker_client.wait_for_instances()
logger.info("Connected to decode worker for disaggregated mode")
......@@ -201,8 +196,9 @@ class WorkerFactory:
shutdown_event: asyncio.Event,
) -> None:
"""Initialize standalone multimodal encode worker."""
component = runtime.namespace(config.namespace).component(config.component)
generate_endpoint = component.endpoint(config.endpoint)
generate_endpoint = runtime.endpoint(
f"{config.namespace}.{config.component}.{config.endpoint}"
)
handler = EncodeWorkerHandler(config.engine_args)
await handler.async_init(runtime)
......
......@@ -102,10 +102,7 @@ async def triton_worker(runtime: DistributedRuntime, args: argparse.Namespace):
f"Environment: DYN_DISCOVERY_BACKEND={os.environ.get('DYN_DISCOVERY_BACKEND', 'NOT SET')}"
)
component = runtime.namespace("triton").component("tritonserver")
logger.info("✓ Created component: triton/tritonserver")
endpoint = component.endpoint("generate")
endpoint = runtime.endpoint("triton.tritonserver.generate")
logger.info("✓ Created endpoint: triton/tritonserver/generate")
model_repository = args.model_repository
......
......@@ -125,10 +125,8 @@ async def worker(runtime: DistributedRuntime) -> None:
)
# Connect to downstream vLLM workers
downstream_endpoint = (
runtime.namespace(args.namespace)
.component(args.downstream_component)
.endpoint(args.downstream_endpoint)
downstream_endpoint = runtime.endpoint(
f"{args.namespace}.{args.downstream_component}.{args.downstream_endpoint}"
)
downstream_client = await downstream_endpoint.client()
......@@ -162,8 +160,7 @@ async def worker(runtime: DistributedRuntime) -> None:
)
# Register this worker's endpoint
component = runtime.namespace(args.namespace).component(args.component)
endpoint = component.endpoint(args.endpoint)
endpoint = runtime.endpoint(f"{args.namespace}.{args.component}.{args.endpoint}")
# Use ModelInput.Tokens so Frontend preprocesses the request
# Request format: {token_ids, sampling_options, stop_conditions, extra_args: {messages}}
......
......@@ -54,10 +54,10 @@ async def main():
# Connect to middle server or direct server based on argument
if use_middle_server:
endpoint = runtime.namespace("demo").component("middle").endpoint("generate")
endpoint = runtime.endpoint("demo.middle.generate")
print("Client connecting to middle server...")
else:
endpoint = runtime.namespace("demo").component("server").endpoint("generate")
endpoint = runtime.endpoint("demo.server.generate")
print("Client connecting directly to backend server...")
client = await endpoint.client()
......
......@@ -21,9 +21,7 @@ class MiddleServer:
async def initialize(self):
"""Initialize connection to backend servers"""
# Connect to backend servers
endpoint = (
self.runtime.namespace("demo").component("server").endpoint("generate")
)
endpoint = self.runtime.endpoint("demo.server.generate")
self.backend_client = await endpoint.client()
await self.backend_client.wait_for_instances()
print("Middle server: Connected to backend servers")
......@@ -56,10 +54,8 @@ async def main():
handler = MiddleServer(runtime)
await handler.initialize()
# Create middle server component
component = runtime.namespace("demo").component("middle")
endpoint = component.endpoint("generate")
# Create middle server endpoint
endpoint = runtime.endpoint("demo.middle.generate")
print("Middle server started")
print("Forwarding requests to backend servers...")
......
......@@ -33,10 +33,8 @@ async def main():
loop = asyncio.get_running_loop()
runtime = DistributedRuntime(loop, "file", "nats")
# Create server component
component = runtime.namespace("demo").component("server")
endpoint = component.endpoint("generate")
# Create server endpoint
endpoint = runtime.endpoint("demo.server.generate")
handler = DemoServer()
print("Demo server started")
......
......@@ -23,9 +23,7 @@ from dynamo.runtime import DistributedRuntime, dynamo_worker
@dynamo_worker()
async def worker(runtime: DistributedRuntime):
# Get endpoint
endpoint = (
runtime.namespace("hello_world").component("backend").endpoint("generate")
)
endpoint = runtime.endpoint("hello_world.backend.generate")
# Create client and wait for service to be ready
client = await endpoint.client()
......
......@@ -27,13 +27,9 @@ async def worker(runtime: DistributedRuntime):
component_name = "backend"
endpoint_name = "generate"
component = runtime.namespace(namespace_name).component(component_name)
endpoint = runtime.endpoint(f"{namespace_name}.{component_name}.{endpoint_name}")
logger.info(f"Created service {namespace_name}/{component_name}")
endpoint = component.endpoint(endpoint_name)
logger.info(f"Serving endpoint {endpoint_name}")
logger.info(f"Serving endpoint {namespace_name}/{component_name}/{endpoint_name}")
await endpoint.serve_endpoint(content_generator)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment