Unverified Commit 9d9b482a authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

feat: integrate activation kernels into sgl-kernel (#3053)

parent 7353fb9b
......@@ -2,6 +2,8 @@ from sgl_kernel.ops import (
custom_dispose,
custom_reduce,
fused_add_rmsnorm,
gelu_and_mul,
gelu_tanh_and_mul,
gemma_fused_add_rmsnorm,
gemma_rmsnorm,
get_graph_buffer_ipc_meta,
......@@ -12,12 +14,15 @@ from sgl_kernel.ops import (
rmsnorm,
rotary_embedding,
sampling_scaling_penalties,
silu_and_mul,
)
__all__ = [
"custom_dispose",
"custom_reduce",
"fused_add_rmsnorm",
"gelu_and_mul",
"gelu_tanh_and_mul",
"gemma_fused_add_rmsnorm",
"gemma_rmsnorm",
"get_graph_buffer_ipc_meta",
......@@ -28,4 +33,5 @@ __all__ = [
"rmsnorm",
"rotary_embedding",
"sampling_scaling_penalties",
"silu_and_mul",
]
......@@ -43,6 +43,15 @@ void gemma_rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, do
void gemma_fused_add_rmsnorm(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps,
int64_t cuda_stream);
// silu and mul
void silu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
// gelu tanh and mul
void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
// gelu and mul
void gelu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// trt_reduce
m.def("init_custom_ar", &init_custom_ar, "init custom allreduce meta (CUDA)");
......@@ -66,4 +75,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("gemma_rmsnorm", &gemma_rmsnorm, "Gemma RMSNorm (CUDA)");
// fused gemma rms norm
m.def("gemma_fused_add_rmsnorm", &gemma_fused_add_rmsnorm, "Gemma Fused Add RMSNorm (CUDA)");
// silu and mul
m.def("silu_and_mul", &silu_and_mul, "Silu and Mul (CUDA)");
// gelu tanh and mul
m.def("gelu_tanh_and_mul", &gelu_tanh_and_mul, "Gelu Tanh and Mul (CUDA)");
// gelu and mul
m.def("gelu_and_mul", &gelu_and_mul, "Gelu and Mul (CUDA)");
}
......@@ -4,6 +4,8 @@ import torch
from sgl_kernel.ops._kernels import all_reduce as _all_reduce
from sgl_kernel.ops._kernels import dispose as _dispose
from sgl_kernel.ops._kernels import fused_add_rmsnorm as _fused_add_rmsnorm
from sgl_kernel.ops._kernels import gelu_and_mul as _gelu_and_mul
from sgl_kernel.ops._kernels import gelu_tanh_and_mul as _gelu_tanh_and_mul
from sgl_kernel.ops._kernels import gemma_fused_add_rmsnorm as _gemma_fused_add_rmsnorm
from sgl_kernel.ops._kernels import gemma_rmsnorm as _gemma_rmsnorm
from sgl_kernel.ops._kernels import (
......@@ -18,6 +20,7 @@ from sgl_kernel.ops._kernels import rotary_embedding as _rotary_embedding
from sgl_kernel.ops._kernels import (
sampling_scaling_penalties as _sampling_scaling_penalties,
)
from sgl_kernel.ops._kernels import silu_and_mul as _silu_and_mul
def get_cuda_stream(device: torch.device) -> int:
......@@ -127,3 +130,61 @@ def gemma_fused_add_rmsnorm(
) -> None:
with input.device as device:
_gemma_fused_add_rmsnorm(input, residual, weight, eps, get_cuda_stream(device))
def _check_shape(input: torch.Tensor, output: torch.Tensor) -> None:
assert input.ndim == output.ndim, f"{input.ndim} != {output.ndim}"
assert (
input.shape[:-1] == output.shape[:-1]
), f"{input.shape[:-1]} != {output.shape[:-1]}"
assert (
input.shape[-1] == 2 * output.shape[-1]
), f"{input.shape[-1]} != {2 * output.shape[-1]}"
def silu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
if input.shape[-1] * input.dtype.itemsize % 16 != 0:
raise ValueError("The pointers must be multiple of 16 bytes.")
if out is not None:
_check_shape(input, out)
else:
out = torch.empty(
input.shape[:-1] + (input.shape[-1] // 2,),
device=input.device,
dtype=input.dtype,
)
with input.device as device:
_silu_and_mul(out, input, get_cuda_stream(device))
return out
def gelu_tanh_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
if input.shape[-1] * input.dtype.itemsize % 16 != 0:
raise ValueError("The pointers must be multiple of 16 bytes.")
if out is not None:
_check_shape(input, out)
else:
out = torch.empty(
input.shape[:-1] + (input.shape[-1] // 2,),
device=input.device,
dtype=input.dtype,
)
with input.device as device:
_gelu_tanh_and_mul(out, input, get_cuda_stream(device))
return out
def gelu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
if input.shape[-1] * input.dtype.itemsize % 16 != 0:
raise ValueError("The pointers must be multiple of 16 bytes.")
if out is not None:
_check_shape(input, out)
else:
out = torch.empty(
input.shape[:-1] + (input.shape[-1] // 2,),
device=input.device,
dtype=input.dtype,
)
with input.device as device:
_gelu_and_mul(out, input, get_cuda_stream(device))
return out
# Adapted from https://github.com/flashinfer-ai/flashinfer/blob/4e8eb1879f9c3ba6d75511e5893183bf8f289a62/tests/test_activation.py
import pytest
import sgl_kernel
import torch
@pytest.mark.parametrize("dim", [128, 256, 512, 2048, 4096, 11008, 16384])
@pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16])
@pytest.mark.parametrize("seq_len", [1, 2, 4, 8, 16, 32, 64, 128, 512])
def test_fused_silu_mul(dim, batch_size, seq_len):
x = torch.randn(batch_size, seq_len, 2 * dim).to(0).to(torch.float16)
y_ref = x[..., dim:] * torch.nn.functional.silu(x[..., :dim])
y = sgl_kernel.silu_and_mul(x)
torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3)
@pytest.mark.parametrize("dim", [128, 256, 512, 2048, 4096, 11008, 16384])
@pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16])
@pytest.mark.parametrize("seq_len", [1, 2, 4, 8, 16, 32, 64, 128, 512])
def test_fused_gelu_tanh_mul(dim, batch_size, seq_len):
x = torch.randn(batch_size, seq_len, 2 * dim).to(0).to(torch.float16)
y_ref = x[..., dim:] * torch.nn.functional.gelu(x[..., :dim], approximate="tanh")
y = sgl_kernel.gelu_tanh_and_mul(x)
torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3)
@pytest.mark.parametrize("dim", [128, 256, 512, 2048, 4096, 11008, 16384])
@pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16])
@pytest.mark.parametrize("seq_len", [1, 2, 4, 8, 16, 32, 64, 128, 512])
def test_fused_gelu_mul(dim, batch_size, seq_len):
x = torch.randn(batch_size, seq_len, 2 * dim).to(0).to(torch.float16)
y_ref = x[..., dim:] * torch.nn.functional.gelu(x[..., :dim], approximate="none")
y = sgl_kernel.gelu_and_mul(x)
torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3)
test_fused_silu_mul(128, 1, 1)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment