huggingface.rs 7.78 KB
Newer Older
1
2
3
use super::traits::{
    Decoder, Encoder, Encoding, SpecialTokens, TokenIdType, Tokenizer as TokenizerTrait,
};
4
5
6
7
use anyhow::{Error, Result};
use std::collections::HashMap;
use tokenizers::tokenizer::Tokenizer as HfTokenizer;

8
9
use super::chat_template::{ChatMessage, ChatTemplateProcessor};

10
11
12
13
/// HuggingFace tokenizer wrapper
pub struct HuggingFaceTokenizer {
    tokenizer: HfTokenizer,
    special_tokens: SpecialTokens,
14
15
16
    vocab: HashMap<String, TokenIdType>,
    reverse_vocab: HashMap<TokenIdType, String>,
    chat_template: Option<String>,
17
18
19
20
21
}

impl HuggingFaceTokenizer {
    /// Create a tokenizer from a HuggingFace tokenizer JSON file
    pub fn from_file(file_path: &str) -> Result<Self> {
22
23
24
25
26
27
28
29
        Self::from_file_with_chat_template(file_path, None)
    }

    /// Create a tokenizer from a HuggingFace tokenizer JSON file with an optional chat template
    pub fn from_file_with_chat_template(
        file_path: &str,
        chat_template_path: Option<&str>,
    ) -> Result<Self> {
30
31
32
33
34
35
36
37
        let tokenizer = HfTokenizer::from_file(file_path)
            .map_err(|e| Error::msg(format!("Failed to load tokenizer: {}", e)))?;

        // Extract special tokens
        let special_tokens = Self::extract_special_tokens(&tokenizer);

        // Build vocab mappings
        let vocab = tokenizer.get_vocab(false);
38
        let reverse_vocab: HashMap<TokenIdType, String> = vocab
39
40
41
42
            .iter()
            .map(|(token, &id)| (id, token.clone()))
            .collect();

43
44
45
46
47
48
49
50
51
        // Load chat template
        let chat_template = if let Some(template_path) = chat_template_path {
            // Load from specified .jinja file
            Self::load_chat_template_from_file(template_path)?
        } else {
            // Try to load from tokenizer_config.json
            Self::load_chat_template(file_path)
        };

52
53
54
55
56
        Ok(HuggingFaceTokenizer {
            tokenizer,
            special_tokens,
            vocab,
            reverse_vocab,
57
            chat_template,
58
59
60
61
62
63
64
        })
    }

    /// Create from an existing HuggingFace tokenizer
    pub fn from_tokenizer(tokenizer: HfTokenizer) -> Self {
        let special_tokens = Self::extract_special_tokens(&tokenizer);
        let vocab = tokenizer.get_vocab(false);
65
        let reverse_vocab: HashMap<TokenIdType, String> = vocab
66
67
68
69
70
71
72
73
74
            .iter()
            .map(|(token, &id)| (id, token.clone()))
            .collect();

        HuggingFaceTokenizer {
            tokenizer,
            special_tokens,
            vocab,
            reverse_vocab,
75
            chat_template: None,
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
        }
    }

    /// Extract special tokens from the tokenizer
    fn extract_special_tokens(tokenizer: &HfTokenizer) -> SpecialTokens {
        // Try to get special tokens from the tokenizer
        // This is a simplified version - actual implementation would need to handle various formats
        let vocab = tokenizer.get_vocab(true);

        let find_token = |patterns: &[&str]| -> Option<String> {
            for pattern in patterns {
                if vocab.contains_key(*pattern) {
                    return Some(pattern.to_string());
                }
            }
            None
        };

        SpecialTokens {
            bos_token: find_token(&["<s>", "<|startoftext|>", "<BOS>", "[CLS]"]),
            eos_token: find_token(&["</s>", "<|endoftext|>", "<EOS>", "[SEP]"]),
            unk_token: find_token(&["<unk>", "<UNK>", "[UNK]"]),
            sep_token: find_token(&["[SEP]", "<sep>", "<SEP>"]),
            pad_token: find_token(&["<pad>", "<PAD>", "[PAD]"]),
            cls_token: find_token(&["[CLS]", "<cls>", "<CLS>"]),
            mask_token: find_token(&["[MASK]", "<mask>", "<MASK>"]),
            additional_special_tokens: vec![],
        }
    }

106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
    /// Try to load chat template from tokenizer_config.json
    fn load_chat_template(tokenizer_path: &str) -> Option<String> {
        // Try to find tokenizer_config.json in the same directory
        let path = std::path::Path::new(tokenizer_path);
        let dir = path.parent()?;
        let config_path = dir.join("tokenizer_config.json");

        if config_path.exists() {
            if let Ok(template) =
                super::chat_template::load_chat_template_from_config(config_path.to_str()?)
            {
                return template;
            }
        }
        None
    }

    /// Load chat template from a .jinja file
    fn load_chat_template_from_file(template_path: &str) -> Result<Option<String>> {
        use std::fs;

        let content = fs::read_to_string(template_path)
            .map_err(|e| Error::msg(format!("Failed to read chat template file: {}", e)))?;

        // Clean up the template (similar to Python implementation)
        let template = content.trim().replace("\\n", "\n");

        Ok(Some(template))
    }

    /// Set or override the chat template
    pub fn set_chat_template(&mut self, template: String) {
        self.chat_template = Some(template);
    }

141
    /// Apply chat template if available
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
    pub fn apply_chat_template(
        &self,
        messages: &[ChatMessage],
        add_generation_prompt: bool,
    ) -> Result<String> {
        if let Some(ref template) = self.chat_template {
            let processor = ChatTemplateProcessor::new(
                template.clone(),
                self.special_tokens.bos_token.clone(),
                self.special_tokens.eos_token.clone(),
            );
            processor.apply_chat_template(messages, add_generation_prompt)
        } else {
            // Fallback to simple formatting if no template is available
            let mut result = String::new();
            for msg in messages {
                result.push_str(&format!("{}: {}\n", msg.role, msg.content));
            }
            if add_generation_prompt {
                result.push_str("assistant: ");
            }
            Ok(result)
        }
    }
166
167
168
169
}

impl Encoder for HuggingFaceTokenizer {
    fn encode(&self, input: &str) -> Result<Encoding> {
170
171
        self.tokenizer
            .encode(input, false)
172
173
            .map_err(|e| Error::msg(format!("Encoding failed: {}", e)))
            .map(|encoding| Encoding::Hf(Box::new(encoding)))
174
175
176
177
178
179
    }

    fn encode_batch(&self, inputs: &[&str]) -> Result<Vec<Encoding>> {
        let encodings = self
            .tokenizer
            .encode_batch(inputs.to_vec(), false)
180
            .map_err(|e| Error::msg(format!("Batch encoding failed: {}", e)))?;
181
182
183
184
185
186
187
188
189

        Ok(encodings
            .into_iter()
            .map(|e| Encoding::Hf(Box::new(e)))
            .collect())
    }
}

impl Decoder for HuggingFaceTokenizer {
190
    fn decode(&self, token_ids: &[TokenIdType], skip_special_tokens: bool) -> Result<String> {
191
192
        self.tokenizer
            .decode(token_ids, skip_special_tokens)
193
            .map_err(|e| Error::msg(format!("Decoding failed: {}", e)))
194
195
196
197
198
199
200
201
202
203
204
205
    }
}

impl TokenizerTrait for HuggingFaceTokenizer {
    fn vocab_size(&self) -> usize {
        self.tokenizer.get_vocab_size(false)
    }

    fn get_special_tokens(&self) -> &SpecialTokens {
        &self.special_tokens
    }

206
    fn token_to_id(&self, token: &str) -> Option<TokenIdType> {
207
208
209
        self.vocab.get(token).copied()
    }

210
    fn id_to_token(&self, id: TokenIdType) -> Option<String> {
211
212
213
214
215
216
        self.reverse_vocab.get(&id).cloned()
    }
}

#[cfg(test)]
mod tests {
217
    use super::ChatMessage;
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234

    #[test]
    fn test_chat_message_creation() {
        let msg = ChatMessage::system("You are a helpful assistant");
        assert_eq!(msg.role, "system");
        assert_eq!(msg.content, "You are a helpful assistant");

        let user_msg = ChatMessage::user("Hello!");
        assert_eq!(user_msg.role, "user");

        let assistant_msg = ChatMessage::assistant("Hi there!");
        assert_eq!(assistant_msg.role, "assistant");
    }

    // Note: Actual tokenizer tests would require a real tokenizer file
    // These would be integration tests rather than unit tests
}