Unverified Commit b62e633c authored by GuanLuo's avatar GuanLuo Committed by GitHub
Browse files

feat: support separate chat_template.jinja file (#1853)

parent 8ae37196
......@@ -86,6 +86,7 @@ impl ModelDeploymentCard {
tokenizer: Some(TokenizerKind::from_gguf(gguf_file)?),
gen_config: None, // AFAICT there is no equivalent in a GGUF
prompt_formatter: Some(PromptFormatterArtifact::GGUF(gguf_file.to_path_buf())),
chat_template_file: None,
prompt_context: None, // TODO - auto-detect prompt context
revision: 0,
last_published: None,
......@@ -124,6 +125,7 @@ impl ModelDeploymentCard {
tokenizer: Some(TokenizerKind::from_repo(repo_id).await?),
gen_config: GenerationConfig::from_repo(repo_id).await.ok(), // optional
prompt_formatter: PromptFormatterArtifact::from_repo(repo_id).await?,
chat_template_file: PromptFormatterArtifact::chat_template_from_repo(repo_id).await?,
prompt_context: None, // TODO - auto-detect prompt context
revision: 0,
last_published: None,
......@@ -157,6 +159,19 @@ impl PromptFormatterArtifact {
.ok())
}
pub async fn chat_template_from_repo(repo_id: &str) -> Result<Option<Self>> {
Ok(Self::chat_template_try_is_hf_repo(repo_id)
.await
.with_context(|| format!("unable to extract prompt format from repo {}", repo_id))
.ok())
}
async fn chat_template_try_is_hf_repo(repo: &str) -> anyhow::Result<Self> {
Ok(Self::HfChatTemplate(
check_for_file(repo, "chat_template.jinja").await?,
))
}
async fn try_is_hf_repo(repo: &str) -> anyhow::Result<Self> {
Ok(Self::HfTokenizerConfigJson(
check_for_file(repo, "tokenizer_config.json").await?,
......
......@@ -62,6 +62,7 @@ pub enum TokenizerKind {
#[serde(rename_all = "snake_case")]
pub enum PromptFormatterArtifact {
HfTokenizerConfigJson(String),
HfChatTemplate(String),
GGUF(PathBuf),
}
......@@ -101,6 +102,10 @@ pub struct ModelDeploymentCard {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub prompt_formatter: Option<PromptFormatterArtifact>,
/// chat template may be stored as a separate file instead of in `prompt_formatter`.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub chat_template_file: Option<PromptFormatterArtifact>,
/// Generation config - default sampling params
#[serde(default, skip_serializing_if = "Option::is_none")]
pub gen_config: Option<GenerationConfig>,
......@@ -259,6 +264,11 @@ impl ModelDeploymentCard {
PromptFormatterArtifact::HfTokenizerConfigJson,
"tokenizer_config.json"
);
nats_upload!(
self.chat_template_file,
PromptFormatterArtifact::HfChatTemplate,
"chat_template.jinja"
);
nats_upload!(
self.tokenizer,
TokenizerKind::HfTokenizerJson,
......@@ -308,6 +318,11 @@ impl ModelDeploymentCard {
PromptFormatterArtifact::HfTokenizerConfigJson,
"tokenizer_config.json"
);
nats_download!(
self.chat_template_file,
PromptFormatterArtifact::HfChatTemplate,
"chat_template.jinja"
);
nats_download!(
self.tokenizer,
TokenizerKind::HfTokenizerJson,
......
......@@ -26,7 +26,7 @@ mod oai;
mod tokcfg;
use super::{OAIChatLikeRequest, OAIPromptFormatter, PromptFormatter};
use tokcfg::ChatTemplate;
use tokcfg::{ChatTemplate, ChatTemplateValue};
impl PromptFormatter {
pub async fn from_mdc(mdc: ModelDeploymentCard) -> Result<PromptFormatter> {
......@@ -37,13 +37,28 @@ impl PromptFormatter {
PromptFormatterArtifact::HfTokenizerConfigJson(file) => {
let content = std::fs::read_to_string(&file)
.with_context(|| format!("fs:read_to_string '{file}'"))?;
let config: ChatTemplate = serde_json::from_str(&content)?;
let mut config: ChatTemplate = serde_json::from_str(&content)?;
// Some HF model (i.e. meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8)
// stores the chat template in a separate file, we check if the file exists and
// put the chat template into config as normalization.
if let Some(PromptFormatterArtifact::HfChatTemplate(chat_template_file)) =
mdc.chat_template_file
{
let chat_template = std::fs::read_to_string(&chat_template_file)
.with_context(|| format!("fs:read_to_string '{}'", chat_template_file))?;
// clean up the string to remove newlines
let chat_template = chat_template.replace('\n', "");
config.chat_template = Some(ChatTemplateValue(either::Left(chat_template)));
}
Self::from_parts(
config,
mdc.prompt_context
.map_or(ContextMixins::default(), |x| ContextMixins::new(&x)),
)
}
PromptFormatterArtifact::HfChatTemplate(_) => Err(anyhow::anyhow!(
"prompt_formatter should not have type HfChatTemplate"
)),
PromptFormatterArtifact::GGUF(gguf_path) => {
let config = ChatTemplate::from_gguf(&gguf_path)?;
Self::from_parts(config, ContextMixins::default())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment