Unverified Commit 04e5b6fa authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

Revert "Triton fused_moe_kernel support ep moe tuning" (#12377)

parent ce6b17c0
...@@ -40,20 +40,10 @@ python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \ ...@@ -40,20 +40,10 @@ python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \
--tp-size 16 \ --tp-size 16 \
--dtype int8_w8a8 \ --dtype int8_w8a8 \
--tune --tune
# Tune with Expert Parallelism (EP) mode
python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \
--model Qwen/Qwen3-30B-A3B-FP8 \
--tp-size 1 \
--ep-size 2 \
--dtype fp8_w8a8 \
--tune
``` ```
After tuning, a configuration file (e.g., `E=64,N=640,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json`) will be generated in the current directory. You can move this file to `sglang/srt/layers/fused_moe_triton/configs/triton_version` dir to use it in `sglang`. After tuning, a configuration file (e.g., `E=64,N=640,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json`) will be generated in the current directory. You can move this file to `sglang/srt/layers/fused_moe_triton/configs/triton_version` dir to use it in `sglang`.
**Note for EP mode**: When using Expert Parallelism (`--ep-size > 1`), `--tp-size` must be set to 1. The configuration file uses local expert count instead of total expert count. For example, with 64 total experts and EP=2, the config file will be named `E=32,N=640,device_name=...,dtype=...json`.
### Performance Comparison Tool ### Performance Comparison Tool
- `benchmark_vllm_vs_sglang_fused_moe_triton.py`: A tool for comparing the performance of fused MoE kernels between vllm and sglang implementations. Supports various model architectures and data types. - `benchmark_vllm_vs_sglang_fused_moe_triton.py`: A tool for comparing the performance of fused MoE kernels between vllm and sglang implementations. Supports various model architectures and data types.
......
...@@ -50,10 +50,7 @@ def benchmark_config( ...@@ -50,10 +50,7 @@ def benchmark_config(
per_channel_quant: bool, per_channel_quant: bool,
block_shape: List[int] = None, block_shape: List[int] = None,
num_iters: int = 100, num_iters: int = 100,
ep_size: int = 1,
) -> float: ) -> float:
# In EP mode, each rank only handles a subset of experts
local_experts = num_experts // ep_size
init_dtype = torch.float16 if use_fp8_w8a8 else dtype init_dtype = torch.float16 if use_fp8_w8a8 else dtype
x = torch.randn(num_tokens, hidden_size, dtype=dtype) x = torch.randn(num_tokens, hidden_size, dtype=dtype)
if use_int8_w8a16 or use_int8_w8a8: if use_int8_w8a16 or use_int8_w8a8:
...@@ -61,7 +58,7 @@ def benchmark_config( ...@@ -61,7 +58,7 @@ def benchmark_config(
-127, -127,
127, 127,
( (
local_experts, num_experts,
shard_intermediate_size, shard_intermediate_size,
hidden_size, hidden_size,
), ),
...@@ -71,7 +68,7 @@ def benchmark_config( ...@@ -71,7 +68,7 @@ def benchmark_config(
-127, -127,
127, 127,
( (
local_experts, num_experts,
hidden_size, hidden_size,
shard_intermediate_size // 2, shard_intermediate_size // 2,
), ),
...@@ -79,14 +76,12 @@ def benchmark_config( ...@@ -79,14 +76,12 @@ def benchmark_config(
) )
else: else:
w1 = torch.randn( w1 = torch.randn(
local_experts, shard_intermediate_size, hidden_size, dtype=init_dtype num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype
) )
w2 = torch.randn( w2 = torch.randn(
local_experts, hidden_size, shard_intermediate_size // 2, dtype=init_dtype num_experts, hidden_size, shard_intermediate_size // 2, dtype=init_dtype
) )
gating_output = torch.randn( gating_output = torch.randn(num_iters, num_tokens, num_experts, dtype=torch.float32)
num_iters, num_tokens, local_experts, dtype=torch.float32
)
w1_scale = None w1_scale = None
w2_scale = None w2_scale = None
...@@ -94,18 +89,18 @@ def benchmark_config( ...@@ -94,18 +89,18 @@ def benchmark_config(
a2_scale = None a2_scale = None
if use_int8_w8a16: if use_int8_w8a16:
w1_scale = torch.randn( w1_scale = torch.randn(
(local_experts, 2 * shard_intermediate_size), dtype=torch.float32 (num_experts, 2 * shard_intermediate_size), dtype=torch.float32
) )
w2_scale = torch.randn((hidden_size, local_experts), dtype=torch.float32) w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32)
if use_fp8_w8a8 or use_int8_w8a8: if use_fp8_w8a8 or use_int8_w8a8:
if use_int8_w8a8 and block_shape is None: if use_int8_w8a8 and block_shape is None:
w1_scale = torch.randn( w1_scale = torch.randn(
local_experts, shard_intermediate_size, dtype=torch.float32 num_experts, shard_intermediate_size, dtype=torch.float32
) )
w2_scale = torch.randn(local_experts, hidden_size, dtype=torch.float32) w2_scale = torch.randn(num_experts, hidden_size, dtype=torch.float32)
elif block_shape is None: elif block_shape is None:
w1_scale = torch.randn(local_experts, dtype=torch.float32) w1_scale = torch.randn(num_experts, dtype=torch.float32)
w2_scale = torch.randn(local_experts, dtype=torch.float32) w2_scale = torch.randn(num_experts, dtype=torch.float32)
a1_scale = torch.randn(1, dtype=torch.float32) a1_scale = torch.randn(1, dtype=torch.float32)
a2_scale = torch.randn(1, dtype=torch.float32) a2_scale = torch.randn(1, dtype=torch.float32)
else: else:
...@@ -115,17 +110,17 @@ def benchmark_config( ...@@ -115,17 +110,17 @@ def benchmark_config(
k_tiles_w1 = (hidden_size + block_k - 1) // block_k k_tiles_w1 = (hidden_size + block_k - 1) // block_k
k_tiles_w2 = (shard_intermediate_size // 2 + block_k - 1) // block_k k_tiles_w2 = (shard_intermediate_size // 2 + block_k - 1) // block_k
w1_scale = torch.rand( w1_scale = torch.rand(
(local_experts, n_tiles_w1, k_tiles_w1), dtype=torch.float32 (num_experts, n_tiles_w1, k_tiles_w1), dtype=torch.float32
) )
w2_scale = torch.rand( w2_scale = torch.rand(
(local_experts, n_tiles_w2, k_tiles_w2), dtype=torch.float32 (num_experts, n_tiles_w2, k_tiles_w2), dtype=torch.float32
) )
if use_fp8_w8a8: if use_fp8_w8a8:
w1 = w1.to(torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn) w1 = w1.to(torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn)
w2 = w2.to(torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn) w2 = w2.to(torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn)
input_gating = torch.randn(num_tokens, local_experts, dtype=torch.float32) input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)
topk_config = TopKConfig( topk_config = TopKConfig(
top_k=topk, top_k=topk,
renormalize=True, renormalize=True,
...@@ -270,7 +265,6 @@ class BenchmarkWorker: ...@@ -270,7 +265,6 @@ class BenchmarkWorker:
use_int8_w8a16: bool, use_int8_w8a16: bool,
per_channel_quant: bool, per_channel_quant: bool,
block_shape: List[int], block_shape: List[int],
ep_size: int = 1,
) -> Tuple[Dict[str, int], float]: ) -> Tuple[Dict[str, int], float]:
torch.cuda.manual_seed_all(0) torch.cuda.manual_seed_all(0)
dtype_str = get_config_dtype_str( dtype_str = get_config_dtype_str(
...@@ -278,12 +272,10 @@ class BenchmarkWorker: ...@@ -278,12 +272,10 @@ class BenchmarkWorker:
) )
# NOTE(woosuk): The current naming convention uses w2.shape[2], which # NOTE(woosuk): The current naming convention uses w2.shape[2], which
# is the intermediate size after silu_and_mul. # is the intermediate size after silu_and_mul.
# For EP mode, use local expert count for config lookup
local_experts = num_experts // ep_size
block_n = block_shape[0] if block_shape else 0 block_n = block_shape[0] if block_shape else 0
block_k = block_shape[1] if block_shape else 0 block_k = block_shape[1] if block_shape else 0
op_config = get_moe_configs( op_config = get_moe_configs(
local_experts, num_experts,
shard_intermediate_size // 2, shard_intermediate_size // 2,
dtype_str, dtype_str,
block_n, block_n,
...@@ -293,7 +285,7 @@ class BenchmarkWorker: ...@@ -293,7 +285,7 @@ class BenchmarkWorker:
if op_config is None: if op_config is None:
config = get_default_config( config = get_default_config(
num_tokens, num_tokens,
local_experts, num_experts,
shard_intermediate_size, shard_intermediate_size,
hidden_size, hidden_size,
topk, topk,
...@@ -317,7 +309,6 @@ class BenchmarkWorker: ...@@ -317,7 +309,6 @@ class BenchmarkWorker:
use_int8_w8a16, use_int8_w8a16,
per_channel_quant, per_channel_quant,
block_shape, block_shape,
ep_size,
) )
return config, kernel_time return config, kernel_time
...@@ -335,7 +326,6 @@ class BenchmarkWorker: ...@@ -335,7 +326,6 @@ class BenchmarkWorker:
per_channel_quant: bool, per_channel_quant: bool,
block_shape: List[int], block_shape: List[int],
search_space: List[Dict[str, int]], search_space: List[Dict[str, int]],
ep_size: int = 1,
) -> Dict[str, int]: ) -> Dict[str, int]:
best_config = None best_config = None
best_time = float("inf") best_time = float("inf")
...@@ -356,7 +346,6 @@ class BenchmarkWorker: ...@@ -356,7 +346,6 @@ class BenchmarkWorker:
per_channel_quant, per_channel_quant,
block_shape, block_shape,
num_iters=10, num_iters=10,
ep_size=ep_size,
) )
except (triton.runtime.autotuner.OutOfResources, RuntimeError): except (triton.runtime.autotuner.OutOfResources, RuntimeError):
# Some configurations may be invalid and fail to compile. # Some configurations may be invalid and fail to compile.
...@@ -406,7 +395,6 @@ def get_filename( ...@@ -406,7 +395,6 @@ def get_filename(
use_int8_w8a16: bool, use_int8_w8a16: bool,
per_channel_quant: bool, per_channel_quant: bool,
block_shape: List[int], block_shape: List[int],
ep_size: int = 1,
) -> None: ) -> None:
dtype_str = get_config_dtype_str( dtype_str = get_config_dtype_str(
dtype, dtype,
...@@ -417,11 +405,9 @@ def get_filename( ...@@ -417,11 +405,9 @@ def get_filename(
# NOTE(woosuk): The current naming convention uses w2.shape[2], which # NOTE(woosuk): The current naming convention uses w2.shape[2], which
# is the intermediate size after silu_and_mul. # is the intermediate size after silu_and_mul.
# For EP mode, we use local expert count instead of total expert count
local_experts = num_experts // ep_size
filename = get_config_file_name( filename = get_config_file_name(
local_experts, num_experts,
shard_intermediate_size if ep_size > 1 else shard_intermediate_size // 2, shard_intermediate_size // 2,
dtype_str, dtype_str,
block_shape, block_shape,
per_channel_quant, per_channel_quant,
...@@ -433,35 +419,17 @@ def get_filename( ...@@ -433,35 +419,17 @@ def get_filename(
def main(args: argparse.Namespace): def main(args: argparse.Namespace):
print(args) print(args)
# Check EP mode constraint: tp_size must be 1 when ep_size > 1
if args.ep_size > 1 and args.tp_size != 1:
raise ValueError(
f"When using Expert Parallelism (ep_size={args.ep_size}), "
f"tp_size must be set to 1, but got tp_size={args.tp_size}. "
f"Please set --tp-size 1 when using --ep-size > 1."
)
config = AutoConfig.from_pretrained(args.model, trust_remote_code=True) config = AutoConfig.from_pretrained(args.model, trust_remote_code=True)
if config.architectures[0] == "DbrxForCausalLM": if config.architectures[0] == "DbrxForCausalLM":
E = config.ffn_config.moe_num_experts E = config.ffn_config.moe_num_experts
topk = config.ffn_config.moe_top_k topk = config.ffn_config.moe_top_k
intermediate_size = config.ffn_config.ffn_hidden_size intermediate_size = config.ffn_config.ffn_hidden_size
# In EP mode, use original intermediate_size; otherwise apply TP sharding shard_intermediate_size = 2 * intermediate_size // args.tp_size
shard_intermediate_size = (
intermediate_size
if args.ep_size > 1
else 2 * intermediate_size // args.tp_size
)
elif config.architectures[0] == "JambaForCausalLM": elif config.architectures[0] == "JambaForCausalLM":
E = config.num_experts E = config.num_experts
topk = config.num_experts_per_tok topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size intermediate_size = config.intermediate_size
# In EP mode, use original intermediate_size; otherwise apply TP sharding shard_intermediate_size = 2 * intermediate_size // args.tp_size
shard_intermediate_size = (
intermediate_size
if args.ep_size > 1
else 2 * intermediate_size // args.tp_size
)
elif config.architectures[0] in [ elif config.architectures[0] in [
"Qwen2MoeForCausalLM", "Qwen2MoeForCausalLM",
"Qwen3MoeForCausalLM", "Qwen3MoeForCausalLM",
...@@ -470,12 +438,7 @@ def main(args: argparse.Namespace): ...@@ -470,12 +438,7 @@ def main(args: argparse.Namespace):
E = config.num_experts E = config.num_experts
topk = config.num_experts_per_tok topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size intermediate_size = config.moe_intermediate_size
# In EP mode, use original intermediate_size; otherwise apply TP sharding shard_intermediate_size = 2 * intermediate_size // args.tp_size
shard_intermediate_size = (
intermediate_size
if args.ep_size > 1
else 2 * intermediate_size // args.tp_size
)
elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]: elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]:
E = ( E = (
config.n_routed_experts + (0 if args.disable_shared_experts_fusion else 1) config.n_routed_experts + (0 if args.disable_shared_experts_fusion else 1)
...@@ -484,24 +447,14 @@ def main(args: argparse.Namespace): ...@@ -484,24 +447,14 @@ def main(args: argparse.Namespace):
) )
topk = config.num_experts_per_tok topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size intermediate_size = config.moe_intermediate_size
# In EP mode, use original intermediate_size; otherwise apply TP sharding shard_intermediate_size = 2 * intermediate_size // args.tp_size
shard_intermediate_size = (
intermediate_size
if args.ep_size > 1
else 2 * intermediate_size // args.tp_size
)
elif config.architectures[0] == "Llama4ForConditionalGeneration": elif config.architectures[0] == "Llama4ForConditionalGeneration":
E = config.text_config.num_local_experts + ( E = config.text_config.num_local_experts + (
0 if args.disable_shared_experts_fusion else 1 0 if args.disable_shared_experts_fusion else 1
) )
topk = config.text_config.num_experts_per_tok topk = config.text_config.num_experts_per_tok
intermediate_size = config.text_config.intermediate_size intermediate_size = config.text_config.intermediate_size
# In EP mode, use original intermediate_size; otherwise apply TP sharding shard_intermediate_size = 2 * intermediate_size // args.tp_size
shard_intermediate_size = (
intermediate_size
if args.ep_size > 1
else 2 * intermediate_size // args.tp_size
)
elif config.architectures[0] in [ elif config.architectures[0] in [
"Grok1ForCausalLM", "Grok1ForCausalLM",
"Grok1ImgGen", "Grok1ImgGen",
...@@ -510,12 +463,7 @@ def main(args: argparse.Namespace): ...@@ -510,12 +463,7 @@ def main(args: argparse.Namespace):
E = config.num_local_experts E = config.num_local_experts
topk = config.num_experts_per_tok topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size intermediate_size = config.moe_intermediate_size
# In EP mode, use original intermediate_size; otherwise apply TP sharding shard_intermediate_size = 2 * intermediate_size // args.tp_size
shard_intermediate_size = (
intermediate_size
if args.ep_size > 1
else 2 * intermediate_size // args.tp_size
)
elif config.architectures[0] in [ elif config.architectures[0] in [
"BailingMoEForCausalLM", "BailingMoEForCausalLM",
"BailingMoeForCausalLM", "BailingMoeForCausalLM",
...@@ -524,33 +472,18 @@ def main(args: argparse.Namespace): ...@@ -524,33 +472,18 @@ def main(args: argparse.Namespace):
E = config.num_experts E = config.num_experts
topk = config.num_experts_per_tok topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size intermediate_size = config.moe_intermediate_size
# In EP mode, use original intermediate_size; otherwise apply TP sharding shard_intermediate_size = 2 * intermediate_size // args.tp_size
shard_intermediate_size = (
intermediate_size
if args.ep_size > 1
else 2 * intermediate_size // args.tp_size
)
elif config.architectures[0] in ["Glm4MoeForCausalLM"]: elif config.architectures[0] in ["Glm4MoeForCausalLM"]:
E = config.n_routed_experts E = config.n_routed_experts
topk = config.num_experts_per_tok topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size intermediate_size = config.moe_intermediate_size
# In EP mode, use original intermediate_size; otherwise apply TP sharding shard_intermediate_size = 2 * intermediate_size // args.tp_size
shard_intermediate_size = (
intermediate_size
if args.ep_size > 1
else 2 * intermediate_size // args.tp_size
)
else: else:
# Default: Mixtral # Default: Mixtral
E = config.num_local_experts E = config.num_local_experts
topk = config.num_experts_per_tok topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size intermediate_size = config.intermediate_size
# In EP mode, use original intermediate_size; otherwise apply TP sharding shard_intermediate_size = 2 * intermediate_size // args.tp_size
shard_intermediate_size = (
intermediate_size
if args.ep_size > 1
else 2 * intermediate_size // args.tp_size
)
hidden_size = getattr(config, "hidden_size", None) or config.text_config.hidden_size hidden_size = getattr(config, "hidden_size", None) or config.text_config.hidden_size
dtype = config.torch_dtype dtype = config.torch_dtype
...@@ -626,7 +559,6 @@ def main(args: argparse.Namespace): ...@@ -626,7 +559,6 @@ def main(args: argparse.Namespace):
use_int8_w8a16, use_int8_w8a16,
per_channel_quant, per_channel_quant,
block_shape, block_shape,
args.ep_size,
) )
print( print(
f"Start tuning over {len(search_space)} configurations to create {filename}..." f"Start tuning over {len(search_space)} configurations to create {filename}..."
...@@ -649,7 +581,6 @@ def main(args: argparse.Namespace): ...@@ -649,7 +581,6 @@ def main(args: argparse.Namespace):
per_channel_quant, per_channel_quant,
block_shape, block_shape,
search_space, search_space,
args.ep_size,
) )
for batch_size in batch_sizes for batch_size in batch_sizes
], ],
...@@ -679,7 +610,6 @@ def main(args: argparse.Namespace): ...@@ -679,7 +610,6 @@ def main(args: argparse.Namespace):
use_int8_w8a16, use_int8_w8a16,
per_channel_quant, per_channel_quant,
block_shape, block_shape,
args.ep_size,
) )
for batch_size in batch_sizes for batch_size in batch_sizes
], ],
...@@ -695,7 +625,7 @@ if __name__ == "__main__": ...@@ -695,7 +625,7 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1" "--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
) )
parser.add_argument("--tp-size", "--tp", type=int, default=1) parser.add_argument("--tp-size", "--tp", type=int, default=2)
parser.add_argument( parser.add_argument(
"--dtype", "--dtype",
type=str, type=str,
...@@ -710,9 +640,6 @@ if __name__ == "__main__": ...@@ -710,9 +640,6 @@ if __name__ == "__main__":
parser.add_argument("--batch-size", type=int, required=False) parser.add_argument("--batch-size", type=int, required=False)
parser.add_argument("--tune", action="store_true") parser.add_argument("--tune", action="store_true")
parser.add_argument("--disable-shared-experts-fusion", action="store_true") parser.add_argument("--disable-shared-experts-fusion", action="store_true")
parser.add_argument(
"--ep-size", "--ep", type=int, default=1, help="Expert parallelism size"
)
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)
...@@ -54,13 +54,10 @@ def benchmark_config( ...@@ -54,13 +54,10 @@ def benchmark_config(
topk_ids_dir: str, topk_ids_dir: str,
block_shape: List[int] = None, block_shape: List[int] = None,
num_iters: int = 100, num_iters: int = 100,
ep_size: int = 1,
) -> float: ) -> float:
ncu_enable = os.getenv("NCU_ENABLE", "0") == "1" ncu_enable = os.getenv("NCU_ENABLE", "0") == "1"
if ncu_enable: if ncu_enable:
num_iters = 1 num_iters = 1
# In EP mode, each rank only handles a subset of experts
local_experts = num_experts // ep_size
init_dtype = torch.float16 if use_fp8_w8a8 else dtype init_dtype = torch.float16 if use_fp8_w8a8 else dtype
hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype) hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype)
if use_int8_w8a16 or use_int8_w8a8: if use_int8_w8a16 or use_int8_w8a8:
...@@ -68,7 +65,7 @@ def benchmark_config( ...@@ -68,7 +65,7 @@ def benchmark_config(
-127, -127,
127, 127,
( (
local_experts, num_experts,
shard_intermediate_size, shard_intermediate_size,
hidden_size, hidden_size,
), ),
...@@ -78,7 +75,7 @@ def benchmark_config( ...@@ -78,7 +75,7 @@ def benchmark_config(
-127, -127,
127, 127,
( (
local_experts, num_experts,
hidden_size, hidden_size,
shard_intermediate_size // 2, shard_intermediate_size // 2,
), ),
...@@ -86,14 +83,12 @@ def benchmark_config( ...@@ -86,14 +83,12 @@ def benchmark_config(
) )
else: else:
w1 = torch.randn( w1 = torch.randn(
local_experts, shard_intermediate_size, hidden_size, dtype=init_dtype num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype
) )
w2 = torch.randn( w2 = torch.randn(
local_experts, hidden_size, shard_intermediate_size // 2, dtype=init_dtype num_experts, hidden_size, shard_intermediate_size // 2, dtype=init_dtype
) )
gating_output = torch.randn( gating_output = torch.randn(num_iters, num_tokens, num_experts, dtype=torch.float32)
num_iters, num_tokens, local_experts, dtype=torch.float32
)
w1_scale = None w1_scale = None
w2_scale = None w2_scale = None
...@@ -101,18 +96,18 @@ def benchmark_config( ...@@ -101,18 +96,18 @@ def benchmark_config(
a2_scale = None a2_scale = None
if use_int8_w8a16: if use_int8_w8a16:
w1_scale = torch.randn( w1_scale = torch.randn(
(local_experts, 2 * shard_intermediate_size), dtype=torch.float32 (num_experts, 2 * shard_intermediate_size), dtype=torch.float32
) )
w2_scale = torch.randn((hidden_size, local_experts), dtype=torch.float32) w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32)
if use_fp8_w8a8 or use_int8_w8a8: if use_fp8_w8a8 or use_int8_w8a8:
if use_int8_w8a8 and block_shape is None: if use_int8_w8a8 and block_shape is None:
w1_scale = torch.randn( w1_scale = torch.randn(
local_experts, shard_intermediate_size, dtype=torch.float32 num_experts, shard_intermediate_size, dtype=torch.float32
) )
w2_scale = torch.randn(local_experts, hidden_size, dtype=torch.float32) w2_scale = torch.randn(num_experts, hidden_size, dtype=torch.float32)
elif block_shape is None: elif block_shape is None:
w1_scale = torch.randn(local_experts, dtype=torch.float32) w1_scale = torch.randn(num_experts, dtype=torch.float32)
w2_scale = torch.randn(local_experts, dtype=torch.float32) w2_scale = torch.randn(num_experts, dtype=torch.float32)
a1_scale = torch.randn(1, dtype=torch.float32) a1_scale = torch.randn(1, dtype=torch.float32)
a2_scale = torch.randn(1, dtype=torch.float32) a2_scale = torch.randn(1, dtype=torch.float32)
else: else:
...@@ -122,17 +117,17 @@ def benchmark_config( ...@@ -122,17 +117,17 @@ def benchmark_config(
k_tiles_w1 = (hidden_size + block_k - 1) // block_k k_tiles_w1 = (hidden_size + block_k - 1) // block_k
k_tiles_w2 = (shard_intermediate_size // 2 + block_k - 1) // block_k k_tiles_w2 = (shard_intermediate_size // 2 + block_k - 1) // block_k
w1_scale = torch.rand( w1_scale = torch.rand(
(local_experts, n_tiles_w1, k_tiles_w1), dtype=torch.float32 (num_experts, n_tiles_w1, k_tiles_w1), dtype=torch.float32
) )
w2_scale = torch.rand( w2_scale = torch.rand(
(local_experts, n_tiles_w2, k_tiles_w2), dtype=torch.float32 (num_experts, n_tiles_w2, k_tiles_w2), dtype=torch.float32
) )
if use_fp8_w8a8: if use_fp8_w8a8:
w1 = w1.to(torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn) w1 = w1.to(torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn)
w2 = w2.to(torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn) w2 = w2.to(torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn)
input_gating = torch.randn(num_tokens, local_experts, dtype=torch.float32) input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)
topk_config = TopKConfig( topk_config = TopKConfig(
top_k=topk, top_k=topk,
renormalize=True, renormalize=True,
...@@ -157,7 +152,7 @@ def benchmark_config( ...@@ -157,7 +152,7 @@ def benchmark_config(
topk_weights, topk_ids, _ = topk_output topk_weights, topk_ids, _ = topk_output
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
topk_ids, config["BLOCK_SIZE_M"], local_experts topk_ids, config["BLOCK_SIZE_M"], num_experts
) )
M = hidden_states.shape[0] M = hidden_states.shape[0]
E, N, _ = w1.shape E, N, _ = w1.shape
...@@ -404,7 +399,6 @@ class BenchmarkWorker: ...@@ -404,7 +399,6 @@ class BenchmarkWorker:
block_shape: List[int], block_shape: List[int],
cfg: Dict[str, int], cfg: Dict[str, int],
topk_ids_dir: str, topk_ids_dir: str,
ep_size: int = 1,
) -> Tuple[Dict[str, int], float]: ) -> Tuple[Dict[str, int], float]:
torch.cuda.manual_seed_all(0) torch.cuda.manual_seed_all(0)
dtype_str = get_config_dtype_str( dtype_str = get_config_dtype_str(
...@@ -428,7 +422,6 @@ class BenchmarkWorker: ...@@ -428,7 +422,6 @@ class BenchmarkWorker:
use_int8_w8a16, use_int8_w8a16,
topk_ids_dir, topk_ids_dir,
block_shape, block_shape,
ep_size=ep_size,
) )
return cfg, kernel_time return cfg, kernel_time
...@@ -446,7 +439,6 @@ class BenchmarkWorker: ...@@ -446,7 +439,6 @@ class BenchmarkWorker:
block_shape: List[int], block_shape: List[int],
search_space: List[Dict[str, int]], search_space: List[Dict[str, int]],
topk_ids_dir: str, topk_ids_dir: str,
ep_size: int = 1,
) -> Dict[str, int]: ) -> Dict[str, int]:
trace0 = BestConfigTrace("kernel0") trace0 = BestConfigTrace("kernel0")
trace1 = BestConfigTrace("kernel1") trace1 = BestConfigTrace("kernel1")
...@@ -469,7 +461,6 @@ class BenchmarkWorker: ...@@ -469,7 +461,6 @@ class BenchmarkWorker:
topk_ids_dir, topk_ids_dir,
block_shape, block_shape,
num_iters=10, num_iters=10,
ep_size=ep_size,
) )
except triton.runtime.autotuner.OutOfResources: except triton.runtime.autotuner.OutOfResources:
# Some configurations may be invalid and fail to compile. # Some configurations may be invalid and fail to compile.
...@@ -545,7 +536,6 @@ def save_configs( ...@@ -545,7 +536,6 @@ def save_configs(
use_int8_w8a16: bool, use_int8_w8a16: bool,
block_shape: List[int], block_shape: List[int],
down_moe: bool = False, down_moe: bool = False,
ep_size: int = 1,
) -> None: ) -> None:
dtype_str = get_config_dtype_str( dtype_str = get_config_dtype_str(
dtype, dtype,
...@@ -556,10 +546,8 @@ def save_configs( ...@@ -556,10 +546,8 @@ def save_configs(
# NOTE(woosuk): The current naming convention uses w2.shape[2], which # NOTE(woosuk): The current naming convention uses w2.shape[2], which
# is the intermediate size after silu_and_mul. # is the intermediate size after silu_and_mul.
# For EP mode, use local expert count instead of total expert count
local_experts = num_experts // ep_size
filename = get_config_file_name( filename = get_config_file_name(
local_experts, num_experts,
shard_intermediate_size // 2, shard_intermediate_size // 2,
dtype_str, dtype_str,
block_shape, block_shape,
...@@ -575,45 +563,22 @@ def save_configs( ...@@ -575,45 +563,22 @@ def save_configs(
def main(args: argparse.Namespace): def main(args: argparse.Namespace):
print(args) print(args)
# Check EP mode constraint: tp_size must be 1 when ep_size > 1
if args.ep_size > 1 and args.tp_size != 1:
raise ValueError(
f"When using Expert Parallelism (ep_size={args.ep_size}), "
f"tp_size must be set to 1, but got tp_size={args.tp_size}. "
f"Please set --tp-size 1 when using --ep-size > 1."
)
config = AutoConfig.from_pretrained(args.model, trust_remote_code=True) config = AutoConfig.from_pretrained(args.model, trust_remote_code=True)
if config.architectures[0] == "DbrxForCausalLM": if config.architectures[0] == "DbrxForCausalLM":
E = config.ffn_config.moe_num_experts E = config.ffn_config.moe_num_experts
topk = config.ffn_config.moe_top_k topk = config.ffn_config.moe_top_k
intermediate_size = config.ffn_config.ffn_hidden_size intermediate_size = config.ffn_config.ffn_hidden_size
# In EP mode, use original intermediate_size; otherwise apply TP sharding shard_intermediate_size = 2 * intermediate_size // args.tp_size
shard_intermediate_size = (
intermediate_size
if args.ep_size > 1
else 2 * intermediate_size // args.tp_size
)
elif config.architectures[0] == "JambaForCausalLM": elif config.architectures[0] == "JambaForCausalLM":
E = config.num_experts E = config.num_experts
topk = config.num_experts_per_tok topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size intermediate_size = config.intermediate_size
# In EP mode, use original intermediate_size; otherwise apply TP sharding shard_intermediate_size = 2 * intermediate_size // args.tp_size
shard_intermediate_size = (
intermediate_size
if args.ep_size > 1
else 2 * intermediate_size // args.tp_size
)
elif config.architectures[0] in ["Qwen2MoeForCausalLM", "Qwen3MoeForCausalLM"]: elif config.architectures[0] in ["Qwen2MoeForCausalLM", "Qwen3MoeForCausalLM"]:
E = config.num_experts E = config.num_experts
topk = config.num_experts_per_tok topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size intermediate_size = config.moe_intermediate_size
# In EP mode, use original intermediate_size; otherwise apply TP sharding shard_intermediate_size = 2 * intermediate_size // args.tp_size
shard_intermediate_size = (
intermediate_size
if args.ep_size > 1
else 2 * intermediate_size // args.tp_size
)
elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]: elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]:
E = ( E = (
config.n_routed_experts + (0 if args.disable_shared_experts_fusion else 1) config.n_routed_experts + (0 if args.disable_shared_experts_fusion else 1)
...@@ -627,24 +592,14 @@ def main(args: argparse.Namespace): ...@@ -627,24 +592,14 @@ def main(args: argparse.Namespace):
else config.num_experts_per_tok else config.num_experts_per_tok
) )
intermediate_size = config.moe_intermediate_size intermediate_size = config.moe_intermediate_size
# In EP mode, use original intermediate_size; otherwise apply TP sharding shard_intermediate_size = 2 * intermediate_size // args.tp_size
shard_intermediate_size = (
intermediate_size
if args.ep_size > 1
else 2 * intermediate_size // args.tp_size
)
elif config.architectures[0] == "Llama4ForConditionalGeneration": elif config.architectures[0] == "Llama4ForConditionalGeneration":
E = config.text_config.num_local_experts + ( E = config.text_config.num_local_experts + (
0 if args.disable_shared_experts_fusion else 1 0 if args.disable_shared_experts_fusion else 1
) )
topk = config.text_config.num_experts_per_tok topk = config.text_config.num_experts_per_tok
intermediate_size = config.text_config.intermediate_size intermediate_size = config.text_config.intermediate_size
# In EP mode, use original intermediate_size; otherwise apply TP sharding shard_intermediate_size = 2 * intermediate_size // args.tp_size
shard_intermediate_size = (
intermediate_size
if args.ep_size > 1
else 2 * intermediate_size // args.tp_size
)
elif config.architectures[0] in [ elif config.architectures[0] in [
"Grok1ForCausalLM", "Grok1ForCausalLM",
"Grok1ImgGen", "Grok1ImgGen",
...@@ -653,33 +608,18 @@ def main(args: argparse.Namespace): ...@@ -653,33 +608,18 @@ def main(args: argparse.Namespace):
E = config.num_local_experts E = config.num_local_experts
topk = config.num_experts_per_tok topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size intermediate_size = config.moe_intermediate_size
# In EP mode, use original intermediate_size; otherwise apply TP sharding shard_intermediate_size = 2 * intermediate_size // args.tp_size
shard_intermediate_size = (
intermediate_size
if args.ep_size > 1
else 2 * intermediate_size // args.tp_size
)
elif config.architectures[0] in ["Glm4MoeForCausalLM"]: elif config.architectures[0] in ["Glm4MoeForCausalLM"]:
E = config.n_routed_experts E = config.n_routed_experts
topk = config.num_experts_per_tok topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size intermediate_size = config.moe_intermediate_size
# In EP mode, use original intermediate_size; otherwise apply TP sharding shard_intermediate_size = 2 * intermediate_size // args.tp_size
shard_intermediate_size = (
intermediate_size
if args.ep_size > 1
else 2 * intermediate_size // args.tp_size
)
else: else:
# Default: Mixtral # Default: Mixtral
E = config.num_local_experts E = config.num_local_experts
topk = config.num_experts_per_tok topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size intermediate_size = config.intermediate_size
# In EP mode, use original intermediate_size; otherwise apply TP sharding shard_intermediate_size = 2 * intermediate_size // args.tp_size
shard_intermediate_size = (
intermediate_size
if args.ep_size > 1
else 2 * intermediate_size // args.tp_size
)
hidden_size = getattr(config, "hidden_size", None) or config.text_config.hidden_size hidden_size = getattr(config, "hidden_size", None) or config.text_config.hidden_size
dtype = config.torch_dtype dtype = config.torch_dtype
...@@ -737,7 +677,6 @@ def main(args: argparse.Namespace): ...@@ -737,7 +677,6 @@ def main(args: argparse.Namespace):
block_shape, block_shape,
search_space, search_space,
topk_ids_dir, topk_ids_dir,
args.ep_size,
) )
else: else:
cfg = { cfg = {
...@@ -762,7 +701,6 @@ def main(args: argparse.Namespace): ...@@ -762,7 +701,6 @@ def main(args: argparse.Namespace):
block_shape, block_shape,
cfg, cfg,
topk_ids_dir, topk_ids_dir,
args.ep_size,
) )
print(f"{t0=}, {t0_tma=}, {t1=}, {t1_tma=}") print(f"{t0=}, {t0_tma=}, {t1=}, {t1_tma=}")
return return
...@@ -812,7 +750,6 @@ def main(args: argparse.Namespace): ...@@ -812,7 +750,6 @@ def main(args: argparse.Namespace):
block_shape, block_shape,
search_space, search_space,
topk_ids_dir, topk_ids_dir,
args.ep_size,
) )
for batch_size in batch_sizes for batch_size in batch_sizes
], ],
...@@ -838,7 +775,6 @@ def main(args: argparse.Namespace): ...@@ -838,7 +775,6 @@ def main(args: argparse.Namespace):
use_int8_w8a8, use_int8_w8a8,
use_int8_w8a16, use_int8_w8a16,
block_shape, block_shape,
ep_size=args.ep_size,
) )
best_configs1 = {M: sort_config(config) for M, config in zip(batch_sizes, configs1)} best_configs1 = {M: sort_config(config) for M, config in zip(batch_sizes, configs1)}
...@@ -854,7 +790,6 @@ def main(args: argparse.Namespace): ...@@ -854,7 +790,6 @@ def main(args: argparse.Namespace):
use_int8_w8a16, use_int8_w8a16,
block_shape, block_shape,
down_moe=True, down_moe=True,
ep_size=args.ep_size,
) )
end = time.perf_counter() end = time.perf_counter()
print(f"Tuning took {end - start:.2f} seconds") print(f"Tuning took {end - start:.2f} seconds")
...@@ -865,7 +800,7 @@ if __name__ == "__main__": ...@@ -865,7 +800,7 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1" "--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
) )
parser.add_argument("--tp-size", "--tp", type=int, default=1) parser.add_argument("--tp-size", "--tp", type=int, default=2)
parser.add_argument( parser.add_argument(
"--dtype", "--dtype",
type=str, type=str,
...@@ -878,9 +813,6 @@ if __name__ == "__main__": ...@@ -878,9 +813,6 @@ if __name__ == "__main__":
parser.add_argument("--disable-shared-experts-fusion", action="store_true") parser.add_argument("--disable-shared-experts-fusion", action="store_true")
parser.add_argument("--configs", type=int, nargs="+", required=False) parser.add_argument("--configs", type=int, nargs="+", required=False)
parser.add_argument("--topk-ids-dir", type=str, required=True) parser.add_argument("--topk-ids-dir", type=str, required=True)
parser.add_argument(
"--ep-size", "--ep", type=int, default=1, help="Expert parallelism size"
)
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"64": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"96": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"128": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment