Commit e0649685 authored by Francisc Bungiu's avatar Francisc Bungiu Committed by Facebook GitHub Bot
Browse files

enable zoomer for train_net

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/378

Some hooks need access to cfg to be initialized correctly. Pass cfg down the hook registration method.

Reviewed By: ertrue, miqueljubert

Differential Revision: D39303862

fbshipit-source-id: 931c356c7045f95fc0af5b20c7782ea4d1aff138
parent 1551bf13
...@@ -472,7 +472,7 @@ class Detectron2GoRunner(BaseRunner): ...@@ -472,7 +472,7 @@ class Detectron2GoRunner(BaseRunner):
tbx_writer, tbx_writer,
] ]
trainer_hooks.append(hooks.PeriodicWriter(writers, cfg.WRITER_PERIOD)) trainer_hooks.append(hooks.PeriodicWriter(writers, cfg.WRITER_PERIOD))
update_hooks_from_registry(trainer_hooks) update_hooks_from_registry(trainer_hooks, cfg)
trainer.register_hooks(trainer_hooks) trainer.register_hooks(trainer_hooks)
trainer.train(start_iter, max_iter) trainer.train(start_iter, max_iter)
......
...@@ -3,9 +3,12 @@ ...@@ -3,9 +3,12 @@
import logging import logging
from typing import List from typing import List
from d2go.config import CfgNode
from detectron2.engine import HookBase from detectron2.engine import HookBase
from detectron2.utils.registry import Registry from detectron2.utils.registry import Registry
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# List of functions to add hooks for trainer, all functions in the registry will # List of functions to add hooks for trainer, all functions in the registry will
...@@ -14,7 +17,7 @@ logger = logging.getLogger(__name__) ...@@ -14,7 +17,7 @@ logger = logging.getLogger(__name__)
TRAINER_HOOKS_REGISTRY = Registry("TRAINER_HOOKS_REGISTRY") TRAINER_HOOKS_REGISTRY = Registry("TRAINER_HOOKS_REGISTRY")
def update_hooks_from_registry(hooks: List[HookBase]): def update_hooks_from_registry(hooks: List[HookBase], cfg: CfgNode):
for name, hook_func in TRAINER_HOOKS_REGISTRY: for name, hook_func in TRAINER_HOOKS_REGISTRY:
logger.info(f"Update trainer hooks from {name}...") logger.info(f"Update trainer hooks from {name}...")
hook_func(hooks) hook_func(hooks, cfg)
...@@ -355,7 +355,7 @@ class TestDefaultRunner(unittest.TestCase): ...@@ -355,7 +355,7 @@ class TestDefaultRunner(unittest.TestCase):
counts = 0 counts = 0
@TRAINER_HOOKS_REGISTRY.register() @TRAINER_HOOKS_REGISTRY.register()
def _check_hook_func(hooks): def _check_hook_func(hooks, cfg):
nonlocal counts nonlocal counts
counts = len(hooks) counts = len(hooks)
print(hooks) print(hooks)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment