Commit fc3a3983 authored by Artsiom Sanakoyeu's avatar Artsiom Sanakoyeu Committed by Facebook GitHub Bot
Browse files

Pass gradient clipping and mixed precision params to the lightning Trainer

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/374

AMP trained with mixed precision is implemented for the Native d2go Runner, but not for Lightning Tasks.

Now we pass params SOLVER.AMP* and SOLVER.CLIP_GRADIENTS* to the lightning Trainer as well.

Reviewed By: wat3rBro

Differential Revision: D39798007

fbshipit-source-id: e48560a91d37c21c56d953eed141876d8c759329
parent dc176d58
......@@ -63,7 +63,7 @@ def get_trainer_params(cfg: CfgNode) -> Dict[str, Any]:
strategy = _get_strategy(cfg)
accelerator = _get_accelerator(use_cpu)
return {
params = {
"max_epochs": -1,
"max_steps": cfg.SOLVER.MAX_ITER,
"val_check_interval": cfg.TEST.EVAL_PERIOD
......@@ -77,7 +77,20 @@ def get_trainer_params(cfg: CfgNode) -> Dict[str, Any]:
"logger": TensorBoardLogger(save_dir=cfg.OUTPUT_DIR),
"num_sanity_val_steps": 0,
"replace_sampler_ddp": False,
"precision": "mixed" if cfg.SOLVER.AMP.ENABLED else 32,
}
if cfg.SOLVER.CLIP_GRADIENTS.ENABLED:
if (
cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE.lower() == "norm"
and cfg.SOLVER.CLIP_GRADIENTS.NORM_TYPE != 2.0
):
raise ValueError(
"D2Go Lightning backend supports only L2-norm for norm-based gradient clipping!"
)
params["gradient_clip_val"] = cfg.SOLVER.CLIP_GRADIENTS.CLIP_VALUE
params["gradient_clip_algorithm"] = cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE
return params
def main(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment