Commit 40325555 authored by yhcao6's avatar yhcao6
Browse files

merge new master

parents cfdd8050 c95c6373
...@@ -194,7 +194,7 @@ Here is an example. ...@@ -194,7 +194,7 @@ Here is an example.
'bboxes': <np.ndarray> (n, 4), 'bboxes': <np.ndarray> (n, 4),
'labels': <np.ndarray> (n, ), 'labels': <np.ndarray> (n, ),
'bboxes_ignore': <np.ndarray> (k, 4), 'bboxes_ignore': <np.ndarray> (k, 4),
'labels_ignore': <np.ndarray> (k, 4) (optional field) 'labels_ignore': <np.ndarray> (k, ) (optional field)
} }
}, },
... ...
...@@ -206,12 +206,12 @@ There are two ways to work with custom datasets. ...@@ -206,12 +206,12 @@ There are two ways to work with custom datasets.
- online conversion - online conversion
You can write a new Dataset class inherited from `CustomDataset`, and overwrite two methods You can write a new Dataset class inherited from `CustomDataset`, and overwrite two methods
`load_annotations(self, ann_file)` and `get_ann_info(self, idx)`, like [CocoDataset](mmdet/datasets/coco.py). `load_annotations(self, ann_file)` and `get_ann_info(self, idx)`, like [CocoDataset](mmdet/datasets/coco.py) and [VOCDataset](mmdet/datasets/voc.py).
- offline conversion - offline conversion
You can convert the annotation format to the expected format above and save it to You can convert the annotation format to the expected format above and save it to
a pickle file, like [pascal_voc.py](tools/convert_datasets/pascal_voc.py). a pickle or json file, like [pascal_voc.py](tools/convert_datasets/pascal_voc.py).
Then you can simply use `CustomDataset`. Then you can simply use `CustomDataset`.
## Technical details ## Technical details
......
...@@ -6,8 +6,8 @@ import torch ...@@ -6,8 +6,8 @@ import torch
from mmcv.runner import Runner, DistSamplerSeedHook from mmcv.runner import Runner, DistSamplerSeedHook
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmdet.core import (DistOptimizerHook, CocoDistEvalRecallHook, from mmdet.core import (DistOptimizerHook, DistEvalmAPHook,
CocoDistEvalmAPHook) CocoDistEvalRecallHook, CocoDistEvalmAPHook)
from mmdet.datasets import build_dataloader from mmdet.datasets import build_dataloader
from mmdet.models import RPN from mmdet.models import RPN
from .env import get_root_logger from .env import get_root_logger
...@@ -81,9 +81,13 @@ def _dist_train(model, dataset, cfg, validate=False): ...@@ -81,9 +81,13 @@ def _dist_train(model, dataset, cfg, validate=False):
# register eval hooks # register eval hooks
if validate: if validate:
if isinstance(model.module, RPN): if isinstance(model.module, RPN):
# TODO: implement recall hooks for other datasets
runner.register_hook(CocoDistEvalRecallHook(cfg.data.val)) runner.register_hook(CocoDistEvalRecallHook(cfg.data.val))
elif cfg.data.val.type == 'CocoDataset': else:
runner.register_hook(CocoDistEvalmAPHook(cfg.data.val)) if cfg.data.val.type == 'CocoDataset':
runner.register_hook(CocoDistEvalmAPHook(cfg.data.val))
else:
runner.register_hook(DistEvalmAPHook(cfg.data.val))
if cfg.resume_from: if cfg.resume_from:
runner.resume(cfg.resume_from) runner.resume(cfg.resume_from)
......
...@@ -2,7 +2,7 @@ from .class_names import (voc_classes, imagenet_det_classes, ...@@ -2,7 +2,7 @@ from .class_names import (voc_classes, imagenet_det_classes,
imagenet_vid_classes, coco_classes, dataset_aliases, imagenet_vid_classes, coco_classes, dataset_aliases,
get_classes) get_classes)
from .coco_utils import coco_eval, fast_eval_recall, results2json from .coco_utils import coco_eval, fast_eval_recall, results2json
from .eval_hooks import (DistEvalHook, CocoDistEvalRecallHook, from .eval_hooks import (DistEvalHook, DistEvalmAPHook, CocoDistEvalRecallHook,
CocoDistEvalmAPHook) CocoDistEvalmAPHook)
from .mean_ap import average_precision, eval_map, print_map_summary from .mean_ap import average_precision, eval_map, print_map_summary
from .recall import (eval_recalls, print_recall_summary, plot_num_recall, from .recall import (eval_recalls, print_recall_summary, plot_num_recall,
...@@ -11,7 +11,7 @@ from .recall import (eval_recalls, print_recall_summary, plot_num_recall, ...@@ -11,7 +11,7 @@ from .recall import (eval_recalls, print_recall_summary, plot_num_recall,
__all__ = [ __all__ = [
'voc_classes', 'imagenet_det_classes', 'imagenet_vid_classes', 'voc_classes', 'imagenet_det_classes', 'imagenet_vid_classes',
'coco_classes', 'dataset_aliases', 'get_classes', 'coco_eval', 'coco_classes', 'dataset_aliases', 'get_classes', 'coco_eval',
'fast_eval_recall', 'results2json', 'DistEvalHook', 'fast_eval_recall', 'results2json', 'DistEvalHook', 'DistEvalmAPHook',
'CocoDistEvalRecallHook', 'CocoDistEvalmAPHook', 'average_precision', 'CocoDistEvalRecallHook', 'CocoDistEvalmAPHook', 'average_precision',
'eval_map', 'print_map_summary', 'eval_recalls', 'print_recall_summary', 'eval_map', 'print_map_summary', 'eval_recalls', 'print_recall_summary',
'plot_num_recall', 'plot_iou_recall' 'plot_num_recall', 'plot_iou_recall'
......
...@@ -63,18 +63,18 @@ def imagenet_vid_classes(): ...@@ -63,18 +63,18 @@ def imagenet_vid_classes():
def coco_classes(): def coco_classes():
return [ return [
'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train',
'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', 'truck', 'boat', 'traffic_light', 'fire_hydrant', 'stop_sign',
'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'parking_meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep',
'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella',
'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard',
'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'sports_ball', 'kite', 'baseball_bat', 'baseball_glove', 'skateboard',
'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'surfboard', 'tennis_racket', 'bottle', 'wine_glass', 'cup', 'fork',
'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange',
'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'broccoli', 'carrot', 'hot_dog', 'pizza', 'donut', 'cake', 'chair',
'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'couch', 'potted_plant', 'bed', 'dining_table', 'toilet', 'tv',
'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'laptop', 'mouse', 'remote', 'keyboard', 'cell_phone', 'microwave',
'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase',
'scissors', 'teddy bear', 'hair drier', 'toothbrush' 'scissors', 'teddy_bear', 'hair_drier', 'toothbrush'
] ]
......
...@@ -12,6 +12,7 @@ from pycocotools.cocoeval import COCOeval ...@@ -12,6 +12,7 @@ from pycocotools.cocoeval import COCOeval
from torch.utils.data import Dataset from torch.utils.data import Dataset
from .coco_utils import results2json, fast_eval_recall from .coco_utils import results2json, fast_eval_recall
from .mean_ap import eval_map
from mmdet import datasets from mmdet import datasets
...@@ -102,6 +103,44 @@ class DistEvalHook(Hook): ...@@ -102,6 +103,44 @@ class DistEvalHook(Hook):
raise NotImplementedError raise NotImplementedError
class DistEvalmAPHook(DistEvalHook):
def evaluate(self, runner, results):
gt_bboxes = []
gt_labels = []
gt_ignore = [] if self.dataset.with_crowd else None
for i in range(len(self.dataset)):
ann = self.dataset.get_ann_info(i)
bboxes = ann['bboxes']
labels = ann['labels']
if gt_ignore is not None:
ignore = np.concatenate([
np.zeros(bboxes.shape[0], dtype=np.bool),
np.ones(ann['bboxes_ignore'].shape[0], dtype=np.bool)
])
gt_ignore.append(ignore)
bboxes = np.vstack([bboxes, ann['bboxes_ignore']])
labels = np.concatenate([labels, ann['labels_ignore']])
gt_bboxes.append(bboxes)
gt_labels.append(labels)
# If the dataset is VOC2007, then use 11 points mAP evaluation.
if hasattr(self.dataset, 'year') and self.dataset.year == 2007:
ds_name = 'voc07'
else:
ds_name = self.dataset.CLASSES
mean_ap, eval_results = eval_map(
results,
gt_bboxes,
gt_labels,
gt_ignore=gt_ignore,
scale_ranges=None,
iou_thr=0.5,
dataset=ds_name,
print_summary=True)
runner.log_buffer.output['mAP'] = mean_ap
runner.log_buffer.ready = True
class CocoDistEvalRecallHook(DistEvalHook): class CocoDistEvalRecallHook(DistEvalHook):
def __init__(self, def __init__(self,
......
from .custom import CustomDataset from .custom import CustomDataset
from .xml_style import XMLDataset
from .coco import CocoDataset from .coco import CocoDataset
from .voc import VOCDataset
from .loader import GroupSampler, DistributedGroupSampler, build_dataloader from .loader import GroupSampler, DistributedGroupSampler, build_dataloader
from .utils import to_tensor, random_scale, show_ann, get_dataset from .utils import to_tensor, random_scale, show_ann, get_dataset
from .concat_dataset import ConcatDataset from .concat_dataset import ConcatDataset
...@@ -7,7 +9,8 @@ from .repeat_dataset import RepeatDataset ...@@ -7,7 +9,8 @@ from .repeat_dataset import RepeatDataset
from .extra_aug import ExtraAugmentation from .extra_aug import ExtraAugmentation
__all__ = [ __all__ = [
'CustomDataset', 'CocoDataset', 'GroupSampler', 'DistributedGroupSampler', 'CustomDataset', 'XMLDataset', 'CocoDataset', 'VOCDataset', 'GroupSampler',
'build_dataloader', 'to_tensor', 'random_scale', 'show_ann', 'DistributedGroupSampler', 'build_dataloader', 'to_tensor', 'random_scale',
'get_dataset', 'ConcatDataset', 'RepeatDataset', 'ExtraAugmentation' 'show_ann', 'get_dataset', 'ConcatDataset', 'RepeatDataset',
'ExtraAugmentation'
] ]
...@@ -6,6 +6,21 @@ from .custom import CustomDataset ...@@ -6,6 +6,21 @@ from .custom import CustomDataset
class CocoDataset(CustomDataset): class CocoDataset(CustomDataset):
CLASSES = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
'train', 'truck', 'boat', 'traffic_light', 'fire_hydrant',
'stop_sign', 'parking_meter', 'bench', 'bird', 'cat', 'dog',
'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe',
'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
'skis', 'snowboard', 'sports_ball', 'kite', 'baseball_bat',
'baseball_glove', 'skateboard', 'surfboard', 'tennis_racket',
'bottle', 'wine_glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot',
'hot_dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
'potted_plant', 'bed', 'dining_table', 'toilet', 'tv', 'laptop',
'mouse', 'remote', 'keyboard', 'cell_phone', 'microwave',
'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock',
'vase', 'scissors', 'teddy_bear', 'hair_drier', 'toothbrush')
def load_annotations(self, ann_file): def load_annotations(self, ann_file):
self.coco = COCO(ann_file) self.coco = COCO(ann_file)
self.cat_ids = self.coco.getCatIds() self.cat_ids = self.coco.getCatIds()
......
...@@ -3,16 +3,18 @@ from torch.utils.data.dataset import ConcatDataset as _ConcatDataset ...@@ -3,16 +3,18 @@ from torch.utils.data.dataset import ConcatDataset as _ConcatDataset
class ConcatDataset(_ConcatDataset): class ConcatDataset(_ConcatDataset):
""" """A wrapper of concatenated dataset.
Same as torch.utils.data.dataset.ConcatDataset, but
Same as :obj:`torch.utils.data.dataset.ConcatDataset`, but
concat the group flag for image aspect ratio. concat the group flag for image aspect ratio.
Args:
datasets (list[:obj:`Dataset`]): A list of datasets.
""" """
def __init__(self, datasets): def __init__(self, datasets):
"""
flag: Images with aspect ratio greater than 1 will be set as group 1,
otherwise group 0.
"""
super(ConcatDataset, self).__init__(datasets) super(ConcatDataset, self).__init__(datasets)
self.CLASSES = datasets[0].CLASSES
if hasattr(datasets[0], 'flag'): if hasattr(datasets[0], 'flag'):
flags = [] flags = []
for i in range(0, len(datasets)): for i in range(0, len(datasets)):
......
...@@ -33,6 +33,8 @@ class CustomDataset(Dataset): ...@@ -33,6 +33,8 @@ class CustomDataset(Dataset):
The `ann` field is optional for testing. The `ann` field is optional for testing.
""" """
CLASSES = None
def __init__(self, def __init__(self,
ann_file, ann_file,
img_prefix, img_prefix,
...@@ -48,6 +50,9 @@ class CustomDataset(Dataset): ...@@ -48,6 +50,9 @@ class CustomDataset(Dataset):
test_mode=False, test_mode=False,
extra_aug=None, extra_aug=None,
resize_keep_ratio=True): resize_keep_ratio=True):
# prefix of images path
self.img_prefix = img_prefix
# load annotations (and proposals) # load annotations (and proposals)
self.img_infos = self.load_annotations(ann_file) self.img_infos = self.load_annotations(ann_file)
if proposal_file is not None: if proposal_file is not None:
...@@ -61,8 +66,6 @@ class CustomDataset(Dataset): ...@@ -61,8 +66,6 @@ class CustomDataset(Dataset):
if self.proposals is not None: if self.proposals is not None:
self.proposals = [self.proposals[i] for i in valid_inds] self.proposals = [self.proposals[i] for i in valid_inds]
# prefix of images path
self.img_prefix = img_prefix
# (long_edge, short_edge) or [(long1, short1), (long2, short2), ...] # (long_edge, short_edge) or [(long1, short1), (long2, short2), ...]
self.img_scales = img_scale if isinstance(img_scale, self.img_scales = img_scale if isinstance(img_scale,
list) else [img_scale] list) else [img_scale]
......
...@@ -6,12 +6,14 @@ class RepeatDataset(object): ...@@ -6,12 +6,14 @@ class RepeatDataset(object):
def __init__(self, dataset, times): def __init__(self, dataset, times):
self.dataset = dataset self.dataset = dataset
self.times = times self.times = times
self.CLASSES = dataset.CLASSES
if hasattr(self.dataset, 'flag'): if hasattr(self.dataset, 'flag'):
self.flag = np.tile(self.dataset.flag, times) self.flag = np.tile(self.dataset.flag, times)
self._original_length = len(self.dataset)
self._ori_len = len(self.dataset)
def __getitem__(self, idx): def __getitem__(self, idx):
return self.dataset[idx % self._original_length] return self.dataset[idx % self._ori_len]
def __len__(self): def __len__(self):
return self.times * self._original_length return self.times * self._ori_len
from .xml_style import XMLDataset
class VOCDataset(XMLDataset):
CLASSES = ('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car',
'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse',
'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train',
'tvmonitor')
def __init__(self, **kwargs):
super(VOCDataset, self).__init__(**kwargs)
if 'VOC2007' in self.img_prefix:
self.year = 2007
elif 'VOC2012' in self.img_prefix:
self.year = 2012
else:
raise ValueError('Cannot infer dataset year from img_prefix')
import os.path as osp
import xml.etree.ElementTree as ET
import mmcv
import numpy as np
from .custom import CustomDataset
class XMLDataset(CustomDataset):
def __init__(self, **kwargs):
super(XMLDataset, self).__init__(**kwargs)
self.cat2label = {cat: i + 1 for i, cat in enumerate(self.CLASSES)}
def load_annotations(self, ann_file):
img_infos = []
img_ids = mmcv.list_from_file(ann_file)
for img_id in img_ids:
filename = 'JPEGImages/{}.jpg'.format(img_id)
xml_path = osp.join(self.img_prefix, 'Annotations',
'{}.xml'.format(img_id))
tree = ET.parse(xml_path)
root = tree.getroot()
size = root.find('size')
width = int(size.find('width').text)
height = int(size.find('height').text)
img_infos.append(
dict(id=img_id, filename=filename, width=width, height=height))
return img_infos
def get_ann_info(self, idx):
img_id = self.img_infos[idx]['id']
xml_path = osp.join(self.img_prefix, 'Annotations',
'{}.xml'.format(img_id))
tree = ET.parse(xml_path)
root = tree.getroot()
bboxes = []
labels = []
bboxes_ignore = []
labels_ignore = []
for obj in root.findall('object'):
name = obj.find('name').text
label = self.cat2label[name]
difficult = int(obj.find('difficult').text)
bnd_box = obj.find('bndbox')
bbox = [
int(bnd_box.find('xmin').text),
int(bnd_box.find('ymin').text),
int(bnd_box.find('xmax').text),
int(bnd_box.find('ymax').text)
]
if difficult:
bboxes_ignore.append(bbox)
labels_ignore.append(label)
else:
bboxes.append(bbox)
labels.append(label)
if not bboxes:
bboxes = np.zeros((0, 4))
labels = np.zeros((0, ))
else:
bboxes = np.array(bboxes, ndmin=2) - 1
labels = np.array(labels)
if not bboxes_ignore:
bboxes_ignore = np.zeros((0, 4))
labels_ignore = np.zeros((0, ))
else:
bboxes_ignore = np.array(bboxes_ignore, ndmin=2) - 1
labels_ignore = np.array(labels_ignore)
ann = dict(
bboxes=bboxes.astype(np.float32),
labels=labels.astype(np.int64),
bboxes_ignore=bboxes_ignore.astype(np.float32),
labels_ignore=labels_ignore.astype(np.int64))
return ann
...@@ -99,11 +99,12 @@ class BaseDetector(nn.Module): ...@@ -99,11 +99,12 @@ class BaseDetector(nn.Module):
if isinstance(dataset, str): if isinstance(dataset, str):
class_names = get_classes(dataset) class_names = get_classes(dataset)
elif isinstance(dataset, list): elif isinstance(dataset, (list, tuple)) or dataset is None:
class_names = dataset class_names = dataset
else: else:
raise TypeError('dataset must be a valid dataset name or a list' raise TypeError(
' of class names, not {}'.format(type(dataset))) 'dataset must be a valid dataset name or a sequence'
' of class names, not {}'.format(type(dataset)))
for img, img_meta in zip(imgs, img_metas): for img, img_meta in zip(imgs, img_metas):
h, w, _ = img_meta['img_shape'] h, w, _ = img_meta['img_shape']
......
...@@ -14,15 +14,16 @@ from mmdet.models import build_detector, detectors ...@@ -14,15 +14,16 @@ from mmdet.models import build_detector, detectors
def single_test(model, data_loader, show=False): def single_test(model, data_loader, show=False):
model.eval() model.eval()
results = [] results = []
prog_bar = mmcv.ProgressBar(len(data_loader.dataset)) dataset = data_loader.dataset
prog_bar = mmcv.ProgressBar(len(dataset))
for i, data in enumerate(data_loader): for i, data in enumerate(data_loader):
with torch.no_grad(): with torch.no_grad():
result = model(return_loss=False, rescale=not show, **data) result = model(return_loss=False, rescale=not show, **data)
results.append(result) results.append(result)
if show: if show:
model.module.show_result(data, result, model.module.show_result(data, result, dataset.img_norm_cfg,
data_loader.dataset.img_norm_cfg) dataset.CLASSES)
batch_size = data['img'][0].size(0) batch_size = data['img'][0].size(0)
for _ in range(batch_size): for _ in range(batch_size):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment