Commit 1ee81a9a authored by 赵小蒙's avatar 赵小蒙
Browse files

update:

1.Disable scaling when loading large images.
2.Moving the logic for channel conversion in image processing.
parent f14e50e2
...@@ -21,10 +21,11 @@ def remove_duplicates_dicts(lst): ...@@ -21,10 +21,11 @@ def remove_duplicates_dicts(lst):
def load_images_from_pdf(pdf_bytes: bytes, dpi=200) -> list: def load_images_from_pdf(pdf_bytes: bytes, dpi=200) -> list:
try: try:
import cv2
from PIL import Image from PIL import Image
except ImportError: except ImportError:
logger.error("opencv-python and Pillow are not installed, please install by pip.") logger.error("Pillow not installed, please install by pip.")
exit(1)
images = [] images = []
with fitz.open("pdf", pdf_bytes) as doc: with fitz.open("pdf", pdf_bytes) as doc:
for index in range(0, doc.page_count): for index in range(0, doc.page_count):
...@@ -32,12 +33,12 @@ def load_images_from_pdf(pdf_bytes: bytes, dpi=200) -> list: ...@@ -32,12 +33,12 @@ def load_images_from_pdf(pdf_bytes: bytes, dpi=200) -> list:
mat = fitz.Matrix(dpi / 72, dpi / 72) mat = fitz.Matrix(dpi / 72, dpi / 72)
pm = page.get_pixmap(matrix=mat, alpha=False) pm = page.get_pixmap(matrix=mat, alpha=False)
# if width or height > 2000 pixels, don't enlarge the image # if width or height > 3000 pixels, don't enlarge the image
# if pm.width > 2000 or pm.height > 2000: if pix.width > 3000 or pix.height > 3000:
# pm = page.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False) pix = page.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False)
img = Image.frombytes("RGB", [pm.width, pm.height], pm.samples) img = Image.frombytes("RGB", (pm.width, pm.height), pm.samples)
img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) img = np.array(img)
img_dict = {"img": img, "width": pm.width, "height": pm.height} img_dict = {"img": img, "width": pm.width, "height": pm.height}
images.append(img_dict) images.append(img_dict)
return images return images
...@@ -68,4 +69,6 @@ def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False, mod ...@@ -68,4 +69,6 @@ def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False, mod
model_json.append(page_dict) model_json.append(page_dict)
# @todo 把公式识别放在后置位置,待整本全部模型结果出来之后再补公式数据
return model_json return model_json
import os
import time
import cv2
import fitz
import numpy as np
import torch
import unimernet.tasks as tasks
import yaml
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from ultralytics import YOLO
from unimernet.common.config import Config
from unimernet.processors import load_processor
class CustomPEKModel:
def __init__(self, ocr: bool = False, show_log: bool = False):
## ======== model init ========##
with open('configs/model_configs.yaml') as f:
model_configs = yaml.load(f, Loader=yaml.FullLoader)
img_size = model_configs['model_args']['img_size']
conf_thres = model_configs['model_args']['conf_thres']
iou_thres = model_configs['model_args']['iou_thres']
device = model_configs['model_args']['device']
dpi = model_configs['model_args']['pdf_dpi']
mfd_model = mfd_model_init(model_configs['model_args']['mfd_weight'])
mfr_model, mfr_vis_processors = mfr_model_init(model_configs['model_args']['mfr_weight'], device=device)
mfr_transform = transforms.Compose([mfr_vis_processors, ])
layout_model = layout_model_init(model_configs['model_args']['layout_weight'])
ocr_model = ModifiedPaddleOCR(show_log=True)
print(now.strftime('%Y-%m-%d %H:%M:%S'))
print('Model init done!')
## ======== model init ========##
def __call__(self, image):
# layout检测 + 公式检测
doc_layout_result = []
latex_filling_list = []
mf_image_list = []
img_H, img_W = image.shape[0], image.shape[1]
layout_res = layout_model(image, ignore_catids=[])
# 公式检测
mfd_res = mfd_model.predict(image, imgsz=img_size, conf=conf_thres, iou=iou_thres, verbose=True)[0]
for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()):
xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
new_item = {
'category_id': 13 + int(cla.item()),
'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
'score': round(float(conf.item()), 2),
'latex': '',
}
layout_res['layout_dets'].append(new_item)
latex_filling_list.append(new_item)
bbox_img = get_croped_image(Image.fromarray(image), [xmin, ymin, xmax, ymax])
mf_image_list.append(bbox_img)
layout_res['page_info'] = dict(
page_no=idx,
height=img_H,
width=img_W
)
doc_layout_result.append(layout_res)
# 公式识别,因为识别速度较慢,为了提速,把单个pdf的所有公式裁剪完,一起批量做识别。
a = time.time()
dataset = MathDataset(mf_image_list, transform=mfr_transform)
dataloader = DataLoader(dataset, batch_size=128, num_workers=0)
mfr_res = []
gpu_total_cost = 0
for imgs in dataloader:
imgs = imgs.to(device)
gpu_start = time.time()
output = mfr_model.generate({'image': imgs})
gpu_cost = time.time() - gpu_start
gpu_total_cost += gpu_cost
print(f"gpu_cost: {gpu_cost}")
mfr_res.extend(output['pred_str'])
print(f"gpu_total_cost: {gpu_total_cost}")
for res, latex in zip(latex_filling_list, mfr_res):
res['latex'] = latex_rm_whitespace(latex)
b = time.time()
print("formula nums:", len(mf_image_list), "mfr time:", round(b - a, 2))
# ocr识别
for idx, image in enumerate(img_list):
pil_img = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
single_page_res = doc_layout_result[idx]['layout_dets']
single_page_mfdetrec_res = []
for res in single_page_res:
if int(res['category_id']) in [13, 14]:
xmin, ymin = int(res['poly'][0]), int(res['poly'][1])
xmax, ymax = int(res['poly'][4]), int(res['poly'][5])
single_page_mfdetrec_res.append({
"bbox": [xmin, ymin, xmax, ymax],
})
for res in single_page_res:
if int(res['category_id']) in [0, 1, 2, 4, 6, 7]: # 需要进行ocr的类别
xmin, ymin = int(res['poly'][0]), int(res['poly'][1])
xmax, ymax = int(res['poly'][4]), int(res['poly'][5])
crop_box = [xmin, ymin, xmax, ymax]
cropped_img = Image.new('RGB', pil_img.size, 'white')
cropped_img.paste(pil_img.crop(crop_box), crop_box)
cropped_img = cv2.cvtColor(np.asarray(cropped_img), cv2.COLOR_RGB2BGR)
ocr_res = ocr_model.ocr(cropped_img, mfd_res=single_page_mfdetrec_res)[0]
if ocr_res:
for box_ocr_res in ocr_res:
p1, p2, p3, p4 = box_ocr_res[0]
text, score = box_ocr_res[1]
doc_layout_result[idx]['layout_dets'].append({
'category_id': 15,
'poly': p1 + p2 + p3 + p4,
'score': round(score, 2),
'text': text,
})
output_dir = args.output
os.makedirs(output_dir, exist_ok=True)
basename = os.path.basename(single_pdf)[0:-4]
with open(os.path.join(output_dir, f'{basename}.json'), 'w') as f:
json.dump(doc_layout_result, f)
\ No newline at end of file
...@@ -22,6 +22,13 @@ class CustomPaddleModel: ...@@ -22,6 +22,13 @@ class CustomPaddleModel:
self.model = PPStructure(table=False, ocr=ocr, show_log=show_log) self.model = PPStructure(table=False, ocr=ocr, show_log=show_log)
def __call__(self, img): def __call__(self, img):
try:
import cv2
except ImportError:
logger.error("opencv-python not installed, please install by pip.")
exit(1)
# 将RGB图片转换为BGR格式适配paddle
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
result = self.model(img) result = self.model(img)
spans = [] spans = []
for line in result: for line in result:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment