Unverified Commit d966b902 authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] move tokenizer, reasoning, tool initialization to server (#9996)

parent de921733
......@@ -986,7 +986,7 @@ pub fn start_health_checker(
// Periodically reset load counters to prevent drift
// Only do this when we believe all workers should be idle
if check_count.is_multiple_of(LOAD_RESET_INTERVAL) {
if check_count % LOAD_RESET_INTERVAL == 0 {
let max_load = workers_to_check.iter().map(|w| w.load()).max().unwrap_or(0);
// Only reset if load appears to be very low (likely drift)
if max_load <= 2 {
......
......@@ -146,17 +146,29 @@ impl RouterFactory {
// Create policy
let policy = PolicyFactory::create_from_config(policy_config);
// Determine which tokenizer path to use
// Priority: tokenizer_path > model_path
let tokenizer_path = ctx
.router_config
.tokenizer_path
.clone()
.or_else(|| ctx.router_config.model_path.clone())
// Get tokenizer from context
let tokenizer = ctx
.tokenizer
.as_ref()
.ok_or_else(|| {
"gRPC router requires either --tokenizer-path or --model-path to be specified"
"gRPC router requires tokenizer to be initialized in AppContext".to_string()
})?
.clone();
// Get reasoning parser factory from context
let reasoning_parser_factory = ctx
.reasoning_parser_factory
.as_ref()
.ok_or_else(|| {
"gRPC router requires reasoning parser factory to be initialized in AppContext"
.to_string()
})?;
})?
.clone();
// Get tool parser registry from context
let tool_parser_registry = ctx.tool_parser_registry.ok_or_else(|| {
"gRPC router requires tool parser registry to be initialized in AppContext".to_string()
})?;
// Create gRPC router
let router = GrpcRouter::new(
......@@ -169,7 +181,9 @@ impl RouterFactory {
ctx.router_config.effective_retry_config(),
ctx.router_config.effective_circuit_breaker_config(),
ctx.router_config.health_check.clone(),
tokenizer_path,
tokenizer,
reasoning_parser_factory,
tool_parser_registry,
)
.await?;
......@@ -193,17 +207,30 @@ impl RouterFactory {
let decode_policy =
PolicyFactory::create_from_config(decode_policy_config.unwrap_or(main_policy_config));
// Determine which tokenizer path to use
// Priority: tokenizer_path > model_path
let tokenizer_path = ctx
.router_config
.tokenizer_path
.clone()
.or_else(|| ctx.router_config.model_path.clone())
// Get tokenizer from context
let tokenizer = ctx
.tokenizer
.as_ref()
.ok_or_else(|| {
"gRPC PD router requires either --tokenizer-path or --model-path to be specified"
"gRPC PD router requires tokenizer to be initialized in AppContext".to_string()
})?
.clone();
// Get reasoning parser factory from context
let reasoning_parser_factory = ctx
.reasoning_parser_factory
.as_ref()
.ok_or_else(|| {
"gRPC PD router requires reasoning parser factory to be initialized in AppContext"
.to_string()
})?;
})?
.clone();
// Get tool parser registry from context
let tool_parser_registry = ctx.tool_parser_registry.ok_or_else(|| {
"gRPC PD router requires tool parser registry to be initialized in AppContext"
.to_string()
})?;
// Create gRPC PD router
let router = GrpcPDRouter::new(
......@@ -218,7 +245,9 @@ impl RouterFactory {
ctx.router_config.effective_retry_config(),
ctx.router_config.effective_circuit_breaker_config(),
ctx.router_config.health_check.clone(),
tokenizer_path,
tokenizer,
reasoning_parser_factory,
tool_parser_registry,
)
.await?;
......
......@@ -12,7 +12,7 @@ use crate::metrics::RouterMetrics;
use crate::policies::LoadBalancingPolicy;
use crate::reasoning_parser::ParserFactory;
use crate::routers::{RouterTrait, WorkerManagement};
use crate::tokenizer::{factory, traits::Tokenizer};
use crate::tokenizer::traits::Tokenizer;
use crate::tool_parser::ParserRegistry;
use async_trait::async_trait;
use axum::{
......@@ -74,21 +74,13 @@ impl GrpcPDRouter {
retry_config: RetryConfig,
circuit_breaker_config: ConfigCircuitBreakerConfig,
health_check_config: ConfigHealthCheckConfig,
tokenizer_path_or_model: String,
tokenizer: Arc<dyn Tokenizer>,
reasoning_parser_factory: ParserFactory,
tool_parser_registry: &'static ParserRegistry,
) -> Result<Self, String> {
// Update metrics
RouterMetrics::set_active_workers(prefill_urls.len() + decode_urls.len());
// Initialize tokenizer
let tokenizer = factory::create_tokenizer(&tokenizer_path_or_model)
.map_err(|e| format!("Failed to create tokenizer: {}", e))?;
// Initialize reasoning parser factory
let reasoning_parser_factory = ParserFactory::new();
// Get tool parser registry
let tool_parser_registry = ParserRegistry::new();
// Convert config CircuitBreakerConfig to core CircuitBreakerConfig
let core_cb_config = CircuitBreakerConfig {
failure_threshold: circuit_breaker_config.failure_threshold,
......
......@@ -12,7 +12,7 @@ use crate::metrics::RouterMetrics;
use crate::policies::LoadBalancingPolicy;
use crate::reasoning_parser::ParserFactory;
use crate::routers::{RouterTrait, WorkerManagement};
use crate::tokenizer::{factory, traits::Tokenizer};
use crate::tokenizer::traits::Tokenizer;
use crate::tool_parser::ParserRegistry;
use async_trait::async_trait;
use axum::{
......@@ -65,21 +65,13 @@ impl GrpcRouter {
retry_config: RetryConfig,
circuit_breaker_config: ConfigCircuitBreakerConfig,
health_check_config: ConfigHealthCheckConfig,
tokenizer_path_or_model: String,
tokenizer: Arc<dyn Tokenizer>,
reasoning_parser_factory: ParserFactory,
tool_parser_registry: &'static ParserRegistry,
) -> Result<Self, String> {
// Update metrics
RouterMetrics::set_active_workers(worker_urls.len());
// Initialize tokenizer
let tokenizer = factory::create_tokenizer(&tokenizer_path_or_model)
.map_err(|e| format!("Failed to create tokenizer: {}", e))?;
// Initialize reasoning parser factory
let reasoning_parser_factory = ParserFactory::new();
// Get tool parser registry
let tool_parser_registry = ParserRegistry::new();
// Convert config CircuitBreakerConfig to core CircuitBreakerConfig
let core_cb_config = CircuitBreakerConfig {
failure_threshold: circuit_breaker_config.failure_threshold,
......
......@@ -3,8 +3,11 @@ use crate::logging::{self, LoggingConfig};
use crate::metrics::{self, PrometheusConfig};
use crate::middleware::TokenBucket;
use crate::protocols::spec::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
use crate::reasoning_parser::ParserFactory;
use crate::routers::{RouterFactory, RouterTrait};
use crate::service_discovery::{start_service_discovery, ServiceDiscoveryConfig};
use crate::tokenizer::{factory as tokenizer_factory, traits::Tokenizer};
use crate::tool_parser::ParserRegistry;
use axum::{
extract::{Query, Request, State},
http::StatusCode,
......@@ -27,7 +30,9 @@ pub struct AppContext {
pub client: Client,
pub router_config: RouterConfig,
pub rate_limiter: Arc<TokenBucket>,
// Future dependencies can be added here
pub tokenizer: Option<Arc<dyn Tokenizer>>,
pub reasoning_parser_factory: Option<ParserFactory>,
pub tool_parser_registry: Option<&'static ParserRegistry>,
}
impl AppContext {
......@@ -36,14 +41,45 @@ impl AppContext {
client: Client,
max_concurrent_requests: usize,
rate_limit_tokens_per_second: Option<usize>,
) -> Self {
) -> Result<Self, String> {
let rate_limit_tokens = rate_limit_tokens_per_second.unwrap_or(max_concurrent_requests);
let rate_limiter = Arc::new(TokenBucket::new(max_concurrent_requests, rate_limit_tokens));
Self {
// Initialize gRPC-specific components only when in gRPC mode
let (tokenizer, reasoning_parser_factory, tool_parser_registry) =
if router_config.connection_mode == crate::config::ConnectionMode::Grpc {
// Get tokenizer path (required for gRPC mode)
let tokenizer_path = router_config
.tokenizer_path
.clone()
.or_else(|| router_config.model_path.clone())
.ok_or_else(|| {
"gRPC mode requires either --tokenizer-path or --model-path to be specified"
.to_string()
})?;
// Initialize all gRPC components
let tokenizer = Some(
tokenizer_factory::create_tokenizer(&tokenizer_path)
.map_err(|e| format!("Failed to create tokenizer: {}", e))?,
);
let reasoning_parser_factory = Some(ParserFactory::new());
let tool_parser_registry = Some(ParserRegistry::new());
(tokenizer, reasoning_parser_factory, tool_parser_registry)
} else {
// HTTP mode doesn't need these components
(None, None, None)
};
Ok(Self {
client,
router_config,
rate_limiter,
}
tokenizer,
reasoning_parser_factory,
tool_parser_registry,
})
}
}
......@@ -291,7 +327,7 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
client.clone(),
config.router_config.max_concurrent_requests,
config.router_config.rate_limit_tokens_per_second,
));
)?);
// Create router with the context
let router = RouterFactory::create_router(&app_context).await?;
......
......@@ -13,12 +13,15 @@ use std::sync::{Arc, Mutex, OnceLock};
/// Helper function to create AppContext for tests
pub fn create_test_context(config: RouterConfig) -> Arc<AppContext> {
Arc::new(AppContext::new(
config.clone(),
reqwest::Client::new(),
config.max_concurrent_requests,
config.rate_limit_tokens_per_second,
))
Arc::new(
AppContext::new(
config.clone(),
reqwest::Client::new(),
config.max_concurrent_requests,
config.rate_limit_tokens_per_second,
)
.expect("Failed to create AppContext in test"),
)
}
// Tokenizer download configuration
......
......@@ -15,12 +15,15 @@ pub fn create_test_app(
router_config: &RouterConfig,
) -> Router {
// Create AppContext
let app_context = Arc::new(AppContext::new(
router_config.clone(),
client,
router_config.max_concurrent_requests,
router_config.rate_limit_tokens_per_second,
));
let app_context = Arc::new(
AppContext::new(
router_config.clone(),
client,
router_config.max_concurrent_requests,
router_config.rate_limit_tokens_per_second,
)
.expect("Failed to create AppContext in test"),
);
// Create AppState with the test router and context
let app_state = Arc::new(AppState {
......
......@@ -195,7 +195,8 @@ mod test_pd_routing {
// Router creation will fail due to health checks, but config should be valid
let app_context =
sglang_router_rs::server::AppContext::new(config, reqwest::Client::new(), 64, None);
sglang_router_rs::server::AppContext::new(config, reqwest::Client::new(), 64, None)
.expect("Failed to create AppContext");
let app_context = std::sync::Arc::new(app_context);
let result = RouterFactory::create_router(&app_context).await;
assert!(result.is_err());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment