Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
fd5dcc5c
Unverified
Commit
fd5dcc5c
authored
Feb 21, 2024
by
Woosuk Kwon
Committed by
GitHub
Feb 21, 2024
Browse files
Optimize GeGLU layer in Gemma (#2975)
parent
93dc5a28
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
108 additions
and
77 deletions
+108
-77
csrc/activation_kernels.cu
csrc/activation_kernels.cu
+48
-25
csrc/ops.h
csrc/ops.h
+4
-0
csrc/pybind.cpp
csrc/pybind.cpp
+4
-0
tests/kernels/test_activation.py
tests/kernels/test_activation.py
+15
-35
vllm/model_executor/layers/activation.py
vllm/model_executor/layers/activation.py
+23
-0
vllm/model_executor/models/gemma.py
vllm/model_executor/models/gemma.py
+14
-17
No files found.
csrc/activation_kernels.cu
View file @
fd5dcc5c
...
@@ -2,19 +2,16 @@
...
@@ -2,19 +2,16 @@
#include <torch/extension.h>
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAGuard.h>
#include <cmath>
#include "cuda_compat.h"
#include "cuda_compat.h"
#include "dispatch_utils.h"
#include "dispatch_utils.h"
namespace
vllm
{
namespace
vllm
{
template
<
typename
T
>
// Activation and gating kernel template.
__device__
__forceinline__
T
silu
(
const
T
&
x
)
{
template
<
typename
scalar_t
,
scalar_t
(
*
ACT_FN
)(
const
scalar_t
&
)>
// x * sigmoid(x)
__global__
void
act_and_mul_kernel
(
return
(
T
)
(((
float
)
x
)
/
(
1.0
f
+
expf
((
float
)
-
x
)));
}
template
<
typename
scalar_t
>
__global__
void
silu_and_mul_kernel
(
scalar_t
*
__restrict__
out
,
// [..., d]
scalar_t
*
__restrict__
out
,
// [..., d]
const
scalar_t
*
__restrict__
input
,
// [..., 2, d]
const
scalar_t
*
__restrict__
input
,
// [..., 2, d]
const
int
d
)
{
const
int
d
)
{
...
@@ -22,32 +19,58 @@ __global__ void silu_and_mul_kernel(
...
@@ -22,32 +19,58 @@ __global__ void silu_and_mul_kernel(
for
(
int64_t
idx
=
threadIdx
.
x
;
idx
<
d
;
idx
+=
blockDim
.
x
)
{
for
(
int64_t
idx
=
threadIdx
.
x
;
idx
<
d
;
idx
+=
blockDim
.
x
)
{
const
scalar_t
x
=
VLLM_LDG
(
&
input
[
token_idx
*
2
*
d
+
idx
]);
const
scalar_t
x
=
VLLM_LDG
(
&
input
[
token_idx
*
2
*
d
+
idx
]);
const
scalar_t
y
=
VLLM_LDG
(
&
input
[
token_idx
*
2
*
d
+
d
+
idx
]);
const
scalar_t
y
=
VLLM_LDG
(
&
input
[
token_idx
*
2
*
d
+
d
+
idx
]);
out
[
token_idx
*
d
+
idx
]
=
silu
(
x
)
*
y
;
out
[
token_idx
*
d
+
idx
]
=
ACT_FN
(
x
)
*
y
;
}
}
}
}
template
<
typename
T
>
__device__
__forceinline__
T
silu_kernel
(
const
T
&
x
)
{
// x * sigmoid(x)
return
(
T
)
(((
float
)
x
)
/
(
1.0
f
+
expf
((
float
)
-
x
)));
}
template
<
typename
T
>
__device__
__forceinline__
T
gelu_kernel
(
const
T
&
x
)
{
// Equivalent to PyTorch GELU with 'none' approximation.
// Refer to:
// https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L38
const
float
f
=
(
float
)
x
;
constexpr
float
ALPHA
=
M_SQRT1_2
;
return
(
T
)
(
f
*
0.5
f
*
(
1.0
f
+
::
erf
(
f
*
ALPHA
)));
}
}
// namespace vllm
}
// namespace vllm
// Launch activation and gating kernel.
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \
int d = input.size(-1) / 2; \
int64_t num_tokens = input.numel() / input.size(-1); \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), \
"act_and_mul_kernel", \
[&] { \
vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>><<<grid, block, 0, stream>>>( \
out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), \
d); \
});
void
silu_and_mul
(
void
silu_and_mul
(
torch
::
Tensor
&
out
,
// [..., d]
torch
::
Tensor
&
out
,
// [..., d]
torch
::
Tensor
&
input
)
// [..., 2 * d]
torch
::
Tensor
&
input
)
// [..., 2 * d]
{
{
int64_t
num_tokens
=
input
.
numel
()
/
input
.
size
(
-
1
);
LAUNCH_ACTIVATION_GATE_KERNEL
(
vllm
::
silu_kernel
);
int
d
=
input
.
size
(
-
1
)
/
2
;
}
dim3
grid
(
num_tokens
);
void
gelu_and_mul
(
dim3
block
(
std
::
min
(
d
,
1024
));
torch
::
Tensor
&
out
,
// [..., d]
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
torch
::
Tensor
&
input
)
// [..., 2 * d]
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
{
VLLM_DISPATCH_FLOATING_TYPES
(
LAUNCH_ACTIVATION_GATE_KERNEL
(
vllm
::
gelu_kernel
);
input
.
scalar_type
(),
"silu_and_mul_kernel"
,
[
&
]
{
vllm
::
silu_and_mul_kernel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
out
.
data_ptr
<
scalar_t
>
(),
input
.
data_ptr
<
scalar_t
>
(),
d
);
});
}
}
namespace
vllm
{
namespace
vllm
{
...
...
csrc/ops.h
View file @
fd5dcc5c
...
@@ -57,6 +57,10 @@ void silu_and_mul(
...
@@ -57,6 +57,10 @@ void silu_and_mul(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
torch
::
Tensor
&
input
);
void
gelu_and_mul
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
void
gelu_new
(
void
gelu_new
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
torch
::
Tensor
&
input
);
...
...
csrc/pybind.cpp
View file @
fd5dcc5c
...
@@ -22,6 +22,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -22,6 +22,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"silu_and_mul"
,
"silu_and_mul"
,
&
silu_and_mul
,
&
silu_and_mul
,
"Activation function used in SwiGLU."
);
"Activation function used in SwiGLU."
);
ops
.
def
(
"gelu_and_mul"
,
&
gelu_and_mul
,
"Activation function used in GeGLU."
);
ops
.
def
(
ops
.
def
(
"gelu_new"
,
"gelu_new"
,
&
gelu_new
,
&
gelu_new
,
...
...
tests/kernels/test_activation.py
View file @
fd5dcc5c
from
typing
import
Type
import
pytest
import
pytest
import
torch
import
torch
from
vllm.model_executor.layers.activation
import
FastGELU
,
NewGELU
,
SiluAndMul
from
vllm.model_executor.layers.activation
import
(
FastGELU
,
GeluAndMul
,
NewGELU
,
SiluAndMul
)
from
allclose_default
import
get_default_atol
,
get_default_rtol
from
allclose_default
import
get_default_atol
,
get_default_rtol
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
...
@@ -13,13 +16,15 @@ CUDA_DEVICES = [
...
@@ -13,13 +16,15 @@ CUDA_DEVICES = [
]
]
@
pytest
.
mark
.
parametrize
(
"activation"
,
[
SiluAndMul
,
GeluAndMul
])
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"d"
,
D
)
@
pytest
.
mark
.
parametrize
(
"d"
,
D
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_silu_and_mul
(
def
test_act_and_mul
(
activation
:
Type
[
torch
.
nn
.
Module
],
num_tokens
:
int
,
num_tokens
:
int
,
d
:
int
,
d
:
int
,
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
...
@@ -31,48 +36,23 @@ def test_silu_and_mul(
...
@@ -31,48 +36,23 @@ def test_silu_and_mul(
torch
.
cuda
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
x
=
torch
.
randn
(
num_tokens
,
2
*
d
,
dtype
=
dtype
)
x
=
torch
.
randn
(
num_tokens
,
2
*
d
,
dtype
=
dtype
)
layer
=
SiluAndMul
()
layer
=
activation
()
out
=
layer
(
x
)
out
=
layer
(
x
)
ref_out
=
layer
.
_forward
(
x
)
ref_out
=
layer
.
_forward
(
x
)
assert
torch
.
allclose
(
out
,
# The SiLU and GELU implementations are equivalent to the native PyTorch
ref_out
,
# implementations, so we can do exact comparison.
atol
=
get_default_atol
(
out
),
assert
torch
.
allclose
(
out
,
ref_out
,
atol
=
0.0
,
rtol
=
0.0
)
rtol
=
get_default_rtol
(
out
))
@
pytest
.
mark
.
parametrize
(
"activation"
,
[
FastGELU
,
NewGELU
])
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"d"
,
D
)
@
pytest
.
mark
.
parametrize
(
"d"
,
D
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_gelu_new
(
def
test_activation
(
num_tokens
:
int
,
activation
:
Type
[
torch
.
nn
.
Module
],
d
:
int
,
dtype
:
torch
.
dtype
,
seed
:
int
,
device
:
str
,
)
->
None
:
torch
.
random
.
manual_seed
(
seed
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed
(
seed
)
torch
.
set_default_device
(
device
)
x
=
torch
.
randn
(
num_tokens
,
d
,
dtype
=
dtype
)
layer
=
NewGELU
()
out
=
layer
(
x
)
ref_out
=
layer
.
_forward
(
x
)
assert
torch
.
allclose
(
out
,
ref_out
,
atol
=
get_default_atol
(
out
),
rtol
=
get_default_rtol
(
out
))
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"d"
,
D
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
def
test_gelu_fast
(
num_tokens
:
int
,
num_tokens
:
int
,
d
:
int
,
d
:
int
,
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
...
@@ -84,7 +64,7 @@ def test_gelu_fast(
...
@@ -84,7 +64,7 @@ def test_gelu_fast(
torch
.
cuda
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
x
=
torch
.
randn
(
num_tokens
,
d
,
dtype
=
dtype
)
x
=
torch
.
randn
(
num_tokens
,
d
,
dtype
=
dtype
)
layer
=
FastGELU
()
layer
=
activation
()
out
=
layer
(
x
)
out
=
layer
(
x
)
ref_out
=
layer
.
_forward
(
x
)
ref_out
=
layer
.
_forward
(
x
)
assert
torch
.
allclose
(
out
,
assert
torch
.
allclose
(
out
,
...
...
vllm/model_executor/layers/activation.py
View file @
fd5dcc5c
...
@@ -37,6 +37,29 @@ class SiluAndMul(nn.Module):
...
@@ -37,6 +37,29 @@ class SiluAndMul(nn.Module):
return
out
return
out
class
GeluAndMul
(
nn
.
Module
):
"""An activation function for GeGLU.
The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.
Shapes:
x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
return: (batch_size, seq_len, d) or (num_tokens, d)
"""
def
_forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""PyTorch-native implementation equivalent to forward()."""
d
=
x
.
shape
[
-
1
]
//
2
return
F
.
gelu
(
x
[...,
:
d
])
*
x
[...,
d
:]
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
d
=
x
.
shape
[
-
1
]
//
2
output_shape
=
(
x
.
shape
[:
-
1
]
+
(
d
,
))
out
=
torch
.
empty
(
output_shape
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
ops
.
gelu_and_mul
(
out
,
x
)
return
out
class
NewGELU
(
nn
.
Module
):
class
NewGELU
(
nn
.
Module
):
def
_forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
_forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
vllm/model_executor/models/gemma.py
View file @
fd5dcc5c
...
@@ -21,10 +21,11 @@ from torch import nn
...
@@ -21,10 +21,11 @@ from torch import nn
from
transformers
import
GemmaConfig
from
transformers
import
GemmaConfig
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.activation
import
GeluAndMul
from
vllm.model_executor.layers.attention
import
PagedAttention
from
vllm.model_executor.layers.attention
import
PagedAttention
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
LinearMethodBase
,
MergedColumnParallelLinear
,
QKVParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
...
@@ -50,27 +51,21 @@ class GemmaMLP(nn.Module):
...
@@ -50,27 +51,21 @@ class GemmaMLP(nn.Module):
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
gate_proj
=
ColumnParallelLinear
(
hidden_size
,
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
intermediate_size
,
hidden_size
,
[
intermediate_size
]
*
2
,
bias
=
False
,
linear_method
=
linear_method
)
self
.
up_proj
=
ColumnParallelLinear
(
hidden_size
,
intermediate_size
,
bias
=
False
,
bias
=
False
,
linear_method
=
linear_method
)
linear_method
=
linear_method
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
hidden_size
,
bias
=
False
,
bias
=
False
,
linear_method
=
linear_method
)
linear_method
=
linear_method
)
self
.
act_fn
=
nn
.
GELU
()
self
.
act_fn
=
GeluAndMul
()
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
gate
,
_
=
self
.
gate_proj
(
x
)
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
gate
=
self
.
act_fn
(
gate
)
x
=
self
.
act_fn
(
gate_up
)
up
,
_
=
self
.
up_proj
(
x
)
x
,
_
=
self
.
down_proj
(
x
)
fuse
=
gate
*
up
return
x
outputs
,
_
=
self
.
down_proj
(
fuse
)
return
outputs
class
GemmaAttention
(
nn
.
Module
):
class
GemmaAttention
(
nn
.
Module
):
...
@@ -294,6 +289,8 @@ class GemmaForCausalLM(nn.Module):
...
@@ -294,6 +289,8 @@ class GemmaForCausalLM(nn.Module):
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
]
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
=
set
()
loaded_params
=
set
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment