Unverified Commit 8a548052 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

[Refactor] Rename `n_share_experts_fusion` as `num_fused_shared_experts` (#6735)

parent b6d0ce9f
...@@ -27,19 +27,17 @@ python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \ ...@@ -27,19 +27,17 @@ python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \
--dtype fp8_w8a8 \ --dtype fp8_w8a8 \
--tune --tune
# Tune DeepSeek-V3 with FP8, TP=8 and n_share_experts_fusion=8 # Tune DeepSeek-V3 with FP8 and TP=8
python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \ python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \
--model deepseek-ai/DeepSeek-V3-0324 \ --model deepseek-ai/DeepSeek-V3-0324 \
--tp-size 8 \ --tp-size 8 \
--n-share-experts-fusion 8 \
--dtype fp8_w8a8 \ --dtype fp8_w8a8 \
--tune --tune
# Tune DeepSeek-R1 with channel-wise INT8, TP=16 and n_share_experts_fusion=16 # Tune DeepSeek-R1 with channel-wise INT8 and TP=16
python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \ python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \
--model meituan/DeepSeek-R1-Channel-INT8 \ --model meituan/DeepSeek-R1-Channel-INT8 \
--tp-size 16 \ --tp-size 16 \
--n-share-experts-fusion 16 \
--dtype int8_w8a8 \ --dtype int8_w8a8 \
--tune --tune
``` ```
...@@ -65,11 +63,10 @@ python benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_tri ...@@ -65,11 +63,10 @@ python benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_tri
--model deepseek-ai/DeepSeek-V3-0324 \ --model deepseek-ai/DeepSeek-V3-0324 \
--tp-size 8 --tp-size 8
# Compare with custom TP size and n_share_experts_fusion # Compare with custom TP size
python benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py \ python benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py \
--model deepseek-ai/DeepSeek-V3-0324 \ --model deepseek-ai/DeepSeek-V3-0324 \
--tp-size 8 \ --tp-size 8
--n-share-experts-fusion 8
``` ```
The benchmark results will be saved as plots and data files in the specified output directory (default: `./configs/benchmark_ops/vllm_sglang_fused_moe/`). The benchmark results will be saved as plots and data files in the specified output directory (default: `./configs/benchmark_ops/vllm_sglang_fused_moe/`).
......
...@@ -18,7 +18,7 @@ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import ( ...@@ -18,7 +18,7 @@ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
) )
def get_model_config(model_name: str, tp_size: int, n_share_experts_fusion: int = 0): def get_model_config(model_name: str, tp_size: int):
"""Get model configuration parameters""" """Get model configuration parameters"""
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
...@@ -43,9 +43,8 @@ def get_model_config(model_name: str, tp_size: int, n_share_experts_fusion: int ...@@ -43,9 +43,8 @@ def get_model_config(model_name: str, tp_size: int, n_share_experts_fusion: int
intermediate_size = config.moe_intermediate_size intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]: elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]:
n_share_fusion_experts = n_share_experts_fusion
E = ( E = (
config.n_routed_experts + n_share_fusion_experts config.n_routed_experts + 1
if config.architectures[0] in ["DeepseekV3ForCausalLM"] if config.architectures[0] in ["DeepseekV3ForCausalLM"]
else config.n_routed_experts else config.n_routed_experts
) )
...@@ -294,7 +293,6 @@ def main(): ...@@ -294,7 +293,6 @@ def main():
"--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1" "--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
) )
parser.add_argument("--tp-size", type=int, default=2) parser.add_argument("--tp-size", type=int, default=2)
parser.add_argument("--n-share-experts-fusion", type=int, default=0)
parser.add_argument("--use-fp8-w8a8", action="store_true") parser.add_argument("--use-fp8-w8a8", action="store_true")
parser.add_argument( parser.add_argument(
"--save-path", "--save-path",
...@@ -325,9 +323,7 @@ def main(): ...@@ -325,9 +323,7 @@ def main():
pipeline_model_parallel_size=1, pipeline_model_parallel_size=1,
) )
model_config = get_model_config( model_config = get_model_config(args.model, args.tp_size)
args.model, args.tp_size, args.n_share_experts_fusion
)
benchmark.run( benchmark.run(
show_plots=True, show_plots=True,
print_data=True, print_data=True,
......
...@@ -399,9 +399,8 @@ def main(args: argparse.Namespace): ...@@ -399,9 +399,8 @@ def main(args: argparse.Namespace):
intermediate_size = config.moe_intermediate_size intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size shard_intermediate_size = 2 * intermediate_size // args.tp_size
elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]: elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]:
n_share_fusion_experts = args.n_share_experts_fusion
E = ( E = (
config.n_routed_experts + n_share_fusion_experts config.n_routed_experts + 1
if config.architectures[0] in ["DeepseekV3ForCausalLM"] if config.architectures[0] in ["DeepseekV3ForCausalLM"]
else config.n_routed_experts else config.n_routed_experts
) )
...@@ -409,8 +408,7 @@ def main(args: argparse.Namespace): ...@@ -409,8 +408,7 @@ def main(args: argparse.Namespace):
intermediate_size = config.moe_intermediate_size intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size shard_intermediate_size = 2 * intermediate_size // args.tp_size
elif config.architectures[0] == "Llama4ForConditionalGeneration": elif config.architectures[0] == "Llama4ForConditionalGeneration":
n_share_fusion_experts = args.n_share_experts_fusion E = config.text_config.num_local_experts + 1
E = config.text_config.num_local_experts + n_share_fusion_experts
topk = config.text_config.num_experts_per_tok topk = config.text_config.num_experts_per_tok
intermediate_size = config.text_config.intermediate_size intermediate_size = config.text_config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size shard_intermediate_size = 2 * intermediate_size // args.tp_size
...@@ -570,12 +568,6 @@ if __name__ == "__main__": ...@@ -570,12 +568,6 @@ if __name__ == "__main__":
parser.add_argument("--seed", type=int, default=0) parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--batch-size", type=int, required=False) parser.add_argument("--batch-size", type=int, required=False)
parser.add_argument("--tune", action="store_true") parser.add_argument("--tune", action="store_true")
parser.add_argument(
"--n-share-experts-fusion",
type=int,
default=0,
help="The number of shared_experts need to be replica to fuse with normal experts in deepseek v3/r1",
)
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)
...@@ -103,7 +103,7 @@ def grouped_topk( ...@@ -103,7 +103,7 @@ def grouped_topk(
renormalize: bool, renormalize: bool,
num_expert_group: int = 0, num_expert_group: int = 0,
topk_group: int = 0, topk_group: int = 0,
n_share_experts_fusion: int = 0, num_fused_shared_experts: int = 0,
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
num_token_non_padded: Optional[torch.Tensor] = None, num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None, expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
...@@ -128,10 +128,10 @@ def grouped_topk( ...@@ -128,10 +128,10 @@ def grouped_topk(
) # [n, e] ) # [n, e]
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e] tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
if n_share_experts_fusion: if num_fused_shared_experts:
topk_ids[:, -1] = torch.randint( topk_ids[:, -1] = torch.randint(
low=num_experts, low=num_experts,
high=num_experts + n_share_experts_fusion, high=num_experts + num_fused_shared_experts,
size=(topk_ids.size(0),), size=(topk_ids.size(0),),
dtype=topk_ids.dtype, dtype=topk_ids.dtype,
device=topk_ids.device, device=topk_ids.device,
...@@ -141,7 +141,7 @@ def grouped_topk( ...@@ -141,7 +141,7 @@ def grouped_topk(
if renormalize: if renormalize:
topk_weights_sum = ( topk_weights_sum = (
topk_weights.sum(dim=-1, keepdim=True) topk_weights.sum(dim=-1, keepdim=True)
if n_share_experts_fusion == 0 if num_fused_shared_experts == 0
else topk_weights[:, :-1].sum(dim=-1, keepdim=True) else topk_weights[:, :-1].sum(dim=-1, keepdim=True)
) )
topk_weights = topk_weights / topk_weights_sum topk_weights = topk_weights / topk_weights_sum
...@@ -160,7 +160,7 @@ def biased_grouped_topk_impl( ...@@ -160,7 +160,7 @@ def biased_grouped_topk_impl(
renormalize: bool, renormalize: bool,
num_expert_group: int = 0, num_expert_group: int = 0,
topk_group: int = 0, topk_group: int = 0,
n_share_experts_fusion: int = 0, num_fused_shared_experts: int = 0,
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
num_token_non_padded: Optional[torch.Tensor] = None, num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None, expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
...@@ -192,10 +192,10 @@ def biased_grouped_topk_impl( ...@@ -192,10 +192,10 @@ def biased_grouped_topk_impl(
_, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) _, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
topk_weights = scores.gather(1, topk_ids) topk_weights = scores.gather(1, topk_ids)
if n_share_experts_fusion: if num_fused_shared_experts:
topk_ids[:, -1] = torch.randint( topk_ids[:, -1] = torch.randint(
low=num_experts, low=num_experts,
high=num_experts + n_share_experts_fusion, high=num_experts + num_fused_shared_experts,
size=(topk_ids.size(0),), size=(topk_ids.size(0),),
dtype=topk_ids.dtype, dtype=topk_ids.dtype,
device=topk_ids.device, device=topk_ids.device,
...@@ -205,7 +205,7 @@ def biased_grouped_topk_impl( ...@@ -205,7 +205,7 @@ def biased_grouped_topk_impl(
if renormalize: if renormalize:
topk_weights_sum = ( topk_weights_sum = (
topk_weights.sum(dim=-1, keepdim=True) topk_weights.sum(dim=-1, keepdim=True)
if n_share_experts_fusion == 0 if num_fused_shared_experts == 0
else topk_weights[:, :-1].sum(dim=-1, keepdim=True) else topk_weights[:, :-1].sum(dim=-1, keepdim=True)
) )
topk_weights = topk_weights / topk_weights_sum topk_weights = topk_weights / topk_weights_sum
...@@ -239,7 +239,7 @@ def biased_grouped_topk( ...@@ -239,7 +239,7 @@ def biased_grouped_topk(
num_expert_group: int = 0, num_expert_group: int = 0,
topk_group: int = 0, topk_group: int = 0,
compiled: bool = True, compiled: bool = True,
n_share_experts_fusion: int = 0, num_fused_shared_experts: int = 0,
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
num_token_non_padded: Optional[torch.Tensor] = None, num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None, expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
...@@ -247,7 +247,7 @@ def biased_grouped_topk( ...@@ -247,7 +247,7 @@ def biased_grouped_topk(
assert ( assert (
routed_scaling_factor is not None routed_scaling_factor is not None
), "routed_scaling_factor is required for biased_grouped_topk" ), "routed_scaling_factor is required for biased_grouped_topk"
# TODO: moe_fused_gate kernel is not supported for n_share_experts_fusion > 0 now. # TODO: moe_fused_gate kernel is not supported for num_fused_shared_experts > 0 now.
if ( if (
_is_cuda _is_cuda
and gating_output.shape[1] // num_expert_group and gating_output.shape[1] // num_expert_group
...@@ -260,7 +260,7 @@ def biased_grouped_topk( ...@@ -260,7 +260,7 @@ def biased_grouped_topk(
num_expert_group, num_expert_group,
topk_group, topk_group,
topk, topk,
n_share_experts_fusion, num_fused_shared_experts,
routed_scaling_factor, routed_scaling_factor,
) )
# TODO merge into kernel for this branch # TODO merge into kernel for this branch
...@@ -288,7 +288,7 @@ def biased_grouped_topk( ...@@ -288,7 +288,7 @@ def biased_grouped_topk(
renormalize, renormalize,
num_expert_group, num_expert_group,
topk_group, topk_group,
n_share_experts_fusion=n_share_experts_fusion, num_fused_shared_experts=num_fused_shared_experts,
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor,
num_token_non_padded=num_token_non_padded, num_token_non_padded=num_token_non_padded,
expert_location_dispatch_info=expert_location_dispatch_info, expert_location_dispatch_info=expert_location_dispatch_info,
...@@ -310,7 +310,7 @@ def select_experts( ...@@ -310,7 +310,7 @@ def select_experts(
num_token_non_padded: Optional[torch.Tensor] = None, num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None, expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
): ):
n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"] num_fused_shared_experts = global_server_args_dict["num_fused_shared_experts"]
router_logits, correction_bias = ( router_logits, correction_bias = (
expert_location_dispatch.transform_select_experts_inputs( expert_location_dispatch.transform_select_experts_inputs(
...@@ -332,7 +332,7 @@ def select_experts( ...@@ -332,7 +332,7 @@ def select_experts(
renormalize=renormalize, renormalize=renormalize,
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
topk_group=topk_group, topk_group=topk_group,
n_share_experts_fusion=n_share_experts_fusion, num_fused_shared_experts=num_fused_shared_experts,
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor,
num_token_non_padded=num_token_non_padded, num_token_non_padded=num_token_non_padded,
expert_location_dispatch_info=expert_location_dispatch_info, expert_location_dispatch_info=expert_location_dispatch_info,
...@@ -346,7 +346,7 @@ def select_experts( ...@@ -346,7 +346,7 @@ def select_experts(
renormalize=renormalize, renormalize=renormalize,
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
topk_group=topk_group, topk_group=topk_group,
n_share_experts_fusion=n_share_experts_fusion, num_fused_shared_experts=num_fused_shared_experts,
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor,
num_token_non_padded=num_token_non_padded, num_token_non_padded=num_token_non_padded,
expert_location_dispatch_info=expert_location_dispatch_info, expert_location_dispatch_info=expert_location_dispatch_info,
......
...@@ -89,7 +89,7 @@ global_server_args_dict = { ...@@ -89,7 +89,7 @@ global_server_args_dict = {
"max_micro_batch_size": ServerArgs.max_micro_batch_size, "max_micro_batch_size": ServerArgs.max_micro_batch_size,
"moe_dense_tp_size": ServerArgs.moe_dense_tp_size, "moe_dense_tp_size": ServerArgs.moe_dense_tp_size,
"ep_dispatch_algorithm": ServerArgs.ep_dispatch_algorithm, "ep_dispatch_algorithm": ServerArgs.ep_dispatch_algorithm,
"n_share_experts_fusion": ServerArgs.n_share_experts_fusion, "num_fused_shared_experts": ServerArgs.num_fused_shared_experts,
"sampling_backend": ServerArgs.sampling_backend, "sampling_backend": ServerArgs.sampling_backend,
"speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc, "speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,
"speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single, "speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single,
......
...@@ -204,7 +204,7 @@ class ModelRunner: ...@@ -204,7 +204,7 @@ class ModelRunner:
"flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged, "flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
"moe_dense_tp_size": server_args.moe_dense_tp_size, "moe_dense_tp_size": server_args.moe_dense_tp_size,
"ep_dispatch_algorithm": server_args.ep_dispatch_algorithm, "ep_dispatch_algorithm": server_args.ep_dispatch_algorithm,
"n_share_experts_fusion": server_args.n_share_experts_fusion, "num_fused_shared_experts": server_args.num_fused_shared_experts,
"triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32, "triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
"torchao_config": server_args.torchao_config, "torchao_config": server_args.torchao_config,
"sampling_backend": server_args.sampling_backend, "sampling_backend": server_args.sampling_backend,
......
...@@ -122,7 +122,7 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM): ...@@ -122,7 +122,7 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
self.config = config self.config = config
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.quant_config = quant_config self.quant_config = quant_config
self.determine_n_share_experts_fusion("DeepseekV3ForCausalLMNextN") self.determine_num_fused_shared_experts("DeepseekV3ForCausalLMNextN")
self.model = DeepseekModelNextN( self.model = DeepseekModelNextN(
config, quant_config, prefix=add_prefix("model", prefix) config, quant_config, prefix=add_prefix("model", prefix)
......
...@@ -224,7 +224,9 @@ class DeepseekV2MoE(nn.Module): ...@@ -224,7 +224,9 @@ class DeepseekV2MoE(nn.Module):
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.routed_scaling_factor = config.routed_scaling_factor self.routed_scaling_factor = config.routed_scaling_factor
self.n_shared_experts = config.n_shared_experts self.n_shared_experts = config.n_shared_experts
self.n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"] self.num_fused_shared_experts = global_server_args_dict[
"num_fused_shared_experts"
]
self.config = config self.config = config
self.layer_id = layer_id self.layer_id = layer_id
...@@ -244,9 +246,9 @@ class DeepseekV2MoE(nn.Module): ...@@ -244,9 +246,9 @@ class DeepseekV2MoE(nn.Module):
self.experts = get_moe_impl_class()( self.experts = get_moe_impl_class()(
num_experts=config.n_routed_experts num_experts=config.n_routed_experts
+ self.n_share_experts_fusion + self.num_fused_shared_experts
+ global_server_args_dict["ep_num_redundant_experts"], + global_server_args_dict["ep_num_redundant_experts"],
top_k=config.num_experts_per_tok + min(self.n_share_experts_fusion, 1), top_k=config.num_experts_per_tok + min(self.num_fused_shared_experts, 1),
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size, intermediate_size=config.moe_intermediate_size,
layer_id=self.layer_id, layer_id=self.layer_id,
...@@ -265,7 +267,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -265,7 +267,7 @@ class DeepseekV2MoE(nn.Module):
), ),
) )
if config.n_shared_experts is not None and self.n_share_experts_fusion == 0: if config.n_shared_experts is not None and self.num_fused_shared_experts == 0:
intermediate_size = config.moe_intermediate_size * config.n_shared_experts intermediate_size = config.moe_intermediate_size * config.n_shared_experts
# disable tp for shared experts when enable deepep moe # disable tp for shared experts when enable deepep moe
self.shared_experts = DeepseekV2MLP( self.shared_experts = DeepseekV2MLP(
...@@ -418,7 +420,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -418,7 +420,7 @@ class DeepseekV2MoE(nn.Module):
return final_hidden_states return final_hidden_states
def _forward_shared_experts(self, hidden_states): def _forward_shared_experts(self, hidden_states):
if self.n_share_experts_fusion == 0: if self.num_fused_shared_experts == 0:
return self.shared_experts(hidden_states) return self.shared_experts(hidden_states)
else: else:
return None return None
...@@ -434,7 +436,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -434,7 +436,7 @@ class DeepseekV2MoE(nn.Module):
def op_shared_experts(self, state): def op_shared_experts(self, state):
hidden_states_mlp_input = state.pop("hidden_states_mlp_input") hidden_states_mlp_input = state.pop("hidden_states_mlp_input")
if (self.n_share_experts_fusion == 0) and is_non_idle_and_non_empty( if (self.num_fused_shared_experts == 0) and is_non_idle_and_non_empty(
state.forward_batch.forward_mode, hidden_states_mlp_input state.forward_batch.forward_mode, hidden_states_mlp_input
): ):
state.shared_output = self.shared_experts(hidden_states_mlp_input) state.shared_output = self.shared_experts(hidden_states_mlp_input)
...@@ -1648,7 +1650,7 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -1648,7 +1650,7 @@ class DeepseekV2ForCausalLM(nn.Module):
self.config = config self.config = config
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.quant_config = quant_config self.quant_config = quant_config
self.determine_n_share_experts_fusion() self.determine_num_fused_shared_experts()
self.model = DeepseekV2Model( self.model = DeepseekV2Model(
config, quant_config, prefix=add_prefix("model", prefix) config, quant_config, prefix=add_prefix("model", prefix)
) )
...@@ -1674,28 +1676,30 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -1674,28 +1676,30 @@ class DeepseekV2ForCausalLM(nn.Module):
def routed_experts_weights_of_layer(self): def routed_experts_weights_of_layer(self):
return self._routed_experts_weights_of_layer.value return self._routed_experts_weights_of_layer.value
def determine_n_share_experts_fusion( def determine_num_fused_shared_experts(
self, architecture: str = "DeepseekV3ForCausalLM" self, architecture: str = "DeepseekV3ForCausalLM"
): ):
self.n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"] self.num_fused_shared_experts = global_server_args_dict[
if self.n_share_experts_fusion > 0: "num_fused_shared_experts"
]
if self.num_fused_shared_experts > 0:
# Only Deepseek V3/R1 can use shared experts fusion optimization now. # Only Deepseek V3/R1 can use shared experts fusion optimization now.
if ( if (
not _is_cuda not _is_cuda
or self.config.architectures[0] != architecture or self.config.architectures[0] != architecture
or self.config.n_routed_experts != 256 or self.config.n_routed_experts != 256
): ):
self.n_share_experts_fusion = 0 self.num_fused_shared_experts = 0
global_server_args_dict["n_share_experts_fusion"] = 0 global_server_args_dict["num_fused_shared_experts"] = 0
log_info_on_rank0( log_info_on_rank0(
logger, logger,
"Only Deepseek V3/R1 on NV-platform can use shared experts fusion optimization. Shared experts fusion optimization is disabled.", "Only Deepseek V3/R1 on NV-platform can use shared experts fusion optimization. Shared experts fusion optimization is disabled.",
) )
else: else:
assert ( assert (
self.n_share_experts_fusion == self.tp_size self.num_fused_shared_experts == self.tp_size
), f"Shared experts fusion optimization is enabled in DeepSeek V3/R1, set it to {self.tp_size} can get best optimized performance." ), f"Shared experts fusion optimization is enabled in DeepSeek V3/R1, set it to {self.tp_size} can get best optimized performance."
elif self.n_share_experts_fusion == 0: elif self.num_fused_shared_experts == 0:
if ( if (
_is_cuda _is_cuda
and torch.cuda.get_device_capability("cuda") >= (9, 0) and torch.cuda.get_device_capability("cuda") >= (9, 0)
...@@ -1703,8 +1707,8 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -1703,8 +1707,8 @@ class DeepseekV2ForCausalLM(nn.Module):
and self.config.n_routed_experts == 256 and self.config.n_routed_experts == 256
and (not global_server_args_dict["enable_deepep_moe"]) and (not global_server_args_dict["enable_deepep_moe"])
): ):
self.n_share_experts_fusion = self.tp_size self.num_fused_shared_experts = self.tp_size
global_server_args_dict["n_share_experts_fusion"] = self.tp_size global_server_args_dict["num_fused_shared_experts"] = self.tp_size
log_info_on_rank0( log_info_on_rank0(
logger, logger,
"Deepseek V3/R1 with fp8 can use shared experts fusion optimization when SM version >=90. Shared experts fusion optimization is enabled.", "Deepseek V3/R1 with fp8 can use shared experts fusion optimization when SM version >=90. Shared experts fusion optimization is enabled.",
...@@ -1905,7 +1909,7 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -1905,7 +1909,7 @@ class DeepseekV2ForCausalLM(nn.Module):
("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1), ("gate_up_proj", "up_proj", 1),
] ]
if self.n_share_experts_fusion > 0: if self.num_fused_shared_experts > 0:
weights_list = list(weights) weights_list = list(weights)
weights_dict = dict(weights_list) weights_dict = dict(weights_list)
if self.quant_config is not None: if self.quant_config is not None:
...@@ -1966,14 +1970,14 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -1966,14 +1970,14 @@ class DeepseekV2ForCausalLM(nn.Module):
for moe_layer in tqdm( for moe_layer in tqdm(
moe_layers, moe_layers,
desc=f"Cloning {self.n_share_experts_fusion} " desc=f"Cloning {self.num_fused_shared_experts} "
"replicas of the shared expert into MoE", "replicas of the shared expert into MoE",
): ):
for suffix in suffix_list: for suffix in suffix_list:
shared_expert_weight_name = ( shared_expert_weight_name = (
f"model.layers.{moe_layer}.mlp.shared_experts.{suffix}" f"model.layers.{moe_layer}.mlp.shared_experts.{suffix}"
) )
for num_repeat in range(self.n_share_experts_fusion): for num_repeat in range(self.num_fused_shared_experts):
weights_list.append( weights_list.append(
( (
f"model.layers.{moe_layer}." f"model.layers.{moe_layer}."
...@@ -1992,7 +1996,7 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -1992,7 +1996,7 @@ class DeepseekV2ForCausalLM(nn.Module):
ckpt_gate_proj_name="gate_proj", ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj", ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts + self.n_share_experts_fusion, num_experts=self.config.n_routed_experts + self.num_fused_shared_experts,
) )
# Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None # Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
......
...@@ -206,7 +206,7 @@ class ServerArgs: ...@@ -206,7 +206,7 @@ class ServerArgs:
flashinfer_mla_disable_ragged: bool = False flashinfer_mla_disable_ragged: bool = False
warmups: Optional[str] = None warmups: Optional[str] = None
moe_dense_tp_size: Optional[int] = None moe_dense_tp_size: Optional[int] = None
n_share_experts_fusion: int = 0 num_fused_shared_experts: int = 0
disable_chunked_prefix_cache: bool = False disable_chunked_prefix_cache: bool = False
disable_fast_image_processor: bool = False disable_fast_image_processor: bool = False
mm_attention_backend: Optional[str] = None mm_attention_backend: Optional[str] = None
...@@ -1373,11 +1373,11 @@ class ServerArgs: ...@@ -1373,11 +1373,11 @@ class ServerArgs:
) )
parser.add_argument( parser.add_argument(
"--n-share-experts-fusion", "--num-fused-shared-experts",
type=int, type=int,
default=0, default=0,
help="The number of shared_experts need to be replicated to fuse with normal experts in deepseek v3/r1, " help="The number of shared_experts need to be replicated to fuse with normal experts in deepseek v3/r1, "
"set it to tp_size can get best optimized performance. Note that for architectures with SM==90, we have enabled the shared experts fusion optimization by default for DeepSeek V3/R1, with n_share_experts_fusion automatically set to the TP size.", "set it to tp_size can get best optimized performance. Note that for architectures with SM==90, we have enabled the shared experts fusion optimization by default for DeepSeek V3/R1, with num_fused_shared_experts automatically set to the TP size.",
) )
parser.add_argument( parser.add_argument(
"--disable-chunked-prefix-cache", "--disable-chunked-prefix-cache",
......
...@@ -161,7 +161,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { ...@@ -161,7 +161,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m.def( m.def(
"moe_fused_gate(Tensor input, Tensor bias, int num_expert_group, int topk_group, int topk, int " "moe_fused_gate(Tensor input, Tensor bias, int num_expert_group, int topk_group, int topk, int "
"n_share_experts_fusion, float routed_scaling_factor) -> " "num_fused_shared_experts, float routed_scaling_factor) -> "
"(Tensor[])"); "(Tensor[])");
m.impl("moe_fused_gate", torch::kCUDA, &moe_fused_gate); m.impl("moe_fused_gate", torch::kCUDA, &moe_fused_gate);
m.def( m.def(
......
...@@ -57,7 +57,7 @@ __device__ void moe_fused_gate_impl( ...@@ -57,7 +57,7 @@ __device__ void moe_fused_gate_impl(
int64_t num_rows, int64_t num_rows,
int64_t topk_group, int64_t topk_group,
int64_t topk, int64_t topk,
int64_t n_share_experts_fusion, int64_t num_fused_shared_experts,
double routed_scaling_factor, double routed_scaling_factor,
Params params) { Params params) {
int tidx = threadIdx.x; int tidx = threadIdx.x;
...@@ -68,7 +68,7 @@ __device__ void moe_fused_gate_impl( ...@@ -68,7 +68,7 @@ __device__ void moe_fused_gate_impl(
} }
// Calculate topk_excluding_share_expert_fusion from topk // Calculate topk_excluding_share_expert_fusion from topk
int64_t topk_excluding_share_expert_fusion = topk - (n_share_experts_fusion > 0 ? 1 : 0); int64_t topk_excluding_share_expert_fusion = topk - (num_fused_shared_experts > 0 ? 1 : 0);
// Cast pointers to type T: // Cast pointers to type T:
auto* input_ptr = reinterpret_cast<T*>(input); auto* input_ptr = reinterpret_cast<T*>(input);
...@@ -222,11 +222,11 @@ __device__ void moe_fused_gate_impl( ...@@ -222,11 +222,11 @@ __device__ void moe_fused_gate_impl(
__syncthreads(); __syncthreads();
} }
if (thread_group_idx == 0 && n_share_experts_fusion > 0) { if (thread_group_idx == 0 && num_fused_shared_experts > 0) {
int64_t last_idx = topk * thread_row + topk_excluding_share_expert_fusion; int64_t last_idx = topk * thread_row + topk_excluding_share_expert_fusion;
// Use round-robin to select expert // Use round-robin to select expert
int64_t expert_offset = thread_row % n_share_experts_fusion; int64_t expert_offset = thread_row % num_fused_shared_experts;
indices_ptr[last_idx] = static_cast<int32_t>(params.NUM_EXPERTS + expert_offset); indices_ptr[last_idx] = static_cast<int32_t>(params.NUM_EXPERTS + expert_offset);
// Set the weight to the sum of all weights divided by routed_scaling_factor // Set the weight to the sum of all weights divided by routed_scaling_factor
...@@ -273,7 +273,7 @@ __global__ void moe_fused_gate_kernel( ...@@ -273,7 +273,7 @@ __global__ void moe_fused_gate_kernel(
int64_t num_rows, int64_t num_rows,
int64_t topk_group, int64_t topk_group,
int64_t topk, int64_t topk,
int64_t n_share_experts_fusion, int64_t num_fused_shared_experts,
double routed_scaling_factor) { double routed_scaling_factor) {
KernelParams<VPT, NUM_EXPERTS, THREADS_PER_ROW, ROWS_PER_WARP, ROWS_PER_CTA, WARPS_PER_CTA> params; KernelParams<VPT, NUM_EXPERTS, THREADS_PER_ROW, ROWS_PER_WARP, ROWS_PER_CTA, WARPS_PER_CTA> params;
moe_fused_gate_impl<T>( moe_fused_gate_impl<T>(
...@@ -284,7 +284,7 @@ __global__ void moe_fused_gate_kernel( ...@@ -284,7 +284,7 @@ __global__ void moe_fused_gate_kernel(
num_rows, num_rows,
topk_group, topk_group,
topk, topk,
n_share_experts_fusion, num_fused_shared_experts,
routed_scaling_factor, routed_scaling_factor,
params); params);
} }
...@@ -305,7 +305,7 @@ __global__ void moe_fused_gate_kernel( ...@@ -305,7 +305,7 @@ __global__ void moe_fused_gate_kernel(
num_rows, \ num_rows, \
topk_group, \ topk_group, \
topk, \ topk, \
n_share_experts_fusion, \ num_fused_shared_experts, \
routed_scaling_factor); \ routed_scaling_factor); \
dispatched = true; \ dispatched = true; \
} while (0) } while (0)
...@@ -333,7 +333,7 @@ __global__ void moe_fused_gate_kernel_dynamic( ...@@ -333,7 +333,7 @@ __global__ void moe_fused_gate_kernel_dynamic(
int64_t num_expert_group, int64_t num_expert_group,
int64_t topk_group, int64_t topk_group,
int64_t topk, int64_t topk,
int64_t n_share_experts_fusion, int64_t num_fused_shared_experts,
double routed_scaling_factor) { double routed_scaling_factor) {
KernelParamsDynamic params; KernelParamsDynamic params;
params.NUM_EXPERTS = num_experts; // e.g, for deepseek v3, this is 256 params.NUM_EXPERTS = num_experts; // e.g, for deepseek v3, this is 256
...@@ -351,7 +351,7 @@ __global__ void moe_fused_gate_kernel_dynamic( ...@@ -351,7 +351,7 @@ __global__ void moe_fused_gate_kernel_dynamic(
num_rows, num_rows,
topk_group, topk_group,
topk, topk,
n_share_experts_fusion, num_fused_shared_experts,
routed_scaling_factor, routed_scaling_factor,
params); params);
} }
...@@ -365,7 +365,7 @@ std::vector<at::Tensor> moe_fused_gate( ...@@ -365,7 +365,7 @@ std::vector<at::Tensor> moe_fused_gate(
int64_t num_expert_group, int64_t num_expert_group,
int64_t topk_group, int64_t topk_group,
int64_t topk, int64_t topk,
int64_t n_share_experts_fusion, int64_t num_fused_shared_experts,
double routed_scaling_factor) { double routed_scaling_factor) {
int64_t num_rows = input.size(0); int64_t num_rows = input.size(0);
int32_t num_experts = input.size(1); int32_t num_experts = input.size(1);
...@@ -464,7 +464,7 @@ std::vector<at::Tensor> moe_fused_gate( ...@@ -464,7 +464,7 @@ std::vector<at::Tensor> moe_fused_gate(
num_expert_group, num_expert_group,
topk_group, topk_group,
topk, topk,
n_share_experts_fusion, num_fused_shared_experts,
routed_scaling_factor); routed_scaling_factor);
} else if (input.scalar_type() == at::kHalf) { } else if (input.scalar_type() == at::kHalf) {
moe_fused_gate_kernel_dynamic<float16_t><<<num_blocks, block_dim, 0, stream>>>( moe_fused_gate_kernel_dynamic<float16_t><<<num_blocks, block_dim, 0, stream>>>(
...@@ -477,7 +477,7 @@ std::vector<at::Tensor> moe_fused_gate( ...@@ -477,7 +477,7 @@ std::vector<at::Tensor> moe_fused_gate(
num_expert_group, num_expert_group,
topk_group, topk_group,
topk, topk,
n_share_experts_fusion, num_fused_shared_experts,
routed_scaling_factor); routed_scaling_factor);
} else if (input.scalar_type() == at::kFloat) { } else if (input.scalar_type() == at::kFloat) {
moe_fused_gate_kernel_dynamic<float32_t><<<num_blocks, block_dim, 0, stream>>>( moe_fused_gate_kernel_dynamic<float32_t><<<num_blocks, block_dim, 0, stream>>>(
...@@ -490,7 +490,7 @@ std::vector<at::Tensor> moe_fused_gate( ...@@ -490,7 +490,7 @@ std::vector<at::Tensor> moe_fused_gate(
num_expert_group, num_expert_group,
topk_group, topk_group,
topk, topk,
n_share_experts_fusion, num_fused_shared_experts,
routed_scaling_factor); routed_scaling_factor);
} else { } else {
TORCH_CHECK(false, "Unsupported data type for moe_fused_gate"); TORCH_CHECK(false, "Unsupported data type for moe_fused_gate");
......
...@@ -206,7 +206,7 @@ std::vector<at::Tensor> moe_fused_gate( ...@@ -206,7 +206,7 @@ std::vector<at::Tensor> moe_fused_gate(
int64_t num_expert_group, int64_t num_expert_group,
int64_t topk_group, int64_t topk_group,
int64_t topk, int64_t topk,
int64_t n_share_experts_fusion, int64_t num_fused_shared_experts,
double routed_scaling_factor); double routed_scaling_factor);
void fp8_blockwise_scaled_grouped_mm( void fp8_blockwise_scaled_grouped_mm(
......
...@@ -42,7 +42,7 @@ def moe_fused_gate( ...@@ -42,7 +42,7 @@ def moe_fused_gate(
num_expert_group, num_expert_group,
topk_group, topk_group,
topk, topk,
n_share_experts_fusion=0, num_fused_shared_experts=0,
routed_scaling_factor=0, routed_scaling_factor=0,
): ):
# This fused kernel function is used to select topk expert in a hierarchical 2-layer fashion # This fused kernel function is used to select topk expert in a hierarchical 2-layer fashion
...@@ -51,7 +51,7 @@ def moe_fused_gate( ...@@ -51,7 +51,7 @@ def moe_fused_gate(
# the #experts is decided by the input tensor shape and we currently only support power of 2 #experts # the #experts is decided by the input tensor shape and we currently only support power of 2 #experts
# and #experts should be divisible by num_expert_group. #expert/num_expert_group <= 32 is limited for now. # and #experts should be divisible by num_expert_group. #expert/num_expert_group <= 32 is limited for now.
# for non-supported case, we suggest to use the biased_grouped_topk func in sglang.srt.layers.moe.topk # for non-supported case, we suggest to use the biased_grouped_topk func in sglang.srt.layers.moe.topk
# n_share_experts_fusion: if > 0, the last expert will be replaced with a round-robin shared expert # num_fused_shared_experts: if > 0, the last expert will be replaced with a round-robin shared expert
# routed_scaling_factor: if > 0, the last expert will be scaled by this factor # routed_scaling_factor: if > 0, the last expert will be scaled by this factor
return torch.ops.sgl_kernel.moe_fused_gate.default( return torch.ops.sgl_kernel.moe_fused_gate.default(
input_tensor, input_tensor,
...@@ -59,7 +59,7 @@ def moe_fused_gate( ...@@ -59,7 +59,7 @@ def moe_fused_gate(
num_expert_group, num_expert_group,
topk_group, topk_group,
topk, topk,
n_share_experts_fusion, num_fused_shared_experts,
routed_scaling_factor, routed_scaling_factor,
) )
......
...@@ -19,15 +19,15 @@ from sglang.srt.layers.moe.topk import biased_grouped_topk ...@@ -19,15 +19,15 @@ from sglang.srt.layers.moe.topk import biased_grouped_topk
(512, 16, 8, 16), (512, 16, 8, 16),
], ],
) )
@pytest.mark.parametrize("n_share_experts_fusion", [0, 1, 8, 16]) @pytest.mark.parametrize("num_fused_shared_experts", [0, 1])
def test_moe_fused_gate_combined(seq_length, dtype, params, n_share_experts_fusion): def test_moe_fused_gate_combined(seq_length, dtype, params, num_fused_shared_experts):
num_experts, num_expert_group, topk_group, topk = params num_experts, num_expert_group, topk_group, topk = params
torch.manual_seed(seq_length) torch.manual_seed(seq_length)
tensor = torch.rand((seq_length, num_experts)).to(dtype).cuda() tensor = torch.rand((seq_length, num_experts)).to(dtype).cuda()
scores = tensor.clone() scores = tensor.clone()
bias = torch.rand(num_experts).to(dtype).cuda() bias = torch.rand(num_experts).to(dtype).cuda()
topk = topk + min(1, n_share_experts_fusion) topk = topk + min(1, num_fused_shared_experts)
output, indices = moe_fused_gate( output, indices = moe_fused_gate(
tensor, tensor,
...@@ -35,7 +35,7 @@ def test_moe_fused_gate_combined(seq_length, dtype, params, n_share_experts_fusi ...@@ -35,7 +35,7 @@ def test_moe_fused_gate_combined(seq_length, dtype, params, n_share_experts_fusi
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
topk_group=topk_group, topk_group=topk_group,
topk=topk, topk=topk,
n_share_experts_fusion=n_share_experts_fusion, num_fused_shared_experts=num_fused_shared_experts,
routed_scaling_factor=2.5, routed_scaling_factor=2.5,
) )
ref_output, ref_indices = biased_grouped_topk( ref_output, ref_indices = biased_grouped_topk(
...@@ -47,12 +47,12 @@ def test_moe_fused_gate_combined(seq_length, dtype, params, n_share_experts_fusi ...@@ -47,12 +47,12 @@ def test_moe_fused_gate_combined(seq_length, dtype, params, n_share_experts_fusi
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
topk_group=topk_group, topk_group=topk_group,
compiled=False, compiled=False,
n_share_experts_fusion=n_share_experts_fusion, num_fused_shared_experts=num_fused_shared_experts,
routed_scaling_factor=2.5, routed_scaling_factor=2.5,
) )
# When n_share_experts_fusion > 0, ignore the comparison of the last topk dimension # When num_fused_shared_experts > 0, ignore the comparison of the last topk dimension
if n_share_experts_fusion > 0: if num_fused_shared_experts > 0:
original_indices = indices.clone() original_indices = indices.clone()
original_ref_indices = ref_indices.clone() original_ref_indices = ref_indices.clone()
...@@ -60,7 +60,7 @@ def test_moe_fused_gate_combined(seq_length, dtype, params, n_share_experts_fusi ...@@ -60,7 +60,7 @@ def test_moe_fused_gate_combined(seq_length, dtype, params, n_share_experts_fusi
ref_indices = ref_indices[:, :-1] ref_indices = ref_indices[:, :-1]
valid_min = num_experts valid_min = num_experts
valid_max = num_experts + n_share_experts_fusion valid_max = num_experts + num_fused_shared_experts
shared_indices = original_indices[:, -1] shared_indices = original_indices[:, -1]
shared_ref_indices = original_ref_indices[:, -1] shared_ref_indices = original_ref_indices[:, -1]
if shared_indices is not None: if shared_indices is not None:
...@@ -87,11 +87,11 @@ def test_moe_fused_gate_combined(seq_length, dtype, params, n_share_experts_fusi ...@@ -87,11 +87,11 @@ def test_moe_fused_gate_combined(seq_length, dtype, params, n_share_experts_fusi
assert idx_check, ( assert idx_check, (
f"Indices mismatch at seq_length {seq_length}, dtype {dtype}, " f"Indices mismatch at seq_length {seq_length}, dtype {dtype}, "
f"params {params}, n_share_experts_fusion {n_share_experts_fusion}" f"params {params}, num_fused_shared_experts {num_fused_shared_experts}"
) )
assert output_check, ( assert output_check, (
f"Output mismatch at seq_length {seq_length}, dtype {dtype}, " f"Output mismatch at seq_length {seq_length}, dtype {dtype}, "
f"params {params}, n_share_experts_fusion {n_share_experts_fusion}" f"params {params}, num_fused_shared_experts {num_fused_shared_experts}"
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment