Unverified Commit 80fd937f authored by Xiaomeng Zhao's avatar Xiaomeng Zhao Committed by GitHub
Browse files

Merge pull request #1672 from myhloli/dev

refactor(model): integrate Ascend plugin for NPU support
parents 6e1fba93 f5112e21
import os
import time
import torch
os.environ['FLAGS_npu_jit_compile'] = '0' # 关闭paddle的jit编译
os.environ['FLAGS_use_stride_kernel'] = '0'
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 让mps可以fallback
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
# 关闭paddle的信号处理
import paddle
import torch
paddle.disable_signal_handler()
from loguru import logger
from magic_pdf.model.batch_analyze import BatchAnalyze
from magic_pdf.model.sub_modules.model_utils import get_vram
paddle.disable_signal_handler()
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
try:
import torchtext
if torchtext.__version__ >= '0.18.0':
torchtext.disable_torchtext_deprecation_warning()
except ImportError:
......@@ -32,20 +33,6 @@ from magic_pdf.model.model_list import MODEL
from magic_pdf.operators.models import InferenceResult
def dict_compare(d1, d2):
return d1.items() == d2.items()
def remove_duplicates_dicts(lst):
unique_dicts = []
for dict_item in lst:
if not any(
dict_compare(dict_item, existing_dict) for existing_dict in unique_dicts
):
unique_dicts.append(dict_item)
return unique_dicts
class ModelSingleton:
_instance = None
_models = {}
......
......@@ -89,13 +89,6 @@ class CustomPEKModel:
# 初始化解析方案
self.device = kwargs.get('device', 'cpu')
if str(self.device).startswith("npu"):
import torch_npu
os.environ['FLAGS_npu_jit_compile'] = '0'
os.environ['FLAGS_use_stride_kernel'] = '0'
elif str(self.device).startswith("mps"):
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
logger.info('using device: {}'.format(self.device))
models_dir = kwargs.get(
'models_dir', os.path.join(root_dir, 'resources', 'models')
......
......@@ -4,22 +4,22 @@ from loguru import logger
from magic_pdf.config.constants import MODEL_NAME
from magic_pdf.model.model_list import AtomicModel
from magic_pdf.model.sub_modules.language_detection.yolov11.YOLOv11 import YOLOv11LangDetModel
from magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO import \
DocLayoutYOLOModel
from magic_pdf.model.sub_modules.layout.layoutlmv3.model_init import \
Layoutlmv3_Predictor
from magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO import DocLayoutYOLOModel
from magic_pdf.model.sub_modules.layout.layoutlmv3.model_init import Layoutlmv3_Predictor
from magic_pdf.model.sub_modules.mfd.yolov8.YOLOv8 import YOLOv8MFDModel
from magic_pdf.model.sub_modules.mfr.unimernet.Unimernet import UnimernetModel
from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_273_mod import \
ModifiedPaddleOCR
from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import \
RapidTableModel
# from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_291_mod import ModifiedPaddleOCR
from magic_pdf.model.sub_modules.table.structeqtable.struct_eqtable import \
StructTableModel
from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import \
TableMasterPaddleModel
try:
from magic_pdf_ascend_plugin.model_plugin.ocr.paddleocr.ppocr_273_npu import ModifiedPaddleOCR
from magic_pdf_ascend_plugin.model_plugin.table.rapidtable.rapid_table_npu import RapidTableModel
logger.info('Using Ascend Plugin')
except ImportError:
from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_273_mod import ModifiedPaddleOCR
# from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_291_mod import ModifiedPaddleOCR
from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import RapidTableModel
from magic_pdf.model.sub_modules.table.structeqtable.struct_eqtable import StructTableModel
from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import TableMasterPaddleModel
def table_model_init(table_model_type, model_path, max_time, _device_='cpu', ocr_engine=None, table_sub_model_name=None):
if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
......@@ -76,7 +76,6 @@ def ocr_model_init(show_log: bool = False,
use_dilation=True,
det_db_unclip_ratio=1.8,
):
if lang is not None and lang != '':
model = ModifiedPaddleOCR(
show_log=show_log,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment