Unverified Commit b62e633c authored by GuanLuo's avatar GuanLuo Committed by GitHub
Browse files

feat: support separate chat_template.jinja file (#1853)

parent 8ae37196
...@@ -86,6 +86,7 @@ impl ModelDeploymentCard { ...@@ -86,6 +86,7 @@ impl ModelDeploymentCard {
tokenizer: Some(TokenizerKind::from_gguf(gguf_file)?), tokenizer: Some(TokenizerKind::from_gguf(gguf_file)?),
gen_config: None, // AFAICT there is no equivalent in a GGUF gen_config: None, // AFAICT there is no equivalent in a GGUF
prompt_formatter: Some(PromptFormatterArtifact::GGUF(gguf_file.to_path_buf())), prompt_formatter: Some(PromptFormatterArtifact::GGUF(gguf_file.to_path_buf())),
chat_template_file: None,
prompt_context: None, // TODO - auto-detect prompt context prompt_context: None, // TODO - auto-detect prompt context
revision: 0, revision: 0,
last_published: None, last_published: None,
...@@ -124,6 +125,7 @@ impl ModelDeploymentCard { ...@@ -124,6 +125,7 @@ impl ModelDeploymentCard {
tokenizer: Some(TokenizerKind::from_repo(repo_id).await?), tokenizer: Some(TokenizerKind::from_repo(repo_id).await?),
gen_config: GenerationConfig::from_repo(repo_id).await.ok(), // optional gen_config: GenerationConfig::from_repo(repo_id).await.ok(), // optional
prompt_formatter: PromptFormatterArtifact::from_repo(repo_id).await?, prompt_formatter: PromptFormatterArtifact::from_repo(repo_id).await?,
chat_template_file: PromptFormatterArtifact::chat_template_from_repo(repo_id).await?,
prompt_context: None, // TODO - auto-detect prompt context prompt_context: None, // TODO - auto-detect prompt context
revision: 0, revision: 0,
last_published: None, last_published: None,
...@@ -157,6 +159,19 @@ impl PromptFormatterArtifact { ...@@ -157,6 +159,19 @@ impl PromptFormatterArtifact {
.ok()) .ok())
} }
pub async fn chat_template_from_repo(repo_id: &str) -> Result<Option<Self>> {
Ok(Self::chat_template_try_is_hf_repo(repo_id)
.await
.with_context(|| format!("unable to extract prompt format from repo {}", repo_id))
.ok())
}
async fn chat_template_try_is_hf_repo(repo: &str) -> anyhow::Result<Self> {
Ok(Self::HfChatTemplate(
check_for_file(repo, "chat_template.jinja").await?,
))
}
async fn try_is_hf_repo(repo: &str) -> anyhow::Result<Self> { async fn try_is_hf_repo(repo: &str) -> anyhow::Result<Self> {
Ok(Self::HfTokenizerConfigJson( Ok(Self::HfTokenizerConfigJson(
check_for_file(repo, "tokenizer_config.json").await?, check_for_file(repo, "tokenizer_config.json").await?,
......
...@@ -62,6 +62,7 @@ pub enum TokenizerKind { ...@@ -62,6 +62,7 @@ pub enum TokenizerKind {
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub enum PromptFormatterArtifact { pub enum PromptFormatterArtifact {
HfTokenizerConfigJson(String), HfTokenizerConfigJson(String),
HfChatTemplate(String),
GGUF(PathBuf), GGUF(PathBuf),
} }
...@@ -101,6 +102,10 @@ pub struct ModelDeploymentCard { ...@@ -101,6 +102,10 @@ pub struct ModelDeploymentCard {
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
pub prompt_formatter: Option<PromptFormatterArtifact>, pub prompt_formatter: Option<PromptFormatterArtifact>,
/// chat template may be stored as a separate file instead of in `prompt_formatter`.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub chat_template_file: Option<PromptFormatterArtifact>,
/// Generation config - default sampling params /// Generation config - default sampling params
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
pub gen_config: Option<GenerationConfig>, pub gen_config: Option<GenerationConfig>,
...@@ -259,6 +264,11 @@ impl ModelDeploymentCard { ...@@ -259,6 +264,11 @@ impl ModelDeploymentCard {
PromptFormatterArtifact::HfTokenizerConfigJson, PromptFormatterArtifact::HfTokenizerConfigJson,
"tokenizer_config.json" "tokenizer_config.json"
); );
nats_upload!(
self.chat_template_file,
PromptFormatterArtifact::HfChatTemplate,
"chat_template.jinja"
);
nats_upload!( nats_upload!(
self.tokenizer, self.tokenizer,
TokenizerKind::HfTokenizerJson, TokenizerKind::HfTokenizerJson,
...@@ -308,6 +318,11 @@ impl ModelDeploymentCard { ...@@ -308,6 +318,11 @@ impl ModelDeploymentCard {
PromptFormatterArtifact::HfTokenizerConfigJson, PromptFormatterArtifact::HfTokenizerConfigJson,
"tokenizer_config.json" "tokenizer_config.json"
); );
nats_download!(
self.chat_template_file,
PromptFormatterArtifact::HfChatTemplate,
"chat_template.jinja"
);
nats_download!( nats_download!(
self.tokenizer, self.tokenizer,
TokenizerKind::HfTokenizerJson, TokenizerKind::HfTokenizerJson,
......
...@@ -26,7 +26,7 @@ mod oai; ...@@ -26,7 +26,7 @@ mod oai;
mod tokcfg; mod tokcfg;
use super::{OAIChatLikeRequest, OAIPromptFormatter, PromptFormatter}; use super::{OAIChatLikeRequest, OAIPromptFormatter, PromptFormatter};
use tokcfg::ChatTemplate; use tokcfg::{ChatTemplate, ChatTemplateValue};
impl PromptFormatter { impl PromptFormatter {
pub async fn from_mdc(mdc: ModelDeploymentCard) -> Result<PromptFormatter> { pub async fn from_mdc(mdc: ModelDeploymentCard) -> Result<PromptFormatter> {
...@@ -37,13 +37,28 @@ impl PromptFormatter { ...@@ -37,13 +37,28 @@ impl PromptFormatter {
PromptFormatterArtifact::HfTokenizerConfigJson(file) => { PromptFormatterArtifact::HfTokenizerConfigJson(file) => {
let content = std::fs::read_to_string(&file) let content = std::fs::read_to_string(&file)
.with_context(|| format!("fs:read_to_string '{file}'"))?; .with_context(|| format!("fs:read_to_string '{file}'"))?;
let config: ChatTemplate = serde_json::from_str(&content)?; let mut config: ChatTemplate = serde_json::from_str(&content)?;
// Some HF model (i.e. meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8)
// stores the chat template in a separate file, we check if the file exists and
// put the chat template into config as normalization.
if let Some(PromptFormatterArtifact::HfChatTemplate(chat_template_file)) =
mdc.chat_template_file
{
let chat_template = std::fs::read_to_string(&chat_template_file)
.with_context(|| format!("fs:read_to_string '{}'", chat_template_file))?;
// clean up the string to remove newlines
let chat_template = chat_template.replace('\n', "");
config.chat_template = Some(ChatTemplateValue(either::Left(chat_template)));
}
Self::from_parts( Self::from_parts(
config, config,
mdc.prompt_context mdc.prompt_context
.map_or(ContextMixins::default(), |x| ContextMixins::new(&x)), .map_or(ContextMixins::default(), |x| ContextMixins::new(&x)),
) )
} }
PromptFormatterArtifact::HfChatTemplate(_) => Err(anyhow::anyhow!(
"prompt_formatter should not have type HfChatTemplate"
)),
PromptFormatterArtifact::GGUF(gguf_path) => { PromptFormatterArtifact::GGUF(gguf_path) => {
let config = ChatTemplate::from_gguf(&gguf_path)?; let config = ChatTemplate::from_gguf(&gguf_path)?;
Self::from_parts(config, ContextMixins::default()) Self::from_parts(config, ContextMixins::default())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment