Commit 50f48417 authored by myhloli's avatar myhloli
Browse files

refactor(device): optimize memory cleaning and device selection

- Update clean_memory function to support both CUDA and NPU devices
- Implement get_device function to centralize device selection logic
- Modify model initialization and memory cleaning to use the selected device
- Update RapidTableModel to support both RapidOCR and PaddleOCR engines
parent 7990e7df
...@@ -3,11 +3,14 @@ import torch ...@@ -3,11 +3,14 @@ import torch
import gc import gc
def clean_memory(): def clean_memory(device='cuda'):
if torch.cuda.is_available(): if device == 'cuda':
torch.cuda.empty_cache() if torch.cuda.is_available():
torch.cuda.ipc_collect() torch.cuda.empty_cache()
elif torch.npu.is_available(): torch.cuda.ipc_collect()
torch.npu.empty_cache() elif str(device).startswith("npu"):
torch.npu.ipc_collect() import torch_npu
if torch.npu.is_available():
torch_npu.empty_cache()
torch_npu.ipc_collect()
gc.collect() gc.collect()
\ No newline at end of file
...@@ -10,6 +10,7 @@ from magic_pdf.config.constants import MODEL_NAME ...@@ -10,6 +10,7 @@ from magic_pdf.config.constants import MODEL_NAME
from magic_pdf.config.exceptions import CUDA_NOT_AVAILABLE from magic_pdf.config.exceptions import CUDA_NOT_AVAILABLE
from magic_pdf.data.dataset import Dataset from magic_pdf.data.dataset import Dataset
from magic_pdf.libs.clean_memory import clean_memory from magic_pdf.libs.clean_memory import clean_memory
from magic_pdf.libs.config_reader import get_device
from magic_pdf.model.doc_analyze_by_custom_model import ModelSingleton from magic_pdf.model.doc_analyze_by_custom_model import ModelSingleton
from magic_pdf.model.pdf_extract_kit import CustomPEKModel from magic_pdf.model.pdf_extract_kit import CustomPEKModel
from magic_pdf.model.sub_modules.model_utils import ( from magic_pdf.model.sub_modules.model_utils import (
...@@ -268,7 +269,7 @@ def doc_batch_analyze( ...@@ -268,7 +269,7 @@ def doc_batch_analyze(
# TODO: clean memory when gpu memory is not enough # TODO: clean memory when gpu memory is not enough
clean_memory_start_time = time.time() clean_memory_start_time = time.time()
clean_memory() clean_memory(get_device())
logger.info(f'clean memory time: {round(time.time() - clean_memory_start_time, 2)}') logger.info(f'clean memory time: {round(time.time() - clean_memory_start_time, 2)}')
return InferenceResult(model_json, dataset) return InferenceResult(model_json, dataset)
...@@ -183,7 +183,7 @@ def doc_analyze( ...@@ -183,7 +183,7 @@ def doc_analyze(
model_json.append(page_dict) model_json.append(page_dict)
gc_start = time.time() gc_start = time.time()
clean_memory() clean_memory(get_device())
gc_time = round(time.time() - gc_start, 2) gc_time = round(time.time() - gc_start, 2)
logger.info(f'gc time: {gc_time}') logger.info(f'gc time: {gc_time}')
......
...@@ -170,7 +170,7 @@ class CustomPEKModel: ...@@ -170,7 +170,7 @@ class CustomPEKModel:
table_model_path=str(os.path.join(models_dir, table_model_dir)), table_model_path=str(os.path.join(models_dir, table_model_dir)),
table_max_time=self.table_max_time, table_max_time=self.table_max_time,
device=self.device, device=self.device,
lang=self.lang, ocr_engine=self.ocr_model,
) )
logger.info('DocAnalysis init done!') logger.info('DocAnalysis init done!')
......
...@@ -21,7 +21,7 @@ from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import \ ...@@ -21,7 +21,7 @@ from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import \
TableMasterPaddleModel TableMasterPaddleModel
def table_model_init(table_model_type, model_path, max_time, _device_='cpu', lang=None): def table_model_init(table_model_type, model_path, max_time, _device_='cpu', ocr_engine=None):
if table_model_type == MODEL_NAME.STRUCT_EQTABLE: if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
table_model = StructTableModel(model_path, max_new_tokens=2048, max_time=max_time) table_model = StructTableModel(model_path, max_new_tokens=2048, max_time=max_time)
elif table_model_type == MODEL_NAME.TABLE_MASTER: elif table_model_type == MODEL_NAME.TABLE_MASTER:
...@@ -31,7 +31,7 @@ def table_model_init(table_model_type, model_path, max_time, _device_='cpu', lan ...@@ -31,7 +31,7 @@ def table_model_init(table_model_type, model_path, max_time, _device_='cpu', lan
} }
table_model = TableMasterPaddleModel(config) table_model = TableMasterPaddleModel(config)
elif table_model_type == MODEL_NAME.RAPID_TABLE: elif table_model_type == MODEL_NAME.RAPID_TABLE:
table_model = RapidTableModel(lang) table_model = RapidTableModel(ocr_engine)
else: else:
logger.error('table model type not allow') logger.error('table model type not allow')
exit(1) exit(1)
...@@ -160,7 +160,6 @@ def atom_model_init(model_name: str, **kwargs): ...@@ -160,7 +160,6 @@ def atom_model_init(model_name: str, **kwargs):
kwargs.get('table_model_path'), kwargs.get('table_model_path'),
kwargs.get('table_max_time'), kwargs.get('table_max_time'),
kwargs.get('device'), kwargs.get('device'),
kwargs.get('lang'),
) )
else: else:
logger.error('model name not allow') logger.error('model name not allow')
......
...@@ -45,7 +45,7 @@ def clean_vram(device, vram_threshold=8): ...@@ -45,7 +45,7 @@ def clean_vram(device, vram_threshold=8):
total_memory = get_vram(device) total_memory = get_vram(device)
if total_memory and total_memory <= vram_threshold: if total_memory and total_memory <= vram_threshold:
gc_start = time.time() gc_start = time.time()
clean_memory() clean_memory(device)
gc_time = round(time.time() - gc_start, 2) gc_time = round(time.time() - gc_start, 2)
logger.info(f"gc time: {gc_time}") logger.info(f"gc time: {gc_time}")
...@@ -54,7 +54,10 @@ def get_vram(device): ...@@ -54,7 +54,10 @@ def get_vram(device):
if torch.cuda.is_available() and device != 'cpu': if torch.cuda.is_available() and device != 'cpu':
total_memory = torch.cuda.get_device_properties(device).total_memory / (1024 ** 3) # 将字节转换为 GB total_memory = torch.cuda.get_device_properties(device).total_memory / (1024 ** 3) # 将字节转换为 GB
return total_memory return total_memory
elif torch.npu.is_available() and device != 'cpu': elif str(device).startswith("npu"):
total_memory = torch.npu.get_device_properties(device).total_memory / (1024 ** 3) # 转为 GB import torch_npu
return total_memory if torch.npu.is_available():
return None total_memory = torch.npu.get_device_properties(device).total_memory / (1024 ** 3) # 转为 GB
\ No newline at end of file return total_memory
else:
return None
\ No newline at end of file
import os
import cv2 import cv2
import numpy as np import numpy as np
from loguru import logger
from rapid_table import RapidTable from rapid_table import RapidTable
from rapidocr_paddle import RapidOCR from rapidocr_paddle import RapidOCR
try:
import torchtext
if torchtext.__version__ >= '0.18.0':
torchtext.disable_torchtext_deprecation_warning()
except ImportError:
pass
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
class RapidTableModel(object): class RapidTableModel(object):
def __init__(self, lang=None): def __init__(self, ocr_engine):
self.table_model = RapidTable() self.table_model = RapidTable()
# self.ocr_engine = RapidOCR(det_use_cuda=True, cls_use_cuda=True, rec_use_cuda=True) if ocr_engine is None:
self.ocr_model_name = "RapidOCR"
atom_model_manager = AtomModelSingleton() self.ocr_engine = RapidOCR(det_use_cuda=True, cls_use_cuda=True, rec_use_cuda=True)
self.ocr_engine = atom_model_manager.get_atom_model( else:
atom_model_name='ocr', self.ocr_model_name = "PaddleOCR"
ocr_show_log=False, self.ocr_engine = ocr_engine
det_db_box_thresh=0.3,
lang=lang,
)
def predict(self, image): def predict(self, image):
# ocr_result, _ = self.ocr_engine(np.asarray(image))
bgr_image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR) if self.ocr_model_name == "RapidOCR":
ocr_result = self.ocr_engine.ocr(bgr_image)[0] ocr_result, _ = self.ocr_engine(np.asarray(image))
ocr_result = [[item[0], item[1][0], item[1][1]] for item in ocr_result if elif self.ocr_model_name == "PaddleOCR":
len(item) == 2 and isinstance(item[1], tuple)] bgr_image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
ocr_result = self.ocr_engine.ocr(bgr_image)[0]
ocr_result = [[item[0], item[1][0], item[1][1]] for item in ocr_result if
len(item) == 2 and isinstance(item[1], tuple)]
else:
logger.error("OCR model not supported")
ocr_result = None
if ocr_result: if ocr_result:
html_code, table_cell_bboxes, elapse = self.table_model(np.asarray(image), ocr_result) html_code, table_cell_bboxes, elapse = self.table_model(np.asarray(image), ocr_result)
......
...@@ -14,7 +14,7 @@ from magic_pdf.config.ocr_content_type import BlockType, ContentType ...@@ -14,7 +14,7 @@ from magic_pdf.config.ocr_content_type import BlockType, ContentType
from magic_pdf.data.dataset import Dataset, PageableData from magic_pdf.data.dataset import Dataset, PageableData
from magic_pdf.libs.boxbase import calculate_overlap_area_in_bbox1_area_ratio from magic_pdf.libs.boxbase import calculate_overlap_area_in_bbox1_area_ratio
from magic_pdf.libs.clean_memory import clean_memory from magic_pdf.libs.clean_memory import clean_memory
from magic_pdf.libs.config_reader import get_local_layoutreader_model_dir, get_llm_aided_config from magic_pdf.libs.config_reader import get_local_layoutreader_model_dir, get_llm_aided_config, get_device
from magic_pdf.libs.convert_utils import dict_to_list from magic_pdf.libs.convert_utils import dict_to_list
from magic_pdf.libs.hash_utils import compute_md5 from magic_pdf.libs.hash_utils import compute_md5
from magic_pdf.libs.pdf_image_tools import cut_image_to_pil_image from magic_pdf.libs.pdf_image_tools import cut_image_to_pil_image
...@@ -277,21 +277,24 @@ def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, lang ...@@ -277,21 +277,24 @@ def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, lang
def model_init(model_name: str): def model_init(model_name: str):
from transformers import LayoutLMv3ForTokenClassification from transformers import LayoutLMv3ForTokenClassification
device = get_device()
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device('cuda') device = torch.device('cuda')
if torch.cuda.is_bf16_supported(): if torch.cuda.is_bf16_supported():
supports_bfloat16 = True supports_bfloat16 = True
else: else:
supports_bfloat16 = False supports_bfloat16 = False
elif str(device).startswith("npu"):
elif torch.npu.is_available(): import torch_npu
device = torch.device('npu') if torch.npu.is_available():
if torch.npu.is_bf16_supported(): device = torch.device('npu')
supports_bfloat16 = True if torch.npu.is_bf16_supported():
supports_bfloat16 = True
else:
supports_bfloat16 = False
else: else:
device = torch.device('cpu')
supports_bfloat16 = False supports_bfloat16 = False
else: else:
device = torch.device('cpu') device = torch.device('cpu')
supports_bfloat16 = False supports_bfloat16 = False
...@@ -865,7 +868,7 @@ def pdf_parse_union( ...@@ -865,7 +868,7 @@ def pdf_parse_union(
'pdf_info': pdf_info_list, 'pdf_info': pdf_info_list,
} }
clean_memory() clean_memory(get_device())
return new_pdf_info_dict return new_pdf_info_dict
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment