matcher.py 6.81 KB
Newer Older
WenmuZhou's avatar
WenmuZhou committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import json
def distance(box_1, box_2):
        x1, y1, x2, y2 = box_1
        x3, y3, x4, y4 = box_2
        dis = abs(x3 - x1) + abs(y3 - y1) + abs(x4- x2) + abs(y4 - y2)
        dis_2 = abs(x3 - x1) + abs(y3 - y1)
        dis_3 = abs(x4- x2) + abs(y4 - y2)
        return dis + min(dis_2, dis_3)

def compute_iou(rec1, rec2):
    """
    computing IoU
    :param rec1: (y0, x0, y1, x1), which reflects
            (top, left, bottom, right)
    :param rec2: (y0, x0, y1, x1)
    :return: scala value of IoU
    """
    # computing area of each rectangles
    S_rec1 = (rec1[2] - rec1[0]) * (rec1[3] - rec1[1])
    S_rec2 = (rec2[2] - rec2[0]) * (rec2[3] - rec2[1])
 
    # computing the sum_area
    sum_area = S_rec1 + S_rec2
 
    # find the each edge of intersect rectangle
    left_line = max(rec1[1], rec2[1])
    right_line = min(rec1[3], rec2[3])
    top_line = max(rec1[0], rec2[0])
    bottom_line = min(rec1[2], rec2[2])
 
    # judge if there is an intersect
    if left_line >= right_line or top_line >= bottom_line:
WenmuZhou's avatar
WenmuZhou committed
33
        return 0.0
WenmuZhou's avatar
WenmuZhou committed
34
35
36
37
38
39
    else:
        intersect = (right_line - left_line) * (bottom_line - top_line)
        return (intersect / (sum_area - intersect))*1.0
 


WenmuZhou's avatar
WenmuZhou committed
40
def matcher_merge(ocr_bboxes, pred_bboxes):
WenmuZhou's avatar
WenmuZhou committed
41
42
43
44
45
46
    all_dis = []
    ious = []
    matched = {}
    for i, gt_box in enumerate(ocr_bboxes):
        distances = []
        for j, pred_box in enumerate(pred_bboxes):
WenmuZhou's avatar
WenmuZhou committed
47
48
            # compute l1 distence and IOU between two boxes
            distances.append((distance(gt_box, pred_box), 1. - compute_iou(gt_box, pred_box)))
WenmuZhou's avatar
WenmuZhou committed
49
        sorted_distances = distances.copy()
WenmuZhou's avatar
WenmuZhou committed
50
        # select nearest cell
WenmuZhou's avatar
WenmuZhou committed
51
52
53
54
55
56
        sorted_distances = sorted(sorted_distances, key = lambda item: (item[1], item[0])) 
        if distances.index(sorted_distances[0]) not in matched.keys(): 
            matched[distances.index(sorted_distances[0])] = [i]
        else:
            matched[distances.index(sorted_distances[0])].append(i)
    return matched#, sum(ious) / len(ious)
WenmuZhou's avatar
WenmuZhou committed
57

WenmuZhou's avatar
WenmuZhou committed
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
def complex_num(pred_bboxes):
    complex_nums = []
    for bbox in pred_bboxes:
        distances = []
        temp_ious = []
        for pred_bbox in pred_bboxes:
            if bbox != pred_bbox:
                distances.append(distance(bbox, pred_bbox))
                temp_ious.append(compute_iou(bbox, pred_bbox))
        complex_nums.append(temp_ious[distances.index(min(distances))])
    return sum(complex_nums) / len(complex_nums)

def get_rows(pred_bboxes):
    pre_bbox = pred_bboxes[0]
    res = []
    step = 0
    for i in range(len(pred_bboxes)):
        bbox = pred_bboxes[i]
        if bbox[1] - pre_bbox[1] > 2 or bbox[0] - pre_bbox[0] < 0:
            break
        else:
            res.append(bbox)
            step += 1
    for i in range(step):
        pred_bboxes.pop(0)
    return res, pred_bboxes
def refine_rows(pred_bboxes): # 微调整行的框,使在一条水平线上
    ys_1 = []
    ys_2 = []
    for box in pred_bboxes:
        ys_1.append(box[1])
        ys_2.append(box[3])
    min_y_1 = sum(ys_1) / len(ys_1)
    min_y_2 = sum(ys_2) / len(ys_2)
    re_boxes = []
    for box in pred_bboxes:
        box[1] = min_y_1
        box[3] = min_y_2
        re_boxes.append(box)
    return re_boxes
    
def matcher_refine_row(gt_bboxes, pred_bboxes):
    before_refine_pred_bboxes = pred_bboxes.copy()
    pred_bboxes = []
    while(len(before_refine_pred_bboxes) != 0):
        row_bboxes, before_refine_pred_bboxes = get_rows(before_refine_pred_bboxes)
        print(row_bboxes)
        pred_bboxes.extend(refine_rows(row_bboxes))
    all_dis = []
    ious = []
    matched = {}
    for i, gt_box in enumerate(gt_bboxes):
        distances = []
        #temp_ious = []
        for j, pred_box in enumerate(pred_bboxes):
            distances.append(distance(gt_box, pred_box))
            #temp_ious.append(compute_iou(gt_box, pred_box))
        #all_dis.append(min(distances))
        #ious.append(temp_ious[distances.index(min(distances))])
        if distances.index(min(distances)) not in matched.keys(): 
            matched[distances.index(min(distances))] = [i]
        else:
            matched[distances.index(min(distances))].append(i)
    return matched#, sum(ious) / len(ious)



#先挑选出一行,再进行匹配
def matcher_structure_1(gt_bboxes, pred_bboxes_rows, pred_bboxes):
    gt_box_index = 0
    delete_gt_bboxes = gt_bboxes.copy()
    match_bboxes_ready = []
    matched = {}
    while(len(delete_gt_bboxes) != 0):
        row_bboxes, delete_gt_bboxes = get_rows(delete_gt_bboxes)
        row_bboxes = sorted(row_bboxes, key = lambda key: key[0])
        if len(pred_bboxes_rows) > 0:
            match_bboxes_ready.extend(pred_bboxes_rows.pop(0))
        print(row_bboxes)
        for i, gt_box in enumerate(row_bboxes):
            #print(gt_box)
            pred_distances = []
            distances = []  
            for pred_bbox in pred_bboxes:
                pred_distances.append(distance(gt_box, pred_bbox))
            for j, pred_box in enumerate(match_bboxes_ready):
                distances.append(distance(gt_box, pred_box))
            index = pred_distances.index(min(distances))
            #print('index', index)
            if index not in matched.keys(): 
                matched[index] = [gt_box_index]
            else:
                matched[index].append(gt_box_index)
            gt_box_index += 1
    return matched

def matcher_structure(gt_bboxes, pred_bboxes_rows, pred_bboxes):
    '''
    gt_bboxes: 排序后
    pred_bboxes: 
    '''
    pre_bbox = gt_bboxes[0]
    matched = {}
    match_bboxes_ready = []
    match_bboxes_ready.extend(pred_bboxes_rows.pop(0))
    for i, gt_box in enumerate(gt_bboxes):
        
        pred_distances = []
        for pred_bbox in pred_bboxes:
            pred_distances.append(distance(gt_box, pred_bbox))
        distances = []
        gap_pre = gt_box[1] - pre_bbox[1]
        gap_pre_1 = gt_box[0] - pre_bbox[2]
        #print(gap_pre, len(pred_bboxes_rows))
        if (gap_pre_1 < 0 and len(pred_bboxes_rows) > 0):
            match_bboxes_ready.extend(pred_bboxes_rows.pop(0))
        if len(pred_bboxes_rows) == 1:
            match_bboxes_ready.extend(pred_bboxes_rows.pop(0))
        if len(match_bboxes_ready) == 0 and len(pred_bboxes_rows) > 0:
            match_bboxes_ready.extend(pred_bboxes_rows.pop(0))
        if len(match_bboxes_ready) == 0 and len(pred_bboxes_rows) == 0:
            break
        #print(match_bboxes_ready)
        for j, pred_box in enumerate(match_bboxes_ready):
            distances.append(distance(gt_box, pred_box))
        index = pred_distances.index(min(distances))
        #print(gt_box, index)
        #match_bboxes_ready.pop(distances.index(min(distances)))
        print(gt_box, match_bboxes_ready[distances.index(min(distances))])
        if index not in matched.keys(): 
            matched[index] = [i]
        else:
            matched[index].append(i)
        pre_bbox = gt_box
    return matched