Unverified Commit a8fd1271 authored by Jacky's avatar Jacky Committed by GitHub
Browse files

feat: Request Cancellation unary request support (#3004)


Signed-off-by: default avatarJacky <18255193+kthui@users.noreply.github.com>
parent 10bfb73a
......@@ -39,16 +39,12 @@ class MiddleServer:
stream = await self.backend_client.generate(request, context=context)
# Stream responses back to client
try:
async for response in stream:
data = response.data()
print(f"Middle server: Forwarding response {data}")
yield data
except ValueError as e:
if str(e) != "Stream ended before generation completed":
raise
print("Middle server: Backend stream ended early due to cancellation")
async for response in stream:
data = response.data()
print(f"Middle server: Forwarding response {data}")
yield data
print("Middle server: Backend stream ended")
async def main():
......
......@@ -38,15 +38,17 @@ async def test_client_context_cancel(server, client):
if iteration_count >= 2:
print("Cancelling after 2 responses...")
context.stop_generating()
break
iteration_count += 1
# Verify we received exactly 3 responses (0, 1, 2)
assert iteration_count == 3
# Give server a moment to process the cancellation
await asyncio.sleep(0.2)
# Verify server detected the cancellation
assert handler.context_is_stopped
assert handler.context_is_killed
assert not handler.context_is_killed
# TODO: Test with _generate_until_asyncio_cancelled server handler
......@@ -139,9 +139,9 @@ async def test_middle_server_cancellation(
assert (
"Client: Cancelling after 3 responses..." in client_output
), f"Client output: {client_output}"
assert (
"Middle server: Forwarding response 2" in middle_output
), f"Middle server output: {middle_output}"
assert (
"Server: Cancelled at iteration" in server_output
), f"Server output: {server_output}"
assert (
"Middle server: Backend stream ended early due to cancellation" in middle_output
), f"Middle server output: {middle_output}"
......@@ -17,8 +17,8 @@ use crate::{
use dynamo_runtime::{
pipeline::{
AsyncEngineContextProvider, Context, ManyOut, Operator, ResponseStream,
ServerStreamingEngine, SingleIn, async_trait,
AsyncEngineContext, AsyncEngineContextProvider, Context, ManyOut, Operator, ResponseStream,
ServerStreamingEngine, SingleIn, async_trait, network::STREAM_ERR_MSG,
},
protocols::{annotated::Annotated, maybe_error::MaybeError},
};
......@@ -55,30 +55,23 @@ impl
next: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>>,
) -> Result<ManyOut<Annotated<LLMEngineOutput>>> {
let (preprocessed_request, context) = request.transfer(());
let context_id = context.id().to_string();
let engine_ctx = context.context();
let engine_ctx_ = engine_ctx.clone();
let retry_manager =
RetryManager::build(context_id, preprocessed_request, next, self.migration_limit)
RetryManager::build(engine_ctx, preprocessed_request, next, self.migration_limit)
.await?;
let response_stream = stream::unfold(retry_manager, move |mut retry_manager| {
let engine_ctx = engine_ctx_.clone();
async move {
if engine_ctx.is_stopped() || engine_ctx.is_killed() {
return None; // Stop if the context is cancelled or stopped
}
retry_manager
.next()
.await
.map(|response| (response, retry_manager))
}
let response_stream = stream::unfold(retry_manager, move |mut retry_manager| async move {
retry_manager
.next()
.await
.map(|response| (response, retry_manager))
});
Ok(ResponseStream::new(Box::pin(response_stream), engine_ctx))
Ok(ResponseStream::new(Box::pin(response_stream), engine_ctx_))
}
}
struct RetryManager {
context_id: String,
context: Arc<dyn AsyncEngineContext>,
request: PreprocessedRequest,
next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>>,
next_stream: Option<ManyOut<Annotated<LLMEngineOutput>>>,
......@@ -87,13 +80,13 @@ struct RetryManager {
impl RetryManager {
pub async fn build(
context_id: String,
context: Arc<dyn AsyncEngineContext>,
preprocessed_request: PreprocessedRequest,
next: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>>,
retries_left: u32,
) -> Result<Self> {
let mut slf = Self {
context_id,
context,
request: preprocessed_request,
next_generate: next,
next_stream: None,
......@@ -115,18 +108,16 @@ impl RetryManager {
}
};
if let Some(response) = response_stream.next().await {
if let Some(err) = response.err() {
const STREAM_ERR_MSG: &str = "Stream ended before generation completed";
if err
if let Some(err) = response.err()
&& err
.chain()
.any(|e| e.to_string().starts_with(STREAM_ERR_MSG))
{
tracing::warn!("Stream disconnected... recreating stream...");
if let Err(err) = self.new_stream().await {
tracing::warn!("Cannot recreate stream: {:#}", err);
} else {
continue;
}
{
tracing::warn!("Stream disconnected... recreating stream...");
if let Err(err) = self.new_stream().await {
tracing::warn!("Cannot recreate stream: {:#}", err);
} else {
continue;
}
}
self.track_response(&response);
......@@ -140,7 +131,8 @@ impl RetryManager {
let mut response_stream: Option<Result<ManyOut<Annotated<LLMEngineOutput>>>> = None;
while self.retries_left > 0 {
self.retries_left -= 1;
let request = Context::with_id(self.request.clone(), self.context_id.clone());
let request = Context::with_id(self.request.clone(), self.context.id().to_string());
self.context.link_child(request.context());
response_stream = Some(self.next_generate.generate(request).await);
if let Some(err) = response_stream.as_ref().unwrap().as_ref().err()
&& let Some(req_err) = err.downcast_ref::<NatsRequestError>()
......@@ -339,10 +331,8 @@ mod tests {
}
}
// Send the specific error that triggers retry logic
let error_response = Annotated::from_err(
anyhow::Error::msg("Stream ended before generation completed")
.into(),
);
let error_response =
Annotated::from_err(anyhow::Error::msg(STREAM_ERR_MSG).into());
let _ = tx.send(error_response).await;
});
} else {
......@@ -381,10 +371,8 @@ mod tests {
}
}
// Send the specific error that triggers retry logic
let error_response = Annotated::from_err(
anyhow::Error::msg("Stream ended before generation completed")
.into(),
);
let error_response =
Annotated::from_err(anyhow::Error::msg(STREAM_ERR_MSG).into());
let _ = tx.send(error_response).await;
});
......@@ -417,10 +405,8 @@ mod tests {
}
}
// Send the specific error that triggers retry logic
let error_response = Annotated::from_err(
anyhow::Error::msg("Stream ended before generation completed")
.into(),
);
let error_response =
Annotated::from_err(anyhow::Error::msg(STREAM_ERR_MSG).into());
let _ = tx.send(error_response).await;
});
......@@ -434,10 +420,8 @@ mod tests {
// Subsequent calls - immediately send stream error (no successful responses)
tokio::spawn(async move {
// Send the stream error immediately
let error_response = Annotated::from_err(
anyhow::Error::msg("Stream ended before generation completed")
.into(),
);
let error_response =
Annotated::from_err(anyhow::Error::msg(STREAM_ERR_MSG).into());
let _ = tx.send(error_response).await;
});
......@@ -503,7 +487,8 @@ mod tests {
let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>> =
mock_engine;
let mut retry_manager = RetryManager::build(context_id, request, next_generate, 0)
let ctx = Arc::new(Controller::new(context_id.clone()));
let mut retry_manager = RetryManager::build(ctx, request, next_generate, 0)
.await
.expect("Failed to build RetryManager");
......@@ -541,7 +526,8 @@ mod tests {
let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>> =
mock_engine;
let mut retry_manager = RetryManager::build(context_id, request, next_generate, 3)
let ctx = Arc::new(Controller::new(context_id.clone()));
let mut retry_manager = RetryManager::build(ctx, request, next_generate, 3)
.await
.expect("Failed to build RetryManager");
......@@ -580,7 +566,8 @@ mod tests {
let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>> =
mock_engine;
let mut retry_manager = RetryManager::build(context_id, request, next_generate, 3)
let ctx = Arc::new(Controller::new(context_id.clone()));
let mut retry_manager = RetryManager::build(ctx, request, next_generate, 3)
.await
.expect("Failed to build RetryManager");
......@@ -620,7 +607,8 @@ mod tests {
mock_engine;
// Should fail to build due to initial stream creation failure after exhausting all 3 retries
let retry_manager_result = RetryManager::build(context_id, request, next_generate, 3).await;
let ctx = Arc::new(Controller::new(context_id.clone()));
let retry_manager_result = RetryManager::build(ctx, request, next_generate, 3).await;
assert!(retry_manager_result.is_err());
if let Err(error) = retry_manager_result {
......@@ -646,7 +634,8 @@ mod tests {
let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>> =
mock_engine;
let mut retry_manager = RetryManager::build(context_id, request, next_generate, 3) // 3 retries
let ctx = Arc::new(Controller::new(context_id.clone()));
let mut retry_manager = RetryManager::build(ctx, request, next_generate, 3) // 3 retries
.await
.expect("Failed to build RetryManager");
......@@ -672,11 +661,7 @@ mod tests {
let error_response = &responses[3];
assert!(error_response.err().is_some());
if let Some(error) = error_response.err() {
assert!(
error
.to_string()
.contains("Stream ended before generation completed")
);
assert!(error.to_string().contains(STREAM_ERR_MSG));
}
}
......@@ -698,7 +683,8 @@ mod tests {
let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>> =
mock_engine;
let mut retry_manager = RetryManager::build(context_id, request, next_generate, 3) // 3 retries
let ctx = Arc::new(Controller::new(context_id.clone()));
let mut retry_manager = RetryManager::build(ctx, request, next_generate, 3) // 3 retries
.await
.expect("Failed to build RetryManager");
......@@ -724,11 +710,7 @@ mod tests {
let error_response = &responses[3];
assert!(error_response.err().is_some());
if let Some(error) = error_response.err() {
assert!(
error
.to_string()
.contains("Stream ended before generation completed")
);
assert!(error.to_string().contains(STREAM_ERR_MSG));
}
}
}
......@@ -358,18 +358,20 @@ impl AsyncEngineContext for Controller {
async fn stopped(&self) {
let mut rx = self.rx.clone();
if *rx.borrow_and_update() != State::Live {
return;
loop {
if *rx.borrow_and_update() != State::Live || rx.changed().await.is_err() {
return;
}
}
let _ = rx.changed().await;
}
async fn killed(&self) {
let mut rx = self.rx.clone();
if *rx.borrow_and_update() == State::Killed {
return;
loop {
if *rx.borrow_and_update() == State::Killed || rx.changed().await.is_err() {
return;
}
}
let _ = rx.changed().await;
}
fn stop_generating(&self) {
......
......@@ -27,6 +27,9 @@ use super::{
};
use ingress::push_handler::WorkHandlerMetrics;
// Define stream error message constant
pub const STREAM_ERR_MSG: &str = "Stream ended before generation completed";
// Add Prometheus metrics types
use crate::metrics::MetricsRegistry;
use prometheus::{CounterVec, Histogram, IntCounter, IntCounterVec, IntGauge};
......
......@@ -80,6 +80,7 @@ where
let (addressed_request, context) = request.transfer(());
let (request, address) = addressed_request.into_parts();
let engine_ctx = context.context();
let engine_ctx_ = engine_ctx.clone();
// registration options for the data plane in a singe in / many out configuration
let options = StreamOptions::builder()
......@@ -209,11 +210,18 @@ where
}
}
} else if is_complete_final {
// end of stream
None
} else if engine_ctx_.is_stopped() {
// Gracefully end the stream if 'stop_generating()' was called. Do NOT check for
// 'is_killed()' here because it implies the stream ended abnormally which should be
// handled by the error branch below.
log::debug!("Request cancelled and then trying to read a response");
None
} else {
Some(U::from_err(
Error::msg("Stream ended before generation completed").into(),
))
// stream ended unexpectedly
log::debug!("{STREAM_ERR_MSG}");
Some(U::from_err(Error::msg(STREAM_ERR_MSG).into()))
}
});
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use super::{AsyncEngineContextProvider, ResponseStream};
use super::{AsyncEngineContextProvider, ResponseStream, STREAM_ERR_MSG};
use crate::utils::worker_monitor::WorkerMonitor;
use crate::{
component::{Client, Endpoint, InstanceSource},
......@@ -231,11 +231,14 @@ where
let engine_ctx = stream.context();
let client = self.client.clone();
let stream = stream.map(move |res| {
if let Some(err) = res.err() {
const STREAM_ERR_MSG: &str = "Stream ended before generation completed";
if format!("{:?}", err) == STREAM_ERR_MSG {
client.report_instance_down(instance_id);
}
// TODO: Standardize error type to avoid using string matching DIS-364
if let Some(err) = res.err()
&& format!("{:?}", err) == STREAM_ERR_MSG
{
tracing::debug!(
"Reporting instance {instance_id} down due to stream error: {err}"
);
client.report_instance_down(instance_id);
}
res
});
......@@ -245,6 +248,9 @@ where
if let Some(req_err) = err.downcast_ref::<NatsRequestError>()
&& matches!(req_err.kind(), NatsNoResponders)
{
tracing::debug!(
"Reporting instance {instance_id} down due to request error: {req_err}"
);
self.client.report_instance_down(instance_id);
}
Err(err)
......
......@@ -264,13 +264,12 @@ where
let mut send_complete_final = true;
while let Some(resp) = stream.next().await {
tracing::trace!("Sending response: {:?}", resp);
if let Some(err) = resp.err() {
const STREAM_ERR_MSG: &str = "Stream ended before generation completed";
if format!("{:?}", err) == STREAM_ERR_MSG {
tracing::warn!(STREAM_ERR_MSG);
send_complete_final = false;
break;
}
if let Some(err) = resp.err()
&& format!("{:?}", err) == STREAM_ERR_MSG
{
tracing::warn!(STREAM_ERR_MSG);
send_complete_final = false;
break;
}
let resp_wrapper = NetworkStreamWrapper {
data: Some(resp),
......
......@@ -261,6 +261,11 @@ async fn handle_writer(
break;
}
_ = context.stopped() => {
tracing::trace!("context stop signal received; shutting down");
break;
}
msg = bytes_rx.recv() => {
match msg {
Some(msg) => msg,
......
......@@ -138,7 +138,7 @@ class DynamoWorkerProcess(ManagedProcess):
def send_completion_request(
prompt: str, max_tokens: int, timeout: int = 120
prompt: str, max_tokens: int, timeout: int | float = 120
) -> requests.Response:
"""Send a completion request to the frontend"""
payload = {
......@@ -172,7 +172,7 @@ def send_completion_request(
def send_chat_completion_request(
prompt: str, max_tokens: int, timeout: int = 120, stream: bool = False
prompt: str, max_tokens: int, timeout: int | float = 120, stream: bool = False
) -> requests.Response:
"""Send a chat completion request to the frontend"""
payload = {
......@@ -207,11 +207,18 @@ def send_chat_completion_request(
raise
def send_request_and_cancel(request_type: str = "completion", timeout: int = 1):
def send_request_and_cancel(
request_type: str = "completion",
timeout: int | float = 1,
use_long_prompt: bool = False,
):
"""Send a request with short timeout to trigger cancellation"""
logger.info(f"Sending {request_type} request to be cancelled...")
prompt = "Tell me a very long and detailed story about the history of artificial intelligence, including all major milestones, researchers, and breakthroughs?"
if use_long_prompt:
prompt += " Make sure it is" + " long" * 8000 + "!"
try:
if request_type == "completion":
response = send_completion_request(prompt, 8000, timeout)
......@@ -268,7 +275,7 @@ def verify_request_cancelled(
prefill_worker_process: DynamoWorkerProcess | None = None,
frontend_log_offset: int = 0,
worker_log_offset: int = 0,
prefill_worker_log_offset: int = 0,
assert_cancel_at_prefill: bool = False,
) -> tuple[int, int]:
"""Verify that the worker and frontend logs contain cancellation messages
......@@ -296,7 +303,11 @@ def verify_request_cancelled(
# Check if the same request ID was cancelled
has_worker_cancellation = False
cancellation_pattern = f"Aborted Request ID: {request_id}"
cancellation_pattern = (
f"Aborted Remote Prefill Request ID: {request_id}"
if assert_cancel_at_prefill
else f"Aborted Request ID: {request_id}"
)
for line in new_worker_content.split("\n"):
# Strip ANSI codes and whitespace for pattern matching
clean_line = strip_ansi_codes(line).strip()
......@@ -304,29 +315,39 @@ def verify_request_cancelled(
has_worker_cancellation = True
break
if not has_worker_cancellation:
pytest.fail(
f"Could not find 'Aborted Request ID: {request_id}' pattern in worker log"
)
pytest.fail(f"Could not find '{cancellation_pattern}' pattern in worker log")
# Check if the same request ID was remote prefilled
# Check prefill worker log if provided
if prefill_worker_process is not None:
prefill_worker_log_content = read_log_content(prefill_worker_process._log_path)
new_prefill_worker_content = prefill_worker_log_content[
prefill_worker_log_offset:
]
# Check if the same request ID was remote prefilled
has_remote_prefill = False
remote_prefill_pattern = f"New Prefill Request ID: {request_id}"
for line in new_prefill_worker_content.split("\n"):
for line in prefill_worker_log_content.split("\n"):
clean_line = strip_ansi_codes(line).strip()
if clean_line.endswith(remote_prefill_pattern):
has_remote_prefill = True
break
if not has_remote_prefill:
pytest.fail(
f"Could not find 'New Prefill Request ID: {request_id}' pattern in prefill worker log"
f"Could not find '{remote_prefill_pattern}' pattern in prefill worker log"
)
# Check for remote prefill cancellation
if assert_cancel_at_prefill:
has_prefill_cancellation = False
prefill_cancellation_pattern = f"Aborted Prefill Request ID: {request_id}"
for line in prefill_worker_log_content.split("\n"):
clean_line = strip_ansi_codes(line).strip()
if clean_line.endswith(prefill_cancellation_pattern):
has_prefill_cancellation = True
break
if not has_prefill_cancellation:
pytest.fail(
f"Could not find '{prefill_cancellation_pattern}' pattern in prefill worker log"
)
# Check frontend log for cancellation issued pattern
frontend_log_content = read_log_content(frontend_process._log_path)
new_frontend_content = frontend_log_content[frontend_log_offset:]
......@@ -391,7 +412,9 @@ def test_request_cancellation_vllm(request, runtime_services, predownload_models
logger.info(
"Checking for cancellation messages in worker and frontend logs..."
)
time.sleep(0.5) # Make sure logs are written before proceeding
# TODO: Need to wait for the next token to generate before seeing the
# cancellation on the logs. DIS-625
time.sleep(0.5)
frontend_log_offset, worker_log_offset = verify_request_cancelled(
frontend,
worker,
......@@ -401,10 +424,6 @@ def test_request_cancellation_vllm(request, runtime_services, predownload_models
logger.info(f"{description} detected successfully")
logger.info(
"All request cancellation tests completed successfully - request cancellation is working correctly"
)
@pytest.mark.vllm
@pytest.mark.gpu_1
......@@ -441,30 +460,26 @@ def test_request_cancellation_vllm_decode(
# Step 4: Test request cancellation for completion scenario only
logger.info(
"Testing completion request cancellation in disaggregated mode..."
"Testing completion request cancellation in decode worker..."
)
send_request_and_cancel("completion")
logger.info(
"Checking for cancellation messages in decode worker, prefill worker, and frontend logs..."
"Checking for cancellation messages in decode and prefill worker and frontend logs..."
)
time.sleep(0.5) # Make sure logs are written before proceeding
# TODO: Need to wait for the next token to generate before seeing the
# cancellation on the logs. DIS-625
time.sleep(0.5)
verify_request_cancelled(frontend, decode_worker, prefill_worker)
logger.info(
"Completion request cancellation detected successfully in disaggregated mode"
)
logger.info(
"Request cancellation test completed successfully in disaggregated mode - request cancellation is working correctly"
)
@pytest.mark.skip(reason="require cancel support before receiving 1st response")
@pytest.mark.vllm
@pytest.mark.gpu_1
@pytest.mark.e2e
def test_request_cancellation_vllm_prefill(request, runtime_services):
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
def test_request_cancellation_vllm_prefill(
request, runtime_services, predownload_models
):
"""
End-to-end test for request cancellation on remote prefill.
......@@ -473,3 +488,44 @@ def test_request_cancellation_vllm_prefill(request, runtime_services):
resources on the prefill worker and decode worker sides in a disaggregated
setup.
"""
# Step 1: Start the frontend
with DynamoFrontendProcess(request) as frontend:
logger.info("Frontend started successfully")
# Step 2: Start the prefill worker
logger.info("Starting prefill worker...")
prefill_worker = DynamoWorkerProcess(request, is_prefill=True)
with prefill_worker:
logger.info(f"Prefill Worker PID: {prefill_worker.get_pid()}")
# Step 3: Start the decode worker
logger.info("Starting decode worker...")
decode_worker = DynamoWorkerProcess(request, is_prefill=False)
with decode_worker:
logger.info(f"Decode Worker PID: {decode_worker.get_pid()}")
# TODO: Why the model is not immediately available at the frontend after
# health check returns success.
time.sleep(2)
# Step 4: Test request cancellation for completion scenario only
logger.info(
"Testing completion request cancellation in prefill worker..."
)
send_request_and_cancel("completion", timeout=0.1, use_long_prompt=True)
logger.info(
"Checking for cancellation messages in decode and prefill worker and frontend logs..."
)
# TODO: Need to wait for prefill to generate first token before seeing
# the cancellation on the logs. DIS-625
time.sleep(3)
verify_request_cancelled(
frontend,
decode_worker,
prefill_worker,
assert_cancel_at_prefill=True,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment