Unverified Commit 2f173ea0 authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] allow one router to support different model families and serving mode (#10244)

parent 321fecab
This diff is collapsed.
...@@ -17,6 +17,7 @@ pub mod factory; ...@@ -17,6 +17,7 @@ pub mod factory;
pub mod grpc; pub mod grpc;
pub mod header_utils; pub mod header_utils;
pub mod http; pub mod http;
pub mod router_manager;
pub use factory::RouterFactory; pub use factory::RouterFactory;
// Re-export HTTP routers for convenience (keeps routers::openai_router path working) // Re-export HTTP routers for convenience (keeps routers::openai_router path working)
...@@ -63,14 +64,19 @@ pub trait RouterTrait: Send + Sync + Debug + WorkerManagement { ...@@ -63,14 +64,19 @@ pub trait RouterTrait: Send + Sync + Debug + WorkerManagement {
async fn get_model_info(&self, req: Request<Body>) -> Response; async fn get_model_info(&self, req: Request<Body>) -> Response;
/// Route a generate request /// Route a generate request
async fn route_generate(&self, headers: Option<&HeaderMap>, body: &GenerateRequest) async fn route_generate(
-> Response; &self,
headers: Option<&HeaderMap>,
body: &GenerateRequest,
model_id: Option<&str>,
) -> Response;
/// Route a chat completion request /// Route a chat completion request
async fn route_chat( async fn route_chat(
&self, &self,
headers: Option<&HeaderMap>, headers: Option<&HeaderMap>,
body: &ChatCompletionRequest, body: &ChatCompletionRequest,
model_id: Option<&str>,
) -> Response; ) -> Response;
/// Route a completion request /// Route a completion request
...@@ -78,6 +84,7 @@ pub trait RouterTrait: Send + Sync + Debug + WorkerManagement { ...@@ -78,6 +84,7 @@ pub trait RouterTrait: Send + Sync + Debug + WorkerManagement {
&self, &self,
headers: Option<&HeaderMap>, headers: Option<&HeaderMap>,
body: &CompletionRequest, body: &CompletionRequest,
model_id: Option<&str>,
) -> Response; ) -> Response;
/// Route a responses request /// Route a responses request
...@@ -85,11 +92,17 @@ pub trait RouterTrait: Send + Sync + Debug + WorkerManagement { ...@@ -85,11 +92,17 @@ pub trait RouterTrait: Send + Sync + Debug + WorkerManagement {
&self, &self,
headers: Option<&HeaderMap>, headers: Option<&HeaderMap>,
body: &ResponsesRequest, body: &ResponsesRequest,
model_id: Option<&str>,
) -> Response; ) -> Response;
async fn route_embeddings(&self, headers: Option<&HeaderMap>, body: Body) -> Response; async fn route_embeddings(&self, headers: Option<&HeaderMap>, body: Body) -> Response;
async fn route_rerank(&self, headers: Option<&HeaderMap>, body: &RerankRequest) -> Response; async fn route_rerank(
&self,
headers: Option<&HeaderMap>,
body: &RerankRequest,
model_id: Option<&str>,
) -> Response;
/// Flush cache on all workers /// Flush cache on all workers
async fn flush_cache(&self) -> Response; async fn flush_cache(&self) -> Response;
......
This diff is collapsed.
use crate::config::RouterConfig; use crate::config::RouterConfig;
use crate::core::WorkerRegistry;
use crate::logging::{self, LoggingConfig}; use crate::logging::{self, LoggingConfig};
use crate::metrics::{self, PrometheusConfig}; use crate::metrics::{self, PrometheusConfig};
use crate::middleware::TokenBucket; use crate::middleware::TokenBucket;
use crate::policies::PolicyRegistry;
use crate::protocols::spec::{ use crate::protocols::spec::{
ChatCompletionRequest, CompletionRequest, GenerateRequest, RerankRequest, ResponsesRequest, ChatCompletionRequest, CompletionRequest, GenerateRequest, RerankRequest, ResponsesRequest,
V1RerankReqInput, V1RerankReqInput,
}; };
use crate::protocols::worker_spec::{WorkerApiResponse, WorkerConfigRequest, WorkerErrorResponse};
use crate::reasoning_parser::ParserFactory; use crate::reasoning_parser::ParserFactory;
use crate::routers::router_manager::{RouterId, RouterManager};
use crate::routers::{RouterFactory, RouterTrait}; use crate::routers::{RouterFactory, RouterTrait};
use crate::service_discovery::{start_service_discovery, ServiceDiscoveryConfig}; use crate::service_discovery::{start_service_discovery, ServiceDiscoveryConfig};
use crate::tokenizer::{factory as tokenizer_factory, traits::Tokenizer}; use crate::tokenizer::{factory as tokenizer_factory, traits::Tokenizer};
...@@ -36,6 +40,9 @@ pub struct AppContext { ...@@ -36,6 +40,9 @@ pub struct AppContext {
pub tokenizer: Option<Arc<dyn Tokenizer>>, pub tokenizer: Option<Arc<dyn Tokenizer>>,
pub reasoning_parser_factory: Option<ParserFactory>, pub reasoning_parser_factory: Option<ParserFactory>,
pub tool_parser_registry: Option<&'static ParserRegistry>, pub tool_parser_registry: Option<&'static ParserRegistry>,
pub worker_registry: Arc<WorkerRegistry>, // Shared worker registry
pub policy_registry: Arc<PolicyRegistry>, // Shared policy registry
pub router_manager: Option<Arc<RouterManager>>, // Only present when enable_igw=true
} }
impl AppContext { impl AppContext {
...@@ -75,6 +82,15 @@ impl AppContext { ...@@ -75,6 +82,15 @@ impl AppContext {
(None, None, None) (None, None, None)
}; };
// Initialize shared registries
let worker_registry = Arc::new(WorkerRegistry::new());
let policy_registry = Arc::new(PolicyRegistry::new(
router_config.policy.clone(), // Use default policy from config
));
// Initialize RouterManager only when enable_igw is true
let router_manager = None; // Will be initialized in startup() based on config
Ok(Self { Ok(Self {
client, client,
router_config, router_config,
...@@ -82,6 +98,9 @@ impl AppContext { ...@@ -82,6 +98,9 @@ impl AppContext {
tokenizer, tokenizer,
reasoning_parser_factory, reasoning_parser_factory,
tool_parser_registry, tool_parser_registry,
worker_registry,
policy_registry,
router_manager,
}) })
} }
} }
...@@ -134,7 +153,10 @@ async fn generate( ...@@ -134,7 +153,10 @@ async fn generate(
headers: http::HeaderMap, headers: http::HeaderMap,
Json(body): Json<GenerateRequest>, Json(body): Json<GenerateRequest>,
) -> Response { ) -> Response {
state.router.route_generate(Some(&headers), &body).await state
.router
.route_generate(Some(&headers), &body, None)
.await
} }
async fn v1_chat_completions( async fn v1_chat_completions(
...@@ -142,7 +164,7 @@ async fn v1_chat_completions( ...@@ -142,7 +164,7 @@ async fn v1_chat_completions(
headers: http::HeaderMap, headers: http::HeaderMap,
Json(body): Json<ChatCompletionRequest>, Json(body): Json<ChatCompletionRequest>,
) -> Response { ) -> Response {
state.router.route_chat(Some(&headers), &body).await state.router.route_chat(Some(&headers), &body, None).await
} }
async fn v1_completions( async fn v1_completions(
...@@ -150,7 +172,10 @@ async fn v1_completions( ...@@ -150,7 +172,10 @@ async fn v1_completions(
headers: http::HeaderMap, headers: http::HeaderMap,
Json(body): Json<CompletionRequest>, Json(body): Json<CompletionRequest>,
) -> Response { ) -> Response {
state.router.route_completion(Some(&headers), &body).await state
.router
.route_completion(Some(&headers), &body, None)
.await
} }
async fn rerank( async fn rerank(
...@@ -158,7 +183,7 @@ async fn rerank( ...@@ -158,7 +183,7 @@ async fn rerank(
headers: http::HeaderMap, headers: http::HeaderMap,
Json(body): Json<RerankRequest>, Json(body): Json<RerankRequest>,
) -> Response { ) -> Response {
state.router.route_rerank(Some(&headers), &body).await state.router.route_rerank(Some(&headers), &body, None).await
} }
async fn v1_rerank( async fn v1_rerank(
...@@ -168,7 +193,7 @@ async fn v1_rerank( ...@@ -168,7 +193,7 @@ async fn v1_rerank(
) -> Response { ) -> Response {
state state
.router .router
.route_rerank(Some(&headers), &body.into()) .route_rerank(Some(&headers), &body.into(), None)
.await .await
} }
...@@ -177,7 +202,10 @@ async fn v1_responses( ...@@ -177,7 +202,10 @@ async fn v1_responses(
headers: http::HeaderMap, headers: http::HeaderMap,
Json(body): Json<ResponsesRequest>, Json(body): Json<ResponsesRequest>,
) -> Response { ) -> Response {
state.router.route_responses(Some(&headers), &body).await state
.router
.route_responses(Some(&headers), &body, None)
.await
} }
// Worker management endpoints // Worker management endpoints
...@@ -232,6 +260,137 @@ async fn get_loads(State(state): State<Arc<AppState>>, _req: Request) -> Respons ...@@ -232,6 +260,137 @@ async fn get_loads(State(state): State<Arc<AppState>>, _req: Request) -> Respons
state.router.get_worker_loads().await state.router.get_worker_loads().await
} }
// New RESTful worker management endpoints (when enable_igw=true)
/// POST /workers - Add a new worker with full configuration
async fn create_worker(
State(state): State<Arc<AppState>>,
Json(config): Json<WorkerConfigRequest>,
) -> Response {
// Check if RouterManager is available (enable_igw=true)
if let Some(router_manager) = &state.context.router_manager {
match router_manager.add_worker(config).await {
Ok(response) => (StatusCode::OK, Json(response)).into_response(),
Err(error) => (StatusCode::BAD_REQUEST, Json(error)).into_response(),
}
} else {
// In single router mode, use the router's add_worker with basic config
match state.router.add_worker(&config.url).await {
Ok(message) => {
let response = WorkerApiResponse {
success: true,
message,
worker: None,
};
(StatusCode::OK, Json(response)).into_response()
}
Err(error) => {
let error_response = WorkerErrorResponse {
error,
code: "ADD_WORKER_FAILED".to_string(),
};
(StatusCode::BAD_REQUEST, Json(error_response)).into_response()
}
}
}
}
/// GET /workers - List all workers with details
async fn list_workers_rest(State(state): State<Arc<AppState>>) -> Response {
if let Some(router_manager) = &state.context.router_manager {
let response = router_manager.list_workers();
Json(response).into_response()
} else {
// In single router mode, get detailed worker info from registry
let workers = state.context.worker_registry.get_all();
let response = serde_json::json!({
"workers": workers.iter().map(|worker| {
let mut worker_info = serde_json::json!({
"url": worker.url(),
"model_id": worker.model_id(),
"worker_type": format!("{:?}", worker.worker_type()),
"is_healthy": worker.is_healthy(),
"load": worker.load(),
"connection_mode": format!("{:?}", worker.connection_mode()),
"priority": worker.priority(),
"cost": worker.cost(),
});
// Add bootstrap_port for Prefill workers
if let crate::core::WorkerType::Prefill { bootstrap_port } = worker.worker_type() {
worker_info["bootstrap_port"] = serde_json::json!(bootstrap_port);
}
worker_info
}).collect::<Vec<_>>(),
"total": workers.len(),
"stats": {
"prefill_count": state.context.worker_registry.get_prefill_workers().len(),
"decode_count": state.context.worker_registry.get_decode_workers().len(),
"regular_count": state.context.worker_registry.get_by_type(&crate::core::WorkerType::Regular).len(),
}
});
Json(response).into_response()
}
}
/// GET /workers/{url} - Get specific worker info
async fn get_worker(
State(state): State<Arc<AppState>>,
axum::extract::Path(url): axum::extract::Path<String>,
) -> Response {
if let Some(router_manager) = &state.context.router_manager {
if let Some(worker) = router_manager.get_worker(&url) {
Json(worker).into_response()
} else {
let error = WorkerErrorResponse {
error: format!("Worker {} not found", url),
code: "WORKER_NOT_FOUND".to_string(),
};
(StatusCode::NOT_FOUND, Json(error)).into_response()
}
} else {
// In single router mode, check if worker exists
let workers = state.router.get_worker_urls();
if workers.contains(&url) {
let worker_info = serde_json::json!({
"url": url,
"model_id": "unknown",
"is_healthy": true
});
Json(worker_info).into_response()
} else {
let error = WorkerErrorResponse {
error: format!("Worker {} not found", url),
code: "WORKER_NOT_FOUND".to_string(),
};
(StatusCode::NOT_FOUND, Json(error)).into_response()
}
}
}
/// DELETE /workers/{url} - Remove a worker
async fn delete_worker(
State(state): State<Arc<AppState>>,
axum::extract::Path(url): axum::extract::Path<String>,
) -> Response {
if let Some(router_manager) = &state.context.router_manager {
match router_manager.remove_worker_from_registry(&url) {
Ok(response) => (StatusCode::OK, Json(response)).into_response(),
Err(error) => (StatusCode::BAD_REQUEST, Json(error)).into_response(),
}
} else {
// In single router mode, use router's remove_worker
state.router.remove_worker(&url);
let response = WorkerApiResponse {
success: true,
message: format!("Worker {} removed successfully", url),
worker: None,
};
(StatusCode::OK, Json(response)).into_response()
}
}
pub struct ServerConfig { pub struct ServerConfig {
pub host: String, pub host: String,
pub port: u16, pub port: u16,
...@@ -281,11 +440,19 @@ pub fn build_app( ...@@ -281,11 +440,19 @@ pub fn build_app(
.route("/flush_cache", post(flush_cache)) .route("/flush_cache", post(flush_cache))
.route("/get_loads", get(get_loads)); .route("/get_loads", get(get_loads));
// Worker management routes
let worker_routes = Router::new()
.route("/workers", post(create_worker))
.route("/workers", get(list_workers_rest))
.route("/workers/{url}", get(get_worker))
.route("/workers/{url}", axum::routing::delete(delete_worker));
// Build app with all routes and middleware // Build app with all routes and middleware
Router::new() Router::new()
.merge(protected_routes) .merge(protected_routes)
.merge(public_routes) .merge(public_routes)
.merge(admin_routes) .merge(admin_routes)
.merge(worker_routes)
// Request body size limiting // Request body size limiting
.layer(tower_http::limit::RequestBodyLimitLayer::new( .layer(tower_http::limit::RequestBodyLimitLayer::new(
max_payload_size, max_payload_size,
...@@ -355,15 +522,100 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err ...@@ -355,15 +522,100 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
.expect("Failed to create HTTP client"); .expect("Failed to create HTTP client");
// Create the application context with all dependencies // Create the application context with all dependencies
let app_context = Arc::new(AppContext::new( let app_context = AppContext::new(
config.router_config.clone(), config.router_config.clone(),
client.clone(), client.clone(),
config.router_config.max_concurrent_requests, config.router_config.max_concurrent_requests,
config.router_config.rate_limit_tokens_per_second, config.router_config.rate_limit_tokens_per_second,
)?); )?;
let app_context = Arc::new(app_context);
// Create the appropriate router based on enable_igw flag
let router: Box<dyn RouterTrait> = if config.router_config.enable_igw {
info!("Multi-router mode enabled (enable_igw=true)");
// Create RouterManager with shared registries from AppContext
let mut router_manager = RouterManager::new(
config.router_config.clone(),
client.clone(),
app_context.worker_registry.clone(),
app_context.policy_registry.clone(),
);
// Create HTTP routers at startup (with empty worker lists)
// Workers will be added to these routers dynamically via RouterManager's worker registry
// 1. HTTP Regular Router
match RouterFactory::create_regular_router(
&[], // Empty worker list - workers added later
&app_context,
)
.await
{
Ok(http_regular) => {
info!("Created HTTP Regular router");
router_manager.register_router(
RouterId::new("http-regular".to_string()),
Arc::from(http_regular),
vec![], // Models will be determined by workers
);
}
Err(e) => {
warn!("Failed to create HTTP Regular router: {}", e);
}
}
// 2. HTTP PD Router
match RouterFactory::create_pd_router(
&[], // Empty prefill URLs
&[], // Empty decode URLs
None, // Use default prefill policy
None, // Use default decode policy
&config.router_config.policy,
&app_context,
)
.await
{
Ok(http_pd) => {
info!("Created HTTP PD router");
router_manager.register_router(
RouterId::new("http-pd".to_string()),
Arc::from(http_pd),
vec![],
);
}
Err(e) => {
warn!("Failed to create HTTP PD router: {}", e);
}
}
// TODO: Add gRPC routers once we have dynamic tokenizer loading
// Currently gRPC routers require tokenizer to be initialized first,
// but each model needs its own tokenizer. Once we implement dynamic
// tokenizer loading per model, we can enable gRPC routers here:
// - RouterType::GrpcRegular (RouterId: "grpc-regular")
// - RouterType::GrpcPd (RouterId: "grpc-pd")
// Create router with the context info!(
let router = RouterFactory::create_router(&app_context).await?; "RouterManager initialized with {} routers",
router_manager.router_count()
);
Box::new(router_manager)
} else {
info!("Single router mode (enable_igw=false)");
// Create single router with the context
RouterFactory::create_router(&app_context).await?
};
// Start health checker for all workers in the registry
let _health_checker = app_context
.worker_registry
.start_health_checker(config.router_config.health_check.check_interval_secs);
info!(
"Started health checker for workers with {}s interval",
config.router_config.health_check.check_interval_secs
);
// Set up concurrency limiter with queue if configured // Set up concurrency limiter with queue if configured
let (limiter, processor) = crate::middleware::ConcurrencyLimiter::new( let (limiter, processor) = crate::middleware::ConcurrencyLimiter::new(
......
...@@ -579,9 +579,8 @@ mod tests { ...@@ -579,9 +579,8 @@ mod tests {
// Helper to create a Router instance for testing event handlers // Helper to create a Router instance for testing event handlers
async fn create_test_router() -> Arc<dyn RouterTrait> { async fn create_test_router() -> Arc<dyn RouterTrait> {
use crate::config::{PolicyConfig, RouterConfig}; use crate::config::RouterConfig;
use crate::middleware::TokenBucket; use crate::middleware::TokenBucket;
use crate::policies::PolicyFactory;
use crate::routers::http::router::Router; use crate::routers::http::router::Router;
use crate::server::AppContext; use crate::server::AppContext;
...@@ -591,15 +590,19 @@ mod tests { ...@@ -591,15 +590,19 @@ mod tests {
// Create AppContext with minimal components // Create AppContext with minimal components
let app_context = Arc::new(AppContext { let app_context = Arc::new(AppContext {
client: reqwest::Client::new(), client: reqwest::Client::new(),
router_config, router_config: router_config.clone(),
rate_limiter: Arc::new(TokenBucket::new(1000, 1000)), rate_limiter: Arc::new(TokenBucket::new(1000, 1000)),
worker_registry: Arc::new(crate::core::WorkerRegistry::new()),
policy_registry: Arc::new(crate::policies::PolicyRegistry::new(
router_config.policy.clone(),
)),
tokenizer: None, // HTTP mode doesn't need tokenizer tokenizer: None, // HTTP mode doesn't need tokenizer
reasoning_parser_factory: None, // HTTP mode doesn't need reasoning parser reasoning_parser_factory: None, // HTTP mode doesn't need reasoning parser
tool_parser_registry: None, // HTTP mode doesn't need tool parser tool_parser_registry: None, // HTTP mode doesn't need tool parser
router_manager: None, // Test doesn't need router manager
}); });
let policy = PolicyFactory::create_from_config(&PolicyConfig::Random); let router = Router::new(vec![], &app_context).await.unwrap();
let router = Router::new(vec![], policy, &app_context).await.unwrap();
Arc::new(router) as Arc<dyn RouterTrait> Arc::new(router) as Arc<dyn RouterTrait>
} }
......
use sglang_router_rs::core::{BasicWorker, Worker, WorkerType};
use sglang_router_rs::policies::{CacheAwareConfig, CacheAwarePolicy, LoadBalancingPolicy};
use std::collections::HashMap;
use std::sync::Arc;
#[test]
fn test_backward_compatibility_with_empty_model_id() {
let config = CacheAwareConfig {
cache_threshold: 0.5,
balance_abs_threshold: 2,
balance_rel_threshold: 1.5,
eviction_interval_secs: 0, // Disable background eviction for testing
max_tree_size: 100,
};
let policy = CacheAwarePolicy::with_config(config);
// Create workers with empty model_id (simulating existing routers)
let worker1 = BasicWorker::new("http://worker1:8080".to_string(), WorkerType::Regular);
// No model_id label - should default to "unknown"
let mut labels2 = HashMap::new();
labels2.insert("model_id".to_string(), "unknown".to_string());
let worker2 = BasicWorker::new("http://worker2:8080".to_string(), WorkerType::Regular)
.with_labels(labels2);
// Add workers - should both go to "default" tree
policy.add_worker(&worker1);
policy.add_worker(&worker2);
// Create worker list
let workers: Vec<Arc<dyn Worker>> = vec![Arc::new(worker1.clone()), Arc::new(worker2.clone())];
// Select worker - should work without errors
let selected = policy.select_worker(&workers, Some("test request"));
assert!(selected.is_some(), "Should select a worker");
// Remove workers - should work without errors
policy.remove_worker(&worker1);
policy.remove_worker(&worker2);
}
#[test]
fn test_mixed_model_ids() {
let config = CacheAwareConfig {
cache_threshold: 0.5,
balance_abs_threshold: 2,
balance_rel_threshold: 1.5,
eviction_interval_secs: 0,
max_tree_size: 100,
};
let policy = CacheAwarePolicy::with_config(config);
// Create workers with different model_id scenarios
let worker1 = BasicWorker::new("http://worker1:8080".to_string(), WorkerType::Regular);
// No model_id label - defaults to "unknown" which goes to "default" tree
let mut labels2 = HashMap::new();
labels2.insert("model_id".to_string(), "llama-3".to_string());
let worker2 = BasicWorker::new("http://worker2:8080".to_string(), WorkerType::Regular)
.with_labels(labels2);
let mut labels3 = HashMap::new();
labels3.insert("model_id".to_string(), "unknown".to_string());
let worker3 = BasicWorker::new("http://worker3:8080".to_string(), WorkerType::Regular)
.with_labels(labels3);
let mut labels4 = HashMap::new();
labels4.insert("model_id".to_string(), "llama-3".to_string());
let worker4 = BasicWorker::new("http://worker4:8080".to_string(), WorkerType::Regular)
.with_labels(labels4);
// Add all workers
policy.add_worker(&worker1);
policy.add_worker(&worker2);
policy.add_worker(&worker3);
policy.add_worker(&worker4);
// Test selection with default workers only
let default_workers: Vec<Arc<dyn Worker>> =
vec![Arc::new(worker1.clone()), Arc::new(worker3.clone())];
let selected = policy.select_worker(&default_workers, Some("test request"));
assert!(selected.is_some(), "Should select from default workers");
// Test selection with specific model workers only
let llama_workers: Vec<Arc<dyn Worker>> =
vec![Arc::new(worker2.clone()), Arc::new(worker4.clone())];
let selected = policy.select_worker(&llama_workers, Some("test request"));
assert!(selected.is_some(), "Should select from llama-3 workers");
// Test selection with mixed workers
let all_workers: Vec<Arc<dyn Worker>> = vec![
Arc::new(worker1.clone()),
Arc::new(worker2.clone()),
Arc::new(worker3.clone()),
Arc::new(worker4.clone()),
];
let selected = policy.select_worker(&all_workers, Some("test request"));
assert!(selected.is_some(), "Should select from all workers");
}
#[test]
fn test_remove_worker_by_url_backward_compat() {
let config = CacheAwareConfig::default();
let policy = CacheAwarePolicy::with_config(config);
// Create workers with different model_ids
let mut labels1 = HashMap::new();
labels1.insert("model_id".to_string(), "llama-3".to_string());
let worker1 = BasicWorker::new("http://worker1:8080".to_string(), WorkerType::Regular)
.with_labels(labels1);
let worker2 = BasicWorker::new("http://worker2:8080".to_string(), WorkerType::Regular);
// No model_id label - defaults to "unknown"
// Add workers
policy.add_worker(&worker1);
policy.add_worker(&worker2);
// Remove by URL (backward compatibility method)
// Should remove from all trees since we don't know the model
policy.remove_worker_by_url("http://worker1:8080");
// Verify removal worked
let workers: Vec<Arc<dyn Worker>> = vec![Arc::new(worker2.clone())];
let selected = policy.select_worker(&workers, Some("test"));
assert_eq!(selected, Some(0), "Should only have worker2 left");
}
This diff is collapsed.
...@@ -197,12 +197,14 @@ async fn test_unsupported_endpoints() { ...@@ -197,12 +197,14 @@ async fn test_unsupported_endpoints() {
rid: None, rid: None,
}; };
let response = router.route_generate(None, &generate_request).await; let response = router.route_generate(None, &generate_request, None).await;
assert_eq!(response.status(), StatusCode::NOT_IMPLEMENTED); assert_eq!(response.status(), StatusCode::NOT_IMPLEMENTED);
// Test completion endpoint (should also not be supported) // Test completion endpoint (should also not be supported)
let completion_request = create_minimal_completion_request(); let completion_request = create_minimal_completion_request();
let response = router.route_completion(None, &completion_request).await; let response = router
.route_completion(None, &completion_request, None)
.await;
assert_eq!(response.status(), StatusCode::NOT_IMPLEMENTED); assert_eq!(response.status(), StatusCode::NOT_IMPLEMENTED);
} }
...@@ -228,7 +230,7 @@ async fn test_openai_router_chat_completion_with_mock() { ...@@ -228,7 +230,7 @@ async fn test_openai_router_chat_completion_with_mock() {
chat_request.temperature = Some(0.7); chat_request.temperature = Some(0.7);
// Route the request // Route the request
let response = router.route_chat(None, &chat_request).await; let response = router.route_chat(None, &chat_request, None).await;
// Should get a successful response from mock server // Should get a successful response from mock server
assert_eq!(response.status(), StatusCode::OK); assert_eq!(response.status(), StatusCode::OK);
...@@ -269,7 +271,9 @@ async fn test_openai_e2e_with_server() { ...@@ -269,7 +271,9 @@ async fn test_openai_e2e_with_server() {
let chat_request: ChatCompletionRequest = let chat_request: ChatCompletionRequest =
serde_json::from_str(&body_str).unwrap(); serde_json::from_str(&body_str).unwrap();
router.route_chat(Some(&parts.headers), &chat_request).await router
.route_chat(Some(&parts.headers), &chat_request, None)
.await
} }
} }
}), }),
...@@ -327,7 +331,7 @@ async fn test_openai_router_chat_streaming_with_mock() { ...@@ -327,7 +331,7 @@ async fn test_openai_router_chat_streaming_with_mock() {
}); });
let chat_request: ChatCompletionRequest = serde_json::from_value(val).unwrap(); let chat_request: ChatCompletionRequest = serde_json::from_value(val).unwrap();
let response = router.route_chat(None, &chat_request).await; let response = router.route_chat(None, &chat_request, None).await;
assert_eq!(response.status(), StatusCode::OK); assert_eq!(response.status(), StatusCode::OK);
// Should be SSE // Should be SSE
...@@ -371,7 +375,7 @@ async fn test_openai_router_circuit_breaker() { ...@@ -371,7 +375,7 @@ async fn test_openai_router_circuit_breaker() {
// First few requests should fail and record failures // First few requests should fail and record failures
for _ in 0..3 { for _ in 0..3 {
let response = router.route_chat(None, &chat_request).await; let response = router.route_chat(None, &chat_request, None).await;
// Should get either an error or circuit breaker response // Should get either an error or circuit breaker response
assert!( assert!(
response.status() == StatusCode::INTERNAL_SERVER_ERROR response.status() == StatusCode::INTERNAL_SERVER_ERROR
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment