Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
41efcaeb
Unverified
Commit
41efcaeb
authored
Nov 01, 2025
by
ykcombat
Committed by
GitHub
Nov 01, 2025
Browse files
[Feature] PD-Multiplexing Context and Scheduler, lazy import spatial. (#12275)
parent
70562969
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
458 additions
and
24 deletions
+458
-24
python/sglang/srt/layers/logits_processor.py
python/sglang/srt/layers/logits_processor.py
+3
-6
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+5
-1
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+15
-2
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+24
-1
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+4
-1
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+1
-0
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+33
-13
python/sglang/srt/multiplex/multiplexing_mixin.py
python/sglang/srt/multiplex/multiplexing_mixin.py
+209
-0
python/sglang/srt/multiplex/pdmux_context.py
python/sglang/srt/multiplex/pdmux_context.py
+164
-0
No files found.
python/sglang/srt/layers/logits_processor.py
View file @
41efcaeb
...
...
@@ -134,10 +134,7 @@ class LogitsMetadata:
@
classmethod
def
from_forward_batch
(
cls
,
forward_batch
:
ForwardBatch
):
if
(
(
forward_batch
.
forward_mode
.
is_extend
()
or
forward_batch
.
forward_mode
.
is_split_prefill
()
)
forward_batch
.
forward_mode
.
is_extend
()
and
forward_batch
.
return_logprob
and
not
forward_batch
.
forward_mode
.
is_target_verify
()
):
...
...
@@ -384,8 +381,8 @@ class LogitsProcessor(nn.Module):
input_logprob_indices
=
None
elif
(
logits_metadata
.
forward_mode
.
is_extend
()
or
logits_metadata
.
forward_mode
.
is_split_prefill
()
)
and
not
logits_metadata
.
extend_return_logprob
:
and
not
logits_metadata
.
extend_return_logprob
):
# Prefill without input logprobs.
if
logits_metadata
.
padded_static_len
<
0
:
last_index
=
torch
.
cumsum
(
logits_metadata
.
extend_seq_lens
,
dim
=
0
)
-
1
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
41efcaeb
...
...
@@ -72,7 +72,11 @@ from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
from
sglang.srt.mem_cache.radix_cache
import
RadixKey
from
sglang.srt.mem_cache.swa_radix_cache
import
SWARadixCache
from
sglang.srt.metrics.collector
import
SchedulerMetricsCollector
,
TimeStats
from
sglang.srt.model_executor.forward_batch_info
import
CaptureHiddenMode
,
ForwardMode
from
sglang.srt.model_executor.forward_batch_info
import
(
CaptureHiddenMode
,
ForwardBatch
,
ForwardMode
,
)
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.sampling.sampling_params
import
SamplingParams
from
sglang.srt.server_args
import
ServerArgs
,
get_global_server_args
...
...
python/sglang/srt/managers/scheduler.py
View file @
41efcaeb
...
...
@@ -152,6 +152,7 @@ from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
from
sglang.srt.mem_cache.mamba_radix_cache
import
MambaRadixCache
from
sglang.srt.mem_cache.radix_cache
import
RadixCache
from
sglang.srt.mem_cache.swa_radix_cache
import
SWARadixCache
from
sglang.srt.multiplex.multiplexing_mixin
import
SchedulerMultiplexMixin
from
sglang.srt.parser.reasoning_parser
import
ReasoningParser
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
,
get_global_server_args
from
sglang.srt.speculative.spec_info
import
SpeculativeAlgorithm
...
...
@@ -213,6 +214,7 @@ class Scheduler(
SchedulerMetricsMixin
,
SchedulerDisaggregationDecodeMixin
,
SchedulerDisaggregationPrefillMixin
,
SchedulerMultiplexMixin
,
SchedulerRuntimeCheckerMixin
,
SchedulerPPMixin
,
):
...
...
@@ -252,6 +254,7 @@ class Scheduler(
self
.
enable_lora
=
server_args
.
enable_lora
self
.
max_loras_per_batch
=
server_args
.
max_loras_per_batch
self
.
enable_overlap
=
not
server_args
.
disable_overlap_schedule
self
.
enable_pdmux
=
server_args
.
enable_pdmux
self
.
skip_tokenizer_init
=
server_args
.
skip_tokenizer_init
self
.
enable_metrics
=
server_args
.
enable_metrics
self
.
enable_metrics_for_all_schedulers
=
(
...
...
@@ -285,6 +288,10 @@ class Scheduler(
# Init inter-process communication
self
.
init_sockets
(
server_args
,
port_args
)
# Init pdmux context
if
self
.
enable_pdmux
:
self
.
init_pdmux
()
# Init tokenizer
self
.
init_tokenizer
()
...
...
@@ -424,6 +431,8 @@ class Scheduler(
self
.
running_batch
:
ScheduleBatch
=
ScheduleBatch
(
reqs
=
[],
batch_is_full
=
False
)
# The current forward batch
self
.
cur_batch
:
Optional
[
ScheduleBatch
]
=
None
# The current split prefill batch
self
.
split_prefill_batch
:
Optional
[
ScheduleBatch
]
=
None
# The last forward batch
self
.
last_batch
:
Optional
[
ScheduleBatch
]
=
None
self
.
forward_ct
=
0
...
...
@@ -1952,7 +1961,6 @@ class Scheduler(
# Run forward
if
self
.
is_generation
:
batch_or_worker_batch
=
batch
if
self
.
enable_overlap
or
self
.
spec_algorithm
.
is_none
():
...
...
@@ -2009,6 +2017,9 @@ class Scheduler(
# The future value, usually for next batch preparation
# Current implementation strictly synchronizes the seq_lens
batch
.
seq_lens
=
batch_result
.
next_draft_input
.
new_seq_lens
elif
self
.
enable_pdmux
and
batch
.
forward_mode
.
is_split_prefill
():
batch_result
=
self
.
tp_worker
.
forward_batch_split_prefill
(
batch
)
future_indices_or_next_token_ids
=
batch_result
.
next_token_ids
else
:
batch_result
=
self
.
model_worker
.
forward_batch_generation
(
batch_or_worker_batch
...
...
@@ -2791,7 +2802,9 @@ def run_scheduler_process(
disaggregation_mode
:
DisaggregationMode
=
scheduler
.
disaggregation_mode
if
disaggregation_mode
==
DisaggregationMode
.
NULL
:
if
server_args
.
pp_size
>
1
:
if
scheduler
.
enable_pdmux
:
scheduler
.
event_loop_pdmux
()
elif
server_args
.
pp_size
>
1
:
scheduler
.
event_loop_pp
()
elif
scheduler
.
enable_overlap
:
scheduler
.
event_loop_overlap
()
...
...
python/sglang/srt/managers/tp_worker.py
View file @
41efcaeb
...
...
@@ -35,7 +35,7 @@ from sglang.srt.managers.io_struct import (
UpdateWeightsFromIPCReqInput
,
UpdateWeightsFromTensorReqInput
,
)
from
sglang.srt.managers.schedule_batch
import
ModelWorkerBatch
from
sglang.srt.managers.schedule_batch
import
ModelWorkerBatch
,
ScheduleBatch
from
sglang.srt.managers.scheduler
import
GenerationBatchResult
from
sglang.srt.mem_cache.allocator
import
BaseTokenToKVPoolAllocator
from
sglang.srt.mem_cache.memory_pool
import
ReqToTokenPool
...
...
@@ -425,3 +425,26 @@ class TpModelWorker(BaseTpWorker):
pp_hidden_states_proxy_tensors
=
pp_proxy_tensors
,
can_run_cuda_graph
=
can_run_cuda_graph
,
)
def
forward_batch_split_prefill
(
self
,
batch
:
ScheduleBatch
):
if
batch
.
split_index
==
0
:
model_worker_batch
=
batch
.
get_model_worker_batch
()
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
self
.
model_runner
)
batch
.
split_forward_batch
=
forward_batch
batch
.
seq_lens_cpu_cache
=
model_worker_batch
.
seq_lens_cpu
else
:
model_worker_batch
=
batch
.
get_model_worker_batch
(
batch
.
seq_lens_cpu_cache
)
logits_output
,
can_run_cuda_graph
=
self
.
model_runner
.
forward
(
batch
.
split_forward_batch
,
split_forward_count
=
batch
.
split_forward_count
)
if
logits_output
:
next_token_ids
=
self
.
model_runner
.
sample
(
logits_output
,
model_worker_batch
)
else
:
next_token_ids
=
None
batch_result
=
GenerationBatchResult
(
logits_output
=
logits_output
,
can_run_cuda_graph
=
can_run_cuda_graph
,
)
batch_result
.
next_token_ids
=
next_token_ids
return
batch_result
python/sglang/srt/mem_cache/memory_pool.py
View file @
41efcaeb
...
...
@@ -509,6 +509,7 @@ class MHATokenToKVPool(KVCache):
enable_memory_saver
:
bool
,
start_layer
:
Optional
[
int
]
=
None
,
end_layer
:
Optional
[
int
]
=
None
,
enable_alt_stream
:
bool
=
True
,
enable_kv_cache_copy
:
bool
=
False
,
):
super
().
__init__
(
...
...
@@ -527,7 +528,9 @@ class MHATokenToKVPool(KVCache):
self
.
_create_buffers
()
self
.
device_module
=
torch
.
get_device_module
(
self
.
device
)
self
.
alt_stream
=
self
.
device_module
.
Stream
()
if
_is_cuda
else
None
self
.
alt_stream
=
(
self
.
device_module
.
Stream
()
if
_is_cuda
and
enable_alt_stream
else
None
)
if
enable_kv_cache_copy
:
self
.
_init_kv_copy_and_warmup
()
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
41efcaeb
...
...
@@ -96,6 +96,7 @@ class ForwardMode(IntEnum):
else
False
)
or
self
==
ForwardMode
.
TARGET_VERIFY
or
self
==
ForwardMode
.
SPLIT_PREFILL
)
def
is_decode
(
self
):
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
41efcaeb
...
...
@@ -1765,6 +1765,7 @@ class ModelRunner:
enable_memory_saver
=
self
.
server_args
.
enable_memory_saver
,
start_layer
=
self
.
start_layer
,
end_layer
=
self
.
end_layer
,
enable_alt_stream
=
not
self
.
server_args
.
enable_pdmux
,
enable_kv_cache_copy
=
(
self
.
server_args
.
speculative_algorithm
is
not
None
),
...
...
@@ -1833,12 +1834,18 @@ class ModelRunner:
def
init_attention_backend
(
self
):
"""Init attention kernel backend."""
if
self
.
server_args
.
enable_two_batch_overlap
and
not
self
.
is_draft_worker
:
if
self
.
server_args
.
enable_pdmux
:
self
.
attn_backend
=
self
.
_get_attention_backend
(
init_new_workspace
=
True
)
self
.
decode_attn_backend_group
=
[]
for
_
in
range
(
self
.
server_args
.
sm_group_num
):
self
.
decode_attn_backend_group
.
append
(
self
.
_get_attention_backend
())
self
.
decode_attn_backend
=
self
.
decode_attn_backend_group
[
0
]
elif
self
.
server_args
.
enable_two_batch_overlap
and
not
self
.
is_draft_worker
:
self
.
attn_backend
=
TboAttnBackend
.
init_new
(
self
.
_get_attention_backend
)
else
:
self
.
attn_backend
=
self
.
_get_attention_backend
()
def
_get_attention_backend
(
self
):
def
_get_attention_backend
(
self
,
init_new_workspace
:
bool
=
False
):
"""Init attention kernel backend."""
self
.
prefill_attention_backend_str
,
self
.
decode_attention_backend_str
=
(
self
.
server_args
.
get_attention_backends
()
...
...
@@ -1852,10 +1859,12 @@ class ModelRunner:
attn_backend
=
HybridAttnBackend
(
self
,
decode_backend
=
self
.
_get_attention_backend_from_str
(
self
.
decode_attention_backend_str
self
.
decode_attention_backend_str
,
init_new_workspace
=
init_new_workspace
,
),
prefill_backend
=
self
.
_get_attention_backend_from_str
(
self
.
prefill_attention_backend_str
self
.
prefill_attention_backend_str
,
init_new_workspace
=
init_new_workspace
,
),
)
logger
.
info
(
...
...
@@ -1869,7 +1878,8 @@ class ModelRunner:
)
else
:
attn_backend
=
self
.
_get_attention_backend_from_str
(
self
.
server_args
.
attention_backend
self
.
server_args
.
attention_backend
,
init_new_workspace
=
init_new_workspace
,
)
(
...
...
@@ -1878,9 +1888,12 @@ class ModelRunner:
)
=
(
self
.
prefill_attention_backend_str
,
self
.
decode_attention_backend_str
)
return
attn_backend
def
_get_attention_backend_from_str
(
self
,
backend_str
:
str
):
def
_get_attention_backend_from_str
(
self
,
backend_str
:
str
,
init_new_workspace
:
bool
=
False
):
if
backend_str
not
in
ATTENTION_BACKENDS
:
raise
ValueError
(
f
"Invalid attention backend:
{
backend_str
}
"
)
self
.
init_new_workspace
=
init_new_workspace
full_attention_backend
=
ATTENTION_BACKENDS
[
backend_str
](
self
)
return
attn_backend_wrapper
(
self
,
full_attention_backend
)
...
...
@@ -1978,6 +1991,9 @@ class ModelRunner:
device_mesh
=
torch
.
distributed
.
init_device_mesh
(
self
.
device
,
(
self
.
tp_size
,))
tensor_parallel
(
self
.
model
,
device_mesh
)
def
update_decode_attn_backend
(
self
,
stream_idx
:
int
):
self
.
decode_attn_backend
=
self
.
decode_attn_backend_group
[
stream_idx
]
def
forward_decode
(
self
,
forward_batch
:
ForwardBatch
,
...
...
@@ -1985,7 +2001,11 @@ class ModelRunner:
pp_proxy_tensors
=
None
,
)
->
LogitsProcessorOutput
:
if
not
skip_attn_backend_init
:
self
.
attn_backend
.
init_forward_metadata
(
forward_batch
)
if
self
.
server_args
.
enable_pdmux
:
self
.
decode_attn_backend
.
init_forward_metadata
(
forward_batch
)
forward_batch
.
attn_backend
=
self
.
decode_attn_backend
else
:
self
.
attn_backend
.
init_forward_metadata
(
forward_batch
)
# FIXME: add pp_proxy_tensors arg to all models
kwargs
=
{}
if
self
.
support_pp
:
...
...
@@ -2123,18 +2143,18 @@ class ModelRunner:
skip_attn_backend_init
=
skip_attn_backend_init
,
pp_proxy_tensors
=
pp_proxy_tensors
,
)
elif
forward_batch
.
forward_mode
.
is_extend
():
ret
=
self
.
forward_extend
(
forward_batch
,
skip_attn_backend_init
=
skip_attn_backend_init
,
pp_proxy_tensors
=
pp_proxy_tensors
,
)
elif
forward_batch
.
forward_mode
.
is_split_prefill
():
ret
=
self
.
forward_split_prefill
(
forward_batch
,
reinit_attn_backend
=
reinit_attn_backend
,
forward_count
=
split_forward_count
,
)
elif
forward_batch
.
forward_mode
.
is_extend
():
ret
=
self
.
forward_extend
(
forward_batch
,
skip_attn_backend_init
=
skip_attn_backend_init
,
pp_proxy_tensors
=
pp_proxy_tensors
,
)
elif
forward_batch
.
forward_mode
.
is_idle
():
ret
=
self
.
forward_idle
(
forward_batch
,
pp_proxy_tensors
=
pp_proxy_tensors
)
else
:
...
...
python/sglang/srt/multiplex/multiplexing_mixin.py
0 → 100644
View file @
41efcaeb
"""
Mixin class providing multiplexing scheduling logic
"""
import
logging
import
torch
import
torch.distributed
as
dist
from
torch.cuda.streams
import
ExternalStream
from
sglang.srt.distributed.parallel_state
import
set_pdmux_status
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
from
sglang.srt.multiplex.pdmux_context
import
(
get_current_stream_idx
,
get_sm_counts
,
get_stream_groups
,
initialize_stream_groups
,
load_pdmux_config
,
set_current_stream_idx
,
)
logger
=
logging
.
getLogger
(
__name__
)
class
SchedulerMultiplexMixin
:
def
init_pdmux
(
self
):
# for pd_multiplexing, Init stream_groups, exclude normal stream for prefill only and decode only
self
.
pdmux_config
=
load_pdmux_config
(
self
.
server_args
.
pdmux_config_path
)
initialize_stream_groups
(
self
.
gpu_id
,
self
.
pdmux_config
)
self
.
stream_groups
=
get_stream_groups
()
self
.
sm_counts
=
get_sm_counts
()
self
.
real_sm_group_num
=
len
(
self
.
stream_groups
)
logger
.
info
(
f
"PD-Multiplexing enabled with
{
self
.
real_sm_group_num
}
stream groups, sm_counts (prefill_sm, decode_sm):
{
self
.
sm_counts
}
"
)
# TODO(jason-fxz): This is a temporary demo
def
adjust_stream_groups
(
self
)
->
tuple
[
int
,
tuple
[
ExternalStream
,
ExternalStream
]]:
if
not
self
.
running_batch
.
is_empty
()
and
self
.
split_prefill_batch
:
decode_bs
=
self
.
running_batch
.
batch_size
()
manual_divisions
=
self
.
pdmux_config
.
manual_divisions
if
manual_divisions
:
for
i
in
range
(
len
(
manual_divisions
)):
_
,
_
,
threshold
=
manual_divisions
[
i
]
if
decode_bs
>=
threshold
:
stream_idx
=
i
+
1
else
:
stream_idx
=
max
(
1
,
min
(
self
.
real_sm_group_num
-
2
,
decode_bs
*
(
self
.
real_sm_group_num
-
2
)
//
self
.
pdmux_config
.
decode_bs_divisor
,
),
)
set_current_stream_idx
(
stream_idx
)
elif
not
self
.
running_batch
.
is_empty
():
set_current_stream_idx
(
self
.
real_sm_group_num
-
1
)
else
:
set_current_stream_idx
(
0
)
stream_idx
=
get_current_stream_idx
()
self
.
tp_worker
.
model_runner
.
update_decode_attn_backend
(
stream_idx
)
return
stream_idx
,
self
.
stream_groups
[
stream_idx
]
def
update_split_prefill_batch
(
self
,
sm_count
:
int
)
->
bool
:
if
self
.
split_prefill_batch
:
return
False
# add new request
batch
=
self
.
get_new_batch_prefill
()
if
batch
and
not
batch
.
is_empty
():
batch
.
forward_mode
=
(
ForwardMode
.
SPLIT_PREFILL
)
# Set forward mode for split prefill
self
.
split_prefill_batch
=
batch
return
True
return
False
@
torch
.
inference_mode
()
def
event_loop_pdmux
(
self
):
"""A scheduler loop for pd multiplexing."""
decode_done
=
False
prefill_done
=
False
wait_prefill_kernel_done
=
False
adjust_stream_group
=
False
stream_idx
=
get_current_stream_idx
()
stream_group
=
self
.
stream_groups
[
stream_idx
]
prefill_stream
=
stream_group
[
0
]
decode_stream
=
stream_group
[
1
]
torch
.
cuda
.
empty_cache
()
logger
.
debug
(
"Starting event loop for pd multiplexing..."
)
while
True
:
with
torch
.
cuda
.
stream
(
decode_stream
):
set_pdmux_status
(
False
)
recv_reqs
=
self
.
recv_requests
()
self
.
process_input_requests
(
recv_reqs
)
with
torch
.
cuda
.
stream
(
prefill_stream
):
set_pdmux_status
(
True
)
sm_count
=
self
.
sm_counts
[
stream_idx
][
0
]
if
not
wait_prefill_kernel_done
:
adjust_stream_group
=
(
self
.
update_split_prefill_batch
(
sm_count
)
or
adjust_stream_group
)
with
torch
.
cuda
.
stream
(
decode_stream
):
set_pdmux_status
(
False
)
self
.
running_batch
=
self
.
update_running_batch
(
self
.
running_batch
)
adjust_stream_group
=
adjust_stream_group
or
(
stream_idx
>
0
and
self
.
running_batch
.
is_empty
()
)
if
self
.
running_batch
.
is_empty
()
and
self
.
split_prefill_batch
is
None
:
self
.
check_memory
()
self
.
check_tree_cache
()
self
.
new_token_ratio
=
self
.
init_new_token_ratio
self
.
maybe_sleep_on_idle
()
if
adjust_stream_group
:
prefill_stream
.
synchronize
()
decode_stream
.
synchronize
()
stream_idx
,
stream_group
=
self
.
adjust_stream_groups
()
prefill_stream
=
stream_group
[
0
]
decode_stream
=
stream_group
[
1
]
adjust_stream_group
=
False
logger
.
debug
(
f
"Adjusting stream groups:
{
stream_idx
}
, prefill sm:
{
self
.
sm_counts
[
stream_idx
][
0
]
}
, decode sm:
{
self
.
sm_counts
[
stream_idx
][
1
]
}
"
)
with
torch
.
cuda
.
stream
(
decode_stream
):
set_pdmux_status
(
False
)
# process decode batch
if
self
.
running_batch
and
not
self
.
running_batch
.
is_empty
():
decode_result
=
self
.
run_batch
(
self
.
running_batch
)
decode_done
=
True
else
:
decode_done
=
False
with
torch
.
cuda
.
stream
(
prefill_stream
):
set_pdmux_status
(
True
)
if
(
self
.
split_prefill_batch
and
not
self
.
split_prefill_batch
.
is_empty
()
and
not
wait_prefill_kernel_done
):
prefill_done
=
True
forward_count
=
(
max
(
1
,
self
.
pdmux_config
.
split_forward_token_budget
//
self
.
split_prefill_batch
.
extend_num_tokens
,
)
if
self
.
split_prefill_batch
.
extend_num_tokens
>
0
else
self
.
model_config
.
num_hidden_layers
)
next_split_index
=
min
(
self
.
split_prefill_batch
.
split_index
+
forward_count
,
self
.
model_config
.
num_hidden_layers
,
)
forward_count
=
(
next_split_index
-
self
.
split_prefill_batch
.
split_index
)
self
.
split_prefill_batch
.
split_forward_count
=
forward_count
prefill_result
=
self
.
run_batch
(
self
.
split_prefill_batch
)
if
next_split_index
==
self
.
model_config
.
num_hidden_layers
:
self
.
split_prefill_batch
.
split_prefill_finished
=
True
prefill_exe_done
=
prefill_stream
.
record_event
()
self
.
split_prefill_batch
.
split_index
=
next_split_index
elif
wait_prefill_kernel_done
:
prefill_done
=
True
else
:
prefill_done
=
False
with
torch
.
cuda
.
stream
(
decode_stream
):
set_pdmux_status
(
False
)
decode_stream
.
synchronize
()
if
decode_done
:
self
.
process_batch_result
(
self
.
running_batch
,
decode_result
)
with
torch
.
cuda
.
stream
(
prefill_stream
):
set_pdmux_status
(
True
)
if
prefill_done
and
self
.
split_prefill_batch
.
split_prefill_finished
:
wait_prefill_kernel_done
=
True
prefill_exe_done_flag
=
prefill_exe_done
.
query
()
flags
=
(
torch
.
ones
(
1
,
device
=
"cpu"
,
dtype
=
torch
.
int32
)
if
prefill_exe_done_flag
else
torch
.
zeros
(
1
,
device
=
"cpu"
,
dtype
=
torch
.
int32
)
)
self
.
tp_cpu_group
.
allreduce
(
flags
,
dist
.
ReduceOp
.
SUM
).
wait
()
if
flags
.
item
()
==
self
.
tp_size
:
self
.
process_batch_result
(
self
.
split_prefill_batch
,
prefill_result
)
if
self
.
running_batch
and
not
self
.
running_batch
.
is_empty
():
self
.
running_batch
.
merge_batch
(
self
.
split_prefill_batch
)
else
:
self
.
running_batch
=
self
.
split_prefill_batch
self
.
split_prefill_batch
=
None
wait_prefill_kernel_done
=
False
adjust_stream_group
=
True
python/sglang/srt/multiplex/pdmux_context.py
0 → 100644
View file @
41efcaeb
from
dataclasses
import
dataclass
,
field
from
typing
import
List
import
torch
import
yaml
STREAM_GROUPS
=
[]
SM_COUNTS
=
[]
SM_GROUP_NUM
=
8
# Default number of SM groups
CURRENT_STREAM_IDX
=
0
CURRENT_STREAM_GROUP
=
None
@
dataclass
class
PDMuxConfig
:
sm_group_num
:
int
=
8
manual_divisions
:
List
[
List
[
int
]]
=
field
(
default_factory
=
list
)
# [prefill_sm, decode_sm, decode_bs_threshold]
split_forward_token_budget
:
int
=
65536
decode_bs_divisor
:
int
=
36
def
load_pdmux_config
(
config_path
:
str
)
->
PDMuxConfig
:
"""Load pdmux configuration from YAML file into a dataclass."""
if
not
config_path
:
return
PDMuxConfig
()
with
open
(
config_path
,
"r"
)
as
f
:
raw
=
yaml
.
safe_load
(
f
)
if
"sm_group_num"
not
in
raw
:
raise
ValueError
(
"Missing required field: sm_group_num"
)
if
raw
[
"sm_group_num"
]
<
3
:
raise
ValueError
(
"sm_group_num must greater than 3"
)
manual_divisions
=
raw
.
get
(
"manual_divisions"
,
[])
expected
=
raw
[
"sm_group_num"
]
-
2
if
manual_divisions
and
len
(
manual_divisions
)
!=
expected
:
raise
ValueError
(
f
"manual_divisions must have
{
expected
}
entries, "
f
"but got
{
len
(
manual_divisions
)
}
"
)
return
PDMuxConfig
(
sm_group_num
=
raw
[
"sm_group_num"
],
manual_divisions
=
manual_divisions
,
split_forward_token_budget
=
raw
.
get
(
"split_forward_token_budget"
,
65536
),
decode_bs_divisor
=
raw
.
get
(
"decode_bs_divisor"
,
36
),
)
def
get_arch_constraints
(
compute_capability
):
major
,
minor
=
compute_capability
# green context constraints for different architectures
if
major
==
6
:
return
1
,
1
# min_per_part, multiple
elif
major
==
7
:
return
2
,
2
elif
major
==
8
:
return
4
,
2
elif
major
==
9
and
minor
>=
0
:
return
8
,
8
else
:
raise
ValueError
(
f
"Unsupported compute capability:
{
major
}
.
{
minor
}
"
)
def
divide_sm
(
total_sms
,
compute_capability
,
groups
):
"""
:param total_sms: total sm count on a single GPU
:param compute_capability: (major, minor)
:return: SM partition group(prefill sm, decode sm)
"""
min_per_part
,
multiple
=
get_arch_constraints
(
compute_capability
)
possible_values
=
[
x
for
x
in
range
(
min_per_part
,
total_sms
-
min_per_part
+
1
,
multiple
)
if
x
>=
total_sms
-
x
and
total_sms
-
x
>=
16
]
if
not
possible_values
:
raise
ValueError
(
f
"No valid partitions found for total SMs
{
total_sms
}
"
f
"with constraints (min per part:
{
min_per_part
}
, multiple:
{
multiple
}
)"
)
if
len
(
possible_values
)
>=
groups
:
step
=
max
(
1
,
len
(
possible_values
)
//
groups
)
selected_values
=
possible_values
[::
step
][:
groups
]
else
:
selected_values
=
possible_values
divisions
=
[]
for
part1
in
selected_values
:
part2
=
total_sms
-
part1
divisions
.
append
((
part1
,
part2
))
divisions
.
reverse
()
# Reverse to have larger prefill SM first
return
divisions
def
initialize_stream_groups
(
gpu_id
:
int
,
config
:
PDMuxConfig
):
from
sgl_kernel
import
spatial
global
STREAM_GROUPS
,
SM_COUNTS
,
SM_GROUP_NUM
,
CURRENT_STREAM_IDX
,
CURRENT_STREAM_GROUP
# for pd_multiplexing, Init stream_groups
device
=
torch
.
cuda
.
current_device
()
total_sm_count
=
spatial
.
get_sm_available
(
gpu_id
)
# (prefill_sm_count, decode_sm_count)
if
config
.
manual_divisions
:
divisions
=
[
(
prefill_sm
,
decode_sm
)
for
prefill_sm
,
decode_sm
,
_
in
config
.
manual_divisions
]
else
:
divisions
=
divide_sm
(
total_sm_count
,
torch
.
cuda
.
get_device_capability
(
device
),
config
.
sm_group_num
-
2
,
)
SM_COUNTS
=
[]
SM_COUNTS
.
append
((
total_sm_count
,
0
))
# Normal stream for prefill
SM_COUNTS
.
extend
(
divisions
)
# Add the divided SM counts
SM_COUNTS
.
append
((
0
,
total_sm_count
))
# Normal stream for decode
STREAM_GROUPS
=
[]
STREAM_GROUPS
.
append
(
(
torch
.
cuda
.
Stream
(
gpu_id
),
torch
.
cuda
.
Stream
(
gpu_id
))
)
# Normal stream for prefill
for
prefill_sm
,
decode_sm
in
divisions
:
STREAM_GROUPS
.
append
(
(
spatial
.
create_greenctx_stream_by_value
(
prefill_sm
,
decode_sm
,
gpu_id
))
)
STREAM_GROUPS
.
append
(
(
torch
.
cuda
.
Stream
(
gpu_id
),
torch
.
cuda
.
Stream
(
gpu_id
))
)
# Normal stream for decode
CURRENT_STREAM_IDX
=
0
CURRENT_STREAM_GROUP
=
STREAM_GROUPS
[
CURRENT_STREAM_IDX
]
def
set_current_stream_idx
(
idx
:
int
):
global
CURRENT_STREAM_IDX
,
CURRENT_STREAM_GROUP
if
idx
<
0
or
idx
>=
len
(
STREAM_GROUPS
):
raise
ValueError
(
f
"Invalid stream index:
{
idx
}
"
)
CURRENT_STREAM_IDX
=
idx
CURRENT_STREAM_GROUP
=
STREAM_GROUPS
[
CURRENT_STREAM_IDX
]
def
get_stream_groups
()
->
list
[
tuple
[
torch
.
cuda
.
Stream
,
torch
.
cuda
.
Stream
]]:
"""Get the stream groups."""
return
STREAM_GROUPS
def
get_sm_counts
()
->
list
[
tuple
[
int
,
int
]]:
"""Get the SM counts."""
return
SM_COUNTS
def
get_current_stream_idx
()
->
int
:
"""Get the current stream index."""
return
CURRENT_STREAM_IDX
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment