metric.py 11.5 KB
Newer Older
1
2
3
4
import logging
import time
from operator import attrgetter, itemgetter

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
5
6
import dgl

7
8
import mxnet as mx
import numpy as np
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
9
10
from dgl.nn.mxnet import GraphConv
from dgl.utils import toindex
11
12
13
from gluoncv.data.batchify import Pad
from gluoncv.model_zoo import get_model
from mxnet import gluon, nd
14
from mxnet.gluon import nn
15

16
17
18
19
20
21
22

def iou(boxA, boxB):
    # determine the (x, y)-coordinates of the intersection rectangle
    xA = max(boxA[0], boxB[0])
    yA = max(boxA[1], boxB[1])
    xB = min(boxA[2], boxB[2])
    yB = min(boxA[3], boxB[3])
23

24
    interArea = max(0, xB - xA) * max(0, yB - yA)
25
    if interArea < 1e-7:
26
27
28
29
30
31
32
33
34
35
        return 0

    boxAArea = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1])
    boxBArea = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1])
    if boxAArea + boxBArea - interArea < 1e-7:
        return 0

    iou_val = interArea / float(boxAArea + boxBArea - interArea)
    return iou_val

36

37
38
39
40
41
42
def object_iou_thresh(gt_object, pred_object, iou_thresh=0.5):
    obj_iou = iou(gt_object[1:5], pred_object[1:5])
    if obj_iou >= iou_thresh:
        return True
    return False

43

44
45
46
47
48
49
50
51
def triplet_iou_thresh(pred_triplet, gt_triplet, iou_thresh=0.5):
    sub_iou = iou(gt_triplet[5:9], pred_triplet[5:9])
    if sub_iou >= iou_thresh:
        ob_iou = iou(gt_triplet[9:13], pred_triplet[9:13])
        if ob_iou >= iou_thresh:
            return True
    return False

52

53
@mx.metric.register
54
@mx.metric.alias("auc")
55
class AUCMetric(mx.metric.EvalMetric):
56
    def __init__(self, name="auc", eps=1e-12):
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
        super(AUCMetric, self).__init__(name)
        self.eps = eps

    def update(self, labels, preds):
        mx.metric.check_label_shapes(labels, preds)
        label_weight = labels[0].asnumpy()
        preds = preds[0].asnumpy()
        tmp = []
        for i in range(preds.shape[0]):
            tmp.append((label_weight[i], preds[i][1]))
        tmp = sorted(tmp, key=itemgetter(1), reverse=True)
        label_sum = label_weight.sum()
        if label_sum == 0 or label_sum == label_weight.size:
            return

        label_one_num = np.count_nonzero(label_weight)
        label_zero_num = len(label_weight) - label_one_num
        total_area = label_zero_num * label_one_num
        height = 0
        width = 0
        area = 0
        for a, _ in tmp:
            if a == 1.0:
                height += 1.0
            else:
                width += 1.0
                area += height

        self.sum_metric += area / total_area
        self.num_inst += 1

88

89
@mx.metric.register
90
@mx.metric.alias("predcls")
91
class PredCls(mx.metric.EvalMetric):
92
93
    """Metric with ground truth object location and label"""

94
    def __init__(self, topk=20, iou_thresh=0.99):
95
        super(PredCls, self).__init__("predcls@%d" % (topk))
96
97
98
99
100
101
102
        self.topk = topk
        self.iou_thresh = iou_thresh

    def update(self, labels, preds):
        if labels is None or preds is None:
            self.num_inst += 1
            return
103
        preds = preds[preds[:, 0].argsort()[::-1]]
104
105
106
107
108
109
110
111
112
113
        m = min(self.topk, preds.shape[0])
        count = 0
        gt_edge_num = labels.shape[0]
        label_matched = [False for label in labels]
        for i in range(m):
            pred = preds[i]
            for j in range(gt_edge_num):
                if label_matched[j]:
                    continue
                label = labels[j]
114
115
116
                if int(label[2]) == int(pred[2]) and triplet_iou_thresh(
                    pred, label, self.iou_thresh
                ):
117
118
119
120
121
122
123
                    count += 1
                    label_matched[j] = True

        total = labels.shape[0]
        self.sum_metric += count / total
        self.num_inst += 1

124

125
@mx.metric.register
126
@mx.metric.alias("phrcls")
127
class PhrCls(mx.metric.EvalMetric):
128
129
    """Metric with ground truth object location and predicted object label from detector"""

130
    def __init__(self, topk=20, iou_thresh=0.99):
131
        super(PhrCls, self).__init__("phrcls@%d" % (topk))
132
133
134
135
136
137
138
        self.topk = topk
        self.iou_thresh = iou_thresh

    def update(self, labels, preds):
        if labels is None or preds is None:
            self.num_inst += 1
            return
139
        preds = preds[preds[:, 1].argsort()[::-1]]
140
141
142
143
144
145
146
147
148
149
        m = min(self.topk, preds.shape[0])
        count = 0
        gt_edge_num = labels.shape[0]
        label_matched = [False for label in labels]
        for i in range(m):
            pred = preds[i]
            for j in range(gt_edge_num):
                if label_matched[j]:
                    continue
                label = labels[j]
150
151
152
153
154
155
                if (
                    int(label[2]) == int(pred[2])
                    and int(label[3]) == int(pred[3])
                    and int(label[4]) == int(pred[4])
                    and triplet_iou_thresh(pred, label, self.iou_thresh)
                ):
156
157
158
159
160
161
                    count += 1
                    label_matched[j] = True
        total = labels.shape[0]
        self.sum_metric += count / total
        self.num_inst += 1

162

163
@mx.metric.register
164
@mx.metric.alias("sgdet")
165
class SGDet(mx.metric.EvalMetric):
166
167
    """Metric with predicted object information by the detector"""

168
    def __init__(self, topk=20, iou_thresh=0.5):
169
        super(SGDet, self).__init__("sgdet@%d" % (topk))
170
171
172
173
174
175
176
        self.topk = topk
        self.iou_thresh = iou_thresh

    def update(self, labels, preds):
        if labels is None or preds is None:
            self.num_inst += 1
            return
177
        preds = preds[preds[:, 1].argsort()[::-1]]
178
179
180
181
182
183
184
185
186
187
        m = min(self.topk, len(preds))
        count = 0
        gt_edge_num = labels.shape[0]
        label_matched = [False for label in labels]
        for i in range(m):
            pred = preds[i]
            for j in range(gt_edge_num):
                if label_matched[j]:
                    continue
                label = labels[j]
188
189
190
191
192
193
                if (
                    int(label[2]) == int(pred[2])
                    and int(label[3]) == int(pred[3])
                    and int(label[4]) == int(pred[4])
                    and triplet_iou_thresh(pred, label, self.iou_thresh)
                ):
194
                    count += 1
195
                    label_matched[j] = True
196
197
198
199
        total = labels.shape[0]
        self.sum_metric += count / total
        self.num_inst += 1

200

201
@mx.metric.register
202
@mx.metric.alias("sgdet+")
203
class SGDetPlus(mx.metric.EvalMetric):
204
205
    """Metric proposed by `Graph R-CNN for Scene Graph Generation`"""

206
    def __init__(self, topk=20, iou_thresh=0.5):
207
        super(SGDetPlus, self).__init__("sgdet+@%d" % (topk))
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
        self.topk = topk
        self.iou_thresh = iou_thresh

    def update(self, labels, preds):
        label_objects, label_triplets = labels
        pred_objects, pred_triplets = preds
        if label_objects is None or pred_objects is None:
            self.num_inst += 1
            return
        count = 0
        # count objects
        object_matched = [False for obj in label_objects]
        m = len(pred_objects)
        gt_obj_num = label_objects.shape[0]
        for i in range(m):
            pred = pred_objects[i]
            for j in range(gt_obj_num):
                if object_matched[j]:
                    continue
                label = label_objects[j]
228
229
230
                if int(label[0]) == int(pred[0]) and object_iou_thresh(
                    pred, label, self.iou_thresh
                ):
231
232
                    count += 1
                    object_matched[j] = True
233

234
        # count predicate and triplet
235
        pred_triplets = pred_triplets[pred_triplets[:, 1].argsort()[::-1]]
236
237
238
239
240
241
242
243
244
        m = min(self.topk, len(pred_triplets))
        gt_triplet_num = label_triplets.shape[0]
        triplet_matched = [False for label in label_triplets]
        predicate_matched = [False for label in label_triplets]
        for i in range(m):
            pred = pred_triplets[i]
            for j in range(gt_triplet_num):
                label = label_triplets[j]
                if not predicate_matched:
245
246
247
                    if int(label[2]) == int(pred[2]) and triplet_iou_thresh(
                        pred, label, self.iou_thresh
                    ):
248
249
250
                        count += label[3]
                        predicate_matched[j] = True
                if not triplet_matched[j]:
251
252
253
254
255
256
                    if (
                        int(label[2]) == int(pred[2])
                        and int(label[3]) == int(pred[3])
                        and int(label[4]) == int(pred[4])
                        and triplet_iou_thresh(pred, label, self.iou_thresh)
                    ):
257
258
259
260
261
262
263
264
                        count += 1
                        triplet_matched[j] = True
        # compute sum
        total = labels.shape[0]
        N = gt_obj_num + 2 * total
        self.sum_metric += count / N
        self.num_inst += 1

265

266
def extract_gt(g, img_size):
267
    """extract prediction from ground truth graph"""
268
269
    if g is None or g.number_of_nodes() == 0:
        return None, None
270
    gt_eids = np.where(g.edata["rel_class"].asnumpy() > 0)[0]
271
272
273
    if len(gt_eids) == 0:
        return None, None

274
275
276
277
278
279
    gt_class = g.ndata["node_class"][:, 0].asnumpy()
    gt_bbox = g.ndata["bbox"].asnumpy()
    gt_bbox[:, 0] /= img_size[1]
    gt_bbox[:, 1] /= img_size[0]
    gt_bbox[:, 2] /= img_size[1]
    gt_bbox[:, 3] /= img_size[0]
280
281
282
283
284
285

    gt_objects = np.vstack([gt_class, gt_bbox.transpose(1, 0)]).transpose(1, 0)

    gt_node_ids = g.find_edges(gt_eids)
    gt_node_sub = gt_node_ids[0].asnumpy()
    gt_node_ob = gt_node_ids[1].asnumpy()
286
    gt_rel_class = g.edata["rel_class"][gt_eids, 0].asnumpy() - 1
287
288
289
290
291
292
293
    gt_sub_class = gt_class[gt_node_sub]
    gt_ob_class = gt_class[gt_node_ob]

    gt_sub_bbox = gt_bbox[gt_node_sub]
    gt_ob_bbox = gt_bbox[gt_node_ob]

    n = len(gt_eids)
294
295
296
297
298
299
300
301
302
303
304
    gt_triplets = np.vstack(
        [
            np.ones(n),
            np.ones(n),
            gt_rel_class,
            gt_sub_class,
            gt_ob_class,
            gt_sub_bbox.transpose(1, 0),
            gt_ob_bbox.transpose(1, 0),
        ]
    ).transpose(1, 0)
305
306
    return gt_objects, gt_triplets

307

308
def extract_pred(g, topk=100, joint_preds=False):
309
    """extract prediction from prediction graph for validation and visualization"""
310
311
312
    if g is None or g.number_of_nodes() == 0:
        return None, None

313
314
315
    pred_class = g.ndata["node_class_pred"].asnumpy()
    pred_class_prob = g.ndata["node_class_logit"].asnumpy()
    pred_bbox = g.ndata["pred_bbox"][:, 0:4].asnumpy()
316

317
318
319
    pred_objects = np.vstack([pred_class, pred_bbox.transpose(1, 0)]).transpose(
        1, 0
    )
320

321
322
    score_pred = g.edata["score_pred"].asnumpy()
    score_phr = g.edata["score_phr"].asnumpy()
323
324
325
326
    score_pred_topk_eids = (-score_pred).argsort()[0:topk].tolist()
    score_phr_topk_eids = (-score_phr).argsort()[0:topk].tolist()
    topk_eids = sorted(list(set(score_pred_topk_eids + score_phr_topk_eids)))

327
    pred_rel_prob = g.edata["preds"][topk_eids].asnumpy()
328
    if joint_preds:
329
        pred_rel_class = pred_rel_prob[:, 1:].argmax(axis=1)
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
    else:
        pred_rel_class = pred_rel_prob.argmax(axis=1)

    pred_node_ids = g.find_edges(topk_eids)
    pred_node_sub = pred_node_ids[0].asnumpy()
    pred_node_ob = pred_node_ids[1].asnumpy()

    pred_sub_class = pred_class[pred_node_sub]
    pred_sub_class_prob = pred_class_prob[pred_node_sub]
    pred_sub_bbox = pred_bbox[pred_node_sub]

    pred_ob_class = pred_class[pred_node_ob]
    pred_ob_class_prob = pred_class_prob[pred_node_ob]
    pred_ob_bbox = pred_bbox[pred_node_ob]

345
346
347
348
349
350
351
352
353
354
355
    pred_triplets = np.vstack(
        [
            score_pred[topk_eids],
            score_phr[topk_eids],
            pred_rel_class,
            pred_sub_class,
            pred_ob_class,
            pred_sub_bbox.transpose(1, 0),
            pred_ob_bbox.transpose(1, 0),
        ]
    ).transpose(1, 0)
356
    return pred_objects, pred_triplets