Commit 7cbdbc78 authored by wangg12's avatar wangg12
Browse files

move the function to datasets.utils

parent 7906bd20
from .custom import CustomDataset
from .coco import CocoDataset
from .loader import GroupSampler, DistributedGroupSampler, build_dataloader
from .utils import to_tensor, random_scale, show_ann
from .utils import to_tensor, random_scale, show_ann, get_dataset
from .concat_dataset import ConcatDataset
__all__ = [
'CustomDataset', 'CocoDataset', 'GroupSampler', 'DistributedGroupSampler', 'ConcatDataset',
'build_dataloader', 'to_tensor', 'random_scale', 'show_ann'
'build_dataloader', 'to_tensor', 'random_scale', 'show_ann', 'get_dataset'
]
from collections import Sequence
import copy
import mmcv
from mmcv.runner import obj_from_dict
import torch
import matplotlib.pyplot as plt
import numpy as np
from .concat_dataset import ConcatDataset
from .. import datasets
def to_tensor(data):
"""Convert objects of various python types to :obj:`torch.Tensor`.
......@@ -67,3 +69,21 @@ def show_ann(coco, img, ann_info):
plt.axis('off')
coco.showAnns(ann_info)
plt.show()
def get_dataset(data_cfg):
if isinstance(data_cfg['ann_file'], list) or isinstance(data_cfg['ann_file'], tuple):
ann_files = data_cfg['ann_file']
dsets = []
for ann_file in ann_files:
data_info = copy.deepcopy(data_cfg)
data_info['ann_file'] = ann_file
dset = obj_from_dict(data_info, datasets)
dsets.append(dset)
if len(dsets) > 1:
dset = ConcatDataset(dsets)
else:
dset = dsets[0]
else:
dset = obj_from_dict(data_cfg, datasets)
return dset
\ No newline at end of file
from __future__ import division
import argparse
import copy
from mmcv import Config
from mmcv.runner import obj_from_dict
from mmdet import datasets, __version__
from mmdet.datasets import ConcatDataset
from mmdet.apis import (train_detector, init_dist, get_root_logger,
set_random_seed)
from mmdet.models import build_detector
......@@ -38,24 +36,6 @@ def parse_args():
return args
def get_train_dataset(cfg):
if isinstance(cfg.data.train['ann_file'], list) or isinstance(cfg.data.train['ann_file'], tuple):
ann_files = cfg.data.train['ann_file']
train_datasets = []
for ann_file in ann_files:
data_info = copy.deepcopy(cfg.data.train)
data_info['ann_file'] = ann_file
train_dset = obj_from_dict(data_info, datasets)
train_datasets.append(train_dset)
if len(train_datasets) > 1:
train_dataset = ConcatDataset(train_datasets)
else:
train_dataset = train_datasets[0]
else:
train_dataset = obj_from_dict(cfg.data.train, datasets)
return train_dataset
def main():
args = parse_args()
......@@ -87,7 +67,7 @@ def main():
model = build_detector(
cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
train_dataset = get_train_dataset(cfg)
train_dataset = datasets.get_dataset(cfg.data.train)
train_detector(
model,
train_dataset,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment