Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
366043db
Unverified
Commit
366043db
authored
Sep 12, 2025
by
Keyang Ru
Committed by
GitHub
Sep 12, 2025
Browse files
[router] Add get and cancel method for response api (#10387)
parent
2f173ea0
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
581 additions
and
5 deletions
+581
-5
sgl-router/src/routers/grpc/pd_router.rs
sgl-router/src/routers/grpc/pd_router.rs
+8
-0
sgl-router/src/routers/grpc/router.rs
sgl-router/src/routers/grpc/router.rs
+8
-0
sgl-router/src/routers/http/openai_router.rs
sgl-router/src/routers/http/openai_router.rs
+16
-0
sgl-router/src/routers/http/pd_router.rs
sgl-router/src/routers/http/pd_router.rs
+16
-0
sgl-router/src/routers/http/router.rs
sgl-router/src/routers/http/router.rs
+121
-1
sgl-router/src/routers/mod.rs
sgl-router/src/routers/mod.rs
+28
-0
sgl-router/src/server.rs
sgl-router/src/server.rs
+58
-2
sgl-router/tests/api_endpoints_test.rs
sgl-router/tests/api_endpoints_test.rs
+202
-0
sgl-router/tests/common/mock_worker.rs
sgl-router/tests/common/mock_worker.rs
+124
-2
No files found.
sgl-router/src/routers/grpc/pd_router.rs
View file @
366043db
...
@@ -301,6 +301,14 @@ impl RouterTrait for GrpcPDRouter {
...
@@ -301,6 +301,14 @@ impl RouterTrait for GrpcPDRouter {
(
StatusCode
::
NOT_IMPLEMENTED
)
.into_response
()
(
StatusCode
::
NOT_IMPLEMENTED
)
.into_response
()
}
}
async
fn
get_response
(
&
self
,
_
headers
:
Option
<&
HeaderMap
>
,
_
response_id
:
&
str
)
->
Response
{
(
StatusCode
::
NOT_IMPLEMENTED
)
.into_response
()
}
async
fn
cancel_response
(
&
self
,
_
headers
:
Option
<&
HeaderMap
>
,
_
response_id
:
&
str
)
->
Response
{
(
StatusCode
::
NOT_IMPLEMENTED
)
.into_response
()
}
async
fn
route_embeddings
(
&
self
,
_
headers
:
Option
<&
HeaderMap
>
,
_
body
:
Body
)
->
Response
{
async
fn
route_embeddings
(
&
self
,
_
headers
:
Option
<&
HeaderMap
>
,
_
body
:
Body
)
->
Response
{
(
StatusCode
::
NOT_IMPLEMENTED
)
.into_response
()
(
StatusCode
::
NOT_IMPLEMENTED
)
.into_response
()
}
}
...
...
sgl-router/src/routers/grpc/router.rs
View file @
366043db
...
@@ -234,6 +234,14 @@ impl RouterTrait for GrpcRouter {
...
@@ -234,6 +234,14 @@ impl RouterTrait for GrpcRouter {
(
StatusCode
::
NOT_IMPLEMENTED
)
.into_response
()
(
StatusCode
::
NOT_IMPLEMENTED
)
.into_response
()
}
}
async
fn
get_response
(
&
self
,
_
headers
:
Option
<&
HeaderMap
>
,
_
response_id
:
&
str
)
->
Response
{
(
StatusCode
::
NOT_IMPLEMENTED
)
.into_response
()
}
async
fn
cancel_response
(
&
self
,
_
headers
:
Option
<&
HeaderMap
>
,
_
response_id
:
&
str
)
->
Response
{
(
StatusCode
::
NOT_IMPLEMENTED
)
.into_response
()
}
async
fn
route_embeddings
(
&
self
,
_
headers
:
Option
<&
HeaderMap
>
,
_
body
:
Body
)
->
Response
{
async
fn
route_embeddings
(
&
self
,
_
headers
:
Option
<&
HeaderMap
>
,
_
body
:
Body
)
->
Response
{
(
StatusCode
::
NOT_IMPLEMENTED
)
.into_response
()
(
StatusCode
::
NOT_IMPLEMENTED
)
.into_response
()
}
}
...
...
sgl-router/src/routers/http/openai_router.rs
View file @
366043db
...
@@ -351,6 +351,22 @@ impl super::super::RouterTrait for OpenAIRouter {
...
@@ -351,6 +351,22 @@ impl super::super::RouterTrait for OpenAIRouter {
.into_response
()
.into_response
()
}
}
async
fn
get_response
(
&
self
,
_
headers
:
Option
<&
HeaderMap
>
,
_
response_id
:
&
str
)
->
Response
{
(
StatusCode
::
NOT_IMPLEMENTED
,
"Responses retrieve endpoint not implemented for OpenAI router"
,
)
.into_response
()
}
async
fn
cancel_response
(
&
self
,
_
headers
:
Option
<&
HeaderMap
>
,
_
response_id
:
&
str
)
->
Response
{
(
StatusCode
::
NOT_IMPLEMENTED
,
"Responses cancel endpoint not implemented for OpenAI router"
,
)
.into_response
()
}
async
fn
flush_cache
(
&
self
)
->
Response
{
async
fn
flush_cache
(
&
self
)
->
Response
{
(
(
StatusCode
::
NOT_IMPLEMENTED
,
StatusCode
::
NOT_IMPLEMENTED
,
...
...
sgl-router/src/routers/http/pd_router.rs
View file @
366043db
...
@@ -1922,6 +1922,22 @@ impl RouterTrait for PDRouter {
...
@@ -1922,6 +1922,22 @@ impl RouterTrait for PDRouter {
.into_response
()
.into_response
()
}
}
async
fn
get_response
(
&
self
,
_
headers
:
Option
<&
HeaderMap
>
,
_
response_id
:
&
str
)
->
Response
{
(
StatusCode
::
NOT_IMPLEMENTED
,
"Responses retrieve endpoint not implemented for PD router"
,
)
.into_response
()
}
async
fn
cancel_response
(
&
self
,
_
headers
:
Option
<&
HeaderMap
>
,
_
response_id
:
&
str
)
->
Response
{
(
StatusCode
::
NOT_IMPLEMENTED
,
"Responses cancel endpoint not implemented for PD router"
,
)
.into_response
()
}
async
fn
route_embeddings
(
&
self
,
_
headers
:
Option
<&
HeaderMap
>
,
_
body
:
Body
)
->
Response
{
async
fn
route_embeddings
(
&
self
,
_
headers
:
Option
<&
HeaderMap
>
,
_
body
:
Body
)
->
Response
{
todo!
()
todo!
()
}
}
...
...
sgl-router/src/routers/http/router.rs
View file @
366043db
...
@@ -15,7 +15,9 @@ use axum::body::to_bytes;
...
@@ -15,7 +15,9 @@ use axum::body::to_bytes;
use
axum
::{
use
axum
::{
body
::
Body
,
body
::
Body
,
extract
::
Request
,
extract
::
Request
,
http
::{
header
::
CONTENT_LENGTH
,
header
::
CONTENT_TYPE
,
HeaderMap
,
HeaderValue
,
StatusCode
},
http
::{
header
::
CONTENT_LENGTH
,
header
::
CONTENT_TYPE
,
HeaderMap
,
HeaderValue
,
Method
,
StatusCode
,
},
response
::{
IntoResponse
,
Response
},
response
::{
IntoResponse
,
Response
},
Json
,
Json
,
};
};
...
@@ -600,6 +602,114 @@ impl Router {
...
@@ -600,6 +602,114 @@ impl Router {
response
response
}
}
// Helper: return base worker URL (strips DP suffix when enabled)
fn
worker_base_url
(
&
self
,
worker_url
:
&
str
)
->
String
{
if
self
.dp_aware
{
if
let
Ok
((
prefix
,
_
))
=
Self
::
extract_dp_rank
(
worker_url
)
{
return
prefix
.to_string
();
}
}
worker_url
.to_string
()
}
// Generic simple routing for GET/POST without JSON body
async
fn
route_simple_request
(
&
self
,
headers
:
Option
<&
HeaderMap
>
,
endpoint
:
&
str
,
method
:
Method
,
)
->
Response
{
// TODO: currently the sglang worker is using in-memory state management, so this implementation has to fan out to all workers.
// Eventually, we need to have router to manage the chat history with a proper database, will update this implementation accordingly.
let
worker_urls
=
self
.get_worker_urls
();
if
worker_urls
.is_empty
()
{
return
(
StatusCode
::
SERVICE_UNAVAILABLE
,
"No available workers"
)
.into_response
();
}
let
mut
last_response
:
Option
<
Response
>
=
None
;
for
worker_url
in
worker_urls
{
let
base
=
self
.worker_base_url
(
&
worker_url
);
let
url
=
format!
(
"{}/{}"
,
base
,
endpoint
);
let
mut
request_builder
=
match
method
{
Method
::
GET
=>
self
.client
.get
(
url
),
Method
::
POST
=>
self
.client
.post
(
url
),
_
=>
{
return
(
StatusCode
::
METHOD_NOT_ALLOWED
,
"Unsupported method for simple routing"
,
)
.into_response
()
}
};
if
let
Some
(
hdrs
)
=
headers
{
for
(
name
,
value
)
in
hdrs
{
let
name_lc
=
name
.as_str
()
.to_lowercase
();
if
name_lc
!=
"content-type"
&&
name_lc
!=
"content-length"
{
request_builder
=
request_builder
.header
(
name
,
value
);
}
}
}
match
request_builder
.send
()
.await
{
Ok
(
res
)
=>
{
let
status
=
StatusCode
::
from_u16
(
res
.status
()
.as_u16
())
.unwrap_or
(
StatusCode
::
INTERNAL_SERVER_ERROR
);
let
response_headers
=
header_utils
::
preserve_response_headers
(
res
.headers
());
match
res
.bytes
()
.await
{
Ok
(
body
)
=>
{
let
mut
response
=
Response
::
new
(
axum
::
body
::
Body
::
from
(
body
));
*
response
.status_mut
()
=
status
;
*
response
.headers_mut
()
=
response_headers
;
if
status
.is_success
()
{
return
response
;
}
last_response
=
Some
(
response
);
}
Err
(
e
)
=>
{
last_response
=
Some
(
(
StatusCode
::
INTERNAL_SERVER_ERROR
,
format!
(
"Failed to read response: {}"
,
e
),
)
.into_response
(),
);
}
}
}
Err
(
e
)
=>
{
last_response
=
Some
(
(
StatusCode
::
INTERNAL_SERVER_ERROR
,
format!
(
"Request failed: {}"
,
e
),
)
.into_response
(),
);
}
}
}
last_response
.unwrap_or_else
(||
(
StatusCode
::
BAD_GATEWAY
,
"No worker response"
)
.into_response
())
}
// Route a GET request with provided headers to a specific endpoint
async
fn
route_get_request
(
&
self
,
headers
:
Option
<&
HeaderMap
>
,
endpoint
:
&
str
)
->
Response
{
self
.route_simple_request
(
headers
,
endpoint
,
Method
::
GET
)
.await
}
// Route a POST request with empty body to a specific endpoint
async
fn
route_post_empty_request
(
&
self
,
headers
:
Option
<&
HeaderMap
>
,
endpoint
:
&
str
,
)
->
Response
{
self
.route_simple_request
(
headers
,
endpoint
,
Method
::
POST
)
.await
}
// TODO (rui): Better accommodate to the Worker abstraction
// TODO (rui): Better accommodate to the Worker abstraction
fn
extract_dp_rank
(
worker_url
:
&
str
)
->
Result
<
(
&
str
,
usize
),
String
>
{
fn
extract_dp_rank
(
worker_url
:
&
str
)
->
Result
<
(
&
str
,
usize
),
String
>
{
let
parts
:
Vec
<&
str
>
=
worker_url
.split
(
'@'
)
.collect
();
let
parts
:
Vec
<&
str
>
=
worker_url
.split
(
'@'
)
.collect
();
...
@@ -1310,6 +1420,16 @@ impl RouterTrait for Router {
...
@@ -1310,6 +1420,16 @@ impl RouterTrait for Router {
.await
.await
}
}
async
fn
get_response
(
&
self
,
headers
:
Option
<&
HeaderMap
>
,
response_id
:
&
str
)
->
Response
{
let
endpoint
=
format!
(
"v1/responses/{}"
,
response_id
);
self
.route_get_request
(
headers
,
&
endpoint
)
.await
}
async
fn
cancel_response
(
&
self
,
headers
:
Option
<&
HeaderMap
>
,
response_id
:
&
str
)
->
Response
{
let
endpoint
=
format!
(
"v1/responses/{}/cancel"
,
response_id
);
self
.route_post_empty_request
(
headers
,
&
endpoint
)
.await
}
async
fn
route_embeddings
(
&
self
,
_
headers
:
Option
<&
HeaderMap
>
,
_
body
:
Body
)
->
Response
{
async
fn
route_embeddings
(
&
self
,
_
headers
:
Option
<&
HeaderMap
>
,
_
body
:
Body
)
->
Response
{
todo!
()
todo!
()
}
}
...
...
sgl-router/src/routers/mod.rs
View file @
366043db
...
@@ -95,6 +95,34 @@ pub trait RouterTrait: Send + Sync + Debug + WorkerManagement {
...
@@ -95,6 +95,34 @@ pub trait RouterTrait: Send + Sync + Debug + WorkerManagement {
model_id
:
Option
<&
str
>
,
model_id
:
Option
<&
str
>
,
)
->
Response
;
)
->
Response
;
/// Retrieve a stored/background response by id
async
fn
get_response
(
&
self
,
headers
:
Option
<&
HeaderMap
>
,
response_id
:
&
str
)
->
Response
;
/// Cancel a background response by id
async
fn
cancel_response
(
&
self
,
headers
:
Option
<&
HeaderMap
>
,
response_id
:
&
str
)
->
Response
;
/// Delete a response by id
async
fn
delete_response
(
&
self
,
_
headers
:
Option
<&
HeaderMap
>
,
_
response_id
:
&
str
)
->
Response
{
(
StatusCode
::
NOT_IMPLEMENTED
,
"Responses delete endpoint not implemented"
,
)
.into_response
()
}
/// List input items of a response by id
async
fn
list_response_input_items
(
&
self
,
_
headers
:
Option
<&
HeaderMap
>
,
_
response_id
:
&
str
,
)
->
Response
{
(
StatusCode
::
NOT_IMPLEMENTED
,
"Responses list input items endpoint not implemented"
,
)
.into_response
()
}
async
fn
route_embeddings
(
&
self
,
headers
:
Option
<&
HeaderMap
>
,
body
:
Body
)
->
Response
;
async
fn
route_embeddings
(
&
self
,
headers
:
Option
<&
HeaderMap
>
,
body
:
Body
)
->
Response
;
async
fn
route_rerank
(
async
fn
route_rerank
(
...
...
sgl-router/src/server.rs
View file @
366043db
...
@@ -16,10 +16,10 @@ use crate::service_discovery::{start_service_discovery, ServiceDiscoveryConfig};
...
@@ -16,10 +16,10 @@ use crate::service_discovery::{start_service_discovery, ServiceDiscoveryConfig};
use
crate
::
tokenizer
::{
factory
as
tokenizer_factory
,
traits
::
Tokenizer
};
use
crate
::
tokenizer
::{
factory
as
tokenizer_factory
,
traits
::
Tokenizer
};
use
crate
::
tool_parser
::
ParserRegistry
;
use
crate
::
tool_parser
::
ParserRegistry
;
use
axum
::{
use
axum
::{
extract
::{
Query
,
Request
,
State
},
extract
::{
Path
,
Query
,
Request
,
State
},
http
::
StatusCode
,
http
::
StatusCode
,
response
::{
IntoResponse
,
Response
},
response
::{
IntoResponse
,
Response
},
routing
::{
get
,
post
},
routing
::{
delete
,
get
,
post
},
Json
,
Router
,
Json
,
Router
,
};
};
use
reqwest
::
Client
;
use
reqwest
::
Client
;
...
@@ -208,6 +208,52 @@ async fn v1_responses(
...
@@ -208,6 +208,52 @@ async fn v1_responses(
.await
.await
}
}
async
fn
v1_responses_get
(
State
(
state
):
State
<
Arc
<
AppState
>>
,
Path
(
response_id
):
Path
<
String
>
,
headers
:
http
::
HeaderMap
,
)
->
Response
{
state
.router
.get_response
(
Some
(
&
headers
),
&
response_id
)
.await
}
async
fn
v1_responses_cancel
(
State
(
state
):
State
<
Arc
<
AppState
>>
,
Path
(
response_id
):
Path
<
String
>
,
headers
:
http
::
HeaderMap
,
)
->
Response
{
state
.router
.cancel_response
(
Some
(
&
headers
),
&
response_id
)
.await
}
async
fn
v1_responses_delete
(
State
(
state
):
State
<
Arc
<
AppState
>>
,
Path
(
response_id
):
Path
<
String
>
,
headers
:
http
::
HeaderMap
,
)
->
Response
{
// Python server does not support this yet
state
.router
.delete_response
(
Some
(
&
headers
),
&
response_id
)
.await
}
async
fn
v1_responses_list_input_items
(
State
(
state
):
State
<
Arc
<
AppState
>>
,
Path
(
response_id
):
Path
<
String
>
,
headers
:
http
::
HeaderMap
,
)
->
Response
{
// Python server does not support this yet
state
.router
.list_response_input_items
(
Some
(
&
headers
),
&
response_id
)
.await
}
// Worker management endpoints
// Worker management endpoints
async
fn
add_worker
(
async
fn
add_worker
(
State
(
state
):
State
<
Arc
<
AppState
>>
,
State
(
state
):
State
<
Arc
<
AppState
>>
,
...
@@ -419,6 +465,16 @@ pub fn build_app(
...
@@ -419,6 +465,16 @@ pub fn build_app(
.route
(
"/rerank"
,
post
(
rerank
))
.route
(
"/rerank"
,
post
(
rerank
))
.route
(
"/v1/rerank"
,
post
(
v1_rerank
))
.route
(
"/v1/rerank"
,
post
(
v1_rerank
))
.route
(
"/v1/responses"
,
post
(
v1_responses
))
.route
(
"/v1/responses"
,
post
(
v1_responses
))
.route
(
"/v1/responses/{response_id}"
,
get
(
v1_responses_get
))
.route
(
"/v1/responses/{response_id}/cancel"
,
post
(
v1_responses_cancel
),
)
.route
(
"/v1/responses/{response_id}"
,
delete
(
v1_responses_delete
))
.route
(
"/v1/responses/{response_id}/input"
,
get
(
v1_responses_list_input_items
),
)
.route_layer
(
axum
::
middleware
::
from_fn_with_state
(
.route_layer
(
axum
::
middleware
::
from_fn_with_state
(
app_state
.clone
(),
app_state
.clone
(),
crate
::
middleware
::
concurrency_limit_middleware
,
crate
::
middleware
::
concurrency_limit_middleware
,
...
...
sgl-router/tests/api_endpoints_test.rs
View file @
366043db
...
@@ -994,6 +994,7 @@ mod router_policy_tests {
...
@@ -994,6 +994,7 @@ mod router_policy_tests {
#[cfg(test)]
#[cfg(test)]
mod
responses_endpoint_tests
{
mod
responses_endpoint_tests
{
use
super
::
*
;
use
super
::
*
;
use
reqwest
::
Client
as
HttpClient
;
#[tokio::test]
#[tokio::test]
async
fn
test_v1_responses_non_streaming
()
{
async
fn
test_v1_responses_non_streaming
()
{
...
@@ -1074,6 +1075,207 @@ mod responses_endpoint_tests {
...
@@ -1074,6 +1075,207 @@ mod responses_endpoint_tests {
// We don't fully consume the stream in this test harness.
// We don't fully consume the stream in this test harness.
ctx
.shutdown
()
.await
;
ctx
.shutdown
()
.await
;
}
}
#[tokio::test]
async
fn
test_v1_responses_get
()
{
let
ctx
=
TestContext
::
new
(
vec!
[
MockWorkerConfig
{
port
:
18952
,
worker_type
:
WorkerType
::
Regular
,
health_status
:
HealthStatus
::
Healthy
,
response_delay_ms
:
0
,
fail_rate
:
0.0
,
}])
.await
;
let
app
=
ctx
.create_app
()
.await
;
// First create a response to obtain an id
let
payload
=
json!
({
"input"
:
"Hello Responses API"
,
"model"
:
"mock-model"
,
"stream"
:
false
});
let
req
=
Request
::
builder
()
.method
(
"POST"
)
.uri
(
"/v1/responses"
)
.header
(
CONTENT_TYPE
,
"application/json"
)
.body
(
Body
::
from
(
serde_json
::
to_string
(
&
payload
)
.unwrap
()))
.unwrap
();
let
resp
=
app
.clone
()
.oneshot
(
req
)
.await
.unwrap
();
assert_eq!
(
resp
.status
(),
StatusCode
::
OK
);
let
body
=
axum
::
body
::
to_bytes
(
resp
.into_body
(),
usize
::
MAX
)
.await
.unwrap
();
let
body_json
:
serde_json
::
Value
=
serde_json
::
from_slice
(
&
body
)
.unwrap
();
let
resp_id
=
body_json
[
"id"
]
.as_str
()
.unwrap
()
.to_string
();
// Retrieve the response
let
req
=
Request
::
builder
()
.method
(
"GET"
)
.uri
(
format!
(
"/v1/responses/{}"
,
resp_id
))
.body
(
Body
::
empty
())
.unwrap
();
let
resp
=
app
.clone
()
.oneshot
(
req
)
.await
.unwrap
();
assert_eq!
(
resp
.status
(),
StatusCode
::
OK
);
let
body
=
axum
::
body
::
to_bytes
(
resp
.into_body
(),
usize
::
MAX
)
.await
.unwrap
();
let
get_json
:
serde_json
::
Value
=
serde_json
::
from_slice
(
&
body
)
.unwrap
();
assert_eq!
(
get_json
[
"object"
],
"response"
);
ctx
.shutdown
()
.await
;
}
#[tokio::test]
async
fn
test_v1_responses_cancel
()
{
let
ctx
=
TestContext
::
new
(
vec!
[
MockWorkerConfig
{
port
:
18953
,
worker_type
:
WorkerType
::
Regular
,
health_status
:
HealthStatus
::
Healthy
,
response_delay_ms
:
0
,
fail_rate
:
0.0
,
}])
.await
;
let
app
=
ctx
.create_app
()
.await
;
// First create a response to obtain an id
let
payload
=
json!
({
"input"
:
"Hello Responses API"
,
"model"
:
"mock-model"
,
"stream"
:
false
});
let
req
=
Request
::
builder
()
.method
(
"POST"
)
.uri
(
"/v1/responses"
)
.header
(
CONTENT_TYPE
,
"application/json"
)
.body
(
Body
::
from
(
serde_json
::
to_string
(
&
payload
)
.unwrap
()))
.unwrap
();
let
resp
=
app
.clone
()
.oneshot
(
req
)
.await
.unwrap
();
assert_eq!
(
resp
.status
(),
StatusCode
::
OK
);
let
body
=
axum
::
body
::
to_bytes
(
resp
.into_body
(),
usize
::
MAX
)
.await
.unwrap
();
let
body_json
:
serde_json
::
Value
=
serde_json
::
from_slice
(
&
body
)
.unwrap
();
let
resp_id
=
body_json
[
"id"
]
.as_str
()
.unwrap
()
.to_string
();
// Cancel the response
let
req
=
Request
::
builder
()
.method
(
"POST"
)
.uri
(
format!
(
"/v1/responses/{}/cancel"
,
resp_id
))
.body
(
Body
::
empty
())
.unwrap
();
let
resp
=
app
.clone
()
.oneshot
(
req
)
.await
.unwrap
();
assert_eq!
(
resp
.status
(),
StatusCode
::
OK
);
let
body
=
axum
::
body
::
to_bytes
(
resp
.into_body
(),
usize
::
MAX
)
.await
.unwrap
();
let
cancel_json
:
serde_json
::
Value
=
serde_json
::
from_slice
(
&
body
)
.unwrap
();
assert_eq!
(
cancel_json
[
"status"
],
"cancelled"
);
ctx
.shutdown
()
.await
;
}
#[tokio::test]
async
fn
test_v1_responses_delete_and_list_not_implemented
()
{
let
ctx
=
TestContext
::
new
(
vec!
[
MockWorkerConfig
{
port
:
18954
,
worker_type
:
WorkerType
::
Regular
,
health_status
:
HealthStatus
::
Healthy
,
response_delay_ms
:
0
,
fail_rate
:
0.0
,
}])
.await
;
let
app
=
ctx
.create_app
()
.await
;
// Use an arbitrary id for delete/list
let
resp_id
=
"resp-test-123"
;
let
req
=
Request
::
builder
()
.method
(
"DELETE"
)
.uri
(
format!
(
"/v1/responses/{}"
,
resp_id
))
.body
(
Body
::
empty
())
.unwrap
();
let
resp
=
app
.clone
()
.oneshot
(
req
)
.await
.unwrap
();
assert_eq!
(
resp
.status
(),
StatusCode
::
NOT_IMPLEMENTED
);
let
req
=
Request
::
builder
()
.method
(
"GET"
)
.uri
(
format!
(
"/v1/responses/{}/input"
,
resp_id
))
.body
(
Body
::
empty
())
.unwrap
();
let
resp
=
app
.clone
()
.oneshot
(
req
)
.await
.unwrap
();
assert_eq!
(
resp
.status
(),
StatusCode
::
NOT_IMPLEMENTED
);
ctx
.shutdown
()
.await
;
}
#[tokio::test]
async
fn
test_v1_responses_get_multi_worker_fanout
()
{
// Start two mock workers
let
ctx
=
TestContext
::
new
(
vec!
[
MockWorkerConfig
{
port
:
18960
,
worker_type
:
WorkerType
::
Regular
,
health_status
:
HealthStatus
::
Healthy
,
response_delay_ms
:
0
,
fail_rate
:
0.0
,
},
MockWorkerConfig
{
port
:
18961
,
worker_type
:
WorkerType
::
Regular
,
health_status
:
HealthStatus
::
Healthy
,
response_delay_ms
:
0
,
fail_rate
:
0.0
,
},
])
.await
;
let
app
=
ctx
.create_app
()
.await
;
// Create a background response with a known id
let
rid
=
format!
(
"resp_{}"
,
18960
);
// arbitrary unique id
let
payload
=
json!
({
"input"
:
"Hello Responses API"
,
"model"
:
"mock-model"
,
"background"
:
true
,
"store"
:
true
,
"request_id"
:
rid
,
});
let
req
=
Request
::
builder
()
.method
(
"POST"
)
.uri
(
"/v1/responses"
)
.header
(
CONTENT_TYPE
,
"application/json"
)
.body
(
Body
::
from
(
serde_json
::
to_string
(
&
payload
)
.unwrap
()))
.unwrap
();
let
resp
=
app
.clone
()
.oneshot
(
req
)
.await
.unwrap
();
assert_eq!
(
resp
.status
(),
StatusCode
::
OK
);
// Using the router, GET should succeed by fanning out across workers
let
req
=
Request
::
builder
()
.method
(
"GET"
)
.uri
(
format!
(
"/v1/responses/{}"
,
rid
))
.body
(
Body
::
empty
())
.unwrap
();
let
resp
=
app
.clone
()
.oneshot
(
req
)
.await
.unwrap
();
assert_eq!
(
resp
.status
(),
StatusCode
::
OK
);
// Validate only one worker holds the metadata: direct calls
let
client
=
HttpClient
::
new
();
let
mut
ok_count
=
0u
size
;
for
url
in
ctx
.router
.get_worker_urls
()
{
let
get_url
=
format!
(
"{}/v1/responses/{}"
,
url
,
rid
);
let
res
=
client
.get
(
get_url
)
.send
()
.await
.unwrap
();
if
res
.status
()
==
StatusCode
::
OK
{
ok_count
+=
1
;
}
}
assert_eq!
(
ok_count
,
1
,
"exactly one worker should store the response"
);
ctx
.shutdown
()
.await
;
}
}
}
#[cfg(test)]
#[cfg(test)]
...
...
sgl-router/tests/common/mock_worker.rs
View file @
366043db
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
#![allow(dead_code)]
#![allow(dead_code)]
use
axum
::{
use
axum
::{
extract
::{
Json
,
State
},
extract
::{
Json
,
Path
,
State
},
http
::
StatusCode
,
http
::
StatusCode
,
response
::
sse
::{
Event
,
KeepAlive
},
response
::
sse
::{
Event
,
KeepAlive
},
response
::{
IntoResponse
,
Response
,
Sse
},
response
::{
IntoResponse
,
Response
,
Sse
},
...
@@ -11,8 +11,9 @@ use axum::{
...
@@ -11,8 +11,9 @@ use axum::{
};
};
use
futures_util
::
stream
::{
self
,
StreamExt
};
use
futures_util
::
stream
::{
self
,
StreamExt
};
use
serde_json
::
json
;
use
serde_json
::
json
;
use
std
::
collections
::{
HashMap
,
HashSet
};
use
std
::
convert
::
Infallible
;
use
std
::
convert
::
Infallible
;
use
std
::
sync
::
Arc
;
use
std
::
sync
::
{
Arc
,
Mutex
,
OnceLock
}
;
use
std
::
time
::{
SystemTime
,
UNIX_EPOCH
};
use
std
::
time
::{
SystemTime
,
UNIX_EPOCH
};
use
tokio
::
sync
::
RwLock
;
use
tokio
::
sync
::
RwLock
;
use
uuid
::
Uuid
;
use
uuid
::
Uuid
;
...
@@ -83,6 +84,11 @@ impl MockWorker {
...
@@ -83,6 +84,11 @@ impl MockWorker {
.route
(
"/v1/completions"
,
post
(
completions_handler
))
.route
(
"/v1/completions"
,
post
(
completions_handler
))
.route
(
"/v1/rerank"
,
post
(
rerank_handler
))
.route
(
"/v1/rerank"
,
post
(
rerank_handler
))
.route
(
"/v1/responses"
,
post
(
responses_handler
))
.route
(
"/v1/responses"
,
post
(
responses_handler
))
.route
(
"/v1/responses/{response_id}"
,
get
(
responses_get_handler
))
.route
(
"/v1/responses/{response_id}/cancel"
,
post
(
responses_cancel_handler
),
)
.route
(
"/flush_cache"
,
post
(
flush_cache_handler
))
.route
(
"/flush_cache"
,
post
(
flush_cache_handler
))
.route
(
"/v1/models"
,
get
(
v1_models_handler
))
.route
(
"/v1/models"
,
get
(
v1_models_handler
))
.with_state
(
config
);
.with_state
(
config
);
...
@@ -584,6 +590,21 @@ async fn responses_handler(
...
@@ -584,6 +590,21 @@ async fn responses_handler(
.unwrap
()
.unwrap
()
.as_secs
()
as
i64
;
.as_secs
()
as
i64
;
// Background storage simulation
let
is_background
=
payload
.get
(
"background"
)
.and_then
(|
v
|
v
.as_bool
())
.unwrap_or
(
false
);
let
req_id
=
payload
.get
(
"request_id"
)
.and_then
(|
v
|
v
.as_str
())
.map
(|
s
|
s
.to_string
());
if
is_background
{
if
let
Some
(
id
)
=
&
req_id
{
store_response_for_port
(
config
.port
,
id
);
}
}
if
is_stream
{
if
is_stream
{
let
request_id
=
format!
(
"resp-{}"
,
Uuid
::
new_v4
());
let
request_id
=
format!
(
"resp-{}"
,
Uuid
::
new_v4
());
...
@@ -610,6 +631,18 @@ async fn responses_handler(
...
@@ -610,6 +631,18 @@ async fn responses_handler(
Sse
::
new
(
stream
)
Sse
::
new
(
stream
)
.keep_alive
(
KeepAlive
::
default
())
.keep_alive
(
KeepAlive
::
default
())
.into_response
()
.into_response
()
}
else
if
is_background
{
let
rid
=
req_id
.unwrap_or_else
(||
format!
(
"resp-{}"
,
Uuid
::
new_v4
()));
Json
(
json!
({
"id"
:
rid
,
"object"
:
"response"
,
"created_at"
:
timestamp
,
"model"
:
"mock-model"
,
"output"
:
[],
"status"
:
"queued"
,
"usage"
:
null
}))
.into_response
()
}
else
{
}
else
{
Json
(
json!
({
Json
(
json!
({
"id"
:
format!
(
"resp-{}"
,
Uuid
::
new_v4
()),
"id"
:
format!
(
"resp-{}"
,
Uuid
::
new_v4
()),
...
@@ -688,6 +721,95 @@ async fn v1_models_handler(State(config): State<Arc<RwLock<MockWorkerConfig>>>)
...
@@ -688,6 +721,95 @@ async fn v1_models_handler(State(config): State<Arc<RwLock<MockWorkerConfig>>>)
.into_response
()
.into_response
()
}
}
async
fn
responses_get_handler
(
State
(
config
):
State
<
Arc
<
RwLock
<
MockWorkerConfig
>>>
,
Path
(
response_id
):
Path
<
String
>
,
)
->
Response
{
let
config
=
config
.read
()
.await
;
if
should_fail
(
&
config
)
.await
{
return
(
StatusCode
::
INTERNAL_SERVER_ERROR
,
Json
(
json!
({
"error"
:
"Random failure for testing"
})),
)
.into_response
();
}
let
timestamp
=
SystemTime
::
now
()
.duration_since
(
UNIX_EPOCH
)
.unwrap
()
.as_secs
()
as
i64
;
// Only return 200 if this worker "stores" the response id
if
response_exists_for_port
(
config
.port
,
&
response_id
)
{
Json
(
json!
({
"id"
:
response_id
,
"object"
:
"response"
,
"created_at"
:
timestamp
,
"model"
:
"mock-model"
,
"output"
:
[],
"status"
:
"completed"
,
"usage"
:
{
"input_tokens"
:
0
,
"output_tokens"
:
0
,
"total_tokens"
:
0
}
}))
.into_response
()
}
else
{
StatusCode
::
NOT_FOUND
.into_response
()
}
}
async
fn
responses_cancel_handler
(
State
(
config
):
State
<
Arc
<
RwLock
<
MockWorkerConfig
>>>
,
Path
(
response_id
):
Path
<
String
>
,
)
->
Response
{
let
config
=
config
.read
()
.await
;
if
should_fail
(
&
config
)
.await
{
return
(
StatusCode
::
INTERNAL_SERVER_ERROR
,
Json
(
json!
({
"error"
:
"Random failure for testing"
})),
)
.into_response
();
}
let
timestamp
=
SystemTime
::
now
()
.duration_since
(
UNIX_EPOCH
)
.unwrap
()
.as_secs
()
as
i64
;
if
response_exists_for_port
(
config
.port
,
&
response_id
)
{
Json
(
json!
({
"id"
:
response_id
,
"object"
:
"response"
,
"created_at"
:
timestamp
,
"model"
:
"mock-model"
,
"output"
:
[],
"status"
:
"cancelled"
,
"usage"
:
null
}))
.into_response
()
}
else
{
StatusCode
::
NOT_FOUND
.into_response
()
}
}
// --- Simple in-memory response store per worker port (for tests) ---
static
RESP_STORE
:
OnceLock
<
Mutex
<
HashMap
<
u16
,
HashSet
<
String
>>>>
=
OnceLock
::
new
();
fn
get_store
()
->
&
'static
Mutex
<
HashMap
<
u16
,
HashSet
<
String
>>>
{
RESP_STORE
.get_or_init
(||
Mutex
::
new
(
HashMap
::
new
()))
}
fn
store_response_for_port
(
port
:
u16
,
response_id
:
&
str
)
{
let
mut
map
=
get_store
()
.lock
()
.unwrap
();
map
.entry
(
port
)
.or_default
()
.insert
(
response_id
.to_string
());
}
fn
response_exists_for_port
(
port
:
u16
,
response_id
:
&
str
)
->
bool
{
let
map
=
get_store
()
.lock
()
.unwrap
();
map
.get
(
&
port
)
.map
(|
set
|
set
.contains
(
response_id
))
.unwrap_or
(
false
)
}
// Minimal rerank handler returning mock results; router shapes final response
async
fn
rerank_handler
(
async
fn
rerank_handler
(
State
(
config
):
State
<
Arc
<
RwLock
<
MockWorkerConfig
>>>
,
State
(
config
):
State
<
Arc
<
RwLock
<
MockWorkerConfig
>>>
,
Json
(
payload
):
Json
<
serde_json
::
Value
>
,
Json
(
payload
):
Json
<
serde_json
::
Value
>
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment