model_init.py 4.95 KB
Newer Older
1
2
from loguru import logger

3
from magic_pdf.config.constants import MODEL_NAME
4
from magic_pdf.model.model_list import AtomicModel
5
6
7
8
from magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO import \
    DocLayoutYOLOModel
from magic_pdf.model.sub_modules.layout.layoutlmv3.model_init import \
    Layoutlmv3_Predictor
9
10
from magic_pdf.model.sub_modules.mfd.yolov8.YOLOv8 import YOLOv8MFDModel
from magic_pdf.model.sub_modules.mfr.unimernet.Unimernet import UnimernetModel
11
12
13
14
from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_273_mod import \
    ModifiedPaddleOCR
from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import \
    RapidTableModel
15
# from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_291_mod import ModifiedPaddleOCR
16
17
18
19
from magic_pdf.model.sub_modules.table.structeqtable.struct_eqtable import \
    StructTableModel
from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import \
    TableMasterPaddleModel
20
21
22
23
24
25
26


def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
    if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
        table_model = StructTableModel(model_path, max_new_tokens=2048, max_time=max_time)
    elif table_model_type == MODEL_NAME.TABLE_MASTER:
        config = {
27
28
            'model_dir': model_path,
            'device': _device_
29
30
31
32
33
        }
        table_model = TableMasterPaddleModel(config)
    elif table_model_type == MODEL_NAME.RAPID_TABLE:
        table_model = RapidTableModel()
    else:
34
        logger.error('table model type not allow')
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
        exit(1)

    return table_model


def mfd_model_init(weight, device='cpu'):
    mfd_model = YOLOv8MFDModel(weight, device)
    return mfd_model


def mfr_model_init(weight_dir, cfg_path, device='cpu'):
    mfr_model = UnimernetModel(weight_dir, cfg_path, device)
    return mfr_model


def layout_model_init(weight, config_file, device):
    model = Layoutlmv3_Predictor(weight, config_file, device)
    return model


def doclayout_yolo_model_init(weight, device='cpu'):
    model = DocLayoutYOLOModel(weight, device)
    return model


def ocr_model_init(show_log: bool = False,
                   det_db_box_thresh=0.3,
                   lang=None,
                   use_dilation=True,
                   det_db_unclip_ratio=1.8,
                   ):
66
    if lang is not None and lang != '':
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
        model = ModifiedPaddleOCR(
            show_log=show_log,
            det_db_box_thresh=det_db_box_thresh,
            lang=lang,
            use_dilation=use_dilation,
            det_db_unclip_ratio=det_db_unclip_ratio,
        )
    else:
        model = ModifiedPaddleOCR(
            show_log=show_log,
            det_db_box_thresh=det_db_box_thresh,
            use_dilation=use_dilation,
            det_db_unclip_ratio=det_db_unclip_ratio,
            # use_angle_cls=True,
        )
    return model


class AtomModelSingleton:
    _instance = None
    _models = {}

    def __new__(cls, *args, **kwargs):
        if cls._instance is None:
            cls._instance = super().__new__(cls)
        return cls._instance

    def get_atom_model(self, atom_model_name: str, **kwargs):
95
96
        lang = kwargs.get('lang', None)
        layout_model_name = kwargs.get('layout_model_name', None)
97
98
99
100
101
102
103
104
105
        key = (atom_model_name, layout_model_name, lang)
        if key not in self._models:
            self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
        return self._models[key]


def atom_model_init(model_name: str, **kwargs):
    atom_model = None
    if model_name == AtomicModel.Layout:
106
        if kwargs.get('layout_model_name') == MODEL_NAME.LAYOUTLMv3:
107
            atom_model = layout_model_init(
108
109
110
                kwargs.get('layout_weights'),
                kwargs.get('layout_config_file'),
                kwargs.get('device')
111
            )
112
        elif kwargs.get('layout_model_name') == MODEL_NAME.DocLayout_YOLO:
113
            atom_model = doclayout_yolo_model_init(
114
115
                kwargs.get('doclayout_yolo_weights'),
                kwargs.get('device')
116
117
118
            )
    elif model_name == AtomicModel.MFD:
        atom_model = mfd_model_init(
119
120
            kwargs.get('mfd_weights'),
            kwargs.get('device')
121
122
123
        )
    elif model_name == AtomicModel.MFR:
        atom_model = mfr_model_init(
124
125
126
            kwargs.get('mfr_weight_dir'),
            kwargs.get('mfr_cfg_path'),
            kwargs.get('device')
127
128
129
        )
    elif model_name == AtomicModel.OCR:
        atom_model = ocr_model_init(
130
131
132
            kwargs.get('ocr_show_log'),
            kwargs.get('det_db_box_thresh'),
            kwargs.get('lang')
133
134
135
        )
    elif model_name == AtomicModel.Table:
        atom_model = table_model_init(
136
137
138
139
            kwargs.get('table_model_name'),
            kwargs.get('table_model_path'),
            kwargs.get('table_max_time'),
            kwargs.get('device')
140
141
        )
    else:
142
        logger.error('model name not allow')
143
144
145
        exit(1)

    if atom_model is None:
146
        logger.error('model init failed')
147
148
149
        exit(1)
    else:
        return atom_model