mock.rs 3.24 KB
Newer Older
1
2
3
4
//! Mock tokenizer implementation for testing

use std::collections::HashMap;

5
6
7
8
use anyhow::Result;

use super::traits::{Decoder, Encoder, Encoding, SpecialTokens, Tokenizer as TokenizerTrait};

9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
/// Mock tokenizer for testing purposes
pub struct MockTokenizer {
    vocab: HashMap<String, u32>,
    reverse_vocab: HashMap<u32, String>,
    special_tokens: SpecialTokens,
}

impl Default for MockTokenizer {
    fn default() -> Self {
        Self::new()
    }
}

impl MockTokenizer {
    pub fn new() -> Self {
        let mut vocab = HashMap::new();
        let mut reverse_vocab = HashMap::new();

        // Add some basic tokens
        let tokens = vec![
            ("Hello", 1),
            ("world", 2),
            ("test", 3),
            ("token", 4),
            (" ", 5),
            (".", 6),
            ("<eos>", 999),
            ("<bos>", 1000),
37
38
39
40
41
42
            ("<|im_start|>", 1001),
            ("<|im_end|>", 1002),
            ("<|eot_id|>", 1003),
            ("system", 7),
            ("user", 8),
            ("assistant", 9),
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
        ];

        for (token, id) in tokens {
            vocab.insert(token.to_string(), id);
            reverse_vocab.insert(id, token.to_string());
        }

        let special_tokens = SpecialTokens {
            bos_token: Some("<bos>".to_string()),
            eos_token: Some("<eos>".to_string()),
            unk_token: Some("<unk>".to_string()),
            sep_token: None,
            pad_token: None,
            cls_token: None,
            mask_token: None,
            additional_special_tokens: vec![],
        };

        Self {
            vocab,
            reverse_vocab,
            special_tokens,
        }
    }
}

impl Encoder for MockTokenizer {
    fn encode(&self, input: &str) -> Result<Encoding> {
71
72
        // Simple word-based tokenization using the vocab
        // Split by whitespace and look up each word (decoder adds spaces back)
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
        let tokens: Vec<u32> = input
            .split_whitespace()
            .filter_map(|word| self.vocab.get(word).copied())
            .collect();

        Ok(Encoding::Sp(tokens))
    }

    fn encode_batch(&self, inputs: &[&str]) -> Result<Vec<Encoding>> {
        inputs.iter().map(|input| self.encode(input)).collect()
    }
}

impl Decoder for MockTokenizer {
    fn decode(&self, token_ids: &[u32], skip_special_tokens: bool) -> Result<String> {
        let tokens: Vec<String> = token_ids
            .iter()
            .filter_map(|id| {
                self.reverse_vocab.get(id).and_then(|token| {
                    if skip_special_tokens && (token == "<eos>" || token == "<bos>") {
                        None
                    } else {
                        Some(token.clone())
                    }
                })
            })
            .collect();

        Ok(tokens.join(" "))
    }
}

impl TokenizerTrait for MockTokenizer {
    fn vocab_size(&self) -> usize {
        self.vocab.len()
    }

    fn get_special_tokens(&self) -> &SpecialTokens {
        &self.special_tokens
    }

    fn token_to_id(&self, token: &str) -> Option<u32> {
        self.vocab.get(token).copied()
    }

    fn id_to_token(&self, id: u32) -> Option<String> {
        self.reverse_vocab.get(&id).cloned()
    }
121
122
123
124

    fn as_any(&self) -> &dyn std::any::Any {
        self
    }
125
}