Unverified Commit b10d103d authored by Jacky's avatar Jacky Committed by GitHub
Browse files

fix: Temporary disable cancellation at TRT-LLM decode worker (#5764)


Signed-off-by: default avatarJacky <18255193+kthui@users.noreply.github.com>
parent aa4ac947
...@@ -172,32 +172,39 @@ class HandlerBase: ...@@ -172,32 +172,39 @@ class HandlerBase:
return log_probs if log_probs else None, top_logprobs if top_logprobs else None return log_probs if log_probs else None, top_logprobs if top_logprobs else None
async def _handle_cancellation_and_shutdown( async def _handle_cancellation(
self, generation_result: GenerationResult, context: Context self, generation_result: GenerationResult, context: Context
): ):
""" """
Background task to handle cancellation and shutdown by monitoring both signals. Background task to trigger cancellation if request is cancelled or shutdown
Returns 'shutdown' if shutdown was triggered, 'cancelled' if cancelled, None otherwise. event is set.
Raise GeneratorExit if shutdown event is triggered.
""" """
try: try:
cancellation_task = context.async_killed_or_stopped() cancellation_triggers = [
context.async_killed_or_stopped(), # Request cancellation
# Build list of futures/tasks to wait for ]
wait_for = [cancellation_task] # Shutdown cancellation
shutdown_task = None shutdown_task = None
if self.shutdown_event is not None:
if self.shutdown_event:
# Create task for shutdown monitoring and add to wait list
shutdown_task = asyncio.create_task(self.shutdown_event.wait()) shutdown_task = asyncio.create_task(self.shutdown_event.wait())
wait_for.append(shutdown_task) cancellation_triggers.append(shutdown_task)
# Wait for whichever happens first # Wait for cancellation to be triggered
done, pending = await asyncio.wait( done, pending = await asyncio.wait(
wait_for, cancellation_triggers,
return_when=asyncio.FIRST_COMPLETED, return_when=asyncio.FIRST_COMPLETED,
) )
# Cancel the pending task/future # Abort the generation
# Temporary: Disabled on DECODE workers to prevent engine hangs in
# disaggregated setups where abort() may cause the engine to get stuck
if self.disaggregation_mode != DisaggregationMode.DECODE:
generation_result.abort()
logging.debug(f"Aborted Request ID: {context.id()}")
# Clean up any remaining background task
for task in pending: for task in pending:
task.cancel() task.cancel()
try: try:
...@@ -205,12 +212,8 @@ class HandlerBase: ...@@ -205,12 +212,8 @@ class HandlerBase:
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass
# Abort the generation # Raise GeneratorExit if cancellation is due to shutdown event triggered
generation_result.abort() if shutdown_task in done:
logging.debug(f"Aborted Request ID: {context.id()}")
# Check which event triggered and return the reason
if shutdown_task and shutdown_task in done:
raise GeneratorExit("Engine was shut down during generation.") raise GeneratorExit("Engine was shut down during generation.")
except asyncio.CancelledError: except asyncio.CancelledError:
...@@ -222,31 +225,30 @@ class HandlerBase: ...@@ -222,31 +225,30 @@ class HandlerBase:
self, generation_result: GenerationResult, context: Context self, generation_result: GenerationResult, context: Context
) -> AsyncGenerator[asyncio.Task, None]: ) -> AsyncGenerator[asyncio.Task, None]:
""" """
Context manager for monitoring request cancellation and shutdown. Monitor for cancellation triggers and cancel by calling
generation_result.abort().
Automatically creates a background task to monitor for cancellation
and shutdown events, cleaning it up when the context exits.
If shutdown event was triggered, raises GeneratorExit on exit. Raise GeneratorExit if shutdown event is triggered.
Yields: Yields:
asyncio.Task: The monitoring task asyncio.Task: The cancellation monitoring task
""" """
monitor_task = asyncio.create_task( monitor_task = asyncio.create_task(
self._handle_cancellation_and_shutdown(generation_result, context) self._handle_cancellation(generation_result, context)
) )
try: try:
yield monitor_task yield monitor_task
finally: finally:
# Clean up the background monitoring task
if not monitor_task.done(): if not monitor_task.done():
# Cancellation not triggered - clean up the background monitoring task
monitor_task.cancel() monitor_task.cancel()
try: try:
await monitor_task await monitor_task
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass
else: else:
# Cancellation triggered - propagate any exceptions
monitor_task.result() monitor_task.result()
def _decode_disaggregated_params_from_prefill( def _decode_disaggregated_params_from_prefill(
...@@ -690,7 +692,7 @@ class HandlerBase: ...@@ -690,7 +692,7 @@ class HandlerBase:
trace_headers=trace_headers, trace_headers=trace_headers,
) )
# Use the context manager to handle cancellation and shutdown monitoring # Monitor for cancellation triggers and cancel by calling generation_result.abort()
async with self._cancellation_monitor(generation_result, context): async with self._cancellation_monitor(generation_result, context):
async for res in generation_result: async for res in generation_result:
# TRTLLM engine needs to start generating tokens first before stats # TRTLLM engine needs to start generating tokens first before stats
......
...@@ -253,6 +253,9 @@ def test_request_cancellation_trtllm_aggregated( ...@@ -253,6 +253,9 @@ def test_request_cancellation_trtllm_aggregated(
logger.info(f"{description} detected successfully") logger.info(f"{description} detected successfully")
@pytest.mark.xfail(
reason="Decode worker cancellation is temporarily disabled", strict=True
)
@pytest.mark.timeout(195) # 3x average @pytest.mark.timeout(195) # 3x average
def test_request_cancellation_trtllm_decode_cancel( def test_request_cancellation_trtllm_decode_cancel(
request, runtime_services_dynamic_ports, predownload_models request, runtime_services_dynamic_ports, predownload_models
...@@ -429,6 +432,9 @@ def test_request_cancellation_trtllm_prefill_cancel( ...@@ -429,6 +432,9 @@ def test_request_cancellation_trtllm_prefill_cancel(
) )
@pytest.mark.xfail(
reason="Decode worker cancellation is temporarily disabled", strict=True
)
@pytest.mark.xfail(reason="Test fails only on CI", strict=False) @pytest.mark.xfail(reason="Test fails only on CI", strict=False)
@pytest.mark.timeout(195) # 3x average @pytest.mark.timeout(195) # 3x average
def test_request_cancellation_trtllm_kv_transfer_cancel( def test_request_cancellation_trtllm_kv_transfer_cancel(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment