Unverified Commit 849f58d6 authored by GaoYuYang's avatar GaoYuYang Committed by GitHub
Browse files

Update fused_moe's benchmark (#3346)

parent 64480df4
...@@ -2,6 +2,7 @@ import argparse ...@@ -2,6 +2,7 @@ import argparse
import torch import torch
import triton import triton
import vllm
from transformers import AutoConfig from transformers import AutoConfig
from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe as fused_moe_vllm from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe as fused_moe_vllm
...@@ -29,11 +30,11 @@ def get_model_config(model_name: str, tp_size: int): ...@@ -29,11 +30,11 @@ def get_model_config(model_name: str, tp_size: int):
topk = config.num_experts_per_tok topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] == "DeepseekV2ForCausalLM": elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]:
E = config.n_routed_experts E = config.n_routed_experts
topk = config.num_experts_per_tok topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size shard_intermediate_size = 2 * intermediate_size // tp_size
else: else:
# Default: Mixtral # Default: Mixtral
E = config.num_local_experts E = config.num_local_experts
...@@ -41,12 +42,27 @@ def get_model_config(model_name: str, tp_size: int): ...@@ -41,12 +42,27 @@ def get_model_config(model_name: str, tp_size: int):
intermediate_size = config.intermediate_size intermediate_size = config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size shard_intermediate_size = 2 * intermediate_size // tp_size
vllm_version_num = (
vllm.__version_tuple__[0] * 100
+ vllm.__version_tuple__[1] * 10
+ vllm.__version_tuple__[2]
)
block_shape = None
if (
hasattr(config, "quantization_config")
and "weight_block_size" in config.quantization_config
):
block_shape = config.quantization_config["weight_block_size"]
assert len(block_shape) == 2
assert vllm_version_num >= 66, "Block-wise quantized fp8 fused_moe is only supported for VLLM>=0.6.6.post1"
shape_configs = { shape_configs = {
"num_experts": E, "num_experts": E,
"topk": topk, "topk": topk,
"hidden_size": config.hidden_size, "hidden_size": config.hidden_size,
"shard_intermediate_size": shard_intermediate_size, "shard_intermediate_size": shard_intermediate_size,
"dtype": config.torch_dtype, "dtype": config.torch_dtype,
"block_shape": block_shape,
} }
print(f"{shape_configs=}") print(f"{shape_configs=}")
return shape_configs return shape_configs
...@@ -63,7 +79,25 @@ def fused_moe_vllm_api( ...@@ -63,7 +79,25 @@ def fused_moe_vllm_api(
w2_scale=None, w2_scale=None,
a1_scale=None, a1_scale=None,
a2_scale=None, a2_scale=None,
block_shape=None,
): ):
if block_shape is not None:
return fused_moe_vllm(
x,
w1,
w2,
input_gating,
topk,
renormalize=True,
inplace=True,
use_fp8_w8a8=use_fp8_w8a8,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape,
)
else:
return fused_moe_vllm( return fused_moe_vllm(
x, x,
w1, w1,
...@@ -91,6 +125,7 @@ def fused_moe_sglang_api( ...@@ -91,6 +125,7 @@ def fused_moe_sglang_api(
w2_scale=None, w2_scale=None,
a1_scale=None, a1_scale=None,
a2_scale=None, a2_scale=None,
block_shape=None,
): ):
return fused_moe_sglang( return fused_moe_sglang(
x, x,
...@@ -105,6 +140,7 @@ def fused_moe_sglang_api( ...@@ -105,6 +140,7 @@ def fused_moe_sglang_api(
w2_scale=w2_scale, w2_scale=w2_scale,
a1_scale=a1_scale, a1_scale=a1_scale,
a2_scale=a2_scale, a2_scale=a2_scale,
block_shape=block_shape,
) )
...@@ -141,8 +177,10 @@ def benchmark(batch_size, provider, model_config, use_fp8=False): ...@@ -141,8 +177,10 @@ def benchmark(batch_size, provider, model_config, use_fp8=False):
shard_intermediate_size = model_config["shard_intermediate_size"] shard_intermediate_size = model_config["shard_intermediate_size"]
topk = model_config["topk"] topk = model_config["topk"]
dtype = model_config["dtype"] dtype = model_config["dtype"]
block_shape = getattr(model_config, "block_shape", None)
x = torch.randn(num_tokens, hidden_size, dtype=dtype) x = torch.randn(num_tokens, hidden_size, dtype=dtype)
w1_scale = w2_scale = a1_scale = a2_scale = None
if use_fp8: if use_fp8:
init_dtype = dtype init_dtype = dtype
...@@ -154,16 +192,29 @@ def benchmark(batch_size, provider, model_config, use_fp8=False): ...@@ -154,16 +192,29 @@ def benchmark(batch_size, provider, model_config, use_fp8=False):
) )
w1 = w1.to(torch.float8_e4m3fn) w1 = w1.to(torch.float8_e4m3fn)
w2 = w2.to(torch.float8_e4m3fn) w2 = w2.to(torch.float8_e4m3fn)
if block_shape is None:
w1_scale = torch.randn(num_experts, dtype=torch.float32) w1_scale = torch.randn(num_experts, dtype=torch.float32)
w2_scale = torch.randn(num_experts, dtype=torch.float32) w2_scale = torch.randn(num_experts, dtype=torch.float32)
a1_scale = torch.randn(1, dtype=torch.float32) a1_scale = torch.randn(1, dtype=torch.float32)
a2_scale = torch.randn(1, dtype=torch.float32) a2_scale = torch.randn(1, dtype=torch.float32)
else:
block_n, block_k = block_shape[0], block_shape[1]
n_tiles_w1 = (shard_intermediate_size + block_n - 1) // block_n
n_tiles_w2 = (hidden_size + block_n - 1) // block_n
k_tiles_w1 = (hidden_size + block_k - 1) // block_k
k_tiles_w2 = (shard_intermediate_size // 2 + block_k - 1) // block_k
w1_scale = torch.rand(
(num_experts, n_tiles_w1, k_tiles_w1), dtype=torch.float32
)
w2_scale = torch.rand(
(num_experts, n_tiles_w2, k_tiles_w2), dtype=torch.float32
)
else: else:
w1 = torch.randn(num_experts, shard_intermediate_size, hidden_size, dtype=dtype) w1 = torch.randn(num_experts, shard_intermediate_size, hidden_size, dtype=dtype)
w2 = torch.randn( w2 = torch.randn(
num_experts, hidden_size, shard_intermediate_size // 2, dtype=dtype num_experts, hidden_size, shard_intermediate_size // 2, dtype=dtype
) )
w1_scale = w2_scale = a1_scale = a2_scale = None
input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32) input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)
...@@ -185,6 +236,7 @@ def benchmark(batch_size, provider, model_config, use_fp8=False): ...@@ -185,6 +236,7 @@ def benchmark(batch_size, provider, model_config, use_fp8=False):
w2_scale=w2_scale, w2_scale=w2_scale,
a1_scale=a1_scale, a1_scale=a1_scale,
a2_scale=a2_scale, a2_scale=a2_scale,
block_shape=block_shape,
) )
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -201,6 +253,7 @@ def benchmark(batch_size, provider, model_config, use_fp8=False): ...@@ -201,6 +253,7 @@ def benchmark(batch_size, provider, model_config, use_fp8=False):
w2_scale=w2_scale, w2_scale=w2_scale,
a1_scale=a1_scale, a1_scale=a1_scale,
a2_scale=a2_scale, a2_scale=a2_scale,
block_shape=block_shape,
)[0], )[0],
quantiles=quantiles, quantiles=quantiles,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment