Unverified Commit fae0b9fa authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[trainer] conditional ctx managers into one wrapper (#14663)



* [trainer] conditional ctx managers into one wrapper

* workaround for contextlib.nullcontext for py<3.7

* Update src/transformers/trainer.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* one more autocast

* style
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 39f1dff5
...@@ -17,6 +17,7 @@ The Trainer class, to easily train a 🤗 Transformers from scratch or finetune ...@@ -17,6 +17,7 @@ The Trainer class, to easily train a 🤗 Transformers from scratch or finetune
""" """
import collections import collections
import contextlib
import inspect import inspect
import math import math
import os import os
...@@ -1837,6 +1838,21 @@ class Trainer: ...@@ -1837,6 +1838,21 @@ class Trainer:
return inputs return inputs
def autocast_smart_context_manager(self):
"""
A helper wrapper that creates an appropriate context manager for :obj:`autocast` while feeding it the desired
arguments, depending on the situation.
"""
if self.use_amp:
if version.parse(torch.__version__) >= version.parse("1.10"):
ctx_manager = autocast(dtype=self.amp_dtype)
else:
ctx_manager = autocast()
else:
ctx_manager = contextlib.nullcontext() if sys.version_info >= (3, 7) else contextlib.suppress()
return ctx_manager
def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
""" """
Perform a training step on a batch of inputs. Perform a training step on a batch of inputs.
...@@ -1863,14 +1879,7 @@ class Trainer: ...@@ -1863,14 +1879,7 @@ class Trainer:
loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps, scaler=scaler) loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps, scaler=scaler)
return loss_mb.reduce_mean().detach().to(self.args.device) return loss_mb.reduce_mean().detach().to(self.args.device)
if self.use_amp: with self.autocast_smart_context_manager():
if version.parse(torch.__version__) >= version.parse("1.10"):
with autocast(dtype=self.amp_dtype):
loss = self.compute_loss(model, inputs)
else:
with autocast():
loss = self.compute_loss(model, inputs)
else:
loss = self.compute_loss(model, inputs) loss = self.compute_loss(model, inputs)
if self.args.n_gpu > 1: if self.args.n_gpu > 1:
...@@ -2514,14 +2523,7 @@ class Trainer: ...@@ -2514,14 +2523,7 @@ class Trainer:
logits = smp_nested_concat(logits_mb) logits = smp_nested_concat(logits_mb)
else: else:
if has_labels: if has_labels:
if self.use_amp: with self.autocast_smart_context_manager():
if version.parse(torch.__version__) >= version.parse("1.10"):
with autocast(dtype=self.amp_dtype):
loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
else:
with autocast():
loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
else:
loss, outputs = self.compute_loss(model, inputs, return_outputs=True) loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
loss = loss.mean().detach() loss = loss.mean().detach()
...@@ -2531,14 +2533,7 @@ class Trainer: ...@@ -2531,14 +2533,7 @@ class Trainer:
logits = outputs[1:] logits = outputs[1:]
else: else:
loss = None loss = None
if self.use_amp: with self.autocast_smart_context_manager():
if version.parse(torch.__version__) >= version.parse("1.10"):
with autocast(dtype=self.amp_dtype):
outputs = model(**inputs)
else:
with autocast():
outputs = model(**inputs)
else:
outputs = model(**inputs) outputs = model(**inputs)
if isinstance(outputs, dict): if isinstance(outputs, dict):
logits = tuple(v for k, v in outputs.items() if k not in ignore_keys) logits = tuple(v for k, v in outputs.items() if k not in ignore_keys)
......
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
import torch import torch
from packaging import version
from torch import nn from torch import nn
from torch.utils.data import Dataset from torch.utils.data import Dataset
...@@ -25,10 +24,6 @@ from .trainer_utils import PredictionOutput ...@@ -25,10 +24,6 @@ from .trainer_utils import PredictionOutput
from .utils import logging from .utils import logging
if version.parse(torch.__version__) >= version.parse("1.6"):
from torch.cuda.amp import autocast
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -180,10 +175,7 @@ class Seq2SeqTrainer(Trainer): ...@@ -180,10 +175,7 @@ class Seq2SeqTrainer(Trainer):
generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"]) generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"])
with torch.no_grad(): with torch.no_grad():
if self.use_amp: with self.autocast_smart_context_manager():
with autocast():
outputs = model(**inputs)
else:
outputs = model(**inputs) outputs = model(**inputs)
if has_labels: if has_labels:
if self.label_smoother is not None: if self.label_smoother is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment