Unverified Commit 9510c3a7 authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

Merge pull request #127 from wangg12/master

support training on dataset with multiple ann_files
parents ae4646fa 9baf0a8b
...@@ -107,3 +107,5 @@ venv.bak/ ...@@ -107,3 +107,5 @@ venv.bak/
mmdet/ops/nms/*.cpp mmdet/ops/nms/*.cpp
mmdet/version.py mmdet/version.py
data data
.vscode
.idea
from .custom import CustomDataset from .custom import CustomDataset
from .coco import CocoDataset from .coco import CocoDataset
from .loader import GroupSampler, DistributedGroupSampler, build_dataloader from .loader import GroupSampler, DistributedGroupSampler, build_dataloader
from .utils import to_tensor, random_scale, show_ann from .utils import to_tensor, random_scale, show_ann, get_dataset
from .concat_dataset import ConcatDataset
__all__ = [ __all__ = [
'CustomDataset', 'CocoDataset', 'GroupSampler', 'DistributedGroupSampler', 'CustomDataset', 'CocoDataset', 'GroupSampler', 'DistributedGroupSampler',
'build_dataloader', 'to_tensor', 'random_scale', 'show_ann' 'ConcatDataset', 'build_dataloader', 'to_tensor', 'random_scale',
'show_ann', 'get_dataset'
] ]
import numpy as np
from torch.utils.data.dataset import ConcatDataset as _ConcatDataset
class ConcatDataset(_ConcatDataset):
"""
Same as torch.utils.data.dataset.ConcatDataset, but
concat the group flag for image aspect ratio.
"""
def __init__(self, datasets):
"""
flag: Images with aspect ratio greater than 1 will be set as group 1,
otherwise group 0.
"""
super(ConcatDataset, self).__init__(datasets)
if hasattr(datasets[0], 'flag'):
flags = []
for i in range(0, len(datasets)):
flags.append(datasets[i].flag)
self.flag = np.concatenate(flags)
import copy
from collections import Sequence from collections import Sequence
import mmcv import mmcv
from mmcv.runner import obj_from_dict
import torch import torch
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
from .concat_dataset import ConcatDataset
from .. import datasets
def to_tensor(data): def to_tensor(data):
...@@ -67,3 +71,41 @@ def show_ann(coco, img, ann_info): ...@@ -67,3 +71,41 @@ def show_ann(coco, img, ann_info):
plt.axis('off') plt.axis('off')
coco.showAnns(ann_info) coco.showAnns(ann_info)
plt.show() plt.show()
def get_dataset(data_cfg):
if isinstance(data_cfg['ann_file'], (list, tuple)):
ann_files = data_cfg['ann_file']
num_dset = len(ann_files)
else:
ann_files = [data_cfg['ann_file']]
num_dset = 1
if 'proposal_file' in data_cfg.keys():
if isinstance(data_cfg['proposal_file'], (list, tuple)):
proposal_files = data_cfg['proposal_file']
else:
proposal_files = [data_cfg['proposal_file']]
else:
proposal_files = [None] * num_dset
assert len(proposal_files) == num_dset
if isinstance(data_cfg['img_prefix'], (list, tuple)):
img_prefixes = data_cfg['img_prefix']
else:
img_prefixes = [data_cfg['img_prefix']] * num_dset
assert len(img_prefixes) == num_dset
dsets = []
for i in range(num_dset):
data_info = copy.deepcopy(data_cfg)
data_info['ann_file'] = ann_files[i]
data_info['proposal_file'] = proposal_files[i]
data_info['img_prefix'] = img_prefixes[i]
dset = obj_from_dict(data_info, datasets)
dsets.append(dset)
if len(dsets) > 1:
dset = ConcatDataset(dsets)
else:
dset = dsets[0]
return dset
...@@ -2,9 +2,9 @@ from __future__ import division ...@@ -2,9 +2,9 @@ from __future__ import division
import argparse import argparse
from mmcv import Config from mmcv import Config
from mmcv.runner import obj_from_dict
from mmdet import datasets, __version__ from mmdet import __version__
from mmdet.datasets import get_dataset
from mmdet.apis import (train_detector, init_dist, get_root_logger, from mmdet.apis import (train_detector, init_dist, get_root_logger,
set_random_seed) set_random_seed)
from mmdet.models import build_detector from mmdet.models import build_detector
...@@ -67,7 +67,7 @@ def main(): ...@@ -67,7 +67,7 @@ def main():
model = build_detector( model = build_detector(
cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg) cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
train_dataset = obj_from_dict(cfg.data.train, datasets) train_dataset = get_dataset(cfg.data.train)
train_detector( train_detector(
model, model,
train_dataset, train_dataset,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment