Unverified Commit 4c9bcb9d authored by Keyang Ru's avatar Keyang Ru Committed by GitHub
Browse files

[Router] Refactor protocol definitions: split spec.rs into modular files (#11677)


Co-authored-by: default avatarChang Su <chang.s.su@oracle.com>
parent 86b04d25
use async_trait::async_trait;
use serde_json::Value;
use crate::protocols::spec::Tool;
use crate::protocols::common::Tool;
use crate::tool_parser::{
errors::{ParserError, ParserResult},
......
......@@ -4,7 +4,7 @@
//! tool call parsing should be performed. It simply returns the input text
//! with no tool calls detected.
use crate::protocols::spec::Tool;
use crate::protocols::common::Tool;
use crate::tool_parser::errors::ParserResult;
use crate::tool_parser::traits::ToolParser;
use crate::tool_parser::types::{StreamingParseResult, ToolCall, ToolCallItem};
......
......@@ -15,7 +15,7 @@ use rustpython_parser::{parse, Mode};
use serde_json::{Map, Number, Value};
use std::sync::OnceLock;
use crate::protocols::spec::Tool;
use crate::protocols::common::Tool;
use crate::tool_parser::{
errors::{ParserError, ParserResult},
......
......@@ -2,7 +2,7 @@ use async_trait::async_trait;
use regex::Regex;
use serde_json::Value;
use crate::protocols::spec::Tool;
use crate::protocols::common::Tool;
use crate::tool_parser::{
errors::{ParserError, ParserResult},
......
......@@ -3,7 +3,7 @@ use regex::Regex;
use serde_json::Value;
use std::collections::HashMap;
use crate::protocols::spec::Tool;
use crate::protocols::common::Tool;
use crate::tool_parser::{
errors::{ParserError, ParserResult},
......
use crate::protocols::spec::Tool;
use crate::protocols::common::Tool;
use crate::tool_parser::{
errors::ParserResult,
types::{StreamingParseResult, ToolCall},
......
use sglang_router_rs::protocols::spec;
use sglang_router_rs::protocols::chat::{ChatMessage, UserMessageContent};
use sglang_router_rs::tokenizer::chat_template::{
detect_chat_template_content_format, ChatTemplateContentFormat, ChatTemplateParams,
ChatTemplateProcessor,
......@@ -173,12 +173,12 @@ assistant:
let processor = ChatTemplateProcessor::new(template.to_string());
let messages = [
spec::ChatMessage::System {
ChatMessage::System {
content: "You are helpful".to_string(),
name: None,
},
spec::ChatMessage::User {
content: spec::UserMessageContent::Text("Hello".to_string()),
ChatMessage::User {
content: UserMessageContent::Text("Hello".to_string()),
name: None,
},
];
......@@ -213,8 +213,8 @@ fn test_chat_template_with_tokens_unit_test() {
let processor = ChatTemplateProcessor::new(template.to_string());
let messages = [spec::ChatMessage::User {
content: spec::UserMessageContent::Text("Test".to_string()),
let messages = [ChatMessage::User {
content: UserMessageContent::Text("Test".to_string()),
name: None,
}];
......
use sglang_router_rs::protocols::spec;
use sglang_router_rs::protocols::chat::{ChatMessage, UserMessageContent};
use sglang_router_rs::protocols::common::{ContentPart, ImageUrl};
use sglang_router_rs::tokenizer::chat_template::{
detect_chat_template_content_format, ChatTemplateContentFormat, ChatTemplateParams,
ChatTemplateProcessor,
......@@ -17,8 +18,8 @@ fn test_simple_chat_template() {
let processor = ChatTemplateProcessor::new(template.to_string());
let messages = [spec::ChatMessage::User {
content: spec::UserMessageContent::Text("Test".to_string()),
let messages = [ChatMessage::User {
content: UserMessageContent::Text("Test".to_string()),
name: None,
}];
......@@ -51,8 +52,8 @@ fn test_chat_template_with_tokens() {
let processor = ChatTemplateProcessor::new(template.to_string());
let messages = [spec::ChatMessage::User {
content: spec::UserMessageContent::Text("Test".to_string()),
let messages = [ChatMessage::User {
content: UserMessageContent::Text("Test".to_string()),
name: None,
}];
......@@ -112,12 +113,12 @@ fn test_llama_style_template() {
let processor = ChatTemplateProcessor::new(template.to_string());
let messages = [
spec::ChatMessage::System {
ChatMessage::System {
content: "You are a helpful assistant".to_string(),
name: None,
},
spec::ChatMessage::User {
content: spec::UserMessageContent::Text("What is 2+2?".to_string()),
ChatMessage::User {
content: UserMessageContent::Text("What is 2+2?".to_string()),
name: None,
},
];
......@@ -167,18 +168,18 @@ fn test_chatml_template() {
let processor = ChatTemplateProcessor::new(template.to_string());
let messages = vec![
spec::ChatMessage::User {
content: spec::UserMessageContent::Text("Hello".to_string()),
ChatMessage::User {
content: UserMessageContent::Text("Hello".to_string()),
name: None,
},
spec::ChatMessage::Assistant {
ChatMessage::Assistant {
content: Some("Hi there!".to_string()),
name: None,
tool_calls: None,
reasoning_content: None,
},
spec::ChatMessage::User {
content: spec::UserMessageContent::Text("How are you?".to_string()),
ChatMessage::User {
content: UserMessageContent::Text("How are you?".to_string()),
name: None,
},
];
......@@ -219,8 +220,8 @@ assistant:
let processor = ChatTemplateProcessor::new(template.to_string());
let messages = [spec::ChatMessage::User {
content: spec::UserMessageContent::Text("Test".to_string()),
let messages = [ChatMessage::User {
content: UserMessageContent::Text("Test".to_string()),
name: None,
}];
......@@ -306,13 +307,13 @@ fn test_template_with_multimodal_content() {
let processor = ChatTemplateProcessor::new(template.to_string());
let messages = [spec::ChatMessage::User {
content: spec::UserMessageContent::Parts(vec![
spec::ContentPart::Text {
let messages = [ChatMessage::User {
content: UserMessageContent::Parts(vec![
ContentPart::Text {
text: "Look at this:".to_string(),
},
spec::ContentPart::ImageUrl {
image_url: spec::ImageUrl {
ContentPart::ImageUrl {
image_url: ImageUrl {
url: "https://example.com/image.jpg".to_string(),
detail: None,
},
......
#[cfg(test)]
mod tests {
use sglang_router_rs::protocols::spec;
use sglang_router_rs::protocols::chat::{ChatMessage, UserMessageContent};
use sglang_router_rs::tokenizer::chat_template::ChatTemplateParams;
use sglang_router_rs::tokenizer::huggingface::HuggingFaceTokenizer;
use std::fs;
......@@ -58,11 +58,11 @@ mod tests {
.unwrap();
let messages = [
spec::ChatMessage::User {
content: spec::UserMessageContent::Text("Hello".to_string()),
ChatMessage::User {
content: UserMessageContent::Text("Hello".to_string()),
name: None,
},
spec::ChatMessage::Assistant {
ChatMessage::Assistant {
content: Some("Hi there".to_string()),
name: None,
tool_calls: None,
......@@ -140,8 +140,8 @@ mod tests {
)
.unwrap();
let messages = [spec::ChatMessage::User {
content: spec::UserMessageContent::Text("Test".to_string()),
let messages = [ChatMessage::User {
content: UserMessageContent::Text("Test".to_string()),
name: None,
}];
......@@ -199,11 +199,11 @@ mod tests {
tokenizer.set_chat_template(new_template.to_string());
let messages = [
spec::ChatMessage::User {
content: spec::UserMessageContent::Text("Hello".to_string()),
ChatMessage::User {
content: UserMessageContent::Text("Hello".to_string()),
name: None,
},
spec::ChatMessage::Assistant {
ChatMessage::Assistant {
content: Some("World".to_string()),
name: None,
tool_calls: None,
......
......@@ -15,7 +15,7 @@ use sglang_router_rs::data_connector::{
};
use sglang_router_rs::middleware::TokenBucket;
use sglang_router_rs::policies::PolicyRegistry;
use sglang_router_rs::protocols::spec::{Function, Tool};
use sglang_router_rs::protocols::common::{Function, Tool};
use sglang_router_rs::server::AppContext;
use std::fs;
use std::path::PathBuf;
......
// Integration test for Responses API
use axum::http::StatusCode;
use sglang_router_rs::protocols::spec::{
GenerationRequest, ReasoningEffort, ResponseInput, ResponseReasoningParam, ResponseStatus,
ResponseTool, ResponseToolType, ResponsesRequest, ResponsesResponse, ServiceTier, ToolChoice,
ToolChoiceValue, Truncation, UsageInfo,
use sglang_router_rs::protocols::common::{
GenerationRequest, ToolChoice, ToolChoiceValue, UsageInfo,
};
use sglang_router_rs::protocols::responses::{
ReasoningEffort, ResponseInput, ResponseReasoningParam, ResponseTool, ResponseToolType,
ResponsesRequest, ServiceTier, Truncation,
};
mod common;
......@@ -430,24 +432,18 @@ fn test_responses_request_sglang_extensions() {
assert_eq!(parsed.repetition_penalty, 1.1);
}
#[test]
fn test_responses_response_creation() {
let response = ResponsesResponse::new(
"resp_test789".to_string(),
"test-model".to_string(),
ResponseStatus::Completed,
);
assert_eq!(response.id, "resp_test789");
assert_eq!(response.model, "test-model");
assert!(response.is_complete());
assert!(!response.is_in_progress());
assert!(!response.is_failed());
}
#[test]
fn test_usage_conversion() {
let usage_info = UsageInfo::new_with_cached(15, 25, Some(8), 3);
// Construct UsageInfo directly with cached token details
let usage_info = UsageInfo {
prompt_tokens: 15,
completion_tokens: 25,
total_tokens: 40,
reasoning_tokens: Some(8),
prompt_tokens_details: Some(sglang_router_rs::protocols::common::PromptTokenUsageInfo {
cached_tokens: 3,
}),
};
let response_usage = usage_info.to_response_usage();
assert_eq!(response_usage.input_tokens, 15);
......
use serde_json::json;
use sglang_router_rs::protocols::spec::{
ChatCompletionRequest, ChatMessage, Function, FunctionCall, FunctionChoice, StreamOptions,
Tool, ToolChoice, ToolChoiceValue, ToolReference, UserMessageContent,
use sglang_router_rs::protocols::chat::{ChatCompletionRequest, ChatMessage, UserMessageContent};
use sglang_router_rs::protocols::common::{
Function, FunctionCall, FunctionChoice, StreamOptions, Tool, ToolChoice, ToolChoiceValue,
ToolReference,
};
use sglang_router_rs::protocols::validated::Normalizable;
use validator::Validate;
......
use serde_json::json;
use sglang_router_rs::protocols::spec::{ChatMessage, UserMessageContent};
use sglang_router_rs::protocols::chat::{ChatMessage, UserMessageContent};
#[test]
fn test_chat_message_tagged_by_role_system() {
......
use serde_json::{from_str, json, to_string};
use sglang_router_rs::protocols::spec::{EmbeddingRequest, GenerationRequest};
use sglang_router_rs::protocols::common::GenerationRequest;
use sglang_router_rs::protocols::embedding::EmbeddingRequest;
#[test]
fn test_embedding_request_serialization_string_input() {
......
use serde_json::{from_str, to_string, Number, Value};
use sglang_router_rs::protocols::spec::{
GenerationRequest, RerankRequest, RerankResponse, RerankResult, StringOrArray, UsageInfo,
V1RerankReqInput,
use sglang_router_rs::protocols::common::{GenerationRequest, StringOrArray, UsageInfo};
use sglang_router_rs::protocols::rerank::{
RerankRequest, RerankResponse, RerankResult, V1RerankReqInput,
};
use std::collections::HashMap;
use validator::Validate;
#[test]
fn test_rerank_request_serialization() {
......@@ -75,8 +76,7 @@ fn test_rerank_request_validation_empty_query() {
};
let result = request.validate();
assert!(result.is_err());
assert_eq!(result.unwrap_err(), "Query cannot be empty");
assert!(result.is_err(), "Should reject empty query");
}
#[test]
......@@ -92,8 +92,7 @@ fn test_rerank_request_validation_whitespace_query() {
};
let result = request.validate();
assert!(result.is_err());
assert_eq!(result.unwrap_err(), "Query cannot be empty");
assert!(result.is_err(), "Should reject whitespace-only query");
}
#[test]
......@@ -109,8 +108,7 @@ fn test_rerank_request_validation_empty_documents() {
};
let result = request.validate();
assert!(result.is_err());
assert_eq!(result.unwrap_err(), "Documents list cannot be empty");
assert!(result.is_err(), "Should reject empty documents list");
}
#[test]
......@@ -126,8 +124,7 @@ fn test_rerank_request_validation_top_k_zero() {
};
let result = request.validate();
assert!(result.is_err());
assert_eq!(result.unwrap_err(), "top_k must be greater than 0");
assert!(result.is_err(), "Should reject top_k of zero");
}
#[test]
......
......@@ -9,18 +9,20 @@ use axum::{
Json, Router,
};
use serde_json::json;
use sglang_router_rs::data_connector::MemoryConversationItemStorage;
use sglang_router_rs::{
config::{
ConfigError, ConfigValidator, HistoryBackend, OracleConfig, RouterConfig, RoutingMode,
},
data_connector::{
MemoryConversationStorage, MemoryResponseStorage, ResponseId, ResponseStorage,
StoredResponse,
MemoryConversationItemStorage, MemoryConversationStorage, MemoryResponseStorage,
ResponseId, ResponseStorage, StoredResponse,
},
protocols::spec::{
ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateRequest, ResponseInput,
ResponsesGetParams, ResponsesRequest, UserMessageContent,
protocols::{
chat::{ChatCompletionRequest, ChatMessage, UserMessageContent},
common::StringOrArray,
completion::CompletionRequest,
generate::GenerateRequest,
responses::{ResponseInput, ResponsesGetParams, ResponsesRequest},
},
routers::{openai::OpenAIRouter, RouterTrait},
};
......@@ -52,7 +54,7 @@ fn create_minimal_chat_request() -> ChatCompletionRequest {
fn create_minimal_completion_request() -> CompletionRequest {
CompletionRequest {
model: "gpt-3.5-turbo".to_string(),
prompt: sglang_router_rs::protocols::spec::StringOrArray::String("Hello".to_string()),
prompt: StringOrArray::String("Hello".to_string()),
suffix: None,
max_tokens: Some(100),
temperature: None,
......@@ -605,12 +607,12 @@ async fn test_unsupported_endpoints() {
video_data: None,
audio_data: None,
sampling_params: None,
stream: false,
return_logprob: Some(false),
logprob_start_len: None,
top_logprobs_num: None,
token_ids_logprob: None,
return_text_in_logprobs: false,
stream: false,
log_metrics: true,
return_hidden_states: false,
modalities: None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment