Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
98c3b04f
Unverified
Commit
98c3b04f
authored
Sep 23, 2025
by
Simo Lin
Committed by
GitHub
Sep 23, 2025
Browse files
[router] responses api POST and GET with local storage (#10581)
Co-authored-by:
key4ng
<
rukeyang@gmail.com
>
parent
ddab4fc7
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
1154 additions
and
172 deletions
+1154
-172
sgl-router/src/data_connector/response_memory_store.rs
sgl-router/src/data_connector/response_memory_store.rs
+18
-1
sgl-router/src/data_connector/responses.rs
sgl-router/src/data_connector/responses.rs
+12
-0
sgl-router/src/protocols/spec.rs
sgl-router/src/protocols/spec.rs
+263
-79
sgl-router/src/routers/factory.rs
sgl-router/src/routers/factory.rs
+6
-2
sgl-router/src/routers/grpc/pd_router.rs
sgl-router/src/routers/grpc/pd_router.rs
+6
-1
sgl-router/src/routers/grpc/router.rs
sgl-router/src/routers/grpc/router.rs
+6
-1
sgl-router/src/routers/header_utils.rs
sgl-router/src/routers/header_utils.rs
+42
-0
sgl-router/src/routers/http/openai_router.rs
sgl-router/src/routers/http/openai_router.rs
+548
-33
sgl-router/src/routers/http/pd_router.rs
sgl-router/src/routers/http/pd_router.rs
+7
-2
sgl-router/src/routers/http/router.rs
sgl-router/src/routers/http/router.rs
+7
-2
sgl-router/src/routers/mod.rs
sgl-router/src/routers/mod.rs
+7
-2
sgl-router/src/routers/router_manager.rs
sgl-router/src/routers/router_manager.rs
+37
-25
sgl-router/src/server.rs
sgl-router/src/server.rs
+3
-2
sgl-router/tests/test_openai_routing.rs
sgl-router/tests/test_openai_routing.rs
+192
-22
No files found.
sgl-router/src/data_connector/response_memory_store.rs
View file @
98c3b04f
...
@@ -74,13 +74,16 @@ impl ResponseStorage for MemoryResponseStorage {
...
@@ -74,13 +74,16 @@ impl ResponseStorage for MemoryResponseStorage {
// Store the response
// Store the response
store
.responses
.insert
(
response_id
.clone
(),
response
);
store
.responses
.insert
(
response_id
.clone
(),
response
);
tracing
::
info!
(
"memory_store_size"
=
store
.responses
.len
());
Ok
(
response_id
)
Ok
(
response_id
)
}
}
async
fn
get_response
(
&
self
,
response_id
:
&
ResponseId
)
->
Result
<
Option
<
StoredResponse
>>
{
async
fn
get_response
(
&
self
,
response_id
:
&
ResponseId
)
->
Result
<
Option
<
StoredResponse
>>
{
let
store
=
self
.store
.read
();
let
store
=
self
.store
.read
();
Ok
(
store
.responses
.get
(
response_id
)
.cloned
())
let
result
=
store
.responses
.get
(
response_id
)
.cloned
();
tracing
::
info!
(
"memory_get_response"
=
%
response_id
.0
,
found
=
result
.is_some
());
Ok
(
result
)
}
}
async
fn
delete_response
(
&
self
,
response_id
:
&
ResponseId
)
->
Result
<
()
>
{
async
fn
delete_response
(
&
self
,
response_id
:
&
ResponseId
)
->
Result
<
()
>
{
...
@@ -200,6 +203,20 @@ pub struct MemoryStoreStats {
...
@@ -200,6 +203,20 @@ pub struct MemoryStoreStats {
mod
tests
{
mod
tests
{
use
super
::
*
;
use
super
::
*
;
#[tokio::test]
async
fn
test_store_with_custom_id
()
{
let
store
=
MemoryResponseStorage
::
new
();
let
mut
response
=
StoredResponse
::
new
(
"Input"
.to_string
(),
"Output"
.to_string
(),
None
);
response
.id
=
ResponseId
::
from_string
(
"resp_custom"
.to_string
());
store
.store_response
(
response
.clone
())
.await
.unwrap
();
let
retrieved
=
store
.get_response
(
&
ResponseId
::
from_string
(
"resp_custom"
.to_string
()))
.await
.unwrap
();
assert
!
(
retrieved
.is_some
());
assert_eq!
(
retrieved
.unwrap
()
.output
,
"Output"
);
}
#[tokio::test]
#[tokio::test]
async
fn
test_memory_store_basic
()
{
async
fn
test_memory_store_basic
()
{
let
store
=
MemoryResponseStorage
::
new
();
let
store
=
MemoryResponseStorage
::
new
();
...
...
sgl-router/src/data_connector/responses.rs
View file @
98c3b04f
use
async_trait
::
async_trait
;
use
async_trait
::
async_trait
;
use
serde
::{
Deserialize
,
Serialize
};
use
serde
::{
Deserialize
,
Serialize
};
use
serde_json
::
Value
;
use
std
::
collections
::
HashMap
;
use
std
::
collections
::
HashMap
;
use
std
::
sync
::
Arc
;
use
std
::
sync
::
Arc
;
...
@@ -55,6 +56,10 @@ pub struct StoredResponse {
...
@@ -55,6 +56,10 @@ pub struct StoredResponse {
/// Model used for generation
/// Model used for generation
pub
model
:
Option
<
String
>
,
pub
model
:
Option
<
String
>
,
/// Raw OpenAI response payload
#[serde(default)]
pub
raw_response
:
Value
,
}
}
impl
StoredResponse
{
impl
StoredResponse
{
...
@@ -70,6 +75,7 @@ impl StoredResponse {
...
@@ -70,6 +75,7 @@ impl StoredResponse {
created_at
:
chrono
::
Utc
::
now
(),
created_at
:
chrono
::
Utc
::
now
(),
user
:
None
,
user
:
None
,
model
:
None
,
model
:
None
,
raw_response
:
Value
::
Null
,
}
}
}
}
}
}
...
@@ -175,3 +181,9 @@ pub trait ResponseStorage: Send + Sync {
...
@@ -175,3 +181,9 @@ pub trait ResponseStorage: Send + Sync {
/// Type alias for shared storage
/// Type alias for shared storage
pub
type
SharedResponseStorage
=
Arc
<
dyn
ResponseStorage
>
;
pub
type
SharedResponseStorage
=
Arc
<
dyn
ResponseStorage
>
;
impl
Default
for
StoredResponse
{
fn
default
()
->
Self
{
Self
::
new
(
String
::
new
(),
String
::
new
(),
None
)
}
}
sgl-router/src/protocols/spec.rs
View file @
98c3b04f
use
serde
::{
Deserialize
,
Serialize
};
use
serde
::{
Deserialize
,
Serialize
};
use
serde_json
::
Value
;
use
serde_json
::
{
to_value
,
Map
,
Number
,
Value
}
;
use
std
::
collections
::
HashMap
;
use
std
::
collections
::
HashMap
;
// # Protocol Specifications
// # Protocol Specifications
...
@@ -350,7 +350,7 @@ pub struct ChatCompletionRequest {
...
@@ -350,7 +350,7 @@ pub struct ChatCompletionRequest {
/// Session parameters for continual prompting
/// Session parameters for continual prompting
#[serde(skip_serializing_if
=
"Option::is_none"
)]
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
session_params
:
Option
<
HashMap
<
String
,
serde_json
::
Value
>>
,
pub
session_params
:
Option
<
HashMap
<
String
,
Value
>>
,
/// Separate reasoning content from final answer (O1-style models)
/// Separate reasoning content from final answer (O1-style models)
#[serde(default
=
"default_true"
)]
#[serde(default
=
"default_true"
)]
...
@@ -362,7 +362,7 @@ pub struct ChatCompletionRequest {
...
@@ -362,7 +362,7 @@ pub struct ChatCompletionRequest {
/// Chat template kwargs
/// Chat template kwargs
#[serde(skip_serializing_if
=
"Option::is_none"
)]
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
chat_template_kwargs
:
Option
<
HashMap
<
String
,
serde_json
::
Value
>>
,
pub
chat_template_kwargs
:
Option
<
HashMap
<
String
,
Value
>>
,
/// Return model hidden states
/// Return model hidden states
#[serde(default)]
#[serde(default)]
...
@@ -447,7 +447,7 @@ pub struct ChatChoice {
...
@@ -447,7 +447,7 @@ pub struct ChatChoice {
pub
finish_reason
:
Option
<
String
>
,
// "stop", "length", "tool_calls", "content_filter", "function_call"
pub
finish_reason
:
Option
<
String
>
,
// "stop", "length", "tool_calls", "content_filter", "function_call"
/// Information about which stop condition was matched
/// Information about which stop condition was matched
#[serde(skip_serializing_if
=
"Option::is_none"
)]
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
matched_stop
:
Option
<
serde_json
::
Value
>
,
// Can be string or integer
pub
matched_stop
:
Option
<
Value
>
,
// Can be string or integer
/// Hidden states from the model (SGLang extension)
/// Hidden states from the model (SGLang extension)
#[serde(skip_serializing_if
=
"Option::is_none"
)]
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
hidden_states
:
Option
<
Vec
<
f32
>>
,
pub
hidden_states
:
Option
<
Vec
<
f32
>>
,
...
@@ -606,7 +606,7 @@ pub struct CompletionRequest {
...
@@ -606,7 +606,7 @@ pub struct CompletionRequest {
/// Session parameters for continual prompting
/// Session parameters for continual prompting
#[serde(skip_serializing_if
=
"Option::is_none"
)]
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
session_params
:
Option
<
HashMap
<
String
,
serde_json
::
Value
>>
,
pub
session_params
:
Option
<
HashMap
<
String
,
Value
>>
,
/// Return model hidden states
/// Return model hidden states
#[serde(default)]
#[serde(default)]
...
@@ -618,7 +618,7 @@ pub struct CompletionRequest {
...
@@ -618,7 +618,7 @@ pub struct CompletionRequest {
/// Additional fields including bootstrap info for PD routing
/// Additional fields including bootstrap info for PD routing
#[serde(flatten)]
#[serde(flatten)]
pub
other
:
serde_json
::
Map
<
String
,
serde_json
::
Value
>
,
pub
other
:
Map
<
String
,
Value
>
,
}
}
impl
GenerationRequest
for
CompletionRequest
{
impl
GenerationRequest
for
CompletionRequest
{
...
@@ -662,7 +662,7 @@ pub struct CompletionChoice {
...
@@ -662,7 +662,7 @@ pub struct CompletionChoice {
pub
finish_reason
:
Option
<
String
>
,
// "stop", "length", "content_filter", etc.
pub
finish_reason
:
Option
<
String
>
,
// "stop", "length", "content_filter", etc.
/// Information about which stop condition was matched
/// Information about which stop condition was matched
#[serde(skip_serializing_if
=
"Option::is_none"
)]
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
matched_stop
:
Option
<
serde_json
::
Value
>
,
// Can be string or integer
pub
matched_stop
:
Option
<
Value
>
,
// Can be string or integer
/// Hidden states from the model (SGLang extension)
/// Hidden states from the model (SGLang extension)
#[serde(skip_serializing_if
=
"Option::is_none"
)]
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
hidden_states
:
Option
<
Vec
<
f32
>>
,
pub
hidden_states
:
Option
<
Vec
<
f32
>>
,
...
@@ -776,6 +776,10 @@ pub enum ResponseContentPart {
...
@@ -776,6 +776,10 @@ pub enum ResponseContentPart {
#[serde(skip_serializing_if
=
"Option::is_none"
)]
#[serde(skip_serializing_if
=
"Option::is_none"
)]
logprobs
:
Option
<
ChatLogProbs
>
,
logprobs
:
Option
<
ChatLogProbs
>
,
},
},
#[serde(rename
=
"input_text"
)]
InputText
{
text
:
String
},
#[serde(other)]
Unknown
,
}
}
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
...
@@ -864,6 +868,29 @@ pub enum ResponseStatus {
...
@@ -864,6 +868,29 @@ pub enum ResponseStatus {
Cancelled
,
Cancelled
,
}
}
// ============= Reasoning Info =============
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
pub
struct
ReasoningInfo
{
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
effort
:
Option
<
String
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
summary
:
Option
<
String
>
,
}
// ============= Text Format =============
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
pub
struct
ResponseTextFormat
{
pub
format
:
TextFormatType
,
}
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
pub
struct
TextFormatType
{
#[serde(rename
=
"type"
)]
pub
format_type
:
String
,
}
// ============= Include Fields =============
// ============= Include Fields =============
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
...
@@ -915,6 +942,13 @@ pub struct ResponseUsage {
...
@@ -915,6 +942,13 @@ pub struct ResponseUsage {
pub
output_tokens_details
:
Option
<
OutputTokensDetails
>
,
pub
output_tokens_details
:
Option
<
OutputTokensDetails
>
,
}
}
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
#[serde(untagged)]
pub
enum
ResponsesUsage
{
Classic
(
UsageInfo
),
Modern
(
ResponseUsage
),
}
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
pub
struct
InputTokensDetails
{
pub
struct
InputTokensDetails
{
pub
cached_tokens
:
u32
,
pub
cached_tokens
:
u32
,
...
@@ -970,6 +1004,34 @@ impl ResponseUsage {
...
@@ -970,6 +1004,34 @@ impl ResponseUsage {
}
}
}
}
#[derive(Debug,
Clone,
Default,
Deserialize,
Serialize)]
pub
struct
ResponsesGetParams
{
#[serde(default)]
pub
include
:
Vec
<
String
>
,
#[serde(default)]
pub
include_obfuscation
:
Option
<
bool
>
,
#[serde(default)]
pub
starting_after
:
Option
<
i64
>
,
#[serde(default)]
pub
stream
:
Option
<
bool
>
,
}
impl
ResponsesUsage
{
pub
fn
to_response_usage
(
&
self
)
->
ResponseUsage
{
match
self
{
ResponsesUsage
::
Classic
(
usage
)
=>
usage
.to_response_usage
(),
ResponsesUsage
::
Modern
(
usage
)
=>
usage
.clone
(),
}
}
pub
fn
to_usage_info
(
&
self
)
->
UsageInfo
{
match
self
{
ResponsesUsage
::
Classic
(
usage
)
=>
usage
.clone
(),
ResponsesUsage
::
Modern
(
usage
)
=>
usage
.to_usage_info
(),
}
}
}
fn
generate_request_id
()
->
String
{
fn
generate_request_id
()
->
String
{
format!
(
"resp_{}"
,
uuid
::
Uuid
::
new_v4
()
.simple
())
format!
(
"resp_{}"
,
uuid
::
Uuid
::
new_v4
()
.simple
())
}
}
...
@@ -1002,7 +1064,7 @@ pub struct ResponsesRequest {
...
@@ -1002,7 +1064,7 @@ pub struct ResponsesRequest {
/// Additional metadata
/// Additional metadata
#[serde(skip_serializing_if
=
"Option::is_none"
)]
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
metadata
:
Option
<
HashMap
<
String
,
serde_json
::
Value
>>
,
pub
metadata
:
Option
<
HashMap
<
String
,
Value
>>
,
/// Model to use (optional to match vLLM)
/// Model to use (optional to match vLLM)
#[serde(skip_serializing_if
=
"Option::is_none"
)]
#[serde(skip_serializing_if
=
"Option::is_none"
)]
...
@@ -1109,6 +1171,42 @@ fn default_repetition_penalty() -> f32 {
...
@@ -1109,6 +1171,42 @@ fn default_repetition_penalty() -> f32 {
1.0
1.0
}
}
impl
Default
for
ResponsesRequest
{
fn
default
()
->
Self
{
Self
{
background
:
false
,
include
:
None
,
input
:
ResponseInput
::
Text
(
String
::
new
()),
instructions
:
None
,
max_output_tokens
:
None
,
max_tool_calls
:
None
,
metadata
:
None
,
model
:
None
,
parallel_tool_calls
:
true
,
previous_response_id
:
None
,
reasoning
:
None
,
service_tier
:
ServiceTier
::
default
(),
store
:
true
,
stream
:
false
,
temperature
:
None
,
tool_choice
:
ToolChoice
::
default
(),
tools
:
Vec
::
new
(),
top_logprobs
:
0
,
top_p
:
None
,
truncation
:
Truncation
::
default
(),
user
:
None
,
request_id
:
generate_request_id
(),
priority
:
0
,
frequency_penalty
:
0.0
,
presence_penalty
:
0.0
,
stop
:
None
,
top_k
:
default_top_k
(),
min_p
:
0.0
,
repetition_penalty
:
default_repetition_penalty
(),
}
}
}
impl
ResponsesRequest
{
impl
ResponsesRequest
{
/// Default sampling parameters
/// Default sampling parameters
const
DEFAULT_TEMPERATURE
:
f32
=
0.7
;
const
DEFAULT_TEMPERATURE
:
f32
=
0.7
;
...
@@ -1118,8 +1216,8 @@ impl ResponsesRequest {
...
@@ -1118,8 +1216,8 @@ impl ResponsesRequest {
pub
fn
to_sampling_params
(
pub
fn
to_sampling_params
(
&
self
,
&
self
,
default_max_tokens
:
u32
,
default_max_tokens
:
u32
,
default_params
:
Option
<
HashMap
<
String
,
serde_json
::
Value
>>
,
default_params
:
Option
<
HashMap
<
String
,
Value
>>
,
)
->
HashMap
<
String
,
serde_json
::
Value
>
{
)
->
HashMap
<
String
,
Value
>
{
let
mut
params
=
HashMap
::
new
();
let
mut
params
=
HashMap
::
new
();
// Use max_output_tokens if available
// Use max_output_tokens if available
...
@@ -1154,47 +1252,38 @@ impl ResponsesRequest {
...
@@ -1154,47 +1252,38 @@ impl ResponsesRequest {
params
.insert
(
params
.insert
(
"max_new_tokens"
.to_string
(),
"max_new_tokens"
.to_string
(),
serde_json
::
Value
::
Number
(
serde_json
::
Number
::
from
(
max_tokens
)),
Value
::
Number
(
Number
::
from
(
max_tokens
)),
);
);
params
.insert
(
params
.insert
(
"temperature"
.to_string
(),
"temperature"
.to_string
(),
serde_json
::
Value
::
Number
(
serde_json
::
Number
::
from_f64
(
temperature
as
f64
)
.unwrap
()),
Value
::
Number
(
Number
::
from_f64
(
temperature
as
f64
)
.unwrap
()),
);
);
params
.insert
(
params
.insert
(
"top_p"
.to_string
(),
"top_p"
.to_string
(),
serde_json
::
Value
::
Number
(
serde_json
::
Number
::
from_f64
(
top_p
as
f64
)
.unwrap
()),
Value
::
Number
(
Number
::
from_f64
(
top_p
as
f64
)
.unwrap
()),
);
);
params
.insert
(
params
.insert
(
"frequency_penalty"
.to_string
(),
"frequency_penalty"
.to_string
(),
serde_json
::
Value
::
Number
(
Value
::
Number
(
Number
::
from_f64
(
self
.frequency_penalty
as
f64
)
.unwrap
()),
serde_json
::
Number
::
from_f64
(
self
.frequency_penalty
as
f64
)
.unwrap
(),
),
);
);
params
.insert
(
params
.insert
(
"presence_penalty"
.to_string
(),
"presence_penalty"
.to_string
(),
serde_json
::
Value
::
Number
(
Value
::
Number
(
Number
::
from_f64
(
self
.presence_penalty
as
f64
)
.unwrap
()),
serde_json
::
Number
::
from_f64
(
self
.presence_penalty
as
f64
)
.unwrap
(),
),
);
params
.insert
(
"top_k"
.to_string
(),
serde_json
::
Value
::
Number
(
serde_json
::
Number
::
from
(
self
.top_k
)),
);
);
params
.insert
(
"top_k"
.to_string
(),
Value
::
Number
(
Number
::
from
(
self
.top_k
)));
params
.insert
(
params
.insert
(
"min_p"
.to_string
(),
"min_p"
.to_string
(),
serde_json
::
Value
::
Number
(
serde_json
::
Number
::
from_f64
(
self
.min_p
as
f64
)
.unwrap
()),
Value
::
Number
(
Number
::
from_f64
(
self
.min_p
as
f64
)
.unwrap
()),
);
);
params
.insert
(
params
.insert
(
"repetition_penalty"
.to_string
(),
"repetition_penalty"
.to_string
(),
serde_json
::
Value
::
Number
(
Value
::
Number
(
Number
::
from_f64
(
self
.repetition_penalty
as
f64
)
.unwrap
()),
serde_json
::
Number
::
from_f64
(
self
.repetition_penalty
as
f64
)
.unwrap
(),
),
);
);
if
let
Some
(
ref
stop
)
=
self
.stop
{
if
let
Some
(
ref
stop
)
=
self
.stop
{
match
serde_json
::
to_value
(
stop
)
{
match
to_value
(
stop
)
{
Ok
(
value
)
=>
params
.insert
(
"stop"
.to_string
(),
value
),
Ok
(
value
)
=>
params
.insert
(
"stop"
.to_string
(),
value
),
Err
(
_
)
=>
params
.insert
(
"stop"
.to_string
(),
serde_json
::
Value
::
Null
),
Err
(
_
)
=>
params
.insert
(
"stop"
.to_string
(),
Value
::
Null
),
};
};
}
}
...
@@ -1227,8 +1316,10 @@ impl GenerationRequest for ResponsesRequest {
...
@@ -1227,8 +1316,10 @@ impl GenerationRequest for ResponsesRequest {
ResponseInputOutputItem
::
Message
{
content
,
..
}
=>
{
ResponseInputOutputItem
::
Message
{
content
,
..
}
=>
{
let
texts
:
Vec
<
String
>
=
content
let
texts
:
Vec
<
String
>
=
content
.iter
()
.iter
()
.map
(|
part
|
match
part
{
.filter_map
(|
part
|
match
part
{
ResponseContentPart
::
OutputText
{
text
,
..
}
=>
text
.clone
(),
ResponseContentPart
::
OutputText
{
text
,
..
}
=>
Some
(
text
.clone
()),
ResponseContentPart
::
InputText
{
text
}
=>
Some
(
text
.clone
()),
ResponseContentPart
::
Unknown
=>
None
,
})
})
.collect
();
.collect
();
if
texts
.is_empty
()
{
if
texts
.is_empty
()
{
...
@@ -1285,6 +1376,25 @@ pub struct ResponsesResponse {
...
@@ -1285,6 +1376,25 @@ pub struct ResponsesResponse {
#[serde(default
=
"current_timestamp"
)]
#[serde(default
=
"current_timestamp"
)]
pub
created_at
:
i64
,
pub
created_at
:
i64
,
/// Response status
pub
status
:
ResponseStatus
,
/// Error information if status is failed
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
error
:
Option
<
Value
>
,
/// Incomplete details if response was truncated
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
incomplete_details
:
Option
<
Value
>
,
/// System instructions used
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
instructions
:
Option
<
String
>
,
/// Max output tokens setting
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
max_output_tokens
:
Option
<
u32
>
,
/// Model name
/// Model name
pub
model
:
String
,
pub
model
:
String
,
...
@@ -1292,16 +1402,29 @@ pub struct ResponsesResponse {
...
@@ -1292,16 +1402,29 @@ pub struct ResponsesResponse {
#[serde(default)]
#[serde(default)]
pub
output
:
Vec
<
ResponseOutputItem
>
,
pub
output
:
Vec
<
ResponseOutputItem
>
,
/// Response status
/// Whether parallel tool calls are enabled
pub
status
:
ResponseStatus
,
#[serde(default
=
"default_true"
)]
pub
parallel_tool_calls
:
bool
,
///
Usage inform
ation
///
Previous response ID if this is a continu
ation
#[serde(skip_serializing_if
=
"Option::is_none"
)]
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
usage
:
Option
<
UsageInfo
>
,
pub
previous_response_id
:
Option
<
String
>
,
/// Whether parallel tool calls are enabled
/// Reasoning information
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
reasoning
:
Option
<
ReasoningInfo
>
,
/// Whether the response is stored
#[serde(default
=
"default_true"
)]
#[serde(default
=
"default_true"
)]
pub
parallel_tool_calls
:
bool
,
pub
store
:
bool
,
/// Temperature setting used
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
temperature
:
Option
<
f32
>
,
/// Text format settings
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
text
:
Option
<
ResponseTextFormat
>
,
/// Tool choice setting
/// Tool choice setting
#[serde(default
=
"default_tool_choice"
)]
#[serde(default
=
"default_tool_choice"
)]
...
@@ -1310,6 +1433,26 @@ pub struct ResponsesResponse {
...
@@ -1310,6 +1433,26 @@ pub struct ResponsesResponse {
/// Available tools
/// Available tools
#[serde(default)]
#[serde(default)]
pub
tools
:
Vec
<
ResponseTool
>
,
pub
tools
:
Vec
<
ResponseTool
>
,
/// Top-p setting used
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
top_p
:
Option
<
f32
>
,
/// Truncation strategy used
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
truncation
:
Option
<
String
>
,
/// Usage information
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
usage
:
Option
<
ResponsesUsage
>
,
/// User identifier
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
user
:
Option
<
String
>
,
/// Additional metadata
#[serde(default)]
pub
metadata
:
HashMap
<
String
,
Value
>
,
}
}
fn
default_object_type
()
->
String
{
fn
default_object_type
()
->
String
{
...
@@ -1325,7 +1468,7 @@ impl ResponsesResponse {
...
@@ -1325,7 +1468,7 @@ impl ResponsesResponse {
#[allow(clippy::too_many_arguments)]
#[allow(clippy::too_many_arguments)]
pub
fn
from_request
(
pub
fn
from_request
(
request
:
&
ResponsesRequest
,
request
:
&
ResponsesRequest
,
_
sampling_params
:
&
HashMap
<
String
,
serde_json
::
Value
>
,
_
sampling_params
:
&
HashMap
<
String
,
Value
>
,
model_name
:
String
,
model_name
:
String
,
created_time
:
i64
,
created_time
:
i64
,
output
:
Vec
<
ResponseOutputItem
>
,
output
:
Vec
<
ResponseOutputItem
>
,
...
@@ -1336,11 +1479,26 @@ impl ResponsesResponse {
...
@@ -1336,11 +1479,26 @@ impl ResponsesResponse {
id
:
request
.request_id
.clone
(),
id
:
request
.request_id
.clone
(),
object
:
"response"
.to_string
(),
object
:
"response"
.to_string
(),
created_at
:
created_time
,
created_at
:
created_time
,
status
,
error
:
None
,
incomplete_details
:
None
,
instructions
:
request
.instructions
.clone
(),
max_output_tokens
:
request
.max_output_tokens
,
model
:
model_name
,
model
:
model_name
,
output
,
output
,
status
,
usage
,
parallel_tool_calls
:
request
.parallel_tool_calls
,
parallel_tool_calls
:
request
.parallel_tool_calls
,
previous_response_id
:
request
.previous_response_id
.clone
(),
reasoning
:
request
.reasoning
.as_ref
()
.map
(|
r
|
ReasoningInfo
{
effort
:
r
.effort
.as_ref
()
.map
(|
e
|
format!
(
"{:?}"
,
e
)),
summary
:
None
,
}),
store
:
request
.store
,
temperature
:
request
.temperature
,
text
:
Some
(
ResponseTextFormat
{
format
:
TextFormatType
{
format_type
:
"text"
.to_string
(),
},
}),
tool_choice
:
match
&
request
.tool_choice
{
tool_choice
:
match
&
request
.tool_choice
{
ToolChoice
::
Value
(
ToolChoiceValue
::
Auto
)
=>
"auto"
.to_string
(),
ToolChoice
::
Value
(
ToolChoiceValue
::
Auto
)
=>
"auto"
.to_string
(),
ToolChoice
::
Value
(
ToolChoiceValue
::
Required
)
=>
"required"
.to_string
(),
ToolChoice
::
Value
(
ToolChoiceValue
::
Required
)
=>
"required"
.to_string
(),
...
@@ -1348,6 +1506,14 @@ impl ResponsesResponse {
...
@@ -1348,6 +1506,14 @@ impl ResponsesResponse {
ToolChoice
::
Function
{
..
}
=>
"function"
.to_string
(),
ToolChoice
::
Function
{
..
}
=>
"function"
.to_string
(),
},
},
tools
:
request
.tools
.clone
(),
tools
:
request
.tools
.clone
(),
top_p
:
request
.top_p
,
truncation
:
match
&
request
.truncation
{
Truncation
::
Auto
=>
Some
(
"auto"
.to_string
()),
Truncation
::
Disabled
=>
Some
(
"disabled"
.to_string
()),
},
usage
:
usage
.map
(
ResponsesUsage
::
Classic
),
user
:
request
.user
.clone
(),
metadata
:
request
.metadata
.clone
()
.unwrap_or_default
(),
}
}
}
}
...
@@ -1357,13 +1523,26 @@ impl ResponsesResponse {
...
@@ -1357,13 +1523,26 @@ impl ResponsesResponse {
id
:
request_id
,
id
:
request_id
,
object
:
"response"
.to_string
(),
object
:
"response"
.to_string
(),
created_at
:
current_timestamp
(),
created_at
:
current_timestamp
(),
status
,
error
:
None
,
incomplete_details
:
None
,
instructions
:
None
,
max_output_tokens
:
None
,
model
,
model
,
output
:
Vec
::
new
(),
output
:
Vec
::
new
(),
status
,
usage
:
None
,
parallel_tool_calls
:
true
,
parallel_tool_calls
:
true
,
previous_response_id
:
None
,
reasoning
:
None
,
store
:
true
,
temperature
:
None
,
text
:
None
,
tool_choice
:
"auto"
.to_string
(),
tool_choice
:
"auto"
.to_string
(),
tools
:
Vec
::
new
(),
tools
:
Vec
::
new
(),
top_p
:
None
,
truncation
:
None
,
usage
:
None
,
user
:
None
,
metadata
:
HashMap
::
new
(),
}
}
}
}
...
@@ -1374,7 +1553,7 @@ impl ResponsesResponse {
...
@@ -1374,7 +1553,7 @@ impl ResponsesResponse {
/// Set the usage information
/// Set the usage information
pub
fn
set_usage
(
&
mut
self
,
usage
:
UsageInfo
)
{
pub
fn
set_usage
(
&
mut
self
,
usage
:
UsageInfo
)
{
self
.usage
=
Some
(
usage
);
self
.usage
=
Some
(
ResponsesUsage
::
Classic
(
usage
)
)
;
}
}
/// Update the status
/// Update the status
...
@@ -1413,12 +1592,12 @@ impl ResponsesResponse {
...
@@ -1413,12 +1592,12 @@ impl ResponsesResponse {
}
}
/// Get the response as a JSON value with usage in response format
/// Get the response as a JSON value with usage in response format
pub
fn
to_response_format
(
&
self
)
->
serde_json
::
Value
{
pub
fn
to_response_format
(
&
self
)
->
Value
{
let
mut
response
=
serde_json
::
to_value
(
self
)
.unwrap_or
(
serde_json
::
Value
::
Null
);
let
mut
response
=
to_value
(
self
)
.unwrap_or
(
Value
::
Null
);
// Convert usage to response format if present
// Convert usage to response format if present
if
let
Some
(
usage
)
=
&
self
.usage
{
if
let
Some
(
usage
)
=
&
self
.usage
{
if
let
Ok
(
usage_value
)
=
serde_json
::
to_value
(
usage
.to_response_usage
())
{
if
let
Ok
(
usage_value
)
=
to_value
(
usage
.to_response_usage
())
{
response
[
"usage"
]
=
usage_value
;
response
[
"usage"
]
=
usage_value
;
}
}
}
}
...
@@ -1641,8 +1820,13 @@ pub struct LogProbs {
...
@@ -1641,8 +1820,13 @@ pub struct LogProbs {
}
}
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
pub
struct
ChatLogProbs
{
#[serde(untagged)]
pub
content
:
Option
<
Vec
<
ChatLogProbsContent
>>
,
pub
enum
ChatLogProbs
{
Detailed
{
#[serde(skip_serializing_if
=
"Option::is_none"
)]
content
:
Option
<
Vec
<
ChatLogProbsContent
>>
,
},
Raw
(
Value
),
}
}
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
...
@@ -1798,7 +1982,7 @@ pub struct GenerateRequest {
...
@@ -1798,7 +1982,7 @@ pub struct GenerateRequest {
/// Session parameters for continual prompting
/// Session parameters for continual prompting
#[serde(skip_serializing_if
=
"Option::is_none"
)]
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
session_params
:
Option
<
HashMap
<
String
,
serde_json
::
Value
>>
,
pub
session_params
:
Option
<
HashMap
<
String
,
Value
>>
,
/// Return model hidden states
/// Return model hidden states
#[serde(default)]
#[serde(default)]
...
@@ -2065,7 +2249,7 @@ pub struct EmbeddingRequest {
...
@@ -2065,7 +2249,7 @@ pub struct EmbeddingRequest {
pub
model
:
String
,
pub
model
:
String
,
/// Input can be a string, array of strings, tokens, or batch inputs
/// Input can be a string, array of strings, tokens, or batch inputs
pub
input
:
serde_json
::
Value
,
pub
input
:
Value
,
/// Optional encoding format (e.g., "float", "base64")
/// Optional encoding format (e.g., "float", "base64")
#[serde(skip_serializing_if
=
"Option::is_none"
)]
#[serde(skip_serializing_if
=
"Option::is_none"
)]
...
@@ -2097,8 +2281,8 @@ impl GenerationRequest for EmbeddingRequest {
...
@@ -2097,8 +2281,8 @@ impl GenerationRequest for EmbeddingRequest {
fn
extract_text_for_routing
(
&
self
)
->
String
{
fn
extract_text_for_routing
(
&
self
)
->
String
{
// Best effort: extract text content for routing decisions
// Best effort: extract text content for routing decisions
match
&
self
.input
{
match
&
self
.input
{
serde_json
::
Value
::
String
(
s
)
=>
s
.clone
(),
Value
::
String
(
s
)
=>
s
.clone
(),
serde_json
::
Value
::
Array
(
arr
)
=>
arr
Value
::
Array
(
arr
)
=>
arr
.iter
()
.iter
()
.filter_map
(|
v
|
v
.as_str
())
.filter_map
(|
v
|
v
.as_str
())
.collect
::
<
Vec
<
_
>>
()
.collect
::
<
Vec
<
_
>>
()
...
@@ -2173,7 +2357,7 @@ pub enum LoRAPath {
...
@@ -2173,7 +2357,7 @@ pub enum LoRAPath {
#[cfg(test)]
#[cfg(test)]
mod
tests
{
mod
tests
{
use
super
::
*
;
use
super
::
*
;
use
serde_json
;
use
serde_json
::{
from_str
,
json
,
to_string
}
;
// ==================================================================
// ==================================================================
// = RERANK REQUEST TESTS =
// = RERANK REQUEST TESTS =
...
@@ -2191,8 +2375,8 @@ mod tests {
...
@@ -2191,8 +2375,8 @@ mod tests {
user
:
Some
(
"user-456"
.to_string
()),
user
:
Some
(
"user-456"
.to_string
()),
};
};
let
serialized
=
serde_json
::
to_string
(
&
request
)
.unwrap
();
let
serialized
=
to_string
(
&
request
)
.unwrap
();
let
deserialized
:
RerankRequest
=
serde_json
::
from_str
(
&
serialized
)
.unwrap
();
let
deserialized
:
RerankRequest
=
from_str
(
&
serialized
)
.unwrap
();
assert_eq!
(
deserialized
.query
,
request
.query
);
assert_eq!
(
deserialized
.query
,
request
.query
);
assert_eq!
(
deserialized
.documents
,
request
.documents
);
assert_eq!
(
deserialized
.documents
,
request
.documents
);
...
@@ -2210,7 +2394,7 @@ mod tests {
...
@@ -2210,7 +2394,7 @@ mod tests {
"documents": ["doc1", "doc2"]
"documents": ["doc1", "doc2"]
}"#
;
}"#
;
let
request
:
RerankRequest
=
serde_json
::
from_str
(
json
)
.unwrap
();
let
request
:
RerankRequest
=
from_str
(
json
)
.unwrap
();
assert_eq!
(
request
.query
,
"test query"
);
assert_eq!
(
request
.query
,
"test query"
);
assert_eq!
(
request
.documents
,
vec!
[
"doc1"
,
"doc2"
]);
assert_eq!
(
request
.documents
,
vec!
[
"doc1"
,
"doc2"
]);
...
@@ -2402,8 +2586,8 @@ mod tests {
...
@@ -2402,8 +2586,8 @@ mod tests {
Some
(
StringOrArray
::
String
(
"req-123"
.to_string
())),
Some
(
StringOrArray
::
String
(
"req-123"
.to_string
())),
);
);
let
serialized
=
serde_json
::
to_string
(
&
response
)
.unwrap
();
let
serialized
=
to_string
(
&
response
)
.unwrap
();
let
deserialized
:
RerankResponse
=
serde_json
::
from_str
(
&
serialized
)
.unwrap
();
let
deserialized
:
RerankResponse
=
from_str
(
&
serialized
)
.unwrap
();
assert_eq!
(
deserialized
.results
.len
(),
response
.results
.len
());
assert_eq!
(
deserialized
.results
.len
(),
response
.results
.len
());
assert_eq!
(
deserialized
.model
,
response
.model
);
assert_eq!
(
deserialized
.model
,
response
.model
);
...
@@ -2539,13 +2723,13 @@ mod tests {
...
@@ -2539,13 +2723,13 @@ mod tests {
(
"confidence"
.to_string
(),
Value
::
String
(
"high"
.to_string
())),
(
"confidence"
.to_string
(),
Value
::
String
(
"high"
.to_string
())),
(
(
"processing_time"
.to_string
(),
"processing_time"
.to_string
(),
Value
::
Number
(
serde_json
::
Number
::
from
(
150
)),
Value
::
Number
(
Number
::
from
(
150
)),
),
),
])),
])),
};
};
let
serialized
=
serde_json
::
to_string
(
&
result
)
.unwrap
();
let
serialized
=
to_string
(
&
result
)
.unwrap
();
let
deserialized
:
RerankResult
=
serde_json
::
from_str
(
&
serialized
)
.unwrap
();
let
deserialized
:
RerankResult
=
from_str
(
&
serialized
)
.unwrap
();
assert_eq!
(
deserialized
.score
,
result
.score
);
assert_eq!
(
deserialized
.score
,
result
.score
);
assert_eq!
(
deserialized
.document
,
result
.document
);
assert_eq!
(
deserialized
.document
,
result
.document
);
...
@@ -2562,8 +2746,8 @@ mod tests {
...
@@ -2562,8 +2746,8 @@ mod tests {
meta_info
:
None
,
meta_info
:
None
,
};
};
let
serialized
=
serde_json
::
to_string
(
&
result
)
.unwrap
();
let
serialized
=
to_string
(
&
result
)
.unwrap
();
let
deserialized
:
RerankResult
=
serde_json
::
from_str
(
&
serialized
)
.unwrap
();
let
deserialized
:
RerankResult
=
from_str
(
&
serialized
)
.unwrap
();
assert_eq!
(
deserialized
.score
,
result
.score
);
assert_eq!
(
deserialized
.score
,
result
.score
);
assert_eq!
(
deserialized
.document
,
result
.document
);
assert_eq!
(
deserialized
.document
,
result
.document
);
...
@@ -2582,8 +2766,8 @@ mod tests {
...
@@ -2582,8 +2766,8 @@ mod tests {
documents
:
vec!
[
"doc1"
.to_string
(),
"doc2"
.to_string
()],
documents
:
vec!
[
"doc1"
.to_string
(),
"doc2"
.to_string
()],
};
};
let
serialized
=
serde_json
::
to_string
(
&
v1_input
)
.unwrap
();
let
serialized
=
to_string
(
&
v1_input
)
.unwrap
();
let
deserialized
:
V1RerankReqInput
=
serde_json
::
from_str
(
&
serialized
)
.unwrap
();
let
deserialized
:
V1RerankReqInput
=
from_str
(
&
serialized
)
.unwrap
();
assert_eq!
(
deserialized
.query
,
v1_input
.query
);
assert_eq!
(
deserialized
.query
,
v1_input
.query
);
assert_eq!
(
deserialized
.documents
,
v1_input
.documents
);
assert_eq!
(
deserialized
.documents
,
v1_input
.documents
);
...
@@ -2724,8 +2908,8 @@ mod tests {
...
@@ -2724,8 +2908,8 @@ mod tests {
prompt_tokens_details
:
None
,
prompt_tokens_details
:
None
,
});
});
let
serialized
=
serde_json
::
to_string
(
&
response
)
.unwrap
();
let
serialized
=
to_string
(
&
response
)
.unwrap
();
let
deserialized
:
RerankResponse
=
serde_json
::
from_str
(
&
serialized
)
.unwrap
();
let
deserialized
:
RerankResponse
=
from_str
(
&
serialized
)
.unwrap
();
assert
!
(
deserialized
.usage
.is_some
());
assert
!
(
deserialized
.usage
.is_some
());
let
usage
=
deserialized
.usage
.unwrap
();
let
usage
=
deserialized
.usage
.unwrap
();
...
@@ -2805,8 +2989,8 @@ mod tests {
...
@@ -2805,8 +2989,8 @@ mod tests {
assert_eq!
(
response
.model
,
"rerank-model"
);
assert_eq!
(
response
.model
,
"rerank-model"
);
// Serialize and deserialize
// Serialize and deserialize
let
serialized
=
serde_json
::
to_string
(
&
response
)
.unwrap
();
let
serialized
=
to_string
(
&
response
)
.unwrap
();
let
deserialized
:
RerankResponse
=
serde_json
::
from_str
(
&
serialized
)
.unwrap
();
let
deserialized
:
RerankResponse
=
from_str
(
&
serialized
)
.unwrap
();
assert_eq!
(
deserialized
.results
.len
(),
2
);
assert_eq!
(
deserialized
.results
.len
(),
2
);
assert_eq!
(
deserialized
.model
,
response
.model
);
assert_eq!
(
deserialized
.model
,
response
.model
);
}
}
...
@@ -2819,15 +3003,15 @@ mod tests {
...
@@ -2819,15 +3003,15 @@ mod tests {
fn
test_embedding_request_serialization_string_input
()
{
fn
test_embedding_request_serialization_string_input
()
{
let
req
=
EmbeddingRequest
{
let
req
=
EmbeddingRequest
{
model
:
"test-emb"
.to_string
(),
model
:
"test-emb"
.to_string
(),
input
:
serde_json
::
Value
::
String
(
"hello"
.to_string
()),
input
:
Value
::
String
(
"hello"
.to_string
()),
encoding_format
:
Some
(
"float"
.to_string
()),
encoding_format
:
Some
(
"float"
.to_string
()),
user
:
Some
(
"user-1"
.to_string
()),
user
:
Some
(
"user-1"
.to_string
()),
dimensions
:
Some
(
128
),
dimensions
:
Some
(
128
),
rid
:
Some
(
"rid-123"
.to_string
()),
rid
:
Some
(
"rid-123"
.to_string
()),
};
};
let
serialized
=
serde_json
::
to_string
(
&
req
)
.unwrap
();
let
serialized
=
to_string
(
&
req
)
.unwrap
();
let
deserialized
:
EmbeddingRequest
=
serde_json
::
from_str
(
&
serialized
)
.unwrap
();
let
deserialized
:
EmbeddingRequest
=
from_str
(
&
serialized
)
.unwrap
();
assert_eq!
(
deserialized
.model
,
req
.model
);
assert_eq!
(
deserialized
.model
,
req
.model
);
assert_eq!
(
deserialized
.input
,
req
.input
);
assert_eq!
(
deserialized
.input
,
req
.input
);
...
@@ -2841,15 +3025,15 @@ mod tests {
...
@@ -2841,15 +3025,15 @@ mod tests {
fn
test_embedding_request_serialization_array_input
()
{
fn
test_embedding_request_serialization_array_input
()
{
let
req
=
EmbeddingRequest
{
let
req
=
EmbeddingRequest
{
model
:
"test-emb"
.to_string
(),
model
:
"test-emb"
.to_string
(),
input
:
serde_json
::
json!
([
"a"
,
"b"
,
"c"
]),
input
:
json!
([
"a"
,
"b"
,
"c"
]),
encoding_format
:
None
,
encoding_format
:
None
,
user
:
None
,
user
:
None
,
dimensions
:
None
,
dimensions
:
None
,
rid
:
None
,
rid
:
None
,
};
};
let
serialized
=
serde_json
::
to_string
(
&
req
)
.unwrap
();
let
serialized
=
to_string
(
&
req
)
.unwrap
();
let
de
:
EmbeddingRequest
=
serde_json
::
from_str
(
&
serialized
)
.unwrap
();
let
de
:
EmbeddingRequest
=
from_str
(
&
serialized
)
.unwrap
();
assert_eq!
(
de
.model
,
req
.model
);
assert_eq!
(
de
.model
,
req
.model
);
assert_eq!
(
de
.input
,
req
.input
);
assert_eq!
(
de
.input
,
req
.input
);
}
}
...
@@ -2858,7 +3042,7 @@ mod tests {
...
@@ -2858,7 +3042,7 @@ mod tests {
fn
test_embedding_generation_request_trait_string
()
{
fn
test_embedding_generation_request_trait_string
()
{
let
req
=
EmbeddingRequest
{
let
req
=
EmbeddingRequest
{
model
:
"emb-model"
.to_string
(),
model
:
"emb-model"
.to_string
(),
input
:
serde_json
::
Value
::
String
(
"hello"
.to_string
()),
input
:
Value
::
String
(
"hello"
.to_string
()),
encoding_format
:
None
,
encoding_format
:
None
,
user
:
None
,
user
:
None
,
dimensions
:
None
,
dimensions
:
None
,
...
@@ -2873,7 +3057,7 @@ mod tests {
...
@@ -2873,7 +3057,7 @@ mod tests {
fn
test_embedding_generation_request_trait_array
()
{
fn
test_embedding_generation_request_trait_array
()
{
let
req
=
EmbeddingRequest
{
let
req
=
EmbeddingRequest
{
model
:
"emb-model"
.to_string
(),
model
:
"emb-model"
.to_string
(),
input
:
serde_json
::
json!
([
"hello"
,
"world"
]),
input
:
json!
([
"hello"
,
"world"
]),
encoding_format
:
None
,
encoding_format
:
None
,
user
:
None
,
user
:
None
,
dimensions
:
None
,
dimensions
:
None
,
...
@@ -2886,7 +3070,7 @@ mod tests {
...
@@ -2886,7 +3070,7 @@ mod tests {
fn
test_embedding_generation_request_trait_non_text
()
{
fn
test_embedding_generation_request_trait_non_text
()
{
let
req
=
EmbeddingRequest
{
let
req
=
EmbeddingRequest
{
model
:
"emb-model"
.to_string
(),
model
:
"emb-model"
.to_string
(),
input
:
serde_json
::
json!
({
"tokens"
:
[
1
,
2
,
3
]}),
input
:
json!
({
"tokens"
:
[
1
,
2
,
3
]}),
encoding_format
:
None
,
encoding_format
:
None
,
user
:
None
,
user
:
None
,
dimensions
:
None
,
dimensions
:
None
,
...
@@ -2899,7 +3083,7 @@ mod tests {
...
@@ -2899,7 +3083,7 @@ mod tests {
fn
test_embedding_generation_request_trait_mixed_array_ignores_nested
()
{
fn
test_embedding_generation_request_trait_mixed_array_ignores_nested
()
{
let
req
=
EmbeddingRequest
{
let
req
=
EmbeddingRequest
{
model
:
"emb-model"
.to_string
(),
model
:
"emb-model"
.to_string
(),
input
:
serde_json
::
json!
([
"a"
,
[
"b"
,
"c"
],
123
,
{
"k"
:
"v"
}]),
input
:
json!
([
"a"
,
[
"b"
,
"c"
],
123
,
{
"k"
:
"v"
}]),
encoding_format
:
None
,
encoding_format
:
None
,
user
:
None
,
user
:
None
,
dimensions
:
None
,
dimensions
:
None
,
...
...
sgl-router/src/routers/factory.rs
View file @
98c3b04f
...
@@ -166,8 +166,12 @@ impl RouterFactory {
...
@@ -166,8 +166,12 @@ impl RouterFactory {
.cloned
()
.cloned
()
.ok_or_else
(||
"OpenAI mode requires at least one worker URL"
.to_string
())
?
;
.ok_or_else
(||
"OpenAI mode requires at least one worker URL"
.to_string
())
?
;
let
router
=
let
router
=
OpenAIRouter
::
new
(
OpenAIRouter
::
new
(
base_url
,
Some
(
ctx
.router_config.circuit_breaker
.clone
()))
.await
?
;
base_url
,
Some
(
ctx
.router_config.circuit_breaker
.clone
()),
ctx
.response_storage
.clone
(),
)
.await
?
;
Ok
(
Box
::
new
(
router
))
Ok
(
Box
::
new
(
router
))
}
}
...
...
sgl-router/src/routers/grpc/pd_router.rs
View file @
98c3b04f
...
@@ -308,7 +308,12 @@ impl RouterTrait for GrpcPDRouter {
...
@@ -308,7 +308,12 @@ impl RouterTrait for GrpcPDRouter {
(
StatusCode
::
NOT_IMPLEMENTED
)
.into_response
()
(
StatusCode
::
NOT_IMPLEMENTED
)
.into_response
()
}
}
async
fn
get_response
(
&
self
,
_
headers
:
Option
<&
HeaderMap
>
,
_
response_id
:
&
str
)
->
Response
{
async
fn
get_response
(
&
self
,
_
headers
:
Option
<&
HeaderMap
>
,
_
response_id
:
&
str
,
_
params
:
&
crate
::
protocols
::
spec
::
ResponsesGetParams
,
)
->
Response
{
(
StatusCode
::
NOT_IMPLEMENTED
)
.into_response
()
(
StatusCode
::
NOT_IMPLEMENTED
)
.into_response
()
}
}
...
...
sgl-router/src/routers/grpc/router.rs
View file @
98c3b04f
...
@@ -237,7 +237,12 @@ impl RouterTrait for GrpcRouter {
...
@@ -237,7 +237,12 @@ impl RouterTrait for GrpcRouter {
(
StatusCode
::
NOT_IMPLEMENTED
)
.into_response
()
(
StatusCode
::
NOT_IMPLEMENTED
)
.into_response
()
}
}
async
fn
get_response
(
&
self
,
_
headers
:
Option
<&
HeaderMap
>
,
_
response_id
:
&
str
)
->
Response
{
async
fn
get_response
(
&
self
,
_
headers
:
Option
<&
HeaderMap
>
,
_
response_id
:
&
str
,
_
params
:
&
crate
::
protocols
::
spec
::
ResponsesGetParams
,
)
->
Response
{
(
StatusCode
::
NOT_IMPLEMENTED
)
.into_response
()
(
StatusCode
::
NOT_IMPLEMENTED
)
.into_response
()
}
}
...
...
sgl-router/src/routers/header_utils.rs
View file @
98c3b04f
...
@@ -51,3 +51,45 @@ fn should_forward_header(name: &str) -> bool {
...
@@ -51,3 +51,45 @@ fn should_forward_header(name: &str) -> bool {
"host"
// Should not forward the backend's host header
"host"
// Should not forward the backend's host header
)
)
}
}
/// Apply headers to a reqwest request builder, filtering out headers that shouldn't be forwarded
/// or that will be set automatically by reqwest
pub
fn
apply_request_headers
(
headers
:
&
HeaderMap
,
mut
request_builder
:
reqwest
::
RequestBuilder
,
skip_content_headers
:
bool
,
)
->
reqwest
::
RequestBuilder
{
// Always forward Authorization header first if present
if
let
Some
(
auth
)
=
headers
.get
(
"authorization"
)
.or_else
(||
headers
.get
(
"Authorization"
))
{
request_builder
=
request_builder
.header
(
"Authorization"
,
auth
.clone
());
}
// Forward other headers, filtering out problematic ones
for
(
key
,
value
)
in
headers
.iter
()
{
let
key_str
=
key
.as_str
()
.to_lowercase
();
// Skip headers that:
// - Are set automatically by reqwest (content-type, content-length for POST/PUT)
// - We already handled (authorization)
// - Are hop-by-hop headers (connection, transfer-encoding)
// - Should not be forwarded (host)
let
should_skip
=
key_str
==
"authorization"
||
// Already handled above
key_str
==
"host"
||
key_str
==
"connection"
||
key_str
==
"transfer-encoding"
||
key_str
==
"keep-alive"
||
key_str
==
"te"
||
key_str
==
"trailers"
||
key_str
==
"upgrade"
||
(
skip_content_headers
&&
(
key_str
==
"content-type"
||
key_str
==
"content-length"
));
if
!
should_skip
{
request_builder
=
request_builder
.header
(
key
.clone
(),
value
.clone
());
}
}
request_builder
}
sgl-router/src/routers/http/openai_router.rs
View file @
98c3b04f
//! OpenAI router implementation
(reqwest-based)
//! OpenAI router implementation
use
crate
::
config
::
CircuitBreakerConfig
;
use
crate
::
config
::
CircuitBreakerConfig
;
use
crate
::
core
::{
CircuitBreaker
,
CircuitBreakerConfig
as
CoreCircuitBreakerConfig
};
use
crate
::
core
::{
CircuitBreaker
,
CircuitBreakerConfig
as
CoreCircuitBreakerConfig
};
use
crate
::
data_connector
::{
ResponseId
,
SharedResponseStorage
,
StoredResponse
};
use
crate
::
protocols
::
spec
::{
use
crate
::
protocols
::
spec
::{
ChatCompletionRequest
,
CompletionRequest
,
GenerateRequest
,
RerankRequest
,
ChatCompletionRequest
,
CompletionRequest
,
EmbeddingRequest
,
GenerateRequest
,
RerankRequest
,
ResponseContentPart
,
ResponseInput
,
ResponseInputOutputItem
,
ResponseOutputItem
,
ResponseStatus
,
ResponseTextFormat
,
ResponsesGetParams
,
ResponsesRequest
,
ResponsesResponse
,
TextFormatType
,
};
};
use
crate
::
routers
::
header_utils
::{
apply_request_headers
,
preserve_response_headers
};
use
async_trait
::
async_trait
;
use
async_trait
::
async_trait
;
use
axum
::{
use
axum
::{
body
::
Body
,
body
::
Body
,
...
@@ -13,13 +18,17 @@ use axum::{
...
@@ -13,13 +18,17 @@ use axum::{
response
::{
IntoResponse
,
Response
},
response
::{
IntoResponse
,
Response
},
};
};
use
futures_util
::
StreamExt
;
use
futures_util
::
StreamExt
;
use
serde_json
::{
json
,
to_value
,
Value
};
use
std
::{
use
std
::{
any
::
Any
,
any
::
Any
,
collections
::
HashMap
,
sync
::
atomic
::{
AtomicBool
,
Ordering
},
sync
::
atomic
::{
AtomicBool
,
Ordering
},
};
};
use
tokio
::
sync
::
mpsc
;
use
tokio_stream
::
wrappers
::
UnboundedReceiverStream
;
use
tracing
::{
error
,
info
,
warn
};
/// Router for OpenAI backend
/// Router for OpenAI backend
#[derive(Debug)]
pub
struct
OpenAIRouter
{
pub
struct
OpenAIRouter
{
/// HTTP client for upstream OpenAI-compatible API
/// HTTP client for upstream OpenAI-compatible API
client
:
reqwest
::
Client
,
client
:
reqwest
::
Client
,
...
@@ -29,6 +38,17 @@ pub struct OpenAIRouter {
...
@@ -29,6 +38,17 @@ pub struct OpenAIRouter {
circuit_breaker
:
CircuitBreaker
,
circuit_breaker
:
CircuitBreaker
,
/// Health status
/// Health status
healthy
:
AtomicBool
,
healthy
:
AtomicBool
,
/// Response storage for managing conversation history
response_storage
:
SharedResponseStorage
,
}
impl
std
::
fmt
::
Debug
for
OpenAIRouter
{
fn
fmt
(
&
self
,
f
:
&
mut
std
::
fmt
::
Formatter
<
'_
>
)
->
std
::
fmt
::
Result
{
f
.debug_struct
(
"OpenAIRouter"
)
.field
(
"base_url"
,
&
self
.base_url
)
.field
(
"healthy"
,
&
self
.healthy
)
.finish
()
}
}
}
impl
OpenAIRouter
{
impl
OpenAIRouter
{
...
@@ -36,6 +56,7 @@ impl OpenAIRouter {
...
@@ -36,6 +56,7 @@ impl OpenAIRouter {
pub
async
fn
new
(
pub
async
fn
new
(
base_url
:
String
,
base_url
:
String
,
circuit_breaker_config
:
Option
<
CircuitBreakerConfig
>
,
circuit_breaker_config
:
Option
<
CircuitBreakerConfig
>
,
response_storage
:
SharedResponseStorage
,
)
->
Result
<
Self
,
String
>
{
)
->
Result
<
Self
,
String
>
{
let
client
=
reqwest
::
Client
::
builder
()
let
client
=
reqwest
::
Client
::
builder
()
.timeout
(
std
::
time
::
Duration
::
from_secs
(
300
))
.timeout
(
std
::
time
::
Duration
::
from_secs
(
300
))
...
@@ -61,8 +82,246 @@ impl OpenAIRouter {
...
@@ -61,8 +82,246 @@ impl OpenAIRouter {
base_url
,
base_url
,
circuit_breaker
,
circuit_breaker
,
healthy
:
AtomicBool
::
new
(
true
),
healthy
:
AtomicBool
::
new
(
true
),
response_storage
,
})
})
}
}
async
fn
handle_non_streaming_response
(
&
self
,
url
:
String
,
headers
:
Option
<&
HeaderMap
>
,
payload
:
Value
,
original_body
:
&
ResponsesRequest
,
original_previous_response_id
:
Option
<
String
>
,
)
->
Response
{
let
request_builder
=
self
.client
.post
(
&
url
)
.json
(
&
payload
);
// Apply headers with filtering
let
request_builder
=
if
let
Some
(
headers
)
=
headers
{
apply_request_headers
(
headers
,
request_builder
,
true
)
}
else
{
request_builder
};
match
request_builder
.send
()
.await
{
Ok
(
response
)
=>
{
let
status
=
response
.status
();
if
!
status
.is_success
()
{
let
error_text
=
response
.text
()
.await
.unwrap_or_else
(|
e
|
format!
(
"Failed to get error body: {}"
,
e
));
return
(
status
,
error_text
)
.into_response
();
}
// Parse the response
match
response
.json
::
<
Value
>
()
.await
{
Ok
(
mut
openai_response_json
)
=>
{
if
let
Some
(
prev_id
)
=
original_previous_response_id
{
if
let
Some
(
obj
)
=
openai_response_json
.as_object_mut
()
{
let
should_insert
=
obj
.get
(
"previous_response_id"
)
.map
(|
v
|
v
.is_null
())
.unwrap_or
(
true
);
if
should_insert
{
obj
.insert
(
"previous_response_id"
.to_string
(),
Value
::
String
(
prev_id
),
);
}
}
}
if
let
Some
(
obj
)
=
openai_response_json
.as_object_mut
()
{
if
!
obj
.contains_key
(
"instructions"
)
{
if
let
Some
(
instructions
)
=
&
original_body
.instructions
{
obj
.insert
(
"instructions"
.to_string
(),
Value
::
String
(
instructions
.clone
()),
);
}
}
if
!
obj
.contains_key
(
"metadata"
)
{
if
let
Some
(
metadata
)
=
&
original_body
.metadata
{
let
metadata_map
:
serde_json
::
Map
<
String
,
Value
>
=
metadata
.iter
()
.map
(|(
k
,
v
)|
(
k
.clone
(),
v
.clone
()))
.collect
();
obj
.insert
(
"metadata"
.to_string
(),
Value
::
Object
(
metadata_map
));
}
}
// Reflect the client's requested store preference in the response body
obj
.insert
(
"store"
.to_string
(),
Value
::
Bool
(
original_body
.store
));
}
if
original_body
.store
{
if
let
Err
(
e
)
=
self
.store_response_internal
(
&
openai_response_json
,
original_body
)
.await
{
warn!
(
"Failed to store response: {}"
,
e
);
}
}
match
serde_json
::
to_string
(
&
openai_response_json
)
{
Ok
(
json_str
)
=>
(
StatusCode
::
OK
,
[(
"content-type"
,
"application/json"
)],
json_str
,
)
.into_response
(),
Err
(
e
)
=>
{
error!
(
"Failed to serialize response: {}"
,
e
);
(
StatusCode
::
INTERNAL_SERVER_ERROR
,
json!
({
"error"
:
{
"message"
:
"Failed to serialize response"
,
"type"
:
"internal_error"
}})
.to_string
(),
)
.into_response
()
}
}
}
Err
(
e
)
=>
{
error!
(
"Failed to parse OpenAI response: {}"
,
e
);
(
StatusCode
::
INTERNAL_SERVER_ERROR
,
format!
(
"Failed to parse response: {}"
,
e
),
)
.into_response
()
}
}
}
Err
(
e
)
=>
(
StatusCode
::
BAD_GATEWAY
,
format!
(
"Failed to forward request to OpenAI: {}"
,
e
),
)
.into_response
(),
}
}
async
fn
handle_streaming_response
(
&
self
,
_u
rl
:
String
,
_
headers
:
Option
<&
HeaderMap
>
,
_
payload
:
Value
,
_
original_body
:
&
ResponsesRequest
,
_
original_previous_response_id
:
Option
<
String
>
,
)
->
Response
{
(
StatusCode
::
NOT_IMPLEMENTED
,
"Streaming responses not yet implemented"
,
)
.into_response
()
}
async
fn
store_response_internal
(
&
self
,
response_json
:
&
Value
,
original_body
:
&
ResponsesRequest
,
)
->
Result
<
(),
String
>
{
if
!
original_body
.store
{
return
Ok
(());
}
match
Self
::
store_response_impl
(
&
self
.response_storage
,
response_json
,
original_body
)
.await
{
Ok
(
response_id
)
=>
{
info!
(
response_id
=
%
response_id
.0
,
"Stored response locally"
);
Ok
(())
}
Err
(
e
)
=>
Err
(
e
),
}
}
async
fn
store_response_impl
(
response_storage
:
&
SharedResponseStorage
,
response_json
:
&
Value
,
original_body
:
&
ResponsesRequest
,
)
->
Result
<
ResponseId
,
String
>
{
let
input_text
=
match
&
original_body
.input
{
ResponseInput
::
Text
(
text
)
=>
text
.clone
(),
ResponseInput
::
Items
(
_
)
=>
"complex input"
.to_string
(),
};
let
output_text
=
Self
::
extract_primary_output_text
(
response_json
)
.unwrap_or_default
();
let
mut
stored_response
=
StoredResponse
::
new
(
input_text
,
output_text
,
None
);
stored_response
.instructions
=
response_json
.get
(
"instructions"
)
.and_then
(|
v
|
v
.as_str
())
.map
(|
s
|
s
.to_string
())
.or_else
(||
original_body
.instructions
.clone
());
stored_response
.model
=
response_json
.get
(
"model"
)
.and_then
(|
v
|
v
.as_str
())
.map
(|
s
|
s
.to_string
())
.or_else
(||
original_body
.model
.clone
());
stored_response
.user
=
response_json
.get
(
"user"
)
.and_then
(|
v
|
v
.as_str
())
.map
(|
s
|
s
.to_string
())
.or_else
(||
original_body
.user
.clone
());
stored_response
.metadata
=
response_json
.get
(
"metadata"
)
.and_then
(|
v
|
v
.as_object
())
.map
(|
m
|
{
m
.iter
()
.map
(|(
k
,
v
)|
(
k
.clone
(),
v
.clone
()))
.collect
::
<
HashMap
<
_
,
_
>>
()
})
.unwrap_or_else
(||
original_body
.metadata
.clone
()
.unwrap_or_default
());
stored_response
.previous_response_id
=
response_json
.get
(
"previous_response_id"
)
.and_then
(|
v
|
v
.as_str
())
.map
(|
s
|
ResponseId
::
from_string
(
s
.to_string
()))
.or_else
(||
{
original_body
.previous_response_id
.as_ref
()
.map
(|
id
|
ResponseId
::
from_string
(
id
.clone
()))
});
if
let
Some
(
id_str
)
=
response_json
.get
(
"id"
)
.and_then
(|
v
|
v
.as_str
())
{
stored_response
.id
=
ResponseId
::
from_string
(
id_str
.to_string
());
}
stored_response
.raw_response
=
response_json
.clone
();
response_storage
.store_response
(
stored_response
)
.await
.map_err
(|
e
|
format!
(
"Failed to store response: {}"
,
e
))
}
fn
extract_primary_output_text
(
response_json
:
&
Value
)
->
Option
<
String
>
{
if
let
Some
(
items
)
=
response_json
.get
(
"output"
)
.and_then
(|
v
|
v
.as_array
())
{
for
item
in
items
{
if
let
Some
(
content
)
=
item
.get
(
"content"
)
.and_then
(|
v
|
v
.as_array
())
{
for
part
in
content
{
if
part
.get
(
"type"
)
.and_then
(|
v
|
v
.as_str
())
.map
(|
t
|
t
==
"output_text"
)
.unwrap_or
(
false
)
{
if
let
Some
(
text
)
=
part
.get
(
"text"
)
.and_then
(|
v
|
v
.as_str
())
{
return
Some
(
text
.to_string
());
}
}
}
}
}
}
None
}
}
}
#[async_trait]
#[async_trait]
...
@@ -108,7 +367,7 @@ impl super::super::RouterTrait for OpenAIRouter {
...
@@ -108,7 +367,7 @@ impl super::super::RouterTrait for OpenAIRouter {
}
}
async
fn
get_server_info
(
&
self
,
_
req
:
Request
<
Body
>
)
->
Response
{
async
fn
get_server_info
(
&
self
,
_
req
:
Request
<
Body
>
)
->
Response
{
let
info
=
serde_json
::
json!
({
let
info
=
json!
({
"router_type"
:
"openai"
,
"router_type"
:
"openai"
,
"workers"
:
1
,
"workers"
:
1
,
"base_url"
:
&
self
.base_url
"base_url"
:
&
self
.base_url
...
@@ -192,7 +451,7 @@ impl super::super::RouterTrait for OpenAIRouter {
...
@@ -192,7 +451,7 @@ impl super::super::RouterTrait for OpenAIRouter {
}
}
// Serialize request body, removing SGLang-only fields
// Serialize request body, removing SGLang-only fields
let
mut
payload
=
match
serde_json
::
to_value
(
body
)
{
let
mut
payload
=
match
to_value
(
body
)
{
Ok
(
v
)
=>
v
,
Ok
(
v
)
=>
v
,
Err
(
e
)
=>
{
Err
(
e
)
=>
{
return
(
return
(
...
@@ -282,7 +541,7 @@ impl super::super::RouterTrait for OpenAIRouter {
...
@@ -282,7 +541,7 @@ impl super::super::RouterTrait for OpenAIRouter {
}
else
{
}
else
{
// Stream SSE bytes to client
// Stream SSE bytes to client
let
stream
=
resp
.bytes_stream
();
let
stream
=
resp
.bytes_stream
();
let
(
tx
,
rx
)
=
tokio
::
sync
::
mpsc
::
unbounded_channel
();
let
(
tx
,
rx
)
=
mpsc
::
unbounded_channel
();
tokio
::
spawn
(
async
move
{
tokio
::
spawn
(
async
move
{
let
mut
s
=
stream
;
let
mut
s
=
stream
;
while
let
Some
(
chunk
)
=
s
.next
()
.await
{
while
let
Some
(
chunk
)
=
s
.next
()
.await
{
...
@@ -299,9 +558,7 @@ impl super::super::RouterTrait for OpenAIRouter {
...
@@ -299,9 +558,7 @@ impl super::super::RouterTrait for OpenAIRouter {
}
}
}
}
});
});
let
mut
response
=
Response
::
new
(
Body
::
from_stream
(
let
mut
response
=
Response
::
new
(
Body
::
from_stream
(
UnboundedReceiverStream
::
new
(
rx
)));
tokio_stream
::
wrappers
::
UnboundedReceiverStream
::
new
(
rx
),
));
*
response
.status_mut
()
=
status
;
*
response
.status_mut
()
=
status
;
response
response
.headers_mut
()
.headers_mut
()
...
@@ -326,36 +583,294 @@ impl super::super::RouterTrait for OpenAIRouter {
...
@@ -326,36 +583,294 @@ impl super::super::RouterTrait for OpenAIRouter {
async
fn
route_responses
(
async
fn
route_responses
(
&
self
,
&
self
,
_
headers
:
Option
<&
HeaderMap
>
,
headers
:
Option
<&
HeaderMap
>
,
_
body
:
&
crate
::
protocols
::
spec
::
ResponsesRequest
,
body
:
&
ResponsesRequest
,
_
model_id
:
Option
<&
str
>
,
model_id
:
Option
<&
str
>
,
)
->
Response
{
)
->
Response
{
(
let
url
=
format!
(
"{}/v1/responses"
,
self
.base_url
);
StatusCode
::
NOT_IMPLEMENTED
,
"Responses endpoint not implemented for OpenAI router"
,
info!
(
)
requested_store
=
body
.store
,
.into_response
()
is_streaming
=
body
.stream
,
"openai_responses_request"
);
if
body
.stream
{
return
(
StatusCode
::
NOT_IMPLEMENTED
,
"Streaming responses not yet implemented"
,
)
.into_response
();
}
// Clone the body and override model if needed
let
mut
request_body
=
body
.clone
();
if
let
Some
(
model
)
=
model_id
{
request_body
.model
=
Some
(
model
.to_string
());
}
// Store the original previous_response_id for the response
let
original_previous_response_id
=
request_body
.previous_response_id
.clone
();
// Handle previous_response_id by loading prior context
let
mut
conversation_items
:
Option
<
Vec
<
ResponseInputOutputItem
>>
=
None
;
if
let
Some
(
prev_id_str
)
=
request_body
.previous_response_id
.clone
()
{
let
prev_id
=
ResponseId
::
from_string
(
prev_id_str
.clone
());
match
self
.response_storage
.get_response_chain
(
&
prev_id
,
None
)
.await
{
Ok
(
chain
)
=>
{
if
!
chain
.responses
.is_empty
()
{
let
mut
items
=
Vec
::
new
();
for
stored
in
chain
.responses
.iter
()
{
let
trimmed_id
=
stored
.id
.0
.trim_start_matches
(
"resp_"
);
if
!
stored
.input
.is_empty
()
{
items
.push
(
ResponseInputOutputItem
::
Message
{
id
:
format!
(
"msg_u_{}"
,
trimmed_id
),
role
:
"user"
.to_string
(),
status
:
Some
(
"completed"
.to_string
()),
content
:
vec!
[
ResponseContentPart
::
InputText
{
text
:
stored
.input
.clone
(),
}],
});
}
if
!
stored
.output
.is_empty
()
{
items
.push
(
ResponseInputOutputItem
::
Message
{
id
:
format!
(
"msg_a_{}"
,
trimmed_id
),
role
:
"assistant"
.to_string
(),
status
:
Some
(
"completed"
.to_string
()),
content
:
vec!
[
ResponseContentPart
::
OutputText
{
text
:
stored
.output
.clone
(),
annotations
:
vec!
[],
logprobs
:
None
,
}],
});
}
}
conversation_items
=
Some
(
items
);
}
else
{
info!
(
previous_response_id
=
%
prev_id_str
,
"previous chain empty"
);
}
}
Err
(
err
)
=>
{
warn!
(
previous_response_id
=
%
prev_id_str
,
%
err
,
"failed to fetch previous response chain"
);
}
}
// Clear previous_response_id from request since we're converting to conversation
request_body
.previous_response_id
=
None
;
}
if
let
Some
(
mut
items
)
=
conversation_items
{
match
&
request_body
.input
{
ResponseInput
::
Text
(
text
)
=>
{
items
.push
(
ResponseInputOutputItem
::
Message
{
id
:
format!
(
"msg_u_current_{}"
,
items
.len
()),
role
:
"user"
.to_string
(),
status
:
Some
(
"completed"
.to_string
()),
content
:
vec!
[
ResponseContentPart
::
InputText
{
text
:
text
.clone
()
}],
});
}
ResponseInput
::
Items
(
existing
)
=>
{
items
.extend
(
existing
.clone
());
}
}
request_body
.input
=
ResponseInput
::
Items
(
items
);
}
// Always set store=false for OpenAI (we store internally)
request_body
.store
=
false
;
// Convert to JSON payload and strip SGLang-specific fields before forwarding
let
mut
payload
=
match
to_value
(
&
request_body
)
{
Ok
(
value
)
=>
value
,
Err
(
err
)
=>
{
return
(
StatusCode
::
BAD_REQUEST
,
format!
(
"Failed to serialize responses request: {}"
,
err
),
)
.into_response
();
}
};
if
let
Some
(
obj
)
=
payload
.as_object_mut
()
{
for
key
in
[
"request_id"
,
"priority"
,
"frequency_penalty"
,
"presence_penalty"
,
"stop"
,
"top_k"
,
"min_p"
,
"repetition_penalty"
,
]
{
obj
.remove
(
key
);
}
}
// Check if streaming is requested
if
body
.stream
{
// Handle streaming response
self
.handle_streaming_response
(
url
,
headers
,
payload
,
body
,
original_previous_response_id
,
)
.await
}
else
{
// Handle non-streaming response
self
.handle_non_streaming_response
(
url
,
headers
,
payload
,
body
,
original_previous_response_id
,
)
.await
}
}
}
async
fn
get_response
(
&
self
,
_
headers
:
Option
<&
HeaderMap
>
,
_
response_id
:
&
str
)
->
Response
{
async
fn
get_response
(
&
self
,
_
headers
:
Option
<&
HeaderMap
>
,
response_id
:
&
str
,
params
:
&
ResponsesGetParams
,
)
->
Response
{
let
stored_id
=
ResponseId
::
from_string
(
response_id
.to_string
());
if
let
Ok
(
Some
(
stored_response
))
=
self
.response_storage
.get_response
(
&
stored_id
)
.await
{
let
stream_requested
=
params
.stream
.unwrap_or
(
false
);
let
raw_value
=
stored_response
.raw_response
.clone
();
if
!
raw_value
.is_null
()
{
if
stream_requested
{
return
(
StatusCode
::
NOT_IMPLEMENTED
,
"Streaming retrieval not yet implemented"
,
)
.into_response
();
}
return
(
StatusCode
::
OK
,
[(
"content-type"
,
"application/json"
)],
raw_value
.to_string
(),
)
.into_response
();
}
let
openai_response
=
ResponsesResponse
{
id
:
stored_response
.id
.0
.clone
(),
object
:
"response"
.to_string
(),
created_at
:
stored_response
.created_at
.timestamp
(),
status
:
ResponseStatus
::
Completed
,
error
:
None
,
incomplete_details
:
None
,
instructions
:
stored_response
.instructions
.clone
(),
max_output_tokens
:
None
,
model
:
stored_response
.model
.unwrap_or_else
(||
"gpt-4o"
.to_string
()),
output
:
vec!
[
ResponseOutputItem
::
Message
{
id
:
format!
(
"msg_{}"
,
stored_response
.id
.0
),
role
:
"assistant"
.to_string
(),
status
:
"completed"
.to_string
(),
content
:
vec!
[
ResponseContentPart
::
OutputText
{
text
:
stored_response
.output
,
annotations
:
vec!
[],
logprobs
:
None
,
}],
}],
parallel_tool_calls
:
true
,
previous_response_id
:
stored_response
.previous_response_id
.map
(|
id
|
id
.0
),
reasoning
:
None
,
store
:
true
,
temperature
:
Some
(
1.0
),
text
:
Some
(
ResponseTextFormat
{
format
:
TextFormatType
{
format_type
:
"text"
.to_string
(),
},
}),
tool_choice
:
"auto"
.to_string
(),
tools
:
vec!
[],
top_p
:
Some
(
1.0
),
truncation
:
Some
(
"disabled"
.to_string
()),
usage
:
None
,
user
:
stored_response
.user
.clone
(),
metadata
:
stored_response
.metadata
.clone
(),
};
if
stream_requested
{
return
(
StatusCode
::
NOT_IMPLEMENTED
,
"Streaming retrieval not yet implemented"
,
)
.into_response
();
}
return
(
StatusCode
::
OK
,
[(
"content-type"
,
"application/json"
)],
serde_json
::
to_string
(
&
openai_response
)
.unwrap_or_else
(|
e
|
{
format!
(
"{{
\"
error
\"
:
\"
Failed to serialize response: {}
\"
}}"
,
e
)
}),
)
.into_response
();
}
(
(
StatusCode
::
NOT_IMPLEMENTED
,
StatusCode
::
NOT_FOUND
,
"Responses retrieve endpoint not implemented for OpenAI router"
,
format!
(
"Response with id '{}' not found in local storage"
,
response_id
),
)
)
.into_response
()
.into_response
()
}
}
async
fn
cancel_response
(
&
self
,
_
headers
:
Option
<&
HeaderMap
>
,
_
response_id
:
&
str
)
->
Response
{
async
fn
cancel_response
(
&
self
,
headers
:
Option
<&
HeaderMap
>
,
response_id
:
&
str
)
->
Response
{
(
// Forward to OpenAI's cancel endpoint
StatusCode
::
NOT_IMPLEMENTED
,
let
url
=
format!
(
"{}/v1/responses/{}/cancel"
,
self
.base_url
,
response_id
);
"Responses cancel endpoint not implemented for OpenAI router"
,
)
let
request_builder
=
self
.client
.post
(
&
url
);
.into_response
()
// Apply headers with filtering (skip content headers for POST without body)
let
request_builder
=
if
let
Some
(
headers
)
=
headers
{
apply_request_headers
(
headers
,
request_builder
,
true
)
}
else
{
request_builder
};
match
request_builder
.send
()
.await
{
Ok
(
response
)
=>
{
let
status
=
response
.status
();
let
headers
=
response
.headers
()
.clone
();
match
response
.text
()
.await
{
Ok
(
body_text
)
=>
{
let
mut
response
=
(
status
,
body_text
)
.into_response
();
*
response
.headers_mut
()
=
preserve_response_headers
(
&
headers
);
response
}
Err
(
e
)
=>
(
StatusCode
::
INTERNAL_SERVER_ERROR
,
format!
(
"Failed to read response body: {}"
,
e
),
)
.into_response
(),
}
}
Err
(
e
)
=>
(
StatusCode
::
BAD_GATEWAY
,
format!
(
"Failed to cancel response on OpenAI: {}"
,
e
),
)
.into_response
(),
}
}
}
async
fn
flush_cache
(
&
self
)
->
Response
{
async
fn
flush_cache
(
&
self
)
->
Response
{
(
(
StatusCode
::
NOT_IMPLEMENTED
,
StatusCode
::
FORBIDDEN
,
"flush_cache not supported for OpenAI router"
,
"flush_cache not supported for OpenAI router"
,
)
)
.into_response
()
.into_response
()
...
@@ -363,7 +878,7 @@ impl super::super::RouterTrait for OpenAIRouter {
...
@@ -363,7 +878,7 @@ impl super::super::RouterTrait for OpenAIRouter {
async
fn
get_worker_loads
(
&
self
)
->
Response
{
async
fn
get_worker_loads
(
&
self
)
->
Response
{
(
(
StatusCode
::
NOT_IMPLEMENTED
,
StatusCode
::
FORBIDDEN
,
"get_worker_loads not supported for OpenAI router"
,
"get_worker_loads not supported for OpenAI router"
,
)
)
.into_response
()
.into_response
()
...
@@ -384,12 +899,12 @@ impl super::super::RouterTrait for OpenAIRouter {
...
@@ -384,12 +899,12 @@ impl super::super::RouterTrait for OpenAIRouter {
async
fn
route_embeddings
(
async
fn
route_embeddings
(
&
self
,
&
self
,
_
headers
:
Option
<&
HeaderMap
>
,
_
headers
:
Option
<&
HeaderMap
>
,
_
body
:
&
crate
::
protocols
::
spec
::
EmbeddingRequest
,
_
body
:
&
EmbeddingRequest
,
_
model_id
:
Option
<&
str
>
,
_
model_id
:
Option
<&
str
>
,
)
->
Response
{
)
->
Response
{
(
(
StatusCode
::
NOT_IMPLEMENTED
,
StatusCode
::
FORBIDDEN
,
"Embeddings endpoint not
implemen
ted for OpenAI backend"
,
"Embeddings endpoint not
suppor
ted for OpenAI backend"
,
)
)
.into_response
()
.into_response
()
}
}
...
@@ -401,8 +916,8 @@ impl super::super::RouterTrait for OpenAIRouter {
...
@@ -401,8 +916,8 @@ impl super::super::RouterTrait for OpenAIRouter {
_
model_id
:
Option
<&
str
>
,
_
model_id
:
Option
<&
str
>
,
)
->
Response
{
)
->
Response
{
(
(
StatusCode
::
NOT_IMPLEMENTED
,
StatusCode
::
FORBIDDEN
,
"Rerank endpoint not
implemen
ted for OpenAI backend"
,
"Rerank endpoint not
suppor
ted for OpenAI backend"
,
)
)
.into_response
()
.into_response
()
}
}
...
...
sgl-router/src/routers/http/pd_router.rs
View file @
98c3b04f
...
@@ -8,7 +8,7 @@ use crate::metrics::RouterMetrics;
...
@@ -8,7 +8,7 @@ use crate::metrics::RouterMetrics;
use
crate
::
policies
::{
LoadBalancingPolicy
,
PolicyRegistry
};
use
crate
::
policies
::{
LoadBalancingPolicy
,
PolicyRegistry
};
use
crate
::
protocols
::
spec
::{
use
crate
::
protocols
::
spec
::{
ChatCompletionRequest
,
ChatMessage
,
CompletionRequest
,
GenerateRequest
,
RerankRequest
,
ChatCompletionRequest
,
ChatMessage
,
CompletionRequest
,
GenerateRequest
,
RerankRequest
,
ResponsesRequest
,
StringOrArray
,
UserMessageContent
,
ResponsesGetParams
,
ResponsesRequest
,
StringOrArray
,
UserMessageContent
,
};
};
use
crate
::
routers
::
header_utils
;
use
crate
::
routers
::
header_utils
;
use
crate
::
routers
::
RouterTrait
;
use
crate
::
routers
::
RouterTrait
;
...
@@ -1424,7 +1424,12 @@ impl RouterTrait for PDRouter {
...
@@ -1424,7 +1424,12 @@ impl RouterTrait for PDRouter {
.into_response
()
.into_response
()
}
}
async
fn
get_response
(
&
self
,
_
headers
:
Option
<&
HeaderMap
>
,
_
response_id
:
&
str
)
->
Response
{
async
fn
get_response
(
&
self
,
_
headers
:
Option
<&
HeaderMap
>
,
_
response_id
:
&
str
,
_
params
:
&
ResponsesGetParams
,
)
->
Response
{
(
(
StatusCode
::
NOT_IMPLEMENTED
,
StatusCode
::
NOT_IMPLEMENTED
,
"Responses retrieve endpoint not implemented for PD router"
,
"Responses retrieve endpoint not implemented for PD router"
,
...
...
sgl-router/src/routers/http/router.rs
View file @
98c3b04f
...
@@ -6,7 +6,7 @@ use crate::metrics::RouterMetrics;
...
@@ -6,7 +6,7 @@ use crate::metrics::RouterMetrics;
use
crate
::
policies
::{
LoadBalancingPolicy
,
PolicyRegistry
};
use
crate
::
policies
::{
LoadBalancingPolicy
,
PolicyRegistry
};
use
crate
::
protocols
::
spec
::{
use
crate
::
protocols
::
spec
::{
ChatCompletionRequest
,
CompletionRequest
,
EmbeddingRequest
,
GenerateRequest
,
GenerationRequest
,
ChatCompletionRequest
,
CompletionRequest
,
EmbeddingRequest
,
GenerateRequest
,
GenerationRequest
,
RerankRequest
,
RerankResponse
,
RerankResult
,
ResponsesRequest
,
RerankRequest
,
RerankResponse
,
RerankResult
,
ResponsesGetParams
,
ResponsesRequest
,
};
};
use
crate
::
routers
::
header_utils
;
use
crate
::
routers
::
header_utils
;
use
crate
::
routers
::
RouterTrait
;
use
crate
::
routers
::
RouterTrait
;
...
@@ -903,7 +903,12 @@ impl RouterTrait for Router {
...
@@ -903,7 +903,12 @@ impl RouterTrait for Router {
.await
.await
}
}
async
fn
get_response
(
&
self
,
headers
:
Option
<&
HeaderMap
>
,
response_id
:
&
str
)
->
Response
{
async
fn
get_response
(
&
self
,
headers
:
Option
<&
HeaderMap
>
,
response_id
:
&
str
,
_
params
:
&
ResponsesGetParams
,
)
->
Response
{
let
endpoint
=
format!
(
"v1/responses/{}"
,
response_id
);
let
endpoint
=
format!
(
"v1/responses/{}"
,
response_id
);
self
.route_get_request
(
headers
,
&
endpoint
)
.await
self
.route_get_request
(
headers
,
&
endpoint
)
.await
}
}
...
...
sgl-router/src/routers/mod.rs
View file @
98c3b04f
...
@@ -11,7 +11,7 @@ use std::fmt::Debug;
...
@@ -11,7 +11,7 @@ use std::fmt::Debug;
use
crate
::
protocols
::
spec
::{
use
crate
::
protocols
::
spec
::{
ChatCompletionRequest
,
CompletionRequest
,
EmbeddingRequest
,
GenerateRequest
,
RerankRequest
,
ChatCompletionRequest
,
CompletionRequest
,
EmbeddingRequest
,
GenerateRequest
,
RerankRequest
,
ResponsesRequest
,
ResponsesGetParams
,
ResponsesRequest
,
};
};
pub
mod
factory
;
pub
mod
factory
;
...
@@ -82,7 +82,12 @@ pub trait RouterTrait: Send + Sync + Debug {
...
@@ -82,7 +82,12 @@ pub trait RouterTrait: Send + Sync + Debug {
)
->
Response
;
)
->
Response
;
/// Retrieve a stored/background response by id
/// Retrieve a stored/background response by id
async
fn
get_response
(
&
self
,
headers
:
Option
<&
HeaderMap
>
,
response_id
:
&
str
)
->
Response
;
async
fn
get_response
(
&
self
,
headers
:
Option
<&
HeaderMap
>
,
response_id
:
&
str
,
params
:
&
ResponsesGetParams
,
)
->
Response
;
/// Cancel a background response by id
/// Cancel a background response by id
async
fn
cancel_response
(
&
self
,
headers
:
Option
<&
HeaderMap
>
,
response_id
:
&
str
)
->
Response
;
async
fn
cancel_response
(
&
self
,
headers
:
Option
<&
HeaderMap
>
,
response_id
:
&
str
)
->
Response
;
...
...
sgl-router/src/routers/router_manager.rs
View file @
98c3b04f
...
@@ -8,7 +8,7 @@ use crate::config::{ConnectionMode, RoutingMode};
...
@@ -8,7 +8,7 @@ use crate::config::{ConnectionMode, RoutingMode};
use
crate
::
core
::{
WorkerRegistry
,
WorkerType
};
use
crate
::
core
::{
WorkerRegistry
,
WorkerType
};
use
crate
::
protocols
::
spec
::{
use
crate
::
protocols
::
spec
::{
ChatCompletionRequest
,
CompletionRequest
,
EmbeddingRequest
,
GenerateRequest
,
RerankRequest
,
ChatCompletionRequest
,
CompletionRequest
,
EmbeddingRequest
,
GenerateRequest
,
RerankRequest
,
ResponsesRequest
,
ResponsesGetParams
,
ResponsesRequest
,
};
};
use
crate
::
routers
::
RouterTrait
;
use
crate
::
routers
::
RouterTrait
;
use
crate
::
server
::{
AppContext
,
ServerConfig
};
use
crate
::
server
::{
AppContext
,
ServerConfig
};
...
@@ -402,10 +402,37 @@ impl RouterTrait for RouterManager {
...
@@ -402,10 +402,37 @@ impl RouterTrait for RouterManager {
}
}
async
fn
route_responses
(
async
fn
route_responses
(
&
self
,
headers
:
Option
<&
HeaderMap
>
,
body
:
&
ResponsesRequest
,
model_id
:
Option
<&
str
>
,
)
->
Response
{
let
selected_model
=
body
.model
.as_deref
()
.or
(
model_id
);
let
router
=
self
.select_router_for_request
(
headers
,
selected_model
);
if
let
Some
(
router
)
=
router
{
router
.route_responses
(
headers
,
body
,
selected_model
)
.await
}
else
{
(
StatusCode
::
NOT_FOUND
,
"No router available to handle responses request"
,
)
.into_response
()
}
}
async
fn
delete_response
(
&
self
,
_
headers
:
Option
<&
HeaderMap
>
,
_
response_id
:
&
str
)
->
Response
{
(
StatusCode
::
NOT_IMPLEMENTED
,
"responses api not yet implemented in inference gateway mode"
,
)
.into_response
()
}
async
fn
list_response_input_items
(
&
self
,
&
self
,
_
headers
:
Option
<&
HeaderMap
>
,
_
headers
:
Option
<&
HeaderMap
>
,
_
body
:
&
ResponsesRequest
,
_
response_id
:
&
str
,
_
model_id
:
Option
<&
str
>
,
)
->
Response
{
)
->
Response
{
(
(
StatusCode
::
NOT_IMPLEMENTED
,
StatusCode
::
NOT_IMPLEMENTED
,
...
@@ -414,10 +441,15 @@ impl RouterTrait for RouterManager {
...
@@ -414,10 +441,15 @@ impl RouterTrait for RouterManager {
.into_response
()
.into_response
()
}
}
async
fn
get_response
(
&
self
,
headers
:
Option
<&
HeaderMap
>
,
response_id
:
&
str
)
->
Response
{
async
fn
get_response
(
&
self
,
headers
:
Option
<&
HeaderMap
>
,
response_id
:
&
str
,
params
:
&
ResponsesGetParams
,
)
->
Response
{
let
router
=
self
.select_router_for_request
(
headers
,
None
);
let
router
=
self
.select_router_for_request
(
headers
,
None
);
if
let
Some
(
router
)
=
router
{
if
let
Some
(
router
)
=
router
{
router
.get_response
(
headers
,
response_id
)
.await
router
.get_response
(
headers
,
response_id
,
params
)
.await
}
else
{
}
else
{
(
(
StatusCode
::
NOT_FOUND
,
StatusCode
::
NOT_FOUND
,
...
@@ -440,26 +472,6 @@ impl RouterTrait for RouterManager {
...
@@ -440,26 +472,6 @@ impl RouterTrait for RouterManager {
}
}
}
}
async
fn
delete_response
(
&
self
,
_
headers
:
Option
<&
HeaderMap
>
,
_
response_id
:
&
str
)
->
Response
{
(
StatusCode
::
NOT_IMPLEMENTED
,
"responses api not yet implemented in inference gateway mode"
,
)
.into_response
()
}
async
fn
list_response_input_items
(
&
self
,
_
headers
:
Option
<&
HeaderMap
>
,
_
response_id
:
&
str
,
)
->
Response
{
(
StatusCode
::
NOT_IMPLEMENTED
,
"responses api not yet implemented in inference gateway mode"
,
)
.into_response
()
}
async
fn
route_embeddings
(
async
fn
route_embeddings
(
&
self
,
&
self
,
headers
:
Option
<&
HeaderMap
>
,
headers
:
Option
<&
HeaderMap
>
,
...
...
sgl-router/src/server.rs
View file @
98c3b04f
...
@@ -9,7 +9,7 @@ use crate::{
...
@@ -9,7 +9,7 @@ use crate::{
protocols
::{
protocols
::{
spec
::{
spec
::{
ChatCompletionRequest
,
CompletionRequest
,
EmbeddingRequest
,
GenerateRequest
,
ChatCompletionRequest
,
CompletionRequest
,
EmbeddingRequest
,
GenerateRequest
,
RerankRequest
,
ResponsesRequest
,
V1RerankReqInput
,
RerankRequest
,
ResponsesGetParams
,
ResponsesRequest
,
V1RerankReqInput
,
},
},
worker_spec
::{
WorkerApiResponse
,
WorkerConfigRequest
,
WorkerErrorResponse
},
worker_spec
::{
WorkerApiResponse
,
WorkerConfigRequest
,
WorkerErrorResponse
},
},
},
...
@@ -224,10 +224,11 @@ async fn v1_responses_get(
...
@@ -224,10 +224,11 @@ async fn v1_responses_get(
State
(
state
):
State
<
Arc
<
AppState
>>
,
State
(
state
):
State
<
Arc
<
AppState
>>
,
Path
(
response_id
):
Path
<
String
>
,
Path
(
response_id
):
Path
<
String
>
,
headers
:
http
::
HeaderMap
,
headers
:
http
::
HeaderMap
,
Query
(
params
):
Query
<
ResponsesGetParams
>
,
)
->
Response
{
)
->
Response
{
state
state
.router
.router
.get_response
(
Some
(
&
headers
),
&
response_id
)
.get_response
(
Some
(
&
headers
),
&
response_id
,
&
params
)
.await
.await
}
}
...
...
sgl-router/tests/test_openai_routing.rs
View file @
98c3b04f
...
@@ -5,17 +5,23 @@ use axum::{
...
@@ -5,17 +5,23 @@ use axum::{
extract
::
Request
,
extract
::
Request
,
http
::{
Method
,
StatusCode
},
http
::{
Method
,
StatusCode
},
routing
::
post
,
routing
::
post
,
Router
,
Json
,
Router
,
};
};
use
serde_json
::
json
;
use
serde_json
::
json
;
use
sglang_router_rs
::{
use
sglang_router_rs
::{
config
::{
RouterConfig
,
RoutingMode
},
config
::{
RouterConfig
,
RoutingMode
},
data_connector
::{
MemoryResponseStorage
,
ResponseId
,
ResponseStorage
},
protocols
::
spec
::{
protocols
::
spec
::{
ChatCompletionRequest
,
ChatMessage
,
CompletionRequest
,
GenerateRequest
,
UserMessageContent
,
ChatCompletionRequest
,
ChatMessage
,
CompletionRequest
,
GenerateRequest
,
ResponseInput
,
ResponsesGetParams
,
ResponsesRequest
,
UserMessageContent
,
},
},
routers
::{
openai_router
::
OpenAIRouter
,
RouterTrait
},
routers
::{
openai_router
::
OpenAIRouter
,
RouterTrait
},
};
};
use
std
::
sync
::
Arc
;
use
std
::
sync
::{
atomic
::{
AtomicUsize
,
Ordering
},
Arc
,
};
use
tokio
::
net
::
TcpListener
;
use
tower
::
ServiceExt
;
use
tower
::
ServiceExt
;
mod
common
;
mod
common
;
...
@@ -78,7 +84,12 @@ fn create_minimal_completion_request() -> CompletionRequest {
...
@@ -78,7 +84,12 @@ fn create_minimal_completion_request() -> CompletionRequest {
/// Test basic OpenAI router creation and configuration
/// Test basic OpenAI router creation and configuration
#[tokio::test]
#[tokio::test]
async
fn
test_openai_router_creation
()
{
async
fn
test_openai_router_creation
()
{
let
router
=
OpenAIRouter
::
new
(
"https://api.openai.com"
.to_string
(),
None
)
.await
;
let
router
=
OpenAIRouter
::
new
(
"https://api.openai.com"
.to_string
(),
None
,
Arc
::
new
(
MemoryResponseStorage
::
new
()),
)
.await
;
assert
!
(
router
.is_ok
(),
"Router creation should succeed"
);
assert
!
(
router
.is_ok
(),
"Router creation should succeed"
);
...
@@ -90,9 +101,13 @@ async fn test_openai_router_creation() {
...
@@ -90,9 +101,13 @@ async fn test_openai_router_creation() {
/// Test health endpoints
/// Test health endpoints
#[tokio::test]
#[tokio::test]
async
fn
test_openai_router_health
()
{
async
fn
test_openai_router_health
()
{
let
router
=
OpenAIRouter
::
new
(
"https://api.openai.com"
.to_string
(),
None
)
let
router
=
OpenAIRouter
::
new
(
.await
"https://api.openai.com"
.to_string
(),
.unwrap
();
None
,
Arc
::
new
(
MemoryResponseStorage
::
new
()),
)
.await
.unwrap
();
let
req
=
Request
::
builder
()
let
req
=
Request
::
builder
()
.method
(
Method
::
GET
)
.method
(
Method
::
GET
)
...
@@ -107,9 +122,13 @@ async fn test_openai_router_health() {
...
@@ -107,9 +122,13 @@ async fn test_openai_router_health() {
/// Test server info endpoint
/// Test server info endpoint
#[tokio::test]
#[tokio::test]
async
fn
test_openai_router_server_info
()
{
async
fn
test_openai_router_server_info
()
{
let
router
=
OpenAIRouter
::
new
(
"https://api.openai.com"
.to_string
(),
None
)
let
router
=
OpenAIRouter
::
new
(
.await
"https://api.openai.com"
.to_string
(),
.unwrap
();
None
,
Arc
::
new
(
MemoryResponseStorage
::
new
()),
)
.await
.unwrap
();
let
req
=
Request
::
builder
()
let
req
=
Request
::
builder
()
.method
(
Method
::
GET
)
.method
(
Method
::
GET
)
...
@@ -132,9 +151,13 @@ async fn test_openai_router_server_info() {
...
@@ -132,9 +151,13 @@ async fn test_openai_router_server_info() {
async
fn
test_openai_router_models
()
{
async
fn
test_openai_router_models
()
{
// Use mock server for deterministic models response
// Use mock server for deterministic models response
let
mock_server
=
MockOpenAIServer
::
new
()
.await
;
let
mock_server
=
MockOpenAIServer
::
new
()
.await
;
let
router
=
OpenAIRouter
::
new
(
mock_server
.base_url
(),
None
)
let
router
=
OpenAIRouter
::
new
(
.await
mock_server
.base_url
(),
.unwrap
();
None
,
Arc
::
new
(
MemoryResponseStorage
::
new
()),
)
.await
.unwrap
();
let
req
=
Request
::
builder
()
let
req
=
Request
::
builder
()
.method
(
Method
::
GET
)
.method
(
Method
::
GET
)
...
@@ -154,6 +177,138 @@ async fn test_openai_router_models() {
...
@@ -154,6 +177,138 @@ async fn test_openai_router_models() {
assert
!
(
models
[
"data"
]
.is_array
());
assert
!
(
models
[
"data"
]
.is_array
());
}
}
#[tokio::test]
async
fn
test_openai_router_responses_with_mock
()
{
let
listener
=
TcpListener
::
bind
(
"127.0.0.1:0"
)
.await
.unwrap
();
let
addr
=
listener
.local_addr
()
.unwrap
();
let
counter
=
Arc
::
new
(
AtomicUsize
::
new
(
0
));
let
counter_clone
=
counter
.clone
();
let
app
=
Router
::
new
()
.route
(
"/v1/responses"
,
post
({
move
|
Json
(
request
):
Json
<
serde_json
::
Value
>
|
{
let
counter
=
counter_clone
.clone
();
async
move
{
let
idx
=
counter
.fetch_add
(
1
,
Ordering
::
SeqCst
)
+
1
;
let
model
=
request
.get
(
"model"
)
.and_then
(|
v
|
v
.as_str
())
.unwrap_or
(
"gpt-4o-mini"
)
.to_string
();
let
id
=
format!
(
"resp_mock_{idx}"
);
let
response
=
json!
({
"id"
:
id
,
"object"
:
"response"
,
"created_at"
:
1_700_000_000
+
idx
as
i64
,
"status"
:
"completed"
,
"model"
:
model
,
"output"
:
[{
"type"
:
"message"
,
"id"
:
format!
(
"msg_{idx}"
),
"role"
:
"assistant"
,
"status"
:
"completed"
,
"content"
:
[{
"type"
:
"output_text"
,
"text"
:
format!
(
"mock_output_{idx}"
),
"annotations"
:
[]
}]
}],
"metadata"
:
{}
});
Json
(
response
)
}
}
}),
);
let
server
=
tokio
::
spawn
(
async
move
{
axum
::
serve
(
listener
,
app
)
.await
.unwrap
();
});
let
base_url
=
format!
(
"http://{}"
,
addr
);
let
storage
=
Arc
::
new
(
MemoryResponseStorage
::
new
());
let
router
=
OpenAIRouter
::
new
(
base_url
,
None
,
storage
.clone
())
.await
.unwrap
();
let
request1
=
ResponsesRequest
{
model
:
Some
(
"gpt-4o-mini"
.to_string
()),
input
:
ResponseInput
::
Text
(
"Say hi"
.to_string
()),
store
:
true
,
..
Default
::
default
()
};
let
response1
=
router
.route_responses
(
None
,
&
request1
,
None
)
.await
;
assert_eq!
(
response1
.status
(),
StatusCode
::
OK
);
let
body1_bytes
=
axum
::
body
::
to_bytes
(
response1
.into_body
(),
usize
::
MAX
)
.await
.unwrap
();
let
body1
:
serde_json
::
Value
=
serde_json
::
from_slice
(
&
body1_bytes
)
.unwrap
();
let
resp1_id
=
body1
[
"id"
]
.as_str
()
.expect
(
"id missing"
)
.to_string
();
assert_eq!
(
body1
[
"previous_response_id"
],
serde_json
::
Value
::
Null
);
let
request2
=
ResponsesRequest
{
model
:
Some
(
"gpt-4o-mini"
.to_string
()),
input
:
ResponseInput
::
Text
(
"Thanks"
.to_string
()),
store
:
true
,
previous_response_id
:
Some
(
resp1_id
.clone
()),
..
Default
::
default
()
};
let
response2
=
router
.route_responses
(
None
,
&
request2
,
None
)
.await
;
assert_eq!
(
response2
.status
(),
StatusCode
::
OK
);
let
body2_bytes
=
axum
::
body
::
to_bytes
(
response2
.into_body
(),
usize
::
MAX
)
.await
.unwrap
();
let
body2
:
serde_json
::
Value
=
serde_json
::
from_slice
(
&
body2_bytes
)
.unwrap
();
let
resp2_id
=
body2
[
"id"
]
.as_str
()
.expect
(
"second id missing"
);
assert_eq!
(
body2
[
"previous_response_id"
]
.as_str
(),
Some
(
resp1_id
.as_str
())
);
let
stored1
=
storage
.get_response
(
&
ResponseId
::
from_string
(
resp1_id
.clone
()))
.await
.unwrap
()
.expect
(
"first response missing"
);
assert_eq!
(
stored1
.input
,
"Say hi"
);
assert_eq!
(
stored1
.output
,
"mock_output_1"
);
assert
!
(
stored1
.previous_response_id
.is_none
());
let
stored2
=
storage
.get_response
(
&
ResponseId
::
from_string
(
resp2_id
.to_string
()))
.await
.unwrap
()
.expect
(
"second response missing"
);
assert_eq!
(
stored2
.previous_response_id
.unwrap
()
.0
,
resp1_id
);
assert_eq!
(
stored2
.output
,
"mock_output_2"
);
let
get1
=
router
.get_response
(
None
,
&
stored1
.id
.0
,
&
ResponsesGetParams
::
default
())
.await
;
assert_eq!
(
get1
.status
(),
StatusCode
::
OK
);
let
get1_body_bytes
=
axum
::
body
::
to_bytes
(
get1
.into_body
(),
usize
::
MAX
)
.await
.unwrap
();
let
get1_json
:
serde_json
::
Value
=
serde_json
::
from_slice
(
&
get1_body_bytes
)
.unwrap
();
assert_eq!
(
get1_json
,
body1
);
let
get2
=
router
.get_response
(
None
,
&
stored2
.id
.0
,
&
ResponsesGetParams
::
default
())
.await
;
assert_eq!
(
get2
.status
(),
StatusCode
::
OK
);
let
get2_body_bytes
=
axum
::
body
::
to_bytes
(
get2
.into_body
(),
usize
::
MAX
)
.await
.unwrap
();
let
get2_json
:
serde_json
::
Value
=
serde_json
::
from_slice
(
&
get2_body_bytes
)
.unwrap
();
assert_eq!
(
get2_json
,
body2
);
server
.abort
();
}
/// Test router factory with OpenAI routing mode
/// Test router factory with OpenAI routing mode
#[tokio::test]
#[tokio::test]
async
fn
test_router_factory_openai_mode
()
{
async
fn
test_router_factory_openai_mode
()
{
...
@@ -179,9 +334,13 @@ async fn test_router_factory_openai_mode() {
...
@@ -179,9 +334,13 @@ async fn test_router_factory_openai_mode() {
/// Test that unsupported endpoints return proper error codes
/// Test that unsupported endpoints return proper error codes
#[tokio::test]
#[tokio::test]
async
fn
test_unsupported_endpoints
()
{
async
fn
test_unsupported_endpoints
()
{
let
router
=
OpenAIRouter
::
new
(
"https://api.openai.com"
.to_string
(),
None
)
let
router
=
OpenAIRouter
::
new
(
.await
"https://api.openai.com"
.to_string
(),
.unwrap
();
None
,
Arc
::
new
(
MemoryResponseStorage
::
new
()),
)
.await
.unwrap
();
// Test generate endpoint (SGLang-specific, should not be supported)
// Test generate endpoint (SGLang-specific, should not be supported)
let
generate_request
=
GenerateRequest
{
let
generate_request
=
GenerateRequest
{
...
@@ -219,7 +378,9 @@ async fn test_openai_router_chat_completion_with_mock() {
...
@@ -219,7 +378,9 @@ async fn test_openai_router_chat_completion_with_mock() {
let
base_url
=
mock_server
.base_url
();
let
base_url
=
mock_server
.base_url
();
// Create router pointing to mock server
// Create router pointing to mock server
let
router
=
OpenAIRouter
::
new
(
base_url
,
None
)
.await
.unwrap
();
let
router
=
OpenAIRouter
::
new
(
base_url
,
None
,
Arc
::
new
(
MemoryResponseStorage
::
new
()))
.await
.unwrap
();
// Create a minimal chat completion request
// Create a minimal chat completion request
let
mut
chat_request
=
create_minimal_chat_request
();
let
mut
chat_request
=
create_minimal_chat_request
();
...
@@ -255,7 +416,9 @@ async fn test_openai_e2e_with_server() {
...
@@ -255,7 +416,9 @@ async fn test_openai_e2e_with_server() {
let
base_url
=
mock_server
.base_url
();
let
base_url
=
mock_server
.base_url
();
// Create router
// Create router
let
router
=
OpenAIRouter
::
new
(
base_url
,
None
)
.await
.unwrap
();
let
router
=
OpenAIRouter
::
new
(
base_url
,
None
,
Arc
::
new
(
MemoryResponseStorage
::
new
()))
.await
.unwrap
();
// Create Axum app with chat completions endpoint
// Create Axum app with chat completions endpoint
let
app
=
Router
::
new
()
.route
(
let
app
=
Router
::
new
()
.route
(
...
@@ -319,7 +482,9 @@ async fn test_openai_e2e_with_server() {
...
@@ -319,7 +482,9 @@ async fn test_openai_e2e_with_server() {
async
fn
test_openai_router_chat_streaming_with_mock
()
{
async
fn
test_openai_router_chat_streaming_with_mock
()
{
let
mock_server
=
MockOpenAIServer
::
new
()
.await
;
let
mock_server
=
MockOpenAIServer
::
new
()
.await
;
let
base_url
=
mock_server
.base_url
();
let
base_url
=
mock_server
.base_url
();
let
router
=
OpenAIRouter
::
new
(
base_url
,
None
)
.await
.unwrap
();
let
router
=
OpenAIRouter
::
new
(
base_url
,
None
,
Arc
::
new
(
MemoryResponseStorage
::
new
()))
.await
.unwrap
();
// Build a streaming chat request
// Build a streaming chat request
let
val
=
json!
({
let
val
=
json!
({
...
@@ -368,6 +533,7 @@ async fn test_openai_router_circuit_breaker() {
...
@@ -368,6 +533,7 @@ async fn test_openai_router_circuit_breaker() {
let
router
=
OpenAIRouter
::
new
(
let
router
=
OpenAIRouter
::
new
(
"http://invalid-url-that-will-fail"
.to_string
(),
"http://invalid-url-that-will-fail"
.to_string
(),
Some
(
cb_config
),
Some
(
cb_config
),
Arc
::
new
(
MemoryResponseStorage
::
new
()),
)
)
.await
.await
.unwrap
();
.unwrap
();
...
@@ -391,9 +557,13 @@ async fn test_openai_router_models_auth_forwarding() {
...
@@ -391,9 +557,13 @@ async fn test_openai_router_models_auth_forwarding() {
// Start a mock server that requires Authorization
// Start a mock server that requires Authorization
let
expected_auth
=
"Bearer test-token"
.to_string
();
let
expected_auth
=
"Bearer test-token"
.to_string
();
let
mock_server
=
MockOpenAIServer
::
new_with_auth
(
Some
(
expected_auth
.clone
()))
.await
;
let
mock_server
=
MockOpenAIServer
::
new_with_auth
(
Some
(
expected_auth
.clone
()))
.await
;
let
router
=
OpenAIRouter
::
new
(
mock_server
.base_url
(),
None
)
let
router
=
OpenAIRouter
::
new
(
.await
mock_server
.base_url
(),
.unwrap
();
None
,
Arc
::
new
(
MemoryResponseStorage
::
new
()),
)
.await
.unwrap
();
// 1) Without auth header -> expect 401
// 1) Without auth header -> expect 401
let
req
=
Request
::
builder
()
let
req
=
Request
::
builder
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment