Commit d1a06223 authored by liuxu3's avatar liuxu3
Browse files

added vllm092 auto test scripts

parent fba2e3b5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from sglang quantization/tuning_block_wise_kernel.py
import argparse
import json
import multiprocessing as mp
import os
import time
from datetime import datetime
from typing import Any
import torch
import tqdm
import triton
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
_w8a8_block_fp8_matmul,
)
from vllm.platforms import current_platform
from vllm.utils import FlexibleArgumentParser
mp.set_start_method("spawn", force=True)
assert current_platform.is_cuda(), (
"Only support tune w8a8 block fp8 kernel on CUDA device."
)
DTYPE_MAP = {
"float32": torch.float32,
"float16": torch.float16,
"half": torch.half,
"bfloat16": torch.bfloat16,
}
def w8a8_block_matmul(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
block_size: list[int],
config: dict[str, Any],
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
"""This function performs matrix multiplication with
block-wise quantization.
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
The output is returned in the specified `output_dtype`.
Args:
A: The input tensor, e.g., activation.
B: The input tensor, e.g., weight.
As: The per-token-group quantization scale for `A`.
Bs: The per-block quantization scale for `B`.
block_size: The block size for per-block quantization.
It should be 2-dim, e.g., [128, 128].
output_dytpe: The dtype of the returned tensor.
Returns:
torch.Tensor: The result of matmul.
"""
assert len(block_size) == 2
block_n, block_k = block_size[0], block_size[1]
assert A.shape[-1] == B.shape[-1]
assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous()
assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
M = A.numel() // A.shape[-1]
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
N, K = B.shape
assert triton.cdiv(N, block_n) == Bs.shape[0]
assert triton.cdiv(K, block_k) == Bs.shape[1]
C_shape = A.shape[:-1] + (N,)
C = A.new_empty(C_shape, dtype=output_dtype)
def grid(META):
return (
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
)
if A.dtype == torch.float8_e4m3fn:
kernel = _w8a8_block_fp8_matmul
else:
raise RuntimeError("Currently, only support tune w8a8 block fp8 kernel.")
kernel[grid](
A,
B,
C,
As,
Bs,
M,
N,
K,
block_n,
block_k,
A.stride(-2),
A.stride(-1),
B.stride(1),
B.stride(0),
C.stride(-2),
C.stride(-1),
As.stride(-2),
As.stride(-1),
Bs.stride(1),
Bs.stride(0),
**config,
)
return C
def get_configs_compute_bound():
configs = []
for num_stages in [2, 3, 4, 5]:
for block_m in [16, 32, 64, 128, 256]:
for block_k in [64, 128]:
for block_n in [32, 64, 128, 256]:
for num_warps in [4, 8]:
for group_size in [1, 16, 32, 64]:
configs.append(
{
"BLOCK_SIZE_M": block_m,
"BLOCK_SIZE_N": block_n,
"BLOCK_SIZE_K": block_k,
"GROUP_SIZE_M": group_size,
"num_warps": num_warps,
"num_stages": num_stages,
}
)
return configs
def get_weight_shapes(tp_size):
# NOTE(HandH1998): The weight shapes only works for DeepSeek-V3.
# Modify them, if you tune for another different model.
# cannot TP
total = [
(512 + 64, 7168),
((128 + 64) * 128, 7168),
(128 * (128 + 128), 512),
(7168, 16384),
(7168, 18432),
]
# N can TP
n_tp = [
(18432 * 2, 7168),
((128 + 64) * 128, 7168),
(128 * (128 + 128), 512),
(24576, 1536),
(12288, 7168),
(4096, 7168),
]
# K can TP
k_tp = [(7168, 18432), (7168, 16384), (7168, 2048)]
weight_shapes = []
for t in total:
weight_shapes.append(t)
for n_t in n_tp:
new_t = (n_t[0] // tp_size, n_t[1])
weight_shapes.append(new_t)
for k_t in k_tp:
new_t = (k_t[0], k_t[1] // tp_size)
weight_shapes.append(new_t)
return weight_shapes
def benchmark_config(
A, B, As, Bs, block_size, config, out_dtype=torch.float16, num_iters=10
):
def run():
w8a8_block_matmul(A, B, As, Bs, block_size, config, out_dtype)
torch.cuda.synchronize()
# JIT complication & warmup
for _ in range(5):
run()
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
latencies: list[float] = []
for i in range(num_iters):
torch.cuda.synchronize()
start_event.record()
run()
end_event.record()
end_event.synchronize()
latencies.append(start_event.elapsed_time(end_event))
avg = sum(latencies) / (num_iters * 10) * 1000 # us
return avg
def tune(M, N, K, block_size, out_dtype, search_space, input_type):
factor_for_scale = 1e-2
if input_type == "fp8":
fp8_info = torch.finfo(torch.float8_e4m3fn)
fp8_max, fp8_min = fp8_info.max, fp8_info.min
A_fp32 = (
(torch.rand(M, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max
)
A = A_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
B_fp32 = (
(torch.rand(N, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max
)
B = B_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
else:
raise RuntimeError("Currently, only support tune w8a8 block fp8 kernel.")
block_n, block_k = block_size[0], block_size[1]
n_tiles = (N + block_n - 1) // block_n
k_tiles = (K + block_k - 1) // block_k
As = torch.rand(M, k_tiles, dtype=torch.float32, device="cuda") * factor_for_scale
Bs = (
torch.rand(n_tiles, k_tiles, dtype=torch.float32, device="cuda")
* factor_for_scale
)
best_config = None
best_time = float("inf")
for config in tqdm(search_space):
try:
kernel_time = benchmark_config(
A,
B,
As,
Bs,
block_size,
config,
out_dtype,
num_iters=10,
)
except triton.runtime.autotuner.OutOfResources:
# Some configurations may be invalid and fail to compile.
continue
if kernel_time < best_time:
best_time = kernel_time
best_config = config
now = datetime.now()
print(f"{now.ctime()}] Completed tuning for batch_size={M}")
assert best_config is not None
return best_config
def save_configs(
N,
K,
block_n,
block_k,
configs,
save_path,
input_type="fp8",
) -> None:
os.makedirs(save_path, exist_ok=True)
device_name = current_platform.get_device_name().replace(" ", "_")
json_file_name = (
f"N={N},K={K},device_name={device_name},dtype={input_type}_w8a8,"
f"block_shape=[{block_n},{block_k}].json"
)
config_file_path = os.path.join(save_path, json_file_name)
print(f"Writing best config to {config_file_path}...")
with open(config_file_path, "w") as f:
json.dump(configs, f, indent=4)
f.write("\n")
def tune_on_gpu(args_dict):
"""Run tuning on a specific GPU."""
gpu_id = args_dict["gpu_id"]
batch_sizes = args_dict["batch_sizes"]
weight_shapes = args_dict["weight_shapes"]
args = args_dict["args"]
torch.cuda.set_device(gpu_id)
print(f"Starting tuning on GPU {gpu_id} with batch sizes {batch_sizes}")
block_n = args.block_n
block_k = args.block_k
out_dtype = DTYPE_MAP[args.out_dtype]
save_path = args.save_path
input_type = args.input_type
search_space = get_configs_compute_bound()
search_space = [
config for config in search_space if block_k % config["BLOCK_SIZE_K"] == 0
]
start = time.time()
for shape in tqdm(weight_shapes, desc=f"GPU {gpu_id} - Shapes"):
N, K = shape[0], shape[1]
print(f"[GPU {gpu_id}] Tune for weight shape of `N: {N}, K: {K}`")
benchmark_results = [
tune(
batch_size,
N,
K,
[block_n, block_k],
out_dtype,
search_space,
input_type,
)
for batch_size in tqdm(batch_sizes, desc=f"GPU {gpu_id} - Batch sizes")
]
best_configs = {M: config for M, config in zip(batch_sizes, benchmark_results)}
save_configs(N, K, block_n, block_k, best_configs, save_path, input_type)
end = time.time()
print(f"Tuning on GPU {gpu_id} took {end - start:.2f} seconds")
def distribute_batch_sizes(batch_sizes, num_gpus):
"""Distribute batch sizes across available GPUs."""
batches_per_gpu = []
for i in range(num_gpus):
start_idx = i * len(batch_sizes) // num_gpus
end_idx = (i + 1) * len(batch_sizes) // num_gpus
batches_per_gpu.append(batch_sizes[start_idx:end_idx])
return batches_per_gpu
def main(args):
print(args)
num_gpus = torch.cuda.device_count()
if num_gpus == 0:
raise RuntimeError("No GPU available for tuning")
print(f"Found {num_gpus} GPUs for parallel tuning")
torch.cuda.init()
if args.batch_size is None:
batch_sizes = [
1,
2,
4,
8,
16,
24,
32,
48,
64,
96,
128,
256,
512,
1024,
1536,
2048,
3072,
4096,
]
else:
batch_sizes = [args.batch_size]
num_gpus = 1 # If only one batch size, use only one GPU
weight_shapes = get_weight_shapes(args.tp_size)
batches_per_gpu = distribute_batch_sizes(batch_sizes, num_gpus)
process_args = []
for gpu_id in range(num_gpus):
process_args.append(
{
"gpu_id": gpu_id,
"batch_sizes": batches_per_gpu[gpu_id],
"weight_shapes": weight_shapes, # Each GPU processes all weight shapes
"args": args,
}
)
ctx = mp.get_context("spawn")
with ctx.Pool(num_gpus) as pool:
pool.map(tune_on_gpu, process_args)
print("Multi-GPU tuning completed")
if __name__ == "__main__":
parser = FlexibleArgumentParser(
description="""
Tune triton w8a8 block fp8 for DeepSeek-V3/DeepSeek-R1:
python3 benchmark_w8a8_block_fp8.py --tp-size 8 --input-type fp8
Then copy to model_executor/layers/quantization/utils/configs
""",
formatter_class=argparse.RawTextHelpFormatter,
)
parser.add_argument("--tp-size", "-tp", type=int, default=8)
parser.add_argument("--input-type", type=str, choices=["fp8"], default="fp8")
parser.add_argument(
"--out-dtype",
type=str,
choices=["float32", "float16", "bfloat16", "half"],
default="float16",
)
parser.add_argument("--block-n", type=int, default=128)
parser.add_argument("--block-k", type=int, default=128)
parser.add_argument("--batch-size", type=int, required=False)
parser.add_argument("--save-path", type=str, default="./")
args = parser.parse_args()
main(args)
# DeepSeek DeepGEMM Kernels Benchmark
This directory includes benchmarks between DeepSeek's DeepGEMM block fp8 kernels against vLLM's existing triton and CUTLASS-based kernels.
Currently this just includes dense GEMMs and only works on Hopper GPUs.
## Setup
You need to install vLLM in your usual fashion, then install DeepGEMM from source in its own directory:
```
git clone --recursive https://github.com/deepseek-ai/DeepGEMM
cd DeepGEMM
python setup.py install
uv pip install -e .
```
## Usage
```
python benchmark_fp8_block_dense_gemm.py
INFO 02-26 21:55:13 [__init__.py:207] Automatically detected platform cuda.
===== STARTING FP8 GEMM BENCHMARK =====
PyTorch version: 2.5.1+cu124
CUDA version: 12.4
Triton version: 3.1.0
Using device: NVIDIA H100 80GB HBM3
WARNING 02-26 21:55:15 [fp8_utils.py:458] Using default W8A8 Block FP8 kernel config. Performance might be sub-optimal! Config file not found at /home/mgoin/code/vllm/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json
INFO 02-26 21:55:15 [fp8_utils.py:449] Using configuration from /home/mgoin/code/vllm/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json for W8A8 Block FP8 kernel.
WARNING 02-26 21:55:16 [fp8_utils.py:458] Using default W8A8 Block FP8 kernel config. Performance might be sub-optimal! Config file not found at /home/mgoin/code/vllm/vllm/model_executor/layers/quantization/utils/configs/N=18432,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json
WARNING 02-26 21:55:17 [fp8_utils.py:458] Using default W8A8 Block FP8 kernel config. Performance might be sub-optimal! Config file not found at /home/mgoin/code/vllm/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=1536,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json
INFO 02-26 21:55:17 [fp8_utils.py:449] Using configuration from /home/mgoin/code/vllm/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json for W8A8 Block FP8 kernel.
INFO 02-26 21:55:17 [fp8_utils.py:449] Using configuration from /home/mgoin/code/vllm/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json for W8A8 Block FP8 kernel.
===== PERFORMANCE COMPARISON =====
DeepGEMM Implementation:
+------+-------+-------+-----------+--------+--------+
| m | n | k | Time (μs) | TFLOPS | GB/s |
+------+-------+-------+-----------+--------+--------+
| 8 | 4096 | 7168 | 102.9 | 4.6 | 286.4 |
| 8 | 7168 | 18432 | 70.8 | 29.8 | 1868.8 |
| 8 | 18432 | 7168 | 69.3 | 30.5 | 1911.8 |
| 64 | 4096 | 7168 | 69.1 | 54.4 | 439.0 |
| 64 | 7168 | 18432 | 69.4 | 243.6 | 1933.6 |
| 64 | 18432 | 7168 | 70.4 | 240.3 | 1917.2 |
| 64 | 24576 | 1536 | 70.1 | 68.9 | 584.6 |
| 64 | 32768 | 512 | 68.4 | 31.4 | 307.1 |
| 64 | 7168 | 16384 | 69.5 | 216.3 | 1718.5 |
| 128 | 4096 | 7168 | 141.1 | 53.3 | 222.1 |
| 128 | 7168 | 18432 | 71.9 | 470.5 | 1896.1 |
| 128 | 18432 | 7168 | 69.3 | 488.2 | 1988.2 |
| 1024 | 4096 | 7168 | 89.7 | 670.1 | 502.5 |
| 1024 | 18432 | 7168 | 279.0 | 969.8 | 635.2 |
| 2048 | 4096 | 7168 | 175.1 | 687.0 | 347.4 |
| 4096 | 4096 | 7168 | 335.4 | 717.0 | 275.1 |
+------+-------+-------+-----------+--------+--------+
vLLM Triton Implementation:
+------+-------+-------+-----------+--------+--------+--------------+
| m | n | k | Time (μs) | TFLOPS | GB/s | vs DeepGEMM |
+------+-------+-------+-----------+--------+--------+--------------+
| 8 | 4096 | 7168 | 74.0 | 6.3 | 398.2 | 1.39x faster |
| 8 | 7168 | 18432 | 89.6 | 23.6 | 1478.1 | 0.79x slower |
| 8 | 18432 | 7168 | 113.2 | 18.7 | 1170.4 | 0.61x slower |
| 64 | 4096 | 7168 | 79.4 | 47.3 | 382.2 | 0.87x slower |
| 64 | 7168 | 18432 | 98.5 | 171.7 | 1363.0 | 0.70x slower |
| 64 | 18432 | 7168 | 119.5 | 141.5 | 1129.4 | 0.59x slower |
| 64 | 24576 | 1536 | 37.6 | 128.4 | 1089.7 | 1.86x faster |
| 64 | 32768 | 512 | 38.7 | 55.5 | 542.6 | 1.77x faster |
| 64 | 7168 | 16384 | 86.1 | 174.5 | 1386.4 | 0.81x slower |
| 128 | 4096 | 7168 | 90.7 | 82.9 | 345.4 | 1.56x faster |
| 128 | 7168 | 18432 | 144.0 | 234.9 | 946.9 | 0.50x slower |
| 128 | 18432 | 7168 | 229.5 | 147.4 | 600.1 | 0.30x slower |
| 1024 | 4096 | 7168 | 242.3 | 248.2 | 186.1 | 0.37x slower |
| 1024 | 18432 | 7168 | 897.8 | 301.4 | 197.4 | 0.31x slower |
| 2048 | 4096 | 7168 | 463.0 | 259.7 | 131.4 | 0.38x slower |
| 4096 | 4096 | 7168 | 901.8 | 266.7 | 102.3 | 0.37x slower |
+------+-------+-------+-----------+--------+--------+--------------+
vLLM CUTLASS Implementation:
+------+-------+-------+-----------+--------+--------+--------------+--------------+
| m | n | k | Time (μs) | TFLOPS | GB/s | vs DeepGEMM | vs Triton |
+------+-------+-------+-----------+--------+--------+--------------+--------------+
| 8 | 4096 | 7168 | 34.6 | 13.6 | 852.3 | 2.98x faster | 2.14x faster |
| 8 | 7168 | 18432 | 78.9 | 26.8 | 1677.3 | 0.90x slower | 1.13x faster |
| 8 | 18432 | 7168 | 81.2 | 26.0 | 1631.1 | 0.85x slower | 1.39x faster |
| 64 | 4096 | 7168 | 36.9 | 101.9 | 822.9 | 1.87x faster | 2.15x faster |
| 64 | 7168 | 18432 | 87.4 | 193.4 | 1535.2 | 0.79x slower | 1.13x faster |
| 64 | 18432 | 7168 | 85.0 | 199.0 | 1587.6 | 0.83x slower | 1.41x faster |
| 64 | 24576 | 1536 | 28.0 | 172.8 | 1465.8 | 2.51x faster | 1.35x faster |
| 64 | 32768 | 512 | 28.8 | 74.5 | 728.5 | 2.37x faster | 1.34x faster |
| 64 | 7168 | 16384 | 77.9 | 193.0 | 1532.8 | 0.89x slower | 1.11x faster |
| 128 | 4096 | 7168 | 39.1 | 192.4 | 802.0 | 3.61x faster | 2.32x faster |
| 128 | 7168 | 18432 | 93.7 | 360.8 | 1454.2 | 0.77x slower | 1.54x faster |
| 128 | 18432 | 7168 | 85.7 | 394.8 | 1608.0 | 0.81x slower | 2.68x faster |
| 1024 | 4096 | 7168 | 99.7 | 603.1 | 452.2 | 0.90x slower | 2.43x faster |
| 1024 | 18432 | 7168 | 331.3 | 816.7 | 534.9 | 0.84x slower | 2.71x faster |
| 2048 | 4096 | 7168 | 198.3 | 606.6 | 306.7 | 0.88x slower | 2.34x faster |
| 4096 | 4096 | 7168 | 392.2 | 613.2 | 235.3 | 0.86x slower | 2.30x faster |
+------+-------+-------+-----------+--------+--------+--------------+--------------+
===== AVERAGE PERFORMANCE =====
+----------------+------------+----------+---------------+
| Implementation | Avg TFLOPS | Avg GB/s | Avg Time (ms) |
+----------------+------------+----------+---------------+
| DeepGEMM | 310.98 | 1052.10 | 0.11 |
| vLLM Triton | 144.30 | 715.60 | 0.23 |
| vLLM CUTLASS | 286.78 | 1076.67 | 0.11 |
+----------------+------------+----------+---------------+
===== AVERAGE SPEEDUPS =====
+-----------------------------+--------------+
| Comparison | Speedup |
+-----------------------------+--------------+
| DeepGEMM vs vLLM Triton | 1.71x faster |
| DeepGEMM vs vLLM CUTLASS | 0.94x slower |
| vLLM CUTLASS vs vLLM Triton | 1.84x faster |
+-----------------------------+--------------+
===== ACCURACY COMPARISON =====
+----------------+-----------------------+
| Implementation | Avg Diff vs Reference |
+----------------+-----------------------+
| DeepGEMM | 0.000684 |
| vLLM Triton | 0.000684 |
| vLLM CUTLASS | 0.000684 |
+----------------+-----------------------+
```
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment