pdf_extract_kit.py 11.9 KB
Newer Older
1
# flake8: noqa
myhloli's avatar
myhloli committed
2
import os
3
import time
4

5
import cv2
6
7
import numpy as np
import torch
8
import yaml
9
from loguru import logger
10
from PIL import Image
11
12

os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1'  # 禁止albumentations检查更新
13

myhloli's avatar
myhloli committed
14
try:
15
    import torchtext
16

17
    if torchtext.__version__ >= '0.18.0':
18
        torchtext.disable_torchtext_deprecation_warning()
19
20
except ImportError:
    pass
21

22
from magic_pdf.config.constants import *
23
from magic_pdf.model.model_list import AtomicModel
24
from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
25
26
27
28
from magic_pdf.model.sub_modules.model_utils import (
    clean_vram, crop_img, get_res_list_from_layout_res)
from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import (
    get_adjusted_mfdetrec_res, get_ocr_result_list)
29
30


31
class CustomPEKModel:
32

33
34
35
36
37
38
39
40
41
42
43
44
45
46
    def __init__(self, ocr: bool = False, show_log: bool = False, **kwargs):
        """
        ======== model init ========
        """
        # 获取当前文件(即 pdf_extract_kit.py)的绝对路径
        current_file_path = os.path.abspath(__file__)
        # 获取当前文件所在的目录(model)
        current_dir = os.path.dirname(current_file_path)
        # 上一级目录(magic_pdf)
        root_dir = os.path.dirname(current_dir)
        # model_config目录
        model_config_dir = os.path.join(root_dir, 'resources', 'model_config')
        # 构建 model_configs.yaml 文件的完整路径
        config_path = os.path.join(model_config_dir, 'model_configs.yaml')
47
        with open(config_path, 'r', encoding='utf-8') as f:
48
49
            self.configs = yaml.load(f, Loader=yaml.FullLoader)
        # 初始化解析配置
50
51

        # layout config
52
53
54
55
        self.layout_config = kwargs.get('layout_config')
        self.layout_model_name = self.layout_config.get(
            'model', MODEL_NAME.DocLayout_YOLO
        )
56
57

        # formula config
58
59
60
61
62
63
64
65
        self.formula_config = kwargs.get('formula_config')
        self.mfd_model_name = self.formula_config.get(
            'mfd_model', MODEL_NAME.YOLO_V8_MFD
        )
        self.mfr_model_name = self.formula_config.get(
            'mfr_model', MODEL_NAME.UniMerNet_v2_Small
        )
        self.apply_formula = self.formula_config.get('enable', True)
66

67
        # table config
68
69
70
71
        self.table_config = kwargs.get('table_config')
        self.apply_table = self.table_config.get('enable', False)
        self.table_max_time = self.table_config.get('max_time', TABLE_MAX_TIME_VALUE)
        self.table_model_name = self.table_config.get('model', MODEL_NAME.RAPID_TABLE)
72
        self.table_sub_model_name = self.table_config.get('sub_model', None)
73
74

        # ocr config
75
        self.apply_ocr = ocr
76
        self.lang = kwargs.get('lang', None)
77

78
        logger.info(
79
80
81
82
83
84
85
86
            'DocAnalysis init, this may take some times, layout_model: {}, apply_formula: {}, apply_ocr: {}, '
            'apply_table: {}, table_model: {}, lang: {}'.format(
                self.layout_model_name,
                self.apply_formula,
                self.apply_ocr,
                self.apply_table,
                self.table_model_name,
                self.lang,
赵小蒙's avatar
update:  
赵小蒙 committed
87
            )
88
89
        )
        # 初始化解析方案
90
        self.device = kwargs.get('device', 'cpu')
91

92
93
94
95
96
        logger.info('using device: {}'.format(self.device))
        models_dir = kwargs.get(
            'models_dir', os.path.join(root_dir, 'resources', 'models')
        )
        logger.info('using models_dir: {}'.format(models_dir))
97

98
99
        atom_model_manager = AtomModelSingleton()

100
101
102
        # 初始化公式识别
        if self.apply_formula:
            # 初始化公式检测模型
103
104
            self.mfd_model = atom_model_manager.get_atom_model(
                atom_model_name=AtomicModel.MFD,
105
106
107
108
109
110
                mfd_weights=str(
                    os.path.join(
                        models_dir, self.configs['weights'][self.mfd_model_name]
                    )
                ),
                device=self.device,
111
            )
112

113
            # 初始化公式解析模型
114
115
116
117
            mfr_weight_dir = str(
                os.path.join(models_dir, self.configs['weights'][self.mfr_model_name])
            )
            mfr_cfg_path = str(os.path.join(model_config_dir, 'UniMERNet', 'demo.yaml'))
118

119
            self.mfr_model = atom_model_manager.get_atom_model(
120
121
122
                atom_model_name=AtomicModel.MFR,
                mfr_weight_dir=mfr_weight_dir,
                mfr_cfg_path=mfr_cfg_path,
123
                device='cpu' if str(self.device).startswith("mps") else self.device,
124
            )
125
126

        # 初始化layout模型
127
128
129
130
        if self.layout_model_name == MODEL_NAME.LAYOUTLMv3:
            self.layout_model = atom_model_manager.get_atom_model(
                atom_model_name=AtomicModel.Layout,
                layout_model_name=MODEL_NAME.LAYOUTLMv3,
131
132
133
134
135
136
137
138
139
140
                layout_weights=str(
                    os.path.join(
                        models_dir, self.configs['weights'][self.layout_model_name]
                    )
                ),
                layout_config_file=str(
                    os.path.join(
                        model_config_dir, 'layoutlmv3', 'layoutlmv3_base_inference.yaml'
                    )
                ),
141
                device='cpu' if str(self.device).startswith("mps") else self.device,
142
143
144
145
146
            )
        elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
            self.layout_model = atom_model_manager.get_atom_model(
                atom_model_name=AtomicModel.Layout,
                layout_model_name=MODEL_NAME.DocLayout_YOLO,
147
148
149
150
151
152
                doclayout_yolo_weights=str(
                    os.path.join(
                        models_dir, self.configs['weights'][self.layout_model_name]
                    )
                ),
                device=self.device,
153
            )
154
        # 初始化ocr
155
156
157
        self.ocr_model = atom_model_manager.get_atom_model(
            atom_model_name=AtomicModel.OCR,
            ocr_show_log=show_log,
158
159
160
            det_db_box_thresh=0.3,
            lang=self.lang
        )
161
        # init table model
162
        if self.apply_table:
163
            table_model_dir = self.configs['weights'][self.table_model_name]
164
165
166
167
168
            self.table_model = atom_model_manager.get_atom_model(
                atom_model_name=AtomicModel.Table,
                table_model_name=self.table_model_name,
                table_model_path=str(os.path.join(models_dir, table_model_dir)),
                table_max_time=self.table_max_time,
169
                device=self.device,
170
                ocr_engine=self.ocr_model,
171
                table_sub_model_name=self.table_sub_model_name
172
            )
drunkpig's avatar
drunkpig committed
173

174
        logger.info('DocAnalysis init done!')
赵小蒙's avatar
update:  
赵小蒙 committed
175

176
    def __call__(self, image):
177

178
179
180
181
        pil_img = Image.fromarray(image)
        width, height = pil_img.size
        # logger.info(f'width: {width}, height: {height}')

182
183
        # layout检测
        layout_start = time.time()
184
        layout_res = []
185
186
187
188
189
        if self.layout_model_name == MODEL_NAME.LAYOUTLMv3:
            # layoutlmv3
            layout_res = self.layout_model(image, ignore_catids=[])
        elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
            # doclayout_yolo
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
            # if height > width:
            #     input_res = {"poly":[0,0,width,0,width,height,0,height]}
            #     new_image, useful_list = crop_img(input_res, pil_img, crop_paste_x=width//2, crop_paste_y=0)
            #     paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
            #     layout_res = self.layout_model.predict(new_image)
            #     for res in layout_res:
            #         p1, p2, p3, p4, p5, p6, p7, p8 = res['poly']
            #         p1 = p1 - paste_x + xmin
            #         p2 = p2 - paste_y + ymin
            #         p3 = p3 - paste_x + xmin
            #         p4 = p4 - paste_y + ymin
            #         p5 = p5 - paste_x + xmin
            #         p6 = p6 - paste_y + ymin
            #         p7 = p7 - paste_x + xmin
            #         p8 = p8 - paste_y + ymin
            #         res['poly'] = [p1, p2, p3, p4, p5, p6, p7, p8]
            # else:
            layout_res = self.layout_model.predict(image)
208

209
        layout_cost = round(time.time() - layout_start, 2)
210
        logger.info(f'layout detection time: {layout_cost}')
211

212
213
        if self.apply_formula:
            # 公式检测
214
            mfd_start = time.time()
215
            mfd_res = self.mfd_model.predict(image)
216
            logger.info(f'mfd time: {round(time.time() - mfd_start, 2)}')
217
218
219

            # 公式识别
            mfr_start = time.time()
220
221
            formula_list = self.mfr_model.predict(mfd_res, image)
            layout_res.extend(formula_list)
222
            mfr_cost = round(time.time() - mfr_start, 2)
223
            logger.info(f'formula nums: {len(formula_list)}, mfr time: {mfr_cost}')
224
225

        # 清理显存
226
        clean_vram(self.device, vram_threshold=6)
227
228

        # 从layout_res中获取ocr区域、表格区域、公式区域
229
230
231
        ocr_res_list, table_res_list, single_page_mfdetrec_res = (
            get_res_list_from_layout_res(layout_res)
        )
232

myhloli's avatar
myhloli committed
233
        # ocr识别
234
235
236
237
238
239
240
241
        ocr_start = time.time()
        # Process each area that requires OCR processing
        for res in ocr_res_list:
            new_image, useful_list = crop_img(res, pil_img, crop_paste_x=50, crop_paste_y=50)
            adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(single_page_mfdetrec_res, useful_list)

            # OCR recognition
            new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
242

243
244
245
246
            if self.apply_ocr:
                ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res)[0]
            else:
                ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res, rec=False)[0]
247

248
249
250
251
            # Integration results
            if ocr_res:
                ocr_result_list = get_ocr_result_list(ocr_res, useful_list)
                layout_res.extend(ocr_result_list)
252

253
254
        ocr_cost = round(time.time() - ocr_start, 2)
        if self.apply_ocr:
255
            logger.info(f"ocr time: {ocr_cost}")
256
257
        else:
            logger.info(f"det time: {ocr_cost}")
258

259
260
        # 表格识别 table recognition
        if self.apply_table:
261
262
263
264
            table_start = time.time()
            for res in table_res_list:
                new_image, _ = crop_img(res, pil_img)
                single_table_start_time = time.time()
265
                html_code = None
266
                if self.table_model_name == MODEL_NAME.STRUCT_EQTABLE:
267
                    with torch.no_grad():
268
                        table_result = self.table_model.predict(new_image, 'html')
269
270
                        if len(table_result) > 0:
                            html_code = table_result[0]
271
                elif self.table_model_name == MODEL_NAME.TABLE_MASTER:
272
                    html_code = self.table_model.img2html(new_image)
273
                elif self.table_model_name == MODEL_NAME.RAPID_TABLE:
274
                    html_code, table_cell_bboxes, logic_points, elapse = self.table_model.predict(
275
276
                        new_image
                    )
277
278
                run_time = time.time() - single_table_start_time
                if run_time > self.table_max_time:
279
280
281
                    logger.warning(
                        f'table recognition processing exceeds max time {self.table_max_time}s'
                    )
282
                # 判断是否返回正常
283
                if html_code:
284
285
286
                    expected_ending = html_code.strip().endswith(
                        '</html>'
                    ) or html_code.strip().endswith('</table>')
287
                    if expected_ending:
288
                        res['html'] = html_code
289
                    else:
290
291
292
                        logger.warning(
                            'table recognition processing fails, not found expected HTML table end'
                        )
293
                else:
294
295
296
297
                    logger.warning(
                        'table recognition processing fails, not get html return'
                    )
            logger.info(f'table time: {round(time.time() - table_start, 2)}')
298

299
        return layout_res