template.rs 5.88 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Biswa Panda's avatar
Biswa Panda committed
2
3
4
5
// SPDX-License-Identifier: Apache-2.0

use std::{collections::HashSet, sync::Arc};

6
use anyhow::{Context, Ok, Result};
Biswa Panda's avatar
Biswa Panda committed
7
8
use minijinja::Environment;

9
use crate::model_card::{ModelDeploymentCard, PromptContextMixin, PromptFormatterArtifact};
Biswa Panda's avatar
Biswa Panda committed
10
11
12
13
14
15
16

mod context;
mod formatters;
mod oai;
mod tokcfg;

use super::{OAIChatLikeRequest, OAIPromptFormatter, PromptFormatter};
17
18
pub use tokcfg::ChatTemplate;
use tokcfg::ChatTemplateValue;
Biswa Panda's avatar
Biswa Panda committed
19
20

impl PromptFormatter {
21
    pub fn from_mdc(mdc: &ModelDeploymentCard) -> Result<PromptFormatter> {
22
23
24
25
26
27
28
29
30
31
32
33
        // Special handling for DeepSeek-V3.2(-Speciale) which doesn't provide Jinja chat_template
        let name_lower = mdc.display_name.to_lowercase();
        if name_lower.contains("deepseek")
            && name_lower.contains("v3.2")
            && !name_lower.contains("exp")
        {
            tracing::info!("Detected DeepSeek V3.2 model (non-Exp), using native Rust formatter");
            return Ok(Self::OAI(Arc::new(
                super::deepseek_v32::DeepSeekV32Formatter::new_thinking(),
            )));
        }

Biswa Panda's avatar
Biswa Panda committed
34
35
        match mdc
            .prompt_formatter
36
            .as_ref()
Biswa Panda's avatar
Biswa Panda committed
37
38
            .ok_or(anyhow::anyhow!("MDC does not contain a prompt formatter"))?
        {
39
40
41
42
43
44
45
            PromptFormatterArtifact::HfTokenizerConfigJson(checked_file) => {
                let Some(file) = checked_file.path() else {
                    anyhow::bail!(
                        "HfTokenizerConfigJson for {} is a URL, cannot load",
                        mdc.display_name
                    );
                };
46
47
48
49
50
51
                let contents = std::fs::read_to_string(file).with_context(|| {
                    format!(
                        "PromptFormatter.from_mdc fs:read_to_string '{}'",
                        file.display()
                    )
                })?;
52
53
54
55
                let mut config: ChatTemplate =
                    serde_json::from_str(&contents).inspect_err(|err| {
                        crate::log_json_err(&file.display().to_string(), &contents, err)
                    })?;
56

57
58
59
                // Some HF model (i.e. meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8)
                // stores the chat template in a separate file, we check if the file exists and
                // put the chat template into config as normalization.
60
                // This may also be a custom template provided via CLI flag.
61
62
63
                if let Some(PromptFormatterArtifact::HfChatTemplate {
                    file: checked_file, ..
                }) = mdc.chat_template_file.as_ref()
64
                {
65
66
67
68
69
70
71
72
73
74
                    let Some(chat_template_file) = checked_file.path() else {
                        anyhow::bail!(
                            "HfChatTemplate for {} is a URL, cannot load",
                            mdc.display_name
                        );
                    };
                    let chat_template =
                        std::fs::read_to_string(chat_template_file).with_context(|| {
                            format!("fs:read_to_string '{}'", chat_template_file.display())
                        })?;
75
76
                    config.chat_template = Some(ChatTemplateValue(either::Left(chat_template)));
                }
77
                Self::from_parts(
Biswa Panda's avatar
Biswa Panda committed
78
79
                    config,
                    mdc.prompt_context
80
                        .clone()
Biswa Panda's avatar
Biswa Panda committed
81
                        .map_or(ContextMixins::default(), |x| ContextMixins::new(&x)),
82
                    mdc.runtime_config.exclude_tools_when_tool_choice_none,
83
84
                )
            }
85
            PromptFormatterArtifact::HfChatTemplate { .. } => Err(anyhow::anyhow!(
86
87
                "prompt_formatter should not have type HfChatTemplate"
            )),
Biswa Panda's avatar
Biswa Panda committed
88
89
        }
    }
90

91
92
93
94
95
96
97
98
99
100
    pub fn from_parts(
        config: ChatTemplate,
        context: ContextMixins,
        exclude_tools_when_tool_choice_none: bool,
    ) -> Result<PromptFormatter> {
        let formatter = HfTokenizerConfigJsonFormatter::with_options(
            config,
            context,
            exclude_tools_when_tool_choice_none,
        )?;
101
102
        Ok(Self::OAI(Arc::new(formatter)))
    }
Biswa Panda's avatar
Biswa Panda committed
103
104
105
106
107
108
109
110
111
112
113
}

/// Chat Template Jinja Renderer
///
/// Manages a Jinja environment with registered templates for chat formatting.
/// Handles two types of ChatTemplateValue templates:
///
/// 1. String template: Registered as the 'default' template
/// 2. Map template: Contains 'tool_use' and/or 'default' templates
///    - tool_use: Template for tool-based interactions
///    - default: Template for standard chat interactions
114
///
Biswa Panda's avatar
Biswa Panda committed
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
///   If the map contains both keys, the `tool_use` template is registered as the `tool_use` template
///   and the `default` template is registered as the `default` template.
struct JinjaEnvironment {
    env: Environment<'static>,
}

/// Formatter for HuggingFace tokenizer config JSON templates
///
/// Implements chat template rendering based on HuggingFace's tokenizer_config.json format.
/// Supports:
/// - Tool usage templates
/// - Generation prompts
/// - Context mixins for template customization
#[derive(Debug)]
struct HfTokenizerConfigJsonFormatter {
    env: Environment<'static>,
131
    config: ChatTemplate,
Biswa Panda's avatar
Biswa Panda committed
132
133
    mixins: Arc<ContextMixins>,
    supports_add_generation_prompt: bool,
134
    requires_content_arrays: bool,
135
136
137
    /// When true, strip tool definitions from the chat template when tool_choice is "none".
    /// This prevents models from generating raw XML tool calls in the content field.
    exclude_tools_when_tool_choice_none: bool,
Biswa Panda's avatar
Biswa Panda committed
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
}

// /// OpenAI Standard Prompt Formatter
// pub trait StandardPromptFormatter {
//     fn render(&self, context: &impl StandardPromptContext) -> Result<String>;
// }

// pub trait StandardPromptContext {
//     fn messages(&self) -> Value;
//     fn tools(&self) -> Option<Value>;
// }

#[derive(Debug, Clone, Default)]
pub struct ContextMixins {
    context_mixins: HashSet<PromptContextMixin>,
}