Unverified Commit d401c1a4 authored by Jacky's avatar Jacky Committed by GitHub
Browse files

refactor: Removes obsolete TRT-LLM request-abort disable wiring and re-enables...


refactor: Removes obsolete TRT-LLM request-abort disable wiring and re-enables cancellation test coverage (#8562)
Signed-off-by: default avatarJacky <18255193+kthui@users.noreply.github.com>
parent 6ba534a6
......@@ -153,13 +153,6 @@ class DynamoTrtllmArgGroup(ArgGroup):
default=False,
help="If set, publish events and metrics to Dynamo components.",
)
add_negatable_bool_argument(
g,
flag_name="--disable-request-abort",
env_var="DYN_TRTLLM_DISABLE_REQUEST_ABORT",
default=True,
help="Disable calling abort() on the TRT-LLM engine when a request is cancelled.",
)
add_argument(
g,
flag_name="--load-format",
......@@ -503,7 +496,6 @@ class DynamoTrtllmConfig(ConfigBase):
extra_engine_args: str
override_engine_args: str
publish_events_and_metrics: bool
disable_request_abort: bool
load_format: str
model_loader_extra_config: str
guided_decoding_backend: Optional[str] = None
......
......@@ -216,7 +216,6 @@ class RequestHandlerConfig:
shutdown_event: Optional[asyncio.Event] = None
generate_endpoint: Optional[Any] = None
encoder_cache_capacity_gb: float = 0 # Encoder cache capacity in GB
disable_request_abort: bool = True
additional_metrics: Optional["AdditionalMetricsCollector"] = None
max_seq_len: Optional[int] = None
disagg_machine_id: int = 0 # 10-bit machine_id for snowflake disagg_request_id
......@@ -249,7 +248,6 @@ class HandlerBase(BaseGenerativeHandler):
self.kv_block_size: int = config.kv_block_size
self.shutdown_event = config.shutdown_event
self.generate_endpoint = config.generate_endpoint
self.disable_request_abort = config.disable_request_abort
self.additional_metrics = config.additional_metrics
self.max_seq_len = config.max_seq_len
self.disagg_machine_id = config.disagg_machine_id
......@@ -464,13 +462,6 @@ class HandlerBase(BaseGenerativeHandler):
return_when=asyncio.FIRST_COMPLETED,
)
# Abort the generation unless disabled
if self.disable_request_abort:
logging.debug(
f"Request ID {context.id()} cancelled but abort() skipped "
"(DYN_TRTLLM_DISABLE_REQUEST_ABORT=true)"
)
else:
generation_result.abort()
logging.debug(f"Aborted Request ID: {context.id()}")
......
......@@ -338,47 +338,6 @@ class _ConcreteHandler(HandlerBase):
raise NotImplementedError
class TestHandleCancellationAbortToggle:
"""Tests for the disable_request_abort toggle in _handle_cancellation."""
def _make_handler(self, disable_request_abort: bool) -> HandlerBase:
"""Create a HandlerBase with mocked config."""
config = MagicMock()
config.disable_request_abort = disable_request_abort
config.shutdown_event = None
return _ConcreteHandler(config)
@pytest.mark.asyncio
async def test_abort_called_by_default(self):
handler = self._make_handler(disable_request_abort=False)
generation_result = MagicMock()
context = MagicMock()
# async_killed_or_stopped returns an awaitable that resolves immediately
# (simulating the client cancelling the request)
killed_future = asyncio.get_event_loop().create_future()
killed_future.set_result(None)
context.async_killed_or_stopped.return_value = killed_future
context.id.return_value = "test-id-1"
await handler._handle_cancellation(generation_result, context)
generation_result.abort.assert_called_once()
@pytest.mark.asyncio
async def test_abort_not_called_when_disabled(self):
handler = self._make_handler(disable_request_abort=True)
generation_result = MagicMock()
context = MagicMock()
killed_future = asyncio.get_event_loop().create_future()
killed_future.set_result(None)
context.async_killed_or_stopped.return_value = killed_future
context.id.return_value = "test-id-2"
await handler._handle_cancellation(generation_result, context)
generation_result.abort.assert_not_called()
class TestDeferredAbortGuard:
"""Tests for _DeferredAbort in disaggregated decode cancellation.
......@@ -388,9 +347,8 @@ class TestDeferredAbortGuard:
that waits for the first token before calling real abort.
"""
def _make_handler(self, disable_request_abort: bool = False) -> HandlerBase:
def _make_handler(self) -> HandlerBase:
config = MagicMock()
config.disable_request_abort = disable_request_abort
config.shutdown_event = None
return _ConcreteHandler(config)
......@@ -456,7 +414,7 @@ class TestDeferredAbortGuard:
@pytest.mark.asyncio
async def test_no_guard_in_non_disagg_mode(self):
"""Without _DeferredAbort wrapper, abort fires immediately on cancel."""
handler = self._make_handler(disable_request_abort=False)
handler = self._make_handler()
generation_result = MagicMock()
context = MagicMock()
killed_future = asyncio.get_event_loop().create_future()
......@@ -472,7 +430,7 @@ class TestDeferredAbortGuard:
@pytest.mark.timeout(5)
async def test_shutdown_calls_abort_directly(self):
"""Shutdown calls abort on whatever is passed (wrapper or real), immediately."""
handler = self._make_handler(disable_request_abort=False)
handler = self._make_handler()
handler.shutdown_event = asyncio.Event()
# Pass a _DeferredAbort wrapper — shutdown should still call .abort()
......@@ -497,20 +455,6 @@ class TestDeferredAbortGuard:
# Shutdown calls guard.abort() → since no first token, spawns background task
# The important thing is EngineShutdown is raised and abort path is entered
@pytest.mark.asyncio
async def test_disable_request_abort_skips_guard(self):
"""When disable_request_abort=True, abort is never called (guard irrelevant)."""
handler = self._make_handler(disable_request_abort=True)
generation_result = MagicMock()
context = MagicMock()
killed_future = asyncio.get_event_loop().create_future()
killed_future.set_result(None)
context.async_killed_or_stopped.return_value = killed_future
context.id.return_value = "test-disabled"
await handler._handle_cancellation(generation_result, context)
generation_result.abort.assert_not_called()
class TestMultimodalGuard:
"""Tests for multimodal guard when --modality multimodal is not configured."""
......
......@@ -593,7 +593,6 @@ async def init_llm_worker(
kv_block_size=config.kv_block_size,
shutdown_event=shutdown_event,
encoder_cache_capacity_gb=config.multimodal_embedding_cache_capacity_gb,
disable_request_abort=config.disable_request_abort,
additional_metrics=additional_metrics,
max_seq_len=config.max_seq_len,
disagg_machine_id=int(endpoint.connection_id()) % 1021,
......
......@@ -40,7 +40,6 @@ pytestmark = [
pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME),
pytest.mark.nightly,
pytest.mark.parametrize("request_plane", ["nats", "tcp"], indirect=True),
pytest.mark.skip(reason="Cancellation is temporarily disabled"),
]
......@@ -364,6 +363,7 @@ def test_request_cancellation_trtllm_decode_cancel(
)
@pytest.mark.skip(reason="TRT-LLM prefill cancellation is disabled due to reliability")
@pytest.mark.timeout(195) # 3x average
def test_request_cancellation_trtllm_prefill_cancel(
request, runtime_services_dynamic_ports, predownload_models
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment