"ssh:/git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "e99e467384001e284e0722a33362866b10fed65b"
Unverified Commit cf5f65f7 authored by Vladislav Nosivskoy's avatar Vladislav Nosivskoy Committed by GitHub
Browse files

feat: add generate health check support for PD SGLang (#6004)


Signed-off-by: default avatarVladislav Nosivskoy <vladnosiv@gmail.com>
Co-authored-by: default avatarishandhanani <82981111+ishandhanani@users.noreply.github.com>
parent 5fd39ade
...@@ -85,41 +85,88 @@ class SglangHealthCheckPayload(HealthCheckPayload): ...@@ -85,41 +85,88 @@ class SglangHealthCheckPayload(HealthCheckPayload):
super().__init__() super().__init__()
class SglangPrefillHealthCheckPayload(HealthCheckPayload): class SglangDisaggHealthCheckPayload(HealthCheckPayload):
"""SGLang-specific health check payload for prefill workers in disaggregated mode. """SGLang-specific health check payload for PD-disaggregated mode.
The prefill handler expects a wrapped structure with 'request' and 'sampling_params'. Both prefill and decode handlers support flat format with bootstrap_info.
Uses FAKE_BOOTSTRAP_HOST to enable fake-transfer mode, so health checks
don't require real KV-transfer between prefill/decode workers.
Uses bootstrap_room=0 (same as SGLang). This means health checks always go to
DP rank 0. For proper DP coverage, runtime would need to support dynamic payload
generation per health check request.
""" """
def __init__( def __init__(
self, engine: Optional[sgl.Engine] = None, use_text_input: bool = False self,
engine: Optional[sgl.Engine] = None,
use_text_input: bool = False,
) -> None: ) -> None:
"""Initialize SGLang prefill health check payload with proper wrapped structure. """Initialize SGLang disaggregated health check payload.
Args: Args:
engine: Optional SGLang Engine instance to extract BOS token from. engine: SGLang Engine instance to extract BOS token and bootstrap port from.
use_text_input: Whether to use text prompt instead of token IDs.
""" """
from sglang.srt.disaggregation.utils import FAKE_BOOTSTRAP_HOST
bos_token_id = _get_bos_token_id_from_engine(engine) bos_token_id = _get_bos_token_id_from_engine(engine)
# Get bootstrap port from engine
bootstrap_port = 0
if engine is not None:
try:
inner_tm = engine.tokenizer_manager
bootstrap_port = getattr(
inner_tm.server_args, "disaggregation_bootstrap_port", 0
)
except Exception as e:
logger.warning(f"Failed to get bootstrap port from engine: {e}")
# Create bootstrap_info for fake-transfer mode
# FAKE_BOOTSTRAP_HOST tells SGLang to skip real KV-transfer
# bootstrap_room=0 matches SGLang behavior (always routes to DP rank 0)
# TODO: For proper DP coverage, runtime needs to support dynamic payload generation
bootstrap_info = {
"bootstrap_host": FAKE_BOOTSTRAP_HOST,
"bootstrap_port": bootstrap_port,
"bootstrap_room": 0,
}
self.default_payload = { self.default_payload = {
"request": {}, "bootstrap_info": bootstrap_info,
"sampling_params": { "stop_conditions": {
"max_new_tokens": 1, # Generate only 1 token "max_tokens": 1, # Generate only 1 token
"ignore_eos": False,
},
"sampling_options": {
"temperature": 0.0, "temperature": 0.0,
"top_p": 1.0, "top_p": 1.0,
"top_k": -1, "top_k": -1,
"ignore_eos": False,
}, },
"eos_token_ids": [],
"annotations": [],
} }
if use_text_input: if use_text_input:
self.default_payload["request"]["prompt"] = "Test" # type: ignore self.default_payload["prompt"] = "Test"
else: else:
self.default_payload["request"]["token_ids"] = [bos_token_id] # type: ignore self.default_payload["token_ids"] = [bos_token_id]
logger.info(
f"Disagg health check configured: "
f"bootstrap_host={FAKE_BOOTSTRAP_HOST}, "
f"bootstrap_port={bootstrap_port}, "
f"bootstrap_room=0"
)
super().__init__() super().__init__()
class SglangPrefillHealthCheckPayload(SglangDisaggHealthCheckPayload):
"""Backward-compatible alias for prefill health checks in disaggregated mode."""
class ImageDiffusionHealthCheckPayload(HealthCheckPayload): class ImageDiffusionHealthCheckPayload(HealthCheckPayload):
"""Image diffusion-specific health check payload for image generation workers. """Image diffusion-specific health check payload for image generation workers.
......
...@@ -9,11 +9,13 @@ from typing import Awaitable, Callable, Optional ...@@ -9,11 +9,13 @@ from typing import Awaitable, Callable, Optional
import sglang as sgl import sglang as sgl
from dynamo.common.constants import DisaggregationMode
from dynamo.common.utils.endpoint_types import parse_endpoint_types from dynamo.common.utils.endpoint_types import parse_endpoint_types
from dynamo.llm import ModelInput, ModelType from dynamo.llm import ModelInput, ModelType
from dynamo.runtime import DistributedRuntime from dynamo.runtime import DistributedRuntime
from dynamo.sglang.args import Config from dynamo.sglang.args import Config
from dynamo.sglang.health_check import ( from dynamo.sglang.health_check import (
SglangDisaggHealthCheckPayload,
SglangHealthCheckPayload, SglangHealthCheckPayload,
SglangPrefillHealthCheckPayload, SglangPrefillHealthCheckPayload,
) )
...@@ -101,9 +103,14 @@ async def init_decode( ...@@ -101,9 +103,14 @@ async def init_decode(
) )
handler.register_engine_routes(runtime) handler.register_engine_routes(runtime)
health_check_payload = SglangHealthCheckPayload( if config.serving_mode == DisaggregationMode.DECODE:
engine, use_text_input=dynamo_args.use_sglang_tokenizer health_check_payload = SglangDisaggHealthCheckPayload(
).to_dict() engine, use_text_input=dynamo_args.use_sglang_tokenizer
).to_dict()
else:
health_check_payload = SglangHealthCheckPayload(
engine, use_text_input=dynamo_args.use_sglang_tokenizer
).to_dict()
logging.info(f"Registering model with endpoint types: {dynamo_args.endpoint_types}") logging.info(f"Registering model with endpoint types: {dynamo_args.endpoint_types}")
if dynamo_args.custom_jinja_template and "chat" not in dynamo_args.endpoint_types: if dynamo_args.custom_jinja_template and "chat" not in dynamo_args.endpoint_types:
......
...@@ -13,6 +13,7 @@ from dynamo.llm import ModelInput ...@@ -13,6 +13,7 @@ from dynamo.llm import ModelInput
from dynamo.runtime import DistributedRuntime from dynamo.runtime import DistributedRuntime
from dynamo.sglang.args import Config from dynamo.sglang.args import Config
from dynamo.sglang.health_check import ( from dynamo.sglang.health_check import (
SglangDisaggHealthCheckPayload,
SglangHealthCheckPayload, SglangHealthCheckPayload,
SglangPrefillHealthCheckPayload, SglangPrefillHealthCheckPayload,
) )
...@@ -160,7 +161,10 @@ async def init_multimodal_worker( ...@@ -160,7 +161,10 @@ async def init_multimodal_worker(
await handler.async_init() await handler.async_init()
health_check_payload = SglangHealthCheckPayload(engine).to_dict() if config.serving_mode == DisaggregationMode.DECODE:
health_check_payload = SglangDisaggHealthCheckPayload(engine).to_dict()
else:
health_check_payload = SglangHealthCheckPayload(engine).to_dict()
try: try:
await generate_endpoint.serve_endpoint( await generate_endpoint.serve_endpoint(
......
...@@ -86,10 +86,25 @@ class PrefillWorkerHandler(BaseWorkerHandler): ...@@ -86,10 +86,25 @@ class PrefillWorkerHandler(BaseWorkerHandler):
k: v for k, v in sampling_params.items() if v is not None k: v for k, v in sampling_params.items() if v is not None
} }
# Use provided bootstrap_room from bootstrap_info if available, otherwise generate one # Use provided bootstrap_info if available (e.g., for health checks with FAKE_BOOTSTRAP_HOST)
# Otherwise use real bootstrap host/port from engine and generate room locally
bootstrap_host = self.bootstrap_host
bootstrap_port = self.bootstrap_port
bootstrap_room = None bootstrap_room = None
bootstrap_info_from_req = inner_request.get("bootstrap_info") bootstrap_info_from_req = inner_request.get("bootstrap_info")
if isinstance(bootstrap_info_from_req, dict): if isinstance(bootstrap_info_from_req, dict):
# Allow overriding bootstrap_host for fake-transfer mode (health checks)
if "bootstrap_host" in bootstrap_info_from_req:
bootstrap_host = bootstrap_info_from_req["bootstrap_host"]
logging.debug(
f"Using request-provided bootstrap_host: {bootstrap_host}"
)
if "bootstrap_port" in bootstrap_info_from_req:
bootstrap_port = bootstrap_info_from_req["bootstrap_port"]
logging.debug(
f"Using request-provided bootstrap_port: {bootstrap_port}"
)
bootstrap_room = bootstrap_info_from_req.get("bootstrap_room") bootstrap_room = bootstrap_info_from_req.get("bootstrap_room")
if bootstrap_room is not None: if bootstrap_room is not None:
logging.debug(f"Using router-provided bootstrap_room: {bootstrap_room}") logging.debug(f"Using router-provided bootstrap_room: {bootstrap_room}")
...@@ -99,8 +114,8 @@ class PrefillWorkerHandler(BaseWorkerHandler): ...@@ -99,8 +114,8 @@ class PrefillWorkerHandler(BaseWorkerHandler):
logging.debug(f"Generated bootstrap_room locally: {bootstrap_room}") logging.debug(f"Generated bootstrap_room locally: {bootstrap_room}")
bootstrap_info = { bootstrap_info = {
"bootstrap_host": self.bootstrap_host, "bootstrap_host": bootstrap_host,
"bootstrap_port": self.bootstrap_port, "bootstrap_port": bootstrap_port,
"bootstrap_room": bootstrap_room, "bootstrap_room": bootstrap_room,
} }
...@@ -122,8 +137,8 @@ class PrefillWorkerHandler(BaseWorkerHandler): ...@@ -122,8 +137,8 @@ class PrefillWorkerHandler(BaseWorkerHandler):
**input_param, **input_param,
sampling_params=sampling_params, sampling_params=sampling_params,
stream=True, stream=True,
bootstrap_host=self.bootstrap_host, bootstrap_host=bootstrap_host,
bootstrap_port=self.bootstrap_port, bootstrap_port=bootstrap_port,
bootstrap_room=bootstrap_room, bootstrap_room=bootstrap_room,
external_trace_header=trace_header, external_trace_header=trace_header,
rid=trace_id, rid=trace_id,
......
...@@ -11,6 +11,10 @@ import pytest ...@@ -11,6 +11,10 @@ import pytest
import yaml import yaml
from dynamo.sglang.args import parse_args from dynamo.sglang.args import parse_args
from dynamo.sglang.health_check import (
SglangDisaggHealthCheckPayload,
SglangPrefillHealthCheckPayload,
)
from dynamo.sglang.tests.conftest import make_cli_args_fixture from dynamo.sglang.tests.conftest import make_cli_args_fixture
# Get path relative to this test file # Get path relative to this test file
...@@ -265,3 +269,20 @@ async def test_disagg_config_rejects_dynamo_keys(tmp_path, mock_sglang_cli, capf ...@@ -265,3 +269,20 @@ async def test_disagg_config_rejects_dynamo_keys(tmp_path, mock_sglang_cli, capf
out, err = capfd.readouterr() out, err = capfd.readouterr()
assert "unrecognized arguments: --store-kv mem" in err assert "unrecognized arguments: --store-kv mem" in err
def test_disagg_health_check_payload_includes_bootstrap_info():
payload = SglangDisaggHealthCheckPayload().to_dict()
assert payload["bootstrap_info"]["bootstrap_host"] == "fake_bootstrap_host"
assert payload["bootstrap_info"]["bootstrap_port"] == 0
assert payload["bootstrap_info"]["bootstrap_room"] == 0
assert payload["token_ids"] == [1]
def test_prefill_health_check_payload_is_disagg_compatible_alias():
payload = SglangPrefillHealthCheckPayload().to_dict()
assert "request" not in payload
assert payload["bootstrap_info"]["bootstrap_host"] == "fake_bootstrap_host"
assert payload["stop_conditions"]["max_tokens"] == 1
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment