Unverified Commit 1198ba8f authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Add timing inside Trainer (#9196)

* Add timing inside Trainer

* Fix tests

* Add n_objs for train

* Sort logs
parent 9a25c5bd
...@@ -16,7 +16,6 @@ ...@@ -16,7 +16,6 @@
import logging import logging
import os import os
import sys import sys
import time
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Optional from typing import Optional
...@@ -120,30 +119,6 @@ class DataTrainingArguments: ...@@ -120,30 +119,6 @@ class DataTrainingArguments:
) )
def speed_metrics(split, start_time, num_samples):
"""
Measure and return speed performance metrics.
This function requires a time snapshot `start_time` before the operation to be measured starts and this
function should be run immediately after the operation to be measured has completed.
Args:
- split: one of train, val, test
- start_time: operation start time
- num_samples: number of samples processed
"""
runtime = time.time() - start_time
result = {}
samples_per_second = 1 / (runtime / num_samples)
result[f"{split}_samples_per_second"] = round(samples_per_second, 3)
result[f"{split}_runtime"] = round(runtime, 4)
result[f"{split}_n_ojbs"] = num_samples
return result
def handle_metrics(split, metrics, output_dir): def handle_metrics(split, metrics, output_dir):
""" """
Log and save metrics Log and save metrics
...@@ -155,8 +130,8 @@ def handle_metrics(split, metrics, output_dir): ...@@ -155,8 +130,8 @@ def handle_metrics(split, metrics, output_dir):
""" """
logger.info(f"***** {split} metrics *****") logger.info(f"***** {split} metrics *****")
for key, value in metrics.items(): for key in sorted(metrics.keys()):
logger.info(f" {key} = {value}") logger.info(f" {key} = {metrics[key]}")
save_json(metrics, os.path.join(output_dir, f"{split}_results.json")) save_json(metrics, os.path.join(output_dir, f"{split}_results.json"))
...@@ -311,11 +286,11 @@ def main(): ...@@ -311,11 +286,11 @@ def main():
if training_args.do_train: if training_args.do_train:
logger.info("*** Train ***") logger.info("*** Train ***")
start_time = time.time() train_result = trainer.train(
trainer.train(
model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None
) )
metrics = speed_metrics("train", start_time, data_args.n_train) metrics = train_result.metrics
metrics["train_n_objs"] = data_args.n_train
trainer.save_model() # this also saves the tokenizer trainer.save_model() # this also saves the tokenizer
...@@ -334,9 +309,8 @@ def main(): ...@@ -334,9 +309,8 @@ def main():
if training_args.do_eval: if training_args.do_eval:
logger.info("*** Evaluate ***") logger.info("*** Evaluate ***")
start_time = time.time()
metrics = trainer.evaluate(metric_key_prefix="val") metrics = trainer.evaluate(metric_key_prefix="val")
metrics.update(speed_metrics("val", start_time, data_args.n_val)) metrics["val_n_objs"] = data_args.n_val
metrics["val_loss"] = round(metrics["val_loss"], 4) metrics["val_loss"] = round(metrics["val_loss"], 4)
if trainer.is_world_process_zero(): if trainer.is_world_process_zero():
...@@ -347,10 +321,9 @@ def main(): ...@@ -347,10 +321,9 @@ def main():
if training_args.do_predict: if training_args.do_predict:
logger.info("*** Predict ***") logger.info("*** Predict ***")
start_time = time.time()
test_output = trainer.predict(test_dataset=test_dataset, metric_key_prefix="test") test_output = trainer.predict(test_dataset=test_dataset, metric_key_prefix="test")
metrics = test_output.metrics metrics = test_output.metrics
metrics.update(speed_metrics("test", start_time, data_args.n_test)) metrics["test_n_objs"] = data_args.n_test
if trainer.is_world_process_zero(): if trainer.is_world_process_zero():
metrics["test_loss"] = round(metrics["test_loss"], 4) metrics["test_loss"] = round(metrics["test_loss"], 4)
......
...@@ -97,9 +97,7 @@ class ExamplesTests(TestCasePlus): ...@@ -97,9 +97,7 @@ class ExamplesTests(TestCasePlus):
with patch.object(sys, "argv", testargs): with patch.object(sys, "argv", testargs):
result = run_glue.main() result = run_glue.main()
del result["eval_loss"] self.assertGreaterEqual(result["eval_accuracy"], 0.75)
for value in result.values():
self.assertGreaterEqual(value, 0.75)
@require_torch_non_multi_gpu_but_fix_me @require_torch_non_multi_gpu_but_fix_me
def test_run_clm(self): def test_run_clm(self):
......
...@@ -22,6 +22,7 @@ import math ...@@ -22,6 +22,7 @@ import math
import os import os
import re import re
import shutil import shutil
import time
import warnings import warnings
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Union
...@@ -89,6 +90,7 @@ from .trainer_utils import ( ...@@ -89,6 +90,7 @@ from .trainer_utils import (
default_compute_objective, default_compute_objective,
default_hp_space, default_hp_space,
set_seed, set_seed,
speed_metrics,
) )
from .training_args import TrainingArguments from .training_args import TrainingArguments
from .utils import logging from .utils import logging
...@@ -707,6 +709,7 @@ class Trainer: ...@@ -707,6 +709,7 @@ class Trainer:
logger.info(f" Total optimization steps = {max_steps}") logger.info(f" Total optimization steps = {max_steps}")
self.state.epoch = 0 self.state.epoch = 0
start_time = time.time()
epochs_trained = 0 epochs_trained = 0
steps_trained_in_current_epoch = 0 steps_trained_in_current_epoch = 0
...@@ -870,15 +873,17 @@ class Trainer: ...@@ -870,15 +873,17 @@ class Trainer:
state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME)) state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME))
self.model.load_state_dict(state_dict) self.model.load_state_dict(state_dict)
metrics = speed_metrics("train", start_time, self.state.max_steps)
if self._total_flos is not None: if self._total_flos is not None:
self.store_flos() self.store_flos()
self.log({"total_flos": self.state.total_flos}) metrics["total_flos"] = self.state.total_flos
self.log(metrics)
self.control = self.callback_handler.on_train_end(self.args, self.state, self.control) self.control = self.callback_handler.on_train_end(self.args, self.state, self.control)
# add remaining tr_loss # add remaining tr_loss
self._total_loss_scalar += tr_loss.item() self._total_loss_scalar += tr_loss.item()
return TrainOutput(self.state.global_step, self._total_loss_scalar / self.state.global_step) return TrainOutput(self.state.global_step, self._total_loss_scalar / self.state.global_step, metrics)
def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch): def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch):
if self.control.should_log: if self.control.should_log:
...@@ -1317,6 +1322,7 @@ class Trainer: ...@@ -1317,6 +1322,7 @@ class Trainer:
raise ValueError("eval_dataset must implement __len__") raise ValueError("eval_dataset must implement __len__")
eval_dataloader = self.get_eval_dataloader(eval_dataset) eval_dataloader = self.get_eval_dataloader(eval_dataset)
start_time = time.time()
output = self.prediction_loop( output = self.prediction_loop(
eval_dataloader, eval_dataloader,
...@@ -1328,6 +1334,8 @@ class Trainer: ...@@ -1328,6 +1334,8 @@ class Trainer:
metric_key_prefix=metric_key_prefix, metric_key_prefix=metric_key_prefix,
) )
n_samples = len(eval_dataset if eval_dataset is not None else self.eval_dataset)
output.metrics.update(speed_metrics(metric_key_prefix, start_time, n_samples))
self.log(output.metrics) self.log(output.metrics)
if self.args.tpu_metrics_debug or self.args.debug: if self.args.tpu_metrics_debug or self.args.debug:
...@@ -1374,10 +1382,13 @@ class Trainer: ...@@ -1374,10 +1382,13 @@ class Trainer:
raise ValueError("test_dataset must implement __len__") raise ValueError("test_dataset must implement __len__")
test_dataloader = self.get_test_dataloader(test_dataset) test_dataloader = self.get_test_dataloader(test_dataset)
start_time = time.time()
return self.prediction_loop( output = self.prediction_loop(
test_dataloader, description="Prediction", ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix test_dataloader, description="Prediction", ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix
) )
output.metrics.update(speed_metrics(metric_key_prefix, start_time, len(test_dataset)))
return output
def prediction_loop( def prediction_loop(
self, self,
......
...@@ -18,6 +18,7 @@ Utilities for the Trainer and TFTrainer class. Should be independent from PyTorc ...@@ -18,6 +18,7 @@ Utilities for the Trainer and TFTrainer class. Should be independent from PyTorc
import copy import copy
import random import random
import time
from typing import Any, Dict, NamedTuple, Optional, Tuple, Union from typing import Any, Dict, NamedTuple, Optional, Tuple, Union
import numpy as np import numpy as np
...@@ -70,6 +71,7 @@ class PredictionOutput(NamedTuple): ...@@ -70,6 +71,7 @@ class PredictionOutput(NamedTuple):
class TrainOutput(NamedTuple): class TrainOutput(NamedTuple):
global_step: int global_step: int
training_loss: float training_loss: float
metrics: Dict[str, float]
PREFIX_CHECKPOINT_DIR = "checkpoint" PREFIX_CHECKPOINT_DIR = "checkpoint"
...@@ -179,3 +181,23 @@ def total_processes_number(local_rank): ...@@ -179,3 +181,23 @@ def total_processes_number(local_rank):
return torch.distributed.get_world_size() return torch.distributed.get_world_size()
return 1 return 1
def speed_metrics(split, start_time, num_samples=None):
"""
Measure and return speed performance metrics.
This function requires a time snapshot `start_time` before the operation to be measured starts and this function
should be run immediately after the operation to be measured has completed.
Args:
- split: name to prefix metric (like train, eval, test...)
- start_time: operation start time
- num_samples: number of samples processed
"""
runtime = time.time() - start_time
result = {f"{split}_runtime": round(runtime, 4)}
if num_samples is not None:
samples_per_second = 1 / (runtime / num_samples)
result[f"{split}_samples_per_second"] = round(samples_per_second, 3)
return result
...@@ -12,10 +12,9 @@ ...@@ -12,10 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import dataclasses
import json import json
import os import os
from dataclasses import dataclass, field from dataclasses import asdict, dataclass, field
from enum import Enum from enum import Enum
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
...@@ -411,7 +410,16 @@ class TrainingArguments: ...@@ -411,7 +410,16 @@ class TrainingArguments:
self.run_name = self.output_dir self.run_name = self.output_dir
if is_torch_available() and self.device.type != "cuda" and self.fp16: if is_torch_available() and self.device.type != "cuda" and self.fp16:
raise ValueError("AMP (`--fp16`) can only be used on CUDA devices.") raise ValueError("Mixed precision training with AMP or APEX (`--fp16`) can only be used on CUDA devices.")
def __repr__(self):
# We override the default repr to remove deprecated arguments from the repr. This method should be removed once
# those deprecated arguments are removed form TrainingArguments. (TODO: v5)
self_as_dict = asdict(self)
del self_as_dict["per_gpu_train_batch_size"]
del self_as_dict["per_gpu_eval_batch_size"]
attrs_as_str = [f"{k}={v}" for k, v in self_as_dict.items()]
return f"{self.__class__.__name__}({', '.join(attrs_as_str)})"
@property @property
def train_batch_size(self) -> int: def train_batch_size(self) -> int:
...@@ -523,7 +531,7 @@ class TrainingArguments: ...@@ -523,7 +531,7 @@ class TrainingArguments:
""" """
Serializes this instance while replace `Enum` by their values (for JSON serialization support). Serializes this instance while replace `Enum` by their values (for JSON serialization support).
""" """
d = dataclasses.asdict(self) d = asdict(self)
for k, v in d.items(): for k, v in d.items():
if isinstance(v, Enum): if isinstance(v, Enum):
d[k] = v.value d[k] = v.value
......
...@@ -265,6 +265,21 @@ class TrainerIntegrationTest(unittest.TestCase): ...@@ -265,6 +265,21 @@ class TrainerIntegrationTest(unittest.TestCase):
metrics = trainer.evaluate() metrics = trainer.evaluate()
self.assertEqual(metrics[metric], best_value) self.assertEqual(metrics[metric], best_value)
def check_trainer_state_are_the_same(self, trainer_state, trainer_state1):
# We'll pop things so operate on copies.
state = trainer_state.copy()
state1 = trainer_state1.copy()
# Log history main contain different logs for the time metrics (after resuming a training).
log_history = state.pop("log_history", None)
log_history1 = state1.pop("log_history", None)
self.assertEqual(state, state1)
for log, log1 in zip(log_history, log_history1):
_ = log.pop("train_runtime", None)
_ = log1.pop("train_runtime", None)
_ = log.pop("train_samples_per_second", None)
_ = log1.pop("train_samples_per_second", None)
self.assertEqual(log, log1)
def test_trainer_works_with_dict(self): def test_trainer_works_with_dict(self):
# Edge case because Apex with mode O2 will change our models to return dicts. This test checks it doesn't break # Edge case because Apex with mode O2 will change our models to return dicts. This test checks it doesn't break
# anything. # anything.
...@@ -552,7 +567,7 @@ class TrainerIntegrationTest(unittest.TestCase): ...@@ -552,7 +567,7 @@ class TrainerIntegrationTest(unittest.TestCase):
state1 = dataclasses.asdict(trainer.state) state1 = dataclasses.asdict(trainer.state)
self.assertEqual(a, a1) self.assertEqual(a, a1)
self.assertEqual(b, b1) self.assertEqual(b, b1)
self.assertEqual(state, state1) self.check_trainer_state_are_the_same(state, state1)
# Now check with a later checkpoint that it also works when we span over one epoch # Now check with a later checkpoint that it also works when we span over one epoch
checkpoint = os.path.join(tmpdir, "checkpoint-15") checkpoint = os.path.join(tmpdir, "checkpoint-15")
...@@ -566,7 +581,7 @@ class TrainerIntegrationTest(unittest.TestCase): ...@@ -566,7 +581,7 @@ class TrainerIntegrationTest(unittest.TestCase):
state1 = dataclasses.asdict(trainer.state) state1 = dataclasses.asdict(trainer.state)
self.assertEqual(a, a1) self.assertEqual(a, a1)
self.assertEqual(b, b1) self.assertEqual(b, b1)
self.assertEqual(state, state1) self.check_trainer_state_are_the_same(state, state1)
# With a regular model that is not a PreTrainedModel # With a regular model that is not a PreTrainedModel
with tempfile.TemporaryDirectory() as tmpdir: with tempfile.TemporaryDirectory() as tmpdir:
...@@ -590,7 +605,7 @@ class TrainerIntegrationTest(unittest.TestCase): ...@@ -590,7 +605,7 @@ class TrainerIntegrationTest(unittest.TestCase):
state1 = dataclasses.asdict(trainer.state) state1 = dataclasses.asdict(trainer.state)
self.assertEqual(a, a1) self.assertEqual(a, a1)
self.assertEqual(b, b1) self.assertEqual(b, b1)
self.assertEqual(state, state1) self.check_trainer_state_are_the_same(state, state1)
# Now check with a later checkpoint that it also works when we span over one epoch # Now check with a later checkpoint that it also works when we span over one epoch
checkpoint = os.path.join(tmpdir, "checkpoint-15") checkpoint = os.path.join(tmpdir, "checkpoint-15")
...@@ -606,7 +621,7 @@ class TrainerIntegrationTest(unittest.TestCase): ...@@ -606,7 +621,7 @@ class TrainerIntegrationTest(unittest.TestCase):
state1 = dataclasses.asdict(trainer.state) state1 = dataclasses.asdict(trainer.state)
self.assertEqual(a, a1) self.assertEqual(a, a1)
self.assertEqual(b, b1) self.assertEqual(b, b1)
self.assertEqual(state, state1) self.check_trainer_state_are_the_same(state, state1)
def test_resume_training_with_gradient_accumulation(self): def test_resume_training_with_gradient_accumulation(self):
if torch.cuda.device_count() > 2: if torch.cuda.device_count() > 2:
...@@ -638,7 +653,7 @@ class TrainerIntegrationTest(unittest.TestCase): ...@@ -638,7 +653,7 @@ class TrainerIntegrationTest(unittest.TestCase):
state1 = dataclasses.asdict(trainer.state) state1 = dataclasses.asdict(trainer.state)
self.assertEqual(a, a1) self.assertEqual(a, a1)
self.assertEqual(b, b1) self.assertEqual(b, b1)
self.assertEqual(state, state1) self.check_trainer_state_are_the_same(state, state1)
def test_load_best_model_at_end(self): def test_load_best_model_at_end(self):
total = int(self.n_epochs * 64 / self.batch_size) total = int(self.n_epochs * 64 / self.batch_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment