Unverified Commit 01d340ad authored by Teven's avatar Teven Committed by GitHub
Browse files

Floating-point operations logging in trainer (#6768)



* neFLOs calculation, logging, and reloading (#1)

* testing distributed consecutive batches

* fixed AttributeError from DataParallel

* removed verbosity

* rotate with use_mtime=True

* removed print

* fixed interaction with gradient accumulation

* indent formatting

* distributed neflo counting

* fixed typo

* fixed typo

* mean distributed losses

* exporting log history

* moved a few functions

* floating_point_ops clarification for transformers with parameter-reuse

* code quality

* double import

* made flo estimation more task-agnostic

* only logging flos if computed

* code quality

* unused import

* Update src/transformers/trainer.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/modeling_utils.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Sylvain review

* Update src/transformers/modeling_utils.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* black
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent d155b38d
......@@ -17,8 +17,9 @@
import inspect
import os
import re
import warnings
from dataclasses import dataclass
from typing import Callable, Dict, List, Optional, Set, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
import torch
from torch import Tensor, device, dtype, nn
......@@ -45,7 +46,6 @@ from .utils import logging
logger = logging.get_logger(__name__)
try:
from torch.nn import Identity
except ImportError:
......@@ -91,20 +91,6 @@ class ModuleUtilsMixin:
A few utilities for :obj:`torch.nn.Modules`, to be used as a mixin.
"""
def num_parameters(self, only_trainable: bool = False) -> int:
"""
Get the number of (optionally, trainable) parameters in the model.
Args:
only_trainable (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to return only the number of trainable parameters
Returns:
:obj:`int`: The number of parameters.
"""
params = filter(lambda x: x.requires_grad, self.parameters()) if only_trainable else self.parameters()
return sum(p.numel() for p in params)
@staticmethod
def _hook_rss_memory_pre_forward(module, *args, **kwargs):
try:
......@@ -307,9 +293,77 @@ class ModuleUtilsMixin:
elif head_mask.dim() == 2:
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
assert head_mask.dim() == 5, f"head_mask.dim != 5, instead {head_mask.dim()}"
head_mask = head_mask.to(dtype=self.dtype) # switch to fload if need + fp16 compatibility
head_mask = head_mask.to(dtype=self.dtype) # switch to float if need + fp16 compatibility
return head_mask
def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int:
"""
Get number of (optionally, trainable or non-embeddings) parameters in the module.
Args:
only_trainable (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to return only the number of trainable parameters
exclude_embeddings (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to return only the number of non-embeddings parameters
Returns:
:obj:`int`: The number of parameters.
"""
def parameter_filter(x):
return (x.requires_grad or not only_trainable) and not (
isinstance(x, torch.nn.Embedding) and exclude_embeddings
)
params = filter(parameter_filter, self.parameters()) if only_trainable else self.parameters()
return sum(p.numel() for p in params)
def estimate_tokens(self, input_dict: Dict[str, Union[torch.Tensor, Any]]) -> int:
"""
Helper function to estimate the total number of tokens from the model inputs.
Args:
inputs (:obj:`dict`): The model inputs.
Returns:
:obj:`int`: The total number of tokens.
"""
token_inputs = [tensor for key, tensor in input_dict.items() if "input" in key]
if token_inputs:
return sum([token_input.numel() for token_input in token_inputs])
else:
warnings.warn(
"Could not estimate the number of tokens of the input, floating-point operations will not be computed"
)
return 0
def floating_point_ops(
self, input_dict: Dict[str, Union[torch.Tensor, Any]], exclude_embeddings: bool = True
) -> int:
"""
Get number of (optionally, non-embeddings) floating-point operations for the forward and backward passes of a
batch with this transformer model. Default approximation neglects the quadratic dependency on the number of
tokens (valid if :obj:`12 * d_model << sequence_length`) as laid out in `this paper <https://arxiv.org/pdf/2001.08361.pdf>`__ section
2.1. Should be overriden for transformers with parameter re-use e.g. Albert or Universal Transformers, or
if doing long-range modeling with very high sequence lengths.
Args:
batch_size (:obj:`int`):
The batch size for the forward pass.
sequence_length (:obj:`int`):
The number of tokens in each line of the batch.
exclude_embeddings (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to count embedding and softmax operations.
Returns:
:obj:`int`: The number of floating-point operations.
"""
return 6 * self.estimate_tokens(input_dict) * self.num_parameters(exclude_embeddings=exclude_embeddings)
class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
r"""
......
import inspect
import json
import math
import os
import re
......@@ -42,6 +43,8 @@ from .trainer_utils import (
TrainOutput,
default_compute_objective,
default_hp_space,
distributed_broadcast_scalars,
distributed_concat,
set_seed,
)
from .training_args import TrainingArguments
......@@ -146,7 +149,7 @@ class SequentialDistributedSampler(Sampler):
indices = indices[self.rank * self.num_samples : (self.rank + 1) * self.num_samples]
assert (
len(indices) == self.num_samples
), f"Indices length {len(indices)} and and sample number {self.num_samples} mismatched"
), f"Indices length {len(indices)} and sample number {self.num_samples} mismatched"
return iter(indices)
......@@ -241,6 +244,7 @@ class Trainer:
"You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
)
self.tb_writer = tb_writer
self.log_history = []
if "prediction_loss_only" in kwargs:
warnings.warn(
"Passing `prediction_loss_only` as a keyword argument is deprecated and won't be possible in a future version. Use `args.prediction_loss_only` instead.",
......@@ -284,6 +288,7 @@ class Trainer:
self.global_step = None
self.epoch = None
self.total_flos = None
if self.args.fp16 and _use_native_amp:
self.scaler = torch.cuda.amp.GradScaler()
self.hp_search_backend = None
......@@ -461,7 +466,11 @@ class Trainer:
logger.info(
'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"'
)
combined_dict = {**self.model.config.to_dict(), **self.args.to_sanitized_dict()}
try:
combined_dict = {**self.model.config.to_dict(), **self.args.to_sanitized_dict()}
except AttributeError:
# in case the model has no config
combined_dict = {**self.args.to_sanitized_dict()}
wandb.init(
project=os.getenv("WANDB_PROJECT", "huggingface"), config=combined_dict, name=self.args.run_name
)
......@@ -663,6 +672,7 @@ class Trainer:
self.global_step = 0
self.epoch = 0
self.total_flos = 0
epochs_trained = 0
steps_trained_in_current_epoch = 0
# Check if continuing training from a checkpoint
......@@ -670,6 +680,8 @@ class Trainer:
# set global_step to global_step of last saved checkpoint from model path
try:
self.global_step = int(model_path.split("-")[-1].split(os.path.sep)[0])
self.total_flos = getattr(model.config, "total_flos", 0)
epochs_trained = self.global_step // (len(train_dataloader) // self.args.gradient_accumulation_steps)
steps_trained_in_current_epoch = self.global_step % (
len(train_dataloader) // self.args.gradient_accumulation_steps
......@@ -678,9 +690,11 @@ class Trainer:
logger.info(" Continuing training from checkpoint, will skip to saved global_step")
logger.info(" Continuing training from epoch %d", epochs_trained)
logger.info(" Continuing training from global step %d", self.global_step)
logger.info(" Continuing training from %d non-embedding floating-point operations", self.total_flos)
logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
except ValueError:
self.global_step = 0
self.total_flos = 0
logger.info(" Starting fine-tuning.")
tr_loss = torch.tensor(0.0).to(self.args.device)
......@@ -714,6 +728,7 @@ class Trainer:
continue
tr_loss += self.training_step(model, inputs)
self.total_flos += self.floating_point_ops(inputs)
if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
# last step in epoch but step is always smaller than gradient_accumulation_steps
......@@ -784,7 +799,7 @@ class Trainer:
self.save_model(output_dir)
if self.is_world_process_zero():
self._rotate_checkpoints()
self._rotate_checkpoints(use_mtime=True)
if is_torch_tpu_available():
xm.rendezvous("saving_optimizer_states")
......@@ -924,6 +939,13 @@ class Trainer:
if self.epoch is not None:
logs["epoch"] = self.epoch
if self.total_flos is not None:
if self.args.local_rank != -1:
total_flos = distributed_broadcast_scalars([self.total_flos]).sum().item()
else:
total_flos = self.total_flos
if total_flos > 0:
logs["total_flos"] = self.total_flos
if self.global_step is None:
# when logging evaluation metrics without training
self.global_step = 0
......@@ -951,6 +973,8 @@ class Trainer:
if experiment is not None:
experiment._log_metrics(logs, step=self.global_step, epoch=self.epoch, framework="transformers")
output = {**logs, **{"step": self.global_step}}
if self.is_world_process_zero():
self.log_history.append(output)
if iterator is not None:
iterator.write(output)
else:
......@@ -1089,6 +1113,9 @@ class Trainer:
if xm.is_master_ordinal():
os.makedirs(output_dir, exist_ok=True)
torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
json.dump(
self.log_history, open(os.path.join(output_dir, "log_history.json"), "w"), indent=2, ensure_ascii=False
)
# Save a trained model and configuration using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
......@@ -1096,6 +1123,7 @@ class Trainer:
raise ValueError("Trainer.model appears to not be a PreTrainedModel")
xm.rendezvous("saving_checkpoint")
self._store_flos()
self.model.save_pretrained(output_dir)
if self.tokenizer is not None:
self.tokenizer.save_pretrained(output_dir)
......@@ -1108,12 +1136,26 @@ class Trainer:
# They can then be reloaded using `from_pretrained()`
if not isinstance(self.model, PreTrainedModel):
raise ValueError("Trainer.model appears to not be a PreTrainedModel")
self._store_flos()
self.model.save_pretrained(output_dir)
if self.tokenizer is not None:
self.tokenizer.save_pretrained(output_dir)
# Good practice: save your training arguments together with the trained model
torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
json.dump(
self.log_history, open(os.path.join(output_dir, "log_history.json"), "w"), indent=2, ensure_ascii=False
)
def _store_flos(self):
# Storing the number of floating-point operations that went into the model
if self.total_flos is not None:
if self.args.local_rank != -1:
total_flos = distributed_broadcast_scalars([self.total_flos]).sum().item()
else:
total_flos = self.total_flos
if total_flos > 0:
self.model.config.total_flos = total_flos
def _sorted_checkpoints(self, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False) -> List[str]:
ordering_and_checkpoint_path = []
......@@ -1245,13 +1287,11 @@ class Trainer:
self._past = None
disable_tqdm = not self.is_local_process_zero() or self.args.disable_tqdm
samples_count = 0
for inputs in tqdm(dataloader, desc=description, disable=disable_tqdm):
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only)
batch_size = inputs[list(inputs.keys())[0]].shape[0]
samples_count += batch_size
if loss is not None:
eval_losses.append(loss * batch_size)
eval_losses.extend([loss] * batch_size)
if logits is not None:
preds = logits if preds is None else torch.cat((preds, logits), dim=0)
if labels is not None:
......@@ -1264,9 +1304,9 @@ class Trainer:
if self.args.local_rank != -1:
# In distributed mode, concatenate all results from all nodes:
if preds is not None:
preds = self.distributed_concat(preds, num_total_examples=self.num_examples(dataloader))
preds = distributed_concat(preds, num_total_examples=self.num_examples(dataloader))
if label_ids is not None:
label_ids = self.distributed_concat(label_ids, num_total_examples=self.num_examples(dataloader))
label_ids = distributed_concat(label_ids, num_total_examples=self.num_examples(dataloader))
elif is_torch_tpu_available():
# tpu-comment: Get all predictions and labels from all worker shards of eval dataset
if preds is not None:
......@@ -1289,7 +1329,14 @@ class Trainer:
else:
metrics = {}
if len(eval_losses) > 0:
metrics["eval_loss"] = np.sum(eval_losses) / samples_count
if self.args.local_rank != -1:
metrics["eval_loss"] = (
distributed_broadcast_scalars(eval_losses, num_total_examples=self.num_examples(dataloader))
.mean()
.item()
)
else:
metrics["eval_loss"] = np.mean(eval_losses)
# Prefix all keys with eval_
for key in list(metrics.keys()):
......@@ -1298,18 +1345,6 @@ class Trainer:
return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)
def distributed_concat(self, tensor: torch.Tensor, num_total_examples: int) -> torch.Tensor:
assert self.args.local_rank != -1
output_tensors = [tensor.clone() for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(output_tensors, tensor)
concat = torch.cat(output_tensors, dim=0)
# truncate the dummy elements added by SequentialDistributedSampler
output = concat[:num_total_examples]
return output
def prediction_step(
self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], prediction_loss_only: bool
) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
......@@ -1355,3 +1390,32 @@ class Trainer:
if labels is not None:
labels = labels.detach()
return (loss, logits.detach(), labels)
def floating_point_ops(self, inputs: Dict[str, Union[torch.Tensor, Any]]):
"""
For models that inherit from :class:`~transformers.PretrainedModel`, uses
that method to compute the number of floating point operations for every backward + forward pass. If using
another model, either implement such a method in the model or subclass and override this method.
Args:
model (:obj:`nn.Module`):
The model to evaluate.
inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
The inputs and targets of the model.
Returns:
:obj:`int`: The number of floating-point operations.
"""
if isinstance(self.model, torch.nn.DataParallel) or isinstance(
self.model, torch.nn.parallel.DistributedDataParallel
):
model = self.model.module
else:
model = self.model
if hasattr(model, "floating_point_ops"):
return model.floating_point_ops(inputs)
else:
return 0
import random
from typing import Any, Dict, NamedTuple, Optional
from typing import Any, Dict, List, NamedTuple, Optional, Union
import numpy as np
import torch
from .file_utils import is_tf_available, is_torch_available
from .tokenization_utils_base import ExplicitEnum
......@@ -126,3 +127,32 @@ default_hp_space = {
HPSearchBackend.OPTUNA: default_hp_space_optuna,
HPSearchBackend.RAY: default_hp_space_ray,
}
def distributed_concat(self, tensor: torch.Tensor, num_total_examples: Optional[int] = None) -> torch.Tensor:
assert self.args.local_rank != -1
output_tensors = [tensor.clone() for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(output_tensors, tensor)
concat = torch.cat(output_tensors, dim=0)
# truncate the dummy elements added by SequentialDistributedSampler
if num_total_examples is not None:
concat = concat[:num_total_examples]
return concat
def distributed_broadcast_scalars(
self, scalars: List[Union[int, float]], num_total_examples: Optional[int] = None
) -> torch.Tensor:
assert self.args.local_rank != -1
tensorized_scalar = torch.Tensor(scalars).cuda()
output_tensors = [tensorized_scalar.clone() for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(output_tensors, tensorized_scalar)
concat = torch.cat(output_tensors, dim=0)
# truncate the dummy elements added by SequentialDistributedSampler
if num_total_examples is not None:
concat = concat[:num_total_examples]
return concat
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment