Commit 8d5c70e9 authored by Pavel Pidlypenskyi's avatar Pavel Pidlypenskyi Committed by Facebook GitHub Bot
Browse files

Allow to modify train hooks.

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/204

Having additional train hooks is a nice feature to have, especially when one wants to add some training metrics via hooks.

Reviewed By: tglik

Differential Revision: D35377418

fbshipit-source-id: ca8e00a3c64f992fe9f6975689e50a8b846a1a37
parent 6b4dbb31
...@@ -419,6 +419,25 @@ class Detectron2GoRunner(BaseRunner): ...@@ -419,6 +419,25 @@ class Detectron2GoRunner(BaseRunner):
return results return results
def _get_trainer_hooks(
self, cfg, model, optimizer, scheduler, periodic_checkpointer, trainer
):
return [
hooks.IterationTimer(),
model_ema.EMAHook(cfg, model) if cfg.MODEL_EMA.ENABLED else None,
self._create_data_loader_hook(cfg),
self._create_after_step_hook(
cfg, model, optimizer, scheduler, periodic_checkpointer
),
hooks.EvalHook(
cfg.TEST.EVAL_PERIOD,
lambda: self.do_test(cfg, model, train_iter=trainer.iter),
eval_after_train=False, # done by a separate do_test call in tools/train_net.py
),
kmeans_anchors.compute_kmeans_anchors_hook(self, cfg),
self._create_qat_hook(cfg) if cfg.QUANTIZATION.QAT.ENABLED else None,
]
def do_train(self, cfg, model, resume): def do_train(self, cfg, model, resume):
# Note that flops at the beginning of training is often inaccurate, # Note that flops at the beginning of training is often inaccurate,
# if a model has input-dependent logic # if a model has input-dependent logic
...@@ -463,21 +482,9 @@ class Detectron2GoRunner(BaseRunner): ...@@ -463,21 +482,9 @@ class Detectron2GoRunner(BaseRunner):
trainer = (AMPTrainer if cfg.SOLVER.AMP.ENABLED else SimpleTrainer)( trainer = (AMPTrainer if cfg.SOLVER.AMP.ENABLED else SimpleTrainer)(
_get_model_with_abnormal_checker(model), data_loader, optimizer _get_model_with_abnormal_checker(model), data_loader, optimizer
) )
trainer_hooks = [ trainer_hooks = self._get_trainer_hooks(
hooks.IterationTimer(), cfg, model, optimizer, scheduler, periodic_checkpointer, trainer
model_ema.EMAHook(cfg, model) if cfg.MODEL_EMA.ENABLED else None, )
self._create_data_loader_hook(cfg),
self._create_after_step_hook(
cfg, model, optimizer, scheduler, periodic_checkpointer
),
hooks.EvalHook(
cfg.TEST.EVAL_PERIOD,
lambda: self.do_test(cfg, model, train_iter=trainer.iter),
eval_after_train=False, # done by a separate do_test call in tools/train_net.py
),
kmeans_anchors.compute_kmeans_anchors_hook(self, cfg),
self._create_qat_hook(cfg) if cfg.QUANTIZATION.QAT.ENABLED else None,
]
if comm.is_main_process(): if comm.is_main_process():
tbx_writer = self.get_tbx_writer(cfg) tbx_writer = self.get_tbx_writer(cfg)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment