Unverified Commit f4ffdfe8 authored by Xiaomeng Zhao's avatar Xiaomeng Zhao Committed by GitHub
Browse files

Merge pull request #2062 from myhloli/dev

feat: support 3.10~3.12 & remove paddle
parents ec566d22 cb3a4314
......@@ -216,7 +216,7 @@ https://github.com/user-attachments/assets/4bea02c9-6d54-4cd6-97ed-dff14340982c
</tr>
<tr>
<td colspan="3">python版本</td>
<td colspan="3">3.10 (请务必通过conda创建3.10虚拟环境)</td>
<td colspan="3">>=3.9,<=3.12</td>
</tr>
<tr>
<td colspan="3">Nvidia Driver 版本</td>
......@@ -226,8 +226,8 @@ https://github.com/user-attachments/assets/4bea02c9-6d54-4cd6-97ed-dff14340982c
</tr>
<tr>
<td colspan="3">CUDA环境</td>
<td>自动安装[12.1(pytorch)+11.8(paddle)]</td>
<td>11.8(手动安装)+cuDNN v8.7.0(手动安装)</td>
<td>11.8/12.4/12.6</td>
<td>11.8/12.4/12.6</td>
<td>None</td>
</tr>
<tr>
......@@ -237,12 +237,12 @@ https://github.com/user-attachments/assets/4bea02c9-6d54-4cd6-97ed-dff14340982c
<td>None</td>
</tr>
<tr>
<td rowspan="2">GPU硬件支持列表</td>
<td colspan="2">显存8G以上</td>
<td rowspan="2">GPU/MPS 硬件支持列表</td>
<td colspan="2">显存6G以上</td>
<td colspan="2">
2080~2080Ti / 3060Ti~3090Ti / 4060~4090<br>
8G显存及以上可开启全部加速功能</td>
<td rowspan="2">None</td>
Volta(2017)及之后生产的全部带Tensor Core的GPU <br>
6G显存及以上</td>
<td rowspan="2">apple slicon</td>
</tr>
</table>
......@@ -262,9 +262,9 @@ https://github.com/user-attachments/assets/4bea02c9-6d54-4cd6-97ed-dff14340982c
> 最新版本国内镜像源同步可能会有延迟,请耐心等待
```bash
conda create -n mineru python=3.10
conda create -n mineru 'python<3.13' -y
conda activate mineru
pip install -U "magic-pdf[full]" --extra-index-url https://wheels.myhloli.com -i https://mirrors.aliyun.com/pypi/simple
pip install -U "magic-pdf[full]" -i https://mirrors.aliyun.com/pypi/simple
```
#### 2. 下载模型权重文件
......
......@@ -34,10 +34,9 @@ RUN python3 -m venv /opt/mineru_venv
RUN /bin/bash -c "source /opt/mineru_venv/bin/activate && \
pip3 install --upgrade pip -i https://mirrors.aliyun.com/pypi/simple && \
wget https://gcore.jsdelivr.net/gh/opendatalab/MinerU@master/docker/ascend_npu/requirements.txt -O requirements.txt && \
pip3 install -r requirements.txt --extra-index-url https://wheels.myhloli.com -i https://mirrors.aliyun.com/pypi/simple && \
pip3 install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple && \
wget https://gitee.com/ascend/pytorch/releases/download/v6.0.rc2-pytorch2.3.1/torch_npu-2.3.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl && \
pip3 install torch_npu-2.3.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl && \
pip3 install https://gcore.jsdelivr.net/gh/myhloli/wheels@main/assets/whl/paddle-custom-npu/paddle_custom_npu-0.0.0-cp310-cp310-linux_aarch64.whl"
pip3 install torch_npu-2.3.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl"
# Copy the configuration file template and install magic-pdf latest
RUN /bin/bash -c "wget https://gcore.jsdelivr.net/gh/opendatalab/MinerU@master/magic-pdf.template.json && \
......
boto3>=1.28.43
Brotli>=1.1.0
click>=8.1.7
PyMuPDF>=1.24.9,<=1.24.14
PyMuPDF>=1.24.9,<1.25.0
loguru>=0.6.0
numpy>=1.21.6,<2.0.0
fast-langdetect>=0.2.3,<0.3.0
......@@ -11,11 +11,9 @@ torch==2.3.1
torchvision==0.18.1
matplotlib
ultralytics>=8.3.48
paddleocr==2.7.3
paddlepaddle==3.0.0rc1
rapidocr-paddle>=1.4.5,<2.0.0
rapidocr-onnxruntime>=1.4.4,<2.0.0
rapid-table>=1.0.3,<2.0.0
doclayout-yolo==0.0.2b1
ftfy
openai
pydantic>=2.7.2,<2.11
transformers>=4.49.0,<5.0.0
\ No newline at end of file
......@@ -31,8 +31,7 @@ RUN python3 -m venv /opt/mineru_venv
RUN /bin/bash -c "source /opt/mineru_venv/bin/activate && \
pip3 install --upgrade pip -i https://mirrors.aliyun.com/pypi/simple && \
wget https://gcore.jsdelivr.net/gh/opendatalab/MinerU@master/docker/china/requirements.txt -O requirements.txt && \
pip3 install -r requirements.txt --extra-index-url https://wheels.myhloli.com -i https://mirrors.aliyun.com/pypi/simple && \
pip3 install paddlepaddle-gpu==3.0.0rc1 -i https://www.paddlepaddle.org.cn/packages/stable/cu118/"
pip3 install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple"
# Copy the configuration file template and install magic-pdf latest
RUN /bin/bash -c "wget https://gcore.jsdelivr.net/gh/opendatalab/MinerU@master/magic-pdf.template.json && \
......
boto3>=1.28.43
Brotli>=1.1.0
click>=8.1.7
PyMuPDF>=1.24.9,<=1.24.14
PyMuPDF>=1.24.9,<1.25.0
loguru>=0.6.0
numpy>=1.21.6,<2.0.0
fast-langdetect>=0.2.3,<0.3.0
......@@ -11,10 +11,9 @@ torch>=2.2.2,!=2.5.0,!=2.5.1,<=2.6.0
torchvision
matplotlib
ultralytics>=8.3.48
paddleocr==2.7.3
rapidocr-paddle>=1.4.5,<2.0.0
rapidocr-onnxruntime>=1.4.4,<2.0.0
rapid-table>=1.0.3,<2.0.0
doclayout-yolo==0.0.2b1
ftfy
openai
pydantic>=2.7.2,<2.11
transformers>=4.49.0,<5.0.0
\ No newline at end of file
......@@ -31,8 +31,7 @@ RUN python3 -m venv /opt/mineru_venv
RUN /bin/bash -c "source /opt/mineru_venv/bin/activate && \
pip3 install --upgrade pip && \
wget https://github.com/opendatalab/MinerU/raw/master/docker/global/requirements.txt -O requirements.txt && \
pip3 install -r requirements.txt --extra-index-url https://wheels.myhloli.com && \
pip3 install paddlepaddle-gpu==3.0.0rc1 -i https://www.paddlepaddle.org.cn/packages/stable/cu118/"
pip3 install -r requirements.txt"
# Copy the configuration file template and install magic-pdf latest
RUN /bin/bash -c "wget https://github.com/opendatalab/MinerU/raw/master/magic-pdf.template.json && \
......
boto3>=1.28.43
Brotli>=1.1.0
click>=8.1.7
PyMuPDF>=1.24.9,<=1.24.14
PyMuPDF>=1.24.9,<1.25.0
loguru>=0.6.0
numpy>=1.21.6,<2.0.0
fast-langdetect>=0.2.3,<0.3.0
......@@ -11,10 +11,9 @@ torch>=2.2.2,!=2.5.0,!=2.5.1,<=2.6.0
torchvision
matplotlib
ultralytics>=8.3.48
paddleocr==2.7.3
rapidocr-paddle>=1.4.5,<2.0.0
rapidocr-onnxruntime>=1.4.4,<2.0.0
rapid-table>=1.0.3,<2.0.0
doclayout-yolo==0.0.2b1
ftfy
openai
pydantic>=2.7.2,<2.11
transformers>=4.49.0,<5.0.0
\ No newline at end of file
......@@ -48,7 +48,18 @@ def measure_time(func):
start_time = time.time()
result = func(*args, **kwargs)
execution_time = time.time() - start_time
PerformanceStats.add_execution_time(func.__name__, execution_time)
# 获取更详细的函数标识
if hasattr(func, "__self__"): # 实例方法
class_name = func.__self__.__class__.__name__
full_name = f"{class_name}.{func.__name__}"
elif hasattr(func, "__qualname__"): # 类方法或静态方法
full_name = func.__qualname__
else:
module_name = func.__module__
full_name = f"{module_name}.{func.__name__}"
PerformanceStats.add_execution_time(full_name, execution_time)
return result
return wrapper
\ No newline at end of file
......@@ -5,10 +5,10 @@ import torch
from loguru import logger
from magic_pdf.config.constants import MODEL_NAME
from magic_pdf.model.pdf_extract_kit import CustomPEKModel
from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
from magic_pdf.model.sub_modules.model_utils import (
clean_vram, crop_img, get_res_list_from_layout_res)
from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import (
from magic_pdf.model.sub_modules.ocr.paddleocr2pytorch.ocr_utils import (
get_adjusted_mfdetrec_res, get_ocr_result_list)
YOLO_LAYOUT_BASE_BATCH_SIZE = 1
......@@ -85,8 +85,8 @@ class BatchAnalyze:
# 清理显存
clean_vram(self.model.device, vram_threshold=8)
ocr_time = 0
ocr_count = 0
det_time = 0
det_count = 0
table_time = 0
table_count = 0
# reference: magic_pdf/model/doc_analyze_by_custom_model.py:doc_analyze
......@@ -100,7 +100,7 @@ class BatchAnalyze:
get_res_list_from_layout_res(layout_res)
)
# ocr识别
ocr_start = time.time()
det_start = time.time()
# Process each area that requires OCR processing
for res in ocr_res_list:
new_image, useful_list = crop_img(
......@@ -113,21 +113,21 @@ class BatchAnalyze:
# OCR recognition
new_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR)
if ocr_enable:
ocr_res = self.model.ocr_model.ocr(
new_image, mfd_res=adjusted_mfdetrec_res
)[0]
else:
# if ocr_enable:
# ocr_res = self.model.ocr_model.ocr(
# new_image, mfd_res=adjusted_mfdetrec_res
# )[0]
# else:
ocr_res = self.model.ocr_model.ocr(
new_image, mfd_res=adjusted_mfdetrec_res, rec=False
)[0]
# Integration results
if ocr_res:
ocr_result_list = get_ocr_result_list(ocr_res, useful_list)
ocr_result_list = get_ocr_result_list(ocr_res, useful_list, ocr_enable, new_image, _lang)
layout_res.extend(ocr_result_list)
ocr_time += time.time() - ocr_start
ocr_count += len(ocr_res_list)
det_time += time.time() - det_start
det_count += len(ocr_res_list)
# 表格识别 table recognition
if self.model.apply_table:
......@@ -172,9 +172,70 @@ class BatchAnalyze:
table_time += time.time() - table_start
table_count += len(table_res_list)
if self.model.apply_ocr:
logger.info(f'det or det time costs: {round(ocr_time, 2)}, image num: {ocr_count}')
logger.info(f'ocr-det time: {round(det_time, 2)}, image num: {det_count}')
if self.model.apply_table:
logger.info(f'table time: {round(table_time, 2)}, image num: {table_count}')
# Create dictionaries to store items by language
need_ocr_lists_by_lang = {} # Dict of lists for each language
img_crop_lists_by_lang = {} # Dict of lists for each language
for layout_res in images_layout_res:
for layout_res_item in layout_res:
if layout_res_item['category_id'] in [15]:
if 'np_img' in layout_res_item and 'lang' in layout_res_item:
lang = layout_res_item['lang']
# Initialize lists for this language if not exist
if lang not in need_ocr_lists_by_lang:
need_ocr_lists_by_lang[lang] = []
img_crop_lists_by_lang[lang] = []
# Add to the appropriate language-specific lists
need_ocr_lists_by_lang[lang].append(layout_res_item)
img_crop_lists_by_lang[lang].append(layout_res_item['np_img'])
# Remove the fields after adding to lists
layout_res_item.pop('np_img')
layout_res_item.pop('lang')
if len(img_crop_lists_by_lang) > 0:
# Process OCR by language
rec_time = 0
rec_start = time.time()
total_processed = 0
# Process each language separately
for lang, img_crop_list in img_crop_lists_by_lang.items():
if len(img_crop_list) > 0:
# Get OCR results for this language's images
atom_model_manager = AtomModelSingleton()
ocr_model = atom_model_manager.get_atom_model(
atom_model_name='ocr',
ocr_show_log=False,
det_db_box_thresh=0.3,
lang=lang
)
ocr_res_list = ocr_model.ocr(img_crop_list, det=False)[0]
# Verify we have matching counts
assert len(ocr_res_list) == len(
need_ocr_lists_by_lang[lang]), f'ocr_res_list: {len(ocr_res_list)}, need_ocr_list: {len(need_ocr_lists_by_lang[lang])} for lang: {lang}'
# Process OCR results for this language
for index, layout_res_item in enumerate(need_ocr_lists_by_lang[lang]):
ocr_text, ocr_score = ocr_res_list[index]
layout_res_item['text'] = ocr_text
layout_res_item['score'] = float(round(ocr_score, 2))
total_processed += len(img_crop_list)
rec_time += time.time() - rec_start
logger.info(f'ocr-rec time: {round(rec_time, 2)}, total images processed: {total_processed}')
return images_layout_res
import concurrent.futures as fut
import multiprocessing as mp
import os
import time
......@@ -25,8 +23,6 @@ from magic_pdf.libs.config_reader import (get_device, get_formula_config,
get_table_recog_config)
from magic_pdf.model.model_list import MODEL
# from magic_pdf.operators.models import InferenceResult
class ModelSingleton:
_instance = None
_models = {}
......@@ -141,7 +137,7 @@ def doc_analyze(
else len(dataset) - 1
)
MIN_BATCH_INFERENCE_SIZE = int(os.environ.get('MINERU_MIN_BATCH_INFERENCE_SIZE', 100))
MIN_BATCH_INFERENCE_SIZE = int(os.environ.get('MINERU_MIN_BATCH_INFERENCE_SIZE', 200))
images = []
page_wh_list = []
for index in range(len(dataset)):
......@@ -244,9 +240,7 @@ def may_batch_image_analyze(
formula_enable=None,
table_enable=None):
# os.environ['CUDA_VISIBLE_DEVICES'] = str(idx)
# 关闭paddle的信号处理
import paddle
paddle.disable_signal_handler()
from magic_pdf.model.batch_analyze import BatchAnalyze
model_manager = ModelSingleton()
......
......@@ -14,7 +14,7 @@ from magic_pdf.model.model_list import AtomicModel
from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
from magic_pdf.model.sub_modules.model_utils import (
clean_vram, crop_img, get_res_list_from_layout_res)
from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import (
from magic_pdf.model.sub_modules.ocr.paddleocr2pytorch.ocr_utils import (
get_adjusted_mfdetrec_res, get_ocr_result_list)
......
......@@ -16,7 +16,7 @@ class UnimerSwinImageProcessor(BaseImageProcessor):
self.transform = alb.Compose(
[
alb.ToGray(always_apply=True),
alb.ToGray(),
alb.Normalize((0.7931, 0.7931, 0.7931), (0.1738, 0.1738, 0.1738)),
# alb.Sharpen()
ToTensorV2(),
......
......@@ -7,35 +7,36 @@ from magic_pdf.model.sub_modules.language_detection.yolov11.YOLOv11 import YOLOv
from magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO import DocLayoutYOLOModel
from magic_pdf.model.sub_modules.mfd.yolov8.YOLOv8 import YOLOv8MFDModel
from magic_pdf.model.sub_modules.mfr.unimernet.Unimernet import UnimernetModel
try:
from magic_pdf_ascend_plugin.libs.license_verifier import (
LicenseExpiredError, LicenseFormatError, LicenseSignatureError,
load_license)
from magic_pdf_ascend_plugin.model_plugin.ocr.paddleocr.ppocr_273_npu import ModifiedPaddleOCR
from magic_pdf_ascend_plugin.model_plugin.table.rapidtable.rapid_table_npu import RapidTableModel
license_key = load_license()
logger.info(f'Using Ascend Plugin Success, License id is {license_key["payload"]["id"]},'
f' License expired at {license_key["payload"]["date"]["end_date"]}')
except Exception as e:
if isinstance(e, ImportError):
pass
elif isinstance(e, LicenseFormatError):
logger.error('Ascend Plugin: Invalid license format. Please check the license file.')
elif isinstance(e, LicenseSignatureError):
logger.error('Ascend Plugin: Invalid signature. The license may be tampered with.')
elif isinstance(e, LicenseExpiredError):
logger.error('Ascend Plugin: License has expired. Please renew your license.')
elif isinstance(e, FileNotFoundError):
logger.error('Ascend Plugin: Not found License file.')
else:
logger.error(f'Ascend Plugin: {e}')
from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_273_mod import ModifiedPaddleOCR
# from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_291_mod import ModifiedPaddleOCR
from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import RapidTableModel
def table_model_init(table_model_type, model_path, max_time, _device_='cpu', ocr_engine=None, table_sub_model_name=None):
from magic_pdf.model.sub_modules.ocr.paddleocr2pytorch.pytorch_paddle import PytorchPaddleOCR
from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import RapidTableModel
# try:
# from magic_pdf_ascend_plugin.libs.license_verifier import (
# LicenseExpiredError, LicenseFormatError, LicenseSignatureError,
# load_license)
# from magic_pdf_ascend_plugin.model_plugin.ocr.paddleocr.ppocr_273_npu import ModifiedPaddleOCR
# from magic_pdf_ascend_plugin.model_plugin.table.rapidtable.rapid_table_npu import RapidTableModel
# license_key = load_license()
# logger.info(f'Using Ascend Plugin Success, License id is {license_key["payload"]["id"]},'
# f' License expired at {license_key["payload"]["date"]["end_date"]}')
# except Exception as e:
# if isinstance(e, ImportError):
# pass
# elif isinstance(e, LicenseFormatError):
# logger.error('Ascend Plugin: Invalid license format. Please check the license file.')
# elif isinstance(e, LicenseSignatureError):
# logger.error('Ascend Plugin: Invalid signature. The license may be tampered with.')
# elif isinstance(e, LicenseExpiredError):
# logger.error('Ascend Plugin: License has expired. Please renew your license.')
# elif isinstance(e, FileNotFoundError):
# logger.error('Ascend Plugin: Not found License file.')
# else:
# logger.error(f'Ascend Plugin: {e}')
# from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_273_mod import ModifiedPaddleOCR
# # from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_291_mod import ModifiedPaddleOCR
# from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import RapidTableModel
def table_model_init(table_model_type, model_path, max_time, _device_='cpu', lang=None, table_sub_model_name=None):
if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
from magic_pdf.model.sub_modules.table.structeqtable.struct_eqtable import StructTableModel
table_model = StructTableModel(model_path, max_new_tokens=2048, max_time=max_time)
......@@ -47,6 +48,14 @@ def table_model_init(table_model_type, model_path, max_time, _device_='cpu', ocr
}
table_model = TableMasterPaddleModel(config)
elif table_model_type == MODEL_NAME.RAPID_TABLE:
atom_model_manager = AtomModelSingleton()
ocr_engine = atom_model_manager.get_atom_model(
atom_model_name='ocr',
ocr_show_log=False,
det_db_box_thresh=0.5,
det_db_unclip_ratio=1.6,
lang=lang
)
table_model = RapidTableModel(ocr_engine, table_sub_model_name)
else:
logger.error('table model type not allow')
......@@ -94,7 +103,8 @@ def ocr_model_init(show_log: bool = False,
det_db_unclip_ratio=1.8,
):
if lang is not None and lang != '':
model = ModifiedPaddleOCR(
# model = ModifiedPaddleOCR(
model = PytorchPaddleOCR(
show_log=show_log,
det_db_box_thresh=det_db_box_thresh,
lang=lang,
......@@ -102,7 +112,8 @@ def ocr_model_init(show_log: bool = False,
det_db_unclip_ratio=det_db_unclip_ratio,
)
else:
model = ModifiedPaddleOCR(
# model = ModifiedPaddleOCR(
model = PytorchPaddleOCR(
show_log=show_log,
det_db_box_thresh=det_db_box_thresh,
use_dilation=use_dilation,
......@@ -131,7 +142,7 @@ class AtomModelSingleton:
elif atom_model_name in [AtomicModel.Layout]:
key = (atom_model_name, layout_model_name)
elif atom_model_name in [AtomicModel.Table]:
key = (atom_model_name, table_model_name)
key = (atom_model_name, table_model_name, lang)
else:
key = atom_model_name
......@@ -179,7 +190,7 @@ def atom_model_init(model_name: str, **kwargs):
kwargs.get('table_model_path'),
kwargs.get('table_max_time'),
kwargs.get('device'),
kwargs.get('ocr_engine'),
kwargs.get('lang'),
kwargs.get('table_sub_model_name')
)
elif model_name == AtomicModel.LangDetect:
......
import copy
import time
import cv2
import numpy as np
from paddleocr import PaddleOCR
from paddleocr.paddleocr import check_img, logger
from paddleocr.ppocr.utils.utility import alpha_to_color, binarize_img
from paddleocr.tools.infer.predict_system import sorted_boxes
from paddleocr.tools.infer.utility import slice_generator, merge_fragmented, get_rotate_crop_image, \
get_minarea_rect_crop
from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import update_det_boxes
class ModifiedPaddleOCR(PaddleOCR):
def ocr(
self,
img,
det=True,
rec=True,
cls=True,
bin=False,
inv=False,
alpha_color=(255, 255, 255),
slice={},
mfd_res=None,
):
"""
OCR with PaddleOCR
Args:
img: Image for OCR. It can be an ndarray, img_path, or a list of ndarrays.
det: Use text detection or not. If False, only text recognition will be executed. Default is True.
rec: Use text recognition or not. If False, only text detection will be executed. Default is True.
cls: Use angle classifier or not. Default is True. If True, the text with a rotation of 180 degrees can be recognized. If no text is rotated by 180 degrees, use cls=False to get better performance.
bin: Binarize image to black and white. Default is False.
inv: Invert image colors. Default is False.
alpha_color: Set RGB color Tuple for transparent parts replacement. Default is pure white.
slice: Use sliding window inference for large images. Both det and rec must be True. Requires int values for slice["horizontal_stride"], slice["vertical_stride"], slice["merge_x_thres"], slice["merge_y_thres"] (See doc/doc_en/slice_en.md). Default is {}.
Returns:
If both det and rec are True, returns a list of OCR results for each image. Each OCR result is a list of bounding boxes and recognized text for each detected text region.
If det is True and rec is False, returns a list of detected bounding boxes for each image.
If det is False and rec is True, returns a list of recognized text for each image.
If both det and rec are False, returns a list of angle classification results for each image.
Raises:
AssertionError: If the input image is not of type ndarray, list, str, or bytes.
SystemExit: If det is True and the input is a list of images.
Note:
- If the angle classifier is not initialized (use_angle_cls=False), it will not be used during the forward process.
- For PDF files, if the input is a list of images and the page_num is specified, only the first page_num images will be processed.
- The preprocess_image function is used to preprocess the input image by applying alpha color replacement, inversion, and binarization if specified.
"""
assert isinstance(img, (np.ndarray, list, str, bytes))
if isinstance(img, list) and det == True:
logger.error("When input a list of images, det must be false")
exit(0)
if cls == True and self.use_angle_cls == False:
logger.warning(
"Since the angle classifier is not initialized, it will not be used during the forward process"
)
img, flag_gif, flag_pdf = check_img(img, alpha_color)
# for infer pdf file
if isinstance(img, list) and flag_pdf:
if self.page_num > len(img) or self.page_num == 0:
imgs = img
else:
imgs = img[: self.page_num]
else:
imgs = [img]
def preprocess_image(_image):
_image = alpha_to_color(_image, alpha_color)
if inv:
_image = cv2.bitwise_not(_image)
if bin:
_image = binarize_img(_image)
return _image
if det and rec:
ocr_res = []
for img in imgs:
img = preprocess_image(img)
dt_boxes, rec_res, _ = self.__call__(img, cls, slice, mfd_res=mfd_res)
if not dt_boxes and not rec_res:
ocr_res.append(None)
continue
tmp_res = [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)]
ocr_res.append(tmp_res)
return ocr_res
elif det and not rec:
ocr_res = []
for img in imgs:
img = preprocess_image(img)
dt_boxes, elapse = self.text_detector(img)
if dt_boxes.size == 0:
ocr_res.append(None)
continue
tmp_res = [box.tolist() for box in dt_boxes]
ocr_res.append(tmp_res)
return ocr_res
else:
ocr_res = []
cls_res = []
for img in imgs:
if not isinstance(img, list):
img = preprocess_image(img)
img = [img]
if self.use_angle_cls and cls:
img, cls_res_tmp, elapse = self.text_classifier(img)
if not rec:
cls_res.append(cls_res_tmp)
rec_res, elapse = self.text_recognizer(img)
ocr_res.append(rec_res)
if not rec:
return cls_res
return ocr_res
def __call__(self, img, cls=True, slice={}, mfd_res=None):
time_dict = {"det": 0, "rec": 0, "cls": 0, "all": 0}
if img is None:
logger.debug("no valid image provided")
return None, None, time_dict
start = time.time()
ori_im = img.copy()
if slice:
slice_gen = slice_generator(
img,
horizontal_stride=slice["horizontal_stride"],
vertical_stride=slice["vertical_stride"],
)
elapsed = []
dt_slice_boxes = []
for slice_crop, v_start, h_start in slice_gen:
dt_boxes, elapse = self.text_detector(slice_crop, use_slice=True)
if dt_boxes.size:
dt_boxes[:, :, 0] += h_start
dt_boxes[:, :, 1] += v_start
dt_slice_boxes.append(dt_boxes)
elapsed.append(elapse)
dt_boxes = np.concatenate(dt_slice_boxes)
dt_boxes = merge_fragmented(
boxes=dt_boxes,
x_threshold=slice["merge_x_thres"],
y_threshold=slice["merge_y_thres"],
)
elapse = sum(elapsed)
else:
dt_boxes, elapse = self.text_detector(img)
time_dict["det"] = elapse
if dt_boxes is None:
logger.debug("no dt_boxes found, elapsed : {}".format(elapse))
end = time.time()
time_dict["all"] = end - start
return None, None, time_dict
else:
logger.debug(
"dt_boxes num : {}, elapsed : {}".format(len(dt_boxes), elapse)
)
img_crop_list = []
dt_boxes = sorted_boxes(dt_boxes)
if mfd_res:
bef = time.time()
dt_boxes = update_det_boxes(dt_boxes, mfd_res)
aft = time.time()
logger.debug("split text box by formula, new dt_boxes num : {}, elapsed : {}".format(
len(dt_boxes), aft - bef))
for bno in range(len(dt_boxes)):
tmp_box = copy.deepcopy(dt_boxes[bno])
if self.args.det_box_type == "quad":
img_crop = get_rotate_crop_image(ori_im, tmp_box)
else:
img_crop = get_minarea_rect_crop(ori_im, tmp_box)
img_crop_list.append(img_crop)
if self.use_angle_cls and cls:
img_crop_list, angle_list, elapse = self.text_classifier(img_crop_list)
time_dict["cls"] = elapse
logger.debug(
"cls num : {}, elapsed : {}".format(len(img_crop_list), elapse)
)
if len(img_crop_list) > 1000:
logger.debug(
f"rec crops num: {len(img_crop_list)}, time and memory cost may be large."
)
rec_res, elapse = self.text_recognizer(img_crop_list)
time_dict["rec"] = elapse
logger.debug("rec_res num : {}, elapsed : {}".format(len(rec_res), elapse))
if self.args.save_crop_res:
self.draw_crop_rec_res(self.args.crop_res_save_dir, img_crop_list, rec_res)
filter_boxes, filter_rec_res = [], []
for box, rec_result in zip(dt_boxes, rec_res):
text, score = rec_result[0], rec_result[1]
if score >= self.drop_score:
filter_boxes.append(box)
filter_rec_res.append(rec_result)
end = time.time()
time_dict["all"] = end - start
return filter_boxes, filter_rec_res, time_dict
# Copyright (c) Opendatalab. All rights reserved.
# Copyright (c) Opendatalab. All rights reserved.
import copy
import cv2
import numpy as np
from loguru import logger
from io import BytesIO
from PIL import Image
import base64
from magic_pdf.libs.boxbase import __is_overlaps_y_exceeds_threshold
from magic_pdf.pre_proc.ocr_dict_merge import merge_spans_to_line
from importlib.resources import files
from paddleocr import PaddleOCR
from ppocr.utils.utility import check_and_read
from magic_pdf.libs.boxbase import __is_overlaps_y_exceeds_threshold
def img_decode(content: bytes):
np_arr = np.frombuffer(content, dtype=np.uint8)
return cv2.imdecode(np_arr, cv2.IMREAD_UNCHANGED)
def check_img(img):
if isinstance(img, bytes):
img = img_decode(img)
if isinstance(img, str):
image_file = img
img, flag_gif, flag_pdf = check_and_read(image_file)
if not flag_gif and not flag_pdf:
with open(image_file, 'rb') as f:
img_str = f.read()
img = img_decode(img_str)
if img is None:
try:
buf = BytesIO()
image = BytesIO(img_str)
im = Image.open(image)
rgb = im.convert('RGB')
rgb.save(buf, 'jpeg')
buf.seek(0)
image_bytes = buf.read()
data_base64 = str(base64.b64encode(image_bytes),
encoding="utf-8")
image_decode = base64.b64decode(data_base64)
img_array = np.frombuffer(image_decode, np.uint8)
img = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
except:
logger.error("error in loading image:{}".format(image_file))
return None
if img is None:
logger.error("error in loading image:{}".format(image_file))
return None
if isinstance(img, np.ndarray) and len(img.shape) == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
return img
def alpha_to_color(img, alpha_color=(255, 255, 255)):
if len(img.shape) == 3 and img.shape[2] == 4:
B, G, R, A = cv2.split(img)
alpha = A / 255
R = (alpha_color[0] * (1 - alpha) + R * alpha).astype(np.uint8)
G = (alpha_color[1] * (1 - alpha) + G * alpha).astype(np.uint8)
B = (alpha_color[2] * (1 - alpha) + B * alpha).astype(np.uint8)
img = cv2.merge((B, G, R))
return img
def preprocess_image(_image):
alpha_color = (255, 255, 255)
_image = alpha_to_color(_image, alpha_color)
return _image
def sorted_boxes(dt_boxes):
"""
Sort text boxes in order from top to bottom, left to right
args:
dt_boxes(array):detected text boxes with shape [4, 2]
return:
sorted boxes(array) with shape [4, 2]
"""
num_boxes = dt_boxes.shape[0]
sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
_boxes = list(sorted_boxes)
for i in range(num_boxes - 1):
for j in range(i, -1, -1):
if abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10 and \
(_boxes[j + 1][0][0] < _boxes[j][0][0]):
tmp = _boxes[j]
_boxes[j] = _boxes[j + 1]
_boxes[j + 1] = tmp
else:
break
return _boxes
def bbox_to_points(bbox):
""" 将bbox格式转换为四个顶点的数组 """
x0, y0, x1, y1 = bbox
......@@ -252,9 +261,10 @@ def get_adjusted_mfdetrec_res(single_page_mfdetrec_res, useful_list):
return adjusted_mfdetrec_res
def get_ocr_result_list(ocr_res, useful_list):
def get_ocr_result_list(ocr_res, useful_list, ocr_enable, new_image, lang):
paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
ocr_result_list = []
ori_im = new_image.copy()
for box_ocr_res in ocr_res:
if len(box_ocr_res) == 2:
......@@ -266,6 +276,11 @@ def get_ocr_result_list(ocr_res, useful_list):
else:
p1, p2, p3, p4 = box_ocr_res
text, score = "", 1
if ocr_enable:
tmp_box = copy.deepcopy(np.array([p1, p2, p3, p4]).astype('float32'))
img_crop = get_rotate_crop_image(ori_im, tmp_box)
# average_angle_degrees = calculate_angle_degrees(box_ocr_res[0])
# if average_angle_degrees > 0.5:
poly = [p1, p2, p3, p4]
......@@ -288,6 +303,16 @@ def get_ocr_result_list(ocr_res, useful_list):
p3 = [p3[0] - paste_x + xmin, p3[1] - paste_y + ymin]
p4 = [p4[0] - paste_x + xmin, p4[1] - paste_y + ymin]
if ocr_enable:
ocr_result_list.append({
'category_id': 15,
'poly': p1 + p2 + p3 + p4,
'score': 1,
'text': text,
'np_img': img_crop,
'lang': lang,
})
else:
ocr_result_list.append({
'category_id': 15,
'poly': p1 + p2 + p3 + p4,
......@@ -308,57 +333,36 @@ def calculate_is_angle(poly):
return True
class ONNXModelSingleton:
_instance = None
_models = {}
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def get_onnx_model(self, **kwargs):
lang = kwargs.get('lang', None)
det_db_box_thresh = kwargs.get('det_db_box_thresh', 0.3)
use_dilation = kwargs.get('use_dilation', True)
det_db_unclip_ratio = kwargs.get('det_db_unclip_ratio', 1.8)
key = (lang, det_db_box_thresh, use_dilation, det_db_unclip_ratio)
if key not in self._models:
self._models[key] = onnx_model_init(key)
return self._models[key]
def onnx_model_init(key):
if len(key) < 4:
logger.error('Invalid key length, expected at least 4 elements')
exit(1)
try:
resource_path = files("rapidocr_onnxruntime") / "models"
additional_ocr_params = {
"use_onnx": True,
"det_model_dir": f'{resource_path}/ch_PP-OCRv4_det_infer.onnx',
"rec_model_dir": f'{resource_path}/ch_PP-OCRv4_rec_infer.onnx',
"cls_model_dir": f'{resource_path}/ch_ppocr_mobile_v2.0_cls_infer.onnx',
"det_db_box_thresh": key[1],
"use_dilation": key[2],
"det_db_unclip_ratio": key[3],
}
if key[0] is not None:
additional_ocr_params["lang"] = key[0]
# logger.info(f"additional_ocr_params: {additional_ocr_params}")
onnx_model = PaddleOCR(**additional_ocr_params)
if onnx_model is None:
logger.error('model init failed')
exit(1)
else:
return onnx_model
except Exception as e:
logger.exception(f'Error initializing model: {e}')
exit(1)
\ No newline at end of file
def get_rotate_crop_image(img, points):
'''
img_height, img_width = img.shape[0:2]
left = int(np.min(points[:, 0]))
right = int(np.max(points[:, 0]))
top = int(np.min(points[:, 1]))
bottom = int(np.max(points[:, 1]))
img_crop = img[top:bottom, left:right, :].copy()
points[:, 0] = points[:, 0] - left
points[:, 1] = points[:, 1] - top
'''
assert len(points) == 4, "shape of points must be 4*2"
img_crop_width = int(
max(
np.linalg.norm(points[0] - points[1]),
np.linalg.norm(points[2] - points[3])))
img_crop_height = int(
max(
np.linalg.norm(points[0] - points[3]),
np.linalg.norm(points[1] - points[2])))
pts_std = np.float32([[0, 0], [img_crop_width, 0],
[img_crop_width, img_crop_height],
[0, img_crop_height]])
M = cv2.getPerspectiveTransform(points, pts_std)
dst_img = cv2.warpPerspective(
img,
M, (img_crop_width, img_crop_height),
borderMode=cv2.BORDER_REPLICATE,
flags=cv2.INTER_CUBIC)
dst_img_height, dst_img_width = dst_img.shape[0:2]
if dst_img_height * 1.0 / dst_img_width >= 1.5:
dst_img = np.rot90(dst_img)
return dst_img
\ No newline at end of file
# Copyright (c) Opendatalab. All rights reserved.
import copy
import platform
import time
import os.path
from pathlib import Path
import cv2
import numpy as np
import torch
from paddleocr import PaddleOCR
from ppocr.utils.logging import get_logger
from ppocr.utils.utility import alpha_to_color, binarize_img
from tools.infer.predict_system import sorted_boxes
from tools.infer.utility import get_rotate_crop_image, get_minarea_rect_crop
import yaml
from loguru import logger
from magic_pdf.libs.config_reader import get_device, get_local_models_dir
from .ocr_utils import check_img, preprocess_image, sorted_boxes, merge_det_boxes, update_det_boxes, get_rotate_crop_image
from .tools.infer.predict_system import TextSystem
from .tools.infer import pytorchocr_utility as utility
import argparse
latin_lang = [
'af', 'az', 'bs', 'cs', 'cy', 'da', 'de', 'es', 'et', 'fr', 'ga', 'hr', # noqa: E126
'hu', 'id', 'is', 'it', 'ku', 'la', 'lt', 'lv', 'mi', 'ms', 'mt', 'nl',
'no', 'oc', 'pi', 'pl', 'pt', 'ro', 'rs_latin', 'sk', 'sl', 'sq', 'sv',
'sw', 'tl', 'tr', 'uz', 'vi', 'french', 'german'
]
arabic_lang = ['ar', 'fa', 'ug', 'ur']
cyrillic_lang = [
'ru', 'rs_cyrillic', 'be', 'bg', 'uk', 'mn', 'abq', 'ady', 'kbd', 'ava', # noqa: E126
'dar', 'inh', 'che', 'lbe', 'lez', 'tab'
]
devanagari_lang = [
'hi', 'mr', 'ne', 'bh', 'mai', 'ang', 'bho', 'mah', 'sck', 'new', 'gom', # noqa: E126
'sa', 'bgc'
]
def get_model_params(lang, config):
if lang in config['lang']:
params = config['lang'][lang]
det = params.get('det')
rec = params.get('rec')
dict_file = params.get('dict')
return det, rec, dict_file
else:
raise Exception (f'Language {lang} not supported')
from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import update_det_boxes, merge_det_boxes, check_img, \
ONNXModelSingleton
logger = get_logger()
root_dir = Path(__file__).resolve().parent
class ModifiedPaddleOCR(PaddleOCR):
class PytorchPaddleOCR(TextSystem):
def __init__(self, *args, **kwargs):
parser = utility.init_args()
args = parser.parse_args(args)
super().__init__(*args, **kwargs)
self.lang = kwargs.get('lang', 'ch')
# 在cpu架构为arm且不支持cuda时调用onnx、
if not torch.cuda.is_available() and platform.machine() in ['arm64', 'aarch64']:
self.use_onnx = True
onnx_model_manager = ONNXModelSingleton()
self.additional_ocr = onnx_model_manager.get_onnx_model(**kwargs)
if self.lang in latin_lang:
self.lang = 'latin'
elif self.lang in arabic_lang:
self.lang = 'arabic'
elif self.lang in cyrillic_lang:
self.lang = 'cyrillic'
elif self.lang in devanagari_lang:
self.lang = 'devanagari'
else:
self.use_onnx = False
pass
models_config_path = os.path.join(root_dir, 'pytorchocr', 'utils', 'resources', 'models_config.yml')
with open(models_config_path) as file:
config = yaml.safe_load(file)
det, rec, dict_file = get_model_params(self.lang, config)
ocr_models_dir = os.path.join(get_local_models_dir(), 'OCR', 'paddleocr_torch')
kwargs['det_model_path'] = os.path.join(ocr_models_dir, det)
kwargs['rec_model_path'] = os.path.join(ocr_models_dir, rec)
kwargs['rec_char_dict_path'] = os.path.join(root_dir, 'pytorchocr', 'utils', 'resources', 'dict', dict_file)
kwargs['device'] = get_device()
default_args = vars(args)
default_args.update(kwargs)
args = argparse.Namespace(**default_args)
super().__init__(args)
def ocr(self,
img,
det=True,
rec=True,
cls=True,
bin=False,
inv=False,
alpha_color=(255, 255, 255),
mfd_res=None,
):
"""
OCR with PaddleOCR
args:
img: img for OCR, support ndarray, img_path and list or ndarray
det: use text detection or not. If False, only rec will be exec. Default is True
rec: use text recognition or not. If False, only det will be exec. Default is True
cls: use angle classifier or not. Default is True. If True, the text with rotation of 180 degrees can be recognized. If no text is rotated by 180 degrees, use cls=False to get better performance. Text with rotation of 90 or 270 degrees can be recognized even if cls=False.
bin: binarize image to black and white. Default is False.
inv: invert image colors. Default is False.
alpha_color: set RGB color Tuple for transparent parts replacement. Default is pure white.
"""
assert isinstance(img, (np.ndarray, list, str, bytes))
if isinstance(img, list) and det == True:
logger.error('When input a list of images, det must be false')
exit(0)
if cls == True and self.use_angle_cls == False:
pass
# logger.warning(
# 'Since the angle classifier is not initialized, it will not be used during the forward process'
# )
img = check_img(img)
# for infer pdf file
if isinstance(img, list):
if self.page_num > len(img) or self.page_num == 0:
self.page_num = len(img)
imgs = img[:self.page_num]
else:
imgs = [img]
def preprocess_image(_image):
_image = alpha_to_color(_image, alpha_color)
if inv:
_image = cv2.bitwise_not(_image)
if bin:
_image = binarize_img(_image)
return _image
if det and rec:
ocr_res = []
for img in imgs:
img = preprocess_image(img)
dt_boxes, rec_res, _ = self.__call__(img, cls, mfd_res=mfd_res)
dt_boxes, rec_res = self.__call__(img, mfd_res=mfd_res)
if not dt_boxes and not rec_res:
ocr_res.append(None)
continue
tmp_res = [[box.tolist(), res]
for box, res in zip(dt_boxes, rec_res)]
tmp_res = [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)]
ocr_res.append(tmp_res)
return ocr_res
elif det and not rec:
ocr_res = []
for img in imgs:
img = preprocess_image(img)
if self.lang in ['ch'] and self.use_onnx:
dt_boxes, elapse = self.additional_ocr.text_detector(img)
else:
dt_boxes, elapse = self.text_detector(img)
# logger.debug("dt_boxes num : {}, elapsed : {}".format(len(dt_boxes), elapse))
if dt_boxes is None:
ocr_res.append(None)
continue
......@@ -106,57 +117,36 @@ class ModifiedPaddleOCR(PaddleOCR):
# merge_det_boxes 和 update_det_boxes 都会把poly转成bbox再转回poly,因此需要过滤所有倾斜程度较大的文本框
dt_boxes = merge_det_boxes(dt_boxes)
if mfd_res:
bef = time.time()
dt_boxes = update_det_boxes(dt_boxes, mfd_res)
aft = time.time()
logger.debug("split text box by formula, new dt_boxes num : {}, elapsed : {}".format(
len(dt_boxes), aft - bef))
tmp_res = [box.tolist() for box in dt_boxes]
ocr_res.append(tmp_res)
return ocr_res
else:
elif not det and rec:
ocr_res = []
cls_res = []
for img in imgs:
if not isinstance(img, list):
img = preprocess_image(img)
img = [img]
if self.use_angle_cls and cls:
img, cls_res_tmp, elapse = self.text_classifier(img)
if not rec:
cls_res.append(cls_res_tmp)
if self.lang in ['ch'] and self.use_onnx:
rec_res, elapse = self.additional_ocr.text_recognizer(img)
else:
rec_res, elapse = self.text_recognizer(img)
# logger.debug("rec_res num : {}, elapsed : {}".format(len(rec_res), elapse))
ocr_res.append(rec_res)
if not rec:
return cls_res
return ocr_res
def __call__(self, img, cls=True, mfd_res=None):
time_dict = {'det': 0, 'rec': 0, 'cls': 0, 'all': 0}
def __call__(self, img, mfd_res=None):
if img is None:
logger.debug("no valid image provided")
return None, None, time_dict
return None, None
start = time.time()
ori_im = img.copy()
if self.lang in ['ch'] and self.use_onnx:
dt_boxes, elapse = self.additional_ocr.text_detector(img)
else:
dt_boxes, elapse = self.text_detector(img)
time_dict['det'] = elapse
if dt_boxes is None:
logger.debug("no dt_boxes found, elapsed : {}".format(elapse))
end = time.time()
time_dict['all'] = end - start
return None, None, time_dict
return None, None
else:
logger.debug("dt_boxes num : {}, elapsed : {}".format(
len(dt_boxes), elapse))
pass
# logger.debug("dt_boxes num : {}, elapsed : {}".format(len(dt_boxes), elapse))
img_crop_list = []
dt_boxes = sorted_boxes(dt_boxes)
......@@ -165,41 +155,35 @@ class ModifiedPaddleOCR(PaddleOCR):
dt_boxes = merge_det_boxes(dt_boxes)
if mfd_res:
bef = time.time()
dt_boxes = update_det_boxes(dt_boxes, mfd_res)
aft = time.time()
logger.debug("split text box by formula, new dt_boxes num : {}, elapsed : {}".format(
len(dt_boxes), aft - bef))
for bno in range(len(dt_boxes)):
tmp_box = copy.deepcopy(dt_boxes[bno])
if self.args.det_box_type == "quad":
img_crop = get_rotate_crop_image(ori_im, tmp_box)
else:
img_crop = get_minarea_rect_crop(ori_im, tmp_box)
img_crop_list.append(img_crop)
if self.use_angle_cls and cls:
img_crop_list, angle_list, elapse = self.text_classifier(
img_crop_list)
time_dict['cls'] = elapse
logger.debug("cls num : {}, elapsed : {}".format(
len(img_crop_list), elapse))
if self.lang in ['ch'] and self.use_onnx:
rec_res, elapse = self.additional_ocr.text_recognizer(img_crop_list)
else:
rec_res, elapse = self.text_recognizer(img_crop_list)
time_dict['rec'] = elapse
logger.debug("rec_res num : {}, elapsed : {}".format(
len(rec_res), elapse))
if self.args.save_crop_res:
self.draw_crop_rec_res(self.args.crop_res_save_dir, img_crop_list,
rec_res)
# logger.debug("rec_res num : {}, elapsed : {}".format(len(rec_res), elapse))
filter_boxes, filter_rec_res = [], []
for box, rec_result in zip(dt_boxes, rec_res):
text, score = rec_result
if score >= self.drop_score:
filter_boxes.append(box)
filter_rec_res.append(rec_result)
end = time.time()
time_dict['all'] = end - start
return filter_boxes, filter_rec_res, time_dict
\ No newline at end of file
return filter_boxes, filter_rec_res
if __name__ == '__main__':
pytorch_paddle_ocr = PytorchPaddleOCR()
img = cv2.imread("/Users/myhloli/Downloads/screenshot-20250326-194348.png")
dt_boxes, rec_res = pytorch_paddle_ocr(img)
ocr_res = []
if not dt_boxes and not rec_res:
ocr_res.append(None)
else:
tmp_res = [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)]
ocr_res.append(tmp_res)
print(ocr_res)
import os
import torch
from .modeling.architectures.base_model import BaseModel
class BaseOCRV20:
def __init__(self, config, **kwargs):
self.config = config
self.build_net(**kwargs)
self.net.eval()
def build_net(self, **kwargs):
self.net = BaseModel(self.config, **kwargs)
def read_pytorch_weights(self, weights_path):
if not os.path.exists(weights_path):
raise FileNotFoundError('{} is not existed.'.format(weights_path))
weights = torch.load(weights_path)
return weights
def get_out_channels(self, weights):
if list(weights.keys())[-1].endswith('.weight') and len(list(weights.values())[-1].shape) == 2:
out_channels = list(weights.values())[-1].numpy().shape[1]
else:
out_channels = list(weights.values())[-1].numpy().shape[0]
return out_channels
def load_state_dict(self, weights):
self.net.load_state_dict(weights)
# print('weights is loaded.')
def load_pytorch_weights(self, weights_path):
self.net.load_state_dict(torch.load(weights_path, weights_only=True))
# print('model is loaded: {}'.format(weights_path))
def inference(self, inputs):
with torch.no_grad():
infer = self.net(inputs)
return infer
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import os
import sys
import numpy as np
# import paddle
import signal
import random
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
import copy
# from paddle.io import Dataset, DataLoader, BatchSampler, DistributedBatchSampler
# import paddle.distributed as dist
from .imaug import transform, create_operators
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment