Commit 6f43a43a authored by Anthony Chen's avatar Anthony Chen Committed by Facebook GitHub Bot
Browse files

Add an option to specify the period of metric gathering and writing in Trainer

Summary:
X-link: https://github.com/fairinternal/detectron2/pull/591

Pull Request resolved: https://github.com/facebookresearch/d2go/pull/469

X-link: https://github.com/facebookresearch/detectron2/pull/4785

Add an option to specify the period of metric gathering and writing in Trainer.

This feature is needed to optimize training speed for large-scale training jobs like generative AI. The reason is that the all_gather call in metric writing at every iteration is time-consuming when hundreds of gpus are used. This takes ~10% of the total training time. With this feature we can set the metric writing period as the same as cfg.WRITER_PERIOD=20 to reduce training time while still keeping metric logging the same to users

Reviewed By: miqueljubert, wat3rBro

Differential Revision:
D43098985

Privacy Context Container: 2011691122555468

fbshipit-source-id: 63c93a7331aa63badce5125e5240d2d5f7e61b74
parent f0f55cdc
...@@ -137,6 +137,8 @@ def get_base_runner_default_cfg(cfg: CN) -> CN: ...@@ -137,6 +137,8 @@ def get_base_runner_default_cfg(cfg: CN) -> CN:
cfg.SOLVER.AUTO_SCALING_METHODS = ["default_scale_d2_configs"] cfg.SOLVER.AUTO_SCALING_METHODS = ["default_scale_d2_configs"]
# Frequency of metric gathering in trainer.
cfg.GATHER_METRIC_PERIOD = 1
# Frequency of metric printer, tensorboard writer, etc. # Frequency of metric printer, tensorboard writer, etc.
cfg.WRITER_PERIOD = 20 cfg.WRITER_PERIOD = 20
......
...@@ -71,8 +71,8 @@ ALL_TB_WRITERS = [] ...@@ -71,8 +71,8 @@ ALL_TB_WRITERS = []
@lru_cache() @lru_cache()
def _get_tbx_writer(log_dir): def _get_tbx_writer(log_dir, window_size=20):
ret = TensorboardXWriter(log_dir) ret = TensorboardXWriter(log_dir, window_size=window_size)
ALL_TB_WRITERS.append(ret) ALL_TB_WRITERS.append(ret)
return ret return ret
...@@ -236,7 +236,10 @@ class D2GoDataAPIMixIn: ...@@ -236,7 +236,10 @@ class D2GoDataAPIMixIn:
@classmethod @classmethod
def get_tbx_writer(cls, cfg): def get_tbx_writer(cls, cfg):
return _get_tbx_writer(get_tensorboard_log_dir(cfg.OUTPUT_DIR)) return _get_tbx_writer(
get_tensorboard_log_dir(cfg.OUTPUT_DIR),
window_size=cfg.get("WRITER_PERIOD", 20),
)
@staticmethod @staticmethod
def get_data_loader_vis_wrapper() -> Optional[Type[DataLoaderVisWrapper]]: def get_data_loader_vis_wrapper() -> Optional[Type[DataLoaderVisWrapper]]:
...@@ -535,6 +538,7 @@ class Detectron2GoRunner(D2GoDataAPIMixIn, BaseRunner): ...@@ -535,6 +538,7 @@ class Detectron2GoRunner(D2GoDataAPIMixIn, BaseRunner):
_get_model_with_abnormal_checker(model), _get_model_with_abnormal_checker(model),
data_loader, data_loader,
optimizer, optimizer,
gather_metric_period=cfg.GATHER_METRIC_PERIOD,
grad_scaler=get_grad_scaler(cfg), grad_scaler=get_grad_scaler(cfg),
precision=parse_precision_from_string( precision=parse_precision_from_string(
cfg.SOLVER.AMP.PRECISION, lightning=False cfg.SOLVER.AMP.PRECISION, lightning=False
...@@ -542,7 +546,10 @@ class Detectron2GoRunner(D2GoDataAPIMixIn, BaseRunner): ...@@ -542,7 +546,10 @@ class Detectron2GoRunner(D2GoDataAPIMixIn, BaseRunner):
) )
else: else:
trainer = SimpleTrainer( trainer = SimpleTrainer(
_get_model_with_abnormal_checker(model), data_loader, optimizer _get_model_with_abnormal_checker(model),
data_loader,
optimizer,
gather_metric_period=cfg.GATHER_METRIC_PERIOD,
) )
if cfg.SOLVER.AMP.ENABLED and torch.cuda.is_available(): if cfg.SOLVER.AMP.ENABLED and torch.cuda.is_available():
...@@ -556,10 +563,17 @@ class Detectron2GoRunner(D2GoDataAPIMixIn, BaseRunner): ...@@ -556,10 +563,17 @@ class Detectron2GoRunner(D2GoDataAPIMixIn, BaseRunner):
) )
if comm.is_main_process(): if comm.is_main_process():
assert (
cfg.GATHER_METRIC_PERIOD <= cfg.WRITER_PERIOD
and cfg.WRITER_PERIOD % cfg.GATHER_METRIC_PERIOD == 0
), "WRITER_PERIOD needs to be divisible by GATHER_METRIC_PERIOD"
tbx_writer = self.get_tbx_writer(cfg) tbx_writer = self.get_tbx_writer(cfg)
writers = [ writers = [
CommonMetricPrinter(max_iter), CommonMetricPrinter(max_iter, window_size=cfg.WRITER_PERIOD),
JSONWriter(os.path.join(cfg.OUTPUT_DIR, "metrics.json")), JSONWriter(
os.path.join(cfg.OUTPUT_DIR, "metrics.json"),
window_size=cfg.WRITER_PERIOD,
),
tbx_writer, tbx_writer,
] ]
trainer_hooks.append(hooks.PeriodicWriter(writers, cfg.WRITER_PERIOD)) trainer_hooks.append(hooks.PeriodicWriter(writers, cfg.WRITER_PERIOD))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment