metric.py 11.5 KB
Newer Older
1
2
3
4
import logging
import time
from operator import attrgetter, itemgetter

5
6
import mxnet as mx
import numpy as np
7
8
9
from gluoncv.data.batchify import Pad
from gluoncv.model_zoo import get_model
from mxnet import gluon, nd
10
from mxnet.gluon import nn
11
12

import dgl
13
from dgl.nn.mxnet import GraphConv
14
15
from dgl.utils import toindex

16
17
18
19
20
21
22

def iou(boxA, boxB):
    # determine the (x, y)-coordinates of the intersection rectangle
    xA = max(boxA[0], boxB[0])
    yA = max(boxA[1], boxB[1])
    xB = min(boxA[2], boxB[2])
    yB = min(boxA[3], boxB[3])
23

24
    interArea = max(0, xB - xA) * max(0, yB - yA)
25
    if interArea < 1e-7:
26
27
28
29
30
31
32
33
34
35
        return 0

    boxAArea = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1])
    boxBArea = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1])
    if boxAArea + boxBArea - interArea < 1e-7:
        return 0

    iou_val = interArea / float(boxAArea + boxBArea - interArea)
    return iou_val

36

37
38
39
40
41
42
def object_iou_thresh(gt_object, pred_object, iou_thresh=0.5):
    obj_iou = iou(gt_object[1:5], pred_object[1:5])
    if obj_iou >= iou_thresh:
        return True
    return False

43

44
45
46
47
48
49
50
51
def triplet_iou_thresh(pred_triplet, gt_triplet, iou_thresh=0.5):
    sub_iou = iou(gt_triplet[5:9], pred_triplet[5:9])
    if sub_iou >= iou_thresh:
        ob_iou = iou(gt_triplet[9:13], pred_triplet[9:13])
        if ob_iou >= iou_thresh:
            return True
    return False

52

53
@mx.metric.register
54
@mx.metric.alias("auc")
55
class AUCMetric(mx.metric.EvalMetric):
56
    def __init__(self, name="auc", eps=1e-12):
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
        super(AUCMetric, self).__init__(name)
        self.eps = eps

    def update(self, labels, preds):
        mx.metric.check_label_shapes(labels, preds)
        label_weight = labels[0].asnumpy()
        preds = preds[0].asnumpy()
        tmp = []
        for i in range(preds.shape[0]):
            tmp.append((label_weight[i], preds[i][1]))
        tmp = sorted(tmp, key=itemgetter(1), reverse=True)
        label_sum = label_weight.sum()
        if label_sum == 0 or label_sum == label_weight.size:
            return

        label_one_num = np.count_nonzero(label_weight)
        label_zero_num = len(label_weight) - label_one_num
        total_area = label_zero_num * label_one_num
        height = 0
        width = 0
        area = 0
        for a, _ in tmp:
            if a == 1.0:
                height += 1.0
            else:
                width += 1.0
                area += height

        self.sum_metric += area / total_area
        self.num_inst += 1

88

89
@mx.metric.register
90
@mx.metric.alias("predcls")
91
class PredCls(mx.metric.EvalMetric):
92
93
    """Metric with ground truth object location and label"""

94
    def __init__(self, topk=20, iou_thresh=0.99):
95
        super(PredCls, self).__init__("predcls@%d" % (topk))
96
97
98
99
100
101
102
        self.topk = topk
        self.iou_thresh = iou_thresh

    def update(self, labels, preds):
        if labels is None or preds is None:
            self.num_inst += 1
            return
103
        preds = preds[preds[:, 0].argsort()[::-1]]
104
105
106
107
108
109
110
111
112
113
        m = min(self.topk, preds.shape[0])
        count = 0
        gt_edge_num = labels.shape[0]
        label_matched = [False for label in labels]
        for i in range(m):
            pred = preds[i]
            for j in range(gt_edge_num):
                if label_matched[j]:
                    continue
                label = labels[j]
114
115
116
                if int(label[2]) == int(pred[2]) and triplet_iou_thresh(
                    pred, label, self.iou_thresh
                ):
117
118
119
120
121
122
123
                    count += 1
                    label_matched[j] = True

        total = labels.shape[0]
        self.sum_metric += count / total
        self.num_inst += 1

124

125
@mx.metric.register
126
@mx.metric.alias("phrcls")
127
class PhrCls(mx.metric.EvalMetric):
128
129
    """Metric with ground truth object location and predicted object label from detector"""

130
    def __init__(self, topk=20, iou_thresh=0.99):
131
        super(PhrCls, self).__init__("phrcls@%d" % (topk))
132
133
134
135
136
137
138
        self.topk = topk
        self.iou_thresh = iou_thresh

    def update(self, labels, preds):
        if labels is None or preds is None:
            self.num_inst += 1
            return
139
        preds = preds[preds[:, 1].argsort()[::-1]]
140
141
142
143
144
145
146
147
148
149
        m = min(self.topk, preds.shape[0])
        count = 0
        gt_edge_num = labels.shape[0]
        label_matched = [False for label in labels]
        for i in range(m):
            pred = preds[i]
            for j in range(gt_edge_num):
                if label_matched[j]:
                    continue
                label = labels[j]
150
151
152
153
154
155
                if (
                    int(label[2]) == int(pred[2])
                    and int(label[3]) == int(pred[3])
                    and int(label[4]) == int(pred[4])
                    and triplet_iou_thresh(pred, label, self.iou_thresh)
                ):
156
157
158
159
160
161
                    count += 1
                    label_matched[j] = True
        total = labels.shape[0]
        self.sum_metric += count / total
        self.num_inst += 1

162

163
@mx.metric.register
164
@mx.metric.alias("sgdet")
165
class SGDet(mx.metric.EvalMetric):
166
167
    """Metric with predicted object information by the detector"""

168
    def __init__(self, topk=20, iou_thresh=0.5):
169
        super(SGDet, self).__init__("sgdet@%d" % (topk))
170
171
172
173
174
175
176
        self.topk = topk
        self.iou_thresh = iou_thresh

    def update(self, labels, preds):
        if labels is None or preds is None:
            self.num_inst += 1
            return
177
        preds = preds[preds[:, 1].argsort()[::-1]]
178
179
180
181
182
183
184
185
186
187
        m = min(self.topk, len(preds))
        count = 0
        gt_edge_num = labels.shape[0]
        label_matched = [False for label in labels]
        for i in range(m):
            pred = preds[i]
            for j in range(gt_edge_num):
                if label_matched[j]:
                    continue
                label = labels[j]
188
189
190
191
192
193
                if (
                    int(label[2]) == int(pred[2])
                    and int(label[3]) == int(pred[3])
                    and int(label[4]) == int(pred[4])
                    and triplet_iou_thresh(pred, label, self.iou_thresh)
                ):
194
                    count += 1
195
                    label_matched[j] = True
196
197
198
199
        total = labels.shape[0]
        self.sum_metric += count / total
        self.num_inst += 1

200

201
@mx.metric.register
202
@mx.metric.alias("sgdet+")
203
class SGDetPlus(mx.metric.EvalMetric):
204
205
    """Metric proposed by `Graph R-CNN for Scene Graph Generation`"""

206
    def __init__(self, topk=20, iou_thresh=0.5):
207
        super(SGDetPlus, self).__init__("sgdet+@%d" % (topk))
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
        self.topk = topk
        self.iou_thresh = iou_thresh

    def update(self, labels, preds):
        label_objects, label_triplets = labels
        pred_objects, pred_triplets = preds
        if label_objects is None or pred_objects is None:
            self.num_inst += 1
            return
        count = 0
        # count objects
        object_matched = [False for obj in label_objects]
        m = len(pred_objects)
        gt_obj_num = label_objects.shape[0]
        for i in range(m):
            pred = pred_objects[i]
            for j in range(gt_obj_num):
                if object_matched[j]:
                    continue
                label = label_objects[j]
228
229
230
                if int(label[0]) == int(pred[0]) and object_iou_thresh(
                    pred, label, self.iou_thresh
                ):
231
232
                    count += 1
                    object_matched[j] = True
233

234
        # count predicate and triplet
235
        pred_triplets = pred_triplets[pred_triplets[:, 1].argsort()[::-1]]
236
237
238
239
240
241
242
243
244
        m = min(self.topk, len(pred_triplets))
        gt_triplet_num = label_triplets.shape[0]
        triplet_matched = [False for label in label_triplets]
        predicate_matched = [False for label in label_triplets]
        for i in range(m):
            pred = pred_triplets[i]
            for j in range(gt_triplet_num):
                label = label_triplets[j]
                if not predicate_matched:
245
246
247
                    if int(label[2]) == int(pred[2]) and triplet_iou_thresh(
                        pred, label, self.iou_thresh
                    ):
248
249
250
                        count += label[3]
                        predicate_matched[j] = True
                if not triplet_matched[j]:
251
252
253
254
255
256
                    if (
                        int(label[2]) == int(pred[2])
                        and int(label[3]) == int(pred[3])
                        and int(label[4]) == int(pred[4])
                        and triplet_iou_thresh(pred, label, self.iou_thresh)
                    ):
257
258
259
260
261
262
263
264
                        count += 1
                        triplet_matched[j] = True
        # compute sum
        total = labels.shape[0]
        N = gt_obj_num + 2 * total
        self.sum_metric += count / N
        self.num_inst += 1

265

266
def extract_gt(g, img_size):
267
    """extract prediction from ground truth graph"""
268
269
    if g is None or g.number_of_nodes() == 0:
        return None, None
270
    gt_eids = np.where(g.edata["rel_class"].asnumpy() > 0)[0]
271
272
273
    if len(gt_eids) == 0:
        return None, None

274
275
276
277
278
279
    gt_class = g.ndata["node_class"][:, 0].asnumpy()
    gt_bbox = g.ndata["bbox"].asnumpy()
    gt_bbox[:, 0] /= img_size[1]
    gt_bbox[:, 1] /= img_size[0]
    gt_bbox[:, 2] /= img_size[1]
    gt_bbox[:, 3] /= img_size[0]
280
281
282
283
284
285

    gt_objects = np.vstack([gt_class, gt_bbox.transpose(1, 0)]).transpose(1, 0)

    gt_node_ids = g.find_edges(gt_eids)
    gt_node_sub = gt_node_ids[0].asnumpy()
    gt_node_ob = gt_node_ids[1].asnumpy()
286
    gt_rel_class = g.edata["rel_class"][gt_eids, 0].asnumpy() - 1
287
288
289
290
291
292
293
    gt_sub_class = gt_class[gt_node_sub]
    gt_ob_class = gt_class[gt_node_ob]

    gt_sub_bbox = gt_bbox[gt_node_sub]
    gt_ob_bbox = gt_bbox[gt_node_ob]

    n = len(gt_eids)
294
295
296
297
298
299
300
301
302
303
304
    gt_triplets = np.vstack(
        [
            np.ones(n),
            np.ones(n),
            gt_rel_class,
            gt_sub_class,
            gt_ob_class,
            gt_sub_bbox.transpose(1, 0),
            gt_ob_bbox.transpose(1, 0),
        ]
    ).transpose(1, 0)
305
306
    return gt_objects, gt_triplets

307

308
def extract_pred(g, topk=100, joint_preds=False):
309
    """extract prediction from prediction graph for validation and visualization"""
310
311
312
    if g is None or g.number_of_nodes() == 0:
        return None, None

313
314
315
    pred_class = g.ndata["node_class_pred"].asnumpy()
    pred_class_prob = g.ndata["node_class_logit"].asnumpy()
    pred_bbox = g.ndata["pred_bbox"][:, 0:4].asnumpy()
316

317
318
319
    pred_objects = np.vstack([pred_class, pred_bbox.transpose(1, 0)]).transpose(
        1, 0
    )
320

321
322
    score_pred = g.edata["score_pred"].asnumpy()
    score_phr = g.edata["score_phr"].asnumpy()
323
324
325
326
    score_pred_topk_eids = (-score_pred).argsort()[0:topk].tolist()
    score_phr_topk_eids = (-score_phr).argsort()[0:topk].tolist()
    topk_eids = sorted(list(set(score_pred_topk_eids + score_phr_topk_eids)))

327
    pred_rel_prob = g.edata["preds"][topk_eids].asnumpy()
328
    if joint_preds:
329
        pred_rel_class = pred_rel_prob[:, 1:].argmax(axis=1)
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
    else:
        pred_rel_class = pred_rel_prob.argmax(axis=1)

    pred_node_ids = g.find_edges(topk_eids)
    pred_node_sub = pred_node_ids[0].asnumpy()
    pred_node_ob = pred_node_ids[1].asnumpy()

    pred_sub_class = pred_class[pred_node_sub]
    pred_sub_class_prob = pred_class_prob[pred_node_sub]
    pred_sub_bbox = pred_bbox[pred_node_sub]

    pred_ob_class = pred_class[pred_node_ob]
    pred_ob_class_prob = pred_class_prob[pred_node_ob]
    pred_ob_bbox = pred_bbox[pred_node_ob]

345
346
347
348
349
350
351
352
353
354
355
    pred_triplets = np.vstack(
        [
            score_pred[topk_eids],
            score_phr[topk_eids],
            pred_rel_class,
            pred_sub_class,
            pred_ob_class,
            pred_sub_bbox.transpose(1, 0),
            pred_ob_bbox.transpose(1, 0),
        ]
    ).transpose(1, 0)
356
    return pred_objects, pred_triplets