Unverified Commit 13156361 authored by nv-nedelman-1's avatar nv-nedelman-1 Committed by GitHub
Browse files

chore: relaxing constraints on metadata field, adding metadata field to completions API (#3240)


Signed-off-by: default avatarNicholas Edelman <nedelman@nvidia.com>
parent e21dcf6c
...@@ -232,14 +232,9 @@ pub struct CreateResponse { ...@@ -232,14 +232,9 @@ pub struct CreateResponse {
/// Any further attempts to call a tool by the model will be ignored. /// Any further attempts to call a tool by the model will be ignored.
pub max_tool_calls: Option<u32>, pub max_tool_calls: Option<u32>,
/// Set of 16 key-value pairs that can be attached to an object. This can be /// Arbitrary JSON metadata used as a passthrough parameter
/// useful for storing additional information about the object in a structured
/// format, and querying for objects via API or the dashboard.
///
/// Keys are strings with a maximum length of 64 characters. Values are
/// strings with a maximum length of 512 characters.
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub metadata: Option<HashMap<String, String>>, pub metadata: Option<serde_json::Value>,
/// Whether to allow the model to run tool calls in parallel. /// Whether to allow the model to run tool calls in parallel.
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
...@@ -1387,7 +1382,7 @@ pub struct Response { ...@@ -1387,7 +1382,7 @@ pub struct Response {
/// Metadata tags/values that were attached to this response. /// Metadata tags/values that were attached to this response.
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub metadata: Option<HashMap<String, String>>, pub metadata: Option<serde_json::Value>,
/// Model ID used to generate the response. /// Model ID used to generate the response.
pub model: String, pub model: String,
...@@ -2121,7 +2116,7 @@ pub struct ResponseMetadata { ...@@ -2121,7 +2116,7 @@ pub struct ResponseMetadata {
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub user: Option<String>, pub user: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub metadata: Option<HashMap<String, String>>, pub metadata: Option<serde_json::Value>,
/// Prompt cache key for improved performance /// Prompt cache key for improved performance
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub prompt_cache_key: Option<String>, pub prompt_cache_key: Option<String>,
......
...@@ -314,6 +314,7 @@ impl TryFrom<inference::ModelInferRequest> for NvCreateCompletionRequest { ...@@ -314,6 +314,7 @@ impl TryFrom<inference::ModelInferRequest> for NvCreateCompletionRequest {
}, },
common: Default::default(), common: Default::default(),
nvext: None, nvext: None,
metadata: None,
}) })
} }
} }
......
...@@ -976,11 +976,6 @@ pub fn validate_response_unsupported_fields( ...@@ -976,11 +976,6 @@ pub fn validate_response_unsupported_fields(
"`max_tool_calls` is not supported.", "`max_tool_calls` is not supported.",
)); ));
} }
if inner.metadata.is_some() {
return Some(ErrorMessage::not_implemented_error(
"`metadata` is not supported.",
));
}
if inner.previous_response_id.is_some() { if inner.previous_response_id.is_some() {
return Some(ErrorMessage::not_implemented_error( return Some(ErrorMessage::not_implemented_error(
"`previous_response_id` is not supported.", "`previous_response_id` is not supported.",
...@@ -1187,7 +1182,6 @@ pub fn responses_router( ...@@ -1187,7 +1182,6 @@ pub fn responses_router(
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::collections::HashMap;
use super::*; use super::*;
use crate::discovery::ModelManagerError; use crate::discovery::ModelManagerError;
...@@ -1355,7 +1349,6 @@ mod tests { ...@@ -1355,7 +1349,6 @@ mod tests {
Box::new(|r| r.instructions = Some("System prompt".into())), Box::new(|r| r.instructions = Some("System prompt".into())),
), ),
("max_tool_calls", Box::new(|r| r.max_tool_calls = Some(3))), ("max_tool_calls", Box::new(|r| r.max_tool_calls = Some(3))),
("metadata", Box::new(|r| r.metadata = Some(HashMap::new()))),
( (
"previous_response_id", "previous_response_id",
Box::new(|r| r.previous_response_id = Some("prev-id".into())), Box::new(|r| r.previous_response_id = Some("prev-id".into())),
...@@ -1482,6 +1475,7 @@ mod tests { ...@@ -1482,6 +1475,7 @@ mod tests {
}, },
common: Default::default(), common: Default::default(),
nvext: None, nvext: None,
metadata: None,
}; };
let result = validate_completion_fields_generic(&request); let result = validate_completion_fields_generic(&request);
...@@ -1504,6 +1498,7 @@ mod tests { ...@@ -1504,6 +1498,7 @@ mod tests {
}, },
common: Default::default(), common: Default::default(),
nvext: None, nvext: None,
metadata: None,
}; };
let result = validate_completion_fields_generic(&request); let result = validate_completion_fields_generic(&request);
assert!(result.is_err()); assert!(result.is_err());
...@@ -1525,6 +1520,7 @@ mod tests { ...@@ -1525,6 +1520,7 @@ mod tests {
}, },
common: Default::default(), common: Default::default(),
nvext: None, nvext: None,
metadata: None,
}; };
let result = validate_completion_fields_generic(&request); let result = validate_completion_fields_generic(&request);
assert!(result.is_err()); assert!(result.is_err());
...@@ -1546,6 +1542,7 @@ mod tests { ...@@ -1546,6 +1542,7 @@ mod tests {
}, },
common: Default::default(), common: Default::default(),
nvext: None, nvext: None,
metadata: None,
}; };
let result = validate_completion_fields_generic(&request); let result = validate_completion_fields_generic(&request);
assert!(result.is_err()); assert!(result.is_err());
...@@ -1569,6 +1566,7 @@ mod tests { ...@@ -1569,6 +1566,7 @@ mod tests {
.build() .build()
.unwrap(), .unwrap(),
nvext: None, nvext: None,
metadata: None,
}; };
let result = validate_completion_fields_generic(&request); let result = validate_completion_fields_generic(&request);
assert!(result.is_err()); assert!(result.is_err());
...@@ -1590,6 +1588,7 @@ mod tests { ...@@ -1590,6 +1588,7 @@ mod tests {
}, },
common: Default::default(), common: Default::default(),
nvext: None, nvext: None,
metadata: None,
}; };
let result = validate_completion_fields_generic(&request); let result = validate_completion_fields_generic(&request);
assert!(result.is_err()); assert!(result.is_err());
...@@ -1602,6 +1601,34 @@ mod tests { ...@@ -1602,6 +1601,34 @@ mod tests {
} }
} }
#[test]
fn test_metadata_field_nested() {
use serde_json::json;
// Test metadata field with nested object
let request = NvCreateCompletionRequest {
inner: CreateCompletionRequest {
model: "test-model".to_string(),
prompt: "Hello".into(),
..Default::default()
},
common: Default::default(),
nvext: None,
metadata: json!({
"user": {"id": 1, "name": "user-1"},
"session": {"id": "session-1", "timestamp": 1640995200}
})
.into(),
};
let result = validate_completion_fields_generic(&request);
assert!(result.is_ok());
// Verify metadata is accessible
assert!(request.metadata.is_some());
assert_eq!(request.metadata.as_ref().unwrap()["user"]["id"], 1);
}
#[test] #[test]
fn test_bad_base_request_for_chatcompletion() { fn test_bad_base_request_for_chatcompletion() {
// Frequency Penalty: Should be a float between -2.0 and 2.0 // Frequency Penalty: Should be a float between -2.0 and 2.0
......
...@@ -331,7 +331,7 @@ impl ValidateRequest for NvCreateChatCompletionRequest { ...@@ -331,7 +331,7 @@ impl ValidateRequest for NvCreateChatCompletionRequest {
validate::validate_model(&self.inner.model)?; validate::validate_model(&self.inner.model)?;
// none for store // none for store
validate::validate_reasoning_effort(&self.inner.reasoning_effort)?; validate::validate_reasoning_effort(&self.inner.reasoning_effort)?;
validate::validate_metadata(&self.inner.metadata)?; // none for metadata
validate::validate_frequency_penalty(self.inner.frequency_penalty)?; validate::validate_frequency_penalty(self.inner.frequency_penalty)?;
validate::validate_logit_bias(&self.inner.logit_bias)?; validate::validate_logit_bias(&self.inner.logit_bias)?;
// none for logprobs // none for logprobs
......
...@@ -35,6 +35,10 @@ pub struct NvCreateCompletionRequest { ...@@ -35,6 +35,10 @@ pub struct NvCreateCompletionRequest {
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub nvext: Option<NvExt>, pub nvext: Option<NvExt>,
// metadata - passthrough parameter without restrictions
#[serde(skip_serializing_if = "Option::is_none")]
pub metadata: Option<serde_json::Value>,
} }
#[derive(Serialize, Deserialize, Validate, Debug, Clone)] #[derive(Serialize, Deserialize, Validate, Debug, Clone)]
...@@ -442,6 +446,7 @@ impl ValidateRequest for NvCreateCompletionRequest { ...@@ -442,6 +446,7 @@ impl ValidateRequest for NvCreateCompletionRequest {
validate::validate_logit_bias(&self.inner.logit_bias)?; validate::validate_logit_bias(&self.inner.logit_bias)?;
validate::validate_user(self.inner.user.as_deref())?; validate::validate_user(self.inner.user.as_deref())?;
// none for seed // none for seed
// none for metadata
// Common Ext // Common Ext
validate::validate_repetition_penalty(self.get_repetition_penalty())?; validate::validate_repetition_penalty(self.get_repetition_penalty())?;
......
...@@ -182,6 +182,7 @@ impl TryFrom<NvCreateResponse> for NvCreateChatCompletionRequest { ...@@ -182,6 +182,7 @@ impl TryFrom<NvCreateResponse> for NvCreateChatCompletionRequest {
top_p: resp.inner.top_p, top_p: resp.inner.top_p,
max_completion_tokens: resp.inner.max_output_tokens, max_completion_tokens: resp.inner.max_output_tokens,
top_logprobs, top_logprobs,
metadata: resp.inner.metadata,
stream: Some(true), // Set this to Some(True) by default to aggregate stream stream: Some(true), // Set this to Some(True) by default to aggregate stream
..Default::default() ..Default::default()
}, },
......
...@@ -82,12 +82,7 @@ pub const BEST_OF_RANGE: (u8, u8) = (MIN_BEST_OF, MAX_BEST_OF); ...@@ -82,12 +82,7 @@ pub const BEST_OF_RANGE: (u8, u8) = (MIN_BEST_OF, MAX_BEST_OF);
pub const MAX_STOP_SEQUENCES: usize = 4; pub const MAX_STOP_SEQUENCES: usize = 4;
/// Maximum allowed number of tools /// Maximum allowed number of tools
pub const MAX_TOOLS: usize = 128; pub const MAX_TOOLS: usize = 128;
/// Maximum allowed number of metadata key-value pairs // Metadata validation constants removed - we are no longer restricting the metadata field char limits
pub const MAX_METADATA_PAIRS: usize = 16;
/// Maximum allowed length for metadata keys
pub const MAX_METADATA_KEY_LENGTH: usize = 64;
/// Maximum allowed length for metadata values
pub const MAX_METADATA_VALUE_LENGTH: usize = 512;
/// Maximum allowed length for function names /// Maximum allowed length for function names
pub const MAX_FUNCTION_NAME_LENGTH: usize = 64; pub const MAX_FUNCTION_NAME_LENGTH: usize = 64;
/// Maximum allowed value for Prompt IntegerArray elements /// Maximum allowed value for Prompt IntegerArray elements
...@@ -364,45 +359,6 @@ pub fn validate_tools( ...@@ -364,45 +359,6 @@ pub fn validate_tools(
Ok(()) Ok(())
} }
/// Validates metadata
pub fn validate_metadata(metadata: &Option<serde_json::Value>) -> Result<(), anyhow::Error> {
let metadata = match metadata {
Some(val) => val,
None => return Ok(()),
};
if let Some(obj) = metadata.as_object() {
if obj.len() > MAX_METADATA_PAIRS {
anyhow::bail!(
"Metadata cannot have more than {} key-value pairs, got {}",
MAX_METADATA_PAIRS,
obj.len()
);
}
for (key, value) in obj {
if key.len() > MAX_METADATA_KEY_LENGTH {
anyhow::bail!(
"Metadata key '{}' exceeds {} character limit",
key,
MAX_METADATA_KEY_LENGTH
);
}
if let Some(value_str) = value.as_str()
&& value_str.len() > MAX_METADATA_VALUE_LENGTH
{
anyhow::bail!(
"Metadata value for key '{}' exceeds {} character limit",
key,
MAX_METADATA_VALUE_LENGTH
);
}
}
}
Ok(())
}
/// Validates reasoning effort parameter /// Validates reasoning effort parameter
pub fn validate_reasoning_effort( pub fn validate_reasoning_effort(
_reasoning_effort: &Option<dynamo_async_openai::types::ReasoningEffort>, _reasoning_effort: &Option<dynamo_async_openai::types::ReasoningEffort>,
......
...@@ -28,6 +28,7 @@ impl CompletionSample { ...@@ -28,6 +28,7 @@ impl CompletionSample {
inner, inner,
common: Default::default(), common: Default::default(),
nvext: None, nvext: None,
metadata: None,
}; };
Ok(Self { Ok(Self {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment