mod.rs 2.81 KB
Newer Older
1
2
3
4
use anyhow::Result;
use std::ops::Deref;
use std::sync::Arc;

5
pub mod factory;
6
pub mod mock;
7
pub mod sequence;
8
pub mod stop;
9
10
11
pub mod stream;
pub mod traits;

12
13
14
15
// Feature-gated modules
#[cfg(feature = "huggingface")]
pub mod huggingface;

16
17
18
#[cfg(feature = "tiktoken")]
pub mod tiktoken;

19
20
21
#[cfg(test)]
mod tests;

22
23
// Re-exports
pub use factory::{create_tokenizer, create_tokenizer_from_file, TokenizerType};
24
pub use sequence::Sequence;
25
pub use stop::{SequenceDecoderOutput, StopSequenceConfig, StopSequenceDecoder};
26
27
28
pub use stream::DecodeStream;
pub use traits::{Decoder, Encoder, Encoding, SpecialTokens, Tokenizer as TokenizerTrait};

29
30
31
#[cfg(feature = "huggingface")]
pub use huggingface::{ChatMessage, HuggingFaceTokenizer};

32
33
34
#[cfg(feature = "tiktoken")]
pub use tiktoken::{TiktokenModel, TiktokenTokenizer};

35
36
37
38
39
40
/// Main tokenizer wrapper that provides a unified interface for different tokenizer implementations
#[derive(Clone)]
pub struct Tokenizer(Arc<dyn traits::Tokenizer>);

impl Tokenizer {
    /// Create a tokenizer from a file path
41
42
    pub fn from_file(file_path: &str) -> Result<Tokenizer> {
        Ok(Tokenizer(factory::create_tokenizer_from_file(file_path)?))
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
    }

    /// Create a tokenizer from an Arc<dyn Tokenizer>
    pub fn from_arc(tokenizer: Arc<dyn traits::Tokenizer>) -> Self {
        Tokenizer(tokenizer)
    }

    /// Create a stateful sequence object for decoding token_ids into text
    pub fn decode_stream(
        &self,
        prompt_token_ids: &[u32],
        skip_special_tokens: bool,
    ) -> DecodeStream {
        DecodeStream::new(self.0.clone(), prompt_token_ids, skip_special_tokens)
    }

    /// Direct encode method
    pub fn encode(&self, input: &str) -> Result<Encoding> {
        self.0.encode(input)
    }

    /// Direct batch encode method
    pub fn encode_batch(&self, inputs: &[&str]) -> Result<Vec<Encoding>> {
        self.0.encode_batch(inputs)
    }

    /// Direct decode method
    pub fn decode(&self, token_ids: &[u32], skip_special_tokens: bool) -> Result<String> {
        self.0.decode(token_ids, skip_special_tokens)
    }

    /// Get vocabulary size
    pub fn vocab_size(&self) -> usize {
        self.0.vocab_size()
    }

    /// Get special tokens
    pub fn get_special_tokens(&self) -> &SpecialTokens {
        self.0.get_special_tokens()
    }

    /// Convert token string to ID
    pub fn token_to_id(&self, token: &str) -> Option<u32> {
        self.0.token_to_id(token)
    }

    /// Convert ID to token string
    pub fn id_to_token(&self, id: u32) -> Option<String> {
        self.0.id_to_token(id)
    }
}

impl Deref for Tokenizer {
    type Target = Arc<dyn traits::Tokenizer>;

    fn deref(&self) -> &Self::Target {
        &self.0
    }
}

impl From<Arc<dyn traits::Tokenizer>> for Tokenizer {
    fn from(tokenizer: Arc<dyn traits::Tokenizer>) -> Self {
        Tokenizer(tokenizer)
    }
}