Unverified Commit f8bd8c6c authored by François Lagunas's avatar François Lagunas Committed by GitHub
Browse files

Fixes bug that appears when using QA bert and distilation. (#12026)

* Fixing bug that appears when using distilation (and potentially other uses).
During backward pass Pytorch complains with:
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
This happens because the QA model code modifies the start_positions and end_positions input tensors, using clamp_ function: as a consequence the teacher and the student both modifies the inputs, and backward pass fails.

* Fixing all models QA clamp_ bug.
parent 59f75d53
...@@ -1554,8 +1554,8 @@ class RoFormerForQuestionAnswering(RoFormerPreTrainedModel): ...@@ -1554,8 +1554,8 @@ class RoFormerForQuestionAnswering(RoFormerPreTrainedModel):
end_positions = end_positions.squeeze(-1) end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms # sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1) ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index) start_positions = start_positions.clamp(0, ignored_index)
end_positions.clamp_(0, ignored_index) end_positions = end_positions.clamp(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index) loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions) start_loss = loss_fct(start_logits, start_positions)
......
...@@ -1080,8 +1080,8 @@ class SqueezeBertForQuestionAnswering(SqueezeBertPreTrainedModel): ...@@ -1080,8 +1080,8 @@ class SqueezeBertForQuestionAnswering(SqueezeBertPreTrainedModel):
end_positions = end_positions.squeeze(-1) end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms # sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1) ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index) start_positions = start_positions.clamp(0, ignored_index)
end_positions.clamp_(0, ignored_index) end_positions = end_positions.clamp(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index) loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions) start_loss = loss_fct(start_logits, start_positions)
......
...@@ -953,8 +953,8 @@ class XLMForQuestionAnsweringSimple(XLMPreTrainedModel): ...@@ -953,8 +953,8 @@ class XLMForQuestionAnsweringSimple(XLMPreTrainedModel):
end_positions = end_positions.squeeze(-1) end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms # sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1) ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index) start_positions = start_positions.clamp(0, ignored_index)
end_positions.clamp_(0, ignored_index) end_positions = end_positions.clamp(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index) loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions) start_loss = loss_fct(start_logits, start_positions)
......
...@@ -1874,8 +1874,8 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel): ...@@ -1874,8 +1874,8 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel):
end_positions = end_positions.squeeze(-1) end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms # sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1) ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index) start_positions = start_positions.clamp(0, ignored_index)
end_positions.clamp_(0, ignored_index) end_positions = end_positions.clamp(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index) loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions) start_loss = loss_fct(start_logits, start_positions)
......
...@@ -1516,8 +1516,8 @@ class {{cookiecutter.camelcase_modelname}}ForQuestionAnswering({{cookiecutter.ca ...@@ -1516,8 +1516,8 @@ class {{cookiecutter.camelcase_modelname}}ForQuestionAnswering({{cookiecutter.ca
end_positions = end_positions.squeeze(-1) end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms # sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1) ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index) start_positions = start_positions.clamp(0, ignored_index)
end_positions.clamp_(0, ignored_index) end_positions = end_positions.clamp(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index) loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions) start_loss = loss_fct(start_logits, start_positions)
...@@ -3066,8 +3066,8 @@ class {{cookiecutter.camelcase_modelname}}ForQuestionAnswering({{cookiecutter.ca ...@@ -3066,8 +3066,8 @@ class {{cookiecutter.camelcase_modelname}}ForQuestionAnswering({{cookiecutter.ca
end_positions = end_positions.squeeze(-1) end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms # sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1) ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index) start_positions = start_positions.clamp(0, ignored_index)
end_positions.clamp_(0, ignored_index) end_positions = end_positions.clamp(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index) loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions) start_loss = loss_fct(start_logits, start_positions)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment