traits.rs 2.7 KB
Newer Older
1
use anyhow::Result;
2
3
4
5
6
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};

/// Type alias for token IDs
pub type TokenIdType = u32;
7
8
9
10
11
12
13
14
15

/// Core encoding trait - separate from decoding for modularity
pub trait Encoder: Send + Sync {
    fn encode(&self, input: &str) -> Result<Encoding>;
    fn encode_batch(&self, inputs: &[&str]) -> Result<Vec<Encoding>>;
}

/// Core decoding trait - can be implemented independently
pub trait Decoder: Send + Sync {
16
    fn decode(&self, token_ids: &[TokenIdType], skip_special_tokens: bool) -> Result<String>;
17
18
19
20
21
22
}

/// Combined tokenizer trait
pub trait Tokenizer: Encoder + Decoder {
    fn vocab_size(&self) -> usize;
    fn get_special_tokens(&self) -> &SpecialTokens;
23
24
    fn token_to_id(&self, token: &str) -> Option<TokenIdType>;
    fn id_to_token(&self, id: TokenIdType) -> Option<String>;
25
26
27
28
29
30
31
32
}

/// Contains the results of tokenizing text: token IDs, string tokens, and their spans
#[derive(Debug, Clone)]
pub enum Encoding {
    /// Hugging Face
    Hf(Box<tokenizers::tokenizer::Encoding>),
    /// Sentence Piece
33
34
35
    Sp(Vec<TokenIdType>),
    /// Tiktoken (for GPT models) - now uses u32 in tiktoken-rs 0.7.0
    Tiktoken(Vec<TokenIdType>),
36
37
38
}

impl Encoding {
39
40
    /// Returns a reference to token IDs when possible, owned Vec for compatibility
    pub fn token_ids(&self) -> Vec<TokenIdType> {
41
42
43
        match self {
            Encoding::Hf(inner) => inner.get_ids().to_vec(),
            Encoding::Sp(inner) => inner.clone(),
44
            Encoding::Tiktoken(inner) => inner.clone(),
45
46
47
        }
    }

48
49
    /// Returns a reference to token IDs where possible
    pub fn token_ids_ref(&self) -> &[TokenIdType] {
50
51
52
        match self {
            Encoding::Hf(inner) => inner.get_ids(),
            Encoding::Sp(inner) => inner,
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
            Encoding::Tiktoken(inner) => inner, // Now works with tiktoken-rs 0.7.0!
        }
    }

    /// Get a hash of the token IDs for caching purposes
    pub fn get_hash(&self) -> u64 {
        let mut hasher = DefaultHasher::new();
        self.hash(&mut hasher);
        hasher.finish()
    }
}

/// Hash implementation for Encoding
impl Hash for Encoding {
    fn hash<H: Hasher>(&self, state: &mut H) {
        match self {
            Encoding::Hf(inner) => inner.get_ids().hash(state),
            Encoding::Sp(inner) => inner.hash(state),
            Encoding::Tiktoken(inner) => inner.hash(state),
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
        }
    }
}

#[derive(Debug, Clone)]
pub struct SpecialTokens {
    pub bos_token: Option<String>,
    pub eos_token: Option<String>,
    pub unk_token: Option<String>,
    pub sep_token: Option<String>,
    pub pad_token: Option<String>,
    pub cls_token: Option<String>,
    pub mask_token: Option<String>,
    pub additional_special_tokens: Vec<String>,
}