Commit c829f054 authored by Demetris Marnerides's avatar Demetris Marnerides Committed by Kai Chen
Browse files

Added Registry use for validation datasets (distributed) (#1058)

* Added Registry use for validation datasets (distributed)

* Allowing for default_args for build_dataset

* Using build_dataset instead of build_from_cfg
parent 3dc0047c
...@@ -11,7 +11,7 @@ from mmdet import datasets ...@@ -11,7 +11,7 @@ from mmdet import datasets
from mmdet.core import (DistOptimizerHook, DistEvalmAPHook, from mmdet.core import (DistOptimizerHook, DistEvalmAPHook,
CocoDistEvalRecallHook, CocoDistEvalmAPHook, CocoDistEvalRecallHook, CocoDistEvalmAPHook,
Fp16OptimizerHook) Fp16OptimizerHook)
from mmdet.datasets import build_dataloader from mmdet.datasets import build_dataloader, DATASETS
from mmdet.models import RPN from mmdet.models import RPN
from .env import get_root_logger from .env import get_root_logger
...@@ -174,7 +174,7 @@ def _dist_train(model, dataset, cfg, validate=False): ...@@ -174,7 +174,7 @@ def _dist_train(model, dataset, cfg, validate=False):
runner.register_hook( runner.register_hook(
CocoDistEvalRecallHook(val_dataset_cfg, **eval_cfg)) CocoDistEvalRecallHook(val_dataset_cfg, **eval_cfg))
else: else:
dataset_type = getattr(datasets, val_dataset_cfg.type) dataset_type = DATASETS.get(val_dataset_cfg.type)
if issubclass(dataset_type, datasets.CocoDataset): if issubclass(dataset_type, datasets.CocoDataset):
runner.register_hook( runner.register_hook(
CocoDistEvalmAPHook(val_dataset_cfg, **eval_cfg)) CocoDistEvalmAPHook(val_dataset_cfg, **eval_cfg))
......
...@@ -5,7 +5,7 @@ import mmcv ...@@ -5,7 +5,7 @@ import mmcv
import numpy as np import numpy as np
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from mmcv.runner import Hook, obj_from_dict from mmcv.runner import Hook
from mmcv.parallel import scatter, collate from mmcv.parallel import scatter, collate
from pycocotools.cocoeval import COCOeval from pycocotools.cocoeval import COCOeval
from torch.utils.data import Dataset from torch.utils.data import Dataset
...@@ -21,8 +21,7 @@ class DistEvalHook(Hook): ...@@ -21,8 +21,7 @@ class DistEvalHook(Hook):
if isinstance(dataset, Dataset): if isinstance(dataset, Dataset):
self.dataset = dataset self.dataset = dataset
elif isinstance(dataset, dict): elif isinstance(dataset, dict):
self.dataset = obj_from_dict(dataset, datasets, self.dataset = datasets.build_dataset(dataset, {'test_mode': True})
{'test_mode': True})
else: else:
raise TypeError( raise TypeError(
'dataset must be a Dataset object or a dict, not {}'.format( 'dataset must be a Dataset object or a dict, not {}'.format(
......
...@@ -5,7 +5,7 @@ from .dataset_wrappers import ConcatDataset, RepeatDataset ...@@ -5,7 +5,7 @@ from .dataset_wrappers import ConcatDataset, RepeatDataset
from .registry import DATASETS from .registry import DATASETS
def _concat_dataset(cfg): def _concat_dataset(cfg, default_args=None):
ann_files = cfg['ann_file'] ann_files = cfg['ann_file']
img_prefixes = cfg.get('img_prefix', None) img_prefixes = cfg.get('img_prefix', None)
seg_prefixes = cfg.get('seg_prefixes', None) seg_prefixes = cfg.get('seg_prefixes', None)
...@@ -22,17 +22,18 @@ def _concat_dataset(cfg): ...@@ -22,17 +22,18 @@ def _concat_dataset(cfg):
data_cfg['seg_prefix'] = seg_prefixes[i] data_cfg['seg_prefix'] = seg_prefixes[i]
if isinstance(proposal_files, (list, tuple)): if isinstance(proposal_files, (list, tuple)):
data_cfg['proposal_file'] = proposal_files[i] data_cfg['proposal_file'] = proposal_files[i]
datasets.append(build_dataset(data_cfg)) datasets.append(build_dataset(data_cfg, default_args))
return ConcatDataset(datasets) return ConcatDataset(datasets)
def build_dataset(cfg): def build_dataset(cfg, default_args=None):
if cfg['type'] == 'RepeatDataset': if cfg['type'] == 'RepeatDataset':
dataset = RepeatDataset(build_dataset(cfg['dataset']), cfg['times']) dataset = RepeatDataset(build_dataset(cfg['dataset'], default_args),
cfg['times'])
elif isinstance(cfg['ann_file'], (list, tuple)): elif isinstance(cfg['ann_file'], (list, tuple)):
dataset = _concat_dataset(cfg) dataset = _concat_dataset(cfg, default_args)
else: else:
dataset = build_from_cfg(cfg, DATASETS) dataset = build_from_cfg(cfg, DATASETS, default_args)
return dataset return dataset
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment