Unverified Commit 263f99dc authored by Zhongxuan (Daniel) Wang's avatar Zhongxuan (Daniel) Wang Committed by GitHub
Browse files

feat: nvext field to OpenAI APIs and add worker_id reporting (vLLM) (#4372)


Signed-off-by: default avatarZhongxuan Wang <daniewang@nvidia.com>
parent 684107c4
...@@ -1045,6 +1045,10 @@ pub struct CreateChatCompletionResponse { ...@@ -1045,6 +1045,10 @@ pub struct CreateChatCompletionResponse {
/// The object type, which is always `chat.completion`. /// The object type, which is always `chat.completion`.
pub object: String, pub object: String,
pub usage: Option<CompletionUsage>, pub usage: Option<CompletionUsage>,
/// NVIDIA extension field for response metadata (worker IDs, etc.)
#[serde(skip_serializing_if = "Option::is_none")]
pub nvext: Option<serde_json::Value>,
} }
/// Parsed server side events stream until an \[DONE\] is received from server. /// Parsed server side events stream until an \[DONE\] is received from server.
...@@ -1136,6 +1140,10 @@ pub struct CreateChatCompletionStreamResponse { ...@@ -1136,6 +1140,10 @@ pub struct CreateChatCompletionStreamResponse {
/// An optional field that will only be present when you set `stream_options: {"include_usage": true}` in your request. /// An optional field that will only be present when you set `stream_options: {"include_usage": true}` in your request.
/// When present, it contains a null value except for the last chunk which contains the token usage statistics for the entire request. /// When present, it contains a null value except for the last chunk which contains the token usage statistics for the entire request.
pub usage: Option<CompletionUsage>, pub usage: Option<CompletionUsage>,
/// NVIDIA extension field for response metadata
#[serde(skip_serializing_if = "Option::is_none")]
pub nvext: Option<serde_json::Value>,
} }
#[cfg(test)] #[cfg(test)]
......
...@@ -216,6 +216,10 @@ pub struct CreateCompletionResponse { ...@@ -216,6 +216,10 @@ pub struct CreateCompletionResponse {
/// The object type, which is always "text_completion" /// The object type, which is always "text_completion"
pub object: String, pub object: String,
pub usage: Option<CompletionUsage>, pub usage: Option<CompletionUsage>,
/// NVIDIA extension field for response metadata (worker IDs, etc.)
#[serde(skip_serializing_if = "Option::is_none")]
pub nvext: Option<serde_json::Value>,
} }
/// Parsed server side events stream until an \[DONE\] is received from server. /// Parsed server side events stream until an \[DONE\] is received from server.
......
...@@ -419,6 +419,7 @@ impl ...@@ -419,6 +419,7 @@ impl
usage: None, usage: None,
system_fingerprint: Some(c.system_fingerprint), system_fingerprint: Some(c.system_fingerprint),
service_tier: None, service_tier: None,
nvext: None,
}; };
let ann = Annotated{ let ann = Annotated{
id: None, id: None,
......
...@@ -98,6 +98,7 @@ where ...@@ -98,6 +98,7 @@ where
system_fingerprint: None, system_fingerprint: None,
choices: vec![], choices: vec![],
service_tier: None, service_tier: None,
nvext: None,
} }
}) })
}), }),
...@@ -132,6 +133,7 @@ where ...@@ -132,6 +133,7 @@ where
system_fingerprint: None, system_fingerprint: None,
choices: vec![], choices: vec![],
service_tier: None, service_tier: None,
nvext: None,
}; };
let _ = tx.send(fallback.clone()); let _ = tx.send(fallback.clone());
final_response_to_one_chunk_stream(fallback) final_response_to_one_chunk_stream(fallback)
...@@ -151,6 +153,7 @@ where ...@@ -151,6 +153,7 @@ where
system_fingerprint: None, system_fingerprint: None,
choices: vec![], choices: vec![],
service_tier: None, service_tier: None,
nvext: None,
} }
}) })
}); });
...@@ -226,6 +229,7 @@ pub fn final_response_to_one_chunk_stream( ...@@ -226,6 +229,7 @@ pub fn final_response_to_one_chunk_stream(
service_tier: resp.service_tier.clone(), service_tier: resp.service_tier.clone(),
choices, choices,
usage: resp.usage.clone(), usage: resp.usage.clone(),
nvext: resp.nvext.clone(),
}; };
let annotated = Annotated { let annotated = Annotated {
...@@ -275,6 +279,7 @@ mod tests { ...@@ -275,6 +279,7 @@ mod tests {
object: "chat.completion.chunk".to_string(), object: "chat.completion.chunk".to_string(),
usage: None, usage: None,
service_tier: None, service_tier: None,
nvext: None,
}; };
Annotated { Annotated {
...@@ -311,6 +316,7 @@ mod tests { ...@@ -311,6 +316,7 @@ mod tests {
object: "chat.completion.chunk".to_string(), object: "chat.completion.chunk".to_string(),
usage: None, usage: None,
service_tier: None, service_tier: None,
nvext: None,
}; };
Annotated { Annotated {
...@@ -430,6 +436,7 @@ mod tests { ...@@ -430,6 +436,7 @@ mod tests {
object: "chat.completion.chunk".to_string(), object: "chat.completion.chunk".to_string(),
usage: None, usage: None,
service_tier: None, service_tier: None,
nvext: None,
}), }),
id: Some("correlation-123".to_string()), id: Some("correlation-123".to_string()),
event: Some("test-event".to_string()), event: Some("test-event".to_string()),
......
...@@ -243,6 +243,7 @@ impl ...@@ -243,6 +243,7 @@ impl
//mdcsum: mdcsum.clone(), //mdcsum: mdcsum.clone(),
index: data.index, index: data.index,
completion_usage: data.completion_usage, completion_usage: data.completion_usage,
disaggregated_params: data.disaggregated_params,
}) })
}) })
}); });
......
...@@ -19,6 +19,7 @@ use dynamo_runtime::{ ...@@ -19,6 +19,7 @@ use dynamo_runtime::{
}; };
use futures::stream::{self, StreamExt}; use futures::stream::{self, StreamExt};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::json;
pub mod approx; pub mod approx;
pub mod indexer; pub mod indexer;
...@@ -583,6 +584,24 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu ...@@ -583,6 +584,24 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
let (mut backend_input, context) = request.into_parts(); let (mut backend_input, context) = request.into_parts();
backend_input.estimated_prefix_hit_num_blocks = Some(overlap_amount); backend_input.estimated_prefix_hit_num_blocks = Some(overlap_amount);
backend_input.dp_rank = Some(dp_rank); backend_input.dp_rank = Some(dp_rank);
// Check if worker_id is requested in extra_fields
let should_populate_worker_id = backend_input
.extra_fields
.as_deref()
.unwrap_or(&[])
.iter()
.any(|s| s == "worker_id");
// Get prefill worker ID if available (stored by PrefillRouter)
// In aggregated mode, prefill_worker_id is None, so we use decode_worker_id for both
let decode_worker_id = instance_id;
let prefill_worker_id = context
.get::<u64>("prefill_worker_id")
.ok()
.map(|arc| *arc)
.or(Some(decode_worker_id)); // Use decode_worker_id if no separate prefill worker
let updated_request = context.map(|_| backend_input); let updated_request = context.map(|_| backend_input);
let mut response_stream = self.inner.direct(updated_request, instance_id).await?; let mut response_stream = self.inner.direct(updated_request, instance_id).await?;
...@@ -592,6 +611,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu ...@@ -592,6 +611,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
let wrapped_stream = Box::pin(async_stream::stream! { let wrapped_stream = Box::pin(async_stream::stream! {
let mut prefill_marked = false; let mut prefill_marked = false;
let mut first_item = true;
loop { loop {
tokio::select! { tokio::select! {
...@@ -603,7 +623,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu ...@@ -603,7 +623,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
} }
item = response_stream.next() => { item = response_stream.next() => {
let Some(item) = item else { let Some(mut item) = item else {
break; break;
}; };
...@@ -613,7 +633,28 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu ...@@ -613,7 +633,28 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
} }
prefill_marked = true; prefill_marked = true;
} }
yield item;
yield item.clone();
// Inject worker_id in first item's disaggregated_params if requested
if first_item && should_populate_worker_id {
if let Some(ref mut data) = item.data {
// Add worker_id to disaggregated_params
let worker_id_json = json!({
"prefill_worker_id": prefill_worker_id,
"decode_worker_id": decode_worker_id,
});
if let Some(ref mut params) = data.disaggregated_params {
if let Some(obj) = params.as_object_mut() {
obj.insert("worker_id".to_string(), worker_id_json);
}
} else {
data.disaggregated_params = Some(json!({"worker_id": worker_id_json}));
}
}
first_item = false;
}
} }
} }
} }
......
...@@ -176,11 +176,11 @@ impl PrefillRouter { ...@@ -176,11 +176,11 @@ impl PrefillRouter {
Ok(()) Ok(())
} }
/// Call the prefill router and extract structured prefill result /// Call the prefill router and extract structured prefill result and worker ID
async fn call_prefill( async fn call_prefill(
&self, &self,
request: SingleIn<PreprocessedRequest>, request: SingleIn<PreprocessedRequest>,
) -> Result<PrefillResult, PrefillError> { ) -> Result<(PrefillResult, Option<u64>), PrefillError> {
// Get the prefill router, error if not activated // Get the prefill router, error if not activated
let Some(prefill_router) = self.prefill_router.get() else { let Some(prefill_router) = self.prefill_router.get() else {
return Err(PrefillError::NotActivated); return Err(PrefillError::NotActivated);
...@@ -239,10 +239,21 @@ impl PrefillRouter { ...@@ -239,10 +239,21 @@ impl PrefillRouter {
)); ));
}; };
Ok(PrefillResult { // Extract prefill worker ID from disaggregated_params
disaggregated_params, let prefill_worker_id = disaggregated_params
prompt_tokens_details, .get("worker_id")
}) .and_then(|worker_id_json| {
worker_id_json
.get("prefill_worker_id")
.and_then(|v| v.as_u64())
});
Ok((
PrefillResult {
disaggregated_params,
prompt_tokens_details,
},
prefill_worker_id,
))
} }
} }
...@@ -299,7 +310,7 @@ impl ...@@ -299,7 +310,7 @@ impl
// Handle prefill result // Handle prefill result
match prefill_result { match prefill_result {
Ok(prefill_result) => { Ok((prefill_result, prefill_worker_id)) => {
tracing::debug!("Prefill succeeded, using disaggregated params for decode"); tracing::debug!("Prefill succeeded, using disaggregated params for decode");
let mut decode_req = req; let mut decode_req = req;
...@@ -315,8 +326,14 @@ impl ...@@ -315,8 +326,14 @@ impl
..existing_override.unwrap_or_default() ..existing_override.unwrap_or_default()
}); });
// Store prefill worker ID in context if available
let mut decode_context = context;
if let Some(worker_id) = prefill_worker_id {
decode_context.insert("prefill_worker_id", worker_id);
}
// Map the modified request through with preserved context // Map the modified request through with preserved context
let decode_request = context.map(|_| decode_req); let decode_request = decode_context.map(|_| decode_req);
next.generate(decode_request).await next.generate(decode_request).await
} }
Err(PrefillError::NotActivated) => { Err(PrefillError::NotActivated) => {
......
...@@ -972,6 +972,7 @@ mod tests { ...@@ -972,6 +972,7 @@ mod tests {
system_fingerprint: None, system_fingerprint: None,
object: "chat.completion.chunk".to_string(), object: "chat.completion.chunk".to_string(),
usage: None, usage: None,
nvext: None,
} }
} }
...@@ -1009,6 +1010,7 @@ mod tests { ...@@ -1009,6 +1010,7 @@ mod tests {
system_fingerprint: None, system_fingerprint: None,
object: "chat.completion.chunk".to_string(), object: "chat.completion.chunk".to_string(),
usage: None, usage: None,
nvext: None,
} }
} }
...@@ -1349,6 +1351,7 @@ mod tests { ...@@ -1349,6 +1351,7 @@ mod tests {
system_fingerprint: None, system_fingerprint: None,
object: "chat.completion.chunk".to_string(), object: "chat.completion.chunk".to_string(),
usage: None, usage: None,
nvext: None,
}; };
let logprobs = response.extract_logprobs_by_choice(); let logprobs = response.extract_logprobs_by_choice();
...@@ -1563,6 +1566,7 @@ mod tests { ...@@ -1563,6 +1566,7 @@ mod tests {
system_fingerprint: None, system_fingerprint: None,
object: "chat.completion.chunk".to_string(), object: "chat.completion.chunk".to_string(),
usage: None, usage: None,
nvext: None,
} }
} }
......
...@@ -228,9 +228,10 @@ impl OpenAIPreprocessor { ...@@ -228,9 +228,10 @@ impl OpenAIPreprocessor {
builder.annotations(request.annotations().unwrap_or_default()); builder.annotations(request.annotations().unwrap_or_default());
builder.mdc_sum(Some(self.mdcsum.clone())); builder.mdc_sum(Some(self.mdcsum.clone()));
builder.estimated_prefix_hit_num_blocks(None); builder.estimated_prefix_hit_num_blocks(None);
// Extract backend_instance_id from nvext if present // Extract backend_instance_id and extra_fields from nvext if present
if let Some(nvext) = request.nvext() { if let Some(nvext) = request.nvext() {
builder.backend_instance_id(nvext.backend_instance_id); builder.backend_instance_id(nvext.backend_instance_id);
builder.extra_fields(nvext.extra_fields.clone());
} }
Ok(builder) Ok(builder)
......
...@@ -53,6 +53,10 @@ pub struct BackendOutput { ...@@ -53,6 +53,10 @@ pub struct BackendOutput {
// Token usage information // Token usage information
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
pub completion_usage: Option<CompletionUsage>, pub completion_usage: Option<CompletionUsage>,
/// Disaggregated execution parameters (for prefill/decode separation)
#[serde(default, skip_serializing_if = "Option::is_none")]
pub disaggregated_params: Option<serde_json::Value>,
} }
/// The LLM engine and backnd with manage it's own state, specifically translating how a /// The LLM engine and backnd with manage it's own state, specifically translating how a
......
...@@ -92,6 +92,11 @@ pub struct PreprocessedRequest { ...@@ -92,6 +92,11 @@ pub struct PreprocessedRequest {
#[builder(default)] #[builder(default)]
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
pub extra_args: Option<serde_json::Value>, pub extra_args: Option<serde_json::Value>,
/// Extra fields requested to be included in the response's nvext
#[builder(default)]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub extra_fields: Option<Vec<String>>,
} }
impl PreprocessedRequest { impl PreprocessedRequest {
......
...@@ -34,6 +34,8 @@ pub struct DeltaAggregator { ...@@ -34,6 +34,8 @@ pub struct DeltaAggregator {
error: Option<String>, error: Option<String>,
/// Optional service tier information for the response. /// Optional service tier information for the response.
service_tier: Option<dynamo_async_openai::types::ServiceTierResponse>, service_tier: Option<dynamo_async_openai::types::ServiceTierResponse>,
/// Aggregated nvext field from stream responses
nvext: Option<serde_json::Value>,
} }
/// Represents the accumulated state of a single chat choice during streaming aggregation. /// Represents the accumulated state of a single chat choice during streaming aggregation.
...@@ -97,6 +99,7 @@ impl DeltaAggregator { ...@@ -97,6 +99,7 @@ impl DeltaAggregator {
choices: HashMap::new(), choices: HashMap::new(),
error: None, error: None,
service_tier: None, service_tier: None,
nvext: None,
} }
} }
...@@ -140,6 +143,11 @@ impl DeltaAggregator { ...@@ -140,6 +143,11 @@ impl DeltaAggregator {
aggregator.system_fingerprint = Some(system_fingerprint); aggregator.system_fingerprint = Some(system_fingerprint);
} }
// Aggregate nvext field (take the last non-None value)
if delta.nvext.is_some() {
aggregator.nvext = delta.nvext;
}
// Aggregate choices incrementally. // Aggregate choices incrementally.
for choice in delta.choices { for choice in delta.choices {
let state_choice = let state_choice =
...@@ -247,6 +255,7 @@ impl DeltaAggregator { ...@@ -247,6 +255,7 @@ impl DeltaAggregator {
system_fingerprint: aggregator.system_fingerprint, system_fingerprint: aggregator.system_fingerprint,
choices, choices,
service_tier: aggregator.service_tier, service_tier: aggregator.service_tier,
nvext: aggregator.nvext,
}; };
Ok(response) Ok(response)
...@@ -411,6 +420,7 @@ mod tests { ...@@ -411,6 +420,7 @@ mod tests {
system_fingerprint: None, system_fingerprint: None,
choices: vec![choice], choices: vec![choice],
object: "chat.completion".to_string(), object: "chat.completion".to_string(),
nvext: None,
}; };
Annotated { Annotated {
...@@ -633,6 +643,7 @@ mod tests { ...@@ -633,6 +643,7 @@ mod tests {
}, },
], ],
object: "chat.completion".to_string(), object: "chat.completion".to_string(),
nvext: None,
}; };
// Wrap it in Annotated and create a stream // Wrap it in Annotated and create a stream
......
...@@ -258,6 +258,7 @@ impl DeltaGenerator { ...@@ -258,6 +258,7 @@ impl DeltaGenerator {
choices, choices,
usage: None, // Always None for chunks with content/choices usage: None, // Always None for chunks with content/choices
service_tier: self.service_tier.clone(), service_tier: self.service_tier.clone(),
nvext: None, // Will be populated by router layer if needed
} }
} }
...@@ -279,6 +280,7 @@ impl DeltaGenerator { ...@@ -279,6 +280,7 @@ impl DeltaGenerator {
choices: vec![], // Empty choices for usage-only chunk choices: vec![], // Empty choices for usage-only chunk
usage: Some(usage), usage: Some(usage),
service_tier: self.service_tier.clone(), service_tier: self.service_tier.clone(),
nvext: None,
} }
} }
...@@ -358,7 +360,41 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamRes ...@@ -358,7 +360,41 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamRes
// Create the streaming response. // Create the streaming response.
let index = 0; let index = 0;
let stream_response = self.create_choice(index, delta.text, finish_reason, logprobs); let mut stream_response = self.create_choice(index, delta.text, finish_reason, logprobs);
// Extract worker_id from disaggregated_params and inject into nvext if present
if let Some(worker_id_json) = delta
.disaggregated_params
.as_ref()
.and_then(|params| params.get("worker_id"))
{
use crate::protocols::openai::nvext::{NvExtResponse, WorkerIdInfo};
let prefill_worker_id = worker_id_json
.get("prefill_worker_id")
.and_then(|v| v.as_u64());
let decode_worker_id = worker_id_json
.get("decode_worker_id")
.and_then(|v| v.as_u64());
let worker_id_info = WorkerIdInfo {
prefill_worker_id,
decode_worker_id,
};
let nvext_response = NvExtResponse {
worker_id: Some(worker_id_info),
};
if let Ok(nvext_json) = serde_json::to_value(&nvext_response) {
stream_response.nvext = Some(nvext_json);
tracing::debug!(
"Injected worker_id into chat completion nvext: prefill={:?}, decode={:?}",
prefill_worker_id,
decode_worker_id
);
}
}
Ok(stream_response) Ok(stream_response)
} }
......
...@@ -582,6 +582,7 @@ impl JailedStream { ...@@ -582,6 +582,7 @@ impl JailedStream {
usage: None, usage: None,
service_tier: None, service_tier: None,
system_fingerprint: None, system_fingerprint: None,
nvext: None,
}; };
let final_metadata = (last_annotated_id, last_annotated_event, last_annotated_comment); let final_metadata = (last_annotated_id, last_annotated_event, last_annotated_comment);
......
...@@ -291,6 +291,7 @@ impl ResponseFactory { ...@@ -291,6 +291,7 @@ impl ResponseFactory {
choices: vec![choice], choices: vec![choice],
system_fingerprint: self.system_fingerprint.clone(), system_fingerprint: self.system_fingerprint.clone(),
usage, usage,
nvext: None, // Will be populated by router layer if needed
}; };
NvCreateCompletionResponse { inner } NvCreateCompletionResponse { inner }
} }
......
...@@ -24,6 +24,7 @@ pub struct DeltaAggregator { ...@@ -24,6 +24,7 @@ pub struct DeltaAggregator {
system_fingerprint: Option<String>, system_fingerprint: Option<String>,
choices: HashMap<u32, DeltaChoice>, choices: HashMap<u32, DeltaChoice>,
error: Option<String>, error: Option<String>,
nvext: Option<serde_json::Value>,
} }
struct DeltaChoice { struct DeltaChoice {
...@@ -49,6 +50,7 @@ impl DeltaAggregator { ...@@ -49,6 +50,7 @@ impl DeltaAggregator {
system_fingerprint: None, system_fingerprint: None,
choices: HashMap::new(), choices: HashMap::new(),
error: None, error: None,
nvext: None,
} }
} }
...@@ -84,6 +86,10 @@ impl DeltaAggregator { ...@@ -84,6 +86,10 @@ impl DeltaAggregator {
if let Some(system_fingerprint) = delta.inner.system_fingerprint { if let Some(system_fingerprint) = delta.inner.system_fingerprint {
aggregator.system_fingerprint = Some(system_fingerprint); aggregator.system_fingerprint = Some(system_fingerprint);
} }
// Aggregate nvext field (take the last non-None value)
if delta.inner.nvext.is_some() {
aggregator.nvext = delta.inner.nvext;
}
// handle the choices // handle the choices
for choice in delta.inner.choices { for choice in delta.inner.choices {
...@@ -163,6 +169,7 @@ impl DeltaAggregator { ...@@ -163,6 +169,7 @@ impl DeltaAggregator {
object: "text_completion".to_string(), object: "text_completion".to_string(),
system_fingerprint: aggregator.system_fingerprint, system_fingerprint: aggregator.system_fingerprint,
choices, choices,
nvext: aggregator.nvext,
}; };
let response = NvCreateCompletionResponse { inner }; let response = NvCreateCompletionResponse { inner };
...@@ -250,6 +257,7 @@ mod tests { ...@@ -250,6 +257,7 @@ mod tests {
logprobs, logprobs,
}], }],
object: "text_completion".to_string(), object: "text_completion".to_string(),
nvext: None,
}; };
let response = NvCreateCompletionResponse { inner }; let response = NvCreateCompletionResponse { inner };
...@@ -379,6 +387,7 @@ mod tests { ...@@ -379,6 +387,7 @@ mod tests {
}, },
], ],
object: "text_completion".to_string(), object: "text_completion".to_string(),
nvext: None,
}; };
let response = NvCreateCompletionResponse { inner }; let response = NvCreateCompletionResponse { inner };
......
...@@ -189,6 +189,7 @@ impl DeltaGenerator { ...@@ -189,6 +189,7 @@ impl DeltaGenerator {
logprobs, logprobs,
}], }],
usage: None, // Always None for chunks with content/choices usage: None, // Always None for chunks with content/choices
nvext: None, // Will be populated by router layer if needed
}; };
NvCreateCompletionResponse { inner } NvCreateCompletionResponse { inner }
...@@ -211,6 +212,7 @@ impl DeltaGenerator { ...@@ -211,6 +212,7 @@ impl DeltaGenerator {
system_fingerprint: self.system_fingerprint.clone(), system_fingerprint: self.system_fingerprint.clone(),
choices: vec![], // Empty choices for usage-only chunk choices: vec![], // Empty choices for usage-only chunk
usage: Some(usage), usage: Some(usage),
nvext: None, // Will be populated by router layer if needed
}; };
NvCreateCompletionResponse { inner } NvCreateCompletionResponse { inner }
...@@ -261,7 +263,42 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateCompletionResponse> for ...@@ -261,7 +263,42 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateCompletionResponse> for
// create choice // create choice
let index = delta.index.unwrap_or(0); let index = delta.index.unwrap_or(0);
let response = self.create_choice(index, delta.text.clone(), finish_reason, logprobs); let mut response = self.create_choice(index, delta.text.clone(), finish_reason, logprobs);
// Extract worker_id from disaggregated_params and inject into nvext if present
if let Some(worker_id_json) = delta
.disaggregated_params
.as_ref()
.and_then(|params| params.get("worker_id"))
{
use crate::protocols::openai::nvext::{NvExtResponse, WorkerIdInfo};
let prefill_worker_id = worker_id_json
.get("prefill_worker_id")
.and_then(|v| v.as_u64());
let decode_worker_id = worker_id_json
.get("decode_worker_id")
.and_then(|v| v.as_u64());
let worker_id_info = WorkerIdInfo {
prefill_worker_id,
decode_worker_id,
};
let nvext_response = NvExtResponse {
worker_id: Some(worker_id_info),
};
if let Ok(nvext_json) = serde_json::to_value(&nvext_response) {
response.inner.nvext = Some(nvext_json);
tracing::debug!(
"Injected worker_id into completions nvext: prefill={:?}, decode={:?}",
prefill_worker_id,
decode_worker_id
);
}
}
Ok(response) Ok(response)
} }
......
...@@ -10,6 +10,26 @@ pub trait NvExtProvider { ...@@ -10,6 +10,26 @@ pub trait NvExtProvider {
fn raw_prompt(&self) -> Option<String>; fn raw_prompt(&self) -> Option<String>;
} }
/// Worker ID information for disaggregated serving
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct WorkerIdInfo {
/// The prefill worker ID that processed this request
#[serde(skip_serializing_if = "Option::is_none")]
pub prefill_worker_id: Option<u64>,
/// The decode worker ID that processed this request
#[serde(skip_serializing_if = "Option::is_none")]
pub decode_worker_id: Option<u64>,
}
/// NVIDIA LLM response extensions
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct NvExtResponse {
/// Worker ID information (prefill and decode worker IDs)
#[serde(skip_serializing_if = "Option::is_none")]
pub worker_id: Option<WorkerIdInfo>,
}
/// NVIDIA LLM extensions to the OpenAI API /// NVIDIA LLM extensions to the OpenAI API
#[derive(Serialize, Deserialize, Builder, Validate, Debug, Clone)] #[derive(Serialize, Deserialize, Builder, Validate, Debug, Clone)]
#[validate(schema(function = "validate_nv_ext"))] #[validate(schema(function = "validate_nv_ext"))]
...@@ -53,6 +73,13 @@ pub struct NvExt { ...@@ -53,6 +73,13 @@ pub struct NvExt {
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))] #[builder(default, setter(strip_option))]
pub max_thinking_tokens: Option<u32>, pub max_thinking_tokens: Option<u32>,
/// Extra fields to be included in the response's nvext
/// This is a list of field names that should be populated in the response
/// Supported fields: "worker_id"
#[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub extra_fields: Option<Vec<String>>,
} }
impl Default for NvExt { impl Default for NvExt {
...@@ -98,6 +125,7 @@ mod tests { ...@@ -98,6 +125,7 @@ mod tests {
assert_eq!(nv_ext.backend_instance_id, None); assert_eq!(nv_ext.backend_instance_id, None);
assert_eq!(nv_ext.token_data, None); assert_eq!(nv_ext.token_data, None);
assert_eq!(nv_ext.max_thinking_tokens, None); assert_eq!(nv_ext.max_thinking_tokens, None);
assert_eq!(nv_ext.extra_fields, None);
} }
// Test valid builder configurations // Test valid builder configurations
...@@ -109,6 +137,7 @@ mod tests { ...@@ -109,6 +137,7 @@ mod tests {
.backend_instance_id(42) .backend_instance_id(42)
.token_data(vec![1, 2, 3, 4]) .token_data(vec![1, 2, 3, 4])
.max_thinking_tokens(1024) .max_thinking_tokens(1024)
.extra_fields(vec!["worker_id".to_string()])
.build() .build()
.unwrap(); .unwrap();
...@@ -117,6 +146,7 @@ mod tests { ...@@ -117,6 +146,7 @@ mod tests {
assert_eq!(nv_ext.backend_instance_id, Some(42)); assert_eq!(nv_ext.backend_instance_id, Some(42));
assert_eq!(nv_ext.token_data, Some(vec![1, 2, 3, 4])); assert_eq!(nv_ext.token_data, Some(vec![1, 2, 3, 4]));
assert_eq!(nv_ext.max_thinking_tokens, Some(1024)); assert_eq!(nv_ext.max_thinking_tokens, Some(1024));
assert_eq!(nv_ext.extra_fields, Some(vec!["worker_id".to_string()]));
// Validate the built struct // Validate the built struct
assert!(nv_ext.validate().is_ok()); assert!(nv_ext.validate().is_ok());
} }
......
...@@ -31,6 +31,10 @@ pub struct NvCreateResponse { ...@@ -31,6 +31,10 @@ pub struct NvCreateResponse {
pub struct NvResponse { pub struct NvResponse {
#[serde(flatten)] #[serde(flatten)]
pub inner: dynamo_async_openai::types::responses::Response, pub inner: dynamo_async_openai::types::responses::Response,
/// NVIDIA extension field for response metadata (worker IDs, etc.)
#[serde(skip_serializing_if = "Option::is_none")]
pub nvext: Option<serde_json::Value>,
} }
/// Implements `NvExtProvider` for `NvCreateResponse`, /// Implements `NvExtProvider` for `NvCreateResponse`,
...@@ -203,6 +207,10 @@ impl TryFrom<NvCreateChatCompletionResponse> for NvResponse { ...@@ -203,6 +207,10 @@ impl TryFrom<NvCreateChatCompletionResponse> for NvResponse {
fn try_from(nv_resp: NvCreateChatCompletionResponse) -> Result<Self, Self::Error> { fn try_from(nv_resp: NvCreateChatCompletionResponse) -> Result<Self, Self::Error> {
let chat_resp = nv_resp; let chat_resp = nv_resp;
// Preserve nvext field from chat completion response
let nvext = chat_resp.nvext.clone();
let content_text = chat_resp let content_text = chat_resp
.choices .choices
.into_iter() .into_iter()
...@@ -253,7 +261,10 @@ impl TryFrom<NvCreateChatCompletionResponse> for NvResponse { ...@@ -253,7 +261,10 @@ impl TryFrom<NvCreateChatCompletionResponse> for NvResponse {
user: None, user: None,
}; };
Ok(NvResponse { inner: response }) Ok(NvResponse {
inner: response,
nvext,
})
} }
} }
...@@ -365,6 +376,7 @@ mod tests { ...@@ -365,6 +376,7 @@ mod tests {
system_fingerprint: None, system_fingerprint: None,
object: "chat.completion".to_string(), object: "chat.completion".to_string(),
usage: None, usage: None,
nvext: None,
}; };
let wrapped: NvResponse = chat_resp.try_into().unwrap(); let wrapped: NvResponse = chat_resp.try_into().unwrap();
......
...@@ -404,6 +404,7 @@ fn create_response_with_linear_probs( ...@@ -404,6 +404,7 @@ fn create_response_with_linear_probs(
system_fingerprint: None, system_fingerprint: None,
object: "chat.completion.chunk".to_string(), object: "chat.completion.chunk".to_string(),
usage: None, usage: None,
nvext: None,
} }
} }
...@@ -484,5 +485,6 @@ fn create_multi_choice_response( ...@@ -484,5 +485,6 @@ fn create_multi_choice_response(
system_fingerprint: None, system_fingerprint: None,
object: "chat.completion.chunk".to_string(), object: "chat.completion.chunk".to_string(),
usage: None, usage: None,
nvext: None,
} }
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment