Unverified Commit 728af887 authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] allow user to specify chat template path (#11549)

parent 7b59b0b8
...@@ -83,10 +83,9 @@ class RouterArgs: ...@@ -83,10 +83,9 @@ class RouterArgs:
cb_timeout_duration_secs: int = 60 cb_timeout_duration_secs: int = 60
cb_window_duration_secs: int = 120 cb_window_duration_secs: int = 120
disable_circuit_breaker: bool = False disable_circuit_breaker: bool = False
# Tokenizer configuration
model_path: Optional[str] = None model_path: Optional[str] = None
tokenizer_path: Optional[str] = None tokenizer_path: Optional[str] = None
# Parser configuration chat_template: Optional[str] = None
reasoning_parser: Optional[str] = None reasoning_parser: Optional[str] = None
tool_call_parser: Optional[str] = None tool_call_parser: Optional[str] = None
...@@ -449,6 +448,12 @@ class RouterArgs: ...@@ -449,6 +448,12 @@ class RouterArgs:
default=None, default=None,
help="Explicit tokenizer path (overrides model_path tokenizer if provided)", help="Explicit tokenizer path (overrides model_path tokenizer if provided)",
) )
parser.add_argument(
f"--{prefix}chat-template",
type=str,
default=None,
help="Chat template path (optional)",
)
parser.add_argument( parser.add_argument(
f"--{prefix}reasoning-parser", f"--{prefix}reasoning-parser",
type=str, type=str,
......
...@@ -67,6 +67,8 @@ pub struct RouterConfig { ...@@ -67,6 +67,8 @@ pub struct RouterConfig {
pub model_path: Option<String>, pub model_path: Option<String>,
/// Explicit tokenizer path (overrides model_path tokenizer if provided) /// Explicit tokenizer path (overrides model_path tokenizer if provided)
pub tokenizer_path: Option<String>, pub tokenizer_path: Option<String>,
/// Chat template path (optional)
pub chat_template: Option<String>,
/// History backend configuration (memory or none, default: memory) /// History backend configuration (memory or none, default: memory)
#[serde(default = "default_history_backend")] #[serde(default = "default_history_backend")]
pub history_backend: HistoryBackend, pub history_backend: HistoryBackend,
...@@ -450,6 +452,7 @@ impl Default for RouterConfig { ...@@ -450,6 +452,7 @@ impl Default for RouterConfig {
connection_mode: ConnectionMode::Http, connection_mode: ConnectionMode::Http,
model_path: None, model_path: None,
tokenizer_path: None, tokenizer_path: None,
chat_template: None,
history_backend: default_history_backend(), history_backend: default_history_backend(),
oracle: None, oracle: None,
reasoning_parser: None, reasoning_parser: None,
...@@ -994,6 +997,7 @@ mod tests { ...@@ -994,6 +997,7 @@ mod tests {
connection_mode: ConnectionMode::Http, connection_mode: ConnectionMode::Http,
model_path: None, model_path: None,
tokenizer_path: None, tokenizer_path: None,
chat_template: None,
history_backend: default_history_backend(), history_backend: default_history_backend(),
oracle: None, oracle: None,
reasoning_parser: None, reasoning_parser: None,
...@@ -1061,6 +1065,7 @@ mod tests { ...@@ -1061,6 +1065,7 @@ mod tests {
connection_mode: ConnectionMode::Http, connection_mode: ConnectionMode::Http,
model_path: None, model_path: None,
tokenizer_path: None, tokenizer_path: None,
chat_template: None,
history_backend: default_history_backend(), history_backend: default_history_backend(),
oracle: None, oracle: None,
reasoning_parser: None, reasoning_parser: None,
...@@ -1124,6 +1129,7 @@ mod tests { ...@@ -1124,6 +1129,7 @@ mod tests {
connection_mode: ConnectionMode::Http, connection_mode: ConnectionMode::Http,
model_path: None, model_path: None,
tokenizer_path: None, tokenizer_path: None,
chat_template: None,
history_backend: default_history_backend(), history_backend: default_history_backend(),
oracle: None, oracle: None,
reasoning_parser: None, reasoning_parser: None,
......
...@@ -90,6 +90,7 @@ struct Router { ...@@ -90,6 +90,7 @@ struct Router {
connection_mode: config::ConnectionMode, connection_mode: config::ConnectionMode,
model_path: Option<String>, model_path: Option<String>,
tokenizer_path: Option<String>, tokenizer_path: Option<String>,
chat_template: Option<String>,
reasoning_parser: Option<String>, reasoning_parser: Option<String>,
tool_call_parser: Option<String>, tool_call_parser: Option<String>,
} }
...@@ -216,6 +217,7 @@ impl Router { ...@@ -216,6 +217,7 @@ impl Router {
enable_igw: self.enable_igw, enable_igw: self.enable_igw,
model_path: self.model_path.clone(), model_path: self.model_path.clone(),
tokenizer_path: self.tokenizer_path.clone(), tokenizer_path: self.tokenizer_path.clone(),
chat_template: self.chat_template.clone(),
history_backend: config::HistoryBackend::Memory, history_backend: config::HistoryBackend::Memory,
oracle: None, oracle: None,
reasoning_parser: self.reasoning_parser.clone(), reasoning_parser: self.reasoning_parser.clone(),
...@@ -284,6 +286,7 @@ impl Router { ...@@ -284,6 +286,7 @@ impl Router {
rate_limit_tokens_per_second = None, rate_limit_tokens_per_second = None,
model_path = None, model_path = None,
tokenizer_path = None, tokenizer_path = None,
chat_template = None,
reasoning_parser = None, reasoning_parser = None,
tool_call_parser = None, tool_call_parser = None,
))] ))]
...@@ -345,6 +348,7 @@ impl Router { ...@@ -345,6 +348,7 @@ impl Router {
rate_limit_tokens_per_second: Option<i32>, rate_limit_tokens_per_second: Option<i32>,
model_path: Option<String>, model_path: Option<String>,
tokenizer_path: Option<String>, tokenizer_path: Option<String>,
chat_template: Option<String>,
reasoning_parser: Option<String>, reasoning_parser: Option<String>,
tool_call_parser: Option<String>, tool_call_parser: Option<String>,
) -> PyResult<Self> { ) -> PyResult<Self> {
...@@ -420,6 +424,7 @@ impl Router { ...@@ -420,6 +424,7 @@ impl Router {
connection_mode, connection_mode,
model_path, model_path,
tokenizer_path, tokenizer_path,
chat_template,
reasoning_parser, reasoning_parser,
tool_call_parser, tool_call_parser,
}) })
......
...@@ -255,6 +255,9 @@ struct CliArgs { ...@@ -255,6 +255,9 @@ struct CliArgs {
#[arg(long)] #[arg(long)]
tokenizer_path: Option<String>, tokenizer_path: Option<String>,
#[arg(long)]
chat_template: Option<String>,
#[arg(long, default_value = "memory", value_parser = ["memory", "none", "oracle"])] #[arg(long, default_value = "memory", value_parser = ["memory", "none", "oracle"])]
history_backend: String, history_backend: String,
...@@ -561,6 +564,7 @@ impl CliArgs { ...@@ -561,6 +564,7 @@ impl CliArgs {
rate_limit_tokens_per_second: None, rate_limit_tokens_per_second: None,
model_path: self.model_path.clone(), model_path: self.model_path.clone(),
tokenizer_path: self.tokenizer_path.clone(), tokenizer_path: self.tokenizer_path.clone(),
chat_template: self.chat_template.clone(),
history_backend, history_backend,
oracle, oracle,
reasoning_parser: self.reasoning_parser.clone(), reasoning_parser: self.reasoning_parser.clone(),
......
...@@ -82,28 +82,40 @@ impl AppContext { ...@@ -82,28 +82,40 @@ impl AppContext {
} }
}; };
let (tokenizer, reasoning_parser_factory, tool_parser_factory) = let (tokenizer, reasoning_parser_factory, tool_parser_factory) = if router_config
if router_config.connection_mode == ConnectionMode::Grpc { .connection_mode
let tokenizer_path = router_config == ConnectionMode::Grpc
.tokenizer_path {
.clone() let tokenizer_path = router_config
.or_else(|| router_config.model_path.clone()) .tokenizer_path
.ok_or_else(|| { .clone()
"gRPC mode requires either --tokenizer-path or --model-path to be specified" .or_else(|| router_config.model_path.clone())
.to_string() .ok_or_else(|| {
})?; "gRPC mode requires either --tokenizer-path or --model-path to be specified"
.to_string()
})?;
let tokenizer = Some( let tokenizer = Some(
tokenizer_factory::create_tokenizer(&tokenizer_path) tokenizer_factory::create_tokenizer_with_chat_template_blocking(
.map_err(|e| format!("Failed to create tokenizer: {e}"))?, &tokenizer_path,
router_config.chat_template.as_deref(),
)
.map_err(|e| {
format!(
"Failed to create tokenizer from '{}': {}. \
Ensure the path is valid and points to a tokenizer file (tokenizer.json) \
or a HuggingFace model ID. For directories, ensure they contain tokenizer files.",
tokenizer_path, e
)
})?,
); );
let reasoning_parser_factory = Some(crate::reasoning_parser::ParserFactory::new()); let reasoning_parser_factory = Some(crate::reasoning_parser::ParserFactory::new());
let tool_parser_factory = Some(crate::tool_parser::ParserFactory::new()); let tool_parser_factory = Some(crate::tool_parser::ParserFactory::new());
(tokenizer, reasoning_parser_factory, tool_parser_factory) (tokenizer, reasoning_parser_factory, tool_parser_factory)
} else { } else {
(None, None, None) (None, None, None)
}; };
let worker_registry = Arc::new(WorkerRegistry::new()); let worker_registry = Arc::new(WorkerRegistry::new());
let policy_registry = Arc::new(PolicyRegistry::new(router_config.policy.clone())); let policy_registry = Arc::new(PolicyRegistry::new(router_config.policy.clone()));
......
...@@ -4,6 +4,7 @@ use std::fs::File; ...@@ -4,6 +4,7 @@ use std::fs::File;
use std::io::Read; use std::io::Read;
use std::path::Path; use std::path::Path;
use std::sync::Arc; use std::sync::Arc;
use tracing::{debug, info};
use super::huggingface::HuggingFaceTokenizer; use super::huggingface::HuggingFaceTokenizer;
use super::tiktoken::TiktokenTokenizer; use super::tiktoken::TiktokenTokenizer;
...@@ -189,14 +190,57 @@ pub fn discover_chat_template_in_dir(dir: &Path) -> Option<String> { ...@@ -189,14 +190,57 @@ pub fn discover_chat_template_in_dir(dir: &Path) -> Option<String> {
None None
} }
/// Helper function to resolve and log chat template selection
///
/// Resolves the final chat template to use by prioritizing provided path over auto-discovery,
/// and logs the source for debugging purposes.
fn resolve_and_log_chat_template(
provided_path: Option<&str>,
discovery_dir: &Path,
model_name: &str,
) -> Option<String> {
let final_chat_template = provided_path
.map(|s| s.to_string())
.or_else(|| discover_chat_template_in_dir(discovery_dir));
match (&provided_path, &final_chat_template) {
(Some(provided), _) => {
info!("Using provided chat template: {}", provided);
}
(None, Some(discovered)) => {
info!(
"Auto-discovered chat template in '{}': {}",
discovery_dir.display(),
discovered
);
}
(None, None) => {
debug!(
"No chat template provided or discovered for model: {}",
model_name
);
}
}
final_chat_template
}
/// Factory function to create tokenizer from a model name or path (async version) /// Factory function to create tokenizer from a model name or path (async version)
pub async fn create_tokenizer_async( pub async fn create_tokenizer_async(
model_name_or_path: &str, model_name_or_path: &str,
) -> Result<Arc<dyn traits::Tokenizer>> {
create_tokenizer_async_with_chat_template(model_name_or_path, None).await
}
/// Factory function to create tokenizer with optional chat template (async version)
pub async fn create_tokenizer_async_with_chat_template(
model_name_or_path: &str,
chat_template_path: Option<&str>,
) -> Result<Arc<dyn traits::Tokenizer>> { ) -> Result<Arc<dyn traits::Tokenizer>> {
// Check if it's a file path // Check if it's a file path
let path = Path::new(model_name_or_path); let path = Path::new(model_name_or_path);
if path.exists() { if path.exists() {
return create_tokenizer_from_file(model_name_or_path); return create_tokenizer_with_chat_template(model_name_or_path, chat_template_path);
} }
// Check if it's a GPT model name that should use Tiktoken // Check if it's a GPT model name that should use Tiktoken
...@@ -216,8 +260,13 @@ pub async fn create_tokenizer_async( ...@@ -216,8 +260,13 @@ pub async fn create_tokenizer_async(
// Look for tokenizer.json in the cache directory // Look for tokenizer.json in the cache directory
let tokenizer_path = cache_dir.join("tokenizer.json"); let tokenizer_path = cache_dir.join("tokenizer.json");
if tokenizer_path.exists() { if tokenizer_path.exists() {
// Try to find a chat template file in the cache directory // Resolve chat template: provided path takes precedence over auto-discovery
let chat_template_path = discover_chat_template_in_dir(&cache_dir); let final_chat_template = resolve_and_log_chat_template(
chat_template_path,
&cache_dir,
model_name_or_path,
);
let tokenizer_path_str = tokenizer_path.to_str().ok_or_else(|| { let tokenizer_path_str = tokenizer_path.to_str().ok_or_else(|| {
Error::msg(format!( Error::msg(format!(
"Tokenizer path is not valid UTF-8: {:?}", "Tokenizer path is not valid UTF-8: {:?}",
...@@ -226,7 +275,7 @@ pub async fn create_tokenizer_async( ...@@ -226,7 +275,7 @@ pub async fn create_tokenizer_async(
})?; })?;
create_tokenizer_with_chat_template( create_tokenizer_with_chat_template(
tokenizer_path_str, tokenizer_path_str,
chat_template_path.as_deref(), final_chat_template.as_deref(),
) )
} else { } else {
// Try other common tokenizer file names // Try other common tokenizer file names
...@@ -234,13 +283,19 @@ pub async fn create_tokenizer_async( ...@@ -234,13 +283,19 @@ pub async fn create_tokenizer_async(
for file_name in &possible_files { for file_name in &possible_files {
let file_path = cache_dir.join(file_name); let file_path = cache_dir.join(file_name);
if file_path.exists() { if file_path.exists() {
let chat_template_path = discover_chat_template_in_dir(&cache_dir); // Resolve chat template: provided path takes precedence over auto-discovery
let final_chat_template = resolve_and_log_chat_template(
chat_template_path,
&cache_dir,
model_name_or_path,
);
let file_path_str = file_path.to_str().ok_or_else(|| { let file_path_str = file_path.to_str().ok_or_else(|| {
Error::msg(format!("File path is not valid UTF-8: {:?}", file_path)) Error::msg(format!("File path is not valid UTF-8: {:?}", file_path))
})?; })?;
return create_tokenizer_with_chat_template( return create_tokenizer_with_chat_template(
file_path_str, file_path_str,
chat_template_path.as_deref(), final_chat_template.as_deref(),
); );
} }
} }
...@@ -258,11 +313,22 @@ pub async fn create_tokenizer_async( ...@@ -258,11 +313,22 @@ pub async fn create_tokenizer_async(
} }
/// Factory function to create tokenizer from a model name or path (blocking version) /// Factory function to create tokenizer from a model name or path (blocking version)
///
/// This delegates to `create_tokenizer_with_chat_template_blocking` with no chat template,
/// which handles both local files and HuggingFace Hub downloads uniformly.
pub fn create_tokenizer(model_name_or_path: &str) -> Result<Arc<dyn traits::Tokenizer>> { pub fn create_tokenizer(model_name_or_path: &str) -> Result<Arc<dyn traits::Tokenizer>> {
create_tokenizer_with_chat_template_blocking(model_name_or_path, None)
}
/// Factory function to create tokenizer with optional chat template (blocking version)
pub fn create_tokenizer_with_chat_template_blocking(
model_name_or_path: &str,
chat_template_path: Option<&str>,
) -> Result<Arc<dyn traits::Tokenizer>> {
// Check if it's a file path // Check if it's a file path
let path = Path::new(model_name_or_path); let path = Path::new(model_name_or_path);
if path.exists() { if path.exists() {
return create_tokenizer_from_file(model_name_or_path); return create_tokenizer_with_chat_template(model_name_or_path, chat_template_path);
} }
// Check if it's a GPT model name that should use Tiktoken // Check if it's a GPT model name that should use Tiktoken
...@@ -280,11 +346,19 @@ pub fn create_tokenizer(model_name_or_path: &str) -> Result<Arc<dyn traits::Toke ...@@ -280,11 +346,19 @@ pub fn create_tokenizer(model_name_or_path: &str) -> Result<Arc<dyn traits::Toke
// Check if we're already in a tokio runtime // Check if we're already in a tokio runtime
if let Ok(handle) = tokio::runtime::Handle::try_current() { if let Ok(handle) = tokio::runtime::Handle::try_current() {
// We're in a runtime, use block_in_place // We're in a runtime, use block_in_place
tokio::task::block_in_place(|| handle.block_on(create_tokenizer_async(model_name_or_path))) tokio::task::block_in_place(|| {
handle.block_on(create_tokenizer_async_with_chat_template(
model_name_or_path,
chat_template_path,
))
})
} else { } else {
// No runtime, create a temporary one // No runtime, create a temporary one
let rt = tokio::runtime::Runtime::new()?; let rt = tokio::runtime::Runtime::new()?;
rt.block_on(create_tokenizer_async(model_name_or_path)) rt.block_on(create_tokenizer_async_with_chat_template(
model_name_or_path,
chat_template_path,
))
} }
} }
......
...@@ -23,8 +23,9 @@ mod tests; ...@@ -23,8 +23,9 @@ mod tests;
// Re-exports // Re-exports
pub use factory::{ pub use factory::{
create_tokenizer, create_tokenizer_async, create_tokenizer_from_file, create_tokenizer, create_tokenizer_async, create_tokenizer_async_with_chat_template,
create_tokenizer_with_chat_template, TokenizerType, create_tokenizer_from_file, create_tokenizer_with_chat_template,
create_tokenizer_with_chat_template_blocking, TokenizerType,
}; };
pub use sequence::Sequence; pub use sequence::Sequence;
pub use stop::{SequenceDecoderOutput, StopSequenceConfig, StopSequenceDecoder}; pub use stop::{SequenceDecoderOutput, StopSequenceConfig, StopSequenceDecoder};
......
...@@ -30,6 +30,7 @@ impl TestContext { ...@@ -30,6 +30,7 @@ impl TestContext {
async fn new(worker_configs: Vec<MockWorkerConfig>) -> Self { async fn new(worker_configs: Vec<MockWorkerConfig>) -> Self {
// Create default router config // Create default router config
let config = RouterConfig { let config = RouterConfig {
chat_template: None,
mode: RoutingMode::Regular { mode: RoutingMode::Regular {
worker_urls: vec![], worker_urls: vec![],
}, },
...@@ -1365,6 +1366,7 @@ mod error_tests { ...@@ -1365,6 +1366,7 @@ mod error_tests {
async fn test_payload_too_large() { async fn test_payload_too_large() {
// Create context with small payload limit // Create context with small payload limit
let config = RouterConfig { let config = RouterConfig {
chat_template: None,
mode: RoutingMode::Regular { mode: RoutingMode::Regular {
worker_urls: vec![], worker_urls: vec![],
}, },
...@@ -1723,6 +1725,7 @@ mod pd_mode_tests { ...@@ -1723,6 +1725,7 @@ mod pd_mode_tests {
.unwrap_or(9000); .unwrap_or(9000);
let config = RouterConfig { let config = RouterConfig {
chat_template: None,
mode: RoutingMode::PrefillDecode { mode: RoutingMode::PrefillDecode {
prefill_urls: vec![(prefill_url, Some(prefill_port))], prefill_urls: vec![(prefill_url, Some(prefill_port))],
decode_urls: vec![decode_url], decode_urls: vec![decode_url],
...@@ -1888,6 +1891,7 @@ mod request_id_tests { ...@@ -1888,6 +1891,7 @@ mod request_id_tests {
async fn test_request_id_with_custom_headers() { async fn test_request_id_with_custom_headers() {
// Create config with custom request ID headers // Create config with custom request ID headers
let config = RouterConfig { let config = RouterConfig {
chat_template: None,
mode: RoutingMode::Regular { mode: RoutingMode::Regular {
worker_urls: vec![], worker_urls: vec![],
}, },
......
...@@ -18,6 +18,7 @@ struct TestContext { ...@@ -18,6 +18,7 @@ struct TestContext {
impl TestContext { impl TestContext {
async fn new(worker_configs: Vec<MockWorkerConfig>) -> Self { async fn new(worker_configs: Vec<MockWorkerConfig>) -> Self {
let mut config = RouterConfig { let mut config = RouterConfig {
chat_template: None,
mode: RoutingMode::Regular { mode: RoutingMode::Regular {
worker_urls: vec![], worker_urls: vec![],
}, },
......
...@@ -44,6 +44,7 @@ async fn test_non_streaming_mcp_minimal_e2e_with_persistence() { ...@@ -44,6 +44,7 @@ async fn test_non_streaming_mcp_minimal_e2e_with_persistence() {
// Build router config (HTTP OpenAI mode) // Build router config (HTTP OpenAI mode)
let router_cfg = RouterConfig { let router_cfg = RouterConfig {
chat_template: None,
mode: RoutingMode::OpenAI { mode: RoutingMode::OpenAI {
worker_urls: vec![worker_url], worker_urls: vec![worker_url],
}, },
...@@ -245,6 +246,7 @@ async fn test_non_streaming_mcp_minimal_e2e_with_persistence() { ...@@ -245,6 +246,7 @@ async fn test_non_streaming_mcp_minimal_e2e_with_persistence() {
async fn test_conversations_crud_basic() { async fn test_conversations_crud_basic() {
// Router in OpenAI mode (no actual upstream calls in these tests) // Router in OpenAI mode (no actual upstream calls in these tests)
let router_cfg = RouterConfig { let router_cfg = RouterConfig {
chat_template: None,
mode: RoutingMode::OpenAI { mode: RoutingMode::OpenAI {
worker_urls: vec!["http://localhost".to_string()], worker_urls: vec!["http://localhost".to_string()],
}, },
...@@ -576,6 +578,7 @@ async fn test_multi_turn_loop_with_mcp() { ...@@ -576,6 +578,7 @@ async fn test_multi_turn_loop_with_mcp() {
// Build router config // Build router config
let router_cfg = RouterConfig { let router_cfg = RouterConfig {
chat_template: None,
mode: RoutingMode::OpenAI { mode: RoutingMode::OpenAI {
worker_urls: vec![worker_url], worker_urls: vec![worker_url],
}, },
...@@ -753,6 +756,7 @@ async fn test_max_tool_calls_limit() { ...@@ -753,6 +756,7 @@ async fn test_max_tool_calls_limit() {
let worker_url = worker.start().await.expect("start worker"); let worker_url = worker.start().await.expect("start worker");
let router_cfg = RouterConfig { let router_cfg = RouterConfig {
chat_template: None,
mode: RoutingMode::OpenAI { mode: RoutingMode::OpenAI {
worker_urls: vec![worker_url], worker_urls: vec![worker_url],
}, },
...@@ -896,6 +900,7 @@ async fn setup_streaming_mcp_test() -> ( ...@@ -896,6 +900,7 @@ async fn setup_streaming_mcp_test() -> (
let worker_url = worker.start().await.expect("start worker"); let worker_url = worker.start().await.expect("start worker");
let router_cfg = RouterConfig { let router_cfg = RouterConfig {
chat_template: None,
mode: RoutingMode::OpenAI { mode: RoutingMode::OpenAI {
worker_urls: vec![worker_url], worker_urls: vec![worker_url],
}, },
...@@ -1338,6 +1343,7 @@ async fn test_streaming_multi_turn_with_mcp() { ...@@ -1338,6 +1343,7 @@ async fn test_streaming_multi_turn_with_mcp() {
async fn test_conversation_items_create_and_get() { async fn test_conversation_items_create_and_get() {
// Test creating items and getting a specific item // Test creating items and getting a specific item
let router_cfg = RouterConfig { let router_cfg = RouterConfig {
chat_template: None,
mode: RoutingMode::OpenAI { mode: RoutingMode::OpenAI {
worker_urls: vec!["http://localhost".to_string()], worker_urls: vec!["http://localhost".to_string()],
}, },
...@@ -1440,6 +1446,7 @@ async fn test_conversation_items_create_and_get() { ...@@ -1440,6 +1446,7 @@ async fn test_conversation_items_create_and_get() {
async fn test_conversation_items_delete() { async fn test_conversation_items_delete() {
// Test deleting an item from a conversation // Test deleting an item from a conversation
let router_cfg = RouterConfig { let router_cfg = RouterConfig {
chat_template: None,
mode: RoutingMode::OpenAI { mode: RoutingMode::OpenAI {
worker_urls: vec!["http://localhost".to_string()], worker_urls: vec!["http://localhost".to_string()],
}, },
...@@ -1548,6 +1555,7 @@ async fn test_conversation_items_delete() { ...@@ -1548,6 +1555,7 @@ async fn test_conversation_items_delete() {
async fn test_conversation_items_max_limit() { async fn test_conversation_items_max_limit() {
// Test that creating > 20 items returns error // Test that creating > 20 items returns error
let router_cfg = RouterConfig { let router_cfg = RouterConfig {
chat_template: None,
mode: RoutingMode::OpenAI { mode: RoutingMode::OpenAI {
worker_urls: vec!["http://localhost".to_string()], worker_urls: vec!["http://localhost".to_string()],
}, },
...@@ -1626,6 +1634,7 @@ async fn test_conversation_items_max_limit() { ...@@ -1626,6 +1634,7 @@ async fn test_conversation_items_max_limit() {
async fn test_conversation_items_unsupported_type() { async fn test_conversation_items_unsupported_type() {
// Test that unsupported item types return error // Test that unsupported item types return error
let router_cfg = RouterConfig { let router_cfg = RouterConfig {
chat_template: None,
mode: RoutingMode::OpenAI { mode: RoutingMode::OpenAI {
worker_urls: vec!["http://localhost".to_string()], worker_urls: vec!["http://localhost".to_string()],
}, },
...@@ -1703,6 +1712,7 @@ async fn test_conversation_items_unsupported_type() { ...@@ -1703,6 +1712,7 @@ async fn test_conversation_items_unsupported_type() {
async fn test_conversation_items_multi_conversation_sharing() { async fn test_conversation_items_multi_conversation_sharing() {
// Test that items can be shared across conversations via soft delete // Test that items can be shared across conversations via soft delete
let router_cfg = RouterConfig { let router_cfg = RouterConfig {
chat_template: None,
mode: RoutingMode::OpenAI { mode: RoutingMode::OpenAI {
worker_urls: vec!["http://localhost".to_string()], worker_urls: vec!["http://localhost".to_string()],
}, },
......
...@@ -19,6 +19,7 @@ struct TestContext { ...@@ -19,6 +19,7 @@ struct TestContext {
impl TestContext { impl TestContext {
async fn new(worker_configs: Vec<MockWorkerConfig>) -> Self { async fn new(worker_configs: Vec<MockWorkerConfig>) -> Self {
let mut config = RouterConfig { let mut config = RouterConfig {
chat_template: None,
mode: RoutingMode::Regular { mode: RoutingMode::Regular {
worker_urls: vec![], worker_urls: vec![],
}, },
......
...@@ -867,6 +867,7 @@ async fn test_openai_router_models_auth_forwarding() { ...@@ -867,6 +867,7 @@ async fn test_openai_router_models_auth_forwarding() {
#[test] #[test]
fn oracle_config_validation_requires_config_when_enabled() { fn oracle_config_validation_requires_config_when_enabled() {
let config = RouterConfig { let config = RouterConfig {
chat_template: None,
mode: RoutingMode::OpenAI { mode: RoutingMode::OpenAI {
worker_urls: vec!["https://api.openai.com".to_string()], worker_urls: vec!["https://api.openai.com".to_string()],
}, },
...@@ -891,6 +892,7 @@ fn oracle_config_validation_requires_config_when_enabled() { ...@@ -891,6 +892,7 @@ fn oracle_config_validation_requires_config_when_enabled() {
#[test] #[test]
fn oracle_config_validation_accepts_dsn_only() { fn oracle_config_validation_accepts_dsn_only() {
let config = RouterConfig { let config = RouterConfig {
chat_template: None,
mode: RoutingMode::OpenAI { mode: RoutingMode::OpenAI {
worker_urls: vec!["https://api.openai.com".to_string()], worker_urls: vec!["https://api.openai.com".to_string()],
}, },
...@@ -913,6 +915,7 @@ fn oracle_config_validation_accepts_dsn_only() { ...@@ -913,6 +915,7 @@ fn oracle_config_validation_accepts_dsn_only() {
#[test] #[test]
fn oracle_config_validation_accepts_wallet_alias() { fn oracle_config_validation_accepts_wallet_alias() {
let config = RouterConfig { let config = RouterConfig {
chat_template: None,
mode: RoutingMode::OpenAI { mode: RoutingMode::OpenAI {
worker_urls: vec!["https://api.openai.com".to_string()], worker_urls: vec!["https://api.openai.com".to_string()],
}, },
......
...@@ -164,6 +164,7 @@ mod test_pd_routing { ...@@ -164,6 +164,7 @@ mod test_pd_routing {
for (mode, policy) in test_cases { for (mode, policy) in test_cases {
let config = RouterConfig { let config = RouterConfig {
chat_template: None,
mode, mode,
policy, policy,
host: "127.0.0.1".to_string(), host: "127.0.0.1".to_string(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment