Unverified Commit c5f5ab60 authored by ishandhanani's avatar ishandhanani Committed by GitHub
Browse files

fix: enable DP attention KV events for multi-node deployments (#5589)


Co-authored-by: default avatarClaude Opus 4.5 <noreply@anthropic.com>
Co-authored-by: default avatarhuitianbai <huitianbai@gmail.com>
Co-authored-by: default avatarHuitianqi Bai <huitianbai@users.noreply.github.com>
parent aea57cae
......@@ -20,7 +20,11 @@ from dynamo.sglang.health_check import (
SglangHealthCheckPayload,
SglangPrefillHealthCheckPayload,
)
from dynamo.sglang.publisher import setup_prometheus_registry, setup_sgl_metrics
from dynamo.sglang.publisher import (
DynamoSglangPublisher,
setup_prometheus_registry,
setup_sgl_metrics,
)
from dynamo.sglang.register import register_llm_with_readiness_gate
from dynamo.sglang.request_handlers import (
DecodeWorkerHandler,
......@@ -38,32 +42,38 @@ configure_dynamo_logging()
async def _handle_non_leader_node(
engine: sgl.Engine,
generate_endpoint,
publisher: DynamoSglangPublisher,
metrics_task: asyncio.Task,
) -> None:
"""
Handle non-leader node (node_rank >= 1) in multi-node deployments.
Non-leader nodes only run scheduler processes and don't handle requests,
but they should still expose metrics via Dynamo's metrics endpoint.
Non-leader nodes run scheduler processes but don't handle requests directly.
They still need:
- KV event publishing (subscribe to local DP ranks, forward to NATS)
- Metrics collection from local schedulers
- Prometheus metrics exposure
Args:
engine: The SGLang engine instance.
config: SGLang configuration including server args.
component: The Dynamo runtime component.
generate_endpoint: The Dynamo endpoint for generation requests.
publisher: The DynamoSglangPublisher for metrics and KV events.
metrics_task: The asyncio task running the metrics loop.
"""
logging.info(
f"Non-leader node detected (node_rank={engine.server_args.node_rank})."
f"Non-leader node detected (node_rank={engine.server_args.node_rank}). "
"Running with metrics and KV event publishing for local DP ranks."
)
# Only setup Prometheus registry to expose SGLang metrics from shared memory
# Non-leader nodes don't need Dynamo metrics publishing or KV events
if engine.server_args.enable_metrics:
setup_prometheus_registry(engine, generate_endpoint)
logging.info("Prometheus metrics registry configured for non-leader node")
try:
# Wait indefinitely - the process will be terminated via signal handlers
await asyncio.Event().wait()
finally:
metrics_task.cancel()
try:
await metrics_task
except asyncio.CancelledError:
pass
publisher.cleanup()
async def worker():
......@@ -137,13 +147,8 @@ async def init(runtime: DistributedRuntime, config: Config):
generate_endpoint = component.endpoint(dynamo_args.endpoint)
# Handle non-leader nodes (multi-node parallelism)
# Non-leader nodes only run scheduler processes and expose metrics
if server_args.node_rank >= 1:
await _handle_non_leader_node(engine, generate_endpoint)
return
# publisher instantiates the metrics and kv event publishers
# Setup metrics and KV events for ALL nodes (including non-leader)
# Non-leader nodes need KV event publishing for their local DP ranks
publisher, metrics_task, metrics_labels = await setup_sgl_metrics(
engine, config, component, generate_endpoint
)
......@@ -152,6 +157,12 @@ async def init(runtime: DistributedRuntime, config: Config):
if engine.server_args.enable_metrics:
setup_prometheus_registry(engine, generate_endpoint)
# Handle non-leader nodes (multi-node parallelism)
# Non-leader nodes run schedulers and publish KV events, but don't serve requests
if server_args.node_rank >= 1:
await _handle_non_leader_node(engine, publisher, metrics_task)
return
# Readiness gate: requests wait until model is registered
ready_event = asyncio.Event()
......@@ -160,7 +171,6 @@ async def init(runtime: DistributedRuntime, config: Config):
)
handler.register_engine_routes(runtime)
print(f"Config: {config}")
health_check_payload = SglangHealthCheckPayload(
engine, use_text_input=dynamo_args.use_sglang_tokenizer
).to_dict()
......@@ -224,17 +234,8 @@ async def init_prefill(runtime: DistributedRuntime, config: Config):
generate_endpoint = component.endpoint(dynamo_args.endpoint)
# Handle non-leader nodes (multi-node tensor parallelism)
# Non-leader nodes only run scheduler processes and expose metrics
if server_args.node_rank >= 1:
await _handle_non_leader_node(engine, generate_endpoint)
return
# Perform dummy warmup for prefill worker to avoid initial TTFT hit
# Only needed on leader node that handles requests
await _warmup_prefill_engine(engine, server_args)
# publisher instantiates the metrics and kv event publishers
# Setup metrics and KV events for ALL nodes (including non-leader)
# Non-leader nodes need KV event publishing for their local DP ranks
publisher, metrics_task, metrics_labels = await setup_sgl_metrics(
engine, config, component, generate_endpoint
)
......@@ -243,6 +244,16 @@ async def init_prefill(runtime: DistributedRuntime, config: Config):
if engine.server_args.enable_metrics:
setup_prometheus_registry(engine, generate_endpoint)
# Handle non-leader nodes (multi-node parallelism)
# Non-leader nodes run schedulers and publish KV events, but don't serve requests
if server_args.node_rank >= 1:
await _handle_non_leader_node(engine, publisher, metrics_task)
return
# Perform dummy warmup for prefill worker to avoid initial TTFT hit
# Only needed on leader node that handles requests
await _warmup_prefill_engine(engine, server_args)
handler = PrefillWorkerHandler(
component, engine, config, publisher, generate_endpoint
)
......@@ -310,12 +321,8 @@ async def init_diffusion(runtime: DistributedRuntime, config: Config):
generate_endpoint = component.endpoint(dynamo_args.endpoint)
# Handle non-leader nodes (multi-node parallelism)
if server_args.node_rank >= 1:
await _handle_non_leader_node(engine, generate_endpoint)
return
# Setup metrics publisher
# Setup metrics and KV events for ALL nodes (including non-leader)
# Non-leader nodes need KV event publishing for their local DP ranks
publisher, metrics_task, metrics_labels = await setup_sgl_metrics(
engine, config, component, generate_endpoint
)
......@@ -324,6 +331,12 @@ async def init_diffusion(runtime: DistributedRuntime, config: Config):
if engine.server_args.enable_metrics:
setup_prometheus_registry(engine, generate_endpoint)
# Handle non-leader nodes (multi-node parallelism)
# Non-leader nodes run schedulers and publish KV events, but don't serve requests
if server_args.node_rank >= 1:
await _handle_non_leader_node(engine, publisher, metrics_task)
return
# Readiness gate: requests wait until model is registered
ready_event = asyncio.Event()
......
......@@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, List, Optional, Tuple
import sglang as sgl
import zmq
import zmq.asyncio
from sglang.srt.disaggregation.kv_events import ZmqEventPublisher
from sglang.srt.utils import get_local_ip_auto, get_zmq_socket, maybe_wrap_ipv6_address
if TYPE_CHECKING:
......@@ -48,6 +49,11 @@ def format_zmq_endpoint(endpoint_template: str, ip_address: str) -> str:
return endpoint_template.replace("*", formatted_ip)
# Note: We use SGLang's ZmqEventPublisher.offset_endpoint_port() directly
# to ensure perfect alignment between publisher (SGLang) and subscriber (dynamo).
# This is the same pattern used by dynamo+vLLM.
class DynamoSglangPublisher:
"""
Handles SGLang kv events and metrics reception and publishing.
......@@ -81,15 +87,43 @@ class DynamoSglangPublisher:
# Set default values (can be overridden later if needed)
self.dp_rank = 0
# ZMQ setup for receiving scheduler metrics
self._running = True
self.kv_publishers: List[ZmqKvEventPublisher] = []
# ZMQ setup for receiving scheduler metrics (leader node only)
# Non-leader nodes don't receive scheduler metrics via this socket - they only
# need KV event publishing which is set up separately in init_kv_event_publish()
node_rank = getattr(self.server_args, "node_rank", 0) or 0
if node_rank == 0:
self._ctx = zmq.asyncio.Context() # type: ignore
self._sock = get_zmq_socket(
self._ctx, zmq.PULL, self.engine.port_args.metrics_ipc_name, True # type: ignore
self._ctx,
zmq.PULL,
self.engine.port_args.metrics_ipc_name,
True, # type: ignore
)
else:
self._ctx = None
self._sock = None
logging.info(
f"Non-leader node (node_rank={node_rank}): skipping scheduler metrics "
"ZMQ socket setup. KV event publishing will still be configured."
)
async def run(self) -> None:
"""Continuously receive scheduler metrics from ZMQ socket and publish them."""
while True:
"""Continuously receive scheduler metrics from ZMQ socket and publish them.
On non-leader nodes (node_rank >= 1), this is a no-op since they don't have
a scheduler metrics socket. They only publish KV events via init_kv_event_publish().
"""
if self._sock is None:
# Non-leader node: no scheduler metrics to receive
# Just wait until stopped (KV events are handled by separate publishers)
while self._running:
await asyncio.sleep(1)
return
while self._running:
try:
kv_metrics = await self._sock.recv_pyobj() # type: ignore
dp_rank = (
......@@ -99,26 +133,101 @@ class DynamoSglangPublisher:
)
self.metrics_publisher.publish(dp_rank, kv_metrics.kv_active_blocks)
except Exception:
if self._running:
logging.exception(
"Failed to receive or publish SGLang scheduler metrics"
)
def cleanup(self) -> None:
"""Clean up ZMQ resources."""
self._running = False
# Close ZMQ socket and context
if self._sock is not None:
try:
self._sock.close(linger=0)
except Exception as e:
logging.warning(f"Failed to close ZMQ socket: {e}")
if self._ctx is not None:
try:
self._ctx.term()
except Exception as e:
logging.warning(f"Failed to terminate ZMQ context: {e}")
# Shutdown kv publishers
for publisher in self.kv_publishers:
try:
publisher.shutdown()
except Exception as e:
logging.warning(f"Failed to shutdown kv publisher: {e}")
logging.info("DynamoSglangPublisher cleanup complete")
def init_engine_metrics_publish(self) -> None:
"""Publish initial dummy metrics to bootstrap the metrics endpoint."""
logging.info("Sending dummy metrics to initialize")
self.metrics_publisher.publish(self.dp_rank, 0)
def init_kv_event_publish(self) -> Optional[ZmqKvEventPublisher]:
"""Initialize KV event publisher if configured.
def init_kv_event_publish(self) -> List[ZmqKvEventPublisher]:
"""Initialize KV event publisher(s) if configured.
For DP attention mode, creates one subscriber per LOCAL DP rank port.
Each SGLang scheduler in DP attention mode publishes to a unique port
(base_port + attn_dp_rank). In multi-node setups, each node's dynamo.sglang
instance subscribes only to the DP ranks running on that node.
Multi-node handling:
- Each node runs dynamo.sglang alongside its local SGLang DP ranks
- Each dynamo.sglang subscribes only to LOCAL DP ranks (same node)
- SGLang binds locally (wildcard), Dynamo connects locally
- NATS handles cross-node event distribution
Returns:
ZmqKvEventPublisher instance if kv_events_config is set, None otherwise.
List of ZmqKvEventPublisher instances if kv_events_config is set,
empty list otherwise.
"""
self.kv_publisher = None
if self.server_args.kv_events_config:
kv_events = json.loads(self.server_args.kv_events_config)
ep = kv_events.get("endpoint")
zmq_ep = format_zmq_endpoint(ep, get_local_ip_auto()) if ep else None
base_ep = kv_events.get("endpoint")
local_ip = get_local_ip_auto()
# Determine DP attention configuration
dp_size = getattr(self.server_args, "dp_size", 1) or 1
enable_dp_attention = getattr(
self.server_args, "enable_dp_attention", False
)
nnodes = getattr(self.server_args, "nnodes", 1) or 1
node_rank = getattr(self.server_args, "node_rank", 0) or 0
if enable_dp_attention and dp_size > 1:
# Calculate which DP ranks are local to this node
# DP ranks are distributed evenly across nodes
local_dp_size = dp_size // nnodes if nnodes > 0 else dp_size
start_dp_rank = node_rank * local_dp_size
end_dp_rank = start_dp_rank + local_dp_size
logging.info(
f"DP attention mode: node_rank={node_rank}, dp_size={dp_size}, "
f"nnodes={nnodes}. Subscribing to local DP ranks [{start_dp_rank}, {end_dp_rank})"
)
else:
# Standard mode: single subscriber for rank 0
start_dp_rank = 0
end_dp_rank = 1
for dp_rank in range(start_dp_rank, end_dp_rank):
# Use SGLang's offset_endpoint_port to ensure alignment with publishers
# This is the same function SGLang schedulers use to determine their bind ports
zmq_ep = ZmqEventPublisher.offset_endpoint_port(base_ep, dp_rank)
if not zmq_ep:
logging.warning(
f"Skipping ZMQ subscriber for dp_rank={dp_rank}: "
f"offset_endpoint_port returned None for base_ep={base_ep}"
)
continue
zmq_ep = format_zmq_endpoint(zmq_ep, local_ip)
zmq_config = ZmqKvEventPublisherConfig(
worker_id=self.generate_endpoint.connection_id(),
......@@ -126,11 +235,19 @@ class DynamoSglangPublisher:
zmq_endpoint=zmq_ep,
enable_local_indexer=self.dynamo_args.enable_local_indexer,
)
logging.info(f"Setting up ZMQ kv event publisher at {zmq_ep}")
self.kv_publisher = ZmqKvEventPublisher(
logging.info(
f"Setting up ZMQ kv event subscriber for dp_rank={dp_rank} "
f"(connecting to {zmq_ep})"
)
publisher = ZmqKvEventPublisher(
component=self.component, config=zmq_config
)
return self.kv_publisher
self.kv_publishers.append(publisher)
# Maintain backward compatibility: set kv_publisher to first publisher if any
self.kv_publisher = self.kv_publishers[0] if self.kv_publishers else None
return self.kv_publishers
def setup_prometheus_registry(
......
......@@ -163,6 +163,13 @@ async def _get_runtime_config(
runtime_config.tool_call_parser = dynamo_args.tool_call_parser
runtime_config.enable_local_indexer = dynamo_args.enable_local_indexer
# Set data_parallel_size for DP attention mode
# This enables the router to correctly track per-(worker_id, dp_rank) pairs
dp_size = getattr(server_args, "dp_size", 1) or 1
runtime_config.data_parallel_size = dp_size
if dp_size > 1:
logging.info(f"Registering with data_parallel_size={dp_size}")
# Set bootstrap endpoint for disaggregated serving (prefill workers)
bootstrap_host, bootstrap_port = _get_bootstrap_info_for_config(engine)
if bootstrap_host and bootstrap_port:
......
......@@ -25,9 +25,9 @@ class EmbeddingWorkerHandler(BaseWorkerHandler):
logging.info("Embedding worker handler initialized")
def cleanup(self):
super().cleanup()
self.engine.shutdown()
logging.info("Engine shutdown")
super().cleanup()
async def generate(self, request: dict, context: Context):
"""
......
......@@ -42,6 +42,7 @@ class BaseWorkerHandler(ABC):
self.engine = engine
self.config = config
self.generate_endpoint = generate_endpoint
self.publisher = publisher
if publisher is not None:
self.metrics_publisher = publisher.metrics_publisher
self.kv_publisher = publisher.kv_publisher
......@@ -202,7 +203,8 @@ class BaseWorkerHandler(ABC):
def cleanup(self) -> None:
"""Cleanup resources. Override in subclasses as needed."""
pass
if self.publisher is not None:
self.publisher.cleanup()
def _get_input_param(self, request: Dict[str, Any]) -> Dict[str, Any]:
request_input = self.input_param_manager.get_input_param(
......
......@@ -50,9 +50,9 @@ class DecodeWorkerHandler(BaseWorkerHandler):
def cleanup(self) -> None:
"""Shutdown the engine and cleanup resources."""
super().cleanup()
self.engine.shutdown()
logging.info("Engine shutdown")
super().cleanup()
def _build_sampling_params(self, request: Dict[str, Any]) -> Dict[str, Any]:
"""Build sampling params from request format.
......@@ -126,6 +126,10 @@ class DecodeWorkerHandler(BaseWorkerHandler):
self._get_trace_header(context) if self.enable_trace else None
)
# Extract dp_rank from routing info (set by KV router)
routing = request.get("routing") or {}
dp_rank = routing.get("dp_rank")
decode = await self.engine.async_generate(
**input_param,
sampling_params=sampling_params,
......@@ -135,6 +139,7 @@ class DecodeWorkerHandler(BaseWorkerHandler):
bootstrap_room=bootstrap_info["bootstrap_room"],
external_trace_header=trace_header,
rid=trace_id,
data_parallel_rank=dp_rank,
)
if self.skip_tokenizer_init:
......@@ -161,6 +166,10 @@ class DecodeWorkerHandler(BaseWorkerHandler):
self._get_trace_header(context) if self.enable_trace else None
)
# Extract dp_rank from routing info (set by KV router)
routing = request.get("routing") or {}
dp_rank = routing.get("dp_rank")
agg = await self.engine.async_generate(
**input_param,
image_data=image_data,
......@@ -168,6 +177,7 @@ class DecodeWorkerHandler(BaseWorkerHandler):
stream=True,
external_trace_header=trace_header,
rid=trace_id,
data_parallel_rank=dp_rank,
)
if self.skip_tokenizer_init:
async for out in self._process_token_stream(agg, context):
......
......@@ -49,9 +49,9 @@ class PrefillWorkerHandler(BaseWorkerHandler):
task.cancel()
self._consume_tasks.clear()
super().cleanup()
self.engine.shutdown()
logging.info("Prefill engine shutdown")
super().cleanup()
async def generate(
self, request: Dict[str, Any], context: Context
......
......@@ -400,9 +400,9 @@ class MultimodalWorkerHandler(BaseWorkerHandler):
return bootstrap_info
def cleanup(self):
super().cleanup()
self.engine.shutdown()
logger.info("Multimodal worker engine shutdown")
super().cleanup()
class MultimodalPrefillWorkerHandler(BaseWorkerHandler):
......@@ -515,6 +515,6 @@ class MultimodalPrefillWorkerHandler(BaseWorkerHandler):
pass
def cleanup(self):
super().cleanup()
self.engine.shutdown()
logger.info("Multimodal prefill engine shutdown")
super().cleanup()
......@@ -182,6 +182,7 @@ pub(crate) struct ZmqKvEventListener {
#[pymethods]
impl ZmqKvEventListener {
#[new]
#[pyo3(signature = (zmq_endpoint, zmq_topic, kv_block_size))]
fn new(zmq_endpoint: String, zmq_topic: String, kv_block_size: usize) -> PyResult<Self> {
if kv_block_size == 0 {
return Err(to_pyerr(anyhow::anyhow!("kv_block_size cannot be 0")));
......
......@@ -423,8 +423,11 @@ pub async fn start_zmq_listener(
return;
}
// Connect to the ZMQ endpoint. SGLang binds locally, Dynamo connects.
// In multi-node setups, each node runs dynamo.sglang alongside local SGLang ranks,
// so ZMQ connections are always local. NATS handles cross-node event distribution.
if let Err(e) = socket.connect(&zmq_endpoint).await {
tracing::error!("Failed to connect ZMQ SUB socket: {}", e);
tracing::error!("Failed to connect ZMQ SUB socket to {zmq_endpoint}: {e}");
return;
}
......@@ -1556,7 +1559,7 @@ mod tests_startup_helpers {
// Cancellation token so we can stop the listener
let token = dynamo_runtime::CancellationToken::new();
// Spawn async listener
// Spawn async listener (connects to publisher bound above)
let listener_handle = tokio::spawn({
let token = token.clone();
start_zmq_listener(endpoint.to_string(), topic, tx, token, 4)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment