Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
ada4710d
Commit
ada4710d
authored
Aug 17, 2023
by
Tri Dao
Browse files
[ViT] Run black on vit.py
parent
a81900d4
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
177 additions
and
108 deletions
+177
-108
flash_attn/models/vit.py
flash_attn/models/vit.py
+177
-108
No files found.
flash_attn/models/vit.py
View file @
ada4710d
...
@@ -2,26 +2,21 @@
...
@@ -2,26 +2,21 @@
# Inspired by / adapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
# Inspired by / adapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
import
math
import
math
import
re
import
re
from
functools
import
partial
from
copy
import
deepcopy
from
collections
import
OrderedDict
from
collections
import
OrderedDict
from
copy
import
deepcopy
from
functools
import
partial
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
torch.nn.init
import
trunc_normal_
from
torchvision.ops
import
StochasticDepth
from
einops
import
rearrange
from
einops
import
rearrange
from
timm.models.helpers
import
named_apply
from
flash_attn.layers.patch_embed
import
PatchEmbed
from
flash_attn.layers.patch_embed
import
PatchEmbed
from
flash_attn.modules.mha
import
MHA
from
flash_attn.modules.mlp
import
Mlp
,
FusedMLP
from
flash_attn.modules.block
import
Block
from
flash_attn.modules.block
import
Block
from
flash_attn.modules.mha
import
MHA
from
flash_attn.modules.mlp
import
FusedMLP
,
Mlp
from
timm.models.helpers
import
named_apply
from
torch.nn.init
import
trunc_normal_
from
torchvision.ops
import
StochasticDepth
try
:
try
:
from
flash_attn.ops.layer_norm
import
dropout_add_layer_norm
from
flash_attn.ops.layer_norm
import
dropout_add_layer_norm
...
@@ -29,11 +24,18 @@ except ImportError:
...
@@ -29,11 +24,18 @@ except ImportError:
dropout_add_layer_norm
=
None
dropout_add_layer_norm
=
None
def
create_mixer_cls
(
num_heads
,
qkv_bias
,
attn_drop
,
use_flash_attn
,
fused_bias_fc
,
def
create_mixer_cls
(
cross_attn
=
False
):
num_heads
,
qkv_bias
,
attn_drop
,
use_flash_attn
,
fused_bias_fc
,
cross_attn
=
False
mixer_cls
=
partial
(
MHA
,
num_heads
=
num_heads
,
cross_attn
=
cross_attn
,
qkv_proj_bias
=
qkv_bias
,
):
dropout
=
attn_drop
,
fused_bias_fc
=
fused_bias_fc
,
mixer_cls
=
partial
(
use_flash_attn
=
use_flash_attn
)
MHA
,
num_heads
=
num_heads
,
cross_attn
=
cross_attn
,
qkv_proj_bias
=
qkv_bias
,
dropout
=
attn_drop
,
fused_bias_fc
=
fused_bias_fc
,
use_flash_attn
=
use_flash_attn
,
)
return
mixer_cls
return
mixer_cls
...
@@ -46,54 +48,85 @@ def create_mlp_cls(embed_dim, mlp_ratio, act_layer, fused_mlp):
...
@@ -46,54 +48,85 @@ def create_mlp_cls(embed_dim, mlp_ratio, act_layer, fused_mlp):
return
mlp_cls
return
mlp_cls
def
create_block
(
embed_dim
,
num_heads
,
mlp_ratio
,
qkv_bias
,
drop_rate
,
attn_drop_rate
,
def
create_block
(
drop_path1
,
drop_path2
,
norm_layer
,
act_layer
,
use_flash_attn
,
fused_bias_fc
,
embed_dim
,
fused_mlp
,
fused_dropout_add_ln
,
layer_idx
=
None
,
n_layer
=
None
,
num_heads
,
last_layer_subset
=
False
):
mlp_ratio
,
mixer_cls
=
create_mixer_cls
(
num_heads
,
qkv_bias
,
attn_drop_rate
,
use_flash_attn
,
fused_bias_fc
,
qkv_bias
,
cross_attn
=
(
last_layer_subset
and
layer_idx
==
n_layer
-
1
))
drop_rate
,
attn_drop_rate
,
drop_path1
,
drop_path2
,
norm_layer
,
act_layer
,
use_flash_attn
,
fused_bias_fc
,
fused_mlp
,
fused_dropout_add_ln
,
layer_idx
=
None
,
n_layer
=
None
,
last_layer_subset
=
False
,
):
mixer_cls
=
create_mixer_cls
(
num_heads
,
qkv_bias
,
attn_drop_rate
,
use_flash_attn
,
fused_bias_fc
,
cross_attn
=
(
last_layer_subset
and
layer_idx
==
n_layer
-
1
),
)
mlp_cls
=
create_mlp_cls
(
embed_dim
,
mlp_ratio
,
act_layer
,
fused_mlp
)
mlp_cls
=
create_mlp_cls
(
embed_dim
,
mlp_ratio
,
act_layer
,
fused_mlp
)
# TD [2022-10-15]: Force residual in fp32 in case of DeepSpeed
# TD [2022-10-15]: Force residual in fp32 in case of DeepSpeed
block
=
Block
(
embed_dim
,
mixer_cls
,
mlp_cls
,
norm_cls
=
norm_layer
,
block
=
Block
(
prenorm
=
True
,
resid_dropout1
=
drop_rate
,
resid_dropout2
=
drop_rate
,
embed_dim
,
drop_path1
=
drop_path1
,
drop_path2
=
drop_path2
,
mixer_cls
,
fused_dropout_add_ln
=
fused_dropout_add_ln
,
residual_in_fp32
=
True
)
mlp_cls
,
norm_cls
=
norm_layer
,
prenorm
=
True
,
resid_dropout1
=
drop_rate
,
resid_dropout2
=
drop_rate
,
drop_path1
=
drop_path1
,
drop_path2
=
drop_path2
,
fused_dropout_add_ln
=
fused_dropout_add_ln
,
residual_in_fp32
=
True
,
)
return
block
return
block
class
VisionTransformer
(
nn
.
Module
):
class
VisionTransformer
(
nn
.
Module
):
"""
Vision Transformer
"""Vision Transformer
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
- https://arxiv.org/abs/2010.11929
- https://arxiv.org/abs/2010.11929
"""
"""
def
__init__
(
def
__init__
(
self
,
self
,
img_size
=
224
,
img_size
=
224
,
patch_size
=
16
,
patch_size
=
16
,
in_chans
=
3
,
in_chans
=
3
,
num_classes
=
1000
,
num_classes
=
1000
,
global_pool
=
'
token
'
,
global_pool
=
"
token
"
,
embed_dim
=
768
,
embed_dim
=
768
,
depth
=
12
,
depth
=
12
,
num_heads
=
12
,
num_heads
=
12
,
mlp_ratio
=
4.
,
mlp_ratio
=
4.
0
,
qkv_bias
=
True
,
qkv_bias
=
True
,
init_values
=
None
,
init_values
=
None
,
class_token
=
True
,
class_token
=
True
,
no_embed_class
=
False
,
no_embed_class
=
False
,
pre_norm
=
False
,
pre_norm
=
False
,
fc_norm
=
None
,
fc_norm
=
None
,
drop_rate
=
0.
,
drop_rate
=
0.
0
,
attn_drop_rate
=
0.
,
attn_drop_rate
=
0.
0
,
drop_path_rate
=
0.
,
drop_path_rate
=
0.
0
,
weight_init
=
''
,
weight_init
=
""
,
embed_layer
=
PatchEmbed
,
embed_layer
=
PatchEmbed
,
norm_layer
=
None
,
norm_layer
=
None
,
act_layer
=
None
,
act_layer
=
None
,
use_flash_attn
=
False
,
use_flash_attn
=
False
,
fused_bias_fc
=
False
,
fused_bias_fc
=
False
,
fused_mlp
=
False
,
fused_mlp
=
False
,
fused_dropout_add_ln
=
False
,
fused_dropout_add_ln
=
False
,
):
):
"""
"""
Args:
Args:
...
@@ -119,40 +152,45 @@ class VisionTransformer(nn.Module):
...
@@ -119,40 +152,45 @@ class VisionTransformer(nn.Module):
act_layer: (nn.Module): MLP activation layer
act_layer: (nn.Module): MLP activation layer
"""
"""
super
().
__init__
()
super
().
__init__
()
assert
global_pool
==
'
token
'
,
'
Only support pooling with CLS token
'
assert
global_pool
==
"
token
"
,
"
Only support pooling with CLS token
"
assert
class_token
assert
class_token
assert
init_values
is
None
,
'
LayerScale is not supported yet
'
assert
init_values
is
None
,
"
LayerScale is not supported yet
"
assert
weight_init
==
''
assert
weight_init
==
""
assert
fc_norm
is
None
assert
fc_norm
is
None
# pre_norm seems redundant, as there's a LayerNorm right at the start of each block, idk
# pre_norm seems redundant, as there's a LayerNorm right at the start of each block, idk
assert
not
pre_norm
assert
not
pre_norm
use_fc_norm
=
global_pool
==
'
avg
'
if
fc_norm
is
None
else
fc_norm
use_fc_norm
=
global_pool
==
"
avg
"
if
fc_norm
is
None
else
fc_norm
norm_layer
=
norm_layer
or
partial
(
nn
.
LayerNorm
,
eps
=
1e-6
)
norm_layer
=
norm_layer
or
partial
(
nn
.
LayerNorm
,
eps
=
1e-6
)
act_layer
=
act_layer
or
nn
.
GELU
act_layer
=
act_layer
or
nn
.
GELU
self
.
num_classes
=
num_classes
self
.
num_classes
=
num_classes
self
.
global_pool
=
global_pool
self
.
global_pool
=
global_pool
self
.
num_features
=
self
.
embed_dim
=
embed_dim
# num_features for consistency with other models
self
.
num_features
=
(
self
.
embed_dim
)
=
embed_dim
# num_features for consistency with other models
self
.
num_prefix_tokens
=
1
if
class_token
else
0
self
.
num_prefix_tokens
=
1
if
class_token
else
0
self
.
no_embed_class
=
no_embed_class
self
.
no_embed_class
=
no_embed_class
patch_embed_extra_kwargs
=
({
'fused_bias_fc'
:
fused_bias_fc
}
if
embed_layer
is
PatchEmbed
patch_embed_extra_kwargs
=
(
else
{})
{
"fused_bias_fc"
:
fused_bias_fc
}
if
embed_layer
is
PatchEmbed
else
{}
)
self
.
patch_embed
=
embed_layer
(
self
.
patch_embed
=
embed_layer
(
img_size
=
img_size
,
img_size
=
img_size
,
patch_size
=
patch_size
,
patch_size
=
patch_size
,
in_chans
=
in_chans
,
in_chans
=
in_chans
,
embed_dim
=
embed_dim
,
embed_dim
=
embed_dim
,
bias
=
not
pre_norm
,
# disable bias if pre-norm is used (e.g. CLIP)
bias
=
not
pre_norm
,
# disable bias if pre-norm is used (e.g. CLIP)
**
patch_embed_extra_kwargs
**
patch_embed_extra_kwargs
,
)
)
num_patches
=
self
.
patch_embed
.
num_patches
num_patches
=
self
.
patch_embed
.
num_patches
self
.
cls_token
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
1
,
embed_dim
))
if
class_token
else
None
self
.
cls_token
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
1
,
embed_dim
))
if
class_token
else
None
embed_len
=
num_patches
if
no_embed_class
else
num_patches
+
self
.
num_prefix_tokens
embed_len
=
num_patches
if
no_embed_class
else
num_patches
+
self
.
num_prefix_tokens
self
.
pos_embed
=
nn
.
Parameter
(
torch
.
randn
(
1
,
embed_len
,
embed_dim
)
*
.
02
)
self
.
pos_embed
=
nn
.
Parameter
(
torch
.
randn
(
1
,
embed_len
,
embed_dim
)
*
0
.02
)
dpr
=
[
x
.
item
()
for
x
in
torch
.
linspace
(
0
,
drop_path_rate
,
depth
)]
# stochastic depth decay rule
dpr
=
[
x
.
item
()
for
x
in
torch
.
linspace
(
0
,
drop_path_rate
,
depth
)
]
# stochastic depth decay rule
# We change the order of dropout, residual and layer norm:
# We change the order of dropout, residual and layer norm:
# Instead of LN -> Attn / MLP -> Dropout -> Add, we do:
# Instead of LN -> Attn / MLP -> Dropout -> Add, we do:
...
@@ -160,31 +198,47 @@ class VisionTransformer(nn.Module):
...
@@ -160,31 +198,47 @@ class VisionTransformer(nn.Module):
# the main branch (output of MLP). The model definition is unchanged, but the mapping of the
# the main branch (output of MLP). The model definition is unchanged, but the mapping of the
# nn.Dropout probabilities are changed.
# nn.Dropout probabilities are changed.
# This is for performance reason: we can fuse dropout + add + layer_norm.
# This is for performance reason: we can fuse dropout + add + layer_norm.
self
.
blocks
=
nn
.
ModuleList
([
create_block
(
self
.
blocks
=
nn
.
ModuleList
(
embed_dim
,
num_heads
,
mlp_ratio
,
qkv_bias
,
drop_rate
,
attn_drop_rate
,
[
drop_path1
=
dpr
[
i
-
1
]
if
i
>
0
else
0.
,
drop_path2
=
dpr
[
i
],
create_block
(
norm_layer
=
norm_layer
,
act_layer
=
act_layer
,
use_flash_attn
=
use_flash_attn
,
embed_dim
,
fused_bias_fc
=
fused_bias_fc
,
fused_mlp
=
fused_mlp
,
num_heads
,
fused_dropout_add_ln
=
fused_dropout_add_ln
,
layer_idx
=
i
,
n_layer
=
depth
,
mlp_ratio
,
last_layer_subset
=
(
global_pool
==
'token'
)
qkv_bias
,
)
for
i
in
range
(
depth
)])
drop_rate
,
attn_drop_rate
,
drop_path1
=
dpr
[
i
-
1
]
if
i
>
0
else
0.0
,
drop_path2
=
dpr
[
i
],
norm_layer
=
norm_layer
,
act_layer
=
act_layer
,
use_flash_attn
=
use_flash_attn
,
fused_bias_fc
=
fused_bias_fc
,
fused_mlp
=
fused_mlp
,
fused_dropout_add_ln
=
fused_dropout_add_ln
,
layer_idx
=
i
,
n_layer
=
depth
,
last_layer_subset
=
(
global_pool
==
"token"
),
)
for
i
in
range
(
depth
)
]
)
self
.
dropout
=
nn
.
Dropout
(
p
=
drop_rate
)
self
.
dropout
=
nn
.
Dropout
(
p
=
drop_rate
)
self
.
drop_path
=
StochasticDepth
(
p
=
dpr
[
-
1
],
mode
=
'
row
'
)
self
.
drop_path
=
StochasticDepth
(
p
=
dpr
[
-
1
],
mode
=
"
row
"
)
self
.
norm
=
norm_layer
(
embed_dim
)
self
.
norm
=
norm_layer
(
embed_dim
)
self
.
fused_dropout_add_ln
=
fused_dropout_add_ln
self
.
fused_dropout_add_ln
=
fused_dropout_add_ln
if
self
.
fused_dropout_add_ln
and
dropout_add_layer_norm
is
None
:
if
self
.
fused_dropout_add_ln
and
dropout_add_layer_norm
is
None
:
raise
ImportError
(
'
dropout_add_layer_norm is not installed
'
)
raise
ImportError
(
"
dropout_add_layer_norm is not installed
"
)
# Classifier Head
# Classifier Head
self
.
head
=
nn
.
Linear
(
self
.
embed_dim
,
num_classes
)
if
num_classes
>
0
else
nn
.
Identity
()
self
.
head
=
nn
.
Linear
(
self
.
embed_dim
,
num_classes
)
if
num_classes
>
0
else
nn
.
Identity
()
self
.
init_weights
(
weight_init
)
self
.
init_weights
(
weight_init
)
def
init_weights
(
self
,
mode
=
''
):
def
init_weights
(
self
,
mode
=
""
):
assert
mode
==
''
assert
mode
==
""
trunc_normal_
(
self
.
pos_embed
,
std
=
.
02
)
trunc_normal_
(
self
.
pos_embed
,
std
=
0
.02
)
if
self
.
cls_token
is
not
None
:
if
self
.
cls_token
is
not
None
:
nn
.
init
.
normal_
(
self
.
cls_token
,
std
=
1e-6
)
nn
.
init
.
normal_
(
self
.
cls_token
,
std
=
1e-6
)
named_apply
(
init_weights_vit_timm
,
self
)
named_apply
(
init_weights_vit_timm
,
self
)
...
@@ -195,7 +249,7 @@ class VisionTransformer(nn.Module):
...
@@ -195,7 +249,7 @@ class VisionTransformer(nn.Module):
@
torch
.
jit
.
ignore
@
torch
.
jit
.
ignore
def
no_weight_decay
(
self
):
def
no_weight_decay
(
self
):
return
{
'
pos_embed
'
,
'
cls_token
'
}
return
{
"
pos_embed
"
,
"
cls_token
"
}
def
_pos_embed
(
self
,
x
):
def
_pos_embed
(
self
,
x
):
if
self
.
no_embed_class
:
if
self
.
no_embed_class
:
...
@@ -220,8 +274,8 @@ class VisionTransformer(nn.Module):
...
@@ -220,8 +274,8 @@ class VisionTransformer(nn.Module):
x
=
self
.
patch_embed
(
x
)
x
=
self
.
patch_embed
(
x
)
hidden_states
=
self
.
_pos_embed
(
x
)
hidden_states
=
self
.
_pos_embed
(
x
)
residual
=
None
residual
=
None
if
self
.
global_pool
!=
'
token
'
or
all_tokens
:
if
self
.
global_pool
!=
"
token
"
or
all_tokens
:
# if True:
# if True:
for
block
in
self
.
blocks
:
for
block
in
self
.
blocks
:
hidden_states
,
residual
=
block
(
hidden_states
,
residual
)
hidden_states
,
residual
=
block
(
hidden_states
,
residual
)
else
:
else
:
...
@@ -229,8 +283,9 @@ class VisionTransformer(nn.Module):
...
@@ -229,8 +283,9 @@ class VisionTransformer(nn.Module):
hidden_states
,
residual
=
block
(
hidden_states
,
residual
)
hidden_states
,
residual
=
block
(
hidden_states
,
residual
)
# For the last layer, we only want the 1st token of the output. So we do cross-attention
# For the last layer, we only want the 1st token of the output. So we do cross-attention
# where the query is the 1st token and the key/value is the whole sequence.
# where the query is the 1st token and the key/value is the whole sequence.
hidden_states
,
residual
=
self
.
blocks
[
-
1
](
hidden_states
,
residual
,
hidden_states
,
residual
=
self
.
blocks
[
-
1
](
mixer_subset
=
slice
(
0
,
1
))
hidden_states
,
residual
,
mixer_subset
=
slice
(
0
,
1
)
)
if
not
self
.
fused_dropout_add_ln
:
if
not
self
.
fused_dropout_add_ln
:
residual
=
self
.
drop_path
(
self
.
dropout
(
hidden_states
))
+
residual
residual
=
self
.
drop_path
(
self
.
dropout
(
hidden_states
))
+
residual
hidden_states
=
self
.
norm
(
residual
.
to
(
dtype
=
self
.
norm
.
weight
.
dtype
))
hidden_states
=
self
.
norm
(
residual
.
to
(
dtype
=
self
.
norm
.
weight
.
dtype
))
...
@@ -238,21 +293,30 @@ class VisionTransformer(nn.Module):
...
@@ -238,21 +293,30 @@ class VisionTransformer(nn.Module):
if
self
.
drop_path
.
p
==
0
or
not
self
.
training
:
if
self
.
drop_path
.
p
==
0
or
not
self
.
training
:
rowscale
=
None
rowscale
=
None
else
:
else
:
rowscale
=
self
.
drop_path
(
torch
.
ones
(
rowscale
=
self
.
drop_path
(
hidden_states
.
shape
[:
-
1
],
device
=
hidden_states
.
device
,
torch
.
ones
(
dtype
=
hidden_states
.
dtype
)
hidden_states
.
shape
[:
-
1
],
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
)
)
)
# Set prenorm=False here since we don't need to the residual
# Set prenorm=False here since we don't need to the residual
hidden_states
=
dropout_add_layer_norm
(
hidden_states
=
dropout_add_layer_norm
(
hidden_states
,
residual
,
self
.
norm
.
weight
,
self
.
norm
.
bias
,
hidden_states
,
self
.
dropout
.
p
if
self
.
training
else
0.0
,
self
.
norm
.
eps
,
rowscale
=
rowscale
,
residual
,
prenorm
=
False
,
residual_in_fp32
=
True
self
.
norm
.
weight
,
self
.
norm
.
bias
,
self
.
dropout
.
p
if
self
.
training
else
0.0
,
self
.
norm
.
eps
,
rowscale
=
rowscale
,
prenorm
=
False
,
residual_in_fp32
=
True
,
)
)
return
hidden_states
return
hidden_states
def
forward_head
(
self
,
x
,
pre_logits
:
bool
=
False
):
def
forward_head
(
self
,
x
,
pre_logits
:
bool
=
False
):
if
self
.
global_pool
:
if
self
.
global_pool
:
x
=
x
[:,
self
.
num_prefix_tokens
:].
mean
(
dim
=
1
)
if
self
.
global_pool
==
'
avg
'
else
x
[:,
0
]
x
=
x
[:,
self
.
num_prefix_tokens
:].
mean
(
dim
=
1
)
if
self
.
global_pool
==
"
avg
"
else
x
[:,
0
]
return
x
if
pre_logits
else
self
.
head
(
x
)
return
x
if
pre_logits
else
self
.
head
(
x
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
...
@@ -261,41 +325,46 @@ class VisionTransformer(nn.Module):
...
@@ -261,41 +325,46 @@ class VisionTransformer(nn.Module):
return
x
return
x
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
patch_embed_weight
=
state_dict
[
'
patch_embed.proj.weight
'
]
patch_embed_weight
=
state_dict
[
"
patch_embed.proj.weight
"
]
if
patch_embed_weight
.
dim
()
==
4
:
if
patch_embed_weight
.
dim
()
==
4
:
# convert from Conv2d to Linear
# convert from Conv2d to Linear
state_dict
[
'patch_embed.proj.weight'
]
=
rearrange
(
patch_embed_weight
,
state_dict
[
"patch_embed.proj.weight"
]
=
rearrange
(
'o c h w -> o (c h w)'
)
patch_embed_weight
,
"o c h w -> o (c h w)"
)
def
key_mapping_attn
(
key
):
def
key_mapping_attn
(
key
):
key
=
re
.
sub
(
r
'
^blocks.(\d+).attn.qkv.
'
,
r
'
blocks.\1.mixer.Wqkv.
'
,
key
)
key
=
re
.
sub
(
r
"
^blocks.(\d+).attn.qkv.
"
,
r
"
blocks.\1.mixer.Wqkv.
"
,
key
)
key
=
re
.
sub
(
r
'
^blocks.(\d+).attn.proj.
'
,
r
'
blocks.\1.mixer.out_proj.
'
,
key
)
key
=
re
.
sub
(
r
"
^blocks.(\d+).attn.proj.
"
,
r
"
blocks.\1.mixer.out_proj.
"
,
key
)
return
key
return
key
state_dict
=
OrderedDict
((
key_mapping_attn
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
state_dict
=
OrderedDict
((
key_mapping_attn
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
n_layer
=
len
(
self
.
blocks
)
n_layer
=
len
(
self
.
blocks
)
# Convert from Wqkv to Wq and Wkv for cross attention (last layer)
# Convert from Wqkv to Wq and Wkv for cross attention (last layer)
if
(
self
.
blocks
[
-
1
].
mixer
.
cross_attn
if
(
and
f
'blocks.
{
n_layer
-
1
}
.mixer.Wqkv.weight'
in
state_dict
):
self
.
blocks
[
-
1
].
mixer
.
cross_attn
Wqkv
=
state_dict
.
pop
(
f
'blocks.
{
n_layer
-
1
}
.mixer.Wqkv.weight'
)
and
f
"blocks.
{
n_layer
-
1
}
.mixer.Wqkv.weight"
in
state_dict
bqkv
=
state_dict
.
pop
(
f
'blocks.
{
n_layer
-
1
}
.mixer.Wqkv.bias'
)
):
state_dict
[
f
'blocks.
{
n_layer
-
1
}
.mixer.Wq.weight'
]
=
Wqkv
[:
self
.
embed_dim
]
Wqkv
=
state_dict
.
pop
(
f
"blocks.
{
n_layer
-
1
}
.mixer.Wqkv.weight"
)
state_dict
[
f
'blocks.
{
n_layer
-
1
}
.mixer.Wkv.weight'
]
=
Wqkv
[
self
.
embed_dim
:]
bqkv
=
state_dict
.
pop
(
f
"blocks.
{
n_layer
-
1
}
.mixer.Wqkv.bias"
)
state_dict
[
f
'blocks.
{
n_layer
-
1
}
.mixer.Wq.bias'
]
=
bqkv
[:
self
.
embed_dim
]
state_dict
[
f
"blocks.
{
n_layer
-
1
}
.mixer.Wq.weight"
]
=
Wqkv
[:
self
.
embed_dim
]
state_dict
[
f
'blocks.
{
n_layer
-
1
}
.mixer.Wkv.bias'
]
=
bqkv
[
self
.
embed_dim
:]
state_dict
[
f
"blocks.
{
n_layer
-
1
}
.mixer.Wkv.weight"
]
=
Wqkv
[
self
.
embed_dim
:]
state_dict
[
f
"blocks.
{
n_layer
-
1
}
.mixer.Wq.bias"
]
=
bqkv
[:
self
.
embed_dim
]
state_dict
[
f
"blocks.
{
n_layer
-
1
}
.mixer.Wkv.bias"
]
=
bqkv
[
self
.
embed_dim
:]
return
super
().
load_state_dict
(
state_dict
,
strict
=
strict
)
return
super
().
load_state_dict
(
state_dict
,
strict
=
strict
)
def
init_weights_vit_timm
(
module
:
nn
.
Module
,
name
:
str
=
''
):
def
init_weights_vit_timm
(
module
:
nn
.
Module
,
name
:
str
=
""
):
"""
ViT weight initialization, original timm impl (for reproducibility)
"""
"""ViT weight initialization, original timm impl (for reproducibility)"""
if
isinstance
(
module
,
nn
.
Linear
):
if
isinstance
(
module
,
nn
.
Linear
):
trunc_normal_
(
module
.
weight
,
std
=
.
02
)
trunc_normal_
(
module
.
weight
,
std
=
0
.02
)
if
module
.
bias
is
not
None
:
if
module
.
bias
is
not
None
:
nn
.
init
.
zeros_
(
module
.
bias
)
nn
.
init
.
zeros_
(
module
.
bias
)
elif
hasattr
(
module
,
'
init_weights
'
):
elif
hasattr
(
module
,
"
init_weights
"
):
module
.
init_weights
()
module
.
init_weights
()
def
vit_base_patch16_224
(
pretrained
=
False
,
**
kwargs
):
def
vit_base_patch16_224
(
pretrained
=
False
,
**
kwargs
):
"""
ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
"""ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
"""
"""
assert
not
pretrained
assert
not
pretrained
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment