template.rs 3.6 KB
Newer Older
Biswa Panda's avatar
Biswa Panda committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use std::{collections::HashSet, sync::Arc};

use anyhow::{Ok, Result};
use minijinja::Environment;

use crate::model_card::model::{ModelDeploymentCard, PromptContextMixin, PromptFormatterArtifact};

mod context;
mod formatters;
mod oai;
mod tokcfg;

use super::{OAIChatLikeRequest, OAIPromptFormatter, PromptFormatter};
29
use tokcfg::ChatTemplate;
Biswa Panda's avatar
Biswa Panda committed
30
31
32
33
34
35
36
37
38

impl PromptFormatter {
    pub async fn from_mdc(mdc: ModelDeploymentCard) -> Result<PromptFormatter> {
        match mdc
            .prompt_formatter
            .ok_or(anyhow::anyhow!("MDC does not contain a prompt formatter"))?
        {
            PromptFormatterArtifact::HfTokenizerConfigJson(file) => {
                let content = std::fs::read_to_string(file)?;
39
40
                let config: ChatTemplate = serde_json::from_str(&content)?;
                Self::from_parts(
Biswa Panda's avatar
Biswa Panda committed
41
42
43
                    config,
                    mdc.prompt_context
                        .map_or(ContextMixins::default(), |x| ContextMixins::new(&x)),
44
45
46
47
48
                )
            }
            PromptFormatterArtifact::GGUF(gguf_path) => {
                let config = ChatTemplate::from_gguf(&gguf_path)?;
                Self::from_parts(config, ContextMixins::default())
Biswa Panda's avatar
Biswa Panda committed
49
50
51
            }
        }
    }
52
53
54
55
56

    pub fn from_parts(config: ChatTemplate, context: ContextMixins) -> Result<PromptFormatter> {
        let formatter = HfTokenizerConfigJsonFormatter::new(config, context)?;
        Ok(Self::OAI(Arc::new(formatter)))
    }
Biswa Panda's avatar
Biswa Panda committed
57
58
59
60
61
62
63
64
65
66
67
}

/// Chat Template Jinja Renderer
///
/// Manages a Jinja environment with registered templates for chat formatting.
/// Handles two types of ChatTemplateValue templates:
///
/// 1. String template: Registered as the 'default' template
/// 2. Map template: Contains 'tool_use' and/or 'default' templates
///    - tool_use: Template for tool-based interactions
///    - default: Template for standard chat interactions
68
///
Biswa Panda's avatar
Biswa Panda committed
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
///   If the map contains both keys, the `tool_use` template is registered as the `tool_use` template
///   and the `default` template is registered as the `default` template.
struct JinjaEnvironment {
    env: Environment<'static>,
}

/// Formatter for HuggingFace tokenizer config JSON templates
///
/// Implements chat template rendering based on HuggingFace's tokenizer_config.json format.
/// Supports:
/// - Tool usage templates
/// - Generation prompts
/// - Context mixins for template customization
#[derive(Debug)]
struct HfTokenizerConfigJsonFormatter {
    env: Environment<'static>,
85
    config: ChatTemplate,
Biswa Panda's avatar
Biswa Panda committed
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
    mixins: Arc<ContextMixins>,
    supports_add_generation_prompt: bool,
}

// /// OpenAI Standard Prompt Formatter
// pub trait StandardPromptFormatter {
//     fn render(&self, context: &impl StandardPromptContext) -> Result<String>;
// }

// pub trait StandardPromptContext {
//     fn messages(&self) -> Value;
//     fn tools(&self) -> Option<Value>;
// }

#[derive(Debug, Clone, Default)]
pub struct ContextMixins {
    context_mixins: HashSet<PromptContextMixin>,
}