Commit 8d58b499 authored by Miquel Jubert Hermoso's avatar Miquel Jubert Hermoso Committed by Facebook GitHub Bot
Browse files

Refactor runner and runner_fb, to be autodeps friendly and user OSS option 2

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/245

At the moment D2 (https://github.com/facebookresearch/d2go/commit/87374efb134e539090e0b5c476809dc35bf6aedb)Go's runner still uses the OSS pattern 1 (see wiki), where the files get remapped. This does not work with D2 (https://github.com/facebookresearch/d2go/commit/87374efb134e539090e0b5c476809dc35bf6aedb)Go, and makes it necessary to use some renaming tricks.

This diff refactors the runner setup, to reduce the number of classes, and rely on fb_overwrite to add the correct fields to the config.

Reviewed By: wat3rBro

Differential Revision: D36316955

fbshipit-source-id: 4aaaece121b8df802f9395648c97a647fa7db857
parent d9c04ecc
...@@ -5,13 +5,9 @@ ...@@ -5,13 +5,9 @@
import importlib import importlib
from typing import Optional, Type, Union from typing import Optional, Type, Union
from .default_runner import ( from .default_runner import BaseRunner, Detectron2GoRunner, GeneralizedRCNNRunner
BaseRunner,
Detectron2GoRunner,
GeneralizedRCNNRunner,
TRAINER_HOOKS_REGISTRY,
)
from .lightning_task import DefaultTask from .lightning_task import DefaultTask
from .training_hooks import TRAINER_HOOKS_REGISTRY
__all__ = [ __all__ = [
......
...@@ -27,10 +27,12 @@ from d2go.modeling import build_model, kmeans_anchors, model_ema ...@@ -27,10 +27,12 @@ from d2go.modeling import build_model, kmeans_anchors, model_ema
from d2go.modeling.model_freezing_utils import freeze_matched_bn, set_requires_grad from d2go.modeling.model_freezing_utils import freeze_matched_bn, set_requires_grad
from d2go.optimizer import build_optimizer_mapper from d2go.optimizer import build_optimizer_mapper
from d2go.quantization.modeling import QATCheckpointer, QATHook, setup_qat_model from d2go.quantization.modeling import QATCheckpointer, QATHook, setup_qat_model
from d2go.runner.training_hooks import update_hooks_from_registry
from d2go.utils.flop_calculator import attach_profilers from d2go.utils.flop_calculator import attach_profilers
from d2go.utils.get_default_cfg import get_default_cfg from d2go.utils.get_default_cfg import get_default_cfg
from d2go.utils.helper import D2Trainer, TensorboardXWriter from d2go.utils.helper import D2Trainer, TensorboardXWriter
from d2go.utils.misc import get_tensorboard_log_dir from d2go.utils.misc import get_tensorboard_log_dir
from d2go.utils.oss_helper import fb_overwritable
from d2go.utils.visualization import DataLoaderVisWrapper, VisualizationEvaluator from d2go.utils.visualization import DataLoaderVisWrapper, VisualizationEvaluator
from detectron2.checkpoint import DetectionCheckpointer, PeriodicCheckpointer from detectron2.checkpoint import DetectionCheckpointer, PeriodicCheckpointer
from detectron2.data import ( from detectron2.data import (
...@@ -38,7 +40,7 @@ from detectron2.data import ( ...@@ -38,7 +40,7 @@ from detectron2.data import (
build_detection_train_loader as d2_build_detection_train_loader, build_detection_train_loader as d2_build_detection_train_loader,
MetadataCatalog, MetadataCatalog,
) )
from detectron2.engine import AMPTrainer, HookBase, hooks, SimpleTrainer from detectron2.engine import AMPTrainer, hooks, SimpleTrainer
from detectron2.evaluation import ( from detectron2.evaluation import (
COCOEvaluator, COCOEvaluator,
DatasetEvaluators, DatasetEvaluators,
...@@ -51,7 +53,6 @@ from detectron2.evaluation import ( ...@@ -51,7 +53,6 @@ from detectron2.evaluation import (
from detectron2.modeling import GeneralizedRCNNWithTTA from detectron2.modeling import GeneralizedRCNNWithTTA
from detectron2.solver import build_lr_scheduler as d2_build_lr_scheduler from detectron2.solver import build_lr_scheduler as d2_build_lr_scheduler
from detectron2.utils.events import CommonMetricPrinter, JSONWriter from detectron2.utils.events import CommonMetricPrinter, JSONWriter
from detectron2.utils.registry import Registry
from mobile_cv.predictor.api import PredictorWrapper from mobile_cv.predictor.api import PredictorWrapper
...@@ -128,6 +129,16 @@ def default_scale_quantization_configs(cfg, new_world_size): ...@@ -128,6 +129,16 @@ def default_scale_quantization_configs(cfg, new_world_size):
) )
@fb_overwritable()
def add_fb_base_runner_default_configs(cfg: CfgNode) -> CfgNode:
return cfg
@fb_overwritable()
def prepare_fb_model(cfg: CfgNode, model: torch.nn.Module) -> torch.nn.Module:
return model
class BaseRunner(object): class BaseRunner(object):
def __init__(self): def __init__(self):
identifier = f"D2Go.Runner.{self.__class__.__name__}" identifier = f"D2Go.Runner.{self.__class__.__name__}"
...@@ -163,8 +174,6 @@ class BaseRunner(object): ...@@ -163,8 +174,6 @@ class BaseRunner(object):
cfg.SOLVER.AUTO_SCALING_METHODS = ["default_scale_d2_configs"] cfg.SOLVER.AUTO_SCALING_METHODS = ["default_scale_d2_configs"]
cfg.PROFILERS = ["default_flop_counter"]
return cfg return cfg
def build_model(self, cfg, eval_only=False): def build_model(self, cfg, eval_only=False):
...@@ -193,18 +202,6 @@ class BaseRunner(object): ...@@ -193,18 +202,6 @@ class BaseRunner(object):
return d2_build_detection_train_loader(*args, **kwargs) return d2_build_detection_train_loader(*args, **kwargs)
# List of functions to add hooks for trainer, all functions in the registry will
# be called to add hooks
# func(hooks: List[HookBase]) -> None
TRAINER_HOOKS_REGISTRY = Registry("TRAINER_HOOKS_REGISTRY")
def update_hooks_from_registry(hooks: List[HookBase]):
for name, hook_func in TRAINER_HOOKS_REGISTRY:
logger.info(f"Update trainer hooks from {name}...")
hook_func(hooks)
class Detectron2GoRunner(BaseRunner): class Detectron2GoRunner(BaseRunner):
def register(self, cfg): def register(self, cfg):
super().register(cfg) super().register(cfg)
...@@ -216,8 +213,12 @@ class Detectron2GoRunner(BaseRunner): ...@@ -216,8 +213,12 @@ class Detectron2GoRunner(BaseRunner):
@staticmethod @staticmethod
def get_default_cfg(): def get_default_cfg():
_C = super(Detectron2GoRunner, Detectron2GoRunner).get_default_cfg() cfg = super(Detectron2GoRunner, Detectron2GoRunner).get_default_cfg()
return get_default_cfg(_C)
cfg.PROFILERS = ["default_flop_counter"]
cfg = add_fb_base_runner_default_configs(cfg)
return get_default_cfg(cfg)
# temporary API # temporary API
def _build_model(self, cfg, eval_only=False): def _build_model(self, cfg, eval_only=False):
...@@ -258,6 +259,7 @@ class Detectron2GoRunner(BaseRunner): ...@@ -258,6 +259,7 @@ class Detectron2GoRunner(BaseRunner):
def build_model(self, cfg, eval_only=False): def build_model(self, cfg, eval_only=False):
model = self._build_model(cfg, eval_only) model = self._build_model(cfg, eval_only)
model = prepare_fb_model(cfg, model)
# Note: the _visualize_model API is experimental # Note: the _visualize_model API is experimental
if comm.is_main_process(): if comm.is_main_process():
......
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import logging
from typing import List
from detectron2.engine import HookBase
from detectron2.utils.registry import Registry
logger = logging.getLogger(__name__)
# List of functions to add hooks for trainer, all functions in the registry will
# be called to add hooks
# func(hooks: List[HookBase]) -> None
TRAINER_HOOKS_REGISTRY = Registry("TRAINER_HOOKS_REGISTRY")
def update_hooks_from_registry(hooks: List[HookBase]):
for name, hook_func in TRAINER_HOOKS_REGISTRY:
logger.info(f"Update trainer hooks from {name}...")
hook_func(hooks)
...@@ -10,6 +10,7 @@ import unittest ...@@ -10,6 +10,7 @@ import unittest
import d2go.runner.default_runner as default_runner import d2go.runner.default_runner as default_runner
import torch import torch
from d2go.runner import create_runner from d2go.runner import create_runner
from d2go.runner.training_hooks import TRAINER_HOOKS_REGISTRY
from d2go.utils.testing import helper from d2go.utils.testing import helper
from d2go.utils.testing.data_loader_helper import create_local_dataset from d2go.utils.testing.data_loader_helper import create_local_dataset
from detectron2.evaluation import COCOEvaluator, RotatedCOCOEvaluator from detectron2.evaluation import COCOEvaluator, RotatedCOCOEvaluator
...@@ -353,7 +354,7 @@ class TestDefaultRunner(unittest.TestCase): ...@@ -353,7 +354,7 @@ class TestDefaultRunner(unittest.TestCase):
def test_d2go_runner_trainer_hooks(self): def test_d2go_runner_trainer_hooks(self):
counts = 0 counts = 0
@default_runner.TRAINER_HOOKS_REGISTRY.register() @TRAINER_HOOKS_REGISTRY.register()
def _check_hook_func(hooks): def _check_hook_func(hooks):
nonlocal counts nonlocal counts
counts = len(hooks) counts = len(hooks)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment