Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
ea0696b9
Unverified
Commit
ea0696b9
authored
Aug 25, 2025
by
Sundara Raman Ramachandran
Committed by
GitHub
Aug 26, 2025
Browse files
[Performance] Batch Send from Tokenizer Manager. (#9436)
parent
3aec3d4f
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
117 additions
and
6 deletions
+117
-6
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+30
-0
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+51
-3
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+36
-3
No files found.
python/sglang/srt/managers/io_struct.py
View file @
ea0696b9
...
@@ -533,6 +533,21 @@ class TokenizedGenerateReqInput:
...
@@ -533,6 +533,21 @@ class TokenizedGenerateReqInput:
dp_balance_id
:
int
=
-
1
dp_balance_id
:
int
=
-
1
@
dataclass
class
BatchTokenizedGenerateReqInput
:
# The batch of tokenized requests
batch
:
List
[
TokenizedGenerateReqInput
]
def
__len__
(
self
):
return
len
(
self
.
batch
)
def
__getitem__
(
self
,
i
):
return
self
.
batch
[
i
]
def
__iter__
(
self
):
return
iter
(
self
.
batch
)
@
dataclass
@
dataclass
class
EmbeddingReqInput
:
class
EmbeddingReqInput
:
# The input prompt. It can be a single prompt or a batch of prompts.
# The input prompt. It can be a single prompt or a batch of prompts.
...
@@ -668,6 +683,21 @@ class TokenizedEmbeddingReqInput:
...
@@ -668,6 +683,21 @@ class TokenizedEmbeddingReqInput:
dp_balance_id
:
int
=
-
1
dp_balance_id
:
int
=
-
1
@
dataclass
class
BatchTokenizedEmbeddingReqInput
:
# The batch of tokenized embedding requests
batch
:
List
[
TokenizedEmbeddingReqInput
]
def
__len__
(
self
):
return
len
(
self
.
batch
)
def
__getitem__
(
self
,
i
):
return
self
.
batch
[
i
]
def
__iter__
(
self
):
return
iter
(
self
.
batch
)
@
dataclass
@
dataclass
class
BatchTokenIDOut
:
class
BatchTokenIDOut
:
# The request id
# The request id
...
...
python/sglang/srt/managers/scheduler.py
View file @
ea0696b9
...
@@ -67,6 +67,8 @@ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
...
@@ -67,6 +67,8 @@ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from
sglang.srt.layers.moe
import
initialize_moe_config
from
sglang.srt.layers.moe
import
initialize_moe_config
from
sglang.srt.managers.io_struct
import
(
from
sglang.srt.managers.io_struct
import
(
AbortReq
,
AbortReq
,
BatchTokenizedEmbeddingReqInput
,
BatchTokenizedGenerateReqInput
,
CloseSessionReqInput
,
CloseSessionReqInput
,
ExpertDistributionReq
,
ExpertDistributionReq
,
ExpertDistributionReqOutput
,
ExpertDistributionReqOutput
,
...
@@ -510,6 +512,8 @@ class Scheduler(
...
@@ -510,6 +512,8 @@ class Scheduler(
[
[
(
TokenizedGenerateReqInput
,
self
.
handle_generate_request
),
(
TokenizedGenerateReqInput
,
self
.
handle_generate_request
),
(
TokenizedEmbeddingReqInput
,
self
.
handle_embedding_request
),
(
TokenizedEmbeddingReqInput
,
self
.
handle_embedding_request
),
(
BatchTokenizedGenerateReqInput
,
self
.
handle_batch_generate_request
),
(
BatchTokenizedEmbeddingReqInput
,
self
.
handle_batch_embedding_request
),
(
FlushCacheReqInput
,
self
.
flush_cache_wrapped
),
(
FlushCacheReqInput
,
self
.
flush_cache_wrapped
),
(
AbortReq
,
self
.
abort_request
),
(
AbortReq
,
self
.
abort_request
),
(
OpenSessionReqInput
,
self
.
open_session
),
(
OpenSessionReqInput
,
self
.
open_session
),
...
@@ -1018,14 +1022,26 @@ class Scheduler(
...
@@ -1018,14 +1022,26 @@ class Scheduler(
req
req
for
req
in
recv_reqs
for
req
in
recv_reqs
if
isinstance
(
if
isinstance
(
req
,
(
TokenizedGenerateReqInput
,
TokenizedEmbeddingReqInput
)
req
,
(
TokenizedGenerateReqInput
,
TokenizedEmbeddingReqInput
,
BatchTokenizedGenerateReqInput
,
BatchTokenizedEmbeddingReqInput
,
),
)
)
]
]
control_reqs
=
[
control_reqs
=
[
req
req
for
req
in
recv_reqs
for
req
in
recv_reqs
if
not
isinstance
(
if
not
isinstance
(
req
,
(
TokenizedGenerateReqInput
,
TokenizedEmbeddingReqInput
)
req
,
(
TokenizedGenerateReqInput
,
TokenizedEmbeddingReqInput
,
BatchTokenizedGenerateReqInput
,
BatchTokenizedEmbeddingReqInput
,
),
)
)
]
]
else
:
else
:
...
@@ -1253,6 +1269,17 @@ class Scheduler(
...
@@ -1253,6 +1269,17 @@ class Scheduler(
else
:
else
:
self
.
_add_request_to_queue
(
req
)
self
.
_add_request_to_queue
(
req
)
def
handle_batch_generate_request
(
self
,
recv_req
:
BatchTokenizedGenerateReqInput
,
):
"""Handle optimized batch generate request."""
logger
.
debug
(
f
"Processing batch generate request with
{
len
(
recv_req
)
}
requests"
)
# Process each request in the batch
for
tokenized_req
in
recv_req
:
self
.
handle_generate_request
(
tokenized_req
)
def
_add_request_to_queue
(
self
,
req
:
Req
):
def
_add_request_to_queue
(
self
,
req
:
Req
):
req
.
queue_time_start
=
time
.
perf_counter
()
req
.
queue_time_start
=
time
.
perf_counter
()
if
self
.
disaggregation_mode
==
DisaggregationMode
.
PREFILL
:
if
self
.
disaggregation_mode
==
DisaggregationMode
.
PREFILL
:
...
@@ -1335,6 +1362,19 @@ class Scheduler(
...
@@ -1335,6 +1362,19 @@ class Scheduler(
req
.
logprob_start_len
=
len
(
req
.
origin_input_ids
)
-
1
req
.
logprob_start_len
=
len
(
req
.
origin_input_ids
)
-
1
self
.
_add_request_to_queue
(
req
)
self
.
_add_request_to_queue
(
req
)
def
handle_batch_embedding_request
(
self
,
recv_req
:
BatchTokenizedEmbeddingReqInput
,
):
"""Handle optimized batch embedding request."""
logger
.
debug
(
f
"Processing batch embedding request with
{
len
(
recv_req
)
}
requests"
)
# Process each request in the batch
for
tokenized_req
in
recv_req
:
self
.
handle_embedding_request
(
tokenized_req
)
def
self_check_during_idle
(
self
):
def
self_check_during_idle
(
self
):
self
.
check_memory
()
self
.
check_memory
()
self
.
check_tree_cache
()
self
.
check_tree_cache
()
...
@@ -2513,7 +2553,15 @@ def is_health_check_generate_req(recv_req):
...
@@ -2513,7 +2553,15 @@ def is_health_check_generate_req(recv_req):
def
is_work_request
(
recv_req
):
def
is_work_request
(
recv_req
):
return
isinstance
(
recv_req
,
(
TokenizedGenerateReqInput
,
TokenizedEmbeddingReqInput
))
return
isinstance
(
recv_req
,
(
TokenizedGenerateReqInput
,
TokenizedEmbeddingReqInput
,
BatchTokenizedGenerateReqInput
,
BatchTokenizedEmbeddingReqInput
,
),
)
def
run_scheduler_process
(
def
run_scheduler_process
(
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
ea0696b9
...
@@ -71,6 +71,8 @@ from sglang.srt.managers.io_struct import (
...
@@ -71,6 +71,8 @@ from sglang.srt.managers.io_struct import (
BatchMultimodalOut
,
BatchMultimodalOut
,
BatchStrOut
,
BatchStrOut
,
BatchTokenIDOut
,
BatchTokenIDOut
,
BatchTokenizedEmbeddingReqInput
,
BatchTokenizedGenerateReqInput
,
CloseSessionReqInput
,
CloseSessionReqInput
,
ConfigureLoggingReq
,
ConfigureLoggingReq
,
EmbeddingReqInput
,
EmbeddingReqInput
,
...
@@ -768,6 +770,30 @@ class TokenizerManager:
...
@@ -768,6 +770,30 @@ class TokenizerManager:
self
.
rid_to_state
[
obj
.
rid
]
=
state
self
.
rid_to_state
[
obj
.
rid
]
=
state
return
state
return
state
def
_send_batch_request
(
self
,
obj
:
Union
[
GenerateReqInput
,
EmbeddingReqInput
],
tokenized_objs
:
List
[
Union
[
TokenizedGenerateReqInput
,
TokenizedEmbeddingReqInput
]
],
created_time
:
Optional
[
float
]
=
None
,
):
"""Send a batch of tokenized requests as a single batched request to the scheduler."""
if
isinstance
(
tokenized_objs
[
0
],
TokenizedGenerateReqInput
):
batch_req
=
BatchTokenizedGenerateReqInput
(
batch
=
tokenized_objs
)
else
:
batch_req
=
BatchTokenizedEmbeddingReqInput
(
batch
=
tokenized_objs
)
self
.
send_to_scheduler
.
send_pyobj
(
batch_req
)
# Create states for each individual request in the batch
for
i
,
tokenized_obj
in
enumerate
(
tokenized_objs
):
tmp_obj
=
obj
[
i
]
state
=
ReqState
(
[],
False
,
asyncio
.
Event
(),
tmp_obj
,
created_time
=
created_time
)
self
.
rid_to_state
[
tmp_obj
.
rid
]
=
state
async
def
_wait_one_response
(
async
def
_wait_one_response
(
self
,
self
,
obj
:
Union
[
GenerateReqInput
,
EmbeddingReqInput
],
obj
:
Union
[
GenerateReqInput
,
EmbeddingReqInput
],
...
@@ -870,10 +896,17 @@ class TokenizerManager:
...
@@ -870,10 +896,17 @@ class TokenizerManager:
tokenized_objs
=
await
self
.
_batch_tokenize_and_process
(
batch_size
,
obj
)
tokenized_objs
=
await
self
.
_batch_tokenize_and_process
(
batch_size
,
obj
)
for
i
,
tokenized_obj
in
enumerate
(
tokenized_objs
):
# Send as a single batched request
self
.
_send_batch_request
(
obj
,
tokenized_objs
,
created_time
)
# Set up generators for each request in the batch
for
i
in
range
(
batch_size
):
tmp_obj
=
obj
[
i
]
tmp_obj
=
obj
[
i
]
state
=
self
.
_send_one_request
(
tmp_obj
,
tokenized_obj
,
created_time
)
generators
.
append
(
generators
.
append
(
self
.
_wait_one_response
(
tmp_obj
,
state
,
request
))
self
.
_wait_one_response
(
tmp_obj
,
self
.
rid_to_state
[
tmp_obj
.
rid
],
request
)
)
rids
.
append
(
tmp_obj
.
rid
)
rids
.
append
(
tmp_obj
.
rid
)
else
:
else
:
# Sequential tokenization and processing
# Sequential tokenization and processing
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment