Commit 3fb4b5fa authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.18.0' into v0.18.0-ori

parents bcf25339 89138b21
...@@ -5,12 +5,14 @@ import time ...@@ -5,12 +5,14 @@ import time
import torch import torch
from vllm.benchmarks.lib.utils import default_vllm_config
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE, set_random_seed from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE, set_random_seed
@torch.inference_mode() @torch.inference_mode()
@default_vllm_config()
def main( def main(
num_tokens: int, num_tokens: int,
hidden_size: int, hidden_size: int,
...@@ -32,14 +34,14 @@ def main( ...@@ -32,14 +34,14 @@ def main(
residual = torch.randn_like(x) * scale if add_residual else None residual = torch.randn_like(x) * scale if add_residual else None
def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
torch.cuda.synchronize() torch.accelerator.synchronize()
if profile: if profile:
torch.cuda.cudart().cudaProfilerStart() torch.cuda.cudart().cudaProfilerStart()
start_time = time.perf_counter() start_time = time.perf_counter()
for _ in range(num_iters): for _ in range(num_iters):
layer(x, residual) layer(x, residual)
torch.cuda.synchronize() torch.accelerator.synchronize()
end_time = time.perf_counter() end_time = time.perf_counter()
if profile: if profile:
......
...@@ -1035,7 +1035,7 @@ def bench_optype( ...@@ -1035,7 +1035,7 @@ def bench_optype(
# Run bench function so that _LORA_A_PTR_DICT and _LORA_B_PTR_DICT are set up # Run bench function so that _LORA_A_PTR_DICT and _LORA_B_PTR_DICT are set up
for kwargs in kwargs_list: for kwargs in kwargs_list:
op_type.bench_fn()(**kwargs) op_type.bench_fn()(**kwargs)
torch.cuda.synchronize() torch.accelerator.synchronize()
# Merge into a single kwargs and qualify arguments as ArgPool # Merge into a single kwargs and qualify arguments as ArgPool
kwargs = {k: ArgPool([]) for k in kwargs_list[0]} kwargs = {k: ArgPool([]) for k in kwargs_list[0]}
......
...@@ -47,13 +47,13 @@ def benchmark_method( ...@@ -47,13 +47,13 @@ def benchmark_method(
# Warmup # Warmup
for _ in range(num_warmup): for _ in range(num_warmup):
_ = method(k_nope, k_pe) _ = method(k_nope, k_pe)
torch.cuda.synchronize() torch.accelerator.synchronize()
# Benchmark # Benchmark
start = time.perf_counter() start = time.perf_counter()
for _ in range(num_iters): for _ in range(num_iters):
_ = method(k_nope, k_pe) _ = method(k_nope, k_pe)
torch.cuda.synchronize() torch.accelerator.synchronize()
end = time.perf_counter() end = time.perf_counter()
return (end - start) / num_iters * 1000 # Convert to ms return (end - start) / num_iters * 1000 # Convert to ms
......
...@@ -16,6 +16,10 @@ import torch ...@@ -16,6 +16,10 @@ import torch
from ray.experimental.tqdm_ray import tqdm from ray.experimental.tqdm_ray import tqdm
from vllm.model_executor.layers.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.all2all_utils import (
maybe_make_prepare_finalize,
)
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEConfig,
FusedMoEParallelConfig, FusedMoEParallelConfig,
...@@ -50,7 +54,7 @@ def clear_triton_cache(): ...@@ -50,7 +54,7 @@ def clear_triton_cache():
# Clear CUDA memory cache # Clear CUDA memory cache
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.accelerator.empty_cache()
# Try to clear Triton's runtime cache # Try to clear Triton's runtime cache
try: try:
...@@ -99,13 +103,38 @@ def benchmark_config( ...@@ -99,13 +103,38 @@ def benchmark_config(
dtype: torch.dtype, dtype: torch.dtype,
use_fp8_w8a8: bool, use_fp8_w8a8: bool,
use_int8_w8a16: bool, use_int8_w8a16: bool,
use_int4_w4a16: bool = False,
num_iters: int = 100, num_iters: int = 100,
block_quant_shape: list[int] = None, block_quant_shape: list[int] = None,
use_deep_gemm: bool = False, use_deep_gemm: bool = False,
) -> float: ) -> float:
init_dtype = torch.float16 if use_fp8_w8a8 else dtype init_dtype = torch.float16 if use_fp8_w8a8 else dtype
x = torch.randn(num_tokens, hidden_size, dtype=dtype) x = torch.randn(num_tokens, hidden_size, dtype=dtype)
if use_int8_w8a16: if use_int4_w4a16:
# Int4 packed weights: 2 int4 values per uint8 byte
# K dimension is packed (halved)
intermediate_size = shard_intermediate_size // 2 # after silu_and_mul
w1 = torch.randint(
0,
255,
(
num_experts,
shard_intermediate_size,
hidden_size // 2, # int4 packing
),
dtype=torch.uint8,
)
w2 = torch.randint(
0,
255,
(
num_experts,
hidden_size,
intermediate_size // 2, # int4 packing
),
dtype=torch.uint8,
)
elif use_int8_w8a16:
w1 = torch.randint( w1 = torch.randint(
-127, -127,
127, 127,
...@@ -139,7 +168,20 @@ def benchmark_config( ...@@ -139,7 +168,20 @@ def benchmark_config(
w2_scale = None w2_scale = None
a1_scale = None a1_scale = None
a2_scale = None a2_scale = None
if use_int8_w8a16: if use_int4_w4a16:
if block_quant_shape is None:
raise ValueError("block_quant_shape is required for int4_w4a16")
group_size = block_quant_shape[1]
# Scales shape: (E, N, K // group_size) in fp16
w1_scale = torch.rand(
(num_experts, shard_intermediate_size, hidden_size // group_size),
dtype=dtype,
)
w2_scale = torch.rand(
(num_experts, hidden_size, intermediate_size // group_size),
dtype=dtype,
)
elif use_int8_w8a16:
w1_scale = torch.randn( w1_scale = torch.randn(
(num_experts, 2 * shard_intermediate_size), dtype=torch.float32 (num_experts, 2 * shard_intermediate_size), dtype=torch.float32
) )
...@@ -198,27 +240,38 @@ def benchmark_config( ...@@ -198,27 +240,38 @@ def benchmark_config(
a1_scale=a1_scale, a1_scale=a1_scale,
a2_scale=a2_scale, a2_scale=a2_scale,
block_shape=block_quant_shape, block_shape=block_quant_shape,
weight_dtype="int4" if use_int4_w4a16 else None,
) )
deep_gemm_experts = None deep_gemm_experts = None
if use_deep_gemm: if use_deep_gemm:
deep_gemm_experts = mk.FusedMoEModularKernel( moe_config = (
prepare_finalize=MoEPrepareAndFinalizeNoEP(), FusedMoEConfig(
num_experts=num_experts,
experts_per_token=topk,
hidden_dim=hidden_size,
intermediate_size_per_partition=shard_intermediate_size,
num_local_experts=num_experts,
num_logical_experts=num_experts,
activation=MoEActivation.SILU,
moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
in_dtype=init_dtype,
routing_method=RoutingMethodType.TopK,
device="cuda",
),
)
deep_gemm_experts = mk.FusedMoEKernel(
prepare_finalize=maybe_make_prepare_finalize(
moe=moe_config,
quant_config=quant_config,
allow_new_interface=True,
use_monolithic=False,
),
fused_experts=TritonOrDeepGemmExperts( fused_experts=TritonOrDeepGemmExperts(
moe_config=FusedMoEConfig( moe_config=moe_config,
num_experts=num_experts,
experts_per_token=topk,
hidden_dim=hidden_size,
intermediate_size_per_partition=shard_intermediate_size,
num_local_experts=num_experts,
activation="silu",
moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
in_dtype=init_dtype,
routing_method=RoutingMethodType.TopK,
device="cuda",
),
quant_config=quant_config, quant_config=quant_config,
), ),
inplace=not disable_inplace(),
) )
with override_config(config): with override_config(config):
...@@ -226,9 +279,18 @@ def benchmark_config( ...@@ -226,9 +279,18 @@ def benchmark_config(
x, input_gating, topk, renormalize=not use_deep_gemm x, input_gating, topk, renormalize=not use_deep_gemm
) )
inplace = not disable_inplace()
if use_deep_gemm: if use_deep_gemm:
return deep_gemm_experts( return deep_gemm_experts.apply(
x, w1, w2, topk_weights, topk_ids, inplace=True x,
w1,
w2,
topk_weights,
topk_ids,
activation=MoEActivation.SILU,
global_num_experts=num_experts,
apply_router_weight_on_input=False,
expert_map=False,
) )
return fused_experts( return fused_experts(
x, x,
...@@ -236,25 +298,25 @@ def benchmark_config( ...@@ -236,25 +298,25 @@ def benchmark_config(
w2, w2,
topk_weights, topk_weights,
topk_ids, topk_ids,
inplace=True, inplace=inplace,
quant_config=quant_config, quant_config=quant_config,
) )
# JIT compilation & warmup # JIT compilation & warmup
run() run()
torch.cuda.synchronize() torch.accelerator.synchronize()
# Capture 10 invocations with CUDA graph # Capture 10 invocations with CUDA graph
graph = torch.cuda.CUDAGraph() graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph): with torch.cuda.graph(graph):
for _ in range(10): for _ in range(10):
run() run()
torch.cuda.synchronize() torch.accelerator.synchronize()
# Warmup # Warmup
for _ in range(5): for _ in range(5):
graph.replay() graph.replay()
torch.cuda.synchronize() torch.accelerator.synchronize()
start_event = torch.Event(enable_timing=True) start_event = torch.Event(enable_timing=True)
end_event = torch.Event(enable_timing=True) end_event = torch.Event(enable_timing=True)
...@@ -262,7 +324,7 @@ def benchmark_config( ...@@ -262,7 +324,7 @@ def benchmark_config(
latencies: list[float] = [] latencies: list[float] = []
for i in range(num_iters): for i in range(num_iters):
prepare(i) prepare(i)
torch.cuda.synchronize() torch.accelerator.synchronize()
start_event.record() start_event.record()
graph.replay() graph.replay()
...@@ -478,6 +540,7 @@ class BenchmarkWorker: ...@@ -478,6 +540,7 @@ class BenchmarkWorker:
dtype: torch.dtype, dtype: torch.dtype,
use_fp8_w8a8: bool, use_fp8_w8a8: bool,
use_int8_w8a16: bool, use_int8_w8a16: bool,
use_int4_w4a16: bool = False,
block_quant_shape: list[int] = None, block_quant_shape: list[int] = None,
use_deep_gemm: bool = False, use_deep_gemm: bool = False,
) -> tuple[dict[str, int], float]: ) -> tuple[dict[str, int], float]:
...@@ -485,7 +548,10 @@ class BenchmarkWorker: ...@@ -485,7 +548,10 @@ class BenchmarkWorker:
set_random_seed(self.seed) set_random_seed(self.seed)
dtype_str = _get_config_dtype_str( dtype_str = _get_config_dtype_str(
dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8 dtype,
use_int8_w8a16=use_int8_w8a16,
use_fp8_w8a8=use_fp8_w8a8,
use_int4_w4a16=use_int4_w4a16,
) )
# NOTE(woosuk): The current naming convention uses w2.shape[2], which # NOTE(woosuk): The current naming convention uses w2.shape[2], which
# is the intermediate size after silu_and_mul. # is the intermediate size after silu_and_mul.
...@@ -516,6 +582,7 @@ class BenchmarkWorker: ...@@ -516,6 +582,7 @@ class BenchmarkWorker:
dtype, dtype,
use_fp8_w8a8, use_fp8_w8a8,
use_int8_w8a16, use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
num_iters=100, num_iters=100,
block_quant_shape=block_quant_shape, block_quant_shape=block_quant_shape,
use_deep_gemm=use_deep_gemm, use_deep_gemm=use_deep_gemm,
...@@ -532,6 +599,7 @@ class BenchmarkWorker: ...@@ -532,6 +599,7 @@ class BenchmarkWorker:
dtype: torch.dtype, dtype: torch.dtype,
use_fp8_w8a8: bool, use_fp8_w8a8: bool,
use_int8_w8a16: bool, use_int8_w8a16: bool,
use_int4_w4a16: bool,
search_space: list[dict[str, int]], search_space: list[dict[str, int]],
block_quant_shape: list[int], block_quant_shape: list[int],
use_deep_gemm: bool, use_deep_gemm: bool,
...@@ -542,7 +610,7 @@ class BenchmarkWorker: ...@@ -542,7 +610,7 @@ class BenchmarkWorker:
best_config = None best_config = None
best_time = float("inf") best_time = float("inf")
if current_platform.is_rocm(): if current_platform.is_rocm():
is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16) is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16 or use_int4_w4a16)
search_space = prune_rocm_search_space( search_space = prune_rocm_search_space(
num_tokens, num_tokens,
shard_intermediate_size, shard_intermediate_size,
...@@ -558,7 +626,11 @@ class BenchmarkWorker: ...@@ -558,7 +626,11 @@ class BenchmarkWorker:
if visible_device != f"{self.device_id}": if visible_device != f"{self.device_id}":
need_device_guard = True need_device_guard = True
with torch.cuda.device(self.device_id) if need_device_guard else nullcontext(): with (
torch.accelerator.device_index(self.device_id)
if need_device_guard
else nullcontext()
):
for idx, config in enumerate(tqdm(search_space)): for idx, config in enumerate(tqdm(search_space)):
try: try:
kernel_time = benchmark_config( kernel_time = benchmark_config(
...@@ -571,6 +643,7 @@ class BenchmarkWorker: ...@@ -571,6 +643,7 @@ class BenchmarkWorker:
dtype, dtype,
use_fp8_w8a8, use_fp8_w8a8,
use_int8_w8a16, use_int8_w8a16,
use_int4_w4a16,
num_iters=20, num_iters=20,
block_quant_shape=block_quant_shape, block_quant_shape=block_quant_shape,
use_deep_gemm=use_deep_gemm, use_deep_gemm=use_deep_gemm,
...@@ -618,6 +691,7 @@ def sort_config(config: BenchmarkConfig) -> BenchmarkConfig: ...@@ -618,6 +691,7 @@ def sort_config(config: BenchmarkConfig) -> BenchmarkConfig:
else {} else {}
), ),
**({"kpack": config["kpack"]} if "kpack" in config else {}), **({"kpack": config["kpack"]} if "kpack" in config else {}),
**({"SPLIT_K": config["SPLIT_K"]} if "SPLIT_K" in config else {}),
} }
...@@ -630,11 +704,15 @@ def save_configs( ...@@ -630,11 +704,15 @@ def save_configs(
dtype: torch.dtype, dtype: torch.dtype,
use_fp8_w8a8: bool, use_fp8_w8a8: bool,
use_int8_w8a16: bool, use_int8_w8a16: bool,
use_int4_w4a16: bool,
block_quant_shape: list[int], block_quant_shape: list[int],
save_dir: str, save_dir: str,
) -> None: ) -> None:
dtype_str = _get_config_dtype_str( dtype_str = _get_config_dtype_str(
dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8 dtype,
use_int8_w8a16=use_int8_w8a16,
use_fp8_w8a8=use_fp8_w8a8,
use_int4_w4a16=use_int4_w4a16,
) )
# NOTE(woosuk): The current naming convention uses w2.shape[2], which # NOTE(woosuk): The current naming convention uses w2.shape[2], which
...@@ -736,6 +814,38 @@ def get_model_params(config): ...@@ -736,6 +814,38 @@ def get_model_params(config):
return E, topk, intermediate_size, hidden_size return E, topk, intermediate_size, hidden_size
def get_quantization_group_size(config) -> int | None:
"""Extract the quantization group size from the HF model config.
This reads directly from the HuggingFace config object (as returned by
``get_config()``), not from vLLM's quantization config classes.
Supports AWQ/GPTQ-style configs (direct 'group_size' key) and
compressed-tensors configs (nested inside 'config_groups').
"""
quantization_config = getattr(config, "quantization_config", {})
if not isinstance(quantization_config, dict):
return None
# AWQ / GPTQ style: group_size is a top-level key
gs = quantization_config.get("group_size")
if gs is not None:
return gs
# compressed-tensors style: group_size is nested in config_groups
config_groups = quantization_config.get("config_groups", {})
if not isinstance(config_groups, dict):
return None
for group_cfg in config_groups.values():
if not isinstance(group_cfg, dict):
continue
weights = group_cfg.get("weights", {})
if not isinstance(weights, dict):
continue
gs = weights.get("group_size")
if gs is not None:
return gs
return None
def main(args: argparse.Namespace): def main(args: argparse.Namespace):
print(args) print(args)
...@@ -754,7 +864,20 @@ def main(args: argparse.Namespace): ...@@ -754,7 +864,20 @@ def main(args: argparse.Namespace):
dtype = torch.float16 if current_platform.is_rocm() else config.dtype dtype = torch.float16 if current_platform.is_rocm() else config.dtype
use_fp8_w8a8 = args.dtype == "fp8_w8a8" use_fp8_w8a8 = args.dtype == "fp8_w8a8"
use_int8_w8a16 = args.dtype == "int8_w8a16" use_int8_w8a16 = args.dtype == "int8_w8a16"
use_int4_w4a16 = args.dtype == "int4_w4a16"
block_quant_shape = get_weight_block_size_safety(config) block_quant_shape = get_weight_block_size_safety(config)
if use_int4_w4a16:
group_size = get_quantization_group_size(config)
if group_size is None:
raise ValueError(
"Could not determine group_size from model config. "
"The model's quantization_config must contain a 'group_size' "
"field (AWQ/GPTQ) or 'config_groups.*.weights.group_size' "
"(compressed-tensors)."
)
# For int4_w4a16, block_shape = [0, group_size]
# block_shape[0]=0 means no block quantization on N dimension
block_quant_shape = [0, group_size]
if args.batch_size is None: if args.batch_size is None:
batch_sizes = [ batch_sizes = [
...@@ -808,8 +931,20 @@ def main(args: argparse.Namespace): ...@@ -808,8 +931,20 @@ def main(args: argparse.Namespace):
return ray.get(outputs) return ray.get(outputs)
if args.tune: if args.tune:
is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16) # int4_w4a16 weights are uint8-packed, not fp16; treat like fp8 for
search_space = get_configs_compute_bound(is_fp16, block_quant_shape) # search space generation (no matrix_instr_nonkdim/kpack exploration).
is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16 or use_int4_w4a16)
# For int4_w4a16, the group_size constraint on BLOCK_SIZE_K does not
# apply: the gptq_awq kernel handles arbitrary BLOCK_SIZE_K regardless
# of group_size. Skip block_quant_shape filtering to keep the full
# search space (e.g. BLOCK_SIZE_K=64 with group_size=128).
tune_block_quant_shape = None if use_int4_w4a16 else block_quant_shape
search_space = get_configs_compute_bound(is_fp16, tune_block_quant_shape)
if use_int4_w4a16:
# SPLIT_K is a required kernel constexpr for gptq_awq kernel;
# only SPLIT_K=1 is used at runtime, so fix it during tuning.
for cfg in search_space:
cfg["SPLIT_K"] = 1
print(f"Start tuning over {len(search_space)} configurations...") print(f"Start tuning over {len(search_space)} configurations...")
if use_deep_gemm: if use_deep_gemm:
raise ValueError( raise ValueError(
...@@ -829,6 +964,7 @@ def main(args: argparse.Namespace): ...@@ -829,6 +964,7 @@ def main(args: argparse.Namespace):
dtype, dtype,
use_fp8_w8a8, use_fp8_w8a8,
use_int8_w8a16, use_int8_w8a16,
use_int4_w4a16,
search_space, search_space,
block_quant_shape, block_quant_shape,
use_deep_gemm, use_deep_gemm,
...@@ -848,6 +984,7 @@ def main(args: argparse.Namespace): ...@@ -848,6 +984,7 @@ def main(args: argparse.Namespace):
dtype, dtype,
use_fp8_w8a8, use_fp8_w8a8,
use_int8_w8a16, use_int8_w8a16,
use_int4_w4a16,
block_quant_shape, block_quant_shape,
args.save_dir, args.save_dir,
) )
...@@ -866,6 +1003,7 @@ def main(args: argparse.Namespace): ...@@ -866,6 +1003,7 @@ def main(args: argparse.Namespace):
dtype, dtype,
use_fp8_w8a8, use_fp8_w8a8,
use_int8_w8a16, use_int8_w8a16,
use_int4_w4a16,
block_quant_shape, block_quant_shape,
use_deep_gemm, use_deep_gemm,
) )
...@@ -888,7 +1026,10 @@ if __name__ == "__main__": ...@@ -888,7 +1026,10 @@ if __name__ == "__main__":
) )
parser.add_argument("--enable-expert-parallel", "-enable-ep", action="store_true") parser.add_argument("--enable-expert-parallel", "-enable-ep", action="store_true")
parser.add_argument( parser.add_argument(
"--dtype", type=str, choices=["auto", "fp8_w8a8", "int8_w8a16"], default="auto" "--dtype",
type=str,
choices=["auto", "fp8_w8a8", "int8_w8a16", "int4_w4a16"],
default="auto",
) )
parser.add_argument("--use-deep-gemm", action="store_true") parser.add_argument("--use-deep-gemm", action="store_true")
parser.add_argument( parser.add_argument(
......
#!/usr/bin/env python3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Benchmark comparing old vs new default fused MoE configs.
Runs the triton fused_moe kernel with three configurations for each scenario:
1. Tuned config (from JSON file, if available) — the target to match
2. Old default (the hardcoded defaults before this change)
3. New default (the improved defaults)
Usage:
python benchmarks/kernels/benchmark_moe_defaults.py
Produces a table showing kernel time (us) and speedup of new vs old defaults.
"""
import torch
from vllm.model_executor.layers.fused_moe import fused_topk, override_config
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_experts,
get_default_config,
get_moe_configs,
)
from vllm.platforms import current_platform
from vllm.triton_utils import triton
from vllm.utils.torch_utils import set_random_seed
FP8_DTYPE = current_platform.fp8_dtype()
def old_default_config(M, E, N, K, topk, dtype=None, block_shape=None):
"""The original defaults before https://github.com/vllm-project/vllm/pull/34846,
for comparison."""
if dtype == "fp8_w8a8" and block_shape is not None:
return {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": block_shape[0],
"BLOCK_SIZE_K": block_shape[1],
"GROUP_SIZE_M": 32,
"SPLIT_K": 1,
"num_warps": 4,
"num_stages": 3 if not current_platform.is_rocm() else 2,
}
elif M <= E:
return {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"SPLIT_K": 1,
}
else:
return {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
"SPLIT_K": 1,
}
def benchmark_config(
config,
M,
E,
N,
K,
topk,
dtype,
use_fp8=False,
block_shape=None,
num_iters=100,
):
"""Time a single kernel config. Returns kernel time in microseconds."""
init_dtype = torch.float16 if use_fp8 else dtype
a = torch.randn(M, K, device="cuda", dtype=init_dtype) / 10
w1 = torch.randn(E, 2 * N, K, device="cuda", dtype=init_dtype) / 10
w2 = torch.randn(E, K, N, device="cuda", dtype=init_dtype) / 10
w1_scale = None
w2_scale = None
a1_scale = None
a2_scale = None
if use_fp8:
if block_shape is not None:
bsn, bsk = block_shape
n_tiles_w1 = triton.cdiv(2 * N, bsn)
k_tiles_w1 = triton.cdiv(K, bsk)
n_tiles_w2 = triton.cdiv(K, bsn)
k_tiles_w2 = triton.cdiv(N, bsk)
w1_scale = torch.rand(
E, n_tiles_w1, k_tiles_w1, device="cuda", dtype=torch.float32
)
w2_scale = torch.rand(
E, n_tiles_w2, k_tiles_w2, device="cuda", dtype=torch.float32
)
else:
w1_scale = torch.rand(E, device="cuda", dtype=torch.float32)
w2_scale = torch.rand(E, device="cuda", dtype=torch.float32)
a1_scale = torch.rand(1, device="cuda", dtype=torch.float32)
a2_scale = torch.rand(1, device="cuda", dtype=torch.float32)
# Only weights are stored in fp8; activations stay in bf16/fp16
# and get dynamically quantized inside the kernel.
w1 = w1.to(FP8_DTYPE)
w2 = w2.to(FP8_DTYPE)
quant_config = FusedMoEQuantConfig.make(
quant_dtype=torch.float8_e4m3fn if use_fp8 else None,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape,
)
gating = torch.randn(M, E, device="cuda", dtype=torch.float32)
# Warmup
for _ in range(20):
with override_config(config):
topk_weights, topk_ids, _ = fused_topk(a, gating, topk, renormalize=True)
fused_experts(
a,
w1,
w2,
topk_weights,
topk_ids,
quant_config=quant_config,
)
torch.accelerator.synchronize()
# Benchmark
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
for _ in range(num_iters):
with override_config(config):
topk_weights, topk_ids, _ = fused_topk(a, gating, topk, renormalize=True)
fused_experts(
a,
w1,
w2,
topk_weights,
topk_ids,
quant_config=quant_config,
)
end.record()
torch.accelerator.synchronize()
return start.elapsed_time(end) / num_iters * 1000 # ms -> us
# Model configurations: (name, E, N, K, topk, dtype_str, use_fp8, block_shape)
# N = moe_intermediate_size // tp_size (the value used in config file lookup)
MODELS = [
# --- Few experts ---
("Mixtral bf16", 8, 7168, 4096, 2, None, False, None),
("Mixtral fp8", 8, 7168, 4096, 2, "fp8_w8a8", True, None),
# --- Many experts: real model shapes at tp=1 ---
# Qwen2-MoE-57B: E=60, topk=4, N=1408, K=2048
("Qwen2-MoE bf16", 60, 1408, 2048, 4, None, False, None),
# DeepSeek-V2: E=64, topk=6, N=1407, K=4096
# (use 1408 to avoid odd alignment; real model is 1407)
("DeepSeek-V2 bf16", 64, 1408, 4096, 6, None, False, None),
# OLMoE-7B: E=64, topk=8, N=2048, K=2048
("OLMoE bf16", 64, 2048, 2048, 8, None, False, None),
# GLM-4-100B-A10B: E=128, topk=8, N=1408, K=4096
("GLM-4-MoE bf16", 128, 1408, 4096, 8, None, False, None),
# Qwen3-30B-A3B: E=128, topk=8, N=768, K=2048
("Qwen3-MoE bf16", 128, 768, 2048, 8, None, False, None),
# DeepSeek-V3 / MiMo-V2-Flash: E=256, topk=8, N=2048, K=7168
("DeepSeek-V3 bf16", 256, 2048, 7168, 8, None, False, None),
# Qwen3.5-70B-A22B (Qwen3-Next): E=512, topk=10, N=512, K=2048
("Qwen3-Next bf16", 512, 512, 2048, 10, None, False, None),
# E=128 N=1856 bf16
("E128 N1856 bf16", 128, 1856, 4096, 8, None, False, None),
# E=256 N=512 bf16 (DS-V3 tp=4)
("DS-V3 tp4 bf16", 256, 512, 7168, 8, None, False, None),
# E=512 N=512 bf16 (Qwen3-Next tp=1)
("Qwen3-Next bf16", 512, 512, 2048, 10, None, False, None),
# E=512 N=256 bf16 (Qwen3-Next tp=2)
("Qwen3-Next tp2", 512, 256, 2048, 10, None, False, None),
# --- FP8 block quant (many experts) ---
# DS-V3 tp=4: E=256, N=512, fp8 block
("DS-V3 tp4 fp8blk", 256, 512, 7168, 8, "fp8_w8a8", True, [128, 128]),
# DS-V3 tp=8: E=256, N=256, fp8 block
("DS-V3 tp8 fp8blk", 256, 256, 7168, 8, "fp8_w8a8", True, [128, 128]),
# Qwen3-Next tp=2 fp8 block
("Qwen3-Next tp2 fp8blk", 512, 256, 2048, 10, "fp8_w8a8", True, [128, 128]),
]
BATCH_SIZES = [1, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096]
def main():
set_random_seed(0)
torch.set_default_device("cuda")
dtype = torch.bfloat16
for name, E, N, K, topk, dtype_str, use_fp8, block_shape in MODELS:
print(f"\n{'=' * 90}")
print(f" {name} (E={E}, N={N}, K={K}, topk={topk})")
print(f"{'=' * 90}")
# Try to load tuned config
block_n = block_shape[0] if block_shape else None
block_k = block_shape[1] if block_shape else None
tuned = get_moe_configs(E, N, dtype_str, block_n, block_k)
has_tuned = tuned is not None
print(f" Tuned config available: {has_tuned}")
hdr = (
f"{'Batch':>6} | {'Tuned (us)':>11} | {'Old (us)':>11} | "
f"{'New (us)':>11} | {'New/Old':>8} | {'New/Tuned':>10}"
)
print(f" {hdr}")
print(f" {'-' * len(hdr)}")
for M in BATCH_SIZES:
old_cfg = old_default_config(M, E, N, K, topk, dtype_str, block_shape)
new_cfg = get_default_config(M, E, N, K, topk, dtype_str, block_shape)
if has_tuned:
tuned_cfg = tuned[min(tuned.keys(), key=lambda x: abs(x - M))]
t_tuned = benchmark_config(
tuned_cfg,
M,
E,
N,
K,
topk,
dtype,
use_fp8=use_fp8,
block_shape=block_shape,
)
else:
t_tuned = None
t_old = benchmark_config(
old_cfg,
M,
E,
N,
K,
topk,
dtype,
use_fp8=use_fp8,
block_shape=block_shape,
)
t_new = benchmark_config(
new_cfg,
M,
E,
N,
K,
topk,
dtype,
use_fp8=use_fp8,
block_shape=block_shape,
)
ratio_new_old = t_new / t_old
tuned_str = f"{t_tuned:11.2f}" if t_tuned else f"{'N/A':>11}"
ratio_tuned = f"{t_new / t_tuned:10.2f}x" if t_tuned else f"{'N/A':>10}"
# flag regressions where new default is >5% slower than old
marker = " <--" if ratio_new_old > 1.05 else ""
print(
f" {M:>6} | {tuned_str} | {t_old:11.2f} | {t_new:11.2f} "
f"| {ratio_new_old:7.2f}x | {ratio_tuned}{marker}"
)
if __name__ == "__main__":
main()
...@@ -72,19 +72,19 @@ def benchmark_permute( ...@@ -72,19 +72,19 @@ def benchmark_permute(
# JIT compilation & warmup # JIT compilation & warmup
run() run()
torch.cuda.synchronize() torch.accelerator.synchronize()
# Capture 10 invocations with CUDA graph # Capture 10 invocations with CUDA graph
graph = torch.cuda.CUDAGraph() graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph): with torch.cuda.graph(graph):
for _ in range(10): for _ in range(10):
run() run()
torch.cuda.synchronize() torch.accelerator.synchronize()
# Warmup # Warmup
for _ in range(5): for _ in range(5):
graph.replay() graph.replay()
torch.cuda.synchronize() torch.accelerator.synchronize()
start_event = torch.Event(enable_timing=True) start_event = torch.Event(enable_timing=True)
end_event = torch.Event(enable_timing=True) end_event = torch.Event(enable_timing=True)
...@@ -92,7 +92,7 @@ def benchmark_permute( ...@@ -92,7 +92,7 @@ def benchmark_permute(
latencies: list[float] = [] latencies: list[float] = []
for i in range(num_iters): for i in range(num_iters):
prepare(i) prepare(i)
torch.cuda.synchronize() torch.accelerator.synchronize()
start_event.record() start_event.record()
graph.replay() graph.replay()
...@@ -185,26 +185,26 @@ def benchmark_unpermute( ...@@ -185,26 +185,26 @@ def benchmark_unpermute(
# JIT compilation & warmup # JIT compilation & warmup
input = prepare() input = prepare()
run(input) run(input)
torch.cuda.synchronize() torch.accelerator.synchronize()
# Capture 10 invocations with CUDA graph # Capture 10 invocations with CUDA graph
graph = torch.cuda.CUDAGraph() graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph): with torch.cuda.graph(graph):
for _ in range(10): for _ in range(10):
run(input) run(input)
torch.cuda.synchronize() torch.accelerator.synchronize()
# Warmup # Warmup
for _ in range(5): for _ in range(5):
graph.replay() graph.replay()
torch.cuda.synchronize() torch.accelerator.synchronize()
start_event = torch.Event(enable_timing=True) start_event = torch.Event(enable_timing=True)
end_event = torch.Event(enable_timing=True) end_event = torch.Event(enable_timing=True)
latencies: list[float] = [] latencies: list[float] = []
for i in range(num_iters): for i in range(num_iters):
torch.cuda.synchronize() torch.accelerator.synchronize()
start_event.record() start_event.record()
graph.replay() graph.replay()
end_event.record() end_event.record()
......
...@@ -36,6 +36,7 @@ from typing import Any ...@@ -36,6 +36,7 @@ from typing import Any
import numpy as np import numpy as np
import torch import torch
from vllm.benchmarks.lib.utils import default_vllm_config
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.transformers_utils.config import get_config from vllm.transformers_utils.config import get_config
from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
...@@ -78,6 +79,7 @@ def calculate_stats(times: list[float]) -> dict[str, float]: ...@@ -78,6 +79,7 @@ def calculate_stats(times: list[float]) -> dict[str, float]:
} }
@default_vllm_config()
def benchmark_mrope( def benchmark_mrope(
model_name: str, model_name: str,
num_tokens: int, num_tokens: int,
...@@ -133,14 +135,14 @@ def benchmark_mrope( ...@@ -133,14 +135,14 @@ def benchmark_mrope(
key.clone(), key.clone(),
) )
torch.cuda.synchronize() torch.accelerator.synchronize()
# Time reference implementation # Time reference implementation
torch_times = [] torch_times = []
for _ in range(benchmark_iter): for _ in range(benchmark_iter):
query_clone = query.clone() query_clone = query.clone()
key_clone = key.clone() key_clone = key.clone()
torch.cuda.synchronize() torch.accelerator.synchronize()
start_time = time.time() start_time = time.time()
mrope_helper_class.forward_native( mrope_helper_class.forward_native(
...@@ -149,7 +151,7 @@ def benchmark_mrope( ...@@ -149,7 +151,7 @@ def benchmark_mrope(
key_clone, key_clone,
) )
torch.cuda.synchronize() torch.accelerator.synchronize()
torch_times.append(time.time() - start_time) torch_times.append(time.time() - start_time)
# Time triton kernel implementation # Time triton kernel implementation
...@@ -157,14 +159,14 @@ def benchmark_mrope( ...@@ -157,14 +159,14 @@ def benchmark_mrope(
for _ in range(benchmark_iter): for _ in range(benchmark_iter):
query_clone = query.clone() query_clone = query.clone()
key_clone = key.clone() key_clone = key.clone()
torch.cuda.synchronize() torch.accelerator.synchronize()
start_time = time.time() start_time = time.time()
mrope_helper_class.forward_cuda( mrope_helper_class.forward_cuda(
positions, positions,
query_clone, query_clone,
key_clone, key_clone,
) )
torch.cuda.synchronize() torch.accelerator.synchronize()
triton_times.append(time.time() - start_time) triton_times.append(time.time() - start_time)
# Calculate statistics # Calculate statistics
......
...@@ -103,7 +103,7 @@ def main( ...@@ -103,7 +103,7 @@ def main(
max_logits = torch.empty_like(exp_sums) max_logits = torch.empty_like(exp_sums)
def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
torch.cuda.synchronize() torch.accelerator.synchronize()
if profile: if profile:
torch.cuda.cudart().cudaProfilerStart() torch.cuda.cudart().cudaProfilerStart()
start_time = time.perf_counter() start_time = time.perf_counter()
...@@ -173,7 +173,7 @@ def main( ...@@ -173,7 +173,7 @@ def main(
) )
else: else:
raise ValueError(f"Invalid version: {version}") raise ValueError(f"Invalid version: {version}")
torch.cuda.synchronize() torch.accelerator.synchronize()
end_time = time.perf_counter() end_time = time.perf_counter()
if profile: if profile:
......
...@@ -28,7 +28,7 @@ def _time_cuda( ...@@ -28,7 +28,7 @@ def _time_cuda(
# warmup # warmup
for _ in range(warmup_iters): for _ in range(warmup_iters):
fn() fn()
torch.cuda.synchronize() torch.accelerator.synchronize()
start = torch.Event(enable_timing=True) start = torch.Event(enable_timing=True)
end = torch.Event(enable_timing=True) end = torch.Event(enable_timing=True)
...@@ -37,7 +37,7 @@ def _time_cuda( ...@@ -37,7 +37,7 @@ def _time_cuda(
for _ in range(bench_iters): for _ in range(bench_iters):
fn() fn()
end.record() end.record()
torch.cuda.synchronize() torch.accelerator.synchronize()
return start.elapsed_time(end) / bench_iters # ms/iter return start.elapsed_time(end) / bench_iters # ms/iter
......
...@@ -7,6 +7,7 @@ from unittest.mock import patch ...@@ -7,6 +7,7 @@ from unittest.mock import patch
import pandas as pd import pandas as pd
import torch import torch
from vllm.benchmarks.lib.utils import default_vllm_config
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.triton_utils import triton from vllm.triton_utils import triton
...@@ -84,6 +85,7 @@ def calculate_diff( ...@@ -84,6 +85,7 @@ def calculate_diff(
configs = [] configs = []
@default_vllm_config()
def benchmark_quantization( def benchmark_quantization(
batch_size, batch_size,
hidden_size, hidden_size,
......
...@@ -29,7 +29,7 @@ def main( ...@@ -29,7 +29,7 @@ def main(
scale = torch.randn(1, 1, dtype=torch.float32) if static_scale else None scale = torch.randn(1, 1, dtype=torch.float32) if static_scale else None
def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
torch.cuda.synchronize() torch.accelerator.synchronize()
if profile: if profile:
torch.cuda.cudart().cudaProfilerStart() torch.cuda.cudart().cudaProfilerStart()
start_time = time.perf_counter() start_time = time.perf_counter()
...@@ -39,7 +39,7 @@ def main( ...@@ -39,7 +39,7 @@ def main(
ops.scaled_int8_quant(x, scale) ops.scaled_int8_quant(x, scale)
else: else:
ops.scaled_fp8_quant(x, scale) ops.scaled_fp8_quant(x, scale)
torch.cuda.synchronize() torch.accelerator.synchronize()
end_time = time.perf_counter() end_time = time.perf_counter()
if profile: if profile:
......
...@@ -84,16 +84,16 @@ def run_benchmark( ...@@ -84,16 +84,16 @@ def run_benchmark(
g = torch.cuda.CUDAGraph() g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g): with torch.cuda.graph(g):
function_under_test() function_under_test()
torch.cuda.synchronize() torch.accelerator.synchronize()
function_under_test = lambda: g.replay() function_under_test = lambda: g.replay()
def run_cuda_benchmark(n_iters: int) -> float: def run_cuda_benchmark(n_iters: int) -> float:
nonlocal key, value, key_cache, value_cache, slot_mapping nonlocal key, value, key_cache, value_cache, slot_mapping
torch.cuda.synchronize() torch.accelerator.synchronize()
start = time.perf_counter() start = time.perf_counter()
for _ in range(n_iters): for _ in range(n_iters):
function_under_test() function_under_test()
torch.cuda.synchronize() torch.accelerator.synchronize()
end = time.perf_counter() end = time.perf_counter()
return (end - start) / n_iters return (end - start) / n_iters
...@@ -104,7 +104,7 @@ def run_benchmark( ...@@ -104,7 +104,7 @@ def run_benchmark(
# free tensors to mitigate OOM when sweeping # free tensors to mitigate OOM when sweeping
del key, value, key_cache, value_cache, slot_mapping del key, value, key_cache, value_cache, slot_mapping
torch.cuda.empty_cache() torch.accelerator.empty_cache()
return lat return lat
......
...@@ -109,16 +109,16 @@ def run_benchmark( ...@@ -109,16 +109,16 @@ def run_benchmark(
g = torch.cuda.CUDAGraph() g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g): with torch.cuda.graph(g):
function_under_test() function_under_test()
torch.cuda.synchronize() torch.accelerator.synchronize()
function_under_test = lambda: g.replay() function_under_test = lambda: g.replay()
def run_cuda_benchmark(n_iters: int) -> float: def run_cuda_benchmark(n_iters: int) -> float:
nonlocal key, value, key_cache, value_cache, slot_mapping nonlocal key, value, key_cache, value_cache, slot_mapping
torch.cuda.synchronize() torch.accelerator.synchronize()
start = time.perf_counter() start = time.perf_counter()
for _ in range(n_iters): for _ in range(n_iters):
function_under_test() function_under_test()
torch.cuda.synchronize() torch.accelerator.synchronize()
end = time.perf_counter() end = time.perf_counter()
return (end - start) / n_iters return (end - start) / n_iters
...@@ -129,7 +129,7 @@ def run_benchmark( ...@@ -129,7 +129,7 @@ def run_benchmark(
# free tensors to mitigate OOM when sweeping # free tensors to mitigate OOM when sweeping
del key, value, key_cache, value_cache, slot_mapping del key, value, key_cache, value_cache, slot_mapping
torch.cuda.empty_cache() torch.accelerator.empty_cache()
return lat return lat
......
...@@ -5,6 +5,7 @@ import itertools ...@@ -5,6 +5,7 @@ import itertools
import torch import torch
from vllm.benchmarks.lib.utils import default_vllm_config
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.triton_utils import triton from vllm.triton_utils import triton
from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
...@@ -29,6 +30,7 @@ def get_benchmark(head_size, rotary_dim, is_neox_style, device): ...@@ -29,6 +30,7 @@ def get_benchmark(head_size, rotary_dim, is_neox_style, device):
args={}, args={},
) )
) )
@default_vllm_config()
def benchmark(batch_size, seq_len, num_heads, provider): def benchmark(batch_size, seq_len, num_heads, provider):
dtype = torch.bfloat16 dtype = torch.bfloat16
max_position = 8192 max_position = 8192
......
...@@ -251,7 +251,7 @@ def benchmark( ...@@ -251,7 +251,7 @@ def benchmark(
kernel( kernel(
y, tokens_per_expert, num_parallel_tokens=num_parallel_tokens, group_size=G y, tokens_per_expert, num_parallel_tokens=num_parallel_tokens, group_size=G
) )
torch.cuda.synchronize() torch.accelerator.synchronize()
start_event = torch.Event(enable_timing=True) start_event = torch.Event(enable_timing=True)
end_event = torch.Event(enable_timing=True) end_event = torch.Event(enable_timing=True)
...@@ -259,7 +259,7 @@ def benchmark( ...@@ -259,7 +259,7 @@ def benchmark(
# Benchmark # Benchmark
latencies: list[float] = [] latencies: list[float] = []
for _ in range(runs): for _ in range(runs):
torch.cuda.synchronize() torch.accelerator.synchronize()
start_event.record() start_event.record()
for i in range(iterations_per_run): for i in range(iterations_per_run):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment