Unverified Commit 7353fb9b authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

feat: integrate norm kernels into sgl-kernel (#3052)

parent bcda0c9e
from sgl_kernel.ops import (
custom_dispose,
custom_reduce,
fused_add_rmsnorm,
gemma_fused_add_rmsnorm,
gemma_rmsnorm,
get_graph_buffer_ipc_meta,
init_custom_reduce,
int8_scaled_mm,
......@@ -12,14 +15,17 @@ from sgl_kernel.ops import (
)
__all__ = [
"moe_align_block_size",
"init_custom_reduce",
"custom_dispose",
"custom_reduce",
"int8_scaled_mm",
"sampling_scaling_penalties",
"fused_add_rmsnorm",
"gemma_fused_add_rmsnorm",
"gemma_rmsnorm",
"get_graph_buffer_ipc_meta",
"init_custom_reduce",
"int8_scaled_mm",
"moe_align_block_size",
"register_graph_buffers",
"rotary_embedding",
"rmsnorm",
"rotary_embedding",
"sampling_scaling_penalties",
]
......@@ -33,6 +33,16 @@ void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, torch::Ten
// rms norm
void rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream);
// fused rms norm
void fused_add_rmsnorm(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps, int64_t cuda_stream);
// gemma rms norm
void gemma_rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream);
// fused gemma rms norm
void gemma_fused_add_rmsnorm(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps,
int64_t cuda_stream);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// trt_reduce
m.def("init_custom_ar", &init_custom_ar, "init custom allreduce meta (CUDA)");
......@@ -50,4 +60,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("rotary_embedding", &rotary_embedding, "Rotary Embedding (CUDA)");
// rms norm
m.def("rmsnorm", &rmsnorm, "RMSNorm (CUDA)");
// fused rms norm
m.def("fused_add_rmsnorm", &fused_add_rmsnorm, "Fused Add RMSNorm (CUDA)");
// gemma rms norm
m.def("gemma_rmsnorm", &gemma_rmsnorm, "Gemma RMSNorm (CUDA)");
// fused gemma rms norm
m.def("gemma_fused_add_rmsnorm", &gemma_fused_add_rmsnorm, "Gemma Fused Add RMSNorm (CUDA)");
}
......@@ -3,6 +3,9 @@ from typing import Optional
import torch
from sgl_kernel.ops._kernels import all_reduce as _all_reduce
from sgl_kernel.ops._kernels import dispose as _dispose
from sgl_kernel.ops._kernels import fused_add_rmsnorm as _fused_add_rmsnorm
from sgl_kernel.ops._kernels import gemma_fused_add_rmsnorm as _gemma_fused_add_rmsnorm
from sgl_kernel.ops._kernels import gemma_rmsnorm as _gemma_rmsnorm
from sgl_kernel.ops._kernels import (
get_graph_buffer_ipc_meta as _get_graph_buffer_ipc_meta,
)
......@@ -17,6 +20,10 @@ from sgl_kernel.ops._kernels import (
)
def get_cuda_stream(device: torch.device) -> int:
return torch.cuda.current_stream(device).cuda_stream
def init_custom_reduce(
rank_id, num_devices, rank_data, buffers, tmp_buffers, barrier_in, barrier_out
):
......@@ -88,9 +95,35 @@ def rmsnorm(
eps: float = 1e-6,
out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if out is None:
out = torch.empty_like(input)
stream = torch.cuda.current_stream().cuda_stream
stream_int = int(stream)
_rmsnorm(out, input, weight, eps, stream_int)
return out
with input.device as device:
if out is None:
out = torch.empty_like(input)
_rmsnorm(out, input, weight, eps, get_cuda_stream(device))
return out
def fused_add_rmsnorm(
input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
) -> None:
with input.device as device:
_fused_add_rmsnorm(input, residual, weight, eps, get_cuda_stream(device))
def gemma_rmsnorm(
input: torch.Tensor,
weight: torch.Tensor,
eps: float = 1e-6,
out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
with input.device as device:
if out is None:
out = torch.empty_like(input)
_gemma_rmsnorm(out, input, weight, eps, get_cuda_stream(device))
return out
def gemma_fused_add_rmsnorm(
input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
) -> None:
with input.device as device:
_gemma_fused_add_rmsnorm(input, residual, weight, eps, get_cuda_stream(device))
# Adapted from https://github.com/flashinfer-ai/flashinfer/blob/4e8eb1879f9c3ba6d75511e5893183bf8f289a62/tests/test_norm.py
import pytest
import sgl_kernel
import torch
def llama_rms_norm(x, w, eps=1e-6):
orig_dtype = x.dtype
x = x.float()
variance = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + eps)
x = x * w.float()
x = x.to(orig_dtype)
return x
def gemma_rms_norm(x, w, eps=1e-6):
orig_dtype = x.dtype
x = x.float()
variance = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + eps)
x = x * (1.0 + w.float())
x = x.to(orig_dtype)
return x
def gemma_fused_add_rms_norm(x, residual, w, eps=1e-6):
orig_dtype = x.dtype
x = x + residual
residual = x
x = x.float()
variance = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + eps)
x = x * (1.0 + w.float())
x = x.to(orig_dtype)
return x, residual
def fused_add_rms_norm(x, residual, weight, eps):
orig_dtype = x.dtype
x = x.to(torch.float32)
x = x + residual.to(torch.float32)
residual = x.to(orig_dtype)
variance = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + eps)
x = (x * weight.float()).to(orig_dtype)
return x, residual
@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384])
@pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("specify_out", [True, False])
def test_norm(batch_size, hidden_size, dtype, specify_out):
x = torch.randn(batch_size, hidden_size).to(0).to(dtype)
w = torch.randn(hidden_size).to(0).to(dtype)
y_ref = llama_rms_norm(x, w)
if specify_out:
y = torch.empty_like(x)
sgl_kernel.rmsnorm(x, w, out=y)
else:
y = sgl_kernel.rmsnorm(x, w)
torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3)
@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384])
@pytest.mark.parametrize("dtype", [torch.float16])
def test_fused_add_rmsnorm(batch_size, hidden_size, dtype):
eps = 1e-6
x = torch.randn(batch_size, hidden_size, dtype=dtype, device="cuda")
residual = torch.randn_like(x)
weight = torch.randn(hidden_size, dtype=dtype, device="cuda")
x_native, residual_native = fused_add_rms_norm(
x.clone(), residual.clone(), weight, eps
)
x_fused = x.clone()
residual_fused = residual.clone()
sgl_kernel.fused_add_rmsnorm(x_fused, residual_fused, weight, eps)
torch.testing.assert_close(x_fused, x_native, rtol=1e-3, atol=1e-3)
torch.testing.assert_close(residual_fused, residual_native, rtol=1e-3, atol=1e-3)
@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384])
@pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("specify_out", [True, False])
def test_gemma_norm(batch_size, hidden_size, dtype, specify_out):
x = torch.randn(batch_size, hidden_size).to(0).to(dtype)
w = torch.randn(hidden_size).to(0).to(dtype)
y_ref = gemma_rms_norm(x, w)
if specify_out:
y = torch.empty_like(x)
sgl_kernel.gemma_rmsnorm(x, w, out=y)
else:
y = sgl_kernel.gemma_rmsnorm(x, w)
torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3)
@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384])
@pytest.mark.parametrize("dtype", [torch.float16])
def test_gemma_fused_add_rmsnorm(batch_size, hidden_size, dtype):
eps = 1e-6
x = torch.randn(batch_size, hidden_size, dtype=dtype, device="cuda")
residual = torch.randn_like(x)
weight = torch.randn(hidden_size, dtype=dtype, device="cuda")
x_native, residual_native = gemma_fused_add_rms_norm(
x.clone(), residual.clone(), weight, eps
)
x_fused = x.clone()
residual_fused = residual.clone()
sgl_kernel.gemma_fused_add_rmsnorm(x_fused, residual_fused, weight, eps)
torch.testing.assert_close(x_fused, x_native, rtol=1e-3, atol=1e-3)
torch.testing.assert_close(residual_fused, residual_native, rtol=1e-3, atol=1e-3)
import pytest
import torch
from sgl_kernel import rmsnorm
def llama_rms_norm(x, w, eps=1e-6):
orig_dtype = x.dtype
x = x.float()
variance = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + eps)
x = x * w.float()
x = x.to(orig_dtype)
return x
@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384])
@pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("specify_out", [True, False])
def test_norm(batch_size, hidden_size, dtype, specify_out):
x = torch.randn(batch_size, hidden_size).to(0).to(dtype)
w = torch.randn(hidden_size).to(0).to(dtype)
y_ref = llama_rms_norm(x, w)
if specify_out:
y = torch.empty_like(x)
rmsnorm(x, w, out=y)
else:
y = rmsnorm(x, w)
torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment