"vscode:/vscode.git/clone" did not exist on "6b481595f096254817902d1dc0e1ead18e5610ca"
Commit bd927919 authored by myhloli's avatar myhloli
Browse files

refactor: rename init file and update app.py to enable parsing method

parent f5016508
"""span维度自定义字段."""
# span是否是跨页合并的
CROSS_PAGE = 'cross_page'
"""
block维度自定义字段
"""
# block中lines是否被删除
LINES_DELETED = 'lines_deleted'
# table recognition max time default value
TABLE_MAX_TIME_VALUE = 400
# pp_table_result_max_length
TABLE_MAX_LEN = 480
# table master structure dict
TABLE_MASTER_DICT = 'table_master_structure_dict.txt'
# table master dir
TABLE_MASTER_DIR = 'table_structure_tablemaster_infer/'
# pp detect model dir
DETECT_MODEL_DIR = 'ch_PP-OCRv4_det_infer'
# pp rec model dir
REC_MODEL_DIR = 'ch_PP-OCRv4_rec_infer'
# pp rec char dict path
REC_CHAR_DICT = 'ppocr_keys_v1.txt'
# pp rec copy rec directory
PP_REC_DIRECTORY = '.paddleocr/whl/rec/ch/ch_PP-OCRv4_rec_infer'
# pp rec copy det directory
PP_DET_DIRECTORY = '.paddleocr/whl/det/ch/ch_PP-OCRv4_det_infer'
class MODEL_NAME:
# pp table structure algorithm
TABLE_MASTER = 'tablemaster'
# struct eqtable
STRUCT_EQTABLE = 'struct_eqtable'
DocLayout_YOLO = 'doclayout_yolo'
LAYOUTLMv3 = 'layoutlmv3'
YOLO_V8_MFD = 'yolo_v8_mfd'
UniMerNet_v2_Small = 'unimernet_small'
RAPID_TABLE = 'rapid_table'
YOLO_V11_LangDetect = 'yolo_v11n_langdetect'
PARSE_TYPE_TXT = 'txt'
PARSE_TYPE_OCR = 'ocr'
class DropReason:
TEXT_BLCOK_HOR_OVERLAP = 'text_block_horizontal_overlap' # 文字块有水平互相覆盖,导致无法准确定位文字顺序
USEFUL_BLOCK_HOR_OVERLAP = (
'useful_block_horizontal_overlap' # 需保留的block水平覆盖
)
COMPLICATED_LAYOUT = 'complicated_layout' # 复杂的布局,暂时不支持
TOO_MANY_LAYOUT_COLUMNS = 'too_many_layout_columns' # 目前不支持分栏超过2列的
COLOR_BACKGROUND_TEXT_BOX = 'color_background_text_box' # 含有带色块的PDF,色块会改变阅读顺序,目前不支持带底色文字块的PDF。
HIGH_COMPUTATIONAL_lOAD_BY_IMGS = (
'high_computational_load_by_imgs' # 含特殊图片,计算量太大,从而丢弃
)
HIGH_COMPUTATIONAL_lOAD_BY_SVGS = (
'high_computational_load_by_svgs' # 特殊的SVG图,计算量太大,从而丢弃
)
HIGH_COMPUTATIONAL_lOAD_BY_TOTAL_PAGES = 'high_computational_load_by_total_pages' # 计算量超过负荷,当前方法下计算量消耗过大
MISS_DOC_LAYOUT_RESULT = 'missing doc_layout_result' # 版面分析失败
Exception = '_exception' # 解析中发生异常
ENCRYPTED = 'encrypted' # PDF是加密的
EMPTY_PDF = 'total_page=0' # PDF页面总数为0
NOT_IS_TEXT_PDF = 'not_is_text_pdf' # 不是文字版PDF,无法直接解析
DENSE_SINGLE_LINE_BLOCK = 'dense_single_line_block' # 无法清晰的分段
TITLE_DETECTION_FAILED = 'title_detection_failed' # 探测标题失败
TITLE_LEVEL_FAILED = (
'title_level_failed' # 分析标题级别失败(例如一级、二级、三级标题)
)
PARA_SPLIT_FAILED = 'para_split_failed' # 识别段落失败
PARA_MERGE_FAILED = 'para_merge_failed' # 段落合并失败
NOT_ALLOW_LANGUAGE = 'not_allow_language' # 不支持的语种
SPECIAL_PDF = 'special_pdf'
PSEUDO_SINGLE_COLUMN = 'pseudo_single_column' # 无法精确判断文字分栏
CAN_NOT_DETECT_PAGE_LAYOUT = 'can_not_detect_page_layout' # 无法分析页面的版面
NEGATIVE_BBOX_AREA = 'negative_bbox_area' # 缩放导致 bbox 面积为负
OVERLAP_BLOCKS_CAN_NOT_SEPARATION = (
'overlap_blocks_can_t_separation' # 无法分离重叠的block
)
COLOR_BG_HEADER_TXT_BLOCK = 'color_background_header_txt_block'
PAGE_NO = 'page-no' # 页码
CONTENT_IN_FOOT_OR_HEADER = 'in-foot-header-area' # 页眉页脚内的文本
VERTICAL_TEXT = 'vertical-text' # 垂直文本
ROTATE_TEXT = 'rotate-text' # 旋转文本
EMPTY_SIDE_BLOCK = 'empty-side-block' # 边缘上的空白没有任何内容的block
ON_IMAGE_TEXT = 'on-image-text' # 文本在图片上
ON_TABLE_TEXT = 'on-table-text' # 文本在表格上
class DropTag:
PAGE_NUMBER = 'page_no'
HEADER = 'header'
FOOTER = 'footer'
FOOTNOTE = 'footnote'
NOT_IN_LAYOUT = 'not_in_layout'
SPAN_OVERLAP = 'span_overlap'
BLOCK_OVERLAP = 'block_overlap'
import enum
class SupportedPdfParseMethod(enum.Enum):
OCR = 'ocr'
TXT = 'txt'
class FileNotExisted(Exception):
def __init__(self, path):
self.path = path
def __str__(self):
return f'File {self.path} does not exist.'
class InvalidConfig(Exception):
def __init__(self, msg):
self.msg = msg
def __str__(self):
return f'Invalid config: {self.msg}'
class InvalidParams(Exception):
def __init__(self, msg):
self.msg = msg
def __str__(self):
return f'Invalid params: {self.msg}'
class EmptyData(Exception):
def __init__(self, msg):
self.msg = msg
def __str__(self):
return f'Empty data: {self.msg}'
class CUDA_NOT_AVAILABLE(Exception):
def __init__(self, msg):
self.msg = msg
def __str__(self):
return f'CUDA not available: {self.msg}'
\ No newline at end of file
class MakeMode:
MM_MD = 'mm_markdown'
NLP_MD = 'nlp_markdown'
STANDARD_FORMAT = 'standard_format'
class DropMode:
WHOLE_PDF = 'whole_pdf'
SINGLE_PAGE = 'single_page'
NONE = 'none'
NONE_WITH_REASON = 'none_with_reason'
from enum import Enum
class ModelBlockTypeEnum(Enum):
TITLE = 0
PLAIN_TEXT = 1
ABANDON = 2
ISOLATE_FORMULA = 8
EMBEDDING = 13
ISOLATED = 14
class ContentType:
Image = 'image'
Table = 'table'
Text = 'text'
InlineEquation = 'inline_equation'
InterlineEquation = 'interline_equation'
class BlockType:
Image = 'image'
ImageBody = 'image_body'
ImageCaption = 'image_caption'
ImageFootnote = 'image_footnote'
Table = 'table'
TableBody = 'table_body'
TableCaption = 'table_caption'
TableFootnote = 'table_footnote'
Text = 'text'
Title = 'title'
InterlineEquation = 'interline_equation'
Footnote = 'footnote'
Discarded = 'discarded'
List = 'list'
Index = 'index'
class CategoryId:
Title = 0
Text = 1
Abandon = 2
ImageBody = 3
ImageCaption = 4
TableBody = 5
TableCaption = 6
TableFootnote = 7
InterlineEquation_Layout = 8
InlineEquation = 13
InterlineEquation_YOLO = 14
OcrText = 15
ImageFootnote = 101
import concurrent.futures
import fitz
from magic_pdf.data.dataset import PymuDocDataset
from magic_pdf.data.utils import fitz_doc_to_image # PyMuPDF
def partition_array_greedy(arr, k):
"""Partition an array into k parts using a simple greedy approach.
Parameters:
-----------
arr : list
The input array of integers
k : int
Number of partitions to create
Returns:
--------
partitions : list of lists
The k partitions of the array
"""
# Handle edge cases
if k <= 0:
raise ValueError('k must be a positive integer')
if k > len(arr):
k = len(arr) # Adjust k if it's too large
if k == 1:
return [list(range(len(arr)))]
if k == len(arr):
return [[i] for i in range(len(arr))]
# Sort the array in descending order
sorted_indices = sorted(range(len(arr)), key=lambda i: arr[i][1], reverse=True)
# Initialize k empty partitions
partitions = [[] for _ in range(k)]
partition_sums = [0] * k
# Assign each element to the partition with the smallest current sum
for idx in sorted_indices:
# Find the partition with the smallest sum
min_sum_idx = partition_sums.index(min(partition_sums))
# Add the element to this partition
partitions[min_sum_idx].append(idx) # Store the original index
partition_sums[min_sum_idx] += arr[idx][1]
return partitions
def process_pdf_batch(pdf_jobs, idx):
"""Process a batch of PDF pages using multiple threads.
Parameters:
-----------
pdf_jobs : list of tuples
List of (pdf_path, page_num) tuples
output_dir : str or None
Directory to save images to
num_threads : int
Number of threads to use
**kwargs :
Additional arguments for process_pdf_page
Returns:
--------
images : list
List of processed images
"""
images = []
for pdf_path, _ in pdf_jobs:
doc = fitz.open(pdf_path)
tmp = []
for page_num in range(len(doc)):
page = doc[page_num]
tmp.append(fitz_doc_to_image(page))
images.append(tmp)
return (idx, images)
def batch_build_dataset(pdf_paths, k, lang=None):
"""Process multiple PDFs by partitioning them into k balanced parts and
processing each part in parallel.
Parameters:
-----------
pdf_paths : list
List of paths to PDF files
k : int
Number of partitions to create
output_dir : str or None
Directory to save images to
threads_per_worker : int
Number of threads to use per worker
**kwargs :
Additional arguments for process_pdf_page
Returns:
--------
all_images : list
List of all processed images
"""
results = []
for pdf_path in pdf_paths:
with open(pdf_path, 'rb') as f:
pdf_bytes = f.read()
dataset = PymuDocDataset(pdf_bytes, lang=lang)
results.append(dataset)
return results
#
# # Get page counts for each PDF
# pdf_info = []
# total_pages = 0
#
# for pdf_path in pdf_paths:
# try:
# doc = fitz.open(pdf_path)
# num_pages = len(doc)
# pdf_info.append((pdf_path, num_pages))
# total_pages += num_pages
# doc.close()
# except Exception as e:
# print(f'Error opening {pdf_path}: {e}')
#
# # Partition the jobs based on page countEach job has 1 page
# partitions = partition_array_greedy(pdf_info, k)
#
# # Process each partition in parallel
# all_images_h = {}
#
# with concurrent.futures.ProcessPoolExecutor(max_workers=k) as executor:
# # Submit one task per partition
# futures = []
# for sn, partition in enumerate(partitions):
# # Get the jobs for this partition
# partition_jobs = [pdf_info[idx] for idx in partition]
#
# # Submit the task
# future = executor.submit(
# process_pdf_batch,
# partition_jobs,
# sn
# )
# futures.append(future)
# # Process results as they complete
# for i, future in enumerate(concurrent.futures.as_completed(futures)):
# try:
# idx, images = future.result()
# all_images_h[idx] = images
# except Exception as e:
# print(f'Error processing partition: {e}')
# results = [None] * len(pdf_paths)
# for i in range(len(partitions)):
# partition = partitions[i]
# for j in range(len(partition)):
# with open(pdf_info[partition[j]][0], 'rb') as f:
# pdf_bytes = f.read()
# dataset = PymuDocDataset(pdf_bytes, lang=lang)
# dataset.set_images(all_images_h[i][j])
# results[partition[j]] = dataset
# return results
\ No newline at end of file
from magic_pdf.data.data_reader_writer.filebase import \
FileBasedDataReader # noqa: F401
from magic_pdf.data.data_reader_writer.filebase import \
FileBasedDataWriter # noqa: F401
from magic_pdf.data.data_reader_writer.multi_bucket_s3 import \
MultiBucketS3DataReader # noqa: F401
from magic_pdf.data.data_reader_writer.multi_bucket_s3 import \
MultiBucketS3DataWriter # noqa: F401
from magic_pdf.data.data_reader_writer.s3 import S3DataReader # noqa: F401
from magic_pdf.data.data_reader_writer.s3 import S3DataWriter # noqa: F401
from magic_pdf.data.data_reader_writer.base import DataReader # noqa: F401
from magic_pdf.data.data_reader_writer.base import DataWriter # noqa: F401
\ No newline at end of file
from abc import ABC, abstractmethod
class DataReader(ABC):
def read(self, path: str) -> bytes:
"""Read the file.
Args:
path (str): file path to read
Returns:
bytes: the content of the file
"""
return self.read_at(path)
@abstractmethod
def read_at(self, path: str, offset: int = 0, limit: int = -1) -> bytes:
"""Read the file at offset and limit.
Args:
path (str): the file path
offset (int, optional): the number of bytes skipped. Defaults to 0.
limit (int, optional): the length of bytes want to read. Defaults to -1.
Returns:
bytes: the content of the file
"""
pass
class DataWriter(ABC):
@abstractmethod
def write(self, path: str, data: bytes) -> None:
"""Write the data to the file.
Args:
path (str): the target file where to write
data (bytes): the data want to write
"""
pass
def write_string(self, path: str, data: str) -> None:
"""Write the data to file, the data will be encoded to bytes.
Args:
path (str): the target file where to write
data (str): the data want to write
"""
def safe_encode(data: str, method: str):
try:
bit_data = data.encode(encoding=method, errors='replace')
return bit_data, True
except: # noqa
return None, False
for method in ['utf-8', 'ascii']:
bit_data, flag = safe_encode(data, method)
if flag:
self.write(path, bit_data)
break
import os
from magic_pdf.data.data_reader_writer.base import DataReader, DataWriter
class FileBasedDataReader(DataReader):
def __init__(self, parent_dir: str = ''):
"""Initialized with parent_dir.
Args:
parent_dir (str, optional): the parent directory that may be used within methods. Defaults to ''.
"""
self._parent_dir = parent_dir
def read_at(self, path: str, offset: int = 0, limit: int = -1) -> bytes:
"""Read at offset and limit.
Args:
path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
offset (int, optional): the number of bytes skipped. Defaults to 0.
limit (int, optional): the length of bytes want to read. Defaults to -1.
Returns:
bytes: the content of file
"""
fn_path = path
if not os.path.isabs(fn_path) and len(self._parent_dir) > 0:
fn_path = os.path.join(self._parent_dir, path)
with open(fn_path, 'rb') as f:
f.seek(offset)
if limit == -1:
return f.read()
else:
return f.read(limit)
class FileBasedDataWriter(DataWriter):
def __init__(self, parent_dir: str = '') -> None:
"""Initialized with parent_dir.
Args:
parent_dir (str, optional): the parent directory that may be used within methods. Defaults to ''.
"""
self._parent_dir = parent_dir
def write(self, path: str, data: bytes) -> None:
"""Write file with data.
Args:
path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
data (bytes): the data want to write
"""
fn_path = path
if not os.path.isabs(fn_path) and len(self._parent_dir) > 0:
fn_path = os.path.join(self._parent_dir, path)
if not os.path.exists(os.path.dirname(fn_path)) and os.path.dirname(fn_path) != "":
os.makedirs(os.path.dirname(fn_path), exist_ok=True)
with open(fn_path, 'wb') as f:
f.write(data)
from magic_pdf.config.exceptions import InvalidConfig, InvalidParams
from magic_pdf.data.data_reader_writer.base import DataReader, DataWriter
from magic_pdf.data.io.s3 import S3Reader, S3Writer
from magic_pdf.data.schemas import S3Config
from magic_pdf.libs.path_utils import (parse_s3_range_params, parse_s3path,
remove_non_official_s3_args)
class MultiS3Mixin:
def __init__(self, default_prefix: str, s3_configs: list[S3Config]):
"""Initialized with multiple s3 configs.
Args:
default_prefix (str): the default prefix of the relative path. for example, {some_bucket}/{some_prefix} or {some_bucket}
s3_configs (list[S3Config]): list of s3 configs, the bucket_name must be unique in the list.
Raises:
InvalidConfig: default bucket config not in s3_configs.
InvalidConfig: bucket name not unique in s3_configs.
InvalidConfig: default bucket must be provided.
"""
if len(default_prefix) == 0:
raise InvalidConfig('default_prefix must be provided')
arr = default_prefix.strip('/').split('/')
self.default_bucket = arr[0]
self.default_prefix = '/'.join(arr[1:])
found_default_bucket_config = False
for conf in s3_configs:
if conf.bucket_name == self.default_bucket:
found_default_bucket_config = True
break
if not found_default_bucket_config:
raise InvalidConfig(
f'default_bucket: {self.default_bucket} config must be provided in s3_configs: {s3_configs}'
)
uniq_bucket = set([conf.bucket_name for conf in s3_configs])
if len(uniq_bucket) != len(s3_configs):
raise InvalidConfig(
f'the bucket_name in s3_configs: {s3_configs} must be unique'
)
self.s3_configs = s3_configs
self._s3_clients_h: dict = {}
class MultiBucketS3DataReader(DataReader, MultiS3Mixin):
def read(self, path: str) -> bytes:
"""Read the path from s3, select diffect bucket client for each request
based on the bucket, also support range read.
Args:
path (str): the s3 path of file, the path must be in the format of s3://bucket_name/path?offset,limit.
for example: s3://bucket_name/path?0,100.
Returns:
bytes: the content of s3 file.
"""
may_range_params = parse_s3_range_params(path)
if may_range_params is None or 2 != len(may_range_params):
byte_start, byte_len = 0, -1
else:
byte_start, byte_len = int(may_range_params[0]), int(may_range_params[1])
path = remove_non_official_s3_args(path)
return self.read_at(path, byte_start, byte_len)
def __get_s3_client(self, bucket_name: str):
if bucket_name not in set([conf.bucket_name for conf in self.s3_configs]):
raise InvalidParams(
f'bucket name: {bucket_name} not found in s3_configs: {self.s3_configs}'
)
if bucket_name not in self._s3_clients_h:
conf = next(
filter(lambda conf: conf.bucket_name == bucket_name, self.s3_configs)
)
self._s3_clients_h[bucket_name] = S3Reader(
bucket_name,
conf.access_key,
conf.secret_key,
conf.endpoint_url,
conf.addressing_style,
)
return self._s3_clients_h[bucket_name]
def read_at(self, path: str, offset: int = 0, limit: int = -1) -> bytes:
"""Read the file with offset and limit, select diffect bucket client
for each request based on the bucket.
Args:
path (str): the file path.
offset (int, optional): the number of bytes skipped. Defaults to 0.
limit (int, optional): the number of bytes want to read. Defaults to -1 which means infinite.
Returns:
bytes: the file content.
"""
if path.startswith('s3://'):
bucket_name, path = parse_s3path(path)
s3_reader = self.__get_s3_client(bucket_name)
else:
s3_reader = self.__get_s3_client(self.default_bucket)
if self.default_prefix:
path = self.default_prefix + '/' + path
return s3_reader.read_at(path, offset, limit)
class MultiBucketS3DataWriter(DataWriter, MultiS3Mixin):
def __get_s3_client(self, bucket_name: str):
if bucket_name not in set([conf.bucket_name for conf in self.s3_configs]):
raise InvalidParams(
f'bucket name: {bucket_name} not found in s3_configs: {self.s3_configs}'
)
if bucket_name not in self._s3_clients_h:
conf = next(
filter(lambda conf: conf.bucket_name == bucket_name, self.s3_configs)
)
self._s3_clients_h[bucket_name] = S3Writer(
bucket_name,
conf.access_key,
conf.secret_key,
conf.endpoint_url,
conf.addressing_style,
)
return self._s3_clients_h[bucket_name]
def write(self, path: str, data: bytes) -> None:
"""Write file with data, also select diffect bucket client for each
request based on the bucket.
Args:
path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
data (bytes): the data want to write.
"""
if path.startswith('s3://'):
bucket_name, path = parse_s3path(path)
s3_writer = self.__get_s3_client(bucket_name)
else:
s3_writer = self.__get_s3_client(self.default_bucket)
if self.default_prefix:
path = self.default_prefix + '/' + path
return s3_writer.write(path, data)
from magic_pdf.data.data_reader_writer.multi_bucket_s3 import (
MultiBucketS3DataReader, MultiBucketS3DataWriter)
from magic_pdf.data.schemas import S3Config
class S3DataReader(MultiBucketS3DataReader):
def __init__(
self,
default_prefix_without_bucket: str,
bucket: str,
ak: str,
sk: str,
endpoint_url: str,
addressing_style: str = 'auto',
):
"""s3 reader client.
Args:
default_prefix_without_bucket: prefix that not contains bucket
bucket (str): bucket name
ak (str): access key
sk (str): secret key
endpoint_url (str): endpoint url of s3
addressing_style (str, optional): Defaults to 'auto'. Other valid options here are 'path' and 'virtual'
refer to https://boto3.amazonaws.com/v1/documentation/api/1.9.42/guide/s3.html
"""
super().__init__(
f'{bucket}/{default_prefix_without_bucket}',
[
S3Config(
bucket_name=bucket,
access_key=ak,
secret_key=sk,
endpoint_url=endpoint_url,
addressing_style=addressing_style,
)
],
)
class S3DataWriter(MultiBucketS3DataWriter):
def __init__(
self,
default_prefix_without_bucket: str,
bucket: str,
ak: str,
sk: str,
endpoint_url: str,
addressing_style: str = 'auto',
):
"""s3 writer client.
Args:
default_prefix_without_bucket: prefix that not contains bucket
bucket (str): bucket name
ak (str): access key
sk (str): secret key
endpoint_url (str): endpoint url of s3
addressing_style (str, optional): Defaults to 'auto'. Other valid options here are 'path' and 'virtual'
refer to https://boto3.amazonaws.com/v1/documentation/api/1.9.42/guide/s3.html
"""
super().__init__(
f'{bucket}/{default_prefix_without_bucket}',
[
S3Config(
bucket_name=bucket,
access_key=ak,
secret_key=sk,
endpoint_url=endpoint_url,
addressing_style=addressing_style,
)
],
)
import os
from abc import ABC, abstractmethod
from typing import Callable, Iterator
import fitz
from loguru import logger
from magic_pdf.config.enums import SupportedPdfParseMethod
from magic_pdf.data.schemas import PageInfo
from magic_pdf.data.utils import fitz_doc_to_image
from magic_pdf.filter import classify
class PageableData(ABC):
@abstractmethod
def get_image(self) -> dict:
"""Transform data to image."""
pass
@abstractmethod
def get_doc(self) -> fitz.Page:
"""Get the pymudoc page."""
pass
@abstractmethod
def get_page_info(self) -> PageInfo:
"""Get the page info of the page.
Returns:
PageInfo: the page info of this page
"""
pass
@abstractmethod
def draw_rect(self, rect_coords, color, fill, fill_opacity, width, overlay):
"""draw rectangle.
Args:
rect_coords (list[float]): four elements array contain the top-left and bottom-right coordinates, [x0, y0, x1, y1]
color (list[float] | None): three element tuple which describe the RGB of the board line, None means no board line
fill (list[float] | None): fill the board with RGB, None means will not fill with color
fill_opacity (float): opacity of the fill, range from [0, 1]
width (float): the width of board
overlay (bool): fill the color in foreground or background. True means fill in background.
"""
pass
@abstractmethod
def insert_text(self, coord, content, fontsize, color):
"""insert text.
Args:
coord (list[float]): four elements array contain the top-left and bottom-right coordinates, [x0, y0, x1, y1]
content (str): the text content
fontsize (int): font size of the text
color (list[float] | None): three element tuple which describe the RGB of the board line, None will use the default font color!
"""
pass
class Dataset(ABC):
@abstractmethod
def __len__(self) -> int:
"""The length of the dataset."""
pass
@abstractmethod
def __iter__(self) -> Iterator[PageableData]:
"""Yield the page data."""
pass
@abstractmethod
def supported_methods(self) -> list[SupportedPdfParseMethod]:
"""The methods that this dataset support.
Returns:
list[SupportedPdfParseMethod]: The supported methods, Valid methods are: OCR, TXT
"""
pass
@abstractmethod
def data_bits(self) -> bytes:
"""The bits used to create this dataset."""
pass
@abstractmethod
def get_page(self, page_id: int) -> PageableData:
"""Get the page indexed by page_id.
Args:
page_id (int): the index of the page
Returns:
PageableData: the page doc object
"""
pass
@abstractmethod
def dump_to_file(self, file_path: str):
"""Dump the file.
Args:
file_path (str): the file path
"""
pass
@abstractmethod
def apply(self, proc: Callable, *args, **kwargs):
"""Apply callable method which.
Args:
proc (Callable): invoke proc as follows:
proc(self, *args, **kwargs)
Returns:
Any: return the result generated by proc
"""
pass
@abstractmethod
def classify(self) -> SupportedPdfParseMethod:
"""classify the dataset.
Returns:
SupportedPdfParseMethod: _description_
"""
pass
@abstractmethod
def clone(self):
"""clone this dataset."""
pass
class PymuDocDataset(Dataset):
def __init__(self, bits: bytes, lang=None):
"""Initialize the dataset, which wraps the pymudoc documents.
Args:
bits (bytes): the bytes of the pdf
"""
self._raw_fitz = fitz.open('pdf', bits)
self._records = [Doc(v) for v in self._raw_fitz]
self._data_bits = bits
self._raw_data = bits
self._classify_result = None
if lang == '':
self._lang = None
elif lang == 'auto':
from magic_pdf.model.sub_modules.language_detection.utils import \
auto_detect_lang
self._lang = auto_detect_lang(self._data_bits)
logger.info(f'lang: {lang}, detect_lang: {self._lang}')
else:
self._lang = lang
logger.info(f'lang: {lang}')
def __len__(self) -> int:
"""The page number of the pdf."""
return len(self._records)
def __iter__(self) -> Iterator[PageableData]:
"""Yield the page doc object."""
return iter(self._records)
def supported_methods(self) -> list[SupportedPdfParseMethod]:
"""The method supported by this dataset.
Returns:
list[SupportedPdfParseMethod]: the supported methods
"""
return [SupportedPdfParseMethod.OCR, SupportedPdfParseMethod.TXT]
def data_bits(self) -> bytes:
"""The pdf bits used to create this dataset."""
return self._data_bits
def get_page(self, page_id: int) -> PageableData:
"""The page doc object.
Args:
page_id (int): the page doc index
Returns:
PageableData: the page doc object
"""
return self._records[page_id]
def dump_to_file(self, file_path: str):
"""Dump the file.
Args:
file_path (str): the file path
"""
dir_name = os.path.dirname(file_path)
if dir_name not in ('', '.', '..'):
os.makedirs(dir_name, exist_ok=True)
self._raw_fitz.save(file_path)
def apply(self, proc: Callable, *args, **kwargs):
"""Apply callable method which.
Args:
proc (Callable): invoke proc as follows:
proc(dataset, *args, **kwargs)
Returns:
Any: return the result generated by proc
"""
if 'lang' in kwargs and self._lang is not None:
kwargs['lang'] = self._lang
return proc(self, *args, **kwargs)
def classify(self) -> SupportedPdfParseMethod:
"""classify the dataset.
Returns:
SupportedPdfParseMethod: _description_
"""
if self._classify_result is None:
self._classify_result = classify(self._data_bits)
return self._classify_result
def clone(self):
"""clone this dataset."""
return PymuDocDataset(self._raw_data)
def set_images(self, images):
for i in range(len(self._records)):
self._records[i].set_image(images[i])
class ImageDataset(Dataset):
def __init__(self, bits: bytes, lang=None):
"""Initialize the dataset, which wraps the pymudoc documents.
Args:
bits (bytes): the bytes of the photo which will be converted to pdf first. then converted to pymudoc.
"""
pdf_bytes = fitz.open(stream=bits).convert_to_pdf()
self._raw_fitz = fitz.open('pdf', pdf_bytes)
self._records = [Doc(v) for v in self._raw_fitz]
self._raw_data = bits
self._data_bits = pdf_bytes
if lang == '':
self._lang = None
elif lang == 'auto':
from magic_pdf.model.sub_modules.language_detection.utils import \
auto_detect_lang
self._lang = auto_detect_lang(self._data_bits)
logger.info(f'lang: {lang}, detect_lang: {self._lang}')
else:
self._lang = lang
logger.info(f'lang: {lang}')
def __len__(self) -> int:
"""The length of the dataset."""
return len(self._records)
def __iter__(self) -> Iterator[PageableData]:
"""Yield the page object."""
return iter(self._records)
def supported_methods(self):
"""The method supported by this dataset.
Returns:
list[SupportedPdfParseMethod]: the supported methods
"""
return [SupportedPdfParseMethod.OCR]
def data_bits(self) -> bytes:
"""The pdf bits used to create this dataset."""
return self._data_bits
def get_page(self, page_id: int) -> PageableData:
"""The page doc object.
Args:
page_id (int): the page doc index
Returns:
PageableData: the page doc object
"""
return self._records[page_id]
def dump_to_file(self, file_path: str):
"""Dump the file.
Args:
file_path (str): the file path
"""
dir_name = os.path.dirname(file_path)
if dir_name not in ('', '.', '..'):
os.makedirs(dir_name, exist_ok=True)
self._raw_fitz.save(file_path)
def apply(self, proc: Callable, *args, **kwargs):
"""Apply callable method which.
Args:
proc (Callable): invoke proc as follows:
proc(dataset, *args, **kwargs)
Returns:
Any: return the result generated by proc
"""
return proc(self, *args, **kwargs)
def classify(self) -> SupportedPdfParseMethod:
"""classify the dataset.
Returns:
SupportedPdfParseMethod: _description_
"""
return SupportedPdfParseMethod.OCR
def clone(self):
"""clone this dataset."""
return ImageDataset(self._raw_data)
def set_images(self, images):
for i in range(len(self._records)):
self._records[i].set_image(images[i])
class Doc(PageableData):
"""Initialized with pymudoc object."""
def __init__(self, doc: fitz.Page):
self._doc = doc
self._img = None
def get_image(self):
"""Return the image info.
Returns:
dict: {
img: np.ndarray,
width: int,
height: int
}
"""
if self._img is None:
self._img = fitz_doc_to_image(self._doc)
return self._img
def set_image(self, img):
"""
Args:
img (np.ndarray): the image
"""
if self._img is None:
self._img = img
def get_doc(self) -> fitz.Page:
"""Get the pymudoc object.
Returns:
fitz.Page: the pymudoc object
"""
return self._doc
def get_page_info(self) -> PageInfo:
"""Get the page info of the page.
Returns:
PageInfo: the page info of this page
"""
page_w = self._doc.rect.width
page_h = self._doc.rect.height
return PageInfo(w=page_w, h=page_h)
def __getattr__(self, name):
if hasattr(self._doc, name):
return getattr(self._doc, name)
def draw_rect(self, rect_coords, color, fill, fill_opacity, width, overlay):
"""draw rectangle.
Args:
rect_coords (list[float]): four elements array contain the top-left and bottom-right coordinates, [x0, y0, x1, y1]
color (list[float] | None): three element tuple which describe the RGB of the board line, None means no board line
fill (list[float] | None): fill the board with RGB, None means will not fill with color
fill_opacity (float): opacity of the fill, range from [0, 1]
width (float): the width of board
overlay (bool): fill the color in foreground or background. True means fill in background.
"""
self._doc.draw_rect(
rect_coords,
color=color,
fill=fill,
fill_opacity=fill_opacity,
width=width,
overlay=overlay,
)
def insert_text(self, coord, content, fontsize, color):
"""insert text.
Args:
coord (list[float]): four elements array contain the top-left and bottom-right coordinates, [x0, y0, x1, y1]
content (str): the text content
fontsize (int): font size of the text
color (list[float] | None): three element tuple which describe the RGB of the board line, None will use the default font color!
"""
self._doc.insert_text(coord, content, fontsize=fontsize, color=color)
\ No newline at end of file
from magic_pdf.data.io.base import IOReader, IOWriter # noqa: F401
from magic_pdf.data.io.http import HttpReader, HttpWriter # noqa: F401
from magic_pdf.data.io.s3 import S3Reader, S3Writer # noqa: F401
__all__ = ['IOReader', 'IOWriter', 'HttpReader', 'HttpWriter', 'S3Reader', 'S3Writer']
\ No newline at end of file
from abc import ABC, abstractmethod
class IOReader(ABC):
@abstractmethod
def read(self, path: str) -> bytes:
"""Read the file.
Args:
path (str): file path to read
Returns:
bytes: the content of the file
"""
pass
@abstractmethod
def read_at(self, path: str, offset: int = 0, limit: int = -1) -> bytes:
"""Read at offset and limit.
Args:
path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
offset (int, optional): the number of bytes skipped. Defaults to 0.
limit (int, optional): the length of bytes want to read. Defaults to -1.
Returns:
bytes: the content of file
"""
pass
class IOWriter(ABC):
@abstractmethod
def write(self, path: str, data: bytes) -> None:
"""Write file with data.
Args:
path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
data (bytes): the data want to write
"""
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment