Commit 04a2956d authored by Fei Sun's avatar Fei Sun Committed by Facebook GitHub Bot
Browse files

Pass the zero_grad_before_forward flag to the trainer

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/525

In d2go, pass the argument ZERO_GRAD_BEFORE_FORWARD to the detectron runtime.

Reviewed By: tglik

Differential Revision: D44267319

fbshipit-source-id: 3bd5874bea96ac381fb49972a2dfe9bb52005a7d
parent abdeafb0
......@@ -118,6 +118,9 @@ def _add_detectron2go_runner_default_cfg(_C: CN) -> None:
# Specify whether to perform NUMA binding
_C.NUMA_BINDING = False
# Specify whether to zero the gradients before forward
_C.ZERO_GRAD_BEFORE_FORWARD = False
def _add_rcnn_default_config(_C: CN) -> None:
_C.EXPORT_CAFFE2 = CN()
......
......@@ -548,6 +548,7 @@ class Detectron2GoRunner(D2GoDataAPIMixIn, BaseRunner):
data_loader,
optimizer,
gather_metric_period=cfg.GATHER_METRIC_PERIOD,
zero_grad_before_forward=cfg.ZERO_GRAD_BEFORE_FORWARD,
grad_scaler=get_grad_scaler(cfg),
precision=parse_precision_from_string(
cfg.SOLVER.AMP.PRECISION, lightning=False
......@@ -560,6 +561,7 @@ class Detectron2GoRunner(D2GoDataAPIMixIn, BaseRunner):
data_loader,
optimizer,
gather_metric_period=cfg.GATHER_METRIC_PERIOD,
zero_grad_before_forward=cfg.ZERO_GRAD_BEFORE_FORWARD,
)
if cfg.SOLVER.AMP.ENABLED and torch.cuda.is_available():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment