Unverified Commit 1ad2a3b8 authored by Tzu-Ling Kan's avatar Tzu-Ling Kan Committed by GitHub
Browse files

feat: catch Trtllm engine exceptions (#3544)


Signed-off-by: default avatartzulingk@nvidia.com <tzulingk@nvidia.com>
parent 5cbfc27e
...@@ -349,6 +349,7 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -349,6 +349,7 @@ async def init(runtime: DistributedRuntime, config: Config):
encode_client=encode_client, encode_client=encode_client,
multimodal_processor=multimodal_processor, multimodal_processor=multimodal_processor,
connector=connector, connector=connector,
runtime=runtime, # Pass runtime for graceful shutdown
) )
if next_client: if next_client:
......
...@@ -24,12 +24,14 @@ from typing import AsyncGenerator, Optional, Union ...@@ -24,12 +24,14 @@ from typing import AsyncGenerator, Optional, Union
import torch import torch
from tensorrt_llm.executor.result import GenerationResult from tensorrt_llm.executor.result import GenerationResult
from tensorrt_llm.executor.utils import RequestError
from tensorrt_llm.llmapi import DisaggregatedParams as LlmDisaggregatedParams from tensorrt_llm.llmapi import DisaggregatedParams as LlmDisaggregatedParams
from tensorrt_llm.llmapi.llm import SamplingParams from tensorrt_llm.llmapi.llm import SamplingParams
from dynamo._core import Context from dynamo._core import Context
from dynamo.logits_processing.examples import HelloWorldLogitsProcessor from dynamo.logits_processing.examples import HelloWorldLogitsProcessor
from dynamo.nixl_connect import Connector from dynamo.nixl_connect import Connector
from dynamo.runtime import DistributedRuntime
from dynamo.runtime.logging import configure_dynamo_logging from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.trtllm.engine import TensorRTLLMEngine from dynamo.trtllm.engine import TensorRTLLMEngine
from dynamo.trtllm.logits_processing.adapter import create_trtllm_adapters from dynamo.trtllm.logits_processing.adapter import create_trtllm_adapters
...@@ -74,6 +76,9 @@ class RequestHandlerConfig: ...@@ -74,6 +76,9 @@ class RequestHandlerConfig:
MultimodalRequestProcessor MultimodalRequestProcessor
] = None # for multimodal support ] = None # for multimodal support
connector: Optional[Connector] = None connector: Optional[Connector] = None
runtime: Optional[
DistributedRuntime
] = None # DistributedRuntime reference for graceful shutdown
class HandlerBase: class HandlerBase:
...@@ -94,6 +99,8 @@ class HandlerBase: ...@@ -94,6 +99,8 @@ class HandlerBase:
self.multimodal_processor = config.multimodal_processor self.multimodal_processor = config.multimodal_processor
self.first_generation = True self.first_generation = True
self.connector = config.connector self.connector = config.connector
# Store runtime reference for graceful shutdown
self.runtime = config.runtime
def check_error(self, result: dict): def check_error(self, result: dict):
""" """
...@@ -148,6 +155,24 @@ class HandlerBase: ...@@ -148,6 +155,24 @@ class HandlerBase:
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass
async def _initiate_shutdown(self, error: Exception):
"""Initiate graceful shutdown after fatal error"""
logging.warning(f"Initiating graceful shutdown due to: {error}")
try:
if self.runtime:
logging.info("Shutting down Dynamo runtime...")
self.runtime.shutdown()
if self.engine:
logging.info("Shutting down TensorRT-LLM engine...")
await self.engine.cleanup()
except Exception as cleanup_error:
logging.error(f"Error during graceful shutdown: {cleanup_error}")
finally:
logging.critical("Forcing process exit for restart")
os._exit(1)
async def generate_locally( async def generate_locally(
self, self,
request: dict, request: dict,
...@@ -243,66 +268,99 @@ class HandlerBase: ...@@ -243,66 +268,99 @@ class HandlerBase:
adapters = create_trtllm_adapters(processors) adapters = create_trtllm_adapters(processors)
sampling_params.logits_processor = adapters sampling_params.logits_processor = adapters
# NEW: Updated engine call to include multimodal data try:
generation_result = self.engine.llm.generate_async( # NEW: Updated engine call to include multimodal data
inputs=processed_input, # Use the correctly extracted inputs generation_result = self.engine.llm.generate_async(
sampling_params=sampling_params, inputs=processed_input, # Use the correctly extracted inputs
disaggregated_params=disaggregated_params, sampling_params=sampling_params,
streaming=streaming, disaggregated_params=disaggregated_params,
) streaming=streaming,
)
# Use the context manager to handle cancellation monitoring # Use the context manager to handle cancellation monitoring
async with self._cancellation_monitor(generation_result, context): async with self._cancellation_monitor(generation_result, context):
async for res in generation_result: async for res in generation_result:
# TRTLLM engine needs to start generating tokens first before stats # TRTLLM engine needs to start generating tokens first before stats
# can be retrieved. # can be retrieved.
if self.first_generation and self.publisher: if self.first_generation and self.publisher:
self.publisher.start() self.publisher.start()
self.first_generation = False self.first_generation = False
# Upon completion, send a final chunk with "stop" as the finish reason. # Upon completion, send a final chunk with "stop" as the finish reason.
# This signals to the client that the stream has ended. # This signals to the client that the stream has ended.
if ( if (
res.finished res.finished
and self.disaggregation_mode != DisaggregationMode.PREFILL and self.disaggregation_mode != DisaggregationMode.PREFILL
): ):
if self.multimodal_processor:
final_out = self.multimodal_processor.get_stop_response(
request_id, model_name
)
yield final_out
# If we are not done generating, but there are no outputs, return an error
if not res.outputs and not res.finished:
yield {"finish_reason": "error", "token_ids": []}
break
output = res.outputs[0]
# The engine returns all tokens generated so far. We must calculate the new
# tokens generated in this iteration to create the "delta".
next_total_toks = len(output.token_ids)
if self.multimodal_processor: if self.multimodal_processor:
final_out = self.multimodal_processor.get_stop_response( out = self.multimodal_processor.create_response_chunk(
request_id, model_name output, num_output_tokens_so_far, request_id, model_name
)
else:
out = {"token_ids": output.token_ids[num_output_tokens_so_far:]}
if output.finish_reason:
out["finish_reason"] = output.finish_reason
if output.stop_reason:
out["stop_reason"] = output.stop_reason
if self.disaggregation_mode == DisaggregationMode.PREFILL:
# Return the disaggregated params only when operating in prefill mode.
out["disaggregated_params"] = asdict(
DisaggregatedParamsCodec.encode(output.disaggregated_params)
) )
yield final_out
if res.finished and not out.get("finish_reason"):
# If we are not done generating, but there are no outputs, return an error out["finish_reason"] = "unknown"
if not res.outputs and not res.finished: logging.warning(
yield {"finish_reason": "error", "token_ids": []} "Request finished with no finish reason set - this indicates a possible bug"
break )
output = res.outputs[0] # Yield the chunk to the client and update the token count for the next iteration.
# The engine returns all tokens generated so far. We must calculate the new yield out
# tokens generated in this iteration to create the "delta". num_output_tokens_so_far = next_total_toks
next_total_toks = len(output.token_ids)
if self.multimodal_processor: # 1. Client cancellation - don't shutdown
out = self.multimodal_processor.create_response_chunk( except asyncio.CancelledError:
output, num_output_tokens_so_far, request_id, model_name logging.debug(f"Request {request_id}: Client cancelled")
) # _cancellation_monitor already called abort_request
else: return # Just stop, no error response
out = {"token_ids": output.token_ids[num_output_tokens_so_far:]}
if output.finish_reason: # 2. Per-request errors - send to client, don't shutdown
out["finish_reason"] = output.finish_reason except RequestError as e:
if output.stop_reason: logging.warning(f"Request {request_id} error: {e}")
out["stop_reason"] = output.stop_reason yield {"finish_reason": "error", "token_ids": []}
if self.disaggregation_mode == DisaggregationMode.PREFILL:
# Return the disaggregated params only when operating in prefill mode. # 3. ALL OTHER ERRORS - graceful shutdown
out["disaggregated_params"] = asdict( except Exception as e:
DisaggregatedParamsCodec.encode(output.disaggregated_params) error_type = type(e).__name__
) error_msg = str(e)
logging.error(
if res.finished and not out.get("finish_reason"): f"Fatal {error_type} in request {request_id}: {error_msg}",
out["finish_reason"] = "unknown" exc_info=True,
logging.warning( )
"Request finished with no finish reason set - this indicates a possible bug"
) # Try to send error to client before shutdown
try:
# Yield the chunk to the client and update the token count for the next iteration. yield {
yield out "finish_reason": "error",
num_output_tokens_so_far = next_total_toks "token_ids": [],
}
except Exception:
pass # Best effort
# Initiate graceful shutdown
await self._initiate_shutdown(e)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment