Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
33b3c0f8
Unverified
Commit
33b3c0f8
authored
Sep 30, 2025
by
Simo Lin
Committed by
GitHub
Sep 29, 2025
Browse files
[router] grpc router generate endpoint support (#11070)
Co-authored-by:
Chang Su
<
chang.s.su@oracle.com
>
parent
e5281f84
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
480 additions
and
89 deletions
+480
-89
sgl-router/src/grpc_client/sglang_scheduler.rs
sgl-router/src/grpc_client/sglang_scheduler.rs
+132
-3
sgl-router/src/routers/grpc/router.rs
sgl-router/src/routers/grpc/router.rs
+348
-86
No files found.
sgl-router/src/grpc_client/sglang_scheduler.rs
View file @
33b3c0f8
use
std
::
convert
::
TryFrom
;
use
std
::
time
::
Duration
;
use
std
::
time
::
Duration
;
use
tonic
::{
transport
::
Channel
,
Request
};
use
tonic
::{
transport
::
Channel
,
Request
};
use
tracing
::
debug
;
use
tracing
::
debug
;
use
crate
::
protocols
::
spec
::{
ChatCompletionRequest
,
ResponseFormat
};
use
crate
::
protocols
::
spec
::{
ChatCompletionRequest
,
GenerateRequest
,
ResponseFormat
,
SamplingParams
as
GenerateSamplingParams
,
StringOrArray
,
};
// Include the generated protobuf code
// Include the generated protobuf code
pub
mod
proto
{
pub
mod
proto
{
...
@@ -112,6 +116,37 @@ impl SglangSchedulerClient {
...
@@ -112,6 +116,37 @@ impl SglangSchedulerClient {
Ok
(
grpc_request
)
Ok
(
grpc_request
)
}
}
/// Build a basic GenerateRequest from the SGLang spec GenerateRequest
pub
fn
build_plain_generate_request
(
&
self
,
request_id
:
String
,
body
:
&
GenerateRequest
,
original_text
:
Option
<
String
>
,
token_ids
:
Vec
<
u32
>
,
)
->
Result
<
proto
::
GenerateRequest
,
String
>
{
let
sampling_params
=
Self
::
build_sampling_params_from_plain
(
body
.sampling_params
.as_ref
())
?
;
let
grpc_request
=
proto
::
GenerateRequest
{
request_id
,
tokenized
:
Some
(
proto
::
TokenizedInput
{
original_text
:
original_text
.unwrap_or_default
(),
input_ids
:
token_ids
,
}),
sampling_params
:
Some
(
sampling_params
),
return_logprob
:
body
.return_logprob
,
logprob_start_len
:
-
1
,
top_logprobs_num
:
0
,
token_ids_logprob
:
vec!
[],
return_hidden_states
:
body
.return_hidden_states
,
stream
:
body
.stream
,
log_metrics
:
true
,
..
Default
::
default
()
};
Ok
(
grpc_request
)
}
/// Build gRPC SamplingParams from OpenAI request
/// Build gRPC SamplingParams from OpenAI request
fn
build_grpc_sampling_params
(
fn
build_grpc_sampling_params
(
&
self
,
&
self
,
...
@@ -165,8 +200,8 @@ impl SglangSchedulerClient {
...
@@ -165,8 +200,8 @@ impl SglangSchedulerClient {
/// Extract stop strings from request
/// Extract stop strings from request
fn
extract_stop_strings
(
&
self
,
request
:
&
ChatCompletionRequest
)
->
Vec
<
String
>
{
fn
extract_stop_strings
(
&
self
,
request
:
&
ChatCompletionRequest
)
->
Vec
<
String
>
{
match
&
request
.stop
{
match
&
request
.stop
{
Some
(
crate
::
protocols
::
spec
::
StringOrArray
::
String
(
s
))
=>
vec!
[
s
.clone
()],
Some
(
StringOrArray
::
String
(
s
))
=>
vec!
[
s
.clone
()],
Some
(
crate
::
protocols
::
spec
::
StringOrArray
::
Array
(
arr
))
=>
arr
.clone
(),
Some
(
StringOrArray
::
Array
(
arr
))
=>
arr
.clone
(),
None
=>
vec!
[],
None
=>
vec!
[],
}
}
}
}
...
@@ -218,6 +253,100 @@ impl SglangSchedulerClient {
...
@@ -218,6 +253,100 @@ impl SglangSchedulerClient {
_
=>
Err
(
"Multiple constraints are not allowed."
.to_string
()),
_
=>
Err
(
"Multiple constraints are not allowed."
.to_string
()),
}
}
}
}
fn
build_single_constraint_from_plain
(
params
:
&
GenerateSamplingParams
,
)
->
Result
<
Option
<
proto
::
sampling_params
::
Constraint
>
,
String
>
{
let
mut
constraints
=
Vec
::
new
();
if
let
Some
(
json_schema
)
=
&
params
.json_schema
{
constraints
.push
(
proto
::
sampling_params
::
Constraint
::
JsonSchema
(
json_schema
.clone
(),
));
}
if
let
Some
(
regex
)
=
&
params
.regex
{
constraints
.push
(
proto
::
sampling_params
::
Constraint
::
Regex
(
regex
.clone
()));
}
if
let
Some
(
ebnf
)
=
&
params
.ebnf
{
constraints
.push
(
proto
::
sampling_params
::
Constraint
::
EbnfGrammar
(
ebnf
.clone
(),
));
}
match
constraints
.len
()
{
0
=>
Ok
(
None
),
1
=>
Ok
(
constraints
.pop
()),
_
=>
Err
(
"Multiple structured constraints are not allowed"
.to_string
()),
}
}
fn
build_sampling_params_from_plain
(
params
:
Option
<&
GenerateSamplingParams
>
,
)
->
Result
<
proto
::
SamplingParams
,
String
>
{
let
mut
sampling
=
proto
::
SamplingParams
{
temperature
:
1.0
,
top_p
:
1.0
,
top_k
:
-
1
,
repetition_penalty
:
1.0
,
n
:
1
,
..
Default
::
default
()
};
let
Some
(
p
)
=
params
else
{
return
Ok
(
sampling
);
};
// Simple field mappings using a macro
macro_rules!
map_field
{
(
$field:ident
)
=>
{
if
let
Some
(
val
)
=
p
.
$field
{
sampling
.
$field
=
val
;
}
};
}
map_field!
(
temperature
);
map_field!
(
top_p
);
map_field!
(
top_k
);
map_field!
(
frequency_penalty
);
map_field!
(
presence_penalty
);
map_field!
(
repetition_penalty
);
map_field!
(
min_p
);
map_field!
(
ignore_eos
);
map_field!
(
skip_special_tokens
);
map_field!
(
no_stop_trim
);
// Handle stop sequences
if
let
Some
(
stop
)
=
&
p
.stop
{
match
stop
{
StringOrArray
::
String
(
s
)
=>
sampling
.stop
.push
(
s
.clone
()),
StringOrArray
::
Array
(
arr
)
=>
sampling
.stop
.extend
(
arr
.clone
()),
}
}
// Handle stop token IDs
if
let
Some
(
stop_token_ids
)
=
&
p
.stop_token_ids
{
sampling
.stop_token_ids
=
stop_token_ids
.clone
();
}
// Handle max_new_tokens with conversion
if
let
Some
(
max_new_tokens
)
=
p
.max_new_tokens
{
sampling
.max_new_tokens
=
Some
(
i32
::
try_from
(
max_new_tokens
)
.map_err
(|
_
|
{
"max_new_tokens must fit into a 32-bit signed integer"
.to_string
()
})
?
);
}
// Handle min_tokens with conversion
if
let
Some
(
min_tokens
)
=
p
.min_tokens
{
sampling
.min_new_tokens
=
i32
::
try_from
(
min_tokens
)
.map_err
(|
_
|
"min_tokens must fit into a 32-bit signed integer"
.to_string
())
?
;
}
// Handle constraints (exactly one allowed)
sampling
.constraint
=
Self
::
build_single_constraint_from_plain
(
p
)
?
;
Ok
(
sampling
)
}
}
}
#[cfg(test)]
#[cfg(test)]
...
...
sgl-router/src/routers/grpc/router.rs
View file @
33b3c0f8
...
@@ -27,12 +27,15 @@ use crate::reasoning_parser::ParserFactory;
...
@@ -27,12 +27,15 @@ use crate::reasoning_parser::ParserFactory;
use
crate
::
routers
::
RouterTrait
;
use
crate
::
routers
::
RouterTrait
;
use
crate
::
server
::
AppContext
;
use
crate
::
server
::
AppContext
;
use
crate
::
tokenizer
::
chat_template
::{
ChatTemplateContentFormat
,
ChatTemplateParams
};
use
crate
::
tokenizer
::
chat_template
::{
ChatTemplateContentFormat
,
ChatTemplateParams
};
use
crate
::
tokenizer
::
stop
::{
SequenceDecoderOutput
,
StopSequenceDecoderBuilder
};
use
crate
::
tokenizer
::
stop
::{
SequenceDecoderOutput
,
StopSequenceDecoder
,
StopSequenceDecoderBuilder
,
};
use
crate
::
tokenizer
::
traits
::
Tokenizer
;
use
crate
::
tokenizer
::
traits
::
Tokenizer
;
use
crate
::
tokenizer
::
HuggingFaceTokenizer
;
use
crate
::
tokenizer
::
HuggingFaceTokenizer
;
use
crate
::
tool_parser
::
ParserRegistry
;
use
crate
::
tool_parser
::
ParserRegistry
;
use
serde_json
::
Value
;
use
proto
::
generate_response
::
Response
::{
Chunk
,
Complete
,
Error
};
use
std
::
time
::{
SystemTime
,
UNIX_EPOCH
};
use
serde_json
::{
json
,
Value
};
use
std
::
time
::{
Instant
,
SystemTime
,
UNIX_EPOCH
};
use
tokio_stream
::
StreamExt
;
use
tokio_stream
::
StreamExt
;
use
uuid
::
Uuid
;
use
uuid
::
Uuid
;
...
@@ -124,28 +127,9 @@ impl GrpcRouter {
...
@@ -124,28 +127,9 @@ impl GrpcRouter {
debug!
(
"Selected worker: {}"
,
worker
.url
());
debug!
(
"Selected worker: {}"
,
worker
.url
());
// Step 2: Get gRPC client from worker
// Step 2: Get gRPC client from worker
let
client
=
match
worker
.get_grpc_client
()
.await
{
let
client
=
match
Self
::
get_grpc_client_from_worker
(
&
worker
)
.await
{
Ok
(
Some
(
client_arc
))
=>
{
Ok
(
client
)
=>
client
,
// Clone the client from inside the Arc<Mutex<>>
Err
(
response
)
=>
return
response
,
let
client
=
client_arc
.lock
()
.await
.clone
();
client
}
Ok
(
None
)
=>
{
error!
(
"Selected worker is not a gRPC worker"
);
return
(
StatusCode
::
INTERNAL_SERVER_ERROR
,
"Selected worker is not configured for gRPC"
,
)
.into_response
();
}
Err
(
e
)
=>
{
error!
(
"Failed to get gRPC client from worker: {}"
,
e
);
return
(
StatusCode
::
INTERNAL_SERVER_ERROR
,
format!
(
"Failed to get gRPC client: {}"
,
e
),
)
.into_response
();
}
};
};
// Step 3: Process messages and apply chat template
// Step 3: Process messages and apply chat template
...
@@ -209,6 +193,112 @@ impl GrpcRouter {
...
@@ -209,6 +193,112 @@ impl GrpcRouter {
}
}
}
}
/// Main route_generate implementation
async
fn
route_generate_impl
(
&
self
,
_
headers
:
Option
<&
HeaderMap
>
,
body
:
&
GenerateRequest
,
model_id
:
Option
<&
str
>
,
)
->
Response
{
debug!
(
"Processing generate request for model: {:?}"
,
model_id
);
// Step 1: Resolve input (text, prompt, or input_ids)
let
(
original_text
,
token_ids
)
=
match
self
.resolve_generate_input
(
body
)
{
Ok
(
res
)
=>
res
,
Err
(
msg
)
=>
{
error!
(
"Invalid generate request: {}"
,
msg
);
return
(
StatusCode
::
BAD_REQUEST
,
msg
)
.into_response
();
}
};
debug!
(
"Resolved input with {} tokens"
,
token_ids
.len
());
// Step 2: Select worker (fail fast if no workers available)
let
worker
=
match
self
.select_worker_for_request
(
model_id
,
original_text
.as_deref
())
{
Some
(
w
)
=>
w
,
None
=>
{
warn!
(
"No available workers for model: {:?}"
,
model_id
);
return
(
StatusCode
::
SERVICE_UNAVAILABLE
,
"No available workers"
)
.into_response
();
}
};
debug!
(
"Selected worker: {}"
,
worker
.url
());
// Step 3: Get gRPC client from worker
let
client
=
match
Self
::
get_grpc_client_from_worker
(
&
worker
)
.await
{
Ok
(
client
)
=>
client
,
Err
(
response
)
=>
return
response
,
};
// Step 4: Build the gRPC request
let
request_id
=
body
.rid
.clone
()
.unwrap_or_else
(||
format!
(
"gen-{}"
,
Uuid
::
new_v4
()));
let
request
=
match
client
.build_plain_generate_request
(
request_id
.clone
(),
body
,
original_text
.clone
(),
token_ids
,
)
{
Ok
(
req
)
=>
req
,
Err
(
e
)
=>
{
error!
(
"Failed to build generate request: {}"
,
e
);
return
(
StatusCode
::
BAD_REQUEST
,
e
)
.into_response
();
}
};
// Step 5: Get weight version for response metadata
let
weight_version
=
worker
.metadata
()
.labels
.get
(
"weight_version"
)
.cloned
()
.unwrap_or_else
(||
"default"
.to_string
());
// Step 6: Handle streaming vs non-streaming
if
body
.stream
{
// TODO: Implement streaming support for generate endpoint
return
(
StatusCode
::
NOT_IMPLEMENTED
,
"Streaming generate over gRPC is not supported yet"
,
)
.into_response
();
}
self
.handle_non_streaming_generate
(
client
,
request
,
body
,
request_id
,
weight_version
)
.await
}
/// Get gRPC client from worker, returning appropriate error response on failure
async
fn
get_grpc_client_from_worker
(
worker
:
&
Arc
<
dyn
Worker
>
,
)
->
Result
<
SglangSchedulerClient
,
Response
>
{
let
client_arc
=
worker
.get_grpc_client
()
.await
.map_err
(|
e
|
{
error!
(
"Failed to get gRPC client from worker: {}"
,
e
);
(
StatusCode
::
INTERNAL_SERVER_ERROR
,
format!
(
"Failed to get gRPC client: {}"
,
e
),
)
.into_response
()
})
?
.ok_or_else
(||
{
error!
(
"Selected worker is not a gRPC worker"
);
(
StatusCode
::
INTERNAL_SERVER_ERROR
,
"Selected worker is not configured for gRPC"
,
)
.into_response
()
})
?
;
let
client
=
client_arc
.lock
()
.await
.clone
();
Ok
(
client
)
}
/// Select a worker for the request
/// Select a worker for the request
fn
select_worker_for_request
(
fn
select_worker_for_request
(
&
self
,
&
self
,
...
@@ -265,7 +355,7 @@ impl GrpcRouter {
...
@@ -265,7 +355,7 @@ impl GrpcRouter {
Self
::
process_tool_call_arguments
(
&
mut
transformed_messages
)
?
;
Self
::
process_tool_call_arguments
(
&
mut
transformed_messages
)
?
;
// Convert tools to JSON values for template processing
// Convert tools to JSON values for template processing
let
tools_json
:
Option
<
Vec
<
serde_json
::
Value
>>
=
request
let
tools_json
:
Option
<
Vec
<
Value
>>
=
request
.tools
.tools
.as_ref
()
.as_ref
()
.map
(|
tools
|
{
.map
(|
tools
|
{
...
@@ -284,7 +374,7 @@ impl GrpcRouter {
...
@@ -284,7 +374,7 @@ impl GrpcRouter {
if
let
Some
(
reasoning_effort
)
=
&
request
.reasoning_effort
{
if
let
Some
(
reasoning_effort
)
=
&
request
.reasoning_effort
{
combined_template_kwargs
.insert
(
combined_template_kwargs
.insert
(
"reasoning_effort"
.to_string
(),
"reasoning_effort"
.to_string
(),
serde_json
::
Value
::
String
(
reasoning_effort
.clone
()),
Value
::
String
(
reasoning_effort
.clone
()),
);
);
}
}
...
@@ -413,9 +503,9 @@ impl GrpcRouter {
...
@@ -413,9 +503,9 @@ impl GrpcRouter {
part
.as_object
()
part
.as_object
()
.and_then
(|
obj
|
obj
.get
(
"type"
)
?
.as_str
())
.and_then
(|
obj
|
obj
.get
(
"type"
)
?
.as_str
())
.and_then
(|
type_str
|
match
type_str
{
.and_then
(|
type_str
|
match
type_str
{
"image_url"
=>
Some
(
serde_json
::
json!
({
"type"
:
"image"
})),
"image_url"
=>
Some
(
json!
({
"type"
:
"image"
})),
"video_url"
=>
Some
(
serde_json
::
json!
({
"type"
:
"video"
})),
"video_url"
=>
Some
(
json!
({
"type"
:
"video"
})),
"audio_url"
=>
Some
(
serde_json
::
json!
({
"type"
:
"audio"
})),
"audio_url"
=>
Some
(
json!
({
"type"
:
"audio"
})),
_
=>
None
,
_
=>
None
,
})
})
.unwrap_or_else
(||
part
.clone
())
.unwrap_or_else
(||
part
.clone
())
...
@@ -456,7 +546,7 @@ impl GrpcRouter {
...
@@ -456,7 +546,7 @@ impl GrpcRouter {
};
};
// Parse JSON string to object (like Python json.loads)
// Parse JSON string to object (like Python json.loads)
match
serde_json
::
from_str
::
<
serde_json
::
Value
>
(
args_str
)
{
match
serde_json
::
from_str
::
<
Value
>
(
args_str
)
{
Ok
(
parsed
)
=>
*
args
=
parsed
,
Ok
(
parsed
)
=>
*
args
=
parsed
,
Err
(
e
)
=>
{
Err
(
e
)
=>
{
return
Err
(
format!
(
return
Err
(
format!
(
...
@@ -483,13 +573,63 @@ impl GrpcRouter {
...
@@ -483,13 +573,63 @@ impl GrpcRouter {
None
None
}
}
/// Create a StopSequenceDecoder from the chat completion request
/// Resolve the generate input into optional original text and token IDs
fn
resolve_generate_input
(
&
self
,
request
:
&
GenerateRequest
,
)
->
Result
<
(
Option
<
String
>
,
Vec
<
u32
>
),
String
>
{
if
let
Some
(
text
)
=
&
request
.text
{
return
self
.tokenize_single_text
(
text
)
.map
(|(
original
,
ids
)|
(
Some
(
original
),
ids
));
}
// Handle input_ids - validate and convert
if
let
Some
(
input_ids
)
=
&
request
.input_ids
{
return
match
input_ids
{
crate
::
protocols
::
spec
::
InputIds
::
Single
(
ids
)
=>
ids
.iter
()
.map
(|
&
id
|
u32
::
try_from
(
id
))
.collect
::
<
Result
<
Vec
<
u32
>
,
_
>>
()
.map
(|
converted
|
(
None
,
converted
))
.map_err
(|
_
|
"input_ids must be non-negative"
.to_string
()),
crate
::
protocols
::
spec
::
InputIds
::
Batch
(
_
)
=>
{
Err
(
"Batch input_ids are not supported over gRPC generate yet"
.to_string
())
}
};
}
Err
(
"Either `text` or `input_ids` must be provided"
.to_string
())
}
fn
tokenize_single_text
(
&
self
,
text
:
&
str
)
->
Result
<
(
String
,
Vec
<
u32
>
),
String
>
{
let
encoding
=
self
.tokenizer
.encode
(
text
)
.map_err
(|
e
|
format!
(
"Tokenization failed: {}"
,
e
))
?
;
Ok
((
text
.to_string
(),
encoding
.token_ids
()
.to_vec
()))
}
fn
internal_error_static
(
msg
:
&
'static
str
)
->
Response
{
error!
(
"{}"
,
msg
);
(
StatusCode
::
INTERNAL_SERVER_ERROR
,
msg
)
.into_response
()
}
fn
internal_error_message
(
message
:
String
)
->
Response
{
error!
(
"{}"
,
message
);
(
StatusCode
::
INTERNAL_SERVER_ERROR
,
message
)
.into_response
()
}
/// Create a StopSequenceDecoder from stop parameters
fn
create_stop_decoder
(
fn
create_stop_decoder
(
&
self
,
&
self
,
original_request
:
&
ChatCompletionRequest
,
stop
:
Option
<&
StringOrArray
>
,
)
->
crate
::
tokenizer
::
stop
::
StopSequenceDecoder
{
stop_token_ids
:
Option
<&
Vec
<
u32
>>
,
// Extract stop sequences from request
skip_special_tokens
:
bool
,
let
stop_sequences
:
Vec
<
String
>
=
match
&
original_request
.stop
{
no_stop_trim
:
bool
,
)
->
StopSequenceDecoder
{
// Extract stop sequences
let
stop_sequences
:
Vec
<
String
>
=
match
stop
{
Some
(
StringOrArray
::
String
(
s
))
=>
vec!
[
s
.clone
()],
Some
(
StringOrArray
::
String
(
s
))
=>
vec!
[
s
.clone
()],
Some
(
StringOrArray
::
Array
(
arr
))
=>
arr
.clone
(),
Some
(
StringOrArray
::
Array
(
arr
))
=>
arr
.clone
(),
None
=>
vec!
[],
None
=>
vec!
[],
...
@@ -497,11 +637,11 @@ impl GrpcRouter {
...
@@ -497,11 +637,11 @@ impl GrpcRouter {
// Build stop sequence decoder
// Build stop sequence decoder
let
mut
builder
=
StopSequenceDecoderBuilder
::
new
(
self
.tokenizer
.clone
())
let
mut
builder
=
StopSequenceDecoderBuilder
::
new
(
self
.tokenizer
.clone
())
.skip_special_tokens
(
original_request
.
skip_special_tokens
);
.skip_special_tokens
(
skip_special_tokens
);
// Add stop sequences (visible if no_stop_trim is true, hidden otherwise)
// Add stop sequences (visible if no_stop_trim is true, hidden otherwise)
for
seq
in
stop_sequences
{
for
seq
in
stop_sequences
{
builder
=
if
original_request
.
no_stop_trim
{
builder
=
if
no_stop_trim
{
builder
.visible_stop_sequence
(
seq
)
builder
.visible_stop_sequence
(
seq
)
}
else
{
}
else
{
builder
.stop_sequence
(
seq
)
builder
.stop_sequence
(
seq
)
...
@@ -509,9 +649,9 @@ impl GrpcRouter {
...
@@ -509,9 +649,9 @@ impl GrpcRouter {
}
}
// Add stop token IDs (visible if no_stop_trim is true, hidden otherwise)
// Add stop token IDs (visible if no_stop_trim is true, hidden otherwise)
if
let
Some
(
stop_
token_ids
)
=
&
original_request
.
stop_token_ids
{
if
let
Some
(
token_ids
)
=
stop_token_ids
{
for
&
token_id
in
stop_
token_ids
{
for
&
token_id
in
token_ids
{
builder
=
if
original_request
.
no_stop_trim
{
builder
=
if
no_stop_trim
{
builder
.visible_stop_token
(
token_id
)
builder
.visible_stop_token
(
token_id
)
}
else
{
}
else
{
builder
.stop_token
(
token_id
)
builder
.stop_token
(
token_id
)
...
@@ -524,7 +664,7 @@ impl GrpcRouter {
...
@@ -524,7 +664,7 @@ impl GrpcRouter {
/// Process a chunk of tokens through the stop decoder
/// Process a chunk of tokens through the stop decoder
fn
process_chunk_tokens
(
fn
process_chunk_tokens
(
stop_decoder
:
&
mut
crate
::
tokenizer
::
stop
::
StopSequenceDecoder
,
stop_decoder
:
&
mut
StopSequenceDecoder
,
token_ids
:
&
[
u32
],
token_ids
:
&
[
u32
],
)
->
(
String
,
bool
)
{
)
->
(
String
,
bool
)
{
let
mut
chunk_text
=
String
::
new
();
let
mut
chunk_text
=
String
::
new
();
...
@@ -562,7 +702,12 @@ impl GrpcRouter {
...
@@ -562,7 +702,12 @@ impl GrpcRouter {
request
:
proto
::
GenerateRequest
,
request
:
proto
::
GenerateRequest
,
original_request
:
&
ChatCompletionRequest
,
original_request
:
&
ChatCompletionRequest
,
)
->
Response
{
)
->
Response
{
let
mut
stop_decoder
=
self
.create_stop_decoder
(
original_request
);
let
mut
stop_decoder
=
self
.create_stop_decoder
(
original_request
.stop
.as_ref
(),
original_request
.stop_token_ids
.as_ref
(),
original_request
.skip_special_tokens
,
original_request
.no_stop_trim
,
);
// Process streaming tokens
// Process streaming tokens
let
mut
grpc_stream
=
match
client
.generate
(
request
)
.await
{
let
mut
grpc_stream
=
match
client
.generate
(
request
)
.await
{
...
@@ -589,7 +734,7 @@ impl GrpcRouter {
...
@@ -589,7 +734,7 @@ impl GrpcRouter {
};
};
match
gen_response
.response
{
match
gen_response
.response
{
Some
(
proto
::
generate_response
::
Response
::
Chunk
(
chunk
))
=>
{
Some
(
Chunk
(
chunk
))
=>
{
// Process tokens and check if we should stop
// Process tokens and check if we should stop
let
(
chunk_text
,
should_stop
)
=
let
(
chunk_text
,
should_stop
)
=
Self
::
process_chunk_tokens
(
&
mut
stop_decoder
,
&
chunk
.token_ids
);
Self
::
process_chunk_tokens
(
&
mut
stop_decoder
,
&
chunk
.token_ids
);
...
@@ -599,7 +744,7 @@ impl GrpcRouter {
...
@@ -599,7 +744,7 @@ impl GrpcRouter {
}
}
continue
;
continue
;
}
}
Some
(
proto
::
generate_response
::
Response
::
Complete
(
_
complete
))
=>
{
Some
(
Complete
(
_
complete
))
=>
{
// Flush any remaining text
// Flush any remaining text
if
let
SequenceDecoderOutput
::
Text
(
text
)
=
stop_decoder
.flush
()
{
if
let
SequenceDecoderOutput
::
Text
(
text
)
=
stop_decoder
.flush
()
{
if
!
text
.is_empty
()
{
if
!
text
.is_empty
()
{
...
@@ -609,7 +754,7 @@ impl GrpcRouter {
...
@@ -609,7 +754,7 @@ impl GrpcRouter {
}
}
break
;
break
;
}
}
Some
(
proto
::
generate_response
::
Response
::
Error
(
error
))
=>
{
Some
(
Error
(
error
))
=>
{
error!
(
"Generation error: {}"
,
error
.message
);
error!
(
"Generation error: {}"
,
error
.message
);
break
;
break
;
}
}
...
@@ -629,26 +774,19 @@ impl GrpcRouter {
...
@@ -629,26 +774,19 @@ impl GrpcRouter {
request
:
proto
::
GenerateRequest
,
request
:
proto
::
GenerateRequest
,
original_request
:
&
ChatCompletionRequest
,
original_request
:
&
ChatCompletionRequest
,
)
->
Response
{
)
->
Response
{
let
mut
stop_decoder
=
self
.create_stop_decoder
(
original_request
);
let
mut
stop_decoder
=
self
.create_stop_decoder
(
original_request
.stop
.as_ref
(),
// Small helpers to log + return a uniform 500
original_request
.stop_token_ids
.as_ref
(),
let
fail_str
=
|
msg
:
&
'static
str
|
->
Response
{
original_request
.skip_special_tokens
,
error!
(
"{}"
,
msg
);
original_request
.no_stop_trim
,
(
StatusCode
::
INTERNAL_SERVER_ERROR
,
msg
)
.into_response
()
);
};
let
fail_fmt
=
|
prefix
:
&
str
,
e
:
&
dyn
std
::
fmt
::
Display
|
->
Response
{
error!
(
"{}{}"
,
prefix
,
e
);
(
StatusCode
::
INTERNAL_SERVER_ERROR
,
format!
(
"{}{}"
,
prefix
,
e
),
)
.into_response
()
};
// Start generation
// Start generation
let
mut
stream
=
match
client
.generate
(
request
)
.await
{
let
mut
stream
=
match
client
.generate
(
request
)
.await
{
Ok
(
s
)
=>
s
,
Ok
(
s
)
=>
s
,
Err
(
e
)
=>
return
fail_fmt
(
"Failed to start generation: "
,
&
e
),
Err
(
e
)
=>
{
return
Self
::
internal_error_message
(
format!
(
"Failed to start generation: {}"
,
e
))
}
};
};
// Collect all responses (for n>1 support)
// Collect all responses (for n>1 support)
...
@@ -656,28 +794,33 @@ impl GrpcRouter {
...
@@ -656,28 +794,33 @@ impl GrpcRouter {
while
let
Some
(
response
)
=
stream
.next
()
.await
{
while
let
Some
(
response
)
=
stream
.next
()
.await
{
match
response
{
match
response
{
Ok
(
gen_response
)
=>
match
gen_response
.response
{
Ok
(
gen_response
)
=>
match
gen_response
.response
{
Some
(
proto
::
generate_response
::
Response
::
Complete
(
complete
))
=>
{
Some
(
Complete
(
complete
))
=>
{
all_responses
.push
(
complete
);
all_responses
.push
(
complete
);
}
}
Some
(
proto
::
generate_response
::
Response
::
Error
(
err
))
=>
{
Some
(
Error
(
err
))
=>
{
error!
(
"Generation failed for one choice: {}"
,
err
.message
);
return
Self
::
internal_error_message
(
format!
(
return
(
"Generation failed: {}"
,
StatusCode
::
INTERNAL_SERVER_ERROR
,
err
.message
format!
(
"Generation failed: {}"
,
err
.message
),
));
)
.into_response
();
}
}
Some
(
proto
::
generate_response
::
Response
::
Chunk
(
_
))
=>
{
Some
(
Chunk
(
_
))
=>
{
return
fail_str
(
"Unexpected chunk response for non-streaming request"
)
return
Self
::
internal_error_static
(
"Unexpected chunk response for non-streaming request"
,
)
}
}
None
=>
return
fail_str
(
"Empty response from server"
),
None
=>
return
Self
::
internal_error_static
(
"Empty response from server"
),
},
},
Err
(
e
)
=>
return
fail_fmt
(
"Failed to get GenerateResponse: "
,
&
e
),
Err
(
e
)
=>
{
return
Self
::
internal_error_message
(
format!
(
"Failed to get GenerateResponse: {}"
,
e
))
}
}
}
}
}
if
all_responses
.is_empty
()
{
if
all_responses
.is_empty
()
{
return
fail_str
(
"No responses from server"
);
return
Self
::
internal_error_static
(
"No responses from server"
);
}
}
// Process each response into a ChatChoice
// Process each response into a ChatChoice
...
@@ -689,12 +832,10 @@ impl GrpcRouter {
...
@@ -689,12 +832,10 @@ impl GrpcRouter {
{
{
Ok
(
choice
)
=>
choices
.push
(
choice
),
Ok
(
choice
)
=>
choices
.push
(
choice
),
Err
(
e
)
=>
{
Err
(
e
)
=>
{
error!
(
"Failed to process choice {}: {}"
,
index
,
e
);
return
Self
::
internal_error_message
(
format!
(
return
(
"Failed to process choice {}: {}"
,
StatusCode
::
INTERNAL_SERVER_ERROR
,
index
,
e
format!
(
"Failed to process choice {}: {}"
,
index
,
e
),
));
)
.into_response
();
}
}
}
}
}
}
...
@@ -730,6 +871,127 @@ impl GrpcRouter {
...
@@ -730,6 +871,127 @@ impl GrpcRouter {
Json
(
response
)
.into_response
()
Json
(
response
)
.into_response
()
}
}
/// Submit request and handle non-streaming response for the `/generate` endpoint
async
fn
handle_non_streaming_generate
(
&
self
,
mut
client
:
SglangSchedulerClient
,
request
:
proto
::
GenerateRequest
,
original_request
:
&
GenerateRequest
,
request_id
:
String
,
weight_version
:
String
,
)
->
Response
{
let
start_time
=
Instant
::
now
();
let
mut
stream
=
match
client
.generate
(
request
)
.await
{
Ok
(
stream
)
=>
stream
,
Err
(
e
)
=>
{
return
Self
::
internal_error_message
(
format!
(
"Failed to start generation: {}"
,
e
))
}
};
let
mut
final_completion
:
Option
<
proto
::
GenerateComplete
>
=
None
;
while
let
Some
(
result
)
=
stream
.next
()
.await
{
match
result
{
Ok
(
gen_response
)
=>
match
gen_response
.response
{
Some
(
Complete
(
complete
))
=>
{
final_completion
=
Some
(
complete
);
break
;
}
Some
(
Error
(
err
))
=>
{
return
Self
::
internal_error_message
(
format!
(
"Generation failed: {}"
,
err
.message
));
}
Some
(
Chunk
(
_
))
|
None
=>
continue
,
},
Err
(
e
)
=>
{
return
Self
::
internal_error_message
(
format!
(
"Failed to receive generate response: {}"
,
e
))
}
}
}
let
mut
complete
=
match
final_completion
{
Some
(
c
)
=>
c
,
None
=>
{
return
Self
::
internal_error_static
(
"No completion received from scheduler"
);
}
};
// Create stop decoder from sampling params
let
params
=
original_request
.sampling_params
.as_ref
();
let
mut
stop_decoder
=
self
.create_stop_decoder
(
params
.and_then
(|
p
|
p
.stop
.as_ref
()),
params
.and_then
(|
p
|
p
.stop_token_ids
.as_ref
()),
params
.and_then
(|
p
|
p
.skip_special_tokens
)
.unwrap_or
(
true
),
params
.and_then
(|
p
|
p
.no_stop_trim
)
.unwrap_or
(
false
),
);
// Process tokens through stop decoder
let
outputs
=
match
stop_decoder
.process_tokens
(
&
complete
.output_ids
)
{
Ok
(
outputs
)
=>
outputs
,
Err
(
e
)
=>
{
return
Self
::
internal_error_message
(
format!
(
"Failed to process tokens: {}"
,
e
))
}
};
// Accumulate text with early breaks
let
mut
decoded_text
=
String
::
new
();
for
output
in
outputs
{
match
output
{
SequenceDecoderOutput
::
Text
(
t
)
=>
decoded_text
.push_str
(
&
t
),
SequenceDecoderOutput
::
StoppedWithText
(
t
)
=>
{
decoded_text
.push_str
(
&
t
);
break
;
}
SequenceDecoderOutput
::
Stopped
=>
break
,
SequenceDecoderOutput
::
Held
=>
{}
}
}
// Flush remaining text
if
let
SequenceDecoderOutput
::
Text
(
t
)
=
stop_decoder
.flush
()
{
decoded_text
.push_str
(
&
t
);
}
let
output_ids
=
complete
.output_ids
.clone
();
// Build base meta_info using json! macro
let
mut
meta_info
=
json!
({
"finish_reason"
:
complete
.finish_reason
.clone
(),
"prompt_tokens"
:
complete
.prompt_tokens
,
"completion_tokens"
:
complete
.completion_tokens
,
"cached_tokens"
:
complete
.cached_tokens
,
"id"
:
request_id
,
"weight_version"
:
weight_version
,
"e2e_latency"
:
start_time
.elapsed
()
.as_secs_f64
(),
});
let
meta_obj
=
meta_info
.as_object_mut
()
.unwrap
();
// Add matched_stop if present
if
let
Some
(
matched
)
=
complete
.matched_stop
.take
()
{
use
proto
::
generate_complete
::
MatchedStop
;
let
matched_value
=
match
matched
{
MatchedStop
::
MatchedTokenId
(
id
)
=>
json!
(
id
),
MatchedStop
::
MatchedStopStr
(
s
)
=>
json!
(
s
),
};
meta_obj
.insert
(
"matched_stop"
.to_string
(),
matched_value
);
}
let
response_body
=
json!
({
"text"
:
decoded_text
,
"output_ids"
:
output_ids
,
"meta_info"
:
meta_info
,
});
Json
(
response_body
)
.into_response
()
}
/// Convert proto LogProbs to OpenAI ChatLogProbs format
/// Convert proto LogProbs to OpenAI ChatLogProbs format
/// Note: Always decodes with skip_special_tokens=false to show actual tokens generated
/// Note: Always decodes with skip_special_tokens=false to show actual tokens generated
fn
convert_proto_to_openai_logprobs
(
fn
convert_proto_to_openai_logprobs
(
...
@@ -803,7 +1065,7 @@ impl GrpcRouter {
...
@@ -803,7 +1065,7 @@ impl GrpcRouter {
complete
:
&
proto
::
GenerateComplete
,
complete
:
&
proto
::
GenerateComplete
,
index
:
usize
,
index
:
usize
,
original_request
:
&
ChatCompletionRequest
,
original_request
:
&
ChatCompletionRequest
,
stop_decoder
:
&
mut
crate
::
tokenizer
::
stop
::
StopSequenceDecoder
,
stop_decoder
:
&
mut
StopSequenceDecoder
,
)
->
Result
<
ChatChoice
,
String
>
{
)
->
Result
<
ChatChoice
,
String
>
{
stop_decoder
.reset
();
stop_decoder
.reset
();
// Decode tokens
// Decode tokens
...
@@ -1002,11 +1264,11 @@ impl RouterTrait for GrpcRouter {
...
@@ -1002,11 +1264,11 @@ impl RouterTrait for GrpcRouter {
async
fn
route_generate
(
async
fn
route_generate
(
&
self
,
&
self
,
_
headers
:
Option
<&
HeaderMap
>
,
headers
:
Option
<&
HeaderMap
>
,
_
body
:
&
GenerateRequest
,
body
:
&
GenerateRequest
,
_
model_id
:
Option
<&
str
>
,
model_id
:
Option
<&
str
>
,
)
->
Response
{
)
->
Response
{
(
StatusCode
::
NOT_IMPLEMENTED
)
.into_response
()
self
.route_generate_impl
(
headers
,
body
,
model_id
)
.await
}
}
async
fn
route_chat
(
async
fn
route_chat
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment