Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
e607850f
Unverified
Commit
e607850f
authored
Nov 03, 2025
by
akhilg-nv
Committed by
GitHub
Nov 03, 2025
Browse files
Enable mixed type LayerNorm kernel for NSA indexer (#12044)
parent
15efbcb4
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
166 additions
and
25 deletions
+166
-25
python/sglang/srt/layers/attention/nsa/nsa_indexer.py
python/sglang/srt/layers/attention/nsa/nsa_indexer.py
+2
-21
python/sglang/srt/layers/layernorm.py
python/sglang/srt/layers/layernorm.py
+91
-3
python/sglang/test/test_layernorm.py
python/sglang/test/test_layernorm.py
+73
-1
No files found.
python/sglang/srt/layers/attention/nsa/nsa_indexer.py
View file @
e607850f
...
@@ -4,11 +4,10 @@ from abc import ABC, abstractmethod
...
@@ -4,11 +4,10 @@ from abc import ABC, abstractmethod
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
Optional
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
Optional
import
torch
import
torch
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
einops
import
rearrange
from
torch
import
nn
from
sglang.srt.custom_op
import
CustomOp
from
sglang.srt.custom_op
import
CustomOp
from
sglang.srt.layers.layernorm
import
LayerNorm
from
sglang.srt.utils
import
add_prefix
,
align
,
is_cuda
,
is_hip
,
is_npu
from
sglang.srt.utils
import
add_prefix
,
align
,
is_cuda
,
is_hip
,
is_npu
if
is_cuda
():
if
is_cuda
():
...
@@ -83,24 +82,6 @@ def rotate_activation(x: torch.Tensor) -> torch.Tensor:
...
@@ -83,24 +82,6 @@ def rotate_activation(x: torch.Tensor) -> torch.Tensor:
return
hadamard_transform
(
x
,
scale
=
hidden_size
**-
0.5
)
return
hadamard_transform
(
x
,
scale
=
hidden_size
**-
0.5
)
class
V32LayerNorm
(
nn
.
Module
):
"""
Layer Normalization.
"""
def
__init__
(
self
,
dim
:
int
,
eps
:
float
=
1e-6
):
super
().
__init__
()
self
.
dim
=
dim
self
.
eps
=
eps
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
dim
,
dtype
=
torch
.
float32
))
self
.
bias
=
nn
.
Parameter
(
torch
.
zeros
(
dim
,
dtype
=
torch
.
float32
))
def
forward
(
self
,
x
:
torch
.
Tensor
):
return
F
.
layer_norm
(
x
.
float
(),
(
self
.
dim
,),
self
.
weight
,
self
.
bias
,
self
.
eps
).
type_as
(
x
)
class
Indexer
(
CustomOp
):
class
Indexer
(
CustomOp
):
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -164,7 +145,7 @@ class Indexer(CustomOp):
...
@@ -164,7 +145,7 @@ class Indexer(CustomOp):
bias
=
False
,
bias
=
False
,
prefix
=
add_prefix
(
"weights_proj"
,
prefix
),
prefix
=
add_prefix
(
"weights_proj"
,
prefix
),
)
)
self
.
k_norm
=
V32
LayerNorm
(
self
.
head_dim
)
self
.
k_norm
=
LayerNorm
(
self
.
head_dim
,
dtype
=
torch
.
float32
)
self
.
rotary_emb
=
get_rope_wrapper
(
self
.
rotary_emb
=
get_rope_wrapper
(
rope_head_dim
,
rope_head_dim
,
rotary_dim
=
rope_head_dim
,
rotary_dim
=
rope_head_dim
,
...
...
python/sglang/srt/layers/layernorm.py
View file @
e607850f
...
@@ -18,6 +18,7 @@ from typing import Optional, Tuple, Union
...
@@ -18,6 +18,7 @@ from typing import Optional, Tuple, Union
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
packaging.version
import
Version
from
packaging.version
import
Version
from
sglang.srt.batch_invariant_ops
import
(
from
sglang.srt.batch_invariant_ops
import
(
...
@@ -46,11 +47,19 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
...
@@ -46,11 +47,19 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
_is_cpu_amx_available
=
cpu_has_amx_support
()
_is_cpu_amx_available
=
cpu_has_amx_support
()
_is_cpu
=
is_cpu
()
_is_cpu
=
is_cpu
()
_is_xpu
=
is_xpu
()
_is_xpu
=
is_xpu
()
_flashinfer_layernorm_available
=
False
if
_is_cuda
or
_is_xpu
:
if
_is_cuda
or
_is_xpu
:
# if _is_flashinfer_available:
if
_is_flashinfer_available
:
# from flashinfer.norm import fused_add_rmsnorm
try
:
# else:
from
flashinfer.norm
import
layernorm
_flashinfer_layernorm_available
=
True
except
(
ImportError
,
AttributeError
):
_flashinfer_layernorm_available
=
False
else
:
_flashinfer_layernorm_available
=
False
from
sgl_kernel
import
(
from
sgl_kernel
import
(
fused_add_rmsnorm
,
fused_add_rmsnorm
,
gemma_fused_add_rmsnorm
,
gemma_fused_add_rmsnorm
,
...
@@ -289,6 +298,85 @@ class RMSNorm(CustomOp):
...
@@ -289,6 +298,85 @@ class RMSNorm(CustomOp):
return
self
.
forward
(
x
,
residual
)
return
self
.
forward
(
x
,
residual
)
class
LayerNorm
(
CustomOp
):
def
__init__
(
self
,
hidden_size
:
int
,
eps
:
float
=
1e-6
,
elementwise_affine
:
bool
=
True
,
bias
:
bool
=
True
,
dtype
:
torch
.
dtype
=
torch
.
float32
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
variance_epsilon
=
eps
self
.
elementwise_affine
=
elementwise_affine
self
.
use_bias
=
bias
self
.
dtype
=
dtype
self
.
bias
=
nn
.
Parameter
(
torch
.
zeros
(
hidden_size
,
dtype
=
self
.
dtype
))
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
hidden_size
,
dtype
=
self
.
dtype
))
def
forward_cuda
(
self
,
x
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
if
(
_flashinfer_layernorm_available
and
x
.
dtype
==
torch
.
bfloat16
and
self
.
dtype
==
torch
.
float32
):
return
layernorm
(
x
,
self
.
weight
,
self
.
bias
,
self
.
variance_epsilon
)
else
:
return
self
.
forward_native
(
x
)
def
forward_native
(
self
,
x
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
weight
=
self
.
weight
if
self
.
elementwise_affine
else
None
bias
=
self
.
bias
if
self
.
use_bias
else
None
orig_dtype
=
x
.
dtype
x
=
x
.
to
(
self
.
dtype
)
return
F
.
layer_norm
(
x
,
(
self
.
hidden_size
,),
weight
=
self
.
weight
,
bias
=
bias
,
eps
=
self
.
variance_epsilon
,
).
to
(
orig_dtype
)
def
forward_hip
(
self
,
x
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
return
self
.
forward_native
(
x
)
def
forward_npu
(
self
,
x
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
orig_dtype
=
x
.
dtype
x
=
x
.
to
(
self
.
dtype
)
mean
=
x
.
mean
(
dim
=-
1
,
keepdim
=
True
)
variance
=
(
x
-
mean
).
pow
(
2
).
mean
(
dim
=-
1
,
keepdim
=
True
)
x
=
(
x
-
mean
)
*
torch
.
rsqrt
(
variance
+
self
.
variance_epsilon
)
if
self
.
elementwise_affine
:
x
=
x
*
self
.
weight
.
to
(
self
.
dtype
)
if
self
.
use_bias
:
x
=
x
+
self
.
bias
.
to
(
self
.
dtype
)
return
x
.
to
(
orig_dtype
)
def
forward_cpu
(
self
,
x
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
return
self
.
forward_native
(
x
)
class
GemmaRMSNorm
(
CustomOp
):
class
GemmaRMSNorm
(
CustomOp
):
def
__init__
(
def
__init__
(
self
,
self
,
...
...
python/sglang/test/test_layernorm.py
View file @
e607850f
...
@@ -3,7 +3,7 @@ import unittest
...
@@ -3,7 +3,7 @@ import unittest
import
torch
import
torch
from
sglang.srt.layers.layernorm
import
GemmaRMSNorm
,
RMSNorm
from
sglang.srt.layers.layernorm
import
GemmaRMSNorm
,
LayerNorm
,
RMSNorm
from
sglang.test.test_utils
import
CustomTestCase
from
sglang.test.test_utils
import
CustomTestCase
...
@@ -109,5 +109,77 @@ class TestGemmaRMSNorm(CustomTestCase):
...
@@ -109,5 +109,77 @@ class TestGemmaRMSNorm(CustomTestCase):
self
.
_run_gemma_rms_norm_test
(
*
params
)
self
.
_run_gemma_rms_norm_test
(
*
params
)
class
TestLayerNorm
(
CustomTestCase
):
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
]
PARAM_DTYPES
=
[
torch
.
bfloat16
,
torch
.
float32
]
NUM_TOKENS
=
[
7
,
83
,
1024
]
HIDDEN_SIZES
=
[
128
,
512
,
1536
,
5120
,
5124
,
5125
,
5126
,
7168
]
USE_AFFINE
=
[
False
,
True
]
USE_BIAS
=
[
False
,
True
]
SEEDS
=
[
0
]
@
classmethod
def
setUpClass
(
cls
):
if
not
torch
.
cuda
.
is_available
():
raise
unittest
.
SkipTest
(
"CUDA is not available"
)
torch
.
set_default_device
(
"cuda"
)
def
_run_layer_norm_test
(
self
,
num_tokens
,
hidden_size
,
use_affine
,
use_bias
,
dtype
,
seed
,
param_dtype
):
torch
.
manual_seed
(
seed
)
layer
=
LayerNorm
(
hidden_size
,
elementwise_affine
=
use_affine
,
bias
=
use_bias
,
dtype
=
param_dtype
)
if
use_affine
:
layer
.
weight
.
data
.
normal_
(
mean
=
1.0
,
std
=
0.1
)
if
use_bias
:
layer
.
bias
.
data
.
normal_
(
mean
=
0.0
,
std
=
0.1
)
scale
=
1
/
(
2
*
hidden_size
)
x
=
torch
.
randn
(
num_tokens
,
hidden_size
,
dtype
=
dtype
)
*
scale
with
torch
.
inference_mode
():
ref_out
=
layer
.
forward_native
(
x
)
out
=
layer
(
x
)
self
.
assertTrue
(
torch
.
allclose
(
out
,
ref_out
,
atol
=
1e-2
,
rtol
=
1e-3
))
if
(
use_affine
and
use_bias
and
not
(
dtype
==
torch
.
bfloat16
and
param_dtype
==
torch
.
float32
)
):
layer
.
dtype
=
torch
.
float32
layer
.
weight
.
data
=
layer
.
weight
.
data
.
to
(
torch
.
float32
)
layer
.
bias
.
data
=
layer
.
bias
.
data
.
to
(
torch
.
float32
)
with
torch
.
inference_mode
():
cuda_out
=
layer
(
x
.
to
(
torch
.
bfloat16
)).
to
(
x
.
dtype
)
self
.
assertTrue
(
torch
.
allclose
(
cuda_out
,
ref_out
,
atol
=
2e-2
,
rtol
=
1e-3
))
def
test_layer_norm
(
self
):
for
params
in
itertools
.
product
(
self
.
NUM_TOKENS
,
self
.
HIDDEN_SIZES
,
self
.
USE_AFFINE
,
self
.
USE_BIAS
,
self
.
DTYPES
,
self
.
SEEDS
,
self
.
PARAM_DTYPES
,
):
with
self
.
subTest
(
num_tokens
=
params
[
0
],
hidden_size
=
params
[
1
],
use_affine
=
params
[
2
],
use_bias
=
params
[
3
],
dtype
=
params
[
4
],
seed
=
params
[
5
],
param_dtype
=
params
[
6
],
):
self
.
_run_layer_norm_test
(
*
params
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
(
verbosity
=
2
)
unittest
.
main
(
verbosity
=
2
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment