Unverified Commit 8d705996 authored by Jee Jee Li's avatar Jee Jee Li Committed by GitHub
Browse files

[Misc] Minor enhancement of benchmark_moe (#22068)


Signed-off-by: default avatarJee Jee Li <pandaleefree@gmail.com>
parent 38c8bce8
......@@ -22,6 +22,13 @@ from vllm.utils import FlexibleArgumentParser
FP8_DTYPE = current_platform.fp8_dtype()
def ensure_divisibility(numerator, denominator):
"""Ensure that numerator is divisible by the denominator."""
assert numerator % denominator == 0, (
"intermediate_size {} is not divisible by tp {}.".format(numerator, denominator)
)
class BenchmarkConfig(TypedDict):
BLOCK_SIZE_M: int
BLOCK_SIZE_N: int
......@@ -603,7 +610,7 @@ def main(args: argparse.Namespace):
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
ensure_divisibility(intermediate_size, args.tp_size)
hidden_size = config.hidden_size
dtype = torch.float16 if current_platform.is_rocm() else config.torch_dtype
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment