Unverified Commit e342ac7e authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Add some warning for Dynamo and enable TF32 when it's set (#20515)

parent 68cfffc4
...@@ -1148,6 +1148,15 @@ class TrainingArguments: ...@@ -1148,6 +1148,15 @@ class TrainingArguments:
" (`--bf16_full_eval`) can only be used on CUDA or CPU devices." " (`--bf16_full_eval`) can only be used on CUDA or CPU devices."
) )
if self.framework == "pt" and is_torch_available() and self.torchdynamo is not None:
if is_torch_tf32_available():
if self.tf32 is None and not self.fp16 or self.bf16:
logger.info("Setting TF32 in CUDA backends to speedup torchdynamo.")
torch.backends.cuda.matmul.allow_tf32 = True
else:
logger.warning(
"The speedups for torchdynamo mostly come wih GPU Ampere or higher and which is not detected here."
)
if self.framework == "pt" and is_torch_available() and self.tf32 is not None: if self.framework == "pt" and is_torch_available() and self.tf32 is not None:
if self.tf32: if self.tf32:
if is_torch_tf32_available(): if is_torch_tf32_available():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment