Unverified Commit f8bd8c6c authored by François Lagunas's avatar François Lagunas Committed by GitHub
Browse files

Fixes bug that appears when using QA bert and distilation. (#12026)

* Fixing bug that appears when using distilation (and potentially other uses).
During backward pass Pytorch complains with:
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
This happens because the QA model code modifies the start_positions and end_positions input tensors, using clamp_ function: as a consequence the teacher and the student both modifies the inputs, and backward pass fails.

* Fixing all models QA clamp_ bug.
parent 59f75d53
...@@ -1230,8 +1230,8 @@ class AlbertForQuestionAnswering(AlbertPreTrainedModel): ...@@ -1230,8 +1230,8 @@ class AlbertForQuestionAnswering(AlbertPreTrainedModel):
end_positions = end_positions.squeeze(-1) end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms # sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1) ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index) start_positions = start_positions.clamp(0, ignored_index)
end_positions.clamp_(0, ignored_index) end_positions = end_positions.clamp(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index) loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions) start_loss = loss_fct(start_logits, start_positions)
......
...@@ -1578,8 +1578,8 @@ class BartForQuestionAnswering(BartPretrainedModel): ...@@ -1578,8 +1578,8 @@ class BartForQuestionAnswering(BartPretrainedModel):
end_positions = end_positions.squeeze(-1) end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms # sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1) ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index) start_positions = start_positions.clamp(0, ignored_index)
end_positions.clamp_(0, ignored_index) end_positions = end_positions.clamp(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index) loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions) start_loss = loss_fct(start_logits, start_positions)
......
...@@ -1813,8 +1813,8 @@ class BertForQuestionAnswering(BertPreTrainedModel): ...@@ -1813,8 +1813,8 @@ class BertForQuestionAnswering(BertPreTrainedModel):
end_positions = end_positions.squeeze(-1) end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms # sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1) ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index) start_positions = start_positions.clamp(0, ignored_index)
end_positions.clamp_(0, ignored_index) end_positions = end_positions.clamp(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index) loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions) start_loss = loss_fct(start_logits, start_positions)
......
...@@ -2995,8 +2995,8 @@ class BigBirdForQuestionAnswering(BigBirdPreTrainedModel): ...@@ -2995,8 +2995,8 @@ class BigBirdForQuestionAnswering(BigBirdPreTrainedModel):
end_positions = end_positions.squeeze(-1) end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms # sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1) ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index) start_positions = start_positions.clamp(0, ignored_index)
end_positions.clamp_(0, ignored_index) end_positions = end_positions.clamp(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index) loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions) start_loss = loss_fct(start_logits, start_positions)
......
...@@ -2783,8 +2783,8 @@ class BigBirdPegasusForQuestionAnswering(BigBirdPegasusPreTrainedModel): ...@@ -2783,8 +2783,8 @@ class BigBirdPegasusForQuestionAnswering(BigBirdPegasusPreTrainedModel):
end_positions = end_positions.squeeze(-1) end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms # sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1) ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index) start_positions = start_positions.clamp(0, ignored_index)
end_positions.clamp_(0, ignored_index) end_positions = end_positions.clamp(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index) loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions) start_loss = loss_fct(start_logits, start_positions)
......
...@@ -1305,8 +1305,8 @@ class ConvBertForQuestionAnswering(ConvBertPreTrainedModel): ...@@ -1305,8 +1305,8 @@ class ConvBertForQuestionAnswering(ConvBertPreTrainedModel):
end_positions = end_positions.squeeze(-1) end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms # sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1) ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index) start_positions = start_positions.clamp(0, ignored_index)
end_positions.clamp_(0, ignored_index) end_positions = end_positions.clamp(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index) loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions) start_loss = loss_fct(start_logits, start_positions)
......
...@@ -1376,8 +1376,8 @@ class DebertaForQuestionAnswering(DebertaPreTrainedModel): ...@@ -1376,8 +1376,8 @@ class DebertaForQuestionAnswering(DebertaPreTrainedModel):
end_positions = end_positions.squeeze(-1) end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms # sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1) ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index) start_positions = start_positions.clamp(0, ignored_index)
end_positions.clamp_(0, ignored_index) end_positions = end_positions.clamp(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index) loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions) start_loss = loss_fct(start_logits, start_positions)
......
...@@ -1500,8 +1500,8 @@ class DebertaV2ForQuestionAnswering(DebertaV2PreTrainedModel): ...@@ -1500,8 +1500,8 @@ class DebertaV2ForQuestionAnswering(DebertaV2PreTrainedModel):
end_positions = end_positions.squeeze(-1) end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms # sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1) ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index) start_positions = start_positions.clamp(0, ignored_index)
end_positions.clamp_(0, ignored_index) end_positions = end_positions.clamp(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index) loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions) start_loss = loss_fct(start_logits, start_positions)
......
...@@ -740,8 +740,8 @@ class DistilBertForQuestionAnswering(DistilBertPreTrainedModel): ...@@ -740,8 +740,8 @@ class DistilBertForQuestionAnswering(DistilBertPreTrainedModel):
end_positions = end_positions.squeeze(-1) end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms # sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1) ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index) start_positions = start_positions.clamp(0, ignored_index)
end_positions.clamp_(0, ignored_index) end_positions = end_positions.clamp(0, ignored_index)
loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index) loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions) start_loss = loss_fct(start_logits, start_positions)
......
...@@ -1330,8 +1330,8 @@ class ElectraForQuestionAnswering(ElectraPreTrainedModel): ...@@ -1330,8 +1330,8 @@ class ElectraForQuestionAnswering(ElectraPreTrainedModel):
end_positions = end_positions.squeeze(-1) end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms # sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1) ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index) start_positions = start_positions.clamp(0, ignored_index)
end_positions.clamp_(0, ignored_index) end_positions = end_positions.clamp(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index) loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions) start_loss = loss_fct(start_logits, start_positions)
......
...@@ -1561,8 +1561,8 @@ class FunnelForQuestionAnswering(FunnelPreTrainedModel): ...@@ -1561,8 +1561,8 @@ class FunnelForQuestionAnswering(FunnelPreTrainedModel):
end_positions = end_positions.squeeze(-1) end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms # sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1) ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index) start_positions = start_positions.clamp(0, ignored_index)
end_positions.clamp_(0, ignored_index) end_positions = end_positions.clamp(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index) loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions) start_loss = loss_fct(start_logits, start_positions)
......
...@@ -1331,8 +1331,8 @@ class IBertForQuestionAnswering(IBertPreTrainedModel): ...@@ -1331,8 +1331,8 @@ class IBertForQuestionAnswering(IBertPreTrainedModel):
end_positions = end_positions.squeeze(-1) end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms # sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1) ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index) start_positions = start_positions.clamp(0, ignored_index)
end_positions.clamp_(0, ignored_index) end_positions = end_positions.clamp(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index) loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions) start_loss = loss_fct(start_logits, start_positions)
......
...@@ -2607,8 +2607,8 @@ class LEDForQuestionAnswering(LEDPreTrainedModel): ...@@ -2607,8 +2607,8 @@ class LEDForQuestionAnswering(LEDPreTrainedModel):
end_positions = end_positions.squeeze(-1) end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms # sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1) ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index) start_positions = start_positions.clamp(0, ignored_index)
end_positions.clamp_(0, ignored_index) end_positions = end_positions.clamp(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index) loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions) start_loss = loss_fct(start_logits, start_positions)
......
...@@ -2029,8 +2029,8 @@ class LongformerForQuestionAnswering(LongformerPreTrainedModel): ...@@ -2029,8 +2029,8 @@ class LongformerForQuestionAnswering(LongformerPreTrainedModel):
end_positions = end_positions.squeeze(-1) end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms # sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1) ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index) start_positions = start_positions.clamp(0, ignored_index)
end_positions.clamp_(0, ignored_index) end_positions = end_positions.clamp(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index) loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions) start_loss = loss_fct(start_logits, start_positions)
......
...@@ -1585,8 +1585,8 @@ class MBartForQuestionAnswering(MBartPreTrainedModel): ...@@ -1585,8 +1585,8 @@ class MBartForQuestionAnswering(MBartPreTrainedModel):
end_positions = end_positions.squeeze(-1) end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms # sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1) ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index) start_positions = start_positions.clamp(0, ignored_index)
end_positions.clamp_(0, ignored_index) end_positions = end_positions.clamp(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index) loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions) start_loss = loss_fct(start_logits, start_positions)
......
...@@ -1806,8 +1806,8 @@ class MegatronBertForQuestionAnswering(MegatronBertPreTrainedModel): ...@@ -1806,8 +1806,8 @@ class MegatronBertForQuestionAnswering(MegatronBertPreTrainedModel):
end_positions = end_positions.squeeze(-1) end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms # sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1) ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index) start_positions = start_positions.clamp(0, ignored_index)
end_positions.clamp_(0, ignored_index) end_positions = end_positions.clamp(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index) loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions) start_loss = loss_fct(start_logits, start_positions)
......
...@@ -1383,8 +1383,8 @@ class MobileBertForQuestionAnswering(MobileBertPreTrainedModel): ...@@ -1383,8 +1383,8 @@ class MobileBertForQuestionAnswering(MobileBertPreTrainedModel):
end_positions = end_positions.squeeze(-1) end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms # sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1) ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index) start_positions = start_positions.clamp(0, ignored_index)
end_positions.clamp_(0, ignored_index) end_positions = end_positions.clamp(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index) loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions) start_loss = loss_fct(start_logits, start_positions)
......
...@@ -1035,8 +1035,8 @@ class MPNetForQuestionAnswering(MPNetPreTrainedModel): ...@@ -1035,8 +1035,8 @@ class MPNetForQuestionAnswering(MPNetPreTrainedModel):
end_positions = end_positions.squeeze(-1) end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms # sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1) ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index) start_positions = start_positions.clamp(0, ignored_index)
end_positions.clamp_(0, ignored_index) end_positions = end_positions.clamp(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index) loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions) start_loss = loss_fct(start_logits, start_positions)
......
...@@ -2567,8 +2567,8 @@ class ReformerForQuestionAnswering(ReformerPreTrainedModel): ...@@ -2567,8 +2567,8 @@ class ReformerForQuestionAnswering(ReformerPreTrainedModel):
end_positions = end_positions.squeeze(-1) end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms # sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1) ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index) start_positions = start_positions.clamp(0, ignored_index)
end_positions.clamp_(0, ignored_index) end_positions = end_positions.clamp(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index) loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions) start_loss = loss_fct(start_logits, start_positions)
......
...@@ -1484,8 +1484,8 @@ class RobertaForQuestionAnswering(RobertaPreTrainedModel): ...@@ -1484,8 +1484,8 @@ class RobertaForQuestionAnswering(RobertaPreTrainedModel):
end_positions = end_positions.squeeze(-1) end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms # sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1) ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index) start_positions = start_positions.clamp(0, ignored_index)
end_positions.clamp_(0, ignored_index) end_positions = end_positions.clamp(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index) loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions) start_loss = loss_fct(start_logits, start_positions)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment