Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
9ef46659
Unverified
Commit
9ef46659
authored
Nov 22, 2022
by
NielsRogge
Committed by
GitHub
Nov 22, 2022
Browse files
Improve backbone (#20380)
Co-authored-by:
Niels Rogge
<
nielsrogge@Nielss-MacBook-Pro.local
>
parent
5efd074a
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
10 additions
and
10 deletions
+10
-10
src/transformers/models/resnet/modeling_resnet.py
src/transformers/models/resnet/modeling_resnet.py
+6
-7
tests/models/resnet/test_modeling_resnet.py
tests/models/resnet/test_modeling_resnet.py
+4
-3
No files found.
src/transformers/models/resnet/modeling_resnet.py
View file @
9ef46659
...
@@ -440,13 +440,12 @@ class ResNetBackbone(ResNetPreTrainedModel):
...
@@ -440,13 +440,12 @@ class ResNetBackbone(ResNetPreTrainedModel):
self
.
out_features
=
config
.
out_features
self
.
out_features
=
config
.
out_features
self
.
out_feature_channels
=
{
out_feature_channels
=
{}
"stem"
:
config
.
embedding_size
,
out_feature_channels
[
"stem"
]
=
config
.
embedding_size
"stage1"
:
config
.
hidden_sizes
[
0
],
for
idx
,
stage
in
enumerate
(
self
.
stage_names
[
1
:]):
"stage2"
:
config
.
hidden_sizes
[
1
],
out_feature_channels
[
stage
]
=
config
.
hidden_sizes
[
idx
]
"stage3"
:
config
.
hidden_sizes
[
2
],
"stage4"
:
config
.
hidden_sizes
[
3
],
self
.
out_feature_channels
=
out_feature_channels
}
# initialize weights and apply final processing
# initialize weights and apply final processing
self
.
post_init
()
self
.
post_init
()
...
...
tests/models/resnet/test_modeling_resnet.py
View file @
9ef46659
...
@@ -55,7 +55,7 @@ class ResNetModelTester:
...
@@ -55,7 +55,7 @@ class ResNetModelTester:
hidden_act
=
"relu"
,
hidden_act
=
"relu"
,
num_labels
=
3
,
num_labels
=
3
,
scope
=
None
,
scope
=
None
,
out_features
=
[
"stage1"
,
"stage2"
,
"stage3"
,
"stage4"
],
out_features
=
[
"stage2"
,
"stage3"
,
"stage4"
],
):
):
self
.
parent
=
parent
self
.
parent
=
parent
self
.
batch_size
=
batch_size
self
.
batch_size
=
batch_size
...
@@ -121,10 +121,11 @@ class ResNetModelTester:
...
@@ -121,10 +121,11 @@ class ResNetModelTester:
# verify hidden states
# verify hidden states
self
.
parent
.
assertEqual
(
len
(
result
.
feature_maps
),
len
(
config
.
out_features
))
self
.
parent
.
assertEqual
(
len
(
result
.
feature_maps
),
len
(
config
.
out_features
))
self
.
parent
.
assertListEqual
(
list
(
result
.
feature_maps
[
0
].
shape
),
[
3
,
10
,
8
,
8
])
self
.
parent
.
assertListEqual
(
list
(
result
.
feature_maps
[
0
].
shape
),
[
self
.
batch_size
,
self
.
hidden_sizes
[
1
]
,
4
,
4
])
# verify channels
# verify channels
self
.
parent
.
assertListEqual
(
model
.
channels
,
config
.
hidden_sizes
)
self
.
parent
.
assertEqual
(
len
(
model
.
channels
),
len
(
config
.
out_features
))
self
.
parent
.
assertListEqual
(
model
.
channels
,
config
.
hidden_sizes
[
1
:])
def
prepare_config_and_inputs_for_common
(
self
):
def
prepare_config_and_inputs_for_common
(
self
):
config_and_inputs
=
self
.
prepare_config_and_inputs
()
config_and_inputs
=
self
.
prepare_config_and_inputs
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment