Unverified Commit ade7af93 authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

[ConvNext] Improve backbone (#27621)

* Improve convnext backbone

* Fix convnext2
parent 0e6794ff
......@@ -529,14 +529,13 @@ class ConvNextBackbone(ConvNextPreTrainedModel, BackboneMixin):
outputs = self.encoder(
embedding_output,
output_hidden_states=True,
return_dict=True,
return_dict=return_dict,
)
hidden_states = outputs.hidden_states
hidden_states = outputs.hidden_states if return_dict else outputs[1]
feature_maps = ()
# we skip the stem
for idx, (stage, hidden_state) in enumerate(zip(self.stage_names[1:], hidden_states[1:])):
for stage, hidden_state in zip(self.stage_names, hidden_states):
if stage in self.out_features:
hidden_state = self.hidden_states_norms[stage](hidden_state)
feature_maps += (hidden_state,)
......@@ -544,11 +543,11 @@ class ConvNextBackbone(ConvNextPreTrainedModel, BackboneMixin):
if not return_dict:
output = (feature_maps,)
if output_hidden_states:
output += (outputs.hidden_states,)
output += (hidden_states,)
return output
return BackboneOutput(
feature_maps=feature_maps,
hidden_states=outputs.hidden_states if output_hidden_states else None,
hidden_states=hidden_states if output_hidden_states else None,
attentions=None,
)
......@@ -552,14 +552,13 @@ class ConvNextV2Backbone(ConvNextV2PreTrainedModel, BackboneMixin):
outputs = self.encoder(
embedding_output,
output_hidden_states=True,
return_dict=True,
return_dict=return_dict,
)
hidden_states = outputs.hidden_states
hidden_states = outputs.hidden_states if return_dict else outputs[1]
feature_maps = ()
# we skip the stem
for idx, (stage, hidden_state) in enumerate(zip(self.stage_names[1:], hidden_states[1:])):
for stage, hidden_state in zip(self.stage_names, hidden_states):
if stage in self.out_features:
hidden_state = self.hidden_states_norms[stage](hidden_state)
feature_maps += (hidden_state,)
......@@ -567,11 +566,11 @@ class ConvNextV2Backbone(ConvNextV2PreTrainedModel, BackboneMixin):
if not return_dict:
output = (feature_maps,)
if output_hidden_states:
output += (outputs.hidden_states,)
output += (hidden_states,)
return output
return BackboneOutput(
feature_maps=feature_maps,
hidden_states=outputs.hidden_states if output_hidden_states else None,
hidden_states=hidden_states if output_hidden_states else None,
attentions=None,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment