Unverified Commit 33deca81 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Add more fused moe benchmark utilities (#2314)

parent 18108abe
import argparse
import torch
import triton
from torch.nn import functional as F
from transformers import AutoConfig
from sglang.srt.layers.fused_moe_triton.fused_moe import fused_moe as fused_moe_triton
def get_model_config(model_name: str, tp_size: int):
"""Get model configuration parameters"""
config = AutoConfig.from_pretrained(model_name)
if config.architectures[0] == "DbrxForCausalLM":
E = config.ffn_config.moe_num_experts
topk = config.ffn_config.moe_top_k
intermediate_size = config.ffn_config.ffn_hidden_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] == "JambaForCausalLM":
E = config.num_experts
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] == "Qwen2MoeForCausalLM":
E = config.num_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
else:
# Default: Mixtral
E = config.num_local_experts
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
shape_configs = {
"num_experts": E,
"topk": topk,
"hidden_size": config.hidden_size,
"shard_intermediate_size": shard_intermediate_size,
"dtype": config.torch_dtype,
}
print(f"{shape_configs=}")
return shape_configs
def fused_topk_native(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
):
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
M, _ = hidden_states.shape
topk_weights = torch.empty(
M, topk, dtype=torch.float32, device=hidden_states.device
)
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
topk_weights = F.softmax(gating_output.float(), dim=-1)
topk_weights, topk_ids = torch.topk(topk_weights, topk, dim=-1)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_ids
@torch.compile
def fused_moe_torch(
x,
w1,
w2,
input_gating,
topk,
use_fp8_w8a8=False,
w1_scale=None,
w2_scale=None,
a1_scale=None,
a2_scale=None,
) -> torch.Tensor:
assert not use_fp8_w8a8, "Not supported"
topk_weights, topk_ids = fused_topk_native(
hidden_states=x,
gating_output=input_gating,
topk=topk,
renormalize=True,
)
w13_weights = w1[topk_ids]
w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2)
w2_weights = w2[topk_ids]
x1 = F.gelu(torch.einsum("ti,taoi -> tao", x, w1_weights))
x3 = torch.einsum("ti, taoi -> tao", x, w3_weights)
expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights)
return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype))
def fused_moe_torch_compile(
x,
w1,
w2,
input_gating,
topk,
use_fp8_w8a8=False,
w1_scale=None,
w2_scale=None,
a1_scale=None,
a2_scale=None,
):
return fused_moe_torch(
x,
w1,
w2,
input_gating,
topk,
use_fp8_w8a8=use_fp8_w8a8,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
)
def fused_moe_sglang_api(
x,
w1,
w2,
input_gating,
topk,
use_fp8_w8a8=False,
w1_scale=None,
w2_scale=None,
a1_scale=None,
a2_scale=None,
):
return fused_moe_triton(
x,
w1,
w2,
input_gating,
topk,
renormalize=True,
inplace=True,
use_fp8_w8a8=use_fp8_w8a8,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
)
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size"],
x_vals=list(range(1, 5)),
line_arg="provider",
line_vals=[
"fused_moe_triton",
"fused_moe_torch_compile",
],
line_names=[
"fused_moe_triton",
"fused_moe_torch_compile",
],
styles=[
("blue", "-"),
("green", "-"),
],
ylabel="Time (ms)",
plot_name="fused-moe-performance",
args={},
)
)
def benchmark(batch_size, provider, model_config, use_fp8=False):
print(f"benchmark {provider} with batch_size={batch_size}")
torch.set_default_device("cuda")
torch.cuda.manual_seed_all(0)
num_tokens = batch_size
num_experts = model_config["num_experts"]
hidden_size = model_config["hidden_size"]
shard_intermediate_size = model_config["shard_intermediate_size"]
topk = model_config["topk"]
dtype = model_config["dtype"]
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
if use_fp8:
init_dtype = dtype
w1 = torch.randn(
num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype
)
w2 = torch.randn(
num_experts, hidden_size, shard_intermediate_size // 2, dtype=init_dtype
)
w1 = w1.to(torch.float8_e4m3fn)
w2 = w2.to(torch.float8_e4m3fn)
w1_scale = torch.randn(num_experts, dtype=torch.float32)
w2_scale = torch.randn(num_experts, dtype=torch.float32)
a1_scale = torch.randn(1, dtype=torch.float32)
a2_scale = torch.randn(1, dtype=torch.float32)
else:
w1 = torch.randn(num_experts, shard_intermediate_size, hidden_size, dtype=dtype)
w2 = torch.randn(
num_experts, hidden_size, shard_intermediate_size // 2, dtype=dtype
)
w1_scale = w2_scale = a1_scale = a2_scale = None
input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)
# Warmup
api_func = (
fused_moe_torch_compile
if provider == "fused_moe_torch_compile"
else fused_moe_sglang_api
)
for _ in range(10):
y = api_func(
x,
w1,
w2,
input_gating,
topk,
use_fp8_w8a8=use_fp8,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
)
torch.cuda.synchronize()
quantiles = [0.5, 0.2, 0.8]
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: api_func(
x,
w1,
w2,
input_gating,
topk,
use_fp8_w8a8=use_fp8,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
)[0],
quantiles=quantiles,
)
return ms, min_ms, max_ms
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
)
parser.add_argument("--tp-size", type=int, default=2)
parser.add_argument("--use-fp8", action="store_true")
parser.add_argument(
"--save-path",
type=str,
default="./configs/benchmark_ops/fused_moe_torch_compile/",
)
args = parser.parse_args()
model_config = get_model_config(args.model, args.tp_size)
benchmark.run(
show_plots=True,
print_data=True,
save_path=args.save_path,
model_config=model_config,
use_fp8=args.use_fp8,
)
if __name__ == "__main__":
main()
import argparse import argparse
import numbers
from typing import Optional
import torch import torch
import triton import triton
from torch.nn import init
from torch.nn.parameter import Parameter
from transformers import AutoConfig from transformers import AutoConfig
from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe as fused_moe_vllm from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe as fused_moe_vllm
from vllm.model_executor.layers.fused_moe.fused_moe import (
get_moe_configs as get_moe_configs_vllm,
)
from vllm.utils import FlexibleArgumentParser
from sglang.srt.layers.fused_moe_triton.fused_moe import fused_moe as fused_moe_sglang from sglang.srt.layers.fused_moe_triton.fused_moe import fused_moe as fused_moe_sglang
from sglang.srt.layers.fused_moe_triton.fused_moe import (
get_moe_configs as get_moe_configs_sglang,
)
def get_model_config(model_name: str, tp_size: int): def get_model_config(model_name: str, tp_size: int):
...@@ -39,19 +28,21 @@ def get_model_config(model_name: str, tp_size: int): ...@@ -39,19 +28,21 @@ def get_model_config(model_name: str, tp_size: int):
intermediate_size = config.moe_intermediate_size intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size shard_intermediate_size = 2 * intermediate_size // tp_size
else: else:
# Default: Mixtral, Grok1, etc. # Default: Mixtral
E = config.num_local_experts E = config.num_local_experts
topk = config.num_experts_per_tok topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size intermediate_size = config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size shard_intermediate_size = 2 * intermediate_size // tp_size
return { shape_configs = {
"num_experts": E, "num_experts": E,
"topk": topk, "topk": topk,
"hidden_size": config.hidden_size, "hidden_size": config.hidden_size,
"shard_intermediate_size": shard_intermediate_size, "shard_intermediate_size": shard_intermediate_size,
"dtype": config.torch_dtype, "dtype": config.torch_dtype,
} }
print(f"{shape_configs=}")
return shape_configs
def fused_moe_vllm_api( def fused_moe_vllm_api(
...@@ -133,7 +124,7 @@ def fused_moe_sglang_api( ...@@ -133,7 +124,7 @@ def fused_moe_sglang_api(
) )
) )
def benchmark(batch_size, provider, model_config, use_fp8=False): def benchmark(batch_size, provider, model_config, use_fp8=False):
print(f"benchmark for batch_size={batch_size}") print(f"benchmark {provider} with batch_size={batch_size}")
torch.set_default_device("cuda") torch.set_default_device("cuda")
torch.cuda.manual_seed_all(0) torch.cuda.manual_seed_all(0)
...@@ -210,7 +201,7 @@ def benchmark(batch_size, provider, model_config, use_fp8=False): ...@@ -210,7 +201,7 @@ def benchmark(batch_size, provider, model_config, use_fp8=False):
def main(): def main():
parser = FlexibleArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
"--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1" "--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
) )
......
# Adapted from https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_moe.py # Adapted from https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_moe.py
import argparse import argparse
import json
import time import time
from datetime import datetime from datetime import datetime
from typing import Any, Dict, List, Tuple, TypedDict from typing import Any, Dict, List, Tuple, TypedDict
...@@ -9,10 +10,14 @@ import torch ...@@ -9,10 +10,14 @@ import torch
import triton import triton
from ray.experimental.tqdm_ray import tqdm from ray.experimental.tqdm_ray import tqdm
from transformers import AutoConfig from transformers import AutoConfig
from vllm.platforms import current_platform
from vllm.utils import FlexibleArgumentParser
from sglang.srt.layers.fused_moe_triton.fused_moe import * from sglang.srt.layers.fused_moe_triton.fused_moe import (
fused_moe,
get_config_dtype_str,
get_config_file_name,
get_default_config,
get_moe_configs,
)
class BenchmarkConfig(TypedDict): class BenchmarkConfig(TypedDict):
...@@ -92,7 +97,7 @@ def benchmark_config( ...@@ -92,7 +97,7 @@ def benchmark_config(
input_gating.copy_(gating_output[i]) input_gating.copy_(gating_output[i])
def run(): def run():
from sglang.srt.layers.fused_moe_triton.fused_moe import override_config from sglang.srt.layers.fused_moe_triton import override_config
with override_config(config): with override_config(config):
fused_moe( fused_moe(
...@@ -174,7 +179,7 @@ class BenchmarkWorker: ...@@ -174,7 +179,7 @@ class BenchmarkWorker:
def __init__(self, seed: int) -> None: def __init__(self, seed: int) -> None:
torch.set_default_device("cuda") torch.set_default_device("cuda")
current_platform.seed_everything(seed) torch.cuda.manual_seed_all(0)
self.seed = seed self.seed = seed
def benchmark( def benchmark(
...@@ -188,7 +193,7 @@ class BenchmarkWorker: ...@@ -188,7 +193,7 @@ class BenchmarkWorker:
use_fp8_w8a8: bool, use_fp8_w8a8: bool,
use_int8_w8a16: bool, use_int8_w8a16: bool,
) -> Tuple[Dict[str, int], float]: ) -> Tuple[Dict[str, int], float]:
current_platform.seed_everything(self.seed) torch.cuda.manual_seed_all(0)
dtype_str = get_config_dtype_str( dtype_str = get_config_dtype_str(
dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8 dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
) )
...@@ -319,7 +324,7 @@ def main(args: argparse.Namespace): ...@@ -319,7 +324,7 @@ def main(args: argparse.Namespace):
intermediate_size = config.moe_intermediate_size intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size shard_intermediate_size = 2 * intermediate_size // args.tp_size
else: else:
# Default: Mixtral. # Default: Mixtral
E = config.num_local_experts E = config.num_local_experts
topk = config.num_experts_per_tok topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size intermediate_size = config.intermediate_size
...@@ -430,7 +435,7 @@ def main(args: argparse.Namespace): ...@@ -430,7 +435,7 @@ def main(args: argparse.Namespace):
if __name__ == "__main__": if __name__ == "__main__":
parser = FlexibleArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
"--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1" "--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment