// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 //! Fastokens backend using the `fastokens` crate for high-performance BPE encoding. //! //! `fastokens` only supports encoding, so this module provides a hybrid tokenizer that //! uses `fastokens` for encoding and falls back to `HuggingFaceTokenizer` for decoding. //! Both are loaded from the same `tokenizer.json` file. use std::path::Path; use rayon::prelude::*; use super::{ Encoding, Error, Result, TokenIdType, hf::HuggingFaceTokenizer, traits::{DecodeResult, Decoder, Encoder, Tokenizer}, }; /// Hybrid tokenizer: fast BPE encoding via `fastokens`, decoding via HuggingFace. /// /// Both backends are loaded from the same `tokenizer.json` file. pub struct FastTokenizer { fast_encoder: fastokens::Tokenizer, hf_decoder: HuggingFaceTokenizer, } impl FastTokenizer { pub fn from_file(path: &str) -> Result { let fast_encoder = fastokens::Tokenizer::from_file(Path::new(path)) .map_err(|e| Error::msg(format!("Error loading fastokens tokenizer: {e}")))?; let hf_decoder = HuggingFaceTokenizer::from_file(path)?; Ok(Self { fast_encoder, hf_decoder, }) } } impl Encoder for FastTokenizer { fn encode(&self, input: &str) -> Result { let ids = self .fast_encoder .encode(input) .map_err(|e| Error::msg(format!("Fastokens encode error: {e}")))?; Ok(Encoding::Sp(ids)) } fn encode_batch(&self, inputs: &[&str]) -> Result> { inputs.par_iter().map(|input| self.encode(input)).collect() } } impl Decoder for FastTokenizer { fn decode(&self, token_ids: &[TokenIdType], skip_special_tokens: bool) -> Result { self.hf_decoder.decode(token_ids, skip_special_tokens) } } impl Tokenizer for FastTokenizer {} #[cfg(test)] mod tests { use super::*; use crate::tokenizers::HuggingFaceTokenizer; // Minimal synthetic BPE tokenizer with no normalizer or post-processor -- // compatible with fastokens. Vocab covers: H,T,a,d,e,h,i,l,o,r,s,t,w + punctuation. const TOKENIZER_PATH: &str = concat!( env!("CARGO_MANIFEST_DIR"), "/tests/data/sample-models/minimal-bpe/tokenizer.json" ); #[test] fn test_fast_encode_decode_roundtrip() { let tokenizer = FastTokenizer::from_file(TOKENIZER_PATH).unwrap(); // Encode then decode: verifies both paths execute without error. // With a null decoder, HF inserts spaces between tokens so exact equality // is not expected here -- we just verify the operations succeed and produce // non-empty results. let text = "Hello, world!"; let encoding = tokenizer.encode(text).unwrap(); assert!(!encoding.token_ids().is_empty()); let decoded: String = tokenizer.decode(encoding.token_ids(), true).unwrap().into(); assert!(!decoded.is_empty()); // The decoded text should contain the same non-space characters let enc_chars: String = text.chars().filter(|c| !c.is_whitespace()).collect(); let dec_chars: String = decoded.chars().filter(|c| !c.is_whitespace()).collect(); assert_eq!( enc_chars, dec_chars, "non-space characters must be preserved" ); } #[test] fn test_fast_matches_hf_encoding() { let fast = FastTokenizer::from_file(TOKENIZER_PATH).unwrap(); let hf = HuggingFaceTokenizer::from_file(TOKENIZER_PATH).unwrap(); for text in &["Hello, world!", "Hello", " world", "He llo"] { let fast_ids = fast.encode(text).unwrap(); let hf_ids = hf.encode(text).unwrap(); assert_eq!( fast_ids.token_ids(), hf_ids.token_ids(), "fastokens and HuggingFace must produce identical token IDs for '{text}'" ); } } #[test] fn test_fast_batch_encode() { let tokenizer = FastTokenizer::from_file(TOKENIZER_PATH).unwrap(); let inputs = &["Hello", " world", "Hello, world!"]; let encodings = tokenizer.encode_batch(inputs).unwrap(); assert_eq!(encodings.len(), inputs.len()); for (enc, input) in encodings.iter().zip(inputs.iter()) { assert!( !enc.token_ids().is_empty(), "encoding for '{input}' must be non-empty" ); } } #[test] fn test_fast_with_decode_stream() { use crate::tokenizers::Tokenizer as TokenizerWrapper; use std::sync::Arc; let tokenizer = Arc::new(FastTokenizer::from_file(TOKENIZER_PATH).unwrap()); let wrapper = TokenizerWrapper::from(tokenizer); // Encode a prompt and a continuation, then step through the decode stream let prompt_ids = wrapper.encode("Hello").unwrap().token_ids().to_vec(); let continuation = ", world!"; let cont_ids = wrapper.encode(continuation).unwrap().token_ids().to_vec(); let mut stream = wrapper.decode_stream(&prompt_ids, true); // Accumulate incremental chunks from decode_stream let mut accumulated = String::new(); for id in &cont_ids { if let Some(chunk) = stream.step(*id).unwrap() { accumulated.push_str(&chunk); } } // DecodeStream uses prompt tokens as context, so the expected text is // decode(prompt + continuation) minus decode(prompt) -- not a bare // decode(continuation) which lacks the surrounding context. let mut all_ids = prompt_ids.clone(); all_ids.extend_from_slice(&cont_ids); let full_text: String = wrapper.decode(&all_ids, true).unwrap().into(); let prompt_text: String = wrapper.decode(&prompt_ids, true).unwrap().into(); let expected = &full_text[prompt_text.len()..]; assert_eq!( accumulated, expected, "streamed chunks must equal context-aware decoded continuation" ); } }