Unverified Commit eb6722e3 authored by Chi McIsaac's avatar Chi McIsaac Committed by GitHub
Browse files

feat: add chat_template_kwargs param to v1/chat/completion (#3016)


Signed-off-by: default avatarChi McIsaac <chixie.mcisaac@gmail.com>
parent 9060ce12
...@@ -228,6 +228,7 @@ async fn evaluate( ...@@ -228,6 +228,7 @@ async fn evaluate(
inner, inner,
common: Default::default(), common: Default::default(),
nvext: None, nvext: None,
chat_template_args: None,
}; };
let mut stream = engine.generate(Context::new(req)).await?; let mut stream = engine.generate(Context::new(req)).await?;
let mut output = String::new(); let mut output = String::new();
......
...@@ -111,6 +111,7 @@ async fn main_loop( ...@@ -111,6 +111,7 @@ async fn main_loop(
inner, inner,
common: Default::default(), common: Default::default(),
nvext: None, nvext: None,
chat_template_args: None,
}; };
// Call the model // Call the model
......
...@@ -1350,6 +1350,7 @@ mod tests { ...@@ -1350,6 +1350,7 @@ mod tests {
}, },
common: Default::default(), common: Default::default(),
nvext: None, nvext: None,
chat_template_args: None,
}; };
let result = validate_chat_completion_required_fields(&request); let result = validate_chat_completion_required_fields(&request);
assert!(result.is_err()); assert!(result.is_err());
...@@ -1377,6 +1378,7 @@ mod tests { ...@@ -1377,6 +1378,7 @@ mod tests {
}, },
common: Default::default(), common: Default::default(),
nvext: None, nvext: None,
chat_template_args: None,
}; };
let result = validate_chat_completion_required_fields(&request); let result = validate_chat_completion_required_fields(&request);
assert!(result.is_ok()); assert!(result.is_ok());
...@@ -1549,6 +1551,7 @@ mod tests { ...@@ -1549,6 +1551,7 @@ mod tests {
}, },
common: Default::default(), common: Default::default(),
nvext: None, nvext: None,
chat_template_args: None,
}; };
let result = validate_chat_completion_fields_generic(&request); let result = validate_chat_completion_fields_generic(&request);
...@@ -1576,6 +1579,7 @@ mod tests { ...@@ -1576,6 +1579,7 @@ mod tests {
}, },
common: Default::default(), common: Default::default(),
nvext: None, nvext: None,
chat_template_args: None,
}; };
let result = validate_chat_completion_fields_generic(&request); let result = validate_chat_completion_fields_generic(&request);
assert!(result.is_err()); assert!(result.is_err());
...@@ -1602,6 +1606,7 @@ mod tests { ...@@ -1602,6 +1606,7 @@ mod tests {
}, },
common: Default::default(), common: Default::default(),
nvext: None, nvext: None,
chat_template_args: None,
}; };
let result = validate_chat_completion_fields_generic(&request); let result = validate_chat_completion_fields_generic(&request);
assert!(result.is_err()); assert!(result.is_err());
...@@ -1628,6 +1633,7 @@ mod tests { ...@@ -1628,6 +1633,7 @@ mod tests {
}, },
common: Default::default(), common: Default::default(),
nvext: None, nvext: None,
chat_template_args: None,
}; };
let result = validate_chat_completion_fields_generic(&request); let result = validate_chat_completion_fields_generic(&request);
assert!(result.is_err()); assert!(result.is_err());
...@@ -1656,6 +1662,7 @@ mod tests { ...@@ -1656,6 +1662,7 @@ mod tests {
.build() .build()
.unwrap(), .unwrap(),
nvext: None, nvext: None,
chat_template_args: None,
}; };
let result = validate_chat_completion_fields_generic(&request); let result = validate_chat_completion_fields_generic(&request);
assert!(result.is_err()); assert!(result.is_err());
...@@ -1682,6 +1689,7 @@ mod tests { ...@@ -1682,6 +1689,7 @@ mod tests {
}, },
common: Default::default(), common: Default::default(),
nvext: None, nvext: None,
chat_template_args: None,
}; };
let result = validate_chat_completion_fields_generic(&request); let result = validate_chat_completion_fields_generic(&request);
assert!(result.is_err()); assert!(result.is_err());
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
use anyhow::Result; use anyhow::Result;
use minijinja::value::Value; use minijinja::value::Value;
use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
mod template; mod template;
...@@ -57,6 +58,11 @@ pub trait OAIChatLikeRequest { ...@@ -57,6 +58,11 @@ pub trait OAIChatLikeRequest {
fn should_add_generation_prompt(&self) -> bool; fn should_add_generation_prompt(&self) -> bool;
/// Optional additional args to merge into the chat template context
fn chat_template_args(&self) -> Option<&HashMap<String, serde_json::Value>> {
None
}
/// Returns the type of input for the prompt. Default is Text. /// Returns the type of input for the prompt. Default is Text.
fn prompt_input_type(&self) -> PromptInput { fn prompt_input_type(&self) -> PromptInput {
PromptInput::Text(TextInput::Single(String::new())) PromptInput::Text(TextInput::Single(String::new()))
......
...@@ -114,6 +114,10 @@ impl OAIChatLikeRequest for NvCreateChatCompletionRequest { ...@@ -114,6 +114,10 @@ impl OAIChatLikeRequest for NvCreateChatCompletionRequest {
fn extract_text(&self) -> Option<TextInput> { fn extract_text(&self) -> Option<TextInput> {
Some(TextInput::Single(String::new())) Some(TextInput::Single(String::new()))
} }
fn chat_template_args(&self) -> Option<&std::collections::HashMap<String, serde_json::Value>> {
self.chat_template_args.as_ref()
}
} }
impl OAIChatLikeRequest for NvCreateCompletionRequest { impl OAIChatLikeRequest for NvCreateCompletionRequest {
...@@ -207,9 +211,13 @@ impl OAIPromptFormatter for HfTokenizerConfigJsonFormatter { ...@@ -207,9 +211,13 @@ impl OAIPromptFormatter for HfTokenizerConfigJsonFormatter {
..mixins ..mixins
}; };
let ctx = context! { ..ctx, ..context! { // Merge any additional args into the context last so they take precedence
let ctx = if let Some(args) = req.chat_template_args() {
}}; let extra = Value::from_serialize(args);
context! { ..ctx, ..extra }
} else {
ctx
};
let tmpl: minijinja::Template<'_, '_> = if has_tools { let tmpl: minijinja::Template<'_, '_> = if has_tools {
self.env.get_template("tool_use")? self.env.get_template("tool_use")?
......
...@@ -41,6 +41,10 @@ pub struct NvCreateChatCompletionRequest { ...@@ -41,6 +41,10 @@ pub struct NvCreateChatCompletionRequest {
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub nvext: Option<NvExt>, pub nvext: Option<NvExt>,
/// Extra args to pass to the chat template rendering context
#[serde(default, skip_serializing_if = "Option::is_none")]
pub chat_template_args: Option<std::collections::HashMap<String, serde_json::Value>>,
} }
/// A response structure for unary chat completion responses, embedding OpenAI's /// A response structure for unary chat completion responses, embedding OpenAI's
......
...@@ -175,6 +175,7 @@ impl TryFrom<NvCreateResponse> for NvCreateChatCompletionRequest { ...@@ -175,6 +175,7 @@ impl TryFrom<NvCreateResponse> for NvCreateChatCompletionRequest {
}, },
common: Default::default(), common: Default::default(),
nvext: resp.nvext, nvext: resp.nvext,
chat_template_args: None,
}) })
} }
} }
......
...@@ -768,6 +768,7 @@ async fn test_nv_custom_client() { ...@@ -768,6 +768,7 @@ async fn test_nv_custom_client() {
inner: inner_request, inner: inner_request,
common: Default::default(), common: Default::default(),
nvext: None, nvext: None,
chat_template_args: None,
}; };
let result = nv_custom_client.chat_stream(request).await; let result = nv_custom_client.chat_stream(request).await;
...@@ -807,6 +808,7 @@ async fn test_nv_custom_client() { ...@@ -807,6 +808,7 @@ async fn test_nv_custom_client() {
inner: inner_request, inner: inner_request,
common: Default::default(), common: Default::default(),
nvext: None, nvext: None,
chat_template_args: None,
}; };
let result = nv_custom_client.chat_stream(request).await; let result = nv_custom_client.chat_stream(request).await;
...@@ -847,6 +849,7 @@ async fn test_nv_custom_client() { ...@@ -847,6 +849,7 @@ async fn test_nv_custom_client() {
inner: inner_request, inner: inner_request,
common: Default::default(), common: Default::default(),
nvext: None, nvext: None,
chat_template_args: None,
}; };
let result = nv_custom_client let result = nv_custom_client
......
...@@ -270,6 +270,7 @@ impl Request { ...@@ -270,6 +270,7 @@ impl Request {
inner, inner,
common: Default::default(), common: Default::default(),
nvext: None, nvext: None,
chat_template_args: None,
} }
} }
} }
......
...@@ -67,6 +67,7 @@ fn test_sampling_parameters_include_stop_str_in_output_extraction() { ...@@ -67,6 +67,7 @@ fn test_sampling_parameters_include_stop_str_in_output_extraction() {
.build() .build()
.unwrap(), .unwrap(),
nvext: None, nvext: None,
chat_template_args: None,
}; };
let sampling = request.extract_sampling_options().unwrap(); let sampling = request.extract_sampling_options().unwrap();
...@@ -327,6 +328,7 @@ fn test_serialization_preserves_structure() { ...@@ -327,6 +328,7 @@ fn test_serialization_preserves_structure() {
ignore_eos: Some(false), ignore_eos: Some(false),
..Default::default() ..Default::default()
}), }),
chat_template_args: None,
}; };
let json = serde_json::to_value(&request).unwrap(); let json = serde_json::to_value(&request).unwrap();
...@@ -376,6 +378,7 @@ fn test_sampling_parameters_extraction() { ...@@ -376,6 +378,7 @@ fn test_sampling_parameters_extraction() {
.build() .build()
.unwrap(), .unwrap(),
nvext: None, nvext: None,
chat_template_args: None,
}; };
let sampling_options = request.extract_sampling_options().unwrap(); let sampling_options = request.extract_sampling_options().unwrap();
......
...@@ -146,6 +146,7 @@ fn create_chat_request(include_usage: Option<bool>) -> NvCreateChatCompletionReq ...@@ -146,6 +146,7 @@ fn create_chat_request(include_usage: Option<bool>) -> NvCreateChatCompletionReq
inner, inner,
common: Default::default(), common: Default::default(),
nvext: None, nvext: None,
chat_template_args: None,
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment