Unverified Commit d513ee93 authored by ybyang's avatar ybyang Committed by GitHub
Browse files

[2/2] [feature] support openai like classification api in router (#11670)

parent a7ae61ed
...@@ -480,6 +480,26 @@ impl RouterMetrics { ...@@ -480,6 +480,26 @@ impl RouterMetrics {
gauge!("sgl_router_embeddings_queue_size").set(size as f64); gauge!("sgl_router_embeddings_queue_size").set(size as f64);
} }
pub fn record_classify_request() {
counter!("sgl_router_classify_total").increment(1);
}
pub fn record_classify_duration(duration: Duration) {
histogram!("sgl_router_classify_duration_seconds").record(duration.as_secs_f64());
}
pub fn record_classify_error(error_type: &str) {
counter!(
"sgl_router_classify_errors_total",
"error_type" => error_type.to_string()
)
.increment(1);
}
pub fn set_classify_queue_size(size: usize) {
gauge!("sgl_router_classify_queue_size").set(size as f64);
}
pub fn set_running_requests(worker: &str, count: usize) { pub fn set_running_requests(worker: &str, count: usize) {
gauge!("sgl_router_running_requests", gauge!("sgl_router_running_requests",
"worker" => worker.to_string() "worker" => worker.to_string()
......
use serde::{Deserialize, Serialize};
use serde_json::Value;
use super::common::GenerationRequest;
// ============================================================================
// Embedding API
// ============================================================================
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ClassifyRequest {
/// ID of the model to use
pub model: String,
/// Input can be a string, array of strings, tokens, or batch inputs
pub input: Value,
/// Optional encoding format (e.g., "float", "base64")
#[serde(skip_serializing_if = "Option::is_none")]
pub encoding_format: Option<String>,
/// Optional user identifier
#[serde(skip_serializing_if = "Option::is_none")]
pub user: Option<String>,
/// Optional number of dimensions for the embedding
#[serde(skip_serializing_if = "Option::is_none")]
pub dimensions: Option<u32>,
/// SGLang extension: request id for tracking
#[serde(skip_serializing_if = "Option::is_none")]
pub rid: Option<String>,
}
impl GenerationRequest for ClassifyRequest {
fn is_stream(&self) -> bool {
// Embeddings are non-streaming
false
}
fn get_model(&self) -> Option<&str> {
Some(&self.model)
}
fn extract_text_for_routing(&self) -> String {
// Best effort: extract text content for routing decisions
match &self.input {
Value::String(s) => s.clone(),
Value::Array(arr) => arr
.iter()
.filter_map(|v| v.as_str())
.collect::<Vec<_>>()
.join(" "),
_ => String::new(),
}
}
}
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
// This module provides a structured approach to handling different API protocols // This module provides a structured approach to handling different API protocols
pub mod chat; pub mod chat;
pub mod classify;
pub mod common; pub mod common;
pub mod completion; pub mod completion;
pub mod embedding; pub mod embedding;
......
...@@ -18,6 +18,7 @@ use crate::{ ...@@ -18,6 +18,7 @@ use crate::{
policies::PolicyRegistry, policies::PolicyRegistry,
protocols::{ protocols::{
chat::ChatCompletionRequest, chat::ChatCompletionRequest,
classify::ClassifyRequest,
completion::CompletionRequest, completion::CompletionRequest,
embedding::EmbeddingRequest, embedding::EmbeddingRequest,
generate::GenerateRequest, generate::GenerateRequest,
...@@ -254,6 +255,15 @@ impl RouterTrait for GrpcPDRouter { ...@@ -254,6 +255,15 @@ impl RouterTrait for GrpcPDRouter {
(StatusCode::NOT_IMPLEMENTED).into_response() (StatusCode::NOT_IMPLEMENTED).into_response()
} }
async fn route_classify(
&self,
_headers: Option<&HeaderMap>,
_body: &ClassifyRequest,
_model_id: Option<&str>,
) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn route_embeddings( async fn route_embeddings(
&self, &self,
_headers: Option<&HeaderMap>, _headers: Option<&HeaderMap>,
......
...@@ -18,6 +18,7 @@ use crate::{ ...@@ -18,6 +18,7 @@ use crate::{
policies::PolicyRegistry, policies::PolicyRegistry,
protocols::{ protocols::{
chat::ChatCompletionRequest, chat::ChatCompletionRequest,
classify::ClassifyRequest,
completion::CompletionRequest, completion::CompletionRequest,
embedding::EmbeddingRequest, embedding::EmbeddingRequest,
generate::GenerateRequest, generate::GenerateRequest,
...@@ -236,6 +237,15 @@ impl RouterTrait for GrpcRouter { ...@@ -236,6 +237,15 @@ impl RouterTrait for GrpcRouter {
(StatusCode::NOT_IMPLEMENTED).into_response() (StatusCode::NOT_IMPLEMENTED).into_response()
} }
async fn route_classify(
&self,
_headers: Option<&HeaderMap>,
_body: &ClassifyRequest,
_model_id: Option<&str>,
) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn route_embeddings( async fn route_embeddings(
&self, &self,
_headers: Option<&HeaderMap>, _headers: Option<&HeaderMap>,
......
...@@ -24,6 +24,7 @@ use crate::{ ...@@ -24,6 +24,7 @@ use crate::{
policies::{LoadBalancingPolicy, PolicyRegistry}, policies::{LoadBalancingPolicy, PolicyRegistry},
protocols::{ protocols::{
chat::{ChatCompletionRequest, ChatMessage, UserMessageContent}, chat::{ChatCompletionRequest, ChatMessage, UserMessageContent},
classify::ClassifyRequest,
common::{InputIds, StringOrArray}, common::{InputIds, StringOrArray},
completion::CompletionRequest, completion::CompletionRequest,
embedding::EmbeddingRequest, embedding::EmbeddingRequest,
...@@ -1190,6 +1191,19 @@ impl RouterTrait for PDRouter { ...@@ -1190,6 +1191,19 @@ impl RouterTrait for PDRouter {
.into_response() .into_response()
} }
async fn route_classify(
&self,
_headers: Option<&HeaderMap>,
_body: &ClassifyRequest,
_model_id: Option<&str>,
) -> Response {
(
StatusCode::NOT_IMPLEMENTED,
"Classify endpoint not implemented for PD router",
)
.into_response()
}
async fn route_embeddings( async fn route_embeddings(
&self, &self,
_headers: Option<&HeaderMap>, _headers: Option<&HeaderMap>,
......
...@@ -24,6 +24,7 @@ use crate::{ ...@@ -24,6 +24,7 @@ use crate::{
policies::PolicyRegistry, policies::PolicyRegistry,
protocols::{ protocols::{
chat::ChatCompletionRequest, chat::ChatCompletionRequest,
classify::ClassifyRequest,
common::GenerationRequest, common::GenerationRequest,
completion::CompletionRequest, completion::CompletionRequest,
embedding::EmbeddingRequest, embedding::EmbeddingRequest,
...@@ -749,6 +750,30 @@ impl RouterTrait for Router { ...@@ -749,6 +750,30 @@ impl RouterTrait for Router {
res res
} }
async fn route_classify(
&self,
headers: Option<&HeaderMap>,
body: &ClassifyRequest,
model_id: Option<&str>,
) -> Response {
// Record classification-specific metrics in addition to general request metrics
let start = Instant::now();
let res = self
.route_typed_request(headers, body, "/v1/classify", model_id)
.await;
// Classification specific metrics
if res.status().is_success() {
RouterMetrics::record_classify_request();
RouterMetrics::record_classify_duration(start.elapsed());
} else {
let error_type = format!("http_{}", res.status().as_u16());
RouterMetrics::record_classify_error(&error_type);
}
res
}
async fn route_rerank( async fn route_rerank(
&self, &self,
headers: Option<&HeaderMap>, headers: Option<&HeaderMap>,
......
...@@ -13,6 +13,7 @@ use serde_json::Value; ...@@ -13,6 +13,7 @@ use serde_json::Value;
use crate::protocols::{ use crate::protocols::{
chat::ChatCompletionRequest, chat::ChatCompletionRequest,
classify::ClassifyRequest,
completion::CompletionRequest, completion::CompletionRequest,
embedding::EmbeddingRequest, embedding::EmbeddingRequest,
generate::GenerateRequest, generate::GenerateRequest,
...@@ -125,6 +126,14 @@ pub trait RouterTrait: Send + Sync + Debug { ...@@ -125,6 +126,14 @@ pub trait RouterTrait: Send + Sync + Debug {
model_id: Option<&str>, model_id: Option<&str>,
) -> Response; ) -> Response;
/// Route classification requests (OpenAI-compatible /v1/classify)
async fn route_classify(
&self,
headers: Option<&HeaderMap>,
body: &ClassifyRequest,
model_id: Option<&str>,
) -> Response;
async fn route_rerank( async fn route_rerank(
&self, &self,
headers: Option<&HeaderMap>, headers: Option<&HeaderMap>,
......
...@@ -41,10 +41,11 @@ pub(super) async fn create_conversation( ...@@ -41,10 +41,11 @@ pub(super) async fn create_conversation(
return ( return (
StatusCode::BAD_REQUEST, StatusCode::BAD_REQUEST,
Json(json!({ Json(json!({
"error": format!( "error":
"metadata cannot have more than {} properties", format!(
MAX_METADATA_PROPERTIES "metadata cannot have more than {} properties",
) MAX_METADATA_PROPERTIES
)
})), })),
) )
.into_response(); .into_response();
...@@ -70,7 +71,9 @@ pub(super) async fn create_conversation( ...@@ -70,7 +71,9 @@ pub(super) async fn create_conversation(
} }
Err(e) => ( Err(e) => (
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("Failed to create conversation: {}", e)})), Json(json!({
"error": format!("Failed to create conversation: {}", e)
})),
) )
.into_response(), .into_response(),
} }
...@@ -97,7 +100,9 @@ pub(super) async fn get_conversation( ...@@ -97,7 +100,9 @@ pub(super) async fn get_conversation(
.into_response(), .into_response(),
Err(e) => ( Err(e) => (
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("Failed to get conversation: {}", e)})), Json(json!({
"error": format!("Failed to get conversation: {}", e)
})),
) )
.into_response(), .into_response(),
} }
...@@ -126,7 +131,9 @@ pub(super) async fn update_conversation( ...@@ -126,7 +131,9 @@ pub(super) async fn update_conversation(
Err(e) => { Err(e) => {
return ( return (
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("Failed to get conversation: {}", e)})), Json(json!({
"error": format!("Failed to get conversation: {}", e)
})),
) )
.into_response(); .into_response();
} }
...@@ -174,10 +181,11 @@ pub(super) async fn update_conversation( ...@@ -174,10 +181,11 @@ pub(super) async fn update_conversation(
return ( return (
StatusCode::BAD_REQUEST, StatusCode::BAD_REQUEST,
Json(json!({ Json(json!({
"error": format!( "error":
"metadata cannot have more than {} properties", format!(
MAX_METADATA_PROPERTIES "metadata cannot have more than {} properties",
) MAX_METADATA_PROPERTIES
)
})), })),
) )
.into_response(); .into_response();
...@@ -204,7 +212,9 @@ pub(super) async fn update_conversation( ...@@ -204,7 +212,9 @@ pub(super) async fn update_conversation(
.into_response(), .into_response(),
Err(e) => ( Err(e) => (
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("Failed to update conversation: {}", e)})), Json(json!({
"error": format!("Failed to update conversation: {}", e)
})),
) )
.into_response(), .into_response(),
} }
...@@ -232,7 +242,9 @@ pub(super) async fn delete_conversation( ...@@ -232,7 +242,9 @@ pub(super) async fn delete_conversation(
Err(e) => { Err(e) => {
return ( return (
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("Failed to get conversation: {}", e)})), Json(json!({
"error": format!("Failed to get conversation: {}", e)
})),
) )
.into_response(); .into_response();
} }
...@@ -256,7 +268,9 @@ pub(super) async fn delete_conversation( ...@@ -256,7 +268,9 @@ pub(super) async fn delete_conversation(
} }
Err(e) => ( Err(e) => (
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("Failed to delete conversation: {}", e)})), Json(json!({
"error": format!("Failed to delete conversation: {}", e)
})),
) )
.into_response(), .into_response(),
} }
...@@ -286,7 +300,9 @@ pub(super) async fn list_conversation_items( ...@@ -286,7 +300,9 @@ pub(super) async fn list_conversation_items(
Err(e) => { Err(e) => {
return ( return (
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("Failed to get conversation: {}", e)})), Json(json!({
"error": format!("Failed to get conversation: {}", e)
})),
) )
.into_response(); .into_response();
} }
...@@ -346,7 +362,7 @@ pub(super) async fn list_conversation_items( ...@@ -346,7 +362,7 @@ pub(super) async fn list_conversation_items(
} }
Err(e) => ( Err(e) => (
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("Failed to list items: {}", e)})), Json(json!({ "error": format!("Failed to list items: {}", e) })),
) )
.into_response(), .into_response(),
} }
...@@ -417,7 +433,9 @@ pub(super) async fn create_conversation_items( ...@@ -417,7 +433,9 @@ pub(super) async fn create_conversation_items(
Err(e) => { Err(e) => {
return ( return (
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("Failed to get conversation: {}", e)})), Json(json!({
"error": format!("Failed to get conversation: {}", e)
})),
) )
.into_response(); .into_response();
} }
...@@ -476,14 +494,18 @@ pub(super) async fn create_conversation_items( ...@@ -476,14 +494,18 @@ pub(super) async fn create_conversation_items(
Ok(None) => { Ok(None) => {
return ( return (
StatusCode::NOT_FOUND, StatusCode::NOT_FOUND,
Json(json!({"error": format!("Referenced item '{}' not found", ref_id)})), Json(json!({
"error": format!("Referenced item '{}' not found", ref_id)
})),
) )
.into_response(); .into_response();
} }
Err(e) => { Err(e) => {
return ( return (
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("Failed to get referenced item: {}", e)})), Json(json!({
"error": format!("Failed to get referenced item: {}", e)
})),
) )
.into_response(); .into_response();
} }
...@@ -517,7 +539,9 @@ pub(super) async fn create_conversation_items( ...@@ -517,7 +539,9 @@ pub(super) async fn create_conversation_items(
Err(e) => { Err(e) => {
return ( return (
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("Failed to check item link: {}", e)})), Json(json!({
"error": format!("Failed to check item link: {}", e)
})),
) )
.into_response(); .into_response();
} }
...@@ -553,7 +577,7 @@ pub(super) async fn create_conversation_items( ...@@ -553,7 +577,7 @@ pub(super) async fn create_conversation_items(
Err(e) => { Err(e) => {
return ( return (
StatusCode::BAD_REQUEST, StatusCode::BAD_REQUEST,
Json(json!({"error": format!("Invalid item: {}", e)})), Json(json!({ "error": format!("Invalid item: {}", e) })),
) )
.into_response(); .into_response();
} }
...@@ -570,7 +594,7 @@ pub(super) async fn create_conversation_items( ...@@ -570,7 +594,7 @@ pub(super) async fn create_conversation_items(
Err(e) => { Err(e) => {
return ( return (
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("Failed to create item: {}", e)})), Json(json!({ "error": format!("Failed to create item: {}", e) })),
) )
.into_response(); .into_response();
} }
...@@ -579,7 +603,9 @@ pub(super) async fn create_conversation_items( ...@@ -579,7 +603,9 @@ pub(super) async fn create_conversation_items(
Err(e) => { Err(e) => {
return ( return (
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("Failed to check item existence: {}", e)})), Json(json!({
"error": format!("Failed to check item existence: {}", e)
})),
) )
.into_response(); .into_response();
} }
...@@ -593,7 +619,7 @@ pub(super) async fn create_conversation_items( ...@@ -593,7 +619,7 @@ pub(super) async fn create_conversation_items(
Err(e) => { Err(e) => {
return ( return (
StatusCode::BAD_REQUEST, StatusCode::BAD_REQUEST,
Json(json!({"error": format!("Invalid item: {}", e)})), Json(json!({ "error": format!("Invalid item: {}", e) })),
) )
.into_response(); .into_response();
} }
...@@ -610,7 +636,7 @@ pub(super) async fn create_conversation_items( ...@@ -610,7 +636,7 @@ pub(super) async fn create_conversation_items(
Err(e) => { Err(e) => {
return ( return (
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("Failed to create item: {}", e)})), Json(json!({ "error": format!("Failed to create item: {}", e) })),
) )
.into_response(); .into_response();
} }
...@@ -678,7 +704,9 @@ pub(super) async fn get_conversation_item( ...@@ -678,7 +704,9 @@ pub(super) async fn get_conversation_item(
Err(e) => { Err(e) => {
return ( return (
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("Failed to get conversation: {}", e)})), Json(json!({
"error": format!("Failed to get conversation: {}", e)
})),
) )
.into_response(); .into_response();
} }
...@@ -693,7 +721,9 @@ pub(super) async fn get_conversation_item( ...@@ -693,7 +721,9 @@ pub(super) async fn get_conversation_item(
Err(e) => { Err(e) => {
return ( return (
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("Failed to check item link: {}", e)})), Json(json!({
"error": format!("Failed to check item link: {}", e)
})),
) )
.into_response(); .into_response();
} }
...@@ -721,7 +751,7 @@ pub(super) async fn get_conversation_item( ...@@ -721,7 +751,7 @@ pub(super) async fn get_conversation_item(
.into_response(), .into_response(),
Err(e) => ( Err(e) => (
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("Failed to get item: {}", e)})), Json(json!({ "error": format!("Failed to get item: {}", e) })),
) )
.into_response(), .into_response(),
} }
...@@ -753,7 +783,9 @@ pub(super) async fn delete_conversation_item( ...@@ -753,7 +783,9 @@ pub(super) async fn delete_conversation_item(
Err(e) => { Err(e) => {
return ( return (
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("Failed to get conversation: {}", e)})), Json(json!({
"error": format!("Failed to get conversation: {}", e)
})),
) )
.into_response(); .into_response();
} }
...@@ -773,7 +805,7 @@ pub(super) async fn delete_conversation_item( ...@@ -773,7 +805,7 @@ pub(super) async fn delete_conversation_item(
} }
Err(e) => ( Err(e) => (
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("Failed to delete item: {}", e)})), Json(json!({ "error": format!("Failed to delete item: {}", e) })),
) )
.into_response(), .into_response(),
} }
......
...@@ -156,7 +156,7 @@ pub(super) fn patch_streaming_response_json( ...@@ -156,7 +156,7 @@ pub(super) fn patch_streaming_response_json(
// Attach conversation id for client response if present (final aggregated JSON) // Attach conversation id for client response if present (final aggregated JSON)
if let Some(conv_id) = original_body.conversation.clone() { if let Some(conv_id) = original_body.conversation.clone() {
obj.insert("conversation".to_string(), json!({"id": conv_id})); obj.insert("conversation".to_string(), json!({ "id": conv_id }));
} }
} }
} }
...@@ -234,7 +234,7 @@ pub(super) fn rewrite_streaming_block( ...@@ -234,7 +234,7 @@ pub(super) fn rewrite_streaming_block(
// Attach conversation id into streaming event response content with ordering // Attach conversation id into streaming event response content with ordering
if let Some(conv_id) = original_body.conversation.clone() { if let Some(conv_id) = original_body.conversation.clone() {
response_obj.insert("conversation".to_string(), json!({"id": conv_id})); response_obj.insert("conversation".to_string(), json!({ "id": conv_id }));
changed = true; changed = true;
} }
} }
......
...@@ -42,6 +42,7 @@ use crate::{ ...@@ -42,6 +42,7 @@ use crate::{
}, },
protocols::{ protocols::{
chat::ChatCompletionRequest, chat::ChatCompletionRequest,
classify::ClassifyRequest,
completion::CompletionRequest, completion::CompletionRequest,
embedding::EmbeddingRequest, embedding::EmbeddingRequest,
generate::GenerateRequest, generate::GenerateRequest,
...@@ -828,7 +829,7 @@ impl crate::routers::RouterTrait for OpenAIRouter { ...@@ -828,7 +829,7 @@ impl crate::routers::RouterTrait for OpenAIRouter {
.into_response(), .into_response(),
Err(e) => ( Err(e) => (
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("Failed to get response: {}", e)})), Json(json!({ "error": format!("Failed to get response: {}", e) })),
) )
.into_response(), .into_response(),
} }
...@@ -882,6 +883,15 @@ impl crate::routers::RouterTrait for OpenAIRouter { ...@@ -882,6 +883,15 @@ impl crate::routers::RouterTrait for OpenAIRouter {
(StatusCode::NOT_IMPLEMENTED, "Rerank not supported").into_response() (StatusCode::NOT_IMPLEMENTED, "Rerank not supported").into_response()
} }
async fn route_classify(
&self,
_headers: Option<&HeaderMap>,
_body: &ClassifyRequest,
_model_id: Option<&str>,
) -> Response {
(StatusCode::NOT_IMPLEMENTED, "Classify not supported").into_response()
}
async fn create_conversation(&self, _headers: Option<&HeaderMap>, body: &Value) -> Response { async fn create_conversation(&self, _headers: Option<&HeaderMap>, body: &Value) -> Response {
create_conversation(&self.conversation_storage, body.clone()).await create_conversation(&self.conversation_storage, body.clone()).await
} }
......
...@@ -22,6 +22,7 @@ use crate::{ ...@@ -22,6 +22,7 @@ use crate::{
core::{WorkerRegistry, WorkerType}, core::{WorkerRegistry, WorkerType},
protocols::{ protocols::{
chat::ChatCompletionRequest, chat::ChatCompletionRequest,
classify::ClassifyRequest,
completion::CompletionRequest, completion::CompletionRequest,
embedding::EmbeddingRequest, embedding::EmbeddingRequest,
generate::GenerateRequest, generate::GenerateRequest,
...@@ -329,10 +330,7 @@ impl RouterTrait for RouterManager { ...@@ -329,10 +330,7 @@ impl RouterTrait for RouterManager {
} else { } else {
( (
StatusCode::OK, StatusCode::OK,
serde_json::json!({ serde_json::json!({ "models": models }).to_string(),
"models": models
})
.to_string(),
) )
.into_response() .into_response()
} }
...@@ -517,6 +515,25 @@ impl RouterTrait for RouterManager { ...@@ -517,6 +515,25 @@ impl RouterTrait for RouterManager {
} }
} }
async fn route_classify(
&self,
headers: Option<&HeaderMap>,
body: &ClassifyRequest,
model_id: Option<&str>,
) -> Response {
let router = self.select_router_for_request(headers, Some(&body.model));
if let Some(router) = router {
router.route_classify(headers, body, model_id).await
} else {
(
StatusCode::NOT_FOUND,
format!("Model '{}' not found or no router available", body.model),
)
.into_response()
}
}
fn router_type(&self) -> &'static str { fn router_type(&self) -> &'static str {
"manager" "manager"
} }
......
...@@ -37,6 +37,7 @@ use crate::{ ...@@ -37,6 +37,7 @@ use crate::{
policies::PolicyRegistry, policies::PolicyRegistry,
protocols::{ protocols::{
chat::ChatCompletionRequest, chat::ChatCompletionRequest,
classify::ClassifyRequest,
completion::CompletionRequest, completion::CompletionRequest,
embedding::EmbeddingRequest, embedding::EmbeddingRequest,
generate::GenerateRequest, generate::GenerateRequest,
...@@ -270,6 +271,17 @@ async fn v1_embeddings( ...@@ -270,6 +271,17 @@ async fn v1_embeddings(
.await .await
} }
async fn v1_classify(
State(state): State<Arc<AppState>>,
headers: http::HeaderMap,
Json(body): Json<ClassifyRequest>,
) -> Response {
state
.router
.route_classify(Some(&headers), &body, None)
.await
}
async fn v1_responses_get( async fn v1_responses_get(
State(state): State<Arc<AppState>>, State(state): State<Arc<AppState>>,
Path(response_id): Path<String>, Path(response_id): Path<String>,
...@@ -534,13 +546,7 @@ async fn get_loads(State(state): State<Arc<AppState>>, _req: Request) -> Respons ...@@ -534,13 +546,7 @@ async fn get_loads(State(state): State<Arc<AppState>>, _req: Request) -> Respons
}) })
.collect(); .collect();
( (StatusCode::OK, Json(json!({ "workers": loads }))).into_response()
StatusCode::OK,
Json(json!({
"workers": loads
})),
)
.into_response()
} }
async fn create_worker( async fn create_worker(
...@@ -707,6 +713,7 @@ pub fn build_app( ...@@ -707,6 +713,7 @@ pub fn build_app(
.route("/v1/rerank", post(v1_rerank)) .route("/v1/rerank", post(v1_rerank))
.route("/v1/responses", post(v1_responses)) .route("/v1/responses", post(v1_responses))
.route("/v1/embeddings", post(v1_embeddings)) .route("/v1/embeddings", post(v1_embeddings))
.route("/v1/classify", post(v1_classify))
.route("/v1/responses/{response_id}", get(v1_responses_get)) .route("/v1/responses/{response_id}", get(v1_responses_get))
.route( .route(
"/v1/responses/{response_id}/cancel", "/v1/responses/{response_id}/cancel",
......
...@@ -1617,7 +1617,7 @@ async fn test_conversation_items_max_limit() { ...@@ -1617,7 +1617,7 @@ async fn test_conversation_items_max_limit() {
"content": [{"type": "input_text", "text": format!("Message {}", i)}] "content": [{"type": "input_text", "text": format!("Message {}", i)}]
})); }));
} }
let create_items = serde_json::json!({"items": items}); let create_items = serde_json::json!({ "items": items });
let items_resp = router let items_resp = router
.create_conversation_items(None, conv_id, &create_items) .create_conversation_items(None, conv_id, &create_items)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment