pdf_extract_kit.py 10.4 KB
Newer Older
1
from loguru import logger
myhloli's avatar
myhloli committed
2
import os
3
import time
4
5

os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1'  # 禁止albumentations检查更新
myhloli's avatar
myhloli committed
6
7
8
9
10
11
try:
    import cv2
    import yaml
    import argparse
    import numpy as np
    import torch
12

myhloli's avatar
myhloli committed
13
14
15
16
17
18
19
20
    from paddleocr import draw_ocr
    from PIL import Image
    from torchvision import transforms
    from torch.utils.data import Dataset, DataLoader
    from ultralytics import YOLO
    from unimernet.common.config import Config
    import unimernet.tasks as tasks
    from unimernet.processors import load_processor
赵小蒙's avatar
update:  
赵小蒙 committed
21

22
23
except ImportError as e:
    logger.exception(e)
24
25
26
    logger.error(
        'Required dependency not installed, please install by \n'
        '"pip install magic-pdf[full] detectron2 --extra-index-url https://myhloli.github.io/wheels/"')
myhloli's avatar
myhloli committed
27
    exit(1)
赵小蒙's avatar
update:  
赵小蒙 committed
28

29
30
31
32
from magic_pdf.model.pek_sub_modules.layoutlmv3.model_init import Layoutlmv3_Predictor
from magic_pdf.model.pek_sub_modules.post_process import get_croped_image, latex_rm_whitespace
from magic_pdf.model.pek_sub_modules.self_modify import ModifiedPaddleOCR

赵小蒙's avatar
update:  
赵小蒙 committed
33

34
35
36
def mfd_model_init(weight):
    mfd_model = YOLO(weight)
    return mfd_model
赵小蒙's avatar
update:  
赵小蒙 committed
37
38


39
def mfr_model_init(weight_dir, cfg_path, _device_='cpu'):
40
41
42
43
44
45
46
    args = argparse.Namespace(cfg_path=cfg_path, options=None)
    cfg = Config(args)
    cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.bin")
    cfg.config.model.model_config.model_name = weight_dir
    cfg.config.model.tokenizer_config.path = weight_dir
    task = tasks.setup_task(cfg)
    model = task.build_model(cfg)
47
    model = model.to(_device_)
48
49
    vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
    return model, vis_processor
赵小蒙's avatar
update:  
赵小蒙 committed
50
51


52
53
54
55
56
def layout_model_init(weight, config_file, device):
    model = Layoutlmv3_Predictor(weight, config_file, device)
    return model


57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
class MathDataset(Dataset):
    def __init__(self, image_paths, transform=None):
        self.image_paths = image_paths
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        # if not pil image, then convert to pil image
        if isinstance(self.image_paths[idx], str):
            raw_image = Image.open(self.image_paths[idx])
        else:
            raw_image = self.image_paths[idx]
        if self.transform:
            image = self.transform(raw_image)
73
            return image
74
75


76
class CustomPEKModel:
77

78
79
80
81
82
83
84
85
86
87
88
89
90
91
    def __init__(self, ocr: bool = False, show_log: bool = False, **kwargs):
        """
        ======== model init ========
        """
        # 获取当前文件(即 pdf_extract_kit.py)的绝对路径
        current_file_path = os.path.abspath(__file__)
        # 获取当前文件所在的目录(model)
        current_dir = os.path.dirname(current_file_path)
        # 上一级目录(magic_pdf)
        root_dir = os.path.dirname(current_dir)
        # model_config目录
        model_config_dir = os.path.join(root_dir, 'resources', 'model_config')
        # 构建 model_configs.yaml 文件的完整路径
        config_path = os.path.join(model_config_dir, 'model_configs.yaml')
92
        with open(config_path, "r", encoding='utf-8') as f:
93
94
95
96
97
98
99
100
            self.configs = yaml.load(f, Loader=yaml.FullLoader)
        # 初始化解析配置
        self.apply_layout = kwargs.get("apply_layout", self.configs["config"]["layout"])
        self.apply_formula = kwargs.get("apply_formula", self.configs["config"]["formula"])
        self.apply_ocr = ocr
        logger.info(
            "DocAnalysis init, this may take some times. apply_layout: {}, apply_formula: {}, apply_ocr: {}".format(
                self.apply_layout, self.apply_formula, self.apply_ocr
赵小蒙's avatar
update:  
赵小蒙 committed
101
            )
102
103
104
        )
        assert self.apply_layout, "DocAnalysis must contain layout model."
        # 初始化解析方案
105
        self.device = kwargs.get("device", self.configs["config"]["device"])
106
        logger.info("using device: {}".format(self.device))
107
        models_dir = kwargs.get("models_dir", os.path.join(root_dir, "resources", "models"))
108
        logger.info("using models_dir: {}".format(models_dir))
109

110
111
112
        # 初始化公式识别
        if self.apply_formula:
            # 初始化公式检测模型
113
114
            self.mfd_model = mfd_model_init(str(os.path.join(models_dir, self.configs["weights"]["mfd"])))

115
            # 初始化公式解析模型
116
117
118
            mfr_weight_dir = str(os.path.join(models_dir, self.configs["weights"]["mfr"]))
            mfr_cfg_path = str(os.path.join(model_config_dir, "UniMERNet", "demo.yaml"))
            self.mfr_model, mfr_vis_processors = mfr_model_init(mfr_weight_dir, mfr_cfg_path, _device_=self.device)
119
            self.mfr_transform = transforms.Compose([mfr_vis_processors, ])
120
121
122
123
124
125
126

        # 初始化layout模型
        self.layout_model = Layoutlmv3_Predictor(
            str(os.path.join(models_dir, self.configs['weights']['layout'])),
            str(os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml")),
            device=self.device
        )
127
128
129
        # 初始化ocr
        if self.apply_ocr:
            self.ocr_model = ModifiedPaddleOCR(show_log=show_log)
赵小蒙's avatar
update:  
赵小蒙 committed
130

131
        logger.info('DocAnalysis init done!')
赵小蒙's avatar
update:  
赵小蒙 committed
132

133
134
    def __call__(self, image):

135
136
137
        latex_filling_list = []
        mf_image_list = []

138
139
140
141
142
143
        # layout检测
        layout_start = time.time()
        layout_res = self.layout_model(image, ignore_catids=[])
        layout_cost = round(time.time() - layout_start, 2)
        logger.info(f"layout detection cost: {layout_cost}")

144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
        if self.apply_formula:
            # 公式检测
            mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True)[0]
            for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()):
                xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
                new_item = {
                    'category_id': 13 + int(cla.item()),
                    'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
                    'score': round(float(conf.item()), 2),
                    'latex': '',
                }
                layout_res.append(new_item)
                latex_filling_list.append(new_item)
                bbox_img = get_croped_image(Image.fromarray(image), [xmin, ymin, xmax, ymax])
                mf_image_list.append(bbox_img)

            # 公式识别
            mfr_start = time.time()
            dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
            dataloader = DataLoader(dataset, batch_size=64, num_workers=0)
            mfr_res = []
            for mf_img in dataloader:
                mf_img = mf_img.to(self.device)
                output = self.mfr_model.generate({'image': mf_img})
                mfr_res.extend(output['pred_str'])
            for res, latex in zip(latex_filling_list, mfr_res):
                res['latex'] = latex_rm_whitespace(latex)
            mfr_cost = round(time.time() - mfr_start, 2)
            logger.info(f"formula nums: {len(mf_image_list)}, mfr time: {mfr_cost}")
173

myhloli's avatar
myhloli committed
174
        # ocr识别
175
        if self.apply_ocr:
176
177
            ocr_start = time.time()
            pil_img = Image.fromarray(image)
178
179
180

            # 筛选出需要OCR的区域和公式区域
            ocr_res_list = []
181
182
183
184
            single_page_mfdetrec_res = []
            for res in layout_res:
                if int(res['category_id']) in [13, 14]:
                    single_page_mfdetrec_res.append({
185
186
                        "bbox": [int(res['poly'][0]), int(res['poly'][1]),
                                 int(res['poly'][4]), int(res['poly'][5])],
187
                    })
188
189
190
191
192
193
194
195
196
197
198
                elif int(res['category_id']) in [0, 1, 2, 4, 6, 7]:
                    ocr_res_list.append(res)

            # 对每一个需OCR处理的区域进行处理
            for res in ocr_res_list:
                xmin, ymin = int(res['poly'][0]), int(res['poly'][1])
                xmax, ymax = int(res['poly'][4]), int(res['poly'][5])

                paste_x = 50
                paste_y = 50
                # 创建一个宽高各多50的白色背景
199
200
                new_width = xmax - xmin + paste_x * 2
                new_height = ymax - ymin + paste_y * 2
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
                new_image = Image.new('RGB', (new_width, new_height), 'white')

                # 裁剪图像
                crop_box = (xmin, ymin, xmax, ymax)
                cropped_img = pil_img.crop(crop_box)
                new_image.paste(cropped_img, (paste_x, paste_y))

                # 调整公式区域坐标
                adjusted_mfdetrec_res = []
                for mf_res in single_page_mfdetrec_res:
                    mf_xmin, mf_ymin, mf_xmax, mf_ymax = mf_res["bbox"]
                    # 将公式区域坐标调整为相对于裁剪区域的坐标
                    x0 = mf_xmin - xmin + paste_x
                    y0 = mf_ymin - ymin + paste_y
                    x1 = mf_xmax - xmin + paste_x
                    y1 = mf_ymax - ymin + paste_y
217
218
                    # 过滤在图外的公式块
                    if any([x1 < 0, y1 < 0]) or any([x0 > new_width, y0 > new_height]):
219
220
221
222
223
224
225
                        continue
                    else:
                        adjusted_mfdetrec_res.append({
                            "bbox": [x0, y0, x1, y1],
                        })

                # OCR识别
226
227
                new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
                ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res)[0]
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247

                # 整合结果
                if ocr_res:
                    for box_ocr_res in ocr_res:
                        p1, p2, p3, p4 = box_ocr_res[0]
                        text, score = box_ocr_res[1]

                        # 将坐标转换回原图坐标系
                        p1 = [p1[0] - paste_x + xmin, p1[1] - paste_y + ymin]
                        p2 = [p2[0] - paste_x + xmin, p2[1] - paste_y + ymin]
                        p3 = [p3[0] - paste_x + xmin, p3[1] - paste_y + ymin]
                        p4 = [p4[0] - paste_x + xmin, p4[1] - paste_y + ymin]

                        layout_res.append({
                            'category_id': 15,
                            'poly': p1 + p2 + p3 + p4,
                            'score': round(score, 2),
                            'text': text,
                        })

248
249
250
251
            ocr_cost = round(time.time() - ocr_start, 2)
            logger.info(f"ocr cost: {ocr_cost}")

        return layout_res