traits.rs 2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
use anyhow::Result;

/// Core encoding trait - separate from decoding for modularity
pub trait Encoder: Send + Sync {
    fn encode(&self, input: &str) -> Result<Encoding>;
    fn encode_batch(&self, inputs: &[&str]) -> Result<Vec<Encoding>>;
}

/// Core decoding trait - can be implemented independently
pub trait Decoder: Send + Sync {
    fn decode(&self, token_ids: &[u32], skip_special_tokens: bool) -> Result<String>;
}

/// Combined tokenizer trait
pub trait Tokenizer: Encoder + Decoder {
    fn vocab_size(&self) -> usize;
    fn get_special_tokens(&self) -> &SpecialTokens;
    fn token_to_id(&self, token: &str) -> Option<u32>;
    fn id_to_token(&self, id: u32) -> Option<String>;
}

/// Contains the results of tokenizing text: token IDs, string tokens, and their spans
#[derive(Debug, Clone)]
pub enum Encoding {
    /// Hugging Face
    Hf(Box<tokenizers::tokenizer::Encoding>),
    /// Sentence Piece
    Sp(Vec<u32>),
29
30
    /// Tiktoken (for GPT models)
    Tiktoken(Vec<usize>),
31
32
33
}

impl Encoding {
34
35
36
37
38
39
40
41
42
    pub fn token_ids(&self) -> Vec<u32> {
        match self {
            Encoding::Hf(inner) => inner.get_ids().to_vec(),
            Encoding::Sp(inner) => inner.clone(),
            Encoding::Tiktoken(inner) => inner.iter().map(|&id| id as u32).collect(),
        }
    }

    pub fn token_ids_ref(&self) -> &[u32] {
43
44
45
        match self {
            Encoding::Hf(inner) => inner.get_ids(),
            Encoding::Sp(inner) => inner,
46
47
48
49
50
            Encoding::Tiktoken(_) => {
                // Tiktoken uses usize, we can't return a reference to u32
                // This is a limitation - callers should use token_ids() for Tiktoken
                &[]
            }
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
        }
    }
}

#[derive(Debug, Clone)]
pub struct SpecialTokens {
    pub bos_token: Option<String>,
    pub eos_token: Option<String>,
    pub unk_token: Option<String>,
    pub sep_token: Option<String>,
    pub pad_token: Option<String>,
    pub cls_token: Option<String>,
    pub mask_token: Option<String>,
    pub additional_special_tokens: Vec<String>,
}