"llm/git@developer.sourcefind.cn:orangecat/ollama.git" did not exist on "f3604534e5bdab1666777adaecfac9963f83347e"
Commit 43164533 authored by icecraft's avatar icecraft
Browse files

feat: inference with iter style

parent ce212da1
......@@ -107,50 +107,13 @@ def batch_build_dataset(pdf_paths, k, lang=None):
pdf_info = []
total_pages = 0
results = []
for pdf_path in pdf_paths:
try:
doc = fitz.open(pdf_path)
num_pages = len(doc)
pdf_info.append((pdf_path, num_pages))
total_pages += num_pages
doc.close()
with open(pdf_path, 'rb') as f:
bits = f.read()
results.append(PymuDocDataset(bits, lang))
except Exception as e:
print(f'Error opening {pdf_path}: {e}')
# Partition the jobs based on page countEach job has 1 page
partitions = partition_array_greedy(pdf_info, k)
# Process each partition in parallel
all_images_h = {}
with concurrent.futures.ProcessPoolExecutor(max_workers=k) as executor:
# Submit one task per partition
futures = []
for sn, partition in enumerate(partitions):
# Get the jobs for this partition
partition_jobs = [pdf_info[idx] for idx in partition]
# Submit the task
future = executor.submit(
process_pdf_batch,
partition_jobs,
sn
)
futures.append(future)
# Process results as they complete
for i, future in enumerate(concurrent.futures.as_completed(futures)):
try:
idx, images = future.result()
all_images_h[idx] = images
except Exception as e:
print(f'Error processing partition: {e}')
results = [None] * len(pdf_paths)
for i in range(len(partitions)):
partition = partitions[i]
for j in range(len(partition)):
with open(pdf_info[partition[j]][0], 'rb') as f:
pdf_bytes = f.read()
dataset = PymuDocDataset(pdf_bytes, lang=lang)
dataset.set_images(all_images_h[i][j])
results[partition[j]] = dataset
return results
......@@ -342,17 +342,8 @@ class Doc(PageableData):
height: int
}
"""
if self._img is None:
self._img = fitz_doc_to_image(self._doc)
return self._img
return fitz_doc_to_image(self._doc)
def set_image(self, img):
"""
Args:
img (np.ndarray): the image
"""
if self._img is None:
self._img = img
def get_doc(self) -> fitz.Page:
"""Get the pymudoc object.
......
......@@ -138,30 +138,31 @@ def doc_analyze(
)
MIN_BATCH_INFERENCE_SIZE = int(os.environ.get('MINERU_MIN_BATCH_INFERENCE_SIZE', 200))
batch_size = MIN_BATCH_INFERENCE_SIZE
images = []
page_wh_list = []
images_with_extra_info = []
results = []
for index in range(len(dataset)):
if start_page_id <= index <= end_page_id:
page_data = dataset.get_page(index)
img_dict = page_data.get_image()
images.append(img_dict['img'])
page_wh_list.append((img_dict['width'], img_dict['height']))
if lang is None or lang == 'auto':
images_with_extra_info = [(images[index], ocr, dataset._lang) for index in range(len(images))]
else:
images_with_extra_info = [(images[index], ocr, lang) for index in range(len(images))]
if len(images) >= MIN_BATCH_INFERENCE_SIZE:
batch_size = MIN_BATCH_INFERENCE_SIZE
batch_images = [images_with_extra_info[i:i+batch_size] for i in range(0, len(images_with_extra_info), batch_size)]
else:
batch_images = [images_with_extra_info]
results = []
for sn, batch_image in enumerate(batch_images):
_, result = may_batch_image_analyze(batch_image, sn, ocr, show_log,layout_model, formula_enable, table_enable)
if lang is None or lang == 'auto':
images_with_extra_info.append((images[index], ocr, dataset._lang))
else:
images_with_extra_info.append((images[index], ocr, lang))
if len(images_with_extra_info) == batch_size:
_, result = may_batch_image_analyze(images_with_extra_info, 0, ocr, show_log, layout_model, formula_enable, table_enable)
results.extend(result)
images_with_extra_info = []
if len(images_with_extra_info) > 0:
_, result = may_batch_image_analyze(images_with_extra_info, 0, ocr, show_log, layout_model, formula_enable, table_enable)
results.extend(result)
images_with_extra_info = []
model_json = []
for index in range(len(dataset)):
......@@ -193,6 +194,7 @@ def batch_doc_analyze(
batch_size = MIN_BATCH_INFERENCE_SIZE
images = []
page_wh_list = []
results = []
images_with_extra_info = []
for dataset in datasets:
......@@ -211,11 +213,15 @@ def batch_doc_analyze(
else:
images_with_extra_info.append((images[-1], parse_method == 'ocr', _lang))
batch_images = [images_with_extra_info[i:i+batch_size] for i in range(0, len(images_with_extra_info), batch_size)]
results = []
for sn, batch_image in enumerate(batch_images):
_, result = may_batch_image_analyze(batch_image, sn, True, show_log, layout_model, formula_enable, table_enable)
if len(images_with_extra_info) == batch_size:
_, result = may_batch_image_analyze(images_with_extra_info, 0, True, show_log, layout_model, formula_enable, table_enable)
results.extend(result)
images_with_extra_info = []
if len(images_with_extra_info) > 0:
_, result = may_batch_image_analyze(images_with_extra_info, 0, True, show_log, layout_model, formula_enable, table_enable)
results.extend(result)
images_with_extra_info = []
infer_results = []
from magic_pdf.operators.models import InferenceResult
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment