Unverified Commit 52694b60 authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

Triton fused_moe_kernel support ep moe tuning (#12343)

parent 400bddf2
......@@ -40,10 +40,20 @@ python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \
--tp-size 16 \
--dtype int8_w8a8 \
--tune
# Tune with Expert Parallelism (EP) mode
python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \
--model Qwen/Qwen3-30B-A3B-FP8 \
--tp-size 1 \
--ep-size 2 \
--dtype fp8_w8a8 \
--tune
```
After tuning, a configuration file (e.g., `E=64,N=640,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json`) will be generated in the current directory. You can move this file to `sglang/srt/layers/fused_moe_triton/configs/triton_version` dir to use it in `sglang`.
**Note for EP mode**: When using Expert Parallelism (`--ep-size > 1`), `--tp-size` must be set to 1. The configuration file uses local expert count instead of total expert count. For example, with 64 total experts and EP=2, the config file will be named `E=32,N=640,device_name=...,dtype=...json`.
### Performance Comparison Tool
- `benchmark_vllm_vs_sglang_fused_moe_triton.py`: A tool for comparing the performance of fused MoE kernels between vllm and sglang implementations. Supports various model architectures and data types.
......
......@@ -50,7 +50,10 @@ def benchmark_config(
per_channel_quant: bool,
block_shape: List[int] = None,
num_iters: int = 100,
ep_size: int = 1,
) -> float:
# In EP mode, each rank only handles a subset of experts
local_experts = num_experts // ep_size
init_dtype = torch.float16 if use_fp8_w8a8 else dtype
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
if use_int8_w8a16 or use_int8_w8a8:
......@@ -58,7 +61,7 @@ def benchmark_config(
-127,
127,
(
num_experts,
local_experts,
shard_intermediate_size,
hidden_size,
),
......@@ -68,7 +71,7 @@ def benchmark_config(
-127,
127,
(
num_experts,
local_experts,
hidden_size,
shard_intermediate_size // 2,
),
......@@ -76,12 +79,14 @@ def benchmark_config(
)
else:
w1 = torch.randn(
num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype
local_experts, shard_intermediate_size, hidden_size, dtype=init_dtype
)
w2 = torch.randn(
num_experts, hidden_size, shard_intermediate_size // 2, dtype=init_dtype
local_experts, hidden_size, shard_intermediate_size // 2, dtype=init_dtype
)
gating_output = torch.randn(
num_iters, num_tokens, local_experts, dtype=torch.float32
)
gating_output = torch.randn(num_iters, num_tokens, num_experts, dtype=torch.float32)
w1_scale = None
w2_scale = None
......@@ -89,18 +94,18 @@ def benchmark_config(
a2_scale = None
if use_int8_w8a16:
w1_scale = torch.randn(
(num_experts, 2 * shard_intermediate_size), dtype=torch.float32
(local_experts, 2 * shard_intermediate_size), dtype=torch.float32
)
w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32)
w2_scale = torch.randn((hidden_size, local_experts), dtype=torch.float32)
if use_fp8_w8a8 or use_int8_w8a8:
if use_int8_w8a8 and block_shape is None:
w1_scale = torch.randn(
num_experts, shard_intermediate_size, dtype=torch.float32
local_experts, shard_intermediate_size, dtype=torch.float32
)
w2_scale = torch.randn(num_experts, hidden_size, dtype=torch.float32)
w2_scale = torch.randn(local_experts, hidden_size, dtype=torch.float32)
elif block_shape is None:
w1_scale = torch.randn(num_experts, dtype=torch.float32)
w2_scale = torch.randn(num_experts, dtype=torch.float32)
w1_scale = torch.randn(local_experts, dtype=torch.float32)
w2_scale = torch.randn(local_experts, dtype=torch.float32)
a1_scale = torch.randn(1, dtype=torch.float32)
a2_scale = torch.randn(1, dtype=torch.float32)
else:
......@@ -110,17 +115,17 @@ def benchmark_config(
k_tiles_w1 = (hidden_size + block_k - 1) // block_k
k_tiles_w2 = (shard_intermediate_size // 2 + block_k - 1) // block_k
w1_scale = torch.rand(
(num_experts, n_tiles_w1, k_tiles_w1), dtype=torch.float32
(local_experts, n_tiles_w1, k_tiles_w1), dtype=torch.float32
)
w2_scale = torch.rand(
(num_experts, n_tiles_w2, k_tiles_w2), dtype=torch.float32
(local_experts, n_tiles_w2, k_tiles_w2), dtype=torch.float32
)
if use_fp8_w8a8:
w1 = w1.to(torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn)
w2 = w2.to(torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn)
input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)
input_gating = torch.randn(num_tokens, local_experts, dtype=torch.float32)
topk_config = TopKConfig(
top_k=topk,
renormalize=True,
......@@ -265,6 +270,7 @@ class BenchmarkWorker:
use_int8_w8a16: bool,
per_channel_quant: bool,
block_shape: List[int],
ep_size: int = 1,
) -> Tuple[Dict[str, int], float]:
torch.cuda.manual_seed_all(0)
dtype_str = get_config_dtype_str(
......@@ -272,10 +278,12 @@ class BenchmarkWorker:
)
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
# is the intermediate size after silu_and_mul.
# For EP mode, use local expert count for config lookup
local_experts = num_experts // ep_size
block_n = block_shape[0] if block_shape else 0
block_k = block_shape[1] if block_shape else 0
op_config = get_moe_configs(
num_experts,
local_experts,
shard_intermediate_size // 2,
dtype_str,
block_n,
......@@ -285,7 +293,7 @@ class BenchmarkWorker:
if op_config is None:
config = get_default_config(
num_tokens,
num_experts,
local_experts,
shard_intermediate_size,
hidden_size,
topk,
......@@ -309,6 +317,7 @@ class BenchmarkWorker:
use_int8_w8a16,
per_channel_quant,
block_shape,
ep_size,
)
return config, kernel_time
......@@ -326,6 +335,7 @@ class BenchmarkWorker:
per_channel_quant: bool,
block_shape: List[int],
search_space: List[Dict[str, int]],
ep_size: int = 1,
) -> Dict[str, int]:
best_config = None
best_time = float("inf")
......@@ -346,6 +356,7 @@ class BenchmarkWorker:
per_channel_quant,
block_shape,
num_iters=10,
ep_size=ep_size,
)
except (triton.runtime.autotuner.OutOfResources, RuntimeError):
# Some configurations may be invalid and fail to compile.
......@@ -395,6 +406,7 @@ def get_filename(
use_int8_w8a16: bool,
per_channel_quant: bool,
block_shape: List[int],
ep_size: int = 1,
) -> None:
dtype_str = get_config_dtype_str(
dtype,
......@@ -405,9 +417,11 @@ def get_filename(
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
# is the intermediate size after silu_and_mul.
# For EP mode, we use local expert count instead of total expert count
local_experts = num_experts // ep_size
filename = get_config_file_name(
num_experts,
shard_intermediate_size // 2,
local_experts,
shard_intermediate_size if ep_size > 1 else shard_intermediate_size // 2,
dtype_str,
block_shape,
per_channel_quant,
......@@ -419,17 +433,35 @@ def get_filename(
def main(args: argparse.Namespace):
print(args)
# Check EP mode constraint: tp_size must be 1 when ep_size > 1
if args.ep_size > 1 and args.tp_size != 1:
raise ValueError(
f"When using Expert Parallelism (ep_size={args.ep_size}), "
f"tp_size must be set to 1, but got tp_size={args.tp_size}. "
f"Please set --tp-size 1 when using --ep-size > 1."
)
config = AutoConfig.from_pretrained(args.model, trust_remote_code=True)
if config.architectures[0] == "DbrxForCausalLM":
E = config.ffn_config.moe_num_experts
topk = config.ffn_config.moe_top_k
intermediate_size = config.ffn_config.ffn_hidden_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
# In EP mode, use original intermediate_size; otherwise apply TP sharding
shard_intermediate_size = (
intermediate_size
if args.ep_size > 1
else 2 * intermediate_size // args.tp_size
)
elif config.architectures[0] == "JambaForCausalLM":
E = config.num_experts
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
# In EP mode, use original intermediate_size; otherwise apply TP sharding
shard_intermediate_size = (
intermediate_size
if args.ep_size > 1
else 2 * intermediate_size // args.tp_size
)
elif config.architectures[0] in [
"Qwen2MoeForCausalLM",
"Qwen3MoeForCausalLM",
......@@ -438,7 +470,12 @@ def main(args: argparse.Namespace):
E = config.num_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
# In EP mode, use original intermediate_size; otherwise apply TP sharding
shard_intermediate_size = (
intermediate_size
if args.ep_size > 1
else 2 * intermediate_size // args.tp_size
)
elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]:
E = (
config.n_routed_experts + (0 if args.disable_shared_experts_fusion else 1)
......@@ -447,14 +484,24 @@ def main(args: argparse.Namespace):
)
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
# In EP mode, use original intermediate_size; otherwise apply TP sharding
shard_intermediate_size = (
intermediate_size
if args.ep_size > 1
else 2 * intermediate_size // args.tp_size
)
elif config.architectures[0] == "Llama4ForConditionalGeneration":
E = config.text_config.num_local_experts + (
0 if args.disable_shared_experts_fusion else 1
)
topk = config.text_config.num_experts_per_tok
intermediate_size = config.text_config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
# In EP mode, use original intermediate_size; otherwise apply TP sharding
shard_intermediate_size = (
intermediate_size
if args.ep_size > 1
else 2 * intermediate_size // args.tp_size
)
elif config.architectures[0] in [
"Grok1ForCausalLM",
"Grok1ImgGen",
......@@ -463,7 +510,12 @@ def main(args: argparse.Namespace):
E = config.num_local_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
# In EP mode, use original intermediate_size; otherwise apply TP sharding
shard_intermediate_size = (
intermediate_size
if args.ep_size > 1
else 2 * intermediate_size // args.tp_size
)
elif config.architectures[0] in [
"BailingMoEForCausalLM",
"BailingMoeForCausalLM",
......@@ -472,18 +524,33 @@ def main(args: argparse.Namespace):
E = config.num_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
# In EP mode, use original intermediate_size; otherwise apply TP sharding
shard_intermediate_size = (
intermediate_size
if args.ep_size > 1
else 2 * intermediate_size // args.tp_size
)
elif config.architectures[0] in ["Glm4MoeForCausalLM"]:
E = config.n_routed_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
# In EP mode, use original intermediate_size; otherwise apply TP sharding
shard_intermediate_size = (
intermediate_size
if args.ep_size > 1
else 2 * intermediate_size // args.tp_size
)
else:
# Default: Mixtral
E = config.num_local_experts
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
# In EP mode, use original intermediate_size; otherwise apply TP sharding
shard_intermediate_size = (
intermediate_size
if args.ep_size > 1
else 2 * intermediate_size // args.tp_size
)
hidden_size = getattr(config, "hidden_size", None) or config.text_config.hidden_size
dtype = config.torch_dtype
......@@ -559,6 +626,7 @@ def main(args: argparse.Namespace):
use_int8_w8a16,
per_channel_quant,
block_shape,
args.ep_size,
)
print(
f"Start tuning over {len(search_space)} configurations to create {filename}..."
......@@ -581,6 +649,7 @@ def main(args: argparse.Namespace):
per_channel_quant,
block_shape,
search_space,
args.ep_size,
)
for batch_size in batch_sizes
],
......@@ -610,6 +679,7 @@ def main(args: argparse.Namespace):
use_int8_w8a16,
per_channel_quant,
block_shape,
args.ep_size,
)
for batch_size in batch_sizes
],
......@@ -625,7 +695,7 @@ if __name__ == "__main__":
parser.add_argument(
"--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
)
parser.add_argument("--tp-size", "--tp", type=int, default=2)
parser.add_argument("--tp-size", "--tp", type=int, default=1)
parser.add_argument(
"--dtype",
type=str,
......@@ -640,6 +710,9 @@ if __name__ == "__main__":
parser.add_argument("--batch-size", type=int, required=False)
parser.add_argument("--tune", action="store_true")
parser.add_argument("--disable-shared-experts-fusion", action="store_true")
parser.add_argument(
"--ep-size", "--ep", type=int, default=1, help="Expert parallelism size"
)
args = parser.parse_args()
main(args)
......@@ -54,10 +54,13 @@ def benchmark_config(
topk_ids_dir: str,
block_shape: List[int] = None,
num_iters: int = 100,
ep_size: int = 1,
) -> float:
ncu_enable = os.getenv("NCU_ENABLE", "0") == "1"
if ncu_enable:
num_iters = 1
# In EP mode, each rank only handles a subset of experts
local_experts = num_experts // ep_size
init_dtype = torch.float16 if use_fp8_w8a8 else dtype
hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype)
if use_int8_w8a16 or use_int8_w8a8:
......@@ -65,7 +68,7 @@ def benchmark_config(
-127,
127,
(
num_experts,
local_experts,
shard_intermediate_size,
hidden_size,
),
......@@ -75,7 +78,7 @@ def benchmark_config(
-127,
127,
(
num_experts,
local_experts,
hidden_size,
shard_intermediate_size // 2,
),
......@@ -83,12 +86,14 @@ def benchmark_config(
)
else:
w1 = torch.randn(
num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype
local_experts, shard_intermediate_size, hidden_size, dtype=init_dtype
)
w2 = torch.randn(
num_experts, hidden_size, shard_intermediate_size // 2, dtype=init_dtype
local_experts, hidden_size, shard_intermediate_size // 2, dtype=init_dtype
)
gating_output = torch.randn(
num_iters, num_tokens, local_experts, dtype=torch.float32
)
gating_output = torch.randn(num_iters, num_tokens, num_experts, dtype=torch.float32)
w1_scale = None
w2_scale = None
......@@ -96,18 +101,18 @@ def benchmark_config(
a2_scale = None
if use_int8_w8a16:
w1_scale = torch.randn(
(num_experts, 2 * shard_intermediate_size), dtype=torch.float32
(local_experts, 2 * shard_intermediate_size), dtype=torch.float32
)
w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32)
w2_scale = torch.randn((hidden_size, local_experts), dtype=torch.float32)
if use_fp8_w8a8 or use_int8_w8a8:
if use_int8_w8a8 and block_shape is None:
w1_scale = torch.randn(
num_experts, shard_intermediate_size, dtype=torch.float32
local_experts, shard_intermediate_size, dtype=torch.float32
)
w2_scale = torch.randn(num_experts, hidden_size, dtype=torch.float32)
w2_scale = torch.randn(local_experts, hidden_size, dtype=torch.float32)
elif block_shape is None:
w1_scale = torch.randn(num_experts, dtype=torch.float32)
w2_scale = torch.randn(num_experts, dtype=torch.float32)
w1_scale = torch.randn(local_experts, dtype=torch.float32)
w2_scale = torch.randn(local_experts, dtype=torch.float32)
a1_scale = torch.randn(1, dtype=torch.float32)
a2_scale = torch.randn(1, dtype=torch.float32)
else:
......@@ -117,17 +122,17 @@ def benchmark_config(
k_tiles_w1 = (hidden_size + block_k - 1) // block_k
k_tiles_w2 = (shard_intermediate_size // 2 + block_k - 1) // block_k
w1_scale = torch.rand(
(num_experts, n_tiles_w1, k_tiles_w1), dtype=torch.float32
(local_experts, n_tiles_w1, k_tiles_w1), dtype=torch.float32
)
w2_scale = torch.rand(
(num_experts, n_tiles_w2, k_tiles_w2), dtype=torch.float32
(local_experts, n_tiles_w2, k_tiles_w2), dtype=torch.float32
)
if use_fp8_w8a8:
w1 = w1.to(torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn)
w2 = w2.to(torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn)
input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)
input_gating = torch.randn(num_tokens, local_experts, dtype=torch.float32)
topk_config = TopKConfig(
top_k=topk,
renormalize=True,
......@@ -152,7 +157,7 @@ def benchmark_config(
topk_weights, topk_ids, _ = topk_output
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
topk_ids, config["BLOCK_SIZE_M"], num_experts
topk_ids, config["BLOCK_SIZE_M"], local_experts
)
M = hidden_states.shape[0]
E, N, _ = w1.shape
......@@ -399,6 +404,7 @@ class BenchmarkWorker:
block_shape: List[int],
cfg: Dict[str, int],
topk_ids_dir: str,
ep_size: int = 1,
) -> Tuple[Dict[str, int], float]:
torch.cuda.manual_seed_all(0)
dtype_str = get_config_dtype_str(
......@@ -422,6 +428,7 @@ class BenchmarkWorker:
use_int8_w8a16,
topk_ids_dir,
block_shape,
ep_size=ep_size,
)
return cfg, kernel_time
......@@ -439,6 +446,7 @@ class BenchmarkWorker:
block_shape: List[int],
search_space: List[Dict[str, int]],
topk_ids_dir: str,
ep_size: int = 1,
) -> Dict[str, int]:
trace0 = BestConfigTrace("kernel0")
trace1 = BestConfigTrace("kernel1")
......@@ -461,6 +469,7 @@ class BenchmarkWorker:
topk_ids_dir,
block_shape,
num_iters=10,
ep_size=ep_size,
)
except triton.runtime.autotuner.OutOfResources:
# Some configurations may be invalid and fail to compile.
......@@ -536,6 +545,7 @@ def save_configs(
use_int8_w8a16: bool,
block_shape: List[int],
down_moe: bool = False,
ep_size: int = 1,
) -> None:
dtype_str = get_config_dtype_str(
dtype,
......@@ -546,8 +556,10 @@ def save_configs(
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
# is the intermediate size after silu_and_mul.
# For EP mode, use local expert count instead of total expert count
local_experts = num_experts // ep_size
filename = get_config_file_name(
num_experts,
local_experts,
shard_intermediate_size // 2,
dtype_str,
block_shape,
......@@ -563,22 +575,45 @@ def save_configs(
def main(args: argparse.Namespace):
print(args)
# Check EP mode constraint: tp_size must be 1 when ep_size > 1
if args.ep_size > 1 and args.tp_size != 1:
raise ValueError(
f"When using Expert Parallelism (ep_size={args.ep_size}), "
f"tp_size must be set to 1, but got tp_size={args.tp_size}. "
f"Please set --tp-size 1 when using --ep-size > 1."
)
config = AutoConfig.from_pretrained(args.model, trust_remote_code=True)
if config.architectures[0] == "DbrxForCausalLM":
E = config.ffn_config.moe_num_experts
topk = config.ffn_config.moe_top_k
intermediate_size = config.ffn_config.ffn_hidden_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
# In EP mode, use original intermediate_size; otherwise apply TP sharding
shard_intermediate_size = (
intermediate_size
if args.ep_size > 1
else 2 * intermediate_size // args.tp_size
)
elif config.architectures[0] == "JambaForCausalLM":
E = config.num_experts
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
# In EP mode, use original intermediate_size; otherwise apply TP sharding
shard_intermediate_size = (
intermediate_size
if args.ep_size > 1
else 2 * intermediate_size // args.tp_size
)
elif config.architectures[0] in ["Qwen2MoeForCausalLM", "Qwen3MoeForCausalLM"]:
E = config.num_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
# In EP mode, use original intermediate_size; otherwise apply TP sharding
shard_intermediate_size = (
intermediate_size
if args.ep_size > 1
else 2 * intermediate_size // args.tp_size
)
elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]:
E = (
config.n_routed_experts + (0 if args.disable_shared_experts_fusion else 1)
......@@ -592,14 +627,24 @@ def main(args: argparse.Namespace):
else config.num_experts_per_tok
)
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
# In EP mode, use original intermediate_size; otherwise apply TP sharding
shard_intermediate_size = (
intermediate_size
if args.ep_size > 1
else 2 * intermediate_size // args.tp_size
)
elif config.architectures[0] == "Llama4ForConditionalGeneration":
E = config.text_config.num_local_experts + (
0 if args.disable_shared_experts_fusion else 1
)
topk = config.text_config.num_experts_per_tok
intermediate_size = config.text_config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
# In EP mode, use original intermediate_size; otherwise apply TP sharding
shard_intermediate_size = (
intermediate_size
if args.ep_size > 1
else 2 * intermediate_size // args.tp_size
)
elif config.architectures[0] in [
"Grok1ForCausalLM",
"Grok1ImgGen",
......@@ -608,18 +653,33 @@ def main(args: argparse.Namespace):
E = config.num_local_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
# In EP mode, use original intermediate_size; otherwise apply TP sharding
shard_intermediate_size = (
intermediate_size
if args.ep_size > 1
else 2 * intermediate_size // args.tp_size
)
elif config.architectures[0] in ["Glm4MoeForCausalLM"]:
E = config.n_routed_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
# In EP mode, use original intermediate_size; otherwise apply TP sharding
shard_intermediate_size = (
intermediate_size
if args.ep_size > 1
else 2 * intermediate_size // args.tp_size
)
else:
# Default: Mixtral
E = config.num_local_experts
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
# In EP mode, use original intermediate_size; otherwise apply TP sharding
shard_intermediate_size = (
intermediate_size
if args.ep_size > 1
else 2 * intermediate_size // args.tp_size
)
hidden_size = getattr(config, "hidden_size", None) or config.text_config.hidden_size
dtype = config.torch_dtype
......@@ -677,6 +737,7 @@ def main(args: argparse.Namespace):
block_shape,
search_space,
topk_ids_dir,
args.ep_size,
)
else:
cfg = {
......@@ -701,6 +762,7 @@ def main(args: argparse.Namespace):
block_shape,
cfg,
topk_ids_dir,
args.ep_size,
)
print(f"{t0=}, {t0_tma=}, {t1=}, {t1_tma=}")
return
......@@ -750,6 +812,7 @@ def main(args: argparse.Namespace):
block_shape,
search_space,
topk_ids_dir,
args.ep_size,
)
for batch_size in batch_sizes
],
......@@ -775,6 +838,7 @@ def main(args: argparse.Namespace):
use_int8_w8a8,
use_int8_w8a16,
block_shape,
ep_size=args.ep_size,
)
best_configs1 = {M: sort_config(config) for M, config in zip(batch_sizes, configs1)}
......@@ -790,6 +854,7 @@ def main(args: argparse.Namespace):
use_int8_w8a16,
block_shape,
down_moe=True,
ep_size=args.ep_size,
)
end = time.perf_counter()
print(f"Tuning took {end - start:.2f} seconds")
......@@ -800,7 +865,7 @@ if __name__ == "__main__":
parser.add_argument(
"--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
)
parser.add_argument("--tp-size", "--tp", type=int, default=2)
parser.add_argument("--tp-size", "--tp", type=int, default=1)
parser.add_argument(
"--dtype",
type=str,
......@@ -813,6 +878,9 @@ if __name__ == "__main__":
parser.add_argument("--disable-shared-experts-fusion", action="store_true")
parser.add_argument("--configs", type=int, nargs="+", required=False)
parser.add_argument("--topk-ids-dir", type=str, required=True)
parser.add_argument(
"--ep-size", "--ep", type=int, default=1, help="Expert parallelism size"
)
args = parser.parse_args()
main(args)
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"64": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"96": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"128": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment