relation.py 5.4 KB
Newer Older
1
"""Pascal VOC object detection dataset."""
2
3
from __future__ import absolute_import, division

4
import json
5
6
import logging
import os
7
import pickle
8
9
10
import warnings
from collections import Counter

11
import mxnet as mx
12
import numpy as np
13
from gluoncv.data.base import VisionDataset
14
15
16
17
18
from gluoncv.data.transforms.presets.rcnn import (
    FasterRCNNDefaultTrainTransform, FasterRCNNDefaultValTransform)

import dgl

19
20

class VGRelation(VisionDataset):
21
22
23
24
25
    def __init__(
        self,
        root=os.path.join("~", ".mxnet", "datasets", "visualgenome"),
        split="train",
    ):
26
27
        super(VGRelation, self).__init__(root)
        self._root = os.path.expanduser(root)
28
29
30
31
32
33
34
35
36
37
        self._img_path = os.path.join(self._root, "VG_100K", "{}")

        if split == "train":
            self._dict_path = os.path.join(
                self._root, "rel_annotations_train.json"
            )
        elif split == "val":
            self._dict_path = os.path.join(
                self._root, "rel_annotations_val.json"
            )
38
39
40
41
42
43
        else:
            raise NotImplementedError
        with open(self._dict_path) as f:
            tmp = f.read()
            self._dict = json.loads(tmp)

44
45
        self._predicates_path = os.path.join(self._root, "predicates.json")
        with open(self._predicates_path, "r") as f:
46
47
48
49
            tmp = f.read()
            self.rel_classes = json.loads(tmp)
        self.num_rel_classes = len(self.rel_classes) + 1

50
51
        self._objects_path = os.path.join(self._root, "objects.json")
        with open(self._objects_path, "r") as f:
52
53
54
55
            tmp = f.read()
            self.obj_classes = json.loads(tmp)
        self.num_obj_classes = len(self.obj_classes)

56
57
58
59
        if split == "val":
            self.img_transform = FasterRCNNDefaultValTransform(
                short=600, max_size=1000
            )
60
        else:
61
62
63
            self.img_transform = FasterRCNNDefaultTrainTransform(
                short=600, max_size=1000
            )
64
65
66
67
68
69
        self.split = split

    def __len__(self):
        return len(self._dict)

    def _hash_bbox(self, object):
70
71
        num_list = [object["category"]] + object["bbox"]
        return "_".join([str(num) for num in num_list])
72
73
74
75
76
77
78
79
80
81
82
83
84

    def __getitem__(self, idx):
        img_id = list(self._dict)[idx]
        img_path = self._img_path.format(img_id)
        img = mx.image.imread(img_path)

        item = self._dict[img_id]
        n_edges = len(item)

        # edge to node ids
        sub_node_hash = []
        ob_node_hash = []
        for i, it in enumerate(item):
85
86
            sub_node_hash.append(self._hash_bbox(it["subject"]))
            ob_node_hash.append(self._hash_bbox(it["object"]))
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
        node_set = sorted(list(set(sub_node_hash + ob_node_hash)))
        n_nodes = len(node_set)
        node_to_id = {}
        for i, node in enumerate(node_set):
            node_to_id[node] = i
        sub_id = []
        ob_id = []
        for i in range(n_edges):
            sub_id.append(node_to_id[sub_node_hash[i]])
            ob_id.append(node_to_id[ob_node_hash[i]])

        # node features
        bbox = mx.nd.zeros((n_nodes, 4))
        node_class_ids = mx.nd.zeros((n_nodes, 1))
        node_visited = [False for i in range(n_nodes)]
        for i, it in enumerate(item):
            if not node_visited[sub_id[i]]:
                ind = sub_id[i]
105
106
                sub = it["subject"]
                node_class_ids[ind] = sub["category"]
107
                # y1y2x1x2 to x1y1x2y2
108
109
110
111
                bbox[ind, 0] = sub["bbox"][2]
                bbox[ind, 1] = sub["bbox"][0]
                bbox[ind, 2] = sub["bbox"][3]
                bbox[ind, 3] = sub["bbox"][1]
112
113
114
115
116

                node_visited[ind] = True

            if not node_visited[ob_id[i]]:
                ind = ob_id[i]
117
118
                ob = it["object"]
                node_class_ids[ind] = ob["category"]
119
                # y1y2x1x2 to x1y1x2y2
120
121
122
123
                bbox[ind, 0] = ob["bbox"][2]
                bbox[ind, 1] = ob["bbox"][0]
                bbox[ind, 2] = ob["bbox"][3]
                bbox[ind, 3] = ob["bbox"][1]
124
125
126
127

                node_visited[ind] = True

        eta = 0.1
128
129
130
131
132
        node_class_vec = node_class_ids[:, 0].one_hot(
            self.num_obj_classes,
            on_value=1 - eta + eta / self.num_obj_classes,
            off_value=eta / self.num_obj_classes,
        )
133
134

        # augmentation
135
        if self.split == "val":
136
137
138
139
140
            img, bbox, _ = self.img_transform(img, bbox)
        else:
            img, bbox = self.img_transform(img, bbox)

        # build the graph
141
        g = dgl.DGLGraph()
142
143
144
145
146
        g.add_nodes(n_nodes)
        adjmat = np.zeros((n_nodes, n_nodes))
        predicate = []
        for i, it in enumerate(item):
            adjmat[sub_id[i], ob_id[i]] = 1
147
            predicate.append(it["predicate"])
148
        predicate = mx.nd.array(predicate).expand_dims(1)
149
        g.add_edges(sub_id, ob_id, {"rel_class": mx.nd.array(predicate) + 1})
150
151
152
153
154
155
156
        empty_edge_list = []
        for i in range(n_nodes):
            for j in range(n_nodes):
                if i != j and adjmat[i, j] == 0:
                    empty_edge_list.append((i, j))
        if len(empty_edge_list) > 0:
            src, dst = tuple(zip(*empty_edge_list))
157
158
159
            g.add_edges(
                src, dst, {"rel_class": mx.nd.zeros((len(empty_edge_list), 1))}
            )
160
161

        # assign features
162
163
164
        g.ndata["bbox"] = bbox
        g.ndata["node_class"] = node_class_ids
        g.ndata["node_class_vec"] = node_class_vec
165
166

        return g, img