Commit 7990e7df authored by myhloli's avatar myhloli
Browse files

feat(model): add npu support and optimize table model

- Add NPU support for memory cleaning and model initialization
- Optimize table model initialization and prediction process
- Update memory utils to support NPU
- Add language parameter for table model
parent 96f8da2a
......@@ -7,4 +7,7 @@ def clean_memory():
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
elif torch.npu.is_available():
torch.npu.empty_cache()
torch.npu.ipc_collect()
gc.collect()
\ No newline at end of file
......@@ -87,6 +87,12 @@ class CustomPEKModel:
)
# 初始化解析方案
self.device = kwargs.get('device', 'cpu')
if str(self.device).startswith("npu"):
import torch_npu
os.environ['FLAGS_npu_jit_compile'] = '0'
os.environ['FLAGS_use_stride_kernel'] = '0'
logger.info('using device: {}'.format(self.device))
models_dir = kwargs.get(
'models_dir', os.path.join(root_dir, 'resources', 'models')
......@@ -164,6 +170,7 @@ class CustomPEKModel:
table_model_path=str(os.path.join(models_dir, table_model_dir)),
table_max_time=self.table_max_time,
device=self.device,
lang=self.lang,
)
logger.info('DocAnalysis init done!')
......
import torch
from loguru import logger
from magic_pdf.config.constants import MODEL_NAME
from magic_pdf.libs.config_reader import get_device
from magic_pdf.model.model_list import AtomicModel
from magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO import \
DocLayoutYOLOModel
......@@ -19,7 +21,7 @@ from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import \
TableMasterPaddleModel
def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
def table_model_init(table_model_type, model_path, max_time, _device_='cpu', lang=None):
if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
table_model = StructTableModel(model_path, max_new_tokens=2048, max_time=max_time)
elif table_model_type == MODEL_NAME.TABLE_MASTER:
......@@ -29,7 +31,7 @@ def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
}
table_model = TableMasterPaddleModel(config)
elif table_model_type == MODEL_NAME.RAPID_TABLE:
table_model = RapidTableModel()
table_model = RapidTableModel(lang)
else:
logger.error('table model type not allow')
exit(1)
......@@ -38,6 +40,8 @@ def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
def mfd_model_init(weight, device='cpu'):
if str(device).startswith("npu"):
device = torch.device(device)
mfd_model = YOLOv8MFDModel(weight, device)
return mfd_model
......@@ -53,6 +57,8 @@ def layout_model_init(weight, config_file, device):
def doclayout_yolo_model_init(weight, device='cpu'):
if str(device).startswith("npu"):
device = torch.device(device)
model = DocLayoutYOLOModel(weight, device)
return model
......@@ -63,6 +69,12 @@ def ocr_model_init(show_log: bool = False,
use_dilation=True,
det_db_unclip_ratio=1.8,
):
use_npu = False
device = get_device()
if str(device).startswith("npu"):
use_npu = True
if lang is not None and lang != '':
model = ModifiedPaddleOCR(
show_log=show_log,
......@@ -70,6 +82,7 @@ def ocr_model_init(show_log: bool = False,
lang=lang,
use_dilation=use_dilation,
det_db_unclip_ratio=det_db_unclip_ratio,
use_npu=use_npu,
)
else:
model = ModifiedPaddleOCR(
......@@ -77,7 +90,7 @@ def ocr_model_init(show_log: bool = False,
det_db_box_thresh=det_db_box_thresh,
use_dilation=use_dilation,
det_db_unclip_ratio=det_db_unclip_ratio,
# use_angle_cls=True,
use_npu=use_npu,
)
return model
......@@ -146,7 +159,8 @@ def atom_model_init(model_name: str, **kwargs):
kwargs.get('table_model_name'),
kwargs.get('table_model_path'),
kwargs.get('table_max_time'),
kwargs.get('device')
kwargs.get('device'),
kwargs.get('lang'),
)
else:
logger.error('model name not allow')
......
......@@ -54,4 +54,7 @@ def get_vram(device):
if torch.cuda.is_available() and device != 'cpu':
total_memory = torch.cuda.get_device_properties(device).total_memory / (1024 ** 3) # 将字节转换为 GB
return total_memory
elif torch.npu.is_available() and device != 'cpu':
total_memory = torch.npu.get_device_properties(device).total_memory / (1024 ** 3) # 转为 GB
return total_memory
return None
\ No newline at end of file
import os
import cv2
import numpy as np
from rapid_table import RapidTable
from rapidocr_paddle import RapidOCR
try:
import torchtext
if torchtext.__version__ >= '0.18.0':
torchtext.disable_torchtext_deprecation_warning()
except ImportError:
pass
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
class RapidTableModel(object):
def __init__(self):
def __init__(self, lang=None):
self.table_model = RapidTable()
self.ocr_engine = RapidOCR(det_use_cuda=True, cls_use_cuda=True, rec_use_cuda=True)
# self.ocr_engine = RapidOCR(det_use_cuda=True, cls_use_cuda=True, rec_use_cuda=True)
atom_model_manager = AtomModelSingleton()
self.ocr_engine = atom_model_manager.get_atom_model(
atom_model_name='ocr',
ocr_show_log=False,
det_db_box_thresh=0.3,
lang=lang,
)
def predict(self, image):
ocr_result, _ = self.ocr_engine(np.asarray(image))
if ocr_result is None:
# ocr_result, _ = self.ocr_engine(np.asarray(image))
bgr_image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
ocr_result = self.ocr_engine.ocr(bgr_image)[0]
ocr_result = [[item[0], item[1][0], item[1][1]] for item in ocr_result if
len(item) == 2 and isinstance(item[1], tuple)]
if ocr_result:
html_code, table_cell_bboxes, elapse = self.table_model(np.asarray(image), ocr_result)
return html_code, table_cell_bboxes, elapse
else:
return None, None, None
html_code, table_cell_bboxes, elapse = self.table_model(np.asarray(image), ocr_result)
return html_code, table_cell_bboxes, elapse
\ No newline at end of file
......@@ -284,6 +284,14 @@ def model_init(model_name: str):
supports_bfloat16 = True
else:
supports_bfloat16 = False
elif torch.npu.is_available():
device = torch.device('npu')
if torch.npu.is_bf16_supported():
supports_bfloat16 = True
else:
supports_bfloat16 = False
else:
device = torch.device('cpu')
supports_bfloat16 = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment