Commit b50f742f authored by icecraft's avatar icecraft
Browse files

feat: add parallel evalution

parent 3a2f86a1
import os import concurrent.futures
import glob import glob
import os
import threading import threading
import concurrent.futures
import fitz import fitz
from magic_pdf.data.utils import fitz_doc_to_image # PyMuPDF
from magic_pdf.data.dataset import PymuDocDataset from magic_pdf.data.dataset import PymuDocDataset
from magic_pdf.data.utils import fitz_doc_to_image # PyMuPDF
def partition_array_greedy(arr, k): def partition_array_greedy(arr, k):
""" """Partition an array into k parts using a simple greedy approach.
Partition an array into k parts using a simple greedy approach.
Parameters: Parameters:
----------- -----------
arr : list arr : list
The input array of integers The input array of integers
k : int k : int
Number of partitions to create Number of partitions to create
Returns: Returns:
-------- --------
partitions : list of lists partitions : list of lists
...@@ -24,37 +26,36 @@ def partition_array_greedy(arr, k): ...@@ -24,37 +26,36 @@ def partition_array_greedy(arr, k):
""" """
# Handle edge cases # Handle edge cases
if k <= 0: if k <= 0:
raise ValueError("k must be a positive integer") raise ValueError('k must be a positive integer')
if k > len(arr): if k > len(arr):
k = len(arr) # Adjust k if it's too large k = len(arr) # Adjust k if it's too large
if k == 1: if k == 1:
return [list(range(len(arr)))] return [list(range(len(arr)))]
if k == len(arr): if k == len(arr):
return [[i] for i in range(len(arr))] return [[i] for i in range(len(arr))]
# Sort the array in descending order # Sort the array in descending order
sorted_indices = sorted(range(len(arr)), key=lambda i: arr[i][1], reverse=True) sorted_indices = sorted(range(len(arr)), key=lambda i: arr[i][1], reverse=True)
# Initialize k empty partitions # Initialize k empty partitions
partitions = [[] for _ in range(k)] partitions = [[] for _ in range(k)]
partition_sums = [0] * k partition_sums = [0] * k
# Assign each element to the partition with the smallest current sum # Assign each element to the partition with the smallest current sum
for idx in sorted_indices: for idx in sorted_indices:
# Find the partition with the smallest sum # Find the partition with the smallest sum
min_sum_idx = partition_sums.index(min(partition_sums)) min_sum_idx = partition_sums.index(min(partition_sums))
# Add the element to this partition # Add the element to this partition
partitions[min_sum_idx].append(idx) # Store the original index partitions[min_sum_idx].append(idx) # Store the original index
partition_sums[min_sum_idx] += arr[idx][1] partition_sums[min_sum_idx] += arr[idx][1]
return partitions return partitions
def process_pdf_batch(pdf_jobs, idx): def process_pdf_batch(pdf_jobs, idx):
""" """Process a batch of PDF pages using multiple threads.
Process a batch of PDF pages using multiple threads.
Parameters: Parameters:
----------- -----------
pdf_jobs : list of tuples pdf_jobs : list of tuples
...@@ -65,14 +66,14 @@ def process_pdf_batch(pdf_jobs, idx): ...@@ -65,14 +66,14 @@ def process_pdf_batch(pdf_jobs, idx):
Number of threads to use Number of threads to use
**kwargs : **kwargs :
Additional arguments for process_pdf_page Additional arguments for process_pdf_page
Returns: Returns:
-------- --------
images : list images : list
List of processed images List of processed images
""" """
images = [] images = []
for pdf_path, _ in pdf_jobs: for pdf_path, _ in pdf_jobs:
doc = fitz.open(pdf_path) doc = fitz.open(pdf_path)
tmp = [] tmp = []
...@@ -83,9 +84,9 @@ def process_pdf_batch(pdf_jobs, idx): ...@@ -83,9 +84,9 @@ def process_pdf_batch(pdf_jobs, idx):
return (idx, images) return (idx, images)
def batch_build_dataset(pdf_paths, k, lang=None): def batch_build_dataset(pdf_paths, k, lang=None):
""" """Process multiple PDFs by partitioning them into k balanced parts and
Process multiple PDFs by partitioning them into k balanced parts and processing each part in parallel. processing each part in parallel.
Parameters: Parameters:
----------- -----------
pdf_paths : list pdf_paths : list
...@@ -98,7 +99,7 @@ def batch_build_dataset(pdf_paths, k, lang=None): ...@@ -98,7 +99,7 @@ def batch_build_dataset(pdf_paths, k, lang=None):
Number of threads to use per worker Number of threads to use per worker
**kwargs : **kwargs :
Additional arguments for process_pdf_page Additional arguments for process_pdf_page
Returns: Returns:
-------- --------
all_images : list all_images : list
...@@ -107,7 +108,7 @@ def batch_build_dataset(pdf_paths, k, lang=None): ...@@ -107,7 +108,7 @@ def batch_build_dataset(pdf_paths, k, lang=None):
# Get page counts for each PDF # Get page counts for each PDF
pdf_info = [] pdf_info = []
total_pages = 0 total_pages = 0
for pdf_path in pdf_paths: for pdf_path in pdf_paths:
try: try:
doc = fitz.open(pdf_path) doc = fitz.open(pdf_path)
...@@ -116,24 +117,24 @@ def batch_build_dataset(pdf_paths, k, lang=None): ...@@ -116,24 +117,24 @@ def batch_build_dataset(pdf_paths, k, lang=None):
total_pages += num_pages total_pages += num_pages
doc.close() doc.close()
except Exception as e: except Exception as e:
print(f"Error opening {pdf_path}: {e}") print(f'Error opening {pdf_path}: {e}')
# Partition the jobs based on page countEach job has 1 page # Partition the jobs based on page countEach job has 1 page
partitions = partition_array_greedy(pdf_info, k) partitions = partition_array_greedy(pdf_info, k)
for i, partition in enumerate(partitions): for i, partition in enumerate(partitions):
print(f"Partition {i+1}: {len(partition)} pdfs") print(f'Partition {i+1}: {len(partition)} pdfs')
# Process each partition in parallel # Process each partition in parallel
all_images_h = {} all_images_h = {}
with concurrent.futures.ProcessPoolExecutor(max_workers=k) as executor: with concurrent.futures.ProcessPoolExecutor(max_workers=k) as executor:
# Submit one task per partition # Submit one task per partition
futures = [] futures = []
for sn, partition in enumerate(partitions): for sn, partition in enumerate(partitions):
# Get the jobs for this partition # Get the jobs for this partition
partition_jobs = [pdf_info[idx] for idx in partition] partition_jobs = [pdf_info[idx] for idx in partition]
# Submit the task # Submit the task
future = executor.submit( future = executor.submit(
process_pdf_batch, process_pdf_batch,
...@@ -145,15 +146,15 @@ def batch_build_dataset(pdf_paths, k, lang=None): ...@@ -145,15 +146,15 @@ def batch_build_dataset(pdf_paths, k, lang=None):
for i, future in enumerate(concurrent.futures.as_completed(futures)): for i, future in enumerate(concurrent.futures.as_completed(futures)):
try: try:
idx, images = future.result() idx, images = future.result()
print(f"Partition {i+1} completed: processed {len(images)} images") print(f'Partition {i+1} completed: processed {len(images)} images')
all_images_h[idx] = images all_images_h[idx] = images
except Exception as e: except Exception as e:
print(f"Error processing partition: {e}") print(f'Error processing partition: {e}')
results = [None] * len(pdf_paths) results = [None] * len(pdf_paths)
for i in range(len(partitions)): for i in range(len(partitions)):
partition = partitions[i] partition = partitions[i]
for j in range(len(partition)): for j in range(len(partition)):
with open(pdf_info[partition[j]][0], "rb") as f: with open(pdf_info[partition[j]][0], 'rb') as f:
pdf_bytes = f.read() pdf_bytes = f.read()
dataset = PymuDocDataset(pdf_bytes, lang=lang) dataset = PymuDocDataset(pdf_bytes, lang=lang)
dataset.set_images(all_images_h[i][j]) dataset.set_images(all_images_h[i][j])
......
...@@ -97,10 +97,10 @@ class Dataset(ABC): ...@@ -97,10 +97,10 @@ class Dataset(ABC):
@abstractmethod @abstractmethod
def dump_to_file(self, file_path: str): def dump_to_file(self, file_path: str):
"""Dump the file """Dump the file.
Args: Args:
file_path (str): the file path file_path (str): the file path
""" """
pass pass
...@@ -119,7 +119,7 @@ class Dataset(ABC): ...@@ -119,7 +119,7 @@ class Dataset(ABC):
@abstractmethod @abstractmethod
def classify(self) -> SupportedPdfParseMethod: def classify(self) -> SupportedPdfParseMethod:
"""classify the dataset """classify the dataset.
Returns: Returns:
SupportedPdfParseMethod: _description_ SupportedPdfParseMethod: _description_
...@@ -128,8 +128,7 @@ class Dataset(ABC): ...@@ -128,8 +128,7 @@ class Dataset(ABC):
@abstractmethod @abstractmethod
def clone(self): def clone(self):
"""clone this dataset """clone this dataset."""
"""
pass pass
...@@ -148,12 +147,13 @@ class PymuDocDataset(Dataset): ...@@ -148,12 +147,13 @@ class PymuDocDataset(Dataset):
if lang == '': if lang == '':
self._lang = None self._lang = None
elif lang == 'auto': elif lang == 'auto':
from magic_pdf.model.sub_modules.language_detection.utils import auto_detect_lang from magic_pdf.model.sub_modules.language_detection.utils import \
auto_detect_lang
self._lang = auto_detect_lang(bits) self._lang = auto_detect_lang(bits)
logger.info(f"lang: {lang}, detect_lang: {self._lang}") logger.info(f'lang: {lang}, detect_lang: {self._lang}')
else: else:
self._lang = lang self._lang = lang
logger.info(f"lang: {lang}") logger.info(f'lang: {lang}')
def __len__(self) -> int: def __len__(self) -> int:
"""The page number of the pdf.""" """The page number of the pdf."""
...@@ -187,12 +187,12 @@ class PymuDocDataset(Dataset): ...@@ -187,12 +187,12 @@ class PymuDocDataset(Dataset):
return self._records[page_id] return self._records[page_id]
def dump_to_file(self, file_path: str): def dump_to_file(self, file_path: str):
"""Dump the file """Dump the file.
Args: Args:
file_path (str): the file path file_path (str): the file path
""" """
dir_name = os.path.dirname(file_path) dir_name = os.path.dirname(file_path)
if dir_name not in ('', '.', '..'): if dir_name not in ('', '.', '..'):
os.makedirs(dir_name, exist_ok=True) os.makedirs(dir_name, exist_ok=True)
...@@ -213,7 +213,7 @@ class PymuDocDataset(Dataset): ...@@ -213,7 +213,7 @@ class PymuDocDataset(Dataset):
return proc(self, *args, **kwargs) return proc(self, *args, **kwargs)
def classify(self) -> SupportedPdfParseMethod: def classify(self) -> SupportedPdfParseMethod:
"""classify the dataset """classify the dataset.
Returns: Returns:
SupportedPdfParseMethod: _description_ SupportedPdfParseMethod: _description_
...@@ -221,8 +221,7 @@ class PymuDocDataset(Dataset): ...@@ -221,8 +221,7 @@ class PymuDocDataset(Dataset):
return classify(self._data_bits) return classify(self._data_bits)
def clone(self): def clone(self):
"""clone this dataset """clone this dataset."""
"""
return PymuDocDataset(self._raw_data) return PymuDocDataset(self._raw_data)
def set_images(self, images): def set_images(self, images):
...@@ -274,10 +273,10 @@ class ImageDataset(Dataset): ...@@ -274,10 +273,10 @@ class ImageDataset(Dataset):
return self._records[page_id] return self._records[page_id]
def dump_to_file(self, file_path: str): def dump_to_file(self, file_path: str):
"""Dump the file """Dump the file.
Args: Args:
file_path (str): the file path file_path (str): the file path
""" """
dir_name = os.path.dirname(file_path) dir_name = os.path.dirname(file_path)
if dir_name not in ('', '.', '..'): if dir_name not in ('', '.', '..'):
...@@ -297,7 +296,7 @@ class ImageDataset(Dataset): ...@@ -297,7 +296,7 @@ class ImageDataset(Dataset):
return proc(self, *args, **kwargs) return proc(self, *args, **kwargs)
def classify(self) -> SupportedPdfParseMethod: def classify(self) -> SupportedPdfParseMethod:
"""classify the dataset """classify the dataset.
Returns: Returns:
SupportedPdfParseMethod: _description_ SupportedPdfParseMethod: _description_
...@@ -305,10 +304,9 @@ class ImageDataset(Dataset): ...@@ -305,10 +304,9 @@ class ImageDataset(Dataset):
return SupportedPdfParseMethod.OCR return SupportedPdfParseMethod.OCR
def clone(self): def clone(self):
"""clone this dataset """clone this dataset."""
"""
return ImageDataset(self._raw_data) return ImageDataset(self._raw_data)
def set_images(self, images): def set_images(self, images):
for i in range(len(self._records)): for i in range(len(self._records)):
self._records[i].set_image(images[i]) self._records[i].set_image(images[i])
......
import multiprocessing as mp import multiprocessing as mp
import threading import threading
from concurrent.futures import (ProcessPoolExecutor, ThreadPoolExecutor,
as_completed)
import fitz import fitz
import numpy as np import numpy as np
from loguru import logger from loguru import logger
from magic_pdf.utils.annotations import ImportPIL from magic_pdf.utils.annotations import ImportPIL
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
@ImportPIL @ImportPIL
...@@ -69,17 +71,17 @@ def load_images_from_pdf(pdf_bytes: bytes, dpi=200, start_page_id=0, end_page_id ...@@ -69,17 +71,17 @@ def load_images_from_pdf(pdf_bytes: bytes, dpi=200, start_page_id=0, end_page_id
images.append(img_dict) images.append(img_dict)
return images return images
def convert_page(bytes_page): def convert_page(bytes_page):
pdfs = fitz.open('pdf', bytes_page) pdfs = fitz.open('pdf', bytes_page)
page = pdfs[0] page = pdfs[0]
return fitz_doc_to_image(page) return fitz_doc_to_image(page)
def parallel_process_pdf_safe(pages, num_workers=None, **kwargs): def parallel_process_pdf_safe(pages, num_workers=None, **kwargs):
"""Process PDF pages in parallel with serialization-safe approach""" """Process PDF pages in parallel with serialization-safe approach."""
if num_workers is None: if num_workers is None:
num_workers = mp.cpu_count() num_workers = mp.cpu_count()
# Process the extracted page data in parallel # Process the extracted page data in parallel
with ProcessPoolExecutor(max_workers=num_workers) as executor: with ProcessPoolExecutor(max_workers=num_workers) as executor:
...@@ -87,14 +89,13 @@ def parallel_process_pdf_safe(pages, num_workers=None, **kwargs): ...@@ -87,14 +89,13 @@ def parallel_process_pdf_safe(pages, num_workers=None, **kwargs):
results = list( results = list(
executor.map(convert_page, pages) executor.map(convert_page, pages)
) )
return results return results
def threaded_process_pdf(pdf_path, num_threads=4, **kwargs): def threaded_process_pdf(pdf_path, num_threads=4, **kwargs):
""" """Process all pages of a PDF using multiple threads.
Process all pages of a PDF using multiple threads
Parameters: Parameters:
----------- -----------
pdf_path : str pdf_path : str
...@@ -103,7 +104,7 @@ def threaded_process_pdf(pdf_path, num_threads=4, **kwargs): ...@@ -103,7 +104,7 @@ def threaded_process_pdf(pdf_path, num_threads=4, **kwargs):
Number of threads to use Number of threads to use
**kwargs : **kwargs :
Additional arguments for fitz_doc_to_image Additional arguments for fitz_doc_to_image
Returns: Returns:
-------- --------
images : list images : list
...@@ -112,10 +113,10 @@ def threaded_process_pdf(pdf_path, num_threads=4, **kwargs): ...@@ -112,10 +113,10 @@ def threaded_process_pdf(pdf_path, num_threads=4, **kwargs):
# Open the PDF # Open the PDF
doc = fitz.open(pdf_path) doc = fitz.open(pdf_path)
num_pages = len(doc) num_pages = len(doc)
# Create a list to store results in the correct order # Create a list to store results in the correct order
results = [None] * num_pages results = [None] * num_pages
# Create a thread pool # Create a thread pool
with ThreadPoolExecutor(max_workers=num_threads) as executor: with ThreadPoolExecutor(max_workers=num_threads) as executor:
# Submit all tasks # Submit all tasks
...@@ -130,27 +131,27 @@ def threaded_process_pdf(pdf_path, num_threads=4, **kwargs): ...@@ -130,27 +131,27 @@ def threaded_process_pdf(pdf_path, num_threads=4, **kwargs):
try: try:
results[page_num] = future.result() results[page_num] = future.result()
except Exception as e: except Exception as e:
print(f"Error processing page {page_num}: {e}") print(f'Error processing page {page_num}: {e}')
results[page_num] = None results[page_num] = None
# Close the document # Close the document
doc.close() doc.close()
if __name__ == "__main__": if __name__ == '__main__':
pdf = fitz.open('/tmp/[MS-DOC].pdf') pdf = fitz.open('/tmp/[MS-DOC].pdf')
pdf_page = [fitz.open() for i in range(pdf.page_count)] pdf_page = [fitz.open() for i in range(pdf.page_count)]
[pdf_page[i].insert_pdf(pdf, from_page=i, to_page=i) for i in range(pdf.page_count)] [pdf_page[i].insert_pdf(pdf, from_page=i, to_page=i) for i in range(pdf.page_count)]
pdf_page = [v.tobytes() for v in pdf_page] pdf_page = [v.tobytes() for v in pdf_page]
results = parallel_process_pdf_safe(pdf_page, num_workers=16) results = parallel_process_pdf_safe(pdf_page, num_workers=16)
# threaded_process_pdf('/tmp/[MS-DOC].pdf', num_threads=16) # threaded_process_pdf('/tmp/[MS-DOC].pdf', num_threads=16)
""" benchmark results of multi-threaded processing (fitz page to image) """ benchmark results of multi-threaded processing (fitz page to image)
total page nums: 578 total page nums: 578
thread nums, time cost thread nums, time cost
1 7.351 sec 1 7.351 sec
2 6.334 sec 2 6.334 sec
4 5.968 sec 4 5.968 sec
...@@ -159,14 +160,11 @@ if __name__ == "__main__": ...@@ -159,14 +160,11 @@ if __name__ == "__main__":
""" """
""" benchmark results of multi-processor processing (fitz page to image) """ benchmark results of multi-processor processing (fitz page to image)
total page nums: 578 total page nums: 578
processor nums, time cost processor nums, time cost
1 17.170 sec 1 17.170 sec
2 10.170 sec 2 10.170 sec
4 7.841 sec 4 7.841 sec
8 7.900 sec 8 7.900 sec
16 7.984 sec 16 7.984 sec
""" """
import concurrent.futures as fut
import multiprocessing as mp
import os import os
import time import time
import torch
import numpy as np import numpy as np
import multiprocessing as mp import torch
import concurrent.futures as fut
os.environ['FLAGS_npu_jit_compile'] = '0' # 关闭paddle的jit编译 os.environ['FLAGS_npu_jit_compile'] = '0' # 关闭paddle的jit编译
os.environ['FLAGS_use_stride_kernel'] = '0' os.environ['FLAGS_use_stride_kernel'] = '0'
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 让mps可以fallback os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 让mps可以fallback
...@@ -29,6 +31,7 @@ from magic_pdf.libs.config_reader import (get_device, get_formula_config, ...@@ -29,6 +31,7 @@ from magic_pdf.libs.config_reader import (get_device, get_formula_config,
get_local_models_dir, get_local_models_dir,
get_table_recog_config) get_table_recog_config)
from magic_pdf.model.model_list import MODEL from magic_pdf.model.model_list import MODEL
# from magic_pdf.operators.models import InferenceResult # from magic_pdf.operators.models import InferenceResult
MIN_BATCH_INFERENCE_SIZE = 100 MIN_BATCH_INFERENCE_SIZE = 100
...@@ -170,7 +173,7 @@ def doc_analyze( ...@@ -170,7 +173,7 @@ def doc_analyze(
else: else:
batch_images = [images] batch_images = [images]
results = [] results = []
parallel_count = len(batch_images) # adjust to real parallel count parallel_count = len(batch_images) # adjust to real parallel count
# using concurrent.futures to analyze # using concurrent.futures to analyze
""" """
with fut.ProcessPoolExecutor(max_workers=parallel_count) as executor: with fut.ProcessPoolExecutor(max_workers=parallel_count) as executor:
...@@ -192,8 +195,8 @@ def doc_analyze( ...@@ -192,8 +195,8 @@ def doc_analyze(
_, results = may_batch_image_analyze( _, results = may_batch_image_analyze(
images, images,
0, 0,
ocr, ocr,
show_log, show_log,
lang, layout_model, formula_enable, table_enable) lang, layout_model, formula_enable, table_enable)
model_json = [] model_json = []
...@@ -234,7 +237,7 @@ def batch_doc_analyze( ...@@ -234,7 +237,7 @@ def batch_doc_analyze(
img_dict = page_data.get_image() img_dict = page_data.get_image()
images.append(img_dict['img']) images.append(img_dict['img'])
page_wh_list.append((img_dict['width'], img_dict['height'])) page_wh_list.append((img_dict['width'], img_dict['height']))
if one_shot and len(images) >= MIN_BATCH_INFERENCE_SIZE: if one_shot and len(images) >= MIN_BATCH_INFERENCE_SIZE:
if parallel_count is None: if parallel_count is None:
parallel_count = 2 # should check the gpu memory firstly ! parallel_count = 2 # should check the gpu memory firstly !
...@@ -245,7 +248,7 @@ def batch_doc_analyze( ...@@ -245,7 +248,7 @@ def batch_doc_analyze(
else: else:
batch_images = [images] batch_images = [images]
results = [] results = []
parallel_count = len(batch_images) # adjust to real parallel count parallel_count = len(batch_images) # adjust to real parallel count
# using concurrent.futures to analyze # using concurrent.futures to analyze
""" """
with fut.ProcessPoolExecutor(max_workers=parallel_count) as executor: with fut.ProcessPoolExecutor(max_workers=parallel_count) as executor:
...@@ -266,8 +269,8 @@ def batch_doc_analyze( ...@@ -266,8 +269,8 @@ def batch_doc_analyze(
_, results = may_batch_image_analyze( _, results = may_batch_image_analyze(
images, images,
0, 0,
ocr, ocr,
show_log, show_log,
lang, layout_model, formula_enable, table_enable) lang, layout_model, formula_enable, table_enable)
infer_results = [] infer_results = []
...@@ -286,20 +289,20 @@ def batch_doc_analyze( ...@@ -286,20 +289,20 @@ def batch_doc_analyze(
def may_batch_image_analyze( def may_batch_image_analyze(
images: list[np.ndarray], images: list[np.ndarray],
idx: int, idx: int,
ocr: bool = False, ocr: bool = False,
show_log: bool = False, show_log: bool = False,
lang=None, lang=None,
layout_model=None, layout_model=None,
formula_enable=None, formula_enable=None,
table_enable=None): table_enable=None):
# os.environ['CUDA_VISIBLE_DEVICES'] = str(idx) # os.environ['CUDA_VISIBLE_DEVICES'] = str(idx)
# 关闭paddle的信号处理 # 关闭paddle的信号处理
import paddle import paddle
paddle.disable_signal_handler() paddle.disable_signal_handler()
from magic_pdf.model.batch_analyze import BatchAnalyze from magic_pdf.model.batch_analyze import BatchAnalyze
model_manager = ModelSingleton() model_manager = ModelSingleton()
custom_model = model_manager.get_model( custom_model = model_manager.get_model(
ocr, show_log, lang, layout_model, formula_enable, table_enable ocr, show_log, lang, layout_model, formula_enable, table_enable
...@@ -310,14 +313,14 @@ def may_batch_image_analyze( ...@@ -310,14 +313,14 @@ def may_batch_image_analyze(
device = get_device() device = get_device()
npu_support = False npu_support = False
if str(device).startswith("npu"): if str(device).startswith('npu'):
import torch_npu import torch_npu
if torch_npu.npu.is_available(): if torch_npu.npu.is_available():
npu_support = True npu_support = True
torch.npu.set_compile_mode(jit_compile=False) torch.npu.set_compile_mode(jit_compile=False)
if torch.cuda.is_available() and device != 'cpu' or npu_support: if torch.cuda.is_available() and device != 'cpu' or npu_support:
gpu_memory = int(os.getenv("VIRTUAL_VRAM_SIZE", round(get_vram(device)))) gpu_memory = int(os.getenv('VIRTUAL_VRAM_SIZE', round(get_vram(device))))
if gpu_memory is not None and gpu_memory >= 8: if gpu_memory is not None and gpu_memory >= 8:
if gpu_memory >= 20: if gpu_memory >= 20:
batch_ratio = 16 batch_ratio = 16
...@@ -398,4 +401,3 @@ def may_batch_image_analyze( ...@@ -398,4 +401,3 @@ def may_batch_image_analyze(
f' speed: {doc_analyze_speed} pages/second' f' speed: {doc_analyze_speed} pages/second'
) )
return (idx, results) return (idx, results)
import os import os
import torch import torch
from loguru import logger from loguru import logger
from magic_pdf.config.constants import MODEL_NAME from magic_pdf.config.constants import MODEL_NAME
from magic_pdf.model.model_list import AtomicModel from magic_pdf.model.model_list import AtomicModel
from magic_pdf.model.sub_modules.language_detection.yolov11.YOLOv11 import YOLOv11LangDetModel from magic_pdf.model.sub_modules.language_detection.yolov11.YOLOv11 import \
from magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO import DocLayoutYOLOModel YOLOv11LangDetModel
from magic_pdf.model.sub_modules.layout.layoutlmv3.model_init import Layoutlmv3_Predictor from magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO import \
DocLayoutYOLOModel
from magic_pdf.model.sub_modules.layout.layoutlmv3.model_init import \
Layoutlmv3_Predictor
from magic_pdf.model.sub_modules.mfd.yolov8.YOLOv8 import YOLOv8MFDModel from magic_pdf.model.sub_modules.mfd.yolov8.YOLOv8 import YOLOv8MFDModel
from magic_pdf.model.sub_modules.mfr.unimernet.Unimernet import UnimernetModel from magic_pdf.model.sub_modules.mfr.unimernet.Unimernet import UnimernetModel
try: try:
from magic_pdf_ascend_plugin.libs.license_verifier import load_license, LicenseFormatError, LicenseSignatureError, LicenseExpiredError from magic_pdf_ascend_plugin.libs.license_verifier import (
from magic_pdf_ascend_plugin.model_plugin.ocr.paddleocr.ppocr_273_npu import ModifiedPaddleOCR LicenseExpiredError, LicenseFormatError, LicenseSignatureError,
from magic_pdf_ascend_plugin.model_plugin.table.rapidtable.rapid_table_npu import RapidTableModel load_license)
from magic_pdf_ascend_plugin.model_plugin.ocr.paddleocr.ppocr_273_npu import \
ModifiedPaddleOCR
from magic_pdf_ascend_plugin.model_plugin.table.rapidtable.rapid_table_npu import \
RapidTableModel
license_key = load_license() license_key = load_license()
logger.info(f'Using Ascend Plugin Success, License id is {license_key["payload"]["id"]},' logger.info(f'Using Ascend Plugin Success, License id is {license_key["payload"]["id"]},'
f' License expired at {license_key["payload"]["date"]["end_date"]}') f' License expired at {license_key["payload"]["date"]["end_date"]}')
...@@ -21,21 +29,24 @@ except Exception as e: ...@@ -21,21 +29,24 @@ except Exception as e:
if isinstance(e, ImportError): if isinstance(e, ImportError):
pass pass
elif isinstance(e, LicenseFormatError): elif isinstance(e, LicenseFormatError):
logger.error("Ascend Plugin: Invalid license format. Please check the license file.") logger.error('Ascend Plugin: Invalid license format. Please check the license file.')
elif isinstance(e, LicenseSignatureError): elif isinstance(e, LicenseSignatureError):
logger.error("Ascend Plugin: Invalid signature. The license may be tampered with.") logger.error('Ascend Plugin: Invalid signature. The license may be tampered with.')
elif isinstance(e, LicenseExpiredError): elif isinstance(e, LicenseExpiredError):
logger.error("Ascend Plugin: License has expired. Please renew your license.") logger.error('Ascend Plugin: License has expired. Please renew your license.')
elif isinstance(e, FileNotFoundError): elif isinstance(e, FileNotFoundError):
logger.error("Ascend Plugin: Not found License file.") logger.error('Ascend Plugin: Not found License file.')
else: else:
logger.error(f"Ascend Plugin: {e}") logger.error(f'Ascend Plugin: {e}')
from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_273_mod import ModifiedPaddleOCR from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_273_mod import ModifiedPaddleOCR
# from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_291_mod import ModifiedPaddleOCR # from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_291_mod import ModifiedPaddleOCR
from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import RapidTableModel from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import RapidTableModel
from magic_pdf.model.sub_modules.table.structeqtable.struct_eqtable import StructTableModel from magic_pdf.model.sub_modules.table.structeqtable.struct_eqtable import \
from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import TableMasterPaddleModel StructTableModel
from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import \
TableMasterPaddleModel
def table_model_init(table_model_type, model_path, max_time, _device_='cpu', ocr_engine=None, table_sub_model_name=None): def table_model_init(table_model_type, model_path, max_time, _device_='cpu', ocr_engine=None, table_sub_model_name=None):
if table_model_type == MODEL_NAME.STRUCT_EQTABLE: if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
...@@ -56,7 +67,7 @@ def table_model_init(table_model_type, model_path, max_time, _device_='cpu', ocr ...@@ -56,7 +67,7 @@ def table_model_init(table_model_type, model_path, max_time, _device_='cpu', ocr
def mfd_model_init(weight, device='cpu'): def mfd_model_init(weight, device='cpu'):
if str(device).startswith("npu"): if str(device).startswith('npu'):
device = torch.device(device) device = torch.device(device)
mfd_model = YOLOv8MFDModel(weight, device) mfd_model = YOLOv8MFDModel(weight, device)
return mfd_model return mfd_model
...@@ -73,14 +84,14 @@ def layout_model_init(weight, config_file, device): ...@@ -73,14 +84,14 @@ def layout_model_init(weight, config_file, device):
def doclayout_yolo_model_init(weight, device='cpu'): def doclayout_yolo_model_init(weight, device='cpu'):
if str(device).startswith("npu"): if str(device).startswith('npu'):
device = torch.device(device) device = torch.device(device)
model = DocLayoutYOLOModel(weight, device) model = DocLayoutYOLOModel(weight, device)
return model return model
def langdetect_model_init(langdetect_model_weight, device='cpu'): def langdetect_model_init(langdetect_model_weight, device='cpu'):
if str(device).startswith("npu"): if str(device).startswith('npu'):
device = torch.device(device) device = torch.device(device)
model = YOLOv11LangDetModel(langdetect_model_weight, device) model = YOLOv11LangDetModel(langdetect_model_weight, device)
return model return model
......
import os import os
import shutil import shutil
import tempfile import tempfile
from pathlib import Path
import click import click
import fitz import fitz
from loguru import logger from loguru import logger
from pathlib import Path
import magic_pdf.model as model_config import magic_pdf.model as model_config
from magic_pdf.data.data_reader_writer import FileBasedDataReader
from magic_pdf.data.batch_build_dataset import batch_build_dataset from magic_pdf.data.batch_build_dataset import batch_build_dataset
from magic_pdf.data.data_reader_writer import FileBasedDataReader
from magic_pdf.data.dataset import Dataset from magic_pdf.data.dataset import Dataset
from magic_pdf.libs.version import __version__ from magic_pdf.libs.version import __version__
from magic_pdf.tools.common import do_parse, parse_pdf_methods, batch_do_parse from magic_pdf.tools.common import batch_do_parse, do_parse, parse_pdf_methods
from magic_pdf.utils.office_to_pdf import convert_file_to_pdf from magic_pdf.utils.office_to_pdf import convert_file_to_pdf
pdf_suffixes = ['.pdf'] pdf_suffixes = ['.pdf']
ms_office_suffixes = ['.ppt', '.pptx', '.doc', '.docx'] ms_office_suffixes = ['.ppt', '.pptx', '.doc', '.docx']
image_suffixes = ['.png', '.jpeg', '.jpg'] image_suffixes = ['.png', '.jpeg', '.jpg']
...@@ -97,19 +97,19 @@ def cli(path, output_dir, method, lang, debug_able, start_page_id, end_page_id): ...@@ -97,19 +97,19 @@ def cli(path, output_dir, method, lang, debug_able, start_page_id, end_page_id):
def read_fn(path: Path): def read_fn(path: Path):
if path.suffix in ms_office_suffixes: if path.suffix in ms_office_suffixes:
convert_file_to_pdf(str(path), temp_dir) convert_file_to_pdf(str(path), temp_dir)
fn = os.path.join(temp_dir, f"{path.stem}.pdf") fn = os.path.join(temp_dir, f'{path.stem}.pdf')
elif path.suffix in image_suffixes: elif path.suffix in image_suffixes:
with open(str(path), 'rb') as f: with open(str(path), 'rb') as f:
bits = f.read() bits = f.read()
pdf_bytes = fitz.open(stream=bits).convert_to_pdf() pdf_bytes = fitz.open(stream=bits).convert_to_pdf()
fn = os.path.join(temp_dir, f"{path.stem}.pdf") fn = os.path.join(temp_dir, f'{path.stem}.pdf')
with open(fn, 'wb') as f: with open(fn, 'wb') as f:
f.write(pdf_bytes) f.write(pdf_bytes)
elif path.suffix in pdf_suffixes: elif path.suffix in pdf_suffixes:
fn = str(path) fn = str(path)
else: else:
raise Exception(f"Unknown file suffix: {path.suffix}") raise Exception(f'Unknown file suffix: {path.suffix}')
disk_rw = FileBasedDataReader(os.path.dirname(fn)) disk_rw = FileBasedDataReader(os.path.dirname(fn))
return disk_rw.read(os.path.basename(fn)) return disk_rw.read(os.path.basename(fn))
......
...@@ -8,10 +8,10 @@ import magic_pdf.model as model_config ...@@ -8,10 +8,10 @@ import magic_pdf.model as model_config
from magic_pdf.config.enums import SupportedPdfParseMethod from magic_pdf.config.enums import SupportedPdfParseMethod
from magic_pdf.config.make_content_config import DropMode, MakeMode from magic_pdf.config.make_content_config import DropMode, MakeMode
from magic_pdf.data.data_reader_writer import FileBasedDataWriter from magic_pdf.data.data_reader_writer import FileBasedDataWriter
from magic_pdf.data.dataset import PymuDocDataset, Dataset from magic_pdf.data.dataset import Dataset, PymuDocDataset
from magic_pdf.libs.draw_bbox import draw_char_bbox from magic_pdf.libs.draw_bbox import draw_char_bbox
from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze, batch_doc_analyze from magic_pdf.model.doc_analyze_by_custom_model import (batch_doc_analyze,
doc_analyze)
# from io import BytesIO # from io import BytesIO
# from pypdf import PdfReader, PdfWriter # from pypdf import PdfReader, PdfWriter
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment