Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
01c9ee1a
Unverified
Commit
01c9ee1a
authored
Oct 08, 2025
by
Simo Lin
Committed by
GitHub
Oct 08, 2025
Browse files
[router] refactor generate to use new pipeline arch (#11323)
parent
d6837aea
Changes
7
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
713 additions
and
1181 deletions
+713
-1181
sgl-router/src/protocols/spec.rs
sgl-router/src/protocols/spec.rs
+58
-33
sgl-router/src/routers/grpc/context.rs
sgl-router/src/routers/grpc/context.rs
+5
-28
sgl-router/src/routers/grpc/pd_router.rs
sgl-router/src/routers/grpc/pd_router.rs
+15
-622
sgl-router/src/routers/grpc/pipeline.rs
sgl-router/src/routers/grpc/pipeline.rs
+189
-43
sgl-router/src/routers/grpc/router.rs
sgl-router/src/routers/grpc/router.rs
+13
-421
sgl-router/src/routers/grpc/streaming.rs
sgl-router/src/routers/grpc/streaming.rs
+368
-33
sgl-router/src/routers/grpc/utils.rs
sgl-router/src/routers/grpc/utils.rs
+65
-1
No files found.
sgl-router/src/protocols/spec.rs
View file @
01c9ee1a
...
@@ -2066,39 +2066,64 @@ impl GenerationRequest for GenerateRequest {
...
@@ -2066,39 +2066,64 @@ impl GenerationRequest for GenerateRequest {
}
}
}
}
// TODO(generate): Define GenerateResponse and GenerateChoice structs
// ============================================================================
//
// SGLang Generate Response Types
// Required for pipeline generate response processing (see grpc/pipeline.rs:931-964)
// ============================================================================
//
// #[derive(Debug, Clone, Serialize, Deserialize)]
/// SGLang generate response (single completion or array for n>1)
// pub struct GenerateResponse {
///
// pub id: String,
/// Format for n=1:
// pub object: String, // "text.completion"
/// ```json
// pub created: u64,
/// {
// pub model: String,
/// "text": "...",
// pub choices: Vec<GenerateChoice>,
/// "output_ids": [...],
// #[serde(skip_serializing_if = "Option::is_none")]
/// "meta_info": { ... }
// pub usage: Option<Usage>,
/// }
// #[serde(skip_serializing_if = "Option::is_none")]
/// ```
// pub system_fingerprint: Option<String>,
///
// }
/// Format for n>1:
//
/// ```json
// #[derive(Debug, Clone, Serialize, Deserialize)]
/// [
// pub struct GenerateChoice {
/// {"text": "...", "output_ids": [...], "meta_info": {...}},
// pub index: u32,
/// {"text": "...", "output_ids": [...], "meta_info": {...}}
// pub text: String,
/// ]
// #[serde(skip_serializing_if = "Option::is_none")]
/// ```
// pub output_ids: Option<Vec<u32>>,
#[derive(Debug,
Clone,
Serialize,
Deserialize)]
// #[serde(skip_serializing_if = "Option::is_none")]
pub
struct
GenerateResponse
{
// pub finish_reason: Option<String>,
pub
text
:
String
,
// #[serde(skip_serializing_if = "Option::is_none")]
pub
output_ids
:
Vec
<
u32
>
,
// pub logprobs: Option<TopLogprobs>,
pub
meta_info
:
GenerateMetaInfo
,
// #[serde(skip_serializing_if = "Option::is_none")]
}
// pub matched_stop: Option<Value>,
// }
/// Metadata for a single generate completion
//
#[derive(Debug,
Clone,
Serialize,
Deserialize)]
// Note: Verify if similar structs already exist elsewhere before implementing.
pub
struct
GenerateMetaInfo
{
// May need streaming variant (GenerateStreamResponse) as well.
pub
id
:
String
,
pub
finish_reason
:
GenerateFinishReason
,
pub
prompt_tokens
:
u32
,
pub
weight_version
:
String
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
input_token_logprobs
:
Option
<
Vec
<
Vec
<
Option
<
f64
>>>>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
output_token_logprobs
:
Option
<
Vec
<
Vec
<
Option
<
f64
>>>>
,
pub
completion_tokens
:
u32
,
pub
cached_tokens
:
u32
,
pub
e2e_latency
:
f64
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
matched_stop
:
Option
<
Value
>
,
}
/// Finish reason for generate endpoint
#[derive(Debug,
Clone,
Serialize,
Deserialize)]
#[serde(tag
=
"type"
,
rename_all
=
"lowercase"
)]
pub
enum
GenerateFinishReason
{
Length
{
length
:
u32
,
},
Stop
,
#[serde(untagged)]
Other
(
Value
),
}
// Constants for rerank API
// Constants for rerank API
pub
const
DEFAULT_MODEL_NAME
:
&
str
=
"default"
;
pub
const
DEFAULT_MODEL_NAME
:
&
str
=
"default"
;
...
...
sgl-router/src/routers/grpc/context.rs
View file @
01c9ee1a
...
@@ -12,7 +12,9 @@ use serde_json::Value;
...
@@ -12,7 +12,9 @@ use serde_json::Value;
use
crate
::
core
::
Worker
;
use
crate
::
core
::
Worker
;
use
crate
::
grpc_client
::{
proto
,
SglangSchedulerClient
};
use
crate
::
grpc_client
::{
proto
,
SglangSchedulerClient
};
use
crate
::
protocols
::
spec
::{
ChatCompletionRequest
,
ChatCompletionResponse
,
GenerateRequest
};
use
crate
::
protocols
::
spec
::{
ChatCompletionRequest
,
ChatCompletionResponse
,
GenerateRequest
,
GenerateResponse
,
};
use
crate
::
reasoning_parser
::
ReasoningParserFactory
;
use
crate
::
reasoning_parser
::
ReasoningParserFactory
;
use
crate
::
tokenizer
::
stop
::
StopSequenceDecoder
;
use
crate
::
tokenizer
::
stop
::
StopSequenceDecoder
;
use
crate
::
tokenizer
::
traits
::
Tokenizer
;
use
crate
::
tokenizer
::
traits
::
Tokenizer
;
...
@@ -226,14 +228,6 @@ impl RequestContext {
...
@@ -226,14 +228,6 @@ impl RequestContext {
}
}
}
}
/// Try to get chat request
pub
fn
try_chat_request
(
&
self
)
->
Option
<&
ChatCompletionRequest
>
{
match
&
self
.input.request_type
{
RequestType
::
Chat
(
req
)
=>
Some
(
req
.as_ref
()),
_
=>
None
,
}
}
/// Get generate request (panics if not generate)
/// Get generate request (panics if not generate)
pub
fn
generate_request
(
&
self
)
->
&
GenerateRequest
{
pub
fn
generate_request
(
&
self
)
->
&
GenerateRequest
{
match
&
self
.input.request_type
{
match
&
self
.input.request_type
{
...
@@ -242,14 +236,6 @@ impl RequestContext {
...
@@ -242,14 +236,6 @@ impl RequestContext {
}
}
}
}
/// Try to get generate request
pub
fn
try_generate_request
(
&
self
)
->
Option
<&
GenerateRequest
>
{
match
&
self
.input.request_type
{
RequestType
::
Generate
(
req
)
=>
Some
(
req
.as_ref
()),
_
=>
None
,
}
}
/// Check if request is streaming
/// Check if request is streaming
pub
fn
is_streaming
(
&
self
)
->
bool
{
pub
fn
is_streaming
(
&
self
)
->
bool
{
match
&
self
.input.request_type
{
match
&
self
.input.request_type
{
...
@@ -257,16 +243,6 @@ impl RequestContext {
...
@@ -257,16 +243,6 @@ impl RequestContext {
RequestType
::
Generate
(
req
)
=>
req
.stream
,
RequestType
::
Generate
(
req
)
=>
req
.stream
,
}
}
}
}
/// Check if request is chat
pub
fn
is_chat
(
&
self
)
->
bool
{
matches!
(
&
self
.input.request_type
,
RequestType
::
Chat
(
_
))
}
/// Check if request is generate
pub
fn
is_generate
(
&
self
)
->
bool
{
matches!
(
&
self
.input.request_type
,
RequestType
::
Generate
(
_
))
}
}
}
// ============================================================================
// ============================================================================
...
@@ -394,5 +370,6 @@ pub enum ExecutionResult {
...
@@ -394,5 +370,6 @@ pub enum ExecutionResult {
/// Final processed response
/// Final processed response
pub
enum
FinalResponse
{
pub
enum
FinalResponse
{
Chat
(
ChatCompletionResponse
),
Chat
(
ChatCompletionResponse
),
Generate
(
Box
<
GenerateRequest
>
),
/// Generate response is a Vec of GenerateResponse (n=1 returns single item, n>1 returns multiple)
Generate
(
Vec
<
GenerateResponse
>
),
}
}
sgl-router/src/routers/grpc/pd_router.rs
View file @
01c9ee1a
This diff is collapsed.
Click to expand it.
sgl-router/src/routers/grpc/pipeline.rs
View file @
01c9ee1a
...
@@ -11,15 +11,20 @@ use super::context::*;
...
@@ -11,15 +11,20 @@ use super::context::*;
use
super
::
processing
;
use
super
::
processing
;
use
super
::
streaming
;
use
super
::
streaming
;
use
super
::
utils
;
use
super
::
utils
;
use
crate
::
core
::{
ConnectionMode
,
WorkerRegistry
,
WorkerType
};
use
crate
::
core
::{
ConnectionMode
,
Worker
,
WorkerRegistry
,
WorkerType
};
use
crate
::
grpc_client
::
proto
;
use
crate
::
grpc_client
::
proto
;
use
crate
::
policies
::
PolicyRegistry
;
use
crate
::
policies
::
PolicyRegistry
;
use
crate
::
protocols
::
spec
::{
use
crate
::
protocols
::
spec
::{
ChatCompletionRequest
,
ChatCompletionResponse
,
GenerateRequest
,
InputIds
,
Usage
,
ChatCompletionRequest
,
ChatCompletionResponse
,
GenerateMetaInfo
,
GenerateRequest
,
GenerateResponse
,
InputIds
,
Usage
,
};
};
use
crate
::
tokenizer
::
stop
::
SequenceDecoderOutput
;
use
crate
::
tokenizer
::
traits
::
Tokenizer
;
use
proto
::
generate_complete
::
MatchedStop
;
use
proto
::
DisaggregatedParams
;
use
rand
::
Rng
;
use
rand
::
Rng
;
use
std
::
sync
::
Arc
;
use
std
::
sync
::
Arc
;
use
std
::
time
::{
SystemTime
,
UNIX_EPOCH
};
use
std
::
time
::{
Instant
,
SystemTime
,
UNIX_EPOCH
};
use
uuid
::
Uuid
;
use
uuid
::
Uuid
;
// ============================================================================
// ============================================================================
...
@@ -208,7 +213,7 @@ impl PreparationStage {
...
@@ -208,7 +213,7 @@ impl PreparationStage {
fn
tokenize_single_text
(
fn
tokenize_single_text
(
&
self
,
&
self
,
tokenizer
:
&
Arc
<
dyn
crate
::
tokenizer
::
traits
::
Tokenizer
>
,
tokenizer
:
&
Arc
<
dyn
Tokenizer
>
,
text
:
&
str
,
text
:
&
str
,
)
->
Result
<
(
String
,
Vec
<
u32
>
),
String
>
{
)
->
Result
<
(
String
,
Vec
<
u32
>
),
String
>
{
let
encoding
=
tokenizer
let
encoding
=
tokenizer
...
@@ -302,7 +307,7 @@ impl WorkerSelectionStage {
...
@@ -302,7 +307,7 @@ impl WorkerSelectionStage {
&
self
,
&
self
,
model_id
:
Option
<&
str
>
,
model_id
:
Option
<&
str
>
,
text
:
Option
<&
str
>
,
text
:
Option
<&
str
>
,
)
->
Option
<
Arc
<
dyn
crate
::
core
::
Worker
>>
{
)
->
Option
<
Arc
<
dyn
Worker
>>
{
// Get workers for the specified model, filtered by connection mode
// Get workers for the specified model, filtered by connection mode
let
workers
=
self
.worker_registry
.get_workers_filtered
(
let
workers
=
self
.worker_registry
.get_workers_filtered
(
model_id
,
model_id
,
...
@@ -312,7 +317,7 @@ impl WorkerSelectionStage {
...
@@ -312,7 +317,7 @@ impl WorkerSelectionStage {
);
);
// Filter by availability (health + circuit breaker)
// Filter by availability (health + circuit breaker)
let
available
:
Vec
<
Arc
<
dyn
crate
::
core
::
Worker
>>
=
workers
let
available
:
Vec
<
Arc
<
dyn
Worker
>>
=
workers
.iter
()
.iter
()
.filter
(|
w
|
w
.is_available
())
.filter
(|
w
|
w
.is_available
())
.cloned
()
.cloned
()
...
@@ -337,7 +342,7 @@ impl WorkerSelectionStage {
...
@@ -337,7 +342,7 @@ impl WorkerSelectionStage {
&
self
,
&
self
,
model_id
:
Option
<&
str
>
,
model_id
:
Option
<&
str
>
,
text
:
Option
<&
str
>
,
text
:
Option
<&
str
>
,
)
->
Option
<
(
Arc
<
dyn
crate
::
core
::
Worker
>
,
Arc
<
dyn
crate
::
core
::
Worker
>
)
>
{
)
->
Option
<
(
Arc
<
dyn
Worker
>
,
Arc
<
dyn
Worker
>
)
>
{
// Get prefill workers - use None for WorkerType filter to get all types,
// Get prefill workers - use None for WorkerType filter to get all types,
// then filter manually (since Prefill is a struct variant)
// then filter manually (since Prefill is a struct variant)
let
all_workers
=
self
.worker_registry
.get_workers_filtered
(
let
all_workers
=
self
.worker_registry
.get_workers_filtered
(
...
@@ -537,10 +542,8 @@ impl RequestBuildingStage {
...
@@ -537,10 +542,8 @@ impl RequestBuildingStage {
fn
inject_bootstrap_metadata
(
fn
inject_bootstrap_metadata
(
&
self
,
&
self
,
request
:
&
mut
proto
::
GenerateRequest
,
request
:
&
mut
proto
::
GenerateRequest
,
prefill_worker
:
&
Arc
<
dyn
crate
::
core
::
Worker
>
,
prefill_worker
:
&
Arc
<
dyn
Worker
>
,
)
{
)
{
use
proto
::
DisaggregatedParams
;
let
hostname
=
prefill_worker
.bootstrap_host
();
let
hostname
=
prefill_worker
.bootstrap_host
();
let
bootstrap_port
=
prefill_worker
.bootstrap_port
()
.unwrap_or
(
8998
);
let
bootstrap_port
=
prefill_worker
.bootstrap_port
()
.unwrap_or
(
8998
);
...
@@ -935,40 +938,183 @@ impl ResponseProcessingStage {
...
@@ -935,40 +938,183 @@ impl ResponseProcessingStage {
async
fn
process_generate_response
(
async
fn
process_generate_response
(
&
self
,
&
self
,
_
ctx
:
&
mut
RequestContext
,
ctx
:
&
mut
RequestContext
,
)
->
Result
<
Option
<
Response
>
,
Response
>
{
)
->
Result
<
Option
<
Response
>
,
Response
>
{
// TODO(generate): Implement generate response processing
let
start_time
=
Instant
::
now
();
//
let
is_streaming
=
ctx
.is_streaming
();
// Required implementation:
// 1. Extract execution_result from ctx
// Extract execution result
// 2. Check is_streaming flag
let
execution_result
=
ctx
// 3. For streaming:
.state
// - Add StreamingProcessor::process_streaming_generate() method
.response
// - Similar to process_streaming_response but WITHOUT tool/reasoning parsing
.execution_result
// - Return Err(sse_response) for early exit
.take
()
// 4. For non-streaming:
.ok_or_else
(||
utils
::
internal_error_static
(
"No execution result"
))
?
;
// - Collect stream responses using utils::collect_stream_responses()
// - Process through stop decoder (sequential with reset for n>1, like chat)
if
is_streaming
{
// - Build GenerateResponse struct (see TODO in protocols/spec.rs)
// Get dispatch metadata for consistent response fields
// - Set ctx.state.response.final_response = Some(FinalResponse::Generate(response))
let
dispatch
=
ctx
//
.state
// Reference implementation: router.rs:297-595
.dispatch
// Key differences from chat:
.as_ref
()
// - No tool parsing
.ok_or_else
(||
utils
::
internal_error_static
(
"Dispatch metadata not set"
))
?
;
// - No reasoning parsing
// - Different response format (GenerateResponse instead of ChatCompletionResponse)
let
generate_request
=
ctx
.generate_request
()
.clone
();
// - Still needs: stop decoder, logprobs, finish_reason, matched_stop
Err
((
// Streaming: Use StreamingProcessor and return SSE response (done)
axum
::
http
::
StatusCode
::
NOT_IMPLEMENTED
,
return
Ok
(
Some
(
axum
::
Json
(
serde_json
::
json!
({
self
.streaming_processor
.clone
()
.process_streaming_generate
(
"error"
:
{
execution_result
,
"message"
:
"Generate response processing not yet implemented in pipeline"
,
generate_request
,
"type"
:
"not_implemented"
,
dispatch
.clone
(),
"code"
:
501
),
));
}
// Non-streaming: Collect all responses
let
request_logprobs
=
ctx
.generate_request
()
.return_logprob
;
let
all_responses
=
match
execution_result
{
ExecutionResult
::
Single
{
stream
}
=>
{
utils
::
collect_stream_responses
(
stream
,
"Single"
)
.await
?
}
ExecutionResult
::
Dual
{
prefill
,
decode
}
=>
{
// Collect prefill for input_logprobs
let
prefill_responses
=
utils
::
collect_stream_responses
(
prefill
,
"Prefill"
)
.await
?
;
// Collect decode for actual output
let
mut
decode_responses
=
utils
::
collect_stream_responses
(
*
decode
,
"Decode"
)
.await
?
;
// Merge prefill input_logprobs if requested
if
request_logprobs
{
if
let
Some
(
prefill_input_logprobs
)
=
prefill_responses
.first
()
.and_then
(|
r
|
r
.input_logprobs
.clone
())
{
for
response
in
&
mut
decode_responses
{
response
.input_logprobs
=
Some
(
prefill_input_logprobs
.clone
());
}
}
}
}
})),
)
decode_responses
.into_response
())
}
};
if
all_responses
.is_empty
()
{
return
Err
(
utils
::
internal_error_static
(
"No responses from server"
));
}
// Get stop decoder for processing
let
stop_decoder
=
ctx
.state
.response
.stop_decoder
.as_mut
()
.ok_or_else
(||
utils
::
internal_error_static
(
"Stop decoder not initialized"
))
?
;
// Get dispatch metadata
let
dispatch
=
ctx
.state
.dispatch
.as_ref
()
.ok_or_else
(||
utils
::
internal_error_static
(
"Dispatch metadata not set"
))
?
;
// Process each completion (similar to router.rs:336-400)
let
mut
result_array
=
Vec
::
new
();
for
mut
complete
in
all_responses
{
stop_decoder
.reset
();
// Process tokens through stop decoder
let
outputs
=
match
stop_decoder
.process_tokens
(
&
complete
.output_ids
)
{
Ok
(
outputs
)
=>
outputs
,
Err
(
e
)
=>
{
return
Err
(
utils
::
internal_error_message
(
format!
(
"Failed to process tokens: {}"
,
e
)))
}
};
// Accumulate text with early breaks
let
mut
decoded_text
=
String
::
new
();
for
output
in
outputs
{
match
output
{
SequenceDecoderOutput
::
Text
(
t
)
=>
decoded_text
.push_str
(
&
t
),
SequenceDecoderOutput
::
StoppedWithText
(
t
)
=>
{
decoded_text
.push_str
(
&
t
);
break
;
}
SequenceDecoderOutput
::
Stopped
=>
break
,
SequenceDecoderOutput
::
Held
=>
{}
}
}
// Flush remaining text
if
let
SequenceDecoderOutput
::
Text
(
t
)
=
stop_decoder
.flush
()
{
decoded_text
.push_str
(
&
t
);
}
let
output_ids
=
std
::
mem
::
take
(
&
mut
complete
.output_ids
);
let
finish_reason_str
=
std
::
mem
::
take
(
&
mut
complete
.finish_reason
);
// Parse finish_reason from string to proper type
let
finish_reason
=
utils
::
parse_finish_reason
(
&
finish_reason_str
,
complete
.completion_tokens
);
// Handle matched_stop if present
let
matched_stop
=
complete
.matched_stop
.take
()
.map
(|
matched
|
match
matched
{
MatchedStop
::
MatchedTokenId
(
id
)
=>
serde_json
::
json!
(
id
),
MatchedStop
::
MatchedStopStr
(
s
)
=>
serde_json
::
json!
(
s
),
});
// Extract logprobs if requested (convert proto types to Generate format)
let
input_token_logprobs
=
if
request_logprobs
{
complete
.input_logprobs
.as_ref
()
.map
(
utils
::
convert_generate_input_logprobs
)
}
else
{
None
};
let
output_token_logprobs
=
if
request_logprobs
{
complete
.output_logprobs
.as_ref
()
.map
(
utils
::
convert_generate_output_logprobs
)
}
else
{
None
};
// Build GenerateResponse struct
let
meta_info
=
GenerateMetaInfo
{
id
:
dispatch
.request_id
.clone
(),
finish_reason
,
prompt_tokens
:
complete
.prompt_tokens
as
u32
,
weight_version
:
dispatch
.weight_version
.clone
()
.unwrap_or_else
(||
"default"
.to_string
()),
input_token_logprobs
,
output_token_logprobs
,
completion_tokens
:
complete
.completion_tokens
as
u32
,
cached_tokens
:
complete
.cached_tokens
as
u32
,
e2e_latency
:
start_time
.elapsed
()
.as_secs_f64
(),
matched_stop
,
};
result_array
.push
(
GenerateResponse
{
text
:
decoded_text
,
output_ids
,
meta_info
,
});
}
// Store the final response
ctx
.state.response.final_response
=
Some
(
FinalResponse
::
Generate
(
result_array
));
Ok
(
None
)
}
}
}
}
...
@@ -1136,7 +1282,7 @@ impl ChatCompletionPipeline {
...
@@ -1136,7 +1282,7 @@ impl ChatCompletionPipeline {
// Extract final response
// Extract final response
match
ctx
.state.response.final_response
{
match
ctx
.state.response.final_response
{
Some
(
FinalResponse
::
Generate
(
response
))
=>
axum
::
Json
(
*
response
)
.into_response
(),
Some
(
FinalResponse
::
Generate
(
response
))
=>
axum
::
Json
(
response
)
.into_response
(),
Some
(
FinalResponse
::
Chat
(
_
))
=>
{
Some
(
FinalResponse
::
Chat
(
_
))
=>
{
utils
::
internal_error_static
(
"Internal error: wrong response type"
)
utils
::
internal_error_static
(
"Internal error: wrong response type"
)
}
}
...
...
sgl-router/src/routers/grpc/router.rs
View file @
01c9ee1a
...
@@ -8,28 +8,21 @@ use axum::{
...
@@ -8,28 +8,21 @@ use axum::{
extract
::
Request
,
extract
::
Request
,
http
::{
HeaderMap
,
StatusCode
},
http
::{
HeaderMap
,
StatusCode
},
response
::{
IntoResponse
,
Response
},
response
::{
IntoResponse
,
Response
},
Json
,
};
};
use
tracing
::
debug
;
use
tracing
::
debug
;
use
crate
::
config
::
types
::
RetryConfig
;
use
crate
::
config
::
types
::
RetryConfig
;
use
crate
::
core
::{
ConnectionMode
,
Worker
,
WorkerRegistry
,
WorkerType
};
use
crate
::
core
::
WorkerRegistry
;
use
crate
::
grpc_client
::{
proto
,
SglangSchedulerClient
};
use
crate
::
policies
::
PolicyRegistry
;
use
crate
::
policies
::
PolicyRegistry
;
use
crate
::
protocols
::
spec
::{
use
crate
::
protocols
::
spec
::{
ChatCompletionRequest
,
CompletionRequest
,
EmbeddingRequest
,
GenerateRequest
,
InputIds
,
ChatCompletionRequest
,
CompletionRequest
,
EmbeddingRequest
,
GenerateRequest
,
RerankRequest
,
RerankRequest
,
ResponsesGetParams
,
ResponsesRequest
,
ResponsesGetParams
,
ResponsesRequest
,
};
};
use
crate
::
reasoning_parser
::
ReasoningParserFactory
;
use
crate
::
reasoning_parser
::
ReasoningParserFactory
;
use
crate
::
routers
::
{
grpc
,
RouterTrait
}
;
use
crate
::
routers
::
RouterTrait
;
use
crate
::
server
::
AppContext
;
use
crate
::
server
::
AppContext
;
use
crate
::
tokenizer
::
stop
::
SequenceDecoderOutput
;
use
crate
::
tokenizer
::
traits
::
Tokenizer
;
use
crate
::
tokenizer
::
traits
::
Tokenizer
;
use
crate
::
tool_parser
::
ToolParserFactory
;
use
crate
::
tool_parser
::
ToolParserFactory
;
use
grpc
::
utils
;
use
serde_json
::
json
;
use
std
::
time
::
Instant
;
use
uuid
::
Uuid
;
/// gRPC router implementation for SGLang
/// gRPC router implementation for SGLang
#[derive(Clone)]
#[derive(Clone)]
...
@@ -45,9 +38,7 @@ pub struct GrpcRouter {
...
@@ -45,9 +38,7 @@ pub struct GrpcRouter {
retry_config
:
RetryConfig
,
retry_config
:
RetryConfig
,
configured_reasoning_parser
:
Option
<
String
>
,
configured_reasoning_parser
:
Option
<
String
>
,
configured_tool_parser
:
Option
<
String
>
,
configured_tool_parser
:
Option
<
String
>
,
// Pipeline for non-streaming requests
pipeline
:
super
::
pipeline
::
ChatCompletionPipeline
,
pipeline
:
super
::
pipeline
::
ChatCompletionPipeline
,
// Shared components for pipeline
shared_components
:
Arc
<
super
::
context
::
SharedComponents
>
,
shared_components
:
Arc
<
super
::
context
::
SharedComponents
>
,
}
}
...
@@ -149,420 +140,21 @@ impl GrpcRouter {
...
@@ -149,420 +140,21 @@ impl GrpcRouter {
/// Main route_generate implementation
/// Main route_generate implementation
async
fn
route_generate_impl
(
async
fn
route_generate_impl
(
&
self
,
&
self
,
_
headers
:
Option
<&
HeaderMap
>
,
headers
:
Option
<&
HeaderMap
>
,
body
:
&
GenerateRequest
,
body
:
&
GenerateRequest
,
model_id
:
Option
<&
str
>
,
model_id
:
Option
<&
str
>
,
)
->
Response
{
)
->
Response
{
debug!
(
"Processing generate request for model: {:?}"
,
model_id
);
debug!
(
"Processing generate request for model: {:?}"
,
model_id
);
// Step 1: Resolve input (text, prompt, or input_ids)
// Use pipeline for ALL requests (streaming and non-streaming)
let
(
original_text
,
token_ids
)
=
match
self
.resolve_generate_input
(
body
)
{
self
.pipeline
Ok
(
res
)
=>
res
,
.execute_generate
(
Err
(
msg
)
=>
{
body
.clone
(),
return
utils
::
bad_request_error
(
msg
);
headers
.cloned
(),
}
model_id
.map
(|
s
|
s
.to_string
()),
};
self
.shared_components
.clone
(),
debug!
(
"Resolved input with {} tokens"
,
token_ids
.len
());
// Step 2: Select worker (fail fast if no workers available)
let
worker
=
match
self
.select_worker_for_request
(
model_id
,
original_text
.as_deref
())
{
Some
(
w
)
=>
w
,
None
=>
{
return
utils
::
service_unavailable_error
(
format!
(
"No available workers for model: {:?}"
,
model_id
));
}
};
debug!
(
"Selected worker: {}"
,
worker
.url
());
// Step 3: Get gRPC client from worker
let
client
=
match
utils
::
get_grpc_client_from_worker
(
&
worker
)
.await
{
Ok
(
client
)
=>
client
,
Err
(
response
)
=>
return
response
,
};
// Step 4: Build the gRPC request
let
request_id
=
body
.rid
.clone
()
.unwrap_or_else
(||
format!
(
"gen-{}"
,
Uuid
::
new_v4
()));
let
request
=
match
client
.build_plain_generate_request
(
request_id
.clone
(),
body
,
original_text
.clone
(),
token_ids
,
)
{
Ok
(
req
)
=>
req
,
Err
(
e
)
=>
{
return
utils
::
bad_request_error
(
e
);
}
};
// Step 5: Get weight version for response metadata
let
weight_version
=
worker
.metadata
()
.labels
.get
(
"weight_version"
)
.cloned
()
.unwrap_or_else
(||
"default"
.to_string
());
// Step 6: Handle streaming vs non-streaming
if
body
.stream
{
self
.handle_streaming_generate
(
client
,
request
,
body
,
request_id
,
weight_version
)
.await
}
else
{
self
.handle_non_streaming_generate
(
client
,
request
,
body
,
request_id
,
weight_version
)
.await
}
}
/// Select a worker for the request
fn
select_worker_for_request
(
&
self
,
model_id
:
Option
<&
str
>
,
text
:
Option
<&
str
>
,
)
->
Option
<
Arc
<
dyn
Worker
>>
{
// Get workers for the specified model, filtered by connection mode
let
workers
=
self
.worker_registry
.get_workers_filtered
(
model_id
,
Some
(
WorkerType
::
Regular
),
Some
(
ConnectionMode
::
Grpc
{
port
:
None
}),
false
,
// get all workers, we'll filter by is_available() next
);
// Filter by availability (health + circuit breaker)
let
available
:
Vec
<
Arc
<
dyn
Worker
>>
=
workers
.iter
()
.filter
(|
w
|
w
.is_available
())
.cloned
()
.collect
();
if
available
.is_empty
()
{
return
None
;
}
// Get the appropriate policy for this model
let
policy
=
match
model_id
{
Some
(
model
)
=>
self
.policy_registry
.get_policy_or_default
(
model
),
None
=>
self
.policy_registry
.get_default_policy
(),
};
// Select worker using the policy
let
idx
=
policy
.select_worker
(
&
available
,
text
)
?
;
Some
(
available
[
idx
]
.clone
())
}
/// Resolve the generate input into optional original text and token IDs
fn
resolve_generate_input
(
&
self
,
request
:
&
GenerateRequest
,
)
->
Result
<
(
Option
<
String
>
,
Vec
<
u32
>
),
String
>
{
if
let
Some
(
text
)
=
&
request
.text
{
return
self
.tokenize_single_text
(
text
)
.map
(|(
original
,
ids
)|
(
Some
(
original
),
ids
));
}
// Handle input_ids - validate and convert
if
let
Some
(
input_ids
)
=
&
request
.input_ids
{
return
match
input_ids
{
InputIds
::
Single
(
ids
)
=>
ids
.iter
()
.map
(|
&
id
|
u32
::
try_from
(
id
))
.collect
::
<
Result
<
Vec
<
u32
>
,
_
>>
()
.map
(|
converted
|
(
None
,
converted
))
.map_err
(|
_
|
"input_ids must be non-negative"
.to_string
()),
InputIds
::
Batch
(
_
)
=>
{
Err
(
"Batch input_ids are not supported over gRPC generate yet"
.to_string
())
}
};
}
Err
(
"Either `text` or `input_ids` must be provided"
.to_string
())
}
fn
tokenize_single_text
(
&
self
,
text
:
&
str
)
->
Result
<
(
String
,
Vec
<
u32
>
),
String
>
{
let
encoding
=
self
.tokenizer
.encode
(
text
)
.map_err
(|
e
|
format!
(
"Tokenization failed: {}"
,
e
))
?
;
Ok
((
text
.to_string
(),
encoding
.token_ids
()
.to_vec
()))
}
/// Submit request and handle non-streaming response for the `/generate` endpoint
async
fn
handle_non_streaming_generate
(
&
self
,
mut
client
:
SglangSchedulerClient
,
request
:
proto
::
GenerateRequest
,
original_request
:
&
GenerateRequest
,
request_id
:
String
,
weight_version
:
String
,
)
->
Response
{
let
start_time
=
Instant
::
now
();
let
stream
=
match
client
.generate
(
request
)
.await
{
Ok
(
stream
)
=>
stream
,
Err
(
e
)
=>
{
return
utils
::
internal_error_message
(
format!
(
"Failed to start generation: {}"
,
e
))
}
};
// Collect all responses using utils helper
let
responses
=
match
utils
::
collect_stream_responses
(
stream
,
"Generate"
)
.await
{
Ok
(
responses
)
=>
responses
,
Err
(
error_response
)
=>
return
error_response
,
};
if
responses
.is_empty
()
{
return
utils
::
internal_error_static
(
"No completion received from scheduler"
);
}
// Create stop decoder from sampling params
let
params
=
original_request
.sampling_params
.as_ref
();
let
mut
stop_decoder
=
utils
::
create_stop_decoder
(
&
self
.tokenizer
,
params
.and_then
(|
p
|
p
.stop
.as_ref
()),
params
.and_then
(|
p
|
p
.stop_token_ids
.as_ref
()),
params
.and_then
(|
p
|
p
.skip_special_tokens
)
.unwrap_or
(
true
),
params
.and_then
(|
p
|
p
.no_stop_trim
)
.unwrap_or
(
false
),
);
// Process each completion
let
mut
result_array
=
Vec
::
new
();
for
mut
complete
in
responses
{
stop_decoder
.reset
();
// Process tokens through stop decoder
let
outputs
=
match
stop_decoder
.process_tokens
(
&
complete
.output_ids
)
{
Ok
(
outputs
)
=>
outputs
,
Err
(
e
)
=>
{
return
utils
::
internal_error_message
(
format!
(
"Failed to process tokens: {}"
,
e
))
}
};
// Accumulate text with early breaks
let
mut
decoded_text
=
String
::
new
();
for
output
in
outputs
{
match
output
{
SequenceDecoderOutput
::
Text
(
t
)
=>
decoded_text
.push_str
(
&
t
),
SequenceDecoderOutput
::
StoppedWithText
(
t
)
=>
{
decoded_text
.push_str
(
&
t
);
break
;
}
SequenceDecoderOutput
::
Stopped
=>
break
,
SequenceDecoderOutput
::
Held
=>
{}
}
}
// Flush remaining text
if
let
SequenceDecoderOutput
::
Text
(
t
)
=
stop_decoder
.flush
()
{
decoded_text
.push_str
(
&
t
);
}
let
output_ids
=
std
::
mem
::
take
(
&
mut
complete
.output_ids
);
let
finish_reason
=
std
::
mem
::
take
(
&
mut
complete
.finish_reason
);
// Build base meta_info using json! macro
let
mut
meta_info
=
json!
({
"id"
:
request_id
.clone
(),
"finish_reason"
:
finish_reason
,
"prompt_tokens"
:
complete
.prompt_tokens
,
"weight_version"
:
weight_version
.clone
(),
"completion_tokens"
:
complete
.completion_tokens
,
"cached_tokens"
:
complete
.cached_tokens
,
"e2e_latency"
:
start_time
.elapsed
()
.as_secs_f64
(),
});
let
meta_obj
=
meta_info
.as_object_mut
()
.unwrap
();
// Add matched_stop if present
if
let
Some
(
matched
)
=
complete
.matched_stop
.take
()
{
use
proto
::
generate_complete
::
MatchedStop
;
let
matched_value
=
match
matched
{
MatchedStop
::
MatchedTokenId
(
id
)
=>
json!
(
id
),
MatchedStop
::
MatchedStopStr
(
s
)
=>
json!
(
s
),
};
meta_obj
.insert
(
"matched_stop"
.to_string
(),
matched_value
);
}
result_array
.push
(
json!
({
"text"
:
decoded_text
,
"output_ids"
:
output_ids
,
"meta_info"
:
meta_info
,
}));
}
Json
(
result_array
)
.into_response
()
}
/// Submit request and handle streaming response for the `/generate` endpoint
async
fn
handle_streaming_generate
(
&
self
,
mut
client
:
SglangSchedulerClient
,
request
:
proto
::
GenerateRequest
,
original_request
:
&
GenerateRequest
,
request_id
:
String
,
weight_version
:
String
,
)
->
Response
{
let
tokenizer
=
self
.tokenizer
.clone
();
let
return_logprob
=
original_request
.return_logprob
;
// Create channel for SSE streaming
let
(
tx
,
rx
)
=
tokio
::
sync
::
mpsc
::
unbounded_channel
::
<
Result
<
bytes
::
Bytes
,
std
::
io
::
Error
>>
();
// Start the stream
let
stream
=
match
client
.generate
(
request
)
.await
{
Ok
(
stream
)
=>
stream
,
Err
(
e
)
=>
{
return
utils
::
internal_error_message
(
format!
(
"Failed to start generation: {}"
,
e
))
}
};
// Spawn async task to process stream
tokio
::
spawn
(
async
move
{
let
result
=
Self
::
process_generate_streaming
(
tokenizer
,
stream
,
request_id
,
weight_version
,
return_logprob
,
&
tx
,
)
)
.await
;
.await
if
let
Err
(
e
)
=
result
{
let
error_chunk
=
format!
(
"data: {{
\"
error
\"
:
\"
{}
\"
}}
\n\n
"
,
e
);
let
_
=
tx
.send
(
Ok
(
bytes
::
Bytes
::
from
(
error_chunk
)));
}
// Send [DONE] marker
let
_
=
tx
.send
(
Ok
(
bytes
::
Bytes
::
from
(
"data: [DONE]
\n\n
"
)));
});
// Create SSE response stream
let
body_stream
=
tokio_stream
::
wrappers
::
UnboundedReceiverStream
::
new
(
rx
);
Response
::
builder
()
.status
(
StatusCode
::
OK
)
.header
(
"Content-Type"
,
"text/event-stream"
)
.header
(
"Cache-Control"
,
"no-cache"
)
.header
(
"Connection"
,
"keep-alive"
)
.body
(
axum
::
body
::
Body
::
from_stream
(
body_stream
))
.unwrap
()
}
/// Process streaming chunks for generate endpoint
async
fn
process_generate_streaming
(
tokenizer
:
Arc
<
dyn
Tokenizer
>
,
mut
stream
:
impl
tokio_stream
::
Stream
<
Item
=
Result
<
proto
::
GenerateResponse
,
tonic
::
Status
>>
+
Unpin
,
request_id
:
String
,
weight_version
:
String
,
_
include_logprobs
:
bool
,
tx
:
&
tokio
::
sync
::
mpsc
::
UnboundedSender
<
Result
<
bytes
::
Bytes
,
std
::
io
::
Error
>>
,
)
->
Result
<
(),
String
>
{
use
proto
::
generate_response
::
Response
::{
Chunk
,
Complete
,
Error
};
use
std
::
time
::
Instant
;
use
tokio_stream
::
StreamExt
;
let
start_time
=
Instant
::
now
();
// Track state per index for n>1 case
use
std
::
collections
::
HashMap
;
let
mut
accumulated_texts
:
HashMap
<
u32
,
String
>
=
HashMap
::
new
();
let
mut
completion_tokens_map
:
HashMap
<
u32
,
u32
>
=
HashMap
::
new
();
while
let
Some
(
response
)
=
stream
.next
()
.await
{
let
gen_response
=
response
.map_err
(|
e
|
format!
(
"Stream error: {}"
,
e
))
?
;
match
gen_response
.response
{
Some
(
Chunk
(
chunk
))
=>
{
let
index
=
chunk
.index
;
// Update completion tokens for this index
let
completion_tokens
=
completion_tokens_map
.entry
(
index
)
.or_insert
(
0
);
*
completion_tokens
+=
chunk
.token_ids
.len
()
as
u32
;
// Decode tokens to text (skip_special_tokens=true to handle newlines correctly)
let
chunk_text
=
tokenizer
.decode
(
&
chunk
.token_ids
,
true
)
.unwrap_or_default
();
// Accumulate text for this index
let
accumulated_text
=
accumulated_texts
.entry
(
index
)
.or_default
();
accumulated_text
.push_str
(
&
chunk_text
);
// Generate unique ID per index
let
index_id
=
format!
(
"{}-{}"
,
request_id
,
index
);
// Build streaming response chunk (SGLang format)
let
chunk_response
=
serde_json
::
json!
({
"text"
:
accumulated_text
.clone
(),
"output_ids"
:
chunk
.token_ids
,
"meta_info"
:
{
"id"
:
index_id
,
"finish_reason"
:
null
,
"prompt_tokens"
:
chunk
.prompt_tokens
,
"weight_version"
:
weight_version
,
"completion_tokens"
:
*
completion_tokens
,
"cached_tokens"
:
chunk
.cached_tokens
},
"index"
:
index
});
let
sse_chunk
=
format!
(
"data: {}
\n\n
"
,
serde_json
::
to_string
(
&
chunk_response
)
.unwrap
()
);
tx
.send
(
Ok
(
bytes
::
Bytes
::
from
(
sse_chunk
)))
.map_err
(|
_
|
"Failed to send chunk"
.to_string
())
?
;
}
Some
(
Complete
(
complete
))
=>
{
let
index
=
complete
.index
;
let
accumulated_text
=
accumulated_texts
.get
(
&
index
)
.cloned
()
.unwrap_or_default
();
let
completion_tokens
=
*
completion_tokens_map
.get
(
&
index
)
.unwrap_or
(
&
0
);
let
index_id
=
format!
(
"{}-{}"
,
request_id
,
index
);
let
e2e_latency
=
start_time
.elapsed
()
.as_secs_f64
();
// Send final chunk with finish_reason (no new tokens in Complete, they were already sent in Chunks)
let
finish_response
=
serde_json
::
json!
({
"text"
:
accumulated_text
,
"output_ids"
:
complete
.output_ids
[
complete
.output_ids
.len
()
.saturating_sub
(
1
)
..
]
.to_vec
(),
"meta_info"
:
{
"id"
:
index_id
,
"finish_reason"
:
complete
.finish_reason
,
"prompt_tokens"
:
complete
.prompt_tokens
,
"weight_version"
:
weight_version
,
"completion_tokens"
:
completion_tokens
,
"cached_tokens"
:
complete
.cached_tokens
,
"e2e_latency"
:
e2e_latency
},
"index"
:
index
});
let
sse_chunk
=
format!
(
"data: {}
\n\n
"
,
serde_json
::
to_string
(
&
finish_response
)
.unwrap
()
);
tx
.send
(
Ok
(
bytes
::
Bytes
::
from
(
sse_chunk
)))
.map_err
(|
_
|
"Failed to send finish chunk"
.to_string
())
?
;
// Continue to process all completions if n>1
}
Some
(
Error
(
error
))
=>
{
return
Err
(
error
.message
);
}
None
=>
continue
,
}
}
Ok
(())
}
}
}
}
...
...
sgl-router/src/routers/grpc/streaming.rs
View file @
01c9ee1a
This diff is collapsed.
Click to expand it.
sgl-router/src/routers/grpc/utils.rs
View file @
01c9ee1a
...
@@ -5,7 +5,7 @@ use crate::core::Worker;
...
@@ -5,7 +5,7 @@ use crate::core::Worker;
use
crate
::
grpc_client
::{
proto
,
SglangSchedulerClient
};
use
crate
::
grpc_client
::{
proto
,
SglangSchedulerClient
};
use
crate
::
protocols
::
spec
::{
use
crate
::
protocols
::
spec
::{
ChatCompletionRequest
,
ChatLogProbs
,
ChatLogProbsContent
,
ChatMessage
,
FunctionCallResponse
,
ChatCompletionRequest
,
ChatLogProbs
,
ChatLogProbsContent
,
ChatMessage
,
FunctionCallResponse
,
StringOrArray
,
Tool
,
ToolCall
,
ToolChoice
,
ToolChoiceValue
,
TopLogProb
,
GenerateFinishReason
,
StringOrArray
,
Tool
,
ToolCall
,
ToolChoice
,
ToolChoiceValue
,
TopLogProb
,
};
};
use
crate
::
tokenizer
::
chat_template
::{
ChatTemplateContentFormat
,
ChatTemplateParams
};
use
crate
::
tokenizer
::
chat_template
::{
ChatTemplateContentFormat
,
ChatTemplateParams
};
use
crate
::
tokenizer
::
traits
::
Tokenizer
;
use
crate
::
tokenizer
::
traits
::
Tokenizer
;
...
@@ -809,6 +809,70 @@ pub fn convert_proto_to_openai_logprobs(
...
@@ -809,6 +809,70 @@ pub fn convert_proto_to_openai_logprobs(
})
})
}
}
/// Convert proto::OutputLogProbs to Generate format Vec<Vec<Option<f64>>>
///
/// Generate format: [[logprob, token_id, ...], [logprob, token_id, ...], ...]
/// Each inner vec contains [logprob (f64), token_id (i32), ...]
pub
fn
convert_generate_output_logprobs
(
proto_logprobs
:
&
proto
::
OutputLogProbs
,
)
->
Vec
<
Vec
<
Option
<
f64
>>>
{
proto_logprobs
.token_logprobs
.iter
()
.zip
(
proto_logprobs
.token_ids
.iter
())
.map
(|(
&
logprob
,
&
token_id
)|
vec!
[
Some
(
logprob
as
f64
),
Some
(
token_id
as
f64
)])
.collect
()
}
/// Convert proto::InputLogProbs to Generate format Vec<Vec<Option<f64>>>
///
/// Generate format: [[logprob, token_id, ...], [logprob, token_id, ...], ...]
/// First token has null logprob: [[null, token_id], [logprob, token_id], ...]
pub
fn
convert_generate_input_logprobs
(
proto_logprobs
:
&
proto
::
InputLogProbs
,
)
->
Vec
<
Vec
<
Option
<
f64
>>>
{
proto_logprobs
.token_logprobs
.iter
()
.zip
(
proto_logprobs
.token_ids
.iter
())
.map
(|(
token_logprob
,
&
token_id
)|
{
// InputTokenLogProb has optional value field
let
logprob_value
=
token_logprob
.value
.map
(|
v
|
v
as
f64
);
vec!
[
logprob_value
,
Some
(
token_id
as
f64
)]
})
.collect
()
}
/// Parse finish_reason string into GenerateFinishReason enum
///
/// Uses serde to deserialize the finish_reason, which handles all tagged variants automatically.
/// The GenerateFinishReason enum is tagged with `#[serde(tag = "type", rename_all = "lowercase")]`,
/// so it expects JSON objects like:
/// - `{"type":"stop"}` -> Stop
/// - `{"type":"length","length":100}` -> Length { length: 100 }
/// - Any other JSON -> Other(...)
///
/// For backward compatibility, also handles simple string "stop" -> Stop
pub
fn
parse_finish_reason
(
reason_str
:
&
str
,
completion_tokens
:
i32
)
->
GenerateFinishReason
{
if
reason_str
==
"stop"
{
return
GenerateFinishReason
::
Stop
;
}
if
reason_str
==
"length"
{
return
GenerateFinishReason
::
Length
{
length
:
completion_tokens
.max
(
0
)
as
u32
,
};
}
match
serde_json
::
from_str
::
<
GenerateFinishReason
>
(
reason_str
)
{
Ok
(
finish_reason
)
=>
finish_reason
,
Err
(
_
)
=>
match
serde_json
::
from_str
::
<
Value
>
(
reason_str
)
{
Ok
(
json_value
)
=>
GenerateFinishReason
::
Other
(
json_value
),
Err
(
_
)
=>
GenerateFinishReason
::
Other
(
Value
::
String
(
reason_str
.to_string
())),
},
}
}
#[cfg(test)]
#[cfg(test)]
mod
tests
{
mod
tests
{
use
super
::
*
;
use
super
::
*
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment