Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
27e8ffed
Unverified
Commit
27e8ffed
authored
Sep 04, 2025
by
Liangsheng Yin
Committed by
GitHub
Sep 04, 2025
Browse files
[1/N] DP-refactor: move dp balance code into scheduler's mixin class (#10004)
parent
4dbb34fe
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
116 additions
and
106 deletions
+116
-106
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+3
-99
python/sglang/srt/managers/scheduler_metrics_mixin.py
python/sglang/srt/managers/scheduler_metrics_mixin.py
+113
-7
No files found.
python/sglang/srt/managers/scheduler.py
View file @
27e8ffed
...
...
@@ -500,6 +500,7 @@ class Scheduler(
# Init metrics stats
self
.
init_metrics
(
tp_rank
,
pp_rank
,
dp_rank
)
self
.
init_kv_events
(
server_args
.
kv_events_config
)
self
.
init_dp_balance
(
dp_balance_meta
)
# Init disaggregation
self
.
disaggregation_mode
=
DisaggregationMode
(
...
...
@@ -545,15 +546,6 @@ class Scheduler(
]
)
self
.
balance_meta
=
dp_balance_meta
if
(
server_args
.
enable_dp_attention
and
server_args
.
load_balance_method
==
"minimum_tokens"
):
assert
dp_balance_meta
is
not
None
self
.
recv_dp_balance_id_this_term
=
[]
def
init_tokenizer
(
self
):
server_args
=
self
.
server_args
self
.
is_generation
=
self
.
model_config
.
is_generation
...
...
@@ -1126,11 +1118,7 @@ class Scheduler(
self
,
recv_req
:
TokenizedGenerateReqInput
,
):
if
(
self
.
server_args
.
enable_dp_attention
and
self
.
server_args
.
load_balance_method
==
"minimum_tokens"
):
self
.
recv_dp_balance_id_this_term
.
append
(
recv_req
.
dp_balance_id
)
self
.
maybe_update_dp_balance_data
(
recv_req
)
# Create a new request
if
(
...
...
@@ -1568,11 +1556,7 @@ class Scheduler(
# Handle DP attention
if
need_dp_attn_preparation
:
if
(
self
.
server_args
.
load_balance_method
==
"minimum_tokens"
and
self
.
forward_ct
%
40
==
0
):
self
.
handle_dp_balance_data
(
ret
)
self
.
maybe_handle_dp_balance_data
()
ret
=
self
.
prepare_mlp_sync_batch
(
ret
)
return
ret
...
...
@@ -1897,86 +1881,6 @@ class Scheduler(
disable_overlap_schedule
=
self
.
server_args
.
disable_overlap_schedule
,
)
def
handle_dp_balance_data
(
self
,
local_batch
:
ScheduleBatch
):
def
gather_dp_balance_info
(
holding_tokens_list
)
->
Union
[
None
,
List
[
List
[
int
]]]:
"""gather recv_dp_balance_id_this_term and holding tokens per worker for dp balance"""
recv_list
=
self
.
recv_dp_balance_id_this_term
assert
len
(
recv_list
)
<=
511
,
(
"The number of requests received this round is too large. "
"Please increase gather_tensor_size and onfly_info_size."
)
# The maximum size of the tensor used for gathering data from all workers.
gather_tensor_size
=
512
# recv_tensor: | holding_tokens | len(recv_dp_balance_id) | recv_dp_balance_ids
recv_tensor
=
torch
.
zeros
(
gather_tensor_size
,
dtype
=
torch
.
int32
)
recv_tensor
[
0
]
=
holding_tokens_list
recv_tensor
[
1
]
=
len
(
recv_list
)
# The first element is the length of the list.
recv_tensor
[
2
:
len
(
recv_list
)
+
2
]
=
torch
.
tensor
(
recv_list
,
dtype
=
torch
.
int32
)
if
self
.
tp_rank
==
0
:
gathered_list
=
[
torch
.
zeros
(
gather_tensor_size
,
dtype
=
torch
.
int32
)
for
_
in
range
(
self
.
balance_meta
.
num_workers
)
]
else
:
gathered_list
=
None
torch
.
distributed
.
gather
(
recv_tensor
,
gathered_list
,
group
=
self
.
tp_cpu_group
)
gathered_id_list_per_worker
=
None
if
self
.
tp_rank
==
0
:
gathered_id_list_per_worker
=
[]
holding_tokens_list
=
[]
for
tensor
in
gathered_list
:
holding_tokens_list
.
append
(
tensor
[
0
].
item
())
list_length
=
tensor
[
1
].
item
()
gathered_id_list_per_worker
.
append
(
tensor
[
2
:
list_length
+
2
].
tolist
()
)
return
gathered_id_list_per_worker
,
holding_tokens_list
def
write_shared_dp_balance_info
(
new_recv_rid_lists
,
local_tokens
):
meta
=
self
.
balance_meta
with
meta
.
mutex
:
onfly_list
:
List
[
Dict
[
int
,
int
]]
=
meta
.
get_shared_onfly
()
assert
len
(
new_recv_rid_lists
)
==
len
(
onfly_list
),
"num_worker not equal"
# 1.Check if the rid received by each worker this round is present in onfly.
# If it is, remove the corresponding onfly item.
worker_id
=
0
for
new_recv_rids
,
on_fly_reqs
in
zip
(
new_recv_rid_lists
,
onfly_list
):
for
new_recv_rid
in
new_recv_rids
:
assert
(
new_recv_rid
in
on_fly_reqs
),
f
"
{
new_recv_rid
=
}
not in
{
worker_id
=
}
{
on_fly_reqs
=
}
, data consistency is wrong"
del
on_fly_reqs
[
new_recv_rid
]
worker_id
+=
1
# 2. Atomically write local_tokens and onfly into shm under the mutex
meta
.
set_shared_onfly_info
(
onfly_list
)
meta
.
set_shared_local_tokens
(
local_tokens
)
holding_tokens
=
self
.
get_load
()
new_recv_dp_balance_id_list
,
holding_token_list
=
gather_dp_balance_info
(
holding_tokens
)
self
.
recv_dp_balance_id_this_term
.
clear
()
if
self
.
tp_rank
==
0
:
# only first worker write info
write_shared_dp_balance_info
(
new_recv_dp_balance_id_list
,
holding_token_list
)
@
staticmethod
def
prepare_mlp_sync_batch_raw
(
local_batch
:
ScheduleBatch
,
...
...
python/sglang/srt/managers/scheduler_metrics_mixin.py
View file @
27e8ffed
from
__future__
import
annotations
import
logging
import
time
from
collections
import
defaultdict
from
typing
import
List
,
Optional
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Union
import
torch
from
sglang.srt.disaggregation.kv_events
import
EventPublisherFactory
,
KVEventBatch
from
sglang.srt.disaggregation.utils
import
DisaggregationMode
from
sglang.srt.managers.io_struct
import
TokenizedGenerateReqInput
from
sglang.srt.managers.schedule_policy
import
PrefillAdder
from
sglang.srt.managers.scheduler
import
Req
,
ScheduleBatch
from
sglang.srt.managers.utils
import
DPBalanceMeta
from
sglang.srt.metrics.collector
import
SchedulerMetricsCollector
,
SchedulerStats
from
sglang.srt.utils
import
get_bool_env_var
if
TYPE_CHECKING
:
from
sglang.srt.managers.scheduler
import
Scheduler
logger
=
logging
.
getLogger
(
__name__
)
RECORD_STEP_TIME
=
get_bool_env_var
(
"SGLANG_RECORD_STEP_TIME"
)
...
...
@@ -28,7 +37,9 @@ class KvMetrics:
class
SchedulerMetricsMixin
:
def
init_metrics
(
self
,
tp_rank
:
int
,
pp_rank
:
int
,
dp_rank
:
Optional
[
int
]):
def
init_metrics
(
self
:
Scheduler
,
tp_rank
:
int
,
pp_rank
:
int
,
dp_rank
:
Optional
[
int
]
):
self
.
last_gen_throughput
:
float
=
0.0
self
.
last_input_throughput
:
float
=
0.0
self
.
step_time_dict
=
defaultdict
(
list
)
# Dict[batch size -> step time]
...
...
@@ -50,14 +61,24 @@ class SchedulerMetricsMixin:
labels
[
"dp_rank"
]
=
dp_rank
self
.
metrics_collector
=
SchedulerMetricsCollector
(
labels
=
labels
)
def
init_kv_events
(
self
,
kv_events_config
:
Optional
[
str
]):
def
init_dp_balance
(
self
:
Scheduler
,
dp_balance_meta
:
Optional
[
DPBalanceMeta
]):
self
.
balance_meta
=
dp_balance_meta
if
(
self
.
server_args
.
enable_dp_attention
and
self
.
server_args
.
load_balance_method
==
"minimum_tokens"
):
assert
dp_balance_meta
is
not
None
self
.
recv_dp_balance_id_this_term
=
[]
def
init_kv_events
(
self
:
Scheduler
,
kv_events_config
:
Optional
[
str
]):
if
self
.
enable_kv_cache_events
:
self
.
kv_event_publisher
=
EventPublisherFactory
.
create
(
kv_events_config
,
self
.
attn_dp_rank
)
def
log_prefill_stats
(
self
,
self
:
Scheduler
,
adder
:
PrefillAdder
,
can_run_list
:
List
[
Req
],
running_bs
:
int
,
...
...
@@ -138,7 +159,7 @@ class SchedulerMetricsMixin:
self
.
_publish_kv_events
()
def
log_decode_stats
(
self
,
can_run_cuda_graph
:
bool
,
running_batch
:
ScheduleBatch
=
None
self
:
Scheduler
,
can_run_cuda_graph
:
bool
,
running_batch
:
ScheduleBatch
=
None
):
batch
=
running_batch
or
self
.
running_batch
...
...
@@ -220,7 +241,7 @@ class SchedulerMetricsMixin:
self
.
_emit_kv_metrics
()
self
.
_publish_kv_events
()
def
_emit_kv_metrics
(
self
):
def
_emit_kv_metrics
(
self
:
Scheduler
):
kv_metrics
=
KvMetrics
()
kv_metrics
.
request_active_slots
=
self
.
stats
.
num_running_reqs
kv_metrics
.
request_total_slots
=
self
.
max_running_requests
...
...
@@ -236,9 +257,94 @@ class SchedulerMetricsMixin:
if
not
self
.
send_metrics_from_scheduler
.
closed
:
self
.
send_metrics_from_scheduler
.
send_pyobj
(
kv_metrics
)
def
_publish_kv_events
(
self
):
def
_publish_kv_events
(
self
:
Scheduler
):
if
self
.
enable_kv_cache_events
:
events
=
self
.
tree_cache
.
take_events
()
if
events
:
batch
=
KVEventBatch
(
ts
=
time
.
time
(),
events
=
events
)
self
.
kv_event_publisher
.
publish
(
batch
)
def
maybe_update_dp_balance_data
(
self
:
Scheduler
,
recv_req
:
TokenizedGenerateReqInput
):
if
(
self
.
server_args
.
enable_dp_attention
and
self
.
server_args
.
load_balance_method
==
"minimum_tokens"
):
self
.
recv_dp_balance_id_this_term
.
append
(
recv_req
.
dp_balance_id
)
def
maybe_handle_dp_balance_data
(
self
:
Scheduler
):
if
(
self
.
server_args
.
load_balance_method
==
"minimum_tokens"
and
self
.
forward_ct
%
40
==
0
):
holding_tokens
=
self
.
get_load
()
new_recv_dp_balance_id_list
,
holding_token_list
=
(
self
.
gather_dp_balance_info
(
holding_tokens
)
)
self
.
recv_dp_balance_id_this_term
.
clear
()
if
self
.
tp_rank
==
0
:
# only first worker write info
self
.
write_shared_dp_balance_info
(
new_recv_dp_balance_id_list
,
holding_token_list
)
def
gather_dp_balance_info
(
self
:
Scheduler
,
holding_tokens_list
)
->
Union
[
None
,
List
[
List
[
int
]]]:
"""gather recv_dp_balance_id_this_term and holding tokens per worker for dp balance"""
recv_list
=
self
.
recv_dp_balance_id_this_term
assert
len
(
recv_list
)
<=
511
,
(
"The number of requests received this round is too large. "
"Please increase gather_tensor_size and onfly_info_size."
)
# The maximum size of the tensor used for gathering data from all workers.
gather_tensor_size
=
512
# recv_tensor: | holding_tokens | len(recv_dp_balance_id) | recv_dp_balance_ids
recv_tensor
=
torch
.
zeros
(
gather_tensor_size
,
dtype
=
torch
.
int32
)
recv_tensor
[
0
]
=
holding_tokens_list
recv_tensor
[
1
]
=
len
(
recv_list
)
# The first element is the length of the list.
recv_tensor
[
2
:
len
(
recv_list
)
+
2
]
=
torch
.
tensor
(
recv_list
,
dtype
=
torch
.
int32
)
if
self
.
tp_rank
==
0
:
gathered_list
=
[
torch
.
zeros
(
gather_tensor_size
,
dtype
=
torch
.
int32
)
for
_
in
range
(
self
.
balance_meta
.
num_workers
)
]
else
:
gathered_list
=
None
torch
.
distributed
.
gather
(
recv_tensor
,
gathered_list
,
group
=
self
.
tp_cpu_group
)
gathered_id_list_per_worker
=
None
if
self
.
tp_rank
==
0
:
gathered_id_list_per_worker
=
[]
holding_tokens_list
=
[]
for
tensor
in
gathered_list
:
holding_tokens_list
.
append
(
tensor
[
0
].
item
())
list_length
=
tensor
[
1
].
item
()
gathered_id_list_per_worker
.
append
(
tensor
[
2
:
list_length
+
2
].
tolist
())
return
gathered_id_list_per_worker
,
holding_tokens_list
def
write_shared_dp_balance_info
(
self
:
Scheduler
,
new_recv_rid_lists
,
local_tokens
):
meta
=
self
.
balance_meta
with
meta
.
mutex
:
onfly_list
:
List
[
Dict
[
int
,
int
]]
=
meta
.
get_shared_onfly
()
assert
len
(
new_recv_rid_lists
)
==
len
(
onfly_list
),
"num_worker not equal"
# 1.Check if the rid received by each worker this round is present in onfly.
# If it is, remove the corresponding onfly item.
worker_id
=
0
for
new_recv_rids
,
on_fly_reqs
in
zip
(
new_recv_rid_lists
,
onfly_list
):
for
new_recv_rid
in
new_recv_rids
:
assert
(
new_recv_rid
in
on_fly_reqs
),
f
"
{
new_recv_rid
=
}
not in
{
worker_id
=
}
{
on_fly_reqs
=
}
, data consistency is wrong"
del
on_fly_reqs
[
new_recv_rid
]
worker_id
+=
1
# 2. Atomically write local_tokens and onfly into shm under the mutex
meta
.
set_shared_onfly_info
(
onfly_list
)
meta
.
set_shared_local_tokens
(
local_tokens
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment