Unverified Commit 79d34951 authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] add reasoning and tool parser argument in router (#11290)

parent 1519a89c
......@@ -86,6 +86,9 @@ class RouterArgs:
# Tokenizer configuration
model_path: Optional[str] = None
tokenizer_path: Optional[str] = None
# Parser configuration
reasoning_parser: Optional[str] = None
tool_call_parser: Optional[str] = None
@staticmethod
def add_cli_args(
......@@ -446,6 +449,18 @@ class RouterArgs:
default=None,
help="Explicit tokenizer path (overrides model_path tokenizer if provided)",
)
parser.add_argument(
f"--{prefix}reasoning-parser",
type=str,
default=None,
help="Specify the parser for reasoning models (e.g., deepseek-r1, qwen3)",
)
parser.add_argument(
f"--{prefix}tool-call-parser",
type=str,
default=None,
help="Specify the parser for handling tool-call interactions",
)
@classmethod
def from_cli_args(
......
......@@ -73,6 +73,10 @@ pub struct RouterConfig {
/// Oracle history backend configuration (required when `history_backend` = "oracle")
#[serde(skip_serializing_if = "Option::is_none")]
pub oracle: Option<OracleConfig>,
/// Parser for reasoning models (e.g., deepseek-r1, qwen3)
pub reasoning_parser: Option<String>,
/// Parser for handling tool-call interactions
pub tool_call_parser: Option<String>,
}
fn default_history_backend() -> HistoryBackend {
......@@ -448,6 +452,8 @@ impl Default for RouterConfig {
tokenizer_path: None,
history_backend: default_history_backend(),
oracle: None,
reasoning_parser: None,
tool_call_parser: None,
}
}
}
......@@ -990,6 +996,8 @@ mod tests {
tokenizer_path: None,
history_backend: default_history_backend(),
oracle: None,
reasoning_parser: None,
tool_call_parser: None,
};
assert!(config.mode.is_pd_mode());
......@@ -1055,6 +1063,8 @@ mod tests {
tokenizer_path: None,
history_backend: default_history_backend(),
oracle: None,
reasoning_parser: None,
tool_call_parser: None,
};
assert!(!config.mode.is_pd_mode());
......@@ -1116,6 +1126,8 @@ mod tests {
tokenizer_path: None,
history_backend: default_history_backend(),
oracle: None,
reasoning_parser: None,
tool_call_parser: None,
};
assert!(config.has_service_discovery());
......
......@@ -90,6 +90,8 @@ struct Router {
connection_mode: config::ConnectionMode,
model_path: Option<String>,
tokenizer_path: Option<String>,
reasoning_parser: Option<String>,
tool_call_parser: Option<String>,
}
impl Router {
......@@ -216,6 +218,8 @@ impl Router {
tokenizer_path: self.tokenizer_path.clone(),
history_backend: config::HistoryBackend::Memory,
oracle: None,
reasoning_parser: self.reasoning_parser.clone(),
tool_call_parser: self.tool_call_parser.clone(),
})
}
}
......@@ -280,6 +284,8 @@ impl Router {
rate_limit_tokens_per_second = None,
model_path = None,
tokenizer_path = None,
reasoning_parser = None,
tool_call_parser = None,
))]
#[allow(clippy::too_many_arguments)]
fn new(
......@@ -339,6 +345,8 @@ impl Router {
rate_limit_tokens_per_second: Option<usize>,
model_path: Option<String>,
tokenizer_path: Option<String>,
reasoning_parser: Option<String>,
tool_call_parser: Option<String>,
) -> PyResult<Self> {
let mut all_urls = worker_urls.clone();
......@@ -412,6 +420,8 @@ impl Router {
connection_mode,
model_path,
tokenizer_path,
reasoning_parser,
tool_call_parser,
})
}
......
......@@ -281,6 +281,12 @@ struct CliArgs {
#[arg(long, env = "ATP_POOL_TIMEOUT_SECS")]
oracle_pool_timeout_secs: Option<u64>,
#[arg(long)]
reasoning_parser: Option<String>,
#[arg(long)]
tool_call_parser: Option<String>,
}
enum OracleConnectSource {
......@@ -557,6 +563,8 @@ impl CliArgs {
tokenizer_path: self.tokenizer_path.clone(),
history_backend,
oracle,
reasoning_parser: self.reasoning_parser.clone(),
tool_call_parser: self.tool_call_parser.clone(),
})
}
......
......@@ -53,6 +53,8 @@ pub struct GrpcPDRouter {
dp_aware: bool,
api_key: Option<String>,
retry_config: RetryConfig,
configured_reasoning_parser: Option<String>,
configured_tool_parser: Option<String>,
}
impl GrpcPDRouter {
......@@ -88,6 +90,8 @@ impl GrpcPDRouter {
dp_aware: ctx.router_config.dp_aware,
api_key: ctx.router_config.api_key.clone(),
retry_config: ctx.router_config.effective_retry_config(),
configured_reasoning_parser: ctx.configured_reasoning_parser.clone(),
configured_tool_parser: ctx.configured_tool_parser.clone(),
})
}
......@@ -1179,9 +1183,13 @@ impl GrpcPDRouter {
created: u64,
) -> (String, Option<ChatCompletionStreamResponse>, bool) {
// Get or create parser for this index
reasoning_parsers
.entry(index)
.or_insert_with(|| self.reasoning_parser_factory.get_pooled(model));
reasoning_parsers.entry(index).or_insert_with(|| {
utils::get_reasoning_parser(
&self.reasoning_parser_factory,
self.configured_reasoning_parser.as_ref(),
model,
)
});
if let Some(pooled_parser) = reasoning_parsers.get(&index) {
let (parse_result, in_reasoning) = {
......@@ -1248,9 +1256,13 @@ impl GrpcPDRouter {
let mut chunks = Vec::new();
// Get or create parser for this index
tool_parsers
.entry(index)
.or_insert_with(|| self.tool_parser_factory.get_pooled(model));
tool_parsers.entry(index).or_insert_with(|| {
utils::get_tool_parser(
&self.tool_parser_factory,
self.configured_tool_parser.as_ref(),
model,
)
});
if let Some(pooled_parser) = tool_parsers.get(&index) {
let mut parser = pooled_parser.lock().await;
......@@ -1737,9 +1749,11 @@ impl GrpcPDRouter {
// Check if reasoning parsing is enabled and separate_reasoning is requested
if original_request.separate_reasoning {
let pooled_parser = self
.reasoning_parser_factory
.get_pooled(&original_request.model);
let pooled_parser = utils::get_reasoning_parser(
&self.reasoning_parser_factory,
self.configured_reasoning_parser.as_ref(),
&original_request.model,
);
let mut parser = pooled_parser
.lock()
......@@ -1860,7 +1874,11 @@ impl GrpcPDRouter {
history_tool_calls_count: usize,
) -> (Option<Vec<ToolCall>>, String) {
// Get pooled parser for this model
let pooled_parser = self.tool_parser_factory.get_pooled(model);
let pooled_parser = utils::get_tool_parser(
&self.tool_parser_factory,
self.configured_tool_parser.as_ref(),
model,
);
// Check format detection first
let can_parse = {
......
......@@ -53,6 +53,8 @@ pub struct GrpcRouter {
dp_aware: bool,
api_key: Option<String>,
retry_config: RetryConfig,
configured_reasoning_parser: Option<String>,
configured_tool_parser: Option<String>,
}
impl GrpcRouter {
......@@ -87,6 +89,8 @@ impl GrpcRouter {
dp_aware: ctx.router_config.dp_aware,
api_key: ctx.router_config.api_key.clone(),
retry_config: ctx.router_config.effective_retry_config(),
configured_reasoning_parser: ctx.configured_reasoning_parser.clone(),
configured_tool_parser: ctx.configured_tool_parser.clone(),
})
}
......@@ -301,7 +305,11 @@ impl GrpcRouter {
history_tool_calls_count: usize,
) -> (Option<Vec<ToolCall>>, String) {
// Get pooled parser for this model
let pooled_parser = self.tool_parser_factory.get_pooled(model);
let pooled_parser = utils::get_tool_parser(
&self.tool_parser_factory,
self.configured_tool_parser.as_ref(),
model,
);
// Check format detection first
let can_parse = {
......@@ -496,9 +504,13 @@ impl GrpcRouter {
created: u64,
) -> (String, Option<ChatCompletionStreamResponse>, bool) {
// Get or create parser for this index
reasoning_parsers
.entry(index)
.or_insert_with(|| self.reasoning_parser_factory.get_pooled(model));
reasoning_parsers.entry(index).or_insert_with(|| {
utils::get_reasoning_parser(
&self.reasoning_parser_factory,
self.configured_reasoning_parser.as_ref(),
model,
)
});
if let Some(pooled_parser) = reasoning_parsers.get(&index) {
let (parse_result, in_reasoning) = {
......@@ -569,9 +581,13 @@ impl GrpcRouter {
let mut chunks = Vec::new();
// Get or create parser for this index
tool_parsers
.entry(index)
.or_insert_with(|| self.tool_parser_factory.get_pooled(model));
tool_parsers.entry(index).or_insert_with(|| {
utils::get_tool_parser(
&self.tool_parser_factory,
self.configured_tool_parser.as_ref(),
model,
)
});
if let Some(pooled_parser) = tool_parsers.get(&index) {
let mut parser = pooled_parser.lock().await;
......@@ -1615,9 +1631,11 @@ impl GrpcRouter {
// Check if reasoning parsing is enabled and separate_reasoning is requested
if original_request.separate_reasoning {
let pooled_parser = self
.reasoning_parser_factory
.get_pooled(&original_request.model);
let pooled_parser = utils::get_reasoning_parser(
&self.reasoning_parser_factory,
self.configured_reasoning_parser.as_ref(),
&original_request.model,
);
let mut parser = pooled_parser
.lock()
......
......@@ -641,6 +641,64 @@ pub fn generate_tool_call_id(
}
}
/// Get the appropriate reasoning parser for a model
///
/// If a parser name is explicitly configured, use that parser.
/// Otherwise, auto-detect based on the model name.
pub fn get_reasoning_parser(
reasoning_parser_factory: &crate::reasoning_parser::ReasoningParserFactory,
configured_parser: Option<&String>,
model: &str,
) -> crate::reasoning_parser::PooledParser {
use tracing::warn;
if let Some(parser_name) = configured_parser {
// Use configured parser if specified
reasoning_parser_factory
.registry()
.get_pooled_parser(parser_name)
.unwrap_or_else(|| {
warn!(
"Configured reasoning parser '{}' not found, falling back to model-based selection",
parser_name
);
reasoning_parser_factory.get_pooled(model)
})
} else {
// Auto-detect based on model
reasoning_parser_factory.get_pooled(model)
}
}
/// Get the appropriate tool parser for a model
///
/// If a parser name is explicitly configured, use that parser.
/// Otherwise, auto-detect based on the model name.
pub fn get_tool_parser(
tool_parser_factory: &crate::tool_parser::ToolParserFactory,
configured_parser: Option<&String>,
model: &str,
) -> crate::tool_parser::PooledToolParser {
use tracing::warn;
if let Some(parser_name) = configured_parser {
// Use configured parser if specified
tool_parser_factory
.registry()
.get_pooled_parser(parser_name)
.unwrap_or_else(|| {
warn!(
"Configured tool parser '{}' not found, falling back to model-based selection",
parser_name
);
tool_parser_factory.get_pooled(model)
})
} else {
// Auto-detect based on model
tool_parser_factory.get_pooled(model)
}
}
#[cfg(test)]
mod tests {
use super::*;
......
......@@ -52,6 +52,8 @@ pub struct AppContext {
pub router_manager: Option<Arc<RouterManager>>,
pub response_storage: SharedResponseStorage,
pub load_monitor: Option<Arc<LoadMonitor>>,
pub configured_reasoning_parser: Option<String>,
pub configured_tool_parser: Option<String>,
}
impl AppContext {
......@@ -115,6 +117,9 @@ impl AppContext {
router_config.worker_startup_check_interval_secs,
)));
let configured_reasoning_parser = router_config.reasoning_parser.clone();
let configured_tool_parser = router_config.tool_call_parser.clone();
Ok(Self {
client,
router_config,
......@@ -127,6 +132,8 @@ impl AppContext {
router_manager,
response_storage,
load_monitor,
configured_reasoning_parser,
configured_tool_parser,
})
}
}
......
......@@ -543,6 +543,8 @@ mod tests {
router_manager: None,
response_storage: Arc::new(crate::data_connector::MemoryResponseStorage::new()),
load_monitor: None,
configured_reasoning_parser: None,
configured_tool_parser: None,
})
}
......
......@@ -63,6 +63,8 @@ impl TestContext {
tokenizer_path: None,
history_backend: sglang_router_rs::config::HistoryBackend::Memory,
oracle: None,
reasoning_parser: None,
tool_call_parser: None,
};
Self::new_with_config(config, worker_configs).await
......@@ -1396,6 +1398,8 @@ mod error_tests {
tokenizer_path: None,
history_backend: sglang_router_rs::config::HistoryBackend::Memory,
oracle: None,
reasoning_parser: None,
tool_call_parser: None,
};
let ctx = TestContext::new_with_config(
......@@ -1755,6 +1759,8 @@ mod pd_mode_tests {
tokenizer_path: None,
history_backend: sglang_router_rs::config::HistoryBackend::Memory,
oracle: None,
reasoning_parser: None,
tool_call_parser: None,
};
// Create app context
......@@ -1915,6 +1921,8 @@ mod request_id_tests {
tokenizer_path: None,
history_backend: sglang_router_rs::config::HistoryBackend::Memory,
oracle: None,
reasoning_parser: None,
tool_call_parser: None,
};
let ctx = TestContext::new_with_config(
......
......@@ -76,6 +76,8 @@ async fn test_non_streaming_mcp_minimal_e2e_with_persistence() {
tokenizer_path: None,
history_backend: sglang_router_rs::config::HistoryBackend::Memory,
oracle: None,
reasoning_parser: None,
tool_call_parser: None,
};
// Create router and context
......@@ -508,6 +510,8 @@ async fn test_multi_turn_loop_with_mcp() {
tokenizer_path: None,
history_backend: sglang_router_rs::config::HistoryBackend::Memory,
oracle: None,
reasoning_parser: None,
tool_call_parser: None,
};
let ctx = AppContext::new(router_cfg, reqwest::Client::new(), 64, None).expect("ctx");
......@@ -686,6 +690,8 @@ async fn test_max_tool_calls_limit() {
tokenizer_path: None,
history_backend: sglang_router_rs::config::HistoryBackend::Memory,
oracle: None,
reasoning_parser: None,
tool_call_parser: None,
};
let ctx = AppContext::new(router_cfg, reqwest::Client::new(), 64, None).expect("ctx");
......@@ -826,6 +832,8 @@ async fn setup_streaming_mcp_test() -> (
tokenizer_path: None,
history_backend: sglang_router_rs::config::HistoryBackend::Memory,
oracle: None,
reasoning_parser: None,
tool_call_parser: None,
};
let ctx = AppContext::new(router_cfg, reqwest::Client::new(), 64, None).expect("ctx");
......
......@@ -826,6 +826,8 @@ fn oracle_config_validation_requires_config_when_enabled() {
},
history_backend: HistoryBackend::Oracle,
oracle: None,
reasoning_parser: None,
tool_call_parser: None,
..Default::default()
};
......
......@@ -195,6 +195,8 @@ mod test_pd_routing {
tokenizer_path: None,
history_backend: sglang_router_rs::config::HistoryBackend::Memory,
oracle: None,
reasoning_parser: None,
tool_call_parser: None,
};
let app_context =
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment