Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d997afc4
Commit
d997afc4
authored
Sep 04, 2025
by
王敏
Browse files
1.优化大EP,合入grouped gemm
2.解决mtp >1 大EP推理all gather卡住问题
parent
6f5d76dc
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
147 additions
and
69 deletions
+147
-69
vllm/envs.py
vllm/envs.py
+5
-1
vllm/forward_context.py
vllm/forward_context.py
+4
-2
vllm/model_executor/layers/fused_moe/ep_moe/ep_moe_utlis.py
vllm/model_executor/layers/fused_moe/ep_moe/ep_moe_utlis.py
+3
-1
vllm/model_executor/layers/fused_moe/ep_moe/layer.py
vllm/model_executor/layers/fused_moe/ep_moe/layer.py
+59
-34
vllm/model_executor/layers/fused_moe/ep_moe/token_dispatcher.py
...odel_executor/layers/fused_moe/ep_moe/token_dispatcher.py
+22
-13
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+6
-2
vllm/model_executor/models/deepseek_mtp.py
vllm/model_executor/models/deepseek_mtp.py
+18
-0
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+15
-15
vllm/v1/spec_decode/eagle.py
vllm/v1/spec_decode/eagle.py
+11
-0
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+4
-1
No files found.
vllm/envs.py
View file @
d997afc4
...
@@ -164,6 +164,7 @@ if TYPE_CHECKING:
...
@@ -164,6 +164,7 @@ if TYPE_CHECKING:
VLLM_USE_FLASH_ATTN_PA
:
bool
=
False
VLLM_USE_FLASH_ATTN_PA
:
bool
=
False
VLLM_USE_APEX_RN
:
bool
=
False
VLLM_USE_APEX_RN
:
bool
=
False
VLLM_USE_GLOBAL_CACHE13
:
bool
=
False
VLLM_USE_GLOBAL_CACHE13
:
bool
=
False
VLLM_USE_ALLTOALL_EP
:
bool
=
False
def
get_default_cache_root
():
def
get_default_cache_root
():
return
os
.
getenv
(
return
os
.
getenv
(
...
@@ -1089,7 +1090,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
...
@@ -1089,7 +1090,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_GLOBAL_CACHE13"
:
"VLLM_USE_GLOBAL_CACHE13"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_GLOBAL_CACHE13"
,
"False"
).
lower
()
in
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_GLOBAL_CACHE13"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
(
"true"
,
"1"
)),
# vLLM will use all_to_all ep mode
"VLLM_USE_ALLTOALL_EP"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_ALLTOALL_EP"
,
"True"
).
lower
()
in
(
"true"
,
"1"
)),
}
}
# --8<-- [end:env-vars-definition]
# --8<-- [end:env-vars-definition]
...
...
vllm/forward_context.py
View file @
d997afc4
...
@@ -135,8 +135,10 @@ def set_forward_context(
...
@@ -135,8 +135,10 @@ def set_forward_context(
if
need_to_track_batchsize
:
if
need_to_track_batchsize
:
forward_start_time
=
time
.
perf_counter
()
forward_start_time
=
time
.
perf_counter
()
dp_metadata
:
Optional
[
DPMetadata
]
=
None
dp_metadata
:
Optional
[
DPMetadata
]
=
None
if
vllm_config
.
parallel_config
.
data_parallel_size
>
1
and
(
dp_size
=
vllm_config
.
parallel_config
.
data_parallel_size
attn_metadata
is
not
None
or
num_tokens
is
not
None
):
use_all2all_ep
=
envs
.
VLLM_USE_ALLTOALL_EP
and
dp_size
>
1
and
vllm_config
.
parallel_config
.
enable_expert_parallel
if
not
use_all2all_ep
and
dp_size
>
1
and
(
attn_metadata
is
not
None
or
num_tokens
is
not
None
)
:
dp_metadata
=
DPMetadata
.
make
(
vllm_config
.
parallel_config
,
dp_metadata
=
DPMetadata
.
make
(
vllm_config
.
parallel_config
,
attn_metadata
,
num_tokens
or
0
,
attn_metadata
,
num_tokens
or
0
,
num_tokens_across_dp
)
num_tokens_across_dp
)
...
...
vllm/model_executor/layers/fused_moe/ep_moe/ep_moe_utlis.py
View file @
d997afc4
...
@@ -327,7 +327,8 @@ def all_to_all(group, input, output_split_sizes, input_split_sizes):
...
@@ -327,7 +327,8 @@ def all_to_all(group, input, output_split_sizes, input_split_sizes):
output
=
input
.
new_empty
(
output
=
input
.
new_empty
(
size
=
[
sum
(
output_split_sizes
)]
+
list
(
input
.
size
()[
1
:]),
size
=
[
sum
(
output_split_sizes
)]
+
list
(
input
.
size
()[
1
:]),
dtype
=
input
.
dtype
,
dtype
=
input
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
#device=torch.cuda.current_device(),
device
=
input
.
device
,
)
)
torch
.
distributed
.
all_to_all_single
(
torch
.
distributed
.
all_to_all_single
(
...
@@ -336,6 +337,7 @@ def all_to_all(group, input, output_split_sizes, input_split_sizes):
...
@@ -336,6 +337,7 @@ def all_to_all(group, input, output_split_sizes, input_split_sizes):
output_split_sizes
=
output_split_sizes
,
output_split_sizes
=
output_split_sizes
,
input_split_sizes
=
input_split_sizes
,
input_split_sizes
=
input_split_sizes
,
group
=
group
,
group
=
group
,
async_op
=
True
)
)
return
output
return
output
vllm/model_executor/layers/fused_moe/ep_moe/layer.py
View file @
d997afc4
...
@@ -18,6 +18,7 @@ from vllm.model_executor.layers.fused_moe.layer import FusedMoEMethodBase, Unqua
...
@@ -18,6 +18,7 @@ from vllm.model_executor.layers.fused_moe.layer import FusedMoEMethodBase, Unqua
from
vllm.model_executor.layers.fused_moe.ep_moe.token_dispatcher
import
MoEAlltoAllTokenDispatcher
from
vllm.model_executor.layers.fused_moe.ep_moe.token_dispatcher
import
MoEAlltoAllTokenDispatcher
from
vllm.model_executor.layers.fused_moe.ep_moe.ep_moe_utlis
import
EpMoeConfig
from
vllm.model_executor.layers.fused_moe.ep_moe.ep_moe_utlis
import
EpMoeConfig
from
vllm.utils
import
direct_register_custom_op
from
vllm.utils
import
direct_register_custom_op
from
lightop
import
groupgemm
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -33,6 +34,7 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
...
@@ -33,6 +34,7 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
self
.
moe
=
moe
self
.
moe
=
moe
self
.
rocm_aiter_moe_enabled
=
False
# is_rocm_aiter_moe_enabled()
self
.
rocm_aiter_moe_enabled
=
False
# is_rocm_aiter_moe_enabled()
self
.
zero_token_count
=
None
def
apply
(
def
apply
(
self
,
self
,
...
@@ -55,40 +57,59 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
...
@@ -55,40 +57,59 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
# process MoE
# process MoE
def
custom_forward
(
layer
,
hidden_states
,
tokens_per_expert
):
def
custom_forward
(
layer
,
hidden_states
,
tokens_per_expert
):
tokens_per_expert
=
tokens_per_expert
.
cpu
().
numpy
()
if
False
:
tokens_per_expert
=
tokens_per_expert
.
cpu
().
numpy
()
outputs
=
[]
outputs
=
[]
start_idx
=
0
start_idx
=
0
for
i
,
num_tokens
in
enumerate
(
tokens_per_expert
):
for
i
,
num_tokens
in
enumerate
(
tokens_per_expert
):
end_idx
=
start_idx
+
num_tokens
end_idx
=
start_idx
+
num_tokens
if
num_tokens
==
0
:
if
num_tokens
==
0
:
continue
continue
w1
=
layer
.
w13_weight
[
i
]
w1
=
layer
.
w13_weight
[
i
]
w2
=
layer
.
w2_weight
[
i
]
w2
=
layer
.
w2_weight
[
i
]
tokens_for_this_expert
=
hidden_states
[
start_idx
:
end_idx
]
tokens_for_this_expert
=
hidden_states
[
start_idx
:
end_idx
]
gateup_output
=
torch
.
matmul
(
tokens_for_this_expert
,
w1
.
T
)
gateup_output
=
torch
.
matmul
(
tokens_for_this_expert
,
w1
)
# Act
# Act
down_input
=
torch
.
zeros
(
down_input
=
torch
.
zeros
(
gateup_output
.
shape
[
0
],
gateup_output
.
shape
[
0
],
gateup_output
.
shape
[
1
]
//
2
,
gateup_output
.
shape
[
1
]
//
2
,
device
=
gateup_output
.
device
,
device
=
gateup_output
.
device
,
dtype
=
hidden_states
.
dtype
dtype
=
hidden_states
.
dtype
)
)
torch
.
ops
.
_C
.
silu_and_mul
(
down_input
,
torch
.
ops
.
_C
.
silu_and_mul
(
down_input
,
gateup_output
.
view
(
-
1
,
w1
.
shape
[
0
]))
gateup_output
.
view
(
-
1
,
w1
.
shape
[
1
]))
expert_out
=
torch
.
matmul
(
down_input
,
w2
.
T
)
expert_out
=
torch
.
matmul
(
down_input
,
w2
)
outputs
.
append
(
expert_out
)
outputs
.
append
(
expert_out
)
start_idx
=
end_idx
start_idx
=
end_idx
if
len
(
outputs
)
>
0
:
if
len
(
outputs
)
>
0
:
expert_output
=
torch
.
cat
(
outputs
,
dim
=
0
)
expert_output
=
torch
.
cat
(
outputs
,
dim
=
0
)
else
:
assert
hidden_states
.
numel
()
==
0
,
f
"sorted_tokens: should be empty, but got
{
hidden_states
.
shape
}
"
expert_output
=
hidden_states
else
:
else
:
assert
hidden_states
.
numel
()
==
0
,
f
"sorted_tokens: should be empty, but got
{
hidden_states
.
shape
}
"
if
self
.
zero_token_count
is
None
:
expert_output
=
hidden_states
self
.
zero_token_count
=
torch
.
zeros
(
1
,
dtype
=
torch
.
int64
,
device
=
hidden_states
.
device
)
total_tokens
=
tokens_per_expert
.
sum
()
if
total_tokens
>
self
.
zero_token_count
:
gateup_output
=
groupgemm
(
hidden_states
,
layer
.
w13_weight
,
tokens_per_expert
,
False
)
# Act
down_input
=
torch
.
zeros
(
gateup_output
.
shape
[
0
],
gateup_output
.
shape
[
1
]
//
2
,
device
=
gateup_output
.
device
,
dtype
=
hidden_states
.
dtype
)
torch
.
ops
.
_C
.
silu_and_mul
(
down_input
,
gateup_output
.
view
(
-
1
,
layer
.
w13_weight
.
shape
[
2
]))
expert_output
=
groupgemm
(
down_input
,
layer
.
w2_weight
,
tokens_per_expert
,
False
)
else
:
expert_output
=
hidden_states
return
expert_output
return
expert_output
...
@@ -157,6 +178,8 @@ class EPMoE(FusedMoE):
...
@@ -157,6 +178,8 @@ class EPMoE(FusedMoE):
apply_router_weight_on_input
:
bool
=
False
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
activation
:
str
=
"silu"
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
enable_eplb
:
bool
=
False
,
num_redundant_experts
:
int
=
0
,
moe_permute_fusion
:
bool
=
True
,
moe_permute_fusion
:
bool
=
True
,
moe_shared_expert_overlap
:
bool
=
False
moe_shared_expert_overlap
:
bool
=
False
):
):
...
@@ -170,7 +193,9 @@ class EPMoE(FusedMoE):
...
@@ -170,7 +193,9 @@ class EPMoE(FusedMoE):
e_score_correction_bias
,
e_score_correction_bias
,
apply_router_weight_on_input
,
apply_router_weight_on_input
,
activation
,
activation
,
routed_scaling_factor
=
routed_scaling_factor
routed_scaling_factor
=
routed_scaling_factor
,
enable_eplb
=
enable_eplb
,
num_redundant_experts
=
num_redundant_experts
,
)
)
self
.
ep_moe_config
:
EpMoeConfig
=
EpMoeConfig
.
make
(
self
.
ep_moe_config
:
EpMoeConfig
=
EpMoeConfig
.
make
(
...
...
vllm/model_executor/layers/fused_moe/ep_moe/token_dispatcher.py
View file @
d997afc4
...
@@ -24,6 +24,7 @@ from vllm.platforms import current_platform
...
@@ -24,6 +24,7 @@ from vllm.platforms import current_platform
cuda_dtoh_stream
=
torch
.
cuda
.
Stream
()
cuda_dtoh_stream
=
torch
.
cuda
.
Stream
()
cuda_dtoh_sync_event
=
torch
.
cuda
.
Event
(
enable_timing
=
False
)
class
MoETokenDispatcher
:
class
MoETokenDispatcher
:
"""
"""
...
@@ -137,7 +138,7 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
...
@@ -137,7 +138,7 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
self
.
output_splits
=
None
self
.
output_splits
=
None
# [tp_size]. Represents the number of tokens received by the current rank from
# [tp_size]. Represents the number of tokens received by the current rank from
# other TP ranks.
# other TP ranks.
self
.
output_splits_tp
=
None
#
self.output_splits_tp = None
self
.
permute_idx_device
=
torch
.
device
(
"cuda"
)
if
self
.
config
.
moe_permute_fusion
else
None
self
.
permute_idx_device
=
torch
.
device
(
"cuda"
)
if
self
.
config
.
moe_permute_fusion
else
None
input_chunk_idxs
=
torch
.
arange
(
input_chunk_idxs
=
torch
.
arange
(
self
.
num_experts
*
self
.
tp_size
,
device
=
self
.
permute_idx_device
self
.
num_experts
*
self
.
tp_size
,
device
=
self
.
permute_idx_device
...
@@ -211,6 +212,7 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
...
@@ -211,6 +212,7 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
if
self
.
use_all_gather
:
if
self
.
use_all_gather
:
# Gather is not supported for some devices such as TPUs.
# Gather is not supported for some devices such as TPUs.
# Use all-gather instead.
# Use all-gather instead.
num_global_tokens_per_expert
=
expert_parallel_all_gather
(
num_local_tokens_per_expert
)
\
num_global_tokens_per_expert
=
expert_parallel_all_gather
(
num_local_tokens_per_expert
)
\
.
reshape
(
self
.
ep_size
,
self
.
tp_size
,
self
.
num_experts
)
\
.
reshape
(
self
.
ep_size
,
self
.
tp_size
,
self
.
num_experts
)
\
.
transpose
(
0
,
1
)
.
transpose
(
0
,
1
)
...
@@ -233,7 +235,7 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
...
@@ -233,7 +235,7 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
# [tp_size, ep_size] -> [tp_size]
# [tp_size, ep_size] -> [tp_size]
# self.output_splits_tp represents the number of tokens received by the current
# self.output_splits_tp represents the number of tokens received by the current
# rank from other TP rank.
# rank from other TP rank.
self
.
output_splits_tp
=
num_global_tokens_per_rank
.
sum
(
axis
=
1
)
#
self.output_splits_tp = num_global_tokens_per_rank.sum(axis=1)
# [tp_size, ep_size, num_local_experts] -> [num_local_experts]
# [tp_size, ep_size, num_local_experts] -> [num_local_experts]
num_tokens_per_local_expert
=
num_global_tokens_per_local_expert
.
sum
(
dim
=
(
0
,
1
))
num_tokens_per_local_expert
=
num_global_tokens_per_local_expert
.
sum
(
dim
=
(
0
,
1
))
...
@@ -317,11 +319,16 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
...
@@ -317,11 +319,16 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
num_out_tokens
=
self
.
num_out_tokens
,
num_out_tokens
=
self
.
num_out_tokens
,
fused
=
self
.
config
.
moe_permute_fusion
fused
=
self
.
config
.
moe_permute_fusion
)
)
# Perform expert parallel AlltoAll communication
# Perform expert parallel AlltoAll communication
tokens_per_expert
=
self
.
_maybe_dtoh_and_synchronize
(
# tokens_per_expert = self._maybe_dtoh_and_synchronize(
"before_ep_alltoall"
,
tokens_per_expert
# "before_ep_alltoall", tokens_per_expert
)
# )
###test##############
#cuda_dtoh_stream.synchronize()
#cuda_dtoh_sync_event.synchronize()
###test##############
global_input_tokens
=
all_to_all
(
global_input_tokens
=
all_to_all
(
self
.
ep_group
.
device_group
,
permutated_local_input_tokens
,
self
.
output_splits
,
self
.
input_splits
self
.
ep_group
.
device_group
,
permutated_local_input_tokens
,
self
.
output_splits
,
self
.
input_splits
...
@@ -331,9 +338,9 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
...
@@ -331,9 +338,9 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
self
.
shared_experts
.
linear_fc1_forward_and_act
(
global_input_tokens
)
self
.
shared_experts
.
linear_fc1_forward_and_act
(
global_input_tokens
)
# Permutation 2: Sort tokens by local expert.
# Permutation 2: Sort tokens by local expert.
tokens_per_expert
=
self
.
_maybe_dtoh_and_synchronize
(
#
tokens_per_expert = self._maybe_dtoh_and_synchronize(
"before_permutation_2"
,
tokens_per_expert
#
"before_permutation_2", tokens_per_expert
)
#
)
if
self
.
num_local_experts
>
1
:
if
self
.
num_local_experts
>
1
:
global_input_tokens
=
sort_chunks_by_idxs
(
global_input_tokens
=
sort_chunks_by_idxs
(
global_input_tokens
,
global_input_tokens
,
...
@@ -342,7 +349,7 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
...
@@ -342,7 +349,7 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
fused
=
self
.
config
.
moe_permute_fusion
,
fused
=
self
.
config
.
moe_permute_fusion
,
)
)
tokens_per_expert
=
self
.
_maybe_dtoh_and_synchronize
(
"before_finish"
,
tokens_per_expert
)
#
tokens_per_expert = self._maybe_dtoh_and_synchronize("before_finish", tokens_per_expert)
return
global_input_tokens
,
tokens_per_expert
return
global_input_tokens
,
tokens_per_expert
...
@@ -444,9 +451,9 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
...
@@ -444,9 +451,9 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
self
.
output_splits
=
maybe_move_tensor_to_cpu
(
self
.
output_splits
=
maybe_move_tensor_to_cpu
(
self
.
output_splits
,
as_numpy
=
True
,
record_stream
=
on_side_stream
self
.
output_splits
,
as_numpy
=
True
,
record_stream
=
on_side_stream
)
)
self
.
output_splits_tp
=
maybe_move_tensor_to_cpu
(
#
self.output_splits_tp = maybe_move_tensor_to_cpu(
self
.
output_splits_tp
,
as_numpy
=
True
,
record_stream
=
on_side_stream
#
self.output_splits_tp, as_numpy=True, record_stream=on_side_stream
)
#
)
self
.
num_out_tokens
=
maybe_move_tensor_to_cpu
(
self
.
num_out_tokens
=
maybe_move_tensor_to_cpu
(
self
.
num_out_tokens
,
record_stream
=
on_side_stream
self
.
num_out_tokens
,
record_stream
=
on_side_stream
)
)
...
@@ -455,6 +462,8 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
...
@@ -455,6 +462,8 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
self
.
num_global_tokens_per_local_expert
,
record_stream
=
on_side_stream
self
.
num_global_tokens_per_local_expert
,
record_stream
=
on_side_stream
)
)
#cuda_dtoh_sync_event.record()
if
point
==
self
.
cuda_sync_point
:
if
point
==
self
.
cuda_sync_point
:
# Synchronize with the dtoh stream at self.cuda_sync_point.
# Synchronize with the dtoh stream at self.cuda_sync_point.
cuda_dtoh_stream
.
synchronize
()
cuda_dtoh_stream
.
synchronize
()
...
...
vllm/model_executor/layers/fused_moe/layer.py
View file @
d997afc4
...
@@ -772,12 +772,16 @@ class FusedMoE(torch.nn.Module):
...
@@ -772,12 +772,16 @@ class FusedMoE(torch.nn.Module):
self
.
moe_config
=
moe
self
.
moe_config
=
moe
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
quant_method
=
self
.
create_quant_method
(
moe
,
quant_config
,
prefix
)
quant_method
=
self
.
create_quant_method
(
moe
,
quant_config
,
prefix
)
assert
quant_method
is
not
None
assert
isinstance
(
quant_method
,
FusedMoEMethodBase
)
self
.
quant_method
=
quant_method
if
self
.
enable_eplb
:
if
self
.
enable_eplb
:
from
vllm.model_executor.layers.quantization.fp8
import
(
from
vllm.model_executor.layers.quantization.fp8
import
(
Fp8MoEMethod
)
Fp8MoEMethod
)
if
not
isinstance
(
self
.
quant_method
,
Fp8MoEMethod
):
if
not
isinstance
(
quant_method
,
Fp8MoEMethod
):
# TODO: Add support for additional quantization methods.
# TODO: Add support for additional quantization methods.
# The implementation for other quantization methods does not
# The implementation for other quantization methods does not
# contain essential differences, but the current quant API
# contain essential differences, but the current quant API
...
...
vllm/model_executor/models/deepseek_mtp.py
View file @
d997afc4
...
@@ -11,6 +11,7 @@ import torch
...
@@ -11,6 +11,7 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
import
vllm.envs
as
envs
from
vllm.config
import
CacheConfig
,
ModelConfig
,
VllmConfig
from
vllm.config
import
CacheConfig
,
ModelConfig
,
VllmConfig
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
...
@@ -24,6 +25,7 @@ from vllm.sequence import IntermediateTensors
...
@@ -24,6 +25,7 @@ from vllm.sequence import IntermediateTensors
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.compilation.decorators
import
support_torch_compile
from
.deepseek_v2
import
(
DeepseekV2DecoderLayer
,
from
.deepseek_v2
import
(
DeepseekV2DecoderLayer
,
get_spec_layer_idx_from_weight_name
)
get_spec_layer_idx_from_weight_name
)
from
vllm.distributed
import
get_dp_group
from
.interfaces
import
SupportsPP
from
.interfaces
import
SupportsPP
from
.utils
import
maybe_prefix
from
.utils
import
maybe_prefix
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
...
@@ -174,6 +176,10 @@ class DeepSeekMTP(nn.Module, SupportsPP):
...
@@ -174,6 +176,10 @@ class DeepSeekMTP(nn.Module, SupportsPP):
prefix
,
"model"
))
prefix
,
"model"
))
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
parallel_config
=
vllm_config
.
parallel_config
dp_size
=
get_dp_group
().
world_size
self
.
use_all2all_ep
=
envs
.
VLLM_USE_ALLTOALL_EP
and
dp_size
>
1
and
parallel_config
.
enable_expert_parallel
def
forward
(
def
forward
(
self
,
self
,
...
@@ -205,6 +211,10 @@ class DeepSeekMTP(nn.Module, SupportsPP):
...
@@ -205,6 +211,10 @@ class DeepSeekMTP(nn.Module, SupportsPP):
(
"gate_up_proj"
,
"up_proj"
,
1
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
]
if
self
.
use_all2all_ep
:
ep_moe_shared_experts_keys
=
"mlp.shared_experts"
ep_moe_shared_experts_mapping
=
{
ep_moe_shared_experts_keys
:
"mlp.experts.shared_experts"
}
expert_params_mapping
=
FusedMoE
.
make_expert_params_mapping
(
expert_params_mapping
=
FusedMoE
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
...
@@ -233,6 +243,9 @@ class DeepSeekMTP(nn.Module, SupportsPP):
...
@@ -233,6 +243,9 @@ class DeepSeekMTP(nn.Module, SupportsPP):
if
((
"mlp.experts."
in
name
)
and
name
not
in
params_dict
):
if
((
"mlp.experts."
in
name
)
and
name
not
in
params_dict
):
continue
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
name
=
name
.
replace
(
weight_name
,
param_name
)
if
self
.
use_all2all_ep
:
name
=
name
.
replace
(
ep_moe_shared_experts_keys
,
ep_moe_shared_experts_mapping
[
ep_moe_shared_experts_keys
])
# Skip loading extra bias for GPTQ models.
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
continue
...
@@ -248,6 +261,9 @@ class DeepSeekMTP(nn.Module, SupportsPP):
...
@@ -248,6 +261,9 @@ class DeepSeekMTP(nn.Module, SupportsPP):
continue
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
name
=
name
.
replace
(
weight_name
,
param_name
)
if
self
.
use_all2all_ep
:
name
=
name
.
replace
(
ep_moe_shared_experts_keys
,
ep_moe_shared_experts_mapping
[
ep_moe_shared_experts_keys
])
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
weight_loader
(
param
,
...
@@ -257,6 +273,8 @@ class DeepSeekMTP(nn.Module, SupportsPP):
...
@@ -257,6 +273,8 @@ class DeepSeekMTP(nn.Module, SupportsPP):
expert_id
=
expert_id
)
expert_id
=
expert_id
)
break
break
else
:
else
:
if
self
.
use_all2all_ep
:
name
=
name
.
replace
(
ep_moe_shared_experts_keys
,
ep_moe_shared_experts_mapping
[
ep_moe_shared_experts_keys
])
# Skip loading extra bias for GPTQ models.
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
continue
...
...
vllm/model_executor/models/deepseek_v2.py
View file @
d997afc4
...
@@ -155,9 +155,9 @@ class DeepseekV2MoE(nn.Module):
...
@@ -155,9 +155,9 @@ class DeepseekV2MoE(nn.Module):
self
.
n_local_physical_experts
)
self
.
n_local_physical_experts
)
dp_size
=
get_dp_group
().
world_size
dp_size
=
get_dp_group
().
world_size
self
.
use_
ep_opt
=
dp_size
>
1
and
parallel_config
.
enable_expert_parallel
self
.
use_
all2all_ep
=
envs
.
VLLM_USE_ALLTOALL_EP
and
dp_size
>
1
and
parallel_config
.
enable_expert_parallel
moe_cls
=
FusedMoE
if
not
self
.
use_
ep_opt
else
EPMoE
moe_cls
=
FusedMoE
if
not
self
.
use_
all2all_ep
else
EPMoE
self
.
experts
=
moe_cls
(
self
.
experts
=
moe_cls
(
num_experts
=
config
.
n_routed_experts
,
num_experts
=
config
.
n_routed_experts
,
top_k
=
config
.
num_experts_per_tok
,
top_k
=
config
.
num_experts_per_tok
,
...
@@ -172,12 +172,14 @@ class DeepseekV2MoE(nn.Module):
...
@@ -172,12 +172,14 @@ class DeepseekV2MoE(nn.Module):
prefix
=
f
"
{
prefix
}
.experts"
,
prefix
=
f
"
{
prefix
}
.experts"
,
scoring_func
=
config
.
scoring_func
,
scoring_func
=
config
.
scoring_func
,
e_score_correction_bias
=
self
.
gate
.
e_score_correction_bias
,
e_score_correction_bias
=
self
.
gate
.
e_score_correction_bias
,
enable_eplb
=
self
.
enable_eplb
,
num_redundant_experts
=
self
.
n_redundant_experts
,
routed_scaling_factor
=
self
.
routed_scaling_factor
)
routed_scaling_factor
=
self
.
routed_scaling_factor
)
if
config
.
n_shared_experts
is
not
None
:
if
config
.
n_shared_experts
is
not
None
:
intermediate_size
=
(
config
.
moe_intermediate_size
*
intermediate_size
=
(
config
.
moe_intermediate_size
*
config
.
n_shared_experts
)
config
.
n_shared_experts
)
shared_expert_cls
=
DeepseekV2MLP
if
not
self
.
use_
ep_opt
else
EPSharedExperts
shared_expert_cls
=
DeepseekV2MLP
if
not
self
.
use_
all2all_ep
else
EPSharedExperts
self
.
shared_experts
=
shared_expert_cls
(
self
.
shared_experts
=
shared_expert_cls
(
hidden_size
=
config
.
hidden_size
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
intermediate_size
,
intermediate_size
=
intermediate_size
,
...
@@ -187,8 +189,8 @@ class DeepseekV2MoE(nn.Module):
...
@@ -187,8 +189,8 @@ class DeepseekV2MoE(nn.Module):
),
),
prefix
=
f
"
{
prefix
}
.shared_experts"
,
prefix
=
f
"
{
prefix
}
.shared_experts"
,
)
)
if
self
.
use_all2all_ep
:
self
.
experts
.
set_shared_experts
(
self
.
shared_experts
)
self
.
experts
.
set_shared_experts
(
self
.
shared_experts
)
from
vllm.two_batch_overlap.two_batch_overlap
import
tbo_all_reduce
from
vllm.two_batch_overlap.two_batch_overlap
import
tbo_all_reduce
self
.
tbo_all_reduce
=
tbo_all_reduce
self
.
tbo_all_reduce
=
tbo_all_reduce
...
@@ -196,13 +198,13 @@ class DeepseekV2MoE(nn.Module):
...
@@ -196,13 +198,13 @@ class DeepseekV2MoE(nn.Module):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
num_tokens
,
hidden_dim
=
hidden_states
.
shape
num_tokens
,
hidden_dim
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
if
not
self
.
use_
ep_opt
:
if
not
self
.
use_
all2all_ep
:
if
self
.
n_shared_experts
is
not
None
:
if
self
.
n_shared_experts
is
not
None
:
shared_output
=
self
.
shared_experts
(
hidden_states
)
shared_output
=
self
.
shared_experts
(
hidden_states
)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
if
not
self
.
use_
ep_opt
:
if
not
self
.
use_
all2all_ep
:
if
hidden_states
.
dtype
!=
torch
.
float16
:
if
hidden_states
.
dtype
!=
torch
.
float16
:
final_hidden_states
=
self
.
experts
(
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
...
@@ -216,7 +218,7 @@ class DeepseekV2MoE(nn.Module):
...
@@ -216,7 +218,7 @@ class DeepseekV2MoE(nn.Module):
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
router_logits
=
router_logits
)
if
not
self
.
use_
ep_opt
:
if
not
self
.
use_
all2all_ep
:
if
shared_output
is
not
None
:
if
shared_output
is
not
None
:
if
hidden_states
.
dtype
!=
torch
.
float16
or
self
.
dpsk_fp16_quick
:
if
hidden_states
.
dtype
!=
torch
.
float16
or
self
.
dpsk_fp16_quick
:
final_hidden_states
=
final_hidden_states
+
shared_output
final_hidden_states
=
final_hidden_states
+
shared_output
...
@@ -637,8 +639,6 @@ class DeepseekV2DecoderLayer(nn.Module):
...
@@ -637,8 +639,6 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
)
)
#ops.print_tensor(hidden_states)
if
hidden_states
.
dtype
==
torch
.
float16
and
not
self
.
dpsk_fp16_quick
:
if
hidden_states
.
dtype
==
torch
.
float16
and
not
self
.
dpsk_fp16_quick
:
# Fix FP16 overflow
# Fix FP16 overflow
# We scale both hidden_states and residual before
# We scale both hidden_states and residual before
...
@@ -808,7 +808,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
...
@@ -808,7 +808,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
parallel_config
=
vllm_config
.
parallel_config
parallel_config
=
vllm_config
.
parallel_config
dp_size
=
get_dp_group
().
world_size
dp_size
=
get_dp_group
().
world_size
self
.
use_
ep_opt
=
dp_size
>
1
and
parallel_config
.
enable_expert_parallel
self
.
use_
all2all_ep
=
envs
.
VLLM_USE_ALLTOALL_EP
and
dp_size
>
1
and
parallel_config
.
enable_expert_parallel
def
set_eplb_state
(
def
set_eplb_state
(
self
,
self
,
...
@@ -891,7 +891,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
...
@@ -891,7 +891,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
(
"gate_up_proj"
,
"up_proj"
,
1
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
]
if
self
.
use_
ep_opt
:
if
self
.
use_
all2all_ep
:
ep_moe_shared_experts_keys
=
"mlp.shared_experts"
ep_moe_shared_experts_keys
=
"mlp.shared_experts"
ep_moe_shared_experts_mapping
=
{
ep_moe_shared_experts_keys
:
"mlp.experts.shared_experts"
}
ep_moe_shared_experts_mapping
=
{
ep_moe_shared_experts_keys
:
"mlp.experts.shared_experts"
}
...
@@ -928,7 +928,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
...
@@ -928,7 +928,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
continue
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
name
=
name
.
replace
(
weight_name
,
param_name
)
if
self
.
use_
ep_opt
:
if
self
.
use_
all2all_ep
:
name
=
name
.
replace
(
ep_moe_shared_experts_keys
,
ep_moe_shared_experts_mapping
[
ep_moe_shared_experts_keys
])
name
=
name
.
replace
(
ep_moe_shared_experts_keys
,
ep_moe_shared_experts_mapping
[
ep_moe_shared_experts_keys
])
# Skip loading extra bias for GPTQ models.
# Skip loading extra bias for GPTQ models.
...
@@ -957,7 +957,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
...
@@ -957,7 +957,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
# Instead, create a new variable
# Instead, create a new variable
name_mapped
=
name
.
replace
(
weight_name
,
param_name
)
name_mapped
=
name
.
replace
(
weight_name
,
param_name
)
if
self
.
use_
ep_opt
:
if
self
.
use_
all2all_ep
:
name_mapped
=
name_mapped
.
replace
(
ep_moe_shared_experts_keys
,
ep_moe_shared_experts_mapping
[
ep_moe_shared_experts_keys
])
name_mapped
=
name_mapped
.
replace
(
ep_moe_shared_experts_keys
,
ep_moe_shared_experts_mapping
[
ep_moe_shared_experts_keys
])
if
is_pp_missing_parameter
(
name_mapped
,
self
):
if
is_pp_missing_parameter
(
name_mapped
,
self
):
...
@@ -985,7 +985,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
...
@@ -985,7 +985,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
# So we simply skip it
# So we simply skip it
continue
continue
if
self
.
use_
ep_opt
:
if
self
.
use_
all2all_ep
:
name
=
name
.
replace
(
ep_moe_shared_experts_keys
,
ep_moe_shared_experts_mapping
[
ep_moe_shared_experts_keys
])
name
=
name
.
replace
(
ep_moe_shared_experts_keys
,
ep_moe_shared_experts_mapping
[
ep_moe_shared_experts_keys
])
# Skip loading extra bias for GPTQ models.
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
...
...
vllm/v1/spec_decode/eagle.py
View file @
d997afc4
...
@@ -86,6 +86,9 @@ class EagleProposer:
...
@@ -86,6 +86,9 @@ class EagleProposer:
1
,
1
,
device
=
device
,
device
=
device
,
dtype
=
torch
.
int32
)
dtype
=
torch
.
int32
)
self
.
dp_size
=
self
.
vllm_config
.
parallel_config
.
data_parallel_size
self
.
enable_expert_parallel
=
vllm_config
.
parallel_config
.
enable_expert_parallel
def
propose
(
def
propose
(
self
,
self
,
...
@@ -510,6 +513,14 @@ class EagleProposer:
...
@@ -510,6 +513,14 @@ class EagleProposer:
self
.
hidden_states
[:
num_tokens
],
self
.
hidden_states
[:
num_tokens
],
)
)
if
self
.
dp_size
>
1
and
self
.
enable_expert_parallel
and
self
.
num_speculative_tokens
>
1
:
for
_
in
range
(
self
.
num_speculative_tokens
-
1
):
self
.
model
(
self
.
input_ids
[:
num_tokens
],
self
.
positions
[:
num_tokens
],
self
.
hidden_states
[:
num_tokens
],
)
def
validate_same_kv_cache_group
(
self
,
def
validate_same_kv_cache_group
(
self
,
kv_cache_config
:
KVCacheConfig
)
->
None
:
kv_cache_config
:
KVCacheConfig
)
->
None
:
"""
"""
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
d997afc4
...
@@ -319,6 +319,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -319,6 +319,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# from the KV cache of `shared_kv_cache_layers[layer_name]`.
# from the KV cache of `shared_kv_cache_layers[layer_name]`.
self
.
shared_kv_cache_layers
:
dict
[
str
,
str
]
=
{}
self
.
shared_kv_cache_layers
:
dict
[
str
,
str
]
=
{}
dp_size
=
self
.
vllm_config
.
parallel_config
.
data_parallel_size
self
.
use_all2all_ep
=
envs
.
VLLM_USE_ALLTOALL_EP
and
dp_size
>
1
and
parallel_config
.
enable_expert_parallel
def
_may_reorder_batch
(
self
,
scheduler_output
:
"SchedulerOutput"
)
->
None
:
def
_may_reorder_batch
(
self
,
scheduler_output
:
"SchedulerOutput"
)
->
None
:
"""
"""
Update the order of requests in the batch based on the attention
Update the order of requests in the batch based on the attention
...
@@ -1229,7 +1232,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1229,7 +1232,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# TODO(tms) : There are many cases where padding is enabled for
# TODO(tms) : There are many cases where padding is enabled for
# prefills, causing unnecessary and excessive padding of activations.
# prefills, causing unnecessary and excessive padding of activations.
if
dp_size
==
1
or
self
.
vllm_config
.
model_config
.
enforce_eager
:
if
dp_size
==
1
or
self
.
vllm_config
.
model_config
.
enforce_eager
or
self
.
use_all2all_ep
:
# Early exit.
# Early exit.
return
0
,
None
return
0
,
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment