"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "df4594a9dadca89482b69ec9a388456114a81491"
Unverified Commit 0034a1d2 authored by Prajjwal Bhargava's avatar Prajjwal Bhargava Committed by GitHub
Browse files

Add Pytorch Native AMP support in Trainer (#6151)

* fixed type; add Pytorch Native CUDA AMP support

* reverted commit on modeling_utils

* confirming to HF black formatting rule

* changed bool value of _use_apex

* scaler support for gradient clipping

* fix inplace operation of clip_grad_norm

* removed not while version comparison
parent 7231f7b5
...@@ -19,7 +19,7 @@ from torch.utils.data.sampler import RandomSampler, Sampler, SequentialSampler ...@@ -19,7 +19,7 @@ from torch.utils.data.sampler import RandomSampler, Sampler, SequentialSampler
from tqdm.auto import tqdm, trange from tqdm.auto import tqdm, trange
from .data.data_collator import DataCollator, default_data_collator from .data.data_collator import DataCollator, default_data_collator
from .file_utils import is_apex_available, is_torch_tpu_available from .file_utils import is_torch_tpu_available
from .modeling_utils import PreTrainedModel from .modeling_utils import PreTrainedModel
from .optimization import AdamW, get_linear_schedule_with_warmup from .optimization import AdamW, get_linear_schedule_with_warmup
from .trainer_utils import ( from .trainer_utils import (
...@@ -33,8 +33,19 @@ from .trainer_utils import ( ...@@ -33,8 +33,19 @@ from .trainer_utils import (
from .training_args import TrainingArguments from .training_args import TrainingArguments
if is_apex_available(): _use_native_amp = False
from apex import amp _use_apex = False
# Check if Pytorch version >= 1.6 to switch between Native AMP and Apex
if version.parse(torch.__version__) < version.parse("1.6"):
from transformers.file_utils import is_apex_available
if is_apex_available():
from apex import amp
_use_apex = True
else:
_use_native_amp = True
from torch.cuda.amp import autocast
if is_torch_tpu_available(): if is_torch_tpu_available():
...@@ -225,6 +236,8 @@ class Trainer: ...@@ -225,6 +236,8 @@ class Trainer:
), ),
FutureWarning, FutureWarning,
) )
if self.args.fp16 and _use_native_amp:
self.scaler = torch.cuda.amp.GradScaler()
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]: def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
if isinstance(self.train_dataset, torch.utils.data.IterableDataset): if isinstance(self.train_dataset, torch.utils.data.IterableDataset):
...@@ -428,7 +441,7 @@ class Trainer: ...@@ -428,7 +441,7 @@ class Trainer:
scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt"))) scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt")))
model = self.model model = self.model
if self.args.fp16: if self.args.fp16 and _use_apex:
if not is_apex_available(): if not is_apex_available():
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
model, optimizer = amp.initialize(model, optimizer, opt_level=self.args.fp16_opt_level) model, optimizer = amp.initialize(model, optimizer, opt_level=self.args.fp16_opt_level)
...@@ -525,13 +538,20 @@ class Trainer: ...@@ -525,13 +538,20 @@ class Trainer:
len(epoch_iterator) <= self.args.gradient_accumulation_steps len(epoch_iterator) <= self.args.gradient_accumulation_steps
and (step + 1) == len(epoch_iterator) and (step + 1) == len(epoch_iterator)
): ):
if self.args.fp16: if self.args.fp16 and _use_native_amp:
self.scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)
elif self.args.fp16 and _use_apex:
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), self.args.max_grad_norm) torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), self.args.max_grad_norm)
else: else:
torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm) torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)
if is_torch_tpu_available(): if is_torch_tpu_available():
xm.optimizer_step(optimizer) xm.optimizer_step(optimizer)
if self.args.fp16 and _use_native_amp:
self.scaler.step(optimizer)
self.scaler.update()
else: else:
optimizer.step() optimizer.step()
...@@ -697,19 +717,27 @@ class Trainer: ...@@ -697,19 +717,27 @@ class Trainer:
model.train() model.train()
inputs = self._prepare_inputs(inputs, model) inputs = self._prepare_inputs(inputs, model)
outputs = model(**inputs) if self.args.fp16 and _use_native_amp:
# We don't use .loss here since the model may return tuples instead of ModelOutput. with autocast():
loss = outputs[0] outputs = model(**inputs)
loss = outputs[0]
else:
outputs = model(**inputs)
# We don't use .loss here since the model may return tuples instead of ModelOutput.
loss = outputs[0]
if self.args.past_index >= 0: if self.args.past_index >= 0:
self._past = outputs[self.args.past_index] self._past = outputs[self.args.past_index]
if self.args.n_gpu > 1: if self.args.n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu parallel training loss = loss.mean() # mean() to average on multi-gpu parallel training
if self.args.gradient_accumulation_steps > 1: if self.args.gradient_accumulation_steps > 1:
loss = loss / self.args.gradient_accumulation_steps loss = loss / self.args.gradient_accumulation_steps
if self.args.fp16: if self.args.fp16 and _use_native_amp:
self.scaler.scale(loss).backward()
elif self.args.fp16 and _use_apex:
with amp.scale_loss(loss, optimizer) as scaled_loss: with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward() scaled_loss.backward()
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment