Commit c20e9a1e authored by myhloli's avatar myhloli
Browse files

feat(layout): improve title block handling and layout detection

- Merge title blocks that are close to each other horizontally
- Adjust line insertion logic for title blocks- Increase image size and decrease confidence threshold for layout detection
- Update DocLayoutYOLO model weights
- Refactor drawing of bounding boxes for different block types
parent ee9340ea
......@@ -17,7 +17,7 @@ paddlepaddle==3.0.0b1
struct-eqtable==0.3.2
einops
accelerate
doclayout_yolo==0.0.2
doclayout_yolo==0.0.2b1
rapidocr-paddle
rapidocr-onnxruntime
rapid_table==0.3.0
......
......@@ -16,7 +16,7 @@ paddleocr==2.7.3
struct-eqtable==0.3.2
einops
accelerate
doclayout_yolo==0.0.2
doclayout_yolo==0.0.2b1
rapidocr-paddle
rapidocr-onnxruntime
rapid_table==0.3.0
......
......@@ -16,7 +16,7 @@ paddleocr==2.7.3
struct-eqtable==0.3.2
einops
accelerate
doclayout_yolo==0.0.2
doclayout_yolo==0.0.2b1
rapidocr-paddle
rapidocr-onnxruntime
rapid_table==0.3.0
......
......@@ -362,12 +362,24 @@ def draw_line_sort_bbox(pdf_info, pdf_bytes, out_path, filename):
for page in pdf_info:
page_line_list = []
for block in page['preproc_blocks']:
if block['type'] in [BlockType.Text, BlockType.Title, BlockType.InterlineEquation]:
if block['type'] in [BlockType.Text]:
for line in block['lines']:
bbox = line['bbox']
index = line['index']
page_line_list.append({'index': index, 'bbox': bbox})
if block['type'] in [BlockType.Image, BlockType.Table]:
elif block['type'] in [BlockType.Title, BlockType.InterlineEquation]:
if 'virtual_lines' in block:
if len(block['virtual_lines']) > 0 and block['virtual_lines'][0].get('index', None) is not None:
for line in block['virtual_lines']:
bbox = line['bbox']
index = line['index']
page_line_list.append({'index': index, 'bbox': bbox})
else:
for line in block['lines']:
bbox = line['bbox']
index = line['index']
page_line_list.append({'index': index, 'bbox': bbox})
elif block['type'] in [BlockType.Image, BlockType.Table]:
for sub_block in block['blocks']:
if sub_block['type'] in [BlockType.ImageBody, BlockType.TableBody]:
if len(sub_block['virtual_lines']) > 0 and sub_block['virtual_lines'][0].get('index', None) is not None:
......
......@@ -144,7 +144,7 @@ class CustomPEKModel:
model_config_dir, 'layoutlmv3', 'layoutlmv3_base_inference.yaml'
)
),
device=self.device,
device='cpu' if str(self.device).startswith("mps") else self.device,
)
elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
self.layout_model = atom_model_manager.get_atom_model(
......@@ -192,24 +192,24 @@ class CustomPEKModel:
layout_res = self.layout_model(image, ignore_catids=[])
elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
# doclayout_yolo
if height > width:
input_res = {"poly":[0,0,width,0,width,height,0,height]}
new_image, useful_list = crop_img(input_res, pil_img, crop_paste_x=width//2, crop_paste_y=0)
paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
layout_res = self.layout_model.predict(new_image)
for res in layout_res:
p1, p2, p3, p4, p5, p6, p7, p8 = res['poly']
p1 = p1 - paste_x + xmin
p2 = p2 - paste_y + ymin
p3 = p3 - paste_x + xmin
p4 = p4 - paste_y + ymin
p5 = p5 - paste_x + xmin
p6 = p6 - paste_y + ymin
p7 = p7 - paste_x + xmin
p8 = p8 - paste_y + ymin
res['poly'] = [p1, p2, p3, p4, p5, p6, p7, p8]
else:
layout_res = self.layout_model.predict(image)
# if height > width:
# input_res = {"poly":[0,0,width,0,width,height,0,height]}
# new_image, useful_list = crop_img(input_res, pil_img, crop_paste_x=width//2, crop_paste_y=0)
# paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
# layout_res = self.layout_model.predict(new_image)
# for res in layout_res:
# p1, p2, p3, p4, p5, p6, p7, p8 = res['poly']
# p1 = p1 - paste_x + xmin
# p2 = p2 - paste_y + ymin
# p3 = p3 - paste_x + xmin
# p4 = p4 - paste_y + ymin
# p5 = p5 - paste_x + xmin
# p6 = p6 - paste_y + ymin
# p7 = p7 - paste_x + xmin
# p8 = p8 - paste_y + ymin
# res['poly'] = [p1, p2, p3, p4, p5, p6, p7, p8]
# else:
layout_res = self.layout_model.predict(image)
layout_cost = round(time.time() - layout_start, 2)
logger.info(f'layout detection time: {layout_cost}')
......
......@@ -9,7 +9,11 @@ class DocLayoutYOLOModel(object):
def predict(self, image):
layout_res = []
doclayout_yolo_res = self.model.predict(
image, imgsz=1024, conf=0.25, iou=0.45, verbose=False, device=self.device
image,
imgsz=1280,
conf=0.10,
iou=0.45,
verbose=False, device=self.device
)[0]
for xyxy, conf, cla in zip(
doclayout_yolo_res.boxes.xyxy.cpu(),
......@@ -32,8 +36,8 @@ class DocLayoutYOLOModel(object):
image_res.cpu()
for image_res in self.model.predict(
images[index : index + batch_size],
imgsz=1024,
conf=0.25,
imgsz=1280,
conf=0.10,
iou=0.45,
verbose=False,
device=self.device,
......
......@@ -12,7 +12,7 @@ from loguru import logger
from magic_pdf.config.enums import SupportedPdfParseMethod
from magic_pdf.config.ocr_content_type import BlockType, ContentType
from magic_pdf.data.dataset import Dataset, PageableData
from magic_pdf.libs.boxbase import calculate_overlap_area_in_bbox1_area_ratio
from magic_pdf.libs.boxbase import calculate_overlap_area_in_bbox1_area_ratio, __is_overlaps_y_exceeds_threshold
from magic_pdf.libs.clean_memory import clean_memory
from magic_pdf.libs.config_reader import get_local_layoutreader_model_dir, get_llm_aided_config, get_device
from magic_pdf.libs.convert_utils import dict_to_list
......@@ -365,10 +365,11 @@ def cal_block_index(fix_blocks, sorted_bboxes):
block['index'] = median_value
# 删除图表body block中的虚拟line信息, 并用real_lines信息回填
if block['type'] in [BlockType.ImageBody, BlockType.TableBody]:
block['virtual_lines'] = copy.deepcopy(block['lines'])
block['lines'] = copy.deepcopy(block['real_lines'])
del block['real_lines']
if block['type'] in [BlockType.ImageBody, BlockType.TableBody, BlockType.Title, BlockType.InterlineEquation]:
if 'real_lines' in block:
block['virtual_lines'] = copy.deepcopy(block['lines'])
block['lines'] = copy.deepcopy(block['real_lines'])
del block['real_lines']
else:
# 使用xycut排序
block_bboxes = []
......@@ -417,7 +418,7 @@ def insert_lines_into_block(block_bbox, line_height, page_w, page_h):
block_weight = x1 - x0
# 如果block高度小于n行正文,则直接返回block的bbox
if line_height * 3 < block_height:
if line_height * 2 < block_height:
if (
block_height > page_h * 0.25 and page_w * 0.5 > block_weight > page_w * 0.25
): # 可能是双列结构,可以切细点
......@@ -425,16 +426,16 @@ def insert_lines_into_block(block_bbox, line_height, page_w, page_h):
else:
# 如果block的宽度超过0.4页面宽度,则将block分成3行(是一种复杂布局,图不能切的太细)
if block_weight > page_w * 0.4:
line_height = (y1 - y0) / 3
lines = 3
line_height = (y1 - y0) / lines
elif block_weight > page_w * 0.25: # (可能是三列结构,也切细点)
lines = int(block_height / line_height) + 1
else: # 判断长宽比
if block_height / block_weight > 1.2: # 细长的不分
return [[x0, y0, x1, y1]]
else: # 不细长的还是分成两行
line_height = (y1 - y0) / 2
lines = 2
line_height = (y1 - y0) / lines
# 确定从哪个y位置开始绘制线条
current_y = y0
......@@ -453,30 +454,32 @@ def insert_lines_into_block(block_bbox, line_height, page_w, page_h):
def sort_lines_by_model(fix_blocks, page_w, page_h, line_height):
page_line_list = []
def add_lines_to_block(b):
line_bboxes = insert_lines_into_block(b['bbox'], line_height, page_w, page_h)
b['lines'] = []
for line_bbox in line_bboxes:
b['lines'].append({'bbox': line_bbox, 'spans': []})
page_line_list.extend(line_bboxes)
for block in fix_blocks:
if block['type'] in [
BlockType.Text, BlockType.Title, BlockType.InterlineEquation,
BlockType.Text, BlockType.Title,
BlockType.ImageCaption, BlockType.ImageFootnote,
BlockType.TableCaption, BlockType.TableFootnote
]:
if len(block['lines']) == 0:
bbox = block['bbox']
lines = insert_lines_into_block(bbox, line_height, page_w, page_h)
for line in lines:
block['lines'].append({'bbox': line, 'spans': []})
page_line_list.extend(lines)
add_lines_to_block(block)
elif block['type'] in [BlockType.Title] and len(block['lines']) == 1 and (block['bbox'][3] - block['bbox'][1]) > line_height * 2:
block['real_lines'] = copy.deepcopy(block['lines'])
add_lines_to_block(block)
else:
for line in block['lines']:
bbox = line['bbox']
page_line_list.append(bbox)
elif block['type'] in [BlockType.ImageBody, BlockType.TableBody]:
bbox = block['bbox']
elif block['type'] in [BlockType.ImageBody, BlockType.TableBody, BlockType.InterlineEquation]:
block['real_lines'] = copy.deepcopy(block['lines'])
lines = insert_lines_into_block(bbox, line_height, page_w, page_h)
block['lines'] = []
for line in lines:
block['lines'].append({'bbox': line, 'spans': []})
page_line_list.extend(lines)
add_lines_to_block(block)
if len(page_line_list) > 200: # layoutreader最高支持512line
return None
......@@ -663,12 +666,68 @@ def parse_page_core(
discarded_blocks = magic_model.get_discarded(page_id)
text_blocks = magic_model.get_text_blocks(page_id)
title_blocks = magic_model.get_title_blocks(page_id)
inline_equations, interline_equations, interline_equation_blocks = (
magic_model.get_equations(page_id)
)
inline_equations, interline_equations, interline_equation_blocks = magic_model.get_equations(page_id)
page_w, page_h = magic_model.get_page_size(page_id)
def merge_title_blocks(blocks, x_distance_threshold=0.1*page_w):
def merge_two_blocks(b1, b2):
# 合并两个标题块的边界框
x_min = min(b1['bbox'][0], b2['bbox'][0])
y_min = min(b1['bbox'][1], b2['bbox'][1])
x_max = max(b1['bbox'][2], b2['bbox'][2])
y_max = max(b1['bbox'][3], b2['bbox'][3])
merged_bbox = (x_min, y_min, x_max, y_max)
# 合并两个标题块的文本内容
merged_score = (b1['score'] + b2['score']) / 2
return {'bbox': merged_bbox, 'score': merged_score}
# 按 y 轴重叠度聚集标题块
y_overlapping_blocks = []
while blocks:
block1 = blocks.pop(0)
current_row = [block1]
to_remove = []
for block2 in blocks:
if __is_overlaps_y_exceeds_threshold(block1['bbox'], block2['bbox'], 0.9):
current_row.append(block2)
to_remove.append(block2)
for b in to_remove:
blocks.remove(b)
y_overlapping_blocks.append(current_row)
# 按x轴坐标排序并合并标题块
merged_blocks = []
for row in y_overlapping_blocks:
if len(row) == 1:
merged_blocks.append(row[0])
continue
# 按x轴坐标排序
row.sort(key=lambda x: x['bbox'][0])
merged_block = row[0]
for i in range(1, len(row)):
left_block = merged_block
right_block = row[i]
left_height = left_block['bbox'][3] - left_block['bbox'][1]
right_height = right_block['bbox'][3] - right_block['bbox'][1]
if right_block['bbox'][0] - left_block['bbox'][2] < x_distance_threshold and left_height * 0.95 < right_height < left_height * 1.05:
merged_block = merge_two_blocks(merged_block, right_block)
else:
merged_blocks.append(merged_block)
merged_block = right_block
merged_blocks.append(merged_block)
return merged_blocks
"""同一行被断开的titile合并"""
title_blocks = merge_title_blocks(title_blocks)
"""将所有区块的bbox整理到一起"""
# interline_equation_blocks参数不够准,后面切换到interline_equations上
interline_equation_blocks = []
......
weights:
layoutlmv3: Layout/LayoutLMv3/model_final.pth
doclayout_yolo: Layout/YOLO/doclayout_yolo_ft.pt
doclayout_yolo: Layout/YOLO/doclayout_yolo_docstructbench_imgsz1280_2501.pt
yolo_v8_mfd: MFD/YOLO/yolo_v8_ft.pt
unimernet_small: MFR/unimernet_small
struct_eqtable: TabRec/StructEqTable
......
......@@ -48,7 +48,7 @@ if __name__ == '__main__':
"struct-eqtable==0.3.2", # 表格解析
"einops", # struct-eqtable依赖
"accelerate", # struct-eqtable依赖
"doclayout_yolo==0.0.2", # doclayout_yolo
"doclayout_yolo==0.0.2b1", # doclayout_yolo
"rapidocr-paddle", # rapidocr-paddle
"rapidocr_onnxruntime",
"rapid_table==0.3.0", # rapid_table
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment