Commit 7990e7df authored by myhloli's avatar myhloli
Browse files

feat(model): add npu support and optimize table model

- Add NPU support for memory cleaning and model initialization
- Optimize table model initialization and prediction process
- Update memory utils to support NPU
- Add language parameter for table model
parent 96f8da2a
...@@ -7,4 +7,7 @@ def clean_memory(): ...@@ -7,4 +7,7 @@ def clean_memory():
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()
torch.cuda.ipc_collect() torch.cuda.ipc_collect()
elif torch.npu.is_available():
torch.npu.empty_cache()
torch.npu.ipc_collect()
gc.collect() gc.collect()
\ No newline at end of file
...@@ -87,6 +87,12 @@ class CustomPEKModel: ...@@ -87,6 +87,12 @@ class CustomPEKModel:
) )
# 初始化解析方案 # 初始化解析方案
self.device = kwargs.get('device', 'cpu') self.device = kwargs.get('device', 'cpu')
if str(self.device).startswith("npu"):
import torch_npu
os.environ['FLAGS_npu_jit_compile'] = '0'
os.environ['FLAGS_use_stride_kernel'] = '0'
logger.info('using device: {}'.format(self.device)) logger.info('using device: {}'.format(self.device))
models_dir = kwargs.get( models_dir = kwargs.get(
'models_dir', os.path.join(root_dir, 'resources', 'models') 'models_dir', os.path.join(root_dir, 'resources', 'models')
...@@ -164,6 +170,7 @@ class CustomPEKModel: ...@@ -164,6 +170,7 @@ class CustomPEKModel:
table_model_path=str(os.path.join(models_dir, table_model_dir)), table_model_path=str(os.path.join(models_dir, table_model_dir)),
table_max_time=self.table_max_time, table_max_time=self.table_max_time,
device=self.device, device=self.device,
lang=self.lang,
) )
logger.info('DocAnalysis init done!') logger.info('DocAnalysis init done!')
......
import torch
from loguru import logger from loguru import logger
from magic_pdf.config.constants import MODEL_NAME from magic_pdf.config.constants import MODEL_NAME
from magic_pdf.libs.config_reader import get_device
from magic_pdf.model.model_list import AtomicModel from magic_pdf.model.model_list import AtomicModel
from magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO import \ from magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO import \
DocLayoutYOLOModel DocLayoutYOLOModel
...@@ -19,7 +21,7 @@ from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import \ ...@@ -19,7 +21,7 @@ from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import \
TableMasterPaddleModel TableMasterPaddleModel
def table_model_init(table_model_type, model_path, max_time, _device_='cpu'): def table_model_init(table_model_type, model_path, max_time, _device_='cpu', lang=None):
if table_model_type == MODEL_NAME.STRUCT_EQTABLE: if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
table_model = StructTableModel(model_path, max_new_tokens=2048, max_time=max_time) table_model = StructTableModel(model_path, max_new_tokens=2048, max_time=max_time)
elif table_model_type == MODEL_NAME.TABLE_MASTER: elif table_model_type == MODEL_NAME.TABLE_MASTER:
...@@ -29,7 +31,7 @@ def table_model_init(table_model_type, model_path, max_time, _device_='cpu'): ...@@ -29,7 +31,7 @@ def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
} }
table_model = TableMasterPaddleModel(config) table_model = TableMasterPaddleModel(config)
elif table_model_type == MODEL_NAME.RAPID_TABLE: elif table_model_type == MODEL_NAME.RAPID_TABLE:
table_model = RapidTableModel() table_model = RapidTableModel(lang)
else: else:
logger.error('table model type not allow') logger.error('table model type not allow')
exit(1) exit(1)
...@@ -38,6 +40,8 @@ def table_model_init(table_model_type, model_path, max_time, _device_='cpu'): ...@@ -38,6 +40,8 @@ def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
def mfd_model_init(weight, device='cpu'): def mfd_model_init(weight, device='cpu'):
if str(device).startswith("npu"):
device = torch.device(device)
mfd_model = YOLOv8MFDModel(weight, device) mfd_model = YOLOv8MFDModel(weight, device)
return mfd_model return mfd_model
...@@ -53,6 +57,8 @@ def layout_model_init(weight, config_file, device): ...@@ -53,6 +57,8 @@ def layout_model_init(weight, config_file, device):
def doclayout_yolo_model_init(weight, device='cpu'): def doclayout_yolo_model_init(weight, device='cpu'):
if str(device).startswith("npu"):
device = torch.device(device)
model = DocLayoutYOLOModel(weight, device) model = DocLayoutYOLOModel(weight, device)
return model return model
...@@ -63,6 +69,12 @@ def ocr_model_init(show_log: bool = False, ...@@ -63,6 +69,12 @@ def ocr_model_init(show_log: bool = False,
use_dilation=True, use_dilation=True,
det_db_unclip_ratio=1.8, det_db_unclip_ratio=1.8,
): ):
use_npu = False
device = get_device()
if str(device).startswith("npu"):
use_npu = True
if lang is not None and lang != '': if lang is not None and lang != '':
model = ModifiedPaddleOCR( model = ModifiedPaddleOCR(
show_log=show_log, show_log=show_log,
...@@ -70,6 +82,7 @@ def ocr_model_init(show_log: bool = False, ...@@ -70,6 +82,7 @@ def ocr_model_init(show_log: bool = False,
lang=lang, lang=lang,
use_dilation=use_dilation, use_dilation=use_dilation,
det_db_unclip_ratio=det_db_unclip_ratio, det_db_unclip_ratio=det_db_unclip_ratio,
use_npu=use_npu,
) )
else: else:
model = ModifiedPaddleOCR( model = ModifiedPaddleOCR(
...@@ -77,7 +90,7 @@ def ocr_model_init(show_log: bool = False, ...@@ -77,7 +90,7 @@ def ocr_model_init(show_log: bool = False,
det_db_box_thresh=det_db_box_thresh, det_db_box_thresh=det_db_box_thresh,
use_dilation=use_dilation, use_dilation=use_dilation,
det_db_unclip_ratio=det_db_unclip_ratio, det_db_unclip_ratio=det_db_unclip_ratio,
# use_angle_cls=True, use_npu=use_npu,
) )
return model return model
...@@ -146,7 +159,8 @@ def atom_model_init(model_name: str, **kwargs): ...@@ -146,7 +159,8 @@ def atom_model_init(model_name: str, **kwargs):
kwargs.get('table_model_name'), kwargs.get('table_model_name'),
kwargs.get('table_model_path'), kwargs.get('table_model_path'),
kwargs.get('table_max_time'), kwargs.get('table_max_time'),
kwargs.get('device') kwargs.get('device'),
kwargs.get('lang'),
) )
else: else:
logger.error('model name not allow') logger.error('model name not allow')
......
...@@ -54,4 +54,7 @@ def get_vram(device): ...@@ -54,4 +54,7 @@ def get_vram(device):
if torch.cuda.is_available() and device != 'cpu': if torch.cuda.is_available() and device != 'cpu':
total_memory = torch.cuda.get_device_properties(device).total_memory / (1024 ** 3) # 将字节转换为 GB total_memory = torch.cuda.get_device_properties(device).total_memory / (1024 ** 3) # 将字节转换为 GB
return total_memory return total_memory
elif torch.npu.is_available() and device != 'cpu':
total_memory = torch.npu.get_device_properties(device).total_memory / (1024 ** 3) # 转为 GB
return total_memory
return None return None
\ No newline at end of file
import os
import cv2
import numpy as np import numpy as np
from rapid_table import RapidTable from rapid_table import RapidTable
from rapidocr_paddle import RapidOCR from rapidocr_paddle import RapidOCR
try:
import torchtext
if torchtext.__version__ >= '0.18.0':
torchtext.disable_torchtext_deprecation_warning()
except ImportError:
pass
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
class RapidTableModel(object): class RapidTableModel(object):
def __init__(self): def __init__(self, lang=None):
self.table_model = RapidTable() self.table_model = RapidTable()
self.ocr_engine = RapidOCR(det_use_cuda=True, cls_use_cuda=True, rec_use_cuda=True) # self.ocr_engine = RapidOCR(det_use_cuda=True, cls_use_cuda=True, rec_use_cuda=True)
atom_model_manager = AtomModelSingleton()
self.ocr_engine = atom_model_manager.get_atom_model(
atom_model_name='ocr',
ocr_show_log=False,
det_db_box_thresh=0.3,
lang=lang,
)
def predict(self, image): def predict(self, image):
ocr_result, _ = self.ocr_engine(np.asarray(image)) # ocr_result, _ = self.ocr_engine(np.asarray(image))
if ocr_result is None:
return None, None, None bgr_image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
ocr_result = self.ocr_engine.ocr(bgr_image)[0]
ocr_result = [[item[0], item[1][0], item[1][1]] for item in ocr_result if
len(item) == 2 and isinstance(item[1], tuple)]
if ocr_result:
html_code, table_cell_bboxes, elapse = self.table_model(np.asarray(image), ocr_result) html_code, table_cell_bboxes, elapse = self.table_model(np.asarray(image), ocr_result)
return html_code, table_cell_bboxes, elapse return html_code, table_cell_bboxes, elapse
else:
return None, None, None
...@@ -284,6 +284,14 @@ def model_init(model_name: str): ...@@ -284,6 +284,14 @@ def model_init(model_name: str):
supports_bfloat16 = True supports_bfloat16 = True
else: else:
supports_bfloat16 = False supports_bfloat16 = False
elif torch.npu.is_available():
device = torch.device('npu')
if torch.npu.is_bf16_supported():
supports_bfloat16 = True
else:
supports_bfloat16 = False
else: else:
device = torch.device('cpu') device = torch.device('cpu')
supports_bfloat16 = False supports_bfloat16 = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment