Unverified Commit 01d340ad authored by Teven's avatar Teven Committed by GitHub
Browse files

Floating-point operations logging in trainer (#6768)



* neFLOs calculation, logging, and reloading (#1)

* testing distributed consecutive batches

* fixed AttributeError from DataParallel

* removed verbosity

* rotate with use_mtime=True

* removed print

* fixed interaction with gradient accumulation

* indent formatting

* distributed neflo counting

* fixed typo

* fixed typo

* mean distributed losses

* exporting log history

* moved a few functions

* floating_point_ops clarification for transformers with parameter-reuse

* code quality

* double import

* made flo estimation more task-agnostic

* only logging flos if computed

* code quality

* unused import

* Update src/transformers/trainer.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/modeling_utils.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Sylvain review

* Update src/transformers/modeling_utils.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* black
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent d155b38d
...@@ -17,8 +17,9 @@ ...@@ -17,8 +17,9 @@
import inspect import inspect
import os import os
import re import re
import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import Callable, Dict, List, Optional, Set, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
import torch import torch
from torch import Tensor, device, dtype, nn from torch import Tensor, device, dtype, nn
...@@ -45,7 +46,6 @@ from .utils import logging ...@@ -45,7 +46,6 @@ from .utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
try: try:
from torch.nn import Identity from torch.nn import Identity
except ImportError: except ImportError:
...@@ -91,20 +91,6 @@ class ModuleUtilsMixin: ...@@ -91,20 +91,6 @@ class ModuleUtilsMixin:
A few utilities for :obj:`torch.nn.Modules`, to be used as a mixin. A few utilities for :obj:`torch.nn.Modules`, to be used as a mixin.
""" """
def num_parameters(self, only_trainable: bool = False) -> int:
"""
Get the number of (optionally, trainable) parameters in the model.
Args:
only_trainable (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to return only the number of trainable parameters
Returns:
:obj:`int`: The number of parameters.
"""
params = filter(lambda x: x.requires_grad, self.parameters()) if only_trainable else self.parameters()
return sum(p.numel() for p in params)
@staticmethod @staticmethod
def _hook_rss_memory_pre_forward(module, *args, **kwargs): def _hook_rss_memory_pre_forward(module, *args, **kwargs):
try: try:
...@@ -307,9 +293,77 @@ class ModuleUtilsMixin: ...@@ -307,9 +293,77 @@ class ModuleUtilsMixin:
elif head_mask.dim() == 2: elif head_mask.dim() == 2:
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
assert head_mask.dim() == 5, f"head_mask.dim != 5, instead {head_mask.dim()}" assert head_mask.dim() == 5, f"head_mask.dim != 5, instead {head_mask.dim()}"
head_mask = head_mask.to(dtype=self.dtype) # switch to fload if need + fp16 compatibility head_mask = head_mask.to(dtype=self.dtype) # switch to float if need + fp16 compatibility
return head_mask return head_mask
def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int:
"""
Get number of (optionally, trainable or non-embeddings) parameters in the module.
Args:
only_trainable (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to return only the number of trainable parameters
exclude_embeddings (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to return only the number of non-embeddings parameters
Returns:
:obj:`int`: The number of parameters.
"""
def parameter_filter(x):
return (x.requires_grad or not only_trainable) and not (
isinstance(x, torch.nn.Embedding) and exclude_embeddings
)
params = filter(parameter_filter, self.parameters()) if only_trainable else self.parameters()
return sum(p.numel() for p in params)
def estimate_tokens(self, input_dict: Dict[str, Union[torch.Tensor, Any]]) -> int:
"""
Helper function to estimate the total number of tokens from the model inputs.
Args:
inputs (:obj:`dict`): The model inputs.
Returns:
:obj:`int`: The total number of tokens.
"""
token_inputs = [tensor for key, tensor in input_dict.items() if "input" in key]
if token_inputs:
return sum([token_input.numel() for token_input in token_inputs])
else:
warnings.warn(
"Could not estimate the number of tokens of the input, floating-point operations will not be computed"
)
return 0
def floating_point_ops(
self, input_dict: Dict[str, Union[torch.Tensor, Any]], exclude_embeddings: bool = True
) -> int:
"""
Get number of (optionally, non-embeddings) floating-point operations for the forward and backward passes of a
batch with this transformer model. Default approximation neglects the quadratic dependency on the number of
tokens (valid if :obj:`12 * d_model << sequence_length`) as laid out in `this paper <https://arxiv.org/pdf/2001.08361.pdf>`__ section
2.1. Should be overriden for transformers with parameter re-use e.g. Albert or Universal Transformers, or
if doing long-range modeling with very high sequence lengths.
Args:
batch_size (:obj:`int`):
The batch size for the forward pass.
sequence_length (:obj:`int`):
The number of tokens in each line of the batch.
exclude_embeddings (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to count embedding and softmax operations.
Returns:
:obj:`int`: The number of floating-point operations.
"""
return 6 * self.estimate_tokens(input_dict) * self.num_parameters(exclude_embeddings=exclude_embeddings)
class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
r""" r"""
......
import inspect import inspect
import json
import math import math
import os import os
import re import re
...@@ -42,6 +43,8 @@ from .trainer_utils import ( ...@@ -42,6 +43,8 @@ from .trainer_utils import (
TrainOutput, TrainOutput,
default_compute_objective, default_compute_objective,
default_hp_space, default_hp_space,
distributed_broadcast_scalars,
distributed_concat,
set_seed, set_seed,
) )
from .training_args import TrainingArguments from .training_args import TrainingArguments
...@@ -146,7 +149,7 @@ class SequentialDistributedSampler(Sampler): ...@@ -146,7 +149,7 @@ class SequentialDistributedSampler(Sampler):
indices = indices[self.rank * self.num_samples : (self.rank + 1) * self.num_samples] indices = indices[self.rank * self.num_samples : (self.rank + 1) * self.num_samples]
assert ( assert (
len(indices) == self.num_samples len(indices) == self.num_samples
), f"Indices length {len(indices)} and and sample number {self.num_samples} mismatched" ), f"Indices length {len(indices)} and sample number {self.num_samples} mismatched"
return iter(indices) return iter(indices)
...@@ -241,6 +244,7 @@ class Trainer: ...@@ -241,6 +244,7 @@ class Trainer:
"You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method." "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
) )
self.tb_writer = tb_writer self.tb_writer = tb_writer
self.log_history = []
if "prediction_loss_only" in kwargs: if "prediction_loss_only" in kwargs:
warnings.warn( warnings.warn(
"Passing `prediction_loss_only` as a keyword argument is deprecated and won't be possible in a future version. Use `args.prediction_loss_only` instead.", "Passing `prediction_loss_only` as a keyword argument is deprecated and won't be possible in a future version. Use `args.prediction_loss_only` instead.",
...@@ -284,6 +288,7 @@ class Trainer: ...@@ -284,6 +288,7 @@ class Trainer:
self.global_step = None self.global_step = None
self.epoch = None self.epoch = None
self.total_flos = None
if self.args.fp16 and _use_native_amp: if self.args.fp16 and _use_native_amp:
self.scaler = torch.cuda.amp.GradScaler() self.scaler = torch.cuda.amp.GradScaler()
self.hp_search_backend = None self.hp_search_backend = None
...@@ -461,7 +466,11 @@ class Trainer: ...@@ -461,7 +466,11 @@ class Trainer:
logger.info( logger.info(
'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"' 'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"'
) )
try:
combined_dict = {**self.model.config.to_dict(), **self.args.to_sanitized_dict()} combined_dict = {**self.model.config.to_dict(), **self.args.to_sanitized_dict()}
except AttributeError:
# in case the model has no config
combined_dict = {**self.args.to_sanitized_dict()}
wandb.init( wandb.init(
project=os.getenv("WANDB_PROJECT", "huggingface"), config=combined_dict, name=self.args.run_name project=os.getenv("WANDB_PROJECT", "huggingface"), config=combined_dict, name=self.args.run_name
) )
...@@ -663,6 +672,7 @@ class Trainer: ...@@ -663,6 +672,7 @@ class Trainer:
self.global_step = 0 self.global_step = 0
self.epoch = 0 self.epoch = 0
self.total_flos = 0
epochs_trained = 0 epochs_trained = 0
steps_trained_in_current_epoch = 0 steps_trained_in_current_epoch = 0
# Check if continuing training from a checkpoint # Check if continuing training from a checkpoint
...@@ -670,6 +680,8 @@ class Trainer: ...@@ -670,6 +680,8 @@ class Trainer:
# set global_step to global_step of last saved checkpoint from model path # set global_step to global_step of last saved checkpoint from model path
try: try:
self.global_step = int(model_path.split("-")[-1].split(os.path.sep)[0]) self.global_step = int(model_path.split("-")[-1].split(os.path.sep)[0])
self.total_flos = getattr(model.config, "total_flos", 0)
epochs_trained = self.global_step // (len(train_dataloader) // self.args.gradient_accumulation_steps) epochs_trained = self.global_step // (len(train_dataloader) // self.args.gradient_accumulation_steps)
steps_trained_in_current_epoch = self.global_step % ( steps_trained_in_current_epoch = self.global_step % (
len(train_dataloader) // self.args.gradient_accumulation_steps len(train_dataloader) // self.args.gradient_accumulation_steps
...@@ -678,9 +690,11 @@ class Trainer: ...@@ -678,9 +690,11 @@ class Trainer:
logger.info(" Continuing training from checkpoint, will skip to saved global_step") logger.info(" Continuing training from checkpoint, will skip to saved global_step")
logger.info(" Continuing training from epoch %d", epochs_trained) logger.info(" Continuing training from epoch %d", epochs_trained)
logger.info(" Continuing training from global step %d", self.global_step) logger.info(" Continuing training from global step %d", self.global_step)
logger.info(" Continuing training from %d non-embedding floating-point operations", self.total_flos)
logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch) logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
except ValueError: except ValueError:
self.global_step = 0 self.global_step = 0
self.total_flos = 0
logger.info(" Starting fine-tuning.") logger.info(" Starting fine-tuning.")
tr_loss = torch.tensor(0.0).to(self.args.device) tr_loss = torch.tensor(0.0).to(self.args.device)
...@@ -714,6 +728,7 @@ class Trainer: ...@@ -714,6 +728,7 @@ class Trainer:
continue continue
tr_loss += self.training_step(model, inputs) tr_loss += self.training_step(model, inputs)
self.total_flos += self.floating_point_ops(inputs)
if (step + 1) % self.args.gradient_accumulation_steps == 0 or ( if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
# last step in epoch but step is always smaller than gradient_accumulation_steps # last step in epoch but step is always smaller than gradient_accumulation_steps
...@@ -784,7 +799,7 @@ class Trainer: ...@@ -784,7 +799,7 @@ class Trainer:
self.save_model(output_dir) self.save_model(output_dir)
if self.is_world_process_zero(): if self.is_world_process_zero():
self._rotate_checkpoints() self._rotate_checkpoints(use_mtime=True)
if is_torch_tpu_available(): if is_torch_tpu_available():
xm.rendezvous("saving_optimizer_states") xm.rendezvous("saving_optimizer_states")
...@@ -924,6 +939,13 @@ class Trainer: ...@@ -924,6 +939,13 @@ class Trainer:
if self.epoch is not None: if self.epoch is not None:
logs["epoch"] = self.epoch logs["epoch"] = self.epoch
if self.total_flos is not None:
if self.args.local_rank != -1:
total_flos = distributed_broadcast_scalars([self.total_flos]).sum().item()
else:
total_flos = self.total_flos
if total_flos > 0:
logs["total_flos"] = self.total_flos
if self.global_step is None: if self.global_step is None:
# when logging evaluation metrics without training # when logging evaluation metrics without training
self.global_step = 0 self.global_step = 0
...@@ -951,6 +973,8 @@ class Trainer: ...@@ -951,6 +973,8 @@ class Trainer:
if experiment is not None: if experiment is not None:
experiment._log_metrics(logs, step=self.global_step, epoch=self.epoch, framework="transformers") experiment._log_metrics(logs, step=self.global_step, epoch=self.epoch, framework="transformers")
output = {**logs, **{"step": self.global_step}} output = {**logs, **{"step": self.global_step}}
if self.is_world_process_zero():
self.log_history.append(output)
if iterator is not None: if iterator is not None:
iterator.write(output) iterator.write(output)
else: else:
...@@ -1089,6 +1113,9 @@ class Trainer: ...@@ -1089,6 +1113,9 @@ class Trainer:
if xm.is_master_ordinal(): if xm.is_master_ordinal():
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
torch.save(self.args, os.path.join(output_dir, "training_args.bin")) torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
json.dump(
self.log_history, open(os.path.join(output_dir, "log_history.json"), "w"), indent=2, ensure_ascii=False
)
# Save a trained model and configuration using `save_pretrained()`. # Save a trained model and configuration using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()` # They can then be reloaded using `from_pretrained()`
...@@ -1096,6 +1123,7 @@ class Trainer: ...@@ -1096,6 +1123,7 @@ class Trainer:
raise ValueError("Trainer.model appears to not be a PreTrainedModel") raise ValueError("Trainer.model appears to not be a PreTrainedModel")
xm.rendezvous("saving_checkpoint") xm.rendezvous("saving_checkpoint")
self._store_flos()
self.model.save_pretrained(output_dir) self.model.save_pretrained(output_dir)
if self.tokenizer is not None: if self.tokenizer is not None:
self.tokenizer.save_pretrained(output_dir) self.tokenizer.save_pretrained(output_dir)
...@@ -1108,12 +1136,26 @@ class Trainer: ...@@ -1108,12 +1136,26 @@ class Trainer:
# They can then be reloaded using `from_pretrained()` # They can then be reloaded using `from_pretrained()`
if not isinstance(self.model, PreTrainedModel): if not isinstance(self.model, PreTrainedModel):
raise ValueError("Trainer.model appears to not be a PreTrainedModel") raise ValueError("Trainer.model appears to not be a PreTrainedModel")
self._store_flos()
self.model.save_pretrained(output_dir) self.model.save_pretrained(output_dir)
if self.tokenizer is not None: if self.tokenizer is not None:
self.tokenizer.save_pretrained(output_dir) self.tokenizer.save_pretrained(output_dir)
# Good practice: save your training arguments together with the trained model # Good practice: save your training arguments together with the trained model
torch.save(self.args, os.path.join(output_dir, "training_args.bin")) torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
json.dump(
self.log_history, open(os.path.join(output_dir, "log_history.json"), "w"), indent=2, ensure_ascii=False
)
def _store_flos(self):
# Storing the number of floating-point operations that went into the model
if self.total_flos is not None:
if self.args.local_rank != -1:
total_flos = distributed_broadcast_scalars([self.total_flos]).sum().item()
else:
total_flos = self.total_flos
if total_flos > 0:
self.model.config.total_flos = total_flos
def _sorted_checkpoints(self, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False) -> List[str]: def _sorted_checkpoints(self, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False) -> List[str]:
ordering_and_checkpoint_path = [] ordering_and_checkpoint_path = []
...@@ -1245,13 +1287,11 @@ class Trainer: ...@@ -1245,13 +1287,11 @@ class Trainer:
self._past = None self._past = None
disable_tqdm = not self.is_local_process_zero() or self.args.disable_tqdm disable_tqdm = not self.is_local_process_zero() or self.args.disable_tqdm
samples_count = 0
for inputs in tqdm(dataloader, desc=description, disable=disable_tqdm): for inputs in tqdm(dataloader, desc=description, disable=disable_tqdm):
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only) loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only)
batch_size = inputs[list(inputs.keys())[0]].shape[0] batch_size = inputs[list(inputs.keys())[0]].shape[0]
samples_count += batch_size
if loss is not None: if loss is not None:
eval_losses.append(loss * batch_size) eval_losses.extend([loss] * batch_size)
if logits is not None: if logits is not None:
preds = logits if preds is None else torch.cat((preds, logits), dim=0) preds = logits if preds is None else torch.cat((preds, logits), dim=0)
if labels is not None: if labels is not None:
...@@ -1264,9 +1304,9 @@ class Trainer: ...@@ -1264,9 +1304,9 @@ class Trainer:
if self.args.local_rank != -1: if self.args.local_rank != -1:
# In distributed mode, concatenate all results from all nodes: # In distributed mode, concatenate all results from all nodes:
if preds is not None: if preds is not None:
preds = self.distributed_concat(preds, num_total_examples=self.num_examples(dataloader)) preds = distributed_concat(preds, num_total_examples=self.num_examples(dataloader))
if label_ids is not None: if label_ids is not None:
label_ids = self.distributed_concat(label_ids, num_total_examples=self.num_examples(dataloader)) label_ids = distributed_concat(label_ids, num_total_examples=self.num_examples(dataloader))
elif is_torch_tpu_available(): elif is_torch_tpu_available():
# tpu-comment: Get all predictions and labels from all worker shards of eval dataset # tpu-comment: Get all predictions and labels from all worker shards of eval dataset
if preds is not None: if preds is not None:
...@@ -1289,7 +1329,14 @@ class Trainer: ...@@ -1289,7 +1329,14 @@ class Trainer:
else: else:
metrics = {} metrics = {}
if len(eval_losses) > 0: if len(eval_losses) > 0:
metrics["eval_loss"] = np.sum(eval_losses) / samples_count if self.args.local_rank != -1:
metrics["eval_loss"] = (
distributed_broadcast_scalars(eval_losses, num_total_examples=self.num_examples(dataloader))
.mean()
.item()
)
else:
metrics["eval_loss"] = np.mean(eval_losses)
# Prefix all keys with eval_ # Prefix all keys with eval_
for key in list(metrics.keys()): for key in list(metrics.keys()):
...@@ -1298,18 +1345,6 @@ class Trainer: ...@@ -1298,18 +1345,6 @@ class Trainer:
return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics) return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)
def distributed_concat(self, tensor: torch.Tensor, num_total_examples: int) -> torch.Tensor:
assert self.args.local_rank != -1
output_tensors = [tensor.clone() for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(output_tensors, tensor)
concat = torch.cat(output_tensors, dim=0)
# truncate the dummy elements added by SequentialDistributedSampler
output = concat[:num_total_examples]
return output
def prediction_step( def prediction_step(
self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], prediction_loss_only: bool self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], prediction_loss_only: bool
) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
...@@ -1355,3 +1390,32 @@ class Trainer: ...@@ -1355,3 +1390,32 @@ class Trainer:
if labels is not None: if labels is not None:
labels = labels.detach() labels = labels.detach()
return (loss, logits.detach(), labels) return (loss, logits.detach(), labels)
def floating_point_ops(self, inputs: Dict[str, Union[torch.Tensor, Any]]):
"""
For models that inherit from :class:`~transformers.PretrainedModel`, uses
that method to compute the number of floating point operations for every backward + forward pass. If using
another model, either implement such a method in the model or subclass and override this method.
Args:
model (:obj:`nn.Module`):
The model to evaluate.
inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
The inputs and targets of the model.
Returns:
:obj:`int`: The number of floating-point operations.
"""
if isinstance(self.model, torch.nn.DataParallel) or isinstance(
self.model, torch.nn.parallel.DistributedDataParallel
):
model = self.model.module
else:
model = self.model
if hasattr(model, "floating_point_ops"):
return model.floating_point_ops(inputs)
else:
return 0
import random import random
from typing import Any, Dict, NamedTuple, Optional from typing import Any, Dict, List, NamedTuple, Optional, Union
import numpy as np import numpy as np
import torch
from .file_utils import is_tf_available, is_torch_available from .file_utils import is_tf_available, is_torch_available
from .tokenization_utils_base import ExplicitEnum from .tokenization_utils_base import ExplicitEnum
...@@ -126,3 +127,32 @@ default_hp_space = { ...@@ -126,3 +127,32 @@ default_hp_space = {
HPSearchBackend.OPTUNA: default_hp_space_optuna, HPSearchBackend.OPTUNA: default_hp_space_optuna,
HPSearchBackend.RAY: default_hp_space_ray, HPSearchBackend.RAY: default_hp_space_ray,
} }
def distributed_concat(self, tensor: torch.Tensor, num_total_examples: Optional[int] = None) -> torch.Tensor:
assert self.args.local_rank != -1
output_tensors = [tensor.clone() for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(output_tensors, tensor)
concat = torch.cat(output_tensors, dim=0)
# truncate the dummy elements added by SequentialDistributedSampler
if num_total_examples is not None:
concat = concat[:num_total_examples]
return concat
def distributed_broadcast_scalars(
self, scalars: List[Union[int, float]], num_total_examples: Optional[int] = None
) -> torch.Tensor:
assert self.args.local_rank != -1
tensorized_scalar = torch.Tensor(scalars).cuda()
output_tensors = [tensorized_scalar.clone() for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(output_tensors, tensorized_scalar)
concat = torch.cat(output_tensors, dim=0)
# truncate the dummy elements added by SequentialDistributedSampler
if num_total_examples is not None:
concat = concat[:num_total_examples]
return concat
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment