Unverified Commit 56f423c6 authored by Tzu-Ling Kan's avatar Tzu-Ling Kan Committed by GitHub
Browse files

feat: Sglang canary health check (#3103)


Signed-off-by: default avatartzulingk@nvidia.com <tzulingk@nvidia.com>
parent 8c89a555
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""
sglang-specific health check configuration.
This module defines the default health check payload for sglang backends.
"""
from dynamo.health_check import HealthCheckPayload
class SglangHealthCheckPayload(HealthCheckPayload):
"""
sglang-specific health check payload.
Provides sglang defaults and inherits environment override support from base class.
"""
def __init__(self):
"""
Initialize sglang health check payload with sglang-specific defaults.
The format matches what DecodeWorkerHandler expects from the frontend.
"""
self.default_payload = {
"token_ids": [1], # Single token for minimal processing
"stop_conditions": {
"max_tokens": 1, # Generate only 1 token
"ignore_eos": False,
},
"sampling_options": {
"temperature": 0.0,
"top_p": 1.0,
"top_k": -1,
},
"eos_token_ids": [],
"annotations": [],
}
super().__init__()
class SglangPrefillHealthCheckPayload(HealthCheckPayload):
"""
SGLang-specific health check payload for prefill workers in disaggregated mode.
The prefill handler expects a wrapped structure with 'request' and 'sampling_params'.
"""
def __init__(self):
"""
Initialize SGLang prefill health check payload with proper wrapped structure.
"""
self.default_payload = {
"request": {
"token_ids": [1], # Single token for minimal processing
},
"sampling_params": {
"max_new_tokens": 1, # Generate only 1 token
"temperature": 0.0,
"top_p": 1.0,
"top_k": -1,
"ignore_eos": False,
},
}
super().__init__()
...@@ -15,6 +15,10 @@ from dynamo.llm import ZmqKvEventPublisher, ZmqKvEventPublisherConfig ...@@ -15,6 +15,10 @@ from dynamo.llm import ZmqKvEventPublisher, ZmqKvEventPublisherConfig
from dynamo.runtime import DistributedRuntime, dynamo_worker from dynamo.runtime import DistributedRuntime, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.sglang.args import Config, DisaggregationMode, parse_args from dynamo.sglang.args import Config, DisaggregationMode, parse_args
from dynamo.sglang.health_check import (
SglangHealthCheckPayload,
SglangPrefillHealthCheckPayload,
)
from dynamo.sglang.publisher import setup_sgl_metrics from dynamo.sglang.publisher import setup_sgl_metrics
from dynamo.sglang.register import register_llm_with_runtime_config from dynamo.sglang.register import register_llm_with_runtime_config
from dynamo.sglang.request_handlers import DecodeWorkerHandler, PrefillWorkerHandler from dynamo.sglang.request_handlers import DecodeWorkerHandler, PrefillWorkerHandler
...@@ -112,6 +116,8 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -112,6 +116,8 @@ async def init(runtime: DistributedRuntime, config: Config):
ready_event.set() ready_event.set()
logging.info("Model registration succeeded; processing queued requests") logging.info("Model registration succeeded; processing queued requests")
health_check_payload = SglangHealthCheckPayload().to_dict()
try: try:
# Start endpoint immediately and register model concurrently # Start endpoint immediately and register model concurrently
# Requests queue until ready_event is set # Requests queue until ready_event is set
...@@ -120,6 +126,7 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -120,6 +126,7 @@ async def init(runtime: DistributedRuntime, config: Config):
handler.generate, handler.generate,
graceful_shutdown=True, graceful_shutdown=True,
metrics_labels=metrics_labels, metrics_labels=metrics_labels,
health_check_payload=health_check_payload,
), ),
register_model(), register_model(),
) )
...@@ -150,11 +157,14 @@ async def init_prefill(runtime: DistributedRuntime, config: Config): ...@@ -150,11 +157,14 @@ async def init_prefill(runtime: DistributedRuntime, config: Config):
handler = PrefillWorkerHandler(component, engine, config) handler = PrefillWorkerHandler(component, engine, config)
health_check_payload = SglangPrefillHealthCheckPayload().to_dict()
tasks = [ tasks = [
generate_endpoint.serve_endpoint( generate_endpoint.serve_endpoint(
handler.generate, handler.generate,
graceful_shutdown=True, graceful_shutdown=True,
metrics_labels=[("model", server_args.served_model_name)], metrics_labels=[("model", server_args.served_model_name)],
health_check_payload=health_check_payload,
) )
] ]
......
...@@ -53,7 +53,7 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -53,7 +53,7 @@ class DecodeWorkerHandler(BaseWorkerHandler):
sampling_params["ignore_eos"] = request["stop_conditions"]["ignore_eos"] sampling_params["ignore_eos"] = request["stop_conditions"]["ignore_eos"]
return sampling_params return sampling_params
async def generate(self, request: str): async def generate(self, request: dict):
sampling_params = self._build_sampling_params(request) sampling_params = self._build_sampling_params(request)
if self.serving_mode == DisaggregationMode.DECODE: if self.serving_mode == DisaggregationMode.DECODE:
...@@ -62,7 +62,7 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -62,7 +62,7 @@ class DecodeWorkerHandler(BaseWorkerHandler):
DisaggPreprocessedRequest( DisaggPreprocessedRequest(
request=request, request=request,
sampling_params=sampling_params, sampling_params=sampling_params,
).model_dump_json() ).model_dump()
) )
bootstrap_info = None bootstrap_info = None
......
...@@ -6,7 +6,6 @@ import logging ...@@ -6,7 +6,6 @@ import logging
import random import random
import socket import socket
import msgspec
import sglang as sgl import sglang as sgl
from sglang.srt.utils import get_ip from sglang.srt.utils import get_ip
...@@ -46,8 +45,7 @@ class PrefillWorkerHandler(BaseWorkerHandler): ...@@ -46,8 +45,7 @@ class PrefillWorkerHandler(BaseWorkerHandler):
return bootstrap_host, bootstrap_port return bootstrap_host, bootstrap_port
async def generate(self, request: str): async def generate(self, request: dict):
req = msgspec.json.decode(request, type=dict)
bootstrap_room = self._generate_bootstrap_room() bootstrap_room = self._generate_bootstrap_room()
bootstrap_info = { bootstrap_info = {
...@@ -59,8 +57,8 @@ class PrefillWorkerHandler(BaseWorkerHandler): ...@@ -59,8 +57,8 @@ class PrefillWorkerHandler(BaseWorkerHandler):
yield bootstrap_info yield bootstrap_info
results = await self.engine.async_generate( results = await self.engine.async_generate(
input_ids=req["request"]["token_ids"], input_ids=request["request"]["token_ids"],
sampling_params=req["sampling_params"], sampling_params=request["sampling_params"],
stream=True, stream=True,
bootstrap_host=self.bootstrap_host, bootstrap_host=self.bootstrap_host,
bootstrap_port=self.bootstrap_port, bootstrap_port=self.bootstrap_port,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment