Unverified Commit d966b902 authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] move tokenizer, reasoning, tool initialization to server (#9996)

parent de921733
...@@ -986,7 +986,7 @@ pub fn start_health_checker( ...@@ -986,7 +986,7 @@ pub fn start_health_checker(
// Periodically reset load counters to prevent drift // Periodically reset load counters to prevent drift
// Only do this when we believe all workers should be idle // Only do this when we believe all workers should be idle
if check_count.is_multiple_of(LOAD_RESET_INTERVAL) { if check_count % LOAD_RESET_INTERVAL == 0 {
let max_load = workers_to_check.iter().map(|w| w.load()).max().unwrap_or(0); let max_load = workers_to_check.iter().map(|w| w.load()).max().unwrap_or(0);
// Only reset if load appears to be very low (likely drift) // Only reset if load appears to be very low (likely drift)
if max_load <= 2 { if max_load <= 2 {
......
...@@ -146,16 +146,28 @@ impl RouterFactory { ...@@ -146,16 +146,28 @@ impl RouterFactory {
// Create policy // Create policy
let policy = PolicyFactory::create_from_config(policy_config); let policy = PolicyFactory::create_from_config(policy_config);
// Determine which tokenizer path to use // Get tokenizer from context
// Priority: tokenizer_path > model_path let tokenizer = ctx
let tokenizer_path = ctx .tokenizer
.router_config .as_ref()
.tokenizer_path
.clone()
.or_else(|| ctx.router_config.model_path.clone())
.ok_or_else(|| { .ok_or_else(|| {
"gRPC router requires either --tokenizer-path or --model-path to be specified" "gRPC router requires tokenizer to be initialized in AppContext".to_string()
})?
.clone();
// Get reasoning parser factory from context
let reasoning_parser_factory = ctx
.reasoning_parser_factory
.as_ref()
.ok_or_else(|| {
"gRPC router requires reasoning parser factory to be initialized in AppContext"
.to_string() .to_string()
})?
.clone();
// Get tool parser registry from context
let tool_parser_registry = ctx.tool_parser_registry.ok_or_else(|| {
"gRPC router requires tool parser registry to be initialized in AppContext".to_string()
})?; })?;
// Create gRPC router // Create gRPC router
...@@ -169,7 +181,9 @@ impl RouterFactory { ...@@ -169,7 +181,9 @@ impl RouterFactory {
ctx.router_config.effective_retry_config(), ctx.router_config.effective_retry_config(),
ctx.router_config.effective_circuit_breaker_config(), ctx.router_config.effective_circuit_breaker_config(),
ctx.router_config.health_check.clone(), ctx.router_config.health_check.clone(),
tokenizer_path, tokenizer,
reasoning_parser_factory,
tool_parser_registry,
) )
.await?; .await?;
...@@ -193,15 +207,28 @@ impl RouterFactory { ...@@ -193,15 +207,28 @@ impl RouterFactory {
let decode_policy = let decode_policy =
PolicyFactory::create_from_config(decode_policy_config.unwrap_or(main_policy_config)); PolicyFactory::create_from_config(decode_policy_config.unwrap_or(main_policy_config));
// Determine which tokenizer path to use // Get tokenizer from context
// Priority: tokenizer_path > model_path let tokenizer = ctx
let tokenizer_path = ctx .tokenizer
.router_config .as_ref()
.tokenizer_path
.clone()
.or_else(|| ctx.router_config.model_path.clone())
.ok_or_else(|| { .ok_or_else(|| {
"gRPC PD router requires either --tokenizer-path or --model-path to be specified" "gRPC PD router requires tokenizer to be initialized in AppContext".to_string()
})?
.clone();
// Get reasoning parser factory from context
let reasoning_parser_factory = ctx
.reasoning_parser_factory
.as_ref()
.ok_or_else(|| {
"gRPC PD router requires reasoning parser factory to be initialized in AppContext"
.to_string()
})?
.clone();
// Get tool parser registry from context
let tool_parser_registry = ctx.tool_parser_registry.ok_or_else(|| {
"gRPC PD router requires tool parser registry to be initialized in AppContext"
.to_string() .to_string()
})?; })?;
...@@ -218,7 +245,9 @@ impl RouterFactory { ...@@ -218,7 +245,9 @@ impl RouterFactory {
ctx.router_config.effective_retry_config(), ctx.router_config.effective_retry_config(),
ctx.router_config.effective_circuit_breaker_config(), ctx.router_config.effective_circuit_breaker_config(),
ctx.router_config.health_check.clone(), ctx.router_config.health_check.clone(),
tokenizer_path, tokenizer,
reasoning_parser_factory,
tool_parser_registry,
) )
.await?; .await?;
......
...@@ -12,7 +12,7 @@ use crate::metrics::RouterMetrics; ...@@ -12,7 +12,7 @@ use crate::metrics::RouterMetrics;
use crate::policies::LoadBalancingPolicy; use crate::policies::LoadBalancingPolicy;
use crate::reasoning_parser::ParserFactory; use crate::reasoning_parser::ParserFactory;
use crate::routers::{RouterTrait, WorkerManagement}; use crate::routers::{RouterTrait, WorkerManagement};
use crate::tokenizer::{factory, traits::Tokenizer}; use crate::tokenizer::traits::Tokenizer;
use crate::tool_parser::ParserRegistry; use crate::tool_parser::ParserRegistry;
use async_trait::async_trait; use async_trait::async_trait;
use axum::{ use axum::{
...@@ -74,21 +74,13 @@ impl GrpcPDRouter { ...@@ -74,21 +74,13 @@ impl GrpcPDRouter {
retry_config: RetryConfig, retry_config: RetryConfig,
circuit_breaker_config: ConfigCircuitBreakerConfig, circuit_breaker_config: ConfigCircuitBreakerConfig,
health_check_config: ConfigHealthCheckConfig, health_check_config: ConfigHealthCheckConfig,
tokenizer_path_or_model: String, tokenizer: Arc<dyn Tokenizer>,
reasoning_parser_factory: ParserFactory,
tool_parser_registry: &'static ParserRegistry,
) -> Result<Self, String> { ) -> Result<Self, String> {
// Update metrics // Update metrics
RouterMetrics::set_active_workers(prefill_urls.len() + decode_urls.len()); RouterMetrics::set_active_workers(prefill_urls.len() + decode_urls.len());
// Initialize tokenizer
let tokenizer = factory::create_tokenizer(&tokenizer_path_or_model)
.map_err(|e| format!("Failed to create tokenizer: {}", e))?;
// Initialize reasoning parser factory
let reasoning_parser_factory = ParserFactory::new();
// Get tool parser registry
let tool_parser_registry = ParserRegistry::new();
// Convert config CircuitBreakerConfig to core CircuitBreakerConfig // Convert config CircuitBreakerConfig to core CircuitBreakerConfig
let core_cb_config = CircuitBreakerConfig { let core_cb_config = CircuitBreakerConfig {
failure_threshold: circuit_breaker_config.failure_threshold, failure_threshold: circuit_breaker_config.failure_threshold,
......
...@@ -12,7 +12,7 @@ use crate::metrics::RouterMetrics; ...@@ -12,7 +12,7 @@ use crate::metrics::RouterMetrics;
use crate::policies::LoadBalancingPolicy; use crate::policies::LoadBalancingPolicy;
use crate::reasoning_parser::ParserFactory; use crate::reasoning_parser::ParserFactory;
use crate::routers::{RouterTrait, WorkerManagement}; use crate::routers::{RouterTrait, WorkerManagement};
use crate::tokenizer::{factory, traits::Tokenizer}; use crate::tokenizer::traits::Tokenizer;
use crate::tool_parser::ParserRegistry; use crate::tool_parser::ParserRegistry;
use async_trait::async_trait; use async_trait::async_trait;
use axum::{ use axum::{
...@@ -65,21 +65,13 @@ impl GrpcRouter { ...@@ -65,21 +65,13 @@ impl GrpcRouter {
retry_config: RetryConfig, retry_config: RetryConfig,
circuit_breaker_config: ConfigCircuitBreakerConfig, circuit_breaker_config: ConfigCircuitBreakerConfig,
health_check_config: ConfigHealthCheckConfig, health_check_config: ConfigHealthCheckConfig,
tokenizer_path_or_model: String, tokenizer: Arc<dyn Tokenizer>,
reasoning_parser_factory: ParserFactory,
tool_parser_registry: &'static ParserRegistry,
) -> Result<Self, String> { ) -> Result<Self, String> {
// Update metrics // Update metrics
RouterMetrics::set_active_workers(worker_urls.len()); RouterMetrics::set_active_workers(worker_urls.len());
// Initialize tokenizer
let tokenizer = factory::create_tokenizer(&tokenizer_path_or_model)
.map_err(|e| format!("Failed to create tokenizer: {}", e))?;
// Initialize reasoning parser factory
let reasoning_parser_factory = ParserFactory::new();
// Get tool parser registry
let tool_parser_registry = ParserRegistry::new();
// Convert config CircuitBreakerConfig to core CircuitBreakerConfig // Convert config CircuitBreakerConfig to core CircuitBreakerConfig
let core_cb_config = CircuitBreakerConfig { let core_cb_config = CircuitBreakerConfig {
failure_threshold: circuit_breaker_config.failure_threshold, failure_threshold: circuit_breaker_config.failure_threshold,
......
...@@ -3,8 +3,11 @@ use crate::logging::{self, LoggingConfig}; ...@@ -3,8 +3,11 @@ use crate::logging::{self, LoggingConfig};
use crate::metrics::{self, PrometheusConfig}; use crate::metrics::{self, PrometheusConfig};
use crate::middleware::TokenBucket; use crate::middleware::TokenBucket;
use crate::protocols::spec::{ChatCompletionRequest, CompletionRequest, GenerateRequest}; use crate::protocols::spec::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
use crate::reasoning_parser::ParserFactory;
use crate::routers::{RouterFactory, RouterTrait}; use crate::routers::{RouterFactory, RouterTrait};
use crate::service_discovery::{start_service_discovery, ServiceDiscoveryConfig}; use crate::service_discovery::{start_service_discovery, ServiceDiscoveryConfig};
use crate::tokenizer::{factory as tokenizer_factory, traits::Tokenizer};
use crate::tool_parser::ParserRegistry;
use axum::{ use axum::{
extract::{Query, Request, State}, extract::{Query, Request, State},
http::StatusCode, http::StatusCode,
...@@ -27,7 +30,9 @@ pub struct AppContext { ...@@ -27,7 +30,9 @@ pub struct AppContext {
pub client: Client, pub client: Client,
pub router_config: RouterConfig, pub router_config: RouterConfig,
pub rate_limiter: Arc<TokenBucket>, pub rate_limiter: Arc<TokenBucket>,
// Future dependencies can be added here pub tokenizer: Option<Arc<dyn Tokenizer>>,
pub reasoning_parser_factory: Option<ParserFactory>,
pub tool_parser_registry: Option<&'static ParserRegistry>,
} }
impl AppContext { impl AppContext {
...@@ -36,14 +41,45 @@ impl AppContext { ...@@ -36,14 +41,45 @@ impl AppContext {
client: Client, client: Client,
max_concurrent_requests: usize, max_concurrent_requests: usize,
rate_limit_tokens_per_second: Option<usize>, rate_limit_tokens_per_second: Option<usize>,
) -> Self { ) -> Result<Self, String> {
let rate_limit_tokens = rate_limit_tokens_per_second.unwrap_or(max_concurrent_requests); let rate_limit_tokens = rate_limit_tokens_per_second.unwrap_or(max_concurrent_requests);
let rate_limiter = Arc::new(TokenBucket::new(max_concurrent_requests, rate_limit_tokens)); let rate_limiter = Arc::new(TokenBucket::new(max_concurrent_requests, rate_limit_tokens));
Self {
// Initialize gRPC-specific components only when in gRPC mode
let (tokenizer, reasoning_parser_factory, tool_parser_registry) =
if router_config.connection_mode == crate::config::ConnectionMode::Grpc {
// Get tokenizer path (required for gRPC mode)
let tokenizer_path = router_config
.tokenizer_path
.clone()
.or_else(|| router_config.model_path.clone())
.ok_or_else(|| {
"gRPC mode requires either --tokenizer-path or --model-path to be specified"
.to_string()
})?;
// Initialize all gRPC components
let tokenizer = Some(
tokenizer_factory::create_tokenizer(&tokenizer_path)
.map_err(|e| format!("Failed to create tokenizer: {}", e))?,
);
let reasoning_parser_factory = Some(ParserFactory::new());
let tool_parser_registry = Some(ParserRegistry::new());
(tokenizer, reasoning_parser_factory, tool_parser_registry)
} else {
// HTTP mode doesn't need these components
(None, None, None)
};
Ok(Self {
client, client,
router_config, router_config,
rate_limiter, rate_limiter,
} tokenizer,
reasoning_parser_factory,
tool_parser_registry,
})
} }
} }
...@@ -291,7 +327,7 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err ...@@ -291,7 +327,7 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
client.clone(), client.clone(),
config.router_config.max_concurrent_requests, config.router_config.max_concurrent_requests,
config.router_config.rate_limit_tokens_per_second, config.router_config.rate_limit_tokens_per_second,
)); )?);
// Create router with the context // Create router with the context
let router = RouterFactory::create_router(&app_context).await?; let router = RouterFactory::create_router(&app_context).await?;
......
...@@ -13,12 +13,15 @@ use std::sync::{Arc, Mutex, OnceLock}; ...@@ -13,12 +13,15 @@ use std::sync::{Arc, Mutex, OnceLock};
/// Helper function to create AppContext for tests /// Helper function to create AppContext for tests
pub fn create_test_context(config: RouterConfig) -> Arc<AppContext> { pub fn create_test_context(config: RouterConfig) -> Arc<AppContext> {
Arc::new(AppContext::new( Arc::new(
AppContext::new(
config.clone(), config.clone(),
reqwest::Client::new(), reqwest::Client::new(),
config.max_concurrent_requests, config.max_concurrent_requests,
config.rate_limit_tokens_per_second, config.rate_limit_tokens_per_second,
)) )
.expect("Failed to create AppContext in test"),
)
} }
// Tokenizer download configuration // Tokenizer download configuration
......
...@@ -15,12 +15,15 @@ pub fn create_test_app( ...@@ -15,12 +15,15 @@ pub fn create_test_app(
router_config: &RouterConfig, router_config: &RouterConfig,
) -> Router { ) -> Router {
// Create AppContext // Create AppContext
let app_context = Arc::new(AppContext::new( let app_context = Arc::new(
AppContext::new(
router_config.clone(), router_config.clone(),
client, client,
router_config.max_concurrent_requests, router_config.max_concurrent_requests,
router_config.rate_limit_tokens_per_second, router_config.rate_limit_tokens_per_second,
)); )
.expect("Failed to create AppContext in test"),
);
// Create AppState with the test router and context // Create AppState with the test router and context
let app_state = Arc::new(AppState { let app_state = Arc::new(AppState {
......
...@@ -195,7 +195,8 @@ mod test_pd_routing { ...@@ -195,7 +195,8 @@ mod test_pd_routing {
// Router creation will fail due to health checks, but config should be valid // Router creation will fail due to health checks, but config should be valid
let app_context = let app_context =
sglang_router_rs::server::AppContext::new(config, reqwest::Client::new(), 64, None); sglang_router_rs::server::AppContext::new(config, reqwest::Client::new(), 64, None)
.expect("Failed to create AppContext");
let app_context = std::sync::Arc::new(app_context); let app_context = std::sync::Arc::new(app_context);
let result = RouterFactory::create_router(&app_context).await; let result = RouterFactory::create_router(&app_context).await;
assert!(result.is_err()); assert!(result.is_err());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment