huggingface.rs 8.71 KB
Newer Older
1
2
3
use super::traits::{
    Decoder, Encoder, Encoding, SpecialTokens, TokenIdType, Tokenizer as TokenizerTrait,
};
4
5
6
7
use anyhow::{Error, Result};
use std::collections::HashMap;
use tokenizers::tokenizer::Tokenizer as HfTokenizer;

8
9
10
#[cfg(feature = "minijinja")]
use super::chat_template::{ChatMessage, ChatTemplateProcessor};

11
12
13
14
/// HuggingFace tokenizer wrapper
pub struct HuggingFaceTokenizer {
    tokenizer: HfTokenizer,
    special_tokens: SpecialTokens,
15
16
17
18
    vocab: HashMap<String, TokenIdType>,
    reverse_vocab: HashMap<TokenIdType, String>,
    #[cfg(feature = "minijinja")]
    chat_template: Option<String>,
19
20
21
22
23
}

impl HuggingFaceTokenizer {
    /// Create a tokenizer from a HuggingFace tokenizer JSON file
    pub fn from_file(file_path: &str) -> Result<Self> {
24
25
26
27
28
29
30
31
        Self::from_file_with_chat_template(file_path, None)
    }

    /// Create a tokenizer from a HuggingFace tokenizer JSON file with an optional chat template
    pub fn from_file_with_chat_template(
        file_path: &str,
        chat_template_path: Option<&str>,
    ) -> Result<Self> {
32
33
34
35
36
37
38
39
        let tokenizer = HfTokenizer::from_file(file_path)
            .map_err(|e| Error::msg(format!("Failed to load tokenizer: {}", e)))?;

        // Extract special tokens
        let special_tokens = Self::extract_special_tokens(&tokenizer);

        // Build vocab mappings
        let vocab = tokenizer.get_vocab(false);
40
        let reverse_vocab: HashMap<TokenIdType, String> = vocab
41
42
43
44
            .iter()
            .map(|(token, &id)| (id, token.clone()))
            .collect();

45
46
47
48
49
50
51
52
53
54
        // Load chat template
        #[cfg(feature = "minijinja")]
        let chat_template = if let Some(template_path) = chat_template_path {
            // Load from specified .jinja file
            Self::load_chat_template_from_file(template_path)?
        } else {
            // Try to load from tokenizer_config.json
            Self::load_chat_template(file_path)
        };

55
56
57
58
59
        Ok(HuggingFaceTokenizer {
            tokenizer,
            special_tokens,
            vocab,
            reverse_vocab,
60
61
            #[cfg(feature = "minijinja")]
            chat_template,
62
63
64
65
66
67
68
        })
    }

    /// Create from an existing HuggingFace tokenizer
    pub fn from_tokenizer(tokenizer: HfTokenizer) -> Self {
        let special_tokens = Self::extract_special_tokens(&tokenizer);
        let vocab = tokenizer.get_vocab(false);
69
        let reverse_vocab: HashMap<TokenIdType, String> = vocab
70
71
72
73
74
75
76
77
78
            .iter()
            .map(|(token, &id)| (id, token.clone()))
            .collect();

        HuggingFaceTokenizer {
            tokenizer,
            special_tokens,
            vocab,
            reverse_vocab,
79
80
            #[cfg(feature = "minijinja")]
            chat_template: None,
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
        }
    }

    /// Extract special tokens from the tokenizer
    fn extract_special_tokens(tokenizer: &HfTokenizer) -> SpecialTokens {
        // Try to get special tokens from the tokenizer
        // This is a simplified version - actual implementation would need to handle various formats
        let vocab = tokenizer.get_vocab(true);

        let find_token = |patterns: &[&str]| -> Option<String> {
            for pattern in patterns {
                if vocab.contains_key(*pattern) {
                    return Some(pattern.to_string());
                }
            }
            None
        };

        SpecialTokens {
            bos_token: find_token(&["<s>", "<|startoftext|>", "<BOS>", "[CLS]"]),
            eos_token: find_token(&["</s>", "<|endoftext|>", "<EOS>", "[SEP]"]),
            unk_token: find_token(&["<unk>", "<UNK>", "[UNK]"]),
            sep_token: find_token(&["[SEP]", "<sep>", "<SEP>"]),
            pad_token: find_token(&["<pad>", "<PAD>", "[PAD]"]),
            cls_token: find_token(&["[CLS]", "<cls>", "<CLS>"]),
            mask_token: find_token(&["[MASK]", "<mask>", "<MASK>"]),
            additional_special_tokens: vec![],
        }
    }

111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
    /// Try to load chat template from tokenizer_config.json
    #[cfg(feature = "minijinja")]
    fn load_chat_template(tokenizer_path: &str) -> Option<String> {
        // Try to find tokenizer_config.json in the same directory
        let path = std::path::Path::new(tokenizer_path);
        let dir = path.parent()?;
        let config_path = dir.join("tokenizer_config.json");

        if config_path.exists() {
            if let Ok(template) =
                super::chat_template::load_chat_template_from_config(config_path.to_str()?)
            {
                return template;
            }
        }
        None
    }

    /// Load chat template from a .jinja file
    #[cfg(feature = "minijinja")]
    fn load_chat_template_from_file(template_path: &str) -> Result<Option<String>> {
        use std::fs;

        let content = fs::read_to_string(template_path)
            .map_err(|e| Error::msg(format!("Failed to read chat template file: {}", e)))?;

        // Clean up the template (similar to Python implementation)
        let template = content.trim().replace("\\n", "\n");

        Ok(Some(template))
    }

    /// Set or override the chat template
    #[cfg(feature = "minijinja")]
    pub fn set_chat_template(&mut self, template: String) {
        self.chat_template = Some(template);
    }

149
    /// Apply chat template if available
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
    #[cfg(feature = "minijinja")]
    pub fn apply_chat_template(
        &self,
        messages: &[ChatMessage],
        add_generation_prompt: bool,
    ) -> Result<String> {
        if let Some(ref template) = self.chat_template {
            let processor = ChatTemplateProcessor::new(
                template.clone(),
                self.special_tokens.bos_token.clone(),
                self.special_tokens.eos_token.clone(),
            );
            processor.apply_chat_template(messages, add_generation_prompt)
        } else {
            // Fallback to simple formatting if no template is available
            let mut result = String::new();
            for msg in messages {
                result.push_str(&format!("{}: {}\n", msg.role, msg.content));
            }
            if add_generation_prompt {
                result.push_str("assistant: ");
            }
            Ok(result)
        }
    }

    /// Apply chat template if available (without minijinja feature)
    #[cfg(not(feature = "minijinja"))]
    pub fn apply_chat_template(
        &self,
        messages: &[ChatMessage],
        add_generation_prompt: bool,
    ) -> Result<String> {
        // Fallback to simple formatting
184
185
186
187
        let mut result = String::new();
        for msg in messages {
            result.push_str(&format!("{}: {}\n", msg.role, msg.content));
        }
188
189
190
        if add_generation_prompt {
            result.push_str("assistant: ");
        }
191
192
193
194
195
196
        Ok(result)
    }
}

impl Encoder for HuggingFaceTokenizer {
    fn encode(&self, input: &str) -> Result<Encoding> {
197
198
        self.tokenizer
            .encode(input, false)
199
200
            .map_err(|e| Error::msg(format!("Encoding failed: {}", e)))
            .map(|encoding| Encoding::Hf(Box::new(encoding)))
201
202
203
204
205
206
    }

    fn encode_batch(&self, inputs: &[&str]) -> Result<Vec<Encoding>> {
        let encodings = self
            .tokenizer
            .encode_batch(inputs.to_vec(), false)
207
            .map_err(|e| Error::msg(format!("Batch encoding failed: {}", e)))?;
208
209
210
211
212
213
214
215
216

        Ok(encodings
            .into_iter()
            .map(|e| Encoding::Hf(Box::new(e)))
            .collect())
    }
}

impl Decoder for HuggingFaceTokenizer {
217
    fn decode(&self, token_ids: &[TokenIdType], skip_special_tokens: bool) -> Result<String> {
218
219
        self.tokenizer
            .decode(token_ids, skip_special_tokens)
220
            .map_err(|e| Error::msg(format!("Decoding failed: {}", e)))
221
222
223
224
225
226
227
228
229
230
231
232
    }
}

impl TokenizerTrait for HuggingFaceTokenizer {
    fn vocab_size(&self) -> usize {
        self.tokenizer.get_vocab_size(false)
    }

    fn get_special_tokens(&self) -> &SpecialTokens {
        &self.special_tokens
    }

233
    fn token_to_id(&self, token: &str) -> Option<TokenIdType> {
234
235
236
        self.vocab.get(token).copied()
    }

237
    fn id_to_token(&self, id: TokenIdType) -> Option<String> {
238
239
240
241
242
243
        self.reverse_vocab.get(&id).cloned()
    }
}

#[cfg(test)]
mod tests {
244
245
    #[cfg(feature = "minijinja")]
    use super::ChatMessage;
246

247
    #[cfg(feature = "minijinja")]
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
    #[test]
    fn test_chat_message_creation() {
        let msg = ChatMessage::system("You are a helpful assistant");
        assert_eq!(msg.role, "system");
        assert_eq!(msg.content, "You are a helpful assistant");

        let user_msg = ChatMessage::user("Hello!");
        assert_eq!(user_msg.role, "user");

        let assistant_msg = ChatMessage::assistant("Hi there!");
        assert_eq!(assistant_msg.role, "assistant");
    }

    // Note: Actual tokenizer tests would require a real tokenizer file
    // These would be integration tests rather than unit tests
}