Unverified Commit 849f58d6 authored by GaoYuYang's avatar GaoYuYang Committed by GitHub
Browse files

Update fused_moe's benchmark (#3346)

parent 64480df4
......@@ -2,6 +2,7 @@ import argparse
import torch
import triton
import vllm
from transformers import AutoConfig
from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe as fused_moe_vllm
......@@ -29,11 +30,11 @@ def get_model_config(model_name: str, tp_size: int):
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] == "DeepseekV2ForCausalLM":
elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]:
E = config.n_routed_experts
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
else:
# Default: Mixtral
E = config.num_local_experts
......@@ -41,12 +42,27 @@ def get_model_config(model_name: str, tp_size: int):
intermediate_size = config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
vllm_version_num = (
vllm.__version_tuple__[0] * 100
+ vllm.__version_tuple__[1] * 10
+ vllm.__version_tuple__[2]
)
block_shape = None
if (
hasattr(config, "quantization_config")
and "weight_block_size" in config.quantization_config
):
block_shape = config.quantization_config["weight_block_size"]
assert len(block_shape) == 2
assert vllm_version_num >= 66, "Block-wise quantized fp8 fused_moe is only supported for VLLM>=0.6.6.post1"
shape_configs = {
"num_experts": E,
"topk": topk,
"hidden_size": config.hidden_size,
"shard_intermediate_size": shard_intermediate_size,
"dtype": config.torch_dtype,
"block_shape": block_shape,
}
print(f"{shape_configs=}")
return shape_configs
......@@ -63,21 +79,39 @@ def fused_moe_vllm_api(
w2_scale=None,
a1_scale=None,
a2_scale=None,
block_shape=None,
):
return fused_moe_vllm(
x,
w1,
w2,
input_gating,
topk,
renormalize=True,
inplace=True,
use_fp8_w8a8=use_fp8_w8a8,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
)
if block_shape is not None:
return fused_moe_vllm(
x,
w1,
w2,
input_gating,
topk,
renormalize=True,
inplace=True,
use_fp8_w8a8=use_fp8_w8a8,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape,
)
else:
return fused_moe_vllm(
x,
w1,
w2,
input_gating,
topk,
renormalize=True,
inplace=True,
use_fp8_w8a8=use_fp8_w8a8,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
)
def fused_moe_sglang_api(
......@@ -91,6 +125,7 @@ def fused_moe_sglang_api(
w2_scale=None,
a1_scale=None,
a2_scale=None,
block_shape=None,
):
return fused_moe_sglang(
x,
......@@ -105,6 +140,7 @@ def fused_moe_sglang_api(
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape,
)
......@@ -141,8 +177,10 @@ def benchmark(batch_size, provider, model_config, use_fp8=False):
shard_intermediate_size = model_config["shard_intermediate_size"]
topk = model_config["topk"]
dtype = model_config["dtype"]
block_shape = getattr(model_config, "block_shape", None)
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
w1_scale = w2_scale = a1_scale = a2_scale = None
if use_fp8:
init_dtype = dtype
......@@ -154,16 +192,29 @@ def benchmark(batch_size, provider, model_config, use_fp8=False):
)
w1 = w1.to(torch.float8_e4m3fn)
w2 = w2.to(torch.float8_e4m3fn)
w1_scale = torch.randn(num_experts, dtype=torch.float32)
w2_scale = torch.randn(num_experts, dtype=torch.float32)
a1_scale = torch.randn(1, dtype=torch.float32)
a2_scale = torch.randn(1, dtype=torch.float32)
if block_shape is None:
w1_scale = torch.randn(num_experts, dtype=torch.float32)
w2_scale = torch.randn(num_experts, dtype=torch.float32)
a1_scale = torch.randn(1, dtype=torch.float32)
a2_scale = torch.randn(1, dtype=torch.float32)
else:
block_n, block_k = block_shape[0], block_shape[1]
n_tiles_w1 = (shard_intermediate_size + block_n - 1) // block_n
n_tiles_w2 = (hidden_size + block_n - 1) // block_n
k_tiles_w1 = (hidden_size + block_k - 1) // block_k
k_tiles_w2 = (shard_intermediate_size // 2 + block_k - 1) // block_k
w1_scale = torch.rand(
(num_experts, n_tiles_w1, k_tiles_w1), dtype=torch.float32
)
w2_scale = torch.rand(
(num_experts, n_tiles_w2, k_tiles_w2), dtype=torch.float32
)
else:
w1 = torch.randn(num_experts, shard_intermediate_size, hidden_size, dtype=dtype)
w2 = torch.randn(
num_experts, hidden_size, shard_intermediate_size // 2, dtype=dtype
)
w1_scale = w2_scale = a1_scale = a2_scale = None
input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)
......@@ -185,6 +236,7 @@ def benchmark(batch_size, provider, model_config, use_fp8=False):
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape,
)
torch.cuda.synchronize()
......@@ -201,6 +253,7 @@ def benchmark(batch_size, provider, model_config, use_fp8=False):
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape,
)[0],
quantiles=quantiles,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment