Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
ef085cfc
Commit
ef085cfc
authored
Jan 15, 2023
by
Tri Dao
Browse files
[ViT] Fix extra norm_0, use new LN order in Block
parent
ff34123b
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
46 additions
and
34 deletions
+46
-34
flash_attn/models/vit.py
flash_attn/models/vit.py
+45
-33
flash_attn/modules/block.py
flash_attn/modules/block.py
+1
-1
No files found.
flash_attn/models/vit.py
View file @
ef085cfc
...
@@ -9,6 +9,8 @@ import torch.nn as nn
...
@@ -9,6 +9,8 @@ import torch.nn as nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
torch.nn.init
import
trunc_normal_
from
torch.nn.init
import
trunc_normal_
from
torchvision.ops
import
StochasticDepth
from
einops
import
rearrange
from
einops
import
rearrange
from
timm.models.helpers
import
named_apply
from
timm.models.helpers
import
named_apply
...
@@ -41,15 +43,18 @@ def create_mlp_cls(embed_dim, mlp_ratio, act_layer, fused_dense_gelu_dense):
...
@@ -41,15 +43,18 @@ def create_mlp_cls(embed_dim, mlp_ratio, act_layer, fused_dense_gelu_dense):
return
mlp_cls
return
mlp_cls
def
create_block
(
embed_dim
,
num_heads
,
mlp_ratio
,
qkv_bias
,
drop_rate
,
attn_drop_rate
,
drop_path
,
def
create_block
(
embed_dim
,
num_heads
,
mlp_ratio
,
qkv_bias
,
drop_rate
,
attn_drop_rate
,
norm_layer
,
act_layer
,
use_flash_attn
,
fused_bias_fc
,
fused_dense_gelu_dense
,
drop_path1
,
drop_path2
,
norm_layer
,
act_layer
,
use_flash_attn
,
fused_bias_fc
,
fused_dropout_add_ln
,
layer_idx
=
None
,
n_layer
=
None
,
last_layer_subset
=
False
):
fused_dense_gelu_dense
,
fused_dropout_add_ln
,
layer_idx
=
None
,
n_layer
=
None
,
last_layer_subset
=
False
):
mixer_cls
=
create_mixer_cls
(
num_heads
,
qkv_bias
,
attn_drop_rate
,
use_flash_attn
,
fused_bias_fc
,
mixer_cls
=
create_mixer_cls
(
num_heads
,
qkv_bias
,
attn_drop_rate
,
use_flash_attn
,
fused_bias_fc
,
cross_attn
=
(
last_layer_subset
and
layer_idx
==
n_layer
-
1
))
cross_attn
=
(
last_layer_subset
and
layer_idx
==
n_layer
-
1
))
mlp_cls
=
create_mlp_cls
(
embed_dim
,
mlp_ratio
,
act_layer
,
fused_dense_gelu_dense
)
mlp_cls
=
create_mlp_cls
(
embed_dim
,
mlp_ratio
,
act_layer
,
fused_dense_gelu_dense
)
# TD [2022-10-15]: Force residual in fp32 in case of DeepSpeed
block
=
Block
(
embed_dim
,
mixer_cls
,
mlp_cls
,
norm_cls
=
norm_layer
,
block
=
Block
(
embed_dim
,
mixer_cls
,
mlp_cls
,
norm_cls
=
norm_layer
,
prenorm
=
True
,
resid_dropout
=
drop_rate
,
drop_path
=
drop_path
,
prenorm
=
True
,
resid_dropout1
=
drop_rate
,
resid_dropout2
=
drop_rate
,
fused_dropout_add_ln
=
fused_dropout_add_ln
)
drop_path1
=
drop_path1
,
drop_path2
=
drop_path2
,
fused_dropout_add_ln
=
fused_dropout_add_ln
,
residual_in_fp32
=
True
)
return
block
return
block
...
@@ -143,32 +148,32 @@ class VisionTransformer(nn.Module):
...
@@ -143,32 +148,32 @@ class VisionTransformer(nn.Module):
self
.
cls_token
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
1
,
embed_dim
))
if
class_token
else
None
self
.
cls_token
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
1
,
embed_dim
))
if
class_token
else
None
embed_len
=
num_patches
if
no_embed_class
else
num_patches
+
self
.
num_prefix_tokens
embed_len
=
num_patches
if
no_embed_class
else
num_patches
+
self
.
num_prefix_tokens
self
.
pos_embed
=
nn
.
Parameter
(
torch
.
randn
(
1
,
embed_len
,
embed_dim
)
*
.
02
)
self
.
pos_embed
=
nn
.
Parameter
(
torch
.
randn
(
1
,
embed_len
,
embed_dim
)
*
.
02
)
self
.
pos_drop
=
nn
.
Dropout
(
p
=
drop_rate
)
dpr
=
[
x
.
item
()
for
x
in
torch
.
linspace
(
0
,
drop_path_rate
,
depth
)]
# stochastic depth decay rule
dpr
=
[
x
.
item
()
for
x
in
torch
.
linspace
(
0
,
drop_path_rate
,
depth
)]
# stochastic depth decay rule
# We change the order of residual and layer norm:
# We change the order of
dropout,
residual and layer norm:
# Instead of LN -> Attn / MLP -> Dropout -> Add, we do:
# Instead of LN -> Attn / MLP -> Dropout -> Add, we do:
#
Attn / MLP ->
Dropout -> Add -> LN, returning both the residual branch (output of Add) and
# Dropout -> Add -> LN
-> Attn / MLP
, returning both the residual branch (output of Add) and
# the main branch (output of
LN
). The model definition is unchanged, but the mapping of the
# the main branch (output of
MLP
). The model definition is unchanged, but the mapping of the
# nn.
LayerNorm weight
s are changed.
# nn.
Dropout probabilitie
s are changed.
# This is for performance reason: we can fuse dropout + add + layer_norm.
# This is for performance reason: we can fuse dropout + add + layer_norm.
# self.norm_0 is the first layer norm in the model, while self.norm
# (in the pretrained weight) is the final layer norm.
self
.
norm_0
=
norm_layer
(
embed_dim
)
self
.
fused_dropout_add_ln
=
fused_dropout_add_ln
if
self
.
fused_dropout_add_ln
and
dropout_add_layer_norm
is
None
:
raise
ImportError
(
'dropout_add_layer_norm is not installed'
)
self
.
blocks
=
nn
.
ModuleList
([
create_block
(
self
.
blocks
=
nn
.
ModuleList
([
create_block
(
embed_dim
,
num_heads
,
mlp_ratio
,
qkv_bias
,
drop_rate
,
attn_drop_rate
,
drop_path
=
dpr
[
i
],
embed_dim
,
num_heads
,
mlp_ratio
,
qkv_bias
,
drop_rate
,
attn_drop_rate
,
drop_path1
=
dpr
[
i
-
1
]
if
i
>
0
else
0.
,
drop_path2
=
dpr
[
i
],
norm_layer
=
norm_layer
,
act_layer
=
act_layer
,
use_flash_attn
=
use_flash_attn
,
norm_layer
=
norm_layer
,
act_layer
=
act_layer
,
use_flash_attn
=
use_flash_attn
,
fused_bias_fc
=
fused_bias_fc
,
fused_dense_gelu_dense
=
fused_dense_gelu_dense
,
fused_bias_fc
=
fused_bias_fc
,
fused_dense_gelu_dense
=
fused_dense_gelu_dense
,
fused_dropout_add_ln
=
fused_dropout_add_ln
,
layer_idx
=
i
,
n_layer
=
depth
,
fused_dropout_add_ln
=
fused_dropout_add_ln
,
layer_idx
=
i
,
n_layer
=
depth
,
last_layer_subset
=
(
global_pool
==
'token'
)
last_layer_subset
=
(
global_pool
==
'token'
)
)
for
i
in
range
(
depth
)])
)
for
i
in
range
(
depth
)])
self
.
dropout
=
nn
.
Dropout
(
p
=
drop_rate
)
self
.
drop_path
=
StochasticDepth
(
p
=
dpr
[
-
1
],
mode
=
'row'
)
self
.
norm
=
norm_layer
(
embed_dim
)
self
.
fused_dropout_add_ln
=
fused_dropout_add_ln
if
self
.
fused_dropout_add_ln
and
dropout_add_layer_norm
is
None
:
raise
ImportError
(
'dropout_add_layer_norm is not installed'
)
# Classifier Head
# Classifier Head
self
.
head
=
nn
.
Linear
(
self
.
embed_dim
,
num_classes
)
if
num_classes
>
0
else
nn
.
Identity
()
self
.
head
=
nn
.
Linear
(
self
.
embed_dim
,
num_classes
)
if
num_classes
>
0
else
nn
.
Identity
()
...
@@ -210,18 +215,8 @@ class VisionTransformer(nn.Module):
...
@@ -210,18 +215,8 @@ class VisionTransformer(nn.Module):
cls token.
cls token.
"""
"""
x
=
self
.
patch_embed
(
x
)
x
=
self
.
patch_embed
(
x
)
x
=
self
.
_pos_embed
(
x
)
hidden_states
=
self
.
_pos_embed
(
x
)
# TD [2022-10-15]: Force residual in fp32 in case of DeepSpeed
residual
=
None
if
not
self
.
fused_dropout_add_ln
:
residual
=
self
.
pos_drop
(
x
).
float
()
hidden_states
=
self
.
norm_0
(
residual
.
to
(
dtype
=
self
.
norm_0
.
weight
.
dtype
))
else
:
hidden_states
,
residual
=
dropout_add_layer_norm
(
x
,
None
,
self
.
norm_0
.
weight
,
self
.
norm_0
.
bias
,
self
.
pos_drop
.
p
if
self
.
training
else
0.0
,
self
.
norm_0
.
eps
,
prenorm
=
True
,
residual_in_fp32
=
True
)
hidden_states
=
self
.
norm_0
(
residual
.
to
(
dtype
=
self
.
norm_0
.
weight
.
dtype
))
if
self
.
global_pool
!=
'token'
or
all_tokens
:
if
self
.
global_pool
!=
'token'
or
all_tokens
:
for
block
in
self
.
blocks
:
for
block
in
self
.
blocks
:
hidden_states
,
residual
=
block
(
hidden_states
,
residual
)
hidden_states
,
residual
=
block
(
hidden_states
,
residual
)
...
@@ -232,8 +227,25 @@ class VisionTransformer(nn.Module):
...
@@ -232,8 +227,25 @@ class VisionTransformer(nn.Module):
# where the query is the 1st token and the key/value is the whole sequence.
# where the query is the 1st token and the key/value is the whole sequence.
hidden_states_1st
=
rearrange
(
hidden_states
[:,
0
],
'b d -> b 1 d'
)
hidden_states_1st
=
rearrange
(
hidden_states
[:,
0
],
'b d -> b 1 d'
)
residual_1st
=
rearrange
(
residual
[:,
0
],
'b d -> b 1 d'
)
residual_1st
=
rearrange
(
residual
[:,
0
],
'b d -> b 1 d'
)
hidden_states
,
_
=
self
.
blocks
[
-
1
](
hidden_states_1st
,
residual_1st
,
hidden_states
,
residual
=
self
.
blocks
[
-
1
](
hidden_states_1st
,
residual_1st
,
mixer_kwargs
=
{
'x_kv'
:
hidden_states
})
mixer_kwargs
=
{
'x_kv'
:
hidden_states
})
if
not
self
.
fused_dropout_add_ln
:
residual
=
self
.
drop_path
(
self
.
dropout
(
hidden_states
))
+
residual
hidden_states
=
self
.
norm
(
residual
.
to
(
dtype
=
self
.
norm
.
weight
.
dtype
))
else
:
if
self
.
drop_path
.
p
==
0
or
not
self
.
training
:
rowscale
=
None
else
:
rowscale
=
self
.
drop_path
(
torch
.
ones
(
hidden_states
.
shape
[:
-
1
],
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
)
)
# Set prenorm=False here since we don't need to the residual
hidden_states
=
dropout_add_layer_norm
(
hidden_states
,
residual
,
self
.
norm
.
weight
,
self
.
norm
.
bias
,
self
.
dropout
.
p
if
self
.
training
else
0.0
,
self
.
norm
.
eps
,
rowscale
=
rowscale
,
prenorm
=
False
,
residual_in_fp32
=
True
)
return
hidden_states
return
hidden_states
def
forward_head
(
self
,
x
,
pre_logits
:
bool
=
False
):
def
forward_head
(
self
,
x
,
pre_logits
:
bool
=
False
):
...
...
flash_attn/modules/block.py
View file @
ef085cfc
...
@@ -94,7 +94,7 @@ class Block(nn.Module):
...
@@ -94,7 +94,7 @@ class Block(nn.Module):
Args:
Args:
hidden_states: the sequence to the encoder layer (required).
hidden_states: the sequence to the encoder layer (required).
residual: if postnorm, residual=None, If prenorm, hidden_states =
LayerNorm
(residual)
residual: if postnorm, residual=None, If prenorm, hidden_states =
Attn/MLP(LN
(residual)
)
"""
"""
if
self
.
prenorm
:
if
self
.
prenorm
:
if
not
self
.
fused_dropout_add_ln
:
if
not
self
.
fused_dropout_add_ln
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment