Unverified Commit 5ef545e6 authored by Keyang Ru's avatar Keyang Ru Committed by GitHub
Browse files

[router] Move all protocols to spec.rs file (#9519)

parent c4500233
// Supporting types for Responses API
use crate::protocols::openai::common::ChatLogProbs;
use serde::{Deserialize, Serialize};
// ============= Tool Definitions =============
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ResponseTool {
#[serde(rename = "type")]
pub r#type: ResponseToolType,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum ResponseToolType {
WebSearchPreview,
CodeInterpreter,
}
// ============= Reasoning Configuration =============
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ResponseReasoningParam {
#[serde(default = "default_reasoning_effort")]
pub effort: Option<ReasoningEffort>,
}
fn default_reasoning_effort() -> Option<ReasoningEffort> {
Some(ReasoningEffort::Medium)
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum ReasoningEffort {
Low,
Medium,
High,
}
// ============= Input/Output Items =============
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(tag = "type")]
#[serde(rename_all = "snake_case")]
pub enum ResponseInputOutputItem {
#[serde(rename = "message")]
Message {
id: String,
role: String,
content: Vec<ResponseContentPart>,
#[serde(skip_serializing_if = "Option::is_none")]
status: Option<String>,
},
#[serde(rename = "reasoning")]
Reasoning {
id: String,
#[serde(skip_serializing_if = "Vec::is_empty")]
summary: Vec<String>,
content: Vec<ResponseReasoningContent>,
#[serde(skip_serializing_if = "Option::is_none")]
status: Option<String>,
},
#[serde(rename = "function_tool_call")]
FunctionToolCall {
id: String,
name: String,
arguments: String,
#[serde(skip_serializing_if = "Option::is_none")]
output: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
status: Option<String>,
},
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(tag = "type")]
#[serde(rename_all = "snake_case")]
pub enum ResponseContentPart {
#[serde(rename = "output_text")]
OutputText {
text: String,
#[serde(skip_serializing_if = "Vec::is_empty")]
annotations: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
logprobs: Option<ChatLogProbs>,
},
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(tag = "type")]
#[serde(rename_all = "snake_case")]
pub enum ResponseReasoningContent {
#[serde(rename = "reasoning_text")]
ReasoningText { text: String },
}
// ============= Output Items for Response =============
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(tag = "type")]
#[serde(rename_all = "snake_case")]
pub enum ResponseOutputItem {
#[serde(rename = "message")]
Message {
id: String,
role: String,
content: Vec<ResponseContentPart>,
status: String,
},
#[serde(rename = "reasoning")]
Reasoning {
id: String,
#[serde(skip_serializing_if = "Vec::is_empty")]
summary: Vec<String>,
content: Vec<ResponseReasoningContent>,
#[serde(skip_serializing_if = "Option::is_none")]
status: Option<String>,
},
#[serde(rename = "function_tool_call")]
FunctionToolCall {
id: String,
name: String,
arguments: String,
#[serde(skip_serializing_if = "Option::is_none")]
output: Option<String>,
status: String,
},
}
// ============= Service Tier =============
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum ServiceTier {
Auto,
Default,
Flex,
Scale,
Priority,
}
impl Default for ServiceTier {
fn default() -> Self {
Self::Auto
}
}
// ============= Tool Choice =============
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum ToolChoice {
Auto,
Required,
None,
}
impl Default for ToolChoice {
fn default() -> Self {
Self::Auto
}
}
// ============= Truncation =============
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum Truncation {
Auto,
Disabled,
}
impl Default for Truncation {
fn default() -> Self {
Self::Disabled
}
}
// ============= Response Status =============
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum ResponseStatus {
Queued,
InProgress,
Completed,
Failed,
Cancelled,
}
// ============= Include Fields =============
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum IncludeField {
#[serde(rename = "code_interpreter_call.outputs")]
CodeInterpreterCallOutputs,
#[serde(rename = "computer_call_output.output.image_url")]
ComputerCallOutputImageUrl,
#[serde(rename = "file_search_call.results")]
FileSearchCallResults,
#[serde(rename = "message.input_image.image_url")]
MessageInputImageUrl,
#[serde(rename = "message.output_text.logprobs")]
MessageOutputTextLogprobs,
#[serde(rename = "reasoning.encrypted_content")]
ReasoningEncryptedContent,
}
// ============= Usage Info =============
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct UsageInfo {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub reasoning_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub prompt_tokens_details: Option<PromptTokenUsageInfo>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct PromptTokenUsageInfo {
pub cached_tokens: u32,
}
// ============= Response Usage Format =============
/// OpenAI Responses API usage format (different from standard UsageInfo)
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ResponseUsage {
pub input_tokens: u32,
pub output_tokens: u32,
pub total_tokens: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub input_tokens_details: Option<InputTokensDetails>,
#[serde(skip_serializing_if = "Option::is_none")]
pub output_tokens_details: Option<OutputTokensDetails>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct InputTokensDetails {
pub cached_tokens: u32,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct OutputTokensDetails {
pub reasoning_tokens: u32,
}
impl UsageInfo {
/// Convert to OpenAI Responses API format
pub fn to_response_usage(&self) -> ResponseUsage {
ResponseUsage {
input_tokens: self.prompt_tokens,
output_tokens: self.completion_tokens,
total_tokens: self.total_tokens,
input_tokens_details: self.prompt_tokens_details.as_ref().map(|details| {
InputTokensDetails {
cached_tokens: details.cached_tokens,
}
}),
output_tokens_details: self.reasoning_tokens.map(|tokens| OutputTokensDetails {
reasoning_tokens: tokens,
}),
}
}
}
impl From<UsageInfo> for ResponseUsage {
fn from(usage: UsageInfo) -> Self {
usage.to_response_usage()
}
}
impl ResponseUsage {
/// Convert back to standard UsageInfo format
pub fn to_usage_info(&self) -> UsageInfo {
UsageInfo {
prompt_tokens: self.input_tokens,
completion_tokens: self.output_tokens,
total_tokens: self.total_tokens,
reasoning_tokens: self
.output_tokens_details
.as_ref()
.map(|details| details.reasoning_tokens),
prompt_tokens_details: self.input_tokens_details.as_ref().map(|details| {
PromptTokenUsageInfo {
cached_tokens: details.cached_tokens,
}
}),
}
}
}
This diff is collapsed.
This diff is collapsed.
......@@ -9,10 +9,7 @@ use axum::{
};
use std::fmt::Debug;
use crate::protocols::{
generate::GenerateRequest,
openai::{chat::ChatCompletionRequest, completions::CompletionRequest},
};
use crate::protocols::spec::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
pub mod factory;
pub mod header_utils;
......
......@@ -12,13 +12,9 @@ use crate::core::{
};
use crate::metrics::RouterMetrics;
use crate::policies::LoadBalancingPolicy;
use crate::protocols::{
common::StringOrArray,
generate::GenerateRequest,
openai::{
chat::{ChatCompletionRequest, ChatMessage, UserMessageContent},
completions::CompletionRequest,
},
use crate::protocols::spec::{
ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateRequest, StringOrArray,
UserMessageContent,
};
use crate::routers::{RouterTrait, WorkerManagement};
use async_trait::async_trait;
......
......@@ -9,10 +9,8 @@ use crate::core::{
};
use crate::metrics::RouterMetrics;
use crate::policies::LoadBalancingPolicy;
use crate::protocols::{
common::GenerationRequest,
generate::GenerateRequest,
openai::{chat::ChatCompletionRequest, completions::CompletionRequest},
use crate::protocols::spec::{
ChatCompletionRequest, CompletionRequest, GenerateRequest, GenerationRequest,
};
use crate::routers::{RouterTrait, WorkerManagement};
use axum::{
......
use crate::config::RouterConfig;
use crate::logging::{self, LoggingConfig};
use crate::metrics::{self, PrometheusConfig};
use crate::protocols::{
generate::GenerateRequest,
openai::{chat::ChatCompletionRequest, completions::CompletionRequest},
};
use crate::protocols::spec::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
use crate::routers::{RouterFactory, RouterTrait};
use crate::service_discovery::{start_service_discovery, ServiceDiscoveryConfig};
use axum::{
......
......@@ -5,13 +5,9 @@
use serde_json::{from_str, to_string, to_value};
use sglang_router_rs::core::{BasicWorker, WorkerType};
use sglang_router_rs::protocols::{
common::StringOrArray,
generate::{GenerateParameters, GenerateRequest, SamplingParams},
openai::{
chat::{ChatCompletionRequest, ChatMessage, UserMessageContent},
completions::CompletionRequest,
},
use sglang_router_rs::protocols::spec::{
ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateParameters, GenerateRequest,
SamplingParams, StringOrArray, UserMessageContent,
};
/// Create a default GenerateRequest for benchmarks with minimal fields set
......
// Integration test for Responses API
use sglang_router_rs::protocols::common::GenerationRequest;
use sglang_router_rs::protocols::openai::responses::request::ResponseInput;
use sglang_router_rs::protocols::openai::responses::*;
use sglang_router_rs::protocols::spec::{
GenerationRequest, ReasoningEffort, ResponseInput, ResponseReasoningParam, ResponseStatus,
ResponseTool, ResponseToolType, ResponsesRequest, ResponsesResponse, ServiceTier, ToolChoice,
ToolChoiceValue, Truncation, UsageInfo,
};
#[test]
fn test_responses_request_creation() {
......@@ -24,7 +26,7 @@ fn test_responses_request_creation() {
store: true,
stream: false,
temperature: Some(0.7),
tool_choice: ToolChoice::Auto,
tool_choice: ToolChoice::Value(ToolChoiceValue::Auto),
tools: vec![ResponseTool {
r#type: ResponseToolType::WebSearchPreview,
}],
......@@ -67,7 +69,7 @@ fn test_sampling_params_conversion() {
store: true, // Use default true
stream: false,
temperature: Some(0.8),
tool_choice: ToolChoice::Auto,
tool_choice: ToolChoice::Value(ToolChoiceValue::Auto),
tools: vec![],
top_logprobs: 0, // Use default 0
top_p: Some(0.95),
......@@ -177,7 +179,7 @@ fn test_json_serialization() {
store: false,
stream: true,
temperature: Some(0.9),
tool_choice: ToolChoice::Required,
tool_choice: ToolChoice::Value(ToolChoiceValue::Required),
tools: vec![ResponseTool {
r#type: ResponseToolType::CodeInterpreter,
}],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment