Unverified Commit 0034a1d2 authored by Prajjwal Bhargava's avatar Prajjwal Bhargava Committed by GitHub
Browse files

Add Pytorch Native AMP support in Trainer (#6151)

* fixed type; add Pytorch Native CUDA AMP support

* reverted commit on modeling_utils

* confirming to HF black formatting rule

* changed bool value of _use_apex

* scaler support for gradient clipping

* fix inplace operation of clip_grad_norm

* removed not while version comparison
parent 7231f7b5
......@@ -19,7 +19,7 @@ from torch.utils.data.sampler import RandomSampler, Sampler, SequentialSampler
from tqdm.auto import tqdm, trange
from .data.data_collator import DataCollator, default_data_collator
from .file_utils import is_apex_available, is_torch_tpu_available
from .file_utils import is_torch_tpu_available
from .modeling_utils import PreTrainedModel
from .optimization import AdamW, get_linear_schedule_with_warmup
from .trainer_utils import (
......@@ -33,8 +33,19 @@ from .trainer_utils import (
from .training_args import TrainingArguments
if is_apex_available():
_use_native_amp = False
_use_apex = False
# Check if Pytorch version >= 1.6 to switch between Native AMP and Apex
if version.parse(torch.__version__) < version.parse("1.6"):
from transformers.file_utils import is_apex_available
if is_apex_available():
from apex import amp
_use_apex = True
else:
_use_native_amp = True
from torch.cuda.amp import autocast
if is_torch_tpu_available():
......@@ -225,6 +236,8 @@ class Trainer:
),
FutureWarning,
)
if self.args.fp16 and _use_native_amp:
self.scaler = torch.cuda.amp.GradScaler()
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
if isinstance(self.train_dataset, torch.utils.data.IterableDataset):
......@@ -428,7 +441,7 @@ class Trainer:
scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt")))
model = self.model
if self.args.fp16:
if self.args.fp16 and _use_apex:
if not is_apex_available():
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
model, optimizer = amp.initialize(model, optimizer, opt_level=self.args.fp16_opt_level)
......@@ -525,13 +538,20 @@ class Trainer:
len(epoch_iterator) <= self.args.gradient_accumulation_steps
and (step + 1) == len(epoch_iterator)
):
if self.args.fp16:
if self.args.fp16 and _use_native_amp:
self.scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)
elif self.args.fp16 and _use_apex:
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), self.args.max_grad_norm)
else:
torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)
if is_torch_tpu_available():
xm.optimizer_step(optimizer)
if self.args.fp16 and _use_native_amp:
self.scaler.step(optimizer)
self.scaler.update()
else:
optimizer.step()
......@@ -697,6 +717,11 @@ class Trainer:
model.train()
inputs = self._prepare_inputs(inputs, model)
if self.args.fp16 and _use_native_amp:
with autocast():
outputs = model(**inputs)
loss = outputs[0]
else:
outputs = model(**inputs)
# We don't use .loss here since the model may return tuples instead of ModelOutput.
loss = outputs[0]
......@@ -706,10 +731,13 @@ class Trainer:
if self.args.n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu parallel training
if self.args.gradient_accumulation_steps > 1:
loss = loss / self.args.gradient_accumulation_steps
if self.args.fp16:
if self.args.fp16 and _use_native_amp:
self.scaler.scale(loss).backward()
elif self.args.fp16 and _use_apex:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment