Unverified Commit 7e73601f authored by Fan Zhang's avatar Fan Zhang Committed by GitHub
Browse files

modify qa-trainer (#11872)

* modify qa-trainer

* fix flax model
parent 9ec0f01b
...@@ -692,7 +692,11 @@ def main(): ...@@ -692,7 +692,11 @@ def main():
if completed_steps >= args.max_train_steps: if completed_steps >= args.max_train_steps:
break break
# Validation # Evaluation
logger.info("***** Running Evaluation *****")
logger.info(f" Num examples = {len(eval_dataset)}")
logger.info(f" Batch size = {args.per_device_eval_batch_size}")
all_start_logits = [] all_start_logits = []
all_end_logits = [] all_end_logits = []
for step, batch in enumerate(eval_dataloader): for step, batch in enumerate(eval_dataloader):
...@@ -725,6 +729,10 @@ def main(): ...@@ -725,6 +729,10 @@ def main():
# Prediction # Prediction
if args.do_predict: if args.do_predict:
logger.info("***** Running Prediction *****")
logger.info(f" Num examples = {len(predict_dataset)}")
logger.info(f" Batch size = {args.per_device_eval_batch_size}")
all_start_logits = [] all_start_logits = []
all_end_logits = [] all_end_logits = []
for step, batch in enumerate(predict_dataloader): for step, batch in enumerate(predict_dataloader):
......
...@@ -1218,8 +1218,8 @@ class AlbertForQuestionAnswering(AlbertPreTrainedModel): ...@@ -1218,8 +1218,8 @@ class AlbertForQuestionAnswering(AlbertPreTrainedModel):
logits = self.qa_outputs(sequence_output) logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1) start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1) start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1) end_logits = end_logits.squeeze(-1).contiguous()
total_loss = None total_loss = None
if start_positions is not None and end_positions is not None: if start_positions is not None and end_positions is not None:
......
...@@ -1556,8 +1556,8 @@ class BartForQuestionAnswering(BartPretrainedModel): ...@@ -1556,8 +1556,8 @@ class BartForQuestionAnswering(BartPretrainedModel):
logits = self.qa_outputs(sequence_output) logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1) start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1) start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1) end_logits = end_logits.squeeze(-1).contiguous()
total_loss = None total_loss = None
if start_positions is not None and end_positions is not None: if start_positions is not None and end_positions is not None:
......
...@@ -1801,8 +1801,8 @@ class BertForQuestionAnswering(BertPreTrainedModel): ...@@ -1801,8 +1801,8 @@ class BertForQuestionAnswering(BertPreTrainedModel):
logits = self.qa_outputs(sequence_output) logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1) start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1) start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1) end_logits = end_logits.squeeze(-1).contiguous()
total_loss = None total_loss = None
if start_positions is not None and end_positions is not None: if start_positions is not None and end_positions is not None:
......
...@@ -2983,8 +2983,8 @@ class BigBirdForQuestionAnswering(BigBirdPreTrainedModel): ...@@ -2983,8 +2983,8 @@ class BigBirdForQuestionAnswering(BigBirdPreTrainedModel):
logits = logits - logits_mask * 1e6 logits = logits - logits_mask * 1e6
start_logits, end_logits = logits.split(1, dim=-1) start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1) start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1) end_logits = end_logits.squeeze(-1).contiguous()
total_loss = None total_loss = None
if start_positions is not None and end_positions is not None: if start_positions is not None and end_positions is not None:
......
...@@ -2761,8 +2761,8 @@ class BigBirdPegasusForQuestionAnswering(BigBirdPegasusPreTrainedModel): ...@@ -2761,8 +2761,8 @@ class BigBirdPegasusForQuestionAnswering(BigBirdPegasusPreTrainedModel):
logits = self.qa_outputs(sequence_output) logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1) start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1) start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1) end_logits = end_logits.squeeze(-1).contiguous()
total_loss = None total_loss = None
if start_positions is not None and end_positions is not None: if start_positions is not None and end_positions is not None:
......
...@@ -1293,8 +1293,8 @@ class ConvBertForQuestionAnswering(ConvBertPreTrainedModel): ...@@ -1293,8 +1293,8 @@ class ConvBertForQuestionAnswering(ConvBertPreTrainedModel):
logits = self.qa_outputs(sequence_output) logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1) start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1) start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1) end_logits = end_logits.squeeze(-1).contiguous()
total_loss = None total_loss = None
if start_positions is not None and end_positions is not None: if start_positions is not None and end_positions is not None:
......
...@@ -1364,8 +1364,8 @@ class DebertaForQuestionAnswering(DebertaPreTrainedModel): ...@@ -1364,8 +1364,8 @@ class DebertaForQuestionAnswering(DebertaPreTrainedModel):
logits = self.qa_outputs(sequence_output) logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1) start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1) start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1) end_logits = end_logits.squeeze(-1).contiguous()
total_loss = None total_loss = None
if start_positions is not None and end_positions is not None: if start_positions is not None and end_positions is not None:
......
...@@ -1488,8 +1488,8 @@ class DebertaV2ForQuestionAnswering(DebertaV2PreTrainedModel): ...@@ -1488,8 +1488,8 @@ class DebertaV2ForQuestionAnswering(DebertaV2PreTrainedModel):
logits = self.qa_outputs(sequence_output) logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1) start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1) start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1) end_logits = end_logits.squeeze(-1).contiguous()
total_loss = None total_loss = None
if start_positions is not None and end_positions is not None: if start_positions is not None and end_positions is not None:
......
...@@ -728,8 +728,8 @@ class DistilBertForQuestionAnswering(DistilBertPreTrainedModel): ...@@ -728,8 +728,8 @@ class DistilBertForQuestionAnswering(DistilBertPreTrainedModel):
hidden_states = self.dropout(hidden_states) # (bs, max_query_len, dim) hidden_states = self.dropout(hidden_states) # (bs, max_query_len, dim)
logits = self.qa_outputs(hidden_states) # (bs, max_query_len, 2) logits = self.qa_outputs(hidden_states) # (bs, max_query_len, 2)
start_logits, end_logits = logits.split(1, dim=-1) start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1) # (bs, max_query_len) start_logits = start_logits.squeeze(-1).contiguous() # (bs, max_query_len)
end_logits = end_logits.squeeze(-1) # (bs, max_query_len) end_logits = end_logits.squeeze(-1).contiguous() # (bs, max_query_len)
total_loss = None total_loss = None
if start_positions is not None and end_positions is not None: if start_positions is not None and end_positions is not None:
......
...@@ -241,8 +241,8 @@ class DPRSpanPredictor(PreTrainedModel): ...@@ -241,8 +241,8 @@ class DPRSpanPredictor(PreTrainedModel):
# compute logits # compute logits
logits = self.qa_outputs(sequence_output) logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1) start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1) start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1) end_logits = end_logits.squeeze(-1).contiguous()
relevance_logits = self.qa_classifier(sequence_output[:, 0, :]) relevance_logits = self.qa_classifier(sequence_output[:, 0, :])
# resize # resize
......
...@@ -1318,8 +1318,8 @@ class ElectraForQuestionAnswering(ElectraPreTrainedModel): ...@@ -1318,8 +1318,8 @@ class ElectraForQuestionAnswering(ElectraPreTrainedModel):
logits = self.qa_outputs(sequence_output) logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1) start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1) start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1) end_logits = end_logits.squeeze(-1).contiguous()
total_loss = None total_loss = None
if start_positions is not None and end_positions is not None: if start_positions is not None and end_positions is not None:
......
...@@ -1549,8 +1549,8 @@ class FunnelForQuestionAnswering(FunnelPreTrainedModel): ...@@ -1549,8 +1549,8 @@ class FunnelForQuestionAnswering(FunnelPreTrainedModel):
logits = self.qa_outputs(last_hidden_state) logits = self.qa_outputs(last_hidden_state)
start_logits, end_logits = logits.split(1, dim=-1) start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1) start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1) end_logits = end_logits.squeeze(-1).contiguous()
total_loss = None total_loss = None
if start_positions is not None and end_positions is not None: if start_positions is not None and end_positions is not None:
......
...@@ -1319,8 +1319,8 @@ class IBertForQuestionAnswering(IBertPreTrainedModel): ...@@ -1319,8 +1319,8 @@ class IBertForQuestionAnswering(IBertPreTrainedModel):
logits = self.qa_outputs(sequence_output) logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1) start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1) start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1) end_logits = end_logits.squeeze(-1).contiguous()
total_loss = None total_loss = None
if start_positions is not None and end_positions is not None: if start_positions is not None and end_positions is not None:
......
...@@ -2585,8 +2585,8 @@ class LEDForQuestionAnswering(LEDPreTrainedModel): ...@@ -2585,8 +2585,8 @@ class LEDForQuestionAnswering(LEDPreTrainedModel):
logits = self.qa_outputs(sequence_output) logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1) start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1) start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1) end_logits = end_logits.squeeze(-1).contiguous()
total_loss = None total_loss = None
if start_positions is not None and end_positions is not None: if start_positions is not None and end_positions is not None:
......
...@@ -2017,8 +2017,8 @@ class LongformerForQuestionAnswering(LongformerPreTrainedModel): ...@@ -2017,8 +2017,8 @@ class LongformerForQuestionAnswering(LongformerPreTrainedModel):
logits = self.qa_outputs(sequence_output) logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1) start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1) start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1) end_logits = end_logits.squeeze(-1).contiguous()
total_loss = None total_loss = None
if start_positions is not None and end_positions is not None: if start_positions is not None and end_positions is not None:
......
...@@ -1563,8 +1563,8 @@ class MBartForQuestionAnswering(MBartPreTrainedModel): ...@@ -1563,8 +1563,8 @@ class MBartForQuestionAnswering(MBartPreTrainedModel):
logits = self.qa_outputs(sequence_output) logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1) start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1) start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1) end_logits = end_logits.squeeze(-1).contiguous()
total_loss = None total_loss = None
if start_positions is not None and end_positions is not None: if start_positions is not None and end_positions is not None:
......
...@@ -1794,8 +1794,8 @@ class MegatronBertForQuestionAnswering(MegatronBertPreTrainedModel): ...@@ -1794,8 +1794,8 @@ class MegatronBertForQuestionAnswering(MegatronBertPreTrainedModel):
logits = self.qa_outputs(sequence_output) logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1) start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1) start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1) end_logits = end_logits.squeeze(-1).contiguous()
total_loss = None total_loss = None
if start_positions is not None and end_positions is not None: if start_positions is not None and end_positions is not None:
......
...@@ -1371,8 +1371,8 @@ class MobileBertForQuestionAnswering(MobileBertPreTrainedModel): ...@@ -1371,8 +1371,8 @@ class MobileBertForQuestionAnswering(MobileBertPreTrainedModel):
logits = self.qa_outputs(sequence_output) logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1) start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1) start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1) end_logits = end_logits.squeeze(-1).contiguous()
total_loss = None total_loss = None
if start_positions is not None and end_positions is not None: if start_positions is not None and end_positions is not None:
......
...@@ -1023,8 +1023,8 @@ class MPNetForQuestionAnswering(MPNetPreTrainedModel): ...@@ -1023,8 +1023,8 @@ class MPNetForQuestionAnswering(MPNetPreTrainedModel):
logits = self.qa_outputs(sequence_output) logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1) start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1) start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1) end_logits = end_logits.squeeze(-1).contiguous()
total_loss = None total_loss = None
if start_positions is not None and end_positions is not None: if start_positions is not None and end_positions is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment