Unverified Commit 21fce9ba authored by Nikita's avatar Nikita Committed by GitHub
Browse files

feat: Tiktoken support (#6460)


Signed-off-by: default avatarNikita Sukharev <kaonael@gmail.com>
parent 8483e4a0
......@@ -1974,6 +1974,7 @@ dependencies = [
"reqwest 0.12.28",
"rmp-serde",
"rstest 0.18.2",
"rustc-hash 1.1.0",
"rustls",
"serde",
"serde_json",
......@@ -1982,6 +1983,7 @@ dependencies = [
"temp-env",
"tempfile",
"thiserror 2.0.18",
"tiktoken-rs",
"tmq",
"tokenizers",
"tokio",
......@@ -7504,6 +7506,21 @@ dependencies = [
"zune-jpeg 0.4.21",
]
[[package]]
name = "tiktoken-rs"
version = "0.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3a19830747d9034cd9da43a60eaa8e552dfda7712424aebf187b7a60126bae0d"
dependencies = [
"anyhow",
"base64 0.22.1",
"bstr",
"fancy-regex 0.13.0",
"lazy_static",
"regex",
"rustc-hash 1.1.0",
]
[[package]]
name = "time"
version = "0.3.47"
......
......@@ -1593,12 +1593,14 @@ dependencies = [
"rayon",
"reqwest",
"rmp-serde",
"rustc-hash 1.1.0",
"rustls",
"serde",
"serde_json",
"strum",
"tempfile",
"thiserror 2.0.18",
"tiktoken-rs",
"tmq",
"tokenizers",
"tokio",
......@@ -6433,6 +6435,21 @@ dependencies = [
"zune-jpeg 0.4.21",
]
[[package]]
name = "tiktoken-rs"
version = "0.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3a19830747d9034cd9da43a60eaa8e552dfda7712424aebf187b7a60126bae0d"
dependencies = [
"anyhow",
"base64 0.22.1",
"bstr",
"fancy-regex",
"lazy_static",
"regex",
"rustc-hash 1.1.0",
]
[[package]]
name = "time"
version = "0.3.47"
......
......@@ -1603,12 +1603,14 @@ dependencies = [
"rayon",
"reqwest",
"rmp-serde",
"rustc-hash 1.1.0",
"rustls",
"serde",
"serde_json",
"strum",
"tempfile",
"thiserror 2.0.18",
"tiktoken-rs",
"tmq",
"tokenizers",
"tokio",
......@@ -6491,6 +6493,21 @@ dependencies = [
"zune-jpeg 0.4.21",
]
[[package]]
name = "tiktoken-rs"
version = "0.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3a19830747d9034cd9da43a60eaa8e552dfda7712424aebf187b7a60126bae0d"
dependencies = [
"anyhow",
"base64 0.22.1",
"bstr",
"fancy-regex",
"lazy_static",
"regex",
"rustc-hash 1.1.0",
]
[[package]]
name = "time"
version = "0.3.47"
......
......@@ -141,6 +141,8 @@ tokenizers = { version = "0.21.4", default-features = false, features = [
"esaxx_fast",
"rustls-tls",
] }
tiktoken-rs = { version = "0.9", default-features = false }
rustc-hash = "1.1"
# backend
galil-seiferas = { version = "0.1" }
......
......@@ -10,6 +10,7 @@ use dynamo_llm::backend::Decoder;
use dynamo_llm::protocols::common::StopConditions;
use dynamo_llm::tokenizers::DecodeStream;
use dynamo_llm::tokenizers::hf::HuggingFaceTokenizer;
use dynamo_llm::tokenizers::tiktoken::TikTokenTokenizer;
use dynamo_llm::tokenizers::traits::{Encoder, Tokenizer};
use dynamo_llm::types::TokenIdType;
......@@ -18,6 +19,11 @@ const TEST_TOKENIZER: &str = concat!(
"/tests/data/sample-models/TinyLlama_v1.1/tokenizer.json"
);
const TEST_TIKTOKEN: &str = concat!(
env!("CARGO_MANIFEST_DIR"),
"/tests/data/sample-models/mock-tiktoken/tiktoken.model"
);
/// Input Sequence Length for tokenizer
const TARGET_ISL: usize = 8_000;
......@@ -90,5 +96,53 @@ pub fn decode_big(c: &mut Criterion) {
group.finish();
}
criterion_group!(benches, encode, decode, decode_big);
pub fn tiktoken_encode(c: &mut Criterion) {
let test_str: &str = &INPUT_STR.repeat(TARGET_ISL / INPUT_STR.len());
let encoder = TikTokenTokenizer::from_file_auto(TEST_TIKTOKEN).unwrap();
let mut group = c.benchmark_group("tiktoken-encode-group");
group.throughput(Throughput::Bytes(test_str.len() as u64));
group.bench_function("tiktoken_encode", |b| {
b.iter(|| {
let _ = encoder.encode(black_box(test_str)).unwrap();
})
});
group.finish();
}
pub fn tiktoken_decode(c: &mut Criterion) {
// Encode a test string to get realistic token IDs for this tokenizer
let encoder = TikTokenTokenizer::from_file_auto(TEST_TIKTOKEN).unwrap();
let encoding = encoder.encode(INPUT_STR).unwrap();
let test_toks: Vec<TokenIdType> = encoding.token_ids().to_vec();
let mut group = c.benchmark_group("tiktoken-decode-group");
group.throughput(Throughput::Elements(test_toks.len() as u64));
group.bench_function("tiktoken_decoder", |b| {
let toks = test_toks.clone();
b.iter_with_setup(
|| {
let tokenizer: Arc<dyn Tokenizer> =
Arc::new(TikTokenTokenizer::from_file_auto(TEST_TIKTOKEN).unwrap());
let ds = DecodeStream::new(tokenizer, &[], false);
Decoder::new(ds, StopConditions::default(), false, None)
},
|mut decoder| {
for tok in black_box(&toks) {
let _ = decoder.step(*tok).unwrap();
}
},
)
});
group.finish();
}
criterion_group!(
benches,
encode,
decode,
decode_big,
tiktoken_encode,
tiktoken_decode
);
criterion_main!(benches);
......@@ -41,9 +41,8 @@ use crate::protocols::{
timing::RequestTracker,
},
};
use crate::tokenizers::{DecodeStream, HuggingFaceTokenizer, Tokenizer};
use crate::tokenizers::{DecodeStream, Tokenizer};
use dynamo_async_openai::types::StopReason;
use tokenizers::Tokenizer as HfTokenizer;
/// Represents the output stream from the execution engine
pub type ExecutionOutputStream = Annotated<LLMEngineOutput>;
......@@ -69,10 +68,7 @@ struct DecoderUnfoldState {
}
impl Backend {
pub fn from_tokenizer(tokenizer: HfTokenizer) -> Arc<Self> {
let tokenizer = HuggingFaceTokenizer::from_tokenizer(tokenizer);
let tokenizer = Tokenizer::from(Arc::new(tokenizer));
pub fn from_tokenizer(tokenizer: Tokenizer) -> Arc<Self> {
Arc::new(Self {
tokenizer: Some(tokenizer),
validate_engine_decode: false,
......@@ -80,10 +76,10 @@ impl Backend {
}
pub fn from_mdc(mdc: &ModelDeploymentCard) -> Arc<Self> {
match mdc.tokenizer_hf() {
match mdc.tokenizer() {
Ok(tokenizer) => Self::from_tokenizer(tokenizer),
Err(err) => {
tracing::warn!(%err, "tokenizer_hf error converting ModelDeploymentCard to HF tokenizer");
tracing::warn!(%err, "error loading tokenizer from ModelDeploymentCard");
Arc::new(Self {
tokenizer: None,
validate_engine_decode: false,
......
......@@ -472,7 +472,7 @@ impl ModelWatcher {
};
// This is expensive, we are loading ~10MiB JSON, so only do it once
let tokenizer_hf = card.tokenizer_hf().context("tokenizer_hf")?;
let tokenizer = card.tokenizer().context("tokenizer")?;
// Create prefill chooser once if we're building pipelines
// Both chat and completions will share the same prefill chooser instance
......@@ -534,7 +534,7 @@ impl ModelWatcher {
self.router_config.router_mode,
worker_monitor.clone(),
kv_chooser.clone(),
tokenizer_hf.clone(),
tokenizer.clone(),
prefill_chooser.clone(),
self.router_config.decode_fallback,
self.migration_limit,
......@@ -551,12 +551,9 @@ impl ModelWatcher {
if card.model_type.supports_completions() {
let formatter = PromptFormatter::no_op();
let PromptFormatter::OAI(formatter) = formatter;
let preprocessor = OpenAIPreprocessor::new_with_parts(
card.clone(),
formatter,
tokenizer_hf.clone(),
)
.context("OpenAIPreprocessor::new_with_parts")?;
let preprocessor =
OpenAIPreprocessor::new_with_parts(card.clone(), formatter, tokenizer.clone())
.context("OpenAIPreprocessor::new_with_parts")?;
let completions_engine = entrypoint::build_routed_pipeline_with_preprocessor::<
NvCreateCompletionRequest,
NvCreateCompletionResponse,
......@@ -568,7 +565,7 @@ impl ModelWatcher {
worker_monitor,
kv_chooser,
preprocessor,
tokenizer_hf,
tokenizer,
prefill_chooser,
self.router_config.decode_fallback,
self.migration_limit,
......
......@@ -131,7 +131,7 @@ pub async fn prepare_engine(
let pipeline = build_pipeline::<
NvCreateChatCompletionRequest,
NvCreateChatCompletionStreamResponse,
>(model.card(), inner_engine, model.card().tokenizer_hf()?)
>(model.card(), inner_engine, model.card().tokenizer()?)
.await?;
let service_name = model.service_name().to_string();
......@@ -150,7 +150,7 @@ pub async fn prepare_engine(
pub async fn build_pipeline<Req, Resp>(
card: &ModelDeploymentCard,
engine: ExecutionContext,
hf_tokenizer: tokenizers::Tokenizer,
tokenizer: crate::tokenizers::Tokenizer,
) -> anyhow::Result<Arc<ServiceFrontend<SingleIn<Req>, ManyOut<Annotated<Resp>>>>>
where
Req: Data,
......@@ -165,9 +165,9 @@ where
let frontend = ServiceFrontend::<SingleIn<Req>, ManyOut<Annotated<Resp>>>::new();
let PromptFormatter::OAI(formatter) = PromptFormatter::from_mdc(card)?;
let preprocessor =
OpenAIPreprocessor::new_with_parts(card.clone(), formatter, hf_tokenizer.clone())?
OpenAIPreprocessor::new_with_parts(card.clone(), formatter, tokenizer.clone())?
.into_operator();
let backend = Backend::from_tokenizer(hf_tokenizer).into_operator();
let backend = Backend::from_tokenizer(tokenizer).into_operator();
let engine = ServiceBackend::from_engine(engine);
Ok(frontend
......@@ -187,7 +187,7 @@ pub async fn build_routed_pipeline<Req, Resp>(
router_mode: RouterMode,
worker_monitor: Option<KvWorkerMonitor>,
chooser: Option<Arc<KvRouter>>,
hf_tokenizer: tokenizers::Tokenizer,
tokenizer: crate::tokenizers::Tokenizer,
prefill_chooser: Option<Arc<PrefillRouter>>,
decode_fallback: bool,
migration_limit: u32,
......@@ -206,7 +206,7 @@ where
let PromptFormatter::OAI(formatter) =
PromptFormatter::from_mdc(card).context("PromptFormatter.from_mdc")?;
let preprocessor =
OpenAIPreprocessor::new_with_parts(card.clone(), formatter, hf_tokenizer.clone())
OpenAIPreprocessor::new_with_parts(card.clone(), formatter, tokenizer.clone())
.context("OpenAIPreprocessor.new_with_parts")?;
build_routed_pipeline_with_preprocessor(
card,
......@@ -216,7 +216,7 @@ where
worker_monitor,
chooser,
preprocessor,
hf_tokenizer,
tokenizer,
prefill_chooser,
decode_fallback,
migration_limit,
......@@ -234,7 +234,7 @@ pub async fn build_routed_pipeline_with_preprocessor<Req, Resp>(
worker_monitor: Option<KvWorkerMonitor>,
chooser: Option<Arc<KvRouter>>,
preprocessor: Arc<OpenAIPreprocessor>,
hf_tokenizer: tokenizers::Tokenizer,
tokenizer: crate::tokenizers::Tokenizer,
prefill_chooser: Option<Arc<PrefillRouter>>,
decode_fallback: bool,
migration_limit: u32,
......@@ -252,7 +252,7 @@ where
{
let frontend = SegmentSource::<SingleIn<Req>, ManyOut<Annotated<Resp>>>::new();
let preprocessor_op = preprocessor.into_operator();
let backend = Backend::from_tokenizer(hf_tokenizer).into_operator();
let backend = Backend::from_tokenizer(tokenizer).into_operator();
let migration = Migration::from_mdc(card, migration_limit, metrics).into_operator();
// For KV routing, use the client from the chooser to ensure shared state
......
......@@ -70,19 +70,18 @@ pub async fn run(
let manager = grpc_service.model_manager();
let checksum = model.card().mdcsum();
let tokenizer_hf = model.card().tokenizer_hf()?;
let chat_pipeline =
common::build_pipeline::<
NvCreateChatCompletionRequest,
NvCreateChatCompletionStreamResponse,
>(model.card(), inner_engine.clone(), tokenizer_hf.clone())
.await?;
let tokenizer = model.card().tokenizer()?;
let chat_pipeline = common::build_pipeline::<
NvCreateChatCompletionRequest,
NvCreateChatCompletionStreamResponse,
>(model.card(), inner_engine.clone(), tokenizer.clone())
.await?;
manager.add_chat_completions_model(model.service_name(), checksum, chat_pipeline)?;
let cmpl_pipeline = common::build_pipeline::<
NvCreateCompletionRequest,
NvCreateCompletionResponse,
>(model.card(), inner_engine, tokenizer_hf)
>(model.card(), inner_engine, tokenizer)
.await?;
manager.add_completions_model(model.service_name(), checksum, cmpl_pipeline)?;
grpc_service
......
......@@ -118,19 +118,18 @@ pub async fn run(
let manager = http_service.model_manager();
let checksum = model.card().mdcsum();
let tokenizer_hf = model.card().tokenizer_hf()?;
let chat_pipeline =
common::build_pipeline::<
NvCreateChatCompletionRequest,
NvCreateChatCompletionStreamResponse,
>(model.card(), inner_engine.clone(), tokenizer_hf.clone())
.await?;
let tokenizer = model.card().tokenizer()?;
let chat_pipeline = common::build_pipeline::<
NvCreateChatCompletionRequest,
NvCreateChatCompletionStreamResponse,
>(model.card(), inner_engine.clone(), tokenizer.clone())
.await?;
manager.add_chat_completions_model(model.display_name(), checksum, chat_pipeline)?;
let cmpl_pipeline = common::build_pipeline::<
NvCreateCompletionRequest,
NvCreateCompletionResponse,
>(model.card(), inner_engine, tokenizer_hf)
>(model.card(), inner_engine, tokenizer)
.await?;
manager.add_completions_model(model.display_name(), checksum, cmpl_pipeline)?;
// Enable all endpoints
......
......@@ -26,8 +26,10 @@ fn get_cached_model_path(model_name: &str, ignore_weights: bool) -> Option<PathB
let config_path = repo.get("config.json")?;
// Check for tokenizer files (at least one must exist)
let has_tokenizer =
repo.get("tokenizer.json").is_some() || repo.get("tokenizer_config.json").is_some();
let has_tokenizer = repo.get("tokenizer.json").is_some()
|| repo.get("tokenizer_config.json").is_some()
|| repo.get("tiktoken.model").is_some()
|| has_tiktoken_file(config_path.parent()?);
if !has_tokenizer {
return None;
......@@ -52,6 +54,15 @@ fn get_cached_model_path(model_name: &str, ignore_weights: bool) -> Option<PathB
Some(snapshot_path)
}
/// Check if the snapshot directory contains any `*.tiktoken` file (e.g. `qwen.tiktoken`).
fn has_tiktoken_file(dir: &Path) -> bool {
std::fs::read_dir(dir)
.into_iter()
.flatten()
.flatten()
.any(|e| e.path().extension().is_some_and(|ext| ext == "tiktoken"))
}
/// Check if offline mode is enabled via HF_HUB_OFFLINE environment variable.
fn is_offline_mode() -> bool {
env::var(env_model::huggingface::HF_HUB_OFFLINE)
......
......@@ -61,24 +61,29 @@ impl ModelInfoType {
#[serde(rename_all = "snake_case")]
pub enum TokenizerKind {
HfTokenizerJson(CheckedFile),
TikTokenModel(CheckedFile),
}
impl TokenizerKind {
pub fn checksum(&self) -> String {
match self {
TokenizerKind::HfTokenizerJson(c) => c.checksum().to_string(),
TokenizerKind::HfTokenizerJson(c) | TokenizerKind::TikTokenModel(c) => {
c.checksum().to_string()
}
}
}
pub fn is_local(&self) -> bool {
match self {
TokenizerKind::HfTokenizerJson(c) => c.is_local(),
TokenizerKind::HfTokenizerJson(c) | TokenizerKind::TikTokenModel(c) => c.is_local(),
}
}
pub fn update_dir(&mut self, dir: &Path) {
match self {
TokenizerKind::HfTokenizerJson(c) => c.update_dir(dir),
TokenizerKind::HfTokenizerJson(c) | TokenizerKind::TikTokenModel(c) => {
c.update_dir(dir)
}
}
}
}
......@@ -371,13 +376,15 @@ impl ModelDeploymentCard {
self.tokenizer.is_some()
}
pub fn tokenizer_hf(&self) -> anyhow::Result<HfTokenizer> {
/// Load the tokenizer as a generic, backend-agnostic `Tokenizer` trait object.
/// This supports both HuggingFace `tokenizer.json` and tiktoken `.model`/`.tiktoken` files.
pub fn tokenizer(&self) -> anyhow::Result<crate::tokenizers::Tokenizer> {
match &self.tokenizer {
Some(TokenizerKind::HfTokenizerJson(checked_file)) => {
let p = checked_file.path().ok_or_else(|| {
anyhow::anyhow!("Tokenizer is URL-backed ({:?})", checked_file.url())
})?;
HfTokenizer::from_file(p)
let hf = HfTokenizer::from_file(p)
.inspect_err(|err| {
if let Some(serde_err) = err.downcast_ref::<serde_json::Error>()
&& let Ok(contents) = std::fs::read_to_string(p)
......@@ -386,7 +393,23 @@ impl ModelDeploymentCard {
}
})
.map_err(anyhow::Error::msg)
.with_context(|| p.display().to_string())
.with_context(|| p.display().to_string())?;
Ok(crate::tokenizers::Tokenizer::from(Arc::new(
crate::tokenizers::HuggingFaceTokenizer::from_tokenizer(hf),
)))
}
Some(TokenizerKind::TikTokenModel(checked_file)) => {
let p = checked_file.path().ok_or_else(|| {
anyhow::anyhow!("Tokenizer is URL-backed ({:?})", checked_file.url())
})?;
let path_str = p.to_str().ok_or_else(|| {
anyhow::anyhow!("Tokenizer path contains invalid UTF-8: {}", p.display())
})?;
let tokenizer = crate::tokenizers::TikTokenTokenizer::from_file_auto(path_str)
.with_context(|| {
format!("Failed to load tiktoken tokenizer from {}", p.display())
})?;
Ok(crate::tokenizers::Tokenizer::from(Arc::new(tokenizer)))
}
None => {
anyhow::bail!(
......@@ -480,8 +503,9 @@ impl ModelDeploymentCard {
PromptFormatterArtifact::HfTokenizerConfigJson
);
// tokenizer.json
// tokenizer.json or tiktoken.model
change!(self.tokenizer, TokenizerKind::HfTokenizerJson);
change!(self.tokenizer, TokenizerKind::TikTokenModel);
// We only "move" the chat template if it came form the repo. If we have a custom template
// file we cannot download that from HF.
......@@ -945,13 +969,44 @@ impl PromptFormatterArtifact {
impl TokenizerKind {
pub fn from_disk(directory: &Path) -> Result<Self> {
let f = CheckedFile::from_disk(directory.join("tokenizer.json")).with_context(|| {
format!(
"unable to extract tokenizer kind from directory {}",
directory.display()
)
})?;
Ok(Self::HfTokenizerJson(f))
// 1. Try tokenizer.json (HuggingFace)
if let Ok(f) = CheckedFile::from_disk(directory.join("tokenizer.json")) {
return Ok(Self::HfTokenizerJson(f));
}
// 2. Try tiktoken.model
if let Ok(f) = CheckedFile::from_disk(directory.join("tiktoken.model")) {
return Ok(Self::TikTokenModel(f));
}
// 3. Search for any *.tiktoken file
let tiktoken_files: Vec<_> = std::fs::read_dir(directory)
.into_iter()
.flatten()
.flatten()
.filter(|entry| entry.path().extension().is_some_and(|e| e == "tiktoken"))
.collect();
if tiktoken_files.len() == 1 {
if let Ok(f) = CheckedFile::from_disk(tiktoken_files[0].path()) {
return Ok(Self::TikTokenModel(f));
}
} else if tiktoken_files.len() > 1 {
let names: Vec<_> = tiktoken_files
.iter()
.map(|e| e.path().display().to_string())
.collect();
anyhow::bail!(
"Multiple .tiktoken files found in {}: {:?}. Cannot determine which to use.",
directory.display(),
names
);
}
anyhow::bail!(
"No tokenizer.json or tiktoken model file found in {}",
directory.display()
)
}
}
......
......@@ -56,7 +56,7 @@ use crate::protocols::{
nvext::NvExtProvider,
},
};
use crate::tokenizers::{HuggingFaceTokenizer, traits::Tokenizer};
use crate::tokenizers::traits::Tokenizer;
use crate::preprocessor::prompt::{PromptFormatter, PromptInput, TextInput, TokenInput};
......@@ -151,7 +151,7 @@ pub struct OpenAIPreprocessor {
impl OpenAIPreprocessor {
pub fn new(mdc: ModelDeploymentCard) -> Result<Arc<Self>> {
let formatter = PromptFormatter::from_mdc(&mdc)?;
let tokenizer = mdc.tokenizer_hf()?;
let tokenizer = mdc.tokenizer()?;
match formatter {
PromptFormatter::OAI(formatter) => Self::new_with_parts(mdc, formatter, tokenizer),
}
......@@ -160,10 +160,10 @@ impl OpenAIPreprocessor {
pub fn new_with_parts(
mdc: ModelDeploymentCard,
formatter: Arc<dyn OAIPromptFormatter>,
hf_tokenizer: tokenizers::Tokenizer,
tokenizer: crate::tokenizers::Tokenizer,
) -> Result<Arc<Self>> {
let mdcsum = mdc.mdcsum().to_string();
let tokenizer = Arc::new(HuggingFaceTokenizer::from_tokenizer(hf_tokenizer));
let tokenizer: Arc<dyn Tokenizer> = (*tokenizer).clone();
let lora_name = mdc.lora.as_ref().map(|l| l.name.clone());
let Some(ref model_info) = mdc.model_info else {
anyhow::bail!(
......
......@@ -2,6 +2,7 @@
// SPDX-License-Identifier: Apache-2.0
pub mod hf;
pub mod tiktoken;
// TODO: Add tokenizer benchmarks
// TODO: Enable README.md as a module doc
......@@ -15,11 +16,13 @@ use crate::protocols::TokenIdType;
pub use anyhow::{Error, Result};
pub use hf::HuggingFaceTokenizer;
pub use tiktoken::TikTokenTokenizer;
/// Represents the type of tokenizer being used
#[derive(Debug)]
pub enum TokenizerType {
HuggingFace(String),
TikToken(String),
}
/// character offsets in the original text
......@@ -121,6 +124,8 @@ where
/// The file extension is used to determine the tokenizer type.
/// Supported file types are:
/// - json: HuggingFace tokenizer
/// - model, tiktoken: tiktoken BPE tokenizer (requires `config.json` with a supported
/// `model_type` in the same directory; currently: kimi, kimi_k2, kimi_k25)
pub fn create_tokenizer_from_file(file_path: &str) -> Result<Arc<dyn traits::Tokenizer>> {
let path = Path::new(file_path);
let extension = path
......@@ -133,7 +138,13 @@ pub fn create_tokenizer_from_file(file_path: &str) -> Result<Arc<dyn traits::Tok
let tokenizer = HuggingFaceTokenizer::from_file(file_path)?;
Ok(Arc::new(tokenizer))
}
_ => Err(Error::msg("Unsupported file type".to_string())),
"model" | "tiktoken" => {
let tokenizer = TikTokenTokenizer::from_file_auto(file_path)?;
Ok(Arc::new(tokenizer))
}
_ => Err(Error::msg(format!(
"Unsupported tokenizer file type: .{extension}"
))),
}
}
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::collections::HashSet;
use std::path::Path;
use base64::Engine as _;
use rayon::prelude::*;
use rustc_hash::FxHashMap;
use tiktoken_rs::CoreBPE;
use super::{
Encoding, Error, Result, TokenIdType,
traits::{Decoder, Encoder, Tokenizer},
};
/// Number of reserved special-token slots to generate when filling gaps in the vocabulary.
/// Most tiktoken-based models reserve 256 IDs above the base vocabulary for special tokens.
const DEFAULT_NUM_RESERVED_SPECIAL_TOKENS: u32 = 256;
/// Kimi BPE pattern from moonshotai/Kimi-K2-Instruct/tokenization_kimi.py
const KIMI_PATTERN: &str = r#"[\p{Han}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"#;
pub struct TikTokenTokenizer {
bpe: CoreBPE,
special_token_ids: HashSet<u32>,
}
impl TikTokenTokenizer {
/// Create a TikTokenTokenizer from a tiktoken model file.
///
/// # Arguments
/// * `path` - Path to the `.model` or `.tiktoken` file (base64 rank-per-line format)
/// * `pattern` - BPE regex pattern string
/// * `special_tokens` - Map of special token strings to their IDs
pub fn from_file(
path: &str,
pattern: &str,
special_tokens: FxHashMap<String, u32>,
) -> Result<Self> {
let encoder = parse_tiktoken_file(path)?;
let special_token_ids: HashSet<u32> = special_tokens.values().copied().collect();
let bpe = CoreBPE::new(encoder, special_tokens, pattern)
.map_err(|err| Error::msg(format!("Error creating tiktoken BPE: {err}")))?;
Ok(Self {
bpe,
special_token_ids,
})
}
/// Create a TikTokenTokenizer from a tiktoken model file, auto-detecting
/// the BPE pattern from `config.json` and special tokens from `tokenizer_config.json`.
///
/// The tiktoken file and config files must be in the same directory.
pub fn from_file_auto(path: &str) -> Result<Self> {
let file_path = Path::new(path);
let directory = file_path
.parent()
.ok_or_else(|| Error::msg("Cannot determine parent directory of tiktoken file"))?;
let pattern = detect_bpe_pattern(directory)?;
let encoder = parse_tiktoken_file(path)?;
// Use max rank + 1 (not len) to avoid ID collisions with sparse/non-contiguous ranks
let num_base_tokens = encoder.values().max().map_or(0, |&m| m + 1) as usize;
let special_tokens = load_special_tokens(directory, num_base_tokens)?;
let special_token_ids: HashSet<u32> = special_tokens.values().copied().collect();
let bpe = CoreBPE::new(encoder, special_tokens, pattern)
.map_err(|err| Error::msg(format!("Error creating tiktoken BPE: {err}")))?;
Ok(Self {
bpe,
special_token_ids,
})
}
}
impl Encoder for TikTokenTokenizer {
fn encode(&self, input: &str) -> Result<Encoding> {
let token_ids: Vec<u32> = self.bpe.encode_with_special_tokens(input);
Ok(Encoding::Sp(token_ids))
}
fn encode_batch(&self, inputs: &[&str]) -> Result<Vec<Encoding>> {
inputs.par_iter().map(|input| self.encode(input)).collect()
}
}
impl Decoder for TikTokenTokenizer {
fn decode(&self, token_ids: &[TokenIdType], skip_special_tokens: bool) -> Result<String> {
let ids: Vec<u32> = if skip_special_tokens {
token_ids
.iter()
.filter(|&&id| !self.special_token_ids.contains(&id))
.copied()
.collect()
} else {
token_ids.to_vec()
};
self.bpe
.decode(ids)
.map_err(|err| Error::msg(format!("Error decoding tiktoken tokens: {err}")))
}
}
impl Tokenizer for TikTokenTokenizer {}
/// Parse a tiktoken model file (base64-encoded token + rank per line).
fn parse_tiktoken_file(path: &str) -> Result<FxHashMap<Vec<u8>, u32>> {
let contents = std::fs::read_to_string(path)
.map_err(|err| Error::msg(format!("Failed to read tiktoken file '{path}': {err}")))?;
let engine = base64::engine::general_purpose::STANDARD;
let mut encoder = FxHashMap::default();
for line in contents.lines() {
let line = line.trim();
if line.is_empty() {
continue;
}
let mut parts = line.split_whitespace();
let token_b64 = parts
.next()
.ok_or_else(|| Error::msg(format!("Invalid tiktoken line (no token): {line}")))?;
let rank_str = parts
.next()
.ok_or_else(|| Error::msg(format!("Invalid tiktoken line (no rank): {line}")))?;
let token_bytes = engine
.decode(token_b64)
.map_err(|err| Error::msg(format!("Invalid base64 in tiktoken file: {err}")))?;
let rank: u32 = rank_str
.parse()
.map_err(|err| Error::msg(format!("Invalid rank in tiktoken file: {err}")))?;
encoder.insert(token_bytes, rank);
}
Ok(encoder)
}
/// Detect the BPE pattern for a model by reading `model_type` from `config.json`.
fn detect_bpe_pattern(directory: &Path) -> Result<&'static str> {
let model_type: String = crate::file_json_field(&directory.join("config.json"), "model_type")
.map_err(|err| {
Error::msg(format!("Failed to read model_type from config.json: {err}"))
})?;
match model_type.as_str() {
"kimi" | "kimi_k2" | "kimi_k25" => Ok(KIMI_PATTERN),
_ => Err(Error::msg(format!(
"Unsupported tiktoken model_type '{model_type}'. \
Currently supported: kimi, kimi_k2, kimi_k25. \
To add a new model type, extend detect_bpe_pattern() in tokenizers/tiktoken.rs \
with the appropriate BPE regex pattern. \
Alternatively, provide a tokenizer.json (HuggingFace format) instead."
))),
}
}
/// Load special tokens from `tokenizer_config.json` in the model directory.
///
/// Reads the `added_tokens_decoder` field which maps string token IDs to token definitions.
/// Falls back to generating `<|reserved_token_{id}|>` names for unmapped IDs.
fn load_special_tokens(directory: &Path, num_base_tokens: usize) -> Result<FxHashMap<String, u32>> {
let config_path = directory.join("tokenizer_config.json");
let mut special_tokens = FxHashMap::default();
if !config_path.exists() {
// No tokenizer_config.json — generate default reserved tokens
for i in 0..DEFAULT_NUM_RESERVED_SPECIAL_TOKENS {
let id = num_base_tokens as u32 + i;
special_tokens.insert(format!("<|reserved_token_{i}|>"), id);
}
return Ok(special_tokens);
}
let contents = std::fs::read_to_string(&config_path)
.map_err(|err| Error::msg(format!("Failed to read tokenizer_config.json: {err}")))?;
let config: serde_json::Value = serde_json::from_str(&contents)
.map_err(|err| Error::msg(format!("Failed to parse tokenizer_config.json: {err}")))?;
if let Some(added_tokens) = config
.get("added_tokens_decoder")
.and_then(|v| v.as_object())
{
for (id_str, token_def) in added_tokens {
let id: u32 = id_str.parse().map_err(|err| {
Error::msg(format!(
"Invalid token ID '{id_str}' in added_tokens_decoder: {err}"
))
})?;
let content = token_def
.get("content")
.and_then(|v| v.as_str())
.unwrap_or_else(|| {
// This shouldn't happen in well-formed configs, but handle gracefully
tracing::warn!("Missing 'content' field for token ID {id}");
""
});
if !content.is_empty() {
special_tokens.insert(content.to_string(), id);
}
}
// Fill in any gaps with reserved tokens for the expected range
let used_ids: HashSet<u32> = special_tokens.values().copied().collect();
for i in 0..DEFAULT_NUM_RESERVED_SPECIAL_TOKENS {
let id = num_base_tokens as u32 + i;
if !used_ids.contains(&id) {
special_tokens.insert(format!("<|reserved_token_{i}|>"), id);
}
}
} else {
// No added_tokens_decoder — generate default reserved tokens
for i in 0..DEFAULT_NUM_RESERVED_SPECIAL_TOKENS {
let id = num_base_tokens as u32 + i;
special_tokens.insert(format!("<|reserved_token_{i}|>"), id);
}
}
Ok(special_tokens)
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
fn create_test_tiktoken_file(dir: &Path) -> String {
let engine = base64::engine::general_purpose::STANDARD;
let mut content = String::new();
// Create some simple token entries: single bytes with sequential ranks
let tokens: Vec<(&[u8], u32)> = vec![
(b"h", 0),
(b"e", 1),
(b"l", 2),
(b"o", 3),
(b" ", 4),
(b"w", 5),
(b"r", 6),
(b"d", 7),
(b"he", 8),
(b"ll", 9),
(b"lo", 10),
(b"wo", 11),
(b"rl", 12),
(b"hel", 13),
(b"llo", 14),
(b"wor", 15),
(b"hell", 16),
(b"ello", 17),
(b"worl", 18),
(b"hello", 19),
(b"world", 20),
];
for (token, rank) in tokens {
let encoded = engine.encode(token);
content.push_str(&format!("{encoded} {rank}\n"));
}
let file_path = dir.join("tiktoken.model");
let mut file = std::fs::File::create(&file_path).unwrap();
file.write_all(content.as_bytes()).unwrap();
file_path.to_str().unwrap().to_string()
}
fn create_test_config(dir: &Path, model_type: &str) {
let config = serde_json::json!({
"model_type": model_type,
"max_position_embeddings": 32768,
"eos_token_id": [21]
});
let file_path = dir.join("config.json");
std::fs::write(file_path, serde_json::to_string_pretty(&config).unwrap()).unwrap();
}
fn create_test_tokenizer_config(dir: &Path, num_base_tokens: usize) {
let mut added_tokens = serde_json::Map::new();
let bos_id = num_base_tokens;
let eos_id = num_base_tokens + 1;
added_tokens.insert(
bos_id.to_string(),
serde_json::json!({"content": "[BOS]", "special": true}),
);
added_tokens.insert(
eos_id.to_string(),
serde_json::json!({"content": "[EOS]", "special": true}),
);
let config = serde_json::json!({
"added_tokens_decoder": added_tokens
});
let file_path = dir.join("tokenizer_config.json");
std::fs::write(file_path, serde_json::to_string_pretty(&config).unwrap()).unwrap();
}
#[test]
fn test_parse_tiktoken_file() {
let dir = tempfile::tempdir().unwrap();
let file_path = create_test_tiktoken_file(dir.path());
let encoder = parse_tiktoken_file(&file_path).unwrap();
assert_eq!(encoder.len(), 21);
assert_eq!(encoder[b"hello".as_slice()], 19);
assert_eq!(encoder[b"world".as_slice()], 20);
}
#[test]
fn test_parse_tiktoken_file_missing() {
let result = parse_tiktoken_file("/nonexistent/path/tiktoken.model");
assert!(result.is_err());
}
#[test]
fn test_tiktoken_from_file() {
let dir = tempfile::tempdir().unwrap();
let file_path = create_test_tiktoken_file(dir.path());
let mut special_tokens = FxHashMap::default();
special_tokens.insert("[BOS]".to_string(), 21_u32);
special_tokens.insert("[EOS]".to_string(), 22_u32);
// Use a simple pattern for testing
let pattern = r"[\w]+|[^\w\s]+|\s+";
let tokenizer = TikTokenTokenizer::from_file(&file_path, pattern, special_tokens).unwrap();
// Test encode
let encoding = tokenizer.encode("hello world").unwrap();
let ids = encoding.token_ids();
assert!(!ids.is_empty());
// Test decode roundtrip
let decoded = tokenizer.decode(ids, false).unwrap();
assert_eq!(decoded, "hello world");
}
#[test]
fn test_tiktoken_encoding_variant() {
let dir = tempfile::tempdir().unwrap();
let file_path = create_test_tiktoken_file(dir.path());
let special_tokens = FxHashMap::default();
let pattern = r"[\w]+|[^\w\s]+|\s+";
let tokenizer = TikTokenTokenizer::from_file(&file_path, pattern, special_tokens).unwrap();
let encoding = tokenizer.encode("hello").unwrap();
// Verify it produces the Sp variant
match &encoding {
Encoding::Sp(_) => {}
other => panic!("Expected Encoding::Sp, got {:?}", other),
}
}
#[test]
fn test_tiktoken_skip_special_tokens() {
let dir = tempfile::tempdir().unwrap();
let file_path = create_test_tiktoken_file(dir.path());
let mut special_tokens = FxHashMap::default();
special_tokens.insert("[BOS]".to_string(), 21_u32);
special_tokens.insert("[EOS]".to_string(), 22_u32);
let pattern = r"[\w]+|[^\w\s]+|\s+";
let tokenizer = TikTokenTokenizer::from_file(&file_path, pattern, special_tokens).unwrap();
// Encode hello and prepend/append special tokens
let encoding = tokenizer.encode("hello").unwrap();
let mut ids = vec![21u32]; // [BOS]
ids.extend(encoding.token_ids());
ids.push(22); // [EOS]
// Decode with skip_special_tokens=true should strip special tokens
let decoded_skip = tokenizer.decode(&ids, true).unwrap();
assert_eq!(decoded_skip, "hello");
// Decode with skip_special_tokens=false should include them
let decoded_all = tokenizer.decode(&ids, false).unwrap();
assert!(decoded_all.contains("hello"));
}
#[test]
fn test_tiktoken_from_file_auto() {
let dir = tempfile::tempdir().unwrap();
let file_path = create_test_tiktoken_file(dir.path());
create_test_config(dir.path(), "kimi");
create_test_tokenizer_config(dir.path(), 21);
let tokenizer = TikTokenTokenizer::from_file_auto(&file_path).unwrap();
// Basic encode/decode roundtrip
let encoding = tokenizer.encode("hello world").unwrap();
let ids = encoding.token_ids();
assert!(!ids.is_empty());
let decoded = tokenizer.decode(ids, false).unwrap();
assert_eq!(decoded, "hello world");
}
#[test]
fn test_detect_bpe_pattern_unknown() {
let dir = tempfile::tempdir().unwrap();
create_test_config(dir.path(), "unknown_model");
let result = detect_bpe_pattern(dir.path());
assert!(result.is_err());
}
#[test]
fn test_load_special_tokens_no_config() {
let dir = tempfile::tempdir().unwrap();
let tokens = load_special_tokens(dir.path(), 100).unwrap();
assert_eq!(tokens.len(), 256);
assert_eq!(tokens["<|reserved_token_0|>"], 100);
assert_eq!(tokens["<|reserved_token_255|>"], 355);
}
#[test]
fn test_load_special_tokens_with_config() {
let dir = tempfile::tempdir().unwrap();
create_test_tokenizer_config(dir.path(), 100);
let tokens = load_special_tokens(dir.path(), 100).unwrap();
assert_eq!(tokens["[BOS]"], 100);
assert_eq!(tokens["[EOS]"], 101);
// Should also have reserved tokens filling gaps
assert!(tokens.len() > 2);
}
#[test]
fn test_tiktoken_encode_batch() {
let dir = tempfile::tempdir().unwrap();
let file_path = create_test_tiktoken_file(dir.path());
let special_tokens = FxHashMap::default();
let pattern = r"[\w]+|[^\w\s]+|\s+";
let tokenizer = TikTokenTokenizer::from_file(&file_path, pattern, special_tokens).unwrap();
let inputs = &["hello", "world"];
let encodings = tokenizer.encode_batch(inputs).unwrap();
assert_eq!(encodings.len(), 2);
for (encoding, input) in encodings.iter().zip(inputs.iter()) {
let decoded = tokenizer.decode(encoding.token_ids(), false).unwrap();
assert_eq!(decoded, *input);
}
}
}
{
"model_type": "kimi",
"max_position_embeddings": 32768,
"architectures": ["KimiForCausalLM"],
"eos_token_id": [370],
"vocab_size": 375
}
AA== 0
AQ== 1
Ag== 2
Aw== 3
BA== 4
BQ== 5
Bg== 6
Bw== 7
CA== 8
CQ== 9
Cg== 10
Cw== 11
DA== 12
DQ== 13
Dg== 14
Dw== 15
EA== 16
EQ== 17
Eg== 18
Ew== 19
FA== 20
FQ== 21
Fg== 22
Fw== 23
GA== 24
GQ== 25
Gg== 26
Gw== 27
HA== 28
HQ== 29
Hg== 30
Hw== 31
IA== 32
IQ== 33
Ig== 34
Iw== 35
JA== 36
JQ== 37
Jg== 38
Jw== 39
KA== 40
KQ== 41
Kg== 42
Kw== 43
LA== 44
LQ== 45
Lg== 46
Lw== 47
MA== 48
MQ== 49
Mg== 50
Mw== 51
NA== 52
NQ== 53
Ng== 54
Nw== 55
OA== 56
OQ== 57
Og== 58
Ow== 59
PA== 60
PQ== 61
Pg== 62
Pw== 63
QA== 64
QQ== 65
Qg== 66
Qw== 67
RA== 68
RQ== 69
Rg== 70
Rw== 71
SA== 72
SQ== 73
Sg== 74
Sw== 75
TA== 76
TQ== 77
Tg== 78
Tw== 79
UA== 80
UQ== 81
Ug== 82
Uw== 83
VA== 84
VQ== 85
Vg== 86
Vw== 87
WA== 88
WQ== 89
Wg== 90
Ww== 91
XA== 92
XQ== 93
Xg== 94
Xw== 95
YA== 96
YQ== 97
Yg== 98
Yw== 99
ZA== 100
ZQ== 101
Zg== 102
Zw== 103
aA== 104
aQ== 105
ag== 106
aw== 107
bA== 108
bQ== 109
bg== 110
bw== 111
cA== 112
cQ== 113
cg== 114
cw== 115
dA== 116
dQ== 117
dg== 118
dw== 119
eA== 120
eQ== 121
eg== 122
ew== 123
fA== 124
fQ== 125
fg== 126
fw== 127
gA== 128
gQ== 129
gg== 130
gw== 131
hA== 132
hQ== 133
hg== 134
hw== 135
iA== 136
iQ== 137
ig== 138
iw== 139
jA== 140
jQ== 141
jg== 142
jw== 143
kA== 144
kQ== 145
kg== 146
kw== 147
lA== 148
lQ== 149
lg== 150
lw== 151
mA== 152
mQ== 153
mg== 154
mw== 155
nA== 156
nQ== 157
ng== 158
nw== 159
oA== 160
oQ== 161
og== 162
ow== 163
pA== 164
pQ== 165
pg== 166
pw== 167
qA== 168
qQ== 169
qg== 170
qw== 171
rA== 172
rQ== 173
rg== 174
rw== 175
sA== 176
sQ== 177
sg== 178
sw== 179
tA== 180
tQ== 181
tg== 182
tw== 183
uA== 184
uQ== 185
ug== 186
uw== 187
vA== 188
vQ== 189
vg== 190
vw== 191
wA== 192
wQ== 193
wg== 194
ww== 195
xA== 196
xQ== 197
xg== 198
xw== 199
yA== 200
yQ== 201
yg== 202
yw== 203
zA== 204
zQ== 205
zg== 206
zw== 207
0A== 208
0Q== 209
0g== 210
0w== 211
1A== 212
1Q== 213
1g== 214
1w== 215
2A== 216
2Q== 217
2g== 218
2w== 219
3A== 220
3Q== 221
3g== 222
3w== 223
4A== 224
4Q== 225
4g== 226
4w== 227
5A== 228
5Q== 229
5g== 230
5w== 231
6A== 232
6Q== 233
6g== 234
6w== 235
7A== 236
7Q== 237
7g== 238
7w== 239
8A== 240
8Q== 241
8g== 242
8w== 243
9A== 244
9Q== 245
9g== 246
9w== 247
+A== 248
+Q== 249
+g== 250
+w== 251
/A== 252
/Q== 253
/g== 254
/w== 255
dGg= 256
aGU= 257
aW4= 258
ZXI= 259
YW4= 260
cmU= 261
b24= 262
ZW4= 263
YXQ= 264
ZXM= 265
b3I= 266
dGU= 267
b2Y= 268
ZWQ= 269
aXM= 270
aXQ= 271
YWw= 272
YXI= 273
c3Q= 274
dG8= 275
bnQ= 276
bmc= 277
c2U= 278
aGE= 279
YXM= 280
b3U= 281
aW8= 282
bGU= 283
dmU= 284
Y28= 285
bWU= 286
ZGU= 287
aGk= 288
cmk= 289
cm8= 290
aWM= 291
bmU= 292
ZWE= 293
cmE= 294
Y2U= 295
bGk= 296
Y2g= 297
bGw= 298
YmU= 299
bWE= 300
c2k= 301
b20= 302
dXI= 303
dGhl 304
YW5k 305
aW5n 306
aGVy 307
aGF0 308
aGlz 309
dGhh 310
ZXJl 311
Zm9y 312
ZW50 313
aW9u 314
dGVy 315
d2Fz 316
eW91 317
YWxs 318
d2l0 319
dGhp 320
dmVy 321
aGVs 322
ZWxs 323
bGxv 324
bG8g 325
d29y 326
b3Js 327
cmxk 328
aGVsbA== 329
ZWxsbw== 330
bGxvIA== 331
d29ybA== 332
b3JsZA== 333
aGVsbG8= 334
d29ybGQ= 335
IHdv 336
IHdvcg== 337
IHdvcmw= 338
IHdvcmxk 339
aGVsbG8g 340
aGVsbG8gdw== 341
ZGVl 342
ZWVw 343
ZXAg 344
ZGVlcA== 345
ZGVlcCA= 346
bGVh 347
ZWFy 348
YXJu 349
bGVhcg== 350
ZWFybg== 351
YXJuaQ== 352
cm5pbg== 353
bmluZw== 354
bGVhcm4= 355
ZWFybmk= 356
YXJuaW4= 357
cm5pbmc= 358
bGVhcm5p 359
ZWFybmlu 360
YXJuaW5n 361
bGVhcm5pbg== 362
ZWFybmluZw== 363
bGVhcm5pbmc= 364
ZGVlcCBs 365
ZGVlcCBsZQ== 366
IGlz 367
IGxl 368
{
"added_tokens_decoder": {
"369": {
"content": "[BOS]",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"370": {
"content": "[EOS]",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"371": {
"content": "[PAD]",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"372": {
"content": "[UNK]",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"373": {
"content": "<|im_start|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"374": {
"content": "<|im_end|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
}
}
}
......@@ -30,6 +30,7 @@ async fn test_tokenizer_from_hf_like_local_repo() {
// Verify tokenizer file was found
match mdc.tokenizer.unwrap() {
TokenizerKind::HfTokenizerJson(_) => (),
TokenizerKind::TikTokenModel(_) => panic!("Expected HfTokenizerJson, got TikTokenModel"),
}
}
......
......@@ -16,6 +16,7 @@
use dynamo_llm::tokenizers::traits::{Decoder, Encoder};
use dynamo_llm::tokenizers::*;
use std::collections::HashMap;
use std::path::Path;
use std::sync::Arc;
const TEST_PROMPTS: [&str; 4] = [
......@@ -206,3 +207,111 @@ fn test_decode_with_skip_special_tokens() {
assert_eq!(decoded_with_special, "<s> Hello world</s>");
assert_eq!(decoded_without_special, "Hello world");
}
// --- tiktoken tests ---
const MOCK_TIKTOKEN_DIR: &str = "tests/data/sample-models/mock-tiktoken";
fn mock_tiktoken_model_path() -> String {
Path::new(env!("CARGO_MANIFEST_DIR"))
.join(MOCK_TIKTOKEN_DIR)
.join("tiktoken.model")
.to_str()
.unwrap()
.to_string()
}
#[test]
fn test_tiktoken_lifecycle() {
let path = mock_tiktoken_model_path();
let tokenizer =
TikTokenTokenizer::from_file_auto(&path).expect("Failed to load tiktoken tokenizer");
// Test simple encode/decode roundtrip
let text = "hello world";
let encoding = tokenizer.encode(text).expect("Failed to encode");
let ids = encoding.token_ids();
assert!(!ids.is_empty(), "Token IDs should not be empty");
// Verify Sp variant
match &encoding {
Encoding::Sp(_) => {}
other => panic!("Expected Encoding::Sp, got {:?}", other),
}
let decoded = tokenizer
.decode(ids, false)
.expect("Failed to decode token_ids");
assert_eq!(decoded, text);
}
#[test]
fn test_tiktoken_decode_stream() {
let path = mock_tiktoken_model_path();
let tokenizer =
TikTokenTokenizer::from_file_auto(&path).expect("Failed to load tiktoken tokenizer");
let shared_tokenizer: Arc<dyn dynamo_llm::tokenizers::traits::Tokenizer> = Arc::new(tokenizer);
let text = "hello world";
let encoding = shared_tokenizer
.encode(text)
.expect("Failed to encode prompt");
let mut decoder = DecodeStream::new(shared_tokenizer.clone(), &[], false);
let mut output = String::new();
for token_id in encoding.token_ids() {
let step_text = decoder.step(*token_id).expect("Failed to decode token_id");
if let Some(t) = step_text {
output.push_str(&t);
}
}
assert_eq!(output, text);
}
#[test]
fn compute_hashes_tiktoken() {
let path = mock_tiktoken_model_path();
let tokenizer =
TikTokenTokenizer::from_file_auto(&path).expect("Failed to load tiktoken tokenizer");
let simple_prompts = &["hello world", "hello", "world"];
let hashes = compute_hashes_for_tokenizer(&tokenizer, simple_prompts);
// Just verify we get consistent hashes (non-zero, deterministic)
let hashes2 = compute_hashes_for_tokenizer(&tokenizer, simple_prompts);
assert_eq!(hashes, hashes2, "Hashes should be deterministic");
assert!(hashes.iter().all(|&h| h != 0), "Hashes should be non-zero");
}
#[test]
fn test_tiktoken_create_from_file() {
let path = mock_tiktoken_model_path();
// Test the factory function used by the Tokenizer wrapper
let tokenizer = create_tokenizer_from_file(&path).expect("Failed to create tokenizer");
let encoding = tokenizer
.encode("hello")
.expect("Failed to encode with factory-created tokenizer");
assert!(!encoding.token_ids().is_empty());
}
#[test]
fn test_tiktoken_batch_encode() {
let path = mock_tiktoken_model_path();
let tokenizer =
TikTokenTokenizer::from_file_auto(&path).expect("Failed to load tiktoken tokenizer");
let inputs = &["hello", "world"];
let encodings = tokenizer
.encode_batch(inputs)
.expect("Failed to batch encode");
assert_eq!(encodings.len(), 2);
for (encoding, input) in encodings.iter().zip(inputs.iter()) {
let decoded = tokenizer
.decode(encoding.token_ids(), false)
.expect("Failed to decode");
assert_eq!(decoded, *input);
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment