Commit d698d6f2 authored by 王敏's avatar 王敏
Browse files

[feat]整合mori和deepep相关代码

parent 7293a072
File added
...@@ -171,16 +171,16 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase): ...@@ -171,16 +171,16 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
num_qps_per_rank = None num_qps_per_rank = None
if self.internode: if self.internode:
num_rdma_bytes = int(1e9/2)#1024 * 1024 * 1024 num_rdma_bytes = int(1e9/2) #1024 * 1024 * 1024
num_qps_per_rank = 30#self.num_sms // 2 num_qps_per_rank = 30 #self.num_sms // 2
import deep_ep # import deep_ep
num_nvl_bytes, num_rdma_bytes = 0, 0 # num_nvl_bytes, num_rdma_bytes = 0, 0
hidden_size = 7168 # hidden_size = 7168
hidden_bytes = hidden_size * 2 # hidden_bytes = hidden_size * 2
for config in (deep_ep.Buffer.get_dispatch_config(self.cpu_group.size()), deep_ep.Buffer.get_combine_config(self.cpu_group.size())): # for config in (deep_ep.Buffer.get_dispatch_config(self.cpu_group.size()), deep_ep.Buffer.get_combine_config(self.cpu_group.size())):
num_nvl_bytes = max(config.get_nvl_buffer_size_hint(hidden_bytes, self.cpu_group.size()), num_nvl_bytes) # num_nvl_bytes = max(config.get_nvl_buffer_size_hint(hidden_bytes, self.cpu_group.size()), num_nvl_bytes)
num_rdma_bytes = max(config.get_rdma_buffer_size_hint(hidden_bytes, self.cpu_group.size()), num_rdma_bytes) # num_rdma_bytes = max(config.get_rdma_buffer_size_hint(hidden_bytes, self.cpu_group.size()), num_rdma_bytes)
else: else:
num_rdma_bytes = 0 num_rdma_bytes = 0
num_qps_per_rank = 1 num_qps_per_rank = 1
......
...@@ -175,6 +175,7 @@ if TYPE_CHECKING: ...@@ -175,6 +175,7 @@ if TYPE_CHECKING:
USE_FUSED_SILU_MUL_QUANT: bool = False USE_FUSED_SILU_MUL_QUANT: bool = False
VLLM_P2P_ASYNC: bool = False VLLM_P2P_ASYNC: bool = False
VLLM_P2P_BUF_TOKENS: int = 30000 VLLM_P2P_BUF_TOKENS: int = 30000
VLLM_ENABLE_MOE_GROUP_GEMM: bool = False
def get_default_cache_root(): def get_default_cache_root():
return os.getenv( return os.getenv(
...@@ -1151,6 +1152,11 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1151,6 +1152,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
# pd separation p2p async buf tokens # pd separation p2p async buf tokens
"VLLM_P2P_BUF_TOKENS": "VLLM_P2P_BUF_TOKENS":
lambda: int(os.getenv("VLLM_P2P_BUF_TOKENS", "30000")), lambda: int(os.getenv("VLLM_P2P_BUF_TOKENS", "30000")),
# pd separation p2p async buf tokens
"VLLM_ENABLE_MOE_GROUP_GEMM":
lambda: (os.environ.get("VLLM_ENABLE_MOE_GROUP_GEMM", "False").lower() in
("true", "1")),
} }
# --8<-- [end:env-vars-definition] # --8<-- [end:env-vars-definition]
......
import os from typing import Callable, Optional
import logging
from typing import Callable, List, Optional, Tuple
from dataclasses import dataclass
from collections.abc import Iterable from collections.abc import Iterable
import torch import torch
import torch.nn.functional as F import torch.distributed as dist
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -18,10 +15,8 @@ from vllm.model_executor.layers.quantization.base_config import ( ...@@ -18,10 +15,8 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase) QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe.layer import FusedMoEMethodBase, UnquantizedFusedMoEMethod from vllm.model_executor.layers.fused_moe.layer import FusedMoEMethodBase, UnquantizedFusedMoEMethod
from vllm.model_executor.layers.fused_moe.ep_moe.token_dispatcher import MoEAlltoAllTokenDispatcher from vllm.model_executor.layers.fused_moe.mori_moe.ep_moe_utlis import EpMoeConfig
from vllm.model_executor.layers.fused_moe.ep_moe.ep_moe_utlis import EpMoeConfig
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
import torch.distributed as dist
try: try:
import mori import mori
...@@ -35,8 +30,8 @@ logger = init_logger(__name__) ...@@ -35,8 +30,8 @@ logger = init_logger(__name__)
_MORI_OP = None _MORI_OP = None
@CustomOp.register("unquantized_ep_moe") @CustomOp.register("unquantized_mori_moe")
class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod): class UnquantizedMoriMoeMethod(UnquantizedFusedMoEMethod):
"""MoE method without quantization.""" """MoE method without quantization."""
def __init__(self, moe: FusedMoEConfig): def __init__(self, moe: FusedMoEConfig):
...@@ -44,9 +39,9 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod): ...@@ -44,9 +39,9 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
self.topk_indices_dtype = None self.topk_indices_dtype = None
self.moe = moe self.moe = moe
self.rocm_aiter_moe_enabled = False # is_rocm_aiter_moe_enabled() self.rocm_aiter_moe_enabled = False
def apply_ep( def apply_mori_ep(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -162,7 +157,7 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod): ...@@ -162,7 +157,7 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
forward_native = forward_cuda forward_native = forward_cuda
class EPMoE(FusedMoE): class MoriMoE(FusedMoE):
""" """
dp+ep MoE Expert Parallel Impl dp+ep MoE Expert Parallel Impl
...@@ -194,7 +189,6 @@ class EPMoE(FusedMoE): ...@@ -194,7 +189,6 @@ class EPMoE(FusedMoE):
enable_eplb: bool = False, enable_eplb: bool = False,
num_redundant_experts: int = 0, num_redundant_experts: int = 0,
moe_permute_fusion: bool = False, moe_permute_fusion: bool = False,
moe_shared_expert_overlap: bool = False
): ):
super().__init__(num_experts, top_k, hidden_size, super().__init__(num_experts, top_k, hidden_size,
intermediate_size, params_dtype, intermediate_size, params_dtype,
...@@ -215,7 +209,6 @@ class EPMoE(FusedMoE): ...@@ -215,7 +209,6 @@ class EPMoE(FusedMoE):
moe_router_topk=self.top_k, moe_router_topk=self.top_k,
# TODO: support fusion permute # TODO: support fusion permute
moe_permute_fusion=moe_permute_fusion, moe_permute_fusion=moe_permute_fusion,
moe_shared_expert_overlap=moe_shared_expert_overlap,
ep_size=self.ep_size, ep_size=self.ep_size,
num_moe_experts=self.global_num_experts, num_moe_experts=self.global_num_experts,
routed_scaling_factor=self.routed_scaling_factor, routed_scaling_factor=self.routed_scaling_factor,
...@@ -228,21 +221,14 @@ class EPMoE(FusedMoE): ...@@ -228,21 +221,14 @@ class EPMoE(FusedMoE):
self.local_expert_indices = [ self.local_expert_indices = [
local_expert_indices_offset + i for i in range(self.local_num_experts) local_expert_indices_offset + i for i in range(self.local_num_experts)
] ]
self.use_shared_expert = False
# self.token_dispatcher = MoEAlltoAllTokenDispatcher(
# self.local_num_experts, self.local_expert_indices,
# config=self.ep_moe_config, layer_name=f"{self.layer_name}.token_dispatcher",
# )
self.shared_expert_overlap = moe_shared_expert_overlap
self.shared_experts = None self.shared_experts = None
self.scales = None self.scales = None
self.use_int8_dispatch = True self.use_int8_dispatch = True
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
self.max_num_inp_token_per_rank = 1024#vllm_config.scheduler_config.max_num_seqs self.max_num_inp_token_per_rank = 1024 #vllm_config.scheduler_config.max_num_seqs
self.mori_op = self.get_mori_op() self.mori_op = self.get_mori_op()
def get_mori_op(self): def get_mori_op(self):
...@@ -252,10 +238,6 @@ class EPMoE(FusedMoE): ...@@ -252,10 +238,6 @@ class EPMoE(FusedMoE):
assert world_group is not None assert world_group is not None
torch._C._distributed_c10d._register_process_group("mori_ep", get_ep_group().device_group) torch._C._distributed_c10d._register_process_group("mori_ep", get_ep_group().device_group)
mori.shmem.shmem_torch_process_group_init("mori_ep") mori.shmem.shmem_torch_process_group_init("mori_ep")
# world_group = torch.distributed.group.WORLD
# assert world_group is not None
# torch._C._distributed_c10d._register_process_group("default", world_group)
# mori.shmem.shmem_torch_process_group_init("default")
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
multi_node = self.ep_size / 8 > 1 multi_node = self.ep_size / 8 > 1
...@@ -278,7 +260,6 @@ class EPMoE(FusedMoE): ...@@ -278,7 +260,6 @@ class EPMoE(FusedMoE):
max_token_type_size=2, max_token_type_size=2,
block_num=80, block_num=80,
warp_num_per_block=4, warp_num_per_block=4,
# kernel_type=mori.ops.EpDispatchCombineKernelType.InterNode
kernel_type=mori.ops.EpDispatchCombineKernelType.InterNode if multi_node else \ kernel_type=mori.ops.EpDispatchCombineKernelType.InterNode if multi_node else \
mori.ops.EpDispatchCombineKernelType.IntraNode mori.ops.EpDispatchCombineKernelType.IntraNode
) )
...@@ -290,14 +271,11 @@ class EPMoE(FusedMoE): ...@@ -290,14 +271,11 @@ class EPMoE(FusedMoE):
if self.shared_experts is None: if self.shared_experts is None:
self.shared_experts = shared_experts self.shared_experts = shared_experts
# if self.shared_expert_overlap:
# self.token_dispatcher.set_shared_experts(self.shared_experts)
def create_quant_method(self, moe, quant_config, prefix): def create_quant_method(self, moe, quant_config, prefix):
# Note: get_quant_method will look at the layer's local_num_experts # Note: get_quant_method will look at the layer's local_num_experts
# for heuristic purposes, so it must be initialized first. # for heuristic purposes, so it must be initialized first.
quant_method: Optional[QuantizeMethodBase] = None quant_method: Optional[QuantizeMethodBase] = None
quant_method = (UnquantizedEPGroupedGemmMethod(moe) if quant_config is None quant_method = (UnquantizedMoriMoeMethod(moe) if quant_config is None
else quant_config.get_quant_method(self, prefix)) else quant_config.get_quant_method(self, prefix))
assert quant_method is not None assert quant_method is not None
...@@ -310,7 +288,7 @@ class EPMoE(FusedMoE): ...@@ -310,7 +288,7 @@ class EPMoE(FusedMoE):
def forward(self, hidden_states: torch.Tensor, def forward(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor): router_logits: torch.Tensor):
return torch.ops.vllm.ep_moe_forward(hidden_states, router_logits, return torch.ops.vllm.mori_moe_forward(hidden_states, router_logits,
self.layer_name) self.layer_name)
def get_expert_weights(self) -> Iterable[torch.Tensor]: def get_expert_weights(self) -> Iterable[torch.Tensor]:
...@@ -350,7 +328,7 @@ class EPMoE(FusedMoE): ...@@ -350,7 +328,7 @@ class EPMoE(FusedMoE):
routed_scaling_factor=self.routed_scaling_factor, routed_scaling_factor=self.routed_scaling_factor,
use_fused_gate=self.use_fused_gate) use_fused_gate=self.use_fused_gate)
if not self.ep_moe_config.moe_shared_expert_overlap and self.shared_experts is not None: if self.shared_experts is not None:
shared_output = self.shared_experts(hidden_states) shared_output = self.shared_experts(hidden_states)
if self.use_int8_dispatch: if self.use_int8_dispatch:
...@@ -377,11 +355,10 @@ class EPMoE(FusedMoE): ...@@ -377,11 +355,10 @@ class EPMoE(FusedMoE):
hidden_states, hidden_states,
topk_weights, topk_weights,
scales, scales,
topk_ids, topk_ids
#layer_idx=int(self.layer_name.split('.')[2])
) )
expert_output = self.quant_method.apply_ep( expert_output = self.quant_method.apply_mori_ep(
layer=self, layer=self,
x=dispatch_output, x=dispatch_output,
topk_weights=dispatch_weights, topk_weights=dispatch_weights,
...@@ -394,7 +371,6 @@ class EPMoE(FusedMoE): ...@@ -394,7 +371,6 @@ class EPMoE(FusedMoE):
num_local_tokens=dispatch_recv_num_token, num_local_tokens=dispatch_recv_num_token,
config_select_bs=hidden_states.shape[0]*self.ep_size/self.dp_size, config_select_bs=hidden_states.shape[0]*self.ep_size/self.dp_size,
scales=dispatch_scales if self.use_int8_dispatch else None scales=dispatch_scales if self.use_int8_dispatch else None
# routed_scaling_factor=self.routed_scaling_factor,
) )
# self.sync() # self.sync()
...@@ -404,11 +380,7 @@ class EPMoE(FusedMoE): ...@@ -404,11 +380,7 @@ class EPMoE(FusedMoE):
# self.sync() # self.sync()
if not self.ep_moe_config.moe_shared_expert_overlap and self.shared_experts is not None: if self.shared_experts is not None:
# shared_output = (
# self.maybe_all_reduce_tensor_model_parallel(
# shared_output))
if hidden_states.dtype != torch.float16: if hidden_states.dtype != torch.float16:
final_hidden_states = final_hidden_states + shared_output final_hidden_states = final_hidden_states + shared_output
else: else:
...@@ -420,7 +392,7 @@ class EPMoE(FusedMoE): ...@@ -420,7 +392,7 @@ class EPMoE(FusedMoE):
return final_hidden_states return final_hidden_states
def ep_moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor, def mori_moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor,
layer_name: str) -> torch.Tensor: layer_name: str) -> torch.Tensor:
forward_context: ForwardContext = get_forward_context() forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name] self = forward_context.no_compile_layers[layer_name]
...@@ -429,16 +401,16 @@ def ep_moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor, ...@@ -429,16 +401,16 @@ def ep_moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor,
return self.forward_impl(hidden_states, router_logits) return self.forward_impl(hidden_states, router_logits)
def ep_moe_forward_fake(hidden_states: torch.Tensor, router_logits: torch.Tensor, def mori_moe_forward_fake(hidden_states: torch.Tensor, router_logits: torch.Tensor,
layer_name: str) -> torch.Tensor: layer_name: str) -> torch.Tensor:
return torch.empty_like(hidden_states) return torch.empty_like(hidden_states)
direct_register_custom_op( direct_register_custom_op(
op_name="ep_moe_forward", op_name="mori_moe_forward",
op_func=ep_moe_forward, op_func=mori_moe_forward,
mutates_args=["hidden_states", "router_logits"], mutates_args=["hidden_states", "router_logits"],
fake_impl=ep_moe_forward_fake, fake_impl=mori_moe_forward_fake,
dispatch_key=current_platform.dispatch_key, dispatch_key=current_platform.dispatch_key,
tags=(torch.Tag.needs_fixed_stride_order,), tags=(torch.Tag.needs_fixed_stride_order,),
) )
\ No newline at end of file
...@@ -166,6 +166,8 @@ class SlimQuantW4A8Int8MarlinMoEMethod: ...@@ -166,6 +166,8 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
self.use_deepep = parallel_config.enable_expert_parallel and \ self.use_deepep = parallel_config.enable_expert_parallel and \
(envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" or \ (envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" or \
envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency") envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency")
self.enable_moe_group_gemm = parallel_config.enable_expert_parallel and envs.VLLM_ENABLE_MOE_GROUP_GEMM
def create_weights( def create_weights(
...@@ -250,32 +252,36 @@ class SlimQuantW4A8Int8MarlinMoEMethod: ...@@ -250,32 +252,36 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
shared_output: Optional[torch.Tensor] = None, shared_output: Optional[torch.Tensor] = None,
**_ ): **_ ):
workspace, global_reduce_buffer = MarlinMoeWorkspace(x.device).get_buffers() if not self.enable_moe_group_gemm:
return fused_experts_impl_w4a8_marlin( workspace, global_reduce_buffer = MarlinMoeWorkspace(x.device).get_buffers()
x, return fused_experts_impl_w4a8_marlin(
w1, x,
w2, w1,
topk_ids=topk_ids, w2,
topk_weights=topk_weights, topk_ids=topk_ids,
workspace=workspace, topk_weights=topk_weights,
global_reduce_buffer=global_reduce_buffer, workspace=workspace,
inplace=True, global_reduce_buffer=global_reduce_buffer,
use_int4_w4a8=True, inplace=True,
per_channel_quant=True, use_int4_w4a8=True,
activation=activation, per_channel_quant=True,
expert_map=expert_map, activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input, expert_map=expert_map,
global_num_experts=global_num_experts, apply_router_weight_on_input=apply_router_weight_on_input,
w1_scale=w1_scale, global_num_experts=global_num_experts,
w2_scale=w2_scale, w1_scale=w1_scale,
a1_scale=a1_scale, w2_scale=w2_scale,
a2_scale=a2_scale, a1_scale=a1_scale,
use_nn_moe=use_nn_moe, a2_scale=a2_scale,
shared_output=shared_output, use_nn_moe=use_nn_moe,
routed_scaling_factor=routed_scaling_factor, shared_output=shared_output,
) routed_scaling_factor=routed_scaling_factor,
)
def apply_ep( #dp+ep else:
# TODO:
return None
def apply_mori_ep(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
...@@ -310,12 +316,11 @@ class SlimQuantW4A8Int8MarlinMoEMethod: ...@@ -310,12 +316,11 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
w1_scale=(layer.w13_weight_scale), w1_scale=(layer.w13_weight_scale),
w2_scale=(layer.w2_weight_scale), w2_scale=(layer.w2_weight_scale),
a1_scale=layer.w13_input_scale, a1_scale=scales,
a2_scale=layer.w2_input_scale, a2_scale=layer.w2_input_scale,
use_nn_moe=use_nn_moe, use_nn_moe=use_nn_moe,
num_local_tokens=num_local_tokens, num_local_tokens=num_local_tokens,
config_select_bs=config_select_bs, config_select_bs=config_select_bs,
q_scales=scales
) )
def apply( def apply(
......
...@@ -43,8 +43,8 @@ from vllm.distributed import (get_ep_group, get_pp_group, get_dp_group, ...@@ -43,8 +43,8 @@ from vllm.distributed import (get_ep_group, get_pp_group, get_dp_group,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe.ep_moe.layer import EPMoE from vllm.model_executor.layers.fused_moe.mori_moe.layer import MoriMoE
from vllm.model_executor.layers.fused_moe.ep_moe.ep_moe_utlis import EPSharedExperts from vllm.model_executor.layers.fused_moe.mori_moe.ep_moe_utlis import EPSharedExperts
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear, MergedColumnParallelLinear,
...@@ -167,10 +167,10 @@ class DeepseekV2MoE(nn.Module): ...@@ -167,10 +167,10 @@ class DeepseekV2MoE(nn.Module):
self.n_local_physical_experts) self.n_local_physical_experts)
dp_size = get_dp_group().world_size dp_size = get_dp_group().world_size
self.use_mori_ep = envs.VLLM_ALL2ALL_BACKEND == 'mori' and dp_size > 1 and parallel_config.enable_expert_parallel self.use_mori_ep = parallel_config.enable_expert_parallel and dp_size > 1 and envs.VLLM_ALL2ALL_BACKEND == 'mori'
self.enable_expert_parallel = parallel_config.enable_expert_parallel self.enable_expert_parallel = parallel_config.enable_expert_parallel
moe_cls = FusedMoE if not self.use_mori_ep else EPMoE moe_cls = FusedMoE if not self.use_mori_ep else MoriMoE
self.experts = moe_cls( self.experts = moe_cls(
num_experts=config.n_routed_experts, num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok, top_k=config.num_experts_per_tok,
...@@ -225,12 +225,12 @@ class DeepseekV2MoE(nn.Module): ...@@ -225,12 +225,12 @@ class DeepseekV2MoE(nn.Module):
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states) router_logits, _ = self.gate(hidden_states)
if not self.use_mori_ep: if not self.enable_expert_parallel:
if envs.VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD: if envs.VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD:
final_hidden_states = self.experts( final_hidden_states = self.experts(
hidden_states=hidden_states, hidden_states=hidden_states,
router_logits=router_logits, router_logits=router_logits,
shared_output=shared_output) shared_output=shared_output)
else: else:
if hidden_states.dtype != torch.float16: if hidden_states.dtype != torch.float16:
final_hidden_states = self.experts( final_hidden_states = self.experts(
...@@ -249,8 +249,22 @@ class DeepseekV2MoE(nn.Module): ...@@ -249,8 +249,22 @@ class DeepseekV2MoE(nn.Module):
# See DeepseekV2DecoderLayer for more details. # See DeepseekV2DecoderLayer for more details.
final_hidden_states = final_hidden_states + shared_output \ final_hidden_states = final_hidden_states + shared_output \
* (1. / self.routed_scaling_factor) * (1. / self.routed_scaling_factor)
else: else:
final_hidden_states = self.experts(hidden_states=hidden_states, if not self.use_mori_ep:
final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits)
if shared_output is not None:
if hidden_states.dtype != torch.float16:
final_hidden_states = final_hidden_states + shared_output
else:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states = final_hidden_states + shared_output \
* (1. / self.routed_scaling_factor)
else:
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits) router_logits=router_logits)
if not self.use_mori_ep: if not self.use_mori_ep:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment