"...ssh:/git@developer.sourcefind.cn:2222/OpenDAS/dynamo.git" did not exist on "3023c6258ab4f989d070abc4fa29fde8971d51e2"
Unverified Commit e252f82d authored by Jacky's avatar Jacky Committed by GitHub
Browse files

feat: Frontend and Runtime Request Cancellation Metrics (#7493)


Signed-off-by: default avatarJacky <18255193+kthui@users.noreply.github.com>
parent 4159c6fd
...@@ -21,7 +21,7 @@ use super::kserve::inference; ...@@ -21,7 +21,7 @@ use super::kserve::inference;
// [gluo NOTE] These are common utilities that should be shared between frontends // [gluo NOTE] These are common utilities that should be shared between frontends
use crate::http::service::{ use crate::http::service::{
disconnect::{ConnectionHandle, create_connection_monitor}, disconnect::{ConnectionHandle, create_connection_monitor},
metrics::{Endpoint, InflightGuard, process_response_and_observe_metrics}, metrics::{CancellationLabels, Endpoint, InflightGuard, process_response_and_observe_metrics},
}; };
use dynamo_async_openai::types::{CompletionFinishReason, CreateCompletionRequest, Prompt}; use dynamo_async_openai::types::{CompletionFinishReason, CreateCompletionRequest, Prompt};
...@@ -54,14 +54,22 @@ pub async fn completion_response_stream( ...@@ -54,14 +54,22 @@ pub async fn completion_response_stream(
// create the context for the request // create the context for the request
// [WIP] from request id. // [WIP] from request id.
let request_id = get_or_create_request_id(request.inner.user.as_deref()); let request_id = get_or_create_request_id(request.inner.user.as_deref());
let streaming = request.inner.stream.unwrap_or(false);
let cancellation_labels = CancellationLabels {
model: request.inner.model.clone(),
endpoint: "grpc_completions".to_string(),
request_type: if streaming { "stream" } else { "unary" }.to_string(),
};
let request = Context::with_id(request, request_id.clone()); let request = Context::with_id(request, request_id.clone());
let context = request.context(); let context = request.context();
// create the connection handles // create the connection handles
let (mut connection_handle, stream_handle) = let (mut connection_handle, stream_handle) = create_connection_monitor(
create_connection_monitor(context.clone(), Some(state.metrics_clone())).await; context.clone(),
Some(state.metrics_clone()),
let streaming = request.inner.stream.unwrap_or(false); cancellation_labels,
)
.await;
// update the request to always stream // update the request to always stream
let request = request.map(|mut req| { let request = request.map(|mut req| {
req.inner.stream = Some(true); req.inner.stream = Some(true);
......
...@@ -20,7 +20,7 @@ use validator::Validate; ...@@ -20,7 +20,7 @@ use validator::Validate;
use crate::http::service::metrics::InflightGuard; use crate::http::service::metrics::InflightGuard;
use crate::http::service::{ use crate::http::service::{
disconnect::{ConnectionHandle, create_connection_monitor}, disconnect::{ConnectionHandle, create_connection_monitor},
metrics::{Endpoint, process_response_and_observe_metrics}, metrics::{CancellationLabels, Endpoint, process_response_and_observe_metrics},
}; };
use crate::protocols::tensor; use crate::protocols::tensor;
...@@ -60,13 +60,22 @@ pub async fn tensor_response_stream( ...@@ -60,13 +60,22 @@ pub async fn tensor_response_stream(
) -> Result<impl Stream<Item = Annotated<NvCreateTensorResponse>>, Status> { ) -> Result<impl Stream<Item = Annotated<NvCreateTensorResponse>>, Status> {
// create the context for the request // create the context for the request
let request_id = get_or_create_request_id(request.id.as_deref()); let request_id = get_or_create_request_id(request.id.as_deref());
let cancellation_labels = CancellationLabels {
model: request.model.clone(),
endpoint: Endpoint::Tensor.to_string(),
request_type: if streaming { "stream" } else { "unary" }.to_string(),
};
let request = Context::with_id(request, request_id.clone()); let request = Context::with_id(request, request_id.clone());
let context = request.context(); let context = request.context();
// [gluo TODO] revisit metrics to properly expose it // [gluo TODO] revisit metrics to properly expose it
// create the connection handles // create the connection handles
let (mut connection_handle, stream_handle) = let (mut connection_handle, stream_handle) = create_connection_monitor(
create_connection_monitor(context.clone(), Some(state.metrics_clone())).await; context.clone(),
Some(state.metrics_clone()),
cancellation_labels,
)
.await;
// todo - make the protocols be optional for model name // todo - make the protocols be optional for model name
// todo - when optional, if none, apply a default // todo - when optional, if none, apply a default
......
...@@ -30,7 +30,7 @@ use tracing::Instrument; ...@@ -30,7 +30,7 @@ use tracing::Instrument;
use super::{ use super::{
RouteDoc, RouteDoc,
disconnect::{ConnectionHandle, create_connection_monitor, monitor_for_disconnects}, disconnect::{ConnectionHandle, create_connection_monitor, monitor_for_disconnects},
metrics::{Endpoint, process_response_and_observe_metrics}, metrics::{CancellationLabels, Endpoint, process_response_and_observe_metrics},
service_v2, service_v2,
}; };
use crate::preprocessor::OpenAIPreprocessor; use crate::preprocessor::OpenAIPreprocessor;
...@@ -125,12 +125,22 @@ async fn handler_anthropic_messages( ...@@ -125,12 +125,22 @@ async fn handler_anthropic_messages(
// Create request context // Create request context
let request_id = get_or_create_request_id(None, &headers); let request_id = get_or_create_request_id(None, &headers);
let streaming = request.stream;
let cancellation_labels = CancellationLabels {
model: request.model.clone(),
endpoint: Endpoint::AnthropicMessages.to_string(),
request_type: if streaming { "stream" } else { "unary" }.to_string(),
};
let request = Context::with_id(request, request_id); let request = Context::with_id(request, request_id);
let context = request.context(); let context = request.context();
// Create connection handles // Create connection handles
let (mut connection_handle, stream_handle) = let (mut connection_handle, stream_handle) = create_connection_monitor(
create_connection_monitor(context.clone(), Some(state.metrics_clone())).await; context.clone(),
Some(state.metrics_clone()),
cancellation_labels,
)
.await;
let response = let response =
tokio::spawn(anthropic_messages(state, template, request, stream_handle).in_current_span()) tokio::spawn(anthropic_messages(state, template, request, stream_handle).in_current_span())
......
...@@ -33,7 +33,7 @@ use dynamo_runtime::engine::AsyncEngineContext; ...@@ -33,7 +33,7 @@ use dynamo_runtime::engine::AsyncEngineContext;
use futures::{Stream, StreamExt}; use futures::{Stream, StreamExt};
use std::sync::Arc; use std::sync::Arc;
use crate::http::service::metrics::{ErrorType, InflightGuard, Metrics}; use crate::http::service::metrics::{CancellationLabels, ErrorType, InflightGuard, Metrics};
#[derive(Clone, Copy)] #[derive(Clone, Copy)]
pub enum ConnectionStatus { pub enum ConnectionStatus {
...@@ -100,6 +100,7 @@ impl Drop for ConnectionHandle { ...@@ -100,6 +100,7 @@ impl Drop for ConnectionHandle {
pub async fn create_connection_monitor( pub async fn create_connection_monitor(
engine_context: Arc<dyn AsyncEngineContext>, engine_context: Arc<dyn AsyncEngineContext>,
metrics: Option<Arc<Metrics>>, metrics: Option<Arc<Metrics>>,
cancellation_labels: CancellationLabels,
) -> (ConnectionHandle, ConnectionHandle) { ) -> (ConnectionHandle, ConnectionHandle) {
// these oneshot channels monitor possible disconnects from the client in two different scopes: // these oneshot channels monitor possible disconnects from the client in two different scopes:
// - the local task (connection_handle) // - the local task (connection_handle)
...@@ -113,6 +114,7 @@ pub async fn create_connection_monitor( ...@@ -113,6 +114,7 @@ pub async fn create_connection_monitor(
connection_rx, connection_rx,
stream_rx, stream_rx,
metrics, metrics,
cancellation_labels,
)); ));
// Two handles, the first is armed, the second is disarmed // Two handles, the first is armed, the second is disarmed
...@@ -128,6 +130,7 @@ async fn connection_monitor( ...@@ -128,6 +130,7 @@ async fn connection_monitor(
connection_rx: tokio::sync::oneshot::Receiver<ConnectionStatus>, connection_rx: tokio::sync::oneshot::Receiver<ConnectionStatus>,
stream_rx: tokio::sync::oneshot::Receiver<ConnectionStatus>, stream_rx: tokio::sync::oneshot::Receiver<ConnectionStatus>,
metrics: Option<Arc<Metrics>>, metrics: Option<Arc<Metrics>>,
cancellation_labels: CancellationLabels,
) { ) {
match connection_rx.await { match connection_rx.await {
Err(_) | Ok(ConnectionStatus::ClosedUnexpectedly) => { Err(_) | Ok(ConnectionStatus::ClosedUnexpectedly) => {
...@@ -135,6 +138,7 @@ async fn connection_monitor( ...@@ -135,6 +138,7 @@ async fn connection_monitor(
tracing::trace!("Connection closed unexpectedly; issuing cancellation"); tracing::trace!("Connection closed unexpectedly; issuing cancellation");
if let Some(metrics) = &metrics { if let Some(metrics) = &metrics {
metrics.inc_client_disconnect(); metrics.inc_client_disconnect();
metrics.inc_cancellation(&cancellation_labels);
} }
engine_context.kill(); engine_context.kill();
} }
...@@ -149,6 +153,7 @@ async fn connection_monitor( ...@@ -149,6 +153,7 @@ async fn connection_monitor(
tracing::trace!("Stream closed unexpectedly; issuing cancellation"); tracing::trace!("Stream closed unexpectedly; issuing cancellation");
if let Some(metrics) = &metrics { if let Some(metrics) = &metrics {
metrics.inc_client_disconnect(); metrics.inc_client_disconnect();
metrics.inc_cancellation(&cancellation_labels);
} }
engine_context.kill(); engine_context.kill();
} }
......
...@@ -256,6 +256,7 @@ pub struct Metrics { ...@@ -256,6 +256,7 @@ pub struct Metrics {
model_kv_cache_block_size: IntGaugeVec, model_kv_cache_block_size: IntGaugeVec,
model_migration_limit: IntGaugeVec, model_migration_limit: IntGaugeVec,
model_migration_total: IntCounterVec, model_migration_total: IntCounterVec,
model_cancellation_total: IntCounterVec,
} }
// Inflight tracks requests from HTTP handler start until complete response is finished. // Inflight tracks requests from HTTP handler start until complete response is finished.
...@@ -322,6 +323,13 @@ pub enum RequestType { ...@@ -322,6 +323,13 @@ pub enum RequestType {
Stream, Stream,
} }
/// Labels for cancellation metrics
pub struct CancellationLabels {
pub model: String,
pub endpoint: String,
pub request_type: String,
}
/// Status /// Status
#[derive(PartialEq)] #[derive(PartialEq)]
pub enum Status { pub enum Status {
...@@ -654,6 +662,15 @@ impl Metrics { ...@@ -654,6 +662,15 @@ impl Metrics {
) )
.unwrap(); .unwrap();
let model_cancellation_total = IntCounterVec::new(
Opts::new(
frontend_metric_name(frontend_service::MODEL_CANCELLATION_TOTAL),
"Total number of request cancellations",
),
&["model", "endpoint", "request_type"],
)
.unwrap();
Metrics { Metrics {
request_counter, request_counter,
inflight_gauge, inflight_gauge,
...@@ -674,6 +691,7 @@ impl Metrics { ...@@ -674,6 +691,7 @@ impl Metrics {
model_kv_cache_block_size, model_kv_cache_block_size,
model_migration_limit, model_migration_limit,
model_migration_total, model_migration_total,
model_cancellation_total,
} }
} }
...@@ -778,6 +796,7 @@ impl Metrics { ...@@ -778,6 +796,7 @@ impl Metrics {
registry.register(Box::new(self.model_kv_cache_block_size.clone()))?; registry.register(Box::new(self.model_kv_cache_block_size.clone()))?;
registry.register(Box::new(self.model_migration_limit.clone()))?; registry.register(Box::new(self.model_migration_limit.clone()))?;
registry.register(Box::new(self.model_migration_total.clone()))?; registry.register(Box::new(self.model_migration_total.clone()))?;
registry.register(Box::new(self.model_cancellation_total.clone()))?;
Ok(()) Ok(())
} }
...@@ -861,6 +880,20 @@ impl Metrics { ...@@ -861,6 +880,20 @@ impl Metrics {
.get() .get()
} }
/// Increment the cancellation counter
pub fn inc_cancellation(&self, labels: &CancellationLabels) {
self.model_cancellation_total
.with_label_values(&[&labels.model, &labels.endpoint, &labels.request_type])
.inc();
}
/// Get the current cancellation count
pub fn get_cancellation_count(&self, labels: &CancellationLabels) -> u64 {
self.model_cancellation_total
.with_label_values(&[&labels.model, &labels.endpoint, &labels.request_type])
.get()
}
/// Create a new [`InflightGuard`] for the given model and annotate if its a streaming request, /// Create a new [`InflightGuard`] for the given model and annotate if its a streaming request,
/// and the kind of endpoint that was hit /// and the kind of endpoint that was hit
/// ///
......
...@@ -36,7 +36,8 @@ use super::{ ...@@ -36,7 +36,8 @@ use super::{
disconnect::{ConnectionHandle, create_connection_monitor, monitor_for_disconnects}, disconnect::{ConnectionHandle, create_connection_monitor, monitor_for_disconnects},
error::HttpError, error::HttpError,
metrics::{ metrics::{
Endpoint, ErrorType, EventConverter, process_response_and_observe_metrics, CancellationLabels, Endpoint, ErrorType, EventConverter,
process_response_and_observe_metrics,
process_response_using_event_converter_and_observe_metrics, process_response_using_event_converter_and_observe_metrics,
}, },
service_v2, service_v2,
...@@ -341,12 +342,22 @@ async fn handler_completions( ...@@ -341,12 +342,22 @@ async fn handler_completions(
// create the context for the request // create the context for the request
let request_id = get_or_create_request_id(request.inner.user.as_deref(), &headers); let request_id = get_or_create_request_id(request.inner.user.as_deref(), &headers);
let streaming = request.inner.stream.unwrap_or(false);
let cancellation_labels = CancellationLabels {
model: request.inner.model.clone(),
endpoint: Endpoint::Completions.to_string(),
request_type: if streaming { "stream" } else { "unary" }.to_string(),
};
let request = Context::with_id(request, request_id); let request = Context::with_id(request, request_id);
let context = request.context(); let context = request.context();
// create the connection handles // create the connection handles
let (mut connection_handle, stream_handle) = let (mut connection_handle, stream_handle) = create_connection_monitor(
create_connection_monitor(context.clone(), Some(state.metrics_clone())).await; context.clone(),
Some(state.metrics_clone()),
cancellation_labels,
)
.await;
// possibly long running task // possibly long running task
// if this returns a streaming response, the stream handle will be armed and captured by the response stream // if this returns a streaming response, the stream handle will be armed and captured by the response stream
...@@ -789,12 +800,22 @@ async fn handler_chat_completions( ...@@ -789,12 +800,22 @@ async fn handler_chat_completions(
// create the context for the request // create the context for the request
let request_id = get_or_create_request_id(request.inner.user.as_deref(), &headers); let request_id = get_or_create_request_id(request.inner.user.as_deref(), &headers);
let streaming = request.inner.stream.unwrap_or(false);
let cancellation_labels = CancellationLabels {
model: request.inner.model.clone(),
endpoint: Endpoint::ChatCompletions.to_string(),
request_type: if streaming { "stream" } else { "unary" }.to_string(),
};
let request = Context::with_id(request, request_id); let request = Context::with_id(request, request_id);
let context = request.context(); let context = request.context();
// create the connection handles // create the connection handles
let (mut connection_handle, stream_handle) = let (mut connection_handle, stream_handle) = create_connection_monitor(
create_connection_monitor(context.clone(), Some(state.metrics_clone())).await; context.clone(),
Some(state.metrics_clone()),
cancellation_labels,
)
.await;
let response = let response =
tokio::spawn(chat_completions(state, template, request, stream_handle).in_current_span()) tokio::spawn(chat_completions(state, template, request, stream_handle).in_current_span())
...@@ -1388,12 +1409,22 @@ async fn handler_responses( ...@@ -1388,12 +1409,22 @@ async fn handler_responses(
// create the context for the request // create the context for the request
let request_id = get_or_create_request_id(None, &headers); let request_id = get_or_create_request_id(None, &headers);
let streaming = request.inner.stream.unwrap_or(false);
let cancellation_labels = CancellationLabels {
model: request.inner.model.clone().unwrap_or_default(),
endpoint: Endpoint::Responses.to_string(),
request_type: if streaming { "stream" } else { "unary" }.to_string(),
};
let request = Context::with_id(request, request_id); let request = Context::with_id(request, request_id);
let context = request.context(); let context = request.context();
// create the connection handles // create the connection handles
let (mut connection_handle, stream_handle) = let (mut connection_handle, stream_handle) = create_connection_monitor(
create_connection_monitor(context.clone(), Some(state.metrics_clone())).await; context.clone(),
Some(state.metrics_clone()),
cancellation_labels,
)
.await;
let response = let response =
tokio::spawn(responses(state, template, request, stream_handle).in_current_span()) tokio::spawn(responses(state, template, request, stream_handle).in_current_span())
...@@ -2054,8 +2085,16 @@ async fn video_stream( ...@@ -2054,8 +2085,16 @@ async fn video_stream(
// video_stream returns the streaming body directly (graceful handler exit). // video_stream returns the streaming body directly (graceful handler exit).
// The stream_handle is armed below and lives inside the monitored stream so that // The stream_handle is armed below and lives inside the monitored stream so that
// a client disconnect (body drop) signals the engine context to cancel. // a client disconnect (body drop) signals the engine context to cancel.
let (mut connection_handle, mut stream_handle) = let (mut connection_handle, mut stream_handle) = create_connection_monitor(
create_connection_monitor(ctx.clone(), Some(state.metrics_clone())).await; ctx.clone(),
Some(state.metrics_clone()),
CancellationLabels {
model: model.clone(),
endpoint: Endpoint::Videos.to_string(),
request_type: "stream".to_string(),
},
)
.await;
connection_handle.disarm(); connection_handle.disarm();
let mut http_queue_guard = Some(http_queue_guard); let mut http_queue_guard = Some(http_queue_guard);
......
...@@ -232,6 +232,9 @@ pub mod frontend_service { ...@@ -232,6 +232,9 @@ pub mod frontend_service {
/// Total number of request migrations due to worker unavailability /// Total number of request migrations due to worker unavailability
pub const MODEL_MIGRATION_TOTAL: &str = "model_migration_total"; pub const MODEL_MIGRATION_TOTAL: &str = "model_migration_total";
/// Total number of request cancellations
pub const MODEL_CANCELLATION_TOTAL: &str = "model_cancellation_total";
/// Active decode blocks (KV cache blocks) per worker /// Active decode blocks (KV cache blocks) per worker
/// Gauge metric tracking current KV cache block utilization for each worker /// Gauge metric tracking current KV cache block utilization for each worker
pub const WORKER_ACTIVE_DECODE_BLOCKS: &str = "worker_active_decode_blocks"; pub const WORKER_ACTIVE_DECODE_BLOCKS: &str = "worker_active_decode_blocks";
...@@ -346,6 +349,9 @@ pub mod work_handler { ...@@ -346,6 +349,9 @@ pub mod work_handler {
/// Total number of errors in work handler processing /// Total number of errors in work handler processing
pub const ERRORS_TOTAL: &str = "errors_total"; pub const ERRORS_TOTAL: &str = "errors_total";
/// Total number of requests cancelled by work handler (client stop/kill or disconnect)
pub const CANCELLATION_TOTAL: &str = "cancellation_total";
/// Network transit: frontend send to backend receive (wall-clock, cross-process) /// Network transit: frontend send to backend receive (wall-clock, cross-process)
pub const NETWORK_TRANSIT_SECONDS: &str = "network_transit_seconds"; pub const NETWORK_TRANSIT_SECONDS: &str = "network_transit_seconds";
......
...@@ -24,6 +24,7 @@ pub struct WorkHandlerMetrics { ...@@ -24,6 +24,7 @@ pub struct WorkHandlerMetrics {
pub request_bytes: IntCounter, pub request_bytes: IntCounter,
pub response_bytes: IntCounter, pub response_bytes: IntCounter,
pub error_counter: IntCounterVec, pub error_counter: IntCounterVec,
pub cancellation_total: IntCounter,
} }
impl WorkHandlerMetrics { impl WorkHandlerMetrics {
...@@ -34,6 +35,7 @@ impl WorkHandlerMetrics { ...@@ -34,6 +35,7 @@ impl WorkHandlerMetrics {
request_bytes: IntCounter, request_bytes: IntCounter,
response_bytes: IntCounter, response_bytes: IntCounter,
error_counter: IntCounterVec, error_counter: IntCounterVec,
cancellation_total: IntCounter,
) -> Self { ) -> Self {
Self { Self {
request_counter, request_counter,
...@@ -42,6 +44,7 @@ impl WorkHandlerMetrics { ...@@ -42,6 +44,7 @@ impl WorkHandlerMetrics {
request_bytes, request_bytes,
response_bytes, response_bytes,
error_counter, error_counter,
cancellation_total,
} }
} }
...@@ -90,6 +93,12 @@ impl WorkHandlerMetrics { ...@@ -90,6 +93,12 @@ impl WorkHandlerMetrics {
metrics_labels, metrics_labels,
)?; )?;
let cancellation_total = metrics.create_intcounter(
work_handler::CANCELLATION_TOTAL,
"Total number of requests cancelled by work handler",
metrics_labels,
)?;
Ok(Self::new( Ok(Self::new(
request_counter, request_counter,
request_duration, request_duration,
...@@ -97,6 +106,7 @@ impl WorkHandlerMetrics { ...@@ -97,6 +106,7 @@ impl WorkHandlerMetrics {
request_bytes, request_bytes,
response_bytes, response_bytes,
error_counter, error_counter,
cancellation_total,
)) ))
} }
} }
...@@ -218,6 +228,7 @@ where ...@@ -218,6 +228,7 @@ where
let mut publisher = tcp::client::TcpClient::create_response_stream( let mut publisher = tcp::client::TcpClient::create_response_stream(
request.context(), request.context(),
control_msg.connection_info, control_msg.connection_info,
self.metrics().map(|m| m.cancellation_total.clone()),
) )
.await .await
.map_err(|e| { .map_err(|e| {
......
...@@ -124,10 +124,13 @@ mod tests { ...@@ -124,10 +124,13 @@ mod tests {
let context_rank1 = Context::with_id((), context_rank0.id().to_string()); let context_rank1 = Context::with_id((), context_rank0.id().to_string());
// connect to the server socket // connect to the server socket
let mut send_stream = let mut send_stream = client::TcpClient::create_response_stream(
client::TcpClient::create_response_stream(context_rank1.context(), connection_info) context_rank1.context(),
.await connection_info,
.unwrap(); None,
)
.await
.unwrap();
println!("Client connected"); println!("Client connected");
// the client can now setup it's end of the stream and if it errors, it can send a message // the client can now setup it's end of the stream and if it errors, it can send a message
......
...@@ -12,6 +12,8 @@ use tokio::{ ...@@ -12,6 +12,8 @@ use tokio::{
}; };
use tokio_util::codec::{FramedRead, FramedWrite}; use tokio_util::codec::{FramedRead, FramedWrite};
use prometheus::IntCounter;
use super::{CallHomeHandshake, ControlMessage, TcpStreamConnectionInfo}; use super::{CallHomeHandshake, ControlMessage, TcpStreamConnectionInfo};
use crate::engine::AsyncEngineContext; use crate::engine::AsyncEngineContext;
use crate::pipeline::network::{ use crate::pipeline::network::{
...@@ -63,6 +65,7 @@ impl TcpClient { ...@@ -63,6 +65,7 @@ impl TcpClient {
pub async fn create_response_stream( pub async fn create_response_stream(
context: Arc<dyn AsyncEngineContext>, context: Arc<dyn AsyncEngineContext>,
info: ConnectionInfo, info: ConnectionInfo,
cancellation_counter: Option<IntCounter>,
) -> Result<StreamSender> { ) -> Result<StreamSender> {
let info = let info =
TcpStreamConnectionInfo::try_from(info).context("tcp-stream-connection-info-error")?; TcpStreamConnectionInfo::try_from(info).context("tcp-stream-connection-info-error")?;
...@@ -97,7 +100,12 @@ impl TcpClient { ...@@ -97,7 +100,12 @@ impl TcpClient {
// captured by the monitor task // captured by the monitor task
let (alive_tx, alive_rx) = tokio::sync::oneshot::channel::<()>(); let (alive_tx, alive_rx) = tokio::sync::oneshot::channel::<()>();
let reader_task = tokio::spawn(handle_reader(framed_reader, context.clone(), alive_tx)); let reader_task = tokio::spawn(handle_reader(
framed_reader,
context.clone(),
alive_tx,
cancellation_counter,
));
// transport specific handshake message // transport specific handshake message
let handshake = CallHomeHandshake { let handshake = CallHomeHandshake {
...@@ -213,9 +221,11 @@ async fn handle_reader( ...@@ -213,9 +221,11 @@ async fn handle_reader(
framed_reader: FramedRead<tokio::io::ReadHalf<tokio::net::TcpStream>, TwoPartCodec>, framed_reader: FramedRead<tokio::io::ReadHalf<tokio::net::TcpStream>, TwoPartCodec>,
context: Arc<dyn AsyncEngineContext>, context: Arc<dyn AsyncEngineContext>,
alive_tx: tokio::sync::oneshot::Sender<()>, alive_tx: tokio::sync::oneshot::Sender<()>,
cancellation_counter: Option<IntCounter>,
) -> FramedRead<tokio::io::ReadHalf<tokio::net::TcpStream>, TwoPartCodec> { ) -> FramedRead<tokio::io::ReadHalf<tokio::net::TcpStream>, TwoPartCodec> {
let mut framed_reader = framed_reader; let mut framed_reader = framed_reader;
let mut alive_tx = alive_tx; let mut alive_tx = alive_tx;
let mut cancellation_counted = false;
loop { loop {
tokio::select! { tokio::select! {
msg = framed_reader.next() => { msg = framed_reader.next() => {
...@@ -233,9 +243,17 @@ async fn handle_reader( ...@@ -233,9 +243,17 @@ async fn handle_reader(
match msg { match msg {
ControlMessage::Stop => { ControlMessage::Stop => {
if let Some(counter) = &cancellation_counter && !cancellation_counted {
counter.inc();
cancellation_counted = true;
}
context.stop(); context.stop();
} }
ControlMessage::Kill => { ControlMessage::Kill => {
if let Some(counter) = &cancellation_counter && !cancellation_counted {
counter.inc();
cancellation_counted = true;
}
context.kill(); context.kill();
} }
ControlMessage::Sentinel => { ControlMessage::Sentinel => {
...@@ -256,6 +274,11 @@ async fn handle_reader( ...@@ -256,6 +274,11 @@ async fn handle_reader(
} }
None => { None => {
tracing::debug!("tcp stream closed by server"); tracing::debug!("tcp stream closed by server");
// If no Stop/Kill was received, this is a cancellation where frontend
// dropped the connection
if let Some(counter) = &cancellation_counter && !cancellation_counted {
counter.inc();
}
break; break;
} }
} }
...@@ -768,10 +791,9 @@ mod tests { ...@@ -768,10 +791,9 @@ mod tests {
// Spawn the reader task // Spawn the reader task
let controller_clone = controller.clone(); let controller_clone = controller.clone();
let reader_handle = let reader_handle = tokio::spawn(async move {
tokio::spawn( handle_reader(framed_reader, controller_clone, alive_tx, None).await
async move { handle_reader(framed_reader, controller_clone, alive_tx).await }, });
);
// Send Stop control message from server // Send Stop control message from server
framed_server framed_server
...@@ -805,10 +827,9 @@ mod tests { ...@@ -805,10 +827,9 @@ mod tests {
// Spawn the reader task // Spawn the reader task
let controller_clone = controller.clone(); let controller_clone = controller.clone();
let reader_handle = let reader_handle = tokio::spawn(async move {
tokio::spawn( handle_reader(framed_reader, controller_clone, alive_tx, None).await
async move { handle_reader(framed_reader, controller_clone, alive_tx).await }, });
);
// Send Kill control message from server // Send Kill control message from server
framed_server framed_server
...@@ -842,7 +863,9 @@ mod tests { ...@@ -842,7 +863,9 @@ mod tests {
// Spawn the reader task // Spawn the reader task
let reader_handle = let reader_handle =
tokio::spawn(async move { handle_reader(framed_reader, controller, alive_tx).await }); tokio::spawn(
async move { handle_reader(framed_reader, controller, alive_tx, None).await },
);
// Drop the alive_rx to close the channel (simulating writer finishing) // Drop the alive_rx to close the channel (simulating writer finishing)
drop(alive_rx); drop(alive_rx);
...@@ -869,7 +892,9 @@ mod tests { ...@@ -869,7 +892,9 @@ mod tests {
// Spawn the reader task // Spawn the reader task
let reader_handle = let reader_handle =
tokio::spawn(async move { handle_reader(framed_reader, controller, alive_tx).await }); tokio::spawn(
async move { handle_reader(framed_reader, controller, alive_tx, None).await },
);
// Close the framed server to signal EOF to the client // Close the framed server to signal EOF to the client
framed_server.close().await.unwrap(); framed_server.close().await.unwrap();
...@@ -896,10 +921,9 @@ mod tests { ...@@ -896,10 +921,9 @@ mod tests {
// Spawn the reader task // Spawn the reader task
let controller_clone = controller.clone(); let controller_clone = controller.clone();
let reader_handle = let reader_handle = tokio::spawn(async move {
tokio::spawn( handle_reader(framed_reader, controller_clone, alive_tx, None).await
async move { handle_reader(framed_reader, controller_clone, alive_tx).await }, });
);
// Send multiple Stop messages (first one will stop, subsequent ones are no-ops) // Send multiple Stop messages (first one will stop, subsequent ones are no-ops)
framed_server framed_server
...@@ -937,10 +961,9 @@ mod tests { ...@@ -937,10 +961,9 @@ mod tests {
// Spawn the reader task // Spawn the reader task
let controller_clone = controller.clone(); let controller_clone = controller.clone();
let reader_handle = let reader_handle = tokio::spawn(async move {
tokio::spawn( handle_reader(framed_reader, controller_clone, alive_tx, None).await
async move { handle_reader(framed_reader, controller_clone, alive_tx).await }, });
);
// Send Stop first, then Kill // Send Stop first, then Kill
framed_server framed_server
......
...@@ -20,6 +20,8 @@ from tests.fault_tolerance.cancellation.utils import ( ...@@ -20,6 +20,8 @@ from tests.fault_tolerance.cancellation.utils import (
poll_for_pattern, poll_for_pattern,
read_streaming_responses, read_streaming_responses,
send_cancellable_request, send_cancellable_request,
verify_frontend_cancellation_metrics,
verify_runtime_cancellation_metrics,
) )
from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME
from tests.utils.managed_process import ManagedProcess from tests.utils.managed_process import ManagedProcess
...@@ -242,7 +244,7 @@ def test_request_cancellation_sglang_aggregated( ...@@ -242,7 +244,7 @@ def test_request_cancellation_sglang_aggregated(
), ),
] ]
for request_type, description in test_scenarios: for idx, (request_type, description) in enumerate(test_scenarios):
logger.info(f"Testing {description.lower()}...") logger.info(f"Testing {description.lower()}...")
# Send the request (non-blocking) # Send the request (non-blocking)
...@@ -291,6 +293,17 @@ def test_request_cancellation_sglang_aggregated( ...@@ -291,6 +293,17 @@ def test_request_cancellation_sglang_aggregated(
logger.info(f"{description} detected successfully") logger.info(f"{description} detected successfully")
# Verify cancellation metrics after each scenario
verify_frontend_cancellation_metrics(
frontend_port=frontend.frontend_port,
request_type=request_type,
expected_count=1,
)
verify_runtime_cancellation_metrics(
worker_system_port=worker.system_port,
expected_count=idx + 1,
)
@pytest.mark.timeout(300) # 3x average @pytest.mark.timeout(300) # 3x average
@pytest.mark.gpu_2 @pytest.mark.gpu_2
...@@ -396,3 +409,19 @@ def test_request_cancellation_sglang_decode_cancel( ...@@ -396,3 +409,19 @@ def test_request_cancellation_sglang_decode_cancel(
logger.info( logger.info(
"Chat completion stream cancellation in decode phase detected successfully" "Chat completion stream cancellation in decode phase detected successfully"
) )
# Verify cancellation metrics
verify_frontend_cancellation_metrics(
frontend_port=frontend.frontend_port,
request_type="chat_completion_stream",
expected_count=1,
)
verify_runtime_cancellation_metrics(
worker_system_port=decode_worker.system_port,
expected_count=1,
)
verify_runtime_cancellation_metrics(
worker_system_port=prefill_worker.system_port,
expected_count=0,
component="prefill",
)
...@@ -22,6 +22,8 @@ from tests.fault_tolerance.cancellation.utils import ( ...@@ -22,6 +22,8 @@ from tests.fault_tolerance.cancellation.utils import (
poll_for_pattern, poll_for_pattern,
read_streaming_responses, read_streaming_responses,
send_cancellable_request, send_cancellable_request,
verify_frontend_cancellation_metrics,
verify_runtime_cancellation_metrics,
) )
from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME
from tests.utils.managed_process import ManagedProcess from tests.utils.managed_process import ManagedProcess
...@@ -208,7 +210,7 @@ def test_request_cancellation_trtllm_aggregated( ...@@ -208,7 +210,7 @@ def test_request_cancellation_trtllm_aggregated(
), ),
] ]
for request_type, description in test_scenarios: for idx, (request_type, description) in enumerate(test_scenarios):
logger.info(f"Testing {description.lower()}...") logger.info(f"Testing {description.lower()}...")
# Send the request (non-blocking) # Send the request (non-blocking)
...@@ -248,6 +250,18 @@ def test_request_cancellation_trtllm_aggregated( ...@@ -248,6 +250,18 @@ def test_request_cancellation_trtllm_aggregated(
logger.info(f"{description} detected successfully") logger.info(f"{description} detected successfully")
# Verify cancellation metrics after each scenario
verify_frontend_cancellation_metrics(
frontend_port=frontend.frontend_port,
request_type=request_type,
expected_count=1,
)
verify_runtime_cancellation_metrics(
worker_system_port=worker.system_port,
expected_count=idx + 1,
component="tensorrt_llm",
)
@pytest.mark.timeout(195) # 3x average @pytest.mark.timeout(195) # 3x average
def test_request_cancellation_trtllm_decode_cancel( def test_request_cancellation_trtllm_decode_cancel(
...@@ -332,6 +346,23 @@ def test_request_cancellation_trtllm_decode_cancel( ...@@ -332,6 +346,23 @@ def test_request_cancellation_trtllm_decode_cancel(
"Chat completion stream cancellation in decode phase detected successfully" "Chat completion stream cancellation in decode phase detected successfully"
) )
# Verify cancellation metrics
verify_frontend_cancellation_metrics(
frontend_port=frontend.frontend_port,
request_type="chat_completion_stream",
expected_count=1,
)
verify_runtime_cancellation_metrics(
worker_system_port=decode_worker.system_port,
expected_count=1,
component="tensorrt_llm",
)
verify_runtime_cancellation_metrics(
worker_system_port=prefill_worker.system_port,
expected_count=0,
component="prefill",
)
@pytest.mark.timeout(195) # 3x average @pytest.mark.timeout(195) # 3x average
def test_request_cancellation_trtllm_prefill_cancel( def test_request_cancellation_trtllm_prefill_cancel(
...@@ -424,6 +455,23 @@ def test_request_cancellation_trtllm_prefill_cancel( ...@@ -424,6 +455,23 @@ def test_request_cancellation_trtllm_prefill_cancel(
"Completion request cancellation during prefill phase detected successfully" "Completion request cancellation during prefill phase detected successfully"
) )
# Verify cancellation metrics
verify_frontend_cancellation_metrics(
frontend_port=frontend.frontend_port,
request_type="completion",
expected_count=1,
)
verify_runtime_cancellation_metrics(
worker_system_port=decode_worker.system_port,
expected_count=0,
component="tensorrt_llm",
)
verify_runtime_cancellation_metrics(
worker_system_port=prefill_worker.system_port,
expected_count=1,
component="prefill",
)
@pytest.mark.xfail(reason="Test fails only on CI", strict=False) @pytest.mark.xfail(reason="Test fails only on CI", strict=False)
@pytest.mark.timeout(195) # 3x average @pytest.mark.timeout(195) # 3x average
...@@ -523,3 +571,20 @@ def test_request_cancellation_trtllm_kv_transfer_cancel( ...@@ -523,3 +571,20 @@ def test_request_cancellation_trtllm_kv_transfer_cancel(
logger.info( logger.info(
"Workers are functional after cancellation during KV transfer" "Workers are functional after cancellation during KV transfer"
) )
# Verify cancellation metrics
verify_frontend_cancellation_metrics(
frontend_port=frontend.frontend_port,
request_type="completion",
expected_count=1,
)
verify_runtime_cancellation_metrics(
worker_system_port=decode_worker.system_port,
expected_count=1,
component="tensorrt_llm",
)
verify_runtime_cancellation_metrics(
worker_system_port=prefill_worker.system_port,
expected_count=0,
component="prefill",
)
...@@ -21,6 +21,8 @@ from tests.fault_tolerance.cancellation.utils import ( ...@@ -21,6 +21,8 @@ from tests.fault_tolerance.cancellation.utils import (
poll_for_pattern, poll_for_pattern,
read_streaming_responses, read_streaming_responses,
send_cancellable_request, send_cancellable_request,
verify_frontend_cancellation_metrics,
verify_runtime_cancellation_metrics,
) )
from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME
from tests.utils.managed_process import ManagedProcess from tests.utils.managed_process import ManagedProcess
...@@ -242,7 +244,7 @@ def test_request_cancellation_vllm_aggregated( ...@@ -242,7 +244,7 @@ def test_request_cancellation_vllm_aggregated(
), ),
] ]
for request_type, description in test_scenarios: for idx, (request_type, description) in enumerate(test_scenarios):
logger.info(f"Testing {description.lower()}...") logger.info(f"Testing {description.lower()}...")
# Send the request (non-blocking) # Send the request (non-blocking)
...@@ -282,6 +284,17 @@ def test_request_cancellation_vllm_aggregated( ...@@ -282,6 +284,17 @@ def test_request_cancellation_vllm_aggregated(
logger.info(f"{description} detected successfully") logger.info(f"{description} detected successfully")
# Verify cancellation metrics after each scenario
verify_frontend_cancellation_metrics(
frontend_port=frontend.frontend_port,
request_type=request_type,
expected_count=1,
)
verify_runtime_cancellation_metrics(
worker_system_port=worker.system_port,
expected_count=idx + 1,
)
@pytest.mark.timeout(150) # 3x average @pytest.mark.timeout(150) # 3x average
@pytest.mark.nightly @pytest.mark.nightly
...@@ -365,6 +378,22 @@ def test_request_cancellation_vllm_decode_cancel( ...@@ -365,6 +378,22 @@ def test_request_cancellation_vllm_decode_cancel(
"Chat completion stream cancellation in decode phase detected successfully" "Chat completion stream cancellation in decode phase detected successfully"
) )
# Verify cancellation metrics
verify_frontend_cancellation_metrics(
frontend_port=frontend.frontend_port,
request_type="chat_completion_stream",
expected_count=1,
)
verify_runtime_cancellation_metrics(
worker_system_port=decode_worker.system_port,
expected_count=1,
)
verify_runtime_cancellation_metrics(
worker_system_port=prefill_worker.system_port,
expected_count=0,
component="prefill",
)
@pytest.mark.timeout(150) # 3x average @pytest.mark.timeout(150) # 3x average
@pytest.mark.nightly @pytest.mark.nightly
...@@ -458,3 +487,19 @@ def test_request_cancellation_vllm_prefill_cancel( ...@@ -458,3 +487,19 @@ def test_request_cancellation_vllm_prefill_cancel(
logger.info( logger.info(
"Completion request cancellation during prefill phase detected successfully" "Completion request cancellation during prefill phase detected successfully"
) )
# Verify cancellation metrics
verify_frontend_cancellation_metrics(
frontend_port=frontend.frontend_port,
request_type="completion",
expected_count=1,
)
verify_runtime_cancellation_metrics(
worker_system_port=decode_worker.system_port,
expected_count=0,
)
verify_runtime_cancellation_metrics(
worker_system_port=prefill_worker.system_port,
expected_count=1,
component="prefill",
)
...@@ -298,6 +298,170 @@ def read_streaming_responses( ...@@ -298,6 +298,170 @@ def read_streaming_responses(
) )
def _parse_frontend_cancellation_metric(
metrics_text: str, model_name: str, endpoint: str, request_type: str
) -> int:
"""
Parse the frontend cancellation metric from Prometheus metrics text.
Args:
metrics_text: Raw Prometheus metrics text
model_name: The model name label value
endpoint: The endpoint label value (e.g. "completions", "chat_completions")
request_type: The request_type label value ("stream" or "unary")
Returns:
The metric count, or 0 if not found
"""
for line in metrics_text.splitlines():
if not line.startswith("dynamo_frontend_model_cancellation_total{"):
continue
if (
f'endpoint="{endpoint}"' in line
and f'model="{model_name}"' in line
and f'request_type="{request_type}"' in line
):
parts = line.rsplit(None, 1)
if len(parts) == 2:
try:
return int(float(parts[1]))
except ValueError:
pass
return 0
def _parse_runtime_cancellation_metric(
metrics_text: str,
namespace: str = "dynamo",
component: str = "backend",
endpoint: str = "generate",
) -> int:
"""
Parse the runtime cancellation metric from Prometheus metrics text.
The metric is dynamo_component_cancellation_total with auto-injected
labels (dynamo_namespace, dynamo_component, dynamo_endpoint).
Args:
metrics_text: Raw Prometheus metrics text
namespace: Expected dynamo_namespace label value
component: Expected dynamo_component label value
endpoint: Expected dynamo_endpoint label value
Returns:
The metric count, or 0 if not found
"""
for line in metrics_text.splitlines():
if not line.startswith("dynamo_component_cancellation_total{"):
continue
if (
f'dynamo_namespace="{namespace}"' in line
and f'dynamo_component="{component}"' in line
and f'dynamo_endpoint="{endpoint}"' in line
):
parts = line.rsplit(None, 1)
if len(parts) == 2:
try:
return int(float(parts[1]))
except ValueError:
pass
return 0
def _resolve_cancellation_labels(request_type: str) -> tuple[str, str]:
"""
Map a test request type to frontend metric labels.
Args:
request_type: One of "completion", "chat_completion", "chat_completion_stream"
Returns:
(endpoint, request_type_label) tuple
"""
mapping = {
"completion": ("completions", "unary"),
"chat_completion": ("chat_completions", "unary"),
"chat_completion_stream": ("chat_completions", "stream"),
}
if request_type not in mapping:
pytest.fail(f"Unknown request type: {request_type}")
return mapping[request_type]
def verify_frontend_cancellation_metrics(
frontend_port: int,
request_type: str,
expected_count: int = 0,
) -> None:
"""
Verify frontend cancellation metrics.
Args:
frontend_port: Port where the frontend /metrics is served
request_type: The test request type ("completion", "chat_completion", "chat_completion_stream")
expected_count: Expected cancellation count for this request type
"""
endpoint, req_type_label = _resolve_cancellation_labels(request_type)
frontend_metrics_url = f"http://localhost:{frontend_port}/metrics"
try:
response = requests.get(frontend_metrics_url, timeout=5)
response.raise_for_status()
except requests.RequestException as e:
pytest.fail(
f"Failed to fetch frontend metrics from {frontend_metrics_url}: {e}"
)
frontend_text = response.text
count = _parse_frontend_cancellation_metric(
frontend_text, FAULT_TOLERANCE_MODEL_NAME, endpoint, req_type_label
)
logger.info(
f"Frontend cancellation metrics - endpoint={endpoint}, "
f"request_type={req_type_label}: {count}"
)
assert count == expected_count, (
f"Frontend: expected {expected_count} cancellations "
f"for endpoint={endpoint}, request_type={req_type_label}, "
f"but got {count}"
)
def verify_runtime_cancellation_metrics(
worker_system_port: int,
expected_count: int = 0,
component: str = "backend",
) -> None:
"""
Verify runtime (worker) cancellation metrics.
Args:
worker_system_port: Port where the worker /metrics is served
expected_count: Expected cumulative cancellation count
component: The dynamo_component label value (e.g. "backend", "prefill")
"""
worker_metrics_url = f"http://localhost:{worker_system_port}/metrics"
try:
response = requests.get(worker_metrics_url, timeout=5)
response.raise_for_status()
except requests.RequestException as e:
pytest.fail(f"Failed to fetch worker metrics from {worker_metrics_url}: {e}")
worker_text = response.text
count = _parse_runtime_cancellation_metric(worker_text, component=component)
logger.info(f"Runtime cancellation metrics (component={component}): {count}")
assert count == expected_count, (
f"Runtime (component={component}): expected {expected_count} cancellations, "
f"but got {count}"
)
def read_log_content(log_path: str | None) -> str: def read_log_content(log_path: str | None) -> str:
"""Read log content from a file""" """Read log content from a file"""
if log_path is None: if log_path is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment