Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
d64bf164
"src/array/vscode:/vscode.git/clone" did not exist on "6069f34c4fd70de89e8f7d3370ee8581688f3b17"
Unverified
Commit
d64bf164
authored
Aug 23, 2023
by
Woosuk Kwon
Committed by
GitHub
Aug 23, 2023
Browse files
Implement approximate GELU kernels (#828)
parent
a41c2043
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
164 additions
and
18 deletions
+164
-18
csrc/activation.cpp
csrc/activation.cpp
+16
-0
csrc/activation_kernels.cu
csrc/activation_kernels.cu
+68
-0
tests/kernels/test_activation.py
tests/kernels/test_activation.py
+43
-1
vllm/model_executor/layers/activation.py
vllm/model_executor/layers/activation.py
+37
-17
No files found.
csrc/activation.cpp
View file @
d64bf164
...
...
@@ -4,9 +4,25 @@ void silu_and_mul(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
void
gelu_new
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
void
gelu_fast
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"silu_and_mul"
,
&
silu_and_mul
,
"Activation function used in SwiGLU."
);
m
.
def
(
"gelu_new"
,
&
gelu_new
,
"GELU implementation used in GPT-2."
);
m
.
def
(
"gelu_fast"
,
&
gelu_fast
,
"Approximate GELU implementation."
);
}
csrc/activation_kernels.cu
View file @
d64bf164
...
...
@@ -46,3 +46,71 @@ void silu_and_mul(
d
);
});
}
namespace
vllm
{
// Element-wise activation kernel template.
template
<
typename
scalar_t
,
scalar_t
(
*
ACT_FN
)(
const
scalar_t
&
)>
__global__
void
activation_kernel
(
scalar_t
*
__restrict__
out
,
// [num_tokens, d]
const
scalar_t
*
__restrict__
input
,
// [num_tokens, d]
const
int
d
)
{
const
int
token_idx
=
blockIdx
.
x
;
for
(
int
idx
=
threadIdx
.
x
;
idx
<
d
;
idx
+=
blockDim
.
x
)
{
const
scalar_t
x
=
__ldg
(
&
input
[
token_idx
*
d
+
idx
]);
out
[
token_idx
*
d
+
idx
]
=
ACT_FN
(
x
);
}
}
}
// namespace vllm
// Launch element-wise activation kernel.
#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \
int num_tokens = input.size(0); \
int d = input.size(1); \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
AT_DISPATCH_FLOATING_TYPES_AND2( \
at::ScalarType::Half, \
at::ScalarType::BFloat16, \
input.scalar_type(), \
"activation_kernel", \
[&] { \
vllm::activation_kernel<scalar_t, KERNEL<scalar_t>><<<grid, block, 0, stream>>>( \
out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), \
d); \
});
namespace
vllm
{
template
<
typename
T
>
__device__
__forceinline__
T
gelu_new_kernel
(
const
T
&
x
)
{
const
float
x3
=
(
float
)
(
x
*
x
*
x
);
const
T
t
=
(
T
)
tanhf
((
T
)
(
0.79788456
f
*
(
float
)
(
x
+
(
T
)
(
0.044715
f
*
x3
))));
return
((
T
)
0.5
)
*
x
*
(((
T
)
1.0
)
+
t
);
}
template
<
typename
T
>
__device__
__forceinline__
T
gelu_fast_kernel
(
const
T
&
x
)
{
const
float
f
=
(
float
)
x
;
const
T
t
=
(
T
)
tanhf
(((
T
)
(
f
*
0.79788456
f
))
*
(((
T
)
1.0
)
+
(
T
)
(
0.044715
f
*
f
)
*
x
));
return
((
T
)
0.5
)
*
x
*
(((
T
)
1.0
)
+
t
);
}
}
// namespace vllm
void
gelu_new
(
torch
::
Tensor
&
out
,
// [num_tokens, d]
torch
::
Tensor
&
input
)
// [num_tokens, d]
{
LAUNCH_ACTIVATION_KERNEL
(
vllm
::
gelu_new_kernel
);
}
void
gelu_fast
(
torch
::
Tensor
&
out
,
// [num_tokens, d]
torch
::
Tensor
&
input
)
// [num_tokens, d]
{
LAUNCH_ACTIVATION_KERNEL
(
vllm
::
gelu_fast_kernel
);
}
tests/kernels/test_activation.py
View file @
d64bf164
import
torch
import
torch.nn.functional
as
F
from
transformers.activations
import
get_activation
from
vllm
import
activation_ops
...
...
@@ -28,3 +28,45 @@ def test_silu_and_mul() -> None:
for
d
in
[
512
,
4096
,
5120
,
13824
]:
print
(
f
'Testing dtype=
{
dtype
}
, num_tokens=
{
num_tokens
}
, d=
{
d
}
'
)
run_silu_and_mul
(
num_tokens
,
d
,
dtype
)
@
torch
.
inference_mode
()
def
run_gelu_new
(
num_tokens
:
int
,
d
:
int
,
dtype
:
torch
.
dtype
,
)
->
None
:
x
=
torch
.
randn
(
num_tokens
,
d
,
dtype
=
dtype
,
device
=
'cuda'
)
out
=
torch
.
empty
(
num_tokens
,
d
,
dtype
=
dtype
,
device
=
'cuda'
)
activation_ops
.
gelu_new
(
out
,
x
)
ref_out
=
get_activation
(
"gelu_new"
)(
x
)
assert
torch
.
allclose
(
out
,
ref_out
,
atol
=
1e-5
,
rtol
=
1e-5
)
def
test_gelu_new
()
->
None
:
for
dtype
in
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]:
for
num_tokens
in
[
7
,
83
,
2048
]:
for
d
in
[
512
,
4096
,
5120
,
13824
]:
print
(
f
'Testing dtype=
{
dtype
}
, num_tokens=
{
num_tokens
}
, d=
{
d
}
'
)
run_gelu_new
(
num_tokens
,
d
,
dtype
)
@
torch
.
inference_mode
()
def
run_gelu_fast
(
num_tokens
:
int
,
d
:
int
,
dtype
:
torch
.
dtype
,
)
->
None
:
x
=
torch
.
randn
(
num_tokens
,
d
,
dtype
=
dtype
,
device
=
'cuda'
)
out
=
torch
.
empty
(
num_tokens
,
d
,
dtype
=
dtype
,
device
=
'cuda'
)
activation_ops
.
gelu_fast
(
out
,
x
)
ref_out
=
get_activation
(
"gelu_fast"
)(
x
)
assert
torch
.
allclose
(
out
,
ref_out
,
atol
=
1e-5
,
rtol
=
1e-5
)
def
test_gelu_fast
()
->
None
:
for
dtype
in
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]:
for
num_tokens
in
[
7
,
83
,
2048
]:
for
d
in
[
512
,
4096
,
5120
,
13824
]:
print
(
f
'Testing dtype=
{
dtype
}
, num_tokens=
{
num_tokens
}
, d=
{
d
}
'
)
run_gelu_fast
(
num_tokens
,
d
,
dtype
)
vllm/model_executor/layers/activation.py
View file @
d64bf164
...
...
@@ -4,23 +4,6 @@ import torch.nn as nn
from
vllm
import
activation_ops
_ACTIVATION_REGISTRY
=
{
"gelu"
:
nn
.
GELU
(),
# NOTE: The following GELU functions may introduce small rounding errors.
"gelu_new"
:
nn
.
GELU
(
approximate
=
"tanh"
),
"gelu_fast"
:
nn
.
GELU
(
approximate
=
"tanh"
),
"gelu_pytorch_tanh"
:
nn
.
GELU
(
approximate
=
"tanh"
),
"relu"
:
nn
.
ReLU
(),
}
def
get_act_fn
(
act_fn
:
str
)
->
nn
.
Module
:
"""Get an activation function by name."""
act_fn
=
act_fn
.
lower
()
if
act_fn
in
_ACTIVATION_REGISTRY
:
return
_ACTIVATION_REGISTRY
[
act_fn
]
raise
ValueError
(
f
"Activation function
{
act_fn
!
r
}
is not supported."
)
class
SiluAndMul
(
nn
.
Module
):
"""An activation function for SwiGLU.
...
...
@@ -38,3 +21,40 @@ class SiluAndMul(nn.Module):
out
=
torch
.
empty
(
num_tokens
,
d
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
activation_ops
.
silu_and_mul
(
out
,
x
)
return
out
class
NewGELU
(
nn
.
Module
):
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
num_tokens
=
x
.
shape
[
0
]
d
=
x
.
shape
[
1
]
out
=
torch
.
empty
(
num_tokens
,
d
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
activation_ops
.
gelu_new
(
out
,
x
)
return
out
class
FastGELU
(
nn
.
Module
):
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
num_tokens
=
x
.
shape
[
0
]
d
=
x
.
shape
[
1
]
out
=
torch
.
empty
(
num_tokens
,
d
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
activation_ops
.
gelu_fast
(
out
,
x
)
return
out
_ACTIVATION_REGISTRY
=
{
"gelu"
:
nn
.
GELU
(),
"gelu_fast"
:
FastGELU
(),
"gelu_new"
:
NewGELU
(),
"gelu_pytorch_tanh"
:
nn
.
GELU
(
approximate
=
"tanh"
),
"relu"
:
nn
.
ReLU
(),
}
def
get_act_fn
(
act_fn
:
str
)
->
nn
.
Module
:
"""Get an activation function by name."""
act_fn
=
act_fn
.
lower
()
if
act_fn
in
_ACTIVATION_REGISTRY
:
return
_ACTIVATION_REGISTRY
[
act_fn
]
raise
ValueError
(
f
"Activation function
{
act_fn
!
r
}
is not supported."
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment