Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
7353fb9b
Unverified
Commit
7353fb9b
authored
Jan 22, 2025
by
Yineng Zhang
Committed by
GitHub
Jan 22, 2025
Browse files
feat: integrate norm kernels into sgl-kernel (#3052)
parent
bcda0c9e
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
195 additions
and
42 deletions
+195
-42
sgl-kernel/src/sgl-kernel/__init__.py
sgl-kernel/src/sgl-kernel/__init__.py
+11
-5
sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu
sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu
+16
-0
sgl-kernel/src/sgl-kernel/ops/__init__.py
sgl-kernel/src/sgl-kernel/ops/__init__.py
+39
-6
sgl-kernel/tests/test_norm.py
sgl-kernel/tests/test_norm.py
+129
-0
sgl-kernel/tests/test_rmsnorm.py
sgl-kernel/tests/test_rmsnorm.py
+0
-31
No files found.
sgl-kernel/src/sgl-kernel/__init__.py
View file @
7353fb9b
from
sgl_kernel.ops
import
(
from
sgl_kernel.ops
import
(
custom_dispose
,
custom_dispose
,
custom_reduce
,
custom_reduce
,
fused_add_rmsnorm
,
gemma_fused_add_rmsnorm
,
gemma_rmsnorm
,
get_graph_buffer_ipc_meta
,
get_graph_buffer_ipc_meta
,
init_custom_reduce
,
init_custom_reduce
,
int8_scaled_mm
,
int8_scaled_mm
,
...
@@ -12,14 +15,17 @@ from sgl_kernel.ops import (
...
@@ -12,14 +15,17 @@ from sgl_kernel.ops import (
)
)
__all__
=
[
__all__
=
[
"moe_align_block_size"
,
"init_custom_reduce"
,
"custom_dispose"
,
"custom_dispose"
,
"custom_reduce"
,
"custom_reduce"
,
"int8_scaled_mm"
,
"fused_add_rmsnorm"
,
"sampling_scaling_penalties"
,
"gemma_fused_add_rmsnorm"
,
"gemma_rmsnorm"
,
"get_graph_buffer_ipc_meta"
,
"get_graph_buffer_ipc_meta"
,
"init_custom_reduce"
,
"int8_scaled_mm"
,
"moe_align_block_size"
,
"register_graph_buffers"
,
"register_graph_buffers"
,
"rotary_embedding"
,
"rmsnorm"
,
"rmsnorm"
,
"rotary_embedding"
,
"sampling_scaling_penalties"
,
]
]
sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu
View file @
7353fb9b
...
@@ -33,6 +33,16 @@ void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, torch::Ten
...
@@ -33,6 +33,16 @@ void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, torch::Ten
// rms norm
// rms norm
void
rmsnorm
(
at
::
Tensor
&
output
,
at
::
Tensor
&
input
,
at
::
Tensor
&
weight
,
double
eps
,
int64_t
cuda_stream
);
void
rmsnorm
(
at
::
Tensor
&
output
,
at
::
Tensor
&
input
,
at
::
Tensor
&
weight
,
double
eps
,
int64_t
cuda_stream
);
// fused rms norm
void
fused_add_rmsnorm
(
at
::
Tensor
&
input
,
at
::
Tensor
&
residual
,
at
::
Tensor
&
weight
,
double
eps
,
int64_t
cuda_stream
);
// gemma rms norm
void
gemma_rmsnorm
(
at
::
Tensor
&
output
,
at
::
Tensor
&
input
,
at
::
Tensor
&
weight
,
double
eps
,
int64_t
cuda_stream
);
// fused gemma rms norm
void
gemma_fused_add_rmsnorm
(
at
::
Tensor
&
input
,
at
::
Tensor
&
residual
,
at
::
Tensor
&
weight
,
double
eps
,
int64_t
cuda_stream
);
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
// trt_reduce
// trt_reduce
m
.
def
(
"init_custom_ar"
,
&
init_custom_ar
,
"init custom allreduce meta (CUDA)"
);
m
.
def
(
"init_custom_ar"
,
&
init_custom_ar
,
"init custom allreduce meta (CUDA)"
);
...
@@ -50,4 +60,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -50,4 +60,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m
.
def
(
"rotary_embedding"
,
&
rotary_embedding
,
"Rotary Embedding (CUDA)"
);
m
.
def
(
"rotary_embedding"
,
&
rotary_embedding
,
"Rotary Embedding (CUDA)"
);
// rms norm
// rms norm
m
.
def
(
"rmsnorm"
,
&
rmsnorm
,
"RMSNorm (CUDA)"
);
m
.
def
(
"rmsnorm"
,
&
rmsnorm
,
"RMSNorm (CUDA)"
);
// fused rms norm
m
.
def
(
"fused_add_rmsnorm"
,
&
fused_add_rmsnorm
,
"Fused Add RMSNorm (CUDA)"
);
// gemma rms norm
m
.
def
(
"gemma_rmsnorm"
,
&
gemma_rmsnorm
,
"Gemma RMSNorm (CUDA)"
);
// fused gemma rms norm
m
.
def
(
"gemma_fused_add_rmsnorm"
,
&
gemma_fused_add_rmsnorm
,
"Gemma Fused Add RMSNorm (CUDA)"
);
}
}
sgl-kernel/src/sgl-kernel/ops/__init__.py
View file @
7353fb9b
...
@@ -3,6 +3,9 @@ from typing import Optional
...
@@ -3,6 +3,9 @@ from typing import Optional
import
torch
import
torch
from
sgl_kernel.ops._kernels
import
all_reduce
as
_all_reduce
from
sgl_kernel.ops._kernels
import
all_reduce
as
_all_reduce
from
sgl_kernel.ops._kernels
import
dispose
as
_dispose
from
sgl_kernel.ops._kernels
import
dispose
as
_dispose
from
sgl_kernel.ops._kernels
import
fused_add_rmsnorm
as
_fused_add_rmsnorm
from
sgl_kernel.ops._kernels
import
gemma_fused_add_rmsnorm
as
_gemma_fused_add_rmsnorm
from
sgl_kernel.ops._kernels
import
gemma_rmsnorm
as
_gemma_rmsnorm
from
sgl_kernel.ops._kernels
import
(
from
sgl_kernel.ops._kernels
import
(
get_graph_buffer_ipc_meta
as
_get_graph_buffer_ipc_meta
,
get_graph_buffer_ipc_meta
as
_get_graph_buffer_ipc_meta
,
)
)
...
@@ -17,6 +20,10 @@ from sgl_kernel.ops._kernels import (
...
@@ -17,6 +20,10 @@ from sgl_kernel.ops._kernels import (
)
)
def
get_cuda_stream
(
device
:
torch
.
device
)
->
int
:
return
torch
.
cuda
.
current_stream
(
device
).
cuda_stream
def
init_custom_reduce
(
def
init_custom_reduce
(
rank_id
,
num_devices
,
rank_data
,
buffers
,
tmp_buffers
,
barrier_in
,
barrier_out
rank_id
,
num_devices
,
rank_data
,
buffers
,
tmp_buffers
,
barrier_in
,
barrier_out
):
):
...
@@ -88,9 +95,35 @@ def rmsnorm(
...
@@ -88,9 +95,35 @@ def rmsnorm(
eps
:
float
=
1e-6
,
eps
:
float
=
1e-6
,
out
:
Optional
[
torch
.
Tensor
]
=
None
,
out
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
out
is
None
:
with
input
.
device
as
device
:
out
=
torch
.
empty_like
(
input
)
if
out
is
None
:
stream
=
torch
.
cuda
.
current_stream
().
cuda_stream
out
=
torch
.
empty_like
(
input
)
stream_int
=
int
(
stream
)
_rmsnorm
(
out
,
input
,
weight
,
eps
,
get_cuda_stream
(
device
))
_rmsnorm
(
out
,
input
,
weight
,
eps
,
stream_int
)
return
out
return
out
def
fused_add_rmsnorm
(
input
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
eps
:
float
=
1e-6
)
->
None
:
with
input
.
device
as
device
:
_fused_add_rmsnorm
(
input
,
residual
,
weight
,
eps
,
get_cuda_stream
(
device
))
def
gemma_rmsnorm
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
eps
:
float
=
1e-6
,
out
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
with
input
.
device
as
device
:
if
out
is
None
:
out
=
torch
.
empty_like
(
input
)
_gemma_rmsnorm
(
out
,
input
,
weight
,
eps
,
get_cuda_stream
(
device
))
return
out
def
gemma_fused_add_rmsnorm
(
input
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
eps
:
float
=
1e-6
)
->
None
:
with
input
.
device
as
device
:
_gemma_fused_add_rmsnorm
(
input
,
residual
,
weight
,
eps
,
get_cuda_stream
(
device
))
sgl-kernel/tests/test_norm.py
0 → 100644
View file @
7353fb9b
# Adapted from https://github.com/flashinfer-ai/flashinfer/blob/4e8eb1879f9c3ba6d75511e5893183bf8f289a62/tests/test_norm.py
import
pytest
import
sgl_kernel
import
torch
def
llama_rms_norm
(
x
,
w
,
eps
=
1e-6
):
orig_dtype
=
x
.
dtype
x
=
x
.
float
()
variance
=
x
.
pow
(
2
).
mean
(
dim
=-
1
,
keepdim
=
True
)
x
=
x
*
torch
.
rsqrt
(
variance
+
eps
)
x
=
x
*
w
.
float
()
x
=
x
.
to
(
orig_dtype
)
return
x
def
gemma_rms_norm
(
x
,
w
,
eps
=
1e-6
):
orig_dtype
=
x
.
dtype
x
=
x
.
float
()
variance
=
x
.
pow
(
2
).
mean
(
dim
=-
1
,
keepdim
=
True
)
x
=
x
*
torch
.
rsqrt
(
variance
+
eps
)
x
=
x
*
(
1.0
+
w
.
float
())
x
=
x
.
to
(
orig_dtype
)
return
x
def
gemma_fused_add_rms_norm
(
x
,
residual
,
w
,
eps
=
1e-6
):
orig_dtype
=
x
.
dtype
x
=
x
+
residual
residual
=
x
x
=
x
.
float
()
variance
=
x
.
pow
(
2
).
mean
(
dim
=-
1
,
keepdim
=
True
)
x
=
x
*
torch
.
rsqrt
(
variance
+
eps
)
x
=
x
*
(
1.0
+
w
.
float
())
x
=
x
.
to
(
orig_dtype
)
return
x
,
residual
def
fused_add_rms_norm
(
x
,
residual
,
weight
,
eps
):
orig_dtype
=
x
.
dtype
x
=
x
.
to
(
torch
.
float32
)
x
=
x
+
residual
.
to
(
torch
.
float32
)
residual
=
x
.
to
(
orig_dtype
)
variance
=
x
.
pow
(
2
).
mean
(
dim
=-
1
,
keepdim
=
True
)
x
=
x
*
torch
.
rsqrt
(
variance
+
eps
)
x
=
(
x
*
weight
.
float
()).
to
(
orig_dtype
)
return
x
,
residual
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
19
,
99
,
989
])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
111
,
500
,
1024
,
3072
,
3584
,
4096
,
8192
,
16384
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"specify_out"
,
[
True
,
False
])
def
test_norm
(
batch_size
,
hidden_size
,
dtype
,
specify_out
):
x
=
torch
.
randn
(
batch_size
,
hidden_size
).
to
(
0
).
to
(
dtype
)
w
=
torch
.
randn
(
hidden_size
).
to
(
0
).
to
(
dtype
)
y_ref
=
llama_rms_norm
(
x
,
w
)
if
specify_out
:
y
=
torch
.
empty_like
(
x
)
sgl_kernel
.
rmsnorm
(
x
,
w
,
out
=
y
)
else
:
y
=
sgl_kernel
.
rmsnorm
(
x
,
w
)
torch
.
testing
.
assert_close
(
y_ref
,
y
,
rtol
=
1e-3
,
atol
=
1e-3
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
19
,
99
,
989
])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
111
,
500
,
1024
,
3072
,
3584
,
4096
,
8192
,
16384
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
])
def
test_fused_add_rmsnorm
(
batch_size
,
hidden_size
,
dtype
):
eps
=
1e-6
x
=
torch
.
randn
(
batch_size
,
hidden_size
,
dtype
=
dtype
,
device
=
"cuda"
)
residual
=
torch
.
randn_like
(
x
)
weight
=
torch
.
randn
(
hidden_size
,
dtype
=
dtype
,
device
=
"cuda"
)
x_native
,
residual_native
=
fused_add_rms_norm
(
x
.
clone
(),
residual
.
clone
(),
weight
,
eps
)
x_fused
=
x
.
clone
()
residual_fused
=
residual
.
clone
()
sgl_kernel
.
fused_add_rmsnorm
(
x_fused
,
residual_fused
,
weight
,
eps
)
torch
.
testing
.
assert_close
(
x_fused
,
x_native
,
rtol
=
1e-3
,
atol
=
1e-3
)
torch
.
testing
.
assert_close
(
residual_fused
,
residual_native
,
rtol
=
1e-3
,
atol
=
1e-3
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
19
,
99
,
989
])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
111
,
500
,
1024
,
3072
,
3584
,
4096
,
8192
,
16384
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"specify_out"
,
[
True
,
False
])
def
test_gemma_norm
(
batch_size
,
hidden_size
,
dtype
,
specify_out
):
x
=
torch
.
randn
(
batch_size
,
hidden_size
).
to
(
0
).
to
(
dtype
)
w
=
torch
.
randn
(
hidden_size
).
to
(
0
).
to
(
dtype
)
y_ref
=
gemma_rms_norm
(
x
,
w
)
if
specify_out
:
y
=
torch
.
empty_like
(
x
)
sgl_kernel
.
gemma_rmsnorm
(
x
,
w
,
out
=
y
)
else
:
y
=
sgl_kernel
.
gemma_rmsnorm
(
x
,
w
)
torch
.
testing
.
assert_close
(
y_ref
,
y
,
rtol
=
1e-3
,
atol
=
1e-3
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
19
,
99
,
989
])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
111
,
500
,
1024
,
3072
,
3584
,
4096
,
8192
,
16384
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
])
def
test_gemma_fused_add_rmsnorm
(
batch_size
,
hidden_size
,
dtype
):
eps
=
1e-6
x
=
torch
.
randn
(
batch_size
,
hidden_size
,
dtype
=
dtype
,
device
=
"cuda"
)
residual
=
torch
.
randn_like
(
x
)
weight
=
torch
.
randn
(
hidden_size
,
dtype
=
dtype
,
device
=
"cuda"
)
x_native
,
residual_native
=
gemma_fused_add_rms_norm
(
x
.
clone
(),
residual
.
clone
(),
weight
,
eps
)
x_fused
=
x
.
clone
()
residual_fused
=
residual
.
clone
()
sgl_kernel
.
gemma_fused_add_rmsnorm
(
x_fused
,
residual_fused
,
weight
,
eps
)
torch
.
testing
.
assert_close
(
x_fused
,
x_native
,
rtol
=
1e-3
,
atol
=
1e-3
)
torch
.
testing
.
assert_close
(
residual_fused
,
residual_native
,
rtol
=
1e-3
,
atol
=
1e-3
)
sgl-kernel/tests/test_rmsnorm.py
deleted
100644 → 0
View file @
bcda0c9e
import
pytest
import
torch
from
sgl_kernel
import
rmsnorm
def
llama_rms_norm
(
x
,
w
,
eps
=
1e-6
):
orig_dtype
=
x
.
dtype
x
=
x
.
float
()
variance
=
x
.
pow
(
2
).
mean
(
dim
=-
1
,
keepdim
=
True
)
x
=
x
*
torch
.
rsqrt
(
variance
+
eps
)
x
=
x
*
w
.
float
()
x
=
x
.
to
(
orig_dtype
)
return
x
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
19
,
99
,
989
])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
111
,
500
,
1024
,
3072
,
3584
,
4096
,
8192
,
16384
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"specify_out"
,
[
True
,
False
])
def
test_norm
(
batch_size
,
hidden_size
,
dtype
,
specify_out
):
x
=
torch
.
randn
(
batch_size
,
hidden_size
).
to
(
0
).
to
(
dtype
)
w
=
torch
.
randn
(
hidden_size
).
to
(
0
).
to
(
dtype
)
y_ref
=
llama_rms_norm
(
x
,
w
)
if
specify_out
:
y
=
torch
.
empty_like
(
x
)
rmsnorm
(
x
,
w
,
out
=
y
)
else
:
y
=
rmsnorm
(
x
,
w
)
torch
.
testing
.
assert_close
(
y_ref
,
y
,
rtol
=
1e-3
,
atol
=
1e-3
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment