"vscode:/vscode.git/clone" did not exist on "98ac391ed9081ba28b2e8085d1dc05ba2d75fd67"
Unverified Commit e132cba2 authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

fused moe triton tuning script support qwen3 (#5842)

parent 0045f4b2
......@@ -20,6 +20,13 @@ python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \
--dtype fp8_w8a8 \
--tune
# Tune Qwen3-235B-A22B-FP8 and TP=4
python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \
--model Qwen/Qwen3-235B-A22B-FP8 \
--tp-size 4 \
--dtype fp8_w8a8 \
--tune
# Tune DeepSeek-V3 with FP8, TP=8 and n_share_experts_fusion=8
python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \
--model deepseek-ai/DeepSeek-V3-0324 \
......
......@@ -30,10 +30,15 @@ def get_model_config(model_name: str, tp_size: int):
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] == "Qwen3MoeForCausalLM":
E = config.num_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]:
E = config.n_routed_experts
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] in [
"Grok1ForCausalLM",
......
......@@ -30,6 +30,11 @@ def get_model_config(model_name: str, tp_size: int):
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] == "Qwen3MoeForCausalLM":
E = config.num_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]:
E = config.n_routed_experts
topk = config.num_experts_per_tok
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment