Unverified Commit 70996a54 authored by Jamie DeAntonis's avatar Jamie DeAntonis Committed by GitHub
Browse files

WIP: Support for Training with BF16 (#13207)



* started bf16 integration

* minor changes

* code now runs

* style

* lay foundation for bf16 testing

* lay foundation for bf16 testing

* start the tests

* better bf16 check

* style

* 2 separate checkers - one for bf16 support, another for bf16+autocast

* Update src/transformers/training_args.py
Co-authored-by: default avatarStas Bekman <stas00@users.noreply.github.com>

* a couple of comment resolutions

* more comment resolutions

* resolved a small bug

* just some print statemtns

* added todo marking

* added a todo

* adjust for API change s/fast_dtype/dtype/

* fix style

* merge 2 bf16 util functions

* bf16 now does scaling too

* Add support for bfloat16

* Revert T5 layernorm to float32

This is based on the comment at https://github.com/huggingface/transformers/pull/14448/files#r752660929 and the PyTorch PR https://github.com/pytorch/pytorch/pull/66920

 .

* Add comment about conversion to float32 before returning the numpy data

* Add comment about AMP-bfloat16 incompatibility

* Fix formatting

* typo

* reformer / bf16

* cleanup

* require at least pt-1.10

* fix

* will deal with deepspeed separately

* cleanup

* revert

* cleanup

* fp16_full_eval and bf16_full_eval are separate modes

* proper deprecation

* cleanup

* test and fixes

* spelling

* cleanup

* add a note that this API is experimental
Co-authored-by: default avatarjamie <jamie@cortx.com>
Co-authored-by: default avatarStas Bekman <stas@stason.org>
Co-authored-by: default avatarStas Bekman <stas00@users.noreply.github.com>
Co-authored-by: default avatarsuriya <suriya@cortx.com>
Co-authored-by: default avatarManuel R. Ciosici <manuelrciosici@gmail.com>
parent fc1d97f2
...@@ -320,6 +320,37 @@ def is_torch_cuda_available(): ...@@ -320,6 +320,37 @@ def is_torch_cuda_available():
return False return False
def is_torch_bf16_available():
if is_torch_available():
import torch
# since currently no utility function is available we build our own.
# some bits come from https://github.com/pytorch/pytorch/blob/2289a12f21c54da93bf5d696e3f9aea83dd9c10d/torch/testing/_internal/common_cuda.py#L51
# with additional check for torch version
# to succeed:
# 1. the hardware needs to support bf16 (arch >= Ampere)
# 2. torch >= 1.10 (1.9 should be enough for AMP API has changed in 1.10, so using 1.10 as minimal)
# 3. CUDA >= 11
# 4. torch.autocast exists
# XXX: one problem here is that it may give invalid results on mixed gpus setup, so it's
# really only correct for the 0th gpu (or currently set default device if different from 0)
if not torch.cuda.is_available() or torch.version.cuda is None:
return False
if torch.cuda.get_device_properties(torch.cuda.current_device()).major < 8:
return False
if int(torch.version.cuda.split(".")[0]) < 11:
return False
if not version.parse(torch.__version__) >= version.parse("1.10"):
return False
if not hasattr(torch, "autocast"):
return False
return True
else:
return False
_torch_fx_available = _torch_onnx_dict_inputs_support_available = False _torch_fx_available = _torch_onnx_dict_inputs_support_available = False
if _torch_available: if _torch_available:
torch_version = version.parse(importlib_metadata.version("torch")) torch_version = version.parse(importlib_metadata.version("torch"))
......
...@@ -233,7 +233,7 @@ class ModuleUtilsMixin: ...@@ -233,7 +233,7 @@ class ModuleUtilsMixin:
if self.dtype == torch.float16: if self.dtype == torch.float16:
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e4 encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e4
elif self.dtype == torch.float32: elif self.dtype in [torch.bfloat16, torch.float32]:
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e9 encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e9
else: else:
raise ValueError( raise ValueError(
......
...@@ -242,9 +242,10 @@ class T5LayerNorm(nn.Module): ...@@ -242,9 +242,10 @@ class T5LayerNorm(nn.Module):
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
# convert into float16 if necessary # convert into half-precision if necessary
if self.weight.dtype == torch.float16: if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(torch.float16) hidden_states = hidden_states.to(self.weight.dtype)
return self.weight * hidden_states return self.weight * hidden_states
......
...@@ -49,6 +49,7 @@ from .file_utils import ( ...@@ -49,6 +49,7 @@ from .file_utils import (
is_timm_available, is_timm_available,
is_tokenizers_available, is_tokenizers_available,
is_torch_available, is_torch_available,
is_torch_bf16_available,
is_torch_tpu_available, is_torch_tpu_available,
is_torchaudio_available, is_torchaudio_available,
is_vision_available, is_vision_available,
...@@ -493,6 +494,14 @@ def require_torch_gpu(test_case): ...@@ -493,6 +494,14 @@ def require_torch_gpu(test_case):
return test_case return test_case
def require_torch_bf16(test_case):
"""Decorator marking a test that requires CUDA hardware supporting bf16 and PyTorch >= 1.10."""
if not is_torch_bf16_available():
return unittest.skip("test requires CUDA hardware supporting bf16 and PyTorch >= 1.10")(test_case)
else:
return test_case
def require_datasets(test_case): def require_datasets(test_case):
"""Decorator marking a test that requires datasets.""" """Decorator marking a test that requires datasets."""
......
...@@ -353,13 +353,13 @@ class Trainer: ...@@ -353,13 +353,13 @@ class Trainer:
# 1. MP - since we are trying to fit a much bigger than 1 gpu model # 1. MP - since we are trying to fit a much bigger than 1 gpu model
# 2. fp16-enabled DeepSpeed loads the model in half the size and it doesn't need .to() anyway, # 2. fp16-enabled DeepSpeed loads the model in half the size and it doesn't need .to() anyway,
# and we only use deepspeed for training at the moment # and we only use deepspeed for training at the moment
# 3. full fp16 eval - since the model needs to be half'ed first # 3. full bf16 or fp16 eval - since the model needs to be cast to the right dtype first
# 4. Sharded DDP - same as MP # 4. Sharded DDP - same as MP
self.place_model_on_device = args.place_model_on_device self.place_model_on_device = args.place_model_on_device
if ( if (
self.is_model_parallel self.is_model_parallel
or args.deepspeed or args.deepspeed
or (args.fp16_full_eval and not args.do_train) or ((args.fp16_full_eval or args.bf16_full_eval) and not args.do_train)
or (self.sharded_ddp in [ShardedDDPOption.ZERO_DP_2, ShardedDDPOption.ZERO_DP_3]) or (self.sharded_ddp in [ShardedDDPOption.ZERO_DP_2, ShardedDDPOption.ZERO_DP_3])
): ):
self.place_model_on_device = False self.place_model_on_device = False
...@@ -424,18 +424,24 @@ class Trainer: ...@@ -424,18 +424,24 @@ class Trainer:
# Mixed precision setup # Mixed precision setup
self.use_apex = False self.use_apex = False
self.use_amp = False self.use_amp = False
self.fp16_backend = None
if args.fp16: if args.fp16 or args.bf16:
if args.fp16_backend == "auto": if args.half_precision_backend == "auto":
self.fp16_backend = "amp" if _is_native_amp_available else "apex" if _is_native_amp_available:
else: args.half_precision_backend = "amp"
self.fp16_backend = args.fp16_backend else:
logger.info(f"Using {self.fp16_backend} fp16 backend") if args.bf16:
raise ValueError("Tried to use `bf16` but native amp is not available")
else:
args.half_precision_backend = "apex"
logger.info(f"Using {args.half_precision_backend} half precision backend")
if args.fp16 and not args.deepspeed: # deepspeed manages its own fp16 self.do_grad_scaling = False
if self.fp16_backend == "amp": if (args.fp16 or args.bf16) and not args.deepspeed: # deepspeed manages its own half precision
if args.half_precision_backend == "amp":
self.use_amp = True self.use_amp = True
self.amp_dtype = torch.float16 if args.fp16 else torch.bfloat16
self.do_grad_scaling = True
if is_sagemaker_mp_enabled(): if is_sagemaker_mp_enabled():
self.scaler = smp.amp.GradScaler() self.scaler = smp.amp.GradScaler()
elif self.sharded_ddp is not None: elif self.sharded_ddp is not None:
...@@ -975,7 +981,7 @@ class Trainer: ...@@ -975,7 +981,7 @@ class Trainer:
if self.sharded_ddp == ShardedDDPOption.SIMPLE: if self.sharded_ddp == ShardedDDPOption.SIMPLE:
model = ShardedDDP(model, self.optimizer) model = ShardedDDP(model, self.optimizer)
else: else:
mixed_precision = self.args.fp16 mixed_precision = self.args.fp16 or self.args.bf16
cpu_offload = ShardedDDPOption.OFFLOAD in self.args.sharded_ddp cpu_offload = ShardedDDPOption.OFFLOAD in self.args.sharded_ddp
zero_3 = self.sharded_ddp == ShardedDDPOption.ZERO_DP_3 zero_3 = self.sharded_ddp == ShardedDDPOption.ZERO_DP_3
# XXX: Breaking the self.model convention but I see no way around it for now. # XXX: Breaking the self.model convention but I see no way around it for now.
...@@ -1043,7 +1049,7 @@ class Trainer: ...@@ -1043,7 +1049,7 @@ class Trainer:
# do_train is not a reliable argument, as it might not be set and .train() still called, so # do_train is not a reliable argument, as it might not be set and .train() still called, so
# the following is a workaround: # the following is a workaround:
if args.fp16_full_eval and not args.do_train: if (args.fp16_full_eval or args.bf16_full_eval) and not args.do_train:
self._move_model_to_device(self.model, args.device) self._move_model_to_device(self.model, args.device)
if "model_path" in kwargs: if "model_path" in kwargs:
...@@ -1341,7 +1347,7 @@ class Trainer: ...@@ -1341,7 +1347,7 @@ class Trainer:
if args.max_grad_norm is not None and args.max_grad_norm > 0 and not self.deepspeed: if args.max_grad_norm is not None and args.max_grad_norm > 0 and not self.deepspeed:
# deepspeed does its own clipping # deepspeed does its own clipping
if self.use_amp: if self.do_grad_scaling:
# AMP: gradients need unscaling # AMP: gradients need unscaling
self.scaler.unscale_(self.optimizer) self.scaler.unscale_(self.optimizer)
...@@ -1364,7 +1370,7 @@ class Trainer: ...@@ -1364,7 +1370,7 @@ class Trainer:
pass # called outside the loop pass # called outside the loop
elif is_torch_tpu_available(): elif is_torch_tpu_available():
xm.optimizer_step(self.optimizer) xm.optimizer_step(self.optimizer)
elif self.use_amp: elif self.do_grad_scaling:
scale_before = self.scaler.get_scale() scale_before = self.scaler.get_scale()
self.scaler.step(self.optimizer) self.scaler.step(self.optimizer)
self.scaler.update() self.scaler.update()
...@@ -1588,7 +1594,7 @@ class Trainer: ...@@ -1588,7 +1594,7 @@ class Trainer:
with warnings.catch_warnings(record=True) as caught_warnings: with warnings.catch_warnings(record=True) as caught_warnings:
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
reissue_pt_warnings(caught_warnings) reissue_pt_warnings(caught_warnings)
if self.use_amp: if self.do_grad_scaling:
torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME)) torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
elif self.args.should_save and not self.deepspeed: elif self.args.should_save and not self.deepspeed:
# deepspeed.save_checkpoint above saves model/optim/sched # deepspeed.save_checkpoint above saves model/optim/sched
...@@ -1596,7 +1602,7 @@ class Trainer: ...@@ -1596,7 +1602,7 @@ class Trainer:
with warnings.catch_warnings(record=True) as caught_warnings: with warnings.catch_warnings(record=True) as caught_warnings:
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
reissue_pt_warnings(caught_warnings) reissue_pt_warnings(caught_warnings)
if self.use_amp: if self.do_grad_scaling:
torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME)) torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
# Determine the new best metric / best model checkpoint # Determine the new best metric / best model checkpoint
...@@ -1684,7 +1690,7 @@ class Trainer: ...@@ -1684,7 +1690,7 @@ class Trainer:
with warnings.catch_warnings(record=True) as caught_warnings: with warnings.catch_warnings(record=True) as caught_warnings:
self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME))) self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME)))
reissue_pt_warnings(caught_warnings) reissue_pt_warnings(caught_warnings)
if self.use_amp and os.path.isfile(os.path.join(checkpoint, SCALER_NAME)): if self.do_grad_scaling and os.path.isfile(os.path.join(checkpoint, SCALER_NAME)):
self.scaler.load_state_dict(torch.load(os.path.join(checkpoint, SCALER_NAME))) self.scaler.load_state_dict(torch.load(os.path.join(checkpoint, SCALER_NAME)))
def hyperparameter_search( def hyperparameter_search(
...@@ -1846,12 +1852,12 @@ class Trainer: ...@@ -1846,12 +1852,12 @@ class Trainer:
inputs = self._prepare_inputs(inputs) inputs = self._prepare_inputs(inputs)
if is_sagemaker_mp_enabled(): if is_sagemaker_mp_enabled():
scaler = self.scaler if self.use_amp else None scaler = self.scaler if self.do_grad_scaling else None
loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps, scaler=scaler) loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps, scaler=scaler)
return loss_mb.reduce_mean().detach().to(self.args.device) return loss_mb.reduce_mean().detach().to(self.args.device)
if self.use_amp: if self.use_amp:
with autocast(): with autocast(dtype=self.amp_dtype):
loss = self.compute_loss(model, inputs) loss = self.compute_loss(model, inputs)
else: else:
loss = self.compute_loss(model, inputs) loss = self.compute_loss(model, inputs)
...@@ -1863,7 +1869,7 @@ class Trainer: ...@@ -1863,7 +1869,7 @@ class Trainer:
# deepspeed handles loss scaling by gradient_accumulation_steps in its `backward` # deepspeed handles loss scaling by gradient_accumulation_steps in its `backward`
loss = loss / self.args.gradient_accumulation_steps loss = loss / self.args.gradient_accumulation_steps
if self.use_amp: if self.do_grad_scaling:
self.scaler.scale(loss).backward() self.scaler.scale(loss).backward()
elif self.use_apex: elif self.use_apex:
with amp.scale_loss(loss, self.optimizer) as scaled_loss: with amp.scale_loss(loss, self.optimizer) as scaled_loss:
...@@ -2220,12 +2226,12 @@ class Trainer: ...@@ -2220,12 +2226,12 @@ class Trainer:
Works both with or without labels. Works both with or without labels.
""" """
prediction_loss_only = ( args = self.args
prediction_loss_only if prediction_loss_only is not None else self.args.prediction_loss_only
) prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only
# if eval is called w/o train init deepspeed here # if eval is called w/o train init deepspeed here
if self.args.deepspeed and not self.deepspeed: if args.deepspeed and not self.deepspeed:
# XXX: eval doesn't have `resume_from_checkpoint` arg but we should be able to do eval # XXX: eval doesn't have `resume_from_checkpoint` arg but we should be able to do eval
# from the checkpoint eventually # from the checkpoint eventually
...@@ -2238,10 +2244,13 @@ class Trainer: ...@@ -2238,10 +2244,13 @@ class Trainer:
model = self._wrap_model(self.model, training=False) model = self._wrap_model(self.model, training=False)
# if full fp16 is wanted on eval and this ``evaluation`` or ``predict`` isn't called while # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called
# ``train`` is running, halve it first and then put on device # while ``train`` is running, cast it to the right dtype first and then put on device
if not self.is_in_train and self.args.fp16_full_eval: if not self.is_in_train:
model = model.half().to(self.args.device) if args.fp16_full_eval:
model = model.to(dtype=torch.float16, device=args.device)
elif args.bf16_full_eval:
model = model.to(dtype=torch.bfloat16, device=args.device)
batch_size = dataloader.batch_size batch_size = dataloader.batch_size
...@@ -2259,9 +2268,9 @@ class Trainer: ...@@ -2259,9 +2268,9 @@ class Trainer:
eval_dataset = dataloader.dataset eval_dataset = dataloader.dataset
if is_torch_tpu_available(): if is_torch_tpu_available():
dataloader = pl.ParallelLoader(dataloader, [self.args.device]).per_device_loader(self.args.device) dataloader = pl.ParallelLoader(dataloader, [args.device]).per_device_loader(args.device)
if self.args.past_index >= 0: if args.past_index >= 0:
self._past = None self._past = None
# Initialize containers # Initialize containers
...@@ -2301,10 +2310,10 @@ class Trainer: ...@@ -2301,10 +2310,10 @@ class Trainer:
labels = self._pad_across_processes(labels) labels = self._pad_across_processes(labels)
labels = self._nested_gather(labels) labels = self._nested_gather(labels)
labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100) labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100)
self.control = self.callback_handler.on_prediction_step(self.args, self.state, self.control) self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)
# Gather all tensors and put them back on the CPU if we have done enough accumulation steps. # Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
if self.args.eval_accumulation_steps is not None and (step + 1) % self.args.eval_accumulation_steps == 0: if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0:
if losses_host is not None: if losses_host is not None:
losses = nested_numpify(losses_host) losses = nested_numpify(losses_host)
all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0) all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0)
...@@ -2320,7 +2329,7 @@ class Trainer: ...@@ -2320,7 +2329,7 @@ class Trainer:
# Set back to None to begin a new accumulation # Set back to None to begin a new accumulation
losses_host, preds_host, labels_host = None, None, None losses_host, preds_host, labels_host = None, None, None
if self.args.past_index and hasattr(self, "_past"): if args.past_index and hasattr(self, "_past"):
# Clean the state at the end of the evaluation loop # Clean the state at the end of the evaluation loop
delattr(self, "_past") delattr(self, "_past")
...@@ -2492,11 +2501,12 @@ class Trainer: ...@@ -2492,11 +2501,12 @@ class Trainer:
else: else:
if has_labels: if has_labels:
if self.use_amp: if self.use_amp:
with autocast(): with autocast(dtype=self.amp_dtype):
loss, outputs = self.compute_loss(model, inputs, return_outputs=True) loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
else: else:
loss, outputs = self.compute_loss(model, inputs, return_outputs=True) loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
loss = loss.mean().detach() loss = loss.mean().detach()
if isinstance(outputs, dict): if isinstance(outputs, dict):
logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"]) logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"])
else: else:
...@@ -2504,7 +2514,7 @@ class Trainer: ...@@ -2504,7 +2514,7 @@ class Trainer:
else: else:
loss = None loss = None
if self.use_amp: if self.use_amp:
with autocast(): with autocast(dtype=self.amp_dtype):
outputs = model(**inputs) outputs = model(**inputs)
else: else:
outputs = model(**inputs) outputs = model(**inputs)
...@@ -2719,14 +2729,14 @@ class Trainer: ...@@ -2719,14 +2729,14 @@ class Trainer:
Works both with or without labels. Works both with or without labels.
""" """
args = self.args
if not isinstance(dataloader.dataset, collections.abc.Sized): if not isinstance(dataloader.dataset, collections.abc.Sized):
raise ValueError("dataset must implement __len__") raise ValueError("dataset must implement __len__")
prediction_loss_only = ( prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only
prediction_loss_only if prediction_loss_only is not None else self.args.prediction_loss_only
)
# if eval is called w/o train init deepspeed here # if eval is called w/o train init deepspeed here
if self.args.deepspeed and not self.deepspeed: if args.deepspeed and not self.deepspeed:
# XXX: eval doesn't have `resume_from_checkpoint` arg but we should be able to do eval # XXX: eval doesn't have `resume_from_checkpoint` arg but we should be able to do eval
# from the checkpoint eventually # from the checkpoint eventually
...@@ -2742,10 +2752,13 @@ class Trainer: ...@@ -2742,10 +2752,13 @@ class Trainer:
model = self._wrap_model(self.model, training=False) model = self._wrap_model(self.model, training=False)
# if full fp16 is wanted on eval and this ``evaluation`` or ``predict`` isn't called while # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called
# ``train`` is running, halve it first and then put on device # while ``train`` is running, cast it to the right dtype first and then put on device
if not self.is_in_train and self.args.fp16_full_eval: if not self.is_in_train:
model = model.half().to(self.args.device) if args.fp16_full_eval:
model = model.to(dtype=torch.float16, device=args.device)
elif args.bf16_full_eval:
model = model.to(dtype=torch.bfloat16, device=args.device)
batch_size = dataloader.batch_size batch_size = dataloader.batch_size
num_examples = self.num_examples(dataloader) num_examples = self.num_examples(dataloader)
...@@ -2756,7 +2769,7 @@ class Trainer: ...@@ -2756,7 +2769,7 @@ class Trainer:
preds_host: Union[torch.Tensor, List[torch.Tensor]] = None preds_host: Union[torch.Tensor, List[torch.Tensor]] = None
labels_host: Union[torch.Tensor, List[torch.Tensor]] = None labels_host: Union[torch.Tensor, List[torch.Tensor]] = None
world_size = max(1, self.args.world_size) world_size = max(1, args.world_size)
eval_losses_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size) eval_losses_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size)
if not prediction_loss_only: if not prediction_loss_only:
...@@ -2771,9 +2784,9 @@ class Trainer: ...@@ -2771,9 +2784,9 @@ class Trainer:
model.eval() model.eval()
if is_torch_tpu_available(): if is_torch_tpu_available():
dataloader = pl.ParallelLoader(dataloader, [self.args.device]).per_device_loader(self.args.device) dataloader = pl.ParallelLoader(dataloader, [args.device]).per_device_loader(args.device)
if self.args.past_index >= 0: if args.past_index >= 0:
self._past = None self._past = None
self.callback_handler.eval_dataloader = dataloader self.callback_handler.eval_dataloader = dataloader
...@@ -2787,10 +2800,10 @@ class Trainer: ...@@ -2787,10 +2800,10 @@ class Trainer:
preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100) preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100)
if labels is not None: if labels is not None:
labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100) labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100)
self.control = self.callback_handler.on_prediction_step(self.args, self.state, self.control) self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)
# Gather all tensors and put them back on the CPU if we have done enough accumulation steps. # Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
if self.args.eval_accumulation_steps is not None and (step + 1) % self.args.eval_accumulation_steps == 0: if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0:
eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses")) eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses"))
if not prediction_loss_only: if not prediction_loss_only:
preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds")) preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds"))
...@@ -2799,7 +2812,7 @@ class Trainer: ...@@ -2799,7 +2812,7 @@ class Trainer:
# Set back to None to begin a new accumulation # Set back to None to begin a new accumulation
losses_host, preds_host, labels_host = None, None, None losses_host, preds_host, labels_host = None, None, None
if self.args.past_index and hasattr(self, "_past"): if args.past_index and hasattr(self, "_past"):
# Clean the state at the end of the evaluation loop # Clean the state at the end of the evaluation loop
delattr(self, "_past") delattr(self, "_past")
......
...@@ -136,7 +136,13 @@ def nested_numpify(tensors): ...@@ -136,7 +136,13 @@ def nested_numpify(tensors):
"Numpify `tensors` (even if it's a nested list/tuple of tensors)." "Numpify `tensors` (even if it's a nested list/tuple of tensors)."
if isinstance(tensors, (list, tuple)): if isinstance(tensors, (list, tuple)):
return type(tensors)(nested_numpify(t) for t in tensors) return type(tensors)(nested_numpify(t) for t in tensors)
return tensors.cpu().numpy() t = tensors.cpu()
if t.dtype == torch.bfloat16:
# As of Numpy 1.21.4, NumPy does not support bfloat16 (see
# https://github.com/numpy/numpy/blob/a47ecdea856986cd60eabbd53265c2ca5916ad5d/doc/source/user/basics.types.rst ).
# Until Numpy adds bfloat16, we must convert float32.
t = t.to(torch.float32)
return t.numpy()
def nested_detach(tensors): def nested_detach(tensors):
......
...@@ -207,18 +207,26 @@ class TrainingArguments: ...@@ -207,18 +207,26 @@ class TrainingArguments:
Random seed that will be set at the beginning of training. To ensure reproducibility across runs, use the Random seed that will be set at the beginning of training. To ensure reproducibility across runs, use the
:func:`~transformers.Trainer.model_init` function to instantiate the model if it has some randomly :func:`~transformers.Trainer.model_init` function to instantiate the model if it has some randomly
initialized parameters. initialized parameters.
bf16 (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to use bf16 16-bit (mixed) precision training instead of 32-bit training. Requires Ampere or higher
NVIDIA architecture. This is an experimental API and it may change.
fp16 (:obj:`bool`, `optional`, defaults to :obj:`False`): fp16 (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to use 16-bit (mixed) precision training instead of 32-bit training. Whether to use fp16 16-bit (mixed) precision training instead of 32-bit training.
fp16_opt_level (:obj:`str`, `optional`, defaults to 'O1'): fp16_opt_level (:obj:`str`, `optional`, defaults to 'O1'):
For :obj:`fp16` training, Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']. See details For :obj:`fp16` training, Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']. See details
on the `Apex documentation <https://nvidia.github.io/apex/amp.html>`__. on the `Apex documentation <https://nvidia.github.io/apex/amp.html>`__.
fp16_backend (:obj:`str`, `optional`, defaults to :obj:`"auto"`): fp16_backend (:obj:`str`, `optional`, defaults to :obj:`"auto"`):
This argument is deprecated. Use ``half_precision_backend`` instead.
half_precision_backend (:obj:`str`, `optional`, defaults to :obj:`"auto"`):
The backend to use for mixed precision training. Must be one of :obj:`"auto"`, :obj:`"amp"` or The backend to use for mixed precision training. Must be one of :obj:`"auto"`, :obj:`"amp"` or
:obj:`"apex"`. :obj:`"auto"` will use AMP or APEX depending on the PyTorch version detected, while the :obj:`"apex"`. :obj:`"auto"` will use AMP or APEX depending on the PyTorch version detected, while the
other choices will force the requested backend. other choices will force the requested backend.
bf16_full_eval (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to use full bfloat16 evaluation instead of 32-bit. This will be faster and save memory but can harm
metric values. This is an experimental API and it may change.
fp16_full_eval (:obj:`bool`, `optional`, defaults to :obj:`False`): fp16_full_eval (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to use full 16-bit precision evaluation instead of 32-bit. This will be faster and save memory but Whether to use full float16 evaluation instead of 32-bit. This will be faster and save memory but can harm
can harm metric values. metric values.
local_rank (:obj:`int`, `optional`, defaults to -1): local_rank (:obj:`int`, `optional`, defaults to -1):
Rank of the process during distributed training. Rank of the process during distributed training.
xpu_backend (:obj:`str`, `optional`): xpu_backend (:obj:`str`, `optional`):
...@@ -507,10 +515,15 @@ class TrainingArguments: ...@@ -507,10 +515,15 @@ class TrainingArguments:
) )
no_cuda: bool = field(default=False, metadata={"help": "Do not use CUDA even when it is available"}) no_cuda: bool = field(default=False, metadata={"help": "Do not use CUDA even when it is available"})
seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."}) seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."})
bf16: bool = field(
default=False,
metadata={
"help": "Whether to use bf16 (mixed) precision instead of 32-bit. Requires Ampere or higher NVIDIA architecture. This is an experimental API and it may change."
},
)
fp16: bool = field( fp16: bool = field(
default=False, default=False,
metadata={"help": "Whether to use 16-bit (mixed) precision instead of 32-bit"}, metadata={"help": "Whether to use fp16 (mixed) precision instead of 32-bit"},
) )
fp16_opt_level: str = field( fp16_opt_level: str = field(
default="O1", default="O1",
...@@ -521,13 +534,19 @@ class TrainingArguments: ...@@ -521,13 +534,19 @@ class TrainingArguments:
) )
}, },
) )
fp16_backend: str = field( half_precision_backend: str = field(
default="auto", default="auto",
metadata={"help": "The backend to be used for mixed precision.", "choices": ["auto", "amp", "apex"]}, metadata={"help": "The backend to be used for half precision.", "choices": ["auto", "amp", "apex"]},
)
bf16_full_eval: bool = field(
default=False,
metadata={
"help": "Whether to use full bfloat16 evaluation instead of 32-bit. This is an experimental API and it may change."
},
) )
fp16_full_eval: bool = field( fp16_full_eval: bool = field(
default=False, default=False,
metadata={"help": "Whether to use full 16-bit precision evaluation instead of 32-bit"}, metadata={"help": "Whether to use full float16 evaluation instead of 32-bit"},
) )
local_rank: int = field(default=-1, metadata={"help": "For distributed training: local_rank"}) local_rank: int = field(default=-1, metadata={"help": "For distributed training: local_rank"})
xpu_backend: str = field( xpu_backend: str = field(
...@@ -666,6 +685,10 @@ class TrainingArguments: ...@@ -666,6 +685,10 @@ class TrainingArguments:
}, },
) )
# Deprecated arguments # Deprecated arguments
fp16_backend: str = field(
default="auto",
metadata={"help": "Deprecated. Use half_precision_backend instead", "choices": ["auto", "amp", "apex"]},
)
push_to_hub_model_id: str = field( push_to_hub_model_id: str = field(
default=None, metadata={"help": "The name of the repository to which push the `Trainer`."} default=None, metadata={"help": "The name of the repository to which push the `Trainer`."}
) )
...@@ -754,10 +777,31 @@ class TrainingArguments: ...@@ -754,10 +777,31 @@ class TrainingArguments:
if self.run_name is None: if self.run_name is None:
self.run_name = self.output_dir self.run_name = self.output_dir
if is_torch_available() and self.device.type != "cuda" and (self.fp16 or self.fp16_full_eval): if self.fp16_backend and self.fp16_backend != "auto":
warnings.warn(
"`fp16_backend` is deprecated and will be removed in version 5 of 🤗 Transformers. Use `half_precision_backend` instead",
FutureWarning,
)
self.half_precision_backend = self.fp16_backend
if self.fp16 and self.bf16:
raise ValueError("At most one of fp16 and bf16 can be True, but not both")
if self.bf16:
if self.half_precision_backend == "apex":
raise ValueError(
" `--half_precision_backend apex`: bf16 is not supported by apex. Use `--half_precision_backend amp` instead"
)
if not (self.sharded_ddp == "" or not self.sharded_ddp):
raise ValueError("sharded_ddp is not supported with bf16")
if (
is_torch_available()
and self.device.type != "cuda"
and (self.fp16 or self.fp16_full_eval or self.bf16 or self.bf16_full_eval)
):
raise ValueError( raise ValueError(
"Mixed precision training with AMP or APEX (`--fp16`) and FP16 evaluation can only be used on CUDA devices." "Mixed precision training with AMP or APEX (`--fp16` or `--bf16`) and half precision evaluation (`--fp16_full_eval` or `--bf16_full_eval`) can only be used on CUDA devices."
) )
if self.report_to is None: if self.report_to is None:
logger.info( logger.info(
"The default value for the training argument `--report_to` will change in v5 (from all installed " "The default value for the training argument `--report_to` will change in v5 (from all installed "
......
...@@ -53,6 +53,7 @@ from transformers.testing_utils import ( ...@@ -53,6 +53,7 @@ from transformers.testing_utils import (
require_sigopt, require_sigopt,
require_tokenizers, require_tokenizers,
require_torch, require_torch,
require_torch_bf16,
require_torch_gpu, require_torch_gpu,
require_torch_multi_gpu, require_torch_multi_gpu,
require_torch_non_multi_gpu, require_torch_non_multi_gpu,
...@@ -476,6 +477,21 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon): ...@@ -476,6 +477,21 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
self.assertFalse(torch.allclose(trainer.model.b, b)) self.assertFalse(torch.allclose(trainer.model.b, b))
self.assertGreater(trainer.optimizer.state_dict()["param_groups"][0]["lr"], 0) self.assertGreater(trainer.optimizer.state_dict()["param_groups"][0]["lr"], 0)
@require_torch_gpu
@require_torch_bf16
def test_mixed_bf16(self):
# very basic test
trainer = get_regression_trainer(learning_rate=0.1, bf16=True)
trainer.train()
self.check_trained_model(trainer.model)
# --bf16 --half_precision_backend apex can't be used together
with self.assertRaises(ValueError):
trainer = get_regression_trainer(learning_rate=0.1, bf16=True, half_precision_backend="apex")
# will add more specific tests once there are some bugs to fix
@require_torch @require_torch
@require_sentencepiece @require_sentencepiece
...@@ -1323,6 +1339,66 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): ...@@ -1323,6 +1339,66 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
# perfect world: fp32_init/2 == fp16_eval # perfect world: fp32_init/2 == fp16_eval
self.assertAlmostEqual(fp16_eval, fp32_init / 2, delta=5_000) self.assertAlmostEqual(fp16_eval, fp32_init / 2, delta=5_000)
@require_torch_gpu
@require_torch_bf16
def test_bf16_full_eval(self):
# note: most of the logic is the same as test_fp16_full_eval
# this is a sensitive test so let's keep debugging printouts in place for quick diagnosis.
# it's using pretty large safety margins, but small enough to detect broken functionality.
debug = 0
n_gpus = get_gpu_count()
bs = 8
eval_len = 16 * n_gpus
# make the params somewhat big so that there will be enough RAM consumed to be able to
# measure things. We should get about 64KB for a+b in fp32
a = torch.ones(1000, bs) + 0.001
b = torch.ones(1000, bs) - 0.001
# 1. with mem metrics enabled
trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len, skip_memory_metrics=False)
metrics = trainer.evaluate()
del trainer
gc.collect()
fp32_init = metrics["init_mem_gpu_alloc_delta"]
fp32_eval = metrics["eval_mem_gpu_alloc_delta"]
if debug:
print(f"fp32_init {fp32_init}")
print(f"fp32_eval {fp32_eval}")
# here we expect the model to be preloaded in trainer.__init__ and consume around 64K gpu ram.
# perfect world: fp32_init == 64<<10
self.assertGreater(fp32_init, 59_000)
# after eval should be no extra memory allocated - with a small margin (other than the peak
# memory consumption for the forward calculation that gets recovered)
# perfect world: fp32_eval == close to zero
self.assertLess(fp32_eval, 5_000)
# 2. with mem metrics disabled
trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len, bf16_full_eval=True, skip_memory_metrics=False)
metrics = trainer.evaluate()
bf16_init = metrics["init_mem_gpu_alloc_delta"]
bf16_eval = metrics["eval_mem_gpu_alloc_delta"]
if debug:
print(f"bf16_init {bf16_init}")
print(f"bf16_eval {bf16_eval}")
# here we expect the model to not be preloaded in trainer.__init__, so with a small margin it should be close to 0
# perfect world: bf16_init == close to zero
self.assertLess(bf16_init, 5_000)
# here we put the model on device in eval and only `half()` of it, i.e. about 32K,(again we ignore the peak margin which gets returned back)
# perfect world: fp32_init == 32<<10
self.assertGreater(bf16_eval, 27_000)
# 3. relative comparison fp32 vs full bf16
# should be about half of bf16_init
# perfect world: fp32_init/2 == bf16_eval
self.assertAlmostEqual(bf16_eval, fp32_init / 2, delta=5_000)
def test_no_wd_param_group(self): def test_no_wd_param_group(self):
model = nn.Sequential(TstLayer(128), nn.ModuleList([TstLayer(128), TstLayer(128)])) model = nn.Sequential(TstLayer(128), nn.ModuleList([TstLayer(128), TstLayer(128)]))
trainer = Trainer(model=model) trainer = Trainer(model=model)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment