Unverified Commit ade7af93 authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

[ConvNext] Improve backbone (#27621)

* Improve convnext backbone

* Fix convnext2
parent 0e6794ff
...@@ -529,14 +529,13 @@ class ConvNextBackbone(ConvNextPreTrainedModel, BackboneMixin): ...@@ -529,14 +529,13 @@ class ConvNextBackbone(ConvNextPreTrainedModel, BackboneMixin):
outputs = self.encoder( outputs = self.encoder(
embedding_output, embedding_output,
output_hidden_states=True, output_hidden_states=True,
return_dict=True, return_dict=return_dict,
) )
hidden_states = outputs.hidden_states hidden_states = outputs.hidden_states if return_dict else outputs[1]
feature_maps = () feature_maps = ()
# we skip the stem for stage, hidden_state in zip(self.stage_names, hidden_states):
for idx, (stage, hidden_state) in enumerate(zip(self.stage_names[1:], hidden_states[1:])):
if stage in self.out_features: if stage in self.out_features:
hidden_state = self.hidden_states_norms[stage](hidden_state) hidden_state = self.hidden_states_norms[stage](hidden_state)
feature_maps += (hidden_state,) feature_maps += (hidden_state,)
...@@ -544,11 +543,11 @@ class ConvNextBackbone(ConvNextPreTrainedModel, BackboneMixin): ...@@ -544,11 +543,11 @@ class ConvNextBackbone(ConvNextPreTrainedModel, BackboneMixin):
if not return_dict: if not return_dict:
output = (feature_maps,) output = (feature_maps,)
if output_hidden_states: if output_hidden_states:
output += (outputs.hidden_states,) output += (hidden_states,)
return output return output
return BackboneOutput( return BackboneOutput(
feature_maps=feature_maps, feature_maps=feature_maps,
hidden_states=outputs.hidden_states if output_hidden_states else None, hidden_states=hidden_states if output_hidden_states else None,
attentions=None, attentions=None,
) )
...@@ -552,14 +552,13 @@ class ConvNextV2Backbone(ConvNextV2PreTrainedModel, BackboneMixin): ...@@ -552,14 +552,13 @@ class ConvNextV2Backbone(ConvNextV2PreTrainedModel, BackboneMixin):
outputs = self.encoder( outputs = self.encoder(
embedding_output, embedding_output,
output_hidden_states=True, output_hidden_states=True,
return_dict=True, return_dict=return_dict,
) )
hidden_states = outputs.hidden_states hidden_states = outputs.hidden_states if return_dict else outputs[1]
feature_maps = () feature_maps = ()
# we skip the stem for stage, hidden_state in zip(self.stage_names, hidden_states):
for idx, (stage, hidden_state) in enumerate(zip(self.stage_names[1:], hidden_states[1:])):
if stage in self.out_features: if stage in self.out_features:
hidden_state = self.hidden_states_norms[stage](hidden_state) hidden_state = self.hidden_states_norms[stage](hidden_state)
feature_maps += (hidden_state,) feature_maps += (hidden_state,)
...@@ -567,11 +566,11 @@ class ConvNextV2Backbone(ConvNextV2PreTrainedModel, BackboneMixin): ...@@ -567,11 +566,11 @@ class ConvNextV2Backbone(ConvNextV2PreTrainedModel, BackboneMixin):
if not return_dict: if not return_dict:
output = (feature_maps,) output = (feature_maps,)
if output_hidden_states: if output_hidden_states:
output += (outputs.hidden_states,) output += (hidden_states,)
return output return output
return BackboneOutput( return BackboneOutput(
feature_maps=feature_maps, feature_maps=feature_maps,
hidden_states=outputs.hidden_states if output_hidden_states else None, hidden_states=hidden_states if output_hidden_states else None,
attentions=None, attentions=None,
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment