"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "19547a57341bc4033b4f372d734c661a69e59311"
Commit fc3a3983 authored by Artsiom Sanakoyeu's avatar Artsiom Sanakoyeu Committed by Facebook GitHub Bot
Browse files

Pass gradient clipping and mixed precision params to the lightning Trainer

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/374

AMP trained with mixed precision is implemented for the Native d2go Runner, but not for Lightning Tasks.

Now we pass params SOLVER.AMP* and SOLVER.CLIP_GRADIENTS* to the lightning Trainer as well.

Reviewed By: wat3rBro

Differential Revision: D39798007

fbshipit-source-id: e48560a91d37c21c56d953eed141876d8c759329
parent dc176d58
...@@ -63,7 +63,7 @@ def get_trainer_params(cfg: CfgNode) -> Dict[str, Any]: ...@@ -63,7 +63,7 @@ def get_trainer_params(cfg: CfgNode) -> Dict[str, Any]:
strategy = _get_strategy(cfg) strategy = _get_strategy(cfg)
accelerator = _get_accelerator(use_cpu) accelerator = _get_accelerator(use_cpu)
return { params = {
"max_epochs": -1, "max_epochs": -1,
"max_steps": cfg.SOLVER.MAX_ITER, "max_steps": cfg.SOLVER.MAX_ITER,
"val_check_interval": cfg.TEST.EVAL_PERIOD "val_check_interval": cfg.TEST.EVAL_PERIOD
...@@ -77,7 +77,20 @@ def get_trainer_params(cfg: CfgNode) -> Dict[str, Any]: ...@@ -77,7 +77,20 @@ def get_trainer_params(cfg: CfgNode) -> Dict[str, Any]:
"logger": TensorBoardLogger(save_dir=cfg.OUTPUT_DIR), "logger": TensorBoardLogger(save_dir=cfg.OUTPUT_DIR),
"num_sanity_val_steps": 0, "num_sanity_val_steps": 0,
"replace_sampler_ddp": False, "replace_sampler_ddp": False,
"precision": "mixed" if cfg.SOLVER.AMP.ENABLED else 32,
} }
if cfg.SOLVER.CLIP_GRADIENTS.ENABLED:
if (
cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE.lower() == "norm"
and cfg.SOLVER.CLIP_GRADIENTS.NORM_TYPE != 2.0
):
raise ValueError(
"D2Go Lightning backend supports only L2-norm for norm-based gradient clipping!"
)
params["gradient_clip_val"] = cfg.SOLVER.CLIP_GRADIENTS.CLIP_VALUE
params["gradient_clip_algorithm"] = cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE
return params
def main( def main(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment