"tests/python/vscode:/vscode.git/clone" did not exist on "e6226e826d7679c405ab0fdb59cdab8fce571c41"
Unverified Commit 598c0bc1 authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

[router] add tokenizer download support from hf hub (#9882)

parent b361750a
...@@ -4,9 +4,7 @@ version = "0.0.0" ...@@ -4,9 +4,7 @@ version = "0.0.0"
edition = "2021" edition = "2021"
[features] [features]
default = ["huggingface", "grpc-client"] default = ["grpc-client"]
huggingface = ["tokenizers", "minijinja"]
tiktoken = ["tiktoken-rs"]
grpc-client = [] grpc-client = []
grpc-server = [] grpc-server = []
...@@ -52,10 +50,11 @@ regex = "1.10" ...@@ -52,10 +50,11 @@ regex = "1.10"
url = "2.5.4" url = "2.5.4"
tokio-stream = { version = "0.1", features = ["sync"] } tokio-stream = { version = "0.1", features = ["sync"] }
anyhow = "1.0" anyhow = "1.0"
tokenizers = { version = "0.21.4", optional = true } tokenizers = { version = "0.22.0" }
tiktoken-rs = { version = "0.7.0", optional = true } tiktoken-rs = { version = "0.7.0" }
minijinja = { version = "2.0", optional = true } minijinja = { version = "2.0" }
rustls = { version = "0.23", default-features = false, features = ["ring", "std"] } rustls = { version = "0.23", default-features = false, features = ["ring", "std"] }
hf-hub = { version = "0.4.3", features = ["tokio"] }
# gRPC and Protobuf dependencies # gRPC and Protobuf dependencies
tonic = { version = "0.12", features = ["tls", "gzip", "transport"] } tonic = { version = "0.12", features = ["tls", "gzip", "transport"] }
......
...@@ -8,6 +8,7 @@ The SGL Router tokenizer layer provides a unified interface for text tokenizatio ...@@ -8,6 +8,7 @@ The SGL Router tokenizer layer provides a unified interface for text tokenizatio
**Key Components:** **Key Components:**
- **Factory Pattern**: Auto-detection and creation of appropriate tokenizer types from files or model names - **Factory Pattern**: Auto-detection and creation of appropriate tokenizer types from files or model names
- **HuggingFace Hub Integration**: Automatic downloading of tokenizer files from HuggingFace Hub for model IDs
- **Trait System**: `Encoder`, `Decoder`, and `Tokenizer` traits for implementation flexibility - **Trait System**: `Encoder`, `Decoder`, and `Tokenizer` traits for implementation flexibility
- **Streaming**: Incremental decoding with UTF-8 boundary handling and buffering - **Streaming**: Incremental decoding with UTF-8 boundary handling and buffering
- **Stop Sequences**: Complex pattern matching for stop tokens and sequences with "jail" buffering - **Stop Sequences**: Complex pattern matching for stop tokens and sequences with "jail" buffering
...@@ -16,7 +17,7 @@ The SGL Router tokenizer layer provides a unified interface for text tokenizatio ...@@ -16,7 +17,7 @@ The SGL Router tokenizer layer provides a unified interface for text tokenizatio
- **Metrics Integration**: Comprehensive performance and error tracking across all operations - **Metrics Integration**: Comprehensive performance and error tracking across all operations
**Data Flow:** **Data Flow:**
1. Request → Factory (type detection) → Concrete Tokenizer Creation 1. Request → Factory (type detection/HF download) → Concrete Tokenizer Creation
2. Encode: Text → Tokenizer → Encoding (token IDs) 2. Encode: Text → Tokenizer → Encoding (token IDs)
3. Stream: Token IDs → DecodeStream → Incremental Text Chunks 3. Stream: Token IDs → DecodeStream → Incremental Text Chunks
4. Stop Detection: Tokens → StopSequenceDecoder → Text/Held/Stopped 4. Stop Detection: Tokens → StopSequenceDecoder → Text/Held/Stopped
...@@ -25,8 +26,9 @@ The SGL Router tokenizer layer provides a unified interface for text tokenizatio ...@@ -25,8 +26,9 @@ The SGL Router tokenizer layer provides a unified interface for text tokenizatio
### Architecture Highlights ### Architecture Highlights
- **Extended Backend Support**: HuggingFace, Tiktoken (GPT models), and Mock for testing - **Extended Backend Support**: HuggingFace, Tiktoken (GPT models), and Mock for testing
- **HuggingFace Hub Integration**: Automatic tokenizer downloads with caching
- **Comprehensive Metrics**: Full TokenizerMetrics integration for observability - **Comprehensive Metrics**: Full TokenizerMetrics integration for observability
- **Feature Gating**: Conditional compilation for tokenizer backends - **Unified Dependencies**: All tokenizer backends included by default (no feature gates)
- **Stop Sequence Detection**: Sophisticated partial matching with jail buffer - **Stop Sequence Detection**: Sophisticated partial matching with jail buffer
- **Chat Template Support**: Full Jinja2 rendering with HuggingFace compatibility - **Chat Template Support**: Full Jinja2 rendering with HuggingFace compatibility
- **Thread Safety**: Arc-based sharing with Send + Sync guarantees - **Thread Safety**: Arc-based sharing with Send + Sync guarantees
...@@ -92,9 +94,14 @@ sequenceDiagram ...@@ -92,9 +94,14 @@ sequenceDiagram
participant SD as StopDecoder participant SD as StopDecoder
participant M as Metrics participant M as Metrics
C->>F: create_tokenizer(path) C->>F: create_tokenizer(path_or_model_id)
F->>F: detect_type() F->>F: detect_type()
F->>T: new HF/Tiktoken/Mock alt local file
F->>T: new HF/Tiktoken/Mock
else HuggingFace model ID
F->>F: download_tokenizer_from_hf()
F->>T: new from downloaded files
end
F->>M: record_factory_load() F->>M: record_factory_load()
F-->>C: Arc<dyn Tokenizer> F-->>C: Arc<dyn Tokenizer>
...@@ -287,11 +294,11 @@ impl Tokenizer { ...@@ -287,11 +294,11 @@ impl Tokenizer {
- Single field: `Arc<dyn traits::Tokenizer>` for polymorphic dispatch - Single field: `Arc<dyn traits::Tokenizer>` for polymorphic dispatch
- Immutable after creation, Clone via Arc - Immutable after creation, Clone via Arc
**Re-exports** (mod.rs:25-39): **Re-exports** (mod.rs:26-43):
- Factory functions: `create_tokenizer`, `create_tokenizer_from_file`, `create_tokenizer_with_chat_template` - Factory functions: `create_tokenizer`, `create_tokenizer_async`, `create_tokenizer_from_file`, `create_tokenizer_with_chat_template`
- Types: `Sequence`, `StopSequenceConfig`, `DecodeStream`, `Encoding` - Types: `Sequence`, `StopSequenceConfig`, `DecodeStream`, `Encoding`, `TokenizerType`
- Chat template: `ChatMessage` (when huggingface feature enabled) - Chat template: `ChatMessage`
- Conditional: `HuggingFaceTokenizer`, `TiktokenTokenizer` based on features - Tokenizer implementations: `HuggingFaceTokenizer`, `TiktokenTokenizer`
### 3.2 traits.rs (Trait Definitions) ### 3.2 traits.rs (Trait Definitions)
...@@ -350,6 +357,7 @@ pub fn create_tokenizer_with_chat_template( ...@@ -350,6 +357,7 @@ pub fn create_tokenizer_with_chat_template(
chat_template_path: Option<&str> chat_template_path: Option<&str>
) -> Result<Arc<dyn traits::Tokenizer>> ) -> Result<Arc<dyn traits::Tokenizer>>
pub fn create_tokenizer(model_name_or_path: &str) -> Result<Arc<dyn traits::Tokenizer>> pub fn create_tokenizer(model_name_or_path: &str) -> Result<Arc<dyn traits::Tokenizer>>
pub async fn create_tokenizer_async(model_name_or_path: &str) -> Result<Arc<dyn traits::Tokenizer>>
pub fn get_tokenizer_info(file_path: &str) -> Result<TokenizerType> pub fn get_tokenizer_info(file_path: &str) -> Result<TokenizerType>
``` ```
...@@ -364,10 +372,16 @@ pub fn get_tokenizer_info(file_path: &str) -> Result<TokenizerType> ...@@ -364,10 +372,16 @@ pub fn get_tokenizer_info(file_path: &str) -> Result<TokenizerType>
- SentencePiece: Check for specific byte patterns - SentencePiece: Check for specific byte patterns
- GGUF: Check magic number "GGUF" - GGUF: Check magic number "GGUF"
**Model Name Routing** (factory.rs:163-203): **Model Name Routing** (factory.rs:145-193):
- GPT models → Tiktoken (gpt-4, gpt-3.5, davinci, curie, etc.) - GPT models → Tiktoken (gpt-4, gpt-3.5, davinci, curie, etc.)
- File paths → file-based creation - File paths → file-based creation
- HuggingFace Hub → Not implemented (returns error) - HuggingFace model IDs → Automatic download from Hub
**HuggingFace Hub Integration**:
- Downloads tokenizer files (tokenizer.json, tokenizer_config.json, etc.)
- Respects HF_TOKEN environment variable for private models
- Caches downloaded files using hf-hub crate
- Async and blocking versions available
**Metrics Integration:** **Metrics Integration:**
- Records factory load/error events (factory.rs:56-57, 82-83) - Records factory load/error events (factory.rs:56-57, 82-83)
...@@ -613,7 +627,32 @@ pub enum TiktokenModel { ...@@ -613,7 +627,32 @@ pub enum TiktokenModel {
- Decode: Join tokens with spaces - Decode: Join tokens with spaces
- Skips special tokens when requested - Skips special tokens when requested
### 3.10 chat_template.rs (Chat Template Support) ### 3.10 hub.rs (HuggingFace Hub Download)
**Location**: `src/tokenizer/hub.rs`
**Purpose:** Download tokenizer files from HuggingFace Hub when given a model ID.
**Key Functions:**
```rust
pub async fn download_tokenizer_from_hf(model_id: impl AsRef<Path>) -> Result<PathBuf>
pub async fn from_hf(name: impl AsRef<Path>, ignore_weights: bool) -> Result<PathBuf>
```
**Features:**
- Downloads only tokenizer-related files by default
- Filters out model weights, images, and documentation
- Uses HF_TOKEN environment variable for authentication
- Returns cached directory path for subsequent use
- Progress indication during download
**File Detection:**
- Tokenizer files: tokenizer.json, tokenizer_config.json, special_tokens_map.json
- Vocabulary files: vocab.json, merges.txt
- SentencePiece models: *.model files
### 3.11 chat_template.rs (Chat Template Support)
**Location**: `src/tokenizer/chat_template.rs` **Location**: `src/tokenizer/chat_template.rs`
...@@ -894,11 +933,11 @@ The `Encoding` enum must: ...@@ -894,11 +933,11 @@ The `Encoding` enum must:
### Configuration ### Configuration
**Environment Variables:** **Environment Variables:**
- None currently defined - `HF_TOKEN`: HuggingFace authentication token for private models
**Feature Flags:** **Dependencies:**
- `huggingface`: Enable HF tokenizer - All tokenizer backends included by default
- `tiktoken`: Enable Tiktoken support - No feature flags required
**Model Mapping:** **Model Mapping:**
- Hardcoded in factory.rs - Hardcoded in factory.rs
...@@ -961,26 +1000,22 @@ The `Encoding` enum must: ...@@ -961,26 +1000,22 @@ The `Encoding` enum must:
- File: `src/tokenizer/traits.rs` - File: `src/tokenizer/traits.rs`
- Symbol: `pub type Offsets = (usize, usize)` - Symbol: `pub type Offsets = (usize, usize)`
3. **TODO:** Implement HuggingFace Hub downloading 3. **TODO:** Support SentencePiece models
- File: `src/tokenizer/factory.rs:191`
- Symbol: `create_tokenizer()` function
4. **TODO:** Support SentencePiece models
- File: `src/tokenizer/factory.rs:69-72` - File: `src/tokenizer/factory.rs:69-72`
- Symbol: Extension match arm for "model" - Symbol: Extension match arm for "model"
5. **TODO:** Support GGUF format 4. **TODO:** Support GGUF format
- File: `src/tokenizer/factory.rs:74-78` - File: `src/tokenizer/factory.rs:74-78`
- Symbol: Extension match arm for "gguf" - Symbol: Extension match arm for "gguf"
6. **TODO:** Add token↔ID mapping for Tiktoken 5. **TODO:** Add token↔ID mapping for Tiktoken
- File: `src/tokenizer/tiktoken.rs:151-161` - File: `src/tokenizer/tiktoken.rs:151-161`
- Symbol: `token_to_id()` and `id_to_token()` methods - Symbol: `token_to_id()` and `id_to_token()` methods
7. **TODO:** Fix `token_ids_ref()` for Tiktoken 6. **TODO:** Fix `token_ids_ref()` for Tiktoken
- File: `src/tokenizer/traits.rs:46-50` - File: `src/tokenizer/traits.rs:46-50`
- Symbol: `Encoding::Tiktoken` match arm - Symbol: `Encoding::Tiktoken` match arm
8. **TODO:** Make model→tokenizer mapping configurable 7. **TODO:** Make model→tokenizer mapping configurable
- File: `src/tokenizer/factory.rs:174-184` - File: `src/tokenizer/factory.rs:174-184`
- Symbol: GPT model detection logic - Symbol: GPT model detection logic
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
//! similar to HuggingFace transformers' apply_chat_template method. //! similar to HuggingFace transformers' apply_chat_template method.
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
#[cfg(feature = "huggingface")]
use minijinja::{context, Environment, Value}; use minijinja::{context, Environment, Value};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json; use serde_json;
...@@ -38,14 +37,12 @@ impl ChatMessage { ...@@ -38,14 +37,12 @@ impl ChatMessage {
} }
/// Chat template processor using Jinja2 /// Chat template processor using Jinja2
#[cfg(feature = "huggingface")]
pub struct ChatTemplateProcessor { pub struct ChatTemplateProcessor {
template: String, template: String,
bos_token: Option<String>, bos_token: Option<String>,
eos_token: Option<String>, eos_token: Option<String>,
} }
#[cfg(feature = "huggingface")]
impl ChatTemplateProcessor { impl ChatTemplateProcessor {
/// Create a new chat template processor /// Create a new chat template processor
pub fn new(template: String, bos_token: Option<String>, eos_token: Option<String>) -> Self { pub fn new(template: String, bos_token: Option<String>, eos_token: Option<String>) -> Self {
...@@ -102,7 +99,6 @@ impl ChatTemplateProcessor { ...@@ -102,7 +99,6 @@ impl ChatTemplateProcessor {
} }
/// Load chat template from tokenizer config JSON /// Load chat template from tokenizer config JSON
#[cfg(feature = "huggingface")]
pub fn load_chat_template_from_config(config_path: &str) -> Result<Option<String>> { pub fn load_chat_template_from_config(config_path: &str) -> Result<Option<String>> {
use std::fs; use std::fs;
...@@ -136,7 +132,6 @@ mod tests { ...@@ -136,7 +132,6 @@ mod tests {
assert_eq!(assistant_msg.role, "assistant"); assert_eq!(assistant_msg.role, "assistant");
} }
#[cfg(feature = "huggingface")]
#[test] #[test]
fn test_simple_chat_template() { fn test_simple_chat_template() {
// Simple template that formats messages // Simple template that formats messages
...@@ -162,7 +157,6 @@ assistant: ...@@ -162,7 +157,6 @@ assistant:
assert!(result.contains("assistant:")); assert!(result.contains("assistant:"));
} }
#[cfg(feature = "huggingface")]
#[test] #[test]
fn test_chat_template_with_tokens() { fn test_chat_template_with_tokens() {
// Template that uses special tokens // Template that uses special tokens
......
...@@ -5,15 +5,15 @@ use std::io::Read; ...@@ -5,15 +5,15 @@ use std::io::Read;
use std::path::Path; use std::path::Path;
use std::sync::Arc; use std::sync::Arc;
#[cfg(feature = "huggingface")]
use super::huggingface::HuggingFaceTokenizer; use super::huggingface::HuggingFaceTokenizer;
use super::tiktoken::TiktokenTokenizer;
use crate::tokenizer::hub::download_tokenizer_from_hf;
/// Represents the type of tokenizer being used /// Represents the type of tokenizer being used
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub enum TokenizerType { pub enum TokenizerType {
HuggingFace(String), HuggingFace(String),
Mock, Mock,
#[cfg(feature = "tiktoken")]
Tiktoken(String), Tiktoken(String),
// Future: SentencePiece, GGUF // Future: SentencePiece, GGUF
} }
...@@ -52,21 +52,10 @@ pub fn create_tokenizer_with_chat_template( ...@@ -52,21 +52,10 @@ pub fn create_tokenizer_with_chat_template(
let result = match extension.as_deref() { let result = match extension.as_deref() {
Some("json") => { Some("json") => {
#[cfg(feature = "huggingface")] let tokenizer =
{ HuggingFaceTokenizer::from_file_with_chat_template(file_path, chat_template_path)?;
let tokenizer = HuggingFaceTokenizer::from_file_with_chat_template(
file_path, Ok(Arc::new(tokenizer) as Arc<dyn traits::Tokenizer>)
chat_template_path,
)?;
Ok(Arc::new(tokenizer) as Arc<dyn traits::Tokenizer>)
}
#[cfg(not(feature = "huggingface"))]
{
Err(Error::msg(
"HuggingFace support not enabled. Enable the 'huggingface' feature.",
))
}
} }
Some("model") => { Some("model") => {
// SentencePiece model file // SentencePiece model file
...@@ -94,17 +83,8 @@ fn auto_detect_tokenizer(file_path: &str) -> Result<Arc<dyn traits::Tokenizer>> ...@@ -94,17 +83,8 @@ fn auto_detect_tokenizer(file_path: &str) -> Result<Arc<dyn traits::Tokenizer>>
// Check for JSON (HuggingFace format) // Check for JSON (HuggingFace format)
if is_likely_json(&buffer) { if is_likely_json(&buffer) {
#[cfg(feature = "huggingface")] let tokenizer = HuggingFaceTokenizer::from_file(file_path)?;
{ return Ok(Arc::new(tokenizer));
let tokenizer = HuggingFaceTokenizer::from_file(file_path)?;
return Ok(Arc::new(tokenizer));
}
#[cfg(not(feature = "huggingface"))]
{
return Err(Error::msg(
"File appears to be JSON (HuggingFace) format, but HuggingFace support is not enabled",
));
}
} }
// Check for GGUF magic number // Check for GGUF magic number
...@@ -154,8 +134,10 @@ fn is_likely_sentencepiece(buffer: &[u8]) -> bool { ...@@ -154,8 +134,10 @@ fn is_likely_sentencepiece(buffer: &[u8]) -> bool {
|| buffer.windows(4).any(|w| w == b"</s>")) || buffer.windows(4).any(|w| w == b"</s>"))
} }
/// Factory function to create tokenizer from a model name or path /// Factory function to create tokenizer from a model name or path (async version)
pub fn create_tokenizer(model_name_or_path: &str) -> Result<Arc<dyn traits::Tokenizer>> { pub async fn create_tokenizer_async(
model_name_or_path: &str,
) -> Result<Arc<dyn traits::Tokenizer>> {
// Check if it's a file path // Check if it's a file path
let path = Path::new(model_name_or_path); let path = Path::new(model_name_or_path);
if path.exists() { if path.exists() {
...@@ -163,35 +145,73 @@ pub fn create_tokenizer(model_name_or_path: &str) -> Result<Arc<dyn traits::Toke ...@@ -163,35 +145,73 @@ pub fn create_tokenizer(model_name_or_path: &str) -> Result<Arc<dyn traits::Toke
} }
// Check if it's a GPT model name that should use Tiktoken // Check if it's a GPT model name that should use Tiktoken
#[cfg(feature = "tiktoken")] if model_name_or_path.contains("gpt-")
|| model_name_or_path.contains("davinci")
|| model_name_or_path.contains("curie")
|| model_name_or_path.contains("babbage")
|| model_name_or_path.contains("ada")
{ {
if model_name_or_path.contains("gpt-") let tokenizer = TiktokenTokenizer::from_model_name(model_name_or_path)?;
|| model_name_or_path.contains("davinci") return Ok(Arc::new(tokenizer));
|| model_name_or_path.contains("curie") }
|| model_name_or_path.contains("babbage")
|| model_name_or_path.contains("ada") // Try to download tokenizer files from HuggingFace
{ match download_tokenizer_from_hf(model_name_or_path).await {
use super::tiktoken::TiktokenTokenizer; Ok(cache_dir) => {
let tokenizer = TiktokenTokenizer::from_model_name(model_name_or_path)?; // Look for tokenizer.json in the cache directory
return Ok(Arc::new(tokenizer)); let tokenizer_path = cache_dir.join("tokenizer.json");
if tokenizer_path.exists() {
create_tokenizer_from_file(tokenizer_path.to_str().unwrap())
} else {
// Try other common tokenizer file names
let possible_files = ["tokenizer_config.json", "vocab.json"];
for file_name in &possible_files {
let file_path = cache_dir.join(file_name);
if file_path.exists() {
return create_tokenizer_from_file(file_path.to_str().unwrap());
}
}
Err(Error::msg(format!(
"Downloaded model '{}' but couldn't find a suitable tokenizer file",
model_name_or_path
)))
}
} }
Err(e) => Err(Error::msg(format!(
"Failed to download tokenizer from HuggingFace: {}",
e
))),
} }
}
// Otherwise, try to load from HuggingFace Hub /// Factory function to create tokenizer from a model name or path (blocking version)
#[cfg(feature = "huggingface")] pub fn create_tokenizer(model_name_or_path: &str) -> Result<Arc<dyn traits::Tokenizer>> {
{ // Check if it's a file path
// This would download from HF Hub - not implemented yet let path = Path::new(model_name_or_path);
Err(Error::msg( if path.exists() {
"Loading from HuggingFace Hub not yet implemented", return create_tokenizer_from_file(model_name_or_path);
))
} }
#[cfg(not(feature = "huggingface"))] // Check if it's a GPT model name that should use Tiktoken
if model_name_or_path.contains("gpt-")
|| model_name_or_path.contains("davinci")
|| model_name_or_path.contains("curie")
|| model_name_or_path.contains("babbage")
|| model_name_or_path.contains("ada")
{ {
Err(Error::msg(format!( let tokenizer = TiktokenTokenizer::from_model_name(model_name_or_path)?;
"Model '{}' not found locally and HuggingFace support is not enabled", return Ok(Arc::new(tokenizer));
model_name_or_path }
)))
// Only use tokio for HuggingFace downloads
// Check if we're already in a tokio runtime
if let Ok(handle) = tokio::runtime::Handle::try_current() {
// We're in a runtime, use block_in_place
tokio::task::block_in_place(|| handle.block_on(create_tokenizer_async(model_name_or_path)))
} else {
// No runtime, create a temporary one
let rt = tokio::runtime::Runtime::new()?;
rt.block_on(create_tokenizer_async(model_name_or_path))
} }
} }
...@@ -257,7 +277,6 @@ mod tests { ...@@ -257,7 +277,6 @@ mod tests {
} }
} }
#[cfg(feature = "tiktoken")]
#[test] #[test]
fn test_create_tiktoken_tokenizer() { fn test_create_tiktoken_tokenizer() {
// Test creating tokenizer for GPT models // Test creating tokenizer for GPT models
...@@ -270,4 +289,30 @@ mod tests { ...@@ -270,4 +289,30 @@ mod tests {
let decoded = tokenizer.decode(encoding.token_ids(), false).unwrap(); let decoded = tokenizer.decode(encoding.token_ids(), false).unwrap();
assert_eq!(decoded, text); assert_eq!(decoded, text);
} }
#[tokio::test]
async fn test_download_tokenizer_from_hf() {
// Test with a small model that should have tokenizer files
// Skip this test if HF_TOKEN is not set and we're in CI
if std::env::var("CI").is_ok() && std::env::var("HF_TOKEN").is_err() {
println!("Skipping HF download test in CI without HF_TOKEN");
return;
}
// Try to create tokenizer for a known small model
let result = create_tokenizer_async("bert-base-uncased").await;
// The test might fail due to network issues or rate limiting
// so we just check that the function executes without panic
match result {
Ok(tokenizer) => {
assert!(tokenizer.vocab_size() > 0);
println!("Successfully downloaded and created tokenizer");
}
Err(e) => {
println!("Download failed (this might be expected): {}", e);
// Don't fail the test - network issues shouldn't break CI
}
}
}
} }
use hf_hub::api::tokio::ApiBuilder;
use std::env;
use std::path::{Path, PathBuf};
const IGNORED: [&str; 5] = [
".gitattributes",
"LICENSE",
"LICENSE.txt",
"README.md",
"USE_POLICY.md",
];
const HF_TOKEN_ENV_VAR: &str = "HF_TOKEN";
/// Checks if a file is a model weight file
fn is_weight_file(filename: &str) -> bool {
filename.ends_with(".bin")
|| filename.ends_with(".safetensors")
|| filename.ends_with(".h5")
|| filename.ends_with(".msgpack")
|| filename.ends_with(".ckpt.index")
}
/// Checks if a file is an image file
fn is_image(filename: &str) -> bool {
filename.ends_with(".png")
|| filename.ends_with("PNG")
|| filename.ends_with(".jpg")
|| filename.ends_with("JPG")
|| filename.ends_with(".jpeg")
|| filename.ends_with("JPEG")
}
/// Checks if a file is a tokenizer file
fn is_tokenizer_file(filename: &str) -> bool {
filename.ends_with("tokenizer.json")
|| filename.ends_with("tokenizer_config.json")
|| filename.ends_with("special_tokens_map.json")
|| filename.ends_with("vocab.json")
|| filename.ends_with("merges.txt")
|| filename.ends_with(".model") // SentencePiece models
|| filename.ends_with(".tiktoken")
}
/// Attempt to download tokenizer files from Hugging Face
/// Returns the directory containing the downloaded tokenizer files
pub async fn download_tokenizer_from_hf(model_id: impl AsRef<Path>) -> anyhow::Result<PathBuf> {
let model_id = model_id.as_ref();
let token = env::var(HF_TOKEN_ENV_VAR).ok();
let api = ApiBuilder::new()
.with_progress(true)
.with_token(token)
.build()?;
let model_name = model_id.display().to_string();
let repo = api.model(model_name.clone());
let info = match repo.info().await {
Ok(info) => info,
Err(e) => {
return Err(anyhow::anyhow!(
"Failed to fetch model '{}' from HuggingFace: {}. Is this a valid HuggingFace ID?",
model_name,
e
));
}
};
if info.siblings.is_empty() {
return Err(anyhow::anyhow!(
"Model '{}' exists but contains no downloadable files.",
model_name
));
}
let mut cache_dir = None;
let mut tokenizer_files_found = false;
// First, identify all tokenizer files to download
let tokenizer_files: Vec<_> = info
.siblings
.iter()
.filter(|sib| {
!IGNORED.contains(&sib.rfilename.as_str())
&& !is_image(&sib.rfilename)
&& !is_weight_file(&sib.rfilename)
&& is_tokenizer_file(&sib.rfilename)
})
.collect();
if tokenizer_files.is_empty() {
return Err(anyhow::anyhow!(
"No tokenizer files found for model '{}'.",
model_name
));
}
// Download all tokenizer files
for sib in tokenizer_files {
match repo.get(&sib.rfilename).await {
Ok(path) => {
if cache_dir.is_none() {
cache_dir = path.parent().map(|p| p.to_path_buf());
}
tokenizer_files_found = true;
}
Err(e) => {
return Err(anyhow::anyhow!(
"Failed to download tokenizer file '{}' from model '{}': {}",
sib.rfilename,
model_name,
e
));
}
}
}
if !tokenizer_files_found {
return Err(anyhow::anyhow!(
"No tokenizer files could be downloaded for model '{}'.",
model_name
));
}
match cache_dir {
Some(dir) => Ok(dir),
None => Err(anyhow::anyhow!(
"Invalid HF cache path for model '{}'",
model_name
)),
}
}
/// Attempt to download a model from Hugging Face (including weights)
/// Returns the directory it is in
/// If ignore_weights is true, model weight files will be skipped
pub async fn from_hf(name: impl AsRef<Path>, ignore_weights: bool) -> anyhow::Result<PathBuf> {
let name = name.as_ref();
let token = env::var(HF_TOKEN_ENV_VAR).ok();
let api = ApiBuilder::new()
.with_progress(true)
.with_token(token)
.build()?;
let model_name = name.display().to_string();
let repo = api.model(model_name.clone());
let info = match repo.info().await {
Ok(info) => info,
Err(e) => {
return Err(anyhow::anyhow!(
"Failed to fetch model '{}' from HuggingFace: {}. Is this a valid HuggingFace ID?",
model_name,
e
));
}
};
if info.siblings.is_empty() {
return Err(anyhow::anyhow!(
"Model '{}' exists but contains no downloadable files.",
model_name
));
}
let mut p = PathBuf::new();
let mut files_downloaded = false;
for sib in info.siblings {
if IGNORED.contains(&sib.rfilename.as_str()) || is_image(&sib.rfilename) {
continue;
}
// If ignore_weights is true, skip weight files
if ignore_weights && is_weight_file(&sib.rfilename) {
continue;
}
match repo.get(&sib.rfilename).await {
Ok(path) => {
p = path;
files_downloaded = true;
}
Err(e) => {
return Err(anyhow::anyhow!(
"Failed to download file '{}' from model '{}': {}",
sib.rfilename,
model_name,
e
));
}
}
}
if !files_downloaded {
let file_type = if ignore_weights {
"non-weight"
} else {
"valid"
};
return Err(anyhow::anyhow!(
"No {} files found for model '{}'.",
file_type,
model_name
));
}
match p.parent() {
Some(p) => Ok(p.to_path_buf()),
None => Err(anyhow::anyhow!("Invalid HF cache path: {}", p.display())),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_is_tokenizer_file() {
assert!(is_tokenizer_file("tokenizer.json"));
assert!(is_tokenizer_file("tokenizer_config.json"));
assert!(is_tokenizer_file("special_tokens_map.json"));
assert!(is_tokenizer_file("vocab.json"));
assert!(is_tokenizer_file("merges.txt"));
assert!(is_tokenizer_file("spiece.model"));
assert!(!is_tokenizer_file("model.bin"));
assert!(!is_tokenizer_file("README.md"));
}
#[test]
fn test_is_weight_file() {
assert!(is_weight_file("model.bin"));
assert!(is_weight_file("model.safetensors"));
assert!(is_weight_file("pytorch_model.bin"));
assert!(!is_weight_file("tokenizer.json"));
assert!(!is_weight_file("config.json"));
}
}
...@@ -5,7 +5,6 @@ use anyhow::{Error, Result}; ...@@ -5,7 +5,6 @@ use anyhow::{Error, Result};
use std::collections::HashMap; use std::collections::HashMap;
use tokenizers::tokenizer::Tokenizer as HfTokenizer; use tokenizers::tokenizer::Tokenizer as HfTokenizer;
#[cfg(feature = "minijinja")]
use super::chat_template::{ChatMessage, ChatTemplateProcessor}; use super::chat_template::{ChatMessage, ChatTemplateProcessor};
/// HuggingFace tokenizer wrapper /// HuggingFace tokenizer wrapper
...@@ -14,7 +13,6 @@ pub struct HuggingFaceTokenizer { ...@@ -14,7 +13,6 @@ pub struct HuggingFaceTokenizer {
special_tokens: SpecialTokens, special_tokens: SpecialTokens,
vocab: HashMap<String, TokenIdType>, vocab: HashMap<String, TokenIdType>,
reverse_vocab: HashMap<TokenIdType, String>, reverse_vocab: HashMap<TokenIdType, String>,
#[cfg(feature = "minijinja")]
chat_template: Option<String>, chat_template: Option<String>,
} }
...@@ -43,7 +41,6 @@ impl HuggingFaceTokenizer { ...@@ -43,7 +41,6 @@ impl HuggingFaceTokenizer {
.collect(); .collect();
// Load chat template // Load chat template
#[cfg(feature = "minijinja")]
let chat_template = if let Some(template_path) = chat_template_path { let chat_template = if let Some(template_path) = chat_template_path {
// Load from specified .jinja file // Load from specified .jinja file
Self::load_chat_template_from_file(template_path)? Self::load_chat_template_from_file(template_path)?
...@@ -57,7 +54,6 @@ impl HuggingFaceTokenizer { ...@@ -57,7 +54,6 @@ impl HuggingFaceTokenizer {
special_tokens, special_tokens,
vocab, vocab,
reverse_vocab, reverse_vocab,
#[cfg(feature = "minijinja")]
chat_template, chat_template,
}) })
} }
...@@ -76,7 +72,6 @@ impl HuggingFaceTokenizer { ...@@ -76,7 +72,6 @@ impl HuggingFaceTokenizer {
special_tokens, special_tokens,
vocab, vocab,
reverse_vocab, reverse_vocab,
#[cfg(feature = "minijinja")]
chat_template: None, chat_template: None,
} }
} }
...@@ -109,7 +104,6 @@ impl HuggingFaceTokenizer { ...@@ -109,7 +104,6 @@ impl HuggingFaceTokenizer {
} }
/// Try to load chat template from tokenizer_config.json /// Try to load chat template from tokenizer_config.json
#[cfg(feature = "minijinja")]
fn load_chat_template(tokenizer_path: &str) -> Option<String> { fn load_chat_template(tokenizer_path: &str) -> Option<String> {
// Try to find tokenizer_config.json in the same directory // Try to find tokenizer_config.json in the same directory
let path = std::path::Path::new(tokenizer_path); let path = std::path::Path::new(tokenizer_path);
...@@ -127,7 +121,6 @@ impl HuggingFaceTokenizer { ...@@ -127,7 +121,6 @@ impl HuggingFaceTokenizer {
} }
/// Load chat template from a .jinja file /// Load chat template from a .jinja file
#[cfg(feature = "minijinja")]
fn load_chat_template_from_file(template_path: &str) -> Result<Option<String>> { fn load_chat_template_from_file(template_path: &str) -> Result<Option<String>> {
use std::fs; use std::fs;
...@@ -141,13 +134,11 @@ impl HuggingFaceTokenizer { ...@@ -141,13 +134,11 @@ impl HuggingFaceTokenizer {
} }
/// Set or override the chat template /// Set or override the chat template
#[cfg(feature = "minijinja")]
pub fn set_chat_template(&mut self, template: String) { pub fn set_chat_template(&mut self, template: String) {
self.chat_template = Some(template); self.chat_template = Some(template);
} }
/// Apply chat template if available /// Apply chat template if available
#[cfg(feature = "minijinja")]
pub fn apply_chat_template( pub fn apply_chat_template(
&self, &self,
messages: &[ChatMessage], messages: &[ChatMessage],
...@@ -172,24 +163,6 @@ impl HuggingFaceTokenizer { ...@@ -172,24 +163,6 @@ impl HuggingFaceTokenizer {
Ok(result) Ok(result)
} }
} }
/// Apply chat template if available (without minijinja feature)
#[cfg(not(feature = "minijinja"))]
pub fn apply_chat_template(
&self,
messages: &[ChatMessage],
add_generation_prompt: bool,
) -> Result<String> {
// Fallback to simple formatting
let mut result = String::new();
for msg in messages {
result.push_str(&format!("{}: {}\n", msg.role, msg.content));
}
if add_generation_prompt {
result.push_str("assistant: ");
}
Ok(result)
}
} }
impl Encoder for HuggingFaceTokenizer { impl Encoder for HuggingFaceTokenizer {
...@@ -241,10 +214,8 @@ impl TokenizerTrait for HuggingFaceTokenizer { ...@@ -241,10 +214,8 @@ impl TokenizerTrait for HuggingFaceTokenizer {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
#[cfg(feature = "minijinja")]
use super::ChatMessage; use super::ChatMessage;
#[cfg(feature = "minijinja")]
#[test] #[test]
fn test_chat_message_creation() { fn test_chat_message_creation() {
let msg = ChatMessage::system("You are a helpful assistant"); let msg = ChatMessage::system("You are a helpful assistant");
......
...@@ -3,6 +3,7 @@ use std::ops::Deref; ...@@ -3,6 +3,7 @@ use std::ops::Deref;
use std::sync::Arc; use std::sync::Arc;
pub mod factory; pub mod factory;
pub mod hub;
pub mod mock; pub mod mock;
pub mod sequence; pub mod sequence;
pub mod stop; pub mod stop;
...@@ -10,13 +11,11 @@ pub mod stream; ...@@ -10,13 +11,11 @@ pub mod stream;
pub mod traits; pub mod traits;
// Feature-gated modules // Feature-gated modules
#[cfg(feature = "huggingface")]
pub mod chat_template; pub mod chat_template;
#[cfg(feature = "huggingface")]
pub mod huggingface; pub mod huggingface;
#[cfg(feature = "tiktoken")]
pub mod tiktoken; pub mod tiktoken;
#[cfg(test)] #[cfg(test)]
...@@ -24,21 +23,18 @@ mod tests; ...@@ -24,21 +23,18 @@ mod tests;
// Re-exports // Re-exports
pub use factory::{ pub use factory::{
create_tokenizer, create_tokenizer_from_file, create_tokenizer_with_chat_template, create_tokenizer, create_tokenizer_async, create_tokenizer_from_file,
TokenizerType, create_tokenizer_with_chat_template, TokenizerType,
}; };
pub use sequence::Sequence; pub use sequence::Sequence;
pub use stop::{SequenceDecoderOutput, StopSequenceConfig, StopSequenceDecoder}; pub use stop::{SequenceDecoderOutput, StopSequenceConfig, StopSequenceDecoder};
pub use stream::DecodeStream; pub use stream::DecodeStream;
pub use traits::{Decoder, Encoder, Encoding, SpecialTokens, Tokenizer as TokenizerTrait}; pub use traits::{Decoder, Encoder, Encoding, SpecialTokens, Tokenizer as TokenizerTrait};
#[cfg(feature = "huggingface")]
pub use huggingface::HuggingFaceTokenizer; pub use huggingface::HuggingFaceTokenizer;
#[cfg(feature = "huggingface")]
pub use chat_template::ChatMessage; pub use chat_template::ChatMessage;
#[cfg(feature = "tiktoken")]
pub use tiktoken::{TiktokenModel, TiktokenTokenizer}; pub use tiktoken::{TiktokenModel, TiktokenTokenizer};
/// Main tokenizer wrapper that provides a unified interface for different tokenizer implementations /// Main tokenizer wrapper that provides a unified interface for different tokenizer implementations
......
...@@ -3,7 +3,6 @@ mod tests { ...@@ -3,7 +3,6 @@ mod tests {
use sglang_router_rs::tokenizer::chat_template::{ChatMessage, ChatTemplateProcessor}; use sglang_router_rs::tokenizer::chat_template::{ChatMessage, ChatTemplateProcessor};
#[test] #[test]
#[cfg(feature = "huggingface")]
fn test_chat_message_helpers() { fn test_chat_message_helpers() {
let system_msg = ChatMessage::system("You are a helpful assistant"); let system_msg = ChatMessage::system("You are a helpful assistant");
assert_eq!(system_msg.role, "system"); assert_eq!(system_msg.role, "system");
...@@ -19,7 +18,6 @@ mod tests { ...@@ -19,7 +18,6 @@ mod tests {
} }
#[test] #[test]
#[cfg(feature = "huggingface")]
fn test_llama_style_template() { fn test_llama_style_template() {
// Test a Llama-style chat template // Test a Llama-style chat template
let template = r#" let template = r#"
...@@ -67,7 +65,6 @@ mod tests { ...@@ -67,7 +65,6 @@ mod tests {
} }
#[test] #[test]
#[cfg(feature = "huggingface")]
fn test_chatml_template() { fn test_chatml_template() {
// Test a ChatML-style template // Test a ChatML-style template
let template = r#" let template = r#"
...@@ -97,7 +94,6 @@ mod tests { ...@@ -97,7 +94,6 @@ mod tests {
} }
#[test] #[test]
#[cfg(feature = "huggingface")]
fn test_template_without_generation_prompt() { fn test_template_without_generation_prompt() {
let template = r#" let template = r#"
{%- for message in messages -%} {%- for message in messages -%}
...@@ -122,7 +118,6 @@ assistant: ...@@ -122,7 +118,6 @@ assistant:
} }
#[test] #[test]
#[cfg(feature = "huggingface")]
fn test_template_with_special_tokens() { fn test_template_with_special_tokens() {
let template = r#"{{ bos_token }}{% for msg in messages %}{{ msg.content }}{{ eos_token }}{% endfor %}"#; let template = r#"{{ bos_token }}{% for msg in messages %}{{ msg.content }}{{ eos_token }}{% endfor %}"#;
...@@ -139,7 +134,6 @@ assistant: ...@@ -139,7 +134,6 @@ assistant:
} }
#[test] #[test]
#[cfg(feature = "huggingface")]
fn test_empty_messages() { fn test_empty_messages() {
let template = let template =
r#"{% for msg in messages %}{{ msg.role }}: {{ msg.content }}\n{% endfor %}"#; r#"{% for msg in messages %}{{ msg.role }}: {{ msg.content }}\n{% endfor %}"#;
......
...@@ -4,7 +4,6 @@ mod tests { ...@@ -4,7 +4,6 @@ mod tests {
use tempfile::TempDir; use tempfile::TempDir;
#[test] #[test]
#[cfg(feature = "huggingface")]
fn test_load_chat_template_from_file() { fn test_load_chat_template_from_file() {
use sglang_router_rs::tokenizer::chat_template::ChatMessage; use sglang_router_rs::tokenizer::chat_template::ChatMessage;
use sglang_router_rs::tokenizer::huggingface::HuggingFaceTokenizer; use sglang_router_rs::tokenizer::huggingface::HuggingFaceTokenizer;
...@@ -73,7 +72,6 @@ mod tests { ...@@ -73,7 +72,6 @@ mod tests {
} }
#[test] #[test]
#[cfg(feature = "huggingface")]
fn test_override_existing_template() { fn test_override_existing_template() {
use sglang_router_rs::tokenizer::chat_template::ChatMessage; use sglang_router_rs::tokenizer::chat_template::ChatMessage;
use sglang_router_rs::tokenizer::huggingface::HuggingFaceTokenizer; use sglang_router_rs::tokenizer::huggingface::HuggingFaceTokenizer;
...@@ -136,7 +134,6 @@ mod tests { ...@@ -136,7 +134,6 @@ mod tests {
} }
#[test] #[test]
#[cfg(feature = "huggingface")]
fn test_set_chat_template_after_creation() { fn test_set_chat_template_after_creation() {
use sglang_router_rs::tokenizer::chat_template::ChatMessage; use sglang_router_rs::tokenizer::chat_template::ChatMessage;
use sglang_router_rs::tokenizer::huggingface::HuggingFaceTokenizer; use sglang_router_rs::tokenizer::huggingface::HuggingFaceTokenizer;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment