Unverified Commit 79d34951 authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] add reasoning and tool parser argument in router (#11290)

parent 1519a89c
...@@ -86,6 +86,9 @@ class RouterArgs: ...@@ -86,6 +86,9 @@ class RouterArgs:
# Tokenizer configuration # Tokenizer configuration
model_path: Optional[str] = None model_path: Optional[str] = None
tokenizer_path: Optional[str] = None tokenizer_path: Optional[str] = None
# Parser configuration
reasoning_parser: Optional[str] = None
tool_call_parser: Optional[str] = None
@staticmethod @staticmethod
def add_cli_args( def add_cli_args(
...@@ -446,6 +449,18 @@ class RouterArgs: ...@@ -446,6 +449,18 @@ class RouterArgs:
default=None, default=None,
help="Explicit tokenizer path (overrides model_path tokenizer if provided)", help="Explicit tokenizer path (overrides model_path tokenizer if provided)",
) )
parser.add_argument(
f"--{prefix}reasoning-parser",
type=str,
default=None,
help="Specify the parser for reasoning models (e.g., deepseek-r1, qwen3)",
)
parser.add_argument(
f"--{prefix}tool-call-parser",
type=str,
default=None,
help="Specify the parser for handling tool-call interactions",
)
@classmethod @classmethod
def from_cli_args( def from_cli_args(
......
...@@ -73,6 +73,10 @@ pub struct RouterConfig { ...@@ -73,6 +73,10 @@ pub struct RouterConfig {
/// Oracle history backend configuration (required when `history_backend` = "oracle") /// Oracle history backend configuration (required when `history_backend` = "oracle")
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub oracle: Option<OracleConfig>, pub oracle: Option<OracleConfig>,
/// Parser for reasoning models (e.g., deepseek-r1, qwen3)
pub reasoning_parser: Option<String>,
/// Parser for handling tool-call interactions
pub tool_call_parser: Option<String>,
} }
fn default_history_backend() -> HistoryBackend { fn default_history_backend() -> HistoryBackend {
...@@ -448,6 +452,8 @@ impl Default for RouterConfig { ...@@ -448,6 +452,8 @@ impl Default for RouterConfig {
tokenizer_path: None, tokenizer_path: None,
history_backend: default_history_backend(), history_backend: default_history_backend(),
oracle: None, oracle: None,
reasoning_parser: None,
tool_call_parser: None,
} }
} }
} }
...@@ -990,6 +996,8 @@ mod tests { ...@@ -990,6 +996,8 @@ mod tests {
tokenizer_path: None, tokenizer_path: None,
history_backend: default_history_backend(), history_backend: default_history_backend(),
oracle: None, oracle: None,
reasoning_parser: None,
tool_call_parser: None,
}; };
assert!(config.mode.is_pd_mode()); assert!(config.mode.is_pd_mode());
...@@ -1055,6 +1063,8 @@ mod tests { ...@@ -1055,6 +1063,8 @@ mod tests {
tokenizer_path: None, tokenizer_path: None,
history_backend: default_history_backend(), history_backend: default_history_backend(),
oracle: None, oracle: None,
reasoning_parser: None,
tool_call_parser: None,
}; };
assert!(!config.mode.is_pd_mode()); assert!(!config.mode.is_pd_mode());
...@@ -1116,6 +1126,8 @@ mod tests { ...@@ -1116,6 +1126,8 @@ mod tests {
tokenizer_path: None, tokenizer_path: None,
history_backend: default_history_backend(), history_backend: default_history_backend(),
oracle: None, oracle: None,
reasoning_parser: None,
tool_call_parser: None,
}; };
assert!(config.has_service_discovery()); assert!(config.has_service_discovery());
......
...@@ -90,6 +90,8 @@ struct Router { ...@@ -90,6 +90,8 @@ struct Router {
connection_mode: config::ConnectionMode, connection_mode: config::ConnectionMode,
model_path: Option<String>, model_path: Option<String>,
tokenizer_path: Option<String>, tokenizer_path: Option<String>,
reasoning_parser: Option<String>,
tool_call_parser: Option<String>,
} }
impl Router { impl Router {
...@@ -216,6 +218,8 @@ impl Router { ...@@ -216,6 +218,8 @@ impl Router {
tokenizer_path: self.tokenizer_path.clone(), tokenizer_path: self.tokenizer_path.clone(),
history_backend: config::HistoryBackend::Memory, history_backend: config::HistoryBackend::Memory,
oracle: None, oracle: None,
reasoning_parser: self.reasoning_parser.clone(),
tool_call_parser: self.tool_call_parser.clone(),
}) })
} }
} }
...@@ -280,6 +284,8 @@ impl Router { ...@@ -280,6 +284,8 @@ impl Router {
rate_limit_tokens_per_second = None, rate_limit_tokens_per_second = None,
model_path = None, model_path = None,
tokenizer_path = None, tokenizer_path = None,
reasoning_parser = None,
tool_call_parser = None,
))] ))]
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
fn new( fn new(
...@@ -339,6 +345,8 @@ impl Router { ...@@ -339,6 +345,8 @@ impl Router {
rate_limit_tokens_per_second: Option<usize>, rate_limit_tokens_per_second: Option<usize>,
model_path: Option<String>, model_path: Option<String>,
tokenizer_path: Option<String>, tokenizer_path: Option<String>,
reasoning_parser: Option<String>,
tool_call_parser: Option<String>,
) -> PyResult<Self> { ) -> PyResult<Self> {
let mut all_urls = worker_urls.clone(); let mut all_urls = worker_urls.clone();
...@@ -412,6 +420,8 @@ impl Router { ...@@ -412,6 +420,8 @@ impl Router {
connection_mode, connection_mode,
model_path, model_path,
tokenizer_path, tokenizer_path,
reasoning_parser,
tool_call_parser,
}) })
} }
......
...@@ -281,6 +281,12 @@ struct CliArgs { ...@@ -281,6 +281,12 @@ struct CliArgs {
#[arg(long, env = "ATP_POOL_TIMEOUT_SECS")] #[arg(long, env = "ATP_POOL_TIMEOUT_SECS")]
oracle_pool_timeout_secs: Option<u64>, oracle_pool_timeout_secs: Option<u64>,
#[arg(long)]
reasoning_parser: Option<String>,
#[arg(long)]
tool_call_parser: Option<String>,
} }
enum OracleConnectSource { enum OracleConnectSource {
...@@ -557,6 +563,8 @@ impl CliArgs { ...@@ -557,6 +563,8 @@ impl CliArgs {
tokenizer_path: self.tokenizer_path.clone(), tokenizer_path: self.tokenizer_path.clone(),
history_backend, history_backend,
oracle, oracle,
reasoning_parser: self.reasoning_parser.clone(),
tool_call_parser: self.tool_call_parser.clone(),
}) })
} }
......
...@@ -53,6 +53,8 @@ pub struct GrpcPDRouter { ...@@ -53,6 +53,8 @@ pub struct GrpcPDRouter {
dp_aware: bool, dp_aware: bool,
api_key: Option<String>, api_key: Option<String>,
retry_config: RetryConfig, retry_config: RetryConfig,
configured_reasoning_parser: Option<String>,
configured_tool_parser: Option<String>,
} }
impl GrpcPDRouter { impl GrpcPDRouter {
...@@ -88,6 +90,8 @@ impl GrpcPDRouter { ...@@ -88,6 +90,8 @@ impl GrpcPDRouter {
dp_aware: ctx.router_config.dp_aware, dp_aware: ctx.router_config.dp_aware,
api_key: ctx.router_config.api_key.clone(), api_key: ctx.router_config.api_key.clone(),
retry_config: ctx.router_config.effective_retry_config(), retry_config: ctx.router_config.effective_retry_config(),
configured_reasoning_parser: ctx.configured_reasoning_parser.clone(),
configured_tool_parser: ctx.configured_tool_parser.clone(),
}) })
} }
...@@ -1179,9 +1183,13 @@ impl GrpcPDRouter { ...@@ -1179,9 +1183,13 @@ impl GrpcPDRouter {
created: u64, created: u64,
) -> (String, Option<ChatCompletionStreamResponse>, bool) { ) -> (String, Option<ChatCompletionStreamResponse>, bool) {
// Get or create parser for this index // Get or create parser for this index
reasoning_parsers reasoning_parsers.entry(index).or_insert_with(|| {
.entry(index) utils::get_reasoning_parser(
.or_insert_with(|| self.reasoning_parser_factory.get_pooled(model)); &self.reasoning_parser_factory,
self.configured_reasoning_parser.as_ref(),
model,
)
});
if let Some(pooled_parser) = reasoning_parsers.get(&index) { if let Some(pooled_parser) = reasoning_parsers.get(&index) {
let (parse_result, in_reasoning) = { let (parse_result, in_reasoning) = {
...@@ -1248,9 +1256,13 @@ impl GrpcPDRouter { ...@@ -1248,9 +1256,13 @@ impl GrpcPDRouter {
let mut chunks = Vec::new(); let mut chunks = Vec::new();
// Get or create parser for this index // Get or create parser for this index
tool_parsers tool_parsers.entry(index).or_insert_with(|| {
.entry(index) utils::get_tool_parser(
.or_insert_with(|| self.tool_parser_factory.get_pooled(model)); &self.tool_parser_factory,
self.configured_tool_parser.as_ref(),
model,
)
});
if let Some(pooled_parser) = tool_parsers.get(&index) { if let Some(pooled_parser) = tool_parsers.get(&index) {
let mut parser = pooled_parser.lock().await; let mut parser = pooled_parser.lock().await;
...@@ -1737,9 +1749,11 @@ impl GrpcPDRouter { ...@@ -1737,9 +1749,11 @@ impl GrpcPDRouter {
// Check if reasoning parsing is enabled and separate_reasoning is requested // Check if reasoning parsing is enabled and separate_reasoning is requested
if original_request.separate_reasoning { if original_request.separate_reasoning {
let pooled_parser = self let pooled_parser = utils::get_reasoning_parser(
.reasoning_parser_factory &self.reasoning_parser_factory,
.get_pooled(&original_request.model); self.configured_reasoning_parser.as_ref(),
&original_request.model,
);
let mut parser = pooled_parser let mut parser = pooled_parser
.lock() .lock()
...@@ -1860,7 +1874,11 @@ impl GrpcPDRouter { ...@@ -1860,7 +1874,11 @@ impl GrpcPDRouter {
history_tool_calls_count: usize, history_tool_calls_count: usize,
) -> (Option<Vec<ToolCall>>, String) { ) -> (Option<Vec<ToolCall>>, String) {
// Get pooled parser for this model // Get pooled parser for this model
let pooled_parser = self.tool_parser_factory.get_pooled(model); let pooled_parser = utils::get_tool_parser(
&self.tool_parser_factory,
self.configured_tool_parser.as_ref(),
model,
);
// Check format detection first // Check format detection first
let can_parse = { let can_parse = {
......
...@@ -53,6 +53,8 @@ pub struct GrpcRouter { ...@@ -53,6 +53,8 @@ pub struct GrpcRouter {
dp_aware: bool, dp_aware: bool,
api_key: Option<String>, api_key: Option<String>,
retry_config: RetryConfig, retry_config: RetryConfig,
configured_reasoning_parser: Option<String>,
configured_tool_parser: Option<String>,
} }
impl GrpcRouter { impl GrpcRouter {
...@@ -87,6 +89,8 @@ impl GrpcRouter { ...@@ -87,6 +89,8 @@ impl GrpcRouter {
dp_aware: ctx.router_config.dp_aware, dp_aware: ctx.router_config.dp_aware,
api_key: ctx.router_config.api_key.clone(), api_key: ctx.router_config.api_key.clone(),
retry_config: ctx.router_config.effective_retry_config(), retry_config: ctx.router_config.effective_retry_config(),
configured_reasoning_parser: ctx.configured_reasoning_parser.clone(),
configured_tool_parser: ctx.configured_tool_parser.clone(),
}) })
} }
...@@ -301,7 +305,11 @@ impl GrpcRouter { ...@@ -301,7 +305,11 @@ impl GrpcRouter {
history_tool_calls_count: usize, history_tool_calls_count: usize,
) -> (Option<Vec<ToolCall>>, String) { ) -> (Option<Vec<ToolCall>>, String) {
// Get pooled parser for this model // Get pooled parser for this model
let pooled_parser = self.tool_parser_factory.get_pooled(model); let pooled_parser = utils::get_tool_parser(
&self.tool_parser_factory,
self.configured_tool_parser.as_ref(),
model,
);
// Check format detection first // Check format detection first
let can_parse = { let can_parse = {
...@@ -496,9 +504,13 @@ impl GrpcRouter { ...@@ -496,9 +504,13 @@ impl GrpcRouter {
created: u64, created: u64,
) -> (String, Option<ChatCompletionStreamResponse>, bool) { ) -> (String, Option<ChatCompletionStreamResponse>, bool) {
// Get or create parser for this index // Get or create parser for this index
reasoning_parsers reasoning_parsers.entry(index).or_insert_with(|| {
.entry(index) utils::get_reasoning_parser(
.or_insert_with(|| self.reasoning_parser_factory.get_pooled(model)); &self.reasoning_parser_factory,
self.configured_reasoning_parser.as_ref(),
model,
)
});
if let Some(pooled_parser) = reasoning_parsers.get(&index) { if let Some(pooled_parser) = reasoning_parsers.get(&index) {
let (parse_result, in_reasoning) = { let (parse_result, in_reasoning) = {
...@@ -569,9 +581,13 @@ impl GrpcRouter { ...@@ -569,9 +581,13 @@ impl GrpcRouter {
let mut chunks = Vec::new(); let mut chunks = Vec::new();
// Get or create parser for this index // Get or create parser for this index
tool_parsers tool_parsers.entry(index).or_insert_with(|| {
.entry(index) utils::get_tool_parser(
.or_insert_with(|| self.tool_parser_factory.get_pooled(model)); &self.tool_parser_factory,
self.configured_tool_parser.as_ref(),
model,
)
});
if let Some(pooled_parser) = tool_parsers.get(&index) { if let Some(pooled_parser) = tool_parsers.get(&index) {
let mut parser = pooled_parser.lock().await; let mut parser = pooled_parser.lock().await;
...@@ -1615,9 +1631,11 @@ impl GrpcRouter { ...@@ -1615,9 +1631,11 @@ impl GrpcRouter {
// Check if reasoning parsing is enabled and separate_reasoning is requested // Check if reasoning parsing is enabled and separate_reasoning is requested
if original_request.separate_reasoning { if original_request.separate_reasoning {
let pooled_parser = self let pooled_parser = utils::get_reasoning_parser(
.reasoning_parser_factory &self.reasoning_parser_factory,
.get_pooled(&original_request.model); self.configured_reasoning_parser.as_ref(),
&original_request.model,
);
let mut parser = pooled_parser let mut parser = pooled_parser
.lock() .lock()
......
...@@ -641,6 +641,64 @@ pub fn generate_tool_call_id( ...@@ -641,6 +641,64 @@ pub fn generate_tool_call_id(
} }
} }
/// Get the appropriate reasoning parser for a model
///
/// If a parser name is explicitly configured, use that parser.
/// Otherwise, auto-detect based on the model name.
pub fn get_reasoning_parser(
reasoning_parser_factory: &crate::reasoning_parser::ReasoningParserFactory,
configured_parser: Option<&String>,
model: &str,
) -> crate::reasoning_parser::PooledParser {
use tracing::warn;
if let Some(parser_name) = configured_parser {
// Use configured parser if specified
reasoning_parser_factory
.registry()
.get_pooled_parser(parser_name)
.unwrap_or_else(|| {
warn!(
"Configured reasoning parser '{}' not found, falling back to model-based selection",
parser_name
);
reasoning_parser_factory.get_pooled(model)
})
} else {
// Auto-detect based on model
reasoning_parser_factory.get_pooled(model)
}
}
/// Get the appropriate tool parser for a model
///
/// If a parser name is explicitly configured, use that parser.
/// Otherwise, auto-detect based on the model name.
pub fn get_tool_parser(
tool_parser_factory: &crate::tool_parser::ToolParserFactory,
configured_parser: Option<&String>,
model: &str,
) -> crate::tool_parser::PooledToolParser {
use tracing::warn;
if let Some(parser_name) = configured_parser {
// Use configured parser if specified
tool_parser_factory
.registry()
.get_pooled_parser(parser_name)
.unwrap_or_else(|| {
warn!(
"Configured tool parser '{}' not found, falling back to model-based selection",
parser_name
);
tool_parser_factory.get_pooled(model)
})
} else {
// Auto-detect based on model
tool_parser_factory.get_pooled(model)
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
......
...@@ -52,6 +52,8 @@ pub struct AppContext { ...@@ -52,6 +52,8 @@ pub struct AppContext {
pub router_manager: Option<Arc<RouterManager>>, pub router_manager: Option<Arc<RouterManager>>,
pub response_storage: SharedResponseStorage, pub response_storage: SharedResponseStorage,
pub load_monitor: Option<Arc<LoadMonitor>>, pub load_monitor: Option<Arc<LoadMonitor>>,
pub configured_reasoning_parser: Option<String>,
pub configured_tool_parser: Option<String>,
} }
impl AppContext { impl AppContext {
...@@ -115,6 +117,9 @@ impl AppContext { ...@@ -115,6 +117,9 @@ impl AppContext {
router_config.worker_startup_check_interval_secs, router_config.worker_startup_check_interval_secs,
))); )));
let configured_reasoning_parser = router_config.reasoning_parser.clone();
let configured_tool_parser = router_config.tool_call_parser.clone();
Ok(Self { Ok(Self {
client, client,
router_config, router_config,
...@@ -127,6 +132,8 @@ impl AppContext { ...@@ -127,6 +132,8 @@ impl AppContext {
router_manager, router_manager,
response_storage, response_storage,
load_monitor, load_monitor,
configured_reasoning_parser,
configured_tool_parser,
}) })
} }
} }
......
...@@ -543,6 +543,8 @@ mod tests { ...@@ -543,6 +543,8 @@ mod tests {
router_manager: None, router_manager: None,
response_storage: Arc::new(crate::data_connector::MemoryResponseStorage::new()), response_storage: Arc::new(crate::data_connector::MemoryResponseStorage::new()),
load_monitor: None, load_monitor: None,
configured_reasoning_parser: None,
configured_tool_parser: None,
}) })
} }
......
...@@ -63,6 +63,8 @@ impl TestContext { ...@@ -63,6 +63,8 @@ impl TestContext {
tokenizer_path: None, tokenizer_path: None,
history_backend: sglang_router_rs::config::HistoryBackend::Memory, history_backend: sglang_router_rs::config::HistoryBackend::Memory,
oracle: None, oracle: None,
reasoning_parser: None,
tool_call_parser: None,
}; };
Self::new_with_config(config, worker_configs).await Self::new_with_config(config, worker_configs).await
...@@ -1396,6 +1398,8 @@ mod error_tests { ...@@ -1396,6 +1398,8 @@ mod error_tests {
tokenizer_path: None, tokenizer_path: None,
history_backend: sglang_router_rs::config::HistoryBackend::Memory, history_backend: sglang_router_rs::config::HistoryBackend::Memory,
oracle: None, oracle: None,
reasoning_parser: None,
tool_call_parser: None,
}; };
let ctx = TestContext::new_with_config( let ctx = TestContext::new_with_config(
...@@ -1755,6 +1759,8 @@ mod pd_mode_tests { ...@@ -1755,6 +1759,8 @@ mod pd_mode_tests {
tokenizer_path: None, tokenizer_path: None,
history_backend: sglang_router_rs::config::HistoryBackend::Memory, history_backend: sglang_router_rs::config::HistoryBackend::Memory,
oracle: None, oracle: None,
reasoning_parser: None,
tool_call_parser: None,
}; };
// Create app context // Create app context
...@@ -1915,6 +1921,8 @@ mod request_id_tests { ...@@ -1915,6 +1921,8 @@ mod request_id_tests {
tokenizer_path: None, tokenizer_path: None,
history_backend: sglang_router_rs::config::HistoryBackend::Memory, history_backend: sglang_router_rs::config::HistoryBackend::Memory,
oracle: None, oracle: None,
reasoning_parser: None,
tool_call_parser: None,
}; };
let ctx = TestContext::new_with_config( let ctx = TestContext::new_with_config(
......
...@@ -76,6 +76,8 @@ async fn test_non_streaming_mcp_minimal_e2e_with_persistence() { ...@@ -76,6 +76,8 @@ async fn test_non_streaming_mcp_minimal_e2e_with_persistence() {
tokenizer_path: None, tokenizer_path: None,
history_backend: sglang_router_rs::config::HistoryBackend::Memory, history_backend: sglang_router_rs::config::HistoryBackend::Memory,
oracle: None, oracle: None,
reasoning_parser: None,
tool_call_parser: None,
}; };
// Create router and context // Create router and context
...@@ -508,6 +510,8 @@ async fn test_multi_turn_loop_with_mcp() { ...@@ -508,6 +510,8 @@ async fn test_multi_turn_loop_with_mcp() {
tokenizer_path: None, tokenizer_path: None,
history_backend: sglang_router_rs::config::HistoryBackend::Memory, history_backend: sglang_router_rs::config::HistoryBackend::Memory,
oracle: None, oracle: None,
reasoning_parser: None,
tool_call_parser: None,
}; };
let ctx = AppContext::new(router_cfg, reqwest::Client::new(), 64, None).expect("ctx"); let ctx = AppContext::new(router_cfg, reqwest::Client::new(), 64, None).expect("ctx");
...@@ -686,6 +690,8 @@ async fn test_max_tool_calls_limit() { ...@@ -686,6 +690,8 @@ async fn test_max_tool_calls_limit() {
tokenizer_path: None, tokenizer_path: None,
history_backend: sglang_router_rs::config::HistoryBackend::Memory, history_backend: sglang_router_rs::config::HistoryBackend::Memory,
oracle: None, oracle: None,
reasoning_parser: None,
tool_call_parser: None,
}; };
let ctx = AppContext::new(router_cfg, reqwest::Client::new(), 64, None).expect("ctx"); let ctx = AppContext::new(router_cfg, reqwest::Client::new(), 64, None).expect("ctx");
...@@ -826,6 +832,8 @@ async fn setup_streaming_mcp_test() -> ( ...@@ -826,6 +832,8 @@ async fn setup_streaming_mcp_test() -> (
tokenizer_path: None, tokenizer_path: None,
history_backend: sglang_router_rs::config::HistoryBackend::Memory, history_backend: sglang_router_rs::config::HistoryBackend::Memory,
oracle: None, oracle: None,
reasoning_parser: None,
tool_call_parser: None,
}; };
let ctx = AppContext::new(router_cfg, reqwest::Client::new(), 64, None).expect("ctx"); let ctx = AppContext::new(router_cfg, reqwest::Client::new(), 64, None).expect("ctx");
......
...@@ -826,6 +826,8 @@ fn oracle_config_validation_requires_config_when_enabled() { ...@@ -826,6 +826,8 @@ fn oracle_config_validation_requires_config_when_enabled() {
}, },
history_backend: HistoryBackend::Oracle, history_backend: HistoryBackend::Oracle,
oracle: None, oracle: None,
reasoning_parser: None,
tool_call_parser: None,
..Default::default() ..Default::default()
}; };
......
...@@ -195,6 +195,8 @@ mod test_pd_routing { ...@@ -195,6 +195,8 @@ mod test_pd_routing {
tokenizer_path: None, tokenizer_path: None,
history_backend: sglang_router_rs::config::HistoryBackend::Memory, history_backend: sglang_router_rs::config::HistoryBackend::Memory,
oracle: None, oracle: None,
reasoning_parser: None,
tool_call_parser: None,
}; };
let app_context = let app_context =
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment