Unverified Commit f4ffdfe8 authored by Xiaomeng Zhao's avatar Xiaomeng Zhao Committed by GitHub
Browse files

Merge pull request #2062 from myhloli/dev

feat: support 3.10~3.12 & remove paddle
parents ec566d22 cb3a4314
...@@ -216,7 +216,7 @@ https://github.com/user-attachments/assets/4bea02c9-6d54-4cd6-97ed-dff14340982c ...@@ -216,7 +216,7 @@ https://github.com/user-attachments/assets/4bea02c9-6d54-4cd6-97ed-dff14340982c
</tr> </tr>
<tr> <tr>
<td colspan="3">python版本</td> <td colspan="3">python版本</td>
<td colspan="3">3.10 (请务必通过conda创建3.10虚拟环境)</td> <td colspan="3">>=3.9,<=3.12</td>
</tr> </tr>
<tr> <tr>
<td colspan="3">Nvidia Driver 版本</td> <td colspan="3">Nvidia Driver 版本</td>
...@@ -226,8 +226,8 @@ https://github.com/user-attachments/assets/4bea02c9-6d54-4cd6-97ed-dff14340982c ...@@ -226,8 +226,8 @@ https://github.com/user-attachments/assets/4bea02c9-6d54-4cd6-97ed-dff14340982c
</tr> </tr>
<tr> <tr>
<td colspan="3">CUDA环境</td> <td colspan="3">CUDA环境</td>
<td>自动安装[12.1(pytorch)+11.8(paddle)]</td> <td>11.8/12.4/12.6</td>
<td>11.8(手动安装)+cuDNN v8.7.0(手动安装)</td> <td>11.8/12.4/12.6</td>
<td>None</td> <td>None</td>
</tr> </tr>
<tr> <tr>
...@@ -237,12 +237,12 @@ https://github.com/user-attachments/assets/4bea02c9-6d54-4cd6-97ed-dff14340982c ...@@ -237,12 +237,12 @@ https://github.com/user-attachments/assets/4bea02c9-6d54-4cd6-97ed-dff14340982c
<td>None</td> <td>None</td>
</tr> </tr>
<tr> <tr>
<td rowspan="2">GPU硬件支持列表</td> <td rowspan="2">GPU/MPS 硬件支持列表</td>
<td colspan="2">显存8G以上</td> <td colspan="2">显存6G以上</td>
<td colspan="2"> <td colspan="2">
2080~2080Ti / 3060Ti~3090Ti / 4060~4090<br> Volta(2017)及之后生产的全部带Tensor Core的GPU <br>
8G显存及以上可开启全部加速功能</td> 6G显存及以上</td>
<td rowspan="2">None</td> <td rowspan="2">apple slicon</td>
</tr> </tr>
</table> </table>
...@@ -262,9 +262,9 @@ https://github.com/user-attachments/assets/4bea02c9-6d54-4cd6-97ed-dff14340982c ...@@ -262,9 +262,9 @@ https://github.com/user-attachments/assets/4bea02c9-6d54-4cd6-97ed-dff14340982c
> 最新版本国内镜像源同步可能会有延迟,请耐心等待 > 最新版本国内镜像源同步可能会有延迟,请耐心等待
```bash ```bash
conda create -n mineru python=3.10 conda create -n mineru 'python<3.13' -y
conda activate mineru conda activate mineru
pip install -U "magic-pdf[full]" --extra-index-url https://wheels.myhloli.com -i https://mirrors.aliyun.com/pypi/simple pip install -U "magic-pdf[full]" -i https://mirrors.aliyun.com/pypi/simple
``` ```
#### 2. 下载模型权重文件 #### 2. 下载模型权重文件
......
...@@ -34,10 +34,9 @@ RUN python3 -m venv /opt/mineru_venv ...@@ -34,10 +34,9 @@ RUN python3 -m venv /opt/mineru_venv
RUN /bin/bash -c "source /opt/mineru_venv/bin/activate && \ RUN /bin/bash -c "source /opt/mineru_venv/bin/activate && \
pip3 install --upgrade pip -i https://mirrors.aliyun.com/pypi/simple && \ pip3 install --upgrade pip -i https://mirrors.aliyun.com/pypi/simple && \
wget https://gcore.jsdelivr.net/gh/opendatalab/MinerU@master/docker/ascend_npu/requirements.txt -O requirements.txt && \ wget https://gcore.jsdelivr.net/gh/opendatalab/MinerU@master/docker/ascend_npu/requirements.txt -O requirements.txt && \
pip3 install -r requirements.txt --extra-index-url https://wheels.myhloli.com -i https://mirrors.aliyun.com/pypi/simple && \ pip3 install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple && \
wget https://gitee.com/ascend/pytorch/releases/download/v6.0.rc2-pytorch2.3.1/torch_npu-2.3.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl && \ wget https://gitee.com/ascend/pytorch/releases/download/v6.0.rc2-pytorch2.3.1/torch_npu-2.3.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl && \
pip3 install torch_npu-2.3.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl && \ pip3 install torch_npu-2.3.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl"
pip3 install https://gcore.jsdelivr.net/gh/myhloli/wheels@main/assets/whl/paddle-custom-npu/paddle_custom_npu-0.0.0-cp310-cp310-linux_aarch64.whl"
# Copy the configuration file template and install magic-pdf latest # Copy the configuration file template and install magic-pdf latest
RUN /bin/bash -c "wget https://gcore.jsdelivr.net/gh/opendatalab/MinerU@master/magic-pdf.template.json && \ RUN /bin/bash -c "wget https://gcore.jsdelivr.net/gh/opendatalab/MinerU@master/magic-pdf.template.json && \
......
boto3>=1.28.43 boto3>=1.28.43
Brotli>=1.1.0 Brotli>=1.1.0
click>=8.1.7 click>=8.1.7
PyMuPDF>=1.24.9,<=1.24.14 PyMuPDF>=1.24.9,<1.25.0
loguru>=0.6.0 loguru>=0.6.0
numpy>=1.21.6,<2.0.0 numpy>=1.21.6,<2.0.0
fast-langdetect>=0.2.3,<0.3.0 fast-langdetect>=0.2.3,<0.3.0
...@@ -11,11 +11,9 @@ torch==2.3.1 ...@@ -11,11 +11,9 @@ torch==2.3.1
torchvision==0.18.1 torchvision==0.18.1
matplotlib matplotlib
ultralytics>=8.3.48 ultralytics>=8.3.48
paddleocr==2.7.3
paddlepaddle==3.0.0rc1
rapidocr-paddle>=1.4.5,<2.0.0
rapidocr-onnxruntime>=1.4.4,<2.0.0
rapid-table>=1.0.3,<2.0.0 rapid-table>=1.0.3,<2.0.0
doclayout-yolo==0.0.2b1 doclayout-yolo==0.0.2b1
ftfy ftfy
openai openai
pydantic>=2.7.2,<2.11
transformers>=4.49.0,<5.0.0
\ No newline at end of file
...@@ -31,8 +31,7 @@ RUN python3 -m venv /opt/mineru_venv ...@@ -31,8 +31,7 @@ RUN python3 -m venv /opt/mineru_venv
RUN /bin/bash -c "source /opt/mineru_venv/bin/activate && \ RUN /bin/bash -c "source /opt/mineru_venv/bin/activate && \
pip3 install --upgrade pip -i https://mirrors.aliyun.com/pypi/simple && \ pip3 install --upgrade pip -i https://mirrors.aliyun.com/pypi/simple && \
wget https://gcore.jsdelivr.net/gh/opendatalab/MinerU@master/docker/china/requirements.txt -O requirements.txt && \ wget https://gcore.jsdelivr.net/gh/opendatalab/MinerU@master/docker/china/requirements.txt -O requirements.txt && \
pip3 install -r requirements.txt --extra-index-url https://wheels.myhloli.com -i https://mirrors.aliyun.com/pypi/simple && \ pip3 install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple"
pip3 install paddlepaddle-gpu==3.0.0rc1 -i https://www.paddlepaddle.org.cn/packages/stable/cu118/"
# Copy the configuration file template and install magic-pdf latest # Copy the configuration file template and install magic-pdf latest
RUN /bin/bash -c "wget https://gcore.jsdelivr.net/gh/opendatalab/MinerU@master/magic-pdf.template.json && \ RUN /bin/bash -c "wget https://gcore.jsdelivr.net/gh/opendatalab/MinerU@master/magic-pdf.template.json && \
......
boto3>=1.28.43 boto3>=1.28.43
Brotli>=1.1.0 Brotli>=1.1.0
click>=8.1.7 click>=8.1.7
PyMuPDF>=1.24.9,<=1.24.14 PyMuPDF>=1.24.9,<1.25.0
loguru>=0.6.0 loguru>=0.6.0
numpy>=1.21.6,<2.0.0 numpy>=1.21.6,<2.0.0
fast-langdetect>=0.2.3,<0.3.0 fast-langdetect>=0.2.3,<0.3.0
...@@ -11,10 +11,9 @@ torch>=2.2.2,!=2.5.0,!=2.5.1,<=2.6.0 ...@@ -11,10 +11,9 @@ torch>=2.2.2,!=2.5.0,!=2.5.1,<=2.6.0
torchvision torchvision
matplotlib matplotlib
ultralytics>=8.3.48 ultralytics>=8.3.48
paddleocr==2.7.3
rapidocr-paddle>=1.4.5,<2.0.0
rapidocr-onnxruntime>=1.4.4,<2.0.0
rapid-table>=1.0.3,<2.0.0 rapid-table>=1.0.3,<2.0.0
doclayout-yolo==0.0.2b1 doclayout-yolo==0.0.2b1
ftfy ftfy
openai openai
pydantic>=2.7.2,<2.11
transformers>=4.49.0,<5.0.0
\ No newline at end of file
...@@ -31,8 +31,7 @@ RUN python3 -m venv /opt/mineru_venv ...@@ -31,8 +31,7 @@ RUN python3 -m venv /opt/mineru_venv
RUN /bin/bash -c "source /opt/mineru_venv/bin/activate && \ RUN /bin/bash -c "source /opt/mineru_venv/bin/activate && \
pip3 install --upgrade pip && \ pip3 install --upgrade pip && \
wget https://github.com/opendatalab/MinerU/raw/master/docker/global/requirements.txt -O requirements.txt && \ wget https://github.com/opendatalab/MinerU/raw/master/docker/global/requirements.txt -O requirements.txt && \
pip3 install -r requirements.txt --extra-index-url https://wheels.myhloli.com && \ pip3 install -r requirements.txt"
pip3 install paddlepaddle-gpu==3.0.0rc1 -i https://www.paddlepaddle.org.cn/packages/stable/cu118/"
# Copy the configuration file template and install magic-pdf latest # Copy the configuration file template and install magic-pdf latest
RUN /bin/bash -c "wget https://github.com/opendatalab/MinerU/raw/master/magic-pdf.template.json && \ RUN /bin/bash -c "wget https://github.com/opendatalab/MinerU/raw/master/magic-pdf.template.json && \
......
boto3>=1.28.43 boto3>=1.28.43
Brotli>=1.1.0 Brotli>=1.1.0
click>=8.1.7 click>=8.1.7
PyMuPDF>=1.24.9,<=1.24.14 PyMuPDF>=1.24.9,<1.25.0
loguru>=0.6.0 loguru>=0.6.0
numpy>=1.21.6,<2.0.0 numpy>=1.21.6,<2.0.0
fast-langdetect>=0.2.3,<0.3.0 fast-langdetect>=0.2.3,<0.3.0
...@@ -11,10 +11,9 @@ torch>=2.2.2,!=2.5.0,!=2.5.1,<=2.6.0 ...@@ -11,10 +11,9 @@ torch>=2.2.2,!=2.5.0,!=2.5.1,<=2.6.0
torchvision torchvision
matplotlib matplotlib
ultralytics>=8.3.48 ultralytics>=8.3.48
paddleocr==2.7.3
rapidocr-paddle>=1.4.5,<2.0.0
rapidocr-onnxruntime>=1.4.4,<2.0.0
rapid-table>=1.0.3,<2.0.0 rapid-table>=1.0.3,<2.0.0
doclayout-yolo==0.0.2b1 doclayout-yolo==0.0.2b1
ftfy ftfy
openai openai
pydantic>=2.7.2,<2.11
transformers>=4.49.0,<5.0.0
\ No newline at end of file
...@@ -48,7 +48,18 @@ def measure_time(func): ...@@ -48,7 +48,18 @@ def measure_time(func):
start_time = time.time() start_time = time.time()
result = func(*args, **kwargs) result = func(*args, **kwargs)
execution_time = time.time() - start_time execution_time = time.time() - start_time
PerformanceStats.add_execution_time(func.__name__, execution_time)
# 获取更详细的函数标识
if hasattr(func, "__self__"): # 实例方法
class_name = func.__self__.__class__.__name__
full_name = f"{class_name}.{func.__name__}"
elif hasattr(func, "__qualname__"): # 类方法或静态方法
full_name = func.__qualname__
else:
module_name = func.__module__
full_name = f"{module_name}.{func.__name__}"
PerformanceStats.add_execution_time(full_name, execution_time)
return result return result
return wrapper return wrapper
\ No newline at end of file
...@@ -5,10 +5,10 @@ import torch ...@@ -5,10 +5,10 @@ import torch
from loguru import logger from loguru import logger
from magic_pdf.config.constants import MODEL_NAME from magic_pdf.config.constants import MODEL_NAME
from magic_pdf.model.pdf_extract_kit import CustomPEKModel from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
from magic_pdf.model.sub_modules.model_utils import ( from magic_pdf.model.sub_modules.model_utils import (
clean_vram, crop_img, get_res_list_from_layout_res) clean_vram, crop_img, get_res_list_from_layout_res)
from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import ( from magic_pdf.model.sub_modules.ocr.paddleocr2pytorch.ocr_utils import (
get_adjusted_mfdetrec_res, get_ocr_result_list) get_adjusted_mfdetrec_res, get_ocr_result_list)
YOLO_LAYOUT_BASE_BATCH_SIZE = 1 YOLO_LAYOUT_BASE_BATCH_SIZE = 1
...@@ -85,8 +85,8 @@ class BatchAnalyze: ...@@ -85,8 +85,8 @@ class BatchAnalyze:
# 清理显存 # 清理显存
clean_vram(self.model.device, vram_threshold=8) clean_vram(self.model.device, vram_threshold=8)
ocr_time = 0 det_time = 0
ocr_count = 0 det_count = 0
table_time = 0 table_time = 0
table_count = 0 table_count = 0
# reference: magic_pdf/model/doc_analyze_by_custom_model.py:doc_analyze # reference: magic_pdf/model/doc_analyze_by_custom_model.py:doc_analyze
...@@ -100,7 +100,7 @@ class BatchAnalyze: ...@@ -100,7 +100,7 @@ class BatchAnalyze:
get_res_list_from_layout_res(layout_res) get_res_list_from_layout_res(layout_res)
) )
# ocr识别 # ocr识别
ocr_start = time.time() det_start = time.time()
# Process each area that requires OCR processing # Process each area that requires OCR processing
for res in ocr_res_list: for res in ocr_res_list:
new_image, useful_list = crop_img( new_image, useful_list = crop_img(
...@@ -113,21 +113,21 @@ class BatchAnalyze: ...@@ -113,21 +113,21 @@ class BatchAnalyze:
# OCR recognition # OCR recognition
new_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR) new_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR)
if ocr_enable: # if ocr_enable:
ocr_res = self.model.ocr_model.ocr( # ocr_res = self.model.ocr_model.ocr(
new_image, mfd_res=adjusted_mfdetrec_res # new_image, mfd_res=adjusted_mfdetrec_res
)[0] # )[0]
else: # else:
ocr_res = self.model.ocr_model.ocr( ocr_res = self.model.ocr_model.ocr(
new_image, mfd_res=adjusted_mfdetrec_res, rec=False new_image, mfd_res=adjusted_mfdetrec_res, rec=False
)[0] )[0]
# Integration results # Integration results
if ocr_res: if ocr_res:
ocr_result_list = get_ocr_result_list(ocr_res, useful_list) ocr_result_list = get_ocr_result_list(ocr_res, useful_list, ocr_enable, new_image, _lang)
layout_res.extend(ocr_result_list) layout_res.extend(ocr_result_list)
ocr_time += time.time() - ocr_start det_time += time.time() - det_start
ocr_count += len(ocr_res_list) det_count += len(ocr_res_list)
# 表格识别 table recognition # 表格识别 table recognition
if self.model.apply_table: if self.model.apply_table:
...@@ -172,9 +172,70 @@ class BatchAnalyze: ...@@ -172,9 +172,70 @@ class BatchAnalyze:
table_time += time.time() - table_start table_time += time.time() - table_start
table_count += len(table_res_list) table_count += len(table_res_list)
if self.model.apply_ocr:
logger.info(f'det or det time costs: {round(ocr_time, 2)}, image num: {ocr_count}') logger.info(f'ocr-det time: {round(det_time, 2)}, image num: {det_count}')
if self.model.apply_table: if self.model.apply_table:
logger.info(f'table time: {round(table_time, 2)}, image num: {table_count}') logger.info(f'table time: {round(table_time, 2)}, image num: {table_count}')
# Create dictionaries to store items by language
need_ocr_lists_by_lang = {} # Dict of lists for each language
img_crop_lists_by_lang = {} # Dict of lists for each language
for layout_res in images_layout_res:
for layout_res_item in layout_res:
if layout_res_item['category_id'] in [15]:
if 'np_img' in layout_res_item and 'lang' in layout_res_item:
lang = layout_res_item['lang']
# Initialize lists for this language if not exist
if lang not in need_ocr_lists_by_lang:
need_ocr_lists_by_lang[lang] = []
img_crop_lists_by_lang[lang] = []
# Add to the appropriate language-specific lists
need_ocr_lists_by_lang[lang].append(layout_res_item)
img_crop_lists_by_lang[lang].append(layout_res_item['np_img'])
# Remove the fields after adding to lists
layout_res_item.pop('np_img')
layout_res_item.pop('lang')
if len(img_crop_lists_by_lang) > 0:
# Process OCR by language
rec_time = 0
rec_start = time.time()
total_processed = 0
# Process each language separately
for lang, img_crop_list in img_crop_lists_by_lang.items():
if len(img_crop_list) > 0:
# Get OCR results for this language's images
atom_model_manager = AtomModelSingleton()
ocr_model = atom_model_manager.get_atom_model(
atom_model_name='ocr',
ocr_show_log=False,
det_db_box_thresh=0.3,
lang=lang
)
ocr_res_list = ocr_model.ocr(img_crop_list, det=False)[0]
# Verify we have matching counts
assert len(ocr_res_list) == len(
need_ocr_lists_by_lang[lang]), f'ocr_res_list: {len(ocr_res_list)}, need_ocr_list: {len(need_ocr_lists_by_lang[lang])} for lang: {lang}'
# Process OCR results for this language
for index, layout_res_item in enumerate(need_ocr_lists_by_lang[lang]):
ocr_text, ocr_score = ocr_res_list[index]
layout_res_item['text'] = ocr_text
layout_res_item['score'] = float(round(ocr_score, 2))
total_processed += len(img_crop_list)
rec_time += time.time() - rec_start
logger.info(f'ocr-rec time: {round(rec_time, 2)}, total images processed: {total_processed}')
return images_layout_res return images_layout_res
import concurrent.futures as fut
import multiprocessing as mp
import os import os
import time import time
...@@ -25,8 +23,6 @@ from magic_pdf.libs.config_reader import (get_device, get_formula_config, ...@@ -25,8 +23,6 @@ from magic_pdf.libs.config_reader import (get_device, get_formula_config,
get_table_recog_config) get_table_recog_config)
from magic_pdf.model.model_list import MODEL from magic_pdf.model.model_list import MODEL
# from magic_pdf.operators.models import InferenceResult
class ModelSingleton: class ModelSingleton:
_instance = None _instance = None
_models = {} _models = {}
...@@ -141,7 +137,7 @@ def doc_analyze( ...@@ -141,7 +137,7 @@ def doc_analyze(
else len(dataset) - 1 else len(dataset) - 1
) )
MIN_BATCH_INFERENCE_SIZE = int(os.environ.get('MINERU_MIN_BATCH_INFERENCE_SIZE', 100)) MIN_BATCH_INFERENCE_SIZE = int(os.environ.get('MINERU_MIN_BATCH_INFERENCE_SIZE', 200))
images = [] images = []
page_wh_list = [] page_wh_list = []
for index in range(len(dataset)): for index in range(len(dataset)):
...@@ -244,9 +240,7 @@ def may_batch_image_analyze( ...@@ -244,9 +240,7 @@ def may_batch_image_analyze(
formula_enable=None, formula_enable=None,
table_enable=None): table_enable=None):
# os.environ['CUDA_VISIBLE_DEVICES'] = str(idx) # os.environ['CUDA_VISIBLE_DEVICES'] = str(idx)
# 关闭paddle的信号处理
import paddle
paddle.disable_signal_handler()
from magic_pdf.model.batch_analyze import BatchAnalyze from magic_pdf.model.batch_analyze import BatchAnalyze
model_manager = ModelSingleton() model_manager = ModelSingleton()
......
...@@ -14,7 +14,7 @@ from magic_pdf.model.model_list import AtomicModel ...@@ -14,7 +14,7 @@ from magic_pdf.model.model_list import AtomicModel
from magic_pdf.model.sub_modules.model_init import AtomModelSingleton from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
from magic_pdf.model.sub_modules.model_utils import ( from magic_pdf.model.sub_modules.model_utils import (
clean_vram, crop_img, get_res_list_from_layout_res) clean_vram, crop_img, get_res_list_from_layout_res)
from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import ( from magic_pdf.model.sub_modules.ocr.paddleocr2pytorch.ocr_utils import (
get_adjusted_mfdetrec_res, get_ocr_result_list) get_adjusted_mfdetrec_res, get_ocr_result_list)
......
...@@ -16,7 +16,7 @@ class UnimerSwinImageProcessor(BaseImageProcessor): ...@@ -16,7 +16,7 @@ class UnimerSwinImageProcessor(BaseImageProcessor):
self.transform = alb.Compose( self.transform = alb.Compose(
[ [
alb.ToGray(always_apply=True), alb.ToGray(),
alb.Normalize((0.7931, 0.7931, 0.7931), (0.1738, 0.1738, 0.1738)), alb.Normalize((0.7931, 0.7931, 0.7931), (0.1738, 0.1738, 0.1738)),
# alb.Sharpen() # alb.Sharpen()
ToTensorV2(), ToTensorV2(),
......
...@@ -7,35 +7,36 @@ from magic_pdf.model.sub_modules.language_detection.yolov11.YOLOv11 import YOLOv ...@@ -7,35 +7,36 @@ from magic_pdf.model.sub_modules.language_detection.yolov11.YOLOv11 import YOLOv
from magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO import DocLayoutYOLOModel from magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO import DocLayoutYOLOModel
from magic_pdf.model.sub_modules.mfd.yolov8.YOLOv8 import YOLOv8MFDModel from magic_pdf.model.sub_modules.mfd.yolov8.YOLOv8 import YOLOv8MFDModel
from magic_pdf.model.sub_modules.mfr.unimernet.Unimernet import UnimernetModel from magic_pdf.model.sub_modules.mfr.unimernet.Unimernet import UnimernetModel
from magic_pdf.model.sub_modules.ocr.paddleocr2pytorch.pytorch_paddle import PytorchPaddleOCR
try: from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import RapidTableModel
from magic_pdf_ascend_plugin.libs.license_verifier import ( # try:
LicenseExpiredError, LicenseFormatError, LicenseSignatureError, # from magic_pdf_ascend_plugin.libs.license_verifier import (
load_license) # LicenseExpiredError, LicenseFormatError, LicenseSignatureError,
from magic_pdf_ascend_plugin.model_plugin.ocr.paddleocr.ppocr_273_npu import ModifiedPaddleOCR # load_license)
from magic_pdf_ascend_plugin.model_plugin.table.rapidtable.rapid_table_npu import RapidTableModel # from magic_pdf_ascend_plugin.model_plugin.ocr.paddleocr.ppocr_273_npu import ModifiedPaddleOCR
license_key = load_license() # from magic_pdf_ascend_plugin.model_plugin.table.rapidtable.rapid_table_npu import RapidTableModel
logger.info(f'Using Ascend Plugin Success, License id is {license_key["payload"]["id"]},' # license_key = load_license()
f' License expired at {license_key["payload"]["date"]["end_date"]}') # logger.info(f'Using Ascend Plugin Success, License id is {license_key["payload"]["id"]},'
except Exception as e: # f' License expired at {license_key["payload"]["date"]["end_date"]}')
if isinstance(e, ImportError): # except Exception as e:
pass # if isinstance(e, ImportError):
elif isinstance(e, LicenseFormatError): # pass
logger.error('Ascend Plugin: Invalid license format. Please check the license file.') # elif isinstance(e, LicenseFormatError):
elif isinstance(e, LicenseSignatureError): # logger.error('Ascend Plugin: Invalid license format. Please check the license file.')
logger.error('Ascend Plugin: Invalid signature. The license may be tampered with.') # elif isinstance(e, LicenseSignatureError):
elif isinstance(e, LicenseExpiredError): # logger.error('Ascend Plugin: Invalid signature. The license may be tampered with.')
logger.error('Ascend Plugin: License has expired. Please renew your license.') # elif isinstance(e, LicenseExpiredError):
elif isinstance(e, FileNotFoundError): # logger.error('Ascend Plugin: License has expired. Please renew your license.')
logger.error('Ascend Plugin: Not found License file.') # elif isinstance(e, FileNotFoundError):
else: # logger.error('Ascend Plugin: Not found License file.')
logger.error(f'Ascend Plugin: {e}') # else:
from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_273_mod import ModifiedPaddleOCR # logger.error(f'Ascend Plugin: {e}')
# from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_291_mod import ModifiedPaddleOCR # from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_273_mod import ModifiedPaddleOCR
from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import RapidTableModel # # from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_291_mod import ModifiedPaddleOCR
# from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import RapidTableModel
def table_model_init(table_model_type, model_path, max_time, _device_='cpu', ocr_engine=None, table_sub_model_name=None):
def table_model_init(table_model_type, model_path, max_time, _device_='cpu', lang=None, table_sub_model_name=None):
if table_model_type == MODEL_NAME.STRUCT_EQTABLE: if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
from magic_pdf.model.sub_modules.table.structeqtable.struct_eqtable import StructTableModel from magic_pdf.model.sub_modules.table.structeqtable.struct_eqtable import StructTableModel
table_model = StructTableModel(model_path, max_new_tokens=2048, max_time=max_time) table_model = StructTableModel(model_path, max_new_tokens=2048, max_time=max_time)
...@@ -47,6 +48,14 @@ def table_model_init(table_model_type, model_path, max_time, _device_='cpu', ocr ...@@ -47,6 +48,14 @@ def table_model_init(table_model_type, model_path, max_time, _device_='cpu', ocr
} }
table_model = TableMasterPaddleModel(config) table_model = TableMasterPaddleModel(config)
elif table_model_type == MODEL_NAME.RAPID_TABLE: elif table_model_type == MODEL_NAME.RAPID_TABLE:
atom_model_manager = AtomModelSingleton()
ocr_engine = atom_model_manager.get_atom_model(
atom_model_name='ocr',
ocr_show_log=False,
det_db_box_thresh=0.5,
det_db_unclip_ratio=1.6,
lang=lang
)
table_model = RapidTableModel(ocr_engine, table_sub_model_name) table_model = RapidTableModel(ocr_engine, table_sub_model_name)
else: else:
logger.error('table model type not allow') logger.error('table model type not allow')
...@@ -94,7 +103,8 @@ def ocr_model_init(show_log: bool = False, ...@@ -94,7 +103,8 @@ def ocr_model_init(show_log: bool = False,
det_db_unclip_ratio=1.8, det_db_unclip_ratio=1.8,
): ):
if lang is not None and lang != '': if lang is not None and lang != '':
model = ModifiedPaddleOCR( # model = ModifiedPaddleOCR(
model = PytorchPaddleOCR(
show_log=show_log, show_log=show_log,
det_db_box_thresh=det_db_box_thresh, det_db_box_thresh=det_db_box_thresh,
lang=lang, lang=lang,
...@@ -102,7 +112,8 @@ def ocr_model_init(show_log: bool = False, ...@@ -102,7 +112,8 @@ def ocr_model_init(show_log: bool = False,
det_db_unclip_ratio=det_db_unclip_ratio, det_db_unclip_ratio=det_db_unclip_ratio,
) )
else: else:
model = ModifiedPaddleOCR( # model = ModifiedPaddleOCR(
model = PytorchPaddleOCR(
show_log=show_log, show_log=show_log,
det_db_box_thresh=det_db_box_thresh, det_db_box_thresh=det_db_box_thresh,
use_dilation=use_dilation, use_dilation=use_dilation,
...@@ -131,7 +142,7 @@ class AtomModelSingleton: ...@@ -131,7 +142,7 @@ class AtomModelSingleton:
elif atom_model_name in [AtomicModel.Layout]: elif atom_model_name in [AtomicModel.Layout]:
key = (atom_model_name, layout_model_name) key = (atom_model_name, layout_model_name)
elif atom_model_name in [AtomicModel.Table]: elif atom_model_name in [AtomicModel.Table]:
key = (atom_model_name, table_model_name) key = (atom_model_name, table_model_name, lang)
else: else:
key = atom_model_name key = atom_model_name
...@@ -179,7 +190,7 @@ def atom_model_init(model_name: str, **kwargs): ...@@ -179,7 +190,7 @@ def atom_model_init(model_name: str, **kwargs):
kwargs.get('table_model_path'), kwargs.get('table_model_path'),
kwargs.get('table_max_time'), kwargs.get('table_max_time'),
kwargs.get('device'), kwargs.get('device'),
kwargs.get('ocr_engine'), kwargs.get('lang'),
kwargs.get('table_sub_model_name') kwargs.get('table_sub_model_name')
) )
elif model_name == AtomicModel.LangDetect: elif model_name == AtomicModel.LangDetect:
......
import copy
import time
import cv2
import numpy as np
from paddleocr import PaddleOCR
from paddleocr.paddleocr import check_img, logger
from paddleocr.ppocr.utils.utility import alpha_to_color, binarize_img
from paddleocr.tools.infer.predict_system import sorted_boxes
from paddleocr.tools.infer.utility import slice_generator, merge_fragmented, get_rotate_crop_image, \
get_minarea_rect_crop
from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import update_det_boxes
class ModifiedPaddleOCR(PaddleOCR):
def ocr(
self,
img,
det=True,
rec=True,
cls=True,
bin=False,
inv=False,
alpha_color=(255, 255, 255),
slice={},
mfd_res=None,
):
"""
OCR with PaddleOCR
Args:
img: Image for OCR. It can be an ndarray, img_path, or a list of ndarrays.
det: Use text detection or not. If False, only text recognition will be executed. Default is True.
rec: Use text recognition or not. If False, only text detection will be executed. Default is True.
cls: Use angle classifier or not. Default is True. If True, the text with a rotation of 180 degrees can be recognized. If no text is rotated by 180 degrees, use cls=False to get better performance.
bin: Binarize image to black and white. Default is False.
inv: Invert image colors. Default is False.
alpha_color: Set RGB color Tuple for transparent parts replacement. Default is pure white.
slice: Use sliding window inference for large images. Both det and rec must be True. Requires int values for slice["horizontal_stride"], slice["vertical_stride"], slice["merge_x_thres"], slice["merge_y_thres"] (See doc/doc_en/slice_en.md). Default is {}.
Returns:
If both det and rec are True, returns a list of OCR results for each image. Each OCR result is a list of bounding boxes and recognized text for each detected text region.
If det is True and rec is False, returns a list of detected bounding boxes for each image.
If det is False and rec is True, returns a list of recognized text for each image.
If both det and rec are False, returns a list of angle classification results for each image.
Raises:
AssertionError: If the input image is not of type ndarray, list, str, or bytes.
SystemExit: If det is True and the input is a list of images.
Note:
- If the angle classifier is not initialized (use_angle_cls=False), it will not be used during the forward process.
- For PDF files, if the input is a list of images and the page_num is specified, only the first page_num images will be processed.
- The preprocess_image function is used to preprocess the input image by applying alpha color replacement, inversion, and binarization if specified.
"""
assert isinstance(img, (np.ndarray, list, str, bytes))
if isinstance(img, list) and det == True:
logger.error("When input a list of images, det must be false")
exit(0)
if cls == True and self.use_angle_cls == False:
logger.warning(
"Since the angle classifier is not initialized, it will not be used during the forward process"
)
img, flag_gif, flag_pdf = check_img(img, alpha_color)
# for infer pdf file
if isinstance(img, list) and flag_pdf:
if self.page_num > len(img) or self.page_num == 0:
imgs = img
else:
imgs = img[: self.page_num]
else:
imgs = [img]
def preprocess_image(_image):
_image = alpha_to_color(_image, alpha_color)
if inv:
_image = cv2.bitwise_not(_image)
if bin:
_image = binarize_img(_image)
return _image
if det and rec:
ocr_res = []
for img in imgs:
img = preprocess_image(img)
dt_boxes, rec_res, _ = self.__call__(img, cls, slice, mfd_res=mfd_res)
if not dt_boxes and not rec_res:
ocr_res.append(None)
continue
tmp_res = [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)]
ocr_res.append(tmp_res)
return ocr_res
elif det and not rec:
ocr_res = []
for img in imgs:
img = preprocess_image(img)
dt_boxes, elapse = self.text_detector(img)
if dt_boxes.size == 0:
ocr_res.append(None)
continue
tmp_res = [box.tolist() for box in dt_boxes]
ocr_res.append(tmp_res)
return ocr_res
else:
ocr_res = []
cls_res = []
for img in imgs:
if not isinstance(img, list):
img = preprocess_image(img)
img = [img]
if self.use_angle_cls and cls:
img, cls_res_tmp, elapse = self.text_classifier(img)
if not rec:
cls_res.append(cls_res_tmp)
rec_res, elapse = self.text_recognizer(img)
ocr_res.append(rec_res)
if not rec:
return cls_res
return ocr_res
def __call__(self, img, cls=True, slice={}, mfd_res=None):
time_dict = {"det": 0, "rec": 0, "cls": 0, "all": 0}
if img is None:
logger.debug("no valid image provided")
return None, None, time_dict
start = time.time()
ori_im = img.copy()
if slice:
slice_gen = slice_generator(
img,
horizontal_stride=slice["horizontal_stride"],
vertical_stride=slice["vertical_stride"],
)
elapsed = []
dt_slice_boxes = []
for slice_crop, v_start, h_start in slice_gen:
dt_boxes, elapse = self.text_detector(slice_crop, use_slice=True)
if dt_boxes.size:
dt_boxes[:, :, 0] += h_start
dt_boxes[:, :, 1] += v_start
dt_slice_boxes.append(dt_boxes)
elapsed.append(elapse)
dt_boxes = np.concatenate(dt_slice_boxes)
dt_boxes = merge_fragmented(
boxes=dt_boxes,
x_threshold=slice["merge_x_thres"],
y_threshold=slice["merge_y_thres"],
)
elapse = sum(elapsed)
else:
dt_boxes, elapse = self.text_detector(img)
time_dict["det"] = elapse
if dt_boxes is None:
logger.debug("no dt_boxes found, elapsed : {}".format(elapse))
end = time.time()
time_dict["all"] = end - start
return None, None, time_dict
else:
logger.debug(
"dt_boxes num : {}, elapsed : {}".format(len(dt_boxes), elapse)
)
img_crop_list = []
dt_boxes = sorted_boxes(dt_boxes)
if mfd_res:
bef = time.time()
dt_boxes = update_det_boxes(dt_boxes, mfd_res)
aft = time.time()
logger.debug("split text box by formula, new dt_boxes num : {}, elapsed : {}".format(
len(dt_boxes), aft - bef))
for bno in range(len(dt_boxes)):
tmp_box = copy.deepcopy(dt_boxes[bno])
if self.args.det_box_type == "quad":
img_crop = get_rotate_crop_image(ori_im, tmp_box)
else:
img_crop = get_minarea_rect_crop(ori_im, tmp_box)
img_crop_list.append(img_crop)
if self.use_angle_cls and cls:
img_crop_list, angle_list, elapse = self.text_classifier(img_crop_list)
time_dict["cls"] = elapse
logger.debug(
"cls num : {}, elapsed : {}".format(len(img_crop_list), elapse)
)
if len(img_crop_list) > 1000:
logger.debug(
f"rec crops num: {len(img_crop_list)}, time and memory cost may be large."
)
rec_res, elapse = self.text_recognizer(img_crop_list)
time_dict["rec"] = elapse
logger.debug("rec_res num : {}, elapsed : {}".format(len(rec_res), elapse))
if self.args.save_crop_res:
self.draw_crop_rec_res(self.args.crop_res_save_dir, img_crop_list, rec_res)
filter_boxes, filter_rec_res = [], []
for box, rec_result in zip(dt_boxes, rec_res):
text, score = rec_result[0], rec_result[1]
if score >= self.drop_score:
filter_boxes.append(box)
filter_rec_res.append(rec_result)
end = time.time()
time_dict["all"] = end - start
return filter_boxes, filter_rec_res, time_dict
# Copyright (c) Opendatalab. All rights reserved.
# Copyright (c) Opendatalab. All rights reserved.
import copy
import cv2 import cv2
import numpy as np import numpy as np
from loguru import logger
from io import BytesIO
from PIL import Image
import base64
from magic_pdf.libs.boxbase import __is_overlaps_y_exceeds_threshold
from magic_pdf.pre_proc.ocr_dict_merge import merge_spans_to_line from magic_pdf.pre_proc.ocr_dict_merge import merge_spans_to_line
from magic_pdf.libs.boxbase import __is_overlaps_y_exceeds_threshold
from importlib.resources import files
from paddleocr import PaddleOCR
from ppocr.utils.utility import check_and_read
def img_decode(content: bytes): def img_decode(content: bytes):
np_arr = np.frombuffer(content, dtype=np.uint8) np_arr = np.frombuffer(content, dtype=np.uint8)
return cv2.imdecode(np_arr, cv2.IMREAD_UNCHANGED) return cv2.imdecode(np_arr, cv2.IMREAD_UNCHANGED)
def check_img(img): def check_img(img):
if isinstance(img, bytes): if isinstance(img, bytes):
img = img_decode(img) img = img_decode(img)
if isinstance(img, str):
image_file = img
img, flag_gif, flag_pdf = check_and_read(image_file)
if not flag_gif and not flag_pdf:
with open(image_file, 'rb') as f:
img_str = f.read()
img = img_decode(img_str)
if img is None:
try:
buf = BytesIO()
image = BytesIO(img_str)
im = Image.open(image)
rgb = im.convert('RGB')
rgb.save(buf, 'jpeg')
buf.seek(0)
image_bytes = buf.read()
data_base64 = str(base64.b64encode(image_bytes),
encoding="utf-8")
image_decode = base64.b64decode(data_base64)
img_array = np.frombuffer(image_decode, np.uint8)
img = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
except:
logger.error("error in loading image:{}".format(image_file))
return None
if img is None:
logger.error("error in loading image:{}".format(image_file))
return None
if isinstance(img, np.ndarray) and len(img.shape) == 2: if isinstance(img, np.ndarray) and len(img.shape) == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
return img
def alpha_to_color(img, alpha_color=(255, 255, 255)):
if len(img.shape) == 3 and img.shape[2] == 4:
B, G, R, A = cv2.split(img)
alpha = A / 255
R = (alpha_color[0] * (1 - alpha) + R * alpha).astype(np.uint8)
G = (alpha_color[1] * (1 - alpha) + G * alpha).astype(np.uint8)
B = (alpha_color[2] * (1 - alpha) + B * alpha).astype(np.uint8)
img = cv2.merge((B, G, R))
return img return img
def preprocess_image(_image):
alpha_color = (255, 255, 255)
_image = alpha_to_color(_image, alpha_color)
return _image
def sorted_boxes(dt_boxes):
"""
Sort text boxes in order from top to bottom, left to right
args:
dt_boxes(array):detected text boxes with shape [4, 2]
return:
sorted boxes(array) with shape [4, 2]
"""
num_boxes = dt_boxes.shape[0]
sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
_boxes = list(sorted_boxes)
for i in range(num_boxes - 1):
for j in range(i, -1, -1):
if abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10 and \
(_boxes[j + 1][0][0] < _boxes[j][0][0]):
tmp = _boxes[j]
_boxes[j] = _boxes[j + 1]
_boxes[j + 1] = tmp
else:
break
return _boxes
def bbox_to_points(bbox): def bbox_to_points(bbox):
""" 将bbox格式转换为四个顶点的数组 """ """ 将bbox格式转换为四个顶点的数组 """
x0, y0, x1, y1 = bbox x0, y0, x1, y1 = bbox
...@@ -252,9 +261,10 @@ def get_adjusted_mfdetrec_res(single_page_mfdetrec_res, useful_list): ...@@ -252,9 +261,10 @@ def get_adjusted_mfdetrec_res(single_page_mfdetrec_res, useful_list):
return adjusted_mfdetrec_res return adjusted_mfdetrec_res
def get_ocr_result_list(ocr_res, useful_list): def get_ocr_result_list(ocr_res, useful_list, ocr_enable, new_image, lang):
paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
ocr_result_list = [] ocr_result_list = []
ori_im = new_image.copy()
for box_ocr_res in ocr_res: for box_ocr_res in ocr_res:
if len(box_ocr_res) == 2: if len(box_ocr_res) == 2:
...@@ -266,6 +276,11 @@ def get_ocr_result_list(ocr_res, useful_list): ...@@ -266,6 +276,11 @@ def get_ocr_result_list(ocr_res, useful_list):
else: else:
p1, p2, p3, p4 = box_ocr_res p1, p2, p3, p4 = box_ocr_res
text, score = "", 1 text, score = "", 1
if ocr_enable:
tmp_box = copy.deepcopy(np.array([p1, p2, p3, p4]).astype('float32'))
img_crop = get_rotate_crop_image(ori_im, tmp_box)
# average_angle_degrees = calculate_angle_degrees(box_ocr_res[0]) # average_angle_degrees = calculate_angle_degrees(box_ocr_res[0])
# if average_angle_degrees > 0.5: # if average_angle_degrees > 0.5:
poly = [p1, p2, p3, p4] poly = [p1, p2, p3, p4]
...@@ -288,12 +303,22 @@ def get_ocr_result_list(ocr_res, useful_list): ...@@ -288,12 +303,22 @@ def get_ocr_result_list(ocr_res, useful_list):
p3 = [p3[0] - paste_x + xmin, p3[1] - paste_y + ymin] p3 = [p3[0] - paste_x + xmin, p3[1] - paste_y + ymin]
p4 = [p4[0] - paste_x + xmin, p4[1] - paste_y + ymin] p4 = [p4[0] - paste_x + xmin, p4[1] - paste_y + ymin]
ocr_result_list.append({ if ocr_enable:
'category_id': 15, ocr_result_list.append({
'poly': p1 + p2 + p3 + p4, 'category_id': 15,
'score': float(round(score, 2)), 'poly': p1 + p2 + p3 + p4,
'text': text, 'score': 1,
}) 'text': text,
'np_img': img_crop,
'lang': lang,
})
else:
ocr_result_list.append({
'category_id': 15,
'poly': p1 + p2 + p3 + p4,
'score': float(round(score, 2)),
'text': text,
})
return ocr_result_list return ocr_result_list
...@@ -308,57 +333,36 @@ def calculate_is_angle(poly): ...@@ -308,57 +333,36 @@ def calculate_is_angle(poly):
return True return True
class ONNXModelSingleton: def get_rotate_crop_image(img, points):
_instance = None '''
_models = {} img_height, img_width = img.shape[0:2]
left = int(np.min(points[:, 0]))
def __new__(cls, *args, **kwargs): right = int(np.max(points[:, 0]))
if cls._instance is None: top = int(np.min(points[:, 1]))
cls._instance = super().__new__(cls) bottom = int(np.max(points[:, 1]))
return cls._instance img_crop = img[top:bottom, left:right, :].copy()
points[:, 0] = points[:, 0] - left
def get_onnx_model(self, **kwargs): points[:, 1] = points[:, 1] - top
'''
lang = kwargs.get('lang', None) assert len(points) == 4, "shape of points must be 4*2"
det_db_box_thresh = kwargs.get('det_db_box_thresh', 0.3) img_crop_width = int(
use_dilation = kwargs.get('use_dilation', True) max(
det_db_unclip_ratio = kwargs.get('det_db_unclip_ratio', 1.8) np.linalg.norm(points[0] - points[1]),
key = (lang, det_db_box_thresh, use_dilation, det_db_unclip_ratio) np.linalg.norm(points[2] - points[3])))
if key not in self._models: img_crop_height = int(
self._models[key] = onnx_model_init(key) max(
return self._models[key] np.linalg.norm(points[0] - points[3]),
np.linalg.norm(points[1] - points[2])))
pts_std = np.float32([[0, 0], [img_crop_width, 0],
def onnx_model_init(key): [img_crop_width, img_crop_height],
if len(key) < 4: [0, img_crop_height]])
logger.error('Invalid key length, expected at least 4 elements') M = cv2.getPerspectiveTransform(points, pts_std)
exit(1) dst_img = cv2.warpPerspective(
img,
try: M, (img_crop_width, img_crop_height),
resource_path = files("rapidocr_onnxruntime") / "models" borderMode=cv2.BORDER_REPLICATE,
additional_ocr_params = { flags=cv2.INTER_CUBIC)
"use_onnx": True, dst_img_height, dst_img_width = dst_img.shape[0:2]
"det_model_dir": f'{resource_path}/ch_PP-OCRv4_det_infer.onnx', if dst_img_height * 1.0 / dst_img_width >= 1.5:
"rec_model_dir": f'{resource_path}/ch_PP-OCRv4_rec_infer.onnx', dst_img = np.rot90(dst_img)
"cls_model_dir": f'{resource_path}/ch_ppocr_mobile_v2.0_cls_infer.onnx', return dst_img
"det_db_box_thresh": key[1], \ No newline at end of file
"use_dilation": key[2],
"det_db_unclip_ratio": key[3],
}
if key[0] is not None:
additional_ocr_params["lang"] = key[0]
# logger.info(f"additional_ocr_params: {additional_ocr_params}")
onnx_model = PaddleOCR(**additional_ocr_params)
if onnx_model is None:
logger.error('model init failed')
exit(1)
else:
return onnx_model
except Exception as e:
logger.exception(f'Error initializing model: {e}')
exit(1)
\ No newline at end of file
# Copyright (c) Opendatalab. All rights reserved.
import copy import copy
import platform import os.path
import time from pathlib import Path
import cv2 import cv2
import numpy as np import numpy as np
import torch import yaml
from loguru import logger
from paddleocr import PaddleOCR from magic_pdf.libs.config_reader import get_device, get_local_models_dir
from ppocr.utils.logging import get_logger from .ocr_utils import check_img, preprocess_image, sorted_boxes, merge_det_boxes, update_det_boxes, get_rotate_crop_image
from ppocr.utils.utility import alpha_to_color, binarize_img from .tools.infer.predict_system import TextSystem
from tools.infer.predict_system import sorted_boxes from .tools.infer import pytorchocr_utility as utility
from tools.infer.utility import get_rotate_crop_image, get_minarea_rect_crop import argparse
latin_lang = [
'af', 'az', 'bs', 'cs', 'cy', 'da', 'de', 'es', 'et', 'fr', 'ga', 'hr', # noqa: E126
'hu', 'id', 'is', 'it', 'ku', 'la', 'lt', 'lv', 'mi', 'ms', 'mt', 'nl',
'no', 'oc', 'pi', 'pl', 'pt', 'ro', 'rs_latin', 'sk', 'sl', 'sq', 'sv',
'sw', 'tl', 'tr', 'uz', 'vi', 'french', 'german'
]
arabic_lang = ['ar', 'fa', 'ug', 'ur']
cyrillic_lang = [
'ru', 'rs_cyrillic', 'be', 'bg', 'uk', 'mn', 'abq', 'ady', 'kbd', 'ava', # noqa: E126
'dar', 'inh', 'che', 'lbe', 'lez', 'tab'
]
devanagari_lang = [
'hi', 'mr', 'ne', 'bh', 'mai', 'ang', 'bho', 'mah', 'sck', 'new', 'gom', # noqa: E126
'sa', 'bgc'
]
def get_model_params(lang, config):
if lang in config['lang']:
params = config['lang'][lang]
det = params.get('det')
rec = params.get('rec')
dict_file = params.get('dict')
return det, rec, dict_file
else:
raise Exception (f'Language {lang} not supported')
root_dir = Path(__file__).resolve().parent
class PytorchPaddleOCR(TextSystem):
def __init__(self, *args, **kwargs):
parser = utility.init_args()
args = parser.parse_args(args)
from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import update_det_boxes, merge_det_boxes, check_img, \ self.lang = kwargs.get('lang', 'ch')
ONNXModelSingleton if self.lang in latin_lang:
self.lang = 'latin'
elif self.lang in arabic_lang:
self.lang = 'arabic'
elif self.lang in cyrillic_lang:
self.lang = 'cyrillic'
elif self.lang in devanagari_lang:
self.lang = 'devanagari'
else:
pass
logger = get_logger() models_config_path = os.path.join(root_dir, 'pytorchocr', 'utils', 'resources', 'models_config.yml')
with open(models_config_path) as file:
config = yaml.safe_load(file)
det, rec, dict_file = get_model_params(self.lang, config)
ocr_models_dir = os.path.join(get_local_models_dir(), 'OCR', 'paddleocr_torch')
kwargs['det_model_path'] = os.path.join(ocr_models_dir, det)
kwargs['rec_model_path'] = os.path.join(ocr_models_dir, rec)
kwargs['rec_char_dict_path'] = os.path.join(root_dir, 'pytorchocr', 'utils', 'resources', 'dict', dict_file)
kwargs['device'] = get_device()
class ModifiedPaddleOCR(PaddleOCR): default_args = vars(args)
def __init__(self, *args, **kwargs): default_args.update(kwargs)
args = argparse.Namespace(**default_args)
super().__init__(*args, **kwargs) super().__init__(args)
self.lang = kwargs.get('lang', 'ch')
# 在cpu架构为arm且不支持cuda时调用onnx、
if not torch.cuda.is_available() and platform.machine() in ['arm64', 'aarch64']:
self.use_onnx = True
onnx_model_manager = ONNXModelSingleton()
self.additional_ocr = onnx_model_manager.get_onnx_model(**kwargs)
else:
self.use_onnx = False
def ocr(self, def ocr(self,
img, img,
det=True, det=True,
rec=True, rec=True,
cls=True,
bin=False,
inv=False,
alpha_color=(255, 255, 255),
mfd_res=None, mfd_res=None,
): ):
"""
OCR with PaddleOCR
args:
img: img for OCR, support ndarray, img_path and list or ndarray
det: use text detection or not. If False, only rec will be exec. Default is True
rec: use text recognition or not. If False, only det will be exec. Default is True
cls: use angle classifier or not. Default is True. If True, the text with rotation of 180 degrees can be recognized. If no text is rotated by 180 degrees, use cls=False to get better performance. Text with rotation of 90 or 270 degrees can be recognized even if cls=False.
bin: binarize image to black and white. Default is False.
inv: invert image colors. Default is False.
alpha_color: set RGB color Tuple for transparent parts replacement. Default is pure white.
"""
assert isinstance(img, (np.ndarray, list, str, bytes)) assert isinstance(img, (np.ndarray, list, str, bytes))
if isinstance(img, list) and det == True: if isinstance(img, list) and det == True:
logger.error('When input a list of images, det must be false') logger.error('When input a list of images, det must be false')
exit(0) exit(0)
if cls == True and self.use_angle_cls == False:
pass
# logger.warning(
# 'Since the angle classifier is not initialized, it will not be used during the forward process'
# )
img = check_img(img) img = check_img(img)
# for infer pdf file imgs = [img]
if isinstance(img, list):
if self.page_num > len(img) or self.page_num == 0:
self.page_num = len(img)
imgs = img[:self.page_num]
else:
imgs = [img]
def preprocess_image(_image):
_image = alpha_to_color(_image, alpha_color)
if inv:
_image = cv2.bitwise_not(_image)
if bin:
_image = binarize_img(_image)
return _image
if det and rec: if det and rec:
ocr_res = [] ocr_res = []
for img in imgs: for img in imgs:
img = preprocess_image(img) img = preprocess_image(img)
dt_boxes, rec_res, _ = self.__call__(img, cls, mfd_res=mfd_res) dt_boxes, rec_res = self.__call__(img, mfd_res=mfd_res)
if not dt_boxes and not rec_res: if not dt_boxes and not rec_res:
ocr_res.append(None) ocr_res.append(None)
continue continue
tmp_res = [[box.tolist(), res] tmp_res = [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)]
for box, res in zip(dt_boxes, rec_res)]
ocr_res.append(tmp_res) ocr_res.append(tmp_res)
return ocr_res return ocr_res
elif det and not rec: elif det and not rec:
ocr_res = [] ocr_res = []
for img in imgs: for img in imgs:
img = preprocess_image(img) img = preprocess_image(img)
if self.lang in ['ch'] and self.use_onnx: dt_boxes, elapse = self.text_detector(img)
dt_boxes, elapse = self.additional_ocr.text_detector(img) # logger.debug("dt_boxes num : {}, elapsed : {}".format(len(dt_boxes), elapse))
else:
dt_boxes, elapse = self.text_detector(img)
if dt_boxes is None: if dt_boxes is None:
ocr_res.append(None) ocr_res.append(None)
continue continue
...@@ -106,57 +117,36 @@ class ModifiedPaddleOCR(PaddleOCR): ...@@ -106,57 +117,36 @@ class ModifiedPaddleOCR(PaddleOCR):
# merge_det_boxes 和 update_det_boxes 都会把poly转成bbox再转回poly,因此需要过滤所有倾斜程度较大的文本框 # merge_det_boxes 和 update_det_boxes 都会把poly转成bbox再转回poly,因此需要过滤所有倾斜程度较大的文本框
dt_boxes = merge_det_boxes(dt_boxes) dt_boxes = merge_det_boxes(dt_boxes)
if mfd_res: if mfd_res:
bef = time.time()
dt_boxes = update_det_boxes(dt_boxes, mfd_res) dt_boxes = update_det_boxes(dt_boxes, mfd_res)
aft = time.time()
logger.debug("split text box by formula, new dt_boxes num : {}, elapsed : {}".format(
len(dt_boxes), aft - bef))
tmp_res = [box.tolist() for box in dt_boxes] tmp_res = [box.tolist() for box in dt_boxes]
ocr_res.append(tmp_res) ocr_res.append(tmp_res)
return ocr_res return ocr_res
else: elif not det and rec:
ocr_res = [] ocr_res = []
cls_res = []
for img in imgs: for img in imgs:
if not isinstance(img, list): if not isinstance(img, list):
img = preprocess_image(img) img = preprocess_image(img)
img = [img] img = [img]
if self.use_angle_cls and cls: rec_res, elapse = self.text_recognizer(img)
img, cls_res_tmp, elapse = self.text_classifier(img) # logger.debug("rec_res num : {}, elapsed : {}".format(len(rec_res), elapse))
if not rec:
cls_res.append(cls_res_tmp)
if self.lang in ['ch'] and self.use_onnx:
rec_res, elapse = self.additional_ocr.text_recognizer(img)
else:
rec_res, elapse = self.text_recognizer(img)
ocr_res.append(rec_res) ocr_res.append(rec_res)
if not rec:
return cls_res
return ocr_res return ocr_res
def __call__(self, img, cls=True, mfd_res=None): def __call__(self, img, mfd_res=None):
time_dict = {'det': 0, 'rec': 0, 'cls': 0, 'all': 0}
if img is None: if img is None:
logger.debug("no valid image provided") logger.debug("no valid image provided")
return None, None, time_dict return None, None
start = time.time()
ori_im = img.copy() ori_im = img.copy()
if self.lang in ['ch'] and self.use_onnx: dt_boxes, elapse = self.text_detector(img)
dt_boxes, elapse = self.additional_ocr.text_detector(img)
else:
dt_boxes, elapse = self.text_detector(img)
time_dict['det'] = elapse
if dt_boxes is None: if dt_boxes is None:
logger.debug("no dt_boxes found, elapsed : {}".format(elapse)) logger.debug("no dt_boxes found, elapsed : {}".format(elapse))
end = time.time() return None, None
time_dict['all'] = end - start
return None, None, time_dict
else: else:
logger.debug("dt_boxes num : {}, elapsed : {}".format( pass
len(dt_boxes), elapse)) # logger.debug("dt_boxes num : {}, elapsed : {}".format(len(dt_boxes), elapse))
img_crop_list = [] img_crop_list = []
dt_boxes = sorted_boxes(dt_boxes) dt_boxes = sorted_boxes(dt_boxes)
...@@ -165,41 +155,35 @@ class ModifiedPaddleOCR(PaddleOCR): ...@@ -165,41 +155,35 @@ class ModifiedPaddleOCR(PaddleOCR):
dt_boxes = merge_det_boxes(dt_boxes) dt_boxes = merge_det_boxes(dt_boxes)
if mfd_res: if mfd_res:
bef = time.time()
dt_boxes = update_det_boxes(dt_boxes, mfd_res) dt_boxes = update_det_boxes(dt_boxes, mfd_res)
aft = time.time()
logger.debug("split text box by formula, new dt_boxes num : {}, elapsed : {}".format(
len(dt_boxes), aft - bef))
for bno in range(len(dt_boxes)): for bno in range(len(dt_boxes)):
tmp_box = copy.deepcopy(dt_boxes[bno]) tmp_box = copy.deepcopy(dt_boxes[bno])
if self.args.det_box_type == "quad": img_crop = get_rotate_crop_image(ori_im, tmp_box)
img_crop = get_rotate_crop_image(ori_im, tmp_box)
else:
img_crop = get_minarea_rect_crop(ori_im, tmp_box)
img_crop_list.append(img_crop) img_crop_list.append(img_crop)
if self.use_angle_cls and cls:
img_crop_list, angle_list, elapse = self.text_classifier( rec_res, elapse = self.text_recognizer(img_crop_list)
img_crop_list) # logger.debug("rec_res num : {}, elapsed : {}".format(len(rec_res), elapse))
time_dict['cls'] = elapse
logger.debug("cls num : {}, elapsed : {}".format(
len(img_crop_list), elapse))
if self.lang in ['ch'] and self.use_onnx:
rec_res, elapse = self.additional_ocr.text_recognizer(img_crop_list)
else:
rec_res, elapse = self.text_recognizer(img_crop_list)
time_dict['rec'] = elapse
logger.debug("rec_res num : {}, elapsed : {}".format(
len(rec_res), elapse))
if self.args.save_crop_res:
self.draw_crop_rec_res(self.args.crop_res_save_dir, img_crop_list,
rec_res)
filter_boxes, filter_rec_res = [], [] filter_boxes, filter_rec_res = [], []
for box, rec_result in zip(dt_boxes, rec_res): for box, rec_result in zip(dt_boxes, rec_res):
text, score = rec_result text, score = rec_result
if score >= self.drop_score: if score >= self.drop_score:
filter_boxes.append(box) filter_boxes.append(box)
filter_rec_res.append(rec_result) filter_rec_res.append(rec_result)
end = time.time()
time_dict['all'] = end - start return filter_boxes, filter_rec_res
return filter_boxes, filter_rec_res, time_dict
\ No newline at end of file if __name__ == '__main__':
pytorch_paddle_ocr = PytorchPaddleOCR()
img = cv2.imread("/Users/myhloli/Downloads/screenshot-20250326-194348.png")
dt_boxes, rec_res = pytorch_paddle_ocr(img)
ocr_res = []
if not dt_boxes and not rec_res:
ocr_res.append(None)
else:
tmp_res = [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)]
ocr_res.append(tmp_res)
print(ocr_res)
import os
import torch
from .modeling.architectures.base_model import BaseModel
class BaseOCRV20:
def __init__(self, config, **kwargs):
self.config = config
self.build_net(**kwargs)
self.net.eval()
def build_net(self, **kwargs):
self.net = BaseModel(self.config, **kwargs)
def read_pytorch_weights(self, weights_path):
if not os.path.exists(weights_path):
raise FileNotFoundError('{} is not existed.'.format(weights_path))
weights = torch.load(weights_path)
return weights
def get_out_channels(self, weights):
if list(weights.keys())[-1].endswith('.weight') and len(list(weights.values())[-1].shape) == 2:
out_channels = list(weights.values())[-1].numpy().shape[1]
else:
out_channels = list(weights.values())[-1].numpy().shape[0]
return out_channels
def load_state_dict(self, weights):
self.net.load_state_dict(weights)
# print('weights is loaded.')
def load_pytorch_weights(self, weights_path):
self.net.load_state_dict(torch.load(weights_path, weights_only=True))
# print('model is loaded: {}'.format(weights_path))
def inference(self, inputs):
with torch.no_grad():
infer = self.net(inputs)
return infer
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import os
import sys
import numpy as np
# import paddle
import signal
import random
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
import copy
# from paddle.io import Dataset, DataLoader, BatchSampler, DistributedBatchSampler
# import paddle.distributed as dist
from .imaug import transform, create_operators
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment