Unverified Commit 622a8c59 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[trainer] add Trainer methods for metrics logging and saving (#10266)



* make logging and saving trainer built-in

* Update src/transformers/trainer.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 94d8767b
...@@ -18,7 +18,6 @@ Fine-tuning the library models for sequence to sequence. ...@@ -18,7 +18,6 @@ Fine-tuning the library models for sequence to sequence.
""" """
# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments. # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
import json
import logging import logging
import os import os
import re import re
...@@ -55,11 +54,6 @@ with FileLock(".lock") as lock: ...@@ -55,11 +54,6 @@ with FileLock(".lock") as lock:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def save_json(content, path, indent=4, **json_dump_kwargs):
with open(path, "w") as f:
json.dump(content, f, indent=indent, sort_keys=True, **json_dump_kwargs)
@dataclass @dataclass
class ModelArguments: class ModelArguments:
""" """
...@@ -596,13 +590,8 @@ def main(): ...@@ -596,13 +590,8 @@ def main():
) )
metrics["train_samples"] = min(max_train_samples, len(train_dataset)) metrics["train_samples"] = min(max_train_samples, len(train_dataset))
if trainer.is_world_process_zero(): if trainer.is_world_process_zero():
metrics_formatted = trainer.metrics_format(metrics) trainer.log_metrics("train", metrics)
logger.info("***** train metrics *****") trainer.save_metrics("train", metrics)
k_width = max(len(str(x)) for x in metrics_formatted.keys())
v_width = max(len(str(x)) for x in metrics_formatted.values())
for key in sorted(metrics_formatted.keys()):
logger.info(f" {key: <{k_width}} = {metrics_formatted[key]:>{v_width}}")
save_json(metrics, os.path.join(training_args.output_dir, "train_results.json"))
all_metrics.update(metrics) all_metrics.update(metrics)
# Need to save the state, since Trainer.save_model saves only the tokenizer with the model # Need to save the state, since Trainer.save_model saves only the tokenizer with the model
...@@ -620,13 +609,8 @@ def main(): ...@@ -620,13 +609,8 @@ def main():
metrics["eval_samples"] = min(max_val_samples, len(eval_dataset)) metrics["eval_samples"] = min(max_val_samples, len(eval_dataset))
if trainer.is_world_process_zero(): if trainer.is_world_process_zero():
metrics_formatted = trainer.metrics_format(metrics) trainer.log_metrics("eval", metrics)
logger.info("***** val metrics *****") trainer.save_metrics("eval", metrics)
k_width = max(len(str(x)) for x in metrics_formatted.keys())
v_width = max(len(str(x)) for x in metrics_formatted.values())
for key in sorted(metrics_formatted.keys()):
logger.info(f" {key: <{k_width}} = {metrics_formatted[key]:>{v_width}}")
save_json(metrics, os.path.join(training_args.output_dir, "eval_results.json"))
all_metrics.update(metrics) all_metrics.update(metrics)
if training_args.do_predict: if training_args.do_predict:
...@@ -643,13 +627,8 @@ def main(): ...@@ -643,13 +627,8 @@ def main():
metrics["test_samples"] = min(max_test_samples, len(test_dataset)) metrics["test_samples"] = min(max_test_samples, len(test_dataset))
if trainer.is_world_process_zero(): if trainer.is_world_process_zero():
metrics_formatted = trainer.metrics_format(metrics) trainer.log_metrics("test", metrics)
logger.info("***** test metrics *****") trainer.save_metrics("test", metrics)
k_width = max(len(str(x)) for x in metrics_formatted.keys())
v_width = max(len(str(x)) for x in metrics_formatted.values())
for key in sorted(metrics_formatted.keys()):
logger.info(f" {key: <{k_width}} = {metrics_formatted[key]:>{v_width}}")
save_json(metrics, os.path.join(training_args.output_dir, "test_results.json"))
all_metrics.update(metrics) all_metrics.update(metrics)
if training_args.predict_with_generate: if training_args.predict_with_generate:
...@@ -662,7 +641,7 @@ def main(): ...@@ -662,7 +641,7 @@ def main():
writer.write("\n".join(test_preds)) writer.write("\n".join(test_preds))
if trainer.is_world_process_zero(): if trainer.is_world_process_zero():
save_json(all_metrics, os.path.join(training_args.output_dir, "all_results.json")) trainer.save_metrics("all", metrics)
return results return results
......
...@@ -19,6 +19,7 @@ The Trainer class, to easily train a 🤗 Transformers from scratch or finetune ...@@ -19,6 +19,7 @@ The Trainer class, to easily train a 🤗 Transformers from scratch or finetune
import collections import collections
import gc import gc
import inspect import inspect
import json
import math import math
import os import os
import re import re
...@@ -1370,6 +1371,38 @@ class Trainer: ...@@ -1370,6 +1371,38 @@ class Trainer:
return metrics_copy return metrics_copy
def log_metrics(self, split, metrics):
"""
Log metrics in a specially formatted way
Args:
split (:obj:`str`):
Mode/split name: one of ``train``, ``eval``, ``test``
metrics (:obj:`Dict[str, float]`):
The metrics returned from train/evaluate/predictmetrics: metrics dict
"""
logger.info(f"***** {split} metrics *****")
metrics_formatted = self.metrics_format(metrics)
k_width = max(len(str(x)) for x in metrics_formatted.keys())
v_width = max(len(str(x)) for x in metrics_formatted.values())
for key in sorted(metrics_formatted.keys()):
logger.info(f" {key: <{k_width}} = {metrics_formatted[key]:>{v_width}}")
def save_metrics(self, split, metrics):
"""
Save metrics into a json file for that split, e.g. ``train_results.json``.
Args:
split (:obj:`str`):
Mode/split name: one of ``train``, ``eval``, ``test``, ``all``
metrics (:obj:`Dict[str, float]`):
The metrics returned from train/evaluate/predict
"""
path = os.path.join(self.args.output_dir, f"{split}_results.json")
with open(path, "w") as f:
json.dump(metrics, f, indent=4, sort_keys=True)
def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]: def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]:
""" """
Prepare :obj:`inputs` before feeding them to the model, converting them to tensors if they are not already and Prepare :obj:`inputs` before feeding them to the model, converting them to tensors if they are not already and
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment