Commit 0f21495a authored by myhloli's avatar myhloli
Browse files

refactor: enhance block processing and sorting utilities for improved span management

parent ae7b0a6e
# Copyright (c) Opendatalab. All rights reserved. # Copyright (c) Opendatalab. All rights reserved.
from mineru.utils.block_pre_proc import prepare_block_bboxes from mineru.utils.block_pre_proc import prepare_block_bboxes, process_groups
from mineru.utils.block_sort import sort_blocks_by_bbox
from mineru.utils.cut_image import cut_image_and_table
from mineru.utils.pipeline_magic_model import MagicModel from mineru.utils.pipeline_magic_model import MagicModel
from mineru.utils.span_block_fix import fill_spans_in_blocks, fix_discarded_block, fix_block_spans
from mineru.utils.span_pre_proc import remove_outside_spans, remove_overlaps_low_confidence_spans, \
remove_overlaps_min_spans, txt_spans_extract
from mineru.version import __version__ from mineru.version import __version__
from mineru.utils.hash_utils import str_md5 from mineru.utils.hash_utils import str_md5
def page_model_info_to_page_info(page_model_info, image_dict, page, image_writer, page_index, lang=None, ocr=False): def page_model_info_to_page_info(page_model_info, image_dict, page, image_writer, page_index, ocr=False):
scale = image_dict["scale"] scale = image_dict["scale"]
page_pil_img = image_dict["img_pil"] page_pil_img = image_dict["img_pil"]
page_img_md5 = str_md5(image_dict["img_base64"]) page_img_md5 = str_md5(image_dict["img_base64"])
...@@ -54,6 +59,57 @@ def page_model_info_to_page_info(page_model_info, image_dict, page, image_writer ...@@ -54,6 +59,57 @@ def page_model_info_to_page_info(page_model_info, image_dict, page, image_writer
page_w, page_w,
page_h, page_h,
) )
"""获取所有的spans信息"""
spans = magic_model.get_all_spans()
"""在删除重复span之前,应该通过image_body和table_body的block过滤一下image和table的span"""
"""顺便删除大水印并保留abandon的span"""
spans = remove_outside_spans(spans, all_bboxes, all_discarded_blocks)
"""删除重叠spans中置信度较低的那些"""
spans, dropped_spans_by_confidence = remove_overlaps_low_confidence_spans(spans)
"""删除重叠spans中较小的那些"""
spans, dropped_spans_by_span_overlap = remove_overlaps_min_spans(spans)
"""根据parse_mode,构造spans,主要是文本类的字符填充"""
if ocr:
pass
else:
"""使用新版本的混合ocr方案."""
spans = txt_spans_extract(page, spans, page_pil_img, scale)
"""先处理不需要排版的discarded_blocks"""
discarded_block_with_spans, spans = fill_spans_in_blocks(
all_discarded_blocks, spans, 0.4
)
fix_discarded_blocks = fix_discarded_block(discarded_block_with_spans)
"""如果当前页面没有有效的bbox则跳过"""
if len(all_bboxes) == 0:
return None
"""对image和table截图"""
for span in spans:
if span['type'] in ['image', 'table']:
span = cut_image_and_table(
span, page_pil_img, page_img_md5, page_index, image_writer, scale=scale
)
"""span填充进block"""
block_with_spans, spans = fill_spans_in_blocks(all_bboxes, spans, 0.5)
"""对block进行fix操作"""
fix_blocks = fix_block_spans(block_with_spans)
"""同一行被断开的titile合并"""
# merge_title_blocks(fix_blocks)
"""对block进行排序"""
sorted_blocks = sort_blocks_by_bbox(fix_blocks, page_w, page_h, footnote_blocks)
"""构造page_info"""
page_info = make_page_info_dict(sorted_blocks, page_index, page_w, page_h, fix_discarded_blocks)
return page_info
def result_to_middle_json(model_list, images_list, pdf_doc, image_writer, lang=None, ocr=False): def result_to_middle_json(model_list, images_list, pdf_doc, image_writer, lang=None, ocr=False):
...@@ -62,23 +118,20 @@ def result_to_middle_json(model_list, images_list, pdf_doc, image_writer, lang=N ...@@ -62,23 +118,20 @@ def result_to_middle_json(model_list, images_list, pdf_doc, image_writer, lang=N
page = pdf_doc[page_index] page = pdf_doc[page_index]
image_dict = images_list[page_index] image_dict = images_list[page_index]
page_info = page_model_info_to_page_info( page_info = page_model_info_to_page_info(
page_model_info, image_dict, page, image_writer, page_index, lang=lang, ocr=ocr page_model_info, image_dict, page, image_writer, page_index, ocr=ocr
) )
if page_info is None:
page_w, page_h = map(int, page.get_size())
page_info = make_page_info_dict([], page_index, page_w, page_h, [])
middle_json["pdf_info"].append(page_info) middle_json["pdf_info"].append(page_info)
return middle_json return middle_json
def process_groups(groups, body_key, caption_key, footnote_key): def make_page_info_dict(blocks, page_id, page_w, page_h, discarded_blocks):
body_blocks = [] return_dict = {
caption_blocks = [] 'preproc_blocks': blocks,
footnote_blocks = [] 'page_idx': page_id,
for i, group in enumerate(groups): 'page_size': [page_w, page_h],
group[body_key]['group_id'] = i 'discarded_blocks': discarded_blocks,
body_blocks.append(group[body_key]) }
for caption_block in group[caption_key]: return return_dict
caption_block['group_id'] = i \ No newline at end of file
caption_blocks.append(caption_block)
for footnote_block in group[footnote_key]:
footnote_block['group_id'] = i
footnote_blocks.append(footnote_block)
return body_blocks, caption_blocks, footnote_blocks
\ No newline at end of file
...@@ -8,6 +8,22 @@ from mineru.utils.boxbase import ( ...@@ -8,6 +8,22 @@ from mineru.utils.boxbase import (
from mineru.utils.enum_class import BlockType from mineru.utils.enum_class import BlockType
def process_groups(groups, body_key, caption_key, footnote_key):
body_blocks = []
caption_blocks = []
footnote_blocks = []
for i, group in enumerate(groups):
group[body_key]['group_id'] = i
body_blocks.append(group[body_key])
for caption_block in group[caption_key]:
caption_block['group_id'] = i
caption_blocks.append(caption_block)
for footnote_block in group[footnote_key]:
footnote_block['group_id'] = i
footnote_blocks.append(footnote_block)
return body_blocks, caption_blocks, footnote_blocks
def prepare_block_bboxes( def prepare_block_bboxes(
img_body_blocks, img_body_blocks,
img_caption_blocks, img_caption_blocks,
......
# Copyright (c) Opendatalab. All rights reserved.
import copy
import os
import statistics
import warnings
from typing import List
import torch
from loguru import logger
from mineru.backend.pipeline.config_reader import get_device, get_local_layoutreader_model_dir
from mineru.utils.enum_class import BlockType
def sort_blocks_by_bbox(blocks, page_w, page_h, footnote_blocks):
"""获取所有line并计算正文line的高度"""
line_height = get_line_height(blocks)
"""获取所有line并对line排序"""
sorted_bboxes = sort_lines_by_model(blocks, page_w, page_h, line_height, footnote_blocks)
"""根据line的中位数算block的序列关系"""
blocks = cal_block_index(blocks, sorted_bboxes)
"""将image和table的block还原回group形式参与后续流程"""
blocks = revert_group_blocks(blocks)
"""重排block"""
sorted_blocks = sorted(blocks, key=lambda b: b['index'])
"""block内重排(img和table的block内多个caption或footnote的排序)"""
for block in sorted_blocks:
if block['type'] in [BlockType.IMAGE, BlockType.TABLE]:
block['blocks'] = sorted(block['blocks'], key=lambda b: b['index'])
return sorted_blocks
def get_line_height(blocks):
page_line_height_list = []
for block in blocks:
if block['type'] in [
BlockType.TEXT, BlockType.TITLE,
BlockType.IMAGE_CAPTION, BlockType.IMAGE_FOOTNOTE,
BlockType.TABLE_CAPTION, BlockType.TABLE_FOOTNOTE
]:
for line in block['lines']:
bbox = line['bbox']
page_line_height_list.append(int(bbox[3] - bbox[1]))
if len(page_line_height_list) > 0:
return statistics.median(page_line_height_list)
else:
return 10
def sort_lines_by_model(fix_blocks, page_w, page_h, line_height, footnote_blocks):
page_line_list = []
def add_lines_to_block(b):
line_bboxes = insert_lines_into_block(b['bbox'], line_height, page_w, page_h)
b['lines'] = []
for line_bbox in line_bboxes:
b['lines'].append({'bbox': line_bbox, 'spans': []})
page_line_list.extend(line_bboxes)
for block in fix_blocks:
if block['type'] in [
BlockType.TEXT, BlockType.TITLE,
BlockType.IMAGE_CAPTION, BlockType.IMAGE_FOOTNOTE,
BlockType.TABLE_CAPTION, BlockType.TABLE_FOOTNOTE
]:
if len(block['lines']) == 0:
add_lines_to_block(block)
elif block['type'] in [BlockType.TITLE] and len(block['lines']) == 1 and (block['bbox'][3] - block['bbox'][1]) > line_height * 2:
block['real_lines'] = copy.deepcopy(block['lines'])
add_lines_to_block(block)
else:
for line in block['lines']:
bbox = line['bbox']
page_line_list.append(bbox)
elif block['type'] in [BlockType.IMAGE_BODY, BlockType.TABLE_BODY, BlockType.INTERLINE_EQUATION]:
block['real_lines'] = copy.deepcopy(block['lines'])
add_lines_to_block(block)
for block in footnote_blocks:
footnote_block = {'bbox': block[:4]}
add_lines_to_block(footnote_block)
if len(page_line_list) > 200: # layoutreader最高支持512line
return None
# 使用layoutreader排序
x_scale = 1000.0 / page_w
y_scale = 1000.0 / page_h
boxes = []
# logger.info(f"Scale: {x_scale}, {y_scale}, Boxes len: {len(page_line_list)}")
for left, top, right, bottom in page_line_list:
if left < 0:
logger.warning(
f'left < 0, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}'
) # noqa: E501
left = 0
if right > page_w:
logger.warning(
f'right > page_w, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}'
) # noqa: E501
right = page_w
if top < 0:
logger.warning(
f'top < 0, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}'
) # noqa: E501
top = 0
if bottom > page_h:
logger.warning(
f'bottom > page_h, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}'
) # noqa: E501
bottom = page_h
left = round(left * x_scale)
top = round(top * y_scale)
right = round(right * x_scale)
bottom = round(bottom * y_scale)
assert (
1000 >= right >= left >= 0 and 1000 >= bottom >= top >= 0
), f'Invalid box. right: {right}, left: {left}, bottom: {bottom}, top: {top}' # noqa: E126, E121
boxes.append([left, top, right, bottom])
model_manager = ModelSingleton()
model = model_manager.get_model('layoutreader')
with torch.no_grad():
orders = do_predict(boxes, model)
sorted_bboxes = [page_line_list[i] for i in orders]
return sorted_bboxes
def insert_lines_into_block(block_bbox, line_height, page_w, page_h):
# block_bbox是一个元组(x0, y0, x1, y1),其中(x0, y0)是左下角坐标,(x1, y1)是右上角坐标
x0, y0, x1, y1 = block_bbox
block_height = y1 - y0
block_weight = x1 - x0
# 如果block高度小于n行正文,则直接返回block的bbox
if line_height * 2 < block_height:
if (
block_height > page_h * 0.25 and page_w * 0.5 > block_weight > page_w * 0.25
): # 可能是双列结构,可以切细点
lines = int(block_height / line_height)
else:
# 如果block的宽度超过0.4页面宽度,则将block分成3行(是一种复杂布局,图不能切的太细)
if block_weight > page_w * 0.4:
lines = 3
elif block_weight > page_w * 0.25: # (可能是三列结构,也切细点)
lines = int(block_height / line_height)
else: # 判断长宽比
if block_height / block_weight > 1.2: # 细长的不分
return [[x0, y0, x1, y1]]
else: # 不细长的还是分成两行
lines = 2
line_height = (y1 - y0) / lines
# 确定从哪个y位置开始绘制线条
current_y = y0
# 用于存储线条的位置信息[(x0, y), ...]
lines_positions = []
for i in range(lines):
lines_positions.append([x0, current_y, x1, current_y + line_height])
current_y += line_height
return lines_positions
else:
return [[x0, y0, x1, y1]]
def model_init(model_name: str):
from transformers import LayoutLMv3ForTokenClassification
device_name = get_device()
bf_16_support = False
if device_name.startswith("cuda"):
bf_16_support = torch.cuda.is_bf16_supported()
elif device_name.startswith("mps"):
bf_16_support = True
device = torch.device(device_name)
if model_name == 'layoutreader':
# 检测modelscope的缓存目录是否存在
layoutreader_model_dir = get_local_layoutreader_model_dir()
if os.path.exists(layoutreader_model_dir):
model = LayoutLMv3ForTokenClassification.from_pretrained(
layoutreader_model_dir
)
else:
logger.warning(
'local layoutreader model not exists, use online model from huggingface'
)
model = LayoutLMv3ForTokenClassification.from_pretrained(
'hantian/layoutreader'
)
if bf_16_support:
model.to(device).eval().bfloat16()
else:
model.to(device).eval()
else:
logger.error('model name not allow')
exit(1)
return model
class ModelSingleton:
_instance = None
_models = {}
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def get_model(self, model_name: str):
if model_name not in self._models:
self._models[model_name] = model_init(model_name=model_name)
return self._models[model_name]
def do_predict(boxes: List[List[int]], model) -> List[int]:
from mineru.model.reading_order.layout_reader import (
boxes2inputs, parse_logits, prepare_inputs)
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=FutureWarning, module="transformers")
inputs = boxes2inputs(boxes)
inputs = prepare_inputs(inputs, model)
logits = model(**inputs).logits.cpu().squeeze(0)
return parse_logits(logits, len(boxes))
def cal_block_index(fix_blocks, sorted_bboxes):
if sorted_bboxes is not None:
# 使用layoutreader排序
for block in fix_blocks:
line_index_list = []
if len(block['lines']) == 0:
block['index'] = sorted_bboxes.index(block['bbox'])
else:
for line in block['lines']:
line['index'] = sorted_bboxes.index(line['bbox'])
line_index_list.append(line['index'])
median_value = statistics.median(line_index_list)
block['index'] = median_value
# 删除图表body block中的虚拟line信息, 并用real_lines信息回填
if block['type'] in [BlockType.IMAGE_BODY, BlockType.TABLE_BODY, BlockType.TITLE, BlockType.INTERLINE_EQUATION]:
if 'real_lines' in block:
block['virtual_lines'] = copy.deepcopy(block['lines'])
block['lines'] = copy.deepcopy(block['real_lines'])
del block['real_lines']
else:
# 使用xycut排序
block_bboxes = []
for block in fix_blocks:
# 如果block['bbox']任意值小于0,将其置为0
block['bbox'] = [max(0, x) for x in block['bbox']]
block_bboxes.append(block['bbox'])
# 删除图表body block中的虚拟line信息, 并用real_lines信息回填
if block['type'] in [BlockType.IMAGE_BODY, BlockType.TABLE_BODY, BlockType.TITLE, BlockType.INTERLINE_EQUATION]:
if 'real_lines' in block:
block['virtual_lines'] = copy.deepcopy(block['lines'])
block['lines'] = copy.deepcopy(block['real_lines'])
del block['real_lines']
import numpy as np
from mineru.model.reading_order.xycut import recursive_xy_cut
random_boxes = np.array(block_bboxes)
np.random.shuffle(random_boxes)
res = []
recursive_xy_cut(np.asarray(random_boxes).astype(int), np.arange(len(block_bboxes)), res)
assert len(res) == len(block_bboxes)
sorted_boxes = random_boxes[np.array(res)].tolist()
for i, block in enumerate(fix_blocks):
block['index'] = sorted_boxes.index(block['bbox'])
# 生成line index
sorted_blocks = sorted(fix_blocks, key=lambda b: b['index'])
line_inedx = 1
for block in sorted_blocks:
for line in block['lines']:
line['index'] = line_inedx
line_inedx += 1
return fix_blocks
def revert_group_blocks(blocks):
image_groups = {}
table_groups = {}
new_blocks = []
for block in blocks:
if block['type'] in [BlockType.IMAGE_BODY, BlockType.IMAGE_CAPTION, BlockType.IMAGE_FOOTNOTE]:
group_id = block['group_id']
if group_id not in image_groups:
image_groups[group_id] = []
image_groups[group_id].append(block)
elif block['type'] in [BlockType.TABLE_BODY, BlockType.TABLE_CAPTION, BlockType.TABLE_FOOTNOTE]:
group_id = block['group_id']
if group_id not in table_groups:
table_groups[group_id] = []
table_groups[group_id].append(block)
else:
new_blocks.append(block)
for group_id, blocks in image_groups.items():
new_blocks.append(process_block_list(blocks, BlockType.IMAGE_BODY, BlockType.IMAGE))
for group_id, blocks in table_groups.items():
new_blocks.append(process_block_list(blocks, BlockType.TABLE_BODY, BlockType.TABLE))
return new_blocks
def process_block_list(blocks, body_type, block_type):
indices = [block['index'] for block in blocks]
median_index = statistics.median(indices)
body_bbox = next((block['bbox'] for block in blocks if block.get('type') == body_type), [])
return {
'type': block_type,
'bbox': body_bbox,
'blocks': blocks,
'index': median_index,
}
\ No newline at end of file
...@@ -3,14 +3,14 @@ from loguru import logger ...@@ -3,14 +3,14 @@ from loguru import logger
from .pdf_image_tools import cut_image from .pdf_image_tools import cut_image
def cut_image_and_table(span, page_pil_img, page_img_md5, page_id, imageWriter, scale=2): def cut_image_and_table(span, page_pil_img, page_img_md5, page_id, image_writer, scale=2):
def return_path(path_type): def return_path(path_type):
return f"{path_type}/{page_img_md5}" return f"{path_type}/{page_img_md5}"
span_type = span["type"] span_type = span["type"]
if not check_img_bbox(span["bbox"]) or not imageWriter: if not check_img_bbox(span["bbox"]) or not image_writer:
span["image_path"] = "" span["image_path"] = ""
else: else:
span["image_path"] = cut_image( span["image_path"] = cut_image(
......
...@@ -54,7 +54,7 @@ def load_images_from_pdf( ...@@ -54,7 +54,7 @@ def load_images_from_pdf(
return images_list, pdf_doc return images_list, pdf_doc
def cut_image(bbox: tuple, page_num: int, page_pil_img, return_path, imageWriter: FileBasedDataWriter, scale=3): def cut_image(bbox: tuple, page_num: int, page_pil_img, return_path, imageWriter: FileBasedDataWriter, scale=2):
"""从第page_num页的page中,根据bbox进行裁剪出一张jpg图片,返回图片路径 save_path:需要同时支持s3和本地, """从第page_num页的page中,根据bbox进行裁剪出一张jpg图片,返回图片路径 save_path:需要同时支持s3和本地,
图片存放在save_path下,文件名是: 图片存放在save_path下,文件名是:
{page_num}_{bbox[0]}_{bbox[1]}_{bbox[2]}_{bbox[3]}.jpg , bbox内数字取整。""" {page_num}_{bbox[0]}_{bbox[1]}_{bbox[2]}_{bbox[3]}.jpg , bbox内数字取整。"""
......
# Copyright (c) Opendatalab. All rights reserved.
from mineru.utils.boxbase import calculate_overlap_area_in_bbox1_area_ratio
from mineru.utils.enum_class import BlockType, ContentType
from mineru.utils.ocr_utils import __is_overlaps_y_exceeds_threshold
def fill_spans_in_blocks(blocks, spans, radio):
"""将allspans中的span按位置关系,放入blocks中."""
block_with_spans = []
for block in blocks:
block_type = block[7]
block_bbox = block[0:4]
block_dict = {
'type': block_type,
'bbox': block_bbox,
}
if block_type in [
BlockType.IMAGE_BODY, BlockType.IMAGE_CAPTION, BlockType.IMAGE_FOOTNOTE,
BlockType.TABLE_BODY, BlockType.TABLE_CAPTION, BlockType.TABLE_FOOTNOTE
]:
block_dict['group_id'] = block[-1]
block_spans = []
for span in spans:
span_bbox = span['bbox']
if calculate_overlap_area_in_bbox1_area_ratio(span_bbox, block_bbox) > radio and span_block_type_compatible(
span['type'], block_type):
block_spans.append(span)
block_dict['spans'] = block_spans
block_with_spans.append(block_dict)
# 从spans删除已经放入block_spans中的span
if len(block_spans) > 0:
for span in block_spans:
spans.remove(span)
return block_with_spans, spans
def span_block_type_compatible(span_type, block_type):
if span_type in [ContentType.TEXT, ContentType.INTERLINE_EQUATION]:
return block_type in [
BlockType.TEXT,
BlockType.TITLE,
BlockType.IMAGE_CAPTION,
BlockType.IMAGE_FOOTNOTE,
BlockType.TABLE_CAPTION,
BlockType.TABLE_FOOTNOTE,
BlockType.DISCARDED
]
elif span_type == ContentType.INTERLINE_EQUATION:
return block_type in [BlockType.INTERLINE_EQUATION, BlockType.TEXT]
elif span_type == ContentType.IMAGE:
return block_type in [BlockType.IMAGE_BODY]
elif span_type == ContentType.TABLE:
return block_type in [BlockType.TABLE_BODY]
else:
return False
def fix_discarded_block(discarded_block_with_spans):
fix_discarded_blocks = []
for block in discarded_block_with_spans:
block = fix_text_block(block)
fix_discarded_blocks.append(block)
return fix_discarded_blocks
def fix_text_block(block):
# 文本block中的公式span都应该转换成行内type
for span in block['spans']:
if span['type'] == ContentType.INTERLINE_EQUATION:
span['type'] = ContentType.INLINE_EQUATION
block_lines = merge_spans_to_line(block['spans'])
sort_block_lines = line_sort_spans_by_left_to_right(block_lines)
block['lines'] = sort_block_lines
del block['spans']
return block
def merge_spans_to_line(spans, threshold=0.6):
if len(spans) == 0:
return []
else:
# 按照y0坐标排序
spans.sort(key=lambda span: span['bbox'][1])
lines = []
current_line = [spans[0]]
for span in spans[1:]:
# 如果当前的span类型为"interline_equation" 或者 当前行中已经有"interline_equation"
# image和table类型,同上
if span['type'] in [
ContentType.INTERLINE_EQUATION, ContentType.IMAGE,
ContentType.TABLE
] or any(s['type'] in [
ContentType.INTERLINE_EQUATION, ContentType.IMAGE,
ContentType.TABLE
] for s in current_line):
# 则开始新行
lines.append(current_line)
current_line = [span]
continue
# 如果当前的span与当前行的最后一个span在y轴上重叠,则添加到当前行
if __is_overlaps_y_exceeds_threshold(span['bbox'], current_line[-1]['bbox'], threshold):
current_line.append(span)
else:
# 否则,开始新行
lines.append(current_line)
current_line = [span]
# 添加最后一行
if current_line:
lines.append(current_line)
return lines
# 将每一个line中的span从左到右排序
def line_sort_spans_by_left_to_right(lines):
line_objects = []
for line in lines:
# 按照x0坐标排序
line.sort(key=lambda span: span['bbox'][0])
line_bbox = [
min(span['bbox'][0] for span in line), # x0
min(span['bbox'][1] for span in line), # y0
max(span['bbox'][2] for span in line), # x1
max(span['bbox'][3] for span in line), # y1
]
line_objects.append({
'bbox': line_bbox,
'spans': line,
})
return line_objects
def fix_block_spans(block_with_spans):
fix_blocks = []
for block in block_with_spans:
block_type = block['type']
if block_type in [BlockType.TEXT, BlockType.TITLE,
BlockType.IMAGE_CAPTION, BlockType.IMAGE_CAPTION,
BlockType.TABLE_CAPTION, BlockType.TABLE_FOOTNOTE
]:
block = fix_text_block(block)
elif block_type in [BlockType.INTERLINE_EQUATION, BlockType.IMAGE_BODY, BlockType.TABLE_BODY]:
block = fix_interline_block(block)
else:
continue
fix_blocks.append(block)
return fix_blocks
def fix_interline_block(block):
block_lines = merge_spans_to_line(block['spans'])
sort_block_lines = line_sort_spans_by_left_to_right(block_lines)
block['lines'] = sort_block_lines
del block['spans']
return block
\ No newline at end of file
# Copyright (c) Opendatalab. All rights reserved.
import cv2
import numpy as np
from mineru.utils.boxbase import calculate_overlap_area_in_bbox1_area_ratio, calculate_iou, \
get_minbox_if_overlap_by_ratio
from mineru.utils.enum_class import BlockType, ContentType
from mineru.utils.pdf_image_tools import get_crop_img
def remove_outside_spans(spans, all_bboxes, all_discarded_blocks):
def get_block_bboxes(blocks, block_type_list):
return [block[0:4] for block in blocks if block[7] in block_type_list]
image_bboxes = get_block_bboxes(all_bboxes, [BlockType.IMAGE_BODY])
table_bboxes = get_block_bboxes(all_bboxes, [BlockType.TABLE_BODY])
other_block_type = []
for block_type in BlockType.__dict__.values():
if not isinstance(block_type, str):
continue
if block_type not in [BlockType.IMAGE_BODY, BlockType.TABLE_BODY]:
other_block_type.append(block_type)
other_block_bboxes = get_block_bboxes(all_bboxes, other_block_type)
discarded_block_bboxes = get_block_bboxes(all_discarded_blocks, [BlockType.DISCARDED])
new_spans = []
for span in spans:
span_bbox = span['bbox']
span_type = span['type']
if any(calculate_overlap_area_in_bbox1_area_ratio(span_bbox, block_bbox) > 0.4 for block_bbox in
discarded_block_bboxes):
new_spans.append(span)
continue
if span_type == ContentType.IMAGE:
if any(calculate_overlap_area_in_bbox1_area_ratio(span_bbox, block_bbox) > 0.5 for block_bbox in
image_bboxes):
new_spans.append(span)
elif span_type == ContentType.TABLE:
if any(calculate_overlap_area_in_bbox1_area_ratio(span_bbox, block_bbox) > 0.5 for block_bbox in
table_bboxes):
new_spans.append(span)
else:
if any(calculate_overlap_area_in_bbox1_area_ratio(span_bbox, block_bbox) > 0.5 for block_bbox in
other_block_bboxes):
new_spans.append(span)
return new_spans
def remove_overlaps_low_confidence_spans(spans):
dropped_spans = []
# 删除重叠spans中置信度低的的那些
for span1 in spans:
for span2 in spans:
if span1 != span2:
# span1 或 span2 任何一个都不应该在 dropped_spans 中
if span1 in dropped_spans or span2 in dropped_spans:
continue
else:
if calculate_iou(span1['bbox'], span2['bbox']) > 0.9:
if span1['score'] < span2['score']:
span_need_remove = span1
else:
span_need_remove = span2
if (
span_need_remove is not None
and span_need_remove not in dropped_spans
):
dropped_spans.append(span_need_remove)
if len(dropped_spans) > 0:
for span_need_remove in dropped_spans:
spans.remove(span_need_remove)
return spans, dropped_spans
def remove_overlaps_min_spans(spans):
dropped_spans = []
# 删除重叠spans中较小的那些
for span1 in spans:
for span2 in spans:
if span1 != span2:
# span1 或 span2 任何一个都不应该在 dropped_spans 中
if span1 in dropped_spans or span2 in dropped_spans:
continue
else:
overlap_box = get_minbox_if_overlap_by_ratio(span1['bbox'], span2['bbox'], 0.65)
if overlap_box is not None:
span_need_remove = next((span for span in spans if span['bbox'] == overlap_box), None)
if span_need_remove is not None and span_need_remove not in dropped_spans:
dropped_spans.append(span_need_remove)
if len(dropped_spans) > 0:
for span_need_remove in dropped_spans:
spans.remove(span_need_remove)
return spans, dropped_spans
def txt_spans_extract(pdf_page, spans, pil_img, scale):
textpage = pdf_page.get_textpage()
width, height = pdf_page.get_size()
cropbox = pdf_page.get_cropbox()
need_ocr_spans = []
for span in spans:
span_bbox = span['bbox']
rect_box = [span_bbox[0] + cropbox[0],
height - span_bbox[3] + cropbox[1],
span_bbox[2] + cropbox[0],
height - span_bbox[1] + cropbox[1]]
text = textpage.get_text_bounded(left=rect_box[0], top=rect_box[1],
right=rect_box[2], bottom=rect_box[3])
if text and len(text) > 0:
span['content'] = text.strip()
span['score'] = 1.0
else:
need_ocr_spans.append(span)
if len(need_ocr_spans) > 0:
for span in need_ocr_spans:
# 对span的bbox截图再ocr
span_pil_img = get_crop_img(span['bbox'], pil_img, scale)
span_img = cv2.cvtColor(np.array(span_pil_img), cv2.COLOR_RGB2BGR)
# 计算span的对比度,低于0.20的span不进行ocr
if calculate_contrast(span_img, img_mode='bgr') <= 0.17:
spans.remove(span)
continue
span['content'] = ''
span['score'] = 1.0
span['np_img'] = span_img
return spans
def calculate_contrast(img, img_mode) -> float:
"""
计算给定图像的对比度。
:param img: 图像,类型为numpy.ndarray
:Param img_mode = 图像的色彩通道,'rgb' 或 'bgr'
:return: 图像的对比度值
"""
if img_mode == 'rgb':
# 将RGB图像转换为灰度图
gray_img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
elif img_mode == 'bgr':
# 将BGR图像转换为灰度图
gray_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
else:
raise ValueError("Invalid image mode. Please provide 'rgb' or 'bgr'.")
# 计算均值和标准差
mean_value = np.mean(gray_img)
std_dev = np.std(gray_img)
# 对比度定义为标准差除以平均值(加上小常数避免除零错误)
contrast = std_dev / (mean_value + 1e-6)
# logger.debug(f"contrast: {contrast}")
return round(contrast, 2)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment