Unverified Commit a8fd1271 authored by Jacky's avatar Jacky Committed by GitHub
Browse files

feat: Request Cancellation unary request support (#3004)


Signed-off-by: default avatarJacky <18255193+kthui@users.noreply.github.com>
parent 10bfb73a
...@@ -39,16 +39,12 @@ class MiddleServer: ...@@ -39,16 +39,12 @@ class MiddleServer:
stream = await self.backend_client.generate(request, context=context) stream = await self.backend_client.generate(request, context=context)
# Stream responses back to client # Stream responses back to client
try: async for response in stream:
async for response in stream: data = response.data()
data = response.data() print(f"Middle server: Forwarding response {data}")
print(f"Middle server: Forwarding response {data}") yield data
yield data
print("Middle server: Backend stream ended")
except ValueError as e:
if str(e) != "Stream ended before generation completed":
raise
print("Middle server: Backend stream ended early due to cancellation")
async def main(): async def main():
......
...@@ -38,15 +38,17 @@ async def test_client_context_cancel(server, client): ...@@ -38,15 +38,17 @@ async def test_client_context_cancel(server, client):
if iteration_count >= 2: if iteration_count >= 2:
print("Cancelling after 2 responses...") print("Cancelling after 2 responses...")
context.stop_generating() context.stop_generating()
break
iteration_count += 1 iteration_count += 1
# Verify we received exactly 3 responses (0, 1, 2)
assert iteration_count == 3
# Give server a moment to process the cancellation # Give server a moment to process the cancellation
await asyncio.sleep(0.2) await asyncio.sleep(0.2)
# Verify server detected the cancellation # Verify server detected the cancellation
assert handler.context_is_stopped assert handler.context_is_stopped
assert handler.context_is_killed assert not handler.context_is_killed
# TODO: Test with _generate_until_asyncio_cancelled server handler # TODO: Test with _generate_until_asyncio_cancelled server handler
...@@ -139,9 +139,9 @@ async def test_middle_server_cancellation( ...@@ -139,9 +139,9 @@ async def test_middle_server_cancellation(
assert ( assert (
"Client: Cancelling after 3 responses..." in client_output "Client: Cancelling after 3 responses..." in client_output
), f"Client output: {client_output}" ), f"Client output: {client_output}"
assert (
"Middle server: Forwarding response 2" in middle_output
), f"Middle server output: {middle_output}"
assert ( assert (
"Server: Cancelled at iteration" in server_output "Server: Cancelled at iteration" in server_output
), f"Server output: {server_output}" ), f"Server output: {server_output}"
assert (
"Middle server: Backend stream ended early due to cancellation" in middle_output
), f"Middle server output: {middle_output}"
...@@ -17,8 +17,8 @@ use crate::{ ...@@ -17,8 +17,8 @@ use crate::{
use dynamo_runtime::{ use dynamo_runtime::{
pipeline::{ pipeline::{
AsyncEngineContextProvider, Context, ManyOut, Operator, ResponseStream, AsyncEngineContext, AsyncEngineContextProvider, Context, ManyOut, Operator, ResponseStream,
ServerStreamingEngine, SingleIn, async_trait, ServerStreamingEngine, SingleIn, async_trait, network::STREAM_ERR_MSG,
}, },
protocols::{annotated::Annotated, maybe_error::MaybeError}, protocols::{annotated::Annotated, maybe_error::MaybeError},
}; };
...@@ -55,30 +55,23 @@ impl ...@@ -55,30 +55,23 @@ impl
next: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>>, next: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>>,
) -> Result<ManyOut<Annotated<LLMEngineOutput>>> { ) -> Result<ManyOut<Annotated<LLMEngineOutput>>> {
let (preprocessed_request, context) = request.transfer(()); let (preprocessed_request, context) = request.transfer(());
let context_id = context.id().to_string();
let engine_ctx = context.context(); let engine_ctx = context.context();
let engine_ctx_ = engine_ctx.clone(); let engine_ctx_ = engine_ctx.clone();
let retry_manager = let retry_manager =
RetryManager::build(context_id, preprocessed_request, next, self.migration_limit) RetryManager::build(engine_ctx, preprocessed_request, next, self.migration_limit)
.await?; .await?;
let response_stream = stream::unfold(retry_manager, move |mut retry_manager| { let response_stream = stream::unfold(retry_manager, move |mut retry_manager| async move {
let engine_ctx = engine_ctx_.clone(); retry_manager
async move { .next()
if engine_ctx.is_stopped() || engine_ctx.is_killed() { .await
return None; // Stop if the context is cancelled or stopped .map(|response| (response, retry_manager))
}
retry_manager
.next()
.await
.map(|response| (response, retry_manager))
}
}); });
Ok(ResponseStream::new(Box::pin(response_stream), engine_ctx)) Ok(ResponseStream::new(Box::pin(response_stream), engine_ctx_))
} }
} }
struct RetryManager { struct RetryManager {
context_id: String, context: Arc<dyn AsyncEngineContext>,
request: PreprocessedRequest, request: PreprocessedRequest,
next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>>, next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>>,
next_stream: Option<ManyOut<Annotated<LLMEngineOutput>>>, next_stream: Option<ManyOut<Annotated<LLMEngineOutput>>>,
...@@ -87,13 +80,13 @@ struct RetryManager { ...@@ -87,13 +80,13 @@ struct RetryManager {
impl RetryManager { impl RetryManager {
pub async fn build( pub async fn build(
context_id: String, context: Arc<dyn AsyncEngineContext>,
preprocessed_request: PreprocessedRequest, preprocessed_request: PreprocessedRequest,
next: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>>, next: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>>,
retries_left: u32, retries_left: u32,
) -> Result<Self> { ) -> Result<Self> {
let mut slf = Self { let mut slf = Self {
context_id, context,
request: preprocessed_request, request: preprocessed_request,
next_generate: next, next_generate: next,
next_stream: None, next_stream: None,
...@@ -115,18 +108,16 @@ impl RetryManager { ...@@ -115,18 +108,16 @@ impl RetryManager {
} }
}; };
if let Some(response) = response_stream.next().await { if let Some(response) = response_stream.next().await {
if let Some(err) = response.err() { if let Some(err) = response.err()
const STREAM_ERR_MSG: &str = "Stream ended before generation completed"; && err
if err
.chain() .chain()
.any(|e| e.to_string().starts_with(STREAM_ERR_MSG)) .any(|e| e.to_string().starts_with(STREAM_ERR_MSG))
{ {
tracing::warn!("Stream disconnected... recreating stream..."); tracing::warn!("Stream disconnected... recreating stream...");
if let Err(err) = self.new_stream().await { if let Err(err) = self.new_stream().await {
tracing::warn!("Cannot recreate stream: {:#}", err); tracing::warn!("Cannot recreate stream: {:#}", err);
} else { } else {
continue; continue;
}
} }
} }
self.track_response(&response); self.track_response(&response);
...@@ -140,7 +131,8 @@ impl RetryManager { ...@@ -140,7 +131,8 @@ impl RetryManager {
let mut response_stream: Option<Result<ManyOut<Annotated<LLMEngineOutput>>>> = None; let mut response_stream: Option<Result<ManyOut<Annotated<LLMEngineOutput>>>> = None;
while self.retries_left > 0 { while self.retries_left > 0 {
self.retries_left -= 1; self.retries_left -= 1;
let request = Context::with_id(self.request.clone(), self.context_id.clone()); let request = Context::with_id(self.request.clone(), self.context.id().to_string());
self.context.link_child(request.context());
response_stream = Some(self.next_generate.generate(request).await); response_stream = Some(self.next_generate.generate(request).await);
if let Some(err) = response_stream.as_ref().unwrap().as_ref().err() if let Some(err) = response_stream.as_ref().unwrap().as_ref().err()
&& let Some(req_err) = err.downcast_ref::<NatsRequestError>() && let Some(req_err) = err.downcast_ref::<NatsRequestError>()
...@@ -339,10 +331,8 @@ mod tests { ...@@ -339,10 +331,8 @@ mod tests {
} }
} }
// Send the specific error that triggers retry logic // Send the specific error that triggers retry logic
let error_response = Annotated::from_err( let error_response =
anyhow::Error::msg("Stream ended before generation completed") Annotated::from_err(anyhow::Error::msg(STREAM_ERR_MSG).into());
.into(),
);
let _ = tx.send(error_response).await; let _ = tx.send(error_response).await;
}); });
} else { } else {
...@@ -381,10 +371,8 @@ mod tests { ...@@ -381,10 +371,8 @@ mod tests {
} }
} }
// Send the specific error that triggers retry logic // Send the specific error that triggers retry logic
let error_response = Annotated::from_err( let error_response =
anyhow::Error::msg("Stream ended before generation completed") Annotated::from_err(anyhow::Error::msg(STREAM_ERR_MSG).into());
.into(),
);
let _ = tx.send(error_response).await; let _ = tx.send(error_response).await;
}); });
...@@ -417,10 +405,8 @@ mod tests { ...@@ -417,10 +405,8 @@ mod tests {
} }
} }
// Send the specific error that triggers retry logic // Send the specific error that triggers retry logic
let error_response = Annotated::from_err( let error_response =
anyhow::Error::msg("Stream ended before generation completed") Annotated::from_err(anyhow::Error::msg(STREAM_ERR_MSG).into());
.into(),
);
let _ = tx.send(error_response).await; let _ = tx.send(error_response).await;
}); });
...@@ -434,10 +420,8 @@ mod tests { ...@@ -434,10 +420,8 @@ mod tests {
// Subsequent calls - immediately send stream error (no successful responses) // Subsequent calls - immediately send stream error (no successful responses)
tokio::spawn(async move { tokio::spawn(async move {
// Send the stream error immediately // Send the stream error immediately
let error_response = Annotated::from_err( let error_response =
anyhow::Error::msg("Stream ended before generation completed") Annotated::from_err(anyhow::Error::msg(STREAM_ERR_MSG).into());
.into(),
);
let _ = tx.send(error_response).await; let _ = tx.send(error_response).await;
}); });
...@@ -503,7 +487,8 @@ mod tests { ...@@ -503,7 +487,8 @@ mod tests {
let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>> = let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>> =
mock_engine; mock_engine;
let mut retry_manager = RetryManager::build(context_id, request, next_generate, 0) let ctx = Arc::new(Controller::new(context_id.clone()));
let mut retry_manager = RetryManager::build(ctx, request, next_generate, 0)
.await .await
.expect("Failed to build RetryManager"); .expect("Failed to build RetryManager");
...@@ -541,7 +526,8 @@ mod tests { ...@@ -541,7 +526,8 @@ mod tests {
let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>> = let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>> =
mock_engine; mock_engine;
let mut retry_manager = RetryManager::build(context_id, request, next_generate, 3) let ctx = Arc::new(Controller::new(context_id.clone()));
let mut retry_manager = RetryManager::build(ctx, request, next_generate, 3)
.await .await
.expect("Failed to build RetryManager"); .expect("Failed to build RetryManager");
...@@ -580,7 +566,8 @@ mod tests { ...@@ -580,7 +566,8 @@ mod tests {
let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>> = let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>> =
mock_engine; mock_engine;
let mut retry_manager = RetryManager::build(context_id, request, next_generate, 3) let ctx = Arc::new(Controller::new(context_id.clone()));
let mut retry_manager = RetryManager::build(ctx, request, next_generate, 3)
.await .await
.expect("Failed to build RetryManager"); .expect("Failed to build RetryManager");
...@@ -620,7 +607,8 @@ mod tests { ...@@ -620,7 +607,8 @@ mod tests {
mock_engine; mock_engine;
// Should fail to build due to initial stream creation failure after exhausting all 3 retries // Should fail to build due to initial stream creation failure after exhausting all 3 retries
let retry_manager_result = RetryManager::build(context_id, request, next_generate, 3).await; let ctx = Arc::new(Controller::new(context_id.clone()));
let retry_manager_result = RetryManager::build(ctx, request, next_generate, 3).await;
assert!(retry_manager_result.is_err()); assert!(retry_manager_result.is_err());
if let Err(error) = retry_manager_result { if let Err(error) = retry_manager_result {
...@@ -646,7 +634,8 @@ mod tests { ...@@ -646,7 +634,8 @@ mod tests {
let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>> = let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>> =
mock_engine; mock_engine;
let mut retry_manager = RetryManager::build(context_id, request, next_generate, 3) // 3 retries let ctx = Arc::new(Controller::new(context_id.clone()));
let mut retry_manager = RetryManager::build(ctx, request, next_generate, 3) // 3 retries
.await .await
.expect("Failed to build RetryManager"); .expect("Failed to build RetryManager");
...@@ -672,11 +661,7 @@ mod tests { ...@@ -672,11 +661,7 @@ mod tests {
let error_response = &responses[3]; let error_response = &responses[3];
assert!(error_response.err().is_some()); assert!(error_response.err().is_some());
if let Some(error) = error_response.err() { if let Some(error) = error_response.err() {
assert!( assert!(error.to_string().contains(STREAM_ERR_MSG));
error
.to_string()
.contains("Stream ended before generation completed")
);
} }
} }
...@@ -698,7 +683,8 @@ mod tests { ...@@ -698,7 +683,8 @@ mod tests {
let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>> = let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>> =
mock_engine; mock_engine;
let mut retry_manager = RetryManager::build(context_id, request, next_generate, 3) // 3 retries let ctx = Arc::new(Controller::new(context_id.clone()));
let mut retry_manager = RetryManager::build(ctx, request, next_generate, 3) // 3 retries
.await .await
.expect("Failed to build RetryManager"); .expect("Failed to build RetryManager");
...@@ -724,11 +710,7 @@ mod tests { ...@@ -724,11 +710,7 @@ mod tests {
let error_response = &responses[3]; let error_response = &responses[3];
assert!(error_response.err().is_some()); assert!(error_response.err().is_some());
if let Some(error) = error_response.err() { if let Some(error) = error_response.err() {
assert!( assert!(error.to_string().contains(STREAM_ERR_MSG));
error
.to_string()
.contains("Stream ended before generation completed")
);
} }
} }
} }
...@@ -358,18 +358,20 @@ impl AsyncEngineContext for Controller { ...@@ -358,18 +358,20 @@ impl AsyncEngineContext for Controller {
async fn stopped(&self) { async fn stopped(&self) {
let mut rx = self.rx.clone(); let mut rx = self.rx.clone();
if *rx.borrow_and_update() != State::Live { loop {
return; if *rx.borrow_and_update() != State::Live || rx.changed().await.is_err() {
return;
}
} }
let _ = rx.changed().await;
} }
async fn killed(&self) { async fn killed(&self) {
let mut rx = self.rx.clone(); let mut rx = self.rx.clone();
if *rx.borrow_and_update() == State::Killed { loop {
return; if *rx.borrow_and_update() == State::Killed || rx.changed().await.is_err() {
return;
}
} }
let _ = rx.changed().await;
} }
fn stop_generating(&self) { fn stop_generating(&self) {
......
...@@ -27,6 +27,9 @@ use super::{ ...@@ -27,6 +27,9 @@ use super::{
}; };
use ingress::push_handler::WorkHandlerMetrics; use ingress::push_handler::WorkHandlerMetrics;
// Define stream error message constant
pub const STREAM_ERR_MSG: &str = "Stream ended before generation completed";
// Add Prometheus metrics types // Add Prometheus metrics types
use crate::metrics::MetricsRegistry; use crate::metrics::MetricsRegistry;
use prometheus::{CounterVec, Histogram, IntCounter, IntCounterVec, IntGauge}; use prometheus::{CounterVec, Histogram, IntCounter, IntCounterVec, IntGauge};
......
...@@ -80,6 +80,7 @@ where ...@@ -80,6 +80,7 @@ where
let (addressed_request, context) = request.transfer(()); let (addressed_request, context) = request.transfer(());
let (request, address) = addressed_request.into_parts(); let (request, address) = addressed_request.into_parts();
let engine_ctx = context.context(); let engine_ctx = context.context();
let engine_ctx_ = engine_ctx.clone();
// registration options for the data plane in a singe in / many out configuration // registration options for the data plane in a singe in / many out configuration
let options = StreamOptions::builder() let options = StreamOptions::builder()
...@@ -209,11 +210,18 @@ where ...@@ -209,11 +210,18 @@ where
} }
} }
} else if is_complete_final { } else if is_complete_final {
// end of stream
None
} else if engine_ctx_.is_stopped() {
// Gracefully end the stream if 'stop_generating()' was called. Do NOT check for
// 'is_killed()' here because it implies the stream ended abnormally which should be
// handled by the error branch below.
log::debug!("Request cancelled and then trying to read a response");
None None
} else { } else {
Some(U::from_err( // stream ended unexpectedly
Error::msg("Stream ended before generation completed").into(), log::debug!("{STREAM_ERR_MSG}");
)) Some(U::from_err(Error::msg(STREAM_ERR_MSG).into()))
} }
}); });
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use super::{AsyncEngineContextProvider, ResponseStream}; use super::{AsyncEngineContextProvider, ResponseStream, STREAM_ERR_MSG};
use crate::utils::worker_monitor::WorkerMonitor; use crate::utils::worker_monitor::WorkerMonitor;
use crate::{ use crate::{
component::{Client, Endpoint, InstanceSource}, component::{Client, Endpoint, InstanceSource},
...@@ -231,11 +231,14 @@ where ...@@ -231,11 +231,14 @@ where
let engine_ctx = stream.context(); let engine_ctx = stream.context();
let client = self.client.clone(); let client = self.client.clone();
let stream = stream.map(move |res| { let stream = stream.map(move |res| {
if let Some(err) = res.err() { // TODO: Standardize error type to avoid using string matching DIS-364
const STREAM_ERR_MSG: &str = "Stream ended before generation completed"; if let Some(err) = res.err()
if format!("{:?}", err) == STREAM_ERR_MSG { && format!("{:?}", err) == STREAM_ERR_MSG
client.report_instance_down(instance_id); {
} tracing::debug!(
"Reporting instance {instance_id} down due to stream error: {err}"
);
client.report_instance_down(instance_id);
} }
res res
}); });
...@@ -245,6 +248,9 @@ where ...@@ -245,6 +248,9 @@ where
if let Some(req_err) = err.downcast_ref::<NatsRequestError>() if let Some(req_err) = err.downcast_ref::<NatsRequestError>()
&& matches!(req_err.kind(), NatsNoResponders) && matches!(req_err.kind(), NatsNoResponders)
{ {
tracing::debug!(
"Reporting instance {instance_id} down due to request error: {req_err}"
);
self.client.report_instance_down(instance_id); self.client.report_instance_down(instance_id);
} }
Err(err) Err(err)
......
...@@ -264,13 +264,12 @@ where ...@@ -264,13 +264,12 @@ where
let mut send_complete_final = true; let mut send_complete_final = true;
while let Some(resp) = stream.next().await { while let Some(resp) = stream.next().await {
tracing::trace!("Sending response: {:?}", resp); tracing::trace!("Sending response: {:?}", resp);
if let Some(err) = resp.err() { if let Some(err) = resp.err()
const STREAM_ERR_MSG: &str = "Stream ended before generation completed"; && format!("{:?}", err) == STREAM_ERR_MSG
if format!("{:?}", err) == STREAM_ERR_MSG { {
tracing::warn!(STREAM_ERR_MSG); tracing::warn!(STREAM_ERR_MSG);
send_complete_final = false; send_complete_final = false;
break; break;
}
} }
let resp_wrapper = NetworkStreamWrapper { let resp_wrapper = NetworkStreamWrapper {
data: Some(resp), data: Some(resp),
......
...@@ -261,6 +261,11 @@ async fn handle_writer( ...@@ -261,6 +261,11 @@ async fn handle_writer(
break; break;
} }
_ = context.stopped() => {
tracing::trace!("context stop signal received; shutting down");
break;
}
msg = bytes_rx.recv() => { msg = bytes_rx.recv() => {
match msg { match msg {
Some(msg) => msg, Some(msg) => msg,
......
...@@ -138,7 +138,7 @@ class DynamoWorkerProcess(ManagedProcess): ...@@ -138,7 +138,7 @@ class DynamoWorkerProcess(ManagedProcess):
def send_completion_request( def send_completion_request(
prompt: str, max_tokens: int, timeout: int = 120 prompt: str, max_tokens: int, timeout: int | float = 120
) -> requests.Response: ) -> requests.Response:
"""Send a completion request to the frontend""" """Send a completion request to the frontend"""
payload = { payload = {
...@@ -172,7 +172,7 @@ def send_completion_request( ...@@ -172,7 +172,7 @@ def send_completion_request(
def send_chat_completion_request( def send_chat_completion_request(
prompt: str, max_tokens: int, timeout: int = 120, stream: bool = False prompt: str, max_tokens: int, timeout: int | float = 120, stream: bool = False
) -> requests.Response: ) -> requests.Response:
"""Send a chat completion request to the frontend""" """Send a chat completion request to the frontend"""
payload = { payload = {
...@@ -207,11 +207,18 @@ def send_chat_completion_request( ...@@ -207,11 +207,18 @@ def send_chat_completion_request(
raise raise
def send_request_and_cancel(request_type: str = "completion", timeout: int = 1): def send_request_and_cancel(
request_type: str = "completion",
timeout: int | float = 1,
use_long_prompt: bool = False,
):
"""Send a request with short timeout to trigger cancellation""" """Send a request with short timeout to trigger cancellation"""
logger.info(f"Sending {request_type} request to be cancelled...") logger.info(f"Sending {request_type} request to be cancelled...")
prompt = "Tell me a very long and detailed story about the history of artificial intelligence, including all major milestones, researchers, and breakthroughs?" prompt = "Tell me a very long and detailed story about the history of artificial intelligence, including all major milestones, researchers, and breakthroughs?"
if use_long_prompt:
prompt += " Make sure it is" + " long" * 8000 + "!"
try: try:
if request_type == "completion": if request_type == "completion":
response = send_completion_request(prompt, 8000, timeout) response = send_completion_request(prompt, 8000, timeout)
...@@ -268,7 +275,7 @@ def verify_request_cancelled( ...@@ -268,7 +275,7 @@ def verify_request_cancelled(
prefill_worker_process: DynamoWorkerProcess | None = None, prefill_worker_process: DynamoWorkerProcess | None = None,
frontend_log_offset: int = 0, frontend_log_offset: int = 0,
worker_log_offset: int = 0, worker_log_offset: int = 0,
prefill_worker_log_offset: int = 0, assert_cancel_at_prefill: bool = False,
) -> tuple[int, int]: ) -> tuple[int, int]:
"""Verify that the worker and frontend logs contain cancellation messages """Verify that the worker and frontend logs contain cancellation messages
...@@ -296,7 +303,11 @@ def verify_request_cancelled( ...@@ -296,7 +303,11 @@ def verify_request_cancelled(
# Check if the same request ID was cancelled # Check if the same request ID was cancelled
has_worker_cancellation = False has_worker_cancellation = False
cancellation_pattern = f"Aborted Request ID: {request_id}" cancellation_pattern = (
f"Aborted Remote Prefill Request ID: {request_id}"
if assert_cancel_at_prefill
else f"Aborted Request ID: {request_id}"
)
for line in new_worker_content.split("\n"): for line in new_worker_content.split("\n"):
# Strip ANSI codes and whitespace for pattern matching # Strip ANSI codes and whitespace for pattern matching
clean_line = strip_ansi_codes(line).strip() clean_line = strip_ansi_codes(line).strip()
...@@ -304,29 +315,39 @@ def verify_request_cancelled( ...@@ -304,29 +315,39 @@ def verify_request_cancelled(
has_worker_cancellation = True has_worker_cancellation = True
break break
if not has_worker_cancellation: if not has_worker_cancellation:
pytest.fail( pytest.fail(f"Could not find '{cancellation_pattern}' pattern in worker log")
f"Could not find 'Aborted Request ID: {request_id}' pattern in worker log"
)
# Check if the same request ID was remote prefilled # Check prefill worker log if provided
if prefill_worker_process is not None: if prefill_worker_process is not None:
prefill_worker_log_content = read_log_content(prefill_worker_process._log_path) prefill_worker_log_content = read_log_content(prefill_worker_process._log_path)
new_prefill_worker_content = prefill_worker_log_content[
prefill_worker_log_offset:
]
# Check if the same request ID was remote prefilled
has_remote_prefill = False has_remote_prefill = False
remote_prefill_pattern = f"New Prefill Request ID: {request_id}" remote_prefill_pattern = f"New Prefill Request ID: {request_id}"
for line in new_prefill_worker_content.split("\n"): for line in prefill_worker_log_content.split("\n"):
clean_line = strip_ansi_codes(line).strip() clean_line = strip_ansi_codes(line).strip()
if clean_line.endswith(remote_prefill_pattern): if clean_line.endswith(remote_prefill_pattern):
has_remote_prefill = True has_remote_prefill = True
break break
if not has_remote_prefill: if not has_remote_prefill:
pytest.fail( pytest.fail(
f"Could not find 'New Prefill Request ID: {request_id}' pattern in prefill worker log" f"Could not find '{remote_prefill_pattern}' pattern in prefill worker log"
) )
# Check for remote prefill cancellation
if assert_cancel_at_prefill:
has_prefill_cancellation = False
prefill_cancellation_pattern = f"Aborted Prefill Request ID: {request_id}"
for line in prefill_worker_log_content.split("\n"):
clean_line = strip_ansi_codes(line).strip()
if clean_line.endswith(prefill_cancellation_pattern):
has_prefill_cancellation = True
break
if not has_prefill_cancellation:
pytest.fail(
f"Could not find '{prefill_cancellation_pattern}' pattern in prefill worker log"
)
# Check frontend log for cancellation issued pattern # Check frontend log for cancellation issued pattern
frontend_log_content = read_log_content(frontend_process._log_path) frontend_log_content = read_log_content(frontend_process._log_path)
new_frontend_content = frontend_log_content[frontend_log_offset:] new_frontend_content = frontend_log_content[frontend_log_offset:]
...@@ -391,7 +412,9 @@ def test_request_cancellation_vllm(request, runtime_services, predownload_models ...@@ -391,7 +412,9 @@ def test_request_cancellation_vllm(request, runtime_services, predownload_models
logger.info( logger.info(
"Checking for cancellation messages in worker and frontend logs..." "Checking for cancellation messages in worker and frontend logs..."
) )
time.sleep(0.5) # Make sure logs are written before proceeding # TODO: Need to wait for the next token to generate before seeing the
# cancellation on the logs. DIS-625
time.sleep(0.5)
frontend_log_offset, worker_log_offset = verify_request_cancelled( frontend_log_offset, worker_log_offset = verify_request_cancelled(
frontend, frontend,
worker, worker,
...@@ -401,10 +424,6 @@ def test_request_cancellation_vllm(request, runtime_services, predownload_models ...@@ -401,10 +424,6 @@ def test_request_cancellation_vllm(request, runtime_services, predownload_models
logger.info(f"{description} detected successfully") logger.info(f"{description} detected successfully")
logger.info(
"All request cancellation tests completed successfully - request cancellation is working correctly"
)
@pytest.mark.vllm @pytest.mark.vllm
@pytest.mark.gpu_1 @pytest.mark.gpu_1
...@@ -441,30 +460,26 @@ def test_request_cancellation_vllm_decode( ...@@ -441,30 +460,26 @@ def test_request_cancellation_vllm_decode(
# Step 4: Test request cancellation for completion scenario only # Step 4: Test request cancellation for completion scenario only
logger.info( logger.info(
"Testing completion request cancellation in disaggregated mode..." "Testing completion request cancellation in decode worker..."
) )
send_request_and_cancel("completion") send_request_and_cancel("completion")
logger.info( logger.info(
"Checking for cancellation messages in decode worker, prefill worker, and frontend logs..." "Checking for cancellation messages in decode and prefill worker and frontend logs..."
) )
time.sleep(0.5) # Make sure logs are written before proceeding # TODO: Need to wait for the next token to generate before seeing the
# cancellation on the logs. DIS-625
time.sleep(0.5)
verify_request_cancelled(frontend, decode_worker, prefill_worker) verify_request_cancelled(frontend, decode_worker, prefill_worker)
logger.info(
"Completion request cancellation detected successfully in disaggregated mode"
)
logger.info(
"Request cancellation test completed successfully in disaggregated mode - request cancellation is working correctly"
)
@pytest.mark.skip(reason="require cancel support before receiving 1st response")
@pytest.mark.vllm @pytest.mark.vllm
@pytest.mark.gpu_1 @pytest.mark.gpu_1
@pytest.mark.e2e @pytest.mark.e2e
def test_request_cancellation_vllm_prefill(request, runtime_services): @pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
def test_request_cancellation_vllm_prefill(
request, runtime_services, predownload_models
):
""" """
End-to-end test for request cancellation on remote prefill. End-to-end test for request cancellation on remote prefill.
...@@ -473,3 +488,44 @@ def test_request_cancellation_vllm_prefill(request, runtime_services): ...@@ -473,3 +488,44 @@ def test_request_cancellation_vllm_prefill(request, runtime_services):
resources on the prefill worker and decode worker sides in a disaggregated resources on the prefill worker and decode worker sides in a disaggregated
setup. setup.
""" """
# Step 1: Start the frontend
with DynamoFrontendProcess(request) as frontend:
logger.info("Frontend started successfully")
# Step 2: Start the prefill worker
logger.info("Starting prefill worker...")
prefill_worker = DynamoWorkerProcess(request, is_prefill=True)
with prefill_worker:
logger.info(f"Prefill Worker PID: {prefill_worker.get_pid()}")
# Step 3: Start the decode worker
logger.info("Starting decode worker...")
decode_worker = DynamoWorkerProcess(request, is_prefill=False)
with decode_worker:
logger.info(f"Decode Worker PID: {decode_worker.get_pid()}")
# TODO: Why the model is not immediately available at the frontend after
# health check returns success.
time.sleep(2)
# Step 4: Test request cancellation for completion scenario only
logger.info(
"Testing completion request cancellation in prefill worker..."
)
send_request_and_cancel("completion", timeout=0.1, use_long_prompt=True)
logger.info(
"Checking for cancellation messages in decode and prefill worker and frontend logs..."
)
# TODO: Need to wait for prefill to generate first token before seeing
# the cancellation on the logs. DIS-625
time.sleep(3)
verify_request_cancelled(
frontend,
decode_worker,
prefill_worker,
assert_cancel_at_prefill=True,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment