Commit 4101c357 authored by zhaoxiaomeng's avatar zhaoxiaomeng
Browse files

refactor(model): update init methods and improve model loading logic

parent b6df9b18
__use_inside_model__ = True
__use_inside_model__ = False
__model_mode__ = "full"
import os
import time
import cv2
import numpy as np
import yaml
from PIL import Image
from ultralytics import YOLO
import time
import argparse
import numpy as np
import torch
from loguru import logger
from magic_pdf.model.pek_sub_modules.layoutlmv3.model_init import Layoutlmv3_Predictor
from paddleocr import draw_ocr
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from ultralytics import YOLO
from unimernet.common.config import Config
import unimernet.tasks as tasks
from unimernet.processors import load_processor
import argparse
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from magic_pdf.model.pek_sub_modules.layoutlmv3.model_init import Layoutlmv3_Predictor
from magic_pdf.model.pek_sub_modules.post_process import get_croped_image, latex_rm_whitespace
from magic_pdf.model.pek_sub_modules.self_modify import ModifiedPaddleOCR
def layout_model_init(weight, config_file, device):
model = Layoutlmv3_Predictor(weight, config_file, device)
return model
def mfd_model_init(weight):
mfd_model = YOLO(weight)
return mfd_model
def mfr_model_init(weight_dir, cfg_path, device='cpu'):
def mfr_model_init(weight_dir, cfg_path, _device_='cpu'):
args = argparse.Namespace(cfg_path=cfg_path, options=None)
cfg = Config(args)
cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.bin")
......@@ -33,11 +34,16 @@ def mfr_model_init(weight_dir, cfg_path, device='cpu'):
cfg.config.model.tokenizer_config.path = weight_dir
task = tasks.setup_task(cfg)
model = task.build_model(cfg)
model = model.to(device)
model = model.to(_device_)
vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
return model, vis_processor
def layout_model_init(weight, config_file, device):
model = Layoutlmv3_Predictor(weight, config_file, device)
return model
class MathDataset(Dataset):
def __init__(self, image_paths, transform=None):
self.image_paths = image_paths
......@@ -54,10 +60,11 @@ class MathDataset(Dataset):
raw_image = self.image_paths[idx]
if self.transform:
image = self.transform(raw_image)
return image
return image
class CustomPEKModel:
def __init__(self, ocr: bool = False, show_log: bool = False, **kwargs):
"""
======== model init ========
......@@ -88,24 +95,24 @@ class CustomPEKModel:
self.device = kwargs.get("device", self.configs["config"]["device"])
logger.info("using device: {}".format(self.device))
models_dir = kwargs.get("models_dir", os.path.join(root_dir, "resources", "models"))
# 初始化layout模型
self.layout_model = layout_model_init(
os.path.join(models_dir, self.configs['weights']['layout']),
os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml"),
device=self.device
)
# 初始化公式识别
if self.apply_formula:
# 初始化公式检测模型
self.mfd_model = YOLO(model=str(os.path.join(models_dir, self.configs["weights"]["mfd"])))
self.mfd_model = mfd_model_init(str(os.path.join(models_dir, self.configs["weights"]["mfd"])))
# 初始化公式解析模型
mfr_config_path = os.path.join(model_config_dir, 'UniMERNet', 'demo.yaml')
self.mfr_model, mfr_vis_processors = mfr_model_init(
os.path.join(models_dir, self.configs["weights"]["mfr"]),
mfr_config_path,
device=self.device
)
mfr_weight_dir = str(os.path.join(models_dir, self.configs["weights"]["mfr"]))
mfr_cfg_path = str(os.path.join(model_config_dir, "UniMERNet", "demo.yaml"))
self.mfr_model, mfr_vis_processors = mfr_model_init(mfr_weight_dir, mfr_cfg_path, _device_=self.device)
self.mfr_transform = transforms.Compose([mfr_vis_processors, ])
# 初始化layout模型
self.layout_model = Layoutlmv3_Predictor(
str(os.path.join(models_dir, self.configs['weights']['layout'])),
str(os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml")),
device=self.device
)
# 初始化ocr
if self.apply_ocr:
self.ocr_model = ModifiedPaddleOCR(show_log=show_log)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment