"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "4c2538b863d8949a98d6b8dc1dea9ed4cf96a5df"
Unverified Commit ad895af9 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Add possibility to switch between APEX and AMP in Trainer (#9137)



* Add possibility to switch between APEX and AMP in Trainer

* Update src/transformers/training_args.py
Co-authored-by: default avatarStas Bekman <stas00@users.noreply.github.com>

* Address review comments

* Update src/transformers/training_args.py
Co-authored-by: default avatarStas Bekman <stas00@users.noreply.github.com>
Co-authored-by: default avatarStas Bekman <stas00@users.noreply.github.com>
parent 0b2f46fa
......@@ -93,8 +93,7 @@ from .training_args import TrainingArguments
from .utils import logging
_use_native_amp = False
_use_apex = False
_is_native_amp_available = False
DEFAULT_CALLBACKS = [DefaultFlowCallback]
DEFAULT_PROGRESS_CALLBACK = ProgressCallback
......@@ -110,16 +109,10 @@ if version.parse(torch.__version__) < version.parse("1.6"):
if is_apex_available():
from apex import amp
_use_apex = True
else:
_use_native_amp = True
_is_native_amp_available = True
from torch.cuda.amp import autocast
if version.parse(torch.__version__) < version.parse("1.2"):
_use_ddp_no_sync = False
else:
_use_ddp_no_sync = True
if is_datasets_available():
import datasets
......@@ -292,13 +285,30 @@ class Trainer:
if isinstance(eval_dataset, datasets.Dataset):
self._remove_unused_columns(self.eval_dataset, description="evaluation")
# Mixed precision setup
self.use_apex = False
self.use_amp = False
if args.fp16:
if args.fp16_backend == "auto":
backend = "amp" if _is_native_amp_available else "apex"
else:
backend = args.fp16_backend
if backend == "amp":
self.use_amp = True
self.scaler = torch.cuda.amp.GradScaler()
else:
if not is_apex_available():
raise ImportError(
"Using FP16 with APEX but APEX is not installed, please refer to https://www.github.com/nvidia/apex."
)
self.use_apex = True
self.state = TrainerState()
self.control = TrainerControl()
# Internal variable for total_flos used to count as tensors (for distributed + TPU), will be sent in the
# state at each call to self.log.
self._total_flos = None
if self.args.fp16 and _use_native_amp:
self.scaler = torch.cuda.amp.GradScaler()
self.hp_search_backend = None
self.use_tune_checkpoints = False
default_label_names = (
......@@ -625,9 +635,7 @@ class Trainer:
# Mixed precision training with apex (torch < 1.6)
model = self.model
if self.args.fp16 and _use_apex:
if not is_apex_available():
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
if self.use_apex:
model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level)
# Multi-gpu training (should be after apex fp16 initialization)
......@@ -756,11 +764,8 @@ class Trainer:
if (step + 1) % self.args.gradient_accumulation_steps == 0:
self.control = self.callback_handler.on_step_begin(self.args, self.state, self.control)
if (
((step + 1) % self.args.gradient_accumulation_steps != 0)
and self.args.local_rank != -1
and _use_ddp_no_sync
):
if ((step + 1) % self.args.gradient_accumulation_steps != 0) and self.args.local_rank != -1:
# Avoid unnecessary DDP synchronization since there will be no backward pass on this example.
with model.no_sync():
tr_loss += self.training_step(model, inputs)
else:
......@@ -772,17 +777,17 @@ class Trainer:
steps_in_epoch <= self.args.gradient_accumulation_steps
and (step + 1) == steps_in_epoch
):
if self.args.fp16 and _use_native_amp:
if self.use_amp:
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)
elif self.args.fp16 and _use_apex:
elif self.use_apex:
torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), self.args.max_grad_norm)
else:
torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)
if is_torch_tpu_available():
xm.optimizer_step(self.optimizer)
elif self.args.fp16 and _use_native_amp:
elif self.use_amp:
self.scaler.step(self.optimizer)
self.scaler.update()
else:
......@@ -1089,7 +1094,7 @@ class Trainer:
model.train()
inputs = self._prepare_inputs(inputs)
if self.args.fp16 and _use_native_amp:
if self.use_amp:
with autocast():
loss = self.compute_loss(model, inputs)
else:
......@@ -1101,9 +1106,9 @@ class Trainer:
if self.args.gradient_accumulation_steps > 1:
loss = loss / self.args.gradient_accumulation_steps
if self.args.fp16 and _use_native_amp:
if self.use_amp:
self.scaler.scale(loss).backward()
elif self.args.fp16 and _use_apex:
elif self.use_apex:
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
else:
......@@ -1498,7 +1503,7 @@ class Trainer:
ignore_keys = []
with torch.no_grad():
if self.args.fp16 and _use_native_amp:
if self.use_amp:
with autocast():
outputs = model(**inputs)
else:
......
......@@ -211,6 +211,10 @@ class TrainingArguments:
When resuming training, whether or not to skip the epochs and batches to get the data loading at the same
stage as in the previous training. If set to :obj:`True`, the training will begin faster (as that skipping
step can take a long time) but will not yield the same results as the interrupted training would have.
fp16_backend (:obj:`str`, `optional`, defaults to :obj:`"auto"`):
The backend to use for mixed precision training. Must be one of :obj:`"auto"`, :obj:`"amp"` or
:obj:`"apex"`. :obj:`"auto"` will use AMP or APEX depending on the PyTorch version detected, while the
other choices will force the requested backend.
"""
output_dir: str = field(
......@@ -378,6 +382,10 @@ class TrainingArguments:
"help": "When resuming training, whether or not to skip the first epochs and batches to get to the same training data."
},
)
fp16_backend: str = field(
default="auto",
metadata={"help": "The backend to be used for mixed precision. Should be one of 'auto', 'amp' or 'apex'."},
)
def __post_init__(self):
if self.disable_tqdm is None:
......
......@@ -798,34 +798,38 @@ class TrainerIntegrationTest(unittest.TestCase):
def test_early_stopping_callback(self):
# early stopping stops training before num_training_epochs
trainer = get_regression_trainer(
num_train_epochs=20,
gradient_accumulation_steps=1,
per_device_train_batch_size=16,
load_best_model_at_end=True,
evaluation_strategy=EvaluationStrategy.EPOCH,
compute_metrics=AlmostAccuracy(),
metric_for_best_model="accuracy",
)
trainer.add_callback(EarlyStoppingCallback(1, 0.0001))
train_output = trainer.train()
self.assertLess(train_output.global_step, 20 * 64 / 16)
with tempfile.TemporaryDirectory() as tmp_dir:
trainer = get_regression_trainer(
output_dir=tmp_dir,
num_train_epochs=20,
gradient_accumulation_steps=1,
per_device_train_batch_size=16,
load_best_model_at_end=True,
evaluation_strategy=EvaluationStrategy.EPOCH,
compute_metrics=AlmostAccuracy(),
metric_for_best_model="accuracy",
)
trainer.add_callback(EarlyStoppingCallback(1, 0.0001))
train_output = trainer.train()
self.assertLess(train_output.global_step, 20 * 64 / 16)
# Invalid inputs to trainer with early stopping callback result in assertion error
trainer = get_regression_trainer(
num_train_epochs=20,
gradient_accumulation_steps=1,
per_device_train_batch_size=16,
evaluation_strategy=EvaluationStrategy.EPOCH,
compute_metrics=AlmostAccuracy(),
metric_for_best_model="accuracy",
)
trainer.add_callback(EarlyStoppingCallback(1))
self.assertEqual(trainer.state.global_step, 0)
try:
trainer.train()
except AssertionError:
with tempfile.TemporaryDirectory() as tmp_dir:
trainer = get_regression_trainer(
output_dir=tmp_dir,
num_train_epochs=20,
gradient_accumulation_steps=1,
per_device_train_batch_size=16,
evaluation_strategy=EvaluationStrategy.EPOCH,
compute_metrics=AlmostAccuracy(),
metric_for_best_model="accuracy",
)
trainer.add_callback(EarlyStoppingCallback(1))
self.assertEqual(trainer.state.global_step, 0)
try:
trainer.train()
except AssertionError:
self.assertEqual(trainer.state.global_step, 0)
def test_flos_extraction(self):
trainer = get_regression_trainer(learning_rate=0.1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment