Unverified Commit 3a42ebbf authored by Xiaomeng Zhao's avatar Xiaomeng Zhao Committed by GitHub
Browse files

Merge pull request #838 from opendatalab/release-0.9.0

Release 0.9.0
parents 765c6d77 14024793
from abc import ABC, abstractmethod
class IOReader(ABC):
@abstractmethod
def read(self, path: str) -> bytes:
"""Read the file.
Args:
path (str): file path to read
Returns:
bytes: the content of the file
"""
pass
@abstractmethod
def read_at(self, path: str, offset: int = 0, limit: int = -1) -> bytes:
"""Read at offset and limit.
Args:
path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
offset (int, optional): the number of bytes skipped. Defaults to 0.
limit (int, optional): the length of bytes want to read. Defaults to -1.
Returns:
bytes: the content of file
"""
pass
class IOWriter:
@abstractmethod
def write(self, path: str, data: bytes) -> None:
"""Write file with data.
Args:
path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
data (bytes): the data want to write
"""
pass
import io
import requests
from magic_pdf.data.io.base import IOReader, IOWriter
class HttpReader(IOReader):
def read(self, url: str) -> bytes:
"""Read the file.
Args:
path (str): file path to read
Returns:
bytes: the content of the file
"""
return requests.get(url).content
def read_at(self, path: str, offset: int = 0, limit: int = -1) -> bytes:
"""Not Implemented."""
raise NotImplementedError
class HttpWriter(IOWriter):
def write(self, url: str, data: bytes) -> None:
"""Write file with data.
Args:
path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
data (bytes): the data want to write
"""
files = {'file': io.BytesIO(data)}
response = requests.post(url, files=files)
assert 300 > response.status_code and response.status_code > 199
import boto3
from botocore.config import Config
from magic_pdf.data.io.base import IOReader, IOWriter
class S3Reader(IOReader):
def __init__(
self,
bucket: str,
ak: str,
sk: str,
endpoint_url: str,
addressing_style: str = 'auto',
):
"""s3 reader client.
Args:
bucket (str): bucket name
ak (str): access key
sk (str): secret key
endpoint_url (str): endpoint url of s3
addressing_style (str, optional): Defaults to 'auto'. Other valid options here are 'path' and 'virtual'
refer to https://boto3.amazonaws.com/v1/documentation/api/1.9.42/guide/s3.html
"""
self._bucket = bucket
self._ak = ak
self._sk = sk
self._s3_client = boto3.client(
service_name='s3',
aws_access_key_id=ak,
aws_secret_access_key=sk,
endpoint_url=endpoint_url,
config=Config(
s3={'addressing_style': addressing_style},
retries={'max_attempts': 5, 'mode': 'standard'},
),
)
def read(self, key: str) -> bytes:
"""Read the file.
Args:
path (str): file path to read
Returns:
bytes: the content of the file
"""
return self.read_at(key)
def read_at(self, key: str, offset: int = 0, limit: int = -1) -> bytes:
"""Read at offset and limit.
Args:
path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
offset (int, optional): the number of bytes skipped. Defaults to 0.
limit (int, optional): the length of bytes want to read. Defaults to -1.
Returns:
bytes: the content of file
"""
if limit > -1:
range_header = f'bytes={offset}-{offset+limit-1}'
res = self._s3_client.get_object(
Bucket=self._bucket, Key=key, Range=range_header
)
else:
res = self._s3_client.get_object(
Bucket=self._bucket, Key=key, Range=f'bytes={offset}-'
)
return res['Body'].read()
class S3Writer(IOWriter):
def __init__(
self,
bucket: str,
ak: str,
sk: str,
endpoint_url: str,
addressing_style: str = 'auto',
):
"""s3 reader client.
Args:
bucket (str): bucket name
ak (str): access key
sk (str): secret key
endpoint_url (str): endpoint url of s3
addressing_style (str, optional): Defaults to 'auto'. Other valid options here are 'path' and 'virtual'
refer to https://boto3.amazonaws.com/v1/documentation/api/1.9.42/guide/s3.html
"""
self._bucket = bucket
self._ak = ak
self._sk = sk
self._s3_client = boto3.client(
service_name='s3',
aws_access_key_id=ak,
aws_secret_access_key=sk,
endpoint_url=endpoint_url,
config=Config(
s3={'addressing_style': addressing_style},
retries={'max_attempts': 5, 'mode': 'standard'},
),
)
def write(self, key: str, data: bytes):
"""Write file with data.
Args:
path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
data (bytes): the data want to write
"""
self._s3_client.put_object(Bucket=self._bucket, Key=key, Body=data)
import json
import os
from pathlib import Path
from magic_pdf.config.exceptions import EmptyData, InvalidParams
from magic_pdf.data.data_reader_writer import (FileBasedDataReader,
MultiBucketS3DataReader)
from magic_pdf.data.dataset import ImageDataset, PymuDocDataset
def read_jsonl(
s3_path_or_local: str, s3_client: MultiBucketS3DataReader | None = None
) -> list[PymuDocDataset]:
"""Read the jsonl file and return the list of PymuDocDataset.
Args:
s3_path_or_local (str): local file or s3 path
s3_client (MultiBucketS3DataReader | None, optional): s3 client that support multiple bucket. Defaults to None.
Raises:
InvalidParams: if s3_path_or_local is s3 path but s3_client is not provided.
EmptyData: if no pdf file location is provided in some line of jsonl file.
InvalidParams: if the file location is s3 path but s3_client is not provided
Returns:
list[PymuDocDataset]: each line in the jsonl file will be converted to a PymuDocDataset
"""
bits_arr = []
if s3_path_or_local.startswith('s3://'):
if s3_client is None:
raise InvalidParams('s3_client is required when s3_path is provided')
jsonl_bits = s3_client.read(s3_path_or_local)
else:
jsonl_bits = FileBasedDataReader('').read(s3_path_or_local)
jsonl_d = [
json.loads(line) for line in jsonl_bits.decode().split('\n') if line.strip()
]
for d in jsonl_d[:5]:
pdf_path = d.get('file_location', '') or d.get('path', '')
if len(pdf_path) == 0:
raise EmptyData('pdf file location is empty')
if pdf_path.startswith('s3://'):
if s3_client is None:
raise InvalidParams('s3_client is required when s3_path is provided')
bits_arr.append(s3_client.read(pdf_path))
else:
bits_arr.append(FileBasedDataReader('').read(pdf_path))
return [PymuDocDataset(bits) for bits in bits_arr]
def read_local_pdfs(path: str) -> list[PymuDocDataset]:
"""Read pdf from path or directory.
Args:
path (str): pdf file path or directory that contains pdf files
Returns:
list[PymuDocDataset]: each pdf file will converted to a PymuDocDataset
"""
if os.path.isdir(path):
reader = FileBasedDataReader(path)
return [
PymuDocDataset(reader.read(doc_path.name))
for doc_path in Path(path).glob('*.pdf')
]
else:
reader = FileBasedDataReader()
bits = reader.read(path)
return [PymuDocDataset(bits)]
def read_local_images(path: str, suffixes: list[str]) -> list[ImageDataset]:
"""Read images from path or directory.
Args:
path (str): image file path or directory that contains image files
suffixes (list[str]): the suffixes of the image files used to filter the files. Example: ['jpg', 'png']
Returns:
list[ImageDataset]: each image file will converted to a ImageDataset
"""
if os.path.isdir(path):
imgs_bits = []
s_suffixes = set(suffixes)
reader = FileBasedDataReader(path)
for root, _, files in os.walk(path):
for file in files:
suffix = file.split('.')
if suffix[-1] in s_suffixes:
imgs_bits.append(reader.read(file))
return [ImageDataset(bits) for bits in imgs_bits]
else:
reader = FileBasedDataReader()
bits = reader.read(path)
return [ImageDataset(bits)]
from pydantic import BaseModel, Field
class S3Config(BaseModel):
bucket_name: str = Field(description='s3 bucket name', min_length=1)
access_key: str = Field(description='s3 access key', min_length=1)
secret_key: str = Field(description='s3 secret key', min_length=1)
endpoint_url: str = Field(description='s3 endpoint url', min_length=1)
addressing_style: str = Field(description='s3 addressing style', default='auto', min_length=1)
class PageInfo(BaseModel):
w: float = Field(description='the width of page')
h: float = Field(description='the height of page')
import fitz
import numpy as np
from magic_pdf.utils.annotations import ImportPIL
@ImportPIL
def fitz_doc_to_image(doc, dpi=200) -> dict:
"""Convert fitz.Document to image, Then convert the image to numpy array.
Args:
doc (_type_): pymudoc page
dpi (int, optional): reset the dpi of dpi. Defaults to 200.
Returns:
dict: {'img': numpy array, 'width': width, 'height': height }
"""
from PIL import Image
mat = fitz.Matrix(dpi / 72, dpi / 72)
pm = doc.get_pixmap(matrix=mat, alpha=False)
# If the width or height exceeds 9000 after scaling, do not scale further.
if pm.width > 9000 or pm.height > 9000:
pm = doc.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False)
img = Image.frombytes('RGB', (pm.width, pm.height), pm.samples)
img = np.array(img)
img_dict = {'img': img, 'width': pm.width, 'height': pm.height}
return img_dict
import re import re
import wordninja
from loguru import logger from loguru import logger
from magic_pdf.libs.commons import join_path from magic_pdf.libs.commons import join_path
...@@ -8,6 +7,7 @@ from magic_pdf.libs.language import detect_lang ...@@ -8,6 +7,7 @@ from magic_pdf.libs.language import detect_lang
from magic_pdf.libs.MakeContentConfig import DropMode, MakeMode from magic_pdf.libs.MakeContentConfig import DropMode, MakeMode
from magic_pdf.libs.markdown_utils import ocr_escape_special_markdown_char from magic_pdf.libs.markdown_utils import ocr_escape_special_markdown_char
from magic_pdf.libs.ocr_content_type import BlockType, ContentType from magic_pdf.libs.ocr_content_type import BlockType, ContentType
from magic_pdf.para.para_split_v3 import ListLineTag
def __is_hyphen_at_line_end(line): def __is_hyphen_at_line_end(line):
...@@ -24,37 +24,6 @@ def __is_hyphen_at_line_end(line): ...@@ -24,37 +24,6 @@ def __is_hyphen_at_line_end(line):
return bool(re.search(r'[A-Za-z]+-\s*$', line)) return bool(re.search(r'[A-Za-z]+-\s*$', line))
def split_long_words(text):
segments = text.split(' ')
for i in range(len(segments)):
words = re.findall(r'\w+|[^\w]', segments[i], re.UNICODE)
for j in range(len(words)):
if len(words[j]) > 10:
words[j] = ' '.join(wordninja.split(words[j]))
segments[i] = ''.join(words)
return ' '.join(segments)
def ocr_mk_mm_markdown_with_para(pdf_info_list: list, img_buket_path):
markdown = []
for page_info in pdf_info_list:
paras_of_layout = page_info.get('para_blocks')
page_markdown = ocr_mk_markdown_with_para_core_v2(
paras_of_layout, 'mm', img_buket_path)
markdown.extend(page_markdown)
return '\n\n'.join(markdown)
def ocr_mk_nlp_markdown_with_para(pdf_info_dict: list):
markdown = []
for page_info in pdf_info_dict:
paras_of_layout = page_info.get('para_blocks')
page_markdown = ocr_mk_markdown_with_para_core_v2(
paras_of_layout, 'nlp')
markdown.extend(page_markdown)
return '\n\n'.join(markdown)
def ocr_mk_mm_markdown_with_para_and_pagination(pdf_info_dict: list, def ocr_mk_mm_markdown_with_para_and_pagination(pdf_info_dict: list,
img_buket_path): img_buket_path):
markdown_with_para_and_pagination = [] markdown_with_para_and_pagination = []
...@@ -67,61 +36,23 @@ def ocr_mk_mm_markdown_with_para_and_pagination(pdf_info_dict: list, ...@@ -67,61 +36,23 @@ def ocr_mk_mm_markdown_with_para_and_pagination(pdf_info_dict: list,
paras_of_layout, 'mm', img_buket_path) paras_of_layout, 'mm', img_buket_path)
markdown_with_para_and_pagination.append({ markdown_with_para_and_pagination.append({
'page_no': 'page_no':
page_no, page_no,
'md_content': 'md_content':
'\n\n'.join(page_markdown) '\n\n'.join(page_markdown)
}) })
page_no += 1 page_no += 1
return markdown_with_para_and_pagination return markdown_with_para_and_pagination
def ocr_mk_markdown_with_para_core(paras_of_layout, mode, img_buket_path=''):
page_markdown = []
for paras in paras_of_layout:
for para in paras:
para_text = ''
for line in para:
for span in line['spans']:
span_type = span.get('type')
content = ''
language = ''
if span_type == ContentType.Text:
content = span['content']
language = detect_lang(content)
if (language == 'en'): # 只对英文长词进行分词处理,中文分词会丢失文本
content = ocr_escape_special_markdown_char(
split_long_words(content))
else:
content = ocr_escape_special_markdown_char(content)
elif span_type == ContentType.InlineEquation:
content = f"${span['content']}$"
elif span_type == ContentType.InterlineEquation:
content = f"\n$$\n{span['content']}\n$$\n"
elif span_type in [ContentType.Image, ContentType.Table]:
if mode == 'mm':
content = f"\n![]({join_path(img_buket_path, span['image_path'])})\n"
elif mode == 'nlp':
pass
if content != '':
if language == 'en': # 英文语境下 content间需要空格分隔
para_text += content + ' '
else: # 中文语境下,content间不需要空格分隔
para_text += content
if para_text.strip() == '':
continue
else:
page_markdown.append(para_text.strip() + ' ')
return page_markdown
def ocr_mk_markdown_with_para_core_v2(paras_of_layout, def ocr_mk_markdown_with_para_core_v2(paras_of_layout,
mode, mode,
img_buket_path=''): img_buket_path='',
):
page_markdown = [] page_markdown = []
for para_block in paras_of_layout: for para_block in paras_of_layout:
para_text = '' para_text = ''
para_type = para_block['type'] para_type = para_block['type']
if para_type == BlockType.Text: if para_type in [BlockType.Text, BlockType.List, BlockType.Index]:
para_text = merge_para_with_text(para_block) para_text = merge_para_with_text(para_block)
elif para_type == BlockType.Title: elif para_type == BlockType.Title:
para_text = f'# {merge_para_with_text(para_block)}' para_text = f'# {merge_para_with_text(para_block)}'
...@@ -136,20 +67,21 @@ def ocr_mk_markdown_with_para_core_v2(paras_of_layout, ...@@ -136,20 +67,21 @@ def ocr_mk_markdown_with_para_core_v2(paras_of_layout,
for line in block['lines']: for line in block['lines']:
for span in line['spans']: for span in line['spans']:
if span['type'] == ContentType.Image: if span['type'] == ContentType.Image:
para_text += f"\n![]({join_path(img_buket_path, span['image_path'])}) \n" if span.get('image_path', ''):
para_text += f"\n![]({join_path(img_buket_path, span['image_path'])}) \n"
for block in para_block['blocks']: # 2nd.拼image_caption for block in para_block['blocks']: # 2nd.拼image_caption
if block['type'] == BlockType.ImageCaption: if block['type'] == BlockType.ImageCaption:
para_text += merge_para_with_text(block) para_text += merge_para_with_text(block) + ' \n'
for block in para_block['blocks']: # 2nd.拼image_caption for block in para_block['blocks']: # 3rd.拼image_footnote
if block['type'] == BlockType.ImageFootnote: if block['type'] == BlockType.ImageFootnote:
para_text += merge_para_with_text(block) para_text += merge_para_with_text(block) + ' \n'
elif para_type == BlockType.Table: elif para_type == BlockType.Table:
if mode == 'nlp': if mode == 'nlp':
continue continue
elif mode == 'mm': elif mode == 'mm':
for block in para_block['blocks']: # 1st.拼table_caption for block in para_block['blocks']: # 1st.拼table_caption
if block['type'] == BlockType.TableCaption: if block['type'] == BlockType.TableCaption:
para_text += merge_para_with_text(block) para_text += merge_para_with_text(block) + ' \n'
for block in para_block['blocks']: # 2nd.拼table_body for block in para_block['blocks']: # 2nd.拼table_body
if block['type'] == BlockType.TableBody: if block['type'] == BlockType.TableBody:
for line in block['lines']: for line in block['lines']:
...@@ -160,11 +92,11 @@ def ocr_mk_markdown_with_para_core_v2(paras_of_layout, ...@@ -160,11 +92,11 @@ def ocr_mk_markdown_with_para_core_v2(paras_of_layout,
para_text += f"\n\n$\n {span['latex']}\n$\n\n" para_text += f"\n\n$\n {span['latex']}\n$\n\n"
elif span.get('html', ''): elif span.get('html', ''):
para_text += f"\n\n{span['html']}\n\n" para_text += f"\n\n{span['html']}\n\n"
else: elif span.get('image_path', ''):
para_text += f"\n![]({join_path(img_buket_path, span['image_path'])}) \n" para_text += f"\n![]({join_path(img_buket_path, span['image_path'])}) \n"
for block in para_block['blocks']: # 3rd.拼table_footnote for block in para_block['blocks']: # 3rd.拼table_footnote
if block['type'] == BlockType.TableFootnote: if block['type'] == BlockType.TableFootnote:
para_text += merge_para_with_text(block) para_text += merge_para_with_text(block) + ' \n'
if para_text.strip() == '': if para_text.strip() == '':
continue continue
...@@ -174,22 +106,26 @@ def ocr_mk_markdown_with_para_core_v2(paras_of_layout, ...@@ -174,22 +106,26 @@ def ocr_mk_markdown_with_para_core_v2(paras_of_layout,
return page_markdown return page_markdown
def merge_para_with_text(para_block): def detect_language(text):
en_pattern = r'[a-zA-Z]+'
def detect_language(text): en_matches = re.findall(en_pattern, text)
en_pattern = r'[a-zA-Z]+' en_length = sum(len(match) for match in en_matches)
en_matches = re.findall(en_pattern, text) if len(text) > 0:
en_length = sum(len(match) for match in en_matches) if en_length / len(text) >= 0.5:
if len(text) > 0: return 'en'
if en_length / len(text) >= 0.5:
return 'en'
else:
return 'unknown'
else: else:
return 'empty' return 'unknown'
else:
return 'empty'
def merge_para_with_text(para_block):
para_text = '' para_text = ''
for line in para_block['lines']: for i, line in enumerate(para_block['lines']):
if i >= 1 and line.get(ListLineTag.IS_LIST_START_LINE, False):
para_text += ' \n'
line_text = '' line_text = ''
line_lang = '' line_lang = ''
for span in line['spans']: for span in line['spans']:
...@@ -199,17 +135,11 @@ def merge_para_with_text(para_block): ...@@ -199,17 +135,11 @@ def merge_para_with_text(para_block):
if line_text != '': if line_text != '':
line_lang = detect_lang(line_text) line_lang = detect_lang(line_text)
for span in line['spans']: for span in line['spans']:
span_type = span['type'] span_type = span['type']
content = '' content = ''
if span_type == ContentType.Text: if span_type == ContentType.Text:
content = span['content'] content = ocr_escape_special_markdown_char(span['content'])
# language = detect_lang(content)
language = detect_language(content)
if language == 'en': # 只对英文长词进行分词处理,中文分词会丢失文本
content = ocr_escape_special_markdown_char(
split_long_words(content))
else:
content = ocr_escape_special_markdown_char(content)
elif span_type == ContentType.InlineEquation: elif span_type == ContentType.InlineEquation:
content = f" ${span['content']}$ " content = f" ${span['content']}$ "
elif span_type == ContentType.InterlineEquation: elif span_type == ContentType.InterlineEquation:
...@@ -230,177 +160,83 @@ def merge_para_with_text(para_block): ...@@ -230,177 +160,83 @@ def merge_para_with_text(para_block):
return para_text return para_text
def para_to_standard_format(para, img_buket_path): def para_to_standard_format_v2(para_block, img_buket_path, page_idx, drop_reason=None):
para_content = {}
if len(para) == 1:
para_content = line_to_standard_format(para[0], img_buket_path)
elif len(para) > 1:
para_text = ''
inline_equation_num = 0
for line in para:
for span in line['spans']:
language = ''
span_type = span.get('type')
content = ''
if span_type == ContentType.Text:
content = span['content']
language = detect_lang(content)
if language == 'en': # 只对英文长词进行分词处理,中文分词会丢失文本
content = ocr_escape_special_markdown_char(
split_long_words(content))
else:
content = ocr_escape_special_markdown_char(content)
elif span_type == ContentType.InlineEquation:
content = f"${span['content']}$"
inline_equation_num += 1
if language == 'en': # 英文语境下 content间需要空格分隔
para_text += content + ' '
else: # 中文语境下,content间不需要空格分隔
para_text += content
para_content = {
'type': 'text',
'text': para_text,
'inline_equation_num': inline_equation_num,
}
return para_content
def para_to_standard_format_v2(para_block, img_buket_path, page_idx):
para_type = para_block['type'] para_type = para_block['type']
if para_type == BlockType.Text: para_content = {}
if para_type in [BlockType.Text, BlockType.List, BlockType.Index]:
para_content = { para_content = {
'type': 'text', 'type': 'text',
'text': merge_para_with_text(para_block), 'text': merge_para_with_text(para_block),
'page_idx': page_idx,
} }
elif para_type == BlockType.Title: elif para_type == BlockType.Title:
para_content = { para_content = {
'type': 'text', 'type': 'text',
'text': merge_para_with_text(para_block), 'text': merge_para_with_text(para_block),
'text_level': 1, 'text_level': 1,
'page_idx': page_idx,
} }
elif para_type == BlockType.InterlineEquation: elif para_type == BlockType.InterlineEquation:
para_content = { para_content = {
'type': 'equation', 'type': 'equation',
'text': merge_para_with_text(para_block), 'text': merge_para_with_text(para_block),
'text_format': 'latex', 'text_format': 'latex',
'page_idx': page_idx,
} }
elif para_type == BlockType.Image: elif para_type == BlockType.Image:
para_content = {'type': 'image', 'page_idx': page_idx} para_content = {'type': 'image', 'img_path': '', 'img_caption': [], 'img_footnote': []}
for block in para_block['blocks']: for block in para_block['blocks']:
if block['type'] == BlockType.ImageBody: if block['type'] == BlockType.ImageBody:
para_content['img_path'] = join_path( for line in block['lines']:
img_buket_path, for span in line['spans']:
block['lines'][0]['spans'][0]['image_path']) if span['type'] == ContentType.Image:
if span.get('image_path', ''):
para_content['img_path'] = join_path(img_buket_path, span['image_path'])
if block['type'] == BlockType.ImageCaption: if block['type'] == BlockType.ImageCaption:
para_content['img_caption'] = merge_para_with_text(block) para_content['img_caption'].append(merge_para_with_text(block))
if block['type'] == BlockType.ImageFootnote: if block['type'] == BlockType.ImageFootnote:
para_content['img_footnote'] = merge_para_with_text(block) para_content['img_footnote'].append(merge_para_with_text(block))
elif para_type == BlockType.Table: elif para_type == BlockType.Table:
para_content = {'type': 'table', 'page_idx': page_idx} para_content = {'type': 'table', 'img_path': '', 'table_caption': [], 'table_footnote': []}
for block in para_block['blocks']: for block in para_block['blocks']:
if block['type'] == BlockType.TableBody: if block['type'] == BlockType.TableBody:
if block["lines"][0]["spans"][0].get('latex', ''): for line in block['lines']:
para_content['table_body'] = f"\n\n$\n {block['lines'][0]['spans'][0]['latex']}\n$\n\n" for span in line['spans']:
elif block["lines"][0]["spans"][0].get('html', ''): if span['type'] == ContentType.Table:
para_content['table_body'] = f"\n\n{block['lines'][0]['spans'][0]['html']}\n\n"
para_content['img_path'] = join_path(img_buket_path, block["lines"][0]["spans"][0]['image_path'])
if block['type'] == BlockType.TableCaption:
para_content['table_caption'] = merge_para_with_text(block)
if block['type'] == BlockType.TableFootnote:
para_content['table_footnote'] = merge_para_with_text(block)
return para_content
if span.get('latex', ''):
para_content['table_body'] = f"\n\n$\n {span['latex']}\n$\n\n"
elif span.get('html', ''):
para_content['table_body'] = f"\n\n{span['html']}\n\n"
def make_standard_format_with_para(pdf_info_dict: list, img_buket_path: str): if span.get('image_path', ''):
content_list = [] para_content['img_path'] = join_path(img_buket_path, span['image_path'])
for page_info in pdf_info_dict:
paras_of_layout = page_info.get('para_blocks')
if not paras_of_layout:
continue
for para_block in paras_of_layout:
para_content = para_to_standard_format_v2(para_block,
img_buket_path)
content_list.append(para_content)
return content_list
if block['type'] == BlockType.TableCaption:
para_content['table_caption'].append(merge_para_with_text(block))
if block['type'] == BlockType.TableFootnote:
para_content['table_footnote'].append(merge_para_with_text(block))
def line_to_standard_format(line, img_buket_path): para_content['page_idx'] = page_idx
line_text = ''
inline_equation_num = 0
for span in line['spans']:
if not span.get('content'):
if not span.get('image_path'):
continue
else:
if span['type'] == ContentType.Image:
content = {
'type': 'image',
'img_path': join_path(img_buket_path,
span['image_path']),
}
return content
elif span['type'] == ContentType.Table:
content = {
'type': 'table',
'img_path': join_path(img_buket_path,
span['image_path']),
}
return content
else:
if span['type'] == ContentType.InterlineEquation:
interline_equation = span['content']
content = {
'type': 'equation',
'latex': f'$$\n{interline_equation}\n$$'
}
return content
elif span['type'] == ContentType.InlineEquation:
inline_equation = span['content']
line_text += f'${inline_equation}$'
inline_equation_num += 1
elif span['type'] == ContentType.Text:
text_content = ocr_escape_special_markdown_char(
span['content']) # 转义特殊符号
line_text += text_content
content = {
'type': 'text',
'text': line_text,
'inline_equation_num': inline_equation_num,
}
return content
if drop_reason is not None:
para_content['drop_reason'] = drop_reason
def ocr_mk_mm_standard_format(pdf_info_dict: list): return para_content
"""content_list type string
image/text/table/equation(行间的单独拿出来,行内的和text合并) latex string
latex文本字段。 text string 纯文本格式的文本数据。 md string
markdown格式的文本数据。 img_path string s3://full/path/to/img.jpg."""
content_list = []
for page_info in pdf_info_dict:
blocks = page_info.get('preproc_blocks')
if not blocks:
continue
for block in blocks:
for line in block['lines']:
content = line_to_standard_format(line)
content_list.append(content)
return content_list
def union_make(pdf_info_dict: list, def union_make(pdf_info_dict: list,
make_mode: str, make_mode: str,
drop_mode: str, drop_mode: str,
img_buket_path: str = ''): img_buket_path: str = '',
):
output_content = [] output_content = []
for page_info in pdf_info_dict: for page_info in pdf_info_dict:
drop_reason_flag = False
drop_reason = None
if page_info.get('need_drop', False): if page_info.get('need_drop', False):
drop_reason = page_info.get('drop_reason') drop_reason = page_info.get('drop_reason')
if drop_mode == DropMode.NONE: if drop_mode == DropMode.NONE:
pass pass
elif drop_mode == DropMode.NONE_WITH_REASON:
drop_reason_flag = True
elif drop_mode == DropMode.WHOLE_PDF: elif drop_mode == DropMode.WHOLE_PDF:
raise Exception((f'drop_mode is {DropMode.WHOLE_PDF} ,' raise Exception((f'drop_mode is {DropMode.WHOLE_PDF} ,'
f'drop_reason is {drop_reason}')) f'drop_reason is {drop_reason}'))
...@@ -425,8 +261,12 @@ def union_make(pdf_info_dict: list, ...@@ -425,8 +261,12 @@ def union_make(pdf_info_dict: list,
output_content.extend(page_markdown) output_content.extend(page_markdown)
elif make_mode == MakeMode.STANDARD_FORMAT: elif make_mode == MakeMode.STANDARD_FORMAT:
for para_block in paras_of_layout: for para_block in paras_of_layout:
para_content = para_to_standard_format_v2( if drop_reason_flag:
para_block, img_buket_path, page_idx) para_content = para_to_standard_format_v2(
para_block, img_buket_path, page_idx)
else:
para_content = para_to_standard_format_v2(
para_block, img_buket_path, page_idx)
output_content.append(para_content) output_content.append(para_content)
if make_mode in [MakeMode.MM_MD, MakeMode.NLP_MD]: if make_mode in [MakeMode.MM_MD, MakeMode.NLP_MD]:
return '\n\n'.join(output_content) return '\n\n'.join(output_content)
......
...@@ -10,18 +10,12 @@ block维度自定义字段 ...@@ -10,18 +10,12 @@ block维度自定义字段
# block中lines是否被删除 # block中lines是否被删除
LINES_DELETED = "lines_deleted" LINES_DELETED = "lines_deleted"
# struct eqtable
STRUCT_EQTABLE = "struct_eqtable"
# table recognition max time default value # table recognition max time default value
TABLE_MAX_TIME_VALUE = 400 TABLE_MAX_TIME_VALUE = 400
# pp_table_result_max_length # pp_table_result_max_length
TABLE_MAX_LEN = 480 TABLE_MAX_LEN = 480
# pp table structure algorithm
TABLE_MASTER = "TableMaster"
# table master structure dict # table master structure dict
TABLE_MASTER_DICT = "table_master_structure_dict.txt" TABLE_MASTER_DICT = "table_master_structure_dict.txt"
...@@ -29,12 +23,31 @@ TABLE_MASTER_DICT = "table_master_structure_dict.txt" ...@@ -29,12 +23,31 @@ TABLE_MASTER_DICT = "table_master_structure_dict.txt"
TABLE_MASTER_DIR = "table_structure_tablemaster_infer/" TABLE_MASTER_DIR = "table_structure_tablemaster_infer/"
# pp detect model dir # pp detect model dir
DETECT_MODEL_DIR = "ch_PP-OCRv3_det_infer" DETECT_MODEL_DIR = "ch_PP-OCRv4_det_infer"
# pp rec model dir # pp rec model dir
REC_MODEL_DIR = "ch_PP-OCRv3_rec_infer" REC_MODEL_DIR = "ch_PP-OCRv4_rec_infer"
# pp rec char dict path # pp rec char dict path
REC_CHAR_DICT = "ppocr_keys_v1.txt" REC_CHAR_DICT = "ppocr_keys_v1.txt"
# pp rec copy rec directory
PP_REC_DIRECTORY = ".paddleocr/whl/rec/ch/ch_PP-OCRv4_rec_infer"
# pp rec copy det directory
PP_DET_DIRECTORY = ".paddleocr/whl/det/ch/ch_PP-OCRv4_det_infer"
class MODEL_NAME:
# pp table structure algorithm
TABLE_MASTER = "tablemaster"
# struct eqtable
STRUCT_EQTABLE = "struct_eqtable"
DocLayout_YOLO = "doclayout_yolo"
LAYOUTLMv3 = "layoutlmv3"
YOLO_V8_MFD = "yolo_v8_mfd"
UniMerNet_v2_Small = "unimernet_small"
\ No newline at end of file
...@@ -8,3 +8,4 @@ class DropMode: ...@@ -8,3 +8,4 @@ class DropMode:
WHOLE_PDF = "whole_pdf" WHOLE_PDF = "whole_pdf"
SINGLE_PAGE = "single_page" SINGLE_PAGE = "single_page"
NONE = "none" NONE = "none"
NONE_WITH_REASON = "none_with_reason"
...@@ -445,3 +445,38 @@ def get_overlap_area(bbox1, bbox2): ...@@ -445,3 +445,38 @@ def get_overlap_area(bbox1, bbox2):
# The area of overlap area # The area of overlap area
return (x_right - x_left) * (y_bottom - y_top) return (x_right - x_left) * (y_bottom - y_top)
def calculate_vertical_projection_overlap_ratio(block1, block2):
"""
Calculate the proportion of the x-axis covered by the vertical projection of two blocks.
Args:
block1 (tuple): Coordinates of the first block (x0, y0, x1, y1).
block2 (tuple): Coordinates of the second block (x0, y0, x1, y1).
Returns:
float: The proportion of the x-axis covered by the vertical projection of the two blocks.
"""
x0_1, _, x1_1, _ = block1
x0_2, _, x1_2, _ = block2
# Calculate the intersection of the x-coordinates
x_left = max(x0_1, x0_2)
x_right = min(x1_1, x1_2)
if x_right < x_left:
return 0.0
# Length of the intersection
intersection_length = x_right - x_left
# Length of the x-axis projection of the first block
block1_length = x1_1 - x0_1
if block1_length == 0:
return 0.0
# Proportion of the x-axis covered by the intersection
# logger.info(f"intersection_length: {intersection_length}, block1_length: {block1_length}")
return intersection_length / block1_length
# Copyright (c) Opendatalab. All rights reserved.
import torch
import gc
def clean_memory():
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
gc.collect()
\ No newline at end of file
""" """根据bucket的名字返回对应的s3 AK, SK,endpoint三元组."""
根据bucket的名字返回对应的s3 AK, SK,endpoint三元组
"""
import json import json
import os import os
from loguru import logger from loguru import logger
from magic_pdf.libs.Constants import MODEL_NAME
from magic_pdf.libs.commons import parse_bucket_key from magic_pdf.libs.commons import parse_bucket_key
# 定义配置文件名常量 # 定义配置文件名常量
CONFIG_FILE_NAME = "magic-pdf.json" CONFIG_FILE_NAME = os.getenv('MINERU_TOOLS_CONFIG_JSON', 'magic-pdf.json')
def read_config(): def read_config():
home_dir = os.path.expanduser("~") if os.path.isabs(CONFIG_FILE_NAME):
config_file = CONFIG_FILE_NAME
config_file = os.path.join(home_dir, CONFIG_FILE_NAME) else:
home_dir = os.path.expanduser('~')
config_file = os.path.join(home_dir, CONFIG_FILE_NAME)
if not os.path.exists(config_file): if not os.path.exists(config_file):
raise FileNotFoundError(f"{config_file} not found") raise FileNotFoundError(f'{config_file} not found')
with open(config_file, "r", encoding="utf-8") as f: with open(config_file, 'r', encoding='utf-8') as f:
config = json.load(f) config = json.load(f)
return config return config
def get_s3_config(bucket_name: str): def get_s3_config(bucket_name: str):
""" """~/magic-pdf.json 读出来."""
~/magic-pdf.json 读出来
"""
config = read_config() config = read_config()
bucket_info = config.get("bucket_info") bucket_info = config.get('bucket_info')
if bucket_name not in bucket_info: if bucket_name not in bucket_info:
access_key, secret_key, storage_endpoint = bucket_info["[default]"] access_key, secret_key, storage_endpoint = bucket_info['[default]']
else: else:
access_key, secret_key, storage_endpoint = bucket_info[bucket_name] access_key, secret_key, storage_endpoint = bucket_info[bucket_name]
if access_key is None or secret_key is None or storage_endpoint is None: if access_key is None or secret_key is None or storage_endpoint is None:
raise Exception(f"ak, sk or endpoint not found in {CONFIG_FILE_NAME}") raise Exception(f'ak, sk or endpoint not found in {CONFIG_FILE_NAME}')
# logger.info(f"get_s3_config: ak={access_key}, sk={secret_key}, endpoint={storage_endpoint}") # logger.info(f"get_s3_config: ak={access_key}, sk={secret_key}, endpoint={storage_endpoint}")
...@@ -49,7 +47,7 @@ def get_s3_config(bucket_name: str): ...@@ -49,7 +47,7 @@ def get_s3_config(bucket_name: str):
def get_s3_config_dict(path: str): def get_s3_config_dict(path: str):
access_key, secret_key, storage_endpoint = get_s3_config(get_bucket_name(path)) access_key, secret_key, storage_endpoint = get_s3_config(get_bucket_name(path))
return {"ak": access_key, "sk": secret_key, "endpoint": storage_endpoint} return {'ak': access_key, 'sk': secret_key, 'endpoint': storage_endpoint}
def get_bucket_name(path): def get_bucket_name(path):
...@@ -59,33 +57,65 @@ def get_bucket_name(path): ...@@ -59,33 +57,65 @@ def get_bucket_name(path):
def get_local_models_dir(): def get_local_models_dir():
config = read_config() config = read_config()
models_dir = config.get("models-dir") models_dir = config.get('models-dir')
if models_dir is None: if models_dir is None:
logger.warning(f"'models-dir' not found in {CONFIG_FILE_NAME}, use '/tmp/models' as default") logger.warning(f"'models-dir' not found in {CONFIG_FILE_NAME}, use '/tmp/models' as default")
return "/tmp/models" return '/tmp/models'
else: else:
return models_dir return models_dir
def get_local_layoutreader_model_dir():
config = read_config()
layoutreader_model_dir = config.get('layoutreader-model-dir')
if layoutreader_model_dir is None or not os.path.exists(layoutreader_model_dir):
home_dir = os.path.expanduser('~')
layoutreader_at_modelscope_dir_path = os.path.join(home_dir, '.cache/modelscope/hub/ppaanngggg/layoutreader')
logger.warning(f"'layoutreader-model-dir' not exists, use {layoutreader_at_modelscope_dir_path} as default")
return layoutreader_at_modelscope_dir_path
else:
return layoutreader_model_dir
def get_device(): def get_device():
config = read_config() config = read_config()
device = config.get("device-mode") device = config.get('device-mode')
if device is None: if device is None:
logger.warning(f"'device-mode' not found in {CONFIG_FILE_NAME}, use 'cpu' as default") logger.warning(f"'device-mode' not found in {CONFIG_FILE_NAME}, use 'cpu' as default")
return "cpu" return 'cpu'
else: else:
return device return device
def get_table_recog_config(): def get_table_recog_config():
config = read_config() config = read_config()
table_config = config.get("table-config") table_config = config.get('table-config')
if table_config is None: if table_config is None:
logger.warning(f"'table-config' not found in {CONFIG_FILE_NAME}, use 'False' as default") logger.warning(f"'table-config' not found in {CONFIG_FILE_NAME}, use 'False' as default")
return json.loads('{"is_table_recog_enable": false, "max_time": 400}') return json.loads(f'{{"model": "{MODEL_NAME.TABLE_MASTER}","enable": false, "max_time": 400}}')
else: else:
return table_config return table_config
def get_layout_config():
config = read_config()
layout_config = config.get("layout-config")
if layout_config is None:
logger.warning(f"'layout-config' not found in {CONFIG_FILE_NAME}, use '{MODEL_NAME.LAYOUTLMv3}' as default")
return json.loads(f'{{"model": "{MODEL_NAME.LAYOUTLMv3}"}}')
else:
return layout_config
def get_formula_config():
config = read_config()
formula_config = config.get("formula-config")
if formula_config is None:
logger.warning(f"'formula-config' not found in {CONFIG_FILE_NAME}, use 'True' as default")
return json.loads(f'{{"mfd_model": "{MODEL_NAME.YOLO_V8_MFD}","mfr_model": "{MODEL_NAME.UniMerNet_v2_Small}","enable": true}}')
else:
return formula_config
if __name__ == "__main__": if __name__ == "__main__":
ak, sk, endpoint = get_s3_config("llm-raw") ak, sk, endpoint = get_s3_config("llm-raw")
from magic_pdf.data.dataset import PymuDocDataset
from magic_pdf.libs.commons import fitz # PyMuPDF from magic_pdf.libs.commons import fitz # PyMuPDF
from magic_pdf.libs.Constants import CROSS_PAGE from magic_pdf.libs.Constants import CROSS_PAGE
from magic_pdf.libs.ocr_content_type import BlockType, CategoryId, ContentType from magic_pdf.libs.ocr_content_type import BlockType, CategoryId, ContentType
...@@ -33,7 +34,7 @@ def draw_bbox_without_number(i, bbox_list, page, rgb_config, fill_config): ...@@ -33,7 +34,7 @@ def draw_bbox_without_number(i, bbox_list, page, rgb_config, fill_config):
) # Draw the rectangle ) # Draw the rectangle
def draw_bbox_with_number(i, bbox_list, page, rgb_config, fill_config): def draw_bbox_with_number(i, bbox_list, page, rgb_config, fill_config, draw_bbox=True):
new_rgb = [] new_rgb = []
for item in rgb_config: for item in rgb_config:
item = float(item) / 255 item = float(item) / 255
...@@ -42,31 +43,31 @@ def draw_bbox_with_number(i, bbox_list, page, rgb_config, fill_config): ...@@ -42,31 +43,31 @@ def draw_bbox_with_number(i, bbox_list, page, rgb_config, fill_config):
for j, bbox in enumerate(page_data): for j, bbox in enumerate(page_data):
x0, y0, x1, y1 = bbox x0, y0, x1, y1 = bbox
rect_coords = fitz.Rect(x0, y0, x1, y1) # Define the rectangle rect_coords = fitz.Rect(x0, y0, x1, y1) # Define the rectangle
if fill_config: if draw_bbox:
page.draw_rect( if fill_config:
rect_coords, page.draw_rect(
color=None, rect_coords,
fill=new_rgb, color=None,
fill_opacity=0.3, fill=new_rgb,
width=0.5, fill_opacity=0.3,
overlay=True, width=0.5,
) # Draw the rectangle overlay=True,
else: ) # Draw the rectangle
page.draw_rect( else:
rect_coords, page.draw_rect(
color=new_rgb, rect_coords,
fill=None, color=new_rgb,
fill_opacity=1, fill=None,
width=0.5, fill_opacity=1,
overlay=True, width=0.5,
) # Draw the rectangle overlay=True,
) # Draw the rectangle
page.insert_text( page.insert_text(
(x0, y0 + 10), str(j + 1), fontsize=10, color=new_rgb (x1 + 2, y0 + 10), str(j + 1), fontsize=10, color=new_rgb
) # Insert the index in the top left corner of the rectangle ) # Insert the index in the top left corner of the rectangle
def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename): def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
layout_bbox_list = []
dropped_bbox_list = [] dropped_bbox_list = []
tables_list, tables_body_list = [], [] tables_list, tables_body_list = [], []
tables_caption_list, tables_footnote_list = [], [] tables_caption_list, tables_footnote_list = [], []
...@@ -75,17 +76,19 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename): ...@@ -75,17 +76,19 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
titles_list = [] titles_list = []
texts_list = [] texts_list = []
interequations_list = [] interequations_list = []
lists_list = []
indexs_list = []
for page in pdf_info: for page in pdf_info:
page_layout_list = []
page_dropped_list = [] page_dropped_list = []
tables, tables_body, tables_caption, tables_footnote = [], [], [], [] tables, tables_body, tables_caption, tables_footnote = [], [], [], []
imgs, imgs_body, imgs_caption, imgs_footnote = [], [], [], [] imgs, imgs_body, imgs_caption, imgs_footnote = [], [], [], []
titles = [] titles = []
texts = [] texts = []
interequations = [] interequations = []
for layout in page['layout_bboxes']: lists = []
page_layout_list.append(layout['layout_bbox']) indices = []
layout_bbox_list.append(page_layout_list)
for dropped_bbox in page['discarded_blocks']: for dropped_bbox in page['discarded_blocks']:
page_dropped_list.append(dropped_bbox['bbox']) page_dropped_list.append(dropped_bbox['bbox'])
dropped_bbox_list.append(page_dropped_list) dropped_bbox_list.append(page_dropped_list)
...@@ -117,6 +120,11 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename): ...@@ -117,6 +120,11 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
texts.append(bbox) texts.append(bbox)
elif block['type'] == BlockType.InterlineEquation: elif block['type'] == BlockType.InterlineEquation:
interequations.append(bbox) interequations.append(bbox)
elif block['type'] == BlockType.List:
lists.append(bbox)
elif block['type'] == BlockType.Index:
indices.append(bbox)
tables_list.append(tables) tables_list.append(tables)
tables_body_list.append(tables_body) tables_body_list.append(tables_body)
tables_caption_list.append(tables_caption) tables_caption_list.append(tables_caption)
...@@ -128,30 +136,62 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename): ...@@ -128,30 +136,62 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
titles_list.append(titles) titles_list.append(titles)
texts_list.append(texts) texts_list.append(texts)
interequations_list.append(interequations) interequations_list.append(interequations)
lists_list.append(lists)
indexs_list.append(indices)
layout_bbox_list = []
table_type_order = {
'table_caption': 1,
'table_body': 2,
'table_footnote': 3
}
for page in pdf_info:
page_block_list = []
for block in page['para_blocks']:
if block['type'] in [
BlockType.Text,
BlockType.Title,
BlockType.InterlineEquation,
BlockType.List,
BlockType.Index,
]:
bbox = block['bbox']
page_block_list.append(bbox)
elif block['type'] in [BlockType.Image]:
for sub_block in block['blocks']:
bbox = sub_block['bbox']
page_block_list.append(bbox)
elif block['type'] in [BlockType.Table]:
sorted_blocks = sorted(block['blocks'], key=lambda x: table_type_order[x['type']])
for sub_block in sorted_blocks:
bbox = sub_block['bbox']
page_block_list.append(bbox)
layout_bbox_list.append(page_block_list)
pdf_docs = fitz.open('pdf', pdf_bytes) pdf_docs = fitz.open('pdf', pdf_bytes)
for i, page in enumerate(pdf_docs): for i, page in enumerate(pdf_docs):
draw_bbox_with_number(i, layout_bbox_list, page, [255, 0, 0], False)
draw_bbox_without_number(i, dropped_bbox_list, page, [158, 158, 158], draw_bbox_without_number(i, dropped_bbox_list, page, [158, 158, 158], True)
True) # draw_bbox_without_number(i, tables_list, page, [153, 153, 0], True) # color !
draw_bbox_without_number(i, tables_list, page, [153, 153, 0], draw_bbox_without_number(i, tables_body_list, page, [204, 204, 0], True)
True) # color ! draw_bbox_without_number(i, tables_caption_list, page, [255, 255, 102], True)
draw_bbox_without_number(i, tables_body_list, page, [204, 204, 0], draw_bbox_without_number(i, tables_footnote_list, page, [229, 255, 204], True)
True) # draw_bbox_without_number(i, imgs_list, page, [51, 102, 0], True)
draw_bbox_without_number(i, tables_caption_list, page, [255, 255, 102],
True)
draw_bbox_without_number(i, tables_footnote_list, page,
[229, 255, 204], True)
draw_bbox_without_number(i, imgs_list, page, [51, 102, 0], True)
draw_bbox_without_number(i, imgs_body_list, page, [153, 255, 51], True) draw_bbox_without_number(i, imgs_body_list, page, [153, 255, 51], True)
draw_bbox_without_number(i, imgs_caption_list, page, [102, 178, 255], draw_bbox_without_number(i, imgs_caption_list, page, [102, 178, 255], True)
True) draw_bbox_without_number(i, imgs_footnote_list, page, [255, 178, 102], True),
draw_bbox_with_number(i, imgs_footnote_list, page, [255, 178, 102],
True),
draw_bbox_without_number(i, titles_list, page, [102, 102, 255], True) draw_bbox_without_number(i, titles_list, page, [102, 102, 255], True)
draw_bbox_without_number(i, texts_list, page, [153, 0, 76], True) draw_bbox_without_number(i, texts_list, page, [153, 0, 76], True)
draw_bbox_without_number(i, interequations_list, page, [0, 255, 0], draw_bbox_without_number(i, interequations_list, page, [0, 255, 0], True)
True) draw_bbox_without_number(i, lists_list, page, [40, 169, 92], True)
draw_bbox_without_number(i, indexs_list, page, [40, 169, 92], True)
draw_bbox_with_number(
i, layout_bbox_list, page, [255, 0, 0], False, draw_bbox=False
)
# Save the PDF # Save the PDF
pdf_docs.save(f'{out_path}/{filename}_layout.pdf') pdf_docs.save(f'{out_path}/{filename}_layout.pdf')
...@@ -209,11 +249,14 @@ def draw_span_bbox(pdf_info, pdf_bytes, out_path, filename): ...@@ -209,11 +249,14 @@ def draw_span_bbox(pdf_info, pdf_bytes, out_path, filename):
page_dropped_list.append(span['bbox']) page_dropped_list.append(span['bbox'])
dropped_list.append(page_dropped_list) dropped_list.append(page_dropped_list)
# 构造其余useful_list # 构造其余useful_list
for block in page['para_blocks']: # for block in page['para_blocks']: # span直接用分段合并前的结果就可以
for block in page['preproc_blocks']:
if block['type'] in [ if block['type'] in [
BlockType.Text, BlockType.Text,
BlockType.Title, BlockType.Title,
BlockType.InterlineEquation, BlockType.InterlineEquation,
BlockType.List,
BlockType.Index,
]: ]:
for line in block['lines']: for line in block['lines']:
for span in line['spans']: for span in line['spans']:
...@@ -232,10 +275,8 @@ def draw_span_bbox(pdf_info, pdf_bytes, out_path, filename): ...@@ -232,10 +275,8 @@ def draw_span_bbox(pdf_info, pdf_bytes, out_path, filename):
for i, page in enumerate(pdf_docs): for i, page in enumerate(pdf_docs):
# 获取当前页面的数据 # 获取当前页面的数据
draw_bbox_without_number(i, text_list, page, [255, 0, 0], False) draw_bbox_without_number(i, text_list, page, [255, 0, 0], False)
draw_bbox_without_number(i, inline_equation_list, page, [0, 255, 0], draw_bbox_without_number(i, inline_equation_list, page, [0, 255, 0], False)
False) draw_bbox_without_number(i, interline_equation_list, page, [0, 0, 255], False)
draw_bbox_without_number(i, interline_equation_list, page, [0, 0, 255],
False)
draw_bbox_without_number(i, image_list, page, [255, 204, 0], False) draw_bbox_without_number(i, image_list, page, [255, 204, 0], False)
draw_bbox_without_number(i, table_list, page, [204, 0, 255], False) draw_bbox_without_number(i, table_list, page, [204, 0, 255], False)
draw_bbox_without_number(i, dropped_list, page, [158, 158, 158], False) draw_bbox_without_number(i, dropped_list, page, [158, 158, 158], False)
...@@ -244,7 +285,7 @@ def draw_span_bbox(pdf_info, pdf_bytes, out_path, filename): ...@@ -244,7 +285,7 @@ def draw_span_bbox(pdf_info, pdf_bytes, out_path, filename):
pdf_docs.save(f'{out_path}/{filename}_spans.pdf') pdf_docs.save(f'{out_path}/{filename}_spans.pdf')
def drow_model_bbox(model_list: list, pdf_bytes, out_path, filename): def draw_model_bbox(model_list: list, pdf_bytes, out_path, filename):
dropped_bbox_list = [] dropped_bbox_list = []
tables_body_list, tables_caption_list, tables_footnote_list = [], [], [] tables_body_list, tables_caption_list, tables_footnote_list = [], [], []
imgs_body_list, imgs_caption_list, imgs_footnote_list = [], [], [] imgs_body_list, imgs_caption_list, imgs_footnote_list = [], [], []
...@@ -252,7 +293,7 @@ def drow_model_bbox(model_list: list, pdf_bytes, out_path, filename): ...@@ -252,7 +293,7 @@ def drow_model_bbox(model_list: list, pdf_bytes, out_path, filename):
texts_list = [] texts_list = []
interequations_list = [] interequations_list = []
pdf_docs = fitz.open('pdf', pdf_bytes) pdf_docs = fitz.open('pdf', pdf_bytes)
magic_model = MagicModel(model_list, pdf_docs) magic_model = MagicModel(model_list, PymuDocDataset(pdf_bytes))
for i in range(len(model_list)): for i in range(len(model_list)):
page_dropped_list = [] page_dropped_list = []
tables_body, tables_caption, tables_footnote = [], [], [] tables_body, tables_caption, tables_footnote = [], [], []
...@@ -278,8 +319,7 @@ def drow_model_bbox(model_list: list, pdf_bytes, out_path, filename): ...@@ -278,8 +319,7 @@ def drow_model_bbox(model_list: list, pdf_bytes, out_path, filename):
imgs_body.append(bbox) imgs_body.append(bbox)
elif layout_det['category_id'] == CategoryId.ImageCaption: elif layout_det['category_id'] == CategoryId.ImageCaption:
imgs_caption.append(bbox) imgs_caption.append(bbox)
elif layout_det[ elif layout_det['category_id'] == CategoryId.InterlineEquation_YOLO:
'category_id'] == CategoryId.InterlineEquation_YOLO:
interequations.append(bbox) interequations.append(bbox)
elif layout_det['category_id'] == CategoryId.Abandon: elif layout_det['category_id'] == CategoryId.Abandon:
page_dropped_list.append(bbox) page_dropped_list.append(bbox)
...@@ -298,21 +338,66 @@ def drow_model_bbox(model_list: list, pdf_bytes, out_path, filename): ...@@ -298,21 +338,66 @@ def drow_model_bbox(model_list: list, pdf_bytes, out_path, filename):
imgs_footnote_list.append(imgs_footnote) imgs_footnote_list.append(imgs_footnote)
for i, page in enumerate(pdf_docs): for i, page in enumerate(pdf_docs):
draw_bbox_with_number(i, dropped_bbox_list, page, [158, 158, 158], draw_bbox_with_number(
True) # color ! i, dropped_bbox_list, page, [158, 158, 158], True
) # color !
draw_bbox_with_number(i, tables_body_list, page, [204, 204, 0], True) draw_bbox_with_number(i, tables_body_list, page, [204, 204, 0], True)
draw_bbox_with_number(i, tables_caption_list, page, [255, 255, 102], draw_bbox_with_number(i, tables_caption_list, page, [255, 255, 102], True)
True) draw_bbox_with_number(i, tables_footnote_list, page, [229, 255, 204], True)
draw_bbox_with_number(i, tables_footnote_list, page, [229, 255, 204],
True)
draw_bbox_with_number(i, imgs_body_list, page, [153, 255, 51], True) draw_bbox_with_number(i, imgs_body_list, page, [153, 255, 51], True)
draw_bbox_with_number(i, imgs_caption_list, page, [102, 178, 255], draw_bbox_with_number(i, imgs_caption_list, page, [102, 178, 255], True)
True) draw_bbox_with_number(i, imgs_footnote_list, page, [255, 178, 102], True)
draw_bbox_with_number(i, imgs_footnote_list, page, [255, 178, 102],
True)
draw_bbox_with_number(i, titles_list, page, [102, 102, 255], True) draw_bbox_with_number(i, titles_list, page, [102, 102, 255], True)
draw_bbox_with_number(i, texts_list, page, [153, 0, 76], True) draw_bbox_with_number(i, texts_list, page, [153, 0, 76], True)
draw_bbox_with_number(i, interequations_list, page, [0, 255, 0], True) draw_bbox_with_number(i, interequations_list, page, [0, 255, 0], True)
# Save the PDF # Save the PDF
pdf_docs.save(f'{out_path}/{filename}_model.pdf') pdf_docs.save(f'{out_path}/{filename}_model.pdf')
def draw_line_sort_bbox(pdf_info, pdf_bytes, out_path, filename):
layout_bbox_list = []
for page in pdf_info:
page_line_list = []
for block in page['preproc_blocks']:
if block['type'] in [BlockType.Text, BlockType.Title, BlockType.InterlineEquation]:
for line in block['lines']:
bbox = line['bbox']
index = line['index']
page_line_list.append({'index': index, 'bbox': bbox})
if block['type'] in [BlockType.Image, BlockType.Table]:
for sub_block in block['blocks']:
if sub_block['type'] in [BlockType.ImageBody, BlockType.TableBody]:
for line in sub_block['virtual_lines']:
bbox = line['bbox']
index = line['index']
page_line_list.append({'index': index, 'bbox': bbox})
elif sub_block['type'] in [BlockType.ImageCaption, BlockType.TableCaption, BlockType.ImageFootnote, BlockType.TableFootnote]:
for line in sub_block['lines']:
bbox = line['bbox']
index = line['index']
page_line_list.append({'index': index, 'bbox': bbox})
sorted_bboxes = sorted(page_line_list, key=lambda x: x['index'])
layout_bbox_list.append(sorted_bbox['bbox'] for sorted_bbox in sorted_bboxes)
pdf_docs = fitz.open('pdf', pdf_bytes)
for i, page in enumerate(pdf_docs):
draw_bbox_with_number(i, layout_bbox_list, page, [255, 0, 0], False)
pdf_docs.save(f'{out_path}/{filename}_line_sort.pdf')
def draw_layout_sort_bbox(pdf_info, pdf_bytes, out_path, filename):
layout_bbox_list = []
for page in pdf_info:
page_block_list = []
for block in page['para_blocks']:
bbox = block['bbox']
page_block_list.append(bbox)
layout_bbox_list.append(page_block_list)
pdf_docs = fitz.open('pdf', pdf_bytes)
for i, page in enumerate(pdf_docs):
draw_bbox_with_number(i, layout_bbox_list, page, [255, 0, 0], False)
pdf_docs.save(f'{out_path}/{filename}_layout_sort.pdf')
...@@ -20,6 +20,8 @@ class BlockType: ...@@ -20,6 +20,8 @@ class BlockType:
InterlineEquation = 'interline_equation' InterlineEquation = 'interline_equation'
Footnote = 'footnote' Footnote = 'footnote'
Discarded = 'discarded' Discarded = 'discarded'
List = 'list'
Index = 'index'
class CategoryId: class CategoryId:
......
__version__ = "0.8.1" __version__ = "0.9.0"
...@@ -4,7 +4,9 @@ import fitz ...@@ -4,7 +4,9 @@ import fitz
import numpy as np import numpy as np
from loguru import logger from loguru import logger
from magic_pdf.libs.config_reader import get_local_models_dir, get_device, get_table_recog_config from magic_pdf.libs.clean_memory import clean_memory
from magic_pdf.libs.config_reader import get_local_models_dir, get_device, get_table_recog_config, get_layout_config, \
get_formula_config
from magic_pdf.model.model_list import MODEL from magic_pdf.model.model_list import MODEL
import magic_pdf.model as model_config import magic_pdf.model as model_config
...@@ -23,7 +25,7 @@ def remove_duplicates_dicts(lst): ...@@ -23,7 +25,7 @@ def remove_duplicates_dicts(lst):
return unique_dicts return unique_dicts
def load_images_from_pdf(pdf_bytes: bytes, dpi=200) -> list: def load_images_from_pdf(pdf_bytes: bytes, dpi=200, start_page_id=0, end_page_id=None) -> list:
try: try:
from PIL import Image from PIL import Image
except ImportError: except ImportError:
...@@ -32,18 +34,28 @@ def load_images_from_pdf(pdf_bytes: bytes, dpi=200) -> list: ...@@ -32,18 +34,28 @@ def load_images_from_pdf(pdf_bytes: bytes, dpi=200) -> list:
images = [] images = []
with fitz.open("pdf", pdf_bytes) as doc: with fitz.open("pdf", pdf_bytes) as doc:
pdf_page_num = doc.page_count
end_page_id = end_page_id if end_page_id is not None and end_page_id >= 0 else pdf_page_num - 1
if end_page_id > pdf_page_num - 1:
logger.warning("end_page_id is out of range, use images length")
end_page_id = pdf_page_num - 1
for index in range(0, doc.page_count): for index in range(0, doc.page_count):
page = doc[index] if start_page_id <= index <= end_page_id:
mat = fitz.Matrix(dpi / 72, dpi / 72) page = doc[index]
pm = page.get_pixmap(matrix=mat, alpha=False) mat = fitz.Matrix(dpi / 72, dpi / 72)
pm = page.get_pixmap(matrix=mat, alpha=False)
# If the width or height exceeds 9000 after scaling, do not scale further.
if pm.width > 9000 or pm.height > 9000:
pm = page.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False)
# If the width or height exceeds 9000 after scaling, do not scale further. img = Image.frombytes("RGB", (pm.width, pm.height), pm.samples)
if pm.width > 9000 or pm.height > 9000: img = np.array(img)
pm = page.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False) img_dict = {"img": img, "width": pm.width, "height": pm.height}
else:
img_dict = {"img": [], "width": 0, "height": 0}
img = Image.frombytes("RGB", (pm.width, pm.height), pm.samples)
img = np.array(img)
img_dict = {"img": img, "width": pm.width, "height": pm.height}
images.append(img_dict) images.append(img_dict)
return images return images
...@@ -57,14 +69,17 @@ class ModelSingleton: ...@@ -57,14 +69,17 @@ class ModelSingleton:
cls._instance = super().__new__(cls) cls._instance = super().__new__(cls)
return cls._instance return cls._instance
def get_model(self, ocr: bool, show_log: bool): def get_model(self, ocr: bool, show_log: bool, lang=None, layout_model=None, formula_enable=None, table_enable=None):
key = (ocr, show_log) key = (ocr, show_log, lang, layout_model, formula_enable, table_enable)
if key not in self._models: if key not in self._models:
self._models[key] = custom_model_init(ocr=ocr, show_log=show_log) self._models[key] = custom_model_init(ocr=ocr, show_log=show_log, lang=lang, layout_model=layout_model,
formula_enable=formula_enable, table_enable=table_enable)
return self._models[key] return self._models[key]
def custom_model_init(ocr: bool = False, show_log: bool = False): def custom_model_init(ocr: bool = False, show_log: bool = False, lang=None,
layout_model=None, formula_enable=None, table_enable=None):
model = None model = None
if model_config.__model_mode__ == "lite": if model_config.__model_mode__ == "lite":
...@@ -78,18 +93,36 @@ def custom_model_init(ocr: bool = False, show_log: bool = False): ...@@ -78,18 +93,36 @@ def custom_model_init(ocr: bool = False, show_log: bool = False):
model_init_start = time.time() model_init_start = time.time()
if model == MODEL.Paddle: if model == MODEL.Paddle:
from magic_pdf.model.pp_structure_v2 import CustomPaddleModel from magic_pdf.model.pp_structure_v2 import CustomPaddleModel
custom_model = CustomPaddleModel(ocr=ocr, show_log=show_log) custom_model = CustomPaddleModel(ocr=ocr, show_log=show_log, lang=lang)
elif model == MODEL.PEK: elif model == MODEL.PEK:
from magic_pdf.model.pdf_extract_kit import CustomPEKModel from magic_pdf.model.pdf_extract_kit import CustomPEKModel
# 从配置文件读取model-dir和device # 从配置文件读取model-dir和device
local_models_dir = get_local_models_dir() local_models_dir = get_local_models_dir()
device = get_device() device = get_device()
layout_config = get_layout_config()
if layout_model is not None:
layout_config["model"] = layout_model
formula_config = get_formula_config()
if formula_enable is not None:
formula_config["enable"] = formula_enable
table_config = get_table_recog_config() table_config = get_table_recog_config()
model_input = {"ocr": ocr, if table_enable is not None:
"show_log": show_log, table_config["enable"] = table_enable
"models_dir": local_models_dir,
"device": device, model_input = {
"table_config": table_config} "ocr": ocr,
"show_log": show_log,
"models_dir": local_models_dir,
"device": device,
"table_config": table_config,
"layout_config": layout_config,
"formula_config": formula_config,
"lang": lang,
}
custom_model = CustomPEKModel(**model_input) custom_model = CustomPEKModel(**model_input)
else: else:
logger.error("Not allow model_name!") logger.error("Not allow model_name!")
...@@ -104,19 +137,23 @@ def custom_model_init(ocr: bool = False, show_log: bool = False): ...@@ -104,19 +137,23 @@ def custom_model_init(ocr: bool = False, show_log: bool = False):
def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False, def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False,
start_page_id=0, end_page_id=None): start_page_id=0, end_page_id=None, lang=None,
layout_model=None, formula_enable=None, table_enable=None):
model_manager = ModelSingleton() if lang == "":
custom_model = model_manager.get_model(ocr, show_log) lang = None
images = load_images_from_pdf(pdf_bytes) model_manager = ModelSingleton()
custom_model = model_manager.get_model(ocr, show_log, lang, layout_model, formula_enable, table_enable)
# end_page_id = end_page_id if end_page_id else len(images) - 1 with fitz.open("pdf", pdf_bytes) as doc:
end_page_id = end_page_id if end_page_id is not None and end_page_id >= 0 else len(images) - 1 pdf_page_num = doc.page_count
end_page_id = end_page_id if end_page_id is not None and end_page_id >= 0 else pdf_page_num - 1
if end_page_id > pdf_page_num - 1:
logger.warning("end_page_id is out of range, use images length")
end_page_id = pdf_page_num - 1
if end_page_id > len(images) - 1: images = load_images_from_pdf(pdf_bytes, start_page_id=start_page_id, end_page_id=end_page_id)
logger.warning("end_page_id is out of range, use images length")
end_page_id = len(images) - 1
model_json = [] model_json = []
doc_analyze_start = time.time() doc_analyze_start = time.time()
...@@ -132,7 +169,15 @@ def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False, ...@@ -132,7 +169,15 @@ def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False,
page_info = {"page_no": index, "height": page_height, "width": page_width} page_info = {"page_no": index, "height": page_height, "width": page_width}
page_dict = {"layout_dets": result, "page_info": page_info} page_dict = {"layout_dets": result, "page_info": page_info}
model_json.append(page_dict) model_json.append(page_dict)
doc_analyze_cost = time.time() - doc_analyze_start
logger.info(f"doc analyze cost: {doc_analyze_cost}") gc_start = time.time()
clean_memory()
gc_time = round(time.time() - gc_start, 2)
logger.info(f"gc time: {gc_time}")
doc_analyze_time = round(time.time() - doc_analyze_start, 2)
doc_analyze_speed = round( (end_page_id + 1 - start_page_id) / doc_analyze_time, 2)
logger.info(f"doc analyze time: {round(time.time() - doc_analyze_start, 2)},"
f" speed: {doc_analyze_speed} pages/second")
return model_json return model_json
import enum
import json import json
from magic_pdf.data.dataset import Dataset
from magic_pdf.libs.boxbase import (_is_in, _is_part_overlap, bbox_distance, from magic_pdf.libs.boxbase import (_is_in, _is_part_overlap, bbox_distance,
bbox_relative_pos, box_area, calculate_iou, bbox_relative_pos, box_area, calculate_iou,
calculate_overlap_area_in_bbox1_area_ratio, calculate_overlap_area_in_bbox1_area_ratio,
...@@ -9,6 +11,7 @@ from magic_pdf.libs.coordinate_transform import get_scale_ratio ...@@ -9,6 +11,7 @@ from magic_pdf.libs.coordinate_transform import get_scale_ratio
from magic_pdf.libs.local_math import float_gt from magic_pdf.libs.local_math import float_gt
from magic_pdf.libs.ModelBlockTypeEnum import ModelBlockTypeEnum from magic_pdf.libs.ModelBlockTypeEnum import ModelBlockTypeEnum
from magic_pdf.libs.ocr_content_type import CategoryId, ContentType from magic_pdf.libs.ocr_content_type import CategoryId, ContentType
from magic_pdf.pre_proc.remove_bbox_overlap import _remove_overlap_between_bbox
from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter
from magic_pdf.rw.DiskReaderWriter import DiskReaderWriter from magic_pdf.rw.DiskReaderWriter import DiskReaderWriter
...@@ -16,6 +19,14 @@ CAPATION_OVERLAP_AREA_RATIO = 0.6 ...@@ -16,6 +19,14 @@ CAPATION_OVERLAP_AREA_RATIO = 0.6
MERGE_BOX_OVERLAP_AREA_RATIO = 1.1 MERGE_BOX_OVERLAP_AREA_RATIO = 1.1
class PosRelationEnum(enum.Enum):
LEFT = 'left'
RIGHT = 'right'
UP = 'up'
BOTTOM = 'bottom'
ALL = 'all'
class MagicModel: class MagicModel:
"""每个函数没有得到元素的时候返回空list.""" """每个函数没有得到元素的时候返回空list."""
...@@ -24,7 +35,7 @@ class MagicModel: ...@@ -24,7 +35,7 @@ class MagicModel:
need_remove_list = [] need_remove_list = []
page_no = model_page_info['page_info']['page_no'] page_no = model_page_info['page_info']['page_no']
horizontal_scale_ratio, vertical_scale_ratio = get_scale_ratio( horizontal_scale_ratio, vertical_scale_ratio = get_scale_ratio(
model_page_info, self.__docs[page_no] model_page_info, self.__docs.get_page(page_no)
) )
layout_dets = model_page_info['layout_dets'] layout_dets = model_page_info['layout_dets']
for layout_det in layout_dets: for layout_det in layout_dets:
...@@ -99,7 +110,7 @@ class MagicModel: ...@@ -99,7 +110,7 @@ class MagicModel:
for need_remove in need_remove_list: for need_remove in need_remove_list:
layout_dets.remove(need_remove) layout_dets.remove(need_remove)
def __init__(self, model_list: list, docs: fitz.Document): def __init__(self, model_list: list, docs: Dataset):
self.__model_list = model_list self.__model_list = model_list
self.__docs = docs self.__docs = docs
"""为所有模型数据添加bbox信息(缩放,poly->bbox)""" """为所有模型数据添加bbox信息(缩放,poly->bbox)"""
...@@ -110,6 +121,24 @@ class MagicModel: ...@@ -110,6 +121,24 @@ class MagicModel:
self.__fix_by_remove_high_iou_and_low_confidence() self.__fix_by_remove_high_iou_and_low_confidence()
self.__fix_footnote() self.__fix_footnote()
def _bbox_distance(self, bbox1, bbox2):
left, right, bottom, top = bbox_relative_pos(bbox1, bbox2)
flags = [left, right, bottom, top]
count = sum([1 if v else 0 for v in flags])
if count > 1:
return float('inf')
if left or right:
l1 = bbox1[3] - bbox1[1]
l2 = bbox2[3] - bbox2[1]
else:
l1 = bbox1[2] - bbox1[0]
l2 = bbox2[2] - bbox2[0]
if l2 > l1 and (l2 - l1) / l1 > 0.3:
return float('inf')
return bbox_distance(bbox1, bbox2)
def __fix_footnote(self): def __fix_footnote(self):
# 3: figure, 5: table, 7: footnote # 3: figure, 5: table, 7: footnote
for model_page_info in self.__model_list: for model_page_info in self.__model_list:
...@@ -144,7 +173,7 @@ class MagicModel: ...@@ -144,7 +173,7 @@ class MagicModel:
if pos_flag_count > 1: if pos_flag_count > 1:
continue continue
dis_figure_footnote[i] = min( dis_figure_footnote[i] = min(
bbox_distance(figures[j]['bbox'], footnotes[i]['bbox']), self._bbox_distance(figures[j]['bbox'], footnotes[i]['bbox']),
dis_figure_footnote.get(i, float('inf')), dis_figure_footnote.get(i, float('inf')),
) )
for i in range(len(footnotes)): for i in range(len(footnotes)):
...@@ -163,7 +192,7 @@ class MagicModel: ...@@ -163,7 +192,7 @@ class MagicModel:
continue continue
dis_table_footnote[i] = min( dis_table_footnote[i] = min(
bbox_distance(tables[j]['bbox'], footnotes[i]['bbox']), self._bbox_distance(tables[j]['bbox'], footnotes[i]['bbox']),
dis_table_footnote.get(i, float('inf')), dis_table_footnote.get(i, float('inf')),
) )
for i in range(len(footnotes)): for i in range(len(footnotes)):
...@@ -195,9 +224,8 @@ class MagicModel: ...@@ -195,9 +224,8 @@ class MagicModel:
筛选出所有和 merged bbox 有 overlap 且 overlap 面积大于 object 的面积的 subjects。 筛选出所有和 merged bbox 有 overlap 且 overlap 面积大于 object 的面积的 subjects。
再求出筛选出的 subjects 和 object 的最短距离 再求出筛选出的 subjects 和 object 的最短距离
""" """
def search_overlap_between_boxes(
subject_idx, object_idx def search_overlap_between_boxes(subject_idx, object_idx):
):
idxes = [subject_idx, object_idx] idxes = [subject_idx, object_idx]
x0s = [all_bboxes[idx]['bbox'][0] for idx in idxes] x0s = [all_bboxes[idx]['bbox'][0] for idx in idxes]
y0s = [all_bboxes[idx]['bbox'][1] for idx in idxes] y0s = [all_bboxes[idx]['bbox'][1] for idx in idxes]
...@@ -225,9 +253,9 @@ class MagicModel: ...@@ -225,9 +253,9 @@ class MagicModel:
for other_object in other_objects: for other_object in other_objects:
ratio = max( ratio = max(
ratio, ratio,
get_overlap_area( get_overlap_area(merged_bbox, other_object['bbox'])
merged_bbox, other_object['bbox'] * 1.0
) * 1.0 / box_area(all_bboxes[object_idx]['bbox']) / box_area(all_bboxes[object_idx]['bbox']),
) )
if ratio >= MERGE_BOX_OVERLAP_AREA_RATIO: if ratio >= MERGE_BOX_OVERLAP_AREA_RATIO:
break break
...@@ -345,12 +373,17 @@ class MagicModel: ...@@ -345,12 +373,17 @@ class MagicModel:
if all_bboxes[j]['category_id'] == subject_category_id: if all_bboxes[j]['category_id'] == subject_category_id:
subject_idx, object_idx = j, i subject_idx, object_idx = j, i
if search_overlap_between_boxes(subject_idx, object_idx) >= MERGE_BOX_OVERLAP_AREA_RATIO: if (
search_overlap_between_boxes(subject_idx, object_idx)
>= MERGE_BOX_OVERLAP_AREA_RATIO
):
dis[i][j] = float('inf') dis[i][j] = float('inf')
dis[j][i] = dis[i][j] dis[j][i] = dis[i][j]
continue continue
dis[i][j] = bbox_distance(all_bboxes[i]['bbox'], all_bboxes[j]['bbox']) dis[i][j] = self._bbox_distance(
all_bboxes[subject_idx]['bbox'], all_bboxes[object_idx]['bbox']
)
dis[j][i] = dis[i][j] dis[j][i] = dis[i][j]
used = set() used = set()
...@@ -566,6 +599,289 @@ class MagicModel: ...@@ -566,6 +599,289 @@ class MagicModel:
with_caption_subject.add(j) with_caption_subject.add(j)
return ret, total_subject_object_dis return ret, total_subject_object_dis
def __tie_up_category_by_distance_v2(
self,
page_no: int,
subject_category_id: int,
object_category_id: int,
priority_pos: PosRelationEnum,
):
"""_summary_
Args:
page_no (int): _description_
subject_category_id (int): _description_
object_category_id (int): _description_
priority_pos (PosRelationEnum): _description_
Returns:
_type_: _description_
"""
AXIS_MULPLICITY = 0.5
subjects = self.__reduct_overlap(
list(
map(
lambda x: {'bbox': x['bbox'], 'score': x['score']},
filter(
lambda x: x['category_id'] == subject_category_id,
self.__model_list[page_no]['layout_dets'],
),
)
)
)
objects = self.__reduct_overlap(
list(
map(
lambda x: {'bbox': x['bbox'], 'score': x['score']},
filter(
lambda x: x['category_id'] == object_category_id,
self.__model_list[page_no]['layout_dets'],
),
)
)
)
M = len(objects)
subjects.sort(key=lambda x: x['bbox'][0] ** 2 + x['bbox'][1] ** 2)
objects.sort(key=lambda x: x['bbox'][0] ** 2 + x['bbox'][1] ** 2)
sub_obj_map_h = {i: [] for i in range(len(subjects))}
dis_by_directions = {
'top': [[-1, float('inf')]] * M,
'bottom': [[-1, float('inf')]] * M,
'left': [[-1, float('inf')]] * M,
'right': [[-1, float('inf')]] * M,
}
for i, obj in enumerate(objects):
l_x_axis, l_y_axis = (
obj['bbox'][2] - obj['bbox'][0],
obj['bbox'][3] - obj['bbox'][1],
)
axis_unit = min(l_x_axis, l_y_axis)
for j, sub in enumerate(subjects):
bbox1, bbox2, _ = _remove_overlap_between_bbox(
objects[i]['bbox'], subjects[j]['bbox']
)
left, right, bottom, top = bbox_relative_pos(bbox1, bbox2)
flags = [left, right, bottom, top]
if sum([1 if v else 0 for v in flags]) > 1:
continue
if left:
if dis_by_directions['left'][i][1] > bbox_distance(
obj['bbox'], sub['bbox']
):
dis_by_directions['left'][i] = [
j,
bbox_distance(obj['bbox'], sub['bbox']),
]
if right:
if dis_by_directions['right'][i][1] > bbox_distance(
obj['bbox'], sub['bbox']
):
dis_by_directions['right'][i] = [
j,
bbox_distance(obj['bbox'], sub['bbox']),
]
if bottom:
if dis_by_directions['bottom'][i][1] > bbox_distance(
obj['bbox'], sub['bbox']
):
dis_by_directions['bottom'][i] = [
j,
bbox_distance(obj['bbox'], sub['bbox']),
]
if top:
if dis_by_directions['top'][i][1] > bbox_distance(
obj['bbox'], sub['bbox']
):
dis_by_directions['top'][i] = [
j,
bbox_distance(obj['bbox'], sub['bbox']),
]
if (
dis_by_directions['top'][i][1] != float('inf')
and dis_by_directions['bottom'][i][1] != float('inf')
and priority_pos in (PosRelationEnum.BOTTOM, PosRelationEnum.UP)
):
RATIO = 3
if (
abs(
dis_by_directions['top'][i][1]
- dis_by_directions['bottom'][i][1]
)
< RATIO * axis_unit
):
if priority_pos == PosRelationEnum.BOTTOM:
sub_obj_map_h[dis_by_directions['bottom'][i][0]].append(i)
else:
sub_obj_map_h[dis_by_directions['top'][i][0]].append(i)
continue
if dis_by_directions['left'][i][1] != float('inf') or dis_by_directions[
'right'
][i][1] != float('inf'):
if dis_by_directions['left'][i][1] != float(
'inf'
) and dis_by_directions['right'][i][1] != float('inf'):
if AXIS_MULPLICITY * axis_unit >= abs(
dis_by_directions['left'][i][1]
- dis_by_directions['right'][i][1]
):
left_sub_bbox = subjects[dis_by_directions['left'][i][0]][
'bbox'
]
right_sub_bbox = subjects[dis_by_directions['right'][i][0]][
'bbox'
]
left_sub_bbox_y_axis = left_sub_bbox[3] - left_sub_bbox[1]
right_sub_bbox_y_axis = right_sub_bbox[3] - right_sub_bbox[1]
if (
abs(left_sub_bbox_y_axis - l_y_axis)
+ dis_by_directions['left'][i][0]
> abs(right_sub_bbox_y_axis - l_y_axis)
+ dis_by_directions['right'][i][0]
):
left_or_right = dis_by_directions['right'][i]
else:
left_or_right = dis_by_directions['left'][i]
else:
left_or_right = dis_by_directions['left'][i]
if left_or_right[1] > dis_by_directions['right'][i][1]:
left_or_right = dis_by_directions['right'][i]
else:
left_or_right = dis_by_directions['left'][i]
if left_or_right[1] == float('inf'):
left_or_right = dis_by_directions['right'][i]
else:
left_or_right = [-1, float('inf')]
if dis_by_directions['top'][i][1] != float('inf') or dis_by_directions[
'bottom'
][i][1] != float('inf'):
if dis_by_directions['top'][i][1] != float('inf') and dis_by_directions[
'bottom'
][i][1] != float('inf'):
if AXIS_MULPLICITY * axis_unit >= abs(
dis_by_directions['top'][i][1]
- dis_by_directions['bottom'][i][1]
):
top_bottom = subjects[dis_by_directions['bottom'][i][0]]['bbox']
bottom_top = subjects[dis_by_directions['top'][i][0]]['bbox']
top_bottom_x_axis = top_bottom[2] - top_bottom[0]
bottom_top_x_axis = bottom_top[2] - bottom_top[0]
if (
abs(top_bottom_x_axis - l_x_axis)
+ dis_by_directions['bottom'][i][1]
> abs(bottom_top_x_axis - l_x_axis)
+ dis_by_directions['top'][i][1]
):
top_or_bottom = dis_by_directions['top'][i]
else:
top_or_bottom = dis_by_directions['bottom'][i]
else:
top_or_bottom = dis_by_directions['top'][i]
if top_or_bottom[1] > dis_by_directions['bottom'][i][1]:
top_or_bottom = dis_by_directions['bottom'][i]
else:
top_or_bottom = dis_by_directions['top'][i]
if top_or_bottom[1] == float('inf'):
top_or_bottom = dis_by_directions['bottom'][i]
else:
top_or_bottom = [-1, float('inf')]
if left_or_right[1] != float('inf') or top_or_bottom[1] != float('inf'):
if left_or_right[1] != float('inf') and top_or_bottom[1] != float(
'inf'
):
if AXIS_MULPLICITY * axis_unit >= abs(
left_or_right[1] - top_or_bottom[1]
):
y_axis_bbox = subjects[left_or_right[0]]['bbox']
x_axis_bbox = subjects[top_or_bottom[0]]['bbox']
if (
abs((x_axis_bbox[2] - x_axis_bbox[0]) - l_x_axis) / l_x_axis
> abs((y_axis_bbox[3] - y_axis_bbox[1]) - l_y_axis)
/ l_y_axis
):
sub_obj_map_h[left_or_right[0]].append(i)
else:
sub_obj_map_h[top_or_bottom[0]].append(i)
else:
if left_or_right[1] > top_or_bottom[1]:
sub_obj_map_h[top_or_bottom[0]].append(i)
else:
sub_obj_map_h[left_or_right[0]].append(i)
else:
if left_or_right[1] != float('inf'):
sub_obj_map_h[left_or_right[0]].append(i)
else:
sub_obj_map_h[top_or_bottom[0]].append(i)
ret = []
for i in sub_obj_map_h.keys():
ret.append(
{
'sub_bbox': {
'bbox': subjects[i]['bbox'],
'score': subjects[i]['score'],
},
'obj_bboxes': [
{'score': objects[j]['score'], 'bbox': objects[j]['bbox']}
for j in sub_obj_map_h[i]
],
'sub_idx': i,
}
)
return ret
def get_imgs_v2(self, page_no: int):
with_captions = self.__tie_up_category_by_distance_v2(
page_no, 3, 4, PosRelationEnum.BOTTOM
)
with_footnotes = self.__tie_up_category_by_distance_v2(
page_no, 3, CategoryId.ImageFootnote, PosRelationEnum.ALL
)
ret = []
for v in with_captions:
record = {
'image_body': v['sub_bbox'],
'image_caption_list': v['obj_bboxes'],
}
filter_idx = v['sub_idx']
d = next(filter(lambda x: x['sub_idx'] == filter_idx, with_footnotes))
record['image_footnote_list'] = d['obj_bboxes']
ret.append(record)
return ret
def get_tables_v2(self, page_no: int) -> list:
with_captions = self.__tie_up_category_by_distance_v2(
page_no, 5, 6, PosRelationEnum.UP
)
with_footnotes = self.__tie_up_category_by_distance_v2(
page_no, 5, 7, PosRelationEnum.ALL
)
ret = []
for v in with_captions:
record = {
'table_body': v['sub_bbox'],
'table_caption_list': v['obj_bboxes'],
}
filter_idx = v['sub_idx']
d = next(filter(lambda x: x['sub_idx'] == filter_idx, with_footnotes))
record['table_footnote_list'] = d['obj_bboxes']
ret.append(record)
return ret
def get_imgs(self, page_no: int): def get_imgs(self, page_no: int):
with_captions, _ = self.__tie_up_category_by_distance(page_no, 3, 4) with_captions, _ = self.__tie_up_category_by_distance(page_no, 3, 4)
with_footnotes, _ = self.__tie_up_category_by_distance( with_footnotes, _ = self.__tie_up_category_by_distance(
...@@ -699,10 +1015,10 @@ class MagicModel: ...@@ -699,10 +1015,10 @@ class MagicModel:
def get_page_size(self, page_no: int): # 获取页面宽高 def get_page_size(self, page_no: int): # 获取页面宽高
# 获取当前页的page对象 # 获取当前页的page对象
page = self.__docs[page_no] page = self.__docs.get_page(page_no).get_page_info()
# 获取当前页的宽高 # 获取当前页的宽高
page_w = page.rect.width page_w = page.w
page_h = page.rect.height page_h = page.h
return page_w, page_h return page_w, page_h
def __get_blocks_by_type( def __get_blocks_by_type(
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment