Unverified Commit 82392da8 authored by HandH1998's avatar HandH1998 Committed by GitHub
Browse files

support w8a8 fp8 kernel with CUTLASS (#3047)


Co-authored-by: default avataryych0745 <1398089567@qq.com>
parent 95f789ad
import argparse
import copy
import itertools
import torch
import triton
from sgl_kernel import fp8_scaled_mm as sgl_scaled_mm
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
from vllm._custom_ops import scaled_fp8_quant as vllm_scaled_fp8_quant
# Weight Shapes are in the format
# ([K, N], TP_SPLIT_DIM)
# Example:
# A shape of ([14336, 4096], 0) indicates the following GEMM shape,
# - TP1 : K = 14336, N = 4096
# - TP2 : K = 7168, N = 4096
# A shape of ([4096, 6144], 1) indicates the following GEMM shape,
# - TP1 : K = 4096, N = 6144
# - TP4 : K = 4096, N = 1536
# TP1 shapes
WEIGHT_SHAPES = {
"meta-llama/Llama-3.1-8B-Instruct": [
([4096, 6144], 1),
([4096, 4096], 0),
([4096, 28672], 1),
([14336, 4096], 0),
],
"meta-llama/Llama-3.3-70B-Instruct": [
([8192, 10240], 1),
([8192, 8192], 0),
([8192, 57344], 1),
([28672, 8192], 0),
],
"mistralai/Mistral-Large-Instruct-2407": [
([12288, 14336], 1),
([12288, 12288], 0),
([12288, 57344], 1),
([28672, 12288], 0),
],
"Qwen/Qwen2.5-7B-Instruct": [
([3584, 4608], 1),
([3584, 3584], 0),
([3584, 37888], 1),
([18944, 3584], 0),
],
"Qwen/Qwen2.5-32B-Instruct": [
([5120, 7168], 1),
([5120, 5120], 0),
([5120, 55296], 1),
([27648, 5120], 0),
],
"Qwen/Qwen2.5-72B-Instruct": [
([8192, 10240], 1),
([8192, 8192], 0),
([8192, 59136], 1),
([29568, 8192], 0),
],
"deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct": [
([2048, 3072], 1),
([2048, 4096], 1),
([2048, 2048], 0),
([2048, 576], 0),
([2048, 21888], 1),
([10944, 2048], 0),
([2048, 2816], 1),
([1408, 2048], 0),
],
}
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size"],
x_vals=[1, 16, 64, 128, 256, 512, 1024, 2048],
x_log=False,
line_arg="provider",
line_vals=[
"vllm-fp8-fp16",
"vllm-fp8-bf16",
"sglang-fp8-fp16",
"sglang-fp8-bf16",
],
line_names=[
"vllm-fp8-fp16",
"vllm-fp8-bf16",
"sglang-fp8-fp16",
"sglang-fp8-bf16",
],
styles=[("green", "-"), ("green", "--"), ("blue", "-"), ("blue", "--")],
ylabel="GB/s",
plot_name="fp8 scaled matmul",
args={},
)
)
def benchmark(batch_size, provider, N, K):
# M, N, K = batch_size, 4096, 8192
M = batch_size
a = torch.ones((M, K), device="cuda") * 5.0
b = torch.ones((N, K), device="cuda") * 5.0
scale_a = torch.randn((M,), device="cuda", dtype=torch.float32)
scale_b = torch.randn((N,), device="cuda", dtype=torch.float32)
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a)
b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b)
b_fp8 = b_fp8.t()
quantiles = [0.5, 0.2, 0.8]
dtype = torch.float16 if "fp16" in provider else torch.bfloat16
if "vllm-fp8" in provider:
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype),
quantiles=quantiles,
)
elif "sglang-fp8" in provider:
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: sgl_scaled_mm(
a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype, bias=None
),
quantiles=quantiles,
)
gbps = lambda ms: (2 * M * N * K + M * N) * a.element_size() * 1e-9 / (ms * 1e-3)
return gbps(ms), gbps(max_ms), gbps(min_ms)
def prepare_shapes(args):
KN_model_names = []
models_tps = list(itertools.product(args.models, args.tp_sizes))
for model, tp_size in models_tps:
assert model in WEIGHT_SHAPES
for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model]):
KN[tp_split_dim] = KN[tp_split_dim] // tp_size
KN.append(model)
KN_model_names.append(KN)
return KN_model_names
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--models",
nargs="+",
type=str,
default=["meta-llama/Llama-3.1-8B-Instruct"],
help="List of models to benchmark",
)
parser.add_argument(
"--tp-sizes",
nargs="+",
type=int,
default=[1],
help="List of tensor parallel sizes",
)
args = parser.parse_args()
KN_model_names = prepare_shapes(args)
for K, N, model_name in KN_model_names:
print(f"{model_name} N={N} K={K}: ")
benchmark.run(
print_data=True, show_plots=True, save_path="bench_fp8_res", N=N, K=K
)
print("Benchmark finished!")
...@@ -56,6 +56,7 @@ include_dirs = [ ...@@ -56,6 +56,7 @@ include_dirs = [
turbomind.resolve(), turbomind.resolve(),
turbomind.resolve() / "src", turbomind.resolve() / "src",
] ]
nvcc_flags = [ nvcc_flags = [
"-DNDEBUG", "-DNDEBUG",
f"-DOPERATOR_NAMESPACE={operator_namespace}", f"-DOPERATOR_NAMESPACE={operator_namespace}",
...@@ -82,6 +83,7 @@ sources = [ ...@@ -82,6 +83,7 @@ sources = [
"src/sgl-kernel/csrc/trt_reduce_kernel.cu", "src/sgl-kernel/csrc/trt_reduce_kernel.cu",
"src/sgl-kernel/csrc/moe_align_kernel.cu", "src/sgl-kernel/csrc/moe_align_kernel.cu",
"src/sgl-kernel/csrc/int8_gemm_kernel.cu", "src/sgl-kernel/csrc/int8_gemm_kernel.cu",
"src/sgl-kernel/csrc/fp8_gemm_kernel.cu",
"src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu", "src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu",
"src/sgl-kernel/csrc/rotary_embedding.cu", "src/sgl-kernel/csrc/rotary_embedding.cu",
"3rdparty/flashinfer/csrc/activation.cu", "3rdparty/flashinfer/csrc/activation.cu",
......
...@@ -2,6 +2,7 @@ from sgl_kernel.ops import ( ...@@ -2,6 +2,7 @@ from sgl_kernel.ops import (
bmm_fp8, bmm_fp8,
custom_dispose, custom_dispose,
custom_reduce, custom_reduce,
fp8_scaled_mm,
fused_add_rmsnorm, fused_add_rmsnorm,
gelu_and_mul, gelu_and_mul,
gelu_tanh_and_mul, gelu_tanh_and_mul,
...@@ -27,6 +28,7 @@ __all__ = [ ...@@ -27,6 +28,7 @@ __all__ = [
"bmm_fp8", "bmm_fp8",
"custom_dispose", "custom_dispose",
"custom_reduce", "custom_reduce",
"fp8_scaled_mm",
"fused_add_rmsnorm", "fused_add_rmsnorm",
"gelu_and_mul", "gelu_and_mul",
"gelu_tanh_and_mul", "gelu_tanh_and_mul",
......
This diff is collapsed.
...@@ -40,6 +40,11 @@ torch::Tensor int8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& ma ...@@ -40,6 +40,11 @@ torch::Tensor int8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& ma
const torch::Tensor& scales_b, const torch::Dtype& out_dtype, const torch::Tensor& scales_b, const torch::Dtype& out_dtype,
const c10::optional<torch::Tensor>& bias); const c10::optional<torch::Tensor>& bias);
// fp8_scaled_mm
torch::Tensor fp8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b, const torch::Tensor& scales_a,
const torch::Tensor& scales_b, const torch::Dtype& out_dtype,
const c10::optional<torch::Tensor>& bias);
// lightning_attention_decode // lightning_attention_decode
void lightning_attention_decode(const torch::Tensor& q, const torch::Tensor& k, const torch::Tensor& v, void lightning_attention_decode(const torch::Tensor& q, const torch::Tensor& k, const torch::Tensor& v,
const torch::Tensor& past_kv, const torch::Tensor& slope, torch::Tensor output, const torch::Tensor& past_kv, const torch::Tensor& slope, torch::Tensor output,
......
...@@ -71,6 +71,17 @@ def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None): ...@@ -71,6 +71,17 @@ def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
) )
def fp8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
return torch.ops.sgl_kernels.fp8_scaled_mm(
mat_a,
mat_b,
scales_a,
scales_b,
out_dtype,
bias,
)
def lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv): def lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv):
torch.ops.sgl_kernels.lightning_attention_decode( torch.ops.sgl_kernels.lightning_attention_decode(
q, k, v, past_kv, slope, output, new_kv q, k, v, past_kv, slope, output, new_kv
......
...@@ -34,6 +34,12 @@ TORCH_LIBRARY_EXPAND(sgl_kernels, m) { ...@@ -34,6 +34,12 @@ TORCH_LIBRARY_EXPAND(sgl_kernels, m) {
"bias) -> Tensor"); "bias) -> Tensor");
m.impl("int8_scaled_mm", torch::kCUDA, &int8_scaled_mm); m.impl("int8_scaled_mm", torch::kCUDA, &int8_scaled_mm);
// fp8_scaled_mm
m.def(
"fp8_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, ScalarType out_dtype, Tensor? "
"bias) -> Tensor");
m.impl("fp8_scaled_mm", torch::kCUDA, &fp8_scaled_mm);
// lightning_attention_decode // lightning_attention_decode
m.def( m.def(
"lightning_attention_decode(Tensor q, Tensor k, Tensor v, Tensor past_kv, Tensor slope, Tensor! output, Tensor! " "lightning_attention_decode(Tensor q, Tensor k, Tensor v, Tensor past_kv, Tensor slope, Tensor! output, Tensor! "
......
import unittest
import torch
from sgl_kernel import fp8_scaled_mm
def torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias):
o = torch.matmul(a.to(torch.float32), b.to(torch.float32))
o = o.to(torch.float32)
temp1 = o * scale_a.view(-1, 1)
temp2 = temp1 * scale_b.view(1, -1)
final = temp2.to(out_dtype)
if bias is not None:
final = final + bias.view(1, -1)
return final
class TestFp8Gemm(unittest.TestCase):
def _test_accuracy_once(self, M, N, K, with_bias, out_dtype, device):
fp8_info = torch.finfo(torch.float8_e4m3fn)
fp8_max, fp8_min = fp8_info.max, fp8_info.min
a_fp32 = (
(torch.rand(M, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max
)
a_fp8 = a_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
b_fp32 = (
(torch.rand(N, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max
)
b_fp8 = b_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
scale_a = torch.randn((M,), device=device, dtype=torch.float32) * 0.001
scale_b = torch.randn((N,), device=device, dtype=torch.float32) * 0.001
if with_bias:
bias = torch.randn((N,), device=device, dtype=out_dtype)
else:
bias = None
o1 = torch.empty((M, N), device=device, dtype=torch.bfloat16)
b_fp8 = b_fp8.t()
o = torch_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype, bias)
o1 = fp8_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype, bias)
rtol = 0.02
atol = 1
torch.testing.assert_close(o, o1, rtol=rtol, atol=atol)
print(f"M={M}, N={N}, K={K}, with_bias={with_bias}, out_dtype={out_dtype}: OK")
def test_accuracy(self):
Ms = [1, 128, 512, 1024, 4096]
Ns = [16, 128, 512, 1024, 4096]
Ks = [512, 1024, 4096, 8192, 16384]
bias_opts = [True, False]
out_dtypes = [torch.bfloat16, torch.float16]
for M in Ms:
for N in Ns:
for K in Ks:
for with_bias in bias_opts:
for out_dtype in out_dtypes:
self._test_accuracy_once(
M, N, K, with_bias, out_dtype, "cuda"
)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment