Unverified Commit 53e2cd46 authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

[router] remove all tokenizer metrics for performance (#9474)

parent 9708d353
use super::traits::{self, Tokenizer as TokenizerTrait}; use super::traits;
use crate::metrics::TokenizerMetrics;
use anyhow::{Error, Result}; use anyhow::{Error, Result};
use std::fs::File; use std::fs::File;
use std::io::Read; use std::io::Read;
use std::path::Path; use std::path::Path;
use std::sync::Arc; use std::sync::Arc;
use std::time::Instant;
#[cfg(feature = "huggingface")] #[cfg(feature = "huggingface")]
use super::huggingface::HuggingFaceTokenizer; use super::huggingface::HuggingFaceTokenizer;
...@@ -34,8 +32,6 @@ pub fn create_tokenizer_with_chat_template( ...@@ -34,8 +32,6 @@ pub fn create_tokenizer_with_chat_template(
file_path: &str, file_path: &str,
chat_template_path: Option<&str>, chat_template_path: Option<&str>,
) -> Result<Arc<dyn traits::Tokenizer>> { ) -> Result<Arc<dyn traits::Tokenizer>> {
let start_time = Instant::now();
// Special case for testing // Special case for testing
if file_path == "mock" || file_path == "test" { if file_path == "mock" || file_path == "test" {
return Ok(Arc::new(super::mock::MockTokenizer::new())); return Ok(Arc::new(super::mock::MockTokenizer::new()));
...@@ -45,7 +41,6 @@ pub fn create_tokenizer_with_chat_template( ...@@ -45,7 +41,6 @@ pub fn create_tokenizer_with_chat_template(
// Check if file exists // Check if file exists
if !path.exists() { if !path.exists() {
TokenizerMetrics::record_factory_error("file_not_found");
return Err(Error::msg(format!("File not found: {}", file_path))); return Err(Error::msg(format!("File not found: {}", file_path)));
} }
...@@ -64,14 +59,10 @@ pub fn create_tokenizer_with_chat_template( ...@@ -64,14 +59,10 @@ pub fn create_tokenizer_with_chat_template(
chat_template_path, chat_template_path,
)?; )?;
TokenizerMetrics::record_factory_load("json");
TokenizerMetrics::set_vocab_size("huggingface", tokenizer.vocab_size());
Ok(Arc::new(tokenizer) as Arc<dyn traits::Tokenizer>) Ok(Arc::new(tokenizer) as Arc<dyn traits::Tokenizer>)
} }
#[cfg(not(feature = "huggingface"))] #[cfg(not(feature = "huggingface"))]
{ {
TokenizerMetrics::record_factory_error("huggingface_disabled");
Err(Error::msg( Err(Error::msg(
"HuggingFace support not enabled. Enable the 'huggingface' feature.", "HuggingFace support not enabled. Enable the 'huggingface' feature.",
)) ))
...@@ -79,26 +70,18 @@ pub fn create_tokenizer_with_chat_template( ...@@ -79,26 +70,18 @@ pub fn create_tokenizer_with_chat_template(
} }
Some("model") => { Some("model") => {
// SentencePiece model file // SentencePiece model file
TokenizerMetrics::record_factory_error("unsupported_sentencepiece");
Err(Error::msg("SentencePiece models not yet supported")) Err(Error::msg("SentencePiece models not yet supported"))
} }
Some("gguf") => { Some("gguf") => {
// GGUF format // GGUF format
TokenizerMetrics::record_factory_error("unsupported_gguf");
Err(Error::msg("GGUF format not yet supported")) Err(Error::msg("GGUF format not yet supported"))
} }
_ => { _ => {
// Try to auto-detect by reading file content // Try to auto-detect by reading file content
auto_detect_tokenizer(file_path).inspect(|tokenizer| { auto_detect_tokenizer(file_path)
TokenizerMetrics::record_factory_load("auto_detected");
TokenizerMetrics::set_vocab_size("auto_detected", tokenizer.vocab_size());
})
} }
}; };
if result.is_ok() {
TokenizerMetrics::record_factory_load_duration(start_time.elapsed());
}
result result
} }
...@@ -190,8 +173,6 @@ pub fn create_tokenizer(model_name_or_path: &str) -> Result<Arc<dyn traits::Toke ...@@ -190,8 +173,6 @@ pub fn create_tokenizer(model_name_or_path: &str) -> Result<Arc<dyn traits::Toke
{ {
use super::tiktoken::TiktokenTokenizer; use super::tiktoken::TiktokenTokenizer;
let tokenizer = TiktokenTokenizer::from_model_name(model_name_or_path)?; let tokenizer = TiktokenTokenizer::from_model_name(model_name_or_path)?;
TokenizerMetrics::record_factory_load("tiktoken");
TokenizerMetrics::set_vocab_size("tiktoken", tokenizer.vocab_size());
return Ok(Arc::new(tokenizer)); return Ok(Arc::new(tokenizer));
} }
} }
...@@ -286,7 +267,7 @@ mod tests { ...@@ -286,7 +267,7 @@ mod tests {
// Test encoding and decoding // Test encoding and decoding
let text = "Hello, world!"; let text = "Hello, world!";
let encoding = tokenizer.encode(text).unwrap(); let encoding = tokenizer.encode(text).unwrap();
let decoded = tokenizer.decode(&encoding.token_ids(), false).unwrap(); let decoded = tokenizer.decode(encoding.token_ids(), false).unwrap();
assert_eq!(decoded, text); assert_eq!(decoded, text);
} }
} }
use super::traits::{ use super::traits::{
Decoder, Encoder, Encoding, SpecialTokens, TokenIdType, Tokenizer as TokenizerTrait, Decoder, Encoder, Encoding, SpecialTokens, TokenIdType, Tokenizer as TokenizerTrait,
}; };
use crate::metrics::TokenizerMetrics;
use anyhow::{Error, Result}; use anyhow::{Error, Result};
use std::collections::HashMap; use std::collections::HashMap;
use std::time::Instant;
use tokenizers::tokenizer::Tokenizer as HfTokenizer; use tokenizers::tokenizer::Tokenizer as HfTokenizer;
#[cfg(feature = "minijinja")] #[cfg(feature = "minijinja")]
...@@ -196,36 +194,17 @@ impl HuggingFaceTokenizer { ...@@ -196,36 +194,17 @@ impl HuggingFaceTokenizer {
impl Encoder for HuggingFaceTokenizer { impl Encoder for HuggingFaceTokenizer {
fn encode(&self, input: &str) -> Result<Encoding> { fn encode(&self, input: &str) -> Result<Encoding> {
let start = Instant::now();
TokenizerMetrics::record_encode_request("huggingface");
TokenizerMetrics::record_chars_per_encode(input.len());
self.tokenizer self.tokenizer
.encode(input, false) .encode(input, false)
.map_err(|e| { .map_err(|e| Error::msg(format!("Encoding failed: {}", e)))
TokenizerMetrics::record_encode_error("encoding_failed"); .map(|encoding| Encoding::Hf(Box::new(encoding)))
Error::msg(format!("Encoding failed: {}", e))
})
.map(|encoding| {
TokenizerMetrics::record_tokens_per_encode(encoding.get_ids().len());
TokenizerMetrics::record_encode_duration(start.elapsed());
Encoding::Hf(Box::new(encoding))
})
} }
fn encode_batch(&self, inputs: &[&str]) -> Result<Vec<Encoding>> { fn encode_batch(&self, inputs: &[&str]) -> Result<Vec<Encoding>> {
let start = Instant::now();
let encodings = self let encodings = self
.tokenizer .tokenizer
.encode_batch(inputs.to_vec(), false) .encode_batch(inputs.to_vec(), false)
.map_err(|e| { .map_err(|e| Error::msg(format!("Batch encoding failed: {}", e)))?;
TokenizerMetrics::record_encode_error("batch_encoding_failed");
Error::msg(format!("Batch encoding failed: {}", e))
})?;
TokenizerMetrics::record_encode_batch_duration(start.elapsed(), inputs.len());
Ok(encodings Ok(encodings
.into_iter() .into_iter()
...@@ -236,20 +215,9 @@ impl Encoder for HuggingFaceTokenizer { ...@@ -236,20 +215,9 @@ impl Encoder for HuggingFaceTokenizer {
impl Decoder for HuggingFaceTokenizer { impl Decoder for HuggingFaceTokenizer {
fn decode(&self, token_ids: &[TokenIdType], skip_special_tokens: bool) -> Result<String> { fn decode(&self, token_ids: &[TokenIdType], skip_special_tokens: bool) -> Result<String> {
let start = Instant::now();
TokenizerMetrics::record_decode_request("huggingface");
TokenizerMetrics::record_tokens_per_decode(token_ids.len());
self.tokenizer self.tokenizer
.decode(token_ids, skip_special_tokens) .decode(token_ids, skip_special_tokens)
.map_err(|e| { .map_err(|e| Error::msg(format!("Decoding failed: {}", e)))
TokenizerMetrics::record_decode_error("decoding_failed");
Error::msg(format!("Decoding failed: {}", e))
})
.inspect(|_| {
TokenizerMetrics::record_decode_duration(start.elapsed());
})
} }
} }
......
use super::traits::{self, TokenIdType}; use super::traits::{self, TokenIdType};
use crate::metrics::TokenizerMetrics;
use anyhow::Result; use anyhow::Result;
use std::collections::HashSet; use std::collections::HashSet;
use std::sync::Arc; use std::sync::Arc;
use std::time::Instant;
/// Output from the sequence decoder /// Output from the sequence decoder
#[derive(Debug, Clone, PartialEq)] #[derive(Debug, Clone, PartialEq)]
...@@ -95,8 +93,6 @@ impl StopSequenceDecoder { ...@@ -95,8 +93,6 @@ impl StopSequenceDecoder {
/// Process a single token /// Process a single token
pub fn process_token(&mut self, token_id: TokenIdType) -> Result<SequenceDecoderOutput> { pub fn process_token(&mut self, token_id: TokenIdType) -> Result<SequenceDecoderOutput> {
let start = Instant::now();
if self.stopped { if self.stopped {
return Ok(SequenceDecoderOutput::Stopped); return Ok(SequenceDecoderOutput::Stopped);
} }
...@@ -104,22 +100,18 @@ impl StopSequenceDecoder { ...@@ -104,22 +100,18 @@ impl StopSequenceDecoder {
// Check for token-level stops first // Check for token-level stops first
if self.config.stop_tokens.contains(&token_id) { if self.config.stop_tokens.contains(&token_id) {
self.stopped = true; self.stopped = true;
TokenizerMetrics::record_stop_sequence_detected("token");
// Flush any jailed text before stopping // Flush any jailed text before stopping
if !self.jail_buffer.is_empty() { if !self.jail_buffer.is_empty() {
let output = self.jail_buffer.clone(); let output = self.jail_buffer.clone();
self.jail_buffer.clear(); self.jail_buffer.clear();
TokenizerMetrics::record_stop_detection_duration(start.elapsed());
return Ok(SequenceDecoderOutput::StoppedWithText(output)); return Ok(SequenceDecoderOutput::StoppedWithText(output));
} }
TokenizerMetrics::record_stop_detection_duration(start.elapsed());
return Ok(SequenceDecoderOutput::Stopped); return Ok(SequenceDecoderOutput::Stopped);
} }
if self.config.visible_stop_tokens.contains(&token_id) { if self.config.visible_stop_tokens.contains(&token_id) {
self.stopped = true; self.stopped = true;
TokenizerMetrics::record_stop_sequence_detected("visible_token");
// Include jailed text plus the stop token // Include jailed text plus the stop token
let stop_text = self let stop_text = self
...@@ -127,7 +119,6 @@ impl StopSequenceDecoder { ...@@ -127,7 +119,6 @@ impl StopSequenceDecoder {
.decode(&[token_id], self.skip_special_tokens)?; .decode(&[token_id], self.skip_special_tokens)?;
let output = format!("{}{}", self.jail_buffer, stop_text); let output = format!("{}{}", self.jail_buffer, stop_text);
self.jail_buffer.clear(); self.jail_buffer.clear();
TokenizerMetrics::record_stop_detection_duration(start.elapsed());
return Ok(SequenceDecoderOutput::StoppedWithText(output)); return Ok(SequenceDecoderOutput::StoppedWithText(output));
} }
...@@ -172,12 +163,10 @@ impl StopSequenceDecoder { ...@@ -172,12 +163,10 @@ impl StopSequenceDecoder {
for stop_seq in &self.config.stop_sequences { for stop_seq in &self.config.stop_sequences {
if let Some(pos) = check_text.find(stop_seq) { if let Some(pos) = check_text.find(stop_seq) {
self.stopped = true; self.stopped = true;
TokenizerMetrics::record_stop_sequence_detected("string");
// Output text before the stop sequence // Output text before the stop sequence
let output = check_text[..pos].to_string(); let output = check_text[..pos].to_string();
self.jail_buffer.clear(); self.jail_buffer.clear();
TokenizerMetrics::record_stop_detection_duration(start.elapsed());
return Ok(if output.is_empty() { return Ok(if output.is_empty() {
SequenceDecoderOutput::Stopped SequenceDecoderOutput::Stopped
} else { } else {
...@@ -190,13 +179,11 @@ impl StopSequenceDecoder { ...@@ -190,13 +179,11 @@ impl StopSequenceDecoder {
for stop_seq in &self.config.visible_stop_sequences { for stop_seq in &self.config.visible_stop_sequences {
if let Some(pos) = check_text.find(stop_seq) { if let Some(pos) = check_text.find(stop_seq) {
self.stopped = true; self.stopped = true;
TokenizerMetrics::record_stop_sequence_detected("visible_string");
// Include the stop sequence in output // Include the stop sequence in output
let end_pos = pos + stop_seq.len(); let end_pos = pos + stop_seq.len();
let output = check_text[..end_pos].to_string(); let output = check_text[..end_pos].to_string();
self.jail_buffer.clear(); self.jail_buffer.clear();
TokenizerMetrics::record_stop_detection_duration(start.elapsed());
return Ok(SequenceDecoderOutput::StoppedWithText(output)); return Ok(SequenceDecoderOutput::StoppedWithText(output));
} }
} }
...@@ -219,8 +206,6 @@ impl StopSequenceDecoder { ...@@ -219,8 +206,6 @@ impl StopSequenceDecoder {
} }
if partial_match_len > 0 { if partial_match_len > 0 {
TokenizerMetrics::record_partial_match();
// Split: output safe text, jail the potential match // Split: output safe text, jail the potential match
let safe_end = check_text.len() - partial_match_len; let safe_end = check_text.len() - partial_match_len;
let safe_text = &check_text[..safe_end]; let safe_text = &check_text[..safe_end];
...@@ -230,8 +215,6 @@ impl StopSequenceDecoder { ...@@ -230,8 +215,6 @@ impl StopSequenceDecoder {
self.prefix_offset = self.read_offset; self.prefix_offset = self.read_offset;
self.read_offset = self.token_buffer.len(); self.read_offset = self.token_buffer.len();
TokenizerMetrics::record_stop_detection_duration(start.elapsed());
if safe_text.is_empty() { if safe_text.is_empty() {
Ok(SequenceDecoderOutput::Held) Ok(SequenceDecoderOutput::Held)
} else { } else {
...@@ -245,8 +228,6 @@ impl StopSequenceDecoder { ...@@ -245,8 +228,6 @@ impl StopSequenceDecoder {
self.prefix_offset = self.read_offset; self.prefix_offset = self.read_offset;
self.read_offset = self.token_buffer.len(); self.read_offset = self.token_buffer.len();
TokenizerMetrics::record_stop_detection_duration(start.elapsed());
Ok(SequenceDecoderOutput::Text(check_text)) Ok(SequenceDecoderOutput::Text(check_text))
} }
} }
......
// src/tokenizer/stream.rs // src/tokenizer/stream.rs
use super::traits::{self, TokenIdType}; use super::traits::{self, TokenIdType};
use crate::metrics::TokenizerMetrics;
use anyhow::Result; use anyhow::Result;
use std::sync::Arc; use std::sync::Arc;
use std::time::Instant;
const INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET: usize = 5; const INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET: usize = 5;
...@@ -45,12 +43,8 @@ impl DecodeStream { ...@@ -45,12 +43,8 @@ impl DecodeStream {
/// Step appends a token_id to the internal state and tries to produce a text chunk. /// Step appends a token_id to the internal state and tries to produce a text chunk.
/// Returning `None` means the given id is not enough to produce a chunk. /// Returning `None` means the given id is not enough to produce a chunk.
pub fn step(&mut self, id: TokenIdType) -> Result<Option<String>> { pub fn step(&mut self, id: TokenIdType) -> Result<Option<String>> {
let start = Instant::now();
self.all_token_ids.push(id); self.all_token_ids.push(id);
TokenizerMetrics::record_stream_token();
let prefix_text = self.tokenizer.decode( let prefix_text = self.tokenizer.decode(
&self.all_token_ids[self.prefix_offset..self.read_offset], &self.all_token_ids[self.prefix_offset..self.read_offset],
self.skip_special_tokens, self.skip_special_tokens,
...@@ -67,16 +61,8 @@ impl DecodeStream { ...@@ -67,16 +61,8 @@ impl DecodeStream {
self.prefix_offset = self.read_offset; self.prefix_offset = self.read_offset;
self.read_offset = self.all_token_ids.len(); self.read_offset = self.all_token_ids.len();
TokenizerMetrics::record_stream_step_duration(start.elapsed());
Ok(Some(new_text)) Ok(Some(new_text))
} else { } else {
if new_text.ends_with("�") {
TokenizerMetrics::record_incomplete_utf8();
}
TokenizerMetrics::record_stream_step_duration(start.elapsed());
Ok(None) Ok(None)
} }
} }
......
...@@ -129,9 +129,7 @@ fn test_thread_safety() { ...@@ -129,9 +129,7 @@ fn test_thread_safety() {
thread::spawn(move || { thread::spawn(move || {
let text = "Hello test".to_string(); let text = "Hello test".to_string();
let encoding = tokenizer_clone.encode(&text).unwrap(); let encoding = tokenizer_clone.encode(&text).unwrap();
let decoded = tokenizer_clone let decoded = tokenizer_clone.decode(encoding.token_ids(), false).unwrap();
.decode(&encoding.token_ids(), false)
.unwrap();
assert!(decoded.contains("Hello") || decoded.contains("test")); assert!(decoded.contains("Hello") || decoded.contains("test"));
i i
}) })
......
...@@ -213,7 +213,7 @@ mod tests { ...@@ -213,7 +213,7 @@ mod tests {
let text = "Hello, world!"; let text = "Hello, world!";
let encoding = tokenizer.encode(text).unwrap(); let encoding = tokenizer.encode(text).unwrap();
let decoded = tokenizer.decode(&encoding.token_ids(), false).unwrap(); let decoded = tokenizer.decode(encoding.token_ids(), false).unwrap();
assert_eq!(decoded, text); assert_eq!(decoded, text);
} }
...@@ -226,7 +226,7 @@ mod tests { ...@@ -226,7 +226,7 @@ mod tests {
assert_eq!(encodings.len(), 3); assert_eq!(encodings.len(), 3);
for (i, encoding) in encodings.iter().enumerate() { for (i, encoding) in encodings.iter().enumerate() {
let decoded = tokenizer.decode(&encoding.token_ids(), false).unwrap(); let decoded = tokenizer.decode(encoding.token_ids(), false).unwrap();
assert_eq!(decoded, texts[i]); assert_eq!(decoded, texts[i]);
} }
} }
......
...@@ -36,22 +36,19 @@ pub enum Encoding { ...@@ -36,22 +36,19 @@ pub enum Encoding {
} }
impl Encoding { impl Encoding {
/// Returns a reference to token IDs when possible, owned Vec for compatibility /// Returns a reference to token IDs - zero-copy operation
pub fn token_ids(&self) -> Vec<TokenIdType> { pub fn token_ids(&self) -> &[TokenIdType] {
match self { match self {
Encoding::Hf(inner) => inner.get_ids().to_vec(), Encoding::Hf(inner) => inner.get_ids(),
Encoding::Sp(inner) => inner.clone(), Encoding::Sp(inner) => inner,
Encoding::Tiktoken(inner) => inner.clone(), Encoding::Tiktoken(inner) => inner,
} }
} }
/// Returns a reference to token IDs where possible /// Deprecated: Use token_ids() instead (kept for compatibility)
#[deprecated(since = "0.1.0", note = "Use token_ids() instead")]
pub fn token_ids_ref(&self) -> &[TokenIdType] { pub fn token_ids_ref(&self) -> &[TokenIdType] {
match self { self.token_ids()
Encoding::Hf(inner) => inner.get_ids(),
Encoding::Sp(inner) => inner,
Encoding::Tiktoken(inner) => inner, // Now works with tiktoken-rs 0.7.0!
}
} }
/// Get a hash of the token IDs for caching purposes /// Get a hash of the token IDs for caching purposes
......
...@@ -66,7 +66,7 @@ fn test_tokenizer_encode_decode_lifecycle() { ...@@ -66,7 +66,7 @@ fn test_tokenizer_encode_decode_lifecycle() {
let encoding = tokenizer.encode(prompt).expect("Failed to encode prompt"); let encoding = tokenizer.encode(prompt).expect("Failed to encode prompt");
let decoded = tokenizer let decoded = tokenizer
.decode(&encoding.token_ids(), false) .decode(encoding.token_ids(), false)
.expect("Failed to decode token_ids"); .expect("Failed to decode token_ids");
assert_eq!(decoded, *prompt, "Encode-decode mismatch for: {}", prompt); assert_eq!(decoded, *prompt, "Encode-decode mismatch for: {}", prompt);
...@@ -101,7 +101,7 @@ fn test_sequence_operations() { ...@@ -101,7 +101,7 @@ fn test_sequence_operations() {
for token_id in encoding.token_ids() { for token_id in encoding.token_ids() {
let text = decoder let text = decoder
.append_token(token_id) .append_token(*token_id)
.expect("Failed to append token"); .expect("Failed to append token");
output.push_str(&text); output.push_str(&text);
} }
...@@ -131,7 +131,7 @@ fn test_decode_stream() { ...@@ -131,7 +131,7 @@ fn test_decode_stream() {
let mut output = String::new(); let mut output = String::new();
for token_id in encoding.token_ids() { for token_id in encoding.token_ids() {
if let Some(text) = decoder.step(token_id).expect("Failed to decode token") { if let Some(text) = decoder.step(*token_id).expect("Failed to decode token") {
output.push_str(&text); output.push_str(&text);
} }
} }
...@@ -157,11 +157,11 @@ fn test_long_sequence_incremental_decode_with_prefill() { ...@@ -157,11 +157,11 @@ fn test_long_sequence_incremental_decode_with_prefill() {
.encode(output_text) .encode(output_text)
.expect("Failed to encode output"); .expect("Failed to encode output");
let mut decoder = DecodeStream::new(tokenizer.clone(), &input_encoding.token_ids(), false); let mut decoder = DecodeStream::new(tokenizer.clone(), input_encoding.token_ids(), false);
let mut output = String::new(); let mut output = String::new();
for token_id in output_encoding.token_ids() { for token_id in output_encoding.token_ids() {
if let Some(text) = decoder.step(token_id).expect("Failed to decode token") { if let Some(text) = decoder.step(*token_id).expect("Failed to decode token") {
output.push_str(&text); output.push_str(&text);
} }
} }
...@@ -199,7 +199,7 @@ fn test_stop_sequence_decoder() { ...@@ -199,7 +199,7 @@ fn test_stop_sequence_decoder() {
let mut stopped = false; let mut stopped = false;
for token_id in encoding.token_ids() { for token_id in encoding.token_ids() {
match decoder.process_token(token_id).unwrap() { match decoder.process_token(*token_id).unwrap() {
SequenceDecoderOutput::Text(text) => output.push_str(&text), SequenceDecoderOutput::Text(text) => output.push_str(&text),
SequenceDecoderOutput::StoppedWithText(text) => { SequenceDecoderOutput::StoppedWithText(text) => {
output.push_str(&text); output.push_str(&text);
...@@ -245,7 +245,7 @@ fn test_factory_creation() { ...@@ -245,7 +245,7 @@ fn test_factory_creation() {
let encoding = tokenizer.encode(TEST_PROMPTS[0]).expect("Failed to encode"); let encoding = tokenizer.encode(TEST_PROMPTS[0]).expect("Failed to encode");
let decoded = tokenizer let decoded = tokenizer
.decode(&encoding.token_ids(), false) .decode(encoding.token_ids(), false)
.expect("Failed to decode"); .expect("Failed to decode");
assert_eq!(decoded, TEST_PROMPTS[0]); assert_eq!(decoded, TEST_PROMPTS[0]);
...@@ -265,7 +265,7 @@ fn test_batch_encoding() { ...@@ -265,7 +265,7 @@ fn test_batch_encoding() {
for (i, encoding) in encodings.iter().enumerate() { for (i, encoding) in encodings.iter().enumerate() {
let decoded = tokenizer let decoded = tokenizer
.decode(&encoding.token_ids(), false) .decode(encoding.token_ids(), false)
.expect("Failed to decode"); .expect("Failed to decode");
assert_eq!(decoded, TEST_PROMPTS[i]); assert_eq!(decoded, TEST_PROMPTS[i]);
} }
...@@ -307,7 +307,7 @@ fn test_thread_safety() { ...@@ -307,7 +307,7 @@ fn test_thread_safety() {
.encode(prompt) .encode(prompt)
.expect("Failed to encode in thread"); .expect("Failed to encode in thread");
let decoded = tokenizer_clone let decoded = tokenizer_clone
.decode(&encoding.token_ids(), false) .decode(encoding.token_ids(), false)
.expect("Failed to decode in thread"); .expect("Failed to decode in thread");
assert_eq!(decoded, prompt); assert_eq!(decoded, prompt);
}) })
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment