Unverified Commit 0ce3461a authored by Tzu-Ling Kan's avatar Tzu-Ling Kan Committed by GitHub
Browse files

feat: Add runtime.endpoint() method to eliminate namespace chaining (#6386)


Signed-off-by: default avatartzulingk@nvidia.com <tzulingk@nvidia.com>
parent 6f4b33f7
...@@ -438,10 +438,8 @@ class EngineFactory: ...@@ -438,10 +438,8 @@ class EngineFactory:
reasoning_parser_class = None reasoning_parser_class = None
(namespace_name, component_name, endpoint_name) = instance_id.triple() (namespace_name, component_name, endpoint_name) = instance_id.triple()
generate_endpoint = ( generate_endpoint = self.runtime.endpoint(
self.runtime.namespace(namespace_name) f"{namespace_name}.{component_name}.{endpoint_name}"
.component(component_name)
.endpoint(endpoint_name)
) )
if self.router_config.router_mode == RouterMode.KV: if self.router_config.router_mode == RouterMode.KV:
......
...@@ -71,8 +71,8 @@ async def main(runtime: DistributedRuntime, args): ...@@ -71,8 +71,8 @@ async def main(runtime: DistributedRuntime, args):
logger.info("=" * 60) logger.info("=" * 60)
# Create the GlobalPlanner component # Create the GlobalPlanner component (get from first endpoint)
component = runtime.namespace(namespace).component("GlobalPlanner") component = runtime.endpoint(f"{namespace}.GlobalPlanner.scale_request").component()
# Get K8s namespace (where GlobalPlanner pod is running) # Get K8s namespace (where GlobalPlanner pod is running)
k8s_namespace = os.environ.get("POD_NAMESPACE", "default") k8s_namespace = os.environ.get("POD_NAMESPACE", "default")
......
...@@ -69,12 +69,12 @@ async def worker(runtime: DistributedRuntime): ...@@ -69,12 +69,12 @@ async def worker(runtime: DistributedRuntime):
# Initialize connections to local routers # Initialize connections to local routers
await handler.initialize() await handler.initialize()
# Create component in the global router namespace
component = runtime.namespace(config.namespace).component(config.component_name)
# Create endpoints for prefill and decode # Create endpoints for prefill and decode
# Note: We use separate endpoints so we can register them with different ModelTypes # Note: We use separate endpoints so we can register them with different ModelTypes
prefill_endpoint = component.endpoint("prefill_generate") prefill_endpoint = runtime.endpoint(
f"{config.namespace}.{config.component_name}.prefill_generate"
)
component = prefill_endpoint.component()
decode_endpoint = component.endpoint("decode_generate") decode_endpoint = component.endpoint("decode_generate")
logger.info("Registering as prefill worker...") logger.info("Registering as prefill worker...")
......
...@@ -79,11 +79,7 @@ class GlobalRouterHandler: ...@@ -79,11 +79,7 @@ class GlobalRouterHandler:
# Connect to prefill pool local routers # Connect to prefill pool local routers
for idx, namespace in enumerate(self.config.prefill_pool_dynamo_namespaces): for idx, namespace in enumerate(self.config.prefill_pool_dynamo_namespaces):
try: try:
endpoint = ( endpoint = self.runtime.endpoint(f"{namespace}.router.generate")
self.runtime.namespace(namespace)
.component("router")
.endpoint("generate")
)
client = await endpoint.client() client = await endpoint.client()
self.prefill_clients[namespace] = client self.prefill_clients[namespace] = client
logger.info( logger.info(
...@@ -98,11 +94,7 @@ class GlobalRouterHandler: ...@@ -98,11 +94,7 @@ class GlobalRouterHandler:
# Connect to decode pool local routers # Connect to decode pool local routers
for idx, namespace in enumerate(self.config.decode_pool_dynamo_namespaces): for idx, namespace in enumerate(self.config.decode_pool_dynamo_namespaces):
try: try:
endpoint = ( endpoint = self.runtime.endpoint(f"{namespace}.router.generate")
self.runtime.namespace(namespace)
.component("router")
.endpoint("generate")
)
client = await endpoint.client() client = await endpoint.client()
self.decode_clients[namespace] = client self.decode_clients[namespace] = client
logger.info( logger.info(
......
...@@ -58,13 +58,11 @@ async def init_planner(runtime: DistributedRuntime, config: PlannerConfig): ...@@ -58,13 +58,11 @@ async def init_planner(runtime: DistributedRuntime, config: PlannerConfig):
await start_planner(runtime, config) await start_planner(runtime, config)
component = runtime.namespace(config.namespace).component("Planner")
async def generate(request: RequestType): async def generate(request: RequestType):
"""Dummy endpoint to satisfy that each component has an endpoint""" """Dummy endpoint to satisfy that each component has an endpoint"""
yield "mock endpoint" yield "mock endpoint"
generate_endpoint = component.endpoint("generate") generate_endpoint = runtime.endpoint(f"{config.namespace}.Planner.generate")
await generate_endpoint.serve_endpoint(generate) # type: ignore[arg-type] await generate_endpoint.serve_endpoint(generate) # type: ignore[arg-type]
......
...@@ -34,10 +34,8 @@ class RemotePlannerClient: ...@@ -34,10 +34,8 @@ class RemotePlannerClient:
async def _ensure_client(self): async def _ensure_client(self):
"""Lazy initialization of endpoint client with retry mechanism""" """Lazy initialization of endpoint client with retry mechanism"""
if self._client is None: if self._client is None:
endpoint = ( endpoint = self.runtime.endpoint(
self.runtime.namespace(self.central_namespace) f"{self.central_namespace}.{self.central_component}.scale_request"
.component(self.central_component)
.endpoint("scale_request")
) )
# Retry logic with exponential backoff # Retry logic with exponential backoff
......
...@@ -486,12 +486,9 @@ class BasePlanner: ...@@ -486,12 +486,9 @@ class BasePlanner:
async def _get_or_create_client(self, component_name: str, endpoint_name: str): async def _get_or_create_client(self, component_name: str, endpoint_name: str):
"""Create a client for the given component and endpoint, with a brief sleep for state sync.""" """Create a client for the given component and endpoint, with a brief sleep for state sync."""
client = ( client = await self.runtime.endpoint(
await self.runtime.namespace(self.namespace) f"{self.namespace}.{component_name}.{endpoint_name}"
.component(component_name) ).client()
.endpoint(endpoint_name)
.client()
)
# TODO: remove this sleep after rust client() is blocking until watching state # TODO: remove this sleep after rust client() is blocking until watching state
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
return client return client
......
...@@ -59,12 +59,9 @@ class StandaloneRouterHandler: ...@@ -59,12 +59,9 @@ class StandaloneRouterHandler:
namespace, component, endpoint = parts namespace, component, endpoint = parts
# Get worker endpoint # Get worker endpoint
worker_endpoint = ( worker_endpoint = self.runtime.endpoint(
self.runtime.namespace(namespace) f"{namespace}.{component}.{endpoint}"
.component(component)
.endpoint(endpoint)
) )
self.worker_client = await worker_endpoint.client() self.worker_client = await worker_endpoint.client()
self.kv_router = KvRouter( self.kv_router = KvRouter(
...@@ -182,17 +179,15 @@ async def worker(runtime: DistributedRuntime): ...@@ -182,17 +179,15 @@ async def worker(runtime: DistributedRuntime):
kv_router_config = build_kv_router_config(config) kv_router_config = build_kv_router_config(config)
# Create service component - use "router" as component name
component = runtime.namespace(config.namespace).component("router")
# Create handler # Create handler
handler = StandaloneRouterHandler( handler = StandaloneRouterHandler(
runtime, config.endpoint, config.router_block_size, kv_router_config runtime, config.endpoint, config.router_block_size, kv_router_config
) )
await handler.initialize() await handler.initialize()
# Expose endpoints # Create endpoints (get component from first endpoint to avoid duplicate metrics registries)
generate_endpoint = component.endpoint("generate") generate_endpoint = runtime.endpoint(f"{config.namespace}.router.generate")
component = generate_endpoint.component()
best_worker_endpoint = component.endpoint("best_worker_id") best_worker_endpoint = component.endpoint("best_worker_id")
logger.debug("Starting to serve endpoints...") logger.debug("Starting to serve endpoints...")
......
...@@ -251,11 +251,10 @@ async def init( ...@@ -251,11 +251,10 @@ async def init(
engine = sgl.Engine(server_args=server_args) engine = sgl.Engine(server_args=server_args)
load_time = time.time() - start_time load_time = time.time() - start_time
component = runtime.namespace(dynamo_args.namespace).component( generate_endpoint = runtime.endpoint(
dynamo_args.component f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}"
) )
component = generate_endpoint.component()
generate_endpoint = component.endpoint(dynamo_args.endpoint)
# Setup metrics and KV events for ALL nodes (including non-leader) # Setup metrics and KV events for ALL nodes (including non-leader)
# Non-leader nodes need KV event publishing for their local DP ranks # Non-leader nodes need KV event publishing for their local DP ranks
...@@ -338,11 +337,10 @@ async def init_prefill( ...@@ -338,11 +337,10 @@ async def init_prefill(
engine = sgl.Engine(server_args=server_args) engine = sgl.Engine(server_args=server_args)
component = runtime.namespace(dynamo_args.namespace).component( generate_endpoint = runtime.endpoint(
dynamo_args.component f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}"
) )
component = generate_endpoint.component()
generate_endpoint = component.endpoint(dynamo_args.endpoint)
# Setup metrics and KV events for ALL nodes (including non-leader) # Setup metrics and KV events for ALL nodes (including non-leader)
# Non-leader nodes need KV event publishing for their local DP ranks # Non-leader nodes need KV event publishing for their local DP ranks
...@@ -426,11 +424,10 @@ async def init_diffusion( ...@@ -426,11 +424,10 @@ async def init_diffusion(
engine = sgl.Engine(server_args=server_args) engine = sgl.Engine(server_args=server_args)
component = runtime.namespace(dynamo_args.namespace).component( generate_endpoint = runtime.endpoint(
dynamo_args.component f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}"
) )
component = generate_endpoint.component()
generate_endpoint = component.endpoint(dynamo_args.endpoint)
# Setup metrics and KV events for ALL nodes (including non-leader) # Setup metrics and KV events for ALL nodes (including non-leader)
# Non-leader nodes need KV event publishing for their local DP ranks # Non-leader nodes need KV event publishing for their local DP ranks
...@@ -502,11 +499,10 @@ async def init_embedding( ...@@ -502,11 +499,10 @@ async def init_embedding(
engine = sgl.Engine(server_args=server_args) engine = sgl.Engine(server_args=server_args)
component = runtime.namespace(dynamo_args.namespace).component( generate_endpoint = runtime.endpoint(
dynamo_args.component f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}"
) )
component = generate_endpoint.component()
generate_endpoint = component.endpoint(dynamo_args.endpoint)
# publisher instantiates the metrics and kv event publishers # publisher instantiates the metrics and kv event publishers
publisher, metrics_task, metrics_labels = await setup_sgl_metrics( publisher, metrics_task, metrics_labels = await setup_sgl_metrics(
...@@ -594,11 +590,10 @@ async def init_image_diffusion(runtime: DistributedRuntime, config: Config): ...@@ -594,11 +590,10 @@ async def init_image_diffusion(runtime: DistributedRuntime, config: Config):
if not fs_url: if not fs_url:
raise ValueError("--image-diffusion-fs-url is required for diffusion workers") raise ValueError("--image-diffusion-fs-url is required for diffusion workers")
component = runtime.namespace(dynamo_args.namespace).component( generate_endpoint = runtime.endpoint(
dynamo_args.component f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}"
) )
component = generate_endpoint.component()
generate_endpoint = component.endpoint(dynamo_args.endpoint)
# Image diffusion doesn't have metrics publisher like LLM # Image diffusion doesn't have metrics publisher like LLM
# Could add custom metrics for images/sec, steps/sec later # Could add custom metrics for images/sec, steps/sec later
...@@ -681,11 +676,10 @@ async def init_video_generation(runtime: DistributedRuntime, config: Config): ...@@ -681,11 +676,10 @@ async def init_video_generation(runtime: DistributedRuntime, config: Config):
"--video-generation-fs-url is required for video generation workers" "--video-generation-fs-url is required for video generation workers"
) )
component = runtime.namespace(dynamo_args.namespace).component( generate_endpoint = runtime.endpoint(
dynamo_args.component f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}"
) )
component = generate_endpoint.component()
generate_endpoint = component.endpoint(dynamo_args.endpoint)
handler = VideoGenerationWorkerHandler( handler = VideoGenerationWorkerHandler(
component, component,
...@@ -729,19 +723,15 @@ async def init_multimodal_processor( ...@@ -729,19 +723,15 @@ async def init_multimodal_processor(
): ):
"""Initialize multimodal processor component""" """Initialize multimodal processor component"""
server_args, dynamo_args = config.server_args, config.dynamo_args server_args, dynamo_args = config.server_args, config.dynamo_args
component = runtime.namespace(dynamo_args.namespace).component( generate_endpoint = runtime.endpoint(
dynamo_args.component f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}"
) )
component = generate_endpoint.component()
generate_endpoint = component.endpoint(dynamo_args.endpoint)
# For processor, we need to connect to the encode worker # For processor, we need to connect to the encode worker
encode_worker_client = ( encode_worker_client = await runtime.endpoint(
await runtime.namespace(dynamo_args.namespace) f"{dynamo_args.namespace}.encoder.generate"
.component("encoder") ).client()
.endpoint("generate")
.client()
)
ready_event = asyncio.Event() ready_event = asyncio.Event()
...@@ -787,19 +777,15 @@ async def init_multimodal_encode_worker( ...@@ -787,19 +777,15 @@ async def init_multimodal_encode_worker(
"""Initialize multimodal encode worker component""" """Initialize multimodal encode worker component"""
server_args, dynamo_args = config.server_args, config.dynamo_args server_args, dynamo_args = config.server_args, config.dynamo_args
component = runtime.namespace(dynamo_args.namespace).component( generate_endpoint = runtime.endpoint(
dynamo_args.component f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}"
) )
component = generate_endpoint.component()
generate_endpoint = component.endpoint(dynamo_args.endpoint)
# For encode worker, we need to connect to the downstream LLM worker # For encode worker, we need to connect to the downstream LLM worker
pd_worker_client = ( pd_worker_client = await runtime.endpoint(
await runtime.namespace(dynamo_args.namespace) f"{dynamo_args.namespace}.backend.generate"
.component("backend") ).client()
.endpoint("generate")
.client()
)
handler = MultimodalEncodeWorkerHandler( handler = MultimodalEncodeWorkerHandler(
component, config, pd_worker_client, shutdown_event component, config, pd_worker_client, shutdown_event
...@@ -840,22 +826,18 @@ async def init_multimodal_worker( ...@@ -840,22 +826,18 @@ async def init_multimodal_worker(
""" """
server_args, dynamo_args = config.server_args, config.dynamo_args server_args, dynamo_args = config.server_args, config.dynamo_args
component = runtime.namespace(dynamo_args.namespace).component( generate_endpoint = runtime.endpoint(
dynamo_args.component f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}"
) )
component = generate_endpoint.component()
generate_endpoint = component.endpoint(dynamo_args.endpoint)
engine = sgl.Engine(server_args=server_args) engine = sgl.Engine(server_args=server_args)
if config.serving_mode == DisaggregationMode.DECODE: if config.serving_mode == DisaggregationMode.DECODE:
logging.info("Initializing prefill client for multimodal decode worker") logging.info("Initializing prefill client for multimodal decode worker")
prefill_client = ( prefill_client = await runtime.endpoint(
await runtime.namespace(dynamo_args.namespace) f"{dynamo_args.namespace}.prefill.generate"
.component("prefill") ).client()
.endpoint("generate")
.client()
)
handler = MultimodalWorkerHandler( handler = MultimodalWorkerHandler(
component, engine, config, prefill_client, shutdown_event component, engine, config, prefill_client, shutdown_event
) )
...@@ -895,11 +877,10 @@ async def init_multimodal_prefill_worker( ...@@ -895,11 +877,10 @@ async def init_multimodal_prefill_worker(
engine = sgl.Engine(server_args=server_args) engine = sgl.Engine(server_args=server_args)
component = runtime.namespace(dynamo_args.namespace).component( generate_endpoint = runtime.endpoint(
dynamo_args.component f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}"
) )
component = generate_endpoint.component()
generate_endpoint = component.endpoint(dynamo_args.endpoint)
handler = MultimodalPrefillWorkerHandler(component, engine, config, shutdown_event) handler = MultimodalPrefillWorkerHandler(component, engine, config, shutdown_event)
await handler.async_init() await handler.async_init()
......
...@@ -129,14 +129,9 @@ async def init_llm_worker( ...@@ -129,14 +129,9 @@ async def init_llm_worker(
parsed_namespace, parsed_component_name, parsed_endpoint_name = parse_endpoint( parsed_namespace, parsed_component_name, parsed_endpoint_name = parse_endpoint(
config.encode_endpoint config.encode_endpoint
) )
encode_client = ( encode_client = await runtime.endpoint(
await runtime.namespace(parsed_namespace) f"{parsed_namespace}.{parsed_component_name}.{parsed_endpoint_name}"
.component(parsed_component_name) ).client()
.endpoint(parsed_endpoint_name)
.client()
)
component = runtime.namespace(config.namespace).component(config.component)
# Convert model path to Path object if it's a local path, otherwise keep as string # Convert model path to Path object if it's a local path, otherwise keep as string
model_path = str(config.model) model_path = str(config.model)
...@@ -334,7 +329,10 @@ async def init_llm_worker( ...@@ -334,7 +329,10 @@ async def init_llm_worker(
config.disaggregation_mode, config.disaggregation_mode,
component_gauges=component_gauges, component_gauges=component_gauges,
) as engine: ) as engine:
endpoint = component.endpoint(config.endpoint) endpoint = runtime.endpoint(
f"{config.namespace}.{config.component}.{config.endpoint}"
)
component = endpoint.component()
# should ideally call get_engine_runtime_config # should ideally call get_engine_runtime_config
# this is because we don't have a good way to # this is because we don't have a good way to
...@@ -451,9 +449,7 @@ async def init_llm_worker( ...@@ -451,9 +449,7 @@ async def init_llm_worker(
if config.publish_events_and_metrics: if config.publish_events_and_metrics:
# Initialize and pass in the publisher to the request handler to # Initialize and pass in the publisher to the request handler to
# publish events and metrics. # publish events and metrics.
kv_listener = runtime.namespace(config.namespace).component( kv_listener = endpoint.component()
config.component
)
# Use model as fallback if served_model_name is not provided # Use model as fallback if served_model_name is not provided
model_name_for_metrics = config.served_model_name or config.model model_name_for_metrics = config.served_model_name or config.model
metrics_labels = [ metrics_labels = [
......
...@@ -81,9 +81,11 @@ async def init_video_diffusion_worker( ...@@ -81,9 +81,11 @@ async def init_video_diffusion_worker(
enable_async_cpu_offload=config.enable_async_cpu_offload, enable_async_cpu_offload=config.enable_async_cpu_offload,
) )
# Get the component and endpoint from the runtime # Get the endpoint from the runtime
component = runtime.namespace(config.namespace).component(config.component) endpoint = runtime.endpoint(
endpoint = component.endpoint(config.endpoint) f"{config.namespace}.{config.component}.{config.endpoint}"
)
component = endpoint.component()
# Initialize the diffusion engine (auto-detects pipeline from model_index.json) # Initialize the diffusion engine (auto-detects pipeline from model_index.json)
engine = DiffusionEngine(diffusion_config) engine = DiffusionEngine(diffusion_config)
......
...@@ -539,9 +539,10 @@ async def init_prefill( ...@@ -539,9 +539,10 @@ async def init_prefill(
""" """
Instantiate and serve Instantiate and serve
""" """
component = runtime.namespace(config.namespace).component(config.component) generate_endpoint = runtime.endpoint(
f"{config.namespace}.{config.component}.{config.endpoint}"
generate_endpoint = component.endpoint(config.endpoint) )
component = generate_endpoint.component()
clear_endpoint = component.endpoint("clear_kv_blocks") clear_endpoint = component.endpoint("clear_kv_blocks")
# Use pre-created engine if provided (checkpoint mode), otherwise create new # Use pre-created engine if provided (checkpoint mode), otherwise create new
...@@ -681,9 +682,10 @@ async def init( ...@@ -681,9 +682,10 @@ async def init(
Instantiate and serve Instantiate and serve
""" """
component = runtime.namespace(config.namespace).component(config.component) generate_endpoint = runtime.endpoint(
f"{config.namespace}.{config.component}.{config.endpoint}"
generate_endpoint = component.endpoint(config.endpoint) )
component = generate_endpoint.component()
clear_endpoint = component.endpoint("clear_kv_blocks") clear_endpoint = component.endpoint("clear_kv_blocks")
load_lora_endpoint = component.endpoint("load_lora") load_lora_endpoint = component.endpoint("load_lora")
unload_lora_endpoint = component.endpoint("unload_lora") unload_lora_endpoint = component.endpoint("unload_lora")
...@@ -931,8 +933,10 @@ async def init_omni( ...@@ -931,8 +933,10 @@ async def init_omni(
# Lazy import to avoid loading vllm-omni unless explicitly needed # Lazy import to avoid loading vllm-omni unless explicitly needed
from dynamo.vllm.omni import OmniHandler from dynamo.vllm.omni import OmniHandler
component = runtime.namespace(config.namespace).component(config.component) generate_endpoint = runtime.endpoint(
generate_endpoint = component.endpoint(config.endpoint) f"{config.namespace}.{config.component}.{config.endpoint}"
)
component = generate_endpoint.component()
# Initialize OmniHandler with Omni orchestrator # Initialize OmniHandler with Omni orchestrator
handler = OmniHandler( handler = OmniHandler(
......
...@@ -21,8 +21,8 @@ from .multimodal_handlers import ( ...@@ -21,8 +21,8 @@ from .multimodal_handlers import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# (engine_client, vllm_config, default_sampling_params, prometheus_temp_dir) # (engine_client, vllm_config, default_sampling_params, prometheus_temp_dir, component_gauges)
EngineSetupResult = tuple[Any, Any, Any, Any] EngineSetupResult = tuple[Any, Any, Any, Any, Any]
SetupVllmEngineFn = Callable[..., EngineSetupResult] SetupVllmEngineFn = Callable[..., EngineSetupResult]
SetupKvEventPublisherFn = Callable[..., Optional[Any]] SetupKvEventPublisherFn = Callable[..., Optional[Any]]
...@@ -88,9 +88,10 @@ class WorkerFactory: ...@@ -88,9 +88,10 @@ class WorkerFactory:
- Aggregated (P+D): Prefill and decode on same worker - Aggregated (P+D): Prefill and decode on same worker
- Disaggregated (P→D): Prefill forwards to separate decode worker - Disaggregated (P→D): Prefill forwards to separate decode worker
""" """
component = runtime.namespace(config.namespace).component(config.component) generate_endpoint = runtime.endpoint(
f"{config.namespace}.{config.component}.{config.endpoint}"
generate_endpoint = component.endpoint(config.endpoint) )
component = generate_endpoint.component()
clear_endpoint = component.endpoint("clear_kv_blocks") clear_endpoint = component.endpoint("clear_kv_blocks")
# Use pre-created engine if provided (checkpoint mode), otherwise create new # Use pre-created engine if provided (checkpoint mode), otherwise create new
...@@ -114,12 +115,9 @@ class WorkerFactory: ...@@ -114,12 +115,9 @@ class WorkerFactory:
# Set up encode worker client when routing to encoder is enabled # Set up encode worker client when routing to encoder is enabled
encode_worker_client = None encode_worker_client = None
if config.route_to_encoder: if config.route_to_encoder:
encode_worker_client = ( encode_worker_client = await runtime.endpoint(
await runtime.namespace(config.namespace) f"{config.namespace}.encoder.generate"
.component("encoder") ).client()
.endpoint("generate")
.client()
)
logger.info("Waiting for Encoder Worker Instances ...") logger.info("Waiting for Encoder Worker Instances ...")
await encode_worker_client.wait_for_instances() await encode_worker_client.wait_for_instances()
logger.info("Connected to encoder workers") logger.info("Connected to encoder workers")
...@@ -127,12 +125,9 @@ class WorkerFactory: ...@@ -127,12 +125,9 @@ class WorkerFactory:
# Set up decode worker client for disaggregated mode # Set up decode worker client for disaggregated mode
decode_worker_client = None decode_worker_client = None
if config.is_prefill_worker: if config.is_prefill_worker:
decode_worker_client = ( decode_worker_client = await runtime.endpoint(
await runtime.namespace(config.namespace) f"{config.namespace}.decoder.generate"
.component("decoder") ).client()
.endpoint("generate")
.client()
)
await decode_worker_client.wait_for_instances() await decode_worker_client.wait_for_instances()
logger.info("Connected to decode worker for disaggregated mode") logger.info("Connected to decode worker for disaggregated mode")
...@@ -201,8 +196,9 @@ class WorkerFactory: ...@@ -201,8 +196,9 @@ class WorkerFactory:
shutdown_event: asyncio.Event, shutdown_event: asyncio.Event,
) -> None: ) -> None:
"""Initialize standalone multimodal encode worker.""" """Initialize standalone multimodal encode worker."""
component = runtime.namespace(config.namespace).component(config.component) generate_endpoint = runtime.endpoint(
generate_endpoint = component.endpoint(config.endpoint) f"{config.namespace}.{config.component}.{config.endpoint}"
)
handler = EncodeWorkerHandler(config.engine_args) handler = EncodeWorkerHandler(config.engine_args)
await handler.async_init(runtime) await handler.async_init(runtime)
......
...@@ -102,10 +102,7 @@ async def triton_worker(runtime: DistributedRuntime, args: argparse.Namespace): ...@@ -102,10 +102,7 @@ async def triton_worker(runtime: DistributedRuntime, args: argparse.Namespace):
f"Environment: DYN_DISCOVERY_BACKEND={os.environ.get('DYN_DISCOVERY_BACKEND', 'NOT SET')}" f"Environment: DYN_DISCOVERY_BACKEND={os.environ.get('DYN_DISCOVERY_BACKEND', 'NOT SET')}"
) )
component = runtime.namespace("triton").component("tritonserver") endpoint = runtime.endpoint("triton.tritonserver.generate")
logger.info("✓ Created component: triton/tritonserver")
endpoint = component.endpoint("generate")
logger.info("✓ Created endpoint: triton/tritonserver/generate") logger.info("✓ Created endpoint: triton/tritonserver/generate")
model_repository = args.model_repository model_repository = args.model_repository
......
...@@ -125,10 +125,8 @@ async def worker(runtime: DistributedRuntime) -> None: ...@@ -125,10 +125,8 @@ async def worker(runtime: DistributedRuntime) -> None:
) )
# Connect to downstream vLLM workers # Connect to downstream vLLM workers
downstream_endpoint = ( downstream_endpoint = runtime.endpoint(
runtime.namespace(args.namespace) f"{args.namespace}.{args.downstream_component}.{args.downstream_endpoint}"
.component(args.downstream_component)
.endpoint(args.downstream_endpoint)
) )
downstream_client = await downstream_endpoint.client() downstream_client = await downstream_endpoint.client()
...@@ -162,8 +160,7 @@ async def worker(runtime: DistributedRuntime) -> None: ...@@ -162,8 +160,7 @@ async def worker(runtime: DistributedRuntime) -> None:
) )
# Register this worker's endpoint # Register this worker's endpoint
component = runtime.namespace(args.namespace).component(args.component) endpoint = runtime.endpoint(f"{args.namespace}.{args.component}.{args.endpoint}")
endpoint = component.endpoint(args.endpoint)
# Use ModelInput.Tokens so Frontend preprocesses the request # Use ModelInput.Tokens so Frontend preprocesses the request
# Request format: {token_ids, sampling_options, stop_conditions, extra_args: {messages}} # Request format: {token_ids, sampling_options, stop_conditions, extra_args: {messages}}
......
...@@ -54,10 +54,10 @@ async def main(): ...@@ -54,10 +54,10 @@ async def main():
# Connect to middle server or direct server based on argument # Connect to middle server or direct server based on argument
if use_middle_server: if use_middle_server:
endpoint = runtime.namespace("demo").component("middle").endpoint("generate") endpoint = runtime.endpoint("demo.middle.generate")
print("Client connecting to middle server...") print("Client connecting to middle server...")
else: else:
endpoint = runtime.namespace("demo").component("server").endpoint("generate") endpoint = runtime.endpoint("demo.server.generate")
print("Client connecting directly to backend server...") print("Client connecting directly to backend server...")
client = await endpoint.client() client = await endpoint.client()
......
...@@ -21,9 +21,7 @@ class MiddleServer: ...@@ -21,9 +21,7 @@ class MiddleServer:
async def initialize(self): async def initialize(self):
"""Initialize connection to backend servers""" """Initialize connection to backend servers"""
# Connect to backend servers # Connect to backend servers
endpoint = ( endpoint = self.runtime.endpoint("demo.server.generate")
self.runtime.namespace("demo").component("server").endpoint("generate")
)
self.backend_client = await endpoint.client() self.backend_client = await endpoint.client()
await self.backend_client.wait_for_instances() await self.backend_client.wait_for_instances()
print("Middle server: Connected to backend servers") print("Middle server: Connected to backend servers")
...@@ -56,10 +54,8 @@ async def main(): ...@@ -56,10 +54,8 @@ async def main():
handler = MiddleServer(runtime) handler = MiddleServer(runtime)
await handler.initialize() await handler.initialize()
# Create middle server component # Create middle server endpoint
component = runtime.namespace("demo").component("middle") endpoint = runtime.endpoint("demo.middle.generate")
endpoint = component.endpoint("generate")
print("Middle server started") print("Middle server started")
print("Forwarding requests to backend servers...") print("Forwarding requests to backend servers...")
......
...@@ -33,10 +33,8 @@ async def main(): ...@@ -33,10 +33,8 @@ async def main():
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
runtime = DistributedRuntime(loop, "file", "nats") runtime = DistributedRuntime(loop, "file", "nats")
# Create server component # Create server endpoint
component = runtime.namespace("demo").component("server") endpoint = runtime.endpoint("demo.server.generate")
endpoint = component.endpoint("generate")
handler = DemoServer() handler = DemoServer()
print("Demo server started") print("Demo server started")
......
...@@ -23,9 +23,7 @@ from dynamo.runtime import DistributedRuntime, dynamo_worker ...@@ -23,9 +23,7 @@ from dynamo.runtime import DistributedRuntime, dynamo_worker
@dynamo_worker() @dynamo_worker()
async def worker(runtime: DistributedRuntime): async def worker(runtime: DistributedRuntime):
# Get endpoint # Get endpoint
endpoint = ( endpoint = runtime.endpoint("hello_world.backend.generate")
runtime.namespace("hello_world").component("backend").endpoint("generate")
)
# Create client and wait for service to be ready # Create client and wait for service to be ready
client = await endpoint.client() client = await endpoint.client()
......
...@@ -27,13 +27,9 @@ async def worker(runtime: DistributedRuntime): ...@@ -27,13 +27,9 @@ async def worker(runtime: DistributedRuntime):
component_name = "backend" component_name = "backend"
endpoint_name = "generate" endpoint_name = "generate"
component = runtime.namespace(namespace_name).component(component_name) endpoint = runtime.endpoint(f"{namespace_name}.{component_name}.{endpoint_name}")
logger.info(f"Created service {namespace_name}/{component_name}") logger.info(f"Serving endpoint {namespace_name}/{component_name}/{endpoint_name}")
endpoint = component.endpoint(endpoint_name)
logger.info(f"Serving endpoint {endpoint_name}")
await endpoint.serve_endpoint(content_generator) await endpoint.serve_endpoint(content_generator)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment