doc_analyze_by_custom_model.py 10.1 KB
Newer Older
icecraft's avatar
icecraft committed
1
2
import concurrent.futures as fut
import multiprocessing as mp
3
import os
4
import time
icecraft's avatar
icecraft committed
5

icecraft's avatar
icecraft committed
6
import numpy as np
icecraft's avatar
icecraft committed
7
8
import torch

9
10
11
12
os.environ['FLAGS_npu_jit_compile'] = '0'  # 关闭paddle的jit编译
os.environ['FLAGS_use_stride_kernel'] = '0'
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'  # 让mps可以fallback
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1'  # 禁止albumentations检查更新
icecraft's avatar
icecraft committed
13

14

icecraft's avatar
icecraft committed
15
16
from loguru import logger

17
from magic_pdf.model.sub_modules.model_utils import get_vram
18
from magic_pdf.config.enums import SupportedPdfParseMethod
19
20
import magic_pdf.model as model_config
from magic_pdf.data.dataset import Dataset
21
from magic_pdf.libs.clean_memory import clean_memory
22
23
24
25
from magic_pdf.libs.config_reader import (get_device, get_formula_config,
                                          get_layout_config,
                                          get_local_models_dir,
                                          get_table_recog_config)
26
from magic_pdf.model.model_list import MODEL
icecraft's avatar
icecraft committed
27

icecraft's avatar
icecraft committed
28
# from magic_pdf.operators.models import InferenceResult
赵小蒙's avatar
赵小蒙 committed
29

30
31
32
33
34
35
36
37
38
class ModelSingleton:
    _instance = None
    _models = {}

    def __new__(cls, *args, **kwargs):
        if cls._instance is None:
            cls._instance = super().__new__(cls)
        return cls._instance

39
40
41
42
43
44
45
46
47
    def get_model(
        self,
        ocr: bool,
        show_log: bool,
        lang=None,
        layout_model=None,
        formula_enable=None,
        table_enable=None,
    ):
48
        key = (ocr, show_log, lang, layout_model, formula_enable, table_enable)
49
        if key not in self._models:
50
51
52
53
54
55
56
57
            self._models[key] = custom_model_init(
                ocr=ocr,
                show_log=show_log,
                lang=lang,
                layout_model=layout_model,
                formula_enable=formula_enable,
                table_enable=table_enable,
            )
58
59
60
        return self._models[key]


61
62
63
64
65
66
67
68
def custom_model_init(
    ocr: bool = False,
    show_log: bool = False,
    lang=None,
    layout_model=None,
    formula_enable=None,
    table_enable=None,
):
69
    model = None
70
71
72
73
74
    if model_config.__model_mode__ == 'lite':
        logger.warning(
            'The Lite mode is provided for developers to conduct testing only, and the output quality is '
            'not guaranteed to be reliable.'
        )
75
        model = MODEL.Paddle
76
    elif model_config.__model_mode__ == 'full':
77
78
        model = MODEL.PEK

79
    if model_config.__use_inside_model__:
80
        model_init_start = time.time()
81
82
        if model == MODEL.Paddle:
            from magic_pdf.model.pp_structure_v2 import CustomPaddleModel
83

84
            custom_model = CustomPaddleModel(ocr=ocr, show_log=show_log, lang=lang)
85
86
        elif model == MODEL.PEK:
            from magic_pdf.model.pdf_extract_kit import CustomPEKModel
87

88
89
90
            # 从配置文件读取model-dir和device
            local_models_dir = get_local_models_dir()
            device = get_device()
91
92
93

            layout_config = get_layout_config()
            if layout_model is not None:
94
                layout_config['model'] = layout_model
95
96
97

            formula_config = get_formula_config()
            if formula_enable is not None:
98
                formula_config['enable'] = formula_enable
99

100
            table_config = get_table_recog_config()
101
            if table_enable is not None:
102
                table_config['enable'] = table_enable
103
104

            model_input = {
105
106
107
108
109
110
111
112
                'ocr': ocr,
                'show_log': show_log,
                'models_dir': local_models_dir,
                'device': device,
                'table_config': table_config,
                'layout_config': layout_config,
                'formula_config': formula_config,
                'lang': lang,
113
114
            }

115
            custom_model = CustomPEKModel(**model_input)
116
        else:
117
            logger.error('Not allow model_name!')
118
            exit(1)
119
        model_init_cost = time.time() - model_init_start
120
        logger.info(f'model init cost: {model_init_cost}')
121
    else:
122
        logger.error('use_inside_model is False, not allow to use inside model')
123
124
        exit(1)

125
126
    return custom_model

127
128
129
130
131
132
133
134
135
136
def doc_analyze(
    dataset: Dataset,
    ocr: bool = False,
    show_log: bool = False,
    start_page_id=0,
    end_page_id=None,
    lang=None,
    layout_model=None,
    formula_enable=None,
    table_enable=None,
icecraft's avatar
icecraft committed
137
):
138
139
140
141
142
    end_page_id = (
        end_page_id
        if end_page_id is not None and end_page_id >= 0
        else len(dataset) - 1
    )
icecraft's avatar
icecraft committed
143

icecraft's avatar
icecraft committed
144
    MIN_BATCH_INFERENCE_SIZE = int(os.environ.get('MINERU_MIN_BATCH_INFERENCE_SIZE', 100))
icecraft's avatar
icecraft committed
145
146
147
148
149
150
151
152
    images = []
    page_wh_list = []
    for index in range(len(dataset)):
        if start_page_id <= index <= end_page_id:
            page_data = dataset.get_page(index)
            img_dict = page_data.get_image()
            images.append(img_dict['img'])
            page_wh_list.append((img_dict['width'], img_dict['height']))
icecraft's avatar
icecraft committed
153
154
155
156
    if lang is None or lang == 'auto':
        images_with_extra_info = [(images[index], ocr, dataset._lang) for index in range(len(dataset))]
    else:
        images_with_extra_info = [(images[index], ocr, lang) for index in range(len(dataset))]
icecraft's avatar
icecraft committed
157

icecraft's avatar
icecraft committed
158
159
    if len(images) >= MIN_BATCH_INFERENCE_SIZE:
        batch_size = MIN_BATCH_INFERENCE_SIZE
160
        batch_images = [images_with_extra_info[i:i+batch_size] for i in range(0, len(images_with_extra_info), batch_size)]
icecraft's avatar
icecraft committed
161
    else:
162
        batch_images = [images_with_extra_info]
icecraft's avatar
icecraft committed
163
164
165

    results = []
    for sn, batch_image in enumerate(batch_images):
icecraft's avatar
icecraft committed
166
        _, result = may_batch_image_analyze(batch_image, sn, ocr, show_log,layout_model, formula_enable, table_enable)
icecraft's avatar
icecraft committed
167
        results.extend(result)
icecraft's avatar
icecraft committed
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187

    model_json = []
    for index in range(len(dataset)):
        if start_page_id <= index <= end_page_id:
            result = results.pop(0)
            page_width, page_height = page_wh_list.pop(0)
        else:
            result = []
            page_height = 0
            page_width = 0

        page_info = {'page_no': index, 'width': page_width, 'height': page_height}
        page_dict = {'layout_dets': result, 'page_info': page_info}
        model_json.append(page_dict)

    from magic_pdf.operators.models import InferenceResult
    return InferenceResult(model_json, dataset)

def batch_doc_analyze(
    datasets: list[Dataset],
188
    parse_method: str,
icecraft's avatar
icecraft committed
189
190
191
192
193
194
    show_log: bool = False,
    lang=None,
    layout_model=None,
    formula_enable=None,
    table_enable=None,
):
icecraft's avatar
icecraft committed
195
    MIN_BATCH_INFERENCE_SIZE = int(os.environ.get('MINERU_MIN_BATCH_INFERENCE_SIZE', 100))
196
    batch_size = MIN_BATCH_INFERENCE_SIZE
icecraft's avatar
icecraft committed
197
198
    images = []
    page_wh_list = []
199
200

    images_with_extra_info = []
icecraft's avatar
icecraft committed
201
202
    for dataset in datasets:
        for index in range(len(dataset)):
203
            if lang is None or lang == 'auto':
204
                _lang = dataset._lang
205
            else:
206
207
                _lang = lang

icecraft's avatar
icecraft committed
208
209
210
211
            page_data = dataset.get_page(index)
            img_dict = page_data.get_image()
            images.append(img_dict['img'])
            page_wh_list.append((img_dict['width'], img_dict['height']))
212
213
214
215
            if parse_method == 'auto':
                images_with_extra_info.append((images[-1], dataset.classify() == SupportedPdfParseMethod.OCR, _lang))
            else:
                images_with_extra_info.append((images[-1], parse_method == 'ocr', _lang))
icecraft's avatar
icecraft committed
216

217
218
219
    batch_images = [images_with_extra_info[i:i+batch_size] for i in range(0, len(images_with_extra_info), batch_size)]
    results = []
    for sn, batch_image in enumerate(batch_images):
icecraft's avatar
icecraft committed
220
        _, result = may_batch_image_analyze(batch_image, sn, True, show_log, layout_model, formula_enable, table_enable)
221
        results.extend(result)
icecraft's avatar
icecraft committed
222

icecraft's avatar
icecraft committed
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
    infer_results = []
    from magic_pdf.operators.models import InferenceResult
    for index in range(len(datasets)):
        dataset = datasets[index]
        model_json = []
        for i in range(len(dataset)):
            result = results.pop(0)
            page_width, page_height = page_wh_list.pop(0)
            page_info = {'page_no': i, 'width': page_width, 'height': page_height}
            page_dict = {'layout_dets': result, 'page_info': page_info}
            model_json.append(page_dict)
        infer_results.append(InferenceResult(model_json, dataset))
    return infer_results


def may_batch_image_analyze(
239
        images_with_extra_info: list[(np.ndarray, bool, str)],
icecraft's avatar
icecraft committed
240
        idx: int,
241
        ocr: bool,
icecraft's avatar
icecraft committed
242
243
244
        show_log: bool = False,
        layout_model=None,
        formula_enable=None,
icecraft's avatar
icecraft committed
245
246
247
248
249
250
        table_enable=None):
    # os.environ['CUDA_VISIBLE_DEVICES'] = str(idx)
    # 关闭paddle的信号处理
    import paddle
    paddle.disable_signal_handler()
    from magic_pdf.model.batch_analyze import BatchAnalyze
icecraft's avatar
icecraft committed
251

252
    model_manager = ModelSingleton()
253

254
    images = [image for image, _, _ in images_with_extra_info]
255
    batch_analyze = False
256
    batch_ratio = 1
257
258
    device = get_device()

icecraft's avatar
icecraft committed
259
    if str(device).startswith('npu'):
260
261
        import torch_npu
        if torch_npu.npu.is_available():
262
            torch.npu.set_compile_mode(jit_compile=False)
263

264
    if str(device).startswith('npu') or str(device).startswith('cuda'):
icecraft's avatar
icecraft committed
265
        gpu_memory = int(os.getenv('VIRTUAL_VRAM_SIZE', round(get_vram(device))))
266
        if gpu_memory is not None:
267
            if gpu_memory >= 16:
268
                batch_ratio = 16
269
            elif gpu_memory >= 12:
270
                batch_ratio = 8
271
            elif gpu_memory >= 8:
272
                batch_ratio = 4
273
            elif gpu_memory >= 6:
274
                batch_ratio = 2
275
276
            else:
                batch_ratio = 1
277
            logger.info(f'gpu_memory: {gpu_memory} GB, batch_ratio: {batch_ratio}')
icecraft's avatar
icecraft committed
278
            # batch_analyze = True
279
    elif str(device).startswith('mps'):
icecraft's avatar
icecraft committed
280
281
        # batch_analyze = True
        pass
282

icecraft's avatar
icecraft committed
283
    doc_analyze_start = time.time()
284

icecraft's avatar
icecraft committed
285
286
    batch_model = BatchAnalyze(model_manager, batch_ratio, show_log, layout_model, formula_enable, table_enable)
    results = batch_model(images_with_extra_info)
287

288
    gc_start = time.time()
289
    clean_memory(get_device())
290
    gc_time = round(time.time() - gc_start, 2)
291
    logger.info(f'gc time: {gc_time}')
292

293
    doc_analyze_time = round(time.time() - doc_analyze_start, 2)
icecraft's avatar
icecraft committed
294
    doc_analyze_speed = round(len(images) / doc_analyze_time, 2)
295
296
297
298
    logger.info(
        f'doc analyze time: {round(time.time() - doc_analyze_start, 2)},'
        f' speed: {doc_analyze_speed} pages/second'
    )
icecraft's avatar
icecraft committed
299
    return (idx, results)