Unverified Commit 2845aa1f authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: include ttft and total request time in nvext (engine-agnostic) (#4880)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
Signed-off-by: default avatarYan Ru Pei <yanrpei@gmail.com>
Co-authored-by: default avatarKeiven C <213854356+keivenchang@users.noreply.github.com>
parent 52503032
......@@ -22,6 +22,7 @@ use super::TokenIdType;
pub mod llm_backend;
pub mod postprocessor;
pub mod preprocessor;
pub mod timing;
/// SamplingOptionsProvider is a trait that allows the caller to extract the sampling options from
/// the object that implements it. This will mutate the object.
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Per-request timing tracker for capturing request lifecycle metrics.
//!
//! This module provides [`RequestTimingTracker`] for tracking timing information
//! that can be returned to clients via the `nvext` response field.
use serde::{Deserialize, Serialize};
use std::sync::OnceLock;
use std::time::{Instant, SystemTime, UNIX_EPOCH};
/// Per-request timing tracker.
///
/// Captures timing information throughout the request lifecycle:
/// - `request_received`: When the request was received
/// - `first_token_time`: When the first token was generated (set once via OnceLock)
/// - `request_finish_time`: When the request finished (set once via OnceLock)
///
/// The `OnceLock` fields ensure that timing values are set exactly once,
/// which is important for disaggregated serving where the "first token"
/// might appear multiple times.
pub struct RequestTimingTracker {
/// When the request was received (monotonic clock for duration calculations)
request_received: Instant,
/// When the request was received (wall clock time as epoch milliseconds)
request_received_epoch_ms: u64,
/// When the first token was generated - set once via OnceLock
first_token_time: OnceLock<Instant>,
/// When the request finished - set once via OnceLock
request_finish_time: OnceLock<Instant>,
}
impl RequestTimingTracker {
/// Create a new timing tracker, capturing the current time as request received.
pub fn new() -> Self {
let now = Instant::now();
let epoch_ms = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_millis() as u64)
.unwrap_or(0);
RequestTimingTracker {
request_received: now,
request_received_epoch_ms: epoch_ms,
first_token_time: OnceLock::new(),
request_finish_time: OnceLock::new(),
}
}
pub fn record_first_token(&self) -> bool {
self.first_token_time.set(Instant::now()).is_ok()
}
pub fn record_finish(&self) -> bool {
self.request_finish_time.set(Instant::now()).is_ok()
}
pub fn ttft_ms(&self) -> Option<f64> {
self.first_token_time
.get()
.map(|t| t.duration_since(self.request_received).as_secs_f64() * 1000.0)
}
pub fn total_time_ms(&self) -> Option<f64> {
self.request_finish_time
.get()
.map(|t| t.duration_since(self.request_received).as_secs_f64() * 1000.0)
}
pub fn request_received_epoch_ms(&self) -> u64 {
self.request_received_epoch_ms
}
pub fn get_timing_info(&self) -> TimingInfo {
TimingInfo {
request_received_ms: self.request_received_epoch_ms,
ttft_ms: self.ttft_ms(),
total_time_ms: self.total_time_ms(),
}
}
}
impl Default for RequestTimingTracker {
fn default() -> Self {
Self::new()
}
}
/// Timing information for response injection.
///
/// This struct is serialized and included in the response's `nvext` field
/// when the client requests timing information via `extra_fields: ["timing"]`.
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct TimingInfo {
/// When the request was received (epoch milliseconds)
pub request_received_ms: u64,
/// Time to first token in milliseconds
#[serde(skip_serializing_if = "Option::is_none")]
pub ttft_ms: Option<f64>,
/// Total request time in milliseconds
#[serde(skip_serializing_if = "Option::is_none")]
pub total_time_ms: Option<f64>,
}
......@@ -5,8 +5,8 @@ use super::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse}
use crate::{
local_model::runtime_config::ModelRuntimeConfig,
protocols::{
common,
openai::nvext::{NvExtResponse, WorkerIdInfo},
common::{self, timing::RequestTimingTracker},
openai::nvext::{NvExtProvider, NvExtResponse, TimingInfo, WorkerIdInfo},
},
types::TokenIdType,
};
......@@ -44,6 +44,12 @@ impl NvCreateChatCompletionRequest {
/// # Returns
/// * [`DeltaGenerator`] configured with model name and response options.
pub fn response_generator(&self, request_id: String) -> DeltaGenerator {
// Check if client requested timing in extra_fields
let enable_timing = self
.nvext()
.and_then(|nv| nv.extra_fields.as_ref())
.is_some_and(|fields| fields.iter().any(|f| f == "timing"));
let options = DeltaGeneratorOptions {
enable_usage: self
.inner
......@@ -53,6 +59,7 @@ impl NvCreateChatCompletionRequest {
.unwrap_or(false),
enable_logprobs: self.inner.logprobs.unwrap_or(false)
|| self.inner.top_logprobs.unwrap_or(0) > 0,
enable_timing,
runtime_config: ModelRuntimeConfig::default(),
};
......@@ -67,12 +74,13 @@ pub struct DeltaGeneratorOptions {
pub enable_usage: bool,
/// Determines whether log probabilities should be included in the response.
pub enable_logprobs: bool,
/// Determines whether timing information should be included in the response's nvext.
pub enable_timing: bool,
pub runtime_config: ModelRuntimeConfig,
}
/// Generates incremental chat completion responses in a streaming fashion.
#[derive(Debug)]
pub struct DeltaGenerator {
/// Unique identifier for the chat completion session.
id: String,
......@@ -91,6 +99,8 @@ pub struct DeltaGenerator {
msg_counter: u64,
/// Configuration options for response generation.
options: DeltaGeneratorOptions,
/// Optional timing tracker for per-request timing metrics.
timing_tracker: Option<RequestTimingTracker>,
}
impl DeltaGenerator {
......@@ -123,6 +133,13 @@ impl DeltaGenerator {
let chatcmpl_id = format!("chatcmpl-{request_id}");
// Create timing tracker if timing is enabled
let timing_tracker = if options.enable_timing {
Some(RequestTimingTracker::new())
} else {
None
};
Self {
id: chatcmpl_id,
object: "chat.completion.chunk".to_string(),
......@@ -133,6 +150,7 @@ impl DeltaGenerator {
usage,
msg_counter: 0,
options,
timing_tracker,
}
}
......@@ -365,24 +383,44 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamRes
let index = 0;
let mut stream_response = self.create_choice(index, delta.text, finish_reason, logprobs);
// Extract worker_id from disaggregated_params and inject into nvext if present
if let Some(worker_id_info) = delta
// Record first token time (only succeeds on first call due to OnceLock)
if let Some(ref tracker) = self.timing_tracker {
tracker.record_first_token();
}
// Extract worker_id from disaggregated_params
let worker_id_info = delta
.disaggregated_params
.as_ref()
.and_then(|params| params.get("worker_id"))
.and_then(|v| serde_json::from_value::<WorkerIdInfo>(v.clone()).ok())
{
.and_then(|v| serde_json::from_value::<WorkerIdInfo>(v.clone()).ok());
// Get timing info if this is the final response (has finish_reason)
let timing_info: Option<TimingInfo> = if finish_reason.is_some() {
self.timing_tracker.as_ref().map(|tracker| {
tracker.record_finish();
tracker.get_timing_info()
})
} else {
None
};
// Inject nvext if we have worker_id or timing
if worker_id_info.is_some() || timing_info.is_some() {
let nvext_response = NvExtResponse {
worker_id: Some(worker_id_info.clone()),
worker_id: worker_id_info.clone(),
timing: timing_info,
};
if let Ok(nvext_json) = serde_json::to_value(&nvext_response) {
stream_response.nvext = Some(nvext_json);
tracing::debug!(
"Injected worker_id into chat completion nvext: prefill={:?}, decode={:?}",
worker_id_info.prefill_worker_id,
worker_id_info.decode_worker_id
);
if let Some(ref info) = worker_id_info {
tracing::debug!(
"Injected worker_id into chat completion nvext: prefill={:?}, decode={:?}",
info.prefill_worker_id,
info.decode_worker_id
);
}
}
}
......
......@@ -4,8 +4,8 @@
use super::{NvCreateCompletionRequest, NvCreateCompletionResponse};
use crate::{
protocols::{
common,
openai::nvext::{NvExtResponse, WorkerIdInfo},
common::{self, timing::RequestTimingTracker},
openai::nvext::{NvExtProvider, NvExtResponse, TimingInfo, WorkerIdInfo},
},
types::TokenIdType,
};
......@@ -39,6 +39,12 @@ impl NvCreateCompletionRequest {
// put this method on the request
// inspect the request to extract options
pub fn response_generator(&self, request_id: String) -> DeltaGenerator {
// Check if client requested timing in extra_fields
let enable_timing = self
.nvext()
.and_then(|nv| nv.extra_fields.as_ref())
.is_some_and(|fields| fields.iter().any(|f| f == "timing"));
let options = DeltaGeneratorOptions {
enable_usage: self
.inner
......@@ -47,6 +53,7 @@ impl NvCreateCompletionRequest {
.map(|opts| opts.include_usage)
.unwrap_or(false),
enable_logprobs: self.inner.logprobs.unwrap_or(0) > 0,
enable_timing,
};
DeltaGenerator::new(self.inner.model.clone(), options, request_id)
......@@ -57,9 +64,9 @@ impl NvCreateCompletionRequest {
pub struct DeltaGeneratorOptions {
pub enable_usage: bool,
pub enable_logprobs: bool,
pub enable_timing: bool,
}
#[derive(Debug, Clone)]
pub struct DeltaGenerator {
id: String,
object: String,
......@@ -68,6 +75,7 @@ pub struct DeltaGenerator {
system_fingerprint: Option<String>,
usage: dynamo_async_openai::types::CompletionUsage,
options: DeltaGeneratorOptions,
timing_tracker: Option<RequestTimingTracker>,
}
impl DeltaGenerator {
......@@ -93,6 +101,13 @@ impl DeltaGenerator {
let completion_id = format!("cmpl-{request_id}");
// Create timing tracker if timing is enabled
let timing_tracker = if options.enable_timing {
Some(RequestTimingTracker::new())
} else {
None
};
Self {
id: completion_id,
object: "text_completion".to_string(),
......@@ -101,6 +116,7 @@ impl DeltaGenerator {
system_fingerprint: None,
usage,
options,
timing_tracker,
}
}
......@@ -271,24 +287,44 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateCompletionResponse> for
let index = delta.index.unwrap_or(0);
let mut response = self.create_choice(index, delta.text.clone(), finish_reason, logprobs);
// Extract worker_id from disaggregated_params and inject into nvext if present
if let Some(worker_id_info) = delta
// Record first token time (only succeeds on first call due to OnceLock)
if let Some(ref tracker) = self.timing_tracker {
tracker.record_first_token();
}
// Extract worker_id from disaggregated_params
let worker_id_info = delta
.disaggregated_params
.as_ref()
.and_then(|params| params.get("worker_id"))
.and_then(|v| serde_json::from_value::<WorkerIdInfo>(v.clone()).ok())
{
.and_then(|v| serde_json::from_value::<WorkerIdInfo>(v.clone()).ok());
// Get timing info if this is the final response (has finish_reason)
let timing_info: Option<TimingInfo> = if finish_reason.is_some() {
self.timing_tracker.as_ref().map(|tracker| {
tracker.record_finish();
tracker.get_timing_info()
})
} else {
None
};
// Inject nvext if we have worker_id or timing
if worker_id_info.is_some() || timing_info.is_some() {
let nvext_response = NvExtResponse {
worker_id: Some(worker_id_info.clone()),
worker_id: worker_id_info.clone(),
timing: timing_info,
};
if let Ok(nvext_json) = serde_json::to_value(&nvext_response) {
response.inner.nvext = Some(nvext_json);
tracing::debug!(
"Injected worker_id into completions nvext: prefill={:?}, decode={:?}",
worker_id_info.prefill_worker_id,
worker_id_info.decode_worker_id
);
if let Some(ref info) = worker_id_info {
tracing::debug!(
"Injected worker_id into completions nvext: prefill={:?}, decode={:?}",
info.prefill_worker_id,
info.decode_worker_id
);
}
}
}
......
......@@ -5,6 +5,8 @@ use derive_builder::Builder;
use serde::{Deserialize, Serialize};
use validator::{Validate, ValidationError};
pub use crate::protocols::common::timing::TimingInfo;
pub trait NvExtProvider {
fn nvext(&self) -> Option<&NvExt>;
fn raw_prompt(&self) -> Option<String>;
......@@ -28,6 +30,11 @@ pub struct NvExtResponse {
/// Worker ID information (prefill and decode worker IDs)
#[serde(skip_serializing_if = "Option::is_none")]
pub worker_id: Option<WorkerIdInfo>,
/// Per-request timing information
/// Populated when client requests `extra_fields: ["timing"]`
#[serde(skip_serializing_if = "Option::is_none")]
pub timing: Option<TimingInfo>,
}
/// NVIDIA LLM extensions to the OpenAI API
......@@ -76,7 +83,7 @@ pub struct NvExt {
/// Extra fields to be included in the response's nvext
/// This is a list of field names that should be populated in the response
/// Supported fields: "worker_id"
/// Supported fields: "worker_id", "timing", which has a 1:1 mapping with the NvExtResponse names
#[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub extra_fields: Option<Vec<String>>,
......
......@@ -134,6 +134,23 @@ def verify_response_worker_ids(
)
def verify_response_timing(timing_info: dict[str, Any]) -> None:
"""Verify timing info has valid values (ttft_ms > 0, total_time_ms > 0)."""
ttft_ms = timing_info.get("ttft_ms")
total_time_ms = timing_info.get("total_time_ms")
assert ttft_ms is not None and ttft_ms > 0, f"Expected ttft_ms > 0, got: {ttft_ms}"
assert (
total_time_ms is not None and total_time_ms > 0
), f"Expected total_time_ms > 0, got: {total_time_ms}"
assert (
total_time_ms >= ttft_ms
), f"Expected total_time_ms >= ttft_ms, got {total_time_ms} < {ttft_ms}"
logger.info(
f"✓ Verified timing: ttft_ms={ttft_ms:.2f}, total_time_ms={total_time_ms:.2f}"
)
########################################################
# Utility functions
########################################################
......@@ -1646,7 +1663,7 @@ def _test_router_decisions_disagg(
# Each iteration adds more content to extend the prefix
progressive_content = " ".join([base_content] * (i + 1))
# Create payload with worker_id in extra_fields to get prefill/decode worker IDs
# Create payload with worker_id and timing in extra_fields
payload = {
**test_payload,
"messages": [
......@@ -1655,7 +1672,7 @@ def _test_router_decisions_disagg(
"content": progressive_content,
}
],
"nvext": {"extra_fields": ["worker_id"]},
"nvext": {"extra_fields": ["worker_id", "timing"]},
"stream": True,
}
......@@ -1669,9 +1686,10 @@ def _test_router_decisions_disagg(
response.status == 200
), f"Request {i + 1} failed with status {response.status}"
# Collect all chunks and look for nvext with worker_id
# Collect all chunks and look for nvext with worker_id and timing
prefill_wid = None
decode_wid = None
timing_info = None
async for line in response.content:
if not line:
......@@ -1687,24 +1705,29 @@ def _test_router_decisions_disagg(
try:
data = json.loads(data_str)
# Check for nvext.worker_id in the response
# Check for nvext in the response
nvext = data.get("nvext", {})
worker_id_info = nvext.get("worker_id", {})
if worker_id_info:
if "prefill_worker_id" in worker_id_info:
prefill_wid = worker_id_info[
"prefill_worker_id"
]
if "decode_worker_id" in worker_id_info:
decode_wid = worker_id_info["decode_worker_id"]
if nvext:
worker_id_info = nvext.get("worker_id", {})
if worker_id_info:
if "prefill_worker_id" in worker_id_info:
prefill_wid = worker_id_info[
"prefill_worker_id"
]
if "decode_worker_id" in worker_id_info:
decode_wid = worker_id_info[
"decode_worker_id"
]
# Timing info appears in final chunk
if "timing" in nvext:
timing_info = nvext["timing"]
except json.JSONDecodeError:
continue
logger.info(
f"Request {i + 1}: prefill_worker_id={prefill_wid}, "
f"decode_worker_id={decode_wid}"
f"decode_worker_id={decode_wid}, timing={timing_info}"
)
if prefill_wid is not None:
......@@ -1712,6 +1735,12 @@ def _test_router_decisions_disagg(
if decode_wid is not None:
decode_worker_ids.append(decode_wid)
# Verify timing info is present and valid
assert (
timing_info is not None
), f"Request {i + 1}: Expected timing info in final chunk, got None"
verify_response_timing(timing_info)
# Small delay between requests
await asyncio.sleep(0.5)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment