Unverified Commit bb2cfd18 authored by regisss's avatar regisss Committed by GitHub
Browse files

Add multi-node conditions in trainer_qa.py and trainer_seq2seq.py (#19502)

* Add multi-node conditions in trainer_qa.py and trainer_seq2seq.py

* Code improvement
parent 69b81c0a
...@@ -52,7 +52,8 @@ class QuestionAnsweringTrainer(Trainer): ...@@ -52,7 +52,8 @@ class QuestionAnsweringTrainer(Trainer):
finally: finally:
self.compute_metrics = compute_metrics self.compute_metrics = compute_metrics
if self.post_process_function is not None and self.compute_metrics is not None: if self.post_process_function is not None and self.compute_metrics is not None and self.args.should_save:
# Only the main node write the results by default
eval_preds = self.post_process_function(eval_examples, eval_dataset, output.predictions) eval_preds = self.post_process_function(eval_examples, eval_dataset, output.predictions)
metrics = self.compute_metrics(eval_preds) metrics = self.compute_metrics(eval_preds)
...@@ -60,11 +61,13 @@ class QuestionAnsweringTrainer(Trainer): ...@@ -60,11 +61,13 @@ class QuestionAnsweringTrainer(Trainer):
for key in list(metrics.keys()): for key in list(metrics.keys()):
if not key.startswith(f"{metric_key_prefix}_"): if not key.startswith(f"{metric_key_prefix}_"):
metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key) metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
self.log(metrics)
else: else:
metrics = {} metrics = {}
if self.args.should_log:
# Only the main node log the results by default
self.log(metrics)
if self.args.tpu_metrics_debug or self.args.debug: if self.args.tpu_metrics_debug or self.args.debug:
# tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
xm.master_print(met.metrics_report()) xm.master_print(met.metrics_report())
......
...@@ -84,7 +84,8 @@ class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer): ...@@ -84,7 +84,8 @@ class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer):
) )
) )
if self.post_process_function is not None and self.compute_metrics is not None: if self.post_process_function is not None and self.compute_metrics is not None and self.args.should_save:
# Only the main node write the results by default
eval_preds = self.post_process_function(eval_examples, eval_dataset, output) eval_preds = self.post_process_function(eval_examples, eval_dataset, output)
metrics = self.compute_metrics(eval_preds) metrics = self.compute_metrics(eval_preds)
...@@ -94,7 +95,11 @@ class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer): ...@@ -94,7 +95,11 @@ class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer):
metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key) metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
output.metrics.update(metrics) output.metrics.update(metrics)
else:
metrics = {}
if self.args.should_log:
# Only the main node log the results by default
self.log(metrics) self.log(metrics)
if self.args.tpu_metrics_debug or self.args.debug: if self.args.tpu_metrics_debug or self.args.debug:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment