Unverified Commit 8ccfff6f authored by Xiaomeng Zhao's avatar Xiaomeng Zhao Committed by GitHub
Browse files

Merge pull request #1281 from IMSUVEN/dev

feat: add batch prediction methods for YOLOv8 and Unimernet models
parents d0a3058b 7ce9edc6
...@@ -30,3 +30,10 @@ class EmptyData(Exception): ...@@ -30,3 +30,10 @@ class EmptyData(Exception):
def __str__(self): def __str__(self):
return f'Empty data: {self.msg}' return f'Empty data: {self.msg}'
class CUDA_NOT_AVAILABLE(Exception):
def __init__(self, msg):
self.msg = msg
def __str__(self):
return f'CUDA not available: {self.msg}'
\ No newline at end of file
import time
import cv2
import numpy as np
import torch
from loguru import logger
from PIL import Image
from magic_pdf.config.constants import MODEL_NAME
from magic_pdf.config.exceptions import CUDA_NOT_AVAILABLE
from magic_pdf.data.dataset import Dataset
from magic_pdf.libs.clean_memory import clean_memory
from magic_pdf.model.doc_analyze_by_custom_model import ModelSingleton
from magic_pdf.model.operators import InferenceResult
from magic_pdf.model.pdf_extract_kit import CustomPEKModel
from magic_pdf.model.sub_modules.model_utils import (
clean_vram,
crop_img,
get_res_list_from_layout_res,
)
from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import (
get_adjusted_mfdetrec_res,
get_ocr_result_list,
)
YOLO_LAYOUT_BASE_BATCH_SIZE = 4
MFD_BASE_BATCH_SIZE = 1
MFR_BASE_BATCH_SIZE = 16
class BatchAnalyze:
def __init__(self, model: CustomPEKModel, batch_ratio: int):
self.model = model
self.batch_ratio = batch_ratio
def __call__(self, images: list) -> list:
if self.model.layout_model_name == MODEL_NAME.LAYOUTLMv3:
# layoutlmv3
images_layout_res = []
for image in images:
layout_res = self.model.layout_model(image, ignore_catids=[])
images_layout_res.append(layout_res)
elif self.model.layout_model_name == MODEL_NAME.DocLayout_YOLO:
# doclayout_yolo
images_layout_res = self.model.layout_model.batch_predict(
images, self.batch_ratio * YOLO_LAYOUT_BASE_BATCH_SIZE
)
if self.model.apply_formula:
# 公式检测
images_mfd_res = self.model.mfd_model.batch_predict(
images, self.batch_ratio * MFD_BASE_BATCH_SIZE
)
# 公式识别
images_formula_list = self.model.mfr_model.batch_predict(
images_mfd_res,
images,
batch_size=self.batch_ratio * MFR_BASE_BATCH_SIZE,
)
for image_index in range(len(images)):
images_layout_res[image_index] += images_formula_list[image_index]
# 清理显存
clean_vram(self.model.device, vram_threshold=8)
# reference: magic_pdf/model/doc_analyze_by_custom_model.py:doc_analyze
for index in range(len(images)):
layout_res = images_layout_res[index]
pil_img = Image.fromarray(images[index])
ocr_res_list, table_res_list, single_page_mfdetrec_res = (
get_res_list_from_layout_res(layout_res)
)
# ocr识别
ocr_start = time.time()
# Process each area that requires OCR processing
for res in ocr_res_list:
new_image, useful_list = crop_img(
res, pil_img, crop_paste_x=50, crop_paste_y=50
)
adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(
single_page_mfdetrec_res, useful_list
)
# OCR recognition
new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
if self.model.apply_ocr:
ocr_res = self.model.ocr_model.ocr(
new_image, mfd_res=adjusted_mfdetrec_res
)[0]
else:
ocr_res = self.model.ocr_model.ocr(
new_image, mfd_res=adjusted_mfdetrec_res, rec=False
)[0]
# Integration results
if ocr_res:
ocr_result_list = get_ocr_result_list(ocr_res, useful_list)
layout_res.extend(ocr_result_list)
ocr_cost = round(time.time() - ocr_start, 2)
if self.model.apply_ocr:
logger.info(f"ocr time: {ocr_cost}")
else:
logger.info(f"det time: {ocr_cost}")
# 表格识别 table recognition
if self.model.apply_table:
table_start = time.time()
for res in table_res_list:
new_image, _ = crop_img(res, pil_img)
single_table_start_time = time.time()
html_code = None
if self.model.table_model_name == MODEL_NAME.STRUCT_EQTABLE:
with torch.no_grad():
table_result = self.model.table_model.predict(
new_image, "html"
)
if len(table_result) > 0:
html_code = table_result[0]
elif self.model.table_model_name == MODEL_NAME.TABLE_MASTER:
html_code = self.model.table_model.img2html(new_image)
elif self.model.table_model_name == MODEL_NAME.RAPID_TABLE:
html_code, table_cell_bboxes, elapse = (
self.model.table_model.predict(new_image)
)
run_time = time.time() - single_table_start_time
if run_time > self.model.table_max_time:
logger.warning(
f"table recognition processing exceeds max time {self.model.table_max_time}s"
)
# 判断是否返回正常
if html_code:
expected_ending = html_code.strip().endswith(
"</html>"
) or html_code.strip().endswith("</table>")
if expected_ending:
res["html"] = html_code
else:
logger.warning(
"table recognition processing fails, not found expected HTML table end"
)
else:
logger.warning(
"table recognition processing fails, not get html return"
)
logger.info(f"table time: {round(time.time() - table_start, 2)}")
def doc_batch_analyze(
dataset: Dataset,
ocr: bool = False,
show_log: bool = False,
start_page_id=0,
end_page_id=None,
lang=None,
layout_model=None,
formula_enable=None,
table_enable=None,
batch_ratio: int | None = None,
) -> InferenceResult:
"""
Perform batch analysis on a document dataset.
Args:
dataset (Dataset): The dataset containing document pages to be analyzed.
ocr (bool, optional): Flag to enable OCR (Optical Character Recognition). Defaults to False.
show_log (bool, optional): Flag to enable logging. Defaults to False.
start_page_id (int, optional): The starting page ID for analysis. Defaults to 0.
end_page_id (int, optional): The ending page ID for analysis. Defaults to None, which means analyze till the last page.
lang (str, optional): Language for OCR. Defaults to None.
layout_model (optional): Layout model to be used for analysis. Defaults to None.
formula_enable (optional): Flag to enable formula detection. Defaults to None.
table_enable (optional): Flag to enable table detection. Defaults to None.
batch_ratio (int | None, optional): Ratio for batch processing. Defaults to None, which sets it to 1.
Raises:
CUDA_NOT_AVAILABLE: If CUDA is not available, raises an exception as batch analysis is not supported in CPU mode.
Returns:
InferenceResult: The result of the batch analysis containing the analyzed data and the dataset.
"""
if not torch.cuda.is_available():
raise CUDA_NOT_AVAILABLE("batch analyze not support in CPU mode")
lang = None if lang == "" else lang
# TODO: auto detect batch size
batch_ratio = 1 if batch_ratio is None else batch_ratio
end_page_id = end_page_id if end_page_id else len(dataset)
model_manager = ModelSingleton()
custom_model: CustomPEKModel = model_manager.get_model(
ocr, show_log, lang, layout_model, formula_enable, table_enable
)
batch_model = BatchAnalyze(model=custom_model, batch_ratio=batch_ratio)
model_json = []
# batch analyze
images = []
for index in range(len(dataset)):
if start_page_id <= index <= end_page_id:
page_data = dataset.get_page(index)
img_dict = page_data.get_image()
images.append(img_dict["img"])
analyze_result = batch_model(images)
for index in range(len(dataset)):
page_data = dataset.get_page(index)
img_dict = page_data.get_image()
page_width = img_dict["width"]
page_height = img_dict["height"]
if start_page_id <= index <= end_page_id:
result = analyze_result.pop(0)
else:
result = []
page_info = {"page_no": index, "height": page_height, "width": page_width}
page_dict = {"layout_dets": result, "page_info": page_info}
model_json.append(page_dict)
# TODO: clean memory when gpu memory is not enough
clean_memory()
return InferenceResult(model_json, dataset)
...@@ -8,14 +8,48 @@ class DocLayoutYOLOModel(object): ...@@ -8,14 +8,48 @@ class DocLayoutYOLOModel(object):
def predict(self, image): def predict(self, image):
layout_res = [] layout_res = []
doclayout_yolo_res = self.model.predict(image, imgsz=1024, conf=0.25, iou=0.45, verbose=True, device=self.device)[0] doclayout_yolo_res = self.model.predict(
for xyxy, conf, cla in zip(doclayout_yolo_res.boxes.xyxy.cpu(), doclayout_yolo_res.boxes.conf.cpu(), image, imgsz=1024, conf=0.25, iou=0.45, verbose=True, device=self.device
doclayout_yolo_res.boxes.cls.cpu()): )[0]
for xyxy, conf, cla in zip(
doclayout_yolo_res.boxes.xyxy.cpu(),
doclayout_yolo_res.boxes.conf.cpu(),
doclayout_yolo_res.boxes.cls.cpu(),
):
xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy] xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
new_item = { new_item = {
'category_id': int(cla.item()), "category_id": int(cla.item()),
'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax], "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
'score': round(float(conf.item()), 3), "score": round(float(conf.item()), 3),
} }
layout_res.append(new_item) layout_res.append(new_item)
return layout_res return layout_res
\ No newline at end of file
def batch_predict(self, images: list, batch_size: int) -> list:
images_layout_res = []
for index in range(0, len(images), batch_size):
doclayout_yolo_res = self.model.predict(
images[index : index + batch_size],
imgsz=1024,
conf=0.25,
iou=0.45,
verbose=True,
device=self.device,
).cpu()
for image_res in doclayout_yolo_res:
layout_res = []
for xyxy, conf, cla in zip(
image_res.boxes.xyxy,
image_res.boxes.conf,
image_res.boxes.cls,
):
xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
new_item = {
"category_id": int(cla.item()),
"poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
"score": round(float(conf.item()), 3),
}
layout_res.append(new_item)
images_layout_res.append(layout_res)
return images_layout_res
...@@ -2,11 +2,27 @@ from ultralytics import YOLO ...@@ -2,11 +2,27 @@ from ultralytics import YOLO
class YOLOv8MFDModel(object): class YOLOv8MFDModel(object):
def __init__(self, weight, device='cpu'): def __init__(self, weight, device="cpu"):
self.mfd_model = YOLO(weight) self.mfd_model = YOLO(weight)
self.device = device self.device = device
def predict(self, image): def predict(self, image):
mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True, device=self.device)[0] mfd_res = self.mfd_model.predict(
image, imgsz=1888, conf=0.25, iou=0.45, verbose=True, device=self.device
)[0]
return mfd_res return mfd_res
def batch_predict(self, images: list, batch_size: int) -> list:
images_mfd_res = []
for index in range(0, len(images), batch_size):
mfd_res = self.mfd_model.predict(
images[index : index + batch_size],
imgsz=1888,
conf=0.25,
iou=0.45,
verbose=True,
device=self.device,
).cpu()
for image_res in mfd_res:
images_mfd_res.append(image_res)
return images_mfd_res
import os
import argparse import argparse
import os
import re import re
from PIL import Image
import torch import torch
from torch.utils.data import Dataset, DataLoader import unimernet.tasks as tasks
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms from torchvision import transforms
from unimernet.common.config import Config from unimernet.common.config import Config
import unimernet.tasks as tasks
from unimernet.processors import load_processor from unimernet.processors import load_processor
...@@ -31,27 +31,25 @@ class MathDataset(Dataset): ...@@ -31,27 +31,25 @@ class MathDataset(Dataset):
def latex_rm_whitespace(s: str): def latex_rm_whitespace(s: str):
"""Remove unnecessary whitespace from LaTeX code. """Remove unnecessary whitespace from LaTeX code."""
""" text_reg = r"(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})"
text_reg = r'(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})' letter = "[a-zA-Z]"
letter = '[a-zA-Z]' noletter = "[\W_^\d]"
noletter = '[\W_^\d]' names = [x[0].replace(" ", "") for x in re.findall(text_reg, s)]
names = [x[0].replace(' ', '') for x in re.findall(text_reg, s)]
s = re.sub(text_reg, lambda match: str(names.pop(0)), s) s = re.sub(text_reg, lambda match: str(names.pop(0)), s)
news = s news = s
while True: while True:
s = news s = news
news = re.sub(r'(?!\\ )(%s)\s+?(%s)' % (noletter, noletter), r'\1\2', s) news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, noletter), r"\1\2", s)
news = re.sub(r'(?!\\ )(%s)\s+?(%s)' % (noletter, letter), r'\1\2', news) news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, letter), r"\1\2", news)
news = re.sub(r'(%s)\s+?(%s)' % (letter, noletter), r'\1\2', news) news = re.sub(r"(%s)\s+?(%s)" % (letter, noletter), r"\1\2", news)
if news == s: if news == s:
break break
return s return s
class UnimernetModel(object): class UnimernetModel(object):
def __init__(self, weight_dir, cfg_path, _device_='cpu'): def __init__(self, weight_dir, cfg_path, _device_="cpu"):
args = argparse.Namespace(cfg_path=cfg_path, options=None) args = argparse.Namespace(cfg_path=cfg_path, options=None)
cfg = Config(args) cfg = Config(args)
cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.pth") cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.pth")
...@@ -62,20 +60,28 @@ class UnimernetModel(object): ...@@ -62,20 +60,28 @@ class UnimernetModel(object):
self.device = _device_ self.device = _device_
self.model.to(_device_) self.model.to(_device_)
self.model.eval() self.model.eval()
vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval) vis_processor = load_processor(
self.mfr_transform = transforms.Compose([vis_processor, ]) "formula_image_eval",
cfg.config.datasets.formula_rec_eval.vis_processor.eval,
)
self.mfr_transform = transforms.Compose(
[
vis_processor,
]
)
def predict(self, mfd_res, image): def predict(self, mfd_res, image):
formula_list = [] formula_list = []
mf_image_list = [] mf_image_list = []
for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()): for xyxy, conf, cla in zip(
mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()
):
xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy] xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
new_item = { new_item = {
'category_id': 13 + int(cla.item()), "category_id": 13 + int(cla.item()),
'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax], "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
'score': round(float(conf.item()), 2), "score": round(float(conf.item()), 2),
'latex': '', "latex": "",
} }
formula_list.append(new_item) formula_list.append(new_item)
pil_img = Image.fromarray(image) pil_img = Image.fromarray(image)
...@@ -88,11 +94,48 @@ class UnimernetModel(object): ...@@ -88,11 +94,48 @@ class UnimernetModel(object):
for mf_img in dataloader: for mf_img in dataloader:
mf_img = mf_img.to(self.device) mf_img = mf_img.to(self.device)
with torch.no_grad(): with torch.no_grad():
output = self.model.generate({'image': mf_img}) output = self.model.generate({"image": mf_img})
mfr_res.extend(output['pred_str']) mfr_res.extend(output["pred_str"])
for res, latex in zip(formula_list, mfr_res): for res, latex in zip(formula_list, mfr_res):
res['latex'] = latex_rm_whitespace(latex) res["latex"] = latex_rm_whitespace(latex)
return formula_list return formula_list
def batch_predict(
self, images_mfd_res: list, images: list, batch_size: int = 64
) -> list:
images_formula_list = []
mf_image_list = []
backfill_list = []
for image_index in range(len(images_mfd_res)):
mfd_res = images_mfd_res[image_index]
pil_img = Image.fromarray(images[image_index])
formula_list = []
for xyxy, conf, cla in zip(
mfd_res.boxes.xyxy, mfd_res.boxes.conf, mfd_res.boxes.cls
):
xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
new_item = {
"category_id": 13 + int(cla.item()),
"poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
"score": round(float(conf.item()), 2),
"latex": "",
}
formula_list.append(new_item)
bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
mf_image_list.append(bbox_img)
images_formula_list.append(formula_list)
backfill_list += formula_list
dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0)
mfr_res = []
for mf_img in dataloader:
mf_img = mf_img.to(self.device)
with torch.no_grad():
output = self.model.generate({"image": mf_img})
mfr_res.extend(output["pred_str"])
for res, latex in zip(backfill_list, mfr_res):
res["latex"] = latex_rm_whitespace(latex)
return images_formula_list
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment