"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "66fa8ceaeaa6fe12f1bd4a5e6b0a924f59f715d9"
Unverified Commit 71b1bf7e authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[trainer] add tf32-mode control (#14606)



* [trainer] add --tf32 support

* it's pt>=.17

* it's pt>=.17

* flip the default to True

* add experimental note

* simplify logic

* style

* switch to 3-state logic

* doc

* Apply suggestions from code review
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* re-style code
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent aada989a
...@@ -358,8 +358,13 @@ Like all cases with reduced precision this may or may not be satisfactory for yo ...@@ -358,8 +358,13 @@ Like all cases with reduced precision this may or may not be satisfactory for yo
If you're already using fp16 or bf16 mixed precision it may help with the throughput as well. If you're already using fp16 or bf16 mixed precision it may help with the throughput as well.
You can enable this mode in the 🤗 Trainer with `--tf32`, or disable it with `--tf32 0` or `--no_tf32`.
By default the PyTorch default is used.
Note: tf32 mode is internal to CUDA and can't be accessed directly via `tensor.to(dtype=torch.tf32)` as `torch.tf32` doesn't exit. Note: tf32 mode is internal to CUDA and can't be accessed directly via `tensor.to(dtype=torch.tf32)` as `torch.tf32` doesn't exit.
Note: you need `torch>=1.7` to enjoy this feature.
### Gradient Checkpointing ### Gradient Checkpointing
......
...@@ -321,34 +321,52 @@ def is_torch_cuda_available(): ...@@ -321,34 +321,52 @@ def is_torch_cuda_available():
def is_torch_bf16_available(): def is_torch_bf16_available():
if is_torch_available(): if not is_torch_available():
import torch return False
# since currently no utility function is available we build our own. import torch
# some bits come from https://github.com/pytorch/pytorch/blob/2289a12f21c54da93bf5d696e3f9aea83dd9c10d/torch/testing/_internal/common_cuda.py#L51
# with additional check for torch version
# to succeed:
# 1. the hardware needs to support bf16 (arch >= Ampere)
# 2. torch >= 1.10 (1.9 should be enough for AMP API has changed in 1.10, so using 1.10 as minimal)
# 3. CUDA >= 11
# 4. torch.autocast exists
# XXX: one problem here is that it may give invalid results on mixed gpus setup, so it's
# really only correct for the 0th gpu (or currently set default device if different from 0)
if not torch.cuda.is_available() or torch.version.cuda is None:
return False
if torch.cuda.get_device_properties(torch.cuda.current_device()).major < 8:
return False
if int(torch.version.cuda.split(".")[0]) < 11:
return False
if not version.parse(torch.__version__) >= version.parse("1.10"):
return False
if not hasattr(torch, "autocast"):
return False
return True # since currently no utility function is available we build our own.
else: # some bits come from https://github.com/pytorch/pytorch/blob/2289a12f21c54da93bf5d696e3f9aea83dd9c10d/torch/testing/_internal/common_cuda.py#L51
# with additional check for torch version
# to succeed:
# 1. the hardware needs to support bf16 (arch >= Ampere)
# 2. torch >= 1.10 (1.9 should be enough for AMP API has changed in 1.10, so using 1.10 as minimal)
# 3. CUDA >= 11
# 4. torch.autocast exists
# XXX: one problem here is that it may give invalid results on mixed gpus setup, so it's
# really only correct for the 0th gpu (or currently set default device if different from 0)
if not torch.cuda.is_available() or torch.version.cuda is None:
return False
if torch.cuda.get_device_properties(torch.cuda.current_device()).major < 8:
return False
if int(torch.version.cuda.split(".")[0]) < 11:
return False
if version.parse(torch.__version__) < version.parse("1.10"):
return False return False
if not hasattr(torch, "autocast"):
return False
return True
def is_torch_tf32_available():
if not is_torch_available():
return False
import torch
if not torch.cuda.is_available() or torch.version.cuda is None:
return False
if torch.cuda.get_device_properties(torch.cuda.current_device()).major < 8:
return False
if int(torch.version.cuda.split(".")[0]) < 11:
return False
if version.parse(torch.__version__) < version.parse("1.7"):
return False
return True
_torch_fx_available = _torch_onnx_dict_inputs_support_available = False _torch_fx_available = _torch_onnx_dict_inputs_support_available = False
......
...@@ -50,6 +50,7 @@ from .file_utils import ( ...@@ -50,6 +50,7 @@ from .file_utils import (
is_tokenizers_available, is_tokenizers_available,
is_torch_available, is_torch_available,
is_torch_bf16_available, is_torch_bf16_available,
is_torch_tf32_available,
is_torch_tpu_available, is_torch_tpu_available,
is_torchaudio_available, is_torchaudio_available,
is_vision_available, is_vision_available,
...@@ -495,9 +496,17 @@ def require_torch_gpu(test_case): ...@@ -495,9 +496,17 @@ def require_torch_gpu(test_case):
def require_torch_bf16(test_case): def require_torch_bf16(test_case):
"""Decorator marking a test that requires CUDA hardware supporting bf16 and PyTorch >= 1.10.""" """Decorator marking a test that requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.10."""
if not is_torch_bf16_available(): if not is_torch_bf16_available():
return unittest.skip("test requires CUDA hardware supporting bf16 and PyTorch >= 1.10")(test_case) return unittest.skip("test requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.10")(test_case)
else:
return test_case
def require_torch_tf32(test_case):
"""Decorator marking a test that requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.7."""
if not is_torch_tf32_available():
return unittest.skip("test requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.7")(test_case)
else: else:
return test_case return test_case
......
...@@ -29,6 +29,7 @@ from .file_utils import ( ...@@ -29,6 +29,7 @@ from .file_utils import (
is_sagemaker_dp_enabled, is_sagemaker_dp_enabled,
is_sagemaker_mp_enabled, is_sagemaker_mp_enabled,
is_torch_available, is_torch_available,
is_torch_tf32_available,
is_torch_tpu_available, is_torch_tpu_available,
torch_required, torch_required,
) )
...@@ -227,6 +228,9 @@ class TrainingArguments: ...@@ -227,6 +228,9 @@ class TrainingArguments:
fp16_full_eval (:obj:`bool`, `optional`, defaults to :obj:`False`): fp16_full_eval (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to use full float16 evaluation instead of 32-bit. This will be faster and save memory but can harm Whether to use full float16 evaluation instead of 32-bit. This will be faster and save memory but can harm
metric values. metric values.
tf32 (:obj:`bool`, `optional`):
Whether to enable tf32 mode, available in Ampere and newer GPU architectures. This is an experimental API
and it may change.
local_rank (:obj:`int`, `optional`, defaults to -1): local_rank (:obj:`int`, `optional`, defaults to -1):
Rank of the process during distributed training. Rank of the process during distributed training.
xpu_backend (:obj:`str`, `optional`): xpu_backend (:obj:`str`, `optional`):
...@@ -548,6 +552,12 @@ class TrainingArguments: ...@@ -548,6 +552,12 @@ class TrainingArguments:
default=False, default=False,
metadata={"help": "Whether to use full float16 evaluation instead of 32-bit"}, metadata={"help": "Whether to use full float16 evaluation instead of 32-bit"},
) )
tf32: bool = field(
default=None,
metadata={
"help": "Whether to enable tf32 mode, available in Ampere and newer GPU architectures. This is an experimental API and it may change."
},
)
local_rank: int = field(default=-1, metadata={"help": "For distributed training: local_rank"}) local_rank: int = field(default=-1, metadata={"help": "For distributed training: local_rank"})
xpu_backend: str = field( xpu_backend: str = field(
default=None, default=None,
...@@ -802,6 +812,17 @@ class TrainingArguments: ...@@ -802,6 +812,17 @@ class TrainingArguments:
"Mixed precision training with AMP or APEX (`--fp16` or `--bf16`) and half precision evaluation (`--fp16_full_eval` or `--bf16_full_eval`) can only be used on CUDA devices." "Mixed precision training with AMP or APEX (`--fp16` or `--bf16`) and half precision evaluation (`--fp16_full_eval` or `--bf16_full_eval`) can only be used on CUDA devices."
) )
if is_torch_available() and self.tf32 is not None:
if self.tf32:
if is_torch_tf32_available():
torch.backends.cuda.matmul.allow_tf32 = True
else:
raise ValueError("--tf32 requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.7")
else:
if is_torch_tf32_available():
torch.backends.cuda.matmul.allow_tf32 = False
# no need to assert on else
if self.report_to is None: if self.report_to is None:
logger.info( logger.info(
"The default value for the training argument `--report_to` will change in v5 (from all installed " "The default value for the training argument `--report_to` will change in v5 (from all installed "
......
...@@ -57,6 +57,7 @@ from transformers.testing_utils import ( ...@@ -57,6 +57,7 @@ from transformers.testing_utils import (
require_torch_gpu, require_torch_gpu,
require_torch_multi_gpu, require_torch_multi_gpu,
require_torch_non_multi_gpu, require_torch_non_multi_gpu,
require_torch_tf32,
require_torch_up_to_2_gpus, require_torch_up_to_2_gpus,
slow, slow,
) )
...@@ -492,6 +493,15 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon): ...@@ -492,6 +493,15 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
# will add more specific tests once there are some bugs to fix # will add more specific tests once there are some bugs to fix
@require_torch_gpu
@require_torch_tf32
def test_tf32(self):
# very basic test
trainer = get_regression_trainer(learning_rate=0.1, tf32=True)
trainer.train()
self.check_trained_model(trainer.model)
@require_torch @require_torch
@require_sentencepiece @require_sentencepiece
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment