Unverified Commit e252f82d authored by Jacky's avatar Jacky Committed by GitHub
Browse files

feat: Frontend and Runtime Request Cancellation Metrics (#7493)


Signed-off-by: default avatarJacky <18255193+kthui@users.noreply.github.com>
parent 4159c6fd
......@@ -21,7 +21,7 @@ use super::kserve::inference;
// [gluo NOTE] These are common utilities that should be shared between frontends
use crate::http::service::{
disconnect::{ConnectionHandle, create_connection_monitor},
metrics::{Endpoint, InflightGuard, process_response_and_observe_metrics},
metrics::{CancellationLabels, Endpoint, InflightGuard, process_response_and_observe_metrics},
};
use dynamo_async_openai::types::{CompletionFinishReason, CreateCompletionRequest, Prompt};
......@@ -54,14 +54,22 @@ pub async fn completion_response_stream(
// create the context for the request
// [WIP] from request id.
let request_id = get_or_create_request_id(request.inner.user.as_deref());
let streaming = request.inner.stream.unwrap_or(false);
let cancellation_labels = CancellationLabels {
model: request.inner.model.clone(),
endpoint: "grpc_completions".to_string(),
request_type: if streaming { "stream" } else { "unary" }.to_string(),
};
let request = Context::with_id(request, request_id.clone());
let context = request.context();
// create the connection handles
let (mut connection_handle, stream_handle) =
create_connection_monitor(context.clone(), Some(state.metrics_clone())).await;
let streaming = request.inner.stream.unwrap_or(false);
let (mut connection_handle, stream_handle) = create_connection_monitor(
context.clone(),
Some(state.metrics_clone()),
cancellation_labels,
)
.await;
// update the request to always stream
let request = request.map(|mut req| {
req.inner.stream = Some(true);
......
......@@ -20,7 +20,7 @@ use validator::Validate;
use crate::http::service::metrics::InflightGuard;
use crate::http::service::{
disconnect::{ConnectionHandle, create_connection_monitor},
metrics::{Endpoint, process_response_and_observe_metrics},
metrics::{CancellationLabels, Endpoint, process_response_and_observe_metrics},
};
use crate::protocols::tensor;
......@@ -60,13 +60,22 @@ pub async fn tensor_response_stream(
) -> Result<impl Stream<Item = Annotated<NvCreateTensorResponse>>, Status> {
// create the context for the request
let request_id = get_or_create_request_id(request.id.as_deref());
let cancellation_labels = CancellationLabels {
model: request.model.clone(),
endpoint: Endpoint::Tensor.to_string(),
request_type: if streaming { "stream" } else { "unary" }.to_string(),
};
let request = Context::with_id(request, request_id.clone());
let context = request.context();
// [gluo TODO] revisit metrics to properly expose it
// create the connection handles
let (mut connection_handle, stream_handle) =
create_connection_monitor(context.clone(), Some(state.metrics_clone())).await;
let (mut connection_handle, stream_handle) = create_connection_monitor(
context.clone(),
Some(state.metrics_clone()),
cancellation_labels,
)
.await;
// todo - make the protocols be optional for model name
// todo - when optional, if none, apply a default
......
......@@ -30,7 +30,7 @@ use tracing::Instrument;
use super::{
RouteDoc,
disconnect::{ConnectionHandle, create_connection_monitor, monitor_for_disconnects},
metrics::{Endpoint, process_response_and_observe_metrics},
metrics::{CancellationLabels, Endpoint, process_response_and_observe_metrics},
service_v2,
};
use crate::preprocessor::OpenAIPreprocessor;
......@@ -125,12 +125,22 @@ async fn handler_anthropic_messages(
// Create request context
let request_id = get_or_create_request_id(None, &headers);
let streaming = request.stream;
let cancellation_labels = CancellationLabels {
model: request.model.clone(),
endpoint: Endpoint::AnthropicMessages.to_string(),
request_type: if streaming { "stream" } else { "unary" }.to_string(),
};
let request = Context::with_id(request, request_id);
let context = request.context();
// Create connection handles
let (mut connection_handle, stream_handle) =
create_connection_monitor(context.clone(), Some(state.metrics_clone())).await;
let (mut connection_handle, stream_handle) = create_connection_monitor(
context.clone(),
Some(state.metrics_clone()),
cancellation_labels,
)
.await;
let response =
tokio::spawn(anthropic_messages(state, template, request, stream_handle).in_current_span())
......
......@@ -33,7 +33,7 @@ use dynamo_runtime::engine::AsyncEngineContext;
use futures::{Stream, StreamExt};
use std::sync::Arc;
use crate::http::service::metrics::{ErrorType, InflightGuard, Metrics};
use crate::http::service::metrics::{CancellationLabels, ErrorType, InflightGuard, Metrics};
#[derive(Clone, Copy)]
pub enum ConnectionStatus {
......@@ -100,6 +100,7 @@ impl Drop for ConnectionHandle {
pub async fn create_connection_monitor(
engine_context: Arc<dyn AsyncEngineContext>,
metrics: Option<Arc<Metrics>>,
cancellation_labels: CancellationLabels,
) -> (ConnectionHandle, ConnectionHandle) {
// these oneshot channels monitor possible disconnects from the client in two different scopes:
// - the local task (connection_handle)
......@@ -113,6 +114,7 @@ pub async fn create_connection_monitor(
connection_rx,
stream_rx,
metrics,
cancellation_labels,
));
// Two handles, the first is armed, the second is disarmed
......@@ -128,6 +130,7 @@ async fn connection_monitor(
connection_rx: tokio::sync::oneshot::Receiver<ConnectionStatus>,
stream_rx: tokio::sync::oneshot::Receiver<ConnectionStatus>,
metrics: Option<Arc<Metrics>>,
cancellation_labels: CancellationLabels,
) {
match connection_rx.await {
Err(_) | Ok(ConnectionStatus::ClosedUnexpectedly) => {
......@@ -135,6 +138,7 @@ async fn connection_monitor(
tracing::trace!("Connection closed unexpectedly; issuing cancellation");
if let Some(metrics) = &metrics {
metrics.inc_client_disconnect();
metrics.inc_cancellation(&cancellation_labels);
}
engine_context.kill();
}
......@@ -149,6 +153,7 @@ async fn connection_monitor(
tracing::trace!("Stream closed unexpectedly; issuing cancellation");
if let Some(metrics) = &metrics {
metrics.inc_client_disconnect();
metrics.inc_cancellation(&cancellation_labels);
}
engine_context.kill();
}
......
......@@ -256,6 +256,7 @@ pub struct Metrics {
model_kv_cache_block_size: IntGaugeVec,
model_migration_limit: IntGaugeVec,
model_migration_total: IntCounterVec,
model_cancellation_total: IntCounterVec,
}
// Inflight tracks requests from HTTP handler start until complete response is finished.
......@@ -322,6 +323,13 @@ pub enum RequestType {
Stream,
}
/// Labels for cancellation metrics
pub struct CancellationLabels {
pub model: String,
pub endpoint: String,
pub request_type: String,
}
/// Status
#[derive(PartialEq)]
pub enum Status {
......@@ -654,6 +662,15 @@ impl Metrics {
)
.unwrap();
let model_cancellation_total = IntCounterVec::new(
Opts::new(
frontend_metric_name(frontend_service::MODEL_CANCELLATION_TOTAL),
"Total number of request cancellations",
),
&["model", "endpoint", "request_type"],
)
.unwrap();
Metrics {
request_counter,
inflight_gauge,
......@@ -674,6 +691,7 @@ impl Metrics {
model_kv_cache_block_size,
model_migration_limit,
model_migration_total,
model_cancellation_total,
}
}
......@@ -778,6 +796,7 @@ impl Metrics {
registry.register(Box::new(self.model_kv_cache_block_size.clone()))?;
registry.register(Box::new(self.model_migration_limit.clone()))?;
registry.register(Box::new(self.model_migration_total.clone()))?;
registry.register(Box::new(self.model_cancellation_total.clone()))?;
Ok(())
}
......@@ -861,6 +880,20 @@ impl Metrics {
.get()
}
/// Increment the cancellation counter
pub fn inc_cancellation(&self, labels: &CancellationLabels) {
self.model_cancellation_total
.with_label_values(&[&labels.model, &labels.endpoint, &labels.request_type])
.inc();
}
/// Get the current cancellation count
pub fn get_cancellation_count(&self, labels: &CancellationLabels) -> u64 {
self.model_cancellation_total
.with_label_values(&[&labels.model, &labels.endpoint, &labels.request_type])
.get()
}
/// Create a new [`InflightGuard`] for the given model and annotate if its a streaming request,
/// and the kind of endpoint that was hit
///
......
......@@ -36,7 +36,8 @@ use super::{
disconnect::{ConnectionHandle, create_connection_monitor, monitor_for_disconnects},
error::HttpError,
metrics::{
Endpoint, ErrorType, EventConverter, process_response_and_observe_metrics,
CancellationLabels, Endpoint, ErrorType, EventConverter,
process_response_and_observe_metrics,
process_response_using_event_converter_and_observe_metrics,
},
service_v2,
......@@ -341,12 +342,22 @@ async fn handler_completions(
// create the context for the request
let request_id = get_or_create_request_id(request.inner.user.as_deref(), &headers);
let streaming = request.inner.stream.unwrap_or(false);
let cancellation_labels = CancellationLabels {
model: request.inner.model.clone(),
endpoint: Endpoint::Completions.to_string(),
request_type: if streaming { "stream" } else { "unary" }.to_string(),
};
let request = Context::with_id(request, request_id);
let context = request.context();
// create the connection handles
let (mut connection_handle, stream_handle) =
create_connection_monitor(context.clone(), Some(state.metrics_clone())).await;
let (mut connection_handle, stream_handle) = create_connection_monitor(
context.clone(),
Some(state.metrics_clone()),
cancellation_labels,
)
.await;
// possibly long running task
// if this returns a streaming response, the stream handle will be armed and captured by the response stream
......@@ -789,12 +800,22 @@ async fn handler_chat_completions(
// create the context for the request
let request_id = get_or_create_request_id(request.inner.user.as_deref(), &headers);
let streaming = request.inner.stream.unwrap_or(false);
let cancellation_labels = CancellationLabels {
model: request.inner.model.clone(),
endpoint: Endpoint::ChatCompletions.to_string(),
request_type: if streaming { "stream" } else { "unary" }.to_string(),
};
let request = Context::with_id(request, request_id);
let context = request.context();
// create the connection handles
let (mut connection_handle, stream_handle) =
create_connection_monitor(context.clone(), Some(state.metrics_clone())).await;
let (mut connection_handle, stream_handle) = create_connection_monitor(
context.clone(),
Some(state.metrics_clone()),
cancellation_labels,
)
.await;
let response =
tokio::spawn(chat_completions(state, template, request, stream_handle).in_current_span())
......@@ -1388,12 +1409,22 @@ async fn handler_responses(
// create the context for the request
let request_id = get_or_create_request_id(None, &headers);
let streaming = request.inner.stream.unwrap_or(false);
let cancellation_labels = CancellationLabels {
model: request.inner.model.clone().unwrap_or_default(),
endpoint: Endpoint::Responses.to_string(),
request_type: if streaming { "stream" } else { "unary" }.to_string(),
};
let request = Context::with_id(request, request_id);
let context = request.context();
// create the connection handles
let (mut connection_handle, stream_handle) =
create_connection_monitor(context.clone(), Some(state.metrics_clone())).await;
let (mut connection_handle, stream_handle) = create_connection_monitor(
context.clone(),
Some(state.metrics_clone()),
cancellation_labels,
)
.await;
let response =
tokio::spawn(responses(state, template, request, stream_handle).in_current_span())
......@@ -2054,8 +2085,16 @@ async fn video_stream(
// video_stream returns the streaming body directly (graceful handler exit).
// The stream_handle is armed below and lives inside the monitored stream so that
// a client disconnect (body drop) signals the engine context to cancel.
let (mut connection_handle, mut stream_handle) =
create_connection_monitor(ctx.clone(), Some(state.metrics_clone())).await;
let (mut connection_handle, mut stream_handle) = create_connection_monitor(
ctx.clone(),
Some(state.metrics_clone()),
CancellationLabels {
model: model.clone(),
endpoint: Endpoint::Videos.to_string(),
request_type: "stream".to_string(),
},
)
.await;
connection_handle.disarm();
let mut http_queue_guard = Some(http_queue_guard);
......
......@@ -232,6 +232,9 @@ pub mod frontend_service {
/// Total number of request migrations due to worker unavailability
pub const MODEL_MIGRATION_TOTAL: &str = "model_migration_total";
/// Total number of request cancellations
pub const MODEL_CANCELLATION_TOTAL: &str = "model_cancellation_total";
/// Active decode blocks (KV cache blocks) per worker
/// Gauge metric tracking current KV cache block utilization for each worker
pub const WORKER_ACTIVE_DECODE_BLOCKS: &str = "worker_active_decode_blocks";
......@@ -346,6 +349,9 @@ pub mod work_handler {
/// Total number of errors in work handler processing
pub const ERRORS_TOTAL: &str = "errors_total";
/// Total number of requests cancelled by work handler (client stop/kill or disconnect)
pub const CANCELLATION_TOTAL: &str = "cancellation_total";
/// Network transit: frontend send to backend receive (wall-clock, cross-process)
pub const NETWORK_TRANSIT_SECONDS: &str = "network_transit_seconds";
......
......@@ -24,6 +24,7 @@ pub struct WorkHandlerMetrics {
pub request_bytes: IntCounter,
pub response_bytes: IntCounter,
pub error_counter: IntCounterVec,
pub cancellation_total: IntCounter,
}
impl WorkHandlerMetrics {
......@@ -34,6 +35,7 @@ impl WorkHandlerMetrics {
request_bytes: IntCounter,
response_bytes: IntCounter,
error_counter: IntCounterVec,
cancellation_total: IntCounter,
) -> Self {
Self {
request_counter,
......@@ -42,6 +44,7 @@ impl WorkHandlerMetrics {
request_bytes,
response_bytes,
error_counter,
cancellation_total,
}
}
......@@ -90,6 +93,12 @@ impl WorkHandlerMetrics {
metrics_labels,
)?;
let cancellation_total = metrics.create_intcounter(
work_handler::CANCELLATION_TOTAL,
"Total number of requests cancelled by work handler",
metrics_labels,
)?;
Ok(Self::new(
request_counter,
request_duration,
......@@ -97,6 +106,7 @@ impl WorkHandlerMetrics {
request_bytes,
response_bytes,
error_counter,
cancellation_total,
))
}
}
......@@ -218,6 +228,7 @@ where
let mut publisher = tcp::client::TcpClient::create_response_stream(
request.context(),
control_msg.connection_info,
self.metrics().map(|m| m.cancellation_total.clone()),
)
.await
.map_err(|e| {
......
......@@ -124,8 +124,11 @@ mod tests {
let context_rank1 = Context::with_id((), context_rank0.id().to_string());
// connect to the server socket
let mut send_stream =
client::TcpClient::create_response_stream(context_rank1.context(), connection_info)
let mut send_stream = client::TcpClient::create_response_stream(
context_rank1.context(),
connection_info,
None,
)
.await
.unwrap();
println!("Client connected");
......
......@@ -12,6 +12,8 @@ use tokio::{
};
use tokio_util::codec::{FramedRead, FramedWrite};
use prometheus::IntCounter;
use super::{CallHomeHandshake, ControlMessage, TcpStreamConnectionInfo};
use crate::engine::AsyncEngineContext;
use crate::pipeline::network::{
......@@ -63,6 +65,7 @@ impl TcpClient {
pub async fn create_response_stream(
context: Arc<dyn AsyncEngineContext>,
info: ConnectionInfo,
cancellation_counter: Option<IntCounter>,
) -> Result<StreamSender> {
let info =
TcpStreamConnectionInfo::try_from(info).context("tcp-stream-connection-info-error")?;
......@@ -97,7 +100,12 @@ impl TcpClient {
// captured by the monitor task
let (alive_tx, alive_rx) = tokio::sync::oneshot::channel::<()>();
let reader_task = tokio::spawn(handle_reader(framed_reader, context.clone(), alive_tx));
let reader_task = tokio::spawn(handle_reader(
framed_reader,
context.clone(),
alive_tx,
cancellation_counter,
));
// transport specific handshake message
let handshake = CallHomeHandshake {
......@@ -213,9 +221,11 @@ async fn handle_reader(
framed_reader: FramedRead<tokio::io::ReadHalf<tokio::net::TcpStream>, TwoPartCodec>,
context: Arc<dyn AsyncEngineContext>,
alive_tx: tokio::sync::oneshot::Sender<()>,
cancellation_counter: Option<IntCounter>,
) -> FramedRead<tokio::io::ReadHalf<tokio::net::TcpStream>, TwoPartCodec> {
let mut framed_reader = framed_reader;
let mut alive_tx = alive_tx;
let mut cancellation_counted = false;
loop {
tokio::select! {
msg = framed_reader.next() => {
......@@ -233,9 +243,17 @@ async fn handle_reader(
match msg {
ControlMessage::Stop => {
if let Some(counter) = &cancellation_counter && !cancellation_counted {
counter.inc();
cancellation_counted = true;
}
context.stop();
}
ControlMessage::Kill => {
if let Some(counter) = &cancellation_counter && !cancellation_counted {
counter.inc();
cancellation_counted = true;
}
context.kill();
}
ControlMessage::Sentinel => {
......@@ -256,6 +274,11 @@ async fn handle_reader(
}
None => {
tracing::debug!("tcp stream closed by server");
// If no Stop/Kill was received, this is a cancellation where frontend
// dropped the connection
if let Some(counter) = &cancellation_counter && !cancellation_counted {
counter.inc();
}
break;
}
}
......@@ -768,10 +791,9 @@ mod tests {
// Spawn the reader task
let controller_clone = controller.clone();
let reader_handle =
tokio::spawn(
async move { handle_reader(framed_reader, controller_clone, alive_tx).await },
);
let reader_handle = tokio::spawn(async move {
handle_reader(framed_reader, controller_clone, alive_tx, None).await
});
// Send Stop control message from server
framed_server
......@@ -805,10 +827,9 @@ mod tests {
// Spawn the reader task
let controller_clone = controller.clone();
let reader_handle =
tokio::spawn(
async move { handle_reader(framed_reader, controller_clone, alive_tx).await },
);
let reader_handle = tokio::spawn(async move {
handle_reader(framed_reader, controller_clone, alive_tx, None).await
});
// Send Kill control message from server
framed_server
......@@ -842,7 +863,9 @@ mod tests {
// Spawn the reader task
let reader_handle =
tokio::spawn(async move { handle_reader(framed_reader, controller, alive_tx).await });
tokio::spawn(
async move { handle_reader(framed_reader, controller, alive_tx, None).await },
);
// Drop the alive_rx to close the channel (simulating writer finishing)
drop(alive_rx);
......@@ -869,7 +892,9 @@ mod tests {
// Spawn the reader task
let reader_handle =
tokio::spawn(async move { handle_reader(framed_reader, controller, alive_tx).await });
tokio::spawn(
async move { handle_reader(framed_reader, controller, alive_tx, None).await },
);
// Close the framed server to signal EOF to the client
framed_server.close().await.unwrap();
......@@ -896,10 +921,9 @@ mod tests {
// Spawn the reader task
let controller_clone = controller.clone();
let reader_handle =
tokio::spawn(
async move { handle_reader(framed_reader, controller_clone, alive_tx).await },
);
let reader_handle = tokio::spawn(async move {
handle_reader(framed_reader, controller_clone, alive_tx, None).await
});
// Send multiple Stop messages (first one will stop, subsequent ones are no-ops)
framed_server
......@@ -937,10 +961,9 @@ mod tests {
// Spawn the reader task
let controller_clone = controller.clone();
let reader_handle =
tokio::spawn(
async move { handle_reader(framed_reader, controller_clone, alive_tx).await },
);
let reader_handle = tokio::spawn(async move {
handle_reader(framed_reader, controller_clone, alive_tx, None).await
});
// Send Stop first, then Kill
framed_server
......
......@@ -20,6 +20,8 @@ from tests.fault_tolerance.cancellation.utils import (
poll_for_pattern,
read_streaming_responses,
send_cancellable_request,
verify_frontend_cancellation_metrics,
verify_runtime_cancellation_metrics,
)
from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME
from tests.utils.managed_process import ManagedProcess
......@@ -242,7 +244,7 @@ def test_request_cancellation_sglang_aggregated(
),
]
for request_type, description in test_scenarios:
for idx, (request_type, description) in enumerate(test_scenarios):
logger.info(f"Testing {description.lower()}...")
# Send the request (non-blocking)
......@@ -291,6 +293,17 @@ def test_request_cancellation_sglang_aggregated(
logger.info(f"{description} detected successfully")
# Verify cancellation metrics after each scenario
verify_frontend_cancellation_metrics(
frontend_port=frontend.frontend_port,
request_type=request_type,
expected_count=1,
)
verify_runtime_cancellation_metrics(
worker_system_port=worker.system_port,
expected_count=idx + 1,
)
@pytest.mark.timeout(300) # 3x average
@pytest.mark.gpu_2
......@@ -396,3 +409,19 @@ def test_request_cancellation_sglang_decode_cancel(
logger.info(
"Chat completion stream cancellation in decode phase detected successfully"
)
# Verify cancellation metrics
verify_frontend_cancellation_metrics(
frontend_port=frontend.frontend_port,
request_type="chat_completion_stream",
expected_count=1,
)
verify_runtime_cancellation_metrics(
worker_system_port=decode_worker.system_port,
expected_count=1,
)
verify_runtime_cancellation_metrics(
worker_system_port=prefill_worker.system_port,
expected_count=0,
component="prefill",
)
......@@ -22,6 +22,8 @@ from tests.fault_tolerance.cancellation.utils import (
poll_for_pattern,
read_streaming_responses,
send_cancellable_request,
verify_frontend_cancellation_metrics,
verify_runtime_cancellation_metrics,
)
from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME
from tests.utils.managed_process import ManagedProcess
......@@ -208,7 +210,7 @@ def test_request_cancellation_trtllm_aggregated(
),
]
for request_type, description in test_scenarios:
for idx, (request_type, description) in enumerate(test_scenarios):
logger.info(f"Testing {description.lower()}...")
# Send the request (non-blocking)
......@@ -248,6 +250,18 @@ def test_request_cancellation_trtllm_aggregated(
logger.info(f"{description} detected successfully")
# Verify cancellation metrics after each scenario
verify_frontend_cancellation_metrics(
frontend_port=frontend.frontend_port,
request_type=request_type,
expected_count=1,
)
verify_runtime_cancellation_metrics(
worker_system_port=worker.system_port,
expected_count=idx + 1,
component="tensorrt_llm",
)
@pytest.mark.timeout(195) # 3x average
def test_request_cancellation_trtllm_decode_cancel(
......@@ -332,6 +346,23 @@ def test_request_cancellation_trtllm_decode_cancel(
"Chat completion stream cancellation in decode phase detected successfully"
)
# Verify cancellation metrics
verify_frontend_cancellation_metrics(
frontend_port=frontend.frontend_port,
request_type="chat_completion_stream",
expected_count=1,
)
verify_runtime_cancellation_metrics(
worker_system_port=decode_worker.system_port,
expected_count=1,
component="tensorrt_llm",
)
verify_runtime_cancellation_metrics(
worker_system_port=prefill_worker.system_port,
expected_count=0,
component="prefill",
)
@pytest.mark.timeout(195) # 3x average
def test_request_cancellation_trtllm_prefill_cancel(
......@@ -424,6 +455,23 @@ def test_request_cancellation_trtllm_prefill_cancel(
"Completion request cancellation during prefill phase detected successfully"
)
# Verify cancellation metrics
verify_frontend_cancellation_metrics(
frontend_port=frontend.frontend_port,
request_type="completion",
expected_count=1,
)
verify_runtime_cancellation_metrics(
worker_system_port=decode_worker.system_port,
expected_count=0,
component="tensorrt_llm",
)
verify_runtime_cancellation_metrics(
worker_system_port=prefill_worker.system_port,
expected_count=1,
component="prefill",
)
@pytest.mark.xfail(reason="Test fails only on CI", strict=False)
@pytest.mark.timeout(195) # 3x average
......@@ -523,3 +571,20 @@ def test_request_cancellation_trtllm_kv_transfer_cancel(
logger.info(
"Workers are functional after cancellation during KV transfer"
)
# Verify cancellation metrics
verify_frontend_cancellation_metrics(
frontend_port=frontend.frontend_port,
request_type="completion",
expected_count=1,
)
verify_runtime_cancellation_metrics(
worker_system_port=decode_worker.system_port,
expected_count=1,
component="tensorrt_llm",
)
verify_runtime_cancellation_metrics(
worker_system_port=prefill_worker.system_port,
expected_count=0,
component="prefill",
)
......@@ -21,6 +21,8 @@ from tests.fault_tolerance.cancellation.utils import (
poll_for_pattern,
read_streaming_responses,
send_cancellable_request,
verify_frontend_cancellation_metrics,
verify_runtime_cancellation_metrics,
)
from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME
from tests.utils.managed_process import ManagedProcess
......@@ -242,7 +244,7 @@ def test_request_cancellation_vllm_aggregated(
),
]
for request_type, description in test_scenarios:
for idx, (request_type, description) in enumerate(test_scenarios):
logger.info(f"Testing {description.lower()}...")
# Send the request (non-blocking)
......@@ -282,6 +284,17 @@ def test_request_cancellation_vllm_aggregated(
logger.info(f"{description} detected successfully")
# Verify cancellation metrics after each scenario
verify_frontend_cancellation_metrics(
frontend_port=frontend.frontend_port,
request_type=request_type,
expected_count=1,
)
verify_runtime_cancellation_metrics(
worker_system_port=worker.system_port,
expected_count=idx + 1,
)
@pytest.mark.timeout(150) # 3x average
@pytest.mark.nightly
......@@ -365,6 +378,22 @@ def test_request_cancellation_vllm_decode_cancel(
"Chat completion stream cancellation in decode phase detected successfully"
)
# Verify cancellation metrics
verify_frontend_cancellation_metrics(
frontend_port=frontend.frontend_port,
request_type="chat_completion_stream",
expected_count=1,
)
verify_runtime_cancellation_metrics(
worker_system_port=decode_worker.system_port,
expected_count=1,
)
verify_runtime_cancellation_metrics(
worker_system_port=prefill_worker.system_port,
expected_count=0,
component="prefill",
)
@pytest.mark.timeout(150) # 3x average
@pytest.mark.nightly
......@@ -458,3 +487,19 @@ def test_request_cancellation_vllm_prefill_cancel(
logger.info(
"Completion request cancellation during prefill phase detected successfully"
)
# Verify cancellation metrics
verify_frontend_cancellation_metrics(
frontend_port=frontend.frontend_port,
request_type="completion",
expected_count=1,
)
verify_runtime_cancellation_metrics(
worker_system_port=decode_worker.system_port,
expected_count=0,
)
verify_runtime_cancellation_metrics(
worker_system_port=prefill_worker.system_port,
expected_count=1,
component="prefill",
)
......@@ -298,6 +298,170 @@ def read_streaming_responses(
)
def _parse_frontend_cancellation_metric(
metrics_text: str, model_name: str, endpoint: str, request_type: str
) -> int:
"""
Parse the frontend cancellation metric from Prometheus metrics text.
Args:
metrics_text: Raw Prometheus metrics text
model_name: The model name label value
endpoint: The endpoint label value (e.g. "completions", "chat_completions")
request_type: The request_type label value ("stream" or "unary")
Returns:
The metric count, or 0 if not found
"""
for line in metrics_text.splitlines():
if not line.startswith("dynamo_frontend_model_cancellation_total{"):
continue
if (
f'endpoint="{endpoint}"' in line
and f'model="{model_name}"' in line
and f'request_type="{request_type}"' in line
):
parts = line.rsplit(None, 1)
if len(parts) == 2:
try:
return int(float(parts[1]))
except ValueError:
pass
return 0
def _parse_runtime_cancellation_metric(
metrics_text: str,
namespace: str = "dynamo",
component: str = "backend",
endpoint: str = "generate",
) -> int:
"""
Parse the runtime cancellation metric from Prometheus metrics text.
The metric is dynamo_component_cancellation_total with auto-injected
labels (dynamo_namespace, dynamo_component, dynamo_endpoint).
Args:
metrics_text: Raw Prometheus metrics text
namespace: Expected dynamo_namespace label value
component: Expected dynamo_component label value
endpoint: Expected dynamo_endpoint label value
Returns:
The metric count, or 0 if not found
"""
for line in metrics_text.splitlines():
if not line.startswith("dynamo_component_cancellation_total{"):
continue
if (
f'dynamo_namespace="{namespace}"' in line
and f'dynamo_component="{component}"' in line
and f'dynamo_endpoint="{endpoint}"' in line
):
parts = line.rsplit(None, 1)
if len(parts) == 2:
try:
return int(float(parts[1]))
except ValueError:
pass
return 0
def _resolve_cancellation_labels(request_type: str) -> tuple[str, str]:
"""
Map a test request type to frontend metric labels.
Args:
request_type: One of "completion", "chat_completion", "chat_completion_stream"
Returns:
(endpoint, request_type_label) tuple
"""
mapping = {
"completion": ("completions", "unary"),
"chat_completion": ("chat_completions", "unary"),
"chat_completion_stream": ("chat_completions", "stream"),
}
if request_type not in mapping:
pytest.fail(f"Unknown request type: {request_type}")
return mapping[request_type]
def verify_frontend_cancellation_metrics(
frontend_port: int,
request_type: str,
expected_count: int = 0,
) -> None:
"""
Verify frontend cancellation metrics.
Args:
frontend_port: Port where the frontend /metrics is served
request_type: The test request type ("completion", "chat_completion", "chat_completion_stream")
expected_count: Expected cancellation count for this request type
"""
endpoint, req_type_label = _resolve_cancellation_labels(request_type)
frontend_metrics_url = f"http://localhost:{frontend_port}/metrics"
try:
response = requests.get(frontend_metrics_url, timeout=5)
response.raise_for_status()
except requests.RequestException as e:
pytest.fail(
f"Failed to fetch frontend metrics from {frontend_metrics_url}: {e}"
)
frontend_text = response.text
count = _parse_frontend_cancellation_metric(
frontend_text, FAULT_TOLERANCE_MODEL_NAME, endpoint, req_type_label
)
logger.info(
f"Frontend cancellation metrics - endpoint={endpoint}, "
f"request_type={req_type_label}: {count}"
)
assert count == expected_count, (
f"Frontend: expected {expected_count} cancellations "
f"for endpoint={endpoint}, request_type={req_type_label}, "
f"but got {count}"
)
def verify_runtime_cancellation_metrics(
worker_system_port: int,
expected_count: int = 0,
component: str = "backend",
) -> None:
"""
Verify runtime (worker) cancellation metrics.
Args:
worker_system_port: Port where the worker /metrics is served
expected_count: Expected cumulative cancellation count
component: The dynamo_component label value (e.g. "backend", "prefill")
"""
worker_metrics_url = f"http://localhost:{worker_system_port}/metrics"
try:
response = requests.get(worker_metrics_url, timeout=5)
response.raise_for_status()
except requests.RequestException as e:
pytest.fail(f"Failed to fetch worker metrics from {worker_metrics_url}: {e}")
worker_text = response.text
count = _parse_runtime_cancellation_metric(worker_text, component=component)
logger.info(f"Runtime cancellation metrics (component={component}): {count}")
assert count == expected_count, (
f"Runtime (component={component}): expected {expected_count} cancellations, "
f"but got {count}"
)
def read_log_content(log_path: str | None) -> str:
"""Read log content from a file"""
if log_path is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment