Commit 6b25743a authored by wangg12's avatar wangg12
Browse files

fix flake8

parent 7cbdbc78
......@@ -5,6 +5,7 @@ from .utils import to_tensor, random_scale, show_ann, get_dataset
from .concat_dataset import ConcatDataset
__all__ = [
'CustomDataset', 'CocoDataset', 'GroupSampler', 'DistributedGroupSampler', 'ConcatDataset',
'build_dataloader', 'to_tensor', 'random_scale', 'show_ann', 'get_dataset'
'CustomDataset', 'CocoDataset', 'GroupSampler', 'DistributedGroupSampler',
'ConcatDataset', 'build_dataloader', 'to_tensor', 'random_scale',
'show_ann', 'get_dataset'
]
......@@ -5,7 +5,7 @@ from torch.utils.data.dataset import ConcatDataset as _ConcatDataset
class ConcatDataset(_ConcatDataset):
"""
Same as torch.utils.data.dataset.ConcatDataset, but
Same as torch.utils.data.dataset.ConcatDataset, but
concat the group flag for image aspect ratio.
"""
def __init__(self, datasets):
......@@ -13,7 +13,7 @@ class ConcatDataset(_ConcatDataset):
flag: Images with aspect ratio greater than 1 will be set as group 1,
otherwise group 0.
"""
super(ConcatDataset, self).__init__(datasets)
super(ConcatDataset, self).__init__(datasets)
if hasattr(datasets[0], 'flag'):
flags = []
for i in range(0, len(datasets)):
......@@ -27,4 +27,3 @@ class ConcatDataset(_ConcatDataset):
else:
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
return dataset_idx, sample_idx
......@@ -9,6 +9,7 @@ import numpy as np
from .concat_dataset import ConcatDataset
from .. import datasets
def to_tensor(data):
"""Convert objects of various python types to :obj:`torch.Tensor`.
......@@ -72,7 +73,8 @@ def show_ann(coco, img, ann_info):
def get_dataset(data_cfg):
if isinstance(data_cfg['ann_file'], list) or isinstance(data_cfg['ann_file'], tuple):
if isinstance(data_cfg['ann_file'], list) or \
isinstance(data_cfg['ann_file'], tuple):
ann_files = data_cfg['ann_file']
dsets = []
for ann_file in ann_files:
......@@ -81,9 +83,9 @@ def get_dataset(data_cfg):
dset = obj_from_dict(data_info, datasets)
dsets.append(dset)
if len(dsets) > 1:
dset = ConcatDataset(dsets)
dset = ConcatDataset(dsets)
else:
dset = dsets[0]
else:
dset = obj_from_dict(data_cfg, datasets)
return dset
\ No newline at end of file
return dset
......@@ -2,7 +2,6 @@ from __future__ import division
import argparse
from mmcv import Config
from mmcv.runner import obj_from_dict
from mmdet import datasets, __version__
from mmdet.apis import (train_detector, init_dist, get_root_logger,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment