Commit 283b597a authored by icecraft's avatar icecraft
Browse files

feat: add [figure | table] match [caption | footnote] match algorithm v2

feat: add Data api
parent e36627be
from abc import ABC, abstractmethod
class IOReader(ABC):
@abstractmethod
def read(self, path: str) -> bytes:
"""Read the file.
Args:
path (str): file path to read
Returns:
bytes: the content of the file
"""
pass
@abstractmethod
def read_at(self, path: str, offset: int = 0, limit: int = -1) -> bytes:
"""Read at offset and limit.
Args:
path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
offset (int, optional): the number of bytes skipped. Defaults to 0.
limit (int, optional): the length of bytes want to read. Defaults to -1.
Returns:
bytes: the content of file
"""
pass
class IOWriter:
@abstractmethod
def write(self, path: str, data: bytes) -> None:
"""Write file with data.
Args:
path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
data (bytes): the data want to write
"""
pass
import io
import requests
from magic_pdf.data.io.base import IOReader, IOWriter
class HttpReader(IOReader):
def read(self, url: str) -> bytes:
"""Read the file.
Args:
path (str): file path to read
Returns:
bytes: the content of the file
"""
return requests.get(url).content
def read_at(self, path: str, offset: int = 0, limit: int = -1) -> bytes:
"""Not Implemented."""
raise NotImplementedError
class HttpWriter(IOWriter):
def write(self, url: str, data: bytes) -> None:
"""Write file with data.
Args:
path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
data (bytes): the data want to write
"""
files = {'file': io.BytesIO(data)}
response = requests.post(url, files=files)
assert 300 > response.status_code and response.status_code > 199
import boto3
from botocore.config import Config
from magic_pdf.data.io.base import IOReader, IOWriter
class S3Reader(IOReader):
def __init__(
self,
bucket: str,
ak: str,
sk: str,
endpoint_url: str,
addressing_style: str = 'auto',
):
"""s3 reader client.
Args:
bucket (str): bucket name
ak (str): access key
sk (str): secret key
endpoint_url (str): endpoint url of s3
addressing_style (str, optional): Defaults to 'auto'. Other valid options here are 'path' and 'virtual'
refer to https://boto3.amazonaws.com/v1/documentation/api/1.9.42/guide/s3.html
"""
self._bucket = bucket
self._ak = ak
self._sk = sk
self._s3_client = boto3.client(
service_name='s3',
aws_access_key_id=ak,
aws_secret_access_key=sk,
endpoint_url=endpoint_url,
config=Config(
s3={'addressing_style': addressing_style},
retries={'max_attempts': 5, 'mode': 'standard'},
),
)
def read(self, key: str) -> bytes:
"""Read the file.
Args:
path (str): file path to read
Returns:
bytes: the content of the file
"""
return self.read_at(key)
def read_at(self, key: str, offset: int = 0, limit: int = -1) -> bytes:
"""Read at offset and limit.
Args:
path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
offset (int, optional): the number of bytes skipped. Defaults to 0.
limit (int, optional): the length of bytes want to read. Defaults to -1.
Returns:
bytes: the content of file
"""
if limit > -1:
range_header = f'bytes={offset}-{offset+limit-1}'
res = self._s3_client.get_object(
Bucket=self._bucket, Key=key, Range=range_header
)
else:
res = self._s3_client.get_object(
Bucket=self._bucket, Key=key, Range=f'bytes={offset}-'
)
return res['Body'].read()
class S3Writer(IOWriter):
def __init__(
self,
bucket: str,
ak: str,
sk: str,
endpoint_url: str,
addressing_style: str = 'auto',
):
"""s3 reader client.
Args:
bucket (str): bucket name
ak (str): access key
sk (str): secret key
endpoint_url (str): endpoint url of s3
addressing_style (str, optional): Defaults to 'auto'. Other valid options here are 'path' and 'virtual'
refer to https://boto3.amazonaws.com/v1/documentation/api/1.9.42/guide/s3.html
"""
self._bucket = bucket
self._ak = ak
self._sk = sk
self._s3_client = boto3.client(
service_name='s3',
aws_access_key_id=ak,
aws_secret_access_key=sk,
endpoint_url=endpoint_url,
config=Config(
s3={'addressing_style': addressing_style},
retries={'max_attempts': 5, 'mode': 'standard'},
),
)
def write(self, key: str, data: bytes):
"""Write file with data.
Args:
path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
data (bytes): the data want to write
"""
self._s3_client.put_object(Bucket=self._bucket, Key=key, Body=data)
import json
import os
from pathlib import Path
from magic_pdf.config.exceptions import EmptyData, InvalidParams
from magic_pdf.data.data_reader_writer import (FileBasedDataReader,
MultiBucketS3DataReader)
from magic_pdf.data.dataset import ImageDataset, PymuDocDataset
def read_jsonl(
s3_path_or_local: str, s3_client: MultiBucketS3DataReader | None = None
) -> list[PymuDocDataset]:
"""Read the jsonl file and return the list of PymuDocDataset.
Args:
s3_path_or_local (str): local file or s3 path
s3_client (MultiBucketS3DataReader | None, optional): s3 client that support multiple bucket. Defaults to None.
Raises:
InvalidParams: if s3_path_or_local is s3 path but s3_client is not provided.
EmptyData: if no pdf file location is provided in some line of jsonl file.
InvalidParams: if the file location is s3 path but s3_client is not provided
Returns:
list[PymuDocDataset]: each line in the jsonl file will be converted to a PymuDocDataset
"""
bits_arr = []
if s3_path_or_local.startswith('s3://'):
if s3_client is None:
raise InvalidParams('s3_client is required when s3_path is provided')
jsonl_bits = s3_client.read(s3_path_or_local)
else:
jsonl_bits = FileBasedDataReader('').read(s3_path_or_local)
jsonl_d = [
json.loads(line) for line in jsonl_bits.decode().split('\n') if line.strip()
]
for d in jsonl_d[:5]:
pdf_path = d.get('file_location', '') or d.get('path', '')
if len(pdf_path) == 0:
raise EmptyData('pdf file location is empty')
if pdf_path.startswith('s3://'):
if s3_client is None:
raise InvalidParams('s3_client is required when s3_path is provided')
bits_arr.append(s3_client.read(pdf_path))
else:
bits_arr.append(FileBasedDataReader('').read(pdf_path))
return [PymuDocDataset(bits) for bits in bits_arr]
def read_local_pdfs(path: str) -> list[PymuDocDataset]:
"""Read pdf from path or directory.
Args:
path (str): pdf file path or directory that contains pdf files
Returns:
list[PymuDocDataset]: each pdf file will converted to a PymuDocDataset
"""
if os.path.isdir(path):
reader = FileBasedDataReader(path)
return [
PymuDocDataset(reader.read(doc_path.name))
for doc_path in Path(path).glob('*.pdf')
]
else:
reader = FileBasedDataReader()
bits = reader.read(path)
return [PymuDocDataset(bits)]
def read_local_images(path: str, suffixes: list[str]) -> list[ImageDataset]:
"""Read images from path or directory.
Args:
path (str): image file path or directory that contains image files
suffixes (list[str]): the suffixes of the image files used to filter the files. Example: ['jpg', 'png']
Returns:
list[ImageDataset]: each image file will converted to a ImageDataset
"""
if os.path.isdir(path):
imgs_bits = []
s_suffixes = set(suffixes)
reader = FileBasedDataReader(path)
for root, _, files in os.walk(path):
for file in files:
suffix = file.split('.')
if suffix[-1] in s_suffixes:
imgs_bits.append(reader.read(file))
return [ImageDataset(bits) for bits in imgs_bits]
else:
reader = FileBasedDataReader()
bits = reader.read(path)
return [ImageDataset(bits)]
from pydantic import BaseModel, Field
class S3Config(BaseModel):
bucket_name: str = Field(description='s3 bucket name', min_length=1)
access_key: str = Field(description='s3 access key', min_length=1)
secret_key: str = Field(description='s3 secret key', min_length=1)
endpoint_url: str = Field(description='s3 endpoint url', min_length=1)
addressing_style: str = Field(description='s3 addressing style', default='auto', min_length=1)
class PageInfo(BaseModel):
w: float = Field(description='the width of page')
h: float = Field(description='the height of page')
import fitz
import numpy as np
from magic_pdf.utils.annotations import ImportPIL
@ImportPIL
def fitz_doc_to_image(doc, dpi=200) -> dict:
"""Convert fitz.Document to image, Then convert the image to numpy array.
Args:
doc (_type_): pymudoc page
dpi (int, optional): reset the dpi of dpi. Defaults to 200.
Returns:
dict: {'img': numpy array, 'width': width, 'height': height }
"""
from PIL import Image
mat = fitz.Matrix(dpi / 72, dpi / 72)
pm = doc.get_pixmap(matrix=mat, alpha=False)
# If the width or height exceeds 9000 after scaling, do not scale further.
if pm.width > 9000 or pm.height > 9000:
pm = doc.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False)
img = Image.frombytes('RGB', (pm.width, pm.height), pm.samples)
img = np.array(img)
img_dict = {'img': img, 'width': pm.width, 'height': pm.height}
return img_dict
""" """根据bucket的名字返回对应的s3 AK, SK,endpoint三元组."""
根据bucket的名字返回对应的s3 AK, SK,endpoint三元组
"""
import json import json
import os import os
...@@ -12,36 +9,36 @@ from magic_pdf.libs.Constants import MODEL_NAME ...@@ -12,36 +9,36 @@ from magic_pdf.libs.Constants import MODEL_NAME
from magic_pdf.libs.commons import parse_bucket_key from magic_pdf.libs.commons import parse_bucket_key
# 定义配置文件名常量 # 定义配置文件名常量
CONFIG_FILE_NAME = "magic-pdf.json" CONFIG_FILE_NAME = os.getenv('MINERU_TOOLS_CONFIG_JSON', 'magic-pdf.json')
def read_config(): def read_config():
home_dir = os.path.expanduser("~") if os.path.isabs(CONFIG_FILE_NAME):
config_file = CONFIG_FILE_NAME
config_file = os.path.join(home_dir, CONFIG_FILE_NAME) else:
home_dir = os.path.expanduser('~')
config_file = os.path.join(home_dir, CONFIG_FILE_NAME)
if not os.path.exists(config_file): if not os.path.exists(config_file):
raise FileNotFoundError(f"{config_file} not found") raise FileNotFoundError(f'{config_file} not found')
with open(config_file, "r", encoding="utf-8") as f: with open(config_file, 'r', encoding='utf-8') as f:
config = json.load(f) config = json.load(f)
return config return config
def get_s3_config(bucket_name: str): def get_s3_config(bucket_name: str):
""" """~/magic-pdf.json 读出来."""
~/magic-pdf.json 读出来
"""
config = read_config() config = read_config()
bucket_info = config.get("bucket_info") bucket_info = config.get('bucket_info')
if bucket_name not in bucket_info: if bucket_name not in bucket_info:
access_key, secret_key, storage_endpoint = bucket_info["[default]"] access_key, secret_key, storage_endpoint = bucket_info['[default]']
else: else:
access_key, secret_key, storage_endpoint = bucket_info[bucket_name] access_key, secret_key, storage_endpoint = bucket_info[bucket_name]
if access_key is None or secret_key is None or storage_endpoint is None: if access_key is None or secret_key is None or storage_endpoint is None:
raise Exception(f"ak, sk or endpoint not found in {CONFIG_FILE_NAME}") raise Exception(f'ak, sk or endpoint not found in {CONFIG_FILE_NAME}')
# logger.info(f"get_s3_config: ak={access_key}, sk={secret_key}, endpoint={storage_endpoint}") # logger.info(f"get_s3_config: ak={access_key}, sk={secret_key}, endpoint={storage_endpoint}")
...@@ -50,7 +47,7 @@ def get_s3_config(bucket_name: str): ...@@ -50,7 +47,7 @@ def get_s3_config(bucket_name: str):
def get_s3_config_dict(path: str): def get_s3_config_dict(path: str):
access_key, secret_key, storage_endpoint = get_s3_config(get_bucket_name(path)) access_key, secret_key, storage_endpoint = get_s3_config(get_bucket_name(path))
return {"ak": access_key, "sk": secret_key, "endpoint": storage_endpoint} return {'ak': access_key, 'sk': secret_key, 'endpoint': storage_endpoint}
def get_bucket_name(path): def get_bucket_name(path):
...@@ -60,20 +57,20 @@ def get_bucket_name(path): ...@@ -60,20 +57,20 @@ def get_bucket_name(path):
def get_local_models_dir(): def get_local_models_dir():
config = read_config() config = read_config()
models_dir = config.get("models-dir") models_dir = config.get('models-dir')
if models_dir is None: if models_dir is None:
logger.warning(f"'models-dir' not found in {CONFIG_FILE_NAME}, use '/tmp/models' as default") logger.warning(f"'models-dir' not found in {CONFIG_FILE_NAME}, use '/tmp/models' as default")
return "/tmp/models" return '/tmp/models'
else: else:
return models_dir return models_dir
def get_local_layoutreader_model_dir(): def get_local_layoutreader_model_dir():
config = read_config() config = read_config()
layoutreader_model_dir = config.get("layoutreader-model-dir") layoutreader_model_dir = config.get('layoutreader-model-dir')
if layoutreader_model_dir is None or not os.path.exists(layoutreader_model_dir): if layoutreader_model_dir is None or not os.path.exists(layoutreader_model_dir):
home_dir = os.path.expanduser("~") home_dir = os.path.expanduser('~')
layoutreader_at_modelscope_dir_path = os.path.join(home_dir, ".cache/modelscope/hub/ppaanngggg/layoutreader") layoutreader_at_modelscope_dir_path = os.path.join(home_dir, '.cache/modelscope/hub/ppaanngggg/layoutreader')
logger.warning(f"'layoutreader-model-dir' not exists, use {layoutreader_at_modelscope_dir_path} as default") logger.warning(f"'layoutreader-model-dir' not exists, use {layoutreader_at_modelscope_dir_path} as default")
return layoutreader_at_modelscope_dir_path return layoutreader_at_modelscope_dir_path
else: else:
...@@ -82,17 +79,17 @@ def get_local_layoutreader_model_dir(): ...@@ -82,17 +79,17 @@ def get_local_layoutreader_model_dir():
def get_device(): def get_device():
config = read_config() config = read_config()
device = config.get("device-mode") device = config.get('device-mode')
if device is None: if device is None:
logger.warning(f"'device-mode' not found in {CONFIG_FILE_NAME}, use 'cpu' as default") logger.warning(f"'device-mode' not found in {CONFIG_FILE_NAME}, use 'cpu' as default")
return "cpu" return 'cpu'
else: else:
return device return device
def get_table_recog_config(): def get_table_recog_config():
config = read_config() config = read_config()
table_config = config.get("table-config") table_config = config.get('table-config')
if table_config is None: if table_config is None:
logger.warning(f"'table-config' not found in {CONFIG_FILE_NAME}, use 'False' as default") logger.warning(f"'table-config' not found in {CONFIG_FILE_NAME}, use 'False' as default")
return json.loads(f'{{"model": "{MODEL_NAME.TABLE_MASTER}","enable": false, "max_time": 400}}') return json.loads(f'{{"model": "{MODEL_NAME.TABLE_MASTER}","enable": false, "max_time": 400}}')
......
from magic_pdf.data.dataset import PymuDocDataset
from magic_pdf.libs.commons import fitz # PyMuPDF from magic_pdf.libs.commons import fitz # PyMuPDF
from magic_pdf.libs.Constants import CROSS_PAGE from magic_pdf.libs.Constants import CROSS_PAGE
from magic_pdf.libs.ocr_content_type import BlockType, CategoryId, ContentType from magic_pdf.libs.ocr_content_type import BlockType, CategoryId, ContentType
...@@ -62,7 +63,7 @@ def draw_bbox_with_number(i, bbox_list, page, rgb_config, fill_config, draw_bbox ...@@ -62,7 +63,7 @@ def draw_bbox_with_number(i, bbox_list, page, rgb_config, fill_config, draw_bbox
overlay=True, overlay=True,
) # Draw the rectangle ) # Draw the rectangle
page.insert_text( page.insert_text(
(x1+2, y0 + 10), str(j + 1), fontsize=10, color=new_rgb (x1 + 2, y0 + 10), str(j + 1), fontsize=10, color=new_rgb
) # Insert the index in the top left corner of the rectangle ) # Insert the index in the top left corner of the rectangle
...@@ -86,7 +87,7 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename): ...@@ -86,7 +87,7 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
texts = [] texts = []
interequations = [] interequations = []
lists = [] lists = []
indexs = [] indices = []
for dropped_bbox in page['discarded_blocks']: for dropped_bbox in page['discarded_blocks']:
page_dropped_list.append(dropped_bbox['bbox']) page_dropped_list.append(dropped_bbox['bbox'])
...@@ -122,7 +123,7 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename): ...@@ -122,7 +123,7 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
elif block['type'] == BlockType.List: elif block['type'] == BlockType.List:
lists.append(bbox) lists.append(bbox)
elif block['type'] == BlockType.Index: elif block['type'] == BlockType.Index:
indexs.append(bbox) indices.append(bbox)
tables_list.append(tables) tables_list.append(tables)
tables_body_list.append(tables_body) tables_body_list.append(tables_body)
...@@ -136,7 +137,7 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename): ...@@ -136,7 +137,7 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
texts_list.append(texts) texts_list.append(texts)
interequations_list.append(interequations) interequations_list.append(interequations)
lists_list.append(lists) lists_list.append(lists)
indexs_list.append(indexs) indexs_list.append(indices)
layout_bbox_list = [] layout_bbox_list = []
...@@ -151,30 +152,24 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename): ...@@ -151,30 +152,24 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
for i, page in enumerate(pdf_docs): for i, page in enumerate(pdf_docs):
draw_bbox_without_number(i, dropped_bbox_list, page, [158, 158, 158], draw_bbox_without_number(i, dropped_bbox_list, page, [158, 158, 158], True)
True) draw_bbox_without_number(i, tables_list, page, [153, 153, 0], True) # color !
draw_bbox_without_number(i, tables_list, page, [153, 153, 0], draw_bbox_without_number(i, tables_body_list, page, [204, 204, 0], True)
True) # color ! draw_bbox_without_number(i, tables_caption_list, page, [255, 255, 102], True)
draw_bbox_without_number(i, tables_body_list, page, [204, 204, 0], draw_bbox_without_number(i, tables_footnote_list, page, [229, 255, 204], True)
True)
draw_bbox_without_number(i, tables_caption_list, page, [255, 255, 102],
True)
draw_bbox_without_number(i, tables_footnote_list, page,
[229, 255, 204], True)
draw_bbox_without_number(i, imgs_list, page, [51, 102, 0], True) draw_bbox_without_number(i, imgs_list, page, [51, 102, 0], True)
draw_bbox_without_number(i, imgs_body_list, page, [153, 255, 51], True) draw_bbox_without_number(i, imgs_body_list, page, [153, 255, 51], True)
draw_bbox_without_number(i, imgs_caption_list, page, [102, 178, 255], draw_bbox_without_number(i, imgs_caption_list, page, [102, 178, 255], True)
True) draw_bbox_without_number(i, imgs_footnote_list, page, [255, 178, 102], True),
draw_bbox_without_number(i, imgs_footnote_list, page, [255, 178, 102],
True),
draw_bbox_without_number(i, titles_list, page, [102, 102, 255], True) draw_bbox_without_number(i, titles_list, page, [102, 102, 255], True)
draw_bbox_without_number(i, texts_list, page, [153, 0, 76], True) draw_bbox_without_number(i, texts_list, page, [153, 0, 76], True)
draw_bbox_without_number(i, interequations_list, page, [0, 255, 0], draw_bbox_without_number(i, interequations_list, page, [0, 255, 0], True)
True)
draw_bbox_without_number(i, lists_list, page, [40, 169, 92], True) draw_bbox_without_number(i, lists_list, page, [40, 169, 92], True)
draw_bbox_without_number(i, indexs_list, page, [40, 169, 92], True) draw_bbox_without_number(i, indexs_list, page, [40, 169, 92], True)
draw_bbox_with_number(i, layout_bbox_list, page, [255, 0, 0], False, draw_bbox=False) draw_bbox_with_number(
i, layout_bbox_list, page, [255, 0, 0], False, draw_bbox=False
)
# Save the PDF # Save the PDF
pdf_docs.save(f'{out_path}/{filename}_layout.pdf') pdf_docs.save(f'{out_path}/{filename}_layout.pdf')
...@@ -275,7 +270,7 @@ def draw_model_bbox(model_list: list, pdf_bytes, out_path, filename): ...@@ -275,7 +270,7 @@ def draw_model_bbox(model_list: list, pdf_bytes, out_path, filename):
texts_list = [] texts_list = []
interequations_list = [] interequations_list = []
pdf_docs = fitz.open('pdf', pdf_bytes) pdf_docs = fitz.open('pdf', pdf_bytes)
magic_model = MagicModel(model_list, pdf_docs) magic_model = MagicModel(model_list, PymuDocDataset(pdf_bytes))
for i in range(len(model_list)): for i in range(len(model_list)):
page_dropped_list = [] page_dropped_list = []
tables_body, tables_caption, tables_footnote = [], [], [] tables_body, tables_caption, tables_footnote = [], [], []
...@@ -301,8 +296,7 @@ def draw_model_bbox(model_list: list, pdf_bytes, out_path, filename): ...@@ -301,8 +296,7 @@ def draw_model_bbox(model_list: list, pdf_bytes, out_path, filename):
imgs_body.append(bbox) imgs_body.append(bbox)
elif layout_det['category_id'] == CategoryId.ImageCaption: elif layout_det['category_id'] == CategoryId.ImageCaption:
imgs_caption.append(bbox) imgs_caption.append(bbox)
elif layout_det[ elif layout_det['category_id'] == CategoryId.InterlineEquation_YOLO:
'category_id'] == CategoryId.InterlineEquation_YOLO:
interequations.append(bbox) interequations.append(bbox)
elif layout_det['category_id'] == CategoryId.Abandon: elif layout_det['category_id'] == CategoryId.Abandon:
page_dropped_list.append(bbox) page_dropped_list.append(bbox)
...@@ -321,18 +315,15 @@ def draw_model_bbox(model_list: list, pdf_bytes, out_path, filename): ...@@ -321,18 +315,15 @@ def draw_model_bbox(model_list: list, pdf_bytes, out_path, filename):
imgs_footnote_list.append(imgs_footnote) imgs_footnote_list.append(imgs_footnote)
for i, page in enumerate(pdf_docs): for i, page in enumerate(pdf_docs):
draw_bbox_with_number(i, dropped_bbox_list, page, [158, 158, 158], draw_bbox_with_number(
True) # color ! i, dropped_bbox_list, page, [158, 158, 158], True
) # color !
draw_bbox_with_number(i, tables_body_list, page, [204, 204, 0], True) draw_bbox_with_number(i, tables_body_list, page, [204, 204, 0], True)
draw_bbox_with_number(i, tables_caption_list, page, [255, 255, 102], draw_bbox_with_number(i, tables_caption_list, page, [255, 255, 102], True)
True) draw_bbox_with_number(i, tables_footnote_list, page, [229, 255, 204], True)
draw_bbox_with_number(i, tables_footnote_list, page, [229, 255, 204],
True)
draw_bbox_with_number(i, imgs_body_list, page, [153, 255, 51], True) draw_bbox_with_number(i, imgs_body_list, page, [153, 255, 51], True)
draw_bbox_with_number(i, imgs_caption_list, page, [102, 178, 255], draw_bbox_with_number(i, imgs_caption_list, page, [102, 178, 255], True)
True) draw_bbox_with_number(i, imgs_footnote_list, page, [255, 178, 102], True)
draw_bbox_with_number(i, imgs_footnote_list, page, [255, 178, 102],
True)
draw_bbox_with_number(i, titles_list, page, [102, 102, 255], True) draw_bbox_with_number(i, titles_list, page, [102, 102, 255], True)
draw_bbox_with_number(i, texts_list, page, [153, 0, 76], True) draw_bbox_with_number(i, texts_list, page, [153, 0, 76], True)
draw_bbox_with_number(i, interequations_list, page, [0, 255, 0], True) draw_bbox_with_number(i, interequations_list, page, [0, 255, 0], True)
......
import json import json
from magic_pdf.data.dataset import Dataset
from magic_pdf.libs.boxbase import (_is_in, _is_part_overlap, bbox_distance, from magic_pdf.libs.boxbase import (_is_in, _is_part_overlap, bbox_distance,
bbox_relative_pos, box_area, calculate_iou, bbox_relative_pos, box_area, calculate_iou,
calculate_overlap_area_in_bbox1_area_ratio, calculate_overlap_area_in_bbox1_area_ratio,
...@@ -24,7 +25,7 @@ class MagicModel: ...@@ -24,7 +25,7 @@ class MagicModel:
need_remove_list = [] need_remove_list = []
page_no = model_page_info['page_info']['page_no'] page_no = model_page_info['page_info']['page_no']
horizontal_scale_ratio, vertical_scale_ratio = get_scale_ratio( horizontal_scale_ratio, vertical_scale_ratio = get_scale_ratio(
model_page_info, self.__docs[page_no] model_page_info, self.__docs.get_page(page_no)
) )
layout_dets = model_page_info['layout_dets'] layout_dets = model_page_info['layout_dets']
for layout_det in layout_dets: for layout_det in layout_dets:
...@@ -99,7 +100,7 @@ class MagicModel: ...@@ -99,7 +100,7 @@ class MagicModel:
for need_remove in need_remove_list: for need_remove in need_remove_list:
layout_dets.remove(need_remove) layout_dets.remove(need_remove)
def __init__(self, model_list: list, docs: fitz.Document): def __init__(self, model_list: list, docs: Dataset):
self.__model_list = model_list self.__model_list = model_list
self.__docs = docs self.__docs = docs
"""为所有模型数据添加bbox信息(缩放,poly->bbox)""" """为所有模型数据添加bbox信息(缩放,poly->bbox)"""
...@@ -123,7 +124,8 @@ class MagicModel: ...@@ -123,7 +124,8 @@ class MagicModel:
l1 = bbox1[2] - bbox1[0] l1 = bbox1[2] - bbox1[0]
l2 = bbox2[2] - bbox2[0] l2 = bbox2[2] - bbox2[0]
if l2 > l1 and (l2 - l1) / l1 > 0.5: min_l, max_l = min(l1, l2), max(l1, l2)
if (max_l - min_l) * 1.0 / max_l > 0.4:
return float('inf') return float('inf')
return bbox_distance(bbox1, bbox2) return bbox_distance(bbox1, bbox2)
...@@ -213,9 +215,8 @@ class MagicModel: ...@@ -213,9 +215,8 @@ class MagicModel:
筛选出所有和 merged bbox 有 overlap 且 overlap 面积大于 object 的面积的 subjects。 筛选出所有和 merged bbox 有 overlap 且 overlap 面积大于 object 的面积的 subjects。
再求出筛选出的 subjects 和 object 的最短距离 再求出筛选出的 subjects 和 object 的最短距离
""" """
def search_overlap_between_boxes(
subject_idx, object_idx def search_overlap_between_boxes(subject_idx, object_idx):
):
idxes = [subject_idx, object_idx] idxes = [subject_idx, object_idx]
x0s = [all_bboxes[idx]['bbox'][0] for idx in idxes] x0s = [all_bboxes[idx]['bbox'][0] for idx in idxes]
y0s = [all_bboxes[idx]['bbox'][1] for idx in idxes] y0s = [all_bboxes[idx]['bbox'][1] for idx in idxes]
...@@ -243,9 +244,9 @@ class MagicModel: ...@@ -243,9 +244,9 @@ class MagicModel:
for other_object in other_objects: for other_object in other_objects:
ratio = max( ratio = max(
ratio, ratio,
get_overlap_area( get_overlap_area(merged_bbox, other_object['bbox'])
merged_bbox, other_object['bbox'] * 1.0
) * 1.0 / box_area(all_bboxes[object_idx]['bbox']) / box_area(all_bboxes[object_idx]['bbox']),
) )
if ratio >= MERGE_BOX_OVERLAP_AREA_RATIO: if ratio >= MERGE_BOX_OVERLAP_AREA_RATIO:
break break
...@@ -363,12 +364,17 @@ class MagicModel: ...@@ -363,12 +364,17 @@ class MagicModel:
if all_bboxes[j]['category_id'] == subject_category_id: if all_bboxes[j]['category_id'] == subject_category_id:
subject_idx, object_idx = j, i subject_idx, object_idx = j, i
if search_overlap_between_boxes(subject_idx, object_idx) >= MERGE_BOX_OVERLAP_AREA_RATIO: if (
search_overlap_between_boxes(subject_idx, object_idx)
>= MERGE_BOX_OVERLAP_AREA_RATIO
):
dis[i][j] = float('inf') dis[i][j] = float('inf')
dis[j][i] = dis[i][j] dis[j][i] = dis[i][j]
continue continue
dis[i][j] = self._bbox_distance(all_bboxes[subject_idx]['bbox'], all_bboxes[object_idx]['bbox']) dis[i][j] = self._bbox_distance(
all_bboxes[subject_idx]['bbox'], all_bboxes[object_idx]['bbox']
)
dis[j][i] = dis[i][j] dis[j][i] = dis[i][j]
used = set() used = set()
...@@ -584,6 +590,99 @@ class MagicModel: ...@@ -584,6 +590,99 @@ class MagicModel:
with_caption_subject.add(j) with_caption_subject.add(j)
return ret, total_subject_object_dis return ret, total_subject_object_dis
def __tie_up_category_by_distance_v2(
self, page_no, subject_category_id, object_category_id
):
subjects = self.__reduct_overlap(
list(
map(
lambda x: {'bbox': x['bbox'], 'score': x['score']},
filter(
lambda x: x['category_id'] == subject_category_id,
self.__model_list[page_no]['layout_dets'],
),
)
)
)
objects = self.__reduct_overlap(
list(
map(
lambda x: {'bbox': x['bbox'], 'score': x['score']},
filter(
lambda x: x['category_id'] == object_category_id,
self.__model_list[page_no]['layout_dets'],
),
)
)
)
print(len(subjects), len(objects))
subjects.sort(key=lambda x: x['bbox'][0] ** 2 + x['bbox'][1] ** 2)
objects.sort(key=lambda x: x['bbox'][0] ** 2 + x['bbox'][1] ** 2)
dis = [[float('inf')] * len(subjects) for _ in range(len(objects))]
for i, obj in enumerate(objects):
for j, sub in enumerate(subjects):
dis[i][j] = self._bbox_distance(sub['bbox'], obj['bbox'])
sub_obj_map_h = {i: [] for i in range(len(subjects))}
for i in range(len(objects)):
min_l_idx = 0
for j in range(1, len(subjects)):
if dis[i][j] == float('inf'):
continue
if dis[i][j] < dis[i][min_l_idx]:
min_l_idx = j
if dis[i][min_l_idx] < float('inf'):
sub_obj_map_h[min_l_idx].append(i)
else:
print(i, 'no nearest')
ret = []
for i in sub_obj_map_h.keys():
ret.append(
{
'sub_bbox': subjects[i]['bbox'],
'obj_bboxes': [objects[j]['bbox'] for j in sub_obj_map_h[i]],
'sub_idx': i,
}
)
return ret
def get_imgs_v2(self, page_no: int):
with_captions = self.__tie_up_category_by_distance_v2(page_no, 3, 4)
with_footnotes = self.__tie_up_category_by_distance_v2(
page_no, 3, CategoryId.ImageFootnote
)
ret = []
for v in with_captions:
record = {
'image_bbox': v['sub_bbox'],
'image_caption_bbox_list': v['obj_bboxes'],
}
filter_idx = v['sub_idx']
d = next(filter(lambda x: x['sub_idx'] == filter_idx, with_footnotes))
record['image_footnote_bbox_list'] = d['obj_bboxes']
ret.append(record)
return ret
def get_tables_v2(self, page_no: int) -> list:
with_captions = self.__tie_up_category_by_distance_v2(page_no, 5, 6)
with_footnotes = self.__tie_up_category_by_distance_v2(page_no, 5, 7)
ret = []
for v in with_captions:
record = {
'table_bbox': v['sub_bbox'],
'table_caption_bbox_list': v['obj_bboxes'],
}
filter_idx = v['sub_idx']
d = next(filter(lambda x: x['sub_idx'] == filter_idx, with_footnotes))
record['table_footnote_bbox_list'] = d['obj_bboxes']
ret.append(record)
return ret
def get_imgs(self, page_no: int): def get_imgs(self, page_no: int):
with_captions, _ = self.__tie_up_category_by_distance(page_no, 3, 4) with_captions, _ = self.__tie_up_category_by_distance(page_no, 3, 4)
with_footnotes, _ = self.__tie_up_category_by_distance( with_footnotes, _ = self.__tie_up_category_by_distance(
...@@ -717,10 +816,10 @@ class MagicModel: ...@@ -717,10 +816,10 @@ class MagicModel:
def get_page_size(self, page_no: int): # 获取页面宽高 def get_page_size(self, page_no: int): # 获取页面宽高
# 获取当前页的page对象 # 获取当前页的page对象
page = self.__docs[page_no] page = self.__docs.get_page(page_no).get_page_info()
# 获取当前页的宽高 # 获取当前页的宽高
page_w = page.rect.width page_w = page.w
page_h = page.rect.height page_h = page.h
return page_w, page_h return page_w, page_h
def __get_blocks_by_type( def __get_blocks_by_type(
......
from magic_pdf.config.enums import SupportedPdfParseMethod
from magic_pdf.data.dataset import PymuDocDataset
from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union
...@@ -8,10 +10,11 @@ def parse_pdf_by_ocr(pdf_bytes, ...@@ -8,10 +10,11 @@ def parse_pdf_by_ocr(pdf_bytes,
end_page_id=None, end_page_id=None,
debug_mode=False, debug_mode=False,
): ):
return pdf_parse_union(pdf_bytes, dataset = PymuDocDataset(pdf_bytes)
return pdf_parse_union(dataset,
model_list, model_list,
imageWriter, imageWriter,
"ocr", SupportedPdfParseMethod.OCR,
start_page_id=start_page_id, start_page_id=start_page_id,
end_page_id=end_page_id, end_page_id=end_page_id,
debug_mode=debug_mode, debug_mode=debug_mode,
......
from magic_pdf.config.enums import SupportedPdfParseMethod
from magic_pdf.data.dataset import PymuDocDataset
from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union
...@@ -9,10 +11,11 @@ def parse_pdf_by_txt( ...@@ -9,10 +11,11 @@ def parse_pdf_by_txt(
end_page_id=None, end_page_id=None,
debug_mode=False, debug_mode=False,
): ):
return pdf_parse_union(pdf_bytes, dataset = PymuDocDataset(pdf_bytes)
return pdf_parse_union(dataset,
model_list, model_list,
imageWriter, imageWriter,
"txt", SupportedPdfParseMethod.TXT,
start_page_id=start_page_id, start_page_id=start_page_id,
end_page_id=end_page_id, end_page_id=end_page_id,
debug_mode=debug_mode, debug_mode=debug_mode,
......
import os import os
import statistics import statistics
import time import time
from loguru import logger
from typing import List from typing import List
import torch import torch
from loguru import logger
from magic_pdf.config.enums import SupportedPdfParseMethod
from magic_pdf.data.dataset import Dataset, PageableData
from magic_pdf.libs.clean_memory import clean_memory from magic_pdf.libs.clean_memory import clean_memory
from magic_pdf.libs.commons import fitz, get_delta_time from magic_pdf.libs.commons import fitz, get_delta_time
from magic_pdf.libs.config_reader import get_local_layoutreader_model_dir from magic_pdf.libs.config_reader import get_local_layoutreader_model_dir
...@@ -19,27 +19,35 @@ from magic_pdf.libs.ocr_content_type import ContentType ...@@ -19,27 +19,35 @@ from magic_pdf.libs.ocr_content_type import ContentType
from magic_pdf.model.magic_model import MagicModel from magic_pdf.model.magic_model import MagicModel
from magic_pdf.para.para_split_v3 import para_split from magic_pdf.para.para_split_v3 import para_split
from magic_pdf.pre_proc.citationmarker_remove import remove_citation_marker from magic_pdf.pre_proc.citationmarker_remove import remove_citation_marker
from magic_pdf.pre_proc.construct_page_dict import ocr_construct_page_component_v2 from magic_pdf.pre_proc.construct_page_dict import \
ocr_construct_page_component_v2
from magic_pdf.pre_proc.cut_image import ocr_cut_image_and_table from magic_pdf.pre_proc.cut_image import ocr_cut_image_and_table
from magic_pdf.pre_proc.equations_replace import remove_chars_in_text_blocks, replace_equations_in_textblock, \ from magic_pdf.pre_proc.equations_replace import (
combine_chars_to_pymudict combine_chars_to_pymudict, remove_chars_in_text_blocks,
from magic_pdf.pre_proc.ocr_detect_all_bboxes import ocr_prepare_bboxes_for_layout_split_v2 replace_equations_in_textblock)
from magic_pdf.pre_proc.ocr_dict_merge import fill_spans_in_blocks, fix_block_spans, fix_discarded_block from magic_pdf.pre_proc.ocr_detect_all_bboxes import \
from magic_pdf.pre_proc.ocr_span_list_modify import remove_overlaps_min_spans, get_qa_need_list_v2, \ ocr_prepare_bboxes_for_layout_split_v2
remove_overlaps_low_confidence_spans from magic_pdf.pre_proc.ocr_dict_merge import (fill_spans_in_blocks,
from magic_pdf.pre_proc.resolve_bbox_conflict import check_useful_block_horizontal_overlap fix_block_spans,
fix_discarded_block)
from magic_pdf.pre_proc.ocr_span_list_modify import (
get_qa_need_list_v2, remove_overlaps_low_confidence_spans,
remove_overlaps_min_spans)
from magic_pdf.pre_proc.resolve_bbox_conflict import \
check_useful_block_horizontal_overlap
def remove_horizontal_overlap_block_which_smaller(all_bboxes): def remove_horizontal_overlap_block_which_smaller(all_bboxes):
useful_blocks = [] useful_blocks = []
for bbox in all_bboxes: for bbox in all_bboxes:
useful_blocks.append({ useful_blocks.append({'bbox': bbox[:4]})
"bbox": bbox[:4] is_useful_block_horz_overlap, smaller_bbox, bigger_bbox = (
}) check_useful_block_horizontal_overlap(useful_blocks)
is_useful_block_horz_overlap, smaller_bbox, bigger_bbox = check_useful_block_horizontal_overlap(useful_blocks) )
if is_useful_block_horz_overlap: if is_useful_block_horz_overlap:
logger.warning( logger.warning(
f"skip this page, reason: {DropReason.USEFUL_BLOCK_HOR_OVERLAP}, smaller bbox is {smaller_bbox}, bigger bbox is {bigger_bbox}") f'skip this page, reason: {DropReason.USEFUL_BLOCK_HOR_OVERLAP}, smaller bbox is {smaller_bbox}, bigger bbox is {bigger_bbox}'
) # noqa: E501
for bbox in all_bboxes.copy(): for bbox in all_bboxes.copy():
if smaller_bbox == bbox[:4]: if smaller_bbox == bbox[:4]:
all_bboxes.remove(bbox) all_bboxes.remove(bbox)
...@@ -47,27 +55,27 @@ def remove_horizontal_overlap_block_which_smaller(all_bboxes): ...@@ -47,27 +55,27 @@ def remove_horizontal_overlap_block_which_smaller(all_bboxes):
return is_useful_block_horz_overlap, all_bboxes return is_useful_block_horz_overlap, all_bboxes
def __replace_STX_ETX(text_str:str): def __replace_STX_ETX(text_str: str):
""" Replace \u0002 and \u0003, as these characters become garbled when extracted using pymupdf. In fact, they were originally quotation marks. """Replace \u0002 and \u0003, as these characters become garbled when extracted using pymupdf. In fact, they were originally quotation marks.
Drawback: This issue is only observed in English text; it has not been found in Chinese text so far. Drawback: This issue is only observed in English text; it has not been found in Chinese text so far.
Args: Args:
text_str (str): raw text text_str (str): raw text
Returns: Returns:
_type_: replaced text _type_: replaced text
""" """ # noqa: E501
if text_str: if text_str:
s = text_str.replace('\u0002', "'") s = text_str.replace('\u0002', "'")
s = s.replace("\u0003", "'") s = s.replace('\u0003', "'")
return s return s
return text_str return text_str
def txt_spans_extract(pdf_page, inline_equations, interline_equations): def txt_spans_extract(pdf_page, inline_equations, interline_equations):
text_raw_blocks = pdf_page.get_text("dict", flags=fitz.TEXTFLAGS_TEXT)["blocks"] text_raw_blocks = pdf_page.get_text('dict', flags=fitz.TEXTFLAGS_TEXT)['blocks']
char_level_text_blocks = pdf_page.get_text("rawdict", flags=fitz.TEXTFLAGS_TEXT)[ char_level_text_blocks = pdf_page.get_text('rawdict', flags=fitz.TEXTFLAGS_TEXT)[
"blocks" 'blocks'
] ]
text_blocks = combine_chars_to_pymudict(text_raw_blocks, char_level_text_blocks) text_blocks = combine_chars_to_pymudict(text_raw_blocks, char_level_text_blocks)
text_blocks = replace_equations_in_textblock( text_blocks = replace_equations_in_textblock(
...@@ -77,54 +85,63 @@ def txt_spans_extract(pdf_page, inline_equations, interline_equations): ...@@ -77,54 +85,63 @@ def txt_spans_extract(pdf_page, inline_equations, interline_equations):
text_blocks = remove_chars_in_text_blocks(text_blocks) text_blocks = remove_chars_in_text_blocks(text_blocks)
spans = [] spans = []
for v in text_blocks: for v in text_blocks:
for line in v["lines"]: for line in v['lines']:
for span in line["spans"]: for span in line['spans']:
bbox = span["bbox"] bbox = span['bbox']
if float_equal(bbox[0], bbox[2]) or float_equal(bbox[1], bbox[3]): if float_equal(bbox[0], bbox[2]) or float_equal(bbox[1], bbox[3]):
continue continue
if span.get('type') not in (ContentType.InlineEquation, ContentType.InterlineEquation): if span.get('type') not in (
ContentType.InlineEquation,
ContentType.InterlineEquation,
):
spans.append( spans.append(
{ {
"bbox": list(span["bbox"]), 'bbox': list(span['bbox']),
"content": __replace_STX_ETX(span["text"]), 'content': __replace_STX_ETX(span['text']),
"type": ContentType.Text, 'type': ContentType.Text,
"score": 1.0, 'score': 1.0,
} }
) )
return spans return spans
def replace_text_span(pymu_spans, ocr_spans): def replace_text_span(pymu_spans, ocr_spans):
return list(filter(lambda x: x["type"] != ContentType.Text, ocr_spans)) + pymu_spans return list(filter(lambda x: x['type'] != ContentType.Text, ocr_spans)) + pymu_spans
def model_init(model_name: str): def model_init(model_name: str):
from transformers import LayoutLMv3ForTokenClassification from transformers import LayoutLMv3ForTokenClassification
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device('cuda')
if torch.cuda.is_bf16_supported(): if torch.cuda.is_bf16_supported():
supports_bfloat16 = True supports_bfloat16 = True
else: else:
supports_bfloat16 = False supports_bfloat16 = False
else: else:
device = torch.device("cpu") device = torch.device('cpu')
supports_bfloat16 = False supports_bfloat16 = False
if model_name == "layoutreader": if model_name == 'layoutreader':
# 检测modelscope的缓存目录是否存在 # 检测modelscope的缓存目录是否存在
layoutreader_model_dir = get_local_layoutreader_model_dir() layoutreader_model_dir = get_local_layoutreader_model_dir()
if os.path.exists(layoutreader_model_dir): if os.path.exists(layoutreader_model_dir):
model = LayoutLMv3ForTokenClassification.from_pretrained(layoutreader_model_dir) model = LayoutLMv3ForTokenClassification.from_pretrained(
layoutreader_model_dir
)
else: else:
logger.warning( logger.warning(
f"local layoutreader model not exists, use online model from huggingface") 'local layoutreader model not exists, use online model from huggingface'
model = LayoutLMv3ForTokenClassification.from_pretrained("hantian/layoutreader") )
model = LayoutLMv3ForTokenClassification.from_pretrained(
'hantian/layoutreader'
)
# 检查设备是否支持 bfloat16 # 检查设备是否支持 bfloat16
if supports_bfloat16: if supports_bfloat16:
model.bfloat16() model.bfloat16()
model.to(device).eval() model.to(device).eval()
else: else:
logger.error("model name not allow") logger.error('model name not allow')
exit(1) exit(1)
return model return model
...@@ -145,7 +162,9 @@ class ModelSingleton: ...@@ -145,7 +162,9 @@ class ModelSingleton:
def do_predict(boxes: List[List[int]], model) -> List[int]: def do_predict(boxes: List[List[int]], model) -> List[int]:
from magic_pdf.model.v3.helpers import prepare_inputs, boxes2inputs, parse_logits from magic_pdf.model.v3.helpers import (boxes2inputs, parse_logits,
prepare_inputs)
inputs = boxes2inputs(boxes) inputs = boxes2inputs(boxes)
inputs = prepare_inputs(inputs, model) inputs = prepare_inputs(inputs, model)
logits = model(**inputs).logits.cpu().squeeze(0) logits = model(**inputs).logits.cpu().squeeze(0)
...@@ -193,21 +212,23 @@ def insert_lines_into_block(block_bbox, line_height, page_w, page_h): ...@@ -193,21 +212,23 @@ def insert_lines_into_block(block_bbox, line_height, page_w, page_h):
block_weight = x1 - x0 block_weight = x1 - x0
# 如果block高度小于n行正文,则直接返回block的bbox # 如果block高度小于n行正文,则直接返回block的bbox
if line_height*3 < block_height: if line_height * 3 < block_height:
if block_height > page_h*0.25 and page_w*0.5 > block_weight > page_w*0.25: # 可能是双列结构,可以切细点 if (
lines = int(block_height/line_height)+1 block_height > page_h * 0.25 and page_w * 0.5 > block_weight > page_w * 0.25
): # 可能是双列结构,可以切细点
lines = int(block_height / line_height) + 1
else: else:
# 如果block的宽度超过0.4页面宽度,则将block分成3行 # 如果block的宽度超过0.4页面宽度,则将block分成3行
if block_weight > page_w*0.4: if block_weight > page_w * 0.4:
line_height = (y1 - y0) / 3 line_height = (y1 - y0) / 3
lines = 3 lines = 3
elif block_weight > page_w*0.25: # 否则将block分成两行 elif block_weight > page_w * 0.25: # 否则将block分成两行
line_height = (y1 - y0) / 2 line_height = (y1 - y0) / 2
lines = 2 lines = 2
else: # 判断长宽比 else: # 判断长宽比
if block_height/block_weight > 1.2: # 细长的不分 if block_height / block_weight > 1.2: # 细长的不分
return [[x0, y0, x1, y1]] return [[x0, y0, x1, y1]]
else: # 不细长的还是分成两行 else: # 不细长的还是分成两行
line_height = (y1 - y0) / 2 line_height = (y1 - y0) / 2
lines = 2 lines = 2
...@@ -256,19 +277,23 @@ def sort_lines_by_model(fix_blocks, page_w, page_h, line_height): ...@@ -256,19 +277,23 @@ def sort_lines_by_model(fix_blocks, page_w, page_h, line_height):
for left, top, right, bottom in page_line_list: for left, top, right, bottom in page_line_list:
if left < 0: if left < 0:
logger.warning( logger.warning(
f"left < 0, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}") f'left < 0, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}'
) # noqa: E501
left = 0 left = 0
if right > page_w: if right > page_w:
logger.warning( logger.warning(
f"right > page_w, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}") f'right > page_w, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}'
) # noqa: E501
right = page_w right = page_w
if top < 0: if top < 0:
logger.warning( logger.warning(
f"top < 0, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}") f'top < 0, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}'
) # noqa: E501
top = 0 top = 0
if bottom > page_h: if bottom > page_h:
logger.warning( logger.warning(
f"bottom > page_h, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}") f'bottom > page_h, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}'
) # noqa: E501
bottom = page_h bottom = page_h
left = round(left * x_scale) left = round(left * x_scale)
...@@ -276,11 +301,11 @@ def sort_lines_by_model(fix_blocks, page_w, page_h, line_height): ...@@ -276,11 +301,11 @@ def sort_lines_by_model(fix_blocks, page_w, page_h, line_height):
right = round(right * x_scale) right = round(right * x_scale)
bottom = round(bottom * y_scale) bottom = round(bottom * y_scale)
assert ( assert (
1000 >= right >= left >= 0 and 1000 >= bottom >= top >= 0 1000 >= right >= left >= 0 and 1000 >= bottom >= top >= 0
), f"Invalid box. right: {right}, left: {left}, bottom: {bottom}, top: {top}" ), f'Invalid box. right: {right}, left: {left}, bottom: {bottom}, top: {top}' # noqa: E126, E121
boxes.append([left, top, right, bottom]) boxes.append([left, top, right, bottom])
model_manager = ModelSingleton() model_manager = ModelSingleton()
model = model_manager.get_model("layoutreader") model = model_manager.get_model('layoutreader')
with torch.no_grad(): with torch.no_grad():
orders = do_predict(boxes, model) orders = do_predict(boxes, model)
sorted_bboxes = [page_line_list[i] for i in orders] sorted_bboxes = [page_line_list[i] for i in orders]
...@@ -294,146 +319,195 @@ def get_line_height(blocks): ...@@ -294,146 +319,195 @@ def get_line_height(blocks):
if block['type'] in ['text', 'title', 'interline_equation']: if block['type'] in ['text', 'title', 'interline_equation']:
for line in block['lines']: for line in block['lines']:
bbox = line['bbox'] bbox = line['bbox']
page_line_height_list.append(int(bbox[3]-bbox[1])) page_line_height_list.append(int(bbox[3] - bbox[1]))
if len(page_line_height_list) > 0: if len(page_line_height_list) > 0:
return statistics.median(page_line_height_list) return statistics.median(page_line_height_list)
else: else:
return 10 return 10
def parse_page_core(pdf_docs, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode): def parse_page_core(
page_doc: PageableData, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode
):
need_drop = False need_drop = False
drop_reason = [] drop_reason = []
'''从magic_model对象中获取后面会用到的区块信息''' """从magic_model对象中获取后面会用到的区块信息"""
img_blocks = magic_model.get_imgs(page_id) img_blocks = magic_model.get_imgs(page_id)
table_blocks = magic_model.get_tables(page_id) table_blocks = magic_model.get_tables(page_id)
discarded_blocks = magic_model.get_discarded(page_id) discarded_blocks = magic_model.get_discarded(page_id)
text_blocks = magic_model.get_text_blocks(page_id) text_blocks = magic_model.get_text_blocks(page_id)
title_blocks = magic_model.get_title_blocks(page_id) title_blocks = magic_model.get_title_blocks(page_id)
inline_equations, interline_equations, interline_equation_blocks = magic_model.get_equations(page_id) inline_equations, interline_equations, interline_equation_blocks = (
magic_model.get_equations(page_id)
)
page_w, page_h = magic_model.get_page_size(page_id) page_w, page_h = magic_model.get_page_size(page_id)
spans = magic_model.get_all_spans(page_id) spans = magic_model.get_all_spans(page_id)
'''根据parse_mode,构造spans''' """根据parse_mode,构造spans"""
if parse_mode == "txt": if parse_mode == SupportedPdfParseMethod.TXT:
"""ocr 中文本类的 span 用 pymu spans 替换!""" """ocr 中文本类的 span 用 pymu spans 替换!"""
pymu_spans = txt_spans_extract( pymu_spans = txt_spans_extract(page_doc, inline_equations, interline_equations)
pdf_docs[page_id], inline_equations, interline_equations
)
spans = replace_text_span(pymu_spans, spans) spans = replace_text_span(pymu_spans, spans)
elif parse_mode == "ocr": elif parse_mode == SupportedPdfParseMethod.OCR:
pass pass
else: else:
raise Exception("parse_mode must be txt or ocr") raise Exception('parse_mode must be txt or ocr')
'''删除重叠spans中置信度较低的那些''' """删除重叠spans中置信度较低的那些"""
spans, dropped_spans_by_confidence = remove_overlaps_low_confidence_spans(spans) spans, dropped_spans_by_confidence = remove_overlaps_low_confidence_spans(spans)
'''删除重叠spans中较小的那些''' """删除重叠spans中较小的那些"""
spans, dropped_spans_by_span_overlap = remove_overlaps_min_spans(spans) spans, dropped_spans_by_span_overlap = remove_overlaps_min_spans(spans)
'''对image和table截图''' """对image和table截图"""
spans = ocr_cut_image_and_table(spans, pdf_docs[page_id], page_id, pdf_bytes_md5, imageWriter) spans = ocr_cut_image_and_table(
spans, page_doc, page_id, pdf_bytes_md5, imageWriter
)
'''将所有区块的bbox整理到一起''' """将所有区块的bbox整理到一起"""
# interline_equation_blocks参数不够准,后面切换到interline_equations上 # interline_equation_blocks参数不够准,后面切换到interline_equations上
interline_equation_blocks = [] interline_equation_blocks = []
if len(interline_equation_blocks) > 0: if len(interline_equation_blocks) > 0:
all_bboxes, all_discarded_blocks = ocr_prepare_bboxes_for_layout_split_v2( all_bboxes, all_discarded_blocks = ocr_prepare_bboxes_for_layout_split_v2(
img_blocks, table_blocks, discarded_blocks, text_blocks, title_blocks, img_blocks,
interline_equation_blocks, page_w, page_h) table_blocks,
discarded_blocks,
text_blocks,
title_blocks,
interline_equation_blocks,
page_w,
page_h,
)
else: else:
all_bboxes, all_discarded_blocks = ocr_prepare_bboxes_for_layout_split_v2( all_bboxes, all_discarded_blocks = ocr_prepare_bboxes_for_layout_split_v2(
img_blocks, table_blocks, discarded_blocks, text_blocks, title_blocks, img_blocks,
interline_equations, page_w, page_h) table_blocks,
discarded_blocks,
text_blocks,
title_blocks,
interline_equations,
page_w,
page_h,
)
'''先处理不需要排版的discarded_blocks''' """先处理不需要排版的discarded_blocks"""
discarded_block_with_spans, spans = fill_spans_in_blocks(all_discarded_blocks, spans, 0.4) discarded_block_with_spans, spans = fill_spans_in_blocks(
all_discarded_blocks, spans, 0.4
)
fix_discarded_blocks = fix_discarded_block(discarded_block_with_spans) fix_discarded_blocks = fix_discarded_block(discarded_block_with_spans)
'''如果当前页面没有bbox则跳过''' """如果当前页面没有bbox则跳过"""
if len(all_bboxes) == 0: if len(all_bboxes) == 0:
logger.warning(f"skip this page, not found useful bbox, page_id: {page_id}") logger.warning(f'skip this page, not found useful bbox, page_id: {page_id}')
return ocr_construct_page_component_v2([], [], page_id, page_w, page_h, [], return ocr_construct_page_component_v2(
[], [], interline_equations, fix_discarded_blocks, [],
need_drop, drop_reason) [],
page_id,
page_w,
page_h,
[],
[],
[],
interline_equations,
fix_discarded_blocks,
need_drop,
drop_reason,
)
'''将span填入blocks中''' """将span填入blocks中"""
block_with_spans, spans = fill_spans_in_blocks(all_bboxes, spans, 0.5) block_with_spans, spans = fill_spans_in_blocks(all_bboxes, spans, 0.5)
'''对block进行fix操作''' """对block进行fix操作"""
fix_blocks = fix_block_spans(block_with_spans, img_blocks, table_blocks) fix_blocks = fix_block_spans(block_with_spans, img_blocks, table_blocks)
'''获取所有line并计算正文line的高度''' """获取所有line并计算正文line的高度"""
line_height = get_line_height(fix_blocks) line_height = get_line_height(fix_blocks)
'''获取所有line并对line排序''' """获取所有line并对line排序"""
sorted_bboxes = sort_lines_by_model(fix_blocks, page_w, page_h, line_height) sorted_bboxes = sort_lines_by_model(fix_blocks, page_w, page_h, line_height)
'''根据line的中位数算block的序列关系''' """根据line的中位数算block的序列关系"""
fix_blocks = cal_block_index(fix_blocks, sorted_bboxes) fix_blocks = cal_block_index(fix_blocks, sorted_bboxes)
'''重排block''' """重排block"""
sorted_blocks = sorted(fix_blocks, key=lambda b: b['index']) sorted_blocks = sorted(fix_blocks, key=lambda b: b['index'])
'''获取QA需要外置的list''' """获取QA需要外置的list"""
images, tables, interline_equations = get_qa_need_list_v2(sorted_blocks) images, tables, interline_equations = get_qa_need_list_v2(sorted_blocks)
'''构造pdf_info_dict''' """构造pdf_info_dict"""
page_info = ocr_construct_page_component_v2(sorted_blocks, [], page_id, page_w, page_h, [], page_info = ocr_construct_page_component_v2(
images, tables, interline_equations, fix_discarded_blocks, sorted_blocks,
need_drop, drop_reason) [],
page_id,
page_w,
page_h,
[],
images,
tables,
interline_equations,
fix_discarded_blocks,
need_drop,
drop_reason,
)
return page_info return page_info
def pdf_parse_union(pdf_bytes, def pdf_parse_union(
model_list, dataset: Dataset,
imageWriter, model_list,
parse_mode, imageWriter,
start_page_id=0, parse_mode,
end_page_id=None, start_page_id=0,
debug_mode=False, end_page_id=None,
): debug_mode=False,
pdf_bytes_md5 = compute_md5(pdf_bytes) ):
pdf_docs = fitz.open("pdf", pdf_bytes) pdf_bytes_md5 = compute_md5(dataset.data_bits())
'''初始化空的pdf_info_dict''' """初始化空的pdf_info_dict"""
pdf_info_dict = {} pdf_info_dict = {}
'''用model_list和docs对象初始化magic_model''' """用model_list和docs对象初始化magic_model"""
magic_model = MagicModel(model_list, pdf_docs) magic_model = MagicModel(model_list, dataset)
'''根据输入的起始范围解析pdf''' """根据输入的起始范围解析pdf"""
# end_page_id = end_page_id if end_page_id else len(pdf_docs) - 1 # end_page_id = end_page_id if end_page_id else len(pdf_docs) - 1
end_page_id = end_page_id if end_page_id is not None and end_page_id >= 0 else len(pdf_docs) - 1 end_page_id = (
end_page_id
if end_page_id is not None and end_page_id >= 0
else len(dataset) - 1
)
if end_page_id > len(pdf_docs) - 1: if end_page_id > len(dataset) - 1:
logger.warning("end_page_id is out of range, use pdf_docs length") logger.warning('end_page_id is out of range, use pdf_docs length')
end_page_id = len(pdf_docs) - 1 end_page_id = len(dataset) - 1
'''初始化启动时间''' """初始化启动时间"""
start_time = time.time() start_time = time.time()
for page_id, page in enumerate(pdf_docs): for page_id, page in enumerate(dataset):
'''debug时输出每页解析的耗时''' """debug时输出每页解析的耗时."""
if debug_mode: if debug_mode:
time_now = time.time() time_now = time.time()
logger.info( logger.info(
f"page_id: {page_id}, last_page_cost_time: {get_delta_time(start_time)}" f'page_id: {page_id}, last_page_cost_time: {get_delta_time(start_time)}'
) )
start_time = time_now start_time = time_now
'''解析pdf中的每一页''' """解析pdf中的每一页"""
if start_page_id <= page_id <= end_page_id: if start_page_id <= page_id <= end_page_id:
page_info = parse_page_core(pdf_docs, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode) page_info = parse_page_core(
page, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode
)
else: else:
page_w = page.rect.width page_info = page.get_page_info()
page_h = page.rect.height page_w = page_info.w
page_info = ocr_construct_page_component_v2([], [], page_id, page_w, page_h, [], page_h = page_info.h
[], [], [], [], page_info = ocr_construct_page_component_v2(
True, "skip page") [], [], page_id, page_w, page_h, [], [], [], [], [], True, 'skip page'
pdf_info_dict[f"page_{page_id}"] = page_info )
pdf_info_dict[f'page_{page_id}'] = page_info
"""分段""" """分段"""
para_split(pdf_info_dict, debug_mode=debug_mode) para_split(pdf_info_dict, debug_mode=debug_mode)
...@@ -441,7 +515,7 @@ def pdf_parse_union(pdf_bytes, ...@@ -441,7 +515,7 @@ def pdf_parse_union(pdf_bytes,
"""dict转list""" """dict转list"""
pdf_info_list = dict_to_list(pdf_info_dict) pdf_info_list = dict_to_list(pdf_info_dict)
new_pdf_info_dict = { new_pdf_info_dict = {
"pdf_info": pdf_info_list, 'pdf_info': pdf_info_list,
} }
clean_memory() clean_memory()
......
...@@ -6,8 +6,8 @@ import click ...@@ -6,8 +6,8 @@ import click
from loguru import logger from loguru import logger
import magic_pdf.model as model_config import magic_pdf.model as model_config
from magic_pdf.libs.draw_bbox import (draw_layout_bbox, draw_span_bbox, from magic_pdf.libs.draw_bbox import (draw_layout_bbox, draw_line_sort_bbox,
draw_model_bbox, draw_line_sort_bbox) draw_model_bbox, draw_span_bbox)
from magic_pdf.libs.MakeContentConfig import DropMode, MakeMode from magic_pdf.libs.MakeContentConfig import DropMode, MakeMode
from magic_pdf.pipe.OCRPipe import OCRPipe from magic_pdf.pipe.OCRPipe import OCRPipe
from magic_pdf.pipe.TXTPipe import TXTPipe from magic_pdf.pipe.TXTPipe import TXTPipe
......
from loguru import logger
def ImportPIL(f):
try:
import PIL # noqa: F401
except ImportError:
logger.error('Pillow not installed, please install by pip.')
exit(1)
return f
{"track_id":"e8824f5a-9fcb-4ee5-b2d4-6bf2c67019dc","path":"s3://sci-hub/enbook-scimag/78800000/libgen.scimag78872000-78872999/10.1017/cbo9780511770425.012.pdf","file_type":"pdf","content_type":"application/pdf","content_length":80078,"title":"German Idealism and the Concept of Punishment || Conclusion","remark":{"file_id":"scihub_78800000/libgen.scimag78872000-78872999.zip_10.1017/cbo9780511770425.012","file_source_type":"paper","original_file_id":"10.1017/cbo9780511770425.012","file_name":"10.1017/cbo9780511770425.012.pdf","author":"Merle, Jean-Christophe"}}
{"track_id":"e8824f5a-9fcb-4ee5-b2d4-6bf2c67019dc","path":"tests/test_data/assets/pdfs/test_02.pdf","file_type":"pdf","content_type":"application/pdf","content_length":80078,"title":"German Idealism and the Concept of Punishment || Conclusion","remark":{"file_id":"scihub_78800000/libgen.scimag78872000-78872999.zip_10.1017/cbo9780511770425.012","file_source_type":"paper","original_file_id":"10.1017/cbo9780511770425.012","file_name":"10.1017/cbo9780511770425.012.pdf","author":"Merle, Jean-Christophe"}}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment