Commit 4bae056b authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

new interface for build_d2go_train_loader

Summary:
#Facebook:

`build_d2go_train_loader` will replace `runner.build_detection_train_loader`, currently we call `build_d2go_train_loader` from `runner.build_detection_train_loader` since some runner has there own implementation, we will solve those cases and remove the `runner.build_detection_train_loader` API.

Currently `build_d2go_train_loader` uses `_MAPPED_TRAIN_LOADER_BUILDER_REGISTRY` to support different versions between OSS and FB, not sure if this is a good pattern or not, please comment in the diff if you have better idea.

Reviewed By: zhanghang1989

Differential Revision: D27505681

fbshipit-source-id: b5caf7280a88c2ebccb498097c0b7af51c966fc6
parent 1850a632
......@@ -10,13 +10,21 @@ from typing import Dict
import torch
from d2go.config import CfgNode
from d2go.data.dataset_mappers import build_dataset_mapper
from d2go.data.utils import ClipLengthGroupedDataset
from detectron2.data import build_batch_data_loader, get_detection_dataset_dicts
from detectron2.data import (
build_batch_data_loader,
build_detection_train_loader,
get_detection_dataset_dicts,
)
from detectron2.data.build import worker_init_reset_seed
from detectron2.data.common import MapDataset, DatasetFromList
from detectron2.data.dataset_mapper import DatasetMapper
from detectron2.data.samplers import RepeatFactorTrainingSampler
from detectron2.utils.comm import get_world_size
from mobile_cv.common.misc.registry import Registry
logger = logging.getLogger(__name__)
def add_weighted_training_sampler_default_configs(cfg: CfgNode):
......@@ -44,7 +52,6 @@ def get_train_datasets_repeat_factors(cfg: CfgNode) -> Dict[str, float]:
unrecognized = set(name_to_weight.keys()) - set(cfg.DATASETS.TRAIN)
assert not unrecognized, f"unrecognized datasets: {unrecognized}"
logger = logging.getLogger(__name__)
logger.info(f"Found repeat factors: {list(name_to_weight.items())}")
# pyre-fixme[7]: Expected `Dict[str, float]` but got `DefaultDict[typing.Any, int]`.
......@@ -83,7 +90,6 @@ def build_weighted_detection_train_loader(cfg: CfgNode, mapper=None):
mapper = DatasetMapper(cfg, True)
dataset = MapDataset(dataset, mapper)
logger = logging.getLogger(__name__)
logger.info(
"Using WeightedTrainingSampler with repeat_factors={}".format(
cfg.DATASETS.TRAIN_REPEAT_FACTOR
......@@ -130,3 +136,36 @@ def build_clip_grouping_data_loader(dataset, sampler, total_batch_size, num_work
worker_init_fn=worker_init_reset_seed,
) # yield individual mapped dict
return ClipLengthGroupedDataset(data_loader, batch_size)
_MAPPED_TRAIN_LOADER_BUILDER_REGISTRY = Registry("MAPPED_TRAIN_LOADER_BUILDER")
@_MAPPED_TRAIN_LOADER_BUILDER_REGISTRY.register("oss")
def build_mapped_train_loader(cfg, mapper):
if cfg.DATALOADER.SAMPLER_TRAIN == "WeightedTrainingSampler":
data_loader = build_weighted_detection_train_loader(cfg, mapper=mapper)
else:
data_loader = build_detection_train_loader(cfg, mapper=mapper)
return data_loader
def build_d2go_train_loader(cfg, mapper=None):
"""
Build the dataloader for training in D2Go. This is the main entry and customizations
will be done by using Registry.
This interface is currently experimental.
"""
logger.info("Building D2Go's train loader ...")
# TODO: disallow passing mapper and use registry for all mapper registering
mapper = mapper or build_dataset_mapper(cfg, is_train=True)
logger.info("Using dataset mapper:\n{}".format(mapper))
data_loader = (
_MAPPED_TRAIN_LOADER_BUILDER_REGISTRY.get("internal", is_raise=False)
or _MAPPED_TRAIN_LOADER_BUILDER_REGISTRY.get("oss")
)(cfg, mapper)
# TODO: decide if move vis_wrapper inside this interface
return data_loader
......@@ -15,9 +15,7 @@ import detectron2.utils.comm as comm
import mock
import torch
from d2go.config import CfgNode as CN, CONFIG_SCALING_METHOD_REGISTRY, temp_defrost, get_cfg_diff_table
from d2go.data.build import (
build_weighted_detection_train_loader,
)
from d2go.data.build import build_d2go_train_loader
from d2go.data.dataset_mappers import build_dataset_mapper
from d2go.data.datasets import inject_coco_datasets, register_dynamic_datasets
from d2go.data.transforms.build import build_transform_gen
......@@ -476,17 +474,8 @@ class Detectron2GoRunner(BaseRunner):
@classmethod
def build_detection_train_loader(cls, cfg, *args, mapper=None, **kwargs):
logger.info("Building detection train loader ...")
mapper = mapper or cls.get_mapper(cfg, is_train=True)
logger.info("Using dataset mapper:\n{}".format(mapper))
sampler_name = cfg.DATALOADER.SAMPLER_TRAIN
if sampler_name == "WeightedTrainingSampler":
data_loader = build_weighted_detection_train_loader(cfg, mapper=mapper)
else:
data_loader = d2_build_detection_train_loader(
cfg, *args, mapper=mapper, **kwargs
)
data_loader = build_d2go_train_loader(cfg, mapper)
if comm.is_main_process():
tbx_writer = _get_tbx_writer(get_tensorboard_log_dir(cfg.OUTPUT_DIR))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment