Commit c829f054 authored by Demetris Marnerides's avatar Demetris Marnerides Committed by Kai Chen
Browse files

Added Registry use for validation datasets (distributed) (#1058)

* Added Registry use for validation datasets (distributed)

* Allowing for default_args for build_dataset

* Using build_dataset instead of build_from_cfg
parent 3dc0047c
......@@ -11,7 +11,7 @@ from mmdet import datasets
from mmdet.core import (DistOptimizerHook, DistEvalmAPHook,
CocoDistEvalRecallHook, CocoDistEvalmAPHook,
Fp16OptimizerHook)
from mmdet.datasets import build_dataloader
from mmdet.datasets import build_dataloader, DATASETS
from mmdet.models import RPN
from .env import get_root_logger
......@@ -174,7 +174,7 @@ def _dist_train(model, dataset, cfg, validate=False):
runner.register_hook(
CocoDistEvalRecallHook(val_dataset_cfg, **eval_cfg))
else:
dataset_type = getattr(datasets, val_dataset_cfg.type)
dataset_type = DATASETS.get(val_dataset_cfg.type)
if issubclass(dataset_type, datasets.CocoDataset):
runner.register_hook(
CocoDistEvalmAPHook(val_dataset_cfg, **eval_cfg))
......
......@@ -5,7 +5,7 @@ import mmcv
import numpy as np
import torch
import torch.distributed as dist
from mmcv.runner import Hook, obj_from_dict
from mmcv.runner import Hook
from mmcv.parallel import scatter, collate
from pycocotools.cocoeval import COCOeval
from torch.utils.data import Dataset
......@@ -21,8 +21,7 @@ class DistEvalHook(Hook):
if isinstance(dataset, Dataset):
self.dataset = dataset
elif isinstance(dataset, dict):
self.dataset = obj_from_dict(dataset, datasets,
{'test_mode': True})
self.dataset = datasets.build_dataset(dataset, {'test_mode': True})
else:
raise TypeError(
'dataset must be a Dataset object or a dict, not {}'.format(
......
......@@ -5,7 +5,7 @@ from .dataset_wrappers import ConcatDataset, RepeatDataset
from .registry import DATASETS
def _concat_dataset(cfg):
def _concat_dataset(cfg, default_args=None):
ann_files = cfg['ann_file']
img_prefixes = cfg.get('img_prefix', None)
seg_prefixes = cfg.get('seg_prefixes', None)
......@@ -22,17 +22,18 @@ def _concat_dataset(cfg):
data_cfg['seg_prefix'] = seg_prefixes[i]
if isinstance(proposal_files, (list, tuple)):
data_cfg['proposal_file'] = proposal_files[i]
datasets.append(build_dataset(data_cfg))
datasets.append(build_dataset(data_cfg, default_args))
return ConcatDataset(datasets)
def build_dataset(cfg):
def build_dataset(cfg, default_args=None):
if cfg['type'] == 'RepeatDataset':
dataset = RepeatDataset(build_dataset(cfg['dataset']), cfg['times'])
dataset = RepeatDataset(build_dataset(cfg['dataset'], default_args),
cfg['times'])
elif isinstance(cfg['ann_file'], (list, tuple)):
dataset = _concat_dataset(cfg)
dataset = _concat_dataset(cfg, default_args)
else:
dataset = build_from_cfg(cfg, DATASETS)
dataset = build_from_cfg(cfg, DATASETS, default_args)
return dataset
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment