Commit 825cfa0c authored by Kai Chen's avatar Kai Chen
Browse files

add class attribute CLASSES to Dataset

parent 826a5613
......@@ -63,18 +63,18 @@ def imagenet_vid_classes():
def coco_classes():
return [
'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train',
'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign',
'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep',
'truck', 'boat', 'traffic_light', 'fire_hydrant', 'stop_sign',
'parking_meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep',
'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella',
'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard',
'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard',
'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork',
'sports_ball', 'kite', 'baseball_bat', 'baseball_glove', 'skateboard',
'surfboard', 'tennis_racket', 'bottle', 'wine_glass', 'cup', 'fork',
'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange',
'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair',
'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv',
'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave',
'broccoli', 'carrot', 'hot_dog', 'pizza', 'donut', 'cake', 'chair',
'couch', 'potted_plant', 'bed', 'dining_table', 'toilet', 'tv',
'laptop', 'mouse', 'remote', 'keyboard', 'cell_phone', 'microwave',
'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase',
'scissors', 'teddy bear', 'hair drier', 'toothbrush'
'scissors', 'teddy_bear', 'hair_drier', 'toothbrush'
]
......
......@@ -6,6 +6,21 @@ from .custom import CustomDataset
class CocoDataset(CustomDataset):
CLASSES = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
'train', 'truck', 'boat', 'traffic_light', 'fire_hydrant',
'stop_sign', 'parking_meter', 'bench', 'bird', 'cat', 'dog',
'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe',
'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
'skis', 'snowboard', 'sports_ball', 'kite', 'baseball_bat',
'baseball_glove', 'skateboard', 'surfboard', 'tennis_racket',
'bottle', 'wine_glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot',
'hot_dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
'potted_plant', 'bed', 'dining_table', 'toilet', 'tv', 'laptop',
'mouse', 'remote', 'keyboard', 'cell_phone', 'microwave',
'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock',
'vase', 'scissors', 'teddy_bear', 'hair_drier', 'toothbrush')
def load_annotations(self, ann_file):
self.coco = COCO(ann_file)
self.cat_ids = self.coco.getCatIds()
......
......@@ -3,16 +3,18 @@ from torch.utils.data.dataset import ConcatDataset as _ConcatDataset
class ConcatDataset(_ConcatDataset):
"""
Same as torch.utils.data.dataset.ConcatDataset, but
"""A wrapper of concatenated dataset.
Same as :obj:`torch.utils.data.dataset.ConcatDataset`, but
concat the group flag for image aspect ratio.
Args:
datasets (list[:obj:`Dataset`]): A list of datasets.
"""
def __init__(self, datasets):
"""
flag: Images with aspect ratio greater than 1 will be set as group 1,
otherwise group 0.
"""
super(ConcatDataset, self).__init__(datasets)
self.CLASSES = datasets[0].CLASSES
if hasattr(datasets[0], 'flag'):
flags = []
for i in range(0, len(datasets)):
......
......@@ -32,6 +32,8 @@ class CustomDataset(Dataset):
The `ann` field is optional for testing.
"""
CLASSES = None
def __init__(self,
ann_file,
img_prefix,
......@@ -45,6 +47,8 @@ class CustomDataset(Dataset):
with_crowd=True,
with_label=True,
test_mode=False):
# prefix of images path
self.img_prefix = img_prefix
# load annotations (and proposals)
self.img_infos = self.load_annotations(ann_file)
if proposal_file is not None:
......@@ -58,8 +62,6 @@ class CustomDataset(Dataset):
if self.proposals is not None:
self.proposals = [self.proposals[i] for i in valid_inds]
# prefix of images path
self.img_prefix = img_prefix
# (long_edge, short_edge) or [(long1, short1), (long2, short2), ...]
self.img_scales = img_scale if isinstance(img_scale,
list) else [img_scale]
......
......@@ -6,12 +6,14 @@ class RepeatDataset(object):
def __init__(self, dataset, times):
self.dataset = dataset
self.times = times
self.CLASSES = dataset.CLASSES
if hasattr(self.dataset, 'flag'):
self.flag = np.tile(self.dataset.flag, times)
self._original_length = len(self.dataset)
self._ori_len = len(self.dataset)
def __getitem__(self, idx):
return self.dataset[idx % self._original_length]
return self.dataset[idx % self._ori_len]
def __len__(self):
return self.times * self._original_length
return self.times * self._ori_len
......@@ -99,11 +99,12 @@ class BaseDetector(nn.Module):
if isinstance(dataset, str):
class_names = get_classes(dataset)
elif isinstance(dataset, list):
elif isinstance(dataset, (list, tuple)) or dataset is None:
class_names = dataset
else:
raise TypeError('dataset must be a valid dataset name or a list'
' of class names, not {}'.format(type(dataset)))
raise TypeError(
'dataset must be a valid dataset name or a sequence'
' of class names, not {}'.format(type(dataset)))
for img, img_meta in zip(imgs, img_metas):
h, w, _ = img_meta['img_shape']
......
......@@ -14,15 +14,16 @@ from mmdet.models import build_detector, detectors
def single_test(model, data_loader, show=False):
model.eval()
results = []
prog_bar = mmcv.ProgressBar(len(data_loader.dataset))
dataset = data_loader.dataset
prog_bar = mmcv.ProgressBar(len(dataset))
for i, data in enumerate(data_loader):
with torch.no_grad():
result = model(return_loss=False, rescale=not show, **data)
results.append(result)
if show:
model.module.show_result(data, result,
data_loader.dataset.img_norm_cfg)
model.module.show_result(data, result, dataset.img_norm_cfg,
dataset.CLASSES)
batch_size = data['img'][0].size(0)
for _ in range(batch_size):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment