Unverified Commit 41d7d549 authored by Jacky's avatar Jacky Committed by GitHub
Browse files

feat: Configurable Request Cancellation abort passage to TRT-LLM (#6445)


Signed-off-by: default avatarJacky <18255193+kthui@users.noreply.github.com>
parent d82b0050
......@@ -152,6 +152,13 @@ class DynamoTrtllmArgGroup(ArgGroup):
default=False,
help="If set, publish events and metrics to Dynamo components.",
)
add_negatable_bool_argument(
g,
flag_name="--disable-request-abort",
env_var="DYN_TRTLLM_DISABLE_REQUEST_ABORT",
default=True,
help="Disable calling abort() on the TRT-LLM engine when a request is cancelled.",
)
add_argument(
g,
flag_name="--disaggregation-mode",
......@@ -363,6 +370,7 @@ class DynamoTrtllmConfig(ConfigBase):
extra_engine_args: str
override_engine_args: str
publish_events_and_metrics: bool
disable_request_abort: bool
disaggregation_mode: DisaggregationMode
modality: Modality
......
......@@ -72,6 +72,7 @@ class RequestHandlerConfig:
kv_block_size: int = 32
shutdown_event: Optional[asyncio.Event] = None
encoder_cache_capacity_gb: float = 0 # Encoder cache capacity in GB
disable_request_abort: bool = True
class HandlerBase(BaseGenerativeHandler):
......@@ -101,6 +102,7 @@ class HandlerBase(BaseGenerativeHandler):
self.runtime = config.runtime
self.kv_block_size: int = config.kv_block_size
self.shutdown_event = config.shutdown_event
self.disable_request_abort = config.disable_request_abort
def check_error(self, result: dict):
"""
......@@ -208,13 +210,15 @@ class HandlerBase(BaseGenerativeHandler):
return_when=asyncio.FIRST_COMPLETED,
)
# Abort the generation
# Temporary:
# Disable calling abort() on the engine, which may get stuck if a
# sufficiently large number of concurrent requests is cancelled.
# Note to restore:
# call `generation_result.abort()`; and then
# log `logging.debug(f"Aborted Request ID: {context.id()}")`
# Abort the generation unless disabled
if self.disable_request_abort:
logging.debug(
f"Request ID {context.id()} cancelled but abort() skipped "
"(DYN_TRTLLM_DISABLE_REQUEST_ABORT=true)"
)
else:
generation_result.abort()
logging.debug(f"Aborted Request ID: {context.id()}")
# Clean up any remaining background task
for task in pending:
......
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import asyncio
from dataclasses import dataclass
from unittest import mock
from unittest.mock import MagicMock
import pytest
import torch
......@@ -218,3 +220,51 @@ class TestGuidedDecodingFromToolChoice:
# Downstream code (TRT-LLM sampling_params.py) accesses these attributes:
assert result.guided_decoding.json_object is False
assert result.guided_decoding.json == self.GUIDED_DECODING_DICT["json"]
class _ConcreteHandler(HandlerBase):
"""Concrete subclass of HandlerBase for testing (satisfies abstract method)."""
async def generate(self, *args, **kwargs):
raise NotImplementedError
class TestHandleCancellationAbortToggle:
"""Tests for the disable_request_abort toggle in _handle_cancellation."""
def _make_handler(self, disable_request_abort: bool) -> HandlerBase:
"""Create a HandlerBase with mocked config."""
config = MagicMock()
config.disable_request_abort = disable_request_abort
config.shutdown_event = None
return _ConcreteHandler(config)
@pytest.mark.asyncio
async def test_abort_called_by_default(self):
handler = self._make_handler(disable_request_abort=False)
generation_result = MagicMock()
context = MagicMock()
# async_killed_or_stopped returns an awaitable that resolves immediately
# (simulating the client cancelling the request)
killed_future = asyncio.get_event_loop().create_future()
killed_future.set_result(None)
context.async_killed_or_stopped.return_value = killed_future
context.id.return_value = "test-id-1"
await handler._handle_cancellation(generation_result, context)
generation_result.abort.assert_called_once()
@pytest.mark.asyncio
async def test_abort_not_called_when_disabled(self):
handler = self._make_handler(disable_request_abort=True)
generation_result = MagicMock()
context = MagicMock()
killed_future = asyncio.get_event_loop().create_future()
killed_future.set_result(None)
context.async_killed_or_stopped.return_value = killed_future
context.id.return_value = "test-id-2"
await handler._handle_cancellation(generation_result, context)
generation_result.abort.assert_not_called()
......@@ -426,6 +426,7 @@ async def init_llm_worker(
kv_block_size=config.kv_block_size,
shutdown_event=shutdown_event,
encoder_cache_capacity_gb=config.multimodal_embedding_cache_capacity_gb,
disable_request_abort=config.disable_request_abort,
)
# Register the model with runtime config
......
......@@ -187,8 +187,7 @@ def test_request_cancellation_trtllm_aggregated(
with DynamoFrontendProcess(request) as frontend:
logger.info("Frontend started successfully")
# Step 2: Start an aggregated worker
# Step 2: Start a single worker (allocates its own system_port)
# Step 2: Start an aggregated worker (allocates its own system_port)
with DynamoWorkerProcess(
request, frontend.frontend_port, mode="prefill_and_decode"
) as worker:
......@@ -217,10 +216,10 @@ def test_request_cancellation_trtllm_aggregated(
frontend.frontend_port, request_type
)
# Poll for "New Request ID" pattern
# Poll for "AggregatedHandler Request ID" pattern
request_id, worker_log_offset = poll_for_pattern(
process=worker,
pattern="New Request ID: ",
pattern="AggregatedHandler Request ID: ",
log_offset=worker_log_offset,
match_type="contains",
)
......
......@@ -236,7 +236,7 @@ def test_request_migration_trtllm_aggregated(
frontend,
worker1,
worker2,
receiving_pattern="New Request ID: ",
receiving_pattern="AggregatedHandler Request ID: ",
migration_limit=migration_limit,
immediate_kill=immediate_kill,
use_chat_completion=(request_api == "chat"),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment