Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
61a1f4ff
Unverified
Commit
61a1f4ff
authored
Jul 10, 2025
by
Graham King
Committed by
GitHub
Jul 10, 2025
Browse files
perf(tokenizer): Make de-tokenize ~50% faster (#1868)
parent
f242b455
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
317 additions
and
132 deletions
+317
-132
Cargo.lock
Cargo.lock
+168
-25
lib/llm/Cargo.toml
lib/llm/Cargo.toml
+5
-0
lib/llm/benches/tokenizer.rs
lib/llm/benches/tokenizer.rs
+65
-0
lib/llm/src/backend.rs
lib/llm/src/backend.rs
+3
-2
lib/llm/src/entrypoint/input/batch.rs
lib/llm/src/entrypoint/input/batch.rs
+2
-2
lib/llm/src/preprocessor.rs
lib/llm/src/preprocessor.rs
+11
-26
lib/llm/src/tokenizers.rs
lib/llm/src/tokenizers.rs
+44
-29
lib/llm/src/tokenizers/hf.rs
lib/llm/src/tokenizers/hf.rs
+7
-23
lib/llm/src/tokenizers/sp.rs
lib/llm/src/tokenizers/sp.rs
+2
-15
lib/llm/tests/tokenizers.rs
lib/llm/tests/tokenizers.rs
+10
-10
No files found.
Cargo.lock
View file @
61a1f4ff
...
@@ -334,6 +334,17 @@ version = "1.1.2"
...
@@ -334,6 +334,17 @@ version = "1.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0"
checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0"
[[package]]
name = "atty"
version = "0.2.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8"
dependencies = [
"hermit-abi 0.1.19",
"libc",
"winapi 0.3.9",
]
[[package]]
[[package]]
name = "autocfg"
name = "autocfg"
version = "1.4.0"
version = "1.4.0"
...
@@ -794,7 +805,7 @@ dependencies = [
...
@@ -794,7 +805,7 @@ dependencies = [
"cudarc 0.13.9",
"cudarc 0.13.9",
"float8",
"float8",
"gemm 0.17.1",
"gemm 0.17.1",
"half",
"half
2.6.0
",
"memmap2",
"memmap2",
"metal",
"metal",
"num-traits",
"num-traits",
...
@@ -816,7 +827,7 @@ checksum = "06ccf5ee3532e66868516d9b315f73aec9f34ea1a37ae98514534d458915dbf1"
...
@@ -816,7 +827,7 @@ checksum = "06ccf5ee3532e66868516d9b315f73aec9f34ea1a37ae98514534d458915dbf1"
dependencies = [
dependencies = [
"byteorder",
"byteorder",
"gemm 0.17.1",
"gemm 0.17.1",
"half",
"half
2.6.0
",
"memmap2",
"memmap2",
"num-traits",
"num-traits",
"num_cpus",
"num_cpus",
...
@@ -856,7 +867,7 @@ source = "git+https://github.com/EricLBuehler/candle.git?rev=98c0436e#98c0436eaf
...
@@ -856,7 +867,7 @@ source = "git+https://github.com/EricLBuehler/candle.git?rev=98c0436e#98c0436eaf
dependencies = [
dependencies = [
"candle-core 0.8.0",
"candle-core 0.8.0",
"candle-metal-kernels",
"candle-metal-kernels",
"half",
"half
2.6.0
",
"metal",
"metal",
"num-traits",
"num-traits",
"rayon",
"rayon",
...
@@ -865,13 +876,19 @@ dependencies = [
...
@@ -865,13 +876,19 @@ dependencies = [
"thiserror 1.0.69",
"thiserror 1.0.69",
]
]
[[package]]
name = "cast"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5"
[[package]]
[[package]]
name = "cbindgen"
name = "cbindgen"
version = "0.27.0"
version = "0.27.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3fce8dd7fcfcbf3a0a87d8f515194b49d6135acab73e18bd380d1d93bb1a15eb"
checksum = "3fce8dd7fcfcbf3a0a87d8f515194b49d6135acab73e18bd380d1d93bb1a15eb"
dependencies = [
dependencies = [
"clap",
"clap
4.5.40
",
"heck 0.4.1",
"heck 0.4.1",
"indexmap 2.9.0",
"indexmap 2.9.0",
"log",
"log",
...
@@ -972,6 +989,17 @@ dependencies = [
...
@@ -972,6 +989,17 @@ dependencies = [
"libloading",
"libloading",
]
]
[[package]]
name = "clap"
version = "2.34.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a0610544180c38b88101fecf2dd634b174a62eef6946f84dfc6a7127512b381c"
dependencies = [
"bitflags 1.3.2",
"textwrap",
"unicode-width 0.1.14",
]
[[package]]
[[package]]
name = "clap"
name = "clap"
version = "4.5.40"
version = "4.5.40"
...
@@ -1114,6 +1142,42 @@ dependencies = [
...
@@ -1114,6 +1142,42 @@ dependencies = [
"cfg-if 1.0.0",
"cfg-if 1.0.0",
]
]
[[package]]
name = "criterion"
version = "0.3.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b01d6de93b2b6c65e17c634a26653a29d107b3c98c607c765bf38d041531cd8f"
dependencies = [
"atty",
"cast",
"clap 2.34.0",
"criterion-plot",
"csv",
"itertools 0.10.5",
"lazy_static",
"num-traits",
"oorandom",
"plotters",
"rayon",
"regex",
"serde",
"serde_cbor",
"serde_derive",
"serde_json",
"tinytemplate",
"walkdir",
]
[[package]]
name = "criterion-plot"
version = "0.4.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2673cc8207403546f45f5fd319a974b1e6983ad1a3ee7e6041650013be041876"
dependencies = [
"cast",
"itertools 0.10.5",
]
[[package]]
[[package]]
name = "crossbeam"
name = "crossbeam"
version = "0.8.4"
version = "0.8.4"
...
@@ -1261,7 +1325,7 @@ version = "0.13.9"
...
@@ -1261,7 +1325,7 @@ version = "0.13.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "486c221362668c63a1636cfa51463b09574433b39029326cff40864b3ba12b6e"
checksum = "486c221362668c63a1636cfa51463b09574433b39029326cff40864b3ba12b6e"
dependencies = [
dependencies = [
"half",
"half
2.6.0
",
"libloading",
"libloading",
]
]
...
@@ -1559,7 +1623,7 @@ version = "0.2.0"
...
@@ -1559,7 +1623,7 @@ version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09157630eece4139f6cc5a457556d308c3465ecd5af492f0e5aadc043997e2ce"
checksum = "09157630eece4139f6cc5a457556d308c3465ecd5af492f0e5aadc043997e2ce"
dependencies = [
dependencies = [
"half",
"half
2.6.0
",
]
]
[[package]]
[[package]]
...
@@ -1722,6 +1786,7 @@ dependencies = [
...
@@ -1722,6 +1786,7 @@ dependencies = [
"bytes",
"bytes",
"candle-core 0.8.4",
"candle-core 0.8.4",
"chrono",
"chrono",
"criterion",
"cudarc 0.16.2",
"cudarc 0.16.2",
"derive-getters",
"derive-getters",
"derive_builder",
"derive_builder",
...
@@ -1784,7 +1849,7 @@ dependencies = [
...
@@ -1784,7 +1849,7 @@ dependencies = [
"async-openai",
"async-openai",
"async-stream",
"async-stream",
"async-trait",
"async-trait",
"clap",
"clap
4.5.40
",
"dynamo-engine-llamacpp",
"dynamo-engine-llamacpp",
"dynamo-engine-mistralrs",
"dynamo-engine-mistralrs",
"dynamo-llm",
"dynamo-llm",
...
@@ -2099,7 +2164,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
...
@@ -2099,7 +2164,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f83197f59927b46c04a183a619b7c29df34e63e63c7869320862268c0ef687e0"
checksum = "f83197f59927b46c04a183a619b7c29df34e63e63c7869320862268c0ef687e0"
dependencies = [
dependencies = [
"bit_field",
"bit_field",
"half",
"half
2.6.0
",
"lebe",
"lebe",
"miniz_oxide",
"miniz_oxide",
"rayon-core",
"rayon-core",
...
@@ -2194,7 +2259,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
...
@@ -2194,7 +2259,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dee36245af1dccf978103fcd393582806db2a1d0bcd2f38c663cdbb4a363a01c"
checksum = "dee36245af1dccf978103fcd393582806db2a1d0bcd2f38c663cdbb4a363a01c"
dependencies = [
dependencies = [
"cudarc 0.13.9",
"cudarc 0.13.9",
"half",
"half
2.6.0
",
"num-traits",
"num-traits",
"rand 0.9.1",
"rand 0.9.1",
"rand_distr",
"rand_distr",
...
@@ -2496,7 +2561,7 @@ checksum = "a2e7ea062c987abcd8db95db917b4ffb4ecdfd0668471d8dc54734fdff2354e8"
...
@@ -2496,7 +2561,7 @@ checksum = "a2e7ea062c987abcd8db95db917b4ffb4ecdfd0668471d8dc54734fdff2354e8"
dependencies = [
dependencies = [
"bytemuck",
"bytemuck",
"dyn-stack 0.10.0",
"dyn-stack 0.10.0",
"half",
"half
2.6.0
",
"num-complex",
"num-complex",
"num-traits",
"num-traits",
"once_cell",
"once_cell",
...
@@ -2516,7 +2581,7 @@ checksum = "a352d4a69cbe938b9e2a9cb7a3a63b7e72f9349174a2752a558a8a563510d0f3"
...
@@ -2516,7 +2581,7 @@ checksum = "a352d4a69cbe938b9e2a9cb7a3a63b7e72f9349174a2752a558a8a563510d0f3"
dependencies = [
dependencies = [
"bytemuck",
"bytemuck",
"dyn-stack 0.13.0",
"dyn-stack 0.13.0",
"half",
"half
2.6.0
",
"libm",
"libm",
"num-complex",
"num-complex",
"num-traits",
"num-traits",
...
@@ -2538,7 +2603,7 @@ dependencies = [
...
@@ -2538,7 +2603,7 @@ dependencies = [
"dyn-stack 0.10.0",
"dyn-stack 0.10.0",
"gemm-common 0.17.1",
"gemm-common 0.17.1",
"gemm-f32 0.17.1",
"gemm-f32 0.17.1",
"half",
"half
2.6.0
",
"num-complex",
"num-complex",
"num-traits",
"num-traits",
"paste",
"paste",
...
@@ -2556,7 +2621,7 @@ dependencies = [
...
@@ -2556,7 +2621,7 @@ dependencies = [
"dyn-stack 0.13.0",
"dyn-stack 0.13.0",
"gemm-common 0.18.2",
"gemm-common 0.18.2",
"gemm-f32 0.18.2",
"gemm-f32 0.18.2",
"half",
"half
2.6.0
",
"num-complex",
"num-complex",
"num-traits",
"num-traits",
"paste",
"paste",
...
@@ -2678,7 +2743,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
...
@@ -2678,7 +2743,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a27693512784e0786212eb0bef841779a6337d2d04520ed475b4d5a864f98366"
checksum = "a27693512784e0786212eb0bef841779a6337d2d04520ed475b4d5a864f98366"
dependencies = [
dependencies = [
"digit-layout",
"digit-layout",
"half",
"half
2.6.0
",
"rayon",
"rayon",
]
]
...
@@ -2749,6 +2814,12 @@ dependencies = [
...
@@ -2749,6 +2814,12 @@ dependencies = [
"tracing",
"tracing",
]
]
[[package]]
name = "half"
version = "1.8.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1b43ede17f21864e81be2fa654110bf1e793774238d86ef8555c37e6519c0403"
[[package]]
[[package]]
name = "half"
name = "half"
version = "2.6.0"
version = "2.6.0"
...
@@ -2802,6 +2873,15 @@ version = "0.5.0"
...
@@ -2802,6 +2873,15 @@ version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
[[package]]
name = "hermit-abi"
version = "0.1.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "62b467343b94ba476dcb2500d242dadbb39557df889310ac77c5d99100aaac33"
dependencies = [
"libc",
]
[[package]]
[[package]]
name = "hermit-abi"
name = "hermit-abi"
version = "0.3.9"
version = "0.3.9"
...
@@ -2897,7 +2977,7 @@ dependencies = [
...
@@ -2897,7 +2977,7 @@ dependencies = [
name = "http"
name = "http"
version = "0.3.2"
version = "0.3.2"
dependencies = [
dependencies = [
"clap",
"clap
4.5.40
",
"dynamo-llm",
"dynamo-llm",
"dynamo-runtime",
"dynamo-runtime",
"serde",
"serde",
...
@@ -3655,7 +3735,7 @@ name = "llmctl"
...
@@ -3655,7 +3735,7 @@ name = "llmctl"
version = "0.3.2"
version = "0.3.2"
dependencies = [
dependencies = [
"anyhow",
"anyhow",
"clap",
"clap
4.5.40
",
"dynamo-llm",
"dynamo-llm",
"dynamo-runtime",
"dynamo-runtime",
"serde",
"serde",
...
@@ -3855,7 +3935,7 @@ name = "metrics"
...
@@ -3855,7 +3935,7 @@ name = "metrics"
version = "0.3.2"
version = "0.3.2"
dependencies = [
dependencies = [
"axum 0.6.20",
"axum 0.6.20",
"clap",
"clap
4.5.40
",
"dynamo-llm",
"dynamo-llm",
"dynamo-runtime",
"dynamo-runtime",
"futures",
"futures",
...
@@ -3985,7 +4065,7 @@ dependencies = [
...
@@ -3985,7 +4065,7 @@ dependencies = [
"anyhow",
"anyhow",
"candle-core 0.8.0",
"candle-core 0.8.0",
"candle-nn",
"candle-nn",
"clap",
"clap
4.5.40
",
"either",
"either",
"futures",
"futures",
"image",
"image",
...
@@ -4030,7 +4110,7 @@ dependencies = [
...
@@ -4030,7 +4110,7 @@ dependencies = [
"candle-nn",
"candle-nn",
"cfgrammar",
"cfgrammar",
"chrono",
"chrono",
"clap",
"clap
4.5.40
",
"csv",
"csv",
"derive-new",
"derive-new",
"derive_more 2.0.1",
"derive_more 2.0.1",
...
@@ -4039,7 +4119,7 @@ dependencies = [
...
@@ -4039,7 +4119,7 @@ dependencies = [
"float8",
"float8",
"futures",
"futures",
"galil-seiferas",
"galil-seiferas",
"half",
"half
2.6.0
",
"hashbrown 0.15.4",
"hashbrown 0.15.4",
"hf-hub",
"hf-hub",
"hound",
"hound",
...
@@ -4133,7 +4213,7 @@ dependencies = [
...
@@ -4133,7 +4213,7 @@ dependencies = [
"bindgen_cuda 0.1.6",
"bindgen_cuda 0.1.6",
"candle-core 0.8.0",
"candle-core 0.8.0",
"float8",
"float8",
"half",
"half
2.6.0
",
"metal",
"metal",
"once_cell",
"once_cell",
"thiserror 2.0.12",
"thiserror 2.0.12",
...
@@ -4149,7 +4229,7 @@ dependencies = [
...
@@ -4149,7 +4229,7 @@ dependencies = [
"candle-core 0.8.0",
"candle-core 0.8.0",
"candle-nn",
"candle-nn",
"float8",
"float8",
"half",
"half
2.6.0
",
"hf-hub",
"hf-hub",
"lazy_static",
"lazy_static",
"memmap2",
"memmap2",
...
@@ -4479,7 +4559,7 @@ version = "1.16.0"
...
@@ -4479,7 +4559,7 @@ version = "1.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43"
checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43"
dependencies = [
dependencies = [
"hermit-abi",
"hermit-abi
0.3.9
",
"libc",
"libc",
]
]
...
@@ -4591,6 +4671,12 @@ dependencies = [
...
@@ -4591,6 +4671,12 @@ dependencies = [
"pkg-config",
"pkg-config",
]
]
[[package]]
name = "oorandom"
version = "11.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e"
[[package]]
[[package]]
name = "openssl-probe"
name = "openssl-probe"
version = "0.1.6"
version = "0.1.6"
...
@@ -4872,6 +4958,34 @@ version = "0.3.32"
...
@@ -4872,6 +4958,34 @@ version = "0.3.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c"
checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c"
[[package]]
name = "plotters"
version = "0.3.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5aeb6f403d7a4911efb1e33402027fc44f29b5bf6def3effcc22d7bb75f2b747"
dependencies = [
"num-traits",
"plotters-backend",
"plotters-svg",
"wasm-bindgen",
"web-sys",
]
[[package]]
name = "plotters-backend"
version = "0.3.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df42e13c12958a16b3f7f4386b9ab1f3e7933914ecea48da7139435263a4172a"
[[package]]
name = "plotters-svg"
version = "0.3.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "51bae2ac328883f7acdfea3d66a7c35751187f870bc81f94563733a154d7a670"
dependencies = [
"plotters-backend",
]
[[package]]
[[package]]
name = "png"
name = "png"
version = "0.17.16"
version = "0.17.16"
...
@@ -5596,7 +5710,7 @@ dependencies = [
...
@@ -5596,7 +5710,7 @@ dependencies = [
name = "router"
name = "router"
version = "0.3.2"
version = "0.3.2"
dependencies = [
dependencies = [
"clap",
"clap
4.5.40
",
"dynamo-llm",
"dynamo-llm",
"dynamo-runtime",
"dynamo-runtime",
"rand 0.9.1",
"rand 0.9.1",
...
@@ -6078,6 +6192,16 @@ dependencies = [
...
@@ -6078,6 +6192,16 @@ dependencies = [
"serde",
"serde",
]
]
[[package]]
name = "serde_cbor"
version = "0.11.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2bef2ebfde456fb76bbcf9f59315333decc4fda0b2b44b420243c11e0f5ec1f5"
dependencies = [
"half 1.8.3",
"serde",
]
[[package]]
[[package]]
name = "serde_derive"
name = "serde_derive"
version = "1.0.219"
version = "1.0.219"
...
@@ -6796,6 +6920,15 @@ dependencies = [
...
@@ -6796,6 +6920,15 @@ dependencies = [
"windows-sys 0.59.0",
"windows-sys 0.59.0",
]
]
[[package]]
name = "textwrap"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d326610f408c7a4eb6f51c37c330e496b08506c9457c9d34287ecc38809fb060"
dependencies = [
"unicode-width 0.1.14",
]
[[package]]
[[package]]
name = "thiserror"
name = "thiserror"
version = "1.0.69"
version = "1.0.69"
...
@@ -6900,6 +7033,16 @@ dependencies = [
...
@@ -6900,6 +7033,16 @@ dependencies = [
"zerovec",
"zerovec",
]
]
[[package]]
name = "tinytemplate"
version = "1.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc"
dependencies = [
"serde",
"serde_json",
]
[[package]]
[[package]]
name = "tinyvec"
name = "tinyvec"
version = "1.9.0"
version = "1.9.0"
...
@@ -7420,7 +7563,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
...
@@ -7420,7 +7563,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "03719c61a91b51541f076dfdba45caacf750b230cefaa4b32d6f5411c3f7f437"
checksum = "03719c61a91b51541f076dfdba45caacf750b230cefaa4b32d6f5411c3f7f437"
dependencies = [
dependencies = [
"gemm 0.18.2",
"gemm 0.18.2",
"half",
"half
2.6.0
",
"libloading",
"libloading",
"memmap2",
"memmap2",
"num",
"num",
...
...
lib/llm/Cargo.toml
View file @
61a1f4ff
...
@@ -36,6 +36,10 @@ testing-nixl = ["dep:nixl-sys"]
...
@@ -36,6 +36,10 @@ testing-nixl = ["dep:nixl-sys"]
block-manager
=
[
"dep:nixl-sys"
,
"dep:cudarc"
,
"dep:ndarray"
,
"dep:nix"
]
block-manager
=
[
"dep:nixl-sys"
,
"dep:cudarc"
,
"dep:ndarray"
,
"dep:nix"
]
sentencepiece
=
["dep:sentencepiece"]
sentencepiece
=
["dep:sentencepiece"]
[[bench]]
name
=
"tokenizer"
harness
=
false
[dependencies]
[dependencies]
# repo
# repo
dynamo-runtime
=
{
workspace
=
true
}
dynamo-runtime
=
{
workspace
=
true
}
...
@@ -126,6 +130,7 @@ rmp-serde = "1.3"
...
@@ -126,6 +130,7 @@ rmp-serde = "1.3"
[dev-dependencies]
[dev-dependencies]
assert_matches
=
"1.5"
assert_matches
=
"1.5"
criterion
=
{
version
=
"0.3"
,
features
=
["html_reports"]
}
hf-hub
=
{
workspace
=
true
}
hf-hub
=
{
workspace
=
true
}
proptest
=
"1.5.0"
proptest
=
"1.5.0"
reqwest
=
{
version
=
"0.12"
,
default-features
=
false
,
features
=
[
"json"
,
"stream"
,
"rustls-tls"
]
}
reqwest
=
{
version
=
"0.12"
,
default-features
=
false
,
features
=
[
"json"
,
"stream"
,
"rustls-tls"
]
}
...
...
lib/llm/benches/tokenizer.rs
0 → 100644
View file @
61a1f4ff
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use
std
::
hint
::
black_box
;
use
std
::
sync
::
Arc
;
use
criterion
::{
criterion_group
,
criterion_main
,
Criterion
,
Throughput
};
use
dynamo_llm
::
backend
::
Decoder
;
use
dynamo_llm
::
protocols
::
common
::
StopConditions
;
use
dynamo_llm
::
tokenizers
::
hf
::
HuggingFaceTokenizer
;
use
dynamo_llm
::
tokenizers
::
traits
::{
Encoder
,
Tokenizer
};
use
dynamo_llm
::
tokenizers
::
DecodeStream
;
use
dynamo_llm
::
types
::
TokenIdType
;
const
TEST_TOKENIZER
:
&
str
=
concat!
(
env!
(
"CARGO_MANIFEST_DIR"
),
"/tests/data/sample-models/TinyLlama_v1.1/tokenizer.json"
);
/// Input Sequence Length for tokenizer
const
TARGET_ISL
:
usize
=
8_000
;
// A string of length exactly 128 bytes.
const
INPUT_STR
:
&
str
=
"The cat sat by the window, watching raindrops race down the glass. Far thunder rumbled. She purred softly, feeling safe at home."
;
/// `cargo bench -- encode` to run it
pub
fn
encode
(
c
:
&
mut
Criterion
)
{
let
test_str
:
&
str
=
&
INPUT_STR
.repeat
(
TARGET_ISL
/
INPUT_STR
.len
());
let
encoder
=
HuggingFaceTokenizer
::
from_file
(
TEST_TOKENIZER
)
.unwrap
();
let
mut
group
=
c
.benchmark_group
(
"encode-group"
);
group
.throughput
(
Throughput
::
Bytes
(
test_str
.len
()
as
u64
));
group
.bench_function
(
"tokenizer_encode"
,
|
b
|
{
b
.iter
(||
{
let
_
=
encoder
.encode
(
black_box
(
test_str
))
.unwrap
();
})
});
group
.finish
();
}
pub
fn
decode
(
c
:
&
mut
Criterion
)
{
const
TEST_TOKS
:
[
TokenIdType
;
34
]
=
[
450
,
6635
,
3290
,
491
,
278
,
3474
,
29892
,
21217
,
1153
,
513
,
307
,
567
,
8175
,
1623
,
278
,
12917
,
29889
,
8413
,
266
,
5062
,
364
,
25443
,
29889
,
2296
,
3708
,
1127
,
4964
,
368
,
29892
,
11223
,
9109
,
472
,
3271
,
29889
,
];
let
tokenizer
:
Arc
<
dyn
Tokenizer
>
=
Arc
::
new
(
HuggingFaceTokenizer
::
from_file
(
TEST_TOKENIZER
)
.unwrap
());
let
ds
=
DecodeStream
::
new
(
tokenizer
,
false
);
let
mut
decoder
=
Decoder
::
new
(
ds
,
StopConditions
::
default
());
let
mut
group
=
c
.benchmark_group
(
"decode-group"
);
group
.throughput
(
Throughput
::
Bytes
(
TEST_TOKS
.len
()
as
u64
));
group
.bench_function
(
"tokenizer_decoder"
,
|
b
|
{
b
.iter
(||
{
for
tok
in
black_box
(
TEST_TOKS
)
{
let
_
=
decoder
.step
(
tok
)
.unwrap
();
}
})
});
group
.finish
();
}
criterion_group!
(
benches
,
encode
,
decode
);
criterion_main!
(
benches
);
lib/llm/src/backend.rs
View file @
61a1f4ff
...
@@ -466,7 +466,7 @@ impl Decoder {
...
@@ -466,7 +466,7 @@ impl Decoder {
pub
fn
process_token_ids
(
&
mut
self
,
token_ids
:
&
[
TokenIdType
])
->
Result
<
SeqResult
>
{
pub
fn
process_token_ids
(
&
mut
self
,
token_ids
:
&
[
TokenIdType
])
->
Result
<
SeqResult
>
{
let
mut
text
:
Option
<
String
>
=
None
;
let
mut
text
:
Option
<
String
>
=
None
;
let
mut
tokens
=
Vec
::
new
(
);
let
mut
tokens
=
Vec
::
with_capacity
(
token_ids
.len
()
);
for
token_id
in
token_ids
{
for
token_id
in
token_ids
{
let
StepResult
{
let
StepResult
{
...
@@ -481,7 +481,8 @@ impl Decoder {
...
@@ -481,7 +481,8 @@ impl Decoder {
if
!
hide_text
{
if
!
hide_text
{
if
let
Some
(
token
)
=
&
token
{
if
let
Some
(
token
)
=
&
token
{
text
.get_or_insert_with
(
String
::
new
)
.push_str
(
token
);
text
.get_or_insert_with
(||
String
::
with_capacity
(
token_ids
.len
()))
.push_str
(
token
);
}
}
}
}
tokens
.push
(
token
);
tokens
.push
(
token
);
...
...
lib/llm/src/entrypoint/input/batch.rs
View file @
61a1f4ff
...
@@ -142,14 +142,14 @@ pub async fn run(
...
@@ -142,14 +142,14 @@ pub async fn run(
if
let
Some
(
pre
)
=
pre_processor
{
if
let
Some
(
pre
)
=
pre_processor
{
// Note this does not include the prompt template. Probably TODO
// Note this does not include the prompt template. Probably TODO
entry
.tokens_in
=
match
pre
.tokenize
(
&
entry
.text
)
{
entry
.tokens_in
=
match
pre
.tokenize
(
&
entry
.text
)
{
Ok
(
encoding
)
=>
encoding
.token_ids
.len
(),
Ok
(
encoding
)
=>
encoding
.token_ids
()
.len
(),
Err
(
err
)
=>
{
Err
(
err
)
=>
{
tracing
::
warn!
(
%
err
,
entry
.text
,
"Failed tokenizing prompt"
);
tracing
::
warn!
(
%
err
,
entry
.text
,
"Failed tokenizing prompt"
);
0
0
}
}
};
};
entry
.tokens_out
=
match
pre
.tokenize
(
&
response
)
{
entry
.tokens_out
=
match
pre
.tokenize
(
&
response
)
{
Ok
(
encoding
)
=>
encoding
.token_ids
.len
(),
Ok
(
encoding
)
=>
encoding
.token_ids
()
.len
(),
Err
(
err
)
=>
{
Err
(
err
)
=>
{
tracing
::
warn!
(
%
err
,
response
,
"Failed tokenizing response"
);
tracing
::
warn!
(
%
err
,
response
,
"Failed tokenizing response"
);
0
0
...
...
lib/llm/src/preprocessor.rs
View file @
61a1f4ff
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! The Preprocessor consists of the following modules
//! The Preprocessor consists of the following modules
//!
//!
...
@@ -205,9 +193,7 @@ impl OpenAIPreprocessor {
...
@@ -205,9 +193,7 @@ impl OpenAIPreprocessor {
self
.formatter
.render
(
request
)
?
self
.formatter
.render
(
request
)
?
};
};
let
encoding
=
tokio
::
task
::
block_in_place
(||
{
let
encoding
=
self
.tokenizer
.encode
(
&
formatted_prompt
)
?
;
self
.tokenizer
.encode
(
&
formatted_prompt
)
})
?
;
if
request
.has_annotation
(
ANNOTATION_FORMATTED_PROMPT
)
{
if
request
.has_annotation
(
ANNOTATION_FORMATTED_PROMPT
)
{
annotations
.insert
(
annotations
.insert
(
...
@@ -219,22 +205,21 @@ impl OpenAIPreprocessor {
...
@@ -219,22 +205,21 @@ impl OpenAIPreprocessor {
if
request
.has_annotation
(
ANNOTATION_TOKEN_IDS
)
{
if
request
.has_annotation
(
ANNOTATION_TOKEN_IDS
)
{
annotations
.insert
(
annotations
.insert
(
ANNOTATION_TOKEN_IDS
.to_string
(),
ANNOTATION_TOKEN_IDS
.to_string
(),
serde_json
::
to_string
(
&
encoding
.token_ids
)
?
,
serde_json
::
to_string
(
encoding
.token_ids
()
)
?
,
);
);
}
}
builder
.token_ids
(
encoding
.token_ids
);
builder
.token_ids
(
encoding
.token_ids
()
.to_vec
()
);
}
}
TextInput
::
Batch
(
texts
)
=>
{
TextInput
::
Batch
(
texts
)
=>
{
let
token_batches
:
Result
<
Vec
<
Vec
<
u32
>>
,
_
>
=
texts
let
token_batches
:
Vec
<
Vec
<
u32
>>
=
texts
.par_iter
()
.par_iter
()
.map
(|
text
|
{
.map
(|
text
|
{
tokio
::
task
::
block_in_place
(||
self
.tokenizer
.encode
(
text
))
self
.tokenizer
.map
(|
encoding
|
encoding
.token_ids
)
.encode
(
text
)
.map
(|
encoded
|
encoded
.token_ids
()
.to_vec
())
})
})
.collect
();
.collect
::
<
Result
<
Vec
<
_
>>>
()
?
;
let
token_batches
=
token_batches
?
;
builder
.batch_token_ids
(
Some
(
token_batches
));
builder
.batch_token_ids
(
Some
(
token_batches
));
builder
.token_ids
(
vec!
[]);
builder
.token_ids
(
vec!
[]);
}
}
...
@@ -285,8 +270,8 @@ impl OpenAIPreprocessor {
...
@@ -285,8 +270,8 @@ impl OpenAIPreprocessor {
let
all_token_ids
=
match
&
request
.inner.input
{
let
all_token_ids
=
match
&
request
.inner.input
{
async_openai
::
types
::
EmbeddingInput
::
String
(
s
)
=>
{
async_openai
::
types
::
EmbeddingInput
::
String
(
s
)
=>
{
let
encoding
=
tokio
::
task
::
block_in_place
(||
self
.tokenizer
.encode
(
s
)
)
?
;
let
encoding
=
self
.tokenizer
.encode
(
s
)
?
;
vec!
[
encoding
.token_ids
]
vec!
[
encoding
.token_ids
()
.to_vec
()
]
}
}
async_openai
::
types
::
EmbeddingInput
::
StringArray
(
arr
)
=>
{
async_openai
::
types
::
EmbeddingInput
::
StringArray
(
arr
)
=>
{
let
input_strs
:
Vec
<
String
>
=
arr
.to_vec
();
let
input_strs
:
Vec
<
String
>
=
arr
.to_vec
();
...
@@ -300,7 +285,7 @@ impl OpenAIPreprocessor {
...
@@ -300,7 +285,7 @@ impl OpenAIPreprocessor {
.await
??
;
.await
??
;
let
token_arrays
:
Vec
<
Vec
<
u32
>>
=
encodings
let
token_arrays
:
Vec
<
Vec
<
u32
>>
=
encodings
.into_iter
()
.into_iter
()
.map
(|
encoding
|
encoding
.token_ids
)
.map
(|
encoding
|
encoding
.token_ids
()
.to_vec
()
)
.collect
();
.collect
();
token_arrays
token_arrays
}
}
...
...
lib/llm/src/tokenizers.rs
View file @
61a1f4ff
...
@@ -46,11 +46,27 @@ pub enum TokenizerType {
...
@@ -46,11 +46,27 @@ pub enum TokenizerType {
pub
type
Offsets
=
(
usize
,
usize
);
pub
type
Offsets
=
(
usize
,
usize
);
/// Contains the results of tokenizing text: token IDs, string tokens, and their spans
/// Contains the results of tokenizing text: token IDs, string tokens, and their spans
#[derive(Debug,
Hash)]
#[derive(Debug,
Clone)]
pub
struct
Encoding
{
pub
enum
Encoding
{
pub
token_ids
:
Vec
<
TokenIdType
>
,
/// Hugging Face
pub
tokens
:
Vec
<
String
>
,
Hf
(
Box
<
tokenizers
::
tokenizer
::
Encoding
>
),
pub
spans
:
Vec
<
Offsets
>
,
/// Sentence Piece
Sp
(
Vec
<
TokenIdType
>
),
}
impl
Encoding
{
pub
fn
token_ids
(
&
self
)
->
&
[
u32
]
{
match
self
{
Encoding
::
Hf
(
inner
)
=>
inner
.get_ids
(),
Encoding
::
Sp
(
inner
)
=>
inner
,
}
}
}
impl
Hash
for
Encoding
{
fn
hash
<
H
:
Hasher
>
(
&
self
,
state
:
&
mut
H
)
{
self
.token_ids
()
.hash
(
state
);
}
}
}
pub
mod
traits
{
pub
mod
traits
{
...
@@ -194,8 +210,8 @@ impl DecodeStream {
...
@@ -194,8 +210,8 @@ impl DecodeStream {
Self
{
Self
{
tokenizer
,
tokenizer
,
skip_special_tokens
,
skip_special_tokens
,
ids
:
Vec
::
new
(
),
ids
:
Vec
::
with_capacity
(
64
),
prefix
:
""
.to_string
(
),
prefix
:
String
::
with_capacity
(
64
),
prefix_index
:
0
,
prefix_index
:
0
,
read_index
:
0
,
read_index
:
0
,
}
}
...
@@ -211,25 +227,23 @@ impl DecodeStream {
...
@@ -211,25 +227,23 @@ impl DecodeStream {
/// a valid chunk.
/// a valid chunk.
pub
fn
step
(
&
mut
self
,
id
:
u32
)
->
Result
<
Option
<
String
>>
{
pub
fn
step
(
&
mut
self
,
id
:
u32
)
->
Result
<
Option
<
String
>>
{
self
.ids
.push
(
id
);
self
.ids
.push
(
id
);
let
string
=
self
let
decoded
=
self
.tokenizer
.decode
(
&
self
.ids
,
self
.skip_special_tokens
)
?
;
.tokenizer
.decode
(
self
.ids
.as_slice
(),
self
.skip_special_tokens
)
?
;
if
string
.len
()
>
self
.prefix
.len
()
&&
!
string
.ends_with
(
'�'
)
{
if
decoded
.len
()
<=
self
.prefix
.len
()
||
decoded
.ends_with
(
'�'
)
{
if
!
(
string
.starts_with
(
&
self
.prefix
))
{
return
Ok
(
None
);
}
if
!
decoded
.starts_with
(
&
self
.prefix
)
{
anyhow
::
bail!
(
"Detokenizer failure: invalid prefix"
);
anyhow
::
bail!
(
"Detokenizer failure: invalid prefix"
);
}
}
let
new_text
=
&
string
[
self
.prefix
.len
()
..
]
.to_string
();
let
new_text
=
decoded
[
self
.prefix
.len
()
..
]
.to_string
();
let
new_prefix_index
=
self
.ids
.len
()
-
self
.prefix_index
;
self
.prefix
=
self
self
.prefix
=
decoded
;
.tokenizer
.decode
(
self
.ids
.as_slice
(),
self
.skip_special_tokens
)
?
;
self
.read_index
=
self
.prefix_index
;
self
.read_index
=
self
.prefix_index
;
let
new_prefix_index
=
self
.ids
.len
()
-
self
.prefix_index
;
self
.prefix_index
=
new_prefix_index
;
self
.prefix_index
=
new_prefix_index
;
Ok
(
Some
(
new_text
.to_string
()))
}
else
{
Ok
(
Some
(
new_text
))
Ok
(
None
)
}
}
}
}
}
...
@@ -255,11 +269,12 @@ impl std::fmt::Debug for Sequence {
...
@@ -255,11 +269,12 @@ impl std::fmt::Debug for Sequence {
.field
(
.field
(
"token_ids"
,
"token_ids"
,
&
format_args!
(
"{}"
,
{
&
format_args!
(
"{}"
,
{
if
self
.token_ids
.len
()
<=
20
{
let
token_ids
=
self
.token_ids
();
format!
(
"{:?}"
,
self
.token_ids
)
if
token_ids
.len
()
<=
20
{
format!
(
"{:?}"
,
token_ids
)
}
else
{
}
else
{
let
first_ten
=
&
self
.
token_ids
[
..
10
];
let
first_ten
=
&
token_ids
[
..
10
];
let
last_ten
=
&
self
.
token_ids
[
self
.
token_ids
.len
()
-
10
..
];
let
last_ten
=
&
token_ids
[
token_ids
.len
()
-
10
..
];
format!
(
"{:?} ... {:?}"
,
first_ten
,
last_ten
)
format!
(
"{:?} ... {:?}"
,
first_ten
,
last_ten
)
}
}
}),
}),
...
@@ -301,7 +316,7 @@ impl Sequence {
...
@@ -301,7 +316,7 @@ impl Sequence {
// })?;
// })?;
let
encoding
=
self
.tokenizer
.encode
(
input
)
?
;
let
encoding
=
self
.tokenizer
.encode
(
input
)
?
;
self
.token_ids
.extend
(
encoding
.token_ids
);
self
.token_ids
.extend
(
encoding
.token_ids
()
);
Ok
(())
Ok
(())
}
}
...
...
lib/llm/src/tokenizers/hf.rs
View file @
61a1f4ff
...
@@ -39,41 +39,24 @@ impl HuggingFaceTokenizer {
...
@@ -39,41 +39,24 @@ impl HuggingFaceTokenizer {
impl
Encoder
for
HuggingFaceTokenizer
{
impl
Encoder
for
HuggingFaceTokenizer
{
fn
encode
(
&
self
,
input
:
&
str
)
->
Result
<
Encoding
>
{
fn
encode
(
&
self
,
input
:
&
str
)
->
Result
<
Encoding
>
{
// This self.tokenizer is the library
let
encoding
=
self
let
encoding
=
self
.tokenizer
.tokenizer
.encode
(
input
,
false
)
.encode
(
input
,
false
)
.map_err
(|
err
|
Error
::
msg
(
format!
(
"Error
encod
ing input: {
}"
,
err
)))
?
;
.map_err
(|
err
|
Error
::
msg
(
format!
(
"Error
tokeniz
ing input: {err
}"
)))
?
;
let
token_ids
=
encoding
.get_ids
()
.to_vec
();
Ok
(
Encoding
::
Hf
(
Box
::
new
(
encoding
)))
let
tokens
=
encoding
.get_tokens
()
.to_vec
();
let
spans
=
encoding
.get_offsets
()
.to_vec
();
Ok
(
Encoding
{
token_ids
,
tokens
,
spans
,
})
}
}
fn
encode_batch
(
&
self
,
inputs
:
&
[
&
str
])
->
Result
<
Vec
<
Encoding
>>
{
fn
encode_batch
(
&
self
,
inputs
:
&
[
&
str
])
->
Result
<
Vec
<
Encoding
>>
{
let
hf_encodings
=
self
let
hf_encodings
=
self
.tokenizer
.tokenizer
.encode_batch
(
inputs
.to_vec
(),
false
)
.encode_batch
(
inputs
.to_vec
(),
false
)
.map_err
(|
err
|
Error
::
msg
(
format!
(
"Error
encod
ing input: {
}"
,
err
)))
?
;
.map_err
(|
err
|
Error
::
msg
(
format!
(
"Error
batch tokeniz
ing input: {err
}"
)))
?
;
let
encodings
=
hf_encodings
let
encodings
=
hf_encodings
.into_iter
()
.into_iter
()
.map
(|
encoding
|
{
.map
(|
enc
|
Encoding
::
Hf
(
Box
::
new
(
enc
)))
let
token_ids
=
encoding
.get_ids
()
.to_vec
();
let
tokens
=
encoding
.get_tokens
()
.to_vec
();
let
spans
=
encoding
.get_offsets
()
.to_vec
();
Encoding
{
token_ids
,
tokens
,
spans
,
}
})
.collect
();
.collect
();
Ok
(
encodings
)
Ok
(
encodings
)
...
@@ -82,10 +65,11 @@ impl Encoder for HuggingFaceTokenizer {
...
@@ -82,10 +65,11 @@ impl Encoder for HuggingFaceTokenizer {
impl
Decoder
for
HuggingFaceTokenizer
{
impl
Decoder
for
HuggingFaceTokenizer
{
fn
decode
(
&
self
,
token_ids
:
&
[
TokenIdType
],
skip_special_tokens
:
bool
)
->
Result
<
String
>
{
fn
decode
(
&
self
,
token_ids
:
&
[
TokenIdType
],
skip_special_tokens
:
bool
)
->
Result
<
String
>
{
// This calls into the library
let
text
=
self
let
text
=
self
.tokenizer
.tokenizer
.decode
(
token_ids
,
skip_special_tokens
)
.decode
(
token_ids
,
skip_special_tokens
)
.map_err
(|
err
|
Error
::
msg
(
format!
(
"Error de
cod
ing input: {
}"
,
err
)))
?
;
.map_err
(|
err
|
Error
::
msg
(
format!
(
"Error de
-tokeniz
ing input: {err
}"
)))
?
;
Ok
(
text
)
Ok
(
text
)
}
}
...
...
lib/llm/src/tokenizers/sp.rs
View file @
61a1f4ff
...
@@ -57,21 +57,8 @@ impl Encoder for SentencePieceTokenizer {
...
@@ -57,21 +57,8 @@ impl Encoder for SentencePieceTokenizer {
.encode
(
input
)
.encode
(
input
)
.map_err
(|
err
|
Error
::
msg
(
format!
(
"Error encoding input: {}"
,
err
)))
?
;
.map_err
(|
err
|
Error
::
msg
(
format!
(
"Error encoding input: {}"
,
err
)))
?
;
let
mut
token_ids
=
Vec
::
new
();
let
token_ids
=
encoding
.into_iter
()
.map
(|
piece
|
piece
.id
)
.collect
();
let
mut
tokens
=
Vec
::
new
();
Ok
(
Encoding
::
Sp
(
token_ids
))
let
mut
spans
=
Vec
::
new
();
for
piece
in
encoding
{
token_ids
.push
(
piece
.id
);
tokens
.push
(
piece
.piece
);
spans
.push
((
piece
.span
.0
as
usize
,
piece
.span
.1
as
usize
));
}
Ok
(
Encoding
{
token_ids
,
tokens
,
spans
,
})
}
}
/// Encodes multiple string inputs into tokens using the SentencePiece model.
/// Encodes multiple string inputs into tokens using the SentencePiece model.
...
...
lib/llm/tests/tokenizers.rs
View file @
61a1f4ff
...
@@ -44,10 +44,10 @@ const HF_TOKENIZERS_LOCAL: [&str; 1] = [TINYLLAMA_TOKENIZER_PATH];
...
@@ -44,10 +44,10 @@ const HF_TOKENIZERS_LOCAL: [&str; 1] = [TINYLLAMA_TOKENIZER_PATH];
const
HASHES
:
[(
&
str
,
[
u64
;
4
]);
1
]
=
[(
const
HASHES
:
[(
&
str
,
[
u64
;
4
]);
1
]
=
[(
TINYLLAMA_TOKENIZER_PATH
,
TINYLLAMA_TOKENIZER_PATH
,
[
[
771185775798505393
,
1209591529327510910
,
8538328482215529710
,
4181375434596349981
,
17087868772360018644
,
6245658446118930933
,
166021924023882657
7
,
509728569590218523
7
,
],
],
)];
)];
...
@@ -93,7 +93,7 @@ fn test_hf_lifecycle() {
...
@@ -93,7 +93,7 @@ fn test_hf_lifecycle() {
.expect
(
"Failed to encode prompt"
);
.expect
(
"Failed to encode prompt"
);
let
decoded
=
tokenizer
let
decoded
=
tokenizer
.decode
(
&
encoding
.token_ids
,
false
)
.decode
(
encoding
.token_ids
()
,
false
)
.expect
(
"Failed to decode token_ids"
);
.expect
(
"Failed to decode token_ids"
);
assert_eq!
(
decoded
,
TEST_PROMPTS
[
0
]);
assert_eq!
(
decoded
,
TEST_PROMPTS
[
0
]);
...
@@ -117,14 +117,14 @@ fn test_sequence() {
...
@@ -117,14 +117,14 @@ fn test_sequence() {
.append_text
(
TEST_PROMPTS
[
0
])
.append_text
(
TEST_PROMPTS
[
0
])
.expect
(
"Failed to append prompt"
);
.expect
(
"Failed to append prompt"
);
assert_eq!
(
sequence
.len
(),
encoding
.token_ids
.len
());
assert_eq!
(
sequence
.len
(),
encoding
.token_ids
()
.len
());
let
mut
decoder
=
Sequence
::
new
(
shared_tokenizer
.clone
()
.into
());
let
mut
decoder
=
Sequence
::
new
(
shared_tokenizer
.clone
()
.into
());
let
mut
output
=
String
::
new
();
let
mut
output
=
String
::
new
();
for
token_id
in
encoding
.token_ids
.clone
()
{
for
token_id
in
encoding
.token_ids
()
{
let
text
=
decoder
let
text
=
decoder
.append_token_id
(
token_id
)
.append_token_id
(
*
token_id
)
.expect
(
"Failed to decode token_id"
);
.expect
(
"Failed to decode token_id"
);
output
.push_str
(
text
.as_str
());
output
.push_str
(
text
.as_str
());
}
}
...
@@ -135,8 +135,8 @@ fn test_sequence() {
...
@@ -135,8 +135,8 @@ fn test_sequence() {
let
mut
decoder
=
DecodeStream
::
new
(
shared_tokenizer
.clone
(),
false
);
let
mut
decoder
=
DecodeStream
::
new
(
shared_tokenizer
.clone
(),
false
);
let
mut
output
=
String
::
new
();
let
mut
output
=
String
::
new
();
for
token_id
in
encoding
.token_ids
{
for
token_id
in
encoding
.token_ids
()
{
let
text
=
decoder
.step
(
token_id
)
.expect
(
"Failed to decode token_id"
);
let
text
=
decoder
.step
(
*
token_id
)
.expect
(
"Failed to decode token_id"
);
if
let
Some
(
text
)
=
text
{
if
let
Some
(
text
)
=
text
{
output
.push_str
(
text
.as_str
());
output
.push_str
(
text
.as_str
());
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment