Unverified Commit 7b43b3a5 authored by Tzu-Ling Kan's avatar Tzu-Ling Kan Committed by GitHub
Browse files

feat: Trtllm health check payload use bos_token_id (#3145)


Signed-off-by: default avatartzulingk@nvidia.com <tzulingk@nvidia.com>
parent 517f3024
......@@ -7,8 +7,46 @@ TRT-LLM-specific health check configuration.
This module defines the default health check payload for TRT-LLM backends.
"""
import logging
from dynamo.health_check import HealthCheckPayload
logger = logging.getLogger(__name__)
def _get_bos_token_id_from_tokenizer(tokenizer) -> int:
"""
Extract BOS token ID from the TRT-LLM tokenizer if available.
Args:
tokenizer: TRT-LLM tokenizer object
Returns:
BOS token ID from the tokenizer, or 1 as fallback
Note:
The TransformersTokenizer class wraps a HuggingFace tokenizer.
While TransformersTokenizer doesn't expose bos_token_id directly,
the wrapped HuggingFace tokenizer (accessible via tokenizer.tokenizer) does.
"""
if tokenizer is None:
return 1
try:
if hasattr(tokenizer, "tokenizer"):
inner_tokenizer = getattr(tokenizer, "tokenizer")
bos_token_id = getattr(inner_tokenizer, "bos_token_id", None)
if bos_token_id is not None:
logger.info(
f"Using model's BOS token ID for health check: {bos_token_id}"
)
return int(bos_token_id)
except Exception as e:
logger.debug(f"Failed to get BOS token from tokenizer: {e}")
logger.debug("Using default BOS token ID (1) for health check")
return 1
class TrtllmHealthCheckPayload(HealthCheckPayload):
"""
......@@ -17,14 +55,20 @@ class TrtllmHealthCheckPayload(HealthCheckPayload):
Provides TRT-LLM defaults and inherits environment override support from base class.
"""
def __init__(self):
def __init__(self, tokenizer=None):
"""
Initialize TRT-LLM health check payload with TRT-LLM-specific defaults.
Args:
tokenizer: Optional TRT-LLM tokenizer to extract BOS token from.
If provided, will attempt to use the model's actual BOS token.
"""
bos_token_id = _get_bos_token_id_from_tokenizer(tokenizer)
# Set TensorRT-LLM default payload - minimal request that completes quickly
# The handler expects token_ids, stop_conditions, and sampling_options
self.default_payload = {
"token_ids": [1], # Single token for minimal processing
"token_ids": [bos_token_id],
"stop_conditions": {
"max_tokens": 1, # Generate only 1 token
"stop": None,
......
......@@ -318,7 +318,7 @@ async def init(runtime: DistributedRuntime, config: Config):
)
# Get health check payload (checks env var and falls back to TensorRT-LLM default)
health_check_payload = TrtllmHealthCheckPayload().to_dict()
health_check_payload = TrtllmHealthCheckPayload(tokenizer=tokenizer).to_dict()
if config.publish_events_and_metrics and is_first_worker(config):
# Initialize and pass in the publisher to the request handler to
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment