Unverified Commit 404d9253 authored by Hari's avatar Hari Committed by GitHub
Browse files

add conditional statement for auxiliary loss calculation (#23899)

* add conditional statement for auxiliary loss calculation

* fix style and copies
parent c63bfc30
...@@ -1192,8 +1192,10 @@ class BeitForSemanticSegmentation(BeitPreTrainedModel): ...@@ -1192,8 +1192,10 @@ class BeitForSemanticSegmentation(BeitPreTrainedModel):
# compute weighted loss # compute weighted loss
loss_fct = CrossEntropyLoss(ignore_index=self.config.semantic_loss_ignore_index) loss_fct = CrossEntropyLoss(ignore_index=self.config.semantic_loss_ignore_index)
main_loss = loss_fct(upsampled_logits, labels) main_loss = loss_fct(upsampled_logits, labels)
loss = main_loss
if auxiliary_logits is not None:
auxiliary_loss = loss_fct(upsampled_auxiliary_logits, labels) auxiliary_loss = loss_fct(upsampled_auxiliary_logits, labels)
loss = main_loss + self.config.auxiliary_loss_weight * auxiliary_loss loss += self.config.auxiliary_loss_weight * auxiliary_loss
return loss return loss
......
...@@ -1120,8 +1120,10 @@ class Data2VecVisionForSemanticSegmentation(Data2VecVisionPreTrainedModel): ...@@ -1120,8 +1120,10 @@ class Data2VecVisionForSemanticSegmentation(Data2VecVisionPreTrainedModel):
# compute weighted loss # compute weighted loss
loss_fct = CrossEntropyLoss(ignore_index=self.config.semantic_loss_ignore_index) loss_fct = CrossEntropyLoss(ignore_index=self.config.semantic_loss_ignore_index)
main_loss = loss_fct(upsampled_logits, labels) main_loss = loss_fct(upsampled_logits, labels)
loss = main_loss
if auxiliary_logits is not None:
auxiliary_loss = loss_fct(upsampled_auxiliary_logits, labels) auxiliary_loss = loss_fct(upsampled_auxiliary_logits, labels)
loss = main_loss + self.config.auxiliary_loss_weight * auxiliary_loss loss += self.config.auxiliary_loss_weight * auxiliary_loss
return loss return loss
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment