Unverified Commit 4c9bcb9d authored by Keyang Ru's avatar Keyang Ru Committed by GitHub
Browse files

[Router] Refactor protocol definitions: split spec.rs into modular files (#11677)


Co-authored-by: default avatarChang Su <chang.s.su@oracle.com>
parent 86b04d25
use async_trait::async_trait; use async_trait::async_trait;
use serde_json::Value; use serde_json::Value;
use crate::protocols::spec::Tool; use crate::protocols::common::Tool;
use crate::tool_parser::{ use crate::tool_parser::{
errors::{ParserError, ParserResult}, errors::{ParserError, ParserResult},
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
//! tool call parsing should be performed. It simply returns the input text //! tool call parsing should be performed. It simply returns the input text
//! with no tool calls detected. //! with no tool calls detected.
use crate::protocols::spec::Tool; use crate::protocols::common::Tool;
use crate::tool_parser::errors::ParserResult; use crate::tool_parser::errors::ParserResult;
use crate::tool_parser::traits::ToolParser; use crate::tool_parser::traits::ToolParser;
use crate::tool_parser::types::{StreamingParseResult, ToolCall, ToolCallItem}; use crate::tool_parser::types::{StreamingParseResult, ToolCall, ToolCallItem};
......
...@@ -15,7 +15,7 @@ use rustpython_parser::{parse, Mode}; ...@@ -15,7 +15,7 @@ use rustpython_parser::{parse, Mode};
use serde_json::{Map, Number, Value}; use serde_json::{Map, Number, Value};
use std::sync::OnceLock; use std::sync::OnceLock;
use crate::protocols::spec::Tool; use crate::protocols::common::Tool;
use crate::tool_parser::{ use crate::tool_parser::{
errors::{ParserError, ParserResult}, errors::{ParserError, ParserResult},
......
...@@ -2,7 +2,7 @@ use async_trait::async_trait; ...@@ -2,7 +2,7 @@ use async_trait::async_trait;
use regex::Regex; use regex::Regex;
use serde_json::Value; use serde_json::Value;
use crate::protocols::spec::Tool; use crate::protocols::common::Tool;
use crate::tool_parser::{ use crate::tool_parser::{
errors::{ParserError, ParserResult}, errors::{ParserError, ParserResult},
......
...@@ -3,7 +3,7 @@ use regex::Regex; ...@@ -3,7 +3,7 @@ use regex::Regex;
use serde_json::Value; use serde_json::Value;
use std::collections::HashMap; use std::collections::HashMap;
use crate::protocols::spec::Tool; use crate::protocols::common::Tool;
use crate::tool_parser::{ use crate::tool_parser::{
errors::{ParserError, ParserResult}, errors::{ParserError, ParserResult},
......
use crate::protocols::spec::Tool; use crate::protocols::common::Tool;
use crate::tool_parser::{ use crate::tool_parser::{
errors::ParserResult, errors::ParserResult,
types::{StreamingParseResult, ToolCall}, types::{StreamingParseResult, ToolCall},
......
use sglang_router_rs::protocols::spec; use sglang_router_rs::protocols::chat::{ChatMessage, UserMessageContent};
use sglang_router_rs::tokenizer::chat_template::{ use sglang_router_rs::tokenizer::chat_template::{
detect_chat_template_content_format, ChatTemplateContentFormat, ChatTemplateParams, detect_chat_template_content_format, ChatTemplateContentFormat, ChatTemplateParams,
ChatTemplateProcessor, ChatTemplateProcessor,
...@@ -173,12 +173,12 @@ assistant: ...@@ -173,12 +173,12 @@ assistant:
let processor = ChatTemplateProcessor::new(template.to_string()); let processor = ChatTemplateProcessor::new(template.to_string());
let messages = [ let messages = [
spec::ChatMessage::System { ChatMessage::System {
content: "You are helpful".to_string(), content: "You are helpful".to_string(),
name: None, name: None,
}, },
spec::ChatMessage::User { ChatMessage::User {
content: spec::UserMessageContent::Text("Hello".to_string()), content: UserMessageContent::Text("Hello".to_string()),
name: None, name: None,
}, },
]; ];
...@@ -213,8 +213,8 @@ fn test_chat_template_with_tokens_unit_test() { ...@@ -213,8 +213,8 @@ fn test_chat_template_with_tokens_unit_test() {
let processor = ChatTemplateProcessor::new(template.to_string()); let processor = ChatTemplateProcessor::new(template.to_string());
let messages = [spec::ChatMessage::User { let messages = [ChatMessage::User {
content: spec::UserMessageContent::Text("Test".to_string()), content: UserMessageContent::Text("Test".to_string()),
name: None, name: None,
}]; }];
......
use sglang_router_rs::protocols::spec; use sglang_router_rs::protocols::chat::{ChatMessage, UserMessageContent};
use sglang_router_rs::protocols::common::{ContentPart, ImageUrl};
use sglang_router_rs::tokenizer::chat_template::{ use sglang_router_rs::tokenizer::chat_template::{
detect_chat_template_content_format, ChatTemplateContentFormat, ChatTemplateParams, detect_chat_template_content_format, ChatTemplateContentFormat, ChatTemplateParams,
ChatTemplateProcessor, ChatTemplateProcessor,
...@@ -17,8 +18,8 @@ fn test_simple_chat_template() { ...@@ -17,8 +18,8 @@ fn test_simple_chat_template() {
let processor = ChatTemplateProcessor::new(template.to_string()); let processor = ChatTemplateProcessor::new(template.to_string());
let messages = [spec::ChatMessage::User { let messages = [ChatMessage::User {
content: spec::UserMessageContent::Text("Test".to_string()), content: UserMessageContent::Text("Test".to_string()),
name: None, name: None,
}]; }];
...@@ -51,8 +52,8 @@ fn test_chat_template_with_tokens() { ...@@ -51,8 +52,8 @@ fn test_chat_template_with_tokens() {
let processor = ChatTemplateProcessor::new(template.to_string()); let processor = ChatTemplateProcessor::new(template.to_string());
let messages = [spec::ChatMessage::User { let messages = [ChatMessage::User {
content: spec::UserMessageContent::Text("Test".to_string()), content: UserMessageContent::Text("Test".to_string()),
name: None, name: None,
}]; }];
...@@ -112,12 +113,12 @@ fn test_llama_style_template() { ...@@ -112,12 +113,12 @@ fn test_llama_style_template() {
let processor = ChatTemplateProcessor::new(template.to_string()); let processor = ChatTemplateProcessor::new(template.to_string());
let messages = [ let messages = [
spec::ChatMessage::System { ChatMessage::System {
content: "You are a helpful assistant".to_string(), content: "You are a helpful assistant".to_string(),
name: None, name: None,
}, },
spec::ChatMessage::User { ChatMessage::User {
content: spec::UserMessageContent::Text("What is 2+2?".to_string()), content: UserMessageContent::Text("What is 2+2?".to_string()),
name: None, name: None,
}, },
]; ];
...@@ -167,18 +168,18 @@ fn test_chatml_template() { ...@@ -167,18 +168,18 @@ fn test_chatml_template() {
let processor = ChatTemplateProcessor::new(template.to_string()); let processor = ChatTemplateProcessor::new(template.to_string());
let messages = vec![ let messages = vec![
spec::ChatMessage::User { ChatMessage::User {
content: spec::UserMessageContent::Text("Hello".to_string()), content: UserMessageContent::Text("Hello".to_string()),
name: None, name: None,
}, },
spec::ChatMessage::Assistant { ChatMessage::Assistant {
content: Some("Hi there!".to_string()), content: Some("Hi there!".to_string()),
name: None, name: None,
tool_calls: None, tool_calls: None,
reasoning_content: None, reasoning_content: None,
}, },
spec::ChatMessage::User { ChatMessage::User {
content: spec::UserMessageContent::Text("How are you?".to_string()), content: UserMessageContent::Text("How are you?".to_string()),
name: None, name: None,
}, },
]; ];
...@@ -219,8 +220,8 @@ assistant: ...@@ -219,8 +220,8 @@ assistant:
let processor = ChatTemplateProcessor::new(template.to_string()); let processor = ChatTemplateProcessor::new(template.to_string());
let messages = [spec::ChatMessage::User { let messages = [ChatMessage::User {
content: spec::UserMessageContent::Text("Test".to_string()), content: UserMessageContent::Text("Test".to_string()),
name: None, name: None,
}]; }];
...@@ -306,13 +307,13 @@ fn test_template_with_multimodal_content() { ...@@ -306,13 +307,13 @@ fn test_template_with_multimodal_content() {
let processor = ChatTemplateProcessor::new(template.to_string()); let processor = ChatTemplateProcessor::new(template.to_string());
let messages = [spec::ChatMessage::User { let messages = [ChatMessage::User {
content: spec::UserMessageContent::Parts(vec![ content: UserMessageContent::Parts(vec![
spec::ContentPart::Text { ContentPart::Text {
text: "Look at this:".to_string(), text: "Look at this:".to_string(),
}, },
spec::ContentPart::ImageUrl { ContentPart::ImageUrl {
image_url: spec::ImageUrl { image_url: ImageUrl {
url: "https://example.com/image.jpg".to_string(), url: "https://example.com/image.jpg".to_string(),
detail: None, detail: None,
}, },
......
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use sglang_router_rs::protocols::spec; use sglang_router_rs::protocols::chat::{ChatMessage, UserMessageContent};
use sglang_router_rs::tokenizer::chat_template::ChatTemplateParams; use sglang_router_rs::tokenizer::chat_template::ChatTemplateParams;
use sglang_router_rs::tokenizer::huggingface::HuggingFaceTokenizer; use sglang_router_rs::tokenizer::huggingface::HuggingFaceTokenizer;
use std::fs; use std::fs;
...@@ -58,11 +58,11 @@ mod tests { ...@@ -58,11 +58,11 @@ mod tests {
.unwrap(); .unwrap();
let messages = [ let messages = [
spec::ChatMessage::User { ChatMessage::User {
content: spec::UserMessageContent::Text("Hello".to_string()), content: UserMessageContent::Text("Hello".to_string()),
name: None, name: None,
}, },
spec::ChatMessage::Assistant { ChatMessage::Assistant {
content: Some("Hi there".to_string()), content: Some("Hi there".to_string()),
name: None, name: None,
tool_calls: None, tool_calls: None,
...@@ -140,8 +140,8 @@ mod tests { ...@@ -140,8 +140,8 @@ mod tests {
) )
.unwrap(); .unwrap();
let messages = [spec::ChatMessage::User { let messages = [ChatMessage::User {
content: spec::UserMessageContent::Text("Test".to_string()), content: UserMessageContent::Text("Test".to_string()),
name: None, name: None,
}]; }];
...@@ -199,11 +199,11 @@ mod tests { ...@@ -199,11 +199,11 @@ mod tests {
tokenizer.set_chat_template(new_template.to_string()); tokenizer.set_chat_template(new_template.to_string());
let messages = [ let messages = [
spec::ChatMessage::User { ChatMessage::User {
content: spec::UserMessageContent::Text("Hello".to_string()), content: UserMessageContent::Text("Hello".to_string()),
name: None, name: None,
}, },
spec::ChatMessage::Assistant { ChatMessage::Assistant {
content: Some("World".to_string()), content: Some("World".to_string()),
name: None, name: None,
tool_calls: None, tool_calls: None,
......
...@@ -15,7 +15,7 @@ use sglang_router_rs::data_connector::{ ...@@ -15,7 +15,7 @@ use sglang_router_rs::data_connector::{
}; };
use sglang_router_rs::middleware::TokenBucket; use sglang_router_rs::middleware::TokenBucket;
use sglang_router_rs::policies::PolicyRegistry; use sglang_router_rs::policies::PolicyRegistry;
use sglang_router_rs::protocols::spec::{Function, Tool}; use sglang_router_rs::protocols::common::{Function, Tool};
use sglang_router_rs::server::AppContext; use sglang_router_rs::server::AppContext;
use std::fs; use std::fs;
use std::path::PathBuf; use std::path::PathBuf;
......
// Integration test for Responses API // Integration test for Responses API
use axum::http::StatusCode; use axum::http::StatusCode;
use sglang_router_rs::protocols::spec::{ use sglang_router_rs::protocols::common::{
GenerationRequest, ReasoningEffort, ResponseInput, ResponseReasoningParam, ResponseStatus, GenerationRequest, ToolChoice, ToolChoiceValue, UsageInfo,
ResponseTool, ResponseToolType, ResponsesRequest, ResponsesResponse, ServiceTier, ToolChoice, };
ToolChoiceValue, Truncation, UsageInfo, use sglang_router_rs::protocols::responses::{
ReasoningEffort, ResponseInput, ResponseReasoningParam, ResponseTool, ResponseToolType,
ResponsesRequest, ServiceTier, Truncation,
}; };
mod common; mod common;
...@@ -430,24 +432,18 @@ fn test_responses_request_sglang_extensions() { ...@@ -430,24 +432,18 @@ fn test_responses_request_sglang_extensions() {
assert_eq!(parsed.repetition_penalty, 1.1); assert_eq!(parsed.repetition_penalty, 1.1);
} }
#[test]
fn test_responses_response_creation() {
let response = ResponsesResponse::new(
"resp_test789".to_string(),
"test-model".to_string(),
ResponseStatus::Completed,
);
assert_eq!(response.id, "resp_test789");
assert_eq!(response.model, "test-model");
assert!(response.is_complete());
assert!(!response.is_in_progress());
assert!(!response.is_failed());
}
#[test] #[test]
fn test_usage_conversion() { fn test_usage_conversion() {
let usage_info = UsageInfo::new_with_cached(15, 25, Some(8), 3); // Construct UsageInfo directly with cached token details
let usage_info = UsageInfo {
prompt_tokens: 15,
completion_tokens: 25,
total_tokens: 40,
reasoning_tokens: Some(8),
prompt_tokens_details: Some(sglang_router_rs::protocols::common::PromptTokenUsageInfo {
cached_tokens: 3,
}),
};
let response_usage = usage_info.to_response_usage(); let response_usage = usage_info.to_response_usage();
assert_eq!(response_usage.input_tokens, 15); assert_eq!(response_usage.input_tokens, 15);
......
use serde_json::json; use serde_json::json;
use sglang_router_rs::protocols::spec::{ use sglang_router_rs::protocols::chat::{ChatCompletionRequest, ChatMessage, UserMessageContent};
ChatCompletionRequest, ChatMessage, Function, FunctionCall, FunctionChoice, StreamOptions, use sglang_router_rs::protocols::common::{
Tool, ToolChoice, ToolChoiceValue, ToolReference, UserMessageContent, Function, FunctionCall, FunctionChoice, StreamOptions, Tool, ToolChoice, ToolChoiceValue,
ToolReference,
}; };
use sglang_router_rs::protocols::validated::Normalizable; use sglang_router_rs::protocols::validated::Normalizable;
use validator::Validate; use validator::Validate;
......
use serde_json::json; use serde_json::json;
use sglang_router_rs::protocols::spec::{ChatMessage, UserMessageContent}; use sglang_router_rs::protocols::chat::{ChatMessage, UserMessageContent};
#[test] #[test]
fn test_chat_message_tagged_by_role_system() { fn test_chat_message_tagged_by_role_system() {
......
use serde_json::{from_str, json, to_string}; use serde_json::{from_str, json, to_string};
use sglang_router_rs::protocols::spec::{EmbeddingRequest, GenerationRequest}; use sglang_router_rs::protocols::common::GenerationRequest;
use sglang_router_rs::protocols::embedding::EmbeddingRequest;
#[test] #[test]
fn test_embedding_request_serialization_string_input() { fn test_embedding_request_serialization_string_input() {
......
use serde_json::{from_str, to_string, Number, Value}; use serde_json::{from_str, to_string, Number, Value};
use sglang_router_rs::protocols::spec::{ use sglang_router_rs::protocols::common::{GenerationRequest, StringOrArray, UsageInfo};
GenerationRequest, RerankRequest, RerankResponse, RerankResult, StringOrArray, UsageInfo, use sglang_router_rs::protocols::rerank::{
V1RerankReqInput, RerankRequest, RerankResponse, RerankResult, V1RerankReqInput,
}; };
use std::collections::HashMap; use std::collections::HashMap;
use validator::Validate;
#[test] #[test]
fn test_rerank_request_serialization() { fn test_rerank_request_serialization() {
...@@ -75,8 +76,7 @@ fn test_rerank_request_validation_empty_query() { ...@@ -75,8 +76,7 @@ fn test_rerank_request_validation_empty_query() {
}; };
let result = request.validate(); let result = request.validate();
assert!(result.is_err()); assert!(result.is_err(), "Should reject empty query");
assert_eq!(result.unwrap_err(), "Query cannot be empty");
} }
#[test] #[test]
...@@ -92,8 +92,7 @@ fn test_rerank_request_validation_whitespace_query() { ...@@ -92,8 +92,7 @@ fn test_rerank_request_validation_whitespace_query() {
}; };
let result = request.validate(); let result = request.validate();
assert!(result.is_err()); assert!(result.is_err(), "Should reject whitespace-only query");
assert_eq!(result.unwrap_err(), "Query cannot be empty");
} }
#[test] #[test]
...@@ -109,8 +108,7 @@ fn test_rerank_request_validation_empty_documents() { ...@@ -109,8 +108,7 @@ fn test_rerank_request_validation_empty_documents() {
}; };
let result = request.validate(); let result = request.validate();
assert!(result.is_err()); assert!(result.is_err(), "Should reject empty documents list");
assert_eq!(result.unwrap_err(), "Documents list cannot be empty");
} }
#[test] #[test]
...@@ -126,8 +124,7 @@ fn test_rerank_request_validation_top_k_zero() { ...@@ -126,8 +124,7 @@ fn test_rerank_request_validation_top_k_zero() {
}; };
let result = request.validate(); let result = request.validate();
assert!(result.is_err()); assert!(result.is_err(), "Should reject top_k of zero");
assert_eq!(result.unwrap_err(), "top_k must be greater than 0");
} }
#[test] #[test]
......
...@@ -9,18 +9,20 @@ use axum::{ ...@@ -9,18 +9,20 @@ use axum::{
Json, Router, Json, Router,
}; };
use serde_json::json; use serde_json::json;
use sglang_router_rs::data_connector::MemoryConversationItemStorage;
use sglang_router_rs::{ use sglang_router_rs::{
config::{ config::{
ConfigError, ConfigValidator, HistoryBackend, OracleConfig, RouterConfig, RoutingMode, ConfigError, ConfigValidator, HistoryBackend, OracleConfig, RouterConfig, RoutingMode,
}, },
data_connector::{ data_connector::{
MemoryConversationStorage, MemoryResponseStorage, ResponseId, ResponseStorage, MemoryConversationItemStorage, MemoryConversationStorage, MemoryResponseStorage,
StoredResponse, ResponseId, ResponseStorage, StoredResponse,
}, },
protocols::spec::{ protocols::{
ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateRequest, ResponseInput, chat::{ChatCompletionRequest, ChatMessage, UserMessageContent},
ResponsesGetParams, ResponsesRequest, UserMessageContent, common::StringOrArray,
completion::CompletionRequest,
generate::GenerateRequest,
responses::{ResponseInput, ResponsesGetParams, ResponsesRequest},
}, },
routers::{openai::OpenAIRouter, RouterTrait}, routers::{openai::OpenAIRouter, RouterTrait},
}; };
...@@ -52,7 +54,7 @@ fn create_minimal_chat_request() -> ChatCompletionRequest { ...@@ -52,7 +54,7 @@ fn create_minimal_chat_request() -> ChatCompletionRequest {
fn create_minimal_completion_request() -> CompletionRequest { fn create_minimal_completion_request() -> CompletionRequest {
CompletionRequest { CompletionRequest {
model: "gpt-3.5-turbo".to_string(), model: "gpt-3.5-turbo".to_string(),
prompt: sglang_router_rs::protocols::spec::StringOrArray::String("Hello".to_string()), prompt: StringOrArray::String("Hello".to_string()),
suffix: None, suffix: None,
max_tokens: Some(100), max_tokens: Some(100),
temperature: None, temperature: None,
...@@ -605,12 +607,12 @@ async fn test_unsupported_endpoints() { ...@@ -605,12 +607,12 @@ async fn test_unsupported_endpoints() {
video_data: None, video_data: None,
audio_data: None, audio_data: None,
sampling_params: None, sampling_params: None,
stream: false,
return_logprob: Some(false), return_logprob: Some(false),
logprob_start_len: None, logprob_start_len: None,
top_logprobs_num: None, top_logprobs_num: None,
token_ids_logprob: None, token_ids_logprob: None,
return_text_in_logprobs: false, return_text_in_logprobs: false,
stream: false,
log_metrics: true, log_metrics: true,
return_hidden_states: false, return_hidden_states: false,
modalities: None, modalities: None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment