Commit c20e9a1e authored by myhloli's avatar myhloli
Browse files

feat(layout): improve title block handling and layout detection

- Merge title blocks that are close to each other horizontally
- Adjust line insertion logic for title blocks- Increase image size and decrease confidence threshold for layout detection
- Update DocLayoutYOLO model weights
- Refactor drawing of bounding boxes for different block types
parent ee9340ea
...@@ -17,7 +17,7 @@ paddlepaddle==3.0.0b1 ...@@ -17,7 +17,7 @@ paddlepaddle==3.0.0b1
struct-eqtable==0.3.2 struct-eqtable==0.3.2
einops einops
accelerate accelerate
doclayout_yolo==0.0.2 doclayout_yolo==0.0.2b1
rapidocr-paddle rapidocr-paddle
rapidocr-onnxruntime rapidocr-onnxruntime
rapid_table==0.3.0 rapid_table==0.3.0
......
...@@ -16,7 +16,7 @@ paddleocr==2.7.3 ...@@ -16,7 +16,7 @@ paddleocr==2.7.3
struct-eqtable==0.3.2 struct-eqtable==0.3.2
einops einops
accelerate accelerate
doclayout_yolo==0.0.2 doclayout_yolo==0.0.2b1
rapidocr-paddle rapidocr-paddle
rapidocr-onnxruntime rapidocr-onnxruntime
rapid_table==0.3.0 rapid_table==0.3.0
......
...@@ -16,7 +16,7 @@ paddleocr==2.7.3 ...@@ -16,7 +16,7 @@ paddleocr==2.7.3
struct-eqtable==0.3.2 struct-eqtable==0.3.2
einops einops
accelerate accelerate
doclayout_yolo==0.0.2 doclayout_yolo==0.0.2b1
rapidocr-paddle rapidocr-paddle
rapidocr-onnxruntime rapidocr-onnxruntime
rapid_table==0.3.0 rapid_table==0.3.0
......
...@@ -362,12 +362,24 @@ def draw_line_sort_bbox(pdf_info, pdf_bytes, out_path, filename): ...@@ -362,12 +362,24 @@ def draw_line_sort_bbox(pdf_info, pdf_bytes, out_path, filename):
for page in pdf_info: for page in pdf_info:
page_line_list = [] page_line_list = []
for block in page['preproc_blocks']: for block in page['preproc_blocks']:
if block['type'] in [BlockType.Text, BlockType.Title, BlockType.InterlineEquation]: if block['type'] in [BlockType.Text]:
for line in block['lines']: for line in block['lines']:
bbox = line['bbox'] bbox = line['bbox']
index = line['index'] index = line['index']
page_line_list.append({'index': index, 'bbox': bbox}) page_line_list.append({'index': index, 'bbox': bbox})
if block['type'] in [BlockType.Image, BlockType.Table]: elif block['type'] in [BlockType.Title, BlockType.InterlineEquation]:
if 'virtual_lines' in block:
if len(block['virtual_lines']) > 0 and block['virtual_lines'][0].get('index', None) is not None:
for line in block['virtual_lines']:
bbox = line['bbox']
index = line['index']
page_line_list.append({'index': index, 'bbox': bbox})
else:
for line in block['lines']:
bbox = line['bbox']
index = line['index']
page_line_list.append({'index': index, 'bbox': bbox})
elif block['type'] in [BlockType.Image, BlockType.Table]:
for sub_block in block['blocks']: for sub_block in block['blocks']:
if sub_block['type'] in [BlockType.ImageBody, BlockType.TableBody]: if sub_block['type'] in [BlockType.ImageBody, BlockType.TableBody]:
if len(sub_block['virtual_lines']) > 0 and sub_block['virtual_lines'][0].get('index', None) is not None: if len(sub_block['virtual_lines']) > 0 and sub_block['virtual_lines'][0].get('index', None) is not None:
......
...@@ -144,7 +144,7 @@ class CustomPEKModel: ...@@ -144,7 +144,7 @@ class CustomPEKModel:
model_config_dir, 'layoutlmv3', 'layoutlmv3_base_inference.yaml' model_config_dir, 'layoutlmv3', 'layoutlmv3_base_inference.yaml'
) )
), ),
device=self.device, device='cpu' if str(self.device).startswith("mps") else self.device,
) )
elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO: elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
self.layout_model = atom_model_manager.get_atom_model( self.layout_model = atom_model_manager.get_atom_model(
...@@ -192,24 +192,24 @@ class CustomPEKModel: ...@@ -192,24 +192,24 @@ class CustomPEKModel:
layout_res = self.layout_model(image, ignore_catids=[]) layout_res = self.layout_model(image, ignore_catids=[])
elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO: elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
# doclayout_yolo # doclayout_yolo
if height > width: # if height > width:
input_res = {"poly":[0,0,width,0,width,height,0,height]} # input_res = {"poly":[0,0,width,0,width,height,0,height]}
new_image, useful_list = crop_img(input_res, pil_img, crop_paste_x=width//2, crop_paste_y=0) # new_image, useful_list = crop_img(input_res, pil_img, crop_paste_x=width//2, crop_paste_y=0)
paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list # paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
layout_res = self.layout_model.predict(new_image) # layout_res = self.layout_model.predict(new_image)
for res in layout_res: # for res in layout_res:
p1, p2, p3, p4, p5, p6, p7, p8 = res['poly'] # p1, p2, p3, p4, p5, p6, p7, p8 = res['poly']
p1 = p1 - paste_x + xmin # p1 = p1 - paste_x + xmin
p2 = p2 - paste_y + ymin # p2 = p2 - paste_y + ymin
p3 = p3 - paste_x + xmin # p3 = p3 - paste_x + xmin
p4 = p4 - paste_y + ymin # p4 = p4 - paste_y + ymin
p5 = p5 - paste_x + xmin # p5 = p5 - paste_x + xmin
p6 = p6 - paste_y + ymin # p6 = p6 - paste_y + ymin
p7 = p7 - paste_x + xmin # p7 = p7 - paste_x + xmin
p8 = p8 - paste_y + ymin # p8 = p8 - paste_y + ymin
res['poly'] = [p1, p2, p3, p4, p5, p6, p7, p8] # res['poly'] = [p1, p2, p3, p4, p5, p6, p7, p8]
else: # else:
layout_res = self.layout_model.predict(image) layout_res = self.layout_model.predict(image)
layout_cost = round(time.time() - layout_start, 2) layout_cost = round(time.time() - layout_start, 2)
logger.info(f'layout detection time: {layout_cost}') logger.info(f'layout detection time: {layout_cost}')
......
...@@ -9,7 +9,11 @@ class DocLayoutYOLOModel(object): ...@@ -9,7 +9,11 @@ class DocLayoutYOLOModel(object):
def predict(self, image): def predict(self, image):
layout_res = [] layout_res = []
doclayout_yolo_res = self.model.predict( doclayout_yolo_res = self.model.predict(
image, imgsz=1024, conf=0.25, iou=0.45, verbose=False, device=self.device image,
imgsz=1280,
conf=0.10,
iou=0.45,
verbose=False, device=self.device
)[0] )[0]
for xyxy, conf, cla in zip( for xyxy, conf, cla in zip(
doclayout_yolo_res.boxes.xyxy.cpu(), doclayout_yolo_res.boxes.xyxy.cpu(),
...@@ -32,8 +36,8 @@ class DocLayoutYOLOModel(object): ...@@ -32,8 +36,8 @@ class DocLayoutYOLOModel(object):
image_res.cpu() image_res.cpu()
for image_res in self.model.predict( for image_res in self.model.predict(
images[index : index + batch_size], images[index : index + batch_size],
imgsz=1024, imgsz=1280,
conf=0.25, conf=0.10,
iou=0.45, iou=0.45,
verbose=False, verbose=False,
device=self.device, device=self.device,
......
...@@ -12,7 +12,7 @@ from loguru import logger ...@@ -12,7 +12,7 @@ from loguru import logger
from magic_pdf.config.enums import SupportedPdfParseMethod from magic_pdf.config.enums import SupportedPdfParseMethod
from magic_pdf.config.ocr_content_type import BlockType, ContentType from magic_pdf.config.ocr_content_type import BlockType, ContentType
from magic_pdf.data.dataset import Dataset, PageableData from magic_pdf.data.dataset import Dataset, PageableData
from magic_pdf.libs.boxbase import calculate_overlap_area_in_bbox1_area_ratio from magic_pdf.libs.boxbase import calculate_overlap_area_in_bbox1_area_ratio, __is_overlaps_y_exceeds_threshold
from magic_pdf.libs.clean_memory import clean_memory from magic_pdf.libs.clean_memory import clean_memory
from magic_pdf.libs.config_reader import get_local_layoutreader_model_dir, get_llm_aided_config, get_device from magic_pdf.libs.config_reader import get_local_layoutreader_model_dir, get_llm_aided_config, get_device
from magic_pdf.libs.convert_utils import dict_to_list from magic_pdf.libs.convert_utils import dict_to_list
...@@ -365,10 +365,11 @@ def cal_block_index(fix_blocks, sorted_bboxes): ...@@ -365,10 +365,11 @@ def cal_block_index(fix_blocks, sorted_bboxes):
block['index'] = median_value block['index'] = median_value
# 删除图表body block中的虚拟line信息, 并用real_lines信息回填 # 删除图表body block中的虚拟line信息, 并用real_lines信息回填
if block['type'] in [BlockType.ImageBody, BlockType.TableBody]: if block['type'] in [BlockType.ImageBody, BlockType.TableBody, BlockType.Title, BlockType.InterlineEquation]:
block['virtual_lines'] = copy.deepcopy(block['lines']) if 'real_lines' in block:
block['lines'] = copy.deepcopy(block['real_lines']) block['virtual_lines'] = copy.deepcopy(block['lines'])
del block['real_lines'] block['lines'] = copy.deepcopy(block['real_lines'])
del block['real_lines']
else: else:
# 使用xycut排序 # 使用xycut排序
block_bboxes = [] block_bboxes = []
...@@ -417,7 +418,7 @@ def insert_lines_into_block(block_bbox, line_height, page_w, page_h): ...@@ -417,7 +418,7 @@ def insert_lines_into_block(block_bbox, line_height, page_w, page_h):
block_weight = x1 - x0 block_weight = x1 - x0
# 如果block高度小于n行正文,则直接返回block的bbox # 如果block高度小于n行正文,则直接返回block的bbox
if line_height * 3 < block_height: if line_height * 2 < block_height:
if ( if (
block_height > page_h * 0.25 and page_w * 0.5 > block_weight > page_w * 0.25 block_height > page_h * 0.25 and page_w * 0.5 > block_weight > page_w * 0.25
): # 可能是双列结构,可以切细点 ): # 可能是双列结构,可以切细点
...@@ -425,16 +426,16 @@ def insert_lines_into_block(block_bbox, line_height, page_w, page_h): ...@@ -425,16 +426,16 @@ def insert_lines_into_block(block_bbox, line_height, page_w, page_h):
else: else:
# 如果block的宽度超过0.4页面宽度,则将block分成3行(是一种复杂布局,图不能切的太细) # 如果block的宽度超过0.4页面宽度,则将block分成3行(是一种复杂布局,图不能切的太细)
if block_weight > page_w * 0.4: if block_weight > page_w * 0.4:
line_height = (y1 - y0) / 3
lines = 3 lines = 3
line_height = (y1 - y0) / lines
elif block_weight > page_w * 0.25: # (可能是三列结构,也切细点) elif block_weight > page_w * 0.25: # (可能是三列结构,也切细点)
lines = int(block_height / line_height) + 1 lines = int(block_height / line_height) + 1
else: # 判断长宽比 else: # 判断长宽比
if block_height / block_weight > 1.2: # 细长的不分 if block_height / block_weight > 1.2: # 细长的不分
return [[x0, y0, x1, y1]] return [[x0, y0, x1, y1]]
else: # 不细长的还是分成两行 else: # 不细长的还是分成两行
line_height = (y1 - y0) / 2
lines = 2 lines = 2
line_height = (y1 - y0) / lines
# 确定从哪个y位置开始绘制线条 # 确定从哪个y位置开始绘制线条
current_y = y0 current_y = y0
...@@ -453,30 +454,32 @@ def insert_lines_into_block(block_bbox, line_height, page_w, page_h): ...@@ -453,30 +454,32 @@ def insert_lines_into_block(block_bbox, line_height, page_w, page_h):
def sort_lines_by_model(fix_blocks, page_w, page_h, line_height): def sort_lines_by_model(fix_blocks, page_w, page_h, line_height):
page_line_list = [] page_line_list = []
def add_lines_to_block(b):
line_bboxes = insert_lines_into_block(b['bbox'], line_height, page_w, page_h)
b['lines'] = []
for line_bbox in line_bboxes:
b['lines'].append({'bbox': line_bbox, 'spans': []})
page_line_list.extend(line_bboxes)
for block in fix_blocks: for block in fix_blocks:
if block['type'] in [ if block['type'] in [
BlockType.Text, BlockType.Title, BlockType.InterlineEquation, BlockType.Text, BlockType.Title,
BlockType.ImageCaption, BlockType.ImageFootnote, BlockType.ImageCaption, BlockType.ImageFootnote,
BlockType.TableCaption, BlockType.TableFootnote BlockType.TableCaption, BlockType.TableFootnote
]: ]:
if len(block['lines']) == 0: if len(block['lines']) == 0:
bbox = block['bbox'] add_lines_to_block(block)
lines = insert_lines_into_block(bbox, line_height, page_w, page_h) elif block['type'] in [BlockType.Title] and len(block['lines']) == 1 and (block['bbox'][3] - block['bbox'][1]) > line_height * 2:
for line in lines: block['real_lines'] = copy.deepcopy(block['lines'])
block['lines'].append({'bbox': line, 'spans': []}) add_lines_to_block(block)
page_line_list.extend(lines)
else: else:
for line in block['lines']: for line in block['lines']:
bbox = line['bbox'] bbox = line['bbox']
page_line_list.append(bbox) page_line_list.append(bbox)
elif block['type'] in [BlockType.ImageBody, BlockType.TableBody]: elif block['type'] in [BlockType.ImageBody, BlockType.TableBody, BlockType.InterlineEquation]:
bbox = block['bbox']
block['real_lines'] = copy.deepcopy(block['lines']) block['real_lines'] = copy.deepcopy(block['lines'])
lines = insert_lines_into_block(bbox, line_height, page_w, page_h) add_lines_to_block(block)
block['lines'] = []
for line in lines:
block['lines'].append({'bbox': line, 'spans': []})
page_line_list.extend(lines)
if len(page_line_list) > 200: # layoutreader最高支持512line if len(page_line_list) > 200: # layoutreader最高支持512line
return None return None
...@@ -663,12 +666,68 @@ def parse_page_core( ...@@ -663,12 +666,68 @@ def parse_page_core(
discarded_blocks = magic_model.get_discarded(page_id) discarded_blocks = magic_model.get_discarded(page_id)
text_blocks = magic_model.get_text_blocks(page_id) text_blocks = magic_model.get_text_blocks(page_id)
title_blocks = magic_model.get_title_blocks(page_id) title_blocks = magic_model.get_title_blocks(page_id)
inline_equations, interline_equations, interline_equation_blocks = ( inline_equations, interline_equations, interline_equation_blocks = magic_model.get_equations(page_id)
magic_model.get_equations(page_id)
)
page_w, page_h = magic_model.get_page_size(page_id) page_w, page_h = magic_model.get_page_size(page_id)
def merge_title_blocks(blocks, x_distance_threshold=0.1*page_w):
def merge_two_blocks(b1, b2):
# 合并两个标题块的边界框
x_min = min(b1['bbox'][0], b2['bbox'][0])
y_min = min(b1['bbox'][1], b2['bbox'][1])
x_max = max(b1['bbox'][2], b2['bbox'][2])
y_max = max(b1['bbox'][3], b2['bbox'][3])
merged_bbox = (x_min, y_min, x_max, y_max)
# 合并两个标题块的文本内容
merged_score = (b1['score'] + b2['score']) / 2
return {'bbox': merged_bbox, 'score': merged_score}
# 按 y 轴重叠度聚集标题块
y_overlapping_blocks = []
while blocks:
block1 = blocks.pop(0)
current_row = [block1]
to_remove = []
for block2 in blocks:
if __is_overlaps_y_exceeds_threshold(block1['bbox'], block2['bbox'], 0.9):
current_row.append(block2)
to_remove.append(block2)
for b in to_remove:
blocks.remove(b)
y_overlapping_blocks.append(current_row)
# 按x轴坐标排序并合并标题块
merged_blocks = []
for row in y_overlapping_blocks:
if len(row) == 1:
merged_blocks.append(row[0])
continue
# 按x轴坐标排序
row.sort(key=lambda x: x['bbox'][0])
merged_block = row[0]
for i in range(1, len(row)):
left_block = merged_block
right_block = row[i]
left_height = left_block['bbox'][3] - left_block['bbox'][1]
right_height = right_block['bbox'][3] - right_block['bbox'][1]
if right_block['bbox'][0] - left_block['bbox'][2] < x_distance_threshold and left_height * 0.95 < right_height < left_height * 1.05:
merged_block = merge_two_blocks(merged_block, right_block)
else:
merged_blocks.append(merged_block)
merged_block = right_block
merged_blocks.append(merged_block)
return merged_blocks
"""同一行被断开的titile合并"""
title_blocks = merge_title_blocks(title_blocks)
"""将所有区块的bbox整理到一起""" """将所有区块的bbox整理到一起"""
# interline_equation_blocks参数不够准,后面切换到interline_equations上 # interline_equation_blocks参数不够准,后面切换到interline_equations上
interline_equation_blocks = [] interline_equation_blocks = []
......
weights: weights:
layoutlmv3: Layout/LayoutLMv3/model_final.pth layoutlmv3: Layout/LayoutLMv3/model_final.pth
doclayout_yolo: Layout/YOLO/doclayout_yolo_ft.pt doclayout_yolo: Layout/YOLO/doclayout_yolo_docstructbench_imgsz1280_2501.pt
yolo_v8_mfd: MFD/YOLO/yolo_v8_ft.pt yolo_v8_mfd: MFD/YOLO/yolo_v8_ft.pt
unimernet_small: MFR/unimernet_small unimernet_small: MFR/unimernet_small
struct_eqtable: TabRec/StructEqTable struct_eqtable: TabRec/StructEqTable
......
...@@ -48,7 +48,7 @@ if __name__ == '__main__': ...@@ -48,7 +48,7 @@ if __name__ == '__main__':
"struct-eqtable==0.3.2", # 表格解析 "struct-eqtable==0.3.2", # 表格解析
"einops", # struct-eqtable依赖 "einops", # struct-eqtable依赖
"accelerate", # struct-eqtable依赖 "accelerate", # struct-eqtable依赖
"doclayout_yolo==0.0.2", # doclayout_yolo "doclayout_yolo==0.0.2b1", # doclayout_yolo
"rapidocr-paddle", # rapidocr-paddle "rapidocr-paddle", # rapidocr-paddle
"rapidocr_onnxruntime", "rapidocr_onnxruntime",
"rapid_table==0.3.0", # rapid_table "rapid_table==0.3.0", # rapid_table
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment