"""Pascal VOC object detection dataset.""" from __future__ import absolute_import from __future__ import division import os import logging import warnings import json import dgl import pickle import numpy as np import mxnet as mx from gluoncv.data.base import VisionDataset from collections import Counter from gluoncv.data.transforms.presets.rcnn import FasterRCNNDefaultTrainTransform, FasterRCNNDefaultValTransform class VGRelation(VisionDataset): def __init__(self, root=os.path.join('~', '.mxnet', 'datasets', 'visualgenome'), split='train'): super(VGRelation, self).__init__(root) self._root = os.path.expanduser(root) self._img_path = os.path.join(self._root, 'VG_100K', '{}') if split == 'train': self._dict_path = os.path.join(self._root, 'rel_annotations_train.json') elif split == 'val': self._dict_path = os.path.join(self._root, 'rel_annotations_val.json') else: raise NotImplementedError with open(self._dict_path) as f: tmp = f.read() self._dict = json.loads(tmp) self._predicates_path = os.path.join(self._root, 'predicates.json') with open(self._predicates_path, 'r') as f: tmp = f.read() self.rel_classes = json.loads(tmp) self.num_rel_classes = len(self.rel_classes) + 1 self._objects_path = os.path.join(self._root, 'objects.json') with open(self._objects_path, 'r') as f: tmp = f.read() self.obj_classes = json.loads(tmp) self.num_obj_classes = len(self.obj_classes) if split == 'val': self.img_transform = FasterRCNNDefaultValTransform(short=600, max_size=1000) else: self.img_transform = FasterRCNNDefaultTrainTransform(short=600, max_size=1000) self.split = split def __len__(self): return len(self._dict) def _hash_bbox(self, object): num_list = [object['category']] + object['bbox'] return '_'.join([str(num) for num in num_list]) def __getitem__(self, idx): img_id = list(self._dict)[idx] img_path = self._img_path.format(img_id) img = mx.image.imread(img_path) item = self._dict[img_id] n_edges = len(item) # edge to node ids sub_node_hash = [] ob_node_hash = [] for i, it in enumerate(item): sub_node_hash.append(self._hash_bbox(it['subject'])) ob_node_hash.append(self._hash_bbox(it['object'])) node_set = sorted(list(set(sub_node_hash + ob_node_hash))) n_nodes = len(node_set) node_to_id = {} for i, node in enumerate(node_set): node_to_id[node] = i sub_id = [] ob_id = [] for i in range(n_edges): sub_id.append(node_to_id[sub_node_hash[i]]) ob_id.append(node_to_id[ob_node_hash[i]]) # node features bbox = mx.nd.zeros((n_nodes, 4)) node_class_ids = mx.nd.zeros((n_nodes, 1)) node_visited = [False for i in range(n_nodes)] for i, it in enumerate(item): if not node_visited[sub_id[i]]: ind = sub_id[i] sub = it['subject'] node_class_ids[ind] = sub['category'] # y1y2x1x2 to x1y1x2y2 bbox[ind,0] = sub['bbox'][2] bbox[ind,1] = sub['bbox'][0] bbox[ind,2] = sub['bbox'][3] bbox[ind,3] = sub['bbox'][1] node_visited[ind] = True if not node_visited[ob_id[i]]: ind = ob_id[i] ob = it['object'] node_class_ids[ind] = ob['category'] # y1y2x1x2 to x1y1x2y2 bbox[ind,0] = ob['bbox'][2] bbox[ind,1] = ob['bbox'][0] bbox[ind,2] = ob['bbox'][3] bbox[ind,3] = ob['bbox'][1] node_visited[ind] = True eta = 0.1 node_class_vec = node_class_ids[:,0].one_hot(self.num_obj_classes, on_value = 1 - eta + eta / self.num_obj_classes, off_value = eta / self.num_obj_classes) # augmentation if self.split == 'val': img, bbox, _ = self.img_transform(img, bbox) else: img, bbox = self.img_transform(img, bbox) # build the graph g = dgl.DGLGraph(multigraph=True) g.add_nodes(n_nodes) adjmat = np.zeros((n_nodes, n_nodes)) predicate = [] for i, it in enumerate(item): adjmat[sub_id[i], ob_id[i]] = 1 predicate.append(it['predicate']) predicate = mx.nd.array(predicate).expand_dims(1) g.add_edges(sub_id, ob_id, {'rel_class': mx.nd.array(predicate) + 1}) empty_edge_list = [] for i in range(n_nodes): for j in range(n_nodes): if i != j and adjmat[i, j] == 0: empty_edge_list.append((i, j)) if len(empty_edge_list) > 0: src, dst = tuple(zip(*empty_edge_list)) g.add_edges(src, dst, {'rel_class': mx.nd.zeros((len(empty_edge_list), 1))}) # assign features g.ndata['bbox'] = bbox g.ndata['node_class'] = node_class_ids g.ndata['node_class_vec'] = node_class_vec return g, img