"vscode:/vscode.git/clone" did not exist on "287f387b8ce3983ad9a1f293bc1e4bf94b935328"
Commit bd6043ee authored by Peizhao Zhang's avatar Peizhao Zhang Committed by Facebook GitHub Bot
Browse files

Added hooks to report training progress to fblearner and keep alive.

Summary:
* Added a registry to register functions that could be used to register hooks for training.
  * TRAINER_HOOKS_REGISTRY: List of functions to add hooks for trainer, all functions in the registry will be called to add hooks
  * `func(hooks: List[HookBase]) -> None`

Reviewed By: zhanghang1989

Differential Revision: D27560806

fbshipit-source-id: fcfa02623bfd08508b6083db2d318d08f7e3c0b8
parent aeb24a92
...@@ -8,7 +8,7 @@ import math ...@@ -8,7 +8,7 @@ import math
import os import os
from collections import OrderedDict from collections import OrderedDict
from functools import lru_cache, partial from functools import lru_cache, partial
from typing import Type, Optional from typing import Type, Optional, List
import d2go.utils.abnormal_checker as abnormal_checker import d2go.utils.abnormal_checker as abnormal_checker
import detectron2.utils.comm as comm import detectron2.utils.comm as comm
...@@ -51,6 +51,7 @@ from detectron2.data import ( ...@@ -51,6 +51,7 @@ from detectron2.data import (
build_detection_train_loader as d2_build_detection_train_loader, build_detection_train_loader as d2_build_detection_train_loader,
MetadataCatalog, MetadataCatalog,
) )
from detectron2.engine import HookBase
from detectron2.engine import SimpleTrainer, AMPTrainer, hooks from detectron2.engine import SimpleTrainer, AMPTrainer, hooks
from detectron2.evaluation import ( from detectron2.evaluation import (
COCOEvaluator, COCOEvaluator,
...@@ -67,6 +68,7 @@ from detectron2.solver import ( ...@@ -67,6 +68,7 @@ from detectron2.solver import (
build_lr_scheduler as d2_build_lr_scheduler, build_lr_scheduler as d2_build_lr_scheduler,
) )
from detectron2.utils.events import CommonMetricPrinter, JSONWriter from detectron2.utils.events import CommonMetricPrinter, JSONWriter
from detectron2.utils.registry import Registry
from mobile_cv.arch.quantization.observer import update_stat as observer_update_stat from mobile_cv.arch.quantization.observer import update_stat as observer_update_stat
...@@ -221,6 +223,18 @@ class BaseRunner(object): ...@@ -221,6 +223,18 @@ class BaseRunner(object):
return d2_build_detection_train_loader(*args, **kwargs) return d2_build_detection_train_loader(*args, **kwargs)
# List of functions to add hooks for trainer, all functions in the registry will
# be called to add hooks
# func(hooks: List[HookBase]) -> None
TRAINER_HOOKS_REGISTRY = Registry("TRAINER_HOOKS_REGISTRY")
def update_hooks_from_registry(hooks: List[HookBase]):
for name, hook_func in TRAINER_HOOKS_REGISTRY:
logger.info(f"Update trainer hooks from {name}...")
hook_func(hooks)
class Detectron2GoRunner(BaseRunner): class Detectron2GoRunner(BaseRunner):
def register(self, cfg): def register(self, cfg):
super().register(cfg) super().register(cfg)
...@@ -458,6 +472,7 @@ class Detectron2GoRunner(BaseRunner): ...@@ -458,6 +472,7 @@ class Detectron2GoRunner(BaseRunner):
tbx_writer, tbx_writer,
] ]
trainer_hooks.append(hooks.PeriodicWriter(writers)) trainer_hooks.append(hooks.PeriodicWriter(writers))
update_hooks_from_registry(trainer_hooks)
trainer.register_hooks(trainer_hooks) trainer.register_hooks(trainer_hooks)
trainer.train(start_iter, max_iter) trainer.train(start_iter, max_iter)
......
...@@ -339,6 +339,26 @@ class TestDefaultRunner(unittest.TestCase): ...@@ -339,6 +339,26 @@ class TestDefaultRunner(unittest.TestCase):
default_runner._close_all_tbx_writers() default_runner._close_all_tbx_writers()
def test_d2go_runner_trainer_hooks(self):
counts = 0
@default_runner.TRAINER_HOOKS_REGISTRY.register()
def _check_hook_func(hooks):
nonlocal counts
counts = len(hooks)
print(hooks)
with tempfile.TemporaryDirectory() as tmp_dir:
ds_name = create_local_dataset(tmp_dir, 5, 10, 10)
runner = default_runner.Detectron2GoRunner()
cfg = _get_cfg(runner, tmp_dir, ds_name)
model = runner.build_model(cfg)
runner.do_train(cfg, model, resume=True)
default_runner._close_all_tbx_writers()
self.assertGreater(counts, 0)
def _compare_state_dict(sd1, sd2, abs_error=1e-3): def _compare_state_dict(sd1, sd2, abs_error=1e-3):
if len(sd1) != len(sd2): if len(sd1) != len(sd2):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment