// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. //based on: https://github.com/EricLBuehler/mistral.rs/blob/d970bb5feb863acf8e8ec90de97e18221fb959f1/mistralrs-core/src/pipeline/chat_template.rs use std::{collections::HashMap, fs::File, path::Path}; use either::Either; use ggus::{GGufMetaKV, GGufReader}; use memmap2::Mmap; use minijinja::{value::Kwargs, Error, ErrorKind, Value}; use serde::{Deserialize, Serialize}; #[allow(dead_code)] #[derive(Debug, Deserialize)] pub struct AddedTokensDecoder { __type: Option, pub content: String, lstrip: bool, normalized: bool, rstrip: bool, single_word: bool, special: Option, } pub fn raise_exception(msg: String) -> Result { Err(minijinja::Error::new(ErrorKind::InvalidOperation, msg)) } #[derive(Debug, Deserialize)] pub struct BeginEndUnkTok( #[serde(with = "either::serde_untagged")] pub Either, ); /// Support older tool use patterns where the tool use template was separate from the default/chat template. /// Modern patterns use a single template with a `tool_use` key, e.g. /// /// ```jinja /// {%- if tools is not none and tool_choice is not none %} /// ``` #[derive(Debug, Deserialize)] pub struct ChatTemplateValue( #[serde(with = "either::serde_untagged")] pub Either>>, ); /// If present, pad_token is usually a single value. Deepseek R1 and it's distill's use a map. #[allow(dead_code)] #[derive(Debug, Deserialize)] pub struct PadTokenValue( #[serde(with = "either::serde_untagged")] pub Either, ); #[allow(dead_code)] #[derive(Debug, Deserialize, Default)] /// Template for chat models including bos/eos/unk as well as the chat template. pub struct ChatTemplate { pub bos_token: Option, pub eos_token: Option, pub unk_token: Option, /// Jinja format [chat templating] for chat completion. /// /// [chat templating]: https://huggingface.co/docs/transformers/chat_templating pub chat_template: Option, // future add_bos_token: Option, add_eos_token: Option, added_tokens_decoder: Option>, additional_special_tokens: Option>, clean_up_tokenization_spaces: Option, device_map: Option, legacy: Option, model_max_length: Option, pad_token: Option, sp_model_kwargs: Option>, spaces_between_special_tokens: Option, tokenizer_class: Option, truncation_size: Option, use_default_system_prompt: Option, } impl ChatTemplate { pub fn from_gguf(path: &Path) -> anyhow::Result { let file = match File::open(path) { Ok(f) => unsafe { Mmap::map(&f)? }, Err(e) => { anyhow::bail!("Failed to open file '{}': {e:?}", path.display()); } }; let mut reader = GGufReader::new(&file); let header = match reader.read_header() { Ok(header) => header, Err(e) => { anyhow::bail!("Failed to read GGUF header of {}: {e:?}", path.display()); } }; let num_metadata = header.metadata_kv_count; let mut out = ChatTemplate::default(); // bos/eos/unk token conversion fn convert(kv: GGufMetaKV) -> Option { let id_string: String = kv.read_unsigned().to_string(); Some(BeginEndUnkTok(Either::Left(id_string))) } let mut num_found = 0; let num_expected = 4; // How many fields we need for _ in 0..num_metadata { let kv = match reader.read_meta_kv() { Ok(kv) => kv, Err(err) => anyhow::bail!("read_meta_kv error in '{}': {err:?}", path.display()), }; match kv.key() { "tokenizer.ggml.bos_token_id" => { out.bos_token = convert(kv); num_found += 1; } "tokenizer.ggml.eos_token_id" => { out.eos_token = convert(kv); num_found += 1; } "tokenizer.ggml.unknown_token_id" => { out.unk_token = convert(kv); num_found += 1; } "tokenizer.chat_template" => { out.chat_template = kv .value_reader() .read_str() .ok() .map(|s| ChatTemplateValue(Either::Left(s.to_string()))); num_found += 1; } _ => {} } if num_found == num_expected { // No need to look at any more keys break; } } Ok(out) } // pub fn has_chat_template(&self) -> bool { // self.chat_template.is_some() // } pub fn eos_tok(&self) -> Option { match self.eos_token.as_ref()?.0 { Either::Left(ref lit) => Some(lit.clone()), Either::Right(ref added) => Some(added.content.clone()), } } pub fn bos_tok(&self) -> Option { match self.bos_token.as_ref()?.0 { Either::Left(ref lit) => Some(lit.clone()), Either::Right(ref added) => Some(added.content.clone()), } } pub fn unk_tok(&self) -> Option { match self.unk_token.as_ref()?.0 { Either::Left(ref lit) => Some(lit.clone()), Either::Right(ref added) => Some(added.content.clone()), } } } #[allow(dead_code)] #[derive(Debug, Deserialize)] pub struct GenerationConfig { #[serde(with = "either::serde_untagged")] bos_token_id: Either>, #[serde(with = "either::serde_untagged")] eos_token_id: Either>, } pub fn tojson(value: Value, kwargs: Kwargs) -> Result { if let Ok(indent) = kwargs.get("indent") { let mut buf = Vec::new(); let repeat = b" ".repeat(indent); let formatter = serde_json::ser::PrettyFormatter::with_indent(&repeat); let mut serializer = serde_json::Serializer::with_formatter(&mut buf, formatter); value.serialize(&mut serializer).unwrap(); String::from_utf8(buf).map_err(|err| { Error::new(ErrorKind::BadSerialization, "cannot serialize to JSON").with_source(err) }) } else { serde_json::to_string(&value).map_err(|err| { Error::new(ErrorKind::BadSerialization, "cannot serialize to JSON").with_source(err) }) } .map_err(|err| { Error::new(ErrorKind::InvalidOperation, "cannot serialize to JSON").with_source(err) }) .map(|s| { // When this filter is used the return value is safe for both HTML and JSON let mut rv = String::with_capacity(s.len()); for c in s.chars() { match c { '<' => rv.push_str("\\u003c"), '>' => rv.push_str("\\u003e"), '&' => rv.push_str("\\u0026"), '\'' => rv.push_str("\\u0027"), _ => rv.push(c), } } Value::from_safe_string(rv) }) }