Commit 8f4471f0 authored by wangmin6's avatar wangmin6
Browse files

Merge branch 'v0.15.1-dev_lightop_moe_sum_mul_add' into 'v0.15.1-dev'

feat(deepseek-moe): 接入 VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD 融合链路

See merge request dcutoolkit/deeplearing/vllm!485
parents 7676d0c9 0639678c
......@@ -313,6 +313,7 @@ if TYPE_CHECKING:
VLLM_USE_LIGHTOP_FILL_MOE_ALIGN: bool = False
VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT: bool = False
VLLM_USE_CUDA_GRAPH_SIZES: bool = False
VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD: bool = False
VLLM_USE_LIGHTOP_FUSED_TOPP_TOPK: bool = False
VLLM_V1_USE_FA_UNIFIED_ATTN_2D: bool = False
VLLM_ENABLE_RAY_ASYNC_SCHEDULING: bool = False
......@@ -1957,6 +1958,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda: (os.getenv("VLLM_USE_CUDA_GRAPH_SIZES", "False").lower() in
("true", "1")),
# vLLM will use lightop fused moe_sum + mul + add (bias + factor)
"VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD":
lambda: (os.environ.get("VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD",
"False").lower() in ("true", "1")),
#If set to 1/True, enable fused topk topk kernel in lightop
"VLLM_USE_LIGHTOP_FUSED_TOPP_TOPK":
lambda: (os.environ.get("VLLM_USE_LIGHTOP_FUSED_TOPP_TOPK", "False").lower() in
......
......@@ -397,7 +397,22 @@ def fused_experts_impl_w16a16_marlin(hidden_states: torch.Tensor,
)
intermediate_cache3 = intermediate_cache3.view(-1, top_k_num, K)
if envs.VLLM_USE_LIGHTOP_MOE_SUM:
if envs.VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD and shared_output is not None:
from lightop import op as op
factor = (
float(routed_scaling_factor)
if routed_scaling_factor is not None
else 1.0
)
op.moe_sum(
input=intermediate_cache3.view(*intermediate_cache3.size()),
output=out_hidden_states[begin_chunk_idx:end_chunk_idx],
bias=shared_output[begin_chunk_idx:end_chunk_idx],
expert_mask=None,
num_local_tokens=None,
factor=factor,
)
elif envs.VLLM_USE_LIGHTOP_MOE_SUM:
from lightop import op as op
op.moe_sum(input=intermediate_cache3.view(*intermediate_cache3.size()),
output=out_hidden_states[begin_chunk_idx:end_chunk_idx], bias=None,
......@@ -406,4 +421,4 @@ def fused_experts_impl_w16a16_marlin(hidden_states: torch.Tensor,
ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.size()),
out_hidden_states[begin_chunk_idx:end_chunk_idx])
return out_hidden_states
\ No newline at end of file
return out_hidden_states
......@@ -1404,6 +1404,8 @@ def inplace_fused_experts(
w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
use_nn_moe: bool | None = False,
shared_output: torch.Tensor | None = None,
routed_scaling_factor: float = 1.0,
) -> None:
fused_experts_impl(
hidden_states,
......@@ -1433,6 +1435,8 @@ def inplace_fused_experts(
w1_bias,
w2_bias,
use_nn_moe,
shared_output,
routed_scaling_factor,
)
......@@ -1463,6 +1467,8 @@ def inplace_fused_experts_fake(
w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
use_nn_moe: bool | None = False,
shared_output: torch.Tensor | None = None,
routed_scaling_factor: float = 1.0,
) -> None:
pass
......@@ -1508,7 +1514,9 @@ def outplace_fused_experts(
w2_bias: torch.Tensor | None = None,
use_nn_moe: bool | None = False,
i_q: torch.Tensor | None = None,
i_s: torch.Tensor | None = None
i_s: torch.Tensor | None = None,
shared_output: torch.Tensor | None = None,
routed_scaling_factor: float = 1.0,
) -> torch.Tensor:
return fused_experts_impl(
hidden_states,
......@@ -1540,6 +1548,8 @@ def outplace_fused_experts(
use_nn_moe,
i_q=i_q,
i_s=i_s,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
)
......@@ -1550,6 +1560,7 @@ def outplace_fused_experts_fake(
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
......@@ -1569,6 +1580,8 @@ def outplace_fused_experts_fake(
w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
use_nn_moe: bool | None = False,
shared_output: torch.Tensor | None = None,
routed_scaling_factor: float = 1.0,
) -> torch.Tensor:
return torch.empty_like(hidden_states)
......@@ -1618,7 +1631,9 @@ def fused_experts(
use_nn_moe: bool | None = False,
use_fused_gate: bool | None = False,
i_q: torch.Tensor | None = None,
i_s: torch.Tensor | None = None
i_s: torch.Tensor | None = None,
shared_output: torch.Tensor | None = None,
routed_scaling_factor: float = 1.0,
) -> torch.Tensor:
if quant_config is None:
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
......@@ -1652,6 +1667,8 @@ def fused_experts(
use_nn_moe=use_nn_moe,
i_q=i_q,
i_s=i_s,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
)
......@@ -1712,7 +1729,9 @@ def fused_experts_impl(
w2_bias: torch.Tensor | None = None,
use_nn_moe: bool | None = False,
i_q: torch.Tensor | None = None,
i_s: torch.Tensor | None = None
i_s: torch.Tensor | None = None,
shared_output: torch.Tensor | None = None,
routed_scaling_factor: float = 1.0,
) -> torch.Tensor:
# Check constraints.
num_tokens = hidden_states.size(0)
......@@ -1820,6 +1839,8 @@ def fused_experts_impl(
global_num_experts=global_num_experts,
expert_map=expert_map,
use_nn_moe=False,
routed_scaling_factor=routed_scaling_factor,
shared_output=shared_output,
)
if use_nn_moe:
......@@ -2283,6 +2304,8 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
expert_tokens_meta: mk.ExpertTokensMetadata | None,
apply_router_weight_on_input: bool,
use_nn_moe: bool | None = False,
shared_output: torch.Tensor | None = None,
routed_scaling_factor: float = 1.0,
):
# Check constraints.
if self.quant_config.use_int4_w4a16:
......@@ -2416,10 +2439,33 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
)
# separate function is required for MoE + LoRA
self.moe_sum(intermediate_cache3, output)
self.moe_sum(
intermediate_cache3,
output,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
)
def moe_sum(
self,
input: torch.Tensor,
output: torch.Tensor,
shared_output: torch.Tensor | None = None,
routed_scaling_factor: float = 1.0,
) -> None:
if envs.VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD and shared_output is not None:
from lightop import op
def moe_sum(self, input: torch.Tensor, output: torch.Tensor) -> None:
ops.moe_sum(input, output)
op.moe_sum(
input=input,
output=output,
bias=shared_output,
expert_mask=None,
num_local_tokens=None,
factor=float(routed_scaling_factor),
)
else:
ops.moe_sum(input, output)
class TritonWNA16Experts(TritonExperts):
......
......@@ -1670,7 +1670,9 @@ class FusedMoE(CustomOp):
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
i_q: torch.Tensor | None = None,
i_s: torch.Tensor | None = None
i_s: torch.Tensor | None = None,
shared_output: torch.Tensor | None = None,
routed_scaling_factor: float = 1.0,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
og_hidden_states = hidden_states.shape[-1]
if self.hidden_size != og_hidden_states:
......@@ -1709,7 +1711,13 @@ class FusedMoE(CustomOp):
assert not isinstance(fused_output, tuple)
else:
fused_output = torch.ops.vllm.moe_forward(
hidden_states, router_logits, encode_layer_name()
hidden_states,
router_logits,
encode_layer_name(),
i_q=i_q,
i_s=i_s,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
)
return reduce_output(fused_output)[..., :og_hidden_states]
else:
......@@ -1723,13 +1731,21 @@ class FusedMoE(CustomOp):
else:
if envs.USE_FUSED_RMS_QUANT and i_q is not None and i_s is not None:
shared_output, fused_output = torch.ops.vllm.moe_forward_shared(
hidden_states, router_logits, encode_layer_name(),
hidden_states,
router_logits,
encode_layer_name(),
i_q=i_q,
i_s=i_s
i_s=i_s,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
)
else:
shared_output, fused_output = torch.ops.vllm.moe_forward_shared(
hidden_states, router_logits, encode_layer_name()
hidden_states,
router_logits,
encode_layer_name(),
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
)
return (
reduce_output(shared_output)[..., :og_hidden_states],
......@@ -1747,12 +1763,18 @@ class FusedMoE(CustomOp):
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
i_q: torch.Tensor | None = None,
i_s: torch.Tensor | None = None
i_s: torch.Tensor | None = None,
shared_output: torch.Tensor | None = None,
routed_scaling_factor: float = 1.0,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if envs.USE_FUSED_RMS_QUANT and i_q is not None and i_s is not None:
return self.forward_native(hidden_states, router_logits, i_q=i_q, i_s=i_s)
else:
return self.forward_native(hidden_states, router_logits)
return self.forward_native(
hidden_states,
router_logits,
i_q=i_q,
i_s=i_s,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
)
def forward_impl_chunked(
self,
......@@ -1895,10 +1917,11 @@ class FusedMoE(CustomOp):
router_logits: torch.Tensor,
use_fused_gate: bool | None = False,
i_q: torch.Tensor | None = None,
i_s: torch.Tensor | None = None
i_s: torch.Tensor | None = None,
shared_output: torch.Tensor | None = None,
routed_scaling_factor: float = 1.0,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.quant_method is not None
self.ensure_moe_quant_config_init()
self.ensure_dp_chunking_init()
......@@ -2026,21 +2049,25 @@ class FusedMoE(CustomOp):
if envs.USE_FUSED_RMS_QUANT:
final_hidden_states = self.quant_method.apply(
layer=self,
x=x, # The type signture of this is wrong due to the hack.
topk_weights=topk_weights,
topk_ids=topk_ids,
use_nn_moe=self.use_nn_moe,
i_q=i_q,
i_s=i_s
)
layer=self,
x=x, # The type signture of this is wrong due to the hack.
topk_weights=topk_weights,
topk_ids=topk_ids,
use_nn_moe=self.use_nn_moe,
i_q=i_q,
i_s=i_s,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
)
else:
final_hidden_states = self.quant_method.apply(
layer=self,
x=x, # The type signture of this is wrong due to the hack.
topk_weights=topk_weights,
topk_ids=topk_ids,
use_nn_moe=self.use_nn_moe
use_nn_moe=self.use_nn_moe,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
)
if has_separate_shared_experts:
......@@ -2164,11 +2191,20 @@ def moe_forward(
router_logits: torch.Tensor,
layer_name: str,
i_q: torch.Tensor | None = None,
i_s: torch.Tensor | None = None
i_s: torch.Tensor | None = None,
shared_output: torch.Tensor | None = None,
routed_scaling_factor: float = 1.0,
) -> torch.Tensor:
self = get_layer_from_name(layer_name)
assert self.shared_experts is None
return self.forward_impl(hidden_states, router_logits, i_q=i_q, i_s=i_s)
return self.forward_impl(
hidden_states,
router_logits,
i_q=i_q,
i_s=i_s,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
)
def moe_forward_fake(
......@@ -2176,7 +2212,9 @@ def moe_forward_fake(
router_logits: torch.Tensor,
layer_name: str,
i_q: torch.Tensor | None = None,
i_s: torch.Tensor | None = None
i_s: torch.Tensor | None = None,
shared_output: torch.Tensor | None = None,
routed_scaling_factor: float = 1.0,
) -> torch.Tensor:
return torch.empty_like(hidden_states)
......@@ -2195,14 +2233,28 @@ def moe_forward_shared(
router_logits: torch.Tensor,
layer_name: str,
i_q: torch.Tensor | None = None,
i_s: torch.Tensor | None = None
i_s: torch.Tensor | None = None,
shared_output: torch.Tensor | None = None,
routed_scaling_factor: float = 1.0,
) -> tuple[torch.Tensor, torch.Tensor]:
self = get_layer_from_name(layer_name)
assert self.shared_experts is not None
if envs.USE_FUSED_RMS_QUANT and i_q is not None and i_s is not None:
return self.forward_impl(hidden_states, router_logits, i_q=i_q, i_s=i_s)
return self.forward_impl(
hidden_states,
router_logits,
i_q=i_q,
i_s=i_s,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
)
else:
return self.forward_impl(hidden_states, router_logits)
return self.forward_impl(
hidden_states,
router_logits,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
)
def moe_forward_shared_fake(
......@@ -2210,7 +2262,9 @@ def moe_forward_shared_fake(
router_logits: torch.Tensor,
layer_name: str,
i_q: torch.Tensor | None = None,
i_s: torch.Tensor | None = None
i_s: torch.Tensor | None = None,
shared_output: torch.Tensor | None = None,
routed_scaling_factor: float = 1.0,
) -> tuple[torch.Tensor, torch.Tensor]:
shared_out = torch.empty_like(hidden_states)
fused_out = torch.empty_like(hidden_states)
......
......@@ -61,7 +61,10 @@ class SharedFusedMoE(FusedMoE):
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
*, iqis: tuple[torch.Tensor, torch.Tensor] | None = None
*,
iqis: tuple[torch.Tensor, torch.Tensor] | None = None,
shared_output: torch.Tensor | None = None,
routed_scaling_factor: float = 1.0,
) -> tuple[torch.Tensor, torch.Tensor]:
if not self.use_overlapped:
if self._shared_experts is not None:
......@@ -92,12 +95,16 @@ class SharedFusedMoE(FusedMoE):
hidden_states=hidden_states,
router_logits=router_logits,
i_q=iqis[0],
i_s=iqis[1]
i_s=iqis[1],
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
)
else:
fused_out = super().forward(
hidden_states=hidden_states,
router_logits=router_logits
router_logits=router_logits,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
)
else:
if envs.USE_FUSED_RMS_QUANT and iqis is not None:
......@@ -107,12 +114,16 @@ class SharedFusedMoE(FusedMoE):
hidden_states=hidden_states,
router_logits=router_logits,
i_q=iqis[0],
i_s=iqis[1]
i_s=iqis[1],
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
)
else:
shared_out, fused_out = super().forward(
hidden_states=hidden_states,
router_logits=router_logits
hidden_states=hidden_states,
router_logits=router_logits,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
)
# ensure early TP reduction of shared expert outputs when required
if (
......@@ -122,4 +133,4 @@ class SharedFusedMoE(FusedMoE):
and self.must_reduce_shared_expert_outputs()
):
shared_out = tensor_model_parallel_all_reduce(shared_out)
return shared_out, fused_out
\ No newline at end of file
return shared_out, fused_out
......@@ -370,7 +370,10 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
use_nn_moe: bool | None = False, **_
use_nn_moe: bool | None = False,
shared_output: torch.Tensor | None = None,
routed_scaling_factor: float = 1.0,
**_,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
return self.forward(
layer=layer,
......@@ -378,6 +381,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
topk_weights=topk_weights,
topk_ids=topk_ids,
use_nn_moe=use_nn_moe,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
)
def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig:
......@@ -397,6 +402,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
topk_ids: torch.Tensor,
use_nn_moe: bool | None = False,
use_fused_gate: bool | None = False,
shared_output: torch.Tensor | None = None,
routed_scaling_factor: float = 1.0,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if (
getattr(layer, "_marlin_w16a16_moe_enabled", False)
......@@ -415,6 +422,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
expert_map=layer.expert_map,
quant_config=self.get_fused_moe_quant_config(layer),
use_nn_moe=use_nn_moe,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
)
assert self.kernel is not None
......@@ -430,6 +439,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
use_nn_moe=use_nn_moe,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
)
def forward_monolithic_cpu(
......
......@@ -400,7 +400,9 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
topk_ids: torch.Tensor,
use_nn_moe: bool | None = False,
i_q: torch.Tensor | None = None,
i_s: torch.Tensor | None = None
i_s: torch.Tensor | None = None,
shared_output: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = 1.0,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts
......@@ -424,7 +426,9 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
a2_scale=layer.w2_input_scale,
use_nn_moe=False,
i_q=i_q,
i_s=i_s
i_s=i_s,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
)
def select_gemm_impl(
......@@ -461,4 +465,4 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
return DeepGemmExperts(moe_config=self.moe,
quant_config=self.moe_quant_config,
N=self.N,
K=self.K)
\ No newline at end of file
K=self.K)
......@@ -225,9 +225,10 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
use_nn_moe: bool | None = False,
use_fused_gate: bool | None = False,
i_q: torch.Tensor | None = None,
i_s: torch.Tensor | None = None
i_s: torch.Tensor | None = None,
shared_output: torch.Tensor | None = None,
routed_scaling_factor: float = 1.0,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
workspace, global_reduce_buffer = MarlinMoeWorkspace(x.device).get_buffers()
return fused_experts_impl_w4a8_marlin(
x,
......@@ -246,5 +247,7 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
w1_scale=(layer.w13_weight_scale),
w2_scale=(layer.w2_weight_scale),
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale
)
\ No newline at end of file
a2_scale=layer.w2_input_scale,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
)
......@@ -367,35 +367,64 @@ class DeepseekV2MoE(nn.Module):
if self.is_sequence_parallel:
hidden_states = sequence_parallel_chunk(hidden_states)
if self.experts.is_internal_router:
# In this case, the gate/router runs inside the FusedMoE class
fused_moe_out = self.experts(
hidden_states=hidden_states, router_logits=hidden_states,
iqis=iqis
)
else:
# router_logits: (num_tokens, n_experts)
needs_post_moe_combine = (
getattr(self.experts, "dp_size", 1) > 1
or getattr(self.experts, "pcp_size", 1) > 1
)
if (
envs.VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD
and self.shared_experts is not None
and not needs_post_moe_combine
):
router_logits, _ = self.gate(hidden_states)
fused_moe_out = self.experts(
hidden_states=hidden_states, router_logits=router_logits
shared_output = self.shared_experts(hidden_states, iqis=iqis)
routed_scaling_factor = (
1.0 if self.is_rocm_aiter_moe_enabled
else self.routed_scaling_factor
)
self.experts.use_overlapped = False
self.experts._shared_experts = None
# Marlin W16A16 fused reduce consumes the precomputed
# shared_output and routed_scaling_factor directly.
_, final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits,
iqis=iqis,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
)
else:
if self.experts.is_internal_router:
# In this case, the gate/router runs inside the FusedMoE class
fused_moe_out = self.experts(
hidden_states=hidden_states,
router_logits=hidden_states,
iqis=iqis,
)
else:
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
fused_moe_out = self.experts(
hidden_states=hidden_states, router_logits=router_logits
)
shared_output, final_hidden_states = fused_moe_out
if self.shared_experts is None:
assert shared_output is None
shared_output, final_hidden_states = fused_moe_out
if self.shared_experts is None:
assert shared_output is None
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
if hidden_states.dtype != torch.float16:
if not self.is_rocm_aiter_moe_enabled:
final_hidden_states *= self.routed_scaling_factor
elif self.shared_experts is not None:
assert shared_output is not None
shared_output *= 1.0 / self.routed_scaling_factor
if self.shared_experts is not None:
assert shared_output is not None
final_hidden_states += shared_output
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
if hidden_states.dtype != torch.float16:
if not self.is_rocm_aiter_moe_enabled:
final_hidden_states *= self.routed_scaling_factor
elif self.shared_experts is not None:
assert shared_output is not None
shared_output *= 1.0 / self.routed_scaling_factor
if self.shared_experts is not None:
assert shared_output is not None
final_hidden_states += shared_output
if self.is_sequence_parallel:
final_hidden_states = tensor_model_parallel_all_gather(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment