Unverified Commit de5c39df authored by ishandhanani's avatar ishandhanani Committed by GitHub
Browse files

refactor: clean up sglang apis [1/n] (#3432)

parent f5ba9e7f
......@@ -28,6 +28,7 @@ python3 -m dynamo.sglang \
--served-model-name silence09/DeepSeek-R1-Small-2layers \
--tp 2 \
--dp-size 2 \
--page-size 16 \
--enable-dp-attention \
--trust-remote-code \
--disaggregation-mode prefill \
......@@ -42,6 +43,7 @@ CUDA_VISIBLE_DEVICES=2,3 python3 -m dynamo.sglang \
--served-model-name silence09/DeepSeek-R1-Small-2layers \
--tp 2 \
--dp-size 2 \
--page-size 16 \
--enable-dp-attention \
--trust-remote-code \
--disaggregation-mode decode \
......
......@@ -63,6 +63,8 @@ CUDA_VISIBLE_DEVICES=1 python3 -m dynamo.sglang \
--trust-remote-code \
--skip-tokenizer-init \
--disaggregation-mode prefill \
--disaggregation-bootstrap-port 12345 \
--host 0.0.0.0 \
--disaggregation-transfer-backend nixl &
# run SGLang multimodal decode worker
......@@ -74,6 +76,8 @@ CUDA_VISIBLE_DEVICES=2 python3 -m dynamo.sglang \
--trust-remote-code \
--skip-tokenizer-init \
--disaggregation-mode decode \
--disaggregation-bootstrap-port 12345 \
--host 0.0.0.0 \
--disaggregation-transfer-backend nixl &
# Wait for all background processes to complete
......
......@@ -10,7 +10,7 @@ import sys
from argparse import Namespace
from dataclasses import dataclass
from enum import Enum
from typing import Any, Dict, List, Optional
from typing import Any, Dict, Generator, List, Optional
from sglang.srt.server_args import ServerArgs
......@@ -110,6 +110,8 @@ class DisaggregationMode(Enum):
class Config:
"""Combined configuration container for SGLang server and Dynamo args."""
def __init__(self, server_args: ServerArgs, dynamo_args: DynamoArgs) -> None:
self.server_args = server_args
self.dynamo_args = dynamo_args
......@@ -131,6 +133,19 @@ def _set_parser(
dynamo_str: Optional[str],
arg_name: str = "tool-call-parser",
) -> Optional[str]:
"""Resolve parser name from SGLang and Dynamo arguments.
Args:
sglang_str: Parser value from SGLang argument.
dynamo_str: Parser value from Dynamo argument.
arg_name: Name of the parser argument for logging.
Returns:
Resolved parser name, preferring Dynamo's value if both set.
Raises:
ValueError: If parser name is not valid.
"""
# If both are present, give preference to dynamo_str
if sglang_str is not None and dynamo_str is not None:
logging.warning(
......@@ -157,8 +172,16 @@ def _set_parser(
def parse_args(args: list[str]) -> Config:
"""
Parse all arguments and return Config with server_args and dynamo_args
"""Parse CLI arguments and return combined configuration.
Args:
args: Command-line argument strings.
Returns:
Config object with server_args and dynamo_args.
Raises:
SystemExit: If arguments are invalid or incompatible.
"""
parser = argparse.ArgumentParser()
......@@ -288,9 +311,14 @@ def parse_args(args: list[str]) -> Config:
@contextlib.contextmanager
def reserve_free_port(host: str = "localhost"):
"""
Find and reserve a free port until context exits.
def reserve_free_port(host: str = "localhost") -> Generator[int, None, None]:
"""Find and reserve a free port until context exits.
Args:
host: Host address to bind to.
Yields:
Available port number.
"""
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
try:
......@@ -302,7 +330,17 @@ def reserve_free_port(host: str = "localhost"):
def parse_endpoint(endpoint: str) -> List[str]:
"""Parse endpoint string into namespace, component, and endpoint parts."""
"""Parse endpoint string into namespace, component, and endpoint parts.
Args:
endpoint: Endpoint string in 'dyn://namespace.component.endpoint' format.
Returns:
List of [namespace, component, endpoint] strings.
Raises:
ValueError: If endpoint format is invalid.
"""
endpoint_str = endpoint.replace("dyn://", "", 1)
endpoint_parts = endpoint_str.split(".")
if len(endpoint_parts) != 3:
......@@ -316,11 +354,11 @@ def parse_endpoint(endpoint: str) -> List[str]:
return endpoint_parts
def _reserve_disaggregation_bootstrap_port():
"""
Each worker requires a unique port for disaggregation_bootstrap_port.
We use an existing utility function that reserves a free port on your
machine to avoid collisions.
def _reserve_disaggregation_bootstrap_port() -> int:
"""Reserve a unique port for disaggregation bootstrap.
Returns:
Available port number.
"""
with reserve_free_port() as port:
return port
......@@ -8,21 +8,23 @@ This module defines the default health check payload for sglang backends.
"""
import logging
from typing import Optional
import sglang as sgl
from dynamo.health_check import HealthCheckPayload
logger = logging.getLogger(__name__)
def _get_bos_token_id_from_engine(engine) -> int:
"""
Extract BOS token ID from the SGLang engine's tokenizer if available.
def _get_bos_token_id_from_engine(engine: Optional[sgl.Engine]) -> int:
"""Extract BOS token ID from the SGLang engine's tokenizer.
Args:
engine: SGLang Engine instance
engine: SGLang Engine instance.
Returns:
BOS token ID from the model's tokenizer, or 1 as fallback
BOS token ID from the model's tokenizer, or 1 as fallback.
"""
if engine is None:
return 1
......@@ -46,21 +48,16 @@ def _get_bos_token_id_from_engine(engine) -> int:
class SglangHealthCheckPayload(HealthCheckPayload):
"""
sglang-specific health check payload.
"""SGLang-specific health check payload for decode workers.
Provides sglang defaults and inherits environment override support from base class.
Provides SGLang defaults and inherits environment override support from base class.
"""
def __init__(self, engine=None):
"""
Initialize sglang health check payload with sglang-specific defaults.
def __init__(self, engine: Optional[sgl.Engine] = None) -> None:
"""Initialize SGLang health check payload with model-specific BOS token.
Args:
engine: Optional SGLang Engine instance to extract BOS token from.
If provided, will attempt to use the model's actual BOS token.
The format matches what DecodeWorkerHandler expects from the frontend.
"""
bos_token_id = _get_bos_token_id_from_engine(engine)
......@@ -82,19 +79,16 @@ class SglangHealthCheckPayload(HealthCheckPayload):
class SglangPrefillHealthCheckPayload(HealthCheckPayload):
"""
SGLang-specific health check payload for prefill workers in disaggregated mode.
"""SGLang-specific health check payload for prefill workers in disaggregated mode.
The prefill handler expects a wrapped structure with 'request' and 'sampling_params'.
"""
def __init__(self, engine=None):
"""
Initialize SGLang prefill health check payload with proper wrapped structure.
def __init__(self, engine: Optional[sgl.Engine] = None) -> None:
"""Initialize SGLang prefill health check payload with proper wrapped structure.
Args:
engine: Optional SGLang Engine instance to extract BOS token from.
If provided, will attempt to use the model's actual BOS token.
"""
bos_token_id = _get_bos_token_id_from_engine(engine)
......
......@@ -2,25 +2,23 @@
# SPDX-License-Identifier: Apache-2.0
import asyncio
import json
import logging
import signal
import sys
import sglang as sgl
import uvloop
from sglang.srt.utils import get_ip
from dynamo.llm import ModelInput, ZmqKvEventPublisher, ZmqKvEventPublisherConfig
from dynamo.llm import ModelInput
from dynamo.runtime import DistributedRuntime, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.sglang.args import Config, DisaggregationMode, parse_args, parse_endpoint
from dynamo.sglang.args import Config, DisaggregationMode, parse_args
from dynamo.sglang.health_check import (
SglangHealthCheckPayload,
SglangPrefillHealthCheckPayload,
)
from dynamo.sglang.publisher import setup_sgl_metrics
from dynamo.sglang.register import register_llm_with_runtime_config
from dynamo.sglang.register import register_llm_with_readiness_gate
from dynamo.sglang.request_handlers import (
DecodeWorkerHandler,
MultimodalEncodeWorkerHandler,
......@@ -73,8 +71,6 @@ async def init(runtime: DistributedRuntime, config: Config):
generate_endpoint = component.endpoint(dynamo_args.endpoint)
# TODO: think about implementing DisaggregationStrategy for P->D
# TODO: implement a `next` field in the config to dynamically set the next client
prefill_client = None
if config.serving_mode == DisaggregationMode.DECODE:
logging.info("Initializing prefill client")
......@@ -85,46 +81,15 @@ async def init(runtime: DistributedRuntime, config: Config):
.client()
)
publisher, metrics_task, metrics_labels = await setup_sgl_metrics(engine, component)
kv_publisher = None
if server_args.kv_events_config:
kv_events = json.loads(server_args.kv_events_config)
ep = kv_events.get("endpoint")
zmq_ep = ep.replace("*", get_ip()) if ep else None
zmq_config = ZmqKvEventPublisherConfig(
worker_id=generate_endpoint.lease_id(),
kv_block_size=server_args.page_size,
zmq_endpoint=zmq_ep,
)
logging.info(f"Setting up ZMQ kv event publisher at {zmq_ep}")
kv_publisher = ZmqKvEventPublisher(component=component, config=zmq_config)
# publisher instantiates the metrics and kv event publishers
publisher, metrics_task, metrics_labels = await setup_sgl_metrics(
engine, config, component, generate_endpoint
)
# Readiness gate: requests wait until model is registered
ready_event = asyncio.Event()
handler = DecodeWorkerHandler(
component, engine, config, publisher, kv_publisher, prefill_client
)
async def register_model():
"""Register the model and signal readiness"""
registration_success = await register_llm_with_runtime_config(
engine,
generate_endpoint,
server_args,
dynamo_args,
)
if not registration_success:
logging.error("Model registration failed; shutting down")
runtime.shutdown()
raise RuntimeError("Model registration failed")
# Model is ready - allow queued requests to proceed
ready_event.set()
logging.info("Model registration succeeded; processing queued requests")
handler = DecodeWorkerHandler(component, engine, config, publisher, prefill_client)
health_check_payload = SglangHealthCheckPayload(engine).to_dict()
......@@ -138,7 +103,13 @@ async def init(runtime: DistributedRuntime, config: Config):
metrics_labels=metrics_labels,
health_check_payload=health_check_payload,
),
register_model(),
register_llm_with_readiness_gate(
engine,
generate_endpoint,
server_args,
dynamo_args,
readiness_gate=ready_event,
),
)
except Exception as e:
logging.error(f"Failed to serve endpoints: {e}")
......@@ -198,51 +169,35 @@ async def init_multimodal_processor(runtime: DistributedRuntime, config: Config)
generate_endpoint = component.endpoint(dynamo_args.endpoint)
# For processor, we need to connect to the encode worker
# Default endpoint for encode worker
encode_endpoint = f"dyn://{dynamo_args.namespace}.encoder.generate"
parsed_namespace, parsed_component_name, parsed_endpoint_name = parse_endpoint(
encode_endpoint
)
encode_worker_client = (
await runtime.namespace(parsed_namespace)
.component(parsed_component_name)
.endpoint(parsed_endpoint_name)
await runtime.namespace(dynamo_args.namespace)
.component("encoder")
.endpoint("generate")
.client()
)
ready_event = asyncio.Event()
handler = MultimodalProcessorHandler(component, config, encode_worker_client)
logging.info("Waiting for Encoder Worker Instances ...")
await encode_worker_client.wait_for_instances()
async def register_model():
"""Register the model and signal readiness"""
registration_success = await register_llm_with_runtime_config(
None, # engine,
generate_endpoint,
server_args,
dynamo_args,
input_type=ModelInput.Text,
)
if not registration_success:
logging.error("Model registration failed; shutting down")
runtime.shutdown()
raise RuntimeError("Model registration failed")
logging.info("Model registration succeeded; processing queued requests")
try:
# Start endpoint immediately and register model concurrently
# Requests queue until ready_event is set
await asyncio.gather(
generate_endpoint.serve_endpoint(
handler.generate,
graceful_shutdown=True,
metrics_labels=[("model", server_args.served_model_name)],
),
register_model(),
register_llm_with_readiness_gate(
None, # engine
generate_endpoint,
server_args,
dynamo_args,
input_type=ModelInput.Text,
readiness_gate=ready_event,
),
)
except Exception as e:
logging.error(f"Failed to serve endpoints: {e}")
......@@ -262,17 +217,11 @@ async def init_multimodal_encode_worker(runtime: DistributedRuntime, config: Con
generate_endpoint = component.endpoint(dynamo_args.endpoint)
# For encode worker, we need to connect to the downstream worker (LLM worker)
# Default endpoint for LLM worker
llm_endpoint = f"dyn://{dynamo_args.namespace}.backend.generate"
parsed_namespace, parsed_component_name, parsed_endpoint_name = parse_endpoint(
llm_endpoint
)
# For encode worker, we need to connect to the downstream LLM worker
pd_worker_client = (
await runtime.namespace(parsed_namespace)
.component(parsed_component_name)
.endpoint(parsed_endpoint_name)
await runtime.namespace(dynamo_args.namespace)
.component("backend")
.endpoint("generate")
.client()
)
......@@ -311,29 +260,18 @@ async def init_multimodal_worker(runtime: DistributedRuntime, config: Config):
engine = sgl.Engine(server_args=server_args)
# Setup handler based on serving mode
if config.serving_mode == DisaggregationMode.DECODE:
# Decode mode: create prefill client
prefill_endpoint = f"dyn://{dynamo_args.namespace}.prefill.generate"
parsed_namespace, parsed_component_name, parsed_endpoint_name = parse_endpoint(
prefill_endpoint
)
logging.info("Initializing prefill client for multimodal decode worker")
prefill_client = (
await runtime.namespace(parsed_namespace)
.component(parsed_component_name)
.endpoint(parsed_endpoint_name)
await runtime.namespace(dynamo_args.namespace)
.component("prefill")
.endpoint("generate")
.client()
)
handler = MultimodalWorkerHandler(
component, engine, config, None, None, prefill_client
)
handler = MultimodalWorkerHandler(component, engine, config, prefill_client)
else:
# Aggregated mode: no prefill client needed
handler = MultimodalWorkerHandler(component, engine, config)
# Initialize async components
await handler.async_init()
try:
......
......@@ -2,13 +2,14 @@
# SPDX-License-Identifier: Apache-2.0
import asyncio
import json
import logging
from typing import List, Optional, Tuple
import sglang as sgl
import zmq
import zmq.asyncio
from sglang.srt.utils import get_zmq_socket
from sglang.srt.utils import get_ip, get_zmq_socket
from dynamo.llm import (
ForwardPassMetrics,
......@@ -16,24 +17,41 @@ from dynamo.llm import (
SpecDecodeStats,
WorkerMetricsPublisher,
WorkerStats,
ZmqKvEventPublisher,
ZmqKvEventPublisherConfig,
)
from dynamo.runtime import Component
from dynamo.runtime import Component, Endpoint
from dynamo.sglang.args import Config
class DynamoSglangStatPublisher:
class DynamoSglangPublisher:
"""
Handles SGLang metrics reception and publishing.
Handles SGLang kv events and metrics reception and publishing.
"""
def __init__(
self,
engine: sgl.Engine,
config: Config,
component: Component,
generate_endpoint: Endpoint,
metrics_labels: Optional[List[Tuple[str, str]]] = None,
) -> None:
"""Initialize the SGLang publisher for metrics and KV events.
Args:
engine: The SGLang engine instance.
config: SGLang configuration including server args.
component: The Dynamo runtime component.
generate_endpoint: The Dynamo endpoint for generation requests.
metrics_labels: Optional list of label key-value pairs for metrics.
"""
self.engine = engine
self.inner = WorkerMetricsPublisher()
self.inner.create_endpoint(component, metrics_labels)
self.server_args = config.server_args
self.generate_endpoint = generate_endpoint
self.component = component
self.metrics_publisher = WorkerMetricsPublisher()
self.metrics_publisher.create_endpoint(component, metrics_labels)
# Set default values (can be overridden later if needed)
self.request_total_slots = 1024
......@@ -47,11 +65,11 @@ class DynamoSglangStatPublisher:
)
async def run(self) -> None:
"""Main loop to receive scheduler metrics and publish them"""
"""Continuously receive scheduler metrics from ZMQ socket and publish them."""
while True:
try:
kv_metrics = await self._sock.recv_pyobj() # type: ignore
self.record_values(
self._record_values(
request_active_slots=kv_metrics.request_active_slots,
request_total_slots=kv_metrics.request_total_slots,
kv_active_blocks=kv_metrics.kv_active_blocks,
......@@ -66,7 +84,8 @@ class DynamoSglangStatPublisher:
"Failed to receive or publish SGLang scheduler metrics"
)
def init_publish(self) -> None:
def init_engine_metrics_publish(self) -> None:
"""Publish initial dummy metrics to bootstrap the metrics endpoint."""
worker_stats = WorkerStats(
request_active_slots=0,
request_total_slots=self.request_total_slots,
......@@ -85,22 +104,54 @@ class DynamoSglangStatPublisher:
spec_decode_stats=None,
)
logging.info("Sending dummy metrics to initialize")
self.inner.publish(metrics)
self.metrics_publisher.publish(metrics)
def init_kv_event_publish(self) -> Optional[ZmqKvEventPublisher]:
"""Initialize KV event publisher if configured.
Returns:
ZmqKvEventPublisher instance if kv_events_config is set, None otherwise.
"""
self.kv_publisher = None
if self.server_args.kv_events_config:
kv_events = json.loads(self.server_args.kv_events_config)
ep = kv_events.get("endpoint")
zmq_ep = ep.replace("*", get_ip()) if ep else None
lease_id = self.generate_endpoint.lease_id()
zmq_config = ZmqKvEventPublisherConfig(
worker_id=lease_id,
kv_block_size=self.server_args.page_size,
zmq_endpoint=zmq_ep,
)
logging.info(f"Setting up ZMQ kv event publisher at {zmq_ep}")
self.kv_publisher = ZmqKvEventPublisher(
component=self.component, config=zmq_config
)
return self.kv_publisher
def record(
def _record(
self,
worker_stats: WorkerStats,
kv_stats: KvStats,
spec_decode_stats: Optional[SpecDecodeStats] = None,
) -> None:
"""Package and publish metrics.
Args:
worker_stats: Worker-level statistics.
kv_stats: KV cache statistics.
spec_decode_stats: Optional speculative decoding statistics.
"""
metrics = ForwardPassMetrics(
worker_stats=worker_stats,
kv_stats=kv_stats,
spec_decode_stats=spec_decode_stats,
)
self.inner.publish(metrics)
self.metrics_publisher.publish(metrics)
def record_values(
def _record_values(
self,
request_active_slots: int,
request_total_slots: int,
......@@ -112,6 +163,19 @@ class DynamoSglangStatPublisher:
data_parallel_rank: Optional[int] = None,
spec_decode_stats: Optional[SpecDecodeStats] = None,
) -> None:
"""Create stats objects from raw values and publish.
Args:
request_active_slots: Number of active request slots.
request_total_slots: Total number of request slots.
kv_active_blocks: Number of active KV cache blocks.
kv_total_blocks: Total number of KV cache blocks.
num_requests_waiting: Number of queued requests.
gpu_cache_usage_perc: GPU cache utilization percentage.
gpu_prefix_cache_hit_rate: Prefix cache hit rate.
data_parallel_rank: Optional data parallel rank.
spec_decode_stats: Optional speculative decoding statistics.
"""
worker_stats = WorkerStats(
request_active_slots=request_active_slots,
request_total_slots=request_total_slots,
......@@ -126,19 +190,32 @@ class DynamoSglangStatPublisher:
gpu_cache_usage_perc=gpu_cache_usage_perc,
gpu_prefix_cache_hit_rate=gpu_prefix_cache_hit_rate,
)
self.record(worker_stats, kv_stats, spec_decode_stats)
self._record(worker_stats, kv_stats, spec_decode_stats)
async def setup_sgl_metrics(
engine: sgl.Engine,
config: Config,
component: Component,
) -> tuple[DynamoSglangStatPublisher, asyncio.Task, list[tuple[str, str]]]:
"""
Convenience bootstrap: create endpoint, publish an initial update, and start the metrics loop.
generate_endpoint: Endpoint,
) -> tuple[DynamoSglangPublisher, asyncio.Task, list[tuple[str, str]]]:
"""Create publisher, initialize metrics, and start the metrics publishing loop.
Args:
engine: The SGLang engine instance.
config: SGLang configuration including server args.
component: The Dynamo runtime component.
generate_endpoint: The Dynamo endpoint for generation requests.
Returns:
Tuple of (publisher instance, running asyncio task, metrics labels).
"""
metrics_labels = [("model", engine.server_args.served_model_name)]
publisher = DynamoSglangStatPublisher(engine, component, metrics_labels)
publisher.init_publish()
publisher = DynamoSglangPublisher(
engine, config, component, generate_endpoint, metrics_labels
)
publisher.init_engine_metrics_publish()
publisher.init_kv_event_publish()
task = asyncio.create_task(publisher.run())
logging.info("SGLang metrics loop started")
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import asyncio
import logging
from typing import Optional
......@@ -12,17 +13,24 @@ from dynamo.llm import ModelInput, ModelRuntimeConfig, ModelType, register_llm
from dynamo.sglang.args import DynamoArgs
async def register_llm_with_runtime_config(
async def _register_llm_with_runtime_config(
engine: sgl.Engine,
endpoint: Endpoint,
server_args: ServerArgs,
dynamo_args: DynamoArgs,
input_type: Optional[ModelInput] = ModelInput.Tokens,
) -> bool:
"""Register LLM with runtime config
"""Register LLM with the Dynamo runtime.
Args:
engine: The SGLang engine instance.
endpoint: The Dynamo endpoint for communication.
server_args: SGLang server configuration.
dynamo_args: Dynamo-specific configuration.
input_type: Expected model input type. Defaults to ModelInput.Tokens.
Returns:
bool: True if registration succeeded, False if it failed
True if registration succeeded, False otherwise.
"""
runtime_config = await _get_runtime_config(engine, server_args, dynamo_args)
input_type = input_type
......@@ -55,7 +63,16 @@ async def register_llm_with_runtime_config(
async def _get_runtime_config(
engine: sgl.Engine, server_args: ServerArgs, dynamo_args: DynamoArgs
) -> Optional[ModelRuntimeConfig]:
"""Get runtime config from SGLang engine"""
"""Extract runtime configuration from SGLang engine and args.
Args:
engine: The SGLang engine instance.
server_args: SGLang server configuration.
dynamo_args: Dynamo-specific configuration.
Returns:
ModelRuntimeConfig with extracted values, or None if extraction fails.
"""
runtime_config = ModelRuntimeConfig()
# set reasoning parser and tool call parser
runtime_config.reasoning_parser = dynamo_args.reasoning_parser
......@@ -109,3 +126,43 @@ async def _get_runtime_config(
except Exception as e:
logging.warning(f"Failed to get runtime config: {e}. Proceeding without it.")
return runtime_config
async def register_llm_with_readiness_gate(
engine: sgl.Engine,
generate_endpoint: Endpoint,
server_args: ServerArgs,
dynamo_args: DynamoArgs,
input_type: Optional[ModelInput] = ModelInput.Tokens,
readiness_gate: Optional[asyncio.Event] = None,
) -> None:
"""Wrapper function to register LLM with the Dynamo runtime and use optional readiness gate to signal success.
Args:
engine: The SGLang engine instance.
generate_endpoint: The Dynamo endpoint for generation requests.
server_args: SGLang server configuration.
dynamo_args: Dynamo-specific configuration.
input_type: Expected model input type. Defaults to ModelInput.Tokens.
readiness_gate: Optional event to signal when registration completes.
Raises:
RuntimeError: If model registration fails.
"""
registration_success = await _register_llm_with_runtime_config(
engine,
generate_endpoint,
server_args,
dynamo_args,
input_type,
)
if not registration_success:
logging.error("Model registration failed; shutting down")
if engine is not None:
engine.shutdown()
raise RuntimeError("Model registration failed")
if readiness_gate:
readiness_gate.set()
logging.info("Model registration succeeded; processing queued requests")
......@@ -3,28 +3,46 @@
import logging
import time
from typing import Any, AsyncGenerator, Dict, Optional
import sglang as sgl
from dynamo._core import Client, Component
from dynamo.llm import WorkerMetricsPublisher, ZmqKvEventPublisher
from dynamo.sglang.args import Config, DisaggregationMode
from dynamo.sglang.protocol import DisaggPreprocessedRequest
from dynamo.sglang.publisher import DynamoSglangPublisher
from dynamo.sglang.request_handlers.handler_base import BaseWorkerHandler
class DecodeWorkerHandler(BaseWorkerHandler):
"""Handler for decode workers in both aggregated and disaggregated serving modes."""
def __init__(
self,
component: Component,
engine: sgl.Engine,
config: Config,
metrics_publisher: WorkerMetricsPublisher,
kv_publisher: ZmqKvEventPublisher = None,
prefill_client: Client = None,
):
publisher: DynamoSglangPublisher,
prefill_client: Optional[Client] = None,
) -> None:
"""Initialize decode worker handler.
Args:
component: The Dynamo runtime component.
engine: The SGLang engine instance.
config: SGLang and Dynamo configuration.
publisher: Metrics publisher for the worker.
prefill_client: Optional client for prefill worker in disaggregated mode.
Raises:
ValueError: If prefill_client is not provided in decode serving mode.
"""
super().__init__(
component, engine, config, metrics_publisher, kv_publisher, prefill_client
component,
engine,
config,
publisher,
prefill_client,
)
if self.serving_mode == DisaggregationMode.DECODE:
if self.prefill_client is None:
......@@ -36,13 +54,21 @@ class DecodeWorkerHandler(BaseWorkerHandler):
logging.info("Worker handler initialized")
def cleanup(self):
def cleanup(self) -> None:
"""Shutdown the engine and cleanup resources."""
self.engine.shutdown()
logging.info("Engine shutdown")
super().cleanup()
def _build_sampling_params(self, request: dict) -> dict:
"""Build sampling params depending on request from frontend"""
def _build_sampling_params(self, request: Dict[str, Any]) -> Dict[str, Any]:
"""Build sampling params from request format.
Args:
request: Request dict in either token-based or OpenAI format.
Returns:
Dict of sampling parameters for SGLang engine.
"""
if self.skip_tokenizer_init:
# Token-based request format
sampling_opts = request.get("sampling_options", {})
......@@ -66,7 +92,20 @@ class DecodeWorkerHandler(BaseWorkerHandler):
return {k: v for k, v in param_mapping.items() if v is not None}
async def generate(self, request: dict):
async def generate(
self, request: Dict[str, Any]
) -> AsyncGenerator[Dict[str, Any], None]:
"""Generate response in aggregated or disaggregated mode.
Args:
request: Request dict with input and sampling parameters.
Yields:
Response dicts with token_ids or OpenAI-formatted chunks.
Raises:
RuntimeError: If no bootstrap info received from prefill worker.
"""
sampling_params = self._build_sampling_params(request)
input_param = self._get_input_param(request)
......@@ -115,7 +154,20 @@ class DecodeWorkerHandler(BaseWorkerHandler):
async for out in self._process_text_stream(agg):
yield out
async def _process_token_stream(self, stream_source):
async def _process_token_stream(
self, stream_source: AsyncGenerator[Dict[str, Any], None]
) -> AsyncGenerator[Dict[str, Any], None]:
"""Process token-based stream output.
Args:
stream_source: Async generator from engine.async_generate.
Yields:
Dict with token_ids and optional finish_reason.
Raises:
ValueError: If response missing output_ids.
"""
num_output_tokens_so_far = 0
async for res in stream_source:
......@@ -134,8 +186,17 @@ class DecodeWorkerHandler(BaseWorkerHandler):
yield out
async def _process_text_stream(self, stream_source):
"""Process stream for text input mode"""
async def _process_text_stream(
self, stream_source: AsyncGenerator[Dict[str, Any], None]
) -> AsyncGenerator[Dict[str, Any], None]:
"""Process text-based stream output in OpenAI format.
Args:
stream_source: Async generator from engine.async_generate.
Yields:
OpenAI-formatted chat completion chunk dicts.
"""
count = 0
async for res in stream_source:
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import random
import socket
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional, Tuple
import sglang as sgl
from sglang.srt.utils import get_ip
from dynamo._core import Client, Component
from dynamo.llm import WorkerMetricsPublisher, ZmqKvEventPublisher
from dynamo.sglang.args import Config
from dynamo.sglang.publisher import DynamoSglangPublisher
class BaseWorkerHandler(ABC):
"""Abstract base class for SGLang worker handlers."""
def __init__(
self,
component: Component,
engine: sgl.Engine,
config: Config,
metrics_publisher: WorkerMetricsPublisher = None,
kv_publisher: ZmqKvEventPublisher = None,
prefill_client: Client = None,
):
publisher: Optional[DynamoSglangPublisher] = None,
prefill_client: Optional[Client] = None,
) -> None:
"""Initialize base worker handler.
Args:
component: The Dynamo runtime component.
engine: The SGLang engine instance.
config: SGLang and Dynamo configuration.
publisher: Optional metrics publisher for the worker.
prefill_client: Optional client for prefill worker in disaggregated mode.
"""
self.component = component
self.engine = engine
self.config = config
self.metrics_publisher = metrics_publisher
self.kv_publisher = kv_publisher
if publisher is not None:
self.metrics_publisher = publisher.metrics_publisher
self.kv_publisher = publisher.kv_publisher
else:
self.metrics_publisher = None
self.kv_publisher = None
self.prefill_client = prefill_client
self.serving_mode = config.serving_mode
self.skip_tokenizer_init = config.server_args.skip_tokenizer_init
@abstractmethod
async def generate(self, request: str):
async def generate(self, request: Dict[str, Any]):
"""Generate response from request.
Args:
request: Request dict with input and parameters.
Yields:
Response data (format varies by handler implementation).
"""
pass
def cleanup(self):
def cleanup(self) -> None:
"""Cleanup resources. Override in subclasses as needed."""
pass
def _get_input_param(self, request: dict) -> dict:
"""Get the appropriate input parameter for SGLang"""
def _get_input_param(self, request: Dict[str, Any]) -> Dict[str, Any]:
"""Get the appropriate input parameter for SGLang engine.
Args:
request: Request dict with token_ids or messages.
Returns:
Dict with either input_ids or prompt for engine.
"""
if self.skip_tokenizer_init:
return {"input_ids": request["token_ids"]}
else:
......@@ -47,3 +81,34 @@ class BaseWorkerHandler(ABC):
request["messages"], tokenize=False, add_generation_prompt=True
)
return {"prompt": prompt}
@staticmethod
def _generate_bootstrap_room() -> int:
"""Generate a unique bootstrap room ID for disaggregated serving.
Returns:
Random 63-bit integer.
"""
return random.randint(0, 2**63 - 1)
@staticmethod
def _get_bootstrap_info(engine: sgl.Engine) -> Tuple[str, int]:
"""Extract bootstrap host and port from SGLang engine.
Args:
engine: The SGLang engine instance.
Returns:
Tuple of (bootstrap_host, bootstrap_port).
"""
inner_tm = engine.tokenizer_manager
bootstrap_port = inner_tm.server_args.disaggregation_bootstrap_port
if inner_tm.server_args.dist_init_addr:
bootstrap_host = socket.gethostbyname(
inner_tm.server_args.dist_init_addr.split(":")[0]
)
else:
bootstrap_host = get_ip()
return bootstrap_host, bootstrap_port
......@@ -4,17 +4,13 @@
import asyncio
import json
import logging
import random
import socket
from typing import AsyncIterator
import sglang as sgl
import torch
from sglang.srt.utils import get_ip
import dynamo.nixl_connect as connect
from dynamo._core import Client, Component
from dynamo.llm import WorkerMetricsPublisher, ZmqKvEventPublisher
from dynamo.sglang.args import Config, DisaggregationMode
from dynamo.sglang.request_handlers.handler_base import BaseWorkerHandler
from dynamo.sglang.utils.multimodal_protocol import (
......@@ -205,30 +201,6 @@ class StreamProcessor:
}
yield json.dumps(error_output)
class BootstrapManager:
"""Handles bootstrap coordination for disaggregated mode"""
@staticmethod
def generate_bootstrap_room() -> int:
"""Generate a unique bootstrap room ID"""
return random.randint(0, 2**63 - 1)
@staticmethod
def get_bootstrap_info(engine: sgl.Engine) -> tuple[str, int]:
"""Extract bootstrap info from SGLang engine"""
inner_tm = engine.tokenizer_manager
bootstrap_port = inner_tm.server_args.disaggregation_bootstrap_port
if inner_tm.server_args.dist_init_addr:
bootstrap_host = socket.gethostbyname(
inner_tm.server_args.dist_init_addr.split(":")[0]
)
else:
bootstrap_host = get_ip()
return bootstrap_host, bootstrap_port
@staticmethod
def create_bootstrap_info(
bootstrap_host: str, bootstrap_port: int, bootstrap_room: int
......@@ -269,13 +241,9 @@ class MultimodalWorkerHandler(BaseWorkerHandler):
component: Component,
engine: sgl.Engine,
config: Config,
metrics_publisher: WorkerMetricsPublisher = None,
kv_publisher: ZmqKvEventPublisher = None,
prefill_client: Client = None,
):
super().__init__(
component, engine, config, metrics_publisher, kv_publisher, prefill_client
)
super().__init__(component, engine, config, None, prefill_client)
# Initialize processors
self.embeddings_processor = EmbeddingsProcessor()
......@@ -444,15 +412,13 @@ class MultimodalPrefillWorkerHandler(BaseWorkerHandler):
"""
def __init__(self, component: Component, engine: sgl.Engine, config: Config):
super().__init__(component, engine, config, None, None, None)
super().__init__(component, engine, config)
# Initialize processors
self.embeddings_processor = EmbeddingsProcessor()
# Get bootstrap info using BootstrapManager
self.bootstrap_host, self.bootstrap_port = BootstrapManager.get_bootstrap_info(
engine
)
self.bootstrap_host, self.bootstrap_port = self._get_bootstrap_info(engine)
logger.info(
f"Multimodal prefill worker handler initialized - bootstrap host: {self.bootstrap_host}, bootstrap port: {self.bootstrap_port}"
......@@ -474,12 +440,14 @@ class MultimodalPrefillWorkerHandler(BaseWorkerHandler):
disagg_request = self._validate_and_parse_disagg_request(disagg_request)
# Generate and return bootstrap info first (like regular SGLang)
bootstrap_room = BootstrapManager.generate_bootstrap_room()
bootstrap_info = BootstrapManager.create_bootstrap_info(
self.bootstrap_host, self.bootstrap_port, bootstrap_room
)
bootstrap_room = self._generate_bootstrap_room()
bootstrap_info = {
"bootstrap_host": self.bootstrap_host,
"bootstrap_port": self.bootstrap_port,
"bootstrap_room": bootstrap_room,
}
yield json.dumps(bootstrap_info)
yield bootstrap_info
# Process prefill generation
await self._process_prefill_generation(disagg_request, bootstrap_room)
......
......@@ -3,11 +3,9 @@
import asyncio
import logging
import random
import socket
from typing import Any, AsyncGenerator, Dict
import sglang as sgl
from sglang.srt.utils import get_ip
from dynamo._core import Component
from dynamo.sglang.args import Config
......@@ -15,37 +13,42 @@ from dynamo.sglang.request_handlers.handler_base import BaseWorkerHandler
class PrefillWorkerHandler(BaseWorkerHandler):
def __init__(self, component: Component, engine: sgl.Engine, config: Config):
"""Handler for prefill workers in disaggregated serving mode."""
def __init__(
self, component: Component, engine: sgl.Engine, config: Config
) -> None:
"""Initialize prefill worker handler.
Args:
component: The Dynamo runtime component.
engine: The SGLang engine instance.
config: SGLang and Dynamo configuration.
"""
self.engine = engine
self.bootstrap_host, self.bootstrap_port = self._get_bootstrap_info()
super().__init__(component, engine, config, None, None, None)
self.bootstrap_host, self.bootstrap_port = self._get_bootstrap_info(self.engine)
super().__init__(component, engine, config)
logging.info(
f"Prefill worker handler initialized - bootstrap host: {self.bootstrap_host}, bootstrap port: {self.bootstrap_port}"
)
def _generate_bootstrap_room(self):
return random.randint(0, 2**63 - 1)
def cleanup(self):
def cleanup(self) -> None:
"""Shutdown the prefill engine and cleanup resources."""
self.engine.shutdown()
logging.info("Prefill engine shutdown")
super().cleanup()
def _get_bootstrap_info(self):
"""Bootstrap info from tokenizer manager"""
inner_tm = self.engine.tokenizer_manager
bootstrap_port = inner_tm.server_args.disaggregation_bootstrap_port
if inner_tm.server_args.dist_init_addr:
bootstrap_host = socket.gethostbyname(
inner_tm.server_args.dist_init_addr.split(":")[0]
)
else:
bootstrap_host = get_ip()
async def generate(
self, request: Dict[str, Any]
) -> AsyncGenerator[Dict[str, Any], None]:
"""Generate prefill output and provide bootstrap info for decode worker.
return bootstrap_host, bootstrap_port
Args:
request: Request dict with 'request' and 'sampling_params' keys.
async def generate(self, request: dict):
Yields:
Bootstrap info dict with host, port, and room for decode worker connection.
"""
bootstrap_room = self._generate_bootstrap_room()
bootstrap_info = {
......@@ -69,6 +72,11 @@ class PrefillWorkerHandler(BaseWorkerHandler):
asyncio.create_task(self._consume_results(results))
async def _consume_results(self, results):
async def _consume_results(self, results: AsyncGenerator[Any, None]) -> None:
"""Consume async generator results without processing.
Args:
results: Async generator from engine.async_generate.
"""
async for _ in results:
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment