train_freq_prior.py 3.08 KB
Newer Older
1
2
3
4
5
import argparse
import json
import os
import pickle

6
import numpy as np
7

8
9

def parse_args():
10
11
12
13
14
15
16
17
18
19
20
21
    parser = argparse.ArgumentParser(
        description="Train the Frequenct Prior For RelDN."
    )
    parser.add_argument(
        "--overlap", action="store_true", help="Only count overlap boxes."
    )
    parser.add_argument(
        "--json-path",
        type=str,
        default="~/.mxnet/datasets/visualgenome",
        help="Only count overlap boxes.",
    )
22
23
24
    args = parser.parse_args()
    return args

25

26
27
28
args = parse_args()
use_overlap = args.overlap
PATH_TO_DATASETS = os.path.expanduser(args.json_path)
29
path_to_json = os.path.join(PATH_TO_DATASETS, "rel_annotations_train.json")
30

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
31

32
33
34
35
36
37
38
39
40
41
42
43
44
45
# format in y1y2x1x2
def with_overlap(boxA, boxB):
    xA = max(boxA[2], boxB[2])
    xB = min(boxA[3], boxB[3])

    if xB > xA:
        yA = max(boxA[0], boxB[0])
        yB = min(boxA[1], boxB[1])

        if yB > yA:
            return 1

    return 0

46

47
48
49
def box_ious(boxes):
    n = len(boxes)
    res = np.zeros((n, n))
50
51
    for i in range(n - 1):
        for j in range(i + 1, n):
52
53
54
55
56
            iou_val = with_overlap(boxes[i], boxes[j])
            res[i, j] = iou_val
            res[j, i] = iou_val
    return res

57
58

with open(path_to_json, "r") as f:
59
60
61
62
63
64
65
66
67
    tmp = f.read()
    train_data = json.loads(tmp)

fg_matrix = np.zeros((150, 150, 51), dtype=np.int64)
bg_matrix = np.zeros((150, 150), dtype=np.int64)

for _, item in train_data.items():
    gt_box_to_label = {}
    for rel in item:
68
69
70
71
72
        sub_bbox = rel["subject"]["bbox"]
        ob_bbox = rel["object"]["bbox"]
        sub_class = rel["subject"]["category"]
        ob_class = rel["object"]["category"]
        rel_class = rel["predicate"]
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111

        sub_node = tuple(sub_bbox)
        ob_node = tuple(ob_bbox)
        if sub_node not in gt_box_to_label:
            gt_box_to_label[sub_node] = sub_class
        if ob_node not in gt_box_to_label:
            gt_box_to_label[ob_node] = ob_class

        fg_matrix[sub_class, ob_class, rel_class + 1] += 1

    if use_overlap:
        gt_boxes = [*gt_box_to_label]
        gt_classes = np.array([*gt_box_to_label.values()])
        iou_mat = box_ious(gt_boxes)
        cols, rows = np.where(iou_mat)
        if len(cols) and len(rows):
            for col, row in zip(cols, rows):
                bg_matrix[gt_classes[col], gt_classes[row]] += 1
        else:
            all_possib = np.ones_like(iou_mat, dtype=np.bool)
            np.fill_diagonal(all_possib, 0)
            cols, rows = np.where(all_possib)
            for col, row in zip(cols, rows):
                bg_matrix[gt_classes[col], gt_classes[row]] += 1
    else:
        for b1, l1 in gt_box_to_label.items():
            for b2, l2 in gt_box_to_label.items():
                if b1 == b2:
                    continue
                bg_matrix[l1, l2] += 1


eps = 1e-3
bg_matrix += 1
fg_matrix[:, :, 0] = bg_matrix
pred_dist = np.log(fg_matrix / (fg_matrix.sum(2)[:, :, None] + eps) + eps)


if use_overlap:
112
    with open("freq_prior_overlap.pkl", "wb") as f:
113
114
        pickle.dump(pred_dist, f)
else:
115
    with open("freq_prior.pkl", "wb") as f:
116
        pickle.dump(pred_dist, f)