"megatron/git@developer.sourcefind.cn:OpenDAS/megatron-lm.git" did not exist on "183ad176b9934513b0557d1147fdec45ce49ceba"
Commit 8d5c70e9 authored by Pavel Pidlypenskyi's avatar Pavel Pidlypenskyi Committed by Facebook GitHub Bot
Browse files

Allow to modify train hooks.

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/204

Having additional train hooks is a nice feature to have, especially when one wants to add some training metrics via hooks.

Reviewed By: tglik

Differential Revision: D35377418

fbshipit-source-id: ca8e00a3c64f992fe9f6975689e50a8b846a1a37
parent 6b4dbb31
......@@ -419,6 +419,25 @@ class Detectron2GoRunner(BaseRunner):
return results
def _get_trainer_hooks(
self, cfg, model, optimizer, scheduler, periodic_checkpointer, trainer
):
return [
hooks.IterationTimer(),
model_ema.EMAHook(cfg, model) if cfg.MODEL_EMA.ENABLED else None,
self._create_data_loader_hook(cfg),
self._create_after_step_hook(
cfg, model, optimizer, scheduler, periodic_checkpointer
),
hooks.EvalHook(
cfg.TEST.EVAL_PERIOD,
lambda: self.do_test(cfg, model, train_iter=trainer.iter),
eval_after_train=False, # done by a separate do_test call in tools/train_net.py
),
kmeans_anchors.compute_kmeans_anchors_hook(self, cfg),
self._create_qat_hook(cfg) if cfg.QUANTIZATION.QAT.ENABLED else None,
]
def do_train(self, cfg, model, resume):
# Note that flops at the beginning of training is often inaccurate,
# if a model has input-dependent logic
......@@ -463,21 +482,9 @@ class Detectron2GoRunner(BaseRunner):
trainer = (AMPTrainer if cfg.SOLVER.AMP.ENABLED else SimpleTrainer)(
_get_model_with_abnormal_checker(model), data_loader, optimizer
)
trainer_hooks = [
hooks.IterationTimer(),
model_ema.EMAHook(cfg, model) if cfg.MODEL_EMA.ENABLED else None,
self._create_data_loader_hook(cfg),
self._create_after_step_hook(
cfg, model, optimizer, scheduler, periodic_checkpointer
),
hooks.EvalHook(
cfg.TEST.EVAL_PERIOD,
lambda: self.do_test(cfg, model, train_iter=trainer.iter),
eval_after_train=False, # done by a separate do_test call in tools/train_net.py
),
kmeans_anchors.compute_kmeans_anchors_hook(self, cfg),
self._create_qat_hook(cfg) if cfg.QUANTIZATION.QAT.ENABLED else None,
]
trainer_hooks = self._get_trainer_hooks(
cfg, model, optimizer, scheduler, periodic_checkpointer, trainer
)
if comm.is_main_process():
tbx_writer = self.get_tbx_writer(cfg)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment