Unverified Commit ee04b698 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[examples] better model example (#10427)

* refactors

* typo
parent a85eb616
......@@ -572,7 +572,6 @@ def main():
compute_metrics=compute_metrics if training_args.predict_with_generate else None,
)
all_metrics = {}
# Training
if training_args.do_train:
if last_checkpoint is not None:
......@@ -589,13 +588,10 @@ def main():
data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
)
metrics["train_samples"] = min(max_train_samples, len(train_dataset))
if trainer.is_world_process_zero():
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
all_metrics.update(metrics)
# Need to save the state, since Trainer.save_model saves only the tokenizer with the model
trainer.state.save_to_json(os.path.join(training_args.output_dir, "trainer_state.json"))
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()
# Evaluation
results = {}
......@@ -608,10 +604,8 @@ def main():
max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset)
metrics["eval_samples"] = min(max_val_samples, len(eval_dataset))
if trainer.is_world_process_zero():
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
all_metrics.update(metrics)
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
if training_args.do_predict:
logger.info("*** Test ***")
......@@ -626,11 +620,10 @@ def main():
max_test_samples = data_args.max_test_samples if data_args.max_test_samples is not None else len(test_dataset)
metrics["test_samples"] = min(max_test_samples, len(test_dataset))
if trainer.is_world_process_zero():
trainer.log_metrics("test", metrics)
trainer.save_metrics("test", metrics)
all_metrics.update(metrics)
trainer.log_metrics("test", metrics)
trainer.save_metrics("test", metrics)
if trainer.is_world_process_zero():
if training_args.predict_with_generate:
test_preds = tokenizer.batch_decode(
test_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True
......@@ -640,9 +633,6 @@ def main():
with open(output_test_preds_file, "w") as writer:
writer.write("\n".join(test_preds))
if trainer.is_world_process_zero():
trainer.save_metrics("all", metrics)
return results
......
......@@ -231,7 +231,7 @@ class Trainer:
"""
from .trainer_pt_utils import _get_learning_rate, log_metrics, metrics_format, save_metrics
from .trainer_pt_utils import _get_learning_rate, log_metrics, metrics_format, save_metrics, save_state
def __init__(
self,
......
......@@ -599,12 +599,16 @@ def log_metrics(self, split, metrics):
"""
Log metrics in a specially formatted way
Under distributed environment this is done only for a process with rank 0.
Args:
split (:obj:`str`):
Mode/split name: one of ``train``, ``eval``, ``test``
metrics (:obj:`Dict[str, float]`):
The metrics returned from train/evaluate/predictmetrics: metrics dict
"""
if not self.is_world_process_zero():
return
logger.info(f"***** {split} metrics *****")
metrics_formatted = self.metrics_format(metrics)
......@@ -614,16 +618,48 @@ def log_metrics(self, split, metrics):
logger.info(f" {key: <{k_width}} = {metrics_formatted[key]:>{v_width}}")
def save_metrics(self, split, metrics):
def save_metrics(self, split, metrics, combined=True):
"""
Save metrics into a json file for that split, e.g. ``train_results.json``.
Under distributed environment this is done only for a process with rank 0.
Args:
split (:obj:`str`):
Mode/split name: one of ``train``, ``eval``, ``test``, ``all``
metrics (:obj:`Dict[str, float]`):
The metrics returned from train/evaluate/predict
combined (:obj:`bool`, `optional`, defaults to :obj:`True`):
Creates combined metrics by updating ``all_results.json`` with metrics of this call
"""
if not self.is_world_process_zero():
return
path = os.path.join(self.args.output_dir, f"{split}_results.json")
with open(path, "w") as f:
json.dump(metrics, f, indent=4, sort_keys=True)
if combined:
path = os.path.join(self.args.output_dir, "all_results.json")
if os.path.exists(path):
with open(path, "r") as f:
all_metrics = json.load(f)
else:
all_metrics = {}
all_metrics.update(metrics)
with open(path, "w") as f:
json.dump(all_metrics, f, indent=4, sort_keys=True)
def save_state(self):
"""
Saves the Trainer state, since Trainer.save_model saves only the tokenizer with the model
Under distributed environment this is done only for a process with rank 0.
"""
if not self.is_world_process_zero():
return
path = os.path.join(self.args.output_dir, "trainer_state.json")
self.state.save_to_json(path)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment