Unverified Commit 1ad2a3b8 authored by Tzu-Ling Kan's avatar Tzu-Ling Kan Committed by GitHub
Browse files

feat: catch Trtllm engine exceptions (#3544)


Signed-off-by: default avatartzulingk@nvidia.com <tzulingk@nvidia.com>
parent 5cbfc27e
......@@ -349,6 +349,7 @@ async def init(runtime: DistributedRuntime, config: Config):
encode_client=encode_client,
multimodal_processor=multimodal_processor,
connector=connector,
runtime=runtime, # Pass runtime for graceful shutdown
)
if next_client:
......
......@@ -24,12 +24,14 @@ from typing import AsyncGenerator, Optional, Union
import torch
from tensorrt_llm.executor.result import GenerationResult
from tensorrt_llm.executor.utils import RequestError
from tensorrt_llm.llmapi import DisaggregatedParams as LlmDisaggregatedParams
from tensorrt_llm.llmapi.llm import SamplingParams
from dynamo._core import Context
from dynamo.logits_processing.examples import HelloWorldLogitsProcessor
from dynamo.nixl_connect import Connector
from dynamo.runtime import DistributedRuntime
from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.trtllm.engine import TensorRTLLMEngine
from dynamo.trtllm.logits_processing.adapter import create_trtllm_adapters
......@@ -74,6 +76,9 @@ class RequestHandlerConfig:
MultimodalRequestProcessor
] = None # for multimodal support
connector: Optional[Connector] = None
runtime: Optional[
DistributedRuntime
] = None # DistributedRuntime reference for graceful shutdown
class HandlerBase:
......@@ -94,6 +99,8 @@ class HandlerBase:
self.multimodal_processor = config.multimodal_processor
self.first_generation = True
self.connector = config.connector
# Store runtime reference for graceful shutdown
self.runtime = config.runtime
def check_error(self, result: dict):
"""
......@@ -148,6 +155,24 @@ class HandlerBase:
except asyncio.CancelledError:
pass
async def _initiate_shutdown(self, error: Exception):
"""Initiate graceful shutdown after fatal error"""
logging.warning(f"Initiating graceful shutdown due to: {error}")
try:
if self.runtime:
logging.info("Shutting down Dynamo runtime...")
self.runtime.shutdown()
if self.engine:
logging.info("Shutting down TensorRT-LLM engine...")
await self.engine.cleanup()
except Exception as cleanup_error:
logging.error(f"Error during graceful shutdown: {cleanup_error}")
finally:
logging.critical("Forcing process exit for restart")
os._exit(1)
async def generate_locally(
self,
request: dict,
......@@ -243,66 +268,99 @@ class HandlerBase:
adapters = create_trtllm_adapters(processors)
sampling_params.logits_processor = adapters
# NEW: Updated engine call to include multimodal data
generation_result = self.engine.llm.generate_async(
inputs=processed_input, # Use the correctly extracted inputs
sampling_params=sampling_params,
disaggregated_params=disaggregated_params,
streaming=streaming,
)
try:
# NEW: Updated engine call to include multimodal data
generation_result = self.engine.llm.generate_async(
inputs=processed_input, # Use the correctly extracted inputs
sampling_params=sampling_params,
disaggregated_params=disaggregated_params,
streaming=streaming,
)
# Use the context manager to handle cancellation monitoring
async with self._cancellation_monitor(generation_result, context):
async for res in generation_result:
# TRTLLM engine needs to start generating tokens first before stats
# can be retrieved.
if self.first_generation and self.publisher:
self.publisher.start()
self.first_generation = False
# Upon completion, send a final chunk with "stop" as the finish reason.
# This signals to the client that the stream has ended.
if (
res.finished
and self.disaggregation_mode != DisaggregationMode.PREFILL
):
# Use the context manager to handle cancellation monitoring
async with self._cancellation_monitor(generation_result, context):
async for res in generation_result:
# TRTLLM engine needs to start generating tokens first before stats
# can be retrieved.
if self.first_generation and self.publisher:
self.publisher.start()
self.first_generation = False
# Upon completion, send a final chunk with "stop" as the finish reason.
# This signals to the client that the stream has ended.
if (
res.finished
and self.disaggregation_mode != DisaggregationMode.PREFILL
):
if self.multimodal_processor:
final_out = self.multimodal_processor.get_stop_response(
request_id, model_name
)
yield final_out
# If we are not done generating, but there are no outputs, return an error
if not res.outputs and not res.finished:
yield {"finish_reason": "error", "token_ids": []}
break
output = res.outputs[0]
# The engine returns all tokens generated so far. We must calculate the new
# tokens generated in this iteration to create the "delta".
next_total_toks = len(output.token_ids)
if self.multimodal_processor:
final_out = self.multimodal_processor.get_stop_response(
request_id, model_name
out = self.multimodal_processor.create_response_chunk(
output, num_output_tokens_so_far, request_id, model_name
)
else:
out = {"token_ids": output.token_ids[num_output_tokens_so_far:]}
if output.finish_reason:
out["finish_reason"] = output.finish_reason
if output.stop_reason:
out["stop_reason"] = output.stop_reason
if self.disaggregation_mode == DisaggregationMode.PREFILL:
# Return the disaggregated params only when operating in prefill mode.
out["disaggregated_params"] = asdict(
DisaggregatedParamsCodec.encode(output.disaggregated_params)
)
yield final_out
# If we are not done generating, but there are no outputs, return an error
if not res.outputs and not res.finished:
yield {"finish_reason": "error", "token_ids": []}
break
output = res.outputs[0]
# The engine returns all tokens generated so far. We must calculate the new
# tokens generated in this iteration to create the "delta".
next_total_toks = len(output.token_ids)
if self.multimodal_processor:
out = self.multimodal_processor.create_response_chunk(
output, num_output_tokens_so_far, request_id, model_name
)
else:
out = {"token_ids": output.token_ids[num_output_tokens_so_far:]}
if output.finish_reason:
out["finish_reason"] = output.finish_reason
if output.stop_reason:
out["stop_reason"] = output.stop_reason
if self.disaggregation_mode == DisaggregationMode.PREFILL:
# Return the disaggregated params only when operating in prefill mode.
out["disaggregated_params"] = asdict(
DisaggregatedParamsCodec.encode(output.disaggregated_params)
)
if res.finished and not out.get("finish_reason"):
out["finish_reason"] = "unknown"
logging.warning(
"Request finished with no finish reason set - this indicates a possible bug"
)
# Yield the chunk to the client and update the token count for the next iteration.
yield out
num_output_tokens_so_far = next_total_toks
if res.finished and not out.get("finish_reason"):
out["finish_reason"] = "unknown"
logging.warning(
"Request finished with no finish reason set - this indicates a possible bug"
)
# Yield the chunk to the client and update the token count for the next iteration.
yield out
num_output_tokens_so_far = next_total_toks
# 1. Client cancellation - don't shutdown
except asyncio.CancelledError:
logging.debug(f"Request {request_id}: Client cancelled")
# _cancellation_monitor already called abort_request
return # Just stop, no error response
# 2. Per-request errors - send to client, don't shutdown
except RequestError as e:
logging.warning(f"Request {request_id} error: {e}")
yield {"finish_reason": "error", "token_ids": []}
# 3. ALL OTHER ERRORS - graceful shutdown
except Exception as e:
error_type = type(e).__name__
error_msg = str(e)
logging.error(
f"Fatal {error_type} in request {request_id}: {error_msg}",
exc_info=True,
)
# Try to send error to client before shutdown
try:
yield {
"finish_reason": "error",
"token_ids": [],
}
except Exception:
pass # Best effort
# Initiate graceful shutdown
await self._initiate_shutdown(e)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment