Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
36d46479
Unverified
Commit
36d46479
authored
Jun 16, 2022
by
Sylvain Gugger
Committed by
GitHub
Jun 16, 2022
Browse files
Refine Bf16 test for deepspeed (#17734)
* Refine BF16 check in CPU/GPU * Fixes * Renames
parent
f44e2c2b
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
29 additions
and
15 deletions
+29
-15
src/transformers/utils/__init__.py
src/transformers/utils/__init__.py
+2
-0
src/transformers/utils/import_utils.py
src/transformers/utils/import_utils.py
+24
-12
tests/deepspeed/test_deepspeed.py
tests/deepspeed/test_deepspeed.py
+3
-3
No files found.
src/transformers/utils/__init__.py
View file @
36d46479
...
@@ -125,6 +125,8 @@ from .import_utils import (
...
@@ -125,6 +125,8 @@ from .import_utils import (
is_tokenizers_available
,
is_tokenizers_available
,
is_torch_available
,
is_torch_available
,
is_torch_bf16_available
,
is_torch_bf16_available
,
is_torch_bf16_cpu_available
,
is_torch_bf16_gpu_available
,
is_torch_cuda_available
,
is_torch_cuda_available
,
is_torch_fx_available
,
is_torch_fx_available
,
is_torch_fx_proxy
,
is_torch_fx_proxy
,
...
...
src/transformers/utils/import_utils.py
View file @
36d46479
...
@@ -272,7 +272,7 @@ def is_torch_cuda_available():
...
@@ -272,7 +272,7 @@ def is_torch_cuda_available():
return
False
return
False
def
is_torch_bf16_available
():
def
is_torch_bf16_
gpu_
available
():
if
not
is_torch_available
():
if
not
is_torch_available
():
return
False
return
False
...
@@ -288,30 +288,42 @@ def is_torch_bf16_available():
...
@@ -288,30 +288,42 @@ def is_torch_bf16_available():
# 4. torch.autocast exists
# 4. torch.autocast exists
# XXX: one problem here is that it may give invalid results on mixed gpus setup, so it's
# XXX: one problem here is that it may give invalid results on mixed gpus setup, so it's
# really only correct for the 0th gpu (or currently set default device if different from 0)
# really only correct for the 0th gpu (or currently set default device if different from 0)
is_torch_gpu_bf16_available
=
True
is_torch_cpu_bf16_available
=
True
if
version
.
parse
(
torch
.
__version__
)
<
version
.
parse
(
"1.10"
):
if
version
.
parse
(
torch
.
__version__
)
<
version
.
parse
(
"1.10"
):
is_torch_gpu_bf16_available
=
False
return
False
is_torch_cpu_bf16_available
=
False
if
torch
.
cuda
.
is_available
()
and
torch
.
version
.
cuda
is
not
None
:
if
torch
.
cuda
.
is_available
()
and
torch
.
version
.
cuda
is
not
None
:
if
torch
.
cuda
.
get_device_properties
(
torch
.
cuda
.
current_device
()).
major
<
8
:
if
torch
.
cuda
.
get_device_properties
(
torch
.
cuda
.
current_device
()).
major
<
8
:
is_torch_gpu_bf16_available
=
False
return
False
if
int
(
torch
.
version
.
cuda
.
split
(
"."
)[
0
])
<
11
:
if
int
(
torch
.
version
.
cuda
.
split
(
"."
)[
0
])
<
11
:
is_torch_gpu_bf16_available
=
False
return
False
if
not
hasattr
(
torch
.
cuda
.
amp
,
"autocast"
):
if
not
hasattr
(
torch
.
cuda
.
amp
,
"autocast"
):
is_torch_gpu_bf16_available
=
False
return
False
else
:
else
:
is_torch_gpu_bf16_available
=
False
return
False
return
True
def
is_torch_bf16_cpu_available
():
if
not
is_torch_available
():
return
False
import
torch
if
version
.
parse
(
torch
.
__version__
)
<
version
.
parse
(
"1.10"
):
return
False
# checking CPU
try
:
try
:
# multiple levels of AttributeError depending on the pytorch version so do them all in one check
# multiple levels of AttributeError depending on the pytorch version so do them all in one check
_
=
torch
.
cpu
.
amp
.
autocast
_
=
torch
.
cpu
.
amp
.
autocast
except
AttributeError
:
except
AttributeError
:
is_torch_cpu_bf16_available
=
False
return
False
return
True
return
is_torch_cpu_bf16_available
or
is_torch_gpu_bf16_available
def
is_torch_bf16_available
():
return
is_torch_bf16_cpu_available
()
or
is_torch_bf16_gpu_available
()
def
is_torch_tf32_available
():
def
is_torch_tf32_available
():
...
...
tests/deepspeed/test_deepspeed.py
View file @
36d46479
...
@@ -42,7 +42,7 @@ from transformers.testing_utils import (
...
@@ -42,7 +42,7 @@ from transformers.testing_utils import (
slow
,
slow
,
)
)
from
transformers.trainer_utils
import
get_last_checkpoint
,
set_seed
from
transformers.trainer_utils
import
get_last_checkpoint
,
set_seed
from
transformers.utils
import
WEIGHTS_NAME
,
is_torch_bf16_available
from
transformers.utils
import
WEIGHTS_NAME
,
is_torch_bf16_
gpu_
available
if
is_torch_available
():
if
is_torch_available
():
...
@@ -129,7 +129,7 @@ FP16 = "fp16"
...
@@ -129,7 +129,7 @@ FP16 = "fp16"
BF16
=
"bf16"
BF16
=
"bf16"
stages
=
[
ZERO2
,
ZERO3
]
stages
=
[
ZERO2
,
ZERO3
]
if
is_torch_bf16_available
():
if
is_torch_bf16_
gpu_
available
():
dtypes
=
[
FP16
,
BF16
]
dtypes
=
[
FP16
,
BF16
]
else
:
else
:
dtypes
=
[
FP16
]
dtypes
=
[
FP16
]
...
@@ -920,7 +920,7 @@ class TestDeepSpeedWithLauncher(TestCasePlus):
...
@@ -920,7 +920,7 @@ class TestDeepSpeedWithLauncher(TestCasePlus):
@
require_torch_multi_gpu
@
require_torch_multi_gpu
@
parameterized
.
expand
([
"bf16"
,
"fp16"
,
"fp32"
])
@
parameterized
.
expand
([
"bf16"
,
"fp16"
,
"fp32"
])
def
test_inference
(
self
,
dtype
):
def
test_inference
(
self
,
dtype
):
if
dtype
==
"bf16"
and
not
is_torch_bf16_available
():
if
dtype
==
"bf16"
and
not
is_torch_bf16_
gpu_
available
():
self
.
skipTest
(
"test requires bfloat16 hardware support"
)
self
.
skipTest
(
"test requires bfloat16 hardware support"
)
# this is just inference, so no optimizer should be loaded
# this is just inference, so no optimizer should be loaded
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment