Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
ade7af93
Unverified
Commit
ade7af93
authored
Nov 21, 2023
by
NielsRogge
Committed by
GitHub
Nov 21, 2023
Browse files
[ConvNext] Improve backbone (#27621)
* Improve convnext backbone * Fix convnext2
parent
0e6794ff
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
10 additions
and
12 deletions
+10
-12
src/transformers/models/convnext/modeling_convnext.py
src/transformers/models/convnext/modeling_convnext.py
+5
-6
src/transformers/models/convnextv2/modeling_convnextv2.py
src/transformers/models/convnextv2/modeling_convnextv2.py
+5
-6
No files found.
src/transformers/models/convnext/modeling_convnext.py
View file @
ade7af93
...
@@ -529,14 +529,13 @@ class ConvNextBackbone(ConvNextPreTrainedModel, BackboneMixin):
...
@@ -529,14 +529,13 @@ class ConvNextBackbone(ConvNextPreTrainedModel, BackboneMixin):
outputs
=
self
.
encoder
(
outputs
=
self
.
encoder
(
embedding_output
,
embedding_output
,
output_hidden_states
=
True
,
output_hidden_states
=
True
,
return_dict
=
True
,
return_dict
=
return_dict
,
)
)
hidden_states
=
outputs
.
hidden_states
hidden_states
=
outputs
.
hidden_states
if
return_dict
else
outputs
[
1
]
feature_maps
=
()
feature_maps
=
()
# we skip the stem
for
stage
,
hidden_state
in
zip
(
self
.
stage_names
,
hidden_states
):
for
idx
,
(
stage
,
hidden_state
)
in
enumerate
(
zip
(
self
.
stage_names
[
1
:],
hidden_states
[
1
:])):
if
stage
in
self
.
out_features
:
if
stage
in
self
.
out_features
:
hidden_state
=
self
.
hidden_states_norms
[
stage
](
hidden_state
)
hidden_state
=
self
.
hidden_states_norms
[
stage
](
hidden_state
)
feature_maps
+=
(
hidden_state
,)
feature_maps
+=
(
hidden_state
,)
...
@@ -544,11 +543,11 @@ class ConvNextBackbone(ConvNextPreTrainedModel, BackboneMixin):
...
@@ -544,11 +543,11 @@ class ConvNextBackbone(ConvNextPreTrainedModel, BackboneMixin):
if
not
return_dict
:
if
not
return_dict
:
output
=
(
feature_maps
,)
output
=
(
feature_maps
,)
if
output_hidden_states
:
if
output_hidden_states
:
output
+=
(
outputs
.
hidden_states
,)
output
+=
(
hidden_states
,)
return
output
return
output
return
BackboneOutput
(
return
BackboneOutput
(
feature_maps
=
feature_maps
,
feature_maps
=
feature_maps
,
hidden_states
=
outputs
.
hidden_states
if
output_hidden_states
else
None
,
hidden_states
=
hidden_states
if
output_hidden_states
else
None
,
attentions
=
None
,
attentions
=
None
,
)
)
src/transformers/models/convnextv2/modeling_convnextv2.py
View file @
ade7af93
...
@@ -552,14 +552,13 @@ class ConvNextV2Backbone(ConvNextV2PreTrainedModel, BackboneMixin):
...
@@ -552,14 +552,13 @@ class ConvNextV2Backbone(ConvNextV2PreTrainedModel, BackboneMixin):
outputs
=
self
.
encoder
(
outputs
=
self
.
encoder
(
embedding_output
,
embedding_output
,
output_hidden_states
=
True
,
output_hidden_states
=
True
,
return_dict
=
True
,
return_dict
=
return_dict
,
)
)
hidden_states
=
outputs
.
hidden_states
hidden_states
=
outputs
.
hidden_states
if
return_dict
else
outputs
[
1
]
feature_maps
=
()
feature_maps
=
()
# we skip the stem
for
stage
,
hidden_state
in
zip
(
self
.
stage_names
,
hidden_states
):
for
idx
,
(
stage
,
hidden_state
)
in
enumerate
(
zip
(
self
.
stage_names
[
1
:],
hidden_states
[
1
:])):
if
stage
in
self
.
out_features
:
if
stage
in
self
.
out_features
:
hidden_state
=
self
.
hidden_states_norms
[
stage
](
hidden_state
)
hidden_state
=
self
.
hidden_states_norms
[
stage
](
hidden_state
)
feature_maps
+=
(
hidden_state
,)
feature_maps
+=
(
hidden_state
,)
...
@@ -567,11 +566,11 @@ class ConvNextV2Backbone(ConvNextV2PreTrainedModel, BackboneMixin):
...
@@ -567,11 +566,11 @@ class ConvNextV2Backbone(ConvNextV2PreTrainedModel, BackboneMixin):
if
not
return_dict
:
if
not
return_dict
:
output
=
(
feature_maps
,)
output
=
(
feature_maps
,)
if
output_hidden_states
:
if
output_hidden_states
:
output
+=
(
outputs
.
hidden_states
,)
output
+=
(
hidden_states
,)
return
output
return
output
return
BackboneOutput
(
return
BackboneOutput
(
feature_maps
=
feature_maps
,
feature_maps
=
feature_maps
,
hidden_states
=
outputs
.
hidden_states
if
output_hidden_states
else
None
,
hidden_states
=
hidden_states
if
output_hidden_states
else
None
,
attentions
=
None
,
attentions
=
None
,
)
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment