Commit 382bec5b authored by Artsiom Sanakoyeu's avatar Artsiom Sanakoyeu Committed by Facebook GitHub Bot
Browse files

Add "float16" and "bfloat16" precision when training with lightning Task

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/381

Introduce extra parameter SOLVER.AMP.PRECISION which can be sued to control the mixed precision training when lightning  backend is used.

Previous value `precision: "mixed"` was worng and the training failed (See screenshot below)
{F777576618}

I had to make AMP.PRECISION as string and make sure that it can work with two values: "float16" and "bfloat16". Before feeding it to the Trainer we convert "float16" string to integer value 16. Such a workaround was unavoidable because D2 (https://github.com/facebookresearch/d2go/commit/87374efb134e539090e0b5c476809dc35bf6aedb)Go's config value cannot be of int and str at the same time.

Reviewed By: wat3rBro

Differential Revision: D40035367

fbshipit-source-id: ed4f615ab29a2258164cbe179a9adba11559d804
parent e0649685
...@@ -69,6 +69,9 @@ def _add_detectron2go_runner_default_cfg(_C: CN) -> None: ...@@ -69,6 +69,9 @@ def _add_detectron2go_runner_default_cfg(_C: CN) -> None:
_C.SOLVER.LR_MULTIPLIER_OVERWRITE = [] _C.SOLVER.LR_MULTIPLIER_OVERWRITE = []
_C.SOLVER.WEIGHT_DECAY_EMBED = 0.0 _C.SOLVER.WEIGHT_DECAY_EMBED = 0.0
_C.SOLVER.WEIGHT_DECAY_OVERWRITE = [] _C.SOLVER.WEIGHT_DECAY_OVERWRITE = []
assert not _C.SOLVER.AMP.ENABLED
# AMP precision is used by the lightning backend. Can be "float16" or "bfloat16".
_C.SOLVER.AMP.PRECISION = "float16"
# Betas are used in the AdamW optimizer # Betas are used in the AdamW optimizer
_C.SOLVER.BETAS = (0.9, 0.999) _C.SOLVER.BETAS = (0.9, 0.999)
......
...@@ -58,6 +58,18 @@ def _get_accelerator(use_cpu: bool) -> str: ...@@ -58,6 +58,18 @@ def _get_accelerator(use_cpu: bool) -> str:
return "cpu" if use_cpu else "gpu" return "cpu" if use_cpu else "gpu"
def _get_lightning_precision(precision: str) -> Union[str, int]:
"""
Convert our string format for precision to what lightning Trainer expects
"""
if precision == "float16":
return 16
elif precision == "bfloat16":
return "bf16"
else:
return precision
def get_trainer_params(cfg: CfgNode) -> Dict[str, Any]: def get_trainer_params(cfg: CfgNode) -> Dict[str, Any]:
use_cpu = cfg.MODEL.DEVICE.lower() == "cpu" use_cpu = cfg.MODEL.DEVICE.lower() == "cpu"
strategy = _get_strategy(cfg) strategy = _get_strategy(cfg)
...@@ -77,7 +89,9 @@ def get_trainer_params(cfg: CfgNode) -> Dict[str, Any]: ...@@ -77,7 +89,9 @@ def get_trainer_params(cfg: CfgNode) -> Dict[str, Any]:
"logger": TensorBoardLogger(save_dir=cfg.OUTPUT_DIR), "logger": TensorBoardLogger(save_dir=cfg.OUTPUT_DIR),
"num_sanity_val_steps": 0, "num_sanity_val_steps": 0,
"replace_sampler_ddp": False, "replace_sampler_ddp": False,
"precision": "mixed" if cfg.SOLVER.AMP.ENABLED else 32, "precision": _get_lightning_precision(cfg.SOLVER.AMP.PRECISION)
if cfg.SOLVER.AMP.ENABLED
else 32,
} }
if cfg.SOLVER.CLIP_GRADIENTS.ENABLED: if cfg.SOLVER.CLIP_GRADIENTS.ENABLED:
if ( if (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment