Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
8c7bb39d
Unverified
Commit
8c7bb39d
authored
Aug 05, 2025
by
Simo Lin
Committed by
GitHub
Aug 05, 2025
Browse files
[router] PD Router Simplification and Reorganization (#8838)
parent
ca47e24f
Changes
8
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
1098 additions
and
2555 deletions
+1098
-2555
sgl-router/benches/request_processing.rs
sgl-router/benches/request_processing.rs
+101
-58
sgl-router/scripts/run_benchmarks.py
sgl-router/scripts/run_benchmarks.py
+0
-3
sgl-router/src/routers/bootstrap_injector.rs
sgl-router/src/routers/bootstrap_injector.rs
+334
-0
sgl-router/src/routers/mod.rs
sgl-router/src/routers/mod.rs
+1
-1
sgl-router/src/routers/pd_router.rs
sgl-router/src/routers/pd_router.rs
+596
-524
sgl-router/src/routers/pd_types.rs
sgl-router/src/routers/pd_types.rs
+0
-432
sgl-router/src/routers/request_adapter.rs
sgl-router/src/routers/request_adapter.rs
+0
-1512
sgl-router/tests/benchmark_integration.rs
sgl-router/tests/benchmark_integration.rs
+66
-25
No files found.
sgl-router/benches/request_processing.rs
View file @
8c7bb39d
use
criterion
::{
black_box
,
criterion_group
,
criterion_main
,
BenchmarkId
,
Criterion
,
Throughput
};
use
serde_json
::{
from_str
,
to_string
,
to_vec
};
use
serde_json
::{
from_str
,
to_string
,
to_value
,
to_vec
};
use
std
::
time
::
Instant
;
use
sglang_router_rs
::
core
::{
BasicWorker
,
WorkerType
};
use
sglang_router_rs
::
openai_api_types
::{
ChatCompletionRequest
,
ChatMessage
,
CompletionRequest
,
GenerateParameters
,
GenerateRequest
,
SamplingParams
,
StringOrArray
,
UserMessageContent
,
};
use
sglang_router_rs
::
routers
::
request_adapter
::{
RouteableRequest
,
ToPdRequest
};
use
sglang_router_rs
::
routers
::
bootstrap_injector
::
inject_bootstrap_fields
;
fn
create_test_worker
()
->
BasicWorker
{
BasicWorker
::
new
(
"http://test-server:8000"
.to_string
(),
WorkerType
::
Prefill
{
bootstrap_port
:
Some
(
5678
),
},
)
}
/// Create a default GenerateRequest for benchmarks with minimal fields set
fn
default_generate_request
()
->
GenerateRequest
{
...
...
@@ -312,49 +322,54 @@ fn bench_json_deserialization(c: &mut Criterion) {
group
.finish
();
}
// Benchmark
request adaptation from OpenAI to PD format
fn
bench_
request_adapta
tion
(
c
:
&
mut
Criterion
)
{
let
mut
group
=
c
.benchmark_group
(
"
request_adapta
tion"
);
// Benchmark
bootstrap injection (replaces request adaptation)
fn
bench_
bootstrap_injec
tion
(
c
:
&
mut
Criterion
)
{
let
mut
group
=
c
.benchmark_group
(
"
bootstrap_injec
tion"
);
let
generate_req
=
create_sample_generate_request
();
let
chat_req
=
create_sample_chat_completion_request
();
let
completion_req
=
create_sample_completion_request
();
let
large_chat_req
=
create_large_chat_completion_request
();
let
worker
=
create_test_worker
();
group
.bench_function
(
"generate_
to_pd
"
,
|
b
|
{
group
.bench_function
(
"generate_
bootstrap_injection
"
,
|
b
|
{
b
.iter
(||
{
let
pd_req
=
black_box
(
generate_req
.clone
())
.to_pd_request
();
black_box
(
pd_req
);
let
mut
json
=
to_value
(
black_box
(
&
generate_req
))
.unwrap
();
inject_bootstrap_fields
(
&
mut
json
,
&
worker
)
.unwrap
();
black_box
(
json
);
});
});
group
.bench_function
(
"chat_completion_
to_pd
"
,
|
b
|
{
group
.bench_function
(
"chat_completion_
bootstrap_injection
"
,
|
b
|
{
b
.iter
(||
{
let
pd_req
=
black_box
(
chat_req
.clone
())
.to_pd_request
();
black_box
(
pd_req
);
let
mut
json
=
to_value
(
black_box
(
&
chat_req
))
.unwrap
();
inject_bootstrap_fields
(
&
mut
json
,
&
worker
)
.unwrap
();
black_box
(
json
);
});
});
group
.bench_function
(
"completion_
to_pd
"
,
|
b
|
{
group
.bench_function
(
"completion_
bootstrap_injection
"
,
|
b
|
{
b
.iter
(||
{
let
pd_req
=
black_box
(
completion_req
.clone
())
.to_pd_request
();
black_box
(
pd_req
);
let
mut
json
=
to_value
(
black_box
(
&
completion_req
))
.unwrap
();
inject_bootstrap_fields
(
&
mut
json
,
&
worker
)
.unwrap
();
black_box
(
json
);
});
});
group
.bench_function
(
"large_chat_completion_
to_pd
"
,
|
b
|
{
group
.bench_function
(
"large_chat_completion_
bootstrap_injection
"
,
|
b
|
{
b
.iter
(||
{
let
pd_req
=
black_box
(
large_chat_req
.clone
())
.to_pd_request
();
black_box
(
pd_req
);
let
mut
json
=
to_value
(
black_box
(
&
large_chat_req
))
.unwrap
();
inject_bootstrap_fields
(
&
mut
json
,
&
worker
)
.unwrap
();
black_box
(
json
);
});
});
group
.finish
();
}
// Benchmark
regular routing (RouteableRequest methods
)
fn
bench_
regular
_routing
(
c
:
&
mut
Criterion
)
{
let
mut
group
=
c
.benchmark_group
(
"
regular
_routing"
);
// Benchmark
direct JSON routing (replaces regular routing
)
fn
bench_
direct_json
_routing
(
c
:
&
mut
Criterion
)
{
let
mut
group
=
c
.benchmark_group
(
"
direct_json
_routing"
);
let
generate_req
=
create_sample_generate_request
();
let
chat_req
=
create_sample_chat_completion_request
();
...
...
@@ -362,35 +377,42 @@ fn bench_regular_routing(c: &mut Criterion) {
group
.bench_function
(
"generate_to_json"
,
|
b
|
{
b
.iter
(||
{
let
json
=
black_box
(
&
generate_req
)
.to_json
()
.unwrap
();
let
json
=
to_value
(
black_box
(
&
generate_req
))
.unwrap
();
black_box
(
json
);
});
});
group
.bench_function
(
"generate_to_json_string"
,
|
b
|
{
b
.iter
(||
{
let
json
=
to_string
(
black_box
(
&
generate_req
))
.unwrap
();
black_box
(
json
);
});
});
group
.bench_function
(
"generate_to_bytes"
,
|
b
|
{
b
.iter
(||
{
let
bytes
=
black_box
(
&
generate_req
)
.to_bytes
(
)
.unwrap
();
let
bytes
=
to_vec
(
black_box
(
&
generate_req
))
.unwrap
();
black_box
(
bytes
);
});
});
group
.bench_function
(
"chat_completion_to_json"
,
|
b
|
{
b
.iter
(||
{
let
json
=
black_box
(
&
chat_req
)
.to_json
(
)
.unwrap
();
let
json
=
to_value
(
black_box
(
&
chat_req
))
.unwrap
();
black_box
(
json
);
});
});
group
.bench_function
(
"chat_completion_to_
bytes
"
,
|
b
|
{
group
.bench_function
(
"chat_completion_to_
json_string
"
,
|
b
|
{
b
.iter
(||
{
let
bytes
=
black_box
(
&
chat_req
)
.to_bytes
(
)
.unwrap
();
black_box
(
bytes
);
let
json
=
to_string
(
black_box
(
&
chat_req
))
.unwrap
();
black_box
(
json
);
});
});
group
.bench_function
(
"completion_to_json"
,
|
b
|
{
b
.iter
(||
{
let
json
=
black_box
(
&
completion_req
)
.to_json
(
)
.unwrap
();
let
json
=
to_value
(
black_box
(
&
completion_req
))
.unwrap
();
black_box
(
json
);
});
});
...
...
@@ -418,6 +440,8 @@ fn bench_throughput_by_size(c: &mut Criterion) {
..
default_generate_request
()
};
let
worker
=
create_test_worker
();
for
(
name
,
req
)
in
[
(
"small"
,
&
small_generate
),
(
"medium"
,
&
medium_generate
),
...
...
@@ -445,33 +469,41 @@ fn bench_throughput_by_size(c: &mut Criterion) {
},
);
group
.bench_with_input
(
BenchmarkId
::
new
(
"adapt_to_pd"
,
name
),
&
req
,
|
b
,
req
|
{
b
.iter
(||
{
let
pd_req
=
(
*
req
)
.clone
()
.to_pd_request
();
black_box
(
pd_req
);
});
});
group
.bench_with_input
(
BenchmarkId
::
new
(
"bootstrap_inject"
,
name
),
&
req
,
|
b
,
req
|
{
b
.iter
(||
{
let
mut
json
=
to_value
(
req
)
.unwrap
();
inject_bootstrap_fields
(
&
mut
json
,
&
worker
)
.unwrap
();
black_box
(
json
);
});
},
);
}
group
.finish
();
}
// Benchmark full round-trip: deserialize ->
ad
ap
t
-> serialize
// Benchmark full round-trip: deserialize ->
inject bootstr
ap -> serialize
fn
bench_full_round_trip
(
c
:
&
mut
Criterion
)
{
let
mut
group
=
c
.benchmark_group
(
"full_round_trip"
);
let
generate_json
=
to_string
(
&
create_sample_generate_request
())
.unwrap
();
let
chat_json
=
to_string
(
&
create_sample_chat_completion_request
())
.unwrap
();
let
completion_json
=
to_string
(
&
create_sample_completion_request
())
.unwrap
();
let
worker
=
create_test_worker
();
group
.bench_function
(
"generate_openai_to_pd_pipeline"
,
|
b
|
{
b
.iter
(||
{
// Deserialize OpenAI request
let
req
:
GenerateRequest
=
from_str
(
black_box
(
&
generate_json
))
.unwrap
();
// Adapt to PD format
let
pd_req
=
req
.to_pd_request
();
// Serialize PD request
let
pd_json
=
to_string
(
&
pd_req
)
.unwrap
();
// Convert to JSON Value
let
mut
json
=
to_value
(
&
req
)
.unwrap
();
// Inject bootstrap fields
inject_bootstrap_fields
(
&
mut
json
,
&
worker
)
.unwrap
();
// Serialize final request
let
pd_json
=
to_string
(
&
json
)
.unwrap
();
black_box
(
pd_json
);
});
});
...
...
@@ -479,8 +511,9 @@ fn bench_full_round_trip(c: &mut Criterion) {
group
.bench_function
(
"chat_completion_openai_to_pd_pipeline"
,
|
b
|
{
b
.iter
(||
{
let
req
:
ChatCompletionRequest
=
from_str
(
black_box
(
&
chat_json
))
.unwrap
();
let
pd_req
=
req
.to_pd_request
();
let
pd_json
=
to_string
(
&
pd_req
)
.unwrap
();
let
mut
json
=
to_value
(
&
req
)
.unwrap
();
inject_bootstrap_fields
(
&
mut
json
,
&
worker
)
.unwrap
();
let
pd_json
=
to_string
(
&
json
)
.unwrap
();
black_box
(
pd_json
);
});
});
...
...
@@ -488,19 +521,21 @@ fn bench_full_round_trip(c: &mut Criterion) {
group
.bench_function
(
"completion_openai_to_pd_pipeline"
,
|
b
|
{
b
.iter
(||
{
let
req
:
CompletionRequest
=
from_str
(
black_box
(
&
completion_json
))
.unwrap
();
let
pd_req
=
req
.to_pd_request
();
let
pd_json
=
to_string
(
&
pd_req
)
.unwrap
();
let
mut
json
=
to_value
(
&
req
)
.unwrap
();
inject_bootstrap_fields
(
&
mut
json
,
&
worker
)
.unwrap
();
let
pd_json
=
to_string
(
&
json
)
.unwrap
();
black_box
(
pd_json
);
});
});
group
.bench_function
(
"generate_
regular_routing
_pipeline"
,
|
b
|
{
group
.bench_function
(
"generate_
direct_json
_pipeline"
,
|
b
|
{
b
.iter
(||
{
// Deserialize OpenAI request
let
req
:
GenerateRequest
=
from_str
(
black_box
(
&
generate_json
))
.unwrap
();
// Convert to JSON for regular routing
let
routing_json
=
req
.to_json
()
.unwrap
();
black_box
(
routing_json
);
// Convert to JSON for direct routing (no bootstrap injection)
let
routing_json
=
to_value
(
&
req
)
.unwrap
();
let
json_string
=
to_string
(
&
routing_json
)
.unwrap
();
black_box
(
json_string
);
});
});
...
...
@@ -515,6 +550,7 @@ fn benchmark_summary(c: &mut Criterion) {
// Quick performance overview
let
generate_req
=
create_sample_generate_request
();
let
worker
=
create_test_worker
();
println!
(
"
\n
Quick Performance Overview:"
);
...
...
@@ -538,32 +574,39 @@ fn benchmark_summary(c: &mut Criterion) {
deserialize_time
);
// Measure adaptation
// Measure
bootstrap injection (replaces
adaptation
)
let
start
=
Instant
::
now
();
for
_
in
0
..
1000
{
let
_
=
black_box
(
generate_req
.clone
()
.to_pd_request
());
let
mut
json
=
to_value
(
&
generate_req
)
.unwrap
();
let
_
=
black_box
(
inject_bootstrap_fields
(
&
mut
json
,
&
worker
));
}
let
adap
t_time
=
start
.elapsed
()
.as_nanos
()
/
1000
;
println!
(
" *
PD Adapta
tion (avg):
{:>
8
} ns/req"
,
adap
t_time
);
let
injec
t_time
=
start
.elapsed
()
.as_nanos
()
/
1000
;
println!
(
" *
Bootstrap Injec
tion (avg): {:>
6
} ns/req"
,
injec
t_time
);
// Calculate ratios
let
total_pipeline
=
serialize_time
+
deserialize_time
+
adap
t_time
;
let
total_pipeline
=
serialize_time
+
deserialize_time
+
injec
t_time
;
println!
(
" * Total Pipeline (avg): {:>8} ns/req"
,
total_pipeline
);
println!
(
"
\n
Performance Insights:"
);
if
deserialize_time
>
serialize_time
*
2
{
println!
(
" • Deserialization is significantly faster than serialization"
);
}
if
adap
t_time
<
serialize_time
/
10
{
if
injec
t_time
<
serialize_time
/
10
{
println!
(
" •
PD adapta
tion overhead is negligible ({:.1}% of serialization)"
,
(
adap
t_time
as
f64
/
serialize_time
as
f64
)
*
100.0
" •
Bootstrap injec
tion overhead is negligible ({:.1}% of serialization)"
,
(
injec
t_time
as
f64
/
serialize_time
as
f64
)
*
100.0
);
}
if
total_pipeline
<
10_000
{
println!
(
" • Total pipeline latency is excellent (< 10μs)"
);
if
total_pipeline
<
10
0
_000
{
println!
(
" • Total pipeline latency is excellent (< 10
0
μs)"
);
}
println!
(
"
\n
Simplification Benefits:"
);
println!
(
" • Eliminated complex type conversion layer"
);
println!
(
" • Reduced memory allocations"
);
println!
(
" • Automatic field preservation (no manual mapping)"
);
println!
(
" • Direct JSON manipulation improves performance"
);
println!
(
"
\n
Recommendations:"
);
if
serialize_time
>
deserialize_time
{
println!
(
" • Focus optimization efforts on serialization rather than deserialization"
);
...
...
@@ -581,8 +624,8 @@ criterion_group!(
benchmark_summary
,
bench_json_serialization
,
bench_json_deserialization
,
bench_
request_adapta
tion
,
bench_
regular
_routing
,
bench_
bootstrap_injec
tion
,
bench_
direct_json
_routing
,
bench_throughput_by_size
,
bench_full_round_trip
);
...
...
sgl-router/scripts/run_benchmarks.py
View file @
8c7bb39d
...
...
@@ -121,8 +121,6 @@ class BenchmarkRunner:
results
[
"serialization_time"
]
=
self
.
_extract_time
(
line
)
elif
"Deserialization (avg):"
in
line
:
results
[
"deserialization_time"
]
=
self
.
_extract_time
(
line
)
elif
"PD Adaptation (avg):"
in
line
:
results
[
"adaptation_time"
]
=
self
.
_extract_time
(
line
)
elif
"Total Pipeline (avg):"
in
line
:
results
[
"total_time"
]
=
self
.
_extract_time
(
line
)
...
...
@@ -145,7 +143,6 @@ class BenchmarkRunner:
thresholds
=
{
"serialization_time"
:
2000
,
# 2μs max
"deserialization_time"
:
2000
,
# 2μs max
"adaptation_time"
:
5000
,
# 5μs max
"total_time"
:
10000
,
# 10μs max
}
...
...
sgl-router/src/routers/bootstrap_injector.rs
0 → 100644
View file @
8c7bb39d
// Bootstrap field injection for PD routing
// Directly injects bootstrap fields into JSON requests without intermediate type conversions
use
crate
::
core
::{
Worker
,
WorkerType
};
use
crate
::
routers
::
pd_types
::
get_hostname
;
use
serde_json
::{
json
,
Value
};
/// Inject bootstrap fields directly into a JSON request
/// This replaces the complex ToPdRequest -> Bootstrap trait pattern
pub
fn
inject_bootstrap_fields
(
json
:
&
mut
Value
,
worker
:
&
dyn
Worker
)
->
Result
<
(),
String
>
{
let
batch_size
=
extract_batch_size
(
json
)
?
;
// Extract bootstrap port from prefill worker if it's a prefill type
let
bootstrap_port
=
match
worker
.worker_type
()
{
WorkerType
::
Prefill
{
bootstrap_port
}
=>
bootstrap_port
,
_
=>
None
,
};
let
hostname
=
get_hostname
(
worker
.url
());
if
let
Some
(
batch_size
)
=
batch_size
{
// Batch scenario - create arrays of bootstrap values
json
[
"bootstrap_host"
]
=
json!
(
vec!
[
hostname
;
batch_size
]);
json
[
"bootstrap_port"
]
=
json!
(
vec!
[
bootstrap_port
;
batch_size
]);
json
[
"bootstrap_room"
]
=
json!
((
0
..
batch_size
)
.map
(|
_
|
{
// Generate a value in the range [0, 2^63 - 1] to match Python's random.randint(0, 2**63 - 1)
rand
::
random
::
<
u64
>
()
&
(
i64
::
MAX
as
u64
)
})
.collect
::
<
Vec
<
_
>>
());
}
else
{
// Single scenario - create single bootstrap values
json
[
"bootstrap_host"
]
=
json!
(
hostname
);
json
[
"bootstrap_port"
]
=
json!
(
bootstrap_port
);
json
[
"bootstrap_room"
]
=
json!
(
rand
::
random
::
<
u64
>
()
&
(
i64
::
MAX
as
u64
));
}
Ok
(())
}
/// Extract batch size from various JSON request formats
/// Handles chat completions, completions, and generate requests
fn
extract_batch_size
(
json
:
&
Value
)
->
Result
<
Option
<
usize
>
,
String
>
{
// Check for chat completions 'n' parameter (number of choices)
if
let
Some
(
n
)
=
json
.get
(
"n"
)
.and_then
(|
v
|
v
.as_u64
())
{
if
n
>
1
{
return
Ok
(
Some
(
n
as
usize
));
}
}
// Check for array prompts (completions API)
if
let
Some
(
prompt
)
=
json
.get
(
"prompt"
)
{
if
let
Some
(
arr
)
=
prompt
.as_array
()
{
if
arr
.is_empty
()
{
return
Err
(
"Batch prompt array is empty"
.to_string
());
}
return
Ok
(
Some
(
arr
.len
()));
}
}
// Check for array texts (generate API)
if
let
Some
(
text
)
=
json
.get
(
"text"
)
{
if
let
Some
(
arr
)
=
text
.as_array
()
{
if
arr
.is_empty
()
{
return
Err
(
"Batch text array is empty"
.to_string
());
}
return
Ok
(
Some
(
arr
.len
()));
}
}
// Check for batch input_ids (generate API)
if
let
Some
(
input_ids
)
=
json
.get
(
"input_ids"
)
{
if
let
Some
(
arr
)
=
input_ids
.as_array
()
{
if
arr
.is_empty
()
{
return
Err
(
"Batch input_ids array is empty"
.to_string
());
}
return
Ok
(
Some
(
arr
.len
()));
}
}
// No batch indicators found - single request
Ok
(
None
)
}
#[cfg(test)]
mod
tests
{
use
super
::
*
;
use
crate
::
core
::
BasicWorker
;
use
serde_json
::
json
;
fn
create_test_worker
()
->
BasicWorker
{
BasicWorker
::
new
(
"http://test-server:8000"
.to_string
(),
WorkerType
::
Prefill
{
bootstrap_port
:
Some
(
5678
),
},
)
}
#[test]
fn
test_inject_bootstrap_single_request
()
{
let
worker
=
create_test_worker
();
let
mut
json
=
json!
({
"model"
:
"test-model"
,
"prompt"
:
"Hello world"
,
"max_tokens"
:
100
});
let
result
=
inject_bootstrap_fields
(
&
mut
json
,
&
worker
);
assert
!
(
result
.is_ok
());
// Verify bootstrap fields were added
assert_eq!
(
json
[
"bootstrap_host"
],
json!
(
"test-server"
));
assert_eq!
(
json
[
"bootstrap_port"
],
json!
(
5678
));
assert
!
(
json
[
"bootstrap_room"
]
.is_number
());
// Verify original fields preserved
assert_eq!
(
json
[
"model"
],
json!
(
"test-model"
));
assert_eq!
(
json
[
"prompt"
],
json!
(
"Hello world"
));
assert_eq!
(
json
[
"max_tokens"
],
json!
(
100
));
}
#[test]
fn
test_inject_bootstrap_batch_prompt
()
{
let
worker
=
create_test_worker
();
let
mut
json
=
json!
({
"model"
:
"test-model"
,
"prompt"
:
[
"Hello"
,
"World"
],
"max_tokens"
:
100
});
let
result
=
inject_bootstrap_fields
(
&
mut
json
,
&
worker
);
assert
!
(
result
.is_ok
());
// Verify batch bootstrap fields
assert_eq!
(
json
[
"bootstrap_host"
],
json!
([
"test-server"
,
"test-server"
])
);
assert_eq!
(
json
[
"bootstrap_port"
],
json!
([
5678
,
5678
]));
let
bootstrap_rooms
=
json
[
"bootstrap_room"
]
.as_array
()
.unwrap
();
assert_eq!
(
bootstrap_rooms
.len
(),
2
);
for
room
in
bootstrap_rooms
{
assert
!
(
room
.is_number
());
let
room_val
=
room
.as_u64
()
.unwrap
();
assert
!
(
room_val
<=
i64
::
MAX
as
u64
);
}
}
#[test]
fn
test_inject_bootstrap_chat_n_parameter
()
{
let
worker
=
create_test_worker
();
let
mut
json
=
json!
({
"model"
:
"gpt-4"
,
"messages"
:
[{
"role"
:
"user"
,
"content"
:
"Hello"
}],
"n"
:
3
});
let
result
=
inject_bootstrap_fields
(
&
mut
json
,
&
worker
);
assert
!
(
result
.is_ok
());
// Verify batch bootstrap fields for n=3
let
bootstrap_hosts
=
json
[
"bootstrap_host"
]
.as_array
()
.unwrap
();
assert_eq!
(
bootstrap_hosts
.len
(),
3
);
assert_eq!
(
bootstrap_hosts
[
0
],
json!
(
"test-server"
));
let
bootstrap_ports
=
json
[
"bootstrap_port"
]
.as_array
()
.unwrap
();
assert_eq!
(
bootstrap_ports
.len
(),
3
);
assert_eq!
(
bootstrap_ports
[
0
],
json!
(
5678
));
let
bootstrap_rooms
=
json
[
"bootstrap_room"
]
.as_array
()
.unwrap
();
assert_eq!
(
bootstrap_rooms
.len
(),
3
);
}
#[test]
fn
test_inject_bootstrap_generate_text_array
()
{
let
worker
=
create_test_worker
();
let
mut
json
=
json!
({
"text"
:
[
"First prompt"
,
"Second prompt"
],
"stream"
:
false
});
let
result
=
inject_bootstrap_fields
(
&
mut
json
,
&
worker
);
assert
!
(
result
.is_ok
());
// Verify batch bootstrap fields
let
bootstrap_hosts
=
json
[
"bootstrap_host"
]
.as_array
()
.unwrap
();
assert_eq!
(
bootstrap_hosts
.len
(),
2
);
let
bootstrap_rooms
=
json
[
"bootstrap_room"
]
.as_array
()
.unwrap
();
assert_eq!
(
bootstrap_rooms
.len
(),
2
);
// Ensure room values are different (randomness)
assert_ne!
(
bootstrap_rooms
[
0
],
bootstrap_rooms
[
1
]);
}
#[test]
fn
test_inject_bootstrap_input_ids_array
()
{
let
worker
=
create_test_worker
();
let
mut
json
=
json!
({
"input_ids"
:
[[
1
,
2
,
3
],
[
4
,
5
,
6
]],
"stream"
:
false
});
let
result
=
inject_bootstrap_fields
(
&
mut
json
,
&
worker
);
assert
!
(
result
.is_ok
());
// Verify batch bootstrap fields
let
bootstrap_hosts
=
json
[
"bootstrap_host"
]
.as_array
()
.unwrap
();
assert_eq!
(
bootstrap_hosts
.len
(),
2
);
}
#[test]
fn
test_extract_batch_size_empty_array_error
()
{
let
json
=
json!
({
"prompt"
:
[],
"model"
:
"test"
});
let
result
=
extract_batch_size
(
&
json
);
assert
!
(
result
.is_err
());
assert
!
(
result
.unwrap_err
()
.contains
(
"empty"
));
}
#[test]
fn
test_extract_batch_size_single_requests
()
{
// Single string prompt
let
json
=
json!
({
"prompt"
:
"Hello world"
,
"model"
:
"test"
});
assert_eq!
(
extract_batch_size
(
&
json
)
.unwrap
(),
None
);
// Single text
let
json
=
json!
({
"text"
:
"Hello world"
,
"stream"
:
false
});
assert_eq!
(
extract_batch_size
(
&
json
)
.unwrap
(),
None
);
// Chat with n=1 (default)
let
json
=
json!
({
"messages"
:
[{
"role"
:
"user"
,
"content"
:
"Hello"
}],
"n"
:
1
});
assert_eq!
(
extract_batch_size
(
&
json
)
.unwrap
(),
None
);
// Chat without n parameter
let
json
=
json!
({
"messages"
:
[{
"role"
:
"user"
,
"content"
:
"Hello"
}]
});
assert_eq!
(
extract_batch_size
(
&
json
)
.unwrap
(),
None
);
}
#[test]
fn
test_inject_bootstrap_preserves_sglang_fields
()
{
let
worker
=
create_test_worker
();
let
mut
json
=
json!
({
"model"
:
"test-model"
,
"prompt"
:
"Hello"
,
// SGLang extensions should be preserved
"top_k"
:
40
,
"min_p"
:
0.05
,
"repetition_penalty"
:
1.1
,
"regex"
:
"test_pattern"
,
"lora_path"
:
"test.bin"
,
"no_stop_trim"
:
true
,
"ignore_eos"
:
false
});
let
result
=
inject_bootstrap_fields
(
&
mut
json
,
&
worker
);
assert
!
(
result
.is_ok
());
// Verify bootstrap fields added
assert
!
(
json
.get
(
"bootstrap_host"
)
.is_some
());
assert
!
(
json
.get
(
"bootstrap_port"
)
.is_some
());
assert
!
(
json
.get
(
"bootstrap_room"
)
.is_some
());
// Verify all SGLang fields preserved
assert_eq!
(
json
[
"top_k"
],
json!
(
40
));
assert_eq!
(
json
[
"min_p"
],
json!
(
0.05
));
assert_eq!
(
json
[
"repetition_penalty"
],
json!
(
1.1
));
assert_eq!
(
json
[
"regex"
],
json!
(
"test_pattern"
));
assert_eq!
(
json
[
"lora_path"
],
json!
(
"test.bin"
));
assert_eq!
(
json
[
"no_stop_trim"
],
json!
(
true
));
assert_eq!
(
json
[
"ignore_eos"
],
json!
(
false
));
}
#[test]
fn
test_bootstrap_room_range
()
{
let
worker
=
create_test_worker
();
// Test single request room generation
for
_
in
0
..
1000
{
let
mut
json
=
json!
({
"prompt"
:
"test"
});
inject_bootstrap_fields
(
&
mut
json
,
&
worker
)
.unwrap
();
let
room
=
json
[
"bootstrap_room"
]
.as_u64
()
.unwrap
();
assert
!
(
room
<=
i64
::
MAX
as
u64
,
"Room {} exceeds i64::MAX"
,
room
);
}
// Test batch request room generation
for
_
in
0
..
100
{
let
mut
json
=
json!
({
"prompt"
:
[
"test1"
,
"test2"
]});
inject_bootstrap_fields
(
&
mut
json
,
&
worker
)
.unwrap
();
let
rooms
=
json
[
"bootstrap_room"
]
.as_array
()
.unwrap
();
for
room_val
in
rooms
{
let
room
=
room_val
.as_u64
()
.unwrap
();
assert
!
(
room
<=
i64
::
MAX
as
u64
,
"Room {} exceeds i64::MAX"
,
room
);
}
}
}
#[test]
fn
test_worker_without_bootstrap_port
()
{
let
worker
=
BasicWorker
::
new
(
"http://decode-only:8000"
.to_string
(),
WorkerType
::
Decode
,
// No bootstrap port
);
let
mut
json
=
json!
({
"prompt"
:
"Hello world"
});
let
result
=
inject_bootstrap_fields
(
&
mut
json
,
&
worker
);
assert
!
(
result
.is_ok
());
// Verify bootstrap fields with null port
assert_eq!
(
json
[
"bootstrap_host"
],
json!
(
"decode-only"
));
assert_eq!
(
json
[
"bootstrap_port"
],
json!
(
null
));
assert
!
(
json
[
"bootstrap_room"
]
.is_number
());
}
}
sgl-router/src/routers/mod.rs
View file @
8c7bb39d
...
...
@@ -11,10 +11,10 @@ use std::fmt::Debug;
use
crate
::
openai_api_types
::{
ChatCompletionRequest
,
CompletionRequest
,
GenerateRequest
};
pub
mod
bootstrap_injector
;
pub
mod
factory
;
pub
mod
pd_router
;
pub
mod
pd_types
;
pub
mod
request_adapter
;
pub
mod
router
;
pub
use
factory
::
RouterFactory
;
...
...
sgl-router/src/routers/pd_router.rs
View file @
8c7bb39d
This diff is collapsed.
Click to expand it.
sgl-router/src/routers/pd_types.rs
View file @
8c7bb39d
// Essential PDLB types extracted for PD routing
use
crate
::
core
::{
Worker
,
WorkerType
};
use
crate
::
openai_api_types
::{
CompletionRequest
,
StringOrArray
};
use
serde
::{
Deserialize
,
Serialize
};
use
serde_json
::
Value
;
// Custom error type for PD router operations
#[derive(Debug,
thiserror::Error)]
pub
enum
PDRouterError
{
...
...
@@ -58,428 +51,3 @@ pub enum PDSelectionPolicy {
balance_rel_threshold
:
f32
,
},
}
// Bootstrap types from PDLB
#[derive(Debug,
Deserialize,
Serialize,
PartialEq)]
#[serde(untagged)]
pub
enum
SingleOrBatch
<
T
>
{
Single
(
T
),
Batch
(
Vec
<
T
>
),
}
pub
type
InputIds
=
SingleOrBatch
<
Vec
<
i32
>>
;
pub
type
InputText
=
SingleOrBatch
<
String
>
;
pub
type
BootstrapHost
=
SingleOrBatch
<
String
>
;
pub
type
BootstrapPort
=
SingleOrBatch
<
Option
<
u16
>>
;
pub
type
BootstrapRoom
=
SingleOrBatch
<
u64
>
;
// Bootstrap trait for request handling
pub
trait
Bootstrap
:
Send
+
Sync
{
fn
is_stream
(
&
self
)
->
bool
;
fn
get_batch_size
(
&
self
)
->
Result
<
Option
<
usize
>
,
String
>
;
fn
set_bootstrap_info
(
&
mut
self
,
bootstrap_host
:
BootstrapHost
,
bootstrap_port
:
BootstrapPort
,
bootstrap_room
:
BootstrapRoom
,
);
fn
add_bootstrap_info
(
&
mut
self
,
prefill_worker
:
&
dyn
Worker
)
->
Result
<
(),
String
>
{
let
batch_size
=
self
.get_batch_size
()
?
;
// Extract bootstrap port from prefill worker if it's a prefill type
let
bootstrap_port
=
match
prefill_worker
.worker_type
()
{
WorkerType
::
Prefill
{
bootstrap_port
}
=>
bootstrap_port
,
_
=>
None
,
};
let
hostname
=
get_hostname
(
prefill_worker
.url
());
if
let
Some
(
batch_size
)
=
batch_size
{
self
.set_bootstrap_info
(
BootstrapHost
::
Batch
(
vec!
[
hostname
;
batch_size
]),
BootstrapPort
::
Batch
(
vec!
[
bootstrap_port
;
batch_size
]),
// Use high-quality random numbers to minimize collision risk
BootstrapRoom
::
Batch
(
(
0
..
batch_size
)
.map
(|
_
|
{
// Generate a value in the range [0, 2^63 - 1] to match Python's random.randint(0, 2**63 - 1)
rand
::
random
::
<
u64
>
()
&
(
i64
::
MAX
as
u64
)
})
.collect
(),
),
);
}
else
{
self
.set_bootstrap_info
(
BootstrapHost
::
Single
(
hostname
),
BootstrapPort
::
Single
(
bootstrap_port
),
BootstrapRoom
::
Single
(
// Generate a value in the range [0, 2^63 - 1] to match Python's random.randint(0, 2**63 - 1)
rand
::
random
::
<
u64
>
()
&
(
i64
::
MAX
as
u64
),
),
);
}
Ok
(())
}
}
// Request types
#[derive(Debug,
Deserialize,
Serialize)]
pub
struct
GenerateReqInput
{
pub
text
:
Option
<
InputText
>
,
pub
input_ids
:
Option
<
InputIds
>
,
#[serde(default)]
pub
stream
:
bool
,
pub
bootstrap_host
:
Option
<
BootstrapHost
>
,
pub
bootstrap_port
:
Option
<
BootstrapPort
>
,
pub
bootstrap_room
:
Option
<
BootstrapRoom
>
,
#[serde(flatten)]
pub
other
:
Value
,
}
impl
GenerateReqInput
{
pub
fn
get_batch_size
(
&
self
)
->
Result
<
Option
<
usize
>
,
String
>
{
if
self
.text
.is_some
()
&&
self
.input_ids
.is_some
()
{
return
Err
(
"Both text and input_ids are present in the request"
.to_string
());
}
// Check text batch
if
let
Some
(
InputText
::
Batch
(
texts
))
=
&
self
.text
{
if
texts
.is_empty
()
{
return
Err
(
"Batch text array is empty"
.to_string
());
}
return
Ok
(
Some
(
texts
.len
()));
}
// Check input_ids batch
if
let
Some
(
InputIds
::
Batch
(
ids
))
=
&
self
.input_ids
{
if
ids
.is_empty
()
{
return
Err
(
"Batch input_ids array is empty"
.to_string
());
}
// Validate each sequence is not empty
for
(
i
,
seq
)
in
ids
.iter
()
.enumerate
()
{
if
seq
.is_empty
()
{
return
Err
(
format!
(
"Input sequence at index {} is empty"
,
i
));
}
}
return
Ok
(
Some
(
ids
.len
()));
}
Ok
(
None
)
}
}
impl
Bootstrap
for
GenerateReqInput
{
fn
is_stream
(
&
self
)
->
bool
{
self
.stream
}
fn
get_batch_size
(
&
self
)
->
Result
<
Option
<
usize
>
,
String
>
{
self
.get_batch_size
()
}
fn
set_bootstrap_info
(
&
mut
self
,
bootstrap_host
:
BootstrapHost
,
bootstrap_port
:
BootstrapPort
,
bootstrap_room
:
BootstrapRoom
,
)
{
self
.bootstrap_host
=
Some
(
bootstrap_host
);
self
.bootstrap_port
=
Some
(
bootstrap_port
);
self
.bootstrap_room
=
Some
(
bootstrap_room
);
}
}
#[derive(Debug,
Deserialize,
Serialize)]
pub
struct
ChatReqInput
{
#[serde(default)]
pub
stream
:
bool
,
pub
bootstrap_host
:
Option
<
BootstrapHost
>
,
pub
bootstrap_port
:
Option
<
BootstrapPort
>
,
pub
bootstrap_room
:
Option
<
BootstrapRoom
>
,
#[serde(flatten)]
pub
other
:
Value
,
}
impl
Bootstrap
for
ChatReqInput
{
fn
is_stream
(
&
self
)
->
bool
{
self
.stream
}
fn
get_batch_size
(
&
self
)
->
Result
<
Option
<
usize
>
,
String
>
{
// Check if 'n' parameter is present and > 1
if
let
Some
(
n_value
)
=
self
.other
.get
(
"n"
)
{
if
let
Some
(
n
)
=
n_value
.as_u64
()
{
if
n
>
1
{
return
Ok
(
Some
(
n
as
usize
));
}
}
}
Ok
(
None
)
}
fn
set_bootstrap_info
(
&
mut
self
,
bootstrap_host
:
BootstrapHost
,
bootstrap_port
:
BootstrapPort
,
bootstrap_room
:
BootstrapRoom
,
)
{
self
.bootstrap_host
=
Some
(
bootstrap_host
);
self
.bootstrap_port
=
Some
(
bootstrap_port
);
self
.bootstrap_room
=
Some
(
bootstrap_room
);
}
}
// Bootstrap implementation for CompletionRequest to preserve OpenAI format
impl
Bootstrap
for
CompletionRequest
{
fn
is_stream
(
&
self
)
->
bool
{
self
.stream
}
fn
get_batch_size
(
&
self
)
->
Result
<
Option
<
usize
>
,
String
>
{
if
let
StringOrArray
::
Array
(
prompts
)
=
&
self
.prompt
{
if
prompts
.is_empty
()
{
return
Err
(
"Batch prompt array is empty"
.to_string
());
}
return
Ok
(
Some
(
prompts
.len
()));
}
// Single string prompt
Ok
(
None
)
}
fn
set_bootstrap_info
(
&
mut
self
,
bootstrap_host
:
BootstrapHost
,
bootstrap_port
:
BootstrapPort
,
bootstrap_room
:
BootstrapRoom
,
)
{
// Insert bootstrap_host - it serializes correctly whether Single or Batch
if
let
Ok
(
host_value
)
=
serde_json
::
to_value
(
&
bootstrap_host
)
{
self
.other
.insert
(
"bootstrap_host"
.to_string
(),
host_value
);
}
// Insert bootstrap_port - it serializes correctly whether Single or Batch
if
let
Ok
(
port_value
)
=
serde_json
::
to_value
(
&
bootstrap_port
)
{
self
.other
.insert
(
"bootstrap_port"
.to_string
(),
port_value
);
}
// Insert bootstrap_room - it serializes correctly whether Single or Batch
if
let
Ok
(
room_value
)
=
serde_json
::
to_value
(
&
bootstrap_room
)
{
self
.other
.insert
(
"bootstrap_room"
.to_string
(),
room_value
);
}
}
}
#[cfg(test)]
mod
bootstrap_tests
{
use
super
::
*
;
use
crate
::
core
::
BasicWorker
;
use
crate
::
openai_api_types
::
StringOrArray
;
/// Create a default CompletionRequest for testing with minimal fields set
fn
default_completion_request
()
->
CompletionRequest
{
CompletionRequest
{
model
:
String
::
new
(),
prompt
:
StringOrArray
::
String
(
String
::
new
()),
n
:
None
,
other
:
serde_json
::
Map
::
new
(),
suffix
:
None
,
max_tokens
:
None
,
temperature
:
None
,
top_p
:
None
,
stream
:
false
,
stream_options
:
None
,
logprobs
:
None
,
echo
:
false
,
stop
:
None
,
presence_penalty
:
None
,
frequency_penalty
:
None
,
best_of
:
None
,
logit_bias
:
None
,
user
:
None
,
seed
:
None
,
// SGLang Extensions
top_k
:
None
,
min_p
:
None
,
min_tokens
:
None
,
repetition_penalty
:
None
,
regex
:
None
,
ebnf
:
None
,
json_schema
:
None
,
stop_token_ids
:
None
,
no_stop_trim
:
false
,
ignore_eos
:
false
,
skip_special_tokens
:
true
,
// SGLang Extensions
lora_path
:
None
,
session_params
:
None
,
return_hidden_states
:
false
,
}
}
#[test]
fn
test_completion_batch_size_with_array_prompt
()
{
let
req
=
CompletionRequest
{
model
:
"test"
.to_string
(),
prompt
:
StringOrArray
::
Array
(
vec!
[
"prompt1"
.to_string
(),
"prompt2"
.to_string
()]),
..
default_completion_request
()
};
// Should return batch size for array prompt
assert_eq!
(
req
.get_batch_size
()
.unwrap
(),
Some
(
2
));
}
#[test]
fn
test_completion_batch_size_with_single_prompt
()
{
let
req
=
CompletionRequest
{
model
:
"test"
.to_string
(),
prompt
:
StringOrArray
::
String
(
"single prompt"
.to_string
()),
..
default_completion_request
()
};
// Should return None for single prompt
assert_eq!
(
req
.get_batch_size
()
.unwrap
(),
None
);
}
#[test]
fn
test_completion_batch_size_with_n_parameter
()
{
let
req
=
CompletionRequest
{
model
:
"test"
.to_string
(),
prompt
:
StringOrArray
::
String
(
"single prompt"
.to_string
()),
n
:
Some
(
3
),
..
default_completion_request
()
};
// Should return None for single string prompt, even with n > 1
// SGLang handles n parameter differently than batch requests
assert_eq!
(
req
.get_batch_size
()
.unwrap
(),
None
);
}
#[test]
fn
test_completion_bootstrap_single_values
()
{
let
mut
req
=
CompletionRequest
{
model
:
"test"
.to_string
(),
prompt
:
StringOrArray
::
Array
(
vec!
[
"prompt1"
.to_string
(),
"prompt2"
.to_string
()]),
..
default_completion_request
()
};
// Set bootstrap info - should always use single values
req
.set_bootstrap_info
(
BootstrapHost
::
Single
(
"test-server"
.to_string
()),
BootstrapPort
::
Single
(
Some
(
5678
)),
BootstrapRoom
::
Single
(
12345
),
);
// Verify single values were created
assert
!
(
req
.other
.get
(
"bootstrap_host"
)
.unwrap
()
.is_string
());
assert
!
(
req
.other
.get
(
"bootstrap_port"
)
.unwrap
()
.is_number
());
assert
!
(
req
.other
.get
(
"bootstrap_room"
)
.unwrap
()
.is_number
());
assert_eq!
(
req
.other
.get
(
"bootstrap_host"
)
.unwrap
()
.as_str
()
.unwrap
(),
"test-server"
);
assert_eq!
(
req
.other
.get
(
"bootstrap_port"
)
.unwrap
()
.as_u64
()
.unwrap
(),
5678
);
assert_eq!
(
req
.other
.get
(
"bootstrap_room"
)
.unwrap
()
.as_u64
()
.unwrap
(),
12345
);
}
#[test]
fn
test_completion_bootstrap_array_values
()
{
let
mut
req
=
CompletionRequest
{
model
:
"test"
.to_string
(),
prompt
:
StringOrArray
::
Array
(
vec!
[
"prompt1"
.to_string
(),
"prompt2"
.to_string
()]),
..
default_completion_request
()
};
// Set bootstrap info with arrays
req
.set_bootstrap_info
(
BootstrapHost
::
Batch
(
vec!
[
"test-server"
.to_string
();
2
]),
BootstrapPort
::
Batch
(
vec!
[
Some
(
5678
);
2
]),
BootstrapRoom
::
Batch
(
vec!
[
12345
,
67890
]),
);
// Verify arrays were created correctly
assert
!
(
req
.other
.get
(
"bootstrap_host"
)
.unwrap
()
.is_array
());
assert
!
(
req
.other
.get
(
"bootstrap_port"
)
.unwrap
()
.is_array
());
assert
!
(
req
.other
.get
(
"bootstrap_room"
)
.unwrap
()
.is_array
());
let
hosts
=
req
.other
.get
(
"bootstrap_host"
)
.unwrap
()
.as_array
()
.unwrap
();
assert_eq!
(
hosts
.len
(),
2
);
assert_eq!
(
hosts
[
0
]
.as_str
()
.unwrap
(),
"test-server"
);
let
ports
=
req
.other
.get
(
"bootstrap_port"
)
.unwrap
()
.as_array
()
.unwrap
();
assert_eq!
(
ports
.len
(),
2
);
assert_eq!
(
ports
[
0
]
.as_u64
()
.unwrap
(),
5678
);
let
rooms
=
req
.other
.get
(
"bootstrap_room"
)
.unwrap
()
.as_array
()
.unwrap
();
assert_eq!
(
rooms
.len
(),
2
);
assert_eq!
(
rooms
[
0
]
.as_u64
()
.unwrap
(),
12345
);
assert_eq!
(
rooms
[
1
]
.as_u64
()
.unwrap
(),
67890
);
}
#[test]
fn
test_bootstrap_room_range
()
{
// Test that bootstrap_room values are within the expected range [0, 2^63 - 1]
let
worker
=
BasicWorker
::
new
(
"http://test:8000"
.to_string
(),
WorkerType
::
Prefill
{
bootstrap_port
:
Some
(
8080
),
},
);
// Test single request
let
mut
single_req
=
GenerateReqInput
{
text
:
Some
(
InputText
::
Single
(
"test"
.to_string
())),
input_ids
:
None
,
stream
:
false
,
bootstrap_host
:
None
,
bootstrap_port
:
None
,
bootstrap_room
:
None
,
other
:
Value
::
Object
(
serde_json
::
Map
::
new
()),
};
for
_
in
0
..
200000
{
single_req
.add_bootstrap_info
(
&
worker
)
.unwrap
();
if
let
Some
(
BootstrapRoom
::
Single
(
room
))
=
single_req
.bootstrap_room
{
// Verify the room value is within signed 64-bit range
assert
!
(
room
<=
i64
::
MAX
as
u64
,
"Room {} exceeds i64::MAX"
,
room
);
}
else
{
panic!
(
"Expected single bootstrap room"
);
}
}
// Test batch request
let
mut
batch_req
=
GenerateReqInput
{
text
:
Some
(
InputText
::
Batch
(
vec!
[
"test1"
.to_string
(),
"test2"
.to_string
(),
])),
input_ids
:
None
,
stream
:
false
,
bootstrap_host
:
None
,
bootstrap_port
:
None
,
bootstrap_room
:
None
,
other
:
Value
::
Object
(
serde_json
::
Map
::
new
()),
};
for
_
in
0
..
200000
{
batch_req
.add_bootstrap_info
(
&
worker
)
.unwrap
();
if
let
Some
(
BootstrapRoom
::
Batch
(
rooms
))
=
&
batch_req
.bootstrap_room
{
for
room
in
rooms
{
// Verify each room value is within signed 64-bit range
assert
!
(
*
room
<=
i64
::
MAX
as
u64
,
"Room {} exceeds i64::MAX"
,
room
);
}
}
else
{
panic!
(
"Expected batch bootstrap rooms"
);
}
}
}
}
sgl-router/src/routers/request_adapter.rs
deleted
100644 → 0
View file @
ca47e24f
This diff is collapsed.
Click to expand it.
sgl-router/tests/benchmark_integration.rs
View file @
8c7bb39d
// Integration test to ensure benchmarks compile and basic functionality works
// This prevents benchmarks from breaking in CI
//
// UPDATED: Removed deprecated ToPdRequest usage, now uses direct JSON serialization
use
serde_json
::{
from_str
,
to_string
};
use
serde_json
::{
from_str
,
to_string
,
to_value
};
use
sglang_router_rs
::
core
::{
BasicWorker
,
WorkerType
};
use
sglang_router_rs
::
openai_api_types
::{
ChatCompletionRequest
,
ChatMessage
,
CompletionRequest
,
GenerateParameters
,
GenerateRequest
,
SamplingParams
,
StringOrArray
,
UserMessageContent
,
};
use
sglang_router_rs
::
routers
::
request_adapter
::{
RouteableRequest
,
ToPdRequest
}
;
use
sglang_router_rs
::
routers
::
bootstrap_injector
::
inject_bootstrap_fields
;
/// Create a default GenerateRequest for benchmarks with minimal fields set
fn
default_generate_request
()
->
GenerateRequest
{
...
...
@@ -114,6 +117,15 @@ fn default_completion_request() -> CompletionRequest {
}
}
fn
create_test_worker
()
->
BasicWorker
{
BasicWorker
::
new
(
"http://test-server:8000"
.to_string
(),
WorkerType
::
Prefill
{
bootstrap_port
:
Some
(
5678
),
},
)
}
#[test]
fn
test_benchmark_request_creation
()
{
// Ensure all benchmark request types can be created without panicking
...
...
@@ -197,8 +209,8 @@ fn test_benchmark_serialization_roundtrip() {
}
#[test]
fn
test_benchmark_
request_adapta
tion
()
{
// Test that
PD request adapta
tion works for benchmark types
fn
test_benchmark_
bootstrap_injec
tion
()
{
// Test that
bootstrap injec
tion works for benchmark types
(replaces PD request adaptation)
let
generate_req
=
GenerateRequest
{
text
:
Some
(
"Test prompt"
.to_string
()),
...
...
@@ -236,24 +248,40 @@ fn test_benchmark_request_adaptation() {
..
default_completion_request
()
};
// Test PD adaptation (should not panic)
let
_
pd_generate
=
generate_req
.to_pd_request
();
let
_
pd_chat
=
chat_req
.to_pd_request
();
let
_
pd_completion
=
completion_req
.to_pd_request
();
let
worker
=
create_test_worker
();
// Test bootstrap injection (should not panic)
let
mut
generate_json
=
to_value
(
&
generate_req
)
.unwrap
();
let
mut
chat_json
=
to_value
(
&
chat_req
)
.unwrap
();
let
mut
completion_json
=
to_value
(
&
completion_req
)
.unwrap
();
assert
!
(
inject_bootstrap_fields
(
&
mut
generate_json
,
&
worker
)
.is_ok
());
assert
!
(
inject_bootstrap_fields
(
&
mut
chat_json
,
&
worker
)
.is_ok
());
assert
!
(
inject_bootstrap_fields
(
&
mut
completion_json
,
&
worker
)
.is_ok
());
// Verify bootstrap fields were added
assert
!
(
generate_json
.get
(
"bootstrap_host"
)
.is_some
());
assert
!
(
generate_json
.get
(
"bootstrap_port"
)
.is_some
());
assert
!
(
generate_json
.get
(
"bootstrap_room"
)
.is_some
());
}
#[test]
fn
test_benchmark_
regular
_routing
()
{
// Test
regular
routing functionality for benchmark types
fn
test_benchmark_
direct_json
_routing
()
{
// Test
direct JSON
routing functionality for benchmark types
(replaces regular routing)
let
generate_req
=
GenerateRequest
{
text
:
Some
(
"Test prompt"
.to_string
()),
..
default_generate_request
()
};
// Test regular routing methods (should not panic)
let
_
json
=
generate_req
.to_json
();
let
_
bytes
=
generate_req
.to_bytes
();
// Test direct JSON conversion (replaces regular routing methods)
let
json
=
to_value
(
&
generate_req
)
.unwrap
();
let
json_string
=
to_string
(
&
json
)
.unwrap
();
let
bytes
=
json_string
.as_bytes
();
// Verify conversions work
assert
!
(
!
json_string
.is_empty
());
assert
!
(
!
bytes
.is_empty
());
}
#[test]
...
...
@@ -266,23 +294,36 @@ fn test_benchmark_performance_baseline() {
..
default_generate_request
()
};
//
Serialization should be fast (< 1ms for simple requests)
//
Test the actual simplified pipeline: to_value + bootstrap injection
let
start
=
Instant
::
now
();
let
_
json
=
to_string
(
&
generate_req
)
.unwrap
();
let
serialize_duration
=
start
.elapsed
();
let
worker
=
create_test_worker
();
// This mirrors the actual router pipeline
let
mut
json
=
to_value
(
&
generate_req
)
.unwrap
();
let
_
=
inject_bootstrap_fields
(
&
mut
json
,
&
worker
);
let
total_duration
=
start
.elapsed
();
assert
!
(
serialize
_duration
.as_millis
()
<
1
,
"S
erialization
took too long: {:?}"
,
serialize
_duration
total
_duration
.as_millis
()
<
5
,
"S
implified pipeline
took too long: {:?}
(should be faster than old adapter approach)
"
,
total
_duration
);
//
PD adaptation should be very fast (< 1ms)
//
Individual components should also be fast
let
start
=
Instant
::
now
();
let
_
pd_req
=
generate_req
.to_pd_request
();
let
adapt_duration
=
start
.elapsed
();
let
_
json
=
to_value
(
&
generate_req
)
.unwrap
();
let
to_value_duration
=
start
.elapsed
();
let
start
=
Instant
::
now
();
let
mut
json
=
to_value
(
&
generate_req
)
.unwrap
();
let
_
=
inject_bootstrap_fields
(
&
mut
json
,
&
worker
);
let
inject_duration
=
start
.elapsed
();
// Bootstrap injection should be faster than the JSON conversion
assert
!
(
adapt_duration
.as_millis
()
<
1
,
"PD adaptation took too long: {:?}"
,
adapt_duration
inject_duration
<=
to_value_duration
*
3
,
"Bootstrap injection ({:?}) should not be much slower than JSON conversion ({:?})"
,
inject_duration
,
to_value_duration
);
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment