Unverified Commit 85b5c808 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: unify per-request metric tracking (#5004)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent b823575e
...@@ -22,7 +22,9 @@ use tracing; ...@@ -22,7 +22,9 @@ use tracing;
use llm_rs::kv_router::protocols::*; use llm_rs::kv_router::protocols::*;
use llm_rs::kv_router::publisher::{KvEventSourceConfig, create_stored_blocks, start_zmq_listener}; use llm_rs::kv_router::publisher::{KvEventSourceConfig, create_stored_blocks, start_zmq_listener};
use llm_rs::protocols::common::timing::RequestTracker;
use llm_rs::protocols::common::{OutputOptions, SamplingOptions, StopConditions}; use llm_rs::protocols::common::{OutputOptions, SamplingOptions, StopConditions};
use serde_json::json;
#[pyfunction] #[pyfunction]
pub fn compute_block_hash_for_seq_py(tokens: Vec<u32>, kv_block_size: usize) -> PyResult<Vec<u64>> { pub fn compute_block_hash_for_seq_py(tokens: Vec<u32>, kv_block_size: usize) -> PyResult<Vec<u64>> {
...@@ -1055,6 +1057,31 @@ pub(crate) struct KvPushRouter { ...@@ -1055,6 +1057,31 @@ pub(crate) struct KvPushRouter {
inner: Arc<llm_rs::kv_router::KvPushRouter>, inner: Arc<llm_rs::kv_router::KvPushRouter>,
} }
/// Inject worker_id info from tracker into response's disaggregated_params.
/// This is needed for Python bindings to expose worker routing info since
/// the raw LLMEngineOutput doesn't go through DeltaGenerator (which adds nvext).
fn inject_worker_id_from_tracker(
data: &mut llm_rs::protocols::common::llm_backend::LLMEngineOutput,
tracker: &RequestTracker,
) {
let Some(worker_info) = tracker.get_worker_info() else {
return;
};
let worker_id_json =
serde_json::to_value(&worker_info).expect("WorkerIdInfo serialization should not fail");
if let Some(obj) = data
.disaggregated_params
.as_mut()
.and_then(|p| p.as_object_mut())
{
obj.insert("worker_id".to_string(), worker_id_json);
} else {
data.disaggregated_params = Some(json!({"worker_id": worker_id_json}));
}
}
// TODO: can this reuse the stream conversion method in Client bindings? // TODO: can this reuse the stream conversion method in Client bindings?
impl KvPushRouter { impl KvPushRouter {
/// Helper method to process a request and create a Python async generator /// Helper method to process a request and create a Python async generator
...@@ -1062,6 +1089,7 @@ impl KvPushRouter { ...@@ -1062,6 +1089,7 @@ impl KvPushRouter {
py: Python<'p>, py: Python<'p>,
inner: Arc<llm_rs::kv_router::KvPushRouter>, inner: Arc<llm_rs::kv_router::KvPushRouter>,
request: llm_rs::protocols::common::preprocessor::PreprocessedRequest, request: llm_rs::protocols::common::preprocessor::PreprocessedRequest,
tracker: Option<Arc<RequestTracker>>,
) -> PyResult<Bound<'p, PyAny>> { ) -> PyResult<Bound<'p, PyAny>> {
pyo3_async_runtimes::tokio::future_into_py(py, async move { pyo3_async_runtimes::tokio::future_into_py(py, async move {
let single_in = SingleIn::new(request); let single_in = SingleIn::new(request);
...@@ -1071,7 +1099,17 @@ impl KvPushRouter { ...@@ -1071,7 +1099,17 @@ impl KvPushRouter {
// Spawn a task to process the stream // Spawn a task to process the stream
tokio::spawn(async move { tokio::spawn(async move {
let mut stream = stream; let mut stream = stream;
while let Some(response) = stream.next().await { let mut first_item = true;
while let Some(mut response) = stream.next().await {
// Inject worker_id into first response if tracker is available
if first_item {
first_item = false;
if let (Some(tracker), Some(data)) = (&tracker, &mut response.data) {
inject_worker_id_from_tracker(data, tracker);
}
}
// Convert LLMEngineOutput to PyObject // Convert LLMEngineOutput to PyObject
let py_response = Python::with_gil(|py| { let py_response = Python::with_gil(|py| {
pythonize(py, &response.data) pythonize(py, &response.data)
...@@ -1190,6 +1228,9 @@ impl KvPushRouter { ...@@ -1190,6 +1228,9 @@ impl KvPushRouter {
None None
}; };
// Create tracker to capture worker routing info from KvRouter
let tracker = Arc::new(RequestTracker::new());
// Build the PreprocessedRequest // Build the PreprocessedRequest
let mut request_builder = let mut request_builder =
llm_rs::protocols::common::preprocessor::PreprocessedRequest::builder(); llm_rs::protocols::common::preprocessor::PreprocessedRequest::builder();
...@@ -1201,7 +1242,8 @@ impl KvPushRouter { ...@@ -1201,7 +1242,8 @@ impl KvPushRouter {
.output_options(output_options) .output_options(output_options)
.router_config_override(router_config_override) .router_config_override(router_config_override)
.dp_rank(dp_rank) .dp_rank(dp_rank)
.extra_args(extra_args); .extra_args(extra_args)
.tracker(Some(tracker.clone()));
// Set backend_instance_id if worker_id is provided // Set backend_instance_id if worker_id is provided
if let Some(worker_id) = worker_id { if let Some(worker_id) = worker_id {
...@@ -1211,7 +1253,7 @@ impl KvPushRouter { ...@@ -1211,7 +1253,7 @@ impl KvPushRouter {
let request = request_builder.build().map_err(to_pyerr)?; let request = request_builder.build().map_err(to_pyerr)?;
// Use the helper method to process the request // Use the helper method to process the request
Self::process_request_to_stream(py, self.inner.clone(), request) Self::process_request_to_stream(py, self.inner.clone(), request, Some(tracker))
} }
fn generate_from_request<'p>( fn generate_from_request<'p>(
...@@ -1220,11 +1262,21 @@ impl KvPushRouter { ...@@ -1220,11 +1262,21 @@ impl KvPushRouter {
request: PyObject, request: PyObject,
) -> PyResult<Bound<'p, PyAny>> { ) -> PyResult<Bound<'p, PyAny>> {
// Depythonize the request directly into PreprocessedRequest // Depythonize the request directly into PreprocessedRequest
let request: llm_rs::protocols::common::preprocessor::PreprocessedRequest = let mut request: llm_rs::protocols::common::preprocessor::PreprocessedRequest =
depythonize(request.bind(py)).map_err(to_pyerr)?; depythonize(request.bind(py)).map_err(to_pyerr)?;
// Create tracker if not already set, to capture worker routing info
let tracker = match request.tracker {
Some(ref t) => t.clone(),
None => {
let t = Arc::new(RequestTracker::new());
request.tracker = Some(t.clone());
t
}
};
// Use the helper method to process the request // Use the helper method to process the request
Self::process_request_to_stream(py, self.inner.clone(), request) Self::process_request_to_stream(py, self.inner.clone(), request, Some(tracker))
} }
#[pyo3(signature = (token_ids, router_config_override=None, request_id=None))] #[pyo3(signature = (token_ids, router_config_override=None, request_id=None))]
......
...@@ -22,8 +22,6 @@ use futures::stream::{self, StreamExt}; ...@@ -22,8 +22,6 @@ use futures::stream::{self, StreamExt};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::json; use serde_json::json;
use crate::protocols::openai::nvext::WorkerIdInfo;
pub mod approx; pub mod approx;
pub mod indexer; pub mod indexer;
pub mod prefill_router; pub mod prefill_router;
...@@ -58,6 +56,7 @@ use crate::{ ...@@ -58,6 +56,7 @@ use crate::{
model_card::ModelDeploymentCard, model_card::ModelDeploymentCard,
preprocessor::PreprocessedRequest, preprocessor::PreprocessedRequest,
protocols::common::llm_backend::LLMEngineOutput, protocols::common::llm_backend::LLMEngineOutput,
protocols::common::timing::RequestPhase,
tokens::SequenceHash, tokens::SequenceHash,
}; };
...@@ -97,46 +96,6 @@ pub fn router_endpoint_id(namespace: String) -> EndpointId { ...@@ -97,46 +96,6 @@ pub fn router_endpoint_id(namespace: String) -> EndpointId {
} }
} }
/// Specifies the type of worker being queried when using the `query_instance_id` annotation.
/// This tells the router which worker pool to select from and what type of operation is intended.
///
/// Query instance types for worker selection
/// - "prefill" → select a prefill worker (disaggregated serving)
/// - "decode" → select a decode worker (disaggregated serving)
///
/// Note: Empty value ("query_instance_id:") is handled by PrefillRouter for disagg orchestration
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum QueryInstanceType {
/// Query for a prefill worker (disaggregated serving)
Prefill,
/// Query for a decode worker (disaggregated serving)
Decode,
}
impl std::fmt::Display for QueryInstanceType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
QueryInstanceType::Prefill => write!(f, "prefill"),
QueryInstanceType::Decode => write!(f, "decode"),
}
}
}
impl std::str::FromStr for QueryInstanceType {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"prefill" => Ok(QueryInstanceType::Prefill),
"decode" => Ok(QueryInstanceType::Decode),
_ => Err(format!(
"Invalid QueryInstanceType: '{s}'. Expected 'prefill' or 'decode'"
)),
}
}
}
/// Creates a DiscoveryQuery for the KV router in the given namespace. /// Creates a DiscoveryQuery for the KV router in the given namespace.
pub fn router_discovery_query(namespace: String) -> DiscoveryQuery { pub fn router_discovery_query(namespace: String) -> DiscoveryQuery {
DiscoveryQuery::Endpoint { DiscoveryQuery::Endpoint {
...@@ -771,38 +730,34 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu ...@@ -771,38 +730,34 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
// Extract context ID for request tracking // Extract context ID for request tracking
let context_id = request.context().id().to_string(); let context_id = request.context().id().to_string();
// Check if this is a query_instance_id request and parse its type // Simple query-only detection: presence of query_instance_id annotation means query-only mode
// Format: "query_instance_id:type" where type is "prefill", "decode", or "" (empty for aggregated) let is_query_only = request.get_annotation_value("query_instance_id").is_some();
// Empty value ("query_instance_id:") means GAIE Aggregated mode - return same worker as both prefill and decode
let query_instance_annotation = request.get_annotation_value("query_instance_id"); // Get phase from tracker (defaults to Aggregated if no tracker or phase not set)
let is_gaie_agg_query = query_instance_annotation let phase = request
.tracker
.as_ref() .as_ref()
.is_some_and(|s| s.is_empty()); .map(|t| t.phase())
let query_instance_type: Option<QueryInstanceType> = .unwrap_or(RequestPhase::Aggregated);
if let Some(type_str) = &query_instance_annotation {
match type_str.parse::<QueryInstanceType>() { // Get pre-selected worker based on phase
Ok(t) => Some(t), let preselected = match phase {
Err(_) if type_str.is_empty() => { RequestPhase::Prefill => request.target_prefill_worker_id,
// Empty value is valid for aggregated mode, not a warning RequestPhase::Decode => request.target_decode_worker_id,
None RequestPhase::Aggregated => None,
}
Err(e) => {
tracing::warn!("Invalid query_instance_id type '{type_str}': {e}");
None
}
}
} else {
None
}; };
let (instance_id, dp_rank, overlap_amount) = if let Some(id) = request.backend_instance_id { let block_size = self.chooser.block_size() as usize;
// If instance_id is set, use it and compute actual overlap let (instance_id, dp_rank, overlap_amount) =
if let Some(id) = preselected.or(request.backend_instance_id) {
// Route to pre-selected or explicitly specified worker
let dp_rank = request.dp_rank.unwrap_or(0); let dp_rank = request.dp_rank.unwrap_or(0);
if query_instance_type.is_some() {
tracing::debug!( tracing::debug!(
"backend_instance_id is set, routing to instance {id} with dp_rank {dp_rank} and ignoring query_instance_id annotation" worker_id = id,
dp_rank = dp_rank,
?phase,
"Routing to specified worker"
); );
}
// Compute actual overlap blocks by querying the indexer // Compute actual overlap blocks by querying the indexer
let block_hashes = let block_hashes =
...@@ -811,6 +766,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu ...@@ -811,6 +766,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
let worker = WorkerWithDpRank::new(id, dp_rank); let worker = WorkerWithDpRank::new(id, dp_rank);
let overlap_blocks = overlap_scores.scores.get(&worker).copied().unwrap_or(0); let overlap_blocks = overlap_scores.scores.get(&worker).copied().unwrap_or(0);
if !is_query_only {
self.chooser self.chooser
.add_request( .add_request(
context_id.clone(), context_id.clone(),
...@@ -819,74 +775,44 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu ...@@ -819,74 +775,44 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
worker, worker,
) )
.await; .await;
}
(id, dp_rank, overlap_blocks) (id, dp_rank, overlap_blocks)
} else { } else {
// Otherwise, find the best match // Find the best worker match
// Don't update states if this is a query-only request (any query_instance_id annotation) // Don't update states if this is a query-only request
let should_update_states = query_instance_annotation.is_none();
let (best_worker, overlap_amount) = self let (best_worker, overlap_amount) = self
.chooser .chooser
.find_best_match( .find_best_match(
Some(&context_id), Some(&context_id),
&request.token_ids, &request.token_ids,
request.router_config_override.as_ref(), request.router_config_override.as_ref(),
should_update_states, !is_query_only,
) )
.await?; .await?;
(best_worker.worker_id, best_worker.dp_rank, overlap_amount) (best_worker.worker_id, best_worker.dp_rank, overlap_amount)
}; };
// If request has a query_instance_id annotation, return worker selection info // Record metrics in tracker: KV hit rate and worker ID based on phase
// without routing to the actual worker. Returns LLMEngineOutput with disaggregated_params if let Some(ref tracker) = request.tracker {
// containing worker_id info, same structure as normal execution for uniform extraction. let isl_blocks = request.token_ids.len().div_ceil(block_size);
tracker.record_kv_hit(overlap_amount, isl_blocks);
tracker.record_worker(instance_id);
}
// Handle query-only requests: early return with worker info
if is_query_only {
let stream_context = request.context().clone(); let stream_context = request.context().clone();
// Tracker is always created for query-only requests (delta generator enables tracking
// when query_instance_id annotation is present)
let worker_id_info = request.tracker.as_ref().and_then(|t| t.get_worker_info());
// Handle query-only requests (GAIE Stage 1)
if query_instance_type.is_some() || is_gaie_agg_query {
let worker_id_info = if is_gaie_agg_query {
// GAIE Aggregated mode: same worker serves both prefill and decode
tracing::trace!( tracing::trace!(
query_type = "aggregated", ?phase,
worker_id = instance_id, worker_id = instance_id,
"Returning aggregated worker selection (same worker for prefill and decode)" ?worker_id_info,
); "Returning worker selection (query-only mode)"
WorkerIdInfo {
prefill_worker_id: Some(instance_id),
decode_worker_id: Some(instance_id),
}
} else {
match query_instance_type.unwrap() {
QueryInstanceType::Prefill => {
tracing::trace!(
query_type = "prefill",
prefill_worker_id = instance_id,
"Returning prefill worker selection"
);
WorkerIdInfo {
prefill_worker_id: Some(instance_id),
decode_worker_id: None,
}
}
QueryInstanceType::Decode => {
// Get prefill_worker_id from annotation (set by caller after prefill selection)
let prefill_worker_id = request
.get_annotation_value("prefill_worker_id")
.and_then(|s| s.parse::<u64>().ok());
tracing::trace!(
query_type = "decode",
prefill_worker_id = ?prefill_worker_id,
decode_worker_id = instance_id,
"Returning decode worker selection"
); );
WorkerIdInfo {
prefill_worker_id,
decode_worker_id: Some(instance_id),
}
}
}
};
// Return as LLMEngineOutput with disaggregated_params (same structure as normal execution)
let output = LLMEngineOutput { let output = LLMEngineOutput {
disaggregated_params: Some(json!({ disaggregated_params: Some(json!({
"worker_id": worker_id_info, "worker_id": worker_id_info,
...@@ -898,25 +824,10 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu ...@@ -898,25 +824,10 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
let stream = stream::iter(vec![response]); let stream = stream::iter(vec![response]);
return Ok(ResponseStream::new(Box::pin(stream), stream_context)); return Ok(ResponseStream::new(Box::pin(stream), stream_context));
} }
// Route to worker
let (mut backend_input, context) = request.into_parts(); let (mut backend_input, context) = request.into_parts();
backend_input.estimated_prefix_hit_num_blocks = Some(overlap_amount);
backend_input.dp_rank = Some(dp_rank); backend_input.dp_rank = Some(dp_rank);
// Get prefill worker ID from prefill_result if available
// In aggregated mode, prefill_result is None, so we use decode_worker_id for both
let decode_worker_id = instance_id;
let prefill_worker_id = backend_input
.prefill_result
.as_ref()
.and_then(|prefill_result| {
prefill_result
.disaggregated_params
.get("worker_id")
.and_then(|v| serde_json::from_value::<WorkerIdInfo>(v.clone()).ok())
.and_then(|info| info.prefill_worker_id)
})
.or(Some(decode_worker_id)); // Use decode_worker_id if no separate prefill worker
let updated_request = context.map(|_| backend_input); let updated_request = context.map(|_| backend_input);
let mut response_stream = self.inner.direct(updated_request, instance_id).await?; let mut response_stream = self.inner.direct(updated_request, instance_id).await?;
...@@ -926,7 +837,6 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu ...@@ -926,7 +837,6 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
let wrapped_stream = Box::pin(async_stream::stream! { let wrapped_stream = Box::pin(async_stream::stream! {
let mut prefill_marked = false; let mut prefill_marked = false;
let mut first_item = true;
loop { loop {
tokio::select! { tokio::select! {
...@@ -938,7 +848,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu ...@@ -938,7 +848,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
} }
item = response_stream.next() => { item = response_stream.next() => {
let Some(mut item) = item else { let Some(item) = item else {
break; break;
}; };
...@@ -949,34 +859,6 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu ...@@ -949,34 +859,6 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
prefill_marked = true; prefill_marked = true;
} }
// Always inject worker_id in first item's disaggregated_params
// This is needed for:
// 1. PrefillRouter to know which prefill worker was chosen
// 2. Client response when extra_fields contains "worker_id"
if first_item {
first_item = false;
let Some(ref mut data) = item.data else {
yield item;
continue;
};
// prefill_worker_id comes from prefill_result.disaggregated_params or falls back to instance_id
// decode_worker_id is always the current instance_id
let worker_id_info = WorkerIdInfo {
prefill_worker_id,
decode_worker_id: Some(decode_worker_id),
};
let worker_id_json = serde_json::to_value(&worker_id_info)
.expect("WorkerIdInfo serialization should not fail");
if let Some(obj) = data.disaggregated_params.as_mut().and_then(|p| p.as_object_mut()) {
obj.insert("worker_id".to_string(), worker_id_json);
} else {
data.disaggregated_params = Some(json!({"worker_id": worker_id_json}));
}
}
yield item; yield item;
} }
} }
......
...@@ -20,9 +20,10 @@ use dynamo_runtime::{ ...@@ -20,9 +20,10 @@ use dynamo_runtime::{
use crate::{ use crate::{
discovery::ModelManager, discovery::ModelManager,
kv_router::{KvPushRouter, KvRouterConfig, QueryInstanceType, RouterConfigOverride}, kv_router::{KvPushRouter, KvRouterConfig, RouterConfigOverride},
protocols::common::llm_backend::{LLMEngineOutput, PreprocessedRequest}, protocols::common::llm_backend::{LLMEngineOutput, PreprocessedRequest},
protocols::common::preprocessor::{BootstrapInfo, PrefillResult}, protocols::common::preprocessor::{BootstrapInfo, PrefillResult},
protocols::common::timing::RequestPhase,
protocols::openai::nvext::WorkerIdInfo, protocols::openai::nvext::WorkerIdInfo,
}; };
...@@ -375,7 +376,7 @@ impl PrefillRouter { ...@@ -375,7 +376,7 @@ impl PrefillRouter {
.retain(|a| !a.starts_with("query_instance_id")); .retain(|a| !a.starts_with("query_instance_id"));
prefill_req prefill_req
.annotations .annotations
.push(format!("query_instance_id:{}", QueryInstanceType::Prefill)); .push(format!("query_instance_id:{}", RequestPhase::Prefill));
} else if let Some(prefill_worker_id) = prefill_req.target_prefill_worker_id { } else if let Some(prefill_worker_id) = prefill_req.target_prefill_worker_id {
// GAIE Stage 2: Route to pre-selected prefill worker from the stage 1 // GAIE Stage 2: Route to pre-selected prefill worker from the stage 1
tracing::debug!( tracing::debug!(
...@@ -404,7 +405,7 @@ impl PrefillRouter { ...@@ -404,7 +405,7 @@ impl PrefillRouter {
.retain(|a| !a.starts_with("query_instance_id")); .retain(|a| !a.starts_with("query_instance_id"));
decode_req decode_req
.annotations .annotations
.push(format!("query_instance_id:{}", QueryInstanceType::Decode)); .push(format!("query_instance_id:{}", RequestPhase::Decode));
decode_req decode_req
.annotations .annotations
.push(format!("prefill_worker_id:{worker_id}")); .push(format!("prefill_worker_id:{worker_id}"));
...@@ -477,6 +478,12 @@ impl ...@@ -477,6 +478,12 @@ impl
); );
} }
// Set phase to Prefill and record prefill start time if tracking is enabled
if let Some(ref tracker) = req.tracker {
tracker.set_phase(RequestPhase::Prefill);
tracker.record_prefill_start();
}
let prefill_context = Context::with_id(prefill_req, request_id.clone()); let prefill_context = Context::with_id(prefill_req, request_id.clone());
engine_ctx.link_child(prefill_context.context()); engine_ctx.link_child(prefill_context.context());
...@@ -487,6 +494,12 @@ impl ...@@ -487,6 +494,12 @@ impl
// Fallback to original: Wait for prefill to complete // Fallback to original: Wait for prefill to complete
tracing::debug!("Using original prefill path"); tracing::debug!("Using original prefill path");
// Set phase to Prefill and record prefill start time if tracking is enabled
if let Some(ref tracker) = req.tracker {
tracker.set_phase(RequestPhase::Prefill);
tracker.record_prefill_start();
}
let prefill_context = Context::with_id(prefill_req, request_id.clone()); let prefill_context = Context::with_id(prefill_req, request_id.clone());
engine_ctx.link_child(prefill_context.context()); engine_ctx.link_child(prefill_context.context());
...@@ -496,12 +509,26 @@ impl ...@@ -496,12 +509,26 @@ impl
} }
} else { } else {
// GAIE Stage 1: Use original path (no bootstrap optimization) // GAIE Stage 1: Use original path (no bootstrap optimization)
// But first check if prefill router is activated - if not, skip to avoid setting phase
if self.prefill_router.get().is_none() {
tracing::debug!("GAIE Stage 1: Prefill router not activated, skipping to decode");
Err(PrefillError::NotActivated)
} else {
tracing::debug!("Using original prefill path (GAIE Stage 1)");
// Set phase to Prefill and record prefill start time if tracking is enabled
if let Some(ref tracker) = req.tracker {
tracker.set_phase(RequestPhase::Prefill);
tracker.record_prefill_start();
}
let prefill_context = Context::with_id(prefill_req, request_id.clone()); let prefill_context = Context::with_id(prefill_req, request_id.clone());
engine_ctx.link_child(prefill_context.context()); engine_ctx.link_child(prefill_context.context());
self.call_prefill(prefill_context) self.call_prefill(prefill_context)
.await .await
.map(|(result, worker_id)| (Some(result), worker_id, None)) .map(|(result, worker_id)| (Some(result), worker_id, None))
}
}; };
// Abort if cancelled during prefill // Abort if cancelled during prefill
...@@ -518,6 +545,11 @@ impl ...@@ -518,6 +545,11 @@ impl
Ok((maybe_prefill_result, _prefill_worker_id, bootstrap_info)) => { Ok((maybe_prefill_result, _prefill_worker_id, bootstrap_info)) => {
tracing::debug!("Prefill completed, proceeding to decode"); tracing::debug!("Prefill completed, proceeding to decode");
// Set phase to Decode for the decode request
if let Some(ref tracker) = req.tracker {
tracker.set_phase(RequestPhase::Decode);
}
let mut decode_req = req; let mut decode_req = req;
// Update request with prefill result // Update request with prefill result
......
...@@ -237,7 +237,6 @@ impl OpenAIPreprocessor { ...@@ -237,7 +237,6 @@ impl OpenAIPreprocessor {
builder.output_options(request.extract_output_options()?); builder.output_options(request.extract_output_options()?);
builder.annotations(request.annotations().unwrap_or_default()); builder.annotations(request.annotations().unwrap_or_default());
builder.mdc_sum(Some(self.mdcsum.clone())); builder.mdc_sum(Some(self.mdcsum.clone()));
builder.estimated_prefix_hit_num_blocks(None);
// Extract backend_instance_id, extra_fields, and worker IDs from nvext if present // Extract backend_instance_id, extra_fields, and worker IDs from nvext if present
if let Some(nvext) = request.nvext() { if let Some(nvext) = request.nvext() {
builder.backend_instance_id(nvext.backend_instance_id); builder.backend_instance_id(nvext.backend_instance_id);
...@@ -943,7 +942,10 @@ impl ...@@ -943,7 +942,10 @@ impl
let response_generator = request.response_generator(context.id().to_string()); let response_generator = request.response_generator(context.id().to_string());
// convert the chat completion request to a common completion request // convert the chat completion request to a common completion request
let (common_request, annotations) = self.preprocess_request(&request).await?; let (mut common_request, annotations) = self.preprocess_request(&request).await?;
// Attach the timing tracker to the request so downstream components can record metrics
common_request.tracker = response_generator.tracker();
let mut response_generator = Box::new(response_generator); let mut response_generator = Box::new(response_generator);
...@@ -1092,7 +1094,10 @@ impl ...@@ -1092,7 +1094,10 @@ impl
let annotations = self.gather_tokens(&request, &mut builder, None)?; let annotations = self.gather_tokens(&request, &mut builder, None)?;
self.gather_multi_modal_data(&request, &mut builder).await?; self.gather_multi_modal_data(&request, &mut builder).await?;
let common_request = builder.build()?; let mut common_request = builder.build()?;
// Attach the timing tracker to the request so downstream components can record metrics
common_request.tracker = response_generator.tracker();
// update isl // update isl
response_generator.update_isl(common_request.token_ids.len() as u32); response_generator.update_isl(common_request.token_ids.len() as u32);
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use std::sync::Arc;
use derive_builder::Builder; use derive_builder::Builder;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use super::timing::RequestTracker;
use super::{OutputOptions, SamplingOptions, StopConditions}; use super::{OutputOptions, SamplingOptions, StopConditions};
use crate::kv_router::RouterConfigOverride; use crate::kv_router::RouterConfigOverride;
#[cfg(feature = "media-nixl")] #[cfg(feature = "media-nixl")]
...@@ -82,10 +85,6 @@ pub struct PreprocessedRequest { ...@@ -82,10 +85,6 @@ pub struct PreprocessedRequest {
#[builder(default)] #[builder(default)]
pub annotations: Vec<String>, pub annotations: Vec<String>,
/// Estimated number of prefix hit tokens (only used in kv aware routing)
#[builder(default)]
pub estimated_prefix_hit_num_blocks: Option<u32>,
/// Targeted backend instance ID for the request /// Targeted backend instance ID for the request
#[builder(default)] #[builder(default)]
pub backend_instance_id: Option<u64>, pub backend_instance_id: Option<u64>,
...@@ -119,6 +118,11 @@ pub struct PreprocessedRequest { ...@@ -119,6 +118,11 @@ pub struct PreprocessedRequest {
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
pub extra_fields: Option<Vec<String>>, pub extra_fields: Option<Vec<String>>,
/// Optional request tracker for per-request metrics (shared with DeltaGenerator)
#[builder(default)]
#[serde(skip)]
pub tracker: Option<Arc<RequestTracker>>,
/// Targeted prefill worker ID for disaggregated serving (GAIE Stage 2) /// Targeted prefill worker ID for disaggregated serving (GAIE Stage 2)
/// When set, the prefill request will be routed to this specific worker. /// When set, the prefill request will be routed to this specific worker.
#[builder(default)] #[builder(default)]
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
//! Per-request timing tracker for capturing request lifecycle metrics. //! Per-request tracker for capturing request lifecycle metrics.
//! //!
//! This module provides [`RequestTimingTracker`] for tracking timing information //! This module provides [`RequestTracker`] for tracking timing and routing information
//! that can be returned to clients via the `nvext` response field. //! that can be returned to clients via the `nvext` response field.
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::sync::OnceLock; use std::sync::{Mutex, OnceLock};
use std::time::{Instant, SystemTime, UNIX_EPOCH}; use std::time::{Instant, SystemTime, UNIX_EPOCH};
/// Per-request timing tracker. use crate::protocols::openai::nvext::WorkerIdInfo;
/// Phase of the request in disaggregated serving.
/// ///
/// Captures timing information throughout the request lifecycle: /// Used to determine which worker ID field to record when routing.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum RequestPhase {
/// Prefill-only phase (disaggregated serving)
Prefill,
/// Decode phase (disaggregated serving)
Decode,
/// Aggregated mode - same worker handles both prefill and decode
#[default]
Aggregated,
}
impl std::fmt::Display for RequestPhase {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
RequestPhase::Prefill => write!(f, "prefill"),
RequestPhase::Decode => write!(f, "decode"),
RequestPhase::Aggregated => write!(f, "aggregated"),
}
}
}
/// Per-request tracker for timing and routing metrics.
///
/// Captures information throughout the request lifecycle:
/// - `request_received`: When the request was received /// - `request_received`: When the request was received
/// - `prefill_start_time`: When prefill started (for disaggregated serving)
/// - `first_token_time`: When the first token was generated (set once via OnceLock) /// - `first_token_time`: When the first token was generated (set once via OnceLock)
/// - `request_finish_time`: When the request finished (set once via OnceLock) /// - `request_finish_time`: When the request finished (set once via OnceLock)
/// - KV cache hit rate information
/// ///
/// The `OnceLock` fields ensure that timing values are set exactly once, /// The `OnceLock` fields ensure that values are set exactly once,
/// which is important for disaggregated serving where the "first token" /// which is important for disaggregated serving where the "first token"
/// might appear multiple times. /// might appear multiple times.
pub struct RequestTimingTracker { #[derive(Debug)]
pub struct RequestTracker {
/// When the request was received (monotonic clock for duration calculations) /// When the request was received (monotonic clock for duration calculations)
request_received: Instant, request_received: Instant,
/// When the request was received (wall clock time as epoch milliseconds) /// When the request was received (wall clock time as epoch milliseconds)
request_received_epoch_ms: u64, request_received_epoch_ms: u64,
/// When prefill started (for disaggregated serving) - set once via OnceLock
prefill_start_time: OnceLock<Instant>,
/// When the first token was generated - set once via OnceLock /// When the first token was generated - set once via OnceLock
first_token_time: OnceLock<Instant>, first_token_time: OnceLock<Instant>,
/// When the request finished - set once via OnceLock /// When the request finished - set once via OnceLock
request_finish_time: OnceLock<Instant>, request_finish_time: OnceLock<Instant>,
/// KV cache overlap blocks (prefix cache hits) - set once via OnceLock
kv_overlap_blocks: OnceLock<u32>,
/// Input sequence length in blocks (for hit rate calculation) - set once via OnceLock
isl_blocks: OnceLock<usize>,
/// Prefill worker ID (for disaggregated serving) - set once via OnceLock
prefill_worker_id: OnceLock<u64>,
/// Decode worker ID - set once via OnceLock
decode_worker_id: OnceLock<u64>,
/// Request phase (Prefill/Decode/Aggregated)
phase: Mutex<RequestPhase>,
} }
impl RequestTimingTracker { impl RequestTracker {
/// Create a new timing tracker, capturing the current time as request received. /// Create a new request tracker, capturing the current time as request received.
pub fn new() -> Self { pub fn new() -> Self {
let now = Instant::now(); let now = Instant::now();
let epoch_ms = SystemTime::now() let epoch_ms = SystemTime::now()
...@@ -43,12 +90,23 @@ impl RequestTimingTracker { ...@@ -43,12 +90,23 @@ impl RequestTimingTracker {
.map(|d| d.as_millis() as u64) .map(|d| d.as_millis() as u64)
.unwrap_or(0); .unwrap_or(0);
RequestTimingTracker { RequestTracker {
request_received: now, request_received: now,
request_received_epoch_ms: epoch_ms, request_received_epoch_ms: epoch_ms,
prefill_start_time: OnceLock::new(),
first_token_time: OnceLock::new(), first_token_time: OnceLock::new(),
request_finish_time: OnceLock::new(), request_finish_time: OnceLock::new(),
kv_overlap_blocks: OnceLock::new(),
isl_blocks: OnceLock::new(),
prefill_worker_id: OnceLock::new(),
decode_worker_id: OnceLock::new(),
phase: Mutex::new(RequestPhase::Aggregated),
}
} }
/// Record when prefill started. Returns true if this was the first call.
pub fn record_prefill_start(&self) -> bool {
self.prefill_start_time.set(Instant::now()).is_ok()
} }
pub fn record_first_token(&self) -> bool { pub fn record_first_token(&self) -> bool {
...@@ -59,6 +117,27 @@ impl RequestTimingTracker { ...@@ -59,6 +117,27 @@ impl RequestTimingTracker {
self.request_finish_time.set(Instant::now()).is_ok() self.request_finish_time.set(Instant::now()).is_ok()
} }
/// Record KV cache hit information. Returns true if this was the first call.
pub fn record_kv_hit(&self, overlap_blocks: u32, isl_blocks: usize) -> bool {
let overlap_set = self.kv_overlap_blocks.set(overlap_blocks).is_ok();
let isl_set = self.isl_blocks.set(isl_blocks).is_ok();
overlap_set && isl_set
}
/// Time from request received to prefill start (queue/wait time) in milliseconds.
pub fn prefill_wait_time_ms(&self) -> Option<f64> {
self.prefill_start_time
.get()
.map(|t| t.duration_since(self.request_received).as_secs_f64() * 1000.0)
}
/// Time from prefill start to first token (prefill execution time) in milliseconds.
pub fn prefill_time_ms(&self) -> Option<f64> {
let prefill_start = self.prefill_start_time.get()?;
let first_token = self.first_token_time.get()?;
Some(first_token.duration_since(*prefill_start).as_secs_f64() * 1000.0)
}
pub fn ttft_ms(&self) -> Option<f64> { pub fn ttft_ms(&self) -> Option<f64> {
self.first_token_time self.first_token_time
.get() .get()
...@@ -75,16 +154,84 @@ impl RequestTimingTracker { ...@@ -75,16 +154,84 @@ impl RequestTimingTracker {
self.request_received_epoch_ms self.request_received_epoch_ms
} }
/// KV cache hit rate as a ratio (0.0 to 1.0).
pub fn kv_hit_rate(&self) -> Option<f64> {
let overlap = *self.kv_overlap_blocks.get()?;
let isl = *self.isl_blocks.get()?;
if isl == 0 {
return None;
}
Some(overlap as f64 / isl as f64)
}
/// Record the prefill worker ID. Returns true if this was the first call.
pub fn record_prefill_worker(&self, id: u64) -> bool {
self.prefill_worker_id.set(id).is_ok()
}
/// Record the decode worker ID. Returns true if this was the first call.
pub fn record_decode_worker(&self, id: u64) -> bool {
self.decode_worker_id.set(id).is_ok()
}
/// Set the request phase. Can be called multiple times to update the phase.
pub fn set_phase(&self, phase: RequestPhase) {
*self.phase.lock().unwrap() = phase;
}
/// Get the current request phase.
pub fn phase(&self) -> RequestPhase {
*self.phase.lock().unwrap()
}
/// Record worker ID based on the current phase.
///
/// - Prefill phase: records as prefill_worker_id
/// - Decode phase: records as decode_worker_id
/// - Aggregated phase: records as both prefill and decode worker
pub fn record_worker(&self, instance_id: u64) {
match self.phase() {
RequestPhase::Prefill => {
self.record_prefill_worker(instance_id);
}
RequestPhase::Decode => {
self.record_decode_worker(instance_id);
}
RequestPhase::Aggregated => {
self.record_prefill_worker(instance_id);
self.record_decode_worker(instance_id);
}
}
}
/// Get worker ID information if any worker IDs have been recorded.
pub fn get_worker_info(&self) -> Option<WorkerIdInfo> {
let prefill = self.prefill_worker_id.get().copied();
let decode = self.decode_worker_id.get().copied();
if prefill.is_none() && decode.is_none() {
return None;
}
Some(WorkerIdInfo {
prefill_worker_id: prefill,
decode_worker_id: decode,
})
}
pub fn get_timing_info(&self) -> TimingInfo { pub fn get_timing_info(&self) -> TimingInfo {
TimingInfo { TimingInfo {
request_received_ms: self.request_received_epoch_ms, request_received_ms: self.request_received_epoch_ms,
prefill_wait_time_ms: self.prefill_wait_time_ms(),
prefill_time_ms: self.prefill_time_ms(),
ttft_ms: self.ttft_ms(), ttft_ms: self.ttft_ms(),
total_time_ms: self.total_time_ms(), total_time_ms: self.total_time_ms(),
kv_hit_rate: self.kv_hit_rate(),
} }
} }
} }
impl Default for RequestTimingTracker { impl Default for RequestTracker {
fn default() -> Self { fn default() -> Self {
Self::new() Self::new()
} }
...@@ -99,6 +246,14 @@ pub struct TimingInfo { ...@@ -99,6 +246,14 @@ pub struct TimingInfo {
/// When the request was received (epoch milliseconds) /// When the request was received (epoch milliseconds)
pub request_received_ms: u64, pub request_received_ms: u64,
/// Time from request received to prefill start (queue/wait time) in milliseconds
#[serde(skip_serializing_if = "Option::is_none")]
pub prefill_wait_time_ms: Option<f64>,
/// Time from prefill start to first token (prefill execution time) in milliseconds
#[serde(skip_serializing_if = "Option::is_none")]
pub prefill_time_ms: Option<f64>,
/// Time to first token in milliseconds /// Time to first token in milliseconds
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub ttft_ms: Option<f64>, pub ttft_ms: Option<f64>,
...@@ -106,4 +261,8 @@ pub struct TimingInfo { ...@@ -106,4 +261,8 @@ pub struct TimingInfo {
/// Total request time in milliseconds /// Total request time in milliseconds
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub total_time_ms: Option<f64>, pub total_time_ms: Option<f64>,
/// KV cache hit rate (0.0 to 1.0) - ratio of cached blocks to total input blocks
#[serde(skip_serializing_if = "Option::is_none")]
pub kv_hit_rate: Option<f64>,
} }
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use std::sync::Arc;
use super::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse}; use super::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse};
use crate::{ use crate::{
local_model::runtime_config::ModelRuntimeConfig, local_model::runtime_config::ModelRuntimeConfig,
protocols::{ protocols::{
common::{self, timing::RequestTimingTracker}, common::{self, timing::RequestTracker},
openai::nvext::{NvExtProvider, NvExtResponse, TimingInfo, WorkerIdInfo}, openai::nvext::{NvExtProvider, NvExtResponse, TimingInfo},
}, },
types::TokenIdType, types::TokenIdType,
}; };
...@@ -44,11 +46,20 @@ impl NvCreateChatCompletionRequest { ...@@ -44,11 +46,20 @@ impl NvCreateChatCompletionRequest {
/// # Returns /// # Returns
/// * [`DeltaGenerator`] configured with model name and response options. /// * [`DeltaGenerator`] configured with model name and response options.
pub fn response_generator(&self, request_id: String) -> DeltaGenerator { pub fn response_generator(&self, request_id: String) -> DeltaGenerator {
// Check if client requested timing in extra_fields // Enable tracking if:
let enable_timing = self // 1. Client requested timing in extra_fields, OR
// 2. query_instance_id annotation is present (needs worker_id tracking for response)
let enable_tracking = self
.nvext() .nvext()
.and_then(|nv| nv.extra_fields.as_ref()) .map(|nv| {
.is_some_and(|fields| fields.iter().any(|f| f == "timing")); nv.extra_fields
.as_ref()
.is_some_and(|fields| fields.iter().any(|f| f == "timing"))
|| nv.annotations.as_ref().is_some_and(|annots| {
annots.iter().any(|a| a.starts_with("query_instance_id"))
})
})
.unwrap_or(false);
let options = DeltaGeneratorOptions { let options = DeltaGeneratorOptions {
enable_usage: self enable_usage: self
...@@ -59,7 +70,7 @@ impl NvCreateChatCompletionRequest { ...@@ -59,7 +70,7 @@ impl NvCreateChatCompletionRequest {
.unwrap_or(false), .unwrap_or(false),
enable_logprobs: self.inner.logprobs.unwrap_or(false) enable_logprobs: self.inner.logprobs.unwrap_or(false)
|| self.inner.top_logprobs.unwrap_or(0) > 0, || self.inner.top_logprobs.unwrap_or(0) > 0,
enable_timing, enable_tracking,
runtime_config: ModelRuntimeConfig::default(), runtime_config: ModelRuntimeConfig::default(),
}; };
...@@ -74,8 +85,8 @@ pub struct DeltaGeneratorOptions { ...@@ -74,8 +85,8 @@ pub struct DeltaGeneratorOptions {
pub enable_usage: bool, pub enable_usage: bool,
/// Determines whether log probabilities should be included in the response. /// Determines whether log probabilities should be included in the response.
pub enable_logprobs: bool, pub enable_logprobs: bool,
/// Determines whether timing information should be included in the response's nvext. /// Determines whether request tracking (timing, KV hit rate) should be enabled.
pub enable_timing: bool, pub enable_tracking: bool,
pub runtime_config: ModelRuntimeConfig, pub runtime_config: ModelRuntimeConfig,
} }
...@@ -99,8 +110,8 @@ pub struct DeltaGenerator { ...@@ -99,8 +110,8 @@ pub struct DeltaGenerator {
msg_counter: u64, msg_counter: u64,
/// Configuration options for response generation. /// Configuration options for response generation.
options: DeltaGeneratorOptions, options: DeltaGeneratorOptions,
/// Optional timing tracker for per-request timing metrics. /// Optional request tracker for per-request metrics (shared with PreprocessedRequest).
timing_tracker: Option<RequestTimingTracker>, tracker: Option<Arc<RequestTracker>>,
} }
impl DeltaGenerator { impl DeltaGenerator {
...@@ -133,9 +144,9 @@ impl DeltaGenerator { ...@@ -133,9 +144,9 @@ impl DeltaGenerator {
let chatcmpl_id = format!("chatcmpl-{request_id}"); let chatcmpl_id = format!("chatcmpl-{request_id}");
// Create timing tracker if timing is enabled // Create request tracker if tracking is enabled
let timing_tracker = if options.enable_timing { let tracker = if options.enable_tracking {
Some(RequestTimingTracker::new()) Some(Arc::new(RequestTracker::new()))
} else { } else {
None None
}; };
...@@ -150,10 +161,15 @@ impl DeltaGenerator { ...@@ -150,10 +161,15 @@ impl DeltaGenerator {
usage, usage,
msg_counter: 0, msg_counter: 0,
options, options,
timing_tracker, tracker,
} }
} }
/// Returns the request tracker if tracking is enabled, for sharing with PreprocessedRequest.
pub fn tracker(&self) -> Option<Arc<RequestTracker>> {
self.tracker.clone()
}
/// Updates the prompt token usage count. /// Updates the prompt token usage count.
/// ///
/// # Arguments /// # Arguments
...@@ -396,16 +412,12 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamRes ...@@ -396,16 +412,12 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamRes
); );
// Record first token time (only succeeds on first call due to OnceLock) // Record first token time (only succeeds on first call due to OnceLock)
if let Some(ref tracker) = self.timing_tracker { if let Some(ref tracker) = self.tracker {
tracker.record_first_token(); tracker.record_first_token();
} }
// Extract worker_id and token_ids from disaggregated_params // Get worker_id info from tracker (set by KvPushRouter based on phase)
let worker_id_info = delta let worker_id_info = self.tracker.as_ref().and_then(|t| t.get_worker_info());
.disaggregated_params
.as_ref()
.and_then(|params| params.get("worker_id"))
.and_then(|v| serde_json::from_value::<WorkerIdInfo>(v.clone()).ok());
let token_ids = delta let token_ids = delta
.disaggregated_params .disaggregated_params
...@@ -415,7 +427,7 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamRes ...@@ -415,7 +427,7 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamRes
// Get timing info if this is the final response (has finish_reason) // Get timing info if this is the final response (has finish_reason)
let timing_info: Option<TimingInfo> = if finish_reason.is_some() { let timing_info: Option<TimingInfo> = if finish_reason.is_some() {
self.timing_tracker.as_ref().map(|tracker| { self.tracker.as_ref().map(|tracker| {
tracker.record_finish(); tracker.record_finish();
tracker.get_timing_info() tracker.get_timing_info()
}) })
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use std::sync::Arc;
use super::{NvCreateCompletionRequest, NvCreateCompletionResponse}; use super::{NvCreateCompletionRequest, NvCreateCompletionResponse};
use crate::{ use crate::{
protocols::{ protocols::{
common::{self, timing::RequestTimingTracker}, common::{self, timing::RequestTracker},
openai::nvext::{NvExtProvider, NvExtResponse, TimingInfo, WorkerIdInfo}, openai::nvext::{NvExtProvider, NvExtResponse, TimingInfo},
}, },
types::TokenIdType, types::TokenIdType,
}; };
...@@ -39,11 +41,20 @@ impl NvCreateCompletionRequest { ...@@ -39,11 +41,20 @@ impl NvCreateCompletionRequest {
// put this method on the request // put this method on the request
// inspect the request to extract options // inspect the request to extract options
pub fn response_generator(&self, request_id: String) -> DeltaGenerator { pub fn response_generator(&self, request_id: String) -> DeltaGenerator {
// Check if client requested timing in extra_fields // Enable tracking if:
let enable_timing = self // 1. Client requested timing in extra_fields, OR
// 2. query_instance_id annotation is present (needs worker_id tracking for response)
let enable_tracking = self
.nvext() .nvext()
.and_then(|nv| nv.extra_fields.as_ref()) .map(|nv| {
.is_some_and(|fields| fields.iter().any(|f| f == "timing")); nv.extra_fields
.as_ref()
.is_some_and(|fields| fields.iter().any(|f| f == "timing"))
|| nv.annotations.as_ref().is_some_and(|annots| {
annots.iter().any(|a| a.starts_with("query_instance_id"))
})
})
.unwrap_or(false);
let options = DeltaGeneratorOptions { let options = DeltaGeneratorOptions {
enable_usage: self enable_usage: self
...@@ -53,7 +64,7 @@ impl NvCreateCompletionRequest { ...@@ -53,7 +64,7 @@ impl NvCreateCompletionRequest {
.map(|opts| opts.include_usage) .map(|opts| opts.include_usage)
.unwrap_or(false), .unwrap_or(false),
enable_logprobs: self.inner.logprobs.unwrap_or(0) > 0, enable_logprobs: self.inner.logprobs.unwrap_or(0) > 0,
enable_timing, enable_tracking,
}; };
DeltaGenerator::new(self.inner.model.clone(), options, request_id) DeltaGenerator::new(self.inner.model.clone(), options, request_id)
...@@ -64,7 +75,7 @@ impl NvCreateCompletionRequest { ...@@ -64,7 +75,7 @@ impl NvCreateCompletionRequest {
pub struct DeltaGeneratorOptions { pub struct DeltaGeneratorOptions {
pub enable_usage: bool, pub enable_usage: bool,
pub enable_logprobs: bool, pub enable_logprobs: bool,
pub enable_timing: bool, pub enable_tracking: bool,
} }
pub struct DeltaGenerator { pub struct DeltaGenerator {
...@@ -75,7 +86,7 @@ pub struct DeltaGenerator { ...@@ -75,7 +86,7 @@ pub struct DeltaGenerator {
system_fingerprint: Option<String>, system_fingerprint: Option<String>,
usage: dynamo_async_openai::types::CompletionUsage, usage: dynamo_async_openai::types::CompletionUsage,
options: DeltaGeneratorOptions, options: DeltaGeneratorOptions,
timing_tracker: Option<RequestTimingTracker>, tracker: Option<Arc<RequestTracker>>,
} }
impl DeltaGenerator { impl DeltaGenerator {
...@@ -101,9 +112,9 @@ impl DeltaGenerator { ...@@ -101,9 +112,9 @@ impl DeltaGenerator {
let completion_id = format!("cmpl-{request_id}"); let completion_id = format!("cmpl-{request_id}");
// Create timing tracker if timing is enabled // Create request tracker if tracking is enabled
let timing_tracker = if options.enable_timing { let tracker = if options.enable_tracking {
Some(RequestTimingTracker::new()) Some(Arc::new(RequestTracker::new()))
} else { } else {
None None
}; };
...@@ -116,10 +127,15 @@ impl DeltaGenerator { ...@@ -116,10 +127,15 @@ impl DeltaGenerator {
system_fingerprint: None, system_fingerprint: None,
usage, usage,
options, options,
timing_tracker, tracker,
} }
} }
/// Returns the request tracker if tracking is enabled, for sharing with PreprocessedRequest.
pub fn tracker(&self) -> Option<Arc<RequestTracker>> {
self.tracker.clone()
}
pub fn update_isl(&mut self, isl: u32) { pub fn update_isl(&mut self, isl: u32) {
self.usage.prompt_tokens = isl; self.usage.prompt_tokens = isl;
} }
...@@ -291,16 +307,12 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateCompletionResponse> for ...@@ -291,16 +307,12 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateCompletionResponse> for
let mut response = self.create_choice(index, delta.text.clone(), finish_reason, logprobs); let mut response = self.create_choice(index, delta.text.clone(), finish_reason, logprobs);
// Record first token time (only succeeds on first call due to OnceLock) // Record first token time (only succeeds on first call due to OnceLock)
if let Some(ref tracker) = self.timing_tracker { if let Some(ref tracker) = self.tracker {
tracker.record_first_token(); tracker.record_first_token();
} }
// Extract worker_id and token_ids from disaggregated_params // Get worker_id info from tracker (set by KvPushRouter based on phase)
let worker_id_info = delta let worker_id_info = self.tracker.as_ref().and_then(|t| t.get_worker_info());
.disaggregated_params
.as_ref()
.and_then(|params| params.get("worker_id"))
.and_then(|v| serde_json::from_value::<WorkerIdInfo>(v.clone()).ok());
let token_ids = delta let token_ids = delta
.disaggregated_params .disaggregated_params
...@@ -310,7 +322,7 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateCompletionResponse> for ...@@ -310,7 +322,7 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateCompletionResponse> for
// Get timing info if this is the final response (has finish_reason) // Get timing info if this is the final response (has finish_reason)
let timing_info: Option<TimingInfo> = if finish_reason.is_some() { let timing_info: Option<TimingInfo> = if finish_reason.is_some() {
self.timing_tracker.as_ref().map(|tracker| { self.tracker.as_ref().map(|tracker| {
tracker.record_finish(); tracker.record_finish();
tracker.get_timing_info() tracker.get_timing_info()
}) })
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment