Unverified Commit b165ec4a authored by Ayush Agarwal's avatar Ayush Agarwal Committed by GitHub
Browse files

chore: guided decoding support for nvext (#2339)


Co-authored-by: default avatarRyan McCormick <rmccormick@nvidia.com>
parent a3f7a39f
......@@ -329,6 +329,105 @@ pub struct SamplingOptions {
/// The seed to use when sampling
pub seed: Option<i64>,
/// Guided Decoding Options
pub guided_decoding: Option<GuidedDecodingOptions>,
}
/// Guided Decoding Options
///
/// Only one of `json`, `regex`, `choice`, or `grammar` should be set.
#[derive(Serialize, Deserialize, Debug, Clone, Default)]
pub struct GuidedDecodingOptions {
/// If specified, the output will follow the JSON schema. Can be a string, an object, or null.
#[serde(skip_serializing_if = "Option::is_none")]
pub json: Option<serde_json::Value>,
/// If specified, the output will follow the regex pattern. Can be a string or null.
#[serde(skip_serializing_if = "Option::is_none")]
pub regex: Option<String>,
/// If specified, the output will be exactly one of the choices.
#[serde(skip_serializing_if = "Option::is_none")]
pub choice: Option<Vec<String>>,
/// If specified, the output will follow the context-free grammar. Can be a string or null.
#[serde(skip_serializing_if = "Option::is_none")]
pub grammar: Option<String>,
/// If specified, the backend to use for guided decoding, can be backends like xgrammar or custom guided decoding backend
#[serde(skip_serializing_if = "Option::is_none")]
pub backend: Option<String>,
}
impl GuidedDecodingOptions {
/// Construct without validation
pub fn new(
json: Option<serde_json::Value>,
regex: Option<String>,
choice: Option<Vec<String>>,
grammar: Option<String>,
backend: Option<String>,
) -> Self {
Self {
json,
regex,
choice,
grammar,
backend,
}
}
/// Construct and validate (fallible)
pub fn validated(
json: Option<serde_json::Value>,
regex: Option<String>,
choice: Option<Vec<String>>,
grammar: Option<String>,
backend: Option<String>,
) -> Result<Self> {
let instance = Self::new(json, regex, choice, grammar, backend);
instance.validate()?;
Ok(instance)
}
/// Construct only if one field is Some (fallible)
pub fn from_optional(
json: Option<serde_json::Value>,
regex: Option<String>,
choice: Option<Vec<String>>,
grammar: Option<String>,
backend: Option<String>,
) -> Result<Option<Self>> {
let is_empty_choice = choice.as_ref().is_none_or(|v| v.is_empty());
if json.is_none() && regex.is_none() && is_empty_choice && grammar.is_none() {
return Ok(None);
}
let instance = Self::validated(json, regex, choice, grammar, backend)?;
Ok(Some(instance))
}
/// Validate that only one guided decoding option is set
pub fn validate(&self) -> Result<()> {
let count = [
self.json.is_some(),
self.regex.is_some(),
self.choice.as_ref().is_some_and(|v| !v.is_empty()),
self.grammar.is_some(),
]
.iter()
.filter(|&&v| v)
.count();
if count > 1 {
Err(anyhow::anyhow!(
"Only one of json, regex, choice, or grammar can be set, but multiple are specified: {:?}",
self
))
} else {
Ok(())
}
}
}
impl SamplingOptions {
......@@ -571,4 +670,135 @@ mod tests {
panic!("Expected a Completion variant");
}
}
#[test]
fn test_guided_decoding_options_new_and_exclusive() {
// Only JSON set
let json_val = serde_json::json!({"type": "object"});
let backend = Some("xgrammar".to_string());
let opts = GuidedDecodingOptions::validated(
Some(json_val.clone()),
None,
None,
None,
backend.clone(),
);
assert!(opts.is_ok());
let opts = opts.unwrap();
assert_eq!(opts.json, Some(json_val));
assert!(opts.regex.is_none());
assert!(opts.choice.is_none());
assert!(opts.grammar.is_none());
assert_eq!(opts.backend, backend);
// Only regex set
let regex = Some(r"\d+".to_string());
let opts = GuidedDecodingOptions::validated(None, regex.clone(), None, None, None);
assert!(opts.is_ok());
let opts = opts.unwrap();
assert_eq!(opts.regex, regex);
assert!(opts.json.is_none());
assert!(opts.choice.is_none());
assert!(opts.grammar.is_none());
// Only choice set
let choice = Some(vec!["A".to_string(), "B".to_string()]);
let opts = GuidedDecodingOptions::validated(None, None, choice.clone(), None, None);
assert!(opts.is_ok());
let opts = opts.unwrap();
assert_eq!(opts.choice, choice);
assert!(opts.json.is_none());
assert!(opts.regex.is_none());
assert!(opts.grammar.is_none());
// Only grammar set
let grammar = Some("root ::= 'yes' | 'no'".to_string());
let opts = GuidedDecodingOptions::validated(None, None, None, grammar.clone(), None);
assert!(opts.is_ok());
let opts = opts.unwrap();
assert_eq!(opts.grammar, grammar);
assert!(opts.json.is_none());
assert!(opts.regex.is_none());
assert!(opts.choice.is_none());
// Multiple fields set (should error)
let opts = GuidedDecodingOptions::validated(
Some(serde_json::json!({})),
Some(r"\d+".to_string()),
None,
None,
None,
);
assert!(opts.is_err());
let opts = GuidedDecodingOptions::validated(
None,
Some(r"\d+".to_string()),
Some(vec!["A".to_string()]),
None,
None,
);
assert!(opts.is_err());
let opts = GuidedDecodingOptions::validated(
Some(serde_json::json!({})),
None,
Some(vec!["A".to_string()]),
Some("root ::= 'yes'".to_string()),
None,
);
assert!(opts.is_err());
// All fields None (should be ok, but not useful)
let opts = GuidedDecodingOptions::validated(None, None, None, None, None);
assert!(opts.is_ok());
}
#[test]
fn test_guided_decoding_options_from_optional() {
// All None returns Ok(None)
let opts = GuidedDecodingOptions::from_optional(None, None, None, None, None);
assert!(opts.is_ok());
assert!(opts.unwrap().is_none());
// Only one set returns Ok(Some)
let regex = Some(r"\w+".to_string());
let opts = GuidedDecodingOptions::from_optional(None, regex.clone(), None, None, None);
assert!(opts.is_ok());
let val = opts.unwrap();
assert!(val.is_some());
let val = val.unwrap();
assert_eq!(val.regex, regex);
// Multiple set returns Err
let opts = GuidedDecodingOptions::from_optional(
Some(serde_json::json!({})),
Some(r"\d+".to_string()),
None,
None,
None,
);
assert!(opts.is_err());
// Choice set but empty vector should not count as set
let opts = GuidedDecodingOptions::from_optional(None, None, Some(vec![]), None, None);
assert!(opts.is_ok());
let val = opts.unwrap();
assert!(val.is_none());
// Choice set with non-empty vector
let opts = GuidedDecodingOptions::from_optional(
None,
None,
Some(vec!["A".to_string()]),
None,
None,
);
assert!(opts.is_ok());
let val = opts.unwrap();
assert!(val.is_some());
let val = val.unwrap();
assert_eq!(val.choice, Some(vec!["A".to_string()]));
}
}
......@@ -88,6 +88,30 @@ impl<T: OpenAISamplingOptionsProvider> SamplingOptionsProvider for T {
}
}
let mut guided_decoding = None;
if let Some(nvext) = self.nvext() {
let guided_decoding_backend = nvext.guided_decoding_backend.clone();
let guided_json = nvext.guided_json.clone();
let guided_regex = nvext.guided_regex.clone();
let guided_grammar = nvext.guided_grammar.clone();
let guided_choice = nvext.guided_choice.clone();
match common::GuidedDecodingOptions::from_optional(
guided_json,
guided_regex,
guided_choice,
guided_grammar,
guided_decoding_backend,
) {
Ok(options) => guided_decoding = options,
Err(e) => {
// Handle the validation error (log, return error, etc.)
tracing::error!("Invalid guided decoding options: {}", e);
return Err(e);
}
}
}
Ok(common::SamplingOptions {
n: None,
best_of: None,
......@@ -101,6 +125,7 @@ impl<T: OpenAISamplingOptionsProvider> SamplingOptionsProvider for T {
seed: None,
use_beam_search: None,
length_penalty: None,
guided_decoding,
})
}
}
......
......@@ -61,6 +61,32 @@ pub struct NvExt {
#[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub annotations: Option<Vec<String>>,
/// Guided Decoding Options
/// If specified, the output will be a JSON object. Can be a string, an object, or null.
#[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub guided_json: Option<serde_json::Value>,
/// If specified, the output will follow the regex pattern. Can be a string or null.
#[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub guided_regex: Option<String>,
/// If specified, the output will follow the context-free grammar. Can be a string or null.
#[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub guided_grammar: Option<String>,
/// If specified, the output will be exactly one of the choices.
#[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub guided_choice: Option<Vec<String>>,
/// If specified, the backend to use for guided decoding, can be backends like xgrammar or custom guided decoding backend
#[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub guided_decoding_backend: Option<String>,
}
impl Default for NvExt {
......@@ -114,6 +140,10 @@ mod tests {
assert_eq!(nv_ext.top_k, None);
assert_eq!(nv_ext.repetition_penalty, None);
assert_eq!(nv_ext.greed_sampling, None);
assert_eq!(nv_ext.guided_json, None);
assert_eq!(nv_ext.guided_regex, None);
assert_eq!(nv_ext.guided_grammar, None);
assert_eq!(nv_ext.guided_choice, None);
}
// Test valid builder configurations
......@@ -124,6 +154,11 @@ mod tests {
.top_k(10)
.repetition_penalty(1.5)
.greed_sampling(true)
.guided_json(serde_json::json!({"type": "object"}))
.guided_regex("^[0-9]+$".to_string())
.guided_grammar("S -> 'a' S 'b' | 'c'".to_string())
.guided_choice(vec!["choice1".to_string(), "choice2".to_string()])
.guided_decoding_backend("xgrammar".to_string())
.build()
.unwrap();
......@@ -131,7 +166,20 @@ mod tests {
assert_eq!(nv_ext.top_k, Some(10));
assert_eq!(nv_ext.repetition_penalty, Some(1.5));
assert_eq!(nv_ext.greed_sampling, Some(true));
assert_eq!(
nv_ext.guided_json,
Some(serde_json::json!({"type": "object"}))
);
assert_eq!(nv_ext.guided_regex, Some("^[0-9]+$".to_string()));
assert_eq!(
nv_ext.guided_grammar,
Some("S -> 'a' S 'b' | 'c'".to_string())
);
assert_eq!(
nv_ext.guided_choice,
Some(vec!["choice1".to_string(), "choice2".to_string()])
);
assert_eq!(nv_ext.guided_decoding_backend, Some("xgrammar".to_string()));
// Validate the built struct
assert!(nv_ext.validate().is_ok());
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment