Commit 5dcc5cb8 authored by 王敏's avatar 王敏
Browse files

优化mori ep

parent e0ba23b5
...@@ -948,7 +948,7 @@ def init_distributed_environment( ...@@ -948,7 +948,7 @@ def init_distributed_environment(
"Fallback Gloo backend is not available.") "Fallback Gloo backend is not available.")
backend = "gloo" backend = "gloo"
# this backend is used for WORLD # this backend is used for WORLD
parallel_config = config.parallel_config
data_parallel_size = parallel_config.data_parallel_size data_parallel_size = parallel_config.data_parallel_size
use_mori_ep = envs.VLLM_USE_MORI_EP and data_parallel_size > 1 and parallel_config.enable_expert_parallel use_mori_ep = envs.VLLM_USE_MORI_EP and data_parallel_size > 1 and parallel_config.enable_expert_parallel
if use_mori_ep: if use_mori_ep:
......
...@@ -10,30 +10,31 @@ import torch.nn.functional as F ...@@ -10,30 +10,31 @@ import torch.nn.functional as F
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.config import get_current_vllm_config from vllm.config import get_current_vllm_config
from vllm.distributed.parallel_state import get_ep_group
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from vllm.forward_context import ForwardContext, get_forward_context from vllm.forward_context import ForwardContext, get_forward_context
from vllm.distributed.parallel_state import get_ep_group, get_node_count
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase) QuantizationConfig, QuantizeMethodBase)
from vllm.distributed import expert_parallel_all_gather, expert_parallel_all_reduce
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe.layer import FusedMoEMethodBase, UnquantizedFusedMoEMethod from vllm.model_executor.layers.fused_moe.layer import FusedMoEMethodBase, UnquantizedFusedMoEMethod
from vllm.model_executor.layers.fused_moe.ep_moe.token_dispatcher import MoEAlltoAllTokenDispatcher from vllm.model_executor.layers.fused_moe.ep_moe.token_dispatcher import MoEAlltoAllTokenDispatcher
from vllm.model_executor.layers.fused_moe.ep_moe.ep_moe_utlis import EpMoeConfig from vllm.model_executor.layers.fused_moe.ep_moe.ep_moe_utlis import EpMoeConfig
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
import mori
import torch.distributed as dist import torch.distributed as dist
from lmslim.layers.gemm.int8_utils import ( try:
per_token_group_quant_int8, import mori
per_token_quant_int8) from lmslim.layers.gemm.int8_utils import (
per_token_quant_int8)
except ImportError:
is_mori_available = False
logger = init_logger(__name__) logger = init_logger(__name__)
_MORI_OP = None _MORI_OP = None
@CustomOp.register("unquantized_ep_moe") @CustomOp.register("unquantized_ep_moe")
class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod): class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
"""MoE method without quantization.""" """MoE method without quantization."""
...@@ -43,20 +44,20 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod): ...@@ -43,20 +44,20 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
self.topk_indices_dtype = None self.topk_indices_dtype = None
self.moe = moe self.moe = moe
self.rocm_aiter_moe_enabled = False # is_rocm_aiter_moe_enabled() self.rocm_aiter_moe_enabled = False # is_rocm_aiter_moe_enabled()
def apply_ep( def apply_ep(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
inplace=True, inplace=True,
activation: str = "silu", activation: str = "silu",
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
global_num_experts: int = -1, global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
) -> torch.Tensor: ) -> torch.Tensor:
return self.forward( return self.forward(
...@@ -72,17 +73,17 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod): ...@@ -72,17 +73,17 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
use_nn_moe=use_nn_moe) use_nn_moe=use_nn_moe)
def forward_cuda( def forward_cuda(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
inplace=True, inplace=True,
activation: str = "silu", activation: str = "silu",
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
global_num_experts: int = -1, global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
) -> torch.Tensor: ) -> torch.Tensor:
# process MoE # process MoE
...@@ -108,48 +109,48 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod): ...@@ -108,48 +109,48 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
return output return output
def forward_cpu( def forward_cpu(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
inplace=True, inplace=True,
activation: str = "silu", activation: str = "silu",
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
global_num_experts: int = -1, global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
**kwargs, **kwargs,
): ):
raise NotImplementedError raise NotImplementedError
def forward_hpu( def forward_hpu(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
inplace=True, inplace=True,
activation: str = "silu", activation: str = "silu",
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
global_num_experts: int = -1, global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
) -> torch.Tensor: ) -> torch.Tensor:
raise NotImplementedError raise NotImplementedError
def forward_tpu( def forward_tpu(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
inplace=True, inplace=True,
activation: str = "silu", activation: str = "silu",
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
global_num_experts: int = -1, global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
) -> torch.Tensor: ) -> torch.Tensor:
raise NotImplementedError raise NotImplementedError
...@@ -166,49 +167,50 @@ class EPMoE(FusedMoE): ...@@ -166,49 +167,50 @@ class EPMoE(FusedMoE):
dp+ep MoE Expert Parallel Impl dp+ep MoE Expert Parallel Impl
""" """
def __init__( def __init__(
self, self,
num_experts: int, # Global number of experts num_experts: int, # Global number of experts
top_k: int, top_k: int,
hidden_size: int, hidden_size: int,
intermediate_size: int, intermediate_size: int,
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
reduce_results: bool = False, reduce_results: bool = False,
renormalize: bool = True, renormalize: bool = True,
use_grouped_topk: bool = False, use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None, num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None, topk_group: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None, tp_size: Optional[int] = None,
ep_size: Optional[int] = None, ep_size: Optional[int] = None,
dp_size: Optional[int] = None, dp_size: Optional[int] = None,
prefix: str = "", prefix: str = "",
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax", scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None, e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
activation: str = "silu", activation: str = "silu",
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
enable_eplb: bool = False, enable_eplb: bool = False,
num_redundant_experts: int = 0, num_redundant_experts: int = 0,
moe_permute_fusion: bool = False, moe_permute_fusion: bool = False,
moe_shared_expert_overlap: bool = False moe_shared_expert_overlap: bool = False
): ):
super().__init__(num_experts, top_k, hidden_size, super().__init__(num_experts, top_k, hidden_size,
intermediate_size, params_dtype, intermediate_size, params_dtype,
reduce_results, renormalize, reduce_results, renormalize,
use_grouped_topk, num_expert_group, use_grouped_topk, num_expert_group,
topk_group, quant_config, tp_size, topk_group, quant_config, tp_size,
ep_size, dp_size, prefix, ep_size, dp_size, prefix,
custom_routing_function, scoring_func, custom_routing_function, scoring_func,
e_score_correction_bias, e_score_correction_bias,
apply_router_weight_on_input, apply_router_weight_on_input,
activation, activation,
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor,
enable_eplb=enable_eplb, enable_eplb=enable_eplb,
num_redundant_experts=num_redundant_experts, num_redundant_experts=num_redundant_experts,
) )
self.ep_moe_config: EpMoeConfig = EpMoeConfig.make( self.ep_moe_config: EpMoeConfig = EpMoeConfig.make(
moe_router_topk=self.top_k, moe_router_topk=self.top_k,
# TODO: support fusion permute # TODO: support fusion permute
...@@ -221,7 +223,7 @@ class EPMoE(FusedMoE): ...@@ -221,7 +223,7 @@ class EPMoE(FusedMoE):
) )
local_expert_indices_offset = ( local_expert_indices_offset = (
self.ep_rank * self.local_num_experts self.ep_rank * self.local_num_experts
) )
self.local_expert_indices = [ self.local_expert_indices = [
local_expert_indices_offset + i for i in range(self.local_num_experts) local_expert_indices_offset + i for i in range(self.local_num_experts)
...@@ -229,10 +231,10 @@ class EPMoE(FusedMoE): ...@@ -229,10 +231,10 @@ class EPMoE(FusedMoE):
self.use_shared_expert = False self.use_shared_expert = False
self.token_dispatcher = MoEAlltoAllTokenDispatcher( self.token_dispatcher = MoEAlltoAllTokenDispatcher(
self.local_num_experts, self.local_expert_indices, self.local_num_experts, self.local_expert_indices,
config=self.ep_moe_config, layer_name=f"{self.layer_name}.token_dispatcher", config=self.ep_moe_config, layer_name=f"{self.layer_name}.token_dispatcher",
) )
self.shared_expert_overlap = moe_shared_expert_overlap self.shared_expert_overlap = moe_shared_expert_overlap
self.shared_experts = None self.shared_experts = None
...@@ -241,29 +243,30 @@ class EPMoE(FusedMoE): ...@@ -241,29 +243,30 @@ class EPMoE(FusedMoE):
self.scales = None self.scales = None
self.use_int8_dispatch = True self.use_int8_dispatch = True
vllm_config = get_current_vllm_config()
self.max_num_inp_token_per_rank = vllm_config.scheduler_config.max_num_seqs
self.mori_op = self.get_mori_op() self.mori_op = self.get_mori_op()
self.first = True self.first = True
def get_mori_op(self): def get_mori_op(self):
global _MORI_OP global _MORI_OP
if _MORI_OP is None: if _MORI_OP is None:
# world_group = torch.distributed.group.WORLD
# assert world_group is not None
#torch._C._distributed_c10d._register_process_group("mori_ep", get_ep_group().device_group)
#mori.shmem.shmem_torch_process_group_init("mori_ep")
world_group = torch.distributed.group.WORLD world_group = torch.distributed.group.WORLD
assert world_group is not None assert world_group is not None
torch._C._distributed_c10d._register_process_group("default", world_group) torch._C._distributed_c10d._register_process_group("mori_ep", get_ep_group().device_group)
mori.shmem.shmem_torch_process_group_init("default") mori.shmem.shmem_torch_process_group_init("mori_ep")
# world_group = torch.distributed.group.WORLD
# assert world_group is not None
# torch._C._distributed_c10d._register_process_group("default", world_group)
# mori.shmem.shmem_torch_process_group_init("default")
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
multi_node = self.ep_size / 8 > 1 multi_node = self.ep_size / 8 > 1
mori_data_type=vllm_config.model_config.dtype mori_data_type = vllm_config.model_config.dtype
mori_scale_type_size = vllm_config.model_config.dtype.itemsize mori_scale_type_size = vllm_config.model_config.dtype.itemsize
if self.use_int8_dispatch: if self.use_int8_dispatch:
mori_scale_type_size = 4 mori_scale_type_size = 4
config = mori.ops.EpDispatchCombineConfig( config = mori.ops.EpDispatchCombineConfig(
data_type=mori_data_type, data_type=mori_data_type,
...@@ -272,17 +275,18 @@ class EPMoE(FusedMoE): ...@@ -272,17 +275,18 @@ class EPMoE(FusedMoE):
hidden_dim=self.hidden_size, hidden_dim=self.hidden_size,
scale_dim=1 if self.use_int8_dispatch else 0, scale_dim=1 if self.use_int8_dispatch else 0,
scale_type_size=mori_scale_type_size, scale_type_size=mori_scale_type_size,
max_num_inp_token_per_rank=2048, max_num_inp_token_per_rank=self.max_num_inp_token_per_rank,
num_experts_per_rank=self.local_num_experts, num_experts_per_rank=self.local_num_experts,
num_experts_per_token=self.top_k, num_experts_per_token=self.top_k,
max_token_type_size=2, max_token_type_size=2,
block_num=80, block_num=80,
warp_num_per_block=16, warp_num_per_block=16,
# kernel_type=mori.ops.EpDispatchCombineKernelType.InterNode
kernel_type=mori.ops.EpDispatchCombineKernelType.InterNode if multi_node else \ kernel_type=mori.ops.EpDispatchCombineKernelType.InterNode if multi_node else \
mori.ops.EpDispatchCombineKernelType.IntraNode mori.ops.EpDispatchCombineKernelType.IntraNode
) )
_MORI_OP = mori.ops.EpDispatchCombineOp(config) _MORI_OP = mori.ops.EpDispatchCombineOp(config)
return _MORI_OP return _MORI_OP
def set_shared_experts(self, shared_experts: torch.nn.Module): def set_shared_experts(self, shared_experts: torch.nn.Module):
...@@ -302,15 +306,15 @@ class EPMoE(FusedMoE): ...@@ -302,15 +306,15 @@ class EPMoE(FusedMoE):
assert quant_method is not None assert quant_method is not None
assert isinstance(quant_method, FusedMoEMethodBase) assert isinstance(quant_method, FusedMoEMethodBase)
return quant_method return quant_method
def sync(self): def sync(self):
#torch.cuda.synchronize() # torch.cuda.synchronize()
dist.barrier() dist.barrier()
def forward(self, hidden_states: torch.Tensor, def forward(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor): router_logits: torch.Tensor):
return torch.ops.vllm.ep_moe_forward(hidden_states, router_logits, return torch.ops.vllm.ep_moe_forward(hidden_states, router_logits,
self.layer_name) self.layer_name)
def get_expert_weights(self) -> Iterable[torch.Tensor]: def get_expert_weights(self) -> Iterable[torch.Tensor]:
weights = list(self.named_parameters()) weights = list(self.named_parameters())
...@@ -329,30 +333,29 @@ class EPMoE(FusedMoE): ...@@ -329,30 +333,29 @@ class EPMoE(FusedMoE):
return [ return [
weight.view(self.local_num_experts, -1) for name, weight in weights weight.view(self.local_num_experts, -1) for name, weight in weights
if name not in NON_EXPERT_WEIGHTS if name not in NON_EXPERT_WEIGHTS
] ]
def forward_impl(self, hidden_states: torch.Tensor, def forward_impl(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor): router_logits: torch.Tensor):
topk_weights, topk_ids = self.select_experts( topk_weights, topk_ids = self.select_experts(
hidden_states=hidden_states, hidden_states=hidden_states,
router_logits=router_logits, router_logits=router_logits,
use_grouped_topk=self.use_grouped_topk, use_grouped_topk=self.use_grouped_topk,
top_k=self.top_k, top_k=self.top_k,
renormalize=self.renormalize, renormalize=self.renormalize,
topk_group=self.topk_group, topk_group=self.topk_group,
num_expert_group=self.num_expert_group, num_expert_group=self.num_expert_group,
custom_routing_function=self.custom_routing_function, custom_routing_function=self.custom_routing_function,
scoring_func=self.scoring_func, scoring_func=self.scoring_func,
e_score_correction_bias=self.e_score_correction_bias, e_score_correction_bias=self.e_score_correction_bias,
indices_type=torch.int32, indices_type=torch.int32,
routed_scaling_factor=self.routed_scaling_factor, routed_scaling_factor=self.routed_scaling_factor,
use_fused_gate=self.use_fused_gate) use_fused_gate=self.use_fused_gate)
if not self.ep_moe_config.moe_shared_expert_overlap and self.shared_experts is not None: if not self.ep_moe_config.moe_shared_expert_overlap and self.shared_experts is not None:
shared_output = self.shared_experts(hidden_states) shared_output = self.shared_experts(hidden_states)
if self.use_int8_dispatch: if self.use_int8_dispatch:
hidden_states, scales = per_token_quant_int8(hidden_states) hidden_states, scales = per_token_quant_int8(hidden_states)
else: else:
...@@ -365,75 +368,64 @@ class EPMoE(FusedMoE): ...@@ -365,75 +368,64 @@ class EPMoE(FusedMoE):
) )
scales = self.scales scales = self.scales
# self.sync()
#self.sync()
( (
dispatch_output, dispatch_output,
dispatch_weights, dispatch_weights,
dispatch_scales, dispatch_scales,
dispatch_indices, dispatch_indices,
dispatch_recv_num_token, dispatch_recv_num_token,
) = self.mori_op.dispatch( ) = self.mori_op.dispatch(
hidden_states, hidden_states,
topk_weights, topk_weights,
scales, scales,
topk_ids, topk_ids,
) )
#self.sync() # self.sync()
expect_m = hidden_states.shape[0] * self.ep_size
dispatch_output_clip = dispatch_output[:expect_m]
dispatch_weights_clip = dispatch_weights[:expect_m]
dispatch_indices_clip = dispatch_indices[:expect_m]
dispatch_scales_clip = dispatch_scales[:expect_m]
expert_output = self.quant_method.apply_ep(
layer=self,
x=dispatch_output_clip,
topk_weights=dispatch_weights_clip,
topk_ids=dispatch_indices_clip,
global_num_experts=self.global_num_experts,
expert_map=self.expert_map,
activation=self.activation,
apply_router_weight_on_input=self.apply_router_weight_on_input,
use_nn_moe=self.use_nn_moe,
num_local_tokens=dispatch_recv_num_token,
config_select_bs=hidden_states.shape[0],
scales=dispatch_scales_clip if self.use_int8_dispatch else None
#routed_scaling_factor=self.routed_scaling_factor,
)
# if self.first and hidden_states.shape[0] == 2:
# self.first = False
# import numpy as np
# np.save(f'/work/vllm_profile/ep{self.ep_rank}_topk_ids.npy', dispatch_indices_clip.cpu().numpy())
# print("##################config_select_bs:{} topk_ids shape:{} num_local_tokens:{}".format(hidden_states.shape[0],
# topk_ids.shape,
# dispatch_recv_num_token))
# expect_m = topk_ids.shape[0] * self.ep_size
# dispatch_output_clip = dispatch_output[:expect_m]
# dispatch_weights_clip = dispatch_weights[:expect_m]
# dispatch_indices_clip = dispatch_indices[:expect_m]
# dispatch_scales_clip = dispatch_scales[:expect_m]
# expert_output = self.quant_method.apply_ep( # expert_output = self.quant_method.apply_ep(
# layer=self, # layer=self,
# x=dispatch_output, # x=dispatch_output_clip,
# topk_weights=dispatch_weights, # topk_weights=dispatch_weights_clip,
# topk_ids=dispatch_indices, # topk_ids=dispatch_indices_clip,
# global_num_experts=self.global_num_experts, # global_num_experts=self.global_num_experts,
# expert_map=self.expert_map, # expert_map=self.expert_map,
# activation=self.activation, # activation=self.activation,
# apply_router_weight_on_input=self.apply_router_weight_on_input, # apply_router_weight_on_input=self.apply_router_weight_on_input,
# use_nn_moe=self.use_nn_moe, # use_nn_moe=self.use_nn_moe,
# num_local_tokens=dispatch_recv_num_token, # num_local_tokens=dispatch_recv_num_token,
# config_select_bs=hidden_states.shape[0]*2, # config_select_bs=hidden_states.shape[0],
# scales=dispatch_scales if self.use_int8_dispatch else None # scales=dispatch_scales_clip if self.use_int8_dispatch else None
# #routed_scaling_factor=self.routed_scaling_factor, # #routed_scaling_factor=self.routed_scaling_factor,
# ) # )
#self.sync()
expert_output = self.quant_method.apply_ep(
layer=self,
x=dispatch_output,
topk_weights=dispatch_weights,
topk_ids=dispatch_indices,
global_num_experts=self.global_num_experts,
expert_map=self.expert_map,
activation=self.activation,
apply_router_weight_on_input=self.apply_router_weight_on_input,
use_nn_moe=self.use_nn_moe,
num_local_tokens=dispatch_recv_num_token,
config_select_bs=hidden_states.shape[0],
scales=dispatch_scales if self.use_int8_dispatch else None
# routed_scaling_factor=self.routed_scaling_factor,
)
# self.sync()
combine_output, _ = self.mori_op.combine(expert_output, dispatch_weights, topk_ids) combine_output, _ = self.mori_op.combine(expert_output, dispatch_weights, topk_ids)
final_hidden_states = combine_output[:hidden_states.shape[0], :] final_hidden_states = combine_output[:hidden_states.shape[0], :]
#self.sync() # self.sync()
if not self.ep_moe_config.moe_shared_expert_overlap and self.shared_experts is not None: if not self.ep_moe_config.moe_shared_expert_overlap and self.shared_experts is not None:
# if shared_expert_overlap is True, the expert calculation happens in # if shared_expert_overlap is True, the expert calculation happens in
...@@ -448,12 +440,13 @@ class EPMoE(FusedMoE): ...@@ -448,12 +440,13 @@ class EPMoE(FusedMoE):
# Fix FP16 overflow # Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details. # See DeepseekV2DecoderLayer for more details.
final_hidden_states = final_hidden_states + shared_output \ final_hidden_states = final_hidden_states + shared_output \
* (1. / self.routed_scaling_factor) * (1. / self.routed_scaling_factor)
return final_hidden_states return final_hidden_states
def ep_moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor, def ep_moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor,
layer_name: str) -> torch.Tensor: layer_name: str) -> torch.Tensor:
forward_context: ForwardContext = get_forward_context() forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name] self = forward_context.no_compile_layers[layer_name]
assert self.quant_method is not None assert self.quant_method is not None
...@@ -462,7 +455,7 @@ def ep_moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor, ...@@ -462,7 +455,7 @@ def ep_moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor,
def ep_moe_forward_fake(hidden_states: torch.Tensor, router_logits: torch.Tensor, def ep_moe_forward_fake(hidden_states: torch.Tensor, router_logits: torch.Tensor,
layer_name: str) -> torch.Tensor: layer_name: str) -> torch.Tensor:
return torch.empty_like(hidden_states) return torch.empty_like(hidden_states)
...@@ -472,5 +465,5 @@ direct_register_custom_op( ...@@ -472,5 +465,5 @@ direct_register_custom_op(
mutates_args=["hidden_states", "router_logits"], mutates_args=["hidden_states", "router_logits"],
fake_impl=ep_moe_forward_fake, fake_impl=ep_moe_forward_fake,
dispatch_key=current_platform.dispatch_key, dispatch_key=current_platform.dispatch_key,
tags=(torch.Tag.needs_fixed_stride_order, ), tags=(torch.Tag.needs_fixed_stride_order,),
) )
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment