Commit 6c1682f9 authored by Francisc Bungiu's avatar Francisc Bungiu Committed by Facebook GitHub Bot
Browse files

Allow to use TensorFloat32

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/403

`cfg.SOLVER.AMP.ENABLED` enabled mixed precision, but this only works for V100 GPUs.
For A100s, the equivalent is to enable TF32.

Reviewed By: tglik

Differential Revision: D40675242

fbshipit-source-id: 5cc3d12cd3d7ec76665e0907ecc87fc5f64d73f0
parent b94b23ee
...@@ -467,6 +467,12 @@ class Detectron2GoRunner(BaseRunner): ...@@ -467,6 +467,12 @@ class Detectron2GoRunner(BaseRunner):
trainer = (AMPTrainer if cfg.SOLVER.AMP.ENABLED else SimpleTrainer)( trainer = (AMPTrainer if cfg.SOLVER.AMP.ENABLED else SimpleTrainer)(
_get_model_with_abnormal_checker(model), data_loader, optimizer _get_model_with_abnormal_checker(model), data_loader, optimizer
) )
if cfg.SOLVER.AMP.ENABLED and torch.cuda.is_available():
# Allow to use the TensorFloat32 (TF32) tensor cores, available on A100 GPUs.
# For more details https://pytorch.org/docs/stable/notes/cuda.html#tf32-on-ampere.
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
trainer_hooks = self._get_trainer_hooks( trainer_hooks = self._get_trainer_hooks(
cfg, model, optimizer, scheduler, periodic_checkpointer, trainer cfg, model, optimizer, scheduler, periodic_checkpointer, trainer
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment