Commit bd927919 authored by myhloli's avatar myhloli
Browse files

refactor: rename init file and update app.py to enable parsing method

parent f5016508
import io
import requests
from magic_pdf.data.io.base import IOReader, IOWriter
class HttpReader(IOReader):
def read(self, url: str) -> bytes:
"""Read the file.
Args:
path (str): file path to read
Returns:
bytes: the content of the file
"""
return requests.get(url).content
def read_at(self, path: str, offset: int = 0, limit: int = -1) -> bytes:
"""Not Implemented."""
raise NotImplementedError
class HttpWriter(IOWriter):
def write(self, url: str, data: bytes) -> None:
"""Write file with data.
Args:
path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
data (bytes): the data want to write
"""
files = {'file': io.BytesIO(data)}
response = requests.post(url, files=files)
assert 300 > response.status_code and response.status_code > 199
import boto3
from botocore.config import Config
from magic_pdf.data.io.base import IOReader, IOWriter
class S3Reader(IOReader):
def __init__(
self,
bucket: str,
ak: str,
sk: str,
endpoint_url: str,
addressing_style: str = 'auto',
):
"""s3 reader client.
Args:
bucket (str): bucket name
ak (str): access key
sk (str): secret key
endpoint_url (str): endpoint url of s3
addressing_style (str, optional): Defaults to 'auto'. Other valid options here are 'path' and 'virtual'
refer to https://boto3.amazonaws.com/v1/documentation/api/1.9.42/guide/s3.html
"""
self._bucket = bucket
self._ak = ak
self._sk = sk
self._s3_client = boto3.client(
service_name='s3',
aws_access_key_id=ak,
aws_secret_access_key=sk,
endpoint_url=endpoint_url,
config=Config(
s3={'addressing_style': addressing_style},
retries={'max_attempts': 5, 'mode': 'standard'},
),
)
def read(self, key: str) -> bytes:
"""Read the file.
Args:
path (str): file path to read
Returns:
bytes: the content of the file
"""
return self.read_at(key)
def read_at(self, key: str, offset: int = 0, limit: int = -1) -> bytes:
"""Read at offset and limit.
Args:
path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
offset (int, optional): the number of bytes skipped. Defaults to 0.
limit (int, optional): the length of bytes want to read. Defaults to -1.
Returns:
bytes: the content of file
"""
if limit > -1:
range_header = f'bytes={offset}-{offset+limit-1}'
res = self._s3_client.get_object(
Bucket=self._bucket, Key=key, Range=range_header
)
else:
res = self._s3_client.get_object(
Bucket=self._bucket, Key=key, Range=f'bytes={offset}-'
)
return res['Body'].read()
class S3Writer(IOWriter):
def __init__(
self,
bucket: str,
ak: str,
sk: str,
endpoint_url: str,
addressing_style: str = 'auto',
):
"""s3 reader client.
Args:
bucket (str): bucket name
ak (str): access key
sk (str): secret key
endpoint_url (str): endpoint url of s3
addressing_style (str, optional): Defaults to 'auto'. Other valid options here are 'path' and 'virtual'
refer to https://boto3.amazonaws.com/v1/documentation/api/1.9.42/guide/s3.html
"""
self._bucket = bucket
self._ak = ak
self._sk = sk
self._s3_client = boto3.client(
service_name='s3',
aws_access_key_id=ak,
aws_secret_access_key=sk,
endpoint_url=endpoint_url,
config=Config(
s3={'addressing_style': addressing_style},
retries={'max_attempts': 5, 'mode': 'standard'},
),
)
def write(self, key: str, data: bytes):
"""Write file with data.
Args:
path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
data (bytes): the data want to write
"""
self._s3_client.put_object(Bucket=self._bucket, Key=key, Body=data)
import json
import os
import tempfile
import shutil
from pathlib import Path
from magic_pdf.config.exceptions import EmptyData, InvalidParams
from magic_pdf.data.data_reader_writer import (FileBasedDataReader,
MultiBucketS3DataReader)
from magic_pdf.data.dataset import ImageDataset, PymuDocDataset
from magic_pdf.utils.office_to_pdf import convert_file_to_pdf, ConvertToPdfError
def read_jsonl(
s3_path_or_local: str, s3_client: MultiBucketS3DataReader | None = None
) -> list[PymuDocDataset]:
"""Read the jsonl file and return the list of PymuDocDataset.
Args:
s3_path_or_local (str): local file or s3 path
s3_client (MultiBucketS3DataReader | None, optional): s3 client that support multiple bucket. Defaults to None.
Raises:
InvalidParams: if s3_path_or_local is s3 path but s3_client is not provided.
EmptyData: if no pdf file location is provided in some line of jsonl file.
InvalidParams: if the file location is s3 path but s3_client is not provided
Returns:
list[PymuDocDataset]: each line in the jsonl file will be converted to a PymuDocDataset
"""
bits_arr = []
if s3_path_or_local.startswith('s3://'):
if s3_client is None:
raise InvalidParams('s3_client is required when s3_path is provided')
jsonl_bits = s3_client.read(s3_path_or_local)
else:
jsonl_bits = FileBasedDataReader('').read(s3_path_or_local)
jsonl_d = [
json.loads(line) for line in jsonl_bits.decode().split('\n') if line.strip()
]
for d in jsonl_d:
pdf_path = d.get('file_location', '') or d.get('path', '')
if len(pdf_path) == 0:
raise EmptyData('pdf file location is empty')
if pdf_path.startswith('s3://'):
if s3_client is None:
raise InvalidParams('s3_client is required when s3_path is provided')
bits_arr.append(s3_client.read(pdf_path))
else:
bits_arr.append(FileBasedDataReader('').read(pdf_path))
return [PymuDocDataset(bits) for bits in bits_arr]
def read_local_pdfs(path: str) -> list[PymuDocDataset]:
"""Read pdf from path or directory.
Args:
path (str): pdf file path or directory that contains pdf files
Returns:
list[PymuDocDataset]: each pdf file will converted to a PymuDocDataset
"""
if os.path.isdir(path):
reader = FileBasedDataReader()
ret = []
for root, _, files in os.walk(path):
for file in files:
suffix = file.split('.')
if suffix[-1] == 'pdf':
ret.append( PymuDocDataset(reader.read(os.path.join(root, file))))
return ret
else:
reader = FileBasedDataReader()
bits = reader.read(path)
return [PymuDocDataset(bits)]
def read_local_office(path: str) -> list[PymuDocDataset]:
"""Read ms-office file (ppt, pptx, doc, docx) from path or directory.
Args:
path (str): ms-office file or directory that contains ms-office files
Returns:
list[PymuDocDataset]: each ms-office file will converted to a PymuDocDataset
Raises:
ConvertToPdfError: Failed to convert ms-office file to pdf via libreoffice
FileNotFoundError: File not Found
Exception: Unknown Exception raised
"""
suffixes = ['.ppt', '.pptx', '.doc', '.docx']
fns = []
ret = []
if os.path.isdir(path):
for root, _, files in os.walk(path):
for file in files:
suffix = Path(file).suffix
if suffix in suffixes:
fns.append((os.path.join(root, file)))
else:
fns.append(path)
reader = FileBasedDataReader()
temp_dir = tempfile.mkdtemp()
for fn in fns:
try:
convert_file_to_pdf(fn, temp_dir)
except ConvertToPdfError as e:
raise e
except FileNotFoundError as e:
raise e
except Exception as e:
raise e
fn_path = Path(fn)
pdf_fn = f"{temp_dir}/{fn_path.stem}.pdf"
ret.append(PymuDocDataset(reader.read(pdf_fn)))
shutil.rmtree(temp_dir)
return ret
def read_local_images(path: str, suffixes: list[str]=['.png', '.jpg', '.jpeg']) -> list[ImageDataset]:
"""Read images from path or directory.
Args:
path (str): image file path or directory that contains image files
suffixes (list[str]): the suffixes of the image files used to filter the files. Example: ['.jpg', '.png']
Returns:
list[ImageDataset]: each image file will converted to a ImageDataset
"""
if os.path.isdir(path):
imgs_bits = []
s_suffixes = set(suffixes)
reader = FileBasedDataReader()
for root, _, files in os.walk(path):
for file in files:
suffix = Path(file).suffix
if suffix in s_suffixes:
imgs_bits.append(reader.read(os.path.join(root, file)))
return [ImageDataset(bits) for bits in imgs_bits]
else:
reader = FileBasedDataReader()
bits = reader.read(path)
return [ImageDataset(bits)]
from pydantic import BaseModel, Field
class S3Config(BaseModel):
"""S3 config
"""
bucket_name: str = Field(description='s3 bucket name', min_length=1)
access_key: str = Field(description='s3 access key', min_length=1)
secret_key: str = Field(description='s3 secret key', min_length=1)
endpoint_url: str = Field(description='s3 endpoint url', min_length=1)
addressing_style: str = Field(description='s3 addressing style', default='auto', min_length=1)
class PageInfo(BaseModel):
"""The width and height of page
"""
w: float = Field(description='the width of page')
h: float = Field(description='the height of page')
import multiprocessing as mp
import threading
from concurrent.futures import (ProcessPoolExecutor, ThreadPoolExecutor,
as_completed)
import fitz
import numpy as np
from loguru import logger
def fitz_doc_to_image(page, dpi=200) -> dict:
"""Convert fitz.Document to image, Then convert the image to numpy array.
Args:
page (_type_): pymudoc page
dpi (int, optional): reset the dpi of dpi. Defaults to 200.
Returns:
dict: {'img': numpy array, 'width': width, 'height': height }
"""
mat = fitz.Matrix(dpi / 72, dpi / 72)
pm = page.get_pixmap(matrix=mat, alpha=False)
# If the width or height exceeds 4500 after scaling, do not scale further.
if pm.width > 4500 or pm.height > 4500:
pm = page.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False)
# Convert pixmap samples directly to numpy array
img = np.frombuffer(pm.samples, dtype=np.uint8).reshape(pm.height, pm.width, 3)
img_dict = {'img': img, 'width': pm.width, 'height': pm.height}
return img_dict
def load_images_from_pdf(pdf_bytes: bytes, dpi=200, start_page_id=0, end_page_id=None) -> list:
images = []
with fitz.open('pdf', pdf_bytes) as doc:
pdf_page_num = doc.page_count
end_page_id = (
end_page_id
if end_page_id is not None and end_page_id >= 0
else pdf_page_num - 1
)
if end_page_id > pdf_page_num - 1:
logger.warning('end_page_id is out of range, use images length')
end_page_id = pdf_page_num - 1
for index in range(0, doc.page_count):
if start_page_id <= index <= end_page_id:
page = doc[index]
mat = fitz.Matrix(dpi / 72, dpi / 72)
pm = page.get_pixmap(matrix=mat, alpha=False)
# If the width or height exceeds 4500 after scaling, do not scale further.
if pm.width > 4500 or pm.height > 4500:
pm = page.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False)
# Convert pixmap samples directly to numpy array
img = np.frombuffer(pm.samples, dtype=np.uint8).reshape(pm.height, pm.width, 3)
img_dict = {'img': img, 'width': pm.width, 'height': pm.height}
else:
img_dict = {'img': [], 'width': 0, 'height': 0}
images.append(img_dict)
return images
def convert_page(bytes_page):
pdfs = fitz.open('pdf', bytes_page)
page = pdfs[0]
return fitz_doc_to_image(page)
def parallel_process_pdf_safe(pages, num_workers=None, **kwargs):
"""Process PDF pages in parallel with serialization-safe approach."""
if num_workers is None:
num_workers = mp.cpu_count()
# Process the extracted page data in parallel
with ProcessPoolExecutor(max_workers=num_workers) as executor:
# Process the page data
results = list(
executor.map(convert_page, pages)
)
return results
def threaded_process_pdf(pdf_path, num_threads=4, **kwargs):
"""Process all pages of a PDF using multiple threads.
Parameters:
-----------
pdf_path : str
Path to the PDF file
num_threads : int
Number of threads to use
**kwargs :
Additional arguments for fitz_doc_to_image
Returns:
--------
images : list
List of processed images, in page order
"""
# Open the PDF
doc = fitz.open(pdf_path)
num_pages = len(doc)
# Create a list to store results in the correct order
results = [None] * num_pages
# Create a thread pool
with ThreadPoolExecutor(max_workers=num_threads) as executor:
# Submit all tasks
futures = {}
for page_num in range(num_pages):
page = doc[page_num]
future = executor.submit(fitz_doc_to_image, page, **kwargs)
futures[future] = page_num
# Process results as they complete with progress bar
for future in as_completed(futures):
page_num = futures[future]
try:
results[page_num] = future.result()
except Exception as e:
print(f'Error processing page {page_num}: {e}')
results[page_num] = None
# Close the document
doc.close()
if __name__ == '__main__':
pdf = fitz.open('/tmp/[MS-DOC].pdf')
pdf_page = [fitz.open() for i in range(pdf.page_count)]
[pdf_page[i].insert_pdf(pdf, from_page=i, to_page=i) for i in range(pdf.page_count)]
pdf_page = [v.tobytes() for v in pdf_page]
results = parallel_process_pdf_safe(pdf_page, num_workers=16)
# threaded_process_pdf('/tmp/[MS-DOC].pdf', num_threads=16)
""" benchmark results of multi-threaded processing (fitz page to image)
total page nums: 578
thread nums, time cost
1 7.351 sec
2 6.334 sec
4 5.968 sec
8 6.728 sec
16 8.085 sec
"""
""" benchmark results of multi-processor processing (fitz page to image)
total page nums: 578
processor nums, time cost
1 17.170 sec
2 10.170 sec
4 7.841 sec
8 7.900 sec
16 7.984 sec
"""
import re
from loguru import logger
from magic_pdf.config.make_content_config import DropMode, MakeMode
from magic_pdf.config.ocr_content_type import BlockType, ContentType
from magic_pdf.libs.commons import join_path
from magic_pdf.libs.config_reader import get_latex_delimiter_config
from magic_pdf.libs.language import detect_lang
from magic_pdf.libs.markdown_utils import ocr_escape_special_markdown_char
from magic_pdf.post_proc.para_split_v3 import ListLineTag
def __is_hyphen_at_line_end(line):
"""Check if a line ends with one or more letters followed by a hyphen.
Args:
line (str): The line of text to check.
Returns:
bool: True if the line ends with one or more letters followed by a hyphen, False otherwise.
"""
# Use regex to check if the line ends with one or more letters followed by a hyphen
return bool(re.search(r'[A-Za-z]+-\s*$', line))
def ocr_mk_mm_markdown_with_para_and_pagination(pdf_info_dict: list,
img_buket_path):
markdown_with_para_and_pagination = []
page_no = 0
for page_info in pdf_info_dict:
paras_of_layout = page_info.get('para_blocks')
if not paras_of_layout:
markdown_with_para_and_pagination.append({
'page_no':
page_no,
'md_content':
'',
})
page_no += 1
continue
page_markdown = ocr_mk_markdown_with_para_core_v2(
paras_of_layout, 'mm', img_buket_path)
markdown_with_para_and_pagination.append({
'page_no':
page_no,
'md_content':
'\n\n'.join(page_markdown)
})
page_no += 1
return markdown_with_para_and_pagination
def ocr_mk_markdown_with_para_core_v2(paras_of_layout,
mode,
img_buket_path='',
):
page_markdown = []
for para_block in paras_of_layout:
para_text = ''
para_type = para_block['type']
if para_type in [BlockType.Text, BlockType.List, BlockType.Index]:
para_text = merge_para_with_text(para_block)
elif para_type == BlockType.Title:
title_level = get_title_level(para_block)
para_text = f'{"#" * title_level} {merge_para_with_text(para_block)}'
elif para_type == BlockType.InterlineEquation:
para_text = merge_para_with_text(para_block)
elif para_type == BlockType.Image:
if mode == 'nlp':
continue
elif mode == 'mm':
# 检测是否存在图片脚注
has_image_footnote = any(block['type'] == BlockType.ImageFootnote for block in para_block['blocks'])
# 如果存在图片脚注,则将图片脚注拼接到图片正文后面
if has_image_footnote:
for block in para_block['blocks']: # 1st.拼image_caption
if block['type'] == BlockType.ImageCaption:
para_text += merge_para_with_text(block) + ' \n'
for block in para_block['blocks']: # 2nd.拼image_body
if block['type'] == BlockType.ImageBody:
for line in block['lines']:
for span in line['spans']:
if span['type'] == ContentType.Image:
if span.get('image_path', ''):
para_text += f"![]({img_buket_path}/{span['image_path']})"
for block in para_block['blocks']: # 3rd.拼image_footnote
if block['type'] == BlockType.ImageFootnote:
para_text += ' \n' + merge_para_with_text(block)
else:
for block in para_block['blocks']: # 1st.拼image_body
if block['type'] == BlockType.ImageBody:
for line in block['lines']:
for span in line['spans']:
if span['type'] == ContentType.Image:
if span.get('image_path', ''):
para_text += f"![]({img_buket_path}/{span['image_path']})"
for block in para_block['blocks']: # 2nd.拼image_caption
if block['type'] == BlockType.ImageCaption:
para_text += ' \n' + merge_para_with_text(block)
elif para_type == BlockType.Table:
if mode == 'nlp':
continue
elif mode == 'mm':
for block in para_block['blocks']: # 1st.拼table_caption
if block['type'] == BlockType.TableCaption:
para_text += merge_para_with_text(block) + ' \n'
for block in para_block['blocks']: # 2nd.拼table_body
if block['type'] == BlockType.TableBody:
for line in block['lines']:
for span in line['spans']:
if span['type'] == ContentType.Table:
# if processed by table model
if span.get('html', ''):
para_text += f"\n{span['html']}\n"
elif span.get('image_path', ''):
para_text += f"![]({img_buket_path}/{span['image_path']})"
for block in para_block['blocks']: # 3rd.拼table_footnote
if block['type'] == BlockType.TableFootnote:
para_text += '\n' + merge_para_with_text(block) + ' '
if para_text.strip() == '':
continue
else:
# page_markdown.append(para_text.strip() + ' ')
page_markdown.append(para_text.strip())
return page_markdown
def detect_language(text):
en_pattern = r'[a-zA-Z]+'
en_matches = re.findall(en_pattern, text)
en_length = sum(len(match) for match in en_matches)
if len(text) > 0:
if en_length / len(text) >= 0.5:
return 'en'
else:
return 'unknown'
else:
return 'empty'
def full_to_half(text: str) -> str:
"""Convert full-width characters to half-width characters using code point manipulation.
Args:
text: String containing full-width characters
Returns:
String with full-width characters converted to half-width
"""
result = []
for char in text:
code = ord(char)
# Full-width letters and numbers (FF21-FF3A for A-Z, FF41-FF5A for a-z, FF10-FF19 for 0-9)
if (0xFF21 <= code <= 0xFF3A) or (0xFF41 <= code <= 0xFF5A) or (0xFF10 <= code <= 0xFF19):
result.append(chr(code - 0xFEE0)) # Shift to ASCII range
else:
result.append(char)
return ''.join(result)
latex_delimiters_config = get_latex_delimiter_config()
default_delimiters = {
'display': {'left': '$$', 'right': '$$'},
'inline': {'left': '$', 'right': '$'}
}
delimiters = latex_delimiters_config if latex_delimiters_config else default_delimiters
display_left_delimiter = delimiters['display']['left']
display_right_delimiter = delimiters['display']['right']
inline_left_delimiter = delimiters['inline']['left']
inline_right_delimiter = delimiters['inline']['right']
def merge_para_with_text(para_block):
block_text = ''
for line in para_block['lines']:
for span in line['spans']:
if span['type'] in [ContentType.Text]:
span['content'] = full_to_half(span['content'])
block_text += span['content']
block_lang = detect_lang(block_text)
para_text = ''
for i, line in enumerate(para_block['lines']):
if i >= 1 and line.get(ListLineTag.IS_LIST_START_LINE, False):
para_text += ' \n'
for j, span in enumerate(line['spans']):
span_type = span['type']
content = ''
if span_type == ContentType.Text:
content = ocr_escape_special_markdown_char(span['content'])
elif span_type == ContentType.InlineEquation:
content = f"{inline_left_delimiter}{span['content']}{inline_right_delimiter}"
elif span_type == ContentType.InterlineEquation:
content = f"\n{display_left_delimiter}\n{span['content']}\n{display_right_delimiter}\n"
content = content.strip()
if content:
langs = ['zh', 'ja', 'ko']
# logger.info(f'block_lang: {block_lang}, content: {content}')
if block_lang in langs: # 中文/日语/韩文语境下,换行不需要空格分隔,但是如果是行内公式结尾,还是要加空格
if j == len(line['spans']) - 1 and span_type not in [ContentType.InlineEquation]:
para_text += content
else:
para_text += f'{content} '
else:
if span_type in [ContentType.Text, ContentType.InlineEquation]:
# 如果span是line的最后一个且末尾带有-连字符,那么末尾不应该加空格,同时应该把-删除
if j == len(line['spans'])-1 and span_type == ContentType.Text and __is_hyphen_at_line_end(content):
para_text += content[:-1]
else: # 西方文本语境下 content间需要空格分隔
para_text += f'{content} '
elif span_type == ContentType.InterlineEquation:
para_text += content
else:
continue
# 连写字符拆分
# para_text = __replace_ligatures(para_text)
return para_text
def para_to_standard_format_v2(para_block, img_buket_path, page_idx, drop_reason=None):
para_type = para_block['type']
para_content = {}
if para_type in [BlockType.Text, BlockType.List, BlockType.Index]:
para_content = {
'type': 'text',
'text': merge_para_with_text(para_block),
}
elif para_type == BlockType.Title:
para_content = {
'type': 'text',
'text': merge_para_with_text(para_block),
}
title_level = get_title_level(para_block)
if title_level != 0:
para_content['text_level'] = title_level
elif para_type == BlockType.InterlineEquation:
para_content = {
'type': 'equation',
'text': merge_para_with_text(para_block),
'text_format': 'latex',
}
elif para_type == BlockType.Image:
para_content = {'type': 'image', 'img_path': '', 'img_caption': [], 'img_footnote': []}
for block in para_block['blocks']:
if block['type'] == BlockType.ImageBody:
for line in block['lines']:
for span in line['spans']:
if span['type'] == ContentType.Image:
if span.get('image_path', ''):
para_content['img_path'] = join_path(img_buket_path, span['image_path'])
if block['type'] == BlockType.ImageCaption:
para_content['img_caption'].append(merge_para_with_text(block))
if block['type'] == BlockType.ImageFootnote:
para_content['img_footnote'].append(merge_para_with_text(block))
elif para_type == BlockType.Table:
para_content = {'type': 'table', 'img_path': '', 'table_caption': [], 'table_footnote': []}
for block in para_block['blocks']:
if block['type'] == BlockType.TableBody:
for line in block['lines']:
for span in line['spans']:
if span['type'] == ContentType.Table:
if span.get('latex', ''):
para_content['table_body'] = f"{span['latex']}"
elif span.get('html', ''):
para_content['table_body'] = f"{span['html']}"
if span.get('image_path', ''):
para_content['img_path'] = join_path(img_buket_path, span['image_path'])
if block['type'] == BlockType.TableCaption:
para_content['table_caption'].append(merge_para_with_text(block))
if block['type'] == BlockType.TableFootnote:
para_content['table_footnote'].append(merge_para_with_text(block))
para_content['page_idx'] = page_idx
if drop_reason is not None:
para_content['drop_reason'] = drop_reason
return para_content
def union_make(pdf_info_dict: list,
make_mode: str,
drop_mode: str,
img_buket_path: str = '',
):
output_content = []
for page_info in pdf_info_dict:
drop_reason_flag = False
drop_reason = None
if page_info.get('need_drop', False):
drop_reason = page_info.get('drop_reason')
if drop_mode == DropMode.NONE:
pass
elif drop_mode == DropMode.NONE_WITH_REASON:
drop_reason_flag = True
elif drop_mode == DropMode.WHOLE_PDF:
raise Exception((f'drop_mode is {DropMode.WHOLE_PDF} ,'
f'drop_reason is {drop_reason}'))
elif drop_mode == DropMode.SINGLE_PAGE:
logger.warning((f'drop_mode is {DropMode.SINGLE_PAGE} ,'
f'drop_reason is {drop_reason}'))
continue
else:
raise Exception('drop_mode can not be null')
paras_of_layout = page_info.get('para_blocks')
page_idx = page_info.get('page_idx')
if not paras_of_layout:
continue
if make_mode == MakeMode.MM_MD:
page_markdown = ocr_mk_markdown_with_para_core_v2(
paras_of_layout, 'mm', img_buket_path)
output_content.extend(page_markdown)
elif make_mode == MakeMode.NLP_MD:
page_markdown = ocr_mk_markdown_with_para_core_v2(
paras_of_layout, 'nlp')
output_content.extend(page_markdown)
elif make_mode == MakeMode.STANDARD_FORMAT:
for para_block in paras_of_layout:
if drop_reason_flag:
para_content = para_to_standard_format_v2(
para_block, img_buket_path, page_idx)
else:
para_content = para_to_standard_format_v2(
para_block, img_buket_path, page_idx)
output_content.append(para_content)
if make_mode in [MakeMode.MM_MD, MakeMode.NLP_MD]:
return '\n\n'.join(output_content)
elif make_mode == MakeMode.STANDARD_FORMAT:
return output_content
def get_title_level(block):
title_level = block.get('level', 1)
if title_level > 4:
title_level = 4
elif title_level < 1:
title_level = 0
return title_level
\ No newline at end of file
from magic_pdf.config.drop_reason import DropReason
from magic_pdf.config.enums import SupportedPdfParseMethod
from magic_pdf.filter.pdf_classify_by_type import classify as do_classify
from magic_pdf.filter.pdf_meta_scan import pdf_meta_scan
def classify(pdf_bytes: bytes) -> SupportedPdfParseMethod:
"""根据pdf的元数据,判断是文本pdf,还是ocr pdf."""
pdf_meta = pdf_meta_scan(pdf_bytes)
if pdf_meta.get('_need_drop', False): # 如果返回了需要丢弃的标志,则抛出异常
raise Exception(f"pdf meta_scan need_drop,reason is {pdf_meta['_drop_reason']}")
else:
is_encrypted = pdf_meta['is_encrypted']
is_needs_password = pdf_meta['is_needs_password']
if is_encrypted or is_needs_password: # 加密的,需要密码的,没有页面的,都不处理
raise Exception(f'pdf meta_scan need_drop,reason is {DropReason.ENCRYPTED}')
else:
is_text_pdf, results = do_classify(
pdf_meta['total_page'],
pdf_meta['page_width_pts'],
pdf_meta['page_height_pts'],
pdf_meta['image_info_per_page'],
pdf_meta['text_len_per_page'],
pdf_meta['imgs_per_page'],
# pdf_meta['text_layout_per_page'],
pdf_meta['invalid_chars'],
)
if is_text_pdf:
return SupportedPdfParseMethod.TXT
else:
return SupportedPdfParseMethod.OCR
This diff is collapsed.
"""输入: s3路径,每行一个 输出: pdf文件元信息,包括每一页上的所有图片的长宽高,bbox位置."""
from collections import Counter
import fitz
from loguru import logger
from magic_pdf.config.drop_reason import DropReason
from magic_pdf.libs.commons import get_top_percent_list, mymax
from magic_pdf.libs.language import detect_lang
from magic_pdf.libs.pdf_check import detect_invalid_chars_by_pymupdf, detect_invalid_chars
scan_max_page = 50
junk_limit_min = 10
def calculate_max_image_area_per_page(result: list, page_width_pts, page_height_pts):
max_image_area_per_page = [
mymax([(x1 - x0) * (y1 - y0) for x0, y0, x1, y1, _ in page_img_sz])
for page_img_sz in result
]
page_area = int(page_width_pts) * int(page_height_pts)
max_image_area_per_page = [area / page_area for area in max_image_area_per_page]
max_image_area_per_page = [area for area in max_image_area_per_page if area > 0.6]
return max_image_area_per_page
def process_image(page, junk_img_bojids=[]):
page_result = [] # 存每个页面里的多张图四元组信息
items = page.get_images()
dedup = set()
for img in items:
# 这里返回的是图片在page上的实际展示的大小。返回一个数组,每个元素第一部分是
img_bojid = img[
0
] # 在pdf文件中是全局唯一的,如果这个图反复出现在pdf里那么就可能是垃圾信息,例如水印、页眉页脚等
if img_bojid in junk_img_bojids: # 如果是垃圾图像,就跳过
continue
recs = page.get_image_rects(img, transform=True)
if recs:
rec = recs[0][0]
x0, y0, x1, y1 = map(int, rec)
width = x1 - x0
height = y1 - y0
if (
x0,
y0,
x1,
y1,
img_bojid,
) in dedup: # 这里面会出现一些重复的bbox,无需重复出现,需要去掉
continue
if not all(
[width, height]
): # 长和宽任何一个都不能是0,否则这个图片不可见,没有实际意义
continue
dedup.add((x0, y0, x1, y1, img_bojid))
page_result.append([x0, y0, x1, y1, img_bojid])
return page_result
def get_image_info(doc: fitz.Document, page_width_pts, page_height_pts) -> list:
"""返回每个页面里的图片的四元组,每个页面多个图片。
:param doc:
:return:
"""
# 使用 Counter 计数 img_bojid 的出现次数
img_bojid_counter = Counter(img[0] for page in doc for img in page.get_images())
# 找出出现次数超过 len(doc) 半数的 img_bojid
junk_limit = max(len(doc) * 0.5, junk_limit_min) # 对一些页数比较少的进行豁免
junk_img_bojids = [
img_bojid
for img_bojid, count in img_bojid_counter.items()
if count >= junk_limit
]
# todo 加个判断,用前十页就行,这些垃圾图片需要满足两个条件,不止出现的次数要足够多,而且图片占书页面积的比例要足够大,且图与图大小都差不多
# 有两种扫描版,一种文字版,这里可能会有误判
# 扫描版1:每页都有所有扫描页图片,特点是图占比大,每页展示1张
# 扫描版2,每页存储的扫描页图片数量递增,特点是图占比大,每页展示1张,需要清空junklist跑前50页图片信息用于分类判断
# 文 字版1.每页存储所有图片,特点是图片占页面比例不大,每页展示可能为0也可能不止1张 这种pdf需要拿前10页抽样检测img大小和个数,如果符合需要清空junklist
imgs_len_list = [len(page.get_images()) for page in doc]
special_limit_pages = 10
# 统一用前十页结果做判断
result = []
break_loop = False
for i, page in enumerate(doc):
if break_loop:
break
if i >= special_limit_pages:
break
page_result = process_image(
page
) # 这里不传junk_img_bojids,拿前十页所有图片信息用于后续分析
result.append(page_result)
for item in result:
if not any(
item
): # 如果任何一页没有图片,说明是个文字版,需要判断是否为特殊文字版
if (
max(imgs_len_list) == min(imgs_len_list)
and max(imgs_len_list) >= junk_limit_min
): # 如果是特殊文字版,就把junklist置空并break
junk_img_bojids = []
else: # 不是特殊文字版,是个普通文字版,但是存在垃圾图片,不置空junklist
pass
break_loop = True
break
if not break_loop:
# 获取前80%的元素
top_eighty_percent = get_top_percent_list(imgs_len_list, 0.8)
# 检查前80%的元素是否都相等
if len(set(top_eighty_percent)) == 1 and max(imgs_len_list) >= junk_limit_min:
# # 如果前10页跑完都有图,根据每页图片数量是否相等判断是否需要清除junklist
# if max(imgs_len_list) == min(imgs_len_list) and max(imgs_len_list) >= junk_limit_min:
# 前10页都有图,且每页数量一致,需要检测图片大小占页面的比例判断是否需要清除junklist
max_image_area_per_page = calculate_max_image_area_per_page(
result, page_width_pts, page_height_pts
)
if (
len(max_image_area_per_page) < 0.8 * special_limit_pages
): # 前10页不全是大图,说明可能是个文字版pdf,把垃圾图片list置空
junk_img_bojids = []
else: # 前10页都有图,而且80%都是大图,且每页图片数量一致并都很多,说明是扫描版1,不需要清空junklist
pass
else: # 每页图片数量不一致,需要清掉junklist全量跑前50页图片
junk_img_bojids = []
# 正式进入取前50页图片的信息流程
result = []
for i, page in enumerate(doc):
if i >= scan_max_page:
break
page_result = process_image(page, junk_img_bojids)
# logger.info(f"page {i} img_len: {len(page_result)}")
result.append(page_result)
return result, junk_img_bojids
def get_pdf_page_size_pts(doc: fitz.Document):
page_cnt = len(doc)
l: int = min(page_cnt, 50)
# 把所有宽度和高度塞到两个list 分别取中位数(中间遇到了个在纵页里塞横页的pdf,导致宽高互换了)
page_width_list = []
page_height_list = []
for i in range(l):
page = doc[i]
page_rect = page.rect
page_width_list.append(page_rect.width)
page_height_list.append(page_rect.height)
page_width_list.sort()
page_height_list.sort()
median_width = page_width_list[len(page_width_list) // 2]
median_height = page_height_list[len(page_height_list) // 2]
return median_width, median_height
def get_pdf_textlen_per_page(doc: fitz.Document):
text_len_lst = []
for page in doc:
# 拿包含img和text的所有blocks
# text_block = page.get_text("blocks")
# 拿所有text的blocks
# text_block = page.get_text("words")
# text_block_len = sum([len(t[4]) for t in text_block])
# 拿所有text的str
text_block = page.get_text('text')
text_block_len = len(text_block)
# logger.info(f"page {page.number} text_block_len: {text_block_len}")
text_len_lst.append(text_block_len)
return text_len_lst
def get_pdf_text_layout_per_page(doc: fitz.Document):
"""根据PDF文档的每一页文本布局,判断该页的文本布局是横向、纵向还是未知。
Args:
doc (fitz.Document): PDF文档对象。
Returns:
List[str]: 每一页的文本布局(横向、纵向、未知)。
"""
text_layout_list = []
for page_id, page in enumerate(doc):
if page_id >= scan_max_page:
break
# 创建每一页的纵向和横向的文本行数计数器
vertical_count = 0
horizontal_count = 0
text_dict = page.get_text('dict')
if 'blocks' in text_dict:
for block in text_dict['blocks']:
if 'lines' in block:
for line in block['lines']:
# 获取line的bbox顶点坐标
x0, y0, x1, y1 = line['bbox']
# 计算bbox的宽高
width = x1 - x0
height = y1 - y0
# 计算bbox的面积
area = width * height
font_sizes = []
for span in line['spans']:
if 'size' in span:
font_sizes.append(span['size'])
if len(font_sizes) > 0:
average_font_size = sum(font_sizes) / len(font_sizes)
else:
average_font_size = (
10 # 有的line拿不到font_size,先定一个阈值100
)
if (
area <= average_font_size**2
): # 判断bbox的面积是否小于平均字体大小的平方,单字无法计算是横向还是纵向
continue
else:
if 'wmode' in line: # 通过wmode判断文本方向
if line['wmode'] == 1: # 判断是否为竖向文本
vertical_count += 1
elif line['wmode'] == 0: # 判断是否为横向文本
horizontal_count += 1
# if 'dir' in line: # 通过旋转角度计算判断文本方向
# # 获取行的 "dir" 值
# dir_value = line['dir']
# cosine, sine = dir_value
# # 计算角度
# angle = math.degrees(math.acos(cosine))
#
# # 判断是否为横向文本
# if abs(angle - 0) < 0.01 or abs(angle - 180) < 0.01:
# # line_text = ' '.join(span['text'] for span in line['spans'])
# # print('This line is horizontal:', line_text)
# horizontal_count += 1
# # 判断是否为纵向文本
# elif abs(angle - 90) < 0.01 or abs(angle - 270) < 0.01:
# # line_text = ' '.join(span['text'] for span in line['spans'])
# # print('This line is vertical:', line_text)
# vertical_count += 1
# print(f"page_id: {page_id}, vertical_count: {vertical_count}, horizontal_count: {horizontal_count}")
# 判断每一页的文本布局
if vertical_count == 0 and horizontal_count == 0: # 该页没有文本,无法判断
text_layout_list.append('unknow')
continue
else:
if vertical_count > horizontal_count: # 该页的文本纵向行数大于横向的
text_layout_list.append('vertical')
else: # 该页的文本横向行数大于纵向的
text_layout_list.append('horizontal')
# logger.info(f"page_id: {page_id}, vertical_count: {vertical_count}, horizontal_count: {horizontal_count}")
return text_layout_list
"""定义一个自定义异常用来抛出单页svg太多的pdf"""
class PageSvgsTooManyError(Exception):
def __init__(self, message='Page SVGs are too many'):
self.message = message
super().__init__(self.message)
def get_svgs_per_page(doc: fitz.Document):
svgs_len_list = []
for page_id, page in enumerate(doc):
# svgs = page.get_drawings()
svgs = page.get_cdrawings() # 切换成get_cdrawings,效率更高
len_svgs = len(svgs)
if len_svgs >= 3000:
raise PageSvgsTooManyError()
else:
svgs_len_list.append(len_svgs)
# logger.info(f"page_id: {page_id}, svgs_len: {len(svgs)}")
return svgs_len_list
def get_imgs_per_page(doc: fitz.Document):
imgs_len_list = []
for page_id, page in enumerate(doc):
imgs = page.get_images()
imgs_len_list.append(len(imgs))
# logger.info(f"page_id: {page}, imgs_len: {len(imgs)}")
return imgs_len_list
def get_language(doc: fitz.Document):
"""
获取PDF文档的语言。
Args:
doc (fitz.Document): PDF文档对象。
Returns:
str: 文档语言,如 "en-US"。
"""
language_lst = []
for page_id, page in enumerate(doc):
if page_id >= scan_max_page:
break
# 拿所有text的str
text_block = page.get_text('text')
page_language = detect_lang(text_block)
language_lst.append(page_language)
# logger.info(f"page_id: {page_id}, page_language: {page_language}")
# 统计text_language_list中每种语言的个数
count_dict = Counter(language_lst)
# 输出text_language_list中出现的次数最多的语言
language = max(count_dict, key=count_dict.get)
return language
def check_invalid_chars(pdf_bytes):
"""乱码检测."""
# return detect_invalid_chars_by_pymupdf(pdf_bytes)
return detect_invalid_chars(pdf_bytes)
def pdf_meta_scan(pdf_bytes: bytes):
"""
:param s3_pdf_path:
:param pdf_bytes: pdf文件的二进制数据
几个维度来评价:是否加密,是否需要密码,纸张大小,总页数,是否文字可提取
"""
doc = fitz.open('pdf', pdf_bytes)
is_needs_password = doc.needs_pass
is_encrypted = doc.is_encrypted
total_page = len(doc)
if total_page == 0:
logger.warning(f'drop this pdf, drop_reason: {DropReason.EMPTY_PDF}')
result = {'_need_drop': True, '_drop_reason': DropReason.EMPTY_PDF}
return result
else:
page_width_pts, page_height_pts = get_pdf_page_size_pts(doc)
# logger.info(f"page_width_pts: {page_width_pts}, page_height_pts: {page_height_pts}")
# svgs_per_page = get_svgs_per_page(doc)
# logger.info(f"svgs_per_page: {svgs_per_page}")
imgs_per_page = get_imgs_per_page(doc)
# logger.info(f"imgs_per_page: {imgs_per_page}")
image_info_per_page, junk_img_bojids = get_image_info(
doc, page_width_pts, page_height_pts
)
# logger.info(f"image_info_per_page: {image_info_per_page}, junk_img_bojids: {junk_img_bojids}")
text_len_per_page = get_pdf_textlen_per_page(doc)
# logger.info(f"text_len_per_page: {text_len_per_page}")
# text_layout_per_page = get_pdf_text_layout_per_page(doc)
# logger.info(f"text_layout_per_page: {text_layout_per_page}")
# text_language = get_language(doc)
# logger.info(f"text_language: {text_language}")
invalid_chars = check_invalid_chars(pdf_bytes)
# logger.info(f"invalid_chars: {invalid_chars}")
# 最后输出一条json
res = {
'is_needs_password': is_needs_password,
'is_encrypted': is_encrypted,
'total_page': total_page,
'page_width_pts': int(page_width_pts),
'page_height_pts': int(page_height_pts),
'image_info_per_page': image_info_per_page,
'text_len_per_page': text_len_per_page,
# 'text_layout_per_page': text_layout_per_page,
# 'text_language': text_language,
# "svgs_per_page": svgs_per_page,
'imgs_per_page': imgs_per_page, # 增加每页img数量list
'junk_img_bojids': junk_img_bojids, # 增加垃圾图片的bojid list
'invalid_chars': invalid_chars,
'metadata': doc.metadata,
}
# logger.info(json.dumps(res, ensure_ascii=False))
return res
if __name__ == '__main__':
pass
# "D:\project/20231108code-clean\pdf_cost_time\竖排例子\净空法师-大乘无量寿.pdf"
# "D:\project/20231108code-clean\pdf_cost_time\竖排例子\三国演义_繁体竖排版.pdf"
# "D:\project/20231108code-clean\pdf_cost_time\scihub\scihub_86800000\libgen.scimag86880000-86880999.zip_10.1021/acsami.1c03109.s002.pdf"
# "D:/project/20231108code-clean/pdf_cost_time/scihub/scihub_18600000/libgen.scimag18645000-18645999.zip_10.1021/om3006239.pdf"
# file_content = read_file("D:/project/20231108code-clean/pdf_cost_time/scihub/scihub_31000000/libgen.scimag31098000-31098999.zip_10.1109/isit.2006.261791.pdf","") # noqa: E501
# file_content = read_file("D:\project/20231108code-clean\pdf_cost_time\竖排例子\净空法师_大乘无量寿.pdf","")
# doc = fitz.open("pdf", file_content)
# text_layout_lst = get_pdf_text_layout_per_page(doc)
# print(text_layout_lst)
import os
from pathlib import Path
from loguru import logger
from magic_pdf.integrations.rag.type import (ElementRelation, LayoutElements,
Node)
from magic_pdf.integrations.rag.utils import inference
class RagPageReader:
def __init__(self, pagedata: LayoutElements):
self.o = [
Node(
category_type=v.category_type,
text=v.text,
image_path=v.image_path,
anno_id=v.anno_id,
latex=v.latex,
html=v.html,
) for v in pagedata.layout_dets
]
self.pagedata = pagedata
def __iter__(self):
return iter(self.o)
def get_rel_map(self) -> list[ElementRelation]:
return self.pagedata.extra.element_relation
class RagDocumentReader:
def __init__(self, ragdata: list[LayoutElements]):
self.o = [RagPageReader(v) for v in ragdata]
def __iter__(self):
return iter(self.o)
class DataReader:
def __init__(self, path_or_directory: str, method: str, output_dir: str):
self.path_or_directory = path_or_directory
self.method = method
self.output_dir = output_dir
self.pdfs = []
if os.path.isdir(path_or_directory):
for doc_path in Path(path_or_directory).glob('*.pdf'):
self.pdfs.append(doc_path)
else:
assert path_or_directory.endswith('.pdf')
self.pdfs.append(Path(path_or_directory))
def get_documents_count(self) -> int:
"""Returns the number of documents in the directory."""
return len(self.pdfs)
def get_document_result(self, idx: int) -> RagDocumentReader | None:
"""
Args:
idx (int): the index of documents under the
directory path_or_directory
Returns:
RagDocumentReader | None: RagDocumentReader is an iterable object,
more details @RagDocumentReader
"""
if idx >= self.get_documents_count() or idx < 0:
logger.error(f'invalid idx: {idx}')
return None
res = inference(str(self.pdfs[idx]), self.output_dir, self.method)
if res is None:
logger.warning(f'failed to inference pdf {self.pdfs[idx]}')
return None
return RagDocumentReader(res)
def get_document_filename(self, idx: int) -> Path:
"""get the filename of the document."""
return self.pdfs[idx]
from enum import Enum
from pydantic import BaseModel, Field
# rag
class CategoryType(Enum): # py310 not support StrEnum
text = 'text'
title = 'title'
interline_equation = 'interline_equation'
image = 'image'
image_body = 'image_body'
image_caption = 'image_caption'
table = 'table'
table_body = 'table_body'
table_caption = 'table_caption'
table_footnote = 'table_footnote'
class ElementRelType(Enum):
sibling = 'sibling'
class PageInfo(BaseModel):
page_no: int = Field(description='the index of page, start from zero',
ge=0)
height: int = Field(description='the height of page', gt=0)
width: int = Field(description='the width of page', ge=0)
image_path: str | None = Field(description='the image of this page',
default=None)
class ContentObject(BaseModel):
category_type: CategoryType = Field(description='类别')
poly: list[float] = Field(
description=('Coordinates, need to convert back to PDF coordinates,'
' order is top-left, top-right, bottom-right, bottom-left'
' x,y coordinates'))
ignore: bool = Field(description='whether ignore this object',
default=False)
text: str | None = Field(description='text content of the object',
default=None)
image_path: str | None = Field(description='path of embedded image',
default=None)
order: int = Field(description='the order of this object within a page',
default=-1)
anno_id: int = Field(description='unique id', default=-1)
latex: str | None = Field(description='latex result', default=None)
html: str | None = Field(description='html result', default=None)
class ElementRelation(BaseModel):
source_anno_id: int = Field(description='unique id of the source object',
default=-1)
target_anno_id: int = Field(description='unique id of the target object',
default=-1)
relation: ElementRelType = Field(
description='the relation between source and target element')
class LayoutElementsExtra(BaseModel):
element_relation: list[ElementRelation] = Field(
description='the relation between source and target element')
class LayoutElements(BaseModel):
layout_dets: list[ContentObject] = Field(
description='layout element details')
page_info: PageInfo = Field(description='page info')
extra: LayoutElementsExtra = Field(description='extra information')
# iter data format
class Node(BaseModel):
category_type: CategoryType = Field(description='类别')
text: str | None = Field(description='text content of the object',
default=None)
image_path: str | None = Field(description='path of embedded image',
default=None)
anno_id: int = Field(description='unique id', default=-1)
latex: str | None = Field(description='latex result', default=None)
html: str | None = Field(description='html result', default=None)
import json
import os
from pathlib import Path
from loguru import logger
import magic_pdf.model as model_config
from magic_pdf.config.ocr_content_type import BlockType, ContentType
from magic_pdf.data.data_reader_writer import FileBasedDataReader
from magic_pdf.dict2md.ocr_mkcontent import merge_para_with_text
from magic_pdf.integrations.rag.type import (CategoryType, ContentObject,
ElementRelation, ElementRelType,
LayoutElements,
LayoutElementsExtra, PageInfo)
from magic_pdf.tools.common import do_parse, prepare_env
def convert_middle_json_to_layout_elements(
json_data: dict,
output_dir: str,
) -> list[LayoutElements]:
uniq_anno_id = 0
res: list[LayoutElements] = []
for page_no, page_data in enumerate(json_data['pdf_info']):
order_id = 0
page_info = PageInfo(
height=int(page_data['page_size'][1]),
width=int(page_data['page_size'][0]),
page_no=page_no,
)
layout_dets: list[ContentObject] = []
extra_element_relation: list[ElementRelation] = []
for para_block in page_data['para_blocks']:
para_text = ''
para_type = para_block['type']
if para_type == BlockType.Text:
para_text = merge_para_with_text(para_block)
x0, y0, x1, y1 = para_block['bbox']
content = ContentObject(
anno_id=uniq_anno_id,
category_type=CategoryType.text,
text=para_text,
order=order_id,
poly=[x0, y0, x1, y0, x1, y1, x0, y1],
)
uniq_anno_id += 1
order_id += 1
layout_dets.append(content)
elif para_type == BlockType.Title:
para_text = merge_para_with_text(para_block)
x0, y0, x1, y1 = para_block['bbox']
content = ContentObject(
anno_id=uniq_anno_id,
category_type=CategoryType.title,
text=para_text,
order=order_id,
poly=[x0, y0, x1, y0, x1, y1, x0, y1],
)
uniq_anno_id += 1
order_id += 1
layout_dets.append(content)
elif para_type == BlockType.InterlineEquation:
para_text = merge_para_with_text(para_block)
x0, y0, x1, y1 = para_block['bbox']
content = ContentObject(
anno_id=uniq_anno_id,
category_type=CategoryType.interline_equation,
text=para_text,
order=order_id,
poly=[x0, y0, x1, y0, x1, y1, x0, y1],
)
uniq_anno_id += 1
order_id += 1
layout_dets.append(content)
elif para_type == BlockType.Image:
body_anno_id = -1
caption_anno_id = -1
for block in para_block['blocks']:
if block['type'] == BlockType.ImageBody:
for line in block['lines']:
for span in line['spans']:
if span['type'] == ContentType.Image:
x0, y0, x1, y1 = block['bbox']
content = ContentObject(
anno_id=uniq_anno_id,
category_type=CategoryType.image_body,
image_path=os.path.join(
output_dir, span['image_path']),
order=order_id,
poly=[x0, y0, x1, y0, x1, y1, x0, y1],
)
body_anno_id = uniq_anno_id
uniq_anno_id += 1
order_id += 1
layout_dets.append(content)
for block in para_block['blocks']:
if block['type'] == BlockType.ImageCaption:
para_text += merge_para_with_text(block)
x0, y0, x1, y1 = block['bbox']
content = ContentObject(
anno_id=uniq_anno_id,
category_type=CategoryType.image_caption,
text=para_text,
order=order_id,
poly=[x0, y0, x1, y0, x1, y1, x0, y1],
)
caption_anno_id = uniq_anno_id
uniq_anno_id += 1
order_id += 1
layout_dets.append(content)
if body_anno_id > 0 and caption_anno_id > 0:
element_relation = ElementRelation(
relation=ElementRelType.sibling,
source_anno_id=body_anno_id,
target_anno_id=caption_anno_id,
)
extra_element_relation.append(element_relation)
elif para_type == BlockType.Table:
body_anno_id, caption_anno_id, footnote_anno_id = -1, -1, -1
for block in para_block['blocks']:
if block['type'] == BlockType.TableCaption:
para_text += merge_para_with_text(block)
x0, y0, x1, y1 = block['bbox']
content = ContentObject(
anno_id=uniq_anno_id,
category_type=CategoryType.table_caption,
text=para_text,
order=order_id,
poly=[x0, y0, x1, y0, x1, y1, x0, y1],
)
caption_anno_id = uniq_anno_id
uniq_anno_id += 1
order_id += 1
layout_dets.append(content)
for block in para_block['blocks']:
if block['type'] == BlockType.TableBody:
for line in block['lines']:
for span in line['spans']:
if span['type'] == ContentType.Table:
x0, y0, x1, y1 = para_block['bbox']
content = ContentObject(
anno_id=uniq_anno_id,
category_type=CategoryType.table_body,
order=order_id,
poly=[x0, y0, x1, y0, x1, y1, x0, y1],
)
body_anno_id = uniq_anno_id
uniq_anno_id += 1
order_id += 1
# if processed by table model
if span.get('latex', ''):
content.latex = span['latex']
else:
content.image_path = os.path.join(
output_dir, span['image_path'])
layout_dets.append(content)
for block in para_block['blocks']:
if block['type'] == BlockType.TableFootnote:
para_text += merge_para_with_text(block)
x0, y0, x1, y1 = block['bbox']
content = ContentObject(
anno_id=uniq_anno_id,
category_type=CategoryType.table_footnote,
text=para_text,
order=order_id,
poly=[x0, y0, x1, y0, x1, y1, x0, y1],
)
footnote_anno_id = uniq_anno_id
uniq_anno_id += 1
order_id += 1
layout_dets.append(content)
if caption_anno_id != -1 and body_anno_id != -1:
element_relation = ElementRelation(
relation=ElementRelType.sibling,
source_anno_id=body_anno_id,
target_anno_id=caption_anno_id,
)
extra_element_relation.append(element_relation)
if footnote_anno_id != -1 and body_anno_id != -1:
element_relation = ElementRelation(
relation=ElementRelType.sibling,
source_anno_id=body_anno_id,
target_anno_id=footnote_anno_id,
)
extra_element_relation.append(element_relation)
res.append(
LayoutElements(
page_info=page_info,
layout_dets=layout_dets,
extra=LayoutElementsExtra(
element_relation=extra_element_relation),
))
return res
def inference(path, output_dir, method):
model_config.__use_inside_model__ = True
model_config.__model_mode__ = 'full'
if output_dir == '':
if os.path.isdir(path):
output_dir = os.path.join(path, 'output')
else:
output_dir = os.path.join(os.path.dirname(path), 'output')
local_image_dir, local_md_dir = prepare_env(output_dir,
str(Path(path).stem), method)
def read_fn(path):
disk_rw = FileBasedDataReader(os.path.dirname(path))
return disk_rw.read(os.path.basename(path))
def parse_doc(doc_path: str):
try:
file_name = str(Path(doc_path).stem)
pdf_data = read_fn(doc_path)
do_parse(
output_dir,
file_name,
pdf_data,
[],
method,
False,
f_draw_span_bbox=False,
f_draw_layout_bbox=False,
f_dump_md=False,
f_dump_middle_json=True,
f_dump_model_json=False,
f_dump_orig_pdf=False,
f_dump_content_list=False,
f_draw_model_bbox=False,
)
middle_json_fn = os.path.join(local_md_dir,
f'{file_name}_middle.json')
with open(middle_json_fn) as fd:
jso = json.load(fd)
os.remove(middle_json_fn)
return convert_middle_json_to_layout_elements(jso, local_image_dir)
except Exception as e:
logger.exception(e)
return parse_doc(path)
if __name__ == '__main__':
import pprint
base_dir = '/opt/data/pdf/resources/samples/'
if 0:
with open(base_dir + 'json_outputs/middle.json') as f:
d = json.load(f)
result = convert_middle_json_to_layout_elements(d, '/tmp')
pprint.pp(result)
if 0:
with open(base_dir + 'json_outputs/middle.3.json') as f:
d = json.load(f)
result = convert_middle_json_to_layout_elements(d, '/tmp')
pprint.pp(result)
if 1:
res = inference(
base_dir + 'samples/pdf/one_page_with_table_image.pdf',
'/tmp/output',
'ocr',
)
pprint.pp(res)
import math
def _is_in_or_part_overlap(box1, box2) -> bool:
"""两个bbox是否有部分重叠或者包含."""
if box1 is None or box2 is None:
return False
x0_1, y0_1, x1_1, y1_1 = box1
x0_2, y0_2, x1_2, y1_2 = box2
return not (x1_1 < x0_2 or # box1在box2的左边
x0_1 > x1_2 or # box1在box2的右边
y1_1 < y0_2 or # box1在box2的上边
y0_1 > y1_2) # box1在box2的下边
def _is_in_or_part_overlap_with_area_ratio(box1,
box2,
area_ratio_threshold=0.6):
"""判断box1是否在box2里面,或者box1和box2有部分重叠,且重叠面积占box1的比例超过area_ratio_threshold."""
if box1 is None or box2 is None:
return False
x0_1, y0_1, x1_1, y1_1 = box1
x0_2, y0_2, x1_2, y1_2 = box2
if not _is_in_or_part_overlap(box1, box2):
return False
# 计算重叠面积
x_left = max(x0_1, x0_2)
y_top = max(y0_1, y0_2)
x_right = min(x1_1, x1_2)
y_bottom = min(y1_1, y1_2)
overlap_area = (x_right - x_left) * (y_bottom - y_top)
# 计算box1的面积
box1_area = (x1_1 - x0_1) * (y1_1 - y0_1)
return overlap_area / box1_area > area_ratio_threshold
def _is_in(box1, box2) -> bool:
"""box1是否完全在box2里面."""
x0_1, y0_1, x1_1, y1_1 = box1
x0_2, y0_2, x1_2, y1_2 = box2
return (x0_1 >= x0_2 and # box1的左边界不在box2的左边外
y0_1 >= y0_2 and # box1的上边界不在box2的上边外
x1_1 <= x1_2 and # box1的右边界不在box2的右边外
y1_1 <= y1_2) # box1的下边界不在box2的下边外
def _is_part_overlap(box1, box2) -> bool:
"""两个bbox是否有部分重叠,但不完全包含."""
if box1 is None or box2 is None:
return False
return _is_in_or_part_overlap(box1, box2) and not _is_in(box1, box2)
def _left_intersect(left_box, right_box):
"""检查两个box的左边界是否有交集,也就是left_box的右边界是否在right_box的左边界内."""
if left_box is None or right_box is None:
return False
x0_1, y0_1, x1_1, y1_1 = left_box
x0_2, y0_2, x1_2, y1_2 = right_box
return x1_1 > x0_2 and x0_1 < x0_2 and (y0_1 <= y0_2 <= y1_1
or y0_1 <= y1_2 <= y1_1)
def _right_intersect(left_box, right_box):
"""检查box是否在右侧边界有交集,也就是left_box的左边界是否在right_box的右边界内."""
if left_box is None or right_box is None:
return False
x0_1, y0_1, x1_1, y1_1 = left_box
x0_2, y0_2, x1_2, y1_2 = right_box
return x0_1 < x1_2 and x1_1 > x1_2 and (y0_1 <= y0_2 <= y1_1
or y0_1 <= y1_2 <= y1_1)
def _is_vertical_full_overlap(box1, box2, x_torlence=2):
"""x方向上:要么box1包含box2, 要么box2包含box1。不能部分包含 y方向上:box1和box2有重叠."""
# 解析box的坐标
x11, y11, x12, y12 = box1 # 左上角和右下角的坐标 (x1, y1, x2, y2)
x21, y21, x22, y22 = box2
# 在x轴方向上,box1是否包含box2 或 box2包含box1
contains_in_x = (x11 - x_torlence <= x21 and x12 + x_torlence >= x22) or (
x21 - x_torlence <= x11 and x22 + x_torlence >= x12)
# 在y轴方向上,box1和box2是否有重叠
overlap_in_y = not (y12 < y21 or y11 > y22)
return contains_in_x and overlap_in_y
def _is_bottom_full_overlap(box1, box2, y_tolerance=2):
"""检查box1下方和box2的上方有轻微的重叠,轻微程度收到y_tolerance的限制 这个函数和_is_vertical-
full_overlap的区别是,这个函数允许box1和box2在x方向上有轻微的重叠,允许一定的模糊度."""
if box1 is None or box2 is None:
return False
x0_1, y0_1, x1_1, y1_1 = box1
x0_2, y0_2, x1_2, y1_2 = box2
tolerance_margin = 2
is_xdir_full_overlap = (
(x0_1 - tolerance_margin <= x0_2 <= x1_1 + tolerance_margin
and x0_1 - tolerance_margin <= x1_2 <= x1_1 + tolerance_margin)
or (x0_2 - tolerance_margin <= x0_1 <= x1_2 + tolerance_margin
and x0_2 - tolerance_margin <= x1_1 <= x1_2 + tolerance_margin))
return y0_2 < y1_1 and 0 < (y1_1 -
y0_2) < y_tolerance and is_xdir_full_overlap
def _is_left_overlap(
box1,
box2,
):
"""检查box1的左侧是否和box2有重叠 在Y方向上可以是部分重叠或者是完全重叠。不分box1和box2的上下关系,也就是无论box1在box2下
方还是box2在box1下方,都可以检测到重叠。 X方向上."""
def __overlap_y(Ay1, Ay2, By1, By2):
return max(0, min(Ay2, By2) - max(Ay1, By1))
if box1 is None or box2 is None:
return False
x0_1, y0_1, x1_1, y1_1 = box1
x0_2, y0_2, x1_2, y1_2 = box2
y_overlap_len = __overlap_y(y0_1, y1_1, y0_2, y1_2)
ratio_1 = 1.0 * y_overlap_len / (y1_1 - y0_1) if y1_1 - y0_1 != 0 else 0
ratio_2 = 1.0 * y_overlap_len / (y1_2 - y0_2) if y1_2 - y0_2 != 0 else 0
vertical_overlap_cond = ratio_1 >= 0.5 or ratio_2 >= 0.5
# vertical_overlap_cond = y0_1<=y0_2<=y1_1 or y0_1<=y1_2<=y1_1 or y0_2<=y0_1<=y1_2 or y0_2<=y1_1<=y1_2
return x0_1 <= x0_2 <= x1_1 and vertical_overlap_cond
def __is_overlaps_y_exceeds_threshold(bbox1,
bbox2,
overlap_ratio_threshold=0.8):
"""检查两个bbox在y轴上是否有重叠,并且该重叠区域的高度占两个bbox高度更低的那个超过80%"""
_, y0_1, _, y1_1 = bbox1
_, y0_2, _, y1_2 = bbox2
overlap = max(0, min(y1_1, y1_2) - max(y0_1, y0_2))
height1, height2 = y1_1 - y0_1, y1_2 - y0_2
# max_height = max(height1, height2)
min_height = min(height1, height2)
return (overlap / min_height) > overlap_ratio_threshold
def calculate_iou(bbox1, bbox2):
"""计算两个边界框的交并比(IOU)。
Args:
bbox1 (list[float]): 第一个边界框的坐标,格式为 [x1, y1, x2, y2],其中 (x1, y1) 为左上角坐标,(x2, y2) 为右下角坐标。
bbox2 (list[float]): 第二个边界框的坐标,格式与 `bbox1` 相同。
Returns:
float: 两个边界框的交并比(IOU),取值范围为 [0, 1]。
"""
# Determine the coordinates of the intersection rectangle
x_left = max(bbox1[0], bbox2[0])
y_top = max(bbox1[1], bbox2[1])
x_right = min(bbox1[2], bbox2[2])
y_bottom = min(bbox1[3], bbox2[3])
if x_right < x_left or y_bottom < y_top:
return 0.0
# The area of overlap area
intersection_area = (x_right - x_left) * (y_bottom - y_top)
# The area of both rectangles
bbox1_area = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1])
bbox2_area = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1])
if any([bbox1_area == 0, bbox2_area == 0]):
return 0
# Compute the intersection over union by taking the intersection area
# and dividing it by the sum of both areas minus the intersection area
iou = intersection_area / float(bbox1_area + bbox2_area - intersection_area)
return iou
def calculate_overlap_area_2_minbox_area_ratio(bbox1, bbox2):
"""计算box1和box2的重叠面积占最小面积的box的比例."""
# Determine the coordinates of the intersection rectangle
x_left = max(bbox1[0], bbox2[0])
y_top = max(bbox1[1], bbox2[1])
x_right = min(bbox1[2], bbox2[2])
y_bottom = min(bbox1[3], bbox2[3])
if x_right < x_left or y_bottom < y_top:
return 0.0
# The area of overlap area
intersection_area = (x_right - x_left) * (y_bottom - y_top)
min_box_area = min([(bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1]),
(bbox2[3] - bbox2[1]) * (bbox2[2] - bbox2[0])])
if min_box_area == 0:
return 0
else:
return intersection_area / min_box_area
def calculate_overlap_area_in_bbox1_area_ratio(bbox1, bbox2):
"""计算box1和box2的重叠面积占bbox1的比例."""
# Determine the coordinates of the intersection rectangle
x_left = max(bbox1[0], bbox2[0])
y_top = max(bbox1[1], bbox2[1])
x_right = min(bbox1[2], bbox2[2])
y_bottom = min(bbox1[3], bbox2[3])
if x_right < x_left or y_bottom < y_top:
return 0.0
# The area of overlap area
intersection_area = (x_right - x_left) * (y_bottom - y_top)
bbox1_area = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1])
if bbox1_area == 0:
return 0
else:
return intersection_area / bbox1_area
def get_minbox_if_overlap_by_ratio(bbox1, bbox2, ratio):
"""通过calculate_overlap_area_2_minbox_area_ratio计算两个bbox重叠的面积占最小面积的box的比例
如果比例大于ratio,则返回小的那个bbox, 否则返回None."""
x1_min, y1_min, x1_max, y1_max = bbox1
x2_min, y2_min, x2_max, y2_max = bbox2
area1 = (x1_max - x1_min) * (y1_max - y1_min)
area2 = (x2_max - x2_min) * (y2_max - y2_min)
overlap_ratio = calculate_overlap_area_2_minbox_area_ratio(bbox1, bbox2)
if overlap_ratio > ratio:
if area1 <= area2:
return bbox1
else:
return bbox2
else:
return None
def get_bbox_in_boundary(bboxes: list, boundary: tuple) -> list:
x0, y0, x1, y1 = boundary
new_boxes = [
box for box in bboxes
if box[0] >= x0 and box[1] >= y0 and box[2] <= x1 and box[3] <= y1
]
return new_boxes
def is_vbox_on_side(bbox, width, height, side_threshold=0.2):
"""判断一个bbox是否在pdf页面的边缘."""
x0, x1 = bbox[0], bbox[2]
if x1 <= width * side_threshold or x0 >= width * (1 - side_threshold):
return True
return False
def find_top_nearest_text_bbox(pymu_blocks, obj_bbox):
tolerance_margin = 4
top_boxes = [
box for box in pymu_blocks
if obj_bbox[1] - box['bbox'][3] >= -tolerance_margin
and not _is_in(box['bbox'], obj_bbox)
]
# 然后找到X方向上有互相重叠的
top_boxes = [
box for box in top_boxes if any([
obj_bbox[0] - tolerance_margin <= box['bbox'][0] <= obj_bbox[2] +
tolerance_margin, obj_bbox[0] -
tolerance_margin <= box['bbox'][2] <= obj_bbox[2] +
tolerance_margin, box['bbox'][0] -
tolerance_margin <= obj_bbox[0] <= box['bbox'][2] +
tolerance_margin, box['bbox'][0] -
tolerance_margin <= obj_bbox[2] <= box['bbox'][2] +
tolerance_margin
])
]
# 然后找到y1最大的那个
if len(top_boxes) > 0:
top_boxes.sort(key=lambda x: x['bbox'][3], reverse=True)
return top_boxes[0]
else:
return None
def find_bottom_nearest_text_bbox(pymu_blocks, obj_bbox):
bottom_boxes = [
box for box in pymu_blocks if box['bbox'][1] -
obj_bbox[3] >= -2 and not _is_in(box['bbox'], obj_bbox)
]
# 然后找到X方向上有互相重叠的
bottom_boxes = [
box for box in bottom_boxes if any([
obj_bbox[0] - 2 <= box['bbox'][0] <= obj_bbox[2] + 2, obj_bbox[0] -
2 <= box['bbox'][2] <= obj_bbox[2] + 2, box['bbox'][0] -
2 <= obj_bbox[0] <= box['bbox'][2] + 2, box['bbox'][0] -
2 <= obj_bbox[2] <= box['bbox'][2] + 2
])
]
# 然后找到y0最小的那个
if len(bottom_boxes) > 0:
bottom_boxes.sort(key=lambda x: x['bbox'][1], reverse=False)
return bottom_boxes[0]
else:
return None
def find_left_nearest_text_bbox(pymu_blocks, obj_bbox):
"""寻找左侧最近的文本block."""
left_boxes = [
box for box in pymu_blocks if obj_bbox[0] -
box['bbox'][2] >= -2 and not _is_in(box['bbox'], obj_bbox)
]
# 然后找到X方向上有互相重叠的
left_boxes = [
box for box in left_boxes if any([
obj_bbox[1] - 2 <= box['bbox'][1] <= obj_bbox[3] + 2, obj_bbox[1] -
2 <= box['bbox'][3] <= obj_bbox[3] + 2, box['bbox'][1] -
2 <= obj_bbox[1] <= box['bbox'][3] + 2, box['bbox'][1] -
2 <= obj_bbox[3] <= box['bbox'][3] + 2
])
]
# 然后找到x1最大的那个
if len(left_boxes) > 0:
left_boxes.sort(key=lambda x: x['bbox'][2], reverse=True)
return left_boxes[0]
else:
return None
def find_right_nearest_text_bbox(pymu_blocks, obj_bbox):
"""寻找右侧最近的文本block."""
right_boxes = [
box for box in pymu_blocks if box['bbox'][0] -
obj_bbox[2] >= -2 and not _is_in(box['bbox'], obj_bbox)
]
# 然后找到X方向上有互相重叠的
right_boxes = [
box for box in right_boxes if any([
obj_bbox[1] - 2 <= box['bbox'][1] <= obj_bbox[3] + 2, obj_bbox[1] -
2 <= box['bbox'][3] <= obj_bbox[3] + 2, box['bbox'][1] -
2 <= obj_bbox[1] <= box['bbox'][3] + 2, box['bbox'][1] -
2 <= obj_bbox[3] <= box['bbox'][3] + 2
])
]
# 然后找到x0最小的那个
if len(right_boxes) > 0:
right_boxes.sort(key=lambda x: x['bbox'][0], reverse=False)
return right_boxes[0]
else:
return None
def bbox_relative_pos(bbox1, bbox2):
"""判断两个矩形框的相对位置关系.
Args:
bbox1: 一个四元组,表示第一个矩形框的左上角和右下角的坐标,格式为(x1, y1, x1b, y1b)
bbox2: 一个四元组,表示第二个矩形框的左上角和右下角的坐标,格式为(x2, y2, x2b, y2b)
Returns:
一个四元组,表示矩形框1相对于矩形框2的位置关系,格式为(left, right, bottom, top)
其中,left表示矩形框1是否在矩形框2的左侧,right表示矩形框1是否在矩形框2的右侧,
bottom表示矩形框1是否在矩形框2的下方,top表示矩形框1是否在矩形框2的上方
"""
x1, y1, x1b, y1b = bbox1
x2, y2, x2b, y2b = bbox2
left = x2b < x1
right = x1b < x2
bottom = y2b < y1
top = y1b < y2
return left, right, bottom, top
def bbox_distance(bbox1, bbox2):
"""计算两个矩形框的距离。
Args:
bbox1 (tuple): 第一个矩形框的坐标,格式为 (x1, y1, x2, y2),其中 (x1, y1) 为左上角坐标,(x2, y2) 为右下角坐标。
bbox2 (tuple): 第二个矩形框的坐标,格式为 (x1, y1, x2, y2),其中 (x1, y1) 为左上角坐标,(x2, y2) 为右下角坐标。
Returns:
float: 矩形框之间的距离。
"""
def dist(point1, point2):
return math.sqrt((point1[0] - point2[0])**2 +
(point1[1] - point2[1])**2)
x1, y1, x1b, y1b = bbox1
x2, y2, x2b, y2b = bbox2
left, right, bottom, top = bbox_relative_pos(bbox1, bbox2)
if top and left:
return dist((x1, y1b), (x2b, y2))
elif left and bottom:
return dist((x1, y1), (x2b, y2b))
elif bottom and right:
return dist((x1b, y1), (x2, y2b))
elif right and top:
return dist((x1b, y1b), (x2, y2))
elif left:
return x1 - x2b
elif right:
return x2 - x1b
elif bottom:
return y1 - y2b
elif top:
return y2 - y1b
return 0.0
def box_area(bbox):
return (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
def get_overlap_area(bbox1, bbox2):
"""计算box1和box2的重叠面积占bbox1的比例."""
# Determine the coordinates of the intersection rectangle
x_left = max(bbox1[0], bbox2[0])
y_top = max(bbox1[1], bbox2[1])
x_right = min(bbox1[2], bbox2[2])
y_bottom = min(bbox1[3], bbox2[3])
if x_right < x_left or y_bottom < y_top:
return 0.0
# The area of overlap area
return (x_right - x_left) * (y_bottom - y_top)
def calculate_vertical_projection_overlap_ratio(block1, block2):
"""
Calculate the proportion of the x-axis covered by the vertical projection of two blocks.
Args:
block1 (tuple): Coordinates of the first block (x0, y0, x1, y1).
block2 (tuple): Coordinates of the second block (x0, y0, x1, y1).
Returns:
float: The proportion of the x-axis covered by the vertical projection of the two blocks.
"""
x0_1, _, x1_1, _ = block1
x0_2, _, x1_2, _ = block2
# Calculate the intersection of the x-coordinates
x_left = max(x0_1, x0_2)
x_right = min(x1_1, x1_2)
if x_right < x_left:
return 0.0
# Length of the intersection
intersection_length = x_right - x_left
# Length of the x-axis projection of the first block
block1_length = x1_1 - x0_1
if block1_length == 0:
return 0.0
# Proportion of the x-axis covered by the intersection
# logger.info(f"intersection_length: {intersection_length}, block1_length: {block1_length}")
return intersection_length / block1_length
# Copyright (c) Opendatalab. All rights reserved.
import torch
import gc
def clean_memory(device='cuda'):
if device == 'cuda':
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
elif str(device).startswith("npu"):
import torch_npu
if torch_npu.npu.is_available():
torch_npu.npu.empty_cache()
elif str(device).startswith("mps"):
torch.mps.empty_cache()
gc.collect()
\ No newline at end of file
def join_path(*args):
return '/'.join(str(s).rstrip('/') for s in args)
def get_top_percent_list(num_list, percent):
"""
获取列表中前百分之多少的元素
:param num_list:
:param percent:
:return:
"""
if len(num_list) == 0:
top_percent_list = []
else:
# 对imgs_len_list排序
sorted_imgs_len_list = sorted(num_list, reverse=True)
# 计算 percent 的索引
top_percent_index = int(len(sorted_imgs_len_list) * percent)
# 取前80%的元素
top_percent_list = sorted_imgs_len_list[:top_percent_index]
return top_percent_list
def mymax(alist: list):
if len(alist) == 0:
return 0 # 空是0, 0*0也是0大小q
else:
return max(alist)
def parse_bucket_key(s3_full_path: str):
"""
输入 s3://bucket/path/to/my/file.txt
输出 bucket, path/to/my/file.txt
"""
s3_full_path = s3_full_path.strip()
if s3_full_path.startswith("s3://"):
s3_full_path = s3_full_path[5:]
if s3_full_path.startswith("/"):
s3_full_path = s3_full_path[1:]
bucket, key = s3_full_path.split("/", 1)
return bucket, key
"""根据bucket的名字返回对应的s3 AK, SK,endpoint三元组."""
import json
import os
from loguru import logger
from magic_pdf.config.constants import MODEL_NAME
from magic_pdf.libs.commons import parse_bucket_key
# 定义配置文件名常量
CONFIG_FILE_NAME = os.getenv('MINERU_TOOLS_CONFIG_JSON', 'magic-pdf.json')
def read_config():
if os.path.isabs(CONFIG_FILE_NAME):
config_file = CONFIG_FILE_NAME
else:
home_dir = os.path.expanduser('~')
config_file = os.path.join(home_dir, CONFIG_FILE_NAME)
if not os.path.exists(config_file):
raise FileNotFoundError(f'{config_file} not found')
with open(config_file, 'r', encoding='utf-8') as f:
config = json.load(f)
return config
def get_s3_config(bucket_name: str):
"""~/magic-pdf.json 读出来."""
config = read_config()
bucket_info = config.get('bucket_info')
if bucket_name not in bucket_info:
access_key, secret_key, storage_endpoint = bucket_info['[default]']
else:
access_key, secret_key, storage_endpoint = bucket_info[bucket_name]
if access_key is None or secret_key is None or storage_endpoint is None:
raise Exception(f'ak, sk or endpoint not found in {CONFIG_FILE_NAME}')
# logger.info(f"get_s3_config: ak={access_key}, sk={secret_key}, endpoint={storage_endpoint}")
return access_key, secret_key, storage_endpoint
def get_s3_config_dict(path: str):
access_key, secret_key, storage_endpoint = get_s3_config(get_bucket_name(path))
return {'ak': access_key, 'sk': secret_key, 'endpoint': storage_endpoint}
def get_bucket_name(path):
bucket, key = parse_bucket_key(path)
return bucket
def get_local_models_dir():
config = read_config()
models_dir = config.get('models-dir')
if models_dir is None:
logger.warning(f"'models-dir' not found in {CONFIG_FILE_NAME}, use '/tmp/models' as default")
return '/tmp/models'
else:
return models_dir
def get_local_layoutreader_model_dir():
config = read_config()
layoutreader_model_dir = config.get('layoutreader-model-dir')
if layoutreader_model_dir is None or not os.path.exists(layoutreader_model_dir):
home_dir = os.path.expanduser('~')
layoutreader_at_modelscope_dir_path = os.path.join(home_dir, '.cache/modelscope/hub/ppaanngggg/layoutreader')
logger.warning(f"'layoutreader-model-dir' not exists, use {layoutreader_at_modelscope_dir_path} as default")
return layoutreader_at_modelscope_dir_path
else:
return layoutreader_model_dir
def get_device():
config = read_config()
device = config.get('device-mode')
if device is None:
logger.warning(f"'device-mode' not found in {CONFIG_FILE_NAME}, use 'cpu' as default")
return 'cpu'
else:
return device
def get_table_recog_config():
config = read_config()
table_config = config.get('table-config')
if table_config is None:
logger.warning(f"'table-config' not found in {CONFIG_FILE_NAME}, use 'False' as default")
return json.loads(f'{{"model": "{MODEL_NAME.RAPID_TABLE}","enable": false, "max_time": 400}}')
else:
return table_config
def get_layout_config():
config = read_config()
layout_config = config.get('layout-config')
if layout_config is None:
logger.warning(f"'layout-config' not found in {CONFIG_FILE_NAME}, use '{MODEL_NAME.LAYOUTLMv3}' as default")
return json.loads(f'{{"model": "{MODEL_NAME.LAYOUTLMv3}"}}')
else:
return layout_config
def get_formula_config():
config = read_config()
formula_config = config.get('formula-config')
if formula_config is None:
logger.warning(f"'formula-config' not found in {CONFIG_FILE_NAME}, use 'True' as default")
return json.loads(f'{{"mfd_model": "{MODEL_NAME.YOLO_V8_MFD}","mfr_model": "{MODEL_NAME.UniMerNet_v2_Small}","enable": true}}')
else:
return formula_config
def get_llm_aided_config():
config = read_config()
llm_aided_config = config.get('llm-aided-config')
if llm_aided_config is None:
logger.warning(f"'llm-aided-config' not found in {CONFIG_FILE_NAME}, use 'None' as default")
return None
else:
return llm_aided_config
def get_latex_delimiter_config():
config = read_config()
latex_delimiter_config = config.get('latex-delimiter-config')
if latex_delimiter_config is None:
logger.warning(f"'latex-delimiter-config' not found in {CONFIG_FILE_NAME}, use 'None' as default")
return None
else:
return latex_delimiter_config
if __name__ == '__main__':
ak, sk, endpoint = get_s3_config('llm-raw')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment