Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
37f3325b
Unverified
Commit
37f3325b
authored
Sep 26, 2025
by
Chang Su
Committed by
GitHub
Sep 26, 2025
Browse files
[router][grpc] Support E2E non-stream chat completions (#10980)
parent
bd95944c
Changes
8
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
325 additions
and
136 deletions
+325
-136
python/sglang/srt/entrypoints/grpc_request_manager.py
python/sglang/srt/entrypoints/grpc_request_manager.py
+2
-2
python/sglang/srt/entrypoints/grpc_server.py
python/sglang/srt/entrypoints/grpc_server.py
+29
-9
python/sglang/srt/grpc/sglang_scheduler.proto
python/sglang/srt/grpc/sglang_scheduler.proto
+8
-14
python/sglang/srt/grpc/sglang_scheduler_pb2.py
python/sglang/srt/grpc/sglang_scheduler_pb2.py
+50
-52
python/sglang/srt/grpc/sglang_scheduler_pb2.pyi
python/sglang/srt/grpc/sglang_scheduler_pb2.pyi
+7
-16
sgl-router/src/proto/sglang_scheduler.proto
sgl-router/src/proto/sglang_scheduler.proto
+8
-14
sgl-router/src/protocols/spec.rs
sgl-router/src/protocols/spec.rs
+16
-1
sgl-router/src/routers/grpc/router.rs
sgl-router/src/routers/grpc/router.rs
+205
-28
No files found.
python/sglang/srt/entrypoints/grpc_request_manager.py
View file @
37f3325b
...
@@ -13,7 +13,7 @@ import sys
...
@@ -13,7 +13,7 @@ import sys
import
threading
import
threading
import
time
import
time
import
uuid
import
uuid
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Union
from
typing
import
Any
,
AsyncGenerator
,
Dict
,
List
,
Optional
,
Union
import
grpc
import
grpc
import
zmq
import
zmq
...
@@ -156,7 +156,7 @@ class GrpcRequestManager:
...
@@ -156,7 +156,7 @@ class GrpcRequestManager:
obj
:
TokenizedGenerateReqInput
,
obj
:
TokenizedGenerateReqInput
,
request_id
:
Optional
[
str
]
=
None
,
request_id
:
Optional
[
str
]
=
None
,
grpc_context
:
Optional
[
grpc
.
aio
.
ServicerContext
]
=
None
,
grpc_context
:
Optional
[
grpc
.
aio
.
ServicerContext
]
=
None
,
):
)
->
AsyncGenerator
[
Union
[
Dict
,
List
[
Dict
]],
None
]
:
"""
"""
Submit a generation request to the scheduler with n>1 parallel sampling support.
Submit a generation request to the scheduler with n>1 parallel sampling support.
...
...
python/sglang/srt/entrypoints/grpc_server.py
View file @
37f3325b
...
@@ -321,14 +321,14 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
...
@@ -321,14 +321,14 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
logger
.
info
(
f
"Sending health check request to request manager..."
)
logger
.
info
(
f
"Sending health check request to request manager..."
)
# Submit and wait for response
# Submit and wait for response
output_
queue
=
await
self
.
request_manager
.
generate_request
(
output_
generator
=
self
.
request_manager
.
generate_request
(
health_request
,
request_id
=
rid
health_request
,
request_id
=
rid
)
)
try
:
try
:
#
Wait for
response with
configurable
timeout
#
Get first
response with timeout
response
=
await
asyncio
.
wait_for
(
response
=
await
asyncio
.
wait_for
(
output_
queue
.
get
(),
timeout
=
HEALTH_CHECK_TIMEOUT
output_
generator
.
__anext__
(),
timeout
=
HEALTH_CHECK_TIMEOUT
)
)
# Clean up
# Clean up
...
@@ -492,13 +492,32 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
...
@@ -492,13 +492,32 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
)
->
sglang_scheduler_pb2
.
GenerateResponse
:
)
->
sglang_scheduler_pb2
.
GenerateResponse
:
"""Create a completion response."""
"""Create a completion response."""
# Determine finish reason
# Extract meta info and finish reason details
finish_reason
=
sglang_scheduler_pb2
.
GenerateComplete
.
STOP
meta_info
=
output
.
get
(
"meta_info"
,
{})
meta_info
=
output
.
get
(
"meta_info"
,
{})
if
meta_info
.
get
(
"finish_reason"
)
==
"length"
:
finish_reason_data
=
meta_info
.
get
(
"finish_reason"
)
finish_reason
=
sglang_scheduler_pb2
.
GenerateComplete
.
LENGTH
elif
meta_info
.
get
(
"finish_reason"
)
==
"eos_token"
:
# Determine finish reason, default is stop
finish_reason
=
sglang_scheduler_pb2
.
GenerateComplete
.
EOS_TOKEN
finish_reason
=
"stop"
if
finish_reason_data
:
if
isinstance
(
finish_reason_data
,
dict
):
finish_reason_type
=
finish_reason_data
.
get
(
"type"
)
else
:
# Handle legacy string format
finish_reason_type
=
finish_reason_data
if
finish_reason_type
==
"length"
:
finish_reason
=
"length"
elif
finish_reason_type
==
"abort"
:
finish_reason
=
"abort"
# Extract matched_stop information
matched_stop_kwargs
=
{}
if
isinstance
(
finish_reason_data
,
dict
)
and
"matched"
in
finish_reason_data
:
matched
=
finish_reason_data
[
"matched"
]
if
isinstance
(
matched
,
int
):
matched_stop_kwargs
[
"matched_token_id"
]
=
matched
elif
isinstance
(
matched
,
str
):
matched_stop_kwargs
[
"matched_stop_str"
]
=
matched
return
sglang_scheduler_pb2
.
GenerateResponse
(
return
sglang_scheduler_pb2
.
GenerateResponse
(
request_id
=
request_id
,
request_id
=
request_id
,
...
@@ -510,6 +529,7 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
...
@@ -510,6 +529,7 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
"completion_tokens"
,
len
(
output
.
get
(
"token_ids"
,
[]))
"completion_tokens"
,
len
(
output
.
get
(
"token_ids"
,
[]))
),
),
cached_tokens
=
meta_info
.
get
(
"cached_tokens"
,
0
),
cached_tokens
=
meta_info
.
get
(
"cached_tokens"
,
0
),
**
matched_stop_kwargs
,
),
),
)
)
...
...
python/sglang/srt/grpc/sglang_scheduler.proto
View file @
37f3325b
...
@@ -185,20 +185,8 @@ message GenerateComplete {
...
@@ -185,20 +185,8 @@ message GenerateComplete {
// Final output
// Final output
repeated
uint32
output_ids
=
1
;
repeated
uint32
output_ids
=
1
;
// Finish reason
// Finish reason as OpenAI-compatible string ("stop", "length", "abort")
enum
FinishReason
{
string
finish_reason
=
2
;
// The model generated a stop sequence.
STOP
=
0
;
// The model reached the maximum generation length.
LENGTH
=
1
;
// The model generated an end-of-sequence (EOS) token.
EOS_TOKEN
=
2
;
// The model generated a user-provided stop string.
STOP_STR
=
3
;
// The request was aborted by the user or system.
ABORT
=
4
;
}
FinishReason
finish_reason
=
2
;
// Token usage counts
// Token usage counts
int32
prompt_tokens
=
3
;
int32
prompt_tokens
=
3
;
...
@@ -210,6 +198,12 @@ message GenerateComplete {
...
@@ -210,6 +198,12 @@ message GenerateComplete {
// All hidden states if requested
// All hidden states if requested
repeated
HiddenStates
all_hidden_states
=
7
;
repeated
HiddenStates
all_hidden_states
=
7
;
// Matched stop information (for stop sequences)
oneof
matched_stop
{
uint32
matched_token_id
=
8
;
string
matched_stop_str
=
9
;
}
}
}
message
GenerateError
{
message
GenerateError
{
...
...
python/sglang/srt/grpc/sglang_scheduler_pb2.py
View file @
37f3325b
This diff is collapsed.
Click to expand it.
python/sglang/srt/grpc/sglang_scheduler_pb2.pyi
View file @
37f3325b
...
@@ -3,7 +3,6 @@ import datetime
...
@@ -3,7 +3,6 @@ import datetime
from google.protobuf import timestamp_pb2 as _timestamp_pb2
from google.protobuf import timestamp_pb2 as _timestamp_pb2
from google.protobuf import struct_pb2 as _struct_pb2
from google.protobuf import struct_pb2 as _struct_pb2
from google.protobuf.internal import containers as _containers
from google.protobuf.internal import containers as _containers
from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from google.protobuf import message as _message
from collections.abc import Iterable as _Iterable, Mapping as _Mapping
from collections.abc import Iterable as _Iterable, Mapping as _Mapping
...
@@ -179,19 +178,7 @@ class GenerateStreamChunk(_message.Message):
...
@@ -179,19 +178,7 @@ class GenerateStreamChunk(_message.Message):
def __init__(self, token_ids: _Optional[_Iterable[int]] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., logprobs: _Optional[_Union[LogProbs, _Mapping]] = ..., hidden_states: _Optional[_Iterable[float]] = ...) -> None: ...
def __init__(self, token_ids: _Optional[_Iterable[int]] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., logprobs: _Optional[_Union[LogProbs, _Mapping]] = ..., hidden_states: _Optional[_Iterable[float]] = ...) -> None: ...
class GenerateComplete(_message.Message):
class GenerateComplete(_message.Message):
__slots__ = ("output_ids", "finish_reason", "prompt_tokens", "completion_tokens", "cached_tokens", "all_logprobs", "all_hidden_states")
__slots__ = ("output_ids", "finish_reason", "prompt_tokens", "completion_tokens", "cached_tokens", "all_logprobs", "all_hidden_states", "matched_token_id", "matched_stop_str")
class FinishReason(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
__slots__ = ()
STOP: _ClassVar[GenerateComplete.FinishReason]
LENGTH: _ClassVar[GenerateComplete.FinishReason]
EOS_TOKEN: _ClassVar[GenerateComplete.FinishReason]
STOP_STR: _ClassVar[GenerateComplete.FinishReason]
ABORT: _ClassVar[GenerateComplete.FinishReason]
STOP: GenerateComplete.FinishReason
LENGTH: GenerateComplete.FinishReason
EOS_TOKEN: GenerateComplete.FinishReason
STOP_STR: GenerateComplete.FinishReason
ABORT: GenerateComplete.FinishReason
OUTPUT_IDS_FIELD_NUMBER: _ClassVar[int]
OUTPUT_IDS_FIELD_NUMBER: _ClassVar[int]
FINISH_REASON_FIELD_NUMBER: _ClassVar[int]
FINISH_REASON_FIELD_NUMBER: _ClassVar[int]
PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int]
PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int]
...
@@ -199,14 +186,18 @@ class GenerateComplete(_message.Message):
...
@@ -199,14 +186,18 @@ class GenerateComplete(_message.Message):
CACHED_TOKENS_FIELD_NUMBER: _ClassVar[int]
CACHED_TOKENS_FIELD_NUMBER: _ClassVar[int]
ALL_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
ALL_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
ALL_HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int]
ALL_HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int]
MATCHED_TOKEN_ID_FIELD_NUMBER: _ClassVar[int]
MATCHED_STOP_STR_FIELD_NUMBER: _ClassVar[int]
output_ids: _containers.RepeatedScalarFieldContainer[int]
output_ids: _containers.RepeatedScalarFieldContainer[int]
finish_reason:
GenerateComplete.FinishReason
finish_reason:
str
prompt_tokens: int
prompt_tokens: int
completion_tokens: int
completion_tokens: int
cached_tokens: int
cached_tokens: int
all_logprobs: _containers.RepeatedCompositeFieldContainer[LogProbs]
all_logprobs: _containers.RepeatedCompositeFieldContainer[LogProbs]
all_hidden_states: _containers.RepeatedCompositeFieldContainer[HiddenStates]
all_hidden_states: _containers.RepeatedCompositeFieldContainer[HiddenStates]
def __init__(self, output_ids: _Optional[_Iterable[int]] = ..., finish_reason: _Optional[_Union[GenerateComplete.FinishReason, str]] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., all_logprobs: _Optional[_Iterable[_Union[LogProbs, _Mapping]]] = ..., all_hidden_states: _Optional[_Iterable[_Union[HiddenStates, _Mapping]]] = ...) -> None: ...
matched_token_id: int
matched_stop_str: str
def __init__(self, output_ids: _Optional[_Iterable[int]] = ..., finish_reason: _Optional[str] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., all_logprobs: _Optional[_Iterable[_Union[LogProbs, _Mapping]]] = ..., all_hidden_states: _Optional[_Iterable[_Union[HiddenStates, _Mapping]]] = ..., matched_token_id: _Optional[int] = ..., matched_stop_str: _Optional[str] = ...) -> None: ...
class GenerateError(_message.Message):
class GenerateError(_message.Message):
__slots__ = ("message", "http_status_code", "details")
__slots__ = ("message", "http_status_code", "details")
...
...
sgl-router/src/proto/sglang_scheduler.proto
View file @
37f3325b
...
@@ -185,20 +185,8 @@ message GenerateComplete {
...
@@ -185,20 +185,8 @@ message GenerateComplete {
// Final output
// Final output
repeated
uint32
output_ids
=
1
;
repeated
uint32
output_ids
=
1
;
// Finish reason
// Finish reason as OpenAI-compatible string ("stop", "length", "abort")
enum
FinishReason
{
string
finish_reason
=
2
;
// The model generated a stop sequence.
STOP
=
0
;
// The model reached the maximum generation length.
LENGTH
=
1
;
// The model generated an end-of-sequence (EOS) token.
EOS_TOKEN
=
2
;
// The model generated a user-provided stop string.
STOP_STR
=
3
;
// The request was aborted by the user or system.
ABORT
=
4
;
}
FinishReason
finish_reason
=
2
;
// Token usage counts
// Token usage counts
int32
prompt_tokens
=
3
;
int32
prompt_tokens
=
3
;
...
@@ -210,6 +198,12 @@ message GenerateComplete {
...
@@ -210,6 +198,12 @@ message GenerateComplete {
// All hidden states if requested
// All hidden states if requested
repeated
HiddenStates
all_hidden_states
=
7
;
repeated
HiddenStates
all_hidden_states
=
7
;
// Matched stop information (for stop sequences)
oneof
matched_stop
{
uint32
matched_token_id
=
8
;
string
matched_stop_str
=
9
;
}
}
}
message
GenerateError
{
message
GenerateError
{
...
...
sgl-router/src/protocols/spec.rs
View file @
37f3325b
...
@@ -423,10 +423,25 @@ pub struct ChatCompletionResponse {
...
@@ -423,10 +423,25 @@ pub struct ChatCompletionResponse {
pub
system_fingerprint
:
Option
<
String
>
,
pub
system_fingerprint
:
Option
<
String
>
,
}
}
/// Response message structure for ChatCompletionResponse (different from request ChatMessage)
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
pub
struct
ChatCompletionMessage
{
pub
role
:
String
,
// Always "assistant" for responses
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
content
:
Option
<
String
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
tool_calls
:
Option
<
Vec
<
ToolCall
>>
,
/// Reasoning content for O1-style models (SGLang extension)
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
reasoning_content
:
Option
<
String
>
,
// Note: function_call is deprecated and not included
// Note: refusal, annotations, audio are not added yet
}
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
pub
struct
ChatChoice
{
pub
struct
ChatChoice
{
pub
index
:
u32
,
pub
index
:
u32
,
pub
message
:
ChatMessage
,
pub
message
:
Chat
Completion
Message
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
logprobs
:
Option
<
ChatLogProbs
>
,
pub
logprobs
:
Option
<
ChatLogProbs
>
,
pub
finish_reason
:
Option
<
String
>
,
// "stop", "length", "tool_calls", "content_filter", "function_call"
pub
finish_reason
:
Option
<
String
>
,
// "stop", "length", "tool_calls", "content_filter", "function_call"
...
...
sgl-router/src/routers/grpc/router.rs
View file @
37f3325b
...
@@ -8,6 +8,7 @@ use axum::{
...
@@ -8,6 +8,7 @@ use axum::{
extract
::
Request
,
extract
::
Request
,
http
::{
HeaderMap
,
StatusCode
},
http
::{
HeaderMap
,
StatusCode
},
response
::{
IntoResponse
,
Response
},
response
::{
IntoResponse
,
Response
},
Json
,
};
};
use
tracing
::{
debug
,
error
,
info
,
warn
};
use
tracing
::{
debug
,
error
,
info
,
warn
};
...
@@ -18,8 +19,9 @@ use crate::metrics::RouterMetrics;
...
@@ -18,8 +19,9 @@ use crate::metrics::RouterMetrics;
use
crate
::
policies
::
PolicyRegistry
;
use
crate
::
policies
::
PolicyRegistry
;
use
crate
::
protocols
::
spec
::
ChatMessage
;
use
crate
::
protocols
::
spec
::
ChatMessage
;
use
crate
::
protocols
::
spec
::{
use
crate
::
protocols
::
spec
::{
ChatCompletionRequest
,
CompletionRequest
,
EmbeddingRequest
,
GenerateRequest
,
RerankRequest
,
ChatChoice
,
ChatCompletionMessage
,
ChatCompletionRequest
,
ChatCompletionResponse
,
ResponsesGetParams
,
ResponsesRequest
,
StringOrArray
,
Tool
,
ToolChoice
,
CompletionRequest
,
EmbeddingRequest
,
GenerateRequest
,
RerankRequest
,
ResponsesGetParams
,
ResponsesRequest
,
StringOrArray
,
Tool
,
ToolChoice
,
Usage
,
};
};
use
crate
::
reasoning_parser
::
ParserFactory
;
use
crate
::
reasoning_parser
::
ParserFactory
;
use
crate
::
routers
::
RouterTrait
;
use
crate
::
routers
::
RouterTrait
;
...
@@ -30,6 +32,7 @@ use crate::tokenizer::traits::Tokenizer;
...
@@ -30,6 +32,7 @@ use crate::tokenizer::traits::Tokenizer;
use
crate
::
tokenizer
::
HuggingFaceTokenizer
;
use
crate
::
tokenizer
::
HuggingFaceTokenizer
;
use
crate
::
tool_parser
::
ParserRegistry
;
use
crate
::
tool_parser
::
ParserRegistry
;
use
serde_json
::
Value
;
use
serde_json
::
Value
;
use
std
::
time
::{
SystemTime
,
UNIX_EPOCH
};
use
tokio_stream
::
StreamExt
;
use
tokio_stream
::
StreamExt
;
use
uuid
::
Uuid
;
use
uuid
::
Uuid
;
...
@@ -648,36 +651,99 @@ impl GrpcRouter {
...
@@ -648,36 +651,99 @@ impl GrpcRouter {
Err
(
e
)
=>
return
fail_fmt
(
"Failed to start generation: "
,
&
e
),
Err
(
e
)
=>
return
fail_fmt
(
"Failed to start generation: "
,
&
e
),
};
};
// Get the single Complete response
// Collect all responses (for n>1 support)
let
gen_response
=
match
stream
.next
()
.await
{
let
mut
all_responses
=
Vec
::
new
();
Some
(
Ok
(
r
))
=>
r
,
while
let
Some
(
response
)
=
stream
.next
()
.await
{
Some
(
Err
(
e
))
=>
return
fail_fmt
(
"Failed to get GenerateResponse: "
,
&
e
),
match
response
{
None
=>
return
fail_str
(
"No response from server"
),
Ok
(
gen_response
)
=>
match
gen_response
.response
{
};
Some
(
proto
::
generate_response
::
Response
::
Complete
(
complete
))
=>
{
all_responses
.push
(
complete
);
// Extract the expected variant early
}
let
complete
=
match
gen_response
.response
{
Some
(
proto
::
generate_response
::
Response
::
Error
(
err
))
=>
{
Some
(
proto
::
generate_response
::
Response
::
Complete
(
c
))
=>
c
,
error!
(
"Generation failed for one choice: {}"
,
err
.message
);
Some
(
proto
::
generate_response
::
Response
::
Error
(
err
))
=>
{
return
(
error!
(
"Generation failed: {}"
,
err
.message
);
StatusCode
::
INTERNAL_SERVER_ERROR
,
return
(
format!
(
"Generation failed: {}"
,
err
.message
),
StatusCode
::
INTERNAL_SERVER_ERROR
,
)
format!
(
"Generation failed: {}"
,
err
.message
),
.into_response
();
)
}
.into_response
();
Some
(
proto
::
generate_response
::
Response
::
Chunk
(
_
))
=>
{
return
fail_str
(
"Unexpected chunk response for non-streaming request"
)
}
None
=>
return
fail_str
(
"Empty response from server"
),
},
Err
(
e
)
=>
return
fail_fmt
(
"Failed to get GenerateResponse: "
,
&
e
),
}
}
Some
(
proto
::
generate_response
::
Response
::
Chunk
(
_
))
=>
{
}
return
fail_str
(
"Unexpected chunk response for non-streaming request"
)
if
all_responses
.is_empty
()
{
return
fail_str
(
"No responses from server"
);
}
// Process each response into a ChatChoice
let
mut
choices
=
Vec
::
new
();
for
(
index
,
complete
)
in
all_responses
.iter
()
.enumerate
()
{
match
self
.process_single_choice
(
complete
,
index
,
original_request
,
&
mut
stop_decoder
)
.await
{
Ok
(
choice
)
=>
choices
.push
(
choice
),
Err
(
e
)
=>
{
error!
(
"Failed to process choice {}: {}"
,
index
,
e
);
return
(
StatusCode
::
INTERNAL_SERVER_ERROR
,
format!
(
"Failed to process choice {}: {}"
,
index
,
e
),
)
.into_response
();
}
}
}
None
=>
return
fail_str
(
"Empty response from server"
),
}
// Aggregate usage information from all responses
let
total_prompt_tokens
:
u32
=
all_responses
.iter
()
.map
(|
r
|
r
.prompt_tokens
as
u32
)
.sum
();
let
total_completion_tokens
:
u32
=
all_responses
.iter
()
.map
(|
r
|
r
.completion_tokens
as
u32
)
.sum
();
let
usage
=
Usage
{
prompt_tokens
:
total_prompt_tokens
,
completion_tokens
:
total_completion_tokens
,
total_tokens
:
total_prompt_tokens
+
total_completion_tokens
,
completion_tokens_details
:
None
,
};
};
// Decode tokens
// Build final ChatCompletionResponse
let
outputs
=
match
stop_decoder
.process_tokens
(
&
complete
.output_ids
)
{
let
response
=
ChatCompletionResponse
{
Ok
(
o
)
=>
o
,
id
:
format!
(
"chatcmpl-{}"
,
Uuid
::
new_v4
()),
Err
(
e
)
=>
return
fail_fmt
(
"Failed to process tokens: "
,
&
e
),
object
:
"chat.completion"
.to_string
(),
created
:
SystemTime
::
now
()
.duration_since
(
UNIX_EPOCH
)
.unwrap_or_default
()
.as_secs
(),
model
:
original_request
.model
.clone
(),
choices
,
usage
:
Some
(
usage
),
system_fingerprint
:
None
,
};
};
// Serialize and return JSON response
Json
(
response
)
.into_response
()
}
/// Process a single GenerateComplete response into a ChatChoice
async
fn
process_single_choice
(
&
self
,
complete
:
&
proto
::
GenerateComplete
,
index
:
usize
,
original_request
:
&
ChatCompletionRequest
,
stop_decoder
:
&
mut
crate
::
tokenizer
::
stop
::
StopSequenceDecoder
,
)
->
Result
<
ChatChoice
,
String
>
{
stop_decoder
.reset
();
// Decode tokens
let
outputs
=
stop_decoder
.process_tokens
(
&
complete
.output_ids
)
.map_err
(|
e
|
format!
(
"Failed to process tokens: {}"
,
e
))
?
;
// Accumulate text with early breaks
// Accumulate text with early breaks
let
mut
final_text
=
String
::
new
();
let
mut
final_text
=
String
::
new
();
for
output
in
outputs
{
for
output
in
outputs
{
...
@@ -697,8 +763,119 @@ impl GrpcRouter {
...
@@ -697,8 +763,119 @@ impl GrpcRouter {
final_text
.push_str
(
&
t
);
final_text
.push_str
(
&
t
);
}
}
// TODO: Create proper OpenAI-compatible response
// Step 1: Handle reasoning content parsing
(
StatusCode
::
OK
,
format!
(
"Final text: {}"
,
final_text
))
.into_response
()
let
mut
reasoning_text
:
Option
<
String
>
=
None
;
let
mut
processed_text
=
final_text
;
// Check if reasoning parsing is enabled and separate_reasoning is requested
if
original_request
.separate_reasoning
{
if
let
Ok
(
mut
parser
)
=
self
.reasoning_parser_factory
.create
(
&
original_request
.model
)
{
match
parser
.detect_and_parse_reasoning
(
&
processed_text
)
{
Ok
(
result
)
=>
{
if
!
result
.reasoning_text
.is_empty
()
{
reasoning_text
=
Some
(
result
.reasoning_text
);
}
processed_text
=
result
.normal_text
;
}
Err
(
e
)
=>
{
return
Err
(
format!
(
"Reasoning parsing error: {}"
,
e
));
}
}
}
}
// Step 2: Handle tool call parsing
let
mut
tool_calls
:
Option
<
Vec
<
crate
::
protocols
::
spec
::
ToolCall
>>
=
None
;
// Check if tool calls should be processed
let
tool_choice_enabled
=
!
matches!
(
&
original_request
.tool_choice
,
Some
(
ToolChoice
::
Value
(
crate
::
protocols
::
spec
::
ToolChoiceValue
::
None
))
);
if
tool_choice_enabled
&&
original_request
.tools
.is_some
()
{
if
let
Some
(
parser
)
=
self
.tool_parser_registry
.get_parser
(
&
original_request
.model
)
{
match
parser
.parse_complete
(
&
processed_text
)
.await
{
Ok
(
parsed_tool_calls
)
=>
{
if
!
parsed_tool_calls
.is_empty
()
{
let
spec_tool_calls
=
parsed_tool_calls
.into_iter
()
.map
(|
tc
|
crate
::
protocols
::
spec
::
ToolCall
{
id
:
tc
.id
,
tool_type
:
"function"
.to_string
(),
function
:
crate
::
protocols
::
spec
::
FunctionCallResponse
{
name
:
tc
.function.name
,
arguments
:
Some
(
serde_json
::
to_string
(
&
tc
.function.arguments
)
.unwrap_or_else
(|
_
|
"{}"
.to_string
()),
),
},
})
.collect
();
tool_calls
=
Some
(
spec_tool_calls
);
processed_text
=
String
::
new
();
}
}
Err
(
e
)
=>
{
error!
(
"Tool call parsing error: {}"
,
e
);
// Continue without tool calls rather than failing
}
}
}
}
// Step 3: Use finish reason directly from proto (already OpenAI-compatible string)
let
finish_reason_str
=
&
complete
.finish_reason
;
// Override finish reason if we have tool calls
let
final_finish_reason_str
=
if
tool_calls
.is_some
()
{
"tool_calls"
}
else
{
finish_reason_str
};
// Extract matched_stop information from proto
let
matched_stop
=
match
&
complete
.matched_stop
{
Some
(
proto
::
generate_complete
::
MatchedStop
::
MatchedTokenId
(
token_id
))
=>
Some
(
serde_json
::
Value
::
Number
(
serde_json
::
Number
::
from
(
*
token_id
)),
),
Some
(
proto
::
generate_complete
::
MatchedStop
::
MatchedStopStr
(
stop_str
))
=>
{
Some
(
serde_json
::
Value
::
String
(
stop_str
.clone
()))
}
None
=>
None
,
};
// Step 4: Build ChatCompletionMessage (proper response message type)
let
chat_message
=
ChatCompletionMessage
{
role
:
"assistant"
.to_string
(),
content
:
if
processed_text
.is_empty
()
{
None
}
else
{
Some
(
processed_text
)
},
tool_calls
,
reasoning_content
:
reasoning_text
,
};
// Step 5: Build ChatChoice
let
choice
=
ChatChoice
{
index
:
index
as
u32
,
message
:
chat_message
,
logprobs
:
None
,
finish_reason
:
Some
(
final_finish_reason_str
.to_string
()),
matched_stop
,
hidden_states
:
None
,
};
Ok
(
choice
)
}
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment