Commit 3a2f86a1 authored by icecraft's avatar icecraft
Browse files

feat: add parallel evalution

parent 734ae27b
import os
import glob
import threading
import concurrent.futures
import fitz
from magic_pdf.data.utils import fitz_doc_to_image # PyMuPDF
from magic_pdf.data.dataset import PymuDocDataset
def partition_array_greedy(arr, k):
"""
Partition an array into k parts using a simple greedy approach.
Parameters:
-----------
arr : list
The input array of integers
k : int
Number of partitions to create
Returns:
--------
partitions : list of lists
The k partitions of the array
"""
# Handle edge cases
if k <= 0:
raise ValueError("k must be a positive integer")
if k > len(arr):
k = len(arr) # Adjust k if it's too large
if k == 1:
return [list(range(len(arr)))]
if k == len(arr):
return [[i] for i in range(len(arr))]
# Sort the array in descending order
sorted_indices = sorted(range(len(arr)), key=lambda i: arr[i][1], reverse=True)
# Initialize k empty partitions
partitions = [[] for _ in range(k)]
partition_sums = [0] * k
# Assign each element to the partition with the smallest current sum
for idx in sorted_indices:
# Find the partition with the smallest sum
min_sum_idx = partition_sums.index(min(partition_sums))
# Add the element to this partition
partitions[min_sum_idx].append(idx) # Store the original index
partition_sums[min_sum_idx] += arr[idx][1]
return partitions
def process_pdf_batch(pdf_jobs, idx):
"""
Process a batch of PDF pages using multiple threads.
Parameters:
-----------
pdf_jobs : list of tuples
List of (pdf_path, page_num) tuples
output_dir : str or None
Directory to save images to
num_threads : int
Number of threads to use
**kwargs :
Additional arguments for process_pdf_page
Returns:
--------
images : list
List of processed images
"""
images = []
for pdf_path, _ in pdf_jobs:
doc = fitz.open(pdf_path)
tmp = []
for page_num in range(len(doc)):
page = doc[page_num]
tmp.append(fitz_doc_to_image(page))
images.append(tmp)
return (idx, images)
def batch_build_dataset(pdf_paths, k, lang=None):
"""
Process multiple PDFs by partitioning them into k balanced parts and processing each part in parallel.
Parameters:
-----------
pdf_paths : list
List of paths to PDF files
k : int
Number of partitions to create
output_dir : str or None
Directory to save images to
threads_per_worker : int
Number of threads to use per worker
**kwargs :
Additional arguments for process_pdf_page
Returns:
--------
all_images : list
List of all processed images
"""
# Get page counts for each PDF
pdf_info = []
total_pages = 0
for pdf_path in pdf_paths:
try:
doc = fitz.open(pdf_path)
num_pages = len(doc)
pdf_info.append((pdf_path, num_pages))
total_pages += num_pages
doc.close()
except Exception as e:
print(f"Error opening {pdf_path}: {e}")
# Partition the jobs based on page countEach job has 1 page
partitions = partition_array_greedy(pdf_info, k)
for i, partition in enumerate(partitions):
print(f"Partition {i+1}: {len(partition)} pdfs")
# Process each partition in parallel
all_images_h = {}
with concurrent.futures.ProcessPoolExecutor(max_workers=k) as executor:
# Submit one task per partition
futures = []
for sn, partition in enumerate(partitions):
# Get the jobs for this partition
partition_jobs = [pdf_info[idx] for idx in partition]
# Submit the task
future = executor.submit(
process_pdf_batch,
partition_jobs,
sn
)
futures.append(future)
# Process results as they complete
for i, future in enumerate(concurrent.futures.as_completed(futures)):
try:
idx, images = future.result()
print(f"Partition {i+1} completed: processed {len(images)} images")
all_images_h[idx] = images
except Exception as e:
print(f"Error processing partition: {e}")
results = [None] * len(pdf_paths)
for i in range(len(partitions)):
partition = partitions[i]
for j in range(len(partition)):
with open(pdf_info[partition[j]][0], "rb") as f:
pdf_bytes = f.read()
dataset = PymuDocDataset(pdf_bytes, lang=lang)
dataset.set_images(all_images_h[i][j])
results[partition[j]] = dataset
return results
...@@ -154,6 +154,7 @@ class PymuDocDataset(Dataset): ...@@ -154,6 +154,7 @@ class PymuDocDataset(Dataset):
else: else:
self._lang = lang self._lang = lang
logger.info(f"lang: {lang}") logger.info(f"lang: {lang}")
def __len__(self) -> int: def __len__(self) -> int:
"""The page number of the pdf.""" """The page number of the pdf."""
return len(self._records) return len(self._records)
...@@ -224,6 +225,9 @@ class PymuDocDataset(Dataset): ...@@ -224,6 +225,9 @@ class PymuDocDataset(Dataset):
""" """
return PymuDocDataset(self._raw_data) return PymuDocDataset(self._raw_data)
def set_images(self, images):
for i in range(len(self._records)):
self._records[i].set_image(images[i])
class ImageDataset(Dataset): class ImageDataset(Dataset):
def __init__(self, bits: bytes): def __init__(self, bits: bytes):
...@@ -304,12 +308,17 @@ class ImageDataset(Dataset): ...@@ -304,12 +308,17 @@ class ImageDataset(Dataset):
"""clone this dataset """clone this dataset
""" """
return ImageDataset(self._raw_data) return ImageDataset(self._raw_data)
def set_images(self, images):
for i in range(len(self._records)):
self._records[i].set_image(images[i])
class Doc(PageableData): class Doc(PageableData):
"""Initialized with pymudoc object.""" """Initialized with pymudoc object."""
def __init__(self, doc: fitz.Page): def __init__(self, doc: fitz.Page):
self._doc = doc self._doc = doc
self._img = None
def get_image(self): def get_image(self):
"""Return the image info. """Return the image info.
...@@ -321,7 +330,17 @@ class Doc(PageableData): ...@@ -321,7 +330,17 @@ class Doc(PageableData):
height: int height: int
} }
""" """
return fitz_doc_to_image(self._doc) if self._img is None:
self._img = fitz_doc_to_image(self._doc)
return self._img
def set_image(self, img):
"""
Args:
img (np.ndarray): the image
"""
if self._img is None:
self._img = img
def get_doc(self) -> fitz.Page: def get_doc(self) -> fitz.Page:
"""Get the pymudoc object. """Get the pymudoc object.
......
import multiprocessing as mp
import threading
import fitz import fitz
import numpy as np import numpy as np
from loguru import logger from loguru import logger
from magic_pdf.utils.annotations import ImportPIL from magic_pdf.utils.annotations import ImportPIL
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
@ImportPIL @ImportPIL
...@@ -65,3 +68,105 @@ def load_images_from_pdf(pdf_bytes: bytes, dpi=200, start_page_id=0, end_page_id ...@@ -65,3 +68,105 @@ def load_images_from_pdf(pdf_bytes: bytes, dpi=200, start_page_id=0, end_page_id
images.append(img_dict) images.append(img_dict)
return images return images
def convert_page(bytes_page):
pdfs = fitz.open('pdf', bytes_page)
page = pdfs[0]
return fitz_doc_to_image(page)
def parallel_process_pdf_safe(pages, num_workers=None, **kwargs):
"""Process PDF pages in parallel with serialization-safe approach"""
if num_workers is None:
num_workers = mp.cpu_count()
# Process the extracted page data in parallel
with ProcessPoolExecutor(max_workers=num_workers) as executor:
# Process the page data
results = list(
executor.map(convert_page, pages)
)
return results
def threaded_process_pdf(pdf_path, num_threads=4, **kwargs):
"""
Process all pages of a PDF using multiple threads
Parameters:
-----------
pdf_path : str
Path to the PDF file
num_threads : int
Number of threads to use
**kwargs :
Additional arguments for fitz_doc_to_image
Returns:
--------
images : list
List of processed images, in page order
"""
# Open the PDF
doc = fitz.open(pdf_path)
num_pages = len(doc)
# Create a list to store results in the correct order
results = [None] * num_pages
# Create a thread pool
with ThreadPoolExecutor(max_workers=num_threads) as executor:
# Submit all tasks
futures = {}
for page_num in range(num_pages):
page = doc[page_num]
future = executor.submit(fitz_doc_to_image, page, **kwargs)
futures[future] = page_num
# Process results as they complete with progress bar
for future in as_completed(futures):
page_num = futures[future]
try:
results[page_num] = future.result()
except Exception as e:
print(f"Error processing page {page_num}: {e}")
results[page_num] = None
# Close the document
doc.close()
if __name__ == "__main__":
pdf = fitz.open('/tmp/[MS-DOC].pdf')
pdf_page = [fitz.open() for i in range(pdf.page_count)]
[pdf_page[i].insert_pdf(pdf, from_page=i, to_page=i) for i in range(pdf.page_count)]
pdf_page = [v.tobytes() for v in pdf_page]
results = parallel_process_pdf_safe(pdf_page, num_workers=16)
# threaded_process_pdf('/tmp/[MS-DOC].pdf', num_threads=16)
""" benchmark results of multi-threaded processing (fitz page to image)
total page nums: 578
thread nums, time cost
1 7.351 sec
2 6.334 sec
4 5.968 sec
8 6.728 sec
16 8.085 sec
"""
""" benchmark results of multi-processor processing (fitz page to image)
total page nums: 578
processor nums, time cost
1 17.170 sec
2 10.170 sec
4 7.841 sec
8 7.900 sec
16 7.984 sec
"""
import os import os
import time import time
import torch import torch
import numpy as np
import multiprocessing as mp
import concurrent.futures as fut
os.environ['FLAGS_npu_jit_compile'] = '0' # 关闭paddle的jit编译 os.environ['FLAGS_npu_jit_compile'] = '0' # 关闭paddle的jit编译
os.environ['FLAGS_use_stride_kernel'] = '0' os.environ['FLAGS_use_stride_kernel'] = '0'
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 让mps可以fallback os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 让mps可以fallback
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新 os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
# 关闭paddle的信号处理
import paddle
paddle.disable_signal_handler()
from loguru import logger from loguru import logger
from magic_pdf.model.batch_analyze import BatchAnalyze
from magic_pdf.model.sub_modules.model_utils import get_vram from magic_pdf.model.sub_modules.model_utils import get_vram
try: try:
...@@ -30,8 +29,9 @@ from magic_pdf.libs.config_reader import (get_device, get_formula_config, ...@@ -30,8 +29,9 @@ from magic_pdf.libs.config_reader import (get_device, get_formula_config,
get_local_models_dir, get_local_models_dir,
get_table_recog_config) get_table_recog_config)
from magic_pdf.model.model_list import MODEL from magic_pdf.model.model_list import MODEL
from magic_pdf.operators.models import InferenceResult # from magic_pdf.operators.models import InferenceResult
MIN_BATCH_INFERENCE_SIZE = 100
class ModelSingleton: class ModelSingleton:
_instance = None _instance = None
...@@ -72,9 +72,7 @@ def custom_model_init( ...@@ -72,9 +72,7 @@ def custom_model_init(
formula_enable=None, formula_enable=None,
table_enable=None, table_enable=None,
): ):
model = None model = None
if model_config.__model_mode__ == 'lite': if model_config.__model_mode__ == 'lite':
logger.warning( logger.warning(
'The Lite mode is provided for developers to conduct testing only, and the output quality is ' 'The Lite mode is provided for developers to conduct testing only, and the output quality is '
...@@ -132,7 +130,6 @@ def custom_model_init( ...@@ -132,7 +130,6 @@ def custom_model_init(
return custom_model return custom_model
def doc_analyze( def doc_analyze(
dataset: Dataset, dataset: Dataset,
ocr: bool = False, ocr: bool = False,
...@@ -143,14 +140,166 @@ def doc_analyze( ...@@ -143,14 +140,166 @@ def doc_analyze(
layout_model=None, layout_model=None,
formula_enable=None, formula_enable=None,
table_enable=None, table_enable=None,
) -> InferenceResult: one_shot: bool = True,
):
end_page_id = ( end_page_id = (
end_page_id end_page_id
if end_page_id is not None and end_page_id >= 0 if end_page_id is not None and end_page_id >= 0
else len(dataset) - 1 else len(dataset) - 1
) )
parallel_count = None
if os.environ.get('MINERU_PARALLEL_INFERENCE_COUNT'):
parallel_count = int(os.environ['MINERU_PARALLEL_INFERENCE_COUNT'])
images = []
page_wh_list = []
for index in range(len(dataset)):
if start_page_id <= index <= end_page_id:
page_data = dataset.get_page(index)
img_dict = page_data.get_image()
images.append(img_dict['img'])
page_wh_list.append((img_dict['width'], img_dict['height']))
if one_shot and len(images) >= MIN_BATCH_INFERENCE_SIZE:
if parallel_count is None:
parallel_count = 2 # should check the gpu memory firstly !
# split images into parallel_count batches
if parallel_count > 1:
batch_size = (len(images) + parallel_count - 1) // parallel_count
batch_images = [images[i:i+batch_size] for i in range(0, len(images), batch_size)]
else:
batch_images = [images]
results = []
parallel_count = len(batch_images) # adjust to real parallel count
# using concurrent.futures to analyze
"""
with fut.ProcessPoolExecutor(max_workers=parallel_count) as executor:
futures = [executor.submit(may_batch_image_analyze, batch_image, sn, ocr, show_log, lang, layout_model, formula_enable, table_enable) for sn, batch_image in enumerate(batch_images)]
for future in fut.as_completed(futures):
sn, result = future.result()
result_history[sn] = result
for key in sorted(result_history.keys()):
results.extend(result_history[key])
"""
results = []
pool = mp.Pool(processes=parallel_count)
mapped_results = pool.starmap(may_batch_image_analyze, [(batch_image, sn, ocr, show_log, lang, layout_model, formula_enable, table_enable) for sn, batch_image in enumerate(batch_images)])
for sn, result in mapped_results:
results.extend(result)
else:
_, results = may_batch_image_analyze(
images,
0,
ocr,
show_log,
lang, layout_model, formula_enable, table_enable)
model_json = []
for index in range(len(dataset)):
if start_page_id <= index <= end_page_id:
result = results.pop(0)
page_width, page_height = page_wh_list.pop(0)
else:
result = []
page_height = 0
page_width = 0
page_info = {'page_no': index, 'width': page_width, 'height': page_height}
page_dict = {'layout_dets': result, 'page_info': page_info}
model_json.append(page_dict)
from magic_pdf.operators.models import InferenceResult
return InferenceResult(model_json, dataset)
def batch_doc_analyze(
datasets: list[Dataset],
ocr: bool = False,
show_log: bool = False,
lang=None,
layout_model=None,
formula_enable=None,
table_enable=None,
one_shot: bool = True,
):
parallel_count = None
if os.environ.get('MINERU_PARALLEL_INFERENCE_COUNT'):
parallel_count = int(os.environ['MINERU_PARALLEL_INFERENCE_COUNT'])
images = []
page_wh_list = []
for dataset in datasets:
for index in range(len(dataset)):
page_data = dataset.get_page(index)
img_dict = page_data.get_image()
images.append(img_dict['img'])
page_wh_list.append((img_dict['width'], img_dict['height']))
if one_shot and len(images) >= MIN_BATCH_INFERENCE_SIZE:
if parallel_count is None:
parallel_count = 2 # should check the gpu memory firstly !
# split images into parallel_count batches
if parallel_count > 1:
batch_size = (len(images) + parallel_count - 1) // parallel_count
batch_images = [images[i:i+batch_size] for i in range(0, len(images), batch_size)]
else:
batch_images = [images]
results = []
parallel_count = len(batch_images) # adjust to real parallel count
# using concurrent.futures to analyze
"""
with fut.ProcessPoolExecutor(max_workers=parallel_count) as executor:
futures = [executor.submit(may_batch_image_analyze, batch_image, sn, ocr, show_log, lang, layout_model, formula_enable, table_enable) for sn, batch_image in enumerate(batch_images)]
for future in fut.as_completed(futures):
sn, result = future.result()
result_history[sn] = result
for key in sorted(result_history.keys()):
results.extend(result_history[key])
"""
results = []
pool = mp.Pool(processes=parallel_count)
mapped_results = pool.starmap(may_batch_image_analyze, [(batch_image, sn, ocr, show_log, lang, layout_model, formula_enable, table_enable) for sn, batch_image in enumerate(batch_images)])
for sn, result in mapped_results:
results.extend(result)
else:
_, results = may_batch_image_analyze(
images,
0,
ocr,
show_log,
lang, layout_model, formula_enable, table_enable)
infer_results = []
from magic_pdf.operators.models import InferenceResult
for index in range(len(datasets)):
dataset = datasets[index]
model_json = []
for i in range(len(dataset)):
result = results.pop(0)
page_width, page_height = page_wh_list.pop(0)
page_info = {'page_no': i, 'width': page_width, 'height': page_height}
page_dict = {'layout_dets': result, 'page_info': page_info}
model_json.append(page_dict)
infer_results.append(InferenceResult(model_json, dataset))
return infer_results
def may_batch_image_analyze(
images: list[np.ndarray],
idx: int,
ocr: bool = False,
show_log: bool = False,
lang=None,
layout_model=None,
formula_enable=None,
table_enable=None):
# os.environ['CUDA_VISIBLE_DEVICES'] = str(idx)
# 关闭paddle的信号处理
import paddle
paddle.disable_signal_handler()
from magic_pdf.model.batch_analyze import BatchAnalyze
model_manager = ModelSingleton() model_manager = ModelSingleton()
custom_model = model_manager.get_model( custom_model = model_manager.get_model(
ocr, show_log, lang, layout_model, formula_enable, table_enable ocr, show_log, lang, layout_model, formula_enable, table_enable
...@@ -181,12 +330,10 @@ def doc_analyze( ...@@ -181,12 +330,10 @@ def doc_analyze(
logger.info(f'gpu_memory: {gpu_memory} GB, batch_ratio: {batch_ratio}') logger.info(f'gpu_memory: {gpu_memory} GB, batch_ratio: {batch_ratio}')
batch_analyze = True batch_analyze = True
model_json = []
doc_analyze_start = time.time() doc_analyze_start = time.time()
if batch_analyze: if batch_analyze:
# batch analyze """# batch analyze
images = [] images = []
page_wh_list = [] page_wh_list = []
for index in range(len(dataset)): for index in range(len(dataset)):
...@@ -195,9 +342,10 @@ def doc_analyze( ...@@ -195,9 +342,10 @@ def doc_analyze(
img_dict = page_data.get_image() img_dict = page_data.get_image()
images.append(img_dict['img']) images.append(img_dict['img'])
page_wh_list.append((img_dict['width'], img_dict['height'])) page_wh_list.append((img_dict['width'], img_dict['height']))
"""
batch_model = BatchAnalyze(model=custom_model, batch_ratio=batch_ratio) batch_model = BatchAnalyze(model=custom_model, batch_ratio=batch_ratio)
analyze_result = batch_model(images) results = batch_model(images)
"""
for index in range(len(dataset)): for index in range(len(dataset)):
if start_page_id <= index <= end_page_id: if start_page_id <= index <= end_page_id:
result = analyze_result.pop(0) result = analyze_result.pop(0)
...@@ -210,10 +358,10 @@ def doc_analyze( ...@@ -210,10 +358,10 @@ def doc_analyze(
page_info = {'page_no': index, 'width': page_width, 'height': page_height} page_info = {'page_no': index, 'width': page_width, 'height': page_height}
page_dict = {'layout_dets': result, 'page_info': page_info} page_dict = {'layout_dets': result, 'page_info': page_info}
model_json.append(page_dict) model_json.append(page_dict)
"""
else: else:
# single analyze # single analyze
"""
for index in range(len(dataset)): for index in range(len(dataset)):
page_data = dataset.get_page(index) page_data = dataset.get_page(index)
img_dict = page_data.get_image() img_dict = page_data.get_image()
...@@ -230,6 +378,13 @@ def doc_analyze( ...@@ -230,6 +378,13 @@ def doc_analyze(
page_info = {'page_no': index, 'width': page_width, 'height': page_height} page_info = {'page_no': index, 'width': page_width, 'height': page_height}
page_dict = {'layout_dets': result, 'page_info': page_info} page_dict = {'layout_dets': result, 'page_info': page_info}
model_json.append(page_dict) model_json.append(page_dict)
"""
results = []
for img_idx, img in enumerate(images):
inference_start = time.time()
result = custom_model(img)
logger.info(f'-----image index : {img_idx}, image inference total time: {round(time.time() - inference_start, 2)}-----')
results.append(result)
gc_start = time.time() gc_start = time.time()
clean_memory(get_device()) clean_memory(get_device())
...@@ -237,10 +392,10 @@ def doc_analyze( ...@@ -237,10 +392,10 @@ def doc_analyze(
logger.info(f'gc time: {gc_time}') logger.info(f'gc time: {gc_time}')
doc_analyze_time = round(time.time() - doc_analyze_start, 2) doc_analyze_time = round(time.time() - doc_analyze_start, 2)
doc_analyze_speed = round((end_page_id + 1 - start_page_id) / doc_analyze_time, 2) doc_analyze_speed = round(len(images) / doc_analyze_time, 2)
logger.info( logger.info(
f'doc analyze time: {round(time.time() - doc_analyze_start, 2)},' f'doc analyze time: {round(time.time() - doc_analyze_start, 2)},'
f' speed: {doc_analyze_speed} pages/second' f' speed: {doc_analyze_speed} pages/second'
) )
return (idx, results)
return InferenceResult(model_json, dataset)
import os
import torch import torch
from loguru import logger from loguru import logger
......
...@@ -8,10 +8,13 @@ from pathlib import Path ...@@ -8,10 +8,13 @@ from pathlib import Path
import magic_pdf.model as model_config import magic_pdf.model as model_config
from magic_pdf.data.data_reader_writer import FileBasedDataReader from magic_pdf.data.data_reader_writer import FileBasedDataReader
from magic_pdf.data.batch_build_dataset import batch_build_dataset
from magic_pdf.data.dataset import Dataset
from magic_pdf.libs.version import __version__ from magic_pdf.libs.version import __version__
from magic_pdf.tools.common import do_parse, parse_pdf_methods from magic_pdf.tools.common import do_parse, parse_pdf_methods, batch_do_parse
from magic_pdf.utils.office_to_pdf import convert_file_to_pdf from magic_pdf.utils.office_to_pdf import convert_file_to_pdf
pdf_suffixes = ['.pdf'] pdf_suffixes = ['.pdf']
ms_office_suffixes = ['.ppt', '.pptx', '.doc', '.docx'] ms_office_suffixes = ['.ppt', '.pptx', '.doc', '.docx']
image_suffixes = ['.png', '.jpeg', '.jpg'] image_suffixes = ['.png', '.jpeg', '.jpg']
...@@ -110,14 +113,17 @@ def cli(path, output_dir, method, lang, debug_able, start_page_id, end_page_id): ...@@ -110,14 +113,17 @@ def cli(path, output_dir, method, lang, debug_able, start_page_id, end_page_id):
disk_rw = FileBasedDataReader(os.path.dirname(fn)) disk_rw = FileBasedDataReader(os.path.dirname(fn))
return disk_rw.read(os.path.basename(fn)) return disk_rw.read(os.path.basename(fn))
def parse_doc(doc_path: Path): def parse_doc(doc_path: Path, dataset: Dataset | None = None):
try: try:
file_name = str(Path(doc_path).stem) file_name = str(Path(doc_path).stem)
pdf_data = read_fn(doc_path) if dataset is None:
pdf_data_or_dataset = read_fn(doc_path)
else:
pdf_data_or_dataset = dataset
do_parse( do_parse(
output_dir, output_dir,
file_name, file_name,
pdf_data, pdf_data_or_dataset,
[], [],
method, method,
debug_able, debug_able,
...@@ -130,9 +136,12 @@ def cli(path, output_dir, method, lang, debug_able, start_page_id, end_page_id): ...@@ -130,9 +136,12 @@ def cli(path, output_dir, method, lang, debug_able, start_page_id, end_page_id):
logger.exception(e) logger.exception(e)
if os.path.isdir(path): if os.path.isdir(path):
doc_paths = []
for doc_path in Path(path).glob('*'): for doc_path in Path(path).glob('*'):
if doc_path.suffix in pdf_suffixes + image_suffixes + ms_office_suffixes: if doc_path.suffix in pdf_suffixes + image_suffixes + ms_office_suffixes:
parse_doc(doc_path) doc_paths.append(doc_path)
datasets = batch_build_dataset(doc_paths, 4, lang)
batch_do_parse(output_dir, [str(doc_path.stem) for doc_path in doc_paths], datasets, method, debug_able, lang=lang)
else: else:
parse_doc(Path(path)) parse_doc(Path(path))
......
...@@ -8,10 +8,10 @@ import magic_pdf.model as model_config ...@@ -8,10 +8,10 @@ import magic_pdf.model as model_config
from magic_pdf.config.enums import SupportedPdfParseMethod from magic_pdf.config.enums import SupportedPdfParseMethod
from magic_pdf.config.make_content_config import DropMode, MakeMode from magic_pdf.config.make_content_config import DropMode, MakeMode
from magic_pdf.data.data_reader_writer import FileBasedDataWriter from magic_pdf.data.data_reader_writer import FileBasedDataWriter
from magic_pdf.data.dataset import PymuDocDataset from magic_pdf.data.dataset import PymuDocDataset, Dataset
from magic_pdf.libs.draw_bbox import draw_char_bbox from magic_pdf.libs.draw_bbox import draw_char_bbox
from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze, batch_doc_analyze
from magic_pdf.operators.models import InferenceResult
# from io import BytesIO # from io import BytesIO
# from pypdf import PdfReader, PdfWriter # from pypdf import PdfReader, PdfWriter
...@@ -67,10 +67,10 @@ def convert_pdf_bytes_to_bytes_by_pymupdf(pdf_bytes, start_page_id=0, end_page_i ...@@ -67,10 +67,10 @@ def convert_pdf_bytes_to_bytes_by_pymupdf(pdf_bytes, start_page_id=0, end_page_i
return output_bytes return output_bytes
def do_parse( def _do_parse(
output_dir, output_dir,
pdf_file_name, pdf_file_name,
pdf_bytes, pdf_bytes_or_dataset,
model_list, model_list,
parse_method, parse_method,
debug_able, debug_able,
...@@ -92,16 +92,21 @@ def do_parse( ...@@ -92,16 +92,21 @@ def do_parse(
formula_enable=None, formula_enable=None,
table_enable=None, table_enable=None,
): ):
from magic_pdf.operators.models import InferenceResult
if debug_able: if debug_able:
logger.warning('debug mode is on') logger.warning('debug mode is on')
f_draw_model_bbox = True f_draw_model_bbox = True
f_draw_line_sort_bbox = True f_draw_line_sort_bbox = True
# f_draw_char_bbox = True # f_draw_char_bbox = True
pdf_bytes = convert_pdf_bytes_to_bytes_by_pymupdf( if isinstance(pdf_bytes_or_dataset, bytes):
pdf_bytes, start_page_id, end_page_id pdf_bytes = convert_pdf_bytes_to_bytes_by_pymupdf(
) pdf_bytes_or_dataset, start_page_id, end_page_id
)
ds = PymuDocDataset(pdf_bytes, lang=lang)
else:
ds = pdf_bytes_or_dataset
pdf_bytes = ds._raw_data
local_image_dir, local_md_dir = prepare_env(output_dir, pdf_file_name, parse_method) local_image_dir, local_md_dir = prepare_env(output_dir, pdf_file_name, parse_method)
image_writer, md_writer = FileBasedDataWriter(local_image_dir), FileBasedDataWriter( image_writer, md_writer = FileBasedDataWriter(local_image_dir), FileBasedDataWriter(
...@@ -109,8 +114,6 @@ def do_parse( ...@@ -109,8 +114,6 @@ def do_parse(
) )
image_dir = str(os.path.basename(local_image_dir)) image_dir = str(os.path.basename(local_image_dir))
ds = PymuDocDataset(pdf_bytes, lang=lang)
if len(model_list) == 0: if len(model_list) == 0:
if model_config.__use_inside_model__: if model_config.__use_inside_model__:
if parse_method == 'auto': if parse_method == 'auto':
...@@ -241,5 +244,79 @@ def do_parse( ...@@ -241,5 +244,79 @@ def do_parse(
logger.info(f'local output dir is {local_md_dir}') logger.info(f'local output dir is {local_md_dir}')
def do_parse(
output_dir,
pdf_file_name,
pdf_bytes_or_dataset,
model_list,
parse_method,
debug_able,
f_draw_span_bbox=True,
f_draw_layout_bbox=True,
f_dump_md=True,
f_dump_middle_json=True,
f_dump_model_json=True,
f_dump_orig_pdf=True,
f_dump_content_list=True,
f_make_md_mode=MakeMode.MM_MD,
f_draw_model_bbox=False,
f_draw_line_sort_bbox=False,
f_draw_char_bbox=False,
start_page_id=0,
end_page_id=None,
lang=None,
layout_model=None,
formula_enable=None,
table_enable=None,
):
parallel_count = 1
if os.environ.get('MINERU_PARALLEL_INFERENCE_COUNT'):
parallel_count = int(os.environ['MINERU_PARALLEL_INFERENCE_COUNT'])
if parallel_count > 1:
if isinstance(pdf_bytes_or_dataset, bytes):
pdf_bytes = convert_pdf_bytes_to_bytes_by_pymupdf(
pdf_bytes_or_dataset, start_page_id, end_page_id
)
ds = PymuDocDataset(pdf_bytes, lang=lang)
else:
ds = pdf_bytes_or_dataset
batch_do_parse(output_dir, [pdf_file_name], [ds], parse_method, debug_able, f_draw_span_bbox=f_draw_span_bbox, f_draw_layout_bbox=f_draw_layout_bbox, f_dump_md=f_dump_md, f_dump_middle_json=f_dump_middle_json, f_dump_model_json=f_dump_model_json, f_dump_orig_pdf=f_dump_orig_pdf, f_dump_content_list=f_dump_content_list, f_make_md_mode=f_make_md_mode, f_draw_model_bbox=f_draw_model_bbox, f_draw_line_sort_bbox=f_draw_line_sort_bbox, f_draw_char_bbox=f_draw_char_bbox)
else:
_do_parse(output_dir, pdf_file_name, pdf_bytes_or_dataset, model_list, parse_method, debug_able, start_page_id=start_page_id, end_page_id=end_page_id, lang=lang, layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable, f_draw_span_bbox=f_draw_span_bbox, f_draw_layout_bbox=f_draw_layout_bbox, f_dump_md=f_dump_md, f_dump_middle_json=f_dump_middle_json, f_dump_model_json=f_dump_model_json, f_dump_orig_pdf=f_dump_orig_pdf, f_dump_content_list=f_dump_content_list, f_make_md_mode=f_make_md_mode, f_draw_model_bbox=f_draw_model_bbox, f_draw_line_sort_bbox=f_draw_line_sort_bbox, f_draw_char_bbox=f_draw_char_bbox)
def batch_do_parse(
output_dir,
pdf_file_names: list[str],
pdf_bytes_or_datasets: list[bytes | Dataset],
parse_method,
debug_able,
f_draw_span_bbox=True,
f_draw_layout_bbox=True,
f_dump_md=True,
f_dump_middle_json=True,
f_dump_model_json=True,
f_dump_orig_pdf=True,
f_dump_content_list=True,
f_make_md_mode=MakeMode.MM_MD,
f_draw_model_bbox=False,
f_draw_line_sort_bbox=False,
f_draw_char_bbox=False,
lang=None,
layout_model=None,
formula_enable=None,
table_enable=None,
):
dss = []
for v in pdf_bytes_or_datasets:
if isinstance(v, bytes):
dss.append(PymuDocDataset(v, lang=lang))
else:
dss.append(v)
infer_results = batch_doc_analyze(dss, lang=lang, layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable, one_shot=True)
for idx, infer_result in enumerate(infer_results):
_do_parse(output_dir, pdf_file_names[idx], dss[idx], infer_result.get_infer_res(), parse_method, debug_able, f_draw_span_bbox=f_draw_span_bbox, f_draw_layout_bbox=f_draw_layout_bbox, f_dump_md=f_dump_md, f_dump_middle_json=f_dump_middle_json, f_dump_model_json=f_dump_model_json, f_dump_orig_pdf=f_dump_orig_pdf, f_dump_content_list=f_dump_content_list, f_make_md_mode=f_make_md_mode, f_draw_model_bbox=f_draw_model_bbox, f_draw_line_sort_bbox=f_draw_line_sort_bbox, f_draw_char_bbox=f_draw_char_bbox)
parse_pdf_methods = click.Choice(['ocr', 'txt', 'auto']) parse_pdf_methods = click.Choice(['ocr', 'txt', 'auto'])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment