Unverified Commit c19d0462 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[finetune_trainer] enhancements and fixes (#9042)



* trainer and finetune_trainer enhancements and fixes

* add fallback default

* move the fixing of incorrect keys back into finetune trainer

* s/eval/val/ to match the split

* trainer can now use a different prefix than eval_ for metrics

* document new arg

* Apply suggestions from code review
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* use 'eval' as the default for metric_key_prefix

* complete adjust var names + disambiguate

* fix logger

* add clarifying comment

* add clarifying comment

* style

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Update src/transformers/trainer.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* complete removal of optional for metric_key_prefix

* Apply suggestions from code review
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 251eb70c
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
import logging import logging
import os import os
import sys import sys
import time
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Optional from typing import Optional
...@@ -119,6 +120,46 @@ class DataTrainingArguments: ...@@ -119,6 +120,46 @@ class DataTrainingArguments:
) )
def speed_metrics(split, start_time, num_samples):
"""
Measure and return speed performance metrics.
This function requires a time snapshot `start_time` before the operation to be measured starts and this
function should be run immediately after the operation to be measured has completed.
Args:
- split: one of train, val, test
- start_time: operation start time
- num_samples: number of samples processed
"""
runtime = time.time() - start_time
result = {}
samples_per_second = 1 / (runtime / num_samples)
result[f"{split}_samples_per_second"] = round(samples_per_second, 3)
result[f"{split}_runtime"] = round(runtime, 4)
result[f"{split}_n_ojbs"] = num_samples
return result
def handle_metrics(split, metrics, output_dir):
"""
Log and save metrics
Args:
- split: one of train, val, test
- metrics: metrics dict
- output_dir: where to save the metrics
"""
logger.info(f"***** {split} metrics *****")
for key, value in metrics.items():
logger.info(f" {key} = {value}")
save_json(metrics, os.path.join(output_dir, f"{split}_results.json"))
def main(): def main():
# See all possible arguments in src/transformers/training_args.py # See all possible arguments in src/transformers/training_args.py
# or by passing the --help flag to this script. # or by passing the --help flag to this script.
...@@ -265,45 +306,56 @@ def main(): ...@@ -265,45 +306,56 @@ def main():
data_args=data_args, data_args=data_args,
) )
all_metrics = {}
# Training # Training
if training_args.do_train: if training_args.do_train:
logger.info("*** Train ***")
start_time = time.time()
trainer.train( trainer.train(
model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None
) )
trainer.save_model() metrics = speed_metrics("train", start_time, data_args.n_train)
# For convenience, we also re-save the tokenizer to the same directory,
# so that you can share your model easily on huggingface.co/models =) trainer.save_model() # this also saves the tokenizer
if trainer.is_world_process_zero(): if trainer.is_world_process_zero():
handle_metrics("train", metrics, training_args.output_dir)
all_metrics.update(metrics)
# Need to save the state, since Trainer.save_model saves only the tokenizer with the model
trainer.state.save_to_json(os.path.join(training_args.output_dir, "trainer_state.json")) trainer.state.save_to_json(os.path.join(training_args.output_dir, "trainer_state.json"))
# For convenience, we also re-save the tokenizer to the same directory,
# so that you can share your model easily on huggingface.co/models =)
tokenizer.save_pretrained(training_args.output_dir) tokenizer.save_pretrained(training_args.output_dir)
# Evaluation # Evaluation
eval_results = {}
if training_args.do_eval: if training_args.do_eval:
logger.info("*** Evaluate ***") logger.info("*** Evaluate ***")
result = trainer.evaluate() start_time = time.time()
metrics = trainer.evaluate(metric_key_prefix="val")
metrics.update(speed_metrics("val", start_time, data_args.n_val))
metrics["val_loss"] = round(metrics["val_loss"], 4)
if trainer.is_world_process_zero(): if trainer.is_world_process_zero():
logger.info("***** Eval results *****")
for key, value in result.items(): handle_metrics("val", metrics, training_args.output_dir)
logger.info(" %s = %s", key, value) all_metrics.update(metrics)
save_json(result, os.path.join(training_args.output_dir, "eval_results.json"))
eval_results.update(result)
if training_args.do_predict: if training_args.do_predict:
logging.info("*** Test ***") logger.info("*** Predict ***")
test_output = trainer.predict(test_dataset=test_dataset) start_time = time.time()
test_metrics = {k.replace("eval", "test"): v for k, v in test_output.metrics.items()} test_output = trainer.predict(test_dataset=test_dataset, metric_key_prefix="test")
metrics = test_output.metrics
metrics.update(speed_metrics("test", start_time, data_args.n_test))
if trainer.is_world_process_zero(): if trainer.is_world_process_zero():
logger.info("***** Test results *****") metrics["test_loss"] = round(metrics["test_loss"], 4)
for key, value in test_metrics.items(): handle_metrics("test", metrics, training_args.output_dir)
logger.info(" %s = %s", key, value) all_metrics.update(metrics)
save_json(test_metrics, os.path.join(training_args.output_dir, "test_results.json"))
eval_results.update(test_metrics)
if training_args.predict_with_generate: if training_args.predict_with_generate:
test_preds = tokenizer.batch_decode( test_preds = tokenizer.batch_decode(
...@@ -313,8 +365,9 @@ def main(): ...@@ -313,8 +365,9 @@ def main():
write_txt_file(test_preds, os.path.join(training_args.output_dir, "test_generations.txt")) write_txt_file(test_preds, os.path.join(training_args.output_dir, "test_generations.txt"))
if trainer.is_world_process_zero(): if trainer.is_world_process_zero():
save_json(eval_results, "all_results.json") save_json(all_metrics, os.path.join(training_args.output_dir, "all_results.json"))
return eval_results
return all_metrics
def _mp_fn(index): def _mp_fn(index):
......
...@@ -462,7 +462,7 @@ def save_git_info(folder_path: str) -> None: ...@@ -462,7 +462,7 @@ def save_git_info(folder_path: str) -> None:
def save_json(content, path, indent=4, **json_dump_kwargs): def save_json(content, path, indent=4, **json_dump_kwargs):
with open(path, "w") as f: with open(path, "w") as f:
json.dump(content, f, indent=indent, **json_dump_kwargs) json.dump(content, f, indent=indent, sort_keys=True, **json_dump_kwargs)
def load_json(path): def load_json(path):
......
...@@ -1243,7 +1243,10 @@ class Trainer: ...@@ -1243,7 +1243,10 @@ class Trainer:
shutil.rmtree(checkpoint) shutil.rmtree(checkpoint)
def evaluate( def evaluate(
self, eval_dataset: Optional[Dataset] = None, ignore_keys: Optional[List[str]] = None self,
eval_dataset: Optional[Dataset] = None,
ignore_keys: Optional[List[str]] = None,
metric_key_prefix: str = "eval",
) -> Dict[str, float]: ) -> Dict[str, float]:
""" """
Run evaluation and returns metrics. Run evaluation and returns metrics.
...@@ -1261,6 +1264,9 @@ class Trainer: ...@@ -1261,6 +1264,9 @@ class Trainer:
ignore_keys (:obj:`Lst[str]`, `optional`): ignore_keys (:obj:`Lst[str]`, `optional`):
A list of keys in the output of your model (if it is a dictionary) that should be ignored when A list of keys in the output of your model (if it is a dictionary) that should be ignored when
gathering predictions. gathering predictions.
metric_key_prefix (:obj:`str`, `optional`, defaults to :obj:`"eval"`):
An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
"eval_bleu" if the prefix is "eval" (default)
Returns: Returns:
A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
...@@ -1278,6 +1284,7 @@ class Trainer: ...@@ -1278,6 +1284,7 @@ class Trainer:
# self.args.prediction_loss_only # self.args.prediction_loss_only
prediction_loss_only=True if self.compute_metrics is None else None, prediction_loss_only=True if self.compute_metrics is None else None,
ignore_keys=ignore_keys, ignore_keys=ignore_keys,
metric_key_prefix=metric_key_prefix,
) )
self.log(output.metrics) self.log(output.metrics)
...@@ -1289,7 +1296,9 @@ class Trainer: ...@@ -1289,7 +1296,9 @@ class Trainer:
self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics) self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics)
return output.metrics return output.metrics
def predict(self, test_dataset: Dataset, ignore_keys: Optional[List[str]] = None) -> PredictionOutput: def predict(
self, test_dataset: Dataset, ignore_keys: Optional[List[str]] = None, metric_key_prefix: str = "eval"
) -> PredictionOutput:
""" """
Run prediction and returns predictions and potential metrics. Run prediction and returns predictions and potential metrics.
...@@ -1303,6 +1312,9 @@ class Trainer: ...@@ -1303,6 +1312,9 @@ class Trainer:
ignore_keys (:obj:`Lst[str]`, `optional`): ignore_keys (:obj:`Lst[str]`, `optional`):
A list of keys in the output of your model (if it is a dictionary) that should be ignored when A list of keys in the output of your model (if it is a dictionary) that should be ignored when
gathering predictions. gathering predictions.
metric_key_prefix (:obj:`str`, `optional`, defaults to :obj:`"eval"`):
An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
"eval_bleu" if the prefix is "eval" (default)
.. note:: .. note::
...@@ -1322,7 +1334,9 @@ class Trainer: ...@@ -1322,7 +1334,9 @@ class Trainer:
test_dataloader = self.get_test_dataloader(test_dataset) test_dataloader = self.get_test_dataloader(test_dataset)
return self.prediction_loop(test_dataloader, description="Prediction", ignore_keys=ignore_keys) return self.prediction_loop(
test_dataloader, description="Prediction", ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix
)
def prediction_loop( def prediction_loop(
self, self,
...@@ -1330,6 +1344,7 @@ class Trainer: ...@@ -1330,6 +1344,7 @@ class Trainer:
description: str, description: str,
prediction_loss_only: Optional[bool] = None, prediction_loss_only: Optional[bool] = None,
ignore_keys: Optional[List[str]] = None, ignore_keys: Optional[List[str]] = None,
metric_key_prefix: str = "eval",
) -> PredictionOutput: ) -> PredictionOutput:
""" """
Prediction/evaluation loop, shared by :obj:`Trainer.evaluate()` and :obj:`Trainer.predict()`. Prediction/evaluation loop, shared by :obj:`Trainer.evaluate()` and :obj:`Trainer.predict()`.
...@@ -1421,12 +1436,12 @@ class Trainer: ...@@ -1421,12 +1436,12 @@ class Trainer:
metrics = {} metrics = {}
if eval_loss is not None: if eval_loss is not None:
metrics["eval_loss"] = eval_loss.mean().item() metrics[f"{metric_key_prefix}_loss"] = eval_loss.mean().item()
# Prefix all keys with eval_ # Prefix all keys with metric_key_prefix + '_'
for key in list(metrics.keys()): for key in list(metrics.keys()):
if not key.startswith("eval_"): if not key.startswith(f"{metric_key_prefix}_"):
metrics[f"eval_{key}"] = metrics.pop(key) metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics) return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment