"...ssh:/git@developer.sourcefind.cn:2222/OpenDAS/dynamo.git" did not exist on "d51580a418c9a7f72360209b3d65a7f6b794a618"
Unverified Commit 9055b2d3 authored by Yuewei Na's avatar Yuewei Na Committed by GitHub
Browse files

fix: prevent KV block leak from cancel during disagg KV transfer (#7489)


Signed-off-by: default avatarYuewei Na <nv-yna@users.noreply.github.com>
Co-authored-by: default avatarYuewei Na <nv-yna@users.noreply.github.com>
parent 859bd6ea
......@@ -21,7 +21,7 @@ import re
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from dataclasses import asdict, dataclass
from typing import TYPE_CHECKING, Any, Optional, Union
from typing import TYPE_CHECKING, Any, Optional, Protocol, Union
import torch
from tensorrt_llm.executor.result import GenerationResult
......@@ -59,6 +59,53 @@ if TYPE_CHECKING:
configure_dynamo_logging()
class _Abortable(Protocol):
"""Structural type for objects that support abort(). Satisfied by both
GenerationResult and _DeferredAbort."""
def abort(self) -> None:
...
class _DeferredAbort:
"""Wraps GenerationResult.abort() to defer until first token in disagg decode.
When abort() is called before the first generation result, spawns a
background asyncio.Task that reads from GenerationResult.aqueue (TRT-LLM's
internal asyncio.Queue, decoupled from Dynamo RPC transport) until the
first result arrives, then calls the real abort().
"""
def __init__(self, generation_result: GenerationResult):
self._generation_result = generation_result
self._first_token_received = False
def signal_first_token(self) -> None:
"""Called by generate_locally() when first generation result is yielded."""
self._first_token_received = True
def abort(self) -> None:
"""Abort immediately if first token received, otherwise defer."""
if self._first_token_received:
self._generation_result.abort()
logging.debug("Deferred abort: first token already received, aborting now")
else:
logging.debug(
"Deferred abort: first token not received, spawning background task"
)
asyncio.create_task(self._wait_and_abort())
async def _wait_and_abort(self) -> None:
"""Background task: read from GenerationResult until first token, then abort."""
try:
async for _ in self._generation_result:
break # First result = KV transfer complete
except Exception:
pass
self._generation_result.abort()
logging.debug("Deferred abort: background task completed, abort fired")
@dataclass
class RequestHandlerConfig:
"""
......@@ -183,9 +230,9 @@ class HandlerBase(BaseGenerativeHandler):
for tok_id, logprob_info in token_logprobs_dict.items():
token_top_logprobs.append(
{
"rank": logprob_info.rank
if hasattr(logprob_info, "rank")
else 0,
"rank": (
logprob_info.rank if hasattr(logprob_info, "rank") else 0
),
"token_id": tok_id,
"token": (
logprob_info.decoded_token
......@@ -200,12 +247,18 @@ class HandlerBase(BaseGenerativeHandler):
return log_probs if log_probs else None, top_logprobs if top_logprobs else None
async def _handle_cancellation(
self, generation_result: GenerationResult, context: Context
self,
generation_result: _Abortable,
context: Context,
):
"""
Background task to trigger cancellation if request is cancelled or shutdown
event is set.
In disaggregated decode mode, generation_result may be a _DeferredAbort
wrapper that defers abort() until the first token is received (KV
transfer complete).
Raise EngineShutdown if shutdown event is triggered.
"""
try:
......@@ -252,12 +305,17 @@ class HandlerBase(BaseGenerativeHandler):
@asynccontextmanager
async def _cancellation_monitor(
self, generation_result: GenerationResult, context: Context
self,
generation_result: _Abortable,
context: Context,
) -> AsyncGenerator[asyncio.Task, None]:
"""
Monitor for cancellation triggers and cancel by calling
generation_result.abort().
In disaggregated decode mode, generation_result may be a _DeferredAbort
wrapper that defers abort() until the first token.
Raise EngineShutdown if shutdown event is triggered.
Yields:
......@@ -838,9 +896,23 @@ class HandlerBase(BaseGenerativeHandler):
scheduling_params=scheduling_params,
)
# Monitor for cancellation triggers and cancel by calling generation_result.abort()
async with self._cancellation_monitor(generation_result, context):
# In disagg decode mode, wrap abort() to defer until first token
# (KV transfer complete).
abort_guard = (
_DeferredAbort(generation_result)
if self.disaggregation_mode == DisaggregationMode.DECODE
else None
)
# Monitor for cancellation triggers and cancel by calling abort()
async with self._cancellation_monitor(
abort_guard or generation_result, context
):
async for res in generation_result:
# Signal first token to deferred abort guard
if abort_guard is not None:
abort_guard.signal_first_token()
# TRTLLM engine needs to start generating tokens first before stats
# can be retrieved.
if self.first_generation and self.publisher:
......
......@@ -17,6 +17,7 @@ if not torch.cuda.is_available():
"CUDA/GPU not available, but tensorrt_llm import and the test require GPU.",
allow_module_level=True,
)
from dynamo.llm.exceptions import EngineShutdown
from dynamo.trtllm.constants import DisaggregationMode
from dynamo.trtllm.request_handlers.handler_base import HandlerBase
......@@ -378,6 +379,139 @@ class TestHandleCancellationAbortToggle:
generation_result.abort.assert_not_called()
class TestDeferredAbortGuard:
"""Tests for _DeferredAbort in disaggregated decode cancellation.
In disaggregated serving, decode abort must be deferred until the first
generation result is received (indicating KV cache transfer is complete).
_DeferredAbort wraps GenerationResult.abort() to spawn a background task
that waits for the first token before calling real abort.
"""
def _make_handler(self, disable_request_abort: bool = False) -> HandlerBase:
config = MagicMock()
config.disable_request_abort = disable_request_abort
config.shutdown_event = None
return _ConcreteHandler(config)
@pytest.mark.asyncio
async def test_deferred_abort_before_first_token(self):
"""abort() before first token should NOT call real abort immediately."""
from dynamo.trtllm.request_handlers.handler_base import _DeferredAbort
generation_result = MagicMock()
# Make generation_result iterable (background task will try to read it)
generation_result.__aiter__ = MagicMock(return_value=generation_result)
never_resolve = asyncio.get_event_loop().create_future()
generation_result.__anext__ = MagicMock(return_value=never_resolve)
guard = _DeferredAbort(generation_result)
guard.abort()
# Real abort should NOT have been called — deferred to background task
generation_result.abort.assert_not_called()
@pytest.mark.asyncio
async def test_deferred_abort_after_first_token(self):
"""abort() after signal_first_token should call real abort immediately."""
from dynamo.trtllm.request_handlers.handler_base import _DeferredAbort
generation_result = MagicMock()
guard = _DeferredAbort(generation_result)
guard.signal_first_token()
guard.abort()
generation_result.abort.assert_called_once()
@pytest.mark.asyncio
@pytest.mark.timeout(5)
async def test_deferred_task_completes(self):
"""Background task should call abort after first result from generation_result."""
from dynamo.trtllm.request_handlers.handler_base import _DeferredAbort
generation_result = MagicMock()
result_queue = asyncio.Queue()
async def mock_anext(self_mock):
val = await result_queue.get()
if val is StopAsyncIteration:
raise StopAsyncIteration
return val
generation_result.__aiter__ = MagicMock(return_value=generation_result)
generation_result.__anext__ = lambda self: mock_anext(self)
guard = _DeferredAbort(generation_result)
guard.abort() # Spawns background task
generation_result.abort.assert_not_called()
# Simulate first result arriving (KV transfer complete)
await result_queue.put("first_token")
await asyncio.sleep(0.05) # Let background task run
generation_result.abort.assert_called_once()
@pytest.mark.asyncio
async def test_no_guard_in_non_disagg_mode(self):
"""Without _DeferredAbort wrapper, abort fires immediately on cancel."""
handler = self._make_handler(disable_request_abort=False)
generation_result = MagicMock()
context = MagicMock()
killed_future = asyncio.get_event_loop().create_future()
killed_future.set_result(None)
context.async_killed_or_stopped.return_value = killed_future
context.id.return_value = "test-no-guard"
# Pass real generation_result (no wrapper) — non-disagg path
await handler._handle_cancellation(generation_result, context)
generation_result.abort.assert_called_once()
@pytest.mark.asyncio
@pytest.mark.timeout(5)
async def test_shutdown_calls_abort_directly(self):
"""Shutdown calls abort on whatever is passed (wrapper or real), immediately."""
handler = self._make_handler(disable_request_abort=False)
handler.shutdown_event = asyncio.Event()
# Pass a _DeferredAbort wrapper — shutdown should still call .abort()
from dynamo.trtllm.request_handlers.handler_base import _DeferredAbort
generation_result = MagicMock()
guard = _DeferredAbort(generation_result)
context = MagicMock()
killed_future = asyncio.get_event_loop().create_future()
context.async_killed_or_stopped.return_value = killed_future
context.id.return_value = "test-shutdown"
task = asyncio.create_task(handler._handle_cancellation(guard, context))
await asyncio.sleep(0.05)
# Trigger shutdown
handler.shutdown_event.set()
with pytest.raises(EngineShutdown):
await task
# Shutdown calls guard.abort() → since no first token, spawns background task
# The important thing is EngineShutdown is raised and abort path is entered
@pytest.mark.asyncio
async def test_disable_request_abort_skips_guard(self):
"""When disable_request_abort=True, abort is never called (guard irrelevant)."""
handler = self._make_handler(disable_request_abort=True)
generation_result = MagicMock()
context = MagicMock()
killed_future = asyncio.get_event_loop().create_future()
killed_future.set_result(None)
context.async_killed_or_stopped.return_value = killed_future
context.id.return_value = "test-disabled"
await handler._handle_cancellation(generation_result, context)
generation_result.abort.assert_not_called()
class TestMultimodalGuard:
"""Tests for multimodal guard when --modality multimodal is not configured."""
......
......@@ -2,7 +2,6 @@
// SPDX-License-Identifier: Apache-2.0
use std::collections::HashSet;
use std::sync::Arc;
use anyhow::Result;
use futures::StreamExt;
......@@ -10,11 +9,7 @@ use tokio::sync::OwnedSemaphorePermit;
use tracing::Instrument;
use dynamo_kv_router::protocols::{BlockExtraInfo, WorkerId};
use dynamo_runtime::{
engine::AsyncEngineContext,
pipeline::{AsyncEngineContextProvider, Context, SingleIn},
protocols::maybe_error::MaybeError,
};
use dynamo_runtime::{pipeline::SingleIn, protocols::maybe_error::MaybeError};
use super::{InnerPrefillRouter, PrefillError, PrefillResolveDecision, PrefillRouter};
use crate::protocols::common::{
......@@ -318,12 +313,9 @@ impl PrefillRouter {
}
}
pub(super) fn link_child_context<T: Send + Sync + 'static>(
engine_ctx: &Arc<dyn AsyncEngineContext>,
request: T,
request_id: &str,
) -> Context<T> {
let child_context = Context::with_id(request, request_id.to_string());
engine_ctx.link_child(child_context.context());
child_context
}
// NVBugs 5969206: link_child_context removed — linking prefill as a child of
// engine_context caused kill propagation that tears down the RPC transport,
// interrupting NIXL KV cache transfers and leaking blocks permanently.
// Prefill context is now created without linking (Context::with_id only).
// Abort on the decode side is deferred via kv_transfer_complete_event in
// handler_base.py until the first generation result confirms KV receipt.
......@@ -9,8 +9,8 @@ use tokio_util::sync::CancellationToken;
use dynamo_kv_router::PrefillLoadEstimator;
use dynamo_runtime::{
pipeline::{
AsyncEngineContextProvider, ManyOut, Operator, RouterMode, ServerStreamingEngine, SingleIn,
async_trait,
AsyncEngineContextProvider, Context, ManyOut, Operator, RouterMode, ServerStreamingEngine,
SingleIn, async_trait,
},
protocols::{EndpointId, annotated::Annotated},
};
......@@ -28,7 +28,6 @@ mod execution;
mod inner;
mod types;
use execution::link_child_context;
use inner::InnerPrefillRouter;
pub use types::PrefillError;
use types::{PrefillOutcome, PrefillResolveDecision, build_decode_router_override};
......@@ -144,8 +143,16 @@ impl
routing.dp_rank = dp_rank;
prefill_req.bootstrap_info = Some(bootstrap_info.clone());
let prefill_context =
link_child_context(&engine_ctx, prefill_req, request_id.as_str());
// NVBugs 5969206: Do NOT link prefill as child of engine context.
// Kill propagation tears down the RPC transport, interrupting NIXL
// KV cache transfers and leaking blocks permanently. The prefill
// runs to completion independently; blocks are freed via the normal
// completion path (state 21→22).
// NOTE: This means prefill runs to completion even if the client
// disconnects, wasting prefill compute. This is an accepted
// trade-off (wasted compute vs permanent KV block leak). Future
// work: add NIXL-level cancellation that properly frees blocks.
let prefill_context = Context::with_id(prefill_req, request_id.clone());
// Pass the phase barrier to the spawned task. It is released after routing
// completes so worker recording finishes before phase changes to Decode.
......@@ -163,8 +170,8 @@ impl
// so there is no race with set_phase(Decode) below.
drop(prefill_phase_barrier);
let prefill_context =
link_child_context(&engine_ctx, prefill_req, request_id.as_str());
// NVBugs 5969206: Do NOT link prefill as child (same rationale as bootstrap path).
let prefill_context = Context::with_id(prefill_req, request_id.clone());
// In Direct mode, pass preselected_worker so execute_prefill uses
// router.direct() instead of router.generate() (which bails in Direct mode).
......@@ -180,13 +187,17 @@ impl
}
};
// Abort if cancelled during prefill
// NVBugs 5969206: Do NOT abort decode routing when context is killed.
// In disaggregated serving, the prefill may have completed and KV transfer
// is in flight. Blocking decode here orphans the transfer (no receiver)
// and leaks KV blocks permanently. The decode handler's
// kv_transfer_complete_event guard will clean up after KV is received.
// Log-only; decode routing must proceed for KV transfer cleanup.
if engine_ctx.is_stopped() || engine_ctx.is_killed() {
tracing::debug!("Abort entering decode after context is stopped or killed");
return Err(anyhow::anyhow!(
"Context id {} is stopped or killed",
tracing::debug!(
"Context {} killed/stopped after prefill, allowing decode routing for KV transfer",
engine_ctx.id()
));
);
}
// Handle prefill result
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment