Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xdb4_94051
vllm
Commits
fd5dcc5c
Unverified
Commit
fd5dcc5c
authored
Feb 21, 2024
by
Woosuk Kwon
Committed by
GitHub
Feb 21, 2024
Browse files
Optimize GeGLU layer in Gemma (#2975)
parent
93dc5a28
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
108 additions
and
77 deletions
+108
-77
csrc/activation_kernels.cu
csrc/activation_kernels.cu
+48
-25
csrc/ops.h
csrc/ops.h
+4
-0
csrc/pybind.cpp
csrc/pybind.cpp
+4
-0
tests/kernels/test_activation.py
tests/kernels/test_activation.py
+15
-35
vllm/model_executor/layers/activation.py
vllm/model_executor/layers/activation.py
+23
-0
vllm/model_executor/models/gemma.py
vllm/model_executor/models/gemma.py
+14
-17
No files found.
csrc/activation_kernels.cu
View file @
fd5dcc5c
...
...
@@ -2,19 +2,16 @@
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cmath>
#include "cuda_compat.h"
#include "dispatch_utils.h"
namespace
vllm
{
template
<
typename
T
>
__device__
__forceinline__
T
silu
(
const
T
&
x
)
{
// x * sigmoid(x)
return
(
T
)
(((
float
)
x
)
/
(
1.0
f
+
expf
((
float
)
-
x
)));
}
template
<
typename
scalar_t
>
__global__
void
silu_and_mul_kernel
(
// Activation and gating kernel template.
template
<
typename
scalar_t
,
scalar_t
(
*
ACT_FN
)(
const
scalar_t
&
)>
__global__
void
act_and_mul_kernel
(
scalar_t
*
__restrict__
out
,
// [..., d]
const
scalar_t
*
__restrict__
input
,
// [..., 2, d]
const
int
d
)
{
...
...
@@ -22,32 +19,58 @@ __global__ void silu_and_mul_kernel(
for
(
int64_t
idx
=
threadIdx
.
x
;
idx
<
d
;
idx
+=
blockDim
.
x
)
{
const
scalar_t
x
=
VLLM_LDG
(
&
input
[
token_idx
*
2
*
d
+
idx
]);
const
scalar_t
y
=
VLLM_LDG
(
&
input
[
token_idx
*
2
*
d
+
d
+
idx
]);
out
[
token_idx
*
d
+
idx
]
=
silu
(
x
)
*
y
;
out
[
token_idx
*
d
+
idx
]
=
ACT_FN
(
x
)
*
y
;
}
}
template
<
typename
T
>
__device__
__forceinline__
T
silu_kernel
(
const
T
&
x
)
{
// x * sigmoid(x)
return
(
T
)
(((
float
)
x
)
/
(
1.0
f
+
expf
((
float
)
-
x
)));
}
template
<
typename
T
>
__device__
__forceinline__
T
gelu_kernel
(
const
T
&
x
)
{
// Equivalent to PyTorch GELU with 'none' approximation.
// Refer to:
// https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L38
const
float
f
=
(
float
)
x
;
constexpr
float
ALPHA
=
M_SQRT1_2
;
return
(
T
)
(
f
*
0.5
f
*
(
1.0
f
+
::
erf
(
f
*
ALPHA
)));
}
}
// namespace vllm
// Launch activation and gating kernel.
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \
int d = input.size(-1) / 2; \
int64_t num_tokens = input.numel() / input.size(-1); \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), \
"act_and_mul_kernel", \
[&] { \
vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>><<<grid, block, 0, stream>>>( \
out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), \
d); \
});
void
silu_and_mul
(
torch
::
Tensor
&
out
,
// [..., d]
torch
::
Tensor
&
input
)
// [..., 2 * d]
{
int64_t
num_tokens
=
input
.
numel
()
/
input
.
size
(
-
1
);
int
d
=
input
.
size
(
-
1
)
/
2
;
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
d
,
1024
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
VLLM_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"silu_and_mul_kernel"
,
[
&
]
{
vllm
::
silu_and_mul_kernel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
out
.
data_ptr
<
scalar_t
>
(),
input
.
data_ptr
<
scalar_t
>
(),
d
);
});
LAUNCH_ACTIVATION_GATE_KERNEL
(
vllm
::
silu_kernel
);
}
void
gelu_and_mul
(
torch
::
Tensor
&
out
,
// [..., d]
torch
::
Tensor
&
input
)
// [..., 2 * d]
{
LAUNCH_ACTIVATION_GATE_KERNEL
(
vllm
::
gelu_kernel
);
}
namespace
vllm
{
...
...
csrc/ops.h
View file @
fd5dcc5c
...
...
@@ -57,6 +57,10 @@ void silu_and_mul(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
void
gelu_and_mul
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
void
gelu_new
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
...
...
csrc/pybind.cpp
View file @
fd5dcc5c
...
...
@@ -22,6 +22,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"silu_and_mul"
,
&
silu_and_mul
,
"Activation function used in SwiGLU."
);
ops
.
def
(
"gelu_and_mul"
,
&
gelu_and_mul
,
"Activation function used in GeGLU."
);
ops
.
def
(
"gelu_new"
,
&
gelu_new
,
...
...
tests/kernels/test_activation.py
View file @
fd5dcc5c
from
typing
import
Type
import
pytest
import
torch
from
vllm.model_executor.layers.activation
import
FastGELU
,
NewGELU
,
SiluAndMul
from
vllm.model_executor.layers.activation
import
(
FastGELU
,
GeluAndMul
,
NewGELU
,
SiluAndMul
)
from
allclose_default
import
get_default_atol
,
get_default_rtol
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
...
...
@@ -13,13 +16,15 @@ CUDA_DEVICES = [
]
@
pytest
.
mark
.
parametrize
(
"activation"
,
[
SiluAndMul
,
GeluAndMul
])
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"d"
,
D
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
torch
.
inference_mode
()
def
test_silu_and_mul
(
def
test_act_and_mul
(
activation
:
Type
[
torch
.
nn
.
Module
],
num_tokens
:
int
,
d
:
int
,
dtype
:
torch
.
dtype
,
...
...
@@ -31,48 +36,23 @@ def test_silu_and_mul(
torch
.
cuda
.
manual_seed
(
seed
)
torch
.
set_default_device
(
device
)
x
=
torch
.
randn
(
num_tokens
,
2
*
d
,
dtype
=
dtype
)
layer
=
SiluAndMul
()
layer
=
activation
()
out
=
layer
(
x
)
ref_out
=
layer
.
_forward
(
x
)
assert
torch
.
allclose
(
out
,
ref_out
,
atol
=
get_default_atol
(
out
),
rtol
=
get_default_rtol
(
out
))
# The SiLU and GELU implementations are equivalent to the native PyTorch
# implementations, so we can do exact comparison.
assert
torch
.
allclose
(
out
,
ref_out
,
atol
=
0.0
,
rtol
=
0.0
)
@
pytest
.
mark
.
parametrize
(
"activation"
,
[
FastGELU
,
NewGELU
])
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"d"
,
D
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
torch
.
inference_mode
()
def
test_gelu_new
(
num_tokens
:
int
,
d
:
int
,
dtype
:
torch
.
dtype
,
seed
:
int
,
device
:
str
,
)
->
None
:
torch
.
random
.
manual_seed
(
seed
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed
(
seed
)
torch
.
set_default_device
(
device
)
x
=
torch
.
randn
(
num_tokens
,
d
,
dtype
=
dtype
)
layer
=
NewGELU
()
out
=
layer
(
x
)
ref_out
=
layer
.
_forward
(
x
)
assert
torch
.
allclose
(
out
,
ref_out
,
atol
=
get_default_atol
(
out
),
rtol
=
get_default_rtol
(
out
))
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"d"
,
D
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
def
test_gelu_fast
(
def
test_activation
(
activation
:
Type
[
torch
.
nn
.
Module
],
num_tokens
:
int
,
d
:
int
,
dtype
:
torch
.
dtype
,
...
...
@@ -84,7 +64,7 @@ def test_gelu_fast(
torch
.
cuda
.
manual_seed
(
seed
)
torch
.
set_default_device
(
device
)
x
=
torch
.
randn
(
num_tokens
,
d
,
dtype
=
dtype
)
layer
=
FastGELU
()
layer
=
activation
()
out
=
layer
(
x
)
ref_out
=
layer
.
_forward
(
x
)
assert
torch
.
allclose
(
out
,
...
...
vllm/model_executor/layers/activation.py
View file @
fd5dcc5c
...
...
@@ -37,6 +37,29 @@ class SiluAndMul(nn.Module):
return
out
class
GeluAndMul
(
nn
.
Module
):
"""An activation function for GeGLU.
The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.
Shapes:
x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
return: (batch_size, seq_len, d) or (num_tokens, d)
"""
def
_forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""PyTorch-native implementation equivalent to forward()."""
d
=
x
.
shape
[
-
1
]
//
2
return
F
.
gelu
(
x
[...,
:
d
])
*
x
[...,
d
:]
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
d
=
x
.
shape
[
-
1
]
//
2
output_shape
=
(
x
.
shape
[:
-
1
]
+
(
d
,
))
out
=
torch
.
empty
(
output_shape
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
ops
.
gelu_and_mul
(
out
,
x
)
return
out
class
NewGELU
(
nn
.
Module
):
def
_forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
vllm/model_executor/models/gemma.py
View file @
fd5dcc5c
...
...
@@ -21,10 +21,11 @@ from torch import nn
from
transformers
import
GemmaConfig
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.activation
import
GeluAndMul
from
vllm.model_executor.layers.attention
import
PagedAttention
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearMethodBase
,
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
MergedColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
...
...
@@ -50,27 +51,21 @@ class GemmaMLP(nn.Module):
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
gate_proj
=
ColumnParallelLinear
(
hidden_size
,
intermediate_size
,
bias
=
False
,
linear_method
=
linear_method
)
self
.
up_proj
=
ColumnParallelLinear
(
hidden_size
,
intermediate_size
,
bias
=
False
,
linear_method
=
linear_method
)
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
hidden_size
,
[
intermediate_size
]
*
2
,
bias
=
False
,
linear_method
=
linear_method
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
False
,
linear_method
=
linear_method
)
self
.
act_fn
=
nn
.
GELU
()
self
.
act_fn
=
GeluAndMul
()
def
forward
(
self
,
x
):
gate
,
_
=
self
.
gate_proj
(
x
)
gate
=
self
.
act_fn
(
gate
)
up
,
_
=
self
.
up_proj
(
x
)
fuse
=
gate
*
up
outputs
,
_
=
self
.
down_proj
(
fuse
)
return
outputs
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
)
return
x
class
GemmaAttention
(
nn
.
Module
):
...
...
@@ -294,6 +289,8 @@ class GemmaForCausalLM(nn.Module):
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
=
set
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment