Unverified Commit 34dd7ee7 authored by ishandhanani's avatar ishandhanani Committed by GitHub
Browse files

fix(sglang): enable metrics on multi-node setups (#4238)

parent 5d1ff687
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import asyncio import asyncio
import logging import logging
import os
import signal import signal
import sys import sys
...@@ -33,6 +34,36 @@ from dynamo.sglang.request_handlers import ( ...@@ -33,6 +34,36 @@ from dynamo.sglang.request_handlers import (
configure_dynamo_logging() configure_dynamo_logging()
async def _handle_non_leader_node(
engine: sgl.Engine,
generate_endpoint,
) -> None:
"""
Handle non-leader node (node_rank >= 1) in multi-node deployments.
Non-leader nodes only run scheduler processes and don't handle requests,
but they should still expose metrics via Dynamo's metrics endpoint.
Args:
engine: The SGLang engine instance.
config: SGLang configuration including server args.
component: The Dynamo runtime component.
generate_endpoint: The Dynamo endpoint for generation requests.
"""
logging.info(
f"Non-leader node detected (node_rank={engine.server_args.node_rank})."
)
# Only setup Prometheus registry to expose SGLang metrics from shared memory
# Non-leader nodes don't need Dynamo metrics publishing or KV events
if engine.server_args.enable_metrics:
setup_prometheus_registry(engine, generate_endpoint)
logging.info("Prometheus metrics registry configured for non-leader node")
# Wait indefinitely - the process will be terminated via signal handlers
await asyncio.Event().wait()
async def worker(): async def worker():
config = await parse_args(sys.argv[1:]) config = await parse_args(sys.argv[1:])
dump_config(config.dynamo_args.dump_config_to, config) dump_config(config.dynamo_args.dump_config_to, config)
...@@ -68,6 +99,10 @@ async def worker(): ...@@ -68,6 +99,10 @@ async def worker():
async def init(runtime: DistributedRuntime, config: Config): async def init(runtime: DistributedRuntime, config: Config):
server_args, dynamo_args = config.server_args, config.dynamo_args server_args, dynamo_args = config.server_args, config.dynamo_args
# Prevent SGLang from blocking on non-leader nodes
if server_args.node_rank >= 1:
os.environ["SGLANG_BLOCK_NONZERO_RANK_CHILDREN"] = "0"
engine = sgl.Engine(server_args=server_args) engine = sgl.Engine(server_args=server_args)
component = runtime.namespace(dynamo_args.namespace).component( component = runtime.namespace(dynamo_args.namespace).component(
...@@ -77,6 +112,12 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -77,6 +112,12 @@ async def init(runtime: DistributedRuntime, config: Config):
generate_endpoint = component.endpoint(dynamo_args.endpoint) generate_endpoint = component.endpoint(dynamo_args.endpoint)
# Handle non-leader nodes (multi-node parallelism)
# Non-leader nodes only run scheduler processes and expose metrics
if server_args.node_rank >= 1:
await _handle_non_leader_node(engine, generate_endpoint)
return
prefill_client = None prefill_client = None
prefill_router_client = None prefill_router_client = None
if config.serving_mode == DisaggregationMode.DECODE: if config.serving_mode == DisaggregationMode.DECODE:
...@@ -145,10 +186,11 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -145,10 +186,11 @@ async def init(runtime: DistributedRuntime, config: Config):
async def init_prefill(runtime: DistributedRuntime, config: Config): async def init_prefill(runtime: DistributedRuntime, config: Config):
server_args, dynamo_args = config.server_args, config.dynamo_args server_args, dynamo_args = config.server_args, config.dynamo_args
engine = sgl.Engine(server_args=server_args) # Prevent SGLang from blocking on non-leader nodes
if server_args.node_rank >= 1:
os.environ["SGLANG_BLOCK_NONZERO_RANK_CHILDREN"] = "0"
# Perform dummy warmup for prefill worker to avoid initial TTFT hit engine = sgl.Engine(server_args=server_args)
await _warmup_prefill_engine(engine, server_args)
component = runtime.namespace(dynamo_args.namespace).component( component = runtime.namespace(dynamo_args.namespace).component(
dynamo_args.component dynamo_args.component
...@@ -157,6 +199,16 @@ async def init_prefill(runtime: DistributedRuntime, config: Config): ...@@ -157,6 +199,16 @@ async def init_prefill(runtime: DistributedRuntime, config: Config):
generate_endpoint = component.endpoint(dynamo_args.endpoint) generate_endpoint = component.endpoint(dynamo_args.endpoint)
# Handle non-leader nodes (multi-node tensor parallelism)
# Non-leader nodes only run scheduler processes and expose metrics
if server_args.node_rank >= 1:
await _handle_non_leader_node(engine, generate_endpoint)
return
# Perform dummy warmup for prefill worker to avoid initial TTFT hit
# Only needed on leader node that handles requests
await _warmup_prefill_engine(engine, server_args)
# publisher instantiates the metrics and kv event publishers # publisher instantiates the metrics and kv event publishers
publisher, metrics_task, metrics_labels = await setup_sgl_metrics( publisher, metrics_task, metrics_labels = await setup_sgl_metrics(
engine, config, component, generate_endpoint engine, config, component, generate_endpoint
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment