Unverified Commit 5983e5bd authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] migrate app context to builder pattern 1/n (#12086)

parent 4b046a72
use std::sync::{Arc, OnceLock};
use reqwest::Client;
use crate::{
config::RouterConfig,
core::{workflow::WorkflowEngine, JobQueue, LoadMonitor, WorkerRegistry},
data_connector::{
SharedConversationItemStorage, SharedConversationStorage, SharedResponseStorage,
},
middleware::TokenBucket,
policies::PolicyRegistry,
reasoning_parser::ParserFactory as ReasoningParserFactory,
routers::router_manager::RouterManager,
tokenizer::traits::Tokenizer,
tool_parser::ParserFactory as ToolParserFactory,
};
/// Error type for AppContext builder
#[derive(Debug)]
pub struct AppContextBuildError(&'static str);
impl std::fmt::Display for AppContextBuildError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Missing required field: {}", self.0)
}
}
impl std::error::Error for AppContextBuildError {}
#[derive(Clone)]
pub struct AppContext {
pub client: Client,
pub router_config: RouterConfig,
pub rate_limiter: Option<Arc<TokenBucket>>,
pub tokenizer: Option<Arc<dyn Tokenizer>>,
pub reasoning_parser_factory: Option<ReasoningParserFactory>,
pub tool_parser_factory: Option<ToolParserFactory>,
pub worker_registry: Arc<WorkerRegistry>,
pub policy_registry: Arc<PolicyRegistry>,
pub router_manager: Option<Arc<RouterManager>>,
pub response_storage: SharedResponseStorage,
pub conversation_storage: SharedConversationStorage,
pub conversation_item_storage: SharedConversationItemStorage,
pub load_monitor: Option<Arc<LoadMonitor>>,
pub configured_reasoning_parser: Option<String>,
pub configured_tool_parser: Option<String>,
pub worker_job_queue: Arc<OnceLock<Arc<JobQueue>>>,
pub workflow_engine: Arc<OnceLock<Arc<WorkflowEngine>>>,
}
pub struct AppContextBuilder {
client: Option<Client>,
router_config: Option<RouterConfig>,
rate_limiter: Option<Arc<TokenBucket>>,
tokenizer: Option<Arc<dyn Tokenizer>>,
reasoning_parser_factory: Option<ReasoningParserFactory>,
tool_parser_factory: Option<ToolParserFactory>,
worker_registry: Option<Arc<WorkerRegistry>>,
policy_registry: Option<Arc<PolicyRegistry>>,
router_manager: Option<Arc<RouterManager>>,
response_storage: Option<SharedResponseStorage>,
conversation_storage: Option<SharedConversationStorage>,
conversation_item_storage: Option<SharedConversationItemStorage>,
load_monitor: Option<Arc<LoadMonitor>>,
worker_job_queue: Option<Arc<OnceLock<Arc<JobQueue>>>>,
workflow_engine: Option<Arc<OnceLock<Arc<WorkflowEngine>>>>,
}
impl AppContext {
pub fn builder() -> AppContextBuilder {
AppContextBuilder::new()
}
}
impl AppContextBuilder {
pub fn new() -> Self {
Self {
client: None,
router_config: None,
rate_limiter: None,
tokenizer: None,
reasoning_parser_factory: None,
tool_parser_factory: None,
worker_registry: None,
policy_registry: None,
router_manager: None,
response_storage: None,
conversation_storage: None,
conversation_item_storage: None,
load_monitor: None,
worker_job_queue: None,
workflow_engine: None,
}
}
pub fn client(mut self, client: Client) -> Self {
self.client = Some(client);
self
}
pub fn router_config(mut self, router_config: RouterConfig) -> Self {
self.router_config = Some(router_config);
self
}
pub fn rate_limiter(mut self, rate_limiter: Option<Arc<TokenBucket>>) -> Self {
self.rate_limiter = rate_limiter;
self
}
pub fn tokenizer(mut self, tokenizer: Option<Arc<dyn Tokenizer>>) -> Self {
self.tokenizer = tokenizer;
self
}
pub fn reasoning_parser_factory(
mut self,
reasoning_parser_factory: Option<ReasoningParserFactory>,
) -> Self {
self.reasoning_parser_factory = reasoning_parser_factory;
self
}
pub fn tool_parser_factory(mut self, tool_parser_factory: Option<ToolParserFactory>) -> Self {
self.tool_parser_factory = tool_parser_factory;
self
}
pub fn worker_registry(mut self, worker_registry: Arc<WorkerRegistry>) -> Self {
self.worker_registry = Some(worker_registry);
self
}
pub fn policy_registry(mut self, policy_registry: Arc<PolicyRegistry>) -> Self {
self.policy_registry = Some(policy_registry);
self
}
pub fn router_manager(mut self, router_manager: Option<Arc<RouterManager>>) -> Self {
self.router_manager = router_manager;
self
}
pub fn response_storage(mut self, response_storage: SharedResponseStorage) -> Self {
self.response_storage = Some(response_storage);
self
}
pub fn conversation_storage(mut self, conversation_storage: SharedConversationStorage) -> Self {
self.conversation_storage = Some(conversation_storage);
self
}
pub fn conversation_item_storage(
mut self,
conversation_item_storage: SharedConversationItemStorage,
) -> Self {
self.conversation_item_storage = Some(conversation_item_storage);
self
}
pub fn load_monitor(mut self, load_monitor: Option<Arc<LoadMonitor>>) -> Self {
self.load_monitor = load_monitor;
self
}
pub fn worker_job_queue(mut self, worker_job_queue: Arc<OnceLock<Arc<JobQueue>>>) -> Self {
self.worker_job_queue = Some(worker_job_queue);
self
}
pub fn workflow_engine(mut self, workflow_engine: Arc<OnceLock<Arc<WorkflowEngine>>>) -> Self {
self.workflow_engine = Some(workflow_engine);
self
}
pub fn build(self) -> Result<AppContext, AppContextBuildError> {
let router_config = self
.router_config
.ok_or(AppContextBuildError("router_config"))?;
let configured_reasoning_parser = router_config.reasoning_parser.clone();
let configured_tool_parser = router_config.tool_call_parser.clone();
Ok(AppContext {
client: self.client.ok_or(AppContextBuildError("client"))?,
router_config,
rate_limiter: self.rate_limiter,
tokenizer: self.tokenizer,
reasoning_parser_factory: self.reasoning_parser_factory,
tool_parser_factory: self.tool_parser_factory,
worker_registry: self
.worker_registry
.ok_or(AppContextBuildError("worker_registry"))?,
policy_registry: self
.policy_registry
.ok_or(AppContextBuildError("policy_registry"))?,
router_manager: self.router_manager,
response_storage: self
.response_storage
.ok_or(AppContextBuildError("response_storage"))?,
conversation_storage: self
.conversation_storage
.ok_or(AppContextBuildError("conversation_storage"))?,
conversation_item_storage: self
.conversation_item_storage
.ok_or(AppContextBuildError("conversation_item_storage"))?,
load_monitor: self.load_monitor,
configured_reasoning_parser,
configured_tool_parser,
worker_job_queue: self
.worker_job_queue
.ok_or(AppContextBuildError("worker_job_queue"))?,
workflow_engine: self
.workflow_engine
.ok_or(AppContextBuildError("workflow_engine"))?,
})
}
}
impl Default for AppContextBuilder {
fn default() -> Self {
Self::new()
}
}
......@@ -14,6 +14,7 @@ use tokio::sync::mpsc;
use tracing::{debug, error, info, warn};
use crate::{
app_context::AppContext,
config::{RouterConfig, RoutingMode},
core::workflow::{
steps::WorkerRemovalRequest, WorkflowContext, WorkflowEngine, WorkflowId,
......@@ -21,7 +22,6 @@ use crate::{
},
metrics::RouterMetrics,
protocols::worker_spec::{JobStatus, WorkerConfigRequest},
server::AppContext,
};
/// Job types for control plane operations
......
......@@ -21,13 +21,13 @@ use serde_json::Value;
use tracing::{debug, info, warn};
use crate::{
app_context::AppContext,
core::{
workflow::*, BasicWorkerBuilder, CircuitBreakerConfig, ConnectionMode,
DPAwareWorkerBuilder, HealthConfig, Worker, WorkerType,
},
grpc_client::SglangSchedulerClient,
protocols::worker_spec::WorkerConfigRequest,
server::AppContext,
};
// HTTP client for metadata fetching
......
......@@ -15,8 +15,8 @@ use async_trait::async_trait;
use tracing::{debug, info};
use crate::{
app_context::AppContext,
core::{workflow::*, Worker},
server::AppContext,
};
/// Request structure for worker removal
......
use pyo3::prelude::*;
pub mod app_context;
pub mod config;
pub mod logging;
use std::collections::HashMap;
......
......@@ -9,10 +9,10 @@ use super::{
RouterTrait,
};
use crate::{
app_context::AppContext,
config::{PolicyConfig, RoutingMode},
core::ConnectionMode,
policies::PolicyFactory,
server::AppContext,
};
/// Factory for creating router instances based on configuration
......
......@@ -13,6 +13,7 @@ use tracing::debug;
use super::{context::SharedComponents, pipeline::RequestPipeline};
use crate::{
app_context::AppContext,
config::types::RetryConfig,
core::{ConnectionMode, WorkerRegistry, WorkerType},
policies::PolicyRegistry,
......@@ -27,7 +28,6 @@ use crate::{
},
reasoning_parser::ParserFactory as ReasoningParserFactory,
routers::RouterTrait,
server::AppContext,
tokenizer::traits::Tokenizer,
tool_parser::ParserFactory as ToolParserFactory,
};
......
......@@ -18,6 +18,7 @@ use super::{
responses::{self, BackgroundTaskInfo},
};
use crate::{
app_context::AppContext,
config::types::RetryConfig,
core::WorkerRegistry,
data_connector::{
......@@ -35,7 +36,6 @@ use crate::{
},
reasoning_parser::ParserFactory as ReasoningParserFactory,
routers::RouterTrait,
server::AppContext,
tokenizer::traits::Tokenizer,
tool_parser::ParserFactory as ToolParserFactory,
};
......
......@@ -127,7 +127,7 @@ impl PDRouter {
}
}
pub async fn new(ctx: &Arc<crate::server::AppContext>) -> Result<Self, String> {
pub async fn new(ctx: &Arc<crate::app_context::AppContext>) -> Result<Self, String> {
Ok(PDRouter {
worker_registry: Arc::clone(&ctx.worker_registry),
policy_registry: Arc::clone(&ctx.policy_registry),
......
......@@ -48,7 +48,7 @@ pub struct Router {
impl Router {
/// Create a new router with injected policy and client
pub async fn new(ctx: &Arc<crate::server::AppContext>) -> Result<Self, String> {
pub async fn new(ctx: &Arc<crate::app_context::AppContext>) -> Result<Self, String> {
let workers = ctx.worker_registry.get_workers_filtered(
None, // any model
Some(WorkerType::Regular),
......
......@@ -18,6 +18,7 @@ use serde_json::Value;
use tracing::{debug, info, warn};
use crate::{
app_context::AppContext,
config::RoutingMode,
core::{ConnectionMode, WorkerRegistry, WorkerType},
protocols::{
......@@ -30,7 +31,7 @@ use crate::{
responses::{ResponsesGetParams, ResponsesRequest},
},
routers::RouterTrait,
server::{AppContext, ServerConfig},
server::ServerConfig,
};
#[derive(Debug, Clone, Hash, Eq, PartialEq)]
......
......@@ -20,6 +20,7 @@ use tokio::{net::TcpListener, signal, spawn};
use tracing::{error, info, warn, Level};
use crate::{
app_context::AppContext,
config::{HistoryBackend, RouterConfig, RoutingMode},
core::{
worker_to_info,
......@@ -62,72 +63,6 @@ use crate::{
tool_parser::ParserFactory as ToolParserFactory,
};
//
#[derive(Clone)]
pub struct AppContext {
pub client: Client,
pub router_config: RouterConfig,
pub rate_limiter: Option<Arc<TokenBucket>>,
pub tokenizer: Option<Arc<dyn Tokenizer>>,
pub reasoning_parser_factory: Option<ReasoningParserFactory>,
pub tool_parser_factory: Option<ToolParserFactory>,
pub worker_registry: Arc<WorkerRegistry>,
pub policy_registry: Arc<PolicyRegistry>,
pub router_manager: Option<Arc<RouterManager>>,
pub response_storage: SharedResponseStorage,
pub conversation_storage: SharedConversationStorage,
pub conversation_item_storage: SharedConversationItemStorage,
pub load_monitor: Option<Arc<LoadMonitor>>,
pub configured_reasoning_parser: Option<String>,
pub configured_tool_parser: Option<String>,
pub worker_job_queue: Arc<OnceLock<Arc<JobQueue>>>,
pub workflow_engine: Arc<OnceLock<Arc<WorkflowEngine>>>,
}
impl AppContext {
#[allow(clippy::too_many_arguments)]
pub fn new(
router_config: RouterConfig,
client: Client,
rate_limiter: Option<Arc<TokenBucket>>,
tokenizer: Option<Arc<dyn Tokenizer>>,
reasoning_parser_factory: Option<ReasoningParserFactory>,
tool_parser_factory: Option<ToolParserFactory>,
worker_registry: Arc<WorkerRegistry>,
policy_registry: Arc<PolicyRegistry>,
response_storage: SharedResponseStorage,
conversation_storage: SharedConversationStorage,
conversation_item_storage: SharedConversationItemStorage,
load_monitor: Option<Arc<LoadMonitor>>,
worker_job_queue: Arc<OnceLock<Arc<JobQueue>>>,
workflow_engine: Arc<OnceLock<Arc<WorkflowEngine>>>,
) -> Self {
let configured_reasoning_parser = router_config.reasoning_parser.clone();
let configured_tool_parser = router_config.tool_call_parser.clone();
Self {
client,
router_config,
rate_limiter,
tokenizer,
reasoning_parser_factory,
tool_parser_factory,
worker_registry,
policy_registry,
router_manager: None,
response_storage,
conversation_storage,
conversation_item_storage,
load_monitor,
configured_reasoning_parser,
configured_tool_parser,
worker_job_queue,
workflow_engine,
}
}
}
#[derive(Clone)]
pub struct AppState {
pub router: Arc<dyn RouterTrait>,
......@@ -994,26 +929,27 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
let worker_job_queue = Arc::new(OnceLock::new());
let workflow_engine = Arc::new(OnceLock::new());
// Create AppContext with all initialized components
let app_context = AppContext::new(
config.router_config.clone(),
client.clone(),
rate_limiter,
tokenizer,
reasoning_parser_factory,
tool_parser_factory,
worker_registry,
policy_registry,
response_storage,
conversation_storage,
conversation_item_storage,
load_monitor,
worker_job_queue,
workflow_engine,
// Create AppContext with all initialized components using builder pattern
let app_context = Arc::new(
AppContext::builder()
.router_config(config.router_config.clone())
.client(client.clone())
.rate_limiter(rate_limiter)
.tokenizer(tokenizer)
.reasoning_parser_factory(reasoning_parser_factory)
.tool_parser_factory(tool_parser_factory)
.worker_registry(worker_registry)
.policy_registry(policy_registry)
.response_storage(response_storage)
.conversation_storage(conversation_storage)
.conversation_item_storage(conversation_item_storage)
.load_monitor(load_monitor)
.worker_job_queue(worker_job_queue)
.workflow_engine(workflow_engine)
.build()
.map_err(|e| e.to_string())?,
);
let app_context = Arc::new(app_context);
let weak_context = Arc::downgrade(&app_context);
let worker_job_queue = JobQueue::new(JobQueueConfig::default(), weak_context);
app_context
......
......@@ -18,7 +18,7 @@ use rustls;
use tokio::{task, time};
use tracing::{debug, error, info, warn};
use crate::{core::Job, protocols::worker_spec::WorkerConfigRequest, server::AppContext};
use crate::{app_context::AppContext, core::Job, protocols::worker_spec::WorkerConfigRequest};
#[derive(Debug, Clone)]
pub struct ServiceDiscoveryConfig {
......
......@@ -11,10 +11,10 @@ use common::mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType
use reqwest::Client;
use serde_json::json;
use sglang_router_rs::{
app_context::AppContext,
config::{RouterConfig, RoutingMode},
core::Job,
routers::{RouterFactory, RouterTrait},
server::AppContext,
};
use tower::ServiceExt;
......
......@@ -15,6 +15,7 @@ use std::{
use serde_json::json;
use sglang_router_rs::{
app_context::AppContext,
config::RouterConfig,
core::{LoadMonitor, WorkerRegistry},
data_connector::{
......@@ -23,7 +24,6 @@ use sglang_router_rs::{
middleware::TokenBucket,
policies::PolicyRegistry,
protocols::common::{Function, Tool},
server::AppContext,
};
/// Helper function to create AppContext for tests
......@@ -66,22 +66,25 @@ pub fn create_test_context(config: RouterConfig) -> Arc<AppContext> {
let worker_job_queue = Arc::new(OnceLock::new());
let workflow_engine = Arc::new(OnceLock::new());
let app_context = Arc::new(AppContext::new(
config,
client,
rate_limiter,
None, // tokenizer
None, // reasoning_parser_factory
None, // tool_parser_factory
worker_registry,
policy_registry,
response_storage,
conversation_storage,
conversation_item_storage,
load_monitor,
worker_job_queue,
workflow_engine,
));
let app_context = Arc::new(
AppContext::builder()
.router_config(config)
.client(client)
.rate_limiter(rate_limiter)
.tokenizer(None) // tokenizer
.reasoning_parser_factory(None) // reasoning_parser_factory
.tool_parser_factory(None) // tool_parser_factory
.worker_registry(worker_registry)
.policy_registry(policy_registry)
.response_storage(response_storage)
.conversation_storage(conversation_storage)
.conversation_item_storage(conversation_item_storage)
.load_monitor(load_monitor)
.worker_job_queue(worker_job_queue)
.workflow_engine(workflow_engine)
.build()
.unwrap(),
);
// Initialize JobQueue after AppContext is created
let weak_context = Arc::downgrade(&app_context);
......
......@@ -3,6 +3,7 @@ use std::sync::{Arc, OnceLock};
use axum::Router;
use reqwest::Client;
use sglang_router_rs::{
app_context::AppContext,
config::RouterConfig,
core::{LoadMonitor, WorkerRegistry},
data_connector::{
......@@ -11,7 +12,7 @@ use sglang_router_rs::{
middleware::{AuthConfig, TokenBucket},
policies::PolicyRegistry,
routers::RouterTrait,
server::{build_app, AppContext, AppState},
server::{build_app, AppState},
};
/// Create a test Axum application using the actual server's build_app function
......@@ -57,23 +58,26 @@ pub fn create_test_app(
let worker_job_queue = Arc::new(OnceLock::new());
let workflow_engine = Arc::new(OnceLock::new());
// Create AppContext
let app_context = Arc::new(AppContext::new(
router_config.clone(),
client,
rate_limiter,
None, // tokenizer
None, // reasoning_parser_factory
None, // tool_parser_factory
worker_registry,
policy_registry,
response_storage,
conversation_storage,
conversation_item_storage,
load_monitor,
worker_job_queue,
workflow_engine,
));
// Create AppContext using builder pattern
let app_context = Arc::new(
AppContext::builder()
.router_config(router_config.clone())
.client(client)
.rate_limiter(rate_limiter)
.tokenizer(None) // tokenizer
.reasoning_parser_factory(None) // reasoning_parser_factory
.tool_parser_factory(None) // tool_parser_factory
.worker_registry(worker_registry)
.policy_registry(policy_registry)
.response_storage(response_storage)
.conversation_storage(conversation_storage)
.conversation_item_storage(conversation_item_storage)
.load_monitor(load_monitor)
.worker_job_queue(worker_job_queue)
.workflow_engine(workflow_engine)
.build()
.unwrap(),
);
// Create AppState with the test router and context
let app_state = Arc::new(AppState {
......
......@@ -2,6 +2,7 @@
mod test_pd_routing {
use serde_json::json;
use sglang_router_rs::{
app_context::AppContext,
config::{PolicyConfig, RouterConfig, RoutingMode},
core::{BasicWorkerBuilder, Worker, WorkerType},
routers::{http::pd_types::PDSelectionPolicy, RouterFactory},
......@@ -221,22 +222,25 @@ mod test_pd_routing {
let worker_job_queue = Arc::new(OnceLock::new());
let workflow_engine = Arc::new(OnceLock::new());
Arc::new(sglang_router_rs::server::AppContext::new(
config,
client,
rate_limiter,
None, // tokenizer
None, // reasoning_parser_factory
None, // tool_parser_factory
worker_registry,
policy_registry,
response_storage,
conversation_storage,
conversation_item_storage,
load_monitor,
worker_job_queue,
workflow_engine,
))
Arc::new(
AppContext::builder()
.router_config(config)
.client(client)
.rate_limiter(rate_limiter)
.tokenizer(None) // tokenizer
.reasoning_parser_factory(None) // reasoning_parser_factory
.tool_parser_factory(None) // tool_parser_factory
.worker_registry(worker_registry)
.policy_registry(policy_registry)
.response_storage(response_storage)
.conversation_storage(conversation_storage)
.conversation_item_storage(conversation_item_storage)
.load_monitor(load_monitor)
.worker_job_queue(worker_job_queue)
.workflow_engine(workflow_engine)
.build()
.unwrap(),
)
};
let result = RouterFactory::create_router(&app_context).await;
assert!(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment