Unverified Commit 516f4926 authored by Xiaomeng Zhao's avatar Xiaomeng Zhao Committed by GitHub
Browse files

Merge pull request #2873 from myhloli/dev

Dev
parents 6a242ada 7d8f68cb
...@@ -152,9 +152,6 @@ def page_model_info_to_page_info(page_model_info, image_dict, page, image_writer ...@@ -152,9 +152,6 @@ def page_model_info_to_page_info(page_model_info, image_dict, page, image_writer
"""对block进行fix操作""" """对block进行fix操作"""
fix_blocks = fix_block_spans(block_with_spans) fix_blocks = fix_block_spans(block_with_spans)
"""同一行被断开的titile合并"""
# merge_title_blocks(fix_blocks)
"""对block进行排序""" """对block进行排序"""
sorted_blocks = sort_blocks_by_bbox(fix_blocks, page_w, page_h, footnote_blocks) sorted_blocks = sort_blocks_by_bbox(fix_blocks, page_w, page_h, footnote_blocks)
......
from mineru.utils.boxbase import bbox_relative_pos, calculate_iou, bbox_distance, is_in from mineru.utils.boxbase import bbox_relative_pos, calculate_iou, bbox_distance, is_in, get_minbox_if_overlap_by_ratio
from mineru.utils.enum_class import CategoryId, ContentType from mineru.utils.enum_class import CategoryId, ContentType
...@@ -13,7 +13,62 @@ class MagicModel: ...@@ -13,7 +13,62 @@ class MagicModel:
self.__fix_by_remove_low_confidence() self.__fix_by_remove_low_confidence()
"""删除高iou(>0.9)数据中置信度较低的那个""" """删除高iou(>0.9)数据中置信度较低的那个"""
self.__fix_by_remove_high_iou_and_low_confidence() self.__fix_by_remove_high_iou_and_low_confidence()
"""将部分tbale_footnote修正为image_footnote"""
self.__fix_footnote() self.__fix_footnote()
"""处理重叠的image_body和table_body"""
self.__fix_by_remove_overlap_image_table_body()
def __fix_by_remove_overlap_image_table_body(self):
need_remove_list = []
layout_dets = self.__page_model_info['layout_dets']
image_blocks = list(filter(
lambda x: x['category_id'] == CategoryId.ImageBody, layout_dets
))
table_blocks = list(filter(
lambda x: x['category_id'] == CategoryId.TableBody, layout_dets
))
def add_need_remove_block(blocks):
for i in range(len(blocks)):
for j in range(i + 1, len(blocks)):
block1 = blocks[i]
block2 = blocks[j]
overlap_box = get_minbox_if_overlap_by_ratio(
block1['bbox'], block2['bbox'], 0.8
)
if overlap_box is not None:
# 判断哪个区块的面积更小,移除较小的区块
area1 = (block1['bbox'][2] - block1['bbox'][0]) * (block1['bbox'][3] - block1['bbox'][1])
area2 = (block2['bbox'][2] - block2['bbox'][0]) * (block2['bbox'][3] - block2['bbox'][1])
if area1 <= area2:
block_to_remove = block1
large_block = block2
else:
block_to_remove = block2
large_block = block1
if block_to_remove not in need_remove_list:
# 扩展大区块的边界框
x1, y1, x2, y2 = large_block['bbox']
sx1, sy1, sx2, sy2 = block_to_remove['bbox']
x1 = min(x1, sx1)
y1 = min(y1, sy1)
x2 = max(x2, sx2)
y2 = max(y2, sy2)
large_block['bbox'] = [x1, y1, x2, y2]
need_remove_list.append(block_to_remove)
# 处理图像-图像重叠
add_need_remove_block(image_blocks)
# 处理表格-表格重叠
add_need_remove_block(table_blocks)
# 从布局中移除标记的区块
for need_remove in need_remove_list:
if need_remove in layout_dets:
layout_dets.remove(need_remove)
def __fix_axis(self): def __fix_axis(self):
need_remove_list = [] need_remove_list = []
...@@ -46,42 +101,46 @@ class MagicModel: ...@@ -46,42 +101,46 @@ class MagicModel:
def __fix_by_remove_high_iou_and_low_confidence(self): def __fix_by_remove_high_iou_and_low_confidence(self):
need_remove_list = [] need_remove_list = []
layout_dets = self.__page_model_info['layout_dets'] layout_dets = list(filter(
lambda x: x['category_id'] in [
CategoryId.Title,
CategoryId.Text,
CategoryId.ImageBody,
CategoryId.ImageCaption,
CategoryId.TableBody,
CategoryId.TableCaption,
CategoryId.TableFootnote,
CategoryId.InterlineEquation_Layout,
CategoryId.InterlineEquationNumber_Layout,
], self.__page_model_info['layout_dets']
)
)
for i in range(len(layout_dets)): for i in range(len(layout_dets)):
for j in range(i + 1, len(layout_dets)): for j in range(i + 1, len(layout_dets)):
layout_det1 = layout_dets[i] layout_det1 = layout_dets[i]
layout_det2 = layout_dets[j] layout_det2 = layout_dets[j]
if layout_det1['category_id'] in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] and layout_det2['category_id'] in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]:
if (
calculate_iou(layout_det1['bbox'], layout_det2['bbox'])
> 0.9
):
if layout_det1['score'] < layout_det2['score']:
layout_det_need_remove = layout_det1
else:
layout_det_need_remove = layout_det2
if layout_det_need_remove not in need_remove_list: if calculate_iou(layout_det1['bbox'], layout_det2['bbox']) > 0.9:
need_remove_list.append(layout_det_need_remove)
else: layout_det_need_remove = layout_det1 if layout_det1['score'] < layout_det2['score'] else layout_det2
continue
else: if layout_det_need_remove not in need_remove_list:
continue need_remove_list.append(layout_det_need_remove)
for need_remove in need_remove_list: for need_remove in need_remove_list:
layout_dets.remove(need_remove) self.__page_model_info['layout_dets'].remove(need_remove)
def __fix_footnote(self): def __fix_footnote(self):
# 3: figure, 5: table, 7: footnote
footnotes = [] footnotes = []
figures = [] figures = []
tables = [] tables = []
for obj in self.__page_model_info['layout_dets']: for obj in self.__page_model_info['layout_dets']:
if obj['category_id'] == 7: if obj['category_id'] == CategoryId.TableFootnote:
footnotes.append(obj) footnotes.append(obj)
elif obj['category_id'] == 3: elif obj['category_id'] == CategoryId.ImageBody:
figures.append(obj) figures.append(obj)
elif obj['category_id'] == 5: elif obj['category_id'] == CategoryId.TableBody:
tables.append(obj) tables.append(obj)
if len(footnotes) * len(figures) == 0: if len(footnotes) * len(figures) == 0:
continue continue
...@@ -314,10 +373,10 @@ class MagicModel: ...@@ -314,10 +373,10 @@ class MagicModel:
def get_imgs(self): def get_imgs(self):
with_captions = self.__tie_up_category_by_distance_v3( with_captions = self.__tie_up_category_by_distance_v3(
3, 4 CategoryId.ImageBody, CategoryId.ImageCaption
) )
with_footnotes = self.__tie_up_category_by_distance_v3( with_footnotes = self.__tie_up_category_by_distance_v3(
3, CategoryId.ImageFootnote CategoryId.ImageBody, CategoryId.ImageFootnote
) )
ret = [] ret = []
for v in with_captions: for v in with_captions:
...@@ -333,10 +392,10 @@ class MagicModel: ...@@ -333,10 +392,10 @@ class MagicModel:
def get_tables(self) -> list: def get_tables(self) -> list:
with_captions = self.__tie_up_category_by_distance_v3( with_captions = self.__tie_up_category_by_distance_v3(
5, 6 CategoryId.TableBody, CategoryId.TableCaption
) )
with_footnotes = self.__tie_up_category_by_distance_v3( with_footnotes = self.__tie_up_category_by_distance_v3(
5, 7 CategoryId.TableBody, CategoryId.TableFootnote
) )
ret = [] ret = []
for v in with_captions: for v in with_captions:
...@@ -385,20 +444,21 @@ class MagicModel: ...@@ -385,20 +444,21 @@ class MagicModel:
all_spans = [] all_spans = []
layout_dets = self.__page_model_info['layout_dets'] layout_dets = self.__page_model_info['layout_dets']
allow_category_id_list = [3, 5, 13, 14, 15] allow_category_id_list = [
CategoryId.ImageBody,
CategoryId.TableBody,
CategoryId.InlineEquation,
CategoryId.InterlineEquation_YOLO,
CategoryId.OcrText,
]
"""当成span拼接的""" """当成span拼接的"""
# 3: 'image', # 图片
# 5: 'table', # 表格
# 13: 'inline_equation', # 行内公式
# 14: 'interline_equation', # 行间公式
# 15: 'text', # ocr识别文本
for layout_det in layout_dets: for layout_det in layout_dets:
category_id = layout_det['category_id'] category_id = layout_det['category_id']
if category_id in allow_category_id_list: if category_id in allow_category_id_list:
span = {'bbox': layout_det['bbox'], 'score': layout_det['score']} span = {'bbox': layout_det['bbox'], 'score': layout_det['score']}
if category_id == 3: if category_id == CategoryId.ImageBody:
span['type'] = ContentType.IMAGE span['type'] = ContentType.IMAGE
elif category_id == 5: elif category_id == CategoryId.TableBody:
# 获取table模型结果 # 获取table模型结果
latex = layout_det.get('latex', None) latex = layout_det.get('latex', None)
html = layout_det.get('html', None) html = layout_det.get('html', None)
...@@ -407,13 +467,13 @@ class MagicModel: ...@@ -407,13 +467,13 @@ class MagicModel:
elif html: elif html:
span['html'] = html span['html'] = html
span['type'] = ContentType.TABLE span['type'] = ContentType.TABLE
elif category_id == 13: elif category_id == CategoryId.InlineEquation:
span['content'] = layout_det['latex'] span['content'] = layout_det['latex']
span['type'] = ContentType.INLINE_EQUATION span['type'] = ContentType.INLINE_EQUATION
elif category_id == 14: elif category_id == CategoryId.InterlineEquation_YOLO:
span['content'] = layout_det['latex'] span['content'] = layout_det['latex']
span['type'] = ContentType.INTERLINE_EQUATION span['type'] = ContentType.INTERLINE_EQUATION
elif category_id == 15: elif category_id == CategoryId.OcrText:
span['content'] = layout_det['text'] span['content'] = layout_det['text']
span['type'] = ContentType.TEXT span['type'] = ContentType.TEXT
all_spans.append(span) all_spans.append(span)
...@@ -438,4 +498,4 @@ class MagicModel: ...@@ -438,4 +498,4 @@ class MagicModel:
for col in extra_col: for col in extra_col:
block[col] = item.get(col, None) block[col] = item.get(col, None)
blocks.append(block) blocks.append(block)
return blocks return blocks
\ No newline at end of file
...@@ -34,10 +34,10 @@ async def parse_pdf( ...@@ -34,10 +34,10 @@ async def parse_pdf(
formula_enable: bool = Form(True), formula_enable: bool = Form(True),
table_enable: bool = Form(True), table_enable: bool = Form(True),
server_url: Optional[str] = Form(None), server_url: Optional[str] = Form(None),
reuturn_md: bool = Form(True), return_md: bool = Form(True),
reuturn_middle_json: bool = Form(False), return_middle_json: bool = Form(False),
return_model_output: bool = Form(False), return_model_output: bool = Form(False),
reuturn_content_list: bool = Form(False), return_content_list: bool = Form(False),
return_images: bool = Form(False), return_images: bool = Form(False),
start_page_id: int = Form(0), start_page_id: int = Form(0),
end_page_id: int = Form(99999), end_page_id: int = Form(99999),
...@@ -98,11 +98,11 @@ async def parse_pdf( ...@@ -98,11 +98,11 @@ async def parse_pdf(
server_url=server_url, server_url=server_url,
f_draw_layout_bbox=False, f_draw_layout_bbox=False,
f_draw_span_bbox=False, f_draw_span_bbox=False,
f_dump_md=reuturn_md, f_dump_md=return_md,
f_dump_middle_json=reuturn_middle_json, f_dump_middle_json=return_middle_json,
f_dump_model_output=return_model_output, f_dump_model_output=return_model_output,
f_dump_orig_pdf=False, f_dump_orig_pdf=False,
f_dump_content_list=reuturn_content_list, f_dump_content_list=return_content_list,
start_page_id=start_page_id, start_page_id=start_page_id,
end_page_id=end_page_id, end_page_id=end_page_id,
) )
...@@ -128,16 +128,16 @@ async def parse_pdf( ...@@ -128,16 +128,16 @@ async def parse_pdf(
if os.path.exists(parse_dir): if os.path.exists(parse_dir):
if reuturn_md: if return_md:
data["md_content"] = get_infer_result(".md") data["md_content"] = get_infer_result(".md")
if reuturn_middle_json: if return_middle_json:
data["middle_json"] = get_infer_result("_middle.json") data["middle_json"] = get_infer_result("_middle.json")
if return_model_output: if return_model_output:
if backend.startswith("pipeline"): if backend.startswith("pipeline"):
data["model_output"] = get_infer_result("_model.json") data["model_output"] = get_infer_result("_model.json")
else: else:
data["model_output"] = get_infer_result("_model_output.txt") data["model_output"] = get_infer_result("_model_output.txt")
if reuturn_content_list: if return_content_list:
data["content_list"] = get_infer_result("_content_list.json") data["content_list"] = get_infer_result("_content_list.json")
if return_images: if return_images:
image_paths = glob(f"{parse_dir}/images/*.jpg") image_paths = glob(f"{parse_dir}/images/*.jpg")
......
...@@ -90,8 +90,8 @@ def prepare_block_bboxes( ...@@ -90,8 +90,8 @@ def prepare_block_bboxes(
"""经过以上处理后,还存在大框套小框的情况,则删除小框""" """经过以上处理后,还存在大框套小框的情况,则删除小框"""
all_bboxes = remove_overlaps_min_blocks(all_bboxes) all_bboxes = remove_overlaps_min_blocks(all_bboxes)
all_discarded_blocks = remove_overlaps_min_blocks(all_discarded_blocks) all_discarded_blocks = remove_overlaps_min_blocks(all_discarded_blocks)
"""将剩余的bbox做分离处理,防止后面分layout时出错"""
# all_bboxes, drop_reasons = remove_overlap_between_bbox_for_block(all_bboxes) """粗排序后返回"""
all_bboxes.sort(key=lambda x: x[0]+x[1]) all_bboxes.sort(key=lambda x: x[0]+x[1])
return all_bboxes, all_discarded_blocks, footnote_blocks return all_bboxes, all_discarded_blocks, footnote_blocks
...@@ -213,35 +213,39 @@ def remove_overlaps_min_blocks(all_bboxes): ...@@ -213,35 +213,39 @@ def remove_overlaps_min_blocks(all_bboxes):
# 重叠block,小的不能直接删除,需要和大的那个合并成一个更大的。 # 重叠block,小的不能直接删除,需要和大的那个合并成一个更大的。
# 删除重叠blocks中较小的那些 # 删除重叠blocks中较小的那些
need_remove = [] need_remove = []
for block1 in all_bboxes: for i in range(len(all_bboxes)):
for block2 in all_bboxes: for j in range(i + 1, len(all_bboxes)):
if block1 != block2: block1 = all_bboxes[i]
block1_bbox = block1[:4] block2 = all_bboxes[j]
block2_bbox = block2[:4] block1_bbox = block1[:4]
overlap_box = get_minbox_if_overlap_by_ratio( block2_bbox = block2[:4]
block1_bbox, block2_bbox, 0.8 overlap_box = get_minbox_if_overlap_by_ratio(
) block1_bbox, block2_bbox, 0.8
if overlap_box is not None: )
block_to_remove = next( if overlap_box is not None:
(block for block in all_bboxes if block[:4] == overlap_box), # 判断哪个区块的面积更小,移除较小的区块
None, area1 = (block1[2] - block1[0]) * (block1[3] - block1[1])
) area2 = (block2[2] - block2[0]) * (block2[3] - block2[1])
if (
block_to_remove is not None if area1 <= area2:
and block_to_remove not in need_remove block_to_remove = block1
): large_block = block2
large_block = block1 if block1 != block_to_remove else block2 else:
x1, y1, x2, y2 = large_block[:4] block_to_remove = block2
sx1, sy1, sx2, sy2 = block_to_remove[:4] large_block = block1
x1 = min(x1, sx1)
y1 = min(y1, sy1) if block_to_remove not in need_remove:
x2 = max(x2, sx2) x1, y1, x2, y2 = large_block[:4]
y2 = max(y2, sy2) sx1, sy1, sx2, sy2 = block_to_remove[:4]
large_block[:4] = [x1, y1, x2, y2] x1 = min(x1, sx1)
need_remove.append(block_to_remove) y1 = min(y1, sy1)
x2 = max(x2, sx2)
if len(need_remove) > 0: y2 = max(y2, sy2)
for block in need_remove: large_block[:4] = [x1, y1, x2, y2]
need_remove.append(block_to_remove)
for block in need_remove:
if block in all_bboxes:
all_bboxes.remove(block) all_bboxes.remove(block)
return all_bboxes return all_bboxes
\ No newline at end of file
...@@ -43,7 +43,7 @@ vlm = [ ...@@ -43,7 +43,7 @@ vlm = [
"pydantic", "pydantic",
] ]
sglang = [ sglang = [
"sglang[all]>=0.4.7,<0.4.9", "sglang[all]>=0.4.8,<0.4.9",
] ]
pipeline = [ pipeline = [
"matplotlib>=3.10,<4", "matplotlib>=3.10,<4",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment