Unverified Commit bdbb2c75 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[trainer] move secondary methods into a separate file (#10363)

* move secondary methods into a separate file

* cleanup

* style
parent 5f2a3d72
...@@ -19,7 +19,6 @@ The Trainer class, to easily train a 🤗 Transformers from scratch or finetune ...@@ -19,7 +19,6 @@ The Trainer class, to easily train a 🤗 Transformers from scratch or finetune
import collections import collections
import gc import gc
import inspect import inspect
import json
import math import math
import os import os
import re import re
...@@ -82,7 +81,6 @@ from .trainer_pt_utils import ( ...@@ -82,7 +81,6 @@ from .trainer_pt_utils import (
SequentialDistributedSampler, SequentialDistributedSampler,
distributed_broadcast_scalars, distributed_broadcast_scalars,
distributed_concat, distributed_concat,
get_learning_rate,
nested_concat, nested_concat,
nested_detach, nested_detach,
nested_numpify, nested_numpify,
...@@ -226,6 +224,8 @@ class Trainer: ...@@ -226,6 +224,8 @@ class Trainer:
""" """
from .trainer_pt_utils import _get_learning_rate, log_metrics, metrics_format, save_metrics
def __init__( def __init__(
self, self,
model: Union[PreTrainedModel, torch.nn.Module] = None, model: Union[PreTrainedModel, torch.nn.Module] = None,
...@@ -1130,7 +1130,7 @@ class Trainer: ...@@ -1130,7 +1130,7 @@ class Trainer:
tr_loss -= tr_loss tr_loss -= tr_loss
logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4) logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
logs["learning_rate"] = get_learning_rate(self) logs["learning_rate"] = self._get_learning_rate()
self._total_loss_scalar += tr_loss_scalar self._total_loss_scalar += tr_loss_scalar
self._globalstep_last_logged = self.state.global_step self._globalstep_last_logged = self.state.global_step
...@@ -1345,61 +1345,6 @@ class Trainer: ...@@ -1345,61 +1345,6 @@ class Trainer:
self.state.log_history.append(output) self.state.log_history.append(output)
self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs) self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs)
def metrics_format(self, metrics: Dict[str, float]) -> Dict[str, float]:
"""
Reformat Trainer metrics values to a human-readable format
Args:
metrics (:obj:`Dict[str, float]`):
The metrics returned from train/evaluate/predict
Returns:
metrics (:obj:`Dict[str, float]`): The reformatted metrics
"""
metrics_copy = metrics.copy()
for k, v in metrics_copy.items():
if "_mem_" in k:
metrics_copy[k] = f"{ v >> 20 }MB"
elif k == "total_flos":
metrics_copy[k] = f"{ int(v) >> 30 }GF"
elif type(metrics_copy[k]) == float:
metrics_copy[k] = round(v, 4)
return metrics_copy
def log_metrics(self, split, metrics):
"""
Log metrics in a specially formatted way
Args:
split (:obj:`str`):
Mode/split name: one of ``train``, ``eval``, ``test``
metrics (:obj:`Dict[str, float]`):
The metrics returned from train/evaluate/predictmetrics: metrics dict
"""
logger.info(f"***** {split} metrics *****")
metrics_formatted = self.metrics_format(metrics)
k_width = max(len(str(x)) for x in metrics_formatted.keys())
v_width = max(len(str(x)) for x in metrics_formatted.values())
for key in sorted(metrics_formatted.keys()):
logger.info(f" {key: <{k_width}} = {metrics_formatted[key]:>{v_width}}")
def save_metrics(self, split, metrics):
"""
Save metrics into a json file for that split, e.g. ``train_results.json``.
Args:
split (:obj:`str`):
Mode/split name: one of ``train``, ``eval``, ``test``, ``all``
metrics (:obj:`Dict[str, float]`):
The metrics returned from train/evaluate/predict
"""
path = os.path.join(self.args.output_dir, f"{split}_results.json")
with open(path, "w") as f:
json.dump(metrics, f, indent=4, sort_keys=True)
def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]: def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]:
""" """
Prepare :obj:`inputs` before feeding them to the model, converting them to tensors if they are not already and Prepare :obj:`inputs` before feeding them to the model, converting them to tensors if they are not already and
......
...@@ -16,11 +16,13 @@ ...@@ -16,11 +16,13 @@
Torch utilities for the Trainer class. Torch utilities for the Trainer class.
""" """
import json
import math import math
import os
import warnings import warnings
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from typing import Iterator, List, Optional, Union from typing import Dict, Iterator, List, Optional, Union
import numpy as np import numpy as np
import torch import torch
...@@ -263,29 +265,6 @@ def _get_first_shape(arrays): ...@@ -263,29 +265,6 @@ def _get_first_shape(arrays):
return arrays.shape return arrays.shape
def get_learning_rate(trainer):
if trainer.deepspeed:
# with deepspeed's fp16 and dynamic loss scale enabled the optimizer/scheduler steps may
# not run for the first few dozen steps while loss scale is too large, and thus during
# that time `get_last_lr` will fail if called during that warm up stage, so work around it:
try:
last_lr = trainer.lr_scheduler.get_last_lr()[0]
except AssertionError as e:
if "need to call step" in str(e):
logger.warn("tried to get lr value before scheduler/optimizer started stepping, returning lr=0")
last_lr = 0
else:
raise
else:
last_lr = (
# backward compatibility for pytorch schedulers
trainer.lr_scheduler.get_last_lr()[0]
if version.parse(torch.__version__) >= version.parse("1.4")
else trainer.lr_scheduler.get_lr()[0]
)
return last_lr
class DistributedTensorGatherer: class DistributedTensorGatherer:
""" """
A class responsible for properly gathering tensors (or nested list/tuple of tensors) on the CPU by chunks. A class responsible for properly gathering tensors (or nested list/tuple of tensors) on the CPU by chunks.
...@@ -563,3 +542,88 @@ class DistributedLengthGroupedSampler(DistributedSampler): ...@@ -563,3 +542,88 @@ class DistributedLengthGroupedSampler(DistributedSampler):
assert len(indices) == self.num_samples assert len(indices) == self.num_samples
return iter(indices) return iter(indices)
# In order to keep `trainer.py` compact and easy to understand, place any secondary PT Trainer
# helper methods here
def _get_learning_rate(self):
if self.deepspeed:
# with deepspeed's fp16 and dynamic loss scale enabled the optimizer/scheduler steps may
# not run for the first few dozen steps while loss scale is too large, and thus during
# that time `get_last_lr` will fail if called during that warm up stage, so work around it:
try:
last_lr = self.lr_scheduler.get_last_lr()[0]
except AssertionError as e:
if "need to call step" in str(e):
logger.warn("tried to get lr value before scheduler/optimizer started stepping, returning lr=0")
last_lr = 0
else:
raise
else:
last_lr = (
# backward compatibility for pytorch schedulers
self.lr_scheduler.get_last_lr()[0]
if version.parse(torch.__version__) >= version.parse("1.4")
else self.lr_scheduler.get_lr()[0]
)
return last_lr
def metrics_format(self, metrics: Dict[str, float]) -> Dict[str, float]:
"""
Reformat Trainer metrics values to a human-readable format
Args:
metrics (:obj:`Dict[str, float]`):
The metrics returned from train/evaluate/predict
Returns:
metrics (:obj:`Dict[str, float]`): The reformatted metrics
"""
metrics_copy = metrics.copy()
for k, v in metrics_copy.items():
if "_mem_" in k:
metrics_copy[k] = f"{ v >> 20 }MB"
elif k == "total_flos":
metrics_copy[k] = f"{ int(v) >> 30 }GF"
elif type(metrics_copy[k]) == float:
metrics_copy[k] = round(v, 4)
return metrics_copy
def log_metrics(self, split, metrics):
"""
Log metrics in a specially formatted way
Args:
split (:obj:`str`):
Mode/split name: one of ``train``, ``eval``, ``test``
metrics (:obj:`Dict[str, float]`):
The metrics returned from train/evaluate/predictmetrics: metrics dict
"""
logger.info(f"***** {split} metrics *****")
metrics_formatted = self.metrics_format(metrics)
k_width = max(len(str(x)) for x in metrics_formatted.keys())
v_width = max(len(str(x)) for x in metrics_formatted.values())
for key in sorted(metrics_formatted.keys()):
logger.info(f" {key: <{k_width}} = {metrics_formatted[key]:>{v_width}}")
def save_metrics(self, split, metrics):
"""
Save metrics into a json file for that split, e.g. ``train_results.json``.
Args:
split (:obj:`str`):
Mode/split name: one of ``train``, ``eval``, ``test``, ``all``
metrics (:obj:`Dict[str, float]`):
The metrics returned from train/evaluate/predict
"""
path = os.path.join(self.args.output_dir, f"{split}_results.json")
with open(path, "w") as f:
json.dump(metrics, f, indent=4, sort_keys=True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment