Unverified Commit 852b841a authored by Xiaomeng Zhao's avatar Xiaomeng Zhao Committed by GitHub
Browse files

Merge pull request #2189 from myhloli/dev

refactor(model): optimize batch processing and inference 
parents 1c7f41dd 54ce594b
...@@ -103,17 +103,65 @@ def batch_build_dataset(pdf_paths, k, lang=None): ...@@ -103,17 +103,65 @@ def batch_build_dataset(pdf_paths, k, lang=None):
all_images : list all_images : list
List of all processed images List of all processed images
""" """
# Get page counts for each PDF
pdf_info = []
total_pages = 0
results = [] results = []
for pdf_path in pdf_paths: for pdf_path in pdf_paths:
try: with open(pdf_path, 'rb') as f:
with open(pdf_path, 'rb') as f: pdf_bytes = f.read()
bits = f.read() dataset = PymuDocDataset(pdf_bytes, lang=lang)
results.append(PymuDocDataset(bits, lang)) results.append(dataset)
except Exception as e:
print(f'Error opening {pdf_path}: {e}')
return results return results
#
# # Get page counts for each PDF
# pdf_info = []
# total_pages = 0
#
# for pdf_path in pdf_paths:
# try:
# doc = fitz.open(pdf_path)
# num_pages = len(doc)
# pdf_info.append((pdf_path, num_pages))
# total_pages += num_pages
# doc.close()
# except Exception as e:
# print(f'Error opening {pdf_path}: {e}')
#
# # Partition the jobs based on page countEach job has 1 page
# partitions = partition_array_greedy(pdf_info, k)
#
# # Process each partition in parallel
# all_images_h = {}
#
# with concurrent.futures.ProcessPoolExecutor(max_workers=k) as executor:
# # Submit one task per partition
# futures = []
# for sn, partition in enumerate(partitions):
# # Get the jobs for this partition
# partition_jobs = [pdf_info[idx] for idx in partition]
#
# # Submit the task
# future = executor.submit(
# process_pdf_batch,
# partition_jobs,
# sn
# )
# futures.append(future)
# # Process results as they complete
# for i, future in enumerate(concurrent.futures.as_completed(futures)):
# try:
# idx, images = future.result()
# all_images_h[idx] = images
# except Exception as e:
# print(f'Error processing partition: {e}')
# results = [None] * len(pdf_paths)
# for i in range(len(partitions)):
# partition = partitions[i]
# for j in range(len(partition)):
# with open(pdf_info[partition[j]][0], 'rb') as f:
# pdf_bytes = f.read()
# dataset = PymuDocDataset(pdf_bytes, lang=lang)
# dataset.set_images(all_images_h[i][j])
# results[partition[j]] = dataset
# return results
\ No newline at end of file
...@@ -150,7 +150,7 @@ class PymuDocDataset(Dataset): ...@@ -150,7 +150,7 @@ class PymuDocDataset(Dataset):
elif lang == 'auto': elif lang == 'auto':
from magic_pdf.model.sub_modules.language_detection.utils import \ from magic_pdf.model.sub_modules.language_detection.utils import \
auto_detect_lang auto_detect_lang
self._lang = auto_detect_lang(bits) self._lang = auto_detect_lang(self._data_bits)
logger.info(f'lang: {lang}, detect_lang: {self._lang}') logger.info(f'lang: {lang}, detect_lang: {self._lang}')
else: else:
self._lang = lang self._lang = lang
...@@ -342,8 +342,17 @@ class Doc(PageableData): ...@@ -342,8 +342,17 @@ class Doc(PageableData):
height: int height: int
} }
""" """
return fitz_doc_to_image(self._doc) if self._img is None:
self._img = fitz_doc_to_image(self._doc)
return self._img
def set_image(self, img):
"""
Args:
img (np.ndarray): the image
"""
if self._img is None:
self._img = img
def get_doc(self) -> fitz.Page: def get_doc(self) -> fitz.Page:
"""Get the pymudoc object. """Get the pymudoc object.
...@@ -396,4 +405,4 @@ class Doc(PageableData): ...@@ -396,4 +405,4 @@ class Doc(PageableData):
fontsize (int): font size of the text fontsize (int): font size of the text
color (list[float] | None): three element tuple which describe the RGB of the board line, None will use the default font color! color (list[float] | None): three element tuple which describe the RGB of the board line, None will use the default font color!
""" """
self._doc.insert_text(coord, content, fontsize=fontsize, color=color) self._doc.insert_text(coord, content, fontsize=fontsize, color=color)
\ No newline at end of file
...@@ -30,8 +30,14 @@ class BatchAnalyze: ...@@ -30,8 +30,14 @@ class BatchAnalyze:
images_layout_res = [] images_layout_res = []
layout_start_time = time.time() layout_start_time = time.time()
_, fst_ocr, fst_lang = images_with_extra_info[0] self.model = self.model_manager.get_model(
self.model = self.model_manager.get_model(fst_ocr, self.show_log, fst_lang, self.layout_model, self.formula_enable, self.table_enable) ocr=True,
show_log=self.show_log,
lang = None,
layout_model = self.layout_model,
formula_enable = self.formula_enable,
table_enable = self.table_enable,
)
images = [image for image, _, _ in images_with_extra_info] images = [image for image, _, _ in images_with_extra_info]
......
...@@ -138,31 +138,27 @@ def doc_analyze( ...@@ -138,31 +138,27 @@ def doc_analyze(
) )
MIN_BATCH_INFERENCE_SIZE = int(os.environ.get('MINERU_MIN_BATCH_INFERENCE_SIZE', 200)) MIN_BATCH_INFERENCE_SIZE = int(os.environ.get('MINERU_MIN_BATCH_INFERENCE_SIZE', 200))
batch_size = MIN_BATCH_INFERENCE_SIZE
images = [] images = []
page_wh_list = [] page_wh_list = []
images_with_extra_info = []
results = []
for index in range(len(dataset)): for index in range(len(dataset)):
if start_page_id <= index <= end_page_id: if start_page_id <= index <= end_page_id:
page_data = dataset.get_page(index) page_data = dataset.get_page(index)
img_dict = page_data.get_image() img_dict = page_data.get_image()
images.append(img_dict['img']) images.append(img_dict['img'])
page_wh_list.append((img_dict['width'], img_dict['height'])) page_wh_list.append((img_dict['width'], img_dict['height']))
if lang is None or lang == 'auto':
images_with_extra_info.append((images[index], ocr, dataset._lang)) images_with_extra_info = [(images[index], ocr, dataset._lang) for index in range(len(dataset))]
else:
images_with_extra_info.append((images[index], ocr, lang)) if len(images) >= MIN_BATCH_INFERENCE_SIZE:
batch_size = MIN_BATCH_INFERENCE_SIZE
if len(images_with_extra_info) == batch_size: batch_images = [images_with_extra_info[i:i+batch_size] for i in range(0, len(images_with_extra_info), batch_size)]
_, result = may_batch_image_analyze(images_with_extra_info, 0, ocr, show_log, layout_model, formula_enable, table_enable) else:
results.extend(result) batch_images = [images_with_extra_info]
images_with_extra_info = []
results = []
if len(images_with_extra_info) > 0: for batch_image in batch_images:
_, result = may_batch_image_analyze(images_with_extra_info, 0, ocr, show_log, layout_model, formula_enable, table_enable) result = may_batch_image_analyze(batch_image, ocr, show_log,layout_model, formula_enable, table_enable)
results.extend(result) results.extend(result)
images_with_extra_info = []
model_json = [] model_json = []
for index in range(len(dataset)): for index in range(len(dataset)):
...@@ -183,7 +179,7 @@ def doc_analyze( ...@@ -183,7 +179,7 @@ def doc_analyze(
def batch_doc_analyze( def batch_doc_analyze(
datasets: list[Dataset], datasets: list[Dataset],
parse_method: str, parse_method: str = 'auto',
show_log: bool = False, show_log: bool = False,
lang=None, lang=None,
layout_model=None, layout_model=None,
...@@ -192,36 +188,35 @@ def batch_doc_analyze( ...@@ -192,36 +188,35 @@ def batch_doc_analyze(
): ):
MIN_BATCH_INFERENCE_SIZE = int(os.environ.get('MINERU_MIN_BATCH_INFERENCE_SIZE', 200)) MIN_BATCH_INFERENCE_SIZE = int(os.environ.get('MINERU_MIN_BATCH_INFERENCE_SIZE', 200))
batch_size = MIN_BATCH_INFERENCE_SIZE batch_size = MIN_BATCH_INFERENCE_SIZE
images = []
page_wh_list = [] page_wh_list = []
results = []
images_with_extra_info = [] images_with_extra_info = []
for dataset in datasets: for dataset in datasets:
for index in range(len(dataset)):
if lang is None or lang == 'auto':
_lang = dataset._lang
else:
_lang = lang
ocr = False
if parse_method == 'auto':
if dataset.classify() == SupportedPdfParseMethod.TXT:
ocr = False
elif dataset.classify() == SupportedPdfParseMethod.OCR:
ocr = True
elif parse_method == 'ocr':
ocr = True
elif parse_method == 'txt':
ocr = False
_lang = dataset._lang
for index in range(len(dataset)):
page_data = dataset.get_page(index) page_data = dataset.get_page(index)
img_dict = page_data.get_image() img_dict = page_data.get_image()
images.append(img_dict['img'])
page_wh_list.append((img_dict['width'], img_dict['height'])) page_wh_list.append((img_dict['width'], img_dict['height']))
if parse_method == 'auto': images_with_extra_info.append((img_dict['img'], ocr, _lang))
images_with_extra_info.append((images[-1], dataset.classify() == SupportedPdfParseMethod.OCR, _lang))
else:
images_with_extra_info.append((images[-1], parse_method == 'ocr', _lang))
if len(images_with_extra_info) == batch_size: batch_images = [images_with_extra_info[i:i+batch_size] for i in range(0, len(images_with_extra_info), batch_size)]
_, result = may_batch_image_analyze(images_with_extra_info, 0, True, show_log, layout_model, formula_enable, table_enable) results = []
results.extend(result) for batch_image in batch_images:
images_with_extra_info = [] result = may_batch_image_analyze(batch_image, True, show_log, layout_model, formula_enable, table_enable)
if len(images_with_extra_info) > 0:
_, result = may_batch_image_analyze(images_with_extra_info, 0, True, show_log, layout_model, formula_enable, table_enable)
results.extend(result) results.extend(result)
images_with_extra_info = []
infer_results = [] infer_results = []
from magic_pdf.operators.models import InferenceResult from magic_pdf.operators.models import InferenceResult
...@@ -240,7 +235,6 @@ def batch_doc_analyze( ...@@ -240,7 +235,6 @@ def batch_doc_analyze(
def may_batch_image_analyze( def may_batch_image_analyze(
images_with_extra_info: list[(np.ndarray, bool, str)], images_with_extra_info: list[(np.ndarray, bool, str)],
idx: int,
ocr: bool, ocr: bool,
show_log: bool = False, show_log: bool = False,
layout_model=None, layout_model=None,
...@@ -298,4 +292,4 @@ def may_batch_image_analyze( ...@@ -298,4 +292,4 @@ def may_batch_image_analyze(
# f'doc analyze time: {round(time.time() - doc_analyze_start, 2)},' # f'doc analyze time: {round(time.time() - doc_analyze_start, 2)},'
# f' speed: {doc_analyze_speed} pages/second' # f' speed: {doc_analyze_speed} pages/second'
# ) # )
return idx, results return results
\ No newline at end of file
...@@ -109,9 +109,7 @@ def _do_parse( ...@@ -109,9 +109,7 @@ def _do_parse(
pdf_bytes = ds._raw_data pdf_bytes = ds._raw_data
local_image_dir, local_md_dir = prepare_env(output_dir, pdf_file_name, parse_method) local_image_dir, local_md_dir = prepare_env(output_dir, pdf_file_name, parse_method)
image_writer, md_writer = FileBasedDataWriter(local_image_dir), FileBasedDataWriter( image_writer, md_writer = FileBasedDataWriter(local_image_dir), FileBasedDataWriter(local_md_dir)
local_md_dir
)
image_dir = str(os.path.basename(local_image_dir)) image_dir = str(os.path.basename(local_image_dir))
if len(model_list) == 0: if len(model_list) == 0:
...@@ -317,7 +315,26 @@ def batch_do_parse( ...@@ -317,7 +315,26 @@ def batch_do_parse(
infer_results = batch_doc_analyze(dss, parse_method, lang=lang, layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable) infer_results = batch_doc_analyze(dss, parse_method, lang=lang, layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable)
for idx, infer_result in enumerate(infer_results): for idx, infer_result in enumerate(infer_results):
_do_parse(output_dir, pdf_file_names[idx], dss[idx], infer_result.get_infer_res(), parse_method, debug_able, f_draw_span_bbox=f_draw_span_bbox, f_draw_layout_bbox=f_draw_layout_bbox, f_dump_md=f_dump_md, f_dump_middle_json=f_dump_middle_json, f_dump_model_json=f_dump_model_json, f_dump_orig_pdf=f_dump_orig_pdf, f_dump_content_list=f_dump_content_list, f_make_md_mode=f_make_md_mode, f_draw_model_bbox=f_draw_model_bbox, f_draw_line_sort_bbox=f_draw_line_sort_bbox, f_draw_char_bbox=f_draw_char_bbox, lang=lang) _do_parse(
output_dir = output_dir,
pdf_file_name = pdf_file_names[idx],
pdf_bytes_or_dataset = dss[idx],
model_list = infer_result.get_infer_res(),
parse_method = parse_method,
debug_able = debug_able,
f_draw_span_bbox = f_draw_span_bbox,
f_draw_layout_bbox = f_draw_layout_bbox,
f_dump_md=f_dump_md,
f_dump_middle_json=f_dump_middle_json,
f_dump_model_json=f_dump_model_json,
f_dump_orig_pdf=f_dump_orig_pdf,
f_dump_content_list=f_dump_content_list,
f_make_md_mode=MakeMode.MM_MD,
f_draw_model_bbox=f_draw_model_bbox,
f_draw_line_sort_bbox=f_draw_line_sort_bbox,
f_draw_char_bbox=f_draw_char_bbox,
lang=lang,
)
parse_pdf_methods = click.Choice(['ocr', 'txt', 'auto']) parse_pdf_methods = click.Choice(['ocr', 'txt', 'auto'])
...@@ -7,7 +7,7 @@ numpy>=1.21.6 ...@@ -7,7 +7,7 @@ numpy>=1.21.6
pydantic>=2.7.2,<2.11 pydantic>=2.7.2,<2.11
PyMuPDF>=1.24.9,<1.25.0 PyMuPDF>=1.24.9,<1.25.0
scikit-learn>=1.0.2 scikit-learn>=1.0.2
torch>=2.2.2,!=2.5.0,!=2.5.1,<=2.6.0 torch>=2.2.2,!=2.5.0,!=2.5.1
torchvision torchvision
transformers>=4.49.0,!=4.51.0,<5.0.0 transformers>=4.49.0,!=4.51.0,<5.0.0
pdfminer.six==20231228 pdfminer.six==20231228
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment