Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
a69b6370
Unverified
Commit
a69b6370
authored
Aug 06, 2025
by
Simo Lin
Committed by
GitHub
Aug 06, 2025
Browse files
[router] fix req handling order, improve serialization, remove retry (#8888)
parent
2d120f8b
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
429 additions
and
853 deletions
+429
-853
sgl-router/benches/request_processing.rs
sgl-router/benches/request_processing.rs
+83
-26
sgl-router/scripts/run_benchmarks.py
sgl-router/scripts/run_benchmarks.py
+3
-0
sgl-router/src/policies/cache_aware.rs
sgl-router/src/policies/cache_aware.rs
+4
-0
sgl-router/src/policies/mod.rs
sgl-router/src/policies/mod.rs
+5
-0
sgl-router/src/routers/bootstrap_injector.rs
sgl-router/src/routers/bootstrap_injector.rs
+0
-334
sgl-router/src/routers/mod.rs
sgl-router/src/routers/mod.rs
+0
-1
sgl-router/src/routers/pd_router.rs
sgl-router/src/routers/pd_router.rs
+305
-389
sgl-router/src/routers/pd_types.rs
sgl-router/src/routers/pd_types.rs
+28
-0
sgl-router/src/server.rs
sgl-router/src/server.rs
+1
-1
sgl-router/tests/benchmark_integration.rs
sgl-router/tests/benchmark_integration.rs
+0
-102
No files found.
sgl-router/benches/request_processing.rs
View file @
a69b6370
...
@@ -2,12 +2,12 @@ use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criteri
...
@@ -2,12 +2,12 @@ use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criteri
use
serde_json
::{
from_str
,
to_string
,
to_value
,
to_vec
};
use
serde_json
::{
from_str
,
to_string
,
to_value
,
to_vec
};
use
std
::
time
::
Instant
;
use
std
::
time
::
Instant
;
use
sglang_router_rs
::
core
::{
BasicWorker
,
WorkerType
};
use
sglang_router_rs
::
core
::{
BasicWorker
,
Worker
,
WorkerType
};
use
sglang_router_rs
::
openai_api_types
::{
use
sglang_router_rs
::
openai_api_types
::{
ChatCompletionRequest
,
ChatMessage
,
CompletionRequest
,
GenerateParameters
,
GenerateRequest
,
ChatCompletionRequest
,
ChatMessage
,
CompletionRequest
,
GenerateParameters
,
GenerateRequest
,
SamplingParams
,
StringOrArray
,
UserMessageContent
,
SamplingParams
,
StringOrArray
,
UserMessageContent
,
};
};
use
sglang_router_rs
::
routers
::
bootstrap_injector
::
inject_b
ootstrap
_fields
;
use
sglang_router_rs
::
routers
::
pd_types
::{
generate_room_id
,
get_hostname
,
RequestWithB
ootstrap
}
;
fn
create_test_worker
()
->
BasicWorker
{
fn
create_test_worker
()
->
BasicWorker
{
BasicWorker
::
new
(
BasicWorker
::
new
(
...
@@ -18,6 +18,16 @@ fn create_test_worker() -> BasicWorker {
...
@@ -18,6 +18,16 @@ fn create_test_worker() -> BasicWorker {
)
)
}
}
// Helper function to get bootstrap info from worker
fn
get_bootstrap_info
(
worker
:
&
BasicWorker
)
->
(
String
,
Option
<
u16
>
)
{
let
hostname
=
get_hostname
(
worker
.url
());
let
bootstrap_port
=
match
worker
.worker_type
()
{
WorkerType
::
Prefill
{
bootstrap_port
}
=>
bootstrap_port
.clone
(),
_
=>
None
,
};
(
hostname
,
bootstrap_port
)
}
/// Create a default GenerateRequest for benchmarks with minimal fields set
/// Create a default GenerateRequest for benchmarks with minimal fields set
fn
default_generate_request
()
->
GenerateRequest
{
fn
default_generate_request
()
->
GenerateRequest
{
GenerateRequest
{
GenerateRequest
{
...
@@ -331,35 +341,56 @@ fn bench_bootstrap_injection(c: &mut Criterion) {
...
@@ -331,35 +341,56 @@ fn bench_bootstrap_injection(c: &mut Criterion) {
let
completion_req
=
create_sample_completion_request
();
let
completion_req
=
create_sample_completion_request
();
let
large_chat_req
=
create_large_chat_completion_request
();
let
large_chat_req
=
create_large_chat_completion_request
();
let
worker
=
create_test_worker
();
let
worker
=
create_test_worker
();
let
(
hostname
,
bootstrap_port
)
=
get_bootstrap_info
(
&
worker
);
group
.bench_function
(
"generate_bootstrap_injection"
,
|
b
|
{
group
.bench_function
(
"generate_bootstrap_injection"
,
|
b
|
{
b
.iter
(||
{
b
.iter
(||
{
let
mut
json
=
to_value
(
black_box
(
&
generate_req
))
.unwrap
();
let
request_with_bootstrap
=
RequestWithBootstrap
{
inject_bootstrap_fields
(
&
mut
json
,
&
worker
)
.unwrap
();
original
:
&
generate_req
,
bootstrap_host
:
hostname
.clone
(),
bootstrap_port
,
bootstrap_room
:
generate_room_id
(),
};
let
json
=
to_value
(
black_box
(
&
request_with_bootstrap
))
.unwrap
();
black_box
(
json
);
black_box
(
json
);
});
});
});
});
group
.bench_function
(
"chat_completion_bootstrap_injection"
,
|
b
|
{
group
.bench_function
(
"chat_completion_bootstrap_injection"
,
|
b
|
{
b
.iter
(||
{
b
.iter
(||
{
let
mut
json
=
to_value
(
black_box
(
&
chat_req
))
.unwrap
();
let
request_with_bootstrap
=
RequestWithBootstrap
{
inject_bootstrap_fields
(
&
mut
json
,
&
worker
)
.unwrap
();
original
:
&
chat_req
,
bootstrap_host
:
hostname
.clone
(),
bootstrap_port
,
bootstrap_room
:
generate_room_id
(),
};
let
json
=
to_value
(
black_box
(
&
request_with_bootstrap
))
.unwrap
();
black_box
(
json
);
black_box
(
json
);
});
});
});
});
group
.bench_function
(
"completion_bootstrap_injection"
,
|
b
|
{
group
.bench_function
(
"completion_bootstrap_injection"
,
|
b
|
{
b
.iter
(||
{
b
.iter
(||
{
let
mut
json
=
to_value
(
black_box
(
&
completion_req
))
.unwrap
();
let
request_with_bootstrap
=
RequestWithBootstrap
{
inject_bootstrap_fields
(
&
mut
json
,
&
worker
)
.unwrap
();
original
:
&
completion_req
,
bootstrap_host
:
hostname
.clone
(),
bootstrap_port
,
bootstrap_room
:
generate_room_id
(),
};
let
json
=
to_value
(
black_box
(
&
request_with_bootstrap
))
.unwrap
();
black_box
(
json
);
black_box
(
json
);
});
});
});
});
group
.bench_function
(
"large_chat_completion_bootstrap_injection"
,
|
b
|
{
group
.bench_function
(
"large_chat_completion_bootstrap_injection"
,
|
b
|
{
b
.iter
(||
{
b
.iter
(||
{
let
mut
json
=
to_value
(
black_box
(
&
large_chat_req
))
.unwrap
();
let
request_with_bootstrap
=
RequestWithBootstrap
{
inject_bootstrap_fields
(
&
mut
json
,
&
worker
)
.unwrap
();
original
:
&
large_chat_req
,
bootstrap_host
:
hostname
.clone
(),
bootstrap_port
,
bootstrap_room
:
generate_room_id
(),
};
let
json
=
to_value
(
black_box
(
&
request_with_bootstrap
))
.unwrap
();
black_box
(
json
);
black_box
(
json
);
});
});
});
});
...
@@ -441,6 +472,7 @@ fn bench_throughput_by_size(c: &mut Criterion) {
...
@@ -441,6 +472,7 @@ fn bench_throughput_by_size(c: &mut Criterion) {
};
};
let
worker
=
create_test_worker
();
let
worker
=
create_test_worker
();
let
(
hostname
,
bootstrap_port
)
=
get_bootstrap_info
(
&
worker
);
for
(
name
,
req
)
in
[
for
(
name
,
req
)
in
[
(
"small"
,
&
small_generate
),
(
"small"
,
&
small_generate
),
...
@@ -449,6 +481,7 @@ fn bench_throughput_by_size(c: &mut Criterion) {
...
@@ -449,6 +481,7 @@ fn bench_throughput_by_size(c: &mut Criterion) {
]
{
]
{
let
json
=
to_string
(
req
)
.unwrap
();
let
json
=
to_string
(
req
)
.unwrap
();
let
size_bytes
=
json
.len
();
let
size_bytes
=
json
.len
();
let
hostname_clone
=
hostname
.clone
();
group
.throughput
(
Throughput
::
Bytes
(
size_bytes
as
u64
));
group
.throughput
(
Throughput
::
Bytes
(
size_bytes
as
u64
));
group
.bench_with_input
(
BenchmarkId
::
new
(
"serialize"
,
name
),
&
req
,
|
b
,
req
|
{
group
.bench_with_input
(
BenchmarkId
::
new
(
"serialize"
,
name
),
&
req
,
|
b
,
req
|
{
...
@@ -472,10 +505,16 @@ fn bench_throughput_by_size(c: &mut Criterion) {
...
@@ -472,10 +505,16 @@ fn bench_throughput_by_size(c: &mut Criterion) {
group
.bench_with_input
(
group
.bench_with_input
(
BenchmarkId
::
new
(
"bootstrap_inject"
,
name
),
BenchmarkId
::
new
(
"bootstrap_inject"
,
name
),
&
req
,
&
req
,
|
b
,
req
|
{
move
|
b
,
req
|
{
let
hostname
=
hostname_clone
.clone
();
b
.iter
(||
{
b
.iter
(||
{
let
mut
json
=
to_value
(
req
)
.unwrap
();
let
request_with_bootstrap
=
RequestWithBootstrap
{
inject_bootstrap_fields
(
&
mut
json
,
&
worker
)
.unwrap
();
original
:
req
,
bootstrap_host
:
hostname
.clone
(),
bootstrap_port
,
bootstrap_room
:
generate_room_id
(),
};
let
json
=
to_value
(
&
request_with_bootstrap
)
.unwrap
();
black_box
(
json
);
black_box
(
json
);
});
});
},
},
...
@@ -493,17 +532,21 @@ fn bench_full_round_trip(c: &mut Criterion) {
...
@@ -493,17 +532,21 @@ fn bench_full_round_trip(c: &mut Criterion) {
let
chat_json
=
to_string
(
&
create_sample_chat_completion_request
())
.unwrap
();
let
chat_json
=
to_string
(
&
create_sample_chat_completion_request
())
.unwrap
();
let
completion_json
=
to_string
(
&
create_sample_completion_request
())
.unwrap
();
let
completion_json
=
to_string
(
&
create_sample_completion_request
())
.unwrap
();
let
worker
=
create_test_worker
();
let
worker
=
create_test_worker
();
let
(
hostname
,
bootstrap_port
)
=
get_bootstrap_info
(
&
worker
);
group
.bench_function
(
"generate_openai_to_pd_pipeline"
,
|
b
|
{
group
.bench_function
(
"generate_openai_to_pd_pipeline"
,
|
b
|
{
b
.iter
(||
{
b
.iter
(||
{
// Deserialize OpenAI request
// Deserialize OpenAI request
let
req
:
GenerateRequest
=
from_str
(
black_box
(
&
generate_json
))
.unwrap
();
let
req
:
GenerateRequest
=
from_str
(
black_box
(
&
generate_json
))
.unwrap
();
// Convert to JSON Value
// Create wrapper with bootstrap fields
let
mut
json
=
to_value
(
&
req
)
.unwrap
();
let
request_with_bootstrap
=
RequestWithBootstrap
{
// Inject bootstrap fields
original
:
&
req
,
inject_bootstrap_fields
(
&
mut
json
,
&
worker
)
.unwrap
();
bootstrap_host
:
hostname
.clone
(),
bootstrap_port
,
bootstrap_room
:
generate_room_id
(),
};
// Serialize final request
// Serialize final request
let
pd_json
=
to_string
(
&
json
)
.unwrap
();
let
pd_json
=
to_string
(
&
request_with_bootstrap
)
.unwrap
();
black_box
(
pd_json
);
black_box
(
pd_json
);
});
});
});
});
...
@@ -511,9 +554,13 @@ fn bench_full_round_trip(c: &mut Criterion) {
...
@@ -511,9 +554,13 @@ fn bench_full_round_trip(c: &mut Criterion) {
group
.bench_function
(
"chat_completion_openai_to_pd_pipeline"
,
|
b
|
{
group
.bench_function
(
"chat_completion_openai_to_pd_pipeline"
,
|
b
|
{
b
.iter
(||
{
b
.iter
(||
{
let
req
:
ChatCompletionRequest
=
from_str
(
black_box
(
&
chat_json
))
.unwrap
();
let
req
:
ChatCompletionRequest
=
from_str
(
black_box
(
&
chat_json
))
.unwrap
();
let
mut
json
=
to_value
(
&
req
)
.unwrap
();
let
request_with_bootstrap
=
RequestWithBootstrap
{
inject_bootstrap_fields
(
&
mut
json
,
&
worker
)
.unwrap
();
original
:
&
req
,
let
pd_json
=
to_string
(
&
json
)
.unwrap
();
bootstrap_host
:
hostname
.clone
(),
bootstrap_port
,
bootstrap_room
:
generate_room_id
(),
};
let
pd_json
=
to_string
(
&
request_with_bootstrap
)
.unwrap
();
black_box
(
pd_json
);
black_box
(
pd_json
);
});
});
});
});
...
@@ -521,9 +568,13 @@ fn bench_full_round_trip(c: &mut Criterion) {
...
@@ -521,9 +568,13 @@ fn bench_full_round_trip(c: &mut Criterion) {
group
.bench_function
(
"completion_openai_to_pd_pipeline"
,
|
b
|
{
group
.bench_function
(
"completion_openai_to_pd_pipeline"
,
|
b
|
{
b
.iter
(||
{
b
.iter
(||
{
let
req
:
CompletionRequest
=
from_str
(
black_box
(
&
completion_json
))
.unwrap
();
let
req
:
CompletionRequest
=
from_str
(
black_box
(
&
completion_json
))
.unwrap
();
let
mut
json
=
to_value
(
&
req
)
.unwrap
();
let
request_with_bootstrap
=
RequestWithBootstrap
{
inject_bootstrap_fields
(
&
mut
json
,
&
worker
)
.unwrap
();
original
:
&
req
,
let
pd_json
=
to_string
(
&
json
)
.unwrap
();
bootstrap_host
:
hostname
.clone
(),
bootstrap_port
,
bootstrap_room
:
generate_room_id
(),
};
let
pd_json
=
to_string
(
&
request_with_bootstrap
)
.unwrap
();
black_box
(
pd_json
);
black_box
(
pd_json
);
});
});
});
});
...
@@ -575,10 +626,16 @@ fn benchmark_summary(c: &mut Criterion) {
...
@@ -575,10 +626,16 @@ fn benchmark_summary(c: &mut Criterion) {
);
);
// Measure bootstrap injection (replaces adaptation)
// Measure bootstrap injection (replaces adaptation)
let
(
hostname
,
bootstrap_port
)
=
get_bootstrap_info
(
&
worker
);
let
start
=
Instant
::
now
();
let
start
=
Instant
::
now
();
for
_
in
0
..
1000
{
for
_
in
0
..
1000
{
let
mut
json
=
to_value
(
&
generate_req
)
.unwrap
();
let
request_with_bootstrap
=
RequestWithBootstrap
{
let
_
=
black_box
(
inject_bootstrap_fields
(
&
mut
json
,
&
worker
));
original
:
&
generate_req
,
bootstrap_host
:
hostname
.clone
(),
bootstrap_port
,
bootstrap_room
:
generate_room_id
(),
};
let
_
=
black_box
(
to_value
(
&
request_with_bootstrap
)
.unwrap
());
}
}
let
inject_time
=
start
.elapsed
()
.as_nanos
()
/
1000
;
let
inject_time
=
start
.elapsed
()
.as_nanos
()
/
1000
;
println!
(
" * Bootstrap Injection (avg): {:>6} ns/req"
,
inject_time
);
println!
(
" * Bootstrap Injection (avg): {:>6} ns/req"
,
inject_time
);
...
...
sgl-router/scripts/run_benchmarks.py
View file @
a69b6370
...
@@ -121,6 +121,8 @@ class BenchmarkRunner:
...
@@ -121,6 +121,8 @@ class BenchmarkRunner:
results
[
"serialization_time"
]
=
self
.
_extract_time
(
line
)
results
[
"serialization_time"
]
=
self
.
_extract_time
(
line
)
elif
"Deserialization (avg):"
in
line
:
elif
"Deserialization (avg):"
in
line
:
results
[
"deserialization_time"
]
=
self
.
_extract_time
(
line
)
results
[
"deserialization_time"
]
=
self
.
_extract_time
(
line
)
elif
"Bootstrap Injection (avg):"
in
line
:
results
[
"bootstrap_injection_time"
]
=
self
.
_extract_time
(
line
)
elif
"Total Pipeline (avg):"
in
line
:
elif
"Total Pipeline (avg):"
in
line
:
results
[
"total_time"
]
=
self
.
_extract_time
(
line
)
results
[
"total_time"
]
=
self
.
_extract_time
(
line
)
...
@@ -143,6 +145,7 @@ class BenchmarkRunner:
...
@@ -143,6 +145,7 @@ class BenchmarkRunner:
thresholds
=
{
thresholds
=
{
"serialization_time"
:
2000
,
# 2μs max
"serialization_time"
:
2000
,
# 2μs max
"deserialization_time"
:
2000
,
# 2μs max
"deserialization_time"
:
2000
,
# 2μs max
"bootstrap_injection_time"
:
5000
,
# 5μs max
"total_time"
:
10000
,
# 10μs max
"total_time"
:
10000
,
# 10μs max
}
}
...
...
sgl-router/src/policies/cache_aware.rs
View file @
a69b6370
...
@@ -230,6 +230,10 @@ impl LoadBalancingPolicy for CacheAwarePolicy {
...
@@ -230,6 +230,10 @@ impl LoadBalancingPolicy for CacheAwarePolicy {
"cache_aware"
"cache_aware"
}
}
fn
needs_request_text
(
&
self
)
->
bool
{
true
// Cache-aware policy needs request text for cache affinity
}
fn
on_request_complete
(
&
self
,
worker_url
:
&
str
,
success
:
bool
)
{
fn
on_request_complete
(
&
self
,
worker_url
:
&
str
,
success
:
bool
)
{
// Could track success rates per worker for more intelligent routing
// Could track success rates per worker for more intelligent routing
if
!
success
{
if
!
success
{
...
...
sgl-router/src/policies/mod.rs
View file @
a69b6370
...
@@ -59,6 +59,11 @@ pub trait LoadBalancingPolicy: Send + Sync + Debug {
...
@@ -59,6 +59,11 @@ pub trait LoadBalancingPolicy: Send + Sync + Debug {
/// Get policy name for metrics and debugging
/// Get policy name for metrics and debugging
fn
name
(
&
self
)
->
&
'static
str
;
fn
name
(
&
self
)
->
&
'static
str
;
/// Check if this policy needs request text for routing decisions
fn
needs_request_text
(
&
self
)
->
bool
{
false
// Default: most policies don't need request text
}
/// Update worker load information
/// Update worker load information
///
///
/// This is called periodically with current load information for load-aware policies.
/// This is called periodically with current load information for load-aware policies.
...
...
sgl-router/src/routers/bootstrap_injector.rs
deleted
100644 → 0
View file @
2d120f8b
// Bootstrap field injection for PD routing
// Directly injects bootstrap fields into JSON requests without intermediate type conversions
use
crate
::
core
::{
Worker
,
WorkerType
};
use
crate
::
routers
::
pd_types
::
get_hostname
;
use
serde_json
::{
json
,
Value
};
/// Inject bootstrap fields directly into a JSON request
/// This replaces the complex ToPdRequest -> Bootstrap trait pattern
pub
fn
inject_bootstrap_fields
(
json
:
&
mut
Value
,
worker
:
&
dyn
Worker
)
->
Result
<
(),
String
>
{
let
batch_size
=
extract_batch_size
(
json
)
?
;
// Extract bootstrap port from prefill worker if it's a prefill type
let
bootstrap_port
=
match
worker
.worker_type
()
{
WorkerType
::
Prefill
{
bootstrap_port
}
=>
bootstrap_port
,
_
=>
None
,
};
let
hostname
=
get_hostname
(
worker
.url
());
if
let
Some
(
batch_size
)
=
batch_size
{
// Batch scenario - create arrays of bootstrap values
json
[
"bootstrap_host"
]
=
json!
(
vec!
[
hostname
;
batch_size
]);
json
[
"bootstrap_port"
]
=
json!
(
vec!
[
bootstrap_port
;
batch_size
]);
json
[
"bootstrap_room"
]
=
json!
((
0
..
batch_size
)
.map
(|
_
|
{
// Generate a value in the range [0, 2^63 - 1] to match Python's random.randint(0, 2**63 - 1)
rand
::
random
::
<
u64
>
()
&
(
i64
::
MAX
as
u64
)
})
.collect
::
<
Vec
<
_
>>
());
}
else
{
// Single scenario - create single bootstrap values
json
[
"bootstrap_host"
]
=
json!
(
hostname
);
json
[
"bootstrap_port"
]
=
json!
(
bootstrap_port
);
json
[
"bootstrap_room"
]
=
json!
(
rand
::
random
::
<
u64
>
()
&
(
i64
::
MAX
as
u64
));
}
Ok
(())
}
/// Extract batch size from various JSON request formats
/// Handles chat completions, completions, and generate requests
fn
extract_batch_size
(
json
:
&
Value
)
->
Result
<
Option
<
usize
>
,
String
>
{
// Check for chat completions 'n' parameter (number of choices)
if
let
Some
(
n
)
=
json
.get
(
"n"
)
.and_then
(|
v
|
v
.as_u64
())
{
if
n
>
1
{
return
Ok
(
Some
(
n
as
usize
));
}
}
// Check for array prompts (completions API)
if
let
Some
(
prompt
)
=
json
.get
(
"prompt"
)
{
if
let
Some
(
arr
)
=
prompt
.as_array
()
{
if
arr
.is_empty
()
{
return
Err
(
"Batch prompt array is empty"
.to_string
());
}
return
Ok
(
Some
(
arr
.len
()));
}
}
// Check for array texts (generate API)
if
let
Some
(
text
)
=
json
.get
(
"text"
)
{
if
let
Some
(
arr
)
=
text
.as_array
()
{
if
arr
.is_empty
()
{
return
Err
(
"Batch text array is empty"
.to_string
());
}
return
Ok
(
Some
(
arr
.len
()));
}
}
// Check for batch input_ids (generate API)
if
let
Some
(
input_ids
)
=
json
.get
(
"input_ids"
)
{
if
let
Some
(
arr
)
=
input_ids
.as_array
()
{
if
arr
.is_empty
()
{
return
Err
(
"Batch input_ids array is empty"
.to_string
());
}
return
Ok
(
Some
(
arr
.len
()));
}
}
// No batch indicators found - single request
Ok
(
None
)
}
#[cfg(test)]
mod
tests
{
use
super
::
*
;
use
crate
::
core
::
BasicWorker
;
use
serde_json
::
json
;
fn
create_test_worker
()
->
BasicWorker
{
BasicWorker
::
new
(
"http://test-server:8000"
.to_string
(),
WorkerType
::
Prefill
{
bootstrap_port
:
Some
(
5678
),
},
)
}
#[test]
fn
test_inject_bootstrap_single_request
()
{
let
worker
=
create_test_worker
();
let
mut
json
=
json!
({
"model"
:
"test-model"
,
"prompt"
:
"Hello world"
,
"max_tokens"
:
100
});
let
result
=
inject_bootstrap_fields
(
&
mut
json
,
&
worker
);
assert
!
(
result
.is_ok
());
// Verify bootstrap fields were added
assert_eq!
(
json
[
"bootstrap_host"
],
json!
(
"test-server"
));
assert_eq!
(
json
[
"bootstrap_port"
],
json!
(
5678
));
assert
!
(
json
[
"bootstrap_room"
]
.is_number
());
// Verify original fields preserved
assert_eq!
(
json
[
"model"
],
json!
(
"test-model"
));
assert_eq!
(
json
[
"prompt"
],
json!
(
"Hello world"
));
assert_eq!
(
json
[
"max_tokens"
],
json!
(
100
));
}
#[test]
fn
test_inject_bootstrap_batch_prompt
()
{
let
worker
=
create_test_worker
();
let
mut
json
=
json!
({
"model"
:
"test-model"
,
"prompt"
:
[
"Hello"
,
"World"
],
"max_tokens"
:
100
});
let
result
=
inject_bootstrap_fields
(
&
mut
json
,
&
worker
);
assert
!
(
result
.is_ok
());
// Verify batch bootstrap fields
assert_eq!
(
json
[
"bootstrap_host"
],
json!
([
"test-server"
,
"test-server"
])
);
assert_eq!
(
json
[
"bootstrap_port"
],
json!
([
5678
,
5678
]));
let
bootstrap_rooms
=
json
[
"bootstrap_room"
]
.as_array
()
.unwrap
();
assert_eq!
(
bootstrap_rooms
.len
(),
2
);
for
room
in
bootstrap_rooms
{
assert
!
(
room
.is_number
());
let
room_val
=
room
.as_u64
()
.unwrap
();
assert
!
(
room_val
<=
i64
::
MAX
as
u64
);
}
}
#[test]
fn
test_inject_bootstrap_chat_n_parameter
()
{
let
worker
=
create_test_worker
();
let
mut
json
=
json!
({
"model"
:
"gpt-4"
,
"messages"
:
[{
"role"
:
"user"
,
"content"
:
"Hello"
}],
"n"
:
3
});
let
result
=
inject_bootstrap_fields
(
&
mut
json
,
&
worker
);
assert
!
(
result
.is_ok
());
// Verify batch bootstrap fields for n=3
let
bootstrap_hosts
=
json
[
"bootstrap_host"
]
.as_array
()
.unwrap
();
assert_eq!
(
bootstrap_hosts
.len
(),
3
);
assert_eq!
(
bootstrap_hosts
[
0
],
json!
(
"test-server"
));
let
bootstrap_ports
=
json
[
"bootstrap_port"
]
.as_array
()
.unwrap
();
assert_eq!
(
bootstrap_ports
.len
(),
3
);
assert_eq!
(
bootstrap_ports
[
0
],
json!
(
5678
));
let
bootstrap_rooms
=
json
[
"bootstrap_room"
]
.as_array
()
.unwrap
();
assert_eq!
(
bootstrap_rooms
.len
(),
3
);
}
#[test]
fn
test_inject_bootstrap_generate_text_array
()
{
let
worker
=
create_test_worker
();
let
mut
json
=
json!
({
"text"
:
[
"First prompt"
,
"Second prompt"
],
"stream"
:
false
});
let
result
=
inject_bootstrap_fields
(
&
mut
json
,
&
worker
);
assert
!
(
result
.is_ok
());
// Verify batch bootstrap fields
let
bootstrap_hosts
=
json
[
"bootstrap_host"
]
.as_array
()
.unwrap
();
assert_eq!
(
bootstrap_hosts
.len
(),
2
);
let
bootstrap_rooms
=
json
[
"bootstrap_room"
]
.as_array
()
.unwrap
();
assert_eq!
(
bootstrap_rooms
.len
(),
2
);
// Ensure room values are different (randomness)
assert_ne!
(
bootstrap_rooms
[
0
],
bootstrap_rooms
[
1
]);
}
#[test]
fn
test_inject_bootstrap_input_ids_array
()
{
let
worker
=
create_test_worker
();
let
mut
json
=
json!
({
"input_ids"
:
[[
1
,
2
,
3
],
[
4
,
5
,
6
]],
"stream"
:
false
});
let
result
=
inject_bootstrap_fields
(
&
mut
json
,
&
worker
);
assert
!
(
result
.is_ok
());
// Verify batch bootstrap fields
let
bootstrap_hosts
=
json
[
"bootstrap_host"
]
.as_array
()
.unwrap
();
assert_eq!
(
bootstrap_hosts
.len
(),
2
);
}
#[test]
fn
test_extract_batch_size_empty_array_error
()
{
let
json
=
json!
({
"prompt"
:
[],
"model"
:
"test"
});
let
result
=
extract_batch_size
(
&
json
);
assert
!
(
result
.is_err
());
assert
!
(
result
.unwrap_err
()
.contains
(
"empty"
));
}
#[test]
fn
test_extract_batch_size_single_requests
()
{
// Single string prompt
let
json
=
json!
({
"prompt"
:
"Hello world"
,
"model"
:
"test"
});
assert_eq!
(
extract_batch_size
(
&
json
)
.unwrap
(),
None
);
// Single text
let
json
=
json!
({
"text"
:
"Hello world"
,
"stream"
:
false
});
assert_eq!
(
extract_batch_size
(
&
json
)
.unwrap
(),
None
);
// Chat with n=1 (default)
let
json
=
json!
({
"messages"
:
[{
"role"
:
"user"
,
"content"
:
"Hello"
}],
"n"
:
1
});
assert_eq!
(
extract_batch_size
(
&
json
)
.unwrap
(),
None
);
// Chat without n parameter
let
json
=
json!
({
"messages"
:
[{
"role"
:
"user"
,
"content"
:
"Hello"
}]
});
assert_eq!
(
extract_batch_size
(
&
json
)
.unwrap
(),
None
);
}
#[test]
fn
test_inject_bootstrap_preserves_sglang_fields
()
{
let
worker
=
create_test_worker
();
let
mut
json
=
json!
({
"model"
:
"test-model"
,
"prompt"
:
"Hello"
,
// SGLang extensions should be preserved
"top_k"
:
40
,
"min_p"
:
0.05
,
"repetition_penalty"
:
1.1
,
"regex"
:
"test_pattern"
,
"lora_path"
:
"test.bin"
,
"no_stop_trim"
:
true
,
"ignore_eos"
:
false
});
let
result
=
inject_bootstrap_fields
(
&
mut
json
,
&
worker
);
assert
!
(
result
.is_ok
());
// Verify bootstrap fields added
assert
!
(
json
.get
(
"bootstrap_host"
)
.is_some
());
assert
!
(
json
.get
(
"bootstrap_port"
)
.is_some
());
assert
!
(
json
.get
(
"bootstrap_room"
)
.is_some
());
// Verify all SGLang fields preserved
assert_eq!
(
json
[
"top_k"
],
json!
(
40
));
assert_eq!
(
json
[
"min_p"
],
json!
(
0.05
));
assert_eq!
(
json
[
"repetition_penalty"
],
json!
(
1.1
));
assert_eq!
(
json
[
"regex"
],
json!
(
"test_pattern"
));
assert_eq!
(
json
[
"lora_path"
],
json!
(
"test.bin"
));
assert_eq!
(
json
[
"no_stop_trim"
],
json!
(
true
));
assert_eq!
(
json
[
"ignore_eos"
],
json!
(
false
));
}
#[test]
fn
test_bootstrap_room_range
()
{
let
worker
=
create_test_worker
();
// Test single request room generation
for
_
in
0
..
1000
{
let
mut
json
=
json!
({
"prompt"
:
"test"
});
inject_bootstrap_fields
(
&
mut
json
,
&
worker
)
.unwrap
();
let
room
=
json
[
"bootstrap_room"
]
.as_u64
()
.unwrap
();
assert
!
(
room
<=
i64
::
MAX
as
u64
,
"Room {} exceeds i64::MAX"
,
room
);
}
// Test batch request room generation
for
_
in
0
..
100
{
let
mut
json
=
json!
({
"prompt"
:
[
"test1"
,
"test2"
]});
inject_bootstrap_fields
(
&
mut
json
,
&
worker
)
.unwrap
();
let
rooms
=
json
[
"bootstrap_room"
]
.as_array
()
.unwrap
();
for
room_val
in
rooms
{
let
room
=
room_val
.as_u64
()
.unwrap
();
assert
!
(
room
<=
i64
::
MAX
as
u64
,
"Room {} exceeds i64::MAX"
,
room
);
}
}
}
#[test]
fn
test_worker_without_bootstrap_port
()
{
let
worker
=
BasicWorker
::
new
(
"http://decode-only:8000"
.to_string
(),
WorkerType
::
Decode
,
// No bootstrap port
);
let
mut
json
=
json!
({
"prompt"
:
"Hello world"
});
let
result
=
inject_bootstrap_fields
(
&
mut
json
,
&
worker
);
assert
!
(
result
.is_ok
());
// Verify bootstrap fields with null port
assert_eq!
(
json
[
"bootstrap_host"
],
json!
(
"decode-only"
));
assert_eq!
(
json
[
"bootstrap_port"
],
json!
(
null
));
assert
!
(
json
[
"bootstrap_room"
]
.is_number
());
}
}
sgl-router/src/routers/mod.rs
View file @
a69b6370
...
@@ -11,7 +11,6 @@ use std::fmt::Debug;
...
@@ -11,7 +11,6 @@ use std::fmt::Debug;
use
crate
::
openai_api_types
::{
ChatCompletionRequest
,
CompletionRequest
,
GenerateRequest
};
use
crate
::
openai_api_types
::{
ChatCompletionRequest
,
CompletionRequest
,
GenerateRequest
};
pub
mod
bootstrap_injector
;
pub
mod
factory
;
pub
mod
factory
;
pub
mod
pd_router
;
pub
mod
pd_router
;
pub
mod
pd_types
;
pub
mod
pd_types
;
...
...
sgl-router/src/routers/pd_router.rs
View file @
a69b6370
// PD (Prefill-Decode) Router Implementation
// PD (Prefill-Decode) Router Implementation
// This module handles routing for disaggregated prefill-decode systems
// This module handles routing for disaggregated prefill-decode systems
use
super
::
bootstrap_injector
::
inject_bootstrap_fields
;
use
super
::
pd_types
::{
api_path
,
PDRouterError
};
use
super
::
pd_types
::{
api_path
,
PDRouterError
};
use
crate
::
config
::
types
::
RetryConfig
;
use
crate
::
config
::
types
::
RetryConfig
;
use
crate
::
core
::{
HealthChecker
,
Worker
,
WorkerFactory
,
WorkerLoadGuard
};
use
crate
::
core
::{
HealthChecker
,
Worker
,
WorkerFactory
,
WorkerLoadGuard
};
...
@@ -19,7 +17,6 @@ use axum::{
...
@@ -19,7 +17,6 @@ use axum::{
Json
,
Json
,
};
};
use
futures_util
::
StreamExt
;
use
futures_util
::
StreamExt
;
use
rand
::
Rng
;
use
reqwest
::
Client
;
use
reqwest
::
Client
;
use
serde_json
::
Value
;
use
serde_json
::
Value
;
use
std
::
collections
::
HashMap
;
use
std
::
collections
::
HashMap
;
...
@@ -316,17 +313,6 @@ impl PDRouter {
...
@@ -316,17 +313,6 @@ impl PDRouter {
.into_response
()
.into_response
()
}
}
// Helper to handle bootstrap injection errors
fn
handle_bootstrap_error
(
error
:
impl
std
::
fmt
::
Display
)
->
Response
{
error!
(
"Failed to add bootstrap info error={}"
,
error
);
RouterMetrics
::
record_pd_error
(
"bootstrap_injection"
);
(
StatusCode
::
INTERNAL_SERVER_ERROR
,
format!
(
"Bootstrap injection failed: {}"
,
error
),
)
.into_response
()
}
// Helper to handle serialization errors
// Helper to handle serialization errors
fn
handle_serialization_error
(
error
:
impl
std
::
fmt
::
Display
)
->
Response
{
fn
handle_serialization_error
(
error
:
impl
std
::
fmt
::
Display
)
->
Response
{
error!
(
"Failed to serialize request error={}"
,
error
);
error!
(
"Failed to serialize request error={}"
,
error
);
...
@@ -337,110 +323,87 @@ impl PDRouter {
...
@@ -337,110 +323,87 @@ impl PDRouter {
.into_response
()
.into_response
()
}
}
// Execute the dual dispatch to prefill and decode servers with retry logic
// Helper to determine batch size from a GenerateRequest
async
fn
execute_dual_dispatch
(
fn
get_generate_batch_size
(
req
:
&
GenerateRequest
)
->
Option
<
usize
>
{
&
self
,
// Check prompt array
headers
:
Option
<&
HeaderMap
>
,
if
let
Some
(
prompt
)
=
&
req
.prompt
{
json_request
:
Value
,
if
let
crate
::
openai_api_types
::
StringOrArray
::
Array
(
arr
)
=
prompt
{
route
:
&
str
,
if
!
arr
.is_empty
()
{
prefill
:
&
dyn
Worker
,
return
Some
(
arr
.len
());
decode
:
&
dyn
Worker
,
}
is_stream
:
bool
,
}
return_logprob
:
bool
,
}
start_time
:
Instant
,
// Check text array
)
->
Response
{
if
let
Some
(
text
)
=
&
req
.text
{
for
attempt
in
0
..
self
.retry_config.max_retries
{
if
text
.contains
(
"["
)
&&
text
.contains
(
"]"
)
{
if
attempt
>
0
{
// This is a simplified check - in reality we'd need to parse JSON
// Calculate backoff with exponential growth and jitter
return
None
;
// For now, fall back to non-batch
let
base_backoff
=
self
.retry_config.initial_backoff_ms
as
f64
*
self
.retry_config
.backoff_multiplier
.powf
((
attempt
-
1
)
as
f32
)
as
f64
;
let
backoff_ms
=
base_backoff
.min
(
self
.retry_config.max_backoff_ms
as
f64
)
as
u64
;
// Add jitter to prevent thundering herd
let
jitter
=
{
let
mut
rng
=
rand
::
thread_rng
();
rng
.gen_range
(
0
..
backoff_ms
/
2
)
};
let
total_backoff
=
Duration
::
from_millis
(
backoff_ms
+
jitter
);
info!
(
"Retrying request (attempt {}/{}) after {:?} backoff"
,
attempt
+
1
,
self
.retry_config.max_retries
,
total_backoff
);
tokio
::
time
::
sleep
(
total_backoff
)
.await
;
}
}
}
debug!
(
None
"Executing request attempt {}/{}"
,
attempt
+
1
,
self
.retry_config.max_retries
);
let
result
=
self
.execute_dual_dispatch_inner
(
headers
,
json_request
.clone
(),
route
,
prefill
,
decode
,
is_stream
,
return_logprob
,
start_time
,
)
.await
;
// Check if we should retry based on the response status
let
status
=
result
.status
();
debug!
(
"Request attempt {} returned status: {}"
,
attempt
+
1
,
status
);
// Don't retry client errors (4xx) or successful responses
if
status
.is_client_error
()
||
status
.is_success
()
{
debug!
(
"Returning response with status {} (no retry needed)"
,
status
);
return
result
;
}
}
// Check if this is the last attempt
// Helper to determine batch size from a ChatCompletionRequest
if
attempt
==
self
.retry_config.max_retries
-
1
{
fn
get_chat_batch_size
(
req
:
&
ChatCompletionRequest
)
->
Option
<
usize
>
{
warn!
(
"Final attempt failed with status {}"
,
status
);
// Check 'n' parameter for multiple responses
return
result
;
if
let
Some
(
n
)
=
req
.n
{
if
n
>
1
{
return
Some
(
n
as
usize
);
}
}
None
}
}
// Log retry decision for retryable errors
// Helper to determine batch size from a CompletionRequest
if
status
.is_server_error
()
fn
get_completion_batch_size
(
req
:
&
CompletionRequest
)
->
Option
<
usize
>
{
||
status
==
StatusCode
::
BAD_GATEWAY
// Check prompt array
||
status
==
StatusCode
::
GATEWAY_TIMEOUT
if
let
crate
::
openai_api_types
::
StringOrArray
::
Array
(
arr
)
=
&
req
.prompt
{
{
if
!
arr
.is_empty
()
{
warn!
(
return
Some
(
arr
.len
());
"Retryable error status: {} on attempt {}/{}. Will retry."
,
status
,
attempt
+
1
,
self
.retry_config.max_retries
);
}
else
{
// Don't retry other statuses
debug!
(
"Status {} is not retryable, returning response"
,
status
);
return
result
;
}
}
}
}
None
}
// This should never be reached due to the loop logic, but just in case
// Helper to create request with bootstrap fields
unreachable!
(
"Retry loop completed without returning"
)
fn
create_request_with_bootstrap
<
T
:
serde
::
Serialize
>
(
request
:
&
T
,
prefill_worker
:
&
dyn
Worker
,
batch_size
:
Option
<
usize
>
,
)
->
Result
<
serde_json
::
Value
,
serde_json
::
Error
>
{
// Get bootstrap port from prefill worker
let
bootstrap_port
=
match
prefill_worker
.worker_type
()
{
crate
::
core
::
WorkerType
::
Prefill
{
bootstrap_port
}
=>
bootstrap_port
,
_
=>
None
,
};
let
hostname
=
super
::
pd_types
::
get_hostname
(
prefill_worker
.url
());
// Create optimized request with bootstrap fields
if
let
Some
(
batch_size
)
=
batch_size
{
// Batch request
let
request_with_bootstrap
=
super
::
pd_types
::
BatchRequestWithBootstrap
{
original
:
request
,
bootstrap_host
:
vec!
[
hostname
;
batch_size
],
bootstrap_port
:
vec!
[
bootstrap_port
;
batch_size
],
bootstrap_room
:
(
0
..
batch_size
)
.map
(|
_
|
super
::
pd_types
::
generate_room_id
())
.collect
(),
};
serde_json
::
to_value
(
&
request_with_bootstrap
)
}
else
{
// Single request
let
request_with_bootstrap
=
super
::
pd_types
::
RequestWithBootstrap
{
original
:
request
,
bootstrap_host
:
hostname
,
bootstrap_port
,
bootstrap_room
:
super
::
pd_types
::
generate_room_id
(),
};
serde_json
::
to_value
(
&
request_with_bootstrap
)
}
}
}
//
Inner implementation of dual dispatch (extracted for retry logic)
//
Execute the dual dispatch to prefill and decode servers
async
fn
execute_dual_dispatch
_inner
(
async
fn
execute_dual_dispatch
(
&
self
,
&
self
,
headers
:
Option
<&
HeaderMap
>
,
headers
:
Option
<&
HeaderMap
>
,
json_request
:
Value
,
json_request
:
Value
,
...
@@ -467,6 +430,9 @@ impl PDRouter {
...
@@ -467,6 +430,9 @@ impl PDRouter {
prefill
.url
(),
prefill
.url
(),
decode
.url
()
decode
.url
()
);
);
if
return_logprob
{
// When we need logprobs, wait for both responses
let
(
prefill_result
,
decode_result
)
=
let
(
prefill_result
,
decode_result
)
=
tokio
::
join!
(
prefill_request
.send
(),
decode_request
.send
());
tokio
::
join!
(
prefill_request
.send
(),
decode_request
.send
());
debug!
(
"Received responses from both servers"
);
debug!
(
"Received responses from both servers"
);
...
@@ -478,17 +444,8 @@ impl PDRouter {
...
@@ -478,17 +444,8 @@ impl PDRouter {
RouterMetrics
::
record_pd_prefill_request
(
prefill
.url
());
RouterMetrics
::
record_pd_prefill_request
(
prefill
.url
());
RouterMetrics
::
record_pd_decode_request
(
decode
.url
());
RouterMetrics
::
record_pd_decode_request
(
decode
.url
());
// Process prefill response
// Process decode response with prefill for logprobs
let
(
_
prefill_status
,
prefill_body
)
=
match
self
debug!
(
"Processing decode response with logprobs"
);
.process_prefill_response
(
prefill_result
,
prefill
.url
(),
return_logprob
)
.await
{
Ok
(
result
)
=>
result
,
Err
(
error_response
)
=>
return
error_response
,
};
// Process decode response
debug!
(
"Processing decode response"
);
match
decode_result
{
match
decode_result
{
Ok
(
res
)
=>
{
Ok
(
res
)
=>
{
let
status
=
StatusCode
::
from_u16
(
res
.status
()
.as_u16
())
let
status
=
StatusCode
::
from_u16
(
res
.status
()
.as_u16
())
...
@@ -509,40 +466,45 @@ impl PDRouter {
...
@@ -509,40 +466,45 @@ impl PDRouter {
return
(
status
,
error_body
)
.into_response
();
return
(
status
,
error_body
)
.into_response
();
}
}
Err
(
e
)
=>
{
Err
(
e
)
=>
{
return
(
status
,
format!
(
"Decode server error: {}"
,
e
))
.into_response
();
return
(
status
,
format!
(
"Decode server error: {}"
,
e
))
.into_response
();
}
}
}
}
}
}
// Process prefill response for logprobs
let
prefill_body
=
match
self
.process_prefill_response
(
prefill_result
,
prefill
.url
(),
return_logprob
)
.await
{
Ok
((
_
,
body
))
=>
body
,
Err
(
error_response
)
=>
return
error_response
,
};
if
is_stream
{
if
is_stream
{
// Streaming response
// Streaming response with logprobs
let
prefill_logprobs
=
if
return_logprob
{
let
prefill_logprobs
=
prefill_body
prefill_body
.as_ref
()
.as_ref
()
.and_then
(|
body
|
serde_json
::
from_slice
::
<
Value
>
(
body
)
.ok
())
.and_then
(|
body
|
serde_json
::
from_slice
::
<
Value
>
(
body
)
.ok
())
.and_then
(|
json
|
{
.and_then
(|
json
|
{
json
.pointer
(
"/meta_info/input_token_logprobs"
)
.cloned
()
json
.pointer
(
"/meta_info/input_token_logprobs"
)
.cloned
()
})
});
}
else
{
None
};
let
decode_url
=
if
!
return_logprob
{
Some
(
decode
.url
()
.to_string
())
}
else
{
None
};
Self
::
create_streaming_response
(
Self
::
create_streaming_response
(
res
.bytes_stream
(),
res
.bytes_stream
(),
status
,
status
,
prefill_logprobs
,
prefill_logprobs
,
return_logprob
,
return_logprob
,
decode_url
,
None
,
)
)
}
else
{
}
else
{
// Non-streaming response - use helper
// Non-streaming response with logprobs
self
.process_non_streaming_response
(
res
,
status
,
return_logprob
,
prefill_body
)
self
.process_non_streaming_response
(
res
,
status
,
return_logprob
,
prefill_body
,
)
.await
.await
}
}
}
}
...
@@ -560,6 +522,101 @@ impl PDRouter {
...
@@ -560,6 +522,101 @@ impl PDRouter {
.into_response
()
.into_response
()
}
}
}
}
}
else
{
// When we don't need logprobs, only wait for decode response
// Send both requests concurrently but don't wait for prefill
// Add headers to minimize response size when we don't need the body
let
prefill_future
=
prefill_request
.header
(
"Connection"
,
"close"
)
.send
();
let
decode_future
=
decode_request
.send
();
tokio
::
spawn
(
async
move
{
if
let
Ok
(
response
)
=
prefill_future
.await
{
// Consume with a short timeout to free connection quickly
let
consume_future
=
async
{
let
_
=
response
.bytes
()
.await
;
};
// Give it 100ms to consume, then abandon
let
_
=
tokio
::
time
::
timeout
(
Duration
::
from_millis
(
100
),
consume_future
)
.await
;
}
});
// Wait only for decode response
let
decode_result
=
decode_future
.await
;
debug!
(
"Received decode response"
);
// Update metrics
let
duration
=
start_time
.elapsed
();
RouterMetrics
::
record_pd_request_duration
(
route
,
duration
);
RouterMetrics
::
record_pd_request
(
route
);
RouterMetrics
::
record_pd_prefill_request
(
prefill
.url
());
RouterMetrics
::
record_pd_decode_request
(
decode
.url
());
// Process decode response immediately
debug!
(
"Processing decode response (no logprobs)"
);
match
decode_result
{
Ok
(
res
)
=>
{
let
status
=
StatusCode
::
from_u16
(
res
.status
()
.as_u16
())
.unwrap_or
(
StatusCode
::
INTERNAL_SERVER_ERROR
);
debug!
(
"Decode response status: {}"
,
status
);
if
!
status
.is_success
()
{
RouterMetrics
::
record_pd_decode_error
(
decode
.url
());
error!
(
"Decode server returned error status decode_url={} status={}"
,
decode
.url
(),
status
);
// Return the error response from decode server
match
res
.bytes
()
.await
{
Ok
(
error_body
)
=>
(
status
,
error_body
)
.into_response
(),
Err
(
e
)
=>
{
(
status
,
format!
(
"Decode server error: {}"
,
e
))
.into_response
()
}
}
}
else
if
is_stream
{
// Streaming response without logprobs - direct passthrough
let
decode_url
=
decode
.url
()
.to_string
();
Self
::
create_streaming_response
(
res
.bytes_stream
(),
status
,
None
,
false
,
Some
(
decode_url
),
)
}
else
{
// Non-streaming response without logprobs - direct passthrough like fast version
match
res
.bytes
()
.await
{
Ok
(
decode_body
)
=>
(
status
,
decode_body
)
.into_response
(),
Err
(
e
)
=>
{
error!
(
"Failed to read decode response: {}"
,
e
);
(
StatusCode
::
INTERNAL_SERVER_ERROR
,
"Failed to read response"
)
.into_response
()
}
}
}
}
Err
(
e
)
=>
{
error!
(
decode_url
=
%
decode
.url
(),
error
=
%
e
,
"Decode request failed"
);
RouterMetrics
::
record_pd_decode_error
(
decode
.url
());
(
StatusCode
::
BAD_GATEWAY
,
format!
(
"Decode server error: {}"
,
e
),
)
.into_response
()
}
}
}
}
// Check if either prefill or decode policy needs request text
fn
policies_need_request_text
(
&
self
)
->
bool
{
self
.prefill_policy
.needs_request_text
()
||
self
.decode_policy
.needs_request_text
()
}
}
// Select a pair of prefill and decode servers
// Select a pair of prefill and decode servers
...
@@ -1311,23 +1368,23 @@ impl RouterTrait for PDRouter {
...
@@ -1311,23 +1368,23 @@ impl RouterTrait for PDRouter {
)
->
Response
{
)
->
Response
{
let
start
=
Instant
::
now
();
let
start
=
Instant
::
now
();
// Convert directly to JSON to preserve all fields automatically
let
mut
json
=
match
serde_json
::
to_value
(
body
)
{
Ok
(
json
)
=>
json
,
Err
(
e
)
=>
return
Self
::
handle_serialization_error
(
e
),
};
// Extract flags for routing logic
// Extract flags for routing logic
let
is_stream
=
body
.stream
;
let
is_stream
=
body
.stream
;
let
return_logprob
=
body
.return_logprob
;
let
return_logprob
=
body
.return_logprob
;
// Extract text for cache-aware routing
// Extract text for cache-aware routing only if needed
let
request_text
=
body
.text
.as_deref
()
.or_else
(||
{
let
request_text
=
if
self
.policies_need_request_text
()
{
body
.text
.as_deref
()
.or_else
(||
{
body
.prompt
.as_ref
()
.and_then
(|
p
|
match
p
{
body
.prompt
.as_ref
()
.and_then
(|
p
|
match
p
{
crate
::
openai_api_types
::
StringOrArray
::
String
(
s
)
=>
Some
(
s
.as_str
()),
crate
::
openai_api_types
::
StringOrArray
::
String
(
s
)
=>
Some
(
s
.as_str
()),
crate
::
openai_api_types
::
StringOrArray
::
Array
(
v
)
=>
v
.first
()
.map
(|
s
|
s
.as_str
()),
crate
::
openai_api_types
::
StringOrArray
::
Array
(
v
)
=>
{
v
.first
()
.map
(|
s
|
s
.as_str
())
}
})
})
});
})
}
else
{
None
};
// Select servers
// Select servers
let
(
prefill
,
decode
)
=
match
self
.select_pd_pair
(
request_text
)
.await
{
let
(
prefill
,
decode
)
=
match
self
.select_pd_pair
(
request_text
)
.await
{
...
@@ -1342,10 +1399,12 @@ impl RouterTrait for PDRouter {
...
@@ -1342,10 +1399,12 @@ impl RouterTrait for PDRouter {
decode
.url
()
decode
.url
()
);
);
// Inject bootstrap fields directly into JSON
// Create optimized request with bootstrap fields
if
let
Err
(
e
)
=
inject_bootstrap_fields
(
&
mut
json
,
prefill
.as_ref
())
{
let
batch_size
=
Self
::
get_generate_batch_size
(
body
);
return
Self
::
handle_bootstrap_error
(
e
);
let
json
=
match
Self
::
create_request_with_bootstrap
(
body
,
prefill
.as_ref
(),
batch_size
)
{
}
Ok
(
json
)
=>
json
,
Err
(
e
)
=>
return
Self
::
handle_serialization_error
(
e
),
};
// Execute dual dispatch
// Execute dual dispatch
self
.execute_dual_dispatch
(
self
.execute_dual_dispatch
(
...
@@ -1368,27 +1427,29 @@ impl RouterTrait for PDRouter {
...
@@ -1368,27 +1427,29 @@ impl RouterTrait for PDRouter {
)
->
Response
{
)
->
Response
{
let
start
=
Instant
::
now
();
let
start
=
Instant
::
now
();
// Convert directly to JSON to preserve all fields automatically
let
mut
json
=
match
serde_json
::
to_value
(
body
)
{
Ok
(
json
)
=>
json
,
Err
(
e
)
=>
return
Self
::
handle_serialization_error
(
e
),
};
// Extract flags for routing logic
// Extract flags for routing logic
let
is_stream
=
body
.stream
;
let
is_stream
=
body
.stream
;
let
return_logprob
=
body
.logprobs
;
let
return_logprob
=
body
.logprobs
;
// Extract text for cache-aware routing from chat messages
// Extract text for cache-aware routing from chat messages only if needed
let
request_text
=
body
.messages
.first
()
.and_then
(|
msg
|
match
msg
{
let
request_text
=
if
self
.policies_need_request_text
()
{
body
.messages
.first
()
.and_then
(|
msg
|
match
msg
{
crate
::
openai_api_types
::
ChatMessage
::
User
{
content
,
..
}
=>
{
crate
::
openai_api_types
::
ChatMessage
::
User
{
content
,
..
}
=>
{
match
content
{
match
content
{
crate
::
openai_api_types
::
UserMessageContent
::
Text
(
text
)
=>
Some
(
text
.as_str
()),
crate
::
openai_api_types
::
UserMessageContent
::
Text
(
text
)
=>
{
Some
(
text
.as_str
())
}
crate
::
openai_api_types
::
UserMessageContent
::
Parts
(
_
)
=>
None
,
// Skip complex content
crate
::
openai_api_types
::
UserMessageContent
::
Parts
(
_
)
=>
None
,
// Skip complex content
}
}
}
}
crate
::
openai_api_types
::
ChatMessage
::
System
{
content
,
..
}
=>
Some
(
content
.as_str
()),
crate
::
openai_api_types
::
ChatMessage
::
System
{
content
,
..
}
=>
{
Some
(
content
.as_str
())
}
_
=>
None
,
_
=>
None
,
});
})
}
else
{
None
};
// Select servers
// Select servers
let
(
prefill
,
decode
)
=
match
self
.select_pd_pair
(
request_text
)
.await
{
let
(
prefill
,
decode
)
=
match
self
.select_pd_pair
(
request_text
)
.await
{
...
@@ -1403,10 +1464,12 @@ impl RouterTrait for PDRouter {
...
@@ -1403,10 +1464,12 @@ impl RouterTrait for PDRouter {
decode
.url
()
decode
.url
()
);
);
// Inject bootstrap fields directly into JSON
// Create optimized request with bootstrap fields
if
let
Err
(
e
)
=
inject_bootstrap_fields
(
&
mut
json
,
prefill
.as_ref
())
{
let
batch_size
=
Self
::
get_chat_batch_size
(
body
);
return
Self
::
handle_bootstrap_error
(
e
);
let
json
=
match
Self
::
create_request_with_bootstrap
(
body
,
prefill
.as_ref
(),
batch_size
)
{
}
Ok
(
json
)
=>
json
,
Err
(
e
)
=>
return
Self
::
handle_serialization_error
(
e
),
};
// Execute dual dispatch
// Execute dual dispatch
self
.execute_dual_dispatch
(
self
.execute_dual_dispatch
(
...
@@ -1429,20 +1492,18 @@ impl RouterTrait for PDRouter {
...
@@ -1429,20 +1492,18 @@ impl RouterTrait for PDRouter {
)
->
Response
{
)
->
Response
{
let
start
=
Instant
::
now
();
let
start
=
Instant
::
now
();
// Convert directly to JSON to preserve all fields automatically
let
mut
json
=
match
serde_json
::
to_value
(
body
)
{
Ok
(
json
)
=>
json
,
Err
(
e
)
=>
return
Self
::
handle_serialization_error
(
e
),
};
// Extract flags for routing logic
// Extract flags for routing logic
let
is_stream
=
body
.stream
;
let
is_stream
=
body
.stream
;
let
return_logprob
=
body
.logprobs
.is_some
();
let
return_logprob
=
body
.logprobs
.is_some
();
// Extract text for cache-aware routing
// Extract text for cache-aware routing only if needed
let
request_text
=
match
&
body
.prompt
{
let
request_text
=
if
self
.policies_need_request_text
()
{
match
&
body
.prompt
{
crate
::
openai_api_types
::
StringOrArray
::
String
(
s
)
=>
Some
(
s
.as_str
()),
crate
::
openai_api_types
::
StringOrArray
::
String
(
s
)
=>
Some
(
s
.as_str
()),
crate
::
openai_api_types
::
StringOrArray
::
Array
(
v
)
=>
v
.first
()
.map
(|
s
|
s
.as_str
()),
crate
::
openai_api_types
::
StringOrArray
::
Array
(
v
)
=>
v
.first
()
.map
(|
s
|
s
.as_str
()),
}
}
else
{
None
};
};
// Select servers
// Select servers
...
@@ -1458,10 +1519,12 @@ impl RouterTrait for PDRouter {
...
@@ -1458,10 +1519,12 @@ impl RouterTrait for PDRouter {
decode
.url
()
decode
.url
()
);
);
// Inject bootstrap fields directly into JSON
// Create optimized request with bootstrap fields
if
let
Err
(
e
)
=
inject_bootstrap_fields
(
&
mut
json
,
prefill
.as_ref
())
{
let
batch_size
=
Self
::
get_completion_batch_size
(
body
);
return
Self
::
handle_bootstrap_error
(
e
);
let
json
=
match
Self
::
create_request_with_bootstrap
(
body
,
prefill
.as_ref
(),
batch_size
)
{
}
Ok
(
json
)
=>
json
,
Err
(
e
)
=>
return
Self
::
handle_serialization_error
(
e
),
};
// Execute dual dispatch
// Execute dual dispatch
self
.execute_dual_dispatch
(
self
.execute_dual_dispatch
(
...
@@ -1937,6 +2000,13 @@ mod tests {
...
@@ -1937,6 +2000,13 @@ mod tests {
assert
!
(
result
.is_ok
());
assert
!
(
result
.is_ok
());
}
}
// ============= Bootstrap Injection Tests =============
// Note: These tests are commented out as we've moved to the optimized bootstrap injection
// approach that doesn't use the Bootstrap trait on GenerateReqInput anymore.
// TODO: Add new tests for the optimized bootstrap injection approach using
// RequestWithBootstrap and BatchRequestWithBootstrap wrappers
// ============= Worker Selection Tests =============
// ============= Worker Selection Tests =============
#[tokio::test]
#[tokio::test]
...
@@ -2114,158 +2184,4 @@ mod tests {
...
@@ -2114,158 +2184,4 @@ mod tests {
let
workers
=
router
.prefill_workers
.read
()
.unwrap
();
let
workers
=
router
.prefill_workers
.read
()
.unwrap
();
assert_eq!
(
workers
.len
(),
5
);
assert_eq!
(
workers
.len
(),
5
);
}
}
#[tokio::test]
async
fn
test_simplified_routing_preserves_sglang_fields
()
{
use
crate
::
openai_api_types
::
GenerateRequest
;
use
crate
::
routers
::
bootstrap_injector
::
inject_bootstrap_fields
;
// Create a test worker
let
worker
=
BasicWorker
::
new
(
"http://test-server:8000"
.to_string
(),
WorkerType
::
Prefill
{
bootstrap_port
:
Some
(
5678
),
},
);
// Create a GenerateRequest with SGLang extensions
let
mut
session_params
=
std
::
collections
::
HashMap
::
new
();
session_params
.insert
(
"test_key"
.to_string
(),
serde_json
::
json!
(
"test_value"
));
let
request
=
GenerateRequest
{
text
:
Some
(
"Test prompt"
.to_string
()),
stream
:
false
,
return_logprob
:
true
,
// SGLang extensions
lora_path
:
Some
(
crate
::
openai_api_types
::
LoRAPath
::
Single
(
Some
(
"test.bin"
.to_string
(),
))),
session_params
:
Some
(
session_params
.clone
()),
return_hidden_states
:
true
,
rid
:
Some
(
"test-request-id"
.to_string
()),
// Other fields default to None/false
prompt
:
None
,
input_ids
:
None
,
parameters
:
None
,
sampling_params
:
None
,
};
// Convert to JSON (simulating the simplified routing path)
let
mut
json
=
serde_json
::
to_value
(
&
request
)
.unwrap
();
// Inject bootstrap fields
let
result
=
inject_bootstrap_fields
(
&
mut
json
,
&
worker
);
assert
!
(
result
.is_ok
());
// Verify all SGLang fields are preserved
assert_eq!
(
json
[
"text"
],
serde_json
::
json!
(
"Test prompt"
));
assert_eq!
(
json
[
"stream"
],
serde_json
::
json!
(
false
));
assert_eq!
(
json
[
"return_logprob"
],
serde_json
::
json!
(
true
));
assert_eq!
(
json
[
"lora_path"
],
serde_json
::
json!
(
"test.bin"
));
// LoRAPath::Single serializes as just the inner value
assert_eq!
(
json
[
"session_params"
],
serde_json
::
to_value
(
&
session_params
)
.unwrap
()
);
assert_eq!
(
json
[
"return_hidden_states"
],
serde_json
::
json!
(
true
));
assert_eq!
(
json
[
"rid"
],
serde_json
::
json!
(
"test-request-id"
));
// Verify bootstrap fields were added
assert_eq!
(
json
[
"bootstrap_host"
],
serde_json
::
json!
(
"test-server"
));
assert_eq!
(
json
[
"bootstrap_port"
],
serde_json
::
json!
(
5678
));
assert
!
(
json
[
"bootstrap_room"
]
.is_number
());
}
#[tokio::test]
async
fn
test_simplified_routing_chat_completion
()
{
use
crate
::
openai_api_types
::{
ChatCompletionRequest
,
ChatMessage
,
UserMessageContent
};
use
crate
::
routers
::
bootstrap_injector
::
inject_bootstrap_fields
;
// Create a test worker
let
worker
=
BasicWorker
::
new
(
"http://chat-server:8000"
.to_string
(),
WorkerType
::
Prefill
{
bootstrap_port
:
Some
(
9999
),
},
);
// Create a ChatCompletionRequest with SGLang extensions
let
request
=
ChatCompletionRequest
{
model
:
"gpt-4"
.to_string
(),
messages
:
vec!
[
ChatMessage
::
User
{
role
:
"user"
.to_string
(),
content
:
UserMessageContent
::
Text
(
"Hello world!"
.to_string
()),
name
:
None
,
}],
stream
:
false
,
n
:
Some
(
2
),
// This should create batch bootstrap
// SGLang extensions
top_k
:
Some
(
50
),
separate_reasoning
:
false
,
stream_reasoning
:
true
,
// Set all other fields to defaults
temperature
:
None
,
top_p
:
None
,
stream_options
:
None
,
stop
:
None
,
max_tokens
:
None
,
max_completion_tokens
:
None
,
presence_penalty
:
None
,
frequency_penalty
:
None
,
logit_bias
:
None
,
user
:
None
,
seed
:
None
,
logprobs
:
false
,
top_logprobs
:
None
,
response_format
:
None
,
tools
:
None
,
tool_choice
:
None
,
parallel_tool_calls
:
None
,
functions
:
None
,
function_call
:
None
,
min_p
:
None
,
min_tokens
:
None
,
repetition_penalty
:
None
,
regex
:
None
,
ebnf
:
None
,
stop_token_ids
:
None
,
no_stop_trim
:
false
,
ignore_eos
:
false
,
continue_final_message
:
false
,
skip_special_tokens
:
true
,
lora_path
:
None
,
session_params
:
None
,
return_hidden_states
:
false
,
};
// Convert to JSON (simulating the simplified routing path)
let
mut
json
=
serde_json
::
to_value
(
&
request
)
.unwrap
();
// Inject bootstrap fields
let
result
=
inject_bootstrap_fields
(
&
mut
json
,
&
worker
);
assert
!
(
result
.is_ok
());
// Verify original fields preserved
assert_eq!
(
json
[
"model"
],
serde_json
::
json!
(
"gpt-4"
));
assert_eq!
(
json
[
"stream"
],
serde_json
::
json!
(
false
));
assert_eq!
(
json
[
"n"
],
serde_json
::
json!
(
2
));
assert_eq!
(
json
[
"top_k"
],
serde_json
::
json!
(
50
));
assert_eq!
(
json
[
"separate_reasoning"
],
serde_json
::
json!
(
false
));
assert_eq!
(
json
[
"stream_reasoning"
],
serde_json
::
json!
(
true
));
// Verify batch bootstrap fields for n=2
let
bootstrap_hosts
=
json
[
"bootstrap_host"
]
.as_array
()
.unwrap
();
assert_eq!
(
bootstrap_hosts
.len
(),
2
);
assert_eq!
(
bootstrap_hosts
[
0
],
serde_json
::
json!
(
"chat-server"
));
assert_eq!
(
bootstrap_hosts
[
1
],
serde_json
::
json!
(
"chat-server"
));
let
bootstrap_ports
=
json
[
"bootstrap_port"
]
.as_array
()
.unwrap
();
assert_eq!
(
bootstrap_ports
.len
(),
2
);
assert_eq!
(
bootstrap_ports
[
0
],
serde_json
::
json!
(
9999
));
assert_eq!
(
bootstrap_ports
[
1
],
serde_json
::
json!
(
9999
));
let
bootstrap_rooms
=
json
[
"bootstrap_room"
]
.as_array
()
.unwrap
();
assert_eq!
(
bootstrap_rooms
.len
(),
2
);
// Rooms should be different (randomness)
assert_ne!
(
bootstrap_rooms
[
0
],
bootstrap_rooms
[
1
]);
}
}
}
sgl-router/src/routers/pd_types.rs
View file @
a69b6370
...
@@ -40,6 +40,34 @@ pub fn get_hostname(url: &str) -> String {
...
@@ -40,6 +40,34 @@ pub fn get_hostname(url: &str) -> String {
url
.split
(
':'
)
.next
()
.unwrap_or
(
"localhost"
)
.to_string
()
url
.split
(
':'
)
.next
()
.unwrap_or
(
"localhost"
)
.to_string
()
}
}
use
serde
::
Serialize
;
// Optimized bootstrap wrapper for single requests
#[derive(Serialize)]
pub
struct
RequestWithBootstrap
<
'a
,
T
:
Serialize
>
{
#[serde(flatten)]
pub
original
:
&
'a
T
,
pub
bootstrap_host
:
String
,
pub
bootstrap_port
:
Option
<
u16
>
,
pub
bootstrap_room
:
u64
,
}
// Optimized bootstrap wrapper for batch requests
#[derive(Serialize)]
pub
struct
BatchRequestWithBootstrap
<
'a
,
T
:
Serialize
>
{
#[serde(flatten)]
pub
original
:
&
'a
T
,
pub
bootstrap_host
:
Vec
<
String
>
,
pub
bootstrap_port
:
Vec
<
Option
<
u16
>>
,
pub
bootstrap_room
:
Vec
<
u64
>
,
}
// Helper to generate bootstrap room ID
pub
fn
generate_room_id
()
->
u64
{
// Generate a value in the range [0, 2^63 - 1] to match Python's random.randint(0, 2**63 - 1)
rand
::
random
::
<
u64
>
()
&
(
i64
::
MAX
as
u64
)
}
// PD-specific routing policies
// PD-specific routing policies
#[derive(Debug,
Clone,
PartialEq)]
#[derive(Debug,
Clone,
PartialEq)]
pub
enum
PDSelectionPolicy
{
pub
enum
PDSelectionPolicy
{
...
...
sgl-router/src/server.rs
View file @
a69b6370
...
@@ -269,7 +269,7 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
...
@@ -269,7 +269,7 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
let
client
=
Client
::
builder
()
let
client
=
Client
::
builder
()
.pool_idle_timeout
(
Some
(
Duration
::
from_secs
(
50
)))
.pool_idle_timeout
(
Some
(
Duration
::
from_secs
(
50
)))
.pool_max_idle_per_host
(
1
00
)
// Increase
from default of 1 to allow more concurrent
connections
.pool_max_idle_per_host
(
5
00
)
// Increase
to 500
connections
per host
.timeout
(
Duration
::
from_secs
(
config
.request_timeout_secs
))
.timeout
(
Duration
::
from_secs
(
config
.request_timeout_secs
))
.connect_timeout
(
Duration
::
from_secs
(
10
))
// Separate connection timeout
.connect_timeout
(
Duration
::
from_secs
(
10
))
// Separate connection timeout
.tcp_nodelay
(
true
)
.tcp_nodelay
(
true
)
...
...
sgl-router/tests/benchmark_integration.rs
View file @
a69b6370
...
@@ -9,7 +9,6 @@ use sglang_router_rs::openai_api_types::{
...
@@ -9,7 +9,6 @@ use sglang_router_rs::openai_api_types::{
ChatCompletionRequest
,
ChatMessage
,
CompletionRequest
,
GenerateParameters
,
GenerateRequest
,
ChatCompletionRequest
,
ChatMessage
,
CompletionRequest
,
GenerateParameters
,
GenerateRequest
,
SamplingParams
,
StringOrArray
,
UserMessageContent
,
SamplingParams
,
StringOrArray
,
UserMessageContent
,
};
};
use
sglang_router_rs
::
routers
::
bootstrap_injector
::
inject_bootstrap_fields
;
/// Create a default GenerateRequest for benchmarks with minimal fields set
/// Create a default GenerateRequest for benchmarks with minimal fields set
fn
default_generate_request
()
->
GenerateRequest
{
fn
default_generate_request
()
->
GenerateRequest
{
...
@@ -208,63 +207,6 @@ fn test_benchmark_serialization_roundtrip() {
...
@@ -208,63 +207,6 @@ fn test_benchmark_serialization_roundtrip() {
assert_eq!
(
generate_req
.return_logprob
,
deserialized
.return_logprob
);
assert_eq!
(
generate_req
.return_logprob
,
deserialized
.return_logprob
);
}
}
#[test]
fn
test_benchmark_bootstrap_injection
()
{
// Test that bootstrap injection works for benchmark types (replaces PD request adaptation)
let
generate_req
=
GenerateRequest
{
text
:
Some
(
"Test prompt"
.to_string
()),
..
default_generate_request
()
};
let
chat_req
=
ChatCompletionRequest
{
model
:
"test-model"
.to_string
(),
messages
:
vec!
[
ChatMessage
::
User
{
role
:
"user"
.to_string
(),
content
:
UserMessageContent
::
Text
(
"Test message"
.to_string
()),
name
:
None
,
}],
max_tokens
:
Some
(
150
),
max_completion_tokens
:
Some
(
150
),
temperature
:
Some
(
0.7
),
top_p
:
Some
(
1.0
),
n
:
Some
(
1
),
presence_penalty
:
Some
(
0.0
),
frequency_penalty
:
Some
(
0.0
),
parallel_tool_calls
:
Some
(
true
),
..
default_chat_completion_request
()
};
let
completion_req
=
CompletionRequest
{
model
:
"test-model"
.to_string
(),
prompt
:
StringOrArray
::
String
(
"Test prompt"
.to_string
()),
max_tokens
:
Some
(
50
),
temperature
:
Some
(
0.8
),
top_p
:
Some
(
1.0
),
n
:
Some
(
1
),
presence_penalty
:
Some
(
0.0
),
frequency_penalty
:
Some
(
0.0
),
best_of
:
Some
(
1
),
..
default_completion_request
()
};
let
worker
=
create_test_worker
();
// Test bootstrap injection (should not panic)
let
mut
generate_json
=
to_value
(
&
generate_req
)
.unwrap
();
let
mut
chat_json
=
to_value
(
&
chat_req
)
.unwrap
();
let
mut
completion_json
=
to_value
(
&
completion_req
)
.unwrap
();
assert
!
(
inject_bootstrap_fields
(
&
mut
generate_json
,
&
worker
)
.is_ok
());
assert
!
(
inject_bootstrap_fields
(
&
mut
chat_json
,
&
worker
)
.is_ok
());
assert
!
(
inject_bootstrap_fields
(
&
mut
completion_json
,
&
worker
)
.is_ok
());
// Verify bootstrap fields were added
assert
!
(
generate_json
.get
(
"bootstrap_host"
)
.is_some
());
assert
!
(
generate_json
.get
(
"bootstrap_port"
)
.is_some
());
assert
!
(
generate_json
.get
(
"bootstrap_room"
)
.is_some
());
}
#[test]
#[test]
fn
test_benchmark_direct_json_routing
()
{
fn
test_benchmark_direct_json_routing
()
{
// Test direct JSON routing functionality for benchmark types (replaces regular routing)
// Test direct JSON routing functionality for benchmark types (replaces regular routing)
...
@@ -283,47 +225,3 @@ fn test_benchmark_direct_json_routing() {
...
@@ -283,47 +225,3 @@ fn test_benchmark_direct_json_routing() {
assert
!
(
!
json_string
.is_empty
());
assert
!
(
!
json_string
.is_empty
());
assert
!
(
!
bytes
.is_empty
());
assert
!
(
!
bytes
.is_empty
());
}
}
#[test]
fn
test_benchmark_performance_baseline
()
{
// Basic performance sanity check - ensure operations complete quickly
use
std
::
time
::
Instant
;
let
generate_req
=
GenerateRequest
{
text
:
Some
(
"Short test prompt"
.to_string
()),
..
default_generate_request
()
};
// Test the actual simplified pipeline: to_value + bootstrap injection
let
start
=
Instant
::
now
();
let
worker
=
create_test_worker
();
// This mirrors the actual router pipeline
let
mut
json
=
to_value
(
&
generate_req
)
.unwrap
();
let
_
=
inject_bootstrap_fields
(
&
mut
json
,
&
worker
);
let
total_duration
=
start
.elapsed
();
assert
!
(
total_duration
.as_millis
()
<
5
,
"Simplified pipeline took too long: {:?} (should be faster than old adapter approach)"
,
total_duration
);
// Individual components should also be fast
let
start
=
Instant
::
now
();
let
_
json
=
to_value
(
&
generate_req
)
.unwrap
();
let
to_value_duration
=
start
.elapsed
();
let
start
=
Instant
::
now
();
let
mut
json
=
to_value
(
&
generate_req
)
.unwrap
();
let
_
=
inject_bootstrap_fields
(
&
mut
json
,
&
worker
);
let
inject_duration
=
start
.elapsed
();
// Bootstrap injection should be faster than the JSON conversion
assert
!
(
inject_duration
<=
to_value_duration
*
3
,
"Bootstrap injection ({:?}) should not be much slower than JSON conversion ({:?})"
,
inject_duration
,
to_value_duration
);
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment