Unverified Commit 61a1f4ff authored by Graham King's avatar Graham King Committed by GitHub
Browse files

perf(tokenizer): Make de-tokenize ~50% faster (#1868)

parent f242b455
...@@ -334,6 +334,17 @@ version = "1.1.2" ...@@ -334,6 +334,17 @@ version = "1.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0"
[[package]]
name = "atty"
version = "0.2.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8"
dependencies = [
"hermit-abi 0.1.19",
"libc",
"winapi 0.3.9",
]
[[package]] [[package]]
name = "autocfg" name = "autocfg"
version = "1.4.0" version = "1.4.0"
...@@ -794,7 +805,7 @@ dependencies = [ ...@@ -794,7 +805,7 @@ dependencies = [
"cudarc 0.13.9", "cudarc 0.13.9",
"float8", "float8",
"gemm 0.17.1", "gemm 0.17.1",
"half", "half 2.6.0",
"memmap2", "memmap2",
"metal", "metal",
"num-traits", "num-traits",
...@@ -816,7 +827,7 @@ checksum = "06ccf5ee3532e66868516d9b315f73aec9f34ea1a37ae98514534d458915dbf1" ...@@ -816,7 +827,7 @@ checksum = "06ccf5ee3532e66868516d9b315f73aec9f34ea1a37ae98514534d458915dbf1"
dependencies = [ dependencies = [
"byteorder", "byteorder",
"gemm 0.17.1", "gemm 0.17.1",
"half", "half 2.6.0",
"memmap2", "memmap2",
"num-traits", "num-traits",
"num_cpus", "num_cpus",
...@@ -856,7 +867,7 @@ source = "git+https://github.com/EricLBuehler/candle.git?rev=98c0436e#98c0436eaf ...@@ -856,7 +867,7 @@ source = "git+https://github.com/EricLBuehler/candle.git?rev=98c0436e#98c0436eaf
dependencies = [ dependencies = [
"candle-core 0.8.0", "candle-core 0.8.0",
"candle-metal-kernels", "candle-metal-kernels",
"half", "half 2.6.0",
"metal", "metal",
"num-traits", "num-traits",
"rayon", "rayon",
...@@ -865,13 +876,19 @@ dependencies = [ ...@@ -865,13 +876,19 @@ dependencies = [
"thiserror 1.0.69", "thiserror 1.0.69",
] ]
[[package]]
name = "cast"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5"
[[package]] [[package]]
name = "cbindgen" name = "cbindgen"
version = "0.27.0" version = "0.27.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3fce8dd7fcfcbf3a0a87d8f515194b49d6135acab73e18bd380d1d93bb1a15eb" checksum = "3fce8dd7fcfcbf3a0a87d8f515194b49d6135acab73e18bd380d1d93bb1a15eb"
dependencies = [ dependencies = [
"clap", "clap 4.5.40",
"heck 0.4.1", "heck 0.4.1",
"indexmap 2.9.0", "indexmap 2.9.0",
"log", "log",
...@@ -972,6 +989,17 @@ dependencies = [ ...@@ -972,6 +989,17 @@ dependencies = [
"libloading", "libloading",
] ]
[[package]]
name = "clap"
version = "2.34.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a0610544180c38b88101fecf2dd634b174a62eef6946f84dfc6a7127512b381c"
dependencies = [
"bitflags 1.3.2",
"textwrap",
"unicode-width 0.1.14",
]
[[package]] [[package]]
name = "clap" name = "clap"
version = "4.5.40" version = "4.5.40"
...@@ -1114,6 +1142,42 @@ dependencies = [ ...@@ -1114,6 +1142,42 @@ dependencies = [
"cfg-if 1.0.0", "cfg-if 1.0.0",
] ]
[[package]]
name = "criterion"
version = "0.3.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b01d6de93b2b6c65e17c634a26653a29d107b3c98c607c765bf38d041531cd8f"
dependencies = [
"atty",
"cast",
"clap 2.34.0",
"criterion-plot",
"csv",
"itertools 0.10.5",
"lazy_static",
"num-traits",
"oorandom",
"plotters",
"rayon",
"regex",
"serde",
"serde_cbor",
"serde_derive",
"serde_json",
"tinytemplate",
"walkdir",
]
[[package]]
name = "criterion-plot"
version = "0.4.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2673cc8207403546f45f5fd319a974b1e6983ad1a3ee7e6041650013be041876"
dependencies = [
"cast",
"itertools 0.10.5",
]
[[package]] [[package]]
name = "crossbeam" name = "crossbeam"
version = "0.8.4" version = "0.8.4"
...@@ -1261,7 +1325,7 @@ version = "0.13.9" ...@@ -1261,7 +1325,7 @@ version = "0.13.9"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "486c221362668c63a1636cfa51463b09574433b39029326cff40864b3ba12b6e" checksum = "486c221362668c63a1636cfa51463b09574433b39029326cff40864b3ba12b6e"
dependencies = [ dependencies = [
"half", "half 2.6.0",
"libloading", "libloading",
] ]
...@@ -1559,7 +1623,7 @@ version = "0.2.0" ...@@ -1559,7 +1623,7 @@ version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09157630eece4139f6cc5a457556d308c3465ecd5af492f0e5aadc043997e2ce" checksum = "09157630eece4139f6cc5a457556d308c3465ecd5af492f0e5aadc043997e2ce"
dependencies = [ dependencies = [
"half", "half 2.6.0",
] ]
[[package]] [[package]]
...@@ -1722,6 +1786,7 @@ dependencies = [ ...@@ -1722,6 +1786,7 @@ dependencies = [
"bytes", "bytes",
"candle-core 0.8.4", "candle-core 0.8.4",
"chrono", "chrono",
"criterion",
"cudarc 0.16.2", "cudarc 0.16.2",
"derive-getters", "derive-getters",
"derive_builder", "derive_builder",
...@@ -1784,7 +1849,7 @@ dependencies = [ ...@@ -1784,7 +1849,7 @@ dependencies = [
"async-openai", "async-openai",
"async-stream", "async-stream",
"async-trait", "async-trait",
"clap", "clap 4.5.40",
"dynamo-engine-llamacpp", "dynamo-engine-llamacpp",
"dynamo-engine-mistralrs", "dynamo-engine-mistralrs",
"dynamo-llm", "dynamo-llm",
...@@ -2099,7 +2164,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" ...@@ -2099,7 +2164,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f83197f59927b46c04a183a619b7c29df34e63e63c7869320862268c0ef687e0" checksum = "f83197f59927b46c04a183a619b7c29df34e63e63c7869320862268c0ef687e0"
dependencies = [ dependencies = [
"bit_field", "bit_field",
"half", "half 2.6.0",
"lebe", "lebe",
"miniz_oxide", "miniz_oxide",
"rayon-core", "rayon-core",
...@@ -2194,7 +2259,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" ...@@ -2194,7 +2259,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dee36245af1dccf978103fcd393582806db2a1d0bcd2f38c663cdbb4a363a01c" checksum = "dee36245af1dccf978103fcd393582806db2a1d0bcd2f38c663cdbb4a363a01c"
dependencies = [ dependencies = [
"cudarc 0.13.9", "cudarc 0.13.9",
"half", "half 2.6.0",
"num-traits", "num-traits",
"rand 0.9.1", "rand 0.9.1",
"rand_distr", "rand_distr",
...@@ -2496,7 +2561,7 @@ checksum = "a2e7ea062c987abcd8db95db917b4ffb4ecdfd0668471d8dc54734fdff2354e8" ...@@ -2496,7 +2561,7 @@ checksum = "a2e7ea062c987abcd8db95db917b4ffb4ecdfd0668471d8dc54734fdff2354e8"
dependencies = [ dependencies = [
"bytemuck", "bytemuck",
"dyn-stack 0.10.0", "dyn-stack 0.10.0",
"half", "half 2.6.0",
"num-complex", "num-complex",
"num-traits", "num-traits",
"once_cell", "once_cell",
...@@ -2516,7 +2581,7 @@ checksum = "a352d4a69cbe938b9e2a9cb7a3a63b7e72f9349174a2752a558a8a563510d0f3" ...@@ -2516,7 +2581,7 @@ checksum = "a352d4a69cbe938b9e2a9cb7a3a63b7e72f9349174a2752a558a8a563510d0f3"
dependencies = [ dependencies = [
"bytemuck", "bytemuck",
"dyn-stack 0.13.0", "dyn-stack 0.13.0",
"half", "half 2.6.0",
"libm", "libm",
"num-complex", "num-complex",
"num-traits", "num-traits",
...@@ -2538,7 +2603,7 @@ dependencies = [ ...@@ -2538,7 +2603,7 @@ dependencies = [
"dyn-stack 0.10.0", "dyn-stack 0.10.0",
"gemm-common 0.17.1", "gemm-common 0.17.1",
"gemm-f32 0.17.1", "gemm-f32 0.17.1",
"half", "half 2.6.0",
"num-complex", "num-complex",
"num-traits", "num-traits",
"paste", "paste",
...@@ -2556,7 +2621,7 @@ dependencies = [ ...@@ -2556,7 +2621,7 @@ dependencies = [
"dyn-stack 0.13.0", "dyn-stack 0.13.0",
"gemm-common 0.18.2", "gemm-common 0.18.2",
"gemm-f32 0.18.2", "gemm-f32 0.18.2",
"half", "half 2.6.0",
"num-complex", "num-complex",
"num-traits", "num-traits",
"paste", "paste",
...@@ -2678,7 +2743,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" ...@@ -2678,7 +2743,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a27693512784e0786212eb0bef841779a6337d2d04520ed475b4d5a864f98366" checksum = "a27693512784e0786212eb0bef841779a6337d2d04520ed475b4d5a864f98366"
dependencies = [ dependencies = [
"digit-layout", "digit-layout",
"half", "half 2.6.0",
"rayon", "rayon",
] ]
...@@ -2749,6 +2814,12 @@ dependencies = [ ...@@ -2749,6 +2814,12 @@ dependencies = [
"tracing", "tracing",
] ]
[[package]]
name = "half"
version = "1.8.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1b43ede17f21864e81be2fa654110bf1e793774238d86ef8555c37e6519c0403"
[[package]] [[package]]
name = "half" name = "half"
version = "2.6.0" version = "2.6.0"
...@@ -2802,6 +2873,15 @@ version = "0.5.0" ...@@ -2802,6 +2873,15 @@ version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
[[package]]
name = "hermit-abi"
version = "0.1.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "62b467343b94ba476dcb2500d242dadbb39557df889310ac77c5d99100aaac33"
dependencies = [
"libc",
]
[[package]] [[package]]
name = "hermit-abi" name = "hermit-abi"
version = "0.3.9" version = "0.3.9"
...@@ -2897,7 +2977,7 @@ dependencies = [ ...@@ -2897,7 +2977,7 @@ dependencies = [
name = "http" name = "http"
version = "0.3.2" version = "0.3.2"
dependencies = [ dependencies = [
"clap", "clap 4.5.40",
"dynamo-llm", "dynamo-llm",
"dynamo-runtime", "dynamo-runtime",
"serde", "serde",
...@@ -3655,7 +3735,7 @@ name = "llmctl" ...@@ -3655,7 +3735,7 @@ name = "llmctl"
version = "0.3.2" version = "0.3.2"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"clap", "clap 4.5.40",
"dynamo-llm", "dynamo-llm",
"dynamo-runtime", "dynamo-runtime",
"serde", "serde",
...@@ -3855,7 +3935,7 @@ name = "metrics" ...@@ -3855,7 +3935,7 @@ name = "metrics"
version = "0.3.2" version = "0.3.2"
dependencies = [ dependencies = [
"axum 0.6.20", "axum 0.6.20",
"clap", "clap 4.5.40",
"dynamo-llm", "dynamo-llm",
"dynamo-runtime", "dynamo-runtime",
"futures", "futures",
...@@ -3985,7 +4065,7 @@ dependencies = [ ...@@ -3985,7 +4065,7 @@ dependencies = [
"anyhow", "anyhow",
"candle-core 0.8.0", "candle-core 0.8.0",
"candle-nn", "candle-nn",
"clap", "clap 4.5.40",
"either", "either",
"futures", "futures",
"image", "image",
...@@ -4030,7 +4110,7 @@ dependencies = [ ...@@ -4030,7 +4110,7 @@ dependencies = [
"candle-nn", "candle-nn",
"cfgrammar", "cfgrammar",
"chrono", "chrono",
"clap", "clap 4.5.40",
"csv", "csv",
"derive-new", "derive-new",
"derive_more 2.0.1", "derive_more 2.0.1",
...@@ -4039,7 +4119,7 @@ dependencies = [ ...@@ -4039,7 +4119,7 @@ dependencies = [
"float8", "float8",
"futures", "futures",
"galil-seiferas", "galil-seiferas",
"half", "half 2.6.0",
"hashbrown 0.15.4", "hashbrown 0.15.4",
"hf-hub", "hf-hub",
"hound", "hound",
...@@ -4133,7 +4213,7 @@ dependencies = [ ...@@ -4133,7 +4213,7 @@ dependencies = [
"bindgen_cuda 0.1.6", "bindgen_cuda 0.1.6",
"candle-core 0.8.0", "candle-core 0.8.0",
"float8", "float8",
"half", "half 2.6.0",
"metal", "metal",
"once_cell", "once_cell",
"thiserror 2.0.12", "thiserror 2.0.12",
...@@ -4149,7 +4229,7 @@ dependencies = [ ...@@ -4149,7 +4229,7 @@ dependencies = [
"candle-core 0.8.0", "candle-core 0.8.0",
"candle-nn", "candle-nn",
"float8", "float8",
"half", "half 2.6.0",
"hf-hub", "hf-hub",
"lazy_static", "lazy_static",
"memmap2", "memmap2",
...@@ -4479,7 +4559,7 @@ version = "1.16.0" ...@@ -4479,7 +4559,7 @@ version = "1.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43" checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43"
dependencies = [ dependencies = [
"hermit-abi", "hermit-abi 0.3.9",
"libc", "libc",
] ]
...@@ -4591,6 +4671,12 @@ dependencies = [ ...@@ -4591,6 +4671,12 @@ dependencies = [
"pkg-config", "pkg-config",
] ]
[[package]]
name = "oorandom"
version = "11.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e"
[[package]] [[package]]
name = "openssl-probe" name = "openssl-probe"
version = "0.1.6" version = "0.1.6"
...@@ -4872,6 +4958,34 @@ version = "0.3.32" ...@@ -4872,6 +4958,34 @@ version = "0.3.32"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c"
[[package]]
name = "plotters"
version = "0.3.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5aeb6f403d7a4911efb1e33402027fc44f29b5bf6def3effcc22d7bb75f2b747"
dependencies = [
"num-traits",
"plotters-backend",
"plotters-svg",
"wasm-bindgen",
"web-sys",
]
[[package]]
name = "plotters-backend"
version = "0.3.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df42e13c12958a16b3f7f4386b9ab1f3e7933914ecea48da7139435263a4172a"
[[package]]
name = "plotters-svg"
version = "0.3.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "51bae2ac328883f7acdfea3d66a7c35751187f870bc81f94563733a154d7a670"
dependencies = [
"plotters-backend",
]
[[package]] [[package]]
name = "png" name = "png"
version = "0.17.16" version = "0.17.16"
...@@ -5596,7 +5710,7 @@ dependencies = [ ...@@ -5596,7 +5710,7 @@ dependencies = [
name = "router" name = "router"
version = "0.3.2" version = "0.3.2"
dependencies = [ dependencies = [
"clap", "clap 4.5.40",
"dynamo-llm", "dynamo-llm",
"dynamo-runtime", "dynamo-runtime",
"rand 0.9.1", "rand 0.9.1",
...@@ -6078,6 +6192,16 @@ dependencies = [ ...@@ -6078,6 +6192,16 @@ dependencies = [
"serde", "serde",
] ]
[[package]]
name = "serde_cbor"
version = "0.11.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2bef2ebfde456fb76bbcf9f59315333decc4fda0b2b44b420243c11e0f5ec1f5"
dependencies = [
"half 1.8.3",
"serde",
]
[[package]] [[package]]
name = "serde_derive" name = "serde_derive"
version = "1.0.219" version = "1.0.219"
...@@ -6796,6 +6920,15 @@ dependencies = [ ...@@ -6796,6 +6920,15 @@ dependencies = [
"windows-sys 0.59.0", "windows-sys 0.59.0",
] ]
[[package]]
name = "textwrap"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d326610f408c7a4eb6f51c37c330e496b08506c9457c9d34287ecc38809fb060"
dependencies = [
"unicode-width 0.1.14",
]
[[package]] [[package]]
name = "thiserror" name = "thiserror"
version = "1.0.69" version = "1.0.69"
...@@ -6900,6 +7033,16 @@ dependencies = [ ...@@ -6900,6 +7033,16 @@ dependencies = [
"zerovec", "zerovec",
] ]
[[package]]
name = "tinytemplate"
version = "1.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc"
dependencies = [
"serde",
"serde_json",
]
[[package]] [[package]]
name = "tinyvec" name = "tinyvec"
version = "1.9.0" version = "1.9.0"
...@@ -7420,7 +7563,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" ...@@ -7420,7 +7563,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "03719c61a91b51541f076dfdba45caacf750b230cefaa4b32d6f5411c3f7f437" checksum = "03719c61a91b51541f076dfdba45caacf750b230cefaa4b32d6f5411c3f7f437"
dependencies = [ dependencies = [
"gemm 0.18.2", "gemm 0.18.2",
"half", "half 2.6.0",
"libloading", "libloading",
"memmap2", "memmap2",
"num", "num",
......
...@@ -36,6 +36,10 @@ testing-nixl = ["dep:nixl-sys"] ...@@ -36,6 +36,10 @@ testing-nixl = ["dep:nixl-sys"]
block-manager = ["dep:nixl-sys", "dep:cudarc", "dep:ndarray", "dep:nix"] block-manager = ["dep:nixl-sys", "dep:cudarc", "dep:ndarray", "dep:nix"]
sentencepiece = ["dep:sentencepiece"] sentencepiece = ["dep:sentencepiece"]
[[bench]]
name = "tokenizer"
harness = false
[dependencies] [dependencies]
# repo # repo
dynamo-runtime = { workspace = true } dynamo-runtime = { workspace = true }
...@@ -126,6 +130,7 @@ rmp-serde = "1.3" ...@@ -126,6 +130,7 @@ rmp-serde = "1.3"
[dev-dependencies] [dev-dependencies]
assert_matches = "1.5" assert_matches = "1.5"
criterion = { version = "0.3", features = ["html_reports"] }
hf-hub = { workspace = true } hf-hub = { workspace = true }
proptest = "1.5.0" proptest = "1.5.0"
reqwest = { version = "0.12", default-features = false, features = ["json", "stream", "rustls-tls"] } reqwest = { version = "0.12", default-features = false, features = ["json", "stream", "rustls-tls"] }
......
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::hint::black_box;
use std::sync::Arc;
use criterion::{criterion_group, criterion_main, Criterion, Throughput};
use dynamo_llm::backend::Decoder;
use dynamo_llm::protocols::common::StopConditions;
use dynamo_llm::tokenizers::hf::HuggingFaceTokenizer;
use dynamo_llm::tokenizers::traits::{Encoder, Tokenizer};
use dynamo_llm::tokenizers::DecodeStream;
use dynamo_llm::types::TokenIdType;
const TEST_TOKENIZER: &str = concat!(
env!("CARGO_MANIFEST_DIR"),
"/tests/data/sample-models/TinyLlama_v1.1/tokenizer.json"
);
/// Input Sequence Length for tokenizer
const TARGET_ISL: usize = 8_000;
// A string of length exactly 128 bytes.
const INPUT_STR: &str = "The cat sat by the window, watching raindrops race down the glass. Far thunder rumbled. She purred softly, feeling safe at home.";
/// `cargo bench -- encode` to run it
pub fn encode(c: &mut Criterion) {
let test_str: &str = &INPUT_STR.repeat(TARGET_ISL / INPUT_STR.len());
let encoder = HuggingFaceTokenizer::from_file(TEST_TOKENIZER).unwrap();
let mut group = c.benchmark_group("encode-group");
group.throughput(Throughput::Bytes(test_str.len() as u64));
group.bench_function("tokenizer_encode", |b| {
b.iter(|| {
let _ = encoder.encode(black_box(test_str)).unwrap();
})
});
group.finish();
}
pub fn decode(c: &mut Criterion) {
const TEST_TOKS: [TokenIdType; 34] = [
450, 6635, 3290, 491, 278, 3474, 29892, 21217, 1153, 513, 307, 567, 8175, 1623, 278, 12917,
29889, 8413, 266, 5062, 364, 25443, 29889, 2296, 3708, 1127, 4964, 368, 29892, 11223, 9109,
472, 3271, 29889,
];
let tokenizer: Arc<dyn Tokenizer> =
Arc::new(HuggingFaceTokenizer::from_file(TEST_TOKENIZER).unwrap());
let ds = DecodeStream::new(tokenizer, false);
let mut decoder = Decoder::new(ds, StopConditions::default());
let mut group = c.benchmark_group("decode-group");
group.throughput(Throughput::Bytes(TEST_TOKS.len() as u64));
group.bench_function("tokenizer_decoder", |b| {
b.iter(|| {
for tok in black_box(TEST_TOKS) {
let _ = decoder.step(tok).unwrap();
}
})
});
group.finish();
}
criterion_group!(benches, encode, decode);
criterion_main!(benches);
...@@ -466,7 +466,7 @@ impl Decoder { ...@@ -466,7 +466,7 @@ impl Decoder {
pub fn process_token_ids(&mut self, token_ids: &[TokenIdType]) -> Result<SeqResult> { pub fn process_token_ids(&mut self, token_ids: &[TokenIdType]) -> Result<SeqResult> {
let mut text: Option<String> = None; let mut text: Option<String> = None;
let mut tokens = Vec::new(); let mut tokens = Vec::with_capacity(token_ids.len());
for token_id in token_ids { for token_id in token_ids {
let StepResult { let StepResult {
...@@ -481,7 +481,8 @@ impl Decoder { ...@@ -481,7 +481,8 @@ impl Decoder {
if !hide_text { if !hide_text {
if let Some(token) = &token { if let Some(token) = &token {
text.get_or_insert_with(String::new).push_str(token); text.get_or_insert_with(|| String::with_capacity(token_ids.len()))
.push_str(token);
} }
} }
tokens.push(token); tokens.push(token);
......
...@@ -142,14 +142,14 @@ pub async fn run( ...@@ -142,14 +142,14 @@ pub async fn run(
if let Some(pre) = pre_processor { if let Some(pre) = pre_processor {
// Note this does not include the prompt template. Probably TODO // Note this does not include the prompt template. Probably TODO
entry.tokens_in = match pre.tokenize(&entry.text) { entry.tokens_in = match pre.tokenize(&entry.text) {
Ok(encoding) => encoding.token_ids.len(), Ok(encoding) => encoding.token_ids().len(),
Err(err) => { Err(err) => {
tracing::warn!(%err, entry.text, "Failed tokenizing prompt"); tracing::warn!(%err, entry.text, "Failed tokenizing prompt");
0 0
} }
}; };
entry.tokens_out = match pre.tokenize(&response) { entry.tokens_out = match pre.tokenize(&response) {
Ok(encoding) => encoding.token_ids.len(), Ok(encoding) => encoding.token_ids().len(),
Err(err) => { Err(err) => {
tracing::warn!(%err, response, "Failed tokenizing response"); tracing::warn!(%err, response, "Failed tokenizing response");
0 0
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! The Preprocessor consists of the following modules //! The Preprocessor consists of the following modules
//! //!
...@@ -205,9 +193,7 @@ impl OpenAIPreprocessor { ...@@ -205,9 +193,7 @@ impl OpenAIPreprocessor {
self.formatter.render(request)? self.formatter.render(request)?
}; };
let encoding = tokio::task::block_in_place(|| { let encoding = self.tokenizer.encode(&formatted_prompt)?;
self.tokenizer.encode(&formatted_prompt)
})?;
if request.has_annotation(ANNOTATION_FORMATTED_PROMPT) { if request.has_annotation(ANNOTATION_FORMATTED_PROMPT) {
annotations.insert( annotations.insert(
...@@ -219,22 +205,21 @@ impl OpenAIPreprocessor { ...@@ -219,22 +205,21 @@ impl OpenAIPreprocessor {
if request.has_annotation(ANNOTATION_TOKEN_IDS) { if request.has_annotation(ANNOTATION_TOKEN_IDS) {
annotations.insert( annotations.insert(
ANNOTATION_TOKEN_IDS.to_string(), ANNOTATION_TOKEN_IDS.to_string(),
serde_json::to_string(&encoding.token_ids)?, serde_json::to_string(encoding.token_ids())?,
); );
} }
builder.token_ids(encoding.token_ids); builder.token_ids(encoding.token_ids().to_vec());
} }
TextInput::Batch(texts) => { TextInput::Batch(texts) => {
let token_batches: Result<Vec<Vec<u32>>, _> = texts let token_batches: Vec<Vec<u32>> = texts
.par_iter() .par_iter()
.map(|text| { .map(|text| {
tokio::task::block_in_place(|| self.tokenizer.encode(text)) self.tokenizer
.map(|encoding| encoding.token_ids) .encode(text)
.map(|encoded| encoded.token_ids().to_vec())
}) })
.collect(); .collect::<Result<Vec<_>>>()?;
let token_batches = token_batches?;
builder.batch_token_ids(Some(token_batches)); builder.batch_token_ids(Some(token_batches));
builder.token_ids(vec![]); builder.token_ids(vec![]);
} }
...@@ -285,8 +270,8 @@ impl OpenAIPreprocessor { ...@@ -285,8 +270,8 @@ impl OpenAIPreprocessor {
let all_token_ids = match &request.inner.input { let all_token_ids = match &request.inner.input {
async_openai::types::EmbeddingInput::String(s) => { async_openai::types::EmbeddingInput::String(s) => {
let encoding = tokio::task::block_in_place(|| self.tokenizer.encode(s))?; let encoding = self.tokenizer.encode(s)?;
vec![encoding.token_ids] vec![encoding.token_ids().to_vec()]
} }
async_openai::types::EmbeddingInput::StringArray(arr) => { async_openai::types::EmbeddingInput::StringArray(arr) => {
let input_strs: Vec<String> = arr.to_vec(); let input_strs: Vec<String> = arr.to_vec();
...@@ -300,7 +285,7 @@ impl OpenAIPreprocessor { ...@@ -300,7 +285,7 @@ impl OpenAIPreprocessor {
.await??; .await??;
let token_arrays: Vec<Vec<u32>> = encodings let token_arrays: Vec<Vec<u32>> = encodings
.into_iter() .into_iter()
.map(|encoding| encoding.token_ids) .map(|encoding| encoding.token_ids().to_vec())
.collect(); .collect();
token_arrays token_arrays
} }
......
...@@ -46,11 +46,27 @@ pub enum TokenizerType { ...@@ -46,11 +46,27 @@ pub enum TokenizerType {
pub type Offsets = (usize, usize); pub type Offsets = (usize, usize);
/// Contains the results of tokenizing text: token IDs, string tokens, and their spans /// Contains the results of tokenizing text: token IDs, string tokens, and their spans
#[derive(Debug, Hash)] #[derive(Debug, Clone)]
pub struct Encoding { pub enum Encoding {
pub token_ids: Vec<TokenIdType>, /// Hugging Face
pub tokens: Vec<String>, Hf(Box<tokenizers::tokenizer::Encoding>),
pub spans: Vec<Offsets>, /// Sentence Piece
Sp(Vec<TokenIdType>),
}
impl Encoding {
pub fn token_ids(&self) -> &[u32] {
match self {
Encoding::Hf(inner) => inner.get_ids(),
Encoding::Sp(inner) => inner,
}
}
}
impl Hash for Encoding {
fn hash<H: Hasher>(&self, state: &mut H) {
self.token_ids().hash(state);
}
} }
pub mod traits { pub mod traits {
...@@ -194,8 +210,8 @@ impl DecodeStream { ...@@ -194,8 +210,8 @@ impl DecodeStream {
Self { Self {
tokenizer, tokenizer,
skip_special_tokens, skip_special_tokens,
ids: Vec::new(), ids: Vec::with_capacity(64),
prefix: "".to_string(), prefix: String::with_capacity(64),
prefix_index: 0, prefix_index: 0,
read_index: 0, read_index: 0,
} }
...@@ -211,25 +227,23 @@ impl DecodeStream { ...@@ -211,25 +227,23 @@ impl DecodeStream {
/// a valid chunk. /// a valid chunk.
pub fn step(&mut self, id: u32) -> Result<Option<String>> { pub fn step(&mut self, id: u32) -> Result<Option<String>> {
self.ids.push(id); self.ids.push(id);
let string = self let decoded = self.tokenizer.decode(&self.ids, self.skip_special_tokens)?;
.tokenizer
.decode(self.ids.as_slice(), self.skip_special_tokens)?;
if string.len() > self.prefix.len() && !string.ends_with('�') { if decoded.len() <= self.prefix.len() || decoded.ends_with('�') {
if !(string.starts_with(&self.prefix)) { return Ok(None);
}
if !decoded.starts_with(&self.prefix) {
anyhow::bail!("Detokenizer failure: invalid prefix"); anyhow::bail!("Detokenizer failure: invalid prefix");
} }
let new_text = &string[self.prefix.len()..].to_string(); let new_text = decoded[self.prefix.len()..].to_string();
let new_prefix_index = self.ids.len() - self.prefix_index;
self.prefix = self self.prefix = decoded;
.tokenizer
.decode(self.ids.as_slice(), self.skip_special_tokens)?;
self.read_index = self.prefix_index; self.read_index = self.prefix_index;
let new_prefix_index = self.ids.len() - self.prefix_index;
self.prefix_index = new_prefix_index; self.prefix_index = new_prefix_index;
Ok(Some(new_text.to_string()))
} else { Ok(Some(new_text))
Ok(None)
}
} }
} }
...@@ -255,11 +269,12 @@ impl std::fmt::Debug for Sequence { ...@@ -255,11 +269,12 @@ impl std::fmt::Debug for Sequence {
.field( .field(
"token_ids", "token_ids",
&format_args!("{}", { &format_args!("{}", {
if self.token_ids.len() <= 20 { let token_ids = self.token_ids();
format!("{:?}", self.token_ids) if token_ids.len() <= 20 {
format!("{:?}", token_ids)
} else { } else {
let first_ten = &self.token_ids[..10]; let first_ten = &token_ids[..10];
let last_ten = &self.token_ids[self.token_ids.len() - 10..]; let last_ten = &token_ids[token_ids.len() - 10..];
format!("{:?} ... {:?}", first_ten, last_ten) format!("{:?} ... {:?}", first_ten, last_ten)
} }
}), }),
...@@ -301,7 +316,7 @@ impl Sequence { ...@@ -301,7 +316,7 @@ impl Sequence {
// })?; // })?;
let encoding = self.tokenizer.encode(input)?; let encoding = self.tokenizer.encode(input)?;
self.token_ids.extend(encoding.token_ids); self.token_ids.extend(encoding.token_ids());
Ok(()) Ok(())
} }
......
...@@ -39,41 +39,24 @@ impl HuggingFaceTokenizer { ...@@ -39,41 +39,24 @@ impl HuggingFaceTokenizer {
impl Encoder for HuggingFaceTokenizer { impl Encoder for HuggingFaceTokenizer {
fn encode(&self, input: &str) -> Result<Encoding> { fn encode(&self, input: &str) -> Result<Encoding> {
// This self.tokenizer is the library
let encoding = self let encoding = self
.tokenizer .tokenizer
.encode(input, false) .encode(input, false)
.map_err(|err| Error::msg(format!("Error encoding input: {}", err)))?; .map_err(|err| Error::msg(format!("Error tokenizing input: {err}")))?;
let token_ids = encoding.get_ids().to_vec(); Ok(Encoding::Hf(Box::new(encoding)))
let tokens = encoding.get_tokens().to_vec();
let spans = encoding.get_offsets().to_vec();
Ok(Encoding {
token_ids,
tokens,
spans,
})
} }
fn encode_batch(&self, inputs: &[&str]) -> Result<Vec<Encoding>> { fn encode_batch(&self, inputs: &[&str]) -> Result<Vec<Encoding>> {
let hf_encodings = self let hf_encodings = self
.tokenizer .tokenizer
.encode_batch(inputs.to_vec(), false) .encode_batch(inputs.to_vec(), false)
.map_err(|err| Error::msg(format!("Error encoding input: {}", err)))?; .map_err(|err| Error::msg(format!("Error batch tokenizing input: {err}")))?;
let encodings = hf_encodings let encodings = hf_encodings
.into_iter() .into_iter()
.map(|encoding| { .map(|enc| Encoding::Hf(Box::new(enc)))
let token_ids = encoding.get_ids().to_vec();
let tokens = encoding.get_tokens().to_vec();
let spans = encoding.get_offsets().to_vec();
Encoding {
token_ids,
tokens,
spans,
}
})
.collect(); .collect();
Ok(encodings) Ok(encodings)
...@@ -82,10 +65,11 @@ impl Encoder for HuggingFaceTokenizer { ...@@ -82,10 +65,11 @@ impl Encoder for HuggingFaceTokenizer {
impl Decoder for HuggingFaceTokenizer { impl Decoder for HuggingFaceTokenizer {
fn decode(&self, token_ids: &[TokenIdType], skip_special_tokens: bool) -> Result<String> { fn decode(&self, token_ids: &[TokenIdType], skip_special_tokens: bool) -> Result<String> {
// This calls into the library
let text = self let text = self
.tokenizer .tokenizer
.decode(token_ids, skip_special_tokens) .decode(token_ids, skip_special_tokens)
.map_err(|err| Error::msg(format!("Error decoding input: {}", err)))?; .map_err(|err| Error::msg(format!("Error de-tokenizing input: {err}")))?;
Ok(text) Ok(text)
} }
......
...@@ -57,21 +57,8 @@ impl Encoder for SentencePieceTokenizer { ...@@ -57,21 +57,8 @@ impl Encoder for SentencePieceTokenizer {
.encode(input) .encode(input)
.map_err(|err| Error::msg(format!("Error encoding input: {}", err)))?; .map_err(|err| Error::msg(format!("Error encoding input: {}", err)))?;
let mut token_ids = Vec::new(); let token_ids = encoding.into_iter().map(|piece| piece.id).collect();
let mut tokens = Vec::new(); Ok(Encoding::Sp(token_ids))
let mut spans = Vec::new();
for piece in encoding {
token_ids.push(piece.id);
tokens.push(piece.piece);
spans.push((piece.span.0 as usize, piece.span.1 as usize));
}
Ok(Encoding {
token_ids,
tokens,
spans,
})
} }
/// Encodes multiple string inputs into tokens using the SentencePiece model. /// Encodes multiple string inputs into tokens using the SentencePiece model.
......
...@@ -44,10 +44,10 @@ const HF_TOKENIZERS_LOCAL: [&str; 1] = [TINYLLAMA_TOKENIZER_PATH]; ...@@ -44,10 +44,10 @@ const HF_TOKENIZERS_LOCAL: [&str; 1] = [TINYLLAMA_TOKENIZER_PATH];
const HASHES: [(&str, [u64; 4]); 1] = [( const HASHES: [(&str, [u64; 4]); 1] = [(
TINYLLAMA_TOKENIZER_PATH, TINYLLAMA_TOKENIZER_PATH,
[ [
771185775798505393, 1209591529327510910,
8538328482215529710, 4181375434596349981,
17087868772360018644, 6245658446118930933,
1660219240238826577, 5097285695902185237,
], ],
)]; )];
...@@ -93,7 +93,7 @@ fn test_hf_lifecycle() { ...@@ -93,7 +93,7 @@ fn test_hf_lifecycle() {
.expect("Failed to encode prompt"); .expect("Failed to encode prompt");
let decoded = tokenizer let decoded = tokenizer
.decode(&encoding.token_ids, false) .decode(encoding.token_ids(), false)
.expect("Failed to decode token_ids"); .expect("Failed to decode token_ids");
assert_eq!(decoded, TEST_PROMPTS[0]); assert_eq!(decoded, TEST_PROMPTS[0]);
...@@ -117,14 +117,14 @@ fn test_sequence() { ...@@ -117,14 +117,14 @@ fn test_sequence() {
.append_text(TEST_PROMPTS[0]) .append_text(TEST_PROMPTS[0])
.expect("Failed to append prompt"); .expect("Failed to append prompt");
assert_eq!(sequence.len(), encoding.token_ids.len()); assert_eq!(sequence.len(), encoding.token_ids().len());
let mut decoder = Sequence::new(shared_tokenizer.clone().into()); let mut decoder = Sequence::new(shared_tokenizer.clone().into());
let mut output = String::new(); let mut output = String::new();
for token_id in encoding.token_ids.clone() { for token_id in encoding.token_ids() {
let text = decoder let text = decoder
.append_token_id(token_id) .append_token_id(*token_id)
.expect("Failed to decode token_id"); .expect("Failed to decode token_id");
output.push_str(text.as_str()); output.push_str(text.as_str());
} }
...@@ -135,8 +135,8 @@ fn test_sequence() { ...@@ -135,8 +135,8 @@ fn test_sequence() {
let mut decoder = DecodeStream::new(shared_tokenizer.clone(), false); let mut decoder = DecodeStream::new(shared_tokenizer.clone(), false);
let mut output = String::new(); let mut output = String::new();
for token_id in encoding.token_ids { for token_id in encoding.token_ids() {
let text = decoder.step(token_id).expect("Failed to decode token_id"); let text = decoder.step(*token_id).expect("Failed to decode token_id");
if let Some(text) = text { if let Some(text) = text {
output.push_str(text.as_str()); output.push_str(text.as_str());
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment