fastokens.rs 6.01 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

//! Fastokens backend using the `fastokens` crate for high-performance BPE encoding.
//!
//! `fastokens` only supports encoding, so this module provides a hybrid tokenizer that
//! uses `fastokens` for encoding and falls back to `HuggingFaceTokenizer` for decoding.
//! Both are loaded from the same `tokenizer.json` file.

use std::path::Path;

use rayon::prelude::*;

use super::{
    Encoding, Error, Result, TokenIdType,
    hf::HuggingFaceTokenizer,
17
    traits::{DecodeResult, Decoder, Encoder, Tokenizer},
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
};

/// Hybrid tokenizer: fast BPE encoding via `fastokens`, decoding via HuggingFace.
///
/// Both backends are loaded from the same `tokenizer.json` file.
pub struct FastTokenizer {
    fast_encoder: fastokens::Tokenizer,
    hf_decoder: HuggingFaceTokenizer,
}

impl FastTokenizer {
    pub fn from_file(path: &str) -> Result<Self> {
        let fast_encoder = fastokens::Tokenizer::from_file(Path::new(path))
            .map_err(|e| Error::msg(format!("Error loading fastokens tokenizer: {e}")))?;
        let hf_decoder = HuggingFaceTokenizer::from_file(path)?;
        Ok(Self {
            fast_encoder,
            hf_decoder,
        })
    }
}

impl Encoder for FastTokenizer {
    fn encode(&self, input: &str) -> Result<Encoding> {
        let ids = self
            .fast_encoder
            .encode(input)
            .map_err(|e| Error::msg(format!("Fastokens encode error: {e}")))?;
        Ok(Encoding::Sp(ids))
    }

    fn encode_batch(&self, inputs: &[&str]) -> Result<Vec<Encoding>> {
        inputs.par_iter().map(|input| self.encode(input)).collect()
    }
}

impl Decoder for FastTokenizer {
55
    fn decode(&self, token_ids: &[TokenIdType], skip_special_tokens: bool) -> Result<DecodeResult> {
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
        self.hf_decoder.decode(token_ids, skip_special_tokens)
    }
}

impl Tokenizer for FastTokenizer {}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::tokenizers::HuggingFaceTokenizer;

    // Minimal synthetic BPE tokenizer with no normalizer or post-processor --
    // compatible with fastokens. Vocab covers: H,T,a,d,e,h,i,l,o,r,s,t,w + punctuation.
    const TOKENIZER_PATH: &str = concat!(
        env!("CARGO_MANIFEST_DIR"),
        "/tests/data/sample-models/minimal-bpe/tokenizer.json"
    );

    #[test]
    fn test_fast_encode_decode_roundtrip() {
        let tokenizer = FastTokenizer::from_file(TOKENIZER_PATH).unwrap();
        // Encode then decode: verifies both paths execute without error.
        // With a null decoder, HF inserts spaces between tokens so exact equality
        // is not expected here -- we just verify the operations succeed and produce
        // non-empty results.
        let text = "Hello, world!";
        let encoding = tokenizer.encode(text).unwrap();
        assert!(!encoding.token_ids().is_empty());
84
        let decoded: String = tokenizer.decode(encoding.token_ids(), true).unwrap().into();
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
        assert!(!decoded.is_empty());
        // The decoded text should contain the same non-space characters
        let enc_chars: String = text.chars().filter(|c| !c.is_whitespace()).collect();
        let dec_chars: String = decoded.chars().filter(|c| !c.is_whitespace()).collect();
        assert_eq!(
            enc_chars, dec_chars,
            "non-space characters must be preserved"
        );
    }

    #[test]
    fn test_fast_matches_hf_encoding() {
        let fast = FastTokenizer::from_file(TOKENIZER_PATH).unwrap();
        let hf = HuggingFaceTokenizer::from_file(TOKENIZER_PATH).unwrap();

        for text in &["Hello, world!", "Hello", " world", "He llo"] {
            let fast_ids = fast.encode(text).unwrap();
            let hf_ids = hf.encode(text).unwrap();
            assert_eq!(
                fast_ids.token_ids(),
                hf_ids.token_ids(),
                "fastokens and HuggingFace must produce identical token IDs for '{text}'"
            );
        }
    }

    #[test]
    fn test_fast_batch_encode() {
        let tokenizer = FastTokenizer::from_file(TOKENIZER_PATH).unwrap();
        let inputs = &["Hello", " world", "Hello, world!"];
        let encodings = tokenizer.encode_batch(inputs).unwrap();
        assert_eq!(encodings.len(), inputs.len());
        for (enc, input) in encodings.iter().zip(inputs.iter()) {
            assert!(
                !enc.token_ids().is_empty(),
                "encoding for '{input}' must be non-empty"
            );
        }
    }

    #[test]
    fn test_fast_with_decode_stream() {
        use crate::tokenizers::Tokenizer as TokenizerWrapper;
        use std::sync::Arc;

        let tokenizer = Arc::new(FastTokenizer::from_file(TOKENIZER_PATH).unwrap());
        let wrapper = TokenizerWrapper::from(tokenizer);

        // Encode a prompt and a continuation, then step through the decode stream
        let prompt_ids = wrapper.encode("Hello").unwrap().token_ids().to_vec();
        let continuation = ", world!";
        let cont_ids = wrapper.encode(continuation).unwrap().token_ids().to_vec();

        let mut stream = wrapper.decode_stream(&prompt_ids, true);
        // Accumulate incremental chunks from decode_stream
        let mut accumulated = String::new();
        for id in &cont_ids {
            if let Some(chunk) = stream.step(*id).unwrap() {
                accumulated.push_str(&chunk);
            }
        }

        // DecodeStream uses prompt tokens as context, so the expected text is
        // decode(prompt + continuation) minus decode(prompt) -- not a bare
        // decode(continuation) which lacks the surrounding context.
        let mut all_ids = prompt_ids.clone();
        all_ids.extend_from_slice(&cont_ids);
152
153
        let full_text: String = wrapper.decode(&all_ids, true).unwrap().into();
        let prompt_text: String = wrapper.decode(&prompt_ids, true).unwrap().into();
154
155
156
157
158
159
160
        let expected = &full_text[prompt_text.len()..];
        assert_eq!(
            accumulated, expected,
            "streamed chunks must equal context-aware decoded continuation"
        );
    }
}