Unverified Commit 61a1f4ff authored by Graham King's avatar Graham King Committed by GitHub
Browse files

perf(tokenizer): Make de-tokenize ~50% faster (#1868)

parent f242b455
......@@ -334,6 +334,17 @@ version = "1.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0"
[[package]]
name = "atty"
version = "0.2.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8"
dependencies = [
"hermit-abi 0.1.19",
"libc",
"winapi 0.3.9",
]
[[package]]
name = "autocfg"
version = "1.4.0"
......@@ -794,7 +805,7 @@ dependencies = [
"cudarc 0.13.9",
"float8",
"gemm 0.17.1",
"half",
"half 2.6.0",
"memmap2",
"metal",
"num-traits",
......@@ -816,7 +827,7 @@ checksum = "06ccf5ee3532e66868516d9b315f73aec9f34ea1a37ae98514534d458915dbf1"
dependencies = [
"byteorder",
"gemm 0.17.1",
"half",
"half 2.6.0",
"memmap2",
"num-traits",
"num_cpus",
......@@ -856,7 +867,7 @@ source = "git+https://github.com/EricLBuehler/candle.git?rev=98c0436e#98c0436eaf
dependencies = [
"candle-core 0.8.0",
"candle-metal-kernels",
"half",
"half 2.6.0",
"metal",
"num-traits",
"rayon",
......@@ -865,13 +876,19 @@ dependencies = [
"thiserror 1.0.69",
]
[[package]]
name = "cast"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5"
[[package]]
name = "cbindgen"
version = "0.27.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3fce8dd7fcfcbf3a0a87d8f515194b49d6135acab73e18bd380d1d93bb1a15eb"
dependencies = [
"clap",
"clap 4.5.40",
"heck 0.4.1",
"indexmap 2.9.0",
"log",
......@@ -972,6 +989,17 @@ dependencies = [
"libloading",
]
[[package]]
name = "clap"
version = "2.34.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a0610544180c38b88101fecf2dd634b174a62eef6946f84dfc6a7127512b381c"
dependencies = [
"bitflags 1.3.2",
"textwrap",
"unicode-width 0.1.14",
]
[[package]]
name = "clap"
version = "4.5.40"
......@@ -1114,6 +1142,42 @@ dependencies = [
"cfg-if 1.0.0",
]
[[package]]
name = "criterion"
version = "0.3.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b01d6de93b2b6c65e17c634a26653a29d107b3c98c607c765bf38d041531cd8f"
dependencies = [
"atty",
"cast",
"clap 2.34.0",
"criterion-plot",
"csv",
"itertools 0.10.5",
"lazy_static",
"num-traits",
"oorandom",
"plotters",
"rayon",
"regex",
"serde",
"serde_cbor",
"serde_derive",
"serde_json",
"tinytemplate",
"walkdir",
]
[[package]]
name = "criterion-plot"
version = "0.4.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2673cc8207403546f45f5fd319a974b1e6983ad1a3ee7e6041650013be041876"
dependencies = [
"cast",
"itertools 0.10.5",
]
[[package]]
name = "crossbeam"
version = "0.8.4"
......@@ -1261,7 +1325,7 @@ version = "0.13.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "486c221362668c63a1636cfa51463b09574433b39029326cff40864b3ba12b6e"
dependencies = [
"half",
"half 2.6.0",
"libloading",
]
......@@ -1559,7 +1623,7 @@ version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09157630eece4139f6cc5a457556d308c3465ecd5af492f0e5aadc043997e2ce"
dependencies = [
"half",
"half 2.6.0",
]
[[package]]
......@@ -1722,6 +1786,7 @@ dependencies = [
"bytes",
"candle-core 0.8.4",
"chrono",
"criterion",
"cudarc 0.16.2",
"derive-getters",
"derive_builder",
......@@ -1784,7 +1849,7 @@ dependencies = [
"async-openai",
"async-stream",
"async-trait",
"clap",
"clap 4.5.40",
"dynamo-engine-llamacpp",
"dynamo-engine-mistralrs",
"dynamo-llm",
......@@ -2099,7 +2164,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f83197f59927b46c04a183a619b7c29df34e63e63c7869320862268c0ef687e0"
dependencies = [
"bit_field",
"half",
"half 2.6.0",
"lebe",
"miniz_oxide",
"rayon-core",
......@@ -2194,7 +2259,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dee36245af1dccf978103fcd393582806db2a1d0bcd2f38c663cdbb4a363a01c"
dependencies = [
"cudarc 0.13.9",
"half",
"half 2.6.0",
"num-traits",
"rand 0.9.1",
"rand_distr",
......@@ -2496,7 +2561,7 @@ checksum = "a2e7ea062c987abcd8db95db917b4ffb4ecdfd0668471d8dc54734fdff2354e8"
dependencies = [
"bytemuck",
"dyn-stack 0.10.0",
"half",
"half 2.6.0",
"num-complex",
"num-traits",
"once_cell",
......@@ -2516,7 +2581,7 @@ checksum = "a352d4a69cbe938b9e2a9cb7a3a63b7e72f9349174a2752a558a8a563510d0f3"
dependencies = [
"bytemuck",
"dyn-stack 0.13.0",
"half",
"half 2.6.0",
"libm",
"num-complex",
"num-traits",
......@@ -2538,7 +2603,7 @@ dependencies = [
"dyn-stack 0.10.0",
"gemm-common 0.17.1",
"gemm-f32 0.17.1",
"half",
"half 2.6.0",
"num-complex",
"num-traits",
"paste",
......@@ -2556,7 +2621,7 @@ dependencies = [
"dyn-stack 0.13.0",
"gemm-common 0.18.2",
"gemm-f32 0.18.2",
"half",
"half 2.6.0",
"num-complex",
"num-traits",
"paste",
......@@ -2678,7 +2743,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a27693512784e0786212eb0bef841779a6337d2d04520ed475b4d5a864f98366"
dependencies = [
"digit-layout",
"half",
"half 2.6.0",
"rayon",
]
......@@ -2749,6 +2814,12 @@ dependencies = [
"tracing",
]
[[package]]
name = "half"
version = "1.8.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1b43ede17f21864e81be2fa654110bf1e793774238d86ef8555c37e6519c0403"
[[package]]
name = "half"
version = "2.6.0"
......@@ -2802,6 +2873,15 @@ version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
[[package]]
name = "hermit-abi"
version = "0.1.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "62b467343b94ba476dcb2500d242dadbb39557df889310ac77c5d99100aaac33"
dependencies = [
"libc",
]
[[package]]
name = "hermit-abi"
version = "0.3.9"
......@@ -2897,7 +2977,7 @@ dependencies = [
name = "http"
version = "0.3.2"
dependencies = [
"clap",
"clap 4.5.40",
"dynamo-llm",
"dynamo-runtime",
"serde",
......@@ -3655,7 +3735,7 @@ name = "llmctl"
version = "0.3.2"
dependencies = [
"anyhow",
"clap",
"clap 4.5.40",
"dynamo-llm",
"dynamo-runtime",
"serde",
......@@ -3855,7 +3935,7 @@ name = "metrics"
version = "0.3.2"
dependencies = [
"axum 0.6.20",
"clap",
"clap 4.5.40",
"dynamo-llm",
"dynamo-runtime",
"futures",
......@@ -3985,7 +4065,7 @@ dependencies = [
"anyhow",
"candle-core 0.8.0",
"candle-nn",
"clap",
"clap 4.5.40",
"either",
"futures",
"image",
......@@ -4030,7 +4110,7 @@ dependencies = [
"candle-nn",
"cfgrammar",
"chrono",
"clap",
"clap 4.5.40",
"csv",
"derive-new",
"derive_more 2.0.1",
......@@ -4039,7 +4119,7 @@ dependencies = [
"float8",
"futures",
"galil-seiferas",
"half",
"half 2.6.0",
"hashbrown 0.15.4",
"hf-hub",
"hound",
......@@ -4133,7 +4213,7 @@ dependencies = [
"bindgen_cuda 0.1.6",
"candle-core 0.8.0",
"float8",
"half",
"half 2.6.0",
"metal",
"once_cell",
"thiserror 2.0.12",
......@@ -4149,7 +4229,7 @@ dependencies = [
"candle-core 0.8.0",
"candle-nn",
"float8",
"half",
"half 2.6.0",
"hf-hub",
"lazy_static",
"memmap2",
......@@ -4479,7 +4559,7 @@ version = "1.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43"
dependencies = [
"hermit-abi",
"hermit-abi 0.3.9",
"libc",
]
......@@ -4591,6 +4671,12 @@ dependencies = [
"pkg-config",
]
[[package]]
name = "oorandom"
version = "11.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e"
[[package]]
name = "openssl-probe"
version = "0.1.6"
......@@ -4872,6 +4958,34 @@ version = "0.3.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c"
[[package]]
name = "plotters"
version = "0.3.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5aeb6f403d7a4911efb1e33402027fc44f29b5bf6def3effcc22d7bb75f2b747"
dependencies = [
"num-traits",
"plotters-backend",
"plotters-svg",
"wasm-bindgen",
"web-sys",
]
[[package]]
name = "plotters-backend"
version = "0.3.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df42e13c12958a16b3f7f4386b9ab1f3e7933914ecea48da7139435263a4172a"
[[package]]
name = "plotters-svg"
version = "0.3.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "51bae2ac328883f7acdfea3d66a7c35751187f870bc81f94563733a154d7a670"
dependencies = [
"plotters-backend",
]
[[package]]
name = "png"
version = "0.17.16"
......@@ -5596,7 +5710,7 @@ dependencies = [
name = "router"
version = "0.3.2"
dependencies = [
"clap",
"clap 4.5.40",
"dynamo-llm",
"dynamo-runtime",
"rand 0.9.1",
......@@ -6078,6 +6192,16 @@ dependencies = [
"serde",
]
[[package]]
name = "serde_cbor"
version = "0.11.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2bef2ebfde456fb76bbcf9f59315333decc4fda0b2b44b420243c11e0f5ec1f5"
dependencies = [
"half 1.8.3",
"serde",
]
[[package]]
name = "serde_derive"
version = "1.0.219"
......@@ -6796,6 +6920,15 @@ dependencies = [
"windows-sys 0.59.0",
]
[[package]]
name = "textwrap"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d326610f408c7a4eb6f51c37c330e496b08506c9457c9d34287ecc38809fb060"
dependencies = [
"unicode-width 0.1.14",
]
[[package]]
name = "thiserror"
version = "1.0.69"
......@@ -6900,6 +7033,16 @@ dependencies = [
"zerovec",
]
[[package]]
name = "tinytemplate"
version = "1.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc"
dependencies = [
"serde",
"serde_json",
]
[[package]]
name = "tinyvec"
version = "1.9.0"
......@@ -7420,7 +7563,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "03719c61a91b51541f076dfdba45caacf750b230cefaa4b32d6f5411c3f7f437"
dependencies = [
"gemm 0.18.2",
"half",
"half 2.6.0",
"libloading",
"memmap2",
"num",
......
......@@ -36,6 +36,10 @@ testing-nixl = ["dep:nixl-sys"]
block-manager = ["dep:nixl-sys", "dep:cudarc", "dep:ndarray", "dep:nix"]
sentencepiece = ["dep:sentencepiece"]
[[bench]]
name = "tokenizer"
harness = false
[dependencies]
# repo
dynamo-runtime = { workspace = true }
......@@ -126,6 +130,7 @@ rmp-serde = "1.3"
[dev-dependencies]
assert_matches = "1.5"
criterion = { version = "0.3", features = ["html_reports"] }
hf-hub = { workspace = true }
proptest = "1.5.0"
reqwest = { version = "0.12", default-features = false, features = ["json", "stream", "rustls-tls"] }
......
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::hint::black_box;
use std::sync::Arc;
use criterion::{criterion_group, criterion_main, Criterion, Throughput};
use dynamo_llm::backend::Decoder;
use dynamo_llm::protocols::common::StopConditions;
use dynamo_llm::tokenizers::hf::HuggingFaceTokenizer;
use dynamo_llm::tokenizers::traits::{Encoder, Tokenizer};
use dynamo_llm::tokenizers::DecodeStream;
use dynamo_llm::types::TokenIdType;
const TEST_TOKENIZER: &str = concat!(
env!("CARGO_MANIFEST_DIR"),
"/tests/data/sample-models/TinyLlama_v1.1/tokenizer.json"
);
/// Input Sequence Length for tokenizer
const TARGET_ISL: usize = 8_000;
// A string of length exactly 128 bytes.
const INPUT_STR: &str = "The cat sat by the window, watching raindrops race down the glass. Far thunder rumbled. She purred softly, feeling safe at home.";
/// `cargo bench -- encode` to run it
pub fn encode(c: &mut Criterion) {
let test_str: &str = &INPUT_STR.repeat(TARGET_ISL / INPUT_STR.len());
let encoder = HuggingFaceTokenizer::from_file(TEST_TOKENIZER).unwrap();
let mut group = c.benchmark_group("encode-group");
group.throughput(Throughput::Bytes(test_str.len() as u64));
group.bench_function("tokenizer_encode", |b| {
b.iter(|| {
let _ = encoder.encode(black_box(test_str)).unwrap();
})
});
group.finish();
}
pub fn decode(c: &mut Criterion) {
const TEST_TOKS: [TokenIdType; 34] = [
450, 6635, 3290, 491, 278, 3474, 29892, 21217, 1153, 513, 307, 567, 8175, 1623, 278, 12917,
29889, 8413, 266, 5062, 364, 25443, 29889, 2296, 3708, 1127, 4964, 368, 29892, 11223, 9109,
472, 3271, 29889,
];
let tokenizer: Arc<dyn Tokenizer> =
Arc::new(HuggingFaceTokenizer::from_file(TEST_TOKENIZER).unwrap());
let ds = DecodeStream::new(tokenizer, false);
let mut decoder = Decoder::new(ds, StopConditions::default());
let mut group = c.benchmark_group("decode-group");
group.throughput(Throughput::Bytes(TEST_TOKS.len() as u64));
group.bench_function("tokenizer_decoder", |b| {
b.iter(|| {
for tok in black_box(TEST_TOKS) {
let _ = decoder.step(tok).unwrap();
}
})
});
group.finish();
}
criterion_group!(benches, encode, decode);
criterion_main!(benches);
......@@ -466,7 +466,7 @@ impl Decoder {
pub fn process_token_ids(&mut self, token_ids: &[TokenIdType]) -> Result<SeqResult> {
let mut text: Option<String> = None;
let mut tokens = Vec::new();
let mut tokens = Vec::with_capacity(token_ids.len());
for token_id in token_ids {
let StepResult {
......@@ -481,7 +481,8 @@ impl Decoder {
if !hide_text {
if let Some(token) = &token {
text.get_or_insert_with(String::new).push_str(token);
text.get_or_insert_with(|| String::with_capacity(token_ids.len()))
.push_str(token);
}
}
tokens.push(token);
......
......@@ -142,14 +142,14 @@ pub async fn run(
if let Some(pre) = pre_processor {
// Note this does not include the prompt template. Probably TODO
entry.tokens_in = match pre.tokenize(&entry.text) {
Ok(encoding) => encoding.token_ids.len(),
Ok(encoding) => encoding.token_ids().len(),
Err(err) => {
tracing::warn!(%err, entry.text, "Failed tokenizing prompt");
0
}
};
entry.tokens_out = match pre.tokenize(&response) {
Ok(encoding) => encoding.token_ids.len(),
Ok(encoding) => encoding.token_ids().len(),
Err(err) => {
tracing::warn!(%err, response, "Failed tokenizing response");
0
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! The Preprocessor consists of the following modules
//!
......@@ -205,9 +193,7 @@ impl OpenAIPreprocessor {
self.formatter.render(request)?
};
let encoding = tokio::task::block_in_place(|| {
self.tokenizer.encode(&formatted_prompt)
})?;
let encoding = self.tokenizer.encode(&formatted_prompt)?;
if request.has_annotation(ANNOTATION_FORMATTED_PROMPT) {
annotations.insert(
......@@ -219,22 +205,21 @@ impl OpenAIPreprocessor {
if request.has_annotation(ANNOTATION_TOKEN_IDS) {
annotations.insert(
ANNOTATION_TOKEN_IDS.to_string(),
serde_json::to_string(&encoding.token_ids)?,
serde_json::to_string(encoding.token_ids())?,
);
}
builder.token_ids(encoding.token_ids);
builder.token_ids(encoding.token_ids().to_vec());
}
TextInput::Batch(texts) => {
let token_batches: Result<Vec<Vec<u32>>, _> = texts
let token_batches: Vec<Vec<u32>> = texts
.par_iter()
.map(|text| {
tokio::task::block_in_place(|| self.tokenizer.encode(text))
.map(|encoding| encoding.token_ids)
self.tokenizer
.encode(text)
.map(|encoded| encoded.token_ids().to_vec())
})
.collect();
let token_batches = token_batches?;
.collect::<Result<Vec<_>>>()?;
builder.batch_token_ids(Some(token_batches));
builder.token_ids(vec![]);
}
......@@ -285,8 +270,8 @@ impl OpenAIPreprocessor {
let all_token_ids = match &request.inner.input {
async_openai::types::EmbeddingInput::String(s) => {
let encoding = tokio::task::block_in_place(|| self.tokenizer.encode(s))?;
vec![encoding.token_ids]
let encoding = self.tokenizer.encode(s)?;
vec![encoding.token_ids().to_vec()]
}
async_openai::types::EmbeddingInput::StringArray(arr) => {
let input_strs: Vec<String> = arr.to_vec();
......@@ -300,7 +285,7 @@ impl OpenAIPreprocessor {
.await??;
let token_arrays: Vec<Vec<u32>> = encodings
.into_iter()
.map(|encoding| encoding.token_ids)
.map(|encoding| encoding.token_ids().to_vec())
.collect();
token_arrays
}
......
......@@ -46,11 +46,27 @@ pub enum TokenizerType {
pub type Offsets = (usize, usize);
/// Contains the results of tokenizing text: token IDs, string tokens, and their spans
#[derive(Debug, Hash)]
pub struct Encoding {
pub token_ids: Vec<TokenIdType>,
pub tokens: Vec<String>,
pub spans: Vec<Offsets>,
#[derive(Debug, Clone)]
pub enum Encoding {
/// Hugging Face
Hf(Box<tokenizers::tokenizer::Encoding>),
/// Sentence Piece
Sp(Vec<TokenIdType>),
}
impl Encoding {
pub fn token_ids(&self) -> &[u32] {
match self {
Encoding::Hf(inner) => inner.get_ids(),
Encoding::Sp(inner) => inner,
}
}
}
impl Hash for Encoding {
fn hash<H: Hasher>(&self, state: &mut H) {
self.token_ids().hash(state);
}
}
pub mod traits {
......@@ -194,8 +210,8 @@ impl DecodeStream {
Self {
tokenizer,
skip_special_tokens,
ids: Vec::new(),
prefix: "".to_string(),
ids: Vec::with_capacity(64),
prefix: String::with_capacity(64),
prefix_index: 0,
read_index: 0,
}
......@@ -211,25 +227,23 @@ impl DecodeStream {
/// a valid chunk.
pub fn step(&mut self, id: u32) -> Result<Option<String>> {
self.ids.push(id);
let string = self
.tokenizer
.decode(self.ids.as_slice(), self.skip_special_tokens)?;
let decoded = self.tokenizer.decode(&self.ids, self.skip_special_tokens)?;
if string.len() > self.prefix.len() && !string.ends_with('�') {
if !(string.starts_with(&self.prefix)) {
if decoded.len() <= self.prefix.len() || decoded.ends_with('�') {
return Ok(None);
}
if !decoded.starts_with(&self.prefix) {
anyhow::bail!("Detokenizer failure: invalid prefix");
}
let new_text = &string[self.prefix.len()..].to_string();
let new_prefix_index = self.ids.len() - self.prefix_index;
self.prefix = self
.tokenizer
.decode(self.ids.as_slice(), self.skip_special_tokens)?;
let new_text = decoded[self.prefix.len()..].to_string();
self.prefix = decoded;
self.read_index = self.prefix_index;
let new_prefix_index = self.ids.len() - self.prefix_index;
self.prefix_index = new_prefix_index;
Ok(Some(new_text.to_string()))
} else {
Ok(None)
}
Ok(Some(new_text))
}
}
......@@ -255,11 +269,12 @@ impl std::fmt::Debug for Sequence {
.field(
"token_ids",
&format_args!("{}", {
if self.token_ids.len() <= 20 {
format!("{:?}", self.token_ids)
let token_ids = self.token_ids();
if token_ids.len() <= 20 {
format!("{:?}", token_ids)
} else {
let first_ten = &self.token_ids[..10];
let last_ten = &self.token_ids[self.token_ids.len() - 10..];
let first_ten = &token_ids[..10];
let last_ten = &token_ids[token_ids.len() - 10..];
format!("{:?} ... {:?}", first_ten, last_ten)
}
}),
......@@ -301,7 +316,7 @@ impl Sequence {
// })?;
let encoding = self.tokenizer.encode(input)?;
self.token_ids.extend(encoding.token_ids);
self.token_ids.extend(encoding.token_ids());
Ok(())
}
......
......@@ -39,41 +39,24 @@ impl HuggingFaceTokenizer {
impl Encoder for HuggingFaceTokenizer {
fn encode(&self, input: &str) -> Result<Encoding> {
// This self.tokenizer is the library
let encoding = self
.tokenizer
.encode(input, false)
.map_err(|err| Error::msg(format!("Error encoding input: {}", err)))?;
.map_err(|err| Error::msg(format!("Error tokenizing input: {err}")))?;
let token_ids = encoding.get_ids().to_vec();
let tokens = encoding.get_tokens().to_vec();
let spans = encoding.get_offsets().to_vec();
Ok(Encoding {
token_ids,
tokens,
spans,
})
Ok(Encoding::Hf(Box::new(encoding)))
}
fn encode_batch(&self, inputs: &[&str]) -> Result<Vec<Encoding>> {
let hf_encodings = self
.tokenizer
.encode_batch(inputs.to_vec(), false)
.map_err(|err| Error::msg(format!("Error encoding input: {}", err)))?;
.map_err(|err| Error::msg(format!("Error batch tokenizing input: {err}")))?;
let encodings = hf_encodings
.into_iter()
.map(|encoding| {
let token_ids = encoding.get_ids().to_vec();
let tokens = encoding.get_tokens().to_vec();
let spans = encoding.get_offsets().to_vec();
Encoding {
token_ids,
tokens,
spans,
}
})
.map(|enc| Encoding::Hf(Box::new(enc)))
.collect();
Ok(encodings)
......@@ -82,10 +65,11 @@ impl Encoder for HuggingFaceTokenizer {
impl Decoder for HuggingFaceTokenizer {
fn decode(&self, token_ids: &[TokenIdType], skip_special_tokens: bool) -> Result<String> {
// This calls into the library
let text = self
.tokenizer
.decode(token_ids, skip_special_tokens)
.map_err(|err| Error::msg(format!("Error decoding input: {}", err)))?;
.map_err(|err| Error::msg(format!("Error de-tokenizing input: {err}")))?;
Ok(text)
}
......
......@@ -57,21 +57,8 @@ impl Encoder for SentencePieceTokenizer {
.encode(input)
.map_err(|err| Error::msg(format!("Error encoding input: {}", err)))?;
let mut token_ids = Vec::new();
let mut tokens = Vec::new();
let mut spans = Vec::new();
for piece in encoding {
token_ids.push(piece.id);
tokens.push(piece.piece);
spans.push((piece.span.0 as usize, piece.span.1 as usize));
}
Ok(Encoding {
token_ids,
tokens,
spans,
})
let token_ids = encoding.into_iter().map(|piece| piece.id).collect();
Ok(Encoding::Sp(token_ids))
}
/// Encodes multiple string inputs into tokens using the SentencePiece model.
......
......@@ -44,10 +44,10 @@ const HF_TOKENIZERS_LOCAL: [&str; 1] = [TINYLLAMA_TOKENIZER_PATH];
const HASHES: [(&str, [u64; 4]); 1] = [(
TINYLLAMA_TOKENIZER_PATH,
[
771185775798505393,
8538328482215529710,
17087868772360018644,
1660219240238826577,
1209591529327510910,
4181375434596349981,
6245658446118930933,
5097285695902185237,
],
)];
......@@ -93,7 +93,7 @@ fn test_hf_lifecycle() {
.expect("Failed to encode prompt");
let decoded = tokenizer
.decode(&encoding.token_ids, false)
.decode(encoding.token_ids(), false)
.expect("Failed to decode token_ids");
assert_eq!(decoded, TEST_PROMPTS[0]);
......@@ -117,14 +117,14 @@ fn test_sequence() {
.append_text(TEST_PROMPTS[0])
.expect("Failed to append prompt");
assert_eq!(sequence.len(), encoding.token_ids.len());
assert_eq!(sequence.len(), encoding.token_ids().len());
let mut decoder = Sequence::new(shared_tokenizer.clone().into());
let mut output = String::new();
for token_id in encoding.token_ids.clone() {
for token_id in encoding.token_ids() {
let text = decoder
.append_token_id(token_id)
.append_token_id(*token_id)
.expect("Failed to decode token_id");
output.push_str(text.as_str());
}
......@@ -135,8 +135,8 @@ fn test_sequence() {
let mut decoder = DecodeStream::new(shared_tokenizer.clone(), false);
let mut output = String::new();
for token_id in encoding.token_ids {
let text = decoder.step(token_id).expect("Failed to decode token_id");
for token_id in encoding.token_ids() {
let text = decoder.step(*token_id).expect("Failed to decode token_id");
if let Some(text) = text {
output.push_str(text.as_str());
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment