Commit f3502226 authored by myhloli's avatar myhloli
Browse files

feat(model): improve batch analysis logic and support npu

- Add support for NPU (Neural Processing Unit) when available
- Implement batch analysis for GPU and NPU devices
- Optimize memory usage and improve performance
- Update logging and error handling
parent 84f808fa
...@@ -7,17 +7,17 @@ from loguru import logger ...@@ -7,17 +7,17 @@ from loguru import logger
from PIL import Image from PIL import Image
from magic_pdf.config.constants import MODEL_NAME from magic_pdf.config.constants import MODEL_NAME
from magic_pdf.config.exceptions import CUDA_NOT_AVAILABLE # from magic_pdf.config.exceptions import CUDA_NOT_AVAILABLE
from magic_pdf.data.dataset import Dataset # from magic_pdf.data.dataset import Dataset
from magic_pdf.libs.clean_memory import clean_memory # from magic_pdf.libs.clean_memory import clean_memory
from magic_pdf.libs.config_reader import get_device # from magic_pdf.libs.config_reader import get_device
from magic_pdf.model.doc_analyze_by_custom_model import ModelSingleton # from magic_pdf.model.doc_analyze_by_custom_model import ModelSingleton
from magic_pdf.model.pdf_extract_kit import CustomPEKModel from magic_pdf.model.pdf_extract_kit import CustomPEKModel
from magic_pdf.model.sub_modules.model_utils import ( from magic_pdf.model.sub_modules.model_utils import (
clean_vram, crop_img, get_res_list_from_layout_res) clean_vram, crop_img, get_res_list_from_layout_res)
from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import ( from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import (
get_adjusted_mfdetrec_res, get_ocr_result_list) get_adjusted_mfdetrec_res, get_ocr_result_list)
from magic_pdf.operators.models import InferenceResult # from magic_pdf.operators.models import InferenceResult
YOLO_LAYOUT_BASE_BATCH_SIZE = 4 YOLO_LAYOUT_BASE_BATCH_SIZE = 4
MFD_BASE_BATCH_SIZE = 1 MFD_BASE_BATCH_SIZE = 1
...@@ -91,10 +91,12 @@ class BatchAnalyze: ...@@ -91,10 +91,12 @@ class BatchAnalyze:
images, images,
batch_size=self.batch_ratio * MFR_BASE_BATCH_SIZE, batch_size=self.batch_ratio * MFR_BASE_BATCH_SIZE,
) )
mfr_count = 0
for image_index in range(len(images)): for image_index in range(len(images)):
images_layout_res[image_index] += images_formula_list[image_index] images_layout_res[image_index] += images_formula_list[image_index]
mfr_count += len(images_formula_list[image_index])
logger.info( logger.info(
f'mfr time: {round(time.time() - mfr_start_time, 2)}, image num: {len(images)}' f'mfr time: {round(time.time() - mfr_start_time, 2)}, image num: {mfr_count}'
) )
# 清理显存 # 清理显存
...@@ -195,81 +197,81 @@ class BatchAnalyze: ...@@ -195,81 +197,81 @@ class BatchAnalyze:
return images_layout_res return images_layout_res
def doc_batch_analyze( # def doc_batch_analyze(
dataset: Dataset, # dataset: Dataset,
ocr: bool = False, # ocr: bool = False,
show_log: bool = False, # show_log: bool = False,
start_page_id=0, # start_page_id=0,
end_page_id=None, # end_page_id=None,
lang=None, # lang=None,
layout_model=None, # layout_model=None,
formula_enable=None, # formula_enable=None,
table_enable=None, # table_enable=None,
batch_ratio: int | None = None, # batch_ratio: int | None = None,
) -> InferenceResult: # ) -> InferenceResult:
"""Perform batch analysis on a document dataset. # """Perform batch analysis on a document dataset.
#
Args: # Args:
dataset (Dataset): The dataset containing document pages to be analyzed. # dataset (Dataset): The dataset containing document pages to be analyzed.
ocr (bool, optional): Flag to enable OCR (Optical Character Recognition). Defaults to False. # ocr (bool, optional): Flag to enable OCR (Optical Character Recognition). Defaults to False.
show_log (bool, optional): Flag to enable logging. Defaults to False. # show_log (bool, optional): Flag to enable logging. Defaults to False.
start_page_id (int, optional): The starting page ID for analysis. Defaults to 0. # start_page_id (int, optional): The starting page ID for analysis. Defaults to 0.
end_page_id (int, optional): The ending page ID for analysis. Defaults to None, which means analyze till the last page. # end_page_id (int, optional): The ending page ID for analysis. Defaults to None, which means analyze till the last page.
lang (str, optional): Language for OCR. Defaults to None. # lang (str, optional): Language for OCR. Defaults to None.
layout_model (optional): Layout model to be used for analysis. Defaults to None. # layout_model (optional): Layout model to be used for analysis. Defaults to None.
formula_enable (optional): Flag to enable formula detection. Defaults to None. # formula_enable (optional): Flag to enable formula detection. Defaults to None.
table_enable (optional): Flag to enable table detection. Defaults to None. # table_enable (optional): Flag to enable table detection. Defaults to None.
batch_ratio (int | None, optional): Ratio for batch processing. Defaults to None, which sets it to 1. # batch_ratio (int | None, optional): Ratio for batch processing. Defaults to None, which sets it to 1.
#
Raises: # Raises:
CUDA_NOT_AVAILABLE: If CUDA is not available, raises an exception as batch analysis is not supported in CPU mode. # CUDA_NOT_AVAILABLE: If CUDA is not available, raises an exception as batch analysis is not supported in CPU mode.
#
Returns: # Returns:
InferenceResult: The result of the batch analysis containing the analyzed data and the dataset. # InferenceResult: The result of the batch analysis containing the analyzed data and the dataset.
""" # """
#
if not torch.cuda.is_available(): # if not torch.cuda.is_available():
raise CUDA_NOT_AVAILABLE('batch analyze not support in CPU mode') # raise CUDA_NOT_AVAILABLE('batch analyze not support in CPU mode')
#
lang = None if lang == '' else lang # lang = None if lang == '' else lang
# TODO: auto detect batch size # # TODO: auto detect batch size
batch_ratio = 1 if batch_ratio is None else batch_ratio # batch_ratio = 1 if batch_ratio is None else batch_ratio
end_page_id = end_page_id if end_page_id else len(dataset) # end_page_id = end_page_id if end_page_id else len(dataset)
#
model_manager = ModelSingleton() # model_manager = ModelSingleton()
custom_model: CustomPEKModel = model_manager.get_model( # custom_model: CustomPEKModel = model_manager.get_model(
ocr, show_log, lang, layout_model, formula_enable, table_enable # ocr, show_log, lang, layout_model, formula_enable, table_enable
) # )
batch_model = BatchAnalyze(model=custom_model, batch_ratio=batch_ratio) # batch_model = BatchAnalyze(model=custom_model, batch_ratio=batch_ratio)
#
model_json = [] # model_json = []
#
# batch analyze # # batch analyze
images = [] # images = []
for index in range(len(dataset)): # for index in range(len(dataset)):
if start_page_id <= index <= end_page_id: # if start_page_id <= index <= end_page_id:
page_data = dataset.get_page(index) # page_data = dataset.get_page(index)
img_dict = page_data.get_image() # img_dict = page_data.get_image()
images.append(img_dict['img']) # images.append(img_dict['img'])
analyze_result = batch_model(images) # analyze_result = batch_model(images)
#
for index in range(len(dataset)): # for index in range(len(dataset)):
page_data = dataset.get_page(index) # page_data = dataset.get_page(index)
img_dict = page_data.get_image() # img_dict = page_data.get_image()
page_width = img_dict['width'] # page_width = img_dict['width']
page_height = img_dict['height'] # page_height = img_dict['height']
if start_page_id <= index <= end_page_id: # if start_page_id <= index <= end_page_id:
result = analyze_result.pop(0) # result = analyze_result.pop(0)
else: # else:
result = [] # result = []
#
page_info = {'page_no': index, 'height': page_height, 'width': page_width} # page_info = {'page_no': index, 'height': page_height, 'width': page_width}
page_dict = {'layout_dets': result, 'page_info': page_info} # page_dict = {'layout_dets': result, 'page_info': page_info}
model_json.append(page_dict) # model_json.append(page_dict)
#
# TODO: clean memory when gpu memory is not enough # # TODO: clean memory when gpu memory is not enough
clean_memory_start_time = time.time() # clean_memory_start_time = time.time()
clean_memory(get_device()) # clean_memory(get_device())
logger.info(f'clean memory time: {round(time.time() - clean_memory_start_time, 2)}') # logger.info(f'clean memory time: {round(time.time() - clean_memory_start_time, 2)}')
#
return InferenceResult(model_json, dataset) # return InferenceResult(model_json, dataset)
...@@ -3,8 +3,12 @@ import time ...@@ -3,8 +3,12 @@ import time
# 关闭paddle的信号处理 # 关闭paddle的信号处理
import paddle import paddle
import torch
from loguru import logger from loguru import logger
from magic_pdf.model.batch_analyze import BatchAnalyze
from magic_pdf.model.sub_modules.model_utils import get_vram
paddle.disable_signal_handler() paddle.disable_signal_handler()
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新 os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
...@@ -154,33 +158,77 @@ def doc_analyze( ...@@ -154,33 +158,77 @@ def doc_analyze(
table_enable=None, table_enable=None,
) -> InferenceResult: ) -> InferenceResult:
end_page_id = end_page_id if end_page_id else len(dataset)
model_manager = ModelSingleton() model_manager = ModelSingleton()
custom_model = model_manager.get_model( custom_model = model_manager.get_model(
ocr, show_log, lang, layout_model, formula_enable, table_enable ocr, show_log, lang, layout_model, formula_enable, table_enable
) )
batch_analyze = False
device = get_device()
npu_support = False
if str(device).startswith("npu"):
import torch_npu
if torch_npu.npu.is_available():
npu_support = True
if torch.cuda.is_available() and device != 'cpu' or npu_support:
gpu_memory = get_vram(device)
if gpu_memory is not None and gpu_memory >= 7:
batch_ratio = int((gpu_memory-3) // 1.5)
if batch_ratio >= 1:
logger.info(f'gpu_memory: {gpu_memory} GB, batch_ratio: {batch_ratio}')
batch_model = BatchAnalyze(model=custom_model, batch_ratio=batch_ratio)
batch_analyze = True
model_json = [] model_json = []
doc_analyze_start = time.time() doc_analyze_start = time.time()
if end_page_id is None: if batch_analyze:
end_page_id = len(dataset) # batch analyze
images = []
for index in range(len(dataset)): for index in range(len(dataset)):
page_data = dataset.get_page(index) if start_page_id <= index <= end_page_id:
img_dict = page_data.get_image() page_data = dataset.get_page(index)
img = img_dict['img'] img_dict = page_data.get_image()
page_width = img_dict['width'] images.append(img_dict['img'])
page_height = img_dict['height'] analyze_result = batch_model(images)
if start_page_id <= index <= end_page_id:
page_start = time.time() for index in range(len(dataset)):
result = custom_model(img) page_data = dataset.get_page(index)
logger.info(f'-----page_id : {index}, page total time: {round(time.time() - page_start, 2)}-----') img_dict = page_data.get_image()
else: page_width = img_dict['width']
result = [] page_height = img_dict['height']
if start_page_id <= index <= end_page_id:
result = analyze_result.pop(0)
else:
result = []
page_info = {'page_no': index, 'height': page_height, 'width': page_width}
page_dict = {'layout_dets': result, 'page_info': page_info}
model_json.append(page_dict)
page_info = {'page_no': index, 'height': page_height, 'width': page_width} else:
page_dict = {'layout_dets': result, 'page_info': page_info} # single analyze
model_json.append(page_dict)
for index in range(len(dataset)):
page_data = dataset.get_page(index)
img_dict = page_data.get_image()
img = img_dict['img']
page_width = img_dict['width']
page_height = img_dict['height']
if start_page_id <= index <= end_page_id:
page_start = time.time()
result = custom_model(img)
logger.info(f'-----page_id : {index}, page total time: {round(time.time() - page_start, 2)}-----')
else:
result = []
page_info = {'page_no': index, 'height': page_height, 'width': page_width}
page_dict = {'layout_dets': result, 'page_info': page_info}
model_json.append(page_dict)
gc_start = time.time() gc_start = time.time()
clean_memory(get_device()) clean_memory(get_device())
......
...@@ -228,7 +228,7 @@ class CustomPEKModel: ...@@ -228,7 +228,7 @@ class CustomPEKModel:
logger.info(f'formula nums: {len(formula_list)}, mfr time: {mfr_cost}') logger.info(f'formula nums: {len(formula_list)}, mfr time: {mfr_cost}')
# 清理显存 # 清理显存
clean_vram(self.device, vram_threshold=8) clean_vram(self.device, vram_threshold=6)
# 从layout_res中获取ocr区域、表格区域、公式区域 # 从layout_res中获取ocr区域、表格区域、公式区域
ocr_res_list, table_res_list, single_page_mfdetrec_res = ( ocr_res_list, table_res_list, single_page_mfdetrec_res = (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment