Unverified Commit 04486c1d authored by Jacky's avatar Jacky Committed by GitHub
Browse files

feat: Request Cancellation when transitioning from Prefill to Decode at KV Prefill Router (#4449)


Signed-off-by: default avatarJacky <18255193+kthui@users.noreply.github.com>
parent 45529d06
...@@ -11,8 +11,8 @@ use tokio_util::sync::CancellationToken; ...@@ -11,8 +11,8 @@ use tokio_util::sync::CancellationToken;
use dynamo_runtime::{ use dynamo_runtime::{
component::Endpoint, component::Endpoint,
pipeline::{ pipeline::{
AsyncEngine, AsyncEngineContext, AsyncEngineContextProvider, Context, ManyOut, Operator, AsyncEngine, AsyncEngineContextProvider, Context, ManyOut, Operator, PushRouter,
PushRouter, RouterMode, ServerStreamingEngine, SingleIn, async_trait, RouterMode, ServerStreamingEngine, SingleIn, async_trait,
}, },
protocols::{annotated::Annotated, maybe_error::MaybeError}, protocols::{annotated::Annotated, maybe_error::MaybeError},
}; };
...@@ -270,6 +270,7 @@ impl ...@@ -270,6 +270,7 @@ impl
// Extract request data while preserving context // Extract request data while preserving context
let (req, context) = request.into_parts(); let (req, context) = request.into_parts();
let request_id = context.id().to_string(); let request_id = context.id().to_string();
let engine_ctx = context.context();
// Save original max_tokens for decode // Save original max_tokens for decode
let original_max_tokens = req.stop_conditions.max_tokens; let original_max_tokens = req.stop_conditions.max_tokens;
...@@ -280,12 +281,24 @@ impl ...@@ -280,12 +281,24 @@ impl
let prefill_context = Context::with_id(prefill_req, request_id.clone()); let prefill_context = Context::with_id(prefill_req, request_id.clone());
// Link the prefill context as a child so that kill signals propagate // Link the prefill context as a child so that kill signals propagate
context.controller().link_child(prefill_context.context()); engine_ctx.link_child(prefill_context.context());
let prefill_request = prefill_context; let prefill_request = prefill_context;
// Attempt prefill and handle results // Attempt prefill
match self.call_prefill(prefill_request).await { let prefill_result = self.call_prefill(prefill_request).await;
// Abort if cancelled during prefill
if engine_ctx.is_stopped() || engine_ctx.is_killed() {
tracing::debug!("Abort entering decode after context is stopped or killed");
return Err(anyhow::anyhow!(
"Context id {} is stopped or killed",
engine_ctx.id()
));
}
// Handle prefill result
match prefill_result {
Ok(prefill_result) => { Ok(prefill_result) => {
tracing::debug!("Prefill succeeded, using disaggregated params for decode"); tracing::debug!("Prefill succeeded, using disaggregated params for decode");
......
...@@ -243,9 +243,9 @@ def test_request_cancellation_sglang_decode_cancel( ...@@ -243,9 +243,9 @@ def test_request_cancellation_sglang_decode_cancel(
request, runtime_services, predownload_models request, runtime_services, predownload_models
): ):
""" """
End-to-end test for request cancellation during remote decode phase. End-to-end test for request cancellation during decode phase.
This test verifies that when a request is cancelled by the client during the remote decode phase, This test verifies that when a request is cancelled by the client during the decode phase,
the system properly handles the cancellation and cleans up resources the system properly handles the cancellation and cleans up resources
on both the prefill and decode workers in a disaggregated setup. on both the prefill and decode workers in a disaggregated setup.
...@@ -267,9 +267,9 @@ def test_request_cancellation_sglang_decode_cancel( ...@@ -267,9 +267,9 @@ def test_request_cancellation_sglang_decode_cancel(
# TODO: Why wait after worker ready fixes frontend 404 / 500 flakiness? # TODO: Why wait after worker ready fixes frontend 404 / 500 flakiness?
time.sleep(2) time.sleep(2)
# Step 4: Test request cancellation during remote decode phase # Step 4: Test request cancellation during decode phase
logger.info( logger.info(
"Testing chat completion stream request cancellation during remote decode phase..." "Testing chat completion stream request cancellation during decode phase..."
) )
# Send streaming request (non-blocking) # Send streaming request (non-blocking)
......
...@@ -206,7 +206,7 @@ def test_request_cancellation_trtllm_aggregated( ...@@ -206,7 +206,7 @@ def test_request_cancellation_trtllm_aggregated(
@pytest.mark.gpu_1 @pytest.mark.gpu_1
@pytest.mark.e2e @pytest.mark.e2e
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME) @pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
def test_request_cancellation_trtllm_disagg_decode_cancel( def test_request_cancellation_trtllm_decode_cancel(
request, runtime_services, predownload_models request, runtime_services, predownload_models
): ):
""" """
...@@ -282,7 +282,7 @@ def test_request_cancellation_trtllm_disagg_decode_cancel( ...@@ -282,7 +282,7 @@ def test_request_cancellation_trtllm_disagg_decode_cancel(
@pytest.mark.gpu_1 @pytest.mark.gpu_1
@pytest.mark.e2e @pytest.mark.e2e
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME) @pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
def test_request_cancellation_trtllm_disagg_prefill_cancel( def test_request_cancellation_trtllm_prefill_cancel(
request, runtime_services, predownload_models request, runtime_services, predownload_models
): ):
""" """
...@@ -342,6 +342,23 @@ def test_request_cancellation_trtllm_disagg_prefill_cancel( ...@@ -342,6 +342,23 @@ def test_request_cancellation_trtllm_disagg_prefill_cancel(
pattern="issued control message Kill to sender", pattern="issued control message Kill to sender",
) )
# Verify decode worker never received the request
pattern = "Request ID: "
try:
_, decode_log_offset = poll_for_pattern(
process=decode_worker,
pattern=pattern,
max_wait_ms=10,
match_type="contains",
)
pytest.fail(
"Decode worker received request cancelled during prefill phase"
)
except AssertionError as e:
assert str(e).startswith(
f"Failed to find '{pattern}' pattern after 2 iterations "
), f"Unexpected error: {e}"
logger.info( logger.info(
"Completion request cancellation during prefill phase detected successfully" "Completion request cancellation during prefill phase detected successfully"
) )
...@@ -274,13 +274,13 @@ def test_request_cancellation_vllm_decode_cancel( ...@@ -274,13 +274,13 @@ def test_request_cancellation_vllm_decode_cancel(
@pytest.mark.gpu_1 @pytest.mark.gpu_1
@pytest.mark.e2e @pytest.mark.e2e
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME) @pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
def test_request_cancellation_vllm_remote_prefill_cancel( def test_request_cancellation_vllm_prefill_cancel(
request, runtime_services, predownload_models, set_ucx_tls_no_mm request, runtime_services, predownload_models, set_ucx_tls_no_mm
): ):
""" """
End-to-end test for request cancellation during remote prefill phase. End-to-end test for request cancellation during prefill phase.
This test verifies that when a request is cancelled by the client during the remote prefill phase, This test verifies that when a request is cancelled by the client during the prefill phase,
the system properly handles the cancellation and cleans up resources the system properly handles the cancellation and cleans up resources
on both the decode and prefill workers in a disaggregated setup. on both the decode and prefill workers in a disaggregated setup.
""" """
...@@ -334,6 +334,23 @@ def test_request_cancellation_vllm_remote_prefill_cancel( ...@@ -334,6 +334,23 @@ def test_request_cancellation_vllm_remote_prefill_cancel(
pattern="issued control message Kill to sender", pattern="issued control message Kill to sender",
) )
# Verify decode worker never received the request
pattern = "Request ID: "
try:
_, decode_log_offset = poll_for_pattern(
process=decode_worker,
pattern=pattern,
max_wait_ms=10,
match_type="contains",
)
pytest.fail(
"Decode worker received request cancelled during prefill phase"
)
except AssertionError as e:
assert str(e).startswith(
f"Failed to find '{pattern}' pattern after 2 iterations "
), f"Unexpected error: {e}"
logger.info( logger.info(
"Completion request cancellation during remote prefill phase detected successfully" "Completion request cancellation during prefill phase detected successfully"
) )
...@@ -388,6 +388,6 @@ def poll_for_pattern( ...@@ -388,6 +388,6 @@ def poll_for_pattern(
time.sleep(poll_interval_ms / 1000.0) time.sleep(poll_interval_ms / 1000.0)
iteration += 1 iteration += 1
pytest.fail( raise AssertionError(
f"Failed to find '{pattern}' pattern after {max_iterations} iterations ({max_wait_ms}ms)" f"Failed to find '{pattern}' pattern after {max_iterations} iterations ({max_wait_ms}ms)"
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment