Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
ada4710d
Commit
ada4710d
authored
Aug 17, 2023
by
Tri Dao
Browse files
[ViT] Run black on vit.py
parent
a81900d4
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
177 additions
and
108 deletions
+177
-108
flash_attn/models/vit.py
flash_attn/models/vit.py
+177
-108
No files found.
flash_attn/models/vit.py
View file @
ada4710d
...
...
@@ -2,26 +2,21 @@
# Inspired by / adapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
import
math
import
re
from
functools
import
partial
from
copy
import
deepcopy
from
collections
import
OrderedDict
from
copy
import
deepcopy
from
functools
import
partial
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
torch.nn.init
import
trunc_normal_
from
torchvision.ops
import
StochasticDepth
from
einops
import
rearrange
from
timm.models.helpers
import
named_apply
from
flash_attn.layers.patch_embed
import
PatchEmbed
from
flash_attn.modules.mha
import
MHA
from
flash_attn.modules.mlp
import
Mlp
,
FusedMLP
from
flash_attn.modules.block
import
Block
from
flash_attn.modules.mha
import
MHA
from
flash_attn.modules.mlp
import
FusedMLP
,
Mlp
from
timm.models.helpers
import
named_apply
from
torch.nn.init
import
trunc_normal_
from
torchvision.ops
import
StochasticDepth
try
:
from
flash_attn.ops.layer_norm
import
dropout_add_layer_norm
...
...
@@ -29,11 +24,18 @@ except ImportError:
dropout_add_layer_norm
=
None
def
create_mixer_cls
(
num_heads
,
qkv_bias
,
attn_drop
,
use_flash_attn
,
fused_bias_fc
,
cross_attn
=
False
):
mixer_cls
=
partial
(
MHA
,
num_heads
=
num_heads
,
cross_attn
=
cross_attn
,
qkv_proj_bias
=
qkv_bias
,
dropout
=
attn_drop
,
fused_bias_fc
=
fused_bias_fc
,
use_flash_attn
=
use_flash_attn
)
def
create_mixer_cls
(
num_heads
,
qkv_bias
,
attn_drop
,
use_flash_attn
,
fused_bias_fc
,
cross_attn
=
False
):
mixer_cls
=
partial
(
MHA
,
num_heads
=
num_heads
,
cross_attn
=
cross_attn
,
qkv_proj_bias
=
qkv_bias
,
dropout
=
attn_drop
,
fused_bias_fc
=
fused_bias_fc
,
use_flash_attn
=
use_flash_attn
,
)
return
mixer_cls
...
...
@@ -46,54 +48,85 @@ def create_mlp_cls(embed_dim, mlp_ratio, act_layer, fused_mlp):
return
mlp_cls
def
create_block
(
embed_dim
,
num_heads
,
mlp_ratio
,
qkv_bias
,
drop_rate
,
attn_drop_rate
,
drop_path1
,
drop_path2
,
norm_layer
,
act_layer
,
use_flash_attn
,
fused_bias_fc
,
fused_mlp
,
fused_dropout_add_ln
,
layer_idx
=
None
,
n_layer
=
None
,
last_layer_subset
=
False
):
mixer_cls
=
create_mixer_cls
(
num_heads
,
qkv_bias
,
attn_drop_rate
,
use_flash_attn
,
fused_bias_fc
,
cross_attn
=
(
last_layer_subset
and
layer_idx
==
n_layer
-
1
))
def
create_block
(
embed_dim
,
num_heads
,
mlp_ratio
,
qkv_bias
,
drop_rate
,
attn_drop_rate
,
drop_path1
,
drop_path2
,
norm_layer
,
act_layer
,
use_flash_attn
,
fused_bias_fc
,
fused_mlp
,
fused_dropout_add_ln
,
layer_idx
=
None
,
n_layer
=
None
,
last_layer_subset
=
False
,
):
mixer_cls
=
create_mixer_cls
(
num_heads
,
qkv_bias
,
attn_drop_rate
,
use_flash_attn
,
fused_bias_fc
,
cross_attn
=
(
last_layer_subset
and
layer_idx
==
n_layer
-
1
),
)
mlp_cls
=
create_mlp_cls
(
embed_dim
,
mlp_ratio
,
act_layer
,
fused_mlp
)
# TD [2022-10-15]: Force residual in fp32 in case of DeepSpeed
block
=
Block
(
embed_dim
,
mixer_cls
,
mlp_cls
,
norm_cls
=
norm_layer
,
prenorm
=
True
,
resid_dropout1
=
drop_rate
,
resid_dropout2
=
drop_rate
,
drop_path1
=
drop_path1
,
drop_path2
=
drop_path2
,
fused_dropout_add_ln
=
fused_dropout_add_ln
,
residual_in_fp32
=
True
)
block
=
Block
(
embed_dim
,
mixer_cls
,
mlp_cls
,
norm_cls
=
norm_layer
,
prenorm
=
True
,
resid_dropout1
=
drop_rate
,
resid_dropout2
=
drop_rate
,
drop_path1
=
drop_path1
,
drop_path2
=
drop_path2
,
fused_dropout_add_ln
=
fused_dropout_add_ln
,
residual_in_fp32
=
True
,
)
return
block
class
VisionTransformer
(
nn
.
Module
):
"""
Vision Transformer
"""Vision Transformer
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
- https://arxiv.org/abs/2010.11929
"""
def
__init__
(
self
,
img_size
=
224
,
patch_size
=
16
,
in_chans
=
3
,
num_classes
=
1000
,
global_pool
=
'
token
'
,
embed_dim
=
768
,
depth
=
12
,
num_heads
=
12
,
mlp_ratio
=
4.
,
qkv_bias
=
True
,
init_values
=
None
,
class_token
=
True
,
no_embed_class
=
False
,
pre_norm
=
False
,
fc_norm
=
None
,
drop_rate
=
0.
,
attn_drop_rate
=
0.
,
drop_path_rate
=
0.
,
weight_init
=
''
,
embed_layer
=
PatchEmbed
,
norm_layer
=
None
,
act_layer
=
None
,
use_flash_attn
=
False
,
fused_bias_fc
=
False
,
fused_mlp
=
False
,
fused_dropout_add_ln
=
False
,
self
,
img_size
=
224
,
patch_size
=
16
,
in_chans
=
3
,
num_classes
=
1000
,
global_pool
=
"
token
"
,
embed_dim
=
768
,
depth
=
12
,
num_heads
=
12
,
mlp_ratio
=
4.
0
,
qkv_bias
=
True
,
init_values
=
None
,
class_token
=
True
,
no_embed_class
=
False
,
pre_norm
=
False
,
fc_norm
=
None
,
drop_rate
=
0.
0
,
attn_drop_rate
=
0.
0
,
drop_path_rate
=
0.
0
,
weight_init
=
""
,
embed_layer
=
PatchEmbed
,
norm_layer
=
None
,
act_layer
=
None
,
use_flash_attn
=
False
,
fused_bias_fc
=
False
,
fused_mlp
=
False
,
fused_dropout_add_ln
=
False
,
):
"""
Args:
...
...
@@ -119,40 +152,45 @@ class VisionTransformer(nn.Module):
act_layer: (nn.Module): MLP activation layer
"""
super
().
__init__
()
assert
global_pool
==
'
token
'
,
'
Only support pooling with CLS token
'
assert
global_pool
==
"
token
"
,
"
Only support pooling with CLS token
"
assert
class_token
assert
init_values
is
None
,
'
LayerScale is not supported yet
'
assert
weight_init
==
''
assert
init_values
is
None
,
"
LayerScale is not supported yet
"
assert
weight_init
==
""
assert
fc_norm
is
None
# pre_norm seems redundant, as there's a LayerNorm right at the start of each block, idk
assert
not
pre_norm
use_fc_norm
=
global_pool
==
'
avg
'
if
fc_norm
is
None
else
fc_norm
use_fc_norm
=
global_pool
==
"
avg
"
if
fc_norm
is
None
else
fc_norm
norm_layer
=
norm_layer
or
partial
(
nn
.
LayerNorm
,
eps
=
1e-6
)
act_layer
=
act_layer
or
nn
.
GELU
self
.
num_classes
=
num_classes
self
.
global_pool
=
global_pool
self
.
num_features
=
self
.
embed_dim
=
embed_dim
# num_features for consistency with other models
self
.
num_features
=
(
self
.
embed_dim
)
=
embed_dim
# num_features for consistency with other models
self
.
num_prefix_tokens
=
1
if
class_token
else
0
self
.
no_embed_class
=
no_embed_class
patch_embed_extra_kwargs
=
({
'fused_bias_fc'
:
fused_bias_fc
}
if
embed_layer
is
PatchEmbed
else
{})
patch_embed_extra_kwargs
=
(
{
"fused_bias_fc"
:
fused_bias_fc
}
if
embed_layer
is
PatchEmbed
else
{}
)
self
.
patch_embed
=
embed_layer
(
img_size
=
img_size
,
patch_size
=
patch_size
,
in_chans
=
in_chans
,
embed_dim
=
embed_dim
,
bias
=
not
pre_norm
,
# disable bias if pre-norm is used (e.g. CLIP)
**
patch_embed_extra_kwargs
**
patch_embed_extra_kwargs
,
)
num_patches
=
self
.
patch_embed
.
num_patches
self
.
cls_token
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
1
,
embed_dim
))
if
class_token
else
None
embed_len
=
num_patches
if
no_embed_class
else
num_patches
+
self
.
num_prefix_tokens
self
.
pos_embed
=
nn
.
Parameter
(
torch
.
randn
(
1
,
embed_len
,
embed_dim
)
*
.
02
)
self
.
pos_embed
=
nn
.
Parameter
(
torch
.
randn
(
1
,
embed_len
,
embed_dim
)
*
0
.02
)
dpr
=
[
x
.
item
()
for
x
in
torch
.
linspace
(
0
,
drop_path_rate
,
depth
)]
# stochastic depth decay rule
dpr
=
[
x
.
item
()
for
x
in
torch
.
linspace
(
0
,
drop_path_rate
,
depth
)
]
# stochastic depth decay rule
# We change the order of dropout, residual and layer norm:
# Instead of LN -> Attn / MLP -> Dropout -> Add, we do:
...
...
@@ -160,31 +198,47 @@ class VisionTransformer(nn.Module):
# the main branch (output of MLP). The model definition is unchanged, but the mapping of the
# nn.Dropout probabilities are changed.
# This is for performance reason: we can fuse dropout + add + layer_norm.
self
.
blocks
=
nn
.
ModuleList
([
create_block
(
embed_dim
,
num_heads
,
mlp_ratio
,
qkv_bias
,
drop_rate
,
attn_drop_rate
,
drop_path1
=
dpr
[
i
-
1
]
if
i
>
0
else
0.
,
drop_path2
=
dpr
[
i
],
norm_layer
=
norm_layer
,
act_layer
=
act_layer
,
use_flash_attn
=
use_flash_attn
,
fused_bias_fc
=
fused_bias_fc
,
fused_mlp
=
fused_mlp
,
fused_dropout_add_ln
=
fused_dropout_add_ln
,
layer_idx
=
i
,
n_layer
=
depth
,
last_layer_subset
=
(
global_pool
==
'token'
)
)
for
i
in
range
(
depth
)])
self
.
blocks
=
nn
.
ModuleList
(
[
create_block
(
embed_dim
,
num_heads
,
mlp_ratio
,
qkv_bias
,
drop_rate
,
attn_drop_rate
,
drop_path1
=
dpr
[
i
-
1
]
if
i
>
0
else
0.0
,
drop_path2
=
dpr
[
i
],
norm_layer
=
norm_layer
,
act_layer
=
act_layer
,
use_flash_attn
=
use_flash_attn
,
fused_bias_fc
=
fused_bias_fc
,
fused_mlp
=
fused_mlp
,
fused_dropout_add_ln
=
fused_dropout_add_ln
,
layer_idx
=
i
,
n_layer
=
depth
,
last_layer_subset
=
(
global_pool
==
"token"
),
)
for
i
in
range
(
depth
)
]
)
self
.
dropout
=
nn
.
Dropout
(
p
=
drop_rate
)
self
.
drop_path
=
StochasticDepth
(
p
=
dpr
[
-
1
],
mode
=
'
row
'
)
self
.
drop_path
=
StochasticDepth
(
p
=
dpr
[
-
1
],
mode
=
"
row
"
)
self
.
norm
=
norm_layer
(
embed_dim
)
self
.
fused_dropout_add_ln
=
fused_dropout_add_ln
if
self
.
fused_dropout_add_ln
and
dropout_add_layer_norm
is
None
:
raise
ImportError
(
'
dropout_add_layer_norm is not installed
'
)
raise
ImportError
(
"
dropout_add_layer_norm is not installed
"
)
# Classifier Head
self
.
head
=
nn
.
Linear
(
self
.
embed_dim
,
num_classes
)
if
num_classes
>
0
else
nn
.
Identity
()
self
.
init_weights
(
weight_init
)
def
init_weights
(
self
,
mode
=
''
):
assert
mode
==
''
trunc_normal_
(
self
.
pos_embed
,
std
=
.
02
)
def
init_weights
(
self
,
mode
=
""
):
assert
mode
==
""
trunc_normal_
(
self
.
pos_embed
,
std
=
0
.02
)
if
self
.
cls_token
is
not
None
:
nn
.
init
.
normal_
(
self
.
cls_token
,
std
=
1e-6
)
named_apply
(
init_weights_vit_timm
,
self
)
...
...
@@ -195,7 +249,7 @@ class VisionTransformer(nn.Module):
@
torch
.
jit
.
ignore
def
no_weight_decay
(
self
):
return
{
'
pos_embed
'
,
'
cls_token
'
}
return
{
"
pos_embed
"
,
"
cls_token
"
}
def
_pos_embed
(
self
,
x
):
if
self
.
no_embed_class
:
...
...
@@ -220,8 +274,8 @@ class VisionTransformer(nn.Module):
x
=
self
.
patch_embed
(
x
)
hidden_states
=
self
.
_pos_embed
(
x
)
residual
=
None
if
self
.
global_pool
!=
'
token
'
or
all_tokens
:
# if True:
if
self
.
global_pool
!=
"
token
"
or
all_tokens
:
# if True:
for
block
in
self
.
blocks
:
hidden_states
,
residual
=
block
(
hidden_states
,
residual
)
else
:
...
...
@@ -229,8 +283,9 @@ class VisionTransformer(nn.Module):
hidden_states
,
residual
=
block
(
hidden_states
,
residual
)
# For the last layer, we only want the 1st token of the output. So we do cross-attention
# where the query is the 1st token and the key/value is the whole sequence.
hidden_states
,
residual
=
self
.
blocks
[
-
1
](
hidden_states
,
residual
,
mixer_subset
=
slice
(
0
,
1
))
hidden_states
,
residual
=
self
.
blocks
[
-
1
](
hidden_states
,
residual
,
mixer_subset
=
slice
(
0
,
1
)
)
if
not
self
.
fused_dropout_add_ln
:
residual
=
self
.
drop_path
(
self
.
dropout
(
hidden_states
))
+
residual
hidden_states
=
self
.
norm
(
residual
.
to
(
dtype
=
self
.
norm
.
weight
.
dtype
))
...
...
@@ -238,21 +293,30 @@ class VisionTransformer(nn.Module):
if
self
.
drop_path
.
p
==
0
or
not
self
.
training
:
rowscale
=
None
else
:
rowscale
=
self
.
drop_path
(
torch
.
ones
(
hidden_states
.
shape
[:
-
1
],
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
)
rowscale
=
self
.
drop_path
(
torch
.
ones
(
hidden_states
.
shape
[:
-
1
],
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
)
)
# Set prenorm=False here since we don't need to the residual
hidden_states
=
dropout_add_layer_norm
(
hidden_states
,
residual
,
self
.
norm
.
weight
,
self
.
norm
.
bias
,
self
.
dropout
.
p
if
self
.
training
else
0.0
,
self
.
norm
.
eps
,
rowscale
=
rowscale
,
prenorm
=
False
,
residual_in_fp32
=
True
hidden_states
,
residual
,
self
.
norm
.
weight
,
self
.
norm
.
bias
,
self
.
dropout
.
p
if
self
.
training
else
0.0
,
self
.
norm
.
eps
,
rowscale
=
rowscale
,
prenorm
=
False
,
residual_in_fp32
=
True
,
)
return
hidden_states
def
forward_head
(
self
,
x
,
pre_logits
:
bool
=
False
):
if
self
.
global_pool
:
x
=
x
[:,
self
.
num_prefix_tokens
:].
mean
(
dim
=
1
)
if
self
.
global_pool
==
'
avg
'
else
x
[:,
0
]
x
=
x
[:,
self
.
num_prefix_tokens
:].
mean
(
dim
=
1
)
if
self
.
global_pool
==
"
avg
"
else
x
[:,
0
]
return
x
if
pre_logits
else
self
.
head
(
x
)
def
forward
(
self
,
x
):
...
...
@@ -261,41 +325,46 @@ class VisionTransformer(nn.Module):
return
x
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
patch_embed_weight
=
state_dict
[
'
patch_embed.proj.weight
'
]
patch_embed_weight
=
state_dict
[
"
patch_embed.proj.weight
"
]
if
patch_embed_weight
.
dim
()
==
4
:
# convert from Conv2d to Linear
state_dict
[
'patch_embed.proj.weight'
]
=
rearrange
(
patch_embed_weight
,
'o c h w -> o (c h w)'
)
state_dict
[
"patch_embed.proj.weight"
]
=
rearrange
(
patch_embed_weight
,
"o c h w -> o (c h w)"
)
def
key_mapping_attn
(
key
):
key
=
re
.
sub
(
r
'
^blocks.(\d+).attn.qkv.
'
,
r
'
blocks.\1.mixer.Wqkv.
'
,
key
)
key
=
re
.
sub
(
r
'
^blocks.(\d+).attn.proj.
'
,
r
'
blocks.\1.mixer.out_proj.
'
,
key
)
key
=
re
.
sub
(
r
"
^blocks.(\d+).attn.qkv.
"
,
r
"
blocks.\1.mixer.Wqkv.
"
,
key
)
key
=
re
.
sub
(
r
"
^blocks.(\d+).attn.proj.
"
,
r
"
blocks.\1.mixer.out_proj.
"
,
key
)
return
key
state_dict
=
OrderedDict
((
key_mapping_attn
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
n_layer
=
len
(
self
.
blocks
)
# Convert from Wqkv to Wq and Wkv for cross attention (last layer)
if
(
self
.
blocks
[
-
1
].
mixer
.
cross_attn
and
f
'blocks.
{
n_layer
-
1
}
.mixer.Wqkv.weight'
in
state_dict
):
Wqkv
=
state_dict
.
pop
(
f
'blocks.
{
n_layer
-
1
}
.mixer.Wqkv.weight'
)
bqkv
=
state_dict
.
pop
(
f
'blocks.
{
n_layer
-
1
}
.mixer.Wqkv.bias'
)
state_dict
[
f
'blocks.
{
n_layer
-
1
}
.mixer.Wq.weight'
]
=
Wqkv
[:
self
.
embed_dim
]
state_dict
[
f
'blocks.
{
n_layer
-
1
}
.mixer.Wkv.weight'
]
=
Wqkv
[
self
.
embed_dim
:]
state_dict
[
f
'blocks.
{
n_layer
-
1
}
.mixer.Wq.bias'
]
=
bqkv
[:
self
.
embed_dim
]
state_dict
[
f
'blocks.
{
n_layer
-
1
}
.mixer.Wkv.bias'
]
=
bqkv
[
self
.
embed_dim
:]
if
(
self
.
blocks
[
-
1
].
mixer
.
cross_attn
and
f
"blocks.
{
n_layer
-
1
}
.mixer.Wqkv.weight"
in
state_dict
):
Wqkv
=
state_dict
.
pop
(
f
"blocks.
{
n_layer
-
1
}
.mixer.Wqkv.weight"
)
bqkv
=
state_dict
.
pop
(
f
"blocks.
{
n_layer
-
1
}
.mixer.Wqkv.bias"
)
state_dict
[
f
"blocks.
{
n_layer
-
1
}
.mixer.Wq.weight"
]
=
Wqkv
[:
self
.
embed_dim
]
state_dict
[
f
"blocks.
{
n_layer
-
1
}
.mixer.Wkv.weight"
]
=
Wqkv
[
self
.
embed_dim
:]
state_dict
[
f
"blocks.
{
n_layer
-
1
}
.mixer.Wq.bias"
]
=
bqkv
[:
self
.
embed_dim
]
state_dict
[
f
"blocks.
{
n_layer
-
1
}
.mixer.Wkv.bias"
]
=
bqkv
[
self
.
embed_dim
:]
return
super
().
load_state_dict
(
state_dict
,
strict
=
strict
)
def
init_weights_vit_timm
(
module
:
nn
.
Module
,
name
:
str
=
''
):
"""
ViT weight initialization, original timm impl (for reproducibility)
"""
def
init_weights_vit_timm
(
module
:
nn
.
Module
,
name
:
str
=
""
):
"""ViT weight initialization, original timm impl (for reproducibility)"""
if
isinstance
(
module
,
nn
.
Linear
):
trunc_normal_
(
module
.
weight
,
std
=
.
02
)
trunc_normal_
(
module
.
weight
,
std
=
0
.02
)
if
module
.
bias
is
not
None
:
nn
.
init
.
zeros_
(
module
.
bias
)
elif
hasattr
(
module
,
'
init_weights
'
):
elif
hasattr
(
module
,
"
init_weights
"
):
module
.
init_weights
()
def
vit_base_patch16_224
(
pretrained
=
False
,
**
kwargs
):
"""
ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
"""ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
"""
assert
not
pretrained
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment