Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
4634fd59
Unverified
Commit
4634fd59
authored
Sep 13, 2025
by
Frank Fang
Committed by
GitHub
Sep 12, 2025
Browse files
[router] Add Rerank Routing Logic in Regular Router (#10219)
parent
efedbe6c
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
545 additions
and
40 deletions
+545
-40
sgl-router/src/protocols/spec.rs
sgl-router/src/protocols/spec.rs
+63
-26
sgl-router/src/routers/grpc/pd_router.rs
sgl-router/src/routers/grpc/pd_router.rs
+5
-1
sgl-router/src/routers/grpc/router.rs
sgl-router/src/routers/grpc/router.rs
+5
-1
sgl-router/src/routers/http/openai_router.rs
sgl-router/src/routers/http/openai_router.rs
+4
-2
sgl-router/src/routers/http/pd_router.rs
sgl-router/src/routers/http/pd_router.rs
+21
-4
sgl-router/src/routers/http/router.rs
sgl-router/src/routers/http/router.rs
+42
-3
sgl-router/src/routers/mod.rs
sgl-router/src/routers/mod.rs
+2
-2
sgl-router/src/server.rs
sgl-router/src/server.rs
+23
-1
sgl-router/tests/api_endpoints_test.rs
sgl-router/tests/api_endpoints_test.rs
+329
-0
sgl-router/tests/common/mock_worker.rs
sgl-router/tests/common/mock_worker.rs
+51
-0
No files found.
sgl-router/src/protocols/spec.rs
View file @
4634fd59
...
@@ -1891,7 +1891,7 @@ pub struct RerankResponse {
...
@@ -1891,7 +1891,7 @@ pub struct RerankResponse {
pub
object
:
String
,
pub
object
:
String
,
/// Response ID
/// Response ID
pub
id
:
String
,
pub
id
:
Option
<
StringOrArray
>
,
/// Creation timestamp
/// Creation timestamp
pub
created
:
i64
,
pub
created
:
i64
,
...
@@ -1976,7 +1976,11 @@ impl RerankRequest {
...
@@ -1976,7 +1976,11 @@ impl RerankRequest {
}
}
impl
RerankResponse
{
impl
RerankResponse
{
pub
fn
new
(
results
:
Vec
<
RerankResult
>
,
model
:
String
,
request_id
:
String
)
->
Self
{
pub
fn
new
(
results
:
Vec
<
RerankResult
>
,
model
:
String
,
request_id
:
Option
<
StringOrArray
>
,
)
->
Self
{
RerankResponse
{
RerankResponse
{
results
,
results
,
model
,
model
,
...
@@ -2000,6 +2004,13 @@ impl RerankResponse {
...
@@ -2000,6 +2004,13 @@ impl RerankResponse {
pub
fn
apply_top_k
(
&
mut
self
,
k
:
usize
)
{
pub
fn
apply_top_k
(
&
mut
self
,
k
:
usize
)
{
self
.results
.truncate
(
k
);
self
.results
.truncate
(
k
);
}
}
/// Drop documents from results
pub
fn
drop_documents
(
&
mut
self
)
{
self
.results
.iter_mut
()
.for_each
(|
result
|
{
result
.document
=
None
;
});
}
}
}
// ==================================================================
// ==================================================================
...
@@ -2268,12 +2279,15 @@ mod tests {
...
@@ -2268,12 +2279,15 @@ mod tests {
let
response
=
RerankResponse
::
new
(
let
response
=
RerankResponse
::
new
(
results
.clone
(),
results
.clone
(),
"test-model"
.to_string
(),
"test-model"
.to_string
(),
"req-123"
.to_string
(),
Some
(
StringOrArray
::
String
(
"req-123"
.to_string
()
))
,
);
);
assert_eq!
(
response
.results
.len
(),
2
);
assert_eq!
(
response
.results
.len
(),
2
);
assert_eq!
(
response
.model
,
"test-model"
);
assert_eq!
(
response
.model
,
"test-model"
);
assert_eq!
(
response
.id
,
"req-123"
);
assert_eq!
(
response
.id
,
Some
(
StringOrArray
::
String
(
"req-123"
.to_string
()))
);
assert_eq!
(
response
.object
,
"rerank"
);
assert_eq!
(
response
.object
,
"rerank"
);
assert
!
(
response
.created
>
0
);
assert
!
(
response
.created
>
0
);
}
}
...
@@ -2287,8 +2301,11 @@ mod tests {
...
@@ -2287,8 +2301,11 @@ mod tests {
meta_info
:
None
,
meta_info
:
None
,
}];
}];
let
response
=
let
response
=
RerankResponse
::
new
(
RerankResponse
::
new
(
results
,
"test-model"
.to_string
(),
"req-123"
.to_string
());
results
,
"test-model"
.to_string
(),
Some
(
StringOrArray
::
String
(
"req-123"
.to_string
())),
);
let
serialized
=
serde_json
::
to_string
(
&
response
)
.unwrap
();
let
serialized
=
serde_json
::
to_string
(
&
response
)
.unwrap
();
let
deserialized
:
RerankResponse
=
serde_json
::
from_str
(
&
serialized
)
.unwrap
();
let
deserialized
:
RerankResponse
=
serde_json
::
from_str
(
&
serialized
)
.unwrap
();
...
@@ -2322,8 +2339,11 @@ mod tests {
...
@@ -2322,8 +2339,11 @@ mod tests {
},
},
];
];
let
mut
response
=
let
mut
response
=
RerankResponse
::
new
(
RerankResponse
::
new
(
results
,
"test-model"
.to_string
(),
"req-123"
.to_string
());
results
,
"test-model"
.to_string
(),
Some
(
StringOrArray
::
String
(
"req-123"
.to_string
())),
);
response
.sort_by_score
();
response
.sort_by_score
();
...
@@ -2358,8 +2378,11 @@ mod tests {
...
@@ -2358,8 +2378,11 @@ mod tests {
},
},
];
];
let
mut
response
=
let
mut
response
=
RerankResponse
::
new
(
RerankResponse
::
new
(
results
,
"test-model"
.to_string
(),
"req-123"
.to_string
());
results
,
"test-model"
.to_string
(),
Some
(
StringOrArray
::
String
(
"req-123"
.to_string
())),
);
response
.apply_top_k
(
2
);
response
.apply_top_k
(
2
);
...
@@ -2377,14 +2400,36 @@ mod tests {
...
@@ -2377,14 +2400,36 @@ mod tests {
meta_info
:
None
,
meta_info
:
None
,
}];
}];
let
mut
response
=
let
mut
response
=
RerankResponse
::
new
(
RerankResponse
::
new
(
results
,
"test-model"
.to_string
(),
"req-123"
.to_string
());
results
,
"test-model"
.to_string
(),
Some
(
StringOrArray
::
String
(
"req-123"
.to_string
())),
);
response
.apply_top_k
(
5
);
response
.apply_top_k
(
5
);
assert_eq!
(
response
.results
.len
(),
1
);
assert_eq!
(
response
.results
.len
(),
1
);
}
}
#[test]
fn
test_rerank_response_drop_documents
()
{
let
results
=
vec!
[
RerankResult
{
score
:
0.8
,
document
:
Some
(
"doc1"
.to_string
()),
index
:
0
,
meta_info
:
None
,
}];
let
mut
response
=
RerankResponse
::
new
(
results
,
"test-model"
.to_string
(),
Some
(
StringOrArray
::
String
(
"req-123"
.to_string
())),
);
response
.drop_documents
();
assert_eq!
(
response
.results
[
0
]
.document
,
None
);
}
// ==================================================================
// ==================================================================
// = RERANK RESULT TESTS =
// = RERANK RESULT TESTS =
// ==================================================================
// ==================================================================
...
@@ -2570,8 +2615,11 @@ mod tests {
...
@@ -2570,8 +2615,11 @@ mod tests {
meta_info
:
None
,
meta_info
:
None
,
}];
}];
let
mut
response
=
let
mut
response
=
RerankResponse
::
new
(
RerankResponse
::
new
(
results
,
"test-model"
.to_string
(),
"req-123"
.to_string
());
results
,
"test-model"
.to_string
(),
Some
(
StringOrArray
::
String
(
"req-123"
.to_string
())),
);
response
.usage
=
Some
(
UsageInfo
{
response
.usage
=
Some
(
UsageInfo
{
prompt_tokens
:
100
,
prompt_tokens
:
100
,
...
@@ -2645,18 +2693,7 @@ mod tests {
...
@@ -2645,18 +2693,7 @@ mod tests {
];
];
// Create response
// Create response
let
mut
response
=
RerankResponse
::
new
(
let
mut
response
=
RerankResponse
::
new
(
results
,
request
.model
.clone
(),
request
.rid
.clone
());
results
,
request
.model
.clone
(),
request
.rid
.as_ref
()
.and_then
(|
r
|
match
r
{
StringOrArray
::
String
(
s
)
=>
Some
(
s
.clone
()),
StringOrArray
::
Array
(
arr
)
=>
arr
.first
()
.cloned
(),
})
.unwrap_or_else
(||
"unknown"
.to_string
()),
);
// Sort by score
// Sort by score
response
.sort_by_score
();
response
.sort_by_score
();
...
...
sgl-router/src/routers/grpc/pd_router.rs
View file @
4634fd59
...
@@ -301,7 +301,11 @@ impl RouterTrait for GrpcPDRouter {
...
@@ -301,7 +301,11 @@ impl RouterTrait for GrpcPDRouter {
(
StatusCode
::
NOT_IMPLEMENTED
)
.into_response
()
(
StatusCode
::
NOT_IMPLEMENTED
)
.into_response
()
}
}
async
fn
route_rerank
(
&
self
,
_
headers
:
Option
<&
HeaderMap
>
,
_
body
:
Body
)
->
Response
{
async
fn
route_rerank
(
&
self
,
_
headers
:
Option
<&
HeaderMap
>
,
_
body
:
&
crate
::
protocols
::
spec
::
RerankRequest
,
)
->
Response
{
(
StatusCode
::
NOT_IMPLEMENTED
)
.into_response
()
(
StatusCode
::
NOT_IMPLEMENTED
)
.into_response
()
}
}
...
...
sgl-router/src/routers/grpc/router.rs
View file @
4634fd59
...
@@ -234,7 +234,11 @@ impl RouterTrait for GrpcRouter {
...
@@ -234,7 +234,11 @@ impl RouterTrait for GrpcRouter {
(
StatusCode
::
NOT_IMPLEMENTED
)
.into_response
()
(
StatusCode
::
NOT_IMPLEMENTED
)
.into_response
()
}
}
async
fn
route_rerank
(
&
self
,
_
headers
:
Option
<&
HeaderMap
>
,
_
body
:
Body
)
->
Response
{
async
fn
route_rerank
(
&
self
,
_
headers
:
Option
<&
HeaderMap
>
,
_
body
:
&
crate
::
protocols
::
spec
::
RerankRequest
,
)
->
Response
{
(
StatusCode
::
NOT_IMPLEMENTED
)
.into_response
()
(
StatusCode
::
NOT_IMPLEMENTED
)
.into_response
()
}
}
...
...
sgl-router/src/routers/http/openai_router.rs
View file @
4634fd59
...
@@ -2,7 +2,9 @@
...
@@ -2,7 +2,9 @@
use
crate
::
config
::
CircuitBreakerConfig
;
use
crate
::
config
::
CircuitBreakerConfig
;
use
crate
::
core
::{
CircuitBreaker
,
CircuitBreakerConfig
as
CoreCircuitBreakerConfig
};
use
crate
::
core
::{
CircuitBreaker
,
CircuitBreakerConfig
as
CoreCircuitBreakerConfig
};
use
crate
::
protocols
::
spec
::{
ChatCompletionRequest
,
CompletionRequest
,
GenerateRequest
};
use
crate
::
protocols
::
spec
::{
ChatCompletionRequest
,
CompletionRequest
,
GenerateRequest
,
RerankRequest
,
};
use
async_trait
::
async_trait
;
use
async_trait
::
async_trait
;
use
axum
::{
use
axum
::{
body
::
Body
,
body
::
Body
,
...
@@ -381,7 +383,7 @@ impl super::super::RouterTrait for OpenAIRouter {
...
@@ -381,7 +383,7 @@ impl super::super::RouterTrait for OpenAIRouter {
.into_response
()
.into_response
()
}
}
async
fn
route_rerank
(
&
self
,
_
headers
:
Option
<&
HeaderMap
>
,
_
body
:
Body
)
->
Response
{
async
fn
route_rerank
(
&
self
,
_
headers
:
Option
<&
HeaderMap
>
,
_
body
:
&
RerankRequest
)
->
Response
{
(
(
StatusCode
::
NOT_IMPLEMENTED
,
StatusCode
::
NOT_IMPLEMENTED
,
"Rerank endpoint not implemented for OpenAI backend"
,
"Rerank endpoint not implemented for OpenAI backend"
,
...
...
sgl-router/src/routers/http/pd_router.rs
View file @
4634fd59
...
@@ -9,8 +9,8 @@ use crate::core::{
...
@@ -9,8 +9,8 @@ use crate::core::{
use
crate
::
metrics
::
RouterMetrics
;
use
crate
::
metrics
::
RouterMetrics
;
use
crate
::
policies
::
LoadBalancingPolicy
;
use
crate
::
policies
::
LoadBalancingPolicy
;
use
crate
::
protocols
::
spec
::{
use
crate
::
protocols
::
spec
::{
ChatCompletionRequest
,
ChatMessage
,
CompletionRequest
,
GenerateRequest
,
Re
sponses
Request
,
ChatCompletionRequest
,
ChatMessage
,
CompletionRequest
,
GenerateRequest
,
Re
rank
Request
,
StringOrArray
,
UserMessageContent
,
ResponsesRequest
,
StringOrArray
,
UserMessageContent
,
};
};
use
crate
::
routers
::
header_utils
;
use
crate
::
routers
::
header_utils
;
use
crate
::
routers
::{
RouterTrait
,
WorkerManagement
};
use
crate
::
routers
::{
RouterTrait
,
WorkerManagement
};
...
@@ -1946,8 +1946,25 @@ impl RouterTrait for PDRouter {
...
@@ -1946,8 +1946,25 @@ impl RouterTrait for PDRouter {
todo!
()
todo!
()
}
}
async
fn
route_rerank
(
&
self
,
_
headers
:
Option
<&
HeaderMap
>
,
_
body
:
Body
)
->
Response
{
async
fn
route_rerank
(
&
self
,
headers
:
Option
<&
HeaderMap
>
,
body
:
&
RerankRequest
)
->
Response
{
todo!
()
// Extract text for cache-aware routing
let
req_text
=
if
self
.policies_need_request_text
()
{
Some
(
body
.query
.clone
())
}
else
{
None
};
// Create context
let
context
=
PDRequestContext
{
route
:
"/v1/rerank"
,
batch_size
:
None
,
is_stream
:
false
,
return_logprob
:
false
,
request_text
:
req_text
,
};
// Execute with retry and bootstrap injection
self
.execute_dual_dispatch
(
headers
,
body
,
context
)
.await
}
}
async
fn
flush_cache
(
&
self
)
->
Response
{
async
fn
flush_cache
(
&
self
)
->
Response
{
...
...
sgl-router/src/routers/http/router.rs
View file @
4634fd59
...
@@ -6,10 +6,12 @@ use crate::core::{
...
@@ -6,10 +6,12 @@ use crate::core::{
use
crate
::
metrics
::
RouterMetrics
;
use
crate
::
metrics
::
RouterMetrics
;
use
crate
::
policies
::
LoadBalancingPolicy
;
use
crate
::
policies
::
LoadBalancingPolicy
;
use
crate
::
protocols
::
spec
::{
use
crate
::
protocols
::
spec
::{
ChatCompletionRequest
,
CompletionRequest
,
GenerateRequest
,
GenerationRequest
,
ResponsesRequest
,
ChatCompletionRequest
,
CompletionRequest
,
GenerateRequest
,
GenerationRequest
,
RerankRequest
,
RerankResponse
,
RerankResult
,
ResponsesRequest
,
};
};
use
crate
::
routers
::
header_utils
;
use
crate
::
routers
::
header_utils
;
use
crate
::
routers
::{
RouterTrait
,
WorkerManagement
};
use
crate
::
routers
::{
RouterTrait
,
WorkerManagement
};
use
axum
::
body
::
to_bytes
;
use
axum
::{
use
axum
::{
body
::
Body
,
body
::
Body
,
extract
::
Request
,
extract
::
Request
,
...
@@ -1124,6 +1126,25 @@ impl Router {
...
@@ -1124,6 +1126,25 @@ impl Router {
}
}
}
}
}
}
async
fn
build_rerank_response
(
req
:
&
RerankRequest
,
response
:
Response
,
)
->
anyhow
::
Result
<
Response
>
{
let
(
_
,
response_body
)
=
response
.into_parts
();
let
body_bytes
=
to_bytes
(
response_body
,
usize
::
MAX
)
.await
?
;
let
rerank_results
=
serde_json
::
from_slice
::
<
Vec
<
RerankResult
>>
(
&
body_bytes
)
?
;
let
mut
rerank_response
=
RerankResponse
::
new
(
rerank_results
,
req
.model
.clone
(),
req
.rid
.clone
());
rerank_response
.sort_by_score
();
if
let
Some
(
top_k
)
=
req
.top_k
{
rerank_response
.apply_top_k
(
top_k
);
}
if
!
req
.return_documents
{
rerank_response
.drop_documents
();
}
Ok
(
Json
(
rerank_response
)
.into_response
())
}
}
}
use
async_trait
::
async_trait
;
use
async_trait
::
async_trait
;
...
@@ -1223,8 +1244,26 @@ impl RouterTrait for Router {
...
@@ -1223,8 +1244,26 @@ impl RouterTrait for Router {
todo!
()
todo!
()
}
}
async
fn
route_rerank
(
&
self
,
_
headers
:
Option
<&
HeaderMap
>
,
_
body
:
Body
)
->
Response
{
async
fn
route_rerank
(
&
self
,
headers
:
Option
<&
HeaderMap
>
,
body
:
&
RerankRequest
)
->
Response
{
todo!
()
if
let
Err
(
e
)
=
body
.validate
()
{
return
(
StatusCode
::
BAD_REQUEST
,
e
)
.into_response
();
}
let
response
=
self
.route_typed_request
(
headers
,
body
,
"/v1/rerank"
)
.await
;
if
response
.status
()
.is_success
()
{
match
Self
::
build_rerank_response
(
body
,
response
)
.await
{
Ok
(
rerank_response
)
=>
rerank_response
,
Err
(
e
)
=>
{
error!
(
"Failed to build rerank response: {}"
,
e
);
return
(
StatusCode
::
INTERNAL_SERVER_ERROR
,
"Failed to build rerank response"
.to_string
(),
)
.into_response
();
}
}
}
else
{
response
}
}
}
async
fn
flush_cache
(
&
self
)
->
Response
{
async
fn
flush_cache
(
&
self
)
->
Response
{
...
...
sgl-router/src/routers/mod.rs
View file @
4634fd59
...
@@ -10,7 +10,7 @@ use axum::{
...
@@ -10,7 +10,7 @@ use axum::{
use
std
::
fmt
::
Debug
;
use
std
::
fmt
::
Debug
;
use
crate
::
protocols
::
spec
::{
use
crate
::
protocols
::
spec
::{
ChatCompletionRequest
,
CompletionRequest
,
GenerateRequest
,
ResponsesRequest
,
ChatCompletionRequest
,
CompletionRequest
,
GenerateRequest
,
RerankRequest
,
ResponsesRequest
,
};
};
pub
mod
factory
;
pub
mod
factory
;
...
@@ -89,7 +89,7 @@ pub trait RouterTrait: Send + Sync + Debug + WorkerManagement {
...
@@ -89,7 +89,7 @@ pub trait RouterTrait: Send + Sync + Debug + WorkerManagement {
async
fn
route_embeddings
(
&
self
,
headers
:
Option
<&
HeaderMap
>
,
body
:
Body
)
->
Response
;
async
fn
route_embeddings
(
&
self
,
headers
:
Option
<&
HeaderMap
>
,
body
:
Body
)
->
Response
;
async
fn
route_rerank
(
&
self
,
headers
:
Option
<&
HeaderMap
>
,
body
:
Body
)
->
Response
;
async
fn
route_rerank
(
&
self
,
headers
:
Option
<&
HeaderMap
>
,
body
:
&
RerankRequest
)
->
Response
;
/// Flush cache on all workers
/// Flush cache on all workers
async
fn
flush_cache
(
&
self
)
->
Response
;
async
fn
flush_cache
(
&
self
)
->
Response
;
...
...
sgl-router/src/server.rs
View file @
4634fd59
...
@@ -3,7 +3,8 @@ use crate::logging::{self, LoggingConfig};
...
@@ -3,7 +3,8 @@ use crate::logging::{self, LoggingConfig};
use
crate
::
metrics
::{
self
,
PrometheusConfig
};
use
crate
::
metrics
::{
self
,
PrometheusConfig
};
use
crate
::
middleware
::
TokenBucket
;
use
crate
::
middleware
::
TokenBucket
;
use
crate
::
protocols
::
spec
::{
use
crate
::
protocols
::
spec
::{
ChatCompletionRequest
,
CompletionRequest
,
GenerateRequest
,
ResponsesRequest
,
ChatCompletionRequest
,
CompletionRequest
,
GenerateRequest
,
RerankRequest
,
ResponsesRequest
,
V1RerankReqInput
,
};
};
use
crate
::
reasoning_parser
::
ParserFactory
;
use
crate
::
reasoning_parser
::
ParserFactory
;
use
crate
::
routers
::{
RouterFactory
,
RouterTrait
};
use
crate
::
routers
::{
RouterFactory
,
RouterTrait
};
...
@@ -152,6 +153,25 @@ async fn v1_completions(
...
@@ -152,6 +153,25 @@ async fn v1_completions(
state
.router
.route_completion
(
Some
(
&
headers
),
&
body
)
.await
state
.router
.route_completion
(
Some
(
&
headers
),
&
body
)
.await
}
}
async
fn
rerank
(
State
(
state
):
State
<
Arc
<
AppState
>>
,
headers
:
http
::
HeaderMap
,
Json
(
body
):
Json
<
RerankRequest
>
,
)
->
Response
{
state
.router
.route_rerank
(
Some
(
&
headers
),
&
body
)
.await
}
async
fn
v1_rerank
(
State
(
state
):
State
<
Arc
<
AppState
>>
,
headers
:
http
::
HeaderMap
,
Json
(
body
):
Json
<
V1RerankReqInput
>
,
)
->
Response
{
state
.router
.route_rerank
(
Some
(
&
headers
),
&
body
.into
())
.await
}
async
fn
v1_responses
(
async
fn
v1_responses
(
State
(
state
):
State
<
Arc
<
AppState
>>
,
State
(
state
):
State
<
Arc
<
AppState
>>
,
headers
:
http
::
HeaderMap
,
headers
:
http
::
HeaderMap
,
...
@@ -237,6 +257,8 @@ pub fn build_app(
...
@@ -237,6 +257,8 @@ pub fn build_app(
.route
(
"/generate"
,
post
(
generate
))
.route
(
"/generate"
,
post
(
generate
))
.route
(
"/v1/chat/completions"
,
post
(
v1_chat_completions
))
.route
(
"/v1/chat/completions"
,
post
(
v1_chat_completions
))
.route
(
"/v1/completions"
,
post
(
v1_completions
))
.route
(
"/v1/completions"
,
post
(
v1_completions
))
.route
(
"/rerank"
,
post
(
rerank
))
.route
(
"/v1/rerank"
,
post
(
v1_rerank
))
.route
(
"/v1/responses"
,
post
(
v1_responses
))
.route
(
"/v1/responses"
,
post
(
v1_responses
))
.route_layer
(
axum
::
middleware
::
from_fn_with_state
(
.route_layer
(
axum
::
middleware
::
from_fn_with_state
(
app_state
.clone
(),
app_state
.clone
(),
...
...
sgl-router/tests/api_endpoints_test.rs
View file @
4634fd59
...
@@ -1752,3 +1752,332 @@ mod request_id_tests {
...
@@ -1752,3 +1752,332 @@ mod request_id_tests {
ctx
.shutdown
()
.await
;
ctx
.shutdown
()
.await
;
}
}
}
}
#[cfg(test)]
mod
rerank_tests
{
use
super
::
*
;
// Note: RerankRequest and RerankResult are available for future use
#[tokio::test]
async
fn
test_rerank_success
()
{
let
ctx
=
TestContext
::
new
(
vec!
[
MockWorkerConfig
{
port
:
18105
,
worker_type
:
WorkerType
::
Regular
,
health_status
:
HealthStatus
::
Healthy
,
response_delay_ms
:
0
,
fail_rate
:
0.0
,
}])
.await
;
let
app
=
ctx
.create_app
()
.await
;
let
payload
=
json!
({
"query"
:
"machine learning algorithms"
,
"documents"
:
[
"Introduction to machine learning concepts"
,
"Deep learning neural networks tutorial"
],
"model"
:
"test-rerank-model"
,
"top_k"
:
2
,
"return_documents"
:
true
,
"rid"
:
"test-request-123"
});
let
req
=
Request
::
builder
()
.method
(
"POST"
)
.uri
(
"/rerank"
)
.header
(
CONTENT_TYPE
,
"application/json"
)
.body
(
Body
::
from
(
serde_json
::
to_string
(
&
payload
)
.unwrap
()))
.unwrap
();
let
resp
=
app
.oneshot
(
req
)
.await
.unwrap
();
assert_eq!
(
resp
.status
(),
StatusCode
::
OK
);
let
body
=
axum
::
body
::
to_bytes
(
resp
.into_body
(),
usize
::
MAX
)
.await
.unwrap
();
let
body_json
:
serde_json
::
Value
=
serde_json
::
from_slice
(
&
body
)
.unwrap
();
// Verify response structure
assert
!
(
body_json
.get
(
"results"
)
.is_some
());
assert
!
(
body_json
.get
(
"model"
)
.is_some
());
assert_eq!
(
body_json
[
"model"
],
"test-rerank-model"
);
let
results
=
body_json
[
"results"
]
.as_array
()
.unwrap
();
assert_eq!
(
results
.len
(),
2
);
// Verify results are sorted by score (highest first)
assert
!
(
results
[
0
][
"score"
]
.as_f64
()
.unwrap
()
>=
results
[
1
][
"score"
]
.as_f64
()
.unwrap
());
ctx
.shutdown
()
.await
;
}
#[tokio::test]
async
fn
test_rerank_with_top_k
()
{
let
ctx
=
TestContext
::
new
(
vec!
[
MockWorkerConfig
{
port
:
18106
,
worker_type
:
WorkerType
::
Regular
,
health_status
:
HealthStatus
::
Healthy
,
response_delay_ms
:
0
,
fail_rate
:
0.0
,
}])
.await
;
let
app
=
ctx
.create_app
()
.await
;
let
payload
=
json!
({
"query"
:
"test query"
,
"documents"
:
[
"Document 1"
,
"Document 2"
,
"Document 3"
],
"model"
:
"test-model"
,
"top_k"
:
1
,
"return_documents"
:
true
});
let
req
=
Request
::
builder
()
.method
(
"POST"
)
.uri
(
"/rerank"
)
.header
(
CONTENT_TYPE
,
"application/json"
)
.body
(
Body
::
from
(
serde_json
::
to_string
(
&
payload
)
.unwrap
()))
.unwrap
();
let
resp
=
app
.oneshot
(
req
)
.await
.unwrap
();
assert_eq!
(
resp
.status
(),
StatusCode
::
OK
);
let
body
=
axum
::
body
::
to_bytes
(
resp
.into_body
(),
usize
::
MAX
)
.await
.unwrap
();
let
body_json
:
serde_json
::
Value
=
serde_json
::
from_slice
(
&
body
)
.unwrap
();
// Should only return top_k results
let
results
=
body_json
[
"results"
]
.as_array
()
.unwrap
();
assert_eq!
(
results
.len
(),
1
);
ctx
.shutdown
()
.await
;
}
#[tokio::test]
async
fn
test_rerank_without_documents
()
{
let
ctx
=
TestContext
::
new
(
vec!
[
MockWorkerConfig
{
port
:
18107
,
worker_type
:
WorkerType
::
Regular
,
health_status
:
HealthStatus
::
Healthy
,
response_delay_ms
:
0
,
fail_rate
:
0.0
,
}])
.await
;
let
app
=
ctx
.create_app
()
.await
;
let
payload
=
json!
({
"query"
:
"test query"
,
"documents"
:
[
"Document 1"
,
"Document 2"
],
"model"
:
"test-model"
,
"return_documents"
:
false
});
let
req
=
Request
::
builder
()
.method
(
"POST"
)
.uri
(
"/rerank"
)
.header
(
CONTENT_TYPE
,
"application/json"
)
.body
(
Body
::
from
(
serde_json
::
to_string
(
&
payload
)
.unwrap
()))
.unwrap
();
let
resp
=
app
.oneshot
(
req
)
.await
.unwrap
();
assert_eq!
(
resp
.status
(),
StatusCode
::
OK
);
let
body
=
axum
::
body
::
to_bytes
(
resp
.into_body
(),
usize
::
MAX
)
.await
.unwrap
();
let
body_json
:
serde_json
::
Value
=
serde_json
::
from_slice
(
&
body
)
.unwrap
();
// Documents should be null when return_documents is false
let
results
=
body_json
[
"results"
]
.as_array
()
.unwrap
();
for
result
in
results
{
assert
!
(
result
.get
(
"document"
)
.is_none
());
}
ctx
.shutdown
()
.await
;
}
#[tokio::test]
async
fn
test_rerank_worker_failure
()
{
let
ctx
=
TestContext
::
new
(
vec!
[
MockWorkerConfig
{
port
:
18108
,
worker_type
:
WorkerType
::
Regular
,
health_status
:
HealthStatus
::
Healthy
,
response_delay_ms
:
0
,
fail_rate
:
1.0
,
// Always fail
}])
.await
;
let
app
=
ctx
.create_app
()
.await
;
let
payload
=
json!
({
"query"
:
"test query"
,
"documents"
:
[
"Document 1"
],
"model"
:
"test-model"
});
let
req
=
Request
::
builder
()
.method
(
"POST"
)
.uri
(
"/rerank"
)
.header
(
CONTENT_TYPE
,
"application/json"
)
.body
(
Body
::
from
(
serde_json
::
to_string
(
&
payload
)
.unwrap
()))
.unwrap
();
let
resp
=
app
.oneshot
(
req
)
.await
.unwrap
();
// Should return the worker's error response
assert_eq!
(
resp
.status
(),
StatusCode
::
INTERNAL_SERVER_ERROR
);
ctx
.shutdown
()
.await
;
}
#[tokio::test]
async
fn
test_v1_rerank_compatibility
()
{
let
ctx
=
TestContext
::
new
(
vec!
[
MockWorkerConfig
{
port
:
18110
,
worker_type
:
WorkerType
::
Regular
,
health_status
:
HealthStatus
::
Healthy
,
response_delay_ms
:
0
,
fail_rate
:
0.0
,
}])
.await
;
let
app
=
ctx
.create_app
()
.await
;
// Test V1 API format (simplified input)
let
payload
=
json!
({
"query"
:
"machine learning algorithms"
,
"documents"
:
[
"Introduction to machine learning concepts"
,
"Deep learning neural networks tutorial"
,
"Statistical learning theory basics"
]
});
let
req
=
Request
::
builder
()
.method
(
"POST"
)
.uri
(
"/v1/rerank"
)
.header
(
CONTENT_TYPE
,
"application/json"
)
.body
(
Body
::
from
(
serde_json
::
to_string
(
&
payload
)
.unwrap
()))
.unwrap
();
let
resp
=
app
.oneshot
(
req
)
.await
.unwrap
();
assert_eq!
(
resp
.status
(),
StatusCode
::
OK
);
let
body
=
axum
::
body
::
to_bytes
(
resp
.into_body
(),
usize
::
MAX
)
.await
.unwrap
();
let
body_json
:
serde_json
::
Value
=
serde_json
::
from_slice
(
&
body
)
.unwrap
();
// Verify response structure
assert
!
(
body_json
.get
(
"results"
)
.is_some
());
assert
!
(
body_json
.get
(
"model"
)
.is_some
());
// V1 API should use default model name
assert_eq!
(
body_json
[
"model"
],
"default"
);
let
results
=
body_json
[
"results"
]
.as_array
()
.unwrap
();
assert_eq!
(
results
.len
(),
3
);
// All documents should be returned
// Verify results are sorted by score (highest first)
assert
!
(
results
[
0
][
"score"
]
.as_f64
()
.unwrap
()
>=
results
[
1
][
"score"
]
.as_f64
()
.unwrap
());
assert
!
(
results
[
1
][
"score"
]
.as_f64
()
.unwrap
()
>=
results
[
2
][
"score"
]
.as_f64
()
.unwrap
());
// V1 API should return documents by default
for
result
in
results
{
assert
!
(
result
.get
(
"document"
)
.is_some
());
}
ctx
.shutdown
()
.await
;
}
#[tokio::test]
async
fn
test_rerank_invalid_request
()
{
let
ctx
=
TestContext
::
new
(
vec!
[
MockWorkerConfig
{
port
:
18111
,
worker_type
:
WorkerType
::
Regular
,
health_status
:
HealthStatus
::
Healthy
,
response_delay_ms
:
0
,
fail_rate
:
0.0
,
}])
.await
;
let
app
=
ctx
.create_app
()
.await
;
// Test empty query string (validation should fail)
let
payload
=
json!
({
"query"
:
""
,
"documents"
:
[
"Document 1"
,
"Document 2"
],
"model"
:
"test-model"
});
let
req
=
Request
::
builder
()
.method
(
"POST"
)
.uri
(
"/rerank"
)
.header
(
CONTENT_TYPE
,
"application/json"
)
.body
(
Body
::
from
(
serde_json
::
to_string
(
&
payload
)
.unwrap
()))
.unwrap
();
let
resp
=
app
.clone
()
.oneshot
(
req
)
.await
.unwrap
();
assert_eq!
(
resp
.status
(),
StatusCode
::
BAD_REQUEST
);
// Test query with only whitespace (validation should fail)
let
payload
=
json!
({
"query"
:
" "
,
"documents"
:
[
"Document 1"
,
"Document 2"
],
"model"
:
"test-model"
});
let
req
=
Request
::
builder
()
.method
(
"POST"
)
.uri
(
"/rerank"
)
.header
(
CONTENT_TYPE
,
"application/json"
)
.body
(
Body
::
from
(
serde_json
::
to_string
(
&
payload
)
.unwrap
()))
.unwrap
();
let
resp
=
app
.clone
()
.oneshot
(
req
)
.await
.unwrap
();
assert_eq!
(
resp
.status
(),
StatusCode
::
BAD_REQUEST
);
// Test empty documents list (validation should fail)
let
payload
=
json!
({
"query"
:
"test query"
,
"documents"
:
[],
"model"
:
"test-model"
});
let
req
=
Request
::
builder
()
.method
(
"POST"
)
.uri
(
"/rerank"
)
.header
(
CONTENT_TYPE
,
"application/json"
)
.body
(
Body
::
from
(
serde_json
::
to_string
(
&
payload
)
.unwrap
()))
.unwrap
();
let
resp
=
app
.clone
()
.oneshot
(
req
)
.await
.unwrap
();
assert_eq!
(
resp
.status
(),
StatusCode
::
BAD_REQUEST
);
// Test invalid top_k (validation should fail)
let
payload
=
json!
({
"query"
:
"test query"
,
"documents"
:
[
"Document 1"
,
"Document 2"
],
"model"
:
"test-model"
,
"top_k"
:
0
});
let
req
=
Request
::
builder
()
.method
(
"POST"
)
.uri
(
"/rerank"
)
.header
(
CONTENT_TYPE
,
"application/json"
)
.body
(
Body
::
from
(
serde_json
::
to_string
(
&
payload
)
.unwrap
()))
.unwrap
();
let
resp
=
app
.oneshot
(
req
)
.await
.unwrap
();
assert_eq!
(
resp
.status
(),
StatusCode
::
BAD_REQUEST
);
ctx
.shutdown
()
.await
;
}
}
sgl-router/tests/common/mock_worker.rs
View file @
4634fd59
...
@@ -81,6 +81,7 @@ impl MockWorker {
...
@@ -81,6 +81,7 @@ impl MockWorker {
.route
(
"/generate"
,
post
(
generate_handler
))
.route
(
"/generate"
,
post
(
generate_handler
))
.route
(
"/v1/chat/completions"
,
post
(
chat_completions_handler
))
.route
(
"/v1/chat/completions"
,
post
(
chat_completions_handler
))
.route
(
"/v1/completions"
,
post
(
completions_handler
))
.route
(
"/v1/completions"
,
post
(
completions_handler
))
.route
(
"/v1/rerank"
,
post
(
rerank_handler
))
.route
(
"/v1/responses"
,
post
(
responses_handler
))
.route
(
"/v1/responses"
,
post
(
responses_handler
))
.route
(
"/flush_cache"
,
post
(
flush_cache_handler
))
.route
(
"/flush_cache"
,
post
(
flush_cache_handler
))
.route
(
"/v1/models"
,
get
(
v1_models_handler
))
.route
(
"/v1/models"
,
get
(
v1_models_handler
))
...
@@ -687,6 +688,56 @@ async fn v1_models_handler(State(config): State<Arc<RwLock<MockWorkerConfig>>>)
...
@@ -687,6 +688,56 @@ async fn v1_models_handler(State(config): State<Arc<RwLock<MockWorkerConfig>>>)
.into_response
()
.into_response
()
}
}
async
fn
rerank_handler
(
State
(
config
):
State
<
Arc
<
RwLock
<
MockWorkerConfig
>>>
,
Json
(
payload
):
Json
<
serde_json
::
Value
>
,
)
->
impl
IntoResponse
{
let
config
=
config
.read
()
.await
;
// Simulate response delay
if
config
.response_delay_ms
>
0
{
tokio
::
time
::
sleep
(
tokio
::
time
::
Duration
::
from_millis
(
config
.response_delay_ms
))
.await
;
}
// Simulate failure rate
if
rand
::
random
::
<
f32
>
()
<
config
.fail_rate
{
return
(
StatusCode
::
INTERNAL_SERVER_ERROR
,
"Simulated failure"
)
.into_response
();
}
// Extract documents from the request to create mock results
let
empty_vec
=
vec!
[];
let
documents
=
payload
.get
(
"documents"
)
.and_then
(|
d
|
d
.as_array
())
.unwrap_or
(
&
empty_vec
);
// Create mock rerank results with scores based on document index
let
mut
mock_results
=
Vec
::
new
();
for
(
i
,
doc
)
in
documents
.iter
()
.enumerate
()
{
let
score
=
0.95
-
(
i
as
f32
*
0.1
);
// Decreasing scores
let
result
=
serde_json
::
json!
({
"score"
:
score
,
"document"
:
doc
.as_str
()
.unwrap_or
(
""
),
"index"
:
i
,
"meta_info"
:
{
"confidence"
:
if
score
>
0.9
{
"high"
}
else
{
"medium"
}
}
});
mock_results
.push
(
result
);
}
// Sort by score (highest first) to simulate proper ranking
mock_results
.sort_by
(|
a
,
b
|
{
b
[
"score"
]
.as_f64
()
.unwrap
()
.partial_cmp
(
&
a
[
"score"
]
.as_f64
()
.unwrap
())
.unwrap
()
});
(
StatusCode
::
OK
,
Json
(
mock_results
))
.into_response
()
}
impl
Default
for
MockWorkerConfig
{
impl
Default
for
MockWorkerConfig
{
fn
default
()
->
Self
{
fn
default
()
->
Self
{
Self
{
Self
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment