Unverified Commit bf669606 authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

feat: integrate bmm_fp8 kernel into sgl-kernel (#3056)

parent b2bd8f44
......@@ -62,12 +62,22 @@ nvcc_flags = [
"-std=c++17",
"-use_fast_math",
"-DFLASHINFER_ENABLE_F16",
"-DFLASHINFER_ENABLE_BF16",
]
if cuda_version >= (12, 0) and sm_version >= 90:
nvcc_flags.append("-gencode=arch=compute_90a,code=sm_90a")
if sm_version >= 90:
nvcc_flags.extend(
[
"-DFLASHINFER_ENABLE_FP8",
"-DFLASHINFER_ENABLE_FP8_E4M3",
"-DFLASHINFER_ENABLE_FP8_E5M2",
]
)
if sm_version >= 80:
nvcc_flags.append("-DFLASHINFER_ENABLE_BF16")
for flag in [
"-D__CUDA_NO_HALF_OPERATORS__",
"-D__CUDA_NO_HALF_CONVERSIONS__",
......
from sgl_kernel.ops import (
bmm_fp8,
custom_dispose,
custom_reduce,
fused_add_rmsnorm,
......@@ -18,6 +19,7 @@ from sgl_kernel.ops import (
)
__all__ = [
"bmm_fp8",
"custom_dispose",
"custom_reduce",
"fused_add_rmsnorm",
......
......@@ -52,6 +52,10 @@ void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
// gelu and mul
void gelu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
// bmm fp8
void bmm_fp8(at::Tensor A, at::Tensor B, at::Tensor D, at::Tensor A_scale, at::Tensor B_scale,
at::Tensor workspace_buffer, int64_t cublas_handle, int64_t cuda_stream);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// trt_reduce
m.def("init_custom_ar", &init_custom_ar, "init custom allreduce meta (CUDA)");
......@@ -81,4 +85,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("gelu_tanh_and_mul", &gelu_tanh_and_mul, "Gelu Tanh and Mul (CUDA)");
// gelu and mul
m.def("gelu_and_mul", &gelu_and_mul, "Gelu and Mul (CUDA)");
// bmm fp8
m.def("bmm_fp8", &bmm_fp8, "BMM FP8 (CUDA)");
}
......@@ -2,6 +2,7 @@ from typing import Optional
import torch
from sgl_kernel.ops._kernels import all_reduce as _all_reduce
from sgl_kernel.ops._kernels import bmm_fp8 as _bmm_fp8
from sgl_kernel.ops._kernels import dispose as _dispose
from sgl_kernel.ops._kernels import fused_add_rmsnorm as _fused_add_rmsnorm
from sgl_kernel.ops._kernels import gelu_and_mul as _gelu_and_mul
......@@ -21,10 +22,7 @@ from sgl_kernel.ops._kernels import (
sampling_scaling_penalties as _sampling_scaling_penalties,
)
from sgl_kernel.ops._kernels import silu_and_mul as _silu_and_mul
def get_cuda_stream(device: torch.device) -> int:
return torch.cuda.current_stream(device).cuda_stream
from sgl_kernel.ops.utils import _get_cache_buf, _get_cuda_stream
def init_custom_reduce(
......@@ -101,7 +99,7 @@ def rmsnorm(
with input.device as device:
if out is None:
out = torch.empty_like(input)
_rmsnorm(out, input, weight, eps, get_cuda_stream(device))
_rmsnorm(out, input, weight, eps, _get_cuda_stream(device))
return out
......@@ -109,7 +107,7 @@ def fused_add_rmsnorm(
input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
) -> None:
with input.device as device:
_fused_add_rmsnorm(input, residual, weight, eps, get_cuda_stream(device))
_fused_add_rmsnorm(input, residual, weight, eps, _get_cuda_stream(device))
def gemma_rmsnorm(
......@@ -121,7 +119,7 @@ def gemma_rmsnorm(
with input.device as device:
if out is None:
out = torch.empty_like(input)
_gemma_rmsnorm(out, input, weight, eps, get_cuda_stream(device))
_gemma_rmsnorm(out, input, weight, eps, _get_cuda_stream(device))
return out
......@@ -129,7 +127,7 @@ def gemma_fused_add_rmsnorm(
input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
) -> None:
with input.device as device:
_gemma_fused_add_rmsnorm(input, residual, weight, eps, get_cuda_stream(device))
_gemma_fused_add_rmsnorm(input, residual, weight, eps, _get_cuda_stream(device))
def _check_shape(input: torch.Tensor, output: torch.Tensor) -> None:
......@@ -154,7 +152,7 @@ def silu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
dtype=input.dtype,
)
with input.device as device:
_silu_and_mul(out, input, get_cuda_stream(device))
_silu_and_mul(out, input, _get_cuda_stream(device))
return out
......@@ -170,7 +168,7 @@ def gelu_tanh_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Te
dtype=input.dtype,
)
with input.device as device:
_gelu_tanh_and_mul(out, input, get_cuda_stream(device))
_gelu_tanh_and_mul(out, input, _get_cuda_stream(device))
return out
......@@ -186,5 +184,46 @@ def gelu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
dtype=input.dtype,
)
with input.device as device:
_gelu_and_mul(out, input, get_cuda_stream(device))
_gelu_and_mul(out, input, _get_cuda_stream(device))
return out
def _bmm_fp8_internal(
workspace_buffer: torch.Tensor,
A: torch.Tensor,
B: torch.Tensor,
D: torch.Tensor,
A_scale: torch.Tensor,
B_scale: torch.Tensor,
) -> None:
with A.device as device:
cublas_handle = torch.cuda.current_blas_handle()
_bmm_fp8(
A,
B,
D,
A_scale,
B_scale,
workspace_buffer,
cublas_handle,
_get_cuda_stream(device),
)
def bmm_fp8(
A: torch.Tensor,
B: torch.Tensor,
A_scale: torch.Tensor,
B_scale: torch.Tensor,
dtype: torch.dtype,
out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if out is None:
out = torch.empty(
(A.shape[0], A.shape[1], B.shape[2]),
device=A.device,
dtype=dtype,
)
workspace_buffer = _get_cache_buf("bmm_fp8_workspace", 32 * 1024 * 1024, A.device)
_bmm_fp8_internal(workspace_buffer, A, B, out, A_scale, B_scale)
return out
from typing import Dict, Tuple
import torch
def _get_cuda_stream(device: torch.device) -> int:
return torch.cuda.current_stream(device).cuda_stream
_cache_buf: Dict[Tuple[str, torch.device], torch.Tensor] = {}
def _get_cache_buf(name: str, bytes: int, device: torch.device) -> torch.Tensor:
key = (name, device)
buf = _cache_buf.get(key)
if buf is None:
buf = torch.empty(bytes, dtype=torch.uint8, device=device)
_cache_buf[key] = buf
return buf
# Adapted from https://github.com/flashinfer-ai/flashinfer/blob/4e8eb1879f9c3ba6d75511e5893183bf8f289a62/tests/test_bmm_fp8.py
import pytest
import torch
import torch.nn.functional as F
from sgl_kernel import bmm_fp8
def to_float8(x, dtype=torch.float8_e4m3fn):
finfo = torch.finfo(dtype)
min_val, max_val = x.aminmax()
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
scale = finfo.max / amax
x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
return x_scl_sat.to(dtype), scale.float().reciprocal()
@pytest.mark.parametrize("input_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])
@pytest.mark.parametrize("mat2_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])
@pytest.mark.parametrize("res_dtype", [torch.bfloat16, torch.float16])
def test_bmm_fp8(input_dtype, mat2_dtype, res_dtype):
if input_dtype == torch.float8_e5m2 and mat2_dtype == torch.float8_e5m2:
pytest.skip("Invalid combination: both input and mat2 are e5m2")
input = torch.randn([16, 48, 64], device="cuda", dtype=torch.bfloat16)
input_fp8, input_inv_s = to_float8(input, dtype=input_dtype)
# mat2 row major -> column major
mat2 = torch.randn([16, 80, 64], device="cuda", dtype=torch.bfloat16).transpose(
-2, -1
)
mat2_fp8, mat2_inv_s = to_float8(mat2, dtype=mat2_dtype)
res = torch.empty([16, 48, 80], device="cuda", dtype=res_dtype)
bmm_fp8(input_fp8, mat2_fp8, input_inv_s, mat2_inv_s, res_dtype, res)
reference = torch.bmm(input, mat2)
cos_sim = F.cosine_similarity(reference.reshape(-1), res.reshape(-1), dim=0)
assert cos_sim > 0.99
if __name__ == "__main__":
pytest.main([__file__])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment