Unverified Commit e607850f authored by akhilg-nv's avatar akhilg-nv Committed by GitHub
Browse files

Enable mixed type LayerNorm kernel for NSA indexer (#12044)

parent 15efbcb4
...@@ -4,11 +4,10 @@ from abc import ABC, abstractmethod ...@@ -4,11 +4,10 @@ from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, Optional from typing import TYPE_CHECKING, Any, Dict, Optional
import torch import torch
import torch.nn.functional as F
from einops import rearrange from einops import rearrange
from torch import nn
from sglang.srt.custom_op import CustomOp from sglang.srt.custom_op import CustomOp
from sglang.srt.layers.layernorm import LayerNorm
from sglang.srt.utils import add_prefix, align, is_cuda, is_hip, is_npu from sglang.srt.utils import add_prefix, align, is_cuda, is_hip, is_npu
if is_cuda(): if is_cuda():
...@@ -83,24 +82,6 @@ def rotate_activation(x: torch.Tensor) -> torch.Tensor: ...@@ -83,24 +82,6 @@ def rotate_activation(x: torch.Tensor) -> torch.Tensor:
return hadamard_transform(x, scale=hidden_size**-0.5) return hadamard_transform(x, scale=hidden_size**-0.5)
class V32LayerNorm(nn.Module):
"""
Layer Normalization.
"""
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.dim = dim
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32))
self.bias = nn.Parameter(torch.zeros(dim, dtype=torch.float32))
def forward(self, x: torch.Tensor):
return F.layer_norm(
x.float(), (self.dim,), self.weight, self.bias, self.eps
).type_as(x)
class Indexer(CustomOp): class Indexer(CustomOp):
def __init__( def __init__(
self, self,
...@@ -164,7 +145,7 @@ class Indexer(CustomOp): ...@@ -164,7 +145,7 @@ class Indexer(CustomOp):
bias=False, bias=False,
prefix=add_prefix("weights_proj", prefix), prefix=add_prefix("weights_proj", prefix),
) )
self.k_norm = V32LayerNorm(self.head_dim) self.k_norm = LayerNorm(self.head_dim, dtype=torch.float32)
self.rotary_emb = get_rope_wrapper( self.rotary_emb = get_rope_wrapper(
rope_head_dim, rope_head_dim,
rotary_dim=rope_head_dim, rotary_dim=rope_head_dim,
......
...@@ -18,6 +18,7 @@ from typing import Optional, Tuple, Union ...@@ -18,6 +18,7 @@ from typing import Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from packaging.version import Version from packaging.version import Version
from sglang.srt.batch_invariant_ops import ( from sglang.srt.batch_invariant_ops import (
...@@ -46,11 +47,19 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip ...@@ -46,11 +47,19 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
_is_cpu_amx_available = cpu_has_amx_support() _is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu() _is_cpu = is_cpu()
_is_xpu = is_xpu() _is_xpu = is_xpu()
_flashinfer_layernorm_available = False
if _is_cuda or _is_xpu: if _is_cuda or _is_xpu:
# if _is_flashinfer_available: if _is_flashinfer_available:
# from flashinfer.norm import fused_add_rmsnorm try:
# else: from flashinfer.norm import layernorm
_flashinfer_layernorm_available = True
except (ImportError, AttributeError):
_flashinfer_layernorm_available = False
else:
_flashinfer_layernorm_available = False
from sgl_kernel import ( from sgl_kernel import (
fused_add_rmsnorm, fused_add_rmsnorm,
gemma_fused_add_rmsnorm, gemma_fused_add_rmsnorm,
...@@ -289,6 +298,85 @@ class RMSNorm(CustomOp): ...@@ -289,6 +298,85 @@ class RMSNorm(CustomOp):
return self.forward(x, residual) return self.forward(x, residual)
class LayerNorm(CustomOp):
def __init__(
self,
hidden_size: int,
eps: float = 1e-6,
elementwise_affine: bool = True,
bias: bool = True,
dtype: torch.dtype = torch.float32,
) -> None:
super().__init__()
self.hidden_size = hidden_size
self.variance_epsilon = eps
self.elementwise_affine = elementwise_affine
self.use_bias = bias
self.dtype = dtype
self.bias = nn.Parameter(torch.zeros(hidden_size, dtype=self.dtype))
self.weight = nn.Parameter(torch.ones(hidden_size, dtype=self.dtype))
def forward_cuda(
self,
x: torch.Tensor,
) -> torch.Tensor:
if (
_flashinfer_layernorm_available
and x.dtype == torch.bfloat16
and self.dtype == torch.float32
):
return layernorm(x, self.weight, self.bias, self.variance_epsilon)
else:
return self.forward_native(x)
def forward_native(
self,
x: torch.Tensor,
) -> torch.Tensor:
weight = self.weight if self.elementwise_affine else None
bias = self.bias if self.use_bias else None
orig_dtype = x.dtype
x = x.to(self.dtype)
return F.layer_norm(
x,
(self.hidden_size,),
weight=self.weight,
bias=bias,
eps=self.variance_epsilon,
).to(orig_dtype)
def forward_hip(
self,
x: torch.Tensor,
) -> torch.Tensor:
return self.forward_native(x)
def forward_npu(
self,
x: torch.Tensor,
) -> torch.Tensor:
orig_dtype = x.dtype
x = x.to(self.dtype)
mean = x.mean(dim=-1, keepdim=True)
variance = (x - mean).pow(2).mean(dim=-1, keepdim=True)
x = (x - mean) * torch.rsqrt(variance + self.variance_epsilon)
if self.elementwise_affine:
x = x * self.weight.to(self.dtype)
if self.use_bias:
x = x + self.bias.to(self.dtype)
return x.to(orig_dtype)
def forward_cpu(
self,
x: torch.Tensor,
) -> torch.Tensor:
return self.forward_native(x)
class GemmaRMSNorm(CustomOp): class GemmaRMSNorm(CustomOp):
def __init__( def __init__(
self, self,
......
...@@ -3,7 +3,7 @@ import unittest ...@@ -3,7 +3,7 @@ import unittest
import torch import torch
from sglang.srt.layers.layernorm import GemmaRMSNorm, RMSNorm from sglang.srt.layers.layernorm import GemmaRMSNorm, LayerNorm, RMSNorm
from sglang.test.test_utils import CustomTestCase from sglang.test.test_utils import CustomTestCase
...@@ -109,5 +109,77 @@ class TestGemmaRMSNorm(CustomTestCase): ...@@ -109,5 +109,77 @@ class TestGemmaRMSNorm(CustomTestCase):
self._run_gemma_rms_norm_test(*params) self._run_gemma_rms_norm_test(*params)
class TestLayerNorm(CustomTestCase):
DTYPES = [torch.half, torch.bfloat16]
PARAM_DTYPES = [torch.bfloat16, torch.float32]
NUM_TOKENS = [7, 83, 1024]
HIDDEN_SIZES = [128, 512, 1536, 5120, 5124, 5125, 5126, 7168]
USE_AFFINE = [False, True]
USE_BIAS = [False, True]
SEEDS = [0]
@classmethod
def setUpClass(cls):
if not torch.cuda.is_available():
raise unittest.SkipTest("CUDA is not available")
torch.set_default_device("cuda")
def _run_layer_norm_test(
self, num_tokens, hidden_size, use_affine, use_bias, dtype, seed, param_dtype
):
torch.manual_seed(seed)
layer = LayerNorm(
hidden_size, elementwise_affine=use_affine, bias=use_bias, dtype=param_dtype
)
if use_affine:
layer.weight.data.normal_(mean=1.0, std=0.1)
if use_bias:
layer.bias.data.normal_(mean=0.0, std=0.1)
scale = 1 / (2 * hidden_size)
x = torch.randn(num_tokens, hidden_size, dtype=dtype) * scale
with torch.inference_mode():
ref_out = layer.forward_native(x)
out = layer(x)
self.assertTrue(torch.allclose(out, ref_out, atol=1e-2, rtol=1e-3))
if (
use_affine
and use_bias
and not (dtype == torch.bfloat16 and param_dtype == torch.float32)
):
layer.dtype = torch.float32
layer.weight.data = layer.weight.data.to(torch.float32)
layer.bias.data = layer.bias.data.to(torch.float32)
with torch.inference_mode():
cuda_out = layer(x.to(torch.bfloat16)).to(x.dtype)
self.assertTrue(torch.allclose(cuda_out, ref_out, atol=2e-2, rtol=1e-3))
def test_layer_norm(self):
for params in itertools.product(
self.NUM_TOKENS,
self.HIDDEN_SIZES,
self.USE_AFFINE,
self.USE_BIAS,
self.DTYPES,
self.SEEDS,
self.PARAM_DTYPES,
):
with self.subTest(
num_tokens=params[0],
hidden_size=params[1],
use_affine=params[2],
use_bias=params[3],
dtype=params[4],
seed=params[5],
param_dtype=params[6],
):
self._run_layer_norm_test(*params)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main(verbosity=2) unittest.main(verbosity=2)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment