Unverified Commit 77374343 authored by Xiaomeng Zhao's avatar Xiaomeng Zhao Committed by GitHub
Browse files

Merge branch 'opendatalab:dev' into dev

parents c91f918f a6870016
......@@ -488,46 +488,58 @@ class MagicModel:
OBJ_IDX_OFFSET = 10000
SUB_BIT_KIND, OBJ_BIT_KIND = 0, 1
all_boxes_with_idx = [(i, SUB_BIT_KIND, sub['bbox'][0], sub['bbox'][1]) for i, sub in enumerate(subjects)] + [(i + OBJ_IDX_OFFSET , OBJ_BIT_KIND, obj['bbox'][0], obj['bbox'][1]) for i, obj in enumerate(objects)]
seen_idx = set()
seen_sub_idx = set()
while N > len(seen_sub_idx):
candidates = []
candidates = []
for idx, kind, x0, y0 in all_boxes_with_idx:
if idx in seen_idx:
continue
continue
candidates.append((idx, kind, x0, y0))
if len(candidates) == 0:
break
left_x = min([v[2] for v in candidates])
top_y = min([v[3] for v in candidates])
candidates.sort(key=lambda x: (x[2]-left_x) ** 2 + (x[3] - top_y) ** 2)
fst_idx, fst_kind, left_x, top_y = candidates[0]
candidates.sort(key=lambda x: (x[2] - left_x) ** 2 + (x[3] - top_y)**2)
nxt = None
for i in range(1, len(candidates)):
if candidates[i][1] ^ fst_kind == 1:
nxt = candidates[i]
break
break
if nxt is None:
break
seen_idx.add(fst_idx)
seen_idx.add(nxt[0])
if fst_kind == SUB_BIT_KIND:
seen_sub_idx.add(fst_idx)
sub_idx, obj_idx = fst_idx, nxt[0] - OBJ_IDX_OFFSET
else:
seen_sub_idx.add(nxt[0])
sub_idx, obj_idx = nxt[0], fst_idx - OBJ_IDX_OFFSET
pair_dis = bbox_distance(subjects[sub_idx]['bbox'], objects[obj_idx]['bbox'])
nearest_dis = float('inf')
for i in range(N):
if i in seen_idx:continue
nearest_dis = min(nearest_dis, bbox_distance(subjects[i]['bbox'], objects[obj_idx]['bbox']))
if pair_dis >= 3*nearest_dis:
seen_idx.add(sub_idx)
continue
seen_idx.add(sub_idx)
seen_idx.add(obj_idx + OBJ_IDX_OFFSET)
seen_sub_idx.add(sub_idx)
ret.append(
{
'sub_bbox': {
......@@ -543,7 +555,7 @@ class MagicModel:
for i in range(len(subjects)):
if i in seen_sub_idx:
continue
continue
ret.append(
{
'sub_bbox': {
......@@ -554,8 +566,8 @@ class MagicModel:
'sub_idx': i,
}
)
return ret
......
......@@ -24,7 +24,7 @@ def test_convert_middle_json_to_layout_elements():
assert len(res[0].layout_dets) > 0
assert res[0].layout_dets[0].anno_id == 0
assert res[0].layout_dets[0].category_type == CategoryType.text
assert len(res[0].extra.element_relation) >= 3
assert len(res[0].extra.element_relation) >= 2
# teardown
shutil.rmtree(temp_output_dir)
......@@ -51,7 +51,7 @@ def test_inference():
assert len(res[0].layout_dets) > 0
assert res[0].layout_dets[0].anno_id == 0
assert res[0].layout_dets[0].category_type == CategoryType.text
assert len(res[0].extra.element_relation) >= 3
assert len(res[0].extra.element_relation) >= 2
# teardown
shutil.rmtree(temp_output_dir)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment