Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
9d9b482a
Unverified
Commit
9d9b482a
authored
Jan 22, 2025
by
Yineng Zhang
Committed by
GitHub
Jan 22, 2025
Browse files
feat: integrate activation kernels into sgl-kernel (#3053)
parent
7353fb9b
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
120 additions
and
0 deletions
+120
-0
sgl-kernel/src/sgl-kernel/__init__.py
sgl-kernel/src/sgl-kernel/__init__.py
+6
-0
sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu
sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu
+15
-0
sgl-kernel/src/sgl-kernel/ops/__init__.py
sgl-kernel/src/sgl-kernel/ops/__init__.py
+61
-0
sgl-kernel/tests/test_activation.py
sgl-kernel/tests/test_activation.py
+38
-0
No files found.
sgl-kernel/src/sgl-kernel/__init__.py
View file @
9d9b482a
...
...
@@ -2,6 +2,8 @@ from sgl_kernel.ops import (
custom_dispose
,
custom_reduce
,
fused_add_rmsnorm
,
gelu_and_mul
,
gelu_tanh_and_mul
,
gemma_fused_add_rmsnorm
,
gemma_rmsnorm
,
get_graph_buffer_ipc_meta
,
...
...
@@ -12,12 +14,15 @@ from sgl_kernel.ops import (
rmsnorm
,
rotary_embedding
,
sampling_scaling_penalties
,
silu_and_mul
,
)
__all__
=
[
"custom_dispose"
,
"custom_reduce"
,
"fused_add_rmsnorm"
,
"gelu_and_mul"
,
"gelu_tanh_and_mul"
,
"gemma_fused_add_rmsnorm"
,
"gemma_rmsnorm"
,
"get_graph_buffer_ipc_meta"
,
...
...
@@ -28,4 +33,5 @@ __all__ = [
"rmsnorm"
,
"rotary_embedding"
,
"sampling_scaling_penalties"
,
"silu_and_mul"
,
]
sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu
View file @
9d9b482a
...
...
@@ -43,6 +43,15 @@ void gemma_rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, do
void
gemma_fused_add_rmsnorm
(
at
::
Tensor
&
input
,
at
::
Tensor
&
residual
,
at
::
Tensor
&
weight
,
double
eps
,
int64_t
cuda_stream
);
// silu and mul
void
silu_and_mul
(
at
::
Tensor
&
out
,
at
::
Tensor
&
input
,
int64_t
cuda_stream
);
// gelu tanh and mul
void
gelu_tanh_and_mul
(
at
::
Tensor
&
out
,
at
::
Tensor
&
input
,
int64_t
cuda_stream
);
// gelu and mul
void
gelu_and_mul
(
at
::
Tensor
&
out
,
at
::
Tensor
&
input
,
int64_t
cuda_stream
);
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
// trt_reduce
m
.
def
(
"init_custom_ar"
,
&
init_custom_ar
,
"init custom allreduce meta (CUDA)"
);
...
...
@@ -66,4 +75,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m
.
def
(
"gemma_rmsnorm"
,
&
gemma_rmsnorm
,
"Gemma RMSNorm (CUDA)"
);
// fused gemma rms norm
m
.
def
(
"gemma_fused_add_rmsnorm"
,
&
gemma_fused_add_rmsnorm
,
"Gemma Fused Add RMSNorm (CUDA)"
);
// silu and mul
m
.
def
(
"silu_and_mul"
,
&
silu_and_mul
,
"Silu and Mul (CUDA)"
);
// gelu tanh and mul
m
.
def
(
"gelu_tanh_and_mul"
,
&
gelu_tanh_and_mul
,
"Gelu Tanh and Mul (CUDA)"
);
// gelu and mul
m
.
def
(
"gelu_and_mul"
,
&
gelu_and_mul
,
"Gelu and Mul (CUDA)"
);
}
sgl-kernel/src/sgl-kernel/ops/__init__.py
View file @
9d9b482a
...
...
@@ -4,6 +4,8 @@ import torch
from
sgl_kernel.ops._kernels
import
all_reduce
as
_all_reduce
from
sgl_kernel.ops._kernels
import
dispose
as
_dispose
from
sgl_kernel.ops._kernels
import
fused_add_rmsnorm
as
_fused_add_rmsnorm
from
sgl_kernel.ops._kernels
import
gelu_and_mul
as
_gelu_and_mul
from
sgl_kernel.ops._kernels
import
gelu_tanh_and_mul
as
_gelu_tanh_and_mul
from
sgl_kernel.ops._kernels
import
gemma_fused_add_rmsnorm
as
_gemma_fused_add_rmsnorm
from
sgl_kernel.ops._kernels
import
gemma_rmsnorm
as
_gemma_rmsnorm
from
sgl_kernel.ops._kernels
import
(
...
...
@@ -18,6 +20,7 @@ from sgl_kernel.ops._kernels import rotary_embedding as _rotary_embedding
from
sgl_kernel.ops._kernels
import
(
sampling_scaling_penalties
as
_sampling_scaling_penalties
,
)
from
sgl_kernel.ops._kernels
import
silu_and_mul
as
_silu_and_mul
def
get_cuda_stream
(
device
:
torch
.
device
)
->
int
:
...
...
@@ -127,3 +130,61 @@ def gemma_fused_add_rmsnorm(
)
->
None
:
with
input
.
device
as
device
:
_gemma_fused_add_rmsnorm
(
input
,
residual
,
weight
,
eps
,
get_cuda_stream
(
device
))
def
_check_shape
(
input
:
torch
.
Tensor
,
output
:
torch
.
Tensor
)
->
None
:
assert
input
.
ndim
==
output
.
ndim
,
f
"
{
input
.
ndim
}
!=
{
output
.
ndim
}
"
assert
(
input
.
shape
[:
-
1
]
==
output
.
shape
[:
-
1
]
),
f
"
{
input
.
shape
[:
-
1
]
}
!=
{
output
.
shape
[:
-
1
]
}
"
assert
(
input
.
shape
[
-
1
]
==
2
*
output
.
shape
[
-
1
]
),
f
"
{
input
.
shape
[
-
1
]
}
!=
{
2
*
output
.
shape
[
-
1
]
}
"
def
silu_and_mul
(
input
:
torch
.
Tensor
,
out
:
torch
.
Tensor
=
None
)
->
torch
.
Tensor
:
if
input
.
shape
[
-
1
]
*
input
.
dtype
.
itemsize
%
16
!=
0
:
raise
ValueError
(
"The pointers must be multiple of 16 bytes."
)
if
out
is
not
None
:
_check_shape
(
input
,
out
)
else
:
out
=
torch
.
empty
(
input
.
shape
[:
-
1
]
+
(
input
.
shape
[
-
1
]
//
2
,),
device
=
input
.
device
,
dtype
=
input
.
dtype
,
)
with
input
.
device
as
device
:
_silu_and_mul
(
out
,
input
,
get_cuda_stream
(
device
))
return
out
def
gelu_tanh_and_mul
(
input
:
torch
.
Tensor
,
out
:
torch
.
Tensor
=
None
)
->
torch
.
Tensor
:
if
input
.
shape
[
-
1
]
*
input
.
dtype
.
itemsize
%
16
!=
0
:
raise
ValueError
(
"The pointers must be multiple of 16 bytes."
)
if
out
is
not
None
:
_check_shape
(
input
,
out
)
else
:
out
=
torch
.
empty
(
input
.
shape
[:
-
1
]
+
(
input
.
shape
[
-
1
]
//
2
,),
device
=
input
.
device
,
dtype
=
input
.
dtype
,
)
with
input
.
device
as
device
:
_gelu_tanh_and_mul
(
out
,
input
,
get_cuda_stream
(
device
))
return
out
def
gelu_and_mul
(
input
:
torch
.
Tensor
,
out
:
torch
.
Tensor
=
None
)
->
torch
.
Tensor
:
if
input
.
shape
[
-
1
]
*
input
.
dtype
.
itemsize
%
16
!=
0
:
raise
ValueError
(
"The pointers must be multiple of 16 bytes."
)
if
out
is
not
None
:
_check_shape
(
input
,
out
)
else
:
out
=
torch
.
empty
(
input
.
shape
[:
-
1
]
+
(
input
.
shape
[
-
1
]
//
2
,),
device
=
input
.
device
,
dtype
=
input
.
dtype
,
)
with
input
.
device
as
device
:
_gelu_and_mul
(
out
,
input
,
get_cuda_stream
(
device
))
return
out
sgl-kernel/tests/test_activation.py
0 → 100644
View file @
9d9b482a
# Adapted from https://github.com/flashinfer-ai/flashinfer/blob/4e8eb1879f9c3ba6d75511e5893183bf8f289a62/tests/test_activation.py
import
pytest
import
sgl_kernel
import
torch
@
pytest
.
mark
.
parametrize
(
"dim"
,
[
128
,
256
,
512
,
2048
,
4096
,
11008
,
16384
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
2
,
4
,
8
,
16
])
@
pytest
.
mark
.
parametrize
(
"seq_len"
,
[
1
,
2
,
4
,
8
,
16
,
32
,
64
,
128
,
512
])
def
test_fused_silu_mul
(
dim
,
batch_size
,
seq_len
):
x
=
torch
.
randn
(
batch_size
,
seq_len
,
2
*
dim
).
to
(
0
).
to
(
torch
.
float16
)
y_ref
=
x
[...,
dim
:]
*
torch
.
nn
.
functional
.
silu
(
x
[...,
:
dim
])
y
=
sgl_kernel
.
silu_and_mul
(
x
)
torch
.
testing
.
assert_close
(
y_ref
,
y
,
rtol
=
1e-3
,
atol
=
1e-3
)
@
pytest
.
mark
.
parametrize
(
"dim"
,
[
128
,
256
,
512
,
2048
,
4096
,
11008
,
16384
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
2
,
4
,
8
,
16
])
@
pytest
.
mark
.
parametrize
(
"seq_len"
,
[
1
,
2
,
4
,
8
,
16
,
32
,
64
,
128
,
512
])
def
test_fused_gelu_tanh_mul
(
dim
,
batch_size
,
seq_len
):
x
=
torch
.
randn
(
batch_size
,
seq_len
,
2
*
dim
).
to
(
0
).
to
(
torch
.
float16
)
y_ref
=
x
[...,
dim
:]
*
torch
.
nn
.
functional
.
gelu
(
x
[...,
:
dim
],
approximate
=
"tanh"
)
y
=
sgl_kernel
.
gelu_tanh_and_mul
(
x
)
torch
.
testing
.
assert_close
(
y_ref
,
y
,
rtol
=
1e-3
,
atol
=
1e-3
)
@
pytest
.
mark
.
parametrize
(
"dim"
,
[
128
,
256
,
512
,
2048
,
4096
,
11008
,
16384
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
2
,
4
,
8
,
16
])
@
pytest
.
mark
.
parametrize
(
"seq_len"
,
[
1
,
2
,
4
,
8
,
16
,
32
,
64
,
128
,
512
])
def
test_fused_gelu_mul
(
dim
,
batch_size
,
seq_len
):
x
=
torch
.
randn
(
batch_size
,
seq_len
,
2
*
dim
).
to
(
0
).
to
(
torch
.
float16
)
y_ref
=
x
[...,
dim
:]
*
torch
.
nn
.
functional
.
gelu
(
x
[...,
:
dim
],
approximate
=
"none"
)
y
=
sgl_kernel
.
gelu_and_mul
(
x
)
torch
.
testing
.
assert_close
(
y_ref
,
y
,
rtol
=
1e-3
,
atol
=
1e-3
)
test_fused_silu_mul
(
128
,
1
,
1
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment