Unverified Commit 516f4926 authored by Xiaomeng Zhao's avatar Xiaomeng Zhao Committed by GitHub
Browse files

Merge pull request #2873 from myhloli/dev

Dev
parents 6a242ada 7d8f68cb
......@@ -152,9 +152,6 @@ def page_model_info_to_page_info(page_model_info, image_dict, page, image_writer
"""对block进行fix操作"""
fix_blocks = fix_block_spans(block_with_spans)
"""同一行被断开的titile合并"""
# merge_title_blocks(fix_blocks)
"""对block进行排序"""
sorted_blocks = sort_blocks_by_bbox(fix_blocks, page_w, page_h, footnote_blocks)
......
from mineru.utils.boxbase import bbox_relative_pos, calculate_iou, bbox_distance, is_in
from mineru.utils.boxbase import bbox_relative_pos, calculate_iou, bbox_distance, is_in, get_minbox_if_overlap_by_ratio
from mineru.utils.enum_class import CategoryId, ContentType
......@@ -13,7 +13,62 @@ class MagicModel:
self.__fix_by_remove_low_confidence()
"""删除高iou(>0.9)数据中置信度较低的那个"""
self.__fix_by_remove_high_iou_and_low_confidence()
"""将部分tbale_footnote修正为image_footnote"""
self.__fix_footnote()
"""处理重叠的image_body和table_body"""
self.__fix_by_remove_overlap_image_table_body()
def __fix_by_remove_overlap_image_table_body(self):
need_remove_list = []
layout_dets = self.__page_model_info['layout_dets']
image_blocks = list(filter(
lambda x: x['category_id'] == CategoryId.ImageBody, layout_dets
))
table_blocks = list(filter(
lambda x: x['category_id'] == CategoryId.TableBody, layout_dets
))
def add_need_remove_block(blocks):
for i in range(len(blocks)):
for j in range(i + 1, len(blocks)):
block1 = blocks[i]
block2 = blocks[j]
overlap_box = get_minbox_if_overlap_by_ratio(
block1['bbox'], block2['bbox'], 0.8
)
if overlap_box is not None:
# 判断哪个区块的面积更小,移除较小的区块
area1 = (block1['bbox'][2] - block1['bbox'][0]) * (block1['bbox'][3] - block1['bbox'][1])
area2 = (block2['bbox'][2] - block2['bbox'][0]) * (block2['bbox'][3] - block2['bbox'][1])
if area1 <= area2:
block_to_remove = block1
large_block = block2
else:
block_to_remove = block2
large_block = block1
if block_to_remove not in need_remove_list:
# 扩展大区块的边界框
x1, y1, x2, y2 = large_block['bbox']
sx1, sy1, sx2, sy2 = block_to_remove['bbox']
x1 = min(x1, sx1)
y1 = min(y1, sy1)
x2 = max(x2, sx2)
y2 = max(y2, sy2)
large_block['bbox'] = [x1, y1, x2, y2]
need_remove_list.append(block_to_remove)
# 处理图像-图像重叠
add_need_remove_block(image_blocks)
# 处理表格-表格重叠
add_need_remove_block(table_blocks)
# 从布局中移除标记的区块
for need_remove in need_remove_list:
if need_remove in layout_dets:
layout_dets.remove(need_remove)
def __fix_axis(self):
need_remove_list = []
......@@ -46,42 +101,46 @@ class MagicModel:
def __fix_by_remove_high_iou_and_low_confidence(self):
need_remove_list = []
layout_dets = self.__page_model_info['layout_dets']
layout_dets = list(filter(
lambda x: x['category_id'] in [
CategoryId.Title,
CategoryId.Text,
CategoryId.ImageBody,
CategoryId.ImageCaption,
CategoryId.TableBody,
CategoryId.TableCaption,
CategoryId.TableFootnote,
CategoryId.InterlineEquation_Layout,
CategoryId.InterlineEquationNumber_Layout,
], self.__page_model_info['layout_dets']
)
)
for i in range(len(layout_dets)):
for j in range(i + 1, len(layout_dets)):
layout_det1 = layout_dets[i]
layout_det2 = layout_dets[j]
if layout_det1['category_id'] in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] and layout_det2['category_id'] in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]:
if (
calculate_iou(layout_det1['bbox'], layout_det2['bbox'])
> 0.9
):
if layout_det1['score'] < layout_det2['score']:
layout_det_need_remove = layout_det1
else:
layout_det_need_remove = layout_det2
if layout_det_need_remove not in need_remove_list:
need_remove_list.append(layout_det_need_remove)
else:
continue
else:
continue
if calculate_iou(layout_det1['bbox'], layout_det2['bbox']) > 0.9:
layout_det_need_remove = layout_det1 if layout_det1['score'] < layout_det2['score'] else layout_det2
if layout_det_need_remove not in need_remove_list:
need_remove_list.append(layout_det_need_remove)
for need_remove in need_remove_list:
layout_dets.remove(need_remove)
self.__page_model_info['layout_dets'].remove(need_remove)
def __fix_footnote(self):
# 3: figure, 5: table, 7: footnote
footnotes = []
figures = []
tables = []
for obj in self.__page_model_info['layout_dets']:
if obj['category_id'] == 7:
if obj['category_id'] == CategoryId.TableFootnote:
footnotes.append(obj)
elif obj['category_id'] == 3:
elif obj['category_id'] == CategoryId.ImageBody:
figures.append(obj)
elif obj['category_id'] == 5:
elif obj['category_id'] == CategoryId.TableBody:
tables.append(obj)
if len(footnotes) * len(figures) == 0:
continue
......@@ -314,10 +373,10 @@ class MagicModel:
def get_imgs(self):
with_captions = self.__tie_up_category_by_distance_v3(
3, 4
CategoryId.ImageBody, CategoryId.ImageCaption
)
with_footnotes = self.__tie_up_category_by_distance_v3(
3, CategoryId.ImageFootnote
CategoryId.ImageBody, CategoryId.ImageFootnote
)
ret = []
for v in with_captions:
......@@ -333,10 +392,10 @@ class MagicModel:
def get_tables(self) -> list:
with_captions = self.__tie_up_category_by_distance_v3(
5, 6
CategoryId.TableBody, CategoryId.TableCaption
)
with_footnotes = self.__tie_up_category_by_distance_v3(
5, 7
CategoryId.TableBody, CategoryId.TableFootnote
)
ret = []
for v in with_captions:
......@@ -385,20 +444,21 @@ class MagicModel:
all_spans = []
layout_dets = self.__page_model_info['layout_dets']
allow_category_id_list = [3, 5, 13, 14, 15]
allow_category_id_list = [
CategoryId.ImageBody,
CategoryId.TableBody,
CategoryId.InlineEquation,
CategoryId.InterlineEquation_YOLO,
CategoryId.OcrText,
]
"""当成span拼接的"""
# 3: 'image', # 图片
# 5: 'table', # 表格
# 13: 'inline_equation', # 行内公式
# 14: 'interline_equation', # 行间公式
# 15: 'text', # ocr识别文本
for layout_det in layout_dets:
category_id = layout_det['category_id']
if category_id in allow_category_id_list:
span = {'bbox': layout_det['bbox'], 'score': layout_det['score']}
if category_id == 3:
if category_id == CategoryId.ImageBody:
span['type'] = ContentType.IMAGE
elif category_id == 5:
elif category_id == CategoryId.TableBody:
# 获取table模型结果
latex = layout_det.get('latex', None)
html = layout_det.get('html', None)
......@@ -407,13 +467,13 @@ class MagicModel:
elif html:
span['html'] = html
span['type'] = ContentType.TABLE
elif category_id == 13:
elif category_id == CategoryId.InlineEquation:
span['content'] = layout_det['latex']
span['type'] = ContentType.INLINE_EQUATION
elif category_id == 14:
elif category_id == CategoryId.InterlineEquation_YOLO:
span['content'] = layout_det['latex']
span['type'] = ContentType.INTERLINE_EQUATION
elif category_id == 15:
elif category_id == CategoryId.OcrText:
span['content'] = layout_det['text']
span['type'] = ContentType.TEXT
all_spans.append(span)
......@@ -438,4 +498,4 @@ class MagicModel:
for col in extra_col:
block[col] = item.get(col, None)
blocks.append(block)
return blocks
return blocks
\ No newline at end of file
......@@ -34,10 +34,10 @@ async def parse_pdf(
formula_enable: bool = Form(True),
table_enable: bool = Form(True),
server_url: Optional[str] = Form(None),
reuturn_md: bool = Form(True),
reuturn_middle_json: bool = Form(False),
return_md: bool = Form(True),
return_middle_json: bool = Form(False),
return_model_output: bool = Form(False),
reuturn_content_list: bool = Form(False),
return_content_list: bool = Form(False),
return_images: bool = Form(False),
start_page_id: int = Form(0),
end_page_id: int = Form(99999),
......@@ -98,11 +98,11 @@ async def parse_pdf(
server_url=server_url,
f_draw_layout_bbox=False,
f_draw_span_bbox=False,
f_dump_md=reuturn_md,
f_dump_middle_json=reuturn_middle_json,
f_dump_md=return_md,
f_dump_middle_json=return_middle_json,
f_dump_model_output=return_model_output,
f_dump_orig_pdf=False,
f_dump_content_list=reuturn_content_list,
f_dump_content_list=return_content_list,
start_page_id=start_page_id,
end_page_id=end_page_id,
)
......@@ -128,16 +128,16 @@ async def parse_pdf(
if os.path.exists(parse_dir):
if reuturn_md:
if return_md:
data["md_content"] = get_infer_result(".md")
if reuturn_middle_json:
if return_middle_json:
data["middle_json"] = get_infer_result("_middle.json")
if return_model_output:
if backend.startswith("pipeline"):
data["model_output"] = get_infer_result("_model.json")
else:
data["model_output"] = get_infer_result("_model_output.txt")
if reuturn_content_list:
if return_content_list:
data["content_list"] = get_infer_result("_content_list.json")
if return_images:
image_paths = glob(f"{parse_dir}/images/*.jpg")
......
......@@ -90,8 +90,8 @@ def prepare_block_bboxes(
"""经过以上处理后,还存在大框套小框的情况,则删除小框"""
all_bboxes = remove_overlaps_min_blocks(all_bboxes)
all_discarded_blocks = remove_overlaps_min_blocks(all_discarded_blocks)
"""将剩余的bbox做分离处理,防止后面分layout时出错"""
# all_bboxes, drop_reasons = remove_overlap_between_bbox_for_block(all_bboxes)
"""粗排序后返回"""
all_bboxes.sort(key=lambda x: x[0]+x[1])
return all_bboxes, all_discarded_blocks, footnote_blocks
......@@ -213,35 +213,39 @@ def remove_overlaps_min_blocks(all_bboxes):
# 重叠block,小的不能直接删除,需要和大的那个合并成一个更大的。
# 删除重叠blocks中较小的那些
need_remove = []
for block1 in all_bboxes:
for block2 in all_bboxes:
if block1 != block2:
block1_bbox = block1[:4]
block2_bbox = block2[:4]
overlap_box = get_minbox_if_overlap_by_ratio(
block1_bbox, block2_bbox, 0.8
)
if overlap_box is not None:
block_to_remove = next(
(block for block in all_bboxes if block[:4] == overlap_box),
None,
)
if (
block_to_remove is not None
and block_to_remove not in need_remove
):
large_block = block1 if block1 != block_to_remove else block2
x1, y1, x2, y2 = large_block[:4]
sx1, sy1, sx2, sy2 = block_to_remove[:4]
x1 = min(x1, sx1)
y1 = min(y1, sy1)
x2 = max(x2, sx2)
y2 = max(y2, sy2)
large_block[:4] = [x1, y1, x2, y2]
need_remove.append(block_to_remove)
if len(need_remove) > 0:
for block in need_remove:
for i in range(len(all_bboxes)):
for j in range(i + 1, len(all_bboxes)):
block1 = all_bboxes[i]
block2 = all_bboxes[j]
block1_bbox = block1[:4]
block2_bbox = block2[:4]
overlap_box = get_minbox_if_overlap_by_ratio(
block1_bbox, block2_bbox, 0.8
)
if overlap_box is not None:
# 判断哪个区块的面积更小,移除较小的区块
area1 = (block1[2] - block1[0]) * (block1[3] - block1[1])
area2 = (block2[2] - block2[0]) * (block2[3] - block2[1])
if area1 <= area2:
block_to_remove = block1
large_block = block2
else:
block_to_remove = block2
large_block = block1
if block_to_remove not in need_remove:
x1, y1, x2, y2 = large_block[:4]
sx1, sy1, sx2, sy2 = block_to_remove[:4]
x1 = min(x1, sx1)
y1 = min(y1, sy1)
x2 = max(x2, sx2)
y2 = max(y2, sy2)
large_block[:4] = [x1, y1, x2, y2]
need_remove.append(block_to_remove)
for block in need_remove:
if block in all_bboxes:
all_bboxes.remove(block)
return all_bboxes
\ No newline at end of file
......@@ -43,7 +43,7 @@ vlm = [
"pydantic",
]
sglang = [
"sglang[all]>=0.4.7,<0.4.9",
"sglang[all]>=0.4.8,<0.4.9",
]
pipeline = [
"matplotlib>=3.10,<4",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment