Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
0ce84c82
Unverified
Commit
0ce84c82
authored
Jul 29, 2025
by
fzyzcjy
Committed by
GitHub
Jul 28, 2025
Browse files
Support colocating requests (#7973)
parent
59d0bf01
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
179 additions
and
6 deletions
+179
-6
python/sglang/srt/managers/data_parallel_controller.py
python/sglang/srt/managers/data_parallel_controller.py
+4
-0
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+10
-0
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+10
-0
python/sglang/srt/managers/scheduler_input_blocker.py
python/sglang/srt/managers/scheduler_input_blocker.py
+106
-0
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+18
-6
python/sglang/srt/poll_based_barrier.py
python/sglang/srt/poll_based_barrier.py
+31
-0
No files found.
python/sglang/srt/managers/data_parallel_controller.py
View file @
0ce84c82
...
@@ -26,6 +26,7 @@ import zmq
...
@@ -26,6 +26,7 @@ import zmq
from
sglang.srt.layers.dp_attention
import
compute_dp_attention_world_info
from
sglang.srt.layers.dp_attention
import
compute_dp_attention_world_info
from
sglang.srt.managers.io_struct
import
(
from
sglang.srt.managers.io_struct
import
(
BlockReqInput
,
TokenizedEmbeddingReqInput
,
TokenizedEmbeddingReqInput
,
TokenizedGenerateReqInput
,
TokenizedGenerateReqInput
,
)
)
...
@@ -282,6 +283,9 @@ class DataParallelController:
...
@@ -282,6 +283,9 @@ class DataParallelController:
),
),
):
):
self
.
dispatching
(
recv_req
)
self
.
dispatching
(
recv_req
)
elif
isinstance
(
recv_req
,
BlockReqInput
):
for
worker
in
self
.
workers
:
worker
.
send_pyobj
(
recv_req
)
else
:
else
:
# Send other control messages to first worker of tp group
# Send other control messages to first worker of tp group
for
worker
in
self
.
workers
[::
self
.
control_message_step
]:
for
worker
in
self
.
workers
[::
self
.
control_message_step
]:
...
...
python/sglang/srt/managers/io_struct.py
View file @
0ce84c82
...
@@ -1103,3 +1103,13 @@ class LoRAUpdateResult:
...
@@ -1103,3 +1103,13 @@ class LoRAUpdateResult:
LoadLoRAAdapterReqOutput
=
UnloadLoRAAdapterReqOutput
=
LoRAUpdateResult
LoadLoRAAdapterReqOutput
=
UnloadLoRAAdapterReqOutput
=
LoRAUpdateResult
class
BlockReqType
(
Enum
):
BLOCK
=
1
UNBLOCK
=
2
@
dataclass
class
BlockReqInput
:
type
:
BlockReqType
python/sglang/srt/managers/scheduler.py
View file @
0ce84c82
...
@@ -123,6 +123,7 @@ from sglang.srt.managers.schedule_policy import (
...
@@ -123,6 +123,7 @@ from sglang.srt.managers.schedule_policy import (
PrefillAdder
,
PrefillAdder
,
SchedulePolicy
,
SchedulePolicy
,
)
)
from
sglang.srt.managers.scheduler_input_blocker
import
SchedulerInputBlocker
from
sglang.srt.managers.scheduler_output_processor_mixin
import
(
from
sglang.srt.managers.scheduler_output_processor_mixin
import
(
SchedulerOutputProcessorMixin
,
SchedulerOutputProcessorMixin
,
)
)
...
@@ -504,6 +505,12 @@ class Scheduler(
...
@@ -504,6 +505,12 @@ class Scheduler(
)
)
self
.
init_profier
()
self
.
init_profier
()
self
.
input_blocker
=
(
SchedulerInputBlocker
(
noop
=
self
.
attn_tp_rank
!=
0
)
if
get_bool_env_var
(
"SGLANG_ENABLE_COLOCATED_BATCH_GEN"
)
else
None
)
# Init metrics stats
# Init metrics stats
self
.
init_metrics
(
tp_rank
,
pp_rank
,
dp_rank
)
self
.
init_metrics
(
tp_rank
,
pp_rank
,
dp_rank
)
self
.
init_kv_events
(
server_args
.
kv_events_config
)
self
.
init_kv_events
(
server_args
.
kv_events_config
)
...
@@ -1035,6 +1042,9 @@ class Scheduler(
...
@@ -1035,6 +1042,9 @@ class Scheduler(
else
:
else
:
recv_reqs
=
None
recv_reqs
=
None
if
self
.
input_blocker
is
not
None
:
recv_reqs
=
self
.
input_blocker
.
handle
(
recv_reqs
)
if
self
.
server_args
.
enable_dp_attention
:
if
self
.
server_args
.
enable_dp_attention
:
if
self
.
attn_tp_rank
==
0
:
if
self
.
attn_tp_rank
==
0
:
work_reqs
=
[
work_reqs
=
[
...
...
python/sglang/srt/managers/scheduler_input_blocker.py
0 → 100644
View file @
0ce84c82
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import
logging
from
contextlib
import
contextmanager
from
enum
import
Enum
,
auto
from
typing
import
Any
,
List
,
Optional
from
sglang.srt.managers.io_struct
import
BlockReqInput
,
BlockReqType
from
sglang.srt.poll_based_barrier
import
PollBasedBarrier
logger
=
logging
.
getLogger
(
__name__
)
class
SchedulerInputBlocker
:
def
__init__
(
self
,
noop
:
bool
):
self
.
_state
=
_State
.
UNBLOCKED
self
.
_pending_reqs
=
[]
self
.
_noop
=
noop
self
.
_global_unblock_barrier
=
PollBasedBarrier
(
noop
=
noop
)
def
handle
(
self
,
recv_reqs
:
Optional
[
List
[
Any
]]):
assert
(
recv_reqs
is
None
)
==
self
.
_noop
if
not
self
.
_noop
:
output_reqs
=
[]
for
recv_req
in
recv_reqs
:
output_reqs
+=
self
.
_handle_recv_req
(
recv_req
)
global_arrived_unblock_barrier
=
(
self
.
_global_unblock_barrier
.
poll_global_arrived
()
)
if
(
self
.
_state
==
_State
.
GLOBAL_UNBLOCK_BARRIER
and
global_arrived_unblock_barrier
):
output_reqs
+=
self
.
_handle_arrive_unblock_barrier
()
if
not
self
.
_noop
:
return
output_reqs
def
_handle_recv_req
(
self
,
recv_req
):
if
isinstance
(
recv_req
,
BlockReqInput
):
if
recv_req
.
type
==
BlockReqType
.
BLOCK
:
self
.
_execute_block_req
()
return
[]
elif
recv_req
.
type
==
BlockReqType
.
UNBLOCK
:
self
.
_execute_unblock_req
()
return
[]
else
:
raise
NotImplementedError
(
f
"
{
recv_req
=
}
"
)
else
:
if
self
.
_state
==
_State
.
UNBLOCKED
:
return
[
recv_req
]
else
:
self
.
_pending_reqs
.
append
(
recv_req
)
return
[]
def
_execute_block_req
(
self
):
logger
.
info
(
"Handle block req"
)
self
.
_change_state
(
original
=
_State
.
UNBLOCKED
,
target
=
_State
.
BLOCKED
)
def
_execute_unblock_req
(
self
):
logger
.
info
(
"Handle unblock req"
)
self
.
_change_state
(
original
=
_State
.
BLOCKED
,
target
=
_State
.
GLOBAL_UNBLOCK_BARRIER
)
self
.
_global_unblock_barrier
.
local_arrive
()
def
_handle_arrive_unblock_barrier
(
self
):
logger
.
info
(
f
"Arrived at unblock barrier (
{
len
(
self
.
_pending_reqs
)
=
}
)"
)
self
.
_change_state
(
original
=
_State
.
GLOBAL_UNBLOCK_BARRIER
,
target
=
_State
.
UNBLOCKED
)
output_reqs
=
[
*
self
.
_pending_reqs
]
self
.
_pending_reqs
.
clear
()
return
output_reqs
def
_change_state
(
self
,
original
:
"_State"
,
target
:
"_State"
):
assert
self
.
_state
==
original
,
f
"
{
self
.
_state
=
}
{
original
=
}
{
target
=
}
"
self
.
_state
=
target
class
_State
(
Enum
):
UNBLOCKED
=
auto
()
BLOCKED
=
auto
()
GLOBAL_UNBLOCK_BARRIER
=
auto
()
@
contextmanager
def
input_blocker_guard_region
(
send_to_scheduler
):
send_to_scheduler
.
send_pyobj
(
BlockReqInput
(
BlockReqType
.
BLOCK
))
try
:
yield
finally
:
send_to_scheduler
.
send_pyobj
(
BlockReqInput
(
BlockReqType
.
UNBLOCK
))
python/sglang/srt/managers/tokenizer_manager.py
View file @
0ce84c82
...
@@ -27,6 +27,7 @@ import threading
...
@@ -27,6 +27,7 @@ import threading
import
time
import
time
import
uuid
import
uuid
from
collections
import
deque
from
collections
import
deque
from
contextlib
import
nullcontext
from
datetime
import
datetime
from
datetime
import
datetime
from
http
import
HTTPStatus
from
http
import
HTTPStatus
from
typing
import
(
from
typing
import
(
...
@@ -69,6 +70,7 @@ from sglang.srt.managers.io_struct import (
...
@@ -69,6 +70,7 @@ from sglang.srt.managers.io_struct import (
BatchMultimodalOut
,
BatchMultimodalOut
,
BatchStrOut
,
BatchStrOut
,
BatchTokenIDOut
,
BatchTokenIDOut
,
BlockReqType
,
CloseSessionReqInput
,
CloseSessionReqInput
,
ConfigureLoggingReq
,
ConfigureLoggingReq
,
EmbeddingReqInput
,
EmbeddingReqInput
,
...
@@ -114,6 +116,7 @@ from sglang.srt.managers.io_struct import (
...
@@ -114,6 +116,7 @@ from sglang.srt.managers.io_struct import (
)
)
from
sglang.srt.managers.mm_utils
import
TensorTransportMode
from
sglang.srt.managers.mm_utils
import
TensorTransportMode
from
sglang.srt.managers.multimodal_processor
import
get_mm_processor
,
import_processors
from
sglang.srt.managers.multimodal_processor
import
get_mm_processor
,
import_processors
from
sglang.srt.managers.scheduler_input_blocker
import
input_blocker_guard_region
from
sglang.srt.metrics.collector
import
TokenizerMetricsCollector
from
sglang.srt.metrics.collector
import
TokenizerMetricsCollector
from
sglang.srt.sampling.sampling_params
import
SamplingParams
from
sglang.srt.sampling.sampling_params
import
SamplingParams
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
...
@@ -819,12 +822,21 @@ class TokenizerManager:
...
@@ -819,12 +822,21 @@ class TokenizerManager:
rids
.
append
(
tmp_obj
.
rid
)
rids
.
append
(
tmp_obj
.
rid
)
else
:
else
:
# Sequential tokenization and processing
# Sequential tokenization and processing
for
i
in
range
(
batch_size
):
with
(
tmp_obj
=
obj
[
i
]
input_blocker_guard_region
(
send_to_scheduler
=
self
.
send_to_scheduler
)
tokenized_obj
=
await
self
.
_tokenize_one_request
(
tmp_obj
)
if
get_bool_env_var
(
"SGLANG_ENABLE_COLOCATED_BATCH_GEN"
)
state
=
self
.
_send_one_request
(
tmp_obj
,
tokenized_obj
,
created_time
)
else
nullcontext
()
generators
.
append
(
self
.
_wait_one_response
(
tmp_obj
,
state
,
request
))
):
rids
.
append
(
tmp_obj
.
rid
)
for
i
in
range
(
batch_size
):
tmp_obj
=
obj
[
i
]
tokenized_obj
=
await
self
.
_tokenize_one_request
(
tmp_obj
)
state
=
self
.
_send_one_request
(
tmp_obj
,
tokenized_obj
,
created_time
)
generators
.
append
(
self
.
_wait_one_response
(
tmp_obj
,
state
,
request
)
)
rids
.
append
(
tmp_obj
.
rid
)
else
:
else
:
# FIXME: When using batch and parallel_sample_num together, the perf is not optimal.
# FIXME: When using batch and parallel_sample_num together, the perf is not optimal.
if
batch_size
>
128
:
if
batch_size
>
128
:
...
...
python/sglang/srt/poll_based_barrier.py
0 → 100644
View file @
0ce84c82
import
torch
from
sglang.srt.distributed
import
get_world_group
class
PollBasedBarrier
:
def
__init__
(
self
,
noop
:
bool
=
False
):
self
.
_noop
=
noop
self
.
_local_arrived
=
False
def
local_arrive
(
self
):
assert
not
self
.
_local_arrived
self
.
_local_arrived
=
True
def
poll_global_arrived
(
self
)
->
bool
:
global_arrived
=
self
.
_compute_global_arrived
()
output
=
self
.
_local_arrived
and
global_arrived
if
output
:
self
.
_local_arrived
=
False
return
output
def
_compute_global_arrived
(
self
)
->
bool
:
local_arrived
=
self
.
_noop
or
self
.
_local_arrived
global_arrived
=
torch
.
tensor
(
local_arrived
)
# Can optimize if bottleneck
torch
.
distributed
.
all_reduce
(
global_arrived
,
torch
.
distributed
.
ReduceOp
.
MIN
,
group
=
get_world_group
().
cpu_group
,
)
return
global_arrived
.
item
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment