huggingface.rs 8.25 KB
Newer Older
1
use std::collections::HashMap;
2
3

use anyhow::{Error, Result};
4
5
use tokenizers::tokenizer::Tokenizer as HfTokenizer;

6
use super::chat_template::{
7
8
    detect_chat_template_content_format, ChatTemplateContentFormat, ChatTemplateParams,
    ChatTemplateProcessor,
9
10
11
12
};
use super::traits::{
    Decoder, Encoder, Encoding, SpecialTokens, TokenIdType, Tokenizer as TokenizerTrait,
};
13

14
15
16
17
/// HuggingFace tokenizer wrapper
pub struct HuggingFaceTokenizer {
    tokenizer: HfTokenizer,
    special_tokens: SpecialTokens,
18
19
20
    vocab: HashMap<String, TokenIdType>,
    reverse_vocab: HashMap<TokenIdType, String>,
    chat_template: Option<String>,
21
22
    /// Detected chat template content format (computed once at initialization)
    content_format: ChatTemplateContentFormat,
23
24
25
26
27
}

impl HuggingFaceTokenizer {
    /// Create a tokenizer from a HuggingFace tokenizer JSON file
    pub fn from_file(file_path: &str) -> Result<Self> {
28
29
30
31
32
33
34
35
        Self::from_file_with_chat_template(file_path, None)
    }

    /// Create a tokenizer from a HuggingFace tokenizer JSON file with an optional chat template
    pub fn from_file_with_chat_template(
        file_path: &str,
        chat_template_path: Option<&str>,
    ) -> Result<Self> {
36
37
38
39
40
41
42
43
        let tokenizer = HfTokenizer::from_file(file_path)
            .map_err(|e| Error::msg(format!("Failed to load tokenizer: {}", e)))?;

        // Extract special tokens
        let special_tokens = Self::extract_special_tokens(&tokenizer);

        // Build vocab mappings
        let vocab = tokenizer.get_vocab(false);
44
        let reverse_vocab: HashMap<TokenIdType, String> = vocab
45
46
47
48
            .iter()
            .map(|(token, &id)| (id, token.clone()))
            .collect();

49
50
51
52
53
54
55
56
57
        // Load chat template
        let chat_template = if let Some(template_path) = chat_template_path {
            // Load from specified .jinja file
            Self::load_chat_template_from_file(template_path)?
        } else {
            // Try to load from tokenizer_config.json
            Self::load_chat_template(file_path)
        };

58
59
60
61
62
63
64
        // Detect content format once at initialization
        let content_format = if let Some(ref template) = chat_template {
            detect_chat_template_content_format(template)
        } else {
            ChatTemplateContentFormat::String // Default if no template
        };

65
66
67
68
69
        Ok(HuggingFaceTokenizer {
            tokenizer,
            special_tokens,
            vocab,
            reverse_vocab,
70
            chat_template,
71
            content_format,
72
73
74
75
76
77
78
        })
    }

    /// Create from an existing HuggingFace tokenizer
    pub fn from_tokenizer(tokenizer: HfTokenizer) -> Self {
        let special_tokens = Self::extract_special_tokens(&tokenizer);
        let vocab = tokenizer.get_vocab(false);
79
        let reverse_vocab: HashMap<TokenIdType, String> = vocab
80
81
82
83
84
85
86
87
88
            .iter()
            .map(|(token, &id)| (id, token.clone()))
            .collect();

        HuggingFaceTokenizer {
            tokenizer,
            special_tokens,
            vocab,
            reverse_vocab,
89
            chat_template: None,
90
            content_format: ChatTemplateContentFormat::String, // Default
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
        }
    }

    /// Extract special tokens from the tokenizer
    fn extract_special_tokens(tokenizer: &HfTokenizer) -> SpecialTokens {
        // Try to get special tokens from the tokenizer
        // This is a simplified version - actual implementation would need to handle various formats
        let vocab = tokenizer.get_vocab(true);

        let find_token = |patterns: &[&str]| -> Option<String> {
            for pattern in patterns {
                if vocab.contains_key(*pattern) {
                    return Some(pattern.to_string());
                }
            }
            None
        };

        SpecialTokens {
            bos_token: find_token(&["<s>", "<|startoftext|>", "<BOS>", "[CLS]"]),
            eos_token: find_token(&["</s>", "<|endoftext|>", "<EOS>", "[SEP]"]),
            unk_token: find_token(&["<unk>", "<UNK>", "[UNK]"]),
            sep_token: find_token(&["[SEP]", "<sep>", "<SEP>"]),
            pad_token: find_token(&["<pad>", "<PAD>", "[PAD]"]),
            cls_token: find_token(&["[CLS]", "<cls>", "<CLS>"]),
            mask_token: find_token(&["[MASK]", "<mask>", "<MASK>"]),
            additional_special_tokens: vec![],
        }
    }

121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
    /// Try to load chat template from tokenizer_config.json
    fn load_chat_template(tokenizer_path: &str) -> Option<String> {
        // Try to find tokenizer_config.json in the same directory
        let path = std::path::Path::new(tokenizer_path);
        let dir = path.parent()?;
        let config_path = dir.join("tokenizer_config.json");

        if config_path.exists() {
            if let Ok(template) =
                super::chat_template::load_chat_template_from_config(config_path.to_str()?)
            {
                return template;
            }
        }
        None
    }

    /// Load chat template from a .jinja file
    fn load_chat_template_from_file(template_path: &str) -> Result<Option<String>> {
        use std::fs;

        let content = fs::read_to_string(template_path)
            .map_err(|e| Error::msg(format!("Failed to read chat template file: {}", e)))?;

        // Clean up the template (similar to Python implementation)
        let template = content.trim().replace("\\n", "\n");

        Ok(Some(template))
    }

    /// Set or override the chat template
    pub fn set_chat_template(&mut self, template: String) {
153
154
        // Detect format for the new template
        self.content_format = detect_chat_template_content_format(&template);
155
156
157
        self.chat_template = Some(template);
    }

158
159
160
161
162
    /// Get the content format expected by the chat template
    pub fn chat_template_content_format(&self) -> ChatTemplateContentFormat {
        self.content_format
    }

163
    /// Apply chat template if available
164
165
    ///
    /// Takes transformed JSON Values (already transformed based on content format)
166
167
    pub fn apply_chat_template(
        &self,
168
        messages: &[serde_json::Value],
169
        params: ChatTemplateParams,
170
171
    ) -> Result<String> {
        if let Some(ref template) = self.chat_template {
172
173
            let processor = ChatTemplateProcessor::new(template.clone());
            processor.apply_chat_template(messages, params)
174
        } else {
175
176
177
178
179
180
            Err(Error::msg(
                "Cannot use chat template functions because tokenizer.chat_template is not set and no template \
                argument was passed! For information about writing templates and setting the \
                tokenizer.chat_template attribute, please see the documentation at \
                https://huggingface.co/docs/transformers/main/en/chat_templating"
            ))
181
182
        }
    }
183
184
185
186
}

impl Encoder for HuggingFaceTokenizer {
    fn encode(&self, input: &str) -> Result<Encoding> {
187
188
        self.tokenizer
            .encode(input, false)
189
190
            .map_err(|e| Error::msg(format!("Encoding failed: {}", e)))
            .map(|encoding| Encoding::Hf(Box::new(encoding)))
191
192
193
194
195
196
    }

    fn encode_batch(&self, inputs: &[&str]) -> Result<Vec<Encoding>> {
        let encodings = self
            .tokenizer
            .encode_batch(inputs.to_vec(), false)
197
            .map_err(|e| Error::msg(format!("Batch encoding failed: {}", e)))?;
198
199
200
201
202
203
204
205
206

        Ok(encodings
            .into_iter()
            .map(|e| Encoding::Hf(Box::new(e)))
            .collect())
    }
}

impl Decoder for HuggingFaceTokenizer {
207
    fn decode(&self, token_ids: &[TokenIdType], skip_special_tokens: bool) -> Result<String> {
208
209
        self.tokenizer
            .decode(token_ids, skip_special_tokens)
210
            .map_err(|e| Error::msg(format!("Decoding failed: {}", e)))
211
212
213
214
215
216
217
218
219
220
221
222
    }
}

impl TokenizerTrait for HuggingFaceTokenizer {
    fn vocab_size(&self) -> usize {
        self.tokenizer.get_vocab_size(false)
    }

    fn get_special_tokens(&self) -> &SpecialTokens {
        &self.special_tokens
    }

223
    fn token_to_id(&self, token: &str) -> Option<TokenIdType> {
224
225
226
        self.vocab.get(token).copied()
    }

227
    fn id_to_token(&self, id: TokenIdType) -> Option<String> {
228
229
        self.reverse_vocab.get(&id).cloned()
    }
230
231
232
233

    fn as_any(&self) -> &dyn std::any::Any {
        self
    }
234
235
236
237
238
239
240
}

#[cfg(test)]
mod tests {
    // Note: Actual tokenizer tests would require a real tokenizer file
    // These would be integration tests rather than unit tests
}