Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
bcc069dd
"docs/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "a54728f8c1087a7a8c604732de21d12193b3c23a"
Unverified
Commit
bcc069dd
authored
Dec 08, 2022
by
jeffhataws
Committed by
GitHub
Dec 08, 2022
Browse files
Enable bf16 option for XLA devices (#20684)
parent
9858ecd7
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
5 additions
and
4 deletions
+5
-4
src/transformers/trainer.py
src/transformers/trainer.py
+1
-1
src/transformers/training_args.py
src/transformers/training_args.py
+4
-3
No files found.
src/transformers/trainer.py
View file @
bcc069dd
...
...
@@ -565,7 +565,7 @@ class Trainer:
logger
.
info
(
f
"Using
{
args
.
half_precision_backend
}
half precision backend"
)
self
.
do_grad_scaling
=
False
if
(
args
.
fp16
or
args
.
bf16
)
and
not
(
args
.
deepspeed
or
is_sagemaker_mp_enabled
()):
if
(
args
.
fp16
or
args
.
bf16
)
and
not
(
args
.
deepspeed
or
is_sagemaker_mp_enabled
()
or
is_torch_tpu_available
()
):
# deepspeed and SageMaker Model Parallel manage their own half precision
if
args
.
half_precision_backend
==
"cuda_amp"
:
self
.
use_cuda_amp
=
True
...
...
src/transformers/training_args.py
View file @
bcc069dd
...
...
@@ -1122,9 +1122,9 @@ class TrainingArguments:
if
self
.
bf16
or
self
.
bf16_full_eval
:
if
self
.
no_cuda
and
not
is_torch_bf16_cpu_available
():
if
self
.
no_cuda
and
not
is_torch_bf16_cpu_available
()
and
not
is_torch_tpu_available
()
:
# cpu
raise
ValueError
(
"Your setup doesn't support bf16/cpu. You need torch>=1.10"
)
raise
ValueError
(
"Your setup doesn't support bf16/
(
cpu
, tpu, neuroncore)
. You need torch>=1.10"
)
elif
not
self
.
no_cuda
and
torch
.
cuda
.
is_available
()
and
not
is_torch_bf16_gpu_available
():
# gpu
raise
ValueError
(
...
...
@@ -1172,12 +1172,13 @@ class TrainingArguments:
and
is_torch_available
()
and
(
self
.
device
.
type
!=
"cuda"
)
and
(
get_xla_device_type
(
self
.
device
)
!=
"GPU"
)
and
(
get_xla_device_type
(
self
.
device
)
!=
"TPU"
)
and
(
self
.
device
.
type
!=
"cpu"
)
and
(
self
.
bf16
or
self
.
bf16_full_eval
)
):
raise
ValueError
(
"BF16 Mixed precision training with AMP (`--bf16`) and BF16 half precision evaluation"
" (`--bf16_full_eval`) can only be used on CUDA or CPU devices."
" (`--bf16_full_eval`) can only be used on CUDA or CPU
/TPU/NeuronCore
devices."
)
if
self
.
torchdynamo
is
not
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment