Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
8c7bb39d
Unverified
Commit
8c7bb39d
authored
Aug 05, 2025
by
Simo Lin
Committed by
GitHub
Aug 05, 2025
Browse files
[router] PD Router Simplification and Reorganization (#8838)
parent
ca47e24f
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
1098 additions
and
2555 deletions
+1098
-2555
sgl-router/benches/request_processing.rs
sgl-router/benches/request_processing.rs
+101
-58
sgl-router/scripts/run_benchmarks.py
sgl-router/scripts/run_benchmarks.py
+0
-3
sgl-router/src/routers/bootstrap_injector.rs
sgl-router/src/routers/bootstrap_injector.rs
+334
-0
sgl-router/src/routers/mod.rs
sgl-router/src/routers/mod.rs
+1
-1
sgl-router/src/routers/pd_router.rs
sgl-router/src/routers/pd_router.rs
+596
-524
sgl-router/src/routers/pd_types.rs
sgl-router/src/routers/pd_types.rs
+0
-432
sgl-router/src/routers/request_adapter.rs
sgl-router/src/routers/request_adapter.rs
+0
-1512
sgl-router/tests/benchmark_integration.rs
sgl-router/tests/benchmark_integration.rs
+66
-25
No files found.
sgl-router/benches/request_processing.rs
View file @
8c7bb39d
use
criterion
::{
black_box
,
criterion_group
,
criterion_main
,
BenchmarkId
,
Criterion
,
Throughput
};
use
criterion
::{
black_box
,
criterion_group
,
criterion_main
,
BenchmarkId
,
Criterion
,
Throughput
};
use
serde_json
::{
from_str
,
to_string
,
to_vec
};
use
serde_json
::{
from_str
,
to_string
,
to_value
,
to_vec
};
use
std
::
time
::
Instant
;
use
std
::
time
::
Instant
;
use
sglang_router_rs
::
core
::{
BasicWorker
,
WorkerType
};
use
sglang_router_rs
::
openai_api_types
::{
use
sglang_router_rs
::
openai_api_types
::{
ChatCompletionRequest
,
ChatMessage
,
CompletionRequest
,
GenerateParameters
,
GenerateRequest
,
ChatCompletionRequest
,
ChatMessage
,
CompletionRequest
,
GenerateParameters
,
GenerateRequest
,
SamplingParams
,
StringOrArray
,
UserMessageContent
,
SamplingParams
,
StringOrArray
,
UserMessageContent
,
};
};
use
sglang_router_rs
::
routers
::
request_adapter
::{
RouteableRequest
,
ToPdRequest
};
use
sglang_router_rs
::
routers
::
bootstrap_injector
::
inject_bootstrap_fields
;
fn
create_test_worker
()
->
BasicWorker
{
BasicWorker
::
new
(
"http://test-server:8000"
.to_string
(),
WorkerType
::
Prefill
{
bootstrap_port
:
Some
(
5678
),
},
)
}
/// Create a default GenerateRequest for benchmarks with minimal fields set
/// Create a default GenerateRequest for benchmarks with minimal fields set
fn
default_generate_request
()
->
GenerateRequest
{
fn
default_generate_request
()
->
GenerateRequest
{
...
@@ -312,49 +322,54 @@ fn bench_json_deserialization(c: &mut Criterion) {
...
@@ -312,49 +322,54 @@ fn bench_json_deserialization(c: &mut Criterion) {
group
.finish
();
group
.finish
();
}
}
// Benchmark
request adaptation from OpenAI to PD format
// Benchmark
bootstrap injection (replaces request adaptation)
fn
bench_
request_adapta
tion
(
c
:
&
mut
Criterion
)
{
fn
bench_
bootstrap_injec
tion
(
c
:
&
mut
Criterion
)
{
let
mut
group
=
c
.benchmark_group
(
"
request_adapta
tion"
);
let
mut
group
=
c
.benchmark_group
(
"
bootstrap_injec
tion"
);
let
generate_req
=
create_sample_generate_request
();
let
generate_req
=
create_sample_generate_request
();
let
chat_req
=
create_sample_chat_completion_request
();
let
chat_req
=
create_sample_chat_completion_request
();
let
completion_req
=
create_sample_completion_request
();
let
completion_req
=
create_sample_completion_request
();
let
large_chat_req
=
create_large_chat_completion_request
();
let
large_chat_req
=
create_large_chat_completion_request
();
let
worker
=
create_test_worker
();
group
.bench_function
(
"generate_
to_pd
"
,
|
b
|
{
group
.bench_function
(
"generate_
bootstrap_injection
"
,
|
b
|
{
b
.iter
(||
{
b
.iter
(||
{
let
pd_req
=
black_box
(
generate_req
.clone
())
.to_pd_request
();
let
mut
json
=
to_value
(
black_box
(
&
generate_req
))
.unwrap
();
black_box
(
pd_req
);
inject_bootstrap_fields
(
&
mut
json
,
&
worker
)
.unwrap
();
black_box
(
json
);
});
});
});
});
group
.bench_function
(
"chat_completion_
to_pd
"
,
|
b
|
{
group
.bench_function
(
"chat_completion_
bootstrap_injection
"
,
|
b
|
{
b
.iter
(||
{
b
.iter
(||
{
let
pd_req
=
black_box
(
chat_req
.clone
())
.to_pd_request
();
let
mut
json
=
to_value
(
black_box
(
&
chat_req
))
.unwrap
();
black_box
(
pd_req
);
inject_bootstrap_fields
(
&
mut
json
,
&
worker
)
.unwrap
();
black_box
(
json
);
});
});
});
});
group
.bench_function
(
"completion_
to_pd
"
,
|
b
|
{
group
.bench_function
(
"completion_
bootstrap_injection
"
,
|
b
|
{
b
.iter
(||
{
b
.iter
(||
{
let
pd_req
=
black_box
(
completion_req
.clone
())
.to_pd_request
();
let
mut
json
=
to_value
(
black_box
(
&
completion_req
))
.unwrap
();
black_box
(
pd_req
);
inject_bootstrap_fields
(
&
mut
json
,
&
worker
)
.unwrap
();
black_box
(
json
);
});
});
});
});
group
.bench_function
(
"large_chat_completion_
to_pd
"
,
|
b
|
{
group
.bench_function
(
"large_chat_completion_
bootstrap_injection
"
,
|
b
|
{
b
.iter
(||
{
b
.iter
(||
{
let
pd_req
=
black_box
(
large_chat_req
.clone
())
.to_pd_request
();
let
mut
json
=
to_value
(
black_box
(
&
large_chat_req
))
.unwrap
();
black_box
(
pd_req
);
inject_bootstrap_fields
(
&
mut
json
,
&
worker
)
.unwrap
();
black_box
(
json
);
});
});
});
});
group
.finish
();
group
.finish
();
}
}
// Benchmark
regular routing (RouteableRequest methods
)
// Benchmark
direct JSON routing (replaces regular routing
)
fn
bench_
regular
_routing
(
c
:
&
mut
Criterion
)
{
fn
bench_
direct_json
_routing
(
c
:
&
mut
Criterion
)
{
let
mut
group
=
c
.benchmark_group
(
"
regular
_routing"
);
let
mut
group
=
c
.benchmark_group
(
"
direct_json
_routing"
);
let
generate_req
=
create_sample_generate_request
();
let
generate_req
=
create_sample_generate_request
();
let
chat_req
=
create_sample_chat_completion_request
();
let
chat_req
=
create_sample_chat_completion_request
();
...
@@ -362,35 +377,42 @@ fn bench_regular_routing(c: &mut Criterion) {
...
@@ -362,35 +377,42 @@ fn bench_regular_routing(c: &mut Criterion) {
group
.bench_function
(
"generate_to_json"
,
|
b
|
{
group
.bench_function
(
"generate_to_json"
,
|
b
|
{
b
.iter
(||
{
b
.iter
(||
{
let
json
=
black_box
(
&
generate_req
)
.to_json
()
.unwrap
();
let
json
=
to_value
(
black_box
(
&
generate_req
))
.unwrap
();
black_box
(
json
);
});
});
group
.bench_function
(
"generate_to_json_string"
,
|
b
|
{
b
.iter
(||
{
let
json
=
to_string
(
black_box
(
&
generate_req
))
.unwrap
();
black_box
(
json
);
black_box
(
json
);
});
});
});
});
group
.bench_function
(
"generate_to_bytes"
,
|
b
|
{
group
.bench_function
(
"generate_to_bytes"
,
|
b
|
{
b
.iter
(||
{
b
.iter
(||
{
let
bytes
=
black_box
(
&
generate_req
)
.to_bytes
(
)
.unwrap
();
let
bytes
=
to_vec
(
black_box
(
&
generate_req
))
.unwrap
();
black_box
(
bytes
);
black_box
(
bytes
);
});
});
});
});
group
.bench_function
(
"chat_completion_to_json"
,
|
b
|
{
group
.bench_function
(
"chat_completion_to_json"
,
|
b
|
{
b
.iter
(||
{
b
.iter
(||
{
let
json
=
black_box
(
&
chat_req
)
.to_json
(
)
.unwrap
();
let
json
=
to_value
(
black_box
(
&
chat_req
))
.unwrap
();
black_box
(
json
);
black_box
(
json
);
});
});
});
});
group
.bench_function
(
"chat_completion_to_
bytes
"
,
|
b
|
{
group
.bench_function
(
"chat_completion_to_
json_string
"
,
|
b
|
{
b
.iter
(||
{
b
.iter
(||
{
let
bytes
=
black_box
(
&
chat_req
)
.to_bytes
(
)
.unwrap
();
let
json
=
to_string
(
black_box
(
&
chat_req
))
.unwrap
();
black_box
(
bytes
);
black_box
(
json
);
});
});
});
});
group
.bench_function
(
"completion_to_json"
,
|
b
|
{
group
.bench_function
(
"completion_to_json"
,
|
b
|
{
b
.iter
(||
{
b
.iter
(||
{
let
json
=
black_box
(
&
completion_req
)
.to_json
(
)
.unwrap
();
let
json
=
to_value
(
black_box
(
&
completion_req
))
.unwrap
();
black_box
(
json
);
black_box
(
json
);
});
});
});
});
...
@@ -418,6 +440,8 @@ fn bench_throughput_by_size(c: &mut Criterion) {
...
@@ -418,6 +440,8 @@ fn bench_throughput_by_size(c: &mut Criterion) {
..
default_generate_request
()
..
default_generate_request
()
};
};
let
worker
=
create_test_worker
();
for
(
name
,
req
)
in
[
for
(
name
,
req
)
in
[
(
"small"
,
&
small_generate
),
(
"small"
,
&
small_generate
),
(
"medium"
,
&
medium_generate
),
(
"medium"
,
&
medium_generate
),
...
@@ -445,33 +469,41 @@ fn bench_throughput_by_size(c: &mut Criterion) {
...
@@ -445,33 +469,41 @@ fn bench_throughput_by_size(c: &mut Criterion) {
},
},
);
);
group
.bench_with_input
(
BenchmarkId
::
new
(
"adapt_to_pd"
,
name
),
&
req
,
|
b
,
req
|
{
group
.bench_with_input
(
b
.iter
(||
{
BenchmarkId
::
new
(
"bootstrap_inject"
,
name
),
let
pd_req
=
(
*
req
)
.clone
()
.to_pd_request
();
&
req
,
black_box
(
pd_req
);
|
b
,
req
|
{
});
b
.iter
(||
{
});
let
mut
json
=
to_value
(
req
)
.unwrap
();
inject_bootstrap_fields
(
&
mut
json
,
&
worker
)
.unwrap
();
black_box
(
json
);
});
},
);
}
}
group
.finish
();
group
.finish
();
}
}
// Benchmark full round-trip: deserialize ->
ad
ap
t
-> serialize
// Benchmark full round-trip: deserialize ->
inject bootstr
ap -> serialize
fn
bench_full_round_trip
(
c
:
&
mut
Criterion
)
{
fn
bench_full_round_trip
(
c
:
&
mut
Criterion
)
{
let
mut
group
=
c
.benchmark_group
(
"full_round_trip"
);
let
mut
group
=
c
.benchmark_group
(
"full_round_trip"
);
let
generate_json
=
to_string
(
&
create_sample_generate_request
())
.unwrap
();
let
generate_json
=
to_string
(
&
create_sample_generate_request
())
.unwrap
();
let
chat_json
=
to_string
(
&
create_sample_chat_completion_request
())
.unwrap
();
let
chat_json
=
to_string
(
&
create_sample_chat_completion_request
())
.unwrap
();
let
completion_json
=
to_string
(
&
create_sample_completion_request
())
.unwrap
();
let
completion_json
=
to_string
(
&
create_sample_completion_request
())
.unwrap
();
let
worker
=
create_test_worker
();
group
.bench_function
(
"generate_openai_to_pd_pipeline"
,
|
b
|
{
group
.bench_function
(
"generate_openai_to_pd_pipeline"
,
|
b
|
{
b
.iter
(||
{
b
.iter
(||
{
// Deserialize OpenAI request
// Deserialize OpenAI request
let
req
:
GenerateRequest
=
from_str
(
black_box
(
&
generate_json
))
.unwrap
();
let
req
:
GenerateRequest
=
from_str
(
black_box
(
&
generate_json
))
.unwrap
();
// Adapt to PD format
// Convert to JSON Value
let
pd_req
=
req
.to_pd_request
();
let
mut
json
=
to_value
(
&
req
)
.unwrap
();
// Serialize PD request
// Inject bootstrap fields
let
pd_json
=
to_string
(
&
pd_req
)
.unwrap
();
inject_bootstrap_fields
(
&
mut
json
,
&
worker
)
.unwrap
();
// Serialize final request
let
pd_json
=
to_string
(
&
json
)
.unwrap
();
black_box
(
pd_json
);
black_box
(
pd_json
);
});
});
});
});
...
@@ -479,8 +511,9 @@ fn bench_full_round_trip(c: &mut Criterion) {
...
@@ -479,8 +511,9 @@ fn bench_full_round_trip(c: &mut Criterion) {
group
.bench_function
(
"chat_completion_openai_to_pd_pipeline"
,
|
b
|
{
group
.bench_function
(
"chat_completion_openai_to_pd_pipeline"
,
|
b
|
{
b
.iter
(||
{
b
.iter
(||
{
let
req
:
ChatCompletionRequest
=
from_str
(
black_box
(
&
chat_json
))
.unwrap
();
let
req
:
ChatCompletionRequest
=
from_str
(
black_box
(
&
chat_json
))
.unwrap
();
let
pd_req
=
req
.to_pd_request
();
let
mut
json
=
to_value
(
&
req
)
.unwrap
();
let
pd_json
=
to_string
(
&
pd_req
)
.unwrap
();
inject_bootstrap_fields
(
&
mut
json
,
&
worker
)
.unwrap
();
let
pd_json
=
to_string
(
&
json
)
.unwrap
();
black_box
(
pd_json
);
black_box
(
pd_json
);
});
});
});
});
...
@@ -488,19 +521,21 @@ fn bench_full_round_trip(c: &mut Criterion) {
...
@@ -488,19 +521,21 @@ fn bench_full_round_trip(c: &mut Criterion) {
group
.bench_function
(
"completion_openai_to_pd_pipeline"
,
|
b
|
{
group
.bench_function
(
"completion_openai_to_pd_pipeline"
,
|
b
|
{
b
.iter
(||
{
b
.iter
(||
{
let
req
:
CompletionRequest
=
from_str
(
black_box
(
&
completion_json
))
.unwrap
();
let
req
:
CompletionRequest
=
from_str
(
black_box
(
&
completion_json
))
.unwrap
();
let
pd_req
=
req
.to_pd_request
();
let
mut
json
=
to_value
(
&
req
)
.unwrap
();
let
pd_json
=
to_string
(
&
pd_req
)
.unwrap
();
inject_bootstrap_fields
(
&
mut
json
,
&
worker
)
.unwrap
();
let
pd_json
=
to_string
(
&
json
)
.unwrap
();
black_box
(
pd_json
);
black_box
(
pd_json
);
});
});
});
});
group
.bench_function
(
"generate_
regular_routing
_pipeline"
,
|
b
|
{
group
.bench_function
(
"generate_
direct_json
_pipeline"
,
|
b
|
{
b
.iter
(||
{
b
.iter
(||
{
// Deserialize OpenAI request
// Deserialize OpenAI request
let
req
:
GenerateRequest
=
from_str
(
black_box
(
&
generate_json
))
.unwrap
();
let
req
:
GenerateRequest
=
from_str
(
black_box
(
&
generate_json
))
.unwrap
();
// Convert to JSON for regular routing
// Convert to JSON for direct routing (no bootstrap injection)
let
routing_json
=
req
.to_json
()
.unwrap
();
let
routing_json
=
to_value
(
&
req
)
.unwrap
();
black_box
(
routing_json
);
let
json_string
=
to_string
(
&
routing_json
)
.unwrap
();
black_box
(
json_string
);
});
});
});
});
...
@@ -515,6 +550,7 @@ fn benchmark_summary(c: &mut Criterion) {
...
@@ -515,6 +550,7 @@ fn benchmark_summary(c: &mut Criterion) {
// Quick performance overview
// Quick performance overview
let
generate_req
=
create_sample_generate_request
();
let
generate_req
=
create_sample_generate_request
();
let
worker
=
create_test_worker
();
println!
(
"
\n
Quick Performance Overview:"
);
println!
(
"
\n
Quick Performance Overview:"
);
...
@@ -538,32 +574,39 @@ fn benchmark_summary(c: &mut Criterion) {
...
@@ -538,32 +574,39 @@ fn benchmark_summary(c: &mut Criterion) {
deserialize_time
deserialize_time
);
);
// Measure adaptation
// Measure
bootstrap injection (replaces
adaptation
)
let
start
=
Instant
::
now
();
let
start
=
Instant
::
now
();
for
_
in
0
..
1000
{
for
_
in
0
..
1000
{
let
_
=
black_box
(
generate_req
.clone
()
.to_pd_request
());
let
mut
json
=
to_value
(
&
generate_req
)
.unwrap
();
let
_
=
black_box
(
inject_bootstrap_fields
(
&
mut
json
,
&
worker
));
}
}
let
adap
t_time
=
start
.elapsed
()
.as_nanos
()
/
1000
;
let
injec
t_time
=
start
.elapsed
()
.as_nanos
()
/
1000
;
println!
(
" *
PD Adapta
tion (avg):
{:>
8
} ns/req"
,
adap
t_time
);
println!
(
" *
Bootstrap Injec
tion (avg): {:>
6
} ns/req"
,
injec
t_time
);
// Calculate ratios
// Calculate ratios
let
total_pipeline
=
serialize_time
+
deserialize_time
+
adap
t_time
;
let
total_pipeline
=
serialize_time
+
deserialize_time
+
injec
t_time
;
println!
(
" * Total Pipeline (avg): {:>8} ns/req"
,
total_pipeline
);
println!
(
" * Total Pipeline (avg): {:>8} ns/req"
,
total_pipeline
);
println!
(
"
\n
Performance Insights:"
);
println!
(
"
\n
Performance Insights:"
);
if
deserialize_time
>
serialize_time
*
2
{
if
deserialize_time
>
serialize_time
*
2
{
println!
(
" • Deserialization is significantly faster than serialization"
);
println!
(
" • Deserialization is significantly faster than serialization"
);
}
}
if
adap
t_time
<
serialize_time
/
10
{
if
injec
t_time
<
serialize_time
/
10
{
println!
(
println!
(
" •
PD adapta
tion overhead is negligible ({:.1}% of serialization)"
,
" •
Bootstrap injec
tion overhead is negligible ({:.1}% of serialization)"
,
(
adap
t_time
as
f64
/
serialize_time
as
f64
)
*
100.0
(
injec
t_time
as
f64
/
serialize_time
as
f64
)
*
100.0
);
);
}
}
if
total_pipeline
<
10_000
{
if
total_pipeline
<
10
0
_000
{
println!
(
" • Total pipeline latency is excellent (< 10μs)"
);
println!
(
" • Total pipeline latency is excellent (< 10
0
μs)"
);
}
}
println!
(
"
\n
Simplification Benefits:"
);
println!
(
" • Eliminated complex type conversion layer"
);
println!
(
" • Reduced memory allocations"
);
println!
(
" • Automatic field preservation (no manual mapping)"
);
println!
(
" • Direct JSON manipulation improves performance"
);
println!
(
"
\n
Recommendations:"
);
println!
(
"
\n
Recommendations:"
);
if
serialize_time
>
deserialize_time
{
if
serialize_time
>
deserialize_time
{
println!
(
" • Focus optimization efforts on serialization rather than deserialization"
);
println!
(
" • Focus optimization efforts on serialization rather than deserialization"
);
...
@@ -581,8 +624,8 @@ criterion_group!(
...
@@ -581,8 +624,8 @@ criterion_group!(
benchmark_summary
,
benchmark_summary
,
bench_json_serialization
,
bench_json_serialization
,
bench_json_deserialization
,
bench_json_deserialization
,
bench_
request_adapta
tion
,
bench_
bootstrap_injec
tion
,
bench_
regular
_routing
,
bench_
direct_json
_routing
,
bench_throughput_by_size
,
bench_throughput_by_size
,
bench_full_round_trip
bench_full_round_trip
);
);
...
...
sgl-router/scripts/run_benchmarks.py
View file @
8c7bb39d
...
@@ -121,8 +121,6 @@ class BenchmarkRunner:
...
@@ -121,8 +121,6 @@ class BenchmarkRunner:
results
[
"serialization_time"
]
=
self
.
_extract_time
(
line
)
results
[
"serialization_time"
]
=
self
.
_extract_time
(
line
)
elif
"Deserialization (avg):"
in
line
:
elif
"Deserialization (avg):"
in
line
:
results
[
"deserialization_time"
]
=
self
.
_extract_time
(
line
)
results
[
"deserialization_time"
]
=
self
.
_extract_time
(
line
)
elif
"PD Adaptation (avg):"
in
line
:
results
[
"adaptation_time"
]
=
self
.
_extract_time
(
line
)
elif
"Total Pipeline (avg):"
in
line
:
elif
"Total Pipeline (avg):"
in
line
:
results
[
"total_time"
]
=
self
.
_extract_time
(
line
)
results
[
"total_time"
]
=
self
.
_extract_time
(
line
)
...
@@ -145,7 +143,6 @@ class BenchmarkRunner:
...
@@ -145,7 +143,6 @@ class BenchmarkRunner:
thresholds
=
{
thresholds
=
{
"serialization_time"
:
2000
,
# 2μs max
"serialization_time"
:
2000
,
# 2μs max
"deserialization_time"
:
2000
,
# 2μs max
"deserialization_time"
:
2000
,
# 2μs max
"adaptation_time"
:
5000
,
# 5μs max
"total_time"
:
10000
,
# 10μs max
"total_time"
:
10000
,
# 10μs max
}
}
...
...
sgl-router/src/routers/bootstrap_injector.rs
0 → 100644
View file @
8c7bb39d
// Bootstrap field injection for PD routing
// Directly injects bootstrap fields into JSON requests without intermediate type conversions
use
crate
::
core
::{
Worker
,
WorkerType
};
use
crate
::
routers
::
pd_types
::
get_hostname
;
use
serde_json
::{
json
,
Value
};
/// Inject bootstrap fields directly into a JSON request
/// This replaces the complex ToPdRequest -> Bootstrap trait pattern
pub
fn
inject_bootstrap_fields
(
json
:
&
mut
Value
,
worker
:
&
dyn
Worker
)
->
Result
<
(),
String
>
{
let
batch_size
=
extract_batch_size
(
json
)
?
;
// Extract bootstrap port from prefill worker if it's a prefill type
let
bootstrap_port
=
match
worker
.worker_type
()
{
WorkerType
::
Prefill
{
bootstrap_port
}
=>
bootstrap_port
,
_
=>
None
,
};
let
hostname
=
get_hostname
(
worker
.url
());
if
let
Some
(
batch_size
)
=
batch_size
{
// Batch scenario - create arrays of bootstrap values
json
[
"bootstrap_host"
]
=
json!
(
vec!
[
hostname
;
batch_size
]);
json
[
"bootstrap_port"
]
=
json!
(
vec!
[
bootstrap_port
;
batch_size
]);
json
[
"bootstrap_room"
]
=
json!
((
0
..
batch_size
)
.map
(|
_
|
{
// Generate a value in the range [0, 2^63 - 1] to match Python's random.randint(0, 2**63 - 1)
rand
::
random
::
<
u64
>
()
&
(
i64
::
MAX
as
u64
)
})
.collect
::
<
Vec
<
_
>>
());
}
else
{
// Single scenario - create single bootstrap values
json
[
"bootstrap_host"
]
=
json!
(
hostname
);
json
[
"bootstrap_port"
]
=
json!
(
bootstrap_port
);
json
[
"bootstrap_room"
]
=
json!
(
rand
::
random
::
<
u64
>
()
&
(
i64
::
MAX
as
u64
));
}
Ok
(())
}
/// Extract batch size from various JSON request formats
/// Handles chat completions, completions, and generate requests
fn
extract_batch_size
(
json
:
&
Value
)
->
Result
<
Option
<
usize
>
,
String
>
{
// Check for chat completions 'n' parameter (number of choices)
if
let
Some
(
n
)
=
json
.get
(
"n"
)
.and_then
(|
v
|
v
.as_u64
())
{
if
n
>
1
{
return
Ok
(
Some
(
n
as
usize
));
}
}
// Check for array prompts (completions API)
if
let
Some
(
prompt
)
=
json
.get
(
"prompt"
)
{
if
let
Some
(
arr
)
=
prompt
.as_array
()
{
if
arr
.is_empty
()
{
return
Err
(
"Batch prompt array is empty"
.to_string
());
}
return
Ok
(
Some
(
arr
.len
()));
}
}
// Check for array texts (generate API)
if
let
Some
(
text
)
=
json
.get
(
"text"
)
{
if
let
Some
(
arr
)
=
text
.as_array
()
{
if
arr
.is_empty
()
{
return
Err
(
"Batch text array is empty"
.to_string
());
}
return
Ok
(
Some
(
arr
.len
()));
}
}
// Check for batch input_ids (generate API)
if
let
Some
(
input_ids
)
=
json
.get
(
"input_ids"
)
{
if
let
Some
(
arr
)
=
input_ids
.as_array
()
{
if
arr
.is_empty
()
{
return
Err
(
"Batch input_ids array is empty"
.to_string
());
}
return
Ok
(
Some
(
arr
.len
()));
}
}
// No batch indicators found - single request
Ok
(
None
)
}
#[cfg(test)]
mod
tests
{
use
super
::
*
;
use
crate
::
core
::
BasicWorker
;
use
serde_json
::
json
;
fn
create_test_worker
()
->
BasicWorker
{
BasicWorker
::
new
(
"http://test-server:8000"
.to_string
(),
WorkerType
::
Prefill
{
bootstrap_port
:
Some
(
5678
),
},
)
}
#[test]
fn
test_inject_bootstrap_single_request
()
{
let
worker
=
create_test_worker
();
let
mut
json
=
json!
({
"model"
:
"test-model"
,
"prompt"
:
"Hello world"
,
"max_tokens"
:
100
});
let
result
=
inject_bootstrap_fields
(
&
mut
json
,
&
worker
);
assert
!
(
result
.is_ok
());
// Verify bootstrap fields were added
assert_eq!
(
json
[
"bootstrap_host"
],
json!
(
"test-server"
));
assert_eq!
(
json
[
"bootstrap_port"
],
json!
(
5678
));
assert
!
(
json
[
"bootstrap_room"
]
.is_number
());
// Verify original fields preserved
assert_eq!
(
json
[
"model"
],
json!
(
"test-model"
));
assert_eq!
(
json
[
"prompt"
],
json!
(
"Hello world"
));
assert_eq!
(
json
[
"max_tokens"
],
json!
(
100
));
}
#[test]
fn
test_inject_bootstrap_batch_prompt
()
{
let
worker
=
create_test_worker
();
let
mut
json
=
json!
({
"model"
:
"test-model"
,
"prompt"
:
[
"Hello"
,
"World"
],
"max_tokens"
:
100
});
let
result
=
inject_bootstrap_fields
(
&
mut
json
,
&
worker
);
assert
!
(
result
.is_ok
());
// Verify batch bootstrap fields
assert_eq!
(
json
[
"bootstrap_host"
],
json!
([
"test-server"
,
"test-server"
])
);
assert_eq!
(
json
[
"bootstrap_port"
],
json!
([
5678
,
5678
]));
let
bootstrap_rooms
=
json
[
"bootstrap_room"
]
.as_array
()
.unwrap
();
assert_eq!
(
bootstrap_rooms
.len
(),
2
);
for
room
in
bootstrap_rooms
{
assert
!
(
room
.is_number
());
let
room_val
=
room
.as_u64
()
.unwrap
();
assert
!
(
room_val
<=
i64
::
MAX
as
u64
);
}
}
#[test]
fn
test_inject_bootstrap_chat_n_parameter
()
{
let
worker
=
create_test_worker
();
let
mut
json
=
json!
({
"model"
:
"gpt-4"
,
"messages"
:
[{
"role"
:
"user"
,
"content"
:
"Hello"
}],
"n"
:
3
});
let
result
=
inject_bootstrap_fields
(
&
mut
json
,
&
worker
);
assert
!
(
result
.is_ok
());
// Verify batch bootstrap fields for n=3
let
bootstrap_hosts
=
json
[
"bootstrap_host"
]
.as_array
()
.unwrap
();
assert_eq!
(
bootstrap_hosts
.len
(),
3
);
assert_eq!
(
bootstrap_hosts
[
0
],
json!
(
"test-server"
));
let
bootstrap_ports
=
json
[
"bootstrap_port"
]
.as_array
()
.unwrap
();
assert_eq!
(
bootstrap_ports
.len
(),
3
);
assert_eq!
(
bootstrap_ports
[
0
],
json!
(
5678
));
let
bootstrap_rooms
=
json
[
"bootstrap_room"
]
.as_array
()
.unwrap
();
assert_eq!
(
bootstrap_rooms
.len
(),
3
);
}
#[test]
fn
test_inject_bootstrap_generate_text_array
()
{
let
worker
=
create_test_worker
();
let
mut
json
=
json!
({
"text"
:
[
"First prompt"
,
"Second prompt"
],
"stream"
:
false
});
let
result
=
inject_bootstrap_fields
(
&
mut
json
,
&
worker
);
assert
!
(
result
.is_ok
());
// Verify batch bootstrap fields
let
bootstrap_hosts
=
json
[
"bootstrap_host"
]
.as_array
()
.unwrap
();
assert_eq!
(
bootstrap_hosts
.len
(),
2
);
let
bootstrap_rooms
=
json
[
"bootstrap_room"
]
.as_array
()
.unwrap
();
assert_eq!
(
bootstrap_rooms
.len
(),
2
);
// Ensure room values are different (randomness)
assert_ne!
(
bootstrap_rooms
[
0
],
bootstrap_rooms
[
1
]);
}
#[test]
fn
test_inject_bootstrap_input_ids_array
()
{
let
worker
=
create_test_worker
();
let
mut
json
=
json!
({
"input_ids"
:
[[
1
,
2
,
3
],
[
4
,
5
,
6
]],
"stream"
:
false
});
let
result
=
inject_bootstrap_fields
(
&
mut
json
,
&
worker
);
assert
!
(
result
.is_ok
());
// Verify batch bootstrap fields
let
bootstrap_hosts
=
json
[
"bootstrap_host"
]
.as_array
()
.unwrap
();
assert_eq!
(
bootstrap_hosts
.len
(),
2
);
}
#[test]
fn
test_extract_batch_size_empty_array_error
()
{
let
json
=
json!
({
"prompt"
:
[],
"model"
:
"test"
});
let
result
=
extract_batch_size
(
&
json
);
assert
!
(
result
.is_err
());
assert
!
(
result
.unwrap_err
()
.contains
(
"empty"
));
}
#[test]
fn
test_extract_batch_size_single_requests
()
{
// Single string prompt
let
json
=
json!
({
"prompt"
:
"Hello world"
,
"model"
:
"test"
});
assert_eq!
(
extract_batch_size
(
&
json
)
.unwrap
(),
None
);
// Single text
let
json
=
json!
({
"text"
:
"Hello world"
,
"stream"
:
false
});
assert_eq!
(
extract_batch_size
(
&
json
)
.unwrap
(),
None
);
// Chat with n=1 (default)
let
json
=
json!
({
"messages"
:
[{
"role"
:
"user"
,
"content"
:
"Hello"
}],
"n"
:
1
});
assert_eq!
(
extract_batch_size
(
&
json
)
.unwrap
(),
None
);
// Chat without n parameter
let
json
=
json!
({
"messages"
:
[{
"role"
:
"user"
,
"content"
:
"Hello"
}]
});
assert_eq!
(
extract_batch_size
(
&
json
)
.unwrap
(),
None
);
}
#[test]
fn
test_inject_bootstrap_preserves_sglang_fields
()
{
let
worker
=
create_test_worker
();
let
mut
json
=
json!
({
"model"
:
"test-model"
,
"prompt"
:
"Hello"
,
// SGLang extensions should be preserved
"top_k"
:
40
,
"min_p"
:
0.05
,
"repetition_penalty"
:
1.1
,
"regex"
:
"test_pattern"
,
"lora_path"
:
"test.bin"
,
"no_stop_trim"
:
true
,
"ignore_eos"
:
false
});
let
result
=
inject_bootstrap_fields
(
&
mut
json
,
&
worker
);
assert
!
(
result
.is_ok
());
// Verify bootstrap fields added
assert
!
(
json
.get
(
"bootstrap_host"
)
.is_some
());
assert
!
(
json
.get
(
"bootstrap_port"
)
.is_some
());
assert
!
(
json
.get
(
"bootstrap_room"
)
.is_some
());
// Verify all SGLang fields preserved
assert_eq!
(
json
[
"top_k"
],
json!
(
40
));
assert_eq!
(
json
[
"min_p"
],
json!
(
0.05
));
assert_eq!
(
json
[
"repetition_penalty"
],
json!
(
1.1
));
assert_eq!
(
json
[
"regex"
],
json!
(
"test_pattern"
));
assert_eq!
(
json
[
"lora_path"
],
json!
(
"test.bin"
));
assert_eq!
(
json
[
"no_stop_trim"
],
json!
(
true
));
assert_eq!
(
json
[
"ignore_eos"
],
json!
(
false
));
}
#[test]
fn
test_bootstrap_room_range
()
{
let
worker
=
create_test_worker
();
// Test single request room generation
for
_
in
0
..
1000
{
let
mut
json
=
json!
({
"prompt"
:
"test"
});
inject_bootstrap_fields
(
&
mut
json
,
&
worker
)
.unwrap
();
let
room
=
json
[
"bootstrap_room"
]
.as_u64
()
.unwrap
();
assert
!
(
room
<=
i64
::
MAX
as
u64
,
"Room {} exceeds i64::MAX"
,
room
);
}
// Test batch request room generation
for
_
in
0
..
100
{
let
mut
json
=
json!
({
"prompt"
:
[
"test1"
,
"test2"
]});
inject_bootstrap_fields
(
&
mut
json
,
&
worker
)
.unwrap
();
let
rooms
=
json
[
"bootstrap_room"
]
.as_array
()
.unwrap
();
for
room_val
in
rooms
{
let
room
=
room_val
.as_u64
()
.unwrap
();
assert
!
(
room
<=
i64
::
MAX
as
u64
,
"Room {} exceeds i64::MAX"
,
room
);
}
}
}
#[test]
fn
test_worker_without_bootstrap_port
()
{
let
worker
=
BasicWorker
::
new
(
"http://decode-only:8000"
.to_string
(),
WorkerType
::
Decode
,
// No bootstrap port
);
let
mut
json
=
json!
({
"prompt"
:
"Hello world"
});
let
result
=
inject_bootstrap_fields
(
&
mut
json
,
&
worker
);
assert
!
(
result
.is_ok
());
// Verify bootstrap fields with null port
assert_eq!
(
json
[
"bootstrap_host"
],
json!
(
"decode-only"
));
assert_eq!
(
json
[
"bootstrap_port"
],
json!
(
null
));
assert
!
(
json
[
"bootstrap_room"
]
.is_number
());
}
}
sgl-router/src/routers/mod.rs
View file @
8c7bb39d
...
@@ -11,10 +11,10 @@ use std::fmt::Debug;
...
@@ -11,10 +11,10 @@ use std::fmt::Debug;
use
crate
::
openai_api_types
::{
ChatCompletionRequest
,
CompletionRequest
,
GenerateRequest
};
use
crate
::
openai_api_types
::{
ChatCompletionRequest
,
CompletionRequest
,
GenerateRequest
};
pub
mod
bootstrap_injector
;
pub
mod
factory
;
pub
mod
factory
;
pub
mod
pd_router
;
pub
mod
pd_router
;
pub
mod
pd_types
;
pub
mod
pd_types
;
pub
mod
request_adapter
;
pub
mod
router
;
pub
mod
router
;
pub
use
factory
::
RouterFactory
;
pub
use
factory
::
RouterFactory
;
...
...
sgl-router/src/routers/pd_router.rs
View file @
8c7bb39d
// PD (Prefill-Decode) Router Implementation
// PD (Prefill-Decode) Router Implementation
// This module handles routing for disaggregated prefill-decode systems
// This module handles routing for disaggregated prefill-decode systems
use
super
::
pd_types
::{
api_path
,
Bootstrap
,
ChatReqInput
,
GenerateReqInput
,
PDRouterError
}
;
use
super
::
bootstrap_injector
::
inject_bootstrap_fields
;
use
super
::
request_adapter
::
ToPdRequest
;
use
super
::
pd_types
::{
api_path
,
PDRouterError
}
;
use
crate
::
config
::
types
::
RetryConfig
;
use
crate
::
config
::
types
::
RetryConfig
;
use
crate
::
core
::{
HealthChecker
,
Worker
,
WorkerFactory
,
WorkerLoadGuard
};
use
crate
::
core
::{
HealthChecker
,
Worker
,
WorkerFactory
,
WorkerLoadGuard
};
use
crate
::
metrics
::
RouterMetrics
;
use
crate
::
metrics
::
RouterMetrics
;
use
crate
::
openai_api_types
::{
ChatCompletionRequest
,
CompletionRequest
,
GenerateRequest
};
use
crate
::
openai_api_types
::{
ChatCompletionRequest
,
CompletionRequest
,
GenerateRequest
};
use
crate
::
policies
::
LoadBalancingPolicy
;
use
crate
::
policies
::
LoadBalancingPolicy
;
use
crate
::
routers
::{
RouterTrait
,
WorkerManagement
};
use
crate
::
tree
::
Tree
;
use
crate
::
tree
::
Tree
;
use
async_trait
::
async_trait
;
use
axum
::{
use
axum
::{
body
::
Body
,
body
::
Body
,
extract
::
Request
,
extract
::
Request
,
...
@@ -46,18 +48,26 @@ pub struct PDRouter {
...
@@ -46,18 +48,26 @@ pub struct PDRouter {
impl
PDRouter
{
impl
PDRouter
{
// Dynamic worker management methods for service discovery
// Dynamic worker management methods for service discovery
// Private helper method to perform health check on a new server
async
fn
wait_for_server_health
(
&
self
,
url
:
&
str
)
->
Result
<
(),
PDRouterError
>
{
crate
::
routers
::
router
::
Router
::
wait_for_healthy_workers
(
&
[
url
.to_string
()],
self
.timeout_secs
,
self
.interval_secs
,
)
.map_err
(|
_
|
PDRouterError
::
HealthCheckFailed
{
url
:
url
.to_string
(),
})
}
pub
async
fn
add_prefill_server
(
pub
async
fn
add_prefill_server
(
&
self
,
&
self
,
url
:
String
,
url
:
String
,
bootstrap_port
:
Option
<
u16
>
,
bootstrap_port
:
Option
<
u16
>
,
)
->
Result
<
String
,
PDRouterError
>
{
)
->
Result
<
String
,
PDRouterError
>
{
// Wait for the new server to be healthy
// Wait for the new server to be healthy
crate
::
routers
::
router
::
Router
::
wait_for_healthy_workers
(
self
.wait_for_server_health
(
&
url
)
.await
?
;
&
[
url
.clone
()],
self
.timeout_secs
,
self
.interval_secs
,
)
.map_err
(|
_
|
PDRouterError
::
HealthCheckFailed
{
url
:
url
.clone
()
})
?
;
// Create Worker for the new prefill server
// Create Worker for the new prefill server
let
worker
=
WorkerFactory
::
create_prefill
(
url
.clone
(),
bootstrap_port
);
let
worker
=
WorkerFactory
::
create_prefill
(
url
.clone
(),
bootstrap_port
);
...
@@ -88,12 +98,7 @@ impl PDRouter {
...
@@ -88,12 +98,7 @@ impl PDRouter {
pub
async
fn
add_decode_server
(
&
self
,
url
:
String
)
->
Result
<
String
,
PDRouterError
>
{
pub
async
fn
add_decode_server
(
&
self
,
url
:
String
)
->
Result
<
String
,
PDRouterError
>
{
// Wait for the new server to be healthy
// Wait for the new server to be healthy
crate
::
routers
::
router
::
Router
::
wait_for_healthy_workers
(
self
.wait_for_server_health
(
&
url
)
.await
?
;
&
[
url
.clone
()],
self
.timeout_secs
,
self
.interval_secs
,
)
.map_err
(|
_
|
PDRouterError
::
HealthCheckFailed
{
url
:
url
.clone
()
})
?
;
// Create Worker for the new decode server
// Create Worker for the new decode server
let
worker
=
WorkerFactory
::
create_decode
(
url
.clone
());
let
worker
=
WorkerFactory
::
create_decode
(
url
.clone
());
...
@@ -332,189 +337,6 @@ impl PDRouter {
...
@@ -332,189 +337,6 @@ impl PDRouter {
.into_response
()
.into_response
()
}
}
// Route a typed generate request
pub
async
fn
route_generate
(
&
self
,
headers
:
Option
<&
HeaderMap
>
,
mut
typed_req
:
GenerateReqInput
,
route
:
&
str
,
)
->
Response
{
let
start
=
Instant
::
now
();
// Get stream flag and return_logprob flag before moving the request
let
is_stream
=
typed_req
.stream
;
let
return_logprob
=
typed_req
.other
.get
(
"return_logprob"
)
.and_then
(|
v
|
v
.as_bool
())
.unwrap_or
(
false
);
// Extract text for cache-aware routing from the typed request
let
request_text
=
typed_req
.text
.as_ref
()
.and_then
(|
t
|
match
t
{
super
::
pd_types
::
InputText
::
Single
(
s
)
=>
Some
(
s
.as_str
()),
super
::
pd_types
::
InputText
::
Batch
(
v
)
=>
v
.first
()
.map
(|
s
|
s
.as_str
()),
});
// Select servers
let
(
prefill
,
decode
)
=
match
self
.select_pd_pair
(
request_text
)
.await
{
Ok
(
pair
)
=>
pair
,
Err
(
e
)
=>
return
Self
::
handle_server_selection_error
(
e
),
};
// Log routing decision
info!
(
"PD routing decision route={} prefill_url={} decode_url={}"
,
route
,
prefill
.url
(),
decode
.url
()
);
// Add bootstrap info using the trait method
if
let
Err
(
e
)
=
typed_req
.add_bootstrap_info
(
prefill
.as_ref
())
{
return
Self
::
handle_bootstrap_error
(
e
);
}
// Convert to JSON after bootstrap injection
let
json_with_bootstrap
=
match
serde_json
::
to_value
(
&
typed_req
)
{
Ok
(
json
)
=>
json
,
Err
(
e
)
=>
return
Self
::
handle_serialization_error
(
e
),
};
// Execute dual dispatch
self
.execute_dual_dispatch
(
headers
,
json_with_bootstrap
,
route
,
prefill
.as_ref
(),
decode
.as_ref
(),
is_stream
,
return_logprob
,
start
,
)
.await
}
// Route a typed chat request
pub
async
fn
route_chat
(
&
self
,
headers
:
Option
<&
HeaderMap
>
,
mut
typed_req
:
ChatReqInput
,
route
:
&
str
,
)
->
Response
{
let
start
=
Instant
::
now
();
// Get stream flag and return_logprob flag before moving the request
let
is_stream
=
typed_req
.stream
;
let
return_logprob
=
typed_req
.other
.get
(
"return_logprob"
)
.and_then
(|
v
|
v
.as_bool
())
.unwrap_or
(
false
);
// Extract text for cache-aware routing from chat messages
let
request_text
=
typed_req
.other
.get
(
"messages"
)
.and_then
(|
messages
|
messages
.as_array
())
.and_then
(|
arr
|
arr
.first
())
.and_then
(|
msg
|
msg
.get
(
"content"
))
.and_then
(|
content
|
content
.as_str
());
// Select servers
let
(
prefill
,
decode
)
=
match
self
.select_pd_pair
(
request_text
)
.await
{
Ok
(
pair
)
=>
pair
,
Err
(
e
)
=>
return
Self
::
handle_server_selection_error
(
e
),
};
// Log routing decision
info!
(
"PD routing decision route={} prefill_url={} decode_url={}"
,
route
,
prefill
.url
(),
decode
.url
()
);
if
let
Err
(
e
)
=
typed_req
.add_bootstrap_info
(
prefill
.as_ref
())
{
return
Self
::
handle_bootstrap_error
(
e
);
}
// Convert to JSON after bootstrap injection
let
json_with_bootstrap
=
match
serde_json
::
to_value
(
&
typed_req
)
{
Ok
(
json
)
=>
json
,
Err
(
e
)
=>
return
Self
::
handle_serialization_error
(
e
),
};
// Execute dual dispatch
self
.execute_dual_dispatch
(
headers
,
json_with_bootstrap
,
route
,
prefill
.as_ref
(),
decode
.as_ref
(),
is_stream
,
return_logprob
,
start
,
)
.await
}
// Route a completion request while preserving OpenAI format
pub
async
fn
route_completion
(
&
self
,
headers
:
Option
<&
HeaderMap
>
,
mut
typed_req
:
CompletionRequest
,
route
:
&
str
,
)
->
Response
{
let
start
=
Instant
::
now
();
// Get stream flag and return_logprob flag before moving the request
let
is_stream
=
typed_req
.stream
;
let
return_logprob
=
typed_req
.logprobs
.is_some
();
// Extract text for cache-aware routing from the typed request
let
request_text
=
match
&
typed_req
.prompt
{
crate
::
openai_api_types
::
StringOrArray
::
String
(
s
)
=>
Some
(
s
.as_str
()),
crate
::
openai_api_types
::
StringOrArray
::
Array
(
arr
)
=>
arr
.first
()
.map
(|
s
|
s
.as_str
()),
};
// Select servers
let
(
prefill
,
decode
)
=
match
self
.select_pd_pair
(
request_text
)
.await
{
Ok
(
pair
)
=>
pair
,
Err
(
e
)
=>
return
Self
::
handle_server_selection_error
(
e
),
};
// Log routing decision
info!
(
"PD routing decision route={} prefill_url={} decode_url={}"
,
route
,
prefill
.url
(),
decode
.url
()
);
if
let
Err
(
e
)
=
typed_req
.add_bootstrap_info
(
prefill
.as_ref
())
{
return
Self
::
handle_bootstrap_error
(
e
);
}
// Convert to JSON after bootstrap injection
let
json_with_bootstrap
=
match
serde_json
::
to_value
(
&
typed_req
)
{
Ok
(
json
)
=>
json
,
Err
(
e
)
=>
return
Self
::
handle_serialization_error
(
e
),
};
// Execute dual dispatch
self
.execute_dual_dispatch
(
headers
,
json_with_bootstrap
,
route
,
prefill
.as_ref
(),
decode
.as_ref
(),
is_stream
,
return_logprob
,
start
,
)
.await
}
// Execute the dual dispatch to prefill and decode servers with retry logic
// Execute the dual dispatch to prefill and decode servers with retry logic
async
fn
execute_dual_dispatch
(
async
fn
execute_dual_dispatch
(
&
self
,
&
self
,
...
@@ -1090,7 +912,7 @@ impl PDRouter {
...
@@ -1090,7 +912,7 @@ impl PDRouter {
// Helper functions
// Helper functions
async
fn
get_worker_load
(
client
:
&
reqwest
::
Client
,
worker_url
:
&
str
)
->
Option
<
isize
>
{
async
fn
get_worker_load
(
client
:
&
Client
,
worker_url
:
&
str
)
->
Option
<
isize
>
{
match
client
.get
(
format!
(
"{}/get_load"
,
worker_url
))
.send
()
.await
{
match
client
.get
(
format!
(
"{}/get_load"
,
worker_url
))
.send
()
.await
{
Ok
(
res
)
if
res
.status
()
.is_success
()
=>
match
res
.bytes
()
.await
{
Ok
(
res
)
if
res
.status
()
.is_success
()
=>
match
res
.bytes
()
.await
{
Ok
(
bytes
)
=>
match
serde_json
::
from_slice
::
<
Value
>
(
&
bytes
)
{
Ok
(
bytes
)
=>
match
serde_json
::
from_slice
::
<
Value
>
(
&
bytes
)
{
...
@@ -1123,9 +945,96 @@ async fn get_worker_load(client: &reqwest::Client, worker_url: &str) -> Option<i
...
@@ -1123,9 +945,96 @@ async fn get_worker_load(client: &reqwest::Client, worker_url: &str) -> Option<i
}
}
}
}
// PD-specific endpoints
#[async_trait]
impl
PDRouter
{
impl
WorkerManagement
for
PDRouter
{
pub
async
fn
health_generate
(
&
self
)
->
Response
{
async
fn
add_worker
(
&
self
,
_
worker_url
:
&
str
)
->
Result
<
String
,
String
>
{
// For PD router, we don't support adding workers via this generic method
Err
(
"PD router requires specific add_prefill_server or add_decode_server methods"
.to_string
(),
)
}
fn
remove_worker
(
&
self
,
worker_url
:
&
str
)
{
// For PD router, we would need to know if it's a prefill or decode server
// For now, try both
if
let
Ok
(
mut
workers
)
=
self
.prefill_workers
.write
()
{
if
let
Some
(
index
)
=
workers
.iter
()
.position
(|
w
|
w
.url
()
==
worker_url
)
{
workers
.remove
(
index
);
info!
(
"Removed prefill worker: {}"
,
worker_url
);
return
;
}
}
if
let
Ok
(
mut
workers
)
=
self
.decode_workers
.write
()
{
if
let
Some
(
index
)
=
workers
.iter
()
.position
(|
w
|
w
.url
()
==
worker_url
)
{
workers
.remove
(
index
);
info!
(
"Removed decode worker: {}"
,
worker_url
);
}
}
}
fn
get_worker_urls
(
&
self
)
->
Vec
<
String
>
{
let
mut
urls
=
Vec
::
new
();
// Add prefill worker URLs
if
let
Ok
(
workers
)
=
self
.prefill_workers
.read
()
{
for
worker
in
workers
.iter
()
{
urls
.push
(
worker
.url
()
.to_string
());
}
}
// Add decode worker URLs
if
let
Ok
(
workers
)
=
self
.decode_workers
.read
()
{
for
worker
in
workers
.iter
()
{
urls
.push
(
worker
.url
()
.to_string
());
}
}
urls
}
}
#[async_trait]
impl
RouterTrait
for
PDRouter
{
fn
as_any
(
&
self
)
->
&
dyn
std
::
any
::
Any
{
self
}
async
fn
health
(
&
self
,
_
req
:
Request
<
Body
>
)
->
Response
{
// This is a server readiness check - checking if we have healthy workers
// Workers handle their own health checks in the background
let
mut
all_healthy
=
true
;
let
mut
unhealthy_servers
=
Vec
::
new
();
// Check prefill servers
for
worker
in
self
.prefill_workers
.read
()
.unwrap
()
.iter
()
{
if
!
worker
.is_healthy
()
{
all_healthy
=
false
;
unhealthy_servers
.push
(
format!
(
"Prefill: {}"
,
worker
.url
()));
}
}
// Check decode servers
for
worker
in
self
.decode_workers
.read
()
.unwrap
()
.iter
()
{
if
!
worker
.is_healthy
()
{
all_healthy
=
false
;
unhealthy_servers
.push
(
format!
(
"Decode: {}"
,
worker
.url
()));
}
}
if
all_healthy
{
(
StatusCode
::
OK
,
"All servers healthy"
)
.into_response
()
}
else
{
(
StatusCode
::
SERVICE_UNAVAILABLE
,
format!
(
"Unhealthy servers: {:?}"
,
unhealthy_servers
),
)
.into_response
()
}
}
async
fn
health_generate
(
&
self
,
_
req
:
Request
<
Body
>
)
->
Response
{
// Test model generation capability by selecting a random pair and testing them
// Test model generation capability by selecting a random pair and testing them
// Note: This endpoint actually causes the model to generate tokens, so we only test one pair
// Note: This endpoint actually causes the model to generate tokens, so we only test one pair
...
@@ -1206,7 +1115,7 @@ impl PDRouter {
...
@@ -1206,7 +1115,7 @@ impl PDRouter {
}
}
}
}
pub
async
fn
get_server_info
(
&
self
)
->
Response
{
async
fn
get_server_info
(
&
self
,
_
req
:
Request
<
Body
>
)
->
Response
{
// Get info from the first decode server to match sglang's server info format
// Get info from the first decode server to match sglang's server info format
let
first_decode_url
=
if
let
Ok
(
workers
)
=
self
.decode_workers
.read
()
{
let
first_decode_url
=
if
let
Ok
(
workers
)
=
self
.decode_workers
.read
()
{
workers
.first
()
.map
(|
w
|
w
.url
()
.to_string
())
workers
.first
()
.map
(|
w
|
w
.url
()
.to_string
())
...
@@ -1269,7 +1178,7 @@ impl PDRouter {
...
@@ -1269,7 +1178,7 @@ impl PDRouter {
}
}
}
}
pub
async
fn
get_models
(
&
self
,
req
:
Request
<
Body
>
)
->
Response
{
async
fn
get_models
(
&
self
,
req
:
Request
<
Body
>
)
->
Response
{
// Extract headers first to avoid Send issues
// Extract headers first to avoid Send issues
let
headers
=
crate
::
routers
::
router
::
copy_request_headers
(
&
req
);
let
headers
=
crate
::
routers
::
router
::
copy_request_headers
(
&
req
);
...
@@ -1285,32 +1194,43 @@ impl PDRouter {
...
@@ -1285,32 +1194,43 @@ impl PDRouter {
};
};
if
let
Some
(
worker_url
)
=
first_worker_url
{
if
let
Some
(
worker_url
)
=
first_worker_url
{
// Send request directly without going through Router
let
url
=
format!
(
"{}/v1/models"
,
worker_url
);
let
mut
request_builder
=
self
.client
.get
(
format!
(
"{}/v1/models"
,
worker_url
));
let
mut
request_builder
=
self
.client
.get
(
&
url
);
// Add headers
for
(
name
,
value
)
in
headers
{
for
(
name
,
value
)
in
headers
{
if
name
.to_lowercase
()
!=
"content-type"
&&
name
.to_lowercase
()
!=
"content-length"
request_builder
=
request_builder
.header
(
name
,
value
);
{
request_builder
=
request_builder
.header
(
name
,
value
);
}
}
}
match
request_builder
.send
()
.await
{
match
request_builder
.send
()
.await
{
Ok
(
res
)
=>
{
Ok
(
res
)
if
res
.status
()
.is_success
()
=>
match
res
.bytes
()
.await
{
let
status
=
StatusCode
::
from_u16
(
res
.status
()
.as_u16
())
Ok
(
body
)
=>
(
StatusCode
::
OK
,
body
)
.into_response
(),
.unwrap_or
(
StatusCode
::
INTERNAL_SERVER_ERROR
);
Err
(
e
)
=>
{
match
res
.bytes
()
.await
{
error!
(
"Failed to read response body: {}"
,
e
);
Ok
(
body
)
=>
(
status
,
body
)
.into_response
(),
(
Err
(
e
)
=>
(
StatusCode
::
INTERNAL_SERVER_ERROR
,
StatusCode
::
INTERNAL_SERVER_ERROR
,
format!
(
"Failed to read response body: {}"
,
e
),
format!
(
"Failed to read response body: {}"
,
e
),
)
)
.into_response
()
,
.into_response
()
}
}
},
Ok
(
res
)
=>
{
let
status
=
StatusCode
::
from_u16
(
res
.status
()
.as_u16
())
.unwrap_or
(
StatusCode
::
INTERNAL_SERVER_ERROR
);
(
status
,
format!
(
"Prefill server returned status: {}"
,
res
.status
()),
)
.into_response
()
}
Err
(
e
)
=>
{
error!
(
"Failed to get models: {}"
,
e
);
(
StatusCode
::
INTERNAL_SERVER_ERROR
,
format!
(
"Failed to get models: {}"
,
e
),
)
.into_response
()
}
}
Err
(
e
)
=>
(
StatusCode
::
INTERNAL_SERVER_ERROR
,
format!
(
"Failed to send request: {}"
,
e
),
)
.into_response
(),
}
}
}
else
{
}
else
{
(
(
...
@@ -1321,53 +1241,10 @@ impl PDRouter {
...
@@ -1321,53 +1241,10 @@ impl PDRouter {
}
}
}
}
pub
async
fn
get_loads
(
&
self
,
client
:
&
reqwest
::
Client
)
->
Response
{
async
fn
get_model_info
(
&
self
,
req
:
Request
<
Body
>
)
->
Response
{
let
p_urls
:
Vec
<
_
>
=
self
.prefill_workers
.read
()
.unwrap
()
.iter
()
.map
(|
w
|
w
.url
()
.to_string
())
.collect
();
let
d_urls
:
Vec
<
_
>
=
self
.decode_workers
.read
()
.unwrap
()
.iter
()
.map
(|
w
|
w
.url
()
.to_string
())
.collect
();
let
mut
prefill_loads
=
Vec
::
new
();
let
mut
decode_loads
=
Vec
::
new
();
for
url
in
&
p_urls
{
let
load
=
get_worker_load
(
client
,
url
)
.await
.unwrap_or
(
-
1
);
prefill_loads
.push
(
serde_json
::
json!
({
"engine"
:
format!
(
"(Prefill@{})"
,
url
),
"load"
:
load
as
i64
}));
}
for
url
in
&
d_urls
{
let
load
=
get_worker_load
(
client
,
url
)
.await
.unwrap_or
(
-
1
);
decode_loads
.push
(
serde_json
::
json!
({
"engine"
:
format!
(
"(Decode@{})"
,
url
),
"load"
:
load
as
i64
}));
}
Json
(
serde_json
::
json!
({
"prefill"
:
prefill_loads
,
"decode"
:
decode_loads
}))
.into_response
()
}
pub
async
fn
get_model_info
(
&
self
,
req
:
Request
<
Body
>
)
->
Response
{
// Extract headers first to avoid Send issues
// Extract headers first to avoid Send issues
let
headers
=
crate
::
routers
::
router
::
copy_request_headers
(
&
req
);
let
headers
=
crate
::
routers
::
router
::
copy_request_headers
(
&
req
);
// Get model info from the first prefill server (matches original Rust PDLB behavior)
// Get first prefill worker URL to avoid holding lock across await
// Get first prefill worker URL to avoid holding lock across await
let
first_worker_url
=
if
let
Ok
(
workers
)
=
self
.prefill_workers
.read
()
{
let
first_worker_url
=
if
let
Ok
(
workers
)
=
self
.prefill_workers
.read
()
{
workers
.first
()
.map
(|
w
|
w
.url
()
.to_string
())
workers
.first
()
.map
(|
w
|
w
.url
()
.to_string
())
...
@@ -1380,31 +1257,43 @@ impl PDRouter {
...
@@ -1380,31 +1257,43 @@ impl PDRouter {
};
};
if
let
Some
(
worker_url
)
=
first_worker_url
{
if
let
Some
(
worker_url
)
=
first_worker_url
{
let
mut
request_builder
=
self
.client
.get
(
format!
(
"{}/get_model_info"
,
worker_url
));
let
url
=
format!
(
"{}/get_model_info"
,
worker_url
);
let
mut
request_builder
=
self
.client
.get
(
&
url
);
// Add headers
for
(
name
,
value
)
in
headers
{
for
(
name
,
value
)
in
headers
{
if
name
.to_lowercase
()
!=
"content-type"
&&
name
.to_lowercase
()
!=
"content-length"
request_builder
=
request_builder
.header
(
name
,
value
);
{
request_builder
=
request_builder
.header
(
name
,
value
);
}
}
}
match
request_builder
.send
()
.await
{
match
request_builder
.send
()
.await
{
Ok
(
res
)
=>
{
Ok
(
res
)
if
res
.status
()
.is_success
()
=>
match
res
.bytes
()
.await
{
let
status
=
StatusCode
::
from_u16
(
res
.status
()
.as_u16
())
Ok
(
body
)
=>
(
StatusCode
::
OK
,
body
)
.into_response
(),
.unwrap_or
(
StatusCode
::
INTERNAL_SERVER_ERROR
);
Err
(
e
)
=>
{
match
res
.bytes
()
.await
{
error!
(
"Failed to read response body: {}"
,
e
);
Ok
(
body
)
=>
(
status
,
body
)
.into_response
(),
(
Err
(
e
)
=>
(
StatusCode
::
INTERNAL_SERVER_ERROR
,
StatusCode
::
INTERNAL_SERVER_ERROR
,
format!
(
"Failed to read response body: {}"
,
e
),
format!
(
"Failed to read response body: {}"
,
e
),
)
)
.into_response
()
,
.into_response
()
}
}
},
Ok
(
res
)
=>
{
let
status
=
StatusCode
::
from_u16
(
res
.status
()
.as_u16
())
.unwrap_or
(
StatusCode
::
INTERNAL_SERVER_ERROR
);
(
status
,
format!
(
"Prefill server returned status: {}"
,
res
.status
()),
)
.into_response
()
}
Err
(
e
)
=>
{
error!
(
"Failed to get model info: {}"
,
e
);
(
StatusCode
::
INTERNAL_SERVER_ERROR
,
format!
(
"Failed to get model info: {}"
,
e
),
)
.into_response
()
}
}
Err
(
e
)
=>
(
StatusCode
::
INTERNAL_SERVER_ERROR
,
format!
(
"Failed to send request: {}"
,
e
),
)
.into_response
(),
}
}
}
else
{
}
else
{
(
(
...
@@ -1415,205 +1304,319 @@ impl PDRouter {
...
@@ -1415,205 +1304,319 @@ impl PDRouter {
}
}
}
}
pub
async
fn
flush_cache
(
&
self
,
client
:
&
reqwest
::
Client
)
->
Response
{
async
fn
route_generate
(
let
mut
tasks
=
Vec
::
new
();
&
self
,
headers
:
Option
<&
HeaderMap
>
,
body
:
&
GenerateRequest
,
)
->
Response
{
let
start
=
Instant
::
now
();
//
Flush cache on all prefill servers
//
Convert directly to JSON to preserve all fields automatically
for
worker
in
self
.prefill_workers
.read
()
.unwrap
()
.iter
(
)
{
let
mut
json
=
match
serde_json
::
to_value
(
body
)
{
let
url
=
format!
(
"{}/flush_cache"
,
worker
.url
());
Ok
(
json
)
=>
json
,
tasks
.push
(
client
.post
(
&
url
)
.send
());
Err
(
e
)
=>
return
Self
::
handle_serialization_error
(
e
),
}
}
;
// Flush cache on all decode servers
// Extract flags for routing logic
for
worker
in
self
.decode_workers
.read
()
.unwrap
()
.iter
()
{
let
is_stream
=
body
.stream
;
let
url
=
format!
(
"{}/flush_cache"
,
worker
.url
());
let
return_logprob
=
body
.return_logprob
;
tasks
.push
(
client
.post
(
&
url
)
.send
());
}
let
results
=
futures_util
::
future
::
join_all
(
tasks
)
.await
;
// Extract text for cache-aware routing
let
request_text
=
body
.text
.as_deref
()
.or_else
(||
{
body
.prompt
.as_ref
()
.and_then
(|
p
|
match
p
{
crate
::
openai_api_types
::
StringOrArray
::
String
(
s
)
=>
Some
(
s
.as_str
()),
crate
::
openai_api_types
::
StringOrArray
::
Array
(
v
)
=>
v
.first
()
.map
(|
s
|
s
.as_str
()),
})
});
let
mut
all_success
=
true
;
// Select servers
for
(
i
,
result
)
in
results
.into_iter
()
.enumerate
()
{
let
(
prefill
,
decode
)
=
match
self
.select_pd_pair
(
request_text
)
.await
{
match
result
{
Ok
(
pair
)
=>
pair
,
Ok
(
res
)
if
res
.status
()
.is_success
()
=>
{}
Err
(
e
)
=>
return
Self
::
handle_server_selection_error
(
e
),
Ok
(
res
)
=>
{
};
all_success
=
false
;
warn!
(
"Server {} returned status {} for flush_cache"
,
i
,
res
.status
()
);
}
Err
(
e
)
=>
{
all_success
=
false
;
error!
(
"Server {} error during flush_cache: {}"
,
i
,
e
);
}
}
}
if
all_success
{
// Log routing decision
(
StatusCode
::
OK
,
"Cache flushed on all servers"
)
.into_response
()
info!
(
}
else
{
"PD routing decision route=/generate prefill_url={} decode_url={}"
,
(
prefill
.url
(),
StatusCode
::
INTERNAL_SERVER_ERROR
,
decode
.url
()
"Cache flush failed on one or more servers"
,
);
)
.into_response
()
}
}
}
use
crate
::
routers
::{
RouterTrait
,
WorkerManagement
};
// Inject bootstrap fields directly into JSON
use
async_trait
::
async_trait
;
if
let
Err
(
e
)
=
inject_bootstrap_fields
(
&
mut
json
,
prefill
.as_ref
())
{
return
Self
::
handle_bootstrap_error
(
e
);
}
#[async_trait]
// Execute dual dispatch
impl
WorkerManagement
for
PDRouter
{
self
.execute_dual_dispatch
(
async
fn
add_worker
(
&
self
,
_
worker_url
:
&
str
)
->
Result
<
String
,
String
>
{
headers
,
// For PD router, we don't support adding workers via this generic method
json
,
Err
(
"/generate"
,
"PD router requires specific add_prefill_server or add_decode_server methods"
prefill
.as_ref
(),
.to_string
(),
decode
.as_ref
(),
is_stream
,
return_logprob
,
start
,
)
)
.await
}
}
fn
remove_worker
(
&
self
,
worker_url
:
&
str
)
{
async
fn
route_chat
(
// For PD router, we would need to know if it's a prefill or decode server
&
self
,
// For now, try both
headers
:
Option
<&
HeaderMap
>
,
if
let
Ok
(
mut
workers
)
=
self
.prefill_workers
.write
()
{
body
:
&
ChatCompletionRequest
,
if
let
Some
(
index
)
=
workers
.iter
()
.position
(|
w
|
w
.url
()
==
worker_url
)
{
)
->
Response
{
workers
.remove
(
index
);
let
start
=
Instant
::
now
();
info!
(
"Removed prefill worker: {}"
,
worker_url
);
return
;
}
}
if
let
Ok
(
mut
workers
)
=
self
.decode_workers
.write
()
{
// Convert directly to JSON to preserve all fields automatically
if
let
Some
(
index
)
=
workers
.iter
()
.position
(|
w
|
w
.url
()
==
worker_url
)
{
let
mut
json
=
match
serde_json
::
to_value
(
body
)
{
workers
.remove
(
index
);
Ok
(
json
)
=>
json
,
info!
(
"Removed decode worker: {}"
,
worker_url
);
Err
(
e
)
=>
return
Self
::
handle_serialization_error
(
e
),
};
// Extract flags for routing logic
let
is_stream
=
body
.stream
;
let
return_logprob
=
body
.logprobs
;
// Extract text for cache-aware routing from chat messages
let
request_text
=
body
.messages
.first
()
.and_then
(|
msg
|
match
msg
{
crate
::
openai_api_types
::
ChatMessage
::
User
{
content
,
..
}
=>
{
match
content
{
crate
::
openai_api_types
::
UserMessageContent
::
Text
(
text
)
=>
Some
(
text
.as_str
()),
crate
::
openai_api_types
::
UserMessageContent
::
Parts
(
_
)
=>
None
,
// Skip complex content
}
}
}
crate
::
openai_api_types
::
ChatMessage
::
System
{
content
,
..
}
=>
Some
(
content
.as_str
()),
_
=>
None
,
});
// Select servers
let
(
prefill
,
decode
)
=
match
self
.select_pd_pair
(
request_text
)
.await
{
Ok
(
pair
)
=>
pair
,
Err
(
e
)
=>
return
Self
::
handle_server_selection_error
(
e
),
};
// Log routing decision
info!
(
"PD routing decision route=/v1/chat/completions prefill_url={} decode_url={}"
,
prefill
.url
(),
decode
.url
()
);
// Inject bootstrap fields directly into JSON
if
let
Err
(
e
)
=
inject_bootstrap_fields
(
&
mut
json
,
prefill
.as_ref
())
{
return
Self
::
handle_bootstrap_error
(
e
);
}
}
// Execute dual dispatch
self
.execute_dual_dispatch
(
headers
,
json
,
"/v1/chat/completions"
,
prefill
.as_ref
(),
decode
.as_ref
(),
is_stream
,
return_logprob
,
start
,
)
.await
}
}
fn
get_worker_urls
(
&
self
)
->
Vec
<
String
>
{
async
fn
route_completion
(
let
mut
urls
=
Vec
::
new
();
&
self
,
headers
:
Option
<&
HeaderMap
>
,
body
:
&
CompletionRequest
,
)
->
Response
{
let
start
=
Instant
::
now
();
// Add prefill worker URLs
// Convert directly to JSON to preserve all fields automatically
if
let
Ok
(
workers
)
=
self
.prefill_workers
.read
()
{
let
mut
json
=
match
serde_json
::
to_value
(
body
)
{
for
worker
in
workers
.iter
()
{
Ok
(
json
)
=>
json
,
urls
.push
(
worker
.url
()
.to_string
());
Err
(
e
)
=>
return
Self
::
handle_serialization_error
(
e
),
}
};
}
// Add decode worker URLs
// Extract flags for routing logic
if
let
Ok
(
workers
)
=
self
.decode_workers
.read
()
{
let
is_stream
=
body
.stream
;
for
worker
in
workers
.iter
()
{
let
return_logprob
=
body
.logprobs
.is_some
();
urls
.push
(
worker
.url
()
.to_string
());
}
// Extract text for cache-aware routing
let
request_text
=
match
&
body
.prompt
{
crate
::
openai_api_types
::
StringOrArray
::
String
(
s
)
=>
Some
(
s
.as_str
()),
crate
::
openai_api_types
::
StringOrArray
::
Array
(
v
)
=>
v
.first
()
.map
(|
s
|
s
.as_str
()),
};
// Select servers
let
(
prefill
,
decode
)
=
match
self
.select_pd_pair
(
request_text
)
.await
{
Ok
(
pair
)
=>
pair
,
Err
(
e
)
=>
return
Self
::
handle_server_selection_error
(
e
),
};
// Log routing decision
info!
(
"PD routing decision route=/v1/completions prefill_url={} decode_url={}"
,
prefill
.url
(),
decode
.url
()
);
// Inject bootstrap fields directly into JSON
if
let
Err
(
e
)
=
inject_bootstrap_fields
(
&
mut
json
,
prefill
.as_ref
())
{
return
Self
::
handle_bootstrap_error
(
e
);
}
}
urls
// Execute dual dispatch
self
.execute_dual_dispatch
(
headers
,
json
,
"/v1/completions"
,
prefill
.as_ref
(),
decode
.as_ref
(),
is_stream
,
return_logprob
,
start
,
)
.await
}
}
}
#[async_trait]
async
fn
flush_cache
(
&
self
)
->
Response
{
impl
RouterTrait
for
PDRouter
{
let
mut
results
=
Vec
::
new
();
fn
as_any
(
&
self
)
->
&
dyn
std
::
any
::
Any
{
let
mut
errors
=
Vec
::
new
();
self
}
async
fn
health
(
&
self
,
_
req
:
Request
<
Body
>
)
->
Response
{
// Get prefill worker URLs first to avoid holding lock across await
// This is a server readiness check - checking if we have healthy workers
let
prefill_urls
=
if
let
Ok
(
workers
)
=
self
.prefill_workers
.read
()
{
// Workers handle their own health checks in the background
workers
let
mut
all_healthy
=
true
;
.iter
()
let
mut
unhealthy_servers
=
Vec
::
new
();
.map
(|
w
|
w
.url
()
.to_string
())
.collect
::
<
Vec
<
_
>>
()
}
else
{
errors
.push
(
"Failed to access prefill workers"
.to_string
());
Vec
::
new
()
};
// Check prefill servers
// Flush prefill workers
for
worker
in
self
.prefill_workers
.read
()
.unwrap
()
.iter
()
{
for
worker_url
in
prefill_urls
{
if
!
worker
.is_healthy
()
{
let
url
=
format!
(
"{}/flush_cache"
,
worker_url
);
all_healthy
=
false
;
match
self
.client
.post
(
&
url
)
.send
()
.await
{
unhealthy_servers
.push
(
format!
(
"Prefill: {}"
,
worker
.url
()));
Ok
(
res
)
if
res
.status
()
.is_success
()
=>
{
results
.push
(
format!
(
"Prefill {}: OK"
,
worker_url
));
}
Ok
(
res
)
=>
{
errors
.push
(
format!
(
"Prefill {} returned status: {}"
,
worker_url
,
res
.status
()
));
}
Err
(
e
)
=>
{
errors
.push
(
format!
(
"Prefill {} error: {}"
,
worker_url
,
e
));
}
}
}
}
}
// Check decode servers
// Get decode worker URLs first to avoid holding lock across await
for
worker
in
self
.decode_workers
.read
()
.unwrap
()
.iter
()
{
let
decode_urls
=
if
let
Ok
(
workers
)
=
self
.decode_workers
.read
()
{
if
!
worker
.is_healthy
()
{
workers
all_healthy
=
false
;
.iter
()
unhealthy_servers
.push
(
format!
(
"Decode: {}"
,
worker
.url
()));
.map
(|
w
|
w
.url
()
.to_string
())
.collect
::
<
Vec
<
_
>>
()
}
else
{
errors
.push
(
"Failed to access decode workers"
.to_string
());
Vec
::
new
()
};
// Flush decode workers
for
worker_url
in
decode_urls
{
let
url
=
format!
(
"{}/flush_cache"
,
worker_url
);
match
self
.client
.post
(
&
url
)
.send
()
.await
{
Ok
(
res
)
if
res
.status
()
.is_success
()
=>
{
results
.push
(
format!
(
"Decode {}: OK"
,
worker_url
));
}
Ok
(
res
)
=>
{
errors
.push
(
format!
(
"Decode {} returned status: {}"
,
worker_url
,
res
.status
()
));
}
Err
(
e
)
=>
{
errors
.push
(
format!
(
"Decode {} error: {}"
,
worker_url
,
e
));
}
}
}
}
}
if
all_healthy
{
if
errors
.is_empty
()
{
(
StatusCode
::
OK
,
"All servers healthy"
)
.into_response
()
(
StatusCode
::
OK
,
format!
(
"Cache flushed successfully: {:?}"
,
results
),
)
.into_response
()
}
else
{
}
else
{
(
(
StatusCode
::
SERVICE_UNAVAILABLE
,
StatusCode
::
PARTIAL_CONTENT
,
format!
(
"Unhealthy servers: {:?}"
,
unhealthy_servers
),
format!
(
"Partial success. Results: {:?}, Errors: {:?}"
,
results
,
errors
),
)
)
.into_response
()
.into_response
()
}
}
}
}
async
fn
health_generate
(
&
self
,
_
req
:
Request
<
Body
>
)
->
Response
{
async
fn
get_worker_loads
(
&
self
)
->
Response
{
// Use the existing PDRouter health_generate method
let
mut
loads
=
HashMap
::
new
();
PDRouter
::
health_generate
(
self
)
.await
let
mut
errors
=
Vec
::
new
();
}
async
fn
get_server_info
(
&
self
,
_
req
:
Request
<
Body
>
)
->
Response
{
// Use the existing PDRouter get_server_info method
PDRouter
::
get_server_info
(
self
)
.await
}
async
fn
get_models
(
&
self
,
req
:
Request
<
Body
>
)
->
Response
{
// Use the existing PDRouter get_models method
PDRouter
::
get_models
(
self
,
req
)
.await
}
async
fn
get_model_info
(
&
self
,
req
:
Request
<
Body
>
)
->
Response
{
// Use the existing PDRouter get_model_info method
PDRouter
::
get_model_info
(
self
,
req
)
.await
}
async
fn
route_generate
(
&
self
,
headers
:
Option
<&
HeaderMap
>
,
body
:
&
GenerateRequest
,
)
->
Response
{
// Convert OpenAI format to PD format
let
pd_req
=
body
.clone
()
.to_pd_request
();
PDRouter
::
route_generate
(
self
,
headers
,
pd_req
,
"/generate"
)
.await
// Get prefill worker URLs first to avoid holding lock across await
}
let
prefill_urls
=
if
let
Ok
(
workers
)
=
self
.prefill_workers
.read
()
{
workers
.iter
()
.map
(|
w
|
w
.url
()
.to_string
())
.collect
::
<
Vec
<
_
>>
()
}
else
{
errors
.push
(
"Failed to access prefill workers"
.to_string
());
Vec
::
new
()
};
async
fn
route_chat
(
// Get loads from prefill workers
&
self
,
for
worker_url
in
prefill_urls
{
headers
:
Option
<&
HeaderMap
>
,
match
get_worker_load
(
&
self
.client
,
&
worker_url
)
.await
{
body
:
&
ChatCompletionRequest
,
Some
(
load
)
=>
{
)
->
Response
{
loads
.insert
(
format!
(
"prefill_{}"
,
worker_url
),
load
);
// Convert OpenAI format to PD format
}
let
pd_req
=
body
.clone
()
.to_pd_request
();
None
=>
{
errors
.push
(
format!
(
"Failed to get load from prefill {}"
,
worker_url
));
}
}
}
PDRouter
::
route_chat
(
self
,
headers
,
pd_req
,
"/v1/chat/completions"
)
.await
// Get decode worker URLs first to avoid holding lock across await
}
let
decode_urls
=
if
let
Ok
(
workers
)
=
self
.decode_workers
.read
()
{
workers
.iter
()
.map
(|
w
|
w
.url
()
.to_string
())
.collect
::
<
Vec
<
_
>>
()
}
else
{
errors
.push
(
"Failed to access decode workers"
.to_string
());
Vec
::
new
()
};
async
fn
route_completion
(
// Get loads from decode workers
&
self
,
for
worker_url
in
decode_urls
{
headers
:
Option
<&
HeaderMap
>
,
match
get_worker_load
(
&
self
.client
,
&
worker_url
)
.await
{
body
:
&
CompletionRequest
,
Some
(
load
)
=>
{
)
->
Response
{
loads
.insert
(
format!
(
"decode_{}"
,
worker_url
),
load
);
// Use the new method that preserves OpenAI format
}
PDRouter
::
route_completion
(
self
,
headers
,
body
.clone
(),
"/v1/completions"
)
.await
None
=>
{
}
errors
.push
(
format!
(
"Failed to get load from decode {}"
,
worker_url
));
}
}
}
async
fn
flush_cache
(
&
self
)
->
Response
{
let
response_data
=
serde_json
::
json!
(
{
// Use the existing PDRouter flush_cache method
"loads"
:
loads
,
PDRouter
::
flush_cache
(
self
,
&
self
.client
)
.await
"errors"
:
errors
}
}
);
async
fn
get_worker_loads
(
&
self
)
->
Response
{
(
StatusCode
::
OK
,
Json
(
response_data
))
.into_response
()
// Use the existing PDRouter get_loads method
PDRouter
::
get_loads
(
self
,
&
self
.client
)
.await
}
}
fn
router_type
(
&
self
)
->
&
'static
str
{
fn
router_type
(
&
self
)
->
&
'static
str
{
...
@@ -1688,7 +1691,6 @@ mod tests {
...
@@ -1688,7 +1691,6 @@ mod tests {
use
super
::
*
;
use
super
::
*
;
use
crate
::
core
::{
BasicWorker
,
WorkerType
};
use
crate
::
core
::{
BasicWorker
,
WorkerType
};
use
crate
::
policies
::{
CacheAwarePolicy
,
RandomPolicy
};
use
crate
::
policies
::{
CacheAwarePolicy
,
RandomPolicy
};
use
crate
::
routers
::
pd_types
::
SingleOrBatch
;
fn
create_test_pd_router
()
->
PDRouter
{
fn
create_test_pd_router
()
->
PDRouter
{
let
prefill_policy
=
Arc
::
new
(
RandomPolicy
::
new
());
let
prefill_policy
=
Arc
::
new
(
RandomPolicy
::
new
());
...
@@ -1935,90 +1937,6 @@ mod tests {
...
@@ -1935,90 +1937,6 @@ mod tests {
assert
!
(
result
.is_ok
());
assert
!
(
result
.is_ok
());
}
}
// ============= Bootstrap Injection Tests =============
#[test]
fn
test_bootstrap_injection_with_existing_fields
()
{
let
mut
req
=
GenerateReqInput
{
text
:
Some
(
SingleOrBatch
::
Single
(
"Test"
.to_string
())),
input_ids
:
None
,
stream
:
false
,
bootstrap_host
:
Some
(
SingleOrBatch
::
Single
(
"existing-host"
.to_string
())),
bootstrap_port
:
Some
(
SingleOrBatch
::
Single
(
Some
(
9999
))),
bootstrap_room
:
Some
(
SingleOrBatch
::
Single
(
12345
)),
other
:
Value
::
Object
(
serde_json
::
Map
::
new
()),
};
let
prefill_worker
=
create_test_worker
(
"http://new-host:8000"
.to_string
(),
WorkerType
::
Prefill
{
bootstrap_port
:
Some
(
8080
),
},
true
,
);
// Bootstrap info is added regardless of existing fields
let
result
=
req
.add_bootstrap_info
(
prefill_worker
.as_ref
());
assert
!
(
result
.is_ok
());
// Bootstrap info should be updated with new values
assert_eq!
(
req
.bootstrap_host
,
Some
(
SingleOrBatch
::
Single
(
"new-host"
.to_string
()))
);
assert_eq!
(
req
.bootstrap_port
,
Some
(
SingleOrBatch
::
Single
(
Some
(
8080
))));
// Room should be regenerated (different from original)
if
let
Some
(
SingleOrBatch
::
Single
(
room
))
=
req
.bootstrap_room
{
assert_ne!
(
room
,
12345
);
}
else
{
panic!
(
"Expected single room ID"
);
}
}
#[test]
fn
test_bootstrap_room_generation
()
{
let
mut
req1
=
GenerateReqInput
{
text
:
Some
(
SingleOrBatch
::
Single
(
"Test"
.to_string
())),
input_ids
:
None
,
stream
:
false
,
bootstrap_host
:
None
,
bootstrap_port
:
None
,
bootstrap_room
:
None
,
other
:
Value
::
Object
(
serde_json
::
Map
::
new
()),
};
let
mut
req2
=
GenerateReqInput
{
text
:
Some
(
SingleOrBatch
::
Single
(
"Test"
.to_string
())),
input_ids
:
None
,
stream
:
false
,
bootstrap_host
:
None
,
bootstrap_port
:
None
,
bootstrap_room
:
None
,
other
:
Value
::
Object
(
serde_json
::
Map
::
new
()),
};
let
prefill_worker
=
create_test_worker
(
"http://host:8000"
.to_string
(),
WorkerType
::
Prefill
{
bootstrap_port
:
Some
(
8080
),
},
true
,
);
// Add bootstrap info to both requests
let
_
=
req1
.add_bootstrap_info
(
prefill_worker
.as_ref
());
let
_
=
req2
.add_bootstrap_info
(
prefill_worker
.as_ref
());
// Room IDs should be different
if
let
(
Some
(
SingleOrBatch
::
Single
(
room1
)),
Some
(
SingleOrBatch
::
Single
(
room2
)))
=
(
req1
.bootstrap_room
,
req2
.bootstrap_room
)
{
assert_ne!
(
room1
,
room2
,
"Room IDs should be unique"
);
}
else
{
panic!
(
"Expected single room IDs"
);
}
}
// ============= Worker Selection Tests =============
// ============= Worker Selection Tests =============
#[tokio::test]
#[tokio::test]
...
@@ -2196,4 +2114,158 @@ mod tests {
...
@@ -2196,4 +2114,158 @@ mod tests {
let
workers
=
router
.prefill_workers
.read
()
.unwrap
();
let
workers
=
router
.prefill_workers
.read
()
.unwrap
();
assert_eq!
(
workers
.len
(),
5
);
assert_eq!
(
workers
.len
(),
5
);
}
}
#[tokio::test]
async
fn
test_simplified_routing_preserves_sglang_fields
()
{
use
crate
::
openai_api_types
::
GenerateRequest
;
use
crate
::
routers
::
bootstrap_injector
::
inject_bootstrap_fields
;
// Create a test worker
let
worker
=
BasicWorker
::
new
(
"http://test-server:8000"
.to_string
(),
WorkerType
::
Prefill
{
bootstrap_port
:
Some
(
5678
),
},
);
// Create a GenerateRequest with SGLang extensions
let
mut
session_params
=
std
::
collections
::
HashMap
::
new
();
session_params
.insert
(
"test_key"
.to_string
(),
serde_json
::
json!
(
"test_value"
));
let
request
=
GenerateRequest
{
text
:
Some
(
"Test prompt"
.to_string
()),
stream
:
false
,
return_logprob
:
true
,
// SGLang extensions
lora_path
:
Some
(
crate
::
openai_api_types
::
LoRAPath
::
Single
(
Some
(
"test.bin"
.to_string
(),
))),
session_params
:
Some
(
session_params
.clone
()),
return_hidden_states
:
true
,
rid
:
Some
(
"test-request-id"
.to_string
()),
// Other fields default to None/false
prompt
:
None
,
input_ids
:
None
,
parameters
:
None
,
sampling_params
:
None
,
};
// Convert to JSON (simulating the simplified routing path)
let
mut
json
=
serde_json
::
to_value
(
&
request
)
.unwrap
();
// Inject bootstrap fields
let
result
=
inject_bootstrap_fields
(
&
mut
json
,
&
worker
);
assert
!
(
result
.is_ok
());
// Verify all SGLang fields are preserved
assert_eq!
(
json
[
"text"
],
serde_json
::
json!
(
"Test prompt"
));
assert_eq!
(
json
[
"stream"
],
serde_json
::
json!
(
false
));
assert_eq!
(
json
[
"return_logprob"
],
serde_json
::
json!
(
true
));
assert_eq!
(
json
[
"lora_path"
],
serde_json
::
json!
(
"test.bin"
));
// LoRAPath::Single serializes as just the inner value
assert_eq!
(
json
[
"session_params"
],
serde_json
::
to_value
(
&
session_params
)
.unwrap
()
);
assert_eq!
(
json
[
"return_hidden_states"
],
serde_json
::
json!
(
true
));
assert_eq!
(
json
[
"rid"
],
serde_json
::
json!
(
"test-request-id"
));
// Verify bootstrap fields were added
assert_eq!
(
json
[
"bootstrap_host"
],
serde_json
::
json!
(
"test-server"
));
assert_eq!
(
json
[
"bootstrap_port"
],
serde_json
::
json!
(
5678
));
assert
!
(
json
[
"bootstrap_room"
]
.is_number
());
}
#[tokio::test]
async
fn
test_simplified_routing_chat_completion
()
{
use
crate
::
openai_api_types
::{
ChatCompletionRequest
,
ChatMessage
,
UserMessageContent
};
use
crate
::
routers
::
bootstrap_injector
::
inject_bootstrap_fields
;
// Create a test worker
let
worker
=
BasicWorker
::
new
(
"http://chat-server:8000"
.to_string
(),
WorkerType
::
Prefill
{
bootstrap_port
:
Some
(
9999
),
},
);
// Create a ChatCompletionRequest with SGLang extensions
let
request
=
ChatCompletionRequest
{
model
:
"gpt-4"
.to_string
(),
messages
:
vec!
[
ChatMessage
::
User
{
role
:
"user"
.to_string
(),
content
:
UserMessageContent
::
Text
(
"Hello world!"
.to_string
()),
name
:
None
,
}],
stream
:
false
,
n
:
Some
(
2
),
// This should create batch bootstrap
// SGLang extensions
top_k
:
Some
(
50
),
separate_reasoning
:
false
,
stream_reasoning
:
true
,
// Set all other fields to defaults
temperature
:
None
,
top_p
:
None
,
stream_options
:
None
,
stop
:
None
,
max_tokens
:
None
,
max_completion_tokens
:
None
,
presence_penalty
:
None
,
frequency_penalty
:
None
,
logit_bias
:
None
,
user
:
None
,
seed
:
None
,
logprobs
:
false
,
top_logprobs
:
None
,
response_format
:
None
,
tools
:
None
,
tool_choice
:
None
,
parallel_tool_calls
:
None
,
functions
:
None
,
function_call
:
None
,
min_p
:
None
,
min_tokens
:
None
,
repetition_penalty
:
None
,
regex
:
None
,
ebnf
:
None
,
stop_token_ids
:
None
,
no_stop_trim
:
false
,
ignore_eos
:
false
,
continue_final_message
:
false
,
skip_special_tokens
:
true
,
lora_path
:
None
,
session_params
:
None
,
return_hidden_states
:
false
,
};
// Convert to JSON (simulating the simplified routing path)
let
mut
json
=
serde_json
::
to_value
(
&
request
)
.unwrap
();
// Inject bootstrap fields
let
result
=
inject_bootstrap_fields
(
&
mut
json
,
&
worker
);
assert
!
(
result
.is_ok
());
// Verify original fields preserved
assert_eq!
(
json
[
"model"
],
serde_json
::
json!
(
"gpt-4"
));
assert_eq!
(
json
[
"stream"
],
serde_json
::
json!
(
false
));
assert_eq!
(
json
[
"n"
],
serde_json
::
json!
(
2
));
assert_eq!
(
json
[
"top_k"
],
serde_json
::
json!
(
50
));
assert_eq!
(
json
[
"separate_reasoning"
],
serde_json
::
json!
(
false
));
assert_eq!
(
json
[
"stream_reasoning"
],
serde_json
::
json!
(
true
));
// Verify batch bootstrap fields for n=2
let
bootstrap_hosts
=
json
[
"bootstrap_host"
]
.as_array
()
.unwrap
();
assert_eq!
(
bootstrap_hosts
.len
(),
2
);
assert_eq!
(
bootstrap_hosts
[
0
],
serde_json
::
json!
(
"chat-server"
));
assert_eq!
(
bootstrap_hosts
[
1
],
serde_json
::
json!
(
"chat-server"
));
let
bootstrap_ports
=
json
[
"bootstrap_port"
]
.as_array
()
.unwrap
();
assert_eq!
(
bootstrap_ports
.len
(),
2
);
assert_eq!
(
bootstrap_ports
[
0
],
serde_json
::
json!
(
9999
));
assert_eq!
(
bootstrap_ports
[
1
],
serde_json
::
json!
(
9999
));
let
bootstrap_rooms
=
json
[
"bootstrap_room"
]
.as_array
()
.unwrap
();
assert_eq!
(
bootstrap_rooms
.len
(),
2
);
// Rooms should be different (randomness)
assert_ne!
(
bootstrap_rooms
[
0
],
bootstrap_rooms
[
1
]);
}
}
}
sgl-router/src/routers/pd_types.rs
View file @
8c7bb39d
// Essential PDLB types extracted for PD routing
use
crate
::
core
::{
Worker
,
WorkerType
};
use
crate
::
openai_api_types
::{
CompletionRequest
,
StringOrArray
};
use
serde
::{
Deserialize
,
Serialize
};
use
serde_json
::
Value
;
// Custom error type for PD router operations
// Custom error type for PD router operations
#[derive(Debug,
thiserror::Error)]
#[derive(Debug,
thiserror::Error)]
pub
enum
PDRouterError
{
pub
enum
PDRouterError
{
...
@@ -58,428 +51,3 @@ pub enum PDSelectionPolicy {
...
@@ -58,428 +51,3 @@ pub enum PDSelectionPolicy {
balance_rel_threshold
:
f32
,
balance_rel_threshold
:
f32
,
},
},
}
}
// Bootstrap types from PDLB
#[derive(Debug,
Deserialize,
Serialize,
PartialEq)]
#[serde(untagged)]
pub
enum
SingleOrBatch
<
T
>
{
Single
(
T
),
Batch
(
Vec
<
T
>
),
}
pub
type
InputIds
=
SingleOrBatch
<
Vec
<
i32
>>
;
pub
type
InputText
=
SingleOrBatch
<
String
>
;
pub
type
BootstrapHost
=
SingleOrBatch
<
String
>
;
pub
type
BootstrapPort
=
SingleOrBatch
<
Option
<
u16
>>
;
pub
type
BootstrapRoom
=
SingleOrBatch
<
u64
>
;
// Bootstrap trait for request handling
pub
trait
Bootstrap
:
Send
+
Sync
{
fn
is_stream
(
&
self
)
->
bool
;
fn
get_batch_size
(
&
self
)
->
Result
<
Option
<
usize
>
,
String
>
;
fn
set_bootstrap_info
(
&
mut
self
,
bootstrap_host
:
BootstrapHost
,
bootstrap_port
:
BootstrapPort
,
bootstrap_room
:
BootstrapRoom
,
);
fn
add_bootstrap_info
(
&
mut
self
,
prefill_worker
:
&
dyn
Worker
)
->
Result
<
(),
String
>
{
let
batch_size
=
self
.get_batch_size
()
?
;
// Extract bootstrap port from prefill worker if it's a prefill type
let
bootstrap_port
=
match
prefill_worker
.worker_type
()
{
WorkerType
::
Prefill
{
bootstrap_port
}
=>
bootstrap_port
,
_
=>
None
,
};
let
hostname
=
get_hostname
(
prefill_worker
.url
());
if
let
Some
(
batch_size
)
=
batch_size
{
self
.set_bootstrap_info
(
BootstrapHost
::
Batch
(
vec!
[
hostname
;
batch_size
]),
BootstrapPort
::
Batch
(
vec!
[
bootstrap_port
;
batch_size
]),
// Use high-quality random numbers to minimize collision risk
BootstrapRoom
::
Batch
(
(
0
..
batch_size
)
.map
(|
_
|
{
// Generate a value in the range [0, 2^63 - 1] to match Python's random.randint(0, 2**63 - 1)
rand
::
random
::
<
u64
>
()
&
(
i64
::
MAX
as
u64
)
})
.collect
(),
),
);
}
else
{
self
.set_bootstrap_info
(
BootstrapHost
::
Single
(
hostname
),
BootstrapPort
::
Single
(
bootstrap_port
),
BootstrapRoom
::
Single
(
// Generate a value in the range [0, 2^63 - 1] to match Python's random.randint(0, 2**63 - 1)
rand
::
random
::
<
u64
>
()
&
(
i64
::
MAX
as
u64
),
),
);
}
Ok
(())
}
}
// Request types
#[derive(Debug,
Deserialize,
Serialize)]
pub
struct
GenerateReqInput
{
pub
text
:
Option
<
InputText
>
,
pub
input_ids
:
Option
<
InputIds
>
,
#[serde(default)]
pub
stream
:
bool
,
pub
bootstrap_host
:
Option
<
BootstrapHost
>
,
pub
bootstrap_port
:
Option
<
BootstrapPort
>
,
pub
bootstrap_room
:
Option
<
BootstrapRoom
>
,
#[serde(flatten)]
pub
other
:
Value
,
}
impl
GenerateReqInput
{
pub
fn
get_batch_size
(
&
self
)
->
Result
<
Option
<
usize
>
,
String
>
{
if
self
.text
.is_some
()
&&
self
.input_ids
.is_some
()
{
return
Err
(
"Both text and input_ids are present in the request"
.to_string
());
}
// Check text batch
if
let
Some
(
InputText
::
Batch
(
texts
))
=
&
self
.text
{
if
texts
.is_empty
()
{
return
Err
(
"Batch text array is empty"
.to_string
());
}
return
Ok
(
Some
(
texts
.len
()));
}
// Check input_ids batch
if
let
Some
(
InputIds
::
Batch
(
ids
))
=
&
self
.input_ids
{
if
ids
.is_empty
()
{
return
Err
(
"Batch input_ids array is empty"
.to_string
());
}
// Validate each sequence is not empty
for
(
i
,
seq
)
in
ids
.iter
()
.enumerate
()
{
if
seq
.is_empty
()
{
return
Err
(
format!
(
"Input sequence at index {} is empty"
,
i
));
}
}
return
Ok
(
Some
(
ids
.len
()));
}
Ok
(
None
)
}
}
impl
Bootstrap
for
GenerateReqInput
{
fn
is_stream
(
&
self
)
->
bool
{
self
.stream
}
fn
get_batch_size
(
&
self
)
->
Result
<
Option
<
usize
>
,
String
>
{
self
.get_batch_size
()
}
fn
set_bootstrap_info
(
&
mut
self
,
bootstrap_host
:
BootstrapHost
,
bootstrap_port
:
BootstrapPort
,
bootstrap_room
:
BootstrapRoom
,
)
{
self
.bootstrap_host
=
Some
(
bootstrap_host
);
self
.bootstrap_port
=
Some
(
bootstrap_port
);
self
.bootstrap_room
=
Some
(
bootstrap_room
);
}
}
#[derive(Debug,
Deserialize,
Serialize)]
pub
struct
ChatReqInput
{
#[serde(default)]
pub
stream
:
bool
,
pub
bootstrap_host
:
Option
<
BootstrapHost
>
,
pub
bootstrap_port
:
Option
<
BootstrapPort
>
,
pub
bootstrap_room
:
Option
<
BootstrapRoom
>
,
#[serde(flatten)]
pub
other
:
Value
,
}
impl
Bootstrap
for
ChatReqInput
{
fn
is_stream
(
&
self
)
->
bool
{
self
.stream
}
fn
get_batch_size
(
&
self
)
->
Result
<
Option
<
usize
>
,
String
>
{
// Check if 'n' parameter is present and > 1
if
let
Some
(
n_value
)
=
self
.other
.get
(
"n"
)
{
if
let
Some
(
n
)
=
n_value
.as_u64
()
{
if
n
>
1
{
return
Ok
(
Some
(
n
as
usize
));
}
}
}
Ok
(
None
)
}
fn
set_bootstrap_info
(
&
mut
self
,
bootstrap_host
:
BootstrapHost
,
bootstrap_port
:
BootstrapPort
,
bootstrap_room
:
BootstrapRoom
,
)
{
self
.bootstrap_host
=
Some
(
bootstrap_host
);
self
.bootstrap_port
=
Some
(
bootstrap_port
);
self
.bootstrap_room
=
Some
(
bootstrap_room
);
}
}
// Bootstrap implementation for CompletionRequest to preserve OpenAI format
impl
Bootstrap
for
CompletionRequest
{
fn
is_stream
(
&
self
)
->
bool
{
self
.stream
}
fn
get_batch_size
(
&
self
)
->
Result
<
Option
<
usize
>
,
String
>
{
if
let
StringOrArray
::
Array
(
prompts
)
=
&
self
.prompt
{
if
prompts
.is_empty
()
{
return
Err
(
"Batch prompt array is empty"
.to_string
());
}
return
Ok
(
Some
(
prompts
.len
()));
}
// Single string prompt
Ok
(
None
)
}
fn
set_bootstrap_info
(
&
mut
self
,
bootstrap_host
:
BootstrapHost
,
bootstrap_port
:
BootstrapPort
,
bootstrap_room
:
BootstrapRoom
,
)
{
// Insert bootstrap_host - it serializes correctly whether Single or Batch
if
let
Ok
(
host_value
)
=
serde_json
::
to_value
(
&
bootstrap_host
)
{
self
.other
.insert
(
"bootstrap_host"
.to_string
(),
host_value
);
}
// Insert bootstrap_port - it serializes correctly whether Single or Batch
if
let
Ok
(
port_value
)
=
serde_json
::
to_value
(
&
bootstrap_port
)
{
self
.other
.insert
(
"bootstrap_port"
.to_string
(),
port_value
);
}
// Insert bootstrap_room - it serializes correctly whether Single or Batch
if
let
Ok
(
room_value
)
=
serde_json
::
to_value
(
&
bootstrap_room
)
{
self
.other
.insert
(
"bootstrap_room"
.to_string
(),
room_value
);
}
}
}
#[cfg(test)]
mod
bootstrap_tests
{
use
super
::
*
;
use
crate
::
core
::
BasicWorker
;
use
crate
::
openai_api_types
::
StringOrArray
;
/// Create a default CompletionRequest for testing with minimal fields set
fn
default_completion_request
()
->
CompletionRequest
{
CompletionRequest
{
model
:
String
::
new
(),
prompt
:
StringOrArray
::
String
(
String
::
new
()),
n
:
None
,
other
:
serde_json
::
Map
::
new
(),
suffix
:
None
,
max_tokens
:
None
,
temperature
:
None
,
top_p
:
None
,
stream
:
false
,
stream_options
:
None
,
logprobs
:
None
,
echo
:
false
,
stop
:
None
,
presence_penalty
:
None
,
frequency_penalty
:
None
,
best_of
:
None
,
logit_bias
:
None
,
user
:
None
,
seed
:
None
,
// SGLang Extensions
top_k
:
None
,
min_p
:
None
,
min_tokens
:
None
,
repetition_penalty
:
None
,
regex
:
None
,
ebnf
:
None
,
json_schema
:
None
,
stop_token_ids
:
None
,
no_stop_trim
:
false
,
ignore_eos
:
false
,
skip_special_tokens
:
true
,
// SGLang Extensions
lora_path
:
None
,
session_params
:
None
,
return_hidden_states
:
false
,
}
}
#[test]
fn
test_completion_batch_size_with_array_prompt
()
{
let
req
=
CompletionRequest
{
model
:
"test"
.to_string
(),
prompt
:
StringOrArray
::
Array
(
vec!
[
"prompt1"
.to_string
(),
"prompt2"
.to_string
()]),
..
default_completion_request
()
};
// Should return batch size for array prompt
assert_eq!
(
req
.get_batch_size
()
.unwrap
(),
Some
(
2
));
}
#[test]
fn
test_completion_batch_size_with_single_prompt
()
{
let
req
=
CompletionRequest
{
model
:
"test"
.to_string
(),
prompt
:
StringOrArray
::
String
(
"single prompt"
.to_string
()),
..
default_completion_request
()
};
// Should return None for single prompt
assert_eq!
(
req
.get_batch_size
()
.unwrap
(),
None
);
}
#[test]
fn
test_completion_batch_size_with_n_parameter
()
{
let
req
=
CompletionRequest
{
model
:
"test"
.to_string
(),
prompt
:
StringOrArray
::
String
(
"single prompt"
.to_string
()),
n
:
Some
(
3
),
..
default_completion_request
()
};
// Should return None for single string prompt, even with n > 1
// SGLang handles n parameter differently than batch requests
assert_eq!
(
req
.get_batch_size
()
.unwrap
(),
None
);
}
#[test]
fn
test_completion_bootstrap_single_values
()
{
let
mut
req
=
CompletionRequest
{
model
:
"test"
.to_string
(),
prompt
:
StringOrArray
::
Array
(
vec!
[
"prompt1"
.to_string
(),
"prompt2"
.to_string
()]),
..
default_completion_request
()
};
// Set bootstrap info - should always use single values
req
.set_bootstrap_info
(
BootstrapHost
::
Single
(
"test-server"
.to_string
()),
BootstrapPort
::
Single
(
Some
(
5678
)),
BootstrapRoom
::
Single
(
12345
),
);
// Verify single values were created
assert
!
(
req
.other
.get
(
"bootstrap_host"
)
.unwrap
()
.is_string
());
assert
!
(
req
.other
.get
(
"bootstrap_port"
)
.unwrap
()
.is_number
());
assert
!
(
req
.other
.get
(
"bootstrap_room"
)
.unwrap
()
.is_number
());
assert_eq!
(
req
.other
.get
(
"bootstrap_host"
)
.unwrap
()
.as_str
()
.unwrap
(),
"test-server"
);
assert_eq!
(
req
.other
.get
(
"bootstrap_port"
)
.unwrap
()
.as_u64
()
.unwrap
(),
5678
);
assert_eq!
(
req
.other
.get
(
"bootstrap_room"
)
.unwrap
()
.as_u64
()
.unwrap
(),
12345
);
}
#[test]
fn
test_completion_bootstrap_array_values
()
{
let
mut
req
=
CompletionRequest
{
model
:
"test"
.to_string
(),
prompt
:
StringOrArray
::
Array
(
vec!
[
"prompt1"
.to_string
(),
"prompt2"
.to_string
()]),
..
default_completion_request
()
};
// Set bootstrap info with arrays
req
.set_bootstrap_info
(
BootstrapHost
::
Batch
(
vec!
[
"test-server"
.to_string
();
2
]),
BootstrapPort
::
Batch
(
vec!
[
Some
(
5678
);
2
]),
BootstrapRoom
::
Batch
(
vec!
[
12345
,
67890
]),
);
// Verify arrays were created correctly
assert
!
(
req
.other
.get
(
"bootstrap_host"
)
.unwrap
()
.is_array
());
assert
!
(
req
.other
.get
(
"bootstrap_port"
)
.unwrap
()
.is_array
());
assert
!
(
req
.other
.get
(
"bootstrap_room"
)
.unwrap
()
.is_array
());
let
hosts
=
req
.other
.get
(
"bootstrap_host"
)
.unwrap
()
.as_array
()
.unwrap
();
assert_eq!
(
hosts
.len
(),
2
);
assert_eq!
(
hosts
[
0
]
.as_str
()
.unwrap
(),
"test-server"
);
let
ports
=
req
.other
.get
(
"bootstrap_port"
)
.unwrap
()
.as_array
()
.unwrap
();
assert_eq!
(
ports
.len
(),
2
);
assert_eq!
(
ports
[
0
]
.as_u64
()
.unwrap
(),
5678
);
let
rooms
=
req
.other
.get
(
"bootstrap_room"
)
.unwrap
()
.as_array
()
.unwrap
();
assert_eq!
(
rooms
.len
(),
2
);
assert_eq!
(
rooms
[
0
]
.as_u64
()
.unwrap
(),
12345
);
assert_eq!
(
rooms
[
1
]
.as_u64
()
.unwrap
(),
67890
);
}
#[test]
fn
test_bootstrap_room_range
()
{
// Test that bootstrap_room values are within the expected range [0, 2^63 - 1]
let
worker
=
BasicWorker
::
new
(
"http://test:8000"
.to_string
(),
WorkerType
::
Prefill
{
bootstrap_port
:
Some
(
8080
),
},
);
// Test single request
let
mut
single_req
=
GenerateReqInput
{
text
:
Some
(
InputText
::
Single
(
"test"
.to_string
())),
input_ids
:
None
,
stream
:
false
,
bootstrap_host
:
None
,
bootstrap_port
:
None
,
bootstrap_room
:
None
,
other
:
Value
::
Object
(
serde_json
::
Map
::
new
()),
};
for
_
in
0
..
200000
{
single_req
.add_bootstrap_info
(
&
worker
)
.unwrap
();
if
let
Some
(
BootstrapRoom
::
Single
(
room
))
=
single_req
.bootstrap_room
{
// Verify the room value is within signed 64-bit range
assert
!
(
room
<=
i64
::
MAX
as
u64
,
"Room {} exceeds i64::MAX"
,
room
);
}
else
{
panic!
(
"Expected single bootstrap room"
);
}
}
// Test batch request
let
mut
batch_req
=
GenerateReqInput
{
text
:
Some
(
InputText
::
Batch
(
vec!
[
"test1"
.to_string
(),
"test2"
.to_string
(),
])),
input_ids
:
None
,
stream
:
false
,
bootstrap_host
:
None
,
bootstrap_port
:
None
,
bootstrap_room
:
None
,
other
:
Value
::
Object
(
serde_json
::
Map
::
new
()),
};
for
_
in
0
..
200000
{
batch_req
.add_bootstrap_info
(
&
worker
)
.unwrap
();
if
let
Some
(
BootstrapRoom
::
Batch
(
rooms
))
=
&
batch_req
.bootstrap_room
{
for
room
in
rooms
{
// Verify each room value is within signed 64-bit range
assert
!
(
*
room
<=
i64
::
MAX
as
u64
,
"Room {} exceeds i64::MAX"
,
room
);
}
}
else
{
panic!
(
"Expected batch bootstrap rooms"
);
}
}
}
}
sgl-router/src/routers/request_adapter.rs
deleted
100644 → 0
View file @
ca47e24f
// Request adapter to bridge OpenAI API types with PD routing requirements
use
super
::
pd_types
::{
Bootstrap
,
ChatReqInput
,
GenerateReqInput
,
SingleOrBatch
};
use
crate
::
openai_api_types
::{
ChatCompletionRequest
,
CompletionRequest
,
GenerateRequest
,
GenerationRequest
,
StringOrArray
,
};
use
serde_json
::
Value
;
/// Adapter trait to convert OpenAI requests to PD-compatible requests
pub
trait
ToPdRequest
{
type
Output
:
Bootstrap
;
fn
to_pd_request
(
self
)
->
Self
::
Output
;
}
// Helper macro to insert optional fields into a map
macro_rules!
insert_if_some
{
(
$map:expr
,
$
(
$field:expr
=>
$key:expr
),
*
$
(,)
?
)
=>
{
$
(
if
let
Some
(
value
)
=
$field
{
$map
.insert
(
$key
.to_string
(),
serde_json
::
to_value
(
value
)
.unwrap_or
(
Value
::
Null
));
}
)
*
};
}
// Helper macro for simple value insertions
macro_rules!
insert_value
{
(
$map:expr
,
$
(
$field:expr
=>
$key:expr
),
*
$
(,)
?
)
=>
{
$
(
$map
.insert
(
$key
.to_string
(),
$field
.into
());
)
*
};
}
// ============= Generate Request Adapter =============
impl
ToPdRequest
for
GenerateRequest
{
type
Output
=
GenerateReqInput
;
fn
to_pd_request
(
self
)
->
Self
::
Output
{
// Build the other fields first
let
mut
other
=
serde_json
::
Map
::
new
();
// Handle text input - check in priority order: text (SGLang), prompt (OpenAI)
let
(
text
,
input_ids
)
=
if
let
Some
(
text_str
)
=
self
.text
{
// SGLang native format
(
Some
(
SingleOrBatch
::
Single
(
text_str
)),
None
)
}
else
if
let
Some
(
prompt
)
=
self
.prompt
{
// OpenAI style prompt
let
text
=
match
prompt
{
StringOrArray
::
String
(
s
)
=>
Some
(
SingleOrBatch
::
Single
(
s
)),
StringOrArray
::
Array
(
v
)
=>
Some
(
SingleOrBatch
::
Batch
(
v
)),
};
(
text
,
None
)
}
else
if
let
Some
(
ids
)
=
self
.input_ids
{
// Input IDs case
let
input_ids
=
match
ids
{
crate
::
openai_api_types
::
InputIds
::
Single
(
ids
)
=>
Some
(
SingleOrBatch
::
Single
(
ids
)),
crate
::
openai_api_types
::
InputIds
::
Batch
(
ids
)
=>
Some
(
SingleOrBatch
::
Batch
(
ids
)),
};
(
None
,
input_ids
)
}
else
{
// No input provided
(
None
,
None
)
};
// Add parameters to other - handle both old and new style
if
let
Some
(
params
)
=
self
.parameters
{
// For generate endpoint, extract max_new_tokens to top level if present
let
mut
params_value
=
serde_json
::
to_value
(
&
params
)
.unwrap_or
(
Value
::
Null
);
if
let
Value
::
Object
(
ref
mut
params_map
)
=
params_value
{
// Move max_new_tokens to top level if it exists
if
let
Some
(
max_new_tokens
)
=
params_map
.remove
(
"max_new_tokens"
)
{
other
.insert
(
"max_new_tokens"
.to_string
(),
max_new_tokens
);
}
// Move temperature to top level if it exists
if
let
Some
(
temperature
)
=
params_map
.remove
(
"temperature"
)
{
other
.insert
(
"temperature"
.to_string
(),
temperature
);
}
}
// Only add parameters if there are remaining fields
if
!
params_value
.is_null
()
&&
params_value
.as_object
()
.map_or
(
false
,
|
m
|
!
m
.is_empty
())
{
other
.insert
(
"parameters"
.to_string
(),
params_value
);
}
}
// Add sampling_params if present
if
let
Some
(
sampling_params
)
=
self
.sampling_params
{
let
params_value
=
serde_json
::
to_value
(
&
sampling_params
)
.unwrap_or
(
Value
::
Null
);
if
!
params_value
.is_null
()
{
// Extract commonly used fields to top level
if
let
Value
::
Object
(
ref
params_map
)
=
params_value
{
if
let
Some
(
max_new_tokens
)
=
params_map
.get
(
"max_new_tokens"
)
{
other
.insert
(
"max_new_tokens"
.to_string
(),
max_new_tokens
.clone
());
}
if
let
Some
(
temperature
)
=
params_map
.get
(
"temperature"
)
{
other
.insert
(
"temperature"
.to_string
(),
temperature
.clone
());
}
}
other
.insert
(
"sampling_params"
.to_string
(),
params_value
);
}
}
// Add other fields
insert_value!
(
other
,
self
.stream
=>
"stream"
,
self
.return_logprob
=>
"return_logprob"
);
GenerateReqInput
{
text
,
input_ids
,
stream
:
self
.stream
,
bootstrap_host
:
None
,
bootstrap_port
:
None
,
bootstrap_room
:
None
,
other
:
Value
::
Object
(
other
),
}
}
}
// ============= Completion Request Adapter =============
impl
ToPdRequest
for
CompletionRequest
{
type
Output
=
GenerateReqInput
;
fn
to_pd_request
(
self
)
->
Self
::
Output
{
// Convert CompletionRequest to GenerateReqInput
let
text
=
match
self
.prompt
{
StringOrArray
::
String
(
s
)
=>
Some
(
SingleOrBatch
::
Single
(
s
)),
StringOrArray
::
Array
(
v
)
=>
Some
(
SingleOrBatch
::
Batch
(
v
)),
};
// Map OpenAI parameters to generate parameters
let
mut
other
=
serde_json
::
Map
::
new
();
// Create parameters object
let
mut
params
=
serde_json
::
Map
::
new
();
// Map OpenAI fields to internal parameter names
insert_if_some!
(
params
,
self
.max_tokens
=>
"max_new_tokens"
,
self
.temperature
=>
"temperature"
,
self
.top_p
=>
"top_p"
,
self
.n
=>
"best_of"
,
self
.logprobs
=>
"top_n_tokens"
,
self
.seed
=>
"seed"
);
// Special handling for fields that need transformation
if
let
Some
(
presence_penalty
)
=
self
.presence_penalty
{
params
.insert
(
"repetition_penalty"
.to_string
(),
(
1.0
+
presence_penalty
)
.into
(),
);
}
if
let
Some
(
stop
)
=
self
.stop
{
let
stop_sequences
=
match
stop
{
StringOrArray
::
String
(
s
)
=>
vec!
[
s
],
StringOrArray
::
Array
(
v
)
=>
v
,
};
params
.insert
(
"stop"
.to_string
(),
stop_sequences
.into
());
}
if
self
.echo
{
params
.insert
(
"return_full_text"
.to_string
(),
true
.into
());
}
other
.insert
(
"parameters"
.to_string
(),
Value
::
Object
(
params
));
// Store original model and stream flag
insert_value!
(
other
,
self
.model
=>
"model"
,
self
.stream
=>
"stream"
);
// Add SGLang extension fields
insert_if_some!
(
other
,
// SGLang Extensions - Priority 1
self
.top_k
=>
"top_k"
,
self
.min_p
=>
"min_p"
,
self
.min_tokens
=>
"min_tokens"
,
self
.repetition_penalty
=>
"repetition_penalty"
,
self
.regex
=>
"regex"
,
self
.ebnf
=>
"ebnf"
,
self
.stop_token_ids
=>
"stop_token_ids"
,
// SGLang Extensions - Priority 2
self
.lora_path
=>
"lora_path"
,
self
.session_params
=>
"session_params"
);
// SGLang boolean extensions (CompletionRequest has these as bool, not Option<bool>)
other
.insert
(
"no_stop_trim"
.to_string
(),
self
.no_stop_trim
.into
());
other
.insert
(
"ignore_eos"
.to_string
(),
self
.ignore_eos
.into
());
other
.insert
(
"skip_special_tokens"
.to_string
(),
self
.skip_special_tokens
.into
(),
);
other
.insert
(
"return_hidden_states"
.to_string
(),
self
.return_hidden_states
.into
(),
);
GenerateReqInput
{
text
,
input_ids
:
None
,
stream
:
self
.stream
,
bootstrap_host
:
None
,
bootstrap_port
:
None
,
bootstrap_room
:
None
,
other
:
Value
::
Object
(
other
),
}
}
}
// ============= Chat Completion Request Adapter =============
impl
ToPdRequest
for
ChatCompletionRequest
{
type
Output
=
ChatReqInput
;
fn
to_pd_request
(
self
)
->
Self
::
Output
{
let
mut
other
=
serde_json
::
Map
::
new
();
// Add required fields
insert_if_some!
(
other
,
Some
(
&
self
.messages
)
=>
"messages"
);
insert_value!
(
other
,
self
.model
=>
"model"
,
self
.stream
=>
"stream"
);
// Add all optional fields
insert_if_some!
(
other
,
self
.temperature
=>
"temperature"
,
self
.top_p
=>
"top_p"
,
self
.n
=>
"n"
,
self
.stream_options
=>
"stream_options"
,
self
.stop
=>
"stop"
,
self
.max_tokens
=>
"max_tokens"
,
self
.max_completion_tokens
=>
"max_completion_tokens"
,
self
.presence_penalty
=>
"presence_penalty"
,
self
.frequency_penalty
=>
"frequency_penalty"
,
self
.logit_bias
=>
"logit_bias"
,
self
.user
=>
"user"
,
self
.seed
=>
"seed"
,
self
.top_logprobs
=>
"top_logprobs"
,
self
.response_format
=>
"response_format"
,
self
.tools
=>
"tools"
,
self
.tool_choice
=>
"tool_choice"
,
self
.parallel_tool_calls
=>
"parallel_tool_calls"
,
self
.functions
=>
"functions"
,
self
.function_call
=>
"function_call"
,
// SGLang Extensions - Priority 1
self
.top_k
=>
"top_k"
,
self
.min_p
=>
"min_p"
,
self
.min_tokens
=>
"min_tokens"
,
self
.repetition_penalty
=>
"repetition_penalty"
,
self
.regex
=>
"regex"
,
self
.ebnf
=>
"ebnf"
,
self
.stop_token_ids
=>
"stop_token_ids"
,
// SGLang Extensions - Priority 2
self
.lora_path
=>
"lora_path"
,
self
.session_params
=>
"session_params"
);
// Handle boolean flags
if
self
.logprobs
{
other
.insert
(
"logprobs"
.to_string
(),
true
.into
());
}
// SGLang boolean extensions (ChatCompletionRequest has these as bool, not Option<bool>)
other
.insert
(
"no_stop_trim"
.to_string
(),
self
.no_stop_trim
.into
());
other
.insert
(
"ignore_eos"
.to_string
(),
self
.ignore_eos
.into
());
other
.insert
(
"continue_final_message"
.to_string
(),
self
.continue_final_message
.into
(),
);
other
.insert
(
"skip_special_tokens"
.to_string
(),
self
.skip_special_tokens
.into
(),
);
other
.insert
(
"separate_reasoning"
.to_string
(),
self
.separate_reasoning
.into
(),
);
other
.insert
(
"stream_reasoning"
.to_string
(),
self
.stream_reasoning
.into
());
other
.insert
(
"return_hidden_states"
.to_string
(),
self
.return_hidden_states
.into
(),
);
ChatReqInput
{
stream
:
self
.stream
,
bootstrap_host
:
None
,
bootstrap_port
:
None
,
bootstrap_room
:
None
,
other
:
Value
::
Object
(
other
),
}
}
}
// ============= Direct routing support for regular router =============
/// Extension trait for routing without PD conversion
pub
trait
RouteableRequest
:
GenerationRequest
+
serde
::
Serialize
+
Clone
{
/// Convert to JSON for sending to backend
fn
to_json
(
&
self
)
->
Result
<
Value
,
serde_json
::
Error
>
{
serde_json
::
to_value
(
self
)
}
/// Convert to bytes for legacy routing
fn
to_bytes
(
&
self
)
->
Result
<
bytes
::
Bytes
,
serde_json
::
Error
>
{
let
json
=
serde_json
::
to_vec
(
self
)
?
;
Ok
(
bytes
::
Bytes
::
from
(
json
))
}
}
impl
RouteableRequest
for
GenerateRequest
{}
impl
RouteableRequest
for
CompletionRequest
{}
impl
RouteableRequest
for
ChatCompletionRequest
{}
#[cfg(test)]
mod
tests
{
use
super
::
*
;
use
crate
::
openai_api_types
::
*
;
use
serde_json
::
json
;
use
std
::
collections
::
HashMap
;
// ============= Test Helper Functions =============
//
// These helper functions create default request instances with all required SGLang extension fields
// properly initialized. Use the struct spread operator `..default_*_request()` to override only
// the fields you need for specific tests, avoiding repetitive boilerplate code.
//
// Example usage:
// let req = GenerateRequest {
// text: Some("Custom text".to_string()),
// stream: true,
// ..default_generate_request()
// };
/// Create a default GenerateRequest with minimal fields set
fn
default_generate_request
()
->
GenerateRequest
{
GenerateRequest
{
text
:
None
,
prompt
:
None
,
input_ids
:
None
,
stream
:
false
,
parameters
:
None
,
sampling_params
:
None
,
return_logprob
:
false
,
// SGLang Extensions
lora_path
:
None
,
session_params
:
None
,
return_hidden_states
:
false
,
rid
:
None
,
}
}
/// Create a default CompletionRequest with minimal fields set
fn
default_completion_request
()
->
CompletionRequest
{
CompletionRequest
{
model
:
"test-model"
.to_string
(),
prompt
:
StringOrArray
::
String
(
"test prompt"
.to_string
()),
max_tokens
:
None
,
temperature
:
None
,
top_p
:
None
,
n
:
None
,
stream
:
false
,
stream_options
:
None
,
logprobs
:
None
,
echo
:
false
,
stop
:
None
,
presence_penalty
:
None
,
frequency_penalty
:
None
,
best_of
:
None
,
logit_bias
:
None
,
user
:
None
,
seed
:
None
,
suffix
:
None
,
// SGLang Extensions
top_k
:
None
,
min_p
:
None
,
min_tokens
:
None
,
repetition_penalty
:
None
,
regex
:
None
,
ebnf
:
None
,
json_schema
:
None
,
stop_token_ids
:
None
,
no_stop_trim
:
false
,
ignore_eos
:
false
,
skip_special_tokens
:
true
,
// SGLang Extensions
lora_path
:
None
,
session_params
:
None
,
return_hidden_states
:
false
,
other
:
serde_json
::
Map
::
new
(),
}
}
/// Create a default ChatCompletionRequest with minimal fields set
fn
default_chat_completion_request
()
->
ChatCompletionRequest
{
ChatCompletionRequest
{
model
:
"test-model"
.to_string
(),
messages
:
vec!
[
ChatMessage
::
User
{
role
:
"user"
.to_string
(),
content
:
UserMessageContent
::
Text
(
"test message"
.to_string
()),
name
:
None
,
}],
temperature
:
None
,
top_p
:
None
,
n
:
None
,
stream
:
false
,
stream_options
:
None
,
stop
:
None
,
max_tokens
:
None
,
max_completion_tokens
:
None
,
presence_penalty
:
None
,
frequency_penalty
:
None
,
logit_bias
:
None
,
logprobs
:
false
,
top_logprobs
:
None
,
user
:
None
,
seed
:
None
,
response_format
:
None
,
tools
:
None
,
tool_choice
:
None
,
parallel_tool_calls
:
None
,
functions
:
None
,
function_call
:
None
,
// SGLang Extensions
top_k
:
None
,
min_p
:
None
,
min_tokens
:
None
,
repetition_penalty
:
None
,
regex
:
None
,
ebnf
:
None
,
stop_token_ids
:
None
,
no_stop_trim
:
false
,
ignore_eos
:
false
,
continue_final_message
:
false
,
skip_special_tokens
:
true
,
// SGLang Extensions
lora_path
:
None
,
session_params
:
None
,
separate_reasoning
:
true
,
stream_reasoning
:
true
,
return_hidden_states
:
false
,
}
}
// ============= GenerateRequest to_pd_request Tests =============
#[test]
fn
test_generate_to_pd_request_with_text_only
()
{
let
req
=
GenerateRequest
{
text
:
Some
(
"Hello world"
.to_string
()),
..
default_generate_request
()
};
let
pd_req
=
req
.to_pd_request
();
// Check text field conversion
assert
!
(
matches!
(
pd_req
.text
,
Some
(
SingleOrBatch
::
Single
(
ref
s
))
if
s
==
"Hello world"
));
assert
!
(
pd_req
.input_ids
.is_none
());
// Check bootstrap fields are None
assert
!
(
pd_req
.bootstrap_host
.is_none
());
assert
!
(
pd_req
.bootstrap_port
.is_none
());
assert
!
(
pd_req
.bootstrap_room
.is_none
());
// Check stream flag
assert_eq!
(
pd_req
.stream
,
false
);
// Check other fields
let
other
=
pd_req
.other
.as_object
()
.unwrap
();
assert_eq!
(
other
.get
(
"stream"
),
Some
(
&
json!
(
false
)));
assert_eq!
(
other
.get
(
"return_logprob"
),
Some
(
&
json!
(
false
)));
}
#[test]
fn
test_generate_to_pd_request_with_prompt_string
()
{
let
req
=
GenerateRequest
{
prompt
:
Some
(
StringOrArray
::
String
(
"Test prompt"
.to_string
())),
stream
:
true
,
return_logprob
:
true
,
..
default_generate_request
()
};
let
pd_req
=
req
.to_pd_request
();
assert
!
(
matches!
(
pd_req
.text
,
Some
(
SingleOrBatch
::
Single
(
ref
s
))
if
s
==
"Test prompt"
));
assert
!
(
pd_req
.input_ids
.is_none
());
assert_eq!
(
pd_req
.stream
,
true
);
let
other
=
pd_req
.other
.as_object
()
.unwrap
();
assert_eq!
(
other
.get
(
"stream"
),
Some
(
&
json!
(
true
)));
assert_eq!
(
other
.get
(
"return_logprob"
),
Some
(
&
json!
(
true
)));
}
#[test]
fn
test_generate_to_pd_request_with_prompt_array
()
{
let
req
=
GenerateRequest
{
text
:
None
,
prompt
:
Some
(
StringOrArray
::
Array
(
vec!
[
"Prompt 1"
.to_string
(),
"Prompt 2"
.to_string
(),
"Prompt 3"
.to_string
(),
])),
input_ids
:
None
,
stream
:
false
,
parameters
:
None
,
sampling_params
:
None
,
return_logprob
:
false
,
..
default_generate_request
()
};
let
pd_req
=
req
.to_pd_request
();
match
pd_req
.text
{
Some
(
SingleOrBatch
::
Batch
(
ref
batch
))
=>
{
assert_eq!
(
batch
.len
(),
3
);
assert_eq!
(
batch
[
0
],
"Prompt 1"
);
assert_eq!
(
batch
[
1
],
"Prompt 2"
);
assert_eq!
(
batch
[
2
],
"Prompt 3"
);
}
_
=>
panic!
(
"Expected batch text"
),
}
}
#[test]
fn
test_generate_to_pd_request_with_single_input_ids
()
{
let
req
=
GenerateRequest
{
input_ids
:
Some
(
InputIds
::
Single
(
vec!
[
100
,
200
,
300
,
400
])),
..
default_generate_request
()
};
let
pd_req
=
req
.to_pd_request
();
assert
!
(
pd_req
.text
.is_none
());
assert
!
(
matches!
(
pd_req
.input_ids
,
Some
(
SingleOrBatch
::
Single
(
ref
ids
))
if
ids
==
&
vec!
[
100
,
200
,
300
,
400
]
));
}
#[test]
fn
test_generate_to_pd_request_with_batch_input_ids
()
{
let
req
=
GenerateRequest
{
input_ids
:
Some
(
InputIds
::
Batch
(
vec!
[
vec!
[
1
,
2
,
3
],
vec!
[
4
,
5
,
6
,
7
],
vec!
[
8
,
9
],
])),
..
default_generate_request
()
};
let
pd_req
=
req
.to_pd_request
();
match
pd_req
.input_ids
{
Some
(
SingleOrBatch
::
Batch
(
ref
batch
))
=>
{
assert_eq!
(
batch
.len
(),
3
);
assert_eq!
(
batch
[
0
],
vec!
[
1
,
2
,
3
]);
assert_eq!
(
batch
[
1
],
vec!
[
4
,
5
,
6
,
7
]);
assert_eq!
(
batch
[
2
],
vec!
[
8
,
9
]);
}
_
=>
panic!
(
"Expected batch input_ids"
),
}
}
#[test]
fn
test_generate_to_pd_request_priority_text_over_prompt
()
{
let
req
=
GenerateRequest
{
text
:
Some
(
"SGLang text"
.to_string
()),
prompt
:
Some
(
StringOrArray
::
String
(
"OpenAI prompt"
.to_string
())),
input_ids
:
Some
(
InputIds
::
Single
(
vec!
[
1
,
2
,
3
])),
..
default_generate_request
()
};
let
pd_req
=
req
.to_pd_request
();
// text should take priority
assert
!
(
matches!
(
pd_req
.text
,
Some
(
SingleOrBatch
::
Single
(
ref
s
))
if
s
==
"SGLang text"
));
assert
!
(
pd_req
.input_ids
.is_none
());
}
#[test]
fn
test_generate_to_pd_request_priority_prompt_over_input_ids
()
{
let
req
=
GenerateRequest
{
prompt
:
Some
(
StringOrArray
::
String
(
"OpenAI prompt"
.to_string
())),
input_ids
:
Some
(
InputIds
::
Single
(
vec!
[
1
,
2
,
3
])),
..
default_generate_request
()
};
let
pd_req
=
req
.to_pd_request
();
// prompt should take priority over input_ids
assert
!
(
matches!
(
pd_req
.text
,
Some
(
SingleOrBatch
::
Single
(
ref
s
))
if
s
==
"OpenAI prompt"
));
assert
!
(
pd_req
.input_ids
.is_none
());
}
#[test]
fn
test_generate_to_pd_request_with_parameters
()
{
let
params
=
GenerateParameters
{
max_new_tokens
:
Some
(
100
),
temperature
:
Some
(
0.8
),
top_p
:
Some
(
0.95
),
seed
:
Some
(
12345
),
stop
:
Some
(
vec!
[
"END"
.to_string
(),
"STOP"
.to_string
()]),
repetition_penalty
:
Some
(
1.1
),
..
Default
::
default
()
};
let
req
=
GenerateRequest
{
text
:
Some
(
"test"
.to_string
()),
parameters
:
Some
(
params
),
..
default_generate_request
()
};
let
pd_req
=
req
.to_pd_request
();
let
other
=
pd_req
.other
.as_object
()
.unwrap
();
// Check that max_new_tokens and temperature were extracted to top level
assert_eq!
(
other
.get
(
"max_new_tokens"
),
Some
(
&
json!
(
100
)));
assert
!
(
other
.get
(
"temperature"
)
.unwrap
()
.as_f64
()
.unwrap
()
-
0.8
<
0.0001
);
// Check that other parameters remain under "parameters"
let
params
=
other
.get
(
"parameters"
)
.unwrap
()
.as_object
()
.unwrap
();
assert
!
(
params
.get
(
"top_p"
)
.unwrap
()
.as_f64
()
.unwrap
()
-
0.95
<
0.0001
);
assert_eq!
(
params
.get
(
"seed"
),
Some
(
&
json!
(
12345
)));
assert_eq!
(
params
.get
(
"stop"
),
Some
(
&
json!
(
vec!
[
"END"
,
"STOP"
])));
assert
!
(
params
.get
(
"repetition_penalty"
)
.unwrap
()
.as_f64
()
.unwrap
()
-
1.1
<
0.0001
);
}
#[test]
fn
test_generate_to_pd_request_with_sampling_params
()
{
let
sampling
=
SamplingParams
{
max_new_tokens
:
Some
(
200
),
temperature
:
Some
(
0.7
),
top_p
:
Some
(
0.9
),
top_k
:
Some
(
50
),
frequency_penalty
:
Some
(
0.1
),
presence_penalty
:
Some
(
0.2
),
repetition_penalty
:
Some
(
1.05
),
..
Default
::
default
()
};
let
req
=
GenerateRequest
{
text
:
Some
(
"test"
.to_string
()),
sampling_params
:
Some
(
sampling
),
..
default_generate_request
()
};
let
pd_req
=
req
.to_pd_request
();
let
other
=
pd_req
.other
.as_object
()
.unwrap
();
// Check extracted top-level fields
assert_eq!
(
other
.get
(
"max_new_tokens"
),
Some
(
&
json!
(
200
)));
assert
!
(
other
.get
(
"temperature"
)
.unwrap
()
.as_f64
()
.unwrap
()
-
0.7
<
0.0001
);
// Check full sampling_params is preserved
let
sampling
=
other
.get
(
"sampling_params"
)
.unwrap
()
.as_object
()
.unwrap
();
assert_eq!
(
sampling
.get
(
"max_new_tokens"
),
Some
(
&
json!
(
200
)));
assert
!
(
sampling
.get
(
"temperature"
)
.unwrap
()
.as_f64
()
.unwrap
()
-
0.7
<
0.0001
);
assert
!
(
sampling
.get
(
"top_p"
)
.unwrap
()
.as_f64
()
.unwrap
()
-
0.9
<
0.0001
);
assert_eq!
(
sampling
.get
(
"top_k"
),
Some
(
&
json!
(
50
)));
assert
!
(
sampling
.get
(
"frequency_penalty"
)
.unwrap
()
.as_f64
()
.unwrap
()
-
0.1
<
0.0001
);
assert
!
(
sampling
.get
(
"presence_penalty"
)
.unwrap
()
.as_f64
()
.unwrap
()
-
0.2
<
0.0001
);
}
#[test]
fn
test_generate_to_pd_request_sampling_params_override_parameters
()
{
// When both parameters and sampling_params have max_new_tokens/temperature,
// sampling_params should take precedence (processed last)
let
params
=
GenerateParameters
{
max_new_tokens
:
Some
(
100
),
temperature
:
Some
(
0.5
),
..
Default
::
default
()
};
let
sampling
=
SamplingParams
{
max_new_tokens
:
Some
(
200
),
temperature
:
Some
(
0.9
),
..
Default
::
default
()
};
let
req
=
GenerateRequest
{
text
:
Some
(
"test"
.to_string
()),
prompt
:
None
,
input_ids
:
None
,
stream
:
false
,
parameters
:
Some
(
params
),
sampling_params
:
Some
(
sampling
),
return_logprob
:
false
,
..
default_generate_request
()
};
let
pd_req
=
req
.to_pd_request
();
let
other
=
pd_req
.other
.as_object
()
.unwrap
();
// Should use values from sampling_params since they're processed last
assert_eq!
(
other
.get
(
"max_new_tokens"
),
Some
(
&
json!
(
200
)));
assert
!
(
other
.get
(
"temperature"
)
.unwrap
()
.as_f64
()
.unwrap
()
-
0.9
<
0.0001
);
}
#[test]
fn
test_generate_to_pd_request_empty_parameters
()
{
let
params
=
GenerateParameters
::
default
();
let
req
=
GenerateRequest
{
text
:
Some
(
"test"
.to_string
()),
prompt
:
None
,
input_ids
:
None
,
stream
:
false
,
parameters
:
Some
(
params
),
sampling_params
:
None
,
return_logprob
:
false
,
..
default_generate_request
()
};
let
pd_req
=
req
.to_pd_request
();
let
other
=
pd_req
.other
.as_object
()
.unwrap
();
// Should not have parameters field if all values are None/default
assert
!
(
!
other
.contains_key
(
"parameters"
));
assert
!
(
!
other
.contains_key
(
"max_new_tokens"
));
assert
!
(
!
other
.contains_key
(
"temperature"
));
}
#[test]
fn
test_generate_to_pd_request_all_fields
()
{
let
params
=
GenerateParameters
{
max_new_tokens
:
Some
(
150
),
temperature
:
Some
(
0.6
),
top_k
:
Some
(
40
),
..
Default
::
default
()
};
let
sampling
=
SamplingParams
{
max_new_tokens
:
Some
(
250
),
// Will override parameters
temperature
:
Some
(
0.8
),
// Will override parameters
presence_penalty
:
Some
(
0.1
),
..
Default
::
default
()
};
let
req
=
GenerateRequest
{
text
:
Some
(
"Complex test"
.to_string
()),
prompt
:
Some
(
StringOrArray
::
String
(
"Ignored prompt"
.to_string
())),
input_ids
:
None
,
stream
:
true
,
parameters
:
Some
(
params
),
sampling_params
:
Some
(
sampling
),
return_logprob
:
true
,
..
default_generate_request
()
};
let
pd_req
=
req
.to_pd_request
();
// Verify all fields
assert
!
(
matches!
(
pd_req
.text
,
Some
(
SingleOrBatch
::
Single
(
ref
s
))
if
s
==
"Complex test"
));
assert
!
(
pd_req
.input_ids
.is_none
());
assert_eq!
(
pd_req
.stream
,
true
);
assert
!
(
pd_req
.bootstrap_host
.is_none
());
assert
!
(
pd_req
.bootstrap_port
.is_none
());
assert
!
(
pd_req
.bootstrap_room
.is_none
());
let
other
=
pd_req
.other
.as_object
()
.unwrap
();
assert_eq!
(
other
.get
(
"stream"
),
Some
(
&
json!
(
true
)));
assert_eq!
(
other
.get
(
"return_logprob"
),
Some
(
&
json!
(
true
)));
// Sampling params override parameters
assert_eq!
(
other
.get
(
"max_new_tokens"
),
Some
(
&
json!
(
250
)));
assert
!
(
other
.get
(
"temperature"
)
.unwrap
()
.as_f64
()
.unwrap
()
-
0.8
<
0.0001
);
assert
!
(
other
.contains_key
(
"parameters"
));
assert
!
(
other
.contains_key
(
"sampling_params"
));
}
// ============= CompletionRequest to_pd_request Tests =============
#[test]
fn
test_completion_to_pd_request_basic
()
{
let
req
=
CompletionRequest
{
model
:
"gpt-3.5-turbo"
.to_string
(),
prompt
:
StringOrArray
::
String
(
"Complete this sentence"
.to_string
()),
..
default_completion_request
()
};
let
pd_req
=
req
.to_pd_request
();
assert
!
(
matches!
(
pd_req
.text
,
Some
(
SingleOrBatch
::
Single
(
ref
s
))
if
s
==
"Complete this sentence"
)
);
assert
!
(
pd_req
.input_ids
.is_none
());
assert_eq!
(
pd_req
.stream
,
false
);
let
other
=
pd_req
.other
.as_object
()
.unwrap
();
assert_eq!
(
other
.get
(
"model"
),
Some
(
&
json!
(
"gpt-3.5-turbo"
)));
assert_eq!
(
other
.get
(
"stream"
),
Some
(
&
json!
(
false
)));
}
#[test]
fn
test_completion_to_pd_request_array_prompt
()
{
let
req
=
CompletionRequest
{
model
:
"test"
.to_string
(),
prompt
:
StringOrArray
::
Array
(
vec!
[
"First prompt"
.to_string
(),
"Second prompt"
.to_string
(),
]),
..
default_completion_request
()
};
let
pd_req
=
req
.to_pd_request
();
match
pd_req
.text
{
Some
(
SingleOrBatch
::
Batch
(
ref
batch
))
=>
{
assert_eq!
(
batch
.len
(),
2
);
assert_eq!
(
batch
[
0
],
"First prompt"
);
assert_eq!
(
batch
[
1
],
"Second prompt"
);
}
_
=>
panic!
(
"Expected batch text"
),
}
}
#[test]
fn
test_completion_to_pd_request_parameter_mapping
()
{
let
req
=
CompletionRequest
{
model
:
"test"
.to_string
(),
prompt
:
StringOrArray
::
String
(
"test"
.to_string
()),
max_tokens
:
Some
(
150
),
// -> max_new_tokens
temperature
:
Some
(
0.75
),
top_p
:
Some
(
0.92
),
n
:
Some
(
3
),
// -> best_of
stream
:
true
,
stream_options
:
None
,
logprobs
:
Some
(
10
),
// -> top_n_tokens
echo
:
true
,
// -> return_full_text
stop
:
Some
(
StringOrArray
::
Array
(
vec!
[
"
\\
n"
.to_string
(),
"END"
.to_string
(),
])),
presence_penalty
:
Some
(
0.5
),
// -> repetition_penalty = 1.5
frequency_penalty
:
Some
(
0.2
),
best_of
:
Some
(
5
),
logit_bias
:
None
,
user
:
Some
(
"user123"
.to_string
()),
seed
:
Some
(
42
),
suffix
:
Some
(
"..."
.to_string
()),
..
default_completion_request
()
};
let
pd_req
=
req
.to_pd_request
();
let
other
=
pd_req
.other
.as_object
()
.unwrap
();
let
params
=
other
.get
(
"parameters"
)
.unwrap
()
.as_object
()
.unwrap
();
// Check parameter mappings
assert_eq!
(
params
.get
(
"max_new_tokens"
),
Some
(
&
json!
(
150
)));
assert
!
(
params
.get
(
"temperature"
)
.unwrap
()
.as_f64
()
.unwrap
()
-
0.75
<
0.0001
);
assert
!
(
params
.get
(
"top_p"
)
.unwrap
()
.as_f64
()
.unwrap
()
-
0.92
<
0.0001
);
assert_eq!
(
params
.get
(
"best_of"
),
Some
(
&
json!
(
3
)));
assert_eq!
(
params
.get
(
"top_n_tokens"
),
Some
(
&
json!
(
10
)));
assert_eq!
(
params
.get
(
"return_full_text"
),
Some
(
&
json!
(
true
)));
assert_eq!
(
params
.get
(
"stop"
),
Some
(
&
json!
(
vec!
[
"
\\
n"
,
"END"
])));
assert
!
(
params
.get
(
"repetition_penalty"
)
.unwrap
()
.as_f64
()
.unwrap
()
-
1.5
<
0.0001
);
assert_eq!
(
params
.get
(
"seed"
),
Some
(
&
json!
(
42
)));
// Check other fields
assert_eq!
(
other
.get
(
"model"
),
Some
(
&
json!
(
"test"
)));
assert_eq!
(
other
.get
(
"stream"
),
Some
(
&
json!
(
true
)));
}
#[test]
fn
test_completion_to_pd_request_stop_string
()
{
let
req
=
CompletionRequest
{
model
:
"test"
.to_string
(),
prompt
:
StringOrArray
::
String
(
"test"
.to_string
()),
stop
:
Some
(
StringOrArray
::
String
(
"STOP"
.to_string
())),
max_tokens
:
None
,
temperature
:
None
,
top_p
:
None
,
n
:
None
,
stream
:
false
,
stream_options
:
None
,
logprobs
:
None
,
echo
:
false
,
presence_penalty
:
None
,
frequency_penalty
:
None
,
best_of
:
None
,
logit_bias
:
None
,
user
:
None
,
seed
:
None
,
suffix
:
None
,
..
default_completion_request
()
};
let
pd_req
=
req
.to_pd_request
();
let
other
=
pd_req
.other
.as_object
()
.unwrap
();
let
params
=
other
.get
(
"parameters"
)
.unwrap
()
.as_object
()
.unwrap
();
// Single string stop should be converted to array
assert_eq!
(
params
.get
(
"stop"
),
Some
(
&
json!
(
vec!
[
"STOP"
])));
}
#[test]
fn
test_completion_to_pd_request_no_presence_penalty
()
{
let
req
=
CompletionRequest
{
model
:
"test"
.to_string
(),
prompt
:
StringOrArray
::
String
(
"test"
.to_string
()),
presence_penalty
:
None
,
max_tokens
:
None
,
temperature
:
None
,
top_p
:
None
,
n
:
None
,
stream
:
false
,
stream_options
:
None
,
logprobs
:
None
,
echo
:
false
,
stop
:
None
,
frequency_penalty
:
None
,
best_of
:
None
,
logit_bias
:
None
,
user
:
None
,
seed
:
None
,
suffix
:
None
,
..
default_completion_request
()
};
let
pd_req
=
req
.to_pd_request
();
let
other
=
pd_req
.other
.as_object
()
.unwrap
();
let
params
=
other
.get
(
"parameters"
)
.unwrap
()
.as_object
()
.unwrap
();
// Should not have repetition_penalty if presence_penalty is None
assert
!
(
!
params
.contains_key
(
"repetition_penalty"
));
}
// ============= ChatCompletionRequest to_pd_request Tests =============
#[test]
fn
test_chat_to_pd_request_basic
()
{
let
messages
=
vec!
[
ChatMessage
::
System
{
role
:
"system"
.to_string
(),
content
:
"You are a helpful assistant"
.to_string
(),
name
:
None
,
},
ChatMessage
::
User
{
role
:
"user"
.to_string
(),
content
:
UserMessageContent
::
Text
(
"Hello!"
.to_string
()),
name
:
None
,
},
];
let
req
=
ChatCompletionRequest
{
messages
,
model
:
"gpt-4"
.to_string
(),
..
default_chat_completion_request
()
};
let
pd_req
=
req
.to_pd_request
();
assert_eq!
(
pd_req
.stream
,
false
);
assert
!
(
pd_req
.bootstrap_host
.is_none
());
assert
!
(
pd_req
.bootstrap_port
.is_none
());
assert
!
(
pd_req
.bootstrap_room
.is_none
());
let
other
=
pd_req
.other
.as_object
()
.unwrap
();
assert
!
(
other
.contains_key
(
"messages"
));
assert_eq!
(
other
.get
(
"model"
),
Some
(
&
json!
(
"gpt-4"
)));
assert_eq!
(
other
.get
(
"stream"
),
Some
(
&
json!
(
false
)));
// Check messages are preserved
let
messages
=
other
.get
(
"messages"
)
.unwrap
()
.as_array
()
.unwrap
();
assert_eq!
(
messages
.len
(),
2
);
}
#[test]
fn
test_chat_to_pd_request_with_all_optional_fields
()
{
let
messages
=
vec!
[
ChatMessage
::
User
{
role
:
"user"
.to_string
(),
content
:
UserMessageContent
::
Text
(
"Test"
.to_string
()),
name
:
Some
(
"test_user"
.to_string
()),
}];
let
mut
logit_bias
=
HashMap
::
new
();
logit_bias
.insert
(
"50256"
.to_string
(),
-
100.0f32
);
let
tool
=
Tool
{
tool_type
:
"function"
.to_string
(),
function
:
Function
{
name
:
"get_weather"
.to_string
(),
description
:
Some
(
"Get weather info"
.to_string
()),
parameters
:
json!
({
"type"
:
"object"
}),
},
};
let
req
=
ChatCompletionRequest
{
messages
,
model
:
"gpt-4"
.to_string
(),
temperature
:
Some
(
0.8
),
top_p
:
Some
(
0.95
),
n
:
Some
(
2
),
stream
:
true
,
stream_options
:
Some
(
StreamOptions
{
include_usage
:
Some
(
true
),
}),
stop
:
Some
(
StringOrArray
::
String
(
"
\\
n
\\
n"
.to_string
())),
max_tokens
:
Some
(
200
),
max_completion_tokens
:
Some
(
150
),
presence_penalty
:
Some
(
0.1
),
frequency_penalty
:
Some
(
0.2
),
logit_bias
:
Some
(
logit_bias
),
logprobs
:
true
,
top_logprobs
:
Some
(
5
),
user
:
Some
(
"user456"
.to_string
()),
seed
:
Some
(
12345
),
response_format
:
Some
(
ResponseFormat
::
JsonObject
),
tools
:
Some
(
vec!
[
tool
]),
tool_choice
:
Some
(
ToolChoice
::
Auto
),
parallel_tool_calls
:
Some
(
false
),
functions
:
None
,
function_call
:
None
,
..
default_chat_completion_request
()
};
let
pd_req
=
req
.to_pd_request
();
let
other
=
pd_req
.other
.as_object
()
.unwrap
();
// Check all fields are preserved
assert
!
(
other
.get
(
"temperature"
)
.unwrap
()
.as_f64
()
.unwrap
()
-
0.8
<
0.0001
);
assert
!
(
other
.get
(
"top_p"
)
.unwrap
()
.as_f64
()
.unwrap
()
-
0.95
<
0.0001
);
assert_eq!
(
other
.get
(
"n"
),
Some
(
&
json!
(
2
)));
assert_eq!
(
other
.get
(
"stream"
),
Some
(
&
json!
(
true
)));
assert
!
(
other
.contains_key
(
"stream_options"
));
assert
!
(
other
.contains_key
(
"stop"
));
assert_eq!
(
other
.get
(
"max_tokens"
),
Some
(
&
json!
(
200
)));
assert_eq!
(
other
.get
(
"max_completion_tokens"
),
Some
(
&
json!
(
150
)));
assert
!
(
other
.get
(
"presence_penalty"
)
.unwrap
()
.as_f64
()
.unwrap
()
-
0.1
<
0.0001
);
assert
!
(
other
.get
(
"frequency_penalty"
)
.unwrap
()
.as_f64
()
.unwrap
()
-
0.2
<
0.0001
);
assert
!
(
other
.contains_key
(
"logit_bias"
));
assert_eq!
(
other
.get
(
"logprobs"
),
Some
(
&
json!
(
true
)));
assert_eq!
(
other
.get
(
"top_logprobs"
),
Some
(
&
json!
(
5
)));
assert_eq!
(
other
.get
(
"user"
),
Some
(
&
json!
(
"user456"
)));
assert_eq!
(
other
.get
(
"seed"
),
Some
(
&
json!
(
12345
)));
assert
!
(
other
.contains_key
(
"response_format"
));
assert
!
(
other
.contains_key
(
"tools"
));
assert
!
(
other
.contains_key
(
"tool_choice"
));
assert_eq!
(
other
.get
(
"parallel_tool_calls"
),
Some
(
&
json!
(
false
)));
}
#[test]
fn
test_chat_to_pd_request_multimodal_content
()
{
let
messages
=
vec!
[
ChatMessage
::
User
{
role
:
"user"
.to_string
(),
content
:
UserMessageContent
::
Parts
(
vec!
[
ContentPart
::
Text
{
text
:
"What's in this image?"
.to_string
(),
},
ContentPart
::
ImageUrl
{
image_url
:
ImageUrl
{
url
:
"https://example.com/image.jpg"
.to_string
(),
detail
:
Some
(
"high"
.to_string
()),
},
},
]),
name
:
None
,
}];
let
req
=
ChatCompletionRequest
{
messages
,
model
:
"gpt-4-vision"
.to_string
(),
..
default_chat_completion_request
()
};
let
pd_req
=
req
.to_pd_request
();
let
other
=
pd_req
.other
.as_object
()
.unwrap
();
// Messages with multimodal content should be preserved
assert
!
(
other
.contains_key
(
"messages"
));
let
messages
=
other
.get
(
"messages"
)
.unwrap
()
.as_array
()
.unwrap
();
assert_eq!
(
messages
.len
(),
1
);
// Verify the message structure is preserved
let
msg
=
&
messages
[
0
];
assert_eq!
(
msg
[
"role"
],
"user"
);
assert
!
(
msg
[
"content"
]
.is_array
());
}
#[test]
fn
test_chat_to_pd_request_logprobs_boolean
()
{
let
messages
=
vec!
[
ChatMessage
::
User
{
role
:
"user"
.to_string
(),
content
:
UserMessageContent
::
Text
(
"Test"
.to_string
()),
name
:
None
,
}];
let
req
=
ChatCompletionRequest
{
messages
,
model
:
"test"
.to_string
(),
logprobs
:
true
,
// Boolean logprobs flag
top_logprobs
:
Some
(
3
),
temperature
:
None
,
top_p
:
None
,
n
:
None
,
stream
:
false
,
stream_options
:
None
,
stop
:
None
,
max_tokens
:
None
,
max_completion_tokens
:
None
,
presence_penalty
:
None
,
frequency_penalty
:
None
,
logit_bias
:
None
,
user
:
None
,
seed
:
None
,
response_format
:
None
,
tools
:
None
,
tool_choice
:
None
,
parallel_tool_calls
:
None
,
functions
:
None
,
function_call
:
None
,
..
default_chat_completion_request
()
};
let
pd_req
=
req
.to_pd_request
();
let
other
=
pd_req
.other
.as_object
()
.unwrap
();
assert_eq!
(
other
.get
(
"logprobs"
),
Some
(
&
json!
(
true
)));
assert_eq!
(
other
.get
(
"top_logprobs"
),
Some
(
&
json!
(
3
)));
}
#[test]
fn
test_chat_to_pd_request_minimal_fields
()
{
let
messages
=
vec!
[
ChatMessage
::
Assistant
{
role
:
"assistant"
.to_string
(),
content
:
Some
(
"I can help with that."
.to_string
()),
name
:
None
,
tool_calls
:
None
,
function_call
:
None
,
reasoning_content
:
None
,
}];
let
req
=
ChatCompletionRequest
{
messages
,
model
:
"gpt-3.5-turbo"
.to_string
(),
..
default_chat_completion_request
()
};
let
pd_req
=
req
.to_pd_request
();
let
other
=
pd_req
.other
.as_object
()
.unwrap
();
// Should only have required fields
assert
!
(
other
.contains_key
(
"messages"
));
assert
!
(
other
.contains_key
(
"model"
));
assert
!
(
other
.contains_key
(
"stream"
));
// Optional fields should not be present
assert
!
(
!
other
.contains_key
(
"temperature"
));
assert
!
(
!
other
.contains_key
(
"top_p"
));
assert
!
(
!
other
.contains_key
(
"max_tokens"
));
assert
!
(
!
other
.contains_key
(
"stop"
));
}
#[test]
fn
test_routeable_request_to_json
()
{
let
req
=
GenerateRequest
{
text
:
Some
(
"test"
.to_string
()),
..
default_generate_request
()
};
let
json
=
req
.to_json
()
.unwrap
();
assert_eq!
(
json
[
"text"
],
"test"
);
assert_eq!
(
json
[
"stream"
],
false
);
}
// ============= Macro Tests =============
#[test]
fn
test_insert_if_some_macro
()
{
let
mut
map
=
serde_json
::
Map
::
new
();
let
some_value
:
Option
<
i32
>
=
Some
(
42
);
let
none_value
:
Option
<
i32
>
=
None
;
insert_if_some!
(
map
,
some_value
=>
"present"
,
none_value
=>
"absent"
);
assert_eq!
(
map
.get
(
"present"
),
Some
(
&
json!
(
42
)));
assert
!
(
!
map
.contains_key
(
"absent"
));
}
#[test]
fn
test_insert_value_macro
()
{
let
mut
map
=
serde_json
::
Map
::
new
();
let
value1
=
"test"
;
let
value2
=
42
;
insert_value!
(
map
,
value1
=>
"string_field"
,
value2
=>
"int_field"
);
assert_eq!
(
map
.get
(
"string_field"
),
Some
(
&
json!
(
"test"
)));
assert_eq!
(
map
.get
(
"int_field"
),
Some
(
&
json!
(
42
)));
}
// ============= Edge Cases and Error Handling =============
#[test]
fn
test_null_value_handling
()
{
let
params
=
GenerateParameters
{
max_new_tokens
:
None
,
temperature
:
None
,
..
Default
::
default
()
};
let
req
=
GenerateRequest
{
text
:
Some
(
"test"
.to_string
()),
prompt
:
None
,
input_ids
:
None
,
stream
:
false
,
parameters
:
Some
(
params
),
sampling_params
:
None
,
return_logprob
:
false
,
..
default_generate_request
()
};
let
pd_req
=
req
.to_pd_request
();
let
other
=
pd_req
.other
.as_object
()
.unwrap
();
// Should not have parameters field if all fields are None
assert
!
(
!
other
.contains_key
(
"parameters"
));
}
#[test]
fn
test_large_batch_conversion
()
{
let
large_batch
:
Vec
<
String
>
=
(
0
..
1000
)
.map
(|
i
|
format!
(
"item_{}"
,
i
))
.collect
();
let
req
=
GenerateRequest
{
text
:
None
,
prompt
:
Some
(
StringOrArray
::
Array
(
large_batch
.clone
())),
input_ids
:
None
,
stream
:
false
,
parameters
:
None
,
sampling_params
:
None
,
return_logprob
:
false
,
..
default_generate_request
()
};
let
pd_req
=
req
.to_pd_request
();
if
let
Some
(
SingleOrBatch
::
Batch
(
batch
))
=
pd_req
.text
{
assert_eq!
(
batch
.len
(),
1000
);
assert_eq!
(
batch
[
0
],
"item_0"
);
assert_eq!
(
batch
[
999
],
"item_999"
);
}
else
{
panic!
(
"Expected batch text"
);
}
}
#[test]
fn
test_unicode_string_handling
()
{
let
unicode_text
=
"Hello 世界 🌍 नमस्ते мир"
.to_string
();
let
req
=
GenerateRequest
{
text
:
Some
(
unicode_text
.clone
()),
..
default_generate_request
()
};
let
pd_req
=
req
.to_pd_request
();
if
let
Some
(
SingleOrBatch
::
Single
(
text
))
=
pd_req
.text
{
assert_eq!
(
text
,
unicode_text
);
}
else
{
panic!
(
"Expected single text"
);
}
}
#[test]
fn
test_deeply_nested_parameters
()
{
let
mut
nested_params
=
serde_json
::
Map
::
new
();
nested_params
.insert
(
"nested"
.to_string
(),
json!
({
"level1"
:
{
"level2"
:
{
"level3"
:
"value"
}
}
}),
);
let
params
=
GenerateParameters
{
max_new_tokens
:
Some
(
100
),
..
Default
::
default
()
};
let
req
=
GenerateRequest
{
text
:
Some
(
"test"
.to_string
()),
prompt
:
None
,
input_ids
:
None
,
stream
:
false
,
parameters
:
Some
(
params
),
sampling_params
:
None
,
return_logprob
:
false
,
..
default_generate_request
()
};
let
pd_req
=
req
.to_pd_request
();
let
other
=
pd_req
.other
.as_object
()
.unwrap
();
// Parameters should be preserved even with nested structures
assert
!
(
other
.contains_key
(
"max_new_tokens"
));
}
// ============= Bootstrap Field Tests =============
#[test]
fn
test_bootstrap_fields_none
()
{
let
req
=
GenerateRequest
{
text
:
Some
(
"test"
.to_string
()),
..
default_generate_request
()
};
let
pd_req
=
req
.to_pd_request
();
assert_eq!
(
pd_req
.bootstrap_host
,
None
);
assert_eq!
(
pd_req
.bootstrap_port
,
None
);
assert_eq!
(
pd_req
.bootstrap_room
,
None
);
}
// ============= SGLang Extension Field Pass-Through Tests =============
#[test]
fn
test_chat_completion_sglang_extensions_passed_through
()
{
let
messages
=
vec!
[
ChatMessage
::
User
{
role
:
"user"
.to_string
(),
content
:
UserMessageContent
::
Text
(
"Test"
.to_string
()),
name
:
None
,
}];
let
mut
session_params
=
std
::
collections
::
HashMap
::
new
();
session_params
.insert
(
"key"
.to_string
(),
serde_json
::
Value
::
String
(
"value"
.to_string
()),
);
let
req
=
ChatCompletionRequest
{
messages
,
model
:
"test-model"
.to_string
(),
// SGLang Extensions - Priority 1
top_k
:
Some
(
40
),
min_p
:
Some
(
0.05
),
min_tokens
:
Some
(
10
),
repetition_penalty
:
Some
(
1.1
),
regex
:
Some
(
"test_regex"
.to_string
()),
ebnf
:
Some
(
"test_ebnf"
.to_string
()),
stop_token_ids
:
Some
(
vec!
[
1
,
2
,
3
]),
// SGLang Extensions - Priority 2
lora_path
:
Some
(
LoRAPath
::
Single
(
Some
(
"test_lora.bin"
.to_string
()))),
session_params
:
Some
(
session_params
.clone
()),
// Boolean extensions (ChatCompletionRequest has these as bool, not Option<bool>)
no_stop_trim
:
true
,
ignore_eos
:
false
,
continue_final_message
:
true
,
skip_special_tokens
:
false
,
separate_reasoning
:
true
,
stream_reasoning
:
false
,
return_hidden_states
:
true
,
..
default_chat_completion_request
()
};
let
pd_req
=
req
.to_pd_request
();
let
other
=
pd_req
.other
.as_object
()
.unwrap
();
// Verify SGLang extensions are passed through
assert_eq!
(
other
.get
(
"top_k"
),
Some
(
&
json!
(
40
)));
assert
!
((
other
.get
(
"min_p"
)
.unwrap
()
.as_f64
()
.unwrap
()
-
0.05
)
.abs
()
<
0.0001
);
assert_eq!
(
other
.get
(
"min_tokens"
),
Some
(
&
json!
(
10
)));
assert
!
((
other
.get
(
"repetition_penalty"
)
.unwrap
()
.as_f64
()
.unwrap
()
-
1.1
)
.abs
()
<
0.0001
);
assert_eq!
(
other
.get
(
"regex"
),
Some
(
&
json!
(
"test_regex"
)));
assert_eq!
(
other
.get
(
"ebnf"
),
Some
(
&
json!
(
"test_ebnf"
)));
assert_eq!
(
other
.get
(
"stop_token_ids"
),
Some
(
&
json!
(
vec!
[
1
,
2
,
3
])));
assert_eq!
(
other
.get
(
"lora_path"
),
Some
(
&
json!
(
"test_lora.bin"
)));
assert_eq!
(
other
.get
(
"session_params"
),
Some
(
&
serde_json
::
to_value
(
&
session_params
)
.unwrap
())
);
// Verify boolean extensions
assert_eq!
(
other
.get
(
"no_stop_trim"
),
Some
(
&
json!
(
true
)));
assert_eq!
(
other
.get
(
"ignore_eos"
),
Some
(
&
json!
(
false
)));
assert_eq!
(
other
.get
(
"continue_final_message"
),
Some
(
&
json!
(
true
)));
assert_eq!
(
other
.get
(
"skip_special_tokens"
),
Some
(
&
json!
(
false
)));
assert_eq!
(
other
.get
(
"separate_reasoning"
),
Some
(
&
json!
(
true
)));
assert_eq!
(
other
.get
(
"stream_reasoning"
),
Some
(
&
json!
(
false
)));
assert_eq!
(
other
.get
(
"return_hidden_states"
),
Some
(
&
json!
(
true
)));
}
#[test]
fn
test_completion_request_sglang_extensions_passed_through
()
{
let
mut
session_params
=
std
::
collections
::
HashMap
::
new
();
session_params
.insert
(
"key"
.to_string
(),
serde_json
::
Value
::
String
(
"value"
.to_string
()),
);
let
req
=
CompletionRequest
{
prompt
:
StringOrArray
::
String
(
"Test prompt"
.to_string
()),
model
:
"test-model"
.to_string
(),
// SGLang Extensions - Priority 1
top_k
:
Some
(
40
),
min_p
:
Some
(
0.05
),
min_tokens
:
Some
(
10
),
repetition_penalty
:
Some
(
1.1
),
regex
:
Some
(
"test_regex"
.to_string
()),
ebnf
:
Some
(
"test_ebnf"
.to_string
()),
stop_token_ids
:
Some
(
vec!
[
1
,
2
,
3
]),
// SGLang Extensions - Priority 2
lora_path
:
Some
(
LoRAPath
::
Single
(
Some
(
"test_lora.bin"
.to_string
()))),
session_params
:
Some
(
session_params
.clone
()),
// Boolean extensions (CompletionRequest only has these 4 boolean fields)
no_stop_trim
:
true
,
ignore_eos
:
false
,
skip_special_tokens
:
false
,
return_hidden_states
:
true
,
..
default_completion_request
()
};
let
pd_req
=
req
.to_pd_request
();
let
other
=
pd_req
.other
.as_object
()
.unwrap
();
// Verify SGLang extensions are passed through
assert_eq!
(
other
.get
(
"top_k"
),
Some
(
&
json!
(
40
)));
assert
!
((
other
.get
(
"min_p"
)
.unwrap
()
.as_f64
()
.unwrap
()
-
0.05
)
.abs
()
<
0.0001
);
assert_eq!
(
other
.get
(
"min_tokens"
),
Some
(
&
json!
(
10
)));
assert
!
((
other
.get
(
"repetition_penalty"
)
.unwrap
()
.as_f64
()
.unwrap
()
-
1.1
)
.abs
()
<
0.0001
);
assert_eq!
(
other
.get
(
"regex"
),
Some
(
&
json!
(
"test_regex"
)));
assert_eq!
(
other
.get
(
"ebnf"
),
Some
(
&
json!
(
"test_ebnf"
)));
assert_eq!
(
other
.get
(
"stop_token_ids"
),
Some
(
&
json!
(
vec!
[
1
,
2
,
3
])));
assert_eq!
(
other
.get
(
"lora_path"
),
Some
(
&
json!
(
"test_lora.bin"
)));
assert_eq!
(
other
.get
(
"session_params"
),
Some
(
&
serde_json
::
to_value
(
&
session_params
)
.unwrap
())
);
// Verify boolean extensions (only the ones CompletionRequest has)
assert_eq!
(
other
.get
(
"no_stop_trim"
),
Some
(
&
json!
(
true
)));
assert_eq!
(
other
.get
(
"ignore_eos"
),
Some
(
&
json!
(
false
)));
assert_eq!
(
other
.get
(
"skip_special_tokens"
),
Some
(
&
json!
(
false
)));
assert_eq!
(
other
.get
(
"return_hidden_states"
),
Some
(
&
json!
(
true
)));
}
#[test]
fn
test_sglang_extensions_none_values_not_passed_through
()
{
let
messages
=
vec!
[
ChatMessage
::
User
{
role
:
"user"
.to_string
(),
content
:
UserMessageContent
::
Text
(
"Test"
.to_string
()),
name
:
None
,
}];
let
req
=
ChatCompletionRequest
{
messages
,
model
:
"test-model"
.to_string
(),
// All SGLang extensions as None/default - Optional fields won't appear, bools will use defaults
top_k
:
None
,
min_p
:
None
,
min_tokens
:
None
,
repetition_penalty
:
None
,
regex
:
None
,
ebnf
:
None
,
stop_token_ids
:
None
,
lora_path
:
None
,
session_params
:
None
,
// Boolean fields use defaults (false for most, true for some with default_true)
no_stop_trim
:
false
,
ignore_eos
:
false
,
continue_final_message
:
false
,
skip_special_tokens
:
true
,
// This has default_true
separate_reasoning
:
true
,
// This has default_true
stream_reasoning
:
true
,
// This has default_true
return_hidden_states
:
false
,
..
default_chat_completion_request
()
};
let
pd_req
=
req
.to_pd_request
();
let
other
=
pd_req
.other
.as_object
()
.unwrap
();
// Verify None values are not included
assert
!
(
!
other
.contains_key
(
"top_k"
));
assert
!
(
!
other
.contains_key
(
"min_p"
));
assert
!
(
!
other
.contains_key
(
"min_tokens"
));
assert
!
(
!
other
.contains_key
(
"repetition_penalty"
));
assert
!
(
!
other
.contains_key
(
"regex"
));
assert
!
(
!
other
.contains_key
(
"ebnf"
));
assert
!
(
!
other
.contains_key
(
"stop_token_ids"
));
assert
!
(
!
other
.contains_key
(
"lora_path"
));
assert
!
(
!
other
.contains_key
(
"session_params"
));
// Boolean fields are always present with their values (can't be None)
assert_eq!
(
other
.get
(
"no_stop_trim"
),
Some
(
&
json!
(
false
)));
assert_eq!
(
other
.get
(
"ignore_eos"
),
Some
(
&
json!
(
false
)));
assert_eq!
(
other
.get
(
"continue_final_message"
),
Some
(
&
json!
(
false
)));
assert_eq!
(
other
.get
(
"skip_special_tokens"
),
Some
(
&
json!
(
true
)));
// default_true
assert_eq!
(
other
.get
(
"separate_reasoning"
),
Some
(
&
json!
(
true
)));
// default_true
assert_eq!
(
other
.get
(
"stream_reasoning"
),
Some
(
&
json!
(
true
)));
// default_true
assert_eq!
(
other
.get
(
"return_hidden_states"
),
Some
(
&
json!
(
false
)));
}
}
sgl-router/tests/benchmark_integration.rs
View file @
8c7bb39d
// Integration test to ensure benchmarks compile and basic functionality works
// Integration test to ensure benchmarks compile and basic functionality works
// This prevents benchmarks from breaking in CI
// This prevents benchmarks from breaking in CI
//
// UPDATED: Removed deprecated ToPdRequest usage, now uses direct JSON serialization
use
serde_json
::{
from_str
,
to_string
};
use
serde_json
::{
from_str
,
to_string
,
to_value
};
use
sglang_router_rs
::
core
::{
BasicWorker
,
WorkerType
};
use
sglang_router_rs
::
openai_api_types
::{
use
sglang_router_rs
::
openai_api_types
::{
ChatCompletionRequest
,
ChatMessage
,
CompletionRequest
,
GenerateParameters
,
GenerateRequest
,
ChatCompletionRequest
,
ChatMessage
,
CompletionRequest
,
GenerateParameters
,
GenerateRequest
,
SamplingParams
,
StringOrArray
,
UserMessageContent
,
SamplingParams
,
StringOrArray
,
UserMessageContent
,
};
};
use
sglang_router_rs
::
routers
::
request_adapter
::{
RouteableRequest
,
ToPdRequest
}
;
use
sglang_router_rs
::
routers
::
bootstrap_injector
::
inject_bootstrap_fields
;
/// Create a default GenerateRequest for benchmarks with minimal fields set
/// Create a default GenerateRequest for benchmarks with minimal fields set
fn
default_generate_request
()
->
GenerateRequest
{
fn
default_generate_request
()
->
GenerateRequest
{
...
@@ -114,6 +117,15 @@ fn default_completion_request() -> CompletionRequest {
...
@@ -114,6 +117,15 @@ fn default_completion_request() -> CompletionRequest {
}
}
}
}
fn
create_test_worker
()
->
BasicWorker
{
BasicWorker
::
new
(
"http://test-server:8000"
.to_string
(),
WorkerType
::
Prefill
{
bootstrap_port
:
Some
(
5678
),
},
)
}
#[test]
#[test]
fn
test_benchmark_request_creation
()
{
fn
test_benchmark_request_creation
()
{
// Ensure all benchmark request types can be created without panicking
// Ensure all benchmark request types can be created without panicking
...
@@ -197,8 +209,8 @@ fn test_benchmark_serialization_roundtrip() {
...
@@ -197,8 +209,8 @@ fn test_benchmark_serialization_roundtrip() {
}
}
#[test]
#[test]
fn
test_benchmark_
request_adapta
tion
()
{
fn
test_benchmark_
bootstrap_injec
tion
()
{
// Test that
PD request adapta
tion works for benchmark types
// Test that
bootstrap injec
tion works for benchmark types
(replaces PD request adaptation)
let
generate_req
=
GenerateRequest
{
let
generate_req
=
GenerateRequest
{
text
:
Some
(
"Test prompt"
.to_string
()),
text
:
Some
(
"Test prompt"
.to_string
()),
...
@@ -236,24 +248,40 @@ fn test_benchmark_request_adaptation() {
...
@@ -236,24 +248,40 @@ fn test_benchmark_request_adaptation() {
..
default_completion_request
()
..
default_completion_request
()
};
};
// Test PD adaptation (should not panic)
let
worker
=
create_test_worker
();
let
_
pd_generate
=
generate_req
.to_pd_request
();
let
_
pd_chat
=
chat_req
.to_pd_request
();
// Test bootstrap injection (should not panic)
let
_
pd_completion
=
completion_req
.to_pd_request
();
let
mut
generate_json
=
to_value
(
&
generate_req
)
.unwrap
();
let
mut
chat_json
=
to_value
(
&
chat_req
)
.unwrap
();
let
mut
completion_json
=
to_value
(
&
completion_req
)
.unwrap
();
assert
!
(
inject_bootstrap_fields
(
&
mut
generate_json
,
&
worker
)
.is_ok
());
assert
!
(
inject_bootstrap_fields
(
&
mut
chat_json
,
&
worker
)
.is_ok
());
assert
!
(
inject_bootstrap_fields
(
&
mut
completion_json
,
&
worker
)
.is_ok
());
// Verify bootstrap fields were added
assert
!
(
generate_json
.get
(
"bootstrap_host"
)
.is_some
());
assert
!
(
generate_json
.get
(
"bootstrap_port"
)
.is_some
());
assert
!
(
generate_json
.get
(
"bootstrap_room"
)
.is_some
());
}
}
#[test]
#[test]
fn
test_benchmark_
regular
_routing
()
{
fn
test_benchmark_
direct_json
_routing
()
{
// Test
regular
routing functionality for benchmark types
// Test
direct JSON
routing functionality for benchmark types
(replaces regular routing)
let
generate_req
=
GenerateRequest
{
let
generate_req
=
GenerateRequest
{
text
:
Some
(
"Test prompt"
.to_string
()),
text
:
Some
(
"Test prompt"
.to_string
()),
..
default_generate_request
()
..
default_generate_request
()
};
};
// Test regular routing methods (should not panic)
// Test direct JSON conversion (replaces regular routing methods)
let
_
json
=
generate_req
.to_json
();
let
json
=
to_value
(
&
generate_req
)
.unwrap
();
let
_
bytes
=
generate_req
.to_bytes
();
let
json_string
=
to_string
(
&
json
)
.unwrap
();
let
bytes
=
json_string
.as_bytes
();
// Verify conversions work
assert
!
(
!
json_string
.is_empty
());
assert
!
(
!
bytes
.is_empty
());
}
}
#[test]
#[test]
...
@@ -266,23 +294,36 @@ fn test_benchmark_performance_baseline() {
...
@@ -266,23 +294,36 @@ fn test_benchmark_performance_baseline() {
..
default_generate_request
()
..
default_generate_request
()
};
};
//
Serialization should be fast (< 1ms for simple requests)
//
Test the actual simplified pipeline: to_value + bootstrap injection
let
start
=
Instant
::
now
();
let
start
=
Instant
::
now
();
let
_
json
=
to_string
(
&
generate_req
)
.unwrap
();
let
worker
=
create_test_worker
();
let
serialize_duration
=
start
.elapsed
();
// This mirrors the actual router pipeline
let
mut
json
=
to_value
(
&
generate_req
)
.unwrap
();
let
_
=
inject_bootstrap_fields
(
&
mut
json
,
&
worker
);
let
total_duration
=
start
.elapsed
();
assert
!
(
assert
!
(
serialize
_duration
.as_millis
()
<
1
,
total
_duration
.as_millis
()
<
5
,
"S
erialization
took too long: {:?}"
,
"S
implified pipeline
took too long: {:?}
(should be faster than old adapter approach)
"
,
serialize
_duration
total
_duration
);
);
//
PD adaptation should be very fast (< 1ms)
//
Individual components should also be fast
let
start
=
Instant
::
now
();
let
start
=
Instant
::
now
();
let
_
pd_req
=
generate_req
.to_pd_request
();
let
_
json
=
to_value
(
&
generate_req
)
.unwrap
();
let
adapt_duration
=
start
.elapsed
();
let
to_value_duration
=
start
.elapsed
();
let
start
=
Instant
::
now
();
let
mut
json
=
to_value
(
&
generate_req
)
.unwrap
();
let
_
=
inject_bootstrap_fields
(
&
mut
json
,
&
worker
);
let
inject_duration
=
start
.elapsed
();
// Bootstrap injection should be faster than the JSON conversion
assert
!
(
assert
!
(
adapt_duration
.as_millis
()
<
1
,
inject_duration
<=
to_value_duration
*
3
,
"PD adaptation took too long: {:?}"
,
"Bootstrap injection ({:?}) should not be much slower than JSON conversion ({:?})"
,
adapt_duration
inject_duration
,
to_value_duration
);
);
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment