Unverified Commit 6bc61aa7 authored by Xuehai Pan's avatar Xuehai Pan Committed by GitHub
Browse files

Set `TF32` flag for PyTorch cuDNN backend (#25075)

parent 5dba88b2
...@@ -203,6 +203,7 @@ improvement. All you need to do is to add the following to your code: ...@@ -203,6 +203,7 @@ improvement. All you need to do is to add the following to your code:
``` ```
import torch import torch
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
``` ```
CUDA will automatically switch to using tf32 instead of fp32 where possible, assuming that the used GPU is from the Ampere series. CUDA will automatically switch to using tf32 instead of fp32 where possible, assuming that the used GPU is from the Ampere series.
......
...@@ -1432,6 +1432,7 @@ class TrainingArguments: ...@@ -1432,6 +1432,7 @@ class TrainingArguments:
" otherwise." " otherwise."
) )
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
else: else:
logger.warning( logger.warning(
"The speedups for torchdynamo mostly come wih GPU Ampere or higher and which is not detected here." "The speedups for torchdynamo mostly come wih GPU Ampere or higher and which is not detected here."
...@@ -1440,11 +1441,13 @@ class TrainingArguments: ...@@ -1440,11 +1441,13 @@ class TrainingArguments:
if self.tf32: if self.tf32:
if is_torch_tf32_available(): if is_torch_tf32_available():
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
else: else:
raise ValueError("--tf32 requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.7") raise ValueError("--tf32 requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.7")
else: else:
if is_torch_tf32_available(): if is_torch_tf32_available():
torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
# no need to assert on else # no need to assert on else
if self.report_to is None: if self.report_to is None:
......
...@@ -167,6 +167,7 @@ class Jukebox1bModelTester(unittest.TestCase): ...@@ -167,6 +167,7 @@ class Jukebox1bModelTester(unittest.TestCase):
@slow @slow
def test_conditioning(self): def test_conditioning(self):
torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
model = JukeboxModel.from_pretrained(self.model_id, min_duration=0).eval() model = JukeboxModel.from_pretrained(self.model_id, min_duration=0).eval()
labels = self.prepare_inputs() labels = self.prepare_inputs()
...@@ -195,6 +196,7 @@ class Jukebox1bModelTester(unittest.TestCase): ...@@ -195,6 +196,7 @@ class Jukebox1bModelTester(unittest.TestCase):
@slow @slow
def test_primed_sampling(self): def test_primed_sampling(self):
torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
model = JukeboxModel.from_pretrained(self.model_id, min_duration=0).eval() model = JukeboxModel.from_pretrained(self.model_id, min_duration=0).eval()
set_seed(0) set_seed(0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment