doc_analyze_by_custom_model.py 7.41 KB
Newer Older
1
import os
2
3
import time

赵小蒙's avatar
赵小蒙 committed
4
5
import fitz
import numpy as np
6
from loguru import logger
7

8
9
10
11
# 关闭paddle的信号处理
import paddle
paddle.disable_signal_handler()

12
13
14
15
16
17
18
19
20
21
22
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1'  # 禁止albumentations检查更新
os.environ['YOLO_VERBOSE'] = 'False'  # disable yolo logger

try:
    import torchtext

    if torchtext.__version__ >= '0.18.0':
        torchtext.disable_torchtext_deprecation_warning()
except ImportError:
    pass

23
24
import magic_pdf.model as model_config
from magic_pdf.data.dataset import Dataset
25
from magic_pdf.libs.clean_memory import clean_memory
26
27
28
29
from magic_pdf.libs.config_reader import (get_device, get_formula_config,
                                          get_layout_config,
                                          get_local_models_dir,
                                          get_table_recog_config)
30
from magic_pdf.model.model_list import MODEL
icecraft's avatar
icecraft committed
31
from magic_pdf.model.operators import InferenceResult
赵小蒙's avatar
赵小蒙 committed
32
33
34
35
36
37
38
39
40
41


def dict_compare(d1, d2):
    return d1.items() == d2.items()


def remove_duplicates_dicts(lst):
    unique_dicts = []
    for dict_item in lst:
        if not any(
42
            dict_compare(dict_item, existing_dict) for existing_dict in unique_dicts
赵小蒙's avatar
赵小蒙 committed
43
44
45
46
47
        ):
            unique_dicts.append(dict_item)
    return unique_dicts


48
49
50
def load_images_from_pdf(
    pdf_bytes: bytes, dpi=200, start_page_id=0, end_page_id=None
) -> list:
51
52
53
    try:
        from PIL import Image
    except ImportError:
54
        logger.error('Pillow not installed, please install by pip.')
赵小蒙's avatar
update:  
赵小蒙 committed
55
56
        exit(1)

赵小蒙's avatar
赵小蒙 committed
57
    images = []
58
    with fitz.open('pdf', pdf_bytes) as doc:
59
        pdf_page_num = doc.page_count
60
61
62
63
64
        end_page_id = (
            end_page_id
            if end_page_id is not None and end_page_id >= 0
            else pdf_page_num - 1
        )
65
        if end_page_id > pdf_page_num - 1:
66
            logger.warning('end_page_id is out of range, use images length')
67
68
            end_page_id = pdf_page_num - 1

赵小蒙's avatar
赵小蒙 committed
69
        for index in range(0, doc.page_count):
70
71
72
73
74
            if start_page_id <= index <= end_page_id:
                page = doc[index]
                mat = fitz.Matrix(dpi / 72, dpi / 72)
                pm = page.get_pixmap(matrix=mat, alpha=False)

75
76
                # If the width or height exceeds 4500 after scaling, do not scale further.
                if pm.width > 4500 or pm.height > 4500:
77
                    pm = page.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False)
赵小蒙's avatar
赵小蒙 committed
78

79
                img = Image.frombytes('RGB', (pm.width, pm.height), pm.samples)
80
                img = np.array(img)
81
                img_dict = {'img': img, 'width': pm.width, 'height': pm.height}
82
            else:
83
                img_dict = {'img': [], 'width': 0, 'height': 0}
赵小蒙's avatar
赵小蒙 committed
84
85
86
87
88

            images.append(img_dict)
    return images


89
90
91
92
93
94
95
96
97
class ModelSingleton:
    _instance = None
    _models = {}

    def __new__(cls, *args, **kwargs):
        if cls._instance is None:
            cls._instance = super().__new__(cls)
        return cls._instance

98
99
100
101
102
103
104
105
106
    def get_model(
        self,
        ocr: bool,
        show_log: bool,
        lang=None,
        layout_model=None,
        formula_enable=None,
        table_enable=None,
    ):
107
        key = (ocr, show_log, lang, layout_model, formula_enable, table_enable)
108
        if key not in self._models:
109
110
111
112
113
114
115
116
            self._models[key] = custom_model_init(
                ocr=ocr,
                show_log=show_log,
                lang=lang,
                layout_model=layout_model,
                formula_enable=formula_enable,
                table_enable=table_enable,
            )
117
118
119
        return self._models[key]


120
121
122
123
124
125
126
127
def custom_model_init(
    ocr: bool = False,
    show_log: bool = False,
    lang=None,
    layout_model=None,
    formula_enable=None,
    table_enable=None,
):
128

129
130
    model = None

131
132
133
134
135
    if model_config.__model_mode__ == 'lite':
        logger.warning(
            'The Lite mode is provided for developers to conduct testing only, and the output quality is '
            'not guaranteed to be reliable.'
        )
136
        model = MODEL.Paddle
137
    elif model_config.__model_mode__ == 'full':
138
139
        model = MODEL.PEK

140
    if model_config.__use_inside_model__:
141
        model_init_start = time.time()
142
143
        if model == MODEL.Paddle:
            from magic_pdf.model.pp_structure_v2 import CustomPaddleModel
144

145
            custom_model = CustomPaddleModel(ocr=ocr, show_log=show_log, lang=lang)
146
147
        elif model == MODEL.PEK:
            from magic_pdf.model.pdf_extract_kit import CustomPEKModel
148

149
150
151
            # 从配置文件读取model-dir和device
            local_models_dir = get_local_models_dir()
            device = get_device()
152
153
154

            layout_config = get_layout_config()
            if layout_model is not None:
155
                layout_config['model'] = layout_model
156
157
158

            formula_config = get_formula_config()
            if formula_enable is not None:
159
                formula_config['enable'] = formula_enable
160

161
            table_config = get_table_recog_config()
162
            if table_enable is not None:
163
                table_config['enable'] = table_enable
164
165

            model_input = {
166
167
168
169
170
171
172
173
                'ocr': ocr,
                'show_log': show_log,
                'models_dir': local_models_dir,
                'device': device,
                'table_config': table_config,
                'layout_config': layout_config,
                'formula_config': formula_config,
                'lang': lang,
174
175
            }

176
            custom_model = CustomPEKModel(**model_input)
177
        else:
178
            logger.error('Not allow model_name!')
179
            exit(1)
180
        model_init_cost = time.time() - model_init_start
181
        logger.info(f'model init cost: {model_init_cost}')
182
    else:
183
        logger.error('use_inside_model is False, not allow to use inside model')
184
185
        exit(1)

186
187
188
    return custom_model


189
190
191
192
193
194
195
196
197
198
199
def doc_analyze(
    dataset: Dataset,
    ocr: bool = False,
    show_log: bool = False,
    start_page_id=0,
    end_page_id=None,
    lang=None,
    layout_model=None,
    formula_enable=None,
    table_enable=None,
) -> InferenceResult:
200

201
    if lang == '':
202
        lang = None
203
204

    model_manager = ModelSingleton()
205
206
207
    custom_model = model_manager.get_model(
        ocr, show_log, lang, layout_model, formula_enable, table_enable
    )
208

209
    model_json = []
210
    doc_analyze_start = time.time()
211

212
213
214
215
216
217
218
219
220
    if end_page_id is None:
        end_page_id = len(dataset)

    for index in range(len(dataset)):
        page_data = dataset.get_page(index)
        img_dict = page_data.get_image()
        img = img_dict['img']
        page_width = img_dict['width']
        page_height = img_dict['height']
221
        if start_page_id <= index <= end_page_id:
222
            page_start = time.time()
223
            result = custom_model(img)
224
            logger.info(f'-----page_id : {index}, page total time: {round(time.time() - page_start, 2)}-----')
225
226
        else:
            result = []
227
228
229

        page_info = {'page_no': index, 'height': page_height, 'width': page_width}
        page_dict = {'layout_dets': result, 'page_info': page_info}
230
        model_json.append(page_dict)
231

232
233
234
    gc_start = time.time()
    clean_memory()
    gc_time = round(time.time() - gc_start, 2)
235
    logger.info(f'gc time: {gc_time}')
236

237
    doc_analyze_time = round(time.time() - doc_analyze_start, 2)
238
239
240
241
242
    doc_analyze_speed = round((end_page_id + 1 - start_page_id) / doc_analyze_time, 2)
    logger.info(
        f'doc analyze time: {round(time.time() - doc_analyze_start, 2)},'
        f' speed: {doc_analyze_speed} pages/second'
    )
赵小蒙's avatar
update:  
赵小蒙 committed
243

244
    return InferenceResult(model_json, dataset)