Unverified Commit 229f236d authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] migrate app context to builder pattern 2/n (#12089)

parent 5983e5bd
use std::sync::{Arc, OnceLock}; use std::{
sync::{Arc, OnceLock},
time::Duration,
};
use reqwest::Client; use reqwest::Client;
use tracing::info;
use crate::{ use crate::{
config::RouterConfig, config::{HistoryBackend, RouterConfig},
core::{workflow::WorkflowEngine, JobQueue, LoadMonitor, WorkerRegistry}, core::{workflow::WorkflowEngine, ConnectionMode, JobQueue, LoadMonitor, WorkerRegistry},
data_connector::{ data_connector::{
SharedConversationItemStorage, SharedConversationStorage, SharedResponseStorage, MemoryConversationItemStorage, MemoryConversationStorage, MemoryResponseStorage,
NoOpConversationStorage, NoOpResponseStorage, OracleConversationItemStorage,
OracleConversationStorage, OracleResponseStorage, SharedConversationItemStorage,
SharedConversationStorage, SharedResponseStorage,
}, },
middleware::TokenBucket, middleware::TokenBucket,
policies::PolicyRegistry, policies::PolicyRegistry,
reasoning_parser::ParserFactory as ReasoningParserFactory, reasoning_parser::ParserFactory as ReasoningParserFactory,
routers::router_manager::RouterManager, routers::router_manager::RouterManager,
tokenizer::traits::Tokenizer, tokenizer::{
cache::{CacheConfig, CachedTokenizer},
factory as tokenizer_factory,
traits::Tokenizer,
},
tool_parser::ParserFactory as ToolParserFactory, tool_parser::ParserFactory as ToolParserFactory,
}; };
...@@ -71,6 +82,18 @@ impl AppContext { ...@@ -71,6 +82,18 @@ impl AppContext {
pub fn builder() -> AppContextBuilder { pub fn builder() -> AppContextBuilder {
AppContextBuilder::new() AppContextBuilder::new()
} }
/// Create AppContext from config with all components initialized
/// This is the main entry point that replaces ~194 lines of initialization in server.rs
pub async fn from_config(
router_config: RouterConfig,
request_timeout_secs: u64,
) -> Result<Self, String> {
AppContextBuilder::from_config(router_config, request_timeout_secs)
.await?
.build()
.map_err(|e| e.to_string())
}
} }
impl AppContextBuilder { impl AppContextBuilder {
...@@ -216,6 +239,291 @@ impl AppContextBuilder { ...@@ -216,6 +239,291 @@ impl AppContextBuilder {
.ok_or(AppContextBuildError("workflow_engine"))?, .ok_or(AppContextBuildError("workflow_engine"))?,
}) })
} }
/// Initialize AppContext from config - creates ALL components
/// This replaces ~194 lines of initialization logic from server.rs
pub async fn from_config(
router_config: RouterConfig,
request_timeout_secs: u64,
) -> Result<Self, String> {
Ok(Self::new()
.with_client(&router_config, request_timeout_secs)?
.maybe_rate_limiter(&router_config)
.maybe_tokenizer(&router_config)?
.maybe_reasoning_parser_factory(&router_config)
.maybe_tool_parser_factory(&router_config)
.with_worker_registry()
.with_policy_registry(&router_config)
.with_response_storage(&router_config)
.await?
.with_conversation_storage(&router_config)
.await?
.with_conversation_item_storage(&router_config)
.await?
.with_load_monitor(&router_config)
.with_worker_job_queue()
.with_workflow_engine()
.router_config(router_config))
}
/// Create HTTP client with TLS/mTLS configuration
fn with_client(mut self, config: &RouterConfig, timeout_secs: u64) -> Result<Self, String> {
// FIXME: Current implementation creates a single HTTP client for all workers.
// This works well for single security domain deployments where all workers share
// the same CA and can accept the same client certificate.
//
// For multi-domain deployments (e.g., different model families with different CAs),
// this architecture needs significant refactoring:
// 1. Move client creation into worker registration workflow (per-worker clients)
// 2. Store client per worker in WorkerRegistry
// 3. Update PDRouter and other routers to fetch client from worker
// 4. Add per-worker TLS spec in WorkerConfigRequest
//
// Current single-domain approach is sufficient for most deployments.
//
// Use rustls TLS backend when TLS/mTLS is configured (client cert or CA certs provided).
// This ensures proper PKCS#8 key format support. For plain HTTP workers, use default
// backend to avoid unnecessary TLS initialization overhead.
let has_tls_config = config.client_identity.is_some() || !config.ca_certificates.is_empty();
let mut client_builder = Client::builder()
.pool_idle_timeout(Some(Duration::from_secs(50)))
.pool_max_idle_per_host(500)
.timeout(Duration::from_secs(timeout_secs))
.connect_timeout(Duration::from_secs(10))
.tcp_nodelay(true)
.tcp_keepalive(Some(Duration::from_secs(30)));
// Force rustls backend when TLS is configured
if has_tls_config {
client_builder = client_builder.use_rustls_tls();
info!("Using rustls TLS backend for TLS/mTLS connections");
}
// Configure mTLS client identity if provided (certificates already loaded during config creation)
if let Some(identity_pem) = &config.client_identity {
let identity = reqwest::Identity::from_pem(identity_pem)
.map_err(|e| format!("Failed to create client identity: {}", e))?;
client_builder = client_builder.identity(identity);
info!("mTLS client authentication enabled");
}
// Add CA certificates for verifying worker TLS (certificates already loaded during config creation)
for ca_cert in &config.ca_certificates {
let cert = reqwest::Certificate::from_pem(ca_cert)
.map_err(|e| format!("Failed to add CA certificate: {}", e))?;
client_builder = client_builder.add_root_certificate(cert);
}
if !config.ca_certificates.is_empty() {
info!(
"Added {} CA certificate(s) for worker verification",
config.ca_certificates.len()
);
}
let client = client_builder
.build()
.map_err(|e| format!("Failed to create HTTP client: {}", e))?;
self.client = Some(client);
Ok(self)
}
/// Create rate limiter based on config
fn maybe_rate_limiter(mut self, config: &RouterConfig) -> Self {
self.rate_limiter = match config.max_concurrent_requests {
n if n <= 0 => None,
n => {
let rate_limit_tokens = config
.rate_limit_tokens_per_second
.filter(|&t| t > 0)
.unwrap_or(n);
Some(Arc::new(TokenBucket::new(
n as usize,
rate_limit_tokens as usize,
)))
}
};
self
}
/// Create tokenizer for gRPC mode
fn maybe_tokenizer(mut self, config: &RouterConfig) -> Result<Self, String> {
if matches!(config.connection_mode, ConnectionMode::Grpc { .. }) {
let tokenizer_path = config
.tokenizer_path
.clone()
.or_else(|| config.model_path.clone())
.ok_or_else(|| {
"gRPC mode requires either --tokenizer-path or --model-path to be specified"
.to_string()
})?;
let base_tokenizer = tokenizer_factory::create_tokenizer_with_chat_template_blocking(
&tokenizer_path,
config.chat_template.as_deref(),
)
.map_err(|e| {
format!(
"Failed to create tokenizer from '{}': {}. \
Ensure the path is valid and points to a tokenizer file (tokenizer.json) \
or a HuggingFace model ID. For directories, ensure they contain tokenizer files.",
tokenizer_path, e
)
})?;
// Conditionally wrap with caching layer if at least one cache is enabled
self.tokenizer = if config.tokenizer_cache.enable_l0 || config.tokenizer_cache.enable_l1
{
let cache_config = CacheConfig {
enable_l0: config.tokenizer_cache.enable_l0,
l0_max_entries: config.tokenizer_cache.l0_max_entries,
enable_l1: config.tokenizer_cache.enable_l1,
l1_max_memory: config.tokenizer_cache.l1_max_memory,
};
Some(Arc::new(CachedTokenizer::new(base_tokenizer, cache_config))
as Arc<dyn Tokenizer>)
} else {
// Use base tokenizer directly without caching
Some(base_tokenizer)
};
}
Ok(self)
}
/// Create reasoning parser factory for gRPC mode
fn maybe_reasoning_parser_factory(mut self, config: &RouterConfig) -> Self {
if matches!(config.connection_mode, ConnectionMode::Grpc { .. }) {
self.reasoning_parser_factory = Some(ReasoningParserFactory::new());
}
self
}
/// Create tool parser factory for gRPC mode
fn maybe_tool_parser_factory(mut self, config: &RouterConfig) -> Self {
if matches!(config.connection_mode, ConnectionMode::Grpc { .. }) {
self.tool_parser_factory = Some(ToolParserFactory::new());
}
self
}
/// Create worker registry
fn with_worker_registry(mut self) -> Self {
self.worker_registry = Some(Arc::new(WorkerRegistry::new()));
self
}
/// Create policy registry
fn with_policy_registry(mut self, config: &RouterConfig) -> Self {
self.policy_registry = Some(Arc::new(PolicyRegistry::new(config.policy.clone())));
self
}
/// Create response storage based on history_backend config
async fn with_response_storage(mut self, config: &RouterConfig) -> Result<Self, String> {
self.response_storage = Some(match config.history_backend {
HistoryBackend::Memory => {
info!("Initializing response storage: Memory");
Arc::new(MemoryResponseStorage::new())
}
HistoryBackend::None => {
info!("Initializing response storage: None (no persistence)");
Arc::new(NoOpResponseStorage::new())
}
HistoryBackend::Oracle => {
let oracle_cfg = config.oracle.clone().ok_or_else(|| {
"oracle configuration is required when history_backend=oracle".to_string()
})?;
info!(
"Initializing response storage: Oracle ATP (pool: {}-{})",
oracle_cfg.pool_min, oracle_cfg.pool_max
);
Arc::new(OracleResponseStorage::new(oracle_cfg).map_err(|err| {
format!("failed to initialize Oracle response storage: {err}")
})?)
}
});
Ok(self)
}
/// Create conversation storage based on history_backend config
async fn with_conversation_storage(mut self, config: &RouterConfig) -> Result<Self, String> {
self.conversation_storage = Some(match config.history_backend {
HistoryBackend::Memory => Arc::new(MemoryConversationStorage::new()),
HistoryBackend::None => Arc::new(NoOpConversationStorage::new()),
HistoryBackend::Oracle => {
let oracle_cfg = config.oracle.clone().ok_or_else(|| {
"oracle configuration is required when history_backend=oracle".to_string()
})?;
info!("Initializing conversation storage: Oracle ATP");
Arc::new(OracleConversationStorage::new(oracle_cfg).map_err(|err| {
format!("failed to initialize Oracle conversation storage: {err}")
})?)
}
});
Ok(self)
}
/// Create conversation item storage based on history_backend config
async fn with_conversation_item_storage(
mut self,
config: &RouterConfig,
) -> Result<Self, String> {
self.conversation_item_storage = Some(match config.history_backend {
HistoryBackend::Oracle => {
let oracle_cfg = config.oracle.clone().ok_or_else(|| {
"oracle configuration is required when history_backend=oracle".to_string()
})?;
info!("Initializing conversation item storage: Oracle ATP");
Arc::new(OracleConversationItemStorage::new(oracle_cfg).map_err(|e| {
format!("failed to initialize Oracle conversation item storage: {e}")
})?)
}
HistoryBackend::Memory => {
info!("Initializing conversation item storage: Memory");
Arc::new(MemoryConversationItemStorage::new())
}
HistoryBackend::None => {
info!("Initializing conversation item storage: Memory (no NoOp implementation available)");
Arc::new(MemoryConversationItemStorage::new())
}
});
Ok(self)
}
/// Create load monitor
fn with_load_monitor(mut self, config: &RouterConfig) -> Self {
let client = self
.client
.as_ref()
.expect("client must be set before load monitor");
self.load_monitor = Some(Arc::new(LoadMonitor::new(
self.worker_registry
.as_ref()
.expect("worker_registry must be set")
.clone(),
self.policy_registry
.as_ref()
.expect("policy_registry must be set")
.clone(),
client.clone(),
config.worker_startup_check_interval_secs,
)));
self
}
/// Create worker job queue OnceLock container
fn with_worker_job_queue(mut self) -> Self {
self.worker_job_queue = Some(Arc::new(OnceLock::new()));
self
}
/// Create workflow engine OnceLock container
fn with_workflow_engine(mut self) -> Self {
self.workflow_engine = Some(Arc::new(OnceLock::new()));
self
}
} }
impl Default for AppContextBuilder { impl Default for AppContextBuilder {
......
use std::{ use std::{
sync::{ sync::{
atomic::{AtomicBool, Ordering}, atomic::{AtomicBool, Ordering},
Arc, OnceLock, Arc,
}, },
time::Duration, time::Duration,
}; };
...@@ -13,7 +13,6 @@ use axum::{ ...@@ -13,7 +13,6 @@ use axum::{
routing::{delete, get, post}, routing::{delete, get, post},
serve, Json, Router, serve, Json, Router,
}; };
use reqwest::Client;
use serde::Deserialize; use serde::Deserialize;
use serde_json::{json, Value}; use serde_json::{json, Value};
use tokio::{net::TcpListener, signal, spawn}; use tokio::{net::TcpListener, signal, spawn};
...@@ -21,26 +20,18 @@ use tracing::{error, info, warn, Level}; ...@@ -21,26 +20,18 @@ use tracing::{error, info, warn, Level};
use crate::{ use crate::{
app_context::AppContext, app_context::AppContext,
config::{HistoryBackend, RouterConfig, RoutingMode}, config::{RouterConfig, RoutingMode},
core::{ core::{
worker_to_info, worker_to_info,
workflow::{ workflow::{
create_worker_registration_workflow, create_worker_removal_workflow, LoggingSubscriber, create_worker_registration_workflow, create_worker_removal_workflow, LoggingSubscriber,
WorkflowEngine, WorkflowEngine,
}, },
ConnectionMode, Job, JobQueue, JobQueueConfig, LoadMonitor, WorkerManager, WorkerRegistry, Job, JobQueue, JobQueueConfig, WorkerManager, WorkerType,
WorkerType,
},
data_connector::{
MemoryConversationItemStorage, MemoryConversationStorage, MemoryResponseStorage,
NoOpConversationStorage, NoOpResponseStorage, OracleConversationItemStorage,
OracleConversationStorage, OracleResponseStorage, SharedConversationItemStorage,
SharedConversationStorage, SharedResponseStorage,
}, },
logging::{self, LoggingConfig}, logging::{self, LoggingConfig},
metrics::{self, PrometheusConfig}, metrics::{self, PrometheusConfig},
middleware::{self, AuthConfig, QueuedRequest, TokenBucket}, middleware::{self, AuthConfig, QueuedRequest},
policies::PolicyRegistry,
protocols::{ protocols::{
chat::ChatCompletionRequest, chat::ChatCompletionRequest,
classify::ClassifyRequest, classify::ClassifyRequest,
...@@ -52,15 +43,8 @@ use crate::{ ...@@ -52,15 +43,8 @@ use crate::{
validated::ValidatedJson, validated::ValidatedJson,
worker_spec::{WorkerConfigRequest, WorkerErrorResponse, WorkerInfo}, worker_spec::{WorkerConfigRequest, WorkerErrorResponse, WorkerInfo},
}, },
reasoning_parser::ParserFactory as ReasoningParserFactory,
routers::{router_manager::RouterManager, RouterTrait}, routers::{router_manager::RouterManager, RouterTrait},
service_discovery::{start_service_discovery, ServiceDiscoveryConfig}, service_discovery::{start_service_discovery, ServiceDiscoveryConfig},
tokenizer::{
cache::{CacheConfig, CachedTokenizer},
factory as tokenizer_factory,
traits::Tokenizer,
},
tool_parser::ParserFactory as ToolParserFactory,
}; };
#[derive(Clone)] #[derive(Clone)]
...@@ -734,220 +718,8 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err ...@@ -734,220 +718,8 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
config.max_payload_size / (1024 * 1024) config.max_payload_size / (1024 * 1024)
); );
// FIXME: Current implementation creates a single HTTP client for all workers.
// This works well for single security domain deployments where all workers share
// the same CA and can accept the same client certificate.
//
// For multi-domain deployments (e.g., different model families with different CAs),
// this architecture needs significant refactoring:
// 1. Move client creation into worker registration workflow (per-worker clients)
// 2. Store client per worker in WorkerRegistry
// 3. Update PDRouter and other routers to fetch client from worker
// 4. Add per-worker TLS spec in WorkerConfigRequest
//
// Current single-domain approach is sufficient for most deployments.
//
// Use rustls TLS backend when TLS/mTLS is configured (client cert or CA certs provided).
// This ensures proper PKCS#8 key format support. For plain HTTP workers, use default
// backend to avoid unnecessary TLS initialization overhead.
let has_tls_config = config.router_config.client_identity.is_some()
|| !config.router_config.ca_certificates.is_empty();
let mut client_builder = Client::builder()
.pool_idle_timeout(Some(Duration::from_secs(50)))
.pool_max_idle_per_host(500)
.timeout(Duration::from_secs(config.request_timeout_secs))
.connect_timeout(Duration::from_secs(10))
.tcp_nodelay(true)
.tcp_keepalive(Some(Duration::from_secs(30)));
// Force rustls backend when TLS is configured
if has_tls_config {
client_builder = client_builder.use_rustls_tls();
info!("Using rustls TLS backend for TLS/mTLS connections");
}
// Configure mTLS client identity if provided (certificates already loaded during config creation)
if let Some(identity_pem) = &config.router_config.client_identity {
let identity = reqwest::Identity::from_pem(identity_pem)?;
client_builder = client_builder.identity(identity);
info!("mTLS client authentication enabled");
}
// Add CA certificates for verifying worker TLS (certificates already loaded during config creation)
for ca_cert in &config.router_config.ca_certificates {
let cert = reqwest::Certificate::from_pem(ca_cert)?;
client_builder = client_builder.add_root_certificate(cert);
}
if !config.router_config.ca_certificates.is_empty() {
info!(
"Added {} CA certificate(s) for worker verification",
config.router_config.ca_certificates.len()
);
}
let client = client_builder
.build()
.expect("Failed to create HTTP client");
// Initialize rate limiter
let rate_limiter = match config.router_config.max_concurrent_requests {
n if n <= 0 => None,
n => {
let rate_limit_tokens = config
.router_config
.rate_limit_tokens_per_second
.filter(|&t| t > 0)
.unwrap_or(n);
Some(Arc::new(TokenBucket::new(
n as usize,
rate_limit_tokens as usize,
)))
}
};
// Initialize tokenizer and parser factories for gRPC mode
let (tokenizer, reasoning_parser_factory, tool_parser_factory) = if matches!(
config.router_config.connection_mode,
ConnectionMode::Grpc { .. }
) {
let tokenizer_path = config
.router_config
.tokenizer_path
.clone()
.or_else(|| config.router_config.model_path.clone())
.ok_or_else(|| {
"gRPC mode requires either --tokenizer-path or --model-path to be specified"
.to_string()
})?;
let base_tokenizer =
tokenizer_factory::create_tokenizer_with_chat_template_blocking(
&tokenizer_path,
config.router_config.chat_template.as_deref(),
)
.map_err(|e| {
format!(
"Failed to create tokenizer from '{}': {}. \
Ensure the path is valid and points to a tokenizer file (tokenizer.json) \
or a HuggingFace model ID. For directories, ensure they contain tokenizer files.",
tokenizer_path, e
)
})?;
// Conditionally wrap with caching layer if at least one cache is enabled
let tokenizer = if config.router_config.tokenizer_cache.enable_l0
|| config.router_config.tokenizer_cache.enable_l1
{
let cache_config = CacheConfig {
enable_l0: config.router_config.tokenizer_cache.enable_l0,
l0_max_entries: config.router_config.tokenizer_cache.l0_max_entries,
enable_l1: config.router_config.tokenizer_cache.enable_l1,
l1_max_memory: config.router_config.tokenizer_cache.l1_max_memory,
};
Some(Arc::new(CachedTokenizer::new(base_tokenizer, cache_config)) as Arc<dyn Tokenizer>)
} else {
// Use base tokenizer directly without caching
Some(base_tokenizer)
};
let reasoning_parser_factory = Some(ReasoningParserFactory::new());
let tool_parser_factory = Some(ToolParserFactory::new());
(tokenizer, reasoning_parser_factory, tool_parser_factory)
} else {
(None, None, None)
};
// Initialize worker registry and policy registry
let worker_registry = Arc::new(WorkerRegistry::new());
let policy_registry = Arc::new(PolicyRegistry::new(config.router_config.policy.clone()));
// Initialize storage backends
let (response_storage, conversation_storage): (
SharedResponseStorage,
SharedConversationStorage,
) = match config.router_config.history_backend {
HistoryBackend::Memory => {
info!("Initializing data connector: Memory");
(
Arc::new(MemoryResponseStorage::new()),
Arc::new(MemoryConversationStorage::new()),
)
}
HistoryBackend::None => {
info!("Initializing data connector: None (no persistence)");
(
Arc::new(NoOpResponseStorage::new()),
Arc::new(NoOpConversationStorage::new()),
)
}
HistoryBackend::Oracle => {
let oracle_cfg = config.router_config.oracle.clone().ok_or_else(|| {
"oracle configuration is required when history_backend=oracle".to_string()
})?;
info!(
"Initializing data connector: Oracle ATP (pool: {}-{})",
oracle_cfg.pool_min, oracle_cfg.pool_max
);
let response_storage = OracleResponseStorage::new(oracle_cfg.clone())
.map_err(|err| format!("failed to initialize Oracle response storage: {err}"))?;
let conversation_storage =
OracleConversationStorage::new(oracle_cfg.clone()).map_err(|err| {
format!("failed to initialize Oracle conversation storage: {err}")
})?;
info!("Data connector initialized successfully: Oracle ATP");
(Arc::new(response_storage), Arc::new(conversation_storage))
}
};
// Initialize conversation items storage
let conversation_item_storage: SharedConversationItemStorage =
match config.router_config.history_backend {
HistoryBackend::Oracle => {
let oracle_cfg = config.router_config.oracle.clone().ok_or_else(|| {
"oracle configuration is required when history_backend=oracle".to_string()
})?;
Arc::new(OracleConversationItemStorage::new(oracle_cfg).map_err(|e| {
format!("failed to initialize Oracle conversation item storage: {e}")
})?)
}
_ => Arc::new(MemoryConversationItemStorage::new()),
};
// Initialize load monitor
let load_monitor = Some(Arc::new(LoadMonitor::new(
worker_registry.clone(),
policy_registry.clone(),
client.clone(),
config.router_config.worker_startup_check_interval_secs,
)));
// Create empty OnceLock for worker job queue and workflow engine (will be initialized below)
let worker_job_queue = Arc::new(OnceLock::new());
let workflow_engine = Arc::new(OnceLock::new());
// Create AppContext with all initialized components using builder pattern
let app_context = Arc::new( let app_context = Arc::new(
AppContext::builder() AppContext::from_config(config.router_config.clone(), config.request_timeout_secs).await?,
.router_config(config.router_config.clone())
.client(client.clone())
.rate_limiter(rate_limiter)
.tokenizer(tokenizer)
.reasoning_parser_factory(reasoning_parser_factory)
.tool_parser_factory(tool_parser_factory)
.worker_registry(worker_registry)
.policy_registry(policy_registry)
.response_storage(response_storage)
.conversation_storage(conversation_storage)
.conversation_item_storage(conversation_item_storage)
.load_monitor(load_monitor)
.worker_job_queue(worker_job_queue)
.workflow_engine(workflow_engine)
.build()
.map_err(|e| e.to_string())?,
); );
let weak_context = Arc::downgrade(&app_context); let weak_context = Arc::downgrade(&app_context);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment