Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
d64bf164
Unverified
Commit
d64bf164
authored
Aug 23, 2023
by
Woosuk Kwon
Committed by
GitHub
Aug 23, 2023
Browse files
Implement approximate GELU kernels (#828)
parent
a41c2043
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
164 additions
and
18 deletions
+164
-18
csrc/activation.cpp
csrc/activation.cpp
+16
-0
csrc/activation_kernels.cu
csrc/activation_kernels.cu
+68
-0
tests/kernels/test_activation.py
tests/kernels/test_activation.py
+43
-1
vllm/model_executor/layers/activation.py
vllm/model_executor/layers/activation.py
+37
-17
No files found.
csrc/activation.cpp
View file @
d64bf164
...
...
@@ -4,9 +4,25 @@ void silu_and_mul(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
void
gelu_new
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
void
gelu_fast
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"silu_and_mul"
,
&
silu_and_mul
,
"Activation function used in SwiGLU."
);
m
.
def
(
"gelu_new"
,
&
gelu_new
,
"GELU implementation used in GPT-2."
);
m
.
def
(
"gelu_fast"
,
&
gelu_fast
,
"Approximate GELU implementation."
);
}
csrc/activation_kernels.cu
View file @
d64bf164
...
...
@@ -46,3 +46,71 @@ void silu_and_mul(
d
);
});
}
namespace
vllm
{
// Element-wise activation kernel template.
template
<
typename
scalar_t
,
scalar_t
(
*
ACT_FN
)(
const
scalar_t
&
)>
__global__
void
activation_kernel
(
scalar_t
*
__restrict__
out
,
// [num_tokens, d]
const
scalar_t
*
__restrict__
input
,
// [num_tokens, d]
const
int
d
)
{
const
int
token_idx
=
blockIdx
.
x
;
for
(
int
idx
=
threadIdx
.
x
;
idx
<
d
;
idx
+=
blockDim
.
x
)
{
const
scalar_t
x
=
__ldg
(
&
input
[
token_idx
*
d
+
idx
]);
out
[
token_idx
*
d
+
idx
]
=
ACT_FN
(
x
);
}
}
}
// namespace vllm
// Launch element-wise activation kernel.
#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \
int num_tokens = input.size(0); \
int d = input.size(1); \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
AT_DISPATCH_FLOATING_TYPES_AND2( \
at::ScalarType::Half, \
at::ScalarType::BFloat16, \
input.scalar_type(), \
"activation_kernel", \
[&] { \
vllm::activation_kernel<scalar_t, KERNEL<scalar_t>><<<grid, block, 0, stream>>>( \
out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), \
d); \
});
namespace
vllm
{
template
<
typename
T
>
__device__
__forceinline__
T
gelu_new_kernel
(
const
T
&
x
)
{
const
float
x3
=
(
float
)
(
x
*
x
*
x
);
const
T
t
=
(
T
)
tanhf
((
T
)
(
0.79788456
f
*
(
float
)
(
x
+
(
T
)
(
0.044715
f
*
x3
))));
return
((
T
)
0.5
)
*
x
*
(((
T
)
1.0
)
+
t
);
}
template
<
typename
T
>
__device__
__forceinline__
T
gelu_fast_kernel
(
const
T
&
x
)
{
const
float
f
=
(
float
)
x
;
const
T
t
=
(
T
)
tanhf
(((
T
)
(
f
*
0.79788456
f
))
*
(((
T
)
1.0
)
+
(
T
)
(
0.044715
f
*
f
)
*
x
));
return
((
T
)
0.5
)
*
x
*
(((
T
)
1.0
)
+
t
);
}
}
// namespace vllm
void
gelu_new
(
torch
::
Tensor
&
out
,
// [num_tokens, d]
torch
::
Tensor
&
input
)
// [num_tokens, d]
{
LAUNCH_ACTIVATION_KERNEL
(
vllm
::
gelu_new_kernel
);
}
void
gelu_fast
(
torch
::
Tensor
&
out
,
// [num_tokens, d]
torch
::
Tensor
&
input
)
// [num_tokens, d]
{
LAUNCH_ACTIVATION_KERNEL
(
vllm
::
gelu_fast_kernel
);
}
tests/kernels/test_activation.py
View file @
d64bf164
import
torch
import
torch.nn.functional
as
F
from
transformers.activations
import
get_activation
from
vllm
import
activation_ops
...
...
@@ -28,3 +28,45 @@ def test_silu_and_mul() -> None:
for
d
in
[
512
,
4096
,
5120
,
13824
]:
print
(
f
'Testing dtype=
{
dtype
}
, num_tokens=
{
num_tokens
}
, d=
{
d
}
'
)
run_silu_and_mul
(
num_tokens
,
d
,
dtype
)
@
torch
.
inference_mode
()
def
run_gelu_new
(
num_tokens
:
int
,
d
:
int
,
dtype
:
torch
.
dtype
,
)
->
None
:
x
=
torch
.
randn
(
num_tokens
,
d
,
dtype
=
dtype
,
device
=
'cuda'
)
out
=
torch
.
empty
(
num_tokens
,
d
,
dtype
=
dtype
,
device
=
'cuda'
)
activation_ops
.
gelu_new
(
out
,
x
)
ref_out
=
get_activation
(
"gelu_new"
)(
x
)
assert
torch
.
allclose
(
out
,
ref_out
,
atol
=
1e-5
,
rtol
=
1e-5
)
def
test_gelu_new
()
->
None
:
for
dtype
in
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]:
for
num_tokens
in
[
7
,
83
,
2048
]:
for
d
in
[
512
,
4096
,
5120
,
13824
]:
print
(
f
'Testing dtype=
{
dtype
}
, num_tokens=
{
num_tokens
}
, d=
{
d
}
'
)
run_gelu_new
(
num_tokens
,
d
,
dtype
)
@
torch
.
inference_mode
()
def
run_gelu_fast
(
num_tokens
:
int
,
d
:
int
,
dtype
:
torch
.
dtype
,
)
->
None
:
x
=
torch
.
randn
(
num_tokens
,
d
,
dtype
=
dtype
,
device
=
'cuda'
)
out
=
torch
.
empty
(
num_tokens
,
d
,
dtype
=
dtype
,
device
=
'cuda'
)
activation_ops
.
gelu_fast
(
out
,
x
)
ref_out
=
get_activation
(
"gelu_fast"
)(
x
)
assert
torch
.
allclose
(
out
,
ref_out
,
atol
=
1e-5
,
rtol
=
1e-5
)
def
test_gelu_fast
()
->
None
:
for
dtype
in
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]:
for
num_tokens
in
[
7
,
83
,
2048
]:
for
d
in
[
512
,
4096
,
5120
,
13824
]:
print
(
f
'Testing dtype=
{
dtype
}
, num_tokens=
{
num_tokens
}
, d=
{
d
}
'
)
run_gelu_fast
(
num_tokens
,
d
,
dtype
)
vllm/model_executor/layers/activation.py
View file @
d64bf164
...
...
@@ -4,23 +4,6 @@ import torch.nn as nn
from
vllm
import
activation_ops
_ACTIVATION_REGISTRY
=
{
"gelu"
:
nn
.
GELU
(),
# NOTE: The following GELU functions may introduce small rounding errors.
"gelu_new"
:
nn
.
GELU
(
approximate
=
"tanh"
),
"gelu_fast"
:
nn
.
GELU
(
approximate
=
"tanh"
),
"gelu_pytorch_tanh"
:
nn
.
GELU
(
approximate
=
"tanh"
),
"relu"
:
nn
.
ReLU
(),
}
def
get_act_fn
(
act_fn
:
str
)
->
nn
.
Module
:
"""Get an activation function by name."""
act_fn
=
act_fn
.
lower
()
if
act_fn
in
_ACTIVATION_REGISTRY
:
return
_ACTIVATION_REGISTRY
[
act_fn
]
raise
ValueError
(
f
"Activation function
{
act_fn
!
r
}
is not supported."
)
class
SiluAndMul
(
nn
.
Module
):
"""An activation function for SwiGLU.
...
...
@@ -38,3 +21,40 @@ class SiluAndMul(nn.Module):
out
=
torch
.
empty
(
num_tokens
,
d
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
activation_ops
.
silu_and_mul
(
out
,
x
)
return
out
class
NewGELU
(
nn
.
Module
):
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
num_tokens
=
x
.
shape
[
0
]
d
=
x
.
shape
[
1
]
out
=
torch
.
empty
(
num_tokens
,
d
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
activation_ops
.
gelu_new
(
out
,
x
)
return
out
class
FastGELU
(
nn
.
Module
):
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
num_tokens
=
x
.
shape
[
0
]
d
=
x
.
shape
[
1
]
out
=
torch
.
empty
(
num_tokens
,
d
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
activation_ops
.
gelu_fast
(
out
,
x
)
return
out
_ACTIVATION_REGISTRY
=
{
"gelu"
:
nn
.
GELU
(),
"gelu_fast"
:
FastGELU
(),
"gelu_new"
:
NewGELU
(),
"gelu_pytorch_tanh"
:
nn
.
GELU
(
approximate
=
"tanh"
),
"relu"
:
nn
.
ReLU
(),
}
def
get_act_fn
(
act_fn
:
str
)
->
nn
.
Module
:
"""Get an activation function by name."""
act_fn
=
act_fn
.
lower
()
if
act_fn
in
_ACTIVATION_REGISTRY
:
return
_ACTIVATION_REGISTRY
[
act_fn
]
raise
ValueError
(
f
"Activation function
{
act_fn
!
r
}
is not supported."
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment