Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
abbc1311
Commit
abbc1311
authored
Jan 05, 2024
by
Tri Dao
Browse files
[LayerNorm] Switch from CUDA to Triton implementation
parent
f5b308e2
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
82 additions
and
143 deletions
+82
-143
csrc/layer_norm/README.md
csrc/layer_norm/README.md
+4
-0
flash_attn/models/bert.py
flash_attn/models/bert.py
+11
-10
flash_attn/models/gpt.py
flash_attn/models/gpt.py
+22
-59
flash_attn/models/vit.py
flash_attn/models/vit.py
+8
-9
flash_attn/modules/block.py
flash_attn/modules/block.py
+35
-59
training/Dockerfile
training/Dockerfile
+2
-6
No files found.
csrc/layer_norm/README.md
View file @
abbc1311
...
...
@@ -14,3 +14,7 @@ This extension has only been tested on A100s.
```
sh
cd
csrc/layer_norm
&&
pip
install
.
```
As of 2024-01-05, this extension is no longer used in the FlashAttention repo.
We've instead switched to a Triton-based
[
implementation
](
https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/layer_norm.py
)
.
flash_attn/models/bert.py
View file @
abbc1311
...
...
@@ -40,9 +40,10 @@ except ImportError:
FusedDense
=
None
try
:
from
flash_attn.ops.layer_norm
import
dropout_add_layer_norm
,
layer_norm
from
flash_attn.ops.
triton.
layer_norm
import
layer_norm
_fn
except
ImportError
:
dropout_add_layer_norm
,
layer_norm
=
None
,
None
layer_norm_fn
=
None
try
:
from
flash_attn.losses.cross_entropy
import
CrossEntropyLoss
...
...
@@ -237,8 +238,8 @@ class BertPredictionHeadTransform(nn.Module):
if
fused_bias_fc
and
FusedDense
is
None
:
raise
ImportError
(
"fused_dense is not installed"
)
self
.
fused_dropout_add_ln
=
getattr
(
config
,
"fused_dropout_add_ln"
,
False
)
if
self
.
fused_dropout_add_ln
and
layer_norm
is
None
:
raise
ImportError
(
"
dropout_add_layer_norm
is not installed"
)
if
self
.
fused_dropout_add_ln
and
layer_norm
_fn
is
None
:
raise
ImportError
(
"
Triton
is not installed"
)
linear_cls
=
nn
.
Linear
if
not
fused_bias_fc
else
FusedDense
self
.
dense
=
linear_cls
(
config
.
hidden_size
,
config
.
hidden_size
)
approximate
=
(
...
...
@@ -255,8 +256,8 @@ class BertPredictionHeadTransform(nn.Module):
if
not
self
.
fused_dropout_add_ln
:
hidden_states
=
self
.
layer_norm
(
hidden_states
)
else
:
hidden_states
=
layer_norm
(
hidden_states
,
self
.
layer_norm
.
weight
,
self
.
layer_norm
.
bias
,
self
.
layer_norm
.
eps
hidden_states
=
layer_norm
_fn
(
hidden_states
,
self
.
layer_norm
.
weight
,
self
.
layer_norm
.
bias
,
eps
=
self
.
layer_norm
.
eps
)
return
hidden_states
...
...
@@ -345,8 +346,8 @@ class BertModel(BertPreTrainedModel):
config
.
vocab_size
%
self
.
pad_vocab_size_multiple
)
self
.
fused_dropout_add_ln
=
getattr
(
config
,
"fused_dropout_add_ln"
,
False
)
if
self
.
fused_dropout_add_ln
and
layer_norm
is
None
:
raise
ImportError
(
"
dropout_add_layer_norm
is not installed"
)
if
self
.
fused_dropout_add_ln
and
layer_norm
_fn
is
None
:
raise
ImportError
(
"
Triton
is not installed"
)
assert
config
.
hidden_act
in
[
"gelu"
,
"gelu_new"
,
"gelu_fast"
,
"gelu_pytorch_tanh"
]
self
.
embeddings
=
BertEmbeddings
(
...
...
@@ -384,8 +385,8 @@ class BertModel(BertPreTrainedModel):
if
not
self
.
fused_dropout_add_ln
:
hidden_states
=
self
.
emb_ln
(
hidden_states
)
else
:
hidden_states
=
layer_norm
(
hidden_states
,
self
.
emb_ln
.
weight
,
self
.
emb_ln
.
bias
,
self
.
emb_ln
.
eps
hidden_states
=
layer_norm
_fn
(
hidden_states
,
self
.
emb_ln
.
weight
,
self
.
emb_ln
.
bias
,
eps
=
self
.
emb_ln
.
eps
)
hidden_states
=
self
.
emb_drop
(
hidden_states
)
...
...
flash_attn/models/gpt.py
View file @
abbc1311
# Copyright (c) 202
3
, Tri Dao.
# Copyright (c) 202
4
, Tri Dao.
import
logging
import
math
...
...
@@ -47,29 +47,14 @@ except ImportError:
ColumnParallelLinear
=
None
try
:
from
flash_attn.ops.layer_norm
import
dropout_add_layer_norm
except
ImportError
:
dropout_add_layer_norm
=
None
try
:
from
flash_attn.ops.layer_norm
import
dropout_add_layer_norm_parallel_residual
except
ImportError
:
dropout_add_layer_norm_parallel_residual
=
None
try
:
from
flash_attn.ops.rms_norm
import
RMSNorm
,
dropout_add_rms_norm
except
ImportError
:
RMSNorm
,
dropout_add_rms_norm
=
None
,
None
try
:
from
flash_attn.ops.rms_norm
import
dropout_add_rms_norm_parallel_residual
from
flash_attn.ops.triton.mlp
import
FusedDenseSqreluDense
except
ImportError
:
dropout_add_rms_norm_parallel_residual
=
None
FusedDenseSqreluDense
=
None
try
:
from
flash_attn.ops.triton.
mlp
import
FusedDenseSqreluDense
from
flash_attn.ops.triton.
layer_norm
import
layer_norm_fn
,
RMSNorm
except
ImportError
:
FusedDenseSqreluDense
=
None
layer_norm_fn
,
RMSNorm
=
None
,
None
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -481,13 +466,15 @@ class GPTModel(GPTPreTrainedModel):
for
i
in
range
(
config
.
num_hidden_layers
)
]
)
rotary_emb_fraction
=
getattr
(
config
,
"rotary_emb_fraction"
,
0.0
)
if
rotary_emb_fraction
>
0.0
:
# Tie all the RotaryEmbedding modules to share the same cos/sin cache
for
layer
in
self
.
layers
[
1
:]:
layer
.
mixer
.
rotary_emb
=
self
.
layers
[
0
].
mixer
.
rotary_emb
self
.
fused_dropout_add_ln
=
getattr
(
config
,
"fused_dropout_add_ln"
,
False
)
if
self
.
fused_dropout_add_ln
:
if
(
not
self
.
parallel_block
and
dropout_add_layer_norm
is
None
)
or
(
self
.
parallel_block
and
dropout_add_layer_norm_parallel_residual
is
None
):
raise
ImportError
(
"dropout_layer_norm is not installed"
)
if
layer_norm_fn
is
None
:
raise
ImportError
(
"Triton is not installed"
)
if
self
.
prenorm
:
self
.
drop_f
=
nn
.
Dropout
(
config
.
resid_pdrop
)
norm_cls
=
nn
.
LayerNorm
if
not
use_rms_norm
else
RMSNorm
...
...
@@ -571,41 +558,17 @@ class GPTModel(GPTPreTrainedModel):
hidden_states
=
self
.
ln_f
(
residual
.
to
(
dtype
=
self
.
ln_f
.
weight
.
dtype
))
else
:
# Set prenorm=False here since we don't need the residual
if
not
self
.
parallel_block
:
fused_add_norm_fn
=
(
dropout_add_rms_norm
if
isinstance
(
self
.
ln_f
,
RMSNorm
)
else
dropout_add_layer_norm
)
hidden_states
=
fused_add_norm_fn
(
hidden_states
,
residual
,
self
.
ln_f
.
weight
,
self
.
ln_f
.
bias
,
self
.
drop_f
.
p
if
self
.
training
else
0.0
,
self
.
ln_f
.
eps
,
prenorm
=
False
,
residual_in_fp32
=
self
.
residual_in_fp32
,
)
else
:
fused_add_norm_fn
=
(
dropout_add_rms_norm_parallel_residual
if
isinstance
(
self
.
ln_f
,
RMSNorm
)
else
dropout_add_layer_norm_parallel_residual
)
hidden_states
,
_
=
fused_add_norm_fn
(
hidden_states
,
hidden_states2
,
residual
,
self
.
ln_f
.
weight
,
self
.
ln_f
.
bias
,
None
,
None
,
self
.
drop_f
.
p
if
self
.
training
else
0.0
,
self
.
ln_f
.
eps
,
prenorm
=
False
,
residual_in_fp32
=
self
.
residual_in_fp32
,
)
hidden_states
=
layer_norm_fn
(
hidden_states
,
self
.
ln_f
.
weight
,
self
.
ln_f
.
bias
,
residual
=
residual
,
x1
=
None
if
not
self
.
parallel_block
else
hidden_states2
,
eps
=
self
.
ln_f
.
eps
,
dropout_p
=
self
.
drop_f
.
p
if
self
.
training
else
0.0
,
prenorm
=
False
,
is_rms_norm
=
isinstance
(
self
.
ln_f
,
RMSNorm
)
)
return
hidden_states
...
...
flash_attn/models/vit.py
View file @
abbc1311
...
...
@@ -20,9 +20,9 @@ from flash_attn.modules.mha import MHA
from
flash_attn.modules.mlp
import
FusedMLP
,
Mlp
try
:
from
flash_attn.ops.layer_norm
import
dropout_add_
layer_norm
from
flash_attn.ops.
triton.
layer_norm
import
layer_norm
_fn
except
ImportError
:
dropout_add_
layer_norm
=
None
layer_norm
_fn
=
None
def
create_mixer_cls
(
...
...
@@ -229,8 +229,8 @@ class VisionTransformer(nn.Module):
self
.
norm
=
norm_layer
(
embed_dim
)
self
.
fused_dropout_add_ln
=
fused_dropout_add_ln
if
self
.
fused_dropout_add_ln
and
dropout_add_
layer_norm
is
None
:
raise
ImportError
(
"
dropout_add_layer_norm
is not installed"
)
if
self
.
fused_dropout_add_ln
and
layer_norm
_fn
is
None
:
raise
ImportError
(
"
Triton
is not installed"
)
# Classifier Head
self
.
head
=
nn
.
Linear
(
self
.
embed_dim
,
num_classes
)
if
num_classes
>
0
else
nn
.
Identity
()
...
...
@@ -302,16 +302,15 @@ class VisionTransformer(nn.Module):
)
)
# Set prenorm=False here since we don't need to the residual
hidden_states
=
dropout_add_
layer_norm
(
hidden_states
=
layer_norm
_fn
(
hidden_states
,
residual
,
self
.
norm
.
weight
,
self
.
norm
.
bias
,
self
.
dropout
.
p
if
self
.
training
else
0.0
,
self
.
norm
.
eps
,
residual
=
residual
,
eps
=
self
.
norm
.
eps
,
dropout_p
=
self
.
dropout
.
p
if
self
.
training
else
0.0
,
rowscale
=
rowscale
,
prenorm
=
False
,
residual_in_fp32
=
True
,
)
return
hidden_states
...
...
flash_attn/modules/block.py
View file @
abbc1311
# Copyright (c) 202
2
, Tri Dao.
# Copyright (c) 202
4
, Tri Dao.
from
functools
import
partial
from
typing
import
Optional
...
...
@@ -13,24 +13,9 @@ from flash_attn.modules.mha import MHA
from
flash_attn.modules.mlp
import
Mlp
try
:
from
flash_attn.ops.layer_norm
import
dropout_add_layer_n
orm
from
flash_attn.ops.
triton.
layer_norm
import
layer_norm_fn
,
RMSN
orm
except
ImportError
:
dropout_add_layer_norm
=
None
try
:
from
flash_attn.ops.layer_norm
import
dropout_add_layer_norm_parallel_residual
except
ImportError
:
dropout_add_layer_norm_parallel_residual
=
None
try
:
from
flash_attn.ops.rms_norm
import
RMSNorm
,
dropout_add_rms_norm
except
ImportError
:
RMSNorm
,
dropout_add_rms_norm
=
None
,
None
try
:
from
flash_attn.ops.rms_norm
import
dropout_add_rms_norm_parallel_residual
except
ImportError
:
dropout_add_rms_norm_parallel_residual
=
None
layer_norm_fn
,
RMSNorm
=
None
,
None
class
Block
(
nn
.
Module
):
...
...
@@ -91,8 +76,7 @@ class Block(nn.Module):
self
.
norm2
=
norm_cls
(
dim
)
if
self
.
fused_dropout_add_ln
:
assert
dropout_add_layer_norm
is
not
None
,
"dropout_layer_norm is not installed"
assert
dropout_add_rms_norm
is
not
None
,
"dropout_layer_norm is not installed"
assert
layer_norm_fn
is
not
None
,
"Triton is not installed"
assert
isinstance
(
self
.
norm1
,
(
nn
.
LayerNorm
,
RMSNorm
))
and
isinstance
(
self
.
dropout1
,
nn
.
Dropout
)
...
...
@@ -137,11 +121,6 @@ class Block(nn.Module):
before applying the query projection. Useful for e.g., ViT where we only care
about the CLS token in the last layer.
"""
fused_add_norm_fn
=
(
dropout_add_rms_norm
if
RMSNorm
and
isinstance
(
self
.
norm1
,
RMSNorm
)
else
dropout_add_layer_norm
)
if
self
.
prenorm
:
if
not
self
.
fused_dropout_add_ln
:
dropped
=
self
.
drop_path1
(
self
.
dropout1
(
hidden_states
))
...
...
@@ -160,16 +139,17 @@ class Block(nn.Module):
dtype
=
hidden_states
.
dtype
,
)
)
hidden_states
,
residual
=
fused_add
_norm_fn
(
hidden_states
,
residual
=
layer
_norm_fn
(
hidden_states
,
residual
,
self
.
norm1
.
weight
,
self
.
norm1
.
bias
,
self
.
dropout1
.
p
if
self
.
training
else
0.0
,
self
.
norm1
.
eps
,
residual
=
residual
,
eps
=
self
.
norm1
.
eps
,
dropout_p
=
self
.
dropout1
.
p
if
self
.
training
else
0.0
,
rowscale
=
rowscale1
,
prenorm
=
True
,
residual_in_fp32
=
self
.
residual_in_fp32
,
is_rms_norm
=
isinstance
(
self
.
norm1
,
RMSNorm
)
)
if
mixer_kwargs
is
None
:
mixer_kwargs
=
{}
...
...
@@ -196,16 +176,17 @@ class Block(nn.Module):
dtype
=
hidden_states
.
dtype
,
)
)
hidden_states
,
residual
=
fused_add
_norm_fn
(
hidden_states
,
residual
=
layer
_norm_fn
(
hidden_states
,
residual
,
self
.
norm2
.
weight
,
self
.
norm2
.
bias
,
self
.
dropout2
.
p
if
self
.
training
else
0.0
,
self
.
norm2
.
eps
,
residual
=
residual
,
eps
=
self
.
norm2
.
eps
,
dropout_p
=
self
.
dropout2
.
p
if
self
.
training
else
0.0
,
rowscale
=
rowscale2
,
prenorm
=
True
,
residual_in_fp32
=
self
.
residual_in_fp32
,
is_rms_norm
=
isinstance
(
self
.
norm2
,
RMSNorm
)
)
hidden_states
=
self
.
mlp
(
hidden_states
)
return
hidden_states
,
residual
...
...
@@ -231,15 +212,16 @@ class Block(nn.Module):
mixer_out
.
shape
[:
-
1
],
device
=
mixer_out
.
device
,
dtype
=
mixer_out
.
dtype
)
)
hidden_states
=
fused_add
_norm_fn
(
hidden_states
=
layer
_norm_fn
(
mixer_out
,
hidden_states
,
self
.
norm1
.
weight
,
self
.
norm1
.
bias
,
self
.
dropout1
.
p
if
self
.
training
else
0.0
,
self
.
norm1
.
eps
,
residual
=
hidden_states
,
eps
=
self
.
norm1
.
eps
,
dropout_p
=
self
.
dropout1
.
p
if
self
.
training
else
0.0
,
rowscale
=
rowscale1
,
prenorm
=
False
,
is_rms_norm
=
isinstance
(
self
.
norm1
,
RMSNorm
)
)
if
not
isinstance
(
self
.
mlp
,
nn
.
Identity
):
mlp_out
=
self
.
mlp
(
hidden_states
)
...
...
@@ -260,15 +242,16 @@ class Block(nn.Module):
mlp_out
.
shape
[:
-
1
],
device
=
mlp_out
.
device
,
dtype
=
mlp_out
.
dtype
)
)
hidden_states
=
fused_add
_norm_fn
(
hidden_states
=
layer
_norm_fn
(
mlp_out
,
hidden_states
,
self
.
norm2
.
weight
,
self
.
norm2
.
bias
,
self
.
dropout2
.
p
if
self
.
training
else
0.0
,
self
.
norm2
.
eps
,
residual
=
hidden_states
,
eps
=
self
.
norm2
.
eps
,
dropout_p
=
self
.
dropout2
.
p
if
self
.
training
else
0.0
,
rowscale
=
rowscale2
,
prenorm
=
False
,
is_rms_norm
=
isinstance
(
self
.
norm2
,
RMSNorm
)
)
return
hidden_states
...
...
@@ -320,12 +303,7 @@ class ParallelBlock(nn.Module):
self
.
norm2
=
norm_cls
(
dim
)
if
self
.
fused_dropout_add_ln
:
assert
(
dropout_add_layer_norm_parallel_residual
is
not
None
),
"dropout_layer_norm is not installed"
assert
(
dropout_add_rms_norm_parallel_residual
is
not
None
),
"dropout_layer_norm is not installed"
assert
layer_norm_fn
is
not
None
,
"Triton is not installed"
assert
isinstance
(
self
.
norm1
,
(
nn
.
LayerNorm
,
RMSNorm
))
and
isinstance
(
self
.
dropout1
,
nn
.
Dropout
)
...
...
@@ -370,11 +348,6 @@ class ParallelBlock(nn.Module):
"""
# TODO: Ideally we should only do the allgather / allreduce once for
# the Linear to MLP & Attention
fused_add_norm_fn
=
(
dropout_add_rms_norm_parallel_residual
if
isinstance
(
self
.
norm1
,
RMSNorm
)
else
dropout_add_layer_norm_parallel_residual
)
if
not
self
.
fused_dropout_add_ln
:
dropped1
=
self
.
dropout1
(
hidden_states1
)
# For the very 1st block, we only want 1 dropout, not two different dropouts
...
...
@@ -399,21 +372,24 @@ class ParallelBlock(nn.Module):
weight2
,
bias2
=
(
(
self
.
norm2
.
weight
,
self
.
norm2
.
bias
)
if
not
self
.
tied_norm
else
(
None
,
None
)
)
hidden_states1
,
hidden_states2
,
residual
=
fused_add
_norm_fn
(
hidden_states1
,
*
rest
,
residual
=
layer
_norm_fn
(
hidden_states1
,
hidden_states2
,
residual
,
self
.
norm1
.
weight
,
self
.
norm1
.
bias
,
weight2
,
bias2
,
self
.
dropout1
.
p
if
self
.
training
else
0.0
,
self
.
norm1
.
eps
,
residual
=
residual
,
x1
=
hidden_states2
,
weight1
=
weight2
,
bias1
=
bias2
,
eps
=
self
.
norm1
.
eps
,
dropout_p
=
self
.
dropout1
.
p
if
self
.
training
else
0.0
,
prenorm
=
True
,
residual_in_fp32
=
self
.
residual_in_fp32
,
is_rms_norm
=
isinstance
(
self
.
norm1
,
RMSNorm
)
)
if
self
.
tied_norm
:
hidden_states2
=
hidden_states1
else
:
hidden_states2
,
=
rest
if
mixer_kwargs
is
None
:
mixer_kwargs
=
{}
hidden_states1
=
self
.
mixer
(
hidden_states1
,
**
mixer_kwargs
)
...
...
training/Dockerfile
View file @
abbc1311
...
...
@@ -87,9 +87,5 @@ RUN pip install git+https://github.com/mlcommons/logging.git@2.1.0
# Install FlashAttention
RUN
pip
install
flash-attn
==
2.4.2
# Install CUDA extensions for fused dense, layer norm
RUN
git clone https://github.com/HazyResearch/flash-attention
\
&&
cd
flash-attention
&&
git checkout v2.4.2
\
&&
cd
csrc/layer_norm
&&
pip
install
.
&&
cd
../../
\
&&
cd
csrc/fused_dense_lib
&&
pip
install
.
&&
cd
../../
\
&&
cd
..
&&
rm
-rf
flash-attention
# Install CUDA extensions for fused dense
RUN
pip
install
git+https://github.com/HazyResearch/flash-attention@v2.4.2#subdirectory
=
csrc/fused_dense_lib
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment