Unverified Commit 77374343 authored by Xiaomeng Zhao's avatar Xiaomeng Zhao Committed by GitHub
Browse files

Merge branch 'opendatalab:dev' into dev

parents c91f918f a6870016
...@@ -519,15 +519,27 @@ class MagicModel: ...@@ -519,15 +519,27 @@ class MagicModel:
if nxt is None: if nxt is None:
break break
seen_idx.add(fst_idx)
seen_idx.add(nxt[0])
if fst_kind == SUB_BIT_KIND: if fst_kind == SUB_BIT_KIND:
seen_sub_idx.add(fst_idx)
sub_idx, obj_idx = fst_idx, nxt[0] - OBJ_IDX_OFFSET sub_idx, obj_idx = fst_idx, nxt[0] - OBJ_IDX_OFFSET
else: else:
seen_sub_idx.add(nxt[0])
sub_idx, obj_idx = nxt[0], fst_idx - OBJ_IDX_OFFSET sub_idx, obj_idx = nxt[0], fst_idx - OBJ_IDX_OFFSET
pair_dis = bbox_distance(subjects[sub_idx]['bbox'], objects[obj_idx]['bbox'])
nearest_dis = float('inf')
for i in range(N):
if i in seen_idx:continue
nearest_dis = min(nearest_dis, bbox_distance(subjects[i]['bbox'], objects[obj_idx]['bbox']))
if pair_dis >= 3*nearest_dis:
seen_idx.add(sub_idx)
continue
seen_idx.add(sub_idx)
seen_idx.add(obj_idx + OBJ_IDX_OFFSET)
seen_sub_idx.add(sub_idx)
ret.append( ret.append(
{ {
'sub_bbox': { 'sub_bbox': {
......
...@@ -24,7 +24,7 @@ def test_convert_middle_json_to_layout_elements(): ...@@ -24,7 +24,7 @@ def test_convert_middle_json_to_layout_elements():
assert len(res[0].layout_dets) > 0 assert len(res[0].layout_dets) > 0
assert res[0].layout_dets[0].anno_id == 0 assert res[0].layout_dets[0].anno_id == 0
assert res[0].layout_dets[0].category_type == CategoryType.text assert res[0].layout_dets[0].category_type == CategoryType.text
assert len(res[0].extra.element_relation) >= 3 assert len(res[0].extra.element_relation) >= 2
# teardown # teardown
shutil.rmtree(temp_output_dir) shutil.rmtree(temp_output_dir)
...@@ -51,7 +51,7 @@ def test_inference(): ...@@ -51,7 +51,7 @@ def test_inference():
assert len(res[0].layout_dets) > 0 assert len(res[0].layout_dets) > 0
assert res[0].layout_dets[0].anno_id == 0 assert res[0].layout_dets[0].anno_id == 0
assert res[0].layout_dets[0].category_type == CategoryType.text assert res[0].layout_dets[0].category_type == CategoryType.text
assert len(res[0].extra.element_relation) >= 3 assert len(res[0].extra.element_relation) >= 2
# teardown # teardown
shutil.rmtree(temp_output_dir) shutil.rmtree(temp_output_dir)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment