Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
37158f20
Unverified
Commit
37158f20
authored
Sep 25, 2025
by
Chang Su
Committed by
GitHub
Sep 25, 2025
Browse files
router: Support parallel sampling num > 1 in grpc_server and non-stream handling (#10929)
parent
3e95aa1a
Changes
8
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
281 additions
and
135 deletions
+281
-135
python/sglang/srt/entrypoints/grpc_request_manager.py
python/sglang/srt/entrypoints/grpc_request_manager.py
+166
-30
python/sglang/srt/entrypoints/grpc_server.py
python/sglang/srt/entrypoints/grpc_server.py
+27
-27
python/sglang/srt/grpc/sglang_scheduler.proto
python/sglang/srt/grpc/sglang_scheduler.proto
+5
-2
python/sglang/srt/grpc/sglang_scheduler_pb2.py
python/sglang/srt/grpc/sglang_scheduler_pb2.py
+62
-62
python/sglang/srt/grpc/sglang_scheduler_pb2.pyi
python/sglang/srt/grpc/sglang_scheduler_pb2.pyi
+8
-6
sgl-router/src/grpc_client/sglang_scheduler.rs
sgl-router/src/grpc_client/sglang_scheduler.rs
+3
-2
sgl-router/src/proto/sglang_scheduler.proto
sgl-router/src/proto/sglang_scheduler.proto
+5
-2
sgl-router/src/routers/grpc/router.rs
sgl-router/src/routers/grpc/router.rs
+5
-4
No files found.
python/sglang/srt/entrypoints/grpc_request_manager.py
View file @
37158f20
...
...
@@ -4,6 +4,7 @@ Mimics TokenizerManager's state management and ZMQ communication patterns.
"""
import
asyncio
import
copy
import
dataclasses
import
logging
import
os
...
...
@@ -11,6 +12,7 @@ import signal
import
sys
import
threading
import
time
import
uuid
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Union
import
grpc
...
...
@@ -79,11 +81,9 @@ class GrpcReqState:
last_completion_tokens
:
int
=
1
# Streaming state
last_output_offset
:
int
=
0
stream_finished
:
bool
=
False
# Output accumulation
text
:
str
=
""
# Token accumulation (for non-streaming)
output_ids
:
List
[
int
]
=
dataclasses
.
field
(
default_factory
=
list
)
input_token_logprobs_val
:
List
[
float
]
=
dataclasses
.
field
(
default_factory
=
list
)
input_token_logprobs_idx
:
List
[
int
]
=
dataclasses
.
field
(
default_factory
=
list
)
...
...
@@ -139,8 +139,6 @@ class GrpcRequestManager:
self
.
is_pause_cond
=
asyncio
.
Condition
()
# Metrics
self
.
request_counter
=
0
self
.
request_counter_lock
=
asyncio
.
Lock
()
self
.
last_receive_tstamp
=
time
.
time
()
# Crash dump for debugging
...
...
@@ -158,22 +156,133 @@ class GrpcRequestManager:
obj
:
TokenizedGenerateReqInput
,
request_id
:
Optional
[
str
]
=
None
,
grpc_context
:
Optional
[
grpc
.
aio
.
ServicerContext
]
=
None
,
)
->
asyncio
.
Queue
:
):
"""
Submit a generation request to the scheduler.
Returns a queue for streaming outputs.
Submit a generation request to the scheduler with n>1 parallel sampling support.
This method implements the same two-phase approach as tokenizer_manager.py:
1. Phase 1: Send prefix caching request (max_new_tokens=0)
2. Phase 2: Send n generation requests that reuse the cached prefix
Yields individual responses for streaming, or aggregated responses for non-streaming.
"""
n
=
getattr
(
obj
.
sampling_params
,
"n"
,
1
)
if
n
<=
1
:
async
for
response
in
self
.
_handle_single_request
(
obj
,
request_id
,
grpc_context
):
yield
response
return
# N>1 handling - two-phase approach
logger
.
debug
(
f
"Multiple sampling request (n=
{
n
}
), using two-phase approach"
)
# Generate base request ID if not provided
if
request_id
is
None
:
base_request_id
=
f
"grpc-
{
uuid
.
uuid4
().
hex
}
"
else
:
base_request_id
=
request_id
# Phase 1: Cache the common prefix
logger
.
debug
(
f
"Phase 1: Caching prefix for request
{
base_request_id
}
"
)
prefix_obj
=
copy
.
copy
(
obj
)
prefix_obj
.
sampling_params
=
copy
.
copy
(
obj
.
sampling_params
)
prefix_obj
.
sampling_params
.
max_new_tokens
=
0
# Prefill-only
prefix_obj
.
sampling_params
.
n
=
1
# Don't replicate prefix request
# Send prefix caching request and consume response
async
for
_
in
self
.
_handle_single_request
(
prefix_obj
,
f
"
{
base_request_id
}
-prefix"
,
grpc_context
):
# Consume prefix response (usually just one chunk with finish_reason)
pass
logger
.
debug
(
f
"Phase 1 completed: Prefix cached for
{
base_request_id
}
"
)
# Phase 2: Generate n parallel requests
logger
.
debug
(
f
"Phase 2: Generating
{
n
}
parallel requests"
)
generators
=
[]
request_ids
=
[]
for
i
in
range
(
n
):
# Create individual generation request
gen_obj
=
copy
.
copy
(
obj
)
gen_obj
.
sampling_params
=
copy
.
copy
(
obj
.
sampling_params
)
gen_obj
.
sampling_params
.
n
=
1
# Each request generates 1 response
gen_request_id
=
f
"
{
base_request_id
}
-
{
i
}
"
request_ids
.
append
(
gen_request_id
)
# Start generation request
generators
.
append
(
self
.
_handle_single_request
(
gen_obj
,
gen_request_id
,
grpc_context
)
)
# Handle response aggregation
is_stream
=
getattr
(
obj
,
"stream"
,
False
)
if
not
is_stream
:
# Non-streaming: collect all responses and return as batch
logger
.
debug
(
f
"Non-streaming mode: collecting
{
n
}
responses"
)
responses
=
[]
for
generator
in
generators
:
async
for
response
in
generator
:
responses
.
append
(
response
)
yield
responses
# Return all responses as a batch
else
:
# Streaming mode: multiplex responses with index for ordering
logger
.
debug
(
f
"Streaming mode: multiplexing
{
n
}
streams"
)
rid_to_index
=
{
rid
:
i
for
i
,
rid
in
enumerate
(
request_ids
)}
# Create async tasks for all generators
task_map
=
{}
for
generator
in
generators
:
task
=
asyncio
.
create_task
(
generator
.
__anext__
())
task_map
[
task
]
=
generator
# Process responses as they arrive
while
task_map
:
done
,
_
=
await
asyncio
.
wait
(
task_map
.
keys
(),
return_when
=
asyncio
.
FIRST_COMPLETED
)
for
task
in
done
:
generator
=
task_map
.
pop
(
task
)
try
:
response
=
await
task
# Add index for client-side ordering
if
isinstance
(
response
,
dict
)
and
"meta_info"
in
response
:
response_rid
=
response
[
"meta_info"
].
get
(
"id"
,
""
)
if
response_rid
in
rid_to_index
:
response
[
"index"
]
=
rid_to_index
[
response_rid
]
yield
response
# Create next task for this generator
next_task
=
asyncio
.
create_task
(
generator
.
__anext__
())
task_map
[
next_task
]
=
generator
except
StopAsyncIteration
:
# This generator is finished
pass
async
def
_handle_single_request
(
self
,
obj
:
TokenizedGenerateReqInput
,
request_id
:
Optional
[
str
]
=
None
,
grpc_context
:
Optional
[
grpc
.
aio
.
ServicerContext
]
=
None
,
):
"""Handle a single request - core implementation without n>1 logic."""
# Generate request ID if not provided
if
request_id
is
None
:
async
with
self
.
request_counter_lock
:
request_id
=
f
"grpc-
{
self
.
request_counter
}
"
self
.
request_counter
+=
1
request_id
=
f
"grpc-
{
uuid
.
uuid4
().
hex
}
"
obj
.
rid
=
request_id
# Create and register request state
# TODO: support log_request
# Create request state
state
=
GrpcReqState
(
request_id
=
request_id
,
grpc_context
=
grpc_context
,
...
...
@@ -189,19 +298,51 @@ class GrpcRequestManager:
state
.
session_id
=
obj
.
session_params
.
session_id
state
.
is_session_request
=
True
# Register state
self
.
rid_to_state
[
request_id
]
=
state
self
.
record_request_for_crash_dump
(
obj
)
# Send to scheduler via ZMQ
try
:
# Send to scheduler - let exceptions bubble up to grpc_server.py
await
self
.
_send_to_scheduler
(
obj
)
except
Exception
as
e
:
# Clean up on failure
del
self
.
rid_to_state
[
request_id
]
raise
RuntimeError
(
f
"Failed to send request to scheduler:
{
e
}
"
)
return
state
.
out_queue
is_stream
=
getattr
(
obj
,
"stream"
,
False
)
while
True
:
# Client cancelled - notify scheduler and exit
if
grpc_context
and
grpc_context
.
cancelled
():
await
self
.
abort_request
(
request_id
)
return
try
:
response
=
await
asyncio
.
wait_for
(
state
.
out_queue
.
get
(),
timeout
=
4
)
if
is_stream
:
yield
response
# Non-streaming: yield final response with accumulated tokens from state
if
isinstance
(
response
,
dict
)
and
response
.
get
(
"finished"
,
False
):
if
not
is_stream
:
final_response
=
response
.
copy
()
final_response
[
"token_ids"
]
=
state
.
output_ids
yield
final_response
break
except
asyncio
.
TimeoutError
:
# Timeout waiting for response - abort and cleanup
logger
.
warning
(
f
"Timeout waiting for response for request
{
request_id
}
"
)
await
self
.
abort_request
(
request_id
)
return
finally
:
# Always clean up request state when exiting
self
.
_cleanup_request_state
(
request_id
)
def
_cleanup_request_state
(
self
,
request_id
:
str
):
"""Clean up local request state (does not notify scheduler)."""
if
request_id
in
self
.
rid_to_state
:
del
self
.
rid_to_state
[
request_id
]
async
def
embedding_request
(
self
,
...
...
@@ -214,9 +355,7 @@ class GrpcRequestManager:
"""
# Generate request ID if not provided
if
request_id
is
None
:
async
with
self
.
request_counter_lock
:
request_id
=
f
"grpc-embed-
{
self
.
request_counter
}
"
self
.
request_counter
+=
1
request_id
=
f
"grpc-embed-
{
uuid
.
uuid4
().
hex
}
"
obj
.
rid
=
request_id
...
...
@@ -355,7 +494,6 @@ class GrpcRequestManager:
# Extract output for this request
output_data
=
{
"request_id"
:
rid
,
"text"
:
batch_out
.
decoded_texts
[
i
]
if
batch_out
.
decoded_texts
else
""
,
"token_ids"
:
batch_out
.
output_ids
[
i
]
if
batch_out
.
output_ids
else
[],
"finished"
:
batch_out
.
finished_reasons
[
i
]
is
not
None
,
"meta_info"
:
{
...
...
@@ -367,6 +505,9 @@ class GrpcRequestManager:
if
batch_out
.
completion_tokens
else
0
),
"cached_tokens"
:
(
batch_out
.
cached_tokens
[
i
]
if
batch_out
.
cached_tokens
else
0
),
"finish_reason"
:
(
str
(
batch_out
.
finished_reasons
[
i
])
if
batch_out
.
finished_reasons
[
i
]
...
...
@@ -389,15 +530,10 @@ class GrpcRequestManager:
),
}
# Update state
if
output_data
[
"text"
]:
state
.
text
+=
output_data
[
"text"
][
state
.
last_output_offset
:]
state
.
last_output_offset
=
len
(
output_data
[
"text"
])
# Update state for accumulation
if
output_data
[
"token_ids"
]:
state
.
output_ids
.
extend
(
output_data
[
"token_ids"
])
# Send to output queue
await
state
.
out_queue
.
put
(
output_data
)
# Handle completion
...
...
python/sglang/srt/entrypoints/grpc_server.py
View file @
37158f20
...
...
@@ -181,20 +181,34 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
# Convert gRPC request to internal format
tokenized_req
=
self
.
_convert_generate_request
(
request
)
# Submit to request manager
output_queue
=
await
self
.
request_manager
.
generate_request
(
# Submit to request manager
(automatically handles n>1)
response_generator
=
self
.
request_manager
.
generate_request
(
obj
=
tokenized_req
,
request_id
=
request
.
request_id
,
grpc_context
=
context
,
)
# Stream outputs
while
True
:
try
:
# Get output with timeout
output
=
await
asyncio
.
wait_for
(
output_queue
.
get
(),
timeout
=
4
)
# Check for errors
async
for
output
in
response_generator
:
# Handle batch responses (for n>1 non-streaming)
if
isinstance
(
output
,
list
):
for
batch_output
in
output
:
if
"error"
in
batch_output
:
yield
sglang_scheduler_pb2
.
GenerateResponse
(
request_id
=
request
.
request_id
,
error
=
sglang_scheduler_pb2
.
GenerateError
(
message
=
batch_output
[
"error"
],
http_status_code
=
(
"500"
if
"abort"
not
in
batch_output
else
"499"
),
),
)
else
:
# All non-error batch outputs are final responses
yield
self
.
_create_completion_response
(
request
.
request_id
,
batch_output
)
else
:
# Handle single response (for streaming or n=1 non-streaming)
if
"error"
in
output
:
yield
sglang_scheduler_pb2
.
GenerateResponse
(
request_id
=
request
.
request_id
,
...
...
@@ -205,27 +219,13 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
),
),
)
break
# Check if finished
if
output
.
get
(
"finished"
,
False
):
# Send completion
elif
output
.
get
(
"finished"
,
False
):
yield
self
.
_create_completion_response
(
request
.
request_id
,
output
)
break
else
:
# Send chunk
yield
self
.
_create_chunk_response
(
request
.
request_id
,
output
)
except
asyncio
.
TimeoutError
:
# Check if context is still active
if
context
.
cancelled
():
# Abort the request
await
self
.
request_manager
.
abort_request
(
request
.
request_id
)
break
continue
except
Exception
as
e
:
logger
.
error
(
f
"Generate failed:
{
e
}
\n
{
get_exception_traceback
()
}
"
)
yield
sglang_scheduler_pb2
.
GenerateResponse
(
...
...
@@ -403,7 +403,7 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
return_logprob
=
grpc_req
.
return_logprob
,
logprob_start_len
=
grpc_req
.
logprob_start_len
or
-
1
,
top_logprobs_num
=
grpc_req
.
top_logprobs_num
or
0
,
stream
=
True
,
# Always
stream
f
or
gRPC
stream
=
grpc_req
.
stream
or
False
,
lora_path
=
grpc_req
.
lora_id
if
grpc_req
.
lora_id
else
None
,
token_ids_logprob
=
(
list
(
grpc_req
.
token_ids_logprob
)
if
grpc_req
.
token_ids_logprob
else
None
...
...
@@ -480,10 +480,10 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
return
sglang_scheduler_pb2
.
GenerateResponse
(
request_id
=
request_id
,
chunk
=
sglang_scheduler_pb2
.
GenerateStreamChunk
(
token_id
=
output
[
"token_ids"
][
-
1
]
if
output
.
get
(
"token_ids"
)
else
0
,
token_id
s
=
output
.
get
(
"token_ids"
,
[])
,
prompt_tokens
=
meta_info
.
get
(
"prompt_tokens"
,
0
),
completion_tokens
=
meta_info
.
get
(
"completion_tokens"
,
0
),
cached_tokens
=
0
,
cached_tokens
=
meta_info
.
get
(
"cached_tokens"
,
0
)
,
),
)
...
...
python/sglang/srt/grpc/sglang_scheduler.proto
View file @
37158f20
...
...
@@ -122,6 +122,9 @@ message GenerateRequest {
// For load balancing
int32
dp_balance_id
=
17
;
// Whether client wants streaming response
bool
stream
=
18
;
}
message
TokenizedInput
{
...
...
@@ -163,8 +166,8 @@ message GenerateResponse {
}
message
GenerateStreamChunk
{
// Generated token
int32
token_id
=
1
;
// Generated token
s (incremental chunk)
repeated
int32
token_id
s
=
1
;
// Cumulative counts
int32
prompt_tokens
=
2
;
...
...
python/sglang/srt/grpc/sglang_scheduler_pb2.py
View file @
37158f20
This diff is collapsed.
Click to expand it.
python/sglang/srt/grpc/sglang_scheduler_pb2.pyi
View file @
37158f20
...
...
@@ -83,7 +83,7 @@ class DisaggregatedParams(_message.Message):
def __init__(self, bootstrap_host: _Optional[str] = ..., bootstrap_port: _Optional[int] = ..., bootstrap_room: _Optional[int] = ...) -> None: ...
class GenerateRequest(_message.Message):
__slots__ = ("request_id", "tokenized", "mm_inputs", "sampling_params", "return_logprob", "logprob_start_len", "top_logprobs_num", "token_ids_logprob", "return_hidden_states", "disaggregated_params", "custom_logit_processor", "timestamp", "log_metrics", "input_embeds", "lora_id", "data_parallel_rank", "dp_balance_id")
__slots__ = ("request_id", "tokenized", "mm_inputs", "sampling_params", "return_logprob", "logprob_start_len", "top_logprobs_num", "token_ids_logprob", "return_hidden_states", "disaggregated_params", "custom_logit_processor", "timestamp", "log_metrics", "input_embeds", "lora_id", "data_parallel_rank", "dp_balance_id"
, "stream"
)
REQUEST_ID_FIELD_NUMBER: _ClassVar[int]
TOKENIZED_FIELD_NUMBER: _ClassVar[int]
MM_INPUTS_FIELD_NUMBER: _ClassVar[int]
...
...
@@ -101,6 +101,7 @@ class GenerateRequest(_message.Message):
LORA_ID_FIELD_NUMBER: _ClassVar[int]
DATA_PARALLEL_RANK_FIELD_NUMBER: _ClassVar[int]
DP_BALANCE_ID_FIELD_NUMBER: _ClassVar[int]
STREAM_FIELD_NUMBER: _ClassVar[int]
request_id: str
tokenized: TokenizedInput
mm_inputs: MultimodalInputs
...
...
@@ -118,7 +119,8 @@ class GenerateRequest(_message.Message):
lora_id: str
data_parallel_rank: int
dp_balance_id: int
def __init__(self, request_id: _Optional[str] = ..., tokenized: _Optional[_Union[TokenizedInput, _Mapping]] = ..., mm_inputs: _Optional[_Union[MultimodalInputs, _Mapping]] = ..., sampling_params: _Optional[_Union[SamplingParams, _Mapping]] = ..., return_logprob: bool = ..., logprob_start_len: _Optional[int] = ..., top_logprobs_num: _Optional[int] = ..., token_ids_logprob: _Optional[_Iterable[int]] = ..., return_hidden_states: bool = ..., disaggregated_params: _Optional[_Union[DisaggregatedParams, _Mapping]] = ..., custom_logit_processor: _Optional[str] = ..., timestamp: _Optional[_Union[datetime.datetime, _timestamp_pb2.Timestamp, _Mapping]] = ..., log_metrics: bool = ..., input_embeds: _Optional[_Iterable[float]] = ..., lora_id: _Optional[str] = ..., data_parallel_rank: _Optional[int] = ..., dp_balance_id: _Optional[int] = ...) -> None: ...
stream: bool
def __init__(self, request_id: _Optional[str] = ..., tokenized: _Optional[_Union[TokenizedInput, _Mapping]] = ..., mm_inputs: _Optional[_Union[MultimodalInputs, _Mapping]] = ..., sampling_params: _Optional[_Union[SamplingParams, _Mapping]] = ..., return_logprob: bool = ..., logprob_start_len: _Optional[int] = ..., top_logprobs_num: _Optional[int] = ..., token_ids_logprob: _Optional[_Iterable[int]] = ..., return_hidden_states: bool = ..., disaggregated_params: _Optional[_Union[DisaggregatedParams, _Mapping]] = ..., custom_logit_processor: _Optional[str] = ..., timestamp: _Optional[_Union[datetime.datetime, _timestamp_pb2.Timestamp, _Mapping]] = ..., log_metrics: bool = ..., input_embeds: _Optional[_Iterable[float]] = ..., lora_id: _Optional[str] = ..., data_parallel_rank: _Optional[int] = ..., dp_balance_id: _Optional[int] = ..., stream: bool = ...) -> None: ...
class TokenizedInput(_message.Message):
__slots__ = ("original_text", "input_ids")
...
...
@@ -161,20 +163,20 @@ class GenerateResponse(_message.Message):
def __init__(self, request_id: _Optional[str] = ..., chunk: _Optional[_Union[GenerateStreamChunk, _Mapping]] = ..., complete: _Optional[_Union[GenerateComplete, _Mapping]] = ..., error: _Optional[_Union[GenerateError, _Mapping]] = ...) -> None: ...
class GenerateStreamChunk(_message.Message):
__slots__ = ("token_id", "prompt_tokens", "completion_tokens", "cached_tokens", "logprobs", "hidden_states")
TOKEN_ID_FIELD_NUMBER: _ClassVar[int]
__slots__ = ("token_id
s
", "prompt_tokens", "completion_tokens", "cached_tokens", "logprobs", "hidden_states")
TOKEN_ID
S
_FIELD_NUMBER: _ClassVar[int]
PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int]
COMPLETION_TOKENS_FIELD_NUMBER: _ClassVar[int]
CACHED_TOKENS_FIELD_NUMBER: _ClassVar[int]
LOGPROBS_FIELD_NUMBER: _ClassVar[int]
HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int]
token_id:
int
token_id
s
:
_containers.RepeatedScalarFieldContainer[int]
prompt_tokens: int
completion_tokens: int
cached_tokens: int
logprobs: LogProbs
hidden_states: _containers.RepeatedScalarFieldContainer[float]
def __init__(self, token_id: _Optional[int] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., logprobs: _Optional[_Union[LogProbs, _Mapping]] = ..., hidden_states: _Optional[_Iterable[float]] = ...) -> None: ...
def __init__(self, token_id
s
: _Optional[
_Iterable[
int]
]
= ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., logprobs: _Optional[_Union[LogProbs, _Mapping]] = ..., hidden_states: _Optional[_Iterable[float]] = ...) -> None: ...
class GenerateComplete(_message.Message):
__slots__ = ("output_ids", "finish_reason", "prompt_tokens", "completion_tokens", "cached_tokens", "all_logprobs", "all_hidden_states")
...
...
sgl-router/src/grpc_client/sglang_scheduler.rs
View file @
37158f20
...
...
@@ -103,6 +103,7 @@ impl SglangSchedulerClient {
logprob_start_len
:
-
1
,
top_logprobs_num
:
body
.top_logprobs
.unwrap_or
(
0
)
as
i32
,
return_hidden_states
:
body
.return_hidden_states
,
stream
:
body
.stream
,
..
Default
::
default
()
};
...
...
@@ -367,14 +368,14 @@ mod tests {
#[test]
fn
test_generate_stream_chunk
()
{
let
chunk
=
proto
::
GenerateStreamChunk
{
token_id
:
1234
,
token_id
s
:
vec!
[
1234
,
5678
]
,
prompt_tokens
:
5
,
completion_tokens
:
2
,
cached_tokens
:
3
,
..
Default
::
default
()
};
assert_eq!
(
chunk
.token_id
,
1234
);
assert_eq!
(
chunk
.token_id
s
,
vec!
[
1234
,
5678
]
);
assert_eq!
(
chunk
.prompt_tokens
,
5
);
assert_eq!
(
chunk
.completion_tokens
,
2
);
assert_eq!
(
chunk
.cached_tokens
,
3
);
...
...
sgl-router/src/proto/sglang_scheduler.proto
View file @
37158f20
...
...
@@ -122,6 +122,9 @@ message GenerateRequest {
// For load balancing
int32
dp_balance_id
=
17
;
// Whether client wants streaming response
bool
stream
=
18
;
}
message
TokenizedInput
{
...
...
@@ -163,8 +166,8 @@ message GenerateResponse {
}
message
GenerateStreamChunk
{
// Generated token
int32
token_id
=
1
;
// Generated token
s (incremental chunk)
repeated
int32
token_id
s
=
1
;
// Cumulative counts
int32
prompt_tokens
=
2
;
...
...
sgl-router/src/routers/grpc/router.rs
View file @
37158f20
...
...
@@ -203,6 +203,7 @@ impl GrpcRouter {
debug!
(
"Selected worker: {}"
,
worker
.url
());
// Step 2: Get gRPC client for worker (fail fast if can't connect)
// TODO(CahterineSue): manage grpc connection in worker. (it should be simpler here)
let
client
=
match
self
.get_or_create_grpc_client
(
worker
.url
())
.await
{
Ok
(
c
)
=>
c
,
Err
(
e
)
=>
{
...
...
@@ -249,7 +250,7 @@ impl GrpcRouter {
// Step 6: Build the base gRPC request
let
request_id
=
format!
(
"chatcmpl-{}"
,
Uuid
::
new_v4
());
let
base_
request
=
match
client
.build_generate_request
(
let
request
=
match
client
.build_generate_request
(
request_id
,
body
,
processed_messages
.text
.clone
(),
...
...
@@ -268,11 +269,11 @@ impl GrpcRouter {
}
};
// Step 7: Handle streaming vs non-streaming
if
body
.stream
{
self
.handle_streaming_chat
(
client
,
base_
request
,
body
)
.await
self
.handle_streaming_chat
(
client
,
request
,
body
)
.await
}
else
{
self
.handle_non_streaming_chat
(
client
,
base_request
,
body
)
.await
self
.handle_non_streaming_chat
(
client
,
request
,
body
)
.await
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment