"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "acdd78db08151a51f7341a1c8cb623ac78496c78"
Unverified Commit fae0b9fa authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[trainer] conditional ctx managers into one wrapper (#14663)



* [trainer] conditional ctx managers into one wrapper

* workaround for contextlib.nullcontext for py<3.7

* Update src/transformers/trainer.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* one more autocast

* style
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 39f1dff5
......@@ -17,6 +17,7 @@ The Trainer class, to easily train a 🤗 Transformers from scratch or finetune
"""
import collections
import contextlib
import inspect
import math
import os
......@@ -1837,6 +1838,21 @@ class Trainer:
return inputs
def autocast_smart_context_manager(self):
"""
A helper wrapper that creates an appropriate context manager for :obj:`autocast` while feeding it the desired
arguments, depending on the situation.
"""
if self.use_amp:
if version.parse(torch.__version__) >= version.parse("1.10"):
ctx_manager = autocast(dtype=self.amp_dtype)
else:
ctx_manager = autocast()
else:
ctx_manager = contextlib.nullcontext() if sys.version_info >= (3, 7) else contextlib.suppress()
return ctx_manager
def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
"""
Perform a training step on a batch of inputs.
......@@ -1863,14 +1879,7 @@ class Trainer:
loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps, scaler=scaler)
return loss_mb.reduce_mean().detach().to(self.args.device)
if self.use_amp:
if version.parse(torch.__version__) >= version.parse("1.10"):
with autocast(dtype=self.amp_dtype):
loss = self.compute_loss(model, inputs)
else:
with autocast():
loss = self.compute_loss(model, inputs)
else:
with self.autocast_smart_context_manager():
loss = self.compute_loss(model, inputs)
if self.args.n_gpu > 1:
......@@ -2514,14 +2523,7 @@ class Trainer:
logits = smp_nested_concat(logits_mb)
else:
if has_labels:
if self.use_amp:
if version.parse(torch.__version__) >= version.parse("1.10"):
with autocast(dtype=self.amp_dtype):
loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
else:
with autocast():
loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
else:
with self.autocast_smart_context_manager():
loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
loss = loss.mean().detach()
......@@ -2531,14 +2533,7 @@ class Trainer:
logits = outputs[1:]
else:
loss = None
if self.use_amp:
if version.parse(torch.__version__) >= version.parse("1.10"):
with autocast(dtype=self.amp_dtype):
outputs = model(**inputs)
else:
with autocast():
outputs = model(**inputs)
else:
with self.autocast_smart_context_manager():
outputs = model(**inputs)
if isinstance(outputs, dict):
logits = tuple(v for k, v in outputs.items() if k not in ignore_keys)
......
......@@ -15,7 +15,6 @@
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from packaging import version
from torch import nn
from torch.utils.data import Dataset
......@@ -25,10 +24,6 @@ from .trainer_utils import PredictionOutput
from .utils import logging
if version.parse(torch.__version__) >= version.parse("1.6"):
from torch.cuda.amp import autocast
logger = logging.get_logger(__name__)
......@@ -180,10 +175,7 @@ class Seq2SeqTrainer(Trainer):
generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"])
with torch.no_grad():
if self.use_amp:
with autocast():
outputs = model(**inputs)
else:
with self.autocast_smart_context_manager():
outputs = model(**inputs)
if has_labels:
if self.label_smoother is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment