Commit 6ec440d6 authored by myhloli's avatar myhloli
Browse files

feat(pdf_parse): implement multi-threaded page processing

- Add ThreadPoolExecutor to process PDF pages in parallel
- Create separate function for page processing to improve readability and maintainability
- Include error handling for individual page processing tasks
- Log total page processing time for performance monitoring
parent df1b8f59
...@@ -24,6 +24,8 @@ from magic_pdf.libs.pdf_image_tools import cut_image_to_pil_image ...@@ -24,6 +24,8 @@ from magic_pdf.libs.pdf_image_tools import cut_image_to_pil_image
from magic_pdf.model.magic_model import MagicModel from magic_pdf.model.magic_model import MagicModel
from magic_pdf.post_proc.llm_aided import llm_aided_formula, llm_aided_text, llm_aided_title from magic_pdf.post_proc.llm_aided import llm_aided_formula, llm_aided_text, llm_aided_title
from concurrent.futures import ThreadPoolExecutor
try: try:
import torchtext import torchtext
...@@ -937,16 +939,33 @@ def pdf_parse_union( ...@@ -937,16 +939,33 @@ def pdf_parse_union(
"""初始化启动时间""" """初始化启动时间"""
start_time = time.time() start_time = time.time()
for page_id, page in enumerate(dataset): # for page_id, page in enumerate(dataset):
"""debug时输出每页解析的耗时.""" # """debug时输出每页解析的耗时."""
# if debug_mode:
# time_now = time.time()
# logger.info(
# f'page_id: {page_id}, last_page_cost_time: {round(time.time() - start_time, 2)}'
# )
# start_time = time_now
#
# """解析pdf中的每一页"""
# if start_page_id <= page_id <= end_page_id:
# page_info = parse_page_core(
# page, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode, lang
# )
# else:
# page_info = page.get_page_info()
# page_w = page_info.w
# page_h = page_info.h
# page_info = ocr_construct_page_component_v2(
# [], [], page_id, page_w, page_h, [], [], [], [], [], True, 'skip page'
# )
# pdf_info_dict[f'page_{page_id}'] = page_info
def process_page(page_id, page, dataset_len, start_page_id, end_page_id, magic_model, pdf_bytes_md5, imageWriter,
parse_mode, lang, debug_mode, start_time):
if debug_mode: if debug_mode:
time_now = time.time() time_now = time.time()
logger.info(
f'page_id: {page_id}, last_page_cost_time: {round(time.time() - start_time, 2)}'
)
start_time = time_now
"""解析pdf中的每一页"""
if start_page_id <= page_id <= end_page_id: if start_page_id <= page_id <= end_page_id:
page_info = parse_page_core( page_info = parse_page_core(
page, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode, lang page, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode, lang
...@@ -958,7 +977,43 @@ def pdf_parse_union( ...@@ -958,7 +977,43 @@ def pdf_parse_union(
page_info = ocr_construct_page_component_v2( page_info = ocr_construct_page_component_v2(
[], [], page_id, page_w, page_h, [], [], [], [], [], True, 'skip page' [], [], page_id, page_w, page_h, [], [], [], [], [], True, 'skip page'
) )
pdf_info_dict[f'page_{page_id}'] = page_info return page_id, page_info
# Use max_workers based on CPU count but limit to avoid excessive resource usage
max_workers = 2
pdf_info_dict = {}
with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = {
executor.submit(
process_page,
page_id,
page,
len(dataset),
start_page_id,
end_page_id,
magic_model,
pdf_bytes_md5,
imageWriter,
parse_mode,
lang,
debug_mode,
time.time()
): page_id
for page_id, page in enumerate(dataset)
}
for page_id in range(len(dataset)):
future = [f for f in futures if futures[f] == page_id][0]
try:
page_id, page_info = future.result()
pdf_info_dict[f'page_{page_id}'] = page_info
except Exception as e:
logger.exception(f"Error processing page {page_id}: {e}")
logger.info(
f'page_process_time: {round(time.time() - start_time, 2)}'
)
"""分段""" """分段"""
para_split(pdf_info_dict) para_split(pdf_info_dict)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment