build_graph.py 11.1 KB
Newer Older
1
import numpy as np
2
3
4
5
from mxnet import nd

import dgl

6
7

def bbox_improve(bbox):
8
9
    """bbox encoding"""
    area = (bbox[:, 2] - bbox[:, 0]) * (bbox[:, 3] - bbox[:, 1])
10
11
    return nd.concat(bbox, area.expand_dims(1))

12

13
def extract_edge_bbox(g):
14
15
    """bbox encoding"""
    src, dst = g.edges(order="eid")
16
    n = g.number_of_edges()
17
18
19
20
21
22
23
    src_bbox = g.ndata["pred_bbox"][src.asnumpy()]
    dst_bbox = g.ndata["pred_bbox"][dst.asnumpy()]
    edge_bbox = nd.zeros((n, 4), ctx=g.ndata["pred_bbox"].context)
    edge_bbox[:, 0] = nd.stack(src_bbox[:, 0], dst_bbox[:, 0]).min(axis=0)
    edge_bbox[:, 1] = nd.stack(src_bbox[:, 1], dst_bbox[:, 1]).min(axis=0)
    edge_bbox[:, 2] = nd.stack(src_bbox[:, 2], dst_bbox[:, 2]).max(axis=0)
    edge_bbox[:, 3] = nd.stack(src_bbox[:, 3], dst_bbox[:, 3]).max(axis=0)
24
25
    return edge_bbox

26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41

def build_graph_train(
    g_slice,
    gt_bbox,
    img,
    ids,
    scores,
    bbox,
    feat_ind,
    spatial_feat,
    iou_thresh=0.5,
    bbox_improvement=True,
    scores_top_k=50,
    overlap=False,
):
    """given ground truth and predicted bboxes, assign the label to the predicted w.r.t iou_thresh"""
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
    # match and re-factor the graph
    img_size = img.shape[2:4]
    gt_bbox[:, :, 0] /= img_size[1]
    gt_bbox[:, :, 1] /= img_size[0]
    gt_bbox[:, :, 2] /= img_size[1]
    gt_bbox[:, :, 3] /= img_size[0]
    bbox[:, :, 0] /= img_size[1]
    bbox[:, :, 1] /= img_size[0]
    bbox[:, :, 2] /= img_size[1]
    bbox[:, :, 3] /= img_size[0]

    n_graph = len(g_slice)
    g_pred_batch = []
    for gi in range(n_graph):
        g = g_slice[gi]
57
        ctx = g.ndata["bbox"].context
58
59
60
61
        inds = np.where(scores[gi, :, 0].asnumpy() > 0)[0].tolist()
        if len(inds) == 0:
            return None
        if len(inds) > scores_top_k:
62
63
64
            top_score_inds = (
                scores[gi, inds, 0].asnumpy().argsort()[::-1][0:scores_top_k]
            )
65
            inds = np.array(inds)[top_score_inds].tolist()
66

67
68
        n_nodes = len(inds)
        roi_ind = feat_ind[gi, inds].squeeze(axis=1)
69
        g_pred = dgl.DGLGraph()
70
71
72
73
74
75
76
77
78
        g_pred.add_nodes(
            n_nodes,
            {
                "pred_bbox": bbox[gi, inds],
                "node_feat": spatial_feat[gi, roi_ind],
                "node_class_pred": ids[gi, inds, 0],
                "node_class_logit": nd.log(scores[gi, inds, 0] + 1e-7),
            },
        )
79
80

        # iou matching
81
82
83
        ious = nd.contrib.box_iou(
            gt_bbox[gi], g_pred.ndata["pred_bbox"]
        ).asnumpy()
84
85
86
87
88
89
90
91
92
93
94
95
96
        H, W = ious.shape
        h = H
        w = W
        pred_to_gt_ind = np.array([-1 for i in range(W)])
        pred_to_gt_class_match = [0 for i in range(W)]
        pred_to_gt_class_match_id = [0 for i in range(W)]
        while h > 0 and w > 0:
            ind = int(ious.argmax())
            row_ind = ind // W
            col_ind = ind % W
            if ious[row_ind, col_ind] < iou_thresh:
                break
            pred_to_gt_ind[col_ind] = row_ind
97
98
            gt_node_class = g.ndata["node_class"][row_ind]
            pred_node_class = g_pred.ndata["node_class_pred"][col_ind]
99
100
101
102
103
104
105
106
107
108
109
110
            if gt_node_class == pred_node_class:
                pred_to_gt_class_match[col_ind] = 1
                pred_to_gt_class_match_id[col_ind] = row_ind
            ious[row_ind, :] = -1
            ious[:, col_ind] = -1
            h -= 1
            w -= 1

        n_nodes = g_pred.number_of_nodes()
        triplet = []
        adjmat = np.zeros((n_nodes, n_nodes))

111
        src, dst = g.all_edges(order="eid")
112
113
114
115
116
117
118
119
        eid_keys = np.column_stack([src.asnumpy(), dst.asnumpy()])
        eid_dict = {}
        for i, key in enumerate(eid_keys):
            k = tuple(key)
            if k not in eid_dict:
                eid_dict[k] = [i]
            else:
                eid_dict[k].append(i)
120
        ori_rel_class = g.edata["rel_class"].asnumpy()
121
122
123
124
125
126
127
128
129
130
131
        for i in range(n_nodes):
            for j in range(n_nodes):
                if i != j:
                    if pred_to_gt_class_match[i] and pred_to_gt_class_match[j]:
                        sub_gt_id = pred_to_gt_class_match_id[i]
                        ob_gt_id = pred_to_gt_class_match_id[j]
                        eids = eid_dict[(sub_gt_id, ob_gt_id)]
                        rel_cls = ori_rel_class[eids]
                        n_edges_between = len(rel_cls)
                        for ii in range(n_edges_between):
                            triplet.append((i, j, rel_cls[ii]))
132
                        adjmat[i, j] = 1
133
134
135
136
                    else:
                        triplet.append((i, j, 0))
        src, dst, rel_class = tuple(zip(*triplet))
        rel_class = nd.array(rel_class, ctx=ctx).expand_dims(1)
137
        g_pred.add_edges(src, dst, data={"rel_class": rel_class})
138
139
140
141
142

        # other operations
        n_nodes = g_pred.number_of_nodes()
        n_edges = g_pred.number_of_edges()
        if bbox_improvement:
143
144
145
            g_pred.ndata["pred_bbox"] = bbox_improve(g_pred.ndata["pred_bbox"])
        g_pred.edata["rel_bbox"] = extract_edge_bbox(g_pred)
        g_pred.edata["batch_id"] = nd.zeros((n_edges, 1), ctx=ctx) + gi
146
147
148

        # remove non-overlapping edges
        if overlap:
149
150
151
152
            overlap_ious = nd.contrib.box_iou(
                g_pred.ndata["pred_bbox"][:, 0:4],
                g_pred.ndata["pred_bbox"][:, 0:4],
            ).asnumpy()
153
154
155
156
157
158
159
160
            cols, rows = np.where(overlap_ious <= 1e-7)
            if cols.shape[0] > 0:
                eids = g_pred.edge_ids(cols, rows)[2].asnumpy().tolist()
                if len(eids):
                    g_pred.remove_edges(eids)
                    if g_pred.number_of_edges() == 0:
                        g_pred = None
        g_pred_batch.append(g_pred)
161

162
163
164
165
166
    if n_graph > 1:
        return dgl.batch(g_pred_batch)
    else:
        return g_pred_batch[0]

167
168
169
170
171

def build_graph_validate_gt_obj(
    img, gt_ids, bbox, spatial_feat, bbox_improvement=True, overlap=False
):
    """given ground truth bbox and label, build graph for validation"""
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
    n_batch = img.shape[0]
    img_size = img.shape[2:4]
    bbox[:, :, 0] /= img_size[1]
    bbox[:, :, 1] /= img_size[0]
    bbox[:, :, 2] /= img_size[1]
    bbox[:, :, 3] /= img_size[0]
    ctx = img.context

    g_batch = []
    for btc in range(n_batch):
        inds = np.where(bbox[btc].sum(1).asnumpy() > 0)[0].tolist()
        if len(inds) == 0:
            continue
        n_nodes = len(inds)
        g_pred = dgl.DGLGraph()
187
188
189
190
191
192
193
194
195
196
197
        g_pred.add_nodes(
            n_nodes,
            {
                "pred_bbox": bbox[btc, inds],
                "node_feat": spatial_feat[btc, inds],
                "node_class_pred": gt_ids[btc, inds, 0],
                "node_class_logit": nd.zeros_like(
                    gt_ids[btc, inds, 0], ctx=ctx
                ),
            },
        )
198
199
200
201
202
203
204
205
206
207
208
209

        edge_list = []
        for i in range(n_nodes - 1):
            for j in range(i + 1, n_nodes):
                edge_list.append((i, j))
        src, dst = tuple(zip(*edge_list))
        g_pred.add_edges(src, dst)
        g_pred.add_edges(dst, src)

        n_nodes = g_pred.number_of_nodes()
        n_edges = g_pred.number_of_edges()
        if bbox_improvement:
210
211
212
            g_pred.ndata["pred_bbox"] = bbox_improve(g_pred.ndata["pred_bbox"])
        g_pred.edata["rel_bbox"] = extract_edge_bbox(g_pred)
        g_pred.edata["batch_id"] = nd.zeros((n_edges, 1), ctx=ctx) + btc
213
214
215
216

        g_batch.append(g_pred)

    if len(g_batch) == 0:
217
        return None
218
219
220
221
    if len(g_batch) > 1:
        return dgl.batch(g_batch)
    return g_batch[0]

222
223
224
225
226
227
228
229
230
231
232
233

def build_graph_validate_gt_bbox(
    img,
    ids,
    scores,
    bbox,
    spatial_feat,
    gt_ids=None,
    bbox_improvement=True,
    overlap=False,
):
    """given ground truth bbox, build graph for validation"""
234
235
236
237
238
239
240
241
242
243
    n_batch = img.shape[0]
    img_size = img.shape[2:4]
    bbox[:, :, 0] /= img_size[1]
    bbox[:, :, 1] /= img_size[0]
    bbox[:, :, 2] /= img_size[1]
    bbox[:, :, 3] /= img_size[0]
    ctx = img.context

    g_batch = []
    for btc in range(n_batch):
244
245
        id_btc = scores[btc][:, :, 0].argmax(0)
        score_btc = scores[btc][:, :, 0].max(0)
246
247
248
249
250
        inds = np.where(bbox[btc].sum(1).asnumpy() > 0)[0].tolist()
        if len(inds) == 0:
            continue
        n_nodes = len(inds)
        g_pred = dgl.DGLGraph()
251
252
253
254
255
256
257
258
259
        g_pred.add_nodes(
            n_nodes,
            {
                "pred_bbox": bbox[btc, inds],
                "node_feat": spatial_feat[btc, inds],
                "node_class_pred": id_btc,
                "node_class_logit": nd.log(score_btc + 1e-7),
            },
        )
260
261
262
263
264
265
266
267
268
269
270
271

        edge_list = []
        for i in range(n_nodes - 1):
            for j in range(i + 1, n_nodes):
                edge_list.append((i, j))
        src, dst = tuple(zip(*edge_list))
        g_pred.add_edges(src, dst)
        g_pred.add_edges(dst, src)

        n_nodes = g_pred.number_of_nodes()
        n_edges = g_pred.number_of_edges()
        if bbox_improvement:
272
273
274
            g_pred.ndata["pred_bbox"] = bbox_improve(g_pred.ndata["pred_bbox"])
        g_pred.edata["rel_bbox"] = extract_edge_bbox(g_pred)
        g_pred.edata["batch_id"] = nd.zeros((n_edges, 1), ctx=ctx) + btc
275
276
277
278

        g_batch.append(g_pred)

    if len(g_batch) == 0:
279
        return None
280
281
282
283
    if len(g_batch) > 1:
        return dgl.batch(g_batch)
    return g_batch[0]

284
285
286
287
288
289
290
291
292
293
294
295
296

def build_graph_validate_pred(
    img,
    ids,
    scores,
    bbox,
    feat_ind,
    spatial_feat,
    bbox_improvement=True,
    scores_top_k=50,
    overlap=False,
):
    """given predicted bbox, build graph for validation"""
297
298
299
300
301
302
303
304
305
306
307
308
309
310
    n_batch = img.shape[0]
    img_size = img.shape[2:4]
    bbox[:, :, 0] /= img_size[1]
    bbox[:, :, 1] /= img_size[0]
    bbox[:, :, 2] /= img_size[1]
    bbox[:, :, 3] /= img_size[0]
    ctx = img.context

    g_batch = []
    for btc in range(n_batch):
        inds = np.where(scores[btc, :, 0].asnumpy() > 0)[0].tolist()
        if len(inds) == 0:
            continue
        if len(inds) > scores_top_k:
311
312
313
            top_score_inds = (
                scores[btc, inds, 0].asnumpy().argsort()[::-1][0:scores_top_k]
            )
314
315
316
317
318
            inds = np.array(inds)[top_score_inds].tolist()
        n_nodes = len(inds)
        roi_ind = feat_ind[btc, inds].squeeze(axis=1)

        g_pred = dgl.DGLGraph()
319
320
321
322
323
324
325
326
327
        g_pred.add_nodes(
            n_nodes,
            {
                "pred_bbox": bbox[btc, inds],
                "node_feat": spatial_feat[btc, roi_ind],
                "node_class_pred": ids[btc, inds, 0],
                "node_class_logit": nd.log(scores[btc, inds, 0] + 1e-7),
            },
        )
328
329
330
331
332
333
334
335
336
337
338
339

        edge_list = []
        for i in range(n_nodes - 1):
            for j in range(i + 1, n_nodes):
                edge_list.append((i, j))
        src, dst = tuple(zip(*edge_list))
        g_pred.add_edges(src, dst)
        g_pred.add_edges(dst, src)

        n_nodes = g_pred.number_of_nodes()
        n_edges = g_pred.number_of_edges()
        if bbox_improvement:
340
341
342
            g_pred.ndata["pred_bbox"] = bbox_improve(g_pred.ndata["pred_bbox"])
        g_pred.edata["rel_bbox"] = extract_edge_bbox(g_pred)
        g_pred.edata["batch_id"] = nd.zeros((n_edges, 1), ctx=ctx) + btc
343
344
345
346

        g_batch.append(g_pred)

    if len(g_batch) == 0:
347
        return None
348
349
350
    if len(g_batch) > 1:
        return dgl.batch(g_batch)
    return g_batch[0]