Unverified Commit 41d7d549 authored by Jacky's avatar Jacky Committed by GitHub
Browse files

feat: Configurable Request Cancellation abort passage to TRT-LLM (#6445)


Signed-off-by: default avatarJacky <18255193+kthui@users.noreply.github.com>
parent d82b0050
...@@ -152,6 +152,13 @@ class DynamoTrtllmArgGroup(ArgGroup): ...@@ -152,6 +152,13 @@ class DynamoTrtllmArgGroup(ArgGroup):
default=False, default=False,
help="If set, publish events and metrics to Dynamo components.", help="If set, publish events and metrics to Dynamo components.",
) )
add_negatable_bool_argument(
g,
flag_name="--disable-request-abort",
env_var="DYN_TRTLLM_DISABLE_REQUEST_ABORT",
default=True,
help="Disable calling abort() on the TRT-LLM engine when a request is cancelled.",
)
add_argument( add_argument(
g, g,
flag_name="--disaggregation-mode", flag_name="--disaggregation-mode",
...@@ -363,6 +370,7 @@ class DynamoTrtllmConfig(ConfigBase): ...@@ -363,6 +370,7 @@ class DynamoTrtllmConfig(ConfigBase):
extra_engine_args: str extra_engine_args: str
override_engine_args: str override_engine_args: str
publish_events_and_metrics: bool publish_events_and_metrics: bool
disable_request_abort: bool
disaggregation_mode: DisaggregationMode disaggregation_mode: DisaggregationMode
modality: Modality modality: Modality
......
...@@ -72,6 +72,7 @@ class RequestHandlerConfig: ...@@ -72,6 +72,7 @@ class RequestHandlerConfig:
kv_block_size: int = 32 kv_block_size: int = 32
shutdown_event: Optional[asyncio.Event] = None shutdown_event: Optional[asyncio.Event] = None
encoder_cache_capacity_gb: float = 0 # Encoder cache capacity in GB encoder_cache_capacity_gb: float = 0 # Encoder cache capacity in GB
disable_request_abort: bool = True
class HandlerBase(BaseGenerativeHandler): class HandlerBase(BaseGenerativeHandler):
...@@ -101,6 +102,7 @@ class HandlerBase(BaseGenerativeHandler): ...@@ -101,6 +102,7 @@ class HandlerBase(BaseGenerativeHandler):
self.runtime = config.runtime self.runtime = config.runtime
self.kv_block_size: int = config.kv_block_size self.kv_block_size: int = config.kv_block_size
self.shutdown_event = config.shutdown_event self.shutdown_event = config.shutdown_event
self.disable_request_abort = config.disable_request_abort
def check_error(self, result: dict): def check_error(self, result: dict):
""" """
...@@ -208,13 +210,15 @@ class HandlerBase(BaseGenerativeHandler): ...@@ -208,13 +210,15 @@ class HandlerBase(BaseGenerativeHandler):
return_when=asyncio.FIRST_COMPLETED, return_when=asyncio.FIRST_COMPLETED,
) )
# Abort the generation # Abort the generation unless disabled
# Temporary: if self.disable_request_abort:
# Disable calling abort() on the engine, which may get stuck if a logging.debug(
# sufficiently large number of concurrent requests is cancelled. f"Request ID {context.id()} cancelled but abort() skipped "
# Note to restore: "(DYN_TRTLLM_DISABLE_REQUEST_ABORT=true)"
# call `generation_result.abort()`; and then )
# log `logging.debug(f"Aborted Request ID: {context.id()}")` else:
generation_result.abort()
logging.debug(f"Aborted Request ID: {context.id()}")
# Clean up any remaining background task # Clean up any remaining background task
for task in pending: for task in pending:
......
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import asyncio
from dataclasses import dataclass from dataclasses import dataclass
from unittest import mock from unittest import mock
from unittest.mock import MagicMock
import pytest import pytest
import torch import torch
...@@ -218,3 +220,51 @@ class TestGuidedDecodingFromToolChoice: ...@@ -218,3 +220,51 @@ class TestGuidedDecodingFromToolChoice:
# Downstream code (TRT-LLM sampling_params.py) accesses these attributes: # Downstream code (TRT-LLM sampling_params.py) accesses these attributes:
assert result.guided_decoding.json_object is False assert result.guided_decoding.json_object is False
assert result.guided_decoding.json == self.GUIDED_DECODING_DICT["json"] assert result.guided_decoding.json == self.GUIDED_DECODING_DICT["json"]
class _ConcreteHandler(HandlerBase):
"""Concrete subclass of HandlerBase for testing (satisfies abstract method)."""
async def generate(self, *args, **kwargs):
raise NotImplementedError
class TestHandleCancellationAbortToggle:
"""Tests for the disable_request_abort toggle in _handle_cancellation."""
def _make_handler(self, disable_request_abort: bool) -> HandlerBase:
"""Create a HandlerBase with mocked config."""
config = MagicMock()
config.disable_request_abort = disable_request_abort
config.shutdown_event = None
return _ConcreteHandler(config)
@pytest.mark.asyncio
async def test_abort_called_by_default(self):
handler = self._make_handler(disable_request_abort=False)
generation_result = MagicMock()
context = MagicMock()
# async_killed_or_stopped returns an awaitable that resolves immediately
# (simulating the client cancelling the request)
killed_future = asyncio.get_event_loop().create_future()
killed_future.set_result(None)
context.async_killed_or_stopped.return_value = killed_future
context.id.return_value = "test-id-1"
await handler._handle_cancellation(generation_result, context)
generation_result.abort.assert_called_once()
@pytest.mark.asyncio
async def test_abort_not_called_when_disabled(self):
handler = self._make_handler(disable_request_abort=True)
generation_result = MagicMock()
context = MagicMock()
killed_future = asyncio.get_event_loop().create_future()
killed_future.set_result(None)
context.async_killed_or_stopped.return_value = killed_future
context.id.return_value = "test-id-2"
await handler._handle_cancellation(generation_result, context)
generation_result.abort.assert_not_called()
...@@ -426,6 +426,7 @@ async def init_llm_worker( ...@@ -426,6 +426,7 @@ async def init_llm_worker(
kv_block_size=config.kv_block_size, kv_block_size=config.kv_block_size,
shutdown_event=shutdown_event, shutdown_event=shutdown_event,
encoder_cache_capacity_gb=config.multimodal_embedding_cache_capacity_gb, encoder_cache_capacity_gb=config.multimodal_embedding_cache_capacity_gb,
disable_request_abort=config.disable_request_abort,
) )
# Register the model with runtime config # Register the model with runtime config
......
...@@ -187,8 +187,7 @@ def test_request_cancellation_trtllm_aggregated( ...@@ -187,8 +187,7 @@ def test_request_cancellation_trtllm_aggregated(
with DynamoFrontendProcess(request) as frontend: with DynamoFrontendProcess(request) as frontend:
logger.info("Frontend started successfully") logger.info("Frontend started successfully")
# Step 2: Start an aggregated worker # Step 2: Start an aggregated worker (allocates its own system_port)
# Step 2: Start a single worker (allocates its own system_port)
with DynamoWorkerProcess( with DynamoWorkerProcess(
request, frontend.frontend_port, mode="prefill_and_decode" request, frontend.frontend_port, mode="prefill_and_decode"
) as worker: ) as worker:
...@@ -217,10 +216,10 @@ def test_request_cancellation_trtllm_aggregated( ...@@ -217,10 +216,10 @@ def test_request_cancellation_trtllm_aggregated(
frontend.frontend_port, request_type frontend.frontend_port, request_type
) )
# Poll for "New Request ID" pattern # Poll for "AggregatedHandler Request ID" pattern
request_id, worker_log_offset = poll_for_pattern( request_id, worker_log_offset = poll_for_pattern(
process=worker, process=worker,
pattern="New Request ID: ", pattern="AggregatedHandler Request ID: ",
log_offset=worker_log_offset, log_offset=worker_log_offset,
match_type="contains", match_type="contains",
) )
......
...@@ -236,7 +236,7 @@ def test_request_migration_trtllm_aggregated( ...@@ -236,7 +236,7 @@ def test_request_migration_trtllm_aggregated(
frontend, frontend,
worker1, worker1,
worker2, worker2,
receiving_pattern="New Request ID: ", receiving_pattern="AggregatedHandler Request ID: ",
migration_limit=migration_limit, migration_limit=migration_limit,
immediate_kill=immediate_kill, immediate_kill=immediate_kill,
use_chat_completion=(request_api == "chat"), use_chat_completion=(request_api == "chat"),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment