Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
76c74b37
Unverified
Commit
76c74b37
authored
Mar 17, 2022
by
Francesco Saverio Zuppichini
Committed by
GitHub
Mar 17, 2022
Browse files
VAN: update modules names (#16201)
* done * done
parent
99e2982f
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
18 additions
and
16 deletions
+18
-16
src/transformers/models/van/convert_van_to_pytorch.py
src/transformers/models/van/convert_van_to_pytorch.py
+4
-4
src/transformers/models/van/modeling_van.py
src/transformers/models/van/modeling_van.py
+14
-12
No files found.
src/transformers/models/van/convert_van_to_pytorch.py
View file @
76c74b37
...
...
@@ -207,10 +207,10 @@ def convert_weights_and_push(save_directory: Path, model_name: str = None, push_
}
names_to_original_checkpoints
=
{
"van-tiny"
:
"https://huggingface.co/Visual-Attention-Network/VAN-Tiny/resolve/main/van_tiny_754.pth.tar"
,
"van-small"
:
"https://huggingface.co/Visual-Attention-Network/VAN-Small/resolve/main/van_small_811.pth.tar"
,
"van-base"
:
"https://huggingface.co/Visual-Attention-Network/VAN-Base/resolve/main/van_base_828.pth.tar"
,
"van-large"
:
"https://huggingface.co/Visual-Attention-Network/VAN-Large/resolve/main/van_large_839.pth.tar"
,
"van-tiny"
:
"https://huggingface.co/Visual-Attention-Network/VAN-Tiny
-original
/resolve/main/van_tiny_754.pth.tar"
,
"van-small"
:
"https://huggingface.co/Visual-Attention-Network/VAN-Small
-original
/resolve/main/van_small_811.pth.tar"
,
"van-base"
:
"https://huggingface.co/Visual-Attention-Network/VAN-Base
-original
/resolve/main/van_base_828.pth.tar"
,
"van-large"
:
"https://huggingface.co/Visual-Attention-Network/VAN-Large
-original
/resolve/main/van_large_839.pth.tar"
,
}
if
model_name
:
...
...
src/transformers/models/van/modeling_van.py
View file @
76c74b37
...
...
@@ -154,8 +154,10 @@ class VanOverlappingPatchEmbedder(nn.Sequential):
def
__init__
(
self
,
in_channels
:
int
,
hidden_size
:
int
,
patch_size
:
int
=
7
,
stride
:
int
=
4
):
super
().
__init__
()
self
.
conv
=
nn
.
Conv2d
(
in_channels
,
hidden_size
,
kernel_size
=
patch_size
,
stride
=
stride
,
padding
=
patch_size
//
2
)
self
.
norm
=
nn
.
BatchNorm2d
(
hidden_size
)
self
.
convolution
=
nn
.
Conv2d
(
in_channels
,
hidden_size
,
kernel_size
=
patch_size
,
stride
=
stride
,
padding
=
patch_size
//
2
)
self
.
normalization
=
nn
.
BatchNorm2d
(
hidden_size
)
class
VanMlpLayer
(
nn
.
Sequential
):
...
...
@@ -173,12 +175,12 @@ class VanMlpLayer(nn.Sequential):
dropout_rate
:
float
=
0.5
,
):
super
().
__init__
()
self
.
fc1
=
nn
.
Conv2d
(
in_channels
,
hidden_size
,
kernel_size
=
1
)
self
.
in_dense
=
nn
.
Conv2d
(
in_channels
,
hidden_size
,
kernel_size
=
1
)
self
.
depth_wise
=
nn
.
Conv2d
(
hidden_size
,
hidden_size
,
kernel_size
=
3
,
padding
=
1
,
groups
=
hidden_size
)
self
.
activation
=
ACT2FN
[
hidden_act
]
self
.
drop1
=
nn
.
Dropout
(
dropout_rate
)
self
.
fc2
=
nn
.
Conv2d
(
hidden_size
,
out_channels
,
kernel_size
=
1
)
self
.
drop2
=
nn
.
Dropout
(
dropout_rate
)
self
.
drop
out
1
=
nn
.
Dropout
(
dropout_rate
)
self
.
out_dense
=
nn
.
Conv2d
(
hidden_size
,
out_channels
,
kernel_size
=
1
)
self
.
drop
out
2
=
nn
.
Dropout
(
dropout_rate
)
class
VanLargeKernelAttention
(
nn
.
Sequential
):
...
...
@@ -267,10 +269,10 @@ class VanLayer(nn.Module):
):
super
().
__init__
()
self
.
drop_path
=
VanDropPath
(
drop_path
)
if
drop_path_rate
>
0.0
else
nn
.
Identity
()
self
.
pre_norm
=
nn
.
BatchNorm2d
(
hidden_size
)
self
.
pre_norm
omalization
=
nn
.
BatchNorm2d
(
hidden_size
)
self
.
attention
=
VanSpatialAttentionLayer
(
hidden_size
,
config
.
hidden_act
)
self
.
attention_scaling
=
VanLayerScaling
(
hidden_size
,
config
.
layer_scale_init_value
)
self
.
post_norm
=
nn
.
BatchNorm2d
(
hidden_size
)
self
.
post_norm
alization
=
nn
.
BatchNorm2d
(
hidden_size
)
self
.
mlp
=
VanMlpLayer
(
hidden_size
,
hidden_size
*
mlp_ratio
,
hidden_size
,
config
.
hidden_act
,
config
.
dropout_rate
)
...
...
@@ -279,7 +281,7 @@ class VanLayer(nn.Module):
def
forward
(
self
,
hidden_state
):
residual
=
hidden_state
# attention
hidden_state
=
self
.
pre_norm
(
hidden_state
)
hidden_state
=
self
.
pre_norm
omalization
(
hidden_state
)
hidden_state
=
self
.
attention
(
hidden_state
)
hidden_state
=
self
.
attention_scaling
(
hidden_state
)
hidden_state
=
self
.
drop_path
(
hidden_state
)
...
...
@@ -287,7 +289,7 @@ class VanLayer(nn.Module):
hidden_state
=
residual
+
hidden_state
residual
=
hidden_state
# mlp
hidden_state
=
self
.
post_norm
(
hidden_state
)
hidden_state
=
self
.
post_norm
alization
(
hidden_state
)
hidden_state
=
self
.
mlp
(
hidden_state
)
hidden_state
=
self
.
mlp_scaling
(
hidden_state
)
hidden_state
=
self
.
drop_path
(
hidden_state
)
...
...
@@ -325,7 +327,7 @@ class VanStage(nn.Module):
for
_
in
range
(
depth
)
]
)
self
.
norm
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
config
.
layer_norm_eps
)
self
.
norm
alization
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
config
.
layer_norm_eps
)
def
forward
(
self
,
hidden_state
):
hidden_state
=
self
.
embeddings
(
hidden_state
)
...
...
@@ -333,7 +335,7 @@ class VanStage(nn.Module):
# rearrange b c h w -> b (h w) c
batch_size
,
hidden_size
,
height
,
width
=
hidden_state
.
shape
hidden_state
=
hidden_state
.
flatten
(
2
).
transpose
(
1
,
2
)
hidden_state
=
self
.
norm
(
hidden_state
)
hidden_state
=
self
.
norm
alization
(
hidden_state
)
# rearrange b (h w) c- > b c h w
hidden_state
=
hidden_state
.
view
(
batch_size
,
height
,
width
,
hidden_size
).
permute
(
0
,
3
,
1
,
2
)
return
hidden_state
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment