Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
602358f8
Unverified
Commit
602358f8
authored
Mar 12, 2024
by
Woosuk Kwon
Committed by
GitHub
Mar 12, 2024
Browse files
Add kernel for GeGLU with approximate GELU (#3337)
parent
49a3c866
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
49 additions
and
7 deletions
+49
-7
csrc/activation_kernels.cu
csrc/activation_kernels.cu
+21
-1
csrc/ops.h
csrc/ops.h
+4
-0
csrc/pybind.cpp
csrc/pybind.cpp
+5
-1
tests/kernels/test_activation.py
tests/kernels/test_activation.py
+8
-3
vllm/model_executor/layers/activation.py
vllm/model_executor/layers/activation.py
+11
-2
No files found.
csrc/activation_kernels.cu
View file @
602358f8
...
@@ -33,12 +33,25 @@ template<typename T>
...
@@ -33,12 +33,25 @@ template<typename T>
__device__
__forceinline__
T
gelu_kernel
(
const
T
&
x
)
{
__device__
__forceinline__
T
gelu_kernel
(
const
T
&
x
)
{
// Equivalent to PyTorch GELU with 'none' approximation.
// Equivalent to PyTorch GELU with 'none' approximation.
// Refer to:
// Refer to:
// https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L38
// https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#
L36-
L38
const
float
f
=
(
float
)
x
;
const
float
f
=
(
float
)
x
;
constexpr
float
ALPHA
=
M_SQRT1_2
;
constexpr
float
ALPHA
=
M_SQRT1_2
;
return
(
T
)
(
f
*
0.5
f
*
(
1.0
f
+
::
erf
(
f
*
ALPHA
)));
return
(
T
)
(
f
*
0.5
f
*
(
1.0
f
+
::
erf
(
f
*
ALPHA
)));
}
}
template
<
typename
T
>
__device__
__forceinline__
T
gelu_tanh_kernel
(
const
T
&
x
)
{
// Equivalent to PyTorch GELU with 'tanh' approximation.
// Refer to:
// https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L25-L30
const
float
f
=
(
float
)
x
;
constexpr
float
BETA
=
M_SQRT2
*
M_2_SQRTPI
*
0.5
f
;
constexpr
float
KAPPA
=
0.044715
;
float
x_cube
=
f
*
f
*
f
;
float
inner
=
BETA
*
(
f
+
KAPPA
*
x_cube
);
return
(
T
)
(
0.5
f
*
f
*
(
1.0
f
+
::
tanhf
(
inner
)));
}
}
// namespace vllm
}
// namespace vllm
// Launch activation and gating kernel.
// Launch activation and gating kernel.
...
@@ -73,6 +86,13 @@ void gelu_and_mul(
...
@@ -73,6 +86,13 @@ void gelu_and_mul(
LAUNCH_ACTIVATION_GATE_KERNEL
(
vllm
::
gelu_kernel
);
LAUNCH_ACTIVATION_GATE_KERNEL
(
vllm
::
gelu_kernel
);
}
}
void
gelu_tanh_and_mul
(
torch
::
Tensor
&
out
,
// [..., d]
torch
::
Tensor
&
input
)
// [..., 2 * d]
{
LAUNCH_ACTIVATION_GATE_KERNEL
(
vllm
::
gelu_tanh_kernel
);
}
namespace
vllm
{
namespace
vllm
{
// Element-wise activation kernel template.
// Element-wise activation kernel template.
...
...
csrc/ops.h
View file @
602358f8
...
@@ -61,6 +61,10 @@ void gelu_and_mul(
...
@@ -61,6 +61,10 @@ void gelu_and_mul(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
torch
::
Tensor
&
input
);
void
gelu_tanh_and_mul
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
void
gelu_new
(
void
gelu_new
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
torch
::
Tensor
&
input
);
...
...
csrc/pybind.cpp
View file @
602358f8
...
@@ -25,7 +25,11 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -25,7 +25,11 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
ops
.
def
(
ops
.
def
(
"gelu_and_mul"
,
"gelu_and_mul"
,
&
gelu_and_mul
,
&
gelu_and_mul
,
"Activation function used in GeGLU."
);
"Activation function used in GeGLU with `none` approximation."
);
ops
.
def
(
"gelu_tanh_and_mul"
,
&
gelu_tanh_and_mul
,
"Activation function used in GeGLU with `tanh` approximation."
);
ops
.
def
(
ops
.
def
(
"gelu_new"
,
"gelu_new"
,
&
gelu_new
,
&
gelu_new
,
...
...
tests/kernels/test_activation.py
View file @
602358f8
...
@@ -16,7 +16,7 @@ CUDA_DEVICES = [
...
@@ -16,7 +16,7 @@ CUDA_DEVICES = [
]
]
@
pytest
.
mark
.
parametrize
(
"activation"
,
[
S
ilu
AndMul
,
GeluAndMul
])
@
pytest
.
mark
.
parametrize
(
"activation"
,
[
"s
ilu
"
,
"gelu"
,
"gelu_tanh"
])
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"d"
,
D
)
@
pytest
.
mark
.
parametrize
(
"d"
,
D
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
...
@@ -24,7 +24,7 @@ CUDA_DEVICES = [
...
@@ -24,7 +24,7 @@ CUDA_DEVICES = [
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_act_and_mul
(
def
test_act_and_mul
(
activation
:
Type
[
torch
.
nn
.
Module
]
,
activation
:
str
,
num_tokens
:
int
,
num_tokens
:
int
,
d
:
int
,
d
:
int
,
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
...
@@ -36,7 +36,12 @@ def test_act_and_mul(
...
@@ -36,7 +36,12 @@ def test_act_and_mul(
torch
.
cuda
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
x
=
torch
.
randn
(
num_tokens
,
2
*
d
,
dtype
=
dtype
)
x
=
torch
.
randn
(
num_tokens
,
2
*
d
,
dtype
=
dtype
)
layer
=
activation
()
if
activation
==
"silu"
:
layer
=
SiluAndMul
()
elif
activation
==
"gelu"
:
layer
=
GeluAndMul
(
approximate
=
"none"
)
elif
activation
==
"gelu_tanh"
:
layer
=
GeluAndMul
(
approximate
=
"tanh"
)
out
=
layer
(
x
)
out
=
layer
(
x
)
ref_out
=
layer
.
_forward
(
x
)
ref_out
=
layer
.
_forward
(
x
)
# The SiLU and GELU implementations are equivalent to the native PyTorch
# The SiLU and GELU implementations are equivalent to the native PyTorch
...
...
vllm/model_executor/layers/activation.py
View file @
602358f8
...
@@ -47,16 +47,25 @@ class GeluAndMul(nn.Module):
...
@@ -47,16 +47,25 @@ class GeluAndMul(nn.Module):
return: (batch_size, seq_len, d) or (num_tokens, d)
return: (batch_size, seq_len, d) or (num_tokens, d)
"""
"""
def
__init__
(
self
,
approximate
:
str
=
"none"
):
super
().
__init__
()
self
.
approximate
=
approximate
if
approximate
not
in
(
"none"
,
"tanh"
):
raise
ValueError
(
f
"Unknown approximate mode:
{
approximate
}
"
)
def
_forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
_forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""PyTorch-native implementation equivalent to forward()."""
"""PyTorch-native implementation equivalent to forward()."""
d
=
x
.
shape
[
-
1
]
//
2
d
=
x
.
shape
[
-
1
]
//
2
return
F
.
gelu
(
x
[...,
:
d
])
*
x
[...,
d
:]
return
F
.
gelu
(
x
[...,
:
d
]
,
approximate
=
self
.
approximate
)
*
x
[...,
d
:]
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
d
=
x
.
shape
[
-
1
]
//
2
d
=
x
.
shape
[
-
1
]
//
2
output_shape
=
(
x
.
shape
[:
-
1
]
+
(
d
,
))
output_shape
=
(
x
.
shape
[:
-
1
]
+
(
d
,
))
out
=
torch
.
empty
(
output_shape
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
out
=
torch
.
empty
(
output_shape
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
ops
.
gelu_and_mul
(
out
,
x
)
if
self
.
approximate
==
"none"
:
ops
.
gelu_and_mul
(
out
,
x
)
elif
self
.
approximate
==
"tanh"
:
ops
.
gelu_tanh_and_mul
(
out
,
x
)
return
out
return
out
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment