Unverified Commit 622a8c59 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[trainer] add Trainer methods for metrics logging and saving (#10266)



* make logging and saving trainer built-in

* Update src/transformers/trainer.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 94d8767b
......@@ -18,7 +18,6 @@ Fine-tuning the library models for sequence to sequence.
"""
# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
import json
import logging
import os
import re
......@@ -55,11 +54,6 @@ with FileLock(".lock") as lock:
logger = logging.getLogger(__name__)
def save_json(content, path, indent=4, **json_dump_kwargs):
with open(path, "w") as f:
json.dump(content, f, indent=indent, sort_keys=True, **json_dump_kwargs)
@dataclass
class ModelArguments:
"""
......@@ -596,13 +590,8 @@ def main():
)
metrics["train_samples"] = min(max_train_samples, len(train_dataset))
if trainer.is_world_process_zero():
metrics_formatted = trainer.metrics_format(metrics)
logger.info("***** train metrics *****")
k_width = max(len(str(x)) for x in metrics_formatted.keys())
v_width = max(len(str(x)) for x in metrics_formatted.values())
for key in sorted(metrics_formatted.keys()):
logger.info(f" {key: <{k_width}} = {metrics_formatted[key]:>{v_width}}")
save_json(metrics, os.path.join(training_args.output_dir, "train_results.json"))
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
all_metrics.update(metrics)
# Need to save the state, since Trainer.save_model saves only the tokenizer with the model
......@@ -620,13 +609,8 @@ def main():
metrics["eval_samples"] = min(max_val_samples, len(eval_dataset))
if trainer.is_world_process_zero():
metrics_formatted = trainer.metrics_format(metrics)
logger.info("***** val metrics *****")
k_width = max(len(str(x)) for x in metrics_formatted.keys())
v_width = max(len(str(x)) for x in metrics_formatted.values())
for key in sorted(metrics_formatted.keys()):
logger.info(f" {key: <{k_width}} = {metrics_formatted[key]:>{v_width}}")
save_json(metrics, os.path.join(training_args.output_dir, "eval_results.json"))
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
all_metrics.update(metrics)
if training_args.do_predict:
......@@ -643,13 +627,8 @@ def main():
metrics["test_samples"] = min(max_test_samples, len(test_dataset))
if trainer.is_world_process_zero():
metrics_formatted = trainer.metrics_format(metrics)
logger.info("***** test metrics *****")
k_width = max(len(str(x)) for x in metrics_formatted.keys())
v_width = max(len(str(x)) for x in metrics_formatted.values())
for key in sorted(metrics_formatted.keys()):
logger.info(f" {key: <{k_width}} = {metrics_formatted[key]:>{v_width}}")
save_json(metrics, os.path.join(training_args.output_dir, "test_results.json"))
trainer.log_metrics("test", metrics)
trainer.save_metrics("test", metrics)
all_metrics.update(metrics)
if training_args.predict_with_generate:
......@@ -662,7 +641,7 @@ def main():
writer.write("\n".join(test_preds))
if trainer.is_world_process_zero():
save_json(all_metrics, os.path.join(training_args.output_dir, "all_results.json"))
trainer.save_metrics("all", metrics)
return results
......
......@@ -19,6 +19,7 @@ The Trainer class, to easily train a 🤗 Transformers from scratch or finetune
import collections
import gc
import inspect
import json
import math
import os
import re
......@@ -1370,6 +1371,38 @@ class Trainer:
return metrics_copy
def log_metrics(self, split, metrics):
"""
Log metrics in a specially formatted way
Args:
split (:obj:`str`):
Mode/split name: one of ``train``, ``eval``, ``test``
metrics (:obj:`Dict[str, float]`):
The metrics returned from train/evaluate/predictmetrics: metrics dict
"""
logger.info(f"***** {split} metrics *****")
metrics_formatted = self.metrics_format(metrics)
k_width = max(len(str(x)) for x in metrics_formatted.keys())
v_width = max(len(str(x)) for x in metrics_formatted.values())
for key in sorted(metrics_formatted.keys()):
logger.info(f" {key: <{k_width}} = {metrics_formatted[key]:>{v_width}}")
def save_metrics(self, split, metrics):
"""
Save metrics into a json file for that split, e.g. ``train_results.json``.
Args:
split (:obj:`str`):
Mode/split name: one of ``train``, ``eval``, ``test``, ``all``
metrics (:obj:`Dict[str, float]`):
The metrics returned from train/evaluate/predict
"""
path = os.path.join(self.args.output_dir, f"{split}_results.json")
with open(path, "w") as f:
json.dump(metrics, f, indent=4, sort_keys=True)
def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]:
"""
Prepare :obj:`inputs` before feeding them to the model, converting them to tensors if they are not already and
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment