Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
f6e07f27
Unverified
Commit
f6e07f27
authored
Jul 23, 2025
by
Simo Lin
Committed by
GitHub
Jul 23, 2025
Browse files
[router] fix pd model completion request (#8303)
parent
5dd0f870
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
320 additions
and
15 deletions
+320
-15
sgl-router/benches/request_processing.rs
sgl-router/benches/request_processing.rs
+1
-0
sgl-router/src/openai_api_types.rs
sgl-router/src/openai_api_types.rs
+4
-0
sgl-router/src/routers/pd_router.rs
sgl-router/src/routers/pd_router.rs
+75
-15
sgl-router/src/routers/pd_types.rs
sgl-router/src/routers/pd_types.rs
+233
-0
sgl-router/src/routers/request_adapter.rs
sgl-router/src/routers/request_adapter.rs
+5
-0
sgl-router/tests/benchmark_integration.rs
sgl-router/tests/benchmark_integration.rs
+2
-0
No files found.
sgl-router/benches/request_processing.rs
View file @
f6e07f27
...
@@ -97,6 +97,7 @@ fn create_sample_completion_request() -> CompletionRequest {
...
@@ -97,6 +97,7 @@ fn create_sample_completion_request() -> CompletionRequest {
logit_bias
:
None
,
logit_bias
:
None
,
user
:
None
,
user
:
None
,
seed
:
None
,
seed
:
None
,
other
:
serde_json
::
Map
::
new
(),
}
}
}
}
...
...
sgl-router/src/openai_api_types.rs
View file @
f6e07f27
...
@@ -91,6 +91,10 @@ pub struct CompletionRequest {
...
@@ -91,6 +91,10 @@ pub struct CompletionRequest {
/// If specified, our system will make a best effort to sample deterministically
/// If specified, our system will make a best effort to sample deterministically
#[serde(skip_serializing_if
=
"Option::is_none"
)]
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
seed
:
Option
<
i64
>
,
pub
seed
:
Option
<
i64
>
,
/// Additional fields including bootstrap info for PD routing
#[serde(flatten)]
pub
other
:
serde_json
::
Map
<
String
,
serde_json
::
Value
>
,
}
}
impl
GenerationRequest
for
CompletionRequest
{
impl
GenerationRequest
for
CompletionRequest
{
...
...
sgl-router/src/routers/pd_router.rs
View file @
f6e07f27
...
@@ -420,6 +420,77 @@ impl PDRouter {
...
@@ -420,6 +420,77 @@ impl PDRouter {
.await
.await
}
}
// Route a completion request while preserving OpenAI format
pub
async
fn
route_completion
(
&
self
,
client
:
&
reqwest
::
Client
,
req
:
&
HttpRequest
,
mut
typed_req
:
CompletionRequest
,
route
:
&
str
,
)
->
HttpResponse
{
let
start
=
Instant
::
now
();
// Get stream flag and return_logprob flag before moving the request
let
is_stream
=
typed_req
.stream
;
let
return_logprob
=
typed_req
.logprobs
.is_some
();
// Extract text for cache-aware routing from the typed request
let
request_text
=
match
&
typed_req
.prompt
{
crate
::
openai_api_types
::
StringOrArray
::
String
(
s
)
=>
Some
(
s
.as_str
()),
crate
::
openai_api_types
::
StringOrArray
::
Array
(
arr
)
=>
arr
.first
()
.map
(|
s
|
s
.as_str
()),
};
// Select servers
let
(
prefill
,
decode
)
=
match
self
.select_pd_pair
(
client
,
request_text
)
.await
{
Ok
(
pair
)
=>
pair
,
Err
(
e
)
=>
{
error!
(
"Failed to select PD pair: {}"
,
e
);
RouterMetrics
::
record_pd_error
(
"server_selection"
);
return
HttpResponse
::
ServiceUnavailable
()
.body
(
format!
(
"No available servers: {}"
,
e
));
}
};
// Log routing decision
info!
(
"PD routing: {} -> prefill={}, decode={}"
,
route
,
prefill
.url
(),
decode
.url
()
);
// Add bootstrap info using the trait method
if
let
Err
(
e
)
=
typed_req
.add_bootstrap_info
(
prefill
.as_ref
())
{
error!
(
"Failed to add bootstrap info: {}"
,
e
);
RouterMetrics
::
record_pd_error
(
"bootstrap_injection"
);
return
HttpResponse
::
InternalServerError
()
.body
(
format!
(
"Bootstrap injection failed: {}"
,
e
));
}
// Convert to JSON after bootstrap injection
let
json_with_bootstrap
=
match
serde_json
::
to_value
(
&
typed_req
)
{
Ok
(
json
)
=>
json
,
Err
(
e
)
=>
{
error!
(
"Failed to serialize request: {}"
,
e
);
return
HttpResponse
::
InternalServerError
()
.body
(
"Failed to serialize request"
);
}
};
// Execute dual dispatch
self
.execute_dual_dispatch
(
client
,
req
,
json_with_bootstrap
,
route
,
prefill
.as_ref
(),
decode
.as_ref
(),
is_stream
,
return_logprob
,
start
,
)
.await
}
// Execute the dual dispatch to prefill and decode servers
// Execute the dual dispatch to prefill and decode servers
#[allow(clippy::too_many_arguments)]
#[allow(clippy::too_many_arguments)]
async
fn
execute_dual_dispatch
(
async
fn
execute_dual_dispatch
(
...
@@ -1302,23 +1373,12 @@ impl RouterTrait for PDRouter {
...
@@ -1302,23 +1373,12 @@ impl RouterTrait for PDRouter {
req
:
&
HttpRequest
,
req
:
&
HttpRequest
,
body
:
serde_json
::
Value
,
body
:
serde_json
::
Value
,
)
->
HttpResponse
{
)
->
HttpResponse
{
match
serde_json
::
from_value
::
<
CompletionRequest
>
(
body
.clone
()
)
{
match
serde_json
::
from_value
::
<
CompletionRequest
>
(
body
)
{
Ok
(
openai_req
)
=>
{
Ok
(
openai_req
)
=>
{
// Convert OpenAI format to PD format (CompletionRequest -> GenerateReqInput)
// Use the new method that preserves OpenAI format
let
pd_req
=
openai_req
.to_pd_request
();
PDRouter
::
route_completion
(
self
,
client
,
req
,
openai_req
,
"/v1/completions"
)
.await
PDRouter
::
route_generate
(
self
,
client
,
req
,
pd_req
,
"/v1/completions"
)
.await
}
Err
(
_
)
=>
{
// If that fails, try to deserialize directly as PD format (for backwards compatibility)
match
serde_json
::
from_value
::
<
GenerateReqInput
>
(
body
)
{
Ok
(
pd_req
)
=>
{
PDRouter
::
route_generate
(
self
,
client
,
req
,
pd_req
,
"/v1/completions"
)
.await
}
Err
(
e
)
=>
{
HttpResponse
::
BadRequest
()
.body
(
format!
(
"Invalid request format: {}"
,
e
))
}
}
}
}
Err
(
e
)
=>
HttpResponse
::
BadRequest
()
.body
(
format!
(
"Invalid request format: {}"
,
e
)),
}
}
}
}
...
...
sgl-router/src/routers/pd_types.rs
View file @
f6e07f27
// Essential PDLB types extracted for PD routing
// Essential PDLB types extracted for PD routing
use
crate
::
core
::{
Worker
,
WorkerType
};
use
crate
::
core
::{
Worker
,
WorkerType
};
use
crate
::
openai_api_types
::{
CompletionRequest
,
StringOrArray
};
use
serde
::{
Deserialize
,
Serialize
};
use
serde
::{
Deserialize
,
Serialize
};
use
serde_json
::
Value
;
use
serde_json
::
Value
;
...
@@ -233,3 +234,235 @@ impl Bootstrap for ChatReqInput {
...
@@ -233,3 +234,235 @@ impl Bootstrap for ChatReqInput {
self
.bootstrap_room
=
Some
(
bootstrap_room
);
self
.bootstrap_room
=
Some
(
bootstrap_room
);
}
}
}
}
// Bootstrap implementation for CompletionRequest to preserve OpenAI format
impl
Bootstrap
for
CompletionRequest
{
fn
is_stream
(
&
self
)
->
bool
{
self
.stream
}
fn
get_batch_size
(
&
self
)
->
Result
<
Option
<
usize
>
,
String
>
{
if
let
StringOrArray
::
Array
(
prompts
)
=
&
self
.prompt
{
if
prompts
.is_empty
()
{
return
Err
(
"Batch prompt array is empty"
.to_string
());
}
return
Ok
(
Some
(
prompts
.len
()));
}
// Single string prompt
Ok
(
None
)
}
fn
set_bootstrap_info
(
&
mut
self
,
bootstrap_host
:
BootstrapHost
,
bootstrap_port
:
BootstrapPort
,
bootstrap_room
:
BootstrapRoom
,
)
{
// Insert bootstrap_host - it serializes correctly whether Single or Batch
if
let
Ok
(
host_value
)
=
serde_json
::
to_value
(
&
bootstrap_host
)
{
self
.other
.insert
(
"bootstrap_host"
.to_string
(),
host_value
);
}
// Insert bootstrap_port - it serializes correctly whether Single or Batch
if
let
Ok
(
port_value
)
=
serde_json
::
to_value
(
&
bootstrap_port
)
{
self
.other
.insert
(
"bootstrap_port"
.to_string
(),
port_value
);
}
// Insert bootstrap_room - it serializes correctly whether Single or Batch
if
let
Ok
(
room_value
)
=
serde_json
::
to_value
(
&
bootstrap_room
)
{
self
.other
.insert
(
"bootstrap_room"
.to_string
(),
room_value
);
}
}
}
#[cfg(test)]
mod
bootstrap_tests
{
use
super
::
*
;
use
crate
::
openai_api_types
::
StringOrArray
;
#[test]
fn
test_completion_batch_size_with_array_prompt
()
{
let
req
=
CompletionRequest
{
model
:
"test"
.to_string
(),
prompt
:
StringOrArray
::
Array
(
vec!
[
"prompt1"
.to_string
(),
"prompt2"
.to_string
()]),
n
:
None
,
other
:
serde_json
::
Map
::
new
(),
suffix
:
None
,
max_tokens
:
None
,
temperature
:
None
,
top_p
:
None
,
stream
:
false
,
stream_options
:
None
,
logprobs
:
None
,
echo
:
false
,
stop
:
None
,
presence_penalty
:
None
,
frequency_penalty
:
None
,
best_of
:
None
,
logit_bias
:
None
,
user
:
None
,
seed
:
None
,
};
// Should return batch size for array prompt
assert_eq!
(
req
.get_batch_size
()
.unwrap
(),
Some
(
2
));
}
#[test]
fn
test_completion_batch_size_with_single_prompt
()
{
let
req
=
CompletionRequest
{
model
:
"test"
.to_string
(),
prompt
:
StringOrArray
::
String
(
"single prompt"
.to_string
()),
n
:
None
,
other
:
serde_json
::
Map
::
new
(),
suffix
:
None
,
max_tokens
:
None
,
temperature
:
None
,
top_p
:
None
,
stream
:
false
,
stream_options
:
None
,
logprobs
:
None
,
echo
:
false
,
stop
:
None
,
presence_penalty
:
None
,
frequency_penalty
:
None
,
best_of
:
None
,
logit_bias
:
None
,
user
:
None
,
seed
:
None
,
};
// Should return None for single prompt
assert_eq!
(
req
.get_batch_size
()
.unwrap
(),
None
);
}
#[test]
fn
test_completion_batch_size_with_n_parameter
()
{
let
req
=
CompletionRequest
{
model
:
"test"
.to_string
(),
prompt
:
StringOrArray
::
String
(
"single prompt"
.to_string
()),
n
:
Some
(
3
),
other
:
serde_json
::
Map
::
new
(),
suffix
:
None
,
max_tokens
:
None
,
temperature
:
None
,
top_p
:
None
,
stream
:
false
,
stream_options
:
None
,
logprobs
:
None
,
echo
:
false
,
stop
:
None
,
presence_penalty
:
None
,
frequency_penalty
:
None
,
best_of
:
None
,
logit_bias
:
None
,
user
:
None
,
seed
:
None
,
};
// Should return None for single string prompt, even with n > 1
// SGLang handles n parameter differently than batch requests
assert_eq!
(
req
.get_batch_size
()
.unwrap
(),
None
);
}
#[test]
fn
test_completion_bootstrap_single_values
()
{
let
mut
req
=
CompletionRequest
{
model
:
"test"
.to_string
(),
prompt
:
StringOrArray
::
Array
(
vec!
[
"prompt1"
.to_string
(),
"prompt2"
.to_string
()]),
n
:
None
,
other
:
serde_json
::
Map
::
new
(),
suffix
:
None
,
max_tokens
:
None
,
temperature
:
None
,
top_p
:
None
,
stream
:
false
,
stream_options
:
None
,
logprobs
:
None
,
echo
:
false
,
stop
:
None
,
presence_penalty
:
None
,
frequency_penalty
:
None
,
best_of
:
None
,
logit_bias
:
None
,
user
:
None
,
seed
:
None
,
};
// Set bootstrap info - should always use single values
req
.set_bootstrap_info
(
BootstrapHost
::
Single
(
"test-server"
.to_string
()),
BootstrapPort
::
Single
(
Some
(
5678
)),
BootstrapRoom
::
Single
(
12345
),
);
// Verify single values were created
assert
!
(
req
.other
.get
(
"bootstrap_host"
)
.unwrap
()
.is_string
());
assert
!
(
req
.other
.get
(
"bootstrap_port"
)
.unwrap
()
.is_number
());
assert
!
(
req
.other
.get
(
"bootstrap_room"
)
.unwrap
()
.is_number
());
assert_eq!
(
req
.other
.get
(
"bootstrap_host"
)
.unwrap
()
.as_str
()
.unwrap
(),
"test-server"
);
assert_eq!
(
req
.other
.get
(
"bootstrap_port"
)
.unwrap
()
.as_u64
()
.unwrap
(),
5678
);
assert_eq!
(
req
.other
.get
(
"bootstrap_room"
)
.unwrap
()
.as_u64
()
.unwrap
(),
12345
);
}
#[test]
fn
test_completion_bootstrap_array_values
()
{
let
mut
req
=
CompletionRequest
{
model
:
"test"
.to_string
(),
prompt
:
StringOrArray
::
Array
(
vec!
[
"prompt1"
.to_string
(),
"prompt2"
.to_string
()]),
n
:
None
,
other
:
serde_json
::
Map
::
new
(),
suffix
:
None
,
max_tokens
:
None
,
temperature
:
None
,
top_p
:
None
,
stream
:
false
,
stream_options
:
None
,
logprobs
:
None
,
echo
:
false
,
stop
:
None
,
presence_penalty
:
None
,
frequency_penalty
:
None
,
best_of
:
None
,
logit_bias
:
None
,
user
:
None
,
seed
:
None
,
};
// Set bootstrap info with arrays
req
.set_bootstrap_info
(
BootstrapHost
::
Batch
(
vec!
[
"test-server"
.to_string
();
2
]),
BootstrapPort
::
Batch
(
vec!
[
Some
(
5678
);
2
]),
BootstrapRoom
::
Batch
(
vec!
[
12345
,
67890
]),
);
// Verify arrays were created correctly
assert
!
(
req
.other
.get
(
"bootstrap_host"
)
.unwrap
()
.is_array
());
assert
!
(
req
.other
.get
(
"bootstrap_port"
)
.unwrap
()
.is_array
());
assert
!
(
req
.other
.get
(
"bootstrap_room"
)
.unwrap
()
.is_array
());
let
hosts
=
req
.other
.get
(
"bootstrap_host"
)
.unwrap
()
.as_array
()
.unwrap
();
assert_eq!
(
hosts
.len
(),
2
);
assert_eq!
(
hosts
[
0
]
.as_str
()
.unwrap
(),
"test-server"
);
let
ports
=
req
.other
.get
(
"bootstrap_port"
)
.unwrap
()
.as_array
()
.unwrap
();
assert_eq!
(
ports
.len
(),
2
);
assert_eq!
(
ports
[
0
]
.as_u64
()
.unwrap
(),
5678
);
let
rooms
=
req
.other
.get
(
"bootstrap_room"
)
.unwrap
()
.as_array
()
.unwrap
();
assert_eq!
(
rooms
.len
(),
2
);
assert_eq!
(
rooms
[
0
]
.as_u64
()
.unwrap
(),
12345
);
assert_eq!
(
rooms
[
1
]
.as_u64
()
.unwrap
(),
67890
);
}
}
sgl-router/src/routers/request_adapter.rs
View file @
f6e07f27
...
@@ -648,6 +648,7 @@ mod tests {
...
@@ -648,6 +648,7 @@ mod tests {
user
:
None
,
user
:
None
,
seed
:
None
,
seed
:
None
,
suffix
:
None
,
suffix
:
None
,
other
:
serde_json
::
Map
::
new
(),
};
};
let
pd_req
=
req
.to_pd_request
();
let
pd_req
=
req
.to_pd_request
();
...
@@ -687,6 +688,7 @@ mod tests {
...
@@ -687,6 +688,7 @@ mod tests {
user
:
None
,
user
:
None
,
seed
:
None
,
seed
:
None
,
suffix
:
None
,
suffix
:
None
,
other
:
serde_json
::
Map
::
new
(),
};
};
let
pd_req
=
req
.to_pd_request
();
let
pd_req
=
req
.to_pd_request
();
...
@@ -725,6 +727,7 @@ mod tests {
...
@@ -725,6 +727,7 @@ mod tests {
user
:
Some
(
"user123"
.to_string
()),
user
:
Some
(
"user123"
.to_string
()),
seed
:
Some
(
42
),
seed
:
Some
(
42
),
suffix
:
Some
(
"..."
.to_string
()),
suffix
:
Some
(
"..."
.to_string
()),
other
:
serde_json
::
Map
::
new
(),
};
};
let
pd_req
=
req
.to_pd_request
();
let
pd_req
=
req
.to_pd_request
();
...
@@ -768,6 +771,7 @@ mod tests {
...
@@ -768,6 +771,7 @@ mod tests {
user
:
None
,
user
:
None
,
seed
:
None
,
seed
:
None
,
suffix
:
None
,
suffix
:
None
,
other
:
serde_json
::
Map
::
new
(),
};
};
let
pd_req
=
req
.to_pd_request
();
let
pd_req
=
req
.to_pd_request
();
...
@@ -799,6 +803,7 @@ mod tests {
...
@@ -799,6 +803,7 @@ mod tests {
user
:
None
,
user
:
None
,
seed
:
None
,
seed
:
None
,
suffix
:
None
,
suffix
:
None
,
other
:
serde_json
::
Map
::
new
(),
};
};
let
pd_req
=
req
.to_pd_request
();
let
pd_req
=
req
.to_pd_request
();
...
...
sgl-router/tests/benchmark_integration.rs
View file @
f6e07f27
...
@@ -86,6 +86,7 @@ fn test_benchmark_request_creation() {
...
@@ -86,6 +86,7 @@ fn test_benchmark_request_creation() {
logit_bias
:
None
,
logit_bias
:
None
,
user
:
None
,
user
:
None
,
seed
:
None
,
seed
:
None
,
other
:
serde_json
::
Map
::
new
(),
};
};
// Test serialization works
// Test serialization works
...
@@ -181,6 +182,7 @@ fn test_benchmark_request_adaptation() {
...
@@ -181,6 +182,7 @@ fn test_benchmark_request_adaptation() {
logit_bias
:
None
,
logit_bias
:
None
,
user
:
None
,
user
:
None
,
seed
:
None
,
seed
:
None
,
other
:
serde_json
::
Map
::
new
(),
};
};
// Test PD adaptation (should not panic)
// Test PD adaptation (should not panic)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment