"vscode:/vscode.git/clone" did not exist on "2875315d6b6b4f5a375f04b6673ef2a57483edfa"
Unverified Commit 41d96cd8 authored by Xiaomeng Zhao's avatar Xiaomeng Zhao Committed by GitHub
Browse files

Merge pull request #2065 from opendatalab/release-1.3.0

Release 1.3.0
parents c3d43e52 dd96663c
...@@ -97,10 +97,10 @@ class Dataset(ABC): ...@@ -97,10 +97,10 @@ class Dataset(ABC):
@abstractmethod @abstractmethod
def dump_to_file(self, file_path: str): def dump_to_file(self, file_path: str):
"""Dump the file """Dump the file.
Args: Args:
file_path (str): the file path file_path (str): the file path
""" """
pass pass
...@@ -119,7 +119,7 @@ class Dataset(ABC): ...@@ -119,7 +119,7 @@ class Dataset(ABC):
@abstractmethod @abstractmethod
def classify(self) -> SupportedPdfParseMethod: def classify(self) -> SupportedPdfParseMethod:
"""classify the dataset """classify the dataset.
Returns: Returns:
SupportedPdfParseMethod: _description_ SupportedPdfParseMethod: _description_
...@@ -128,8 +128,7 @@ class Dataset(ABC): ...@@ -128,8 +128,7 @@ class Dataset(ABC):
@abstractmethod @abstractmethod
def clone(self): def clone(self):
"""clone this dataset """clone this dataset."""
"""
pass pass
...@@ -144,16 +143,19 @@ class PymuDocDataset(Dataset): ...@@ -144,16 +143,19 @@ class PymuDocDataset(Dataset):
self._records = [Doc(v) for v in self._raw_fitz] self._records = [Doc(v) for v in self._raw_fitz]
self._data_bits = bits self._data_bits = bits
self._raw_data = bits self._raw_data = bits
self._classify_result = None
if lang == '': if lang == '':
self._lang = None self._lang = None
elif lang == 'auto': elif lang == 'auto':
from magic_pdf.model.sub_modules.language_detection.utils import auto_detect_lang from magic_pdf.model.sub_modules.language_detection.utils import \
auto_detect_lang
self._lang = auto_detect_lang(bits) self._lang = auto_detect_lang(bits)
logger.info(f"lang: {lang}, detect_lang: {self._lang}") logger.info(f'lang: {lang}, detect_lang: {self._lang}')
else: else:
self._lang = lang self._lang = lang
logger.info(f"lang: {lang}") logger.info(f'lang: {lang}')
def __len__(self) -> int: def __len__(self) -> int:
"""The page number of the pdf.""" """The page number of the pdf."""
return len(self._records) return len(self._records)
...@@ -186,12 +188,12 @@ class PymuDocDataset(Dataset): ...@@ -186,12 +188,12 @@ class PymuDocDataset(Dataset):
return self._records[page_id] return self._records[page_id]
def dump_to_file(self, file_path: str): def dump_to_file(self, file_path: str):
"""Dump the file """Dump the file.
Args: Args:
file_path (str): the file path file_path (str): the file path
""" """
dir_name = os.path.dirname(file_path) dir_name = os.path.dirname(file_path)
if dir_name not in ('', '.', '..'): if dir_name not in ('', '.', '..'):
os.makedirs(dir_name, exist_ok=True) os.makedirs(dir_name, exist_ok=True)
...@@ -212,18 +214,22 @@ class PymuDocDataset(Dataset): ...@@ -212,18 +214,22 @@ class PymuDocDataset(Dataset):
return proc(self, *args, **kwargs) return proc(self, *args, **kwargs)
def classify(self) -> SupportedPdfParseMethod: def classify(self) -> SupportedPdfParseMethod:
"""classify the dataset """classify the dataset.
Returns: Returns:
SupportedPdfParseMethod: _description_ SupportedPdfParseMethod: _description_
""" """
return classify(self._data_bits) if self._classify_result is None:
self._classify_result = classify(self._data_bits)
return self._classify_result
def clone(self): def clone(self):
"""clone this dataset """clone this dataset."""
"""
return PymuDocDataset(self._raw_data) return PymuDocDataset(self._raw_data)
def set_images(self, images):
for i in range(len(self._records)):
self._records[i].set_image(images[i])
class ImageDataset(Dataset): class ImageDataset(Dataset):
def __init__(self, bits: bytes): def __init__(self, bits: bytes):
...@@ -270,10 +276,10 @@ class ImageDataset(Dataset): ...@@ -270,10 +276,10 @@ class ImageDataset(Dataset):
return self._records[page_id] return self._records[page_id]
def dump_to_file(self, file_path: str): def dump_to_file(self, file_path: str):
"""Dump the file """Dump the file.
Args: Args:
file_path (str): the file path file_path (str): the file path
""" """
dir_name = os.path.dirname(file_path) dir_name = os.path.dirname(file_path)
if dir_name not in ('', '.', '..'): if dir_name not in ('', '.', '..'):
...@@ -293,7 +299,7 @@ class ImageDataset(Dataset): ...@@ -293,7 +299,7 @@ class ImageDataset(Dataset):
return proc(self, *args, **kwargs) return proc(self, *args, **kwargs)
def classify(self) -> SupportedPdfParseMethod: def classify(self) -> SupportedPdfParseMethod:
"""classify the dataset """classify the dataset.
Returns: Returns:
SupportedPdfParseMethod: _description_ SupportedPdfParseMethod: _description_
...@@ -301,15 +307,19 @@ class ImageDataset(Dataset): ...@@ -301,15 +307,19 @@ class ImageDataset(Dataset):
return SupportedPdfParseMethod.OCR return SupportedPdfParseMethod.OCR
def clone(self): def clone(self):
"""clone this dataset """clone this dataset."""
"""
return ImageDataset(self._raw_data) return ImageDataset(self._raw_data)
def set_images(self, images):
for i in range(len(self._records)):
self._records[i].set_image(images[i])
class Doc(PageableData): class Doc(PageableData):
"""Initialized with pymudoc object.""" """Initialized with pymudoc object."""
def __init__(self, doc: fitz.Page): def __init__(self, doc: fitz.Page):
self._doc = doc self._doc = doc
self._img = None
def get_image(self): def get_image(self):
"""Return the image info. """Return the image info.
...@@ -321,7 +331,17 @@ class Doc(PageableData): ...@@ -321,7 +331,17 @@ class Doc(PageableData):
height: int height: int
} }
""" """
return fitz_doc_to_image(self._doc) if self._img is None:
self._img = fitz_doc_to_image(self._doc)
return self._img
def set_image(self, img):
"""
Args:
img (np.ndarray): the image
"""
if self._img is None:
self._img = img
def get_doc(self) -> fitz.Page: def get_doc(self) -> fitz.Page:
"""Get the pymudoc object. """Get the pymudoc object.
......
import multiprocessing as mp
import threading
from concurrent.futures import (ProcessPoolExecutor, ThreadPoolExecutor,
as_completed)
import fitz import fitz
import numpy as np import numpy as np
from loguru import logger from loguru import logger
from magic_pdf.utils.annotations import ImportPIL
@ImportPIL
def fitz_doc_to_image(doc, dpi=200) -> dict: def fitz_doc_to_image(doc, dpi=200) -> dict:
"""Convert fitz.Document to image, Then convert the image to numpy array. """Convert fitz.Document to image, Then convert the image to numpy array.
...@@ -17,7 +20,6 @@ def fitz_doc_to_image(doc, dpi=200) -> dict: ...@@ -17,7 +20,6 @@ def fitz_doc_to_image(doc, dpi=200) -> dict:
Returns: Returns:
dict: {'img': numpy array, 'width': width, 'height': height } dict: {'img': numpy array, 'width': width, 'height': height }
""" """
from PIL import Image
mat = fitz.Matrix(dpi / 72, dpi / 72) mat = fitz.Matrix(dpi / 72, dpi / 72)
pm = doc.get_pixmap(matrix=mat, alpha=False) pm = doc.get_pixmap(matrix=mat, alpha=False)
...@@ -25,16 +27,14 @@ def fitz_doc_to_image(doc, dpi=200) -> dict: ...@@ -25,16 +27,14 @@ def fitz_doc_to_image(doc, dpi=200) -> dict:
if pm.width > 4500 or pm.height > 4500: if pm.width > 4500 or pm.height > 4500:
pm = doc.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False) pm = doc.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False)
img = Image.frombytes('RGB', (pm.width, pm.height), pm.samples) # Convert pixmap samples directly to numpy array
img = np.array(img) img = np.frombuffer(pm.samples, dtype=np.uint8).reshape(pm.height, pm.width, 3)
img_dict = {'img': img, 'width': pm.width, 'height': pm.height} img_dict = {'img': img, 'width': pm.width, 'height': pm.height}
return img_dict return img_dict
@ImportPIL
def load_images_from_pdf(pdf_bytes: bytes, dpi=200, start_page_id=0, end_page_id=None) -> list: def load_images_from_pdf(pdf_bytes: bytes, dpi=200, start_page_id=0, end_page_id=None) -> list:
from PIL import Image
images = [] images = []
with fitz.open('pdf', pdf_bytes) as doc: with fitz.open('pdf', pdf_bytes) as doc:
pdf_page_num = doc.page_count pdf_page_num = doc.page_count
...@@ -57,11 +57,110 @@ def load_images_from_pdf(pdf_bytes: bytes, dpi=200, start_page_id=0, end_page_id ...@@ -57,11 +57,110 @@ def load_images_from_pdf(pdf_bytes: bytes, dpi=200, start_page_id=0, end_page_id
if pm.width > 4500 or pm.height > 4500: if pm.width > 4500 or pm.height > 4500:
pm = page.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False) pm = page.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False)
img = Image.frombytes('RGB', (pm.width, pm.height), pm.samples) # Convert pixmap samples directly to numpy array
img = np.array(img) img = np.frombuffer(pm.samples, dtype=np.uint8).reshape(pm.height, pm.width, 3)
img_dict = {'img': img, 'width': pm.width, 'height': pm.height} img_dict = {'img': img, 'width': pm.width, 'height': pm.height}
else: else:
img_dict = {'img': [], 'width': 0, 'height': 0} img_dict = {'img': [], 'width': 0, 'height': 0}
images.append(img_dict) images.append(img_dict)
return images return images
def convert_page(bytes_page):
pdfs = fitz.open('pdf', bytes_page)
page = pdfs[0]
return fitz_doc_to_image(page)
def parallel_process_pdf_safe(pages, num_workers=None, **kwargs):
"""Process PDF pages in parallel with serialization-safe approach."""
if num_workers is None:
num_workers = mp.cpu_count()
# Process the extracted page data in parallel
with ProcessPoolExecutor(max_workers=num_workers) as executor:
# Process the page data
results = list(
executor.map(convert_page, pages)
)
return results
def threaded_process_pdf(pdf_path, num_threads=4, **kwargs):
"""Process all pages of a PDF using multiple threads.
Parameters:
-----------
pdf_path : str
Path to the PDF file
num_threads : int
Number of threads to use
**kwargs :
Additional arguments for fitz_doc_to_image
Returns:
--------
images : list
List of processed images, in page order
"""
# Open the PDF
doc = fitz.open(pdf_path)
num_pages = len(doc)
# Create a list to store results in the correct order
results = [None] * num_pages
# Create a thread pool
with ThreadPoolExecutor(max_workers=num_threads) as executor:
# Submit all tasks
futures = {}
for page_num in range(num_pages):
page = doc[page_num]
future = executor.submit(fitz_doc_to_image, page, **kwargs)
futures[future] = page_num
# Process results as they complete with progress bar
for future in as_completed(futures):
page_num = futures[future]
try:
results[page_num] = future.result()
except Exception as e:
print(f'Error processing page {page_num}: {e}')
results[page_num] = None
# Close the document
doc.close()
if __name__ == '__main__':
pdf = fitz.open('/tmp/[MS-DOC].pdf')
pdf_page = [fitz.open() for i in range(pdf.page_count)]
[pdf_page[i].insert_pdf(pdf, from_page=i, to_page=i) for i in range(pdf.page_count)]
pdf_page = [v.tobytes() for v in pdf_page]
results = parallel_process_pdf_safe(pdf_page, num_workers=16)
# threaded_process_pdf('/tmp/[MS-DOC].pdf', num_threads=16)
""" benchmark results of multi-threaded processing (fitz page to image)
total page nums: 578
thread nums, time cost
1 7.351 sec
2 6.334 sec
4 5.968 sec
8 6.728 sec
16 8.085 sec
"""
""" benchmark results of multi-processor processing (fitz page to image)
total page nums: 578
processor nums, time cost
1 17.170 sec
2 10.170 sec
4 7.841 sec
8 7.900 sec
16 7.984 sec
"""
...@@ -208,12 +208,13 @@ def para_to_standard_format_v2(para_block, img_buket_path, page_idx, drop_reason ...@@ -208,12 +208,13 @@ def para_to_standard_format_v2(para_block, img_buket_path, page_idx, drop_reason
'text': merge_para_with_text(para_block), 'text': merge_para_with_text(para_block),
} }
elif para_type == BlockType.Title: elif para_type == BlockType.Title:
title_level = get_title_level(para_block)
para_content = { para_content = {
'type': 'text', 'type': 'text',
'text': merge_para_with_text(para_block), 'text': merge_para_with_text(para_block),
'text_level': title_level,
} }
title_level = get_title_level(para_block)
if title_level != 0:
para_content['text_level'] = title_level
elif para_type == BlockType.InterlineEquation: elif para_type == BlockType.InterlineEquation:
para_content = { para_content = {
'type': 'equation', 'type': 'equation',
...@@ -319,5 +320,5 @@ def get_title_level(block): ...@@ -319,5 +320,5 @@ def get_title_level(block):
if title_level > 4: if title_level > 4:
title_level = 4 title_level = 4
elif title_level < 1: elif title_level < 1:
title_level = 1 title_level = 0
return title_level return title_level
\ No newline at end of file
...@@ -44,14 +44,19 @@ def cut_image_to_pil_image(bbox: tuple, page: fitz.Page, mode="pillow"): ...@@ -44,14 +44,19 @@ def cut_image_to_pil_image(bbox: tuple, page: fitz.Page, mode="pillow"):
# 截取图片 # 截取图片
pix = page.get_pixmap(clip=rect, matrix=zoom) pix = page.get_pixmap(clip=rect, matrix=zoom)
# 将字节数据转换为文件对象
image_file = BytesIO(pix.tobytes(output='png'))
# 使用 Pillow 打开图像
pil_image = Image.open(image_file)
if mode == "cv2": if mode == "cv2":
image_result = cv2.cvtColor(np.asarray(pil_image), cv2.COLOR_RGB2BGR) # 直接转换为numpy数组供cv2使用
img_array = np.frombuffer(pix.samples, dtype=np.uint8).reshape(pix.height, pix.width, pix.n)
# PyMuPDF使用RGB顺序,而cv2使用BGR顺序
if pix.n == 3 or pix.n == 4:
image_result = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)
else:
image_result = img_array
elif mode == "pillow": elif mode == "pillow":
image_result = pil_image # 将字节数据转换为文件对象
image_file = BytesIO(pix.tobytes(output='png'))
# 使用 Pillow 打开图像
image_result = Image.open(image_file)
else: else:
raise ValueError(f"mode: {mode} is not supported.") raise ValueError(f"mode: {mode} is not supported.")
......
...@@ -48,7 +48,18 @@ def measure_time(func): ...@@ -48,7 +48,18 @@ def measure_time(func):
start_time = time.time() start_time = time.time()
result = func(*args, **kwargs) result = func(*args, **kwargs)
execution_time = time.time() - start_time execution_time = time.time() - start_time
PerformanceStats.add_execution_time(func.__name__, execution_time)
# 获取更详细的函数标识
if hasattr(func, "__self__"): # 实例方法
class_name = func.__self__.__class__.__name__
full_name = f"{class_name}.{func.__name__}"
elif hasattr(func, "__qualname__"): # 类方法或静态方法
full_name = func.__qualname__
else:
module_name = func.__module__
full_name = f"{module_name}.{func.__name__}"
PerformanceStats.add_execution_time(full_name, execution_time)
return result return result
return wrapper return wrapper
\ No newline at end of file
This diff is collapsed.
import os import os
import time import time
import numpy as np
import torch import torch
os.environ['FLAGS_npu_jit_compile'] = '0' # 关闭paddle的jit编译 os.environ['FLAGS_npu_jit_compile'] = '0' # 关闭paddle的jit编译
os.environ['FLAGS_use_stride_kernel'] = '0' os.environ['FLAGS_use_stride_kernel'] = '0'
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 让mps可以fallback os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 让mps可以fallback
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新 os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
# 关闭paddle的信号处理
import paddle
paddle.disable_signal_handler()
from loguru import logger from loguru import logger
from magic_pdf.model.batch_analyze import BatchAnalyze
from magic_pdf.model.sub_modules.model_utils import get_vram from magic_pdf.model.sub_modules.model_utils import get_vram
from magic_pdf.config.enums import SupportedPdfParseMethod
try:
import torchtext
if torchtext.__version__ >= '0.18.0':
torchtext.disable_torchtext_deprecation_warning()
except ImportError:
pass
import magic_pdf.model as model_config import magic_pdf.model as model_config
from magic_pdf.data.dataset import Dataset from magic_pdf.data.dataset import Dataset
from magic_pdf.libs.clean_memory import clean_memory from magic_pdf.libs.clean_memory import clean_memory
...@@ -30,8 +22,6 @@ from magic_pdf.libs.config_reader import (get_device, get_formula_config, ...@@ -30,8 +22,6 @@ from magic_pdf.libs.config_reader import (get_device, get_formula_config,
get_local_models_dir, get_local_models_dir,
get_table_recog_config) get_table_recog_config)
from magic_pdf.model.model_list import MODEL from magic_pdf.model.model_list import MODEL
from magic_pdf.operators.models import InferenceResult
class ModelSingleton: class ModelSingleton:
_instance = None _instance = None
...@@ -72,9 +62,7 @@ def custom_model_init( ...@@ -72,9 +62,7 @@ def custom_model_init(
formula_enable=None, formula_enable=None,
table_enable=None, table_enable=None,
): ):
model = None model = None
if model_config.__model_mode__ == 'lite': if model_config.__model_mode__ == 'lite':
logger.warning( logger.warning(
'The Lite mode is provided for developers to conduct testing only, and the output quality is ' 'The Lite mode is provided for developers to conduct testing only, and the output quality is '
...@@ -132,7 +120,6 @@ def custom_model_init( ...@@ -132,7 +120,6 @@ def custom_model_init(
return custom_model return custom_model
def doc_analyze( def doc_analyze(
dataset: Dataset, dataset: Dataset,
ocr: bool = False, ocr: bool = False,
...@@ -143,102 +130,160 @@ def doc_analyze( ...@@ -143,102 +130,160 @@ def doc_analyze(
layout_model=None, layout_model=None,
formula_enable=None, formula_enable=None,
table_enable=None, table_enable=None,
) -> InferenceResult: ):
end_page_id = ( end_page_id = (
end_page_id end_page_id
if end_page_id is not None and end_page_id >= 0 if end_page_id is not None and end_page_id >= 0
else len(dataset) - 1 else len(dataset) - 1
) )
model_manager = ModelSingleton() MIN_BATCH_INFERENCE_SIZE = int(os.environ.get('MINERU_MIN_BATCH_INFERENCE_SIZE', 200))
custom_model = model_manager.get_model( images = []
ocr, show_log, lang, layout_model, formula_enable, table_enable page_wh_list = []
) for index in range(len(dataset)):
if start_page_id <= index <= end_page_id:
batch_analyze = False page_data = dataset.get_page(index)
batch_ratio = 1 img_dict = page_data.get_image()
device = get_device() images.append(img_dict['img'])
page_wh_list.append((img_dict['width'], img_dict['height']))
if lang is None or lang == 'auto':
images_with_extra_info = [(images[index], ocr, dataset._lang) for index in range(len(dataset))]
else:
images_with_extra_info = [(images[index], ocr, lang) for index in range(len(dataset))]
npu_support = False if len(images) >= MIN_BATCH_INFERENCE_SIZE:
if str(device).startswith("npu"): batch_size = MIN_BATCH_INFERENCE_SIZE
import torch_npu batch_images = [images_with_extra_info[i:i+batch_size] for i in range(0, len(images_with_extra_info), batch_size)]
if torch_npu.npu.is_available(): else:
npu_support = True batch_images = [images_with_extra_info]
if torch.cuda.is_available() and device != 'cpu' or npu_support: results = []
gpu_memory = int(os.getenv("VIRTUAL_VRAM_SIZE", round(get_vram(device)))) for sn, batch_image in enumerate(batch_images):
if gpu_memory is not None and gpu_memory >= 8: _, result = may_batch_image_analyze(batch_image, sn, ocr, show_log,layout_model, formula_enable, table_enable)
results.extend(result)
if gpu_memory >= 16: model_json = []
batch_ratio = 8 for index in range(len(dataset)):
elif gpu_memory >= 10: if start_page_id <= index <= end_page_id:
batch_ratio = 4 result = results.pop(0)
else: page_width, page_height = page_wh_list.pop(0)
batch_ratio = 2 else:
result = []
page_height = 0
page_width = 0
logger.info(f'gpu_memory: {gpu_memory} GB, batch_ratio: {batch_ratio}') page_info = {'page_no': index, 'width': page_width, 'height': page_height}
batch_analyze = True page_dict = {'layout_dets': result, 'page_info': page_info}
model_json.append(page_dict)
model_json = [] from magic_pdf.operators.models import InferenceResult
doc_analyze_start = time.time() return InferenceResult(model_json, dataset)
if batch_analyze: def batch_doc_analyze(
# batch analyze datasets: list[Dataset],
images = [] parse_method: str,
page_wh_list = [] show_log: bool = False,
for index in range(len(dataset)): lang=None,
if start_page_id <= index <= end_page_id: layout_model=None,
page_data = dataset.get_page(index) formula_enable=None,
img_dict = page_data.get_image() table_enable=None,
images.append(img_dict['img']) ):
page_wh_list.append((img_dict['width'], img_dict['height'])) MIN_BATCH_INFERENCE_SIZE = int(os.environ.get('MINERU_MIN_BATCH_INFERENCE_SIZE', 200))
batch_model = BatchAnalyze(model=custom_model, batch_ratio=batch_ratio) batch_size = MIN_BATCH_INFERENCE_SIZE
analyze_result = batch_model(images) images = []
page_wh_list = []
images_with_extra_info = []
for dataset in datasets:
for index in range(len(dataset)): for index in range(len(dataset)):
if start_page_id <= index <= end_page_id: if lang is None or lang == 'auto':
result = analyze_result.pop(0) _lang = dataset._lang
page_width, page_height = page_wh_list.pop(0)
else: else:
result = [] _lang = lang
page_height = 0
page_width = 0
page_info = {'page_no': index, 'width': page_width, 'height': page_height} page_data = dataset.get_page(index)
img_dict = page_data.get_image()
images.append(img_dict['img'])
page_wh_list.append((img_dict['width'], img_dict['height']))
if parse_method == 'auto':
images_with_extra_info.append((images[-1], dataset.classify() == SupportedPdfParseMethod.OCR, _lang))
else:
images_with_extra_info.append((images[-1], parse_method == 'ocr', _lang))
batch_images = [images_with_extra_info[i:i+batch_size] for i in range(0, len(images_with_extra_info), batch_size)]
results = []
for sn, batch_image in enumerate(batch_images):
_, result = may_batch_image_analyze(batch_image, sn, True, show_log, layout_model, formula_enable, table_enable)
results.extend(result)
infer_results = []
from magic_pdf.operators.models import InferenceResult
for index in range(len(datasets)):
dataset = datasets[index]
model_json = []
for i in range(len(dataset)):
result = results.pop(0)
page_width, page_height = page_wh_list.pop(0)
page_info = {'page_no': i, 'width': page_width, 'height': page_height}
page_dict = {'layout_dets': result, 'page_info': page_info} page_dict = {'layout_dets': result, 'page_info': page_info}
model_json.append(page_dict) model_json.append(page_dict)
infer_results.append(InferenceResult(model_json, dataset))
return infer_results
else:
# single analyze
for index in range(len(dataset)): def may_batch_image_analyze(
page_data = dataset.get_page(index) images_with_extra_info: list[(np.ndarray, bool, str)],
img_dict = page_data.get_image() idx: int,
img = img_dict['img'] ocr: bool,
page_width = img_dict['width'] show_log: bool = False,
page_height = img_dict['height'] layout_model=None,
if start_page_id <= index <= end_page_id: formula_enable=None,
page_start = time.time() table_enable=None):
result = custom_model(img) # os.environ['CUDA_VISIBLE_DEVICES'] = str(idx)
logger.info(f'-----page_id : {index}, page total time: {round(time.time() - page_start, 2)}-----')
from magic_pdf.model.batch_analyze import BatchAnalyze
model_manager = ModelSingleton()
# images = [image for image, _, _ in images_with_extra_info]
batch_ratio = 1
device = get_device()
if str(device).startswith('npu'):
import torch_npu
if torch_npu.npu.is_available():
torch.npu.set_compile_mode(jit_compile=False)
if str(device).startswith('npu') or str(device).startswith('cuda'):
gpu_memory = int(os.getenv('VIRTUAL_VRAM_SIZE', round(get_vram(device))))
if gpu_memory is not None:
if gpu_memory >= 16:
batch_ratio = 16
elif gpu_memory >= 12:
batch_ratio = 8
elif gpu_memory >= 8:
batch_ratio = 4
elif gpu_memory >= 6:
batch_ratio = 2
else: else:
result = [] batch_ratio = 1
logger.info(f'gpu_memory: {gpu_memory} GB, batch_ratio: {batch_ratio}')
page_info = {'page_no': index, 'width': page_width, 'height': page_height}
page_dict = {'layout_dets': result, 'page_info': page_info}
model_json.append(page_dict)
gc_start = time.time() # doc_analyze_start = time.time()
clean_memory(get_device())
gc_time = round(time.time() - gc_start, 2)
logger.info(f'gc time: {gc_time}')
doc_analyze_time = round(time.time() - doc_analyze_start, 2)
doc_analyze_speed = round((end_page_id + 1 - start_page_id) / doc_analyze_time, 2)
logger.info(
f'doc analyze time: {round(time.time() - doc_analyze_start, 2)},'
f' speed: {doc_analyze_speed} pages/second'
)
return InferenceResult(model_json, dataset) batch_model = BatchAnalyze(model_manager, batch_ratio, show_log, layout_model, formula_enable, table_enable)
results = batch_model(images_with_extra_info)
# gc_start = time.time()
clean_memory(get_device())
# gc_time = round(time.time() - gc_start, 2)
# logger.debug(f'gc time: {gc_time}')
# doc_analyze_time = round(time.time() - doc_analyze_start, 2)
# doc_analyze_speed = round(len(images) / doc_analyze_time, 2)
# logger.debug(
# f'doc analyze time: {round(time.time() - doc_analyze_start, 2)},'
# f' speed: {doc_analyze_speed} pages/second'
# )
return idx, results
...@@ -3,28 +3,18 @@ import os ...@@ -3,28 +3,18 @@ import os
import time import time
import cv2 import cv2
import numpy as np
import torch import torch
import yaml import yaml
from loguru import logger from loguru import logger
from PIL import Image
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新 os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
try:
import torchtext
if torchtext.__version__ >= '0.18.0':
torchtext.disable_torchtext_deprecation_warning()
except ImportError:
pass
from magic_pdf.config.constants import * from magic_pdf.config.constants import *
from magic_pdf.model.model_list import AtomicModel from magic_pdf.model.model_list import AtomicModel
from magic_pdf.model.sub_modules.model_init import AtomModelSingleton from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
from magic_pdf.model.sub_modules.model_utils import ( from magic_pdf.model.sub_modules.model_utils import (
clean_vram, crop_img, get_res_list_from_layout_res) clean_vram, crop_img, get_res_list_from_layout_res)
from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import ( from magic_pdf.model.sub_modules.ocr.paddleocr2pytorch.ocr_utils import (
get_adjusted_mfdetrec_res, get_ocr_result_list) get_adjusted_mfdetrec_res, get_ocr_result_list)
...@@ -120,7 +110,7 @@ class CustomPEKModel: ...@@ -120,7 +110,7 @@ class CustomPEKModel:
atom_model_name=AtomicModel.MFR, atom_model_name=AtomicModel.MFR,
mfr_weight_dir=mfr_weight_dir, mfr_weight_dir=mfr_weight_dir,
mfr_cfg_path=mfr_cfg_path, mfr_cfg_path=mfr_cfg_path,
device='cpu' if str(self.device).startswith("mps") else self.device, device=self.device,
) )
# 初始化layout模型 # 初始化layout模型
...@@ -174,11 +164,6 @@ class CustomPEKModel: ...@@ -174,11 +164,6 @@ class CustomPEKModel:
logger.info('DocAnalysis init done!') logger.info('DocAnalysis init done!')
def __call__(self, image): def __call__(self, image):
pil_img = Image.fromarray(image)
width, height = pil_img.size
# logger.info(f'width: {width}, height: {height}')
# layout检测 # layout检测
layout_start = time.time() layout_start = time.time()
layout_res = [] layout_res = []
...@@ -186,24 +171,6 @@ class CustomPEKModel: ...@@ -186,24 +171,6 @@ class CustomPEKModel:
# layoutlmv3 # layoutlmv3
layout_res = self.layout_model(image, ignore_catids=[]) layout_res = self.layout_model(image, ignore_catids=[])
elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO: elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
# doclayout_yolo
# if height > width:
# input_res = {"poly":[0,0,width,0,width,height,0,height]}
# new_image, useful_list = crop_img(input_res, pil_img, crop_paste_x=width//2, crop_paste_y=0)
# paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
# layout_res = self.layout_model.predict(new_image)
# for res in layout_res:
# p1, p2, p3, p4, p5, p6, p7, p8 = res['poly']
# p1 = p1 - paste_x + xmin
# p2 = p2 - paste_y + ymin
# p3 = p3 - paste_x + xmin
# p4 = p4 - paste_y + ymin
# p5 = p5 - paste_x + xmin
# p6 = p6 - paste_y + ymin
# p7 = p7 - paste_x + xmin
# p8 = p8 - paste_y + ymin
# res['poly'] = [p1, p2, p3, p4, p5, p6, p7, p8]
# else:
layout_res = self.layout_model.predict(image) layout_res = self.layout_model.predict(image)
layout_cost = round(time.time() - layout_start, 2) layout_cost = round(time.time() - layout_start, 2)
...@@ -234,11 +201,11 @@ class CustomPEKModel: ...@@ -234,11 +201,11 @@ class CustomPEKModel:
ocr_start = time.time() ocr_start = time.time()
# Process each area that requires OCR processing # Process each area that requires OCR processing
for res in ocr_res_list: for res in ocr_res_list:
new_image, useful_list = crop_img(res, pil_img, crop_paste_x=50, crop_paste_y=50) new_image, useful_list = crop_img(res, image, crop_paste_x=50, crop_paste_y=50)
adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(single_page_mfdetrec_res, useful_list) adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(single_page_mfdetrec_res, useful_list)
# OCR recognition # OCR recognition
new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR) new_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR)
if self.apply_ocr: if self.apply_ocr:
ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res)[0] ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res)[0]
...@@ -260,7 +227,7 @@ class CustomPEKModel: ...@@ -260,7 +227,7 @@ class CustomPEKModel:
if self.apply_table: if self.apply_table:
table_start = time.time() table_start = time.time()
for res in table_res_list: for res in table_res_list:
new_image, _ = crop_img(res, pil_img) new_image, _ = crop_img(res, image)
single_table_start_time = time.time() single_table_start_time = time.time()
html_code = None html_code = None
if self.table_model_name == MODEL_NAME.STRUCT_EQTABLE: if self.table_model_name == MODEL_NAME.STRUCT_EQTABLE:
......
...@@ -3,8 +3,6 @@ import os ...@@ -3,8 +3,6 @@ import os
from pathlib import Path from pathlib import Path
import yaml import yaml
from PIL import Image
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新 os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
from magic_pdf.config.constants import MODEL_NAME from magic_pdf.config.constants import MODEL_NAME
...@@ -42,7 +40,7 @@ def get_text_images(simple_images): ...@@ -42,7 +40,7 @@ def get_text_images(simple_images):
) )
text_images = [] text_images = []
for simple_image in simple_images: for simple_image in simple_images:
image = Image.fromarray(simple_image['img']) image = simple_image['img']
layout_res = temp_layout_model.predict(image) layout_res = temp_layout_model.predict(image)
# 给textblock截图 # 给textblock截图
for res in layout_res: for res in layout_res:
...@@ -51,7 +49,7 @@ def get_text_images(simple_images): ...@@ -51,7 +49,7 @@ def get_text_images(simple_images):
# 初步清洗(宽和高都小于100) # 初步清洗(宽和高都小于100)
if x2 - x1 < 100 and y2 - y1 < 100: if x2 - x1 < 100 and y2 - y1 < 100:
continue continue
text_images.append(image.crop((x1, y1, x2, y2))) text_images.append(image[y1:y2, x1:x2])
return text_images return text_images
......
...@@ -2,9 +2,9 @@ ...@@ -2,9 +2,9 @@
import time import time
from collections import Counter from collections import Counter
from uuid import uuid4 from uuid import uuid4
import cv2
import numpy as np
import torch import torch
from PIL import Image
from loguru import logger from loguru import logger
from ultralytics import YOLO from ultralytics import YOLO
...@@ -29,7 +29,7 @@ def split_images(image, result_images=None): ...@@ -29,7 +29,7 @@ def split_images(image, result_images=None):
if result_images is None: if result_images is None:
result_images = [] result_images = []
width, height = image.size height, width = image.shape[:2]
long_side = max(width, height) # 获取较长边长度 long_side = max(width, height) # 获取较长边长度
if long_side <= 400: if long_side <= 400:
...@@ -44,16 +44,14 @@ def split_images(image, result_images=None): ...@@ -44,16 +44,14 @@ def split_images(image, result_images=None):
# 判断裁剪区域是否超出图片范围,如果超出则不进行裁剪保存操作 # 判断裁剪区域是否超出图片范围,如果超出则不进行裁剪保存操作
if x + new_long_side > width: if x + new_long_side > width:
continue continue
box = (x, 0, x + new_long_side, height) sub_image = image[0:height, x:x + new_long_side]
sub_image = image.crop(box)
sub_images.append(sub_image) sub_images.append(sub_image)
else: # 如果高度是较长边 else: # 如果高度是较长边
for y in range(0, height, new_long_side): for y in range(0, height, new_long_side):
# 判断裁剪区域是否超出图片范围,如果超出则不进行裁剪保存操作 # 判断裁剪区域是否超出图片范围,如果超出则不进行裁剪保存操作
if y + new_long_side > height: if y + new_long_side > height:
continue continue
box = (0, y, width, y + new_long_side) sub_image = image[y:y + new_long_side, 0:width]
sub_image = image.crop(box)
sub_images.append(sub_image) sub_images.append(sub_image)
for sub_image in sub_images: for sub_image in sub_images:
...@@ -64,24 +62,32 @@ def split_images(image, result_images=None): ...@@ -64,24 +62,32 @@ def split_images(image, result_images=None):
def resize_images_to_224(image): def resize_images_to_224(image):
""" """
若分辨率小于224则用黑色背景补齐到224*224大小,若大于等于224则调整为224*224大小,并保存到输出文件夹中。 若分辨率小于224则用黑色背景补齐到224*224大小,若大于等于224则调整为224*224大小。
Works directly with NumPy arrays.
""" """
try: try:
width, height = image.size height, width = image.shape[:2]
if width < 224 or height < 224: if width < 224 or height < 224:
new_image = Image.new('RGB', (224, 224), (0, 0, 0)) # Create black background
paste_x = (224 - width) // 2 new_image = np.zeros((224, 224, 3), dtype=np.uint8)
paste_y = (224 - height) // 2 # Calculate paste position (ensure they're not negative)
new_image.paste(image, (paste_x, paste_y)) paste_x = max(0, (224 - width) // 2)
paste_y = max(0, (224 - height) // 2)
# Make sure we don't exceed the boundaries of new_image
paste_width = min(width, 224)
paste_height = min(height, 224)
# Paste original image onto black background
new_image[paste_y:paste_y + paste_height, paste_x:paste_x + paste_width] = image[:paste_height, :paste_width]
image = new_image image = new_image
else: else:
image = image.resize((224, 224), Image.Resampling.LANCZOS) # Resize using cv2
image = cv2.resize(image, (224, 224), interpolation=cv2.INTER_LANCZOS4)
# uuid = str(uuid4())
# image.save(f"/tmp/{uuid}.jpg")
return image return image
except Exception as e: except Exception as e:
logger.exception(e) logger.exception(f"Error in resize_images_to_224: {e}")
return None
class YOLOv11LangDetModel(object): class YOLOv11LangDetModel(object):
...@@ -96,8 +102,7 @@ class YOLOv11LangDetModel(object): ...@@ -96,8 +102,7 @@ class YOLOv11LangDetModel(object):
def do_detect(self, images: list): def do_detect(self, images: list):
all_images = [] all_images = []
for image in images: for image in images:
width, height = image.size height, width = image.shape[:2]
# logger.info(f"image size: {width} x {height}")
if width < 100 and height < 100: if width < 100 and height < 100:
continue continue
temp_images = split_images(image) temp_images = split_images(image)
......
from doclayout_yolo import YOLOv10 from doclayout_yolo import YOLOv10
from tqdm import tqdm
class DocLayoutYOLOModel(object): class DocLayoutYOLOModel(object):
...@@ -31,7 +32,8 @@ class DocLayoutYOLOModel(object): ...@@ -31,7 +32,8 @@ class DocLayoutYOLOModel(object):
def batch_predict(self, images: list, batch_size: int) -> list: def batch_predict(self, images: list, batch_size: int) -> list:
images_layout_res = [] images_layout_res = []
for index in range(0, len(images), batch_size): # for index in range(0, len(images), batch_size):
for index in tqdm(range(0, len(images), batch_size), desc="Layout Predict"):
doclayout_yolo_res = [ doclayout_yolo_res = [
image_res.cpu() image_res.cpu()
for image_res in self.model.predict( for image_res in self.model.predict(
......
from tqdm import tqdm
from ultralytics import YOLO from ultralytics import YOLO
...@@ -14,7 +15,8 @@ class YOLOv8MFDModel(object): ...@@ -14,7 +15,8 @@ class YOLOv8MFDModel(object):
def batch_predict(self, images: list, batch_size: int) -> list: def batch_predict(self, images: list, batch_size: int) -> list:
images_mfd_res = [] images_mfd_res = []
for index in range(0, len(images), batch_size): # for index in range(0, len(images), batch_size):
for index in tqdm(range(0, len(images), batch_size), desc="MFD Predict"):
mfd_res = [ mfd_res = [
image_res.cpu() image_res.cpu()
for image_res in self.mfd_model.predict( for image_res in self.mfd_model.predict(
......
import argparse
import os
import re
import torch import torch
import unimernet.tasks as tasks
from PIL import Image
from torch.utils.data import DataLoader, Dataset from torch.utils.data import DataLoader, Dataset
from torchvision import transforms from tqdm import tqdm
from unimernet.common.config import Config
from unimernet.processors import load_processor
class MathDataset(Dataset): class MathDataset(Dataset):
...@@ -20,55 +12,24 @@ class MathDataset(Dataset): ...@@ -20,55 +12,24 @@ class MathDataset(Dataset):
return len(self.image_paths) return len(self.image_paths)
def __getitem__(self, idx): def __getitem__(self, idx):
# if not pil image, then convert to pil image raw_image = self.image_paths[idx]
if isinstance(self.image_paths[idx], str):
raw_image = Image.open(self.image_paths[idx])
else:
raw_image = self.image_paths[idx]
if self.transform: if self.transform:
image = self.transform(raw_image) image = self.transform(raw_image)
return image return image
def latex_rm_whitespace(s: str):
"""Remove unnecessary whitespace from LaTeX code."""
text_reg = r"(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})"
letter = "[a-zA-Z]"
noletter = "[\W_^\d]"
names = [x[0].replace(" ", "") for x in re.findall(text_reg, s)]
s = re.sub(text_reg, lambda match: str(names.pop(0)), s)
news = s
while True:
s = news
news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, noletter), r"\1\2", s)
news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, letter), r"\1\2", news)
news = re.sub(r"(%s)\s+?(%s)" % (letter, noletter), r"\1\2", news)
if news == s:
break
return s
class UnimernetModel(object): class UnimernetModel(object):
def __init__(self, weight_dir, cfg_path, _device_="cpu"): def __init__(self, weight_dir, cfg_path, _device_="cpu"):
args = argparse.Namespace(cfg_path=cfg_path, options=None) from .unimernet_hf import UnimernetModel
cfg = Config(args) if _device_.startswith("mps"):
cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.pth") self.model = UnimernetModel.from_pretrained(weight_dir, attn_implementation="eager")
cfg.config.model.model_config.model_name = weight_dir else:
cfg.config.model.tokenizer_config.path = weight_dir self.model = UnimernetModel.from_pretrained(weight_dir)
task = tasks.setup_task(cfg)
self.model = task.build_model(cfg)
self.device = _device_ self.device = _device_
self.model.to(_device_) self.model.to(_device_)
if not _device_.startswith("cpu"):
self.model = self.model.to(dtype=torch.float16)
self.model.eval() self.model.eval()
vis_processor = load_processor(
"formula_image_eval",
cfg.config.datasets.formula_rec_eval.vis_processor.eval,
)
self.mfr_transform = transforms.Compose(
[
vis_processor,
]
)
def predict(self, mfd_res, image): def predict(self, mfd_res, image):
formula_list = [] formula_list = []
...@@ -84,62 +45,22 @@ class UnimernetModel(object): ...@@ -84,62 +45,22 @@ class UnimernetModel(object):
"latex": "", "latex": "",
} }
formula_list.append(new_item) formula_list.append(new_item)
pil_img = Image.fromarray(image) bbox_img = image[ymin:ymax, xmin:xmax]
bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
mf_image_list.append(bbox_img) mf_image_list.append(bbox_img)
dataset = MathDataset(mf_image_list, transform=self.mfr_transform) dataset = MathDataset(mf_image_list, transform=self.model.transform)
dataloader = DataLoader(dataset, batch_size=32, num_workers=0) dataloader = DataLoader(dataset, batch_size=32, num_workers=0)
mfr_res = [] mfr_res = []
for mf_img in dataloader: for mf_img in dataloader:
mf_img = mf_img.to(dtype=self.model.dtype)
mf_img = mf_img.to(self.device) mf_img = mf_img.to(self.device)
with torch.no_grad(): with torch.no_grad():
output = self.model.generate({"image": mf_img}) output = self.model.generate({"image": mf_img})
mfr_res.extend(output["pred_str"]) mfr_res.extend(output["fixed_str"])
for res, latex in zip(formula_list, mfr_res): for res, latex in zip(formula_list, mfr_res):
res["latex"] = latex_rm_whitespace(latex) res["latex"] = latex
return formula_list return formula_list
# def batch_predict(
# self, images_mfd_res: list, images: list, batch_size: int = 64
# ) -> list:
# images_formula_list = []
# mf_image_list = []
# backfill_list = []
# for image_index in range(len(images_mfd_res)):
# mfd_res = images_mfd_res[image_index]
# pil_img = Image.fromarray(images[image_index])
# formula_list = []
#
# for xyxy, conf, cla in zip(
# mfd_res.boxes.xyxy, mfd_res.boxes.conf, mfd_res.boxes.cls
# ):
# xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
# new_item = {
# "category_id": 13 + int(cla.item()),
# "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
# "score": round(float(conf.item()), 2),
# "latex": "",
# }
# formula_list.append(new_item)
# bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
# mf_image_list.append(bbox_img)
#
# images_formula_list.append(formula_list)
# backfill_list += formula_list
#
# dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
# dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0)
# mfr_res = []
# for mf_img in dataloader:
# mf_img = mf_img.to(self.device)
# with torch.no_grad():
# output = self.model.generate({"image": mf_img})
# mfr_res.extend(output["pred_str"])
# for res, latex in zip(backfill_list, mfr_res):
# res["latex"] = latex_rm_whitespace(latex)
# return images_formula_list
def batch_predict(self, images_mfd_res: list, images: list, batch_size: int = 64) -> list: def batch_predict(self, images_mfd_res: list, images: list, batch_size: int = 64) -> list:
images_formula_list = [] images_formula_list = []
mf_image_list = [] mf_image_list = []
...@@ -149,7 +70,7 @@ class UnimernetModel(object): ...@@ -149,7 +70,7 @@ class UnimernetModel(object):
# Collect images with their original indices # Collect images with their original indices
for image_index in range(len(images_mfd_res)): for image_index in range(len(images_mfd_res)):
mfd_res = images_mfd_res[image_index] mfd_res = images_mfd_res[image_index]
pil_img = Image.fromarray(images[image_index]) np_array_image = images[image_index]
formula_list = [] formula_list = []
for idx, (xyxy, conf, cla) in enumerate(zip( for idx, (xyxy, conf, cla) in enumerate(zip(
...@@ -163,7 +84,7 @@ class UnimernetModel(object): ...@@ -163,7 +84,7 @@ class UnimernetModel(object):
"latex": "", "latex": "",
} }
formula_list.append(new_item) formula_list.append(new_item)
bbox_img = pil_img.crop((xmin, ymin, xmax, ymax)) bbox_img = np_array_image[ymin:ymax, xmin:xmax]
area = (xmax - xmin) * (ymax - ymin) area = (xmax - xmin) * (ymax - ymin)
curr_idx = len(mf_image_list) curr_idx = len(mf_image_list)
...@@ -182,22 +103,30 @@ class UnimernetModel(object): ...@@ -182,22 +103,30 @@ class UnimernetModel(object):
index_mapping = {new_idx: old_idx for new_idx, old_idx in enumerate(sorted_indices)} index_mapping = {new_idx: old_idx for new_idx, old_idx in enumerate(sorted_indices)}
# Create dataset with sorted images # Create dataset with sorted images
dataset = MathDataset(sorted_images, transform=self.mfr_transform) dataset = MathDataset(sorted_images, transform=self.model.transform)
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0) dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0)
# Process batches and store results # Process batches and store results
mfr_res = [] mfr_res = []
for mf_img in dataloader: # for mf_img in dataloader:
mf_img = mf_img.to(self.device)
with torch.no_grad(): with tqdm(total=len(sorted_images), desc="MFR Predict") as pbar:
output = self.model.generate({"image": mf_img}) for index, mf_img in enumerate(dataloader):
mfr_res.extend(output["pred_str"]) mf_img = mf_img.to(dtype=self.model.dtype)
mf_img = mf_img.to(self.device)
with torch.no_grad():
output = self.model.generate({"image": mf_img})
mfr_res.extend(output["fixed_str"])
# 更新进度条,每次增加batch_size,但要注意最后一个batch可能不足batch_size
current_batch_size = min(batch_size, len(sorted_images) - index * batch_size)
pbar.update(current_batch_size)
# Restore original order # Restore original order
unsorted_results = [""] * len(mfr_res) unsorted_results = [""] * len(mfr_res)
for new_idx, latex in enumerate(mfr_res): for new_idx, latex in enumerate(mfr_res):
original_idx = index_mapping[new_idx] original_idx = index_mapping[new_idx]
unsorted_results[original_idx] = latex_rm_whitespace(latex) unsorted_results[original_idx] = latex
# Fill results back # Fill results back
for res, latex in zip(backfill_list, unsorted_results): for res, latex in zip(backfill_list, unsorted_results):
......
from .unimer_swin import UnimerSwinConfig, UnimerSwinModel, UnimerSwinImageProcessor
from .unimer_mbart import UnimerMBartConfig, UnimerMBartModel, UnimerMBartForCausalLM
from .modeling_unimernet import UnimernetModel
__all__ = [
"UnimerSwinConfig",
"UnimerSwinModel",
"UnimerSwinImageProcessor",
"UnimerMBartConfig",
"UnimerMBartModel",
"UnimerMBartForCausalLM",
"UnimernetModel",
]
import os
import re
import warnings
from typing import Optional
import torch
from ftfy import fix_text
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokenizer, PretrainedConfig, PreTrainedModel
from transformers import VisionEncoderDecoderConfig, VisionEncoderDecoderModel
from transformers.models.vision_encoder_decoder.modeling_vision_encoder_decoder import logger as base_model_logger
from .unimer_swin import UnimerSwinConfig, UnimerSwinModel, UnimerSwinImageProcessor
from .unimer_mbart import UnimerMBartConfig, UnimerMBartForCausalLM
AutoConfig.register(UnimerSwinConfig.model_type, UnimerSwinConfig)
AutoConfig.register(UnimerMBartConfig.model_type, UnimerMBartConfig)
AutoModel.register(UnimerSwinConfig, UnimerSwinModel)
AutoModelForCausalLM.register(UnimerMBartConfig, UnimerMBartForCausalLM)
# TODO: rewrite tokenizer
class TokenizerWrapper:
def __init__(self, tokenizer):
self.tokenizer = tokenizer
self.pad_token_id = self.tokenizer.pad_token_id
self.bos_token_id = self.tokenizer.bos_token_id
self.eos_token_id = self.tokenizer.eos_token_id
def __len__(self):
return len(self.tokenizer)
def tokenize(self, text, **kwargs):
return self.tokenizer(
text,
return_token_type_ids=False,
return_tensors="pt",
padding="longest",
truncation=True,
**kwargs,
)
def token2str(self, tokens) -> list:
generated_text = self.tokenizer.batch_decode(tokens, skip_special_tokens=True)
generated_text = [fix_text(text) for text in generated_text]
return generated_text
def detokenize(self, tokens):
toks = [self.tokenizer.convert_ids_to_tokens(tok) for tok in tokens]
for b in range(len(toks)):
for i in reversed(range(len(toks[b]))):
if toks[b][i] is None:
toks[b][i] = ''
toks[b][i] = toks[b][i].replace('Ġ', ' ').strip()
if toks[b][i] in ([self.tokenizer.bos_token, self.tokenizer.eos_token, self.tokenizer.pad_token]):
del toks[b][i]
return toks
def latex_rm_whitespace(s: str):
"""Remove unnecessary whitespace from LaTeX code.
"""
text_reg = r'(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})'
letter = r'[a-zA-Z]'
noletter = r'[\W_^\d]'
names = [x[0].replace(' ', '') for x in re.findall(text_reg, s)]
s = re.sub(text_reg, lambda _: str(names.pop(0)), s)
news = s
while True:
s = news
news = re.sub(r'(?!\\ )(%s)\s+?(%s)' % (noletter, noletter), r'\1\2', s)
news = re.sub(r'(?!\\ )(%s)\s+?(%s)' % (noletter, letter), r'\1\2', news)
news = re.sub(r'(%s)\s+?(%s)' % (letter, noletter), r'\1\2', news)
if news == s:
break
return s
class UnimernetModel(VisionEncoderDecoderModel):
def __init__(
self,
config: Optional[PretrainedConfig] = None,
encoder: Optional[PreTrainedModel] = None,
decoder: Optional[PreTrainedModel] = None,
):
# VisionEncoderDecoderModel's checking log has bug, disable for temp.
base_model_logger.disabled = True
try:
super().__init__(config, encoder, decoder)
finally:
base_model_logger.disabled = False
if not config or not hasattr(config, "_name_or_path"):
raise RuntimeError("config._name_or_path is required by UnimernetModel.")
model_path = config._name_or_path
self.transform = UnimerSwinImageProcessor()
self.tokenizer = TokenizerWrapper(AutoTokenizer.from_pretrained(model_path))
self._post_check()
def _post_check(self):
tokenizer = self.tokenizer
if tokenizer.tokenizer.model_max_length != self.config.decoder.max_position_embeddings:
warnings.warn(
f"decoder.max_position_embeddings={self.config.decoder.max_position_embeddings}," +
f" but tokenizer.model_max_length={tokenizer.tokenizer.model_max_length}, will set" +
f" tokenizer.model_max_length to {self.config.decoder.max_position_embeddings}.")
tokenizer.tokenizer.model_max_length = self.config.decoder.max_position_embeddings
assert self.config.decoder.vocab_size == len(tokenizer)
assert self.config.decoder_start_token_id == tokenizer.bos_token_id
assert self.config.pad_token_id == tokenizer.pad_token_id
@classmethod
def from_checkpoint(cls, model_path: str, model_filename: str = "pytorch_model.pth", state_dict_strip_prefix="model.model."):
config = VisionEncoderDecoderConfig.from_pretrained(model_path)
config._name_or_path = model_path
config.encoder = UnimerSwinConfig(**vars(config.encoder))
config.decoder = UnimerMBartConfig(**vars(config.decoder))
encoder = UnimerSwinModel(config.encoder)
decoder = UnimerMBartForCausalLM(config.decoder)
model = cls(config, encoder, decoder)
# load model weights
model_file_path = os.path.join(model_path, model_filename)
checkpoint = torch.load(model_file_path, map_location="cpu", weights_only=True)
state_dict = checkpoint["model"] if "model" in checkpoint else checkpoint
if not state_dict:
raise RuntimeError("state_dict is empty.")
if state_dict_strip_prefix:
state_dict = {
k[len(state_dict_strip_prefix):] if k.startswith(state_dict_strip_prefix) else k: v
for k, v in state_dict.items()
}
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
if len(unexpected_keys) > 0:
warnings.warn("Unexpected key(s) in state_dict: {}.".format(", ".join(f'"{k}"' for k in unexpected_keys)))
if len(missing_keys) > 0:
raise RuntimeError("Missing key(s) in state_dict: {}.".format(", ".join(f'"{k}"' for k in missing_keys)))
return model
def forward_bak(self, samples):
pixel_values, text = samples["image"], samples["text_input"]
text_inputs = self.tokenizer.tokenize(text).to(pixel_values.device)
decoder_input_ids, decoder_attention_mask = text_inputs["input_ids"], text_inputs["attention_mask"]
num_channels = pixel_values.shape[1]
if num_channels == 1:
pixel_values = pixel_values.repeat(1, 3, 1, 1)
labels = decoder_input_ids * 1
labels = labels.masked_fill(labels == self.tokenizer.pad_token_id, -100)
loss = self.model(
pixel_values=pixel_values,
decoder_input_ids=decoder_input_ids[:, :-1],
decoder_attention_mask=decoder_attention_mask[:, :-1],
labels=labels[:, 1:],
).loss
return {"loss": loss}
def generate(self, samples, do_sample: bool = False, temperature: float = 0.2, top_p: float = 0.95):
pixel_values = samples["image"]
num_channels = pixel_values.shape[1]
if num_channels == 1:
pixel_values = pixel_values.repeat(1, 3, 1, 1)
kwargs = {}
if do_sample:
kwargs["temperature"] = temperature
kwargs["top_p"] = top_p
outputs = super().generate(
pixel_values=pixel_values,
max_new_tokens=self.tokenizer.tokenizer.model_max_length, # required
decoder_start_token_id=self.tokenizer.tokenizer.bos_token_id,
do_sample=do_sample,
**kwargs,
)
outputs = outputs[:, 1:].cpu().numpy()
pred_tokens = self.tokenizer.detokenize(outputs)
pred_str = self.tokenizer.token2str(outputs)
fixed_str = [latex_rm_whitespace(s) for s in pred_str]
return {"pred_ids": outputs, "pred_tokens": pred_tokens, "pred_str": pred_str, "fixed_str": fixed_str}
from .configuration_unimer_mbart import UnimerMBartConfig
from .modeling_unimer_mbart import UnimerMBartModel, UnimerMBartForCausalLM
__all__ = [
"UnimerMBartConfig",
"UnimerMBartModel",
"UnimerMBartForCausalLM",
]
# coding=utf-8
# Copyright 2021, The Facebook AI Research Team and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""UnimerMBART model configuration"""
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
class UnimerMBartConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`MBartModel`]. It is used to instantiate an MBART
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the MBART
[facebook/mbart-large-cc25](https://huggingface.co/facebook/mbart-large-cc25) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 50265):
Vocabulary size of the MBART model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`MBartModel`] or [`TFMBartModel`].
d_model (`int`, *optional*, defaults to 1024):
Dimensionality of the layers and the pooler layer.
qk_squeeze (`int`, *optional*, defaults to 2):
Squeeze ratio for query/key's output dimension. See the [UniMERNet paper](https://arxiv.org/abs/2404.15254).
Squeeze Attention maps the query and key to a lower-dimensional space without excessive loss of information,
thereby accelerating the computation of attention.
encoder_layers (`int`, *optional*, defaults to 12):
Number of encoder layers.
decoder_layers (`int`, *optional*, defaults to 12):
Number of decoder layers.
encoder_attention_heads (`int`, *optional*, defaults to 16):
Number of attention heads for each attention layer in the Transformer encoder.
decoder_attention_heads (`int`, *optional*, defaults to 16):
Number of attention heads for each attention layer in the Transformer decoder.
decoder_ffn_dim (`int`, *optional*, defaults to 4096):
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
encoder_ffn_dim (`int`, *optional*, defaults to 4096):
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"silu"` and `"gelu_new"` are supported.
dropout (`float`, *optional*, defaults to 0.1):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
activation_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for activations inside the fully connected layer.
classifier_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for classifier.
max_position_embeddings (`int`, *optional*, defaults to 1024):
The maximum sequence length that this model might ever be used with. Typically set this to something large
just in case (e.g., 512 or 1024 or 2048).
init_std (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
encoder_layerdrop (`float`, *optional*, defaults to 0.0):
The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
for more details.
decoder_layerdrop (`float`, *optional*, defaults to 0.0):
The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
for more details.
scale_embedding (`bool`, *optional*, defaults to `False`):
Scale embeddings by diving by sqrt(d_model).
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models)
forced_eos_token_id (`int`, *optional*, defaults to 2):
The id of the token to force as the last generated token when `max_length` is reached. Usually set to
`eos_token_id`.
Example:
```python
>>> from transformers import MBartConfig, MBartModel
>>> # Initializing a MBART facebook/mbart-large-cc25 style configuration
>>> configuration = MBartConfig()
>>> # Initializing a model (with random weights) from the facebook/mbart-large-cc25 style configuration
>>> model = MBartModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "unimer-mbart"
keys_to_ignore_at_inference = ["past_key_values"]
attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"}
def __init__(
self,
vocab_size=50265,
max_position_embeddings=1024,
encoder_layers=12,
encoder_ffn_dim=4096,
encoder_attention_heads=16,
decoder_layers=12,
decoder_ffn_dim=4096,
decoder_attention_heads=16,
encoder_layerdrop=0.0,
decoder_layerdrop=0.0,
use_cache=True,
is_encoder_decoder=True,
activation_function="gelu",
d_model=1024,
qk_squeeze=2,
dropout=0.1,
attention_dropout=0.0,
activation_dropout=0.0,
init_std=0.02,
classifier_dropout=0.0,
scale_embedding=False,
pad_token_id=1,
bos_token_id=0,
eos_token_id=2,
forced_eos_token_id=2,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.d_model = d_model
self.qk_squeeze = qk_squeeze
self.encoder_ffn_dim = encoder_ffn_dim
self.encoder_layers = encoder_layers
self.encoder_attention_heads = encoder_attention_heads
self.decoder_ffn_dim = decoder_ffn_dim
self.decoder_layers = decoder_layers
self.decoder_attention_heads = decoder_attention_heads
self.dropout = dropout
self.attention_dropout = attention_dropout
self.activation_dropout = activation_dropout
self.activation_function = activation_function
self.init_std = init_std
self.encoder_layerdrop = encoder_layerdrop
self.decoder_layerdrop = decoder_layerdrop
self.classifier_dropout = classifier_dropout
self.use_cache = use_cache
self.num_hidden_layers = encoder_layers
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
is_encoder_decoder=is_encoder_decoder,
forced_eos_token_id=forced_eos_token_id,
**kwargs,
)
from .configuration_unimer_swin import UnimerSwinConfig
from .modeling_unimer_swin import UnimerSwinModel
from .image_processing_unimer_swin import UnimerSwinImageProcessor
__all__ = [
"UnimerSwinConfig",
"UnimerSwinModel",
"UnimerSwinImageProcessor",
]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment