Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
0ce84c82
Unverified
Commit
0ce84c82
authored
Jul 29, 2025
by
fzyzcjy
Committed by
GitHub
Jul 28, 2025
Browse files
Support colocating requests (#7973)
parent
59d0bf01
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
179 additions
and
6 deletions
+179
-6
python/sglang/srt/managers/data_parallel_controller.py
python/sglang/srt/managers/data_parallel_controller.py
+4
-0
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+10
-0
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+10
-0
python/sglang/srt/managers/scheduler_input_blocker.py
python/sglang/srt/managers/scheduler_input_blocker.py
+106
-0
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+18
-6
python/sglang/srt/poll_based_barrier.py
python/sglang/srt/poll_based_barrier.py
+31
-0
No files found.
python/sglang/srt/managers/data_parallel_controller.py
View file @
0ce84c82
...
...
@@ -26,6 +26,7 @@ import zmq
from
sglang.srt.layers.dp_attention
import
compute_dp_attention_world_info
from
sglang.srt.managers.io_struct
import
(
BlockReqInput
,
TokenizedEmbeddingReqInput
,
TokenizedGenerateReqInput
,
)
...
...
@@ -282,6 +283,9 @@ class DataParallelController:
),
):
self
.
dispatching
(
recv_req
)
elif
isinstance
(
recv_req
,
BlockReqInput
):
for
worker
in
self
.
workers
:
worker
.
send_pyobj
(
recv_req
)
else
:
# Send other control messages to first worker of tp group
for
worker
in
self
.
workers
[::
self
.
control_message_step
]:
...
...
python/sglang/srt/managers/io_struct.py
View file @
0ce84c82
...
...
@@ -1103,3 +1103,13 @@ class LoRAUpdateResult:
LoadLoRAAdapterReqOutput
=
UnloadLoRAAdapterReqOutput
=
LoRAUpdateResult
class
BlockReqType
(
Enum
):
BLOCK
=
1
UNBLOCK
=
2
@
dataclass
class
BlockReqInput
:
type
:
BlockReqType
python/sglang/srt/managers/scheduler.py
View file @
0ce84c82
...
...
@@ -123,6 +123,7 @@ from sglang.srt.managers.schedule_policy import (
PrefillAdder
,
SchedulePolicy
,
)
from
sglang.srt.managers.scheduler_input_blocker
import
SchedulerInputBlocker
from
sglang.srt.managers.scheduler_output_processor_mixin
import
(
SchedulerOutputProcessorMixin
,
)
...
...
@@ -504,6 +505,12 @@ class Scheduler(
)
self
.
init_profier
()
self
.
input_blocker
=
(
SchedulerInputBlocker
(
noop
=
self
.
attn_tp_rank
!=
0
)
if
get_bool_env_var
(
"SGLANG_ENABLE_COLOCATED_BATCH_GEN"
)
else
None
)
# Init metrics stats
self
.
init_metrics
(
tp_rank
,
pp_rank
,
dp_rank
)
self
.
init_kv_events
(
server_args
.
kv_events_config
)
...
...
@@ -1035,6 +1042,9 @@ class Scheduler(
else
:
recv_reqs
=
None
if
self
.
input_blocker
is
not
None
:
recv_reqs
=
self
.
input_blocker
.
handle
(
recv_reqs
)
if
self
.
server_args
.
enable_dp_attention
:
if
self
.
attn_tp_rank
==
0
:
work_reqs
=
[
...
...
python/sglang/srt/managers/scheduler_input_blocker.py
0 → 100644
View file @
0ce84c82
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import
logging
from
contextlib
import
contextmanager
from
enum
import
Enum
,
auto
from
typing
import
Any
,
List
,
Optional
from
sglang.srt.managers.io_struct
import
BlockReqInput
,
BlockReqType
from
sglang.srt.poll_based_barrier
import
PollBasedBarrier
logger
=
logging
.
getLogger
(
__name__
)
class
SchedulerInputBlocker
:
def
__init__
(
self
,
noop
:
bool
):
self
.
_state
=
_State
.
UNBLOCKED
self
.
_pending_reqs
=
[]
self
.
_noop
=
noop
self
.
_global_unblock_barrier
=
PollBasedBarrier
(
noop
=
noop
)
def
handle
(
self
,
recv_reqs
:
Optional
[
List
[
Any
]]):
assert
(
recv_reqs
is
None
)
==
self
.
_noop
if
not
self
.
_noop
:
output_reqs
=
[]
for
recv_req
in
recv_reqs
:
output_reqs
+=
self
.
_handle_recv_req
(
recv_req
)
global_arrived_unblock_barrier
=
(
self
.
_global_unblock_barrier
.
poll_global_arrived
()
)
if
(
self
.
_state
==
_State
.
GLOBAL_UNBLOCK_BARRIER
and
global_arrived_unblock_barrier
):
output_reqs
+=
self
.
_handle_arrive_unblock_barrier
()
if
not
self
.
_noop
:
return
output_reqs
def
_handle_recv_req
(
self
,
recv_req
):
if
isinstance
(
recv_req
,
BlockReqInput
):
if
recv_req
.
type
==
BlockReqType
.
BLOCK
:
self
.
_execute_block_req
()
return
[]
elif
recv_req
.
type
==
BlockReqType
.
UNBLOCK
:
self
.
_execute_unblock_req
()
return
[]
else
:
raise
NotImplementedError
(
f
"
{
recv_req
=
}
"
)
else
:
if
self
.
_state
==
_State
.
UNBLOCKED
:
return
[
recv_req
]
else
:
self
.
_pending_reqs
.
append
(
recv_req
)
return
[]
def
_execute_block_req
(
self
):
logger
.
info
(
"Handle block req"
)
self
.
_change_state
(
original
=
_State
.
UNBLOCKED
,
target
=
_State
.
BLOCKED
)
def
_execute_unblock_req
(
self
):
logger
.
info
(
"Handle unblock req"
)
self
.
_change_state
(
original
=
_State
.
BLOCKED
,
target
=
_State
.
GLOBAL_UNBLOCK_BARRIER
)
self
.
_global_unblock_barrier
.
local_arrive
()
def
_handle_arrive_unblock_barrier
(
self
):
logger
.
info
(
f
"Arrived at unblock barrier (
{
len
(
self
.
_pending_reqs
)
=
}
)"
)
self
.
_change_state
(
original
=
_State
.
GLOBAL_UNBLOCK_BARRIER
,
target
=
_State
.
UNBLOCKED
)
output_reqs
=
[
*
self
.
_pending_reqs
]
self
.
_pending_reqs
.
clear
()
return
output_reqs
def
_change_state
(
self
,
original
:
"_State"
,
target
:
"_State"
):
assert
self
.
_state
==
original
,
f
"
{
self
.
_state
=
}
{
original
=
}
{
target
=
}
"
self
.
_state
=
target
class
_State
(
Enum
):
UNBLOCKED
=
auto
()
BLOCKED
=
auto
()
GLOBAL_UNBLOCK_BARRIER
=
auto
()
@
contextmanager
def
input_blocker_guard_region
(
send_to_scheduler
):
send_to_scheduler
.
send_pyobj
(
BlockReqInput
(
BlockReqType
.
BLOCK
))
try
:
yield
finally
:
send_to_scheduler
.
send_pyobj
(
BlockReqInput
(
BlockReqType
.
UNBLOCK
))
python/sglang/srt/managers/tokenizer_manager.py
View file @
0ce84c82
...
...
@@ -27,6 +27,7 @@ import threading
import
time
import
uuid
from
collections
import
deque
from
contextlib
import
nullcontext
from
datetime
import
datetime
from
http
import
HTTPStatus
from
typing
import
(
...
...
@@ -69,6 +70,7 @@ from sglang.srt.managers.io_struct import (
BatchMultimodalOut
,
BatchStrOut
,
BatchTokenIDOut
,
BlockReqType
,
CloseSessionReqInput
,
ConfigureLoggingReq
,
EmbeddingReqInput
,
...
...
@@ -114,6 +116,7 @@ from sglang.srt.managers.io_struct import (
)
from
sglang.srt.managers.mm_utils
import
TensorTransportMode
from
sglang.srt.managers.multimodal_processor
import
get_mm_processor
,
import_processors
from
sglang.srt.managers.scheduler_input_blocker
import
input_blocker_guard_region
from
sglang.srt.metrics.collector
import
TokenizerMetricsCollector
from
sglang.srt.sampling.sampling_params
import
SamplingParams
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
...
...
@@ -819,12 +822,21 @@ class TokenizerManager:
rids
.
append
(
tmp_obj
.
rid
)
else
:
# Sequential tokenization and processing
for
i
in
range
(
batch_size
):
tmp_obj
=
obj
[
i
]
tokenized_obj
=
await
self
.
_tokenize_one_request
(
tmp_obj
)
state
=
self
.
_send_one_request
(
tmp_obj
,
tokenized_obj
,
created_time
)
generators
.
append
(
self
.
_wait_one_response
(
tmp_obj
,
state
,
request
))
rids
.
append
(
tmp_obj
.
rid
)
with
(
input_blocker_guard_region
(
send_to_scheduler
=
self
.
send_to_scheduler
)
if
get_bool_env_var
(
"SGLANG_ENABLE_COLOCATED_BATCH_GEN"
)
else
nullcontext
()
):
for
i
in
range
(
batch_size
):
tmp_obj
=
obj
[
i
]
tokenized_obj
=
await
self
.
_tokenize_one_request
(
tmp_obj
)
state
=
self
.
_send_one_request
(
tmp_obj
,
tokenized_obj
,
created_time
)
generators
.
append
(
self
.
_wait_one_response
(
tmp_obj
,
state
,
request
)
)
rids
.
append
(
tmp_obj
.
rid
)
else
:
# FIXME: When using batch and parallel_sample_num together, the perf is not optimal.
if
batch_size
>
128
:
...
...
python/sglang/srt/poll_based_barrier.py
0 → 100644
View file @
0ce84c82
import
torch
from
sglang.srt.distributed
import
get_world_group
class
PollBasedBarrier
:
def
__init__
(
self
,
noop
:
bool
=
False
):
self
.
_noop
=
noop
self
.
_local_arrived
=
False
def
local_arrive
(
self
):
assert
not
self
.
_local_arrived
self
.
_local_arrived
=
True
def
poll_global_arrived
(
self
)
->
bool
:
global_arrived
=
self
.
_compute_global_arrived
()
output
=
self
.
_local_arrived
and
global_arrived
if
output
:
self
.
_local_arrived
=
False
return
output
def
_compute_global_arrived
(
self
)
->
bool
:
local_arrived
=
self
.
_noop
or
self
.
_local_arrived
global_arrived
=
torch
.
tensor
(
local_arrived
)
# Can optimize if bottleneck
torch
.
distributed
.
all_reduce
(
global_arrived
,
torch
.
distributed
.
ReduceOp
.
MIN
,
group
=
get_world_group
().
cpu_group
,
)
return
global_arrived
.
item
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment