Unverified Commit df91ff53 authored by Michael Murray's avatar Michael Murray Committed by GitHub
Browse files

Check for case where `auxiliary_head` is `None` in `UperNetPreTrainedModel` (#25514)

check for case where auxiliary_head is None in UperNetPreTrainedModel
parent b42010bb
......@@ -305,13 +305,15 @@ class UperNetPreTrainedModel(PreTrainedModel):
if isinstance(module, UperNetPreTrainedModel):
module.backbone.init_weights()
module.decode_head.init_weights()
module.auxiliary_head.init_weights()
if module.auxiliary_head is not None:
module.auxiliary_head.init_weights()
def init_weights(self):
"""Initialize the weights"""
self.backbone.init_weights()
self.decode_head.init_weights()
self.auxiliary_head.init_weights()
if self.auxiliary_head is not None:
self.auxiliary_head.init_weights()
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, BackboneMixin):
......@@ -429,9 +431,10 @@ class UperNetForSemanticSegmentation(UperNetPreTrainedModel):
else:
# compute weighted loss
loss_fct = CrossEntropyLoss(ignore_index=self.config.loss_ignore_index)
main_loss = loss_fct(logits, labels)
auxiliary_loss = loss_fct(auxiliary_logits, labels)
loss = main_loss + self.config.auxiliary_loss_weight * auxiliary_loss
loss = loss_fct(logits, labels)
if auxiliary_logits is not None:
auxiliary_loss = loss_fct(auxiliary_logits, labels)
loss += self.config.auxiliary_loss_weight * auxiliary_loss
if not return_dict:
if output_hidden_states:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment