Unverified Commit 6bc61aa7 authored by Xuehai Pan's avatar Xuehai Pan Committed by GitHub
Browse files

Set `TF32` flag for PyTorch cuDNN backend (#25075)

parent 5dba88b2
......@@ -203,6 +203,7 @@ improvement. All you need to do is to add the following to your code:
```
import torch
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
```
CUDA will automatically switch to using tf32 instead of fp32 where possible, assuming that the used GPU is from the Ampere series.
......
......@@ -1432,6 +1432,7 @@ class TrainingArguments:
" otherwise."
)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
else:
logger.warning(
"The speedups for torchdynamo mostly come wih GPU Ampere or higher and which is not detected here."
......@@ -1440,11 +1441,13 @@ class TrainingArguments:
if self.tf32:
if is_torch_tf32_available():
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
else:
raise ValueError("--tf32 requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.7")
else:
if is_torch_tf32_available():
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
# no need to assert on else
if self.report_to is None:
......
......@@ -167,6 +167,7 @@ class Jukebox1bModelTester(unittest.TestCase):
@slow
def test_conditioning(self):
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
model = JukeboxModel.from_pretrained(self.model_id, min_duration=0).eval()
labels = self.prepare_inputs()
......@@ -195,6 +196,7 @@ class Jukebox1bModelTester(unittest.TestCase):
@slow
def test_primed_sampling(self):
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
model = JukeboxModel.from_pretrained(self.model_id, min_duration=0).eval()
set_seed(0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment