"tools/base_data.json" did not exist on "bc20cfa2ddad93339edbd4688a0d43cada535e3b"
Unverified Commit 3a42ebbf authored by Xiaomeng Zhao's avatar Xiaomeng Zhao Committed by GitHub
Browse files

Merge pull request #838 from opendatalab/release-0.9.0

Release 0.9.0
parents 765c6d77 14024793
from abc import ABC, abstractmethod
class IOReader(ABC):
@abstractmethod
def read(self, path: str) -> bytes:
"""Read the file.
Args:
path (str): file path to read
Returns:
bytes: the content of the file
"""
pass
@abstractmethod
def read_at(self, path: str, offset: int = 0, limit: int = -1) -> bytes:
"""Read at offset and limit.
Args:
path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
offset (int, optional): the number of bytes skipped. Defaults to 0.
limit (int, optional): the length of bytes want to read. Defaults to -1.
Returns:
bytes: the content of file
"""
pass
class IOWriter:
@abstractmethod
def write(self, path: str, data: bytes) -> None:
"""Write file with data.
Args:
path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
data (bytes): the data want to write
"""
pass
import io
import requests
from magic_pdf.data.io.base import IOReader, IOWriter
class HttpReader(IOReader):
def read(self, url: str) -> bytes:
"""Read the file.
Args:
path (str): file path to read
Returns:
bytes: the content of the file
"""
return requests.get(url).content
def read_at(self, path: str, offset: int = 0, limit: int = -1) -> bytes:
"""Not Implemented."""
raise NotImplementedError
class HttpWriter(IOWriter):
def write(self, url: str, data: bytes) -> None:
"""Write file with data.
Args:
path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
data (bytes): the data want to write
"""
files = {'file': io.BytesIO(data)}
response = requests.post(url, files=files)
assert 300 > response.status_code and response.status_code > 199
import boto3
from botocore.config import Config
from magic_pdf.data.io.base import IOReader, IOWriter
class S3Reader(IOReader):
def __init__(
self,
bucket: str,
ak: str,
sk: str,
endpoint_url: str,
addressing_style: str = 'auto',
):
"""s3 reader client.
Args:
bucket (str): bucket name
ak (str): access key
sk (str): secret key
endpoint_url (str): endpoint url of s3
addressing_style (str, optional): Defaults to 'auto'. Other valid options here are 'path' and 'virtual'
refer to https://boto3.amazonaws.com/v1/documentation/api/1.9.42/guide/s3.html
"""
self._bucket = bucket
self._ak = ak
self._sk = sk
self._s3_client = boto3.client(
service_name='s3',
aws_access_key_id=ak,
aws_secret_access_key=sk,
endpoint_url=endpoint_url,
config=Config(
s3={'addressing_style': addressing_style},
retries={'max_attempts': 5, 'mode': 'standard'},
),
)
def read(self, key: str) -> bytes:
"""Read the file.
Args:
path (str): file path to read
Returns:
bytes: the content of the file
"""
return self.read_at(key)
def read_at(self, key: str, offset: int = 0, limit: int = -1) -> bytes:
"""Read at offset and limit.
Args:
path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
offset (int, optional): the number of bytes skipped. Defaults to 0.
limit (int, optional): the length of bytes want to read. Defaults to -1.
Returns:
bytes: the content of file
"""
if limit > -1:
range_header = f'bytes={offset}-{offset+limit-1}'
res = self._s3_client.get_object(
Bucket=self._bucket, Key=key, Range=range_header
)
else:
res = self._s3_client.get_object(
Bucket=self._bucket, Key=key, Range=f'bytes={offset}-'
)
return res['Body'].read()
class S3Writer(IOWriter):
def __init__(
self,
bucket: str,
ak: str,
sk: str,
endpoint_url: str,
addressing_style: str = 'auto',
):
"""s3 reader client.
Args:
bucket (str): bucket name
ak (str): access key
sk (str): secret key
endpoint_url (str): endpoint url of s3
addressing_style (str, optional): Defaults to 'auto'. Other valid options here are 'path' and 'virtual'
refer to https://boto3.amazonaws.com/v1/documentation/api/1.9.42/guide/s3.html
"""
self._bucket = bucket
self._ak = ak
self._sk = sk
self._s3_client = boto3.client(
service_name='s3',
aws_access_key_id=ak,
aws_secret_access_key=sk,
endpoint_url=endpoint_url,
config=Config(
s3={'addressing_style': addressing_style},
retries={'max_attempts': 5, 'mode': 'standard'},
),
)
def write(self, key: str, data: bytes):
"""Write file with data.
Args:
path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
data (bytes): the data want to write
"""
self._s3_client.put_object(Bucket=self._bucket, Key=key, Body=data)
import json
import os
from pathlib import Path
from magic_pdf.config.exceptions import EmptyData, InvalidParams
from magic_pdf.data.data_reader_writer import (FileBasedDataReader,
MultiBucketS3DataReader)
from magic_pdf.data.dataset import ImageDataset, PymuDocDataset
def read_jsonl(
s3_path_or_local: str, s3_client: MultiBucketS3DataReader | None = None
) -> list[PymuDocDataset]:
"""Read the jsonl file and return the list of PymuDocDataset.
Args:
s3_path_or_local (str): local file or s3 path
s3_client (MultiBucketS3DataReader | None, optional): s3 client that support multiple bucket. Defaults to None.
Raises:
InvalidParams: if s3_path_or_local is s3 path but s3_client is not provided.
EmptyData: if no pdf file location is provided in some line of jsonl file.
InvalidParams: if the file location is s3 path but s3_client is not provided
Returns:
list[PymuDocDataset]: each line in the jsonl file will be converted to a PymuDocDataset
"""
bits_arr = []
if s3_path_or_local.startswith('s3://'):
if s3_client is None:
raise InvalidParams('s3_client is required when s3_path is provided')
jsonl_bits = s3_client.read(s3_path_or_local)
else:
jsonl_bits = FileBasedDataReader('').read(s3_path_or_local)
jsonl_d = [
json.loads(line) for line in jsonl_bits.decode().split('\n') if line.strip()
]
for d in jsonl_d[:5]:
pdf_path = d.get('file_location', '') or d.get('path', '')
if len(pdf_path) == 0:
raise EmptyData('pdf file location is empty')
if pdf_path.startswith('s3://'):
if s3_client is None:
raise InvalidParams('s3_client is required when s3_path is provided')
bits_arr.append(s3_client.read(pdf_path))
else:
bits_arr.append(FileBasedDataReader('').read(pdf_path))
return [PymuDocDataset(bits) for bits in bits_arr]
def read_local_pdfs(path: str) -> list[PymuDocDataset]:
"""Read pdf from path or directory.
Args:
path (str): pdf file path or directory that contains pdf files
Returns:
list[PymuDocDataset]: each pdf file will converted to a PymuDocDataset
"""
if os.path.isdir(path):
reader = FileBasedDataReader(path)
return [
PymuDocDataset(reader.read(doc_path.name))
for doc_path in Path(path).glob('*.pdf')
]
else:
reader = FileBasedDataReader()
bits = reader.read(path)
return [PymuDocDataset(bits)]
def read_local_images(path: str, suffixes: list[str]) -> list[ImageDataset]:
"""Read images from path or directory.
Args:
path (str): image file path or directory that contains image files
suffixes (list[str]): the suffixes of the image files used to filter the files. Example: ['jpg', 'png']
Returns:
list[ImageDataset]: each image file will converted to a ImageDataset
"""
if os.path.isdir(path):
imgs_bits = []
s_suffixes = set(suffixes)
reader = FileBasedDataReader(path)
for root, _, files in os.walk(path):
for file in files:
suffix = file.split('.')
if suffix[-1] in s_suffixes:
imgs_bits.append(reader.read(file))
return [ImageDataset(bits) for bits in imgs_bits]
else:
reader = FileBasedDataReader()
bits = reader.read(path)
return [ImageDataset(bits)]
from pydantic import BaseModel, Field
class S3Config(BaseModel):
bucket_name: str = Field(description='s3 bucket name', min_length=1)
access_key: str = Field(description='s3 access key', min_length=1)
secret_key: str = Field(description='s3 secret key', min_length=1)
endpoint_url: str = Field(description='s3 endpoint url', min_length=1)
addressing_style: str = Field(description='s3 addressing style', default='auto', min_length=1)
class PageInfo(BaseModel):
w: float = Field(description='the width of page')
h: float = Field(description='the height of page')
import fitz
import numpy as np
from magic_pdf.utils.annotations import ImportPIL
@ImportPIL
def fitz_doc_to_image(doc, dpi=200) -> dict:
"""Convert fitz.Document to image, Then convert the image to numpy array.
Args:
doc (_type_): pymudoc page
dpi (int, optional): reset the dpi of dpi. Defaults to 200.
Returns:
dict: {'img': numpy array, 'width': width, 'height': height }
"""
from PIL import Image
mat = fitz.Matrix(dpi / 72, dpi / 72)
pm = doc.get_pixmap(matrix=mat, alpha=False)
# If the width or height exceeds 9000 after scaling, do not scale further.
if pm.width > 9000 or pm.height > 9000:
pm = doc.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False)
img = Image.frombytes('RGB', (pm.width, pm.height), pm.samples)
img = np.array(img)
img_dict = {'img': img, 'width': pm.width, 'height': pm.height}
return img_dict
import re
import wordninja
from loguru import logger
from magic_pdf.libs.commons import join_path
......@@ -8,6 +7,7 @@ from magic_pdf.libs.language import detect_lang
from magic_pdf.libs.MakeContentConfig import DropMode, MakeMode
from magic_pdf.libs.markdown_utils import ocr_escape_special_markdown_char
from magic_pdf.libs.ocr_content_type import BlockType, ContentType
from magic_pdf.para.para_split_v3 import ListLineTag
def __is_hyphen_at_line_end(line):
......@@ -24,37 +24,6 @@ def __is_hyphen_at_line_end(line):
return bool(re.search(r'[A-Za-z]+-\s*$', line))
def split_long_words(text):
segments = text.split(' ')
for i in range(len(segments)):
words = re.findall(r'\w+|[^\w]', segments[i], re.UNICODE)
for j in range(len(words)):
if len(words[j]) > 10:
words[j] = ' '.join(wordninja.split(words[j]))
segments[i] = ''.join(words)
return ' '.join(segments)
def ocr_mk_mm_markdown_with_para(pdf_info_list: list, img_buket_path):
markdown = []
for page_info in pdf_info_list:
paras_of_layout = page_info.get('para_blocks')
page_markdown = ocr_mk_markdown_with_para_core_v2(
paras_of_layout, 'mm', img_buket_path)
markdown.extend(page_markdown)
return '\n\n'.join(markdown)
def ocr_mk_nlp_markdown_with_para(pdf_info_dict: list):
markdown = []
for page_info in pdf_info_dict:
paras_of_layout = page_info.get('para_blocks')
page_markdown = ocr_mk_markdown_with_para_core_v2(
paras_of_layout, 'nlp')
markdown.extend(page_markdown)
return '\n\n'.join(markdown)
def ocr_mk_mm_markdown_with_para_and_pagination(pdf_info_dict: list,
img_buket_path):
markdown_with_para_and_pagination = []
......@@ -67,61 +36,23 @@ def ocr_mk_mm_markdown_with_para_and_pagination(pdf_info_dict: list,
paras_of_layout, 'mm', img_buket_path)
markdown_with_para_and_pagination.append({
'page_no':
page_no,
page_no,
'md_content':
'\n\n'.join(page_markdown)
'\n\n'.join(page_markdown)
})
page_no += 1
return markdown_with_para_and_pagination
def ocr_mk_markdown_with_para_core(paras_of_layout, mode, img_buket_path=''):
page_markdown = []
for paras in paras_of_layout:
for para in paras:
para_text = ''
for line in para:
for span in line['spans']:
span_type = span.get('type')
content = ''
language = ''
if span_type == ContentType.Text:
content = span['content']
language = detect_lang(content)
if (language == 'en'): # 只对英文长词进行分词处理,中文分词会丢失文本
content = ocr_escape_special_markdown_char(
split_long_words(content))
else:
content = ocr_escape_special_markdown_char(content)
elif span_type == ContentType.InlineEquation:
content = f"${span['content']}$"
elif span_type == ContentType.InterlineEquation:
content = f"\n$$\n{span['content']}\n$$\n"
elif span_type in [ContentType.Image, ContentType.Table]:
if mode == 'mm':
content = f"\n![]({join_path(img_buket_path, span['image_path'])})\n"
elif mode == 'nlp':
pass
if content != '':
if language == 'en': # 英文语境下 content间需要空格分隔
para_text += content + ' '
else: # 中文语境下,content间不需要空格分隔
para_text += content
if para_text.strip() == '':
continue
else:
page_markdown.append(para_text.strip() + ' ')
return page_markdown
def ocr_mk_markdown_with_para_core_v2(paras_of_layout,
mode,
img_buket_path=''):
img_buket_path='',
):
page_markdown = []
for para_block in paras_of_layout:
para_text = ''
para_type = para_block['type']
if para_type == BlockType.Text:
if para_type in [BlockType.Text, BlockType.List, BlockType.Index]:
para_text = merge_para_with_text(para_block)
elif para_type == BlockType.Title:
para_text = f'# {merge_para_with_text(para_block)}'
......@@ -136,20 +67,21 @@ def ocr_mk_markdown_with_para_core_v2(paras_of_layout,
for line in block['lines']:
for span in line['spans']:
if span['type'] == ContentType.Image:
para_text += f"\n![]({join_path(img_buket_path, span['image_path'])}) \n"
if span.get('image_path', ''):
para_text += f"\n![]({join_path(img_buket_path, span['image_path'])}) \n"
for block in para_block['blocks']: # 2nd.拼image_caption
if block['type'] == BlockType.ImageCaption:
para_text += merge_para_with_text(block)
for block in para_block['blocks']: # 2nd.拼image_caption
para_text += merge_para_with_text(block) + ' \n'
for block in para_block['blocks']: # 3rd.拼image_footnote
if block['type'] == BlockType.ImageFootnote:
para_text += merge_para_with_text(block)
para_text += merge_para_with_text(block) + ' \n'
elif para_type == BlockType.Table:
if mode == 'nlp':
continue
elif mode == 'mm':
for block in para_block['blocks']: # 1st.拼table_caption
if block['type'] == BlockType.TableCaption:
para_text += merge_para_with_text(block)
para_text += merge_para_with_text(block) + ' \n'
for block in para_block['blocks']: # 2nd.拼table_body
if block['type'] == BlockType.TableBody:
for line in block['lines']:
......@@ -160,11 +92,11 @@ def ocr_mk_markdown_with_para_core_v2(paras_of_layout,
para_text += f"\n\n$\n {span['latex']}\n$\n\n"
elif span.get('html', ''):
para_text += f"\n\n{span['html']}\n\n"
else:
elif span.get('image_path', ''):
para_text += f"\n![]({join_path(img_buket_path, span['image_path'])}) \n"
for block in para_block['blocks']: # 3rd.拼table_footnote
if block['type'] == BlockType.TableFootnote:
para_text += merge_para_with_text(block)
para_text += merge_para_with_text(block) + ' \n'
if para_text.strip() == '':
continue
......@@ -174,22 +106,26 @@ def ocr_mk_markdown_with_para_core_v2(paras_of_layout,
return page_markdown
def merge_para_with_text(para_block):
def detect_language(text):
en_pattern = r'[a-zA-Z]+'
en_matches = re.findall(en_pattern, text)
en_length = sum(len(match) for match in en_matches)
if len(text) > 0:
if en_length / len(text) >= 0.5:
return 'en'
else:
return 'unknown'
def detect_language(text):
en_pattern = r'[a-zA-Z]+'
en_matches = re.findall(en_pattern, text)
en_length = sum(len(match) for match in en_matches)
if len(text) > 0:
if en_length / len(text) >= 0.5:
return 'en'
else:
return 'empty'
return 'unknown'
else:
return 'empty'
def merge_para_with_text(para_block):
para_text = ''
for line in para_block['lines']:
for i, line in enumerate(para_block['lines']):
if i >= 1 and line.get(ListLineTag.IS_LIST_START_LINE, False):
para_text += ' \n'
line_text = ''
line_lang = ''
for span in line['spans']:
......@@ -199,17 +135,11 @@ def merge_para_with_text(para_block):
if line_text != '':
line_lang = detect_lang(line_text)
for span in line['spans']:
span_type = span['type']
content = ''
if span_type == ContentType.Text:
content = span['content']
# language = detect_lang(content)
language = detect_language(content)
if language == 'en': # 只对英文长词进行分词处理,中文分词会丢失文本
content = ocr_escape_special_markdown_char(
split_long_words(content))
else:
content = ocr_escape_special_markdown_char(content)
content = ocr_escape_special_markdown_char(span['content'])
elif span_type == ContentType.InlineEquation:
content = f" ${span['content']}$ "
elif span_type == ContentType.InterlineEquation:
......@@ -230,177 +160,83 @@ def merge_para_with_text(para_block):
return para_text
def para_to_standard_format(para, img_buket_path):
para_content = {}
if len(para) == 1:
para_content = line_to_standard_format(para[0], img_buket_path)
elif len(para) > 1:
para_text = ''
inline_equation_num = 0
for line in para:
for span in line['spans']:
language = ''
span_type = span.get('type')
content = ''
if span_type == ContentType.Text:
content = span['content']
language = detect_lang(content)
if language == 'en': # 只对英文长词进行分词处理,中文分词会丢失文本
content = ocr_escape_special_markdown_char(
split_long_words(content))
else:
content = ocr_escape_special_markdown_char(content)
elif span_type == ContentType.InlineEquation:
content = f"${span['content']}$"
inline_equation_num += 1
if language == 'en': # 英文语境下 content间需要空格分隔
para_text += content + ' '
else: # 中文语境下,content间不需要空格分隔
para_text += content
para_content = {
'type': 'text',
'text': para_text,
'inline_equation_num': inline_equation_num,
}
return para_content
def para_to_standard_format_v2(para_block, img_buket_path, page_idx):
def para_to_standard_format_v2(para_block, img_buket_path, page_idx, drop_reason=None):
para_type = para_block['type']
if para_type == BlockType.Text:
para_content = {}
if para_type in [BlockType.Text, BlockType.List, BlockType.Index]:
para_content = {
'type': 'text',
'text': merge_para_with_text(para_block),
'page_idx': page_idx,
}
elif para_type == BlockType.Title:
para_content = {
'type': 'text',
'text': merge_para_with_text(para_block),
'text_level': 1,
'page_idx': page_idx,
}
elif para_type == BlockType.InterlineEquation:
para_content = {
'type': 'equation',
'text': merge_para_with_text(para_block),
'text_format': 'latex',
'page_idx': page_idx,
}
elif para_type == BlockType.Image:
para_content = {'type': 'image', 'page_idx': page_idx}
para_content = {'type': 'image', 'img_path': '', 'img_caption': [], 'img_footnote': []}
for block in para_block['blocks']:
if block['type'] == BlockType.ImageBody:
para_content['img_path'] = join_path(
img_buket_path,
block['lines'][0]['spans'][0]['image_path'])
for line in block['lines']:
for span in line['spans']:
if span['type'] == ContentType.Image:
if span.get('image_path', ''):
para_content['img_path'] = join_path(img_buket_path, span['image_path'])
if block['type'] == BlockType.ImageCaption:
para_content['img_caption'] = merge_para_with_text(block)
para_content['img_caption'].append(merge_para_with_text(block))
if block['type'] == BlockType.ImageFootnote:
para_content['img_footnote'] = merge_para_with_text(block)
para_content['img_footnote'].append(merge_para_with_text(block))
elif para_type == BlockType.Table:
para_content = {'type': 'table', 'page_idx': page_idx}
para_content = {'type': 'table', 'img_path': '', 'table_caption': [], 'table_footnote': []}
for block in para_block['blocks']:
if block['type'] == BlockType.TableBody:
if block["lines"][0]["spans"][0].get('latex', ''):
para_content['table_body'] = f"\n\n$\n {block['lines'][0]['spans'][0]['latex']}\n$\n\n"
elif block["lines"][0]["spans"][0].get('html', ''):
para_content['table_body'] = f"\n\n{block['lines'][0]['spans'][0]['html']}\n\n"
para_content['img_path'] = join_path(img_buket_path, block["lines"][0]["spans"][0]['image_path'])
if block['type'] == BlockType.TableCaption:
para_content['table_caption'] = merge_para_with_text(block)
if block['type'] == BlockType.TableFootnote:
para_content['table_footnote'] = merge_para_with_text(block)
return para_content
for line in block['lines']:
for span in line['spans']:
if span['type'] == ContentType.Table:
if span.get('latex', ''):
para_content['table_body'] = f"\n\n$\n {span['latex']}\n$\n\n"
elif span.get('html', ''):
para_content['table_body'] = f"\n\n{span['html']}\n\n"
def make_standard_format_with_para(pdf_info_dict: list, img_buket_path: str):
content_list = []
for page_info in pdf_info_dict:
paras_of_layout = page_info.get('para_blocks')
if not paras_of_layout:
continue
for para_block in paras_of_layout:
para_content = para_to_standard_format_v2(para_block,
img_buket_path)
content_list.append(para_content)
return content_list
if span.get('image_path', ''):
para_content['img_path'] = join_path(img_buket_path, span['image_path'])
if block['type'] == BlockType.TableCaption:
para_content['table_caption'].append(merge_para_with_text(block))
if block['type'] == BlockType.TableFootnote:
para_content['table_footnote'].append(merge_para_with_text(block))
def line_to_standard_format(line, img_buket_path):
line_text = ''
inline_equation_num = 0
for span in line['spans']:
if not span.get('content'):
if not span.get('image_path'):
continue
else:
if span['type'] == ContentType.Image:
content = {
'type': 'image',
'img_path': join_path(img_buket_path,
span['image_path']),
}
return content
elif span['type'] == ContentType.Table:
content = {
'type': 'table',
'img_path': join_path(img_buket_path,
span['image_path']),
}
return content
else:
if span['type'] == ContentType.InterlineEquation:
interline_equation = span['content']
content = {
'type': 'equation',
'latex': f'$$\n{interline_equation}\n$$'
}
return content
elif span['type'] == ContentType.InlineEquation:
inline_equation = span['content']
line_text += f'${inline_equation}$'
inline_equation_num += 1
elif span['type'] == ContentType.Text:
text_content = ocr_escape_special_markdown_char(
span['content']) # 转义特殊符号
line_text += text_content
content = {
'type': 'text',
'text': line_text,
'inline_equation_num': inline_equation_num,
}
return content
para_content['page_idx'] = page_idx
if drop_reason is not None:
para_content['drop_reason'] = drop_reason
def ocr_mk_mm_standard_format(pdf_info_dict: list):
"""content_list type string
image/text/table/equation(行间的单独拿出来,行内的和text合并) latex string
latex文本字段。 text string 纯文本格式的文本数据。 md string
markdown格式的文本数据。 img_path string s3://full/path/to/img.jpg."""
content_list = []
for page_info in pdf_info_dict:
blocks = page_info.get('preproc_blocks')
if not blocks:
continue
for block in blocks:
for line in block['lines']:
content = line_to_standard_format(line)
content_list.append(content)
return content_list
return para_content
def union_make(pdf_info_dict: list,
make_mode: str,
drop_mode: str,
img_buket_path: str = ''):
img_buket_path: str = '',
):
output_content = []
for page_info in pdf_info_dict:
drop_reason_flag = False
drop_reason = None
if page_info.get('need_drop', False):
drop_reason = page_info.get('drop_reason')
if drop_mode == DropMode.NONE:
pass
elif drop_mode == DropMode.NONE_WITH_REASON:
drop_reason_flag = True
elif drop_mode == DropMode.WHOLE_PDF:
raise Exception((f'drop_mode is {DropMode.WHOLE_PDF} ,'
f'drop_reason is {drop_reason}'))
......@@ -425,8 +261,12 @@ def union_make(pdf_info_dict: list,
output_content.extend(page_markdown)
elif make_mode == MakeMode.STANDARD_FORMAT:
for para_block in paras_of_layout:
para_content = para_to_standard_format_v2(
para_block, img_buket_path, page_idx)
if drop_reason_flag:
para_content = para_to_standard_format_v2(
para_block, img_buket_path, page_idx)
else:
para_content = para_to_standard_format_v2(
para_block, img_buket_path, page_idx)
output_content.append(para_content)
if make_mode in [MakeMode.MM_MD, MakeMode.NLP_MD]:
return '\n\n'.join(output_content)
......
......@@ -10,18 +10,12 @@ block维度自定义字段
# block中lines是否被删除
LINES_DELETED = "lines_deleted"
# struct eqtable
STRUCT_EQTABLE = "struct_eqtable"
# table recognition max time default value
TABLE_MAX_TIME_VALUE = 400
# pp_table_result_max_length
TABLE_MAX_LEN = 480
# pp table structure algorithm
TABLE_MASTER = "TableMaster"
# table master structure dict
TABLE_MASTER_DICT = "table_master_structure_dict.txt"
......@@ -29,12 +23,31 @@ TABLE_MASTER_DICT = "table_master_structure_dict.txt"
TABLE_MASTER_DIR = "table_structure_tablemaster_infer/"
# pp detect model dir
DETECT_MODEL_DIR = "ch_PP-OCRv3_det_infer"
DETECT_MODEL_DIR = "ch_PP-OCRv4_det_infer"
# pp rec model dir
REC_MODEL_DIR = "ch_PP-OCRv3_rec_infer"
REC_MODEL_DIR = "ch_PP-OCRv4_rec_infer"
# pp rec char dict path
REC_CHAR_DICT = "ppocr_keys_v1.txt"
# pp rec copy rec directory
PP_REC_DIRECTORY = ".paddleocr/whl/rec/ch/ch_PP-OCRv4_rec_infer"
# pp rec copy det directory
PP_DET_DIRECTORY = ".paddleocr/whl/det/ch/ch_PP-OCRv4_det_infer"
class MODEL_NAME:
# pp table structure algorithm
TABLE_MASTER = "tablemaster"
# struct eqtable
STRUCT_EQTABLE = "struct_eqtable"
DocLayout_YOLO = "doclayout_yolo"
LAYOUTLMv3 = "layoutlmv3"
YOLO_V8_MFD = "yolo_v8_mfd"
UniMerNet_v2_Small = "unimernet_small"
\ No newline at end of file
......@@ -8,3 +8,4 @@ class DropMode:
WHOLE_PDF = "whole_pdf"
SINGLE_PAGE = "single_page"
NONE = "none"
NONE_WITH_REASON = "none_with_reason"
......@@ -445,3 +445,38 @@ def get_overlap_area(bbox1, bbox2):
# The area of overlap area
return (x_right - x_left) * (y_bottom - y_top)
def calculate_vertical_projection_overlap_ratio(block1, block2):
"""
Calculate the proportion of the x-axis covered by the vertical projection of two blocks.
Args:
block1 (tuple): Coordinates of the first block (x0, y0, x1, y1).
block2 (tuple): Coordinates of the second block (x0, y0, x1, y1).
Returns:
float: The proportion of the x-axis covered by the vertical projection of the two blocks.
"""
x0_1, _, x1_1, _ = block1
x0_2, _, x1_2, _ = block2
# Calculate the intersection of the x-coordinates
x_left = max(x0_1, x0_2)
x_right = min(x1_1, x1_2)
if x_right < x_left:
return 0.0
# Length of the intersection
intersection_length = x_right - x_left
# Length of the x-axis projection of the first block
block1_length = x1_1 - x0_1
if block1_length == 0:
return 0.0
# Proportion of the x-axis covered by the intersection
# logger.info(f"intersection_length: {intersection_length}, block1_length: {block1_length}")
return intersection_length / block1_length
# Copyright (c) Opendatalab. All rights reserved.
import torch
import gc
def clean_memory():
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
gc.collect()
\ No newline at end of file
"""
根据bucket的名字返回对应的s3 AK, SK,endpoint三元组
"""
"""根据bucket的名字返回对应的s3 AK, SK,endpoint三元组."""
import json
import os
from loguru import logger
from magic_pdf.libs.Constants import MODEL_NAME
from magic_pdf.libs.commons import parse_bucket_key
# 定义配置文件名常量
CONFIG_FILE_NAME = "magic-pdf.json"
CONFIG_FILE_NAME = os.getenv('MINERU_TOOLS_CONFIG_JSON', 'magic-pdf.json')
def read_config():
home_dir = os.path.expanduser("~")
config_file = os.path.join(home_dir, CONFIG_FILE_NAME)
if os.path.isabs(CONFIG_FILE_NAME):
config_file = CONFIG_FILE_NAME
else:
home_dir = os.path.expanduser('~')
config_file = os.path.join(home_dir, CONFIG_FILE_NAME)
if not os.path.exists(config_file):
raise FileNotFoundError(f"{config_file} not found")
raise FileNotFoundError(f'{config_file} not found')
with open(config_file, "r", encoding="utf-8") as f:
with open(config_file, 'r', encoding='utf-8') as f:
config = json.load(f)
return config
def get_s3_config(bucket_name: str):
"""
~/magic-pdf.json 读出来
"""
"""~/magic-pdf.json 读出来."""
config = read_config()
bucket_info = config.get("bucket_info")
bucket_info = config.get('bucket_info')
if bucket_name not in bucket_info:
access_key, secret_key, storage_endpoint = bucket_info["[default]"]
access_key, secret_key, storage_endpoint = bucket_info['[default]']
else:
access_key, secret_key, storage_endpoint = bucket_info[bucket_name]
if access_key is None or secret_key is None or storage_endpoint is None:
raise Exception(f"ak, sk or endpoint not found in {CONFIG_FILE_NAME}")
raise Exception(f'ak, sk or endpoint not found in {CONFIG_FILE_NAME}')
# logger.info(f"get_s3_config: ak={access_key}, sk={secret_key}, endpoint={storage_endpoint}")
......@@ -49,7 +47,7 @@ def get_s3_config(bucket_name: str):
def get_s3_config_dict(path: str):
access_key, secret_key, storage_endpoint = get_s3_config(get_bucket_name(path))
return {"ak": access_key, "sk": secret_key, "endpoint": storage_endpoint}
return {'ak': access_key, 'sk': secret_key, 'endpoint': storage_endpoint}
def get_bucket_name(path):
......@@ -59,33 +57,65 @@ def get_bucket_name(path):
def get_local_models_dir():
config = read_config()
models_dir = config.get("models-dir")
models_dir = config.get('models-dir')
if models_dir is None:
logger.warning(f"'models-dir' not found in {CONFIG_FILE_NAME}, use '/tmp/models' as default")
return "/tmp/models"
return '/tmp/models'
else:
return models_dir
def get_local_layoutreader_model_dir():
config = read_config()
layoutreader_model_dir = config.get('layoutreader-model-dir')
if layoutreader_model_dir is None or not os.path.exists(layoutreader_model_dir):
home_dir = os.path.expanduser('~')
layoutreader_at_modelscope_dir_path = os.path.join(home_dir, '.cache/modelscope/hub/ppaanngggg/layoutreader')
logger.warning(f"'layoutreader-model-dir' not exists, use {layoutreader_at_modelscope_dir_path} as default")
return layoutreader_at_modelscope_dir_path
else:
return layoutreader_model_dir
def get_device():
config = read_config()
device = config.get("device-mode")
device = config.get('device-mode')
if device is None:
logger.warning(f"'device-mode' not found in {CONFIG_FILE_NAME}, use 'cpu' as default")
return "cpu"
return 'cpu'
else:
return device
def get_table_recog_config():
config = read_config()
table_config = config.get("table-config")
table_config = config.get('table-config')
if table_config is None:
logger.warning(f"'table-config' not found in {CONFIG_FILE_NAME}, use 'False' as default")
return json.loads('{"is_table_recog_enable": false, "max_time": 400}')
return json.loads(f'{{"model": "{MODEL_NAME.TABLE_MASTER}","enable": false, "max_time": 400}}')
else:
return table_config
def get_layout_config():
config = read_config()
layout_config = config.get("layout-config")
if layout_config is None:
logger.warning(f"'layout-config' not found in {CONFIG_FILE_NAME}, use '{MODEL_NAME.LAYOUTLMv3}' as default")
return json.loads(f'{{"model": "{MODEL_NAME.LAYOUTLMv3}"}}')
else:
return layout_config
def get_formula_config():
config = read_config()
formula_config = config.get("formula-config")
if formula_config is None:
logger.warning(f"'formula-config' not found in {CONFIG_FILE_NAME}, use 'True' as default")
return json.loads(f'{{"mfd_model": "{MODEL_NAME.YOLO_V8_MFD}","mfr_model": "{MODEL_NAME.UniMerNet_v2_Small}","enable": true}}')
else:
return formula_config
if __name__ == "__main__":
ak, sk, endpoint = get_s3_config("llm-raw")
from magic_pdf.data.dataset import PymuDocDataset
from magic_pdf.libs.commons import fitz # PyMuPDF
from magic_pdf.libs.Constants import CROSS_PAGE
from magic_pdf.libs.ocr_content_type import BlockType, CategoryId, ContentType
......@@ -33,7 +34,7 @@ def draw_bbox_without_number(i, bbox_list, page, rgb_config, fill_config):
) # Draw the rectangle
def draw_bbox_with_number(i, bbox_list, page, rgb_config, fill_config):
def draw_bbox_with_number(i, bbox_list, page, rgb_config, fill_config, draw_bbox=True):
new_rgb = []
for item in rgb_config:
item = float(item) / 255
......@@ -42,31 +43,31 @@ def draw_bbox_with_number(i, bbox_list, page, rgb_config, fill_config):
for j, bbox in enumerate(page_data):
x0, y0, x1, y1 = bbox
rect_coords = fitz.Rect(x0, y0, x1, y1) # Define the rectangle
if fill_config:
page.draw_rect(
rect_coords,
color=None,
fill=new_rgb,
fill_opacity=0.3,
width=0.5,
overlay=True,
) # Draw the rectangle
else:
page.draw_rect(
rect_coords,
color=new_rgb,
fill=None,
fill_opacity=1,
width=0.5,
overlay=True,
) # Draw the rectangle
if draw_bbox:
if fill_config:
page.draw_rect(
rect_coords,
color=None,
fill=new_rgb,
fill_opacity=0.3,
width=0.5,
overlay=True,
) # Draw the rectangle
else:
page.draw_rect(
rect_coords,
color=new_rgb,
fill=None,
fill_opacity=1,
width=0.5,
overlay=True,
) # Draw the rectangle
page.insert_text(
(x0, y0 + 10), str(j + 1), fontsize=10, color=new_rgb
(x1 + 2, y0 + 10), str(j + 1), fontsize=10, color=new_rgb
) # Insert the index in the top left corner of the rectangle
def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
layout_bbox_list = []
dropped_bbox_list = []
tables_list, tables_body_list = [], []
tables_caption_list, tables_footnote_list = [], []
......@@ -75,17 +76,19 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
titles_list = []
texts_list = []
interequations_list = []
lists_list = []
indexs_list = []
for page in pdf_info:
page_layout_list = []
page_dropped_list = []
tables, tables_body, tables_caption, tables_footnote = [], [], [], []
imgs, imgs_body, imgs_caption, imgs_footnote = [], [], [], []
titles = []
texts = []
interequations = []
for layout in page['layout_bboxes']:
page_layout_list.append(layout['layout_bbox'])
layout_bbox_list.append(page_layout_list)
lists = []
indices = []
for dropped_bbox in page['discarded_blocks']:
page_dropped_list.append(dropped_bbox['bbox'])
dropped_bbox_list.append(page_dropped_list)
......@@ -117,6 +120,11 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
texts.append(bbox)
elif block['type'] == BlockType.InterlineEquation:
interequations.append(bbox)
elif block['type'] == BlockType.List:
lists.append(bbox)
elif block['type'] == BlockType.Index:
indices.append(bbox)
tables_list.append(tables)
tables_body_list.append(tables_body)
tables_caption_list.append(tables_caption)
......@@ -128,30 +136,62 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
titles_list.append(titles)
texts_list.append(texts)
interequations_list.append(interequations)
lists_list.append(lists)
indexs_list.append(indices)
layout_bbox_list = []
table_type_order = {
'table_caption': 1,
'table_body': 2,
'table_footnote': 3
}
for page in pdf_info:
page_block_list = []
for block in page['para_blocks']:
if block['type'] in [
BlockType.Text,
BlockType.Title,
BlockType.InterlineEquation,
BlockType.List,
BlockType.Index,
]:
bbox = block['bbox']
page_block_list.append(bbox)
elif block['type'] in [BlockType.Image]:
for sub_block in block['blocks']:
bbox = sub_block['bbox']
page_block_list.append(bbox)
elif block['type'] in [BlockType.Table]:
sorted_blocks = sorted(block['blocks'], key=lambda x: table_type_order[x['type']])
for sub_block in sorted_blocks:
bbox = sub_block['bbox']
page_block_list.append(bbox)
layout_bbox_list.append(page_block_list)
pdf_docs = fitz.open('pdf', pdf_bytes)
for i, page in enumerate(pdf_docs):
draw_bbox_with_number(i, layout_bbox_list, page, [255, 0, 0], False)
draw_bbox_without_number(i, dropped_bbox_list, page, [158, 158, 158],
True)
draw_bbox_without_number(i, tables_list, page, [153, 153, 0],
True) # color !
draw_bbox_without_number(i, tables_body_list, page, [204, 204, 0],
True)
draw_bbox_without_number(i, tables_caption_list, page, [255, 255, 102],
True)
draw_bbox_without_number(i, tables_footnote_list, page,
[229, 255, 204], True)
draw_bbox_without_number(i, imgs_list, page, [51, 102, 0], True)
draw_bbox_without_number(i, dropped_bbox_list, page, [158, 158, 158], True)
# draw_bbox_without_number(i, tables_list, page, [153, 153, 0], True) # color !
draw_bbox_without_number(i, tables_body_list, page, [204, 204, 0], True)
draw_bbox_without_number(i, tables_caption_list, page, [255, 255, 102], True)
draw_bbox_without_number(i, tables_footnote_list, page, [229, 255, 204], True)
# draw_bbox_without_number(i, imgs_list, page, [51, 102, 0], True)
draw_bbox_without_number(i, imgs_body_list, page, [153, 255, 51], True)
draw_bbox_without_number(i, imgs_caption_list, page, [102, 178, 255],
True)
draw_bbox_with_number(i, imgs_footnote_list, page, [255, 178, 102],
True),
draw_bbox_without_number(i, imgs_caption_list, page, [102, 178, 255], True)
draw_bbox_without_number(i, imgs_footnote_list, page, [255, 178, 102], True),
draw_bbox_without_number(i, titles_list, page, [102, 102, 255], True)
draw_bbox_without_number(i, texts_list, page, [153, 0, 76], True)
draw_bbox_without_number(i, interequations_list, page, [0, 255, 0],
True)
draw_bbox_without_number(i, interequations_list, page, [0, 255, 0], True)
draw_bbox_without_number(i, lists_list, page, [40, 169, 92], True)
draw_bbox_without_number(i, indexs_list, page, [40, 169, 92], True)
draw_bbox_with_number(
i, layout_bbox_list, page, [255, 0, 0], False, draw_bbox=False
)
# Save the PDF
pdf_docs.save(f'{out_path}/{filename}_layout.pdf')
......@@ -209,11 +249,14 @@ def draw_span_bbox(pdf_info, pdf_bytes, out_path, filename):
page_dropped_list.append(span['bbox'])
dropped_list.append(page_dropped_list)
# 构造其余useful_list
for block in page['para_blocks']:
# for block in page['para_blocks']: # span直接用分段合并前的结果就可以
for block in page['preproc_blocks']:
if block['type'] in [
BlockType.Text,
BlockType.Title,
BlockType.InterlineEquation,
BlockType.Text,
BlockType.Title,
BlockType.InterlineEquation,
BlockType.List,
BlockType.Index,
]:
for line in block['lines']:
for span in line['spans']:
......@@ -232,10 +275,8 @@ def draw_span_bbox(pdf_info, pdf_bytes, out_path, filename):
for i, page in enumerate(pdf_docs):
# 获取当前页面的数据
draw_bbox_without_number(i, text_list, page, [255, 0, 0], False)
draw_bbox_without_number(i, inline_equation_list, page, [0, 255, 0],
False)
draw_bbox_without_number(i, interline_equation_list, page, [0, 0, 255],
False)
draw_bbox_without_number(i, inline_equation_list, page, [0, 255, 0], False)
draw_bbox_without_number(i, interline_equation_list, page, [0, 0, 255], False)
draw_bbox_without_number(i, image_list, page, [255, 204, 0], False)
draw_bbox_without_number(i, table_list, page, [204, 0, 255], False)
draw_bbox_without_number(i, dropped_list, page, [158, 158, 158], False)
......@@ -244,7 +285,7 @@ def draw_span_bbox(pdf_info, pdf_bytes, out_path, filename):
pdf_docs.save(f'{out_path}/{filename}_spans.pdf')
def drow_model_bbox(model_list: list, pdf_bytes, out_path, filename):
def draw_model_bbox(model_list: list, pdf_bytes, out_path, filename):
dropped_bbox_list = []
tables_body_list, tables_caption_list, tables_footnote_list = [], [], []
imgs_body_list, imgs_caption_list, imgs_footnote_list = [], [], []
......@@ -252,7 +293,7 @@ def drow_model_bbox(model_list: list, pdf_bytes, out_path, filename):
texts_list = []
interequations_list = []
pdf_docs = fitz.open('pdf', pdf_bytes)
magic_model = MagicModel(model_list, pdf_docs)
magic_model = MagicModel(model_list, PymuDocDataset(pdf_bytes))
for i in range(len(model_list)):
page_dropped_list = []
tables_body, tables_caption, tables_footnote = [], [], []
......@@ -278,8 +319,7 @@ def drow_model_bbox(model_list: list, pdf_bytes, out_path, filename):
imgs_body.append(bbox)
elif layout_det['category_id'] == CategoryId.ImageCaption:
imgs_caption.append(bbox)
elif layout_det[
'category_id'] == CategoryId.InterlineEquation_YOLO:
elif layout_det['category_id'] == CategoryId.InterlineEquation_YOLO:
interequations.append(bbox)
elif layout_det['category_id'] == CategoryId.Abandon:
page_dropped_list.append(bbox)
......@@ -298,21 +338,66 @@ def drow_model_bbox(model_list: list, pdf_bytes, out_path, filename):
imgs_footnote_list.append(imgs_footnote)
for i, page in enumerate(pdf_docs):
draw_bbox_with_number(i, dropped_bbox_list, page, [158, 158, 158],
True) # color !
draw_bbox_with_number(
i, dropped_bbox_list, page, [158, 158, 158], True
) # color !
draw_bbox_with_number(i, tables_body_list, page, [204, 204, 0], True)
draw_bbox_with_number(i, tables_caption_list, page, [255, 255, 102],
True)
draw_bbox_with_number(i, tables_footnote_list, page, [229, 255, 204],
True)
draw_bbox_with_number(i, tables_caption_list, page, [255, 255, 102], True)
draw_bbox_with_number(i, tables_footnote_list, page, [229, 255, 204], True)
draw_bbox_with_number(i, imgs_body_list, page, [153, 255, 51], True)
draw_bbox_with_number(i, imgs_caption_list, page, [102, 178, 255],
True)
draw_bbox_with_number(i, imgs_footnote_list, page, [255, 178, 102],
True)
draw_bbox_with_number(i, imgs_caption_list, page, [102, 178, 255], True)
draw_bbox_with_number(i, imgs_footnote_list, page, [255, 178, 102], True)
draw_bbox_with_number(i, titles_list, page, [102, 102, 255], True)
draw_bbox_with_number(i, texts_list, page, [153, 0, 76], True)
draw_bbox_with_number(i, interequations_list, page, [0, 255, 0], True)
# Save the PDF
pdf_docs.save(f'{out_path}/{filename}_model.pdf')
def draw_line_sort_bbox(pdf_info, pdf_bytes, out_path, filename):
layout_bbox_list = []
for page in pdf_info:
page_line_list = []
for block in page['preproc_blocks']:
if block['type'] in [BlockType.Text, BlockType.Title, BlockType.InterlineEquation]:
for line in block['lines']:
bbox = line['bbox']
index = line['index']
page_line_list.append({'index': index, 'bbox': bbox})
if block['type'] in [BlockType.Image, BlockType.Table]:
for sub_block in block['blocks']:
if sub_block['type'] in [BlockType.ImageBody, BlockType.TableBody]:
for line in sub_block['virtual_lines']:
bbox = line['bbox']
index = line['index']
page_line_list.append({'index': index, 'bbox': bbox})
elif sub_block['type'] in [BlockType.ImageCaption, BlockType.TableCaption, BlockType.ImageFootnote, BlockType.TableFootnote]:
for line in sub_block['lines']:
bbox = line['bbox']
index = line['index']
page_line_list.append({'index': index, 'bbox': bbox})
sorted_bboxes = sorted(page_line_list, key=lambda x: x['index'])
layout_bbox_list.append(sorted_bbox['bbox'] for sorted_bbox in sorted_bboxes)
pdf_docs = fitz.open('pdf', pdf_bytes)
for i, page in enumerate(pdf_docs):
draw_bbox_with_number(i, layout_bbox_list, page, [255, 0, 0], False)
pdf_docs.save(f'{out_path}/{filename}_line_sort.pdf')
def draw_layout_sort_bbox(pdf_info, pdf_bytes, out_path, filename):
layout_bbox_list = []
for page in pdf_info:
page_block_list = []
for block in page['para_blocks']:
bbox = block['bbox']
page_block_list.append(bbox)
layout_bbox_list.append(page_block_list)
pdf_docs = fitz.open('pdf', pdf_bytes)
for i, page in enumerate(pdf_docs):
draw_bbox_with_number(i, layout_bbox_list, page, [255, 0, 0], False)
pdf_docs.save(f'{out_path}/{filename}_layout_sort.pdf')
......@@ -20,6 +20,8 @@ class BlockType:
InterlineEquation = 'interline_equation'
Footnote = 'footnote'
Discarded = 'discarded'
List = 'list'
Index = 'index'
class CategoryId:
......
__version__ = "0.8.1"
__version__ = "0.9.0"
......@@ -4,7 +4,9 @@ import fitz
import numpy as np
from loguru import logger
from magic_pdf.libs.config_reader import get_local_models_dir, get_device, get_table_recog_config
from magic_pdf.libs.clean_memory import clean_memory
from magic_pdf.libs.config_reader import get_local_models_dir, get_device, get_table_recog_config, get_layout_config, \
get_formula_config
from magic_pdf.model.model_list import MODEL
import magic_pdf.model as model_config
......@@ -23,7 +25,7 @@ def remove_duplicates_dicts(lst):
return unique_dicts
def load_images_from_pdf(pdf_bytes: bytes, dpi=200) -> list:
def load_images_from_pdf(pdf_bytes: bytes, dpi=200, start_page_id=0, end_page_id=None) -> list:
try:
from PIL import Image
except ImportError:
......@@ -32,18 +34,28 @@ def load_images_from_pdf(pdf_bytes: bytes, dpi=200) -> list:
images = []
with fitz.open("pdf", pdf_bytes) as doc:
pdf_page_num = doc.page_count
end_page_id = end_page_id if end_page_id is not None and end_page_id >= 0 else pdf_page_num - 1
if end_page_id > pdf_page_num - 1:
logger.warning("end_page_id is out of range, use images length")
end_page_id = pdf_page_num - 1
for index in range(0, doc.page_count):
page = doc[index]
mat = fitz.Matrix(dpi / 72, dpi / 72)
pm = page.get_pixmap(matrix=mat, alpha=False)
if start_page_id <= index <= end_page_id:
page = doc[index]
mat = fitz.Matrix(dpi / 72, dpi / 72)
pm = page.get_pixmap(matrix=mat, alpha=False)
# If the width or height exceeds 9000 after scaling, do not scale further.
if pm.width > 9000 or pm.height > 9000:
pm = page.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False)
# If the width or height exceeds 9000 after scaling, do not scale further.
if pm.width > 9000 or pm.height > 9000:
pm = page.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False)
img = Image.frombytes("RGB", (pm.width, pm.height), pm.samples)
img = np.array(img)
img_dict = {"img": img, "width": pm.width, "height": pm.height}
else:
img_dict = {"img": [], "width": 0, "height": 0}
img = Image.frombytes("RGB", (pm.width, pm.height), pm.samples)
img = np.array(img)
img_dict = {"img": img, "width": pm.width, "height": pm.height}
images.append(img_dict)
return images
......@@ -57,14 +69,17 @@ class ModelSingleton:
cls._instance = super().__new__(cls)
return cls._instance
def get_model(self, ocr: bool, show_log: bool):
key = (ocr, show_log)
def get_model(self, ocr: bool, show_log: bool, lang=None, layout_model=None, formula_enable=None, table_enable=None):
key = (ocr, show_log, lang, layout_model, formula_enable, table_enable)
if key not in self._models:
self._models[key] = custom_model_init(ocr=ocr, show_log=show_log)
self._models[key] = custom_model_init(ocr=ocr, show_log=show_log, lang=lang, layout_model=layout_model,
formula_enable=formula_enable, table_enable=table_enable)
return self._models[key]
def custom_model_init(ocr: bool = False, show_log: bool = False):
def custom_model_init(ocr: bool = False, show_log: bool = False, lang=None,
layout_model=None, formula_enable=None, table_enable=None):
model = None
if model_config.__model_mode__ == "lite":
......@@ -78,18 +93,36 @@ def custom_model_init(ocr: bool = False, show_log: bool = False):
model_init_start = time.time()
if model == MODEL.Paddle:
from magic_pdf.model.pp_structure_v2 import CustomPaddleModel
custom_model = CustomPaddleModel(ocr=ocr, show_log=show_log)
custom_model = CustomPaddleModel(ocr=ocr, show_log=show_log, lang=lang)
elif model == MODEL.PEK:
from magic_pdf.model.pdf_extract_kit import CustomPEKModel
# 从配置文件读取model-dir和device
local_models_dir = get_local_models_dir()
device = get_device()
layout_config = get_layout_config()
if layout_model is not None:
layout_config["model"] = layout_model
formula_config = get_formula_config()
if formula_enable is not None:
formula_config["enable"] = formula_enable
table_config = get_table_recog_config()
model_input = {"ocr": ocr,
"show_log": show_log,
"models_dir": local_models_dir,
"device": device,
"table_config": table_config}
if table_enable is not None:
table_config["enable"] = table_enable
model_input = {
"ocr": ocr,
"show_log": show_log,
"models_dir": local_models_dir,
"device": device,
"table_config": table_config,
"layout_config": layout_config,
"formula_config": formula_config,
"lang": lang,
}
custom_model = CustomPEKModel(**model_input)
else:
logger.error("Not allow model_name!")
......@@ -104,19 +137,23 @@ def custom_model_init(ocr: bool = False, show_log: bool = False):
def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False,
start_page_id=0, end_page_id=None):
start_page_id=0, end_page_id=None, lang=None,
layout_model=None, formula_enable=None, table_enable=None):
model_manager = ModelSingleton()
custom_model = model_manager.get_model(ocr, show_log)
if lang == "":
lang = None
images = load_images_from_pdf(pdf_bytes)
model_manager = ModelSingleton()
custom_model = model_manager.get_model(ocr, show_log, lang, layout_model, formula_enable, table_enable)
# end_page_id = end_page_id if end_page_id else len(images) - 1
end_page_id = end_page_id if end_page_id is not None and end_page_id >= 0 else len(images) - 1
with fitz.open("pdf", pdf_bytes) as doc:
pdf_page_num = doc.page_count
end_page_id = end_page_id if end_page_id is not None and end_page_id >= 0 else pdf_page_num - 1
if end_page_id > pdf_page_num - 1:
logger.warning("end_page_id is out of range, use images length")
end_page_id = pdf_page_num - 1
if end_page_id > len(images) - 1:
logger.warning("end_page_id is out of range, use images length")
end_page_id = len(images) - 1
images = load_images_from_pdf(pdf_bytes, start_page_id=start_page_id, end_page_id=end_page_id)
model_json = []
doc_analyze_start = time.time()
......@@ -132,7 +169,15 @@ def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False,
page_info = {"page_no": index, "height": page_height, "width": page_width}
page_dict = {"layout_dets": result, "page_info": page_info}
model_json.append(page_dict)
doc_analyze_cost = time.time() - doc_analyze_start
logger.info(f"doc analyze cost: {doc_analyze_cost}")
gc_start = time.time()
clean_memory()
gc_time = round(time.time() - gc_start, 2)
logger.info(f"gc time: {gc_time}")
doc_analyze_time = round(time.time() - doc_analyze_start, 2)
doc_analyze_speed = round( (end_page_id + 1 - start_page_id) / doc_analyze_time, 2)
logger.info(f"doc analyze time: {round(time.time() - doc_analyze_start, 2)},"
f" speed: {doc_analyze_speed} pages/second")
return model_json
import enum
import json
from magic_pdf.data.dataset import Dataset
from magic_pdf.libs.boxbase import (_is_in, _is_part_overlap, bbox_distance,
bbox_relative_pos, box_area, calculate_iou,
calculate_overlap_area_in_bbox1_area_ratio,
......@@ -9,6 +11,7 @@ from magic_pdf.libs.coordinate_transform import get_scale_ratio
from magic_pdf.libs.local_math import float_gt
from magic_pdf.libs.ModelBlockTypeEnum import ModelBlockTypeEnum
from magic_pdf.libs.ocr_content_type import CategoryId, ContentType
from magic_pdf.pre_proc.remove_bbox_overlap import _remove_overlap_between_bbox
from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter
from magic_pdf.rw.DiskReaderWriter import DiskReaderWriter
......@@ -16,6 +19,14 @@ CAPATION_OVERLAP_AREA_RATIO = 0.6
MERGE_BOX_OVERLAP_AREA_RATIO = 1.1
class PosRelationEnum(enum.Enum):
LEFT = 'left'
RIGHT = 'right'
UP = 'up'
BOTTOM = 'bottom'
ALL = 'all'
class MagicModel:
"""每个函数没有得到元素的时候返回空list."""
......@@ -24,7 +35,7 @@ class MagicModel:
need_remove_list = []
page_no = model_page_info['page_info']['page_no']
horizontal_scale_ratio, vertical_scale_ratio = get_scale_ratio(
model_page_info, self.__docs[page_no]
model_page_info, self.__docs.get_page(page_no)
)
layout_dets = model_page_info['layout_dets']
for layout_det in layout_dets:
......@@ -99,7 +110,7 @@ class MagicModel:
for need_remove in need_remove_list:
layout_dets.remove(need_remove)
def __init__(self, model_list: list, docs: fitz.Document):
def __init__(self, model_list: list, docs: Dataset):
self.__model_list = model_list
self.__docs = docs
"""为所有模型数据添加bbox信息(缩放,poly->bbox)"""
......@@ -110,6 +121,24 @@ class MagicModel:
self.__fix_by_remove_high_iou_and_low_confidence()
self.__fix_footnote()
def _bbox_distance(self, bbox1, bbox2):
left, right, bottom, top = bbox_relative_pos(bbox1, bbox2)
flags = [left, right, bottom, top]
count = sum([1 if v else 0 for v in flags])
if count > 1:
return float('inf')
if left or right:
l1 = bbox1[3] - bbox1[1]
l2 = bbox2[3] - bbox2[1]
else:
l1 = bbox1[2] - bbox1[0]
l2 = bbox2[2] - bbox2[0]
if l2 > l1 and (l2 - l1) / l1 > 0.3:
return float('inf')
return bbox_distance(bbox1, bbox2)
def __fix_footnote(self):
# 3: figure, 5: table, 7: footnote
for model_page_info in self.__model_list:
......@@ -144,7 +173,7 @@ class MagicModel:
if pos_flag_count > 1:
continue
dis_figure_footnote[i] = min(
bbox_distance(figures[j]['bbox'], footnotes[i]['bbox']),
self._bbox_distance(figures[j]['bbox'], footnotes[i]['bbox']),
dis_figure_footnote.get(i, float('inf')),
)
for i in range(len(footnotes)):
......@@ -163,7 +192,7 @@ class MagicModel:
continue
dis_table_footnote[i] = min(
bbox_distance(tables[j]['bbox'], footnotes[i]['bbox']),
self._bbox_distance(tables[j]['bbox'], footnotes[i]['bbox']),
dis_table_footnote.get(i, float('inf')),
)
for i in range(len(footnotes)):
......@@ -195,9 +224,8 @@ class MagicModel:
筛选出所有和 merged bbox 有 overlap 且 overlap 面积大于 object 的面积的 subjects。
再求出筛选出的 subjects 和 object 的最短距离
"""
def search_overlap_between_boxes(
subject_idx, object_idx
):
def search_overlap_between_boxes(subject_idx, object_idx):
idxes = [subject_idx, object_idx]
x0s = [all_bboxes[idx]['bbox'][0] for idx in idxes]
y0s = [all_bboxes[idx]['bbox'][1] for idx in idxes]
......@@ -225,9 +253,9 @@ class MagicModel:
for other_object in other_objects:
ratio = max(
ratio,
get_overlap_area(
merged_bbox, other_object['bbox']
) * 1.0 / box_area(all_bboxes[object_idx]['bbox'])
get_overlap_area(merged_bbox, other_object['bbox'])
* 1.0
/ box_area(all_bboxes[object_idx]['bbox']),
)
if ratio >= MERGE_BOX_OVERLAP_AREA_RATIO:
break
......@@ -345,12 +373,17 @@ class MagicModel:
if all_bboxes[j]['category_id'] == subject_category_id:
subject_idx, object_idx = j, i
if search_overlap_between_boxes(subject_idx, object_idx) >= MERGE_BOX_OVERLAP_AREA_RATIO:
if (
search_overlap_between_boxes(subject_idx, object_idx)
>= MERGE_BOX_OVERLAP_AREA_RATIO
):
dis[i][j] = float('inf')
dis[j][i] = dis[i][j]
continue
dis[i][j] = bbox_distance(all_bboxes[i]['bbox'], all_bboxes[j]['bbox'])
dis[i][j] = self._bbox_distance(
all_bboxes[subject_idx]['bbox'], all_bboxes[object_idx]['bbox']
)
dis[j][i] = dis[i][j]
used = set()
......@@ -566,6 +599,289 @@ class MagicModel:
with_caption_subject.add(j)
return ret, total_subject_object_dis
def __tie_up_category_by_distance_v2(
self,
page_no: int,
subject_category_id: int,
object_category_id: int,
priority_pos: PosRelationEnum,
):
"""_summary_
Args:
page_no (int): _description_
subject_category_id (int): _description_
object_category_id (int): _description_
priority_pos (PosRelationEnum): _description_
Returns:
_type_: _description_
"""
AXIS_MULPLICITY = 0.5
subjects = self.__reduct_overlap(
list(
map(
lambda x: {'bbox': x['bbox'], 'score': x['score']},
filter(
lambda x: x['category_id'] == subject_category_id,
self.__model_list[page_no]['layout_dets'],
),
)
)
)
objects = self.__reduct_overlap(
list(
map(
lambda x: {'bbox': x['bbox'], 'score': x['score']},
filter(
lambda x: x['category_id'] == object_category_id,
self.__model_list[page_no]['layout_dets'],
),
)
)
)
M = len(objects)
subjects.sort(key=lambda x: x['bbox'][0] ** 2 + x['bbox'][1] ** 2)
objects.sort(key=lambda x: x['bbox'][0] ** 2 + x['bbox'][1] ** 2)
sub_obj_map_h = {i: [] for i in range(len(subjects))}
dis_by_directions = {
'top': [[-1, float('inf')]] * M,
'bottom': [[-1, float('inf')]] * M,
'left': [[-1, float('inf')]] * M,
'right': [[-1, float('inf')]] * M,
}
for i, obj in enumerate(objects):
l_x_axis, l_y_axis = (
obj['bbox'][2] - obj['bbox'][0],
obj['bbox'][3] - obj['bbox'][1],
)
axis_unit = min(l_x_axis, l_y_axis)
for j, sub in enumerate(subjects):
bbox1, bbox2, _ = _remove_overlap_between_bbox(
objects[i]['bbox'], subjects[j]['bbox']
)
left, right, bottom, top = bbox_relative_pos(bbox1, bbox2)
flags = [left, right, bottom, top]
if sum([1 if v else 0 for v in flags]) > 1:
continue
if left:
if dis_by_directions['left'][i][1] > bbox_distance(
obj['bbox'], sub['bbox']
):
dis_by_directions['left'][i] = [
j,
bbox_distance(obj['bbox'], sub['bbox']),
]
if right:
if dis_by_directions['right'][i][1] > bbox_distance(
obj['bbox'], sub['bbox']
):
dis_by_directions['right'][i] = [
j,
bbox_distance(obj['bbox'], sub['bbox']),
]
if bottom:
if dis_by_directions['bottom'][i][1] > bbox_distance(
obj['bbox'], sub['bbox']
):
dis_by_directions['bottom'][i] = [
j,
bbox_distance(obj['bbox'], sub['bbox']),
]
if top:
if dis_by_directions['top'][i][1] > bbox_distance(
obj['bbox'], sub['bbox']
):
dis_by_directions['top'][i] = [
j,
bbox_distance(obj['bbox'], sub['bbox']),
]
if (
dis_by_directions['top'][i][1] != float('inf')
and dis_by_directions['bottom'][i][1] != float('inf')
and priority_pos in (PosRelationEnum.BOTTOM, PosRelationEnum.UP)
):
RATIO = 3
if (
abs(
dis_by_directions['top'][i][1]
- dis_by_directions['bottom'][i][1]
)
< RATIO * axis_unit
):
if priority_pos == PosRelationEnum.BOTTOM:
sub_obj_map_h[dis_by_directions['bottom'][i][0]].append(i)
else:
sub_obj_map_h[dis_by_directions['top'][i][0]].append(i)
continue
if dis_by_directions['left'][i][1] != float('inf') or dis_by_directions[
'right'
][i][1] != float('inf'):
if dis_by_directions['left'][i][1] != float(
'inf'
) and dis_by_directions['right'][i][1] != float('inf'):
if AXIS_MULPLICITY * axis_unit >= abs(
dis_by_directions['left'][i][1]
- dis_by_directions['right'][i][1]
):
left_sub_bbox = subjects[dis_by_directions['left'][i][0]][
'bbox'
]
right_sub_bbox = subjects[dis_by_directions['right'][i][0]][
'bbox'
]
left_sub_bbox_y_axis = left_sub_bbox[3] - left_sub_bbox[1]
right_sub_bbox_y_axis = right_sub_bbox[3] - right_sub_bbox[1]
if (
abs(left_sub_bbox_y_axis - l_y_axis)
+ dis_by_directions['left'][i][0]
> abs(right_sub_bbox_y_axis - l_y_axis)
+ dis_by_directions['right'][i][0]
):
left_or_right = dis_by_directions['right'][i]
else:
left_or_right = dis_by_directions['left'][i]
else:
left_or_right = dis_by_directions['left'][i]
if left_or_right[1] > dis_by_directions['right'][i][1]:
left_or_right = dis_by_directions['right'][i]
else:
left_or_right = dis_by_directions['left'][i]
if left_or_right[1] == float('inf'):
left_or_right = dis_by_directions['right'][i]
else:
left_or_right = [-1, float('inf')]
if dis_by_directions['top'][i][1] != float('inf') or dis_by_directions[
'bottom'
][i][1] != float('inf'):
if dis_by_directions['top'][i][1] != float('inf') and dis_by_directions[
'bottom'
][i][1] != float('inf'):
if AXIS_MULPLICITY * axis_unit >= abs(
dis_by_directions['top'][i][1]
- dis_by_directions['bottom'][i][1]
):
top_bottom = subjects[dis_by_directions['bottom'][i][0]]['bbox']
bottom_top = subjects[dis_by_directions['top'][i][0]]['bbox']
top_bottom_x_axis = top_bottom[2] - top_bottom[0]
bottom_top_x_axis = bottom_top[2] - bottom_top[0]
if (
abs(top_bottom_x_axis - l_x_axis)
+ dis_by_directions['bottom'][i][1]
> abs(bottom_top_x_axis - l_x_axis)
+ dis_by_directions['top'][i][1]
):
top_or_bottom = dis_by_directions['top'][i]
else:
top_or_bottom = dis_by_directions['bottom'][i]
else:
top_or_bottom = dis_by_directions['top'][i]
if top_or_bottom[1] > dis_by_directions['bottom'][i][1]:
top_or_bottom = dis_by_directions['bottom'][i]
else:
top_or_bottom = dis_by_directions['top'][i]
if top_or_bottom[1] == float('inf'):
top_or_bottom = dis_by_directions['bottom'][i]
else:
top_or_bottom = [-1, float('inf')]
if left_or_right[1] != float('inf') or top_or_bottom[1] != float('inf'):
if left_or_right[1] != float('inf') and top_or_bottom[1] != float(
'inf'
):
if AXIS_MULPLICITY * axis_unit >= abs(
left_or_right[1] - top_or_bottom[1]
):
y_axis_bbox = subjects[left_or_right[0]]['bbox']
x_axis_bbox = subjects[top_or_bottom[0]]['bbox']
if (
abs((x_axis_bbox[2] - x_axis_bbox[0]) - l_x_axis) / l_x_axis
> abs((y_axis_bbox[3] - y_axis_bbox[1]) - l_y_axis)
/ l_y_axis
):
sub_obj_map_h[left_or_right[0]].append(i)
else:
sub_obj_map_h[top_or_bottom[0]].append(i)
else:
if left_or_right[1] > top_or_bottom[1]:
sub_obj_map_h[top_or_bottom[0]].append(i)
else:
sub_obj_map_h[left_or_right[0]].append(i)
else:
if left_or_right[1] != float('inf'):
sub_obj_map_h[left_or_right[0]].append(i)
else:
sub_obj_map_h[top_or_bottom[0]].append(i)
ret = []
for i in sub_obj_map_h.keys():
ret.append(
{
'sub_bbox': {
'bbox': subjects[i]['bbox'],
'score': subjects[i]['score'],
},
'obj_bboxes': [
{'score': objects[j]['score'], 'bbox': objects[j]['bbox']}
for j in sub_obj_map_h[i]
],
'sub_idx': i,
}
)
return ret
def get_imgs_v2(self, page_no: int):
with_captions = self.__tie_up_category_by_distance_v2(
page_no, 3, 4, PosRelationEnum.BOTTOM
)
with_footnotes = self.__tie_up_category_by_distance_v2(
page_no, 3, CategoryId.ImageFootnote, PosRelationEnum.ALL
)
ret = []
for v in with_captions:
record = {
'image_body': v['sub_bbox'],
'image_caption_list': v['obj_bboxes'],
}
filter_idx = v['sub_idx']
d = next(filter(lambda x: x['sub_idx'] == filter_idx, with_footnotes))
record['image_footnote_list'] = d['obj_bboxes']
ret.append(record)
return ret
def get_tables_v2(self, page_no: int) -> list:
with_captions = self.__tie_up_category_by_distance_v2(
page_no, 5, 6, PosRelationEnum.UP
)
with_footnotes = self.__tie_up_category_by_distance_v2(
page_no, 5, 7, PosRelationEnum.ALL
)
ret = []
for v in with_captions:
record = {
'table_body': v['sub_bbox'],
'table_caption_list': v['obj_bboxes'],
}
filter_idx = v['sub_idx']
d = next(filter(lambda x: x['sub_idx'] == filter_idx, with_footnotes))
record['table_footnote_list'] = d['obj_bboxes']
ret.append(record)
return ret
def get_imgs(self, page_no: int):
with_captions, _ = self.__tie_up_category_by_distance(page_no, 3, 4)
with_footnotes, _ = self.__tie_up_category_by_distance(
......@@ -699,10 +1015,10 @@ class MagicModel:
def get_page_size(self, page_no: int): # 获取页面宽高
# 获取当前页的page对象
page = self.__docs[page_no]
page = self.__docs.get_page(page_no).get_page_info()
# 获取当前页的宽高
page_w = page.rect.width
page_h = page.rect.height
page_w = page.w
page_h = page.h
return page_w, page_h
def __get_blocks_by_type(
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment