Unverified Commit 6a612a66 authored by Tzu-Ling Kan's avatar Tzu-Ling Kan Committed by GitHub
Browse files

feat: vllm. Use prefill-specific health check payload and use bos as token_id (#3126)


Signed-off-by: default avatartzulingk@nvidia.com <tzulingk@nvidia.com>
parent 396755eb
......@@ -7,8 +7,43 @@ vLLM-specific health check configuration.
This module defines the default health check payload for vLLM backends.
"""
import logging
from dynamo.health_check import HealthCheckPayload
logger = logging.getLogger(__name__)
def _get_bos_token_id_from_engine(engine_client) -> int:
"""
Extract BOS token ID from the vLLM engine client's tokenizer if available.
Args:
engine_client: vLLM AsyncLLM engine client
Returns:
BOS token ID from the model's tokenizer, or 1 as fallback
"""
if engine_client is None:
return 1
try:
tokenizer_group = getattr(engine_client, "tokenizer", None)
if tokenizer_group:
tokenizer = getattr(tokenizer_group, "tokenizer", None)
if tokenizer:
bos_token_id = getattr(tokenizer, "bos_token_id", None)
if bos_token_id is not None:
logger.info(
f"Using model's BOS token ID for health check: {bos_token_id}"
)
return int(bos_token_id)
except Exception as e:
logger.debug(f"Failed to get BOS token from engine: {e}")
logger.debug("Using default BOS token ID (1) for health check")
return 1
class VllmHealthCheckPayload(HealthCheckPayload):
"""
......@@ -17,14 +52,20 @@ class VllmHealthCheckPayload(HealthCheckPayload):
Provides vLLM defaults and inherits environment override support from base class.
"""
def __init__(self):
def __init__(self, engine_client=None):
"""
Initialize vLLM health check payload with vLLM-specific defaults.
Args:
engine_client: Optional vLLM AsyncLLM engine client to extract BOS token from.
If provided, will attempt to use the model's actual BOS token.
"""
bos_token_id = _get_bos_token_id_from_engine(engine_client)
# Set vLLM default payload - minimal request that completes quickly
# The handler expects token_ids, sampling_options, and stop_conditions
self.default_payload = {
"token_ids": [1], # Single token for minimal processing
"token_ids": [bos_token_id],
"sampling_options": {
"max_tokens": 1,
"temperature": 0.0,
......@@ -38,3 +79,44 @@ class VllmHealthCheckPayload(HealthCheckPayload):
},
}
super().__init__()
class VllmPrefillHealthCheckPayload(HealthCheckPayload):
"""
vLLM-specific health check payload for prefill workers in disaggregated mode.
The prefill handler expects a different structure with 'request_id' and 'sampling_params'.
"""
def __init__(self, engine_client=None):
"""
Initialize vLLM prefill health check payload with proper structure.
Args:
engine_client: Optional vLLM AsyncLLM engine client to extract BOS token from.
If provided, will attempt to use the model's actual BOS token.
"""
bos_token_id = _get_bos_token_id_from_engine(engine_client)
# Prefill handler expects request_id, token_ids, and sampling_params
# The sampling_params are converted via msgspec in the handler
self.default_payload = {
"request_id": "health_check",
"token_ids": [bos_token_id],
"sampling_params": {
"max_tokens": 1,
"min_tokens": 1,
"temperature": 0.0,
"top_p": 1.0,
"top_k": -1,
"detokenize": False,
"include_stop_str_in_output": False,
"ignore_eos": False,
"extra_args": {
"kv_transfer_params": {
"do_remote_decode": True,
}
},
},
}
super().__init__()
......@@ -24,7 +24,7 @@ from dynamo.runtime.logging import configure_dynamo_logging
from .args import ENABLE_LMCACHE, Config, configure_ports, overwrite_args, parse_args
from .handlers import DecodeWorkerHandler, PrefillWorkerHandler
from .health_check import VllmHealthCheckPayload
from .health_check import VllmHealthCheckPayload, VllmPrefillHealthCheckPayload
from .publisher import StatLoggerFactory
configure_dynamo_logging()
......@@ -145,8 +145,7 @@ async def init_prefill(runtime: DistributedRuntime, config: Config):
runtime, component, engine_client, default_sampling_params
)
# Get health check payload (checks env var and falls back to vLLM default)
health_check_payload = VllmHealthCheckPayload().to_dict()
health_check_payload = VllmPrefillHealthCheckPayload(engine_client).to_dict()
try:
logger.debug("Starting serve_endpoint for prefill worker")
......@@ -261,8 +260,7 @@ async def init(runtime: DistributedRuntime, config: Config):
custom_template_path=config.custom_jinja_template,
)
# Get health check payload (checks env var and falls back to vLLM default)
health_check_payload = VllmHealthCheckPayload().to_dict()
health_check_payload = VllmHealthCheckPayload(engine_client).to_dict()
try:
logger.debug("Starting serve_endpoint for decode worker")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment