Unverified Commit 78ae1758 authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] add tokenizer benchmark (#9427)

parent dae9a80f
......@@ -73,12 +73,18 @@ tower = { version = "0.5", features = ["util"] }
http-body-util = "0.1"
portpicker = "0.1"
tempfile = "3.8"
lazy_static = "1.4"
[[bench]]
name = "request_processing"
harness = false
path = "benches/request_processing.rs"
[[bench]]
name = "tokenizer_benchmark"
harness = false
path = "benches/tokenizer_benchmark.rs"
[profile.release]
lto = "thin"
codegen-units = 1
......
This diff is collapsed.
// Mock worker for testing - these functions are used by integration tests
#![allow(dead_code)]
use axum::{
extract::{Json, State},
http::StatusCode,
......@@ -25,7 +28,6 @@ pub struct MockWorkerConfig {
}
#[derive(Clone, Debug)]
#[allow(dead_code)]
pub enum WorkerType {
Regular,
Prefill,
......@@ -33,7 +35,6 @@ pub enum WorkerType {
}
#[derive(Clone, Debug)]
#[allow(dead_code)]
pub enum HealthStatus {
Healthy,
Unhealthy,
......
// These modules are used by tests and benchmarks
#![allow(dead_code)]
pub mod mock_worker;
pub mod test_app;
use sglang_router_rs::config::RouterConfig;
use sglang_router_rs::server::AppContext;
use std::sync::Arc;
use std::fs;
use std::path::PathBuf;
use std::sync::{Arc, Mutex, OnceLock};
/// Helper function to create AppContext for tests
pub fn create_test_context(config: RouterConfig) -> Arc<AppContext> {
......@@ -13,3 +18,80 @@ pub fn create_test_context(config: RouterConfig) -> Arc<AppContext> {
config.max_concurrent_requests,
))
}
// Tokenizer download configuration
const TINYLLAMA_TOKENIZER_URL: &str =
"https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0/resolve/main/tokenizer.json";
const CACHE_DIR: &str = ".tokenizer_cache";
const TINYLLAMA_TOKENIZER_FILENAME: &str = "tinyllama_tokenizer.json";
// Global mutex to prevent concurrent downloads
static DOWNLOAD_MUTEX: OnceLock<Mutex<()>> = OnceLock::new();
/// Downloads the TinyLlama tokenizer from HuggingFace if not already cached.
/// Returns the path to the cached tokenizer file.
///
/// This function is thread-safe and will only download the tokenizer once
/// even if called from multiple threads concurrently.
pub fn ensure_tokenizer_cached() -> PathBuf {
// Get or initialize the mutex
let mutex = DOWNLOAD_MUTEX.get_or_init(|| Mutex::new(()));
// Lock to ensure only one thread downloads at a time
let _guard = mutex.lock().unwrap();
let cache_dir = PathBuf::from(CACHE_DIR);
let tokenizer_path = cache_dir.join(TINYLLAMA_TOKENIZER_FILENAME);
// Create cache directory if it doesn't exist
if !cache_dir.exists() {
fs::create_dir_all(&cache_dir).expect("Failed to create cache directory");
}
// Download tokenizer if not already cached
if !tokenizer_path.exists() {
println!("Downloading TinyLlama tokenizer from HuggingFace...");
// Use blocking reqwest client since we're in tests/benchmarks
let client = reqwest::blocking::Client::new();
let response = client
.get(TINYLLAMA_TOKENIZER_URL)
.send()
.expect("Failed to download tokenizer");
if !response.status().is_success() {
panic!("Failed to download tokenizer: HTTP {}", response.status());
}
let content = response.bytes().expect("Failed to read tokenizer content");
// Verify we got actual JSON content
if content.len() < 100 {
panic!("Downloaded content too small: {} bytes", content.len());
}
fs::write(&tokenizer_path, content).expect("Failed to write tokenizer to cache");
println!(
"Tokenizer downloaded and cached successfully ({} bytes)",
tokenizer_path.metadata().unwrap().len()
);
}
tokenizer_path
}
/// Common test prompts for consistency across tests
pub const TEST_PROMPTS: [&str; 4] = [
"deep learning is",
"Deep learning is",
"has anyone seen nemo lately",
"another prompt",
];
/// Pre-computed hashes for verification
pub const EXPECTED_HASHES: [u64; 4] = [
1209591529327510910,
4181375434596349981,
6245658446118930933,
5097285695902185237,
];
......@@ -3,20 +3,14 @@
//! These tests download the TinyLlama tokenizer from HuggingFace to verify our tokenizer
//! implementation works correctly with real-world tokenizer files.
mod common;
use common::{ensure_tokenizer_cached, EXPECTED_HASHES, TEST_PROMPTS};
use sglang_router_rs::tokenizer::{
factory, huggingface::HuggingFaceTokenizer, sequence::Sequence, stop::*, stream::DecodeStream,
traits::*,
};
use std::fs;
use std::path::PathBuf;
use std::sync::{Arc, Mutex, OnceLock};
const TEST_PROMPTS: [&str; 4] = [
"deep learning is",
"Deep learning is",
"has anyone seen nemo lately",
"another prompt",
];
use std::sync::Arc;
const LONG_TEST_PROMPTS: [(&str, &str); 6] = [
("Tell me about the following text.", "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat."),
......@@ -34,70 +28,6 @@ const LONG_TEST_PROMPTS: [(&str, &str); 6] = [
("Tell me about the following text.", "😀😃😄😁😆🥹😅😂🤣🥲☺️😊😇🙂🙃😉🤩😎 🤪🥳🤓🙄🤪😵👻")
];
const TINYLLAMA_TOKENIZER_URL: &str =
"https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0/resolve/main/tokenizer.json";
const CACHE_DIR: &str = ".tokenizer_cache";
const TINYLLAMA_TOKENIZER_FILENAME: &str = "tinyllama_tokenizer.json";
// Global mutex to prevent concurrent downloads
static DOWNLOAD_MUTEX: OnceLock<Mutex<()>> = OnceLock::new();
// Pre-computed hashes for verification
const EXPECTED_HASHES: [u64; 4] = [
1209591529327510910,
4181375434596349981,
6245658446118930933,
5097285695902185237,
];
/// Downloads the tokenizer from HuggingFace if not already cached
fn ensure_tokenizer_cached() -> PathBuf {
// Get or initialize the mutex
let mutex = DOWNLOAD_MUTEX.get_or_init(|| Mutex::new(()));
// Lock to ensure only one thread downloads at a time
let _guard = mutex.lock().unwrap();
let cache_dir = PathBuf::from(CACHE_DIR);
let tokenizer_path = cache_dir.join(TINYLLAMA_TOKENIZER_FILENAME);
// Create cache directory if it doesn't exist
if !cache_dir.exists() {
fs::create_dir_all(&cache_dir).expect("Failed to create cache directory");
}
// Download tokenizer if not already cached
if !tokenizer_path.exists() {
println!("Downloading TinyLlama tokenizer from HuggingFace...");
// Use blocking reqwest client since we're in tests
let client = reqwest::blocking::Client::new();
let response = client
.get(TINYLLAMA_TOKENIZER_URL)
.send()
.expect("Failed to download tokenizer");
if !response.status().is_success() {
panic!("Failed to download tokenizer: HTTP {}", response.status());
}
let content = response.bytes().expect("Failed to read tokenizer content");
// Verify we got actual JSON content
if content.len() < 100 {
panic!("Downloaded content too small: {} bytes", content.len());
}
fs::write(&tokenizer_path, content).expect("Failed to write tokenizer to cache");
println!(
"Tokenizer downloaded and cached successfully ({} bytes)",
tokenizer_path.metadata().unwrap().len()
);
}
tokenizer_path
}
fn compute_hashes_for_tokenizer<E: Encoder>(tokenizer: &E, prompts: &[&str]) -> Vec<u64> {
prompts
.iter()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment