Unverified Commit ff8dcb5e authored by Observer46's avatar Observer46 Committed by GitHub
Browse files

Fix arguments passed to predict function in QA Seq2seq training script (#21026)

fix args passed to predict function
parent 35a7052b
...@@ -151,7 +151,7 @@ class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer): ...@@ -151,7 +151,7 @@ class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer):
if self.post_process_function is None or self.compute_metrics is None: if self.post_process_function is None or self.compute_metrics is None:
return output return output
predictions = self.post_process_function(predict_examples, predict_dataset, output.predictions, "predict") predictions = self.post_process_function(predict_examples, predict_dataset, output, "predict")
metrics = self.compute_metrics(predictions) metrics = self.compute_metrics(predictions)
# Prefix all keys with metric_key_prefix + '_' # Prefix all keys with metric_key_prefix + '_'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment