pdf_extract_kit.py 9.89 KB
Newer Older
1
2
import numpy as np
import torch
3
from loguru import logger
myhloli's avatar
myhloli committed
4
import os
5
import time
6
7
8
import cv2
import yaml
from PIL import Image
9
10

os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1'  # 禁止albumentations检查更新
11
os.environ['YOLO_VERBOSE'] = 'False'  # disable yolo logger
12

myhloli's avatar
myhloli committed
13
try:
14
    import torchtext
15

16
17
    if torchtext.__version__ >= "0.18.0":
        torchtext.disable_torchtext_deprecation_warning()
18
19
except ImportError:
    pass
20

21
22
23
24
25
from magic_pdf.libs.Constants import *
from magic_pdf.model.model_list import AtomicModel
from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
from magic_pdf.model.sub_modules.model_utils import get_res_list_from_layout_res, crop_img, clean_vram
from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import get_adjusted_mfdetrec_res, get_ocr_result_list
26
27


28
class CustomPEKModel:
29

30
31
32
33
34
35
36
37
38
39
40
41
42
43
    def __init__(self, ocr: bool = False, show_log: bool = False, **kwargs):
        """
        ======== model init ========
        """
        # 获取当前文件(即 pdf_extract_kit.py)的绝对路径
        current_file_path = os.path.abspath(__file__)
        # 获取当前文件所在的目录(model)
        current_dir = os.path.dirname(current_file_path)
        # 上一级目录(magic_pdf)
        root_dir = os.path.dirname(current_dir)
        # model_config目录
        model_config_dir = os.path.join(root_dir, 'resources', 'model_config')
        # 构建 model_configs.yaml 文件的完整路径
        config_path = os.path.join(model_config_dir, 'model_configs.yaml')
44
        with open(config_path, "r", encoding='utf-8') as f:
45
46
            self.configs = yaml.load(f, Loader=yaml.FullLoader)
        # 初始化解析配置
47
48
49
50
51
52
53
54
55
56
57

        # layout config
        self.layout_config = kwargs.get("layout_config")
        self.layout_model_name = self.layout_config.get("model", MODEL_NAME.DocLayout_YOLO)

        # formula config
        self.formula_config = kwargs.get("formula_config")
        self.mfd_model_name = self.formula_config.get("mfd_model", MODEL_NAME.YOLO_V8_MFD)
        self.mfr_model_name = self.formula_config.get("mfr_model", MODEL_NAME.UniMerNet_v2_Small)
        self.apply_formula = self.formula_config.get("enable", True)

58
        # table config
59
60
        self.table_config = kwargs.get("table_config")
        self.apply_table = self.table_config.get("enable", False)
61
        self.table_max_time = self.table_config.get("max_time", TABLE_MAX_TIME_VALUE)
62
        self.table_model_name = self.table_config.get("model", MODEL_NAME.RAPID_TABLE)
63
64

        # ocr config
65
        self.apply_ocr = ocr
66
        self.lang = kwargs.get("lang", None)
67

68
        logger.info(
69
70
            "DocAnalysis init, this may take some times, layout_model: {}, apply_formula: {}, apply_ocr: {}, "
            "apply_table: {}, table_model: {}, lang: {}".format(
71
72
                self.layout_model_name, self.apply_formula, self.apply_ocr, self.apply_table, self.table_model_name,
                self.lang
赵小蒙's avatar
update:  
赵小蒙 committed
73
            )
74
75
        )
        # 初始化解析方案
76
        self.device = kwargs.get("device", "cpu")
77
        logger.info("using device: {}".format(self.device))
78
        models_dir = kwargs.get("models_dir", os.path.join(root_dir, "resources", "models"))
79
        logger.info("using models_dir: {}".format(models_dir))
80

81
82
        atom_model_manager = AtomModelSingleton()

83
84
85
        # 初始化公式识别
        if self.apply_formula:
            # 初始化公式检测模型
86
87
            self.mfd_model = atom_model_manager.get_atom_model(
                atom_model_name=AtomicModel.MFD,
88
89
                mfd_weights=str(os.path.join(models_dir, self.configs["weights"][self.mfd_model_name])),
                device=self.device
90
            )
91

92
            # 初始化公式解析模型
93
            mfr_weight_dir = str(os.path.join(models_dir, self.configs["weights"][self.mfr_model_name]))
94
            mfr_cfg_path = str(os.path.join(model_config_dir, "UniMERNet", "demo.yaml"))
95
            self.mfr_model = atom_model_manager.get_atom_model(
96
97
98
99
100
                atom_model_name=AtomicModel.MFR,
                mfr_weight_dir=mfr_weight_dir,
                mfr_cfg_path=mfr_cfg_path,
                device=self.device
            )
101
102

        # 初始化layout模型
103
104
105
106
107
108
109
110
111
112
113
114
        if self.layout_model_name == MODEL_NAME.LAYOUTLMv3:
            self.layout_model = atom_model_manager.get_atom_model(
                atom_model_name=AtomicModel.Layout,
                layout_model_name=MODEL_NAME.LAYOUTLMv3,
                layout_weights=str(os.path.join(models_dir, self.configs['weights'][self.layout_model_name])),
                layout_config_file=str(os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml")),
                device=self.device
            )
        elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
            self.layout_model = atom_model_manager.get_atom_model(
                atom_model_name=AtomicModel.Layout,
                layout_model_name=MODEL_NAME.DocLayout_YOLO,
115
116
                doclayout_yolo_weights=str(os.path.join(models_dir, self.configs['weights'][self.layout_model_name])),
                device=self.device
117
            )
118
        # 初始化ocr
119
120
121
122
123
124
125
        if self.apply_ocr:
            self.ocr_model = atom_model_manager.get_atom_model(
                atom_model_name=AtomicModel.OCR,
                ocr_show_log=show_log,
                det_db_box_thresh=0.3,
                lang=self.lang
            )
126
        # init table model
127
        if self.apply_table:
128
            table_model_dir = self.configs["weights"][self.table_model_name]
129
130
131
132
133
134
135
            self.table_model = atom_model_manager.get_atom_model(
                atom_model_name=AtomicModel.Table,
                table_model_name=self.table_model_name,
                table_model_path=str(os.path.join(models_dir, table_model_dir)),
                table_max_time=self.table_max_time,
                device=self.device
            )
drunkpig's avatar
drunkpig committed
136

137
        logger.info('DocAnalysis init done!')
赵小蒙's avatar
update:  
赵小蒙 committed
138

139
140
    def __call__(self, image):

141
142
        page_start = time.time()

143
144
        # layout检测
        layout_start = time.time()
145
        layout_res = []
146
147
148
149
150
        if self.layout_model_name == MODEL_NAME.LAYOUTLMv3:
            # layoutlmv3
            layout_res = self.layout_model(image, ignore_catids=[])
        elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
            # doclayout_yolo
151
            layout_res = self.layout_model.predict(image)
152
        layout_cost = round(time.time() - layout_start, 2)
153
        logger.info(f"layout detection time: {layout_cost}")
154

155
156
        pil_img = Image.fromarray(image)

157
158
        if self.apply_formula:
            # 公式检测
159
            mfd_start = time.time()
160
            mfd_res = self.mfd_model.predict(image)
161
            logger.info(f"mfd time: {round(time.time() - mfd_start, 2)}")
162
163
164

            # 公式识别
            mfr_start = time.time()
165
166
            formula_list = self.mfr_model.predict(mfd_res, image)
            layout_res.extend(formula_list)
167
            mfr_cost = round(time.time() - mfr_start, 2)
168
169
170
171
172
173
174
            logger.info(f"formula nums: {len(formula_list)}, mfr time: {mfr_cost}")

        # 清理显存
        clean_vram(self.device, vram_threshold=8)

        # 从layout_res中获取ocr区域、表格区域、公式区域
        ocr_res_list, table_res_list, single_page_mfdetrec_res = get_res_list_from_layout_res(layout_res)
175

myhloli's avatar
myhloli committed
176
        # ocr识别
177
        if self.apply_ocr:
178
            ocr_start = time.time()
179
            # Process each area that requires OCR processing
180
            for res in ocr_res_list:
181
                new_image, useful_list = crop_img(res, pil_img, crop_paste_x=50, crop_paste_y=50)
182
                adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(single_page_mfdetrec_res, useful_list)
183

184
                # OCR recognition
185
186
                new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
                ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res)[0]
187

188
                # Integration results
189
                if ocr_res:
190
191
                    ocr_result_list = get_ocr_result_list(ocr_res, useful_list)
                    layout_res.extend(ocr_result_list)
192

193
            ocr_cost = round(time.time() - ocr_start, 2)
194
            logger.info(f"ocr time: {ocr_cost}")
195

196
197
        # 表格识别 table recognition
        if self.apply_table:
198
199
200
201
            table_start = time.time()
            for res in table_res_list:
                new_image, _ = crop_img(res, pil_img)
                single_table_start_time = time.time()
202
                html_code = None
203
                if self.table_model_name == MODEL_NAME.STRUCT_EQTABLE:
204
                    with torch.no_grad():
205
206
207
                        table_result = self.table_model.predict(new_image, "html")
                        if len(table_result) > 0:
                            html_code = table_result[0]
208
                elif self.table_model_name == MODEL_NAME.TABLE_MASTER:
209
                    html_code = self.table_model.img2html(new_image)
210
                elif self.table_model_name == MODEL_NAME.RAPID_TABLE:
211
                    html_code, table_cell_bboxes, elapse = self.table_model.predict(new_image)
212
213
                run_time = time.time() - single_table_start_time
                if run_time > self.table_max_time:
214
                    logger.warning(f"table recognition processing exceeds max time {self.table_max_time}s")
215
                # 判断是否返回正常
216
                if html_code:
217
218
219
220
221
                    expected_ending = html_code.strip().endswith('</html>') or html_code.strip().endswith('</table>')
                    if expected_ending:
                        res["html"] = html_code
                    else:
                        logger.warning(f"table recognition processing fails, not found expected HTML table end")
222
                else:
223
                    logger.warning(f"table recognition processing fails, not get html return")
224
225
226
            logger.info(f"table time: {round(time.time() - table_start, 2)}")

        logger.info(f"-----page total time: {round(time.time() - page_start, 2)}-----")
227

228
        return layout_res