Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
abbc1311
Commit
abbc1311
authored
Jan 05, 2024
by
Tri Dao
Browse files
[LayerNorm] Switch from CUDA to Triton implementation
parent
f5b308e2
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
82 additions
and
143 deletions
+82
-143
csrc/layer_norm/README.md
csrc/layer_norm/README.md
+4
-0
flash_attn/models/bert.py
flash_attn/models/bert.py
+11
-10
flash_attn/models/gpt.py
flash_attn/models/gpt.py
+22
-59
flash_attn/models/vit.py
flash_attn/models/vit.py
+8
-9
flash_attn/modules/block.py
flash_attn/modules/block.py
+35
-59
training/Dockerfile
training/Dockerfile
+2
-6
No files found.
csrc/layer_norm/README.md
View file @
abbc1311
...
@@ -14,3 +14,7 @@ This extension has only been tested on A100s.
...
@@ -14,3 +14,7 @@ This extension has only been tested on A100s.
```
sh
```
sh
cd
csrc/layer_norm
&&
pip
install
.
cd
csrc/layer_norm
&&
pip
install
.
```
```
As of 2024-01-05, this extension is no longer used in the FlashAttention repo.
We've instead switched to a Triton-based
[
implementation
](
https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/layer_norm.py
)
.
flash_attn/models/bert.py
View file @
abbc1311
...
@@ -40,9 +40,10 @@ except ImportError:
...
@@ -40,9 +40,10 @@ except ImportError:
FusedDense
=
None
FusedDense
=
None
try
:
try
:
from
flash_attn.ops.layer_norm
import
dropout_add_layer_norm
,
layer_norm
from
flash_attn.ops.
triton.
layer_norm
import
layer_norm
_fn
except
ImportError
:
except
ImportError
:
dropout_add_layer_norm
,
layer_norm
=
None
,
None
layer_norm_fn
=
None
try
:
try
:
from
flash_attn.losses.cross_entropy
import
CrossEntropyLoss
from
flash_attn.losses.cross_entropy
import
CrossEntropyLoss
...
@@ -237,8 +238,8 @@ class BertPredictionHeadTransform(nn.Module):
...
@@ -237,8 +238,8 @@ class BertPredictionHeadTransform(nn.Module):
if
fused_bias_fc
and
FusedDense
is
None
:
if
fused_bias_fc
and
FusedDense
is
None
:
raise
ImportError
(
"fused_dense is not installed"
)
raise
ImportError
(
"fused_dense is not installed"
)
self
.
fused_dropout_add_ln
=
getattr
(
config
,
"fused_dropout_add_ln"
,
False
)
self
.
fused_dropout_add_ln
=
getattr
(
config
,
"fused_dropout_add_ln"
,
False
)
if
self
.
fused_dropout_add_ln
and
layer_norm
is
None
:
if
self
.
fused_dropout_add_ln
and
layer_norm
_fn
is
None
:
raise
ImportError
(
"
dropout_add_layer_norm
is not installed"
)
raise
ImportError
(
"
Triton
is not installed"
)
linear_cls
=
nn
.
Linear
if
not
fused_bias_fc
else
FusedDense
linear_cls
=
nn
.
Linear
if
not
fused_bias_fc
else
FusedDense
self
.
dense
=
linear_cls
(
config
.
hidden_size
,
config
.
hidden_size
)
self
.
dense
=
linear_cls
(
config
.
hidden_size
,
config
.
hidden_size
)
approximate
=
(
approximate
=
(
...
@@ -255,8 +256,8 @@ class BertPredictionHeadTransform(nn.Module):
...
@@ -255,8 +256,8 @@ class BertPredictionHeadTransform(nn.Module):
if
not
self
.
fused_dropout_add_ln
:
if
not
self
.
fused_dropout_add_ln
:
hidden_states
=
self
.
layer_norm
(
hidden_states
)
hidden_states
=
self
.
layer_norm
(
hidden_states
)
else
:
else
:
hidden_states
=
layer_norm
(
hidden_states
=
layer_norm
_fn
(
hidden_states
,
self
.
layer_norm
.
weight
,
self
.
layer_norm
.
bias
,
self
.
layer_norm
.
eps
hidden_states
,
self
.
layer_norm
.
weight
,
self
.
layer_norm
.
bias
,
eps
=
self
.
layer_norm
.
eps
)
)
return
hidden_states
return
hidden_states
...
@@ -345,8 +346,8 @@ class BertModel(BertPreTrainedModel):
...
@@ -345,8 +346,8 @@ class BertModel(BertPreTrainedModel):
config
.
vocab_size
%
self
.
pad_vocab_size_multiple
config
.
vocab_size
%
self
.
pad_vocab_size_multiple
)
)
self
.
fused_dropout_add_ln
=
getattr
(
config
,
"fused_dropout_add_ln"
,
False
)
self
.
fused_dropout_add_ln
=
getattr
(
config
,
"fused_dropout_add_ln"
,
False
)
if
self
.
fused_dropout_add_ln
and
layer_norm
is
None
:
if
self
.
fused_dropout_add_ln
and
layer_norm
_fn
is
None
:
raise
ImportError
(
"
dropout_add_layer_norm
is not installed"
)
raise
ImportError
(
"
Triton
is not installed"
)
assert
config
.
hidden_act
in
[
"gelu"
,
"gelu_new"
,
"gelu_fast"
,
"gelu_pytorch_tanh"
]
assert
config
.
hidden_act
in
[
"gelu"
,
"gelu_new"
,
"gelu_fast"
,
"gelu_pytorch_tanh"
]
self
.
embeddings
=
BertEmbeddings
(
self
.
embeddings
=
BertEmbeddings
(
...
@@ -384,8 +385,8 @@ class BertModel(BertPreTrainedModel):
...
@@ -384,8 +385,8 @@ class BertModel(BertPreTrainedModel):
if
not
self
.
fused_dropout_add_ln
:
if
not
self
.
fused_dropout_add_ln
:
hidden_states
=
self
.
emb_ln
(
hidden_states
)
hidden_states
=
self
.
emb_ln
(
hidden_states
)
else
:
else
:
hidden_states
=
layer_norm
(
hidden_states
=
layer_norm
_fn
(
hidden_states
,
self
.
emb_ln
.
weight
,
self
.
emb_ln
.
bias
,
self
.
emb_ln
.
eps
hidden_states
,
self
.
emb_ln
.
weight
,
self
.
emb_ln
.
bias
,
eps
=
self
.
emb_ln
.
eps
)
)
hidden_states
=
self
.
emb_drop
(
hidden_states
)
hidden_states
=
self
.
emb_drop
(
hidden_states
)
...
...
flash_attn/models/gpt.py
View file @
abbc1311
# Copyright (c) 202
3
, Tri Dao.
# Copyright (c) 202
4
, Tri Dao.
import
logging
import
logging
import
math
import
math
...
@@ -47,29 +47,14 @@ except ImportError:
...
@@ -47,29 +47,14 @@ except ImportError:
ColumnParallelLinear
=
None
ColumnParallelLinear
=
None
try
:
try
:
from
flash_attn.ops.layer_norm
import
dropout_add_layer_norm
from
flash_attn.ops.triton.mlp
import
FusedDenseSqreluDense
except
ImportError
:
dropout_add_layer_norm
=
None
try
:
from
flash_attn.ops.layer_norm
import
dropout_add_layer_norm_parallel_residual
except
ImportError
:
dropout_add_layer_norm_parallel_residual
=
None
try
:
from
flash_attn.ops.rms_norm
import
RMSNorm
,
dropout_add_rms_norm
except
ImportError
:
RMSNorm
,
dropout_add_rms_norm
=
None
,
None
try
:
from
flash_attn.ops.rms_norm
import
dropout_add_rms_norm_parallel_residual
except
ImportError
:
except
ImportError
:
dropout_add_rms_norm_parallel_residual
=
None
FusedDenseSqreluDense
=
None
try
:
try
:
from
flash_attn.ops.triton.
mlp
import
FusedDenseSqreluDense
from
flash_attn.ops.triton.
layer_norm
import
layer_norm_fn
,
RMSNorm
except
ImportError
:
except
ImportError
:
FusedDenseSqreluDense
=
None
layer_norm_fn
,
RMSNorm
=
None
,
None
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -481,13 +466,15 @@ class GPTModel(GPTPreTrainedModel):
...
@@ -481,13 +466,15 @@ class GPTModel(GPTPreTrainedModel):
for
i
in
range
(
config
.
num_hidden_layers
)
for
i
in
range
(
config
.
num_hidden_layers
)
]
]
)
)
rotary_emb_fraction
=
getattr
(
config
,
"rotary_emb_fraction"
,
0.0
)
if
rotary_emb_fraction
>
0.0
:
# Tie all the RotaryEmbedding modules to share the same cos/sin cache
for
layer
in
self
.
layers
[
1
:]:
layer
.
mixer
.
rotary_emb
=
self
.
layers
[
0
].
mixer
.
rotary_emb
self
.
fused_dropout_add_ln
=
getattr
(
config
,
"fused_dropout_add_ln"
,
False
)
self
.
fused_dropout_add_ln
=
getattr
(
config
,
"fused_dropout_add_ln"
,
False
)
if
self
.
fused_dropout_add_ln
:
if
self
.
fused_dropout_add_ln
:
if
(
not
self
.
parallel_block
and
dropout_add_layer_norm
is
None
)
or
(
if
layer_norm_fn
is
None
:
self
.
parallel_block
and
dropout_add_layer_norm_parallel_residual
is
None
raise
ImportError
(
"Triton is not installed"
)
):
raise
ImportError
(
"dropout_layer_norm is not installed"
)
if
self
.
prenorm
:
if
self
.
prenorm
:
self
.
drop_f
=
nn
.
Dropout
(
config
.
resid_pdrop
)
self
.
drop_f
=
nn
.
Dropout
(
config
.
resid_pdrop
)
norm_cls
=
nn
.
LayerNorm
if
not
use_rms_norm
else
RMSNorm
norm_cls
=
nn
.
LayerNorm
if
not
use_rms_norm
else
RMSNorm
...
@@ -571,41 +558,17 @@ class GPTModel(GPTPreTrainedModel):
...
@@ -571,41 +558,17 @@ class GPTModel(GPTPreTrainedModel):
hidden_states
=
self
.
ln_f
(
residual
.
to
(
dtype
=
self
.
ln_f
.
weight
.
dtype
))
hidden_states
=
self
.
ln_f
(
residual
.
to
(
dtype
=
self
.
ln_f
.
weight
.
dtype
))
else
:
else
:
# Set prenorm=False here since we don't need the residual
# Set prenorm=False here since we don't need the residual
if
not
self
.
parallel_block
:
hidden_states
=
layer_norm_fn
(
fused_add_norm_fn
=
(
hidden_states
,
dropout_add_rms_norm
self
.
ln_f
.
weight
,
if
isinstance
(
self
.
ln_f
,
RMSNorm
)
self
.
ln_f
.
bias
,
else
dropout_add_layer_norm
residual
=
residual
,
)
x1
=
None
if
not
self
.
parallel_block
else
hidden_states2
,
hidden_states
=
fused_add_norm_fn
(
eps
=
self
.
ln_f
.
eps
,
hidden_states
,
dropout_p
=
self
.
drop_f
.
p
if
self
.
training
else
0.0
,
residual
,
prenorm
=
False
,
self
.
ln_f
.
weight
,
is_rms_norm
=
isinstance
(
self
.
ln_f
,
RMSNorm
)
self
.
ln_f
.
bias
,
)
self
.
drop_f
.
p
if
self
.
training
else
0.0
,
self
.
ln_f
.
eps
,
prenorm
=
False
,
residual_in_fp32
=
self
.
residual_in_fp32
,
)
else
:
fused_add_norm_fn
=
(
dropout_add_rms_norm_parallel_residual
if
isinstance
(
self
.
ln_f
,
RMSNorm
)
else
dropout_add_layer_norm_parallel_residual
)
hidden_states
,
_
=
fused_add_norm_fn
(
hidden_states
,
hidden_states2
,
residual
,
self
.
ln_f
.
weight
,
self
.
ln_f
.
bias
,
None
,
None
,
self
.
drop_f
.
p
if
self
.
training
else
0.0
,
self
.
ln_f
.
eps
,
prenorm
=
False
,
residual_in_fp32
=
self
.
residual_in_fp32
,
)
return
hidden_states
return
hidden_states
...
...
flash_attn/models/vit.py
View file @
abbc1311
...
@@ -20,9 +20,9 @@ from flash_attn.modules.mha import MHA
...
@@ -20,9 +20,9 @@ from flash_attn.modules.mha import MHA
from
flash_attn.modules.mlp
import
FusedMLP
,
Mlp
from
flash_attn.modules.mlp
import
FusedMLP
,
Mlp
try
:
try
:
from
flash_attn.ops.layer_norm
import
dropout_add_
layer_norm
from
flash_attn.ops.
triton.
layer_norm
import
layer_norm
_fn
except
ImportError
:
except
ImportError
:
dropout_add_
layer_norm
=
None
layer_norm
_fn
=
None
def
create_mixer_cls
(
def
create_mixer_cls
(
...
@@ -229,8 +229,8 @@ class VisionTransformer(nn.Module):
...
@@ -229,8 +229,8 @@ class VisionTransformer(nn.Module):
self
.
norm
=
norm_layer
(
embed_dim
)
self
.
norm
=
norm_layer
(
embed_dim
)
self
.
fused_dropout_add_ln
=
fused_dropout_add_ln
self
.
fused_dropout_add_ln
=
fused_dropout_add_ln
if
self
.
fused_dropout_add_ln
and
dropout_add_
layer_norm
is
None
:
if
self
.
fused_dropout_add_ln
and
layer_norm
_fn
is
None
:
raise
ImportError
(
"
dropout_add_layer_norm
is not installed"
)
raise
ImportError
(
"
Triton
is not installed"
)
# Classifier Head
# Classifier Head
self
.
head
=
nn
.
Linear
(
self
.
embed_dim
,
num_classes
)
if
num_classes
>
0
else
nn
.
Identity
()
self
.
head
=
nn
.
Linear
(
self
.
embed_dim
,
num_classes
)
if
num_classes
>
0
else
nn
.
Identity
()
...
@@ -302,16 +302,15 @@ class VisionTransformer(nn.Module):
...
@@ -302,16 +302,15 @@ class VisionTransformer(nn.Module):
)
)
)
)
# Set prenorm=False here since we don't need to the residual
# Set prenorm=False here since we don't need to the residual
hidden_states
=
dropout_add_
layer_norm
(
hidden_states
=
layer_norm
_fn
(
hidden_states
,
hidden_states
,
residual
,
self
.
norm
.
weight
,
self
.
norm
.
weight
,
self
.
norm
.
bias
,
self
.
norm
.
bias
,
self
.
dropout
.
p
if
self
.
training
else
0.0
,
residual
=
residual
,
self
.
norm
.
eps
,
eps
=
self
.
norm
.
eps
,
dropout_p
=
self
.
dropout
.
p
if
self
.
training
else
0.0
,
rowscale
=
rowscale
,
rowscale
=
rowscale
,
prenorm
=
False
,
prenorm
=
False
,
residual_in_fp32
=
True
,
)
)
return
hidden_states
return
hidden_states
...
...
flash_attn/modules/block.py
View file @
abbc1311
# Copyright (c) 202
2
, Tri Dao.
# Copyright (c) 202
4
, Tri Dao.
from
functools
import
partial
from
functools
import
partial
from
typing
import
Optional
from
typing
import
Optional
...
@@ -13,24 +13,9 @@ from flash_attn.modules.mha import MHA
...
@@ -13,24 +13,9 @@ from flash_attn.modules.mha import MHA
from
flash_attn.modules.mlp
import
Mlp
from
flash_attn.modules.mlp
import
Mlp
try
:
try
:
from
flash_attn.ops.layer_norm
import
dropout_add_layer_n
orm
from
flash_attn.ops.
triton.
layer_norm
import
layer_norm_fn
,
RMSN
orm
except
ImportError
:
except
ImportError
:
dropout_add_layer_norm
=
None
layer_norm_fn
,
RMSNorm
=
None
,
None
try
:
from
flash_attn.ops.layer_norm
import
dropout_add_layer_norm_parallel_residual
except
ImportError
:
dropout_add_layer_norm_parallel_residual
=
None
try
:
from
flash_attn.ops.rms_norm
import
RMSNorm
,
dropout_add_rms_norm
except
ImportError
:
RMSNorm
,
dropout_add_rms_norm
=
None
,
None
try
:
from
flash_attn.ops.rms_norm
import
dropout_add_rms_norm_parallel_residual
except
ImportError
:
dropout_add_rms_norm_parallel_residual
=
None
class
Block
(
nn
.
Module
):
class
Block
(
nn
.
Module
):
...
@@ -91,8 +76,7 @@ class Block(nn.Module):
...
@@ -91,8 +76,7 @@ class Block(nn.Module):
self
.
norm2
=
norm_cls
(
dim
)
self
.
norm2
=
norm_cls
(
dim
)
if
self
.
fused_dropout_add_ln
:
if
self
.
fused_dropout_add_ln
:
assert
dropout_add_layer_norm
is
not
None
,
"dropout_layer_norm is not installed"
assert
layer_norm_fn
is
not
None
,
"Triton is not installed"
assert
dropout_add_rms_norm
is
not
None
,
"dropout_layer_norm is not installed"
assert
isinstance
(
self
.
norm1
,
(
nn
.
LayerNorm
,
RMSNorm
))
and
isinstance
(
assert
isinstance
(
self
.
norm1
,
(
nn
.
LayerNorm
,
RMSNorm
))
and
isinstance
(
self
.
dropout1
,
nn
.
Dropout
self
.
dropout1
,
nn
.
Dropout
)
)
...
@@ -137,11 +121,6 @@ class Block(nn.Module):
...
@@ -137,11 +121,6 @@ class Block(nn.Module):
before applying the query projection. Useful for e.g., ViT where we only care
before applying the query projection. Useful for e.g., ViT where we only care
about the CLS token in the last layer.
about the CLS token in the last layer.
"""
"""
fused_add_norm_fn
=
(
dropout_add_rms_norm
if
RMSNorm
and
isinstance
(
self
.
norm1
,
RMSNorm
)
else
dropout_add_layer_norm
)
if
self
.
prenorm
:
if
self
.
prenorm
:
if
not
self
.
fused_dropout_add_ln
:
if
not
self
.
fused_dropout_add_ln
:
dropped
=
self
.
drop_path1
(
self
.
dropout1
(
hidden_states
))
dropped
=
self
.
drop_path1
(
self
.
dropout1
(
hidden_states
))
...
@@ -160,16 +139,17 @@ class Block(nn.Module):
...
@@ -160,16 +139,17 @@ class Block(nn.Module):
dtype
=
hidden_states
.
dtype
,
dtype
=
hidden_states
.
dtype
,
)
)
)
)
hidden_states
,
residual
=
fused_add
_norm_fn
(
hidden_states
,
residual
=
layer
_norm_fn
(
hidden_states
,
hidden_states
,
residual
,
self
.
norm1
.
weight
,
self
.
norm1
.
weight
,
self
.
norm1
.
bias
,
self
.
norm1
.
bias
,
self
.
dropout1
.
p
if
self
.
training
else
0.0
,
residual
=
residual
,
self
.
norm1
.
eps
,
eps
=
self
.
norm1
.
eps
,
dropout_p
=
self
.
dropout1
.
p
if
self
.
training
else
0.0
,
rowscale
=
rowscale1
,
rowscale
=
rowscale1
,
prenorm
=
True
,
prenorm
=
True
,
residual_in_fp32
=
self
.
residual_in_fp32
,
residual_in_fp32
=
self
.
residual_in_fp32
,
is_rms_norm
=
isinstance
(
self
.
norm1
,
RMSNorm
)
)
)
if
mixer_kwargs
is
None
:
if
mixer_kwargs
is
None
:
mixer_kwargs
=
{}
mixer_kwargs
=
{}
...
@@ -196,16 +176,17 @@ class Block(nn.Module):
...
@@ -196,16 +176,17 @@ class Block(nn.Module):
dtype
=
hidden_states
.
dtype
,
dtype
=
hidden_states
.
dtype
,
)
)
)
)
hidden_states
,
residual
=
fused_add
_norm_fn
(
hidden_states
,
residual
=
layer
_norm_fn
(
hidden_states
,
hidden_states
,
residual
,
self
.
norm2
.
weight
,
self
.
norm2
.
weight
,
self
.
norm2
.
bias
,
self
.
norm2
.
bias
,
self
.
dropout2
.
p
if
self
.
training
else
0.0
,
residual
=
residual
,
self
.
norm2
.
eps
,
eps
=
self
.
norm2
.
eps
,
dropout_p
=
self
.
dropout2
.
p
if
self
.
training
else
0.0
,
rowscale
=
rowscale2
,
rowscale
=
rowscale2
,
prenorm
=
True
,
prenorm
=
True
,
residual_in_fp32
=
self
.
residual_in_fp32
,
residual_in_fp32
=
self
.
residual_in_fp32
,
is_rms_norm
=
isinstance
(
self
.
norm2
,
RMSNorm
)
)
)
hidden_states
=
self
.
mlp
(
hidden_states
)
hidden_states
=
self
.
mlp
(
hidden_states
)
return
hidden_states
,
residual
return
hidden_states
,
residual
...
@@ -231,15 +212,16 @@ class Block(nn.Module):
...
@@ -231,15 +212,16 @@ class Block(nn.Module):
mixer_out
.
shape
[:
-
1
],
device
=
mixer_out
.
device
,
dtype
=
mixer_out
.
dtype
mixer_out
.
shape
[:
-
1
],
device
=
mixer_out
.
device
,
dtype
=
mixer_out
.
dtype
)
)
)
)
hidden_states
=
fused_add
_norm_fn
(
hidden_states
=
layer
_norm_fn
(
mixer_out
,
mixer_out
,
hidden_states
,
self
.
norm1
.
weight
,
self
.
norm1
.
weight
,
self
.
norm1
.
bias
,
self
.
norm1
.
bias
,
self
.
dropout1
.
p
if
self
.
training
else
0.0
,
residual
=
hidden_states
,
self
.
norm1
.
eps
,
eps
=
self
.
norm1
.
eps
,
dropout_p
=
self
.
dropout1
.
p
if
self
.
training
else
0.0
,
rowscale
=
rowscale1
,
rowscale
=
rowscale1
,
prenorm
=
False
,
prenorm
=
False
,
is_rms_norm
=
isinstance
(
self
.
norm1
,
RMSNorm
)
)
)
if
not
isinstance
(
self
.
mlp
,
nn
.
Identity
):
if
not
isinstance
(
self
.
mlp
,
nn
.
Identity
):
mlp_out
=
self
.
mlp
(
hidden_states
)
mlp_out
=
self
.
mlp
(
hidden_states
)
...
@@ -260,15 +242,16 @@ class Block(nn.Module):
...
@@ -260,15 +242,16 @@ class Block(nn.Module):
mlp_out
.
shape
[:
-
1
],
device
=
mlp_out
.
device
,
dtype
=
mlp_out
.
dtype
mlp_out
.
shape
[:
-
1
],
device
=
mlp_out
.
device
,
dtype
=
mlp_out
.
dtype
)
)
)
)
hidden_states
=
fused_add
_norm_fn
(
hidden_states
=
layer
_norm_fn
(
mlp_out
,
mlp_out
,
hidden_states
,
self
.
norm2
.
weight
,
self
.
norm2
.
weight
,
self
.
norm2
.
bias
,
self
.
norm2
.
bias
,
self
.
dropout2
.
p
if
self
.
training
else
0.0
,
residual
=
hidden_states
,
self
.
norm2
.
eps
,
eps
=
self
.
norm2
.
eps
,
dropout_p
=
self
.
dropout2
.
p
if
self
.
training
else
0.0
,
rowscale
=
rowscale2
,
rowscale
=
rowscale2
,
prenorm
=
False
,
prenorm
=
False
,
is_rms_norm
=
isinstance
(
self
.
norm2
,
RMSNorm
)
)
)
return
hidden_states
return
hidden_states
...
@@ -320,12 +303,7 @@ class ParallelBlock(nn.Module):
...
@@ -320,12 +303,7 @@ class ParallelBlock(nn.Module):
self
.
norm2
=
norm_cls
(
dim
)
self
.
norm2
=
norm_cls
(
dim
)
if
self
.
fused_dropout_add_ln
:
if
self
.
fused_dropout_add_ln
:
assert
(
assert
layer_norm_fn
is
not
None
,
"Triton is not installed"
dropout_add_layer_norm_parallel_residual
is
not
None
),
"dropout_layer_norm is not installed"
assert
(
dropout_add_rms_norm_parallel_residual
is
not
None
),
"dropout_layer_norm is not installed"
assert
isinstance
(
self
.
norm1
,
(
nn
.
LayerNorm
,
RMSNorm
))
and
isinstance
(
assert
isinstance
(
self
.
norm1
,
(
nn
.
LayerNorm
,
RMSNorm
))
and
isinstance
(
self
.
dropout1
,
nn
.
Dropout
self
.
dropout1
,
nn
.
Dropout
)
)
...
@@ -370,11 +348,6 @@ class ParallelBlock(nn.Module):
...
@@ -370,11 +348,6 @@ class ParallelBlock(nn.Module):
"""
"""
# TODO: Ideally we should only do the allgather / allreduce once for
# TODO: Ideally we should only do the allgather / allreduce once for
# the Linear to MLP & Attention
# the Linear to MLP & Attention
fused_add_norm_fn
=
(
dropout_add_rms_norm_parallel_residual
if
isinstance
(
self
.
norm1
,
RMSNorm
)
else
dropout_add_layer_norm_parallel_residual
)
if
not
self
.
fused_dropout_add_ln
:
if
not
self
.
fused_dropout_add_ln
:
dropped1
=
self
.
dropout1
(
hidden_states1
)
dropped1
=
self
.
dropout1
(
hidden_states1
)
# For the very 1st block, we only want 1 dropout, not two different dropouts
# For the very 1st block, we only want 1 dropout, not two different dropouts
...
@@ -399,21 +372,24 @@ class ParallelBlock(nn.Module):
...
@@ -399,21 +372,24 @@ class ParallelBlock(nn.Module):
weight2
,
bias2
=
(
weight2
,
bias2
=
(
(
self
.
norm2
.
weight
,
self
.
norm2
.
bias
)
if
not
self
.
tied_norm
else
(
None
,
None
)
(
self
.
norm2
.
weight
,
self
.
norm2
.
bias
)
if
not
self
.
tied_norm
else
(
None
,
None
)
)
)
hidden_states1
,
hidden_states2
,
residual
=
fused_add
_norm_fn
(
hidden_states1
,
*
rest
,
residual
=
layer
_norm_fn
(
hidden_states1
,
hidden_states1
,
hidden_states2
,
residual
,
self
.
norm1
.
weight
,
self
.
norm1
.
weight
,
self
.
norm1
.
bias
,
self
.
norm1
.
bias
,
weight2
,
residual
=
residual
,
bias2
,
x1
=
hidden_states2
,
self
.
dropout1
.
p
if
self
.
training
else
0.0
,
weight1
=
weight2
,
self
.
norm1
.
eps
,
bias1
=
bias2
,
eps
=
self
.
norm1
.
eps
,
dropout_p
=
self
.
dropout1
.
p
if
self
.
training
else
0.0
,
prenorm
=
True
,
prenorm
=
True
,
residual_in_fp32
=
self
.
residual_in_fp32
,
residual_in_fp32
=
self
.
residual_in_fp32
,
is_rms_norm
=
isinstance
(
self
.
norm1
,
RMSNorm
)
)
)
if
self
.
tied_norm
:
if
self
.
tied_norm
:
hidden_states2
=
hidden_states1
hidden_states2
=
hidden_states1
else
:
hidden_states2
,
=
rest
if
mixer_kwargs
is
None
:
if
mixer_kwargs
is
None
:
mixer_kwargs
=
{}
mixer_kwargs
=
{}
hidden_states1
=
self
.
mixer
(
hidden_states1
,
**
mixer_kwargs
)
hidden_states1
=
self
.
mixer
(
hidden_states1
,
**
mixer_kwargs
)
...
...
training/Dockerfile
View file @
abbc1311
...
@@ -87,9 +87,5 @@ RUN pip install git+https://github.com/mlcommons/logging.git@2.1.0
...
@@ -87,9 +87,5 @@ RUN pip install git+https://github.com/mlcommons/logging.git@2.1.0
# Install FlashAttention
# Install FlashAttention
RUN
pip
install
flash-attn
==
2.4.2
RUN
pip
install
flash-attn
==
2.4.2
# Install CUDA extensions for fused dense, layer norm
# Install CUDA extensions for fused dense
RUN
git clone https://github.com/HazyResearch/flash-attention
\
RUN
pip
install
git+https://github.com/HazyResearch/flash-attention@v2.4.2#subdirectory
=
csrc/fused_dense_lib
&&
cd
flash-attention
&&
git checkout v2.4.2
\
&&
cd
csrc/layer_norm
&&
pip
install
.
&&
cd
../../
\
&&
cd
csrc/fused_dense_lib
&&
pip
install
.
&&
cd
../../
\
&&
cd
..
&&
rm
-rf
flash-attention
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment