Commit 0a899f1a authored by myhloli's avatar myhloli
Browse files

feat: add batch processing for OCR detection and implement new client and common utilities

parent cbba27b4
import re
from ..utils.enum_class import MakeMode, BlockType, ContentType
def merge_para_with_text(para_block):
para_text = ''
for line in para_block['lines']:
for span in line['spans']:
content = span['content']
content = content.strip()
if content:
para_text += content
else:
continue
return para_text
def mk_blocks_to_markdown(para_blocks, make_mode, img_buket_path=''):
page_markdown = []
for para_block in para_blocks:
para_text = ''
para_type = para_block['type']
if para_type in [BlockType.TEXT, BlockType.LIST, BlockType.INDEX, BlockType.TITLE, BlockType.INTERLINE_EQUATION]:
para_text = merge_para_with_text(para_block)
elif para_type == BlockType.IMAGE:
if make_mode == MakeMode.NLP_MD:
continue
elif make_mode == MakeMode.MM_MD:
# 检测是否存在图片脚注
has_image_footnote = any(block['type'] == BlockType.IMAGE_FOOTNOTE for block in para_block['blocks'])
# 如果存在图片脚注,则将图片脚注拼接到图片正文后面
if has_image_footnote:
for block in para_block['blocks']: # 1st.拼image_caption
if block['type'] == BlockType.IMAGE_CAPTION:
para_text += merge_para_with_text(block) + ' \n'
for block in para_block['blocks']: # 2nd.拼image_body
if block['type'] == BlockType.IMAGE_BODY:
for line in block['lines']:
for span in line['spans']:
if span['type'] == ContentType.IMAGE:
if span.get('image_path', ''):
para_text += f"![]({img_buket_path}/{span['image_path']})"
for block in para_block['blocks']: # 3rd.拼image_footnote
if block['type'] == BlockType.IMAGE_FOOTNOTE:
para_text += ' \n' + merge_para_with_text(block)
else:
for block in para_block['blocks']: # 1st.拼image_body
if block['type'] == BlockType.IMAGE_BODY:
for line in block['lines']:
for span in line['spans']:
if span['type'] == ContentType.IMAGE:
if span.get('image_path', ''):
para_text += f"![]({img_buket_path}/{span['image_path']})"
for block in para_block['blocks']: # 2nd.拼image_caption
if block['type'] == BlockType.IMAGE_CAPTION:
para_text += ' \n' + merge_para_with_text(block)
elif para_type == BlockType.TABLE:
if make_mode == MakeMode.NLP_MD:
continue
elif make_mode == MakeMode.MM_MD:
for block in para_block['blocks']: # 1st.拼table_caption
if block['type'] == BlockType.TABLE_CAPTION:
para_text += merge_para_with_text(block) + ' \n'
for block in para_block['blocks']: # 2nd.拼table_body
if block['type'] == BlockType.TABLE_BODY:
for line in block['lines']:
for span in line['spans']:
if span['type'] == ContentType.TABLE:
# if processed by table model
if span.get('html', ''):
para_text += f"\n{span['html']}\n"
elif span.get('image_path', ''):
para_text += f"![]({img_buket_path}/{span['image_path']})"
for block in para_block['blocks']: # 3rd.拼table_footnote
if block['type'] == BlockType.TABLE_FOOTNOTE:
para_text += '\n' + merge_para_with_text(block) + ' '
if para_text.strip() == '':
continue
else:
# page_markdown.append(para_text.strip() + ' ')
page_markdown.append(para_text.strip())
return page_markdown
def count_leading_hashes(text):
match = re.match(r'^(#+)', text)
return len(match.group(1)) if match else 0
def strip_leading_hashes(text):
# 去除开头的#和紧随其后的空格
return re.sub(r'^#+\s*', '', text)
def make_blocks_to_content_list(para_block, img_buket_path, page_idx):
para_type = para_block['type']
para_content = {}
if para_type in [BlockType.TEXT, BlockType.LIST, BlockType.INDEX]:
para_content = {
'type': 'text',
'text': merge_para_with_text(para_block),
}
elif para_type == BlockType.TITLE:
title_content = merge_para_with_text(para_block)
title_level = count_leading_hashes(title_content)
para_content = {
'type': 'text',
'text': strip_leading_hashes(title_content),
}
if title_level != 0:
para_content['text_level'] = title_level
elif para_type == BlockType.INTERLINE_EQUATION:
para_content = {
'type': 'equation',
'text': merge_para_with_text(para_block),
'text_format': 'latex',
}
elif para_type == BlockType.IMAGE:
para_content = {'type': 'image', 'img_path': '', 'img_caption': [], 'img_footnote': []}
for block in para_block['blocks']:
if block['type'] == BlockType.IMAGE_BODY:
for line in block['lines']:
for span in line['spans']:
if span['type'] == ContentType.IMAGE:
if span.get('image_path', ''):
para_content['img_path'] = f"{img_buket_path}/{span['image_path']}"
if block['type'] == BlockType.IMAGE_CAPTION:
para_content['img_caption'].append(merge_para_with_text(block))
if block['type'] == BlockType.IMAGE_FOOTNOTE:
para_content['img_footnote'].append(merge_para_with_text(block))
elif para_type == BlockType.TABLE:
para_content = {'type': 'table', 'img_path': '', 'table_caption': [], 'table_footnote': []}
for block in para_block['blocks']:
if block['type'] == BlockType.TABLE_BODY:
for line in block['lines']:
for span in line['spans']:
if span['type'] == ContentType.TABLE:
if span.get('html', ''):
para_content['table_body'] = f"{span['html']}"
if span.get('image_path', ''):
para_content['img_path'] = f"{img_buket_path}/{span['image_path']}"
if block['type'] == BlockType.TABLE_CAPTION:
para_content['table_caption'].append(merge_para_with_text(block))
if block['type'] == BlockType.TABLE_FOOTNOTE:
para_content['table_footnote'].append(merge_para_with_text(block))
para_content['page_idx'] = page_idx
return para_content
def union_make(pdf_info_dict: list,
make_mode: str,
img_buket_path: str = '',
):
output_content = []
for page_info in pdf_info_dict:
paras_of_layout = page_info.get('para_blocks')
page_idx = page_info.get('page_idx')
if not paras_of_layout:
continue
if make_mode in [MakeMode.MM_MD, MakeMode.NLP_MD]:
page_markdown = mk_blocks_to_markdown(paras_of_layout, make_mode, img_buket_path)
output_content.extend(page_markdown)
elif make_mode == MakeMode.STANDARD_FORMAT:
for para_block in paras_of_layout:
para_content = make_blocks_to_content_list(para_block, img_buket_path, page_idx)
output_content.append(para_content)
if make_mode in [MakeMode.MM_MD, MakeMode.NLP_MD]:
return '\n\n'.join(output_content)
elif make_mode == MakeMode.STANDARD_FORMAT:
return output_content
return None
def get_title_level(block):
title_level = block.get('level', 1)
if title_level > 4:
title_level = 4
elif title_level < 1:
title_level = 0
return title_level
import cv2
from loguru import logger
from tqdm import tqdm
from collections import defaultdict
import numpy as np
from .model_init import AtomModelSingleton
from ...utils.model_utils import crop_img, get_res_list_from_layout_res, get_coords_and_area
......@@ -12,11 +14,12 @@ MFR_BASE_BATCH_SIZE = 16
class BatchAnalyze:
def __init__(self, model_manager, batch_ratio: int, formula_enable, table_enable):
def __init__(self, model_manager, batch_ratio: int, formula_enable, table_enable, enable_ocr_det_batch: bool = True):
self.batch_ratio = batch_ratio
self.formula_enable = formula_enable
self.table_enable = table_enable
self.model_manager = model_manager
self.enable_ocr_det_batch = enable_ocr_det_batch
def __call__(self, images_with_extra_info: list) -> list:
if len(images_with_extra_info) == 0:
......@@ -89,48 +92,160 @@ class BatchAnalyze:
'table_img':table_img,
})
# 文本框检测
for ocr_res_list_dict in tqdm(ocr_res_list_all_page, desc="OCR-det Predict"):
# Process each area that requires OCR processing
_lang = ocr_res_list_dict['lang']
# Get OCR results for this language's images
ocr_model = atom_model_manager.get_atom_model(
atom_model_name='ocr',
det_db_box_thresh=0.3,
lang=_lang
)
for res in ocr_res_list_dict['ocr_res_list']:
new_image, useful_list = crop_img(
res, ocr_res_list_dict['np_array_img'], crop_paste_x=50, crop_paste_y=50
)
adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(
ocr_res_list_dict['single_page_mfdetrec_res'], useful_list
)
# OCR检测处理
if self.enable_ocr_det_batch:
# 批处理模式 - 按语言和分辨率分组
# 收集所有需要OCR检测的裁剪图像
all_cropped_images_info = []
for ocr_res_list_dict in ocr_res_list_all_page:
_lang = ocr_res_list_dict['lang']
for res in ocr_res_list_dict['ocr_res_list']:
new_image, useful_list = crop_img(
res, ocr_res_list_dict['np_array_img'], crop_paste_x=50, crop_paste_y=50
)
adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(
ocr_res_list_dict['single_page_mfdetrec_res'], useful_list
)
# BGR转换
new_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR)
all_cropped_images_info.append((
new_image, useful_list, ocr_res_list_dict, res, adjusted_mfdetrec_res, _lang
))
# 按语言分组
lang_groups = defaultdict(list)
for crop_info in all_cropped_images_info:
lang = crop_info[5]
lang_groups[lang].append(crop_info)
# 对每种语言按分辨率分组并批处理
for lang, lang_crop_list in lang_groups.items():
if not lang_crop_list:
continue
# logger.info(f"Processing OCR detection for language {lang} with {len(lang_crop_list)} images")
# 获取OCR模型
ocr_model = atom_model_manager.get_atom_model(
atom_model_name='ocr',
ocr_show_log=False,
det_db_box_thresh=0.3,
lang=lang
)
# OCR-det
new_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR)
ocr_res = ocr_model.ocr(
new_image, mfd_res=adjusted_mfdetrec_res, rec=False
)[0]
# Integration results
if ocr_res:
ocr_result_list = get_ocr_result_list(ocr_res, useful_list, ocr_res_list_dict['ocr_enable'], new_image, _lang)
if res["category_id"] == 3:
# ocr_result_list中所有bbox的面积之和
ocr_res_area = sum(get_coords_and_area(ocr_res_item)[4] for ocr_res_item in ocr_result_list if 'poly' in ocr_res_item)
# 求ocr_res_area和res的面积的比值
res_area = get_coords_and_area(res)[4]
if res_area > 0:
ratio = ocr_res_area / res_area
if ratio > 0.25:
res["category_id"] = 1
else:
continue
ocr_res_list_dict['layout_res'].extend(ocr_result_list)
# 按分辨率分组并同时完成padding
resolution_groups = defaultdict(list)
for crop_info in lang_crop_list:
cropped_img = crop_info[0]
h, w = cropped_img.shape[:2]
# 使用更大的分组容差,减少分组数量
# 将尺寸标准化到32的倍数
normalized_h = ((h + 32) // 32) * 32 # 向上取整到32的倍数
normalized_w = ((w + 32) // 32) * 32
group_key = (normalized_h, normalized_w)
resolution_groups[group_key].append(crop_info)
# 对每个分辨率组进行批处理
for group_key, group_crops in tqdm(resolution_groups.items(), desc=f"OCR-det {lang}"):
raw_images = [crop_info[0] for crop_info in group_crops]
# 计算目标尺寸(组内最大尺寸,向上取整到32的倍数)
max_h = max(img.shape[0] for img in raw_images)
max_w = max(img.shape[1] for img in raw_images)
target_h = ((max_h + 32 - 1) // 32) * 32
target_w = ((max_w + 32 - 1) // 32) * 32
# 对所有图像进行padding到统一尺寸
batch_images = []
for img in raw_images:
h, w = img.shape[:2]
# 创建目标尺寸的白色背景
padded_img = np.ones((target_h, target_w, 3), dtype=np.uint8) * 255
# 将原图像粘贴到左上角
padded_img[:h, :w] = img
batch_images.append(padded_img)
# 批处理检测
batch_size = min(len(batch_images), self.batch_ratio * 16) # 增加批处理大小
# logger.debug(f"OCR-det batch: {batch_size} images, target size: {target_h}x{target_w}")
batch_results = ocr_model.text_detector.batch_predict(batch_images, batch_size)
# 处理批处理结果
for i, (crop_info, (dt_boxes, elapse)) in enumerate(zip(group_crops, batch_results)):
new_image, useful_list, ocr_res_list_dict, res, adjusted_mfdetrec_res, _lang = crop_info
if dt_boxes is not None:
# 构造OCR结果格式 - 每个box应该是4个点的列表
ocr_res = [box.tolist() for box in dt_boxes]
if ocr_res:
ocr_result_list = get_ocr_result_list(
ocr_res, useful_list, ocr_res_list_dict['ocr_enable'], new_image, _lang
)
if res["category_id"] == 3:
# ocr_result_list中所有bbox的面积之和
ocr_res_area = sum(
get_coords_and_area(ocr_res_item)[4] for ocr_res_item in ocr_result_list if 'poly' in ocr_res_item)
# 求ocr_res_area和res的面积的比值
res_area = get_coords_and_area(res)[4]
if res_area > 0:
ratio = ocr_res_area / res_area
if ratio > 0.25:
res["category_id"] = 1
else:
continue
ocr_res_list_dict['layout_res'].extend(ocr_result_list)
else:
# 原始单张处理模式
for ocr_res_list_dict in tqdm(ocr_res_list_all_page, desc="OCR-det Predict"):
# Process each area that requires OCR processing
_lang = ocr_res_list_dict['lang']
# Get OCR results for this language's images
ocr_model = atom_model_manager.get_atom_model(
atom_model_name='ocr',
ocr_show_log=False,
det_db_box_thresh=0.3,
lang=_lang
)
for res in ocr_res_list_dict['ocr_res_list']:
new_image, useful_list = crop_img(
res, ocr_res_list_dict['np_array_img'], crop_paste_x=50, crop_paste_y=50
)
adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(
ocr_res_list_dict['single_page_mfdetrec_res'], useful_list
)
# OCR-det
new_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR)
ocr_res = ocr_model.ocr(
new_image, mfd_res=adjusted_mfdetrec_res, rec=False
)[0]
# Integration results
if ocr_res:
ocr_result_list = get_ocr_result_list(ocr_res, useful_list, ocr_res_list_dict['ocr_enable'],
new_image, _lang)
if res["category_id"] == 3:
# ocr_result_list中所有bbox的面积之和
ocr_res_area = sum(
get_coords_and_area(ocr_res_item)[4] for ocr_res_item in ocr_result_list if 'poly' in ocr_res_item)
# 求ocr_res_area和res的面积的比值
res_area = get_coords_and_area(res)[4]
if res_area > 0:
ratio = ocr_res_area / res_area
if ratio > 0.25:
res["category_id"] = 1
else:
continue
ocr_res_list_dict['layout_res'].extend(ocr_result_list)
# 表格识别 table recognition
if self.table_enable:
......
......@@ -3,7 +3,7 @@ import re
from mineru.utils.cut_image import cut_image_and_table
from mineru.utils.enum_class import BlockType, ContentType
from mineru.utils.hash_utils import str_md5
from mineru.utils.magic_model import fix_two_layer_blocks
from mineru.utils.vlm_magic_model import fix_two_layer_blocks
from mineru.version import __version__
......@@ -113,7 +113,7 @@ def token_to_page_info(token, image_dict, page, image_writer, page_index) -> dic
# 对page_blocks根据index的值进行排序
page_blocks.sort(key=lambda x: x["index"])
page_info = {"para_blocks": page_blocks, "page_size": [width, height], "page_idx": page_index}
page_info = {"para_blocks": page_blocks, "discarded_blocks": [], "page_size": [width, height], "page_idx": page_index}
return page_info
......
# Copyright (c) Opendatalab. All rights reserved.
import os
import click
from pathlib import Path
from loguru import logger
from ..version import __version__
from .common import do_parse, read_fn, pdf_suffixes, image_suffixes
@click.command()
@click.version_option(__version__,
'--version',
'-v',
help='display the version and exit')
@click.option(
'-p',
'--path',
'input_path',
type=click.Path(exists=True),
required=True,
help='local filepath or directory. support pdf, png, jpg, jpeg files',
)
@click.option(
'-o',
'--output-dir',
'output_dir',
type=click.Path(),
required=True,
help='output local directory',
)
@click.option(
'-b',
'--backend',
'backend',
type=click.Choice(['pipeline', 'vlm-huggingface', 'vlm-sglang-engine', 'vlm-sglang-client']),
help="""the backend for parsing pdf:
pipeline: More general.
vlm-huggingface: More general.
vlm-sglang-engine: Faster(engine).
vlm-sglang-client: Faster(client).
without method specified, huggingface will be used by default.""",
default='pipeline',
)
@click.option(
'-u',
'--url',
'server_url',
type=str,
help="""
When the backend is `sglang-client`, you need to specify the server_url, for example:`http://127.0.0.1:30000`
""",
default=None,
)
@click.option(
'-s',
'--start',
'start_page_id',
type=int,
help='The starting page for PDF parsing, beginning from 0.',
default=0,
)
@click.option(
'-e',
'--end',
'end_page_id',
type=int,
help='The ending page for PDF parsing, beginning from 0.',
default=None,
)
def main(input_path, output_dir, backend, server_url, start_page_id, end_page_id):
os.makedirs(output_dir, exist_ok=True)
def parse_doc(path: Path):
try:
file_name = str(Path(path).stem)
pdf_bits = read_fn(path)
do_parse(output_dir, file_name, pdf_bits, backend, server_url,
start_page_id=start_page_id, end_page_id=end_page_id)
except Exception as e:
logger.exception(e)
if os.path.isdir(input_path):
for doc_path in Path(input_path).glob('*'):
if doc_path.suffix in pdf_suffixes + image_suffixes:
parse_doc(Path(doc_path))
else:
parse_doc(Path(input_path))
if __name__ == '__main__':
main()
# Copyright (c) Opendatalab. All rights reserved.
import io
import json
import os
from pathlib import Path
import pypdfium2 as pdfium
from loguru import logger
from ..api.vlm_middle_json_mkcontent import union_make
from ..backend.vlm.vlm_analyze import doc_analyze
from ..data.data_reader_writer import FileBasedDataWriter
from ..utils.draw_bbox import draw_layout_bbox, draw_span_bbox
from ..utils.enum_class import MakeMode
from ..utils.pdf_image_tools import images_bytes_to_pdf_bytes
pdf_suffixes = [".pdf"]
image_suffixes = [".png", ".jpeg", ".jpg"]
def read_fn(path: Path):
with open(str(path), "rb") as input_file:
file_bytes = input_file.read()
if path.suffix in image_suffixes:
return images_bytes_to_pdf_bytes(file_bytes)
elif path.suffix in pdf_suffixes:
return file_bytes
else:
raise Exception(f"Unknown file suffix: {path.suffix}")
def prepare_env(output_dir, pdf_file_name):
local_parent_dir = os.path.join(output_dir, pdf_file_name)
local_image_dir = os.path.join(str(local_parent_dir), "images")
local_md_dir = local_parent_dir
os.makedirs(local_image_dir, exist_ok=True)
os.makedirs(local_md_dir, exist_ok=True)
return local_image_dir, local_md_dir
def convert_pdf_bytes_to_bytes_by_pypdfium2(pdf_bytes, start_page_id=0, end_page_id=None):
# 从字节数据加载PDF
pdf = pdfium.PdfDocument(pdf_bytes)
# 确定结束页
end_page_id = end_page_id if end_page_id is not None and end_page_id >= 0 else len(pdf) - 1
if end_page_id > len(pdf) - 1:
logger.warning("end_page_id is out of range, use pdf_docs length")
end_page_id = len(pdf) - 1
# 创建一个新的PDF文档
output_pdf = pdfium.PdfDocument.new()
# 选择要导入的页面索引
page_indices = list(range(start_page_id, end_page_id + 1))
# 从原PDF导入页面到新PDF
output_pdf.import_pages(pdf, page_indices)
# 将新PDF保存到内存缓冲区
output_buffer = io.BytesIO()
output_pdf.save(output_buffer)
# 获取字节数据
output_bytes = output_buffer.getvalue()
return output_bytes
def do_parse(
output_dir,
pdf_file_name,
pdf_bytes,
backend="pipeline",
model_path="jinzhenj/OEEzRkQ3RTAtMDMx-0415", # TODO: change to formal path after release.
server_url=None,
f_draw_layout_bbox=True,
f_draw_span_bbox=False,
f_dump_md=True,
f_dump_middle_json=True,
f_dump_model_output=True,
f_dump_orig_pdf=True,
f_dump_content_list=True,
f_make_md_mode=MakeMode.MM_MD,
start_page_id=0,
end_page_id=None,
):
if backend == 'pipeline':
f_draw_span_bbox = True
pdf_bytes = convert_pdf_bytes_to_bytes_by_pypdfium2(pdf_bytes, start_page_id, end_page_id)
local_image_dir, local_md_dir = prepare_env(output_dir, pdf_file_name)
image_writer, md_writer = FileBasedDataWriter(local_image_dir), FileBasedDataWriter(local_md_dir)
middle_json, infer_result = doc_analyze(pdf_bytes, image_writer=image_writer, backend=backend, server_url=server_url)
pdf_info = middle_json["pdf_info"]
if f_draw_layout_bbox:
draw_layout_bbox(pdf_info, pdf_bytes, local_md_dir, f"{pdf_file_name}_layout.pdf")
if f_draw_span_bbox:
draw_span_bbox(pdf_info, pdf_bytes, local_md_dir, f"{pdf_file_name}_span.pdf")
if f_dump_orig_pdf:
md_writer.write(
f"{pdf_file_name}_origin.pdf",
pdf_bytes,
)
if f_dump_md:
image_dir = str(os.path.basename(local_image_dir))
md_content_str = union_make(pdf_info, f_make_md_mode, image_dir)
md_writer.write_string(
f"{pdf_file_name}.md",
md_content_str,
)
if f_dump_content_list:
image_dir = str(os.path.basename(local_image_dir))
content_list = union_make(pdf_info, MakeMode.STANDARD_FORMAT, image_dir)
md_writer.write_string(
f"{pdf_file_name}_content_list.json",
json.dumps(content_list, ensure_ascii=False, indent=4),
)
if f_dump_middle_json:
md_writer.write_string(
f"{pdf_file_name}_middle.json",
json.dumps(middle_json, ensure_ascii=False, indent=4),
)
if f_dump_model_output:
model_output = ("\n" + "-" * 50 + "\n").join(infer_result)
md_writer.write_string(
f"{pdf_file_name}_model_output.txt",
model_output,
)
logger.info(f"local output dir is {local_md_dir}")
return infer_result
if __name__ == "__main__":
pdf_path = "../../demo/demo2.pdf"
with open(pdf_path, "rb") as f:
try:
result = do_parse("./output", Path(pdf_path).stem, f.read())
except Exception as e:
logger.exception(e)
# dict转成json
print(json.dumps(result, ensure_ascii=False, indent=4))
......@@ -117,6 +117,128 @@ class TextDetector(BaseOCRV20):
self.net.eval()
self.net.to(self.device)
def _batch_process_same_size(self, img_list):
"""
对相同尺寸的图像进行批处理
Args:
img_list: 相同尺寸的图像列表
Returns:
batch_results: 批处理结果列表
total_elapse: 总耗时
"""
starttime = time.time()
# 预处理所有图像
batch_data = []
batch_shapes = []
ori_imgs = []
for img in img_list:
ori_im = img.copy()
ori_imgs.append(ori_im)
data = {'image': img}
data = transform(data, self.preprocess_op)
if data is None:
# 如果预处理失败,返回空结果
return [(None, 0) for _ in img_list], 0
img_processed, shape_list = data
batch_data.append(img_processed)
batch_shapes.append(shape_list)
# 堆叠成批处理张量
try:
batch_tensor = np.stack(batch_data, axis=0)
batch_shapes = np.stack(batch_shapes, axis=0)
except Exception as e:
# 如果堆叠失败,回退到逐个处理
batch_results = []
for img in img_list:
dt_boxes, elapse = self.__call__(img)
batch_results.append((dt_boxes, elapse))
return batch_results, time.time() - starttime
# 批处理推理
with torch.no_grad():
inp = torch.from_numpy(batch_tensor)
inp = inp.to(self.device)
outputs = self.net(inp)
# 处理输出
preds = {}
if self.det_algorithm == "EAST":
preds['f_geo'] = outputs['f_geo'].cpu().numpy()
preds['f_score'] = outputs['f_score'].cpu().numpy()
elif self.det_algorithm == 'SAST':
preds['f_border'] = outputs['f_border'].cpu().numpy()
preds['f_score'] = outputs['f_score'].cpu().numpy()
preds['f_tco'] = outputs['f_tco'].cpu().numpy()
preds['f_tvo'] = outputs['f_tvo'].cpu().numpy()
elif self.det_algorithm in ['DB', 'PSE', 'DB++']:
preds['maps'] = outputs['maps'].cpu().numpy()
elif self.det_algorithm == 'FCE':
for i, (k, output) in enumerate(outputs.items()):
preds['level_{}'.format(i)] = output.cpu().numpy()
else:
raise NotImplementedError
# 后处理每个图像的结果
batch_results = []
total_elapse = time.time() - starttime
for i in range(len(img_list)):
# 提取单个图像的预测结果
single_preds = {}
for key, value in preds.items():
if isinstance(value, np.ndarray):
single_preds[key] = value[i:i + 1] # 保持批次维度
else:
single_preds[key] = value
# 后处理
post_result = self.postprocess_op(single_preds, batch_shapes[i:i + 1])
dt_boxes = post_result[0]['points']
# 过滤和裁剪检测框
if (self.det_algorithm == "SAST" and
self.det_sast_polygon) or (self.det_algorithm in ["PSE", "FCE"] and
self.postprocess_op.box_type == 'poly'):
dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_imgs[i].shape)
else:
dt_boxes = self.filter_tag_det_res(dt_boxes, ori_imgs[i].shape)
batch_results.append((dt_boxes, total_elapse / len(img_list)))
return batch_results, total_elapse
def batch_predict(self, img_list, max_batch_size=8):
"""
批处理预测方法,支持多张图像同时检测
Args:
img_list: 图像列表
max_batch_size: 最大批处理大小
Returns:
batch_results: 批处理结果列表,每个元素为(dt_boxes, elapse)
"""
if not img_list:
return []
batch_results = []
# 分批处理
for i in range(0, len(img_list), max_batch_size):
batch_imgs = img_list[i:i + max_batch_size]
# assert尺寸一致
batch_dt_boxes, batch_elapse = self._batch_process_same_size(batch_imgs)
batch_results.extend(batch_dt_boxes)
return batch_results
def order_points_clockwise(self, pts):
"""
reference from: https://github.com/jrosebr1/imutils/blob/master/imutils/perspective.py
......
......@@ -4,7 +4,7 @@ from io import BytesIO
from PyPDF2 import PdfReader, PdfWriter
from reportlab.pdfgen import canvas
from .enum_class import BlockType
from .enum_class import BlockType, ContentType
def draw_bbox_without_number(i, bbox_list, page, c, rgb_config, fill_config):
......@@ -54,7 +54,7 @@ def draw_bbox_with_number(i, bbox_list, page, c, rgb_config, fill_config, draw_b
def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
# dropped_bbox_list = []
dropped_bbox_list = []
tables_list, tables_body_list = [], []
tables_caption_list, tables_footnote_list = [], []
imgs_list, imgs_body_list, imgs_caption_list = [], [], []
......@@ -65,7 +65,7 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
lists_list = []
indexs_list = []
for page in pdf_info:
# page_dropped_list = []
page_dropped_list = []
tables, tables_body, tables_caption, tables_footnote = [], [], [], []
imgs, imgs_body, imgs_caption, imgs_footnote = [], [], [], []
titles = []
......@@ -74,9 +74,9 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
lists = []
indices = []
# for dropped_bbox in page['discarded_blocks']:
# page_dropped_list.append(dropped_bbox['bbox'])
# dropped_bbox_list.append(page_dropped_list)
for dropped_bbox in page['discarded_blocks']:
page_dropped_list.append(dropped_bbox['bbox'])
dropped_bbox_list.append(page_dropped_list)
for block in page["para_blocks"]:
bbox = block["bbox"]
if block["type"] == BlockType.TABLE:
......@@ -164,7 +164,7 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
# 使用原始PDF的尺寸创建canvas
c = canvas.Canvas(packet, pagesize=custom_page_size)
# c = draw_bbox_without_number(i, dropped_bbox_list, page, c, [158, 158, 158], True)
c = draw_bbox_without_number(i, dropped_bbox_list, page, c, [158, 158, 158], True)
c = draw_bbox_without_number(i, tables_body_list, page, c, [204, 204, 0], True)
c = draw_bbox_without_number(i, tables_caption_list, page, c, [255, 255, 102], True)
c = draw_bbox_without_number(i, tables_footnote_list, page, c, [229, 255, 204], True)
......@@ -190,6 +190,114 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
output_pdf.write(f)
def draw_span_bbox(pdf_info, pdf_bytes, out_path, filename):
text_list = []
inline_equation_list = []
interline_equation_list = []
image_list = []
table_list = []
dropped_list = []
next_page_text_list = []
next_page_inline_equation_list = []
def get_span_info(span):
if span['type'] == ContentType.TEXT:
if span.get('cross_page', False):
next_page_text_list.append(span['bbox'])
else:
page_text_list.append(span['bbox'])
elif span['type'] == ContentType.INLINE_EQUATION:
if span.get('cross_page', False):
next_page_inline_equation_list.append(span['bbox'])
else:
page_inline_equation_list.append(span['bbox'])
elif span['type'] == ContentType.INTERLINE_EQUATION:
page_interline_equation_list.append(span['bbox'])
elif span['type'] == ContentType.IMAGE:
page_image_list.append(span['bbox'])
elif span['type'] == ContentType.TABLE:
page_table_list.append(span['bbox'])
for page in pdf_info:
page_text_list = []
page_inline_equation_list = []
page_interline_equation_list = []
page_image_list = []
page_table_list = []
page_dropped_list = []
# 将跨页的span放到移动到下一页的列表中
if len(next_page_text_list) > 0:
page_text_list.extend(next_page_text_list)
next_page_text_list.clear()
if len(next_page_inline_equation_list) > 0:
page_inline_equation_list.extend(next_page_inline_equation_list)
next_page_inline_equation_list.clear()
# 构造dropped_list
for block in page['discarded_blocks']:
if block['type'] == BlockType.DISCARDED:
for line in block['lines']:
for span in line['spans']:
page_dropped_list.append(span['bbox'])
dropped_list.append(page_dropped_list)
# 构造其余useful_list
# for block in page['para_blocks']: # span直接用分段合并前的结果就可以
for block in page['preproc_blocks']:
if block['type'] in [
BlockType.TEXT,
BlockType.TITLE,
BlockType.INTERLINE_EQUATION,
BlockType.LIST,
BlockType.INDEX,
]:
for line in block['lines']:
for span in line['spans']:
get_span_info(span)
elif block['type'] in [BlockType.IMAGE, BlockType.TABLE]:
for sub_block in block['blocks']:
for line in sub_block['lines']:
for span in line['spans']:
get_span_info(span)
text_list.append(page_text_list)
inline_equation_list.append(page_inline_equation_list)
interline_equation_list.append(page_interline_equation_list)
image_list.append(page_image_list)
table_list.append(page_table_list)
pdf_bytes_io = BytesIO(pdf_bytes)
pdf_docs = PdfReader(pdf_bytes_io)
output_pdf = PdfWriter()
for i, page in enumerate(pdf_docs.pages):
# 获取原始页面尺寸
page_width, page_height = float(page.cropbox[2]), float(page.cropbox[3])
custom_page_size = (page_width, page_height)
packet = BytesIO()
# 使用原始PDF的尺寸创建canvas
c = canvas.Canvas(packet, pagesize=custom_page_size)
# 获取当前页面的数据
draw_bbox_without_number(i, text_list, page, c,[255, 0, 0], False)
draw_bbox_without_number(i, inline_equation_list, page, c, [0, 255, 0], False)
draw_bbox_without_number(i, interline_equation_list, page, c, [0, 0, 255], False)
draw_bbox_without_number(i, image_list, page, c, [255, 204, 0], False)
draw_bbox_without_number(i, table_list, page, c, [204, 0, 255], False)
draw_bbox_without_number(i, dropped_list, page, c, [158, 158, 158], False)
c.save()
packet.seek(0)
overlay_pdf = PdfReader(packet)
page.merge_page(overlay_pdf.pages[0])
output_pdf.add_page(page)
# Save the PDF
with open(f"{out_path}/{filename}", "wb") as f:
output_pdf.write(f)
if __name__ == "__main__":
# 读取PDF文件
pdf_path = "examples/demo1.pdf"
......
......@@ -12,6 +12,7 @@ class BlockType:
INTERLINE_EQUATION = 'interline_equation'
LIST = 'list'
INDEX = 'index'
DISCARDED = 'discarded'
class ContentType:
......@@ -19,6 +20,7 @@ class ContentType:
TABLE = 'table'
TEXT = 'text'
INTERLINE_EQUATION = 'interline_equation'
INLINE_EQUATION = 'inline_equation'
class MakeMode:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment