Commit bd6043ee authored by Peizhao Zhang's avatar Peizhao Zhang Committed by Facebook GitHub Bot
Browse files

Added hooks to report training progress to fblearner and keep alive.

Summary:
* Added a registry to register functions that could be used to register hooks for training.
  * TRAINER_HOOKS_REGISTRY: List of functions to add hooks for trainer, all functions in the registry will be called to add hooks
  * `func(hooks: List[HookBase]) -> None`

Reviewed By: zhanghang1989

Differential Revision: D27560806

fbshipit-source-id: fcfa02623bfd08508b6083db2d318d08f7e3c0b8
parent aeb24a92
......@@ -8,7 +8,7 @@ import math
import os
from collections import OrderedDict
from functools import lru_cache, partial
from typing import Type, Optional
from typing import Type, Optional, List
import d2go.utils.abnormal_checker as abnormal_checker
import detectron2.utils.comm as comm
......@@ -51,6 +51,7 @@ from detectron2.data import (
build_detection_train_loader as d2_build_detection_train_loader,
MetadataCatalog,
)
from detectron2.engine import HookBase
from detectron2.engine import SimpleTrainer, AMPTrainer, hooks
from detectron2.evaluation import (
COCOEvaluator,
......@@ -67,6 +68,7 @@ from detectron2.solver import (
build_lr_scheduler as d2_build_lr_scheduler,
)
from detectron2.utils.events import CommonMetricPrinter, JSONWriter
from detectron2.utils.registry import Registry
from mobile_cv.arch.quantization.observer import update_stat as observer_update_stat
......@@ -221,6 +223,18 @@ class BaseRunner(object):
return d2_build_detection_train_loader(*args, **kwargs)
# List of functions to add hooks for trainer, all functions in the registry will
# be called to add hooks
# func(hooks: List[HookBase]) -> None
TRAINER_HOOKS_REGISTRY = Registry("TRAINER_HOOKS_REGISTRY")
def update_hooks_from_registry(hooks: List[HookBase]):
for name, hook_func in TRAINER_HOOKS_REGISTRY:
logger.info(f"Update trainer hooks from {name}...")
hook_func(hooks)
class Detectron2GoRunner(BaseRunner):
def register(self, cfg):
super().register(cfg)
......@@ -458,6 +472,7 @@ class Detectron2GoRunner(BaseRunner):
tbx_writer,
]
trainer_hooks.append(hooks.PeriodicWriter(writers))
update_hooks_from_registry(trainer_hooks)
trainer.register_hooks(trainer_hooks)
trainer.train(start_iter, max_iter)
......
......@@ -339,6 +339,26 @@ class TestDefaultRunner(unittest.TestCase):
default_runner._close_all_tbx_writers()
def test_d2go_runner_trainer_hooks(self):
counts = 0
@default_runner.TRAINER_HOOKS_REGISTRY.register()
def _check_hook_func(hooks):
nonlocal counts
counts = len(hooks)
print(hooks)
with tempfile.TemporaryDirectory() as tmp_dir:
ds_name = create_local_dataset(tmp_dir, 5, 10, 10)
runner = default_runner.Detectron2GoRunner()
cfg = _get_cfg(runner, tmp_dir, ds_name)
model = runner.build_model(cfg)
runner.do_train(cfg, model, resume=True)
default_runner._close_all_tbx_writers()
self.assertGreater(counts, 0)
def _compare_state_dict(sd1, sd2, abs_error=1e-3):
if len(sd1) != len(sd2):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment