"vscode:/vscode.git/clone" did not exist on "8601ccdbea876210ccf4e802c9daed8dd06c91b7"
Unverified Commit f8bb53c0 authored by Vladislav Nosivskoy's avatar Vladislav Nosivskoy Committed by GitHub
Browse files

feat: add cached tokens prometheus metric (#4534)


Signed-off-by: default avatarVladislav Nosivskoy <vladnosiv@gmail.com>
Co-authored-by: default avatarRyan McCormick <rmccormick@nvidia.com>
parent 2845aa1f
......@@ -152,6 +152,7 @@ The Dynamo HTTP Frontend (`python -m dynamo.frontend`) exposes `dynamo_frontend_
- `dynamo_frontend_queued_requests`: Number of requests in HTTP processing queue (gauge)
- `dynamo_frontend_disconnected_clients`: Number of disconnected clients (gauge)
- `dynamo_frontend_input_sequence_tokens`: Input sequence length (histogram)
- `dynamo_frontend_cached_tokens`: Number of cached tokens (prefix cache hits) per request (histogram)
- `dynamo_frontend_inter_token_latency_seconds`: Inter-token latency (histogram)
- `dynamo_frontend_output_sequence_tokens`: Output sequence length (histogram)
- `dynamo_frontend_output_tokens_total`: Total number of output tokens generated (counter)
......
......@@ -55,6 +55,8 @@ class frontend_service:
INPUT_SEQUENCE_TOKENS = "input_sequence_tokens"
# Output sequence length in tokens
OUTPUT_SEQUENCE_TOKENS = "output_sequence_tokens"
# Number of cached tokens (prefix cache hits) per request
CACHED_TOKENS = "cached_tokens"
# Total number of output tokens generated (counter that updates in real-time)
OUTPUT_TOKENS_TOTAL = "output_tokens_total"
# Time to first token in seconds
......@@ -93,6 +95,10 @@ class kvbm:
ONBOARD_BLOCKS_D2D = "onboard_blocks_d2d"
# The number of matched tokens
MATCHED_TOKENS = "matched_tokens"
# Host cache hit rate (0.0-1.0) from the sliding window
HOST_CACHE_HIT_RATE = "host_cache_hit_rate"
# Disk cache hit rate (0.0-1.0) from the sliding window
DISK_CACHE_HIT_RATE = "disk_cache_hit_rate"
class kvrouter:
......
......@@ -165,6 +165,7 @@ pub struct Metrics {
request_duration: HistogramVec,
input_sequence_length: HistogramVec,
output_sequence_length: HistogramVec,
cached_tokens: HistogramVec,
output_tokens_counter: IntCounterVec,
time_to_first_token: HistogramVec,
inter_token_latency: HistogramVec,
......@@ -252,6 +253,8 @@ pub struct ResponseMetricCollector {
// be computed.
last_response_time: Option<Duration>,
osl: usize,
// we track if cached_tokens has been observed to ensure we only increment once per request
cached_tokens_observed: bool,
}
impl Default for Metrics {
......@@ -378,7 +381,7 @@ impl Metrics {
frontend_metric_name(frontend_service::INPUT_SEQUENCE_TOKENS),
"Input sequence length in tokens",
)
.buckets(input_sequence_buckets),
.buckets(input_sequence_buckets.clone()),
&["model"],
)
.unwrap();
......@@ -436,6 +439,16 @@ impl Metrics {
)
.unwrap();
let cached_tokens = HistogramVec::new(
HistogramOpts::new(
frontend_metric_name(frontend_service::CACHED_TOKENS),
"Number of cached tokens (prefix cache hits) per request",
)
.buckets(input_sequence_buckets.clone()),
&["model"],
)
.unwrap();
// Runtime configuration metrics
// Note: Some of these metrics represent counter-like values from source systems,
// but are implemented as gauges because they are copied/synchronized from upstream
......@@ -502,6 +515,7 @@ impl Metrics {
request_duration,
input_sequence_length,
output_sequence_length,
cached_tokens,
output_tokens_counter,
time_to_first_token,
inter_token_latency,
......@@ -597,6 +611,7 @@ impl Metrics {
registry.register(Box::new(self.request_duration.clone()))?;
registry.register(Box::new(self.input_sequence_length.clone()))?;
registry.register(Box::new(self.output_sequence_length.clone()))?;
registry.register(Box::new(self.cached_tokens.clone()))?;
registry.register(Box::new(self.output_tokens_counter.clone()))?;
registry.register(Box::new(self.time_to_first_token.clone()))?;
registry.register(Box::new(self.inter_token_latency.clone()))?;
......@@ -830,6 +845,7 @@ impl ResponseMetricCollector {
last_response_time: None,
start_time: Instant::now(),
osl: 0,
cached_tokens_observed: false,
}
}
......@@ -843,6 +859,19 @@ impl ResponseMetricCollector {
self.is_first_token
}
/// Observe cached tokens (prefix cache hits), observing only once per request when value is available
pub fn observe_cached_tokens(&mut self, cached_tokens: Option<usize>) {
if let Some(tokens) = cached_tokens
&& !self.cached_tokens_observed
{
self.cached_tokens_observed = true;
self.metrics
.cached_tokens
.with_label_values(&[&self.model])
.observe(tokens as f64);
}
}
/// Observe a response with input sequence length and number of new tokens
pub fn observe_response(&mut self, isl: usize, num_tokens: usize) {
if num_tokens == 0 {
......@@ -943,11 +972,13 @@ impl<T> From<crate::types::Annotated<T>> for EventConverter<T> {
///
/// This function handles metrics collection, http_queue_guard management, and converts
/// annotated responses to SSE events for streaming responses.
///
/// Returns None for metrics annotation events (events without SSE data payload).
pub fn process_response_using_event_converter_and_observe_metrics<T: Serialize>(
annotated: EventConverter<T>,
response_collector: &mut ResponseMetricCollector,
http_queue_guard: &mut Option<HttpQueueGuard>,
) -> Result<Event, axum::Error> {
) -> Result<Option<Event>, axum::Error> {
use crate::preprocessor::LLMMetricAnnotation;
let mut annotated = annotated.0;
......@@ -955,6 +986,7 @@ pub fn process_response_using_event_converter_and_observe_metrics<T: Serialize>(
// update metrics
if let Ok(Some(metrics)) = LLMMetricAnnotation::from_annotation(&annotated) {
response_collector.observe_current_osl(metrics.output_tokens);
response_collector.observe_cached_tokens(metrics.cached_tokens);
// Drop http_queue_guard on first token for streaming
if response_collector.is_first_token()
......@@ -976,11 +1008,11 @@ pub fn process_response_using_event_converter_and_observe_metrics<T: Serialize>(
let mut event = Event::default();
if let Some(data) = annotated.data {
if let Some(ref data) = annotated.data {
event = event.json_data(data)?;
}
if let Some(msg) = annotated.event {
if let Some(ref msg) = annotated.event {
if msg == "error" {
let msgs = annotated
.comment
......@@ -996,7 +1028,12 @@ pub fn process_response_using_event_converter_and_observe_metrics<T: Serialize>(
}
}
Ok(event)
// Filter out metrics annotation events (events without SSE data payload)
if annotated.data.is_none() && annotated.event.is_none() {
Ok(None)
} else {
Ok(Some(event))
}
}
/// Create a new router with optional custom backend metrics support
......@@ -1357,4 +1394,120 @@ mod tests {
20
);
}
#[test]
fn test_cached_tokens_once_per_request() {
let metrics = Arc::new(Metrics::new());
let registry = prometheus::Registry::new();
metrics.register(&registry).unwrap();
let model = "test-model";
let expected_metric_name = "dynamo_frontend_cached_tokens";
let mut collector = metrics.clone().create_response_collector(model);
// Create histogram handle first
let _histogram = metrics.cached_tokens.with_label_values(&[model]);
// First call should observe and record 1 sample
collector.observe_cached_tokens(Some(100));
let metric_families = registry.gather();
let histogram_family = metric_families
.iter()
.find(|mf| mf.name() == expected_metric_name)
.expect("histogram should be registered");
assert_eq!(
histogram_family.get_metric()[0]
.get_histogram()
.get_sample_count(),
1
);
// Second call with same collector should not observe again (idempotent)
collector.observe_cached_tokens(Some(50));
let metric_families = registry.gather();
let histogram_family = metric_families
.iter()
.find(|mf| mf.name() == expected_metric_name)
.expect("histogram should be registered");
assert_eq!(
histogram_family.get_metric()[0]
.get_histogram()
.get_sample_count(),
1
);
// Third call with different value should still be idempotent
collector.observe_cached_tokens(Some(75));
let metric_families = registry.gather();
let histogram_family = metric_families
.iter()
.find(|mf| mf.name() == expected_metric_name)
.expect("histogram should be registered");
assert_eq!(
histogram_family.get_metric()[0]
.get_histogram()
.get_sample_count(),
1
);
}
#[test]
fn test_metrics_annotation_event_handling() {
use crate::preprocessor::LLMMetricAnnotation;
use crate::types::Annotated;
let metrics = Arc::new(Metrics::new());
let registry = prometheus::Registry::new();
metrics.register(&registry).unwrap();
let model = "test-model";
let expected_metric_name = "dynamo_frontend_cached_tokens";
let mut collector = metrics.clone().create_response_collector(model);
// Create a metrics annotation event (event without SSE data payload)
let mut annotated = Annotated::<
crate::protocols::openai::chat_completions::NvCreateChatCompletionStreamResponse,
> {
id: None,
data: None,
event: Some(crate::preprocessor::ANNOTATION_LLM_METRICS.to_string()),
comment: None,
};
// Add metrics annotation with cached_tokens
let llm_metrics = LLMMetricAnnotation {
input_tokens: 10,
output_tokens: 20,
chunk_tokens: 5,
cached_tokens: Some(15),
};
let annotation = llm_metrics.to_annotation::<()>().unwrap();
annotated.event = annotation.event;
annotated.comment = annotation.comment;
// Process the event
let mut http_queue_guard = None;
let result = process_response_using_event_converter_and_observe_metrics(
EventConverter::from(annotated),
&mut collector,
&mut http_queue_guard,
);
// Should return Ok(None) for metrics annotation events
assert!(matches!(result, Ok(None)));
// Should have observed the cached tokens from the metrics annotation event
let metric_families = registry.gather();
let histogram_family = metric_families
.iter()
.find(|mf| mf.name() == expected_metric_name)
.expect("histogram should be registered");
assert_eq!(
histogram_family.get_metric()[0]
.get_histogram()
.get_sample_count(),
1
);
}
}
......@@ -411,14 +411,20 @@ async fn completions_single(
if streaming {
// For streaming, we'll drop the http_queue_guard on the first token
let mut http_queue_guard = Some(http_queue_guard);
let stream = stream.map(move |response| {
// Calls observe_response() on each token
process_response_using_event_converter_and_observe_metrics(
EventConverter::from(response),
&mut response_collector,
&mut http_queue_guard,
)
});
let stream = stream
.map(move |response| {
// Calls observe_response() on each token
process_response_using_event_converter_and_observe_metrics(
EventConverter::from(response),
&mut response_collector,
&mut http_queue_guard,
)
})
.filter_map(|result| {
use futures::future;
// Transpose Result<Option<T>> -> Option<Result<T>>
future::ready(result.transpose())
});
let stream = monitor_for_disconnects(stream, ctx, inflight_guard, stream_handle);
let mut sse_stream = Sse::new(stream);
......@@ -567,14 +573,20 @@ async fn completions_batch(
if streaming {
// For streaming, we'll drop the http_queue_guard on the first token
let mut http_queue_guard = Some(http_queue_guard);
let stream = merged_stream.map(move |response| {
// Calls observe_response() on each token
process_response_using_event_converter_and_observe_metrics(
EventConverter::from(response),
&mut response_collector,
&mut http_queue_guard,
)
});
let stream = merged_stream
.map(move |response| {
// Calls observe_response() on each token
process_response_using_event_converter_and_observe_metrics(
EventConverter::from(response),
&mut response_collector,
&mut http_queue_guard,
)
})
.filter_map(|result| {
use futures::future;
// Transpose Result<Option<T>> -> Option<Result<T>>
future::ready(result.transpose())
});
let stream = monitor_for_disconnects(stream, ctx, inflight_guard, stream_handle);
let mut sse_stream = Sse::new(stream);
......@@ -942,15 +954,21 @@ async fn chat_completions(
stream_handle.arm(); // allows the system to detect client disconnects and cancel the LLM generation
let mut http_queue_guard = Some(http_queue_guard);
let stream = stream.map(move |response| {
// Calls observe_response() on each token
// EventConverter will detect `event: "error"` and convert to SSE error events
process_response_using_event_converter_and_observe_metrics(
EventConverter::from(response),
&mut response_collector,
&mut http_queue_guard,
)
});
let stream = stream
.map(move |response| {
// Calls observe_response() on each token
// EventConverter will detect `event: "error"` and convert to SSE error events
process_response_using_event_converter_and_observe_metrics(
EventConverter::from(response),
&mut response_collector,
&mut http_queue_guard,
)
})
.filter_map(|result| {
use futures::future;
// Transpose Result<Option<T>> -> Option<Result<T>>
future::ready(result.transpose())
});
let stream = monitor_for_disconnects(stream, ctx, inflight_guard, stream_handle);
let mut sse_stream = Sse::new(stream);
......
......@@ -71,6 +71,7 @@ pub struct LLMMetricAnnotation {
pub input_tokens: usize,
pub output_tokens: usize,
pub chunk_tokens: usize,
pub cached_tokens: Option<usize>,
}
impl LLMMetricAnnotation {
......@@ -646,6 +647,7 @@ impl OpenAIPreprocessor {
input_tokens: isl,
output_tokens: current_osl,
chunk_tokens,
cached_tokens: None,
};
if let Ok(metrics_annotated) = llm_metrics.to_annotation::<()>() {
......@@ -673,20 +675,39 @@ impl OpenAIPreprocessor {
// again. The stream is exhausted and will panic if polled after None.
inner.finished = true;
// Check if we need to send a usage chunk
if inner.response_generator.is_usage_enabled()
&& inner.finish_reason_sent
&& !inner.usage_chunk_sent
{
if inner.finish_reason_sent && !inner.usage_chunk_sent {
inner.usage_chunk_sent = true;
// Create the final usage chunk
let usage_chunk = inner.response_generator.create_usage_chunk();
let usage = inner.response_generator.get_usage();
let llm_metrics = LLMMetricAnnotation {
input_tokens: usage.prompt_tokens as usize,
output_tokens: usage.completion_tokens as usize,
chunk_tokens: 0,
cached_tokens: usage
.prompt_tokens_details
.as_ref()
.and_then(|d| d.cached_tokens.map(|c| c as usize)),
};
// Create annotation string
let annotation = llm_metrics.to_annotation::<()>().unwrap_or_else(|e| {
tracing::warn!("Failed to serialize metrics: {}", e);
Annotated::<()>::from_data(())
});
// Send the usage chunk if needed
let data = if inner.response_generator.is_usage_enabled() {
Some(usage_chunk)
} else {
None
};
let annotated_usage = Annotated::<Resp> {
id: None,
data: Some(usage_chunk),
event: None,
comment: None,
data,
event: Some(ANNOTATION_LLM_METRICS.to_string()),
comment: annotation.comment,
};
tracing::trace!(
......
......@@ -225,6 +225,9 @@ pub trait DeltaGeneratorExt<ResponseType: Send + 'static + std::fmt::Debug>:
/// Check if usage tracking is enabled.
fn is_usage_enabled(&self) -> bool;
/// Get the current usage statistics with properly calculated total_tokens.
fn get_usage(&self) -> dynamo_async_openai::types::CompletionUsage;
}
#[derive(Clone, Debug, Serialize, Deserialize, Default)]
......
......@@ -289,8 +289,7 @@ impl DeltaGenerator {
/// # Returns
/// * A [`CreateChatCompletionStreamResponse`] with empty choices and usage stats.
pub fn create_usage_chunk(&self) -> NvCreateChatCompletionStreamResponse {
let mut usage = self.usage.clone();
usage.total_tokens = usage.prompt_tokens.saturating_add(usage.completion_tokens);
let usage = self.get_usage();
dynamo_async_openai::types::CreateChatCompletionStreamResponse {
id: self.id.clone(),
......@@ -309,6 +308,12 @@ impl DeltaGenerator {
pub fn is_usage_enabled(&self) -> bool {
self.options.enable_usage
}
pub fn get_usage(&self) -> dynamo_async_openai::types::CompletionUsage {
let mut usage = self.usage.clone();
usage.total_tokens = usage.prompt_tokens.saturating_add(usage.completion_tokens);
usage
}
}
/// Implements the [`crate::protocols::openai::DeltaGeneratorExt`] trait for [`DeltaGenerator`], allowing
......@@ -328,27 +333,25 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamRes
&mut self,
delta: crate::protocols::common::llm_backend::BackendOutput,
) -> anyhow::Result<NvCreateChatCompletionStreamResponse> {
// Aggregate token usage if enabled.
if self.options.enable_usage {
// SAFETY: Casting from `usize` to `u32` could lead to precision loss after `u32::MAX`,
// but this will not be an issue until context lengths exceed 4_294_967_295.
let token_length: u32 = delta
.token_ids
.len()
.try_into()
.expect("token_ids length exceeds u32::MAX");
self.usage.completion_tokens += token_length;
// If backend provides completion_usage with prompt token details,
// propagate the entire details struct to usage tracking
if let Some(prompt_details) = delta
.completion_usage
.as_ref()
.and_then(|usage| usage.prompt_tokens_details.as_ref())
{
self.usage.prompt_tokens_details = Some(prompt_details.clone());
}
// Aggregate token usage even if usage tracking is disabled for metrics tracking
// SAFETY: Casting from `usize` to `u32` could lead to precision loss after `u32::MAX`,
// but this will not be an issue until context lengths exceed 4_294_967_295.
let token_length: u32 = delta
.token_ids
.len()
.try_into()
.expect("token_ids length exceeds u32::MAX");
self.usage.completion_tokens += token_length;
// If backend provides completion_usage with prompt token details,
// propagate the entire details struct to usage tracking
if let Some(prompt_details) = delta
.completion_usage
.as_ref()
.and_then(|usage| usage.prompt_tokens_details.as_ref())
{
self.usage.prompt_tokens_details = Some(prompt_details.clone());
}
let logprobs = self.create_logprobs(
......@@ -438,6 +441,10 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamRes
fn is_usage_enabled(&self) -> bool {
DeltaGenerator::is_usage_enabled(self)
}
fn get_usage(&self) -> dynamo_async_openai::types::CompletionUsage {
DeltaGenerator::get_usage(self)
}
}
#[cfg(test)]
......
......@@ -223,8 +223,7 @@ impl DeltaGenerator {
/// # Returns
/// * A [`NvCreateCompletionResponse`] with empty choices and usage stats.
pub fn create_usage_chunk(&self) -> NvCreateCompletionResponse {
let mut usage = self.usage.clone();
usage.total_tokens = usage.prompt_tokens.saturating_add(usage.completion_tokens);
let usage = self.get_usage();
let inner = dynamo_async_openai::types::CreateCompletionResponse {
id: self.id.clone(),
......@@ -244,6 +243,12 @@ impl DeltaGenerator {
pub fn is_usage_enabled(&self) -> bool {
self.options.enable_usage
}
pub fn get_usage(&self) -> dynamo_async_openai::types::CompletionUsage {
let mut usage = self.usage.clone();
usage.total_tokens = usage.prompt_tokens.saturating_add(usage.completion_tokens);
usage
}
}
impl crate::protocols::openai::DeltaGeneratorExt<NvCreateCompletionResponse> for DeltaGenerator {
......@@ -251,27 +256,25 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateCompletionResponse> for
&mut self,
delta: common::llm_backend::BackendOutput,
) -> anyhow::Result<NvCreateCompletionResponse> {
// aggregate usage
if self.options.enable_usage {
// SAFETY: Casting from `usize` to `u32` could lead to precision loss after `u32::MAX`,
// but this will not be an issue until context lengths exceed 4_294_967_295.
let token_length: u32 = delta
.token_ids
.len()
.try_into()
.expect("token_ids length exceeds u32::MAX");
self.usage.completion_tokens += token_length;
// If backend provides completion_usage with prompt token details,
// propagate the entire details struct to usage tracking
if let Some(prompt_details) = delta
.completion_usage
.as_ref()
.and_then(|usage| usage.prompt_tokens_details.as_ref())
{
self.usage.prompt_tokens_details = Some(prompt_details.clone());
}
// Aggregate token usage even if usage tracking is disabled for metrics tracking
// SAFETY: Casting from `usize` to `u32` could lead to precision loss after `u32::MAX`,
// but this will not be an issue until context lengths exceed 4_294_967_295.
let token_length: u32 = delta
.token_ids
.len()
.try_into()
.expect("token_ids length exceeds u32::MAX");
self.usage.completion_tokens += token_length;
// If backend provides completion_usage with prompt token details,
// propagate the entire details struct to usage tracking
if let Some(prompt_details) = delta
.completion_usage
.as_ref()
.and_then(|usage| usage.prompt_tokens_details.as_ref())
{
self.usage.prompt_tokens_details = Some(prompt_details.clone());
}
let logprobs = self.create_logprobs(
......@@ -342,4 +345,8 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateCompletionResponse> for
fn is_usage_enabled(&self) -> bool {
DeltaGenerator::is_usage_enabled(self)
}
fn get_usage(&self) -> dynamo_async_openai::types::CompletionUsage {
DeltaGenerator::get_usage(self)
}
}
......@@ -208,11 +208,29 @@ async fn test_streaming_without_usage() {
// Collect all chunks
let chunks: Vec<_> = transformed_stream.collect().await;
// Verify we got exactly 3 chunks (no extra usage chunk)
assert_eq!(chunks.len(), 3, "Should have exactly 3 content chunks");
// Filter out metrics annotation events (events without SSE data payload)
let content_chunks: Vec<_> = chunks
.into_iter()
.filter(|chunk| {
// Metrics annotation events have event=Some(ANNOTATION_LLM_METRICS) and data=None
!(chunk
.event
.as_ref()
.map(|e| e == "llm_metrics")
.unwrap_or(false)
&& chunk.data.is_none())
})
.collect();
// Verify we got exactly 3 content chunks (no extra usage chunk)
assert_eq!(
content_chunks.len(),
3,
"Should have exactly 3 content chunks"
);
// Verify all chunks have usage: None
for (i, chunk) in chunks.iter().enumerate() {
for (i, chunk) in content_chunks.iter().enumerate() {
if let Some(response) = &chunk.data {
assert!(
response.usage.is_none(),
......@@ -322,15 +340,29 @@ async fn test_streaming_with_usage_false() {
// Collect all chunks
let chunks: Vec<_> = transformed_stream.collect().await;
// Filter out metrics annotation events (events without SSE data payload)
let content_chunks: Vec<_> = chunks
.into_iter()
.filter(|chunk| {
// Metrics annotation events have event=Some(ANNOTATION_LLM_METRICS) and data=None
!(chunk
.event
.as_ref()
.map(|e| e == "llm_metrics")
.unwrap_or(false)
&& chunk.data.is_none())
})
.collect();
// Verify we got exactly 3 chunks (no extra usage chunk when explicitly false)
assert_eq!(
chunks.len(),
content_chunks.len(),
3,
"Should have exactly 3 content chunks when include_usage is false"
);
// Verify all chunks have usage: None
for (i, chunk) in chunks.iter().enumerate() {
for (i, chunk) in content_chunks.iter().enumerate() {
if let Some(response) = &chunk.data {
assert!(
response.usage.is_none(),
......
......@@ -113,6 +113,9 @@ pub mod frontend_service {
/// Output sequence length in tokens
pub const OUTPUT_SEQUENCE_TOKENS: &str = "output_sequence_tokens";
/// Number of cached tokens (prefix cache hits) per request
pub const CACHED_TOKENS: &str = "cached_tokens";
/// Total number of output tokens generated (counter that updates in real-time)
pub const OUTPUT_TOKENS_TOTAL: &str = "output_tokens_total";
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment