"lib/bindings/vscode:/vscode.git/clone" did not exist on "3718da8c689a558b7958f462bc3d00a1bbcced3e"
Unverified Commit a9b74dc2 authored by jh-nv's avatar jh-nv Committed by GitHub
Browse files

feat: request migration for trtllm (#5599)

parent 66c36996
......@@ -71,8 +71,9 @@ DEFAULT_KV_EVENT_BUFFER_MAX_SIZE = 1024
configure_dynamo_logging()
async def graceful_shutdown(runtime):
async def graceful_shutdown(runtime, shutdown_event):
logging.info("Received shutdown signal, shutting down DistributedRuntime")
shutdown_event.set()
runtime.shutdown()
logging.info("DistributedRuntime shutdown complete")
......@@ -128,6 +129,9 @@ async def worker():
config = cmd_line_args()
loop = asyncio.get_running_loop()
# Create shutdown event
shutdown_event = asyncio.Event()
# Enable NATS based on use_kv_events flag (derived from publish_events_and_metrics)
runtime = DistributedRuntime(
loop, config.store_kv, config.request_plane, config.use_kv_events
......@@ -136,17 +140,19 @@ async def worker():
# Set up signal handler for graceful shutdown
def signal_handler():
# Schedule the shutdown coroutine instead of calling it directly
asyncio.create_task(graceful_shutdown(runtime))
asyncio.create_task(graceful_shutdown(runtime, shutdown_event))
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, signal_handler)
logging.info("Signal handlers set up for graceful shutdown")
await init(runtime, config)
await init(runtime, config, shutdown_event)
async def init(runtime: DistributedRuntime, config: Config):
async def init(
runtime: DistributedRuntime, config: Config, shutdown_event: asyncio.Event
):
"""
Instantiate and serve
"""
......@@ -425,6 +431,7 @@ async def init(runtime: DistributedRuntime, config: Config):
runtime=runtime, # Pass runtime for graceful shutdown
metrics_collector=metrics_collector,
kv_block_size=config.kv_block_size,
shutdown_event=shutdown_event,
)
# Register the model with runtime config
......
......@@ -67,6 +67,7 @@ class RequestHandlerConfig:
] = None # DistributedRuntime reference for graceful shutdown
metrics_collector: Optional[Any] = None # TensorRT-LLM MetricsCollector
kv_block_size: int = 32
shutdown_event: Optional[asyncio.Event] = None
class HandlerBase:
......@@ -88,6 +89,7 @@ class HandlerBase:
# Store runtime reference for graceful shutdown
self.runtime = config.runtime
self.kv_block_size: int = config.kv_block_size
self.shutdown_event = config.shutdown_event
def check_error(self, result: dict):
"""
......@@ -170,18 +172,49 @@ class HandlerBase:
return log_probs if log_probs else None, top_logprobs if top_logprobs else None
async def _handle_cancellation(
async def _handle_cancellation_and_shutdown(
self, generation_result: GenerationResult, context: Context
):
"""Background task to handle cancellation by monitoring context state."""
"""
Background task to handle cancellation and shutdown by monitoring both signals.
Returns 'shutdown' if shutdown was triggered, 'cancelled' if cancelled, None otherwise.
"""
try:
# Wait asynchronously for cancellation signal instead of polling
await context.async_killed_or_stopped()
cancellation_task = context.async_killed_or_stopped()
# Build list of futures/tasks to wait for
wait_for = [cancellation_task]
shutdown_task = None
if self.shutdown_event:
# Create task for shutdown monitoring and add to wait list
shutdown_task = asyncio.create_task(self.shutdown_event.wait())
wait_for.append(shutdown_task)
# Wait for whichever happens first
done, pending = await asyncio.wait(
wait_for,
return_when=asyncio.FIRST_COMPLETED,
)
# Cancel the pending task/future
for task in pending:
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
# Abort the generation
generation_result.abort()
logging.debug(f"Aborted Request ID: {context.id()}")
# Check which event triggered and return the reason
if shutdown_task and shutdown_task in done:
raise GeneratorExit("Engine was shut down during generation.")
except asyncio.CancelledError:
# Task was cancelled, which is expected when generation completes
# Task was cancelled, which is expected when generation completes normally
pass
@asynccontextmanager
......@@ -189,28 +222,32 @@ class HandlerBase:
self, generation_result: GenerationResult, context: Context
) -> AsyncGenerator[asyncio.Task, None]:
"""
Context manager for monitoring request cancellation.
Context manager for monitoring request cancellation and shutdown.
Automatically creates a background task to monitor for cancellation
and shutdown events, cleaning it up when the context exits.
Automatically creates a background task to monitor for cancellation and
cleans it up when the context exits.
If shutdown event was triggered, raises GeneratorExit on exit.
Yields:
asyncio.Task: The cancellation monitoring task
asyncio.Task: The monitoring task
"""
cancellation_task = asyncio.create_task(
self._handle_cancellation(generation_result, context)
monitor_task = asyncio.create_task(
self._handle_cancellation_and_shutdown(generation_result, context)
)
try:
yield cancellation_task
yield monitor_task
finally:
# Clean up the background cancellation task
if not cancellation_task.done():
cancellation_task.cancel()
# Clean up the background monitoring task
if not monitor_task.done():
monitor_task.cancel()
try:
await cancellation_task
await monitor_task
except asyncio.CancelledError:
pass
else:
monitor_task.result()
def _decode_disaggregated_params_from_prefill(
self, prefill_result: dict
......@@ -653,7 +690,7 @@ class HandlerBase:
trace_headers=trace_headers,
)
# Use the context manager to handle cancellation monitoring
# Use the context manager to handle cancellation and shutdown monitoring
async with self._cancellation_monitor(generation_result, context):
async for res in generation_result:
# TRTLLM engine needs to start generating tokens first before stats
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment