Unverified Commit 49ecfe60 authored by Keiven C's avatar Keiven C Committed by GitHub
Browse files

fix: add prefill metrics support for SGLang disaggregated mode (#3992)


Signed-off-by: default avatarKeiven Chang <keivenchang@users.noreply.github.com>
parent c9ca5c40
...@@ -18,7 +18,7 @@ from dynamo.sglang.health_check import ( ...@@ -18,7 +18,7 @@ from dynamo.sglang.health_check import (
SglangHealthCheckPayload, SglangHealthCheckPayload,
SglangPrefillHealthCheckPayload, SglangPrefillHealthCheckPayload,
) )
from dynamo.sglang.publisher import setup_sgl_metrics from dynamo.sglang.publisher import setup_prometheus_registry, setup_sgl_metrics
from dynamo.sglang.register import register_llm_with_readiness_gate from dynamo.sglang.register import register_llm_with_readiness_gate
from dynamo.sglang.request_handlers import ( from dynamo.sglang.request_handlers import (
DecodeWorkerHandler, DecodeWorkerHandler,
...@@ -94,12 +94,14 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -94,12 +94,14 @@ async def init(runtime: DistributedRuntime, config: Config):
) )
# publisher instantiates the metrics and kv event publishers # publisher instantiates the metrics and kv event publishers
# Note that when engine.server_args.enable_metrics is True, it'll also
# gather internal SGLang Prometheus metrics from all worker processes.
publisher, metrics_task, metrics_labels = await setup_sgl_metrics( publisher, metrics_task, metrics_labels = await setup_sgl_metrics(
engine, config, component, generate_endpoint engine, config, component, generate_endpoint
) )
# Register Prometheus metrics callback if enabled
if engine.server_args.enable_metrics:
setup_prometheus_registry(engine, generate_endpoint)
# Readiness gate: requests wait until model is registered # Readiness gate: requests wait until model is registered
ready_event = asyncio.Event() ready_event = asyncio.Event()
...@@ -160,6 +162,10 @@ async def init_prefill(runtime: DistributedRuntime, config: Config): ...@@ -160,6 +162,10 @@ async def init_prefill(runtime: DistributedRuntime, config: Config):
engine, config, component, generate_endpoint engine, config, component, generate_endpoint
) )
# Register Prometheus metrics callback if enabled
if engine.server_args.enable_metrics:
setup_prometheus_registry(engine, generate_endpoint)
handler = PrefillWorkerHandler(component, engine, config, publisher) handler = PrefillWorkerHandler(component, engine, config, publisher)
health_check_payload = SglangPrefillHealthCheckPayload(engine).to_dict() health_check_payload = SglangPrefillHealthCheckPayload(engine).to_dict()
...@@ -206,6 +212,10 @@ async def init_embedding(runtime: DistributedRuntime, config: Config): ...@@ -206,6 +212,10 @@ async def init_embedding(runtime: DistributedRuntime, config: Config):
engine, config, component, generate_endpoint engine, config, component, generate_endpoint
) )
# Register Prometheus metrics callback if enabled
if engine.server_args.enable_metrics:
setup_prometheus_registry(engine, generate_endpoint)
# Readiness gate: requests wait until model is registered # Readiness gate: requests wait until model is registered
ready_event = asyncio.Event() ready_event = asyncio.Event()
......
...@@ -196,6 +196,33 @@ class DynamoSglangPublisher: ...@@ -196,6 +196,33 @@ class DynamoSglangPublisher:
self._record(worker_stats, kv_stats, spec_decode_stats) self._record(worker_stats, kv_stats, spec_decode_stats)
def setup_prometheus_registry(
engine: sgl.Engine, generate_endpoint: Endpoint
) -> CollectorRegistry:
"""Set up Prometheus registry for SGLang metrics collection.
SGLang uses multiprocess architecture where metrics are stored in shared memory.
MultiProcessCollector aggregates metrics from all worker processes. The Prometheus
registry collects sglang:* metrics which are exposed via the metrics server endpoint
(typically port 8081) when DYN_SYSTEM_ENABLED=true.
Args:
engine: The SGLang engine instance.
generate_endpoint: The Dynamo endpoint for generation requests.
Returns:
Configured CollectorRegistry with multiprocess support.
"""
registry = CollectorRegistry()
multiprocess.MultiProcessCollector(registry)
register_engine_metrics_callback(
endpoint=generate_endpoint,
registry=registry,
metric_prefix_filter="sglang:",
)
return registry
async def setup_sgl_metrics( async def setup_sgl_metrics(
engine: sgl.Engine, engine: sgl.Engine,
config: Config, config: Config,
...@@ -220,18 +247,6 @@ async def setup_sgl_metrics( ...@@ -220,18 +247,6 @@ async def setup_sgl_metrics(
publisher.init_engine_metrics_publish() publisher.init_engine_metrics_publish()
publisher.init_kv_event_publish() publisher.init_kv_event_publish()
# Register Prometheus metrics callback if enabled
if engine.server_args.enable_metrics:
# SGLang uses multiprocess architecture where metrics are stored in shared memory.
# MultiProcessCollector aggregates metrics from all worker processes.
registry = CollectorRegistry()
multiprocess.MultiProcessCollector(registry)
register_engine_metrics_callback(
endpoint=generate_endpoint,
registry=registry,
metric_prefix_filter="sglang:",
)
task = asyncio.create_task(publisher.run()) task = asyncio.create_task(publisher.run())
logging.info("SGLang metrics loop started") logging.info("SGLang metrics loop started")
return publisher, task, metrics_labels return publisher, task, metrics_labels
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Usage: ./disagg_same_gpu.sh [GPU_MEM_FRACTION]
# GPU_MEM_FRACTION: Fraction of GPU memory to use per worker (default: 0.45)
# Example: ./disagg_same_gpu.sh 0.45
# GPU memory fraction to use per worker (default: 0.45 = 45% each = 90% total for both workers)
GPU_MEM_FRACTION="${1:-0.45}"
# Check GPU memory before starting disaggregated mode on single GPU
FREE_GPU_GB=$(python3 -c "import torch; print(torch.cuda.mem_get_info()[0]/1024**3)" 2>/dev/null)
if [ $? -ne 0 ]; then
echo "Error: Failed to check GPU memory. Is PyTorch with CUDA available?"
exit 1
fi
REQUIRED_GB=16
# Use Python for floating-point comparison to avoid bc dependency
if python3 -c "import sys; sys.exit(0 if float('$FREE_GPU_GB') >= $REQUIRED_GB else 1)"; then
echo "GPU memory check passed: ${FREE_GPU_GB}GB available (required: ${REQUIRED_GB}GB)"
else
echo "Error: Insufficient GPU memory. Required: ${REQUIRED_GB}GB, Available: ${FREE_GPU_GB}GB"
echo "Please free up GPU memory before running disaggregated mode on single GPU."
exit 1
fi
# Setup cleanup trap
cleanup() {
echo "Cleaning up background processes..."
kill $DYNAMO_PID $PREFILL_PID 2>/dev/null || true
wait $DYNAMO_PID $PREFILL_PID 2>/dev/null || true
echo "Cleanup complete."
}
trap cleanup EXIT INT TERM
# run ingress with KV router mode for disaggregated setup
python3 -m dynamo.frontend --router-mode kv --http-port=8000 &
DYNAMO_PID=$!
# run prefill worker with metrics on port 8081
DYN_SYSTEM_ENABLED=true DYN_SYSTEM_PORT=8081 \
python3 -m dynamo.sglang \
--model-path Qwen/Qwen3-0.6B \
--served-model-name Qwen/Qwen3-0.6B \
--page-size 16 \
--tp 1 \
--trust-remote-code \
--disaggregation-mode prefill \
--disaggregation-bootstrap-port 12345 \
--host 0.0.0.0 \
--disaggregation-transfer-backend nixl \
--mem-fraction-static ${GPU_MEM_FRACTION} \
--chunked-prefill-size 4096 \
--max-prefill-tokens 4096 \
--enable-memory-saver \
--delete-ckpt-after-loading \
--max-running-requests 2 \
--enable-metrics &
PREFILL_PID=$!
# Wait for prefill worker to initialize before starting decode worker
# This prevents both workers from competing for GPU memory simultaneously, which can cause OOM.
# The prefill worker needs time to:
# 1. Load model weights and allocate its memory fraction
# 2. Initialize KV cache with --delete-ckpt-after-loading to free checkpoint memory
# 3. Register with NATS service discovery so decode worker can find it
echo "Waiting for prefill worker to initialize..."
sleep 5
# run decode worker with metrics on port 8082 (foreground)
DYN_SYSTEM_ENABLED=true DYN_SYSTEM_PORT=8082 \
python3 -m dynamo.sglang \
--model-path Qwen/Qwen3-0.6B \
--served-model-name Qwen/Qwen3-0.6B \
--page-size 16 \
--tp 1 \
--trust-remote-code \
--disaggregation-mode decode \
--disaggregation-bootstrap-port 12345 \
--host 0.0.0.0 \
--disaggregation-transfer-backend nixl \
--mem-fraction-static ${GPU_MEM_FRACTION} \
--chunked-prefill-size 4096 \
--max-prefill-tokens 4096 \
--enable-memory-saver \
--delete-ckpt-after-loading \
--max-running-requests 2 \
--enable-metrics
...@@ -64,6 +64,25 @@ sglang_configs = { ...@@ -64,6 +64,25 @@ sglang_configs = {
models_port=8000, models_port=8000,
request_payloads=[chat_payload_default(), completion_payload_default()], request_payloads=[chat_payload_default(), completion_payload_default()],
), ),
"disaggregated_same_gpu": SGLangConfig(
# Uses disagg_same_gpu.sh for single-GPU disaggregated testing
# Validates metrics from both prefill (port 8081) and decode (port 8082) workers
name="disaggregated_same_gpu",
directory=sglang_dir,
script_name="disagg_same_gpu.sh",
marks=[pytest.mark.gpu_1],
model="Qwen/Qwen3-0.6B",
env={},
models_port=8000,
request_payloads=[
chat_payload_default(),
completion_payload_default(),
# Validate dynamo_component_* and sglang:* metrics from prefill worker (port 8081)
metric_payload_default(min_num_requests=6, backend="sglang", port=8081),
# Validate dynamo_component_* and sglang:* metrics from decode worker (port 8082)
metric_payload_default(min_num_requests=6, backend="sglang", port=8082),
],
),
"kv_events": SGLangConfig( "kv_events": SGLangConfig(
name="kv_events", name="kv_events",
directory=sglang_dir, directory=sglang_dir,
......
...@@ -67,6 +67,7 @@ def metric_payload_default( ...@@ -67,6 +67,7 @@ def metric_payload_default(
repeat_count: int = 1, repeat_count: int = 1,
expected_log: Optional[List[str]] = None, expected_log: Optional[List[str]] = None,
backend: Optional[str] = None, backend: Optional[str] = None,
port: int = 8081,
) -> MetricsPayload: ) -> MetricsPayload:
return MetricsPayload( return MetricsPayload(
body={}, body={},
...@@ -75,6 +76,7 @@ def metric_payload_default( ...@@ -75,6 +76,7 @@ def metric_payload_default(
expected_response=[], expected_response=[],
min_num_requests=min_num_requests, min_num_requests=min_num_requests,
backend=backend, backend=backend,
port=port,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment