Unverified Commit 4c9bcb9d authored by Keyang Ru's avatar Keyang Ru Committed by GitHub
Browse files

[Router] Refactor protocol definitions: split spec.rs into modular files (#11677)


Co-authored-by: default avatarChang Su <chang.s.su@oracle.com>
parent 86b04d25
...@@ -19,7 +19,14 @@ use tracing::{debug, error, warn}; ...@@ -19,7 +19,14 @@ use tracing::{debug, error, warn};
use super::context; use super::context;
use super::utils; use super::utils;
use crate::grpc_client::proto; use crate::grpc_client::proto;
use crate::protocols::spec::*; use crate::protocols::chat::{
ChatCompletionRequest, ChatCompletionStreamResponse, ChatMessageDelta, ChatStreamChoice,
};
use crate::protocols::common::{
ChatLogProbs, FunctionCallDelta, StringOrArray, Tool, ToolCallDelta, ToolChoice,
ToolChoiceValue, Usage,
};
use crate::protocols::generate::GenerateRequest;
use crate::reasoning_parser::ReasoningParser; use crate::reasoning_parser::ReasoningParser;
use crate::tokenizer::stop::{SequenceDecoderOutput, StopSequenceDecoder}; use crate::tokenizer::stop::{SequenceDecoderOutput, StopSequenceDecoder};
use crate::tokenizer::traits::Tokenizer; use crate::tokenizer::traits::Tokenizer;
......
...@@ -4,10 +4,12 @@ use super::ProcessedMessages; ...@@ -4,10 +4,12 @@ use super::ProcessedMessages;
use crate::core::Worker; use crate::core::Worker;
use crate::grpc_client::sglang_scheduler::AbortOnDropStream; use crate::grpc_client::sglang_scheduler::AbortOnDropStream;
use crate::grpc_client::{proto, SglangSchedulerClient}; use crate::grpc_client::{proto, SglangSchedulerClient};
use crate::protocols::spec::{ use crate::protocols::chat::{ChatCompletionRequest, ChatMessage};
ChatCompletionRequest, ChatLogProbs, ChatLogProbsContent, ChatMessage, FunctionCallResponse, use crate::protocols::common::{
GenerateFinishReason, StringOrArray, Tool, ToolCall, ToolChoice, ToolChoiceValue, TopLogProb, ChatLogProbs, ChatLogProbsContent, FunctionCallResponse, StringOrArray, Tool, ToolCall,
ToolChoice, ToolChoiceValue, TopLogProb,
}; };
use crate::protocols::generate::GenerateFinishReason;
use crate::tokenizer::chat_template::{ChatTemplateContentFormat, ChatTemplateParams}; use crate::tokenizer::chat_template::{ChatTemplateContentFormat, ChatTemplateParams};
use crate::tokenizer::traits::Tokenizer; use crate::tokenizer::traits::Tokenizer;
use crate::tokenizer::HuggingFaceTokenizer; use crate::tokenizer::HuggingFaceTokenizer;
...@@ -952,7 +954,8 @@ pub fn parse_finish_reason(reason_str: &str, completion_tokens: i32) -> Generate ...@@ -952,7 +954,8 @@ pub fn parse_finish_reason(reason_str: &str, completion_tokens: i32) -> Generate
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::protocols::spec::{ChatMessage, ContentPart, ImageUrl, UserMessageContent}; use crate::protocols::chat::{ChatMessage, UserMessageContent};
use crate::protocols::common::{ContentPart, ImageUrl};
use crate::tokenizer::chat_template::ChatTemplateContentFormat; use crate::tokenizer::chat_template::ChatTemplateContentFormat;
use serde_json::json; use serde_json::json;
......
...@@ -5,10 +5,13 @@ use crate::core::{ ...@@ -5,10 +5,13 @@ use crate::core::{
}; };
use crate::metrics::RouterMetrics; use crate::metrics::RouterMetrics;
use crate::policies::{LoadBalancingPolicy, PolicyRegistry}; use crate::policies::{LoadBalancingPolicy, PolicyRegistry};
use crate::protocols::spec::{ use crate::protocols::chat::{ChatCompletionRequest, ChatMessage, UserMessageContent};
ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateRequest, RerankRequest, use crate::protocols::common::{InputIds, StringOrArray};
ResponsesGetParams, ResponsesRequest, StringOrArray, UserMessageContent, use crate::protocols::completion::CompletionRequest;
}; use crate::protocols::embedding::EmbeddingRequest;
use crate::protocols::generate::GenerateRequest;
use crate::protocols::rerank::RerankRequest;
use crate::protocols::responses::{ResponsesGetParams, ResponsesRequest};
use crate::routers::header_utils; use crate::routers::header_utils;
use crate::routers::RouterTrait; use crate::routers::RouterTrait;
use async_trait::async_trait; use async_trait::async_trait;
...@@ -150,9 +153,10 @@ impl PDRouter { ...@@ -150,9 +153,10 @@ impl PDRouter {
} }
fn get_generate_batch_size(req: &GenerateRequest) -> Option<usize> { fn get_generate_batch_size(req: &GenerateRequest) -> Option<usize> {
if let Some(text) = &req.text { // GenerateRequest doesn't support batch via arrays, only via input_ids
if text.contains("[") && text.contains("]") { if let Some(InputIds::Batch(batches)) = &req.input_ids {
return None; if !batches.is_empty() {
return Some(batches.len());
} }
} }
None None
...@@ -1185,7 +1189,7 @@ impl RouterTrait for PDRouter { ...@@ -1185,7 +1189,7 @@ impl RouterTrait for PDRouter {
async fn route_embeddings( async fn route_embeddings(
&self, &self,
_headers: Option<&HeaderMap>, _headers: Option<&HeaderMap>,
_body: &crate::protocols::spec::EmbeddingRequest, _body: &EmbeddingRequest,
_model_id: Option<&str>, _model_id: Option<&str>,
) -> Response { ) -> Response {
( (
......
...@@ -4,10 +4,13 @@ use crate::core::{ ...@@ -4,10 +4,13 @@ use crate::core::{
}; };
use crate::metrics::RouterMetrics; use crate::metrics::RouterMetrics;
use crate::policies::PolicyRegistry; use crate::policies::PolicyRegistry;
use crate::protocols::spec::{ use crate::protocols::chat::ChatCompletionRequest;
ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, GenerationRequest, use crate::protocols::common::GenerationRequest;
RerankRequest, RerankResponse, RerankResult, ResponsesGetParams, ResponsesRequest, use crate::protocols::completion::CompletionRequest;
}; use crate::protocols::embedding::EmbeddingRequest;
use crate::protocols::generate::GenerateRequest;
use crate::protocols::rerank::{RerankRequest, RerankResponse, RerankResult};
use crate::protocols::responses::{ResponsesGetParams, ResponsesRequest};
use crate::routers::header_utils; use crate::routers::header_utils;
use crate::routers::RouterTrait; use crate::routers::RouterTrait;
use axum::body::to_bytes; use axum::body::to_bytes;
...@@ -628,7 +631,7 @@ impl Router { ...@@ -628,7 +631,7 @@ impl Router {
let rerank_results = serde_json::from_slice::<Vec<RerankResult>>(&body_bytes)?; let rerank_results = serde_json::from_slice::<Vec<RerankResult>>(&body_bytes)?;
let mut rerank_response = let mut rerank_response =
RerankResponse::new(rerank_results, req.model.clone(), req.rid.clone()); RerankResponse::new(rerank_results, req.model.clone(), req.rid.clone());
rerank_response.sort_by_score(); // Sorting is handled by Python worker (serving_rerank.py)
if let Some(top_k) = req.top_k { if let Some(top_k) = req.top_k {
rerank_response.apply_top_k(top_k); rerank_response.apply_top_k(top_k);
} }
...@@ -748,9 +751,6 @@ impl RouterTrait for Router { ...@@ -748,9 +751,6 @@ impl RouterTrait for Router {
body: &RerankRequest, body: &RerankRequest,
model_id: Option<&str>, model_id: Option<&str>,
) -> Response { ) -> Response {
if let Err(e) = body.validate() {
return (StatusCode::BAD_REQUEST, e).into_response();
}
let response = self let response = self
.route_typed_request(headers, body, "/v1/rerank", model_id) .route_typed_request(headers, body, "/v1/rerank", model_id)
.await; .await;
......
...@@ -9,10 +9,12 @@ use axum::{ ...@@ -9,10 +9,12 @@ use axum::{
}; };
use std::fmt::Debug; use std::fmt::Debug;
use crate::protocols::spec::{ use crate::protocols::chat::ChatCompletionRequest;
ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest, use crate::protocols::completion::CompletionRequest;
ResponsesGetParams, ResponsesRequest, use crate::protocols::embedding::EmbeddingRequest;
}; use crate::protocols::generate::GenerateRequest;
use crate::protocols::rerank::RerankRequest;
use crate::protocols::responses::{ResponsesGetParams, ResponsesRequest};
use serde_json::Value; use serde_json::Value;
pub mod factory; pub mod factory;
......
...@@ -6,7 +6,7 @@ use crate::data_connector::{ ...@@ -6,7 +6,7 @@ use crate::data_connector::{
NewConversationItem, ResponseId, ResponseStorage, SharedConversationItemStorage, NewConversationItem, ResponseId, ResponseStorage, SharedConversationItemStorage,
SharedConversationStorage, SharedConversationStorage,
}; };
use crate::protocols::spec::{ResponseInput, ResponsesRequest}; use crate::protocols::responses::{ResponseInput, ResponseInputOutputItem, ResponsesRequest};
use axum::http::StatusCode; use axum::http::StatusCode;
use axum::response::{IntoResponse, Response}; use axum::response::{IntoResponse, Response};
use axum::Json; use axum::Json;
...@@ -1028,7 +1028,7 @@ async fn persist_items_with_storages( ...@@ -1028,7 +1028,7 @@ async fn persist_items_with_storages(
ResponseInput::Items(items_array) => { ResponseInput::Items(items_array) => {
for input_item in items_array { for input_item in items_array {
match input_item { match input_item {
crate::protocols::spec::ResponseInputOutputItem::Message { ResponseInputOutputItem::Message {
role, role,
content, content,
status, status,
......
...@@ -9,7 +9,9 @@ ...@@ -9,7 +9,9 @@
//! - Metadata injection for MCP operations //! - Metadata injection for MCP operations
use crate::mcp::McpClientManager; use crate::mcp::McpClientManager;
use crate::protocols::spec::{ResponseInput, ResponseToolType, ResponsesRequest}; use crate::protocols::responses::{
ResponseInput, ResponseTool, ResponseToolType, ResponsesRequest,
};
use crate::routers::header_utils::apply_request_headers; use crate::routers::header_utils::apply_request_headers;
use axum::http::HeaderMap; use axum::http::HeaderMap;
use bytes::Bytes; use bytes::Bytes;
...@@ -127,7 +129,7 @@ impl FunctionCallInProgress { ...@@ -127,7 +129,7 @@ impl FunctionCallInProgress {
/// Build a request-scoped MCP manager from request tools, if present. /// Build a request-scoped MCP manager from request tools, if present.
pub(super) async fn mcp_manager_from_request_tools( pub(super) async fn mcp_manager_from_request_tools(
tools: &[crate::protocols::spec::ResponseTool], tools: &[ResponseTool],
) -> Option<Arc<McpClientManager>> { ) -> Option<Arc<McpClientManager>> {
let tool = tools let tool = tools
.iter() .iter()
......
//! Response storage, patching, and extraction utilities //! Response storage, patching, and extraction utilities
use crate::data_connector::{ResponseId, StoredResponse}; use crate::data_connector::{ResponseId, StoredResponse};
use crate::protocols::spec::{ResponseInput, ResponseToolType, ResponsesRequest}; use crate::protocols::responses::{ResponseInput, ResponseToolType, ResponsesRequest};
use serde_json::{json, Value}; use serde_json::{json, Value};
use std::collections::HashMap; use std::collections::HashMap;
use tracing::warn; use tracing::warn;
......
...@@ -6,8 +6,12 @@ use crate::data_connector::{ ...@@ -6,8 +6,12 @@ use crate::data_connector::{
conversation_items::ListParams, conversation_items::SortOrder, ConversationId, ResponseId, conversation_items::ListParams, conversation_items::SortOrder, ConversationId, ResponseId,
SharedConversationItemStorage, SharedConversationStorage, SharedResponseStorage, SharedConversationItemStorage, SharedConversationStorage, SharedResponseStorage,
}; };
use crate::protocols::spec::{ use crate::protocols::chat::ChatCompletionRequest;
ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest, use crate::protocols::completion::CompletionRequest;
use crate::protocols::embedding::EmbeddingRequest;
use crate::protocols::generate::GenerateRequest;
use crate::protocols::rerank::RerankRequest;
use crate::protocols::responses::{
ResponseContentPart, ResponseInput, ResponseInputOutputItem, ResponsesGetParams, ResponseContentPart, ResponseInput, ResponseInputOutputItem, ResponsesGetParams,
ResponsesRequest, ResponsesRequest,
}; };
......
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
use crate::data_connector::{ use crate::data_connector::{
SharedConversationItemStorage, SharedConversationStorage, SharedResponseStorage, SharedConversationItemStorage, SharedConversationStorage, SharedResponseStorage,
}; };
use crate::protocols::spec::{ResponseToolType, ResponsesRequest}; use crate::protocols::responses::{ResponseToolType, ResponsesRequest};
use crate::routers::header_utils::{apply_request_headers, preserve_response_headers}; use crate::routers::header_utils::{apply_request_headers, preserve_response_headers};
use axum::{ use axum::{
body::Body, body::Body,
......
...@@ -6,10 +6,12 @@ ...@@ -6,10 +6,12 @@
use crate::config::{ConnectionMode, RoutingMode}; use crate::config::{ConnectionMode, RoutingMode};
use crate::core::{WorkerRegistry, WorkerType}; use crate::core::{WorkerRegistry, WorkerType};
use crate::protocols::spec::{ use crate::protocols::chat::ChatCompletionRequest;
ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest, use crate::protocols::completion::CompletionRequest;
ResponsesGetParams, ResponsesRequest, use crate::protocols::embedding::EmbeddingRequest;
}; use crate::protocols::generate::GenerateRequest;
use crate::protocols::rerank::RerankRequest;
use crate::protocols::responses::{ResponsesGetParams, ResponsesRequest};
use crate::routers::RouterTrait; use crate::routers::RouterTrait;
use crate::server::{AppContext, ServerConfig}; use crate::server::{AppContext, ServerConfig};
use async_trait::async_trait; use async_trait::async_trait;
......
...@@ -15,10 +15,12 @@ use crate::{ ...@@ -15,10 +15,12 @@ use crate::{
middleware::{self, AuthConfig, QueuedRequest, TokenBucket}, middleware::{self, AuthConfig, QueuedRequest, TokenBucket},
policies::PolicyRegistry, policies::PolicyRegistry,
protocols::{ protocols::{
spec::{ chat::ChatCompletionRequest,
ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, completion::CompletionRequest,
RerankRequest, ResponsesGetParams, ResponsesRequest, V1RerankReqInput, embedding::EmbeddingRequest,
}, generate::GenerateRequest,
rerank::{RerankRequest, V1RerankReqInput},
responses::{ResponsesGetParams, ResponsesRequest},
validated::ValidatedJson, validated::ValidatedJson,
worker_spec::{WorkerConfigRequest, WorkerErrorResponse, WorkerInfo}, worker_spec::{WorkerConfigRequest, WorkerErrorResponse, WorkerInfo},
}, },
...@@ -223,7 +225,7 @@ async fn v1_completions( ...@@ -223,7 +225,7 @@ async fn v1_completions(
async fn rerank( async fn rerank(
State(state): State<Arc<AppState>>, State(state): State<Arc<AppState>>,
headers: http::HeaderMap, headers: http::HeaderMap,
Json(body): Json<RerankRequest>, ValidatedJson(body): ValidatedJson<RerankRequest>,
) -> Response { ) -> Response {
state.router.route_rerank(Some(&headers), &body, None).await state.router.route_rerank(Some(&headers), &body, None).await
} }
......
...@@ -2,7 +2,7 @@ use async_trait::async_trait; ...@@ -2,7 +2,7 @@ use async_trait::async_trait;
use regex::Regex; use regex::Regex;
use serde_json::Value; use serde_json::Value;
use crate::protocols::spec::Tool; use crate::protocols::common::Tool;
use crate::tool_parser::{ use crate::tool_parser::{
errors::{ParserError, ParserResult}, errors::{ParserError, ParserResult},
......
...@@ -2,7 +2,7 @@ use async_trait::async_trait; ...@@ -2,7 +2,7 @@ use async_trait::async_trait;
use regex::Regex; use regex::Regex;
use serde_json::Value; use serde_json::Value;
use crate::protocols::spec::Tool; use crate::protocols::common::Tool;
use crate::tool_parser::{ use crate::tool_parser::{
errors::{ParserError, ParserResult}, errors::{ParserError, ParserResult},
......
use async_trait::async_trait; use async_trait::async_trait;
use crate::protocols::spec::Tool; use crate::protocols::common::Tool;
use crate::tool_parser::{ use crate::tool_parser::{
errors::ParserResult, errors::ParserResult,
......
...@@ -2,7 +2,7 @@ use async_trait::async_trait; ...@@ -2,7 +2,7 @@ use async_trait::async_trait;
use regex::Regex; use regex::Regex;
use serde_json::Value; use serde_json::Value;
use crate::protocols::spec::Tool; use crate::protocols::common::Tool;
use crate::tool_parser::{ use crate::tool_parser::{
errors::{ParserError, ParserResult}, errors::{ParserError, ParserResult},
......
use crate::protocols::spec::Tool; use crate::protocols::common::Tool;
use serde_json::Value; use serde_json::Value;
use std::collections::HashMap; use std::collections::HashMap;
......
use async_trait::async_trait; use async_trait::async_trait;
use serde_json::Value; use serde_json::Value;
use crate::protocols::spec::Tool; use crate::protocols::common::Tool;
use crate::tool_parser::{ use crate::tool_parser::{
errors::{ParserError, ParserResult}, errors::{ParserError, ParserResult},
......
...@@ -2,7 +2,7 @@ use async_trait::async_trait; ...@@ -2,7 +2,7 @@ use async_trait::async_trait;
use regex::Regex; use regex::Regex;
use serde_json::Value; use serde_json::Value;
use crate::protocols::spec::Tool; use crate::protocols::common::Tool;
use crate::tool_parser::{ use crate::tool_parser::{
errors::ParserResult, errors::ParserResult,
......
use async_trait::async_trait; use async_trait::async_trait;
use serde_json::Value; use serde_json::Value;
use crate::protocols::spec::Tool; use crate::protocols::common::Tool;
use crate::tool_parser::{ use crate::tool_parser::{
errors::{ParserError, ParserResult}, errors::{ParserError, ParserResult},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment