"benchmark/git@developer.sourcefind.cn:change/sglang.git" did not exist on "f84b57c80efa37f4956484ccb45ea984a170e4f5"
Unverified Commit a2d34b7c authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

deprecate is_torch_bf16_available (#17738)

* deprecate is_torch_bf16_available

* address suggestions
parent 132402d7
...@@ -67,7 +67,8 @@ from .utils import ( ...@@ -67,7 +67,8 @@ from .utils import (
is_timm_available, is_timm_available,
is_tokenizers_available, is_tokenizers_available,
is_torch_available, is_torch_available,
is_torch_bf16_available, is_torch_bf16_cpu_available,
is_torch_bf16_gpu_available,
is_torch_tf32_available, is_torch_tf32_available,
is_torch_tpu_available, is_torch_tpu_available,
is_torchaudio_available, is_torchaudio_available,
...@@ -486,11 +487,19 @@ def require_torch_gpu(test_case): ...@@ -486,11 +487,19 @@ def require_torch_gpu(test_case):
return unittest.skipUnless(torch_device == "cuda", "test requires CUDA")(test_case) return unittest.skipUnless(torch_device == "cuda", "test requires CUDA")(test_case)
def require_torch_bf16(test_case): def require_torch_bf16_gpu(test_case):
"""Decorator marking a test that requires torch>=1.10, using Ampere GPU or newer arch with cuda>=11.0 or using CPU.""" """Decorator marking a test that requires torch>=1.10, using Ampere GPU or newer arch with cuda>=11.0"""
return unittest.skipUnless( return unittest.skipUnless(
is_torch_bf16_available(), is_torch_bf16_gpu_available(),
"test requires torch>=1.10, using Ampere GPU or newer arch with cuda>=11.0 or using CPU", "test requires torch>=1.10, using Ampere GPU or newer arch with cuda>=11.0",
)(test_case)
def require_torch_bf16_cpu(test_case):
"""Decorator marking a test that requires torch>=1.10, using CPU."""
return unittest.skipUnless(
is_torch_bf16_cpu_available(),
"test requires torch>=1.10, using CPU",
)(test_case) )(test_case)
......
...@@ -39,7 +39,8 @@ from .utils import ( ...@@ -39,7 +39,8 @@ from .utils import (
is_sagemaker_dp_enabled, is_sagemaker_dp_enabled,
is_sagemaker_mp_enabled, is_sagemaker_mp_enabled,
is_torch_available, is_torch_available,
is_torch_bf16_available, is_torch_bf16_cpu_available,
is_torch_bf16_gpu_available,
is_torch_tf32_available, is_torch_tf32_available,
is_torch_tpu_available, is_torch_tpu_available,
logging, logging,
...@@ -1036,14 +1037,23 @@ class TrainingArguments: ...@@ -1036,14 +1037,23 @@ class TrainingArguments:
) )
self.half_precision_backend = self.fp16_backend self.half_precision_backend = self.fp16_backend
if (self.bf16 or self.bf16_full_eval) and not is_torch_bf16_available() and not self.no_cuda: if self.bf16 or self.bf16_full_eval:
raise ValueError(
"Your setup doesn't support bf16. You need torch>=1.10, using Ampere GPU with cuda>=11.0 or using CPU" if self.no_cuda and not is_torch_bf16_cpu_available():
" (no_cuda)" # cpu
) raise ValueError("Your setup doesn't support bf16/cpu. You need torch>=1.10")
elif not is_torch_bf16_gpu_available():
# gpu
raise ValueError(
"Your setup doesn't support bf16/gpu. You need torch>=1.10, using Ampere GPU with cuda>=11.0"
)
if self.fp16 and self.bf16: if self.fp16 and self.bf16:
raise ValueError("At most one of fp16 and bf16 can be True, but not both") raise ValueError("At most one of fp16 and bf16 can be True, but not both")
if self.fp16_full_eval and self.bf16_full_eval:
raise ValueError("At most one of fp16 and bf16 can be True for full eval, but not both")
if self.bf16: if self.bf16:
if self.half_precision_backend == "apex": if self.half_precision_backend == "apex":
raise ValueError( raise ValueError(
......
...@@ -19,6 +19,7 @@ import importlib.util ...@@ -19,6 +19,7 @@ import importlib.util
import json import json
import os import os
import sys import sys
import warnings
from collections import OrderedDict from collections import OrderedDict
from functools import wraps from functools import wraps
from itertools import chain from itertools import chain
...@@ -323,7 +324,14 @@ def is_torch_bf16_cpu_available(): ...@@ -323,7 +324,14 @@ def is_torch_bf16_cpu_available():
def is_torch_bf16_available(): def is_torch_bf16_available():
return is_torch_bf16_cpu_available() or is_torch_bf16_gpu_available() # the original bf16 check was for gpu only, but later a cpu/bf16 combo has emerged so this util
# has become ambiguous and therefore deprecated
warnings.warn(
"The util is_torch_bf16_available is deprecated, please use is_torch_bf16_gpu_available "
"or is_torch_bf16_cpu_available instead according to whether it's used with cpu or gpu",
FutureWarning,
)
return is_torch_bf16_gpu_available()
def is_torch_tf32_available(): def is_torch_tf32_available():
......
...@@ -306,7 +306,7 @@ stages = [ZERO2, ZERO3] ...@@ -306,7 +306,7 @@ stages = [ZERO2, ZERO3]
# #
# dtypes = [FP16] # dtypes = [FP16]
# so just hardcoding --fp16 for now # so just hardcoding --fp16 for now
# if is_torch_bf16_available(): # if is_torch_bf16_gpu_available():
# dtypes += [BF16] # dtypes += [BF16]
......
...@@ -57,7 +57,8 @@ from transformers.testing_utils import ( ...@@ -57,7 +57,8 @@ from transformers.testing_utils import (
require_sigopt, require_sigopt,
require_tokenizers, require_tokenizers,
require_torch, require_torch,
require_torch_bf16, require_torch_bf16_cpu,
require_torch_bf16_gpu,
require_torch_gpu, require_torch_gpu,
require_torch_multi_gpu, require_torch_multi_gpu,
require_torch_non_multi_gpu, require_torch_non_multi_gpu,
...@@ -554,7 +555,7 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon): ...@@ -554,7 +555,7 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
self.assertGreater(trainer.optimizer.state_dict()["param_groups"][0]["lr"], 0) self.assertGreater(trainer.optimizer.state_dict()["param_groups"][0]["lr"], 0)
@require_torch_gpu @require_torch_gpu
@require_torch_bf16 @require_torch_bf16_gpu
def test_mixed_bf16(self): def test_mixed_bf16(self):
# very basic test # very basic test
...@@ -641,7 +642,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): ...@@ -641,7 +642,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
train_output = trainer.train() train_output = trainer.train()
self.assertEqual(train_output.global_step, 10) self.assertEqual(train_output.global_step, 10)
@require_torch_bf16 @require_torch_bf16_cpu
@require_intel_extension_for_pytorch @require_intel_extension_for_pytorch
def test_number_of_steps_in_training_with_ipex(self): def test_number_of_steps_in_training_with_ipex(self):
for mix_bf16 in [True, False]: for mix_bf16 in [True, False]:
...@@ -885,7 +886,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): ...@@ -885,7 +886,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
expected_acc = AlmostAccuracy()((pred + 1, y))["accuracy"] expected_acc = AlmostAccuracy()((pred + 1, y))["accuracy"]
self.assertAlmostEqual(results["eval_accuracy"], expected_acc) self.assertAlmostEqual(results["eval_accuracy"], expected_acc)
@require_torch_bf16 @require_torch_bf16_cpu
@require_intel_extension_for_pytorch @require_intel_extension_for_pytorch
def test_evaluate_with_ipex(self): def test_evaluate_with_ipex(self):
for mix_bf16 in [True, False]: for mix_bf16 in [True, False]:
...@@ -1005,7 +1006,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): ...@@ -1005,7 +1006,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.assertTrue(np.array_equal(labels[0], trainer.eval_dataset.ys[0])) self.assertTrue(np.array_equal(labels[0], trainer.eval_dataset.ys[0]))
self.assertTrue(np.array_equal(labels[1], trainer.eval_dataset.ys[1])) self.assertTrue(np.array_equal(labels[1], trainer.eval_dataset.ys[1]))
@require_torch_bf16 @require_torch_bf16_cpu
@require_intel_extension_for_pytorch @require_intel_extension_for_pytorch
def test_predict_with_ipex(self): def test_predict_with_ipex(self):
for mix_bf16 in [True, False]: for mix_bf16 in [True, False]:
...@@ -1888,7 +1889,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): ...@@ -1888,7 +1889,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.assertGreater(orig_peak_mem, peak_mem * 2) self.assertGreater(orig_peak_mem, peak_mem * 2)
@require_torch_gpu @require_torch_gpu
@require_torch_bf16 @require_torch_bf16_gpu
def test_bf16_full_eval(self): def test_bf16_full_eval(self):
# note: most of the logic is the same as test_fp16_full_eval # note: most of the logic is the same as test_fp16_full_eval
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment