Unverified Commit b10d103d authored by Jacky's avatar Jacky Committed by GitHub
Browse files

fix: Temporary disable cancellation at TRT-LLM decode worker (#5764)


Signed-off-by: default avatarJacky <18255193+kthui@users.noreply.github.com>
parent aa4ac947
......@@ -172,32 +172,39 @@ class HandlerBase:
return log_probs if log_probs else None, top_logprobs if top_logprobs else None
async def _handle_cancellation_and_shutdown(
async def _handle_cancellation(
self, generation_result: GenerationResult, context: Context
):
"""
Background task to handle cancellation and shutdown by monitoring both signals.
Returns 'shutdown' if shutdown was triggered, 'cancelled' if cancelled, None otherwise.
Background task to trigger cancellation if request is cancelled or shutdown
event is set.
Raise GeneratorExit if shutdown event is triggered.
"""
try:
cancellation_task = context.async_killed_or_stopped()
# Build list of futures/tasks to wait for
wait_for = [cancellation_task]
cancellation_triggers = [
context.async_killed_or_stopped(), # Request cancellation
]
# Shutdown cancellation
shutdown_task = None
if self.shutdown_event:
# Create task for shutdown monitoring and add to wait list
if self.shutdown_event is not None:
shutdown_task = asyncio.create_task(self.shutdown_event.wait())
wait_for.append(shutdown_task)
cancellation_triggers.append(shutdown_task)
# Wait for whichever happens first
# Wait for cancellation to be triggered
done, pending = await asyncio.wait(
wait_for,
cancellation_triggers,
return_when=asyncio.FIRST_COMPLETED,
)
# Cancel the pending task/future
# Abort the generation
# Temporary: Disabled on DECODE workers to prevent engine hangs in
# disaggregated setups where abort() may cause the engine to get stuck
if self.disaggregation_mode != DisaggregationMode.DECODE:
generation_result.abort()
logging.debug(f"Aborted Request ID: {context.id()}")
# Clean up any remaining background task
for task in pending:
task.cancel()
try:
......@@ -205,12 +212,8 @@ class HandlerBase:
except asyncio.CancelledError:
pass
# Abort the generation
generation_result.abort()
logging.debug(f"Aborted Request ID: {context.id()}")
# Check which event triggered and return the reason
if shutdown_task and shutdown_task in done:
# Raise GeneratorExit if cancellation is due to shutdown event triggered
if shutdown_task in done:
raise GeneratorExit("Engine was shut down during generation.")
except asyncio.CancelledError:
......@@ -222,31 +225,30 @@ class HandlerBase:
self, generation_result: GenerationResult, context: Context
) -> AsyncGenerator[asyncio.Task, None]:
"""
Context manager for monitoring request cancellation and shutdown.
Automatically creates a background task to monitor for cancellation
and shutdown events, cleaning it up when the context exits.
Monitor for cancellation triggers and cancel by calling
generation_result.abort().
If shutdown event was triggered, raises GeneratorExit on exit.
Raise GeneratorExit if shutdown event is triggered.
Yields:
asyncio.Task: The monitoring task
asyncio.Task: The cancellation monitoring task
"""
monitor_task = asyncio.create_task(
self._handle_cancellation_and_shutdown(generation_result, context)
self._handle_cancellation(generation_result, context)
)
try:
yield monitor_task
finally:
# Clean up the background monitoring task
if not monitor_task.done():
# Cancellation not triggered - clean up the background monitoring task
monitor_task.cancel()
try:
await monitor_task
except asyncio.CancelledError:
pass
else:
# Cancellation triggered - propagate any exceptions
monitor_task.result()
def _decode_disaggregated_params_from_prefill(
......@@ -690,7 +692,7 @@ class HandlerBase:
trace_headers=trace_headers,
)
# Use the context manager to handle cancellation and shutdown monitoring
# Monitor for cancellation triggers and cancel by calling generation_result.abort()
async with self._cancellation_monitor(generation_result, context):
async for res in generation_result:
# TRTLLM engine needs to start generating tokens first before stats
......
......@@ -253,6 +253,9 @@ def test_request_cancellation_trtllm_aggregated(
logger.info(f"{description} detected successfully")
@pytest.mark.xfail(
reason="Decode worker cancellation is temporarily disabled", strict=True
)
@pytest.mark.timeout(195) # 3x average
def test_request_cancellation_trtllm_decode_cancel(
request, runtime_services_dynamic_ports, predownload_models
......@@ -429,6 +432,9 @@ def test_request_cancellation_trtllm_prefill_cancel(
)
@pytest.mark.xfail(
reason="Decode worker cancellation is temporarily disabled", strict=True
)
@pytest.mark.xfail(reason="Test fails only on CI", strict=False)
@pytest.mark.timeout(195) # 3x average
def test_request_cancellation_trtllm_kv_transfer_cancel(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment