Unverified Commit a9b74dc2 authored by jh-nv's avatar jh-nv Committed by GitHub
Browse files

feat: request migration for trtllm (#5599)

parent 66c36996
...@@ -71,8 +71,9 @@ DEFAULT_KV_EVENT_BUFFER_MAX_SIZE = 1024 ...@@ -71,8 +71,9 @@ DEFAULT_KV_EVENT_BUFFER_MAX_SIZE = 1024
configure_dynamo_logging() configure_dynamo_logging()
async def graceful_shutdown(runtime): async def graceful_shutdown(runtime, shutdown_event):
logging.info("Received shutdown signal, shutting down DistributedRuntime") logging.info("Received shutdown signal, shutting down DistributedRuntime")
shutdown_event.set()
runtime.shutdown() runtime.shutdown()
logging.info("DistributedRuntime shutdown complete") logging.info("DistributedRuntime shutdown complete")
...@@ -128,6 +129,9 @@ async def worker(): ...@@ -128,6 +129,9 @@ async def worker():
config = cmd_line_args() config = cmd_line_args()
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
# Create shutdown event
shutdown_event = asyncio.Event()
# Enable NATS based on use_kv_events flag (derived from publish_events_and_metrics) # Enable NATS based on use_kv_events flag (derived from publish_events_and_metrics)
runtime = DistributedRuntime( runtime = DistributedRuntime(
loop, config.store_kv, config.request_plane, config.use_kv_events loop, config.store_kv, config.request_plane, config.use_kv_events
...@@ -136,17 +140,19 @@ async def worker(): ...@@ -136,17 +140,19 @@ async def worker():
# Set up signal handler for graceful shutdown # Set up signal handler for graceful shutdown
def signal_handler(): def signal_handler():
# Schedule the shutdown coroutine instead of calling it directly # Schedule the shutdown coroutine instead of calling it directly
asyncio.create_task(graceful_shutdown(runtime)) asyncio.create_task(graceful_shutdown(runtime, shutdown_event))
for sig in (signal.SIGTERM, signal.SIGINT): for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, signal_handler) loop.add_signal_handler(sig, signal_handler)
logging.info("Signal handlers set up for graceful shutdown") logging.info("Signal handlers set up for graceful shutdown")
await init(runtime, config) await init(runtime, config, shutdown_event)
async def init(runtime: DistributedRuntime, config: Config): async def init(
runtime: DistributedRuntime, config: Config, shutdown_event: asyncio.Event
):
""" """
Instantiate and serve Instantiate and serve
""" """
...@@ -425,6 +431,7 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -425,6 +431,7 @@ async def init(runtime: DistributedRuntime, config: Config):
runtime=runtime, # Pass runtime for graceful shutdown runtime=runtime, # Pass runtime for graceful shutdown
metrics_collector=metrics_collector, metrics_collector=metrics_collector,
kv_block_size=config.kv_block_size, kv_block_size=config.kv_block_size,
shutdown_event=shutdown_event,
) )
# Register the model with runtime config # Register the model with runtime config
......
...@@ -67,6 +67,7 @@ class RequestHandlerConfig: ...@@ -67,6 +67,7 @@ class RequestHandlerConfig:
] = None # DistributedRuntime reference for graceful shutdown ] = None # DistributedRuntime reference for graceful shutdown
metrics_collector: Optional[Any] = None # TensorRT-LLM MetricsCollector metrics_collector: Optional[Any] = None # TensorRT-LLM MetricsCollector
kv_block_size: int = 32 kv_block_size: int = 32
shutdown_event: Optional[asyncio.Event] = None
class HandlerBase: class HandlerBase:
...@@ -88,6 +89,7 @@ class HandlerBase: ...@@ -88,6 +89,7 @@ class HandlerBase:
# Store runtime reference for graceful shutdown # Store runtime reference for graceful shutdown
self.runtime = config.runtime self.runtime = config.runtime
self.kv_block_size: int = config.kv_block_size self.kv_block_size: int = config.kv_block_size
self.shutdown_event = config.shutdown_event
def check_error(self, result: dict): def check_error(self, result: dict):
""" """
...@@ -170,18 +172,49 @@ class HandlerBase: ...@@ -170,18 +172,49 @@ class HandlerBase:
return log_probs if log_probs else None, top_logprobs if top_logprobs else None return log_probs if log_probs else None, top_logprobs if top_logprobs else None
async def _handle_cancellation( async def _handle_cancellation_and_shutdown(
self, generation_result: GenerationResult, context: Context self, generation_result: GenerationResult, context: Context
): ):
"""Background task to handle cancellation by monitoring context state.""" """
Background task to handle cancellation and shutdown by monitoring both signals.
Returns 'shutdown' if shutdown was triggered, 'cancelled' if cancelled, None otherwise.
"""
try: try:
# Wait asynchronously for cancellation signal instead of polling cancellation_task = context.async_killed_or_stopped()
await context.async_killed_or_stopped()
# Build list of futures/tasks to wait for
wait_for = [cancellation_task]
shutdown_task = None
if self.shutdown_event:
# Create task for shutdown monitoring and add to wait list
shutdown_task = asyncio.create_task(self.shutdown_event.wait())
wait_for.append(shutdown_task)
# Wait for whichever happens first
done, pending = await asyncio.wait(
wait_for,
return_when=asyncio.FIRST_COMPLETED,
)
# Cancel the pending task/future
for task in pending:
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
# Abort the generation # Abort the generation
generation_result.abort() generation_result.abort()
logging.debug(f"Aborted Request ID: {context.id()}") logging.debug(f"Aborted Request ID: {context.id()}")
# Check which event triggered and return the reason
if shutdown_task and shutdown_task in done:
raise GeneratorExit("Engine was shut down during generation.")
except asyncio.CancelledError: except asyncio.CancelledError:
# Task was cancelled, which is expected when generation completes # Task was cancelled, which is expected when generation completes normally
pass pass
@asynccontextmanager @asynccontextmanager
...@@ -189,28 +222,32 @@ class HandlerBase: ...@@ -189,28 +222,32 @@ class HandlerBase:
self, generation_result: GenerationResult, context: Context self, generation_result: GenerationResult, context: Context
) -> AsyncGenerator[asyncio.Task, None]: ) -> AsyncGenerator[asyncio.Task, None]:
""" """
Context manager for monitoring request cancellation. Context manager for monitoring request cancellation and shutdown.
Automatically creates a background task to monitor for cancellation
and shutdown events, cleaning it up when the context exits.
Automatically creates a background task to monitor for cancellation and If shutdown event was triggered, raises GeneratorExit on exit.
cleans it up when the context exits.
Yields: Yields:
asyncio.Task: The cancellation monitoring task asyncio.Task: The monitoring task
""" """
cancellation_task = asyncio.create_task( monitor_task = asyncio.create_task(
self._handle_cancellation(generation_result, context) self._handle_cancellation_and_shutdown(generation_result, context)
) )
try: try:
yield cancellation_task yield monitor_task
finally: finally:
# Clean up the background cancellation task # Clean up the background monitoring task
if not cancellation_task.done(): if not monitor_task.done():
cancellation_task.cancel() monitor_task.cancel()
try: try:
await cancellation_task await monitor_task
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass
else:
monitor_task.result()
def _decode_disaggregated_params_from_prefill( def _decode_disaggregated_params_from_prefill(
self, prefill_result: dict self, prefill_result: dict
...@@ -653,7 +690,7 @@ class HandlerBase: ...@@ -653,7 +690,7 @@ class HandlerBase:
trace_headers=trace_headers, trace_headers=trace_headers,
) )
# Use the context manager to handle cancellation monitoring # Use the context manager to handle cancellation and shutdown monitoring
async with self._cancellation_monitor(generation_result, context): async with self._cancellation_monitor(generation_result, context):
async for res in generation_result: async for res in generation_result:
# TRTLLM engine needs to start generating tokens first before stats # TRTLLM engine needs to start generating tokens first before stats
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment