Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5dcc5cb8
Commit
5dcc5cb8
authored
Oct 09, 2025
by
王敏
Browse files
优化mori ep
parent
e0ba23b5
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
179 additions
and
186 deletions
+179
-186
vllm/distributed/parallel_state.py
vllm/distributed/parallel_state.py
+1
-1
vllm/model_executor/layers/fused_moe/ep_moe/layer.py
vllm/model_executor/layers/fused_moe/ep_moe/layer.py
+178
-185
No files found.
vllm/distributed/parallel_state.py
View file @
5dcc5cb8
...
...
@@ -948,7 +948,7 @@ def init_distributed_environment(
"Fallback Gloo backend is not available."
)
backend
=
"gloo"
# this backend is used for WORLD
parallel_config
=
config
.
parallel_config
data_parallel_size
=
parallel_config
.
data_parallel_size
use_mori_ep
=
envs
.
VLLM_USE_MORI_EP
and
data_parallel_size
>
1
and
parallel_config
.
enable_expert_parallel
if
use_mori_ep
:
...
...
vllm/model_executor/layers/fused_moe/ep_moe/layer.py
View file @
5dcc5cb8
...
...
@@ -10,30 +10,31 @@ import torch.nn.functional as F
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.config
import
get_current_vllm_config
from
vllm.distributed.parallel_state
import
get_ep_group
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.forward_context
import
ForwardContext
,
get_forward_context
from
vllm.distributed.parallel_state
import
get_ep_group
,
get_node_count
from
vllm.model_executor.layers.fused_moe.config
import
FusedMoEConfig
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.distributed
import
expert_parallel_all_gather
,
expert_parallel_all_reduce
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.fused_moe.layer
import
FusedMoEMethodBase
,
UnquantizedFusedMoEMethod
from
vllm.model_executor.layers.fused_moe.ep_moe.token_dispatcher
import
MoEAlltoAllTokenDispatcher
from
vllm.model_executor.layers.fused_moe.ep_moe.ep_moe_utlis
import
EpMoeConfig
from
vllm.utils
import
direct_register_custom_op
import
mori
import
torch.distributed
as
dist
from
lmslim.layers.gemm.int8_utils
import
(
per_token_group_quant_int8
,
per_token_quant_int8
)
try
:
import
mori
from
lmslim.layers.gemm.int8_utils
import
(
per_token_quant_int8
)
except
ImportError
:
is_mori_available
=
False
logger
=
init_logger
(
__name__
)
_MORI_OP
=
None
@
CustomOp
.
register
(
"unquantized_ep_moe"
)
class
UnquantizedEPGroupedGemmMethod
(
UnquantizedFusedMoEMethod
):
"""MoE method without quantization."""
...
...
@@ -43,20 +44,20 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
self
.
topk_indices_dtype
=
None
self
.
moe
=
moe
self
.
rocm_aiter_moe_enabled
=
False
# is_rocm_aiter_moe_enabled()
self
.
rocm_aiter_moe_enabled
=
False
# is_rocm_aiter_moe_enabled()
def
apply_ep
(
self
,
layer
:
torch
.
nn
.
Module
,
hidden_states
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
inplace
=
True
,
activation
:
str
=
"silu"
,
apply_router_weight_on_input
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
self
,
layer
:
torch
.
nn
.
Module
,
hidden_states
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
inplace
=
True
,
activation
:
str
=
"silu"
,
apply_router_weight_on_input
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
)
->
torch
.
Tensor
:
return
self
.
forward
(
...
...
@@ -72,17 +73,17 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
use_nn_moe
=
use_nn_moe
)
def
forward_cuda
(
self
,
layer
:
torch
.
nn
.
Module
,
hidden_states
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
inplace
=
True
,
activation
:
str
=
"silu"
,
apply_router_weight_on_input
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
self
,
layer
:
torch
.
nn
.
Module
,
hidden_states
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
inplace
=
True
,
activation
:
str
=
"silu"
,
apply_router_weight_on_input
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
)
->
torch
.
Tensor
:
# process MoE
...
...
@@ -108,48 +109,48 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
return
output
def
forward_cpu
(
self
,
layer
:
torch
.
nn
.
Module
,
hidden_states
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
inplace
=
True
,
activation
:
str
=
"silu"
,
apply_router_weight_on_input
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
**
kwargs
,
self
,
layer
:
torch
.
nn
.
Module
,
hidden_states
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
inplace
=
True
,
activation
:
str
=
"silu"
,
apply_router_weight_on_input
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
**
kwargs
,
):
raise
NotImplementedError
def
forward_hpu
(
self
,
layer
:
torch
.
nn
.
Module
,
hidden_states
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
inplace
=
True
,
activation
:
str
=
"silu"
,
apply_router_weight_on_input
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
self
,
layer
:
torch
.
nn
.
Module
,
hidden_states
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
inplace
=
True
,
activation
:
str
=
"silu"
,
apply_router_weight_on_input
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
)
->
torch
.
Tensor
:
raise
NotImplementedError
def
forward_tpu
(
self
,
layer
:
torch
.
nn
.
Module
,
hidden_states
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
inplace
=
True
,
activation
:
str
=
"silu"
,
apply_router_weight_on_input
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
self
,
layer
:
torch
.
nn
.
Module
,
hidden_states
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
inplace
=
True
,
activation
:
str
=
"silu"
,
apply_router_weight_on_input
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
)
->
torch
.
Tensor
:
raise
NotImplementedError
...
...
@@ -166,49 +167,50 @@ class EPMoE(FusedMoE):
dp+ep MoE Expert Parallel Impl
"""
def
__init__
(
self
,
num_experts
:
int
,
# Global number of experts
top_k
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
reduce_results
:
bool
=
False
,
renormalize
:
bool
=
True
,
use_grouped_topk
:
bool
=
False
,
num_expert_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
tp_size
:
Optional
[
int
]
=
None
,
ep_size
:
Optional
[
int
]
=
None
,
dp_size
:
Optional
[
int
]
=
None
,
prefix
:
str
=
""
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
enable_eplb
:
bool
=
False
,
num_redundant_experts
:
int
=
0
,
moe_permute_fusion
:
bool
=
False
,
moe_shared_expert_overlap
:
bool
=
False
self
,
num_experts
:
int
,
# Global number of experts
top_k
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
reduce_results
:
bool
=
False
,
renormalize
:
bool
=
True
,
use_grouped_topk
:
bool
=
False
,
num_expert_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
tp_size
:
Optional
[
int
]
=
None
,
ep_size
:
Optional
[
int
]
=
None
,
dp_size
:
Optional
[
int
]
=
None
,
prefix
:
str
=
""
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
enable_eplb
:
bool
=
False
,
num_redundant_experts
:
int
=
0
,
moe_permute_fusion
:
bool
=
False
,
moe_shared_expert_overlap
:
bool
=
False
):
super
().
__init__
(
num_experts
,
top_k
,
hidden_size
,
intermediate_size
,
params_dtype
,
reduce_results
,
renormalize
,
use_grouped_topk
,
num_expert_group
,
topk_group
,
quant_config
,
tp_size
,
ep_size
,
dp_size
,
prefix
,
ep_size
,
dp_size
,
prefix
,
custom_routing_function
,
scoring_func
,
e_score_correction_bias
,
apply_router_weight_on_input
,
activation
,
activation
,
routed_scaling_factor
=
routed_scaling_factor
,
enable_eplb
=
enable_eplb
,
num_redundant_experts
=
num_redundant_experts
,
)
self
.
ep_moe_config
:
EpMoeConfig
=
EpMoeConfig
.
make
(
moe_router_topk
=
self
.
top_k
,
# TODO: support fusion permute
...
...
@@ -221,7 +223,7 @@ class EPMoE(FusedMoE):
)
local_expert_indices_offset
=
(
self
.
ep_rank
*
self
.
local_num_experts
self
.
ep_rank
*
self
.
local_num_experts
)
self
.
local_expert_indices
=
[
local_expert_indices_offset
+
i
for
i
in
range
(
self
.
local_num_experts
)
...
...
@@ -229,10 +231,10 @@ class EPMoE(FusedMoE):
self
.
use_shared_expert
=
False
self
.
token_dispatcher
=
MoEAlltoAllTokenDispatcher
(
self
.
local_num_experts
,
self
.
local_expert_indices
,
config
=
self
.
ep_moe_config
,
layer_name
=
f
"
{
self
.
layer_name
}
.token_dispatcher"
,
)
self
.
local_num_experts
,
self
.
local_expert_indices
,
config
=
self
.
ep_moe_config
,
layer_name
=
f
"
{
self
.
layer_name
}
.token_dispatcher"
,
)
self
.
shared_expert_overlap
=
moe_shared_expert_overlap
self
.
shared_experts
=
None
...
...
@@ -241,29 +243,30 @@ class EPMoE(FusedMoE):
self
.
scales
=
None
self
.
use_int8_dispatch
=
True
vllm_config
=
get_current_vllm_config
()
self
.
max_num_inp_token_per_rank
=
vllm_config
.
scheduler_config
.
max_num_seqs
self
.
mori_op
=
self
.
get_mori_op
()
self
.
first
=
True
def
get_mori_op
(
self
):
global
_MORI_OP
if
_MORI_OP
is
None
:
# world_group = torch.distributed.group.WORLD
# assert world_group is not None
#torch._C._distributed_c10d._register_process_group("mori_ep", get_ep_group().device_group)
#mori.shmem.shmem_torch_process_group_init("mori_ep")
world_group
=
torch
.
distributed
.
group
.
WORLD
assert
world_group
is
not
None
torch
.
_C
.
_distributed_c10d
.
_register_process_group
(
"default"
,
world_group
)
mori
.
shmem
.
shmem_torch_process_group_init
(
"default"
)
torch
.
_C
.
_distributed_c10d
.
_register_process_group
(
"mori_ep"
,
get_ep_group
().
device_group
)
mori
.
shmem
.
shmem_torch_process_group_init
(
"mori_ep"
)
# world_group = torch.distributed.group.WORLD
# assert world_group is not None
# torch._C._distributed_c10d._register_process_group("default", world_group)
# mori.shmem.shmem_torch_process_group_init("default")
vllm_config
=
get_current_vllm_config
()
multi_node
=
self
.
ep_size
/
8
>
1
mori_data_type
=
vllm_config
.
model_config
.
dtype
mori_data_type
=
vllm_config
.
model_config
.
dtype
mori_scale_type_size
=
vllm_config
.
model_config
.
dtype
.
itemsize
if
self
.
use_int8_dispatch
:
mori_scale_type_size
=
4
mori_scale_type_size
=
4
config
=
mori
.
ops
.
EpDispatchCombineConfig
(
data_type
=
mori_data_type
,
...
...
@@ -272,17 +275,18 @@ class EPMoE(FusedMoE):
hidden_dim
=
self
.
hidden_size
,
scale_dim
=
1
if
self
.
use_int8_dispatch
else
0
,
scale_type_size
=
mori_scale_type_size
,
max_num_inp_token_per_rank
=
2048
,
max_num_inp_token_per_rank
=
self
.
max_num_inp_token_per_rank
,
num_experts_per_rank
=
self
.
local_num_experts
,
num_experts_per_token
=
self
.
top_k
,
max_token_type_size
=
2
,
block_num
=
80
,
warp_num_per_block
=
16
,
# kernel_type=mori.ops.EpDispatchCombineKernelType.InterNode
kernel_type
=
mori
.
ops
.
EpDispatchCombineKernelType
.
InterNode
if
multi_node
else
\
mori
.
ops
.
EpDispatchCombineKernelType
.
IntraNode
)
_MORI_OP
=
mori
.
ops
.
EpDispatchCombineOp
(
config
)
return
_MORI_OP
def
set_shared_experts
(
self
,
shared_experts
:
torch
.
nn
.
Module
):
...
...
@@ -302,15 +306,15 @@ class EPMoE(FusedMoE):
assert
quant_method
is
not
None
assert
isinstance
(
quant_method
,
FusedMoEMethodBase
)
return
quant_method
def
sync
(
self
):
#torch.cuda.synchronize()
#
torch.cuda.synchronize()
dist
.
barrier
()
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
):
return
torch
.
ops
.
vllm
.
ep_moe_forward
(
hidden_states
,
router_logits
,
self
.
layer_name
)
self
.
layer_name
)
def
get_expert_weights
(
self
)
->
Iterable
[
torch
.
Tensor
]:
weights
=
list
(
self
.
named_parameters
())
...
...
@@ -329,30 +333,29 @@ class EPMoE(FusedMoE):
return
[
weight
.
view
(
self
.
local_num_experts
,
-
1
)
for
name
,
weight
in
weights
if
name
not
in
NON_EXPERT_WEIGHTS
]
]
def
forward_impl
(
self
,
hidden_states
:
torch
.
Tensor
,
def
forward_impl
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
):
topk_weights
,
topk_ids
=
self
.
select_experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
,
use_grouped_topk
=
self
.
use_grouped_topk
,
top_k
=
self
.
top_k
,
renormalize
=
self
.
renormalize
,
topk_group
=
self
.
topk_group
,
num_expert_group
=
self
.
num_expert_group
,
custom_routing_function
=
self
.
custom_routing_function
,
scoring_func
=
self
.
scoring_func
,
e_score_correction_bias
=
self
.
e_score_correction_bias
,
indices_type
=
torch
.
int32
,
routed_scaling_factor
=
self
.
routed_scaling_factor
,
use_fused_gate
=
self
.
use_fused_gate
)
hidden_states
=
hidden_states
,
router_logits
=
router_logits
,
use_grouped_topk
=
self
.
use_grouped_topk
,
top_k
=
self
.
top_k
,
renormalize
=
self
.
renormalize
,
topk_group
=
self
.
topk_group
,
num_expert_group
=
self
.
num_expert_group
,
custom_routing_function
=
self
.
custom_routing_function
,
scoring_func
=
self
.
scoring_func
,
e_score_correction_bias
=
self
.
e_score_correction_bias
,
indices_type
=
torch
.
int32
,
routed_scaling_factor
=
self
.
routed_scaling_factor
,
use_fused_gate
=
self
.
use_fused_gate
)
if
not
self
.
ep_moe_config
.
moe_shared_expert_overlap
and
self
.
shared_experts
is
not
None
:
shared_output
=
self
.
shared_experts
(
hidden_states
)
if
self
.
use_int8_dispatch
:
hidden_states
,
scales
=
per_token_quant_int8
(
hidden_states
)
else
:
...
...
@@ -365,75 +368,64 @@ class EPMoE(FusedMoE):
)
scales
=
self
.
scales
#self.sync()
# self.sync()
(
dispatch_output
,
dispatch_weights
,
dispatch_scales
,
dispatch_indices
,
dispatch_recv_num_token
,
dispatch_output
,
dispatch_weights
,
dispatch_scales
,
dispatch_indices
,
dispatch_recv_num_token
,
)
=
self
.
mori_op
.
dispatch
(
hidden_states
,
topk_weights
,
scales
,
topk_ids
,
)
#self.sync()
expect_m
=
hidden_states
.
shape
[
0
]
*
self
.
ep_size
dispatch_output_clip
=
dispatch_output
[:
expect_m
]
dispatch_weights_clip
=
dispatch_weights
[:
expect_m
]
dispatch_indices_clip
=
dispatch_indices
[:
expect_m
]
dispatch_scales_clip
=
dispatch_scales
[:
expect_m
]
expert_output
=
self
.
quant_method
.
apply_ep
(
layer
=
self
,
x
=
dispatch_output_clip
,
topk_weights
=
dispatch_weights_clip
,
topk_ids
=
dispatch_indices_clip
,
global_num_experts
=
self
.
global_num_experts
,
expert_map
=
self
.
expert_map
,
activation
=
self
.
activation
,
apply_router_weight_on_input
=
self
.
apply_router_weight_on_input
,
use_nn_moe
=
self
.
use_nn_moe
,
num_local_tokens
=
dispatch_recv_num_token
,
config_select_bs
=
hidden_states
.
shape
[
0
],
scales
=
dispatch_scales_clip
if
self
.
use_int8_dispatch
else
None
#routed_scaling_factor=self.routed_scaling_factor,
)
# if self.first and hidden_states.shape[0] == 2:
# self.first = False
# import numpy as np
# np.save(f'/work/vllm_profile/ep{self.ep_rank}_topk_ids.npy', dispatch_indices_clip.cpu().numpy())
# print("##################config_select_bs:{} topk_ids shape:{} num_local_tokens:{}".format(hidden_states.shape[0],
# topk_ids.shape,
# dispatch_recv_num_token))
# self.sync()
# expect_m = topk_ids.shape[0] * self.ep_size
# dispatch_output_clip = dispatch_output[:expect_m]
# dispatch_weights_clip = dispatch_weights[:expect_m]
# dispatch_indices_clip = dispatch_indices[:expect_m]
# dispatch_scales_clip = dispatch_scales[:expect_m]
# expert_output = self.quant_method.apply_ep(
# layer=self,
# x=dispatch_output,
# topk_weights=dispatch_weights,
# topk_ids=dispatch_indices,
# x=dispatch_output
_clip
,
# topk_weights=dispatch_weights
_clip
,
# topk_ids=dispatch_indices
_clip
,
# global_num_experts=self.global_num_experts,
# expert_map=self.expert_map,
# activation=self.activation,
# apply_router_weight_on_input=self.apply_router_weight_on_input,
# use_nn_moe=self.use_nn_moe,
# num_local_tokens=dispatch_recv_num_token,
# config_select_bs=hidden_states.shape[0]
*2
,
# scales=dispatch_scales if self.use_int8_dispatch else None
# config_select_bs=hidden_states.shape[0],
# scales=dispatch_scales
_clip
if self.use_int8_dispatch else None
# #routed_scaling_factor=self.routed_scaling_factor,
# )
#self.sync()
expert_output
=
self
.
quant_method
.
apply_ep
(
layer
=
self
,
x
=
dispatch_output
,
topk_weights
=
dispatch_weights
,
topk_ids
=
dispatch_indices
,
global_num_experts
=
self
.
global_num_experts
,
expert_map
=
self
.
expert_map
,
activation
=
self
.
activation
,
apply_router_weight_on_input
=
self
.
apply_router_weight_on_input
,
use_nn_moe
=
self
.
use_nn_moe
,
num_local_tokens
=
dispatch_recv_num_token
,
config_select_bs
=
hidden_states
.
shape
[
0
],
scales
=
dispatch_scales
if
self
.
use_int8_dispatch
else
None
# routed_scaling_factor=self.routed_scaling_factor,
)
# self.sync()
combine_output
,
_
=
self
.
mori_op
.
combine
(
expert_output
,
dispatch_weights
,
topk_ids
)
final_hidden_states
=
combine_output
[:
hidden_states
.
shape
[
0
],
:]
#self.sync()
#
self.sync()
if
not
self
.
ep_moe_config
.
moe_shared_expert_overlap
and
self
.
shared_experts
is
not
None
:
# if shared_expert_overlap is True, the expert calculation happens in
...
...
@@ -448,12 +440,13 @@ class EPMoE(FusedMoE):
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states
=
final_hidden_states
+
shared_output
\
*
(
1.
/
self
.
routed_scaling_factor
)
*
(
1.
/
self
.
routed_scaling_factor
)
return
final_hidden_states
def
ep_moe_forward
(
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
layer_name
:
str
)
->
torch
.
Tensor
:
layer_name
:
str
)
->
torch
.
Tensor
:
forward_context
:
ForwardContext
=
get_forward_context
()
self
=
forward_context
.
no_compile_layers
[
layer_name
]
assert
self
.
quant_method
is
not
None
...
...
@@ -462,7 +455,7 @@ def ep_moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor,
def
ep_moe_forward_fake
(
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
layer_name
:
str
)
->
torch
.
Tensor
:
layer_name
:
str
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
hidden_states
)
...
...
@@ -472,5 +465,5 @@ direct_register_custom_op(
mutates_args
=
[
"hidden_states"
,
"router_logits"
],
fake_impl
=
ep_moe_forward_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
tags
=
(
torch
.
Tag
.
needs_fixed_stride_order
,
),
tags
=
(
torch
.
Tag
.
needs_fixed_stride_order
,),
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment