Commit d2fc9dab authored by myhloli's avatar myhloli
Browse files

refactor(model): optimize batch processing and inference

- Update batch processing logic for improved efficiency
- Refactor image analysis and inference methods
- Optimize dataset handling and image retrieval
- Improve error handling and logging in batch processes
parent 930bc47f
......@@ -103,17 +103,65 @@ def batch_build_dataset(pdf_paths, k, lang=None):
all_images : list
List of all processed images
"""
# Get page counts for each PDF
pdf_info = []
total_pages = 0
results = []
for pdf_path in pdf_paths:
try:
with open(pdf_path, 'rb') as f:
bits = f.read()
results.append(PymuDocDataset(bits, lang))
except Exception as e:
print(f'Error opening {pdf_path}: {e}')
pdf_bytes = f.read()
dataset = PymuDocDataset(pdf_bytes, lang=lang)
results.append(dataset)
return results
#
# # Get page counts for each PDF
# pdf_info = []
# total_pages = 0
#
# for pdf_path in pdf_paths:
# try:
# doc = fitz.open(pdf_path)
# num_pages = len(doc)
# pdf_info.append((pdf_path, num_pages))
# total_pages += num_pages
# doc.close()
# except Exception as e:
# print(f'Error opening {pdf_path}: {e}')
#
# # Partition the jobs based on page countEach job has 1 page
# partitions = partition_array_greedy(pdf_info, k)
#
# # Process each partition in parallel
# all_images_h = {}
#
# with concurrent.futures.ProcessPoolExecutor(max_workers=k) as executor:
# # Submit one task per partition
# futures = []
# for sn, partition in enumerate(partitions):
# # Get the jobs for this partition
# partition_jobs = [pdf_info[idx] for idx in partition]
#
# # Submit the task
# future = executor.submit(
# process_pdf_batch,
# partition_jobs,
# sn
# )
# futures.append(future)
# # Process results as they complete
# for i, future in enumerate(concurrent.futures.as_completed(futures)):
# try:
# idx, images = future.result()
# all_images_h[idx] = images
# except Exception as e:
# print(f'Error processing partition: {e}')
# results = [None] * len(pdf_paths)
# for i in range(len(partitions)):
# partition = partitions[i]
# for j in range(len(partition)):
# with open(pdf_info[partition[j]][0], 'rb') as f:
# pdf_bytes = f.read()
# dataset = PymuDocDataset(pdf_bytes, lang=lang)
# dataset.set_images(all_images_h[i][j])
# results[partition[j]] = dataset
# return results
\ No newline at end of file
......@@ -150,7 +150,7 @@ class PymuDocDataset(Dataset):
elif lang == 'auto':
from magic_pdf.model.sub_modules.language_detection.utils import \
auto_detect_lang
self._lang = auto_detect_lang(bits)
self._lang = auto_detect_lang(self._data_bits)
logger.info(f'lang: {lang}, detect_lang: {self._lang}')
else:
self._lang = lang
......@@ -342,8 +342,17 @@ class Doc(PageableData):
height: int
}
"""
return fitz_doc_to_image(self._doc)
if self._img is None:
self._img = fitz_doc_to_image(self._doc)
return self._img
def set_image(self, img):
"""
Args:
img (np.ndarray): the image
"""
if self._img is None:
self._img = img
def get_doc(self) -> fitz.Page:
"""Get the pymudoc object.
......
......@@ -30,8 +30,14 @@ class BatchAnalyze:
images_layout_res = []
layout_start_time = time.time()
_, fst_ocr, fst_lang = images_with_extra_info[0]
self.model = self.model_manager.get_model(fst_ocr, self.show_log, fst_lang, self.layout_model, self.formula_enable, self.table_enable)
self.model = self.model_manager.get_model(
ocr=True,
show_log=self.show_log,
lang = None,
layout_model = self.layout_model,
formula_enable = self.formula_enable,
table_enable = self.table_enable,
)
images = [image for image, _, _ in images_with_extra_info]
......
......@@ -138,31 +138,27 @@ def doc_analyze(
)
MIN_BATCH_INFERENCE_SIZE = int(os.environ.get('MINERU_MIN_BATCH_INFERENCE_SIZE', 200))
batch_size = MIN_BATCH_INFERENCE_SIZE
images = []
page_wh_list = []
images_with_extra_info = []
results = []
for index in range(len(dataset)):
if start_page_id <= index <= end_page_id:
page_data = dataset.get_page(index)
img_dict = page_data.get_image()
images.append(img_dict['img'])
page_wh_list.append((img_dict['width'], img_dict['height']))
if lang is None or lang == 'auto':
images_with_extra_info.append((images[index], ocr, dataset._lang))
else:
images_with_extra_info.append((images[index], ocr, lang))
if len(images_with_extra_info) == batch_size:
_, result = may_batch_image_analyze(images_with_extra_info, 0, ocr, show_log, layout_model, formula_enable, table_enable)
results.extend(result)
images_with_extra_info = []
images_with_extra_info = [(images[index], ocr, dataset._lang) for index in range(len(dataset))]
if len(images) >= MIN_BATCH_INFERENCE_SIZE:
batch_size = MIN_BATCH_INFERENCE_SIZE
batch_images = [images_with_extra_info[i:i+batch_size] for i in range(0, len(images_with_extra_info), batch_size)]
else:
batch_images = [images_with_extra_info]
if len(images_with_extra_info) > 0:
_, result = may_batch_image_analyze(images_with_extra_info, 0, ocr, show_log, layout_model, formula_enable, table_enable)
results = []
for batch_image in batch_images:
result = may_batch_image_analyze(batch_image, ocr, show_log,layout_model, formula_enable, table_enable)
results.extend(result)
images_with_extra_info = []
model_json = []
for index in range(len(dataset)):
......@@ -183,7 +179,7 @@ def doc_analyze(
def batch_doc_analyze(
datasets: list[Dataset],
parse_method: str,
parse_method: str = 'auto',
show_log: bool = False,
lang=None,
layout_model=None,
......@@ -192,36 +188,35 @@ def batch_doc_analyze(
):
MIN_BATCH_INFERENCE_SIZE = int(os.environ.get('MINERU_MIN_BATCH_INFERENCE_SIZE', 200))
batch_size = MIN_BATCH_INFERENCE_SIZE
images = []
page_wh_list = []
results = []
images_with_extra_info = []
for dataset in datasets:
for index in range(len(dataset)):
if lang is None or lang == 'auto':
ocr = False
if parse_method == 'auto':
if dataset.classify() == SupportedPdfParseMethod.TXT:
ocr = False
elif dataset.classify() == SupportedPdfParseMethod.OCR:
ocr = True
elif parse_method == 'ocr':
ocr = True
elif parse_method == 'txt':
ocr = False
_lang = dataset._lang
else:
_lang = lang
for index in range(len(dataset)):
page_data = dataset.get_page(index)
img_dict = page_data.get_image()
images.append(img_dict['img'])
page_wh_list.append((img_dict['width'], img_dict['height']))
if parse_method == 'auto':
images_with_extra_info.append((images[-1], dataset.classify() == SupportedPdfParseMethod.OCR, _lang))
else:
images_with_extra_info.append((images[-1], parse_method == 'ocr', _lang))
if len(images_with_extra_info) == batch_size:
_, result = may_batch_image_analyze(images_with_extra_info, 0, True, show_log, layout_model, formula_enable, table_enable)
results.extend(result)
images_with_extra_info = []
images_with_extra_info.append((img_dict['img'], ocr, _lang))
if len(images_with_extra_info) > 0:
_, result = may_batch_image_analyze(images_with_extra_info, 0, True, show_log, layout_model, formula_enable, table_enable)
batch_images = [images_with_extra_info[i:i+batch_size] for i in range(0, len(images_with_extra_info), batch_size)]
results = []
for batch_image in batch_images:
result = may_batch_image_analyze(batch_image, True, show_log, layout_model, formula_enable, table_enable)
results.extend(result)
images_with_extra_info = []
infer_results = []
from magic_pdf.operators.models import InferenceResult
......@@ -240,7 +235,6 @@ def batch_doc_analyze(
def may_batch_image_analyze(
images_with_extra_info: list[(np.ndarray, bool, str)],
idx: int,
ocr: bool,
show_log: bool = False,
layout_model=None,
......@@ -298,4 +292,4 @@ def may_batch_image_analyze(
# f'doc analyze time: {round(time.time() - doc_analyze_start, 2)},'
# f' speed: {doc_analyze_speed} pages/second'
# )
return idx, results
return results
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment