huggingface.rs 9.84 KB
Newer Older
1
2
3
use super::traits::{
    Decoder, Encoder, Encoding, SpecialTokens, TokenIdType, Tokenizer as TokenizerTrait,
};
4
use crate::metrics::TokenizerMetrics;
5
6
use anyhow::{Error, Result};
use std::collections::HashMap;
7
use std::time::Instant;
8
9
use tokenizers::tokenizer::Tokenizer as HfTokenizer;

10
11
12
#[cfg(feature = "minijinja")]
use super::chat_template::{ChatMessage, ChatTemplateProcessor};

13
14
15
16
/// HuggingFace tokenizer wrapper
pub struct HuggingFaceTokenizer {
    tokenizer: HfTokenizer,
    special_tokens: SpecialTokens,
17
18
19
20
    vocab: HashMap<String, TokenIdType>,
    reverse_vocab: HashMap<TokenIdType, String>,
    #[cfg(feature = "minijinja")]
    chat_template: Option<String>,
21
22
23
24
25
}

impl HuggingFaceTokenizer {
    /// Create a tokenizer from a HuggingFace tokenizer JSON file
    pub fn from_file(file_path: &str) -> Result<Self> {
26
27
28
29
30
31
32
33
        Self::from_file_with_chat_template(file_path, None)
    }

    /// Create a tokenizer from a HuggingFace tokenizer JSON file with an optional chat template
    pub fn from_file_with_chat_template(
        file_path: &str,
        chat_template_path: Option<&str>,
    ) -> Result<Self> {
34
35
36
37
38
39
40
41
        let tokenizer = HfTokenizer::from_file(file_path)
            .map_err(|e| Error::msg(format!("Failed to load tokenizer: {}", e)))?;

        // Extract special tokens
        let special_tokens = Self::extract_special_tokens(&tokenizer);

        // Build vocab mappings
        let vocab = tokenizer.get_vocab(false);
42
        let reverse_vocab: HashMap<TokenIdType, String> = vocab
43
44
45
46
            .iter()
            .map(|(token, &id)| (id, token.clone()))
            .collect();

47
48
49
50
51
52
53
54
55
56
        // Load chat template
        #[cfg(feature = "minijinja")]
        let chat_template = if let Some(template_path) = chat_template_path {
            // Load from specified .jinja file
            Self::load_chat_template_from_file(template_path)?
        } else {
            // Try to load from tokenizer_config.json
            Self::load_chat_template(file_path)
        };

57
58
59
60
61
        Ok(HuggingFaceTokenizer {
            tokenizer,
            special_tokens,
            vocab,
            reverse_vocab,
62
63
            #[cfg(feature = "minijinja")]
            chat_template,
64
65
66
67
68
69
70
        })
    }

    /// Create from an existing HuggingFace tokenizer
    pub fn from_tokenizer(tokenizer: HfTokenizer) -> Self {
        let special_tokens = Self::extract_special_tokens(&tokenizer);
        let vocab = tokenizer.get_vocab(false);
71
        let reverse_vocab: HashMap<TokenIdType, String> = vocab
72
73
74
75
76
77
78
79
80
            .iter()
            .map(|(token, &id)| (id, token.clone()))
            .collect();

        HuggingFaceTokenizer {
            tokenizer,
            special_tokens,
            vocab,
            reverse_vocab,
81
82
            #[cfg(feature = "minijinja")]
            chat_template: None,
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
        }
    }

    /// Extract special tokens from the tokenizer
    fn extract_special_tokens(tokenizer: &HfTokenizer) -> SpecialTokens {
        // Try to get special tokens from the tokenizer
        // This is a simplified version - actual implementation would need to handle various formats
        let vocab = tokenizer.get_vocab(true);

        let find_token = |patterns: &[&str]| -> Option<String> {
            for pattern in patterns {
                if vocab.contains_key(*pattern) {
                    return Some(pattern.to_string());
                }
            }
            None
        };

        SpecialTokens {
            bos_token: find_token(&["<s>", "<|startoftext|>", "<BOS>", "[CLS]"]),
            eos_token: find_token(&["</s>", "<|endoftext|>", "<EOS>", "[SEP]"]),
            unk_token: find_token(&["<unk>", "<UNK>", "[UNK]"]),
            sep_token: find_token(&["[SEP]", "<sep>", "<SEP>"]),
            pad_token: find_token(&["<pad>", "<PAD>", "[PAD]"]),
            cls_token: find_token(&["[CLS]", "<cls>", "<CLS>"]),
            mask_token: find_token(&["[MASK]", "<mask>", "<MASK>"]),
            additional_special_tokens: vec![],
        }
    }

113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
    /// Try to load chat template from tokenizer_config.json
    #[cfg(feature = "minijinja")]
    fn load_chat_template(tokenizer_path: &str) -> Option<String> {
        // Try to find tokenizer_config.json in the same directory
        let path = std::path::Path::new(tokenizer_path);
        let dir = path.parent()?;
        let config_path = dir.join("tokenizer_config.json");

        if config_path.exists() {
            if let Ok(template) =
                super::chat_template::load_chat_template_from_config(config_path.to_str()?)
            {
                return template;
            }
        }
        None
    }

    /// Load chat template from a .jinja file
    #[cfg(feature = "minijinja")]
    fn load_chat_template_from_file(template_path: &str) -> Result<Option<String>> {
        use std::fs;

        let content = fs::read_to_string(template_path)
            .map_err(|e| Error::msg(format!("Failed to read chat template file: {}", e)))?;

        // Clean up the template (similar to Python implementation)
        let template = content.trim().replace("\\n", "\n");

        Ok(Some(template))
    }

    /// Set or override the chat template
    #[cfg(feature = "minijinja")]
    pub fn set_chat_template(&mut self, template: String) {
        self.chat_template = Some(template);
    }

151
    /// Apply chat template if available
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
    #[cfg(feature = "minijinja")]
    pub fn apply_chat_template(
        &self,
        messages: &[ChatMessage],
        add_generation_prompt: bool,
    ) -> Result<String> {
        if let Some(ref template) = self.chat_template {
            let processor = ChatTemplateProcessor::new(
                template.clone(),
                self.special_tokens.bos_token.clone(),
                self.special_tokens.eos_token.clone(),
            );
            processor.apply_chat_template(messages, add_generation_prompt)
        } else {
            // Fallback to simple formatting if no template is available
            let mut result = String::new();
            for msg in messages {
                result.push_str(&format!("{}: {}\n", msg.role, msg.content));
            }
            if add_generation_prompt {
                result.push_str("assistant: ");
            }
            Ok(result)
        }
    }

    /// Apply chat template if available (without minijinja feature)
    #[cfg(not(feature = "minijinja"))]
    pub fn apply_chat_template(
        &self,
        messages: &[ChatMessage],
        add_generation_prompt: bool,
    ) -> Result<String> {
        // Fallback to simple formatting
186
187
188
189
        let mut result = String::new();
        for msg in messages {
            result.push_str(&format!("{}: {}\n", msg.role, msg.content));
        }
190
191
192
        if add_generation_prompt {
            result.push_str("assistant: ");
        }
193
194
195
196
197
198
        Ok(result)
    }
}

impl Encoder for HuggingFaceTokenizer {
    fn encode(&self, input: &str) -> Result<Encoding> {
199
200
201
202
        let start = Instant::now();

        TokenizerMetrics::record_encode_request("huggingface");
        TokenizerMetrics::record_chars_per_encode(input.len());
203

204
205
206
207
208
209
210
211
212
213
214
        self.tokenizer
            .encode(input, false)
            .map_err(|e| {
                TokenizerMetrics::record_encode_error("encoding_failed");
                Error::msg(format!("Encoding failed: {}", e))
            })
            .map(|encoding| {
                TokenizerMetrics::record_tokens_per_encode(encoding.get_ids().len());
                TokenizerMetrics::record_encode_duration(start.elapsed());
                Encoding::Hf(Box::new(encoding))
            })
215
216
217
    }

    fn encode_batch(&self, inputs: &[&str]) -> Result<Vec<Encoding>> {
218
219
        let start = Instant::now();

220
221
222
        let encodings = self
            .tokenizer
            .encode_batch(inputs.to_vec(), false)
223
224
225
226
227
228
            .map_err(|e| {
                TokenizerMetrics::record_encode_error("batch_encoding_failed");
                Error::msg(format!("Batch encoding failed: {}", e))
            })?;

        TokenizerMetrics::record_encode_batch_duration(start.elapsed(), inputs.len());
229
230
231
232
233
234
235
236
237

        Ok(encodings
            .into_iter()
            .map(|e| Encoding::Hf(Box::new(e)))
            .collect())
    }
}

impl Decoder for HuggingFaceTokenizer {
238
    fn decode(&self, token_ids: &[TokenIdType], skip_special_tokens: bool) -> Result<String> {
239
240
241
242
243
        let start = Instant::now();

        TokenizerMetrics::record_decode_request("huggingface");
        TokenizerMetrics::record_tokens_per_decode(token_ids.len());

244
245
        self.tokenizer
            .decode(token_ids, skip_special_tokens)
246
247
248
249
250
251
252
            .map_err(|e| {
                TokenizerMetrics::record_decode_error("decoding_failed");
                Error::msg(format!("Decoding failed: {}", e))
            })
            .inspect(|_| {
                TokenizerMetrics::record_decode_duration(start.elapsed());
            })
253
254
255
256
257
258
259
260
261
262
263
264
    }
}

impl TokenizerTrait for HuggingFaceTokenizer {
    fn vocab_size(&self) -> usize {
        self.tokenizer.get_vocab_size(false)
    }

    fn get_special_tokens(&self) -> &SpecialTokens {
        &self.special_tokens
    }

265
    fn token_to_id(&self, token: &str) -> Option<TokenIdType> {
266
267
268
        self.vocab.get(token).copied()
    }

269
    fn id_to_token(&self, id: TokenIdType) -> Option<String> {
270
271
272
273
274
275
        self.reverse_vocab.get(&id).cloned()
    }
}

#[cfg(test)]
mod tests {
276
277
    #[cfg(feature = "minijinja")]
    use super::ChatMessage;
278

279
    #[cfg(feature = "minijinja")]
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
    #[test]
    fn test_chat_message_creation() {
        let msg = ChatMessage::system("You are a helpful assistant");
        assert_eq!(msg.role, "system");
        assert_eq!(msg.content, "You are a helpful assistant");

        let user_msg = ChatMessage::user("Hello!");
        assert_eq!(user_msg.role, "user");

        let assistant_msg = ChatMessage::assistant("Hi there!");
        assert_eq!(assistant_msg.role, "assistant");
    }

    // Note: Actual tokenizer tests would require a real tokenizer file
    // These would be integration tests rather than unit tests
}