"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "a042909c836c794a34508b314b3ce8ce93a96284"
Commit 04a2956d authored by Fei Sun's avatar Fei Sun Committed by Facebook GitHub Bot
Browse files

Pass the zero_grad_before_forward flag to the trainer

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/525

In d2go, pass the argument ZERO_GRAD_BEFORE_FORWARD to the detectron runtime.

Reviewed By: tglik

Differential Revision: D44267319

fbshipit-source-id: 3bd5874bea96ac381fb49972a2dfe9bb52005a7d
parent abdeafb0
...@@ -118,6 +118,9 @@ def _add_detectron2go_runner_default_cfg(_C: CN) -> None: ...@@ -118,6 +118,9 @@ def _add_detectron2go_runner_default_cfg(_C: CN) -> None:
# Specify whether to perform NUMA binding # Specify whether to perform NUMA binding
_C.NUMA_BINDING = False _C.NUMA_BINDING = False
# Specify whether to zero the gradients before forward
_C.ZERO_GRAD_BEFORE_FORWARD = False
def _add_rcnn_default_config(_C: CN) -> None: def _add_rcnn_default_config(_C: CN) -> None:
_C.EXPORT_CAFFE2 = CN() _C.EXPORT_CAFFE2 = CN()
......
...@@ -548,6 +548,7 @@ class Detectron2GoRunner(D2GoDataAPIMixIn, BaseRunner): ...@@ -548,6 +548,7 @@ class Detectron2GoRunner(D2GoDataAPIMixIn, BaseRunner):
data_loader, data_loader,
optimizer, optimizer,
gather_metric_period=cfg.GATHER_METRIC_PERIOD, gather_metric_period=cfg.GATHER_METRIC_PERIOD,
zero_grad_before_forward=cfg.ZERO_GRAD_BEFORE_FORWARD,
grad_scaler=get_grad_scaler(cfg), grad_scaler=get_grad_scaler(cfg),
precision=parse_precision_from_string( precision=parse_precision_from_string(
cfg.SOLVER.AMP.PRECISION, lightning=False cfg.SOLVER.AMP.PRECISION, lightning=False
...@@ -560,6 +561,7 @@ class Detectron2GoRunner(D2GoDataAPIMixIn, BaseRunner): ...@@ -560,6 +561,7 @@ class Detectron2GoRunner(D2GoDataAPIMixIn, BaseRunner):
data_loader, data_loader,
optimizer, optimizer,
gather_metric_period=cfg.GATHER_METRIC_PERIOD, gather_metric_period=cfg.GATHER_METRIC_PERIOD,
zero_grad_before_forward=cfg.ZERO_GRAD_BEFORE_FORWARD,
) )
if cfg.SOLVER.AMP.ENABLED and torch.cuda.is_available(): if cfg.SOLVER.AMP.ENABLED and torch.cuda.is_available():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment