"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "88466358733da21a4ab45d85300ee6960f588e7d"
Commit 729682ff authored by Anthony Chen's avatar Anthony Chen Committed by Facebook GitHub Bot
Browse files

custom precision dtype for AMP training on D2 backend

Summary:
X-link: https://github.com/facebookresearch/detectron2/pull/4654

Pull Request resolved: https://github.com/facebookresearch/d2go/pull/412

Support custom precision dtype [float16, bfloat16] for AMP training on D2 (https://github.com/facebookresearch/d2go/commit/87374efb134e539090e0b5c476809dc35bf6aedb) backend. There's an old config key `SOLVER.AMP.PRECISION` that only works on lightning backend. This diff enables this config key on D2 (https://github.com/facebookresearch/d2go/commit/87374efb134e539090e0b5c476809dc35bf6aedb) backend (train_net binary) as well.

Reviewed By: tax313, wat3rBro

Differential Revision: D40811604

fbshipit-source-id: 58da17ae1519a54243b5295eb4253c297e4d9296
parent d660903c
...@@ -70,7 +70,7 @@ def _add_detectron2go_runner_default_cfg(_C: CN) -> None: ...@@ -70,7 +70,7 @@ def _add_detectron2go_runner_default_cfg(_C: CN) -> None:
_C.SOLVER.WEIGHT_DECAY_EMBED = 0.0 _C.SOLVER.WEIGHT_DECAY_EMBED = 0.0
_C.SOLVER.WEIGHT_DECAY_OVERWRITE = [] _C.SOLVER.WEIGHT_DECAY_OVERWRITE = []
assert not _C.SOLVER.AMP.ENABLED assert not _C.SOLVER.AMP.ENABLED
# AMP precision is used by the lightning backend. Can be "float16" or "bfloat16". # AMP precision is used by both D2 and lightning backend. Can be "float16" or "bfloat16".
_C.SOLVER.AMP.PRECISION = "float16" _C.SOLVER.AMP.PRECISION = "float16"
# Betas are used in the AdamW optimizer # Betas are used in the AdamW optimizer
......
...@@ -35,6 +35,7 @@ from d2go.runner.config_defaults import ( ...@@ -35,6 +35,7 @@ from d2go.runner.config_defaults import (
get_generalized_rcnn_runner_default_cfg, get_generalized_rcnn_runner_default_cfg,
) )
from d2go.runner.training_hooks import update_hooks_from_registry from d2go.runner.training_hooks import update_hooks_from_registry
from d2go.trainer.helper import parse_precision_from_string
from d2go.utils.flop_calculator import attach_profilers from d2go.utils.flop_calculator import attach_profilers
from d2go.utils.helper import D2Trainer, TensorboardXWriter from d2go.utils.helper import D2Trainer, TensorboardXWriter
from d2go.utils.misc import get_tensorboard_log_dir from d2go.utils.misc import get_tensorboard_log_dir
...@@ -464,9 +465,20 @@ class Detectron2GoRunner(BaseRunner): ...@@ -464,9 +465,20 @@ class Detectron2GoRunner(BaseRunner):
ret = abnormal_checker.AbnormalLossCheckerWrapper(model, checker) ret = abnormal_checker.AbnormalLossCheckerWrapper(model, checker)
return ret return ret
trainer = (AMPTrainer if cfg.SOLVER.AMP.ENABLED else SimpleTrainer)( if cfg.SOLVER.AMP.ENABLED:
_get_model_with_abnormal_checker(model), data_loader, optimizer trainer = AMPTrainer(
) _get_model_with_abnormal_checker(model),
data_loader,
optimizer,
precision=parse_precision_from_string(
cfg.SOLVER.AMP.PRECISION, lightning=False
),
)
else:
trainer = SimpleTrainer(
_get_model_with_abnormal_checker(model), data_loader, optimizer
)
if cfg.SOLVER.AMP.ENABLED and torch.cuda.is_available(): if cfg.SOLVER.AMP.ENABLED and torch.cuda.is_available():
# Allow to use the TensorFloat32 (TF32) tensor cores, available on A100 GPUs. # Allow to use the TensorFloat32 (TF32) tensor cores, available on A100 GPUs.
# For more details https://pytorch.org/docs/stable/notes/cuda.html#tf32-on-ampere. # For more details https://pytorch.org/docs/stable/notes/cuda.html#tf32-on-ampere.
......
from typing import Union
import torch
def parse_precision_from_string(
precision: str, lightning=False
) -> Union[str, int, torch.dtype]:
"""
Convert our string format for precision to what Detectron2 / lightning Trainer expects, controlled by the *lightning* flag
"""
if precision == "float64":
return torch.float64 if not lightning else 64
if precision == "float32":
return torch.float32 if not lightning else 32
elif precision == "float16":
return torch.float16 if not lightning else 16
elif precision == "bfloat16":
return torch.bfloat16 if not lightning else "bf16"
else:
raise ValueError(f"Invalid precision dtype {precision}")
...@@ -13,6 +13,7 @@ from d2go.runner.callbacks.quantization import QuantizationAwareTraining ...@@ -13,6 +13,7 @@ from d2go.runner.callbacks.quantization import QuantizationAwareTraining
from d2go.runner.lightning_task import DefaultTask from d2go.runner.lightning_task import DefaultTask
from d2go.setup import basic_argument_parser, prepare_for_launch, setup_after_launch from d2go.setup import basic_argument_parser, prepare_for_launch, setup_after_launch
from d2go.trainer.api import TrainNetOutput from d2go.trainer.api import TrainNetOutput
from d2go.trainer.helper import parse_precision_from_string
from d2go.trainer.lightning.training_loop import _do_test, _do_train from d2go.trainer.lightning.training_loop import _do_test, _do_train
from detectron2.utils.file_io import PathManager from detectron2.utils.file_io import PathManager
from pytorch_lightning.callbacks import Callback, LearningRateMonitor, TQDMProgressBar from pytorch_lightning.callbacks import Callback, LearningRateMonitor, TQDMProgressBar
...@@ -58,18 +59,6 @@ def _get_accelerator(use_cpu: bool) -> str: ...@@ -58,18 +59,6 @@ def _get_accelerator(use_cpu: bool) -> str:
return "cpu" if use_cpu else "gpu" return "cpu" if use_cpu else "gpu"
def _get_lightning_precision(precision: str) -> Union[str, int]:
"""
Convert our string format for precision to what lightning Trainer expects
"""
if precision == "float16":
return 16
elif precision == "bfloat16":
return "bf16"
else:
return precision
def get_trainer_params(cfg: CfgNode) -> Dict[str, Any]: def get_trainer_params(cfg: CfgNode) -> Dict[str, Any]:
use_cpu = cfg.MODEL.DEVICE.lower() == "cpu" use_cpu = cfg.MODEL.DEVICE.lower() == "cpu"
strategy = _get_strategy(cfg) strategy = _get_strategy(cfg)
...@@ -89,7 +78,9 @@ def get_trainer_params(cfg: CfgNode) -> Dict[str, Any]: ...@@ -89,7 +78,9 @@ def get_trainer_params(cfg: CfgNode) -> Dict[str, Any]:
"logger": TensorBoardLogger(save_dir=cfg.OUTPUT_DIR), "logger": TensorBoardLogger(save_dir=cfg.OUTPUT_DIR),
"num_sanity_val_steps": 0, "num_sanity_val_steps": 0,
"replace_sampler_ddp": False, "replace_sampler_ddp": False,
"precision": _get_lightning_precision(cfg.SOLVER.AMP.PRECISION) "precision": parse_precision_from_string(
cfg.SOLVER.AMP.PRECISION, lightning=True
)
if cfg.SOLVER.AMP.ENABLED if cfg.SOLVER.AMP.ENABLED
else 32, else 32,
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment