template.rs 7.78 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Biswa Panda's avatar
Biswa Panda committed
2
3
4
5
// SPDX-License-Identifier: Apache-2.0

use std::{collections::HashSet, sync::Arc};

6
use anyhow::{Context, Ok, Result};
Biswa Panda's avatar
Biswa Panda committed
7
8
use minijinja::Environment;

9
use crate::model_card::{ModelDeploymentCard, PromptContextMixin, PromptFormatterArtifact};
Biswa Panda's avatar
Biswa Panda committed
10
11
12
13
14
15
16

mod context;
mod formatters;
mod oai;
mod tokcfg;

use super::{OAIChatLikeRequest, OAIPromptFormatter, PromptFormatter};
17
18
pub use tokcfg::ChatTemplate;
use tokcfg::ChatTemplateValue;
Biswa Panda's avatar
Biswa Panda committed
19
20

impl PromptFormatter {
21
    pub fn from_mdc(mdc: &ModelDeploymentCard) -> Result<PromptFormatter> {
22
23
24
25
26
27
28
29
30
31
32
33
        // Special handling for DeepSeek-V3.2(-Speciale) which doesn't provide Jinja chat_template
        let name_lower = mdc.display_name.to_lowercase();
        if name_lower.contains("deepseek")
            && name_lower.contains("v3.2")
            && !name_lower.contains("exp")
        {
            tracing::info!("Detected DeepSeek V3.2 model (non-Exp), using native Rust formatter");
            return Ok(Self::OAI(Arc::new(
                super::deepseek_v32::DeepSeekV32Formatter::new_thinking(),
            )));
        }

Biswa Panda's avatar
Biswa Panda committed
34
35
        match mdc
            .prompt_formatter
36
            .as_ref()
Biswa Panda's avatar
Biswa Panda committed
37
38
            .ok_or(anyhow::anyhow!("MDC does not contain a prompt formatter"))?
        {
39
40
41
42
43
44
45
            PromptFormatterArtifact::HfTokenizerConfigJson(checked_file) => {
                let Some(file) = checked_file.path() else {
                    anyhow::bail!(
                        "HfTokenizerConfigJson for {} is a URL, cannot load",
                        mdc.display_name
                    );
                };
46
47
48
49
50
51
                let contents = std::fs::read_to_string(file).with_context(|| {
                    format!(
                        "PromptFormatter.from_mdc fs:read_to_string '{}'",
                        file.display()
                    )
                })?;
52
53
54
55
                let mut config: ChatTemplate =
                    serde_json::from_str(&contents).inspect_err(|err| {
                        crate::log_json_err(&file.display().to_string(), &contents, err)
                    })?;
56

57
58
59
                // Some HF model (i.e. meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8)
                // stores the chat template in a separate file, we check if the file exists and
                // put the chat template into config as normalization.
60
                // This may also be a custom template provided via CLI flag.
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
                match mdc.chat_template_file.as_ref() {
                    Some(PromptFormatterArtifact::HfChatTemplateJinja {
                        file: checked_file,
                        ..
                    }) => {
                        let Some(path) = checked_file.path() else {
                            anyhow::bail!(
                                "HfChatTemplateJinja for {} is a URL, cannot load",
                                mdc.display_name
                            );
                        };
                        let chat_template = std::fs::read_to_string(path)
                            .with_context(|| format!("fs:read_to_string '{}'", path.display()))?;
                        config.chat_template = Some(ChatTemplateValue(either::Left(chat_template)));
                    }
                    Some(PromptFormatterArtifact::HfChatTemplateJson {
                        file: checked_file,
                        ..
                    }) => {
                        let Some(path) = checked_file.path() else {
                            anyhow::bail!(
                                "HfChatTemplateJson for {} is a URL, cannot load",
                                mdc.display_name
                            );
                        };
                        let raw = std::fs::read_to_string(path)
                            .with_context(|| format!("fs:read_to_string '{}'", path.display()))?;
                        let wrapper: serde_json::Value =
                            serde_json::from_str(&raw).with_context(|| {
                                format!("Failed to parse '{}' as JSON", path.display())
                            })?;
                        let field = wrapper.get("chat_template").ok_or_else(|| {
                            anyhow::anyhow!(
                                "'{}' does not contain a 'chat_template' field",
                                path.display()
                            )
97
                        })?;
98
99
100
101
102
103
104
105
106
107
                        let value = serde_json::from_value::<ChatTemplateValue>(field.clone())
                            .with_context(|| {
                                format!(
                                    "Failed to deserialize 'chat_template' in '{}'",
                                    path.display()
                                )
                            })?;
                        config.chat_template = Some(value);
                    }
                    _ => {}
108
                }
109
                Self::from_parts(
Biswa Panda's avatar
Biswa Panda committed
110
111
                    config,
                    mdc.prompt_context
112
                        .clone()
Biswa Panda's avatar
Biswa Panda committed
113
                        .map_or(ContextMixins::default(), |x| ContextMixins::new(&x)),
114
                    mdc.runtime_config.exclude_tools_when_tool_choice_none,
115
116
                )
            }
117
118
119
            PromptFormatterArtifact::HfChatTemplateJinja { .. }
            | PromptFormatterArtifact::HfChatTemplateJson { .. } => Err(anyhow::anyhow!(
                "prompt_formatter should not have type HfChatTemplate*"
120
            )),
Biswa Panda's avatar
Biswa Panda committed
121
122
        }
    }
123

124
125
126
127
128
129
130
131
132
133
    pub fn from_parts(
        config: ChatTemplate,
        context: ContextMixins,
        exclude_tools_when_tool_choice_none: bool,
    ) -> Result<PromptFormatter> {
        let formatter = HfTokenizerConfigJsonFormatter::with_options(
            config,
            context,
            exclude_tools_when_tool_choice_none,
        )?;
134
135
        Ok(Self::OAI(Arc::new(formatter)))
    }
Biswa Panda's avatar
Biswa Panda committed
136
137
138
139
140
141
142
143
144
145
146
}

/// Chat Template Jinja Renderer
///
/// Manages a Jinja environment with registered templates for chat formatting.
/// Handles two types of ChatTemplateValue templates:
///
/// 1. String template: Registered as the 'default' template
/// 2. Map template: Contains 'tool_use' and/or 'default' templates
///    - tool_use: Template for tool-based interactions
///    - default: Template for standard chat interactions
147
///
Biswa Panda's avatar
Biswa Panda committed
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
///   If the map contains both keys, the `tool_use` template is registered as the `tool_use` template
///   and the `default` template is registered as the `default` template.
struct JinjaEnvironment {
    env: Environment<'static>,
}

/// Formatter for HuggingFace tokenizer config JSON templates
///
/// Implements chat template rendering based on HuggingFace's tokenizer_config.json format.
/// Supports:
/// - Tool usage templates
/// - Generation prompts
/// - Context mixins for template customization
#[derive(Debug)]
struct HfTokenizerConfigJsonFormatter {
    env: Environment<'static>,
164
    config: ChatTemplate,
Biswa Panda's avatar
Biswa Panda committed
165
166
    mixins: Arc<ContextMixins>,
    supports_add_generation_prompt: bool,
167
    requires_content_arrays: bool,
168
169
170
    /// When true, strip tool definitions from the chat template when tool_choice is "none".
    /// This prevents models from generating raw XML tool calls in the content field.
    exclude_tools_when_tool_choice_none: bool,
171
172
173
    /// True if the chat template natively references `reasoning_content`.
    /// When true, skip injection — the template handles it.
    template_handles_reasoning: bool,
Biswa Panda's avatar
Biswa Panda committed
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
}

// /// OpenAI Standard Prompt Formatter
// pub trait StandardPromptFormatter {
//     fn render(&self, context: &impl StandardPromptContext) -> Result<String>;
// }

// pub trait StandardPromptContext {
//     fn messages(&self) -> Value;
//     fn tools(&self) -> Option<Value>;
// }

#[derive(Debug, Clone, Default)]
pub struct ContextMixins {
    context_mixins: HashSet<PromptContextMixin>,
}