Unverified Commit 7e73601f authored by Fan Zhang's avatar Fan Zhang Committed by GitHub
Browse files

modify qa-trainer (#11872)

* modify qa-trainer

* fix flax model
parent 9ec0f01b
...@@ -2555,8 +2555,8 @@ class ReformerForQuestionAnswering(ReformerPreTrainedModel): ...@@ -2555,8 +2555,8 @@ class ReformerForQuestionAnswering(ReformerPreTrainedModel):
logits = self.qa_outputs(sequence_output) logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1) start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1) start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1) end_logits = end_logits.squeeze(-1).contiguous()
total_loss = None total_loss = None
if start_positions is not None and end_positions is not None: if start_positions is not None and end_positions is not None:
......
...@@ -1472,8 +1472,8 @@ class RobertaForQuestionAnswering(RobertaPreTrainedModel): ...@@ -1472,8 +1472,8 @@ class RobertaForQuestionAnswering(RobertaPreTrainedModel):
logits = self.qa_outputs(sequence_output) logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1) start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1) start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1) end_logits = end_logits.squeeze(-1).contiguous()
total_loss = None total_loss = None
if start_positions is not None and end_positions is not None: if start_positions is not None and end_positions is not None:
......
...@@ -1068,8 +1068,8 @@ class SqueezeBertForQuestionAnswering(SqueezeBertPreTrainedModel): ...@@ -1068,8 +1068,8 @@ class SqueezeBertForQuestionAnswering(SqueezeBertPreTrainedModel):
logits = self.qa_outputs(sequence_output) logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1) start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1) start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1) end_logits = end_logits.squeeze(-1).contiguous()
total_loss = None total_loss = None
if start_positions is not None and end_positions is not None: if start_positions is not None and end_positions is not None:
......
...@@ -941,8 +941,8 @@ class XLMForQuestionAnsweringSimple(XLMPreTrainedModel): ...@@ -941,8 +941,8 @@ class XLMForQuestionAnsweringSimple(XLMPreTrainedModel):
logits = self.qa_outputs(sequence_output) logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1) start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1) start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1) end_logits = end_logits.squeeze(-1).contiguous()
total_loss = None total_loss = None
if start_positions is not None and end_positions is not None: if start_positions is not None and end_positions is not None:
......
...@@ -1862,8 +1862,8 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel): ...@@ -1862,8 +1862,8 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel):
logits = self.qa_outputs(sequence_output) logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1) start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1) start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1) end_logits = end_logits.squeeze(-1).contiguous()
total_loss = None total_loss = None
if start_positions is not None and end_positions is not None: if start_positions is not None and end_positions is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment