Unverified Commit 51f65757 authored by ishandhanani's avatar ishandhanani Committed by GitHub
Browse files

feat: sglang prefill router + bump to `0.5.3` (#3498)

parent 179f993a
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# Setup cleanup trap
cleanup() {
echo "Cleaning up background processes..."
kill $DYNAMO_PID $PREFILL_PID $PREFILL_ROUTER_PID 2>/dev/null || true
wait $DYNAMO_PID $PREFILL_PID $PREFILL_ROUTER_PID 2>/dev/null || true
echo "Cleanup complete."
}
trap cleanup EXIT INT TERM
# run ingress
python3 -m dynamo.frontend \
--http-port=8000 \
--router-mode kv \
--kv-overlap-score-weight 0 \
--router-reset-states &
DYNAMO_PID=$!
# run prefill router
python3 -m dynamo.router \
--endpoint dynamo.prefill.generate \
--block-size 64 \
--router-reset-states \
--no-track-active-blocks &
PREFILL_ROUTER_PID=$!
# run prefill worker
python3 -m dynamo.sglang \
--model-path deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--served-model-name deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--page-size 64 \
--tp 1 \
--trust-remote-code \
--disaggregation-mode prefill \
--host 0.0.0.0 \
--kv-events-config '{"publisher":"zmq","topic":"kv-events","endpoint":"tcp://*:5557"}' \
--disaggregation-transfer-backend nixl &
PREFILL_PID=$!
# run prefill worker
CUDA_VISIBLE_DEVICES=1 python3 -m dynamo.sglang \
--model-path deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--served-model-name deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--page-size 64 \
--tp 1 \
--trust-remote-code \
--disaggregation-mode prefill \
--host 0.0.0.0 \
--kv-events-config '{"publisher":"zmq","topic":"kv-events","endpoint":"tcp://*:5558"}' \
--disaggregation-transfer-backend nixl &
PREFILL_PID=$!
# run decode worker
CUDA_VISIBLE_DEVICES=3 python3 -m dynamo.sglang \
--model-path deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--served-model-name deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--page-size 64 \
--tp 1 \
--trust-remote-code \
--disaggregation-mode decode \
--host 0.0.0.0 \
--kv-events-config '{"publisher":"zmq","topic":"kv-events","endpoint":"tcp://*:5560"}' \
--disaggregation-transfer-backend nixl &
PREFILL_PID=$!
# run decode worker
CUDA_VISIBLE_DEVICES=2 python3 -m dynamo.sglang \
--model-path deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--served-model-name deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--page-size 64 \
--tp 1 \
--trust-remote-code \
--disaggregation-mode decode \
--host 0.0.0.0 \
--kv-events-config '{"publisher":"zmq","topic":"kv-events","endpoint":"tcp://*:5559"}' \
--disaggregation-transfer-backend nixl
......@@ -134,10 +134,12 @@ class StandaloneRouterHandler:
logger.error("KvPushRouter not initialized - cannot get best worker")
raise RuntimeError("Router not initialized")
return await self.kv_push_router.best_worker_id(
result = await self.kv_push_router.best_worker_id(
token_ids, router_config_override
)
yield result
def parse_args():
parser = argparse.ArgumentParser(
......
......@@ -77,8 +77,14 @@ async def init(runtime: DistributedRuntime, config: Config):
generate_endpoint = component.endpoint(dynamo_args.endpoint)
prefill_client = None
prefill_router_client = None
if config.serving_mode == DisaggregationMode.DECODE:
logging.info("Initializing prefill client")
prefill_router_client = (
await runtime.namespace(dynamo_args.namespace)
.component("router")
.endpoint("best_worker_id")
.client()
)
prefill_client = (
await runtime.namespace(dynamo_args.namespace)
.component("prefill")
......@@ -94,7 +100,9 @@ async def init(runtime: DistributedRuntime, config: Config):
# Readiness gate: requests wait until model is registered
ready_event = asyncio.Event()
handler = DecodeWorkerHandler(component, engine, config, publisher, prefill_client)
handler = DecodeWorkerHandler(
component, engine, config, publisher, prefill_client, prefill_router_client
)
health_check_payload = SglangHealthCheckPayload(engine).to_dict()
......@@ -141,7 +149,12 @@ async def init_prefill(runtime: DistributedRuntime, config: Config):
generate_endpoint = component.endpoint(dynamo_args.endpoint)
handler = PrefillWorkerHandler(component, engine, config)
# publisher instantiates the metrics and kv event publishers
publisher, metrics_task, metrics_labels = await setup_sgl_metrics(
engine, config, component, generate_endpoint
)
handler = PrefillWorkerHandler(component, engine, config, publisher)
health_check_payload = SglangPrefillHealthCheckPayload(engine).to_dict()
......@@ -149,7 +162,7 @@ async def init_prefill(runtime: DistributedRuntime, config: Config):
generate_endpoint.serve_endpoint(
handler.generate,
graceful_shutdown=True,
metrics_labels=[("model", server_args.served_model_name)],
metrics_labels=metrics_labels,
health_check_payload=health_check_payload,
)
]
......@@ -160,6 +173,12 @@ async def init_prefill(runtime: DistributedRuntime, config: Config):
logging.error(f"Failed to serve endpoints: {e}")
raise
finally:
metrics_task.cancel()
try:
await metrics_task
except asyncio.CancelledError:
logging.info("Metrics task successfully cancelled")
pass
handler.cleanup()
......
......@@ -9,7 +9,7 @@ from typing import List, Optional, Tuple
import sglang as sgl
import zmq
import zmq.asyncio
from sglang.srt.utils import get_ip, get_zmq_socket
from sglang.srt.utils import get_local_ip_auto, get_zmq_socket
from dynamo.llm import (
ForwardPassMetrics,
......@@ -116,7 +116,7 @@ class DynamoSglangPublisher:
if self.server_args.kv_events_config:
kv_events = json.loads(self.server_args.kv_events_config)
ep = kv_events.get("endpoint")
zmq_ep = ep.replace("*", get_ip()) if ep else None
zmq_ep = ep.replace("*", get_local_ip_auto()) if ep else None
lease_id = self.generate_endpoint.lease_id()
......
......@@ -7,7 +7,7 @@ from abc import ABC, abstractmethod
from typing import Any, Dict, Optional, Tuple
import sglang as sgl
from sglang.srt.utils import get_ip
from sglang.srt.utils import get_local_ip_auto
from dynamo._core import Client, Component
from dynamo.sglang.args import Config
......@@ -109,6 +109,6 @@ class BaseWorkerHandler(ABC):
inner_tm.server_args.dist_init_addr.split(":")[0]
)
else:
bootstrap_host = get_ip()
bootstrap_host = get_local_ip_auto()
return bootstrap_host, bootstrap_port
......@@ -24,6 +24,7 @@ class DecodeWorkerHandler(BaseWorkerHandler):
config: Config,
publisher: DynamoSglangPublisher,
prefill_client: Optional[Client] = None,
prefill_router_client: Optional[Client] = None,
) -> None:
"""Initialize decode worker handler.
......@@ -33,6 +34,7 @@ class DecodeWorkerHandler(BaseWorkerHandler):
config: SGLang and Dynamo configuration.
publisher: Metrics publisher for the worker.
prefill_client: Optional client for prefill worker in disaggregated mode.
prefill_router_client: Optional client for prefill router in disaggregated mode.
Raises:
ValueError: If prefill_client is not provided in decode serving mode.
......@@ -52,6 +54,7 @@ class DecodeWorkerHandler(BaseWorkerHandler):
self.prefill_client = prefill_client
logging.info("Decode worker handler initialized")
self.prefill_router_client = prefill_router_client
logging.info("Worker handler initialized")
def cleanup(self) -> None:
......@@ -111,6 +114,27 @@ class DecodeWorkerHandler(BaseWorkerHandler):
if self.serving_mode == DisaggregationMode.DECODE:
# request the bootstrap info from the target prefill worker
if (
self.prefill_router_client is not None
and self.prefill_router_client.instance_ids()
):
token_ids = request["token_ids"]
stream = await self.prefill_router_client.generate(token_ids)
result = await anext(stream)
(
worker_id,
overlap,
) = result.data() # Returns tuple (worker_id, overlap_amount)
logging.info(f"Best prefill worker ID: {worker_id}, overlap: {overlap}")
prefill_stream = await self.prefill_client.direct(
DisaggPreprocessedRequest(
request=request,
sampling_params=sampling_params,
).model_dump(),
worker_id,
)
else:
prefill_stream = await self.prefill_client.generate(
DisaggPreprocessedRequest(
request=request,
......
......@@ -9,6 +9,7 @@ import sglang as sgl
from dynamo._core import Component
from dynamo.sglang.args import Config
from dynamo.sglang.publisher import DynamoSglangPublisher
from dynamo.sglang.request_handlers.handler_base import BaseWorkerHandler
......@@ -16,7 +17,11 @@ class PrefillWorkerHandler(BaseWorkerHandler):
"""Handler for prefill workers in disaggregated serving mode."""
def __init__(
self, component: Component, engine: sgl.Engine, config: Config
self,
component: Component,
engine: sgl.Engine,
config: Config,
publisher: DynamoSglangPublisher,
) -> None:
"""Initialize prefill worker handler.
......@@ -24,10 +29,11 @@ class PrefillWorkerHandler(BaseWorkerHandler):
component: The Dynamo runtime component.
engine: The SGLang engine instance.
config: SGLang and Dynamo configuration.
publisher: The SGLang publisher instance.
"""
self.engine = engine
self.bootstrap_host, self.bootstrap_port = self._get_bootstrap_info(self.engine)
super().__init__(component, engine, config)
super().__init__(component, engine, config, publisher)
logging.info(
f"Prefill worker handler initialized - bootstrap host: {self.bootstrap_host}, bootstrap port: {self.bootstrap_port}"
)
......
......@@ -14,7 +14,7 @@ ARG RUNTIME_IMAGE="nvcr.io/nvidia/cuda"
ARG RUNTIME_IMAGE_TAG="12.8.1-runtime-ubuntu24.04"
# Make sure to update the dependency version in pyproject.toml when updating this
ARG SGLANG_VERSION="0.5.3rc0"
ARG SGLANG_VERSION="0.5.3"
# Define general architecture ARGs for supporting both x86 and aarch64 builds.
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
ARG SGLANG_IMAGE_TAG="v0.5.3rc0-cu126"
ARG SGLANG_IMAGE_TAG="v0.5.3-cu126"
FROM lmsysorg/sglang:${SGLANG_IMAGE_TAG}
......
......@@ -60,7 +60,7 @@ vllm = [
sglang = [
"uvloop",
"nixl<=0.4.1",
"sglang[all]==0.5.3rc0",
"sglang[all]==0.5.3",
]
[dependency-groups]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment