Commit e0649685 authored by Francisc Bungiu's avatar Francisc Bungiu Committed by Facebook GitHub Bot
Browse files

enable zoomer for train_net

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/378

Some hooks need access to cfg to be initialized correctly. Pass cfg down the hook registration method.

Reviewed By: ertrue, miqueljubert

Differential Revision: D39303862

fbshipit-source-id: 931c356c7045f95fc0af5b20c7782ea4d1aff138
parent 1551bf13
......@@ -472,7 +472,7 @@ class Detectron2GoRunner(BaseRunner):
tbx_writer,
]
trainer_hooks.append(hooks.PeriodicWriter(writers, cfg.WRITER_PERIOD))
update_hooks_from_registry(trainer_hooks)
update_hooks_from_registry(trainer_hooks, cfg)
trainer.register_hooks(trainer_hooks)
trainer.train(start_iter, max_iter)
......
......@@ -3,9 +3,12 @@
import logging
from typing import List
from d2go.config import CfgNode
from detectron2.engine import HookBase
from detectron2.utils.registry import Registry
logger = logging.getLogger(__name__)
# List of functions to add hooks for trainer, all functions in the registry will
......@@ -14,7 +17,7 @@ logger = logging.getLogger(__name__)
TRAINER_HOOKS_REGISTRY = Registry("TRAINER_HOOKS_REGISTRY")
def update_hooks_from_registry(hooks: List[HookBase]):
def update_hooks_from_registry(hooks: List[HookBase], cfg: CfgNode):
for name, hook_func in TRAINER_HOOKS_REGISTRY:
logger.info(f"Update trainer hooks from {name}...")
hook_func(hooks)
hook_func(hooks, cfg)
......@@ -355,7 +355,7 @@ class TestDefaultRunner(unittest.TestCase):
counts = 0
@TRAINER_HOOKS_REGISTRY.register()
def _check_hook_func(hooks):
def _check_hook_func(hooks, cfg):
nonlocal counts
counts = len(hooks)
print(hooks)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment