pdf_extract_kit.py 10.1 KB
Newer Older
1
from loguru import logger
myhloli's avatar
myhloli committed
2
import os
3
import time
4
5

os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1'  # 禁止albumentations检查更新
myhloli's avatar
myhloli committed
6
7
8
9
10
11
try:
    import cv2
    import yaml
    import argparse
    import numpy as np
    import torch
12

myhloli's avatar
myhloli committed
13
14
15
16
17
18
19
20
    from paddleocr import draw_ocr
    from PIL import Image
    from torchvision import transforms
    from torch.utils.data import Dataset, DataLoader
    from ultralytics import YOLO
    from unimernet.common.config import Config
    import unimernet.tasks as tasks
    from unimernet.processors import load_processor
赵小蒙's avatar
update:  
赵小蒙 committed
21

22
23
except ImportError as e:
    logger.exception(e)
24
25
26
    logger.error(
        'Required dependency not installed, please install by \n'
        '"pip install magic-pdf[full] detectron2 --extra-index-url https://myhloli.github.io/wheels/"')
myhloli's avatar
myhloli committed
27
    exit(1)
赵小蒙's avatar
update:  
赵小蒙 committed
28

29
30
31
32
from magic_pdf.model.pek_sub_modules.layoutlmv3.model_init import Layoutlmv3_Predictor
from magic_pdf.model.pek_sub_modules.post_process import get_croped_image, latex_rm_whitespace
from magic_pdf.model.pek_sub_modules.self_modify import ModifiedPaddleOCR

赵小蒙's avatar
update:  
赵小蒙 committed
33

34
35
36
def mfd_model_init(weight):
    mfd_model = YOLO(weight)
    return mfd_model
赵小蒙's avatar
update:  
赵小蒙 committed
37
38


39
def mfr_model_init(weight_dir, cfg_path, _device_='cpu'):
40
41
42
43
44
45
46
    args = argparse.Namespace(cfg_path=cfg_path, options=None)
    cfg = Config(args)
    cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.bin")
    cfg.config.model.model_config.model_name = weight_dir
    cfg.config.model.tokenizer_config.path = weight_dir
    task = tasks.setup_task(cfg)
    model = task.build_model(cfg)
47
    model = model.to(_device_)
48
49
    vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
    return model, vis_processor
赵小蒙's avatar
update:  
赵小蒙 committed
50
51


52
53
54
55
56
def layout_model_init(weight, config_file, device):
    model = Layoutlmv3_Predictor(weight, config_file, device)
    return model


57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
class MathDataset(Dataset):
    def __init__(self, image_paths, transform=None):
        self.image_paths = image_paths
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        # if not pil image, then convert to pil image
        if isinstance(self.image_paths[idx], str):
            raw_image = Image.open(self.image_paths[idx])
        else:
            raw_image = self.image_paths[idx]
        if self.transform:
            image = self.transform(raw_image)
73
            return image
74
75


76
class CustomPEKModel:
77

78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
    def __init__(self, ocr: bool = False, show_log: bool = False, **kwargs):
        """
        ======== model init ========
        """
        # 获取当前文件(即 pdf_extract_kit.py)的绝对路径
        current_file_path = os.path.abspath(__file__)
        # 获取当前文件所在的目录(model)
        current_dir = os.path.dirname(current_file_path)
        # 上一级目录(magic_pdf)
        root_dir = os.path.dirname(current_dir)
        # model_config目录
        model_config_dir = os.path.join(root_dir, 'resources', 'model_config')
        # 构建 model_configs.yaml 文件的完整路径
        config_path = os.path.join(model_config_dir, 'model_configs.yaml')
        with open(config_path, "r") as f:
            self.configs = yaml.load(f, Loader=yaml.FullLoader)
        # 初始化解析配置
        self.apply_layout = kwargs.get("apply_layout", self.configs["config"]["layout"])
        self.apply_formula = kwargs.get("apply_formula", self.configs["config"]["formula"])
        self.apply_ocr = ocr
        logger.info(
            "DocAnalysis init, this may take some times. apply_layout: {}, apply_formula: {}, apply_ocr: {}".format(
                self.apply_layout, self.apply_formula, self.apply_ocr
赵小蒙's avatar
update:  
赵小蒙 committed
101
            )
102
103
104
        )
        assert self.apply_layout, "DocAnalysis must contain layout model."
        # 初始化解析方案
105
        self.device = kwargs.get("device", self.configs["config"]["device"])
106
        logger.info("using device: {}".format(self.device))
107
        models_dir = kwargs.get("models_dir", os.path.join(root_dir, "resources", "models"))
108

109
110
111
        # 初始化公式识别
        if self.apply_formula:
            # 初始化公式检测模型
112
113
            self.mfd_model = mfd_model_init(str(os.path.join(models_dir, self.configs["weights"]["mfd"])))

114
            # 初始化公式解析模型
115
116
117
            mfr_weight_dir = str(os.path.join(models_dir, self.configs["weights"]["mfr"]))
            mfr_cfg_path = str(os.path.join(model_config_dir, "UniMERNet", "demo.yaml"))
            self.mfr_model, mfr_vis_processors = mfr_model_init(mfr_weight_dir, mfr_cfg_path, _device_=self.device)
118
            self.mfr_transform = transforms.Compose([mfr_vis_processors, ])
119
120
121
122
123
124
125

        # 初始化layout模型
        self.layout_model = Layoutlmv3_Predictor(
            str(os.path.join(models_dir, self.configs['weights']['layout'])),
            str(os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml")),
            device=self.device
        )
126
127
128
        # 初始化ocr
        if self.apply_ocr:
            self.ocr_model = ModifiedPaddleOCR(show_log=show_log)
赵小蒙's avatar
update:  
赵小蒙 committed
129

130
        logger.info('DocAnalysis init done!')
赵小蒙's avatar
update:  
赵小蒙 committed
131

132
133
    def __call__(self, image):

134
135
136
        latex_filling_list = []
        mf_image_list = []

137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
        # layout检测
        layout_start = time.time()
        layout_res = self.layout_model(image, ignore_catids=[])
        layout_cost = round(time.time() - layout_start, 2)
        logger.info(f"layout detection cost: {layout_cost}")

        # 公式检测
        mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True)[0]
        for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()):
            xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
            new_item = {
                'category_id': 13 + int(cla.item()),
                'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
                'score': round(float(conf.item()), 2),
                'latex': '',
            }
            layout_res.append(new_item)
            latex_filling_list.append(new_item)
            bbox_img = get_croped_image(Image.fromarray(image), [xmin, ymin, xmax, ymax])
            mf_image_list.append(bbox_img)

        # 公式识别
        mfr_start = time.time()
160
        dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
161
        dataloader = DataLoader(dataset, batch_size=64, num_workers=0)
162
        mfr_res = []
163
164
165
        for mf_img in dataloader:
            mf_img = mf_img.to(self.device)
            output = self.mfr_model.generate({'image': mf_img})
166
167
168
            mfr_res.extend(output['pred_str'])
        for res, latex in zip(latex_filling_list, mfr_res):
            res['latex'] = latex_rm_whitespace(latex)
169
170
        mfr_cost = round(time.time() - mfr_start, 2)
        logger.info(f"formula nums: {len(mf_image_list)}, mfr time: {mfr_cost}")
171

myhloli's avatar
myhloli committed
172
        # ocr识别
173
        if self.apply_ocr:
174
175
            ocr_start = time.time()
            pil_img = Image.fromarray(image)
176
177
178

            # 筛选出需要OCR的区域和公式区域
            ocr_res_list = []
179
180
181
182
            single_page_mfdetrec_res = []
            for res in layout_res:
                if int(res['category_id']) in [13, 14]:
                    single_page_mfdetrec_res.append({
183
184
                        "bbox": [int(res['poly'][0]), int(res['poly'][1]),
                                 int(res['poly'][4]), int(res['poly'][5])],
185
                    })
186
187
188
189
190
191
192
193
194
195
196
                elif int(res['category_id']) in [0, 1, 2, 4, 6, 7]:
                    ocr_res_list.append(res)

            # 对每一个需OCR处理的区域进行处理
            for res in ocr_res_list:
                xmin, ymin = int(res['poly'][0]), int(res['poly'][1])
                xmax, ymax = int(res['poly'][4]), int(res['poly'][5])

                paste_x = 50
                paste_y = 50
                # 创建一个宽高各多50的白色背景
197
198
                new_width = xmax - xmin + paste_x * 2
                new_height = ymax - ymin + paste_y * 2
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
                new_image = Image.new('RGB', (new_width, new_height), 'white')

                # 裁剪图像
                crop_box = (xmin, ymin, xmax, ymax)
                cropped_img = pil_img.crop(crop_box)
                new_image.paste(cropped_img, (paste_x, paste_y))

                # 调整公式区域坐标
                adjusted_mfdetrec_res = []
                for mf_res in single_page_mfdetrec_res:
                    mf_xmin, mf_ymin, mf_xmax, mf_ymax = mf_res["bbox"]
                    # 将公式区域坐标调整为相对于裁剪区域的坐标
                    x0 = mf_xmin - xmin + paste_x
                    y0 = mf_ymin - ymin + paste_y
                    x1 = mf_xmax - xmin + paste_x
                    y1 = mf_ymax - ymin + paste_y
215
216
                    # 过滤在图外的公式块
                    if any([x1 < 0, y1 < 0]) or any([x0 > new_width, y0 > new_height]):
217
218
219
220
221
222
223
                        continue
                    else:
                        adjusted_mfdetrec_res.append({
                            "bbox": [x0, y0, x1, y1],
                        })

                # OCR识别
224
225
                new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
                ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res)[0]
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245

                # 整合结果
                if ocr_res:
                    for box_ocr_res in ocr_res:
                        p1, p2, p3, p4 = box_ocr_res[0]
                        text, score = box_ocr_res[1]

                        # 将坐标转换回原图坐标系
                        p1 = [p1[0] - paste_x + xmin, p1[1] - paste_y + ymin]
                        p2 = [p2[0] - paste_x + xmin, p2[1] - paste_y + ymin]
                        p3 = [p3[0] - paste_x + xmin, p3[1] - paste_y + ymin]
                        p4 = [p4[0] - paste_x + xmin, p4[1] - paste_y + ymin]

                        layout_res.append({
                            'category_id': 15,
                            'poly': p1 + p2 + p3 + p4,
                            'score': round(score, 2),
                            'text': text,
                        })

246
247
248
249
            ocr_cost = round(time.time() - ocr_start, 2)
            logger.info(f"ocr cost: {ocr_cost}")

        return layout_res