Commit 07c4e54c authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

merge internal data build files

Reviewed By: ppwwyyxx

Differential Revision: D31035247

fbshipit-source-id: 7340e6f6bb813e284416e37060d0d511c5c79e03
parent f4fcff31
...@@ -12,6 +12,7 @@ import torch ...@@ -12,6 +12,7 @@ import torch
from d2go.config import CfgNode from d2go.config import CfgNode
from d2go.data.dataset_mappers import build_dataset_mapper from d2go.data.dataset_mappers import build_dataset_mapper
from d2go.data.utils import ClipLengthGroupedDataset from d2go.data.utils import ClipLengthGroupedDataset
from d2go.utils.misc import fb_overwritable
from detectron2.data import ( from detectron2.data import (
build_batch_data_loader, build_batch_data_loader,
build_detection_train_loader, build_detection_train_loader,
...@@ -22,7 +23,6 @@ from detectron2.data.common import MapDataset, DatasetFromList ...@@ -22,7 +23,6 @@ from detectron2.data.common import MapDataset, DatasetFromList
from detectron2.data.dataset_mapper import DatasetMapper from detectron2.data.dataset_mapper import DatasetMapper
from detectron2.data.samplers import RepeatFactorTrainingSampler from detectron2.data.samplers import RepeatFactorTrainingSampler
from detectron2.utils.comm import get_world_size from detectron2.utils.comm import get_world_size
from mobile_cv.common.misc.registry import Registry
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -146,10 +146,7 @@ def build_clip_grouping_data_loader(dataset, sampler, total_batch_size, num_work ...@@ -146,10 +146,7 @@ def build_clip_grouping_data_loader(dataset, sampler, total_batch_size, num_work
return ClipLengthGroupedDataset(data_loader, batch_size) return ClipLengthGroupedDataset(data_loader, batch_size)
_MAPPED_TRAIN_LOADER_BUILDER_REGISTRY = Registry("MAPPED_TRAIN_LOADER_BUILDER") @fb_overwritable()
@_MAPPED_TRAIN_LOADER_BUILDER_REGISTRY.register("oss")
def build_mapped_train_loader(cfg, mapper): def build_mapped_train_loader(cfg, mapper):
if cfg.DATALOADER.SAMPLER_TRAIN == "WeightedTrainingSampler": if cfg.DATALOADER.SAMPLER_TRAIN == "WeightedTrainingSampler":
data_loader = build_weighted_detection_train_loader(cfg, mapper=mapper) data_loader = build_weighted_detection_train_loader(cfg, mapper=mapper)
...@@ -170,10 +167,7 @@ def build_d2go_train_loader(cfg, mapper=None): ...@@ -170,10 +167,7 @@ def build_d2go_train_loader(cfg, mapper=None):
mapper = mapper or build_dataset_mapper(cfg, is_train=True) mapper = mapper or build_dataset_mapper(cfg, is_train=True)
logger.info("Using dataset mapper:\n{}".format(mapper)) logger.info("Using dataset mapper:\n{}".format(mapper))
data_loader = ( data_loader = build_mapped_train_loader(cfg, mapper)
_MAPPED_TRAIN_LOADER_BUILDER_REGISTRY.get("internal", is_raise=False)
or _MAPPED_TRAIN_LOADER_BUILDER_REGISTRY.get("oss")
)(cfg, mapper)
# TODO: decide if move vis_wrapper inside this interface # TODO: decide if move vis_wrapper inside this interface
return data_loader return data_loader
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment