"lib/ssh:/git@developer.sourcefind.cn:2222/OpenDAS/dynamo.git" did not exist on "014266954b150a3f2384dbcb53fdeed02abebd55"
Unverified Commit c4334471 authored by Tzu-Ling Kan's avatar Tzu-Ling Kan Committed by GitHub
Browse files

feat: use bos_token_id for sglang's health check payload (#3123)


Signed-off-by: default avatartzulingk@nvidia.com <tzulingk@nvidia.com>
parent 7b43b3a5
...@@ -7,8 +7,43 @@ sglang-specific health check configuration. ...@@ -7,8 +7,43 @@ sglang-specific health check configuration.
This module defines the default health check payload for sglang backends. This module defines the default health check payload for sglang backends.
""" """
import logging
from dynamo.health_check import HealthCheckPayload from dynamo.health_check import HealthCheckPayload
logger = logging.getLogger(__name__)
def _get_bos_token_id_from_engine(engine) -> int:
"""
Extract BOS token ID from the SGLang engine's tokenizer if available.
Args:
engine: SGLang Engine instance
Returns:
BOS token ID from the model's tokenizer, or 1 as fallback
"""
if engine is None:
return 1
try:
tokenizer_manager = getattr(engine, "tokenizer_manager", None)
if tokenizer_manager:
tokenizer = getattr(tokenizer_manager, "tokenizer", None)
if tokenizer:
bos_token_id = getattr(tokenizer, "bos_token_id", None)
if bos_token_id is not None:
logger.info(
f"Using model's BOS token ID for health check: {bos_token_id}"
)
return int(bos_token_id)
except Exception as e:
logger.debug(f"Failed to get BOS token from engine: {e}")
logger.debug("Using default BOS token ID (1) for health check")
return 1
class SglangHealthCheckPayload(HealthCheckPayload): class SglangHealthCheckPayload(HealthCheckPayload):
""" """
...@@ -17,14 +52,20 @@ class SglangHealthCheckPayload(HealthCheckPayload): ...@@ -17,14 +52,20 @@ class SglangHealthCheckPayload(HealthCheckPayload):
Provides sglang defaults and inherits environment override support from base class. Provides sglang defaults and inherits environment override support from base class.
""" """
def __init__(self): def __init__(self, engine=None):
""" """
Initialize sglang health check payload with sglang-specific defaults. Initialize sglang health check payload with sglang-specific defaults.
Args:
engine: Optional SGLang Engine instance to extract BOS token from.
If provided, will attempt to use the model's actual BOS token.
The format matches what DecodeWorkerHandler expects from the frontend. The format matches what DecodeWorkerHandler expects from the frontend.
""" """
bos_token_id = _get_bos_token_id_from_engine(engine)
self.default_payload = { self.default_payload = {
"token_ids": [1], # Single token for minimal processing "token_ids": [bos_token_id],
"stop_conditions": { "stop_conditions": {
"max_tokens": 1, # Generate only 1 token "max_tokens": 1, # Generate only 1 token
"ignore_eos": False, "ignore_eos": False,
...@@ -47,13 +88,19 @@ class SglangPrefillHealthCheckPayload(HealthCheckPayload): ...@@ -47,13 +88,19 @@ class SglangPrefillHealthCheckPayload(HealthCheckPayload):
The prefill handler expects a wrapped structure with 'request' and 'sampling_params'. The prefill handler expects a wrapped structure with 'request' and 'sampling_params'.
""" """
def __init__(self): def __init__(self, engine=None):
""" """
Initialize SGLang prefill health check payload with proper wrapped structure. Initialize SGLang prefill health check payload with proper wrapped structure.
Args:
engine: Optional SGLang Engine instance to extract BOS token from.
If provided, will attempt to use the model's actual BOS token.
""" """
bos_token_id = _get_bos_token_id_from_engine(engine)
self.default_payload = { self.default_payload = {
"request": { "request": {
"token_ids": [1], # Single token for minimal processing "token_ids": [bos_token_id],
}, },
"sampling_params": { "sampling_params": {
"max_new_tokens": 1, # Generate only 1 token "max_new_tokens": 1, # Generate only 1 token
......
...@@ -116,7 +116,7 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -116,7 +116,7 @@ async def init(runtime: DistributedRuntime, config: Config):
ready_event.set() ready_event.set()
logging.info("Model registration succeeded; processing queued requests") logging.info("Model registration succeeded; processing queued requests")
health_check_payload = SglangHealthCheckPayload().to_dict() health_check_payload = SglangHealthCheckPayload(engine).to_dict()
try: try:
# Start endpoint immediately and register model concurrently # Start endpoint immediately and register model concurrently
...@@ -157,7 +157,7 @@ async def init_prefill(runtime: DistributedRuntime, config: Config): ...@@ -157,7 +157,7 @@ async def init_prefill(runtime: DistributedRuntime, config: Config):
handler = PrefillWorkerHandler(component, engine, config) handler = PrefillWorkerHandler(component, engine, config)
health_check_payload = SglangPrefillHealthCheckPayload().to_dict() health_check_payload = SglangPrefillHealthCheckPayload(engine).to_dict()
tasks = [ tasks = [
generate_endpoint.serve_endpoint( generate_endpoint.serve_endpoint(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment