Commit daf0593b authored by icecraft's avatar icecraft
Browse files

fix: caption match algorithm

parent e9203f91
......@@ -450,11 +450,120 @@ class MagicModel:
)
return ret
def __tie_up_category_by_distance_v3(
self,
page_no: int,
subject_category_id: int,
object_category_id: int,
priority_pos: PosRelationEnum,
):
subjects = self.__reduct_overlap(
list(
map(
lambda x: {'bbox': x['bbox'], 'score': x['score']},
filter(
lambda x: x['category_id'] == subject_category_id,
self.__model_list[page_no]['layout_dets'],
),
)
)
)
objects = self.__reduct_overlap(
list(
map(
lambda x: {'bbox': x['bbox'], 'score': x['score']},
filter(
lambda x: x['category_id'] == object_category_id,
self.__model_list[page_no]['layout_dets'],
),
)
)
)
ret = []
N, M = len(subjects), len(objects)
subjects.sort(key=lambda x: x['bbox'][0] ** 2 + x['bbox'][1] ** 2)
objects.sort(key=lambda x: x['bbox'][0] ** 2 + x['bbox'][1] ** 2)
OBJ_IDX_OFFSET = 10000
SUB_BIT_KIND, OBJ_BIT_KIND = 0, 1
all_boxes_with_idx = [(i, SUB_BIT_KIND, sub['bbox'][0], sub['bbox'][1]) for i, sub in enumerate(subjects)] + [(i + OBJ_IDX_OFFSET , OBJ_BIT_KIND, obj['bbox'][0], obj['bbox'][1]) for i, obj in enumerate(objects)]
seen_idx = set()
seen_sub_idx = set()
while N > len(seen_sub_idx):
candidates = []
for idx, kind, x0, y0 in all_boxes_with_idx:
if idx in seen_idx:
continue
candidates.append((idx, kind, x0, y0))
if len(candidates) == 0:
break
left_x = min([v[2] for v in candidates])
top_y = min([v[3] for v in candidates])
candidates.sort(key=lambda x: (x[2]-left_x) ** 2 + (x[3] - top_y) ** 2)
fst_idx, fst_kind, left_x, top_y = candidates[0]
candidates.sort(key=lambda x: (x[2] - left_x) ** 2 + (x[3] - top_y)**2)
nxt = None
for i in range(1, len(candidates)):
if candidates[i][1] ^ fst_kind == 1:
nxt = candidates[i]
break
if nxt is None:
break
seen_idx.add(fst_idx)
seen_idx.add(nxt[0])
if fst_kind == SUB_BIT_KIND:
seen_sub_idx.add(fst_idx)
sub_idx, obj_idx = fst_idx, nxt[0] - OBJ_IDX_OFFSET
else:
seen_sub_idx.add(nxt[0])
sub_idx, obj_idx = nxt[0], fst_idx - OBJ_IDX_OFFSET
ret.append(
{
'sub_bbox': {
'bbox': subjects[sub_idx]['bbox'],
'score': subjects[sub_idx]['score'],
},
'obj_bboxes': [
{'score': objects[obj_idx]['score'], 'bbox': objects[obj_idx]['bbox']}
],
'sub_idx': sub_idx,
}
)
for i in range(len(subjects)):
if i in seen_sub_idx:
continue
ret.append(
{
'sub_bbox': {
'bbox': subjects[i]['bbox'],
'score': subjects[i]['score'],
},
'obj_bboxes': [],
'sub_idx': i,
}
)
return ret
def get_imgs_v2(self, page_no: int):
with_captions = self.__tie_up_category_by_distance_v2(
with_captions = self.__tie_up_category_by_distance_v3(
page_no, 3, 4, PosRelationEnum.BOTTOM
)
with_footnotes = self.__tie_up_category_by_distance_v2(
with_footnotes = self.__tie_up_category_by_distance_v3(
page_no, 3, CategoryId.ImageFootnote, PosRelationEnum.ALL
)
ret = []
......@@ -470,10 +579,10 @@ class MagicModel:
return ret
def get_tables_v2(self, page_no: int) -> list:
with_captions = self.__tie_up_category_by_distance_v2(
with_captions = self.__tie_up_category_by_distance_v3(
page_no, 5, 6, PosRelationEnum.UP
)
with_footnotes = self.__tie_up_category_by_distance_v2(
with_footnotes = self.__tie_up_category_by_distance_v3(
page_no, 5, 7, PosRelationEnum.ALL
)
ret = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment