"references/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "c486bb14ad386257b4125ae08a11c3f87b5f41fc"
Commit 382bec5b authored by Artsiom Sanakoyeu's avatar Artsiom Sanakoyeu Committed by Facebook GitHub Bot
Browse files

Add "float16" and "bfloat16" precision when training with lightning Task

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/381

Introduce extra parameter SOLVER.AMP.PRECISION which can be sued to control the mixed precision training when lightning  backend is used.

Previous value `precision: "mixed"` was worng and the training failed (See screenshot below)
{F777576618}

I had to make AMP.PRECISION as string and make sure that it can work with two values: "float16" and "bfloat16". Before feeding it to the Trainer we convert "float16" string to integer value 16. Such a workaround was unavoidable because D2 (https://github.com/facebookresearch/d2go/commit/87374efb134e539090e0b5c476809dc35bf6aedb)Go's config value cannot be of int and str at the same time.

Reviewed By: wat3rBro

Differential Revision: D40035367

fbshipit-source-id: ed4f615ab29a2258164cbe179a9adba11559d804
parent e0649685
...@@ -69,6 +69,9 @@ def _add_detectron2go_runner_default_cfg(_C: CN) -> None: ...@@ -69,6 +69,9 @@ def _add_detectron2go_runner_default_cfg(_C: CN) -> None:
_C.SOLVER.LR_MULTIPLIER_OVERWRITE = [] _C.SOLVER.LR_MULTIPLIER_OVERWRITE = []
_C.SOLVER.WEIGHT_DECAY_EMBED = 0.0 _C.SOLVER.WEIGHT_DECAY_EMBED = 0.0
_C.SOLVER.WEIGHT_DECAY_OVERWRITE = [] _C.SOLVER.WEIGHT_DECAY_OVERWRITE = []
assert not _C.SOLVER.AMP.ENABLED
# AMP precision is used by the lightning backend. Can be "float16" or "bfloat16".
_C.SOLVER.AMP.PRECISION = "float16"
# Betas are used in the AdamW optimizer # Betas are used in the AdamW optimizer
_C.SOLVER.BETAS = (0.9, 0.999) _C.SOLVER.BETAS = (0.9, 0.999)
......
...@@ -58,6 +58,18 @@ def _get_accelerator(use_cpu: bool) -> str: ...@@ -58,6 +58,18 @@ def _get_accelerator(use_cpu: bool) -> str:
return "cpu" if use_cpu else "gpu" return "cpu" if use_cpu else "gpu"
def _get_lightning_precision(precision: str) -> Union[str, int]:
"""
Convert our string format for precision to what lightning Trainer expects
"""
if precision == "float16":
return 16
elif precision == "bfloat16":
return "bf16"
else:
return precision
def get_trainer_params(cfg: CfgNode) -> Dict[str, Any]: def get_trainer_params(cfg: CfgNode) -> Dict[str, Any]:
use_cpu = cfg.MODEL.DEVICE.lower() == "cpu" use_cpu = cfg.MODEL.DEVICE.lower() == "cpu"
strategy = _get_strategy(cfg) strategy = _get_strategy(cfg)
...@@ -77,7 +89,9 @@ def get_trainer_params(cfg: CfgNode) -> Dict[str, Any]: ...@@ -77,7 +89,9 @@ def get_trainer_params(cfg: CfgNode) -> Dict[str, Any]:
"logger": TensorBoardLogger(save_dir=cfg.OUTPUT_DIR), "logger": TensorBoardLogger(save_dir=cfg.OUTPUT_DIR),
"num_sanity_val_steps": 0, "num_sanity_val_steps": 0,
"replace_sampler_ddp": False, "replace_sampler_ddp": False,
"precision": "mixed" if cfg.SOLVER.AMP.ENABLED else 32, "precision": _get_lightning_precision(cfg.SOLVER.AMP.PRECISION)
if cfg.SOLVER.AMP.ENABLED
else 32,
} }
if cfg.SOLVER.CLIP_GRADIENTS.ENABLED: if cfg.SOLVER.CLIP_GRADIENTS.ENABLED:
if ( if (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment