Unverified Commit 7b43b3a5 authored by Tzu-Ling Kan's avatar Tzu-Ling Kan Committed by GitHub
Browse files

feat: Trtllm health check payload use bos_token_id (#3145)


Signed-off-by: default avatartzulingk@nvidia.com <tzulingk@nvidia.com>
parent 517f3024
...@@ -7,8 +7,46 @@ TRT-LLM-specific health check configuration. ...@@ -7,8 +7,46 @@ TRT-LLM-specific health check configuration.
This module defines the default health check payload for TRT-LLM backends. This module defines the default health check payload for TRT-LLM backends.
""" """
import logging
from dynamo.health_check import HealthCheckPayload from dynamo.health_check import HealthCheckPayload
logger = logging.getLogger(__name__)
def _get_bos_token_id_from_tokenizer(tokenizer) -> int:
"""
Extract BOS token ID from the TRT-LLM tokenizer if available.
Args:
tokenizer: TRT-LLM tokenizer object
Returns:
BOS token ID from the tokenizer, or 1 as fallback
Note:
The TransformersTokenizer class wraps a HuggingFace tokenizer.
While TransformersTokenizer doesn't expose bos_token_id directly,
the wrapped HuggingFace tokenizer (accessible via tokenizer.tokenizer) does.
"""
if tokenizer is None:
return 1
try:
if hasattr(tokenizer, "tokenizer"):
inner_tokenizer = getattr(tokenizer, "tokenizer")
bos_token_id = getattr(inner_tokenizer, "bos_token_id", None)
if bos_token_id is not None:
logger.info(
f"Using model's BOS token ID for health check: {bos_token_id}"
)
return int(bos_token_id)
except Exception as e:
logger.debug(f"Failed to get BOS token from tokenizer: {e}")
logger.debug("Using default BOS token ID (1) for health check")
return 1
class TrtllmHealthCheckPayload(HealthCheckPayload): class TrtllmHealthCheckPayload(HealthCheckPayload):
""" """
...@@ -17,14 +55,20 @@ class TrtllmHealthCheckPayload(HealthCheckPayload): ...@@ -17,14 +55,20 @@ class TrtllmHealthCheckPayload(HealthCheckPayload):
Provides TRT-LLM defaults and inherits environment override support from base class. Provides TRT-LLM defaults and inherits environment override support from base class.
""" """
def __init__(self): def __init__(self, tokenizer=None):
""" """
Initialize TRT-LLM health check payload with TRT-LLM-specific defaults. Initialize TRT-LLM health check payload with TRT-LLM-specific defaults.
Args:
tokenizer: Optional TRT-LLM tokenizer to extract BOS token from.
If provided, will attempt to use the model's actual BOS token.
""" """
bos_token_id = _get_bos_token_id_from_tokenizer(tokenizer)
# Set TensorRT-LLM default payload - minimal request that completes quickly # Set TensorRT-LLM default payload - minimal request that completes quickly
# The handler expects token_ids, stop_conditions, and sampling_options # The handler expects token_ids, stop_conditions, and sampling_options
self.default_payload = { self.default_payload = {
"token_ids": [1], # Single token for minimal processing "token_ids": [bos_token_id],
"stop_conditions": { "stop_conditions": {
"max_tokens": 1, # Generate only 1 token "max_tokens": 1, # Generate only 1 token
"stop": None, "stop": None,
......
...@@ -318,7 +318,7 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -318,7 +318,7 @@ async def init(runtime: DistributedRuntime, config: Config):
) )
# Get health check payload (checks env var and falls back to TensorRT-LLM default) # Get health check payload (checks env var and falls back to TensorRT-LLM default)
health_check_payload = TrtllmHealthCheckPayload().to_dict() health_check_payload = TrtllmHealthCheckPayload(tokenizer=tokenizer).to_dict()
if config.publish_events_and_metrics and is_first_worker(config): if config.publish_events_and_metrics and is_first_worker(config):
# Initialize and pass in the publisher to the request handler to # Initialize and pass in the publisher to the request handler to
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment