Unverified Commit 51f65757 authored by ishandhanani's avatar ishandhanani Committed by GitHub
Browse files

feat: sglang prefill router + bump to `0.5.3` (#3498)

parent 179f993a
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# Setup cleanup trap
cleanup() {
echo "Cleaning up background processes..."
kill $DYNAMO_PID $PREFILL_PID $PREFILL_ROUTER_PID 2>/dev/null || true
wait $DYNAMO_PID $PREFILL_PID $PREFILL_ROUTER_PID 2>/dev/null || true
echo "Cleanup complete."
}
trap cleanup EXIT INT TERM
# run ingress
python3 -m dynamo.frontend \
--http-port=8000 \
--router-mode kv \
--kv-overlap-score-weight 0 \
--router-reset-states &
DYNAMO_PID=$!
# run prefill router
python3 -m dynamo.router \
--endpoint dynamo.prefill.generate \
--block-size 64 \
--router-reset-states \
--no-track-active-blocks &
PREFILL_ROUTER_PID=$!
# run prefill worker
python3 -m dynamo.sglang \
--model-path deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--served-model-name deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--page-size 64 \
--tp 1 \
--trust-remote-code \
--disaggregation-mode prefill \
--host 0.0.0.0 \
--kv-events-config '{"publisher":"zmq","topic":"kv-events","endpoint":"tcp://*:5557"}' \
--disaggregation-transfer-backend nixl &
PREFILL_PID=$!
# run prefill worker
CUDA_VISIBLE_DEVICES=1 python3 -m dynamo.sglang \
--model-path deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--served-model-name deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--page-size 64 \
--tp 1 \
--trust-remote-code \
--disaggregation-mode prefill \
--host 0.0.0.0 \
--kv-events-config '{"publisher":"zmq","topic":"kv-events","endpoint":"tcp://*:5558"}' \
--disaggregation-transfer-backend nixl &
PREFILL_PID=$!
# run decode worker
CUDA_VISIBLE_DEVICES=3 python3 -m dynamo.sglang \
--model-path deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--served-model-name deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--page-size 64 \
--tp 1 \
--trust-remote-code \
--disaggregation-mode decode \
--host 0.0.0.0 \
--kv-events-config '{"publisher":"zmq","topic":"kv-events","endpoint":"tcp://*:5560"}' \
--disaggregation-transfer-backend nixl &
PREFILL_PID=$!
# run decode worker
CUDA_VISIBLE_DEVICES=2 python3 -m dynamo.sglang \
--model-path deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--served-model-name deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--page-size 64 \
--tp 1 \
--trust-remote-code \
--disaggregation-mode decode \
--host 0.0.0.0 \
--kv-events-config '{"publisher":"zmq","topic":"kv-events","endpoint":"tcp://*:5559"}' \
--disaggregation-transfer-backend nixl
...@@ -134,10 +134,12 @@ class StandaloneRouterHandler: ...@@ -134,10 +134,12 @@ class StandaloneRouterHandler:
logger.error("KvPushRouter not initialized - cannot get best worker") logger.error("KvPushRouter not initialized - cannot get best worker")
raise RuntimeError("Router not initialized") raise RuntimeError("Router not initialized")
return await self.kv_push_router.best_worker_id( result = await self.kv_push_router.best_worker_id(
token_ids, router_config_override token_ids, router_config_override
) )
yield result
def parse_args(): def parse_args():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
......
...@@ -77,8 +77,14 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -77,8 +77,14 @@ async def init(runtime: DistributedRuntime, config: Config):
generate_endpoint = component.endpoint(dynamo_args.endpoint) generate_endpoint = component.endpoint(dynamo_args.endpoint)
prefill_client = None prefill_client = None
prefill_router_client = None
if config.serving_mode == DisaggregationMode.DECODE: if config.serving_mode == DisaggregationMode.DECODE:
logging.info("Initializing prefill client") prefill_router_client = (
await runtime.namespace(dynamo_args.namespace)
.component("router")
.endpoint("best_worker_id")
.client()
)
prefill_client = ( prefill_client = (
await runtime.namespace(dynamo_args.namespace) await runtime.namespace(dynamo_args.namespace)
.component("prefill") .component("prefill")
...@@ -94,7 +100,9 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -94,7 +100,9 @@ async def init(runtime: DistributedRuntime, config: Config):
# Readiness gate: requests wait until model is registered # Readiness gate: requests wait until model is registered
ready_event = asyncio.Event() ready_event = asyncio.Event()
handler = DecodeWorkerHandler(component, engine, config, publisher, prefill_client) handler = DecodeWorkerHandler(
component, engine, config, publisher, prefill_client, prefill_router_client
)
health_check_payload = SglangHealthCheckPayload(engine).to_dict() health_check_payload = SglangHealthCheckPayload(engine).to_dict()
...@@ -141,7 +149,12 @@ async def init_prefill(runtime: DistributedRuntime, config: Config): ...@@ -141,7 +149,12 @@ async def init_prefill(runtime: DistributedRuntime, config: Config):
generate_endpoint = component.endpoint(dynamo_args.endpoint) generate_endpoint = component.endpoint(dynamo_args.endpoint)
handler = PrefillWorkerHandler(component, engine, config) # publisher instantiates the metrics and kv event publishers
publisher, metrics_task, metrics_labels = await setup_sgl_metrics(
engine, config, component, generate_endpoint
)
handler = PrefillWorkerHandler(component, engine, config, publisher)
health_check_payload = SglangPrefillHealthCheckPayload(engine).to_dict() health_check_payload = SglangPrefillHealthCheckPayload(engine).to_dict()
...@@ -149,7 +162,7 @@ async def init_prefill(runtime: DistributedRuntime, config: Config): ...@@ -149,7 +162,7 @@ async def init_prefill(runtime: DistributedRuntime, config: Config):
generate_endpoint.serve_endpoint( generate_endpoint.serve_endpoint(
handler.generate, handler.generate,
graceful_shutdown=True, graceful_shutdown=True,
metrics_labels=[("model", server_args.served_model_name)], metrics_labels=metrics_labels,
health_check_payload=health_check_payload, health_check_payload=health_check_payload,
) )
] ]
...@@ -160,6 +173,12 @@ async def init_prefill(runtime: DistributedRuntime, config: Config): ...@@ -160,6 +173,12 @@ async def init_prefill(runtime: DistributedRuntime, config: Config):
logging.error(f"Failed to serve endpoints: {e}") logging.error(f"Failed to serve endpoints: {e}")
raise raise
finally: finally:
metrics_task.cancel()
try:
await metrics_task
except asyncio.CancelledError:
logging.info("Metrics task successfully cancelled")
pass
handler.cleanup() handler.cleanup()
......
...@@ -9,7 +9,7 @@ from typing import List, Optional, Tuple ...@@ -9,7 +9,7 @@ from typing import List, Optional, Tuple
import sglang as sgl import sglang as sgl
import zmq import zmq
import zmq.asyncio import zmq.asyncio
from sglang.srt.utils import get_ip, get_zmq_socket from sglang.srt.utils import get_local_ip_auto, get_zmq_socket
from dynamo.llm import ( from dynamo.llm import (
ForwardPassMetrics, ForwardPassMetrics,
...@@ -116,7 +116,7 @@ class DynamoSglangPublisher: ...@@ -116,7 +116,7 @@ class DynamoSglangPublisher:
if self.server_args.kv_events_config: if self.server_args.kv_events_config:
kv_events = json.loads(self.server_args.kv_events_config) kv_events = json.loads(self.server_args.kv_events_config)
ep = kv_events.get("endpoint") ep = kv_events.get("endpoint")
zmq_ep = ep.replace("*", get_ip()) if ep else None zmq_ep = ep.replace("*", get_local_ip_auto()) if ep else None
lease_id = self.generate_endpoint.lease_id() lease_id = self.generate_endpoint.lease_id()
......
...@@ -7,7 +7,7 @@ from abc import ABC, abstractmethod ...@@ -7,7 +7,7 @@ from abc import ABC, abstractmethod
from typing import Any, Dict, Optional, Tuple from typing import Any, Dict, Optional, Tuple
import sglang as sgl import sglang as sgl
from sglang.srt.utils import get_ip from sglang.srt.utils import get_local_ip_auto
from dynamo._core import Client, Component from dynamo._core import Client, Component
from dynamo.sglang.args import Config from dynamo.sglang.args import Config
...@@ -109,6 +109,6 @@ class BaseWorkerHandler(ABC): ...@@ -109,6 +109,6 @@ class BaseWorkerHandler(ABC):
inner_tm.server_args.dist_init_addr.split(":")[0] inner_tm.server_args.dist_init_addr.split(":")[0]
) )
else: else:
bootstrap_host = get_ip() bootstrap_host = get_local_ip_auto()
return bootstrap_host, bootstrap_port return bootstrap_host, bootstrap_port
...@@ -24,6 +24,7 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -24,6 +24,7 @@ class DecodeWorkerHandler(BaseWorkerHandler):
config: Config, config: Config,
publisher: DynamoSglangPublisher, publisher: DynamoSglangPublisher,
prefill_client: Optional[Client] = None, prefill_client: Optional[Client] = None,
prefill_router_client: Optional[Client] = None,
) -> None: ) -> None:
"""Initialize decode worker handler. """Initialize decode worker handler.
...@@ -33,6 +34,7 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -33,6 +34,7 @@ class DecodeWorkerHandler(BaseWorkerHandler):
config: SGLang and Dynamo configuration. config: SGLang and Dynamo configuration.
publisher: Metrics publisher for the worker. publisher: Metrics publisher for the worker.
prefill_client: Optional client for prefill worker in disaggregated mode. prefill_client: Optional client for prefill worker in disaggregated mode.
prefill_router_client: Optional client for prefill router in disaggregated mode.
Raises: Raises:
ValueError: If prefill_client is not provided in decode serving mode. ValueError: If prefill_client is not provided in decode serving mode.
...@@ -52,6 +54,7 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -52,6 +54,7 @@ class DecodeWorkerHandler(BaseWorkerHandler):
self.prefill_client = prefill_client self.prefill_client = prefill_client
logging.info("Decode worker handler initialized") logging.info("Decode worker handler initialized")
self.prefill_router_client = prefill_router_client
logging.info("Worker handler initialized") logging.info("Worker handler initialized")
def cleanup(self) -> None: def cleanup(self) -> None:
...@@ -111,12 +114,33 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -111,12 +114,33 @@ class DecodeWorkerHandler(BaseWorkerHandler):
if self.serving_mode == DisaggregationMode.DECODE: if self.serving_mode == DisaggregationMode.DECODE:
# request the bootstrap info from the target prefill worker # request the bootstrap info from the target prefill worker
prefill_stream = await self.prefill_client.generate( if (
DisaggPreprocessedRequest( self.prefill_router_client is not None
request=request, and self.prefill_router_client.instance_ids()
sampling_params=sampling_params, ):
).model_dump() token_ids = request["token_ids"]
) stream = await self.prefill_router_client.generate(token_ids)
result = await anext(stream)
(
worker_id,
overlap,
) = result.data() # Returns tuple (worker_id, overlap_amount)
logging.info(f"Best prefill worker ID: {worker_id}, overlap: {overlap}")
prefill_stream = await self.prefill_client.direct(
DisaggPreprocessedRequest(
request=request,
sampling_params=sampling_params,
).model_dump(),
worker_id,
)
else:
prefill_stream = await self.prefill_client.generate(
DisaggPreprocessedRequest(
request=request,
sampling_params=sampling_params,
).model_dump()
)
bootstrap_info = None bootstrap_info = None
async for info in prefill_stream: async for info in prefill_stream:
......
...@@ -9,6 +9,7 @@ import sglang as sgl ...@@ -9,6 +9,7 @@ import sglang as sgl
from dynamo._core import Component from dynamo._core import Component
from dynamo.sglang.args import Config from dynamo.sglang.args import Config
from dynamo.sglang.publisher import DynamoSglangPublisher
from dynamo.sglang.request_handlers.handler_base import BaseWorkerHandler from dynamo.sglang.request_handlers.handler_base import BaseWorkerHandler
...@@ -16,7 +17,11 @@ class PrefillWorkerHandler(BaseWorkerHandler): ...@@ -16,7 +17,11 @@ class PrefillWorkerHandler(BaseWorkerHandler):
"""Handler for prefill workers in disaggregated serving mode.""" """Handler for prefill workers in disaggregated serving mode."""
def __init__( def __init__(
self, component: Component, engine: sgl.Engine, config: Config self,
component: Component,
engine: sgl.Engine,
config: Config,
publisher: DynamoSglangPublisher,
) -> None: ) -> None:
"""Initialize prefill worker handler. """Initialize prefill worker handler.
...@@ -24,10 +29,11 @@ class PrefillWorkerHandler(BaseWorkerHandler): ...@@ -24,10 +29,11 @@ class PrefillWorkerHandler(BaseWorkerHandler):
component: The Dynamo runtime component. component: The Dynamo runtime component.
engine: The SGLang engine instance. engine: The SGLang engine instance.
config: SGLang and Dynamo configuration. config: SGLang and Dynamo configuration.
publisher: The SGLang publisher instance.
""" """
self.engine = engine self.engine = engine
self.bootstrap_host, self.bootstrap_port = self._get_bootstrap_info(self.engine) self.bootstrap_host, self.bootstrap_port = self._get_bootstrap_info(self.engine)
super().__init__(component, engine, config) super().__init__(component, engine, config, publisher)
logging.info( logging.info(
f"Prefill worker handler initialized - bootstrap host: {self.bootstrap_host}, bootstrap port: {self.bootstrap_port}" f"Prefill worker handler initialized - bootstrap host: {self.bootstrap_host}, bootstrap port: {self.bootstrap_port}"
) )
......
...@@ -14,7 +14,7 @@ ARG RUNTIME_IMAGE="nvcr.io/nvidia/cuda" ...@@ -14,7 +14,7 @@ ARG RUNTIME_IMAGE="nvcr.io/nvidia/cuda"
ARG RUNTIME_IMAGE_TAG="12.8.1-runtime-ubuntu24.04" ARG RUNTIME_IMAGE_TAG="12.8.1-runtime-ubuntu24.04"
# Make sure to update the dependency version in pyproject.toml when updating this # Make sure to update the dependency version in pyproject.toml when updating this
ARG SGLANG_VERSION="0.5.3rc0" ARG SGLANG_VERSION="0.5.3"
# Define general architecture ARGs for supporting both x86 and aarch64 builds. # Define general architecture ARGs for supporting both x86 and aarch64 builds.
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
ARG SGLANG_IMAGE_TAG="v0.5.3rc0-cu126" ARG SGLANG_IMAGE_TAG="v0.5.3-cu126"
FROM lmsysorg/sglang:${SGLANG_IMAGE_TAG} FROM lmsysorg/sglang:${SGLANG_IMAGE_TAG}
......
...@@ -60,7 +60,7 @@ vllm = [ ...@@ -60,7 +60,7 @@ vllm = [
sglang = [ sglang = [
"uvloop", "uvloop",
"nixl<=0.4.1", "nixl<=0.4.1",
"sglang[all]==0.5.3rc0", "sglang[all]==0.5.3",
] ]
[dependency-groups] [dependency-groups]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment