Unverified Commit 1ad2a3b8 authored by Tzu-Ling Kan's avatar Tzu-Ling Kan Committed by GitHub
Browse files

feat: catch Trtllm engine exceptions (#3544)


Signed-off-by: default avatartzulingk@nvidia.com <tzulingk@nvidia.com>
parent 5cbfc27e
...@@ -349,6 +349,7 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -349,6 +349,7 @@ async def init(runtime: DistributedRuntime, config: Config):
encode_client=encode_client, encode_client=encode_client,
multimodal_processor=multimodal_processor, multimodal_processor=multimodal_processor,
connector=connector, connector=connector,
runtime=runtime, # Pass runtime for graceful shutdown
) )
if next_client: if next_client:
......
...@@ -24,12 +24,14 @@ from typing import AsyncGenerator, Optional, Union ...@@ -24,12 +24,14 @@ from typing import AsyncGenerator, Optional, Union
import torch import torch
from tensorrt_llm.executor.result import GenerationResult from tensorrt_llm.executor.result import GenerationResult
from tensorrt_llm.executor.utils import RequestError
from tensorrt_llm.llmapi import DisaggregatedParams as LlmDisaggregatedParams from tensorrt_llm.llmapi import DisaggregatedParams as LlmDisaggregatedParams
from tensorrt_llm.llmapi.llm import SamplingParams from tensorrt_llm.llmapi.llm import SamplingParams
from dynamo._core import Context from dynamo._core import Context
from dynamo.logits_processing.examples import HelloWorldLogitsProcessor from dynamo.logits_processing.examples import HelloWorldLogitsProcessor
from dynamo.nixl_connect import Connector from dynamo.nixl_connect import Connector
from dynamo.runtime import DistributedRuntime
from dynamo.runtime.logging import configure_dynamo_logging from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.trtllm.engine import TensorRTLLMEngine from dynamo.trtllm.engine import TensorRTLLMEngine
from dynamo.trtllm.logits_processing.adapter import create_trtllm_adapters from dynamo.trtllm.logits_processing.adapter import create_trtllm_adapters
...@@ -74,6 +76,9 @@ class RequestHandlerConfig: ...@@ -74,6 +76,9 @@ class RequestHandlerConfig:
MultimodalRequestProcessor MultimodalRequestProcessor
] = None # for multimodal support ] = None # for multimodal support
connector: Optional[Connector] = None connector: Optional[Connector] = None
runtime: Optional[
DistributedRuntime
] = None # DistributedRuntime reference for graceful shutdown
class HandlerBase: class HandlerBase:
...@@ -94,6 +99,8 @@ class HandlerBase: ...@@ -94,6 +99,8 @@ class HandlerBase:
self.multimodal_processor = config.multimodal_processor self.multimodal_processor = config.multimodal_processor
self.first_generation = True self.first_generation = True
self.connector = config.connector self.connector = config.connector
# Store runtime reference for graceful shutdown
self.runtime = config.runtime
def check_error(self, result: dict): def check_error(self, result: dict):
""" """
...@@ -148,6 +155,24 @@ class HandlerBase: ...@@ -148,6 +155,24 @@ class HandlerBase:
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass
async def _initiate_shutdown(self, error: Exception):
"""Initiate graceful shutdown after fatal error"""
logging.warning(f"Initiating graceful shutdown due to: {error}")
try:
if self.runtime:
logging.info("Shutting down Dynamo runtime...")
self.runtime.shutdown()
if self.engine:
logging.info("Shutting down TensorRT-LLM engine...")
await self.engine.cleanup()
except Exception as cleanup_error:
logging.error(f"Error during graceful shutdown: {cleanup_error}")
finally:
logging.critical("Forcing process exit for restart")
os._exit(1)
async def generate_locally( async def generate_locally(
self, self,
request: dict, request: dict,
...@@ -243,6 +268,7 @@ class HandlerBase: ...@@ -243,6 +268,7 @@ class HandlerBase:
adapters = create_trtllm_adapters(processors) adapters = create_trtllm_adapters(processors)
sampling_params.logits_processor = adapters sampling_params.logits_processor = adapters
try:
# NEW: Updated engine call to include multimodal data # NEW: Updated engine call to include multimodal data
generation_result = self.engine.llm.generate_async( generation_result = self.engine.llm.generate_async(
inputs=processed_input, # Use the correctly extracted inputs inputs=processed_input, # Use the correctly extracted inputs
...@@ -306,3 +332,35 @@ class HandlerBase: ...@@ -306,3 +332,35 @@ class HandlerBase:
# Yield the chunk to the client and update the token count for the next iteration. # Yield the chunk to the client and update the token count for the next iteration.
yield out yield out
num_output_tokens_so_far = next_total_toks num_output_tokens_so_far = next_total_toks
# 1. Client cancellation - don't shutdown
except asyncio.CancelledError:
logging.debug(f"Request {request_id}: Client cancelled")
# _cancellation_monitor already called abort_request
return # Just stop, no error response
# 2. Per-request errors - send to client, don't shutdown
except RequestError as e:
logging.warning(f"Request {request_id} error: {e}")
yield {"finish_reason": "error", "token_ids": []}
# 3. ALL OTHER ERRORS - graceful shutdown
except Exception as e:
error_type = type(e).__name__
error_msg = str(e)
logging.error(
f"Fatal {error_type} in request {request_id}: {error_msg}",
exc_info=True,
)
# Try to send error to client before shutdown
try:
yield {
"finish_reason": "error",
"token_ids": [],
}
except Exception:
pass # Best effort
# Initiate graceful shutdown
await self._initiate_shutdown(e)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment