"...git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "4ae77dfd42041dc2defe21f6ccf76aecb4478812"
Unverified Commit d401c1a4 authored by Jacky's avatar Jacky Committed by GitHub
Browse files

refactor: Removes obsolete TRT-LLM request-abort disable wiring and re-enables...


refactor: Removes obsolete TRT-LLM request-abort disable wiring and re-enables cancellation test coverage (#8562)
Signed-off-by: default avatarJacky <18255193+kthui@users.noreply.github.com>
parent 6ba534a6
...@@ -153,13 +153,6 @@ class DynamoTrtllmArgGroup(ArgGroup): ...@@ -153,13 +153,6 @@ class DynamoTrtllmArgGroup(ArgGroup):
default=False, default=False,
help="If set, publish events and metrics to Dynamo components.", help="If set, publish events and metrics to Dynamo components.",
) )
add_negatable_bool_argument(
g,
flag_name="--disable-request-abort",
env_var="DYN_TRTLLM_DISABLE_REQUEST_ABORT",
default=True,
help="Disable calling abort() on the TRT-LLM engine when a request is cancelled.",
)
add_argument( add_argument(
g, g,
flag_name="--load-format", flag_name="--load-format",
...@@ -503,7 +496,6 @@ class DynamoTrtllmConfig(ConfigBase): ...@@ -503,7 +496,6 @@ class DynamoTrtllmConfig(ConfigBase):
extra_engine_args: str extra_engine_args: str
override_engine_args: str override_engine_args: str
publish_events_and_metrics: bool publish_events_and_metrics: bool
disable_request_abort: bool
load_format: str load_format: str
model_loader_extra_config: str model_loader_extra_config: str
guided_decoding_backend: Optional[str] = None guided_decoding_backend: Optional[str] = None
......
...@@ -216,7 +216,6 @@ class RequestHandlerConfig: ...@@ -216,7 +216,6 @@ class RequestHandlerConfig:
shutdown_event: Optional[asyncio.Event] = None shutdown_event: Optional[asyncio.Event] = None
generate_endpoint: Optional[Any] = None generate_endpoint: Optional[Any] = None
encoder_cache_capacity_gb: float = 0 # Encoder cache capacity in GB encoder_cache_capacity_gb: float = 0 # Encoder cache capacity in GB
disable_request_abort: bool = True
additional_metrics: Optional["AdditionalMetricsCollector"] = None additional_metrics: Optional["AdditionalMetricsCollector"] = None
max_seq_len: Optional[int] = None max_seq_len: Optional[int] = None
disagg_machine_id: int = 0 # 10-bit machine_id for snowflake disagg_request_id disagg_machine_id: int = 0 # 10-bit machine_id for snowflake disagg_request_id
...@@ -249,7 +248,6 @@ class HandlerBase(BaseGenerativeHandler): ...@@ -249,7 +248,6 @@ class HandlerBase(BaseGenerativeHandler):
self.kv_block_size: int = config.kv_block_size self.kv_block_size: int = config.kv_block_size
self.shutdown_event = config.shutdown_event self.shutdown_event = config.shutdown_event
self.generate_endpoint = config.generate_endpoint self.generate_endpoint = config.generate_endpoint
self.disable_request_abort = config.disable_request_abort
self.additional_metrics = config.additional_metrics self.additional_metrics = config.additional_metrics
self.max_seq_len = config.max_seq_len self.max_seq_len = config.max_seq_len
self.disagg_machine_id = config.disagg_machine_id self.disagg_machine_id = config.disagg_machine_id
...@@ -464,15 +462,8 @@ class HandlerBase(BaseGenerativeHandler): ...@@ -464,15 +462,8 @@ class HandlerBase(BaseGenerativeHandler):
return_when=asyncio.FIRST_COMPLETED, return_when=asyncio.FIRST_COMPLETED,
) )
# Abort the generation unless disabled generation_result.abort()
if self.disable_request_abort: logging.debug(f"Aborted Request ID: {context.id()}")
logging.debug(
f"Request ID {context.id()} cancelled but abort() skipped "
"(DYN_TRTLLM_DISABLE_REQUEST_ABORT=true)"
)
else:
generation_result.abort()
logging.debug(f"Aborted Request ID: {context.id()}")
# Clean up any remaining background task # Clean up any remaining background task
for task in pending: for task in pending:
......
...@@ -338,47 +338,6 @@ class _ConcreteHandler(HandlerBase): ...@@ -338,47 +338,6 @@ class _ConcreteHandler(HandlerBase):
raise NotImplementedError raise NotImplementedError
class TestHandleCancellationAbortToggle:
"""Tests for the disable_request_abort toggle in _handle_cancellation."""
def _make_handler(self, disable_request_abort: bool) -> HandlerBase:
"""Create a HandlerBase with mocked config."""
config = MagicMock()
config.disable_request_abort = disable_request_abort
config.shutdown_event = None
return _ConcreteHandler(config)
@pytest.mark.asyncio
async def test_abort_called_by_default(self):
handler = self._make_handler(disable_request_abort=False)
generation_result = MagicMock()
context = MagicMock()
# async_killed_or_stopped returns an awaitable that resolves immediately
# (simulating the client cancelling the request)
killed_future = asyncio.get_event_loop().create_future()
killed_future.set_result(None)
context.async_killed_or_stopped.return_value = killed_future
context.id.return_value = "test-id-1"
await handler._handle_cancellation(generation_result, context)
generation_result.abort.assert_called_once()
@pytest.mark.asyncio
async def test_abort_not_called_when_disabled(self):
handler = self._make_handler(disable_request_abort=True)
generation_result = MagicMock()
context = MagicMock()
killed_future = asyncio.get_event_loop().create_future()
killed_future.set_result(None)
context.async_killed_or_stopped.return_value = killed_future
context.id.return_value = "test-id-2"
await handler._handle_cancellation(generation_result, context)
generation_result.abort.assert_not_called()
class TestDeferredAbortGuard: class TestDeferredAbortGuard:
"""Tests for _DeferredAbort in disaggregated decode cancellation. """Tests for _DeferredAbort in disaggregated decode cancellation.
...@@ -388,9 +347,8 @@ class TestDeferredAbortGuard: ...@@ -388,9 +347,8 @@ class TestDeferredAbortGuard:
that waits for the first token before calling real abort. that waits for the first token before calling real abort.
""" """
def _make_handler(self, disable_request_abort: bool = False) -> HandlerBase: def _make_handler(self) -> HandlerBase:
config = MagicMock() config = MagicMock()
config.disable_request_abort = disable_request_abort
config.shutdown_event = None config.shutdown_event = None
return _ConcreteHandler(config) return _ConcreteHandler(config)
...@@ -456,7 +414,7 @@ class TestDeferredAbortGuard: ...@@ -456,7 +414,7 @@ class TestDeferredAbortGuard:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_no_guard_in_non_disagg_mode(self): async def test_no_guard_in_non_disagg_mode(self):
"""Without _DeferredAbort wrapper, abort fires immediately on cancel.""" """Without _DeferredAbort wrapper, abort fires immediately on cancel."""
handler = self._make_handler(disable_request_abort=False) handler = self._make_handler()
generation_result = MagicMock() generation_result = MagicMock()
context = MagicMock() context = MagicMock()
killed_future = asyncio.get_event_loop().create_future() killed_future = asyncio.get_event_loop().create_future()
...@@ -472,7 +430,7 @@ class TestDeferredAbortGuard: ...@@ -472,7 +430,7 @@ class TestDeferredAbortGuard:
@pytest.mark.timeout(5) @pytest.mark.timeout(5)
async def test_shutdown_calls_abort_directly(self): async def test_shutdown_calls_abort_directly(self):
"""Shutdown calls abort on whatever is passed (wrapper or real), immediately.""" """Shutdown calls abort on whatever is passed (wrapper or real), immediately."""
handler = self._make_handler(disable_request_abort=False) handler = self._make_handler()
handler.shutdown_event = asyncio.Event() handler.shutdown_event = asyncio.Event()
# Pass a _DeferredAbort wrapper — shutdown should still call .abort() # Pass a _DeferredAbort wrapper — shutdown should still call .abort()
...@@ -497,20 +455,6 @@ class TestDeferredAbortGuard: ...@@ -497,20 +455,6 @@ class TestDeferredAbortGuard:
# Shutdown calls guard.abort() → since no first token, spawns background task # Shutdown calls guard.abort() → since no first token, spawns background task
# The important thing is EngineShutdown is raised and abort path is entered # The important thing is EngineShutdown is raised and abort path is entered
@pytest.mark.asyncio
async def test_disable_request_abort_skips_guard(self):
"""When disable_request_abort=True, abort is never called (guard irrelevant)."""
handler = self._make_handler(disable_request_abort=True)
generation_result = MagicMock()
context = MagicMock()
killed_future = asyncio.get_event_loop().create_future()
killed_future.set_result(None)
context.async_killed_or_stopped.return_value = killed_future
context.id.return_value = "test-disabled"
await handler._handle_cancellation(generation_result, context)
generation_result.abort.assert_not_called()
class TestMultimodalGuard: class TestMultimodalGuard:
"""Tests for multimodal guard when --modality multimodal is not configured.""" """Tests for multimodal guard when --modality multimodal is not configured."""
......
...@@ -593,7 +593,6 @@ async def init_llm_worker( ...@@ -593,7 +593,6 @@ async def init_llm_worker(
kv_block_size=config.kv_block_size, kv_block_size=config.kv_block_size,
shutdown_event=shutdown_event, shutdown_event=shutdown_event,
encoder_cache_capacity_gb=config.multimodal_embedding_cache_capacity_gb, encoder_cache_capacity_gb=config.multimodal_embedding_cache_capacity_gb,
disable_request_abort=config.disable_request_abort,
additional_metrics=additional_metrics, additional_metrics=additional_metrics,
max_seq_len=config.max_seq_len, max_seq_len=config.max_seq_len,
disagg_machine_id=int(endpoint.connection_id()) % 1021, disagg_machine_id=int(endpoint.connection_id()) % 1021,
......
...@@ -40,7 +40,6 @@ pytestmark = [ ...@@ -40,7 +40,6 @@ pytestmark = [
pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME), pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME),
pytest.mark.nightly, pytest.mark.nightly,
pytest.mark.parametrize("request_plane", ["nats", "tcp"], indirect=True), pytest.mark.parametrize("request_plane", ["nats", "tcp"], indirect=True),
pytest.mark.skip(reason="Cancellation is temporarily disabled"),
] ]
...@@ -364,6 +363,7 @@ def test_request_cancellation_trtllm_decode_cancel( ...@@ -364,6 +363,7 @@ def test_request_cancellation_trtllm_decode_cancel(
) )
@pytest.mark.skip(reason="TRT-LLM prefill cancellation is disabled due to reliability")
@pytest.mark.timeout(195) # 3x average @pytest.mark.timeout(195) # 3x average
def test_request_cancellation_trtllm_prefill_cancel( def test_request_cancellation_trtllm_prefill_cancel(
request, runtime_services_dynamic_ports, predownload_models request, runtime_services_dynamic_ports, predownload_models
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment