Unverified Commit ecdd162f authored by Xiaomeng Zhao's avatar Xiaomeng Zhao Committed by GitHub
Browse files

Merge pull request #1910 from icecraft/fix/parallel_split

Fix/parallel split
parents 734ae27b c67a4793
import concurrent.futures
import fitz
from magic_pdf.data.dataset import PymuDocDataset
from magic_pdf.data.utils import fitz_doc_to_image # PyMuPDF
def partition_array_greedy(arr, k):
"""Partition an array into k parts using a simple greedy approach.
Parameters:
-----------
arr : list
The input array of integers
k : int
Number of partitions to create
Returns:
--------
partitions : list of lists
The k partitions of the array
"""
# Handle edge cases
if k <= 0:
raise ValueError('k must be a positive integer')
if k > len(arr):
k = len(arr) # Adjust k if it's too large
if k == 1:
return [list(range(len(arr)))]
if k == len(arr):
return [[i] for i in range(len(arr))]
# Sort the array in descending order
sorted_indices = sorted(range(len(arr)), key=lambda i: arr[i][1], reverse=True)
# Initialize k empty partitions
partitions = [[] for _ in range(k)]
partition_sums = [0] * k
# Assign each element to the partition with the smallest current sum
for idx in sorted_indices:
# Find the partition with the smallest sum
min_sum_idx = partition_sums.index(min(partition_sums))
# Add the element to this partition
partitions[min_sum_idx].append(idx) # Store the original index
partition_sums[min_sum_idx] += arr[idx][1]
return partitions
def process_pdf_batch(pdf_jobs, idx):
"""Process a batch of PDF pages using multiple threads.
Parameters:
-----------
pdf_jobs : list of tuples
List of (pdf_path, page_num) tuples
output_dir : str or None
Directory to save images to
num_threads : int
Number of threads to use
**kwargs :
Additional arguments for process_pdf_page
Returns:
--------
images : list
List of processed images
"""
images = []
for pdf_path, _ in pdf_jobs:
doc = fitz.open(pdf_path)
tmp = []
for page_num in range(len(doc)):
page = doc[page_num]
tmp.append(fitz_doc_to_image(page))
images.append(tmp)
return (idx, images)
def batch_build_dataset(pdf_paths, k, lang=None):
"""Process multiple PDFs by partitioning them into k balanced parts and
processing each part in parallel.
Parameters:
-----------
pdf_paths : list
List of paths to PDF files
k : int
Number of partitions to create
output_dir : str or None
Directory to save images to
threads_per_worker : int
Number of threads to use per worker
**kwargs :
Additional arguments for process_pdf_page
Returns:
--------
all_images : list
List of all processed images
"""
# Get page counts for each PDF
pdf_info = []
total_pages = 0
for pdf_path in pdf_paths:
try:
doc = fitz.open(pdf_path)
num_pages = len(doc)
pdf_info.append((pdf_path, num_pages))
total_pages += num_pages
doc.close()
except Exception as e:
print(f'Error opening {pdf_path}: {e}')
# Partition the jobs based on page countEach job has 1 page
partitions = partition_array_greedy(pdf_info, k)
# Process each partition in parallel
all_images_h = {}
with concurrent.futures.ProcessPoolExecutor(max_workers=k) as executor:
# Submit one task per partition
futures = []
for sn, partition in enumerate(partitions):
# Get the jobs for this partition
partition_jobs = [pdf_info[idx] for idx in partition]
# Submit the task
future = executor.submit(
process_pdf_batch,
partition_jobs,
sn
)
futures.append(future)
# Process results as they complete
for i, future in enumerate(concurrent.futures.as_completed(futures)):
try:
idx, images = future.result()
all_images_h[idx] = images
except Exception as e:
print(f'Error processing partition: {e}')
results = [None] * len(pdf_paths)
for i in range(len(partitions)):
partition = partitions[i]
for j in range(len(partition)):
with open(pdf_info[partition[j]][0], 'rb') as f:
pdf_bytes = f.read()
dataset = PymuDocDataset(pdf_bytes, lang=lang)
dataset.set_images(all_images_h[i][j])
results[partition[j]] = dataset
return results
......@@ -97,10 +97,10 @@ class Dataset(ABC):
@abstractmethod
def dump_to_file(self, file_path: str):
"""Dump the file
"""Dump the file.
Args:
file_path (str): the file path
Args:
file_path (str): the file path
"""
pass
......@@ -119,7 +119,7 @@ class Dataset(ABC):
@abstractmethod
def classify(self) -> SupportedPdfParseMethod:
"""classify the dataset
"""classify the dataset.
Returns:
SupportedPdfParseMethod: _description_
......@@ -128,8 +128,7 @@ class Dataset(ABC):
@abstractmethod
def clone(self):
"""clone this dataset
"""
"""clone this dataset."""
pass
......@@ -148,12 +147,14 @@ class PymuDocDataset(Dataset):
if lang == '':
self._lang = None
elif lang == 'auto':
from magic_pdf.model.sub_modules.language_detection.utils import auto_detect_lang
from magic_pdf.model.sub_modules.language_detection.utils import \
auto_detect_lang
self._lang = auto_detect_lang(bits)
logger.info(f"lang: {lang}, detect_lang: {self._lang}")
logger.info(f'lang: {lang}, detect_lang: {self._lang}')
else:
self._lang = lang
logger.info(f"lang: {lang}")
logger.info(f'lang: {lang}')
def __len__(self) -> int:
"""The page number of the pdf."""
return len(self._records)
......@@ -186,12 +187,12 @@ class PymuDocDataset(Dataset):
return self._records[page_id]
def dump_to_file(self, file_path: str):
"""Dump the file
"""Dump the file.
Args:
file_path (str): the file path
Args:
file_path (str): the file path
"""
dir_name = os.path.dirname(file_path)
if dir_name not in ('', '.', '..'):
os.makedirs(dir_name, exist_ok=True)
......@@ -212,7 +213,7 @@ class PymuDocDataset(Dataset):
return proc(self, *args, **kwargs)
def classify(self) -> SupportedPdfParseMethod:
"""classify the dataset
"""classify the dataset.
Returns:
SupportedPdfParseMethod: _description_
......@@ -220,10 +221,12 @@ class PymuDocDataset(Dataset):
return classify(self._data_bits)
def clone(self):
"""clone this dataset
"""
"""clone this dataset."""
return PymuDocDataset(self._raw_data)
def set_images(self, images):
for i in range(len(self._records)):
self._records[i].set_image(images[i])
class ImageDataset(Dataset):
def __init__(self, bits: bytes):
......@@ -270,10 +273,10 @@ class ImageDataset(Dataset):
return self._records[page_id]
def dump_to_file(self, file_path: str):
"""Dump the file
"""Dump the file.
Args:
file_path (str): the file path
Args:
file_path (str): the file path
"""
dir_name = os.path.dirname(file_path)
if dir_name not in ('', '.', '..'):
......@@ -293,7 +296,7 @@ class ImageDataset(Dataset):
return proc(self, *args, **kwargs)
def classify(self) -> SupportedPdfParseMethod:
"""classify the dataset
"""classify the dataset.
Returns:
SupportedPdfParseMethod: _description_
......@@ -301,15 +304,19 @@ class ImageDataset(Dataset):
return SupportedPdfParseMethod.OCR
def clone(self):
"""clone this dataset
"""
"""clone this dataset."""
return ImageDataset(self._raw_data)
def set_images(self, images):
for i in range(len(self._records)):
self._records[i].set_image(images[i])
class Doc(PageableData):
"""Initialized with pymudoc object."""
def __init__(self, doc: fitz.Page):
self._doc = doc
self._img = None
def get_image(self):
"""Return the image info.
......@@ -321,7 +328,17 @@ class Doc(PageableData):
height: int
}
"""
return fitz_doc_to_image(self._doc)
if self._img is None:
self._img = fitz_doc_to_image(self._doc)
return self._img
def set_image(self, img):
"""
Args:
img (np.ndarray): the image
"""
if self._img is None:
self._img = img
def get_doc(self) -> fitz.Page:
"""Get the pymudoc object.
......
import multiprocessing as mp
import threading
from concurrent.futures import (ProcessPoolExecutor, ThreadPoolExecutor,
as_completed)
import fitz
import numpy as np
from loguru import logger
......@@ -65,3 +70,101 @@ def load_images_from_pdf(pdf_bytes: bytes, dpi=200, start_page_id=0, end_page_id
images.append(img_dict)
return images
def convert_page(bytes_page):
pdfs = fitz.open('pdf', bytes_page)
page = pdfs[0]
return fitz_doc_to_image(page)
def parallel_process_pdf_safe(pages, num_workers=None, **kwargs):
"""Process PDF pages in parallel with serialization-safe approach."""
if num_workers is None:
num_workers = mp.cpu_count()
# Process the extracted page data in parallel
with ProcessPoolExecutor(max_workers=num_workers) as executor:
# Process the page data
results = list(
executor.map(convert_page, pages)
)
return results
def threaded_process_pdf(pdf_path, num_threads=4, **kwargs):
"""Process all pages of a PDF using multiple threads.
Parameters:
-----------
pdf_path : str
Path to the PDF file
num_threads : int
Number of threads to use
**kwargs :
Additional arguments for fitz_doc_to_image
Returns:
--------
images : list
List of processed images, in page order
"""
# Open the PDF
doc = fitz.open(pdf_path)
num_pages = len(doc)
# Create a list to store results in the correct order
results = [None] * num_pages
# Create a thread pool
with ThreadPoolExecutor(max_workers=num_threads) as executor:
# Submit all tasks
futures = {}
for page_num in range(num_pages):
page = doc[page_num]
future = executor.submit(fitz_doc_to_image, page, **kwargs)
futures[future] = page_num
# Process results as they complete with progress bar
for future in as_completed(futures):
page_num = futures[future]
try:
results[page_num] = future.result()
except Exception as e:
print(f'Error processing page {page_num}: {e}')
results[page_num] = None
# Close the document
doc.close()
if __name__ == '__main__':
pdf = fitz.open('/tmp/[MS-DOC].pdf')
pdf_page = [fitz.open() for i in range(pdf.page_count)]
[pdf_page[i].insert_pdf(pdf, from_page=i, to_page=i) for i in range(pdf.page_count)]
pdf_page = [v.tobytes() for v in pdf_page]
results = parallel_process_pdf_safe(pdf_page, num_workers=16)
# threaded_process_pdf('/tmp/[MS-DOC].pdf', num_threads=16)
""" benchmark results of multi-threaded processing (fitz page to image)
total page nums: 578
thread nums, time cost
1 7.351 sec
2 6.334 sec
4 5.968 sec
8 6.728 sec
16 8.085 sec
"""
""" benchmark results of multi-processor processing (fitz page to image)
total page nums: 578
processor nums, time cost
1 17.170 sec
2 10.170 sec
4 7.841 sec
8 7.900 sec
16 7.984 sec
"""
import concurrent.futures as fut
import multiprocessing as mp
import os
import time
import numpy as np
import torch
os.environ['FLAGS_npu_jit_compile'] = '0' # 关闭paddle的jit编译
os.environ['FLAGS_use_stride_kernel'] = '0'
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 让mps可以fallback
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
# 关闭paddle的信号处理
import paddle
paddle.disable_signal_handler()
from loguru import logger
from magic_pdf.model.batch_analyze import BatchAnalyze
from magic_pdf.model.sub_modules.model_utils import get_vram
try:
......@@ -30,8 +31,10 @@ from magic_pdf.libs.config_reader import (get_device, get_formula_config,
get_local_models_dir,
get_table_recog_config)
from magic_pdf.model.model_list import MODEL
from magic_pdf.operators.models import InferenceResult
# from magic_pdf.operators.models import InferenceResult
MIN_BATCH_INFERENCE_SIZE = 100
class ModelSingleton:
_instance = None
......@@ -72,9 +75,7 @@ def custom_model_init(
formula_enable=None,
table_enable=None,
):
model = None
if model_config.__model_mode__ == 'lite':
logger.warning(
'The Lite mode is provided for developers to conduct testing only, and the output quality is '
......@@ -132,7 +133,6 @@ def custom_model_init(
return custom_model
def doc_analyze(
dataset: Dataset,
ocr: bool = False,
......@@ -143,13 +143,165 @@ def doc_analyze(
layout_model=None,
formula_enable=None,
table_enable=None,
) -> InferenceResult:
one_shot: bool = True,
):
end_page_id = (
end_page_id
if end_page_id is not None and end_page_id >= 0
else len(dataset) - 1
)
parallel_count = None
if os.environ.get('MINERU_PARALLEL_INFERENCE_COUNT'):
parallel_count = int(os.environ['MINERU_PARALLEL_INFERENCE_COUNT'])
images = []
page_wh_list = []
for index in range(len(dataset)):
if start_page_id <= index <= end_page_id:
page_data = dataset.get_page(index)
img_dict = page_data.get_image()
images.append(img_dict['img'])
page_wh_list.append((img_dict['width'], img_dict['height']))
if one_shot and len(images) >= MIN_BATCH_INFERENCE_SIZE:
if parallel_count is None:
parallel_count = 2 # should check the gpu memory firstly !
# split images into parallel_count batches
if parallel_count > 1:
batch_size = (len(images) + parallel_count - 1) // parallel_count
batch_images = [images[i:i+batch_size] for i in range(0, len(images), batch_size)]
else:
batch_images = [images]
results = []
parallel_count = len(batch_images) # adjust to real parallel count
# using concurrent.futures to analyze
"""
with fut.ProcessPoolExecutor(max_workers=parallel_count) as executor:
futures = [executor.submit(may_batch_image_analyze, batch_image, sn, ocr, show_log, lang, layout_model, formula_enable, table_enable) for sn, batch_image in enumerate(batch_images)]
for future in fut.as_completed(futures):
sn, result = future.result()
result_history[sn] = result
for key in sorted(result_history.keys()):
results.extend(result_history[key])
"""
results = []
pool = mp.Pool(processes=parallel_count)
mapped_results = pool.starmap(may_batch_image_analyze, [(batch_image, sn, ocr, show_log, lang, layout_model, formula_enable, table_enable) for sn, batch_image in enumerate(batch_images)])
for sn, result in mapped_results:
results.extend(result)
else:
_, results = may_batch_image_analyze(
images,
0,
ocr,
show_log,
lang, layout_model, formula_enable, table_enable)
model_json = []
for index in range(len(dataset)):
if start_page_id <= index <= end_page_id:
result = results.pop(0)
page_width, page_height = page_wh_list.pop(0)
else:
result = []
page_height = 0
page_width = 0
page_info = {'page_no': index, 'width': page_width, 'height': page_height}
page_dict = {'layout_dets': result, 'page_info': page_info}
model_json.append(page_dict)
from magic_pdf.operators.models import InferenceResult
return InferenceResult(model_json, dataset)
def batch_doc_analyze(
datasets: list[Dataset],
ocr: bool = False,
show_log: bool = False,
lang=None,
layout_model=None,
formula_enable=None,
table_enable=None,
one_shot: bool = True,
):
parallel_count = None
if os.environ.get('MINERU_PARALLEL_INFERENCE_COUNT'):
parallel_count = int(os.environ['MINERU_PARALLEL_INFERENCE_COUNT'])
images = []
page_wh_list = []
for dataset in datasets:
for index in range(len(dataset)):
page_data = dataset.get_page(index)
img_dict = page_data.get_image()
images.append(img_dict['img'])
page_wh_list.append((img_dict['width'], img_dict['height']))
if one_shot and len(images) >= MIN_BATCH_INFERENCE_SIZE:
if parallel_count is None:
parallel_count = 2 # should check the gpu memory firstly !
# split images into parallel_count batches
if parallel_count > 1:
batch_size = (len(images) + parallel_count - 1) // parallel_count
batch_images = [images[i:i+batch_size] for i in range(0, len(images), batch_size)]
else:
batch_images = [images]
results = []
parallel_count = len(batch_images) # adjust to real parallel count
# using concurrent.futures to analyze
"""
with fut.ProcessPoolExecutor(max_workers=parallel_count) as executor:
futures = [executor.submit(may_batch_image_analyze, batch_image, sn, ocr, show_log, lang, layout_model, formula_enable, table_enable) for sn, batch_image in enumerate(batch_images)]
for future in fut.as_completed(futures):
sn, result = future.result()
result_history[sn] = result
for key in sorted(result_history.keys()):
results.extend(result_history[key])
"""
results = []
pool = mp.Pool(processes=parallel_count)
mapped_results = pool.starmap(may_batch_image_analyze, [(batch_image, sn, ocr, show_log, lang, layout_model, formula_enable, table_enable) for sn, batch_image in enumerate(batch_images)])
for sn, result in mapped_results:
results.extend(result)
else:
_, results = may_batch_image_analyze(
images,
0,
ocr,
show_log,
lang, layout_model, formula_enable, table_enable)
infer_results = []
from magic_pdf.operators.models import InferenceResult
for index in range(len(datasets)):
dataset = datasets[index]
model_json = []
for i in range(len(dataset)):
result = results.pop(0)
page_width, page_height = page_wh_list.pop(0)
page_info = {'page_no': i, 'width': page_width, 'height': page_height}
page_dict = {'layout_dets': result, 'page_info': page_info}
model_json.append(page_dict)
infer_results.append(InferenceResult(model_json, dataset))
return infer_results
def may_batch_image_analyze(
images: list[np.ndarray],
idx: int,
ocr: bool = False,
show_log: bool = False,
lang=None,
layout_model=None,
formula_enable=None,
table_enable=None):
# os.environ['CUDA_VISIBLE_DEVICES'] = str(idx)
# 关闭paddle的信号处理
import paddle
paddle.disable_signal_handler()
from magic_pdf.model.batch_analyze import BatchAnalyze
model_manager = ModelSingleton()
custom_model = model_manager.get_model(
......@@ -161,14 +313,14 @@ def doc_analyze(
device = get_device()
npu_support = False
if str(device).startswith("npu"):
if str(device).startswith('npu'):
import torch_npu
if torch_npu.npu.is_available():
npu_support = True
torch.npu.set_compile_mode(jit_compile=False)
if torch.cuda.is_available() and device != 'cpu' or npu_support:
gpu_memory = int(os.getenv("VIRTUAL_VRAM_SIZE", round(get_vram(device))))
gpu_memory = int(os.getenv('VIRTUAL_VRAM_SIZE', round(get_vram(device))))
if gpu_memory is not None and gpu_memory >= 8:
if gpu_memory >= 20:
batch_ratio = 16
......@@ -181,12 +333,10 @@ def doc_analyze(
logger.info(f'gpu_memory: {gpu_memory} GB, batch_ratio: {batch_ratio}')
batch_analyze = True
model_json = []
doc_analyze_start = time.time()
if batch_analyze:
# batch analyze
"""# batch analyze
images = []
page_wh_list = []
for index in range(len(dataset)):
......@@ -195,9 +345,10 @@ def doc_analyze(
img_dict = page_data.get_image()
images.append(img_dict['img'])
page_wh_list.append((img_dict['width'], img_dict['height']))
"""
batch_model = BatchAnalyze(model=custom_model, batch_ratio=batch_ratio)
analyze_result = batch_model(images)
results = batch_model(images)
"""
for index in range(len(dataset)):
if start_page_id <= index <= end_page_id:
result = analyze_result.pop(0)
......@@ -210,10 +361,10 @@ def doc_analyze(
page_info = {'page_no': index, 'width': page_width, 'height': page_height}
page_dict = {'layout_dets': result, 'page_info': page_info}
model_json.append(page_dict)
"""
else:
# single analyze
"""
for index in range(len(dataset)):
page_data = dataset.get_page(index)
img_dict = page_data.get_image()
......@@ -230,6 +381,13 @@ def doc_analyze(
page_info = {'page_no': index, 'width': page_width, 'height': page_height}
page_dict = {'layout_dets': result, 'page_info': page_info}
model_json.append(page_dict)
"""
results = []
for img_idx, img in enumerate(images):
inference_start = time.time()
result = custom_model(img)
logger.info(f'-----image index : {img_idx}, image inference total time: {round(time.time() - inference_start, 2)}-----')
results.append(result)
gc_start = time.time()
clean_memory(get_device())
......@@ -237,10 +395,9 @@ def doc_analyze(
logger.info(f'gc time: {gc_time}')
doc_analyze_time = round(time.time() - doc_analyze_start, 2)
doc_analyze_speed = round((end_page_id + 1 - start_page_id) / doc_analyze_time, 2)
doc_analyze_speed = round(len(images) / doc_analyze_time, 2)
logger.info(
f'doc analyze time: {round(time.time() - doc_analyze_start, 2)},'
f' speed: {doc_analyze_speed} pages/second'
)
return InferenceResult(model_json, dataset)
return (idx, results)
import os
import torch
from loguru import logger
from magic_pdf.config.constants import MODEL_NAME
from magic_pdf.model.model_list import AtomicModel
from magic_pdf.model.sub_modules.language_detection.yolov11.YOLOv11 import YOLOv11LangDetModel
from magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO import DocLayoutYOLOModel
from magic_pdf.model.sub_modules.layout.layoutlmv3.model_init import Layoutlmv3_Predictor
from magic_pdf.model.sub_modules.language_detection.yolov11.YOLOv11 import \
YOLOv11LangDetModel
from magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO import \
DocLayoutYOLOModel
from magic_pdf.model.sub_modules.layout.layoutlmv3.model_init import \
Layoutlmv3_Predictor
from magic_pdf.model.sub_modules.mfd.yolov8.YOLOv8 import YOLOv8MFDModel
from magic_pdf.model.sub_modules.mfr.unimernet.Unimernet import UnimernetModel
try:
from magic_pdf_ascend_plugin.libs.license_verifier import load_license, LicenseFormatError, LicenseSignatureError, LicenseExpiredError
from magic_pdf_ascend_plugin.model_plugin.ocr.paddleocr.ppocr_273_npu import ModifiedPaddleOCR
from magic_pdf_ascend_plugin.model_plugin.table.rapidtable.rapid_table_npu import RapidTableModel
from magic_pdf_ascend_plugin.libs.license_verifier import (
LicenseExpiredError, LicenseFormatError, LicenseSignatureError,
load_license)
from magic_pdf_ascend_plugin.model_plugin.ocr.paddleocr.ppocr_273_npu import \
ModifiedPaddleOCR
from magic_pdf_ascend_plugin.model_plugin.table.rapidtable.rapid_table_npu import \
RapidTableModel
license_key = load_license()
logger.info(f'Using Ascend Plugin Success, License id is {license_key["payload"]["id"]},'
f' License expired at {license_key["payload"]["date"]["end_date"]}')
......@@ -20,21 +29,24 @@ except Exception as e:
if isinstance(e, ImportError):
pass
elif isinstance(e, LicenseFormatError):
logger.error("Ascend Plugin: Invalid license format. Please check the license file.")
logger.error('Ascend Plugin: Invalid license format. Please check the license file.')
elif isinstance(e, LicenseSignatureError):
logger.error("Ascend Plugin: Invalid signature. The license may be tampered with.")
logger.error('Ascend Plugin: Invalid signature. The license may be tampered with.')
elif isinstance(e, LicenseExpiredError):
logger.error("Ascend Plugin: License has expired. Please renew your license.")
logger.error('Ascend Plugin: License has expired. Please renew your license.')
elif isinstance(e, FileNotFoundError):
logger.error("Ascend Plugin: Not found License file.")
logger.error('Ascend Plugin: Not found License file.')
else:
logger.error(f"Ascend Plugin: {e}")
logger.error(f'Ascend Plugin: {e}')
from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_273_mod import ModifiedPaddleOCR
# from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_291_mod import ModifiedPaddleOCR
from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import RapidTableModel
from magic_pdf.model.sub_modules.table.structeqtable.struct_eqtable import StructTableModel
from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import TableMasterPaddleModel
from magic_pdf.model.sub_modules.table.structeqtable.struct_eqtable import \
StructTableModel
from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import \
TableMasterPaddleModel
def table_model_init(table_model_type, model_path, max_time, _device_='cpu', ocr_engine=None, table_sub_model_name=None):
if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
......@@ -55,7 +67,7 @@ def table_model_init(table_model_type, model_path, max_time, _device_='cpu', ocr
def mfd_model_init(weight, device='cpu'):
if str(device).startswith("npu"):
if str(device).startswith('npu'):
device = torch.device(device)
mfd_model = YOLOv8MFDModel(weight, device)
return mfd_model
......@@ -72,14 +84,14 @@ def layout_model_init(weight, config_file, device):
def doclayout_yolo_model_init(weight, device='cpu'):
if str(device).startswith("npu"):
if str(device).startswith('npu'):
device = torch.device(device)
model = DocLayoutYOLOModel(weight, device)
return model
def langdetect_model_init(langdetect_model_weight, device='cpu'):
if str(device).startswith("npu"):
if str(device).startswith('npu'):
device = torch.device(device)
model = YOLOv11LangDetModel(langdetect_model_weight, device)
return model
......
......@@ -5,6 +5,7 @@ import cv2
import numpy as np
import torch
from paddleocr import PaddleOCR
from ppocr.utils.logging import get_logger
from ppocr.utils.utility import alpha_to_color, binarize_img
......
......@@ -2,6 +2,7 @@ import os
import cv2
import numpy as np
from paddleocr import PaddleOCR
from ppstructure.table.predict_table import TableSystem
from ppstructure.utility import init_args
from PIL import Image
......
import os
import shutil
import tempfile
from pathlib import Path
import click
import fitz
from loguru import logger
from pathlib import Path
import magic_pdf.model as model_config
from magic_pdf.data.batch_build_dataset import batch_build_dataset
from magic_pdf.data.data_reader_writer import FileBasedDataReader
from magic_pdf.data.dataset import Dataset
from magic_pdf.libs.version import __version__
from magic_pdf.tools.common import do_parse, parse_pdf_methods
from magic_pdf.tools.common import batch_do_parse, do_parse, parse_pdf_methods
from magic_pdf.utils.office_to_pdf import convert_file_to_pdf
pdf_suffixes = ['.pdf']
......@@ -94,30 +97,33 @@ def cli(path, output_dir, method, lang, debug_able, start_page_id, end_page_id):
def read_fn(path: Path):
if path.suffix in ms_office_suffixes:
convert_file_to_pdf(str(path), temp_dir)
fn = os.path.join(temp_dir, f"{path.stem}.pdf")
fn = os.path.join(temp_dir, f'{path.stem}.pdf')
elif path.suffix in image_suffixes:
with open(str(path), 'rb') as f:
bits = f.read()
pdf_bytes = fitz.open(stream=bits).convert_to_pdf()
fn = os.path.join(temp_dir, f"{path.stem}.pdf")
fn = os.path.join(temp_dir, f'{path.stem}.pdf')
with open(fn, 'wb') as f:
f.write(pdf_bytes)
elif path.suffix in pdf_suffixes:
fn = str(path)
else:
raise Exception(f"Unknown file suffix: {path.suffix}")
raise Exception(f'Unknown file suffix: {path.suffix}')
disk_rw = FileBasedDataReader(os.path.dirname(fn))
return disk_rw.read(os.path.basename(fn))
def parse_doc(doc_path: Path):
def parse_doc(doc_path: Path, dataset: Dataset | None = None):
try:
file_name = str(Path(doc_path).stem)
pdf_data = read_fn(doc_path)
if dataset is None:
pdf_data_or_dataset = read_fn(doc_path)
else:
pdf_data_or_dataset = dataset
do_parse(
output_dir,
file_name,
pdf_data,
pdf_data_or_dataset,
[],
method,
debug_able,
......@@ -130,9 +136,12 @@ def cli(path, output_dir, method, lang, debug_able, start_page_id, end_page_id):
logger.exception(e)
if os.path.isdir(path):
doc_paths = []
for doc_path in Path(path).glob('*'):
if doc_path.suffix in pdf_suffixes + image_suffixes + ms_office_suffixes:
parse_doc(doc_path)
doc_paths.append(doc_path)
datasets = batch_build_dataset(doc_paths, 4, lang)
batch_do_parse(output_dir, [str(doc_path.stem) for doc_path in doc_paths], datasets, method, debug_able, lang=lang)
else:
parse_doc(Path(path))
......
......@@ -8,10 +8,10 @@ import magic_pdf.model as model_config
from magic_pdf.config.enums import SupportedPdfParseMethod
from magic_pdf.config.make_content_config import DropMode, MakeMode
from magic_pdf.data.data_reader_writer import FileBasedDataWriter
from magic_pdf.data.dataset import PymuDocDataset
from magic_pdf.data.dataset import Dataset, PymuDocDataset
from magic_pdf.libs.draw_bbox import draw_char_bbox
from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
from magic_pdf.operators.models import InferenceResult
from magic_pdf.model.doc_analyze_by_custom_model import (batch_doc_analyze,
doc_analyze)
# from io import BytesIO
# from pypdf import PdfReader, PdfWriter
......@@ -67,10 +67,10 @@ def convert_pdf_bytes_to_bytes_by_pymupdf(pdf_bytes, start_page_id=0, end_page_i
return output_bytes
def do_parse(
def _do_parse(
output_dir,
pdf_file_name,
pdf_bytes,
pdf_bytes_or_dataset,
model_list,
parse_method,
debug_able,
......@@ -92,16 +92,21 @@ def do_parse(
formula_enable=None,
table_enable=None,
):
from magic_pdf.operators.models import InferenceResult
if debug_able:
logger.warning('debug mode is on')
f_draw_model_bbox = True
f_draw_line_sort_bbox = True
# f_draw_char_bbox = True
pdf_bytes = convert_pdf_bytes_to_bytes_by_pymupdf(
pdf_bytes, start_page_id, end_page_id
)
if isinstance(pdf_bytes_or_dataset, bytes):
pdf_bytes = convert_pdf_bytes_to_bytes_by_pymupdf(
pdf_bytes_or_dataset, start_page_id, end_page_id
)
ds = PymuDocDataset(pdf_bytes, lang=lang)
else:
ds = pdf_bytes_or_dataset
pdf_bytes = ds._raw_data
local_image_dir, local_md_dir = prepare_env(output_dir, pdf_file_name, parse_method)
image_writer, md_writer = FileBasedDataWriter(local_image_dir), FileBasedDataWriter(
......@@ -109,8 +114,6 @@ def do_parse(
)
image_dir = str(os.path.basename(local_image_dir))
ds = PymuDocDataset(pdf_bytes, lang=lang)
if len(model_list) == 0:
if model_config.__use_inside_model__:
if parse_method == 'auto':
......@@ -241,5 +244,79 @@ def do_parse(
logger.info(f'local output dir is {local_md_dir}')
def do_parse(
output_dir,
pdf_file_name,
pdf_bytes_or_dataset,
model_list,
parse_method,
debug_able,
f_draw_span_bbox=True,
f_draw_layout_bbox=True,
f_dump_md=True,
f_dump_middle_json=True,
f_dump_model_json=True,
f_dump_orig_pdf=True,
f_dump_content_list=True,
f_make_md_mode=MakeMode.MM_MD,
f_draw_model_bbox=False,
f_draw_line_sort_bbox=False,
f_draw_char_bbox=False,
start_page_id=0,
end_page_id=None,
lang=None,
layout_model=None,
formula_enable=None,
table_enable=None,
):
parallel_count = 1
if os.environ.get('MINERU_PARALLEL_INFERENCE_COUNT'):
parallel_count = int(os.environ['MINERU_PARALLEL_INFERENCE_COUNT'])
if parallel_count > 1:
if isinstance(pdf_bytes_or_dataset, bytes):
pdf_bytes = convert_pdf_bytes_to_bytes_by_pymupdf(
pdf_bytes_or_dataset, start_page_id, end_page_id
)
ds = PymuDocDataset(pdf_bytes, lang=lang)
else:
ds = pdf_bytes_or_dataset
batch_do_parse(output_dir, [pdf_file_name], [ds], parse_method, debug_able, f_draw_span_bbox=f_draw_span_bbox, f_draw_layout_bbox=f_draw_layout_bbox, f_dump_md=f_dump_md, f_dump_middle_json=f_dump_middle_json, f_dump_model_json=f_dump_model_json, f_dump_orig_pdf=f_dump_orig_pdf, f_dump_content_list=f_dump_content_list, f_make_md_mode=f_make_md_mode, f_draw_model_bbox=f_draw_model_bbox, f_draw_line_sort_bbox=f_draw_line_sort_bbox, f_draw_char_bbox=f_draw_char_bbox)
else:
_do_parse(output_dir, pdf_file_name, pdf_bytes_or_dataset, model_list, parse_method, debug_able, start_page_id=start_page_id, end_page_id=end_page_id, lang=lang, layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable, f_draw_span_bbox=f_draw_span_bbox, f_draw_layout_bbox=f_draw_layout_bbox, f_dump_md=f_dump_md, f_dump_middle_json=f_dump_middle_json, f_dump_model_json=f_dump_model_json, f_dump_orig_pdf=f_dump_orig_pdf, f_dump_content_list=f_dump_content_list, f_make_md_mode=f_make_md_mode, f_draw_model_bbox=f_draw_model_bbox, f_draw_line_sort_bbox=f_draw_line_sort_bbox, f_draw_char_bbox=f_draw_char_bbox)
def batch_do_parse(
output_dir,
pdf_file_names: list[str],
pdf_bytes_or_datasets: list[bytes | Dataset],
parse_method,
debug_able,
f_draw_span_bbox=True,
f_draw_layout_bbox=True,
f_dump_md=True,
f_dump_middle_json=True,
f_dump_model_json=True,
f_dump_orig_pdf=True,
f_dump_content_list=True,
f_make_md_mode=MakeMode.MM_MD,
f_draw_model_bbox=False,
f_draw_line_sort_bbox=False,
f_draw_char_bbox=False,
lang=None,
layout_model=None,
formula_enable=None,
table_enable=None,
):
dss = []
for v in pdf_bytes_or_datasets:
if isinstance(v, bytes):
dss.append(PymuDocDataset(v, lang=lang))
else:
dss.append(v)
infer_results = batch_doc_analyze(dss, lang=lang, layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable, one_shot=True)
for idx, infer_result in enumerate(infer_results):
_do_parse(output_dir, pdf_file_names[idx], dss[idx], infer_result.get_infer_res(), parse_method, debug_able, f_draw_span_bbox=f_draw_span_bbox, f_draw_layout_bbox=f_draw_layout_bbox, f_dump_md=f_dump_md, f_dump_middle_json=f_dump_middle_json, f_dump_model_json=f_dump_model_json, f_dump_orig_pdf=f_dump_orig_pdf, f_dump_content_list=f_dump_content_list, f_make_md_mode=f_make_md_mode, f_draw_model_bbox=f_draw_model_bbox, f_draw_line_sort_bbox=f_draw_line_sort_bbox, f_draw_char_bbox=f_draw_char_bbox)
parse_pdf_methods = click.Choice(['ocr', 'txt', 'auto'])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment