Commit c37ecd66 authored by Jiaxu Zhu's avatar Jiaxu Zhu Committed by Facebook GitHub Bot
Browse files

Training Reproducibility

Summary:
X-link: https://github.com/facebookresearch/detectron2/pull/4955

Pull Request resolved: https://github.com/facebookresearch/d2go/pull/540

Allow users to launch deterministic training jobs. That is, using the same training config, users can get identical training results.

Reviewed By: dilinwang820

Differential Revision: D45370627

fbshipit-source-id: 88db388c992500b0d789b8341952502cd1f8f995
parent 64a0e9a7
...@@ -90,6 +90,7 @@ def _add_detectron2go_runner_default_cfg(_C: CN) -> None: ...@@ -90,6 +90,7 @@ def _add_detectron2go_runner_default_cfg(_C: CN) -> None:
_C.SOLVER.BETAS = (0.9, 0.999) _C.SOLVER.BETAS = (0.9, 0.999)
_C.SOLVER.EPS = 1e-08 _C.SOLVER.EPS = 1e-08
_C.SOLVER.FUSED = False _C.SOLVER.FUSED = False
_C.SOLVER.DETERMINISTIC = False
# RECOMPUTE_BOXES for LSJ Training # RECOMPUTE_BOXES for LSJ Training
_C.INPUT.RECOMPUTE_BOXES = False _C.INPUT.RECOMPUTE_BOXES = False
...@@ -123,6 +124,9 @@ def _add_detectron2go_runner_default_cfg(_C: CN) -> None: ...@@ -123,6 +124,9 @@ def _add_detectron2go_runner_default_cfg(_C: CN) -> None:
# Specify whether to zero the gradients before forward # Specify whether to zero the gradients before forward
_C.ZERO_GRAD_BEFORE_FORWARD = False _C.ZERO_GRAD_BEFORE_FORWARD = False
# Whether to enforce rebuilding data loaders for datasets that have expiration
_C.DATALOADER.ENFORE_EXPIRATION = False
def _add_rcnn_default_config(_C: CN) -> None: def _add_rcnn_default_config(_C: CN) -> None:
_C.EXPORT_CAFFE2 = CN() _C.EXPORT_CAFFE2 = CN()
......
...@@ -295,6 +295,12 @@ def setup_after_launch( ...@@ -295,6 +295,12 @@ def setup_after_launch(
# scale the config after dumping so that dumped config files keep original world size # scale the config after dumping so that dumped config files keep original world size
auto_scale_world_size(cfg, new_world_size=comm.get_world_size()) auto_scale_world_size(cfg, new_world_size=comm.get_world_size())
# avoid random pytorch and CUDA algorithms during the training
if cfg.SOLVER.DETERMINISTIC:
logging.warning("Using deterministic training for the reproducibility")
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)
return runner return runner
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment