Commit 729682ff authored by Anthony Chen's avatar Anthony Chen Committed by Facebook GitHub Bot
Browse files

custom precision dtype for AMP training on D2 backend

Summary:
X-link: https://github.com/facebookresearch/detectron2/pull/4654

Pull Request resolved: https://github.com/facebookresearch/d2go/pull/412

Support custom precision dtype [float16, bfloat16] for AMP training on D2 (https://github.com/facebookresearch/d2go/commit/87374efb134e539090e0b5c476809dc35bf6aedb) backend. There's an old config key `SOLVER.AMP.PRECISION` that only works on lightning backend. This diff enables this config key on D2 (https://github.com/facebookresearch/d2go/commit/87374efb134e539090e0b5c476809dc35bf6aedb) backend (train_net binary) as well.

Reviewed By: tax313, wat3rBro

Differential Revision: D40811604

fbshipit-source-id: 58da17ae1519a54243b5295eb4253c297e4d9296
parent d660903c
......@@ -70,7 +70,7 @@ def _add_detectron2go_runner_default_cfg(_C: CN) -> None:
_C.SOLVER.WEIGHT_DECAY_EMBED = 0.0
_C.SOLVER.WEIGHT_DECAY_OVERWRITE = []
assert not _C.SOLVER.AMP.ENABLED
# AMP precision is used by the lightning backend. Can be "float16" or "bfloat16".
# AMP precision is used by both D2 and lightning backend. Can be "float16" or "bfloat16".
_C.SOLVER.AMP.PRECISION = "float16"
# Betas are used in the AdamW optimizer
......
......@@ -35,6 +35,7 @@ from d2go.runner.config_defaults import (
get_generalized_rcnn_runner_default_cfg,
)
from d2go.runner.training_hooks import update_hooks_from_registry
from d2go.trainer.helper import parse_precision_from_string
from d2go.utils.flop_calculator import attach_profilers
from d2go.utils.helper import D2Trainer, TensorboardXWriter
from d2go.utils.misc import get_tensorboard_log_dir
......@@ -464,9 +465,20 @@ class Detectron2GoRunner(BaseRunner):
ret = abnormal_checker.AbnormalLossCheckerWrapper(model, checker)
return ret
trainer = (AMPTrainer if cfg.SOLVER.AMP.ENABLED else SimpleTrainer)(
_get_model_with_abnormal_checker(model), data_loader, optimizer
)
if cfg.SOLVER.AMP.ENABLED:
trainer = AMPTrainer(
_get_model_with_abnormal_checker(model),
data_loader,
optimizer,
precision=parse_precision_from_string(
cfg.SOLVER.AMP.PRECISION, lightning=False
),
)
else:
trainer = SimpleTrainer(
_get_model_with_abnormal_checker(model), data_loader, optimizer
)
if cfg.SOLVER.AMP.ENABLED and torch.cuda.is_available():
# Allow to use the TensorFloat32 (TF32) tensor cores, available on A100 GPUs.
# For more details https://pytorch.org/docs/stable/notes/cuda.html#tf32-on-ampere.
......
from typing import Union
import torch
def parse_precision_from_string(
precision: str, lightning=False
) -> Union[str, int, torch.dtype]:
"""
Convert our string format for precision to what Detectron2 / lightning Trainer expects, controlled by the *lightning* flag
"""
if precision == "float64":
return torch.float64 if not lightning else 64
if precision == "float32":
return torch.float32 if not lightning else 32
elif precision == "float16":
return torch.float16 if not lightning else 16
elif precision == "bfloat16":
return torch.bfloat16 if not lightning else "bf16"
else:
raise ValueError(f"Invalid precision dtype {precision}")
......@@ -13,6 +13,7 @@ from d2go.runner.callbacks.quantization import QuantizationAwareTraining
from d2go.runner.lightning_task import DefaultTask
from d2go.setup import basic_argument_parser, prepare_for_launch, setup_after_launch
from d2go.trainer.api import TrainNetOutput
from d2go.trainer.helper import parse_precision_from_string
from d2go.trainer.lightning.training_loop import _do_test, _do_train
from detectron2.utils.file_io import PathManager
from pytorch_lightning.callbacks import Callback, LearningRateMonitor, TQDMProgressBar
......@@ -58,18 +59,6 @@ def _get_accelerator(use_cpu: bool) -> str:
return "cpu" if use_cpu else "gpu"
def _get_lightning_precision(precision: str) -> Union[str, int]:
"""
Convert our string format for precision to what lightning Trainer expects
"""
if precision == "float16":
return 16
elif precision == "bfloat16":
return "bf16"
else:
return precision
def get_trainer_params(cfg: CfgNode) -> Dict[str, Any]:
use_cpu = cfg.MODEL.DEVICE.lower() == "cpu"
strategy = _get_strategy(cfg)
......@@ -89,7 +78,9 @@ def get_trainer_params(cfg: CfgNode) -> Dict[str, Any]:
"logger": TensorBoardLogger(save_dir=cfg.OUTPUT_DIR),
"num_sanity_val_steps": 0,
"replace_sampler_ddp": False,
"precision": _get_lightning_precision(cfg.SOLVER.AMP.PRECISION)
"precision": parse_precision_from_string(
cfg.SOLVER.AMP.PRECISION, lightning=True
)
if cfg.SOLVER.AMP.ENABLED
else 32,
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment