Unverified Commit 2f173ea0 authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] allow one router to support different model families and serving mode (#10244)

parent 321fecab
This diff is collapsed.
...@@ -17,6 +17,7 @@ pub mod factory; ...@@ -17,6 +17,7 @@ pub mod factory;
pub mod grpc; pub mod grpc;
pub mod header_utils; pub mod header_utils;
pub mod http; pub mod http;
pub mod router_manager;
pub use factory::RouterFactory; pub use factory::RouterFactory;
// Re-export HTTP routers for convenience (keeps routers::openai_router path working) // Re-export HTTP routers for convenience (keeps routers::openai_router path working)
...@@ -63,14 +64,19 @@ pub trait RouterTrait: Send + Sync + Debug + WorkerManagement { ...@@ -63,14 +64,19 @@ pub trait RouterTrait: Send + Sync + Debug + WorkerManagement {
async fn get_model_info(&self, req: Request<Body>) -> Response; async fn get_model_info(&self, req: Request<Body>) -> Response;
/// Route a generate request /// Route a generate request
async fn route_generate(&self, headers: Option<&HeaderMap>, body: &GenerateRequest) async fn route_generate(
-> Response; &self,
headers: Option<&HeaderMap>,
body: &GenerateRequest,
model_id: Option<&str>,
) -> Response;
/// Route a chat completion request /// Route a chat completion request
async fn route_chat( async fn route_chat(
&self, &self,
headers: Option<&HeaderMap>, headers: Option<&HeaderMap>,
body: &ChatCompletionRequest, body: &ChatCompletionRequest,
model_id: Option<&str>,
) -> Response; ) -> Response;
/// Route a completion request /// Route a completion request
...@@ -78,6 +84,7 @@ pub trait RouterTrait: Send + Sync + Debug + WorkerManagement { ...@@ -78,6 +84,7 @@ pub trait RouterTrait: Send + Sync + Debug + WorkerManagement {
&self, &self,
headers: Option<&HeaderMap>, headers: Option<&HeaderMap>,
body: &CompletionRequest, body: &CompletionRequest,
model_id: Option<&str>,
) -> Response; ) -> Response;
/// Route a responses request /// Route a responses request
...@@ -85,11 +92,17 @@ pub trait RouterTrait: Send + Sync + Debug + WorkerManagement { ...@@ -85,11 +92,17 @@ pub trait RouterTrait: Send + Sync + Debug + WorkerManagement {
&self, &self,
headers: Option<&HeaderMap>, headers: Option<&HeaderMap>,
body: &ResponsesRequest, body: &ResponsesRequest,
model_id: Option<&str>,
) -> Response; ) -> Response;
async fn route_embeddings(&self, headers: Option<&HeaderMap>, body: Body) -> Response; async fn route_embeddings(&self, headers: Option<&HeaderMap>, body: Body) -> Response;
async fn route_rerank(&self, headers: Option<&HeaderMap>, body: &RerankRequest) -> Response; async fn route_rerank(
&self,
headers: Option<&HeaderMap>,
body: &RerankRequest,
model_id: Option<&str>,
) -> Response;
/// Flush cache on all workers /// Flush cache on all workers
async fn flush_cache(&self) -> Response; async fn flush_cache(&self) -> Response;
......
This diff is collapsed.
This diff is collapsed.
...@@ -579,9 +579,8 @@ mod tests { ...@@ -579,9 +579,8 @@ mod tests {
// Helper to create a Router instance for testing event handlers // Helper to create a Router instance for testing event handlers
async fn create_test_router() -> Arc<dyn RouterTrait> { async fn create_test_router() -> Arc<dyn RouterTrait> {
use crate::config::{PolicyConfig, RouterConfig}; use crate::config::RouterConfig;
use crate::middleware::TokenBucket; use crate::middleware::TokenBucket;
use crate::policies::PolicyFactory;
use crate::routers::http::router::Router; use crate::routers::http::router::Router;
use crate::server::AppContext; use crate::server::AppContext;
...@@ -591,15 +590,19 @@ mod tests { ...@@ -591,15 +590,19 @@ mod tests {
// Create AppContext with minimal components // Create AppContext with minimal components
let app_context = Arc::new(AppContext { let app_context = Arc::new(AppContext {
client: reqwest::Client::new(), client: reqwest::Client::new(),
router_config, router_config: router_config.clone(),
rate_limiter: Arc::new(TokenBucket::new(1000, 1000)), rate_limiter: Arc::new(TokenBucket::new(1000, 1000)),
worker_registry: Arc::new(crate::core::WorkerRegistry::new()),
policy_registry: Arc::new(crate::policies::PolicyRegistry::new(
router_config.policy.clone(),
)),
tokenizer: None, // HTTP mode doesn't need tokenizer tokenizer: None, // HTTP mode doesn't need tokenizer
reasoning_parser_factory: None, // HTTP mode doesn't need reasoning parser reasoning_parser_factory: None, // HTTP mode doesn't need reasoning parser
tool_parser_registry: None, // HTTP mode doesn't need tool parser tool_parser_registry: None, // HTTP mode doesn't need tool parser
router_manager: None, // Test doesn't need router manager
}); });
let policy = PolicyFactory::create_from_config(&PolicyConfig::Random); let router = Router::new(vec![], &app_context).await.unwrap();
let router = Router::new(vec![], policy, &app_context).await.unwrap();
Arc::new(router) as Arc<dyn RouterTrait> Arc::new(router) as Arc<dyn RouterTrait>
} }
......
This diff is collapsed.
This diff is collapsed.
...@@ -197,12 +197,14 @@ async fn test_unsupported_endpoints() { ...@@ -197,12 +197,14 @@ async fn test_unsupported_endpoints() {
rid: None, rid: None,
}; };
let response = router.route_generate(None, &generate_request).await; let response = router.route_generate(None, &generate_request, None).await;
assert_eq!(response.status(), StatusCode::NOT_IMPLEMENTED); assert_eq!(response.status(), StatusCode::NOT_IMPLEMENTED);
// Test completion endpoint (should also not be supported) // Test completion endpoint (should also not be supported)
let completion_request = create_minimal_completion_request(); let completion_request = create_minimal_completion_request();
let response = router.route_completion(None, &completion_request).await; let response = router
.route_completion(None, &completion_request, None)
.await;
assert_eq!(response.status(), StatusCode::NOT_IMPLEMENTED); assert_eq!(response.status(), StatusCode::NOT_IMPLEMENTED);
} }
...@@ -228,7 +230,7 @@ async fn test_openai_router_chat_completion_with_mock() { ...@@ -228,7 +230,7 @@ async fn test_openai_router_chat_completion_with_mock() {
chat_request.temperature = Some(0.7); chat_request.temperature = Some(0.7);
// Route the request // Route the request
let response = router.route_chat(None, &chat_request).await; let response = router.route_chat(None, &chat_request, None).await;
// Should get a successful response from mock server // Should get a successful response from mock server
assert_eq!(response.status(), StatusCode::OK); assert_eq!(response.status(), StatusCode::OK);
...@@ -269,7 +271,9 @@ async fn test_openai_e2e_with_server() { ...@@ -269,7 +271,9 @@ async fn test_openai_e2e_with_server() {
let chat_request: ChatCompletionRequest = let chat_request: ChatCompletionRequest =
serde_json::from_str(&body_str).unwrap(); serde_json::from_str(&body_str).unwrap();
router.route_chat(Some(&parts.headers), &chat_request).await router
.route_chat(Some(&parts.headers), &chat_request, None)
.await
} }
} }
}), }),
...@@ -327,7 +331,7 @@ async fn test_openai_router_chat_streaming_with_mock() { ...@@ -327,7 +331,7 @@ async fn test_openai_router_chat_streaming_with_mock() {
}); });
let chat_request: ChatCompletionRequest = serde_json::from_value(val).unwrap(); let chat_request: ChatCompletionRequest = serde_json::from_value(val).unwrap();
let response = router.route_chat(None, &chat_request).await; let response = router.route_chat(None, &chat_request, None).await;
assert_eq!(response.status(), StatusCode::OK); assert_eq!(response.status(), StatusCode::OK);
// Should be SSE // Should be SSE
...@@ -371,7 +375,7 @@ async fn test_openai_router_circuit_breaker() { ...@@ -371,7 +375,7 @@ async fn test_openai_router_circuit_breaker() {
// First few requests should fail and record failures // First few requests should fail and record failures
for _ in 0..3 { for _ in 0..3 {
let response = router.route_chat(None, &chat_request).await; let response = router.route_chat(None, &chat_request, None).await;
// Should get either an error or circuit breaker response // Should get either an error or circuit breaker response
assert!( assert!(
response.status() == StatusCode::INTERNAL_SERVER_ERROR response.status() == StatusCode::INTERNAL_SERVER_ERROR
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment