Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
bad7c26f
Unverified
Commit
bad7c26f
authored
May 12, 2025
by
Ying Sheng
Committed by
GitHub
May 12, 2025
Browse files
[PP] Fix init_memory_pool desync & add PP for mixtral (#6223)
parent
12319a67
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
179 additions
and
47 deletions
+179
-47
.github/workflows/pr-test.yml
.github/workflows/pr-test.yml
+12
-0
python/sglang/srt/managers/schedule_policy.py
python/sglang/srt/managers/schedule_policy.py
+0
-3
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+4
-4
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+9
-2
python/sglang/srt/models/mixtral.py
python/sglang/srt/models/mixtral.py
+98
-34
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+6
-0
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+6
-4
test/srt/test_bench_serving.py
test/srt/test_bench_serving.py
+44
-0
No files found.
.github/workflows/pr-test.yml
View file @
bad7c26f
...
...
@@ -229,6 +229,18 @@ jobs:
cd test/srt
python3 -m unittest test_bench_serving.TestBenchServing.test_moe_offline_throughput_without_radix_cache
-
name
:
Benchmark offline decode throughput (PP=2)
timeout-minutes
:
10
run
:
|
cd test/srt
python3 -m unittest test_bench_serving.TestBenchServing.test_pp_offline_throughput_default_decode
-
name
:
Benchmark offline prefill throughput (PP=2)
timeout-minutes
:
10
run
:
|
cd test/srt
python3 -m unittest test_bench_serving.TestBenchServing.test_pp_long_context_prefill
accuracy-test-1-gpu
:
if
:
(github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') &&
github.event.pull_request.draft ==
false
...
...
python/sglang/srt/managers/schedule_policy.py
View file @
bad7c26f
...
...
@@ -468,9 +468,6 @@ class PrefillAdder:
return
AddReqResult
.
OTHER
with
self
.
_lock_node
(
req
.
last_node
):
if
total_tokens
>
self
.
rem_total_tokens
:
return
AddReqResult
.
NO_TOKEN
if
(
enable_hierarchical_cache
and
req
.
last_node_global
is
not
None
...
...
python/sglang/srt/managers/scheduler.py
View file @
bad7c26f
...
...
@@ -719,7 +719,7 @@ class Scheduler(
server_is_idle
=
False
result
=
self
.
run_batch
(
self
.
cur_batch
)
# send the outputs to the next step
#
(last rank)
send the outputs to the next step
if
self
.
pp_group
.
is_last_rank
:
if
self
.
cur_batch
:
next_token_ids
,
bids
[
mb_id
]
=
(
...
...
@@ -759,18 +759,18 @@ class Scheduler(
self
.
process_batch_result
(
mbs
[
next_mb_id
],
output_result
)
last_mbs
[
next_mb_id
]
=
mbs
[
next_mb_id
]
#
carry the outputs to the next stage
#
(not last rank)
if
not
self
.
pp_group
.
is_last_rank
:
if
self
.
cur_batch
:
bids
[
mb_id
]
=
result
.
bid
# carry the outputs to the next stage
# send the outputs from the last round to let the next stage worker run post processing
if
pp_outputs
:
# send the outputs from the last round to let the next stage worker run post processing
self
.
pp_group
.
send_tensor_dict
(
pp_outputs
.
tensors
,
all_gather_group
=
self
.
attn_tp_group
,
)
if
not
self
.
pp_group
.
is_last_rank
:
# send out reqs to the next stage
dp_offset
=
self
.
dp_rank
*
self
.
attn_tp_size
if
self
.
attn_tp_rank
==
0
:
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
bad7c26f
...
...
@@ -32,6 +32,7 @@ from sglang.srt.configs.load_config import LoadConfig
from
sglang.srt.configs.model_config
import
AttentionArch
,
ModelConfig
from
sglang.srt.distributed
import
(
get_tp_group
,
get_world_group
,
init_distributed_environment
,
initialize_model_parallel
,
set_custom_all_reduce
,
...
...
@@ -404,7 +405,10 @@ class ModelRunner:
)
min_per_gpu_memory
=
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
,
distributed
=
self
.
tp_size
>
1
self
.
device
,
self
.
gpu_id
,
distributed
=
get_world_group
().
world_size
>
1
,
cpu_group
=
get_world_group
().
cpu_group
,
)
self
.
tp_group
=
get_tp_group
()
self
.
attention_tp_group
=
get_attention_tp_group
()
...
...
@@ -716,7 +720,10 @@ class ModelRunner:
def
profile_max_num_token
(
self
,
total_gpu_memory
:
int
):
available_gpu_memory
=
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
,
distributed
=
self
.
tp_size
>
1
self
.
device
,
self
.
gpu_id
,
distributed
=
get_world_group
().
world_size
>
1
,
cpu_group
=
get_world_group
().
cpu_group
,
)
if
self
.
use_mla_backend
:
num_layers
=
(
...
...
python/sglang/srt/models/mixtral.py
View file @
bad7c26f
...
...
@@ -16,13 +16,15 @@
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1
"""Inference-only Mixtral model."""
from
typing
import
Iterable
,
Optional
,
Tuple
import
logging
from
typing
import
Iterable
,
Optional
,
Tuple
,
Union
import
torch
from
torch
import
nn
from
transformers
import
MixtralConfig
from
sglang.srt.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
,
)
...
...
@@ -38,14 +40,17 @@ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.rotary_embedding
import
get_rope
from
sglang.srt.layers.utils
import
PPMissingLayer
,
get_layer_id
from
sglang.srt.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
,
)
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.utils
import
add_prefix
from
sglang.srt.utils
import
add_prefix
,
make_layers
logger
=
logging
.
getLogger
(
__name__
)
class
MixtralMoE
(
nn
.
Module
):
...
...
@@ -257,24 +262,32 @@ class MixtralModel(nn.Module):
super
().
__init__
()
self
.
padding_idx
=
config
.
pad_token_id
self
.
vocab_size
=
config
.
vocab_size
self
.
pp_group
=
get_pp_group
()
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
,
prefix
=
add_prefix
(
"embed_tokens"
,
prefix
),
)
self
.
layers
=
nn
.
ModuleList
(
[
MixtralDecoderLayer
(
config
,
i
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
f
"layers.
{
i
}
"
,
prefix
),
)
for
i
in
range
(
config
.
num_hidden_layers
)
]
if
self
.
pp_group
.
is_first_rank
:
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
,
prefix
=
add_prefix
(
"embed_tokens"
,
prefix
),
)
else
:
self
.
embed_tokens
=
PPMissingLayer
()
self
.
layers
,
self
.
start_layer
,
self
.
end_layer
=
make_layers
(
config
.
num_hidden_layers
,
lambda
idx
,
prefix
:
MixtralDecoderLayer
(
config
=
config
,
quant_config
=
quant_config
,
layer_id
=
idx
,
prefix
=
prefix
),
pp_rank
=
self
.
pp_group
.
rank_in_group
,
pp_size
=
self
.
pp_group
.
world_size
,
prefix
=
"layers"
,
return_tuple
=
True
,
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
if
self
.
pp_group
.
is_last_rank
:
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
else
:
self
.
norm
=
PPMissingLayer
(
return_tuple
=
True
)
def
forward
(
self
,
...
...
@@ -282,18 +295,35 @@ class MixtralModel(nn.Module):
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
if
input_embeds
is
None
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
pp_proxy_tensors
:
Optional
[
PPProxyTensors
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
PPProxyTensors
]:
if
self
.
pp_group
.
is_first_rank
:
if
input_embeds
is
None
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
else
:
hidden_states
=
input_embeds
residual
=
None
else
:
hidden_states
=
input_embeds
residual
=
None
for
i
in
range
(
len
(
self
.
layers
)):
assert
pp_proxy_tensors
is
not
None
hidden_states
=
pp_proxy_tensors
[
"hidden_states"
]
residual
=
pp_proxy_tensors
[
"residual"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
forward_batch
,
residual
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
if
not
self
.
pp_group
.
is_last_rank
:
return
PPProxyTensors
(
{
"hidden_states"
:
hidden_states
,
"residual"
:
residual
,
}
)
else
:
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
...
...
@@ -306,6 +336,7 @@ class MixtralForCausalLM(nn.Module):
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
pp_group
=
get_pp_group
()
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
model
=
MixtralModel
(
...
...
@@ -322,12 +353,31 @@ class MixtralForCausalLM(nn.Module):
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
pp_proxy_tensors
:
Optional
[
PPProxyTensors
]
=
None
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
)
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
,
pp_proxy_tensors
=
pp_proxy_tensors
,
)
if
self
.
pp_group
.
is_last_rank
:
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
)
else
:
return
hidden_states
@
property
def
start_layer
(
self
):
return
self
.
model
.
start_layer
@
property
def
end_layer
(
self
):
return
self
.
model
.
end_layer
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
...
...
@@ -348,6 +398,17 @@ class MixtralForCausalLM(nn.Module):
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
weights
:
layer_id
=
get_layer_id
(
name
)
if
(
layer_id
is
not
None
and
hasattr
(
self
.
model
,
"start_layer"
)
and
(
layer_id
<
self
.
model
.
start_layer
or
layer_id
>=
self
.
model
.
end_layer
)
):
continue
if
"rotary_emb.inv_freq"
in
name
:
continue
...
...
@@ -398,11 +459,14 @@ class MixtralForCausalLM(nn.Module):
if
name
is
None
:
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
if
name
in
params_dict
.
keys
():
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
else
:
logger
.
warning
(
f
"Parameter
{
name
}
not found in params_dict"
)
EntryClass
=
MixtralForCausalLM
python/sglang/srt/server_args.py
View file @
bad7c26f
...
...
@@ -347,6 +347,12 @@ class ServerArgs:
f
"DeepEP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[
{
self
.
tp_size
}
]."
)
if
self
.
pp_size
>
1
:
self
.
disable_overlap_schedule
=
True
logger
.
warning
(
"Pipeline parallelism is incompatible with overlap schedule."
)
# Speculative Decoding
if
self
.
speculative_algorithm
==
"NEXTN"
:
# NEXTN shares the same implementation of EAGLE
...
...
python/sglang/srt/utils.py
View file @
bad7c26f
...
...
@@ -282,7 +282,9 @@ def calculate_time(show=False, min_cost_ms=0.0):
return
wrapper
def
get_available_gpu_memory
(
device
,
gpu_id
,
distributed
=
False
,
empty_cache
=
True
):
def
get_available_gpu_memory
(
device
,
gpu_id
,
distributed
=
False
,
empty_cache
=
True
,
cpu_group
=
None
):
"""
Get available memory for cuda:gpu_id device.
When distributed is True, the available memory is the minimum available memory of all GPUs.
...
...
@@ -344,10 +346,10 @@ def get_available_gpu_memory(device, gpu_id, distributed=False, empty_cache=True
free_gpu_memory
,
total_gpu_memory
=
torch
.
npu
.
mem_get_info
()
if
distributed
:
tensor
=
torch
.
tensor
(
free_gpu_memory
,
dtype
=
torch
.
float32
).
to
(
torch
.
device
(
device
,
gpu_id
)
tensor
=
torch
.
tensor
(
free_gpu_memory
,
dtype
=
torch
.
float32
)
torch
.
distributed
.
all_reduce
(
tensor
,
op
=
torch
.
distributed
.
ReduceOp
.
MIN
,
group
=
cpu_group
)
torch
.
distributed
.
all_reduce
(
tensor
,
op
=
torch
.
distributed
.
ReduceOp
.
MIN
)
free_gpu_memory
=
tensor
.
item
()
return
free_gpu_memory
/
(
1
<<
30
)
...
...
test/srt/test_bench_serving.py
View file @
bad7c26f
...
...
@@ -272,6 +272,50 @@ class TestBenchServing(CustomTestCase):
else
:
self
.
assertGreater
(
res
[
"output_throughput"
],
2200
)
def
test_pp_offline_throughput_default_decode
(
self
):
res
=
run_bench_serving
(
model
=
DEFAULT_MOE_MODEL_NAME_FOR_TEST
,
num_prompts
=
1000
,
request_rate
=
float
(
"inf"
),
random_input_len
=
1
,
random_output_len
=
1024
,
other_server_args
=
[
"--pp"
,
"2"
],
need_warmup
=
True
,
seed
=
42
,
)
if
is_in_ci
():
write_github_step_summary
(
f
"### test_pp_offline_throughput_default_decode
\n
"
f
'Output throughput:
{
res
[
"output_throughput"
]:.
2
f
}
token/s
\n
'
)
self
.
assertGreater
(
res
[
"output_throughput"
],
7500
)
def
test_pp_long_context_prefill
(
self
):
res
=
run_bench_serving
(
model
=
"meta-llama/Llama-3.3-70B-Instruct"
,
num_prompts
=
4
,
request_rate
=
float
(
"inf"
),
random_input_len
=
128000
,
random_output_len
=
1
,
dataset_name
=
"random"
,
other_server_args
=
[
"--quantization"
,
"fp8"
,
"--pp"
,
2
,
],
need_warmup
=
False
,
seed
=
42
,
)
if
is_in_ci
():
write_github_step_summary
(
f
"### test_pp_long_context_latency_prefill
\n
"
f
'input_throughput:
{
res
[
"input_throughput"
]:.
2
f
}
ms
\n
'
)
self
.
assertGreater
(
res
[
"input_throughput"
],
4000
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment