"vscode:/vscode.git/clone" did not exist on "8a45147f9df23b12981f1e80554bbae251f594ea"
build_graph.py 11.1 KB
Newer Older
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1
import dgl
2
import numpy as np
3
4
from mxnet import nd

5
6

def bbox_improve(bbox):
7
8
    """bbox encoding"""
    area = (bbox[:, 2] - bbox[:, 0]) * (bbox[:, 3] - bbox[:, 1])
9
10
    return nd.concat(bbox, area.expand_dims(1))

11

12
def extract_edge_bbox(g):
13
14
    """bbox encoding"""
    src, dst = g.edges(order="eid")
15
    n = g.number_of_edges()
16
17
18
19
20
21
22
    src_bbox = g.ndata["pred_bbox"][src.asnumpy()]
    dst_bbox = g.ndata["pred_bbox"][dst.asnumpy()]
    edge_bbox = nd.zeros((n, 4), ctx=g.ndata["pred_bbox"].context)
    edge_bbox[:, 0] = nd.stack(src_bbox[:, 0], dst_bbox[:, 0]).min(axis=0)
    edge_bbox[:, 1] = nd.stack(src_bbox[:, 1], dst_bbox[:, 1]).min(axis=0)
    edge_bbox[:, 2] = nd.stack(src_bbox[:, 2], dst_bbox[:, 2]).max(axis=0)
    edge_bbox[:, 3] = nd.stack(src_bbox[:, 3], dst_bbox[:, 3]).max(axis=0)
23
24
    return edge_bbox

25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40

def build_graph_train(
    g_slice,
    gt_bbox,
    img,
    ids,
    scores,
    bbox,
    feat_ind,
    spatial_feat,
    iou_thresh=0.5,
    bbox_improvement=True,
    scores_top_k=50,
    overlap=False,
):
    """given ground truth and predicted bboxes, assign the label to the predicted w.r.t iou_thresh"""
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
    # match and re-factor the graph
    img_size = img.shape[2:4]
    gt_bbox[:, :, 0] /= img_size[1]
    gt_bbox[:, :, 1] /= img_size[0]
    gt_bbox[:, :, 2] /= img_size[1]
    gt_bbox[:, :, 3] /= img_size[0]
    bbox[:, :, 0] /= img_size[1]
    bbox[:, :, 1] /= img_size[0]
    bbox[:, :, 2] /= img_size[1]
    bbox[:, :, 3] /= img_size[0]

    n_graph = len(g_slice)
    g_pred_batch = []
    for gi in range(n_graph):
        g = g_slice[gi]
56
        ctx = g.ndata["bbox"].context
57
58
59
60
        inds = np.where(scores[gi, :, 0].asnumpy() > 0)[0].tolist()
        if len(inds) == 0:
            return None
        if len(inds) > scores_top_k:
61
62
63
            top_score_inds = (
                scores[gi, inds, 0].asnumpy().argsort()[::-1][0:scores_top_k]
            )
64
            inds = np.array(inds)[top_score_inds].tolist()
65

66
67
        n_nodes = len(inds)
        roi_ind = feat_ind[gi, inds].squeeze(axis=1)
68
        g_pred = dgl.DGLGraph()
69
70
71
72
73
74
75
76
77
        g_pred.add_nodes(
            n_nodes,
            {
                "pred_bbox": bbox[gi, inds],
                "node_feat": spatial_feat[gi, roi_ind],
                "node_class_pred": ids[gi, inds, 0],
                "node_class_logit": nd.log(scores[gi, inds, 0] + 1e-7),
            },
        )
78
79

        # iou matching
80
81
82
        ious = nd.contrib.box_iou(
            gt_bbox[gi], g_pred.ndata["pred_bbox"]
        ).asnumpy()
83
84
85
86
87
88
89
90
91
92
93
94
95
        H, W = ious.shape
        h = H
        w = W
        pred_to_gt_ind = np.array([-1 for i in range(W)])
        pred_to_gt_class_match = [0 for i in range(W)]
        pred_to_gt_class_match_id = [0 for i in range(W)]
        while h > 0 and w > 0:
            ind = int(ious.argmax())
            row_ind = ind // W
            col_ind = ind % W
            if ious[row_ind, col_ind] < iou_thresh:
                break
            pred_to_gt_ind[col_ind] = row_ind
96
97
            gt_node_class = g.ndata["node_class"][row_ind]
            pred_node_class = g_pred.ndata["node_class_pred"][col_ind]
98
99
100
101
102
103
104
105
106
107
108
109
            if gt_node_class == pred_node_class:
                pred_to_gt_class_match[col_ind] = 1
                pred_to_gt_class_match_id[col_ind] = row_ind
            ious[row_ind, :] = -1
            ious[:, col_ind] = -1
            h -= 1
            w -= 1

        n_nodes = g_pred.number_of_nodes()
        triplet = []
        adjmat = np.zeros((n_nodes, n_nodes))

110
        src, dst = g.all_edges(order="eid")
111
112
113
114
115
116
117
118
        eid_keys = np.column_stack([src.asnumpy(), dst.asnumpy()])
        eid_dict = {}
        for i, key in enumerate(eid_keys):
            k = tuple(key)
            if k not in eid_dict:
                eid_dict[k] = [i]
            else:
                eid_dict[k].append(i)
119
        ori_rel_class = g.edata["rel_class"].asnumpy()
120
121
122
123
124
125
126
127
128
129
130
        for i in range(n_nodes):
            for j in range(n_nodes):
                if i != j:
                    if pred_to_gt_class_match[i] and pred_to_gt_class_match[j]:
                        sub_gt_id = pred_to_gt_class_match_id[i]
                        ob_gt_id = pred_to_gt_class_match_id[j]
                        eids = eid_dict[(sub_gt_id, ob_gt_id)]
                        rel_cls = ori_rel_class[eids]
                        n_edges_between = len(rel_cls)
                        for ii in range(n_edges_between):
                            triplet.append((i, j, rel_cls[ii]))
131
                        adjmat[i, j] = 1
132
133
134
135
                    else:
                        triplet.append((i, j, 0))
        src, dst, rel_class = tuple(zip(*triplet))
        rel_class = nd.array(rel_class, ctx=ctx).expand_dims(1)
136
        g_pred.add_edges(src, dst, data={"rel_class": rel_class})
137
138
139
140
141

        # other operations
        n_nodes = g_pred.number_of_nodes()
        n_edges = g_pred.number_of_edges()
        if bbox_improvement:
142
143
144
            g_pred.ndata["pred_bbox"] = bbox_improve(g_pred.ndata["pred_bbox"])
        g_pred.edata["rel_bbox"] = extract_edge_bbox(g_pred)
        g_pred.edata["batch_id"] = nd.zeros((n_edges, 1), ctx=ctx) + gi
145
146
147

        # remove non-overlapping edges
        if overlap:
148
149
150
151
            overlap_ious = nd.contrib.box_iou(
                g_pred.ndata["pred_bbox"][:, 0:4],
                g_pred.ndata["pred_bbox"][:, 0:4],
            ).asnumpy()
152
153
154
155
156
157
158
159
            cols, rows = np.where(overlap_ious <= 1e-7)
            if cols.shape[0] > 0:
                eids = g_pred.edge_ids(cols, rows)[2].asnumpy().tolist()
                if len(eids):
                    g_pred.remove_edges(eids)
                    if g_pred.number_of_edges() == 0:
                        g_pred = None
        g_pred_batch.append(g_pred)
160

161
162
163
164
165
    if n_graph > 1:
        return dgl.batch(g_pred_batch)
    else:
        return g_pred_batch[0]

166
167
168
169
170

def build_graph_validate_gt_obj(
    img, gt_ids, bbox, spatial_feat, bbox_improvement=True, overlap=False
):
    """given ground truth bbox and label, build graph for validation"""
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
    n_batch = img.shape[0]
    img_size = img.shape[2:4]
    bbox[:, :, 0] /= img_size[1]
    bbox[:, :, 1] /= img_size[0]
    bbox[:, :, 2] /= img_size[1]
    bbox[:, :, 3] /= img_size[0]
    ctx = img.context

    g_batch = []
    for btc in range(n_batch):
        inds = np.where(bbox[btc].sum(1).asnumpy() > 0)[0].tolist()
        if len(inds) == 0:
            continue
        n_nodes = len(inds)
        g_pred = dgl.DGLGraph()
186
187
188
189
190
191
192
193
194
195
196
        g_pred.add_nodes(
            n_nodes,
            {
                "pred_bbox": bbox[btc, inds],
                "node_feat": spatial_feat[btc, inds],
                "node_class_pred": gt_ids[btc, inds, 0],
                "node_class_logit": nd.zeros_like(
                    gt_ids[btc, inds, 0], ctx=ctx
                ),
            },
        )
197
198
199
200
201
202
203
204
205
206
207
208

        edge_list = []
        for i in range(n_nodes - 1):
            for j in range(i + 1, n_nodes):
                edge_list.append((i, j))
        src, dst = tuple(zip(*edge_list))
        g_pred.add_edges(src, dst)
        g_pred.add_edges(dst, src)

        n_nodes = g_pred.number_of_nodes()
        n_edges = g_pred.number_of_edges()
        if bbox_improvement:
209
210
211
            g_pred.ndata["pred_bbox"] = bbox_improve(g_pred.ndata["pred_bbox"])
        g_pred.edata["rel_bbox"] = extract_edge_bbox(g_pred)
        g_pred.edata["batch_id"] = nd.zeros((n_edges, 1), ctx=ctx) + btc
212
213
214
215

        g_batch.append(g_pred)

    if len(g_batch) == 0:
216
        return None
217
218
219
220
    if len(g_batch) > 1:
        return dgl.batch(g_batch)
    return g_batch[0]

221
222
223
224
225
226
227
228
229
230
231
232

def build_graph_validate_gt_bbox(
    img,
    ids,
    scores,
    bbox,
    spatial_feat,
    gt_ids=None,
    bbox_improvement=True,
    overlap=False,
):
    """given ground truth bbox, build graph for validation"""
233
234
235
236
237
238
239
240
241
242
    n_batch = img.shape[0]
    img_size = img.shape[2:4]
    bbox[:, :, 0] /= img_size[1]
    bbox[:, :, 1] /= img_size[0]
    bbox[:, :, 2] /= img_size[1]
    bbox[:, :, 3] /= img_size[0]
    ctx = img.context

    g_batch = []
    for btc in range(n_batch):
243
244
        id_btc = scores[btc][:, :, 0].argmax(0)
        score_btc = scores[btc][:, :, 0].max(0)
245
246
247
248
249
        inds = np.where(bbox[btc].sum(1).asnumpy() > 0)[0].tolist()
        if len(inds) == 0:
            continue
        n_nodes = len(inds)
        g_pred = dgl.DGLGraph()
250
251
252
253
254
255
256
257
258
        g_pred.add_nodes(
            n_nodes,
            {
                "pred_bbox": bbox[btc, inds],
                "node_feat": spatial_feat[btc, inds],
                "node_class_pred": id_btc,
                "node_class_logit": nd.log(score_btc + 1e-7),
            },
        )
259
260
261
262
263
264
265
266
267
268
269
270

        edge_list = []
        for i in range(n_nodes - 1):
            for j in range(i + 1, n_nodes):
                edge_list.append((i, j))
        src, dst = tuple(zip(*edge_list))
        g_pred.add_edges(src, dst)
        g_pred.add_edges(dst, src)

        n_nodes = g_pred.number_of_nodes()
        n_edges = g_pred.number_of_edges()
        if bbox_improvement:
271
272
273
            g_pred.ndata["pred_bbox"] = bbox_improve(g_pred.ndata["pred_bbox"])
        g_pred.edata["rel_bbox"] = extract_edge_bbox(g_pred)
        g_pred.edata["batch_id"] = nd.zeros((n_edges, 1), ctx=ctx) + btc
274
275
276
277

        g_batch.append(g_pred)

    if len(g_batch) == 0:
278
        return None
279
280
281
282
    if len(g_batch) > 1:
        return dgl.batch(g_batch)
    return g_batch[0]

283
284
285
286
287
288
289
290
291
292
293
294
295

def build_graph_validate_pred(
    img,
    ids,
    scores,
    bbox,
    feat_ind,
    spatial_feat,
    bbox_improvement=True,
    scores_top_k=50,
    overlap=False,
):
    """given predicted bbox, build graph for validation"""
296
297
298
299
300
301
302
303
304
305
306
307
308
309
    n_batch = img.shape[0]
    img_size = img.shape[2:4]
    bbox[:, :, 0] /= img_size[1]
    bbox[:, :, 1] /= img_size[0]
    bbox[:, :, 2] /= img_size[1]
    bbox[:, :, 3] /= img_size[0]
    ctx = img.context

    g_batch = []
    for btc in range(n_batch):
        inds = np.where(scores[btc, :, 0].asnumpy() > 0)[0].tolist()
        if len(inds) == 0:
            continue
        if len(inds) > scores_top_k:
310
311
312
            top_score_inds = (
                scores[btc, inds, 0].asnumpy().argsort()[::-1][0:scores_top_k]
            )
313
314
315
316
317
            inds = np.array(inds)[top_score_inds].tolist()
        n_nodes = len(inds)
        roi_ind = feat_ind[btc, inds].squeeze(axis=1)

        g_pred = dgl.DGLGraph()
318
319
320
321
322
323
324
325
326
        g_pred.add_nodes(
            n_nodes,
            {
                "pred_bbox": bbox[btc, inds],
                "node_feat": spatial_feat[btc, roi_ind],
                "node_class_pred": ids[btc, inds, 0],
                "node_class_logit": nd.log(scores[btc, inds, 0] + 1e-7),
            },
        )
327
328
329
330
331
332
333
334
335
336
337
338

        edge_list = []
        for i in range(n_nodes - 1):
            for j in range(i + 1, n_nodes):
                edge_list.append((i, j))
        src, dst = tuple(zip(*edge_list))
        g_pred.add_edges(src, dst)
        g_pred.add_edges(dst, src)

        n_nodes = g_pred.number_of_nodes()
        n_edges = g_pred.number_of_edges()
        if bbox_improvement:
339
340
341
            g_pred.ndata["pred_bbox"] = bbox_improve(g_pred.ndata["pred_bbox"])
        g_pred.edata["rel_bbox"] = extract_edge_bbox(g_pred)
        g_pred.edata["batch_id"] = nd.zeros((n_edges, 1), ctx=ctx) + btc
342
343
344
345

        g_batch.append(g_pred)

    if len(g_batch) == 0:
346
        return None
347
348
349
    if len(g_batch) > 1:
        return dgl.batch(g_batch)
    return g_batch[0]