"official/legacy/bert/common_flags.py" did not exist on "0203278a60fc9cebf65384d23d1f68a6128bf951"
factory.rs 9.03 KB
Newer Older
1
use super::traits::{self, Tokenizer as TokenizerTrait};
2
use crate::metrics::TokenizerMetrics;
3
4
5
6
7
use anyhow::{Error, Result};
use std::fs::File;
use std::io::Read;
use std::path::Path;
use std::sync::Arc;
8
use std::time::Instant;
9
10
11
12
13
14
15
16
17

#[cfg(feature = "huggingface")]
use super::huggingface::HuggingFaceTokenizer;

/// Represents the type of tokenizer being used
#[derive(Debug, Clone)]
pub enum TokenizerType {
    HuggingFace(String),
    Mock,
18
19
20
    #[cfg(feature = "tiktoken")]
    Tiktoken(String),
    // Future: SentencePiece, GGUF
21
22
23
24
25
26
27
28
}

/// Create a tokenizer from a file path to a tokenizer file.
/// The file extension is used to determine the tokenizer type.
/// Supported file types are:
/// - json: HuggingFace tokenizer
/// - For testing: can return mock tokenizer
pub fn create_tokenizer_from_file(file_path: &str) -> Result<Arc<dyn traits::Tokenizer>> {
29
30
    let start_time = Instant::now();

31
32
33
34
35
36
37
38
39
    // Special case for testing
    if file_path == "mock" || file_path == "test" {
        return Ok(Arc::new(super::mock::MockTokenizer::new()));
    }

    let path = Path::new(file_path);

    // Check if file exists
    if !path.exists() {
40
        TokenizerMetrics::record_factory_error("file_not_found");
41
42
43
44
45
46
47
48
49
        return Err(Error::msg(format!("File not found: {}", file_path)));
    }

    // Try to determine tokenizer type from extension
    let extension = path
        .extension()
        .and_then(std::ffi::OsStr::to_str)
        .map(|s| s.to_lowercase());

50
    let result = match extension.as_deref() {
51
52
53
54
        Some("json") => {
            #[cfg(feature = "huggingface")]
            {
                let tokenizer = HuggingFaceTokenizer::from_file(file_path)?;
55
56
57
58
59

                TokenizerMetrics::record_factory_load("json");
                TokenizerMetrics::set_vocab_size("huggingface", tokenizer.vocab_size());

                Ok(Arc::new(tokenizer) as Arc<dyn traits::Tokenizer>)
60
61
62
            }
            #[cfg(not(feature = "huggingface"))]
            {
63
                TokenizerMetrics::record_factory_error("huggingface_disabled");
64
65
66
67
68
69
70
                Err(Error::msg(
                    "HuggingFace support not enabled. Enable the 'huggingface' feature.",
                ))
            }
        }
        Some("model") => {
            // SentencePiece model file
71
            TokenizerMetrics::record_factory_error("unsupported_sentencepiece");
72
73
74
75
            Err(Error::msg("SentencePiece models not yet supported"))
        }
        Some("gguf") => {
            // GGUF format
76
            TokenizerMetrics::record_factory_error("unsupported_gguf");
77
78
79
80
            Err(Error::msg("GGUF format not yet supported"))
        }
        _ => {
            // Try to auto-detect by reading file content
81
82
83
84
            auto_detect_tokenizer(file_path).inspect(|tokenizer| {
                TokenizerMetrics::record_factory_load("auto_detected");
                TokenizerMetrics::set_vocab_size("auto_detected", tokenizer.vocab_size());
            })
85
        }
86
87
88
89
    };

    if result.is_ok() {
        TokenizerMetrics::record_factory_load_duration(start_time.elapsed());
90
    }
91
    result
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
}

/// Auto-detect tokenizer type by examining file content
fn auto_detect_tokenizer(file_path: &str) -> Result<Arc<dyn traits::Tokenizer>> {
    let mut file = File::open(file_path)?;
    let mut buffer = vec![0u8; 512]; // Read first 512 bytes for detection
    let bytes_read = file.read(&mut buffer)?;
    buffer.truncate(bytes_read);

    // Check for JSON (HuggingFace format)
    if is_likely_json(&buffer) {
        #[cfg(feature = "huggingface")]
        {
            let tokenizer = HuggingFaceTokenizer::from_file(file_path)?;
            return Ok(Arc::new(tokenizer));
        }
        #[cfg(not(feature = "huggingface"))]
        {
            return Err(Error::msg(
                "File appears to be JSON (HuggingFace) format, but HuggingFace support is not enabled",
            ));
        }
    }

    // Check for GGUF magic number
    if buffer.len() >= 4 && &buffer[0..4] == b"GGUF" {
        return Err(Error::msg("GGUF format detected but not yet supported"));
    }

    // Check for SentencePiece model
    if is_likely_sentencepiece(&buffer) {
        return Err(Error::msg(
            "SentencePiece model detected but not yet supported",
        ));
    }

    Err(Error::msg(format!(
        "Unable to determine tokenizer type for file: {}",
        file_path
    )))
}

/// Check if the buffer likely contains JSON data
fn is_likely_json(buffer: &[u8]) -> bool {
    // Skip UTF-8 BOM if present
    let content = if buffer.len() >= 3 && buffer[0..3] == [0xEF, 0xBB, 0xBF] {
        &buffer[3..]
    } else {
        buffer
    };

    // Find first non-whitespace character without allocation
    if let Some(first_byte) = content.iter().find(|&&b| !b.is_ascii_whitespace()) {
        *first_byte == b'{' || *first_byte == b'['
    } else {
        false
    }
}

/// Check if the buffer likely contains a SentencePiece model
fn is_likely_sentencepiece(buffer: &[u8]) -> bool {
    // SentencePiece models often start with specific patterns
    // This is a simplified check
    buffer.len() >= 12
        && (buffer.starts_with(b"\x0a\x09")
            || buffer.starts_with(b"\x08\x00")
            || buffer.windows(4).any(|w| w == b"<unk")
            || buffer.windows(4).any(|w| w == b"<s>")
            || buffer.windows(4).any(|w| w == b"</s>"))
}

/// Factory function to create tokenizer from a model name or path
pub fn create_tokenizer(model_name_or_path: &str) -> Result<Arc<dyn traits::Tokenizer>> {
    // Check if it's a file path
    let path = Path::new(model_name_or_path);
    if path.exists() {
        return create_tokenizer_from_file(model_name_or_path);
    }

171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
    // Check if it's a GPT model name that should use Tiktoken
    #[cfg(feature = "tiktoken")]
    {
        if model_name_or_path.contains("gpt-")
            || model_name_or_path.contains("davinci")
            || model_name_or_path.contains("curie")
            || model_name_or_path.contains("babbage")
            || model_name_or_path.contains("ada")
        {
            use super::tiktoken::TiktokenTokenizer;
            let tokenizer = TiktokenTokenizer::from_model_name(model_name_or_path)?;
            TokenizerMetrics::record_factory_load("tiktoken");
            TokenizerMetrics::set_vocab_size("tiktoken", tokenizer.vocab_size());
            return Ok(Arc::new(tokenizer));
        }
    }

188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
    // Otherwise, try to load from HuggingFace Hub
    #[cfg(feature = "huggingface")]
    {
        // This would download from HF Hub - not implemented yet
        Err(Error::msg(
            "Loading from HuggingFace Hub not yet implemented",
        ))
    }

    #[cfg(not(feature = "huggingface"))]
    {
        Err(Error::msg(format!(
            "Model '{}' not found locally and HuggingFace support is not enabled",
            model_name_or_path
        )))
    }
}

/// Get information about a tokenizer file
pub fn get_tokenizer_info(file_path: &str) -> Result<TokenizerType> {
    let path = Path::new(file_path);

    if !path.exists() {
        return Err(Error::msg(format!("File not found: {}", file_path)));
    }

    let extension = path
        .extension()
        .and_then(std::ffi::OsStr::to_str)
        .map(|s| s.to_lowercase());

    match extension.as_deref() {
        Some("json") => Ok(TokenizerType::HuggingFace(file_path.to_string())),
        _ => {
            // Try auto-detection
            use std::fs::File;
            use std::io::Read;

            let mut file = File::open(file_path)?;
            let mut buffer = vec![0u8; 512];
            let bytes_read = file.read(&mut buffer)?;
            buffer.truncate(bytes_read);

            if is_likely_json(&buffer) {
                Ok(TokenizerType::HuggingFace(file_path.to_string()))
            } else {
                Err(Error::msg("Unknown tokenizer type"))
            }
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_json_detection() {
        assert!(is_likely_json(b"{\"test\": \"value\"}"));
        assert!(is_likely_json(b"  \n\t{\"test\": \"value\"}"));
        assert!(is_likely_json(b"[1, 2, 3]"));
        assert!(!is_likely_json(b"not json"));
        assert!(!is_likely_json(b""));
    }

    #[test]
    fn test_mock_tokenizer_creation() {
        let tokenizer = create_tokenizer_from_file("mock").unwrap();
        assert_eq!(tokenizer.vocab_size(), 8); // Mock tokenizer has 8 tokens
    }

    #[test]
    fn test_file_not_found() {
        let result = create_tokenizer_from_file("/nonexistent/file.json");
        assert!(result.is_err());
        if let Err(e) = result {
            assert!(e.to_string().contains("File not found"));
        }
    }
267
268
269
270
271
272
273
274
275
276
277
278
279
280

    #[cfg(feature = "tiktoken")]
    #[test]
    fn test_create_tiktoken_tokenizer() {
        // Test creating tokenizer for GPT models
        let tokenizer = create_tokenizer("gpt-4").unwrap();
        assert!(tokenizer.vocab_size() > 0);

        // Test encoding and decoding
        let text = "Hello, world!";
        let encoding = tokenizer.encode(text).unwrap();
        let decoded = tokenizer.decode(&encoding.token_ids(), false).unwrap();
        assert_eq!(decoded, text);
    }
281
}