Unverified Commit 2f173ea0 authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] allow one router to support different model families and serving mode (#10244)

parent 321fecab
This diff is collapsed.
......@@ -17,6 +17,7 @@ pub mod factory;
pub mod grpc;
pub mod header_utils;
pub mod http;
pub mod router_manager;
pub use factory::RouterFactory;
// Re-export HTTP routers for convenience (keeps routers::openai_router path working)
......@@ -63,14 +64,19 @@ pub trait RouterTrait: Send + Sync + Debug + WorkerManagement {
async fn get_model_info(&self, req: Request<Body>) -> Response;
/// Route a generate request
async fn route_generate(&self, headers: Option<&HeaderMap>, body: &GenerateRequest)
-> Response;
async fn route_generate(
&self,
headers: Option<&HeaderMap>,
body: &GenerateRequest,
model_id: Option<&str>,
) -> Response;
/// Route a chat completion request
async fn route_chat(
&self,
headers: Option<&HeaderMap>,
body: &ChatCompletionRequest,
model_id: Option<&str>,
) -> Response;
/// Route a completion request
......@@ -78,6 +84,7 @@ pub trait RouterTrait: Send + Sync + Debug + WorkerManagement {
&self,
headers: Option<&HeaderMap>,
body: &CompletionRequest,
model_id: Option<&str>,
) -> Response;
/// Route a responses request
......@@ -85,11 +92,17 @@ pub trait RouterTrait: Send + Sync + Debug + WorkerManagement {
&self,
headers: Option<&HeaderMap>,
body: &ResponsesRequest,
model_id: Option<&str>,
) -> Response;
async fn route_embeddings(&self, headers: Option<&HeaderMap>, body: Body) -> Response;
async fn route_rerank(&self, headers: Option<&HeaderMap>, body: &RerankRequest) -> Response;
async fn route_rerank(
&self,
headers: Option<&HeaderMap>,
body: &RerankRequest,
model_id: Option<&str>,
) -> Response;
/// Flush cache on all workers
async fn flush_cache(&self) -> Response;
......
This diff is collapsed.
This diff is collapsed.
......@@ -579,9 +579,8 @@ mod tests {
// Helper to create a Router instance for testing event handlers
async fn create_test_router() -> Arc<dyn RouterTrait> {
use crate::config::{PolicyConfig, RouterConfig};
use crate::config::RouterConfig;
use crate::middleware::TokenBucket;
use crate::policies::PolicyFactory;
use crate::routers::http::router::Router;
use crate::server::AppContext;
......@@ -591,15 +590,19 @@ mod tests {
// Create AppContext with minimal components
let app_context = Arc::new(AppContext {
client: reqwest::Client::new(),
router_config,
router_config: router_config.clone(),
rate_limiter: Arc::new(TokenBucket::new(1000, 1000)),
worker_registry: Arc::new(crate::core::WorkerRegistry::new()),
policy_registry: Arc::new(crate::policies::PolicyRegistry::new(
router_config.policy.clone(),
)),
tokenizer: None, // HTTP mode doesn't need tokenizer
reasoning_parser_factory: None, // HTTP mode doesn't need reasoning parser
tool_parser_registry: None, // HTTP mode doesn't need tool parser
router_manager: None, // Test doesn't need router manager
});
let policy = PolicyFactory::create_from_config(&PolicyConfig::Random);
let router = Router::new(vec![], policy, &app_context).await.unwrap();
let router = Router::new(vec![], &app_context).await.unwrap();
Arc::new(router) as Arc<dyn RouterTrait>
}
......
This diff is collapsed.
This diff is collapsed.
......@@ -197,12 +197,14 @@ async fn test_unsupported_endpoints() {
rid: None,
};
let response = router.route_generate(None, &generate_request).await;
let response = router.route_generate(None, &generate_request, None).await;
assert_eq!(response.status(), StatusCode::NOT_IMPLEMENTED);
// Test completion endpoint (should also not be supported)
let completion_request = create_minimal_completion_request();
let response = router.route_completion(None, &completion_request).await;
let response = router
.route_completion(None, &completion_request, None)
.await;
assert_eq!(response.status(), StatusCode::NOT_IMPLEMENTED);
}
......@@ -228,7 +230,7 @@ async fn test_openai_router_chat_completion_with_mock() {
chat_request.temperature = Some(0.7);
// Route the request
let response = router.route_chat(None, &chat_request).await;
let response = router.route_chat(None, &chat_request, None).await;
// Should get a successful response from mock server
assert_eq!(response.status(), StatusCode::OK);
......@@ -269,7 +271,9 @@ async fn test_openai_e2e_with_server() {
let chat_request: ChatCompletionRequest =
serde_json::from_str(&body_str).unwrap();
router.route_chat(Some(&parts.headers), &chat_request).await
router
.route_chat(Some(&parts.headers), &chat_request, None)
.await
}
}
}),
......@@ -327,7 +331,7 @@ async fn test_openai_router_chat_streaming_with_mock() {
});
let chat_request: ChatCompletionRequest = serde_json::from_value(val).unwrap();
let response = router.route_chat(None, &chat_request).await;
let response = router.route_chat(None, &chat_request, None).await;
assert_eq!(response.status(), StatusCode::OK);
// Should be SSE
......@@ -371,7 +375,7 @@ async fn test_openai_router_circuit_breaker() {
// First few requests should fail and record failures
for _ in 0..3 {
let response = router.route_chat(None, &chat_request).await;
let response = router.route_chat(None, &chat_request, None).await;
// Should get either an error or circuit breaker response
assert!(
response.status() == StatusCode::INTERNAL_SERVER_ERROR
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment