Unverified Commit c5f5ab60 authored by ishandhanani's avatar ishandhanani Committed by GitHub
Browse files

fix: enable DP attention KV events for multi-node deployments (#5589)


Co-authored-by: default avatarClaude Opus 4.5 <noreply@anthropic.com>
Co-authored-by: default avatarhuitianbai <huitianbai@gmail.com>
Co-authored-by: default avatarHuitianqi Bai <huitianbai@users.noreply.github.com>
parent aea57cae
...@@ -20,7 +20,11 @@ from dynamo.sglang.health_check import ( ...@@ -20,7 +20,11 @@ from dynamo.sglang.health_check import (
SglangHealthCheckPayload, SglangHealthCheckPayload,
SglangPrefillHealthCheckPayload, SglangPrefillHealthCheckPayload,
) )
from dynamo.sglang.publisher import setup_prometheus_registry, setup_sgl_metrics from dynamo.sglang.publisher import (
DynamoSglangPublisher,
setup_prometheus_registry,
setup_sgl_metrics,
)
from dynamo.sglang.register import register_llm_with_readiness_gate from dynamo.sglang.register import register_llm_with_readiness_gate
from dynamo.sglang.request_handlers import ( from dynamo.sglang.request_handlers import (
DecodeWorkerHandler, DecodeWorkerHandler,
...@@ -38,32 +42,38 @@ configure_dynamo_logging() ...@@ -38,32 +42,38 @@ configure_dynamo_logging()
async def _handle_non_leader_node( async def _handle_non_leader_node(
engine: sgl.Engine, engine: sgl.Engine,
generate_endpoint, publisher: DynamoSglangPublisher,
metrics_task: asyncio.Task,
) -> None: ) -> None:
""" """
Handle non-leader node (node_rank >= 1) in multi-node deployments. Handle non-leader node (node_rank >= 1) in multi-node deployments.
Non-leader nodes only run scheduler processes and don't handle requests, Non-leader nodes run scheduler processes but don't handle requests directly.
but they should still expose metrics via Dynamo's metrics endpoint. They still need:
- KV event publishing (subscribe to local DP ranks, forward to NATS)
- Metrics collection from local schedulers
- Prometheus metrics exposure
Args: Args:
engine: The SGLang engine instance. engine: The SGLang engine instance.
config: SGLang configuration including server args. publisher: The DynamoSglangPublisher for metrics and KV events.
component: The Dynamo runtime component. metrics_task: The asyncio task running the metrics loop.
generate_endpoint: The Dynamo endpoint for generation requests.
""" """
logging.info( logging.info(
f"Non-leader node detected (node_rank={engine.server_args.node_rank})." f"Non-leader node detected (node_rank={engine.server_args.node_rank}). "
"Running with metrics and KV event publishing for local DP ranks."
) )
# Only setup Prometheus registry to expose SGLang metrics from shared memory try:
# Non-leader nodes don't need Dynamo metrics publishing or KV events
if engine.server_args.enable_metrics:
setup_prometheus_registry(engine, generate_endpoint)
logging.info("Prometheus metrics registry configured for non-leader node")
# Wait indefinitely - the process will be terminated via signal handlers # Wait indefinitely - the process will be terminated via signal handlers
await asyncio.Event().wait() await asyncio.Event().wait()
finally:
metrics_task.cancel()
try:
await metrics_task
except asyncio.CancelledError:
pass
publisher.cleanup()
async def worker(): async def worker():
...@@ -137,13 +147,8 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -137,13 +147,8 @@ async def init(runtime: DistributedRuntime, config: Config):
generate_endpoint = component.endpoint(dynamo_args.endpoint) generate_endpoint = component.endpoint(dynamo_args.endpoint)
# Handle non-leader nodes (multi-node parallelism) # Setup metrics and KV events for ALL nodes (including non-leader)
# Non-leader nodes only run scheduler processes and expose metrics # Non-leader nodes need KV event publishing for their local DP ranks
if server_args.node_rank >= 1:
await _handle_non_leader_node(engine, generate_endpoint)
return
# publisher instantiates the metrics and kv event publishers
publisher, metrics_task, metrics_labels = await setup_sgl_metrics( publisher, metrics_task, metrics_labels = await setup_sgl_metrics(
engine, config, component, generate_endpoint engine, config, component, generate_endpoint
) )
...@@ -152,6 +157,12 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -152,6 +157,12 @@ async def init(runtime: DistributedRuntime, config: Config):
if engine.server_args.enable_metrics: if engine.server_args.enable_metrics:
setup_prometheus_registry(engine, generate_endpoint) setup_prometheus_registry(engine, generate_endpoint)
# Handle non-leader nodes (multi-node parallelism)
# Non-leader nodes run schedulers and publish KV events, but don't serve requests
if server_args.node_rank >= 1:
await _handle_non_leader_node(engine, publisher, metrics_task)
return
# Readiness gate: requests wait until model is registered # Readiness gate: requests wait until model is registered
ready_event = asyncio.Event() ready_event = asyncio.Event()
...@@ -160,7 +171,6 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -160,7 +171,6 @@ async def init(runtime: DistributedRuntime, config: Config):
) )
handler.register_engine_routes(runtime) handler.register_engine_routes(runtime)
print(f"Config: {config}")
health_check_payload = SglangHealthCheckPayload( health_check_payload = SglangHealthCheckPayload(
engine, use_text_input=dynamo_args.use_sglang_tokenizer engine, use_text_input=dynamo_args.use_sglang_tokenizer
).to_dict() ).to_dict()
...@@ -224,17 +234,8 @@ async def init_prefill(runtime: DistributedRuntime, config: Config): ...@@ -224,17 +234,8 @@ async def init_prefill(runtime: DistributedRuntime, config: Config):
generate_endpoint = component.endpoint(dynamo_args.endpoint) generate_endpoint = component.endpoint(dynamo_args.endpoint)
# Handle non-leader nodes (multi-node tensor parallelism) # Setup metrics and KV events for ALL nodes (including non-leader)
# Non-leader nodes only run scheduler processes and expose metrics # Non-leader nodes need KV event publishing for their local DP ranks
if server_args.node_rank >= 1:
await _handle_non_leader_node(engine, generate_endpoint)
return
# Perform dummy warmup for prefill worker to avoid initial TTFT hit
# Only needed on leader node that handles requests
await _warmup_prefill_engine(engine, server_args)
# publisher instantiates the metrics and kv event publishers
publisher, metrics_task, metrics_labels = await setup_sgl_metrics( publisher, metrics_task, metrics_labels = await setup_sgl_metrics(
engine, config, component, generate_endpoint engine, config, component, generate_endpoint
) )
...@@ -243,6 +244,16 @@ async def init_prefill(runtime: DistributedRuntime, config: Config): ...@@ -243,6 +244,16 @@ async def init_prefill(runtime: DistributedRuntime, config: Config):
if engine.server_args.enable_metrics: if engine.server_args.enable_metrics:
setup_prometheus_registry(engine, generate_endpoint) setup_prometheus_registry(engine, generate_endpoint)
# Handle non-leader nodes (multi-node parallelism)
# Non-leader nodes run schedulers and publish KV events, but don't serve requests
if server_args.node_rank >= 1:
await _handle_non_leader_node(engine, publisher, metrics_task)
return
# Perform dummy warmup for prefill worker to avoid initial TTFT hit
# Only needed on leader node that handles requests
await _warmup_prefill_engine(engine, server_args)
handler = PrefillWorkerHandler( handler = PrefillWorkerHandler(
component, engine, config, publisher, generate_endpoint component, engine, config, publisher, generate_endpoint
) )
...@@ -310,12 +321,8 @@ async def init_diffusion(runtime: DistributedRuntime, config: Config): ...@@ -310,12 +321,8 @@ async def init_diffusion(runtime: DistributedRuntime, config: Config):
generate_endpoint = component.endpoint(dynamo_args.endpoint) generate_endpoint = component.endpoint(dynamo_args.endpoint)
# Handle non-leader nodes (multi-node parallelism) # Setup metrics and KV events for ALL nodes (including non-leader)
if server_args.node_rank >= 1: # Non-leader nodes need KV event publishing for their local DP ranks
await _handle_non_leader_node(engine, generate_endpoint)
return
# Setup metrics publisher
publisher, metrics_task, metrics_labels = await setup_sgl_metrics( publisher, metrics_task, metrics_labels = await setup_sgl_metrics(
engine, config, component, generate_endpoint engine, config, component, generate_endpoint
) )
...@@ -324,6 +331,12 @@ async def init_diffusion(runtime: DistributedRuntime, config: Config): ...@@ -324,6 +331,12 @@ async def init_diffusion(runtime: DistributedRuntime, config: Config):
if engine.server_args.enable_metrics: if engine.server_args.enable_metrics:
setup_prometheus_registry(engine, generate_endpoint) setup_prometheus_registry(engine, generate_endpoint)
# Handle non-leader nodes (multi-node parallelism)
# Non-leader nodes run schedulers and publish KV events, but don't serve requests
if server_args.node_rank >= 1:
await _handle_non_leader_node(engine, publisher, metrics_task)
return
# Readiness gate: requests wait until model is registered # Readiness gate: requests wait until model is registered
ready_event = asyncio.Event() ready_event = asyncio.Event()
......
...@@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, List, Optional, Tuple ...@@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, List, Optional, Tuple
import sglang as sgl import sglang as sgl
import zmq import zmq
import zmq.asyncio import zmq.asyncio
from sglang.srt.disaggregation.kv_events import ZmqEventPublisher
from sglang.srt.utils import get_local_ip_auto, get_zmq_socket, maybe_wrap_ipv6_address from sglang.srt.utils import get_local_ip_auto, get_zmq_socket, maybe_wrap_ipv6_address
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -48,6 +49,11 @@ def format_zmq_endpoint(endpoint_template: str, ip_address: str) -> str: ...@@ -48,6 +49,11 @@ def format_zmq_endpoint(endpoint_template: str, ip_address: str) -> str:
return endpoint_template.replace("*", formatted_ip) return endpoint_template.replace("*", formatted_ip)
# Note: We use SGLang's ZmqEventPublisher.offset_endpoint_port() directly
# to ensure perfect alignment between publisher (SGLang) and subscriber (dynamo).
# This is the same pattern used by dynamo+vLLM.
class DynamoSglangPublisher: class DynamoSglangPublisher:
""" """
Handles SGLang kv events and metrics reception and publishing. Handles SGLang kv events and metrics reception and publishing.
...@@ -81,15 +87,43 @@ class DynamoSglangPublisher: ...@@ -81,15 +87,43 @@ class DynamoSglangPublisher:
# Set default values (can be overridden later if needed) # Set default values (can be overridden later if needed)
self.dp_rank = 0 self.dp_rank = 0
# ZMQ setup for receiving scheduler metrics self._running = True
self.kv_publishers: List[ZmqKvEventPublisher] = []
# ZMQ setup for receiving scheduler metrics (leader node only)
# Non-leader nodes don't receive scheduler metrics via this socket - they only
# need KV event publishing which is set up separately in init_kv_event_publish()
node_rank = getattr(self.server_args, "node_rank", 0) or 0
if node_rank == 0:
self._ctx = zmq.asyncio.Context() # type: ignore self._ctx = zmq.asyncio.Context() # type: ignore
self._sock = get_zmq_socket( self._sock = get_zmq_socket(
self._ctx, zmq.PULL, self.engine.port_args.metrics_ipc_name, True # type: ignore self._ctx,
zmq.PULL,
self.engine.port_args.metrics_ipc_name,
True, # type: ignore
)
else:
self._ctx = None
self._sock = None
logging.info(
f"Non-leader node (node_rank={node_rank}): skipping scheduler metrics "
"ZMQ socket setup. KV event publishing will still be configured."
) )
async def run(self) -> None: async def run(self) -> None:
"""Continuously receive scheduler metrics from ZMQ socket and publish them.""" """Continuously receive scheduler metrics from ZMQ socket and publish them.
while True:
On non-leader nodes (node_rank >= 1), this is a no-op since they don't have
a scheduler metrics socket. They only publish KV events via init_kv_event_publish().
"""
if self._sock is None:
# Non-leader node: no scheduler metrics to receive
# Just wait until stopped (KV events are handled by separate publishers)
while self._running:
await asyncio.sleep(1)
return
while self._running:
try: try:
kv_metrics = await self._sock.recv_pyobj() # type: ignore kv_metrics = await self._sock.recv_pyobj() # type: ignore
dp_rank = ( dp_rank = (
...@@ -99,26 +133,101 @@ class DynamoSglangPublisher: ...@@ -99,26 +133,101 @@ class DynamoSglangPublisher:
) )
self.metrics_publisher.publish(dp_rank, kv_metrics.kv_active_blocks) self.metrics_publisher.publish(dp_rank, kv_metrics.kv_active_blocks)
except Exception: except Exception:
if self._running:
logging.exception( logging.exception(
"Failed to receive or publish SGLang scheduler metrics" "Failed to receive or publish SGLang scheduler metrics"
) )
def cleanup(self) -> None:
"""Clean up ZMQ resources."""
self._running = False
# Close ZMQ socket and context
if self._sock is not None:
try:
self._sock.close(linger=0)
except Exception as e:
logging.warning(f"Failed to close ZMQ socket: {e}")
if self._ctx is not None:
try:
self._ctx.term()
except Exception as e:
logging.warning(f"Failed to terminate ZMQ context: {e}")
# Shutdown kv publishers
for publisher in self.kv_publishers:
try:
publisher.shutdown()
except Exception as e:
logging.warning(f"Failed to shutdown kv publisher: {e}")
logging.info("DynamoSglangPublisher cleanup complete")
def init_engine_metrics_publish(self) -> None: def init_engine_metrics_publish(self) -> None:
"""Publish initial dummy metrics to bootstrap the metrics endpoint.""" """Publish initial dummy metrics to bootstrap the metrics endpoint."""
logging.info("Sending dummy metrics to initialize") logging.info("Sending dummy metrics to initialize")
self.metrics_publisher.publish(self.dp_rank, 0) self.metrics_publisher.publish(self.dp_rank, 0)
def init_kv_event_publish(self) -> Optional[ZmqKvEventPublisher]: def init_kv_event_publish(self) -> List[ZmqKvEventPublisher]:
"""Initialize KV event publisher if configured. """Initialize KV event publisher(s) if configured.
For DP attention mode, creates one subscriber per LOCAL DP rank port.
Each SGLang scheduler in DP attention mode publishes to a unique port
(base_port + attn_dp_rank). In multi-node setups, each node's dynamo.sglang
instance subscribes only to the DP ranks running on that node.
Multi-node handling:
- Each node runs dynamo.sglang alongside its local SGLang DP ranks
- Each dynamo.sglang subscribes only to LOCAL DP ranks (same node)
- SGLang binds locally (wildcard), Dynamo connects locally
- NATS handles cross-node event distribution
Returns: Returns:
ZmqKvEventPublisher instance if kv_events_config is set, None otherwise. List of ZmqKvEventPublisher instances if kv_events_config is set,
empty list otherwise.
""" """
self.kv_publisher = None
if self.server_args.kv_events_config: if self.server_args.kv_events_config:
kv_events = json.loads(self.server_args.kv_events_config) kv_events = json.loads(self.server_args.kv_events_config)
ep = kv_events.get("endpoint") base_ep = kv_events.get("endpoint")
zmq_ep = format_zmq_endpoint(ep, get_local_ip_auto()) if ep else None local_ip = get_local_ip_auto()
# Determine DP attention configuration
dp_size = getattr(self.server_args, "dp_size", 1) or 1
enable_dp_attention = getattr(
self.server_args, "enable_dp_attention", False
)
nnodes = getattr(self.server_args, "nnodes", 1) or 1
node_rank = getattr(self.server_args, "node_rank", 0) or 0
if enable_dp_attention and dp_size > 1:
# Calculate which DP ranks are local to this node
# DP ranks are distributed evenly across nodes
local_dp_size = dp_size // nnodes if nnodes > 0 else dp_size
start_dp_rank = node_rank * local_dp_size
end_dp_rank = start_dp_rank + local_dp_size
logging.info(
f"DP attention mode: node_rank={node_rank}, dp_size={dp_size}, "
f"nnodes={nnodes}. Subscribing to local DP ranks [{start_dp_rank}, {end_dp_rank})"
)
else:
# Standard mode: single subscriber for rank 0
start_dp_rank = 0
end_dp_rank = 1
for dp_rank in range(start_dp_rank, end_dp_rank):
# Use SGLang's offset_endpoint_port to ensure alignment with publishers
# This is the same function SGLang schedulers use to determine their bind ports
zmq_ep = ZmqEventPublisher.offset_endpoint_port(base_ep, dp_rank)
if not zmq_ep:
logging.warning(
f"Skipping ZMQ subscriber for dp_rank={dp_rank}: "
f"offset_endpoint_port returned None for base_ep={base_ep}"
)
continue
zmq_ep = format_zmq_endpoint(zmq_ep, local_ip)
zmq_config = ZmqKvEventPublisherConfig( zmq_config = ZmqKvEventPublisherConfig(
worker_id=self.generate_endpoint.connection_id(), worker_id=self.generate_endpoint.connection_id(),
...@@ -126,11 +235,19 @@ class DynamoSglangPublisher: ...@@ -126,11 +235,19 @@ class DynamoSglangPublisher:
zmq_endpoint=zmq_ep, zmq_endpoint=zmq_ep,
enable_local_indexer=self.dynamo_args.enable_local_indexer, enable_local_indexer=self.dynamo_args.enable_local_indexer,
) )
logging.info(f"Setting up ZMQ kv event publisher at {zmq_ep}") logging.info(
self.kv_publisher = ZmqKvEventPublisher( f"Setting up ZMQ kv event subscriber for dp_rank={dp_rank} "
f"(connecting to {zmq_ep})"
)
publisher = ZmqKvEventPublisher(
component=self.component, config=zmq_config component=self.component, config=zmq_config
) )
return self.kv_publisher self.kv_publishers.append(publisher)
# Maintain backward compatibility: set kv_publisher to first publisher if any
self.kv_publisher = self.kv_publishers[0] if self.kv_publishers else None
return self.kv_publishers
def setup_prometheus_registry( def setup_prometheus_registry(
......
...@@ -163,6 +163,13 @@ async def _get_runtime_config( ...@@ -163,6 +163,13 @@ async def _get_runtime_config(
runtime_config.tool_call_parser = dynamo_args.tool_call_parser runtime_config.tool_call_parser = dynamo_args.tool_call_parser
runtime_config.enable_local_indexer = dynamo_args.enable_local_indexer runtime_config.enable_local_indexer = dynamo_args.enable_local_indexer
# Set data_parallel_size for DP attention mode
# This enables the router to correctly track per-(worker_id, dp_rank) pairs
dp_size = getattr(server_args, "dp_size", 1) or 1
runtime_config.data_parallel_size = dp_size
if dp_size > 1:
logging.info(f"Registering with data_parallel_size={dp_size}")
# Set bootstrap endpoint for disaggregated serving (prefill workers) # Set bootstrap endpoint for disaggregated serving (prefill workers)
bootstrap_host, bootstrap_port = _get_bootstrap_info_for_config(engine) bootstrap_host, bootstrap_port = _get_bootstrap_info_for_config(engine)
if bootstrap_host and bootstrap_port: if bootstrap_host and bootstrap_port:
......
...@@ -25,9 +25,9 @@ class EmbeddingWorkerHandler(BaseWorkerHandler): ...@@ -25,9 +25,9 @@ class EmbeddingWorkerHandler(BaseWorkerHandler):
logging.info("Embedding worker handler initialized") logging.info("Embedding worker handler initialized")
def cleanup(self): def cleanup(self):
super().cleanup()
self.engine.shutdown() self.engine.shutdown()
logging.info("Engine shutdown") logging.info("Engine shutdown")
super().cleanup()
async def generate(self, request: dict, context: Context): async def generate(self, request: dict, context: Context):
""" """
......
...@@ -42,6 +42,7 @@ class BaseWorkerHandler(ABC): ...@@ -42,6 +42,7 @@ class BaseWorkerHandler(ABC):
self.engine = engine self.engine = engine
self.config = config self.config = config
self.generate_endpoint = generate_endpoint self.generate_endpoint = generate_endpoint
self.publisher = publisher
if publisher is not None: if publisher is not None:
self.metrics_publisher = publisher.metrics_publisher self.metrics_publisher = publisher.metrics_publisher
self.kv_publisher = publisher.kv_publisher self.kv_publisher = publisher.kv_publisher
...@@ -202,7 +203,8 @@ class BaseWorkerHandler(ABC): ...@@ -202,7 +203,8 @@ class BaseWorkerHandler(ABC):
def cleanup(self) -> None: def cleanup(self) -> None:
"""Cleanup resources. Override in subclasses as needed.""" """Cleanup resources. Override in subclasses as needed."""
pass if self.publisher is not None:
self.publisher.cleanup()
def _get_input_param(self, request: Dict[str, Any]) -> Dict[str, Any]: def _get_input_param(self, request: Dict[str, Any]) -> Dict[str, Any]:
request_input = self.input_param_manager.get_input_param( request_input = self.input_param_manager.get_input_param(
......
...@@ -50,9 +50,9 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -50,9 +50,9 @@ class DecodeWorkerHandler(BaseWorkerHandler):
def cleanup(self) -> None: def cleanup(self) -> None:
"""Shutdown the engine and cleanup resources.""" """Shutdown the engine and cleanup resources."""
super().cleanup()
self.engine.shutdown() self.engine.shutdown()
logging.info("Engine shutdown") logging.info("Engine shutdown")
super().cleanup()
def _build_sampling_params(self, request: Dict[str, Any]) -> Dict[str, Any]: def _build_sampling_params(self, request: Dict[str, Any]) -> Dict[str, Any]:
"""Build sampling params from request format. """Build sampling params from request format.
...@@ -126,6 +126,10 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -126,6 +126,10 @@ class DecodeWorkerHandler(BaseWorkerHandler):
self._get_trace_header(context) if self.enable_trace else None self._get_trace_header(context) if self.enable_trace else None
) )
# Extract dp_rank from routing info (set by KV router)
routing = request.get("routing") or {}
dp_rank = routing.get("dp_rank")
decode = await self.engine.async_generate( decode = await self.engine.async_generate(
**input_param, **input_param,
sampling_params=sampling_params, sampling_params=sampling_params,
...@@ -135,6 +139,7 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -135,6 +139,7 @@ class DecodeWorkerHandler(BaseWorkerHandler):
bootstrap_room=bootstrap_info["bootstrap_room"], bootstrap_room=bootstrap_info["bootstrap_room"],
external_trace_header=trace_header, external_trace_header=trace_header,
rid=trace_id, rid=trace_id,
data_parallel_rank=dp_rank,
) )
if self.skip_tokenizer_init: if self.skip_tokenizer_init:
...@@ -161,6 +166,10 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -161,6 +166,10 @@ class DecodeWorkerHandler(BaseWorkerHandler):
self._get_trace_header(context) if self.enable_trace else None self._get_trace_header(context) if self.enable_trace else None
) )
# Extract dp_rank from routing info (set by KV router)
routing = request.get("routing") or {}
dp_rank = routing.get("dp_rank")
agg = await self.engine.async_generate( agg = await self.engine.async_generate(
**input_param, **input_param,
image_data=image_data, image_data=image_data,
...@@ -168,6 +177,7 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -168,6 +177,7 @@ class DecodeWorkerHandler(BaseWorkerHandler):
stream=True, stream=True,
external_trace_header=trace_header, external_trace_header=trace_header,
rid=trace_id, rid=trace_id,
data_parallel_rank=dp_rank,
) )
if self.skip_tokenizer_init: if self.skip_tokenizer_init:
async for out in self._process_token_stream(agg, context): async for out in self._process_token_stream(agg, context):
......
...@@ -49,9 +49,9 @@ class PrefillWorkerHandler(BaseWorkerHandler): ...@@ -49,9 +49,9 @@ class PrefillWorkerHandler(BaseWorkerHandler):
task.cancel() task.cancel()
self._consume_tasks.clear() self._consume_tasks.clear()
super().cleanup()
self.engine.shutdown() self.engine.shutdown()
logging.info("Prefill engine shutdown") logging.info("Prefill engine shutdown")
super().cleanup()
async def generate( async def generate(
self, request: Dict[str, Any], context: Context self, request: Dict[str, Any], context: Context
......
...@@ -400,9 +400,9 @@ class MultimodalWorkerHandler(BaseWorkerHandler): ...@@ -400,9 +400,9 @@ class MultimodalWorkerHandler(BaseWorkerHandler):
return bootstrap_info return bootstrap_info
def cleanup(self): def cleanup(self):
super().cleanup()
self.engine.shutdown() self.engine.shutdown()
logger.info("Multimodal worker engine shutdown") logger.info("Multimodal worker engine shutdown")
super().cleanup()
class MultimodalPrefillWorkerHandler(BaseWorkerHandler): class MultimodalPrefillWorkerHandler(BaseWorkerHandler):
...@@ -515,6 +515,6 @@ class MultimodalPrefillWorkerHandler(BaseWorkerHandler): ...@@ -515,6 +515,6 @@ class MultimodalPrefillWorkerHandler(BaseWorkerHandler):
pass pass
def cleanup(self): def cleanup(self):
super().cleanup()
self.engine.shutdown() self.engine.shutdown()
logger.info("Multimodal prefill engine shutdown") logger.info("Multimodal prefill engine shutdown")
super().cleanup()
...@@ -182,6 +182,7 @@ pub(crate) struct ZmqKvEventListener { ...@@ -182,6 +182,7 @@ pub(crate) struct ZmqKvEventListener {
#[pymethods] #[pymethods]
impl ZmqKvEventListener { impl ZmqKvEventListener {
#[new] #[new]
#[pyo3(signature = (zmq_endpoint, zmq_topic, kv_block_size))]
fn new(zmq_endpoint: String, zmq_topic: String, kv_block_size: usize) -> PyResult<Self> { fn new(zmq_endpoint: String, zmq_topic: String, kv_block_size: usize) -> PyResult<Self> {
if kv_block_size == 0 { if kv_block_size == 0 {
return Err(to_pyerr(anyhow::anyhow!("kv_block_size cannot be 0"))); return Err(to_pyerr(anyhow::anyhow!("kv_block_size cannot be 0")));
......
...@@ -423,8 +423,11 @@ pub async fn start_zmq_listener( ...@@ -423,8 +423,11 @@ pub async fn start_zmq_listener(
return; return;
} }
// Connect to the ZMQ endpoint. SGLang binds locally, Dynamo connects.
// In multi-node setups, each node runs dynamo.sglang alongside local SGLang ranks,
// so ZMQ connections are always local. NATS handles cross-node event distribution.
if let Err(e) = socket.connect(&zmq_endpoint).await { if let Err(e) = socket.connect(&zmq_endpoint).await {
tracing::error!("Failed to connect ZMQ SUB socket: {}", e); tracing::error!("Failed to connect ZMQ SUB socket to {zmq_endpoint}: {e}");
return; return;
} }
...@@ -1556,7 +1559,7 @@ mod tests_startup_helpers { ...@@ -1556,7 +1559,7 @@ mod tests_startup_helpers {
// Cancellation token so we can stop the listener // Cancellation token so we can stop the listener
let token = dynamo_runtime::CancellationToken::new(); let token = dynamo_runtime::CancellationToken::new();
// Spawn async listener // Spawn async listener (connects to publisher bound above)
let listener_handle = tokio::spawn({ let listener_handle = tokio::spawn({
let token = token.clone(); let token = token.clone();
start_zmq_listener(endpoint.to_string(), topic, tx, token, 4) start_zmq_listener(endpoint.to_string(), topic, tx, token, 4)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment