import os.path as osp import xml.etree.ElementTree as ET import mmcv import numpy as np from .custom import CustomDataset class VOCDataset(CustomDataset): CLASSES = ('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor') def __init__(self, **kwargs): assert not kwargs.get('with_mask', False) super(VOCDataset, self).__init__(**kwargs) self.cat2label = {cat: i + 1 for i, cat in enumerate(self.CLASSES)} def load_annotations(self, ann_file): self.img_infos = [] img_ids = mmcv.list_from_file(ann_file) for img_id in img_ids: filename = 'JPEGImages/{}.jpg'.format(img_id) xml_path = osp.join(self.img_prefix, 'Annotations', '{}.xml'.format(img_id)) tree = ET.parse(xml_path) root = tree.getroot() size = root.find('size') width = int(size.find('width').text) height = int(size.find('height').text) self.img_infos.append( dict(id=img_id, filename=filename, width=width, height=height)) return self.img_infos def get_ann_info(self, idx): img_id = self.img_infos[idx]['id'] xml_path = osp.join(self.img_prefix, 'Annotations', '{}.xml'.format(img_id)) tree = ET.parse(xml_path) root = tree.getroot() bboxes = [] labels = [] bboxes_ignore = [] labels_ignore = [] for obj in root.findall('object'): name = obj.find('name').text label = self.cat2label[name] difficult = int(obj.find('difficult').text) bnd_box = obj.find('bndbox') bbox = [ int(bnd_box.find('xmin').text), int(bnd_box.find('ymin').text), int(bnd_box.find('xmax').text), int(bnd_box.find('ymax').text) ] if difficult: bboxes_ignore.append(bbox) labels_ignore.append(label) else: bboxes.append(bbox) labels.append(label) if not bboxes: bboxes = np.zeros((0, 4)) labels = np.zeros((0, )) else: bboxes = np.array(bboxes, ndmin=2) - 1 labels = np.array(labels) if not bboxes_ignore: bboxes_ignore = np.zeros((0, 4)) labels_ignore = np.zeros((0, )) else: bboxes_ignore = np.array(bboxes_ignore, ndmin=2) - 1 labels_ignore = np.array(labels_ignore) ann = dict( bboxes=bboxes.astype(np.float32), labels=labels.astype(np.int64), bboxes_ignore=bboxes_ignore.astype(np.float32), labels_ignore=labels_ignore.astype(np.int64)) return ann