pipeline_analyze.py 6.2 KB
Newer Older
1
2
import os
import time
3
4
from typing import List, Tuple
import PIL.Image
5
from loguru import logger
6

7
from .model_init import MineruPipelineModel
8
from mineru.utils.config_reader import get_device
9
from ...utils.pdf_classify import classify
10
from ...utils.pdf_image_tools import load_images_from_pdf
11
from ...utils.model_utils import get_vram, clean_memory
12
13
14
15
16


os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'  # 让mps可以fallback
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1'  # 禁止albumentations检查更新

17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
class ModelSingleton:
    _instance = None
    _models = {}

    def __new__(cls, *args, **kwargs):
        if cls._instance is None:
            cls._instance = super().__new__(cls)
        return cls._instance

    def get_model(
        self,
        lang=None,
        formula_enable=None,
        table_enable=None,
    ):
        key = (lang, formula_enable, table_enable)
        if key not in self._models:
            self._models[key] = custom_model_init(
                lang=lang,
                formula_enable=formula_enable,
                table_enable=table_enable,
            )
        return self._models[key]


def custom_model_init(
    lang=None,
44
45
    formula_enable=True,
    table_enable=True,
46
47
48
49
50
):
    model_init_start = time.time()
    # 从配置文件读取model-dir和device
    device = get_device()

51
52
    formula_config = {"enable": formula_enable}
    table_config = {"enable": table_enable}
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67

    model_input = {
        'device': device,
        'table_config': table_config,
        'formula_config': formula_config,
        'lang': lang,
    }

    custom_model = MineruPipelineModel(**model_input)

    model_init_cost = time.time() - model_init_start
    logger.info(f'model init cost: {model_init_cost}')

    return custom_model

68

69
def doc_analyze(
70
71
72
        pdf_bytes_list,
        lang_list,
        parse_method: str = 'auto',
73
74
        formula_enable=True,
        table_enable=True,
75
):
76
    """
77
78
    适当调大MIN_BATCH_INFERENCE_SIZE可以提高性能,更大的 MIN_BATCH_INFERENCE_SIZE会消耗更多内存,
    可通过环境变量MINERU_MIN_BATCH_INFERENCE_SIZE设置,默认值为384。
79
    """
80
    min_batch_inference_size = int(os.environ.get('MINERU_MIN_BATCH_INFERENCE_SIZE', 384))
81

82
83
84
    # 收集所有页面信息
    all_pages_info = []  # 存储(dataset_index, page_index, img, ocr, lang, width, height)

85
86
    all_image_lists = []
    all_pdf_docs = []
87
    ocr_enabled_list = []
88
89
    for pdf_idx, pdf_bytes in enumerate(pdf_bytes_list):
        # 确定OCR设置
90
        _ocr_enable = False
91
92
        if parse_method == 'auto':
            if classify(pdf_bytes) == 'ocr':
93
                _ocr_enable = True
94
        elif parse_method == 'ocr':
95
            _ocr_enable = True
96

97
        ocr_enabled_list.append(_ocr_enable)
98
99
100
        _lang = lang_list[pdf_idx]

        # 收集每个数据集中的页面
101
102
103
104
105
        images_list, pdf_doc = load_images_from_pdf(pdf_bytes)
        all_image_lists.append(images_list)
        all_pdf_docs.append(pdf_doc)
        for page_idx in range(len(images_list)):
            img_dict = images_list[page_idx]
106
107
            all_pages_info.append((
                pdf_idx, page_idx,
108
                img_dict['img_pil'], _ocr_enable, _lang,
109
110
111
112
            ))

    # 准备批处理
    images_with_extra_info = [(info[2], info[3], info[4]) for info in all_pages_info]
113
    batch_size = min_batch_inference_size
114
115
116
117
118
119
    batch_images = [
        images_with_extra_info[i:i + batch_size]
        for i in range(0, len(images_with_extra_info), batch_size)
    ]

    # 执行批处理
120
121
122
123
    results = []
    processed_images_count = 0
    for index, batch_image in enumerate(batch_images):
        processed_images_count += len(batch_image)
124
125
126
127
        logger.info(
            f'Batch {index + 1}/{len(batch_images)}: '
            f'{processed_images_count} pages/{len(images_with_extra_info)} pages'
        )
128
        batch_results = batch_image_analyze(batch_image, formula_enable, table_enable)
129
        results.extend(batch_results)
130

131
    # 构建返回结果
132
    infer_results = []
133

134
135
136
    for _ in range(len(pdf_bytes_list)):
        infer_results.append([])

137
138
139
    for i, page_info in enumerate(all_pages_info):
        pdf_idx, page_idx, pil_img, _, _ = page_info
        result = results[i]
140

141
        page_info_dict = {'page_no': page_idx, 'width': pil_img.width, 'height': pil_img.height}
142
        page_dict = {'layout_dets': result, 'page_info': page_info_dict}
143
144

        infer_results[pdf_idx].append(page_dict)
145

146
    return infer_results, all_image_lists, all_pdf_docs, lang_list, ocr_enabled_list
147
148


149
def batch_image_analyze(
150
        images_with_extra_info: List[Tuple[PIL.Image.Image, bool, str]],
151
152
        formula_enable=True,
        table_enable=True):
153
154
155
156
157
158
159
160
161
162
    # os.environ['CUDA_VISIBLE_DEVICES'] = str(idx)

    from .batch_analyze import BatchAnalyze

    model_manager = ModelSingleton()

    batch_ratio = 1
    device = get_device()

    if str(device).startswith('npu'):
163
164
165
        try:
            import torch_npu
            if torch_npu.npu.is_available():
166
                torch_npu.npu.set_compile_mode(jit_compile=False)
167
168
169
170
171
        except Exception as e:
            raise RuntimeError(
                "NPU is selected as device, but torch_npu is not available. "
                "Please ensure that the torch_npu package is installed correctly."
            ) from e
172
173
174
175

    if str(device).startswith('npu') or str(device).startswith('cuda'):
        vram = get_vram(device)
        if vram is not None:
176
            gpu_memory = int(os.getenv('MINERU_VIRTUAL_VRAM_SIZE', round(vram)))
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
            if gpu_memory >= 16:
                batch_ratio = 16
            elif gpu_memory >= 12:
                batch_ratio = 8
            elif gpu_memory >= 8:
                batch_ratio = 4
            elif gpu_memory >= 6:
                batch_ratio = 2
            else:
                batch_ratio = 1
            logger.info(f'gpu_memory: {gpu_memory} GB, batch_ratio: {batch_ratio}')
        else:
            # Default batch_ratio when VRAM can't be determined
            batch_ratio = 1
            logger.info(f'Could not determine GPU memory, using default batch_ratio: {batch_ratio}')

    batch_model = BatchAnalyze(model_manager, batch_ratio, formula_enable, table_enable)
    results = batch_model(images_with_extra_info)

    clean_memory(get_device())

    return results