Commit 19c5392d authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

move common data api into separate module

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/427

Re-try previous reverted diff D41350485 (https://github.com/facebookresearch/d2go/commit/0ea6bc1b61ab736ccf1840c58c2b19ed2e9a1282). The problem was essentially because `DefaultTask` is not a subclass of `Runner`, so when we call `Runner`'s class methods from `DefaultTask`, it won't work if the `Runner`'s method also calls other methods that are in `Runner` but not `DefaultTask`. The solution is simply split the data related APIs out into a separate class (mixin), and let `DefaultTask` and `Runner` both subclass from it.

Reviewed By: tglik

Differential Revision: D41507448

fbshipit-source-id: 8b26c129811436c0bd35e1c6b0705e7035d7e823
parent 1a75101f
......@@ -199,7 +199,55 @@ class BaseRunner(object):
return d2_build_detection_train_loader(*args, **kwargs)
class Detectron2GoRunner(BaseRunner):
class D2GoDataAPIMixIn:
@staticmethod
def get_mapper(cfg, is_train):
tfm_gens = build_transform_gen(cfg, is_train)
mapper = build_dataset_mapper(cfg, is_train, tfm_gens=tfm_gens)
return mapper
@classmethod
def build_detection_test_loader(
cls, cfg, dataset_name: Union[str, List[str]], mapper=None
):
logger.info(
"Building detection test loader for dataset: {} ...".format(dataset_name)
)
with configure_dataset_creation(cfg):
mapper = mapper or cls.get_mapper(cfg, is_train=False)
logger.info("Using dataset mapper:\n{}".format(mapper))
return d2_build_detection_test_loader(cfg, dataset_name, mapper=mapper)
@classmethod
def build_detection_train_loader(cls, cfg, *args, mapper=None, **kwargs):
with configure_dataset_creation(cfg):
mapper = mapper or cls.get_mapper(cfg, is_train=True)
data_loader = build_d2go_train_loader(cfg, mapper)
return cls._attach_visualizer_to_data_loader(cfg, data_loader)
@classmethod
def _attach_visualizer_to_data_loader(cls, cfg, data_loader):
if comm.is_main_process():
data_loader_type = cls.get_data_loader_vis_wrapper()
if data_loader_type is not None:
tbx_writer = cls.get_tbx_writer(cfg)
data_loader = data_loader_type(cfg, tbx_writer, data_loader)
return data_loader
@classmethod
def get_tbx_writer(cls, cfg):
return _get_tbx_writer(get_tensorboard_log_dir(cfg.OUTPUT_DIR))
@staticmethod
def get_data_loader_vis_wrapper() -> Optional[Type[DataLoaderVisWrapper]]:
return DataLoaderVisWrapper
@staticmethod
def get_visualization_evaluator() -> Optional[Type[VisualizationEvaluator]]:
return VisualizationEvaluator
class Detectron2GoRunner(D2GoDataAPIMixIn, BaseRunner):
def register(self, cfg):
super().register(cfg)
self.original_cfg = cfg.clone()
......@@ -280,10 +328,6 @@ class Detectron2GoRunner(BaseRunner):
def build_lr_scheduler(self, cfg, optimizer):
return d2_build_lr_scheduler(cfg, optimizer)
@classmethod
def get_tbx_writer(cls, cfg):
return _get_tbx_writer(get_tensorboard_log_dir(cfg.OUTPUT_DIR))
def _do_test(self, cfg, model, train_iter=None, model_tag="default"):
"""train_iter: Current iteration of the model, None means final iteration"""
assert len(cfg.DATASETS.TEST)
......@@ -516,29 +560,6 @@ class Detectron2GoRunner(BaseRunner):
trained_cfg.MODEL.WEIGHTS = checkpointer.get_checkpoint_file()
return {"model_final": trained_cfg}
@classmethod
def build_detection_test_loader(
cls, cfg, dataset_name: Union[str, List[str]], mapper=None
):
logger.info(
"Building detection test loader for dataset: {} ...".format(dataset_name)
)
with configure_dataset_creation(cfg):
mapper = mapper or cls.get_mapper(cfg, is_train=False)
logger.info("Using dataset mapper:\n{}".format(mapper))
return d2_build_detection_test_loader(cfg, dataset_name, mapper=mapper)
@classmethod
def build_detection_train_loader(cls, cfg, *args, mapper=None, **kwargs):
with configure_dataset_creation(cfg):
mapper = mapper or cls.get_mapper(cfg, is_train=True)
data_loader = build_d2go_train_loader(cfg, mapper)
return cls._attach_visualizer_to_data_loader(cfg, data_loader)
@staticmethod
def get_data_loader_vis_wrapper() -> Optional[Type[DataLoaderVisWrapper]]:
return DataLoaderVisWrapper
@staticmethod
def get_evaluator(cfg, dataset_name, output_folder):
evaluator_type = MetadataCatalog.get(dataset_name).evaluator_type
......@@ -568,29 +589,10 @@ class Detectron2GoRunner(BaseRunner):
dataset_evaluators = DatasetEvaluators([dataset_evaluators])
return dataset_evaluators
@staticmethod
def get_mapper(cfg, is_train):
tfm_gens = build_transform_gen(cfg, is_train)
mapper = build_dataset_mapper(cfg, is_train, tfm_gens=tfm_gens)
return mapper
@staticmethod
def get_visualization_evaluator() -> Optional[Type[VisualizationEvaluator]]:
return VisualizationEvaluator
@staticmethod
def final_model_name():
return "model_final"
@classmethod
def _attach_visualizer_to_data_loader(cls, cfg, data_loader):
if comm.is_main_process():
data_loader_type = cls.get_data_loader_vis_wrapper()
if data_loader_type is not None:
tbx_writer = cls.get_tbx_writer(cfg)
data_loader = data_loader_type(cfg, tbx_writer, data_loader)
return data_loader
def _create_after_step_hook(
self, cfg, model, optimizer, scheduler, periodic_checkpointer
):
......
......@@ -5,13 +5,11 @@ import logging
import os
from copy import deepcopy
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple, Type
from typing import Any, Dict, List, Optional, Tuple
import detectron2.utils.comm as comm
import pytorch_lightning as pl
import torch
from d2go.config import CfgNode
from d2go.data.build import build_d2go_train_loader
from d2go.data.datasets import inject_coco_datasets, register_dynamic_datasets
from d2go.data.utils import update_cfg_if_using_adhoc_dataset
from d2go.modeling.api import build_meta_arch
......@@ -20,16 +18,13 @@ from d2go.optimizer import build_optimizer_mapper
from d2go.runner.callbacks.quantization import maybe_prepare_for_quantization, PREPARED
from d2go.runner.default_runner import (
_get_tbx_writer,
D2GoDataAPIMixIn,
Detectron2GoRunner,
GeneralizedRCNNRunner,
)
from d2go.utils.ema_state import EMAState
from d2go.utils.misc import get_tensorboard_log_dir
from d2go.utils.visualization import VisualizationEvaluator
from detectron2.solver import (
build_lr_scheduler as d2_build_lr_scheduler,
build_optimizer as d2_build_optimizer,
)
from detectron2.solver import build_lr_scheduler as d2_build_lr_scheduler
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
from pytorch_lightning.utilities.logger import _flatten_dict
......@@ -127,7 +122,7 @@ class ModelTag(str, Enum):
EMA = "ema"
class DefaultTask(pl.LightningModule):
class DefaultTask(D2GoDataAPIMixIn, pl.LightningModule):
def __init__(self, cfg: CfgNode):
super().__init__()
self.register(cfg)
......@@ -393,37 +388,6 @@ class DefaultTask(pl.LightningModule):
cfg=cfg, dataset_name=dataset_name, output_folder=output_folder
)
@staticmethod
def get_mapper(cfg, is_train):
return Detectron2GoRunner.get_mapper(cfg, is_train)
@staticmethod
def get_visualization_evaluator() -> Optional[Type[VisualizationEvaluator]]:
return Detectron2GoRunner.get_visualization_evaluator()
@staticmethod
def get_data_loader_vis_wrapper():
return Detectron2GoRunner.get_data_loader_vis_wrapper()
@classmethod
def build_detection_train_loader(cls, cfg, *args, mapper=None, **kwargs):
mapper = mapper or cls.get_mapper(cfg, is_train=True)
data_loader = build_d2go_train_loader(cfg, mapper)
return cls._attach_visualizer_to_data_loader(cfg, data_loader)
@staticmethod
def build_detection_test_loader(cfg, dataset_name, mapper=None):
return Detectron2GoRunner.build_detection_test_loader(cfg, dataset_name, mapper)
@classmethod
def _attach_visualizer_to_data_loader(cls, cfg, data_loader):
if comm.is_main_process():
data_loader_type = cls.get_data_loader_vis_wrapper()
if data_loader_type is not None:
tbx_writer = Detectron2GoRunner.get_tbx_writer(cfg)
data_loader = data_loader_type(cfg, tbx_writer, data_loader)
return data_loader
# ---------------------------------------------------------------------------
# Hooks
# ---------------------------------------------------------------------------
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment