Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
77258ce0
Unverified
Commit
77258ce0
authored
Oct 22, 2025
by
Keyang Ru
Committed by
GitHub
Oct 22, 2025
Browse files
[router] Support multiple worker URLs for OpenAI router (#11723)
parent
1d097aac
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
426 additions
and
150 deletions
+426
-150
sgl-router/src/config/validation.rs
sgl-router/src/config/validation.rs
+5
-9
sgl-router/src/protocols/responses.rs
sgl-router/src/protocols/responses.rs
+7
-7
sgl-router/src/routers/factory.rs
sgl-router/src/routers/factory.rs
+5
-6
sgl-router/src/routers/openai/responses.rs
sgl-router/src/routers/openai/responses.rs
+5
-4
sgl-router/src/routers/openai/router.rs
sgl-router/src/routers/openai/router.rs
+243
-97
sgl-router/src/routers/openai/utils.rs
sgl-router/src/routers/openai/utils.rs
+127
-0
sgl-router/src/routers/router_manager.rs
sgl-router/src/routers/router_manager.rs
+1
-1
sgl-router/tests/responses_api_test.rs
sgl-router/tests/responses_api_test.rs
+10
-10
sgl-router/tests/test_openai_routing.rs
sgl-router/tests/test_openai_routing.rs
+23
-16
No files found.
sgl-router/src/config/validation.rs
View file @
77258ce0
...
@@ -165,18 +165,14 @@ impl ConfigValidator {
...
@@ -165,18 +165,14 @@ impl ConfigValidator {
}
}
}
}
RoutingMode
::
OpenAI
{
worker_urls
}
=>
{
RoutingMode
::
OpenAI
{
worker_urls
}
=>
{
// Require
exactly
one worker URL for OpenAI router
// Require
at least
one worker URL for OpenAI router
if
worker_urls
.
len
()
!=
1
{
if
worker_urls
.
is_empty
()
{
return
Err
(
ConfigError
::
ValidationFailed
{
return
Err
(
ConfigError
::
ValidationFailed
{
reason
:
"OpenAI mode requires exactly one --worker-urls entry"
.to_string
(),
reason
:
"OpenAI mode requires at least one --worker-urls entry"
.to_string
(),
});
}
// Validate URL format
if
let
Err
(
e
)
=
url
::
Url
::
parse
(
&
worker_urls
[
0
])
{
return
Err
(
ConfigError
::
ValidationFailed
{
reason
:
format!
(
"Invalid OpenAI worker URL '{}': {}"
,
&
worker_urls
[
0
],
e
),
});
});
}
}
// Validate URLs
Self
::
validate_urls
(
worker_urls
)
?
;
}
}
}
}
Ok
(())
Ok
(())
...
...
sgl-router/src/protocols/responses.rs
View file @
77258ce0
...
@@ -8,8 +8,8 @@ use serde_json::Value;
...
@@ -8,8 +8,8 @@ use serde_json::Value;
// Import shared types from common module
// Import shared types from common module
use
super
::
common
::{
use
super
::
common
::{
default_true
,
ChatLogProbs
,
GenerationRequest
,
PromptTokenUsageInfo
,
StringOrArray
,
ToolChoice
,
default_model
,
default_true
,
ChatLogProbs
,
GenerationRequest
,
PromptTokenUsageInfo
,
UsageInfo
,
StringOrArray
,
ToolChoice
,
UsageInfo
,
};
};
// ============================================================================
// ============================================================================
...
@@ -452,9 +452,9 @@ pub struct ResponsesRequest {
...
@@ -452,9 +452,9 @@ pub struct ResponsesRequest {
#[serde(skip_serializing_if
=
"Option::is_none"
)]
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
metadata
:
Option
<
HashMap
<
String
,
Value
>>
,
pub
metadata
:
Option
<
HashMap
<
String
,
Value
>>
,
/// Model to use
(optional to match vLLM)
/// Model to use
#[serde(
skip_serializing_if
=
"Option::is_none
"
)]
#[serde(
default
=
"default_model
"
)]
pub
model
:
Option
<
String
>
,
pub
model
:
String
,
/// Optional conversation id to persist input/output as items
/// Optional conversation id to persist input/output as items
#[serde(skip_serializing_if
=
"Option::is_none"
)]
#[serde(skip_serializing_if
=
"Option::is_none"
)]
...
@@ -565,7 +565,7 @@ impl Default for ResponsesRequest {
...
@@ -565,7 +565,7 @@ impl Default for ResponsesRequest {
max_output_tokens
:
None
,
max_output_tokens
:
None
,
max_tool_calls
:
None
,
max_tool_calls
:
None
,
metadata
:
None
,
metadata
:
None
,
model
:
None
,
model
:
default_model
()
,
conversation
:
None
,
conversation
:
None
,
parallel_tool_calls
:
None
,
parallel_tool_calls
:
None
,
previous_response_id
:
None
,
previous_response_id
:
None
,
...
@@ -598,7 +598,7 @@ impl GenerationRequest for ResponsesRequest {
...
@@ -598,7 +598,7 @@ impl GenerationRequest for ResponsesRequest {
}
}
fn
get_model
(
&
self
)
->
Option
<&
str
>
{
fn
get_model
(
&
self
)
->
Option
<&
str
>
{
self
.model
.as_
deref
()
Some
(
self
.model
.as_
str
()
)
}
}
fn
extract_text_for_routing
(
&
self
)
->
String
{
fn
extract_text_for_routing
(
&
self
)
->
String
{
...
...
sgl-router/src/routers/factory.rs
View file @
77258ce0
...
@@ -55,7 +55,7 @@ impl RouterFactory {
...
@@ -55,7 +55,7 @@ impl RouterFactory {
)
)
.await
.await
}
}
RoutingMode
::
OpenAI
{
worker_urls
,
..
}
=>
{
RoutingMode
::
OpenAI
{
worker_urls
}
=>
{
Self
::
create_openai_router
(
worker_urls
.clone
(),
ctx
)
.await
Self
::
create_openai_router
(
worker_urls
.clone
(),
ctx
)
.await
}
}
},
},
...
@@ -122,13 +122,12 @@ impl RouterFactory {
...
@@ -122,13 +122,12 @@ impl RouterFactory {
worker_urls
:
Vec
<
String
>
,
worker_urls
:
Vec
<
String
>
,
ctx
:
&
Arc
<
AppContext
>
,
ctx
:
&
Arc
<
AppContext
>
,
)
->
Result
<
Box
<
dyn
RouterTrait
>
,
String
>
{
)
->
Result
<
Box
<
dyn
RouterTrait
>
,
String
>
{
let
base_url
=
worker_urls
if
worker_urls
.is_empty
()
{
.first
()
return
Err
(
"OpenAI mode requires at least one worker URL"
.to_string
());
.cloned
()
}
.ok_or_else
(||
"OpenAI mode requires at least one worker URL"
.to_string
())
?
;
let
router
=
OpenAIRouter
::
new
(
let
router
=
OpenAIRouter
::
new
(
base
_url
,
worker
_url
s
,
Some
(
ctx
.router_config.circuit_breaker
.clone
()),
Some
(
ctx
.router_config.circuit_breaker
.clone
()),
ctx
.response_storage
.clone
(),
ctx
.response_storage
.clone
(),
ctx
.conversation_storage
.clone
(),
ctx
.conversation_storage
.clone
(),
...
...
sgl-router/src/routers/openai/responses.rs
View file @
77258ce0
...
@@ -39,7 +39,7 @@ pub(super) fn build_stored_response(
...
@@ -39,7 +39,7 @@ pub(super) fn build_stored_response(
.get
(
"model"
)
.get
(
"model"
)
.and_then
(|
v
|
v
.as_str
())
.and_then
(|
v
|
v
.as_str
())
.map
(|
s
|
s
.to_string
())
.map
(|
s
|
s
.to_string
())
.or_else
(||
original_body
.model
.clone
());
.or_else
(||
Some
(
original_body
.model
.clone
())
)
;
stored_response
.user
=
response_json
stored_response
.user
=
response_json
.get
(
"user"
)
.get
(
"user"
)
...
@@ -143,9 +143,10 @@ pub(super) fn patch_streaming_response_json(
...
@@ -143,9 +143,10 @@ pub(super) fn patch_streaming_response_json(
.map
(|
s
|
s
.is_empty
())
.map
(|
s
|
s
.is_empty
())
.unwrap_or
(
true
)
.unwrap_or
(
true
)
{
{
if
let
Some
(
model
)
=
&
original_body
.model
{
obj
.insert
(
obj
.insert
(
"model"
.to_string
(),
Value
::
String
(
model
.clone
()));
"model"
.to_string
(),
}
Value
::
String
(
original_body
.model
.clone
()),
);
}
}
if
obj
.get
(
"user"
)
.map
(|
v
|
v
.is_null
())
.unwrap_or
(
false
)
{
if
obj
.get
(
"user"
)
.map
(|
v
|
v
.is_null
())
.unwrap_or
(
false
)
{
...
...
sgl-router/src/routers/openai/router.rs
View file @
77258ce0
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
use
std
::{
use
std
::{
any
::
Any
,
any
::
Any
,
sync
::{
atomic
::
AtomicBool
,
Arc
},
sync
::{
atomic
::
AtomicBool
,
Arc
},
time
::{
Duration
,
Instant
},
};
};
use
axum
::{
use
axum
::{
...
@@ -12,6 +13,7 @@ use axum::{
...
@@ -12,6 +13,7 @@ use axum::{
response
::{
IntoResponse
,
Response
},
response
::{
IntoResponse
,
Response
},
Json
,
Json
,
};
};
use
dashmap
::
DashMap
;
use
futures_util
::
StreamExt
;
use
futures_util
::
StreamExt
;
use
serde_json
::{
json
,
to_value
,
Value
};
use
serde_json
::{
json
,
to_value
,
Value
};
use
tokio
::
sync
::
mpsc
;
use
tokio
::
sync
::
mpsc
;
...
@@ -31,6 +33,7 @@ use super::{
...
@@ -31,6 +33,7 @@ use super::{
},
},
responses
::{
mask_tools_as_mcp
,
patch_streaming_response_json
},
responses
::{
mask_tools_as_mcp
,
patch_streaming_response_json
},
streaming
::
handle_streaming_response
,
streaming
::
handle_streaming_response
,
utils
::{
apply_provider_headers
,
extract_auth_header
,
probe_endpoint_for_model
},
};
};
use
crate
::{
use
crate
::{
config
::
CircuitBreakerConfig
,
config
::
CircuitBreakerConfig
,
...
@@ -59,12 +62,21 @@ use crate::{
...
@@ -59,12 +62,21 @@ use crate::{
// OpenAIRouter Struct
// OpenAIRouter Struct
// ============================================================================
// ============================================================================
/// Cached endpoint information
#[derive(Clone,
Debug)]
struct
CachedEndpoint
{
url
:
String
,
cached_at
:
Instant
,
}
/// Router for OpenAI backend
/// Router for OpenAI backend
pub
struct
OpenAIRouter
{
pub
struct
OpenAIRouter
{
/// HTTP client for upstream OpenAI-compatible API
/// HTTP client for upstream OpenAI-compatible API
client
:
reqwest
::
Client
,
client
:
reqwest
::
Client
,
/// Base URL for identification (no trailing slash)
/// Multiple OpenAI-compatible API endpoints (OpenAI, xAI, etc.)
base_url
:
String
,
worker_urls
:
Vec
<
String
>
,
/// Model cache: model_id -> endpoint URL
model_cache
:
Arc
<
DashMap
<
String
,
CachedEndpoint
>>
,
/// Circuit breaker
/// Circuit breaker
circuit_breaker
:
CircuitBreaker
,
circuit_breaker
:
CircuitBreaker
,
/// Health status
/// Health status
...
@@ -82,7 +94,7 @@ pub struct OpenAIRouter {
...
@@ -82,7 +94,7 @@ pub struct OpenAIRouter {
impl
std
::
fmt
::
Debug
for
OpenAIRouter
{
impl
std
::
fmt
::
Debug
for
OpenAIRouter
{
fn
fmt
(
&
self
,
f
:
&
mut
std
::
fmt
::
Formatter
<
'_
>
)
->
std
::
fmt
::
Result
{
fn
fmt
(
&
self
,
f
:
&
mut
std
::
fmt
::
Formatter
<
'_
>
)
->
std
::
fmt
::
Result
{
f
.debug_struct
(
"OpenAIRouter"
)
f
.debug_struct
(
"OpenAIRouter"
)
.field
(
"
base
_url"
,
&
self
.
base
_url
)
.field
(
"
worker
_url
s
"
,
&
self
.
worker
_url
s
)
.field
(
"healthy"
,
&
self
.healthy
)
.field
(
"healthy"
,
&
self
.healthy
)
.finish
()
.finish
()
}
}
...
@@ -92,28 +104,35 @@ impl OpenAIRouter {
...
@@ -92,28 +104,35 @@ impl OpenAIRouter {
/// Maximum number of conversation items to attach as input when a conversation is provided
/// Maximum number of conversation items to attach as input when a conversation is provided
const
MAX_CONVERSATION_HISTORY_ITEMS
:
usize
=
100
;
const
MAX_CONVERSATION_HISTORY_ITEMS
:
usize
=
100
;
/// Model discovery cache TTL (1 hour)
const
MODEL_CACHE_TTL_SECS
:
u64
=
3600
;
/// Create a new OpenAI router
/// Create a new OpenAI router
pub
async
fn
new
(
pub
async
fn
new
(
base
_url
:
String
,
worker
_url
s
:
Vec
<
String
>
,
circuit_breaker_config
:
Option
<
CircuitBreakerConfig
>
,
circuit_breaker_config
:
Option
<
CircuitBreakerConfig
>
,
response_storage
:
SharedResponseStorage
,
response_storage
:
SharedResponseStorage
,
conversation_storage
:
SharedConversationStorage
,
conversation_storage
:
SharedConversationStorage
,
conversation_item_storage
:
SharedConversationItemStorage
,
conversation_item_storage
:
SharedConversationItemStorage
,
)
->
Result
<
Self
,
String
>
{
)
->
Result
<
Self
,
String
>
{
let
client
=
reqwest
::
Client
::
builder
()
let
client
=
reqwest
::
Client
::
builder
()
.timeout
(
std
::
time
::
Duration
::
from_secs
(
300
))
.timeout
(
Duration
::
from_secs
(
300
))
.build
()
.build
()
.map_err
(|
e
|
format!
(
"Failed to create HTTP client: {}"
,
e
))
?
;
.map_err
(|
e
|
format!
(
"Failed to create HTTP client: {}"
,
e
))
?
;
let
base_url
=
base_url
.trim_end_matches
(
'/'
)
.to_string
();
// Normalize URLs (remove trailing slashes)
let
worker_urls
:
Vec
<
String
>
=
worker_urls
.into_iter
()
.map
(|
url
|
url
.trim_end_matches
(
'/'
)
.to_string
())
.collect
();
// Convert circuit breaker config
// Convert circuit breaker config
let
core_cb_config
=
circuit_breaker_config
let
core_cb_config
=
circuit_breaker_config
.map
(|
cb
|
CoreCircuitBreakerConfig
{
.map
(|
cb
|
CoreCircuitBreakerConfig
{
failure_threshold
:
cb
.failure_threshold
,
failure_threshold
:
cb
.failure_threshold
,
success_threshold
:
cb
.success_threshold
,
success_threshold
:
cb
.success_threshold
,
timeout_duration
:
std
::
time
::
Duration
::
from_secs
(
cb
.timeout_duration_secs
),
timeout_duration
:
Duration
::
from_secs
(
cb
.timeout_duration_secs
),
window_duration
:
std
::
time
::
Duration
::
from_secs
(
cb
.window_duration_secs
),
window_duration
:
Duration
::
from_secs
(
cb
.window_duration_secs
),
})
})
.unwrap_or_default
();
.unwrap_or_default
();
...
@@ -141,7 +160,8 @@ impl OpenAIRouter {
...
@@ -141,7 +160,8 @@ impl OpenAIRouter {
Ok
(
Self
{
Ok
(
Self
{
client
,
client
,
base_url
,
worker_urls
,
model_cache
:
Arc
::
new
(
DashMap
::
new
()),
circuit_breaker
,
circuit_breaker
,
healthy
:
AtomicBool
::
new
(
true
),
healthy
:
AtomicBool
::
new
(
true
),
response_storage
,
response_storage
,
...
@@ -151,6 +171,67 @@ impl OpenAIRouter {
...
@@ -151,6 +171,67 @@ impl OpenAIRouter {
})
})
}
}
/// Discover which endpoint has the model
async
fn
find_endpoint_for_model
(
&
self
,
model_id
:
&
str
,
auth_header
:
Option
<&
str
>
,
)
->
Result
<
String
,
Response
>
{
// Single endpoint - fast path
if
self
.worker_urls
.len
()
==
1
{
return
Ok
(
self
.worker_urls
[
0
]
.clone
());
}
// Check cache
if
let
Some
(
entry
)
=
self
.model_cache
.get
(
model_id
)
{
if
entry
.cached_at
.elapsed
()
<
Duration
::
from_secs
(
Self
::
MODEL_CACHE_TTL_SECS
)
{
return
Ok
(
entry
.url
.clone
());
}
}
// Probe all endpoints in parallel
let
mut
handles
=
vec!
[];
let
model
=
model_id
.to_string
();
let
auth
=
auth_header
.map
(|
s
|
s
.to_string
());
for
url
in
&
self
.worker_urls
{
let
handle
=
tokio
::
spawn
(
probe_endpoint_for_model
(
self
.client
.clone
(),
url
.clone
(),
model
.clone
(),
auth
.clone
(),
));
handles
.push
(
handle
);
}
// Return first successful endpoint
for
handle
in
handles
{
if
let
Ok
(
Ok
(
url
))
=
handle
.await
{
// Cache it
self
.model_cache
.insert
(
model_id
.to_string
(),
CachedEndpoint
{
url
:
url
.clone
(),
cached_at
:
Instant
::
now
(),
},
);
return
Ok
(
url
);
}
}
// Model not found on any endpoint
Err
((
StatusCode
::
NOT_FOUND
,
Json
(
json!
({
"error"
:
{
"message"
:
format!
(
"Model '{}' not found on any endpoint"
,
model_id
),
"type"
:
"model_not_found"
,
}
})),
)
.into_response
())
}
/// Handle non-streaming response with optional MCP tool loop
/// Handle non-streaming response with optional MCP tool loop
async
fn
handle_non_streaming_response
(
async
fn
handle_non_streaming_response
(
&
self
,
&
self
,
...
@@ -282,85 +363,145 @@ impl crate::routers::RouterTrait for OpenAIRouter {
...
@@ -282,85 +363,145 @@ impl crate::routers::RouterTrait for OpenAIRouter {
}
}
async
fn
health_generate
(
&
self
,
_
req
:
Request
<
Body
>
)
->
Response
{
async
fn
health_generate
(
&
self
,
_
req
:
Request
<
Body
>
)
->
Response
{
// Simple upstream probe: GET {base}/v1/models without auth
// Check all endpoints in parallel - only healthy if ALL are healthy
let
url
=
format!
(
"{}/v1/models"
,
self
.base_url
);
if
self
.worker_urls
.is_empty
()
{
match
self
return
(
StatusCode
::
SERVICE_UNAVAILABLE
,
"No endpoints configured"
)
.into_response
();
.client
}
.get
(
&
url
)
.timeout
(
std
::
time
::
Duration
::
from_secs
(
2
))
let
mut
handles
=
vec!
[];
.send
()
for
url
in
&
self
.worker_urls
{
.await
let
url
=
url
.clone
();
{
let
client
=
self
.client
.clone
();
Ok
(
resp
)
=>
{
let
code
=
resp
.status
();
let
handle
=
tokio
::
spawn
(
async
move
{
// Treat success and auth-required as healthy (endpoint reachable)
let
probe_url
=
format!
(
"{}/v1/models"
,
url
);
if
code
.is_success
()
||
code
.as_u16
()
==
401
||
code
.as_u16
()
==
403
{
match
client
(
StatusCode
::
OK
,
"OK"
)
.into_response
()
.get
(
&
probe_url
)
}
else
{
.timeout
(
Duration
::
from_secs
(
2
))
(
.send
()
StatusCode
::
SERVICE_UNAVAILABLE
,
.await
format!
(
"Upstream status: {}"
,
code
),
{
)
Ok
(
resp
)
=>
{
.into_response
()
let
code
=
resp
.status
();
// Treat success and auth-required as healthy (endpoint reachable)
if
code
.is_success
()
||
code
.as_u16
()
==
401
||
code
.as_u16
()
==
403
{
Ok
(())
}
else
{
Err
(
format!
(
"Endpoint {} returned status {}"
,
url
,
code
))
}
}
Err
(
e
)
=>
Err
(
format!
(
"Endpoint {} error: {}"
,
url
,
e
)),
}
}
});
handles
.push
(
handle
);
}
// Collect all results
let
mut
errors
=
Vec
::
new
();
for
handle
in
handles
{
match
handle
.await
{
Ok
(
Ok
(()))
=>
(),
Ok
(
Err
(
e
))
=>
errors
.push
(
e
),
Err
(
e
)
=>
errors
.push
(
format!
(
"Task join error: {}"
,
e
)),
}
}
Err
(
e
)
=>
(
}
if
errors
.is_empty
()
{
(
StatusCode
::
OK
,
"OK"
)
.into_response
()
}
else
{
(
StatusCode
::
SERVICE_UNAVAILABLE
,
StatusCode
::
SERVICE_UNAVAILABLE
,
format!
(
"
Upstream error: {}
"
,
e
),
format!
(
"
Some endpoints unhealthy: {}"
,
errors
.join
(
",
"
)
),
)
)
.into_response
()
,
.into_response
()
}
}
}
}
async
fn
get_server_info
(
&
self
,
_
req
:
Request
<
Body
>
)
->
Response
{
async
fn
get_server_info
(
&
self
,
_
req
:
Request
<
Body
>
)
->
Response
{
let
info
=
json!
({
let
info
=
json!
({
"router_type"
:
"openai"
,
"router_type"
:
"openai"
,
"workers"
:
1
,
"workers"
:
self
.worker_urls
.len
()
,
"
base
_url"
:
&
self
.
base
_url
"
worker
_url
s
"
:
&
self
.
worker
_url
s
});
});
(
StatusCode
::
OK
,
info
.to_string
())
.into_response
()
(
StatusCode
::
OK
,
info
.to_string
())
.into_response
()
}
}
async
fn
get_models
(
&
self
,
req
:
Request
<
Body
>
)
->
Response
{
async
fn
get_models
(
&
self
,
req
:
Request
<
Body
>
)
->
Response
{
// Proxy to upstream /v1/models; forward Authorization header if provided
// Aggregate models from all endpoints
let
headers
=
req
.headers
();
if
self
.worker_urls
.is_empty
()
{
return
(
StatusCode
::
SERVICE_UNAVAILABLE
,
"No endpoints configured"
)
.into_response
();
let
mut
upstream
=
self
.client
.get
(
format!
(
"{}/v1/models"
,
self
.base_url
));
if
let
Some
(
auth
)
=
headers
.get
(
"authorization"
)
.or_else
(||
headers
.get
(
"Authorization"
))
{
upstream
=
upstream
.header
(
"Authorization"
,
auth
);
}
}
match
upstream
.send
()
.await
{
let
headers
=
req
.headers
();
Ok
(
res
)
=>
{
let
auth
=
headers
let
status
=
StatusCode
::
from_u16
(
res
.status
()
.as_u16
())
.get
(
"authorization"
)
.unwrap_or
(
StatusCode
::
INTERNAL_SERVER_ERROR
);
.or_else
(||
headers
.get
(
"Authorization"
));
let
content_type
=
res
.headers
()
.get
(
CONTENT_TYPE
)
.cloned
();
match
res
.bytes
()
.await
{
// Query all endpoints in parallel
Ok
(
body
)
=>
{
let
mut
handles
=
vec!
[];
let
mut
response
=
Response
::
new
(
Body
::
from
(
body
));
for
url
in
&
self
.worker_urls
{
*
response
.status_mut
()
=
status
;
let
url
=
url
.clone
();
if
let
Some
(
ct
)
=
content_type
{
let
client
=
self
.client
.clone
();
response
.headers_mut
()
.insert
(
CONTENT_TYPE
,
ct
);
let
auth
=
auth
.cloned
();
let
handle
=
tokio
::
spawn
(
async
move
{
let
models_url
=
format!
(
"{}/v1/models"
,
url
);
let
req
=
client
.get
(
&
models_url
);
// Apply provider-specific headers (handles Anthropic, xAI, OpenAI, etc.)
let
req
=
apply_provider_headers
(
req
,
&
url
,
auth
.as_ref
());
match
req
.send
()
.await
{
Ok
(
res
)
=>
{
if
res
.status
()
.is_success
()
{
match
res
.json
::
<
Value
>
()
.await
{
Ok
(
json
)
=>
Ok
(
json
),
Err
(
e
)
=>
{
tracing
::
warn!
(
"Failed to parse models response from '{}': {}"
,
url
,
e
);
Err
(())
}
}
}
else
{
tracing
::
warn!
(
"Getting models from '{}' failed with status: {}"
,
url
,
res
.status
()
);
Err
(())
}
}
response
}
}
Err
(
e
)
=>
(
Err
(
e
)
=>
{
StatusCode
::
INTERNAL_SERVER_ERROR
,
tracing
::
warn!
(
"Request to get models from '{}' failed: {}"
,
url
,
e
);
format!
(
"Failed to read upstream response: {}"
,
e
),
Err
(())
)
}
.into_response
(),
}
});
handles
.push
(
handle
);
}
// Collect all model lists
let
mut
all_models
=
Vec
::
new
();
for
handle
in
handles
{
if
let
Ok
(
Ok
(
json
))
=
handle
.await
{
if
let
Some
(
data
)
=
json
.get
(
"data"
)
.and_then
(|
v
|
v
.as_array
())
{
all_models
.extend_from_slice
(
data
);
}
}
}
}
Err
(
e
)
=>
(
StatusCode
::
BAD_GATEWAY
,
format!
(
"Failed to contact upstream: {}"
,
e
),
)
.into_response
(),
}
}
// Return aggregated models
let
response_json
=
json!
({
"object"
:
"list"
,
"data"
:
all_models
});
(
StatusCode
::
OK
,
Json
(
response_json
))
.into_response
()
}
}
async
fn
get_model_info
(
&
self
,
_
req
:
Request
<
Body
>
)
->
Response
{
async
fn
get_model_info
(
&
self
,
_
req
:
Request
<
Body
>
)
->
Response
{
...
@@ -396,6 +537,18 @@ impl crate::routers::RouterTrait for OpenAIRouter {
...
@@ -396,6 +537,18 @@ impl crate::routers::RouterTrait for OpenAIRouter {
return
(
StatusCode
::
SERVICE_UNAVAILABLE
,
"Circuit breaker open"
)
.into_response
();
return
(
StatusCode
::
SERVICE_UNAVAILABLE
,
"Circuit breaker open"
)
.into_response
();
}
}
// Extract auth header
let
auth
=
extract_auth_header
(
headers
);
// Find endpoint for model
let
base_url
=
match
self
.find_endpoint_for_model
(
body
.model
.as_str
(),
auth
)
.await
{
Ok
(
url
)
=>
url
,
Err
(
response
)
=>
return
response
,
};
// Serialize request body, removing SGLang-only fields
// Serialize request body, removing SGLang-only fields
let
mut
payload
=
match
to_value
(
body
)
{
let
mut
payload
=
match
to_value
(
body
)
{
Ok
(
v
)
=>
v
,
Ok
(
v
)
=>
v
,
...
@@ -431,9 +584,14 @@ impl crate::routers::RouterTrait for OpenAIRouter {
...
@@ -431,9 +584,14 @@ impl crate::routers::RouterTrait for OpenAIRouter {
]
{
]
{
obj
.remove
(
key
);
obj
.remove
(
key
);
}
}
// Remove logprobs if false (Gemini don't accept it)
if
obj
.get
(
"logprobs"
)
.and_then
(|
v
|
v
.as_bool
())
==
Some
(
false
)
{
obj
.remove
(
"logprobs"
);
}
}
}
let
url
=
format!
(
"{}/v1/chat/completions"
,
self
.
base_url
);
let
url
=
format!
(
"{}/v1/chat/completions"
,
base_url
);
let
mut
req
=
self
.client
.post
(
&
url
)
.json
(
&
payload
);
let
mut
req
=
self
.client
.post
(
&
url
)
.json
(
&
payload
);
// Forward Authorization header if provided
// Forward Authorization header if provided
...
@@ -534,7 +692,17 @@ impl crate::routers::RouterTrait for OpenAIRouter {
...
@@ -534,7 +692,17 @@ impl crate::routers::RouterTrait for OpenAIRouter {
body
:
&
ResponsesRequest
,
body
:
&
ResponsesRequest
,
model_id
:
Option
<&
str
>
,
model_id
:
Option
<&
str
>
,
)
->
Response
{
)
->
Response
{
let
url
=
format!
(
"{}/v1/responses"
,
self
.base_url
);
// Extract auth header
let
auth
=
extract_auth_header
(
headers
);
// Find endpoint for model (use model_id if provided, otherwise use body.model)
let
model
=
model_id
.unwrap_or
(
body
.model
.as_str
());
let
base_url
=
match
self
.find_endpoint_for_model
(
model
,
auth
)
.await
{
Ok
(
url
)
=>
url
,
Err
(
response
)
=>
return
response
,
};
let
url
=
format!
(
"{}/v1/responses"
,
base_url
);
// Validate mutually exclusive params: previous_response_id and conversation
// Validate mutually exclusive params: previous_response_id and conversation
// TODO: this validation logic should move the right place, also we need a proper error message module
// TODO: this validation logic should move the right place, also we need a proper error message module
...
@@ -556,7 +724,7 @@ impl crate::routers::RouterTrait for OpenAIRouter {
...
@@ -556,7 +724,7 @@ impl crate::routers::RouterTrait for OpenAIRouter {
// Clone the body for validation and logic, but we'll build payload differently
// Clone the body for validation and logic, but we'll build payload differently
let
mut
request_body
=
body
.clone
();
let
mut
request_body
=
body
.clone
();
if
let
Some
(
model
)
=
model_id
{
if
let
Some
(
model
)
=
model_id
{
request_body
.model
=
Some
(
model
.to_string
()
)
;
request_body
.model
=
model
.to_string
();
}
}
// Do not forward conversation field upstream; retain for local persistence only
// Do not forward conversation field upstream; retain for local persistence only
request_body
.conversation
=
None
;
request_body
.conversation
=
None
;
...
@@ -847,34 +1015,12 @@ impl crate::routers::RouterTrait for OpenAIRouter {
...
@@ -847,34 +1015,12 @@ impl crate::routers::RouterTrait for OpenAIRouter {
}
}
}
}
async
fn
cancel_response
(
&
self
,
headers
:
Option
<&
HeaderMap
>
,
response_id
:
&
str
)
->
Response
{
async
fn
cancel_response
(
&
self
,
_
headers
:
Option
<&
HeaderMap
>
,
_
response_id
:
&
str
)
->
Response
{
// Forward cancellation to upstream
(
let
url
=
format!
(
"{}/v1/responses/{}/cancel"
,
self
.base_url
,
response_id
);
StatusCode
::
NOT_IMPLEMENTED
,
let
mut
req
=
self
.client
.post
(
&
url
);
"Cancel response not implemented for OpenAI router"
,
)
if
let
Some
(
h
)
=
headers
{
.into_response
()
req
=
apply_request_headers
(
h
,
req
,
false
);
}
match
req
.send
()
.await
{
Ok
(
resp
)
=>
{
let
status
=
StatusCode
::
from_u16
(
resp
.status
()
.as_u16
())
.unwrap_or
(
StatusCode
::
INTERNAL_SERVER_ERROR
);
match
resp
.text
()
.await
{
Ok
(
body
)
=>
(
status
,
body
)
.into_response
(),
Err
(
e
)
=>
(
StatusCode
::
INTERNAL_SERVER_ERROR
,
format!
(
"Failed to read response: {}"
,
e
),
)
.into_response
(),
}
}
Err
(
e
)
=>
(
StatusCode
::
BAD_GATEWAY
,
format!
(
"Failed to contact upstream: {}"
,
e
),
)
.into_response
(),
}
}
}
async
fn
route_embeddings
(
async
fn
route_embeddings
(
...
...
sgl-router/src/routers/openai/utils.rs
View file @
77258ce0
...
@@ -2,6 +2,8 @@
...
@@ -2,6 +2,8 @@
use
std
::
collections
::
HashMap
;
use
std
::
collections
::
HashMap
;
use
axum
::
http
::{
HeaderMap
,
HeaderValue
};
// ============================================================================
// ============================================================================
// SSE Event Type Constants
// SSE Event Type Constants
// ============================================================================
// ============================================================================
...
@@ -93,6 +95,131 @@ impl OutputIndexMapper {
...
@@ -93,6 +95,131 @@ impl OutputIndexMapper {
}
}
}
}
// ============================================================================
// Provider Detection and Header Handling
// ============================================================================
/// Extract authorization header from request headers
/// Checks both "authorization" and "Authorization" (case variations)
pub
fn
extract_auth_header
(
headers
:
Option
<&
HeaderMap
>
)
->
Option
<&
str
>
{
headers
.and_then
(|
h
|
{
h
.get
(
"authorization"
)
.or_else
(||
h
.get
(
"Authorization"
))
.and_then
(|
v
|
v
.to_str
()
.ok
())
})
}
/// API provider types
#[derive(Debug,
Clone,
Copy,
PartialEq,
Eq)]
pub
enum
ApiProvider
{
Anthropic
,
Xai
,
OpenAi
,
Gemini
,
Generic
,
}
impl
ApiProvider
{
/// Detect provider type from URL
pub
fn
from_url
(
url
:
&
str
)
->
Self
{
if
url
.contains
(
"anthropic"
)
{
ApiProvider
::
Anthropic
}
else
if
url
.contains
(
"x.ai"
)
{
ApiProvider
::
Xai
}
else
if
url
.contains
(
"openai.com"
)
{
ApiProvider
::
OpenAi
}
else
if
url
.contains
(
"googleapis.com"
)
{
ApiProvider
::
Gemini
}
else
{
ApiProvider
::
Generic
}
}
}
/// Apply provider-specific headers to request
pub
fn
apply_provider_headers
(
mut
req
:
reqwest
::
RequestBuilder
,
url
:
&
str
,
auth_header
:
Option
<&
HeaderValue
>
,
)
->
reqwest
::
RequestBuilder
{
let
provider
=
ApiProvider
::
from_url
(
url
);
match
provider
{
ApiProvider
::
Anthropic
=>
{
// Anthropic requires x-api-key instead of Authorization
// Extract Bearer token and use as x-api-key
if
let
Some
(
auth
)
=
auth_header
{
if
let
Ok
(
auth_str
)
=
auth
.to_str
()
{
let
api_key
=
auth_str
.strip_prefix
(
"Bearer "
)
.unwrap_or
(
auth_str
);
req
=
req
.header
(
"x-api-key"
,
api_key
)
.header
(
"anthropic-version"
,
"2023-06-01"
);
}
}
}
ApiProvider
::
Gemini
|
ApiProvider
::
Xai
|
ApiProvider
::
OpenAi
|
ApiProvider
::
Generic
=>
{
// Standard OpenAI-compatible: use Authorization header as-is
if
let
Some
(
auth
)
=
auth_header
{
req
=
req
.header
(
"Authorization"
,
auth
);
}
}
}
req
}
/// Probe a single endpoint to check if it has the model
/// Returns Ok(url) if model found, Err(()) otherwise
pub
async
fn
probe_endpoint_for_model
(
client
:
reqwest
::
Client
,
url
:
String
,
model
:
String
,
auth
:
Option
<
String
>
,
)
->
Result
<
String
,
()
>
{
use
tracing
::
debug
;
let
probe_url
=
format!
(
"{}/v1/models/{}"
,
url
,
model
);
let
req
=
client
.get
(
&
probe_url
)
.timeout
(
std
::
time
::
Duration
::
from_secs
(
5
));
// Apply provider-specific headers (handles Anthropic, xAI, OpenAI, etc.)
let
auth_header_value
=
auth
.as_ref
()
.and_then
(|
a
|
HeaderValue
::
from_str
(
a
)
.ok
());
let
req
=
apply_provider_headers
(
req
,
&
url
,
auth_header_value
.as_ref
());
match
req
.send
()
.await
{
Ok
(
resp
)
=>
{
let
status
=
resp
.status
();
if
status
.is_success
()
{
debug!
(
url
=
%
url
,
model
=
%
model
,
status
=
%
status
,
"Model found on endpoint"
);
Ok
(
url
)
}
else
{
debug!
(
url
=
%
url
,
model
=
%
model
,
status
=
%
status
,
"Model not found on endpoint (unsuccessful status)"
);
Err
(())
}
}
Err
(
e
)
=>
{
debug!
(
url
=
%
url
,
model
=
%
model
,
error
=
%
e
,
"Probe request to endpoint failed"
);
Err
(())
}
}
}
// ============================================================================
// ============================================================================
// Re-export FunctionCallInProgress from mcp module
// Re-export FunctionCallInProgress from mcp module
// ============================================================================
// ============================================================================
...
...
sgl-router/src/routers/router_manager.rs
View file @
77258ce0
...
@@ -410,7 +410,7 @@ impl RouterTrait for RouterManager {
...
@@ -410,7 +410,7 @@ impl RouterTrait for RouterManager {
body
:
&
ResponsesRequest
,
body
:
&
ResponsesRequest
,
model_id
:
Option
<&
str
>
,
model_id
:
Option
<&
str
>
,
)
->
Response
{
)
->
Response
{
let
selected_model
=
body
.model
.as_deref
()
.or
(
model_id
);
let
selected_model
=
model_id
.or
(
Some
(
body
.model
.as_str
())
);
let
router
=
self
.select_router_for_request
(
headers
,
selected_model
);
let
router
=
self
.select_router_for_request
(
headers
,
selected_model
);
if
let
Some
(
router
)
=
router
{
if
let
Some
(
router
)
=
router
{
...
...
sgl-router/tests/responses_api_test.rs
View file @
77258ce0
...
@@ -100,7 +100,7 @@ async fn test_non_streaming_mcp_minimal_e2e_with_persistence() {
...
@@ -100,7 +100,7 @@ async fn test_non_streaming_mcp_minimal_e2e_with_persistence() {
max_output_tokens
:
Some
(
64
),
max_output_tokens
:
Some
(
64
),
max_tool_calls
:
None
,
max_tool_calls
:
None
,
metadata
:
None
,
metadata
:
None
,
model
:
Some
(
"mock-model"
.to_string
()
)
,
model
:
"mock-model"
.to_string
(),
parallel_tool_calls
:
Some
(
true
),
parallel_tool_calls
:
Some
(
true
),
previous_response_id
:
None
,
previous_response_id
:
None
,
reasoning
:
None
,
reasoning
:
None
,
...
@@ -134,7 +134,7 @@ async fn test_non_streaming_mcp_minimal_e2e_with_persistence() {
...
@@ -134,7 +134,7 @@ async fn test_non_streaming_mcp_minimal_e2e_with_persistence() {
};
};
let
resp
=
router
let
resp
=
router
.route_responses
(
None
,
&
req
,
req
.model
.as_
deref
(
))
.route_responses
(
None
,
&
req
,
Some
(
req
.model
.as_
str
()
))
.await
;
.await
;
assert_eq!
(
resp
.status
(),
StatusCode
::
OK
);
assert_eq!
(
resp
.status
(),
StatusCode
::
OK
);
...
@@ -349,7 +349,7 @@ fn test_responses_request_creation() {
...
@@ -349,7 +349,7 @@ fn test_responses_request_creation() {
max_output_tokens
:
Some
(
100
),
max_output_tokens
:
Some
(
100
),
max_tool_calls
:
None
,
max_tool_calls
:
None
,
metadata
:
None
,
metadata
:
None
,
model
:
Some
(
"test-model"
.to_string
()
)
,
model
:
"test-model"
.to_string
(),
parallel_tool_calls
:
Some
(
true
),
parallel_tool_calls
:
Some
(
true
),
previous_response_id
:
None
,
previous_response_id
:
None
,
reasoning
:
Some
(
ResponseReasoningParam
{
reasoning
:
Some
(
ResponseReasoningParam
{
...
@@ -397,7 +397,7 @@ fn test_responses_request_sglang_extensions() {
...
@@ -397,7 +397,7 @@ fn test_responses_request_sglang_extensions() {
max_output_tokens
:
Some
(
50
),
max_output_tokens
:
Some
(
50
),
max_tool_calls
:
None
,
max_tool_calls
:
None
,
metadata
:
None
,
metadata
:
None
,
model
:
Some
(
"test-model"
.to_string
()
)
,
model
:
"test-model"
.to_string
(),
parallel_tool_calls
:
Some
(
true
),
parallel_tool_calls
:
Some
(
true
),
previous_response_id
:
None
,
previous_response_id
:
None
,
reasoning
:
None
,
reasoning
:
None
,
...
@@ -506,7 +506,7 @@ fn test_json_serialization() {
...
@@ -506,7 +506,7 @@ fn test_json_serialization() {
max_output_tokens
:
Some
(
200
),
max_output_tokens
:
Some
(
200
),
max_tool_calls
:
Some
(
5
),
max_tool_calls
:
Some
(
5
),
metadata
:
None
,
metadata
:
None
,
model
:
Some
(
"gpt-4"
.to_string
()
)
,
model
:
"gpt-4"
.to_string
(),
parallel_tool_calls
:
Some
(
false
),
parallel_tool_calls
:
Some
(
false
),
previous_response_id
:
None
,
previous_response_id
:
None
,
reasoning
:
Some
(
ResponseReasoningParam
{
reasoning
:
Some
(
ResponseReasoningParam
{
...
@@ -545,7 +545,7 @@ fn test_json_serialization() {
...
@@ -545,7 +545,7 @@ fn test_json_serialization() {
parsed
.request_id
,
parsed
.request_id
,
Some
(
"resp_comprehensive_test"
.to_string
())
Some
(
"resp_comprehensive_test"
.to_string
())
);
);
assert_eq!
(
parsed
.model
,
Some
(
"gpt-4"
.to_string
())
);
assert_eq!
(
parsed
.model
,
"gpt-4"
);
assert_eq!
(
parsed
.background
,
Some
(
true
));
assert_eq!
(
parsed
.background
,
Some
(
true
));
assert_eq!
(
parsed
.stream
,
Some
(
true
));
assert_eq!
(
parsed
.stream
,
Some
(
true
));
assert_eq!
(
parsed
.tools
.as_ref
()
.map
(|
t
|
t
.len
()),
Some
(
1
));
assert_eq!
(
parsed
.tools
.as_ref
()
.map
(|
t
|
t
.len
()),
Some
(
1
));
...
@@ -636,7 +636,7 @@ async fn test_multi_turn_loop_with_mcp() {
...
@@ -636,7 +636,7 @@ async fn test_multi_turn_loop_with_mcp() {
max_output_tokens
:
Some
(
128
),
max_output_tokens
:
Some
(
128
),
max_tool_calls
:
None
,
// No limit - test unlimited
max_tool_calls
:
None
,
// No limit - test unlimited
metadata
:
None
,
metadata
:
None
,
model
:
Some
(
"mock-model"
.to_string
()
)
,
model
:
"mock-model"
.to_string
(),
parallel_tool_calls
:
Some
(
true
),
parallel_tool_calls
:
Some
(
true
),
previous_response_id
:
None
,
previous_response_id
:
None
,
reasoning
:
None
,
reasoning
:
None
,
...
@@ -812,7 +812,7 @@ async fn test_max_tool_calls_limit() {
...
@@ -812,7 +812,7 @@ async fn test_max_tool_calls_limit() {
max_output_tokens
:
Some
(
128
),
max_output_tokens
:
Some
(
128
),
max_tool_calls
:
Some
(
1
),
// Limit to 1 call
max_tool_calls
:
Some
(
1
),
// Limit to 1 call
metadata
:
None
,
metadata
:
None
,
model
:
Some
(
"mock-model"
.to_string
()
)
,
model
:
"mock-model"
.to_string
(),
parallel_tool_calls
:
Some
(
true
),
parallel_tool_calls
:
Some
(
true
),
previous_response_id
:
None
,
previous_response_id
:
None
,
reasoning
:
None
,
reasoning
:
None
,
...
@@ -1006,7 +1006,7 @@ async fn test_streaming_with_mcp_tool_calls() {
...
@@ -1006,7 +1006,7 @@ async fn test_streaming_with_mcp_tool_calls() {
max_output_tokens
:
Some
(
256
),
max_output_tokens
:
Some
(
256
),
max_tool_calls
:
Some
(
3
),
max_tool_calls
:
Some
(
3
),
metadata
:
None
,
metadata
:
None
,
model
:
Some
(
"mock-model"
.to_string
()
)
,
model
:
"mock-model"
.to_string
(),
parallel_tool_calls
:
Some
(
true
),
parallel_tool_calls
:
Some
(
true
),
previous_response_id
:
None
,
previous_response_id
:
None
,
reasoning
:
None
,
reasoning
:
None
,
...
@@ -1287,7 +1287,7 @@ async fn test_streaming_multi_turn_with_mcp() {
...
@@ -1287,7 +1287,7 @@ async fn test_streaming_multi_turn_with_mcp() {
max_output_tokens
:
Some
(
512
),
max_output_tokens
:
Some
(
512
),
max_tool_calls
:
Some
(
5
),
// Allow multiple rounds
max_tool_calls
:
Some
(
5
),
// Allow multiple rounds
metadata
:
None
,
metadata
:
None
,
model
:
Some
(
"mock-model"
.to_string
()
)
,
model
:
"mock-model"
.to_string
(),
parallel_tool_calls
:
Some
(
true
),
parallel_tool_calls
:
Some
(
true
),
previous_response_id
:
None
,
previous_response_id
:
None
,
reasoning
:
None
,
reasoning
:
None
,
...
...
sgl-router/tests/test_openai_routing.rs
View file @
77258ce0
...
@@ -99,7 +99,7 @@ fn create_minimal_completion_request() -> CompletionRequest {
...
@@ -99,7 +99,7 @@ fn create_minimal_completion_request() -> CompletionRequest {
#[tokio::test]
#[tokio::test]
async
fn
test_openai_router_creation
()
{
async
fn
test_openai_router_creation
()
{
let
router
=
OpenAIRouter
::
new
(
let
router
=
OpenAIRouter
::
new
(
"https://api.openai.com"
.to_string
(),
vec!
[
"https://api.openai.com"
.to_string
()
]
,
None
,
None
,
Arc
::
new
(
MemoryResponseStorage
::
new
()),
Arc
::
new
(
MemoryResponseStorage
::
new
()),
Arc
::
new
(
MemoryConversationStorage
::
new
()),
Arc
::
new
(
MemoryConversationStorage
::
new
()),
...
@@ -118,7 +118,7 @@ async fn test_openai_router_creation() {
...
@@ -118,7 +118,7 @@ async fn test_openai_router_creation() {
#[tokio::test]
#[tokio::test]
async
fn
test_openai_router_server_info
()
{
async
fn
test_openai_router_server_info
()
{
let
router
=
OpenAIRouter
::
new
(
let
router
=
OpenAIRouter
::
new
(
"https://api.openai.com"
.to_string
(),
vec!
[
"https://api.openai.com"
.to_string
()
]
,
None
,
None
,
Arc
::
new
(
MemoryResponseStorage
::
new
()),
Arc
::
new
(
MemoryResponseStorage
::
new
()),
Arc
::
new
(
MemoryConversationStorage
::
new
()),
Arc
::
new
(
MemoryConversationStorage
::
new
()),
...
@@ -149,7 +149,7 @@ async fn test_openai_router_models() {
...
@@ -149,7 +149,7 @@ async fn test_openai_router_models() {
// Use mock server for deterministic models response
// Use mock server for deterministic models response
let
mock_server
=
MockOpenAIServer
::
new
()
.await
;
let
mock_server
=
MockOpenAIServer
::
new
()
.await
;
let
router
=
OpenAIRouter
::
new
(
let
router
=
OpenAIRouter
::
new
(
mock_server
.base_url
(),
vec!
[
mock_server
.base_url
()
]
,
None
,
None
,
Arc
::
new
(
MemoryResponseStorage
::
new
()),
Arc
::
new
(
MemoryResponseStorage
::
new
()),
Arc
::
new
(
MemoryConversationStorage
::
new
()),
Arc
::
new
(
MemoryConversationStorage
::
new
()),
...
@@ -229,7 +229,7 @@ async fn test_openai_router_responses_with_mock() {
...
@@ -229,7 +229,7 @@ async fn test_openai_router_responses_with_mock() {
let
storage
=
Arc
::
new
(
MemoryResponseStorage
::
new
());
let
storage
=
Arc
::
new
(
MemoryResponseStorage
::
new
());
let
router
=
OpenAIRouter
::
new
(
let
router
=
OpenAIRouter
::
new
(
base_url
,
vec!
[
base_url
]
,
None
,
None
,
storage
.clone
(),
storage
.clone
(),
Arc
::
new
(
MemoryConversationStorage
::
new
()),
Arc
::
new
(
MemoryConversationStorage
::
new
()),
...
@@ -239,7 +239,7 @@ async fn test_openai_router_responses_with_mock() {
...
@@ -239,7 +239,7 @@ async fn test_openai_router_responses_with_mock() {
.unwrap
();
.unwrap
();
let
request1
=
ResponsesRequest
{
let
request1
=
ResponsesRequest
{
model
:
Some
(
"gpt-4o-mini"
.to_string
()
)
,
model
:
"gpt-4o-mini"
.to_string
(),
input
:
ResponseInput
::
Text
(
"Say hi"
.to_string
()),
input
:
ResponseInput
::
Text
(
"Say hi"
.to_string
()),
store
:
Some
(
true
),
store
:
Some
(
true
),
..
Default
::
default
()
..
Default
::
default
()
...
@@ -255,7 +255,7 @@ async fn test_openai_router_responses_with_mock() {
...
@@ -255,7 +255,7 @@ async fn test_openai_router_responses_with_mock() {
assert_eq!
(
body1
[
"previous_response_id"
],
serde_json
::
Value
::
Null
);
assert_eq!
(
body1
[
"previous_response_id"
],
serde_json
::
Value
::
Null
);
let
request2
=
ResponsesRequest
{
let
request2
=
ResponsesRequest
{
model
:
Some
(
"gpt-4o-mini"
.to_string
()
)
,
model
:
"gpt-4o-mini"
.to_string
(),
input
:
ResponseInput
::
Text
(
"Thanks"
.to_string
()),
input
:
ResponseInput
::
Text
(
"Thanks"
.to_string
()),
store
:
Some
(
true
),
store
:
Some
(
true
),
previous_response_id
:
Some
(
resp1_id
.clone
()),
previous_response_id
:
Some
(
resp1_id
.clone
()),
...
@@ -490,7 +490,7 @@ async fn test_openai_router_responses_streaming_with_mock() {
...
@@ -490,7 +490,7 @@ async fn test_openai_router_responses_streaming_with_mock() {
storage
.store_response
(
previous
)
.await
.unwrap
();
storage
.store_response
(
previous
)
.await
.unwrap
();
let
router
=
OpenAIRouter
::
new
(
let
router
=
OpenAIRouter
::
new
(
base_url
,
vec!
[
base_url
]
,
None
,
None
,
storage
.clone
(),
storage
.clone
(),
Arc
::
new
(
MemoryConversationStorage
::
new
()),
Arc
::
new
(
MemoryConversationStorage
::
new
()),
...
@@ -503,7 +503,7 @@ async fn test_openai_router_responses_streaming_with_mock() {
...
@@ -503,7 +503,7 @@ async fn test_openai_router_responses_streaming_with_mock() {
metadata
.insert
(
"topic"
.to_string
(),
json!
(
"unicorns"
));
metadata
.insert
(
"topic"
.to_string
(),
json!
(
"unicorns"
));
let
request
=
ResponsesRequest
{
let
request
=
ResponsesRequest
{
model
:
Some
(
"gpt-5-nano"
.to_string
()
)
,
model
:
"gpt-5-nano"
.to_string
(),
input
:
ResponseInput
::
Text
(
"Tell me a bedtime story."
.to_string
()),
input
:
ResponseInput
::
Text
(
"Tell me a bedtime story."
.to_string
()),
instructions
:
Some
(
"Be kind"
.to_string
()),
instructions
:
Some
(
"Be kind"
.to_string
()),
metadata
:
Some
(
metadata
),
metadata
:
Some
(
metadata
),
...
@@ -595,7 +595,7 @@ async fn test_router_factory_openai_mode() {
...
@@ -595,7 +595,7 @@ async fn test_router_factory_openai_mode() {
#[tokio::test]
#[tokio::test]
async
fn
test_unsupported_endpoints
()
{
async
fn
test_unsupported_endpoints
()
{
let
router
=
OpenAIRouter
::
new
(
let
router
=
OpenAIRouter
::
new
(
"https://api.openai.com"
.to_string
(),
vec!
[
"https://api.openai.com"
.to_string
()
]
,
None
,
None
,
Arc
::
new
(
MemoryResponseStorage
::
new
()),
Arc
::
new
(
MemoryResponseStorage
::
new
()),
Arc
::
new
(
MemoryConversationStorage
::
new
()),
Arc
::
new
(
MemoryConversationStorage
::
new
()),
...
@@ -660,7 +660,7 @@ async fn test_openai_router_chat_completion_with_mock() {
...
@@ -660,7 +660,7 @@ async fn test_openai_router_chat_completion_with_mock() {
// Create router pointing to mock server
// Create router pointing to mock server
let
router
=
OpenAIRouter
::
new
(
let
router
=
OpenAIRouter
::
new
(
base_url
,
vec!
[
base_url
]
,
None
,
None
,
Arc
::
new
(
MemoryResponseStorage
::
new
()),
Arc
::
new
(
MemoryResponseStorage
::
new
()),
Arc
::
new
(
MemoryConversationStorage
::
new
()),
Arc
::
new
(
MemoryConversationStorage
::
new
()),
...
@@ -702,7 +702,7 @@ async fn test_openai_e2e_with_server() {
...
@@ -702,7 +702,7 @@ async fn test_openai_e2e_with_server() {
// Create router
// Create router
let
router
=
OpenAIRouter
::
new
(
let
router
=
OpenAIRouter
::
new
(
base_url
,
vec!
[
base_url
]
,
None
,
None
,
Arc
::
new
(
MemoryResponseStorage
::
new
()),
Arc
::
new
(
MemoryResponseStorage
::
new
()),
Arc
::
new
(
MemoryConversationStorage
::
new
()),
Arc
::
new
(
MemoryConversationStorage
::
new
()),
...
@@ -773,7 +773,7 @@ async fn test_openai_router_chat_streaming_with_mock() {
...
@@ -773,7 +773,7 @@ async fn test_openai_router_chat_streaming_with_mock() {
let
mock_server
=
MockOpenAIServer
::
new
()
.await
;
let
mock_server
=
MockOpenAIServer
::
new
()
.await
;
let
base_url
=
mock_server
.base_url
();
let
base_url
=
mock_server
.base_url
();
let
router
=
OpenAIRouter
::
new
(
let
router
=
OpenAIRouter
::
new
(
base_url
,
vec!
[
base_url
]
,
None
,
None
,
Arc
::
new
(
MemoryResponseStorage
::
new
()),
Arc
::
new
(
MemoryResponseStorage
::
new
()),
Arc
::
new
(
MemoryConversationStorage
::
new
()),
Arc
::
new
(
MemoryConversationStorage
::
new
()),
...
@@ -827,7 +827,7 @@ async fn test_openai_router_circuit_breaker() {
...
@@ -827,7 +827,7 @@ async fn test_openai_router_circuit_breaker() {
};
};
let
router
=
OpenAIRouter
::
new
(
let
router
=
OpenAIRouter
::
new
(
"http://invalid-url-that-will-fail"
.to_string
(),
vec!
[
"http://invalid-url-that-will-fail"
.to_string
()
]
,
Some
(
cb_config
),
Some
(
cb_config
),
Arc
::
new
(
MemoryResponseStorage
::
new
()),
Arc
::
new
(
MemoryResponseStorage
::
new
()),
Arc
::
new
(
MemoryConversationStorage
::
new
()),
Arc
::
new
(
MemoryConversationStorage
::
new
()),
...
@@ -856,7 +856,7 @@ async fn test_openai_router_models_auth_forwarding() {
...
@@ -856,7 +856,7 @@ async fn test_openai_router_models_auth_forwarding() {
let
expected_auth
=
"Bearer test-token"
.to_string
();
let
expected_auth
=
"Bearer test-token"
.to_string
();
let
mock_server
=
MockOpenAIServer
::
new_with_auth
(
Some
(
expected_auth
.clone
()))
.await
;
let
mock_server
=
MockOpenAIServer
::
new_with_auth
(
Some
(
expected_auth
.clone
()))
.await
;
let
router
=
OpenAIRouter
::
new
(
let
router
=
OpenAIRouter
::
new
(
mock_server
.base_url
(),
vec!
[
mock_server
.base_url
()
]
,
None
,
None
,
Arc
::
new
(
MemoryResponseStorage
::
new
()),
Arc
::
new
(
MemoryResponseStorage
::
new
()),
Arc
::
new
(
MemoryConversationStorage
::
new
()),
Arc
::
new
(
MemoryConversationStorage
::
new
()),
...
@@ -865,7 +865,8 @@ async fn test_openai_router_models_auth_forwarding() {
...
@@ -865,7 +865,8 @@ async fn test_openai_router_models_auth_forwarding() {
.await
.await
.unwrap
();
.unwrap
();
// 1) Without auth header -> expect 401
// 1) Without auth header -> expect 200 with empty model list
// (multi-endpoint aggregation silently skips failed endpoints)
let
req
=
Request
::
builder
()
let
req
=
Request
::
builder
()
.method
(
Method
::
GET
)
.method
(
Method
::
GET
)
.uri
(
"/models"
)
.uri
(
"/models"
)
...
@@ -873,7 +874,13 @@ async fn test_openai_router_models_auth_forwarding() {
...
@@ -873,7 +874,13 @@ async fn test_openai_router_models_auth_forwarding() {
.unwrap
();
.unwrap
();
let
response
=
router
.get_models
(
req
)
.await
;
let
response
=
router
.get_models
(
req
)
.await
;
assert_eq!
(
response
.status
(),
StatusCode
::
UNAUTHORIZED
);
assert_eq!
(
response
.status
(),
StatusCode
::
OK
);
let
(
_
,
body
)
=
response
.into_parts
();
let
body_bytes
=
axum
::
body
::
to_bytes
(
body
,
usize
::
MAX
)
.await
.unwrap
();
let
body_str
=
String
::
from_utf8
(
body_bytes
.to_vec
())
.unwrap
();
let
models
:
serde_json
::
Value
=
serde_json
::
from_str
(
&
body_str
)
.unwrap
();
assert_eq!
(
models
[
"object"
],
"list"
);
assert_eq!
(
models
[
"data"
]
.as_array
()
.unwrap
()
.len
(),
0
);
// Empty when auth fails
// 2) With auth header -> expect 200
// 2) With auth header -> expect 200
let
req
=
Request
::
builder
()
let
req
=
Request
::
builder
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment