Unverified Commit 35ef3f29 authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

[router] fix worker registration in multi model mode (#10486)

parent 31fb19a0
......@@ -804,6 +804,37 @@ impl WorkerFactory {
Box::new(worker)
}
/// Create a prefill worker with labels
pub fn create_prefill_with_labels(
url: String,
bootstrap_port: Option<u16>,
labels: std::collections::HashMap<String, String>,
circuit_breaker_config: CircuitBreakerConfig,
) -> Box<dyn Worker> {
let mut worker = BasicWorker::new(url.clone(), WorkerType::Prefill { bootstrap_port })
.with_circuit_breaker_config(circuit_breaker_config);
// Add labels to metadata
worker.metadata.labels = labels;
Box::new(worker)
}
/// Create a decode worker with labels
pub fn create_decode_with_labels(
url: String,
labels: std::collections::HashMap<String, String>,
circuit_breaker_config: CircuitBreakerConfig,
) -> Box<dyn Worker> {
let mut worker = BasicWorker::new(url.clone(), WorkerType::Decode)
.with_circuit_breaker_config(circuit_breaker_config);
// Add labels to metadata
worker.metadata.labels = labels;
Box::new(worker)
}
/// Create a DP-aware worker of specified type
pub fn create_dp_aware(
base_url: String,
......
......@@ -41,7 +41,6 @@ impl RouterId {
}
/// Router Manager - Central coordinator for routers and workers
/// Only created when enable_igw=true
pub struct RouterManager {
/// Worker registry (single source of truth in multi-router mode)
worker_registry: Arc<WorkerRegistry>,
......@@ -49,7 +48,7 @@ pub struct RouterManager {
/// Policy registry for managing model-to-policy mappings
policy_registry: Arc<crate::policies::PolicyRegistry>,
/// All routers managed by this manager (max 4 routers in Phase 2)
/// All routers managed by this manager
/// RouterId examples: "http-regular", "http-pd", "grpc-regular", "grpc-pd"
routers: Arc<DashMap<RouterId, Arc<dyn RouterTrait>>>,
......@@ -83,12 +82,7 @@ impl RouterManager {
}
/// Register a router with the manager
pub fn register_router(
&mut self,
id: RouterId,
router: Arc<dyn RouterTrait>,
_models: Vec<String>, // Keep parameter for backward compatibility but ignore it
) {
pub fn register_router(&mut self, id: RouterId, router: Arc<dyn RouterTrait>) {
// Store router
self.routers.insert(id.clone(), router);
......@@ -210,32 +204,28 @@ impl RouterManager {
labels.insert("chat_template".to_string(), chat_template);
}
// Create worker based on type
// Note: For prefill and decode workers, we can't easily add labels after creation
// since they return Box<dyn Worker>. We'll need to enhance WorkerFactory in the future.
let worker = match config.worker_type.as_deref() {
Some("prefill") => {
// For now, prefill workers won't have custom labels
// TODO: Enhance WorkerFactory to accept labels for prefill workers
WorkerFactory::create_prefill(config.url.clone(), config.bootstrap_port)
}
Some("decode") => {
// For now, decode workers won't have custom labels
// TODO: Enhance WorkerFactory to accept labels for decode workers
WorkerFactory::create_decode(config.url.clone())
}
_ => {
// Regular workers can have labels
WorkerFactory::create_regular_with_labels(
Some("prefill") => WorkerFactory::create_prefill_with_labels(
config.url.clone(),
config.bootstrap_port,
labels.clone(),
CircuitBreakerConfig::default(),
)
}
),
Some("decode") => WorkerFactory::create_decode_with_labels(
config.url.clone(),
labels.clone(),
CircuitBreakerConfig::default(),
),
_ => WorkerFactory::create_regular_with_labels(
config.url.clone(),
labels.clone(),
CircuitBreakerConfig::default(),
),
};
// Register worker
let worker_id = self.worker_registry.register(Arc::from(worker));
let worker_arc: Arc<dyn Worker> = Arc::from(worker);
let worker_id = self.worker_registry.register(worker_arc.clone());
// Notify PolicyRegistry about the new worker
// Extract policy hint from labels if provided
......@@ -262,7 +252,6 @@ impl RouterManager {
);
// Return worker info
let worker_arc = self.worker_registry.get(&worker_id).unwrap();
let worker_info = self.worker_to_info(worker_id.as_str(), &worker_arc);
Ok(WorkerApiResponse {
......@@ -375,7 +364,11 @@ impl RouterManager {
model_id: worker.model_id().to_string(),
priority: worker.priority(),
cost: worker.cost(),
worker_type: format!("{:?}", worker.worker_type()),
worker_type: match worker.worker_type() {
WorkerType::Regular => "regular".to_string(),
WorkerType::Prefill { .. } => "prefill".to_string(),
WorkerType::Decode => "decode".to_string(),
},
is_healthy: worker.is_healthy(),
load: worker.load(),
connection_mode: format!("{:?}", worker.connection_mode()),
......@@ -387,11 +380,6 @@ impl RouterManager {
}
}
// Note: calculate_stats removed - using WorkerRegistry::stats() instead
// === Phase 2: Router Management ===
// Note: Dynamic router creation removed - routers are created and registered externally
/// Get the appropriate router for a request based on headers and request content
pub fn select_router_for_request(
&self,
......@@ -474,11 +462,6 @@ impl RouterManager {
}
}
// Note: Default implementation removed as RouterManager now requires AppContext
// which cannot be defaulted. RouterManager must be created with explicit context.
// === Phase 2: RouterManager as RouterTrait ===
/// RouterManager implements RouterTrait to act as a meta-router
/// that delegates requests to the appropriate underlying router
#[async_trait]
......
......@@ -317,8 +317,9 @@ async fn create_worker(
State(state): State<Arc<AppState>>,
Json(config): Json<WorkerConfigRequest>,
) -> Response {
// Check if RouterManager is available (enable_igw=true)
if let Some(router_manager) = &state.context.router_manager {
// Check if the router is actually a RouterManager (enable_igw=true)
if let Some(router_manager) = state.router.as_any().downcast_ref::<RouterManager>() {
// Call RouterManager's add_worker method directly with the full config
match router_manager.add_worker(config).await {
Ok(response) => (StatusCode::OK, Json(response)).into_response(),
Err(error) => (StatusCode::BAD_REQUEST, Json(error)).into_response(),
......@@ -347,7 +348,7 @@ async fn create_worker(
/// GET /workers - List all workers with details
async fn list_workers_rest(State(state): State<Arc<AppState>>) -> Response {
if let Some(router_manager) = &state.context.router_manager {
if let Some(router_manager) = state.router.as_any().downcast_ref::<RouterManager>() {
let response = router_manager.list_workers();
Json(response).into_response()
} else {
......@@ -358,7 +359,11 @@ async fn list_workers_rest(State(state): State<Arc<AppState>>) -> Response {
let mut worker_info = serde_json::json!({
"url": worker.url(),
"model_id": worker.model_id(),
"worker_type": format!("{:?}", worker.worker_type()),
"worker_type": match worker.worker_type() {
WorkerType::Regular => "regular",
WorkerType::Prefill { .. } => "prefill",
WorkerType::Decode => "decode",
},
"is_healthy": worker.is_healthy(),
"load": worker.load(),
"connection_mode": format!("{:?}", worker.connection_mode()),
......@@ -386,7 +391,7 @@ async fn list_workers_rest(State(state): State<Arc<AppState>>) -> Response {
/// GET /workers/{url} - Get specific worker info
async fn get_worker(State(state): State<Arc<AppState>>, Path(url): Path<String>) -> Response {
if let Some(router_manager) = &state.context.router_manager {
if let Some(router_manager) = state.router.as_any().downcast_ref::<RouterManager>() {
if let Some(worker) = router_manager.get_worker(&url) {
Json(worker).into_response()
} else {
......@@ -417,7 +422,7 @@ async fn get_worker(State(state): State<Arc<AppState>>, Path(url): Path<String>)
/// DELETE /workers/{url} - Remove a worker
async fn delete_worker(State(state): State<Arc<AppState>>, Path(url): Path<String>) -> Response {
if let Some(router_manager) = &state.context.router_manager {
if let Some(router_manager) = state.router.as_any().downcast_ref::<RouterManager>() {
match router_manager.remove_worker_from_registry(&url) {
Ok(response) => (StatusCode::OK, Json(response)).into_response(),
Err(error) => (StatusCode::BAD_REQUEST, Json(error)).into_response(),
......@@ -603,7 +608,6 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
router_manager.register_router(
RouterId::new("http-regular".to_string()),
Arc::from(http_regular),
vec![], // Models will be determined by workers
);
}
Err(e) => {
......@@ -624,11 +628,8 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
{
Ok(http_pd) => {
info!("Created HTTP PD router");
router_manager.register_router(
RouterId::new("http-pd".to_string()),
Arc::from(http_pd),
vec![],
);
router_manager
.register_router(RouterId::new("http-pd".to_string()), Arc::from(http_pd));
}
Err(e) => {
warn!("Failed to create HTTP PD router: {e}");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment