"vscode:/vscode.git/clone" did not exist on "2366716f0164f18a89a1e041d588a5687455f8bd"
train_freq_prior.py 3.03 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import numpy as np
import json, pickle, os, argparse

def parse_args():
    parser = argparse.ArgumentParser(description='Train the Frequenct Prior For RelDN.')
    parser.add_argument('--overlap', action='store_true',
                        help="Only count overlap boxes.")
    parser.add_argument('--json-path', type=str, default='~/.mxnet/datasets/visualgenome',
                        help="Only count overlap boxes.")
    args = parser.parse_args()
    return args

args = parse_args()
use_overlap = args.overlap
PATH_TO_DATASETS = os.path.expanduser(args.json_path)
path_to_json = os.path.join(PATH_TO_DATASETS, 'rel_annotations_train.json')

# format in y1y2x1x2
def with_overlap(boxA, boxB):
    xA = max(boxA[2], boxB[2])
    xB = min(boxA[3], boxB[3])

    if xB > xA:
        yA = max(boxA[0], boxB[0])
        yB = min(boxA[1], boxB[1])

        if yB > yA:
            return 1

    return 0

def box_ious(boxes):
    n = len(boxes)
    res = np.zeros((n, n))
    for i in range(n-1):
        for j in range(i+1, n):
            iou_val = with_overlap(boxes[i], boxes[j])
            res[i, j] = iou_val
            res[j, i] = iou_val
    return res

with open(path_to_json, 'r') as f:
    tmp = f.read()
    train_data = json.loads(tmp)

fg_matrix = np.zeros((150, 150, 51), dtype=np.int64)
bg_matrix = np.zeros((150, 150), dtype=np.int64)

for _, item in train_data.items():
    gt_box_to_label = {}
    for rel in item:
        sub_bbox = rel['subject']['bbox']
        ob_bbox = rel['object']['bbox']
        sub_class = rel['subject']['category']
        ob_class = rel['object']['category']
        rel_class = rel['predicate']

        sub_node = tuple(sub_bbox)
        ob_node = tuple(ob_bbox)
        if sub_node not in gt_box_to_label:
            gt_box_to_label[sub_node] = sub_class
        if ob_node not in gt_box_to_label:
            gt_box_to_label[ob_node] = ob_class

        fg_matrix[sub_class, ob_class, rel_class + 1] += 1

    if use_overlap:
        gt_boxes = [*gt_box_to_label]
        gt_classes = np.array([*gt_box_to_label.values()])
        iou_mat = box_ious(gt_boxes)
        cols, rows = np.where(iou_mat)
        if len(cols) and len(rows):
            for col, row in zip(cols, rows):
                bg_matrix[gt_classes[col], gt_classes[row]] += 1
        else:
lisj's avatar
lisj committed
76
            all_possib = np.ones_like(iou_mat, dtype=bool)
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
            np.fill_diagonal(all_possib, 0)
            cols, rows = np.where(all_possib)
            for col, row in zip(cols, rows):
                bg_matrix[gt_classes[col], gt_classes[row]] += 1
    else:
        for b1, l1 in gt_box_to_label.items():
            for b2, l2 in gt_box_to_label.items():
                if b1 == b2:
                    continue
                bg_matrix[l1, l2] += 1


eps = 1e-3
bg_matrix += 1
fg_matrix[:, :, 0] = bg_matrix
pred_dist = np.log(fg_matrix / (fg_matrix.sum(2)[:, :, None] + eps) + eps)


if use_overlap:
    with open('freq_prior_overlap.pkl', 'wb') as f:
        pickle.dump(pred_dist, f)
else:
    with open('freq_prior.pkl', 'wb') as f:
        pickle.dump(pred_dist, f)