Unverified Commit eb6722e3 authored by Chi McIsaac's avatar Chi McIsaac Committed by GitHub
Browse files

feat: add chat_template_kwargs param to v1/chat/completion (#3016)


Signed-off-by: default avatarChi McIsaac <chixie.mcisaac@gmail.com>
parent 9060ce12
......@@ -228,6 +228,7 @@ async fn evaluate(
inner,
common: Default::default(),
nvext: None,
chat_template_args: None,
};
let mut stream = engine.generate(Context::new(req)).await?;
let mut output = String::new();
......
......@@ -111,6 +111,7 @@ async fn main_loop(
inner,
common: Default::default(),
nvext: None,
chat_template_args: None,
};
// Call the model
......
......@@ -1350,6 +1350,7 @@ mod tests {
},
common: Default::default(),
nvext: None,
chat_template_args: None,
};
let result = validate_chat_completion_required_fields(&request);
assert!(result.is_err());
......@@ -1377,6 +1378,7 @@ mod tests {
},
common: Default::default(),
nvext: None,
chat_template_args: None,
};
let result = validate_chat_completion_required_fields(&request);
assert!(result.is_ok());
......@@ -1549,6 +1551,7 @@ mod tests {
},
common: Default::default(),
nvext: None,
chat_template_args: None,
};
let result = validate_chat_completion_fields_generic(&request);
......@@ -1576,6 +1579,7 @@ mod tests {
},
common: Default::default(),
nvext: None,
chat_template_args: None,
};
let result = validate_chat_completion_fields_generic(&request);
assert!(result.is_err());
......@@ -1602,6 +1606,7 @@ mod tests {
},
common: Default::default(),
nvext: None,
chat_template_args: None,
};
let result = validate_chat_completion_fields_generic(&request);
assert!(result.is_err());
......@@ -1628,6 +1633,7 @@ mod tests {
},
common: Default::default(),
nvext: None,
chat_template_args: None,
};
let result = validate_chat_completion_fields_generic(&request);
assert!(result.is_err());
......@@ -1656,6 +1662,7 @@ mod tests {
.build()
.unwrap(),
nvext: None,
chat_template_args: None,
};
let result = validate_chat_completion_fields_generic(&request);
assert!(result.is_err());
......@@ -1682,6 +1689,7 @@ mod tests {
},
common: Default::default(),
nvext: None,
chat_template_args: None,
};
let result = validate_chat_completion_fields_generic(&request);
assert!(result.is_err());
......
......@@ -20,6 +20,7 @@
use anyhow::Result;
use minijinja::value::Value;
use std::collections::HashMap;
use std::sync::Arc;
mod template;
......@@ -57,6 +58,11 @@ pub trait OAIChatLikeRequest {
fn should_add_generation_prompt(&self) -> bool;
/// Optional additional args to merge into the chat template context
fn chat_template_args(&self) -> Option<&HashMap<String, serde_json::Value>> {
None
}
/// Returns the type of input for the prompt. Default is Text.
fn prompt_input_type(&self) -> PromptInput {
PromptInput::Text(TextInput::Single(String::new()))
......
......@@ -114,6 +114,10 @@ impl OAIChatLikeRequest for NvCreateChatCompletionRequest {
fn extract_text(&self) -> Option<TextInput> {
Some(TextInput::Single(String::new()))
}
fn chat_template_args(&self) -> Option<&std::collections::HashMap<String, serde_json::Value>> {
self.chat_template_args.as_ref()
}
}
impl OAIChatLikeRequest for NvCreateCompletionRequest {
......@@ -207,9 +211,13 @@ impl OAIPromptFormatter for HfTokenizerConfigJsonFormatter {
..mixins
};
let ctx = context! { ..ctx, ..context! {
}};
// Merge any additional args into the context last so they take precedence
let ctx = if let Some(args) = req.chat_template_args() {
let extra = Value::from_serialize(args);
context! { ..ctx, ..extra }
} else {
ctx
};
let tmpl: minijinja::Template<'_, '_> = if has_tools {
self.env.get_template("tool_use")?
......
......@@ -41,6 +41,10 @@ pub struct NvCreateChatCompletionRequest {
#[serde(skip_serializing_if = "Option::is_none")]
pub nvext: Option<NvExt>,
/// Extra args to pass to the chat template rendering context
#[serde(default, skip_serializing_if = "Option::is_none")]
pub chat_template_args: Option<std::collections::HashMap<String, serde_json::Value>>,
}
/// A response structure for unary chat completion responses, embedding OpenAI's
......
......@@ -175,6 +175,7 @@ impl TryFrom<NvCreateResponse> for NvCreateChatCompletionRequest {
},
common: Default::default(),
nvext: resp.nvext,
chat_template_args: None,
})
}
}
......
......@@ -768,6 +768,7 @@ async fn test_nv_custom_client() {
inner: inner_request,
common: Default::default(),
nvext: None,
chat_template_args: None,
};
let result = nv_custom_client.chat_stream(request).await;
......@@ -807,6 +808,7 @@ async fn test_nv_custom_client() {
inner: inner_request,
common: Default::default(),
nvext: None,
chat_template_args: None,
};
let result = nv_custom_client.chat_stream(request).await;
......@@ -847,6 +849,7 @@ async fn test_nv_custom_client() {
inner: inner_request,
common: Default::default(),
nvext: None,
chat_template_args: None,
};
let result = nv_custom_client
......
......@@ -270,6 +270,7 @@ impl Request {
inner,
common: Default::default(),
nvext: None,
chat_template_args: None,
}
}
}
......
......@@ -67,6 +67,7 @@ fn test_sampling_parameters_include_stop_str_in_output_extraction() {
.build()
.unwrap(),
nvext: None,
chat_template_args: None,
};
let sampling = request.extract_sampling_options().unwrap();
......@@ -327,6 +328,7 @@ fn test_serialization_preserves_structure() {
ignore_eos: Some(false),
..Default::default()
}),
chat_template_args: None,
};
let json = serde_json::to_value(&request).unwrap();
......@@ -376,6 +378,7 @@ fn test_sampling_parameters_extraction() {
.build()
.unwrap(),
nvext: None,
chat_template_args: None,
};
let sampling_options = request.extract_sampling_options().unwrap();
......
......@@ -146,6 +146,7 @@ fn create_chat_request(include_usage: Option<bool>) -> NvCreateChatCompletionReq
inner,
common: Default::default(),
nvext: None,
chat_template_args: None,
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment