Unverified Commit ad895af9 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Add possibility to switch between APEX and AMP in Trainer (#9137)



* Add possibility to switch between APEX and AMP in Trainer

* Update src/transformers/training_args.py
Co-authored-by: default avatarStas Bekman <stas00@users.noreply.github.com>

* Address review comments

* Update src/transformers/training_args.py
Co-authored-by: default avatarStas Bekman <stas00@users.noreply.github.com>
Co-authored-by: default avatarStas Bekman <stas00@users.noreply.github.com>
parent 0b2f46fa
...@@ -93,8 +93,7 @@ from .training_args import TrainingArguments ...@@ -93,8 +93,7 @@ from .training_args import TrainingArguments
from .utils import logging from .utils import logging
_use_native_amp = False _is_native_amp_available = False
_use_apex = False
DEFAULT_CALLBACKS = [DefaultFlowCallback] DEFAULT_CALLBACKS = [DefaultFlowCallback]
DEFAULT_PROGRESS_CALLBACK = ProgressCallback DEFAULT_PROGRESS_CALLBACK = ProgressCallback
...@@ -110,16 +109,10 @@ if version.parse(torch.__version__) < version.parse("1.6"): ...@@ -110,16 +109,10 @@ if version.parse(torch.__version__) < version.parse("1.6"):
if is_apex_available(): if is_apex_available():
from apex import amp from apex import amp
_use_apex = True
else: else:
_use_native_amp = True _is_native_amp_available = True
from torch.cuda.amp import autocast from torch.cuda.amp import autocast
if version.parse(torch.__version__) < version.parse("1.2"):
_use_ddp_no_sync = False
else:
_use_ddp_no_sync = True
if is_datasets_available(): if is_datasets_available():
import datasets import datasets
...@@ -292,13 +285,30 @@ class Trainer: ...@@ -292,13 +285,30 @@ class Trainer:
if isinstance(eval_dataset, datasets.Dataset): if isinstance(eval_dataset, datasets.Dataset):
self._remove_unused_columns(self.eval_dataset, description="evaluation") self._remove_unused_columns(self.eval_dataset, description="evaluation")
# Mixed precision setup
self.use_apex = False
self.use_amp = False
if args.fp16:
if args.fp16_backend == "auto":
backend = "amp" if _is_native_amp_available else "apex"
else:
backend = args.fp16_backend
if backend == "amp":
self.use_amp = True
self.scaler = torch.cuda.amp.GradScaler()
else:
if not is_apex_available():
raise ImportError(
"Using FP16 with APEX but APEX is not installed, please refer to https://www.github.com/nvidia/apex."
)
self.use_apex = True
self.state = TrainerState() self.state = TrainerState()
self.control = TrainerControl() self.control = TrainerControl()
# Internal variable for total_flos used to count as tensors (for distributed + TPU), will be sent in the # Internal variable for total_flos used to count as tensors (for distributed + TPU), will be sent in the
# state at each call to self.log. # state at each call to self.log.
self._total_flos = None self._total_flos = None
if self.args.fp16 and _use_native_amp:
self.scaler = torch.cuda.amp.GradScaler()
self.hp_search_backend = None self.hp_search_backend = None
self.use_tune_checkpoints = False self.use_tune_checkpoints = False
default_label_names = ( default_label_names = (
...@@ -625,9 +635,7 @@ class Trainer: ...@@ -625,9 +635,7 @@ class Trainer:
# Mixed precision training with apex (torch < 1.6) # Mixed precision training with apex (torch < 1.6)
model = self.model model = self.model
if self.args.fp16 and _use_apex: if self.use_apex:
if not is_apex_available():
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level) model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level)
# Multi-gpu training (should be after apex fp16 initialization) # Multi-gpu training (should be after apex fp16 initialization)
...@@ -756,11 +764,8 @@ class Trainer: ...@@ -756,11 +764,8 @@ class Trainer:
if (step + 1) % self.args.gradient_accumulation_steps == 0: if (step + 1) % self.args.gradient_accumulation_steps == 0:
self.control = self.callback_handler.on_step_begin(self.args, self.state, self.control) self.control = self.callback_handler.on_step_begin(self.args, self.state, self.control)
if ( if ((step + 1) % self.args.gradient_accumulation_steps != 0) and self.args.local_rank != -1:
((step + 1) % self.args.gradient_accumulation_steps != 0) # Avoid unnecessary DDP synchronization since there will be no backward pass on this example.
and self.args.local_rank != -1
and _use_ddp_no_sync
):
with model.no_sync(): with model.no_sync():
tr_loss += self.training_step(model, inputs) tr_loss += self.training_step(model, inputs)
else: else:
...@@ -772,17 +777,17 @@ class Trainer: ...@@ -772,17 +777,17 @@ class Trainer:
steps_in_epoch <= self.args.gradient_accumulation_steps steps_in_epoch <= self.args.gradient_accumulation_steps
and (step + 1) == steps_in_epoch and (step + 1) == steps_in_epoch
): ):
if self.args.fp16 and _use_native_amp: if self.use_amp:
self.scaler.unscale_(self.optimizer) self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm) torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)
elif self.args.fp16 and _use_apex: elif self.use_apex:
torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), self.args.max_grad_norm) torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), self.args.max_grad_norm)
else: else:
torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm) torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)
if is_torch_tpu_available(): if is_torch_tpu_available():
xm.optimizer_step(self.optimizer) xm.optimizer_step(self.optimizer)
elif self.args.fp16 and _use_native_amp: elif self.use_amp:
self.scaler.step(self.optimizer) self.scaler.step(self.optimizer)
self.scaler.update() self.scaler.update()
else: else:
...@@ -1089,7 +1094,7 @@ class Trainer: ...@@ -1089,7 +1094,7 @@ class Trainer:
model.train() model.train()
inputs = self._prepare_inputs(inputs) inputs = self._prepare_inputs(inputs)
if self.args.fp16 and _use_native_amp: if self.use_amp:
with autocast(): with autocast():
loss = self.compute_loss(model, inputs) loss = self.compute_loss(model, inputs)
else: else:
...@@ -1101,9 +1106,9 @@ class Trainer: ...@@ -1101,9 +1106,9 @@ class Trainer:
if self.args.gradient_accumulation_steps > 1: if self.args.gradient_accumulation_steps > 1:
loss = loss / self.args.gradient_accumulation_steps loss = loss / self.args.gradient_accumulation_steps
if self.args.fp16 and _use_native_amp: if self.use_amp:
self.scaler.scale(loss).backward() self.scaler.scale(loss).backward()
elif self.args.fp16 and _use_apex: elif self.use_apex:
with amp.scale_loss(loss, self.optimizer) as scaled_loss: with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward() scaled_loss.backward()
else: else:
...@@ -1498,7 +1503,7 @@ class Trainer: ...@@ -1498,7 +1503,7 @@ class Trainer:
ignore_keys = [] ignore_keys = []
with torch.no_grad(): with torch.no_grad():
if self.args.fp16 and _use_native_amp: if self.use_amp:
with autocast(): with autocast():
outputs = model(**inputs) outputs = model(**inputs)
else: else:
......
...@@ -211,6 +211,10 @@ class TrainingArguments: ...@@ -211,6 +211,10 @@ class TrainingArguments:
When resuming training, whether or not to skip the epochs and batches to get the data loading at the same When resuming training, whether or not to skip the epochs and batches to get the data loading at the same
stage as in the previous training. If set to :obj:`True`, the training will begin faster (as that skipping stage as in the previous training. If set to :obj:`True`, the training will begin faster (as that skipping
step can take a long time) but will not yield the same results as the interrupted training would have. step can take a long time) but will not yield the same results as the interrupted training would have.
fp16_backend (:obj:`str`, `optional`, defaults to :obj:`"auto"`):
The backend to use for mixed precision training. Must be one of :obj:`"auto"`, :obj:`"amp"` or
:obj:`"apex"`. :obj:`"auto"` will use AMP or APEX depending on the PyTorch version detected, while the
other choices will force the requested backend.
""" """
output_dir: str = field( output_dir: str = field(
...@@ -378,6 +382,10 @@ class TrainingArguments: ...@@ -378,6 +382,10 @@ class TrainingArguments:
"help": "When resuming training, whether or not to skip the first epochs and batches to get to the same training data." "help": "When resuming training, whether or not to skip the first epochs and batches to get to the same training data."
}, },
) )
fp16_backend: str = field(
default="auto",
metadata={"help": "The backend to be used for mixed precision. Should be one of 'auto', 'amp' or 'apex'."},
)
def __post_init__(self): def __post_init__(self):
if self.disable_tqdm is None: if self.disable_tqdm is None:
......
...@@ -798,7 +798,9 @@ class TrainerIntegrationTest(unittest.TestCase): ...@@ -798,7 +798,9 @@ class TrainerIntegrationTest(unittest.TestCase):
def test_early_stopping_callback(self): def test_early_stopping_callback(self):
# early stopping stops training before num_training_epochs # early stopping stops training before num_training_epochs
with tempfile.TemporaryDirectory() as tmp_dir:
trainer = get_regression_trainer( trainer = get_regression_trainer(
output_dir=tmp_dir,
num_train_epochs=20, num_train_epochs=20,
gradient_accumulation_steps=1, gradient_accumulation_steps=1,
per_device_train_batch_size=16, per_device_train_batch_size=16,
...@@ -812,7 +814,9 @@ class TrainerIntegrationTest(unittest.TestCase): ...@@ -812,7 +814,9 @@ class TrainerIntegrationTest(unittest.TestCase):
self.assertLess(train_output.global_step, 20 * 64 / 16) self.assertLess(train_output.global_step, 20 * 64 / 16)
# Invalid inputs to trainer with early stopping callback result in assertion error # Invalid inputs to trainer with early stopping callback result in assertion error
with tempfile.TemporaryDirectory() as tmp_dir:
trainer = get_regression_trainer( trainer = get_regression_trainer(
output_dir=tmp_dir,
num_train_epochs=20, num_train_epochs=20,
gradient_accumulation_steps=1, gradient_accumulation_steps=1,
per_device_train_batch_size=16, per_device_train_batch_size=16,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment