Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
dee197e1
Unverified
Commit
dee197e1
authored
Sep 11, 2025
by
Keyang Ru
Committed by
GitHub
Sep 11, 2025
Browse files
[router] Add OpenAI backend support - core function (#10254)
parent
ab795ae8
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
1158 additions
and
16 deletions
+1158
-16
sgl-router/src/config/types.rs
sgl-router/src/config/types.rs
+8
-0
sgl-router/src/config/validation.rs
sgl-router/src/config/validation.rs
+20
-0
sgl-router/src/main.rs
sgl-router/src/main.rs
+65
-11
sgl-router/src/routers/factory.rs
sgl-router/src/routers/factory.rs
+24
-1
sgl-router/src/routers/http/mod.rs
sgl-router/src/routers/http/mod.rs
+1
-0
sgl-router/src/routers/http/openai_router.rs
sgl-router/src/routers/http/openai_router.rs
+379
-0
sgl-router/src/routers/mod.rs
sgl-router/src/routers/mod.rs
+2
-0
sgl-router/tests/common/mock_mcp_server.rs
sgl-router/tests/common/mock_mcp_server.rs
+1
-4
sgl-router/tests/common/mock_openai_server.rs
sgl-router/tests/common/mock_openai_server.rs
+238
-0
sgl-router/tests/common/mock_worker.rs
sgl-router/tests/common/mock_worker.rs
+0
-0
sgl-router/tests/common/mod.rs
sgl-router/tests/common/mod.rs
+1
-0
sgl-router/tests/test_openai_routing.rs
sgl-router/tests/test_openai_routing.rs
+419
-0
No files found.
sgl-router/src/config/types.rs
View file @
dee197e1
...
@@ -101,6 +101,11 @@ pub enum RoutingMode {
...
@@ -101,6 +101,11 @@ pub enum RoutingMode {
#[serde(skip_serializing_if
=
"Option::is_none"
)]
#[serde(skip_serializing_if
=
"Option::is_none"
)]
decode_policy
:
Option
<
PolicyConfig
>
,
decode_policy
:
Option
<
PolicyConfig
>
,
},
},
#[serde(rename
=
"openai"
)]
OpenAI
{
/// OpenAI-compatible API base(s), provided via worker URLs
worker_urls
:
Vec
<
String
>
,
},
}
}
impl
RoutingMode
{
impl
RoutingMode
{
...
@@ -116,6 +121,8 @@ impl RoutingMode {
...
@@ -116,6 +121,8 @@ impl RoutingMode {
decode_urls
,
decode_urls
,
..
..
}
=>
prefill_urls
.len
()
+
decode_urls
.len
(),
}
=>
prefill_urls
.len
()
+
decode_urls
.len
(),
// OpenAI mode represents a single upstream
RoutingMode
::
OpenAI
{
..
}
=>
1
,
}
}
}
}
...
@@ -380,6 +387,7 @@ impl RouterConfig {
...
@@ -380,6 +387,7 @@ impl RouterConfig {
match
self
.mode
{
match
self
.mode
{
RoutingMode
::
Regular
{
..
}
=>
"regular"
,
RoutingMode
::
Regular
{
..
}
=>
"regular"
,
RoutingMode
::
PrefillDecode
{
..
}
=>
"prefill_decode"
,
RoutingMode
::
PrefillDecode
{
..
}
=>
"prefill_decode"
,
RoutingMode
::
OpenAI
{
..
}
=>
"openai"
,
}
}
}
}
...
...
sgl-router/src/config/validation.rs
View file @
dee197e1
...
@@ -95,6 +95,20 @@ impl ConfigValidator {
...
@@ -95,6 +95,20 @@ impl ConfigValidator {
Self
::
validate_policy
(
d_policy
)
?
;
Self
::
validate_policy
(
d_policy
)
?
;
}
}
}
}
RoutingMode
::
OpenAI
{
worker_urls
}
=>
{
// Require exactly one worker URL for OpenAI router
if
worker_urls
.len
()
!=
1
{
return
Err
(
ConfigError
::
ValidationFailed
{
reason
:
"OpenAI mode requires exactly one --worker-urls entry"
.to_string
(),
});
}
// Validate URL format
if
let
Err
(
e
)
=
url
::
Url
::
parse
(
&
worker_urls
[
0
])
{
return
Err
(
ConfigError
::
ValidationFailed
{
reason
:
format!
(
"Invalid OpenAI worker URL '{}': {}"
,
&
worker_urls
[
0
],
e
),
});
}
}
}
}
Ok
(())
Ok
(())
}
}
...
@@ -243,6 +257,12 @@ impl ConfigValidator {
...
@@ -243,6 +257,12 @@ impl ConfigValidator {
});
});
}
}
}
}
RoutingMode
::
OpenAI
{
..
}
=>
{
// OpenAI mode doesn't use service discovery
return
Err
(
ConfigError
::
ValidationFailed
{
reason
:
"OpenAI mode does not support service discovery"
.to_string
(),
});
}
}
}
Ok
(())
Ok
(())
...
...
sgl-router/src/main.rs
View file @
dee197e1
use
clap
::{
ArgAction
,
Parser
};
use
clap
::{
ArgAction
,
Parser
,
ValueEnum
};
use
sglang_router_rs
::
config
::{
use
sglang_router_rs
::
config
::{
CircuitBreakerConfig
,
ConfigError
,
ConfigResult
,
ConnectionMode
,
DiscoveryConfig
,
CircuitBreakerConfig
,
ConfigError
,
ConfigResult
,
ConnectionMode
,
DiscoveryConfig
,
HealthCheckConfig
,
MetricsConfig
,
PolicyConfig
,
RetryConfig
,
RouterConfig
,
RoutingMode
,
HealthCheckConfig
,
MetricsConfig
,
PolicyConfig
,
RetryConfig
,
RouterConfig
,
RoutingMode
,
...
@@ -41,6 +41,33 @@ fn parse_prefill_args() -> Vec<(String, Option<u16>)> {
...
@@ -41,6 +41,33 @@ fn parse_prefill_args() -> Vec<(String, Option<u16>)> {
prefill_entries
prefill_entries
}
}
#[derive(Copy,
Clone,
Debug,
Eq,
PartialEq,
ValueEnum)]
pub
enum
Backend
{
#[value(name
=
"sglang"
)]
Sglang
,
#[value(name
=
"vllm"
)]
Vllm
,
#[value(name
=
"trtllm"
)]
Trtllm
,
#[value(name
=
"openai"
)]
Openai
,
#[value(name
=
"anthropic"
)]
Anthropic
,
}
impl
std
::
fmt
::
Display
for
Backend
{
fn
fmt
(
&
self
,
f
:
&
mut
std
::
fmt
::
Formatter
<
'_
>
)
->
std
::
fmt
::
Result
{
let
s
=
match
self
{
Backend
::
Sglang
=>
"sglang"
,
Backend
::
Vllm
=>
"vllm"
,
Backend
::
Trtllm
=>
"trtllm"
,
Backend
::
Openai
=>
"openai"
,
Backend
::
Anthropic
=>
"anthropic"
,
};
write!
(
f
,
"{}"
,
s
)
}
}
#[derive(Parser,
Debug)]
#[derive(Parser,
Debug)]
#[command(name
=
"sglang-router"
)]
#[command(name
=
"sglang-router"
)]
#[command(about
=
"SGLang Router - High-performance request distribution across worker nodes"
)]
#[command(about
=
"SGLang Router - High-performance request distribution across worker nodes"
)]
...
@@ -145,6 +172,10 @@ struct CliArgs {
...
@@ -145,6 +172,10 @@ struct CliArgs {
#[arg(long)]
#[arg(long)]
api_key
:
Option
<
String
>
,
api_key
:
Option
<
String
>
,
/// Backend to route requests to (sglang, vllm, trtllm, openai, anthropic)
#[arg(long,
value_enum,
default_value_t
=
Backend::Sglang,
alias
=
"runtime"
)]
backend
:
Backend
,
/// Directory to store log files
/// Directory to store log files
#[arg(long)]
#[arg(long)]
log_dir
:
Option
<
String
>
,
log_dir
:
Option
<
String
>
,
...
@@ -339,6 +370,11 @@ impl CliArgs {
...
@@ -339,6 +370,11 @@ impl CliArgs {
RoutingMode
::
Regular
{
RoutingMode
::
Regular
{
worker_urls
:
vec!
[],
worker_urls
:
vec!
[],
}
}
}
else
if
matches!
(
self
.backend
,
Backend
::
Openai
)
{
// OpenAI backend mode - use worker_urls as base(s)
RoutingMode
::
OpenAI
{
worker_urls
:
self
.worker_urls
.clone
(),
}
}
else
if
self
.pd_disaggregation
{
}
else
if
self
.pd_disaggregation
{
let
decode_urls
=
self
.decode
.clone
();
let
decode_urls
=
self
.decode
.clone
();
...
@@ -409,8 +445,14 @@ impl CliArgs {
...
@@ -409,8 +445,14 @@ impl CliArgs {
}
}
all_urls
.extend
(
decode_urls
.clone
());
all_urls
.extend
(
decode_urls
.clone
());
}
}
RoutingMode
::
OpenAI
{
..
}
=>
{
// For connection-mode detection, skip URLs; OpenAI forces HTTP below.
}
}
}
let
connection_mode
=
Self
::
determine_connection_mode
(
&
all_urls
);
let
connection_mode
=
match
&
mode
{
RoutingMode
::
OpenAI
{
..
}
=>
ConnectionMode
::
Http
,
_
=>
Self
::
determine_connection_mode
(
&
all_urls
),
};
// Build RouterConfig
// Build RouterConfig
Ok
(
RouterConfig
{
Ok
(
RouterConfig
{
...
@@ -543,16 +585,28 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
...
@@ -543,16 +585,28 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
// Print startup info
// Print startup info
println!
(
"SGLang Router starting..."
);
println!
(
"SGLang Router starting..."
);
println!
(
"Host: {}:{}"
,
cli_args
.host
,
cli_args
.port
);
println!
(
"Host: {}:{}"
,
cli_args
.host
,
cli_args
.port
);
println!
(
let
mode_str
=
if
cli_args
.enable_igw
{
"Mode: {}"
,
"IGW (Inference Gateway)"
.to_string
()
if
cli_args
.enable_igw
{
}
else
if
matches!
(
cli_args
.backend
,
Backend
::
Openai
)
{
"IGW (Inference Gateway)"
"OpenAI Backend"
.to_string
()
}
else
if
cli_args
.pd_disaggregation
{
}
else
if
cli_args
.pd_disaggregation
{
"PD Disaggregated"
"PD Disaggregated"
.to_string
()
}
else
{
}
else
{
"Regular"
format!
(
"Regular ({})"
,
cli_args
.backend
)
};
println!
(
"Mode: {}"
,
mode_str
);
// Warn for runtimes that are parsed but not yet implemented
match
cli_args
.backend
{
Backend
::
Vllm
|
Backend
::
Trtllm
|
Backend
::
Anthropic
=>
{
println!
(
"WARNING: runtime '{}' not implemented yet; falling back to regular routing.
\
Provide --worker-urls or PD flags as usual."
,
cli_args
.backend
);
}
}
);
Backend
::
Sglang
|
Backend
::
Openai
=>
{}
}
if
!
cli_args
.enable_igw
{
if
!
cli_args
.enable_igw
{
println!
(
"Policy: {}"
,
cli_args
.policy
);
println!
(
"Policy: {}"
,
cli_args
.policy
);
...
...
sgl-router/src/routers/factory.rs
View file @
dee197e1
//! Factory for creating router instances
//! Factory for creating router instances
use
super
::{
use
super
::{
http
::{
pd_router
::
PDRouter
,
router
::
Router
},
http
::{
openai_router
::
OpenAIRouter
,
pd_router
::
PDRouter
,
router
::
Router
},
RouterTrait
,
RouterTrait
,
};
};
use
crate
::
config
::{
ConnectionMode
,
PolicyConfig
,
RoutingMode
};
use
crate
::
config
::{
ConnectionMode
,
PolicyConfig
,
RoutingMode
};
...
@@ -44,6 +44,9 @@ impl RouterFactory {
...
@@ -44,6 +44,9 @@ impl RouterFactory {
)
)
.await
.await
}
}
RoutingMode
::
OpenAI
{
..
}
=>
{
Err
(
"OpenAI mode requires HTTP connection_mode"
.to_string
())
}
}
}
}
}
ConnectionMode
::
Http
=>
{
ConnectionMode
::
Http
=>
{
...
@@ -69,6 +72,9 @@ impl RouterFactory {
...
@@ -69,6 +72,9 @@ impl RouterFactory {
)
)
.await
.await
}
}
RoutingMode
::
OpenAI
{
worker_urls
,
..
}
=>
{
Self
::
create_openai_router
(
worker_urls
.clone
(),
ctx
)
.await
}
}
}
}
}
}
}
...
@@ -164,6 +170,23 @@ impl RouterFactory {
...
@@ -164,6 +170,23 @@ impl RouterFactory {
Ok
(
Box
::
new
(
router
))
Ok
(
Box
::
new
(
router
))
}
}
/// Create an OpenAI router
async
fn
create_openai_router
(
worker_urls
:
Vec
<
String
>
,
ctx
:
&
Arc
<
AppContext
>
,
)
->
Result
<
Box
<
dyn
RouterTrait
>
,
String
>
{
// Use the first worker URL as the OpenAI-compatible base
let
base_url
=
worker_urls
.first
()
.cloned
()
.ok_or_else
(||
"OpenAI mode requires at least one worker URL"
.to_string
())
?
;
let
router
=
OpenAIRouter
::
new
(
base_url
,
Some
(
ctx
.router_config.circuit_breaker
.clone
()))
.await
?
;
Ok
(
Box
::
new
(
router
))
}
/// Create an IGW router (placeholder for future implementation)
/// Create an IGW router (placeholder for future implementation)
async
fn
create_igw_router
(
_
ctx
:
&
Arc
<
AppContext
>
)
->
Result
<
Box
<
dyn
RouterTrait
>
,
String
>
{
async
fn
create_igw_router
(
_
ctx
:
&
Arc
<
AppContext
>
)
->
Result
<
Box
<
dyn
RouterTrait
>
,
String
>
{
// For now, return an error indicating IGW is not yet implemented
// For now, return an error indicating IGW is not yet implemented
...
...
sgl-router/src/routers/http/mod.rs
View file @
dee197e1
//! HTTP router implementations
//! HTTP router implementations
pub
mod
openai_router
;
pub
mod
pd_router
;
pub
mod
pd_router
;
pub
mod
pd_types
;
pub
mod
pd_types
;
pub
mod
router
;
pub
mod
router
;
sgl-router/src/routers/http/openai_router.rs
0 → 100644
View file @
dee197e1
//! OpenAI router implementation (reqwest-based)
use
crate
::
config
::
CircuitBreakerConfig
;
use
crate
::
core
::{
CircuitBreaker
,
CircuitBreakerConfig
as
CoreCircuitBreakerConfig
};
use
crate
::
protocols
::
spec
::{
ChatCompletionRequest
,
CompletionRequest
,
GenerateRequest
};
use
async_trait
::
async_trait
;
use
axum
::{
body
::
Body
,
extract
::
Request
,
http
::{
header
::
CONTENT_TYPE
,
HeaderMap
,
HeaderValue
,
StatusCode
},
response
::{
IntoResponse
,
Response
},
};
use
futures_util
::
StreamExt
;
use
std
::{
any
::
Any
,
sync
::
atomic
::{
AtomicBool
,
Ordering
},
};
/// Router for OpenAI backend
#[derive(Debug)]
pub
struct
OpenAIRouter
{
/// HTTP client for upstream OpenAI-compatible API
client
:
reqwest
::
Client
,
/// Base URL for identification (no trailing slash)
base_url
:
String
,
/// Circuit breaker
circuit_breaker
:
CircuitBreaker
,
/// Health status
healthy
:
AtomicBool
,
}
impl
OpenAIRouter
{
/// Create a new OpenAI router
pub
async
fn
new
(
base_url
:
String
,
circuit_breaker_config
:
Option
<
CircuitBreakerConfig
>
,
)
->
Result
<
Self
,
String
>
{
let
client
=
reqwest
::
Client
::
builder
()
.timeout
(
std
::
time
::
Duration
::
from_secs
(
300
))
.build
()
.map_err
(|
e
|
format!
(
"Failed to create HTTP client: {}"
,
e
))
?
;
let
base_url
=
base_url
.trim_end_matches
(
'/'
)
.to_string
();
// Convert circuit breaker config
let
core_cb_config
=
circuit_breaker_config
.map
(|
cb
|
CoreCircuitBreakerConfig
{
failure_threshold
:
cb
.failure_threshold
,
success_threshold
:
cb
.success_threshold
,
timeout_duration
:
std
::
time
::
Duration
::
from_secs
(
cb
.timeout_duration_secs
),
window_duration
:
std
::
time
::
Duration
::
from_secs
(
cb
.window_duration_secs
),
})
.unwrap_or_default
();
let
circuit_breaker
=
CircuitBreaker
::
with_config
(
core_cb_config
);
Ok
(
Self
{
client
,
base_url
,
circuit_breaker
,
healthy
:
AtomicBool
::
new
(
true
),
})
}
}
#[async_trait]
impl
super
::
super
::
WorkerManagement
for
OpenAIRouter
{
async
fn
add_worker
(
&
self
,
_
worker_url
:
&
str
)
->
Result
<
String
,
String
>
{
Err
(
"Cannot add workers to OpenAI router"
.to_string
())
}
fn
remove_worker
(
&
self
,
_
worker_url
:
&
str
)
{
// No-op for OpenAI router
}
fn
get_worker_urls
(
&
self
)
->
Vec
<
String
>
{
vec!
[
self
.base_url
.clone
()]
}
}
#[async_trait]
impl
super
::
super
::
RouterTrait
for
OpenAIRouter
{
fn
as_any
(
&
self
)
->
&
dyn
Any
{
self
}
async
fn
health
(
&
self
,
_
req
:
Request
<
Body
>
)
->
Response
{
// Simple upstream probe: GET {base}/v1/models without auth
let
url
=
format!
(
"{}/v1/models"
,
self
.base_url
);
match
self
.client
.get
(
&
url
)
.timeout
(
std
::
time
::
Duration
::
from_secs
(
2
))
.send
()
.await
{
Ok
(
resp
)
=>
{
let
code
=
resp
.status
();
// Treat success and auth-required as healthy (endpoint reachable)
if
code
.is_success
()
||
code
.as_u16
()
==
401
||
code
.as_u16
()
==
403
{
(
StatusCode
::
OK
,
"OK"
)
.into_response
()
}
else
{
(
StatusCode
::
SERVICE_UNAVAILABLE
,
format!
(
"Upstream status: {}"
,
code
),
)
.into_response
()
}
}
Err
(
e
)
=>
(
StatusCode
::
SERVICE_UNAVAILABLE
,
format!
(
"Upstream error: {}"
,
e
),
)
.into_response
(),
}
}
async
fn
health_generate
(
&
self
,
_
req
:
Request
<
Body
>
)
->
Response
{
// For OpenAI, health_generate is the same as health
self
.health
(
_
req
)
.await
}
async
fn
get_server_info
(
&
self
,
_
req
:
Request
<
Body
>
)
->
Response
{
let
info
=
serde_json
::
json!
({
"router_type"
:
"openai"
,
"workers"
:
1
,
"base_url"
:
&
self
.base_url
});
(
StatusCode
::
OK
,
info
.to_string
())
.into_response
()
}
async
fn
get_models
(
&
self
,
req
:
Request
<
Body
>
)
->
Response
{
// Proxy to upstream /v1/models; forward Authorization header if provided
let
headers
=
req
.headers
();
let
mut
upstream
=
self
.client
.get
(
format!
(
"{}/v1/models"
,
self
.base_url
));
if
let
Some
(
auth
)
=
headers
.get
(
"authorization"
)
.or_else
(||
headers
.get
(
"Authorization"
))
{
upstream
=
upstream
.header
(
"Authorization"
,
auth
);
}
match
upstream
.send
()
.await
{
Ok
(
res
)
=>
{
let
status
=
StatusCode
::
from_u16
(
res
.status
()
.as_u16
())
.unwrap_or
(
StatusCode
::
INTERNAL_SERVER_ERROR
);
let
content_type
=
res
.headers
()
.get
(
CONTENT_TYPE
)
.cloned
();
match
res
.bytes
()
.await
{
Ok
(
body
)
=>
{
let
mut
response
=
Response
::
new
(
axum
::
body
::
Body
::
from
(
body
));
*
response
.status_mut
()
=
status
;
if
let
Some
(
ct
)
=
content_type
{
response
.headers_mut
()
.insert
(
CONTENT_TYPE
,
ct
);
}
response
}
Err
(
e
)
=>
(
StatusCode
::
INTERNAL_SERVER_ERROR
,
format!
(
"Failed to read upstream response: {}"
,
e
),
)
.into_response
(),
}
}
Err
(
e
)
=>
(
StatusCode
::
BAD_GATEWAY
,
format!
(
"Failed to contact upstream: {}"
,
e
),
)
.into_response
(),
}
}
async
fn
get_model_info
(
&
self
,
_
req
:
Request
<
Body
>
)
->
Response
{
// Not directly supported without model param; return 501
(
StatusCode
::
NOT_IMPLEMENTED
,
"get_model_info not implemented for OpenAI router"
,
)
.into_response
()
}
async
fn
route_generate
(
&
self
,
_
headers
:
Option
<&
HeaderMap
>
,
_
body
:
&
GenerateRequest
,
)
->
Response
{
// Generate endpoint is SGLang-specific, not supported for OpenAI backend
(
StatusCode
::
NOT_IMPLEMENTED
,
"Generate endpoint not supported for OpenAI backend"
,
)
.into_response
()
}
async
fn
route_chat
(
&
self
,
headers
:
Option
<&
HeaderMap
>
,
body
:
&
ChatCompletionRequest
,
)
->
Response
{
if
!
self
.circuit_breaker
.can_execute
()
{
return
(
StatusCode
::
SERVICE_UNAVAILABLE
,
"Circuit breaker open"
)
.into_response
();
}
// Serialize request body, removing SGLang-only fields
let
mut
payload
=
match
serde_json
::
to_value
(
body
)
{
Ok
(
v
)
=>
v
,
Err
(
e
)
=>
{
return
(
StatusCode
::
BAD_REQUEST
,
format!
(
"Failed to serialize request: {}"
,
e
),
)
.into_response
();
}
};
if
let
Some
(
obj
)
=
payload
.as_object_mut
()
{
for
key
in
[
"top_k"
,
"min_p"
,
"min_tokens"
,
"regex"
,
"ebnf"
,
"stop_token_ids"
,
"no_stop_trim"
,
"ignore_eos"
,
"continue_final_message"
,
"skip_special_tokens"
,
"lora_path"
,
"session_params"
,
"separate_reasoning"
,
"stream_reasoning"
,
"chat_template_kwargs"
,
"return_hidden_states"
,
"repetition_penalty"
,
]
{
obj
.remove
(
key
);
}
}
let
url
=
format!
(
"{}/v1/chat/completions"
,
self
.base_url
);
let
mut
req
=
self
.client
.post
(
&
url
)
.json
(
&
payload
);
// Forward Authorization header if provided
if
let
Some
(
h
)
=
headers
{
if
let
Some
(
auth
)
=
h
.get
(
"authorization"
)
.or_else
(||
h
.get
(
"Authorization"
))
{
req
=
req
.header
(
"Authorization"
,
auth
);
}
}
// Accept SSE when stream=true
if
body
.stream
{
req
=
req
.header
(
"Accept"
,
"text/event-stream"
);
}
let
resp
=
match
req
.send
()
.await
{
Ok
(
r
)
=>
r
,
Err
(
e
)
=>
{
self
.circuit_breaker
.record_failure
();
return
(
StatusCode
::
SERVICE_UNAVAILABLE
,
format!
(
"Failed to contact upstream: {}"
,
e
),
)
.into_response
();
}
};
let
status
=
StatusCode
::
from_u16
(
resp
.status
()
.as_u16
())
.unwrap_or
(
StatusCode
::
INTERNAL_SERVER_ERROR
);
if
!
body
.stream
{
// Capture Content-Type before consuming response body
let
content_type
=
resp
.headers
()
.get
(
CONTENT_TYPE
)
.cloned
();
match
resp
.bytes
()
.await
{
Ok
(
body
)
=>
{
self
.circuit_breaker
.record_success
();
let
mut
response
=
Response
::
new
(
axum
::
body
::
Body
::
from
(
body
));
*
response
.status_mut
()
=
status
;
if
let
Some
(
ct
)
=
content_type
{
response
.headers_mut
()
.insert
(
CONTENT_TYPE
,
ct
);
}
response
}
Err
(
e
)
=>
{
self
.circuit_breaker
.record_failure
();
(
StatusCode
::
INTERNAL_SERVER_ERROR
,
format!
(
"Failed to read response: {}"
,
e
),
)
.into_response
()
}
}
}
else
{
// Stream SSE bytes to client
let
stream
=
resp
.bytes_stream
();
let
(
tx
,
rx
)
=
tokio
::
sync
::
mpsc
::
unbounded_channel
();
tokio
::
spawn
(
async
move
{
let
mut
s
=
stream
;
while
let
Some
(
chunk
)
=
s
.next
()
.await
{
match
chunk
{
Ok
(
bytes
)
=>
{
if
tx
.send
(
Ok
(
bytes
))
.is_err
()
{
break
;
}
}
Err
(
e
)
=>
{
let
_
=
tx
.send
(
Err
(
format!
(
"Stream error: {}"
,
e
)));
break
;
}
}
}
});
let
mut
response
=
Response
::
new
(
Body
::
from_stream
(
tokio_stream
::
wrappers
::
UnboundedReceiverStream
::
new
(
rx
),
));
*
response
.status_mut
()
=
status
;
response
.headers_mut
()
.insert
(
CONTENT_TYPE
,
HeaderValue
::
from_static
(
"text/event-stream"
));
response
}
}
async
fn
route_completion
(
&
self
,
_
headers
:
Option
<&
HeaderMap
>
,
_
body
:
&
CompletionRequest
,
)
->
Response
{
// Completion endpoint not implemented for OpenAI backend
(
StatusCode
::
NOT_IMPLEMENTED
,
"Completion endpoint not implemented for OpenAI backend"
,
)
.into_response
()
}
async
fn
flush_cache
(
&
self
)
->
Response
{
(
StatusCode
::
NOT_IMPLEMENTED
,
"flush_cache not supported for OpenAI router"
,
)
.into_response
()
}
async
fn
get_worker_loads
(
&
self
)
->
Response
{
(
StatusCode
::
NOT_IMPLEMENTED
,
"get_worker_loads not supported for OpenAI router"
,
)
.into_response
()
}
fn
router_type
(
&
self
)
->
&
'static
str
{
"openai"
}
fn
readiness
(
&
self
)
->
Response
{
if
self
.healthy
.load
(
Ordering
::
Acquire
)
&&
self
.circuit_breaker
.can_execute
()
{
(
StatusCode
::
OK
,
"Ready"
)
.into_response
()
}
else
{
(
StatusCode
::
SERVICE_UNAVAILABLE
,
"Not ready"
)
.into_response
()
}
}
async
fn
route_embeddings
(
&
self
,
_
headers
:
Option
<&
HeaderMap
>
,
_
body
:
Body
)
->
Response
{
(
StatusCode
::
NOT_IMPLEMENTED
,
"Embeddings endpoint not implemented for OpenAI backend"
,
)
.into_response
()
}
async
fn
route_rerank
(
&
self
,
_
headers
:
Option
<&
HeaderMap
>
,
_
body
:
Body
)
->
Response
{
(
StatusCode
::
NOT_IMPLEMENTED
,
"Rerank endpoint not implemented for OpenAI backend"
,
)
.into_response
()
}
}
sgl-router/src/routers/mod.rs
View file @
dee197e1
...
@@ -17,6 +17,8 @@ pub mod header_utils;
...
@@ -17,6 +17,8 @@ pub mod header_utils;
pub
mod
http
;
pub
mod
http
;
pub
use
factory
::
RouterFactory
;
pub
use
factory
::
RouterFactory
;
// Re-export HTTP routers for convenience (keeps routers::openai_router path working)
pub
use
http
::{
openai_router
,
pd_router
,
pd_types
,
router
};
/// Worker management trait for administrative operations
/// Worker management trait for administrative operations
///
///
...
...
sgl-router/tests/common/mock_mcp_server.rs
View file @
dee197e1
...
@@ -63,10 +63,7 @@ impl ServerHandler for MockSearchServer {
...
@@ -63,10 +63,7 @@ impl ServerHandler for MockSearchServer {
ServerInfo
{
ServerInfo
{
protocol_version
:
ProtocolVersion
::
V_2024_11_05
,
protocol_version
:
ProtocolVersion
::
V_2024_11_05
,
capabilities
:
ServerCapabilities
::
builder
()
.enable_tools
()
.build
(),
capabilities
:
ServerCapabilities
::
builder
()
.enable_tools
()
.build
(),
server_info
:
Implementation
{
server_info
:
Implementation
::
from_build_env
(),
name
:
"Mock MCP Server"
.to_string
(),
version
:
"1.0.0"
.to_string
(),
},
instructions
:
Some
(
"Mock server for testing"
.to_string
()),
instructions
:
Some
(
"Mock server for testing"
.to_string
()),
}
}
}
}
...
...
sgl-router/tests/common/mock_openai_server.rs
0 → 100644
View file @
dee197e1
//! Mock servers for testing
#![allow(dead_code)]
use
axum
::{
body
::
Body
,
extract
::{
Request
,
State
},
http
::{
HeaderValue
,
StatusCode
},
response
::
sse
::{
Event
,
KeepAlive
},
response
::{
IntoResponse
,
Response
,
Sse
},
routing
::
post
,
Json
,
Router
,
};
use
futures_util
::
stream
::{
self
,
StreamExt
};
use
serde_json
::
json
;
use
std
::
net
::
SocketAddr
;
use
std
::
sync
::
Arc
;
use
tokio
::
net
::
TcpListener
;
/// Mock OpenAI API server for testing
pub
struct
MockOpenAIServer
{
addr
:
SocketAddr
,
_
handle
:
tokio
::
task
::
JoinHandle
<
()
>
,
}
#[derive(Clone)]
struct
MockServerState
{
require_auth
:
bool
,
expected_auth
:
Option
<
String
>
,
}
impl
MockOpenAIServer
{
/// Create and start a new mock OpenAI server
pub
async
fn
new
()
->
Self
{
Self
::
new_with_auth
(
None
)
.await
}
/// Create and start a new mock OpenAI server with optional auth requirement
pub
async
fn
new_with_auth
(
expected_auth
:
Option
<
String
>
)
->
Self
{
let
listener
=
TcpListener
::
bind
(
"127.0.0.1:0"
)
.await
.unwrap
();
let
addr
=
listener
.local_addr
()
.unwrap
();
let
state
=
Arc
::
new
(
MockServerState
{
require_auth
:
expected_auth
.is_some
(),
expected_auth
,
});
let
app
=
Router
::
new
()
.route
(
"/v1/chat/completions"
,
post
(
mock_chat_completions
))
.route
(
"/v1/completions"
,
post
(
mock_completions
))
.route
(
"/v1/models"
,
post
(
mock_models
)
.get
(
mock_models
))
.with_state
(
state
);
let
handle
=
tokio
::
spawn
(
async
move
{
axum
::
serve
(
listener
,
app
)
.await
.unwrap
();
});
// Give the server a moment to start
tokio
::
time
::
sleep
(
tokio
::
time
::
Duration
::
from_millis
(
10
))
.await
;
Self
{
addr
,
_
handle
:
handle
,
}
}
/// Get the base URL for this mock server
pub
fn
base_url
(
&
self
)
->
String
{
format!
(
"http://{}"
,
self
.addr
)
}
}
/// Mock chat completions endpoint
async
fn
mock_chat_completions
(
req
:
Request
<
Body
>
)
->
Response
{
let
(
_
,
body
)
=
req
.into_parts
();
let
body_bytes
=
match
axum
::
body
::
to_bytes
(
body
,
usize
::
MAX
)
.await
{
Ok
(
bytes
)
=>
bytes
,
Err
(
_
)
=>
return
StatusCode
::
BAD_REQUEST
.into_response
(),
};
let
request
:
serde_json
::
Value
=
match
serde_json
::
from_slice
(
&
body_bytes
)
{
Ok
(
req
)
=>
req
,
Err
(
_
)
=>
return
StatusCode
::
BAD_REQUEST
.into_response
(),
};
// Extract model from request or use default (owned String to satisfy 'static in stream)
let
model
:
String
=
request
.get
(
"model"
)
.and_then
(|
v
|
v
.as_str
())
.unwrap_or
(
"gpt-3.5-turbo"
)
.to_string
();
// If stream requested, return SSE
let
is_stream
=
request
.get
(
"stream"
)
.and_then
(|
v
|
v
.as_bool
())
.unwrap_or
(
false
);
if
is_stream
{
let
created
=
1677652288u64
;
// Single chunk then [DONE]
let
model_chunk
=
model
.clone
();
let
event_stream
=
stream
::
once
(
async
move
{
let
chunk
=
json!
({
"id"
:
"chatcmpl-123456789"
,
"object"
:
"chat.completion.chunk"
,
"created"
:
created
,
"model"
:
model_chunk
,
"choices"
:
[{
"index"
:
0
,
"delta"
:
{
"content"
:
"Hello!"
},
"finish_reason"
:
null
}]
});
Ok
::
<
_
,
std
::
convert
::
Infallible
>
(
Event
::
default
()
.data
(
chunk
.to_string
()))
})
.chain
(
stream
::
once
(
async
{
Ok
(
Event
::
default
()
.data
(
"[DONE]"
))
}));
Sse
::
new
(
event_stream
)
.keep_alive
(
KeepAlive
::
default
())
.into_response
()
}
else
{
// Create a mock non-streaming response
let
response
=
json!
({
"id"
:
"chatcmpl-123456789"
,
"object"
:
"chat.completion"
,
"created"
:
1677652288
,
"model"
:
model
,
"choices"
:
[{
"index"
:
0
,
"message"
:
{
"role"
:
"assistant"
,
"content"
:
"Hello! I'm a mock OpenAI assistant. How can I help you today?"
},
"finish_reason"
:
"stop"
}],
"usage"
:
{
"prompt_tokens"
:
9
,
"completion_tokens"
:
12
,
"total_tokens"
:
21
}
});
Json
(
response
)
.into_response
()
}
}
/// Mock completions endpoint (legacy)
async
fn
mock_completions
(
req
:
Request
<
Body
>
)
->
Response
{
let
(
_
,
body
)
=
req
.into_parts
();
let
body_bytes
=
match
axum
::
body
::
to_bytes
(
body
,
usize
::
MAX
)
.await
{
Ok
(
bytes
)
=>
bytes
,
Err
(
_
)
=>
return
StatusCode
::
BAD_REQUEST
.into_response
(),
};
let
request
:
serde_json
::
Value
=
match
serde_json
::
from_slice
(
&
body_bytes
)
{
Ok
(
req
)
=>
req
,
Err
(
_
)
=>
return
StatusCode
::
BAD_REQUEST
.into_response
(),
};
let
model
=
request
[
"model"
]
.as_str
()
.unwrap_or
(
"text-davinci-003"
);
let
response
=
json!
({
"id"
:
"cmpl-123456789"
,
"object"
:
"text_completion"
,
"created"
:
1677652288
,
"model"
:
model
,
"choices"
:
[{
"text"
:
" This is a mock completion response."
,
"index"
:
0
,
"logprobs"
:
null
,
"finish_reason"
:
"stop"
}],
"usage"
:
{
"prompt_tokens"
:
5
,
"completion_tokens"
:
7
,
"total_tokens"
:
12
}
});
Json
(
response
)
.into_response
()
}
/// Mock models endpoint
async
fn
mock_models
(
State
(
state
):
State
<
Arc
<
MockServerState
>>
,
req
:
Request
<
Body
>
)
->
Response
{
// Optionally enforce Authorization header
if
state
.require_auth
{
let
auth
=
req
.headers
()
.get
(
"authorization"
)
.or_else
(||
req
.headers
()
.get
(
"Authorization"
))
.and_then
(|
v
|
v
.to_str
()
.ok
())
.map
(|
s
|
s
.to_string
());
let
auth_ok
=
match
(
&
state
.expected_auth
,
auth
)
{
(
Some
(
expected
),
Some
(
got
))
=>
&
got
==
expected
,
(
None
,
Some
(
_
))
=>
true
,
_
=>
false
,
};
if
!
auth_ok
{
let
mut
response
=
Response
::
new
(
Body
::
from
(
json!
({
"error"
:
{
"message"
:
"Unauthorized"
,
"type"
:
"invalid_request_error"
}
})
.to_string
(),
));
*
response
.status_mut
()
=
StatusCode
::
UNAUTHORIZED
;
response
.headers_mut
()
.insert
(
"WWW-Authenticate"
,
HeaderValue
::
from_static
(
"Bearer"
));
return
response
;
}
}
let
response
=
json!
({
"object"
:
"list"
,
"data"
:
[
{
"id"
:
"gpt-4"
,
"object"
:
"model"
,
"created"
:
1677610602
,
"owned_by"
:
"openai"
},
{
"id"
:
"gpt-3.5-turbo"
,
"object"
:
"model"
,
"created"
:
1677610602
,
"owned_by"
:
"openai"
}
]
});
Json
(
response
)
.into_response
()
}
sgl-router/tests/common/mock_worker.rs
100644 → 100755
View file @
dee197e1
File mode changed from 100644 to 100755
sgl-router/tests/common/mod.rs
View file @
dee197e1
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
#![allow(dead_code)]
#![allow(dead_code)]
pub
mod
mock_mcp_server
;
pub
mod
mock_mcp_server
;
pub
mod
mock_openai_server
;
pub
mod
mock_worker
;
pub
mod
mock_worker
;
pub
mod
test_app
;
pub
mod
test_app
;
...
...
sgl-router/tests/test_openai_routing.rs
0 → 100644
View file @
dee197e1
//! Comprehensive integration tests for OpenAI backend functionality
use
axum
::{
body
::
Body
,
extract
::
Request
,
http
::{
Method
,
StatusCode
},
routing
::
post
,
Router
,
};
use
serde_json
::
json
;
use
sglang_router_rs
::{
config
::{
RouterConfig
,
RoutingMode
},
protocols
::
spec
::{
ChatCompletionRequest
,
ChatMessage
,
CompletionRequest
,
GenerateRequest
,
UserMessageContent
,
},
routers
::{
openai_router
::
OpenAIRouter
,
RouterTrait
},
};
use
std
::
sync
::
Arc
;
use
tower
::
ServiceExt
;
mod
common
;
use
common
::
mock_openai_server
::
MockOpenAIServer
;
/// Helper function to create a minimal chat completion request for testing
fn
create_minimal_chat_request
()
->
ChatCompletionRequest
{
let
val
=
json!
({
"model"
:
"gpt-3.5-turbo"
,
"messages"
:
[
{
"role"
:
"user"
,
"content"
:
"Hello"
}
],
"max_tokens"
:
100
});
serde_json
::
from_value
(
val
)
.unwrap
()
}
/// Helper function to create a minimal completion request for testing
fn
create_minimal_completion_request
()
->
CompletionRequest
{
CompletionRequest
{
model
:
"gpt-3.5-turbo"
.to_string
(),
prompt
:
sglang_router_rs
::
protocols
::
spec
::
StringOrArray
::
String
(
"Hello"
.to_string
()),
suffix
:
None
,
max_tokens
:
Some
(
100
),
temperature
:
None
,
top_p
:
None
,
n
:
None
,
stream
:
false
,
stream_options
:
None
,
logprobs
:
None
,
echo
:
false
,
stop
:
None
,
presence_penalty
:
None
,
frequency_penalty
:
None
,
best_of
:
None
,
logit_bias
:
None
,
user
:
None
,
seed
:
None
,
top_k
:
None
,
min_p
:
None
,
min_tokens
:
None
,
repetition_penalty
:
None
,
regex
:
None
,
ebnf
:
None
,
json_schema
:
None
,
stop_token_ids
:
None
,
no_stop_trim
:
false
,
ignore_eos
:
false
,
skip_special_tokens
:
true
,
lora_path
:
None
,
session_params
:
None
,
return_hidden_states
:
false
,
other
:
serde_json
::
Map
::
new
(),
}
}
// ============= Basic Unit Tests =============
/// Test basic OpenAI router creation and configuration
#[tokio::test]
async
fn
test_openai_router_creation
()
{
let
router
=
OpenAIRouter
::
new
(
"https://api.openai.com"
.to_string
(),
None
)
.await
;
assert
!
(
router
.is_ok
(),
"Router creation should succeed"
);
let
router
=
router
.unwrap
();
assert_eq!
(
router
.router_type
(),
"openai"
);
assert
!
(
!
router
.is_pd_mode
());
}
/// Test health endpoints
#[tokio::test]
async
fn
test_openai_router_health
()
{
let
router
=
OpenAIRouter
::
new
(
"https://api.openai.com"
.to_string
(),
None
)
.await
.unwrap
();
let
req
=
Request
::
builder
()
.method
(
Method
::
GET
)
.uri
(
"/health"
)
.body
(
Body
::
empty
())
.unwrap
();
let
response
=
router
.health
(
req
)
.await
;
assert_eq!
(
response
.status
(),
StatusCode
::
OK
);
}
/// Test server info endpoint
#[tokio::test]
async
fn
test_openai_router_server_info
()
{
let
router
=
OpenAIRouter
::
new
(
"https://api.openai.com"
.to_string
(),
None
)
.await
.unwrap
();
let
req
=
Request
::
builder
()
.method
(
Method
::
GET
)
.uri
(
"/info"
)
.body
(
Body
::
empty
())
.unwrap
();
let
response
=
router
.get_server_info
(
req
)
.await
;
assert_eq!
(
response
.status
(),
StatusCode
::
OK
);
let
(
_
,
body
)
=
response
.into_parts
();
let
body_bytes
=
axum
::
body
::
to_bytes
(
body
,
usize
::
MAX
)
.await
.unwrap
();
let
body_str
=
String
::
from_utf8
(
body_bytes
.to_vec
())
.unwrap
();
assert
!
(
body_str
.contains
(
"openai"
));
}
/// Test models endpoint
#[tokio::test]
async
fn
test_openai_router_models
()
{
// Use mock server for deterministic models response
let
mock_server
=
MockOpenAIServer
::
new
()
.await
;
let
router
=
OpenAIRouter
::
new
(
mock_server
.base_url
(),
None
)
.await
.unwrap
();
let
req
=
Request
::
builder
()
.method
(
Method
::
GET
)
.uri
(
"/models"
)
.body
(
Body
::
empty
())
.unwrap
();
let
response
=
router
.get_models
(
req
)
.await
;
assert_eq!
(
response
.status
(),
StatusCode
::
OK
);
let
(
_
,
body
)
=
response
.into_parts
();
let
body_bytes
=
axum
::
body
::
to_bytes
(
body
,
usize
::
MAX
)
.await
.unwrap
();
let
body_str
=
String
::
from_utf8
(
body_bytes
.to_vec
())
.unwrap
();
let
models
:
serde_json
::
Value
=
serde_json
::
from_str
(
&
body_str
)
.unwrap
();
assert_eq!
(
models
[
"object"
],
"list"
);
assert
!
(
models
[
"data"
]
.is_array
());
}
/// Test router factory with OpenAI routing mode
#[tokio::test]
async
fn
test_router_factory_openai_mode
()
{
let
routing_mode
=
RoutingMode
::
OpenAI
{
worker_urls
:
vec!
[
"https://api.openai.com"
.to_string
()],
};
let
router_config
=
RouterConfig
::
new
(
routing_mode
,
sglang_router_rs
::
config
::
PolicyConfig
::
Random
);
let
app_context
=
common
::
create_test_context
(
router_config
);
let
router
=
sglang_router_rs
::
routers
::
RouterFactory
::
create_router
(
&
app_context
)
.await
;
assert
!
(
router
.is_ok
(),
"Router factory should create OpenAI router successfully"
);
let
router
=
router
.unwrap
();
assert_eq!
(
router
.router_type
(),
"openai"
);
}
/// Test that unsupported endpoints return proper error codes
#[tokio::test]
async
fn
test_unsupported_endpoints
()
{
let
router
=
OpenAIRouter
::
new
(
"https://api.openai.com"
.to_string
(),
None
)
.await
.unwrap
();
// Test generate endpoint (SGLang-specific, should not be supported)
let
generate_request
=
GenerateRequest
{
prompt
:
None
,
text
:
Some
(
"Hello world"
.to_string
()),
input_ids
:
None
,
parameters
:
None
,
sampling_params
:
None
,
stream
:
false
,
return_logprob
:
false
,
lora_path
:
None
,
session_params
:
None
,
return_hidden_states
:
false
,
rid
:
None
,
};
let
response
=
router
.route_generate
(
None
,
&
generate_request
)
.await
;
assert_eq!
(
response
.status
(),
StatusCode
::
NOT_IMPLEMENTED
);
// Test completion endpoint (should also not be supported)
let
completion_request
=
create_minimal_completion_request
();
let
response
=
router
.route_completion
(
None
,
&
completion_request
)
.await
;
assert_eq!
(
response
.status
(),
StatusCode
::
NOT_IMPLEMENTED
);
}
// ============= Mock Server E2E Tests =============
/// Test chat completion with mock OpenAI server
#[tokio::test]
async
fn
test_openai_router_chat_completion_with_mock
()
{
// Start a mock OpenAI server
let
mock_server
=
MockOpenAIServer
::
new
()
.await
;
let
base_url
=
mock_server
.base_url
();
// Create router pointing to mock server
let
router
=
OpenAIRouter
::
new
(
base_url
,
None
)
.await
.unwrap
();
// Create a minimal chat completion request
let
mut
chat_request
=
create_minimal_chat_request
();
chat_request
.messages
=
vec!
[
ChatMessage
::
User
{
role
:
"user"
.to_string
(),
content
:
UserMessageContent
::
Text
(
"Hello, how are you?"
.to_string
()),
name
:
None
,
}];
chat_request
.temperature
=
Some
(
0.7
);
// Route the request
let
response
=
router
.route_chat
(
None
,
&
chat_request
)
.await
;
// Should get a successful response from mock server
assert_eq!
(
response
.status
(),
StatusCode
::
OK
);
let
(
_
,
body
)
=
response
.into_parts
();
let
body_bytes
=
axum
::
body
::
to_bytes
(
body
,
usize
::
MAX
)
.await
.unwrap
();
let
body_str
=
String
::
from_utf8
(
body_bytes
.to_vec
())
.unwrap
();
let
chat_response
:
serde_json
::
Value
=
serde_json
::
from_str
(
&
body_str
)
.unwrap
();
// Verify it's a valid chat completion response
assert_eq!
(
chat_response
[
"object"
],
"chat.completion"
);
assert_eq!
(
chat_response
[
"model"
],
"gpt-3.5-turbo"
);
assert
!
(
!
chat_response
[
"choices"
]
.as_array
()
.unwrap
()
.is_empty
());
}
/// Test full E2E flow with Axum server
#[tokio::test]
async
fn
test_openai_e2e_with_server
()
{
// Start mock OpenAI server
let
mock_server
=
MockOpenAIServer
::
new
()
.await
;
let
base_url
=
mock_server
.base_url
();
// Create router
let
router
=
OpenAIRouter
::
new
(
base_url
,
None
)
.await
.unwrap
();
// Create Axum app with chat completions endpoint
let
app
=
Router
::
new
()
.route
(
"/v1/chat/completions"
,
post
({
let
router
=
Arc
::
new
(
router
);
move
|
req
:
Request
<
Body
>
|
{
let
router
=
router
.clone
();
async
move
{
let
(
parts
,
body
)
=
req
.into_parts
();
let
body_bytes
=
axum
::
body
::
to_bytes
(
body
,
usize
::
MAX
)
.await
.unwrap
();
let
body_str
=
String
::
from_utf8
(
body_bytes
.to_vec
())
.unwrap
();
let
chat_request
:
ChatCompletionRequest
=
serde_json
::
from_str
(
&
body_str
)
.unwrap
();
router
.route_chat
(
Some
(
&
parts
.headers
),
&
chat_request
)
.await
}
}
}),
);
// Make a request to the server
let
request
=
Request
::
builder
()
.method
(
Method
::
POST
)
.uri
(
"/v1/chat/completions"
)
.header
(
"content-type"
,
"application/json"
)
.body
(
Body
::
from
(
json!
({
"model"
:
"gpt-3.5-turbo"
,
"messages"
:
[
{
"role"
:
"user"
,
"content"
:
"Hello, world!"
}
],
"max_tokens"
:
100
})
.to_string
(),
))
.unwrap
();
let
response
=
app
.oneshot
(
request
)
.await
.unwrap
();
assert_eq!
(
response
.status
(),
StatusCode
::
OK
);
let
body
=
axum
::
body
::
to_bytes
(
response
.into_body
(),
usize
::
MAX
)
.await
.unwrap
();
let
response_json
:
serde_json
::
Value
=
serde_json
::
from_slice
(
&
body
)
.unwrap
();
// Verify the response structure
assert_eq!
(
response_json
[
"object"
],
"chat.completion"
);
assert_eq!
(
response_json
[
"model"
],
"gpt-3.5-turbo"
);
assert
!
(
!
response_json
[
"choices"
]
.as_array
()
.unwrap
()
.is_empty
());
}
/// Test streaming chat completions pass-through with mock server
#[tokio::test]
async
fn
test_openai_router_chat_streaming_with_mock
()
{
let
mock_server
=
MockOpenAIServer
::
new
()
.await
;
let
base_url
=
mock_server
.base_url
();
let
router
=
OpenAIRouter
::
new
(
base_url
,
None
)
.await
.unwrap
();
// Build a streaming chat request
let
val
=
json!
({
"model"
:
"gpt-3.5-turbo"
,
"messages"
:
[
{
"role"
:
"user"
,
"content"
:
"Hello"
}
],
"max_tokens"
:
10
,
"stream"
:
true
});
let
chat_request
:
ChatCompletionRequest
=
serde_json
::
from_value
(
val
)
.unwrap
();
let
response
=
router
.route_chat
(
None
,
&
chat_request
)
.await
;
assert_eq!
(
response
.status
(),
StatusCode
::
OK
);
// Should be SSE
let
headers
=
response
.headers
();
let
ct
=
headers
.get
(
"content-type"
)
.unwrap
()
.to_str
()
.unwrap
()
.to_ascii_lowercase
();
assert
!
(
ct
.contains
(
"text/event-stream"
));
// Read entire stream body and assert chunks + DONE
let
body
=
axum
::
body
::
to_bytes
(
response
.into_body
(),
usize
::
MAX
)
.await
.unwrap
();
let
text
=
String
::
from_utf8
(
body
.to_vec
())
.unwrap
();
assert
!
(
text
.contains
(
"chat.completion.chunk"
));
assert
!
(
text
.contains
(
"[DONE]"
));
}
/// Test circuit breaker functionality
#[tokio::test]
async
fn
test_openai_router_circuit_breaker
()
{
// Create router with circuit breaker config
let
cb_config
=
sglang_router_rs
::
config
::
CircuitBreakerConfig
{
failure_threshold
:
2
,
success_threshold
:
1
,
timeout_duration_secs
:
1
,
window_duration_secs
:
10
,
};
let
router
=
OpenAIRouter
::
new
(
"http://invalid-url-that-will-fail"
.to_string
(),
Some
(
cb_config
),
)
.await
.unwrap
();
let
chat_request
=
create_minimal_chat_request
();
// First few requests should fail and record failures
for
_
in
0
..
3
{
let
response
=
router
.route_chat
(
None
,
&
chat_request
)
.await
;
// Should get either an error or circuit breaker response
assert
!
(
response
.status
()
==
StatusCode
::
INTERNAL_SERVER_ERROR
||
response
.status
()
==
StatusCode
::
SERVICE_UNAVAILABLE
);
}
}
/// Test that Authorization header is forwarded in /v1/models
#[tokio::test]
async
fn
test_openai_router_models_auth_forwarding
()
{
// Start a mock server that requires Authorization
let
expected_auth
=
"Bearer test-token"
.to_string
();
let
mock_server
=
MockOpenAIServer
::
new_with_auth
(
Some
(
expected_auth
.clone
()))
.await
;
let
router
=
OpenAIRouter
::
new
(
mock_server
.base_url
(),
None
)
.await
.unwrap
();
// 1) Without auth header -> expect 401
let
req
=
Request
::
builder
()
.method
(
Method
::
GET
)
.uri
(
"/models"
)
.body
(
Body
::
empty
())
.unwrap
();
let
response
=
router
.get_models
(
req
)
.await
;
assert_eq!
(
response
.status
(),
StatusCode
::
UNAUTHORIZED
);
// 2) With auth header -> expect 200
let
req
=
Request
::
builder
()
.method
(
Method
::
GET
)
.uri
(
"/models"
)
.header
(
"Authorization"
,
expected_auth
)
.body
(
Body
::
empty
())
.unwrap
();
let
response
=
router
.get_models
(
req
)
.await
;
assert_eq!
(
response
.status
(),
StatusCode
::
OK
);
let
(
_
,
body
)
=
response
.into_parts
();
let
body_bytes
=
axum
::
body
::
to_bytes
(
body
,
usize
::
MAX
)
.await
.unwrap
();
let
body_str
=
String
::
from_utf8
(
body_bytes
.to_vec
())
.unwrap
();
let
models
:
serde_json
::
Value
=
serde_json
::
from_str
(
&
body_str
)
.unwrap
();
assert_eq!
(
models
[
"object"
],
"list"
);
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment