model_init.py 5.43 KB
Newer Older
1
2
from loguru import logger

3
from magic_pdf.config.constants import MODEL_NAME
4
from magic_pdf.model.model_list import AtomicModel
5
6
7
8
from magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO import \
    DocLayoutYOLOModel
from magic_pdf.model.sub_modules.layout.layoutlmv3.model_init import \
    Layoutlmv3_Predictor
9
10
from magic_pdf.model.sub_modules.mfd.yolov8.YOLOv8 import YOLOv8MFDModel
from magic_pdf.model.sub_modules.mfr.unimernet.Unimernet import UnimernetModel
11
12
13
14
from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_273_mod import \
    ModifiedPaddleOCR
from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import \
    RapidTableModel
15
# from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_291_mod import ModifiedPaddleOCR
16
17
18
19
from magic_pdf.model.sub_modules.table.structeqtable.struct_eqtable import \
    StructTableModel
from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import \
    TableMasterPaddleModel
20
21
22
23
24
25
26


def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
    if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
        table_model = StructTableModel(model_path, max_new_tokens=2048, max_time=max_time)
    elif table_model_type == MODEL_NAME.TABLE_MASTER:
        config = {
27
28
            'model_dir': model_path,
            'device': _device_
29
30
31
32
33
        }
        table_model = TableMasterPaddleModel(config)
    elif table_model_type == MODEL_NAME.RAPID_TABLE:
        table_model = RapidTableModel()
    else:
34
        logger.error('table model type not allow')
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
        exit(1)

    return table_model


def mfd_model_init(weight, device='cpu'):
    mfd_model = YOLOv8MFDModel(weight, device)
    return mfd_model


def mfr_model_init(weight_dir, cfg_path, device='cpu'):
    mfr_model = UnimernetModel(weight_dir, cfg_path, device)
    return mfr_model


def layout_model_init(weight, config_file, device):
    model = Layoutlmv3_Predictor(weight, config_file, device)
    return model


def doclayout_yolo_model_init(weight, device='cpu'):
    model = DocLayoutYOLOModel(weight, device)
    return model


60
61
62
63
64
import threading
current_thread = threading.current_thread()
current_thread_id = current_thread.ident


65
66
67
68
69
70
def ocr_model_init(show_log: bool = False,
                   det_db_box_thresh=0.3,
                   lang=None,
                   use_dilation=True,
                   det_db_unclip_ratio=1.8,
                   ):
71
    if lang is not None and lang != '':
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
        model = ModifiedPaddleOCR(
            show_log=show_log,
            det_db_box_thresh=det_db_box_thresh,
            lang=lang,
            use_dilation=use_dilation,
            det_db_unclip_ratio=det_db_unclip_ratio,
        )
    else:
        model = ModifiedPaddleOCR(
            show_log=show_log,
            det_db_box_thresh=det_db_box_thresh,
            use_dilation=use_dilation,
            det_db_unclip_ratio=det_db_unclip_ratio,
            # use_angle_cls=True,
        )
    return model


class AtomModelSingleton:
    _instance = None
    _models = {}

    def __new__(cls, *args, **kwargs):
        if cls._instance is None:
            cls._instance = super().__new__(cls)
        return cls._instance

    def get_atom_model(self, atom_model_name: str, **kwargs):
100

101
102
        lang = kwargs.get('lang', None)
        layout_model_name = kwargs.get('layout_model_name', None)
103
104
105
106
107
108
109
110
111
112
113
        table_model_name = kwargs.get('table_model_name', None)

        if atom_model_name in [AtomicModel.OCR]:
            key = (atom_model_name, lang, current_thread_id)
        elif atom_model_name in [AtomicModel.Layout]:
            key = (atom_model_name, layout_model_name)
        elif atom_model_name in [AtomicModel.Table]:
            key = (atom_model_name, table_model_name)
        else:
            key = atom_model_name

114
115
        if key not in self._models:
            self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
116
        return self._models[key]
117
118
119
120

def atom_model_init(model_name: str, **kwargs):
    atom_model = None
    if model_name == AtomicModel.Layout:
121
        if kwargs.get('layout_model_name') == MODEL_NAME.LAYOUTLMv3:
122
            atom_model = layout_model_init(
123
124
125
                kwargs.get('layout_weights'),
                kwargs.get('layout_config_file'),
                kwargs.get('device')
126
            )
127
        elif kwargs.get('layout_model_name') == MODEL_NAME.DocLayout_YOLO:
128
            atom_model = doclayout_yolo_model_init(
129
130
                kwargs.get('doclayout_yolo_weights'),
                kwargs.get('device')
131
132
133
            )
    elif model_name == AtomicModel.MFD:
        atom_model = mfd_model_init(
134
135
            kwargs.get('mfd_weights'),
            kwargs.get('device')
136
137
138
        )
    elif model_name == AtomicModel.MFR:
        atom_model = mfr_model_init(
139
140
141
            kwargs.get('mfr_weight_dir'),
            kwargs.get('mfr_cfg_path'),
            kwargs.get('device')
142
143
144
        )
    elif model_name == AtomicModel.OCR:
        atom_model = ocr_model_init(
145
146
            kwargs.get('ocr_show_log'),
            kwargs.get('det_db_box_thresh'),
147
            kwargs.get('lang'),
148
149
150
        )
    elif model_name == AtomicModel.Table:
        atom_model = table_model_init(
151
152
153
154
            kwargs.get('table_model_name'),
            kwargs.get('table_model_path'),
            kwargs.get('table_max_time'),
            kwargs.get('device')
155
156
        )
    else:
157
        logger.error('model name not allow')
158
159
160
        exit(1)

    if atom_model is None:
161
        logger.error('model init failed')
162
163
164
        exit(1)
    else:
        return atom_model