Unverified Commit 81964328 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

Set `num_fused_shared_experts` as `num_shared_experts` when shared_experts...

Set `num_fused_shared_experts` as `num_shared_experts` when shared_experts fusion is not disabled (#6736)
parent f0f84975
...@@ -400,7 +400,7 @@ def main(args: argparse.Namespace): ...@@ -400,7 +400,7 @@ def main(args: argparse.Namespace):
shard_intermediate_size = 2 * intermediate_size // args.tp_size shard_intermediate_size = 2 * intermediate_size // args.tp_size
elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]: elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]:
E = ( E = (
config.n_routed_experts + 1 config.n_routed_experts + (0 if args.disable_shared_experts_fusion else 1)
if config.architectures[0] in ["DeepseekV3ForCausalLM"] if config.architectures[0] in ["DeepseekV3ForCausalLM"]
else config.n_routed_experts else config.n_routed_experts
) )
...@@ -408,7 +408,9 @@ def main(args: argparse.Namespace): ...@@ -408,7 +408,9 @@ def main(args: argparse.Namespace):
intermediate_size = config.moe_intermediate_size intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size shard_intermediate_size = 2 * intermediate_size // args.tp_size
elif config.architectures[0] == "Llama4ForConditionalGeneration": elif config.architectures[0] == "Llama4ForConditionalGeneration":
E = config.text_config.num_local_experts + 1 E = config.text_config.num_local_experts + (
0 if args.disable_shared_experts_fusion else 1
)
topk = config.text_config.num_experts_per_tok topk = config.text_config.num_experts_per_tok
intermediate_size = config.text_config.intermediate_size intermediate_size = config.text_config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size shard_intermediate_size = 2 * intermediate_size // args.tp_size
...@@ -558,7 +560,7 @@ if __name__ == "__main__": ...@@ -558,7 +560,7 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1" "--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
) )
parser.add_argument("--tp-size", "-tp", type=int, default=2) parser.add_argument("--tp-size", "--tp", type=int, default=2)
parser.add_argument( parser.add_argument(
"--dtype", "--dtype",
type=str, type=str,
...@@ -568,6 +570,7 @@ if __name__ == "__main__": ...@@ -568,6 +570,7 @@ if __name__ == "__main__":
parser.add_argument("--seed", type=int, default=0) parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--batch-size", type=int, required=False) parser.add_argument("--batch-size", type=int, required=False)
parser.add_argument("--tune", action="store_true") parser.add_argument("--tune", action="store_true")
parser.add_argument("--disable-shared-experts-fusion", action="store_true")
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)
...@@ -156,6 +156,7 @@ class EPMoE(torch.nn.Module): ...@@ -156,6 +156,7 @@ class EPMoE(torch.nn.Module):
renormalize: bool = True, renormalize: bool = True,
use_grouped_topk: bool = False, use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None, num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
topk_group: Optional[int] = None, topk_group: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None, tp_size: Optional[int] = None,
...@@ -190,6 +191,7 @@ class EPMoE(torch.nn.Module): ...@@ -190,6 +191,7 @@ class EPMoE(torch.nn.Module):
if self.use_grouped_topk: if self.use_grouped_topk:
assert num_expert_group is not None and topk_group is not None assert num_expert_group is not None and topk_group is not None
self.num_expert_group = num_expert_group self.num_expert_group = num_expert_group
self.num_fused_shared_experts = num_fused_shared_experts
self.topk_group = topk_group self.topk_group = topk_group
self.correction_bias = correction_bias self.correction_bias = correction_bias
self.custom_routing_function = custom_routing_function self.custom_routing_function = custom_routing_function
...@@ -250,6 +252,7 @@ class EPMoE(torch.nn.Module): ...@@ -250,6 +252,7 @@ class EPMoE(torch.nn.Module):
renormalize=self.renormalize, renormalize=self.renormalize,
topk_group=self.topk_group, topk_group=self.topk_group,
num_expert_group=self.num_expert_group, num_expert_group=self.num_expert_group,
num_fused_shared_experts=self.num_fused_shared_experts,
correction_bias=self.correction_bias, correction_bias=self.correction_bias,
custom_routing_function=self.custom_routing_function, custom_routing_function=self.custom_routing_function,
routed_scaling_factor=self.routed_scaling_factor, routed_scaling_factor=self.routed_scaling_factor,
......
...@@ -21,6 +21,7 @@ def fused_moe_forward_native( ...@@ -21,6 +21,7 @@ def fused_moe_forward_native(
renormalize: bool, renormalize: bool,
topk_group: Optional[int] = None, topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None, num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None, correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu", activation: str = "silu",
...@@ -41,6 +42,7 @@ def fused_moe_forward_native( ...@@ -41,6 +42,7 @@ def fused_moe_forward_native(
renormalize=renormalize, renormalize=renormalize,
topk_group=topk_group, topk_group=topk_group,
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
correction_bias=correction_bias, correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor,
...@@ -71,6 +73,7 @@ def moe_forward_native( ...@@ -71,6 +73,7 @@ def moe_forward_native(
renormalize: bool, renormalize: bool,
topk_group: Optional[int] = None, topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None, num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None, correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu", activation: str = "silu",
...@@ -84,6 +87,7 @@ def moe_forward_native( ...@@ -84,6 +87,7 @@ def moe_forward_native(
renormalize=renormalize, renormalize=renormalize,
topk_group=topk_group, topk_group=topk_group,
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
correction_bias=correction_bias, correction_bias=correction_bias,
torch_native=True, torch_native=True,
......
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
}
}
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
}
}
...@@ -1540,6 +1540,7 @@ def fused_moe( ...@@ -1540,6 +1540,7 @@ def fused_moe(
activation: str = "silu", activation: str = "silu",
use_grouped_topk: bool = False, use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None, num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
topk_group: Optional[int] = None, topk_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
...@@ -1609,6 +1610,7 @@ def fused_moe( ...@@ -1609,6 +1610,7 @@ def fused_moe(
renormalize=renormalize, renormalize=renormalize,
topk_group=topk_group, topk_group=topk_group,
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor,
) )
......
...@@ -127,6 +127,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -127,6 +127,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
use_grouped_topk: bool, use_grouped_topk: bool,
topk_group: Optional[int] = None, topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None, num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None, correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu", activation: str = "silu",
...@@ -144,6 +145,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -144,6 +145,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
use_grouped_topk=use_grouped_topk, use_grouped_topk=use_grouped_topk,
topk_group=topk_group, topk_group=topk_group,
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
correction_bias=correction_bias, correction_bias=correction_bias,
activation=activation, activation=activation,
...@@ -163,6 +165,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -163,6 +165,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
renormalize: bool, renormalize: bool,
topk_group: Optional[int] = None, topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None, num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None, correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu", activation: str = "silu",
...@@ -179,6 +182,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -179,6 +182,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
renormalize=renormalize, renormalize=renormalize,
topk_group=topk_group, topk_group=topk_group,
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
correction_bias=correction_bias, correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor,
...@@ -232,6 +236,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -232,6 +236,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
renormalize: bool, renormalize: bool,
topk_group: Optional[int] = None, topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None, num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None, correction_bias: Optional[torch.Tensor] = None,
inplace: bool = True, inplace: bool = True,
...@@ -245,6 +250,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -245,6 +250,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
renormalize, renormalize,
topk_group, topk_group,
num_expert_group, num_expert_group,
num_fused_shared_experts,
custom_routing_function, custom_routing_function,
correction_bias, correction_bias,
) )
...@@ -289,6 +295,7 @@ class FusedMoE(torch.nn.Module): ...@@ -289,6 +295,7 @@ class FusedMoE(torch.nn.Module):
renormalize: bool = True, renormalize: bool = True,
use_grouped_topk: bool = False, use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None, num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
topk_group: Optional[int] = None, topk_group: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None, tp_size: Optional[int] = None,
...@@ -321,6 +328,7 @@ class FusedMoE(torch.nn.Module): ...@@ -321,6 +328,7 @@ class FusedMoE(torch.nn.Module):
if self.use_grouped_topk: if self.use_grouped_topk:
assert num_expert_group is not None and topk_group is not None assert num_expert_group is not None and topk_group is not None
self.num_expert_group = num_expert_group self.num_expert_group = num_expert_group
self.num_fused_shared_experts = num_fused_shared_experts
self.topk_group = topk_group self.topk_group = topk_group
self.custom_routing_function = custom_routing_function self.custom_routing_function = custom_routing_function
self.correction_bias = correction_bias self.correction_bias = correction_bias
...@@ -651,6 +659,7 @@ class FusedMoE(torch.nn.Module): ...@@ -651,6 +659,7 @@ class FusedMoE(torch.nn.Module):
use_grouped_topk=self.use_grouped_topk, use_grouped_topk=self.use_grouped_topk,
topk_group=self.topk_group, topk_group=self.topk_group,
num_expert_group=self.num_expert_group, num_expert_group=self.num_expert_group,
num_fused_shared_experts=self.num_fused_shared_experts,
custom_routing_function=self.custom_routing_function, custom_routing_function=self.custom_routing_function,
correction_bias=self.correction_bias, correction_bias=self.correction_bias,
activation=self.activation, activation=self.activation,
......
...@@ -303,6 +303,7 @@ def select_experts( ...@@ -303,6 +303,7 @@ def select_experts(
renormalize: bool, renormalize: bool,
topk_group: Optional[int] = None, topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None, num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None, correction_bias: Optional[torch.Tensor] = None,
torch_native: bool = False, torch_native: bool = False,
...@@ -310,7 +311,6 @@ def select_experts( ...@@ -310,7 +311,6 @@ def select_experts(
num_token_non_padded: Optional[torch.Tensor] = None, num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None, expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
): ):
num_fused_shared_experts = global_server_args_dict["num_fused_shared_experts"]
router_logits, correction_bias = ( router_logits, correction_bias = (
expert_location_dispatch.transform_select_experts_inputs( expert_location_dispatch.transform_select_experts_inputs(
......
...@@ -289,6 +289,7 @@ def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"): ...@@ -289,6 +289,7 @@ def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
use_grouped_topk: bool, use_grouped_topk: bool,
topk_group: Optional[int] = None, topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None, num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None, correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu", activation: str = "silu",
......
...@@ -367,6 +367,7 @@ class BlockInt8MoEMethod: ...@@ -367,6 +367,7 @@ class BlockInt8MoEMethod:
use_grouped_topk: bool, use_grouped_topk: bool,
topk_group: Optional[int] = None, topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None, num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None, correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu", activation: str = "silu",
...@@ -387,6 +388,7 @@ class BlockInt8MoEMethod: ...@@ -387,6 +388,7 @@ class BlockInt8MoEMethod:
renormalize=renormalize, renormalize=renormalize,
topk_group=topk_group, topk_group=topk_group,
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
correction_bias=correction_bias, correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor,
......
...@@ -272,6 +272,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -272,6 +272,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
use_grouped_topk: bool = False, use_grouped_topk: bool = False,
topk_group: Optional[int] = None, topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None, num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
global_num_experts: int = -1, global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
...@@ -294,6 +295,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -294,6 +295,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
renormalize=renormalize, renormalize=renormalize,
topk_group=topk_group, topk_group=topk_group,
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
correction_bias=correction_bias, correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor,
...@@ -627,6 +629,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): ...@@ -627,6 +629,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
use_grouped_topk: bool = False, use_grouped_topk: bool = False,
topk_group: Optional[int] = None, topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None, num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
global_num_experts: int = -1, global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
...@@ -651,6 +654,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): ...@@ -651,6 +654,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
renormalize=renormalize, renormalize=renormalize,
topk_group=topk_group, topk_group=topk_group,
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
scoring_func=scoring_func, scoring_func=scoring_func,
correction_bias=correction_bias, correction_bias=correction_bias,
......
...@@ -937,6 +937,7 @@ class Fp8MoEMethod: ...@@ -937,6 +937,7 @@ class Fp8MoEMethod:
use_grouped_topk: bool, use_grouped_topk: bool,
topk_group: Optional[int] = None, topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None, num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None, correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu", activation: str = "silu",
...@@ -957,6 +958,7 @@ class Fp8MoEMethod: ...@@ -957,6 +958,7 @@ class Fp8MoEMethod:
renormalize=renormalize, renormalize=renormalize,
topk_group=topk_group, topk_group=topk_group,
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
correction_bias=correction_bias, correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor,
......
...@@ -341,6 +341,7 @@ class MoeWNA16Method: ...@@ -341,6 +341,7 @@ class MoeWNA16Method:
use_grouped_topk: bool = False, use_grouped_topk: bool = False,
topk_group: Optional[int] = None, topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None, num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None, correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu", activation: str = "silu",
...@@ -362,6 +363,7 @@ class MoeWNA16Method: ...@@ -362,6 +363,7 @@ class MoeWNA16Method:
renormalize=renormalize, renormalize=renormalize,
topk_group=topk_group, topk_group=topk_group,
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
correction_bias=correction_bias, correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor,
......
...@@ -287,6 +287,7 @@ class W8A8FP8MoEMethod: ...@@ -287,6 +287,7 @@ class W8A8FP8MoEMethod:
use_grouped_topk: bool, use_grouped_topk: bool,
topk_group: Optional[int] = None, topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None, num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None, correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu", activation: str = "silu",
...@@ -306,6 +307,7 @@ class W8A8FP8MoEMethod: ...@@ -306,6 +307,7 @@ class W8A8FP8MoEMethod:
renormalize=renormalize, renormalize=renormalize,
topk_group=topk_group, topk_group=topk_group,
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
correction_bias=correction_bias, correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor,
......
...@@ -225,6 +225,7 @@ class W8A8Int8MoEMethod: ...@@ -225,6 +225,7 @@ class W8A8Int8MoEMethod:
use_grouped_topk: bool, use_grouped_topk: bool,
topk_group: Optional[int] = None, topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None, num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None, correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu", activation: str = "silu",
...@@ -245,6 +246,7 @@ class W8A8Int8MoEMethod: ...@@ -245,6 +246,7 @@ class W8A8Int8MoEMethod:
renormalize=renormalize, renormalize=renormalize,
topk_group=topk_group, topk_group=topk_group,
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
correction_bias=correction_bias, correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor,
......
...@@ -89,7 +89,7 @@ global_server_args_dict = { ...@@ -89,7 +89,7 @@ global_server_args_dict = {
"max_micro_batch_size": ServerArgs.max_micro_batch_size, "max_micro_batch_size": ServerArgs.max_micro_batch_size,
"moe_dense_tp_size": ServerArgs.moe_dense_tp_size, "moe_dense_tp_size": ServerArgs.moe_dense_tp_size,
"ep_dispatch_algorithm": ServerArgs.ep_dispatch_algorithm, "ep_dispatch_algorithm": ServerArgs.ep_dispatch_algorithm,
"num_fused_shared_experts": ServerArgs.num_fused_shared_experts, "disable_shared_experts_fusion": ServerArgs.disable_shared_experts_fusion,
"sampling_backend": ServerArgs.sampling_backend, "sampling_backend": ServerArgs.sampling_backend,
"speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc, "speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,
"speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single, "speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single,
......
...@@ -204,7 +204,7 @@ class ModelRunner: ...@@ -204,7 +204,7 @@ class ModelRunner:
"flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged, "flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
"moe_dense_tp_size": server_args.moe_dense_tp_size, "moe_dense_tp_size": server_args.moe_dense_tp_size,
"ep_dispatch_algorithm": server_args.ep_dispatch_algorithm, "ep_dispatch_algorithm": server_args.ep_dispatch_algorithm,
"num_fused_shared_experts": server_args.num_fused_shared_experts, "disable_shared_experts_fusion": server_args.disable_shared_experts_fusion,
"triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32, "triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
"torchao_config": server_args.torchao_config, "torchao_config": server_args.torchao_config,
"sampling_backend": server_args.sampling_backend, "sampling_backend": server_args.sampling_backend,
......
...@@ -224,9 +224,11 @@ class DeepseekV2MoE(nn.Module): ...@@ -224,9 +224,11 @@ class DeepseekV2MoE(nn.Module):
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.routed_scaling_factor = config.routed_scaling_factor self.routed_scaling_factor = config.routed_scaling_factor
self.n_shared_experts = config.n_shared_experts self.n_shared_experts = config.n_shared_experts
self.num_fused_shared_experts = global_server_args_dict[ self.num_fused_shared_experts = (
"num_fused_shared_experts" 0
] if global_server_args_dict["disable_shared_experts_fusion"]
else config.n_shared_experts
)
self.config = config self.config = config
self.layer_id = layer_id self.layer_id = layer_id
...@@ -248,7 +250,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -248,7 +250,7 @@ class DeepseekV2MoE(nn.Module):
num_experts=config.n_routed_experts num_experts=config.n_routed_experts
+ self.num_fused_shared_experts + self.num_fused_shared_experts
+ global_server_args_dict["ep_num_redundant_experts"], + global_server_args_dict["ep_num_redundant_experts"],
top_k=config.num_experts_per_tok + min(self.num_fused_shared_experts, 1), top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size, intermediate_size=config.moe_intermediate_size,
layer_id=self.layer_id, layer_id=self.layer_id,
...@@ -256,6 +258,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -256,6 +258,7 @@ class DeepseekV2MoE(nn.Module):
quant_config=quant_config, quant_config=quant_config,
use_grouped_topk=True, use_grouped_topk=True,
num_expert_group=config.n_group, num_expert_group=config.n_group,
num_fused_shared_experts=self.num_fused_shared_experts,
topk_group=config.topk_group, topk_group=config.topk_group,
correction_bias=self.gate.e_score_correction_bias, correction_bias=self.gate.e_score_correction_bias,
routed_scaling_factor=self.routed_scaling_factor, routed_scaling_factor=self.routed_scaling_factor,
...@@ -363,6 +366,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -363,6 +366,7 @@ class DeepseekV2MoE(nn.Module):
renormalize=self.renormalize, renormalize=self.renormalize,
topk_group=self.topk_group, topk_group=self.topk_group,
num_expert_group=self.num_expert_group, num_expert_group=self.num_expert_group,
num_fused_shared_experts=self.num_fused_shared_experts,
correction_bias=self.correction_bias, correction_bias=self.correction_bias,
routed_scaling_factor=self.routed_scaling_factor, routed_scaling_factor=self.routed_scaling_factor,
num_token_non_padded=forward_batch.num_token_non_padded, num_token_non_padded=forward_batch.num_token_non_padded,
...@@ -456,6 +460,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -456,6 +460,7 @@ class DeepseekV2MoE(nn.Module):
renormalize=self.renormalize, renormalize=self.renormalize,
topk_group=self.topk_group, topk_group=self.topk_group,
num_expert_group=self.num_expert_group, num_expert_group=self.num_expert_group,
num_fused_shared_experts=self.num_fused_shared_experts,
correction_bias=self.correction_bias, correction_bias=self.correction_bias,
routed_scaling_factor=self.routed_scaling_factor, routed_scaling_factor=self.routed_scaling_factor,
num_token_non_padded=state.forward_batch.num_token_non_padded, num_token_non_padded=state.forward_batch.num_token_non_padded,
...@@ -1679,9 +1684,11 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -1679,9 +1684,11 @@ class DeepseekV2ForCausalLM(nn.Module):
def determine_num_fused_shared_experts( def determine_num_fused_shared_experts(
self, architecture: str = "DeepseekV3ForCausalLM" self, architecture: str = "DeepseekV3ForCausalLM"
): ):
self.num_fused_shared_experts = global_server_args_dict[ self.num_fused_shared_experts = (
"num_fused_shared_experts" 0
] if global_server_args_dict["disable_shared_experts_fusion"]
else self.config.n_shared_experts
)
if self.num_fused_shared_experts > 0: if self.num_fused_shared_experts > 0:
# Only Deepseek V3/R1 can use shared experts fusion optimization now. # Only Deepseek V3/R1 can use shared experts fusion optimization now.
if ( if (
...@@ -1690,15 +1697,11 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -1690,15 +1697,11 @@ class DeepseekV2ForCausalLM(nn.Module):
or self.config.n_routed_experts != 256 or self.config.n_routed_experts != 256
): ):
self.num_fused_shared_experts = 0 self.num_fused_shared_experts = 0
global_server_args_dict["num_fused_shared_experts"] = 0 global_server_args_dict["disable_shared_experts_fusion"] = 1
log_info_on_rank0( log_info_on_rank0(
logger, logger,
"Only Deepseek V3/R1 on NV-platform can use shared experts fusion optimization. Shared experts fusion optimization is disabled.", "Only Deepseek V3/R1 on NV-platform can use shared experts fusion optimization. Shared experts fusion optimization is disabled.",
) )
else:
assert (
self.num_fused_shared_experts == self.tp_size
), f"Shared experts fusion optimization is enabled in DeepSeek V3/R1, set it to {self.tp_size} can get best optimized performance."
elif self.num_fused_shared_experts == 0: elif self.num_fused_shared_experts == 0:
if ( if (
_is_cuda _is_cuda
...@@ -1707,8 +1710,8 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -1707,8 +1710,8 @@ class DeepseekV2ForCausalLM(nn.Module):
and self.config.n_routed_experts == 256 and self.config.n_routed_experts == 256
and (not global_server_args_dict["enable_deepep_moe"]) and (not global_server_args_dict["enable_deepep_moe"])
): ):
self.num_fused_shared_experts = self.tp_size self.num_fused_shared_experts = self.config.n_shared_experts
global_server_args_dict["num_fused_shared_experts"] = self.tp_size global_server_args_dict["disable_shared_experts_fusion"] = 0
log_info_on_rank0( log_info_on_rank0(
logger, logger,
"Deepseek V3/R1 with fp8 can use shared experts fusion optimization when SM version >=90. Shared experts fusion optimization is enabled.", "Deepseek V3/R1 with fp8 can use shared experts fusion optimization when SM version >=90. Shared experts fusion optimization is enabled.",
...@@ -1910,6 +1913,7 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -1910,6 +1913,7 @@ class DeepseekV2ForCausalLM(nn.Module):
("gate_up_proj", "up_proj", 1), ("gate_up_proj", "up_proj", 1),
] ]
if self.num_fused_shared_experts > 0: if self.num_fused_shared_experts > 0:
assert self.num_fused_shared_experts == 1
weights_list = list(weights) weights_list = list(weights)
weights_dict = dict(weights_list) weights_dict = dict(weights_list)
if self.quant_config is not None: if self.quant_config is not None:
...@@ -1971,22 +1975,21 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -1971,22 +1975,21 @@ class DeepseekV2ForCausalLM(nn.Module):
for moe_layer in tqdm( for moe_layer in tqdm(
moe_layers, moe_layers,
desc=f"Cloning {self.num_fused_shared_experts} " desc=f"Cloning {self.num_fused_shared_experts} "
"replicas of the shared expert into MoE", "shared expert into MoE",
): ):
for suffix in suffix_list: for suffix in suffix_list:
shared_expert_weight_name = ( shared_expert_weight_name = (
f"model.layers.{moe_layer}.mlp.shared_experts.{suffix}" f"model.layers.{moe_layer}.mlp.shared_experts.{suffix}"
) )
for num_repeat in range(self.num_fused_shared_experts): weights_list.append(
weights_list.append( (
( f"model.layers.{moe_layer}."
f"model.layers.{moe_layer}." f"mlp.experts."
f"mlp.experts." f"{self.config.n_routed_experts + 0}"
f"{self.config.n_routed_experts + num_repeat}" f".{suffix}",
f".{suffix}", weights_dict[shared_expert_weight_name],
weights_dict[shared_expert_weight_name],
)
) )
)
names_to_remove += [shared_expert_weight_name] names_to_remove += [shared_expert_weight_name]
weights = [w for w in weights_list if w[0] not in names_to_remove] weights = [w for w in weights_list if w[0] not in names_to_remove]
......
...@@ -207,7 +207,7 @@ class ServerArgs: ...@@ -207,7 +207,7 @@ class ServerArgs:
flashinfer_mla_disable_ragged: bool = False flashinfer_mla_disable_ragged: bool = False
warmups: Optional[str] = None warmups: Optional[str] = None
moe_dense_tp_size: Optional[int] = None moe_dense_tp_size: Optional[int] = None
num_fused_shared_experts: int = 0 disable_shared_experts_fusion: bool = False
disable_chunked_prefix_cache: bool = False disable_chunked_prefix_cache: bool = False
disable_fast_image_processor: bool = False disable_fast_image_processor: bool = False
mm_attention_backend: Optional[str] = None mm_attention_backend: Optional[str] = None
...@@ -1384,13 +1384,10 @@ class ServerArgs: ...@@ -1384,13 +1384,10 @@ class ServerArgs:
default=ServerArgs.deepep_config, default=ServerArgs.deepep_config,
help="Tuned DeepEP config suitable for your own cluster. It can be either a string with JSON content or a file path.", help="Tuned DeepEP config suitable for your own cluster. It can be either a string with JSON content or a file path.",
) )
parser.add_argument( parser.add_argument(
"--num-fused-shared-experts", "--disable-shared-experts-fusion",
type=int, action="store_true",
default=0, help="Disable shared experts fusion optimization for deepseek v3/r1.",
help="The number of shared_experts need to be replicated to fuse with normal experts in deepseek v3/r1, "
"set it to tp_size can get best optimized performance. Note that for architectures with SM==90, we have enabled the shared experts fusion optimization by default for DeepSeek V3/R1, with num_fused_shared_experts automatically set to the TP size.",
) )
parser.add_argument( parser.add_argument(
"--disable-chunked-prefix-cache", "--disable-chunked-prefix-cache",
......
...@@ -68,7 +68,7 @@ __device__ void moe_fused_gate_impl( ...@@ -68,7 +68,7 @@ __device__ void moe_fused_gate_impl(
} }
// Calculate topk_excluding_share_expert_fusion from topk // Calculate topk_excluding_share_expert_fusion from topk
int64_t topk_excluding_share_expert_fusion = topk - (num_fused_shared_experts > 0 ? 1 : 0); int64_t topk_excluding_share_expert_fusion = topk - num_fused_shared_experts;
// Cast pointers to type T: // Cast pointers to type T:
auto* input_ptr = reinterpret_cast<T*>(input); auto* input_ptr = reinterpret_cast<T*>(input);
...@@ -224,13 +224,21 @@ __device__ void moe_fused_gate_impl( ...@@ -224,13 +224,21 @@ __device__ void moe_fused_gate_impl(
if (thread_group_idx == 0 && num_fused_shared_experts > 0) { if (thread_group_idx == 0 && num_fused_shared_experts > 0) {
int64_t last_idx = topk * thread_row + topk_excluding_share_expert_fusion; int64_t last_idx = topk * thread_row + topk_excluding_share_expert_fusion;
int64_t expert_offset = 0;
// Use round-robin to select expert
int64_t expert_offset = thread_row % num_fused_shared_experts;
indices_ptr[last_idx] = static_cast<int32_t>(params.NUM_EXPERTS + expert_offset); indices_ptr[last_idx] = static_cast<int32_t>(params.NUM_EXPERTS + expert_offset);
// Set the weight to the sum of all weights divided by routed_scaling_factor // Set the weight to the sum of all weights divided by routed_scaling_factor
output_ptr[last_idx] = output_sum / routed_scaling_factor; output_ptr[last_idx] = output_sum / routed_scaling_factor;
if (num_fused_shared_experts > 1) {
for (int i = 1; i < num_fused_shared_experts; ++i) {
++last_idx;
++expert_offset;
indices_ptr[last_idx] = static_cast<int32_t>(params.NUM_EXPERTS + expert_offset);
// Set the weight to the sum of all weights divided by routed_scaling_factor
output_ptr[last_idx] = output_sum / routed_scaling_factor;
}
}
} }
__syncthreads(); __syncthreads();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment