"tests/vscode:/vscode.git/clone" did not exist on "2bfa5a61fb21e03cb3e70b0cdace7bd8466a2817"
Commit 9ce72d78 authored by myhloli's avatar myhloli
Browse files

Merge remote-tracking branch 'origin/dev' into dev

parents 59435d88 27281c92
...@@ -7,19 +7,15 @@ numpy>=1.21.6,<2.0.0 ...@@ -7,19 +7,15 @@ numpy>=1.21.6,<2.0.0
fast-langdetect>=0.2.3,<0.3.0 fast-langdetect>=0.2.3,<0.3.0
scikit-learn>=1.0.2 scikit-learn>=1.0.2
pdfminer.six==20231228 pdfminer.six==20231228
unimernet==0.2.3 torch==2.3.1
torch>=2.2.2,<=2.3.1 torchvision==0.18.1
torchvision>=0.17.2,<=0.18.1
matplotlib matplotlib
ultralytics>=8.3.48 ultralytics>=8.3.48
paddleocr==2.7.3 paddleocr==2.7.3
paddlepaddle==3.0.0rc1 paddlepaddle==3.0.0rc1
struct-eqtable==0.3.2
einops
accelerate
rapidocr-paddle>=1.4.5,<2.0.0 rapidocr-paddle>=1.4.5,<2.0.0
rapidocr-onnxruntime>=1.4.4,<2.0.0 rapidocr-onnxruntime>=1.4.4,<2.0.0
rapid-table>=1.0.3,<2.0.0 rapid-table>=1.0.3,<2.0.0
doclayout-yolo==0.0.2b1 doclayout-yolo==0.0.2b1
ftfy
openai openai
detectron2
...@@ -7,18 +7,14 @@ numpy>=1.21.6,<2.0.0 ...@@ -7,18 +7,14 @@ numpy>=1.21.6,<2.0.0
fast-langdetect>=0.2.3,<0.3.0 fast-langdetect>=0.2.3,<0.3.0
scikit-learn>=1.0.2 scikit-learn>=1.0.2
pdfminer.six==20231228 pdfminer.six==20231228
unimernet==0.2.3 torch>=2.2.2,!=2.5.0,!=2.5.1,<=2.6.0
torch>=2.2.2,<=2.3.1 torchvision
torchvision>=0.17.2,<=0.18.1
matplotlib matplotlib
ultralytics>=8.3.48 ultralytics>=8.3.48
paddleocr==2.7.3 paddleocr==2.7.3
struct-eqtable==0.3.2
einops
accelerate
rapidocr-paddle>=1.4.5,<2.0.0 rapidocr-paddle>=1.4.5,<2.0.0
rapidocr-onnxruntime>=1.4.4,<2.0.0 rapidocr-onnxruntime>=1.4.4,<2.0.0
rapid-table>=1.0.3,<2.0.0 rapid-table>=1.0.3,<2.0.0
doclayout-yolo==0.0.2b1 doclayout-yolo==0.0.2b1
ftfy
openai openai
detectron2
...@@ -7,18 +7,14 @@ numpy>=1.21.6,<2.0.0 ...@@ -7,18 +7,14 @@ numpy>=1.21.6,<2.0.0
fast-langdetect>=0.2.3,<0.3.0 fast-langdetect>=0.2.3,<0.3.0
scikit-learn>=1.0.2 scikit-learn>=1.0.2
pdfminer.six==20231228 pdfminer.six==20231228
unimernet==0.2.3 torch>=2.2.2,!=2.5.0,!=2.5.1,<=2.6.0
torch>=2.2.2,<=2.3.1 torchvision
torchvision>=0.17.2,<=0.18.1
matplotlib matplotlib
ultralytics>=8.3.48 ultralytics>=8.3.48
paddleocr==2.7.3 paddleocr==2.7.3
struct-eqtable==0.3.2
einops
accelerate
rapidocr-paddle>=1.4.5,<2.0.0 rapidocr-paddle>=1.4.5,<2.0.0
rapidocr-onnxruntime>=1.4.4,<2.0.0 rapidocr-onnxruntime>=1.4.4,<2.0.0
rapid-table>=1.0.3,<2.0.0 rapid-table>=1.0.3,<2.0.0
doclayout-yolo==0.0.2b1 doclayout-yolo==0.0.2b1
ftfy
openai openai
detectron2
...@@ -40,5 +40,5 @@ ...@@ -40,5 +40,5 @@
"enable": false "enable": false
} }
}, },
"config_version": "1.1.1" "config_version": "1.2.0"
} }
\ No newline at end of file
import concurrent.futures
import fitz
from magic_pdf.data.dataset import PymuDocDataset
from magic_pdf.data.utils import fitz_doc_to_image # PyMuPDF
def partition_array_greedy(arr, k):
"""Partition an array into k parts using a simple greedy approach.
Parameters:
-----------
arr : list
The input array of integers
k : int
Number of partitions to create
Returns:
--------
partitions : list of lists
The k partitions of the array
"""
# Handle edge cases
if k <= 0:
raise ValueError('k must be a positive integer')
if k > len(arr):
k = len(arr) # Adjust k if it's too large
if k == 1:
return [list(range(len(arr)))]
if k == len(arr):
return [[i] for i in range(len(arr))]
# Sort the array in descending order
sorted_indices = sorted(range(len(arr)), key=lambda i: arr[i][1], reverse=True)
# Initialize k empty partitions
partitions = [[] for _ in range(k)]
partition_sums = [0] * k
# Assign each element to the partition with the smallest current sum
for idx in sorted_indices:
# Find the partition with the smallest sum
min_sum_idx = partition_sums.index(min(partition_sums))
# Add the element to this partition
partitions[min_sum_idx].append(idx) # Store the original index
partition_sums[min_sum_idx] += arr[idx][1]
return partitions
def process_pdf_batch(pdf_jobs, idx):
"""Process a batch of PDF pages using multiple threads.
Parameters:
-----------
pdf_jobs : list of tuples
List of (pdf_path, page_num) tuples
output_dir : str or None
Directory to save images to
num_threads : int
Number of threads to use
**kwargs :
Additional arguments for process_pdf_page
Returns:
--------
images : list
List of processed images
"""
images = []
for pdf_path, _ in pdf_jobs:
doc = fitz.open(pdf_path)
tmp = []
for page_num in range(len(doc)):
page = doc[page_num]
tmp.append(fitz_doc_to_image(page))
images.append(tmp)
return (idx, images)
def batch_build_dataset(pdf_paths, k, lang=None):
"""Process multiple PDFs by partitioning them into k balanced parts and
processing each part in parallel.
Parameters:
-----------
pdf_paths : list
List of paths to PDF files
k : int
Number of partitions to create
output_dir : str or None
Directory to save images to
threads_per_worker : int
Number of threads to use per worker
**kwargs :
Additional arguments for process_pdf_page
Returns:
--------
all_images : list
List of all processed images
"""
# Get page counts for each PDF
pdf_info = []
total_pages = 0
for pdf_path in pdf_paths:
try:
doc = fitz.open(pdf_path)
num_pages = len(doc)
pdf_info.append((pdf_path, num_pages))
total_pages += num_pages
doc.close()
except Exception as e:
print(f'Error opening {pdf_path}: {e}')
# Partition the jobs based on page countEach job has 1 page
partitions = partition_array_greedy(pdf_info, k)
# Process each partition in parallel
all_images_h = {}
with concurrent.futures.ProcessPoolExecutor(max_workers=k) as executor:
# Submit one task per partition
futures = []
for sn, partition in enumerate(partitions):
# Get the jobs for this partition
partition_jobs = [pdf_info[idx] for idx in partition]
# Submit the task
future = executor.submit(
process_pdf_batch,
partition_jobs,
sn
)
futures.append(future)
# Process results as they complete
for i, future in enumerate(concurrent.futures.as_completed(futures)):
try:
idx, images = future.result()
all_images_h[idx] = images
except Exception as e:
print(f'Error processing partition: {e}')
results = [None] * len(pdf_paths)
for i in range(len(partitions)):
partition = partitions[i]
for j in range(len(partition)):
with open(pdf_info[partition[j]][0], 'rb') as f:
pdf_bytes = f.read()
dataset = PymuDocDataset(pdf_bytes, lang=lang)
dataset.set_images(all_images_h[i][j])
results[partition[j]] = dataset
return results
...@@ -97,7 +97,7 @@ class Dataset(ABC): ...@@ -97,7 +97,7 @@ class Dataset(ABC):
@abstractmethod @abstractmethod
def dump_to_file(self, file_path: str): def dump_to_file(self, file_path: str):
"""Dump the file """Dump the file.
Args: Args:
file_path (str): the file path file_path (str): the file path
...@@ -119,7 +119,7 @@ class Dataset(ABC): ...@@ -119,7 +119,7 @@ class Dataset(ABC):
@abstractmethod @abstractmethod
def classify(self) -> SupportedPdfParseMethod: def classify(self) -> SupportedPdfParseMethod:
"""classify the dataset """classify the dataset.
Returns: Returns:
SupportedPdfParseMethod: _description_ SupportedPdfParseMethod: _description_
...@@ -128,8 +128,7 @@ class Dataset(ABC): ...@@ -128,8 +128,7 @@ class Dataset(ABC):
@abstractmethod @abstractmethod
def clone(self): def clone(self):
"""clone this dataset """clone this dataset."""
"""
pass pass
...@@ -148,12 +147,14 @@ class PymuDocDataset(Dataset): ...@@ -148,12 +147,14 @@ class PymuDocDataset(Dataset):
if lang == '': if lang == '':
self._lang = None self._lang = None
elif lang == 'auto': elif lang == 'auto':
from magic_pdf.model.sub_modules.language_detection.utils import auto_detect_lang from magic_pdf.model.sub_modules.language_detection.utils import \
auto_detect_lang
self._lang = auto_detect_lang(bits) self._lang = auto_detect_lang(bits)
logger.info(f"lang: {lang}, detect_lang: {self._lang}") logger.info(f'lang: {lang}, detect_lang: {self._lang}')
else: else:
self._lang = lang self._lang = lang
logger.info(f"lang: {lang}") logger.info(f'lang: {lang}')
def __len__(self) -> int: def __len__(self) -> int:
"""The page number of the pdf.""" """The page number of the pdf."""
return len(self._records) return len(self._records)
...@@ -186,7 +187,7 @@ class PymuDocDataset(Dataset): ...@@ -186,7 +187,7 @@ class PymuDocDataset(Dataset):
return self._records[page_id] return self._records[page_id]
def dump_to_file(self, file_path: str): def dump_to_file(self, file_path: str):
"""Dump the file """Dump the file.
Args: Args:
file_path (str): the file path file_path (str): the file path
...@@ -212,7 +213,7 @@ class PymuDocDataset(Dataset): ...@@ -212,7 +213,7 @@ class PymuDocDataset(Dataset):
return proc(self, *args, **kwargs) return proc(self, *args, **kwargs)
def classify(self) -> SupportedPdfParseMethod: def classify(self) -> SupportedPdfParseMethod:
"""classify the dataset """classify the dataset.
Returns: Returns:
SupportedPdfParseMethod: _description_ SupportedPdfParseMethod: _description_
...@@ -220,10 +221,12 @@ class PymuDocDataset(Dataset): ...@@ -220,10 +221,12 @@ class PymuDocDataset(Dataset):
return classify(self._data_bits) return classify(self._data_bits)
def clone(self): def clone(self):
"""clone this dataset """clone this dataset."""
"""
return PymuDocDataset(self._raw_data) return PymuDocDataset(self._raw_data)
def set_images(self, images):
for i in range(len(self._records)):
self._records[i].set_image(images[i])
class ImageDataset(Dataset): class ImageDataset(Dataset):
def __init__(self, bits: bytes): def __init__(self, bits: bytes):
...@@ -270,7 +273,7 @@ class ImageDataset(Dataset): ...@@ -270,7 +273,7 @@ class ImageDataset(Dataset):
return self._records[page_id] return self._records[page_id]
def dump_to_file(self, file_path: str): def dump_to_file(self, file_path: str):
"""Dump the file """Dump the file.
Args: Args:
file_path (str): the file path file_path (str): the file path
...@@ -293,7 +296,7 @@ class ImageDataset(Dataset): ...@@ -293,7 +296,7 @@ class ImageDataset(Dataset):
return proc(self, *args, **kwargs) return proc(self, *args, **kwargs)
def classify(self) -> SupportedPdfParseMethod: def classify(self) -> SupportedPdfParseMethod:
"""classify the dataset """classify the dataset.
Returns: Returns:
SupportedPdfParseMethod: _description_ SupportedPdfParseMethod: _description_
...@@ -301,15 +304,19 @@ class ImageDataset(Dataset): ...@@ -301,15 +304,19 @@ class ImageDataset(Dataset):
return SupportedPdfParseMethod.OCR return SupportedPdfParseMethod.OCR
def clone(self): def clone(self):
"""clone this dataset """clone this dataset."""
"""
return ImageDataset(self._raw_data) return ImageDataset(self._raw_data)
def set_images(self, images):
for i in range(len(self._records)):
self._records[i].set_image(images[i])
class Doc(PageableData): class Doc(PageableData):
"""Initialized with pymudoc object.""" """Initialized with pymudoc object."""
def __init__(self, doc: fitz.Page): def __init__(self, doc: fitz.Page):
self._doc = doc self._doc = doc
self._img = None
def get_image(self): def get_image(self):
"""Return the image info. """Return the image info.
...@@ -321,7 +328,17 @@ class Doc(PageableData): ...@@ -321,7 +328,17 @@ class Doc(PageableData):
height: int height: int
} }
""" """
return fitz_doc_to_image(self._doc) if self._img is None:
self._img = fitz_doc_to_image(self._doc)
return self._img
def set_image(self, img):
"""
Args:
img (np.ndarray): the image
"""
if self._img is None:
self._img = img
def get_doc(self) -> fitz.Page: def get_doc(self) -> fitz.Page:
"""Get the pymudoc object. """Get the pymudoc object.
......
import multiprocessing as mp
import threading
from concurrent.futures import (ProcessPoolExecutor, ThreadPoolExecutor,
as_completed)
import fitz import fitz
import numpy as np import numpy as np
from loguru import logger from loguru import logger
from magic_pdf.utils.annotations import ImportPIL
@ImportPIL
def fitz_doc_to_image(doc, dpi=200) -> dict: def fitz_doc_to_image(doc, dpi=200) -> dict:
"""Convert fitz.Document to image, Then convert the image to numpy array. """Convert fitz.Document to image, Then convert the image to numpy array.
...@@ -17,7 +20,6 @@ def fitz_doc_to_image(doc, dpi=200) -> dict: ...@@ -17,7 +20,6 @@ def fitz_doc_to_image(doc, dpi=200) -> dict:
Returns: Returns:
dict: {'img': numpy array, 'width': width, 'height': height } dict: {'img': numpy array, 'width': width, 'height': height }
""" """
from PIL import Image
mat = fitz.Matrix(dpi / 72, dpi / 72) mat = fitz.Matrix(dpi / 72, dpi / 72)
pm = doc.get_pixmap(matrix=mat, alpha=False) pm = doc.get_pixmap(matrix=mat, alpha=False)
...@@ -25,16 +27,14 @@ def fitz_doc_to_image(doc, dpi=200) -> dict: ...@@ -25,16 +27,14 @@ def fitz_doc_to_image(doc, dpi=200) -> dict:
if pm.width > 4500 or pm.height > 4500: if pm.width > 4500 or pm.height > 4500:
pm = doc.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False) pm = doc.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False)
img = Image.frombytes('RGB', (pm.width, pm.height), pm.samples) # Convert pixmap samples directly to numpy array
img = np.array(img) img = np.frombuffer(pm.samples, dtype=np.uint8).reshape(pm.height, pm.width, 3)
img_dict = {'img': img, 'width': pm.width, 'height': pm.height} img_dict = {'img': img, 'width': pm.width, 'height': pm.height}
return img_dict return img_dict
@ImportPIL
def load_images_from_pdf(pdf_bytes: bytes, dpi=200, start_page_id=0, end_page_id=None) -> list: def load_images_from_pdf(pdf_bytes: bytes, dpi=200, start_page_id=0, end_page_id=None) -> list:
from PIL import Image
images = [] images = []
with fitz.open('pdf', pdf_bytes) as doc: with fitz.open('pdf', pdf_bytes) as doc:
pdf_page_num = doc.page_count pdf_page_num = doc.page_count
...@@ -57,11 +57,110 @@ def load_images_from_pdf(pdf_bytes: bytes, dpi=200, start_page_id=0, end_page_id ...@@ -57,11 +57,110 @@ def load_images_from_pdf(pdf_bytes: bytes, dpi=200, start_page_id=0, end_page_id
if pm.width > 4500 or pm.height > 4500: if pm.width > 4500 or pm.height > 4500:
pm = page.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False) pm = page.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False)
img = Image.frombytes('RGB', (pm.width, pm.height), pm.samples) # Convert pixmap samples directly to numpy array
img = np.array(img) img = np.frombuffer(pm.samples, dtype=np.uint8).reshape(pm.height, pm.width, 3)
img_dict = {'img': img, 'width': pm.width, 'height': pm.height} img_dict = {'img': img, 'width': pm.width, 'height': pm.height}
else: else:
img_dict = {'img': [], 'width': 0, 'height': 0} img_dict = {'img': [], 'width': 0, 'height': 0}
images.append(img_dict) images.append(img_dict)
return images return images
def convert_page(bytes_page):
pdfs = fitz.open('pdf', bytes_page)
page = pdfs[0]
return fitz_doc_to_image(page)
def parallel_process_pdf_safe(pages, num_workers=None, **kwargs):
"""Process PDF pages in parallel with serialization-safe approach."""
if num_workers is None:
num_workers = mp.cpu_count()
# Process the extracted page data in parallel
with ProcessPoolExecutor(max_workers=num_workers) as executor:
# Process the page data
results = list(
executor.map(convert_page, pages)
)
return results
def threaded_process_pdf(pdf_path, num_threads=4, **kwargs):
"""Process all pages of a PDF using multiple threads.
Parameters:
-----------
pdf_path : str
Path to the PDF file
num_threads : int
Number of threads to use
**kwargs :
Additional arguments for fitz_doc_to_image
Returns:
--------
images : list
List of processed images, in page order
"""
# Open the PDF
doc = fitz.open(pdf_path)
num_pages = len(doc)
# Create a list to store results in the correct order
results = [None] * num_pages
# Create a thread pool
with ThreadPoolExecutor(max_workers=num_threads) as executor:
# Submit all tasks
futures = {}
for page_num in range(num_pages):
page = doc[page_num]
future = executor.submit(fitz_doc_to_image, page, **kwargs)
futures[future] = page_num
# Process results as they complete with progress bar
for future in as_completed(futures):
page_num = futures[future]
try:
results[page_num] = future.result()
except Exception as e:
print(f'Error processing page {page_num}: {e}')
results[page_num] = None
# Close the document
doc.close()
if __name__ == '__main__':
pdf = fitz.open('/tmp/[MS-DOC].pdf')
pdf_page = [fitz.open() for i in range(pdf.page_count)]
[pdf_page[i].insert_pdf(pdf, from_page=i, to_page=i) for i in range(pdf.page_count)]
pdf_page = [v.tobytes() for v in pdf_page]
results = parallel_process_pdf_safe(pdf_page, num_workers=16)
# threaded_process_pdf('/tmp/[MS-DOC].pdf', num_threads=16)
""" benchmark results of multi-threaded processing (fitz page to image)
total page nums: 578
thread nums, time cost
1 7.351 sec
2 6.334 sec
4 5.968 sec
8 6.728 sec
16 8.085 sec
"""
""" benchmark results of multi-processor processing (fitz page to image)
total page nums: 578
processor nums, time cost
1 17.170 sec
2 10.170 sec
4 7.841 sec
8 7.900 sec
16 7.984 sec
"""
...@@ -208,12 +208,13 @@ def para_to_standard_format_v2(para_block, img_buket_path, page_idx, drop_reason ...@@ -208,12 +208,13 @@ def para_to_standard_format_v2(para_block, img_buket_path, page_idx, drop_reason
'text': merge_para_with_text(para_block), 'text': merge_para_with_text(para_block),
} }
elif para_type == BlockType.Title: elif para_type == BlockType.Title:
title_level = get_title_level(para_block)
para_content = { para_content = {
'type': 'text', 'type': 'text',
'text': merge_para_with_text(para_block), 'text': merge_para_with_text(para_block),
'text_level': title_level,
} }
title_level = get_title_level(para_block)
if title_level != 0:
para_content['text_level'] = title_level
elif para_type == BlockType.InterlineEquation: elif para_type == BlockType.InterlineEquation:
para_content = { para_content = {
'type': 'equation', 'type': 'equation',
...@@ -319,5 +320,5 @@ def get_title_level(block): ...@@ -319,5 +320,5 @@ def get_title_level(block):
if title_level > 4: if title_level > 4:
title_level = 4 title_level = 4
elif title_level < 1: elif title_level < 1:
title_level = 1 title_level = 0
return title_level return title_level
\ No newline at end of file
...@@ -44,14 +44,19 @@ def cut_image_to_pil_image(bbox: tuple, page: fitz.Page, mode="pillow"): ...@@ -44,14 +44,19 @@ def cut_image_to_pil_image(bbox: tuple, page: fitz.Page, mode="pillow"):
# 截取图片 # 截取图片
pix = page.get_pixmap(clip=rect, matrix=zoom) pix = page.get_pixmap(clip=rect, matrix=zoom)
if mode == "cv2":
# 直接转换为numpy数组供cv2使用
img_array = np.frombuffer(pix.samples, dtype=np.uint8).reshape(pix.height, pix.width, pix.n)
# PyMuPDF使用RGB顺序,而cv2使用BGR顺序
if pix.n == 3 or pix.n == 4:
image_result = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)
else:
image_result = img_array
elif mode == "pillow":
# 将字节数据转换为文件对象 # 将字节数据转换为文件对象
image_file = BytesIO(pix.tobytes(output='png')) image_file = BytesIO(pix.tobytes(output='png'))
# 使用 Pillow 打开图像 # 使用 Pillow 打开图像
pil_image = Image.open(image_file) image_result = Image.open(image_file)
if mode == "cv2":
image_result = cv2.cvtColor(np.asarray(pil_image), cv2.COLOR_RGB2BGR)
elif mode == "pillow":
image_result = pil_image
else: else:
raise ValueError(f"mode: {mode} is not supported.") raise ValueError(f"mode: {mode} is not supported.")
......
import time import time
import cv2 import cv2
import numpy as np
import torch import torch
from loguru import logger from loguru import logger
from PIL import Image
from magic_pdf.config.constants import MODEL_NAME from magic_pdf.config.constants import MODEL_NAME
# from magic_pdf.config.exceptions import CUDA_NOT_AVAILABLE
# from magic_pdf.data.dataset import Dataset
# from magic_pdf.libs.clean_memory import clean_memory
# from magic_pdf.libs.config_reader import get_device
# from magic_pdf.model.doc_analyze_by_custom_model import ModelSingleton
from magic_pdf.model.pdf_extract_kit import CustomPEKModel from magic_pdf.model.pdf_extract_kit import CustomPEKModel
from magic_pdf.model.sub_modules.model_utils import ( from magic_pdf.model.sub_modules.model_utils import (
clean_vram, crop_img, get_res_list_from_layout_res) clean_vram, crop_img, get_res_list_from_layout_res)
from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import ( from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import (
get_adjusted_mfdetrec_res, get_ocr_result_list) get_adjusted_mfdetrec_res, get_ocr_result_list)
# from magic_pdf.operators.models import InferenceResult
YOLO_LAYOUT_BASE_BATCH_SIZE = 1 YOLO_LAYOUT_BASE_BATCH_SIZE = 1
MFD_BASE_BATCH_SIZE = 1 MFD_BASE_BATCH_SIZE = 1
...@@ -31,7 +23,6 @@ class BatchAnalyze: ...@@ -31,7 +23,6 @@ class BatchAnalyze:
def __call__(self, images: list) -> list: def __call__(self, images: list) -> list:
images_layout_res = [] images_layout_res = []
layout_start_time = time.time() layout_start_time = time.time()
if self.model.layout_model_name == MODEL_NAME.LAYOUTLMv3: if self.model.layout_model_name == MODEL_NAME.LAYOUTLMv3:
# layoutlmv3 # layoutlmv3
...@@ -41,36 +32,14 @@ class BatchAnalyze: ...@@ -41,36 +32,14 @@ class BatchAnalyze:
elif self.model.layout_model_name == MODEL_NAME.DocLayout_YOLO: elif self.model.layout_model_name == MODEL_NAME.DocLayout_YOLO:
# doclayout_yolo # doclayout_yolo
layout_images = [] layout_images = []
modified_images = []
for image_index, image in enumerate(images): for image_index, image in enumerate(images):
pil_img = Image.fromarray(image) layout_images.append(image)
# width, height = pil_img.size
# if height > width:
# input_res = {'poly': [0, 0, width, 0, width, height, 0, height]}
# new_image, useful_list = crop_img(
# input_res, pil_img, crop_paste_x=width // 2, crop_paste_y=0
# )
# layout_images.append(new_image)
# modified_images.append([image_index, useful_list])
# else:
layout_images.append(pil_img)
images_layout_res += self.model.layout_model.batch_predict( images_layout_res += self.model.layout_model.batch_predict(
# layout_images, self.batch_ratio * YOLO_LAYOUT_BASE_BATCH_SIZE # layout_images, self.batch_ratio * YOLO_LAYOUT_BASE_BATCH_SIZE
layout_images, YOLO_LAYOUT_BASE_BATCH_SIZE layout_images, YOLO_LAYOUT_BASE_BATCH_SIZE
) )
for image_index, useful_list in modified_images:
for res in images_layout_res[image_index]:
for i in range(len(res['poly'])):
if i % 2 == 0:
res['poly'][i] = (
res['poly'][i] - useful_list[0] + useful_list[2]
)
else:
res['poly'][i] = (
res['poly'][i] - useful_list[1] + useful_list[3]
)
logger.info( logger.info(
f'layout time: {round(time.time() - layout_start_time, 2)}, image num: {len(images)}' f'layout time: {round(time.time() - layout_start_time, 2)}, image num: {len(images)}'
) )
...@@ -111,7 +80,7 @@ class BatchAnalyze: ...@@ -111,7 +80,7 @@ class BatchAnalyze:
# reference: magic_pdf/model/doc_analyze_by_custom_model.py:doc_analyze # reference: magic_pdf/model/doc_analyze_by_custom_model.py:doc_analyze
for index in range(len(images)): for index in range(len(images)):
layout_res = images_layout_res[index] layout_res = images_layout_res[index]
pil_img = Image.fromarray(images[index]) np_array_img = images[index]
ocr_res_list, table_res_list, single_page_mfdetrec_res = ( ocr_res_list, table_res_list, single_page_mfdetrec_res = (
get_res_list_from_layout_res(layout_res) get_res_list_from_layout_res(layout_res)
...@@ -121,14 +90,14 @@ class BatchAnalyze: ...@@ -121,14 +90,14 @@ class BatchAnalyze:
# Process each area that requires OCR processing # Process each area that requires OCR processing
for res in ocr_res_list: for res in ocr_res_list:
new_image, useful_list = crop_img( new_image, useful_list = crop_img(
res, pil_img, crop_paste_x=50, crop_paste_y=50 res, np_array_img, crop_paste_x=50, crop_paste_y=50
) )
adjusted_mfdetrec_res = get_adjusted_mfdetrec_res( adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(
single_page_mfdetrec_res, useful_list single_page_mfdetrec_res, useful_list
) )
# OCR recognition # OCR recognition
new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR) new_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR)
if self.model.apply_ocr: if self.model.apply_ocr:
ocr_res = self.model.ocr_model.ocr( ocr_res = self.model.ocr_model.ocr(
...@@ -150,7 +119,7 @@ class BatchAnalyze: ...@@ -150,7 +119,7 @@ class BatchAnalyze:
if self.model.apply_table: if self.model.apply_table:
table_start = time.time() table_start = time.time()
for res in table_res_list: for res in table_res_list:
new_image, _ = crop_img(res, pil_img) new_image, _ = crop_img(res, np_array_img)
single_table_start_time = time.time() single_table_start_time = time.time()
html_code = None html_code = None
if self.model.table_model_name == MODEL_NAME.STRUCT_EQTABLE: if self.model.table_model_name == MODEL_NAME.STRUCT_EQTABLE:
...@@ -197,83 +166,3 @@ class BatchAnalyze: ...@@ -197,83 +166,3 @@ class BatchAnalyze:
logger.info(f'table time: {round(table_time, 2)}, image num: {table_count}') logger.info(f'table time: {round(table_time, 2)}, image num: {table_count}')
return images_layout_res return images_layout_res
# def doc_batch_analyze(
# dataset: Dataset,
# ocr: bool = False,
# show_log: bool = False,
# start_page_id=0,
# end_page_id=None,
# lang=None,
# layout_model=None,
# formula_enable=None,
# table_enable=None,
# batch_ratio: int | None = None,
# ) -> InferenceResult:
# """Perform batch analysis on a document dataset.
#
# Args:
# dataset (Dataset): The dataset containing document pages to be analyzed.
# ocr (bool, optional): Flag to enable OCR (Optical Character Recognition). Defaults to False.
# show_log (bool, optional): Flag to enable logging. Defaults to False.
# start_page_id (int, optional): The starting page ID for analysis. Defaults to 0.
# end_page_id (int, optional): The ending page ID for analysis. Defaults to None, which means analyze till the last page.
# lang (str, optional): Language for OCR. Defaults to None.
# layout_model (optional): Layout model to be used for analysis. Defaults to None.
# formula_enable (optional): Flag to enable formula detection. Defaults to None.
# table_enable (optional): Flag to enable table detection. Defaults to None.
# batch_ratio (int | None, optional): Ratio for batch processing. Defaults to None, which sets it to 1.
#
# Raises:
# CUDA_NOT_AVAILABLE: If CUDA is not available, raises an exception as batch analysis is not supported in CPU mode.
#
# Returns:
# InferenceResult: The result of the batch analysis containing the analyzed data and the dataset.
# """
#
# if not torch.cuda.is_available():
# raise CUDA_NOT_AVAILABLE('batch analyze not support in CPU mode')
#
# lang = None if lang == '' else lang
# # TODO: auto detect batch size
# batch_ratio = 1 if batch_ratio is None else batch_ratio
# end_page_id = end_page_id if end_page_id else len(dataset)
#
# model_manager = ModelSingleton()
# custom_model: CustomPEKModel = model_manager.get_model(
# ocr, show_log, lang, layout_model, formula_enable, table_enable
# )
# batch_model = BatchAnalyze(model=custom_model, batch_ratio=batch_ratio)
#
# model_json = []
#
# # batch analyze
# images = []
# for index in range(len(dataset)):
# if start_page_id <= index <= end_page_id:
# page_data = dataset.get_page(index)
# img_dict = page_data.get_image()
# images.append(img_dict['img'])
# analyze_result = batch_model(images)
#
# for index in range(len(dataset)):
# page_data = dataset.get_page(index)
# img_dict = page_data.get_image()
# page_width = img_dict['width']
# page_height = img_dict['height']
# if start_page_id <= index <= end_page_id:
# result = analyze_result.pop(0)
# else:
# result = []
#
# page_info = {'page_no': index, 'height': page_height, 'width': page_width}
# page_dict = {'layout_dets': result, 'page_info': page_info}
# model_json.append(page_dict)
#
# # TODO: clean memory when gpu memory is not enough
# clean_memory_start_time = time.time()
# clean_memory(get_device())
# logger.info(f'clean memory time: {round(time.time() - clean_memory_start_time, 2)}')
#
# return InferenceResult(model_json, dataset)
import concurrent.futures as fut
import multiprocessing as mp
import os import os
import time import time
import numpy as np
import torch import torch
os.environ['FLAGS_npu_jit_compile'] = '0' # 关闭paddle的jit编译 os.environ['FLAGS_npu_jit_compile'] = '0' # 关闭paddle的jit编译
os.environ['FLAGS_use_stride_kernel'] = '0' os.environ['FLAGS_use_stride_kernel'] = '0'
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 让mps可以fallback os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 让mps可以fallback
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新 os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
# 关闭paddle的信号处理
import paddle
paddle.disable_signal_handler()
from loguru import logger from loguru import logger
from magic_pdf.model.batch_analyze import BatchAnalyze
from magic_pdf.model.sub_modules.model_utils import get_vram from magic_pdf.model.sub_modules.model_utils import get_vram
try: try:
...@@ -30,8 +31,8 @@ from magic_pdf.libs.config_reader import (get_device, get_formula_config, ...@@ -30,8 +31,8 @@ from magic_pdf.libs.config_reader import (get_device, get_formula_config,
get_local_models_dir, get_local_models_dir,
get_table_recog_config) get_table_recog_config)
from magic_pdf.model.model_list import MODEL from magic_pdf.model.model_list import MODEL
from magic_pdf.operators.models import InferenceResult
# from magic_pdf.operators.models import InferenceResult
class ModelSingleton: class ModelSingleton:
_instance = None _instance = None
...@@ -72,9 +73,7 @@ def custom_model_init( ...@@ -72,9 +73,7 @@ def custom_model_init(
formula_enable=None, formula_enable=None,
table_enable=None, table_enable=None,
): ):
model = None model = None
if model_config.__model_mode__ == 'lite': if model_config.__model_mode__ == 'lite':
logger.warning( logger.warning(
'The Lite mode is provided for developers to conduct testing only, and the output quality is ' 'The Lite mode is provided for developers to conduct testing only, and the output quality is '
...@@ -132,7 +131,6 @@ def custom_model_init( ...@@ -132,7 +131,6 @@ def custom_model_init(
return custom_model return custom_model
def doc_analyze( def doc_analyze(
dataset: Dataset, dataset: Dataset,
ocr: bool = False, ocr: bool = False,
...@@ -143,14 +141,112 @@ def doc_analyze( ...@@ -143,14 +141,112 @@ def doc_analyze(
layout_model=None, layout_model=None,
formula_enable=None, formula_enable=None,
table_enable=None, table_enable=None,
) -> InferenceResult: ):
end_page_id = ( end_page_id = (
end_page_id end_page_id
if end_page_id is not None and end_page_id >= 0 if end_page_id is not None and end_page_id >= 0
else len(dataset) - 1 else len(dataset) - 1
) )
MIN_BATCH_INFERENCE_SIZE = int(os.environ.get('MINERU_MIN_BATCH_INFERENCE_SIZE', 100))
images = []
page_wh_list = []
for index in range(len(dataset)):
if start_page_id <= index <= end_page_id:
page_data = dataset.get_page(index)
img_dict = page_data.get_image()
images.append(img_dict['img'])
page_wh_list.append((img_dict['width'], img_dict['height']))
if len(images) >= MIN_BATCH_INFERENCE_SIZE:
batch_size = MIN_BATCH_INFERENCE_SIZE
batch_images = [images[i:i+batch_size] for i in range(0, len(images), batch_size)]
else:
batch_images = [images]
results = []
for sn, batch_image in enumerate(batch_images):
_, result = may_batch_image_analyze(batch_image, sn, ocr, show_log, lang, layout_model, formula_enable, table_enable)
results.extend(result)
model_json = []
for index in range(len(dataset)):
if start_page_id <= index <= end_page_id:
result = results.pop(0)
page_width, page_height = page_wh_list.pop(0)
else:
result = []
page_height = 0
page_width = 0
page_info = {'page_no': index, 'width': page_width, 'height': page_height}
page_dict = {'layout_dets': result, 'page_info': page_info}
model_json.append(page_dict)
from magic_pdf.operators.models import InferenceResult
return InferenceResult(model_json, dataset)
def batch_doc_analyze(
datasets: list[Dataset],
ocr: bool = False,
show_log: bool = False,
lang=None,
layout_model=None,
formula_enable=None,
table_enable=None,
):
MIN_BATCH_INFERENCE_SIZE = int(os.environ.get('MINERU_MIN_BATCH_INFERENCE_SIZE', 100))
images = []
page_wh_list = []
for dataset in datasets:
for index in range(len(dataset)):
page_data = dataset.get_page(index)
img_dict = page_data.get_image()
images.append(img_dict['img'])
page_wh_list.append((img_dict['width'], img_dict['height']))
if len(images) >= MIN_BATCH_INFERENCE_SIZE:
batch_size = MIN_BATCH_INFERENCE_SIZE
batch_images = [images[i:i+batch_size] for i in range(0, len(images), batch_size)]
else:
batch_images = [images]
results = []
for sn, batch_image in enumerate(batch_images):
_, result = may_batch_image_analyze(batch_image, sn, ocr, show_log, lang, layout_model, formula_enable, table_enable)
results.extend(result)
infer_results = []
from magic_pdf.operators.models import InferenceResult
for index in range(len(datasets)):
dataset = datasets[index]
model_json = []
for i in range(len(dataset)):
result = results.pop(0)
page_width, page_height = page_wh_list.pop(0)
page_info = {'page_no': i, 'width': page_width, 'height': page_height}
page_dict = {'layout_dets': result, 'page_info': page_info}
model_json.append(page_dict)
infer_results.append(InferenceResult(model_json, dataset))
return infer_results
def may_batch_image_analyze(
images: list[np.ndarray],
idx: int,
ocr: bool = False,
show_log: bool = False,
lang=None,
layout_model=None,
formula_enable=None,
table_enable=None):
# os.environ['CUDA_VISIBLE_DEVICES'] = str(idx)
# 关闭paddle的信号处理
import paddle
paddle.disable_signal_handler()
from magic_pdf.model.batch_analyze import BatchAnalyze
model_manager = ModelSingleton() model_manager = ModelSingleton()
custom_model = model_manager.get_model( custom_model = model_manager.get_model(
ocr, show_log, lang, layout_model, formula_enable, table_enable ocr, show_log, lang, layout_model, formula_enable, table_enable
...@@ -160,33 +256,32 @@ def doc_analyze( ...@@ -160,33 +256,32 @@ def doc_analyze(
batch_ratio = 1 batch_ratio = 1
device = get_device() device = get_device()
npu_support = False if str(device).startswith('npu'):
if str(device).startswith("npu"):
import torch_npu import torch_npu
if torch_npu.npu.is_available(): if torch_npu.npu.is_available():
npu_support = True
torch.npu.set_compile_mode(jit_compile=False) torch.npu.set_compile_mode(jit_compile=False)
if torch.cuda.is_available() and device != 'cpu' or npu_support: if str(device).startswith('npu') or str(device).startswith('cuda'):
gpu_memory = int(os.getenv("VIRTUAL_VRAM_SIZE", round(get_vram(device)))) gpu_memory = int(os.getenv('VIRTUAL_VRAM_SIZE', round(get_vram(device))))
if gpu_memory is not None and gpu_memory >= 8: if gpu_memory is not None:
if gpu_memory >= 20: if gpu_memory >= 20:
batch_ratio = 16 batch_ratio = 16
elif gpu_memory >= 15: elif gpu_memory >= 15:
batch_ratio = 8 batch_ratio = 8
elif gpu_memory >= 10: elif gpu_memory >= 10:
batch_ratio = 4 batch_ratio = 4
else: elif gpu_memory >= 7:
batch_ratio = 2 batch_ratio = 2
else:
batch_ratio = 1
logger.info(f'gpu_memory: {gpu_memory} GB, batch_ratio: {batch_ratio}') logger.info(f'gpu_memory: {gpu_memory} GB, batch_ratio: {batch_ratio}')
batch_analyze = True batch_analyze = True
elif str(device).startswith('mps'):
model_json = [] batch_analyze = True
doc_analyze_start = time.time() doc_analyze_start = time.time()
if batch_analyze: if batch_analyze:
# batch analyze """# batch analyze
images = [] images = []
page_wh_list = [] page_wh_list = []
for index in range(len(dataset)): for index in range(len(dataset)):
...@@ -195,9 +290,10 @@ def doc_analyze( ...@@ -195,9 +290,10 @@ def doc_analyze(
img_dict = page_data.get_image() img_dict = page_data.get_image()
images.append(img_dict['img']) images.append(img_dict['img'])
page_wh_list.append((img_dict['width'], img_dict['height'])) page_wh_list.append((img_dict['width'], img_dict['height']))
"""
batch_model = BatchAnalyze(model=custom_model, batch_ratio=batch_ratio) batch_model = BatchAnalyze(model=custom_model, batch_ratio=batch_ratio)
analyze_result = batch_model(images) results = batch_model(images)
"""
for index in range(len(dataset)): for index in range(len(dataset)):
if start_page_id <= index <= end_page_id: if start_page_id <= index <= end_page_id:
result = analyze_result.pop(0) result = analyze_result.pop(0)
...@@ -210,10 +306,10 @@ def doc_analyze( ...@@ -210,10 +306,10 @@ def doc_analyze(
page_info = {'page_no': index, 'width': page_width, 'height': page_height} page_info = {'page_no': index, 'width': page_width, 'height': page_height}
page_dict = {'layout_dets': result, 'page_info': page_info} page_dict = {'layout_dets': result, 'page_info': page_info}
model_json.append(page_dict) model_json.append(page_dict)
"""
else: else:
# single analyze # single analyze
"""
for index in range(len(dataset)): for index in range(len(dataset)):
page_data = dataset.get_page(index) page_data = dataset.get_page(index)
img_dict = page_data.get_image() img_dict = page_data.get_image()
...@@ -230,6 +326,13 @@ def doc_analyze( ...@@ -230,6 +326,13 @@ def doc_analyze(
page_info = {'page_no': index, 'width': page_width, 'height': page_height} page_info = {'page_no': index, 'width': page_width, 'height': page_height}
page_dict = {'layout_dets': result, 'page_info': page_info} page_dict = {'layout_dets': result, 'page_info': page_info}
model_json.append(page_dict) model_json.append(page_dict)
"""
results = []
for img_idx, img in enumerate(images):
inference_start = time.time()
result = custom_model(img)
logger.info(f'-----image index : {img_idx}, image inference total time: {round(time.time() - inference_start, 2)}-----')
results.append(result)
gc_start = time.time() gc_start = time.time()
clean_memory(get_device()) clean_memory(get_device())
...@@ -237,10 +340,9 @@ def doc_analyze( ...@@ -237,10 +340,9 @@ def doc_analyze(
logger.info(f'gc time: {gc_time}') logger.info(f'gc time: {gc_time}')
doc_analyze_time = round(time.time() - doc_analyze_start, 2) doc_analyze_time = round(time.time() - doc_analyze_start, 2)
doc_analyze_speed = round((end_page_id + 1 - start_page_id) / doc_analyze_time, 2) doc_analyze_speed = round(len(images) / doc_analyze_time, 2)
logger.info( logger.info(
f'doc analyze time: {round(time.time() - doc_analyze_start, 2)},' f'doc analyze time: {round(time.time() - doc_analyze_start, 2)},'
f' speed: {doc_analyze_speed} pages/second' f' speed: {doc_analyze_speed} pages/second'
) )
return (idx, results)
return InferenceResult(model_json, dataset)
...@@ -3,11 +3,9 @@ import os ...@@ -3,11 +3,9 @@ import os
import time import time
import cv2 import cv2
import numpy as np
import torch import torch
import yaml import yaml
from loguru import logger from loguru import logger
from PIL import Image
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新 os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
...@@ -120,7 +118,7 @@ class CustomPEKModel: ...@@ -120,7 +118,7 @@ class CustomPEKModel:
atom_model_name=AtomicModel.MFR, atom_model_name=AtomicModel.MFR,
mfr_weight_dir=mfr_weight_dir, mfr_weight_dir=mfr_weight_dir,
mfr_cfg_path=mfr_cfg_path, mfr_cfg_path=mfr_cfg_path,
device='cpu' if str(self.device).startswith("mps") else self.device, device=self.device,
) )
# 初始化layout模型 # 初始化layout模型
...@@ -174,11 +172,6 @@ class CustomPEKModel: ...@@ -174,11 +172,6 @@ class CustomPEKModel:
logger.info('DocAnalysis init done!') logger.info('DocAnalysis init done!')
def __call__(self, image): def __call__(self, image):
pil_img = Image.fromarray(image)
width, height = pil_img.size
# logger.info(f'width: {width}, height: {height}')
# layout检测 # layout检测
layout_start = time.time() layout_start = time.time()
layout_res = [] layout_res = []
...@@ -186,24 +179,6 @@ class CustomPEKModel: ...@@ -186,24 +179,6 @@ class CustomPEKModel:
# layoutlmv3 # layoutlmv3
layout_res = self.layout_model(image, ignore_catids=[]) layout_res = self.layout_model(image, ignore_catids=[])
elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO: elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
# doclayout_yolo
# if height > width:
# input_res = {"poly":[0,0,width,0,width,height,0,height]}
# new_image, useful_list = crop_img(input_res, pil_img, crop_paste_x=width//2, crop_paste_y=0)
# paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
# layout_res = self.layout_model.predict(new_image)
# for res in layout_res:
# p1, p2, p3, p4, p5, p6, p7, p8 = res['poly']
# p1 = p1 - paste_x + xmin
# p2 = p2 - paste_y + ymin
# p3 = p3 - paste_x + xmin
# p4 = p4 - paste_y + ymin
# p5 = p5 - paste_x + xmin
# p6 = p6 - paste_y + ymin
# p7 = p7 - paste_x + xmin
# p8 = p8 - paste_y + ymin
# res['poly'] = [p1, p2, p3, p4, p5, p6, p7, p8]
# else:
layout_res = self.layout_model.predict(image) layout_res = self.layout_model.predict(image)
layout_cost = round(time.time() - layout_start, 2) layout_cost = round(time.time() - layout_start, 2)
...@@ -234,11 +209,11 @@ class CustomPEKModel: ...@@ -234,11 +209,11 @@ class CustomPEKModel:
ocr_start = time.time() ocr_start = time.time()
# Process each area that requires OCR processing # Process each area that requires OCR processing
for res in ocr_res_list: for res in ocr_res_list:
new_image, useful_list = crop_img(res, pil_img, crop_paste_x=50, crop_paste_y=50) new_image, useful_list = crop_img(res, image, crop_paste_x=50, crop_paste_y=50)
adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(single_page_mfdetrec_res, useful_list) adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(single_page_mfdetrec_res, useful_list)
# OCR recognition # OCR recognition
new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR) new_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR)
if self.apply_ocr: if self.apply_ocr:
ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res)[0] ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res)[0]
...@@ -260,7 +235,7 @@ class CustomPEKModel: ...@@ -260,7 +235,7 @@ class CustomPEKModel:
if self.apply_table: if self.apply_table:
table_start = time.time() table_start = time.time()
for res in table_res_list: for res in table_res_list:
new_image, _ = crop_img(res, pil_img) new_image, _ = crop_img(res, image)
single_table_start_time = time.time() single_table_start_time = time.time()
html_code = None html_code = None
if self.table_model_name == MODEL_NAME.STRUCT_EQTABLE: if self.table_model_name == MODEL_NAME.STRUCT_EQTABLE:
......
...@@ -3,8 +3,6 @@ import os ...@@ -3,8 +3,6 @@ import os
from pathlib import Path from pathlib import Path
import yaml import yaml
from PIL import Image
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新 os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
from magic_pdf.config.constants import MODEL_NAME from magic_pdf.config.constants import MODEL_NAME
...@@ -42,7 +40,7 @@ def get_text_images(simple_images): ...@@ -42,7 +40,7 @@ def get_text_images(simple_images):
) )
text_images = [] text_images = []
for simple_image in simple_images: for simple_image in simple_images:
image = Image.fromarray(simple_image['img']) image = simple_image['img']
layout_res = temp_layout_model.predict(image) layout_res = temp_layout_model.predict(image)
# 给textblock截图 # 给textblock截图
for res in layout_res: for res in layout_res:
...@@ -51,7 +49,7 @@ def get_text_images(simple_images): ...@@ -51,7 +49,7 @@ def get_text_images(simple_images):
# 初步清洗(宽和高都小于100) # 初步清洗(宽和高都小于100)
if x2 - x1 < 100 and y2 - y1 < 100: if x2 - x1 < 100 and y2 - y1 < 100:
continue continue
text_images.append(image.crop((x1, y1, x2, y2))) text_images.append(image[y1:y2, x1:x2])
return text_images return text_images
......
...@@ -2,9 +2,9 @@ ...@@ -2,9 +2,9 @@
import time import time
from collections import Counter from collections import Counter
from uuid import uuid4 from uuid import uuid4
import cv2
import numpy as np
import torch import torch
from PIL import Image
from loguru import logger from loguru import logger
from ultralytics import YOLO from ultralytics import YOLO
...@@ -29,7 +29,7 @@ def split_images(image, result_images=None): ...@@ -29,7 +29,7 @@ def split_images(image, result_images=None):
if result_images is None: if result_images is None:
result_images = [] result_images = []
width, height = image.size height, width = image.shape[:2]
long_side = max(width, height) # 获取较长边长度 long_side = max(width, height) # 获取较长边长度
if long_side <= 400: if long_side <= 400:
...@@ -44,16 +44,14 @@ def split_images(image, result_images=None): ...@@ -44,16 +44,14 @@ def split_images(image, result_images=None):
# 判断裁剪区域是否超出图片范围,如果超出则不进行裁剪保存操作 # 判断裁剪区域是否超出图片范围,如果超出则不进行裁剪保存操作
if x + new_long_side > width: if x + new_long_side > width:
continue continue
box = (x, 0, x + new_long_side, height) sub_image = image[0:height, x:x + new_long_side]
sub_image = image.crop(box)
sub_images.append(sub_image) sub_images.append(sub_image)
else: # 如果高度是较长边 else: # 如果高度是较长边
for y in range(0, height, new_long_side): for y in range(0, height, new_long_side):
# 判断裁剪区域是否超出图片范围,如果超出则不进行裁剪保存操作 # 判断裁剪区域是否超出图片范围,如果超出则不进行裁剪保存操作
if y + new_long_side > height: if y + new_long_side > height:
continue continue
box = (0, y, width, y + new_long_side) sub_image = image[y:y + new_long_side, 0:width]
sub_image = image.crop(box)
sub_images.append(sub_image) sub_images.append(sub_image)
for sub_image in sub_images: for sub_image in sub_images:
...@@ -64,24 +62,32 @@ def split_images(image, result_images=None): ...@@ -64,24 +62,32 @@ def split_images(image, result_images=None):
def resize_images_to_224(image): def resize_images_to_224(image):
""" """
若分辨率小于224则用黑色背景补齐到224*224大小,若大于等于224则调整为224*224大小,并保存到输出文件夹中。 若分辨率小于224则用黑色背景补齐到224*224大小,若大于等于224则调整为224*224大小。
Works directly with NumPy arrays.
""" """
try: try:
width, height = image.size height, width = image.shape[:2]
if width < 224 or height < 224: if width < 224 or height < 224:
new_image = Image.new('RGB', (224, 224), (0, 0, 0)) # Create black background
paste_x = (224 - width) // 2 new_image = np.zeros((224, 224, 3), dtype=np.uint8)
paste_y = (224 - height) // 2 # Calculate paste position (ensure they're not negative)
new_image.paste(image, (paste_x, paste_y)) paste_x = max(0, (224 - width) // 2)
paste_y = max(0, (224 - height) // 2)
# Make sure we don't exceed the boundaries of new_image
paste_width = min(width, 224)
paste_height = min(height, 224)
# Paste original image onto black background
new_image[paste_y:paste_y + paste_height, paste_x:paste_x + paste_width] = image[:paste_height, :paste_width]
image = new_image image = new_image
else: else:
image = image.resize((224, 224), Image.Resampling.LANCZOS) # Resize using cv2
image = cv2.resize(image, (224, 224), interpolation=cv2.INTER_LANCZOS4)
# uuid = str(uuid4())
# image.save(f"/tmp/{uuid}.jpg")
return image return image
except Exception as e: except Exception as e:
logger.exception(e) logger.exception(f"Error in resize_images_to_224: {e}")
return None
class YOLOv11LangDetModel(object): class YOLOv11LangDetModel(object):
...@@ -96,8 +102,7 @@ class YOLOv11LangDetModel(object): ...@@ -96,8 +102,7 @@ class YOLOv11LangDetModel(object):
def do_detect(self, images: list): def do_detect(self, images: list):
all_images = [] all_images = []
for image in images: for image in images:
width, height = image.size height, width = image.shape[:2]
# logger.info(f"image size: {width} x {height}")
if width < 100 and height < 100: if width < 100 and height < 100:
continue continue
temp_images = split_images(image) temp_images = split_images(image)
......
import argparse
import os
import re
import torch import torch
import unimernet.tasks as tasks
from PIL import Image
from torch.utils.data import DataLoader, Dataset from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from unimernet.common.config import Config
from unimernet.processors import load_processor
class MathDataset(Dataset): class MathDataset(Dataset):
...@@ -20,55 +11,25 @@ class MathDataset(Dataset): ...@@ -20,55 +11,25 @@ class MathDataset(Dataset):
return len(self.image_paths) return len(self.image_paths)
def __getitem__(self, idx): def __getitem__(self, idx):
# if not pil image, then convert to pil image
if isinstance(self.image_paths[idx], str):
raw_image = Image.open(self.image_paths[idx])
else:
raw_image = self.image_paths[idx] raw_image = self.image_paths[idx]
if self.transform: if self.transform:
image = self.transform(raw_image) image = self.transform(raw_image)
return image return image
def latex_rm_whitespace(s: str):
"""Remove unnecessary whitespace from LaTeX code."""
text_reg = r"(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})"
letter = "[a-zA-Z]"
noletter = "[\W_^\d]"
names = [x[0].replace(" ", "") for x in re.findall(text_reg, s)]
s = re.sub(text_reg, lambda match: str(names.pop(0)), s)
news = s
while True:
s = news
news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, noletter), r"\1\2", s)
news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, letter), r"\1\2", news)
news = re.sub(r"(%s)\s+?(%s)" % (letter, noletter), r"\1\2", news)
if news == s:
break
return s
class UnimernetModel(object): class UnimernetModel(object):
def __init__(self, weight_dir, cfg_path, _device_="cpu"): def __init__(self, weight_dir, cfg_path, _device_="cpu"):
args = argparse.Namespace(cfg_path=cfg_path, options=None) from .unimernet_hf import UnimernetModel
cfg = Config(args) if _device_.startswith("mps"):
cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.pth") self.model = UnimernetModel.from_pretrained(weight_dir, attn_implementation="eager")
cfg.config.model.model_config.model_name = weight_dir else:
cfg.config.model.tokenizer_config.path = weight_dir self.model = UnimernetModel.from_pretrained(weight_dir)
task = tasks.setup_task(cfg)
self.model = task.build_model(cfg)
self.device = _device_ self.device = _device_
self.model.to(_device_) self.model.to(_device_)
if not _device_.startswith("cpu"):
self.model = self.model.to(dtype=torch.float16)
self.model.eval() self.model.eval()
vis_processor = load_processor(
"formula_image_eval",
cfg.config.datasets.formula_rec_eval.vis_processor.eval,
)
self.mfr_transform = transforms.Compose(
[
vis_processor,
]
)
def predict(self, mfd_res, image): def predict(self, mfd_res, image):
formula_list = [] formula_list = []
...@@ -84,62 +45,22 @@ class UnimernetModel(object): ...@@ -84,62 +45,22 @@ class UnimernetModel(object):
"latex": "", "latex": "",
} }
formula_list.append(new_item) formula_list.append(new_item)
pil_img = Image.fromarray(image) bbox_img = image[ymin:ymax, xmin:xmax]
bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
mf_image_list.append(bbox_img) mf_image_list.append(bbox_img)
dataset = MathDataset(mf_image_list, transform=self.mfr_transform) dataset = MathDataset(mf_image_list, transform=self.model.transform)
dataloader = DataLoader(dataset, batch_size=32, num_workers=0) dataloader = DataLoader(dataset, batch_size=32, num_workers=0)
mfr_res = [] mfr_res = []
for mf_img in dataloader: for mf_img in dataloader:
mf_img = mf_img.to(dtype=self.model.dtype)
mf_img = mf_img.to(self.device) mf_img = mf_img.to(self.device)
with torch.no_grad(): with torch.no_grad():
output = self.model.generate({"image": mf_img}) output = self.model.generate({"image": mf_img})
mfr_res.extend(output["pred_str"]) mfr_res.extend(output["fixed_str"])
for res, latex in zip(formula_list, mfr_res): for res, latex in zip(formula_list, mfr_res):
res["latex"] = latex_rm_whitespace(latex) res["latex"] = latex
return formula_list return formula_list
# def batch_predict(
# self, images_mfd_res: list, images: list, batch_size: int = 64
# ) -> list:
# images_formula_list = []
# mf_image_list = []
# backfill_list = []
# for image_index in range(len(images_mfd_res)):
# mfd_res = images_mfd_res[image_index]
# pil_img = Image.fromarray(images[image_index])
# formula_list = []
#
# for xyxy, conf, cla in zip(
# mfd_res.boxes.xyxy, mfd_res.boxes.conf, mfd_res.boxes.cls
# ):
# xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
# new_item = {
# "category_id": 13 + int(cla.item()),
# "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
# "score": round(float(conf.item()), 2),
# "latex": "",
# }
# formula_list.append(new_item)
# bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
# mf_image_list.append(bbox_img)
#
# images_formula_list.append(formula_list)
# backfill_list += formula_list
#
# dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
# dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0)
# mfr_res = []
# for mf_img in dataloader:
# mf_img = mf_img.to(self.device)
# with torch.no_grad():
# output = self.model.generate({"image": mf_img})
# mfr_res.extend(output["pred_str"])
# for res, latex in zip(backfill_list, mfr_res):
# res["latex"] = latex_rm_whitespace(latex)
# return images_formula_list
def batch_predict(self, images_mfd_res: list, images: list, batch_size: int = 64) -> list: def batch_predict(self, images_mfd_res: list, images: list, batch_size: int = 64) -> list:
images_formula_list = [] images_formula_list = []
mf_image_list = [] mf_image_list = []
...@@ -149,7 +70,7 @@ class UnimernetModel(object): ...@@ -149,7 +70,7 @@ class UnimernetModel(object):
# Collect images with their original indices # Collect images with their original indices
for image_index in range(len(images_mfd_res)): for image_index in range(len(images_mfd_res)):
mfd_res = images_mfd_res[image_index] mfd_res = images_mfd_res[image_index]
pil_img = Image.fromarray(images[image_index]) np_array_image = images[image_index]
formula_list = [] formula_list = []
for idx, (xyxy, conf, cla) in enumerate(zip( for idx, (xyxy, conf, cla) in enumerate(zip(
...@@ -163,7 +84,7 @@ class UnimernetModel(object): ...@@ -163,7 +84,7 @@ class UnimernetModel(object):
"latex": "", "latex": "",
} }
formula_list.append(new_item) formula_list.append(new_item)
bbox_img = pil_img.crop((xmin, ymin, xmax, ymax)) bbox_img = np_array_image[ymin:ymax, xmin:xmax]
area = (xmax - xmin) * (ymax - ymin) area = (xmax - xmin) * (ymax - ymin)
curr_idx = len(mf_image_list) curr_idx = len(mf_image_list)
...@@ -182,22 +103,23 @@ class UnimernetModel(object): ...@@ -182,22 +103,23 @@ class UnimernetModel(object):
index_mapping = {new_idx: old_idx for new_idx, old_idx in enumerate(sorted_indices)} index_mapping = {new_idx: old_idx for new_idx, old_idx in enumerate(sorted_indices)}
# Create dataset with sorted images # Create dataset with sorted images
dataset = MathDataset(sorted_images, transform=self.mfr_transform) dataset = MathDataset(sorted_images, transform=self.model.transform)
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0) dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0)
# Process batches and store results # Process batches and store results
mfr_res = [] mfr_res = []
for mf_img in dataloader: for mf_img in dataloader:
mf_img = mf_img.to(dtype=self.model.dtype)
mf_img = mf_img.to(self.device) mf_img = mf_img.to(self.device)
with torch.no_grad(): with torch.no_grad():
output = self.model.generate({"image": mf_img}) output = self.model.generate({"image": mf_img})
mfr_res.extend(output["pred_str"]) mfr_res.extend(output["fixed_str"])
# Restore original order # Restore original order
unsorted_results = [""] * len(mfr_res) unsorted_results = [""] * len(mfr_res)
for new_idx, latex in enumerate(mfr_res): for new_idx, latex in enumerate(mfr_res):
original_idx = index_mapping[new_idx] original_idx = index_mapping[new_idx]
unsorted_results[original_idx] = latex_rm_whitespace(latex) unsorted_results[original_idx] = latex
# Fill results back # Fill results back
for res, latex in zip(backfill_list, unsorted_results): for res, latex in zip(backfill_list, unsorted_results):
......
from .unimer_swin import UnimerSwinConfig, UnimerSwinModel, UnimerSwinImageProcessor
from .unimer_mbart import UnimerMBartConfig, UnimerMBartModel, UnimerMBartForCausalLM
from .modeling_unimernet import UnimernetModel
__all__ = [
"UnimerSwinConfig",
"UnimerSwinModel",
"UnimerSwinImageProcessor",
"UnimerMBartConfig",
"UnimerMBartModel",
"UnimerMBartForCausalLM",
"UnimernetModel",
]
import os
import re
import warnings
from typing import Optional
import torch
from ftfy import fix_text
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokenizer, PretrainedConfig, PreTrainedModel
from transformers import VisionEncoderDecoderConfig, VisionEncoderDecoderModel
from transformers.models.vision_encoder_decoder.modeling_vision_encoder_decoder import logger as base_model_logger
from .unimer_swin import UnimerSwinConfig, UnimerSwinModel, UnimerSwinImageProcessor
from .unimer_mbart import UnimerMBartConfig, UnimerMBartForCausalLM
AutoConfig.register(UnimerSwinConfig.model_type, UnimerSwinConfig)
AutoConfig.register(UnimerMBartConfig.model_type, UnimerMBartConfig)
AutoModel.register(UnimerSwinConfig, UnimerSwinModel)
AutoModelForCausalLM.register(UnimerMBartConfig, UnimerMBartForCausalLM)
# TODO: rewrite tokenizer
class TokenizerWrapper:
def __init__(self, tokenizer):
self.tokenizer = tokenizer
self.pad_token_id = self.tokenizer.pad_token_id
self.bos_token_id = self.tokenizer.bos_token_id
self.eos_token_id = self.tokenizer.eos_token_id
def __len__(self):
return len(self.tokenizer)
def tokenize(self, text, **kwargs):
return self.tokenizer(
text,
return_token_type_ids=False,
return_tensors="pt",
padding="longest",
truncation=True,
**kwargs,
)
def token2str(self, tokens) -> list:
generated_text = self.tokenizer.batch_decode(tokens, skip_special_tokens=True)
generated_text = [fix_text(text) for text in generated_text]
return generated_text
def detokenize(self, tokens):
toks = [self.tokenizer.convert_ids_to_tokens(tok) for tok in tokens]
for b in range(len(toks)):
for i in reversed(range(len(toks[b]))):
if toks[b][i] is None:
toks[b][i] = ''
toks[b][i] = toks[b][i].replace('Ġ', ' ').strip()
if toks[b][i] in ([self.tokenizer.bos_token, self.tokenizer.eos_token, self.tokenizer.pad_token]):
del toks[b][i]
return toks
def latex_rm_whitespace(s: str):
"""Remove unnecessary whitespace from LaTeX code.
"""
text_reg = r'(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})'
letter = r'[a-zA-Z]'
noletter = r'[\W_^\d]'
names = [x[0].replace(' ', '') for x in re.findall(text_reg, s)]
s = re.sub(text_reg, lambda _: str(names.pop(0)), s)
news = s
while True:
s = news
news = re.sub(r'(?!\\ )(%s)\s+?(%s)' % (noletter, noletter), r'\1\2', s)
news = re.sub(r'(?!\\ )(%s)\s+?(%s)' % (noletter, letter), r'\1\2', news)
news = re.sub(r'(%s)\s+?(%s)' % (letter, noletter), r'\1\2', news)
if news == s:
break
return s
class UnimernetModel(VisionEncoderDecoderModel):
def __init__(
self,
config: Optional[PretrainedConfig] = None,
encoder: Optional[PreTrainedModel] = None,
decoder: Optional[PreTrainedModel] = None,
):
# VisionEncoderDecoderModel's checking log has bug, disable for temp.
base_model_logger.disabled = True
try:
super().__init__(config, encoder, decoder)
finally:
base_model_logger.disabled = False
if not config or not hasattr(config, "_name_or_path"):
raise RuntimeError("config._name_or_path is required by UnimernetModel.")
model_path = config._name_or_path
self.transform = UnimerSwinImageProcessor()
self.tokenizer = TokenizerWrapper(AutoTokenizer.from_pretrained(model_path))
self._post_check()
def _post_check(self):
tokenizer = self.tokenizer
if tokenizer.tokenizer.model_max_length != self.config.decoder.max_position_embeddings:
warnings.warn(
f"decoder.max_position_embeddings={self.config.decoder.max_position_embeddings}," +
f" but tokenizer.model_max_length={tokenizer.tokenizer.model_max_length}, will set" +
f" tokenizer.model_max_length to {self.config.decoder.max_position_embeddings}.")
tokenizer.tokenizer.model_max_length = self.config.decoder.max_position_embeddings
assert self.config.decoder.vocab_size == len(tokenizer)
assert self.config.decoder_start_token_id == tokenizer.bos_token_id
assert self.config.pad_token_id == tokenizer.pad_token_id
@classmethod
def from_checkpoint(cls, model_path: str, model_filename: str = "pytorch_model.pth", state_dict_strip_prefix="model.model."):
config = VisionEncoderDecoderConfig.from_pretrained(model_path)
config._name_or_path = model_path
config.encoder = UnimerSwinConfig(**vars(config.encoder))
config.decoder = UnimerMBartConfig(**vars(config.decoder))
encoder = UnimerSwinModel(config.encoder)
decoder = UnimerMBartForCausalLM(config.decoder)
model = cls(config, encoder, decoder)
# load model weights
model_file_path = os.path.join(model_path, model_filename)
checkpoint = torch.load(model_file_path, map_location="cpu", weights_only=True)
state_dict = checkpoint["model"] if "model" in checkpoint else checkpoint
if not state_dict:
raise RuntimeError("state_dict is empty.")
if state_dict_strip_prefix:
state_dict = {
k[len(state_dict_strip_prefix):] if k.startswith(state_dict_strip_prefix) else k: v
for k, v in state_dict.items()
}
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
if len(unexpected_keys) > 0:
warnings.warn("Unexpected key(s) in state_dict: {}.".format(", ".join(f'"{k}"' for k in unexpected_keys)))
if len(missing_keys) > 0:
raise RuntimeError("Missing key(s) in state_dict: {}.".format(", ".join(f'"{k}"' for k in missing_keys)))
return model
def forward_bak(self, samples):
pixel_values, text = samples["image"], samples["text_input"]
text_inputs = self.tokenizer.tokenize(text).to(pixel_values.device)
decoder_input_ids, decoder_attention_mask = text_inputs["input_ids"], text_inputs["attention_mask"]
num_channels = pixel_values.shape[1]
if num_channels == 1:
pixel_values = pixel_values.repeat(1, 3, 1, 1)
labels = decoder_input_ids * 1
labels = labels.masked_fill(labels == self.tokenizer.pad_token_id, -100)
loss = self.model(
pixel_values=pixel_values,
decoder_input_ids=decoder_input_ids[:, :-1],
decoder_attention_mask=decoder_attention_mask[:, :-1],
labels=labels[:, 1:],
).loss
return {"loss": loss}
def generate(self, samples, do_sample: bool = False, temperature: float = 0.2, top_p: float = 0.95):
pixel_values = samples["image"]
num_channels = pixel_values.shape[1]
if num_channels == 1:
pixel_values = pixel_values.repeat(1, 3, 1, 1)
kwargs = {}
if do_sample:
kwargs["temperature"] = temperature
kwargs["top_p"] = top_p
outputs = super().generate(
pixel_values=pixel_values,
max_new_tokens=self.tokenizer.tokenizer.model_max_length, # required
decoder_start_token_id=self.tokenizer.tokenizer.bos_token_id,
do_sample=do_sample,
**kwargs,
)
outputs = outputs[:, 1:].cpu().numpy()
pred_tokens = self.tokenizer.detokenize(outputs)
pred_str = self.tokenizer.token2str(outputs)
fixed_str = [latex_rm_whitespace(s) for s in pred_str]
return {"pred_ids": outputs, "pred_tokens": pred_tokens, "pred_str": pred_str, "fixed_str": fixed_str}
from .configuration_unimer_mbart import UnimerMBartConfig
from .modeling_unimer_mbart import UnimerMBartModel, UnimerMBartForCausalLM
__all__ = [
"UnimerMBartConfig",
"UnimerMBartModel",
"UnimerMBartForCausalLM",
]
# coding=utf-8
# Copyright 2021, The Facebook AI Research Team and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""UnimerMBART model configuration"""
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
class UnimerMBartConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`MBartModel`]. It is used to instantiate an MBART
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the MBART
[facebook/mbart-large-cc25](https://huggingface.co/facebook/mbart-large-cc25) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 50265):
Vocabulary size of the MBART model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`MBartModel`] or [`TFMBartModel`].
d_model (`int`, *optional*, defaults to 1024):
Dimensionality of the layers and the pooler layer.
qk_squeeze (`int`, *optional*, defaults to 2):
Squeeze ratio for query/key's output dimension. See the [UniMERNet paper](https://arxiv.org/abs/2404.15254).
Squeeze Attention maps the query and key to a lower-dimensional space without excessive loss of information,
thereby accelerating the computation of attention.
encoder_layers (`int`, *optional*, defaults to 12):
Number of encoder layers.
decoder_layers (`int`, *optional*, defaults to 12):
Number of decoder layers.
encoder_attention_heads (`int`, *optional*, defaults to 16):
Number of attention heads for each attention layer in the Transformer encoder.
decoder_attention_heads (`int`, *optional*, defaults to 16):
Number of attention heads for each attention layer in the Transformer decoder.
decoder_ffn_dim (`int`, *optional*, defaults to 4096):
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
encoder_ffn_dim (`int`, *optional*, defaults to 4096):
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"silu"` and `"gelu_new"` are supported.
dropout (`float`, *optional*, defaults to 0.1):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
activation_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for activations inside the fully connected layer.
classifier_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for classifier.
max_position_embeddings (`int`, *optional*, defaults to 1024):
The maximum sequence length that this model might ever be used with. Typically set this to something large
just in case (e.g., 512 or 1024 or 2048).
init_std (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
encoder_layerdrop (`float`, *optional*, defaults to 0.0):
The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
for more details.
decoder_layerdrop (`float`, *optional*, defaults to 0.0):
The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
for more details.
scale_embedding (`bool`, *optional*, defaults to `False`):
Scale embeddings by diving by sqrt(d_model).
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models)
forced_eos_token_id (`int`, *optional*, defaults to 2):
The id of the token to force as the last generated token when `max_length` is reached. Usually set to
`eos_token_id`.
Example:
```python
>>> from transformers import MBartConfig, MBartModel
>>> # Initializing a MBART facebook/mbart-large-cc25 style configuration
>>> configuration = MBartConfig()
>>> # Initializing a model (with random weights) from the facebook/mbart-large-cc25 style configuration
>>> model = MBartModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "unimer-mbart"
keys_to_ignore_at_inference = ["past_key_values"]
attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"}
def __init__(
self,
vocab_size=50265,
max_position_embeddings=1024,
encoder_layers=12,
encoder_ffn_dim=4096,
encoder_attention_heads=16,
decoder_layers=12,
decoder_ffn_dim=4096,
decoder_attention_heads=16,
encoder_layerdrop=0.0,
decoder_layerdrop=0.0,
use_cache=True,
is_encoder_decoder=True,
activation_function="gelu",
d_model=1024,
qk_squeeze=2,
dropout=0.1,
attention_dropout=0.0,
activation_dropout=0.0,
init_std=0.02,
classifier_dropout=0.0,
scale_embedding=False,
pad_token_id=1,
bos_token_id=0,
eos_token_id=2,
forced_eos_token_id=2,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.d_model = d_model
self.qk_squeeze = qk_squeeze
self.encoder_ffn_dim = encoder_ffn_dim
self.encoder_layers = encoder_layers
self.encoder_attention_heads = encoder_attention_heads
self.decoder_ffn_dim = decoder_ffn_dim
self.decoder_layers = decoder_layers
self.decoder_attention_heads = decoder_attention_heads
self.dropout = dropout
self.attention_dropout = attention_dropout
self.activation_dropout = activation_dropout
self.activation_function = activation_function
self.init_std = init_std
self.encoder_layerdrop = encoder_layerdrop
self.decoder_layerdrop = decoder_layerdrop
self.classifier_dropout = classifier_dropout
self.use_cache = use_cache
self.num_hidden_layers = encoder_layers
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
is_encoder_decoder=is_encoder_decoder,
forced_eos_token_id=forced_eos_token_id,
**kwargs,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment