Commit 0a899f1a authored by myhloli's avatar myhloli
Browse files

feat: add batch processing for OCR detection and implement new client and common utilities

parent cbba27b4
import re
from ..utils.enum_class import MakeMode, BlockType, ContentType
def merge_para_with_text(para_block):
para_text = ''
for line in para_block['lines']:
for span in line['spans']:
content = span['content']
content = content.strip()
if content:
para_text += content
else:
continue
return para_text
def mk_blocks_to_markdown(para_blocks, make_mode, img_buket_path=''):
page_markdown = []
for para_block in para_blocks:
para_text = ''
para_type = para_block['type']
if para_type in [BlockType.TEXT, BlockType.LIST, BlockType.INDEX, BlockType.TITLE, BlockType.INTERLINE_EQUATION]:
para_text = merge_para_with_text(para_block)
elif para_type == BlockType.IMAGE:
if make_mode == MakeMode.NLP_MD:
continue
elif make_mode == MakeMode.MM_MD:
# 检测是否存在图片脚注
has_image_footnote = any(block['type'] == BlockType.IMAGE_FOOTNOTE for block in para_block['blocks'])
# 如果存在图片脚注,则将图片脚注拼接到图片正文后面
if has_image_footnote:
for block in para_block['blocks']: # 1st.拼image_caption
if block['type'] == BlockType.IMAGE_CAPTION:
para_text += merge_para_with_text(block) + ' \n'
for block in para_block['blocks']: # 2nd.拼image_body
if block['type'] == BlockType.IMAGE_BODY:
for line in block['lines']:
for span in line['spans']:
if span['type'] == ContentType.IMAGE:
if span.get('image_path', ''):
para_text += f"![]({img_buket_path}/{span['image_path']})"
for block in para_block['blocks']: # 3rd.拼image_footnote
if block['type'] == BlockType.IMAGE_FOOTNOTE:
para_text += ' \n' + merge_para_with_text(block)
else:
for block in para_block['blocks']: # 1st.拼image_body
if block['type'] == BlockType.IMAGE_BODY:
for line in block['lines']:
for span in line['spans']:
if span['type'] == ContentType.IMAGE:
if span.get('image_path', ''):
para_text += f"![]({img_buket_path}/{span['image_path']})"
for block in para_block['blocks']: # 2nd.拼image_caption
if block['type'] == BlockType.IMAGE_CAPTION:
para_text += ' \n' + merge_para_with_text(block)
elif para_type == BlockType.TABLE:
if make_mode == MakeMode.NLP_MD:
continue
elif make_mode == MakeMode.MM_MD:
for block in para_block['blocks']: # 1st.拼table_caption
if block['type'] == BlockType.TABLE_CAPTION:
para_text += merge_para_with_text(block) + ' \n'
for block in para_block['blocks']: # 2nd.拼table_body
if block['type'] == BlockType.TABLE_BODY:
for line in block['lines']:
for span in line['spans']:
if span['type'] == ContentType.TABLE:
# if processed by table model
if span.get('html', ''):
para_text += f"\n{span['html']}\n"
elif span.get('image_path', ''):
para_text += f"![]({img_buket_path}/{span['image_path']})"
for block in para_block['blocks']: # 3rd.拼table_footnote
if block['type'] == BlockType.TABLE_FOOTNOTE:
para_text += '\n' + merge_para_with_text(block) + ' '
if para_text.strip() == '':
continue
else:
# page_markdown.append(para_text.strip() + ' ')
page_markdown.append(para_text.strip())
return page_markdown
def count_leading_hashes(text):
match = re.match(r'^(#+)', text)
return len(match.group(1)) if match else 0
def strip_leading_hashes(text):
# 去除开头的#和紧随其后的空格
return re.sub(r'^#+\s*', '', text)
def make_blocks_to_content_list(para_block, img_buket_path, page_idx):
para_type = para_block['type']
para_content = {}
if para_type in [BlockType.TEXT, BlockType.LIST, BlockType.INDEX]:
para_content = {
'type': 'text',
'text': merge_para_with_text(para_block),
}
elif para_type == BlockType.TITLE:
title_content = merge_para_with_text(para_block)
title_level = count_leading_hashes(title_content)
para_content = {
'type': 'text',
'text': strip_leading_hashes(title_content),
}
if title_level != 0:
para_content['text_level'] = title_level
elif para_type == BlockType.INTERLINE_EQUATION:
para_content = {
'type': 'equation',
'text': merge_para_with_text(para_block),
'text_format': 'latex',
}
elif para_type == BlockType.IMAGE:
para_content = {'type': 'image', 'img_path': '', 'img_caption': [], 'img_footnote': []}
for block in para_block['blocks']:
if block['type'] == BlockType.IMAGE_BODY:
for line in block['lines']:
for span in line['spans']:
if span['type'] == ContentType.IMAGE:
if span.get('image_path', ''):
para_content['img_path'] = f"{img_buket_path}/{span['image_path']}"
if block['type'] == BlockType.IMAGE_CAPTION:
para_content['img_caption'].append(merge_para_with_text(block))
if block['type'] == BlockType.IMAGE_FOOTNOTE:
para_content['img_footnote'].append(merge_para_with_text(block))
elif para_type == BlockType.TABLE:
para_content = {'type': 'table', 'img_path': '', 'table_caption': [], 'table_footnote': []}
for block in para_block['blocks']:
if block['type'] == BlockType.TABLE_BODY:
for line in block['lines']:
for span in line['spans']:
if span['type'] == ContentType.TABLE:
if span.get('html', ''):
para_content['table_body'] = f"{span['html']}"
if span.get('image_path', ''):
para_content['img_path'] = f"{img_buket_path}/{span['image_path']}"
if block['type'] == BlockType.TABLE_CAPTION:
para_content['table_caption'].append(merge_para_with_text(block))
if block['type'] == BlockType.TABLE_FOOTNOTE:
para_content['table_footnote'].append(merge_para_with_text(block))
para_content['page_idx'] = page_idx
return para_content
def union_make(pdf_info_dict: list,
make_mode: str,
img_buket_path: str = '',
):
output_content = []
for page_info in pdf_info_dict:
paras_of_layout = page_info.get('para_blocks')
page_idx = page_info.get('page_idx')
if not paras_of_layout:
continue
if make_mode in [MakeMode.MM_MD, MakeMode.NLP_MD]:
page_markdown = mk_blocks_to_markdown(paras_of_layout, make_mode, img_buket_path)
output_content.extend(page_markdown)
elif make_mode == MakeMode.STANDARD_FORMAT:
for para_block in paras_of_layout:
para_content = make_blocks_to_content_list(para_block, img_buket_path, page_idx)
output_content.append(para_content)
if make_mode in [MakeMode.MM_MD, MakeMode.NLP_MD]:
return '\n\n'.join(output_content)
elif make_mode == MakeMode.STANDARD_FORMAT:
return output_content
return None
def get_title_level(block):
title_level = block.get('level', 1)
if title_level > 4:
title_level = 4
elif title_level < 1:
title_level = 0
return title_level
import cv2 import cv2
from loguru import logger from loguru import logger
from tqdm import tqdm from tqdm import tqdm
from collections import defaultdict
import numpy as np
from .model_init import AtomModelSingleton from .model_init import AtomModelSingleton
from ...utils.model_utils import crop_img, get_res_list_from_layout_res, get_coords_and_area from ...utils.model_utils import crop_img, get_res_list_from_layout_res, get_coords_and_area
...@@ -12,11 +14,12 @@ MFR_BASE_BATCH_SIZE = 16 ...@@ -12,11 +14,12 @@ MFR_BASE_BATCH_SIZE = 16
class BatchAnalyze: class BatchAnalyze:
def __init__(self, model_manager, batch_ratio: int, formula_enable, table_enable): def __init__(self, model_manager, batch_ratio: int, formula_enable, table_enable, enable_ocr_det_batch: bool = True):
self.batch_ratio = batch_ratio self.batch_ratio = batch_ratio
self.formula_enable = formula_enable self.formula_enable = formula_enable
self.table_enable = table_enable self.table_enable = table_enable
self.model_manager = model_manager self.model_manager = model_manager
self.enable_ocr_det_batch = enable_ocr_det_batch
def __call__(self, images_with_extra_info: list) -> list: def __call__(self, images_with_extra_info: list) -> list:
if len(images_with_extra_info) == 0: if len(images_with_extra_info) == 0:
...@@ -89,48 +92,160 @@ class BatchAnalyze: ...@@ -89,48 +92,160 @@ class BatchAnalyze:
'table_img':table_img, 'table_img':table_img,
}) })
# 文本框检测 # OCR检测处理
if self.enable_ocr_det_batch:
for ocr_res_list_dict in tqdm(ocr_res_list_all_page, desc="OCR-det Predict"): # 批处理模式 - 按语言和分辨率分组
# Process each area that requires OCR processing # 收集所有需要OCR检测的裁剪图像
_lang = ocr_res_list_dict['lang'] all_cropped_images_info = []
# Get OCR results for this language's images
ocr_model = atom_model_manager.get_atom_model( for ocr_res_list_dict in ocr_res_list_all_page:
atom_model_name='ocr', _lang = ocr_res_list_dict['lang']
det_db_box_thresh=0.3,
lang=_lang for res in ocr_res_list_dict['ocr_res_list']:
) new_image, useful_list = crop_img(
for res in ocr_res_list_dict['ocr_res_list']: res, ocr_res_list_dict['np_array_img'], crop_paste_x=50, crop_paste_y=50
new_image, useful_list = crop_img( )
res, ocr_res_list_dict['np_array_img'], crop_paste_x=50, crop_paste_y=50 adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(
) ocr_res_list_dict['single_page_mfdetrec_res'], useful_list
adjusted_mfdetrec_res = get_adjusted_mfdetrec_res( )
ocr_res_list_dict['single_page_mfdetrec_res'], useful_list
) # BGR转换
new_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR)
all_cropped_images_info.append((
new_image, useful_list, ocr_res_list_dict, res, adjusted_mfdetrec_res, _lang
))
# 按语言分组
lang_groups = defaultdict(list)
for crop_info in all_cropped_images_info:
lang = crop_info[5]
lang_groups[lang].append(crop_info)
# 对每种语言按分辨率分组并批处理
for lang, lang_crop_list in lang_groups.items():
if not lang_crop_list:
continue
# logger.info(f"Processing OCR detection for language {lang} with {len(lang_crop_list)} images")
# 获取OCR模型
ocr_model = atom_model_manager.get_atom_model(
atom_model_name='ocr',
ocr_show_log=False,
det_db_box_thresh=0.3,
lang=lang
)
# OCR-det # 按分辨率分组并同时完成padding
new_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR) resolution_groups = defaultdict(list)
ocr_res = ocr_model.ocr( for crop_info in lang_crop_list:
new_image, mfd_res=adjusted_mfdetrec_res, rec=False cropped_img = crop_info[0]
)[0] h, w = cropped_img.shape[:2]
# 使用更大的分组容差,减少分组数量
# Integration results # 将尺寸标准化到32的倍数
if ocr_res: normalized_h = ((h + 32) // 32) * 32 # 向上取整到32的倍数
ocr_result_list = get_ocr_result_list(ocr_res, useful_list, ocr_res_list_dict['ocr_enable'], new_image, _lang) normalized_w = ((w + 32) // 32) * 32
group_key = (normalized_h, normalized_w)
if res["category_id"] == 3: resolution_groups[group_key].append(crop_info)
# ocr_result_list中所有bbox的面积之和
ocr_res_area = sum(get_coords_and_area(ocr_res_item)[4] for ocr_res_item in ocr_result_list if 'poly' in ocr_res_item) # 对每个分辨率组进行批处理
# 求ocr_res_area和res的面积的比值 for group_key, group_crops in tqdm(resolution_groups.items(), desc=f"OCR-det {lang}"):
res_area = get_coords_and_area(res)[4] raw_images = [crop_info[0] for crop_info in group_crops]
if res_area > 0:
ratio = ocr_res_area / res_area # 计算目标尺寸(组内最大尺寸,向上取整到32的倍数)
if ratio > 0.25: max_h = max(img.shape[0] for img in raw_images)
res["category_id"] = 1 max_w = max(img.shape[1] for img in raw_images)
else: target_h = ((max_h + 32 - 1) // 32) * 32
continue target_w = ((max_w + 32 - 1) // 32) * 32
ocr_res_list_dict['layout_res'].extend(ocr_result_list) # 对所有图像进行padding到统一尺寸
batch_images = []
for img in raw_images:
h, w = img.shape[:2]
# 创建目标尺寸的白色背景
padded_img = np.ones((target_h, target_w, 3), dtype=np.uint8) * 255
# 将原图像粘贴到左上角
padded_img[:h, :w] = img
batch_images.append(padded_img)
# 批处理检测
batch_size = min(len(batch_images), self.batch_ratio * 16) # 增加批处理大小
# logger.debug(f"OCR-det batch: {batch_size} images, target size: {target_h}x{target_w}")
batch_results = ocr_model.text_detector.batch_predict(batch_images, batch_size)
# 处理批处理结果
for i, (crop_info, (dt_boxes, elapse)) in enumerate(zip(group_crops, batch_results)):
new_image, useful_list, ocr_res_list_dict, res, adjusted_mfdetrec_res, _lang = crop_info
if dt_boxes is not None:
# 构造OCR结果格式 - 每个box应该是4个点的列表
ocr_res = [box.tolist() for box in dt_boxes]
if ocr_res:
ocr_result_list = get_ocr_result_list(
ocr_res, useful_list, ocr_res_list_dict['ocr_enable'], new_image, _lang
)
if res["category_id"] == 3:
# ocr_result_list中所有bbox的面积之和
ocr_res_area = sum(
get_coords_and_area(ocr_res_item)[4] for ocr_res_item in ocr_result_list if 'poly' in ocr_res_item)
# 求ocr_res_area和res的面积的比值
res_area = get_coords_and_area(res)[4]
if res_area > 0:
ratio = ocr_res_area / res_area
if ratio > 0.25:
res["category_id"] = 1
else:
continue
ocr_res_list_dict['layout_res'].extend(ocr_result_list)
else:
# 原始单张处理模式
for ocr_res_list_dict in tqdm(ocr_res_list_all_page, desc="OCR-det Predict"):
# Process each area that requires OCR processing
_lang = ocr_res_list_dict['lang']
# Get OCR results for this language's images
ocr_model = atom_model_manager.get_atom_model(
atom_model_name='ocr',
ocr_show_log=False,
det_db_box_thresh=0.3,
lang=_lang
)
for res in ocr_res_list_dict['ocr_res_list']:
new_image, useful_list = crop_img(
res, ocr_res_list_dict['np_array_img'], crop_paste_x=50, crop_paste_y=50
)
adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(
ocr_res_list_dict['single_page_mfdetrec_res'], useful_list
)
# OCR-det
new_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR)
ocr_res = ocr_model.ocr(
new_image, mfd_res=adjusted_mfdetrec_res, rec=False
)[0]
# Integration results
if ocr_res:
ocr_result_list = get_ocr_result_list(ocr_res, useful_list, ocr_res_list_dict['ocr_enable'],
new_image, _lang)
if res["category_id"] == 3:
# ocr_result_list中所有bbox的面积之和
ocr_res_area = sum(
get_coords_and_area(ocr_res_item)[4] for ocr_res_item in ocr_result_list if 'poly' in ocr_res_item)
# 求ocr_res_area和res的面积的比值
res_area = get_coords_and_area(res)[4]
if res_area > 0:
ratio = ocr_res_area / res_area
if ratio > 0.25:
res["category_id"] = 1
else:
continue
ocr_res_list_dict['layout_res'].extend(ocr_result_list)
# 表格识别 table recognition # 表格识别 table recognition
if self.table_enable: if self.table_enable:
......
...@@ -3,7 +3,7 @@ import re ...@@ -3,7 +3,7 @@ import re
from mineru.utils.cut_image import cut_image_and_table from mineru.utils.cut_image import cut_image_and_table
from mineru.utils.enum_class import BlockType, ContentType from mineru.utils.enum_class import BlockType, ContentType
from mineru.utils.hash_utils import str_md5 from mineru.utils.hash_utils import str_md5
from mineru.utils.magic_model import fix_two_layer_blocks from mineru.utils.vlm_magic_model import fix_two_layer_blocks
from mineru.version import __version__ from mineru.version import __version__
...@@ -113,7 +113,7 @@ def token_to_page_info(token, image_dict, page, image_writer, page_index) -> dic ...@@ -113,7 +113,7 @@ def token_to_page_info(token, image_dict, page, image_writer, page_index) -> dic
# 对page_blocks根据index的值进行排序 # 对page_blocks根据index的值进行排序
page_blocks.sort(key=lambda x: x["index"]) page_blocks.sort(key=lambda x: x["index"])
page_info = {"para_blocks": page_blocks, "page_size": [width, height], "page_idx": page_index} page_info = {"para_blocks": page_blocks, "discarded_blocks": [], "page_size": [width, height], "page_idx": page_index}
return page_info return page_info
......
# Copyright (c) Opendatalab. All rights reserved.
import os
import click
from pathlib import Path
from loguru import logger
from ..version import __version__
from .common import do_parse, read_fn, pdf_suffixes, image_suffixes
@click.command()
@click.version_option(__version__,
'--version',
'-v',
help='display the version and exit')
@click.option(
'-p',
'--path',
'input_path',
type=click.Path(exists=True),
required=True,
help='local filepath or directory. support pdf, png, jpg, jpeg files',
)
@click.option(
'-o',
'--output-dir',
'output_dir',
type=click.Path(),
required=True,
help='output local directory',
)
@click.option(
'-b',
'--backend',
'backend',
type=click.Choice(['pipeline', 'vlm-huggingface', 'vlm-sglang-engine', 'vlm-sglang-client']),
help="""the backend for parsing pdf:
pipeline: More general.
vlm-huggingface: More general.
vlm-sglang-engine: Faster(engine).
vlm-sglang-client: Faster(client).
without method specified, huggingface will be used by default.""",
default='pipeline',
)
@click.option(
'-u',
'--url',
'server_url',
type=str,
help="""
When the backend is `sglang-client`, you need to specify the server_url, for example:`http://127.0.0.1:30000`
""",
default=None,
)
@click.option(
'-s',
'--start',
'start_page_id',
type=int,
help='The starting page for PDF parsing, beginning from 0.',
default=0,
)
@click.option(
'-e',
'--end',
'end_page_id',
type=int,
help='The ending page for PDF parsing, beginning from 0.',
default=None,
)
def main(input_path, output_dir, backend, server_url, start_page_id, end_page_id):
os.makedirs(output_dir, exist_ok=True)
def parse_doc(path: Path):
try:
file_name = str(Path(path).stem)
pdf_bits = read_fn(path)
do_parse(output_dir, file_name, pdf_bits, backend, server_url,
start_page_id=start_page_id, end_page_id=end_page_id)
except Exception as e:
logger.exception(e)
if os.path.isdir(input_path):
for doc_path in Path(input_path).glob('*'):
if doc_path.suffix in pdf_suffixes + image_suffixes:
parse_doc(Path(doc_path))
else:
parse_doc(Path(input_path))
if __name__ == '__main__':
main()
# Copyright (c) Opendatalab. All rights reserved.
import io
import json
import os
from pathlib import Path
import pypdfium2 as pdfium
from loguru import logger
from ..api.vlm_middle_json_mkcontent import union_make
from ..backend.vlm.vlm_analyze import doc_analyze
from ..data.data_reader_writer import FileBasedDataWriter
from ..utils.draw_bbox import draw_layout_bbox, draw_span_bbox
from ..utils.enum_class import MakeMode
from ..utils.pdf_image_tools import images_bytes_to_pdf_bytes
pdf_suffixes = [".pdf"]
image_suffixes = [".png", ".jpeg", ".jpg"]
def read_fn(path: Path):
with open(str(path), "rb") as input_file:
file_bytes = input_file.read()
if path.suffix in image_suffixes:
return images_bytes_to_pdf_bytes(file_bytes)
elif path.suffix in pdf_suffixes:
return file_bytes
else:
raise Exception(f"Unknown file suffix: {path.suffix}")
def prepare_env(output_dir, pdf_file_name):
local_parent_dir = os.path.join(output_dir, pdf_file_name)
local_image_dir = os.path.join(str(local_parent_dir), "images")
local_md_dir = local_parent_dir
os.makedirs(local_image_dir, exist_ok=True)
os.makedirs(local_md_dir, exist_ok=True)
return local_image_dir, local_md_dir
def convert_pdf_bytes_to_bytes_by_pypdfium2(pdf_bytes, start_page_id=0, end_page_id=None):
# 从字节数据加载PDF
pdf = pdfium.PdfDocument(pdf_bytes)
# 确定结束页
end_page_id = end_page_id if end_page_id is not None and end_page_id >= 0 else len(pdf) - 1
if end_page_id > len(pdf) - 1:
logger.warning("end_page_id is out of range, use pdf_docs length")
end_page_id = len(pdf) - 1
# 创建一个新的PDF文档
output_pdf = pdfium.PdfDocument.new()
# 选择要导入的页面索引
page_indices = list(range(start_page_id, end_page_id + 1))
# 从原PDF导入页面到新PDF
output_pdf.import_pages(pdf, page_indices)
# 将新PDF保存到内存缓冲区
output_buffer = io.BytesIO()
output_pdf.save(output_buffer)
# 获取字节数据
output_bytes = output_buffer.getvalue()
return output_bytes
def do_parse(
output_dir,
pdf_file_name,
pdf_bytes,
backend="pipeline",
model_path="jinzhenj/OEEzRkQ3RTAtMDMx-0415", # TODO: change to formal path after release.
server_url=None,
f_draw_layout_bbox=True,
f_draw_span_bbox=False,
f_dump_md=True,
f_dump_middle_json=True,
f_dump_model_output=True,
f_dump_orig_pdf=True,
f_dump_content_list=True,
f_make_md_mode=MakeMode.MM_MD,
start_page_id=0,
end_page_id=None,
):
if backend == 'pipeline':
f_draw_span_bbox = True
pdf_bytes = convert_pdf_bytes_to_bytes_by_pypdfium2(pdf_bytes, start_page_id, end_page_id)
local_image_dir, local_md_dir = prepare_env(output_dir, pdf_file_name)
image_writer, md_writer = FileBasedDataWriter(local_image_dir), FileBasedDataWriter(local_md_dir)
middle_json, infer_result = doc_analyze(pdf_bytes, image_writer=image_writer, backend=backend, server_url=server_url)
pdf_info = middle_json["pdf_info"]
if f_draw_layout_bbox:
draw_layout_bbox(pdf_info, pdf_bytes, local_md_dir, f"{pdf_file_name}_layout.pdf")
if f_draw_span_bbox:
draw_span_bbox(pdf_info, pdf_bytes, local_md_dir, f"{pdf_file_name}_span.pdf")
if f_dump_orig_pdf:
md_writer.write(
f"{pdf_file_name}_origin.pdf",
pdf_bytes,
)
if f_dump_md:
image_dir = str(os.path.basename(local_image_dir))
md_content_str = union_make(pdf_info, f_make_md_mode, image_dir)
md_writer.write_string(
f"{pdf_file_name}.md",
md_content_str,
)
if f_dump_content_list:
image_dir = str(os.path.basename(local_image_dir))
content_list = union_make(pdf_info, MakeMode.STANDARD_FORMAT, image_dir)
md_writer.write_string(
f"{pdf_file_name}_content_list.json",
json.dumps(content_list, ensure_ascii=False, indent=4),
)
if f_dump_middle_json:
md_writer.write_string(
f"{pdf_file_name}_middle.json",
json.dumps(middle_json, ensure_ascii=False, indent=4),
)
if f_dump_model_output:
model_output = ("\n" + "-" * 50 + "\n").join(infer_result)
md_writer.write_string(
f"{pdf_file_name}_model_output.txt",
model_output,
)
logger.info(f"local output dir is {local_md_dir}")
return infer_result
if __name__ == "__main__":
pdf_path = "../../demo/demo2.pdf"
with open(pdf_path, "rb") as f:
try:
result = do_parse("./output", Path(pdf_path).stem, f.read())
except Exception as e:
logger.exception(e)
# dict转成json
print(json.dumps(result, ensure_ascii=False, indent=4))
...@@ -117,6 +117,128 @@ class TextDetector(BaseOCRV20): ...@@ -117,6 +117,128 @@ class TextDetector(BaseOCRV20):
self.net.eval() self.net.eval()
self.net.to(self.device) self.net.to(self.device)
def _batch_process_same_size(self, img_list):
"""
对相同尺寸的图像进行批处理
Args:
img_list: 相同尺寸的图像列表
Returns:
batch_results: 批处理结果列表
total_elapse: 总耗时
"""
starttime = time.time()
# 预处理所有图像
batch_data = []
batch_shapes = []
ori_imgs = []
for img in img_list:
ori_im = img.copy()
ori_imgs.append(ori_im)
data = {'image': img}
data = transform(data, self.preprocess_op)
if data is None:
# 如果预处理失败,返回空结果
return [(None, 0) for _ in img_list], 0
img_processed, shape_list = data
batch_data.append(img_processed)
batch_shapes.append(shape_list)
# 堆叠成批处理张量
try:
batch_tensor = np.stack(batch_data, axis=0)
batch_shapes = np.stack(batch_shapes, axis=0)
except Exception as e:
# 如果堆叠失败,回退到逐个处理
batch_results = []
for img in img_list:
dt_boxes, elapse = self.__call__(img)
batch_results.append((dt_boxes, elapse))
return batch_results, time.time() - starttime
# 批处理推理
with torch.no_grad():
inp = torch.from_numpy(batch_tensor)
inp = inp.to(self.device)
outputs = self.net(inp)
# 处理输出
preds = {}
if self.det_algorithm == "EAST":
preds['f_geo'] = outputs['f_geo'].cpu().numpy()
preds['f_score'] = outputs['f_score'].cpu().numpy()
elif self.det_algorithm == 'SAST':
preds['f_border'] = outputs['f_border'].cpu().numpy()
preds['f_score'] = outputs['f_score'].cpu().numpy()
preds['f_tco'] = outputs['f_tco'].cpu().numpy()
preds['f_tvo'] = outputs['f_tvo'].cpu().numpy()
elif self.det_algorithm in ['DB', 'PSE', 'DB++']:
preds['maps'] = outputs['maps'].cpu().numpy()
elif self.det_algorithm == 'FCE':
for i, (k, output) in enumerate(outputs.items()):
preds['level_{}'.format(i)] = output.cpu().numpy()
else:
raise NotImplementedError
# 后处理每个图像的结果
batch_results = []
total_elapse = time.time() - starttime
for i in range(len(img_list)):
# 提取单个图像的预测结果
single_preds = {}
for key, value in preds.items():
if isinstance(value, np.ndarray):
single_preds[key] = value[i:i + 1] # 保持批次维度
else:
single_preds[key] = value
# 后处理
post_result = self.postprocess_op(single_preds, batch_shapes[i:i + 1])
dt_boxes = post_result[0]['points']
# 过滤和裁剪检测框
if (self.det_algorithm == "SAST" and
self.det_sast_polygon) or (self.det_algorithm in ["PSE", "FCE"] and
self.postprocess_op.box_type == 'poly'):
dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_imgs[i].shape)
else:
dt_boxes = self.filter_tag_det_res(dt_boxes, ori_imgs[i].shape)
batch_results.append((dt_boxes, total_elapse / len(img_list)))
return batch_results, total_elapse
def batch_predict(self, img_list, max_batch_size=8):
"""
批处理预测方法,支持多张图像同时检测
Args:
img_list: 图像列表
max_batch_size: 最大批处理大小
Returns:
batch_results: 批处理结果列表,每个元素为(dt_boxes, elapse)
"""
if not img_list:
return []
batch_results = []
# 分批处理
for i in range(0, len(img_list), max_batch_size):
batch_imgs = img_list[i:i + max_batch_size]
# assert尺寸一致
batch_dt_boxes, batch_elapse = self._batch_process_same_size(batch_imgs)
batch_results.extend(batch_dt_boxes)
return batch_results
def order_points_clockwise(self, pts): def order_points_clockwise(self, pts):
""" """
reference from: https://github.com/jrosebr1/imutils/blob/master/imutils/perspective.py reference from: https://github.com/jrosebr1/imutils/blob/master/imutils/perspective.py
......
...@@ -4,7 +4,7 @@ from io import BytesIO ...@@ -4,7 +4,7 @@ from io import BytesIO
from PyPDF2 import PdfReader, PdfWriter from PyPDF2 import PdfReader, PdfWriter
from reportlab.pdfgen import canvas from reportlab.pdfgen import canvas
from .enum_class import BlockType from .enum_class import BlockType, ContentType
def draw_bbox_without_number(i, bbox_list, page, c, rgb_config, fill_config): def draw_bbox_without_number(i, bbox_list, page, c, rgb_config, fill_config):
...@@ -54,7 +54,7 @@ def draw_bbox_with_number(i, bbox_list, page, c, rgb_config, fill_config, draw_b ...@@ -54,7 +54,7 @@ def draw_bbox_with_number(i, bbox_list, page, c, rgb_config, fill_config, draw_b
def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename): def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
# dropped_bbox_list = [] dropped_bbox_list = []
tables_list, tables_body_list = [], [] tables_list, tables_body_list = [], []
tables_caption_list, tables_footnote_list = [], [] tables_caption_list, tables_footnote_list = [], []
imgs_list, imgs_body_list, imgs_caption_list = [], [], [] imgs_list, imgs_body_list, imgs_caption_list = [], [], []
...@@ -65,7 +65,7 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename): ...@@ -65,7 +65,7 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
lists_list = [] lists_list = []
indexs_list = [] indexs_list = []
for page in pdf_info: for page in pdf_info:
# page_dropped_list = [] page_dropped_list = []
tables, tables_body, tables_caption, tables_footnote = [], [], [], [] tables, tables_body, tables_caption, tables_footnote = [], [], [], []
imgs, imgs_body, imgs_caption, imgs_footnote = [], [], [], [] imgs, imgs_body, imgs_caption, imgs_footnote = [], [], [], []
titles = [] titles = []
...@@ -74,9 +74,9 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename): ...@@ -74,9 +74,9 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
lists = [] lists = []
indices = [] indices = []
# for dropped_bbox in page['discarded_blocks']: for dropped_bbox in page['discarded_blocks']:
# page_dropped_list.append(dropped_bbox['bbox']) page_dropped_list.append(dropped_bbox['bbox'])
# dropped_bbox_list.append(page_dropped_list) dropped_bbox_list.append(page_dropped_list)
for block in page["para_blocks"]: for block in page["para_blocks"]:
bbox = block["bbox"] bbox = block["bbox"]
if block["type"] == BlockType.TABLE: if block["type"] == BlockType.TABLE:
...@@ -164,7 +164,7 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename): ...@@ -164,7 +164,7 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
# 使用原始PDF的尺寸创建canvas # 使用原始PDF的尺寸创建canvas
c = canvas.Canvas(packet, pagesize=custom_page_size) c = canvas.Canvas(packet, pagesize=custom_page_size)
# c = draw_bbox_without_number(i, dropped_bbox_list, page, c, [158, 158, 158], True) c = draw_bbox_without_number(i, dropped_bbox_list, page, c, [158, 158, 158], True)
c = draw_bbox_without_number(i, tables_body_list, page, c, [204, 204, 0], True) c = draw_bbox_without_number(i, tables_body_list, page, c, [204, 204, 0], True)
c = draw_bbox_without_number(i, tables_caption_list, page, c, [255, 255, 102], True) c = draw_bbox_without_number(i, tables_caption_list, page, c, [255, 255, 102], True)
c = draw_bbox_without_number(i, tables_footnote_list, page, c, [229, 255, 204], True) c = draw_bbox_without_number(i, tables_footnote_list, page, c, [229, 255, 204], True)
...@@ -190,6 +190,114 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename): ...@@ -190,6 +190,114 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
output_pdf.write(f) output_pdf.write(f)
def draw_span_bbox(pdf_info, pdf_bytes, out_path, filename):
text_list = []
inline_equation_list = []
interline_equation_list = []
image_list = []
table_list = []
dropped_list = []
next_page_text_list = []
next_page_inline_equation_list = []
def get_span_info(span):
if span['type'] == ContentType.TEXT:
if span.get('cross_page', False):
next_page_text_list.append(span['bbox'])
else:
page_text_list.append(span['bbox'])
elif span['type'] == ContentType.INLINE_EQUATION:
if span.get('cross_page', False):
next_page_inline_equation_list.append(span['bbox'])
else:
page_inline_equation_list.append(span['bbox'])
elif span['type'] == ContentType.INTERLINE_EQUATION:
page_interline_equation_list.append(span['bbox'])
elif span['type'] == ContentType.IMAGE:
page_image_list.append(span['bbox'])
elif span['type'] == ContentType.TABLE:
page_table_list.append(span['bbox'])
for page in pdf_info:
page_text_list = []
page_inline_equation_list = []
page_interline_equation_list = []
page_image_list = []
page_table_list = []
page_dropped_list = []
# 将跨页的span放到移动到下一页的列表中
if len(next_page_text_list) > 0:
page_text_list.extend(next_page_text_list)
next_page_text_list.clear()
if len(next_page_inline_equation_list) > 0:
page_inline_equation_list.extend(next_page_inline_equation_list)
next_page_inline_equation_list.clear()
# 构造dropped_list
for block in page['discarded_blocks']:
if block['type'] == BlockType.DISCARDED:
for line in block['lines']:
for span in line['spans']:
page_dropped_list.append(span['bbox'])
dropped_list.append(page_dropped_list)
# 构造其余useful_list
# for block in page['para_blocks']: # span直接用分段合并前的结果就可以
for block in page['preproc_blocks']:
if block['type'] in [
BlockType.TEXT,
BlockType.TITLE,
BlockType.INTERLINE_EQUATION,
BlockType.LIST,
BlockType.INDEX,
]:
for line in block['lines']:
for span in line['spans']:
get_span_info(span)
elif block['type'] in [BlockType.IMAGE, BlockType.TABLE]:
for sub_block in block['blocks']:
for line in sub_block['lines']:
for span in line['spans']:
get_span_info(span)
text_list.append(page_text_list)
inline_equation_list.append(page_inline_equation_list)
interline_equation_list.append(page_interline_equation_list)
image_list.append(page_image_list)
table_list.append(page_table_list)
pdf_bytes_io = BytesIO(pdf_bytes)
pdf_docs = PdfReader(pdf_bytes_io)
output_pdf = PdfWriter()
for i, page in enumerate(pdf_docs.pages):
# 获取原始页面尺寸
page_width, page_height = float(page.cropbox[2]), float(page.cropbox[3])
custom_page_size = (page_width, page_height)
packet = BytesIO()
# 使用原始PDF的尺寸创建canvas
c = canvas.Canvas(packet, pagesize=custom_page_size)
# 获取当前页面的数据
draw_bbox_without_number(i, text_list, page, c,[255, 0, 0], False)
draw_bbox_without_number(i, inline_equation_list, page, c, [0, 255, 0], False)
draw_bbox_without_number(i, interline_equation_list, page, c, [0, 0, 255], False)
draw_bbox_without_number(i, image_list, page, c, [255, 204, 0], False)
draw_bbox_without_number(i, table_list, page, c, [204, 0, 255], False)
draw_bbox_without_number(i, dropped_list, page, c, [158, 158, 158], False)
c.save()
packet.seek(0)
overlay_pdf = PdfReader(packet)
page.merge_page(overlay_pdf.pages[0])
output_pdf.add_page(page)
# Save the PDF
with open(f"{out_path}/{filename}", "wb") as f:
output_pdf.write(f)
if __name__ == "__main__": if __name__ == "__main__":
# 读取PDF文件 # 读取PDF文件
pdf_path = "examples/demo1.pdf" pdf_path = "examples/demo1.pdf"
......
...@@ -12,6 +12,7 @@ class BlockType: ...@@ -12,6 +12,7 @@ class BlockType:
INTERLINE_EQUATION = 'interline_equation' INTERLINE_EQUATION = 'interline_equation'
LIST = 'list' LIST = 'list'
INDEX = 'index' INDEX = 'index'
DISCARDED = 'discarded'
class ContentType: class ContentType:
...@@ -19,6 +20,7 @@ class ContentType: ...@@ -19,6 +20,7 @@ class ContentType:
TABLE = 'table' TABLE = 'table'
TEXT = 'text' TEXT = 'text'
INTERLINE_EQUATION = 'interline_equation' INTERLINE_EQUATION = 'interline_equation'
INLINE_EQUATION = 'inline_equation'
class MakeMode: class MakeMode:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment