Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
5ef545e6
Unverified
Commit
5ef545e6
authored
Aug 22, 2025
by
Keyang Ru
Committed by
GitHub
Aug 22, 2025
Browse files
[router] Move all protocols to spec.rs file (#9519)
parent
c4500233
Changes
29
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
2426 additions
and
463 deletions
+2426
-463
sgl-router/src/protocols/openai/responses/types.rs
sgl-router/src/protocols/openai/responses/types.rs
+0
-296
sgl-router/src/protocols/spec.rs
sgl-router/src/protocols/spec.rs
+1867
-0
sgl-router/src/protocols/validation.rs
sgl-router/src/protocols/validation.rs
+541
-135
sgl-router/src/routers/mod.rs
sgl-router/src/routers/mod.rs
+1
-4
sgl-router/src/routers/pd_router.rs
sgl-router/src/routers/pd_router.rs
+3
-7
sgl-router/src/routers/router.rs
sgl-router/src/routers/router.rs
+2
-4
sgl-router/src/server.rs
sgl-router/src/server.rs
+1
-4
sgl-router/tests/benchmark_integration.rs
sgl-router/tests/benchmark_integration.rs
+3
-7
sgl-router/tests/responses_api_test.rs
sgl-router/tests/responses_api_test.rs
+8
-6
No files found.
sgl-router/src/protocols/openai/responses/types.rs
deleted
100644 → 0
View file @
c4500233
// Supporting types for Responses API
use
crate
::
protocols
::
openai
::
common
::
ChatLogProbs
;
use
serde
::{
Deserialize
,
Serialize
};
// ============= Tool Definitions =============
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
pub
struct
ResponseTool
{
#[serde(rename
=
"type"
)]
pub
r
#
type
:
ResponseToolType
,
}
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
#[serde(rename_all
=
"snake_case"
)]
pub
enum
ResponseToolType
{
WebSearchPreview
,
CodeInterpreter
,
}
// ============= Reasoning Configuration =============
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
pub
struct
ResponseReasoningParam
{
#[serde(default
=
"default_reasoning_effort"
)]
pub
effort
:
Option
<
ReasoningEffort
>
,
}
fn
default_reasoning_effort
()
->
Option
<
ReasoningEffort
>
{
Some
(
ReasoningEffort
::
Medium
)
}
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
#[serde(rename_all
=
"snake_case"
)]
pub
enum
ReasoningEffort
{
Low
,
Medium
,
High
,
}
// ============= Input/Output Items =============
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
#[serde(tag
=
"type"
)]
#[serde(rename_all
=
"snake_case"
)]
pub
enum
ResponseInputOutputItem
{
#[serde(rename
=
"message"
)]
Message
{
id
:
String
,
role
:
String
,
content
:
Vec
<
ResponseContentPart
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
status
:
Option
<
String
>
,
},
#[serde(rename
=
"reasoning"
)]
Reasoning
{
id
:
String
,
#[serde(skip_serializing_if
=
"Vec::is_empty"
)]
summary
:
Vec
<
String
>
,
content
:
Vec
<
ResponseReasoningContent
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
status
:
Option
<
String
>
,
},
#[serde(rename
=
"function_tool_call"
)]
FunctionToolCall
{
id
:
String
,
name
:
String
,
arguments
:
String
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
output
:
Option
<
String
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
status
:
Option
<
String
>
,
},
}
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
#[serde(tag
=
"type"
)]
#[serde(rename_all
=
"snake_case"
)]
pub
enum
ResponseContentPart
{
#[serde(rename
=
"output_text"
)]
OutputText
{
text
:
String
,
#[serde(skip_serializing_if
=
"Vec::is_empty"
)]
annotations
:
Vec
<
String
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
logprobs
:
Option
<
ChatLogProbs
>
,
},
}
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
#[serde(tag
=
"type"
)]
#[serde(rename_all
=
"snake_case"
)]
pub
enum
ResponseReasoningContent
{
#[serde(rename
=
"reasoning_text"
)]
ReasoningText
{
text
:
String
},
}
// ============= Output Items for Response =============
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
#[serde(tag
=
"type"
)]
#[serde(rename_all
=
"snake_case"
)]
pub
enum
ResponseOutputItem
{
#[serde(rename
=
"message"
)]
Message
{
id
:
String
,
role
:
String
,
content
:
Vec
<
ResponseContentPart
>
,
status
:
String
,
},
#[serde(rename
=
"reasoning"
)]
Reasoning
{
id
:
String
,
#[serde(skip_serializing_if
=
"Vec::is_empty"
)]
summary
:
Vec
<
String
>
,
content
:
Vec
<
ResponseReasoningContent
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
status
:
Option
<
String
>
,
},
#[serde(rename
=
"function_tool_call"
)]
FunctionToolCall
{
id
:
String
,
name
:
String
,
arguments
:
String
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
output
:
Option
<
String
>
,
status
:
String
,
},
}
// ============= Service Tier =============
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
#[serde(rename_all
=
"snake_case"
)]
pub
enum
ServiceTier
{
Auto
,
Default
,
Flex
,
Scale
,
Priority
,
}
impl
Default
for
ServiceTier
{
fn
default
()
->
Self
{
Self
::
Auto
}
}
// ============= Tool Choice =============
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
#[serde(rename_all
=
"snake_case"
)]
pub
enum
ToolChoice
{
Auto
,
Required
,
None
,
}
impl
Default
for
ToolChoice
{
fn
default
()
->
Self
{
Self
::
Auto
}
}
// ============= Truncation =============
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
#[serde(rename_all
=
"snake_case"
)]
pub
enum
Truncation
{
Auto
,
Disabled
,
}
impl
Default
for
Truncation
{
fn
default
()
->
Self
{
Self
::
Disabled
}
}
// ============= Response Status =============
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
#[serde(rename_all
=
"snake_case"
)]
pub
enum
ResponseStatus
{
Queued
,
InProgress
,
Completed
,
Failed
,
Cancelled
,
}
// ============= Include Fields =============
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
#[serde(rename_all
=
"snake_case"
)]
pub
enum
IncludeField
{
#[serde(rename
=
"code_interpreter_call.outputs"
)]
CodeInterpreterCallOutputs
,
#[serde(rename
=
"computer_call_output.output.image_url"
)]
ComputerCallOutputImageUrl
,
#[serde(rename
=
"file_search_call.results"
)]
FileSearchCallResults
,
#[serde(rename
=
"message.input_image.image_url"
)]
MessageInputImageUrl
,
#[serde(rename
=
"message.output_text.logprobs"
)]
MessageOutputTextLogprobs
,
#[serde(rename
=
"reasoning.encrypted_content"
)]
ReasoningEncryptedContent
,
}
// ============= Usage Info =============
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
pub
struct
UsageInfo
{
pub
prompt_tokens
:
u32
,
pub
completion_tokens
:
u32
,
pub
total_tokens
:
u32
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
reasoning_tokens
:
Option
<
u32
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
prompt_tokens_details
:
Option
<
PromptTokenUsageInfo
>
,
}
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
pub
struct
PromptTokenUsageInfo
{
pub
cached_tokens
:
u32
,
}
// ============= Response Usage Format =============
/// OpenAI Responses API usage format (different from standard UsageInfo)
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
pub
struct
ResponseUsage
{
pub
input_tokens
:
u32
,
pub
output_tokens
:
u32
,
pub
total_tokens
:
u32
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
input_tokens_details
:
Option
<
InputTokensDetails
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
output_tokens_details
:
Option
<
OutputTokensDetails
>
,
}
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
pub
struct
InputTokensDetails
{
pub
cached_tokens
:
u32
,
}
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
pub
struct
OutputTokensDetails
{
pub
reasoning_tokens
:
u32
,
}
impl
UsageInfo
{
/// Convert to OpenAI Responses API format
pub
fn
to_response_usage
(
&
self
)
->
ResponseUsage
{
ResponseUsage
{
input_tokens
:
self
.prompt_tokens
,
output_tokens
:
self
.completion_tokens
,
total_tokens
:
self
.total_tokens
,
input_tokens_details
:
self
.prompt_tokens_details
.as_ref
()
.map
(|
details
|
{
InputTokensDetails
{
cached_tokens
:
details
.cached_tokens
,
}
}),
output_tokens_details
:
self
.reasoning_tokens
.map
(|
tokens
|
OutputTokensDetails
{
reasoning_tokens
:
tokens
,
}),
}
}
}
impl
From
<
UsageInfo
>
for
ResponseUsage
{
fn
from
(
usage
:
UsageInfo
)
->
Self
{
usage
.to_response_usage
()
}
}
impl
ResponseUsage
{
/// Convert back to standard UsageInfo format
pub
fn
to_usage_info
(
&
self
)
->
UsageInfo
{
UsageInfo
{
prompt_tokens
:
self
.input_tokens
,
completion_tokens
:
self
.output_tokens
,
total_tokens
:
self
.total_tokens
,
reasoning_tokens
:
self
.output_tokens_details
.as_ref
()
.map
(|
details
|
details
.reasoning_tokens
),
prompt_tokens_details
:
self
.input_tokens_details
.as_ref
()
.map
(|
details
|
{
PromptTokenUsageInfo
{
cached_tokens
:
details
.cached_tokens
,
}
}),
}
}
}
sgl-router/src/protocols/spec.rs
0 → 100644
View file @
5ef545e6
use
serde
::{
Deserialize
,
Serialize
};
use
serde_json
::
Value
;
use
std
::
collections
::
HashMap
;
// # Protocol Specifications
//
// This module contains all protocol definitions for OpenAI and SGLang APIs.
//
// ## Table of Contents
//
// 1. **OPENAI SPEC - Chat Completions API**
// - Message Types
// - Response Format Types
// - Tool/Function Types
// - Streaming Delta Types
// - Request/Response structures
//
// 2. **OPENAI SPEC - Completions API**
// - Request/Response structures
// - Streaming support
//
// 3. **OPENAI SPEC - Responses API**
// - Tool Definitions
// - Reasoning Configuration
// - Input/Output Items
// - Service Tier & Tool Choice
// - Request/Response structures
//
// 4. **OPENAI SPEC - Common**
// - Shared Request Components
// - Tool Choice Types
// - Usage Tracking
// - Logprobs Types
// - Error Response Types
//
// 5. **SGLANG SPEC - GENERATE API**
// - Generate Parameters
// - Sampling Parameters
// - Request/Response structures
//
// 6. **COMMON**
// - GenerationRequest trait
// - StringOrArray & LoRAPath types
// - Helper functions
// ==================================================================
// = OPENAI SPEC - Chat Completions API =
// ==================================================================
// ============= Message Types =============
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
#[serde(untagged)]
pub
enum
ChatMessage
{
System
{
role
:
String
,
content
:
String
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
name
:
Option
<
String
>
,
},
User
{
role
:
String
,
// "user"
content
:
UserMessageContent
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
name
:
Option
<
String
>
,
},
Assistant
{
role
:
String
,
// "assistant"
#[serde(skip_serializing_if
=
"Option::is_none"
)]
content
:
Option
<
String
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
name
:
Option
<
String
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
tool_calls
:
Option
<
Vec
<
ToolCall
>>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
function_call
:
Option
<
FunctionCallResponse
>
,
/// Reasoning content for O1-style models (SGLang extension)
#[serde(skip_serializing_if
=
"Option::is_none"
)]
reasoning_content
:
Option
<
String
>
,
},
Tool
{
role
:
String
,
// "tool"
content
:
String
,
tool_call_id
:
String
,
},
Function
{
role
:
String
,
// "function"
content
:
String
,
name
:
String
,
},
}
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
#[serde(untagged)]
pub
enum
UserMessageContent
{
Text
(
String
),
Parts
(
Vec
<
ContentPart
>
),
}
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
#[serde(tag
=
"type"
)]
pub
enum
ContentPart
{
#[serde(rename
=
"text"
)]
Text
{
text
:
String
},
#[serde(rename
=
"image_url"
)]
ImageUrl
{
image_url
:
ImageUrl
},
}
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
pub
struct
ImageUrl
{
pub
url
:
String
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
detail
:
Option
<
String
>
,
// "auto", "low", or "high"
}
// ============= Response Format Types =============
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
#[serde(tag
=
"type"
)]
pub
enum
ResponseFormat
{
#[serde(rename
=
"text"
)]
Text
,
#[serde(rename
=
"json_object"
)]
JsonObject
,
#[serde(rename
=
"json_schema"
)]
JsonSchema
{
json_schema
:
JsonSchemaFormat
},
}
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
pub
struct
JsonSchemaFormat
{
pub
name
:
String
,
pub
schema
:
Value
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
strict
:
Option
<
bool
>
,
}
// ============= Streaming Delta Types =============
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
pub
struct
ChatMessageDelta
{
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
role
:
Option
<
String
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
content
:
Option
<
String
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
tool_calls
:
Option
<
Vec
<
ToolCallDelta
>>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
function_call
:
Option
<
FunctionCallDelta
>
,
/// Reasoning content delta for O1-style models (SGLang extension)
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
reasoning_content
:
Option
<
String
>
,
}
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
pub
struct
ToolCallDelta
{
pub
index
:
u32
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
id
:
Option
<
String
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
#[serde(rename
=
"type"
)]
pub
tool_type
:
Option
<
String
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
function
:
Option
<
FunctionCallDelta
>
,
}
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
pub
struct
FunctionCallDelta
{
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
name
:
Option
<
String
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
arguments
:
Option
<
String
>
,
}
// ============= Request =============
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
pub
struct
ChatCompletionRequest
{
/// ID of the model to use
pub
model
:
String
,
/// A list of messages comprising the conversation so far
pub
messages
:
Vec
<
ChatMessage
>
,
/// What sampling temperature to use, between 0 and 2
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
temperature
:
Option
<
f32
>
,
/// An alternative to sampling with temperature
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
top_p
:
Option
<
f32
>
,
/// How many chat completion choices to generate for each input message
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
n
:
Option
<
u32
>
,
/// If set, partial message deltas will be sent
#[serde(default)]
pub
stream
:
bool
,
/// Options for streaming response
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
stream_options
:
Option
<
StreamOptions
>
,
/// Up to 4 sequences where the API will stop generating further tokens
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
stop
:
Option
<
StringOrArray
>
,
/// The maximum number of tokens to generate
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
max_tokens
:
Option
<
u32
>
,
/// An upper bound for the number of tokens that can be generated for a completion
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
max_completion_tokens
:
Option
<
u32
>
,
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
presence_penalty
:
Option
<
f32
>
,
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
frequency_penalty
:
Option
<
f32
>
,
/// Modify the likelihood of specified tokens appearing in the completion
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
logit_bias
:
Option
<
HashMap
<
String
,
f32
>>
,
/// A unique identifier representing your end-user
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
user
:
Option
<
String
>
,
/// If specified, our system will make a best effort to sample deterministically
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
seed
:
Option
<
i64
>
,
/// Whether to return log probabilities of the output tokens
#[serde(default)]
pub
logprobs
:
bool
,
/// An integer between 0 and 20 specifying the number of most likely tokens to return
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
top_logprobs
:
Option
<
u32
>
,
/// An object specifying the format that the model must output
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
response_format
:
Option
<
ResponseFormat
>
,
/// A list of tools the model may call
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
tools
:
Option
<
Vec
<
Tool
>>
,
/// Controls which (if any) tool is called by the model
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
tool_choice
:
Option
<
ToolChoice
>
,
/// Whether to enable parallel function calling during tool use
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
parallel_tool_calls
:
Option
<
bool
>
,
/// Deprecated: use tools instead
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
functions
:
Option
<
Vec
<
Function
>>
,
/// Deprecated: use tool_choice instead
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
function_call
:
Option
<
FunctionCall
>
,
// ============= SGLang Extensions =============
/// Top-k sampling parameter (-1 to disable)
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
top_k
:
Option
<
i32
>
,
/// Min-p nucleus sampling parameter
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
min_p
:
Option
<
f32
>
,
/// Minimum number of tokens to generate
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
min_tokens
:
Option
<
u32
>
,
/// Repetition penalty for reducing repetitive text
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
repetition_penalty
:
Option
<
f32
>
,
/// Regex constraint for output generation
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
regex
:
Option
<
String
>
,
/// EBNF grammar constraint for structured output
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
ebnf
:
Option
<
String
>
,
/// Specific token IDs to use as stop conditions
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
stop_token_ids
:
Option
<
Vec
<
i32
>>
,
/// Skip trimming stop tokens from output
#[serde(default)]
pub
no_stop_trim
:
bool
,
/// Ignore end-of-sequence tokens during generation
#[serde(default)]
pub
ignore_eos
:
bool
,
/// Continue generating from final assistant message
#[serde(default)]
pub
continue_final_message
:
bool
,
/// Skip special tokens during detokenization
#[serde(default
=
"default_true"
)]
pub
skip_special_tokens
:
bool
,
// ============= SGLang Extensions =============
/// Path to LoRA adapter(s) for model customization
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
lora_path
:
Option
<
LoRAPath
>
,
/// Session parameters for continual prompting
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
session_params
:
Option
<
HashMap
<
String
,
serde_json
::
Value
>>
,
/// Separate reasoning content from final answer (O1-style models)
#[serde(default
=
"default_true"
)]
pub
separate_reasoning
:
bool
,
/// Stream reasoning tokens during generation
#[serde(default
=
"default_true"
)]
pub
stream_reasoning
:
bool
,
/// Return model hidden states
#[serde(default)]
pub
return_hidden_states
:
bool
,
}
impl
GenerationRequest
for
ChatCompletionRequest
{
fn
is_stream
(
&
self
)
->
bool
{
self
.stream
}
fn
get_model
(
&
self
)
->
Option
<&
str
>
{
Some
(
&
self
.model
)
}
fn
extract_text_for_routing
(
&
self
)
->
String
{
// Extract text from messages for routing decisions
self
.messages
.iter
()
.filter_map
(|
msg
|
match
msg
{
ChatMessage
::
System
{
content
,
..
}
=>
Some
(
content
.clone
()),
ChatMessage
::
User
{
content
,
..
}
=>
match
content
{
UserMessageContent
::
Text
(
text
)
=>
Some
(
text
.clone
()),
UserMessageContent
::
Parts
(
parts
)
=>
{
let
texts
:
Vec
<
String
>
=
parts
.iter
()
.filter_map
(|
part
|
match
part
{
ContentPart
::
Text
{
text
}
=>
Some
(
text
.clone
()),
_
=>
None
,
})
.collect
();
Some
(
texts
.join
(
" "
))
}
},
ChatMessage
::
Assistant
{
content
,
reasoning_content
,
..
}
=>
{
// Combine content and reasoning content for routing decisions
let
main_content
=
content
.clone
()
.unwrap_or_default
();
let
reasoning
=
reasoning_content
.clone
()
.unwrap_or_default
();
if
main_content
.is_empty
()
&&
reasoning
.is_empty
()
{
None
}
else
{
Some
(
format!
(
"{} {}"
,
main_content
,
reasoning
)
.trim
()
.to_string
())
}
}
ChatMessage
::
Tool
{
content
,
..
}
=>
Some
(
content
.clone
()),
ChatMessage
::
Function
{
content
,
..
}
=>
Some
(
content
.clone
()),
})
.collect
::
<
Vec
<
String
>>
()
.join
(
" "
)
}
}
// ============= Regular Response =============
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
pub
struct
ChatCompletionResponse
{
pub
id
:
String
,
pub
object
:
String
,
// "chat.completion"
pub
created
:
u64
,
pub
model
:
String
,
pub
choices
:
Vec
<
ChatChoice
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
usage
:
Option
<
Usage
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
system_fingerprint
:
Option
<
String
>
,
}
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
pub
struct
ChatChoice
{
pub
index
:
u32
,
pub
message
:
ChatMessage
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
logprobs
:
Option
<
ChatLogProbs
>
,
pub
finish_reason
:
Option
<
String
>
,
// "stop", "length", "tool_calls", "content_filter", "function_call"
/// Information about which stop condition was matched
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
matched_stop
:
Option
<
serde_json
::
Value
>
,
// Can be string or integer
/// Hidden states from the model (SGLang extension)
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
hidden_states
:
Option
<
Vec
<
f32
>>
,
}
// ============= Streaming Response =============
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
pub
struct
ChatCompletionStreamResponse
{
pub
id
:
String
,
pub
object
:
String
,
// "chat.completion.chunk"
pub
created
:
u64
,
pub
model
:
String
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
system_fingerprint
:
Option
<
String
>
,
pub
choices
:
Vec
<
ChatStreamChoice
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
usage
:
Option
<
Usage
>
,
}
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
pub
struct
ChatStreamChoice
{
pub
index
:
u32
,
pub
delta
:
ChatMessageDelta
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
logprobs
:
Option
<
ChatLogProbs
>
,
pub
finish_reason
:
Option
<
String
>
,
}
// ==================================================================
// = OPENAI SPEC - Completions API =
// ==================================================================
// Completions API request types (v1/completions) - DEPRECATED but still supported
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
pub
struct
CompletionRequest
{
/// ID of the model to use (required for OpenAI, optional for some implementations, such as SGLang)
pub
model
:
String
,
/// The prompt(s) to generate completions for
pub
prompt
:
StringOrArray
,
/// The suffix that comes after a completion of inserted text
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
suffix
:
Option
<
String
>
,
/// The maximum number of tokens to generate
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
max_tokens
:
Option
<
u32
>
,
/// What sampling temperature to use, between 0 and 2
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
temperature
:
Option
<
f32
>
,
/// An alternative to sampling with temperature (nucleus sampling)
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
top_p
:
Option
<
f32
>
,
/// How many completions to generate for each prompt
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
n
:
Option
<
u32
>
,
/// Whether to stream back partial progress
#[serde(default)]
pub
stream
:
bool
,
/// Options for streaming response
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
stream_options
:
Option
<
StreamOptions
>
,
/// Include the log probabilities on the logprobs most likely tokens
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
logprobs
:
Option
<
u32
>
,
/// Echo back the prompt in addition to the completion
#[serde(default)]
pub
echo
:
bool
,
/// Up to 4 sequences where the API will stop generating further tokens
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
stop
:
Option
<
StringOrArray
>
,
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
presence_penalty
:
Option
<
f32
>
,
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
frequency_penalty
:
Option
<
f32
>
,
/// Generates best_of completions server-side and returns the "best"
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
best_of
:
Option
<
u32
>
,
/// Modify the likelihood of specified tokens appearing in the completion
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
logit_bias
:
Option
<
HashMap
<
String
,
f32
>>
,
/// A unique identifier representing your end-user
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
user
:
Option
<
String
>
,
/// If specified, our system will make a best effort to sample deterministically
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
seed
:
Option
<
i64
>
,
// ============= SGLang Extensions =============
/// Top-k sampling parameter (-1 to disable)
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
top_k
:
Option
<
i32
>
,
/// Min-p nucleus sampling parameter
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
min_p
:
Option
<
f32
>
,
/// Minimum number of tokens to generate
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
min_tokens
:
Option
<
u32
>
,
/// Repetition penalty for reducing repetitive text
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
repetition_penalty
:
Option
<
f32
>
,
/// Regex constraint for output generation
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
regex
:
Option
<
String
>
,
/// EBNF grammar constraint for structured output
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
ebnf
:
Option
<
String
>
,
/// JSON schema constraint for structured output
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
json_schema
:
Option
<
String
>
,
/// Specific token IDs to use as stop conditions
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
stop_token_ids
:
Option
<
Vec
<
i32
>>
,
/// Skip trimming stop tokens from output
#[serde(default)]
pub
no_stop_trim
:
bool
,
/// Ignore end-of-sequence tokens during generation
#[serde(default)]
pub
ignore_eos
:
bool
,
/// Skip special tokens during detokenization
#[serde(default
=
"default_true"
)]
pub
skip_special_tokens
:
bool
,
// ============= SGLang Extensions =============
/// Path to LoRA adapter(s) for model customization
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
lora_path
:
Option
<
LoRAPath
>
,
/// Session parameters for continual prompting
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
session_params
:
Option
<
HashMap
<
String
,
serde_json
::
Value
>>
,
/// Return model hidden states
#[serde(default)]
pub
return_hidden_states
:
bool
,
/// Additional fields including bootstrap info for PD routing
#[serde(flatten)]
pub
other
:
serde_json
::
Map
<
String
,
serde_json
::
Value
>
,
}
impl
GenerationRequest
for
CompletionRequest
{
fn
is_stream
(
&
self
)
->
bool
{
self
.stream
}
fn
get_model
(
&
self
)
->
Option
<&
str
>
{
Some
(
&
self
.model
)
}
fn
extract_text_for_routing
(
&
self
)
->
String
{
match
&
self
.prompt
{
StringOrArray
::
String
(
s
)
=>
s
.clone
(),
StringOrArray
::
Array
(
v
)
=>
v
.join
(
" "
),
}
}
}
// ============= Regular Response =============
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
pub
struct
CompletionResponse
{
pub
id
:
String
,
pub
object
:
String
,
// "text_completion"
pub
created
:
u64
,
pub
model
:
String
,
pub
choices
:
Vec
<
CompletionChoice
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
usage
:
Option
<
Usage
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
system_fingerprint
:
Option
<
String
>
,
}
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
pub
struct
CompletionChoice
{
pub
text
:
String
,
pub
index
:
u32
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
logprobs
:
Option
<
LogProbs
>
,
pub
finish_reason
:
Option
<
String
>
,
// "stop", "length", "content_filter", etc.
/// Information about which stop condition was matched
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
matched_stop
:
Option
<
serde_json
::
Value
>
,
// Can be string or integer
/// Hidden states from the model (SGLang extension)
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
hidden_states
:
Option
<
Vec
<
f32
>>
,
}
// ============= Streaming Response =============
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
pub
struct
CompletionStreamResponse
{
pub
id
:
String
,
pub
object
:
String
,
// "text_completion"
pub
created
:
u64
,
pub
choices
:
Vec
<
CompletionStreamChoice
>
,
pub
model
:
String
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
system_fingerprint
:
Option
<
String
>
,
}
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
pub
struct
CompletionStreamChoice
{
pub
text
:
String
,
pub
index
:
u32
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
logprobs
:
Option
<
LogProbs
>
,
pub
finish_reason
:
Option
<
String
>
,
}
// ==================================================================
// = OPENAI SPEC - Responses API =
// ==================================================================
// ============= Tool Definitions =============
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
pub
struct
ResponseTool
{
#[serde(rename
=
"type"
)]
pub
r
#
type
:
ResponseToolType
,
}
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
#[serde(rename_all
=
"snake_case"
)]
pub
enum
ResponseToolType
{
WebSearchPreview
,
CodeInterpreter
,
}
// ============= Reasoning Configuration =============
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
pub
struct
ResponseReasoningParam
{
#[serde(default
=
"default_reasoning_effort"
)]
pub
effort
:
Option
<
ReasoningEffort
>
,
}
fn
default_reasoning_effort
()
->
Option
<
ReasoningEffort
>
{
Some
(
ReasoningEffort
::
Medium
)
}
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
#[serde(rename_all
=
"snake_case"
)]
pub
enum
ReasoningEffort
{
Low
,
Medium
,
High
,
}
// ============= Input/Output Items =============
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
#[serde(tag
=
"type"
)]
#[serde(rename_all
=
"snake_case"
)]
pub
enum
ResponseInputOutputItem
{
#[serde(rename
=
"message"
)]
Message
{
id
:
String
,
role
:
String
,
content
:
Vec
<
ResponseContentPart
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
status
:
Option
<
String
>
,
},
#[serde(rename
=
"reasoning"
)]
Reasoning
{
id
:
String
,
#[serde(skip_serializing_if
=
"Vec::is_empty"
)]
summary
:
Vec
<
String
>
,
content
:
Vec
<
ResponseReasoningContent
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
status
:
Option
<
String
>
,
},
#[serde(rename
=
"function_tool_call"
)]
FunctionToolCall
{
id
:
String
,
name
:
String
,
arguments
:
String
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
output
:
Option
<
String
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
status
:
Option
<
String
>
,
},
}
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
#[serde(tag
=
"type"
)]
#[serde(rename_all
=
"snake_case"
)]
pub
enum
ResponseContentPart
{
#[serde(rename
=
"output_text"
)]
OutputText
{
text
:
String
,
#[serde(skip_serializing_if
=
"Vec::is_empty"
)]
annotations
:
Vec
<
String
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
logprobs
:
Option
<
ChatLogProbs
>
,
},
}
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
#[serde(tag
=
"type"
)]
#[serde(rename_all
=
"snake_case"
)]
pub
enum
ResponseReasoningContent
{
#[serde(rename
=
"reasoning_text"
)]
ReasoningText
{
text
:
String
},
}
// ============= Output Items for Response =============
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
#[serde(tag
=
"type"
)]
#[serde(rename_all
=
"snake_case"
)]
pub
enum
ResponseOutputItem
{
#[serde(rename
=
"message"
)]
Message
{
id
:
String
,
role
:
String
,
content
:
Vec
<
ResponseContentPart
>
,
status
:
String
,
},
#[serde(rename
=
"reasoning"
)]
Reasoning
{
id
:
String
,
#[serde(skip_serializing_if
=
"Vec::is_empty"
)]
summary
:
Vec
<
String
>
,
content
:
Vec
<
ResponseReasoningContent
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
status
:
Option
<
String
>
,
},
#[serde(rename
=
"function_tool_call"
)]
FunctionToolCall
{
id
:
String
,
name
:
String
,
arguments
:
String
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
output
:
Option
<
String
>
,
status
:
String
,
},
}
// ============= Service Tier =============
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
#[serde(rename_all
=
"snake_case"
)]
pub
enum
ServiceTier
{
Auto
,
Default
,
Flex
,
Scale
,
Priority
,
}
impl
Default
for
ServiceTier
{
fn
default
()
->
Self
{
Self
::
Auto
}
}
// ============= Truncation =============
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
#[serde(rename_all
=
"snake_case"
)]
pub
enum
Truncation
{
Auto
,
Disabled
,
}
impl
Default
for
Truncation
{
fn
default
()
->
Self
{
Self
::
Disabled
}
}
// ============= Response Status =============
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
#[serde(rename_all
=
"snake_case"
)]
pub
enum
ResponseStatus
{
Queued
,
InProgress
,
Completed
,
Failed
,
Cancelled
,
}
// ============= Include Fields =============
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
#[serde(rename_all
=
"snake_case"
)]
pub
enum
IncludeField
{
#[serde(rename
=
"code_interpreter_call.outputs"
)]
CodeInterpreterCallOutputs
,
#[serde(rename
=
"computer_call_output.output.image_url"
)]
ComputerCallOutputImageUrl
,
#[serde(rename
=
"file_search_call.results"
)]
FileSearchCallResults
,
#[serde(rename
=
"message.input_image.image_url"
)]
MessageInputImageUrl
,
#[serde(rename
=
"message.output_text.logprobs"
)]
MessageOutputTextLogprobs
,
#[serde(rename
=
"reasoning.encrypted_content"
)]
ReasoningEncryptedContent
,
}
// ============= Usage Info =============
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
pub
struct
UsageInfo
{
pub
prompt_tokens
:
u32
,
pub
completion_tokens
:
u32
,
pub
total_tokens
:
u32
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
reasoning_tokens
:
Option
<
u32
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
prompt_tokens_details
:
Option
<
PromptTokenUsageInfo
>
,
}
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
pub
struct
PromptTokenUsageInfo
{
pub
cached_tokens
:
u32
,
}
// ============= Response Usage Format =============
/// OpenAI Responses API usage format (different from standard UsageInfo)
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
pub
struct
ResponseUsage
{
pub
input_tokens
:
u32
,
pub
output_tokens
:
u32
,
pub
total_tokens
:
u32
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
input_tokens_details
:
Option
<
InputTokensDetails
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
output_tokens_details
:
Option
<
OutputTokensDetails
>
,
}
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
pub
struct
InputTokensDetails
{
pub
cached_tokens
:
u32
,
}
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
pub
struct
OutputTokensDetails
{
pub
reasoning_tokens
:
u32
,
}
impl
UsageInfo
{
/// Convert to OpenAI Responses API format
pub
fn
to_response_usage
(
&
self
)
->
ResponseUsage
{
ResponseUsage
{
input_tokens
:
self
.prompt_tokens
,
output_tokens
:
self
.completion_tokens
,
total_tokens
:
self
.total_tokens
,
input_tokens_details
:
self
.prompt_tokens_details
.as_ref
()
.map
(|
details
|
{
InputTokensDetails
{
cached_tokens
:
details
.cached_tokens
,
}
}),
output_tokens_details
:
self
.reasoning_tokens
.map
(|
tokens
|
OutputTokensDetails
{
reasoning_tokens
:
tokens
,
}),
}
}
}
impl
From
<
UsageInfo
>
for
ResponseUsage
{
fn
from
(
usage
:
UsageInfo
)
->
Self
{
usage
.to_response_usage
()
}
}
impl
ResponseUsage
{
/// Convert back to standard UsageInfo format
pub
fn
to_usage_info
(
&
self
)
->
UsageInfo
{
UsageInfo
{
prompt_tokens
:
self
.input_tokens
,
completion_tokens
:
self
.output_tokens
,
total_tokens
:
self
.total_tokens
,
reasoning_tokens
:
self
.output_tokens_details
.as_ref
()
.map
(|
details
|
details
.reasoning_tokens
),
prompt_tokens_details
:
self
.input_tokens_details
.as_ref
()
.map
(|
details
|
{
PromptTokenUsageInfo
{
cached_tokens
:
details
.cached_tokens
,
}
}),
}
}
}
fn
generate_request_id
()
->
String
{
format!
(
"resp_{}"
,
uuid
::
Uuid
::
new_v4
()
.simple
())
}
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
pub
struct
ResponsesRequest
{
// ============= Core OpenAI API fields =============
/// Run the request in the background
#[serde(default)]
pub
background
:
bool
,
/// Fields to include in the response
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
include
:
Option
<
Vec
<
IncludeField
>>
,
/// Input content - can be string or structured items
pub
input
:
ResponseInput
,
/// System instructions for the model
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
instructions
:
Option
<
String
>
,
/// Maximum number of output tokens
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
max_output_tokens
:
Option
<
u32
>
,
/// Maximum number of tool calls
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
max_tool_calls
:
Option
<
u32
>
,
/// Additional metadata
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
metadata
:
Option
<
HashMap
<
String
,
serde_json
::
Value
>>
,
/// Model to use (optional to match vLLM)
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
model
:
Option
<
String
>
,
/// Whether to enable parallel tool calls
#[serde(default
=
"default_true"
)]
pub
parallel_tool_calls
:
bool
,
/// ID of previous response to continue from
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
previous_response_id
:
Option
<
String
>
,
/// Reasoning configuration
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
reasoning
:
Option
<
ResponseReasoningParam
>
,
/// Service tier
#[serde(default)]
pub
service_tier
:
ServiceTier
,
/// Whether to store the response
#[serde(default
=
"default_true"
)]
pub
store
:
bool
,
/// Whether to stream the response
#[serde(default)]
pub
stream
:
bool
,
/// Temperature for sampling
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
temperature
:
Option
<
f32
>
,
/// Tool choice behavior
#[serde(default)]
pub
tool_choice
:
ToolChoice
,
/// Available tools
#[serde(default)]
pub
tools
:
Vec
<
ResponseTool
>
,
/// Number of top logprobs to return
#[serde(default)]
pub
top_logprobs
:
u32
,
/// Top-p sampling parameter
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
top_p
:
Option
<
f32
>
,
/// Truncation behavior
#[serde(default)]
pub
truncation
:
Truncation
,
/// User identifier
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
user
:
Option
<
String
>
,
// ============= SGLang Extensions =============
/// Request ID
#[serde(default
=
"generate_request_id"
)]
pub
request_id
:
String
,
/// Request priority
#[serde(default)]
pub
priority
:
i32
,
/// Frequency penalty
#[serde(default)]
pub
frequency_penalty
:
f32
,
/// Presence penalty
#[serde(default)]
pub
presence_penalty
:
f32
,
/// Stop sequences
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
stop
:
Option
<
StringOrArray
>
,
/// Top-k sampling parameter
#[serde(default
=
"default_top_k"
)]
pub
top_k
:
i32
,
/// Min-p sampling parameter
#[serde(default)]
pub
min_p
:
f32
,
/// Repetition penalty
#[serde(default
=
"default_repetition_penalty"
)]
pub
repetition_penalty
:
f32
,
}
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
#[serde(untagged)]
pub
enum
ResponseInput
{
Text
(
String
),
Items
(
Vec
<
ResponseInputOutputItem
>
),
}
fn
default_top_k
()
->
i32
{
-
1
}
fn
default_repetition_penalty
()
->
f32
{
1.0
}
impl
ResponsesRequest
{
/// Default sampling parameters
const
DEFAULT_TEMPERATURE
:
f32
=
0.7
;
const
DEFAULT_TOP_P
:
f32
=
1.0
;
/// Convert to sampling parameters for generation
pub
fn
to_sampling_params
(
&
self
,
default_max_tokens
:
u32
,
default_params
:
Option
<
HashMap
<
String
,
serde_json
::
Value
>>
,
)
->
HashMap
<
String
,
serde_json
::
Value
>
{
let
mut
params
=
HashMap
::
new
();
// Use max_output_tokens if available
let
max_tokens
=
if
let
Some
(
max_output
)
=
self
.max_output_tokens
{
std
::
cmp
::
min
(
max_output
,
default_max_tokens
)
}
else
{
default_max_tokens
};
// Avoid exceeding context length by minus 1 token
let
max_tokens
=
max_tokens
.saturating_sub
(
1
);
// Temperature
let
temperature
=
self
.temperature
.unwrap_or_else
(||
{
default_params
.as_ref
()
.and_then
(|
p
|
p
.get
(
"temperature"
))
.and_then
(|
v
|
v
.as_f64
())
.map
(|
v
|
v
as
f32
)
.unwrap_or
(
Self
::
DEFAULT_TEMPERATURE
)
});
// Top-p
let
top_p
=
self
.top_p
.unwrap_or_else
(||
{
default_params
.as_ref
()
.and_then
(|
p
|
p
.get
(
"top_p"
))
.and_then
(|
v
|
v
.as_f64
())
.map
(|
v
|
v
as
f32
)
.unwrap_or
(
Self
::
DEFAULT_TOP_P
)
});
params
.insert
(
"max_new_tokens"
.to_string
(),
serde_json
::
Value
::
Number
(
serde_json
::
Number
::
from
(
max_tokens
)),
);
params
.insert
(
"temperature"
.to_string
(),
serde_json
::
Value
::
Number
(
serde_json
::
Number
::
from_f64
(
temperature
as
f64
)
.unwrap
()),
);
params
.insert
(
"top_p"
.to_string
(),
serde_json
::
Value
::
Number
(
serde_json
::
Number
::
from_f64
(
top_p
as
f64
)
.unwrap
()),
);
params
.insert
(
"frequency_penalty"
.to_string
(),
serde_json
::
Value
::
Number
(
serde_json
::
Number
::
from_f64
(
self
.frequency_penalty
as
f64
)
.unwrap
(),
),
);
params
.insert
(
"presence_penalty"
.to_string
(),
serde_json
::
Value
::
Number
(
serde_json
::
Number
::
from_f64
(
self
.presence_penalty
as
f64
)
.unwrap
(),
),
);
params
.insert
(
"top_k"
.to_string
(),
serde_json
::
Value
::
Number
(
serde_json
::
Number
::
from
(
self
.top_k
)),
);
params
.insert
(
"min_p"
.to_string
(),
serde_json
::
Value
::
Number
(
serde_json
::
Number
::
from_f64
(
self
.min_p
as
f64
)
.unwrap
()),
);
params
.insert
(
"repetition_penalty"
.to_string
(),
serde_json
::
Value
::
Number
(
serde_json
::
Number
::
from_f64
(
self
.repetition_penalty
as
f64
)
.unwrap
(),
),
);
if
let
Some
(
ref
stop
)
=
self
.stop
{
match
serde_json
::
to_value
(
stop
)
{
Ok
(
value
)
=>
params
.insert
(
"stop"
.to_string
(),
value
),
Err
(
_
)
=>
params
.insert
(
"stop"
.to_string
(),
serde_json
::
Value
::
Null
),
};
}
// Apply any additional default parameters
if
let
Some
(
default_params
)
=
default_params
{
for
(
key
,
value
)
in
default_params
{
params
.entry
(
key
)
.or_insert
(
value
);
}
}
params
}
}
impl
GenerationRequest
for
ResponsesRequest
{
fn
is_stream
(
&
self
)
->
bool
{
self
.stream
}
fn
get_model
(
&
self
)
->
Option
<&
str
>
{
self
.model
.as_deref
()
}
fn
extract_text_for_routing
(
&
self
)
->
String
{
match
&
self
.input
{
ResponseInput
::
Text
(
text
)
=>
text
.clone
(),
ResponseInput
::
Items
(
items
)
=>
items
.iter
()
.filter_map
(|
item
|
match
item
{
ResponseInputOutputItem
::
Message
{
content
,
..
}
=>
{
let
texts
:
Vec
<
String
>
=
content
.iter
()
.map
(|
part
|
match
part
{
ResponseContentPart
::
OutputText
{
text
,
..
}
=>
text
.clone
(),
})
.collect
();
if
texts
.is_empty
()
{
None
}
else
{
Some
(
texts
.join
(
" "
))
}
}
ResponseInputOutputItem
::
Reasoning
{
content
,
..
}
=>
{
let
texts
:
Vec
<
String
>
=
content
.iter
()
.map
(|
part
|
match
part
{
ResponseReasoningContent
::
ReasoningText
{
text
}
=>
text
.clone
(),
})
.collect
();
if
texts
.is_empty
()
{
None
}
else
{
Some
(
texts
.join
(
" "
))
}
}
ResponseInputOutputItem
::
FunctionToolCall
{
arguments
,
..
}
=>
{
Some
(
arguments
.clone
())
}
})
.collect
::
<
Vec
<
String
>>
()
.join
(
" "
),
}
}
}
fn
generate_response_id
()
->
String
{
format!
(
"resp_{}"
,
uuid
::
Uuid
::
new_v4
()
.simple
())
}
fn
current_timestamp
()
->
i64
{
std
::
time
::
SystemTime
::
now
()
.duration_since
(
std
::
time
::
UNIX_EPOCH
)
.unwrap_or_else
(|
_
|
std
::
time
::
Duration
::
from_secs
(
0
))
.as_secs
()
as
i64
}
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
pub
struct
ResponsesResponse
{
/// Response ID
#[serde(default
=
"generate_response_id"
)]
pub
id
:
String
,
/// Object type
#[serde(default
=
"default_object_type"
)]
pub
object
:
String
,
/// Creation timestamp
#[serde(default
=
"current_timestamp"
)]
pub
created_at
:
i64
,
/// Model name
pub
model
:
String
,
/// Output items
#[serde(default)]
pub
output
:
Vec
<
ResponseOutputItem
>
,
/// Response status
pub
status
:
ResponseStatus
,
/// Usage information
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
usage
:
Option
<
UsageInfo
>
,
/// Whether parallel tool calls are enabled
#[serde(default
=
"default_true"
)]
pub
parallel_tool_calls
:
bool
,
/// Tool choice setting
#[serde(default
=
"default_tool_choice"
)]
pub
tool_choice
:
String
,
/// Available tools
#[serde(default)]
pub
tools
:
Vec
<
ResponseTool
>
,
}
fn
default_object_type
()
->
String
{
"response"
.to_string
()
}
fn
default_tool_choice
()
->
String
{
"auto"
.to_string
()
}
impl
ResponsesResponse
{
/// Create a response from a request
#[allow(clippy::too_many_arguments)]
pub
fn
from_request
(
request
:
&
ResponsesRequest
,
_
sampling_params
:
&
HashMap
<
String
,
serde_json
::
Value
>
,
model_name
:
String
,
created_time
:
i64
,
output
:
Vec
<
ResponseOutputItem
>
,
status
:
ResponseStatus
,
usage
:
Option
<
UsageInfo
>
,
)
->
Self
{
Self
{
id
:
request
.request_id
.clone
(),
object
:
"response"
.to_string
(),
created_at
:
created_time
,
model
:
model_name
,
output
,
status
,
usage
,
parallel_tool_calls
:
request
.parallel_tool_calls
,
tool_choice
:
match
&
request
.tool_choice
{
ToolChoice
::
Value
(
ToolChoiceValue
::
Auto
)
=>
"auto"
.to_string
(),
ToolChoice
::
Value
(
ToolChoiceValue
::
Required
)
=>
"required"
.to_string
(),
ToolChoice
::
Value
(
ToolChoiceValue
::
None
)
=>
"none"
.to_string
(),
ToolChoice
::
Function
{
..
}
=>
"function"
.to_string
(),
},
tools
:
request
.tools
.clone
(),
}
}
/// Create a new response with default values
pub
fn
new
(
request_id
:
String
,
model
:
String
,
status
:
ResponseStatus
)
->
Self
{
Self
{
id
:
request_id
,
object
:
"response"
.to_string
(),
created_at
:
current_timestamp
(),
model
,
output
:
Vec
::
new
(),
status
,
usage
:
None
,
parallel_tool_calls
:
true
,
tool_choice
:
"auto"
.to_string
(),
tools
:
Vec
::
new
(),
}
}
/// Add an output item to the response
pub
fn
add_output
(
&
mut
self
,
item
:
ResponseOutputItem
)
{
self
.output
.push
(
item
);
}
/// Set the usage information
pub
fn
set_usage
(
&
mut
self
,
usage
:
UsageInfo
)
{
self
.usage
=
Some
(
usage
);
}
/// Update the status
pub
fn
set_status
(
&
mut
self
,
status
:
ResponseStatus
)
{
self
.status
=
status
;
}
/// Check if the response is complete
pub
fn
is_complete
(
&
self
)
->
bool
{
matches!
(
self
.status
,
ResponseStatus
::
Completed
)
}
/// Check if the response is in progress
pub
fn
is_in_progress
(
&
self
)
->
bool
{
matches!
(
self
.status
,
ResponseStatus
::
InProgress
)
}
/// Check if the response failed
pub
fn
is_failed
(
&
self
)
->
bool
{
matches!
(
self
.status
,
ResponseStatus
::
Failed
)
}
/// Check if the response was cancelled
pub
fn
is_cancelled
(
&
self
)
->
bool
{
matches!
(
self
.status
,
ResponseStatus
::
Cancelled
)
}
/// Check if the response is queued
pub
fn
is_queued
(
&
self
)
->
bool
{
matches!
(
self
.status
,
ResponseStatus
::
Queued
)
}
/// Convert usage to OpenAI Responses API format
pub
fn
usage_in_response_format
(
&
self
)
->
Option
<
ResponseUsage
>
{
self
.usage
.as_ref
()
.map
(|
usage
|
usage
.to_response_usage
())
}
/// Get the response as a JSON value with usage in response format
pub
fn
to_response_format
(
&
self
)
->
serde_json
::
Value
{
let
mut
response
=
serde_json
::
to_value
(
self
)
.unwrap_or
(
serde_json
::
Value
::
Null
);
// Convert usage to response format if present
if
let
Some
(
usage
)
=
&
self
.usage
{
if
let
Ok
(
usage_value
)
=
serde_json
::
to_value
(
usage
.to_response_usage
())
{
response
[
"usage"
]
=
usage_value
;
}
}
response
}
}
// ============= Helper Functions =============
impl
ResponseOutputItem
{
/// Create a new message output item
pub
fn
new_message
(
id
:
String
,
role
:
String
,
content
:
Vec
<
ResponseContentPart
>
,
status
:
String
,
)
->
Self
{
Self
::
Message
{
id
,
role
,
content
,
status
,
}
}
/// Create a new reasoning output item
pub
fn
new_reasoning
(
id
:
String
,
summary
:
Vec
<
String
>
,
content
:
Vec
<
ResponseReasoningContent
>
,
status
:
Option
<
String
>
,
)
->
Self
{
Self
::
Reasoning
{
id
,
summary
,
content
,
status
,
}
}
/// Create a new function tool call output item
pub
fn
new_function_tool_call
(
id
:
String
,
name
:
String
,
arguments
:
String
,
output
:
Option
<
String
>
,
status
:
String
,
)
->
Self
{
Self
::
FunctionToolCall
{
id
,
name
,
arguments
,
output
,
status
,
}
}
}
impl
ResponseContentPart
{
/// Create a new text content part
pub
fn
new_text
(
text
:
String
,
annotations
:
Vec
<
String
>
,
logprobs
:
Option
<
ChatLogProbs
>
,
)
->
Self
{
Self
::
OutputText
{
text
,
annotations
,
logprobs
,
}
}
}
impl
ResponseReasoningContent
{
/// Create a new reasoning text content
pub
fn
new_reasoning_text
(
text
:
String
)
->
Self
{
Self
::
ReasoningText
{
text
}
}
}
impl
UsageInfo
{
/// Create a new usage info with token counts
pub
fn
new
(
prompt_tokens
:
u32
,
completion_tokens
:
u32
,
reasoning_tokens
:
Option
<
u32
>
)
->
Self
{
Self
{
prompt_tokens
,
completion_tokens
,
total_tokens
:
prompt_tokens
+
completion_tokens
,
reasoning_tokens
,
prompt_tokens_details
:
None
,
}
}
/// Create usage info with cached token details
pub
fn
new_with_cached
(
prompt_tokens
:
u32
,
completion_tokens
:
u32
,
reasoning_tokens
:
Option
<
u32
>
,
cached_tokens
:
u32
,
)
->
Self
{
Self
{
prompt_tokens
,
completion_tokens
,
total_tokens
:
prompt_tokens
+
completion_tokens
,
reasoning_tokens
,
prompt_tokens_details
:
Some
(
PromptTokenUsageInfo
{
cached_tokens
}),
}
}
}
// ==================================================================
// = OPENAI SPEC - Common =
// ==================================================================
// ============= Shared Request Components =============
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
pub
struct
StreamOptions
{
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
include_usage
:
Option
<
bool
>
,
}
// ============= Tool Choice Types =============
/// Tool choice value for simple string options
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
#[serde(rename_all
=
"snake_case"
)]
pub
enum
ToolChoiceValue
{
Auto
,
Required
,
None
,
}
/// Tool choice for both Chat Completion and Responses APIs
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
#[serde(untagged)]
pub
enum
ToolChoice
{
Value
(
ToolChoiceValue
),
Function
{
#[serde(rename
=
"type"
)]
tool_type
:
String
,
// "function"
function
:
FunctionChoice
,
},
}
impl
Default
for
ToolChoice
{
fn
default
()
->
Self
{
Self
::
Value
(
ToolChoiceValue
::
Auto
)
}
}
/// Function choice specification for ToolChoice::Function
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
pub
struct
FunctionChoice
{
pub
name
:
String
,
}
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
pub
struct
Tool
{
#[serde(rename
=
"type"
)]
pub
tool_type
:
String
,
// "function"
pub
function
:
Function
,
}
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
pub
struct
Function
{
pub
name
:
String
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
description
:
Option
<
String
>
,
pub
parameters
:
Value
,
// JSON Schema
}
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
pub
struct
ToolCall
{
pub
id
:
String
,
#[serde(rename
=
"type"
)]
pub
tool_type
:
String
,
// "function"
pub
function
:
FunctionCallResponse
,
}
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
#[serde(untagged)]
pub
enum
FunctionCall
{
None
,
Auto
,
Function
{
name
:
String
},
}
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
pub
struct
FunctionCallResponse
{
pub
name
:
String
,
pub
arguments
:
String
,
// JSON string
}
// ============= Usage Tracking =============
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
pub
struct
Usage
{
pub
prompt_tokens
:
u32
,
pub
completion_tokens
:
u32
,
pub
total_tokens
:
u32
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
completion_tokens_details
:
Option
<
CompletionTokensDetails
>
,
}
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
pub
struct
CompletionTokensDetails
{
pub
reasoning_tokens
:
Option
<
u32
>
,
}
// ============= Logprobs Types =============
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
pub
struct
LogProbs
{
pub
tokens
:
Vec
<
String
>
,
pub
token_logprobs
:
Vec
<
Option
<
f32
>>
,
pub
top_logprobs
:
Vec
<
Option
<
HashMap
<
String
,
f32
>>>
,
pub
text_offset
:
Vec
<
u32
>
,
}
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
pub
struct
ChatLogProbs
{
pub
content
:
Option
<
Vec
<
ChatLogProbsContent
>>
,
}
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
pub
struct
ChatLogProbsContent
{
pub
token
:
String
,
pub
logprob
:
f32
,
pub
bytes
:
Option
<
Vec
<
u8
>>
,
pub
top_logprobs
:
Vec
<
TopLogProb
>
,
}
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
pub
struct
TopLogProb
{
pub
token
:
String
,
pub
logprob
:
f32
,
pub
bytes
:
Option
<
Vec
<
u8
>>
,
}
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
pub
struct
ErrorResponse
{
pub
error
:
ErrorDetail
,
}
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
pub
struct
ErrorDetail
{
pub
message
:
String
,
#[serde(rename
=
"type"
)]
pub
error_type
:
String
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
param
:
Option
<
String
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
code
:
Option
<
String
>
,
}
// ==================================================================
// = SGLANG SPEC - GENERATE API =
// ==================================================================
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
#[serde(untagged)]
pub
enum
InputIds
{
Single
(
Vec
<
i32
>
),
Batch
(
Vec
<
Vec
<
i32
>>
),
}
#[derive(Debug,
Clone,
Deserialize,
Serialize,
Default)]
pub
struct
GenerateParameters
{
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
best_of
:
Option
<
u32
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
decoder_input_details
:
Option
<
bool
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
details
:
Option
<
bool
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
do_sample
:
Option
<
bool
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
max_new_tokens
:
Option
<
u32
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
repetition_penalty
:
Option
<
f32
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
return_full_text
:
Option
<
bool
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
seed
:
Option
<
u64
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
stop
:
Option
<
Vec
<
String
>>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
temperature
:
Option
<
f32
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
top_k
:
Option
<
u32
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
top_p
:
Option
<
f32
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
truncate
:
Option
<
u32
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
typical_p
:
Option
<
f32
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
watermark
:
Option
<
bool
>
,
}
#[derive(Debug,
Clone,
Deserialize,
Serialize,
Default)]
pub
struct
SamplingParams
{
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
temperature
:
Option
<
f32
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
max_new_tokens
:
Option
<
u32
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
top_p
:
Option
<
f32
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
top_k
:
Option
<
i32
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
frequency_penalty
:
Option
<
f32
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
presence_penalty
:
Option
<
f32
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
repetition_penalty
:
Option
<
f32
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
stop
:
Option
<
StringOrArray
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
ignore_eos
:
Option
<
bool
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
skip_special_tokens
:
Option
<
bool
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
json_schema
:
Option
<
String
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
regex
:
Option
<
String
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
ebnf
:
Option
<
String
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
min_p
:
Option
<
f32
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
min_tokens
:
Option
<
u32
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
stop_token_ids
:
Option
<
Vec
<
i32
>>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
no_stop_trim
:
Option
<
bool
>
,
}
#[derive(Clone,
Debug,
Serialize,
Deserialize)]
pub
struct
GenerateRequest
{
/// The prompt to generate from (OpenAI style)
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
prompt
:
Option
<
StringOrArray
>
,
/// Text input - SGLang native format
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
text
:
Option
<
String
>
,
/// Input IDs for tokenized input
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
input_ids
:
Option
<
InputIds
>
,
/// Generation parameters
#[serde(default,
skip_serializing_if
=
"Option::is_none"
)]
pub
parameters
:
Option
<
GenerateParameters
>
,
/// Sampling parameters (sglang style)
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
sampling_params
:
Option
<
SamplingParams
>
,
/// Whether to stream the response
#[serde(default)]
pub
stream
:
bool
,
/// Whether to return logprobs
#[serde(default)]
pub
return_logprob
:
bool
,
// ============= SGLang Extensions =============
/// Path to LoRA adapter(s) for model customization
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
lora_path
:
Option
<
LoRAPath
>
,
/// Session parameters for continual prompting
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
session_params
:
Option
<
HashMap
<
String
,
serde_json
::
Value
>>
,
/// Return model hidden states
#[serde(default)]
pub
return_hidden_states
:
bool
,
/// Request ID for tracking
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
rid
:
Option
<
String
>
,
}
impl
GenerationRequest
for
GenerateRequest
{
fn
is_stream
(
&
self
)
->
bool
{
self
.stream
}
fn
get_model
(
&
self
)
->
Option
<&
str
>
{
// Generate requests typically don't have a model field
None
}
fn
extract_text_for_routing
(
&
self
)
->
String
{
// Check fields in priority order: text, prompt, inputs
if
let
Some
(
ref
text
)
=
self
.text
{
return
text
.clone
();
}
if
let
Some
(
ref
prompt
)
=
self
.prompt
{
return
match
prompt
{
StringOrArray
::
String
(
s
)
=>
s
.clone
(),
StringOrArray
::
Array
(
v
)
=>
v
.join
(
" "
),
};
}
if
let
Some
(
ref
input_ids
)
=
self
.input_ids
{
return
match
input_ids
{
InputIds
::
Single
(
ids
)
=>
ids
.iter
()
.map
(|
&
id
|
id
.to_string
())
.collect
::
<
Vec
<
String
>>
()
.join
(
" "
),
InputIds
::
Batch
(
batches
)
=>
batches
.iter
()
.flat_map
(|
batch
|
batch
.iter
()
.map
(|
&
id
|
id
.to_string
()))
.collect
::
<
Vec
<
String
>>
()
.join
(
" "
),
};
}
// No text input found
String
::
new
()
}
}
// ==================================================================
// = COMMON =
// ==================================================================
/// Helper function for serde default value
pub
fn
default_true
()
->
bool
{
true
}
/// Common trait for all generation requests across different APIs
pub
trait
GenerationRequest
:
Send
+
Sync
{
/// Check if the request is for streaming
fn
is_stream
(
&
self
)
->
bool
;
/// Get the model name if specified
fn
get_model
(
&
self
)
->
Option
<&
str
>
;
/// Extract text content for routing decisions
fn
extract_text_for_routing
(
&
self
)
->
String
;
}
/// Helper type for string or array of strings
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
#[serde(untagged)]
pub
enum
StringOrArray
{
String
(
String
),
Array
(
Vec
<
String
>
),
}
impl
StringOrArray
{
/// Get the number of items in the StringOrArray
pub
fn
len
(
&
self
)
->
usize
{
match
self
{
StringOrArray
::
String
(
_
)
=>
1
,
StringOrArray
::
Array
(
arr
)
=>
arr
.len
(),
}
}
/// Check if the StringOrArray is empty
pub
fn
is_empty
(
&
self
)
->
bool
{
match
self
{
StringOrArray
::
String
(
s
)
=>
s
.is_empty
(),
StringOrArray
::
Array
(
arr
)
=>
arr
.is_empty
(),
}
}
/// Convert to a vector of strings
pub
fn
to_vec
(
&
self
)
->
Vec
<
String
>
{
match
self
{
StringOrArray
::
String
(
s
)
=>
vec!
[
s
.clone
()],
StringOrArray
::
Array
(
arr
)
=>
arr
.clone
(),
}
}
}
/// LoRA adapter path - can be single path or batch of paths (SGLang extension)
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
#[serde(untagged)]
pub
enum
LoRAPath
{
Single
(
Option
<
String
>
),
Batch
(
Vec
<
Option
<
String
>>
),
}
sgl-router/src/protocols/validation.rs
View file @
5ef545e6
...
...
@@ -4,6 +4,11 @@ use anyhow::Result;
use
serde
::{
Deserialize
,
Serialize
};
use
std
::
fmt
::
Display
;
// Import types from spec module
use
crate
::
protocols
::
spec
::{
ChatCompletionRequest
,
ChatMessage
,
ResponseFormat
,
StringOrArray
,
UserMessageContent
,
};
/// Validation constants for OpenAI API parameters
pub
mod
constants
{
/// Temperature range: 0.0 to 2.0 (OpenAI spec)
...
...
@@ -257,7 +262,7 @@ pub mod utils {
)
->
Result
<
(),
ValidationError
>
{
if
let
Some
(
stop
)
=
request
.get_stop_sequences
()
{
match
stop
{
crate
::
protocols
::
common
::
StringOrArray
::
String
(
s
)
=>
{
StringOrArray
::
String
(
s
)
=>
{
if
s
.is_empty
()
{
return
Err
(
ValidationError
::
InvalidValue
{
parameter
:
"stop"
.to_string
(),
...
...
@@ -266,7 +271,7 @@ pub mod utils {
});
}
}
crate
::
protocols
::
common
::
StringOrArray
::
Array
(
arr
)
=>
{
StringOrArray
::
Array
(
arr
)
=>
{
validate_max_items
(
arr
,
constants
::
MAX_STOP_SEQUENCES
,
"stop"
)
?
;
for
(
i
,
s
)
in
arr
.iter
()
.enumerate
()
{
if
s
.is_empty
()
{
...
...
@@ -469,7 +474,7 @@ pub trait SamplingOptionsProvider {
/// Trait for validating stop conditions
pub
trait
StopConditionsProvider
{
/// Get stop sequences
fn
get_stop_sequences
(
&
self
)
->
Option
<&
crate
::
protocols
::
common
::
StringOrArray
>
;
fn
get_stop_sequences
(
&
self
)
->
Option
<&
StringOrArray
>
;
}
/// Trait for validating token limits
...
...
@@ -532,25 +537,237 @@ pub trait ValidatableRequest:
}
}
// ==================================================================
// = OPENAI CHAT COMPLETION VALIDATION =
// ==================================================================
impl
SamplingOptionsProvider
for
ChatCompletionRequest
{
fn
get_temperature
(
&
self
)
->
Option
<
f32
>
{
self
.temperature
}
fn
get_top_p
(
&
self
)
->
Option
<
f32
>
{
self
.top_p
}
fn
get_frequency_penalty
(
&
self
)
->
Option
<
f32
>
{
self
.frequency_penalty
}
fn
get_presence_penalty
(
&
self
)
->
Option
<
f32
>
{
self
.presence_penalty
}
}
impl
StopConditionsProvider
for
ChatCompletionRequest
{
fn
get_stop_sequences
(
&
self
)
->
Option
<&
StringOrArray
>
{
self
.stop
.as_ref
()
}
}
impl
TokenLimitsProvider
for
ChatCompletionRequest
{
fn
get_max_tokens
(
&
self
)
->
Option
<
u32
>
{
// Prefer max_completion_tokens over max_tokens if both are set
self
.max_completion_tokens
.or
(
self
.max_tokens
)
}
fn
get_min_tokens
(
&
self
)
->
Option
<
u32
>
{
self
.min_tokens
}
}
impl
LogProbsProvider
for
ChatCompletionRequest
{
fn
get_logprobs
(
&
self
)
->
Option
<
u32
>
{
// For chat API, logprobs is a boolean, return 1 if true for validation purposes
if
self
.logprobs
{
Some
(
1
)
}
else
{
None
}
}
fn
get_top_logprobs
(
&
self
)
->
Option
<
u32
>
{
self
.top_logprobs
}
}
impl
SGLangExtensionsProvider
for
ChatCompletionRequest
{
fn
get_top_k
(
&
self
)
->
Option
<
i32
>
{
self
.top_k
}
fn
get_min_p
(
&
self
)
->
Option
<
f32
>
{
self
.min_p
}
fn
get_repetition_penalty
(
&
self
)
->
Option
<
f32
>
{
self
.repetition_penalty
}
}
impl
CompletionCountProvider
for
ChatCompletionRequest
{
fn
get_n
(
&
self
)
->
Option
<
u32
>
{
self
.n
}
}
impl
ChatCompletionRequest
{
/// Validate message-specific requirements
pub
fn
validate_messages
(
&
self
)
->
Result
<
(),
ValidationError
>
{
// Ensure messages array is not empty
utils
::
validate_non_empty_array
(
&
self
.messages
,
"messages"
)
?
;
// Validate message content is not empty
for
(
i
,
msg
)
in
self
.messages
.iter
()
.enumerate
()
{
if
let
ChatMessage
::
User
{
content
,
..
}
=
msg
{
match
content
{
UserMessageContent
::
Text
(
text
)
if
text
.is_empty
()
=>
{
return
Err
(
ValidationError
::
InvalidValue
{
parameter
:
format!
(
"messages[{}].content"
,
i
),
value
:
"empty"
.to_string
(),
reason
:
"message content cannot be empty"
.to_string
(),
});
}
UserMessageContent
::
Parts
(
parts
)
if
parts
.is_empty
()
=>
{
return
Err
(
ValidationError
::
InvalidValue
{
parameter
:
format!
(
"messages[{}].content"
,
i
),
value
:
"empty array"
.to_string
(),
reason
:
"message content parts cannot be empty"
.to_string
(),
});
}
_
=>
{}
}
}
}
Ok
(())
}
/// Validate response format if specified
pub
fn
validate_response_format
(
&
self
)
->
Result
<
(),
ValidationError
>
{
if
let
Some
(
ResponseFormat
::
JsonSchema
{
json_schema
})
=
&
self
.response_format
{
if
json_schema
.name
.is_empty
()
{
return
Err
(
ValidationError
::
InvalidValue
{
parameter
:
"response_format.json_schema.name"
.to_string
(),
value
:
"empty"
.to_string
(),
reason
:
"JSON schema name cannot be empty"
.to_string
(),
});
}
}
Ok
(())
}
/// Validate chat API specific logprobs requirements
pub
fn
validate_chat_logprobs
(
&
self
)
->
Result
<
(),
ValidationError
>
{
// In chat API, if logprobs=true, top_logprobs must be specified
if
self
.logprobs
&&
self
.top_logprobs
.is_none
()
{
return
Err
(
ValidationError
::
MissingRequired
{
parameter
:
"top_logprobs"
.to_string
(),
});
}
// If top_logprobs is specified, logprobs should be true
if
self
.top_logprobs
.is_some
()
&&
!
self
.logprobs
{
return
Err
(
ValidationError
::
InvalidValue
{
parameter
:
"logprobs"
.to_string
(),
value
:
"false"
.to_string
(),
reason
:
"must be true when top_logprobs is specified"
.to_string
(),
});
}
Ok
(())
}
/// Validate cross-parameter relationships specific to chat completions
pub
fn
validate_chat_cross_parameters
(
&
self
)
->
Result
<
(),
ValidationError
>
{
// Validate that both max_tokens and max_completion_tokens aren't set
utils
::
validate_conflicting_parameters
(
"max_tokens"
,
self
.max_tokens
.is_some
(),
"max_completion_tokens"
,
self
.max_completion_tokens
.is_some
(),
"cannot specify both max_tokens and max_completion_tokens"
,
)
?
;
// Validate that tools and functions aren't both specified (deprecated)
utils
::
validate_conflicting_parameters
(
"tools"
,
self
.tools
.is_some
(),
"functions"
,
self
.functions
.is_some
(),
"functions is deprecated, use tools instead"
,
)
?
;
// Validate structured output constraints don't conflict with JSON response format
let
has_json_format
=
matches!
(
self
.response_format
,
Some
(
ResponseFormat
::
JsonObject
|
ResponseFormat
::
JsonSchema
{
..
})
);
utils
::
validate_conflicting_parameters
(
"response_format"
,
has_json_format
,
"regex"
,
self
.regex
.is_some
(),
"cannot use regex constraint with JSON response format"
,
)
?
;
utils
::
validate_conflicting_parameters
(
"response_format"
,
has_json_format
,
"ebnf"
,
self
.ebnf
.is_some
(),
"cannot use EBNF constraint with JSON response format"
,
)
?
;
// Only one structured output constraint should be active
let
structured_constraints
=
[
(
"regex"
,
self
.regex
.is_some
()),
(
"ebnf"
,
self
.ebnf
.is_some
()),
(
"json_schema"
,
matches!
(
self
.response_format
,
Some
(
ResponseFormat
::
JsonSchema
{
..
})
),
),
];
utils
::
validate_mutually_exclusive_options
(
&
structured_constraints
,
"Only one structured output constraint (regex, ebnf, or json_schema) can be active at a time"
,
)
?
;
Ok
(())
}
}
impl
ValidatableRequest
for
ChatCompletionRequest
{
fn
validate
(
&
self
)
->
Result
<
(),
ValidationError
>
{
// Call the common validation function from the validation module
utils
::
validate_common_request_params
(
self
)
?
;
// Then validate chat-specific parameters
self
.validate_messages
()
?
;
self
.validate_response_format
()
?
;
self
.validate_chat_logprobs
()
?
;
self
.validate_chat_cross_parameters
()
?
;
Ok
(())
}
}
#[cfg(test)]
mod
tests
{
use
super
::
constants
::
*
;
use
super
::
utils
::
*
;
use
super
::
*
;
use
crate
::
protocols
::
common
::
StringOrArray
;
use
crate
::
protocols
::
spec
::
StringOrArray
;
// Mock request type for testing validation traits
#[derive(Debug,
Default)]
struct
MockRequest
{
temperature
:
Option
<
f32
>
,
top_p
:
Option
<
f32
>
,
frequency_penalty
:
Option
<
f32
>
,
presence_penalty
:
Option
<
f32
>
,
stop
:
Option
<
StringOrArray
>
,
max_tokens
:
Option
<
u32
>
,
min_tokens
:
Option
<
u32
>
,
logprobs
:
Option
<
u32
>
,
top_logprobs
:
Option
<
u32
>
,
}
impl
SamplingOptionsProvider
for
MockRequest
{
...
...
@@ -558,13 +775,13 @@ mod tests {
self
.temperature
}
fn
get_top_p
(
&
self
)
->
Option
<
f32
>
{
self
.top_p
None
}
fn
get_frequency_penalty
(
&
self
)
->
Option
<
f32
>
{
self
.frequency_penalty
None
}
fn
get_presence_penalty
(
&
self
)
->
Option
<
f32
>
{
self
.presence_penalty
None
}
}
...
...
@@ -585,173 +802,362 @@ mod tests {
impl
LogProbsProvider
for
MockRequest
{
fn
get_logprobs
(
&
self
)
->
Option
<
u32
>
{
self
.logprobs
None
}
fn
get_top_logprobs
(
&
self
)
->
Option
<
u32
>
{
self
.top_logprobs
None
}
}
impl
SGLangExtensionsProvider
for
MockRequest
{
// Default implementations return None, so no custom logic needed
}
impl
CompletionCountProvider
for
MockRequest
{
// Default implementation returns None, so no custom logic needed
}
impl
SGLangExtensionsProvider
for
MockRequest
{}
impl
CompletionCountProvider
for
MockRequest
{}
impl
ValidatableRequest
for
MockRequest
{}
#[test]
fn
test_validate_range_valid
()
{
let
result
=
validate_range
(
1.5f32
,
&
TEMPERATURE_RANGE
,
"temperature"
);
assert
!
(
result
.is_ok
());
assert_eq!
(
result
.unwrap
(),
1.5f32
);
}
#[test]
fn
test_validate_range_too_low
()
{
let
result
=
validate_range
(
-
0.1f32
,
&
TEMPERATURE_RANGE
,
"temperature"
);
assert
!
(
result
.is_err
());
match
result
.unwrap_err
()
{
ValidationError
::
OutOfRange
{
parameter
,
..
}
=>
{
assert_eq!
(
parameter
,
"temperature"
);
}
_
=>
panic!
(
"Expected OutOfRange error"
),
}
}
#[test]
fn
test_validate_positive_valid
()
{
let
result
=
validate_positive
(
5i32
,
"max_tokens"
);
assert
!
(
result
.is_ok
());
assert_eq!
(
result
.unwrap
(),
5i32
);
}
#[test]
fn
test_validate_max_items_valid
()
{
let
items
=
vec!
[
"stop1"
,
"stop2"
];
let
result
=
validate_max_items
(
&
items
,
MAX_STOP_SEQUENCES
,
"stop"
);
assert
!
(
result
.is_ok
());
fn
test_range_validation
()
{
// Valid range
assert
!
(
validate_range
(
1.5f32
,
&
TEMPERATURE_RANGE
,
"temperature"
)
.is_ok
());
// Invalid range
assert
!
(
validate_range
(
-
0.1f32
,
&
TEMPERATURE_RANGE
,
"temperature"
)
.is_err
());
assert
!
(
validate_range
(
3.0f32
,
&
TEMPERATURE_RANGE
,
"temperature"
)
.is_err
());
}
#[test]
fn
test_
validate_top_k
()
{
fn
test_
sglang_top_k_validation
()
{
assert
!
(
validate_top_k
(
-
1
)
.is_ok
());
// Disabled
assert
!
(
validate_top_k
(
50
)
.is_ok
());
//
P
ositive
assert
!
(
validate_top_k
(
50
)
.is_ok
());
//
Valid p
ositive
assert
!
(
validate_top_k
(
0
)
.is_err
());
// Invalid
assert
!
(
validate_top_k
(
-
5
)
.is_err
());
// Invalid
}
#[test]
fn
test_
valid_request
()
{
fn
test_
stop_sequences_limits
()
{
let
request
=
MockRequest
{
temperature
:
Some
(
1.0
),
top_p
:
Some
(
0.9
),
frequency_penalty
:
Some
(
0.5
),
presence_penalty
:
Some
(
-
0.5
),
stop
:
Some
(
StringOrArray
::
Array
(
vec!
[
"stop1"
.to_string
(),
"stop2"
.to_string
(),
"stop3"
.to_string
(),
"stop4"
.to_string
(),
"stop5"
.to_string
(),
// Too many
])),
max_tokens
:
Some
(
100
),
min_tokens
:
Some
(
10
),
logprobs
:
Some
(
3
),
top_logprobs
:
Some
(
15
),
..
Default
::
default
()
};
assert
!
(
request
.validate
()
.is_ok
());
assert
!
(
request
.validate
()
.is_err
());
}
#[test]
fn
test_
invalid_temperature
()
{
fn
test_
token_limits_conflict
()
{
let
request
=
MockRequest
{
temperature
:
Some
(
3.0
),
// Invalid: too high
min_tokens
:
Some
(
100
),
max_tokens
:
Some
(
50
),
// min > max
..
Default
::
default
()
};
let
result
=
request
.validate
();
assert
!
(
result
.is_err
());
assert
!
(
request
.validate
()
.is_err
());
}
#[test]
fn
test_
too_many_stop_sequences
()
{
fn
test_
valid_request
()
{
let
request
=
MockRequest
{
stop
:
Some
(
StringOrArray
::
Array
(
vec!
[
temperature
:
Some
(
1.0
),
stop
:
Some
(
StringOrArray
::
Array
(
vec!
[
"stop"
.to_string
()])),
max_tokens
:
Some
(
100
),
min_tokens
:
Some
(
10
),
};
assert
!
(
request
.validate
()
.is_ok
());
}
// Chat completion specific tests
#[cfg(test)]
mod
chat_tests
{
use
super
::
*
;
fn
create_valid_chat_request
()
->
ChatCompletionRequest
{
ChatCompletionRequest
{
model
:
"gpt-4"
.to_string
(),
messages
:
vec!
[
ChatMessage
::
User
{
role
:
"user"
.to_string
(),
content
:
UserMessageContent
::
Text
(
"Hello"
.to_string
()),
name
:
None
,
}],
temperature
:
Some
(
1.0
),
top_p
:
Some
(
0.9
),
n
:
Some
(
1
),
stream
:
false
,
stream_options
:
None
,
stop
:
None
,
max_tokens
:
Some
(
100
),
max_completion_tokens
:
None
,
presence_penalty
:
Some
(
0.0
),
frequency_penalty
:
Some
(
0.0
),
logit_bias
:
None
,
user
:
None
,
seed
:
None
,
logprobs
:
false
,
top_logprobs
:
None
,
response_format
:
None
,
tools
:
None
,
tool_choice
:
None
,
parallel_tool_calls
:
None
,
functions
:
None
,
function_call
:
None
,
// SGLang extensions
top_k
:
None
,
min_p
:
None
,
min_tokens
:
None
,
repetition_penalty
:
None
,
regex
:
None
,
ebnf
:
None
,
stop_token_ids
:
None
,
no_stop_trim
:
false
,
ignore_eos
:
false
,
continue_final_message
:
false
,
skip_special_tokens
:
true
,
lora_path
:
None
,
session_params
:
None
,
separate_reasoning
:
true
,
stream_reasoning
:
true
,
return_hidden_states
:
false
,
}
}
#[test]
fn
test_chat_validation_basics
()
{
// Valid request
assert
!
(
create_valid_chat_request
()
.validate
()
.is_ok
());
// Empty messages
let
mut
request
=
create_valid_chat_request
();
request
.messages
=
vec!
[];
assert
!
(
request
.validate
()
.is_err
());
// Invalid temperature
let
mut
request
=
create_valid_chat_request
();
request
.temperature
=
Some
(
3.0
);
assert
!
(
request
.validate
()
.is_err
());
}
#[test]
fn
test_chat_conflicts
()
{
let
mut
request
=
create_valid_chat_request
();
// Conflicting max_tokens
request
.max_tokens
=
Some
(
100
);
request
.max_completion_tokens
=
Some
(
200
);
assert
!
(
request
.validate
()
.is_err
());
// Logprobs without top_logprobs
request
.max_tokens
=
None
;
request
.logprobs
=
true
;
request
.top_logprobs
=
None
;
assert
!
(
request
.validate
()
.is_err
());
}
#[test]
fn
test_sglang_extensions
()
{
let
mut
request
=
create_valid_chat_request
();
// Valid SGLang parameters
request
.top_k
=
Some
(
-
1
);
request
.min_p
=
Some
(
0.1
);
request
.repetition_penalty
=
Some
(
1.2
);
assert
!
(
request
.validate
()
.is_ok
());
// Invalid parameters
request
.top_k
=
Some
(
0
);
// Invalid
assert
!
(
request
.validate
()
.is_err
());
}
#[test]
fn
test_parameter_ranges
()
{
let
mut
request
=
create_valid_chat_request
();
// Test temperature range (0.0 to 2.0)
request
.temperature
=
Some
(
1.5
);
assert
!
(
request
.validate
()
.is_ok
());
request
.temperature
=
Some
(
-
0.1
);
assert
!
(
request
.validate
()
.is_err
());
request
.temperature
=
Some
(
3.0
);
assert
!
(
request
.validate
()
.is_err
());
// Test top_p range (0.0 to 1.0)
request
.temperature
=
Some
(
1.0
);
// Reset
request
.top_p
=
Some
(
0.9
);
assert
!
(
request
.validate
()
.is_ok
());
request
.top_p
=
Some
(
-
0.1
);
assert
!
(
request
.validate
()
.is_err
());
request
.top_p
=
Some
(
1.5
);
assert
!
(
request
.validate
()
.is_err
());
// Test frequency_penalty range (-2.0 to 2.0)
request
.top_p
=
Some
(
0.9
);
// Reset
request
.frequency_penalty
=
Some
(
1.5
);
assert
!
(
request
.validate
()
.is_ok
());
request
.frequency_penalty
=
Some
(
-
2.5
);
assert
!
(
request
.validate
()
.is_err
());
request
.frequency_penalty
=
Some
(
3.0
);
assert
!
(
request
.validate
()
.is_err
());
// Test presence_penalty range (-2.0 to 2.0)
request
.frequency_penalty
=
Some
(
0.0
);
// Reset
request
.presence_penalty
=
Some
(
-
1.5
);
assert
!
(
request
.validate
()
.is_ok
());
request
.presence_penalty
=
Some
(
-
3.0
);
assert
!
(
request
.validate
()
.is_err
());
request
.presence_penalty
=
Some
(
2.5
);
assert
!
(
request
.validate
()
.is_err
());
// Test repetition_penalty range (0.0 to 2.0)
request
.presence_penalty
=
Some
(
0.0
);
// Reset
request
.repetition_penalty
=
Some
(
1.2
);
assert
!
(
request
.validate
()
.is_ok
());
request
.repetition_penalty
=
Some
(
-
0.1
);
assert
!
(
request
.validate
()
.is_err
());
request
.repetition_penalty
=
Some
(
2.1
);
assert
!
(
request
.validate
()
.is_err
());
// Test min_p range (0.0 to 1.0)
request
.repetition_penalty
=
Some
(
1.0
);
// Reset
request
.min_p
=
Some
(
0.5
);
assert
!
(
request
.validate
()
.is_ok
());
request
.min_p
=
Some
(
-
0.1
);
assert
!
(
request
.validate
()
.is_err
());
request
.min_p
=
Some
(
1.5
);
assert
!
(
request
.validate
()
.is_err
());
}
#[test]
fn
test_structured_output_conflicts
()
{
let
mut
request
=
create_valid_chat_request
();
// JSON response format with regex should conflict
request
.response_format
=
Some
(
ResponseFormat
::
JsonObject
);
request
.regex
=
Some
(
".*"
.to_string
());
assert
!
(
request
.validate
()
.is_err
());
// JSON response format with EBNF should conflict
request
.regex
=
None
;
request
.ebnf
=
Some
(
"grammar"
.to_string
());
assert
!
(
request
.validate
()
.is_err
());
// Multiple structured constraints should conflict
request
.response_format
=
None
;
request
.regex
=
Some
(
".*"
.to_string
());
request
.ebnf
=
Some
(
"grammar"
.to_string
());
assert
!
(
request
.validate
()
.is_err
());
// Only one constraint should work
request
.ebnf
=
None
;
request
.regex
=
Some
(
".*"
.to_string
());
assert
!
(
request
.validate
()
.is_ok
());
request
.regex
=
None
;
request
.ebnf
=
Some
(
"grammar"
.to_string
());
assert
!
(
request
.validate
()
.is_ok
());
request
.ebnf
=
None
;
request
.response_format
=
Some
(
ResponseFormat
::
JsonObject
);
assert
!
(
request
.validate
()
.is_ok
());
}
#[test]
fn
test_stop_sequences_validation
()
{
let
mut
request
=
create_valid_chat_request
();
// Valid stop sequences
request
.stop
=
Some
(
StringOrArray
::
Array
(
vec!
[
"stop1"
.to_string
(),
"stop2"
.to_string
(),
]));
assert
!
(
request
.validate
()
.is_ok
());
// Too many stop sequences (max 4)
request
.stop
=
Some
(
StringOrArray
::
Array
(
vec!
[
"stop1"
.to_string
(),
"stop2"
.to_string
(),
"stop3"
.to_string
(),
"stop4"
.to_string
(),
"stop5"
.to_string
(),
// Too many
])),
..
Default
::
default
()
};
"stop5"
.to_string
(),
]));
assert
!
(
request
.validate
()
.is_err
());
let
result
=
request
.validate
();
assert
!
(
result
.is_err
());
match
result
.unwrap_err
()
{
ValidationError
::
TooManyItems
{
parameter
,
count
,
max
,
}
=>
{
assert_eq!
(
parameter
,
"stop"
);
assert_eq!
(
count
,
5
);
assert_eq!
(
max
,
MAX_STOP_SEQUENCES
);
}
_
=>
panic!
(
"Expected TooManyItems error"
),
// Empty stop sequence should fail
request
.stop
=
Some
(
StringOrArray
::
String
(
""
.to_string
()));
assert
!
(
request
.validate
()
.is_err
());
// Empty string in array should fail
request
.stop
=
Some
(
StringOrArray
::
Array
(
vec!
[
"stop1"
.to_string
(),
""
.to_string
(),
]));
assert
!
(
request
.validate
()
.is_err
());
}
}
#[test]
fn
test_conflicting_token_limits
()
{
let
request
=
MockRequest
{
min_tokens
:
Some
(
100
),
max_tokens
:
Some
(
50
),
// Invalid: min > max
..
Default
::
default
()
};
#[test]
fn
test_logprobs_validation
()
{
let
mut
request
=
create_valid_chat_request
();
let
result
=
request
.validate
();
assert
!
(
result
.is_err
());
match
result
.unwrap_err
()
{
ValidationError
::
ConflictingParameters
{
parameter1
,
parameter2
,
..
}
=>
{
assert_eq!
(
parameter1
,
"min_tokens"
);
assert_eq!
(
parameter2
,
"max_tokens"
);
}
_
=>
panic!
(
"Expected ConflictingParameters error"
),
}
}
// Valid logprobs configuration
request
.logprobs
=
true
;
request
.top_logprobs
=
Some
(
10
);
assert
!
(
request
.validate
()
.is_ok
());
#[test]
fn
test_boundary_values
()
{
let
request
=
MockRequest
{
temperature
:
Some
(
0.0
),
// Boundary: minimum
top_p
:
Some
(
1.0
),
// Boundary: maximum
frequency_penalty
:
Some
(
-
2.0
),
// Boundary: minimum
presence_penalty
:
Some
(
2.0
),
// Boundary: maximum
logprobs
:
Some
(
0
),
// Boundary: minimum
top_logprobs
:
Some
(
20
),
// Boundary: maximum
..
Default
::
default
()
};
// logprobs=true without top_logprobs should fail
request
.top_logprobs
=
None
;
assert
!
(
request
.validate
()
.is_err
());
assert
!
(
request
.validate
()
.is_ok
());
}
// top_logprobs without logprobs=true should fail
request
.logprobs
=
false
;
request
.top_logprobs
=
Some
(
10
);
assert
!
(
request
.validate
()
.is_err
());
#[test]
fn
test_validation_error_display
()
{
let
error
=
ValidationError
::
OutOfRange
{
parameter
:
"temperature"
.to_string
(),
value
:
"3.0"
.to_string
(),
min
:
"0.0"
.to_string
(),
max
:
"2.0"
.to_string
(),
};
// top_logprobs out of range (0-20)
request
.logprobs
=
true
;
request
.top_logprobs
=
Some
(
25
);
assert
!
(
request
.validate
()
.is_err
());
}
let
message
=
format!
(
"{}"
,
error
);
assert
!
(
message
.contains
(
"temperature"
));
assert
!
(
message
.contains
(
"3.0"
));
#[test]
fn
test_n_parameter_validation
()
{
let
mut
request
=
create_valid_chat_request
();
// Valid n values (1-10)
request
.n
=
Some
(
1
);
assert
!
(
request
.validate
()
.is_ok
());
request
.n
=
Some
(
5
);
assert
!
(
request
.validate
()
.is_ok
());
request
.n
=
Some
(
10
);
assert
!
(
request
.validate
()
.is_ok
());
// Invalid n values
request
.n
=
Some
(
0
);
assert
!
(
request
.validate
()
.is_err
());
request
.n
=
Some
(
15
);
assert
!
(
request
.validate
()
.is_err
());
}
#[test]
fn
test_min_max_tokens_validation
()
{
let
mut
request
=
create_valid_chat_request
();
// Valid token limits
request
.min_tokens
=
Some
(
10
);
request
.max_tokens
=
Some
(
100
);
assert
!
(
request
.validate
()
.is_ok
());
// min_tokens > max_tokens should fail
request
.min_tokens
=
Some
(
150
);
request
.max_tokens
=
Some
(
100
);
assert
!
(
request
.validate
()
.is_err
());
// Should work with max_completion_tokens instead
request
.max_tokens
=
None
;
request
.max_completion_tokens
=
Some
(
200
);
request
.min_tokens
=
Some
(
50
);
assert
!
(
request
.validate
()
.is_ok
());
// min_tokens > max_completion_tokens should fail
request
.min_tokens
=
Some
(
250
);
assert
!
(
request
.validate
()
.is_err
());
}
}
}
sgl-router/src/routers/mod.rs
View file @
5ef545e6
...
...
@@ -9,10 +9,7 @@ use axum::{
};
use
std
::
fmt
::
Debug
;
use
crate
::
protocols
::{
generate
::
GenerateRequest
,
openai
::{
chat
::
ChatCompletionRequest
,
completions
::
CompletionRequest
},
};
use
crate
::
protocols
::
spec
::{
ChatCompletionRequest
,
CompletionRequest
,
GenerateRequest
};
pub
mod
factory
;
pub
mod
header_utils
;
...
...
sgl-router/src/routers/pd_router.rs
View file @
5ef545e6
...
...
@@ -12,13 +12,9 @@ use crate::core::{
};
use
crate
::
metrics
::
RouterMetrics
;
use
crate
::
policies
::
LoadBalancingPolicy
;
use
crate
::
protocols
::{
common
::
StringOrArray
,
generate
::
GenerateRequest
,
openai
::{
chat
::{
ChatCompletionRequest
,
ChatMessage
,
UserMessageContent
},
completions
::
CompletionRequest
,
},
use
crate
::
protocols
::
spec
::{
ChatCompletionRequest
,
ChatMessage
,
CompletionRequest
,
GenerateRequest
,
StringOrArray
,
UserMessageContent
,
};
use
crate
::
routers
::{
RouterTrait
,
WorkerManagement
};
use
async_trait
::
async_trait
;
...
...
sgl-router/src/routers/router.rs
View file @
5ef545e6
...
...
@@ -9,10 +9,8 @@ use crate::core::{
};
use
crate
::
metrics
::
RouterMetrics
;
use
crate
::
policies
::
LoadBalancingPolicy
;
use
crate
::
protocols
::{
common
::
GenerationRequest
,
generate
::
GenerateRequest
,
openai
::{
chat
::
ChatCompletionRequest
,
completions
::
CompletionRequest
},
use
crate
::
protocols
::
spec
::{
ChatCompletionRequest
,
CompletionRequest
,
GenerateRequest
,
GenerationRequest
,
};
use
crate
::
routers
::{
RouterTrait
,
WorkerManagement
};
use
axum
::{
...
...
sgl-router/src/server.rs
View file @
5ef545e6
use
crate
::
config
::
RouterConfig
;
use
crate
::
logging
::{
self
,
LoggingConfig
};
use
crate
::
metrics
::{
self
,
PrometheusConfig
};
use
crate
::
protocols
::{
generate
::
GenerateRequest
,
openai
::{
chat
::
ChatCompletionRequest
,
completions
::
CompletionRequest
},
};
use
crate
::
protocols
::
spec
::{
ChatCompletionRequest
,
CompletionRequest
,
GenerateRequest
};
use
crate
::
routers
::{
RouterFactory
,
RouterTrait
};
use
crate
::
service_discovery
::{
start_service_discovery
,
ServiceDiscoveryConfig
};
use
axum
::{
...
...
sgl-router/tests/benchmark_integration.rs
View file @
5ef545e6
...
...
@@ -5,13 +5,9 @@
use
serde_json
::{
from_str
,
to_string
,
to_value
};
use
sglang_router_rs
::
core
::{
BasicWorker
,
WorkerType
};
use
sglang_router_rs
::
protocols
::{
common
::
StringOrArray
,
generate
::{
GenerateParameters
,
GenerateRequest
,
SamplingParams
},
openai
::{
chat
::{
ChatCompletionRequest
,
ChatMessage
,
UserMessageContent
},
completions
::
CompletionRequest
,
},
use
sglang_router_rs
::
protocols
::
spec
::{
ChatCompletionRequest
,
ChatMessage
,
CompletionRequest
,
GenerateParameters
,
GenerateRequest
,
SamplingParams
,
StringOrArray
,
UserMessageContent
,
};
/// Create a default GenerateRequest for benchmarks with minimal fields set
...
...
sgl-router/tests/responses_api_test.rs
View file @
5ef545e6
// Integration test for Responses API
use
sglang_router_rs
::
protocols
::
common
::
GenerationRequest
;
use
sglang_router_rs
::
protocols
::
openai
::
responses
::
request
::
ResponseInput
;
use
sglang_router_rs
::
protocols
::
openai
::
responses
::
*
;
use
sglang_router_rs
::
protocols
::
spec
::{
GenerationRequest
,
ReasoningEffort
,
ResponseInput
,
ResponseReasoningParam
,
ResponseStatus
,
ResponseTool
,
ResponseToolType
,
ResponsesRequest
,
ResponsesResponse
,
ServiceTier
,
ToolChoice
,
ToolChoiceValue
,
Truncation
,
UsageInfo
,
};
#[test]
fn
test_responses_request_creation
()
{
...
...
@@ -24,7 +26,7 @@ fn test_responses_request_creation() {
store
:
true
,
stream
:
false
,
temperature
:
Some
(
0.7
),
tool_choice
:
ToolChoice
::
Auto
,
tool_choice
:
ToolChoice
::
Value
(
ToolChoiceValue
::
Auto
)
,
tools
:
vec!
[
ResponseTool
{
r
#
type
:
ResponseToolType
::
WebSearchPreview
,
}],
...
...
@@ -67,7 +69,7 @@ fn test_sampling_params_conversion() {
store
:
true
,
// Use default true
stream
:
false
,
temperature
:
Some
(
0.8
),
tool_choice
:
ToolChoice
::
Auto
,
tool_choice
:
ToolChoice
::
Value
(
ToolChoiceValue
::
Auto
)
,
tools
:
vec!
[],
top_logprobs
:
0
,
// Use default 0
top_p
:
Some
(
0.95
),
...
...
@@ -177,7 +179,7 @@ fn test_json_serialization() {
store
:
false
,
stream
:
true
,
temperature
:
Some
(
0.9
),
tool_choice
:
ToolChoice
::
Required
,
tool_choice
:
ToolChoice
::
Value
(
ToolChoiceValue
::
Required
)
,
tools
:
vec!
[
ResponseTool
{
r
#
type
:
ResponseToolType
::
CodeInterpreter
,
}],
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment