Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
b630aef5
Commit
b630aef5
authored
Apr 18, 2023
by
Tri Dao
Browse files
Implement GatedMlp
parent
ac3b684c
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
57 additions
and
25 deletions
+57
-25
flash_attn/models/gpt.py
flash_attn/models/gpt.py
+24
-17
flash_attn/modules/mlp.py
flash_attn/modules/mlp.py
+25
-0
flash_attn/ops/fused_dense.py
flash_attn/ops/fused_dense.py
+8
-8
No files found.
flash_attn/models/gpt.py
View file @
b630aef5
...
@@ -16,8 +16,9 @@ from transformers import GPT2Config
...
@@ -16,8 +16,9 @@ from transformers import GPT2Config
from
einops
import
rearrange
from
einops
import
rearrange
from
flash_attn.ops.activations
import
sqrelu_fwd
from
flash_attn.modules.mha
import
MHA
,
ParallelMHA
from
flash_attn.modules.mha
import
MHA
,
ParallelMHA
from
flash_attn.modules.mlp
import
Mlp
,
FusedMLP
,
ParallelFusedMLP
from
flash_attn.modules.mlp
import
Mlp
,
GatedMlp
,
FusedMLP
,
ParallelFusedMLP
from
flash_attn.modules.block
import
Block
,
ParallelBlock
from
flash_attn.modules.block
import
Block
,
ParallelBlock
from
flash_attn.modules.embedding
import
GPT2Embeddings
,
ParallelGPT2Embeddings
from
flash_attn.modules.embedding
import
GPT2Embeddings
,
ParallelGPT2Embeddings
from
flash_attn.utils.distributed
import
sync_shared_params
,
all_gather_raw
from
flash_attn.utils.distributed
import
sync_shared_params
,
all_gather_raw
...
@@ -43,10 +44,9 @@ except ImportError:
...
@@ -43,10 +44,9 @@ except ImportError:
dropout_add_layer_norm_parallel_residual
=
None
dropout_add_layer_norm_parallel_residual
=
None
try
:
try
:
from
flash_attn.ops.triton.mlp
import
FusedDenseSqreluDense
,
sqrelu_fwd
from
flash_attn.ops.triton.mlp
import
FusedDenseSqreluDense
except
ImportError
:
except
ImportError
:
FusedDenseSqreluDense
=
None
FusedDenseSqreluDense
=
None
sqrelu_fwd
=
None
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -90,7 +90,6 @@ def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dt
...
@@ -90,7 +90,6 @@ def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dt
def
create_mlp_cls
(
config
,
layer_idx
=
None
,
process_group
=
None
,
device
=
None
,
dtype
=
None
):
def
create_mlp_cls
(
config
,
layer_idx
=
None
,
process_group
=
None
,
device
=
None
,
dtype
=
None
):
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
inner_dim
=
config
.
n_inner
if
config
.
n_inner
is
not
None
else
4
*
config
.
hidden_size
fused_mlp
=
getattr
(
config
,
'fused_mlp'
,
False
)
fused_mlp
=
getattr
(
config
,
'fused_mlp'
,
False
)
if
fused_mlp
:
if
fused_mlp
:
assert
config
.
activation_function
in
[
'gelu_new'
,
'gelu_fast'
,
'gelu_approx'
,
'relu'
,
'sqrelu'
]
assert
config
.
activation_function
in
[
'gelu_new'
,
'gelu_fast'
,
'gelu_approx'
,
'relu'
,
'sqrelu'
]
...
@@ -102,17 +101,25 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp
...
@@ -102,17 +101,25 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp
if
process_group
is
not
None
:
if
process_group
is
not
None
:
assert
fused_mlp
,
'Tensor Parallel is only implemented for FusedMLP'
assert
fused_mlp
,
'Tensor Parallel is only implemented for FusedMLP'
if
not
fused_mlp
and
not
fused_dense_sqrelu_dense
:
if
not
fused_mlp
and
not
fused_dense_sqrelu_dense
:
assert
config
.
activation_function
in
[
'gelu_new'
,
'gelu_fast'
,
'gelu_approx'
,
'relu'
,
'sqrelu'
]
assert
config
.
activation_function
in
[
'gelu_new'
,
'gelu_fast'
,
'gelu_approx'
,
'relu'
,
if
config
.
activation_function
==
'relu'
:
'sqrelu'
,
'glu'
,
'swiglu'
,
'geglu'
]
activation
=
partial
(
F
.
relu
,
inplace
=
True
)
if
config
.
activation_function
in
[
'glu'
,
'swiglu'
,
'geglu'
]:
elif
config
.
activation_function
==
'sqrelu'
:
activation
=
(
F
.
sigmoid
if
config
.
activation_function
==
'glu'
assert
sqrelu_fwd
is
not
None
,
'sqrelu_fwd is not implemented'
else
(
F
.
silu
if
config
.
activation_function
==
'swiglu'
activation
=
sqrelu_fwd
else
F
.
gelu
))
mlp_cls
=
partial
(
GatedMlp
,
hidden_features
=
config
.
n_inner
,
activation
=
activation
,
**
factory_kwargs
)
else
:
else
:
approximate
=
(
'tanh'
if
config
.
activation_function
if
config
.
activation_function
==
'relu'
:
in
[
'gelu_new'
,
'gelu_fast'
,
'gelu_approx'
]
else
'none'
)
activation
=
partial
(
F
.
relu
,
inplace
=
True
)
activation
=
partial
(
F
.
gelu
,
approximate
=
approximate
)
elif
config
.
activation_function
==
'sqrelu'
:
mlp_cls
=
partial
(
Mlp
,
hidden_features
=
inner_dim
,
activation
=
activation
,
**
factory_kwargs
)
activation
=
sqrelu_fwd
else
:
approximate
=
(
'tanh'
if
config
.
activation_function
in
[
'gelu_new'
,
'gelu_fast'
,
'gelu_approx'
]
else
'none'
)
activation
=
partial
(
F
.
gelu
,
approximate
=
approximate
)
mlp_cls
=
partial
(
Mlp
,
hidden_features
=
config
.
n_inner
,
activation
=
activation
,
**
factory_kwargs
)
else
:
else
:
mlp_checkpoint_lvl
=
getattr
(
config
,
'mlp_checkpoint_lvl'
,
0
)
mlp_checkpoint_lvl
=
getattr
(
config
,
'mlp_checkpoint_lvl'
,
0
)
# mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
# mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
...
@@ -128,12 +135,12 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp
...
@@ -128,12 +135,12 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp
parallel_kwargs
=
({
'process_group'
:
process_group
,
parallel_kwargs
=
({
'process_group'
:
process_group
,
'sequence_parallel'
:
getattr
(
config
,
'sequence_parallel'
,
True
)}
'sequence_parallel'
:
getattr
(
config
,
'sequence_parallel'
,
True
)}
if
process_group
is
not
None
else
{})
if
process_group
is
not
None
else
{})
mlp_cls
=
partial
(
mlp_cls
,
hidden_features
=
inner
_dim
,
activation
=
activation
,
mlp_cls
=
partial
(
mlp_cls
,
hidden_features
=
config
.
n_
inner
,
activation
=
activation
,
checkpoint_lvl
=
mlp_checkpoint_lvl
,
checkpoint_lvl
=
mlp_checkpoint_lvl
,
**
parallel_kwargs
,
**
factory_kwargs
)
**
parallel_kwargs
,
**
factory_kwargs
)
elif
fused_dense_sqrelu_dense
:
elif
fused_dense_sqrelu_dense
:
assert
FusedDenseSqreluDense
is
not
None
assert
FusedDenseSqreluDense
is
not
None
mlp_cls
=
partial
(
FusedDenseSqreluDense
,
hidden_features
=
inner
_dim
,
mlp_cls
=
partial
(
FusedDenseSqreluDense
,
hidden_features
=
config
.
n_
inner
,
checkpoint_lvl
=
mlp_checkpoint_lvl
,
**
factory_kwargs
)
checkpoint_lvl
=
mlp_checkpoint_lvl
,
**
factory_kwargs
)
else
:
else
:
raise
RuntimeError
(
'MLP type not supported'
)
raise
RuntimeError
(
'MLP type not supported'
)
...
@@ -252,7 +259,7 @@ class GPTModel(GPTPreTrainedModel):
...
@@ -252,7 +259,7 @@ class GPTModel(GPTPreTrainedModel):
self
.
process_group
=
process_group
self
.
process_group
=
process_group
self
.
sequence_parallel
=
getattr
(
config
,
'sequence_parallel'
,
True
)
self
.
sequence_parallel
=
getattr
(
config
,
'sequence_parallel'
,
True
)
assert
config
.
activation_function
in
[
'gelu'
,
'gelu_new'
,
'gelu_fast'
,
'gelu_approx'
,
assert
config
.
activation_function
in
[
'gelu'
,
'gelu_new'
,
'gelu_fast'
,
'gelu_approx'
,
'relu'
,
'sqrelu'
]
'relu'
,
'sqrelu'
,
'glu'
,
'swiglu'
,
'geglu'
]
pad_vocab_size_multiple
=
getattr
(
config
,
'pad_vocab_size_multiple'
,
1
)
pad_vocab_size_multiple
=
getattr
(
config
,
'pad_vocab_size_multiple'
,
1
)
vocab_size
=
(
math
.
ceil
(
config
.
vocab_size
/
pad_vocab_size_multiple
)
vocab_size
=
(
math
.
ceil
(
config
.
vocab_size
/
pad_vocab_size_multiple
)
*
pad_vocab_size_multiple
)
*
pad_vocab_size_multiple
)
...
...
flash_attn/modules/mlp.py
View file @
b630aef5
...
@@ -28,3 +28,28 @@ class Mlp(nn.Module):
...
@@ -28,3 +28,28 @@ class Mlp(nn.Module):
y
=
self
.
activation
(
y
)
y
=
self
.
activation
(
y
)
y
=
self
.
fc2
(
y
)
y
=
self
.
fc2
(
y
)
return
y
if
not
self
.
return_residual
else
(
y
,
x
)
return
y
if
not
self
.
return_residual
else
(
y
,
x
)
class
GatedMlp
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
hidden_features
=
None
,
out_features
=
None
,
activation
=
F
.
sigmoid
,
multiple_of
=
128
,
return_residual
=
False
,
device
=
None
,
dtype
=
None
):
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
super
().
__init__
()
out_features
=
out_features
or
in_features
hidden_features
=
hidden_features
or
int
(
8
*
in_features
/
3
)
hidden_features
=
(
hidden_features
+
multiple_of
-
1
)
//
multiple_of
*
multiple_of
self
.
return_residual
=
return_residual
self
.
fc1
=
nn
.
Linear
(
in_features
,
2
*
hidden_features
,
**
factory_kwargs
)
self
.
activation
=
activation
self
.
fc2
=
nn
.
Linear
(
hidden_features
,
out_features
,
**
factory_kwargs
)
def
forward
(
self
,
x
):
y
=
self
.
fc1
(
x
)
if
self
.
activation
==
F
.
sigmoid
:
# Special case for GLU
y
=
F
.
glu
(
y
,
dim
=-
1
)
else
:
y
,
gate
=
y
.
chunk
(
2
,
dim
=-
1
)
y
=
y
*
self
.
activation
(
gate
)
y
=
self
.
fc2
(
y
)
return
y
if
not
self
.
return_residual
else
(
y
,
x
)
flash_attn/ops/fused_dense.py
View file @
b630aef5
...
@@ -404,7 +404,7 @@ def fused_mlp_func(
...
@@ -404,7 +404,7 @@ def fused_mlp_func(
class
FusedMLP
(
nn
.
Module
):
class
FusedMLP
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
hidden_features
,
out_features
=
None
,
bias1
=
True
,
def
__init__
(
self
,
in_features
,
hidden_features
=
None
,
out_features
=
None
,
bias1
=
True
,
bias2
=
True
,
activation
=
'gelu_approx'
,
return_residual
=
False
,
bias2
=
True
,
activation
=
'gelu_approx'
,
return_residual
=
False
,
checkpoint_lvl
=
0
,
heuristic
=
'auto'
,
device
=
None
,
dtype
=
None
):
checkpoint_lvl
=
0
,
heuristic
=
'auto'
,
device
=
None
,
dtype
=
None
):
"""
"""
...
@@ -432,8 +432,8 @@ class FusedMLP(nn.Module):
...
@@ -432,8 +432,8 @@ class FusedMLP(nn.Module):
assert
activation
in
[
'gelu_approx'
,
'relu'
,
'sqrelu'
]
assert
activation
in
[
'gelu_approx'
,
'relu'
,
'sqrelu'
]
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
super
().
__init__
()
super
().
__init__
()
if
out_features
is
None
:
out_features
=
out_features
or
in_features
out
_features
=
in_features
hidden_features
=
hidden
_features
or
in_features
*
4
self
.
activation
=
activation
self
.
activation
=
activation
self
.
return_residual
=
return_residual
self
.
return_residual
=
return_residual
self
.
checkpoint_lvl
=
checkpoint_lvl
self
.
checkpoint_lvl
=
checkpoint_lvl
...
@@ -469,9 +469,9 @@ class FusedMLP(nn.Module):
...
@@ -469,9 +469,9 @@ class FusedMLP(nn.Module):
class
ParallelFusedMLP
(
nn
.
Module
):
class
ParallelFusedMLP
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
hidden_features
,
out_features
=
None
,
activation
=
'gelu_approx'
,
def
__init__
(
self
,
in_features
,
hidden_features
=
None
,
out_features
=
None
,
process_group
:
ProcessGroup
=
None
,
bias1
=
True
,
bias2
=
True
,
activation
=
'gelu_approx'
,
process_group
:
ProcessGroup
=
None
,
sequence_parallel
=
True
,
checkpoint_lvl
=
0
,
heuristic
=
'auto'
,
bias1
=
True
,
bias2
=
True
,
sequence_parallel
=
True
,
checkpoint_lvl
=
0
,
heuristic
=
'auto'
,
device
=
None
,
dtype
=
None
):
device
=
None
,
dtype
=
None
):
"""
"""
process_group is required. We're doing Tensor Parallel with sequence parallelism:
process_group is required. We're doing Tensor Parallel with sequence parallelism:
...
@@ -494,8 +494,8 @@ class ParallelFusedMLP(nn.Module):
...
@@ -494,8 +494,8 @@ class ParallelFusedMLP(nn.Module):
assert
process_group
is
not
None
assert
process_group
is
not
None
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
super
().
__init__
()
super
().
__init__
()
if
out_features
is
None
:
out_features
=
out_features
or
in_features
out
_features
=
in_features
hidden_features
=
hidden
_features
or
in_features
*
4
self
.
activation
=
activation
self
.
activation
=
activation
self
.
process_group
=
process_group
self
.
process_group
=
process_group
self
.
sequence_parallel
=
sequence_parallel
self
.
sequence_parallel
=
sequence_parallel
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment