"docs/vscode:/vscode.git/clone" did not exist on "534ea8c97037b74fafadebc53693848073919132"
Unverified Commit 81964328 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

Set `num_fused_shared_experts` as `num_shared_experts` when shared_experts...

Set `num_fused_shared_experts` as `num_shared_experts` when shared_experts fusion is not disabled (#6736)
parent f0f84975
......@@ -400,7 +400,7 @@ def main(args: argparse.Namespace):
shard_intermediate_size = 2 * intermediate_size // args.tp_size
elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]:
E = (
config.n_routed_experts + 1
config.n_routed_experts + (0 if args.disable_shared_experts_fusion else 1)
if config.architectures[0] in ["DeepseekV3ForCausalLM"]
else config.n_routed_experts
)
......@@ -408,7 +408,9 @@ def main(args: argparse.Namespace):
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
elif config.architectures[0] == "Llama4ForConditionalGeneration":
E = config.text_config.num_local_experts + 1
E = config.text_config.num_local_experts + (
0 if args.disable_shared_experts_fusion else 1
)
topk = config.text_config.num_experts_per_tok
intermediate_size = config.text_config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
......@@ -558,7 +560,7 @@ if __name__ == "__main__":
parser.add_argument(
"--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
)
parser.add_argument("--tp-size", "-tp", type=int, default=2)
parser.add_argument("--tp-size", "--tp", type=int, default=2)
parser.add_argument(
"--dtype",
type=str,
......@@ -568,6 +570,7 @@ if __name__ == "__main__":
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--batch-size", type=int, required=False)
parser.add_argument("--tune", action="store_true")
parser.add_argument("--disable-shared-experts-fusion", action="store_true")
args = parser.parse_args()
main(args)
......@@ -156,6 +156,7 @@ class EPMoE(torch.nn.Module):
renormalize: bool = True,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
topk_group: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None,
......@@ -190,6 +191,7 @@ class EPMoE(torch.nn.Module):
if self.use_grouped_topk:
assert num_expert_group is not None and topk_group is not None
self.num_expert_group = num_expert_group
self.num_fused_shared_experts = num_fused_shared_experts
self.topk_group = topk_group
self.correction_bias = correction_bias
self.custom_routing_function = custom_routing_function
......@@ -250,6 +252,7 @@ class EPMoE(torch.nn.Module):
renormalize=self.renormalize,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
num_fused_shared_experts=self.num_fused_shared_experts,
correction_bias=self.correction_bias,
custom_routing_function=self.custom_routing_function,
routed_scaling_factor=self.routed_scaling_factor,
......
......@@ -21,6 +21,7 @@ def fused_moe_forward_native(
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
......@@ -41,6 +42,7 @@ def fused_moe_forward_native(
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,
......@@ -71,6 +73,7 @@ def moe_forward_native(
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
......@@ -84,6 +87,7 @@ def moe_forward_native(
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
torch_native=True,
......
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
}
}
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
}
}
......@@ -1540,6 +1540,7 @@ def fused_moe(
activation: str = "silu",
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
topk_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
use_fp8_w8a8: bool = False,
......@@ -1609,6 +1610,7 @@ def fused_moe(
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function,
routed_scaling_factor=routed_scaling_factor,
)
......
......@@ -127,6 +127,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
use_grouped_topk: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
......@@ -144,6 +145,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
use_grouped_topk=use_grouped_topk,
topk_group=topk_group,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
activation=activation,
......@@ -163,6 +165,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
......@@ -179,6 +182,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,
......@@ -232,6 +236,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
inplace: bool = True,
......@@ -245,6 +250,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
renormalize,
topk_group,
num_expert_group,
num_fused_shared_experts,
custom_routing_function,
correction_bias,
)
......@@ -289,6 +295,7 @@ class FusedMoE(torch.nn.Module):
renormalize: bool = True,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
topk_group: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None,
......@@ -321,6 +328,7 @@ class FusedMoE(torch.nn.Module):
if self.use_grouped_topk:
assert num_expert_group is not None and topk_group is not None
self.num_expert_group = num_expert_group
self.num_fused_shared_experts = num_fused_shared_experts
self.topk_group = topk_group
self.custom_routing_function = custom_routing_function
self.correction_bias = correction_bias
......@@ -651,6 +659,7 @@ class FusedMoE(torch.nn.Module):
use_grouped_topk=self.use_grouped_topk,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
num_fused_shared_experts=self.num_fused_shared_experts,
custom_routing_function=self.custom_routing_function,
correction_bias=self.correction_bias,
activation=self.activation,
......
......@@ -303,6 +303,7 @@ def select_experts(
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
torch_native: bool = False,
......@@ -310,7 +311,6 @@ def select_experts(
num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
):
num_fused_shared_experts = global_server_args_dict["num_fused_shared_experts"]
router_logits, correction_bias = (
expert_location_dispatch.transform_select_experts_inputs(
......
......@@ -289,6 +289,7 @@ def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
use_grouped_topk: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
......
......@@ -367,6 +367,7 @@ class BlockInt8MoEMethod:
use_grouped_topk: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
......@@ -387,6 +388,7 @@ class BlockInt8MoEMethod:
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,
......
......@@ -272,6 +272,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
......@@ -294,6 +295,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,
......@@ -627,6 +629,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
......@@ -651,6 +654,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
correction_bias=correction_bias,
......
......@@ -937,6 +937,7 @@ class Fp8MoEMethod:
use_grouped_topk: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
......@@ -957,6 +958,7 @@ class Fp8MoEMethod:
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,
......
......@@ -341,6 +341,7 @@ class MoeWNA16Method:
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
......@@ -362,6 +363,7 @@ class MoeWNA16Method:
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,
......
......@@ -287,6 +287,7 @@ class W8A8FP8MoEMethod:
use_grouped_topk: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
......@@ -306,6 +307,7 @@ class W8A8FP8MoEMethod:
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,
......
......@@ -225,6 +225,7 @@ class W8A8Int8MoEMethod:
use_grouped_topk: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
......@@ -245,6 +246,7 @@ class W8A8Int8MoEMethod:
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,
......
......@@ -89,7 +89,7 @@ global_server_args_dict = {
"max_micro_batch_size": ServerArgs.max_micro_batch_size,
"moe_dense_tp_size": ServerArgs.moe_dense_tp_size,
"ep_dispatch_algorithm": ServerArgs.ep_dispatch_algorithm,
"num_fused_shared_experts": ServerArgs.num_fused_shared_experts,
"disable_shared_experts_fusion": ServerArgs.disable_shared_experts_fusion,
"sampling_backend": ServerArgs.sampling_backend,
"speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,
"speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single,
......
......@@ -204,7 +204,7 @@ class ModelRunner:
"flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
"moe_dense_tp_size": server_args.moe_dense_tp_size,
"ep_dispatch_algorithm": server_args.ep_dispatch_algorithm,
"num_fused_shared_experts": server_args.num_fused_shared_experts,
"disable_shared_experts_fusion": server_args.disable_shared_experts_fusion,
"triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
"torchao_config": server_args.torchao_config,
"sampling_backend": server_args.sampling_backend,
......
......@@ -224,9 +224,11 @@ class DeepseekV2MoE(nn.Module):
self.tp_size = get_tensor_model_parallel_world_size()
self.routed_scaling_factor = config.routed_scaling_factor
self.n_shared_experts = config.n_shared_experts
self.num_fused_shared_experts = global_server_args_dict[
"num_fused_shared_experts"
]
self.num_fused_shared_experts = (
0
if global_server_args_dict["disable_shared_experts_fusion"]
else config.n_shared_experts
)
self.config = config
self.layer_id = layer_id
......@@ -248,7 +250,7 @@ class DeepseekV2MoE(nn.Module):
num_experts=config.n_routed_experts
+ self.num_fused_shared_experts
+ global_server_args_dict["ep_num_redundant_experts"],
top_k=config.num_experts_per_tok + min(self.num_fused_shared_experts, 1),
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
layer_id=self.layer_id,
......@@ -256,6 +258,7 @@ class DeepseekV2MoE(nn.Module):
quant_config=quant_config,
use_grouped_topk=True,
num_expert_group=config.n_group,
num_fused_shared_experts=self.num_fused_shared_experts,
topk_group=config.topk_group,
correction_bias=self.gate.e_score_correction_bias,
routed_scaling_factor=self.routed_scaling_factor,
......@@ -363,6 +366,7 @@ class DeepseekV2MoE(nn.Module):
renormalize=self.renormalize,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
num_fused_shared_experts=self.num_fused_shared_experts,
correction_bias=self.correction_bias,
routed_scaling_factor=self.routed_scaling_factor,
num_token_non_padded=forward_batch.num_token_non_padded,
......@@ -456,6 +460,7 @@ class DeepseekV2MoE(nn.Module):
renormalize=self.renormalize,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
num_fused_shared_experts=self.num_fused_shared_experts,
correction_bias=self.correction_bias,
routed_scaling_factor=self.routed_scaling_factor,
num_token_non_padded=state.forward_batch.num_token_non_padded,
......@@ -1679,9 +1684,11 @@ class DeepseekV2ForCausalLM(nn.Module):
def determine_num_fused_shared_experts(
self, architecture: str = "DeepseekV3ForCausalLM"
):
self.num_fused_shared_experts = global_server_args_dict[
"num_fused_shared_experts"
]
self.num_fused_shared_experts = (
0
if global_server_args_dict["disable_shared_experts_fusion"]
else self.config.n_shared_experts
)
if self.num_fused_shared_experts > 0:
# Only Deepseek V3/R1 can use shared experts fusion optimization now.
if (
......@@ -1690,15 +1697,11 @@ class DeepseekV2ForCausalLM(nn.Module):
or self.config.n_routed_experts != 256
):
self.num_fused_shared_experts = 0
global_server_args_dict["num_fused_shared_experts"] = 0
global_server_args_dict["disable_shared_experts_fusion"] = 1
log_info_on_rank0(
logger,
"Only Deepseek V3/R1 on NV-platform can use shared experts fusion optimization. Shared experts fusion optimization is disabled.",
)
else:
assert (
self.num_fused_shared_experts == self.tp_size
), f"Shared experts fusion optimization is enabled in DeepSeek V3/R1, set it to {self.tp_size} can get best optimized performance."
elif self.num_fused_shared_experts == 0:
if (
_is_cuda
......@@ -1707,8 +1710,8 @@ class DeepseekV2ForCausalLM(nn.Module):
and self.config.n_routed_experts == 256
and (not global_server_args_dict["enable_deepep_moe"])
):
self.num_fused_shared_experts = self.tp_size
global_server_args_dict["num_fused_shared_experts"] = self.tp_size
self.num_fused_shared_experts = self.config.n_shared_experts
global_server_args_dict["disable_shared_experts_fusion"] = 0
log_info_on_rank0(
logger,
"Deepseek V3/R1 with fp8 can use shared experts fusion optimization when SM version >=90. Shared experts fusion optimization is enabled.",
......@@ -1910,6 +1913,7 @@ class DeepseekV2ForCausalLM(nn.Module):
("gate_up_proj", "up_proj", 1),
]
if self.num_fused_shared_experts > 0:
assert self.num_fused_shared_experts == 1
weights_list = list(weights)
weights_dict = dict(weights_list)
if self.quant_config is not None:
......@@ -1971,22 +1975,21 @@ class DeepseekV2ForCausalLM(nn.Module):
for moe_layer in tqdm(
moe_layers,
desc=f"Cloning {self.num_fused_shared_experts} "
"replicas of the shared expert into MoE",
"shared expert into MoE",
):
for suffix in suffix_list:
shared_expert_weight_name = (
f"model.layers.{moe_layer}.mlp.shared_experts.{suffix}"
)
for num_repeat in range(self.num_fused_shared_experts):
weights_list.append(
(
f"model.layers.{moe_layer}."
f"mlp.experts."
f"{self.config.n_routed_experts + num_repeat}"
f".{suffix}",
weights_dict[shared_expert_weight_name],
)
weights_list.append(
(
f"model.layers.{moe_layer}."
f"mlp.experts."
f"{self.config.n_routed_experts + 0}"
f".{suffix}",
weights_dict[shared_expert_weight_name],
)
)
names_to_remove += [shared_expert_weight_name]
weights = [w for w in weights_list if w[0] not in names_to_remove]
......
......@@ -207,7 +207,7 @@ class ServerArgs:
flashinfer_mla_disable_ragged: bool = False
warmups: Optional[str] = None
moe_dense_tp_size: Optional[int] = None
num_fused_shared_experts: int = 0
disable_shared_experts_fusion: bool = False
disable_chunked_prefix_cache: bool = False
disable_fast_image_processor: bool = False
mm_attention_backend: Optional[str] = None
......@@ -1384,13 +1384,10 @@ class ServerArgs:
default=ServerArgs.deepep_config,
help="Tuned DeepEP config suitable for your own cluster. It can be either a string with JSON content or a file path.",
)
parser.add_argument(
"--num-fused-shared-experts",
type=int,
default=0,
help="The number of shared_experts need to be replicated to fuse with normal experts in deepseek v3/r1, "
"set it to tp_size can get best optimized performance. Note that for architectures with SM==90, we have enabled the shared experts fusion optimization by default for DeepSeek V3/R1, with num_fused_shared_experts automatically set to the TP size.",
"--disable-shared-experts-fusion",
action="store_true",
help="Disable shared experts fusion optimization for deepseek v3/r1.",
)
parser.add_argument(
"--disable-chunked-prefix-cache",
......
......@@ -68,7 +68,7 @@ __device__ void moe_fused_gate_impl(
}
// Calculate topk_excluding_share_expert_fusion from topk
int64_t topk_excluding_share_expert_fusion = topk - (num_fused_shared_experts > 0 ? 1 : 0);
int64_t topk_excluding_share_expert_fusion = topk - num_fused_shared_experts;
// Cast pointers to type T:
auto* input_ptr = reinterpret_cast<T*>(input);
......@@ -224,13 +224,21 @@ __device__ void moe_fused_gate_impl(
if (thread_group_idx == 0 && num_fused_shared_experts > 0) {
int64_t last_idx = topk * thread_row + topk_excluding_share_expert_fusion;
// Use round-robin to select expert
int64_t expert_offset = thread_row % num_fused_shared_experts;
int64_t expert_offset = 0;
indices_ptr[last_idx] = static_cast<int32_t>(params.NUM_EXPERTS + expert_offset);
// Set the weight to the sum of all weights divided by routed_scaling_factor
output_ptr[last_idx] = output_sum / routed_scaling_factor;
if (num_fused_shared_experts > 1) {
for (int i = 1; i < num_fused_shared_experts; ++i) {
++last_idx;
++expert_offset;
indices_ptr[last_idx] = static_cast<int32_t>(params.NUM_EXPERTS + expert_offset);
// Set the weight to the sum of all weights divided by routed_scaling_factor
output_ptr[last_idx] = output_sum / routed_scaling_factor;
}
}
}
__syncthreads();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment