Unverified Commit 9055b2d3 authored by Yuewei Na's avatar Yuewei Na Committed by GitHub
Browse files

fix: prevent KV block leak from cancel during disagg KV transfer (#7489)


Signed-off-by: default avatarYuewei Na <nv-yna@users.noreply.github.com>
Co-authored-by: default avatarYuewei Na <nv-yna@users.noreply.github.com>
parent 859bd6ea
...@@ -21,7 +21,7 @@ import re ...@@ -21,7 +21,7 @@ import re
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
from typing import TYPE_CHECKING, Any, Optional, Union from typing import TYPE_CHECKING, Any, Optional, Protocol, Union
import torch import torch
from tensorrt_llm.executor.result import GenerationResult from tensorrt_llm.executor.result import GenerationResult
...@@ -59,6 +59,53 @@ if TYPE_CHECKING: ...@@ -59,6 +59,53 @@ if TYPE_CHECKING:
configure_dynamo_logging() configure_dynamo_logging()
class _Abortable(Protocol):
"""Structural type for objects that support abort(). Satisfied by both
GenerationResult and _DeferredAbort."""
def abort(self) -> None:
...
class _DeferredAbort:
"""Wraps GenerationResult.abort() to defer until first token in disagg decode.
When abort() is called before the first generation result, spawns a
background asyncio.Task that reads from GenerationResult.aqueue (TRT-LLM's
internal asyncio.Queue, decoupled from Dynamo RPC transport) until the
first result arrives, then calls the real abort().
"""
def __init__(self, generation_result: GenerationResult):
self._generation_result = generation_result
self._first_token_received = False
def signal_first_token(self) -> None:
"""Called by generate_locally() when first generation result is yielded."""
self._first_token_received = True
def abort(self) -> None:
"""Abort immediately if first token received, otherwise defer."""
if self._first_token_received:
self._generation_result.abort()
logging.debug("Deferred abort: first token already received, aborting now")
else:
logging.debug(
"Deferred abort: first token not received, spawning background task"
)
asyncio.create_task(self._wait_and_abort())
async def _wait_and_abort(self) -> None:
"""Background task: read from GenerationResult until first token, then abort."""
try:
async for _ in self._generation_result:
break # First result = KV transfer complete
except Exception:
pass
self._generation_result.abort()
logging.debug("Deferred abort: background task completed, abort fired")
@dataclass @dataclass
class RequestHandlerConfig: class RequestHandlerConfig:
""" """
...@@ -183,9 +230,9 @@ class HandlerBase(BaseGenerativeHandler): ...@@ -183,9 +230,9 @@ class HandlerBase(BaseGenerativeHandler):
for tok_id, logprob_info in token_logprobs_dict.items(): for tok_id, logprob_info in token_logprobs_dict.items():
token_top_logprobs.append( token_top_logprobs.append(
{ {
"rank": logprob_info.rank "rank": (
if hasattr(logprob_info, "rank") logprob_info.rank if hasattr(logprob_info, "rank") else 0
else 0, ),
"token_id": tok_id, "token_id": tok_id,
"token": ( "token": (
logprob_info.decoded_token logprob_info.decoded_token
...@@ -200,12 +247,18 @@ class HandlerBase(BaseGenerativeHandler): ...@@ -200,12 +247,18 @@ class HandlerBase(BaseGenerativeHandler):
return log_probs if log_probs else None, top_logprobs if top_logprobs else None return log_probs if log_probs else None, top_logprobs if top_logprobs else None
async def _handle_cancellation( async def _handle_cancellation(
self, generation_result: GenerationResult, context: Context self,
generation_result: _Abortable,
context: Context,
): ):
""" """
Background task to trigger cancellation if request is cancelled or shutdown Background task to trigger cancellation if request is cancelled or shutdown
event is set. event is set.
In disaggregated decode mode, generation_result may be a _DeferredAbort
wrapper that defers abort() until the first token is received (KV
transfer complete).
Raise EngineShutdown if shutdown event is triggered. Raise EngineShutdown if shutdown event is triggered.
""" """
try: try:
...@@ -252,12 +305,17 @@ class HandlerBase(BaseGenerativeHandler): ...@@ -252,12 +305,17 @@ class HandlerBase(BaseGenerativeHandler):
@asynccontextmanager @asynccontextmanager
async def _cancellation_monitor( async def _cancellation_monitor(
self, generation_result: GenerationResult, context: Context self,
generation_result: _Abortable,
context: Context,
) -> AsyncGenerator[asyncio.Task, None]: ) -> AsyncGenerator[asyncio.Task, None]:
""" """
Monitor for cancellation triggers and cancel by calling Monitor for cancellation triggers and cancel by calling
generation_result.abort(). generation_result.abort().
In disaggregated decode mode, generation_result may be a _DeferredAbort
wrapper that defers abort() until the first token.
Raise EngineShutdown if shutdown event is triggered. Raise EngineShutdown if shutdown event is triggered.
Yields: Yields:
...@@ -838,9 +896,23 @@ class HandlerBase(BaseGenerativeHandler): ...@@ -838,9 +896,23 @@ class HandlerBase(BaseGenerativeHandler):
scheduling_params=scheduling_params, scheduling_params=scheduling_params,
) )
# Monitor for cancellation triggers and cancel by calling generation_result.abort() # In disagg decode mode, wrap abort() to defer until first token
async with self._cancellation_monitor(generation_result, context): # (KV transfer complete).
abort_guard = (
_DeferredAbort(generation_result)
if self.disaggregation_mode == DisaggregationMode.DECODE
else None
)
# Monitor for cancellation triggers and cancel by calling abort()
async with self._cancellation_monitor(
abort_guard or generation_result, context
):
async for res in generation_result: async for res in generation_result:
# Signal first token to deferred abort guard
if abort_guard is not None:
abort_guard.signal_first_token()
# TRTLLM engine needs to start generating tokens first before stats # TRTLLM engine needs to start generating tokens first before stats
# can be retrieved. # can be retrieved.
if self.first_generation and self.publisher: if self.first_generation and self.publisher:
......
...@@ -17,6 +17,7 @@ if not torch.cuda.is_available(): ...@@ -17,6 +17,7 @@ if not torch.cuda.is_available():
"CUDA/GPU not available, but tensorrt_llm import and the test require GPU.", "CUDA/GPU not available, but tensorrt_llm import and the test require GPU.",
allow_module_level=True, allow_module_level=True,
) )
from dynamo.llm.exceptions import EngineShutdown
from dynamo.trtllm.constants import DisaggregationMode from dynamo.trtllm.constants import DisaggregationMode
from dynamo.trtllm.request_handlers.handler_base import HandlerBase from dynamo.trtllm.request_handlers.handler_base import HandlerBase
...@@ -378,6 +379,139 @@ class TestHandleCancellationAbortToggle: ...@@ -378,6 +379,139 @@ class TestHandleCancellationAbortToggle:
generation_result.abort.assert_not_called() generation_result.abort.assert_not_called()
class TestDeferredAbortGuard:
"""Tests for _DeferredAbort in disaggregated decode cancellation.
In disaggregated serving, decode abort must be deferred until the first
generation result is received (indicating KV cache transfer is complete).
_DeferredAbort wraps GenerationResult.abort() to spawn a background task
that waits for the first token before calling real abort.
"""
def _make_handler(self, disable_request_abort: bool = False) -> HandlerBase:
config = MagicMock()
config.disable_request_abort = disable_request_abort
config.shutdown_event = None
return _ConcreteHandler(config)
@pytest.mark.asyncio
async def test_deferred_abort_before_first_token(self):
"""abort() before first token should NOT call real abort immediately."""
from dynamo.trtllm.request_handlers.handler_base import _DeferredAbort
generation_result = MagicMock()
# Make generation_result iterable (background task will try to read it)
generation_result.__aiter__ = MagicMock(return_value=generation_result)
never_resolve = asyncio.get_event_loop().create_future()
generation_result.__anext__ = MagicMock(return_value=never_resolve)
guard = _DeferredAbort(generation_result)
guard.abort()
# Real abort should NOT have been called — deferred to background task
generation_result.abort.assert_not_called()
@pytest.mark.asyncio
async def test_deferred_abort_after_first_token(self):
"""abort() after signal_first_token should call real abort immediately."""
from dynamo.trtllm.request_handlers.handler_base import _DeferredAbort
generation_result = MagicMock()
guard = _DeferredAbort(generation_result)
guard.signal_first_token()
guard.abort()
generation_result.abort.assert_called_once()
@pytest.mark.asyncio
@pytest.mark.timeout(5)
async def test_deferred_task_completes(self):
"""Background task should call abort after first result from generation_result."""
from dynamo.trtllm.request_handlers.handler_base import _DeferredAbort
generation_result = MagicMock()
result_queue = asyncio.Queue()
async def mock_anext(self_mock):
val = await result_queue.get()
if val is StopAsyncIteration:
raise StopAsyncIteration
return val
generation_result.__aiter__ = MagicMock(return_value=generation_result)
generation_result.__anext__ = lambda self: mock_anext(self)
guard = _DeferredAbort(generation_result)
guard.abort() # Spawns background task
generation_result.abort.assert_not_called()
# Simulate first result arriving (KV transfer complete)
await result_queue.put("first_token")
await asyncio.sleep(0.05) # Let background task run
generation_result.abort.assert_called_once()
@pytest.mark.asyncio
async def test_no_guard_in_non_disagg_mode(self):
"""Without _DeferredAbort wrapper, abort fires immediately on cancel."""
handler = self._make_handler(disable_request_abort=False)
generation_result = MagicMock()
context = MagicMock()
killed_future = asyncio.get_event_loop().create_future()
killed_future.set_result(None)
context.async_killed_or_stopped.return_value = killed_future
context.id.return_value = "test-no-guard"
# Pass real generation_result (no wrapper) — non-disagg path
await handler._handle_cancellation(generation_result, context)
generation_result.abort.assert_called_once()
@pytest.mark.asyncio
@pytest.mark.timeout(5)
async def test_shutdown_calls_abort_directly(self):
"""Shutdown calls abort on whatever is passed (wrapper or real), immediately."""
handler = self._make_handler(disable_request_abort=False)
handler.shutdown_event = asyncio.Event()
# Pass a _DeferredAbort wrapper — shutdown should still call .abort()
from dynamo.trtllm.request_handlers.handler_base import _DeferredAbort
generation_result = MagicMock()
guard = _DeferredAbort(generation_result)
context = MagicMock()
killed_future = asyncio.get_event_loop().create_future()
context.async_killed_or_stopped.return_value = killed_future
context.id.return_value = "test-shutdown"
task = asyncio.create_task(handler._handle_cancellation(guard, context))
await asyncio.sleep(0.05)
# Trigger shutdown
handler.shutdown_event.set()
with pytest.raises(EngineShutdown):
await task
# Shutdown calls guard.abort() → since no first token, spawns background task
# The important thing is EngineShutdown is raised and abort path is entered
@pytest.mark.asyncio
async def test_disable_request_abort_skips_guard(self):
"""When disable_request_abort=True, abort is never called (guard irrelevant)."""
handler = self._make_handler(disable_request_abort=True)
generation_result = MagicMock()
context = MagicMock()
killed_future = asyncio.get_event_loop().create_future()
killed_future.set_result(None)
context.async_killed_or_stopped.return_value = killed_future
context.id.return_value = "test-disabled"
await handler._handle_cancellation(generation_result, context)
generation_result.abort.assert_not_called()
class TestMultimodalGuard: class TestMultimodalGuard:
"""Tests for multimodal guard when --modality multimodal is not configured.""" """Tests for multimodal guard when --modality multimodal is not configured."""
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use std::collections::HashSet; use std::collections::HashSet;
use std::sync::Arc;
use anyhow::Result; use anyhow::Result;
use futures::StreamExt; use futures::StreamExt;
...@@ -10,11 +9,7 @@ use tokio::sync::OwnedSemaphorePermit; ...@@ -10,11 +9,7 @@ use tokio::sync::OwnedSemaphorePermit;
use tracing::Instrument; use tracing::Instrument;
use dynamo_kv_router::protocols::{BlockExtraInfo, WorkerId}; use dynamo_kv_router::protocols::{BlockExtraInfo, WorkerId};
use dynamo_runtime::{ use dynamo_runtime::{pipeline::SingleIn, protocols::maybe_error::MaybeError};
engine::AsyncEngineContext,
pipeline::{AsyncEngineContextProvider, Context, SingleIn},
protocols::maybe_error::MaybeError,
};
use super::{InnerPrefillRouter, PrefillError, PrefillResolveDecision, PrefillRouter}; use super::{InnerPrefillRouter, PrefillError, PrefillResolveDecision, PrefillRouter};
use crate::protocols::common::{ use crate::protocols::common::{
...@@ -318,12 +313,9 @@ impl PrefillRouter { ...@@ -318,12 +313,9 @@ impl PrefillRouter {
} }
} }
pub(super) fn link_child_context<T: Send + Sync + 'static>( // NVBugs 5969206: link_child_context removed — linking prefill as a child of
engine_ctx: &Arc<dyn AsyncEngineContext>, // engine_context caused kill propagation that tears down the RPC transport,
request: T, // interrupting NIXL KV cache transfers and leaking blocks permanently.
request_id: &str, // Prefill context is now created without linking (Context::with_id only).
) -> Context<T> { // Abort on the decode side is deferred via kv_transfer_complete_event in
let child_context = Context::with_id(request, request_id.to_string()); // handler_base.py until the first generation result confirms KV receipt.
engine_ctx.link_child(child_context.context());
child_context
}
...@@ -9,8 +9,8 @@ use tokio_util::sync::CancellationToken; ...@@ -9,8 +9,8 @@ use tokio_util::sync::CancellationToken;
use dynamo_kv_router::PrefillLoadEstimator; use dynamo_kv_router::PrefillLoadEstimator;
use dynamo_runtime::{ use dynamo_runtime::{
pipeline::{ pipeline::{
AsyncEngineContextProvider, ManyOut, Operator, RouterMode, ServerStreamingEngine, SingleIn, AsyncEngineContextProvider, Context, ManyOut, Operator, RouterMode, ServerStreamingEngine,
async_trait, SingleIn, async_trait,
}, },
protocols::{EndpointId, annotated::Annotated}, protocols::{EndpointId, annotated::Annotated},
}; };
...@@ -28,7 +28,6 @@ mod execution; ...@@ -28,7 +28,6 @@ mod execution;
mod inner; mod inner;
mod types; mod types;
use execution::link_child_context;
use inner::InnerPrefillRouter; use inner::InnerPrefillRouter;
pub use types::PrefillError; pub use types::PrefillError;
use types::{PrefillOutcome, PrefillResolveDecision, build_decode_router_override}; use types::{PrefillOutcome, PrefillResolveDecision, build_decode_router_override};
...@@ -144,8 +143,16 @@ impl ...@@ -144,8 +143,16 @@ impl
routing.dp_rank = dp_rank; routing.dp_rank = dp_rank;
prefill_req.bootstrap_info = Some(bootstrap_info.clone()); prefill_req.bootstrap_info = Some(bootstrap_info.clone());
let prefill_context = // NVBugs 5969206: Do NOT link prefill as child of engine context.
link_child_context(&engine_ctx, prefill_req, request_id.as_str()); // Kill propagation tears down the RPC transport, interrupting NIXL
// KV cache transfers and leaking blocks permanently. The prefill
// runs to completion independently; blocks are freed via the normal
// completion path (state 21→22).
// NOTE: This means prefill runs to completion even if the client
// disconnects, wasting prefill compute. This is an accepted
// trade-off (wasted compute vs permanent KV block leak). Future
// work: add NIXL-level cancellation that properly frees blocks.
let prefill_context = Context::with_id(prefill_req, request_id.clone());
// Pass the phase barrier to the spawned task. It is released after routing // Pass the phase barrier to the spawned task. It is released after routing
// completes so worker recording finishes before phase changes to Decode. // completes so worker recording finishes before phase changes to Decode.
...@@ -163,8 +170,8 @@ impl ...@@ -163,8 +170,8 @@ impl
// so there is no race with set_phase(Decode) below. // so there is no race with set_phase(Decode) below.
drop(prefill_phase_barrier); drop(prefill_phase_barrier);
let prefill_context = // NVBugs 5969206: Do NOT link prefill as child (same rationale as bootstrap path).
link_child_context(&engine_ctx, prefill_req, request_id.as_str()); let prefill_context = Context::with_id(prefill_req, request_id.clone());
// In Direct mode, pass preselected_worker so execute_prefill uses // In Direct mode, pass preselected_worker so execute_prefill uses
// router.direct() instead of router.generate() (which bails in Direct mode). // router.direct() instead of router.generate() (which bails in Direct mode).
...@@ -180,13 +187,17 @@ impl ...@@ -180,13 +187,17 @@ impl
} }
}; };
// Abort if cancelled during prefill // NVBugs 5969206: Do NOT abort decode routing when context is killed.
// In disaggregated serving, the prefill may have completed and KV transfer
// is in flight. Blocking decode here orphans the transfer (no receiver)
// and leaks KV blocks permanently. The decode handler's
// kv_transfer_complete_event guard will clean up after KV is received.
// Log-only; decode routing must proceed for KV transfer cleanup.
if engine_ctx.is_stopped() || engine_ctx.is_killed() { if engine_ctx.is_stopped() || engine_ctx.is_killed() {
tracing::debug!("Abort entering decode after context is stopped or killed"); tracing::debug!(
return Err(anyhow::anyhow!( "Context {} killed/stopped after prefill, allowing decode routing for KV transfer",
"Context id {} is stopped or killed",
engine_ctx.id() engine_ctx.id()
)); );
} }
// Handle prefill result // Handle prefill result
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment