Unverified Commit b4f7b53e authored by Xiaomeng Zhao's avatar Xiaomeng Zhao Committed by GitHub
Browse files

Merge pull request #1261 from opendatalab/release-0.10.6

Release 0.10.6
parents a962824b d3b51aa5
...@@ -30,7 +30,7 @@ jobs: ...@@ -30,7 +30,7 @@ jobs:
source activate mineru source activate mineru
conda env list conda env list
pip show coverage pip show coverage
# cd $GITHUB_WORKSPACE && sh tests/retry_env.sh cd $GITHUB_WORKSPACE && sh tests/retry_env.sh
cd $GITHUB_WORKSPACE && python tests/clean_coverage.py cd $GITHUB_WORKSPACE && python tests/clean_coverage.py
cd $GITHUB_WORKSPACE && coverage run -m pytest tests/unittest/ --cov=magic_pdf/ --cov-report html --cov-report term-missing cd $GITHUB_WORKSPACE && coverage run -m pytest tests/unittest/ --cov=magic_pdf/ --cov-report html --cov-report term-missing
cd $GITHUB_WORKSPACE && python tests/get_coverage.py cd $GITHUB_WORKSPACE && python tests/get_coverage.py
...@@ -41,22 +41,6 @@ jobs: ...@@ -41,22 +41,6 @@ jobs:
needs: cli-test needs: cli-test
runs-on: pdf runs-on: pdf
steps: steps:
- name: get_actor
run: |
metion_list="dt-yy"
echo $GITHUB_ACTOR
if [[ $GITHUB_ACTOR == "drunkpig" ]]; then
metion_list="xuchao"
elif [[ $GITHUB_ACTOR == "myhloli" ]]; then
metion_list="zhaoxiaomeng"
elif [[ $GITHUB_ACTOR == "icecraft" ]]; then
metion_list="xurui1"
fi
echo $metion_list
echo "METIONS=$metion_list" >> "$GITHUB_ENV"
echo ${{ env.METIONS }}
- name: notify - name: notify
run: | run: |
echo ${{ secrets.USER_ID }} curl -X POST -H "Content-Type: application/json" -d '{"msg_type":"post","content":{"post":{"zh_cn":{"title":"'${{ github.repository }}' GitHubAction Failed","content":[[{"tag":"text","text":""},{"tag":"a","text":"Please click here for details ","href":"https://github.com/'${{ github.repository }}'/actions/runs/'${GITHUB_RUN_ID}'"},{"tag":"at","user_id":"'$USER_ID'"}]]}}}}' $WEBHOOK_URL
curl -X POST -H "Content-Type: application/json" -d '{"msg_type":"post","content":{"post":{"zh_cn":{"title":"'${{ github.repository }}' GitHubAction Failed","content":[[{"tag":"text","text":""},{"tag":"a","text":"Please click here for details ","href":"https://github.com/'${{ github.repository }}'/actions/runs/'${GITHUB_RUN_ID}'"},{"tag":"at","user_id":"'${{ secrets.USER_ID }}'"}]]}}}}' ${{ secrets.WEBHOOK_URL }}
...@@ -29,14 +29,14 @@ jobs: ...@@ -29,14 +29,14 @@ jobs:
source activate mineru source activate mineru
conda env list conda env list
pip show coverage pip show coverage
# cd $GITHUB_WORKSPACE && sh tests/retry_env.sh cd $GITHUB_WORKSPACE && sh tests/retry_env.sh
cd $GITHUB_WORKSPACE && python tests/clean_coverage.py cd $GITHUB_WORKSPACE && python tests/clean_coverage.py
cd $GITHUB_WORKSPACE && coverage run -m pytest tests/unittest/ --cov=magic_pdf/ --cov-report html --cov-report term-missing cd $GITHUB_WORKSPACE && coverage run -m pytest tests/unittest/ --cov=magic_pdf/ --cov-report html --cov-report term-missing
cd $GITHUB_WORKSPACE && python tests/get_coverage.py cd $GITHUB_WORKSPACE && python tests/get_coverage.py
cd $GITHUB_WORKSPACE && pytest -s -v tests/test_cli/test_cli_sdk.py cd $GITHUB_WORKSPACE && pytest -s -v tests/test_cli/test_cli_sdk.py
notify_to_feishu: notify_to_feishu:
if: ${{ always() && !cancelled() && contains(needs.*.result, 'failure') && (github.ref_name == 'master') }} if: ${{ always() && !cancelled() && contains(needs.*.result, 'failure')}}
needs: cli-test needs: cli-test
runs-on: pdf runs-on: pdf
steps: steps:
...@@ -57,5 +57,5 @@ jobs: ...@@ -57,5 +57,5 @@ jobs:
- name: notify - name: notify
run: | run: |
echo ${{ secrets.USER_ID }} #echo ${{ secrets.USER_ID }}
curl -X POST -H "Content-Type: application/json" -d '{"msg_type":"post","content":{"post":{"zh_cn":{"title":"'${{ github.repository }}' GitHubAction Failed","content":[[{"tag":"text","text":""},{"tag":"a","text":"Please click here for details ","href":"https://github.com/'${{ github.repository }}'/actions/runs/'${GITHUB_RUN_ID}'"},{"tag":"at","user_id":"'${{ secrets.USER_ID }}'"}]]}}}}' ${{ secrets.WEBHOOK_URL }} curl -X POST -H "Content-Type: application/json" -d '{"msg_type":"post","content":{"post":{"zh_cn":{"title":"'${{ github.repository }}' GitHubAction Failed","content":[[{"tag":"text","text":""},{"tag":"a","text":"Please click here for details ","href":"https://github.com/'${{ github.repository }}'/actions/runs/'${GITHUB_RUN_ID}'"},{"tag":"at","user_id":"'$USER_ID'"}]]}}}}' $WEBHOOK_URL
...@@ -67,14 +67,6 @@ If your graphics card has at least 8GB of VRAM, follow these steps to test CUDA- ...@@ -67,14 +67,6 @@ If your graphics card has at least 8GB of VRAM, follow these steps to test CUDA-
``` ```
pip install --force-reinstall torch==2.3.1 torchvision==0.18.1 --index-url https://download.pytorch.org/whl/cu118 pip install --force-reinstall torch==2.3.1 torchvision==0.18.1 --index-url https://download.pytorch.org/whl/cu118
``` ```
> [!IMPORTANT]
> Ensure the following versions are specified in the command:
>
> ```
> torch==2.3.1 torchvision==0.18.1
> ```
>
> These are the highest versions we support. Installing higher versions without specifying them will cause the program to fail.
2. **Modify the value of `"device-mode"`** in the `magic-pdf.json` configuration file located in your user directory. 2. **Modify the value of `"device-mode"`** in the `magic-pdf.json` configuration file located in your user directory.
......
...@@ -69,15 +69,6 @@ pip install -U magic-pdf[full] --extra-index-url https://wheels.myhloli.com -i h ...@@ -69,15 +69,6 @@ pip install -U magic-pdf[full] --extra-index-url https://wheels.myhloli.com -i h
pip install --force-reinstall torch==2.3.1 torchvision==0.18.1 --index-url https://download.pytorch.org/whl/cu118 pip install --force-reinstall torch==2.3.1 torchvision==0.18.1 --index-url https://download.pytorch.org/whl/cu118
``` ```
> [!IMPORTANT]
> 务必在命令中指定以下版本
>
> ```bash
> torch==2.3.1 torchvision==0.18.1
> ```
>
> 这是我们支持的最高版本,如果不指定版本会自动安装更高版本导致程序无法运行
**2.修改【用户目录】中配置文件magic-pdf.json中"device-mode"的值** **2.修改【用户目录】中配置文件magic-pdf.json中"device-mode"的值**
```json ```json
......
...@@ -51,3 +51,8 @@ class MODEL_NAME: ...@@ -51,3 +51,8 @@ class MODEL_NAME:
UniMerNet_v2_Small = 'unimernet_small' UniMerNet_v2_Small = 'unimernet_small'
RAPID_TABLE = 'rapid_table' RAPID_TABLE = 'rapid_table'
PARSE_TYPE_TXT = 'txt'
PARSE_TYPE_OCR = 'ocr'
...@@ -48,4 +48,16 @@ class DataWriter(ABC): ...@@ -48,4 +48,16 @@ class DataWriter(ABC):
path (str): the target file where to write path (str): the target file where to write
data (str): the data want to write data (str): the data want to write
""" """
self.write(path, data.encode())
def safe_encode(data: str, method: str):
try:
bit_data = data.encode(encoding=method, errors='replace')
return bit_data, True
except: # noqa
return None, False
for method in ['utf-8', 'ascii']:
bit_data, flag = safe_encode(data, method)
if flag:
self.write(path, bit_data)
break
import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Iterator from typing import Callable, Iterator
import fitz import fitz
from magic_pdf.config.enums import SupportedPdfParseMethod from magic_pdf.config.enums import SupportedPdfParseMethod
from magic_pdf.data.schemas import PageInfo from magic_pdf.data.schemas import PageInfo
from magic_pdf.data.utils import fitz_doc_to_image from magic_pdf.data.utils import fitz_doc_to_image
from magic_pdf.filter import classify
class PageableData(ABC): class PageableData(ABC):
...@@ -28,6 +30,32 @@ class PageableData(ABC): ...@@ -28,6 +30,32 @@ class PageableData(ABC):
""" """
pass pass
@abstractmethod
def draw_rect(self, rect_coords, color, fill, fill_opacity, width, overlay):
"""draw rectangle.
Args:
rect_coords (list[float]): four elements array contain the top-left and bottom-right coordinates, [x0, y0, x1, y1]
color (list[float] | None): three element tuple which describe the RGB of the board line, None means no board line
fill (list[float] | None): fill the board with RGB, None means will not fill with color
fill_opacity (float): opacity of the fill, range from [0, 1]
width (float): the width of board
overlay (bool): fill the color in foreground or background. True means fill in background.
"""
pass
@abstractmethod
def insert_text(self, coord, content, fontsize, color):
"""insert text.
Args:
coord (list[float]): four elements array contain the top-left and bottom-right coordinates, [x0, y0, x1, y1]
content (str): the text content
fontsize (int): font size of the text
color (list[float] | None): three element tuple which describe the RGB of the board line, None will use the default font color!
"""
pass
class Dataset(ABC): class Dataset(ABC):
@abstractmethod @abstractmethod
...@@ -66,6 +94,43 @@ class Dataset(ABC): ...@@ -66,6 +94,43 @@ class Dataset(ABC):
""" """
pass pass
@abstractmethod
def dump_to_file(self, file_path: str):
"""Dump the file
Args:
file_path (str): the file path
"""
pass
@abstractmethod
def apply(self, proc: Callable, *args, **kwargs):
"""Apply callable method which.
Args:
proc (Callable): invoke proc as follows:
proc(self, *args, **kwargs)
Returns:
Any: return the result generated by proc
"""
pass
@abstractmethod
def classify(self) -> SupportedPdfParseMethod:
"""classify the dataset
Returns:
SupportedPdfParseMethod: _description_
"""
pass
@abstractmethod
def clone(self):
"""clone this dataset
"""
pass
class PymuDocDataset(Dataset): class PymuDocDataset(Dataset):
def __init__(self, bits: bytes): def __init__(self, bits: bytes):
...@@ -74,7 +139,8 @@ class PymuDocDataset(Dataset): ...@@ -74,7 +139,8 @@ class PymuDocDataset(Dataset):
Args: Args:
bits (bytes): the bytes of the pdf bits (bytes): the bytes of the pdf
""" """
self._records = [Doc(v) for v in fitz.open('pdf', bits)] self._raw_fitz = fitz.open('pdf', bits)
self._records = [Doc(v) for v in self._raw_fitz]
self._data_bits = bits self._data_bits = bits
self._raw_data = bits self._raw_data = bits
...@@ -109,6 +175,43 @@ class PymuDocDataset(Dataset): ...@@ -109,6 +175,43 @@ class PymuDocDataset(Dataset):
""" """
return self._records[page_id] return self._records[page_id]
def dump_to_file(self, file_path: str):
"""Dump the file
Args:
file_path (str): the file path
"""
dir_name = os.path.dirname(file_path)
if dir_name not in ('', '.', '..'):
os.makedirs(dir_name, exist_ok=True)
self._raw_fitz.save(file_path)
def apply(self, proc: Callable, *args, **kwargs):
"""Apply callable method which.
Args:
proc (Callable): invoke proc as follows:
proc(dataset, *args, **kwargs)
Returns:
Any: return the result generated by proc
"""
return proc(self, *args, **kwargs)
def classify(self) -> SupportedPdfParseMethod:
"""classify the dataset
Returns:
SupportedPdfParseMethod: _description_
"""
return classify(self._data_bits)
def clone(self):
"""clone this dataset
"""
return PymuDocDataset(self._raw_data)
class ImageDataset(Dataset): class ImageDataset(Dataset):
def __init__(self, bits: bytes): def __init__(self, bits: bytes):
...@@ -118,7 +221,8 @@ class ImageDataset(Dataset): ...@@ -118,7 +221,8 @@ class ImageDataset(Dataset):
bits (bytes): the bytes of the photo which will be converted to pdf first. then converted to pymudoc. bits (bytes): the bytes of the photo which will be converted to pdf first. then converted to pymudoc.
""" """
pdf_bytes = fitz.open(stream=bits).convert_to_pdf() pdf_bytes = fitz.open(stream=bits).convert_to_pdf()
self._records = [Doc(v) for v in fitz.open('pdf', pdf_bytes)] self._raw_fitz = fitz.open('pdf', pdf_bytes)
self._records = [Doc(v) for v in self._raw_fitz]
self._raw_data = bits self._raw_data = bits
self._data_bits = pdf_bytes self._data_bits = pdf_bytes
...@@ -153,14 +257,50 @@ class ImageDataset(Dataset): ...@@ -153,14 +257,50 @@ class ImageDataset(Dataset):
""" """
return self._records[page_id] return self._records[page_id]
def dump_to_file(self, file_path: str):
"""Dump the file
Args:
file_path (str): the file path
"""
dir_name = os.path.dirname(file_path)
if dir_name not in ('', '.', '..'):
os.makedirs(dir_name, exist_ok=True)
self._raw_fitz.save(file_path)
def apply(self, proc: Callable, *args, **kwargs):
"""Apply callable method which.
Args:
proc (Callable): invoke proc as follows:
proc(dataset, *args, **kwargs)
Returns:
Any: return the result generated by proc
"""
return proc(self, *args, **kwargs)
def classify(self) -> SupportedPdfParseMethod:
"""classify the dataset
Returns:
SupportedPdfParseMethod: _description_
"""
return SupportedPdfParseMethod.OCR
def clone(self):
"""clone this dataset
"""
return ImageDataset(self._raw_data)
class Doc(PageableData): class Doc(PageableData):
"""Initialized with pymudoc object.""" """Initialized with pymudoc object."""
def __init__(self, doc: fitz.Page): def __init__(self, doc: fitz.Page):
self._doc = doc self._doc = doc
def get_image(self): def get_image(self):
"""Return the imge info. """Return the image info.
Returns: Returns:
dict: { dict: {
...@@ -192,3 +332,34 @@ class Doc(PageableData): ...@@ -192,3 +332,34 @@ class Doc(PageableData):
def __getattr__(self, name): def __getattr__(self, name):
if hasattr(self._doc, name): if hasattr(self._doc, name):
return getattr(self._doc, name) return getattr(self._doc, name)
def draw_rect(self, rect_coords, color, fill, fill_opacity, width, overlay):
"""draw rectangle.
Args:
rect_coords (list[float]): four elements array contain the top-left and bottom-right coordinates, [x0, y0, x1, y1]
color (list[float] | None): three element tuple which describe the RGB of the board line, None means no board line
fill (list[float] | None): fill the board with RGB, None means will not fill with color
fill_opacity (float): opacity of the fill, range from [0, 1]
width (float): the width of board
overlay (bool): fill the color in foreground or background. True means fill in background.
"""
self._doc.draw_rect(
rect_coords,
color=color,
fill=fill,
fill_opacity=fill_opacity,
width=width,
overlay=overlay,
)
def insert_text(self, coord, content, fontsize, color):
"""insert text.
Args:
coord (list[float]): four elements array contain the top-left and bottom-right coordinates, [x0, y0, x1, y1]
content (str): the text content
fontsize (int): font size of the text
color (list[float] | None): three element tuple which describe the RGB of the board line, None will use the default font color!
"""
self._doc.insert_text(coord, content, fontsize=fontsize, color=color)
...@@ -165,8 +165,8 @@ def merge_para_with_text(para_block): ...@@ -165,8 +165,8 @@ def merge_para_with_text(para_block):
if content: if content:
langs = ['zh', 'ja', 'ko'] langs = ['zh', 'ja', 'ko']
# logger.info(f'block_lang: {block_lang}, content: {content}') # logger.info(f'block_lang: {block_lang}, content: {content}')
if block_lang in langs: # 中文/日语/韩文语境下,换行不需要空格分隔 if block_lang in langs: # 中文/日语/韩文语境下,换行不需要空格分隔,但是如果是行内公式结尾,还是要加空格
if j == len(line['spans']) - 1: if j == len(line['spans']) - 1 and span_type not in [ContentType.InlineEquation]:
para_text += content para_text += content
else: else:
para_text += f'{content} ' para_text += f'{content} '
......
from magic_pdf.config.drop_reason import DropReason
from magic_pdf.config.enums import SupportedPdfParseMethod
from magic_pdf.filter.pdf_classify_by_type import classify as do_classify
from magic_pdf.filter.pdf_meta_scan import pdf_meta_scan
def classify(pdf_bytes: bytes) -> SupportedPdfParseMethod:
"""根据pdf的元数据,判断是文本pdf,还是ocr pdf."""
pdf_meta = pdf_meta_scan(pdf_bytes)
if pdf_meta.get('_need_drop', False): # 如果返回了需要丢弃的标志,则抛出异常
raise Exception(f"pdf meta_scan need_drop,reason is {pdf_meta['_drop_reason']}")
else:
is_encrypted = pdf_meta['is_encrypted']
is_needs_password = pdf_meta['is_needs_password']
if is_encrypted or is_needs_password: # 加密的,需要密码的,没有页面的,都不处理
raise Exception(f'pdf meta_scan need_drop,reason is {DropReason.ENCRYPTED}')
else:
is_text_pdf, results = do_classify(
pdf_meta['total_page'],
pdf_meta['page_width_pts'],
pdf_meta['page_height_pts'],
pdf_meta['image_info_per_page'],
pdf_meta['text_len_per_page'],
pdf_meta['imgs_per_page'],
pdf_meta['text_layout_per_page'],
pdf_meta['invalid_chars'],
)
if is_text_pdf:
return SupportedPdfParseMethod.TXT
else:
return SupportedPdfParseMethod.OCR
...@@ -8,7 +8,7 @@ from loguru import logger ...@@ -8,7 +8,7 @@ from loguru import logger
from magic_pdf.config.drop_reason import DropReason from magic_pdf.config.drop_reason import DropReason
from magic_pdf.libs.commons import get_top_percent_list, mymax from magic_pdf.libs.commons import get_top_percent_list, mymax
from magic_pdf.libs.language import detect_lang from magic_pdf.libs.language import detect_lang
from magic_pdf.libs.pdf_check import detect_invalid_chars_by_pymupdf from magic_pdf.libs.pdf_check import detect_invalid_chars_by_pymupdf, detect_invalid_chars
scan_max_page = 50 scan_max_page = 50
junk_limit_min = 10 junk_limit_min = 10
...@@ -323,7 +323,8 @@ def get_language(doc: fitz.Document): ...@@ -323,7 +323,8 @@ def get_language(doc: fitz.Document):
def check_invalid_chars(pdf_bytes): def check_invalid_chars(pdf_bytes):
"""乱码检测.""" """乱码检测."""
return detect_invalid_chars_by_pymupdf(pdf_bytes) # return detect_invalid_chars_by_pymupdf(pdf_bytes)
return detect_invalid_chars(pdf_bytes)
def pdf_meta_scan(pdf_bytes: bytes): def pdf_meta_scan(pdf_bytes: bytes):
......
import fitz import fitz
from magic_pdf.config.constants import CROSS_PAGE from magic_pdf.config.constants import CROSS_PAGE
from magic_pdf.config.ocr_content_type import BlockType, CategoryId, ContentType from magic_pdf.config.ocr_content_type import (BlockType, CategoryId,
from magic_pdf.data.dataset import PymuDocDataset ContentType)
from magic_pdf.data.dataset import Dataset
from magic_pdf.model.magic_model import MagicModel from magic_pdf.model.magic_model import MagicModel
...@@ -194,7 +195,7 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename): ...@@ -194,7 +195,7 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
) )
# Save the PDF # Save the PDF
pdf_docs.save(f'{out_path}/{filename}_layout.pdf') pdf_docs.save(f'{out_path}/{filename}')
def draw_span_bbox(pdf_info, pdf_bytes, out_path, filename): def draw_span_bbox(pdf_info, pdf_bytes, out_path, filename):
...@@ -282,18 +283,17 @@ def draw_span_bbox(pdf_info, pdf_bytes, out_path, filename): ...@@ -282,18 +283,17 @@ def draw_span_bbox(pdf_info, pdf_bytes, out_path, filename):
draw_bbox_without_number(i, dropped_list, page, [158, 158, 158], False) draw_bbox_without_number(i, dropped_list, page, [158, 158, 158], False)
# Save the PDF # Save the PDF
pdf_docs.save(f'{out_path}/{filename}_spans.pdf') pdf_docs.save(f'{out_path}/{filename}')
def draw_model_bbox(model_list: list, pdf_bytes, out_path, filename): def draw_model_bbox(model_list, dataset: Dataset, out_path, filename):
dropped_bbox_list = [] dropped_bbox_list = []
tables_body_list, tables_caption_list, tables_footnote_list = [], [], [] tables_body_list, tables_caption_list, tables_footnote_list = [], [], []
imgs_body_list, imgs_caption_list, imgs_footnote_list = [], [], [] imgs_body_list, imgs_caption_list, imgs_footnote_list = [], [], []
titles_list = [] titles_list = []
texts_list = [] texts_list = []
interequations_list = [] interequations_list = []
pdf_docs = fitz.open('pdf', pdf_bytes) magic_model = MagicModel(model_list, dataset)
magic_model = MagicModel(model_list, PymuDocDataset(pdf_bytes))
for i in range(len(model_list)): for i in range(len(model_list)):
page_dropped_list = [] page_dropped_list = []
tables_body, tables_caption, tables_footnote = [], [], [] tables_body, tables_caption, tables_footnote = [], [], []
...@@ -337,7 +337,8 @@ def draw_model_bbox(model_list: list, pdf_bytes, out_path, filename): ...@@ -337,7 +337,8 @@ def draw_model_bbox(model_list: list, pdf_bytes, out_path, filename):
dropped_bbox_list.append(page_dropped_list) dropped_bbox_list.append(page_dropped_list)
imgs_footnote_list.append(imgs_footnote) imgs_footnote_list.append(imgs_footnote)
for i, page in enumerate(pdf_docs): for i in range(len(dataset)):
page = dataset.get_page(i)
draw_bbox_with_number( draw_bbox_with_number(
i, dropped_bbox_list, page, [158, 158, 158], True i, dropped_bbox_list, page, [158, 158, 158], True
) # color ! ) # color !
...@@ -352,7 +353,7 @@ def draw_model_bbox(model_list: list, pdf_bytes, out_path, filename): ...@@ -352,7 +353,7 @@ def draw_model_bbox(model_list: list, pdf_bytes, out_path, filename):
draw_bbox_with_number(i, interequations_list, page, [0, 255, 0], True) draw_bbox_with_number(i, interequations_list, page, [0, 255, 0], True)
# Save the PDF # Save the PDF
pdf_docs.save(f'{out_path}/{filename}_model.pdf') dataset.dump_to_file(f'{out_path}/{filename}')
def draw_line_sort_bbox(pdf_info, pdf_bytes, out_path, filename): def draw_line_sort_bbox(pdf_info, pdf_bytes, out_path, filename):
...@@ -390,7 +391,7 @@ def draw_line_sort_bbox(pdf_info, pdf_bytes, out_path, filename): ...@@ -390,7 +391,7 @@ def draw_line_sort_bbox(pdf_info, pdf_bytes, out_path, filename):
for i, page in enumerate(pdf_docs): for i, page in enumerate(pdf_docs):
draw_bbox_with_number(i, layout_bbox_list, page, [255, 0, 0], False) draw_bbox_with_number(i, layout_bbox_list, page, [255, 0, 0], False)
pdf_docs.save(f'{out_path}/{filename}_line_sort.pdf') pdf_docs.save(f'{out_path}/{filename}')
def draw_layout_sort_bbox(pdf_info, pdf_bytes, out_path, filename): def draw_layout_sort_bbox(pdf_info, pdf_bytes, out_path, filename):
......
import fitz import fitz
import numpy as np import numpy as np
from loguru import logger from loguru import logger
# import re import re
# from io import BytesIO from io import BytesIO
# from pdfminer.high_level import extract_text from pdfminer.high_level import extract_text
def calculate_sample_count(total_page: int): def calculate_sample_count(total_page: int):
...@@ -33,33 +33,33 @@ def extract_pages(src_pdf_bytes: bytes) -> fitz.Document: ...@@ -33,33 +33,33 @@ def extract_pages(src_pdf_bytes: bytes) -> fitz.Document:
return sample_docs return sample_docs
# def detect_invalid_chars(src_pdf_bytes: bytes) -> bool: def detect_invalid_chars(src_pdf_bytes: bytes) -> bool:
# """" """"
# 检测PDF中是否包含非法字符 检测PDF中是否包含非法字符
# """ """
# '''pdfminer比较慢,需要先随机抽取10页左右的sample''' '''pdfminer比较慢,需要先随机抽取10页左右的sample'''
# sample_docs = extract_pages(src_pdf_bytes) sample_docs = extract_pages(src_pdf_bytes)
# sample_pdf_bytes = sample_docs.tobytes() sample_pdf_bytes = sample_docs.tobytes()
# sample_pdf_file_like_object = BytesIO(sample_pdf_bytes) sample_pdf_file_like_object = BytesIO(sample_pdf_bytes)
# text = extract_text(sample_pdf_file_like_object) text = extract_text(sample_pdf_file_like_object)
# text = text.replace("\n", "") text = text.replace("\n", "")
# # logger.info(text) # logger.info(text)
# '''乱码文本用pdfminer提取出来的文本特征是(cid:xxx)''' '''乱码文本用pdfminer提取出来的文本特征是(cid:xxx)'''
# cid_pattern = re.compile(r'\(cid:\d+\)') cid_pattern = re.compile(r'\(cid:\d+\)')
# matches = cid_pattern.findall(text) matches = cid_pattern.findall(text)
# cid_count = len(matches) cid_count = len(matches)
# cid_len = sum(len(match) for match in matches) cid_len = sum(len(match) for match in matches)
# text_len = len(text) text_len = len(text)
# if text_len == 0: if text_len == 0:
# cid_chars_radio = 0 cid_chars_radio = 0
# else: else:
# cid_chars_radio = cid_count/(cid_count + text_len - cid_len) cid_chars_radio = cid_count/(cid_count + text_len - cid_len)
# logger.info(f"cid_count: {cid_count}, text_len: {text_len}, cid_chars_radio: {cid_chars_radio}") logger.info(f"cid_count: {cid_count}, text_len: {text_len}, cid_chars_radio: {cid_chars_radio}")
# '''当一篇文章存在5%以上的文本是乱码时,认为该文档为乱码文档''' '''当一篇文章存在5%以上的文本是乱码时,认为该文档为乱码文档'''
# if cid_chars_radio > 0.05: if cid_chars_radio > 0.05:
# return False # 乱码文档 return False # 乱码文档
# else: else:
# return True # 正常文档 return True # 正常文档
def count_replacement_characters(text: str) -> int: def count_replacement_characters(text: str) -> int:
......
from typing import Callable
from abc import ABC, abstractmethod
from magic_pdf.data.data_reader_writer import DataWriter
from magic_pdf.data.dataset import Dataset
from magic_pdf.pipe.operators import PipeResult
__use_inside_model__ = True __use_inside_model__ = True
__model_mode__ = "full" __model_mode__ = "full"
class InferenceResultBase(ABC):
@abstractmethod
def __init__(self, inference_results: list, dataset: Dataset):
"""Initialized method.
Args:
inference_results (list): the inference result generated by model
dataset (Dataset): the dataset related with model inference result
"""
self._infer_res = inference_results
self._dataset = dataset
@abstractmethod
def draw_model(self, file_path: str) -> None:
"""Draw model inference result.
Args:
file_path (str): the output file path
"""
pass
@abstractmethod
def dump_model(self, writer: DataWriter, file_path: str):
"""Dump model inference result to file.
Args:
writer (DataWriter): writer handle
file_path (str): the location of target file
"""
pass
@abstractmethod
def get_infer_res(self):
"""Get the inference result.
Returns:
list: the inference result generated by model
"""
pass
@abstractmethod
def apply(self, proc: Callable, *args, **kwargs):
"""Apply callable method which.
Args:
proc (Callable): invoke proc as follows:
proc(inference_result, *args, **kwargs)
Returns:
Any: return the result generated by proc
"""
pass
@abstractmethod
def pipe_auto_mode(
self,
imageWriter: DataWriter,
start_page_id=0,
end_page_id=None,
debug_mode=False,
lang=None,
) -> PipeResult:
"""Post-proc the model inference result.
step1: classify the dataset type
step2: based the result of step1, using `pipe_txt_mode` or `pipe_ocr_mode`
Args:
imageWriter (DataWriter): the image writer handle
start_page_id (int, optional): Defaults to 0. Let user select some pages He/She want to process
end_page_id (int, optional): Defaults to the last page index of dataset. Let user select some pages He/She want to process
debug_mode (bool, optional): Defaults to False. will dump more log if enabled
lang (str, optional): Defaults to None.
Returns:
PipeResult: the result
"""
pass
@abstractmethod
def pipe_txt_mode(
self,
imageWriter: DataWriter,
start_page_id=0,
end_page_id=None,
debug_mode=False,
lang=None,
) -> PipeResult:
"""Post-proc the model inference result, Extract the text using the
third library, such as `pymupdf`
Args:
imageWriter (DataWriter): the image writer handle
start_page_id (int, optional): Defaults to 0. Let user select some pages He/She want to process
end_page_id (int, optional): Defaults to the last page index of dataset. Let user select some pages He/She want to process
debug_mode (bool, optional): Defaults to False. will dump more log if enabled
lang (str, optional): Defaults to None.
Returns:
PipeResult: the result
"""
pass
@abstractmethod
def pipe_ocr_mode(
self,
imageWriter: DataWriter,
start_page_id=0,
end_page_id=None,
debug_mode=False,
lang=None,
) -> PipeResult:
pass
import os
import time import time
import fitz import fitz
import numpy as np import numpy as np
from loguru import logger from loguru import logger
# 关闭paddle的信号处理
import paddle
paddle.disable_signal_handler()
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
os.environ['YOLO_VERBOSE'] = 'False' # disable yolo logger
try:
import torchtext
if torchtext.__version__ >= '0.18.0':
torchtext.disable_torchtext_deprecation_warning()
except ImportError:
pass
import magic_pdf.model as model_config
from magic_pdf.data.dataset import Dataset
from magic_pdf.libs.clean_memory import clean_memory from magic_pdf.libs.clean_memory import clean_memory
from magic_pdf.libs.config_reader import get_local_models_dir, get_device, get_table_recog_config, get_layout_config, \ from magic_pdf.libs.config_reader import (get_device, get_formula_config,
get_formula_config get_layout_config,
get_local_models_dir,
get_table_recog_config)
from magic_pdf.model.model_list import MODEL from magic_pdf.model.model_list import MODEL
import magic_pdf.model as model_config from magic_pdf.model.operators import InferenceResult
def dict_compare(d1, d2): def dict_compare(d1, d2):
...@@ -25,19 +45,25 @@ def remove_duplicates_dicts(lst): ...@@ -25,19 +45,25 @@ def remove_duplicates_dicts(lst):
return unique_dicts return unique_dicts
def load_images_from_pdf(pdf_bytes: bytes, dpi=200, start_page_id=0, end_page_id=None) -> list: def load_images_from_pdf(
pdf_bytes: bytes, dpi=200, start_page_id=0, end_page_id=None
) -> list:
try: try:
from PIL import Image from PIL import Image
except ImportError: except ImportError:
logger.error("Pillow not installed, please install by pip.") logger.error('Pillow not installed, please install by pip.')
exit(1) exit(1)
images = [] images = []
with fitz.open("pdf", pdf_bytes) as doc: with fitz.open('pdf', pdf_bytes) as doc:
pdf_page_num = doc.page_count pdf_page_num = doc.page_count
end_page_id = end_page_id if end_page_id is not None and end_page_id >= 0 else pdf_page_num - 1 end_page_id = (
end_page_id
if end_page_id is not None and end_page_id >= 0
else pdf_page_num - 1
)
if end_page_id > pdf_page_num - 1: if end_page_id > pdf_page_num - 1:
logger.warning("end_page_id is out of range, use images length") logger.warning('end_page_id is out of range, use images length')
end_page_id = pdf_page_num - 1 end_page_id = pdf_page_num - 1
for index in range(0, doc.page_count): for index in range(0, doc.page_count):
...@@ -50,11 +76,11 @@ def load_images_from_pdf(pdf_bytes: bytes, dpi=200, start_page_id=0, end_page_id ...@@ -50,11 +76,11 @@ def load_images_from_pdf(pdf_bytes: bytes, dpi=200, start_page_id=0, end_page_id
if pm.width > 4500 or pm.height > 4500: if pm.width > 4500 or pm.height > 4500:
pm = page.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False) pm = page.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False)
img = Image.frombytes("RGB", (pm.width, pm.height), pm.samples) img = Image.frombytes('RGB', (pm.width, pm.height), pm.samples)
img = np.array(img) img = np.array(img)
img_dict = {"img": img, "width": pm.width, "height": pm.height} img_dict = {'img': img, 'width': pm.width, 'height': pm.height}
else: else:
img_dict = {"img": [], "width": 0, "height": 0} img_dict = {'img': [], 'width': 0, 'height': 0}
images.append(img_dict) images.append(img_dict)
return images return images
...@@ -69,117 +95,150 @@ class ModelSingleton: ...@@ -69,117 +95,150 @@ class ModelSingleton:
cls._instance = super().__new__(cls) cls._instance = super().__new__(cls)
return cls._instance return cls._instance
def get_model(self, ocr: bool, show_log: bool, lang=None, layout_model=None, formula_enable=None, table_enable=None): def get_model(
self,
ocr: bool,
show_log: bool,
lang=None,
layout_model=None,
formula_enable=None,
table_enable=None,
):
key = (ocr, show_log, lang, layout_model, formula_enable, table_enable) key = (ocr, show_log, lang, layout_model, formula_enable, table_enable)
if key not in self._models: if key not in self._models:
self._models[key] = custom_model_init(ocr=ocr, show_log=show_log, lang=lang, layout_model=layout_model, self._models[key] = custom_model_init(
formula_enable=formula_enable, table_enable=table_enable) ocr=ocr,
show_log=show_log,
lang=lang,
layout_model=layout_model,
formula_enable=formula_enable,
table_enable=table_enable,
)
return self._models[key] return self._models[key]
def custom_model_init(ocr: bool = False, show_log: bool = False, lang=None, def custom_model_init(
layout_model=None, formula_enable=None, table_enable=None): ocr: bool = False,
show_log: bool = False,
lang=None,
layout_model=None,
formula_enable=None,
table_enable=None,
):
model = None model = None
if model_config.__model_mode__ == "lite": if model_config.__model_mode__ == 'lite':
logger.warning("The Lite mode is provided for developers to conduct testing only, and the output quality is " logger.warning(
"not guaranteed to be reliable.") 'The Lite mode is provided for developers to conduct testing only, and the output quality is '
'not guaranteed to be reliable.'
)
model = MODEL.Paddle model = MODEL.Paddle
elif model_config.__model_mode__ == "full": elif model_config.__model_mode__ == 'full':
model = MODEL.PEK model = MODEL.PEK
if model_config.__use_inside_model__: if model_config.__use_inside_model__:
model_init_start = time.time() model_init_start = time.time()
if model == MODEL.Paddle: if model == MODEL.Paddle:
from magic_pdf.model.pp_structure_v2 import CustomPaddleModel from magic_pdf.model.pp_structure_v2 import CustomPaddleModel
custom_model = CustomPaddleModel(ocr=ocr, show_log=show_log, lang=lang) custom_model = CustomPaddleModel(ocr=ocr, show_log=show_log, lang=lang)
elif model == MODEL.PEK: elif model == MODEL.PEK:
from magic_pdf.model.pdf_extract_kit import CustomPEKModel from magic_pdf.model.pdf_extract_kit import CustomPEKModel
# 从配置文件读取model-dir和device # 从配置文件读取model-dir和device
local_models_dir = get_local_models_dir() local_models_dir = get_local_models_dir()
device = get_device() device = get_device()
layout_config = get_layout_config() layout_config = get_layout_config()
if layout_model is not None: if layout_model is not None:
layout_config["model"] = layout_model layout_config['model'] = layout_model
formula_config = get_formula_config() formula_config = get_formula_config()
if formula_enable is not None: if formula_enable is not None:
formula_config["enable"] = formula_enable formula_config['enable'] = formula_enable
table_config = get_table_recog_config() table_config = get_table_recog_config()
if table_enable is not None: if table_enable is not None:
table_config["enable"] = table_enable table_config['enable'] = table_enable
model_input = { model_input = {
"ocr": ocr, 'ocr': ocr,
"show_log": show_log, 'show_log': show_log,
"models_dir": local_models_dir, 'models_dir': local_models_dir,
"device": device, 'device': device,
"table_config": table_config, 'table_config': table_config,
"layout_config": layout_config, 'layout_config': layout_config,
"formula_config": formula_config, 'formula_config': formula_config,
"lang": lang, 'lang': lang,
} }
custom_model = CustomPEKModel(**model_input) custom_model = CustomPEKModel(**model_input)
else: else:
logger.error("Not allow model_name!") logger.error('Not allow model_name!')
exit(1) exit(1)
model_init_cost = time.time() - model_init_start model_init_cost = time.time() - model_init_start
logger.info(f"model init cost: {model_init_cost}") logger.info(f'model init cost: {model_init_cost}')
else: else:
logger.error("use_inside_model is False, not allow to use inside model") logger.error('use_inside_model is False, not allow to use inside model')
exit(1) exit(1)
return custom_model return custom_model
def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False, def doc_analyze(
start_page_id=0, end_page_id=None, lang=None, dataset: Dataset,
layout_model=None, formula_enable=None, table_enable=None): ocr: bool = False,
show_log: bool = False,
start_page_id=0,
end_page_id=None,
lang=None,
layout_model=None,
formula_enable=None,
table_enable=None,
) -> InferenceResult:
if lang == "": if lang == '':
lang = None lang = None
model_manager = ModelSingleton() model_manager = ModelSingleton()
custom_model = model_manager.get_model(ocr, show_log, lang, layout_model, formula_enable, table_enable) custom_model = model_manager.get_model(
ocr, show_log, lang, layout_model, formula_enable, table_enable
with fitz.open("pdf", pdf_bytes) as doc: )
pdf_page_num = doc.page_count
end_page_id = end_page_id if end_page_id is not None and end_page_id >= 0 else pdf_page_num - 1
if end_page_id > pdf_page_num - 1:
logger.warning("end_page_id is out of range, use images length")
end_page_id = pdf_page_num - 1
images = load_images_from_pdf(pdf_bytes, start_page_id=start_page_id, end_page_id=end_page_id)
model_json = [] model_json = []
doc_analyze_start = time.time() doc_analyze_start = time.time()
for index, img_dict in enumerate(images): if end_page_id is None:
img = img_dict["img"] end_page_id = len(dataset)
page_width = img_dict["width"]
page_height = img_dict["height"] for index in range(len(dataset)):
page_data = dataset.get_page(index)
img_dict = page_data.get_image()
img = img_dict['img']
page_width = img_dict['width']
page_height = img_dict['height']
if start_page_id <= index <= end_page_id: if start_page_id <= index <= end_page_id:
page_start = time.time() page_start = time.time()
result = custom_model(img) result = custom_model(img)
logger.info(f'-----page_id : {index}, page total time: {round(time.time() - page_start, 2)}-----') logger.info(f'-----page_id : {index}, page total time: {round(time.time() - page_start, 2)}-----')
else: else:
result = [] result = []
page_info = {"page_no": index, "height": page_height, "width": page_width}
page_dict = {"layout_dets": result, "page_info": page_info} page_info = {'page_no': index, 'height': page_height, 'width': page_width}
page_dict = {'layout_dets': result, 'page_info': page_info}
model_json.append(page_dict) model_json.append(page_dict)
gc_start = time.time() gc_start = time.time()
clean_memory() clean_memory()
gc_time = round(time.time() - gc_start, 2) gc_time = round(time.time() - gc_start, 2)
logger.info(f"gc time: {gc_time}") logger.info(f'gc time: {gc_time}')
doc_analyze_time = round(time.time() - doc_analyze_start, 2) doc_analyze_time = round(time.time() - doc_analyze_start, 2)
doc_analyze_speed = round( (end_page_id + 1 - start_page_id) / doc_analyze_time, 2) doc_analyze_speed = round((end_page_id + 1 - start_page_id) / doc_analyze_time, 2)
logger.info(f"doc analyze time: {round(time.time() - doc_analyze_start, 2)}," logger.info(
f" speed: {doc_analyze_speed} pages/second") f'doc analyze time: {round(time.time() - doc_analyze_start, 2)},'
f' speed: {doc_analyze_speed} pages/second'
)
return model_json return InferenceResult(model_json, dataset)
import copy
import json
import os
from typing import Callable
from magic_pdf.config.constants import PARSE_TYPE_OCR, PARSE_TYPE_TXT
from magic_pdf.config.enums import SupportedPdfParseMethod
from magic_pdf.data.data_reader_writer import DataWriter
from magic_pdf.data.dataset import Dataset
from magic_pdf.filter import classify
from magic_pdf.libs.draw_bbox import draw_model_bbox
from magic_pdf.libs.version import __version__
from magic_pdf.model import InferenceResultBase
from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union
from magic_pdf.pipe.operators import PipeResult
class InferenceResult(InferenceResultBase):
def __init__(self, inference_results: list, dataset: Dataset):
"""Initialized method.
Args:
inference_results (list): the inference result generated by model
dataset (Dataset): the dataset related with model inference result
"""
self._infer_res = inference_results
self._dataset = dataset
def draw_model(self, file_path: str) -> None:
"""Draw model inference result.
Args:
file_path (str): the output file path
"""
dir_name = os.path.dirname(file_path)
base_name = os.path.basename(file_path)
if not os.path.exists(dir_name):
os.makedirs(dir_name, exist_ok=True)
draw_model_bbox(
copy.deepcopy(self._infer_res), self._dataset, dir_name, base_name
)
def dump_model(self, writer: DataWriter, file_path: str):
"""Dump model inference result to file.
Args:
writer (DataWriter): writer handle
file_path (str): the location of target file
"""
writer.write_string(
file_path, json.dumps(self._infer_res, ensure_ascii=False, indent=4)
)
def get_infer_res(self):
"""Get the inference result.
Returns:
list: the inference result generated by model
"""
return self._infer_res
def apply(self, proc: Callable, *args, **kwargs):
"""Apply callable method which.
Args:
proc (Callable): invoke proc as follows:
proc(inference_result, *args, **kwargs)
Returns:
Any: return the result generated by proc
"""
return proc(copy.deepcopy(self._infer_res), *args, **kwargs)
def pipe_auto_mode(
self,
imageWriter: DataWriter,
start_page_id=0,
end_page_id=None,
debug_mode=False,
lang=None,
) -> PipeResult:
"""Post-proc the model inference result.
step1: classify the dataset type
step2: based the result of step1, using `pipe_txt_mode` or `pipe_ocr_mode`
Args:
imageWriter (DataWriter): the image writer handle
start_page_id (int, optional): Defaults to 0. Let user select some pages He/She want to process
end_page_id (int, optional): Defaults to the last page index of dataset. Let user select some pages He/She want to process
debug_mode (bool, optional): Defaults to False. will dump more log if enabled
lang (str, optional): Defaults to None.
Returns:
PipeResult: the result
"""
pdf_proc_method = classify(self._dataset.data_bits())
if pdf_proc_method == SupportedPdfParseMethod.TXT:
return self.pipe_txt_mode(
imageWriter, start_page_id, end_page_id, debug_mode, lang
)
else:
return self.pipe_ocr_mode(
imageWriter, start_page_id, end_page_id, debug_mode, lang
)
def pipe_txt_mode(
self,
imageWriter: DataWriter,
start_page_id=0,
end_page_id=None,
debug_mode=False,
lang=None,
) -> PipeResult:
"""Post-proc the model inference result, Extract the text using the
third library, such as `pymupdf`
Args:
imageWriter (DataWriter): the image writer handle
start_page_id (int, optional): Defaults to 0. Let user select some pages He/She want to process
end_page_id (int, optional): Defaults to the last page index of dataset. Let user select some pages He/She want to process
debug_mode (bool, optional): Defaults to False. will dump more log if enabled
lang (str, optional): Defaults to None.
Returns:
PipeResult: the result
"""
def proc(*args, **kwargs) -> PipeResult:
res = pdf_parse_union(*args, **kwargs)
res['_parse_type'] = PARSE_TYPE_TXT
res['_version_name'] = __version__
if 'lang' in kwargs and kwargs['lang'] is not None:
res['lang'] = kwargs['lang']
return PipeResult(res, self._dataset)
res = self.apply(
proc,
self._dataset,
imageWriter,
SupportedPdfParseMethod.TXT,
start_page_id=start_page_id,
end_page_id=end_page_id,
debug_mode=debug_mode,
lang=lang,
)
return res
def pipe_ocr_mode(
self,
imageWriter: DataWriter,
start_page_id=0,
end_page_id=None,
debug_mode=False,
lang=None,
) -> PipeResult:
"""Post-proc the model inference result, Extract the text using `OCR`
technical.
Args:
imageWriter (DataWriter): the image writer handle
start_page_id (int, optional): Defaults to 0. Let user select some pages He/She want to process
end_page_id (int, optional): Defaults to the last page index of dataset. Let user select some pages He/She want to process
debug_mode (bool, optional): Defaults to False. will dump more log if enabled
lang (str, optional): Defaults to None.
Returns:
PipeResult: the result
"""
def proc(*args, **kwargs) -> PipeResult:
res = pdf_parse_union(*args, **kwargs)
res['_parse_type'] = PARSE_TYPE_OCR
res['_version_name'] = __version__
if 'lang' in kwargs and kwargs['lang'] is not None:
res['lang'] = kwargs['lang']
return PipeResult(res, self._dataset)
res = self.apply(
proc,
self._dataset,
imageWriter,
SupportedPdfParseMethod.OCR,
start_page_id=start_page_id,
end_page_id=end_page_id,
debug_mode=debug_mode,
lang=lang,
)
return res
...@@ -179,7 +179,25 @@ class CustomPEKModel: ...@@ -179,7 +179,25 @@ class CustomPEKModel:
layout_res = self.layout_model(image, ignore_catids=[]) layout_res = self.layout_model(image, ignore_catids=[])
elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO: elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
# doclayout_yolo # doclayout_yolo
layout_res = self.layout_model.predict(image) img_pil = Image.fromarray(image)
width, height = img_pil.size
# logger.info(f'width: {width}, height: {height}')
input_res = {"poly":[0,0,width,0,width,height,0,height]}
new_image, useful_list = crop_img(input_res, img_pil, crop_paste_x=width//2, crop_paste_y=0)
paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
layout_res = self.layout_model.predict(new_image)
for res in layout_res:
p1, p2, p3, p4, p5, p6, p7, p8 = res['poly']
p1 = p1 - paste_x + xmin
p2 = p2 - paste_y + ymin
p3 = p3 - paste_x + xmin
p4 = p4 - paste_y + ymin
p5 = p5 - paste_x + xmin
p6 = p6 - paste_y + ymin
p7 = p7 - paste_x + xmin
p8 = p8 - paste_y + ymin
res['poly'] = [p1, p2, p3, p4, p5, p6, p7, p8]
layout_cost = round(time.time() - layout_start, 2) layout_cost = round(time.time() - layout_start, 2)
logger.info(f'layout detection time: {layout_cost}') logger.info(f'layout detection time: {layout_cost}')
...@@ -215,6 +233,7 @@ class CustomPEKModel: ...@@ -215,6 +233,7 @@ class CustomPEKModel:
# OCR recognition # OCR recognition
new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR) new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
if self.apply_ocr: if self.apply_ocr:
ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res)[0] ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res)[0]
else: else:
......
...@@ -92,14 +92,24 @@ class AtomModelSingleton: ...@@ -92,14 +92,24 @@ class AtomModelSingleton:
return cls._instance return cls._instance
def get_atom_model(self, atom_model_name: str, **kwargs): def get_atom_model(self, atom_model_name: str, **kwargs):
lang = kwargs.get('lang', None) lang = kwargs.get('lang', None)
layout_model_name = kwargs.get('layout_model_name', None) layout_model_name = kwargs.get('layout_model_name', None)
key = (atom_model_name, layout_model_name, lang) table_model_name = kwargs.get('table_model_name', None)
if atom_model_name in [AtomicModel.OCR]:
key = (atom_model_name, lang)
elif atom_model_name in [AtomicModel.Layout]:
key = (atom_model_name, layout_model_name)
elif atom_model_name in [AtomicModel.Table]:
key = (atom_model_name, table_model_name)
else:
key = atom_model_name
if key not in self._models: if key not in self._models:
self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs) self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
return self._models[key] return self._models[key]
def atom_model_init(model_name: str, **kwargs): def atom_model_init(model_name: str, **kwargs):
atom_model = None atom_model = None
if model_name == AtomicModel.Layout: if model_name == AtomicModel.Layout:
...@@ -129,7 +139,7 @@ def atom_model_init(model_name: str, **kwargs): ...@@ -129,7 +139,7 @@ def atom_model_init(model_name: str, **kwargs):
atom_model = ocr_model_init( atom_model = ocr_model_init(
kwargs.get('ocr_show_log'), kwargs.get('ocr_show_log'),
kwargs.get('det_db_box_thresh'), kwargs.get('det_db_box_thresh'),
kwargs.get('lang') kwargs.get('lang'),
) )
elif model_name == AtomicModel.Table: elif model_name == AtomicModel.Table:
atom_model = table_model_init( atom_model = table_model_init(
......
...@@ -42,10 +42,16 @@ def get_res_list_from_layout_res(layout_res): ...@@ -42,10 +42,16 @@ def get_res_list_from_layout_res(layout_res):
def clean_vram(device, vram_threshold=8): def clean_vram(device, vram_threshold=8):
if torch.cuda.is_available() and device != 'cpu': total_memory = get_vram(device)
total_memory = torch.cuda.get_device_properties(device).total_memory / (1024 ** 3) # 将字节转换为 GB if total_memory and total_memory <= vram_threshold:
if total_memory <= vram_threshold:
gc_start = time.time() gc_start = time.time()
clean_memory() clean_memory()
gc_time = round(time.time() - gc_start, 2) gc_time = round(time.time() - gc_start, 2)
logger.info(f"gc time: {gc_time}") logger.info(f"gc time: {gc_time}")
def get_vram(device):
if torch.cuda.is_available() and device != 'cpu':
total_memory = torch.cuda.get_device_properties(device).total_memory / (1024 ** 3) # 将字节转换为 GB
return total_memory
return None
\ No newline at end of file
from magic_pdf.config.enums import SupportedPdfParseMethod from magic_pdf.config.enums import SupportedPdfParseMethod
from magic_pdf.data.dataset import PymuDocDataset from magic_pdf.data.dataset import Dataset
from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union
def parse_pdf_by_ocr(pdf_bytes, def parse_pdf_by_ocr(dataset: Dataset,
model_list, model_list,
imageWriter, imageWriter,
start_page_id=0, start_page_id=0,
...@@ -11,9 +11,8 @@ def parse_pdf_by_ocr(pdf_bytes, ...@@ -11,9 +11,8 @@ def parse_pdf_by_ocr(pdf_bytes,
debug_mode=False, debug_mode=False,
lang=None, lang=None,
): ):
dataset = PymuDocDataset(pdf_bytes) return pdf_parse_union(model_list,
return pdf_parse_union(dataset, dataset,
model_list,
imageWriter, imageWriter,
SupportedPdfParseMethod.OCR, SupportedPdfParseMethod.OCR,
start_page_id=start_page_id, start_page_id=start_page_id,
......
from magic_pdf.config.enums import SupportedPdfParseMethod from magic_pdf.config.enums import SupportedPdfParseMethod
from magic_pdf.data.dataset import PymuDocDataset from magic_pdf.data.dataset import Dataset
from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union
def parse_pdf_by_txt( def parse_pdf_by_txt(
pdf_bytes, dataset: Dataset,
model_list, model_list,
imageWriter, imageWriter,
start_page_id=0, start_page_id=0,
...@@ -12,9 +12,8 @@ def parse_pdf_by_txt( ...@@ -12,9 +12,8 @@ def parse_pdf_by_txt(
debug_mode=False, debug_mode=False,
lang=None, lang=None,
): ):
dataset = PymuDocDataset(pdf_bytes) return pdf_parse_union(model_list,
return pdf_parse_union(dataset, dataset,
model_list,
imageWriter, imageWriter,
SupportedPdfParseMethod.TXT, SupportedPdfParseMethod.TXT,
start_page_id=start_page_id, start_page_id=start_page_id,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment