Unverified Commit fc84b073 authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

[Refactor] Refactor fused_moe_triton tuning tools: extract shared utils, add...


[Refactor] Refactor fused_moe_triton tuning tools: extract shared utils, add EP/MLLM support, reduce overhead (#12440)
Co-authored-by: default avatarxu-yfei <xu-yfei@users.noreply.github.com>
Co-authored-by: default avatarYongfei Xu <xuyongfei.xyf@antgroup.com>
parent 8be0e1bc
......@@ -2,13 +2,27 @@
This directory contains benchmarking tools for MoE (Mixture of Experts) kernels.
### Tuning Tool
### Overview
- `tuning_fused_moe_triton.py`: A tool for tuning the `fused_moe_triton` kernel. Adapted from [vllm's benchmark_moe.py](https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_moe.py), with added support for various model architectures.
The tuning tools support both **Tensor Parallelism (TP)** and **Expert Parallelism (EP)** modes:
Example usage:
- **TP Mode**: Traditional tensor parallelism where intermediate layers are sharded across GPUs
- **EP Mode**: Expert parallelism where experts are distributed across GPUs. Can be combined with TP mode (e.g., `--tp-size 8 --ep-size 2`)
- **MLLM Support**: Multi-modal Large Language Models with text encoders (e.g., Llama4, Qwen3VL)
### Tuning Tools
#### 1. `tuning_fused_moe_triton.py`
A unified tool for tuning the `fused_moe_triton` kernel. Adapted from [vllm's benchmark_moe.py](https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_moe.py), with support for EP mode and various model architectures.
#### 2. `tuning_fused_moe_triton_sep.py`
A specialized tool for separate kernel tuning, optimizing the first and second MoE kernels independently with TMA (Tensor Memory Accelerator) support.
### Usage Examples
#### Basic TP Mode Tuning
```bash
# Tune Mixtral-8x7B with default settings
# Tune Mixtral-8x7B with default TP settings
python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \
--model mistralai/Mixtral-8x7B-Instruct-v0.1 \
--tune
......@@ -20,29 +34,149 @@ python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \
--dtype fp8_w8a8 \
--tune
# Tune Qwen3-235B-A22B-FP8 and TP=4
# Tune DeepSeek-V3 with FP8 and TP=8
python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \
--model Qwen/Qwen3-235B-A22B-FP8 \
--tp-size 4 \
--model deepseek-ai/DeepSeek-V3-0324 \
--tp-size 8 \
--dtype fp8_w8a8 \
--tune
```
# Tune DeepSeek-V3 with FP8 and TP=8
#### EP Mode Tuning (Expert Parallelism)
**Note**: EP mode can be used alone or combined with TP mode. When using both, ensure `tp_size` is divisible by `ep_size`.
```bash
# Tune Mixtral-8x7B with EP=2 only
python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \
--model mistralai/Mixtral-8x7B-Instruct-v0.1 \
--tp-size 2 \
--ep-size 2 \
--tune
# Tune Qwen2-57B with TP=8 and EP=4 (combined mode)
python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \
--model Qwen/Qwen2-57B-A14B-Instruct \
--tp-size 8 \
--ep-size 4 \
--dtype fp8_w8a8 \
--tune
```
#### MLLM Model Tuning (Multi-modal)
```bash
python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \
--model Qwen/Qwen3-VL-30B-A3B-Instruct \
--tp-size 2 \
--tune
```
#### Separate Kernel Tuning with `tuning_fused_moe_triton_sep.py`
This tool requires pre-generated topk_ids files and supports both TP and EP modes:
Edit the code file (such as srt/models/deepseek_v2.py) in the Python site package and add the logic for saving topk_ids:
```python
# import get_tensor_model_parallel_rank
# DeepseekV2MoE::forward_normal
if hidden_states.shape[0] >= 4096 and get_tensor_model_parallel_rank() == 0:
topk_ids_dir = xxxx
if not hasattr(self, "save_idx"):
self.save_idx = 0
if self.save_idx <= 1:
torch.save(topk_output.topk_ids, f"{topk_ids_dir}/topk_ids_layer{self.layer_id}_idx{self.save_idx}.pt")
self.save_idx += 1
```
Launch sglang server and send request using `benchmark/kernels/fused_moe_triton/tuning_client.py`
```bash
python benchmark/kernels/fused_moe_triton/tuning_client.py --port 8000
```
```bash
# TP Mode: Tune separate kernels with TP=4
python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton_sep.py \
--model Qwen/Qwen2-57B-A14B-Instruct \
--tp-size 4 \
--topk-ids-dir /path/to/topk_ids \
--tune
# EP Mode: Tune separate kernels with TP=4 and EP=2
python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton_sep.py \
--model mistralai/Mixtral-8x7B-Instruct-v0.1 \
--tp-size 4 \
--ep-size 2 \
--topk-ids-dir /path/to/topk_ids \
--tune
# MLLM: Tune DeepSeek-V3 with separate kernels, TP=8 and EP=4
python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton_sep.py \
--model deepseek-ai/DeepSeek-V3-0324 \
--tp-size 8 \
--ep-size 4 \
--dtype fp8_w8a8 \
--topk-ids-dir /path/to/topk_ids \
--tune
# Tune DeepSeek-R1 with channel-wise INT8 and TP=16
# Benchmark specific config without tuning
python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton_sep.py \
--model deepseek-ai/DeepSeek-V3-0324 \
--tp-size 4 \
--batch-size 1024 \
--dtype fp8_w8a8 \
--configs 128 256 128 16 8 4 \
--topk-ids-dir /path/to/topk_ids
```
#### Advanced Options
```bash
# Channel-wise quantization
python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \
--model meituan/DeepSeek-R1-Channel-INT8 \
--tp-size 16 \
--dtype int8_w8a8 \
--per-channel-quant \
--tune
# Specific batch size tuning
python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \
--model mistralai/Mixtral-8x7B-Instruct-v0.1 \
--batch-size 2048 \
--tune
```
After tuning, a configuration file (e.g., `E=64,N=640,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json`) will be generated in the current directory. You can move this file to `sglang/srt/layers/fused_moe_triton/configs/triton_version` dir to use it in `sglang`.
### Configuration Files
After tuning, configuration files will be generated:
- **Standard tuning**: `E=64,N=640,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json`
- **Separate kernel tuning**: Two files for up/down kernels with TMA optimization flags
Move these files to `sglang/srt/layers/fused_moe_triton/configs/triton_version/` directory to use them in SGLang.
### Supported Models
- **Mixtral**: mistralai/Mixtral-8x7B-Instruct-v0.1, mixtral-8x22b
- **Qwen**: Qwen2-57B, Qwen3-235B, Qwen3VL (MLLM)
- **DeepSeek**: DeepSeek-V2, DeepSeek-V3, DeepSeek-R1
- **Llama**: Llama4-Vision (MLLM)
- **DBRX**: databricks/dbrx-instruct
- **Jamba**: ai21labs/AI21-Jamba
- **Grok**: xai-org/grok-1
- **GLM**: THUDM/glm-4-9b-chat
- **Bailing**: Custom MoE models
### Parameters Reference
- `--model`: HuggingFace model name or local path
- `--tp-size`: Tensor parallelism size (default: 2)
- `--ep-size`: Expert parallelism size (default: 1, can be combined with TP mode, ensure tp_size is divisible by ep_size)
- `--dtype`: Data type (`auto`, `fp8_w8a8`, `int8_w8a16`, `int8_w8a8`)
- `--batch-size`: Specific batch size for tuning (optional)
- `--tune`: Enable tuning mode
- `--per-channel-quant`: Enable per-channel quantization
- `--disable-shared-experts-fusion`: Disable shared expert fusion for some models
- `--topk-ids-dir`: Directory containing pre-generated topk_ids (for sep tool only)
- `--configs`: Manual config specification [BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M, warps, stages]
### Performance Comparison Tool
......@@ -73,4 +207,4 @@ The benchmark results will be saved as plots and data files in the specified out
- `benchmark_torch_compile_fused_moe.py`: A tool for benchmarking the performance of the fused MoE kernel with `torch.compile` and original fused MoE kernel.
Usage is the same as `benchmark_vllm_vs_sglang_fused_moe_triton.py`, note that `torch.compile` does not support `fp8_w8a8` and `int8_w8a8` fused_moe_kernel.
Usage is similar to `benchmark_vllm_vs_sglang_fused_moe_triton.py`, note that `torch.compile` does not support `fp8_w8a8` and `int8_w8a8` fused_moe_kernel. Both tools now support EP mode with `--ep-size` parameter.
......@@ -3,7 +3,7 @@ import argparse
import torch
import triton
from transformers import AutoConfig
from common_utils import get_model_config
from sglang.srt.distributed.parallel_state import (
destroy_distributed_environment,
......@@ -21,60 +21,6 @@ from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopK, TopKConfig, select_experts
def get_model_config(model_name: str, tp_size: int):
"""Get model configuration parameters"""
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
if config.architectures[0] == "Qwen2MoeForCausalLM":
E = config.num_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] == "Qwen3MoeForCausalLM":
E = config.num_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] in [
"DeepseekV2ForCausalLM",
"DeepseekV3ForCausalLM",
"Glm4MoeForCausalLM",
]:
E = (
config.n_routed_experts + 1
if config.architectures[0] in ["DeepseekV3ForCausalLM"]
else config.n_routed_experts
)
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
else:
# Default: Mixtral
E = config.num_local_experts
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
block_shape = None
if (
hasattr(config, "quantization_config")
and "weight_block_size" in config.quantization_config
):
block_shape = config.quantization_config["weight_block_size"]
assert len(block_shape) == 2
shape_configs = {
"num_experts": E,
"topk": topk,
"hidden_size": config.hidden_size,
"shard_intermediate_size": shard_intermediate_size,
"dtype": config.torch_dtype,
"block_shape": block_shape,
}
print(f"{shape_configs=}")
return shape_configs
def fused_moe_triton_api(
x,
w1,
......@@ -239,7 +185,8 @@ def main():
parser.add_argument(
"--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
)
parser.add_argument("--tp-size", type=int, default=2)
parser.add_argument("--tp-size", "--tp", type=int, default=2)
parser.add_argument("--ep-size", "--ep", type=int, default=1)
parser.add_argument("--use-fp8-w8a8", action="store_true")
parser.add_argument(
"--use-cuda-graph", action="store_true", help="Enable CUDA Graph capture/replay"
......@@ -270,11 +217,11 @@ def main():
)
initialize_model_parallel(
tensor_model_parallel_size=1,
pipeline_model_parallel_size=1,
tensor_model_parallel_size=args.ep_size,
pipeline_model_parallel_size=args.tp_size,
)
model_config = get_model_config(args.model, args.tp_size)
model_config = get_model_config(args.model, args.tp_size, args.ep_size)
benchmark.run(
show_plots=True,
print_data=True,
......
......@@ -3,8 +3,6 @@ import argparse
import torch
import triton
import vllm
from transformers import AutoConfig
from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe as fused_moe_vllm
from sglang.srt.distributed.parallel_state import (
......@@ -17,91 +15,7 @@ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
fused_moe as fused_moe_sglang,
)
def get_model_config(model_name: str, tp_size: int):
"""Get model configuration parameters"""
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
if config.architectures[0] == "DbrxForCausalLM":
E = config.ffn_config.moe_num_experts
topk = config.ffn_config.moe_top_k
intermediate_size = config.ffn_config.ffn_hidden_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] == "JambaForCausalLM":
E = config.num_experts
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] == "Qwen2MoeForCausalLM":
E = config.num_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] == "Qwen3MoeForCausalLM":
E = config.num_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] in [
"DeepseekV2ForCausalLM",
"DeepseekV3ForCausalLM",
"Glm4MoeForCausalLM",
]:
E = (
config.n_routed_experts + 1
if config.architectures[0] in ["DeepseekV3ForCausalLM"]
else config.n_routed_experts
)
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] == "Llama4ForConditionalGeneration":
E = config.text_config.num_local_experts
topk = config.text_config.num_experts_per_tok
intermediate_size = config.text_config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] in [
"Grok1ForCausalLM",
"Grok1ImgGen",
"Grok1AForCausalLM",
]:
E = config.num_local_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
else:
# Default: Mixtral
E = config.num_local_experts
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
vllm_version_num = (
vllm.__version_tuple__[0] * 100
+ vllm.__version_tuple__[1] * 10
+ vllm.__version_tuple__[2]
)
block_shape = None
if (
hasattr(config, "quantization_config")
and "weight_block_size" in config.quantization_config
):
block_shape = config.quantization_config["weight_block_size"]
assert len(block_shape) == 2
assert (
vllm_version_num >= 66
), "Block-wise quantized fp8 fused_moe is only supported for VLLM>=0.6.6.post1"
shape_configs = {
"num_experts": E,
"topk": topk,
"hidden_size": config.hidden_size,
"shard_intermediate_size": shard_intermediate_size,
"dtype": config.torch_dtype,
"block_shape": block_shape,
}
print(f"{shape_configs=}")
return shape_configs
from .common_utils import get_model_config
def fused_moe_vllm_api(
......@@ -301,7 +215,8 @@ def main():
parser.add_argument(
"--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
)
parser.add_argument("--tp-size", type=int, default=2)
parser.add_argument("--tp-size", "--tp", type=int, default=2)
parser.add_argument("--ep-size", "--ep", type=int, default=1)
parser.add_argument("--use-fp8-w8a8", action="store_true")
parser.add_argument(
"--save-path",
......@@ -332,12 +247,12 @@ def main():
pipeline_model_parallel_size=1,
)
model_config = get_model_config(args.model, args.tp_size)
shape_configs = get_model_config(args.model, args.tp_size, args.ep_size)
benchmark.run(
show_plots=True,
print_data=True,
save_path=args.save_path,
model_config=model_config,
model_config=shape_configs,
use_fp8_w8a8=args.use_fp8_w8a8,
)
finally:
......
import json
from typing import Dict, List, TypedDict
import torch
from transformers import AutoConfig
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import get_config_dtype_str
from sglang.srt.layers.moe.fused_moe_triton.fused_moe_triton_config import (
get_config_file_name,
)
from sglang.srt.utils import is_hip
class BenchmarkConfig(TypedDict):
BLOCK_SIZE_M: int
BLOCK_SIZE_N: int
BLOCK_SIZE_K: int
GROUP_SIZE_M: int
num_warps: int
num_stages: int
def calculate_shard_intermediate_size(
intermediate_size: int, tp_size: int, ep_size: int = 1
) -> int:
assert tp_size % ep_size == 0
moe_tp_size = tp_size // ep_size
assert intermediate_size % moe_tp_size == 0
return 2 * intermediate_size // moe_tp_size
def get_model_config(
model_name: str,
tp_size: int,
ep_size: int = 1,
disable_shared_experts_fusion: bool = False,
topk_ids_dir: str = None,
) -> Dict:
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
block_shape = None
if (
hasattr(config, "quantization_config")
and "weight_block_size" in config.quantization_config
):
block_shape = config.quantization_config["weight_block_size"]
assert len(block_shape) == 2
architecture = config.architectures[0]
# Replace config with text_config for encoder-decoder models after getting block_shape and architecture
if hasattr(config, "text_config"):
config = config.get_text_config()
if architecture == "DbrxForCausalLM":
E = config.ffn_config.moe_num_experts // ep_size
topk = config.ffn_config.moe_top_k
intermediate_size = config.ffn_config.ffn_hidden_size
elif architecture == "JambaForCausalLM":
E = config.num_experts // ep_size
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
elif architecture in [
"Qwen2MoeForCausalLM",
"Qwen3MoeForCausalLM",
"Qwen3NextForCausalLM",
"Qwen3VLMoeForConditionalGeneration",
]:
E = config.num_experts // ep_size
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
elif architecture in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]:
E = (config.n_routed_experts // ep_size) + (
0
if disable_shared_experts_fusion
or architecture not in ["DeepseekV3ForCausalLM"]
else 1
)
topk = config.num_experts_per_tok + (
0 if disable_shared_experts_fusion or topk_ids_dir is None else 1
)
intermediate_size = config.moe_intermediate_size
elif architecture == "Llama4ForConditionalGeneration":
E = config.num_local_experts // ep_size + (
0 if disable_shared_experts_fusion else 1
)
topk = config.num_experts_per_tok + (
0 if disable_shared_experts_fusion or topk_ids_dir is None else 1
)
intermediate_size = config.intermediate_size
elif architecture in [
"Grok1ForCausalLM",
"Grok1ImgGen",
"Grok1AForCausalLM",
]:
E = config.num_local_experts // ep_size
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
elif architecture in [
"BailingMoEForCausalLM",
"BailingMoeForCausalLM",
"BailingMoeV2ForCausalLM",
]:
E = config.num_experts // ep_size
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
elif architecture in ["Glm4MoeForCausalLM"]:
E = config.n_routed_experts // ep_size
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
else:
# Default: Mixtral
E = config.num_local_experts // ep_size
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
shard_intermediate_size = calculate_shard_intermediate_size(
intermediate_size, tp_size, ep_size
)
return {
"num_experts": E,
"topk": topk,
"hidden_size": config.hidden_size,
"shard_intermediate_size": shard_intermediate_size,
"dtype": config.torch_dtype,
"block_shape": block_shape,
"architecture": architecture,
}
def get_rocm_configs_compute_bound() -> List[Dict[str, int]]:
configs: List[BenchmarkConfig] = []
waves_per_eu_range = 0
for num_stages in [2]:
for block_m in [32, 64, 128, 256]:
for block_k in [32, 64, 128, 256]:
for block_n in [16, 32, 64, 128, 256]:
for num_warps in [1, 2, 4, 8]:
for group_size in [1, 4, 8, 16, 32]:
configs.append(
{
"BLOCK_SIZE_M": block_m,
"BLOCK_SIZE_N": block_n,
"BLOCK_SIZE_K": block_k,
"GROUP_SIZE_M": group_size,
"num_warps": num_warps,
"num_stages": num_stages,
"waves_per_eu": waves_per_eu_range,
}
)
return configs
def get_configs_compute_bound() -> List[Dict[str, int]]:
configs: List[BenchmarkConfig] = []
if is_hip():
configs = get_rocm_configs_compute_bound()
else:
for num_stages in [2, 3, 4, 5]:
for block_m in [16, 32, 64, 128, 256]:
for block_k in [64, 128, 256]:
for block_n in [32, 64, 128, 256]:
for num_warps in [4, 8]:
for group_size in [1, 16, 32, 64]:
configs.append(
{
"BLOCK_SIZE_M": block_m,
"BLOCK_SIZE_N": block_n,
"BLOCK_SIZE_K": block_k,
"GROUP_SIZE_M": group_size,
"num_warps": num_warps,
"num_stages": num_stages,
}
)
return configs
def sort_config(config: BenchmarkConfig) -> BenchmarkConfig:
return {
"BLOCK_SIZE_M": config["BLOCK_SIZE_M"],
"BLOCK_SIZE_N": config["BLOCK_SIZE_N"],
"BLOCK_SIZE_K": config["BLOCK_SIZE_K"],
"GROUP_SIZE_M": config["GROUP_SIZE_M"],
"num_warps": config["num_warps"],
"num_stages": config["num_stages"],
**(
{"waves_per_eu": config["waves_per_eu"]} if "waves_per_eu" in config else {}
),
**({"USE_TMA": config["USE_TMA"]} if "USE_TMA" in config else {}),
}
def save_configs(
configs: Dict[int, BenchmarkConfig],
filename: str,
) -> None:
print(f"Writing best config to {filename}...")
with open(filename, "w") as f:
json.dump(configs, f, indent=4)
f.write("\n")
def get_config_filename(
num_experts: int,
shard_intermediate_size: int,
hidden_size: int,
topk: int,
dtype: torch.dtype,
use_fp8_w8a8: bool,
use_int8_w8a8: bool,
use_int8_w8a16: bool,
per_channel_quant: bool,
block_shape: List[int],
) -> str:
dtype_str = get_config_dtype_str(
dtype,
use_int8_w8a16=use_int8_w8a16,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
)
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
# is the intermediate size after silu_and_mul.
filename = get_config_file_name(
num_experts,
shard_intermediate_size // 2,
dtype_str,
block_shape,
per_channel_quant,
)
return filename
def get_default_batch_sizes() -> List[int]:
return [
1,
2,
4,
8,
16,
24,
32,
48,
64,
96,
128,
256,
512,
1024,
1536,
2048,
3072,
4096,
]
import argparse
import os
import time
import openai
"""
# Edit the code file srt/models/deepseek_v2.py in the Python site package and add the logic for saving topk_ids:
# import get_tensor_model_parallel_rank
# DeepseekV2MoE::forward_normal
if hidden_states.shape[0] >= 4096 and get_tensor_model_parallel_rank() == 0:
topk_ids_dir = xxxx
if not hasattr(self, "save_idx"):
self.save_idx = 0
if self.save_idx <= 1:
torch.save(topk_output.topk_ids, f"{topk_ids_dir}/topk_ids_layer{self.layer_id}_idx{self.save_idx}.pt")
self.save_idx += 1
"""
def read_long_prompt():
import json
current_dir = os.path.dirname(os.path.abspath(__file__))
with open(f"{current_dir}/tuning_text.json", "r") as fp:
text = fp.read()
rst = json.loads(text)
return rst["prompt"]
def openai_stream_test(model, ip, port):
client = openai.Client(base_url=f"http://{ip}:{port}/v1", api_key="None")
qst = read_long_prompt()
messages = [
{"role": "user", "content": qst},
]
msg2 = dict(
model=model,
messages=messages,
temperature=0.6,
top_p=0.75,
max_tokens=100,
)
response = client.chat.completions.create(**msg2, stream=True)
time_start = time.time()
time_cost = []
for chunk in response:
time_end = time.time()
# if chunk.choices[0].delta.content:
# print(chunk.choices[0].delta.content, end="", flush=True)
time_cost.append(time_end - time_start)
time_start = time.time()
ttft = time_cost[0] + time_cost[1]
tpot = sum(time_cost[2:]) / len(time_cost[2:])
print(f"\nTTFT {ttft}, TPOT {tpot}")
return ttft, tpot
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default="auto")
parser.add_argument(
"--ip",
type=str,
default="127.0.0.1",
)
parser.add_argument("--port", type=int, default=8188)
args = parser.parse_args()
openai_stream_test(args.model, args.ip, args.port)
# Adapted from https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_moe.py
import argparse
import json
import time
from contextlib import nullcontext
from datetime import datetime
from typing import Any, Dict, List, Tuple, TypedDict
from typing import Any, Dict, List, Tuple
import ray
import torch
import triton
from common_utils import (
BenchmarkConfig,
get_config_filename,
get_configs_compute_bound,
get_default_batch_sizes,
get_model_config,
save_configs,
sort_config,
)
from ray.experimental.tqdm_ray import tqdm
from transformers import AutoConfig
from sglang.srt.layers.moe.fused_moe_triton import override_config
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
from sglang.srt.layers.moe.fused_moe_triton.fused_moe_triton_config import (
get_config_dtype_str,
get_config_file_name,
get_default_config,
get_moe_configs,
)
......@@ -27,15 +33,6 @@ from sglang.srt.utils import is_hip
_is_hip = is_hip()
class BenchmarkConfig(TypedDict):
BLOCK_SIZE_M: int
BLOCK_SIZE_N: int
BLOCK_SIZE_K: int
GROUP_SIZE_M: int
num_warps: int
num_stages: int
def benchmark_config(
config: BenchmarkConfig,
num_tokens: int,
......@@ -173,74 +170,28 @@ def benchmark_config(
graph.replay()
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
# Flush L2 cache with 256 MB data
cache_flush = torch.empty(int(256e6 // 4), dtype=torch.int, device="cuda")
cache_flush.zero_()
start_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_iters)]
end_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_iters)]
latencies: List[float] = []
for i in range(num_iters):
prepare(i)
torch.cuda.synchronize()
start_event.record()
start_events[i].record()
graph.replay()
end_event.record()
end_event.synchronize()
latencies.append(start_event.elapsed_time(end_event))
end_events[i].record()
torch.cuda.synchronize()
latencies: List[float] = []
for i in range(num_iters):
latencies.append(start_events[i].elapsed_time(end_events[i]))
avg = sum(latencies) / (num_iters * 10) * 1000 # us
graph.reset()
return avg
def get_rocm_configs_compute_bound() -> List[Dict[str, int]]:
configs: List[BenchmarkConfig] = []
waves_per_eu_range = 0
for num_stages in [2]:
for block_m in [32, 64, 128, 256]:
for block_k in [32, 64, 128, 256]:
for block_n in [16, 32, 64, 128, 256]:
for num_warps in [1, 2, 4, 8]:
for group_size in [1, 4, 8, 16, 32]:
configs.append(
{
"BLOCK_SIZE_M": block_m,
"BLOCK_SIZE_N": block_n,
"BLOCK_SIZE_K": block_k,
"GROUP_SIZE_M": group_size,
"num_warps": num_warps,
"num_stages": num_stages,
"waves_per_eu": waves_per_eu_range,
}
)
return configs
def get_configs_compute_bound() -> List[Dict[str, int]]:
# Reduced search space for faster tuning.
# TODO(woosuk): Increase the search space and use a performance model to
# prune the search space.
configs: List[BenchmarkConfig] = []
if _is_hip:
configs = get_rocm_configs_compute_bound()
else:
for num_stages in [2, 3, 4, 5]:
for block_m in [16, 32, 64, 128, 256]:
for block_k in [64, 128, 256]:
for block_n in [32, 64, 128, 256]:
for num_warps in [4, 8]:
for group_size in [1, 16, 32, 64]:
configs.append(
{
"BLOCK_SIZE_M": block_m,
"BLOCK_SIZE_N": block_n,
"BLOCK_SIZE_K": block_k,
"GROUP_SIZE_M": group_size,
"num_warps": num_warps,
"num_stages": num_stages,
}
)
return configs
@ray.remote(num_gpus=1)
class BenchmarkWorker:
......@@ -360,189 +311,27 @@ class BenchmarkWorker:
return best_config
def sort_config(config: BenchmarkConfig) -> BenchmarkConfig:
return {
"BLOCK_SIZE_M": config["BLOCK_SIZE_M"],
"BLOCK_SIZE_N": config["BLOCK_SIZE_N"],
"BLOCK_SIZE_K": config["BLOCK_SIZE_K"],
"GROUP_SIZE_M": config["GROUP_SIZE_M"],
"num_warps": config["num_warps"],
"num_stages": config["num_stages"],
**(
{"waves_per_eu": config["waves_per_eu"]} if "waves_per_eu" in config else {}
),
}
def save_configs(
configs: Dict[int, BenchmarkConfig],
filename: str,
) -> None:
print(f"Writing best config to {filename}...")
with open(filename, "w") as f:
json.dump(configs, f, indent=4)
f.write("\n")
def get_filename(
num_experts: int,
shard_intermediate_size: int,
hidden_size: int,
topk: int,
dtype: torch.dtype,
use_fp8_w8a8: bool,
use_int8_w8a8: bool,
use_int8_w8a16: bool,
per_channel_quant: bool,
block_shape: List[int],
) -> None:
dtype_str = get_config_dtype_str(
dtype,
use_int8_w8a16=use_int8_w8a16,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
)
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
# is the intermediate size after silu_and_mul.
filename = get_config_file_name(
num_experts,
shard_intermediate_size // 2,
dtype_str,
block_shape,
per_channel_quant,
)
return filename
def main(args: argparse.Namespace):
print(args)
def _calculate_shard_intermediate_size(intermediate_size: int) -> int:
# In EP mode, use original intermediate_size; otherwise apply TP sharding
return (
intermediate_size
if args.ep_size > 1
else 2 * intermediate_size // args.tp_size
)
model_config = get_model_config(
args.model, args.tp_size, args.ep_size, args.disable_shared_experts_fusion
)
# Check EP mode constraint: tp_size must be 1 when ep_size > 1
if args.ep_size > 1 and args.tp_size != 1:
raise ValueError(
f"When using Expert Parallelism (ep_size={args.ep_size}), "
f"tp_size must be set to 1, but got tp_size={args.tp_size}. "
f"Please set --tp-size 1 when using --ep-size > 1."
)
E = model_config["num_experts"]
topk = model_config["topk"]
hidden_size = model_config["hidden_size"]
shard_intermediate_size = model_config["shard_intermediate_size"]
dtype = model_config["dtype"]
block_shape = model_config["block_shape"]
config = AutoConfig.from_pretrained(args.model, trust_remote_code=True)
# Determine block shape for quantization
block_shape = None
if (
hasattr(config, "quantization_config")
and "weight_block_size" in config.quantization_config
):
block_shape = config.quantization_config["weight_block_size"]
assert len(block_shape) == 2
architecture = config.architectures[0]
# replace config with text_config for encoder-decoder models after getting block_shape and architecture
if hasattr(config, "text_config"):
config = config.get_text_config()
if architecture == "DbrxForCausalLM":
E = config.ffn_config.moe_num_experts
topk = config.ffn_config.moe_top_k
intermediate_size = config.ffn_config.ffn_hidden_size
shard_intermediate_size = _calculate_shard_intermediate_size(intermediate_size)
elif architecture == "JambaForCausalLM":
E = config.num_experts
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
shard_intermediate_size = _calculate_shard_intermediate_size(intermediate_size)
elif architecture in [
"Qwen2MoeForCausalLM",
"Qwen3MoeForCausalLM",
"Qwen3NextForCausalLM",
"Qwen3VLMoeForConditionalGeneration",
]:
E = config.num_experts // args.ep_size
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = _calculate_shard_intermediate_size(intermediate_size)
elif architecture in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]:
E = (
config.n_routed_experts + (0 if args.disable_shared_experts_fusion else 1)
if architecture == "DeepseekV3ForCausalLM"
else config.n_routed_experts
)
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = _calculate_shard_intermediate_size(intermediate_size)
elif architecture == "Llama4ForConditionalGeneration":
E = config.num_local_experts + (0 if args.disable_shared_experts_fusion else 1)
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
shard_intermediate_size = _calculate_shard_intermediate_size(intermediate_size)
elif architecture in [
"Grok1ForCausalLM",
"Grok1ImgGen",
"Grok1AForCausalLM",
]:
E = config.num_local_experts // args.ep_size
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = _calculate_shard_intermediate_size(intermediate_size)
elif architecture in [
"BailingMoEForCausalLM",
"BailingMoeForCausalLM",
"BailingMoeV2ForCausalLM",
]:
E = config.num_experts // args.ep_size
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = _calculate_shard_intermediate_size(intermediate_size)
elif architecture in ["Glm4MoeForCausalLM"]:
E = config.n_routed_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = _calculate_shard_intermediate_size(intermediate_size)
else:
# Default: Mixtral
E = config.num_local_experts // args.ep_size
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
shard_intermediate_size = _calculate_shard_intermediate_size(intermediate_size)
hidden_size = config.hidden_size
dtype = config.torch_dtype
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
use_int8_w8a8 = args.dtype == "int8_w8a8"
use_int8_w8a16 = args.dtype == "int8_w8a16"
per_channel_quant = args.per_channel_quant
if args.batch_size is None:
batch_sizes = [
1,
2,
4,
8,
16,
24,
32,
48,
64,
96,
128,
256,
512,
1024,
1536,
2048,
3072,
4096,
]
batch_sizes = get_default_batch_sizes()
else:
batch_sizes = [args.batch_size]
......@@ -571,7 +360,7 @@ def main(args: argparse.Namespace):
if block_k % config["BLOCK_SIZE_K"] == 0
]
filename = get_filename(
filename = get_config_filename(
E,
shard_intermediate_size,
hidden_size,
......
......@@ -5,15 +5,22 @@ import os
import time
from contextlib import nullcontext
from datetime import datetime
from typing import Any, Dict, List, Tuple, TypedDict
from typing import Any, Dict, List, Tuple
import ray
import torch
import triton
import triton.language as tl
from common_utils import (
BenchmarkConfig,
get_config_filename,
get_configs_compute_bound,
get_default_batch_sizes,
get_model_config,
sort_config,
)
from ray.experimental.tqdm_ray import tqdm
from sgl_kernel import silu_and_mul
from transformers import AutoConfig
from sglang.srt.layers.moe.fused_moe_triton import override_config
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
......@@ -31,15 +38,6 @@ from sglang.srt.utils import is_hip
_is_hip = is_hip()
class BenchmarkConfig(TypedDict):
BLOCK_SIZE_M: int
BLOCK_SIZE_N: int
BLOCK_SIZE_K: int
GROUP_SIZE_M: int
num_warps: int
num_stages: int
def benchmark_config(
config: BenchmarkConfig,
num_tokens: int,
......@@ -294,56 +292,6 @@ def benchmark_config(
return avg, avg_tma, avg1, avg1_tma
def get_rocm_configs_compute_bound() -> List[Dict[str, int]]:
configs: List[BenchmarkConfig] = []
waves_per_eu_range = 0
for block_m in [32, 64, 128, 256]:
for block_k in [32, 64, 128, 256]:
for block_n in [16, 32, 64, 128, 256]:
for num_stages in [2]:
for num_warps in [1, 2, 4, 8]:
for group_size in [1, 4, 8, 16, 32]:
configs.append(
{
"BLOCK_SIZE_M": block_m,
"BLOCK_SIZE_N": block_n,
"BLOCK_SIZE_K": block_k,
"GROUP_SIZE_M": group_size,
"num_warps": num_warps,
"num_stages": num_stages,
"waves_per_eu": waves_per_eu_range,
}
)
return configs
def get_configs_compute_bound() -> List[Dict[str, int]]:
# Reduced search space for faster tuning.
# TODO(woosuk): Increase the search space and use a performance model to
# prune the search space.
configs: List[BenchmarkConfig] = []
if _is_hip:
configs = get_rocm_configs_compute_bound()
else:
for block_m in [16, 32, 64, 128, 256]:
for block_k in [32, 64, 128, 256]:
for block_n in [32, 64, 128, 256]:
for num_stages in [2, 3, 4, 5]:
for num_warps in [4, 8]:
for group_size in [1, 16, 32, 64]:
configs.append(
{
"BLOCK_SIZE_M": block_m,
"BLOCK_SIZE_N": block_n,
"BLOCK_SIZE_K": block_k,
"GROUP_SIZE_M": group_size,
"num_warps": num_warps,
"num_stages": num_stages,
}
)
return configs
class BestConfigTrace:
def __init__(self, name):
self.name = name
......@@ -509,22 +457,7 @@ class BenchmarkWorker:
)
def sort_config(config: BenchmarkConfig) -> BenchmarkConfig:
return {
"BLOCK_SIZE_M": config["BLOCK_SIZE_M"],
"BLOCK_SIZE_N": config["BLOCK_SIZE_N"],
"BLOCK_SIZE_K": config["BLOCK_SIZE_K"],
"GROUP_SIZE_M": config["GROUP_SIZE_M"],
"num_warps": config["num_warps"],
"num_stages": config["num_stages"],
**(
{"waves_per_eu": config["waves_per_eu"]} if "waves_per_eu" in config else {}
),
**({"USE_TMA": config["USE_TMA"]} if "USE_TMA" in config else {}),
}
def save_configs(
def save_configs_sep(
configs: Dict[int, BenchmarkConfig],
num_experts: int,
shard_intermediate_size: int,
......@@ -563,100 +496,29 @@ def save_configs(
def main(args: argparse.Namespace):
print(args)
config = AutoConfig.from_pretrained(args.model, trust_remote_code=True)
if config.architectures[0] == "DbrxForCausalLM":
E = config.ffn_config.moe_num_experts
topk = config.ffn_config.moe_top_k
intermediate_size = config.ffn_config.ffn_hidden_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
elif config.architectures[0] == "JambaForCausalLM":
E = config.num_experts
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
elif config.architectures[0] in ["Qwen2MoeForCausalLM", "Qwen3MoeForCausalLM"]:
E = config.num_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]:
E = (
config.n_routed_experts + (0 if args.disable_shared_experts_fusion else 1)
if config.architectures[0] in ["DeepseekV3ForCausalLM"]
else config.n_routed_experts
)
topk = (
config.num_experts_per_tok
+ (0 if args.disable_shared_experts_fusion else 1)
if config.architectures[0] in ["DeepseekV3ForCausalLM"]
else config.num_experts_per_tok
)
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
elif config.architectures[0] == "Llama4ForConditionalGeneration":
E = config.text_config.num_local_experts + (
0 if args.disable_shared_experts_fusion else 1
)
topk = config.text_config.num_experts_per_tok
intermediate_size = config.text_config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
elif config.architectures[0] in [
"Grok1ForCausalLM",
"Grok1ImgGen",
"Grok1AForCausalLM",
]:
E = config.num_local_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
elif config.architectures[0] in ["Glm4MoeForCausalLM"]:
E = config.n_routed_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
else:
# Default: Mixtral
E = config.num_local_experts
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
hidden_size = getattr(config, "hidden_size", None) or config.text_config.hidden_size
dtype = config.torch_dtype
model_config = get_model_config(
args.model,
args.tp_size,
args.ep_size,
args.disable_shared_experts_fusion,
args.topk_ids_dir,
)
E = model_config["num_experts"]
topk = model_config["topk"]
hidden_size = model_config["hidden_size"]
shard_intermediate_size = model_config["shard_intermediate_size"]
dtype = model_config["dtype"]
block_shape = model_config["block_shape"]
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
use_int8_w8a8 = args.dtype == "int8_w8a8"
use_int8_w8a16 = args.dtype == "int8_w8a16"
block_shape = None
if (
hasattr(config, "quantization_config")
and "weight_block_size" in config.quantization_config
):
block_shape = config.quantization_config["weight_block_size"]
assert len(block_shape) == 2
per_channel_quant = args.per_channel_quant
topk_ids_dir = args.topk_ids_dir
if args.batch_size is None:
batch_sizes = [
1,
2,
4,
8,
16,
24,
32,
48,
64,
96,
128,
256,
512,
1024,
1536,
2048,
3072,
4096,
8192,
]
batch_sizes = get_default_batch_sizes()
batch_sizes.reverse()
else:
batch_sizes = [args.batch_size]
......@@ -731,7 +593,21 @@ def main(args: argparse.Namespace):
search_space = [
config for config in search_space if block_k % config["BLOCK_SIZE_K"] == 0
]
print(f"Start tuning over {len(search_space)} configurations...")
filename = get_config_filename(
E,
shard_intermediate_size,
hidden_size,
topk,
dtype,
use_fp8_w8a8,
use_int8_w8a8,
use_int8_w8a16,
per_channel_quant,
block_shape,
)
print(
f"Start tuning over {len(search_space)} configurations to create {filename}..."
)
start = time.perf_counter()
configs = _distribute(
......@@ -764,7 +640,7 @@ def main(args: argparse.Namespace):
configs0.reverse()
configs1.reverse()
best_configs0 = {M: sort_config(config) for M, config in zip(batch_sizes, configs0)}
save_configs(
save_configs_sep(
best_configs0,
E,
shard_intermediate_size,
......@@ -778,7 +654,7 @@ def main(args: argparse.Namespace):
)
best_configs1 = {M: sort_config(config) for M, config in zip(batch_sizes, configs1)}
save_configs(
save_configs_sep(
best_configs1,
E,
shard_intermediate_size,
......@@ -801,6 +677,7 @@ if __name__ == "__main__":
"--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
)
parser.add_argument("--tp-size", "--tp", type=int, default=2)
parser.add_argument("--ep-size", "--ep", type=int, default=1)
parser.add_argument(
"--dtype",
type=str,
......
This source diff could not be displayed because it is too large. You can view the blob instead.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment