huggingface.rs 8.38 KB
Newer Older
1
use std::collections::HashMap;
2
3

use anyhow::{Error, Result};
4
5
use tokenizers::tokenizer::Tokenizer as HfTokenizer;

6
7
8
9
10
11
use super::chat_template::{
    detect_chat_template_content_format, ChatTemplateContentFormat, ChatTemplateProcessor,
};
use super::traits::{
    Decoder, Encoder, Encoding, SpecialTokens, TokenIdType, Tokenizer as TokenizerTrait,
};
12

13
14
15
16
/// HuggingFace tokenizer wrapper
pub struct HuggingFaceTokenizer {
    tokenizer: HfTokenizer,
    special_tokens: SpecialTokens,
17
18
19
    vocab: HashMap<String, TokenIdType>,
    reverse_vocab: HashMap<TokenIdType, String>,
    chat_template: Option<String>,
20
21
    /// Detected chat template content format (computed once at initialization)
    content_format: ChatTemplateContentFormat,
22
23
24
25
26
}

impl HuggingFaceTokenizer {
    /// Create a tokenizer from a HuggingFace tokenizer JSON file
    pub fn from_file(file_path: &str) -> Result<Self> {
27
28
29
30
31
32
33
34
        Self::from_file_with_chat_template(file_path, None)
    }

    /// Create a tokenizer from a HuggingFace tokenizer JSON file with an optional chat template
    pub fn from_file_with_chat_template(
        file_path: &str,
        chat_template_path: Option<&str>,
    ) -> Result<Self> {
35
36
37
38
39
40
41
42
        let tokenizer = HfTokenizer::from_file(file_path)
            .map_err(|e| Error::msg(format!("Failed to load tokenizer: {}", e)))?;

        // Extract special tokens
        let special_tokens = Self::extract_special_tokens(&tokenizer);

        // Build vocab mappings
        let vocab = tokenizer.get_vocab(false);
43
        let reverse_vocab: HashMap<TokenIdType, String> = vocab
44
45
46
47
            .iter()
            .map(|(token, &id)| (id, token.clone()))
            .collect();

48
49
50
51
52
53
54
55
56
        // Load chat template
        let chat_template = if let Some(template_path) = chat_template_path {
            // Load from specified .jinja file
            Self::load_chat_template_from_file(template_path)?
        } else {
            // Try to load from tokenizer_config.json
            Self::load_chat_template(file_path)
        };

57
58
59
60
61
62
63
        // Detect content format once at initialization
        let content_format = if let Some(ref template) = chat_template {
            detect_chat_template_content_format(template)
        } else {
            ChatTemplateContentFormat::String // Default if no template
        };

64
65
66
67
68
        Ok(HuggingFaceTokenizer {
            tokenizer,
            special_tokens,
            vocab,
            reverse_vocab,
69
            chat_template,
70
            content_format,
71
72
73
74
75
76
77
        })
    }

    /// Create from an existing HuggingFace tokenizer
    pub fn from_tokenizer(tokenizer: HfTokenizer) -> Self {
        let special_tokens = Self::extract_special_tokens(&tokenizer);
        let vocab = tokenizer.get_vocab(false);
78
        let reverse_vocab: HashMap<TokenIdType, String> = vocab
79
80
81
82
83
84
85
86
87
            .iter()
            .map(|(token, &id)| (id, token.clone()))
            .collect();

        HuggingFaceTokenizer {
            tokenizer,
            special_tokens,
            vocab,
            reverse_vocab,
88
            chat_template: None,
89
            content_format: ChatTemplateContentFormat::String, // Default
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
        }
    }

    /// Extract special tokens from the tokenizer
    fn extract_special_tokens(tokenizer: &HfTokenizer) -> SpecialTokens {
        // Try to get special tokens from the tokenizer
        // This is a simplified version - actual implementation would need to handle various formats
        let vocab = tokenizer.get_vocab(true);

        let find_token = |patterns: &[&str]| -> Option<String> {
            for pattern in patterns {
                if vocab.contains_key(*pattern) {
                    return Some(pattern.to_string());
                }
            }
            None
        };

        SpecialTokens {
            bos_token: find_token(&["<s>", "<|startoftext|>", "<BOS>", "[CLS]"]),
            eos_token: find_token(&["</s>", "<|endoftext|>", "<EOS>", "[SEP]"]),
            unk_token: find_token(&["<unk>", "<UNK>", "[UNK]"]),
            sep_token: find_token(&["[SEP]", "<sep>", "<SEP>"]),
            pad_token: find_token(&["<pad>", "<PAD>", "[PAD]"]),
            cls_token: find_token(&["[CLS]", "<cls>", "<CLS>"]),
            mask_token: find_token(&["[MASK]", "<mask>", "<MASK>"]),
            additional_special_tokens: vec![],
        }
    }

120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
    /// Try to load chat template from tokenizer_config.json
    fn load_chat_template(tokenizer_path: &str) -> Option<String> {
        // Try to find tokenizer_config.json in the same directory
        let path = std::path::Path::new(tokenizer_path);
        let dir = path.parent()?;
        let config_path = dir.join("tokenizer_config.json");

        if config_path.exists() {
            if let Ok(template) =
                super::chat_template::load_chat_template_from_config(config_path.to_str()?)
            {
                return template;
            }
        }
        None
    }

    /// Load chat template from a .jinja file
    fn load_chat_template_from_file(template_path: &str) -> Result<Option<String>> {
        use std::fs;

        let content = fs::read_to_string(template_path)
            .map_err(|e| Error::msg(format!("Failed to read chat template file: {}", e)))?;

        // Clean up the template (similar to Python implementation)
        let template = content.trim().replace("\\n", "\n");

        Ok(Some(template))
    }

    /// Set or override the chat template
    pub fn set_chat_template(&mut self, template: String) {
152
153
        // Detect format for the new template
        self.content_format = detect_chat_template_content_format(&template);
154
155
156
        self.chat_template = Some(template);
    }

157
158
159
160
161
    /// Get the content format expected by the chat template
    pub fn chat_template_content_format(&self) -> ChatTemplateContentFormat {
        self.content_format
    }

162
    /// Apply chat template if available
163
164
    ///
    /// Takes transformed JSON Values (already transformed based on content format)
165
166
    pub fn apply_chat_template(
        &self,
167
        messages: &[serde_json::Value],
168
169
170
171
172
173
174
175
        add_generation_prompt: bool,
    ) -> Result<String> {
        if let Some(ref template) = self.chat_template {
            let processor = ChatTemplateProcessor::new(
                template.clone(),
                self.special_tokens.bos_token.clone(),
                self.special_tokens.eos_token.clone(),
            );
176

177
178
            processor.apply_chat_template(messages, add_generation_prompt)
        } else {
179
180
181
182
183
184
            Err(Error::msg(
                "Cannot use chat template functions because tokenizer.chat_template is not set and no template \
                argument was passed! For information about writing templates and setting the \
                tokenizer.chat_template attribute, please see the documentation at \
                https://huggingface.co/docs/transformers/main/en/chat_templating"
            ))
185
186
        }
    }
187
188
189
190
}

impl Encoder for HuggingFaceTokenizer {
    fn encode(&self, input: &str) -> Result<Encoding> {
191
192
        self.tokenizer
            .encode(input, false)
193
194
            .map_err(|e| Error::msg(format!("Encoding failed: {}", e)))
            .map(|encoding| Encoding::Hf(Box::new(encoding)))
195
196
197
198
199
200
    }

    fn encode_batch(&self, inputs: &[&str]) -> Result<Vec<Encoding>> {
        let encodings = self
            .tokenizer
            .encode_batch(inputs.to_vec(), false)
201
            .map_err(|e| Error::msg(format!("Batch encoding failed: {}", e)))?;
202
203
204
205
206
207
208
209
210

        Ok(encodings
            .into_iter()
            .map(|e| Encoding::Hf(Box::new(e)))
            .collect())
    }
}

impl Decoder for HuggingFaceTokenizer {
211
    fn decode(&self, token_ids: &[TokenIdType], skip_special_tokens: bool) -> Result<String> {
212
213
        self.tokenizer
            .decode(token_ids, skip_special_tokens)
214
            .map_err(|e| Error::msg(format!("Decoding failed: {}", e)))
215
216
217
218
219
220
221
222
223
224
225
226
    }
}

impl TokenizerTrait for HuggingFaceTokenizer {
    fn vocab_size(&self) -> usize {
        self.tokenizer.get_vocab_size(false)
    }

    fn get_special_tokens(&self) -> &SpecialTokens {
        &self.special_tokens
    }

227
    fn token_to_id(&self, token: &str) -> Option<TokenIdType> {
228
229
230
        self.vocab.get(token).copied()
    }

231
    fn id_to_token(&self, id: TokenIdType) -> Option<String> {
232
233
        self.reverse_vocab.get(&id).cloned()
    }
234
235
236
237

    fn as_any(&self) -> &dyn std::any::Any {
        self
    }
238
239
240
241
242
243
244
}

#[cfg(test)]
mod tests {
    // Note: Actual tokenizer tests would require a real tokenizer file
    // These would be integration tests rather than unit tests
}