Unverified Commit 70996a54 authored by Jamie DeAntonis's avatar Jamie DeAntonis Committed by GitHub
Browse files

WIP: Support for Training with BF16 (#13207)



* started bf16 integration

* minor changes

* code now runs

* style

* lay foundation for bf16 testing

* lay foundation for bf16 testing

* start the tests

* better bf16 check

* style

* 2 separate checkers - one for bf16 support, another for bf16+autocast

* Update src/transformers/training_args.py
Co-authored-by: default avatarStas Bekman <stas00@users.noreply.github.com>

* a couple of comment resolutions

* more comment resolutions

* resolved a small bug

* just some print statemtns

* added todo marking

* added a todo

* adjust for API change s/fast_dtype/dtype/

* fix style

* merge 2 bf16 util functions

* bf16 now does scaling too

* Add support for bfloat16

* Revert T5 layernorm to float32

This is based on the comment at https://github.com/huggingface/transformers/pull/14448/files#r752660929 and the PyTorch PR https://github.com/pytorch/pytorch/pull/66920

 .

* Add comment about conversion to float32 before returning the numpy data

* Add comment about AMP-bfloat16 incompatibility

* Fix formatting

* typo

* reformer / bf16

* cleanup

* require at least pt-1.10

* fix

* will deal with deepspeed separately

* cleanup

* revert

* cleanup

* fp16_full_eval and bf16_full_eval are separate modes

* proper deprecation

* cleanup

* test and fixes

* spelling

* cleanup

* add a note that this API is experimental
Co-authored-by: default avatarjamie <jamie@cortx.com>
Co-authored-by: default avatarStas Bekman <stas@stason.org>
Co-authored-by: default avatarStas Bekman <stas00@users.noreply.github.com>
Co-authored-by: default avatarsuriya <suriya@cortx.com>
Co-authored-by: default avatarManuel R. Ciosici <manuelrciosici@gmail.com>
parent fc1d97f2
......@@ -320,6 +320,37 @@ def is_torch_cuda_available():
return False
def is_torch_bf16_available():
if is_torch_available():
import torch
# since currently no utility function is available we build our own.
# some bits come from https://github.com/pytorch/pytorch/blob/2289a12f21c54da93bf5d696e3f9aea83dd9c10d/torch/testing/_internal/common_cuda.py#L51
# with additional check for torch version
# to succeed:
# 1. the hardware needs to support bf16 (arch >= Ampere)
# 2. torch >= 1.10 (1.9 should be enough for AMP API has changed in 1.10, so using 1.10 as minimal)
# 3. CUDA >= 11
# 4. torch.autocast exists
# XXX: one problem here is that it may give invalid results on mixed gpus setup, so it's
# really only correct for the 0th gpu (or currently set default device if different from 0)
if not torch.cuda.is_available() or torch.version.cuda is None:
return False
if torch.cuda.get_device_properties(torch.cuda.current_device()).major < 8:
return False
if int(torch.version.cuda.split(".")[0]) < 11:
return False
if not version.parse(torch.__version__) >= version.parse("1.10"):
return False
if not hasattr(torch, "autocast"):
return False
return True
else:
return False
_torch_fx_available = _torch_onnx_dict_inputs_support_available = False
if _torch_available:
torch_version = version.parse(importlib_metadata.version("torch"))
......
......@@ -233,7 +233,7 @@ class ModuleUtilsMixin:
if self.dtype == torch.float16:
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e4
elif self.dtype == torch.float32:
elif self.dtype in [torch.bfloat16, torch.float32]:
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e9
else:
raise ValueError(
......
......@@ -242,9 +242,10 @@ class T5LayerNorm(nn.Module):
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
# convert into float16 if necessary
if self.weight.dtype == torch.float16:
hidden_states = hidden_states.to(torch.float16)
# convert into half-precision if necessary
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
return self.weight * hidden_states
......
......@@ -49,6 +49,7 @@ from .file_utils import (
is_timm_available,
is_tokenizers_available,
is_torch_available,
is_torch_bf16_available,
is_torch_tpu_available,
is_torchaudio_available,
is_vision_available,
......@@ -493,6 +494,14 @@ def require_torch_gpu(test_case):
return test_case
def require_torch_bf16(test_case):
"""Decorator marking a test that requires CUDA hardware supporting bf16 and PyTorch >= 1.10."""
if not is_torch_bf16_available():
return unittest.skip("test requires CUDA hardware supporting bf16 and PyTorch >= 1.10")(test_case)
else:
return test_case
def require_datasets(test_case):
"""Decorator marking a test that requires datasets."""
......
......@@ -353,13 +353,13 @@ class Trainer:
# 1. MP - since we are trying to fit a much bigger than 1 gpu model
# 2. fp16-enabled DeepSpeed loads the model in half the size and it doesn't need .to() anyway,
# and we only use deepspeed for training at the moment
# 3. full fp16 eval - since the model needs to be half'ed first
# 3. full bf16 or fp16 eval - since the model needs to be cast to the right dtype first
# 4. Sharded DDP - same as MP
self.place_model_on_device = args.place_model_on_device
if (
self.is_model_parallel
or args.deepspeed
or (args.fp16_full_eval and not args.do_train)
or ((args.fp16_full_eval or args.bf16_full_eval) and not args.do_train)
or (self.sharded_ddp in [ShardedDDPOption.ZERO_DP_2, ShardedDDPOption.ZERO_DP_3])
):
self.place_model_on_device = False
......@@ -424,18 +424,24 @@ class Trainer:
# Mixed precision setup
self.use_apex = False
self.use_amp = False
self.fp16_backend = None
if args.fp16:
if args.fp16_backend == "auto":
self.fp16_backend = "amp" if _is_native_amp_available else "apex"
else:
self.fp16_backend = args.fp16_backend
logger.info(f"Using {self.fp16_backend} fp16 backend")
if args.fp16 or args.bf16:
if args.half_precision_backend == "auto":
if _is_native_amp_available:
args.half_precision_backend = "amp"
else:
if args.bf16:
raise ValueError("Tried to use `bf16` but native amp is not available")
else:
args.half_precision_backend = "apex"
logger.info(f"Using {args.half_precision_backend} half precision backend")
if args.fp16 and not args.deepspeed: # deepspeed manages its own fp16
if self.fp16_backend == "amp":
self.do_grad_scaling = False
if (args.fp16 or args.bf16) and not args.deepspeed: # deepspeed manages its own half precision
if args.half_precision_backend == "amp":
self.use_amp = True
self.amp_dtype = torch.float16 if args.fp16 else torch.bfloat16
self.do_grad_scaling = True
if is_sagemaker_mp_enabled():
self.scaler = smp.amp.GradScaler()
elif self.sharded_ddp is not None:
......@@ -975,7 +981,7 @@ class Trainer:
if self.sharded_ddp == ShardedDDPOption.SIMPLE:
model = ShardedDDP(model, self.optimizer)
else:
mixed_precision = self.args.fp16
mixed_precision = self.args.fp16 or self.args.bf16
cpu_offload = ShardedDDPOption.OFFLOAD in self.args.sharded_ddp
zero_3 = self.sharded_ddp == ShardedDDPOption.ZERO_DP_3
# XXX: Breaking the self.model convention but I see no way around it for now.
......@@ -1043,7 +1049,7 @@ class Trainer:
# do_train is not a reliable argument, as it might not be set and .train() still called, so
# the following is a workaround:
if args.fp16_full_eval and not args.do_train:
if (args.fp16_full_eval or args.bf16_full_eval) and not args.do_train:
self._move_model_to_device(self.model, args.device)
if "model_path" in kwargs:
......@@ -1341,7 +1347,7 @@ class Trainer:
if args.max_grad_norm is not None and args.max_grad_norm > 0 and not self.deepspeed:
# deepspeed does its own clipping
if self.use_amp:
if self.do_grad_scaling:
# AMP: gradients need unscaling
self.scaler.unscale_(self.optimizer)
......@@ -1364,7 +1370,7 @@ class Trainer:
pass # called outside the loop
elif is_torch_tpu_available():
xm.optimizer_step(self.optimizer)
elif self.use_amp:
elif self.do_grad_scaling:
scale_before = self.scaler.get_scale()
self.scaler.step(self.optimizer)
self.scaler.update()
......@@ -1588,7 +1594,7 @@ class Trainer:
with warnings.catch_warnings(record=True) as caught_warnings:
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
reissue_pt_warnings(caught_warnings)
if self.use_amp:
if self.do_grad_scaling:
torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
elif self.args.should_save and not self.deepspeed:
# deepspeed.save_checkpoint above saves model/optim/sched
......@@ -1596,7 +1602,7 @@ class Trainer:
with warnings.catch_warnings(record=True) as caught_warnings:
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
reissue_pt_warnings(caught_warnings)
if self.use_amp:
if self.do_grad_scaling:
torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
# Determine the new best metric / best model checkpoint
......@@ -1684,7 +1690,7 @@ class Trainer:
with warnings.catch_warnings(record=True) as caught_warnings:
self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME)))
reissue_pt_warnings(caught_warnings)
if self.use_amp and os.path.isfile(os.path.join(checkpoint, SCALER_NAME)):
if self.do_grad_scaling and os.path.isfile(os.path.join(checkpoint, SCALER_NAME)):
self.scaler.load_state_dict(torch.load(os.path.join(checkpoint, SCALER_NAME)))
def hyperparameter_search(
......@@ -1846,12 +1852,12 @@ class Trainer:
inputs = self._prepare_inputs(inputs)
if is_sagemaker_mp_enabled():
scaler = self.scaler if self.use_amp else None
scaler = self.scaler if self.do_grad_scaling else None
loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps, scaler=scaler)
return loss_mb.reduce_mean().detach().to(self.args.device)
if self.use_amp:
with autocast():
with autocast(dtype=self.amp_dtype):
loss = self.compute_loss(model, inputs)
else:
loss = self.compute_loss(model, inputs)
......@@ -1863,7 +1869,7 @@ class Trainer:
# deepspeed handles loss scaling by gradient_accumulation_steps in its `backward`
loss = loss / self.args.gradient_accumulation_steps
if self.use_amp:
if self.do_grad_scaling:
self.scaler.scale(loss).backward()
elif self.use_apex:
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
......@@ -2220,12 +2226,12 @@ class Trainer:
Works both with or without labels.
"""
prediction_loss_only = (
prediction_loss_only if prediction_loss_only is not None else self.args.prediction_loss_only
)
args = self.args
prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only
# if eval is called w/o train init deepspeed here
if self.args.deepspeed and not self.deepspeed:
if args.deepspeed and not self.deepspeed:
# XXX: eval doesn't have `resume_from_checkpoint` arg but we should be able to do eval
# from the checkpoint eventually
......@@ -2238,10 +2244,13 @@ class Trainer:
model = self._wrap_model(self.model, training=False)
# if full fp16 is wanted on eval and this ``evaluation`` or ``predict`` isn't called while
# ``train`` is running, halve it first and then put on device
if not self.is_in_train and self.args.fp16_full_eval:
model = model.half().to(self.args.device)
# if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called
# while ``train`` is running, cast it to the right dtype first and then put on device
if not self.is_in_train:
if args.fp16_full_eval:
model = model.to(dtype=torch.float16, device=args.device)
elif args.bf16_full_eval:
model = model.to(dtype=torch.bfloat16, device=args.device)
batch_size = dataloader.batch_size
......@@ -2259,9 +2268,9 @@ class Trainer:
eval_dataset = dataloader.dataset
if is_torch_tpu_available():
dataloader = pl.ParallelLoader(dataloader, [self.args.device]).per_device_loader(self.args.device)
dataloader = pl.ParallelLoader(dataloader, [args.device]).per_device_loader(args.device)
if self.args.past_index >= 0:
if args.past_index >= 0:
self._past = None
# Initialize containers
......@@ -2301,10 +2310,10 @@ class Trainer:
labels = self._pad_across_processes(labels)
labels = self._nested_gather(labels)
labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100)
self.control = self.callback_handler.on_prediction_step(self.args, self.state, self.control)
self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)
# Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
if self.args.eval_accumulation_steps is not None and (step + 1) % self.args.eval_accumulation_steps == 0:
if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0:
if losses_host is not None:
losses = nested_numpify(losses_host)
all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0)
......@@ -2320,7 +2329,7 @@ class Trainer:
# Set back to None to begin a new accumulation
losses_host, preds_host, labels_host = None, None, None
if self.args.past_index and hasattr(self, "_past"):
if args.past_index and hasattr(self, "_past"):
# Clean the state at the end of the evaluation loop
delattr(self, "_past")
......@@ -2492,11 +2501,12 @@ class Trainer:
else:
if has_labels:
if self.use_amp:
with autocast():
with autocast(dtype=self.amp_dtype):
loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
else:
loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
loss = loss.mean().detach()
if isinstance(outputs, dict):
logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"])
else:
......@@ -2504,7 +2514,7 @@ class Trainer:
else:
loss = None
if self.use_amp:
with autocast():
with autocast(dtype=self.amp_dtype):
outputs = model(**inputs)
else:
outputs = model(**inputs)
......@@ -2719,14 +2729,14 @@ class Trainer:
Works both with or without labels.
"""
args = self.args
if not isinstance(dataloader.dataset, collections.abc.Sized):
raise ValueError("dataset must implement __len__")
prediction_loss_only = (
prediction_loss_only if prediction_loss_only is not None else self.args.prediction_loss_only
)
prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only
# if eval is called w/o train init deepspeed here
if self.args.deepspeed and not self.deepspeed:
if args.deepspeed and not self.deepspeed:
# XXX: eval doesn't have `resume_from_checkpoint` arg but we should be able to do eval
# from the checkpoint eventually
......@@ -2742,10 +2752,13 @@ class Trainer:
model = self._wrap_model(self.model, training=False)
# if full fp16 is wanted on eval and this ``evaluation`` or ``predict`` isn't called while
# ``train`` is running, halve it first and then put on device
if not self.is_in_train and self.args.fp16_full_eval:
model = model.half().to(self.args.device)
# if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called
# while ``train`` is running, cast it to the right dtype first and then put on device
if not self.is_in_train:
if args.fp16_full_eval:
model = model.to(dtype=torch.float16, device=args.device)
elif args.bf16_full_eval:
model = model.to(dtype=torch.bfloat16, device=args.device)
batch_size = dataloader.batch_size
num_examples = self.num_examples(dataloader)
......@@ -2756,7 +2769,7 @@ class Trainer:
preds_host: Union[torch.Tensor, List[torch.Tensor]] = None
labels_host: Union[torch.Tensor, List[torch.Tensor]] = None
world_size = max(1, self.args.world_size)
world_size = max(1, args.world_size)
eval_losses_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size)
if not prediction_loss_only:
......@@ -2771,9 +2784,9 @@ class Trainer:
model.eval()
if is_torch_tpu_available():
dataloader = pl.ParallelLoader(dataloader, [self.args.device]).per_device_loader(self.args.device)
dataloader = pl.ParallelLoader(dataloader, [args.device]).per_device_loader(args.device)
if self.args.past_index >= 0:
if args.past_index >= 0:
self._past = None
self.callback_handler.eval_dataloader = dataloader
......@@ -2787,10 +2800,10 @@ class Trainer:
preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100)
if labels is not None:
labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100)
self.control = self.callback_handler.on_prediction_step(self.args, self.state, self.control)
self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)
# Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
if self.args.eval_accumulation_steps is not None and (step + 1) % self.args.eval_accumulation_steps == 0:
if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0:
eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses"))
if not prediction_loss_only:
preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds"))
......@@ -2799,7 +2812,7 @@ class Trainer:
# Set back to None to begin a new accumulation
losses_host, preds_host, labels_host = None, None, None
if self.args.past_index and hasattr(self, "_past"):
if args.past_index and hasattr(self, "_past"):
# Clean the state at the end of the evaluation loop
delattr(self, "_past")
......
......@@ -136,7 +136,13 @@ def nested_numpify(tensors):
"Numpify `tensors` (even if it's a nested list/tuple of tensors)."
if isinstance(tensors, (list, tuple)):
return type(tensors)(nested_numpify(t) for t in tensors)
return tensors.cpu().numpy()
t = tensors.cpu()
if t.dtype == torch.bfloat16:
# As of Numpy 1.21.4, NumPy does not support bfloat16 (see
# https://github.com/numpy/numpy/blob/a47ecdea856986cd60eabbd53265c2ca5916ad5d/doc/source/user/basics.types.rst ).
# Until Numpy adds bfloat16, we must convert float32.
t = t.to(torch.float32)
return t.numpy()
def nested_detach(tensors):
......
......@@ -207,18 +207,26 @@ class TrainingArguments:
Random seed that will be set at the beginning of training. To ensure reproducibility across runs, use the
:func:`~transformers.Trainer.model_init` function to instantiate the model if it has some randomly
initialized parameters.
bf16 (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to use bf16 16-bit (mixed) precision training instead of 32-bit training. Requires Ampere or higher
NVIDIA architecture. This is an experimental API and it may change.
fp16 (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to use 16-bit (mixed) precision training instead of 32-bit training.
Whether to use fp16 16-bit (mixed) precision training instead of 32-bit training.
fp16_opt_level (:obj:`str`, `optional`, defaults to 'O1'):
For :obj:`fp16` training, Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']. See details
on the `Apex documentation <https://nvidia.github.io/apex/amp.html>`__.
fp16_backend (:obj:`str`, `optional`, defaults to :obj:`"auto"`):
This argument is deprecated. Use ``half_precision_backend`` instead.
half_precision_backend (:obj:`str`, `optional`, defaults to :obj:`"auto"`):
The backend to use for mixed precision training. Must be one of :obj:`"auto"`, :obj:`"amp"` or
:obj:`"apex"`. :obj:`"auto"` will use AMP or APEX depending on the PyTorch version detected, while the
other choices will force the requested backend.
bf16_full_eval (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to use full bfloat16 evaluation instead of 32-bit. This will be faster and save memory but can harm
metric values. This is an experimental API and it may change.
fp16_full_eval (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to use full 16-bit precision evaluation instead of 32-bit. This will be faster and save memory but
can harm metric values.
Whether to use full float16 evaluation instead of 32-bit. This will be faster and save memory but can harm
metric values.
local_rank (:obj:`int`, `optional`, defaults to -1):
Rank of the process during distributed training.
xpu_backend (:obj:`str`, `optional`):
......@@ -507,10 +515,15 @@ class TrainingArguments:
)
no_cuda: bool = field(default=False, metadata={"help": "Do not use CUDA even when it is available"})
seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."})
bf16: bool = field(
default=False,
metadata={
"help": "Whether to use bf16 (mixed) precision instead of 32-bit. Requires Ampere or higher NVIDIA architecture. This is an experimental API and it may change."
},
)
fp16: bool = field(
default=False,
metadata={"help": "Whether to use 16-bit (mixed) precision instead of 32-bit"},
metadata={"help": "Whether to use fp16 (mixed) precision instead of 32-bit"},
)
fp16_opt_level: str = field(
default="O1",
......@@ -521,13 +534,19 @@ class TrainingArguments:
)
},
)
fp16_backend: str = field(
half_precision_backend: str = field(
default="auto",
metadata={"help": "The backend to be used for mixed precision.", "choices": ["auto", "amp", "apex"]},
metadata={"help": "The backend to be used for half precision.", "choices": ["auto", "amp", "apex"]},
)
bf16_full_eval: bool = field(
default=False,
metadata={
"help": "Whether to use full bfloat16 evaluation instead of 32-bit. This is an experimental API and it may change."
},
)
fp16_full_eval: bool = field(
default=False,
metadata={"help": "Whether to use full 16-bit precision evaluation instead of 32-bit"},
metadata={"help": "Whether to use full float16 evaluation instead of 32-bit"},
)
local_rank: int = field(default=-1, metadata={"help": "For distributed training: local_rank"})
xpu_backend: str = field(
......@@ -666,6 +685,10 @@ class TrainingArguments:
},
)
# Deprecated arguments
fp16_backend: str = field(
default="auto",
metadata={"help": "Deprecated. Use half_precision_backend instead", "choices": ["auto", "amp", "apex"]},
)
push_to_hub_model_id: str = field(
default=None, metadata={"help": "The name of the repository to which push the `Trainer`."}
)
......@@ -754,10 +777,31 @@ class TrainingArguments:
if self.run_name is None:
self.run_name = self.output_dir
if is_torch_available() and self.device.type != "cuda" and (self.fp16 or self.fp16_full_eval):
if self.fp16_backend and self.fp16_backend != "auto":
warnings.warn(
"`fp16_backend` is deprecated and will be removed in version 5 of 🤗 Transformers. Use `half_precision_backend` instead",
FutureWarning,
)
self.half_precision_backend = self.fp16_backend
if self.fp16 and self.bf16:
raise ValueError("At most one of fp16 and bf16 can be True, but not both")
if self.bf16:
if self.half_precision_backend == "apex":
raise ValueError(
" `--half_precision_backend apex`: bf16 is not supported by apex. Use `--half_precision_backend amp` instead"
)
if not (self.sharded_ddp == "" or not self.sharded_ddp):
raise ValueError("sharded_ddp is not supported with bf16")
if (
is_torch_available()
and self.device.type != "cuda"
and (self.fp16 or self.fp16_full_eval or self.bf16 or self.bf16_full_eval)
):
raise ValueError(
"Mixed precision training with AMP or APEX (`--fp16`) and FP16 evaluation can only be used on CUDA devices."
"Mixed precision training with AMP or APEX (`--fp16` or `--bf16`) and half precision evaluation (`--fp16_full_eval` or `--bf16_full_eval`) can only be used on CUDA devices."
)
if self.report_to is None:
logger.info(
"The default value for the training argument `--report_to` will change in v5 (from all installed "
......
......@@ -53,6 +53,7 @@ from transformers.testing_utils import (
require_sigopt,
require_tokenizers,
require_torch,
require_torch_bf16,
require_torch_gpu,
require_torch_multi_gpu,
require_torch_non_multi_gpu,
......@@ -476,6 +477,21 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
self.assertFalse(torch.allclose(trainer.model.b, b))
self.assertGreater(trainer.optimizer.state_dict()["param_groups"][0]["lr"], 0)
@require_torch_gpu
@require_torch_bf16
def test_mixed_bf16(self):
# very basic test
trainer = get_regression_trainer(learning_rate=0.1, bf16=True)
trainer.train()
self.check_trained_model(trainer.model)
# --bf16 --half_precision_backend apex can't be used together
with self.assertRaises(ValueError):
trainer = get_regression_trainer(learning_rate=0.1, bf16=True, half_precision_backend="apex")
# will add more specific tests once there are some bugs to fix
@require_torch
@require_sentencepiece
......@@ -1323,6 +1339,66 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
# perfect world: fp32_init/2 == fp16_eval
self.assertAlmostEqual(fp16_eval, fp32_init / 2, delta=5_000)
@require_torch_gpu
@require_torch_bf16
def test_bf16_full_eval(self):
# note: most of the logic is the same as test_fp16_full_eval
# this is a sensitive test so let's keep debugging printouts in place for quick diagnosis.
# it's using pretty large safety margins, but small enough to detect broken functionality.
debug = 0
n_gpus = get_gpu_count()
bs = 8
eval_len = 16 * n_gpus
# make the params somewhat big so that there will be enough RAM consumed to be able to
# measure things. We should get about 64KB for a+b in fp32
a = torch.ones(1000, bs) + 0.001
b = torch.ones(1000, bs) - 0.001
# 1. with mem metrics enabled
trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len, skip_memory_metrics=False)
metrics = trainer.evaluate()
del trainer
gc.collect()
fp32_init = metrics["init_mem_gpu_alloc_delta"]
fp32_eval = metrics["eval_mem_gpu_alloc_delta"]
if debug:
print(f"fp32_init {fp32_init}")
print(f"fp32_eval {fp32_eval}")
# here we expect the model to be preloaded in trainer.__init__ and consume around 64K gpu ram.
# perfect world: fp32_init == 64<<10
self.assertGreater(fp32_init, 59_000)
# after eval should be no extra memory allocated - with a small margin (other than the peak
# memory consumption for the forward calculation that gets recovered)
# perfect world: fp32_eval == close to zero
self.assertLess(fp32_eval, 5_000)
# 2. with mem metrics disabled
trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len, bf16_full_eval=True, skip_memory_metrics=False)
metrics = trainer.evaluate()
bf16_init = metrics["init_mem_gpu_alloc_delta"]
bf16_eval = metrics["eval_mem_gpu_alloc_delta"]
if debug:
print(f"bf16_init {bf16_init}")
print(f"bf16_eval {bf16_eval}")
# here we expect the model to not be preloaded in trainer.__init__, so with a small margin it should be close to 0
# perfect world: bf16_init == close to zero
self.assertLess(bf16_init, 5_000)
# here we put the model on device in eval and only `half()` of it, i.e. about 32K,(again we ignore the peak margin which gets returned back)
# perfect world: fp32_init == 32<<10
self.assertGreater(bf16_eval, 27_000)
# 3. relative comparison fp32 vs full bf16
# should be about half of bf16_init
# perfect world: fp32_init/2 == bf16_eval
self.assertAlmostEqual(bf16_eval, fp32_init / 2, delta=5_000)
def test_no_wd_param_group(self):
model = nn.Sequential(TstLayer(128), nn.ModuleList([TstLayer(128), TstLayer(128)]))
trainer = Trainer(model=model)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment