"build_tools/vscode:/vscode.git/clone" did not exist on "e354062580c5ec83805712e53c1c54f72771707b"
Commit a565fa3a authored by luopl's avatar luopl
Browse files

Initial commit

parents
This diff is collapsed.
import os
import torch
from loguru import logger
from .model_list import AtomicModel
from ...model.layout.doclayoutyolo import DocLayoutYOLOModel
from ...model.mfd.yolo_v8 import YOLOv8MFDModel
from ...model.mfr.unimernet.Unimernet import UnimernetModel
from ...model.ocr.paddleocr2pytorch.pytorch_paddle import PytorchPaddleOCR
from ...model.ori_cls.paddle_ori_cls import PaddleOrientationClsModel
from ...model.table.cls.paddle_table_cls import PaddleTableClsModel
# from ...model.table.rec.RapidTable import RapidTableModel
from ...model.table.rec.slanet_plus.main import RapidTableModel
from ...model.table.rec.unet_table.main import UnetTableModel
from ...utils.enum_class import ModelPath
from ...utils.models_download_utils import auto_download_and_get_model_root_path
def img_orientation_cls_model_init():
atom_model_manager = AtomModelSingleton()
ocr_engine = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.OCR,
det_db_box_thresh=0.5,
det_db_unclip_ratio=1.6,
lang="ch_lite",
enable_merge_det_boxes=False
)
cls_model = PaddleOrientationClsModel(ocr_engine)
return cls_model
def table_cls_model_init():
return PaddleTableClsModel()
def wired_table_model_init(lang=None):
atom_model_manager = AtomModelSingleton()
ocr_engine = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.OCR,
det_db_box_thresh=0.5,
det_db_unclip_ratio=1.6,
lang=lang,
enable_merge_det_boxes=False
)
table_model = UnetTableModel(ocr_engine)
return table_model
def wireless_table_model_init(lang=None):
atom_model_manager = AtomModelSingleton()
ocr_engine = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.OCR,
det_db_box_thresh=0.5,
det_db_unclip_ratio=1.6,
lang=lang,
enable_merge_det_boxes=False
)
table_model = RapidTableModel(ocr_engine)
return table_model
def mfd_model_init(weight, device='cpu'):
if str(device).startswith('npu'):
device = torch.device(device)
mfd_model = YOLOv8MFDModel(weight, device)
return mfd_model
def mfr_model_init(weight_dir, device='cpu'):
mfr_model = UnimernetModel(weight_dir, device)
return mfr_model
def doclayout_yolo_model_init(weight, device='cpu'):
if str(device).startswith('npu'):
device = torch.device(device)
model = DocLayoutYOLOModel(weight, device)
return model
def ocr_model_init(det_db_box_thresh=0.3,
lang=None,
det_db_unclip_ratio=1.8,
enable_merge_det_boxes=True
):
if lang is not None and lang != '':
model = PytorchPaddleOCR(
det_db_box_thresh=det_db_box_thresh,
lang=lang,
use_dilation=True,
det_db_unclip_ratio=det_db_unclip_ratio,
enable_merge_det_boxes=enable_merge_det_boxes,
)
else:
model = PytorchPaddleOCR(
det_db_box_thresh=det_db_box_thresh,
use_dilation=True,
det_db_unclip_ratio=det_db_unclip_ratio,
enable_merge_det_boxes=enable_merge_det_boxes,
)
return model
class AtomModelSingleton:
_instance = None
_models = {}
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def get_atom_model(self, atom_model_name: str, **kwargs):
lang = kwargs.get('lang', None)
if atom_model_name in [AtomicModel.WiredTable, AtomicModel.WirelessTable]:
key = (
atom_model_name,
lang
)
elif atom_model_name in [AtomicModel.OCR]:
key = (
atom_model_name,
kwargs.get('det_db_box_thresh', 0.3),
lang,
kwargs.get('det_db_unclip_ratio', 1.8),
kwargs.get('enable_merge_det_boxes', True)
)
else:
key = atom_model_name
if key not in self._models:
self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
return self._models[key]
def atom_model_init(model_name: str, **kwargs):
atom_model = None
if model_name == AtomicModel.Layout:
atom_model = doclayout_yolo_model_init(
kwargs.get('doclayout_yolo_weights'),
kwargs.get('device')
)
elif model_name == AtomicModel.MFD:
atom_model = mfd_model_init(
kwargs.get('mfd_weights'),
kwargs.get('device')
)
elif model_name == AtomicModel.MFR:
atom_model = mfr_model_init(
kwargs.get('mfr_weight_dir'),
kwargs.get('device')
)
elif model_name == AtomicModel.OCR:
atom_model = ocr_model_init(
kwargs.get('det_db_box_thresh', 0.3),
kwargs.get('lang'),
kwargs.get('det_db_unclip_ratio', 1.8),
kwargs.get('enable_merge_det_boxes', True)
)
elif model_name == AtomicModel.WirelessTable:
atom_model = wireless_table_model_init(
kwargs.get('lang'),
)
elif model_name == AtomicModel.WiredTable:
atom_model = wired_table_model_init(
kwargs.get('lang'),
)
elif model_name == AtomicModel.TableCls:
atom_model = table_cls_model_init()
elif model_name == AtomicModel.ImgOrientationCls:
atom_model = img_orientation_cls_model_init()
else:
logger.error('model name not allow')
exit(1)
if atom_model is None:
logger.error('model init failed')
exit(1)
else:
return atom_model
class MineruPipelineModel:
def __init__(self, **kwargs):
self.formula_config = kwargs.get('formula_config')
self.apply_formula = self.formula_config.get('enable', True)
self.table_config = kwargs.get('table_config')
self.apply_table = self.table_config.get('enable', True)
self.lang = kwargs.get('lang', None)
self.device = kwargs.get('device', 'cpu')
logger.info(
'DocAnalysis init, this may take some times......'
)
atom_model_manager = AtomModelSingleton()
if self.apply_formula:
# 初始化公式检测模型
self.mfd_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.MFD,
mfd_weights=str(
os.path.join(auto_download_and_get_model_root_path(ModelPath.yolo_v8_mfd), ModelPath.yolo_v8_mfd)
),
device=self.device,
)
# 初始化公式解析模型
mfr_weight_dir = os.path.join(auto_download_and_get_model_root_path(ModelPath.unimernet_small), ModelPath.unimernet_small)
self.mfr_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.MFR,
mfr_weight_dir=mfr_weight_dir,
device=self.device,
)
# 初始化layout模型
self.layout_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.Layout,
doclayout_yolo_weights=str(
os.path.join(auto_download_and_get_model_root_path(ModelPath.doclayout_yolo), ModelPath.doclayout_yolo)
),
device=self.device,
)
# 初始化ocr
self.ocr_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.OCR,
det_db_box_thresh=0.3,
lang=self.lang
)
# init table model
if self.apply_table:
self.wired_table_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.WiredTable,
lang=self.lang,
)
self.wireless_table_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.WirelessTable,
lang=self.lang,
)
self.table_cls_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.TableCls,
)
self.img_orientation_cls_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.ImgOrientationCls,
lang=self.lang,
)
logger.info('DocAnalysis init done!')
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
# Copyright (c) Opendatalab. All rights reserved.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment