Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
0c3db889
Unverified
Commit
0c3db889
authored
Sep 26, 2025
by
Chang Su
Committed by
GitHub
Sep 26, 2025
Browse files
[router][grpc] Add helpfer functions for decoder in router.rs and fix specs (#10971)
parent
2bdaf482
Changes
6
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
307 additions
and
106 deletions
+307
-106
python/sglang/srt/grpc/sglang_scheduler.proto
python/sglang/srt/grpc/sglang_scheduler.proto
+6
-6
python/sglang/srt/grpc/sglang_scheduler_pb2.py
python/sglang/srt/grpc/sglang_scheduler_pb2.py
+68
-68
sgl-router/src/grpc_client/sglang_scheduler.rs
sgl-router/src/grpc_client/sglang_scheduler.rs
+9
-6
sgl-router/src/proto/sglang_scheduler.proto
sgl-router/src/proto/sglang_scheduler.proto
+5
-5
sgl-router/src/protocols/spec.rs
sgl-router/src/protocols/spec.rs
+3
-3
sgl-router/src/routers/grpc/router.rs
sgl-router/src/routers/grpc/router.rs
+216
-18
No files found.
python/sglang/srt/grpc/sglang_scheduler.proto
View file @
0c3db889
...
@@ -36,9 +36,9 @@ message SamplingParams {
...
@@ -36,9 +36,9 @@ message SamplingParams {
float
presence_penalty
=
6
;
float
presence_penalty
=
6
;
float
repetition_penalty
=
7
;
float
repetition_penalty
=
7
;
int32
max_new_tokens
=
8
;
optional
int32
max_new_tokens
=
8
;
repeated
string
stop
=
9
;
repeated
string
stop
=
9
;
repeated
int32
stop_token_ids
=
10
;
repeated
u
int32
stop_token_ids
=
10
;
bool
skip_special_tokens
=
11
;
bool
skip_special_tokens
=
11
;
bool
spaces_between_special_tokens
=
12
;
bool
spaces_between_special_tokens
=
12
;
...
@@ -98,7 +98,7 @@ message GenerateRequest {
...
@@ -98,7 +98,7 @@ message GenerateRequest {
bool
return_logprob
=
5
;
bool
return_logprob
=
5
;
int32
logprob_start_len
=
6
;
int32
logprob_start_len
=
6
;
int32
top_logprobs_num
=
7
;
int32
top_logprobs_num
=
7
;
repeated
int32
token_ids_logprob
=
8
;
repeated
u
int32
token_ids_logprob
=
8
;
bool
return_hidden_states
=
9
;
bool
return_hidden_states
=
9
;
// For disaggregated serving
// For disaggregated serving
...
@@ -129,7 +129,7 @@ message GenerateRequest {
...
@@ -129,7 +129,7 @@ message GenerateRequest {
message
TokenizedInput
{
message
TokenizedInput
{
string
original_text
=
1
;
// For reference
string
original_text
=
1
;
// For reference
repeated
int32
input_ids
=
2
;
repeated
u
int32
input_ids
=
2
;
}
}
message
MultimodalInputs
{
message
MultimodalInputs
{
...
@@ -167,7 +167,7 @@ message GenerateResponse {
...
@@ -167,7 +167,7 @@ message GenerateResponse {
message
GenerateStreamChunk
{
message
GenerateStreamChunk
{
// Generated tokens (incremental chunk)
// Generated tokens (incremental chunk)
repeated
int32
token_ids
=
1
;
repeated
u
int32
token_ids
=
1
;
// Cumulative counts
// Cumulative counts
int32
prompt_tokens
=
2
;
int32
prompt_tokens
=
2
;
...
@@ -183,7 +183,7 @@ message GenerateStreamChunk {
...
@@ -183,7 +183,7 @@ message GenerateStreamChunk {
message
GenerateComplete
{
message
GenerateComplete
{
// Final output
// Final output
repeated
int32
output_ids
=
1
;
repeated
u
int32
output_ids
=
1
;
// Finish reason
// Finish reason
enum
FinishReason
{
enum
FinishReason
{
...
...
python/sglang/srt/grpc/sglang_scheduler_pb2.py
View file @
0c3db889
This diff is collapsed.
Click to expand it.
sgl-router/src/grpc_client/sglang_scheduler.rs
View file @
0c3db889
...
@@ -20,7 +20,7 @@ pub struct SglangSchedulerClient {
...
@@ -20,7 +20,7 @@ pub struct SglangSchedulerClient {
impl
SglangSchedulerClient
{
impl
SglangSchedulerClient
{
/// Create a new client and connect to the scheduler
/// Create a new client and connect to the scheduler
pub
async
fn
connect
(
endpoint
:
&
str
)
->
Result
<
Self
,
Box
<
dyn
std
::
error
::
Error
>>
{
pub
async
fn
connect
(
endpoint
:
&
str
)
->
Result
<
Self
,
Box
<
dyn
std
::
error
::
Error
+
Send
+
Sync
>>
{
debug!
(
"Connecting to SGLang scheduler at {}"
,
endpoint
);
debug!
(
"Connecting to SGLang scheduler at {}"
,
endpoint
);
// Convert grpc:// to http:// for tonic
// Convert grpc:// to http:// for tonic
...
@@ -41,10 +41,11 @@ impl SglangSchedulerClient {
...
@@ -41,10 +41,11 @@ impl SglangSchedulerClient {
}
}
/// Submit a generation request (returns streaming response)
/// Submit a generation request (returns streaming response)
pub
async
fn
generate
_stream
(
pub
async
fn
generate
(
&
mut
self
,
&
mut
self
,
req
:
proto
::
GenerateRequest
,
req
:
proto
::
GenerateRequest
,
)
->
Result
<
tonic
::
Streaming
<
proto
::
GenerateResponse
>
,
Box
<
dyn
std
::
error
::
Error
>>
{
)
->
Result
<
tonic
::
Streaming
<
proto
::
GenerateResponse
>
,
Box
<
dyn
std
::
error
::
Error
+
Send
+
Sync
>>
{
let
request
=
Request
::
new
(
req
);
let
request
=
Request
::
new
(
req
);
let
response
=
self
.client
.generate
(
request
)
.await
?
;
let
response
=
self
.client
.generate
(
request
)
.await
?
;
Ok
(
response
.into_inner
())
Ok
(
response
.into_inner
())
...
@@ -53,7 +54,7 @@ impl SglangSchedulerClient {
...
@@ -53,7 +54,7 @@ impl SglangSchedulerClient {
/// Perform health check
/// Perform health check
pub
async
fn
health_check
(
pub
async
fn
health_check
(
&
mut
self
,
&
mut
self
,
)
->
Result
<
proto
::
HealthCheckResponse
,
Box
<
dyn
std
::
error
::
Error
>>
{
)
->
Result
<
proto
::
HealthCheckResponse
,
Box
<
dyn
std
::
error
::
Error
+
Send
+
Sync
>>
{
debug!
(
"Sending health check request"
);
debug!
(
"Sending health check request"
);
let
request
=
Request
::
new
(
proto
::
HealthCheckRequest
{
let
request
=
Request
::
new
(
proto
::
HealthCheckRequest
{
tokenized
:
Some
(
proto
::
TokenizedInput
{
tokenized
:
Some
(
proto
::
TokenizedInput
{
...
@@ -72,7 +73,7 @@ impl SglangSchedulerClient {
...
@@ -72,7 +73,7 @@ impl SglangSchedulerClient {
&
mut
self
,
&
mut
self
,
request_id
:
String
,
request_id
:
String
,
reason
:
String
,
reason
:
String
,
)
->
Result
<
(),
Box
<
dyn
std
::
error
::
Error
>>
{
)
->
Result
<
(),
Box
<
dyn
std
::
error
::
Error
+
Send
+
Sync
>>
{
let
request
=
Request
::
new
(
proto
::
AbortRequest
{
request_id
,
reason
});
let
request
=
Request
::
new
(
proto
::
AbortRequest
{
request_id
,
reason
});
self
.client
.abort
(
request
)
.await
?
;
self
.client
.abort
(
request
)
.await
?
;
...
@@ -85,7 +86,7 @@ impl SglangSchedulerClient {
...
@@ -85,7 +86,7 @@ impl SglangSchedulerClient {
request_id
:
String
,
request_id
:
String
,
body
:
&
ChatCompletionRequest
,
body
:
&
ChatCompletionRequest
,
processed_text
:
String
,
processed_text
:
String
,
token_ids
:
Vec
<
i
32
>
,
token_ids
:
Vec
<
u
32
>
,
multimodal_inputs
:
Option
<
proto
::
MultimodalInputs
>
,
multimodal_inputs
:
Option
<
proto
::
MultimodalInputs
>
,
tool_call_constraint
:
Option
<
(
String
,
String
)
>
,
// (constraint_type, constraint_value)
tool_call_constraint
:
Option
<
(
String
,
String
)
>
,
// (constraint_type, constraint_value)
)
->
Result
<
proto
::
GenerateRequest
,
String
>
{
)
->
Result
<
proto
::
GenerateRequest
,
String
>
{
...
@@ -153,6 +154,8 @@ impl SglangSchedulerClient {
...
@@ -153,6 +154,8 @@ impl SglangSchedulerClient {
stop
:
stop_sequences
,
stop
:
stop_sequences
,
stop_token_ids
:
request
.stop_token_ids
.clone
()
.unwrap_or_default
(),
stop_token_ids
:
request
.stop_token_ids
.clone
()
.unwrap_or_default
(),
skip_special_tokens
,
skip_special_tokens
,
ignore_eos
:
request
.ignore_eos
,
no_stop_trim
:
request
.no_stop_trim
,
n
:
request
.n
.unwrap_or
(
1
)
as
i32
,
n
:
request
.n
.unwrap_or
(
1
)
as
i32
,
constraint
:
self
.build_constraint
(
request
,
tool_call_constraint
)
?
,
constraint
:
self
.build_constraint
(
request
,
tool_call_constraint
)
?
,
..
Default
::
default
()
..
Default
::
default
()
...
...
sgl-router/src/proto/sglang_scheduler.proto
View file @
0c3db889
...
@@ -38,7 +38,7 @@ message SamplingParams {
...
@@ -38,7 +38,7 @@ message SamplingParams {
optional
int32
max_new_tokens
=
8
;
optional
int32
max_new_tokens
=
8
;
repeated
string
stop
=
9
;
repeated
string
stop
=
9
;
repeated
int32
stop_token_ids
=
10
;
repeated
u
int32
stop_token_ids
=
10
;
bool
skip_special_tokens
=
11
;
bool
skip_special_tokens
=
11
;
bool
spaces_between_special_tokens
=
12
;
bool
spaces_between_special_tokens
=
12
;
...
@@ -98,7 +98,7 @@ message GenerateRequest {
...
@@ -98,7 +98,7 @@ message GenerateRequest {
bool
return_logprob
=
5
;
bool
return_logprob
=
5
;
int32
logprob_start_len
=
6
;
int32
logprob_start_len
=
6
;
int32
top_logprobs_num
=
7
;
int32
top_logprobs_num
=
7
;
repeated
int32
token_ids_logprob
=
8
;
repeated
u
int32
token_ids_logprob
=
8
;
bool
return_hidden_states
=
9
;
bool
return_hidden_states
=
9
;
// For disaggregated serving
// For disaggregated serving
...
@@ -129,7 +129,7 @@ message GenerateRequest {
...
@@ -129,7 +129,7 @@ message GenerateRequest {
message
TokenizedInput
{
message
TokenizedInput
{
string
original_text
=
1
;
// For reference
string
original_text
=
1
;
// For reference
repeated
int32
input_ids
=
2
;
repeated
u
int32
input_ids
=
2
;
}
}
message
MultimodalInputs
{
message
MultimodalInputs
{
...
@@ -167,7 +167,7 @@ message GenerateResponse {
...
@@ -167,7 +167,7 @@ message GenerateResponse {
message
GenerateStreamChunk
{
message
GenerateStreamChunk
{
// Generated tokens (incremental chunk)
// Generated tokens (incremental chunk)
repeated
int32
token_ids
=
1
;
repeated
u
int32
token_ids
=
1
;
// Cumulative counts
// Cumulative counts
int32
prompt_tokens
=
2
;
int32
prompt_tokens
=
2
;
...
@@ -183,7 +183,7 @@ message GenerateStreamChunk {
...
@@ -183,7 +183,7 @@ message GenerateStreamChunk {
message
GenerateComplete
{
message
GenerateComplete
{
// Final output
// Final output
repeated
int32
output_ids
=
1
;
repeated
u
int32
output_ids
=
1
;
// Finish reason
// Finish reason
enum
FinishReason
{
enum
FinishReason
{
...
...
sgl-router/src/protocols/spec.rs
View file @
0c3db889
...
@@ -313,7 +313,7 @@ pub struct ChatCompletionRequest {
...
@@ -313,7 +313,7 @@ pub struct ChatCompletionRequest {
/// Specific token IDs to use as stop conditions
/// Specific token IDs to use as stop conditions
#[serde(skip_serializing_if
=
"Option::is_none"
)]
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
stop_token_ids
:
Option
<
Vec
<
i
32
>>
,
pub
stop_token_ids
:
Option
<
Vec
<
u
32
>>
,
/// Skip trimming stop tokens from output
/// Skip trimming stop tokens from output
#[serde(default)]
#[serde(default)]
...
@@ -564,7 +564,7 @@ pub struct CompletionRequest {
...
@@ -564,7 +564,7 @@ pub struct CompletionRequest {
/// Specific token IDs to use as stop conditions
/// Specific token IDs to use as stop conditions
#[serde(skip_serializing_if
=
"Option::is_none"
)]
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
stop_token_ids
:
Option
<
Vec
<
i
32
>>
,
pub
stop_token_ids
:
Option
<
Vec
<
u
32
>>
,
/// Skip trimming stop tokens from output
/// Skip trimming stop tokens from output
#[serde(default)]
#[serde(default)]
...
@@ -1864,7 +1864,7 @@ pub struct SamplingParams {
...
@@ -1864,7 +1864,7 @@ pub struct SamplingParams {
#[serde(skip_serializing_if
=
"Option::is_none"
)]
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
min_tokens
:
Option
<
u32
>
,
pub
min_tokens
:
Option
<
u32
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
stop_token_ids
:
Option
<
Vec
<
i
32
>>
,
pub
stop_token_ids
:
Option
<
Vec
<
u
32
>>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
no_stop_trim
:
Option
<
bool
>
,
pub
no_stop_trim
:
Option
<
bool
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
#[serde(skip_serializing_if
=
"Option::is_none"
)]
...
...
sgl-router/src/routers/grpc/router.rs
View file @
0c3db889
...
@@ -17,19 +17,20 @@ use crate::grpc_client::{proto, SglangSchedulerClient};
...
@@ -17,19 +17,20 @@ use crate::grpc_client::{proto, SglangSchedulerClient};
use
crate
::
metrics
::
RouterMetrics
;
use
crate
::
metrics
::
RouterMetrics
;
use
crate
::
policies
::
PolicyRegistry
;
use
crate
::
policies
::
PolicyRegistry
;
use
crate
::
protocols
::
spec
::
ChatMessage
;
use
crate
::
protocols
::
spec
::
ChatMessage
;
use
crate
::
protocols
::
spec
::{
ChatCompletionRequest
,
StringOrArray
};
use
crate
::
protocols
::
spec
::{
use
crate
::
protocols
::
spec
::{
CompletionRequest
,
EmbeddingRequest
,
GenerateRequest
,
RerankRequest
,
ResponsesGetParams
,
ChatCompletionRequest
,
CompletionRequest
,
EmbeddingRequest
,
GenerateRequest
,
RerankRequest
,
Responses
Request
,
Tool
,
ToolChoice
,
Responses
GetParams
,
ResponsesRequest
,
StringOrArray
,
Tool
,
ToolChoice
,
};
};
use
crate
::
reasoning_parser
::
ParserFactory
;
use
crate
::
reasoning_parser
::
ParserFactory
;
use
crate
::
routers
::
RouterTrait
;
use
crate
::
routers
::
RouterTrait
;
use
crate
::
server
::
AppContext
;
use
crate
::
server
::
AppContext
;
use
crate
::
tokenizer
::
chat_template
::{
ChatTemplateContentFormat
,
ChatTemplateParams
};
use
crate
::
tokenizer
::
chat_template
::{
ChatTemplateContentFormat
,
ChatTemplateParams
};
use
crate
::
tokenizer
::
stop
::{
SequenceDecoderOutput
,
StopSequenceDecoderBuilder
};
use
crate
::
tokenizer
::
traits
::
Tokenizer
;
use
crate
::
tokenizer
::
traits
::
Tokenizer
;
use
crate
::
tokenizer
::
HuggingFaceTokenizer
;
use
crate
::
tokenizer
::
HuggingFaceTokenizer
;
use
crate
::
tool_parser
::
ParserRegistry
;
use
crate
::
tool_parser
::
ParserRegistry
;
use
serde_json
::
Value
;
use
serde_json
::
Value
;
use
tokio_stream
::
StreamExt
;
use
uuid
::
Uuid
;
use
uuid
::
Uuid
;
// Data structures for processing
// Data structures for processing
...
@@ -182,7 +183,7 @@ impl GrpcRouter {
...
@@ -182,7 +183,7 @@ impl GrpcRouter {
request_id
,
request_id
,
body
,
body
,
processed_messages
.text
.clone
(),
processed_messages
.text
.clone
(),
token_ids
.into_iter
()
.map
(|
id
|
id
as
i32
)
.collect
()
,
token_ids
,
processed_messages
.multimodal_inputs
,
processed_messages
.multimodal_inputs
,
tool_call_constraint
,
// Pass the full tuple (type, value)
tool_call_constraint
,
// Pass the full tuple (type, value)
)
{
)
{
...
@@ -479,28 +480,225 @@ impl GrpcRouter {
...
@@ -479,28 +480,225 @@ impl GrpcRouter {
None
None
}
}
/// Placeholder for streaming handler (to be implemented in Phase 2)
/// Create a StopSequenceDecoder from the chat completion request
fn
create_stop_decoder
(
&
self
,
original_request
:
&
ChatCompletionRequest
,
)
->
crate
::
tokenizer
::
stop
::
StopSequenceDecoder
{
// Extract stop sequences from request
let
stop_sequences
:
Vec
<
String
>
=
match
&
original_request
.stop
{
Some
(
StringOrArray
::
String
(
s
))
=>
vec!
[
s
.clone
()],
Some
(
StringOrArray
::
Array
(
arr
))
=>
arr
.clone
(),
None
=>
vec!
[],
};
// Build stop sequence decoder
let
mut
builder
=
StopSequenceDecoderBuilder
::
new
(
self
.tokenizer
.clone
())
.skip_special_tokens
(
original_request
.skip_special_tokens
);
// Add stop sequences (visible if no_stop_trim is true, hidden otherwise)
for
seq
in
stop_sequences
{
builder
=
if
original_request
.no_stop_trim
{
builder
.visible_stop_sequence
(
seq
)
}
else
{
builder
.stop_sequence
(
seq
)
};
}
// Add stop token IDs (visible if no_stop_trim is true, hidden otherwise)
if
let
Some
(
stop_token_ids
)
=
&
original_request
.stop_token_ids
{
for
&
token_id
in
stop_token_ids
{
builder
=
if
original_request
.no_stop_trim
{
builder
.visible_stop_token
(
token_id
)
}
else
{
builder
.stop_token
(
token_id
)
};
}
}
builder
.build
()
}
/// Process a chunk of tokens through the stop decoder
fn
process_chunk_tokens
(
stop_decoder
:
&
mut
crate
::
tokenizer
::
stop
::
StopSequenceDecoder
,
token_ids
:
&
[
u32
],
)
->
(
String
,
bool
)
{
let
mut
chunk_text
=
String
::
new
();
for
&
token_id
in
token_ids
{
match
stop_decoder
.process_token
(
token_id
)
.unwrap_or_else
(|
e
|
{
debug!
(
"Error processing token {}: {}. Treating as Held."
,
token_id
,
e
);
SequenceDecoderOutput
::
Held
})
{
SequenceDecoderOutput
::
Text
(
text
)
=>
{
chunk_text
.push_str
(
&
text
);
}
SequenceDecoderOutput
::
StoppedWithText
(
text
)
=>
{
chunk_text
.push_str
(
&
text
);
return
(
chunk_text
,
true
);
// Return text and signal to stop
}
SequenceDecoderOutput
::
Stopped
=>
{
return
(
chunk_text
,
true
);
// Return text and signal to stop
}
SequenceDecoderOutput
::
Held
=>
{
// Text held for potential stop sequence match
}
}
}
(
chunk_text
,
false
)
// Return text and continue processing
}
/// Submit request and handle streaming response for chat completions route
async
fn
handle_streaming_chat
(
async
fn
handle_streaming_chat
(
&
self
,
&
self
,
_
client
:
SglangSchedulerClient
,
mut
client
:
SglangSchedulerClient
,
_
request
:
proto
::
GenerateRequest
,
request
:
proto
::
GenerateRequest
,
_
original_request
:
&
ChatCompletionRequest
,
original_request
:
&
ChatCompletionRequest
,
)
->
Response
{
)
->
Response
{
(
StatusCode
::
NOT_IMPLEMENTED
,
"Streaming not yet implemented"
)
.into_response
()
let
mut
stop_decoder
=
self
.create_stop_decoder
(
original_request
);
// Process streaming tokens
let
mut
grpc_stream
=
match
client
.generate
(
request
)
.await
{
Ok
(
stream
)
=>
stream
,
Err
(
e
)
=>
{
error!
(
"Failed to start generation: {}"
,
e
);
return
(
StatusCode
::
INTERNAL_SERVER_ERROR
,
format!
(
"Generation failed: {}"
,
e
),
)
.into_response
();
}
};
let
mut
decoded_text
=
String
::
new
();
while
let
Some
(
response
)
=
grpc_stream
.next
()
.await
{
let
gen_response
=
match
response
{
Ok
(
resp
)
=>
resp
,
Err
(
e
)
=>
{
error!
(
"Stream error: {}"
,
e
);
break
;
}
};
match
gen_response
.response
{
Some
(
proto
::
generate_response
::
Response
::
Chunk
(
chunk
))
=>
{
// Process tokens and check if we should stop
let
(
chunk_text
,
should_stop
)
=
Self
::
process_chunk_tokens
(
&
mut
stop_decoder
,
&
chunk
.token_ids
);
decoded_text
.push_str
(
&
chunk_text
);
if
should_stop
{
break
;
}
continue
;
}
Some
(
proto
::
generate_response
::
Response
::
Complete
(
_
complete
))
=>
{
// Flush any remaining text
if
let
SequenceDecoderOutput
::
Text
(
text
)
=
stop_decoder
.flush
()
{
if
!
text
.is_empty
()
{
decoded_text
.push_str
(
&
text
);
debug!
(
"Flushed text: {}"
,
text
);
}
}
break
;
}
Some
(
proto
::
generate_response
::
Response
::
Error
(
error
))
=>
{
error!
(
"Generation error: {}"
,
error
.message
);
break
;
}
None
=>
continue
,
}
}
// TODO: Replace with proper SSE streaming response
// For now, return the complete decoded text
(
StatusCode
::
OK
,
format!
(
"Decoded text: {}"
,
decoded_text
))
.into_response
()
}
}
///
Placeholder for non-streaming handler (to be implemented in Phase 3)
///
Submit request and handle non-streaming response for chat completions route
async
fn
handle_non_streaming_chat
(
async
fn
handle_non_streaming_chat
(
&
self
,
&
self
,
_
client
:
SglangSchedulerClient
,
mut
client
:
SglangSchedulerClient
,
_
request
:
proto
::
GenerateRequest
,
request
:
proto
::
GenerateRequest
,
_
original_request
:
&
ChatCompletionRequest
,
original_request
:
&
ChatCompletionRequest
,
)
->
Response
{
)
->
Response
{
(
let
mut
stop_decoder
=
self
.create_stop_decoder
(
original_request
);
StatusCode
::
NOT_IMPLEMENTED
,
"Non-streaming not yet implemented"
,
// Small helpers to log + return a uniform 500
)
let
fail_str
=
|
msg
:
&
'static
str
|
->
Response
{
.into_response
()
error!
(
"{}"
,
msg
);
(
StatusCode
::
INTERNAL_SERVER_ERROR
,
msg
)
.into_response
()
};
let
fail_fmt
=
|
prefix
:
&
str
,
e
:
&
dyn
std
::
fmt
::
Display
|
->
Response
{
error!
(
"{}{}"
,
prefix
,
e
);
(
StatusCode
::
INTERNAL_SERVER_ERROR
,
format!
(
"{}{}"
,
prefix
,
e
),
)
.into_response
()
};
// Start generation
let
mut
stream
=
match
client
.generate
(
request
)
.await
{
Ok
(
s
)
=>
s
,
Err
(
e
)
=>
return
fail_fmt
(
"Failed to start generation: "
,
&
e
),
};
// Get the single Complete response
let
gen_response
=
match
stream
.next
()
.await
{
Some
(
Ok
(
r
))
=>
r
,
Some
(
Err
(
e
))
=>
return
fail_fmt
(
"Failed to get GenerateResponse: "
,
&
e
),
None
=>
return
fail_str
(
"No response from server"
),
};
// Extract the expected variant early
let
complete
=
match
gen_response
.response
{
Some
(
proto
::
generate_response
::
Response
::
Complete
(
c
))
=>
c
,
Some
(
proto
::
generate_response
::
Response
::
Error
(
err
))
=>
{
error!
(
"Generation failed: {}"
,
err
.message
);
return
(
StatusCode
::
INTERNAL_SERVER_ERROR
,
format!
(
"Generation failed: {}"
,
err
.message
),
)
.into_response
();
}
Some
(
proto
::
generate_response
::
Response
::
Chunk
(
_
))
=>
{
return
fail_str
(
"Unexpected chunk response for non-streaming request"
)
}
None
=>
return
fail_str
(
"Empty response from server"
),
};
// Decode tokens
let
outputs
=
match
stop_decoder
.process_tokens
(
&
complete
.output_ids
)
{
Ok
(
o
)
=>
o
,
Err
(
e
)
=>
return
fail_fmt
(
"Failed to process tokens: "
,
&
e
),
};
// Accumulate text with early breaks
let
mut
final_text
=
String
::
new
();
for
output
in
outputs
{
match
output
{
SequenceDecoderOutput
::
Text
(
t
)
=>
final_text
.push_str
(
&
t
),
SequenceDecoderOutput
::
StoppedWithText
(
t
)
=>
{
final_text
.push_str
(
&
t
);
break
;
}
SequenceDecoderOutput
::
Stopped
=>
break
,
SequenceDecoderOutput
::
Held
=>
{}
}
}
// Flush remaining text
if
let
SequenceDecoderOutput
::
Text
(
t
)
=
stop_decoder
.flush
()
{
final_text
.push_str
(
&
t
);
}
// TODO: Create proper OpenAI-compatible response
(
StatusCode
::
OK
,
format!
(
"Final text: {}"
,
final_text
))
.into_response
()
}
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment