Commit 50f48417 authored by myhloli's avatar myhloli
Browse files

refactor(device): optimize memory cleaning and device selection

- Update clean_memory function to support both CUDA and NPU devices
- Implement get_device function to centralize device selection logic
- Modify model initialization and memory cleaning to use the selected device
- Update RapidTableModel to support both RapidOCR and PaddleOCR engines
parent 7990e7df
......@@ -3,11 +3,14 @@ import torch
import gc
def clean_memory():
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
elif torch.npu.is_available():
torch.npu.empty_cache()
torch.npu.ipc_collect()
def clean_memory(device='cuda'):
if device == 'cuda':
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
elif str(device).startswith("npu"):
import torch_npu
if torch.npu.is_available():
torch_npu.empty_cache()
torch_npu.ipc_collect()
gc.collect()
\ No newline at end of file
......@@ -10,6 +10,7 @@ from magic_pdf.config.constants import MODEL_NAME
from magic_pdf.config.exceptions import CUDA_NOT_AVAILABLE
from magic_pdf.data.dataset import Dataset
from magic_pdf.libs.clean_memory import clean_memory
from magic_pdf.libs.config_reader import get_device
from magic_pdf.model.doc_analyze_by_custom_model import ModelSingleton
from magic_pdf.model.pdf_extract_kit import CustomPEKModel
from magic_pdf.model.sub_modules.model_utils import (
......@@ -268,7 +269,7 @@ def doc_batch_analyze(
# TODO: clean memory when gpu memory is not enough
clean_memory_start_time = time.time()
clean_memory()
clean_memory(get_device())
logger.info(f'clean memory time: {round(time.time() - clean_memory_start_time, 2)}')
return InferenceResult(model_json, dataset)
......@@ -183,7 +183,7 @@ def doc_analyze(
model_json.append(page_dict)
gc_start = time.time()
clean_memory()
clean_memory(get_device())
gc_time = round(time.time() - gc_start, 2)
logger.info(f'gc time: {gc_time}')
......
......@@ -170,7 +170,7 @@ class CustomPEKModel:
table_model_path=str(os.path.join(models_dir, table_model_dir)),
table_max_time=self.table_max_time,
device=self.device,
lang=self.lang,
ocr_engine=self.ocr_model,
)
logger.info('DocAnalysis init done!')
......
......@@ -21,7 +21,7 @@ from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import \
TableMasterPaddleModel
def table_model_init(table_model_type, model_path, max_time, _device_='cpu', lang=None):
def table_model_init(table_model_type, model_path, max_time, _device_='cpu', ocr_engine=None):
if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
table_model = StructTableModel(model_path, max_new_tokens=2048, max_time=max_time)
elif table_model_type == MODEL_NAME.TABLE_MASTER:
......@@ -31,7 +31,7 @@ def table_model_init(table_model_type, model_path, max_time, _device_='cpu', lan
}
table_model = TableMasterPaddleModel(config)
elif table_model_type == MODEL_NAME.RAPID_TABLE:
table_model = RapidTableModel(lang)
table_model = RapidTableModel(ocr_engine)
else:
logger.error('table model type not allow')
exit(1)
......@@ -160,7 +160,6 @@ def atom_model_init(model_name: str, **kwargs):
kwargs.get('table_model_path'),
kwargs.get('table_max_time'),
kwargs.get('device'),
kwargs.get('lang'),
)
else:
logger.error('model name not allow')
......
......@@ -45,7 +45,7 @@ def clean_vram(device, vram_threshold=8):
total_memory = get_vram(device)
if total_memory and total_memory <= vram_threshold:
gc_start = time.time()
clean_memory()
clean_memory(device)
gc_time = round(time.time() - gc_start, 2)
logger.info(f"gc time: {gc_time}")
......@@ -54,7 +54,10 @@ def get_vram(device):
if torch.cuda.is_available() and device != 'cpu':
total_memory = torch.cuda.get_device_properties(device).total_memory / (1024 ** 3) # 将字节转换为 GB
return total_memory
elif torch.npu.is_available() and device != 'cpu':
total_memory = torch.npu.get_device_properties(device).total_memory / (1024 ** 3) # 转为 GB
return total_memory
return None
\ No newline at end of file
elif str(device).startswith("npu"):
import torch_npu
if torch.npu.is_available():
total_memory = torch.npu.get_device_properties(device).total_memory / (1024 ** 3) # 转为 GB
return total_memory
else:
return None
\ No newline at end of file
import os
import cv2
import numpy as np
from loguru import logger
from rapid_table import RapidTable
from rapidocr_paddle import RapidOCR
try:
import torchtext
if torchtext.__version__ >= '0.18.0':
torchtext.disable_torchtext_deprecation_warning()
except ImportError:
pass
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
class RapidTableModel(object):
def __init__(self, lang=None):
def __init__(self, ocr_engine):
self.table_model = RapidTable()
# self.ocr_engine = RapidOCR(det_use_cuda=True, cls_use_cuda=True, rec_use_cuda=True)
atom_model_manager = AtomModelSingleton()
self.ocr_engine = atom_model_manager.get_atom_model(
atom_model_name='ocr',
ocr_show_log=False,
det_db_box_thresh=0.3,
lang=lang,
)
if ocr_engine is None:
self.ocr_model_name = "RapidOCR"
self.ocr_engine = RapidOCR(det_use_cuda=True, cls_use_cuda=True, rec_use_cuda=True)
else:
self.ocr_model_name = "PaddleOCR"
self.ocr_engine = ocr_engine
def predict(self, image):
# ocr_result, _ = self.ocr_engine(np.asarray(image))
bgr_image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
ocr_result = self.ocr_engine.ocr(bgr_image)[0]
ocr_result = [[item[0], item[1][0], item[1][1]] for item in ocr_result if
len(item) == 2 and isinstance(item[1], tuple)]
if self.ocr_model_name == "RapidOCR":
ocr_result, _ = self.ocr_engine(np.asarray(image))
elif self.ocr_model_name == "PaddleOCR":
bgr_image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
ocr_result = self.ocr_engine.ocr(bgr_image)[0]
ocr_result = [[item[0], item[1][0], item[1][1]] for item in ocr_result if
len(item) == 2 and isinstance(item[1], tuple)]
else:
logger.error("OCR model not supported")
ocr_result = None
if ocr_result:
html_code, table_cell_bboxes, elapse = self.table_model(np.asarray(image), ocr_result)
......
......@@ -14,7 +14,7 @@ from magic_pdf.config.ocr_content_type import BlockType, ContentType
from magic_pdf.data.dataset import Dataset, PageableData
from magic_pdf.libs.boxbase import calculate_overlap_area_in_bbox1_area_ratio
from magic_pdf.libs.clean_memory import clean_memory
from magic_pdf.libs.config_reader import get_local_layoutreader_model_dir, get_llm_aided_config
from magic_pdf.libs.config_reader import get_local_layoutreader_model_dir, get_llm_aided_config, get_device
from magic_pdf.libs.convert_utils import dict_to_list
from magic_pdf.libs.hash_utils import compute_md5
from magic_pdf.libs.pdf_image_tools import cut_image_to_pil_image
......@@ -277,21 +277,24 @@ def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, lang
def model_init(model_name: str):
from transformers import LayoutLMv3ForTokenClassification
device = get_device()
if torch.cuda.is_available():
device = torch.device('cuda')
if torch.cuda.is_bf16_supported():
supports_bfloat16 = True
else:
supports_bfloat16 = False
elif torch.npu.is_available():
device = torch.device('npu')
if torch.npu.is_bf16_supported():
supports_bfloat16 = True
elif str(device).startswith("npu"):
import torch_npu
if torch.npu.is_available():
device = torch.device('npu')
if torch.npu.is_bf16_supported():
supports_bfloat16 = True
else:
supports_bfloat16 = False
else:
device = torch.device('cpu')
supports_bfloat16 = False
else:
device = torch.device('cpu')
supports_bfloat16 = False
......@@ -865,7 +868,7 @@ def pdf_parse_union(
'pdf_info': pdf_info_list,
}
clean_memory()
clean_memory(get_device())
return new_pdf_info_dict
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment