relation.py 5.4 KB
Newer Older
1
"""Pascal VOC object detection dataset."""
2
3
from __future__ import absolute_import, division

4
import json
5
6
import logging
import os
7
import pickle
8
9
10
import warnings
from collections import Counter

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
11
12
import dgl

13
import mxnet as mx
14
import numpy as np
15
from gluoncv.data.base import VisionDataset
16
from gluoncv.data.transforms.presets.rcnn import (
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
17
18
19
    FasterRCNNDefaultTrainTransform,
    FasterRCNNDefaultValTransform,
)
20

21
22

class VGRelation(VisionDataset):
23
24
25
26
27
    def __init__(
        self,
        root=os.path.join("~", ".mxnet", "datasets", "visualgenome"),
        split="train",
    ):
28
29
        super(VGRelation, self).__init__(root)
        self._root = os.path.expanduser(root)
30
31
32
33
34
35
36
37
38
39
        self._img_path = os.path.join(self._root, "VG_100K", "{}")

        if split == "train":
            self._dict_path = os.path.join(
                self._root, "rel_annotations_train.json"
            )
        elif split == "val":
            self._dict_path = os.path.join(
                self._root, "rel_annotations_val.json"
            )
40
41
42
43
44
45
        else:
            raise NotImplementedError
        with open(self._dict_path) as f:
            tmp = f.read()
            self._dict = json.loads(tmp)

46
47
        self._predicates_path = os.path.join(self._root, "predicates.json")
        with open(self._predicates_path, "r") as f:
48
49
50
51
            tmp = f.read()
            self.rel_classes = json.loads(tmp)
        self.num_rel_classes = len(self.rel_classes) + 1

52
53
        self._objects_path = os.path.join(self._root, "objects.json")
        with open(self._objects_path, "r") as f:
54
55
56
57
            tmp = f.read()
            self.obj_classes = json.loads(tmp)
        self.num_obj_classes = len(self.obj_classes)

58
59
60
61
        if split == "val":
            self.img_transform = FasterRCNNDefaultValTransform(
                short=600, max_size=1000
            )
62
        else:
63
64
65
            self.img_transform = FasterRCNNDefaultTrainTransform(
                short=600, max_size=1000
            )
66
67
68
69
70
71
        self.split = split

    def __len__(self):
        return len(self._dict)

    def _hash_bbox(self, object):
72
73
        num_list = [object["category"]] + object["bbox"]
        return "_".join([str(num) for num in num_list])
74
75
76
77
78
79
80
81
82
83
84
85
86

    def __getitem__(self, idx):
        img_id = list(self._dict)[idx]
        img_path = self._img_path.format(img_id)
        img = mx.image.imread(img_path)

        item = self._dict[img_id]
        n_edges = len(item)

        # edge to node ids
        sub_node_hash = []
        ob_node_hash = []
        for i, it in enumerate(item):
87
88
            sub_node_hash.append(self._hash_bbox(it["subject"]))
            ob_node_hash.append(self._hash_bbox(it["object"]))
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
        node_set = sorted(list(set(sub_node_hash + ob_node_hash)))
        n_nodes = len(node_set)
        node_to_id = {}
        for i, node in enumerate(node_set):
            node_to_id[node] = i
        sub_id = []
        ob_id = []
        for i in range(n_edges):
            sub_id.append(node_to_id[sub_node_hash[i]])
            ob_id.append(node_to_id[ob_node_hash[i]])

        # node features
        bbox = mx.nd.zeros((n_nodes, 4))
        node_class_ids = mx.nd.zeros((n_nodes, 1))
        node_visited = [False for i in range(n_nodes)]
        for i, it in enumerate(item):
            if not node_visited[sub_id[i]]:
                ind = sub_id[i]
107
108
                sub = it["subject"]
                node_class_ids[ind] = sub["category"]
109
                # y1y2x1x2 to x1y1x2y2
110
111
112
113
                bbox[ind, 0] = sub["bbox"][2]
                bbox[ind, 1] = sub["bbox"][0]
                bbox[ind, 2] = sub["bbox"][3]
                bbox[ind, 3] = sub["bbox"][1]
114
115
116
117
118

                node_visited[ind] = True

            if not node_visited[ob_id[i]]:
                ind = ob_id[i]
119
120
                ob = it["object"]
                node_class_ids[ind] = ob["category"]
121
                # y1y2x1x2 to x1y1x2y2
122
123
124
125
                bbox[ind, 0] = ob["bbox"][2]
                bbox[ind, 1] = ob["bbox"][0]
                bbox[ind, 2] = ob["bbox"][3]
                bbox[ind, 3] = ob["bbox"][1]
126
127
128
129

                node_visited[ind] = True

        eta = 0.1
130
131
132
133
134
        node_class_vec = node_class_ids[:, 0].one_hot(
            self.num_obj_classes,
            on_value=1 - eta + eta / self.num_obj_classes,
            off_value=eta / self.num_obj_classes,
        )
135
136

        # augmentation
137
        if self.split == "val":
138
139
140
141
142
            img, bbox, _ = self.img_transform(img, bbox)
        else:
            img, bbox = self.img_transform(img, bbox)

        # build the graph
143
        g = dgl.DGLGraph()
144
145
146
147
148
        g.add_nodes(n_nodes)
        adjmat = np.zeros((n_nodes, n_nodes))
        predicate = []
        for i, it in enumerate(item):
            adjmat[sub_id[i], ob_id[i]] = 1
149
            predicate.append(it["predicate"])
150
        predicate = mx.nd.array(predicate).expand_dims(1)
151
        g.add_edges(sub_id, ob_id, {"rel_class": mx.nd.array(predicate) + 1})
152
153
154
155
156
157
158
        empty_edge_list = []
        for i in range(n_nodes):
            for j in range(n_nodes):
                if i != j and adjmat[i, j] == 0:
                    empty_edge_list.append((i, j))
        if len(empty_edge_list) > 0:
            src, dst = tuple(zip(*empty_edge_list))
159
160
161
            g.add_edges(
                src, dst, {"rel_class": mx.nd.zeros((len(empty_edge_list), 1))}
            )
162
163

        # assign features
164
165
166
        g.ndata["bbox"] = bbox
        g.ndata["node_class"] = node_class_ids
        g.ndata["node_class_vec"] = node_class_vec
167
168

        return g, img