"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "973218fd3bcaac34254dcc485cabc1d575a8a7f5"
Unverified Commit ee04b698 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[examples] better model example (#10427)

* refactors

* typo
parent a85eb616
...@@ -572,7 +572,6 @@ def main(): ...@@ -572,7 +572,6 @@ def main():
compute_metrics=compute_metrics if training_args.predict_with_generate else None, compute_metrics=compute_metrics if training_args.predict_with_generate else None,
) )
all_metrics = {}
# Training # Training
if training_args.do_train: if training_args.do_train:
if last_checkpoint is not None: if last_checkpoint is not None:
...@@ -589,13 +588,10 @@ def main(): ...@@ -589,13 +588,10 @@ def main():
data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
) )
metrics["train_samples"] = min(max_train_samples, len(train_dataset)) metrics["train_samples"] = min(max_train_samples, len(train_dataset))
if trainer.is_world_process_zero():
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
all_metrics.update(metrics)
# Need to save the state, since Trainer.save_model saves only the tokenizer with the model trainer.log_metrics("train", metrics)
trainer.state.save_to_json(os.path.join(training_args.output_dir, "trainer_state.json")) trainer.save_metrics("train", metrics)
trainer.save_state()
# Evaluation # Evaluation
results = {} results = {}
...@@ -608,10 +604,8 @@ def main(): ...@@ -608,10 +604,8 @@ def main():
max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset) max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset)
metrics["eval_samples"] = min(max_val_samples, len(eval_dataset)) metrics["eval_samples"] = min(max_val_samples, len(eval_dataset))
if trainer.is_world_process_zero(): trainer.log_metrics("eval", metrics)
trainer.log_metrics("eval", metrics) trainer.save_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
all_metrics.update(metrics)
if training_args.do_predict: if training_args.do_predict:
logger.info("*** Test ***") logger.info("*** Test ***")
...@@ -626,11 +620,10 @@ def main(): ...@@ -626,11 +620,10 @@ def main():
max_test_samples = data_args.max_test_samples if data_args.max_test_samples is not None else len(test_dataset) max_test_samples = data_args.max_test_samples if data_args.max_test_samples is not None else len(test_dataset)
metrics["test_samples"] = min(max_test_samples, len(test_dataset)) metrics["test_samples"] = min(max_test_samples, len(test_dataset))
if trainer.is_world_process_zero(): trainer.log_metrics("test", metrics)
trainer.log_metrics("test", metrics) trainer.save_metrics("test", metrics)
trainer.save_metrics("test", metrics)
all_metrics.update(metrics)
if trainer.is_world_process_zero():
if training_args.predict_with_generate: if training_args.predict_with_generate:
test_preds = tokenizer.batch_decode( test_preds = tokenizer.batch_decode(
test_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True test_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True
...@@ -640,9 +633,6 @@ def main(): ...@@ -640,9 +633,6 @@ def main():
with open(output_test_preds_file, "w") as writer: with open(output_test_preds_file, "w") as writer:
writer.write("\n".join(test_preds)) writer.write("\n".join(test_preds))
if trainer.is_world_process_zero():
trainer.save_metrics("all", metrics)
return results return results
......
...@@ -231,7 +231,7 @@ class Trainer: ...@@ -231,7 +231,7 @@ class Trainer:
""" """
from .trainer_pt_utils import _get_learning_rate, log_metrics, metrics_format, save_metrics from .trainer_pt_utils import _get_learning_rate, log_metrics, metrics_format, save_metrics, save_state
def __init__( def __init__(
self, self,
......
...@@ -599,12 +599,16 @@ def log_metrics(self, split, metrics): ...@@ -599,12 +599,16 @@ def log_metrics(self, split, metrics):
""" """
Log metrics in a specially formatted way Log metrics in a specially formatted way
Under distributed environment this is done only for a process with rank 0.
Args: Args:
split (:obj:`str`): split (:obj:`str`):
Mode/split name: one of ``train``, ``eval``, ``test`` Mode/split name: one of ``train``, ``eval``, ``test``
metrics (:obj:`Dict[str, float]`): metrics (:obj:`Dict[str, float]`):
The metrics returned from train/evaluate/predictmetrics: metrics dict The metrics returned from train/evaluate/predictmetrics: metrics dict
""" """
if not self.is_world_process_zero():
return
logger.info(f"***** {split} metrics *****") logger.info(f"***** {split} metrics *****")
metrics_formatted = self.metrics_format(metrics) metrics_formatted = self.metrics_format(metrics)
...@@ -614,16 +618,48 @@ def log_metrics(self, split, metrics): ...@@ -614,16 +618,48 @@ def log_metrics(self, split, metrics):
logger.info(f" {key: <{k_width}} = {metrics_formatted[key]:>{v_width}}") logger.info(f" {key: <{k_width}} = {metrics_formatted[key]:>{v_width}}")
def save_metrics(self, split, metrics): def save_metrics(self, split, metrics, combined=True):
""" """
Save metrics into a json file for that split, e.g. ``train_results.json``. Save metrics into a json file for that split, e.g. ``train_results.json``.
Under distributed environment this is done only for a process with rank 0.
Args: Args:
split (:obj:`str`): split (:obj:`str`):
Mode/split name: one of ``train``, ``eval``, ``test``, ``all`` Mode/split name: one of ``train``, ``eval``, ``test``, ``all``
metrics (:obj:`Dict[str, float]`): metrics (:obj:`Dict[str, float]`):
The metrics returned from train/evaluate/predict The metrics returned from train/evaluate/predict
combined (:obj:`bool`, `optional`, defaults to :obj:`True`):
Creates combined metrics by updating ``all_results.json`` with metrics of this call
""" """
if not self.is_world_process_zero():
return
path = os.path.join(self.args.output_dir, f"{split}_results.json") path = os.path.join(self.args.output_dir, f"{split}_results.json")
with open(path, "w") as f: with open(path, "w") as f:
json.dump(metrics, f, indent=4, sort_keys=True) json.dump(metrics, f, indent=4, sort_keys=True)
if combined:
path = os.path.join(self.args.output_dir, "all_results.json")
if os.path.exists(path):
with open(path, "r") as f:
all_metrics = json.load(f)
else:
all_metrics = {}
all_metrics.update(metrics)
with open(path, "w") as f:
json.dump(all_metrics, f, indent=4, sort_keys=True)
def save_state(self):
"""
Saves the Trainer state, since Trainer.save_model saves only the tokenizer with the model
Under distributed environment this is done only for a process with rank 0.
"""
if not self.is_world_process_zero():
return
path = os.path.join(self.args.output_dir, "trainer_state.json")
self.state.save_to_json(path)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment