Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
bad7c26f
Unverified
Commit
bad7c26f
authored
May 12, 2025
by
Ying Sheng
Committed by
GitHub
May 12, 2025
Browse files
[PP] Fix init_memory_pool desync & add PP for mixtral (#6223)
parent
12319a67
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
179 additions
and
47 deletions
+179
-47
.github/workflows/pr-test.yml
.github/workflows/pr-test.yml
+12
-0
python/sglang/srt/managers/schedule_policy.py
python/sglang/srt/managers/schedule_policy.py
+0
-3
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+4
-4
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+9
-2
python/sglang/srt/models/mixtral.py
python/sglang/srt/models/mixtral.py
+98
-34
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+6
-0
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+6
-4
test/srt/test_bench_serving.py
test/srt/test_bench_serving.py
+44
-0
No files found.
.github/workflows/pr-test.yml
View file @
bad7c26f
...
@@ -229,6 +229,18 @@ jobs:
...
@@ -229,6 +229,18 @@ jobs:
cd test/srt
cd test/srt
python3 -m unittest test_bench_serving.TestBenchServing.test_moe_offline_throughput_without_radix_cache
python3 -m unittest test_bench_serving.TestBenchServing.test_moe_offline_throughput_without_radix_cache
-
name
:
Benchmark offline decode throughput (PP=2)
timeout-minutes
:
10
run
:
|
cd test/srt
python3 -m unittest test_bench_serving.TestBenchServing.test_pp_offline_throughput_default_decode
-
name
:
Benchmark offline prefill throughput (PP=2)
timeout-minutes
:
10
run
:
|
cd test/srt
python3 -m unittest test_bench_serving.TestBenchServing.test_pp_long_context_prefill
accuracy-test-1-gpu
:
accuracy-test-1-gpu
:
if
:
(github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') &&
if
:
(github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') &&
github.event.pull_request.draft ==
false
github.event.pull_request.draft ==
false
...
...
python/sglang/srt/managers/schedule_policy.py
View file @
bad7c26f
...
@@ -468,9 +468,6 @@ class PrefillAdder:
...
@@ -468,9 +468,6 @@ class PrefillAdder:
return
AddReqResult
.
OTHER
return
AddReqResult
.
OTHER
with
self
.
_lock_node
(
req
.
last_node
):
with
self
.
_lock_node
(
req
.
last_node
):
if
total_tokens
>
self
.
rem_total_tokens
:
return
AddReqResult
.
NO_TOKEN
if
(
if
(
enable_hierarchical_cache
enable_hierarchical_cache
and
req
.
last_node_global
is
not
None
and
req
.
last_node_global
is
not
None
...
...
python/sglang/srt/managers/scheduler.py
View file @
bad7c26f
...
@@ -719,7 +719,7 @@ class Scheduler(
...
@@ -719,7 +719,7 @@ class Scheduler(
server_is_idle
=
False
server_is_idle
=
False
result
=
self
.
run_batch
(
self
.
cur_batch
)
result
=
self
.
run_batch
(
self
.
cur_batch
)
# send the outputs to the next step
#
(last rank)
send the outputs to the next step
if
self
.
pp_group
.
is_last_rank
:
if
self
.
pp_group
.
is_last_rank
:
if
self
.
cur_batch
:
if
self
.
cur_batch
:
next_token_ids
,
bids
[
mb_id
]
=
(
next_token_ids
,
bids
[
mb_id
]
=
(
...
@@ -759,18 +759,18 @@ class Scheduler(
...
@@ -759,18 +759,18 @@ class Scheduler(
self
.
process_batch_result
(
mbs
[
next_mb_id
],
output_result
)
self
.
process_batch_result
(
mbs
[
next_mb_id
],
output_result
)
last_mbs
[
next_mb_id
]
=
mbs
[
next_mb_id
]
last_mbs
[
next_mb_id
]
=
mbs
[
next_mb_id
]
#
carry the outputs to the next stage
#
(not last rank)
if
not
self
.
pp_group
.
is_last_rank
:
if
not
self
.
pp_group
.
is_last_rank
:
if
self
.
cur_batch
:
if
self
.
cur_batch
:
bids
[
mb_id
]
=
result
.
bid
bids
[
mb_id
]
=
result
.
bid
# carry the outputs to the next stage
# send the outputs from the last round to let the next stage worker run post processing
if
pp_outputs
:
if
pp_outputs
:
# send the outputs from the last round to let the next stage worker run post processing
self
.
pp_group
.
send_tensor_dict
(
self
.
pp_group
.
send_tensor_dict
(
pp_outputs
.
tensors
,
pp_outputs
.
tensors
,
all_gather_group
=
self
.
attn_tp_group
,
all_gather_group
=
self
.
attn_tp_group
,
)
)
if
not
self
.
pp_group
.
is_last_rank
:
# send out reqs to the next stage
# send out reqs to the next stage
dp_offset
=
self
.
dp_rank
*
self
.
attn_tp_size
dp_offset
=
self
.
dp_rank
*
self
.
attn_tp_size
if
self
.
attn_tp_rank
==
0
:
if
self
.
attn_tp_rank
==
0
:
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
bad7c26f
...
@@ -32,6 +32,7 @@ from sglang.srt.configs.load_config import LoadConfig
...
@@ -32,6 +32,7 @@ from sglang.srt.configs.load_config import LoadConfig
from
sglang.srt.configs.model_config
import
AttentionArch
,
ModelConfig
from
sglang.srt.configs.model_config
import
AttentionArch
,
ModelConfig
from
sglang.srt.distributed
import
(
from
sglang.srt.distributed
import
(
get_tp_group
,
get_tp_group
,
get_world_group
,
init_distributed_environment
,
init_distributed_environment
,
initialize_model_parallel
,
initialize_model_parallel
,
set_custom_all_reduce
,
set_custom_all_reduce
,
...
@@ -404,7 +405,10 @@ class ModelRunner:
...
@@ -404,7 +405,10 @@ class ModelRunner:
)
)
min_per_gpu_memory
=
get_available_gpu_memory
(
min_per_gpu_memory
=
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
,
distributed
=
self
.
tp_size
>
1
self
.
device
,
self
.
gpu_id
,
distributed
=
get_world_group
().
world_size
>
1
,
cpu_group
=
get_world_group
().
cpu_group
,
)
)
self
.
tp_group
=
get_tp_group
()
self
.
tp_group
=
get_tp_group
()
self
.
attention_tp_group
=
get_attention_tp_group
()
self
.
attention_tp_group
=
get_attention_tp_group
()
...
@@ -716,7 +720,10 @@ class ModelRunner:
...
@@ -716,7 +720,10 @@ class ModelRunner:
def
profile_max_num_token
(
self
,
total_gpu_memory
:
int
):
def
profile_max_num_token
(
self
,
total_gpu_memory
:
int
):
available_gpu_memory
=
get_available_gpu_memory
(
available_gpu_memory
=
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
,
distributed
=
self
.
tp_size
>
1
self
.
device
,
self
.
gpu_id
,
distributed
=
get_world_group
().
world_size
>
1
,
cpu_group
=
get_world_group
().
cpu_group
,
)
)
if
self
.
use_mla_backend
:
if
self
.
use_mla_backend
:
num_layers
=
(
num_layers
=
(
...
...
python/sglang/srt/models/mixtral.py
View file @
bad7c26f
...
@@ -16,13 +16,15 @@
...
@@ -16,13 +16,15 @@
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1
"""Inference-only Mixtral model."""
"""Inference-only Mixtral model."""
from
typing
import
Iterable
,
Optional
,
Tuple
import
logging
from
typing
import
Iterable
,
Optional
,
Tuple
,
Union
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
MixtralConfig
from
transformers
import
MixtralConfig
from
sglang.srt.distributed
import
(
from
sglang.srt.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
,
tensor_model_parallel_all_reduce
,
)
)
...
@@ -38,14 +40,17 @@ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
...
@@ -38,14 +40,17 @@ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.rotary_embedding
import
get_rope
from
sglang.srt.layers.rotary_embedding
import
get_rope
from
sglang.srt.layers.utils
import
PPMissingLayer
,
get_layer_id
from
sglang.srt.layers.vocab_parallel_embedding
import
(
from
sglang.srt.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
ParallelLMHead
,
VocabParallelEmbedding
,
VocabParallelEmbedding
,
)
)
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.utils
import
add_prefix
from
sglang.srt.utils
import
add_prefix
,
make_layers
logger
=
logging
.
getLogger
(
__name__
)
class
MixtralMoE
(
nn
.
Module
):
class
MixtralMoE
(
nn
.
Module
):
...
@@ -257,24 +262,32 @@ class MixtralModel(nn.Module):
...
@@ -257,24 +262,32 @@ class MixtralModel(nn.Module):
super
().
__init__
()
super
().
__init__
()
self
.
padding_idx
=
config
.
pad_token_id
self
.
padding_idx
=
config
.
pad_token_id
self
.
vocab_size
=
config
.
vocab_size
self
.
vocab_size
=
config
.
vocab_size
self
.
pp_group
=
get_pp_group
()
self
.
embed_tokens
=
VocabParallelEmbedding
(
if
self
.
pp_group
.
is_first_rank
:
config
.
vocab_size
,
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
hidden_size
,
config
.
vocab_size
,
prefix
=
add_prefix
(
"embed_tokens"
,
prefix
),
config
.
hidden_size
,
)
prefix
=
add_prefix
(
"embed_tokens"
,
prefix
),
self
.
layers
=
nn
.
ModuleList
(
)
[
else
:
MixtralDecoderLayer
(
self
.
embed_tokens
=
PPMissingLayer
()
config
,
i
,
self
.
layers
,
self
.
start_layer
,
self
.
end_layer
=
make_layers
(
quant_config
=
quant_config
,
config
.
num_hidden_layers
,
prefix
=
add_prefix
(
f
"layers.
{
i
}
"
,
prefix
),
lambda
idx
,
prefix
:
MixtralDecoderLayer
(
)
config
=
config
,
quant_config
=
quant_config
,
layer_id
=
idx
,
prefix
=
prefix
for
i
in
range
(
config
.
num_hidden_layers
)
),
]
pp_rank
=
self
.
pp_group
.
rank_in_group
,
pp_size
=
self
.
pp_group
.
world_size
,
prefix
=
"layers"
,
return_tuple
=
True
,
)
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
if
self
.
pp_group
.
is_last_rank
:
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
else
:
self
.
norm
=
PPMissingLayer
(
return_tuple
=
True
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -282,18 +295,35 @@ class MixtralModel(nn.Module):
...
@@ -282,18 +295,35 @@ class MixtralModel(nn.Module):
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
pp_proxy_tensors
:
Optional
[
PPProxyTensors
]
=
None
,
if
input_embeds
is
None
:
)
->
Union
[
torch
.
Tensor
,
PPProxyTensors
]:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
if
self
.
pp_group
.
is_first_rank
:
if
input_embeds
is
None
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
else
:
hidden_states
=
input_embeds
residual
=
None
else
:
else
:
hidden_states
=
input_embeds
assert
pp_proxy_tensors
is
not
None
residual
=
None
hidden_states
=
pp_proxy_tensors
[
"hidden_states"
]
for
i
in
range
(
len
(
self
.
layers
)):
residual
=
pp_proxy_tensors
[
"residual"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
layer
=
self
.
layers
[
i
]
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
forward_batch
,
residual
positions
,
hidden_states
,
forward_batch
,
residual
)
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
if
not
self
.
pp_group
.
is_last_rank
:
return
PPProxyTensors
(
{
"hidden_states"
:
hidden_states
,
"residual"
:
residual
,
}
)
else
:
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
return
hidden_states
...
@@ -306,6 +336,7 @@ class MixtralForCausalLM(nn.Module):
...
@@ -306,6 +336,7 @@ class MixtralForCausalLM(nn.Module):
prefix
:
str
=
""
,
prefix
:
str
=
""
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
pp_group
=
get_pp_group
()
self
.
config
=
config
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
model
=
MixtralModel
(
self
.
model
=
MixtralModel
(
...
@@ -322,12 +353,31 @@ class MixtralForCausalLM(nn.Module):
...
@@ -322,12 +353,31 @@ class MixtralForCausalLM(nn.Module):
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
input_embeds
:
torch
.
Tensor
=
None
,
pp_proxy_tensors
:
Optional
[
PPProxyTensors
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
)
hidden_states
=
self
.
model
(
return
self
.
logits_processor
(
input_ids
,
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
positions
,
forward_batch
,
input_embeds
,
pp_proxy_tensors
=
pp_proxy_tensors
,
)
)
if
self
.
pp_group
.
is_last_rank
:
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
)
else
:
return
hidden_states
@
property
def
start_layer
(
self
):
return
self
.
model
.
start_layer
@
property
def
end_layer
(
self
):
return
self
.
model
.
end_layer
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
# (param_name, shard_name, shard_id)
...
@@ -348,6 +398,17 @@ class MixtralForCausalLM(nn.Module):
...
@@ -348,6 +398,17 @@ class MixtralForCausalLM(nn.Module):
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
weights
:
for
name
,
loaded_weight
in
weights
:
layer_id
=
get_layer_id
(
name
)
if
(
layer_id
is
not
None
and
hasattr
(
self
.
model
,
"start_layer"
)
and
(
layer_id
<
self
.
model
.
start_layer
or
layer_id
>=
self
.
model
.
end_layer
)
):
continue
if
"rotary_emb.inv_freq"
in
name
:
if
"rotary_emb.inv_freq"
in
name
:
continue
continue
...
@@ -398,11 +459,14 @@ class MixtralForCausalLM(nn.Module):
...
@@ -398,11 +459,14 @@ class MixtralForCausalLM(nn.Module):
if
name
is
None
:
if
name
is
None
:
continue
continue
param
=
params_dict
[
name
]
if
name
in
params_dict
.
keys
():
weight_loader
=
getattr
(
param
=
params_dict
[
name
]
param
,
"weight_loader"
,
default_weight_loader
weight_loader
=
getattr
(
)
param
,
"weight_loader"
,
default_weight_loader
weight_loader
(
param
,
loaded_weight
)
)
weight_loader
(
param
,
loaded_weight
)
else
:
logger
.
warning
(
f
"Parameter
{
name
}
not found in params_dict"
)
EntryClass
=
MixtralForCausalLM
EntryClass
=
MixtralForCausalLM
python/sglang/srt/server_args.py
View file @
bad7c26f
...
@@ -347,6 +347,12 @@ class ServerArgs:
...
@@ -347,6 +347,12 @@ class ServerArgs:
f
"DeepEP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[
{
self
.
tp_size
}
]."
f
"DeepEP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[
{
self
.
tp_size
}
]."
)
)
if
self
.
pp_size
>
1
:
self
.
disable_overlap_schedule
=
True
logger
.
warning
(
"Pipeline parallelism is incompatible with overlap schedule."
)
# Speculative Decoding
# Speculative Decoding
if
self
.
speculative_algorithm
==
"NEXTN"
:
if
self
.
speculative_algorithm
==
"NEXTN"
:
# NEXTN shares the same implementation of EAGLE
# NEXTN shares the same implementation of EAGLE
...
...
python/sglang/srt/utils.py
View file @
bad7c26f
...
@@ -282,7 +282,9 @@ def calculate_time(show=False, min_cost_ms=0.0):
...
@@ -282,7 +282,9 @@ def calculate_time(show=False, min_cost_ms=0.0):
return
wrapper
return
wrapper
def
get_available_gpu_memory
(
device
,
gpu_id
,
distributed
=
False
,
empty_cache
=
True
):
def
get_available_gpu_memory
(
device
,
gpu_id
,
distributed
=
False
,
empty_cache
=
True
,
cpu_group
=
None
):
"""
"""
Get available memory for cuda:gpu_id device.
Get available memory for cuda:gpu_id device.
When distributed is True, the available memory is the minimum available memory of all GPUs.
When distributed is True, the available memory is the minimum available memory of all GPUs.
...
@@ -344,10 +346,10 @@ def get_available_gpu_memory(device, gpu_id, distributed=False, empty_cache=True
...
@@ -344,10 +346,10 @@ def get_available_gpu_memory(device, gpu_id, distributed=False, empty_cache=True
free_gpu_memory
,
total_gpu_memory
=
torch
.
npu
.
mem_get_info
()
free_gpu_memory
,
total_gpu_memory
=
torch
.
npu
.
mem_get_info
()
if
distributed
:
if
distributed
:
tensor
=
torch
.
tensor
(
free_gpu_memory
,
dtype
=
torch
.
float32
).
to
(
tensor
=
torch
.
tensor
(
free_gpu_memory
,
dtype
=
torch
.
float32
)
torch
.
device
(
device
,
gpu_id
)
torch
.
distributed
.
all_reduce
(
tensor
,
op
=
torch
.
distributed
.
ReduceOp
.
MIN
,
group
=
cpu_group
)
)
torch
.
distributed
.
all_reduce
(
tensor
,
op
=
torch
.
distributed
.
ReduceOp
.
MIN
)
free_gpu_memory
=
tensor
.
item
()
free_gpu_memory
=
tensor
.
item
()
return
free_gpu_memory
/
(
1
<<
30
)
return
free_gpu_memory
/
(
1
<<
30
)
...
...
test/srt/test_bench_serving.py
View file @
bad7c26f
...
@@ -272,6 +272,50 @@ class TestBenchServing(CustomTestCase):
...
@@ -272,6 +272,50 @@ class TestBenchServing(CustomTestCase):
else
:
else
:
self
.
assertGreater
(
res
[
"output_throughput"
],
2200
)
self
.
assertGreater
(
res
[
"output_throughput"
],
2200
)
def
test_pp_offline_throughput_default_decode
(
self
):
res
=
run_bench_serving
(
model
=
DEFAULT_MOE_MODEL_NAME_FOR_TEST
,
num_prompts
=
1000
,
request_rate
=
float
(
"inf"
),
random_input_len
=
1
,
random_output_len
=
1024
,
other_server_args
=
[
"--pp"
,
"2"
],
need_warmup
=
True
,
seed
=
42
,
)
if
is_in_ci
():
write_github_step_summary
(
f
"### test_pp_offline_throughput_default_decode
\n
"
f
'Output throughput:
{
res
[
"output_throughput"
]:.
2
f
}
token/s
\n
'
)
self
.
assertGreater
(
res
[
"output_throughput"
],
7500
)
def
test_pp_long_context_prefill
(
self
):
res
=
run_bench_serving
(
model
=
"meta-llama/Llama-3.3-70B-Instruct"
,
num_prompts
=
4
,
request_rate
=
float
(
"inf"
),
random_input_len
=
128000
,
random_output_len
=
1
,
dataset_name
=
"random"
,
other_server_args
=
[
"--quantization"
,
"fp8"
,
"--pp"
,
2
,
],
need_warmup
=
False
,
seed
=
42
,
)
if
is_in_ci
():
write_github_step_summary
(
f
"### test_pp_long_context_latency_prefill
\n
"
f
'input_throughput:
{
res
[
"input_throughput"
]:.
2
f
}
ms
\n
'
)
self
.
assertGreater
(
res
[
"input_throughput"
],
4000
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment