Unverified Commit f437c8cf authored by Biswa Panda's avatar Biswa Panda Committed by GitHub
Browse files

feat(frontend): gate nvext response metadata behind extra_fields (#8252)


Co-authored-by: default avatarAmeenP <ameenp360@gmail.com>
parent 0c98d9a1
......@@ -35,7 +35,7 @@ Include `nvext` as a top-level field alongside standard OpenAI-compatible fields
| `backend_instance_id` | `u64` | `None` | Router | Routes the request to a specific backend instance. |
| `token_data` | `u32[]` | `None` | Preprocessor | Pre-tokenized prompt tokens. When provided with `backend_instance_id`, tokenization is skipped. |
| `max_thinking_tokens` | `u32` | `None` | Backend | Maximum thinking tokens allowed (passed through to backends). |
| `extra_fields` | `string[]` | `None` | Response builder | Fields to include in the response `nvext`. Supported: `"worker_id"`, `"timing"`. |
| `extra_fields` | `string[]` | `None` | Response builder | Fields to include in the response `nvext`. Supported: `"worker_id"`, `"timing"`, `"routed_experts"`. |
| `prefill_worker_id` | `u64` | `None` | Router | Routes the request to a specific prefill worker (disaggregated serving). |
| `decode_worker_id` | `u64` | `None` | Router | Routes the request to a specific decode worker (disaggregated serving). |
| `agent_hints` | object | `None` | Router | Per-request hints for scheduling and load balancing. See [Agent Hints](#agent-hints). |
......@@ -163,6 +163,7 @@ When the client requests response metadata via `extra_fields`, the response incl
|-------|---------------|-------------|
| `worker_id` | `extra_fields: ["worker_id"]` | Prefill/decode worker IDs and data parallel ranks that processed the request. |
| `timing` | `extra_fields: ["timing"]` | Per-request timing information (TTFT, ITL, queue time, etc.). |
| `routed_experts` | `extra_fields: ["routed_experts"]` | Routed expert capture payload returned by SGLang-backed requests. |
| `token_ids` | Automatic (GAIE Stage 1) | Tokenized prompt for reuse in Stage 2 query-only mode. |
### Example response `nvext`
......
......@@ -10,7 +10,7 @@ use crate::{
common::{self, timing::RequestTracker},
openai::{
convert_backend_top_logprobs,
nvext::{NvExtProvider, NvExtResponse, TimingInfo},
nvext::{NvExtProvider, NvExtResponseFieldSelection},
token_to_utf8_bytes,
},
},
......@@ -51,20 +51,7 @@ impl NvCreateChatCompletionRequest {
/// # Returns
/// * [`DeltaGenerator`] configured with model name and response options.
pub fn response_generator(&self, request_id: String) -> DeltaGenerator {
// Enable tracking if:
// 1. Client requested timing in extra_fields, OR
// 2. query_instance_id annotation is present (needs worker_id tracking for response)
let enable_tracking = self
.nvext()
.map(|nv| {
nv.extra_fields
.as_ref()
.is_some_and(|fields| fields.iter().any(|f| f == "timing"))
|| nv.annotations.as_ref().is_some_and(|annots| {
annots.iter().any(|a| a.starts_with("query_instance_id"))
})
})
.unwrap_or(false);
let response_fields = NvExtResponseFieldSelection::from_nvext(self.nvext());
let options = DeltaGeneratorOptions {
enable_usage: self
......@@ -81,7 +68,7 @@ impl NvCreateChatCompletionRequest {
.unwrap_or(false),
enable_logprobs: self.inner.logprobs.unwrap_or(false)
|| self.inner.top_logprobs.unwrap_or(0) > 0,
enable_tracking,
response_fields,
runtime_config: ModelRuntimeConfig::default(),
};
......@@ -98,8 +85,8 @@ pub struct DeltaGeneratorOptions {
pub continuous_usage_stats: bool,
/// Determines whether log probabilities should be included in the response.
pub enable_logprobs: bool,
/// Determines whether request tracking (timing, KV hit rate) should be enabled.
pub enable_tracking: bool,
/// Determines which nvext response fields may be emitted for this request.
pub response_fields: NvExtResponseFieldSelection,
pub runtime_config: ModelRuntimeConfig,
}
......@@ -158,7 +145,8 @@ impl DeltaGenerator {
let chatcmpl_id = format!("chatcmpl-{request_id}");
// Always create request tracker for per-worker metrics (TTFT, ITL per worker_id).
// The enable_tracking option only controls whether timing info is included in the response.
// `response_fields` only controls which nvext fields are returned to the client;
// the tracker still records timing/ITL internally for metrics.
let tracker = Some(Arc::new(RequestTracker::new()));
Self {
......@@ -414,60 +402,39 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamRes
delta.stop_reason,
);
// Get worker_id info from tracker (set by KvPushRouter based on phase)
let worker_id_info = self.tracker.as_ref().and_then(|t| t.get_worker_info());
let token_ids = delta
.disaggregated_params
.as_ref()
.and_then(|params| params.get("token_ids"))
.and_then(|v| serde_json::from_value::<Vec<u32>>(v.clone()).ok());
let routed_experts = delta
.disaggregated_params
.as_ref()
.and_then(|params| params.get("routed_experts"))
.cloned();
// Get timing info if this is the final response (has finish_reason)
let timing_info: Option<TimingInfo> = if finish_reason.is_some() {
self.tracker.as_ref().map(|tracker| {
// Record finish for timing/ITL accounting even when timing is not returned to the client.
// Kept at call site because it's a side effect on the tracker — not a gating decision.
if finish_reason.is_some()
&& let Some(ref tracker) = self.tracker
{
tracker.record_finish();
tracker.get_timing_info()
})
} else {
None
};
}
// Inject nvext if we have worker_id, token_ids, timing, or routed experts.
if worker_id_info.is_some()
|| token_ids.is_some()
|| timing_info.is_some()
|| routed_experts.is_some()
// Build the nvext response payload via the shared gating helper on
// `NvExtResponseFieldSelection` (see `nvext.rs`). Both chat and
// completions delta generators go through the same helper so the gating
// rules stay in one place.
if let Some(nvext_response) = self.options.response_fields.build_response_nvext(
self.tracker.as_ref(),
delta.disaggregated_params.as_ref(),
finish_reason.is_some(),
) && let Ok(nvext_json) = serde_json::to_value(&nvext_response)
{
let nvext_response = NvExtResponse {
worker_id: worker_id_info.clone(),
timing: timing_info,
token_ids: token_ids.clone(),
routed_experts,
};
if let Ok(nvext_json) = serde_json::to_value(&nvext_response) {
stream_response.nvext = Some(nvext_json);
if let Some(ref info) = worker_id_info {
if let Some(ref info) = nvext_response.worker_id {
tracing::debug!(
"Injected worker_id into chat completion nvext: prefill={:?}, decode={:?}",
info.prefill_worker_id,
info.decode_worker_id
);
}
if let Some(ref tokens) = token_ids {
if let Some(ref tokens) = nvext_response.token_ids {
tracing::debug!(
"Injected token_ids into chat completion nvext: {} tokens",
tokens.len()
);
}
}
}
Ok(stream_response)
}
......@@ -500,6 +467,8 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamRes
#[cfg(test)]
mod tests {
use super::*;
use crate::protocols::common::{self, llm_backend::BackendOutput, timing::WORKER_TYPE_PREFILL};
use crate::protocols::openai::DeltaGeneratorExt;
use dynamo_protocols::types::{
ChatCompletionRequestMessage, ChatCompletionRequestUserMessage,
ChatCompletionRequestUserMessageContent, CreateChatCompletionRequest,
......@@ -564,4 +533,124 @@ mod tests {
"Streaming request should not have stream_options modified"
);
}
fn make_request_with_nvext(
nvext: crate::protocols::openai::nvext::NvExt,
) -> NvCreateChatCompletionRequest {
let mut request = create_test_request();
request.nvext = Some(nvext);
request
}
fn final_backend_output() -> BackendOutput {
BackendOutput {
token_ids: vec![1],
tokens: vec![Some("hello".to_string())],
text: Some("hello".to_string()),
cum_log_probs: None,
log_probs: None,
top_logprobs: None,
finish_reason: Some(common::FinishReason::Stop),
stop_reason: None,
index: Some(0),
completion_usage: None,
disaggregated_params: Some(serde_json::json!({
"token_ids": [11, 22, 33],
"routed_experts": {"layer_0": [1, 3]}
})),
}
}
#[test]
fn test_plain_request_without_extra_fields_omits_nvext() {
let request = create_test_request();
let mut generator = request.response_generator("req-no-nvext".to_string());
let tracker = generator.tracker().expect("tracker");
tracker.record_worker(42, Some(0), WORKER_TYPE_PREFILL);
let response = generator
.choice_from_postprocessor(final_backend_output())
.expect("choice generation");
assert!(response.nvext.is_none());
}
#[test]
fn test_timing_extra_field_emits_timing_on_final_chunk() {
use crate::protocols::openai::nvext::NvExt;
let nvext = NvExt::builder()
.extra_fields(vec!["timing".to_string()])
.build()
.unwrap();
let mut generator =
make_request_with_nvext(nvext).response_generator("req-timing".to_string());
let response = generator
.choice_from_postprocessor(final_backend_output())
.expect("choice generation");
let nvext_json = response.nvext.expect("nvext present for timing request");
assert!(
nvext_json.get("timing").is_some(),
"timing should be emitted when extra_fields=[\"timing\"]"
);
assert!(nvext_json.get("worker_id").is_none());
assert!(nvext_json.get("token_ids").is_none());
assert!(nvext_json.get("routed_experts").is_none());
}
#[test]
fn test_query_instance_id_emits_worker_id_and_token_ids() {
use crate::protocols::openai::nvext::NvExt;
let nvext = NvExt::builder()
.annotations(vec!["query_instance_id:abc".to_string()])
.build()
.unwrap();
let mut generator =
make_request_with_nvext(nvext).response_generator("req-qid".to_string());
let tracker = generator.tracker().expect("tracker");
tracker.record_worker(42, Some(0), WORKER_TYPE_PREFILL);
let response = generator
.choice_from_postprocessor(final_backend_output())
.expect("choice generation");
let nvext_json = response
.nvext
.expect("nvext present for query_instance_id flow");
assert!(nvext_json.get("worker_id").is_some());
assert_eq!(
nvext_json.get("token_ids"),
Some(&serde_json::json!([11, 22, 33]))
);
// timing is NOT auto-enabled for query_instance_id — it is gated by `extra_fields: ["timing"]`.
assert!(nvext_json.get("timing").is_none());
assert!(nvext_json.get("routed_experts").is_none());
}
#[test]
fn test_routed_experts_extra_field_emits_routed_experts() {
use crate::protocols::openai::nvext::NvExt;
let nvext = NvExt::builder()
.extra_fields(vec!["routed_experts".to_string()])
.build()
.unwrap();
let mut generator =
make_request_with_nvext(nvext).response_generator("req-experts".to_string());
let response = generator
.choice_from_postprocessor(final_backend_output())
.expect("choice generation");
let nvext_json = response
.nvext
.expect("nvext present for routed_experts request");
assert_eq!(
nvext_json.get("routed_experts"),
Some(&serde_json::json!({"layer_0": [1, 3]}))
);
assert!(nvext_json.get("worker_id").is_none());
assert!(nvext_json.get("timing").is_none());
assert!(nvext_json.get("token_ids").is_none());
}
}
......@@ -9,7 +9,7 @@ use crate::{
common::{self, timing::RequestTracker},
openai::{
convert_backend_top_logprobs,
nvext::{NvExtProvider, NvExtResponse, TimingInfo},
nvext::{NvExtProvider, NvExtResponseFieldSelection},
},
},
types::TokenIdType,
......@@ -45,20 +45,7 @@ impl NvCreateCompletionRequest {
// put this method on the request
// inspect the request to extract options
pub fn response_generator(&self, request_id: String) -> DeltaGenerator {
// Enable tracking if:
// 1. Client requested timing in extra_fields, OR
// 2. query_instance_id annotation is present (needs worker_id tracking for response)
let enable_tracking = self
.nvext()
.map(|nv| {
nv.extra_fields
.as_ref()
.is_some_and(|fields| fields.iter().any(|f| f == "timing"))
|| nv.annotations.as_ref().is_some_and(|annots| {
annots.iter().any(|a| a.starts_with("query_instance_id"))
})
})
.unwrap_or(false);
let response_fields = NvExtResponseFieldSelection::from_nvext(self.nvext());
let options = DeltaGeneratorOptions {
enable_usage: self
......@@ -74,7 +61,7 @@ impl NvCreateCompletionRequest {
.map(|opts| opts.continuous_usage_stats)
.unwrap_or(false),
enable_logprobs: self.inner.logprobs.unwrap_or(0) > 0,
enable_tracking,
response_fields,
};
DeltaGenerator::new(self.inner.model.clone(), options, request_id)
......@@ -86,7 +73,7 @@ pub struct DeltaGeneratorOptions {
pub enable_usage: bool,
pub continuous_usage_stats: bool,
pub enable_logprobs: bool,
pub enable_tracking: bool,
pub response_fields: NvExtResponseFieldSelection,
}
pub struct DeltaGenerator {
......@@ -124,7 +111,8 @@ impl DeltaGenerator {
let completion_id = format!("cmpl-{request_id}");
// Always create request tracker for per-worker metrics (TTFT, ITL per worker_id).
// The enable_tracking option only controls whether timing info is included in the response.
// `response_fields` only controls which nvext fields are returned to the client;
// the tracker still records timing/ITL internally for metrics.
let tracker = Some(Arc::new(RequestTracker::new()));
Self {
......@@ -308,60 +296,39 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateCompletionResponse> for
let index = delta.index.unwrap_or(0);
let mut response = self.create_choice(index, delta.text.clone(), finish_reason, logprobs);
// Get worker_id info from tracker (set by KvPushRouter based on phase)
let worker_id_info = self.tracker.as_ref().and_then(|t| t.get_worker_info());
let token_ids = delta
.disaggregated_params
.as_ref()
.and_then(|params| params.get("token_ids"))
.and_then(|v| serde_json::from_value::<Vec<u32>>(v.clone()).ok());
let routed_experts = delta
.disaggregated_params
.as_ref()
.and_then(|params| params.get("routed_experts"))
.cloned();
// Get timing info if this is the final response (has finish_reason)
let timing_info: Option<TimingInfo> = if finish_reason.is_some() {
self.tracker.as_ref().map(|tracker| {
// Record finish for timing/ITL accounting even when timing is not returned to the client.
// Kept at call site because it's a side effect on the tracker — not a gating decision.
if finish_reason.is_some()
&& let Some(ref tracker) = self.tracker
{
tracker.record_finish();
tracker.get_timing_info()
})
} else {
None
};
}
// Inject nvext if we have worker_id, token_ids, timing, or routed experts.
if worker_id_info.is_some()
|| token_ids.is_some()
|| timing_info.is_some()
|| routed_experts.is_some()
// Build the nvext response payload via the shared gating helper on
// `NvExtResponseFieldSelection` (see `nvext.rs`). Both chat and
// completions delta generators go through the same helper so the gating
// rules stay in one place.
if let Some(nvext_response) = self.options.response_fields.build_response_nvext(
self.tracker.as_ref(),
delta.disaggregated_params.as_ref(),
finish_reason.is_some(),
) && let Ok(nvext_json) = serde_json::to_value(&nvext_response)
{
let nvext_response = NvExtResponse {
worker_id: worker_id_info.clone(),
timing: timing_info,
token_ids: token_ids.clone(),
routed_experts,
};
if let Ok(nvext_json) = serde_json::to_value(&nvext_response) {
response.nvext = Some(nvext_json);
if let Some(ref info) = worker_id_info {
if let Some(ref info) = nvext_response.worker_id {
tracing::debug!(
"Injected worker_id into completions nvext: prefill={:?}, decode={:?}",
info.prefill_worker_id,
info.decode_worker_id
);
}
if let Some(ref tokens) = token_ids {
if let Some(ref tokens) = nvext_response.token_ids {
tracing::debug!(
"Injected token_ids into completions nvext: {} tokens",
tokens.len()
);
}
}
}
Ok(response)
}
......@@ -390,3 +357,147 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateCompletionResponse> for
self.tracker.clone()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::protocols::common::{self, llm_backend::BackendOutput, timing::WORKER_TYPE_PREFILL};
use crate::protocols::openai::DeltaGeneratorExt;
use dynamo_protocols::types::{CreateCompletionRequestArgs, Prompt};
fn create_test_request() -> NvCreateCompletionRequest {
let inner = CreateCompletionRequestArgs::default()
.model("test-model")
.prompt(Prompt::String("test".to_string()))
.build()
.expect("completion request");
NvCreateCompletionRequest {
inner,
common: Default::default(),
nvext: None,
metadata: None,
unsupported_fields: Default::default(),
}
}
fn make_request_with_nvext(
nvext: crate::protocols::openai::nvext::NvExt,
) -> NvCreateCompletionRequest {
let mut request = create_test_request();
request.nvext = Some(nvext);
request
}
fn final_backend_output() -> BackendOutput {
BackendOutput {
token_ids: vec![1],
tokens: vec![Some("hello".to_string())],
text: Some("hello".to_string()),
cum_log_probs: None,
log_probs: None,
top_logprobs: None,
finish_reason: Some(common::FinishReason::Stop),
stop_reason: None,
index: Some(0),
completion_usage: None,
disaggregated_params: Some(serde_json::json!({
"token_ids": [11, 22, 33],
"routed_experts": {"layer_0": [1, 3]}
})),
}
}
#[test]
fn test_plain_request_without_extra_fields_omits_nvext() {
let request = create_test_request();
let mut generator = request.response_generator("req-no-nvext".to_string());
let tracker = generator.tracker().expect("tracker");
tracker.record_worker(42, Some(0), WORKER_TYPE_PREFILL);
let response = generator
.choice_from_postprocessor(final_backend_output())
.expect("choice generation");
assert!(response.nvext.is_none());
}
#[test]
fn test_timing_extra_field_emits_timing_on_final_chunk() {
use crate::protocols::openai::nvext::NvExt;
let nvext = NvExt::builder()
.extra_fields(vec!["timing".to_string()])
.build()
.unwrap();
let mut generator =
make_request_with_nvext(nvext).response_generator("req-timing".to_string());
let response = generator
.choice_from_postprocessor(final_backend_output())
.expect("choice generation");
let nvext_json = response.nvext.expect("nvext present for timing request");
assert!(
nvext_json.get("timing").is_some(),
"timing should be emitted when extra_fields=[\"timing\"]"
);
assert!(nvext_json.get("worker_id").is_none());
assert!(nvext_json.get("token_ids").is_none());
assert!(nvext_json.get("routed_experts").is_none());
}
#[test]
fn test_query_instance_id_emits_worker_id_and_token_ids() {
use crate::protocols::openai::nvext::NvExt;
let nvext = NvExt::builder()
.annotations(vec!["query_instance_id:abc".to_string()])
.build()
.unwrap();
let mut generator =
make_request_with_nvext(nvext).response_generator("req-qid".to_string());
let tracker = generator.tracker().expect("tracker");
tracker.record_worker(42, Some(0), WORKER_TYPE_PREFILL);
let response = generator
.choice_from_postprocessor(final_backend_output())
.expect("choice generation");
let nvext_json = response
.nvext
.expect("nvext present for query_instance_id flow");
assert!(nvext_json.get("worker_id").is_some());
assert_eq!(
nvext_json.get("token_ids"),
Some(&serde_json::json!([11, 22, 33]))
);
// timing is NOT auto-enabled for query_instance_id — it is gated by `extra_fields: ["timing"]`.
assert!(nvext_json.get("timing").is_none());
assert!(nvext_json.get("routed_experts").is_none());
}
#[test]
fn test_routed_experts_extra_field_emits_routed_experts() {
use crate::protocols::openai::nvext::NvExt;
let nvext = NvExt::builder()
.extra_fields(vec!["routed_experts".to_string()])
.build()
.unwrap();
let mut generator =
make_request_with_nvext(nvext).response_generator("req-experts".to_string());
let response = generator
.choice_from_postprocessor(final_backend_output())
.expect("choice generation");
let nvext_json = response
.nvext
.expect("nvext present for routed_experts request");
assert_eq!(
nvext_json.get("routed_experts"),
Some(&serde_json::json!({"layer_0": [1, 3]}))
);
assert!(nvext_json.get("worker_id").is_none());
assert!(nvext_json.get("timing").is_none());
assert!(nvext_json.get("token_ids").is_none());
}
}
......@@ -116,6 +116,117 @@ pub struct NvExtResponse {
pub routed_experts: Option<serde_json::Value>,
}
/// Response nvext fields requested for a given request.
///
/// The OpenAI-compatible API should only include `nvext` response fields when the
/// client explicitly opts in via `nvext.extra_fields`, except for the GAIE
/// `query_instance_id` flow which automatically returns `worker_id` and
/// `token_ids`. Note: timing is NOT auto-enabled for `query_instance_id`
/// because the query-only fast path returns no `finish_reason`, and timing
/// is only emitted on the final response chunk.
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub struct NvExtResponseFieldSelection {
pub worker_id: bool,
pub timing: bool,
pub token_ids: bool,
pub routed_experts: bool,
}
impl NvExtResponseFieldSelection {
pub fn from_nvext(nvext: Option<&NvExt>) -> Self {
let Some(ext) = nvext else {
return Self::default();
};
let mut selection = Self::default();
if let Some(fields) = ext.extra_fields.as_ref() {
for field in fields {
match field.as_str() {
"worker_id" => selection.worker_id = true,
"timing" => selection.timing = true,
"routed_experts" => selection.routed_experts = true,
_ => {}
}
}
}
if ext.has_query_instance_id_annotation() {
selection.worker_id = true;
selection.token_ids = true;
}
selection
}
/// Build the `nvext` response payload for a single response chunk, applying
/// per-field gating uniformly across chat and completions delta generators.
///
/// Returns `None` when no fields would be emitted, so call sites can skip
/// their serialization + debug-tracing blocks entirely. Call sites remain
/// responsible for:
///
/// - calling `RequestTracker::record_finish()` (a side effect that must run
/// regardless of whether `timing` is returned to the client), and
/// - emitting provider-specific debug tracing (`"completions nvext"` vs
/// `"chat completion nvext"` labels) so log filtering still works.
///
/// Gating rules match the previous per-site logic byte-for-byte:
///
/// - `worker_id` requires the selection flag **and** `tracker.get_worker_info()` to return `Some`.
/// - `token_ids` requires the selection flag **and** a `"token_ids"` key on `disaggregated_params`
/// that deserializes into `Vec<u32>`; malformed values silently fall back to `None`.
/// - `routed_experts` requires the selection flag **and** a `"routed_experts"` key on
/// `disaggregated_params` (cloned as-is, no validation).
/// - `timing` requires the selection flag, `finish_reason_present == true`, **and** a tracker.
pub fn build_response_nvext(
&self,
tracker: Option<&std::sync::Arc<crate::protocols::common::timing::RequestTracker>>,
disaggregated_params: Option<&serde_json::Value>,
finish_reason_present: bool,
) -> Option<NvExtResponse> {
let worker_id = if self.worker_id {
tracker.and_then(|t| t.get_worker_info())
} else {
None
};
let token_ids = if self.token_ids {
disaggregated_params
.and_then(|params| params.get("token_ids"))
.and_then(|v| serde_json::from_value::<Vec<u32>>(v.clone()).ok())
} else {
None
};
let routed_experts = if self.routed_experts {
disaggregated_params
.and_then(|params| params.get("routed_experts"))
.cloned()
} else {
None
};
let timing = if finish_reason_present && self.timing {
tracker.map(|t| t.get_timing_info())
} else {
None
};
if worker_id.is_none()
&& token_ids.is_none()
&& routed_experts.is_none()
&& timing.is_none()
{
return None;
}
Some(NvExtResponse {
worker_id,
timing,
token_ids,
routed_experts,
})
}
}
/// NVIDIA LLM extensions to the OpenAI API
#[derive(ToSchema, Serialize, Deserialize, Builder, Validate, Debug, Clone)]
#[validate(schema(function = "validate_nv_ext"))]
......@@ -285,6 +396,21 @@ impl NvExt {
pub fn builder() -> NvExtBuilder {
NvExtBuilder::default()
}
/// Check for a `query_instance_id:<value>` annotation (GAIE Stage 1).
///
/// Must match the exact `"query_instance_id:"` key prefix used by
/// `PreprocessedRequest::get_annotation_value` and the KvPushRouter
/// query-only detection, so that stray annotations like
/// `query_instance_id_extra:...` do not accidentally enable response
/// metadata.
pub fn has_query_instance_id_annotation(&self) -> bool {
self.annotations.as_ref().is_some_and(|annotations| {
annotations
.iter()
.any(|annotation| annotation.starts_with("query_instance_id:"))
})
}
}
fn validate_nv_ext(_nv_ext: &NvExt) -> Result<(), ValidationError> {
......@@ -422,4 +548,302 @@ mod tests {
assert_eq!(result.dp_rank, Some(3));
assert_eq!(result.prefill_dp_rank, Some(5));
}
#[test]
fn test_nvext_response_field_selection_defaults_to_none() {
let selection = NvExtResponseFieldSelection::from_nvext(None);
assert_eq!(selection, NvExtResponseFieldSelection::default());
}
#[test]
fn test_nvext_response_field_selection_respects_extra_fields() {
let nvext = NvExt::builder()
.extra_fields(vec!["worker_id".to_string(), "routed_experts".to_string()])
.build()
.unwrap();
let selection = NvExtResponseFieldSelection::from_nvext(Some(&nvext));
assert!(selection.worker_id);
assert!(!selection.timing);
assert!(!selection.token_ids);
assert!(selection.routed_experts);
}
#[test]
fn test_nvext_response_field_selection_query_instance_id_exception() {
let nvext = NvExt::builder()
.annotations(vec!["query_instance_id:".to_string()])
.build()
.unwrap();
let selection = NvExtResponseFieldSelection::from_nvext(Some(&nvext));
assert!(selection.worker_id);
assert!(!selection.timing); // timing NOT auto-enabled: query-only fast path has no finish_reason
assert!(selection.token_ids);
assert!(!selection.routed_experts);
}
#[test]
fn test_nvext_response_field_selection_rejects_stray_annotation() {
// An annotation like "query_instance_id_extra:foo" must NOT trigger the
// query_instance_id exception — only the exact "query_instance_id:" key
// prefix should match, consistent with PreprocessedRequest::get_annotation_value.
let nvext = NvExt::builder()
.annotations(vec!["query_instance_id_extra:foo".to_string()])
.build()
.unwrap();
assert_eq!(
NvExtResponseFieldSelection::from_nvext(Some(&nvext)),
NvExtResponseFieldSelection::default(),
);
}
#[test]
fn test_nvext_response_field_selection_worker_id_only() {
let nvext = NvExt::builder()
.extra_fields(vec!["worker_id".to_string()])
.build()
.unwrap();
assert_eq!(
NvExtResponseFieldSelection::from_nvext(Some(&nvext)),
NvExtResponseFieldSelection {
worker_id: true,
..Default::default()
}
);
}
#[test]
fn test_nvext_response_field_selection_timing_only() {
let nvext = NvExt::builder()
.extra_fields(vec!["timing".to_string()])
.build()
.unwrap();
assert_eq!(
NvExtResponseFieldSelection::from_nvext(Some(&nvext)),
NvExtResponseFieldSelection {
timing: true,
..Default::default()
}
);
}
#[test]
fn test_nvext_response_field_selection_routed_experts_only() {
let nvext = NvExt::builder()
.extra_fields(vec!["routed_experts".to_string()])
.build()
.unwrap();
assert_eq!(
NvExtResponseFieldSelection::from_nvext(Some(&nvext)),
NvExtResponseFieldSelection {
routed_experts: true,
..Default::default()
}
);
}
// Helpers for build_response_nvext tests -----------------------------
fn sel_all_false() -> NvExtResponseFieldSelection {
NvExtResponseFieldSelection::default()
}
fn tracker_with_prefill_worker()
-> std::sync::Arc<crate::protocols::common::timing::RequestTracker> {
use crate::protocols::common::timing::{RequestTracker, WORKER_TYPE_PREFILL};
let tracker = std::sync::Arc::new(RequestTracker::new());
tracker.record_worker(42, Some(0), WORKER_TYPE_PREFILL);
tracker
}
fn disagg_params_full() -> serde_json::Value {
serde_json::json!({
"token_ids": [11u32, 22u32, 33u32],
"routed_experts": {"layer_0": [1, 3]},
})
}
// ---------------------------------------------------------------------
#[test]
fn test_build_response_nvext_all_false_returns_none() {
let sel = sel_all_false();
assert!(
sel.build_response_nvext(None, None, false).is_none(),
"no fields selected → None"
);
assert!(
sel.build_response_nvext(None, None, true).is_none(),
"finish_reason alone does not force emission"
);
}
#[test]
fn test_build_response_nvext_worker_id_only_without_finish() {
let sel = NvExtResponseFieldSelection {
worker_id: true,
..Default::default()
};
let tracker = tracker_with_prefill_worker();
// finish_reason=false: worker_id still emitted (only timing is finish-gated).
let out = sel
.build_response_nvext(Some(&tracker), None, false)
.expect("worker_id should emit regardless of finish_reason");
assert!(out.worker_id.is_some());
assert!(out.timing.is_none());
assert!(out.token_ids.is_none());
assert!(out.routed_experts.is_none());
}
#[test]
fn test_build_response_nvext_timing_suppressed_without_finish() {
let sel = NvExtResponseFieldSelection {
timing: true,
..Default::default()
};
let tracker = tracker_with_prefill_worker();
// timing alone + finish_reason=false → nothing to emit, returns None.
assert!(
sel.build_response_nvext(Some(&tracker), None, false)
.is_none(),
"timing is gated on finish_reason_present"
);
}
#[test]
fn test_build_response_nvext_timing_emitted_on_finish() {
let sel = NvExtResponseFieldSelection {
timing: true,
..Default::default()
};
let tracker = tracker_with_prefill_worker();
let out = sel
.build_response_nvext(Some(&tracker), None, true)
.expect("timing should emit on finish");
assert!(out.timing.is_some());
assert!(out.worker_id.is_none());
assert!(out.token_ids.is_none());
assert!(out.routed_experts.is_none());
}
#[test]
fn test_build_response_nvext_timing_requires_tracker() {
let sel = NvExtResponseFieldSelection {
timing: true,
..Default::default()
};
// finish=true but no tracker → timing not populated → None.
assert!(sel.build_response_nvext(None, None, true).is_none());
}
#[test]
fn test_build_response_nvext_token_ids_from_disagg_params() {
let sel = NvExtResponseFieldSelection {
token_ids: true,
..Default::default()
};
let params = disagg_params_full();
let out = sel
.build_response_nvext(None, Some(&params), false)
.expect("token_ids should emit when present");
assert_eq!(out.token_ids, Some(vec![11u32, 22, 33]));
assert!(out.worker_id.is_none());
assert!(out.timing.is_none());
assert!(out.routed_experts.is_none());
}
#[test]
fn test_build_response_nvext_token_ids_malformed_falls_back_to_none() {
let sel = NvExtResponseFieldSelection {
token_ids: true,
..Default::default()
};
// String payload cannot deserialize into Vec<u32> — matches existing `.ok()` behavior.
let params = serde_json::json!({ "token_ids": "not-an-array" });
assert!(
sel.build_response_nvext(None, Some(&params), false)
.is_none(),
"malformed token_ids silently suppressed; nothing else selected → None"
);
}
#[test]
fn test_build_response_nvext_routed_experts_cloned_as_is() {
let sel = NvExtResponseFieldSelection {
routed_experts: true,
..Default::default()
};
let params = disagg_params_full();
let out = sel
.build_response_nvext(None, Some(&params), false)
.expect("routed_experts should emit when present");
assert_eq!(
out.routed_experts,
Some(serde_json::json!({"layer_0": [1, 3]}))
);
}
#[test]
fn test_build_response_nvext_combined_emission() {
let sel = NvExtResponseFieldSelection {
worker_id: true,
timing: true,
token_ids: true,
routed_experts: true,
};
let tracker = tracker_with_prefill_worker();
let params = disagg_params_full();
let out = sel
.build_response_nvext(Some(&tracker), Some(&params), true)
.expect("all fields selected and available → Some");
assert!(out.worker_id.is_some());
assert!(out.timing.is_some());
assert_eq!(out.token_ids, Some(vec![11u32, 22, 33]));
assert_eq!(
out.routed_experts,
Some(serde_json::json!({"layer_0": [1, 3]}))
);
}
#[test]
fn test_nvext_response_field_selection_multiple_extra_fields() {
let nvext = NvExt::builder()
.extra_fields(vec![
"worker_id".to_string(),
"timing".to_string(),
"routed_experts".to_string(),
])
.build()
.unwrap();
assert_eq!(
NvExtResponseFieldSelection::from_nvext(Some(&nvext)),
NvExtResponseFieldSelection {
worker_id: true,
timing: true,
token_ids: false, // only enabled via query_instance_id
routed_experts: true,
}
);
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment