"docs/vscode:/vscode.git/clone" did not exist on "912a4d4b7ef39eda11b93c11f74b32e8eaf1906a"
Unverified Commit 56f423c6 authored by Tzu-Ling Kan's avatar Tzu-Ling Kan Committed by GitHub
Browse files

feat: Sglang canary health check (#3103)


Signed-off-by: default avatartzulingk@nvidia.com <tzulingk@nvidia.com>
parent 8c89a555
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""
sglang-specific health check configuration.
This module defines the default health check payload for sglang backends.
"""
from dynamo.health_check import HealthCheckPayload
class SglangHealthCheckPayload(HealthCheckPayload):
"""
sglang-specific health check payload.
Provides sglang defaults and inherits environment override support from base class.
"""
def __init__(self):
"""
Initialize sglang health check payload with sglang-specific defaults.
The format matches what DecodeWorkerHandler expects from the frontend.
"""
self.default_payload = {
"token_ids": [1], # Single token for minimal processing
"stop_conditions": {
"max_tokens": 1, # Generate only 1 token
"ignore_eos": False,
},
"sampling_options": {
"temperature": 0.0,
"top_p": 1.0,
"top_k": -1,
},
"eos_token_ids": [],
"annotations": [],
}
super().__init__()
class SglangPrefillHealthCheckPayload(HealthCheckPayload):
"""
SGLang-specific health check payload for prefill workers in disaggregated mode.
The prefill handler expects a wrapped structure with 'request' and 'sampling_params'.
"""
def __init__(self):
"""
Initialize SGLang prefill health check payload with proper wrapped structure.
"""
self.default_payload = {
"request": {
"token_ids": [1], # Single token for minimal processing
},
"sampling_params": {
"max_new_tokens": 1, # Generate only 1 token
"temperature": 0.0,
"top_p": 1.0,
"top_k": -1,
"ignore_eos": False,
},
}
super().__init__()
......@@ -15,6 +15,10 @@ from dynamo.llm import ZmqKvEventPublisher, ZmqKvEventPublisherConfig
from dynamo.runtime import DistributedRuntime, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.sglang.args import Config, DisaggregationMode, parse_args
from dynamo.sglang.health_check import (
SglangHealthCheckPayload,
SglangPrefillHealthCheckPayload,
)
from dynamo.sglang.publisher import setup_sgl_metrics
from dynamo.sglang.register import register_llm_with_runtime_config
from dynamo.sglang.request_handlers import DecodeWorkerHandler, PrefillWorkerHandler
......@@ -112,6 +116,8 @@ async def init(runtime: DistributedRuntime, config: Config):
ready_event.set()
logging.info("Model registration succeeded; processing queued requests")
health_check_payload = SglangHealthCheckPayload().to_dict()
try:
# Start endpoint immediately and register model concurrently
# Requests queue until ready_event is set
......@@ -120,6 +126,7 @@ async def init(runtime: DistributedRuntime, config: Config):
handler.generate,
graceful_shutdown=True,
metrics_labels=metrics_labels,
health_check_payload=health_check_payload,
),
register_model(),
)
......@@ -150,11 +157,14 @@ async def init_prefill(runtime: DistributedRuntime, config: Config):
handler = PrefillWorkerHandler(component, engine, config)
health_check_payload = SglangPrefillHealthCheckPayload().to_dict()
tasks = [
generate_endpoint.serve_endpoint(
handler.generate,
graceful_shutdown=True,
metrics_labels=[("model", server_args.served_model_name)],
health_check_payload=health_check_payload,
)
]
......
......@@ -53,7 +53,7 @@ class DecodeWorkerHandler(BaseWorkerHandler):
sampling_params["ignore_eos"] = request["stop_conditions"]["ignore_eos"]
return sampling_params
async def generate(self, request: str):
async def generate(self, request: dict):
sampling_params = self._build_sampling_params(request)
if self.serving_mode == DisaggregationMode.DECODE:
......@@ -62,7 +62,7 @@ class DecodeWorkerHandler(BaseWorkerHandler):
DisaggPreprocessedRequest(
request=request,
sampling_params=sampling_params,
).model_dump_json()
).model_dump()
)
bootstrap_info = None
......
......@@ -6,7 +6,6 @@ import logging
import random
import socket
import msgspec
import sglang as sgl
from sglang.srt.utils import get_ip
......@@ -46,8 +45,7 @@ class PrefillWorkerHandler(BaseWorkerHandler):
return bootstrap_host, bootstrap_port
async def generate(self, request: str):
req = msgspec.json.decode(request, type=dict)
async def generate(self, request: dict):
bootstrap_room = self._generate_bootstrap_room()
bootstrap_info = {
......@@ -59,8 +57,8 @@ class PrefillWorkerHandler(BaseWorkerHandler):
yield bootstrap_info
results = await self.engine.async_generate(
input_ids=req["request"]["token_ids"],
sampling_params=req["sampling_params"],
input_ids=request["request"]["token_ids"],
sampling_params=request["sampling_params"],
stream=True,
bootstrap_host=self.bootstrap_host,
bootstrap_port=self.bootstrap_port,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment