Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
76c74b37
Unverified
Commit
76c74b37
authored
Mar 17, 2022
by
Francesco Saverio Zuppichini
Committed by
GitHub
Mar 17, 2022
Browse files
VAN: update modules names (#16201)
* done * done
parent
99e2982f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
18 additions
and
16 deletions
+18
-16
src/transformers/models/van/convert_van_to_pytorch.py
src/transformers/models/van/convert_van_to_pytorch.py
+4
-4
src/transformers/models/van/modeling_van.py
src/transformers/models/van/modeling_van.py
+14
-12
No files found.
src/transformers/models/van/convert_van_to_pytorch.py
View file @
76c74b37
...
@@ -207,10 +207,10 @@ def convert_weights_and_push(save_directory: Path, model_name: str = None, push_
...
@@ -207,10 +207,10 @@ def convert_weights_and_push(save_directory: Path, model_name: str = None, push_
}
}
names_to_original_checkpoints
=
{
names_to_original_checkpoints
=
{
"van-tiny"
:
"https://huggingface.co/Visual-Attention-Network/VAN-Tiny/resolve/main/van_tiny_754.pth.tar"
,
"van-tiny"
:
"https://huggingface.co/Visual-Attention-Network/VAN-Tiny
-original
/resolve/main/van_tiny_754.pth.tar"
,
"van-small"
:
"https://huggingface.co/Visual-Attention-Network/VAN-Small/resolve/main/van_small_811.pth.tar"
,
"van-small"
:
"https://huggingface.co/Visual-Attention-Network/VAN-Small
-original
/resolve/main/van_small_811.pth.tar"
,
"van-base"
:
"https://huggingface.co/Visual-Attention-Network/VAN-Base/resolve/main/van_base_828.pth.tar"
,
"van-base"
:
"https://huggingface.co/Visual-Attention-Network/VAN-Base
-original
/resolve/main/van_base_828.pth.tar"
,
"van-large"
:
"https://huggingface.co/Visual-Attention-Network/VAN-Large/resolve/main/van_large_839.pth.tar"
,
"van-large"
:
"https://huggingface.co/Visual-Attention-Network/VAN-Large
-original
/resolve/main/van_large_839.pth.tar"
,
}
}
if
model_name
:
if
model_name
:
...
...
src/transformers/models/van/modeling_van.py
View file @
76c74b37
...
@@ -154,8 +154,10 @@ class VanOverlappingPatchEmbedder(nn.Sequential):
...
@@ -154,8 +154,10 @@ class VanOverlappingPatchEmbedder(nn.Sequential):
def
__init__
(
self
,
in_channels
:
int
,
hidden_size
:
int
,
patch_size
:
int
=
7
,
stride
:
int
=
4
):
def
__init__
(
self
,
in_channels
:
int
,
hidden_size
:
int
,
patch_size
:
int
=
7
,
stride
:
int
=
4
):
super
().
__init__
()
super
().
__init__
()
self
.
conv
=
nn
.
Conv2d
(
in_channels
,
hidden_size
,
kernel_size
=
patch_size
,
stride
=
stride
,
padding
=
patch_size
//
2
)
self
.
convolution
=
nn
.
Conv2d
(
self
.
norm
=
nn
.
BatchNorm2d
(
hidden_size
)
in_channels
,
hidden_size
,
kernel_size
=
patch_size
,
stride
=
stride
,
padding
=
patch_size
//
2
)
self
.
normalization
=
nn
.
BatchNorm2d
(
hidden_size
)
class
VanMlpLayer
(
nn
.
Sequential
):
class
VanMlpLayer
(
nn
.
Sequential
):
...
@@ -173,12 +175,12 @@ class VanMlpLayer(nn.Sequential):
...
@@ -173,12 +175,12 @@ class VanMlpLayer(nn.Sequential):
dropout_rate
:
float
=
0.5
,
dropout_rate
:
float
=
0.5
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
fc1
=
nn
.
Conv2d
(
in_channels
,
hidden_size
,
kernel_size
=
1
)
self
.
in_dense
=
nn
.
Conv2d
(
in_channels
,
hidden_size
,
kernel_size
=
1
)
self
.
depth_wise
=
nn
.
Conv2d
(
hidden_size
,
hidden_size
,
kernel_size
=
3
,
padding
=
1
,
groups
=
hidden_size
)
self
.
depth_wise
=
nn
.
Conv2d
(
hidden_size
,
hidden_size
,
kernel_size
=
3
,
padding
=
1
,
groups
=
hidden_size
)
self
.
activation
=
ACT2FN
[
hidden_act
]
self
.
activation
=
ACT2FN
[
hidden_act
]
self
.
drop1
=
nn
.
Dropout
(
dropout_rate
)
self
.
drop
out
1
=
nn
.
Dropout
(
dropout_rate
)
self
.
fc2
=
nn
.
Conv2d
(
hidden_size
,
out_channels
,
kernel_size
=
1
)
self
.
out_dense
=
nn
.
Conv2d
(
hidden_size
,
out_channels
,
kernel_size
=
1
)
self
.
drop2
=
nn
.
Dropout
(
dropout_rate
)
self
.
drop
out
2
=
nn
.
Dropout
(
dropout_rate
)
class
VanLargeKernelAttention
(
nn
.
Sequential
):
class
VanLargeKernelAttention
(
nn
.
Sequential
):
...
@@ -267,10 +269,10 @@ class VanLayer(nn.Module):
...
@@ -267,10 +269,10 @@ class VanLayer(nn.Module):
):
):
super
().
__init__
()
super
().
__init__
()
self
.
drop_path
=
VanDropPath
(
drop_path
)
if
drop_path_rate
>
0.0
else
nn
.
Identity
()
self
.
drop_path
=
VanDropPath
(
drop_path
)
if
drop_path_rate
>
0.0
else
nn
.
Identity
()
self
.
pre_norm
=
nn
.
BatchNorm2d
(
hidden_size
)
self
.
pre_norm
omalization
=
nn
.
BatchNorm2d
(
hidden_size
)
self
.
attention
=
VanSpatialAttentionLayer
(
hidden_size
,
config
.
hidden_act
)
self
.
attention
=
VanSpatialAttentionLayer
(
hidden_size
,
config
.
hidden_act
)
self
.
attention_scaling
=
VanLayerScaling
(
hidden_size
,
config
.
layer_scale_init_value
)
self
.
attention_scaling
=
VanLayerScaling
(
hidden_size
,
config
.
layer_scale_init_value
)
self
.
post_norm
=
nn
.
BatchNorm2d
(
hidden_size
)
self
.
post_norm
alization
=
nn
.
BatchNorm2d
(
hidden_size
)
self
.
mlp
=
VanMlpLayer
(
self
.
mlp
=
VanMlpLayer
(
hidden_size
,
hidden_size
*
mlp_ratio
,
hidden_size
,
config
.
hidden_act
,
config
.
dropout_rate
hidden_size
,
hidden_size
*
mlp_ratio
,
hidden_size
,
config
.
hidden_act
,
config
.
dropout_rate
)
)
...
@@ -279,7 +281,7 @@ class VanLayer(nn.Module):
...
@@ -279,7 +281,7 @@ class VanLayer(nn.Module):
def
forward
(
self
,
hidden_state
):
def
forward
(
self
,
hidden_state
):
residual
=
hidden_state
residual
=
hidden_state
# attention
# attention
hidden_state
=
self
.
pre_norm
(
hidden_state
)
hidden_state
=
self
.
pre_norm
omalization
(
hidden_state
)
hidden_state
=
self
.
attention
(
hidden_state
)
hidden_state
=
self
.
attention
(
hidden_state
)
hidden_state
=
self
.
attention_scaling
(
hidden_state
)
hidden_state
=
self
.
attention_scaling
(
hidden_state
)
hidden_state
=
self
.
drop_path
(
hidden_state
)
hidden_state
=
self
.
drop_path
(
hidden_state
)
...
@@ -287,7 +289,7 @@ class VanLayer(nn.Module):
...
@@ -287,7 +289,7 @@ class VanLayer(nn.Module):
hidden_state
=
residual
+
hidden_state
hidden_state
=
residual
+
hidden_state
residual
=
hidden_state
residual
=
hidden_state
# mlp
# mlp
hidden_state
=
self
.
post_norm
(
hidden_state
)
hidden_state
=
self
.
post_norm
alization
(
hidden_state
)
hidden_state
=
self
.
mlp
(
hidden_state
)
hidden_state
=
self
.
mlp
(
hidden_state
)
hidden_state
=
self
.
mlp_scaling
(
hidden_state
)
hidden_state
=
self
.
mlp_scaling
(
hidden_state
)
hidden_state
=
self
.
drop_path
(
hidden_state
)
hidden_state
=
self
.
drop_path
(
hidden_state
)
...
@@ -325,7 +327,7 @@ class VanStage(nn.Module):
...
@@ -325,7 +327,7 @@ class VanStage(nn.Module):
for
_
in
range
(
depth
)
for
_
in
range
(
depth
)
]
]
)
)
self
.
norm
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
config
.
layer_norm_eps
)
self
.
norm
alization
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
config
.
layer_norm_eps
)
def
forward
(
self
,
hidden_state
):
def
forward
(
self
,
hidden_state
):
hidden_state
=
self
.
embeddings
(
hidden_state
)
hidden_state
=
self
.
embeddings
(
hidden_state
)
...
@@ -333,7 +335,7 @@ class VanStage(nn.Module):
...
@@ -333,7 +335,7 @@ class VanStage(nn.Module):
# rearrange b c h w -> b (h w) c
# rearrange b c h w -> b (h w) c
batch_size
,
hidden_size
,
height
,
width
=
hidden_state
.
shape
batch_size
,
hidden_size
,
height
,
width
=
hidden_state
.
shape
hidden_state
=
hidden_state
.
flatten
(
2
).
transpose
(
1
,
2
)
hidden_state
=
hidden_state
.
flatten
(
2
).
transpose
(
1
,
2
)
hidden_state
=
self
.
norm
(
hidden_state
)
hidden_state
=
self
.
norm
alization
(
hidden_state
)
# rearrange b (h w) c- > b c h w
# rearrange b (h w) c- > b c h w
hidden_state
=
hidden_state
.
view
(
batch_size
,
height
,
width
,
hidden_size
).
permute
(
0
,
3
,
1
,
2
)
hidden_state
=
hidden_state
.
view
(
batch_size
,
height
,
width
,
hidden_size
).
permute
(
0
,
3
,
1
,
2
)
return
hidden_state
return
hidden_state
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment