Unverified Commit 263f99dc authored by Zhongxuan (Daniel) Wang's avatar Zhongxuan (Daniel) Wang Committed by GitHub
Browse files

feat: nvext field to OpenAI APIs and add worker_id reporting (vLLM) (#4372)


Signed-off-by: default avatarZhongxuan Wang <daniewang@nvidia.com>
parent 684107c4
......@@ -1045,6 +1045,10 @@ pub struct CreateChatCompletionResponse {
/// The object type, which is always `chat.completion`.
pub object: String,
pub usage: Option<CompletionUsage>,
/// NVIDIA extension field for response metadata (worker IDs, etc.)
#[serde(skip_serializing_if = "Option::is_none")]
pub nvext: Option<serde_json::Value>,
}
/// Parsed server side events stream until an \[DONE\] is received from server.
......@@ -1136,6 +1140,10 @@ pub struct CreateChatCompletionStreamResponse {
/// An optional field that will only be present when you set `stream_options: {"include_usage": true}` in your request.
/// When present, it contains a null value except for the last chunk which contains the token usage statistics for the entire request.
pub usage: Option<CompletionUsage>,
/// NVIDIA extension field for response metadata
#[serde(skip_serializing_if = "Option::is_none")]
pub nvext: Option<serde_json::Value>,
}
#[cfg(test)]
......
......@@ -216,6 +216,10 @@ pub struct CreateCompletionResponse {
/// The object type, which is always "text_completion"
pub object: String,
pub usage: Option<CompletionUsage>,
/// NVIDIA extension field for response metadata (worker IDs, etc.)
#[serde(skip_serializing_if = "Option::is_none")]
pub nvext: Option<serde_json::Value>,
}
/// Parsed server side events stream until an \[DONE\] is received from server.
......
......@@ -419,6 +419,7 @@ impl
usage: None,
system_fingerprint: Some(c.system_fingerprint),
service_tier: None,
nvext: None,
};
let ann = Annotated{
id: None,
......
......@@ -98,6 +98,7 @@ where
system_fingerprint: None,
choices: vec![],
service_tier: None,
nvext: None,
}
})
}),
......@@ -132,6 +133,7 @@ where
system_fingerprint: None,
choices: vec![],
service_tier: None,
nvext: None,
};
let _ = tx.send(fallback.clone());
final_response_to_one_chunk_stream(fallback)
......@@ -151,6 +153,7 @@ where
system_fingerprint: None,
choices: vec![],
service_tier: None,
nvext: None,
}
})
});
......@@ -226,6 +229,7 @@ pub fn final_response_to_one_chunk_stream(
service_tier: resp.service_tier.clone(),
choices,
usage: resp.usage.clone(),
nvext: resp.nvext.clone(),
};
let annotated = Annotated {
......@@ -275,6 +279,7 @@ mod tests {
object: "chat.completion.chunk".to_string(),
usage: None,
service_tier: None,
nvext: None,
};
Annotated {
......@@ -311,6 +316,7 @@ mod tests {
object: "chat.completion.chunk".to_string(),
usage: None,
service_tier: None,
nvext: None,
};
Annotated {
......@@ -430,6 +436,7 @@ mod tests {
object: "chat.completion.chunk".to_string(),
usage: None,
service_tier: None,
nvext: None,
}),
id: Some("correlation-123".to_string()),
event: Some("test-event".to_string()),
......
......@@ -243,6 +243,7 @@ impl
//mdcsum: mdcsum.clone(),
index: data.index,
completion_usage: data.completion_usage,
disaggregated_params: data.disaggregated_params,
})
})
});
......
......@@ -19,6 +19,7 @@ use dynamo_runtime::{
};
use futures::stream::{self, StreamExt};
use serde::{Deserialize, Serialize};
use serde_json::json;
pub mod approx;
pub mod indexer;
......@@ -583,6 +584,24 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
let (mut backend_input, context) = request.into_parts();
backend_input.estimated_prefix_hit_num_blocks = Some(overlap_amount);
backend_input.dp_rank = Some(dp_rank);
// Check if worker_id is requested in extra_fields
let should_populate_worker_id = backend_input
.extra_fields
.as_deref()
.unwrap_or(&[])
.iter()
.any(|s| s == "worker_id");
// Get prefill worker ID if available (stored by PrefillRouter)
// In aggregated mode, prefill_worker_id is None, so we use decode_worker_id for both
let decode_worker_id = instance_id;
let prefill_worker_id = context
.get::<u64>("prefill_worker_id")
.ok()
.map(|arc| *arc)
.or(Some(decode_worker_id)); // Use decode_worker_id if no separate prefill worker
let updated_request = context.map(|_| backend_input);
let mut response_stream = self.inner.direct(updated_request, instance_id).await?;
......@@ -592,6 +611,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
let wrapped_stream = Box::pin(async_stream::stream! {
let mut prefill_marked = false;
let mut first_item = true;
loop {
tokio::select! {
......@@ -603,7 +623,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
}
item = response_stream.next() => {
let Some(item) = item else {
let Some(mut item) = item else {
break;
};
......@@ -613,7 +633,28 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
}
prefill_marked = true;
}
yield item;
yield item.clone();
// Inject worker_id in first item's disaggregated_params if requested
if first_item && should_populate_worker_id {
if let Some(ref mut data) = item.data {
// Add worker_id to disaggregated_params
let worker_id_json = json!({
"prefill_worker_id": prefill_worker_id,
"decode_worker_id": decode_worker_id,
});
if let Some(ref mut params) = data.disaggregated_params {
if let Some(obj) = params.as_object_mut() {
obj.insert("worker_id".to_string(), worker_id_json);
}
} else {
data.disaggregated_params = Some(json!({"worker_id": worker_id_json}));
}
}
first_item = false;
}
}
}
}
......
......@@ -176,11 +176,11 @@ impl PrefillRouter {
Ok(())
}
/// Call the prefill router and extract structured prefill result
/// Call the prefill router and extract structured prefill result and worker ID
async fn call_prefill(
&self,
request: SingleIn<PreprocessedRequest>,
) -> Result<PrefillResult, PrefillError> {
) -> Result<(PrefillResult, Option<u64>), PrefillError> {
// Get the prefill router, error if not activated
let Some(prefill_router) = self.prefill_router.get() else {
return Err(PrefillError::NotActivated);
......@@ -239,10 +239,21 @@ impl PrefillRouter {
));
};
Ok(PrefillResult {
// Extract prefill worker ID from disaggregated_params
let prefill_worker_id = disaggregated_params
.get("worker_id")
.and_then(|worker_id_json| {
worker_id_json
.get("prefill_worker_id")
.and_then(|v| v.as_u64())
});
Ok((
PrefillResult {
disaggregated_params,
prompt_tokens_details,
})
},
prefill_worker_id,
))
}
}
......@@ -299,7 +310,7 @@ impl
// Handle prefill result
match prefill_result {
Ok(prefill_result) => {
Ok((prefill_result, prefill_worker_id)) => {
tracing::debug!("Prefill succeeded, using disaggregated params for decode");
let mut decode_req = req;
......@@ -315,8 +326,14 @@ impl
..existing_override.unwrap_or_default()
});
// Store prefill worker ID in context if available
let mut decode_context = context;
if let Some(worker_id) = prefill_worker_id {
decode_context.insert("prefill_worker_id", worker_id);
}
// Map the modified request through with preserved context
let decode_request = context.map(|_| decode_req);
let decode_request = decode_context.map(|_| decode_req);
next.generate(decode_request).await
}
Err(PrefillError::NotActivated) => {
......
......@@ -972,6 +972,7 @@ mod tests {
system_fingerprint: None,
object: "chat.completion.chunk".to_string(),
usage: None,
nvext: None,
}
}
......@@ -1009,6 +1010,7 @@ mod tests {
system_fingerprint: None,
object: "chat.completion.chunk".to_string(),
usage: None,
nvext: None,
}
}
......@@ -1349,6 +1351,7 @@ mod tests {
system_fingerprint: None,
object: "chat.completion.chunk".to_string(),
usage: None,
nvext: None,
};
let logprobs = response.extract_logprobs_by_choice();
......@@ -1563,6 +1566,7 @@ mod tests {
system_fingerprint: None,
object: "chat.completion.chunk".to_string(),
usage: None,
nvext: None,
}
}
......
......@@ -228,9 +228,10 @@ impl OpenAIPreprocessor {
builder.annotations(request.annotations().unwrap_or_default());
builder.mdc_sum(Some(self.mdcsum.clone()));
builder.estimated_prefix_hit_num_blocks(None);
// Extract backend_instance_id from nvext if present
// Extract backend_instance_id and extra_fields from nvext if present
if let Some(nvext) = request.nvext() {
builder.backend_instance_id(nvext.backend_instance_id);
builder.extra_fields(nvext.extra_fields.clone());
}
Ok(builder)
......
......@@ -53,6 +53,10 @@ pub struct BackendOutput {
// Token usage information
#[serde(default, skip_serializing_if = "Option::is_none")]
pub completion_usage: Option<CompletionUsage>,
/// Disaggregated execution parameters (for prefill/decode separation)
#[serde(default, skip_serializing_if = "Option::is_none")]
pub disaggregated_params: Option<serde_json::Value>,
}
/// The LLM engine and backnd with manage it's own state, specifically translating how a
......
......@@ -92,6 +92,11 @@ pub struct PreprocessedRequest {
#[builder(default)]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub extra_args: Option<serde_json::Value>,
/// Extra fields requested to be included in the response's nvext
#[builder(default)]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub extra_fields: Option<Vec<String>>,
}
impl PreprocessedRequest {
......
......@@ -34,6 +34,8 @@ pub struct DeltaAggregator {
error: Option<String>,
/// Optional service tier information for the response.
service_tier: Option<dynamo_async_openai::types::ServiceTierResponse>,
/// Aggregated nvext field from stream responses
nvext: Option<serde_json::Value>,
}
/// Represents the accumulated state of a single chat choice during streaming aggregation.
......@@ -97,6 +99,7 @@ impl DeltaAggregator {
choices: HashMap::new(),
error: None,
service_tier: None,
nvext: None,
}
}
......@@ -140,6 +143,11 @@ impl DeltaAggregator {
aggregator.system_fingerprint = Some(system_fingerprint);
}
// Aggregate nvext field (take the last non-None value)
if delta.nvext.is_some() {
aggregator.nvext = delta.nvext;
}
// Aggregate choices incrementally.
for choice in delta.choices {
let state_choice =
......@@ -247,6 +255,7 @@ impl DeltaAggregator {
system_fingerprint: aggregator.system_fingerprint,
choices,
service_tier: aggregator.service_tier,
nvext: aggregator.nvext,
};
Ok(response)
......@@ -411,6 +420,7 @@ mod tests {
system_fingerprint: None,
choices: vec![choice],
object: "chat.completion".to_string(),
nvext: None,
};
Annotated {
......@@ -633,6 +643,7 @@ mod tests {
},
],
object: "chat.completion".to_string(),
nvext: None,
};
// Wrap it in Annotated and create a stream
......
......@@ -258,6 +258,7 @@ impl DeltaGenerator {
choices,
usage: None, // Always None for chunks with content/choices
service_tier: self.service_tier.clone(),
nvext: None, // Will be populated by router layer if needed
}
}
......@@ -279,6 +280,7 @@ impl DeltaGenerator {
choices: vec![], // Empty choices for usage-only chunk
usage: Some(usage),
service_tier: self.service_tier.clone(),
nvext: None,
}
}
......@@ -358,7 +360,41 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamRes
// Create the streaming response.
let index = 0;
let stream_response = self.create_choice(index, delta.text, finish_reason, logprobs);
let mut stream_response = self.create_choice(index, delta.text, finish_reason, logprobs);
// Extract worker_id from disaggregated_params and inject into nvext if present
if let Some(worker_id_json) = delta
.disaggregated_params
.as_ref()
.and_then(|params| params.get("worker_id"))
{
use crate::protocols::openai::nvext::{NvExtResponse, WorkerIdInfo};
let prefill_worker_id = worker_id_json
.get("prefill_worker_id")
.and_then(|v| v.as_u64());
let decode_worker_id = worker_id_json
.get("decode_worker_id")
.and_then(|v| v.as_u64());
let worker_id_info = WorkerIdInfo {
prefill_worker_id,
decode_worker_id,
};
let nvext_response = NvExtResponse {
worker_id: Some(worker_id_info),
};
if let Ok(nvext_json) = serde_json::to_value(&nvext_response) {
stream_response.nvext = Some(nvext_json);
tracing::debug!(
"Injected worker_id into chat completion nvext: prefill={:?}, decode={:?}",
prefill_worker_id,
decode_worker_id
);
}
}
Ok(stream_response)
}
......
......@@ -582,6 +582,7 @@ impl JailedStream {
usage: None,
service_tier: None,
system_fingerprint: None,
nvext: None,
};
let final_metadata = (last_annotated_id, last_annotated_event, last_annotated_comment);
......
......@@ -291,6 +291,7 @@ impl ResponseFactory {
choices: vec![choice],
system_fingerprint: self.system_fingerprint.clone(),
usage,
nvext: None, // Will be populated by router layer if needed
};
NvCreateCompletionResponse { inner }
}
......
......@@ -24,6 +24,7 @@ pub struct DeltaAggregator {
system_fingerprint: Option<String>,
choices: HashMap<u32, DeltaChoice>,
error: Option<String>,
nvext: Option<serde_json::Value>,
}
struct DeltaChoice {
......@@ -49,6 +50,7 @@ impl DeltaAggregator {
system_fingerprint: None,
choices: HashMap::new(),
error: None,
nvext: None,
}
}
......@@ -84,6 +86,10 @@ impl DeltaAggregator {
if let Some(system_fingerprint) = delta.inner.system_fingerprint {
aggregator.system_fingerprint = Some(system_fingerprint);
}
// Aggregate nvext field (take the last non-None value)
if delta.inner.nvext.is_some() {
aggregator.nvext = delta.inner.nvext;
}
// handle the choices
for choice in delta.inner.choices {
......@@ -163,6 +169,7 @@ impl DeltaAggregator {
object: "text_completion".to_string(),
system_fingerprint: aggregator.system_fingerprint,
choices,
nvext: aggregator.nvext,
};
let response = NvCreateCompletionResponse { inner };
......@@ -250,6 +257,7 @@ mod tests {
logprobs,
}],
object: "text_completion".to_string(),
nvext: None,
};
let response = NvCreateCompletionResponse { inner };
......@@ -379,6 +387,7 @@ mod tests {
},
],
object: "text_completion".to_string(),
nvext: None,
};
let response = NvCreateCompletionResponse { inner };
......
......@@ -189,6 +189,7 @@ impl DeltaGenerator {
logprobs,
}],
usage: None, // Always None for chunks with content/choices
nvext: None, // Will be populated by router layer if needed
};
NvCreateCompletionResponse { inner }
......@@ -211,6 +212,7 @@ impl DeltaGenerator {
system_fingerprint: self.system_fingerprint.clone(),
choices: vec![], // Empty choices for usage-only chunk
usage: Some(usage),
nvext: None, // Will be populated by router layer if needed
};
NvCreateCompletionResponse { inner }
......@@ -261,7 +263,42 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateCompletionResponse> for
// create choice
let index = delta.index.unwrap_or(0);
let response = self.create_choice(index, delta.text.clone(), finish_reason, logprobs);
let mut response = self.create_choice(index, delta.text.clone(), finish_reason, logprobs);
// Extract worker_id from disaggregated_params and inject into nvext if present
if let Some(worker_id_json) = delta
.disaggregated_params
.as_ref()
.and_then(|params| params.get("worker_id"))
{
use crate::protocols::openai::nvext::{NvExtResponse, WorkerIdInfo};
let prefill_worker_id = worker_id_json
.get("prefill_worker_id")
.and_then(|v| v.as_u64());
let decode_worker_id = worker_id_json
.get("decode_worker_id")
.and_then(|v| v.as_u64());
let worker_id_info = WorkerIdInfo {
prefill_worker_id,
decode_worker_id,
};
let nvext_response = NvExtResponse {
worker_id: Some(worker_id_info),
};
if let Ok(nvext_json) = serde_json::to_value(&nvext_response) {
response.inner.nvext = Some(nvext_json);
tracing::debug!(
"Injected worker_id into completions nvext: prefill={:?}, decode={:?}",
prefill_worker_id,
decode_worker_id
);
}
}
Ok(response)
}
......
......@@ -10,6 +10,26 @@ pub trait NvExtProvider {
fn raw_prompt(&self) -> Option<String>;
}
/// Worker ID information for disaggregated serving
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct WorkerIdInfo {
/// The prefill worker ID that processed this request
#[serde(skip_serializing_if = "Option::is_none")]
pub prefill_worker_id: Option<u64>,
/// The decode worker ID that processed this request
#[serde(skip_serializing_if = "Option::is_none")]
pub decode_worker_id: Option<u64>,
}
/// NVIDIA LLM response extensions
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct NvExtResponse {
/// Worker ID information (prefill and decode worker IDs)
#[serde(skip_serializing_if = "Option::is_none")]
pub worker_id: Option<WorkerIdInfo>,
}
/// NVIDIA LLM extensions to the OpenAI API
#[derive(Serialize, Deserialize, Builder, Validate, Debug, Clone)]
#[validate(schema(function = "validate_nv_ext"))]
......@@ -53,6 +73,13 @@ pub struct NvExt {
#[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub max_thinking_tokens: Option<u32>,
/// Extra fields to be included in the response's nvext
/// This is a list of field names that should be populated in the response
/// Supported fields: "worker_id"
#[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub extra_fields: Option<Vec<String>>,
}
impl Default for NvExt {
......@@ -98,6 +125,7 @@ mod tests {
assert_eq!(nv_ext.backend_instance_id, None);
assert_eq!(nv_ext.token_data, None);
assert_eq!(nv_ext.max_thinking_tokens, None);
assert_eq!(nv_ext.extra_fields, None);
}
// Test valid builder configurations
......@@ -109,6 +137,7 @@ mod tests {
.backend_instance_id(42)
.token_data(vec![1, 2, 3, 4])
.max_thinking_tokens(1024)
.extra_fields(vec!["worker_id".to_string()])
.build()
.unwrap();
......@@ -117,6 +146,7 @@ mod tests {
assert_eq!(nv_ext.backend_instance_id, Some(42));
assert_eq!(nv_ext.token_data, Some(vec![1, 2, 3, 4]));
assert_eq!(nv_ext.max_thinking_tokens, Some(1024));
assert_eq!(nv_ext.extra_fields, Some(vec!["worker_id".to_string()]));
// Validate the built struct
assert!(nv_ext.validate().is_ok());
}
......
......@@ -31,6 +31,10 @@ pub struct NvCreateResponse {
pub struct NvResponse {
#[serde(flatten)]
pub inner: dynamo_async_openai::types::responses::Response,
/// NVIDIA extension field for response metadata (worker IDs, etc.)
#[serde(skip_serializing_if = "Option::is_none")]
pub nvext: Option<serde_json::Value>,
}
/// Implements `NvExtProvider` for `NvCreateResponse`,
......@@ -203,6 +207,10 @@ impl TryFrom<NvCreateChatCompletionResponse> for NvResponse {
fn try_from(nv_resp: NvCreateChatCompletionResponse) -> Result<Self, Self::Error> {
let chat_resp = nv_resp;
// Preserve nvext field from chat completion response
let nvext = chat_resp.nvext.clone();
let content_text = chat_resp
.choices
.into_iter()
......@@ -253,7 +261,10 @@ impl TryFrom<NvCreateChatCompletionResponse> for NvResponse {
user: None,
};
Ok(NvResponse { inner: response })
Ok(NvResponse {
inner: response,
nvext,
})
}
}
......@@ -365,6 +376,7 @@ mod tests {
system_fingerprint: None,
object: "chat.completion".to_string(),
usage: None,
nvext: None,
};
let wrapped: NvResponse = chat_resp.try_into().unwrap();
......
......@@ -404,6 +404,7 @@ fn create_response_with_linear_probs(
system_fingerprint: None,
object: "chat.completion.chunk".to_string(),
usage: None,
nvext: None,
}
}
......@@ -484,5 +485,6 @@ fn create_multi_choice_response(
system_fingerprint: None,
object: "chat.completion.chunk".to_string(),
usage: None,
nvext: None,
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment