Unverified Commit 13156361 authored by nv-nedelman-1's avatar nv-nedelman-1 Committed by GitHub
Browse files

chore: relaxing constraints on metadata field, adding metadata field to completions API (#3240)


Signed-off-by: default avatarNicholas Edelman <nedelman@nvidia.com>
parent e21dcf6c
......@@ -232,14 +232,9 @@ pub struct CreateResponse {
/// Any further attempts to call a tool by the model will be ignored.
pub max_tool_calls: Option<u32>,
/// Set of 16 key-value pairs that can be attached to an object. This can be
/// useful for storing additional information about the object in a structured
/// format, and querying for objects via API or the dashboard.
///
/// Keys are strings with a maximum length of 64 characters. Values are
/// strings with a maximum length of 512 characters.
/// Arbitrary JSON metadata used as a passthrough parameter
#[serde(skip_serializing_if = "Option::is_none")]
pub metadata: Option<HashMap<String, String>>,
pub metadata: Option<serde_json::Value>,
/// Whether to allow the model to run tool calls in parallel.
#[serde(skip_serializing_if = "Option::is_none")]
......@@ -1387,7 +1382,7 @@ pub struct Response {
/// Metadata tags/values that were attached to this response.
#[serde(skip_serializing_if = "Option::is_none")]
pub metadata: Option<HashMap<String, String>>,
pub metadata: Option<serde_json::Value>,
/// Model ID used to generate the response.
pub model: String,
......@@ -2121,7 +2116,7 @@ pub struct ResponseMetadata {
#[serde(skip_serializing_if = "Option::is_none")]
pub user: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub metadata: Option<HashMap<String, String>>,
pub metadata: Option<serde_json::Value>,
/// Prompt cache key for improved performance
#[serde(skip_serializing_if = "Option::is_none")]
pub prompt_cache_key: Option<String>,
......
......@@ -314,6 +314,7 @@ impl TryFrom<inference::ModelInferRequest> for NvCreateCompletionRequest {
},
common: Default::default(),
nvext: None,
metadata: None,
})
}
}
......
......@@ -976,11 +976,6 @@ pub fn validate_response_unsupported_fields(
"`max_tool_calls` is not supported.",
));
}
if inner.metadata.is_some() {
return Some(ErrorMessage::not_implemented_error(
"`metadata` is not supported.",
));
}
if inner.previous_response_id.is_some() {
return Some(ErrorMessage::not_implemented_error(
"`previous_response_id` is not supported.",
......@@ -1187,7 +1182,6 @@ pub fn responses_router(
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use super::*;
use crate::discovery::ModelManagerError;
......@@ -1355,7 +1349,6 @@ mod tests {
Box::new(|r| r.instructions = Some("System prompt".into())),
),
("max_tool_calls", Box::new(|r| r.max_tool_calls = Some(3))),
("metadata", Box::new(|r| r.metadata = Some(HashMap::new()))),
(
"previous_response_id",
Box::new(|r| r.previous_response_id = Some("prev-id".into())),
......@@ -1482,6 +1475,7 @@ mod tests {
},
common: Default::default(),
nvext: None,
metadata: None,
};
let result = validate_completion_fields_generic(&request);
......@@ -1504,6 +1498,7 @@ mod tests {
},
common: Default::default(),
nvext: None,
metadata: None,
};
let result = validate_completion_fields_generic(&request);
assert!(result.is_err());
......@@ -1525,6 +1520,7 @@ mod tests {
},
common: Default::default(),
nvext: None,
metadata: None,
};
let result = validate_completion_fields_generic(&request);
assert!(result.is_err());
......@@ -1546,6 +1542,7 @@ mod tests {
},
common: Default::default(),
nvext: None,
metadata: None,
};
let result = validate_completion_fields_generic(&request);
assert!(result.is_err());
......@@ -1569,6 +1566,7 @@ mod tests {
.build()
.unwrap(),
nvext: None,
metadata: None,
};
let result = validate_completion_fields_generic(&request);
assert!(result.is_err());
......@@ -1590,6 +1588,7 @@ mod tests {
},
common: Default::default(),
nvext: None,
metadata: None,
};
let result = validate_completion_fields_generic(&request);
assert!(result.is_err());
......@@ -1602,6 +1601,34 @@ mod tests {
}
}
#[test]
fn test_metadata_field_nested() {
use serde_json::json;
// Test metadata field with nested object
let request = NvCreateCompletionRequest {
inner: CreateCompletionRequest {
model: "test-model".to_string(),
prompt: "Hello".into(),
..Default::default()
},
common: Default::default(),
nvext: None,
metadata: json!({
"user": {"id": 1, "name": "user-1"},
"session": {"id": "session-1", "timestamp": 1640995200}
})
.into(),
};
let result = validate_completion_fields_generic(&request);
assert!(result.is_ok());
// Verify metadata is accessible
assert!(request.metadata.is_some());
assert_eq!(request.metadata.as_ref().unwrap()["user"]["id"], 1);
}
#[test]
fn test_bad_base_request_for_chatcompletion() {
// Frequency Penalty: Should be a float between -2.0 and 2.0
......
......@@ -331,7 +331,7 @@ impl ValidateRequest for NvCreateChatCompletionRequest {
validate::validate_model(&self.inner.model)?;
// none for store
validate::validate_reasoning_effort(&self.inner.reasoning_effort)?;
validate::validate_metadata(&self.inner.metadata)?;
// none for metadata
validate::validate_frequency_penalty(self.inner.frequency_penalty)?;
validate::validate_logit_bias(&self.inner.logit_bias)?;
// none for logprobs
......
......@@ -35,6 +35,10 @@ pub struct NvCreateCompletionRequest {
#[serde(skip_serializing_if = "Option::is_none")]
pub nvext: Option<NvExt>,
// metadata - passthrough parameter without restrictions
#[serde(skip_serializing_if = "Option::is_none")]
pub metadata: Option<serde_json::Value>,
}
#[derive(Serialize, Deserialize, Validate, Debug, Clone)]
......@@ -442,6 +446,7 @@ impl ValidateRequest for NvCreateCompletionRequest {
validate::validate_logit_bias(&self.inner.logit_bias)?;
validate::validate_user(self.inner.user.as_deref())?;
// none for seed
// none for metadata
// Common Ext
validate::validate_repetition_penalty(self.get_repetition_penalty())?;
......
......@@ -182,6 +182,7 @@ impl TryFrom<NvCreateResponse> for NvCreateChatCompletionRequest {
top_p: resp.inner.top_p,
max_completion_tokens: resp.inner.max_output_tokens,
top_logprobs,
metadata: resp.inner.metadata,
stream: Some(true), // Set this to Some(True) by default to aggregate stream
..Default::default()
},
......
......@@ -82,12 +82,7 @@ pub const BEST_OF_RANGE: (u8, u8) = (MIN_BEST_OF, MAX_BEST_OF);
pub const MAX_STOP_SEQUENCES: usize = 4;
/// Maximum allowed number of tools
pub const MAX_TOOLS: usize = 128;
/// Maximum allowed number of metadata key-value pairs
pub const MAX_METADATA_PAIRS: usize = 16;
/// Maximum allowed length for metadata keys
pub const MAX_METADATA_KEY_LENGTH: usize = 64;
/// Maximum allowed length for metadata values
pub const MAX_METADATA_VALUE_LENGTH: usize = 512;
// Metadata validation constants removed - we are no longer restricting the metadata field char limits
/// Maximum allowed length for function names
pub const MAX_FUNCTION_NAME_LENGTH: usize = 64;
/// Maximum allowed value for Prompt IntegerArray elements
......@@ -364,45 +359,6 @@ pub fn validate_tools(
Ok(())
}
/// Validates metadata
pub fn validate_metadata(metadata: &Option<serde_json::Value>) -> Result<(), anyhow::Error> {
let metadata = match metadata {
Some(val) => val,
None => return Ok(()),
};
if let Some(obj) = metadata.as_object() {
if obj.len() > MAX_METADATA_PAIRS {
anyhow::bail!(
"Metadata cannot have more than {} key-value pairs, got {}",
MAX_METADATA_PAIRS,
obj.len()
);
}
for (key, value) in obj {
if key.len() > MAX_METADATA_KEY_LENGTH {
anyhow::bail!(
"Metadata key '{}' exceeds {} character limit",
key,
MAX_METADATA_KEY_LENGTH
);
}
if let Some(value_str) = value.as_str()
&& value_str.len() > MAX_METADATA_VALUE_LENGTH
{
anyhow::bail!(
"Metadata value for key '{}' exceeds {} character limit",
key,
MAX_METADATA_VALUE_LENGTH
);
}
}
}
Ok(())
}
/// Validates reasoning effort parameter
pub fn validate_reasoning_effort(
_reasoning_effort: &Option<dynamo_async_openai::types::ReasoningEffort>,
......
......@@ -28,6 +28,7 @@ impl CompletionSample {
inner,
common: Default::default(),
nvext: None,
metadata: None,
};
Ok(Self {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment