huggingface.rs 9.96 KB
Newer Older
1
use std::collections::HashMap;
2
3

use anyhow::{Error, Result};
4
5
use tokenizers::tokenizer::Tokenizer as HfTokenizer;

6
7
8
9
10
11
use super::{
    chat_template::{
        detect_chat_template_content_format, ChatTemplateContentFormat, ChatTemplateParams,
        ChatTemplateProcessor,
    },
    traits::{Decoder, Encoder, Encoding, SpecialTokens, TokenIdType, Tokenizer as TokenizerTrait},
12
};
13

14
15
16
17
/// HuggingFace tokenizer wrapper
pub struct HuggingFaceTokenizer {
    tokenizer: HfTokenizer,
    special_tokens: SpecialTokens,
18
19
20
    vocab: HashMap<String, TokenIdType>,
    reverse_vocab: HashMap<TokenIdType, String>,
    chat_template: Option<String>,
21
22
    /// Detected chat template content format (computed once at initialization)
    content_format: ChatTemplateContentFormat,
23
24
25
26
27
}

impl HuggingFaceTokenizer {
    /// Create a tokenizer from a HuggingFace tokenizer JSON file
    pub fn from_file(file_path: &str) -> Result<Self> {
28
29
30
31
32
33
        // Try to auto-discover chat template if not explicitly provided
        let path = std::path::Path::new(file_path);
        let chat_template_path = path
            .parent()
            .and_then(crate::tokenizer::factory::discover_chat_template_in_dir);
        Self::from_file_with_chat_template(file_path, chat_template_path.as_deref())
34
35
36
37
38
39
40
    }

    /// Create a tokenizer from a HuggingFace tokenizer JSON file with an optional chat template
    pub fn from_file_with_chat_template(
        file_path: &str,
        chat_template_path: Option<&str>,
    ) -> Result<Self> {
41
42
43
44
45
46
        let tokenizer = HfTokenizer::from_file(file_path)
            .map_err(|e| Error::msg(format!("Failed to load tokenizer: {}", e)))?;

        // Extract special tokens
        let special_tokens = Self::extract_special_tokens(&tokenizer);

47
48
        // Build vocab mappings (include special tokens to get added_tokens like <|im_start|>)
        let vocab = tokenizer.get_vocab(true); // true = include special tokens and added_tokens
49
        let reverse_vocab: HashMap<TokenIdType, String> = vocab
50
51
52
53
            .iter()
            .map(|(token, &id)| (id, token.clone()))
            .collect();

54
55
56
57
58
59
60
61
62
        // Load chat template
        let chat_template = if let Some(template_path) = chat_template_path {
            // Load from specified .jinja file
            Self::load_chat_template_from_file(template_path)?
        } else {
            // Try to load from tokenizer_config.json
            Self::load_chat_template(file_path)
        };

63
64
65
66
67
68
69
        // Detect content format once at initialization
        let content_format = if let Some(ref template) = chat_template {
            detect_chat_template_content_format(template)
        } else {
            ChatTemplateContentFormat::String // Default if no template
        };

70
71
72
73
74
        Ok(HuggingFaceTokenizer {
            tokenizer,
            special_tokens,
            vocab,
            reverse_vocab,
75
            chat_template,
76
            content_format,
77
78
79
80
81
82
        })
    }

    /// Create from an existing HuggingFace tokenizer
    pub fn from_tokenizer(tokenizer: HfTokenizer) -> Self {
        let special_tokens = Self::extract_special_tokens(&tokenizer);
83
        let vocab = tokenizer.get_vocab(true); // true = include special tokens and added_tokens
84
        let reverse_vocab: HashMap<TokenIdType, String> = vocab
85
86
87
88
89
90
91
92
93
            .iter()
            .map(|(token, &id)| (id, token.clone()))
            .collect();

        HuggingFaceTokenizer {
            tokenizer,
            special_tokens,
            vocab,
            reverse_vocab,
94
            chat_template: None,
95
            content_format: ChatTemplateContentFormat::String, // Default
96
97
98
99
100
        }
    }

    /// Extract special tokens from the tokenizer
    fn extract_special_tokens(tokenizer: &HfTokenizer) -> SpecialTokens {
101
        // Get vocab with special tokens included (added_tokens like <|im_start|>)
102
103
104
105
106
107
108
109
110
111
112
        let vocab = tokenizer.get_vocab(true);

        let find_token = |patterns: &[&str]| -> Option<String> {
            for pattern in patterns {
                if vocab.contains_key(*pattern) {
                    return Some(pattern.to_string());
                }
            }
            None
        };

113
114
115
116
117
118
119
120
        // Extract additional special tokens using the tokenizers library API
        let additional_special_tokens: Vec<String> = tokenizer
            .get_added_tokens_decoder()
            .iter()
            .filter(|(_id, token)| token.special) // Only tokens marked as special: true
            .map(|(_id, token)| token.content.clone())
            .collect();

121
122
123
124
125
126
127
128
        SpecialTokens {
            bos_token: find_token(&["<s>", "<|startoftext|>", "<BOS>", "[CLS]"]),
            eos_token: find_token(&["</s>", "<|endoftext|>", "<EOS>", "[SEP]"]),
            unk_token: find_token(&["<unk>", "<UNK>", "[UNK]"]),
            sep_token: find_token(&["[SEP]", "<sep>", "<SEP>"]),
            pad_token: find_token(&["<pad>", "<PAD>", "[PAD]"]),
            cls_token: find_token(&["[CLS]", "<cls>", "<CLS>"]),
            mask_token: find_token(&["[MASK]", "<mask>", "<MASK>"]),
129
            additional_special_tokens,
130
131
132
        }
    }

133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
    /// Try to load chat template from tokenizer_config.json
    fn load_chat_template(tokenizer_path: &str) -> Option<String> {
        // Try to find tokenizer_config.json in the same directory
        let path = std::path::Path::new(tokenizer_path);
        let dir = path.parent()?;
        let config_path = dir.join("tokenizer_config.json");

        if config_path.exists() {
            if let Ok(template) =
                super::chat_template::load_chat_template_from_config(config_path.to_str()?)
            {
                return template;
            }
        }
        None
    }

150
    /// Load chat template from a file (.jinja or .json containing Jinja)
151
152
153
154
155
156
    fn load_chat_template_from_file(template_path: &str) -> Result<Option<String>> {
        use std::fs;

        let content = fs::read_to_string(template_path)
            .map_err(|e| Error::msg(format!("Failed to read chat template file: {}", e)))?;

157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
        // Check if it's a JSON file containing a Jinja template
        if template_path.ends_with(".json") {
            // Parse JSON and extract the template string
            let json_value: serde_json::Value = serde_json::from_str(&content)
                .map_err(|e| Error::msg(format!("Failed to parse chat_template.json: {}", e)))?;

            if let Some(template_str) = json_value.as_str() {
                return Ok(Some(template_str.to_string()));
            } else if let Some(obj) = json_value.as_object() {
                if let Some(template_value) = obj.get("chat_template") {
                    if let Some(template_str) = template_value.as_str() {
                        return Ok(Some(template_str.to_string()));
                    }
                }
            }

            return Err(Error::msg(
                "chat_template.json does not contain a valid template",
            ));
        }

        // Otherwise it's a plain .jinja file
179
180
181
182
183
184
185
186
        // Clean up the template (similar to Python implementation)
        let template = content.trim().replace("\\n", "\n");

        Ok(Some(template))
    }

    /// Set or override the chat template
    pub fn set_chat_template(&mut self, template: String) {
187
188
        // Detect format for the new template
        self.content_format = detect_chat_template_content_format(&template);
189
190
191
        self.chat_template = Some(template);
    }

192
193
194
195
196
    /// Get the content format expected by the chat template
    pub fn chat_template_content_format(&self) -> ChatTemplateContentFormat {
        self.content_format
    }

197
    /// Apply chat template if available
198
199
    ///
    /// Takes transformed JSON Values (already transformed based on content format)
200
201
    pub fn apply_chat_template(
        &self,
202
        messages: &[serde_json::Value],
203
        params: ChatTemplateParams,
204
205
    ) -> Result<String> {
        if let Some(ref template) = self.chat_template {
206
207
            let processor = ChatTemplateProcessor::new(template.clone());
            processor.apply_chat_template(messages, params)
208
        } else {
209
210
211
212
213
214
            Err(Error::msg(
                "Cannot use chat template functions because tokenizer.chat_template is not set and no template \
                argument was passed! For information about writing templates and setting the \
                tokenizer.chat_template attribute, please see the documentation at \
                https://huggingface.co/docs/transformers/main/en/chat_templating"
            ))
215
216
        }
    }
217
218
219
220
}

impl Encoder for HuggingFaceTokenizer {
    fn encode(&self, input: &str) -> Result<Encoding> {
221
222
        self.tokenizer
            .encode(input, false)
223
224
            .map_err(|e| Error::msg(format!("Encoding failed: {}", e)))
            .map(|encoding| Encoding::Hf(Box::new(encoding)))
225
226
227
228
229
230
    }

    fn encode_batch(&self, inputs: &[&str]) -> Result<Vec<Encoding>> {
        let encodings = self
            .tokenizer
            .encode_batch(inputs.to_vec(), false)
231
            .map_err(|e| Error::msg(format!("Batch encoding failed: {}", e)))?;
232
233
234
235
236
237
238
239
240

        Ok(encodings
            .into_iter()
            .map(|e| Encoding::Hf(Box::new(e)))
            .collect())
    }
}

impl Decoder for HuggingFaceTokenizer {
241
    fn decode(&self, token_ids: &[TokenIdType], skip_special_tokens: bool) -> Result<String> {
242
243
        self.tokenizer
            .decode(token_ids, skip_special_tokens)
244
            .map_err(|e| Error::msg(format!("Decoding failed: {}", e)))
245
246
247
248
249
250
251
252
253
254
255
256
    }
}

impl TokenizerTrait for HuggingFaceTokenizer {
    fn vocab_size(&self) -> usize {
        self.tokenizer.get_vocab_size(false)
    }

    fn get_special_tokens(&self) -> &SpecialTokens {
        &self.special_tokens
    }

257
    fn token_to_id(&self, token: &str) -> Option<TokenIdType> {
258
259
260
        self.vocab.get(token).copied()
    }

261
    fn id_to_token(&self, id: TokenIdType) -> Option<String> {
262
263
        self.reverse_vocab.get(&id).cloned()
    }
264
265
266
267

    fn as_any(&self) -> &dyn std::any::Any {
        self
    }
268
269
270
271
272
273
274
}

#[cfg(test)]
mod tests {
    // Note: Actual tokenizer tests would require a real tokenizer file
    // These would be integration tests rather than unit tests
}