Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
5d62b56f
Unverified
Commit
5d62b56f
authored
Aug 05, 2025
by
Simo Lin
Committed by
GitHub
Aug 05, 2025
Browse files
[router] complete router oai spec (#8828)
parent
3ae8e3ea
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
856 additions
and
365 deletions
+856
-365
sgl-router/benches/request_processing.rs
sgl-router/benches/request_processing.rs
+114
-55
sgl-router/src/openai_api_types.rs
sgl-router/src/openai_api_types.rs
+201
-2
sgl-router/src/routers/pd_types.rs
sgl-router/src/routers/pd_types.rs
+34
-72
sgl-router/src/routers/request_adapter.rs
sgl-router/src/routers/request_adapter.rs
+392
-162
sgl-router/tests/benchmark_integration.rs
sgl-router/tests/benchmark_integration.rs
+115
-74
No files found.
sgl-router/benches/request_processing.rs
View file @
5d62b56f
...
...
@@ -8,12 +8,116 @@ use sglang_router_rs::openai_api_types::{
};
use
sglang_router_rs
::
routers
::
request_adapter
::{
RouteableRequest
,
ToPdRequest
};
/// Create a default GenerateRequest for benchmarks with minimal fields set
fn
default_generate_request
()
->
GenerateRequest
{
GenerateRequest
{
text
:
None
,
prompt
:
None
,
input_ids
:
None
,
stream
:
false
,
parameters
:
None
,
sampling_params
:
None
,
return_logprob
:
false
,
// SGLang Extensions
lora_path
:
None
,
session_params
:
None
,
return_hidden_states
:
false
,
rid
:
None
,
}
}
/// Create a default ChatCompletionRequest for benchmarks with minimal fields set
fn
default_chat_completion_request
()
->
ChatCompletionRequest
{
ChatCompletionRequest
{
model
:
String
::
new
(),
messages
:
vec!
[],
max_tokens
:
None
,
max_completion_tokens
:
None
,
temperature
:
None
,
top_p
:
None
,
n
:
None
,
stream
:
false
,
stream_options
:
None
,
stop
:
None
,
presence_penalty
:
None
,
frequency_penalty
:
None
,
logit_bias
:
None
,
logprobs
:
false
,
top_logprobs
:
None
,
user
:
None
,
response_format
:
None
,
seed
:
None
,
tools
:
None
,
tool_choice
:
None
,
parallel_tool_calls
:
None
,
function_call
:
None
,
functions
:
None
,
// SGLang Extensions
top_k
:
None
,
min_p
:
None
,
min_tokens
:
None
,
repetition_penalty
:
None
,
regex
:
None
,
ebnf
:
None
,
stop_token_ids
:
None
,
no_stop_trim
:
false
,
ignore_eos
:
false
,
continue_final_message
:
false
,
skip_special_tokens
:
true
,
// SGLang Extensions
lora_path
:
None
,
session_params
:
None
,
separate_reasoning
:
true
,
stream_reasoning
:
true
,
return_hidden_states
:
false
,
}
}
/// Create a default CompletionRequest for benchmarks with minimal fields set
fn
default_completion_request
()
->
CompletionRequest
{
CompletionRequest
{
model
:
String
::
new
(),
prompt
:
StringOrArray
::
String
(
String
::
new
()),
suffix
:
None
,
max_tokens
:
None
,
temperature
:
None
,
top_p
:
None
,
n
:
None
,
stream
:
false
,
stream_options
:
None
,
logprobs
:
None
,
echo
:
false
,
stop
:
None
,
presence_penalty
:
None
,
frequency_penalty
:
None
,
best_of
:
None
,
logit_bias
:
None
,
user
:
None
,
seed
:
None
,
// SGLang Extensions
top_k
:
None
,
min_p
:
None
,
min_tokens
:
None
,
repetition_penalty
:
None
,
regex
:
None
,
ebnf
:
None
,
json_schema
:
None
,
stop_token_ids
:
None
,
no_stop_trim
:
false
,
ignore_eos
:
false
,
skip_special_tokens
:
true
,
// SGLang Extensions
lora_path
:
None
,
session_params
:
None
,
return_hidden_states
:
false
,
other
:
serde_json
::
Map
::
new
(),
}
}
// Sample request data for benchmarks
fn
create_sample_generate_request
()
->
GenerateRequest
{
GenerateRequest
{
text
:
Some
(
"Write a story about artificial intelligence"
.to_string
()),
input_ids
:
None
,
prompt
:
None
,
parameters
:
Some
(
GenerateParameters
{
max_new_tokens
:
Some
(
100
),
temperature
:
Some
(
0.8
),
...
...
@@ -31,8 +135,7 @@ fn create_sample_generate_request() -> GenerateRequest {
repetition_penalty
:
Some
(
1.0
),
..
Default
::
default
()
}),
stream
:
false
,
return_logprob
:
false
,
..
default_generate_request
()
}
}
...
...
@@ -58,22 +161,10 @@ fn create_sample_chat_completion_request() -> ChatCompletionRequest {
temperature
:
Some
(
0.7
),
top_p
:
Some
(
1.0
),
n
:
Some
(
1
),
stream
:
false
,
stream_options
:
None
,
stop
:
None
,
presence_penalty
:
Some
(
0.0
),
frequency_penalty
:
Some
(
0.0
),
logit_bias
:
None
,
logprobs
:
false
,
top_logprobs
:
None
,
user
:
None
,
response_format
:
None
,
seed
:
None
,
tools
:
None
,
tool_choice
:
None
,
parallel_tool_calls
:
Some
(
true
),
function_call
:
None
,
functions
:
None
,
..
default_chat_completion_request
()
}
}
...
...
@@ -81,23 +172,14 @@ fn create_sample_completion_request() -> CompletionRequest {
CompletionRequest
{
model
:
"text-davinci-003"
.to_string
(),
prompt
:
StringOrArray
::
String
(
"Complete this sentence: The future of AI is"
.to_string
()),
suffix
:
None
,
max_tokens
:
Some
(
50
),
temperature
:
Some
(
0.8
),
top_p
:
Some
(
1.0
),
n
:
Some
(
1
),
stream
:
false
,
stream_options
:
None
,
logprobs
:
None
,
echo
:
false
,
stop
:
None
,
presence_penalty
:
Some
(
0.0
),
frequency_penalty
:
Some
(
0.0
),
best_of
:
Some
(
1
),
logit_bias
:
None
,
user
:
None
,
seed
:
None
,
other
:
serde_json
::
Map
::
new
(),
..
default_completion_request
()
}
}
...
...
@@ -121,6 +203,7 @@ fn create_large_chat_completion_request() -> ChatCompletionRequest {
name
:
None
,
tool_calls
:
None
,
function_call
:
None
,
reasoning_content
:
None
,
});
}
...
...
@@ -132,22 +215,13 @@ fn create_large_chat_completion_request() -> ChatCompletionRequest {
temperature
:
Some
(
0.7
),
top_p
:
Some
(
0.95
),
n
:
Some
(
1
),
stream
:
false
,
stream_options
:
None
,
stop
:
None
,
presence_penalty
:
Some
(
0.1
),
frequency_penalty
:
Some
(
0.1
),
logit_bias
:
None
,
logprobs
:
false
,
top_logprobs
:
Some
(
5
),
user
:
Some
(
"benchmark_user"
.to_string
()),
response_format
:
None
,
seed
:
Some
(
42
),
tools
:
None
,
tool_choice
:
None
,
parallel_tool_calls
:
Some
(
true
),
function_call
:
None
,
functions
:
None
,
..
default_chat_completion_request
()
}
}
...
...
@@ -331,32 +405,17 @@ fn bench_throughput_by_size(c: &mut Criterion) {
// Create requests of different sizes
let
small_generate
=
GenerateRequest
{
text
:
Some
(
"Hi"
.to_string
()),
input_ids
:
None
,
prompt
:
None
,
parameters
:
None
,
sampling_params
:
None
,
stream
:
false
,
return_logprob
:
false
,
..
default_generate_request
()
};
let
medium_generate
=
GenerateRequest
{
text
:
Some
(
"Write a medium length story about AI"
.repeat
(
10
)),
input_ids
:
None
,
prompt
:
None
,
parameters
:
None
,
sampling_params
:
None
,
stream
:
false
,
return_logprob
:
false
,
..
default_generate_request
()
};
let
large_generate
=
GenerateRequest
{
text
:
Some
(
"Write a very long and detailed story about artificial intelligence and its impact on society"
.repeat
(
100
)),
input_ids
:
None
,
prompt
:
None
,
parameters
:
None
,
sampling_params
:
None
,
stream
:
false
,
return_logprob
:
false
,
..
default_generate_request
()
};
for
(
name
,
req
)
in
[
...
...
sgl-router/src/openai_api_types.rs
View file @
5d62b56f
...
...
@@ -6,6 +6,21 @@ use serde::{Deserialize, Serialize};
use
serde_json
::
Value
;
use
std
::
collections
::
HashMap
;
/// Helper function for serde default value
fn
default_true
()
->
bool
{
true
}
// ============= SGLang-Specific Types =============
/// LoRA adapter path - can be single path or batch of paths
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
#[serde(untagged)]
pub
enum
LoRAPath
{
Single
(
Option
<
String
>
),
Batch
(
Vec
<
Option
<
String
>>
),
}
/// Common trait for all generation requests
pub
trait
GenerationRequest
:
Send
+
Sync
{
/// Check if the request is for streaming
...
...
@@ -92,6 +107,64 @@ pub struct CompletionRequest {
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
seed
:
Option
<
i64
>
,
// ============= SGLang Extensions =============
/// Top-k sampling parameter (-1 to disable)
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
top_k
:
Option
<
i32
>
,
/// Min-p nucleus sampling parameter
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
min_p
:
Option
<
f32
>
,
/// Minimum number of tokens to generate
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
min_tokens
:
Option
<
u32
>
,
/// Repetition penalty for reducing repetitive text
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
repetition_penalty
:
Option
<
f32
>
,
/// Regex constraint for output generation
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
regex
:
Option
<
String
>
,
/// EBNF grammar constraint for structured output
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
ebnf
:
Option
<
String
>
,
/// JSON schema constraint for structured output
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
json_schema
:
Option
<
String
>
,
/// Specific token IDs to use as stop conditions
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
stop_token_ids
:
Option
<
Vec
<
i32
>>
,
/// Skip trimming stop tokens from output
#[serde(default)]
pub
no_stop_trim
:
bool
,
/// Ignore end-of-sequence tokens during generation
#[serde(default)]
pub
ignore_eos
:
bool
,
/// Skip special tokens during detokenization
#[serde(default
=
"default_true"
)]
pub
skip_special_tokens
:
bool
,
// ============= SGLang Extensions =============
/// Path to LoRA adapter(s) for model customization
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
lora_path
:
Option
<
LoRAPath
>
,
/// Session parameters for continual prompting
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
session_params
:
Option
<
HashMap
<
String
,
serde_json
::
Value
>>
,
/// Return model hidden states
#[serde(default)]
pub
return_hidden_states
:
bool
,
/// Additional fields including bootstrap info for PD routing
#[serde(flatten)]
pub
other
:
serde_json
::
Map
<
String
,
serde_json
::
Value
>
,
...
...
@@ -166,7 +239,7 @@ pub struct ChatCompletionRequest {
/// Modify the likelihood of specified tokens appearing in the completion
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
logit_bias
:
Option
<
HashMap
<
String
,
i
32
>>
,
pub
logit_bias
:
Option
<
HashMap
<
String
,
f
32
>>
,
/// A unique identifier representing your end-user
#[serde(skip_serializing_if
=
"Option::is_none"
)]
...
...
@@ -207,6 +280,72 @@ pub struct ChatCompletionRequest {
/// Deprecated: use tool_choice instead
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
function_call
:
Option
<
FunctionCall
>
,
// ============= SGLang Extensions =============
/// Top-k sampling parameter (-1 to disable)
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
top_k
:
Option
<
i32
>
,
/// Min-p nucleus sampling parameter
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
min_p
:
Option
<
f32
>
,
/// Minimum number of tokens to generate
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
min_tokens
:
Option
<
u32
>
,
/// Repetition penalty for reducing repetitive text
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
repetition_penalty
:
Option
<
f32
>
,
/// Regex constraint for output generation
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
regex
:
Option
<
String
>
,
/// EBNF grammar constraint for structured output
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
ebnf
:
Option
<
String
>
,
/// Specific token IDs to use as stop conditions
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
stop_token_ids
:
Option
<
Vec
<
i32
>>
,
/// Skip trimming stop tokens from output
#[serde(default)]
pub
no_stop_trim
:
bool
,
/// Ignore end-of-sequence tokens during generation
#[serde(default)]
pub
ignore_eos
:
bool
,
/// Continue generating from final assistant message
#[serde(default)]
pub
continue_final_message
:
bool
,
/// Skip special tokens during detokenization
#[serde(default
=
"default_true"
)]
pub
skip_special_tokens
:
bool
,
// ============= SGLang Extensions =============
/// Path to LoRA adapter(s) for model customization
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
lora_path
:
Option
<
LoRAPath
>
,
/// Session parameters for continual prompting
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
session_params
:
Option
<
HashMap
<
String
,
serde_json
::
Value
>>
,
/// Separate reasoning content from final answer (O1-style models)
#[serde(default
=
"default_true"
)]
pub
separate_reasoning
:
bool
,
/// Stream reasoning tokens during generation
#[serde(default
=
"default_true"
)]
pub
stream_reasoning
:
bool
,
/// Return model hidden states
#[serde(default)]
pub
return_hidden_states
:
bool
,
}
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
...
...
@@ -234,6 +373,9 @@ pub enum ChatMessage {
tool_calls
:
Option
<
Vec
<
ToolCall
>>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
function_call
:
Option
<
FunctionCallResponse
>
,
/// Reasoning content for O1-style models (SGLang extension)
#[serde(skip_serializing_if
=
"Option::is_none"
)]
reasoning_content
:
Option
<
String
>
,
},
Tool
{
role
:
String
,
// "tool"
...
...
@@ -378,7 +520,20 @@ impl GenerationRequest for ChatCompletionRequest {
Some
(
texts
.join
(
" "
))
}
},
ChatMessage
::
Assistant
{
content
,
..
}
=>
content
.clone
(),
ChatMessage
::
Assistant
{
content
,
reasoning_content
,
..
}
=>
{
// Combine content and reasoning content for routing decisions
let
main_content
=
content
.clone
()
.unwrap_or_default
();
let
reasoning
=
reasoning_content
.clone
()
.unwrap_or_default
();
if
main_content
.is_empty
()
&&
reasoning
.is_empty
()
{
None
}
else
{
Some
(
format!
(
"{} {}"
,
main_content
,
reasoning
)
.trim
()
.to_string
())
}
}
ChatMessage
::
Tool
{
content
,
..
}
=>
Some
(
content
.clone
()),
ChatMessage
::
Function
{
content
,
..
}
=>
Some
(
content
.clone
()),
})
...
...
@@ -418,6 +573,23 @@ pub struct GenerateRequest {
/// Whether to return logprobs
#[serde(default)]
pub
return_logprob
:
bool
,
// ============= SGLang Extensions =============
/// Path to LoRA adapter(s) for model customization
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
lora_path
:
Option
<
LoRAPath
>
,
/// Session parameters for continual prompting
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
session_params
:
Option
<
HashMap
<
String
,
serde_json
::
Value
>>
,
/// Return model hidden states
#[serde(default)]
pub
return_hidden_states
:
bool
,
/// Request ID for tracking
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
rid
:
Option
<
String
>
,
}
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
...
...
@@ -485,6 +657,18 @@ pub struct SamplingParams {
pub
skip_special_tokens
:
Option
<
bool
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
json_schema
:
Option
<
String
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
regex
:
Option
<
String
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
ebnf
:
Option
<
String
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
min_p
:
Option
<
f32
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
min_tokens
:
Option
<
u32
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
stop_token_ids
:
Option
<
Vec
<
i32
>>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
no_stop_trim
:
Option
<
bool
>
,
}
impl
GenerationRequest
for
GenerateRequest
{
...
...
@@ -561,6 +745,12 @@ pub struct CompletionChoice {
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
logprobs
:
Option
<
LogProbs
>
,
pub
finish_reason
:
Option
<
String
>
,
// "stop", "length", "content_filter", etc.
/// Information about which stop condition was matched
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
matched_stop
:
Option
<
serde_json
::
Value
>
,
// Can be string or integer
/// Hidden states from the model (SGLang extension)
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
hidden_states
:
Option
<
Vec
<
f32
>>
,
}
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
...
...
@@ -591,6 +781,12 @@ pub struct ChatChoice {
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
logprobs
:
Option
<
ChatLogProbs
>
,
pub
finish_reason
:
Option
<
String
>
,
// "stop", "length", "tool_calls", "content_filter", "function_call"
/// Information about which stop condition was matched
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
matched_stop
:
Option
<
serde_json
::
Value
>
,
// Can be string or integer
/// Hidden states from the model (SGLang extension)
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
hidden_states
:
Option
<
Vec
<
f32
>>
,
}
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
...
...
@@ -681,6 +877,9 @@ pub struct ChatMessageDelta {
pub
tool_calls
:
Option
<
Vec
<
ToolCallDelta
>>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
function_call
:
Option
<
FunctionCallDelta
>
,
/// Reasoning content delta for O1-style models (SGLang extension)
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
reasoning_content
:
Option
<
String
>
,
}
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
...
...
sgl-router/src/routers/pd_types.rs
View file @
5d62b56f
...
...
@@ -278,11 +278,11 @@ mod bootstrap_tests {
use
crate
::
core
::
BasicWorker
;
use
crate
::
openai_api_types
::
StringOrArray
;
#[test]
fn
tes
t_completion_
batch_size_with_array_prompt
()
{
let
req
=
CompletionRequest
{
model
:
"test"
.to_string
(),
prompt
:
StringOrArray
::
Array
(
vec!
[
"prompt1"
.to_string
(),
"prompt2"
.to_string
()
]
),
/// Create a default CompletionRequest for testing with minimal fields set
fn
defaul
t_completion_
request
()
->
CompletionRequest
{
CompletionRequest
{
model
:
String
::
new
(),
prompt
:
StringOrArray
::
String
(
String
::
new
()),
n
:
None
,
other
:
serde_json
::
Map
::
new
(),
suffix
:
None
,
...
...
@@ -300,6 +300,31 @@ mod bootstrap_tests {
logit_bias
:
None
,
user
:
None
,
seed
:
None
,
// SGLang Extensions
top_k
:
None
,
min_p
:
None
,
min_tokens
:
None
,
repetition_penalty
:
None
,
regex
:
None
,
ebnf
:
None
,
json_schema
:
None
,
stop_token_ids
:
None
,
no_stop_trim
:
false
,
ignore_eos
:
false
,
skip_special_tokens
:
true
,
// SGLang Extensions
lora_path
:
None
,
session_params
:
None
,
return_hidden_states
:
false
,
}
}
#[test]
fn
test_completion_batch_size_with_array_prompt
()
{
let
req
=
CompletionRequest
{
model
:
"test"
.to_string
(),
prompt
:
StringOrArray
::
Array
(
vec!
[
"prompt1"
.to_string
(),
"prompt2"
.to_string
()]),
..
default_completion_request
()
};
// Should return batch size for array prompt
...
...
@@ -311,23 +336,7 @@ mod bootstrap_tests {
let
req
=
CompletionRequest
{
model
:
"test"
.to_string
(),
prompt
:
StringOrArray
::
String
(
"single prompt"
.to_string
()),
n
:
None
,
other
:
serde_json
::
Map
::
new
(),
suffix
:
None
,
max_tokens
:
None
,
temperature
:
None
,
top_p
:
None
,
stream
:
false
,
stream_options
:
None
,
logprobs
:
None
,
echo
:
false
,
stop
:
None
,
presence_penalty
:
None
,
frequency_penalty
:
None
,
best_of
:
None
,
logit_bias
:
None
,
user
:
None
,
seed
:
None
,
..
default_completion_request
()
};
// Should return None for single prompt
...
...
@@ -340,22 +349,7 @@ mod bootstrap_tests {
model
:
"test"
.to_string
(),
prompt
:
StringOrArray
::
String
(
"single prompt"
.to_string
()),
n
:
Some
(
3
),
other
:
serde_json
::
Map
::
new
(),
suffix
:
None
,
max_tokens
:
None
,
temperature
:
None
,
top_p
:
None
,
stream
:
false
,
stream_options
:
None
,
logprobs
:
None
,
echo
:
false
,
stop
:
None
,
presence_penalty
:
None
,
frequency_penalty
:
None
,
best_of
:
None
,
logit_bias
:
None
,
user
:
None
,
seed
:
None
,
..
default_completion_request
()
};
// Should return None for single string prompt, even with n > 1
...
...
@@ -368,23 +362,7 @@ mod bootstrap_tests {
let
mut
req
=
CompletionRequest
{
model
:
"test"
.to_string
(),
prompt
:
StringOrArray
::
Array
(
vec!
[
"prompt1"
.to_string
(),
"prompt2"
.to_string
()]),
n
:
None
,
other
:
serde_json
::
Map
::
new
(),
suffix
:
None
,
max_tokens
:
None
,
temperature
:
None
,
top_p
:
None
,
stream
:
false
,
stream_options
:
None
,
logprobs
:
None
,
echo
:
false
,
stop
:
None
,
presence_penalty
:
None
,
frequency_penalty
:
None
,
best_of
:
None
,
logit_bias
:
None
,
user
:
None
,
seed
:
None
,
..
default_completion_request
()
};
// Set bootstrap info - should always use single values
...
...
@@ -418,23 +396,7 @@ mod bootstrap_tests {
let
mut
req
=
CompletionRequest
{
model
:
"test"
.to_string
(),
prompt
:
StringOrArray
::
Array
(
vec!
[
"prompt1"
.to_string
(),
"prompt2"
.to_string
()]),
n
:
None
,
other
:
serde_json
::
Map
::
new
(),
suffix
:
None
,
max_tokens
:
None
,
temperature
:
None
,
top_p
:
None
,
stream
:
false
,
stream_options
:
None
,
logprobs
:
None
,
echo
:
false
,
stop
:
None
,
presence_penalty
:
None
,
frequency_penalty
:
None
,
best_of
:
None
,
logit_bias
:
None
,
user
:
None
,
seed
:
None
,
..
default_completion_request
()
};
// Set bootstrap info with arrays
...
...
sgl-router/src/routers/request_adapter.rs
View file @
5d62b56f
...
...
@@ -176,6 +176,33 @@ impl ToPdRequest for CompletionRequest {
self
.stream
=>
"stream"
);
// Add SGLang extension fields
insert_if_some!
(
other
,
// SGLang Extensions - Priority 1
self
.top_k
=>
"top_k"
,
self
.min_p
=>
"min_p"
,
self
.min_tokens
=>
"min_tokens"
,
self
.repetition_penalty
=>
"repetition_penalty"
,
self
.regex
=>
"regex"
,
self
.ebnf
=>
"ebnf"
,
self
.stop_token_ids
=>
"stop_token_ids"
,
// SGLang Extensions - Priority 2
self
.lora_path
=>
"lora_path"
,
self
.session_params
=>
"session_params"
);
// SGLang boolean extensions (CompletionRequest has these as bool, not Option<bool>)
other
.insert
(
"no_stop_trim"
.to_string
(),
self
.no_stop_trim
.into
());
other
.insert
(
"ignore_eos"
.to_string
(),
self
.ignore_eos
.into
());
other
.insert
(
"skip_special_tokens"
.to_string
(),
self
.skip_special_tokens
.into
(),
);
other
.insert
(
"return_hidden_states"
.to_string
(),
self
.return_hidden_states
.into
(),
);
GenerateReqInput
{
text
,
input_ids
:
None
,
...
...
@@ -226,14 +253,46 @@ impl ToPdRequest for ChatCompletionRequest {
self
.tool_choice
=>
"tool_choice"
,
self
.parallel_tool_calls
=>
"parallel_tool_calls"
,
self
.functions
=>
"functions"
,
self
.function_call
=>
"function_call"
self
.function_call
=>
"function_call"
,
// SGLang Extensions - Priority 1
self
.top_k
=>
"top_k"
,
self
.min_p
=>
"min_p"
,
self
.min_tokens
=>
"min_tokens"
,
self
.repetition_penalty
=>
"repetition_penalty"
,
self
.regex
=>
"regex"
,
self
.ebnf
=>
"ebnf"
,
self
.stop_token_ids
=>
"stop_token_ids"
,
// SGLang Extensions - Priority 2
self
.lora_path
=>
"lora_path"
,
self
.session_params
=>
"session_params"
);
// Handle boolean
logprobs
flag
// Handle boolean flag
s
if
self
.logprobs
{
other
.insert
(
"logprobs"
.to_string
(),
true
.into
());
}
// SGLang boolean extensions (ChatCompletionRequest has these as bool, not Option<bool>)
other
.insert
(
"no_stop_trim"
.to_string
(),
self
.no_stop_trim
.into
());
other
.insert
(
"ignore_eos"
.to_string
(),
self
.ignore_eos
.into
());
other
.insert
(
"continue_final_message"
.to_string
(),
self
.continue_final_message
.into
(),
);
other
.insert
(
"skip_special_tokens"
.to_string
(),
self
.skip_special_tokens
.into
(),
);
other
.insert
(
"separate_reasoning"
.to_string
(),
self
.separate_reasoning
.into
(),
);
other
.insert
(
"stream_reasoning"
.to_string
(),
self
.stream_reasoning
.into
());
other
.insert
(
"return_hidden_states"
.to_string
(),
self
.return_hidden_states
.into
(),
);
ChatReqInput
{
stream
:
self
.stream
,
bootstrap_host
:
None
,
...
...
@@ -271,18 +330,136 @@ mod tests {
use
serde_json
::
json
;
use
std
::
collections
::
HashMap
;
// ============= GenerateRequest to_pd_request Tests =============
#[test]
fn
test_generate_to_pd_request_with_text_only
()
{
let
req
=
GenerateRequest
{
text
:
Some
(
"Hello world"
.to_string
()),
// ============= Test Helper Functions =============
//
// These helper functions create default request instances with all required SGLang extension fields
// properly initialized. Use the struct spread operator `..default_*_request()` to override only
// the fields you need for specific tests, avoiding repetitive boilerplate code.
//
// Example usage:
// let req = GenerateRequest {
// text: Some("Custom text".to_string()),
// stream: true,
// ..default_generate_request()
// };
/// Create a default GenerateRequest with minimal fields set
fn
default_generate_request
()
->
GenerateRequest
{
GenerateRequest
{
text
:
None
,
prompt
:
None
,
input_ids
:
None
,
stream
:
false
,
parameters
:
None
,
sampling_params
:
None
,
return_logprob
:
false
,
// SGLang Extensions
lora_path
:
None
,
session_params
:
None
,
return_hidden_states
:
false
,
rid
:
None
,
}
}
/// Create a default CompletionRequest with minimal fields set
fn
default_completion_request
()
->
CompletionRequest
{
CompletionRequest
{
model
:
"test-model"
.to_string
(),
prompt
:
StringOrArray
::
String
(
"test prompt"
.to_string
()),
max_tokens
:
None
,
temperature
:
None
,
top_p
:
None
,
n
:
None
,
stream
:
false
,
stream_options
:
None
,
logprobs
:
None
,
echo
:
false
,
stop
:
None
,
presence_penalty
:
None
,
frequency_penalty
:
None
,
best_of
:
None
,
logit_bias
:
None
,
user
:
None
,
seed
:
None
,
suffix
:
None
,
// SGLang Extensions
top_k
:
None
,
min_p
:
None
,
min_tokens
:
None
,
repetition_penalty
:
None
,
regex
:
None
,
ebnf
:
None
,
json_schema
:
None
,
stop_token_ids
:
None
,
no_stop_trim
:
false
,
ignore_eos
:
false
,
skip_special_tokens
:
true
,
// SGLang Extensions
lora_path
:
None
,
session_params
:
None
,
return_hidden_states
:
false
,
other
:
serde_json
::
Map
::
new
(),
}
}
/// Create a default ChatCompletionRequest with minimal fields set
fn
default_chat_completion_request
()
->
ChatCompletionRequest
{
ChatCompletionRequest
{
model
:
"test-model"
.to_string
(),
messages
:
vec!
[
ChatMessage
::
User
{
role
:
"user"
.to_string
(),
content
:
UserMessageContent
::
Text
(
"test message"
.to_string
()),
name
:
None
,
}],
temperature
:
None
,
top_p
:
None
,
n
:
None
,
stream
:
false
,
stream_options
:
None
,
stop
:
None
,
max_tokens
:
None
,
max_completion_tokens
:
None
,
presence_penalty
:
None
,
frequency_penalty
:
None
,
logit_bias
:
None
,
logprobs
:
false
,
top_logprobs
:
None
,
user
:
None
,
seed
:
None
,
response_format
:
None
,
tools
:
None
,
tool_choice
:
None
,
parallel_tool_calls
:
None
,
functions
:
None
,
function_call
:
None
,
// SGLang Extensions
top_k
:
None
,
min_p
:
None
,
min_tokens
:
None
,
repetition_penalty
:
None
,
regex
:
None
,
ebnf
:
None
,
stop_token_ids
:
None
,
no_stop_trim
:
false
,
ignore_eos
:
false
,
continue_final_message
:
false
,
skip_special_tokens
:
true
,
// SGLang Extensions
lora_path
:
None
,
session_params
:
None
,
separate_reasoning
:
true
,
stream_reasoning
:
true
,
return_hidden_states
:
false
,
}
}
// ============= GenerateRequest to_pd_request Tests =============
#[test]
fn
test_generate_to_pd_request_with_text_only
()
{
let
req
=
GenerateRequest
{
text
:
Some
(
"Hello world"
.to_string
()),
..
default_generate_request
()
};
let
pd_req
=
req
.to_pd_request
();
...
...
@@ -308,13 +485,10 @@ mod tests {
#[test]
fn
test_generate_to_pd_request_with_prompt_string
()
{
let
req
=
GenerateRequest
{
text
:
None
,
prompt
:
Some
(
StringOrArray
::
String
(
"Test prompt"
.to_string
())),
input_ids
:
None
,
stream
:
true
,
parameters
:
None
,
sampling_params
:
None
,
return_logprob
:
true
,
..
default_generate_request
()
};
let
pd_req
=
req
.to_pd_request
();
...
...
@@ -342,6 +516,7 @@ mod tests {
parameters
:
None
,
sampling_params
:
None
,
return_logprob
:
false
,
..
default_generate_request
()
};
let
pd_req
=
req
.to_pd_request
();
...
...
@@ -360,13 +535,8 @@ mod tests {
#[test]
fn
test_generate_to_pd_request_with_single_input_ids
()
{
let
req
=
GenerateRequest
{
text
:
None
,
prompt
:
None
,
input_ids
:
Some
(
InputIds
::
Single
(
vec!
[
100
,
200
,
300
,
400
])),
stream
:
false
,
parameters
:
None
,
sampling_params
:
None
,
return_logprob
:
false
,
..
default_generate_request
()
};
let
pd_req
=
req
.to_pd_request
();
...
...
@@ -381,17 +551,12 @@ mod tests {
#[test]
fn
test_generate_to_pd_request_with_batch_input_ids
()
{
let
req
=
GenerateRequest
{
text
:
None
,
prompt
:
None
,
input_ids
:
Some
(
InputIds
::
Batch
(
vec!
[
vec!
[
1
,
2
,
3
],
vec!
[
4
,
5
,
6
,
7
],
vec!
[
8
,
9
],
])),
stream
:
false
,
parameters
:
None
,
sampling_params
:
None
,
return_logprob
:
false
,
..
default_generate_request
()
};
let
pd_req
=
req
.to_pd_request
();
...
...
@@ -413,10 +578,7 @@ mod tests {
text
:
Some
(
"SGLang text"
.to_string
()),
prompt
:
Some
(
StringOrArray
::
String
(
"OpenAI prompt"
.to_string
())),
input_ids
:
Some
(
InputIds
::
Single
(
vec!
[
1
,
2
,
3
])),
stream
:
false
,
parameters
:
None
,
sampling_params
:
None
,
return_logprob
:
false
,
..
default_generate_request
()
};
let
pd_req
=
req
.to_pd_request
();
...
...
@@ -429,13 +591,9 @@ mod tests {
#[test]
fn
test_generate_to_pd_request_priority_prompt_over_input_ids
()
{
let
req
=
GenerateRequest
{
text
:
None
,
prompt
:
Some
(
StringOrArray
::
String
(
"OpenAI prompt"
.to_string
())),
input_ids
:
Some
(
InputIds
::
Single
(
vec!
[
1
,
2
,
3
])),
stream
:
false
,
parameters
:
None
,
sampling_params
:
None
,
return_logprob
:
false
,
..
default_generate_request
()
};
let
pd_req
=
req
.to_pd_request
();
...
...
@@ -459,12 +617,8 @@ mod tests {
let
req
=
GenerateRequest
{
text
:
Some
(
"test"
.to_string
()),
prompt
:
None
,
input_ids
:
None
,
stream
:
false
,
parameters
:
Some
(
params
),
sampling_params
:
None
,
return_logprob
:
false
,
..
default_generate_request
()
};
let
pd_req
=
req
.to_pd_request
();
...
...
@@ -497,12 +651,8 @@ mod tests {
let
req
=
GenerateRequest
{
text
:
Some
(
"test"
.to_string
()),
prompt
:
None
,
input_ids
:
None
,
stream
:
false
,
parameters
:
None
,
sampling_params
:
Some
(
sampling
),
return_logprob
:
false
,
..
default_generate_request
()
};
let
pd_req
=
req
.to_pd_request
();
...
...
@@ -546,6 +696,7 @@ mod tests {
parameters
:
Some
(
params
),
sampling_params
:
Some
(
sampling
),
return_logprob
:
false
,
..
default_generate_request
()
};
let
pd_req
=
req
.to_pd_request
();
...
...
@@ -568,6 +719,7 @@ mod tests {
parameters
:
Some
(
params
),
sampling_params
:
None
,
return_logprob
:
false
,
..
default_generate_request
()
};
let
pd_req
=
req
.to_pd_request
();
...
...
@@ -603,6 +755,7 @@ mod tests {
parameters
:
Some
(
params
),
sampling_params
:
Some
(
sampling
),
return_logprob
:
true
,
..
default_generate_request
()
};
let
pd_req
=
req
.to_pd_request
();
...
...
@@ -632,23 +785,7 @@ mod tests {
let
req
=
CompletionRequest
{
model
:
"gpt-3.5-turbo"
.to_string
(),
prompt
:
StringOrArray
::
String
(
"Complete this sentence"
.to_string
()),
max_tokens
:
None
,
temperature
:
None
,
top_p
:
None
,
n
:
None
,
stream
:
false
,
stream_options
:
None
,
logprobs
:
None
,
echo
:
false
,
stop
:
None
,
presence_penalty
:
None
,
frequency_penalty
:
None
,
best_of
:
None
,
logit_bias
:
None
,
user
:
None
,
seed
:
None
,
suffix
:
None
,
other
:
serde_json
::
Map
::
new
(),
..
default_completion_request
()
};
let
pd_req
=
req
.to_pd_request
();
...
...
@@ -672,23 +809,7 @@ mod tests {
"First prompt"
.to_string
(),
"Second prompt"
.to_string
(),
]),
max_tokens
:
None
,
temperature
:
None
,
top_p
:
None
,
n
:
None
,
stream
:
false
,
stream_options
:
None
,
logprobs
:
None
,
echo
:
false
,
stop
:
None
,
presence_penalty
:
None
,
frequency_penalty
:
None
,
best_of
:
None
,
logit_bias
:
None
,
user
:
None
,
seed
:
None
,
suffix
:
None
,
other
:
serde_json
::
Map
::
new
(),
..
default_completion_request
()
};
let
pd_req
=
req
.to_pd_request
();
...
...
@@ -727,7 +848,7 @@ mod tests {
user
:
Some
(
"user123"
.to_string
()),
seed
:
Some
(
42
),
suffix
:
Some
(
"..."
.to_string
()),
other
:
serde_json
::
Map
::
new
()
,
..
default_completion_request
()
};
let
pd_req
=
req
.to_pd_request
();
...
...
@@ -771,7 +892,7 @@ mod tests {
user
:
None
,
seed
:
None
,
suffix
:
None
,
other
:
serde_json
::
Map
::
new
()
,
..
default_completion_request
()
};
let
pd_req
=
req
.to_pd_request
();
...
...
@@ -803,7 +924,7 @@ mod tests {
user
:
None
,
seed
:
None
,
suffix
:
None
,
other
:
serde_json
::
Map
::
new
()
,
..
default_completion_request
()
};
let
pd_req
=
req
.to_pd_request
();
...
...
@@ -834,27 +955,7 @@ mod tests {
let
req
=
ChatCompletionRequest
{
messages
,
model
:
"gpt-4"
.to_string
(),
temperature
:
None
,
top_p
:
None
,
n
:
None
,
stream
:
false
,
stream_options
:
None
,
stop
:
None
,
max_tokens
:
None
,
max_completion_tokens
:
None
,
presence_penalty
:
None
,
frequency_penalty
:
None
,
logit_bias
:
None
,
logprobs
:
false
,
top_logprobs
:
None
,
user
:
None
,
seed
:
None
,
response_format
:
None
,
tools
:
None
,
tool_choice
:
None
,
parallel_tool_calls
:
None
,
functions
:
None
,
function_call
:
None
,
..
default_chat_completion_request
()
};
let
pd_req
=
req
.to_pd_request
();
...
...
@@ -883,7 +984,7 @@ mod tests {
}];
let
mut
logit_bias
=
HashMap
::
new
();
logit_bias
.insert
(
"50256"
.to_string
(),
-
100
);
logit_bias
.insert
(
"50256"
.to_string
(),
-
100
.0f32
);
let
tool
=
Tool
{
tool_type
:
"function"
.to_string
(),
...
...
@@ -920,6 +1021,7 @@ mod tests {
parallel_tool_calls
:
Some
(
false
),
functions
:
None
,
function_call
:
None
,
..
default_chat_completion_request
()
};
let
pd_req
=
req
.to_pd_request
();
...
...
@@ -968,27 +1070,7 @@ mod tests {
let
req
=
ChatCompletionRequest
{
messages
,
model
:
"gpt-4-vision"
.to_string
(),
temperature
:
None
,
top_p
:
None
,
n
:
None
,
stream
:
false
,
stream_options
:
None
,
stop
:
None
,
max_tokens
:
None
,
max_completion_tokens
:
None
,
presence_penalty
:
None
,
frequency_penalty
:
None
,
logit_bias
:
None
,
logprobs
:
false
,
top_logprobs
:
None
,
user
:
None
,
seed
:
None
,
response_format
:
None
,
tools
:
None
,
tool_choice
:
None
,
parallel_tool_calls
:
None
,
functions
:
None
,
function_call
:
None
,
..
default_chat_completion_request
()
};
let
pd_req
=
req
.to_pd_request
();
...
...
@@ -1037,6 +1119,7 @@ mod tests {
parallel_tool_calls
:
None
,
functions
:
None
,
function_call
:
None
,
..
default_chat_completion_request
()
};
let
pd_req
=
req
.to_pd_request
();
...
...
@@ -1054,32 +1137,13 @@ mod tests {
name
:
None
,
tool_calls
:
None
,
function_call
:
None
,
reasoning_content
:
None
,
}];
let
req
=
ChatCompletionRequest
{
messages
,
model
:
"gpt-3.5-turbo"
.to_string
(),
temperature
:
None
,
top_p
:
None
,
n
:
None
,
stream
:
false
,
stream_options
:
None
,
stop
:
None
,
max_tokens
:
None
,
max_completion_tokens
:
None
,
presence_penalty
:
None
,
frequency_penalty
:
None
,
logit_bias
:
None
,
logprobs
:
false
,
top_logprobs
:
None
,
user
:
None
,
seed
:
None
,
response_format
:
None
,
tools
:
None
,
tool_choice
:
None
,
parallel_tool_calls
:
None
,
functions
:
None
,
function_call
:
None
,
..
default_chat_completion_request
()
};
let
pd_req
=
req
.to_pd_request
();
...
...
@@ -1101,12 +1165,7 @@ mod tests {
fn
test_routeable_request_to_json
()
{
let
req
=
GenerateRequest
{
text
:
Some
(
"test"
.to_string
()),
prompt
:
None
,
input_ids
:
None
,
stream
:
false
,
parameters
:
None
,
sampling_params
:
None
,
return_logprob
:
false
,
..
default_generate_request
()
};
let
json
=
req
.to_json
()
.unwrap
();
...
...
@@ -1166,6 +1225,7 @@ mod tests {
parameters
:
Some
(
params
),
sampling_params
:
None
,
return_logprob
:
false
,
..
default_generate_request
()
};
let
pd_req
=
req
.to_pd_request
();
...
...
@@ -1187,6 +1247,7 @@ mod tests {
parameters
:
None
,
sampling_params
:
None
,
return_logprob
:
false
,
..
default_generate_request
()
};
let
pd_req
=
req
.to_pd_request
();
...
...
@@ -1206,12 +1267,7 @@ mod tests {
let
req
=
GenerateRequest
{
text
:
Some
(
unicode_text
.clone
()),
prompt
:
None
,
input_ids
:
None
,
stream
:
false
,
parameters
:
None
,
sampling_params
:
None
,
return_logprob
:
false
,
..
default_generate_request
()
};
let
pd_req
=
req
.to_pd_request
();
...
...
@@ -1250,6 +1306,7 @@ mod tests {
parameters
:
Some
(
params
),
sampling_params
:
None
,
return_logprob
:
false
,
..
default_generate_request
()
};
let
pd_req
=
req
.to_pd_request
();
...
...
@@ -1265,12 +1322,7 @@ mod tests {
fn
test_bootstrap_fields_none
()
{
let
req
=
GenerateRequest
{
text
:
Some
(
"test"
.to_string
()),
prompt
:
None
,
input_ids
:
None
,
stream
:
false
,
parameters
:
None
,
sampling_params
:
None
,
return_logprob
:
false
,
..
default_generate_request
()
};
let
pd_req
=
req
.to_pd_request
();
...
...
@@ -1279,4 +1331,182 @@ mod tests {
assert_eq!
(
pd_req
.bootstrap_port
,
None
);
assert_eq!
(
pd_req
.bootstrap_room
,
None
);
}
// ============= SGLang Extension Field Pass-Through Tests =============
#[test]
fn
test_chat_completion_sglang_extensions_passed_through
()
{
let
messages
=
vec!
[
ChatMessage
::
User
{
role
:
"user"
.to_string
(),
content
:
UserMessageContent
::
Text
(
"Test"
.to_string
()),
name
:
None
,
}];
let
mut
session_params
=
std
::
collections
::
HashMap
::
new
();
session_params
.insert
(
"key"
.to_string
(),
serde_json
::
Value
::
String
(
"value"
.to_string
()),
);
let
req
=
ChatCompletionRequest
{
messages
,
model
:
"test-model"
.to_string
(),
// SGLang Extensions - Priority 1
top_k
:
Some
(
40
),
min_p
:
Some
(
0.05
),
min_tokens
:
Some
(
10
),
repetition_penalty
:
Some
(
1.1
),
regex
:
Some
(
"test_regex"
.to_string
()),
ebnf
:
Some
(
"test_ebnf"
.to_string
()),
stop_token_ids
:
Some
(
vec!
[
1
,
2
,
3
]),
// SGLang Extensions - Priority 2
lora_path
:
Some
(
LoRAPath
::
Single
(
Some
(
"test_lora.bin"
.to_string
()))),
session_params
:
Some
(
session_params
.clone
()),
// Boolean extensions (ChatCompletionRequest has these as bool, not Option<bool>)
no_stop_trim
:
true
,
ignore_eos
:
false
,
continue_final_message
:
true
,
skip_special_tokens
:
false
,
separate_reasoning
:
true
,
stream_reasoning
:
false
,
return_hidden_states
:
true
,
..
default_chat_completion_request
()
};
let
pd_req
=
req
.to_pd_request
();
let
other
=
pd_req
.other
.as_object
()
.unwrap
();
// Verify SGLang extensions are passed through
assert_eq!
(
other
.get
(
"top_k"
),
Some
(
&
json!
(
40
)));
assert
!
((
other
.get
(
"min_p"
)
.unwrap
()
.as_f64
()
.unwrap
()
-
0.05
)
.abs
()
<
0.0001
);
assert_eq!
(
other
.get
(
"min_tokens"
),
Some
(
&
json!
(
10
)));
assert
!
((
other
.get
(
"repetition_penalty"
)
.unwrap
()
.as_f64
()
.unwrap
()
-
1.1
)
.abs
()
<
0.0001
);
assert_eq!
(
other
.get
(
"regex"
),
Some
(
&
json!
(
"test_regex"
)));
assert_eq!
(
other
.get
(
"ebnf"
),
Some
(
&
json!
(
"test_ebnf"
)));
assert_eq!
(
other
.get
(
"stop_token_ids"
),
Some
(
&
json!
(
vec!
[
1
,
2
,
3
])));
assert_eq!
(
other
.get
(
"lora_path"
),
Some
(
&
json!
(
"test_lora.bin"
)));
assert_eq!
(
other
.get
(
"session_params"
),
Some
(
&
serde_json
::
to_value
(
&
session_params
)
.unwrap
())
);
// Verify boolean extensions
assert_eq!
(
other
.get
(
"no_stop_trim"
),
Some
(
&
json!
(
true
)));
assert_eq!
(
other
.get
(
"ignore_eos"
),
Some
(
&
json!
(
false
)));
assert_eq!
(
other
.get
(
"continue_final_message"
),
Some
(
&
json!
(
true
)));
assert_eq!
(
other
.get
(
"skip_special_tokens"
),
Some
(
&
json!
(
false
)));
assert_eq!
(
other
.get
(
"separate_reasoning"
),
Some
(
&
json!
(
true
)));
assert_eq!
(
other
.get
(
"stream_reasoning"
),
Some
(
&
json!
(
false
)));
assert_eq!
(
other
.get
(
"return_hidden_states"
),
Some
(
&
json!
(
true
)));
}
#[test]
fn
test_completion_request_sglang_extensions_passed_through
()
{
let
mut
session_params
=
std
::
collections
::
HashMap
::
new
();
session_params
.insert
(
"key"
.to_string
(),
serde_json
::
Value
::
String
(
"value"
.to_string
()),
);
let
req
=
CompletionRequest
{
prompt
:
StringOrArray
::
String
(
"Test prompt"
.to_string
()),
model
:
"test-model"
.to_string
(),
// SGLang Extensions - Priority 1
top_k
:
Some
(
40
),
min_p
:
Some
(
0.05
),
min_tokens
:
Some
(
10
),
repetition_penalty
:
Some
(
1.1
),
regex
:
Some
(
"test_regex"
.to_string
()),
ebnf
:
Some
(
"test_ebnf"
.to_string
()),
stop_token_ids
:
Some
(
vec!
[
1
,
2
,
3
]),
// SGLang Extensions - Priority 2
lora_path
:
Some
(
LoRAPath
::
Single
(
Some
(
"test_lora.bin"
.to_string
()))),
session_params
:
Some
(
session_params
.clone
()),
// Boolean extensions (CompletionRequest only has these 4 boolean fields)
no_stop_trim
:
true
,
ignore_eos
:
false
,
skip_special_tokens
:
false
,
return_hidden_states
:
true
,
..
default_completion_request
()
};
let
pd_req
=
req
.to_pd_request
();
let
other
=
pd_req
.other
.as_object
()
.unwrap
();
// Verify SGLang extensions are passed through
assert_eq!
(
other
.get
(
"top_k"
),
Some
(
&
json!
(
40
)));
assert
!
((
other
.get
(
"min_p"
)
.unwrap
()
.as_f64
()
.unwrap
()
-
0.05
)
.abs
()
<
0.0001
);
assert_eq!
(
other
.get
(
"min_tokens"
),
Some
(
&
json!
(
10
)));
assert
!
((
other
.get
(
"repetition_penalty"
)
.unwrap
()
.as_f64
()
.unwrap
()
-
1.1
)
.abs
()
<
0.0001
);
assert_eq!
(
other
.get
(
"regex"
),
Some
(
&
json!
(
"test_regex"
)));
assert_eq!
(
other
.get
(
"ebnf"
),
Some
(
&
json!
(
"test_ebnf"
)));
assert_eq!
(
other
.get
(
"stop_token_ids"
),
Some
(
&
json!
(
vec!
[
1
,
2
,
3
])));
assert_eq!
(
other
.get
(
"lora_path"
),
Some
(
&
json!
(
"test_lora.bin"
)));
assert_eq!
(
other
.get
(
"session_params"
),
Some
(
&
serde_json
::
to_value
(
&
session_params
)
.unwrap
())
);
// Verify boolean extensions (only the ones CompletionRequest has)
assert_eq!
(
other
.get
(
"no_stop_trim"
),
Some
(
&
json!
(
true
)));
assert_eq!
(
other
.get
(
"ignore_eos"
),
Some
(
&
json!
(
false
)));
assert_eq!
(
other
.get
(
"skip_special_tokens"
),
Some
(
&
json!
(
false
)));
assert_eq!
(
other
.get
(
"return_hidden_states"
),
Some
(
&
json!
(
true
)));
}
#[test]
fn
test_sglang_extensions_none_values_not_passed_through
()
{
let
messages
=
vec!
[
ChatMessage
::
User
{
role
:
"user"
.to_string
(),
content
:
UserMessageContent
::
Text
(
"Test"
.to_string
()),
name
:
None
,
}];
let
req
=
ChatCompletionRequest
{
messages
,
model
:
"test-model"
.to_string
(),
// All SGLang extensions as None/default - Optional fields won't appear, bools will use defaults
top_k
:
None
,
min_p
:
None
,
min_tokens
:
None
,
repetition_penalty
:
None
,
regex
:
None
,
ebnf
:
None
,
stop_token_ids
:
None
,
lora_path
:
None
,
session_params
:
None
,
// Boolean fields use defaults (false for most, true for some with default_true)
no_stop_trim
:
false
,
ignore_eos
:
false
,
continue_final_message
:
false
,
skip_special_tokens
:
true
,
// This has default_true
separate_reasoning
:
true
,
// This has default_true
stream_reasoning
:
true
,
// This has default_true
return_hidden_states
:
false
,
..
default_chat_completion_request
()
};
let
pd_req
=
req
.to_pd_request
();
let
other
=
pd_req
.other
.as_object
()
.unwrap
();
// Verify None values are not included
assert
!
(
!
other
.contains_key
(
"top_k"
));
assert
!
(
!
other
.contains_key
(
"min_p"
));
assert
!
(
!
other
.contains_key
(
"min_tokens"
));
assert
!
(
!
other
.contains_key
(
"repetition_penalty"
));
assert
!
(
!
other
.contains_key
(
"regex"
));
assert
!
(
!
other
.contains_key
(
"ebnf"
));
assert
!
(
!
other
.contains_key
(
"stop_token_ids"
));
assert
!
(
!
other
.contains_key
(
"lora_path"
));
assert
!
(
!
other
.contains_key
(
"session_params"
));
// Boolean fields are always present with their values (can't be None)
assert_eq!
(
other
.get
(
"no_stop_trim"
),
Some
(
&
json!
(
false
)));
assert_eq!
(
other
.get
(
"ignore_eos"
),
Some
(
&
json!
(
false
)));
assert_eq!
(
other
.get
(
"continue_final_message"
),
Some
(
&
json!
(
false
)));
assert_eq!
(
other
.get
(
"skip_special_tokens"
),
Some
(
&
json!
(
true
)));
// default_true
assert_eq!
(
other
.get
(
"separate_reasoning"
),
Some
(
&
json!
(
true
)));
// default_true
assert_eq!
(
other
.get
(
"stream_reasoning"
),
Some
(
&
json!
(
true
)));
// default_true
assert_eq!
(
other
.get
(
"return_hidden_states"
),
Some
(
&
json!
(
false
)));
}
}
sgl-router/tests/benchmark_integration.rs
View file @
5d62b56f
...
...
@@ -8,14 +8,118 @@ use sglang_router_rs::openai_api_types::{
};
use
sglang_router_rs
::
routers
::
request_adapter
::{
RouteableRequest
,
ToPdRequest
};
/// Create a default GenerateRequest for benchmarks with minimal fields set
fn
default_generate_request
()
->
GenerateRequest
{
GenerateRequest
{
text
:
None
,
prompt
:
None
,
input_ids
:
None
,
stream
:
false
,
parameters
:
None
,
sampling_params
:
None
,
return_logprob
:
false
,
// SGLang Extensions
lora_path
:
None
,
session_params
:
None
,
return_hidden_states
:
false
,
rid
:
None
,
}
}
/// Create a default ChatCompletionRequest for benchmarks with minimal fields set
fn
default_chat_completion_request
()
->
ChatCompletionRequest
{
ChatCompletionRequest
{
model
:
String
::
new
(),
messages
:
vec!
[],
max_tokens
:
None
,
max_completion_tokens
:
None
,
temperature
:
None
,
top_p
:
None
,
n
:
None
,
stream
:
false
,
stream_options
:
None
,
stop
:
None
,
presence_penalty
:
None
,
frequency_penalty
:
None
,
logit_bias
:
None
,
logprobs
:
false
,
top_logprobs
:
None
,
user
:
None
,
response_format
:
None
,
seed
:
None
,
tools
:
None
,
tool_choice
:
None
,
parallel_tool_calls
:
None
,
function_call
:
None
,
functions
:
None
,
// SGLang Extensions
top_k
:
None
,
min_p
:
None
,
min_tokens
:
None
,
repetition_penalty
:
None
,
regex
:
None
,
ebnf
:
None
,
stop_token_ids
:
None
,
no_stop_trim
:
false
,
ignore_eos
:
false
,
continue_final_message
:
false
,
skip_special_tokens
:
true
,
// SGLang Extensions
lora_path
:
None
,
session_params
:
None
,
separate_reasoning
:
true
,
stream_reasoning
:
true
,
return_hidden_states
:
false
,
}
}
/// Create a default CompletionRequest for benchmarks with minimal fields set
fn
default_completion_request
()
->
CompletionRequest
{
CompletionRequest
{
model
:
String
::
new
(),
prompt
:
StringOrArray
::
String
(
String
::
new
()),
suffix
:
None
,
max_tokens
:
None
,
temperature
:
None
,
top_p
:
None
,
n
:
None
,
stream
:
false
,
stream_options
:
None
,
logprobs
:
None
,
echo
:
false
,
stop
:
None
,
presence_penalty
:
None
,
frequency_penalty
:
None
,
best_of
:
None
,
logit_bias
:
None
,
user
:
None
,
seed
:
None
,
// SGLang Extensions
top_k
:
None
,
min_p
:
None
,
min_tokens
:
None
,
repetition_penalty
:
None
,
regex
:
None
,
ebnf
:
None
,
json_schema
:
None
,
stop_token_ids
:
None
,
no_stop_trim
:
false
,
ignore_eos
:
false
,
skip_special_tokens
:
true
,
// SGLang Extensions
lora_path
:
None
,
session_params
:
None
,
return_hidden_states
:
false
,
other
:
serde_json
::
Map
::
new
(),
}
}
#[test]
fn
test_benchmark_request_creation
()
{
// Ensure all benchmark request types can be created without panicking
let
generate_req
=
GenerateRequest
{
text
:
Some
(
"Test prompt"
.to_string
()),
input_ids
:
None
,
prompt
:
None
,
parameters
:
Some
(
GenerateParameters
{
max_new_tokens
:
Some
(
100
),
temperature
:
Some
(
0.8
),
...
...
@@ -33,8 +137,7 @@ fn test_benchmark_request_creation() {
repetition_penalty
:
Some
(
1.0
),
..
Default
::
default
()
}),
stream
:
false
,
return_logprob
:
false
,
..
default_generate_request
()
};
let
chat_req
=
ChatCompletionRequest
{
...
...
@@ -49,44 +152,23 @@ fn test_benchmark_request_creation() {
temperature
:
Some
(
0.7
),
top_p
:
Some
(
1.0
),
n
:
Some
(
1
),
stream
:
false
,
stream_options
:
None
,
stop
:
None
,
presence_penalty
:
Some
(
0.0
),
frequency_penalty
:
Some
(
0.0
),
logit_bias
:
None
,
logprobs
:
false
,
top_logprobs
:
None
,
user
:
None
,
response_format
:
None
,
seed
:
None
,
tools
:
None
,
tool_choice
:
None
,
parallel_tool_calls
:
Some
(
true
),
function_call
:
None
,
functions
:
None
,
..
default_chat_completion_request
()
};
let
completion_req
=
CompletionRequest
{
model
:
"test-model"
.to_string
(),
prompt
:
StringOrArray
::
String
(
"Test prompt"
.to_string
()),
suffix
:
None
,
max_tokens
:
Some
(
50
),
temperature
:
Some
(
0.8
),
top_p
:
Some
(
1.0
),
n
:
Some
(
1
),
stream
:
false
,
stream_options
:
None
,
logprobs
:
None
,
echo
:
false
,
stop
:
None
,
presence_penalty
:
Some
(
0.0
),
frequency_penalty
:
Some
(
0.0
),
best_of
:
Some
(
1
),
logit_bias
:
None
,
user
:
None
,
seed
:
None
,
other
:
serde_json
::
Map
::
new
(),
..
default_completion_request
()
};
// Test serialization works
...
...
@@ -101,12 +183,7 @@ fn test_benchmark_serialization_roundtrip() {
let
generate_req
=
GenerateRequest
{
text
:
Some
(
"Test prompt"
.to_string
()),
input_ids
:
None
,
prompt
:
None
,
parameters
:
None
,
sampling_params
:
None
,
stream
:
false
,
return_logprob
:
false
,
..
default_generate_request
()
};
// Serialize and deserialize
...
...
@@ -125,12 +202,7 @@ fn test_benchmark_request_adaptation() {
let
generate_req
=
GenerateRequest
{
text
:
Some
(
"Test prompt"
.to_string
()),
input_ids
:
None
,
prompt
:
None
,
parameters
:
None
,
sampling_params
:
None
,
stream
:
false
,
return_logprob
:
false
,
..
default_generate_request
()
};
let
chat_req
=
ChatCompletionRequest
{
...
...
@@ -145,44 +217,23 @@ fn test_benchmark_request_adaptation() {
temperature
:
Some
(
0.7
),
top_p
:
Some
(
1.0
),
n
:
Some
(
1
),
stream
:
false
,
stream_options
:
None
,
stop
:
None
,
presence_penalty
:
Some
(
0.0
),
frequency_penalty
:
Some
(
0.0
),
logit_bias
:
None
,
logprobs
:
false
,
top_logprobs
:
None
,
user
:
None
,
response_format
:
None
,
seed
:
None
,
tools
:
None
,
tool_choice
:
None
,
parallel_tool_calls
:
Some
(
true
),
function_call
:
None
,
functions
:
None
,
..
default_chat_completion_request
()
};
let
completion_req
=
CompletionRequest
{
model
:
"test-model"
.to_string
(),
prompt
:
StringOrArray
::
String
(
"Test prompt"
.to_string
()),
suffix
:
None
,
max_tokens
:
Some
(
50
),
temperature
:
Some
(
0.8
),
top_p
:
Some
(
1.0
),
n
:
Some
(
1
),
stream
:
false
,
stream_options
:
None
,
logprobs
:
None
,
echo
:
false
,
stop
:
None
,
presence_penalty
:
Some
(
0.0
),
frequency_penalty
:
Some
(
0.0
),
best_of
:
Some
(
1
),
logit_bias
:
None
,
user
:
None
,
seed
:
None
,
other
:
serde_json
::
Map
::
new
(),
..
default_completion_request
()
};
// Test PD adaptation (should not panic)
...
...
@@ -197,12 +248,7 @@ fn test_benchmark_regular_routing() {
let
generate_req
=
GenerateRequest
{
text
:
Some
(
"Test prompt"
.to_string
()),
input_ids
:
None
,
prompt
:
None
,
parameters
:
None
,
sampling_params
:
None
,
stream
:
false
,
return_logprob
:
false
,
..
default_generate_request
()
};
// Test regular routing methods (should not panic)
...
...
@@ -217,12 +263,7 @@ fn test_benchmark_performance_baseline() {
let
generate_req
=
GenerateRequest
{
text
:
Some
(
"Short test prompt"
.to_string
()),
input_ids
:
None
,
prompt
:
None
,
parameters
:
None
,
sampling_params
:
None
,
stream
:
false
,
return_logprob
:
false
,
..
default_generate_request
()
};
// Serialization should be fast (< 1ms for simple requests)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment