Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
6bc61aa7
Unverified
Commit
6bc61aa7
authored
Jul 25, 2023
by
Xuehai Pan
Committed by
GitHub
Jul 25, 2023
Browse files
Set `TF32` flag for PyTorch cuDNN backend (#25075)
parent
5dba88b2
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
6 additions
and
0 deletions
+6
-0
docs/source/en/perf_train_gpu_one.md
docs/source/en/perf_train_gpu_one.md
+1
-0
src/transformers/training_args.py
src/transformers/training_args.py
+3
-0
tests/models/jukebox/test_modeling_jukebox.py
tests/models/jukebox/test_modeling_jukebox.py
+2
-0
No files found.
docs/source/en/perf_train_gpu_one.md
View file @
6bc61aa7
...
...
@@ -203,6 +203,7 @@ improvement. All you need to do is to add the following to your code:
```
import torch
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
```
CUDA will automatically switch to using tf32 instead of fp32 where possible, assuming that the used GPU is from the Ampere series.
...
...
src/transformers/training_args.py
View file @
6bc61aa7
...
...
@@ -1432,6 +1432,7 @@ class TrainingArguments:
" otherwise."
)
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
True
torch
.
backends
.
cudnn
.
allow_tf32
=
True
else
:
logger
.
warning
(
"The speedups for torchdynamo mostly come wih GPU Ampere or higher and which is not detected here."
...
...
@@ -1440,11 +1441,13 @@ class TrainingArguments:
if
self
.
tf32
:
if
is_torch_tf32_available
():
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
True
torch
.
backends
.
cudnn
.
allow_tf32
=
True
else
:
raise
ValueError
(
"--tf32 requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.7"
)
else
:
if
is_torch_tf32_available
():
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
False
torch
.
backends
.
cudnn
.
allow_tf32
=
False
# no need to assert on else
if
self
.
report_to
is
None
:
...
...
tests/models/jukebox/test_modeling_jukebox.py
View file @
6bc61aa7
...
...
@@ -167,6 +167,7 @@ class Jukebox1bModelTester(unittest.TestCase):
@
slow
def
test_conditioning
(
self
):
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
False
torch
.
backends
.
cudnn
.
allow_tf32
=
False
model
=
JukeboxModel
.
from_pretrained
(
self
.
model_id
,
min_duration
=
0
).
eval
()
labels
=
self
.
prepare_inputs
()
...
...
@@ -195,6 +196,7 @@ class Jukebox1bModelTester(unittest.TestCase):
@
slow
def
test_primed_sampling
(
self
):
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
False
torch
.
backends
.
cudnn
.
allow_tf32
=
False
model
=
JukeboxModel
.
from_pretrained
(
self
.
model_id
,
min_duration
=
0
).
eval
()
set_seed
(
0
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment