Unverified Commit fc84b073 authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

[Refactor] Refactor fused_moe_triton tuning tools: extract shared utils, add...


[Refactor] Refactor fused_moe_triton tuning tools: extract shared utils, add EP/MLLM support, reduce overhead (#12440)
Co-authored-by: default avatarxu-yfei <xu-yfei@users.noreply.github.com>
Co-authored-by: default avatarYongfei Xu <xuyongfei.xyf@antgroup.com>
parent 8be0e1bc
...@@ -2,13 +2,27 @@ ...@@ -2,13 +2,27 @@
This directory contains benchmarking tools for MoE (Mixture of Experts) kernels. This directory contains benchmarking tools for MoE (Mixture of Experts) kernels.
### Tuning Tool ### Overview
- `tuning_fused_moe_triton.py`: A tool for tuning the `fused_moe_triton` kernel. Adapted from [vllm's benchmark_moe.py](https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_moe.py), with added support for various model architectures. The tuning tools support both **Tensor Parallelism (TP)** and **Expert Parallelism (EP)** modes:
Example usage: - **TP Mode**: Traditional tensor parallelism where intermediate layers are sharded across GPUs
- **EP Mode**: Expert parallelism where experts are distributed across GPUs. Can be combined with TP mode (e.g., `--tp-size 8 --ep-size 2`)
- **MLLM Support**: Multi-modal Large Language Models with text encoders (e.g., Llama4, Qwen3VL)
### Tuning Tools
#### 1. `tuning_fused_moe_triton.py`
A unified tool for tuning the `fused_moe_triton` kernel. Adapted from [vllm's benchmark_moe.py](https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_moe.py), with support for EP mode and various model architectures.
#### 2. `tuning_fused_moe_triton_sep.py`
A specialized tool for separate kernel tuning, optimizing the first and second MoE kernels independently with TMA (Tensor Memory Accelerator) support.
### Usage Examples
#### Basic TP Mode Tuning
```bash ```bash
# Tune Mixtral-8x7B with default settings # Tune Mixtral-8x7B with default TP settings
python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \ python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \
--model mistralai/Mixtral-8x7B-Instruct-v0.1 \ --model mistralai/Mixtral-8x7B-Instruct-v0.1 \
--tune --tune
...@@ -20,29 +34,149 @@ python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \ ...@@ -20,29 +34,149 @@ python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \
--dtype fp8_w8a8 \ --dtype fp8_w8a8 \
--tune --tune
# Tune Qwen3-235B-A22B-FP8 and TP=4 # Tune DeepSeek-V3 with FP8 and TP=8
python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \ python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \
--model Qwen/Qwen3-235B-A22B-FP8 \ --model deepseek-ai/DeepSeek-V3-0324 \
--tp-size 4 \ --tp-size 8 \
--dtype fp8_w8a8 \ --dtype fp8_w8a8 \
--tune --tune
```
# Tune DeepSeek-V3 with FP8 and TP=8 #### EP Mode Tuning (Expert Parallelism)
**Note**: EP mode can be used alone or combined with TP mode. When using both, ensure `tp_size` is divisible by `ep_size`.
```bash
# Tune Mixtral-8x7B with EP=2 only
python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \ python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \
--model mistralai/Mixtral-8x7B-Instruct-v0.1 \
--tp-size 2 \
--ep-size 2 \
--tune
# Tune Qwen2-57B with TP=8 and EP=4 (combined mode)
python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \
--model Qwen/Qwen2-57B-A14B-Instruct \
--tp-size 8 \
--ep-size 4 \
--dtype fp8_w8a8 \
--tune
```
#### MLLM Model Tuning (Multi-modal)
```bash
python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \
--model Qwen/Qwen3-VL-30B-A3B-Instruct \
--tp-size 2 \
--tune
```
#### Separate Kernel Tuning with `tuning_fused_moe_triton_sep.py`
This tool requires pre-generated topk_ids files and supports both TP and EP modes:
Edit the code file (such as srt/models/deepseek_v2.py) in the Python site package and add the logic for saving topk_ids:
```python
# import get_tensor_model_parallel_rank
# DeepseekV2MoE::forward_normal
if hidden_states.shape[0] >= 4096 and get_tensor_model_parallel_rank() == 0:
topk_ids_dir = xxxx
if not hasattr(self, "save_idx"):
self.save_idx = 0
if self.save_idx <= 1:
torch.save(topk_output.topk_ids, f"{topk_ids_dir}/topk_ids_layer{self.layer_id}_idx{self.save_idx}.pt")
self.save_idx += 1
```
Launch sglang server and send request using `benchmark/kernels/fused_moe_triton/tuning_client.py`
```bash
python benchmark/kernels/fused_moe_triton/tuning_client.py --port 8000
```
```bash
# TP Mode: Tune separate kernels with TP=4
python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton_sep.py \
--model Qwen/Qwen2-57B-A14B-Instruct \
--tp-size 4 \
--topk-ids-dir /path/to/topk_ids \
--tune
# EP Mode: Tune separate kernels with TP=4 and EP=2
python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton_sep.py \
--model mistralai/Mixtral-8x7B-Instruct-v0.1 \
--tp-size 4 \
--ep-size 2 \
--topk-ids-dir /path/to/topk_ids \
--tune
# MLLM: Tune DeepSeek-V3 with separate kernels, TP=8 and EP=4
python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton_sep.py \
--model deepseek-ai/DeepSeek-V3-0324 \ --model deepseek-ai/DeepSeek-V3-0324 \
--tp-size 8 \ --tp-size 8 \
--ep-size 4 \
--dtype fp8_w8a8 \ --dtype fp8_w8a8 \
--topk-ids-dir /path/to/topk_ids \
--tune --tune
# Tune DeepSeek-R1 with channel-wise INT8 and TP=16 # Benchmark specific config without tuning
python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton_sep.py \
--model deepseek-ai/DeepSeek-V3-0324 \
--tp-size 4 \
--batch-size 1024 \
--dtype fp8_w8a8 \
--configs 128 256 128 16 8 4 \
--topk-ids-dir /path/to/topk_ids
```
#### Advanced Options
```bash
# Channel-wise quantization
python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \ python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \
--model meituan/DeepSeek-R1-Channel-INT8 \ --model meituan/DeepSeek-R1-Channel-INT8 \
--tp-size 16 \ --tp-size 16 \
--dtype int8_w8a8 \ --dtype int8_w8a8 \
--per-channel-quant \
--tune
# Specific batch size tuning
python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \
--model mistralai/Mixtral-8x7B-Instruct-v0.1 \
--batch-size 2048 \
--tune --tune
``` ```
After tuning, a configuration file (e.g., `E=64,N=640,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json`) will be generated in the current directory. You can move this file to `sglang/srt/layers/fused_moe_triton/configs/triton_version` dir to use it in `sglang`. ### Configuration Files
After tuning, configuration files will be generated:
- **Standard tuning**: `E=64,N=640,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json`
- **Separate kernel tuning**: Two files for up/down kernels with TMA optimization flags
Move these files to `sglang/srt/layers/fused_moe_triton/configs/triton_version/` directory to use them in SGLang.
### Supported Models
- **Mixtral**: mistralai/Mixtral-8x7B-Instruct-v0.1, mixtral-8x22b
- **Qwen**: Qwen2-57B, Qwen3-235B, Qwen3VL (MLLM)
- **DeepSeek**: DeepSeek-V2, DeepSeek-V3, DeepSeek-R1
- **Llama**: Llama4-Vision (MLLM)
- **DBRX**: databricks/dbrx-instruct
- **Jamba**: ai21labs/AI21-Jamba
- **Grok**: xai-org/grok-1
- **GLM**: THUDM/glm-4-9b-chat
- **Bailing**: Custom MoE models
### Parameters Reference
- `--model`: HuggingFace model name or local path
- `--tp-size`: Tensor parallelism size (default: 2)
- `--ep-size`: Expert parallelism size (default: 1, can be combined with TP mode, ensure tp_size is divisible by ep_size)
- `--dtype`: Data type (`auto`, `fp8_w8a8`, `int8_w8a16`, `int8_w8a8`)
- `--batch-size`: Specific batch size for tuning (optional)
- `--tune`: Enable tuning mode
- `--per-channel-quant`: Enable per-channel quantization
- `--disable-shared-experts-fusion`: Disable shared expert fusion for some models
- `--topk-ids-dir`: Directory containing pre-generated topk_ids (for sep tool only)
- `--configs`: Manual config specification [BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M, warps, stages]
### Performance Comparison Tool ### Performance Comparison Tool
...@@ -73,4 +207,4 @@ The benchmark results will be saved as plots and data files in the specified out ...@@ -73,4 +207,4 @@ The benchmark results will be saved as plots and data files in the specified out
- `benchmark_torch_compile_fused_moe.py`: A tool for benchmarking the performance of the fused MoE kernel with `torch.compile` and original fused MoE kernel. - `benchmark_torch_compile_fused_moe.py`: A tool for benchmarking the performance of the fused MoE kernel with `torch.compile` and original fused MoE kernel.
Usage is the same as `benchmark_vllm_vs_sglang_fused_moe_triton.py`, note that `torch.compile` does not support `fp8_w8a8` and `int8_w8a8` fused_moe_kernel. Usage is similar to `benchmark_vllm_vs_sglang_fused_moe_triton.py`, note that `torch.compile` does not support `fp8_w8a8` and `int8_w8a8` fused_moe_kernel. Both tools now support EP mode with `--ep-size` parameter.
...@@ -3,7 +3,7 @@ import argparse ...@@ -3,7 +3,7 @@ import argparse
import torch import torch
import triton import triton
from transformers import AutoConfig from common_utils import get_model_config
from sglang.srt.distributed.parallel_state import ( from sglang.srt.distributed.parallel_state import (
destroy_distributed_environment, destroy_distributed_environment,
...@@ -21,60 +21,6 @@ from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig ...@@ -21,60 +21,6 @@ from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopK, TopKConfig, select_experts from sglang.srt.layers.moe.topk import TopK, TopKConfig, select_experts
def get_model_config(model_name: str, tp_size: int):
"""Get model configuration parameters"""
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
if config.architectures[0] == "Qwen2MoeForCausalLM":
E = config.num_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] == "Qwen3MoeForCausalLM":
E = config.num_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] in [
"DeepseekV2ForCausalLM",
"DeepseekV3ForCausalLM",
"Glm4MoeForCausalLM",
]:
E = (
config.n_routed_experts + 1
if config.architectures[0] in ["DeepseekV3ForCausalLM"]
else config.n_routed_experts
)
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
else:
# Default: Mixtral
E = config.num_local_experts
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
block_shape = None
if (
hasattr(config, "quantization_config")
and "weight_block_size" in config.quantization_config
):
block_shape = config.quantization_config["weight_block_size"]
assert len(block_shape) == 2
shape_configs = {
"num_experts": E,
"topk": topk,
"hidden_size": config.hidden_size,
"shard_intermediate_size": shard_intermediate_size,
"dtype": config.torch_dtype,
"block_shape": block_shape,
}
print(f"{shape_configs=}")
return shape_configs
def fused_moe_triton_api( def fused_moe_triton_api(
x, x,
w1, w1,
...@@ -239,7 +185,8 @@ def main(): ...@@ -239,7 +185,8 @@ def main():
parser.add_argument( parser.add_argument(
"--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1" "--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
) )
parser.add_argument("--tp-size", type=int, default=2) parser.add_argument("--tp-size", "--tp", type=int, default=2)
parser.add_argument("--ep-size", "--ep", type=int, default=1)
parser.add_argument("--use-fp8-w8a8", action="store_true") parser.add_argument("--use-fp8-w8a8", action="store_true")
parser.add_argument( parser.add_argument(
"--use-cuda-graph", action="store_true", help="Enable CUDA Graph capture/replay" "--use-cuda-graph", action="store_true", help="Enable CUDA Graph capture/replay"
...@@ -270,11 +217,11 @@ def main(): ...@@ -270,11 +217,11 @@ def main():
) )
initialize_model_parallel( initialize_model_parallel(
tensor_model_parallel_size=1, tensor_model_parallel_size=args.ep_size,
pipeline_model_parallel_size=1, pipeline_model_parallel_size=args.tp_size,
) )
model_config = get_model_config(args.model, args.tp_size) model_config = get_model_config(args.model, args.tp_size, args.ep_size)
benchmark.run( benchmark.run(
show_plots=True, show_plots=True,
print_data=True, print_data=True,
......
...@@ -3,8 +3,6 @@ import argparse ...@@ -3,8 +3,6 @@ import argparse
import torch import torch
import triton import triton
import vllm
from transformers import AutoConfig
from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe as fused_moe_vllm from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe as fused_moe_vllm
from sglang.srt.distributed.parallel_state import ( from sglang.srt.distributed.parallel_state import (
...@@ -17,91 +15,7 @@ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import ( ...@@ -17,91 +15,7 @@ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
fused_moe as fused_moe_sglang, fused_moe as fused_moe_sglang,
) )
from .common_utils import get_model_config
def get_model_config(model_name: str, tp_size: int):
"""Get model configuration parameters"""
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
if config.architectures[0] == "DbrxForCausalLM":
E = config.ffn_config.moe_num_experts
topk = config.ffn_config.moe_top_k
intermediate_size = config.ffn_config.ffn_hidden_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] == "JambaForCausalLM":
E = config.num_experts
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] == "Qwen2MoeForCausalLM":
E = config.num_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] == "Qwen3MoeForCausalLM":
E = config.num_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] in [
"DeepseekV2ForCausalLM",
"DeepseekV3ForCausalLM",
"Glm4MoeForCausalLM",
]:
E = (
config.n_routed_experts + 1
if config.architectures[0] in ["DeepseekV3ForCausalLM"]
else config.n_routed_experts
)
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] == "Llama4ForConditionalGeneration":
E = config.text_config.num_local_experts
topk = config.text_config.num_experts_per_tok
intermediate_size = config.text_config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] in [
"Grok1ForCausalLM",
"Grok1ImgGen",
"Grok1AForCausalLM",
]:
E = config.num_local_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
else:
# Default: Mixtral
E = config.num_local_experts
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
vllm_version_num = (
vllm.__version_tuple__[0] * 100
+ vllm.__version_tuple__[1] * 10
+ vllm.__version_tuple__[2]
)
block_shape = None
if (
hasattr(config, "quantization_config")
and "weight_block_size" in config.quantization_config
):
block_shape = config.quantization_config["weight_block_size"]
assert len(block_shape) == 2
assert (
vllm_version_num >= 66
), "Block-wise quantized fp8 fused_moe is only supported for VLLM>=0.6.6.post1"
shape_configs = {
"num_experts": E,
"topk": topk,
"hidden_size": config.hidden_size,
"shard_intermediate_size": shard_intermediate_size,
"dtype": config.torch_dtype,
"block_shape": block_shape,
}
print(f"{shape_configs=}")
return shape_configs
def fused_moe_vllm_api( def fused_moe_vllm_api(
...@@ -301,7 +215,8 @@ def main(): ...@@ -301,7 +215,8 @@ def main():
parser.add_argument( parser.add_argument(
"--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1" "--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
) )
parser.add_argument("--tp-size", type=int, default=2) parser.add_argument("--tp-size", "--tp", type=int, default=2)
parser.add_argument("--ep-size", "--ep", type=int, default=1)
parser.add_argument("--use-fp8-w8a8", action="store_true") parser.add_argument("--use-fp8-w8a8", action="store_true")
parser.add_argument( parser.add_argument(
"--save-path", "--save-path",
...@@ -332,12 +247,12 @@ def main(): ...@@ -332,12 +247,12 @@ def main():
pipeline_model_parallel_size=1, pipeline_model_parallel_size=1,
) )
model_config = get_model_config(args.model, args.tp_size) shape_configs = get_model_config(args.model, args.tp_size, args.ep_size)
benchmark.run( benchmark.run(
show_plots=True, show_plots=True,
print_data=True, print_data=True,
save_path=args.save_path, save_path=args.save_path,
model_config=model_config, model_config=shape_configs,
use_fp8_w8a8=args.use_fp8_w8a8, use_fp8_w8a8=args.use_fp8_w8a8,
) )
finally: finally:
......
import json
from typing import Dict, List, TypedDict
import torch
from transformers import AutoConfig
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import get_config_dtype_str
from sglang.srt.layers.moe.fused_moe_triton.fused_moe_triton_config import (
get_config_file_name,
)
from sglang.srt.utils import is_hip
class BenchmarkConfig(TypedDict):
BLOCK_SIZE_M: int
BLOCK_SIZE_N: int
BLOCK_SIZE_K: int
GROUP_SIZE_M: int
num_warps: int
num_stages: int
def calculate_shard_intermediate_size(
intermediate_size: int, tp_size: int, ep_size: int = 1
) -> int:
assert tp_size % ep_size == 0
moe_tp_size = tp_size // ep_size
assert intermediate_size % moe_tp_size == 0
return 2 * intermediate_size // moe_tp_size
def get_model_config(
model_name: str,
tp_size: int,
ep_size: int = 1,
disable_shared_experts_fusion: bool = False,
topk_ids_dir: str = None,
) -> Dict:
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
block_shape = None
if (
hasattr(config, "quantization_config")
and "weight_block_size" in config.quantization_config
):
block_shape = config.quantization_config["weight_block_size"]
assert len(block_shape) == 2
architecture = config.architectures[0]
# Replace config with text_config for encoder-decoder models after getting block_shape and architecture
if hasattr(config, "text_config"):
config = config.get_text_config()
if architecture == "DbrxForCausalLM":
E = config.ffn_config.moe_num_experts // ep_size
topk = config.ffn_config.moe_top_k
intermediate_size = config.ffn_config.ffn_hidden_size
elif architecture == "JambaForCausalLM":
E = config.num_experts // ep_size
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
elif architecture in [
"Qwen2MoeForCausalLM",
"Qwen3MoeForCausalLM",
"Qwen3NextForCausalLM",
"Qwen3VLMoeForConditionalGeneration",
]:
E = config.num_experts // ep_size
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
elif architecture in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]:
E = (config.n_routed_experts // ep_size) + (
0
if disable_shared_experts_fusion
or architecture not in ["DeepseekV3ForCausalLM"]
else 1
)
topk = config.num_experts_per_tok + (
0 if disable_shared_experts_fusion or topk_ids_dir is None else 1
)
intermediate_size = config.moe_intermediate_size
elif architecture == "Llama4ForConditionalGeneration":
E = config.num_local_experts // ep_size + (
0 if disable_shared_experts_fusion else 1
)
topk = config.num_experts_per_tok + (
0 if disable_shared_experts_fusion or topk_ids_dir is None else 1
)
intermediate_size = config.intermediate_size
elif architecture in [
"Grok1ForCausalLM",
"Grok1ImgGen",
"Grok1AForCausalLM",
]:
E = config.num_local_experts // ep_size
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
elif architecture in [
"BailingMoEForCausalLM",
"BailingMoeForCausalLM",
"BailingMoeV2ForCausalLM",
]:
E = config.num_experts // ep_size
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
elif architecture in ["Glm4MoeForCausalLM"]:
E = config.n_routed_experts // ep_size
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
else:
# Default: Mixtral
E = config.num_local_experts // ep_size
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
shard_intermediate_size = calculate_shard_intermediate_size(
intermediate_size, tp_size, ep_size
)
return {
"num_experts": E,
"topk": topk,
"hidden_size": config.hidden_size,
"shard_intermediate_size": shard_intermediate_size,
"dtype": config.torch_dtype,
"block_shape": block_shape,
"architecture": architecture,
}
def get_rocm_configs_compute_bound() -> List[Dict[str, int]]:
configs: List[BenchmarkConfig] = []
waves_per_eu_range = 0
for num_stages in [2]:
for block_m in [32, 64, 128, 256]:
for block_k in [32, 64, 128, 256]:
for block_n in [16, 32, 64, 128, 256]:
for num_warps in [1, 2, 4, 8]:
for group_size in [1, 4, 8, 16, 32]:
configs.append(
{
"BLOCK_SIZE_M": block_m,
"BLOCK_SIZE_N": block_n,
"BLOCK_SIZE_K": block_k,
"GROUP_SIZE_M": group_size,
"num_warps": num_warps,
"num_stages": num_stages,
"waves_per_eu": waves_per_eu_range,
}
)
return configs
def get_configs_compute_bound() -> List[Dict[str, int]]:
configs: List[BenchmarkConfig] = []
if is_hip():
configs = get_rocm_configs_compute_bound()
else:
for num_stages in [2, 3, 4, 5]:
for block_m in [16, 32, 64, 128, 256]:
for block_k in [64, 128, 256]:
for block_n in [32, 64, 128, 256]:
for num_warps in [4, 8]:
for group_size in [1, 16, 32, 64]:
configs.append(
{
"BLOCK_SIZE_M": block_m,
"BLOCK_SIZE_N": block_n,
"BLOCK_SIZE_K": block_k,
"GROUP_SIZE_M": group_size,
"num_warps": num_warps,
"num_stages": num_stages,
}
)
return configs
def sort_config(config: BenchmarkConfig) -> BenchmarkConfig:
return {
"BLOCK_SIZE_M": config["BLOCK_SIZE_M"],
"BLOCK_SIZE_N": config["BLOCK_SIZE_N"],
"BLOCK_SIZE_K": config["BLOCK_SIZE_K"],
"GROUP_SIZE_M": config["GROUP_SIZE_M"],
"num_warps": config["num_warps"],
"num_stages": config["num_stages"],
**(
{"waves_per_eu": config["waves_per_eu"]} if "waves_per_eu" in config else {}
),
**({"USE_TMA": config["USE_TMA"]} if "USE_TMA" in config else {}),
}
def save_configs(
configs: Dict[int, BenchmarkConfig],
filename: str,
) -> None:
print(f"Writing best config to {filename}...")
with open(filename, "w") as f:
json.dump(configs, f, indent=4)
f.write("\n")
def get_config_filename(
num_experts: int,
shard_intermediate_size: int,
hidden_size: int,
topk: int,
dtype: torch.dtype,
use_fp8_w8a8: bool,
use_int8_w8a8: bool,
use_int8_w8a16: bool,
per_channel_quant: bool,
block_shape: List[int],
) -> str:
dtype_str = get_config_dtype_str(
dtype,
use_int8_w8a16=use_int8_w8a16,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
)
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
# is the intermediate size after silu_and_mul.
filename = get_config_file_name(
num_experts,
shard_intermediate_size // 2,
dtype_str,
block_shape,
per_channel_quant,
)
return filename
def get_default_batch_sizes() -> List[int]:
return [
1,
2,
4,
8,
16,
24,
32,
48,
64,
96,
128,
256,
512,
1024,
1536,
2048,
3072,
4096,
]
import argparse
import os
import time
import openai
"""
# Edit the code file srt/models/deepseek_v2.py in the Python site package and add the logic for saving topk_ids:
# import get_tensor_model_parallel_rank
# DeepseekV2MoE::forward_normal
if hidden_states.shape[0] >= 4096 and get_tensor_model_parallel_rank() == 0:
topk_ids_dir = xxxx
if not hasattr(self, "save_idx"):
self.save_idx = 0
if self.save_idx <= 1:
torch.save(topk_output.topk_ids, f"{topk_ids_dir}/topk_ids_layer{self.layer_id}_idx{self.save_idx}.pt")
self.save_idx += 1
"""
def read_long_prompt():
import json
current_dir = os.path.dirname(os.path.abspath(__file__))
with open(f"{current_dir}/tuning_text.json", "r") as fp:
text = fp.read()
rst = json.loads(text)
return rst["prompt"]
def openai_stream_test(model, ip, port):
client = openai.Client(base_url=f"http://{ip}:{port}/v1", api_key="None")
qst = read_long_prompt()
messages = [
{"role": "user", "content": qst},
]
msg2 = dict(
model=model,
messages=messages,
temperature=0.6,
top_p=0.75,
max_tokens=100,
)
response = client.chat.completions.create(**msg2, stream=True)
time_start = time.time()
time_cost = []
for chunk in response:
time_end = time.time()
# if chunk.choices[0].delta.content:
# print(chunk.choices[0].delta.content, end="", flush=True)
time_cost.append(time_end - time_start)
time_start = time.time()
ttft = time_cost[0] + time_cost[1]
tpot = sum(time_cost[2:]) / len(time_cost[2:])
print(f"\nTTFT {ttft}, TPOT {tpot}")
return ttft, tpot
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default="auto")
parser.add_argument(
"--ip",
type=str,
default="127.0.0.1",
)
parser.add_argument("--port", type=int, default=8188)
args = parser.parse_args()
openai_stream_test(args.model, args.ip, args.port)
# Adapted from https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_moe.py # Adapted from https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_moe.py
import argparse import argparse
import json
import time import time
from contextlib import nullcontext from contextlib import nullcontext
from datetime import datetime from datetime import datetime
from typing import Any, Dict, List, Tuple, TypedDict from typing import Any, Dict, List, Tuple
import ray import ray
import torch import torch
import triton import triton
from common_utils import (
BenchmarkConfig,
get_config_filename,
get_configs_compute_bound,
get_default_batch_sizes,
get_model_config,
save_configs,
sort_config,
)
from ray.experimental.tqdm_ray import tqdm from ray.experimental.tqdm_ray import tqdm
from transformers import AutoConfig
from sglang.srt.layers.moe.fused_moe_triton import override_config from sglang.srt.layers.moe.fused_moe_triton import override_config
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
from sglang.srt.layers.moe.fused_moe_triton.fused_moe_triton_config import ( from sglang.srt.layers.moe.fused_moe_triton.fused_moe_triton_config import (
get_config_dtype_str, get_config_dtype_str,
get_config_file_name,
get_default_config, get_default_config,
get_moe_configs, get_moe_configs,
) )
...@@ -27,15 +33,6 @@ from sglang.srt.utils import is_hip ...@@ -27,15 +33,6 @@ from sglang.srt.utils import is_hip
_is_hip = is_hip() _is_hip = is_hip()
class BenchmarkConfig(TypedDict):
BLOCK_SIZE_M: int
BLOCK_SIZE_N: int
BLOCK_SIZE_K: int
GROUP_SIZE_M: int
num_warps: int
num_stages: int
def benchmark_config( def benchmark_config(
config: BenchmarkConfig, config: BenchmarkConfig,
num_tokens: int, num_tokens: int,
...@@ -173,74 +170,28 @@ def benchmark_config( ...@@ -173,74 +170,28 @@ def benchmark_config(
graph.replay() graph.replay()
torch.cuda.synchronize() torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True) # Flush L2 cache with 256 MB data
end_event = torch.cuda.Event(enable_timing=True) cache_flush = torch.empty(int(256e6 // 4), dtype=torch.int, device="cuda")
cache_flush.zero_()
start_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_iters)]
end_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_iters)]
latencies: List[float] = []
for i in range(num_iters): for i in range(num_iters):
prepare(i) prepare(i)
start_events[i].record()
graph.replay()
end_events[i].record()
torch.cuda.synchronize() torch.cuda.synchronize()
start_event.record() latencies: List[float] = []
graph.replay() for i in range(num_iters):
end_event.record() latencies.append(start_events[i].elapsed_time(end_events[i]))
end_event.synchronize()
latencies.append(start_event.elapsed_time(end_event))
avg = sum(latencies) / (num_iters * 10) * 1000 # us avg = sum(latencies) / (num_iters * 10) * 1000 # us
graph.reset() graph.reset()
return avg return avg
def get_rocm_configs_compute_bound() -> List[Dict[str, int]]:
configs: List[BenchmarkConfig] = []
waves_per_eu_range = 0
for num_stages in [2]:
for block_m in [32, 64, 128, 256]:
for block_k in [32, 64, 128, 256]:
for block_n in [16, 32, 64, 128, 256]:
for num_warps in [1, 2, 4, 8]:
for group_size in [1, 4, 8, 16, 32]:
configs.append(
{
"BLOCK_SIZE_M": block_m,
"BLOCK_SIZE_N": block_n,
"BLOCK_SIZE_K": block_k,
"GROUP_SIZE_M": group_size,
"num_warps": num_warps,
"num_stages": num_stages,
"waves_per_eu": waves_per_eu_range,
}
)
return configs
def get_configs_compute_bound() -> List[Dict[str, int]]:
# Reduced search space for faster tuning.
# TODO(woosuk): Increase the search space and use a performance model to
# prune the search space.
configs: List[BenchmarkConfig] = []
if _is_hip:
configs = get_rocm_configs_compute_bound()
else:
for num_stages in [2, 3, 4, 5]:
for block_m in [16, 32, 64, 128, 256]:
for block_k in [64, 128, 256]:
for block_n in [32, 64, 128, 256]:
for num_warps in [4, 8]:
for group_size in [1, 16, 32, 64]:
configs.append(
{
"BLOCK_SIZE_M": block_m,
"BLOCK_SIZE_N": block_n,
"BLOCK_SIZE_K": block_k,
"GROUP_SIZE_M": group_size,
"num_warps": num_warps,
"num_stages": num_stages,
}
)
return configs
@ray.remote(num_gpus=1) @ray.remote(num_gpus=1)
class BenchmarkWorker: class BenchmarkWorker:
...@@ -360,189 +311,27 @@ class BenchmarkWorker: ...@@ -360,189 +311,27 @@ class BenchmarkWorker:
return best_config return best_config
def sort_config(config: BenchmarkConfig) -> BenchmarkConfig:
return {
"BLOCK_SIZE_M": config["BLOCK_SIZE_M"],
"BLOCK_SIZE_N": config["BLOCK_SIZE_N"],
"BLOCK_SIZE_K": config["BLOCK_SIZE_K"],
"GROUP_SIZE_M": config["GROUP_SIZE_M"],
"num_warps": config["num_warps"],
"num_stages": config["num_stages"],
**(
{"waves_per_eu": config["waves_per_eu"]} if "waves_per_eu" in config else {}
),
}
def save_configs(
configs: Dict[int, BenchmarkConfig],
filename: str,
) -> None:
print(f"Writing best config to {filename}...")
with open(filename, "w") as f:
json.dump(configs, f, indent=4)
f.write("\n")
def get_filename(
num_experts: int,
shard_intermediate_size: int,
hidden_size: int,
topk: int,
dtype: torch.dtype,
use_fp8_w8a8: bool,
use_int8_w8a8: bool,
use_int8_w8a16: bool,
per_channel_quant: bool,
block_shape: List[int],
) -> None:
dtype_str = get_config_dtype_str(
dtype,
use_int8_w8a16=use_int8_w8a16,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
)
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
# is the intermediate size after silu_and_mul.
filename = get_config_file_name(
num_experts,
shard_intermediate_size // 2,
dtype_str,
block_shape,
per_channel_quant,
)
return filename
def main(args: argparse.Namespace): def main(args: argparse.Namespace):
print(args) print(args)
def _calculate_shard_intermediate_size(intermediate_size: int) -> int: model_config = get_model_config(
# In EP mode, use original intermediate_size; otherwise apply TP sharding args.model, args.tp_size, args.ep_size, args.disable_shared_experts_fusion
return (
intermediate_size
if args.ep_size > 1
else 2 * intermediate_size // args.tp_size
) )
# Check EP mode constraint: tp_size must be 1 when ep_size > 1 E = model_config["num_experts"]
if args.ep_size > 1 and args.tp_size != 1: topk = model_config["topk"]
raise ValueError( hidden_size = model_config["hidden_size"]
f"When using Expert Parallelism (ep_size={args.ep_size}), " shard_intermediate_size = model_config["shard_intermediate_size"]
f"tp_size must be set to 1, but got tp_size={args.tp_size}. " dtype = model_config["dtype"]
f"Please set --tp-size 1 when using --ep-size > 1." block_shape = model_config["block_shape"]
)
config = AutoConfig.from_pretrained(args.model, trust_remote_code=True)
# Determine block shape for quantization
block_shape = None
if (
hasattr(config, "quantization_config")
and "weight_block_size" in config.quantization_config
):
block_shape = config.quantization_config["weight_block_size"]
assert len(block_shape) == 2
architecture = config.architectures[0]
# replace config with text_config for encoder-decoder models after getting block_shape and architecture
if hasattr(config, "text_config"):
config = config.get_text_config()
if architecture == "DbrxForCausalLM":
E = config.ffn_config.moe_num_experts
topk = config.ffn_config.moe_top_k
intermediate_size = config.ffn_config.ffn_hidden_size
shard_intermediate_size = _calculate_shard_intermediate_size(intermediate_size)
elif architecture == "JambaForCausalLM":
E = config.num_experts
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
shard_intermediate_size = _calculate_shard_intermediate_size(intermediate_size)
elif architecture in [
"Qwen2MoeForCausalLM",
"Qwen3MoeForCausalLM",
"Qwen3NextForCausalLM",
"Qwen3VLMoeForConditionalGeneration",
]:
E = config.num_experts // args.ep_size
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = _calculate_shard_intermediate_size(intermediate_size)
elif architecture in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]:
E = (
config.n_routed_experts + (0 if args.disable_shared_experts_fusion else 1)
if architecture == "DeepseekV3ForCausalLM"
else config.n_routed_experts
)
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = _calculate_shard_intermediate_size(intermediate_size)
elif architecture == "Llama4ForConditionalGeneration":
E = config.num_local_experts + (0 if args.disable_shared_experts_fusion else 1)
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
shard_intermediate_size = _calculate_shard_intermediate_size(intermediate_size)
elif architecture in [
"Grok1ForCausalLM",
"Grok1ImgGen",
"Grok1AForCausalLM",
]:
E = config.num_local_experts // args.ep_size
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = _calculate_shard_intermediate_size(intermediate_size)
elif architecture in [
"BailingMoEForCausalLM",
"BailingMoeForCausalLM",
"BailingMoeV2ForCausalLM",
]:
E = config.num_experts // args.ep_size
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = _calculate_shard_intermediate_size(intermediate_size)
elif architecture in ["Glm4MoeForCausalLM"]:
E = config.n_routed_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = _calculate_shard_intermediate_size(intermediate_size)
else:
# Default: Mixtral
E = config.num_local_experts // args.ep_size
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
shard_intermediate_size = _calculate_shard_intermediate_size(intermediate_size)
hidden_size = config.hidden_size
dtype = config.torch_dtype
use_fp8_w8a8 = args.dtype == "fp8_w8a8" use_fp8_w8a8 = args.dtype == "fp8_w8a8"
use_int8_w8a8 = args.dtype == "int8_w8a8" use_int8_w8a8 = args.dtype == "int8_w8a8"
use_int8_w8a16 = args.dtype == "int8_w8a16" use_int8_w8a16 = args.dtype == "int8_w8a16"
per_channel_quant = args.per_channel_quant per_channel_quant = args.per_channel_quant
if args.batch_size is None: if args.batch_size is None:
batch_sizes = [ batch_sizes = get_default_batch_sizes()
1,
2,
4,
8,
16,
24,
32,
48,
64,
96,
128,
256,
512,
1024,
1536,
2048,
3072,
4096,
]
else: else:
batch_sizes = [args.batch_size] batch_sizes = [args.batch_size]
...@@ -571,7 +360,7 @@ def main(args: argparse.Namespace): ...@@ -571,7 +360,7 @@ def main(args: argparse.Namespace):
if block_k % config["BLOCK_SIZE_K"] == 0 if block_k % config["BLOCK_SIZE_K"] == 0
] ]
filename = get_filename( filename = get_config_filename(
E, E,
shard_intermediate_size, shard_intermediate_size,
hidden_size, hidden_size,
......
...@@ -5,15 +5,22 @@ import os ...@@ -5,15 +5,22 @@ import os
import time import time
from contextlib import nullcontext from contextlib import nullcontext
from datetime import datetime from datetime import datetime
from typing import Any, Dict, List, Tuple, TypedDict from typing import Any, Dict, List, Tuple
import ray import ray
import torch import torch
import triton import triton
import triton.language as tl import triton.language as tl
from common_utils import (
BenchmarkConfig,
get_config_filename,
get_configs_compute_bound,
get_default_batch_sizes,
get_model_config,
sort_config,
)
from ray.experimental.tqdm_ray import tqdm from ray.experimental.tqdm_ray import tqdm
from sgl_kernel import silu_and_mul from sgl_kernel import silu_and_mul
from transformers import AutoConfig
from sglang.srt.layers.moe.fused_moe_triton import override_config from sglang.srt.layers.moe.fused_moe_triton import override_config
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import ( from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
...@@ -31,15 +38,6 @@ from sglang.srt.utils import is_hip ...@@ -31,15 +38,6 @@ from sglang.srt.utils import is_hip
_is_hip = is_hip() _is_hip = is_hip()
class BenchmarkConfig(TypedDict):
BLOCK_SIZE_M: int
BLOCK_SIZE_N: int
BLOCK_SIZE_K: int
GROUP_SIZE_M: int
num_warps: int
num_stages: int
def benchmark_config( def benchmark_config(
config: BenchmarkConfig, config: BenchmarkConfig,
num_tokens: int, num_tokens: int,
...@@ -294,56 +292,6 @@ def benchmark_config( ...@@ -294,56 +292,6 @@ def benchmark_config(
return avg, avg_tma, avg1, avg1_tma return avg, avg_tma, avg1, avg1_tma
def get_rocm_configs_compute_bound() -> List[Dict[str, int]]:
configs: List[BenchmarkConfig] = []
waves_per_eu_range = 0
for block_m in [32, 64, 128, 256]:
for block_k in [32, 64, 128, 256]:
for block_n in [16, 32, 64, 128, 256]:
for num_stages in [2]:
for num_warps in [1, 2, 4, 8]:
for group_size in [1, 4, 8, 16, 32]:
configs.append(
{
"BLOCK_SIZE_M": block_m,
"BLOCK_SIZE_N": block_n,
"BLOCK_SIZE_K": block_k,
"GROUP_SIZE_M": group_size,
"num_warps": num_warps,
"num_stages": num_stages,
"waves_per_eu": waves_per_eu_range,
}
)
return configs
def get_configs_compute_bound() -> List[Dict[str, int]]:
# Reduced search space for faster tuning.
# TODO(woosuk): Increase the search space and use a performance model to
# prune the search space.
configs: List[BenchmarkConfig] = []
if _is_hip:
configs = get_rocm_configs_compute_bound()
else:
for block_m in [16, 32, 64, 128, 256]:
for block_k in [32, 64, 128, 256]:
for block_n in [32, 64, 128, 256]:
for num_stages in [2, 3, 4, 5]:
for num_warps in [4, 8]:
for group_size in [1, 16, 32, 64]:
configs.append(
{
"BLOCK_SIZE_M": block_m,
"BLOCK_SIZE_N": block_n,
"BLOCK_SIZE_K": block_k,
"GROUP_SIZE_M": group_size,
"num_warps": num_warps,
"num_stages": num_stages,
}
)
return configs
class BestConfigTrace: class BestConfigTrace:
def __init__(self, name): def __init__(self, name):
self.name = name self.name = name
...@@ -509,22 +457,7 @@ class BenchmarkWorker: ...@@ -509,22 +457,7 @@ class BenchmarkWorker:
) )
def sort_config(config: BenchmarkConfig) -> BenchmarkConfig: def save_configs_sep(
return {
"BLOCK_SIZE_M": config["BLOCK_SIZE_M"],
"BLOCK_SIZE_N": config["BLOCK_SIZE_N"],
"BLOCK_SIZE_K": config["BLOCK_SIZE_K"],
"GROUP_SIZE_M": config["GROUP_SIZE_M"],
"num_warps": config["num_warps"],
"num_stages": config["num_stages"],
**(
{"waves_per_eu": config["waves_per_eu"]} if "waves_per_eu" in config else {}
),
**({"USE_TMA": config["USE_TMA"]} if "USE_TMA" in config else {}),
}
def save_configs(
configs: Dict[int, BenchmarkConfig], configs: Dict[int, BenchmarkConfig],
num_experts: int, num_experts: int,
shard_intermediate_size: int, shard_intermediate_size: int,
...@@ -563,100 +496,29 @@ def save_configs( ...@@ -563,100 +496,29 @@ def save_configs(
def main(args: argparse.Namespace): def main(args: argparse.Namespace):
print(args) print(args)
config = AutoConfig.from_pretrained(args.model, trust_remote_code=True) model_config = get_model_config(
if config.architectures[0] == "DbrxForCausalLM": args.model,
E = config.ffn_config.moe_num_experts args.tp_size,
topk = config.ffn_config.moe_top_k args.ep_size,
intermediate_size = config.ffn_config.ffn_hidden_size args.disable_shared_experts_fusion,
shard_intermediate_size = 2 * intermediate_size // args.tp_size args.topk_ids_dir,
elif config.architectures[0] == "JambaForCausalLM": )
E = config.num_experts
topk = config.num_experts_per_tok E = model_config["num_experts"]
intermediate_size = config.intermediate_size topk = model_config["topk"]
shard_intermediate_size = 2 * intermediate_size // args.tp_size hidden_size = model_config["hidden_size"]
elif config.architectures[0] in ["Qwen2MoeForCausalLM", "Qwen3MoeForCausalLM"]: shard_intermediate_size = model_config["shard_intermediate_size"]
E = config.num_experts dtype = model_config["dtype"]
topk = config.num_experts_per_tok block_shape = model_config["block_shape"]
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]:
E = (
config.n_routed_experts + (0 if args.disable_shared_experts_fusion else 1)
if config.architectures[0] in ["DeepseekV3ForCausalLM"]
else config.n_routed_experts
)
topk = (
config.num_experts_per_tok
+ (0 if args.disable_shared_experts_fusion else 1)
if config.architectures[0] in ["DeepseekV3ForCausalLM"]
else config.num_experts_per_tok
)
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
elif config.architectures[0] == "Llama4ForConditionalGeneration":
E = config.text_config.num_local_experts + (
0 if args.disable_shared_experts_fusion else 1
)
topk = config.text_config.num_experts_per_tok
intermediate_size = config.text_config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
elif config.architectures[0] in [
"Grok1ForCausalLM",
"Grok1ImgGen",
"Grok1AForCausalLM",
]:
E = config.num_local_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
elif config.architectures[0] in ["Glm4MoeForCausalLM"]:
E = config.n_routed_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
else:
# Default: Mixtral
E = config.num_local_experts
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
hidden_size = getattr(config, "hidden_size", None) or config.text_config.hidden_size
dtype = config.torch_dtype
use_fp8_w8a8 = args.dtype == "fp8_w8a8" use_fp8_w8a8 = args.dtype == "fp8_w8a8"
use_int8_w8a8 = args.dtype == "int8_w8a8" use_int8_w8a8 = args.dtype == "int8_w8a8"
use_int8_w8a16 = args.dtype == "int8_w8a16" use_int8_w8a16 = args.dtype == "int8_w8a16"
block_shape = None per_channel_quant = args.per_channel_quant
if (
hasattr(config, "quantization_config")
and "weight_block_size" in config.quantization_config
):
block_shape = config.quantization_config["weight_block_size"]
assert len(block_shape) == 2
topk_ids_dir = args.topk_ids_dir topk_ids_dir = args.topk_ids_dir
if args.batch_size is None: if args.batch_size is None:
batch_sizes = [ batch_sizes = get_default_batch_sizes()
1,
2,
4,
8,
16,
24,
32,
48,
64,
96,
128,
256,
512,
1024,
1536,
2048,
3072,
4096,
8192,
]
batch_sizes.reverse() batch_sizes.reverse()
else: else:
batch_sizes = [args.batch_size] batch_sizes = [args.batch_size]
...@@ -731,7 +593,21 @@ def main(args: argparse.Namespace): ...@@ -731,7 +593,21 @@ def main(args: argparse.Namespace):
search_space = [ search_space = [
config for config in search_space if block_k % config["BLOCK_SIZE_K"] == 0 config for config in search_space if block_k % config["BLOCK_SIZE_K"] == 0
] ]
print(f"Start tuning over {len(search_space)} configurations...") filename = get_config_filename(
E,
shard_intermediate_size,
hidden_size,
topk,
dtype,
use_fp8_w8a8,
use_int8_w8a8,
use_int8_w8a16,
per_channel_quant,
block_shape,
)
print(
f"Start tuning over {len(search_space)} configurations to create {filename}..."
)
start = time.perf_counter() start = time.perf_counter()
configs = _distribute( configs = _distribute(
...@@ -764,7 +640,7 @@ def main(args: argparse.Namespace): ...@@ -764,7 +640,7 @@ def main(args: argparse.Namespace):
configs0.reverse() configs0.reverse()
configs1.reverse() configs1.reverse()
best_configs0 = {M: sort_config(config) for M, config in zip(batch_sizes, configs0)} best_configs0 = {M: sort_config(config) for M, config in zip(batch_sizes, configs0)}
save_configs( save_configs_sep(
best_configs0, best_configs0,
E, E,
shard_intermediate_size, shard_intermediate_size,
...@@ -778,7 +654,7 @@ def main(args: argparse.Namespace): ...@@ -778,7 +654,7 @@ def main(args: argparse.Namespace):
) )
best_configs1 = {M: sort_config(config) for M, config in zip(batch_sizes, configs1)} best_configs1 = {M: sort_config(config) for M, config in zip(batch_sizes, configs1)}
save_configs( save_configs_sep(
best_configs1, best_configs1,
E, E,
shard_intermediate_size, shard_intermediate_size,
...@@ -801,6 +677,7 @@ if __name__ == "__main__": ...@@ -801,6 +677,7 @@ if __name__ == "__main__":
"--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1" "--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
) )
parser.add_argument("--tp-size", "--tp", type=int, default=2) parser.add_argument("--tp-size", "--tp", type=int, default=2)
parser.add_argument("--ep-size", "--ep", type=int, default=1)
parser.add_argument( parser.add_argument(
"--dtype", "--dtype",
type=str, type=str,
......
This source diff could not be displayed because it is too large. You can view the blob instead.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment