Commit 4101c357 authored by zhaoxiaomeng's avatar zhaoxiaomeng
Browse files

refactor(model): update init methods and improve model loading logic

parent b6df9b18
__use_inside_model__ = True __use_inside_model__ = False
__model_mode__ = "full" __model_mode__ = "full"
import os import os
import time
import cv2 import cv2
import numpy as np
import yaml import yaml
from PIL import Image import time
from ultralytics import YOLO import argparse
import numpy as np
import torch
from loguru import logger from loguru import logger
from magic_pdf.model.pek_sub_modules.layoutlmv3.model_init import Layoutlmv3_Predictor from paddleocr import draw_ocr
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from ultralytics import YOLO
from unimernet.common.config import Config from unimernet.common.config import Config
import unimernet.tasks as tasks import unimernet.tasks as tasks
from unimernet.processors import load_processor from unimernet.processors import load_processor
import argparse
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from magic_pdf.model.pek_sub_modules.layoutlmv3.model_init import Layoutlmv3_Predictor
from magic_pdf.model.pek_sub_modules.post_process import get_croped_image, latex_rm_whitespace from magic_pdf.model.pek_sub_modules.post_process import get_croped_image, latex_rm_whitespace
from magic_pdf.model.pek_sub_modules.self_modify import ModifiedPaddleOCR from magic_pdf.model.pek_sub_modules.self_modify import ModifiedPaddleOCR
def layout_model_init(weight, config_file, device): def mfd_model_init(weight):
model = Layoutlmv3_Predictor(weight, config_file, device) mfd_model = YOLO(weight)
return model return mfd_model
def mfr_model_init(weight_dir, cfg_path, device='cpu'): def mfr_model_init(weight_dir, cfg_path, _device_='cpu'):
args = argparse.Namespace(cfg_path=cfg_path, options=None) args = argparse.Namespace(cfg_path=cfg_path, options=None)
cfg = Config(args) cfg = Config(args)
cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.bin") cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.bin")
...@@ -33,11 +34,16 @@ def mfr_model_init(weight_dir, cfg_path, device='cpu'): ...@@ -33,11 +34,16 @@ def mfr_model_init(weight_dir, cfg_path, device='cpu'):
cfg.config.model.tokenizer_config.path = weight_dir cfg.config.model.tokenizer_config.path = weight_dir
task = tasks.setup_task(cfg) task = tasks.setup_task(cfg)
model = task.build_model(cfg) model = task.build_model(cfg)
model = model.to(device) model = model.to(_device_)
vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval) vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
return model, vis_processor return model, vis_processor
def layout_model_init(weight, config_file, device):
model = Layoutlmv3_Predictor(weight, config_file, device)
return model
class MathDataset(Dataset): class MathDataset(Dataset):
def __init__(self, image_paths, transform=None): def __init__(self, image_paths, transform=None):
self.image_paths = image_paths self.image_paths = image_paths
...@@ -54,10 +60,11 @@ class MathDataset(Dataset): ...@@ -54,10 +60,11 @@ class MathDataset(Dataset):
raw_image = self.image_paths[idx] raw_image = self.image_paths[idx]
if self.transform: if self.transform:
image = self.transform(raw_image) image = self.transform(raw_image)
return image return image
class CustomPEKModel: class CustomPEKModel:
def __init__(self, ocr: bool = False, show_log: bool = False, **kwargs): def __init__(self, ocr: bool = False, show_log: bool = False, **kwargs):
""" """
======== model init ======== ======== model init ========
...@@ -88,24 +95,24 @@ class CustomPEKModel: ...@@ -88,24 +95,24 @@ class CustomPEKModel:
self.device = kwargs.get("device", self.configs["config"]["device"]) self.device = kwargs.get("device", self.configs["config"]["device"])
logger.info("using device: {}".format(self.device)) logger.info("using device: {}".format(self.device))
models_dir = kwargs.get("models_dir", os.path.join(root_dir, "resources", "models")) models_dir = kwargs.get("models_dir", os.path.join(root_dir, "resources", "models"))
# 初始化layout模型
self.layout_model = layout_model_init(
os.path.join(models_dir, self.configs['weights']['layout']),
os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml"),
device=self.device
)
# 初始化公式识别 # 初始化公式识别
if self.apply_formula: if self.apply_formula:
# 初始化公式检测模型 # 初始化公式检测模型
self.mfd_model = YOLO(model=str(os.path.join(models_dir, self.configs["weights"]["mfd"]))) self.mfd_model = mfd_model_init(str(os.path.join(models_dir, self.configs["weights"]["mfd"])))
# 初始化公式解析模型 # 初始化公式解析模型
mfr_config_path = os.path.join(model_config_dir, 'UniMERNet', 'demo.yaml') mfr_weight_dir = str(os.path.join(models_dir, self.configs["weights"]["mfr"]))
self.mfr_model, mfr_vis_processors = mfr_model_init( mfr_cfg_path = str(os.path.join(model_config_dir, "UniMERNet", "demo.yaml"))
os.path.join(models_dir, self.configs["weights"]["mfr"]), self.mfr_model, mfr_vis_processors = mfr_model_init(mfr_weight_dir, mfr_cfg_path, _device_=self.device)
mfr_config_path,
device=self.device
)
self.mfr_transform = transforms.Compose([mfr_vis_processors, ]) self.mfr_transform = transforms.Compose([mfr_vis_processors, ])
# 初始化layout模型
self.layout_model = Layoutlmv3_Predictor(
str(os.path.join(models_dir, self.configs['weights']['layout'])),
str(os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml")),
device=self.device
)
# 初始化ocr # 初始化ocr
if self.apply_ocr: if self.apply_ocr:
self.ocr_model = ModifiedPaddleOCR(show_log=show_log) self.ocr_model = ModifiedPaddleOCR(show_log=show_log)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment