Unverified Commit 3a42ebbf authored by Xiaomeng Zhao's avatar Xiaomeng Zhao Committed by GitHub
Browse files

Merge pull request #838 from opendatalab/release-0.9.0

Release 0.9.0
parents 765c6d77 14024793
......@@ -3,82 +3,107 @@
## 1. 安装cuda和cuDNN
需要安装的版本 CUDA 11.8 + cuDNN 8.7.0
- CUDA 11.8 https://developer.nvidia.com/cuda-11-8-0-download-archive
- cuDNN v8.7.0 (November 28th, 2022), for CUDA 11.x https://developer.nvidia.com/rdp/cudnn-archive
## 2. 安装anaconda
如果已安装conda,可以跳过本步骤
下载链接:
https://mirrors.tuna.tsinghua.edu.cn/anaconda/archive/Anaconda3-2024.06-1-Windows-x86_64.exe
## 3. 使用conda 创建环境
需指定python版本为3.10
```bash
conda create -n MinerU python=3.10
conda activate MinerU
```
## 4. 安装应用
```bash
pip install -U magic-pdf[full] --extra-index-url https://wheels.myhloli.com -i https://pypi.tuna.tsinghua.edu.cn/simple
pip install -U magic-pdf[full] --extra-index-url https://wheels.myhloli.com -i https://mirrors.aliyun.com/pypi/simple
```
> ❗️下载完成后,务必通过以下命令确认magic-pdf的版本是否正确
>
>
> ```bash
> magic-pdf --version
>```
> ```
>
> 如果版本号小于0.7.0,请到issue中向我们反馈
## 5. 下载模型
详细参考 [如何下载模型文件](how_to_download_models_zh_cn.md)
## 6. 了解配置文件存放的位置
完成[5.下载模型](#5-下载模型)步骤后,脚本会自动生成用户目录下的magic-pdf.json文件,并自动配置默认模型路径。
您可在【用户目录】下找到magic-pdf.json文件。
> windows用户目录为 "C:/Users/用户名"
## 7. 第一次运行
从仓库中下载样本文件,并测试
```powershell
(New-Object System.Net.WebClient).DownloadFile('https://gitee.com/myhloli/MinerU/raw/master/demo/small_ocr.pdf', 'small_ocr.pdf')
magic-pdf -p small_ocr.pdf
wget https://github.com/opendatalab/MinerU/raw/master/demo/small_ocr.pdf -O small_ocr.pdf
magic-pdf -p small_ocr.pdf
```
## 8. 测试CUDA加速
如果您的显卡显存大于等于8G,可以进行以下流程,测试CUDA解析加速效果
如果您的显卡显存大于等于 **8GB** ,可以进行以下流程,测试CUDA解析加速效果
**1.覆盖安装支持cuda的torch和torchvision**
```bash
pip install --force-reinstall torch==2.3.1 torchvision==0.18.1 --index-url https://download.pytorch.org/whl/cu118
```
> ❗️务必在命令中指定以下版本
>
> ```bash
> torch==2.3.1 torchvision==0.18.1
> torch==2.3.1 torchvision==0.18.1
> ```
>
> 这是我们支持的最高版本,如果不指定版本会自动安装更高版本导致程序无法运行
**2.修改【用户目录】中配置文件magic-pdf.json中"device-mode"的值**
```json
{
"device-mode":"cuda"
}
```
**3.运行以下命令测试cuda加速效果**
```bash
magic-pdf -p small_ocr.pdf
```
> 提示:CUDA加速是否生效可以根据log中输出的各个阶段cost耗时来简单判断,通常情况下,`layout detection cost` 和 `mfr time` 应提速10倍以上。
> 提示:CUDA加速是否生效可以根据log中输出的各个阶段的耗时来简单判断,通常情况下,`layout detection time` 和 `mfr time` 应提速10倍以上。
## 9. 为ocr开启cuda加速
> ❗️以下操作需显卡显存大于等于16G才可进行,否则会因为显存不足导致程序崩溃或运行速度下降
**1.下载paddlepaddle-gpu, 安装完成后会自动开启ocr加速**
```bash
pip install paddlepaddle-gpu==2.6.1
```
**2.运行以下命令测试ocr加速效果**
```bash
magic-pdf -p small_ocr.pdf
```
> 提示:CUDA加速是否生效可以根据log中输出的各个阶段cost耗时来简单判断,通常情况下,`ocr cost`应提速10倍以上。
> 提示:CUDA加速是否生效可以根据log中输出的各个阶段cost耗时来简单判断,通常情况下,`ocr time`应提速10倍以上。
import json
import os
import requests
import json
from modelscope import snapshot_download
def download_json(url):
# 下载JSON文件
response = requests.get(url)
response.raise_for_status() # 检查请求是否成功
return response.json()
def download_and_modify_json(url, local_filename, modifications):
if os.path.exists(local_filename):
data = json.load(open(local_filename))
config_version = data.get('config_version', '0.0.0')
if config_version < '1.0.0':
data = download_json(url)
else:
# 下载JSON文件
response = requests.get(url)
response.raise_for_status() # 检查请求是否成功
data = download_json(url)
# 解析JSON内容
data = response.json()
# 修改内容
for key, value in modifications.items():
......@@ -25,15 +33,25 @@ def download_and_modify_json(url, local_filename, modifications):
if __name__ == '__main__':
model_dir = snapshot_download('opendatalab/PDF-Extract-Kit')
mineru_patterns = [
"models/Layout/LayoutLMv3/*",
"models/Layout/YOLO/*",
"models/MFD/YOLO/*",
"models/MFR/unimernet_small/*",
"models/TabRec/TableMaster/*",
"models/TabRec/StructEqTable/*",
]
model_dir = snapshot_download('opendatalab/PDF-Extract-Kit-1.0', allow_patterns=mineru_patterns)
layoutreader_model_dir = snapshot_download('ppaanngggg/layoutreader')
model_dir = model_dir + "/models"
print(f"model_dir is: {model_dir}")
print(f"layoutreader_model_dir is: {layoutreader_model_dir}")
model_dir = model_dir + '/models'
print(f'model_dir is: {model_dir}')
print(f'layoutreader_model_dir is: {layoutreader_model_dir}')
json_url = 'https://gitee.com/myhloli/MinerU/raw/dev/magic-pdf.template.json'
config_file_name = 'magic-pdf.json'
home_dir = os.path.expanduser('~')
json_url = 'https://gitee.com/myhloli/MinerU/raw/master/magic-pdf.template.json'
config_file_name = "magic-pdf.json"
home_dir = os.path.expanduser("~")
config_file = os.path.join(home_dir, config_file_name)
json_mods = {
......@@ -42,4 +60,6 @@ if __name__ == '__main__':
}
download_and_modify_json(json_url, config_file, json_mods)
print(f"The configuration file has been configured successfully, the path is: {config_file}")
\ No newline at end of file
print(f'The configuration file has been configured successfully, the path is: {config_file}')
import json
import os
import requests
import json
from huggingface_hub import snapshot_download
def download_json(url):
# 下载JSON文件
response = requests.get(url)
response.raise_for_status() # 检查请求是否成功
return response.json()
def download_and_modify_json(url, local_filename, modifications):
if os.path.exists(local_filename):
data = json.load(open(local_filename))
config_version = data.get('config_version', '0.0.0')
if config_version < '1.0.0':
data = download_json(url)
else:
# 下载JSON文件
response = requests.get(url)
response.raise_for_status() # 检查请求是否成功
data = download_json(url)
# 解析JSON内容
data = response.json()
# 修改内容
for key, value in modifications.items():
......@@ -25,15 +32,31 @@ def download_and_modify_json(url, local_filename, modifications):
if __name__ == '__main__':
model_dir = snapshot_download('opendatalab/PDF-Extract-Kit')
layoutreader_model_dir = snapshot_download('hantian/layoutreader')
model_dir = model_dir + "/models"
print(f"model_dir is: {model_dir}")
print(f"layoutreader_model_dir is: {layoutreader_model_dir}")
json_url = 'https://github.com/opendatalab/MinerU/raw/master/magic-pdf.template.json'
config_file_name = "magic-pdf.json"
home_dir = os.path.expanduser("~")
mineru_patterns = [
"models/Layout/LayoutLMv3/*",
"models/Layout/YOLO/*",
"models/MFD/YOLO/*",
"models/MFR/unimernet_small/*",
"models/TabRec/TableMaster/*",
"models/TabRec/StructEqTable/*",
]
model_dir = snapshot_download('opendatalab/PDF-Extract-Kit-1.0', allow_patterns=mineru_patterns)
layoutreader_pattern = [
"*.json",
"*.safetensors",
]
layoutreader_model_dir = snapshot_download('hantian/layoutreader', allow_patterns=layoutreader_pattern)
model_dir = model_dir + '/models'
print(f'model_dir is: {model_dir}')
print(f'layoutreader_model_dir is: {layoutreader_model_dir}')
json_url = 'https://github.com/opendatalab/MinerU/raw/dev/magic-pdf.template.json'
config_file_name = 'magic-pdf.json'
home_dir = os.path.expanduser('~')
config_file = os.path.join(home_dir, config_file_name)
json_mods = {
......@@ -42,4 +65,6 @@ if __name__ == '__main__':
}
download_and_modify_json(json_url, config_file, json_mods)
print(f"The configuration file has been configured successfully, the path is: {config_file}")
\ No newline at end of file
print(f'The configuration file has been configured successfully, the path is: {config_file}')
......@@ -3,7 +3,8 @@ Model downloads are divided into initial downloads and updates to the model dire
# Initial download of model files
### 1. Download the Model from Hugging Face
### Download the Model from Hugging Face
Use a Python Script to Download Model Files from Hugging Face
```bash
pip install huggingface_hub
......@@ -14,14 +15,17 @@ The Python script will automatically download the model files and configure the
The configuration file can be found in the user directory, with the filename `magic-pdf.json`.
# How to update models previously downloaded
## 1. Models downloaded via Git LFS
>Due to feedback from some users that downloading model files using git lfs was incomplete or resulted in corrupted model files, this method is no longer recommended.
> Due to feedback from some users that downloading model files using git lfs was incomplete or resulted in corrupted model files, this method is no longer recommended.
When magic-pdf <= 0.8.1, if you have previously downloaded the model files via git lfs, you can navigate to the previous download directory and update the models using the `git pull` command.
If you previously downloaded model files via git lfs, you can navigate to the previous download directory and use the `git pull` command to update the model.
> For versions 0.9.x and later, due to the repository change and the addition of the layout sorting model in PDF-Extract-Kit 1.0, the models cannot be updated using the `git pull` command. Instead, a Python script must be used for one-click updates.
## 2. Models downloaded via Hugging Face or Model Scope
If you previously downloaded models via Hugging Face or Model Scope, you can rerun the Python script used for the initial download. This will automatically update the model directory to the latest version.
\ No newline at end of file
If you previously downloaded models via Hugging Face or Model Scope, you can rerun the Python script used for the initial download. This will automatically update the model directory to the latest version.
......@@ -10,7 +10,7 @@
<pre><code>pip install huggingface_hub
wget https://gitee.com/myhloli/MinerU/raw/master/docs/download_models_hf.py -O download_models_hf.py
python download_models_hf.py</code></pre>
<p>python脚本执行完毕后,会输出模型下载目录</p>
<p>python脚本会自动下载模型文件并配置好配置文件中的模型目录</p>
</details>
## 方法二:从 ModelScope 下载模型
......@@ -25,6 +25,7 @@ python download_models.py
python脚本会自动下载模型文件并配置好配置文件中的模型目录
配置文件可以在用户目录中找到,文件名为`magic-pdf.json`
> windows的用户目录为 "C:\\Users\\用户名", linux用户目录为 "/home/用户名", macOS用户目录为 "/Users/用户名"
......@@ -32,17 +33,13 @@ python脚本会自动下载模型文件并配置好配置文件中的模型目
## 1. 通过git lfs下载过模型
>由于部分用户反馈通过git lfs下载模型文件遇到下载不全和模型文件损坏情况,现已不推荐使用该方式下载。
> 由于部分用户反馈通过git lfs下载模型文件遇到下载不全和模型文件损坏情况,现已不推荐使用该方式下载。
当magic-pdf <= 0.8.1时,如此前通过 git lfs 下载过模型文件,可以进入到之前的下载目录中,通过`git pull`命令更新模型。
如此前通过 git lfs 下载过模型文件,可以进入到之前的下载目录中,通过`git pull`命令更新模型
> 0.9.x及以后版本由于PDF-Extract-Kit 1.0更换仓库和新增layout排序模型,不能通过`git pull`命令更新,需要使用python脚本一键更新
> 0.9.x及以后版本由于新增layout排序模型,且该模型和此前的模型不在同一仓库,不能通过`git pull`命令更新,需要单独下载。
>
>```
>from modelscope import snapshot_download
>snapshot_download('ppaanngggg/layoutreader')
>```
## 2. 通过 Hugging Face 或 Model Scope 下载过模型
如此前通过 HuggingFace 或 Model Scope 下载过模型,可以重复执行此前的模型下载python脚本,将会自动将模型目录更新到最新版本。
\ No newline at end of file
如此前通过 HuggingFace 或 Model Scope 下载过模型,可以重复执行此前的模型下载python脚本,将会自动将模型目录更新到最新版本。
......@@ -175,11 +175,14 @@ Detailed explanation of second-level block types
| :----------------- | :--------------------- |
| image_body | Main body of the image |
| image_caption | Image description text |
| image_footnote | Image footnote |
| table_body | Main body of the table |
| table_caption | Table description text |
| table_footnote | Table footnote |
| text | Text block |
| title | Title block |
| index | Index block |
| list | List block |
| interline_equation | Block formula |
<br>
......
......@@ -174,12 +174,15 @@ poly 坐标的格式 \[x0, y0, x1, y1, x2, y2, x3, y3\], 分别表示左上、
| :----------------- | :------------- |
| image_body | 图像的本体 |
| image_caption | 图像的描述文本 |
| table_body | 表格本体 |
| image_footnote | 图像的脚注 |
| table_body | 表格本体 |
| table_caption | 表格的描述文本 |
| table_footnote | 表格的脚注 |
| text | 文本块 |
| title | 标题块 |
| interline_equation | 行间公式块 |
| table_footnote | 表格的脚注 |
| text | 文本块 |
| title | 标题块 |
| index | 目录块 |
| list | 列表块 |
| interline_equation | 行间公式块 |
<br>
......
......@@ -6,9 +6,18 @@
"models-dir":"/tmp/models",
"layoutreader-model-dir":"/tmp/layoutreader",
"device-mode":"cpu",
"layout-config": {
"model": "layoutlmv3"
},
"formula-config": {
"mfd_model": "yolo_v8_mfd",
"mfr_model": "unimernet_small",
"enable": true
},
"table-config": {
"model": "TableMaster",
"is_table_recog_enable": false,
"model": "tablemaster",
"enable": false,
"max_time": 400
}
},
"config_version": "1.0.0"
}
\ No newline at end of file
import enum
class SupportedPdfParseMethod(enum.Enum):
OCR = 'ocr'
TXT = 'txt'
class FileNotExisted(Exception):
def __init__(self, path):
self.path = path
def __str__(self):
return f'File {self.path} does not exist.'
class InvalidConfig(Exception):
def __init__(self, msg):
self.msg = msg
def __str__(self):
return f'Invalid config: {self.msg}'
class InvalidParams(Exception):
def __init__(self, msg):
self.msg = msg
def __str__(self):
return f'Invalid params: {self.msg}'
class EmptyData(Exception):
def __init__(self, msg):
self.msg = msg
def __str__(self):
return f'Empty data: {self.msg}'
from magic_pdf.data.data_reader_writer.filebase import \
FileBasedDataReader # noqa: F401
from magic_pdf.data.data_reader_writer.filebase import \
FileBasedDataWriter # noqa: F401
from magic_pdf.data.data_reader_writer.multi_bucket_s3 import \
MultiBucketS3DataReader # noqa: F401
from magic_pdf.data.data_reader_writer.multi_bucket_s3 import \
MultiBucketS3DataWriter # noqa: F401
from magic_pdf.data.data_reader_writer.s3 import S3DataReader # noqa: F401
from magic_pdf.data.data_reader_writer.s3 import S3DataWriter # noqa: F401
from magic_pdf.data.data_reader_writer.base import DataReader # noqa: F401
from magic_pdf.data.data_reader_writer.base import DataWriter # noqa: F401
\ No newline at end of file
from abc import ABC, abstractmethod
class DataReader(ABC):
def read(self, path: str) -> bytes:
"""Read the file.
Args:
path (str): file path to read
Returns:
bytes: the content of the file
"""
return self.read_at(path)
@abstractmethod
def read_at(self, path: str, offset: int = 0, limit: int = -1) -> bytes:
"""Read the file at offset and limit.
Args:
path (str): the file path
offset (int, optional): the number of bytes skipped. Defaults to 0.
limit (int, optional): the length of bytes want to read. Defaults to -1.
Returns:
bytes: the content of the file
"""
pass
class DataWriter(ABC):
@abstractmethod
def write(self, path: str, data: bytes) -> None:
"""Write the data to the file.
Args:
path (str): the target file where to write
data (bytes): the data want to write
"""
pass
def write_string(self, path: str, data: str) -> None:
"""Write the data to file, the data will be encoded to bytes.
Args:
path (str): the target file where to write
data (str): the data want to write
"""
self.write(path, data.encode())
import os
from magic_pdf.data.data_reader_writer.base import DataReader, DataWriter
class FileBasedDataReader(DataReader):
def __init__(self, parent_dir: str = ''):
"""Initialized with parent_dir.
Args:
parent_dir (str, optional): the parent directory that may be used within methods. Defaults to ''.
"""
self._parent_dir = parent_dir
def read_at(self, path: str, offset: int = 0, limit: int = -1) -> bytes:
"""Read at offset and limit.
Args:
path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
offset (int, optional): the number of bytes skipped. Defaults to 0.
limit (int, optional): the length of bytes want to read. Defaults to -1.
Returns:
bytes: the content of file
"""
fn_path = path
if not os.path.isabs(fn_path) and len(self._parent_dir) > 0:
fn_path = os.path.join(self._parent_dir, path)
with open(fn_path, 'rb') as f:
f.seek(offset)
if limit == -1:
return f.read()
else:
return f.read(limit)
class FileBasedDataWriter(DataWriter):
def __init__(self, parent_dir: str = '') -> None:
"""Initialized with parent_dir.
Args:
parent_dir (str, optional): the parent directory that may be used within methods. Defaults to ''.
"""
self._parent_dir = parent_dir
def write(self, path: str, data: bytes) -> None:
"""Write file with data.
Args:
path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
data (bytes): the data want to write
"""
fn_path = path
if not os.path.isabs(fn_path) and len(self._parent_dir) > 0:
fn_path = os.path.join(self._parent_dir, path)
with open(fn_path, 'wb') as f:
f.write(data)
from magic_pdf.config.exceptions import InvalidConfig, InvalidParams
from magic_pdf.data.data_reader_writer.base import DataReader, DataWriter
from magic_pdf.data.io.s3 import S3Reader, S3Writer
from magic_pdf.data.schemas import S3Config
from magic_pdf.libs.path_utils import (parse_s3_range_params, parse_s3path,
remove_non_official_s3_args)
class MultiS3Mixin:
def __init__(self, default_bucket: str, s3_configs: list[S3Config]):
"""Initialized with multiple s3 configs.
Args:
default_bucket (str): the default bucket name of the relative path
s3_configs (list[S3Config]): list of s3 configs, the bucket_name must be unique in the list.
Raises:
InvalidConfig: default bucket config not in s3_configs
InvalidConfig: bucket name not unique in s3_configs
InvalidConfig: default bucket must be provided
"""
if len(default_bucket) == 0:
raise InvalidConfig('default_bucket must be provided')
found_default_bucket_config = False
for conf in s3_configs:
if conf.bucket_name == default_bucket:
found_default_bucket_config = True
break
if not found_default_bucket_config:
raise InvalidConfig(
f'default_bucket: {default_bucket} config must be provided in s3_configs: {s3_configs}'
)
uniq_bucket = set([conf.bucket_name for conf in s3_configs])
if len(uniq_bucket) != len(s3_configs):
raise InvalidConfig(
f'the bucket_name in s3_configs: {s3_configs} must be unique'
)
self.default_bucket = default_bucket
self.s3_configs = s3_configs
self._s3_clients_h: dict = {}
class MultiBucketS3DataReader(DataReader, MultiS3Mixin):
def read(self, path: str) -> bytes:
"""Read the path from s3, select diffect bucket client for each request
based on the path, also support range read.
Args:
path (str): the s3 path of file, the path must be in the format of s3://bucket_name/path?offset,limit
for example: s3://bucket_name/path?0,100
Returns:
bytes: the content of s3 file
"""
may_range_params = parse_s3_range_params(path)
if may_range_params is None or 2 != len(may_range_params):
byte_start, byte_len = 0, -1
else:
byte_start, byte_len = int(may_range_params[0]), int(may_range_params[1])
path = remove_non_official_s3_args(path)
return self.read_at(path, byte_start, byte_len)
def __get_s3_client(self, bucket_name: str):
if bucket_name not in set([conf.bucket_name for conf in self.s3_configs]):
raise InvalidParams(
f'bucket name: {bucket_name} not found in s3_configs: {self.s3_configs}'
)
if bucket_name not in self._s3_clients_h:
conf = next(
filter(lambda conf: conf.bucket_name == bucket_name, self.s3_configs)
)
self._s3_clients_h[bucket_name] = S3Reader(
bucket_name,
conf.access_key,
conf.secret_key,
conf.endpoint_url,
conf.addressing_style,
)
return self._s3_clients_h[bucket_name]
def read_at(self, path: str, offset: int = 0, limit: int = -1) -> bytes:
"""Read the file with offset and limit, select diffect bucket client
for each request based on the path.
Args:
path (str): the file path
offset (int, optional): the number of bytes skipped. Defaults to 0.
limit (int, optional): the number of bytes want to read. Defaults to -1 which means infinite.
Returns:
bytes: the file content
"""
if path.startswith('s3://'):
bucket_name, path = parse_s3path(path)
s3_reader = self.__get_s3_client(bucket_name)
else:
s3_reader = self.__get_s3_client(self.default_bucket)
return s3_reader.read_at(path, offset, limit)
class MultiBucketS3DataWriter(DataWriter, MultiS3Mixin):
def __get_s3_client(self, bucket_name: str):
if bucket_name not in set([conf.bucket_name for conf in self.s3_configs]):
raise InvalidParams(
f'bucket name: {bucket_name} not found in s3_configs: {self.s3_configs}'
)
if bucket_name not in self._s3_clients_h:
conf = next(
filter(lambda conf: conf.bucket_name == bucket_name, self.s3_configs)
)
self._s3_clients_h[bucket_name] = S3Writer(
bucket_name,
conf.access_key,
conf.secret_key,
conf.endpoint_url,
conf.addressing_style,
)
return self._s3_clients_h[bucket_name]
def write(self, path: str, data: bytes) -> None:
"""Write file with data, also select diffect bucket client for each
request based on the path.
Args:
path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
data (bytes): the data want to write
"""
if path.startswith('s3://'):
bucket_name, path = parse_s3path(path)
s3_writer = self.__get_s3_client(bucket_name)
else:
s3_writer = self.__get_s3_client(self.default_bucket)
return s3_writer.write(path, data)
from magic_pdf.data.data_reader_writer.multi_bucket_s3 import (
MultiBucketS3DataReader, MultiBucketS3DataWriter)
from magic_pdf.data.schemas import S3Config
class S3DataReader(MultiBucketS3DataReader):
def __init__(
self,
bucket: str,
ak: str,
sk: str,
endpoint_url: str,
addressing_style: str = 'auto',
):
"""s3 reader client.
Args:
bucket (str): bucket name
ak (str): access key
sk (str): secret key
endpoint_url (str): endpoint url of s3
addressing_style (str, optional): Defaults to 'auto'. Other valid options here are 'path' and 'virtual'
refer to https://boto3.amazonaws.com/v1/documentation/api/1.9.42/guide/s3.html
"""
super().__init__(
bucket,
[
S3Config(
bucket_name=bucket,
access_key=ak,
secret_key=sk,
endpoint_url=endpoint_url,
addressing_style=addressing_style,
)
],
)
class S3DataWriter(MultiBucketS3DataWriter):
def __init__(
self,
bucket: str,
ak: str,
sk: str,
endpoint_url: str,
addressing_style: str = 'auto',
):
"""s3 writer client.
Args:
bucket (str): bucket name
ak (str): access key
sk (str): secret key
endpoint_url (str): endpoint url of s3
addressing_style (str, optional): Defaults to 'auto'. Other valid options here are 'path' and 'virtual'
refer to https://boto3.amazonaws.com/v1/documentation/api/1.9.42/guide/s3.html
"""
super().__init__(
bucket,
[
S3Config(
bucket_name=bucket,
access_key=ak,
secret_key=sk,
endpoint_url=endpoint_url,
addressing_style=addressing_style,
)
],
)
from abc import ABC, abstractmethod
from typing import Iterator
import fitz
from magic_pdf.config.enums import SupportedPdfParseMethod
from magic_pdf.data.schemas import PageInfo
from magic_pdf.data.utils import fitz_doc_to_image
class PageableData(ABC):
@abstractmethod
def get_image(self) -> dict:
"""Transform data to image."""
pass
@abstractmethod
def get_doc(self) -> fitz.Page:
"""Get the pymudoc page."""
pass
@abstractmethod
def get_page_info(self) -> PageInfo:
"""Get the page info of the page.
Returns:
PageInfo: the page info of this page
"""
pass
class Dataset(ABC):
@abstractmethod
def __len__(self) -> int:
"""The length of the dataset."""
pass
@abstractmethod
def __iter__(self) -> Iterator[PageableData]:
"""Yield the page data."""
pass
@abstractmethod
def supported_methods(self) -> list[SupportedPdfParseMethod]:
"""The methods that this dataset support.
Returns:
list[SupportedPdfParseMethod]: The supported methods, Valid methods are: OCR, TXT
"""
pass
@abstractmethod
def data_bits(self) -> bytes:
"""The bits used to create this dataset."""
pass
@abstractmethod
def get_page(self, page_id: int) -> PageableData:
"""Get the page indexed by page_id.
Args:
page_id (int): the index of the page
Returns:
PageableData: the page doc object
"""
pass
class PymuDocDataset(Dataset):
def __init__(self, bits: bytes):
"""Initialize the dataset, which wraps the pymudoc documents.
Args:
bits (bytes): the bytes of the pdf
"""
self._records = [Doc(v) for v in fitz.open('pdf', bits)]
self._data_bits = bits
self._raw_data = bits
def __len__(self) -> int:
"""The page number of the pdf."""
return len(self._records)
def __iter__(self) -> Iterator[PageableData]:
"""Yield the page doc object."""
return iter(self._records)
def supported_methods(self) -> list[SupportedPdfParseMethod]:
"""The method supported by this dataset.
Returns:
list[SupportedPdfParseMethod]: the supported methods
"""
return [SupportedPdfParseMethod.OCR, SupportedPdfParseMethod.TXT]
def data_bits(self) -> bytes:
"""The pdf bits used to create this dataset."""
return self._data_bits
def get_page(self, page_id: int) -> PageableData:
"""The page doc object.
Args:
page_id (int): the page doc index
Returns:
PageableData: the page doc object
"""
return self._records[page_id]
class ImageDataset(Dataset):
def __init__(self, bits: bytes):
"""Initialize the dataset, which wraps the pymudoc documents.
Args:
bits (bytes): the bytes of the photo which will be converted to pdf first. then converted to pymudoc.
"""
pdf_bytes = fitz.open(stream=bits).convert_to_pdf()
self._records = [Doc(v) for v in fitz.open('pdf', pdf_bytes)]
self._raw_data = bits
self._data_bits = pdf_bytes
def __len__(self) -> int:
"""The length of the dataset."""
return len(self._records)
def __iter__(self) -> Iterator[PageableData]:
"""Yield the page object."""
return iter(self._records)
def supported_methods(self):
"""The method supported by this dataset.
Returns:
list[SupportedPdfParseMethod]: the supported methods
"""
return [SupportedPdfParseMethod.OCR]
def data_bits(self) -> bytes:
"""The pdf bits used to create this dataset."""
return self._data_bits
def get_page(self, page_id: int) -> PageableData:
"""The page doc object.
Args:
page_id (int): the page doc index
Returns:
PageableData: the page doc object
"""
return self._records[page_id]
class Doc(PageableData):
"""Initialized with pymudoc object."""
def __init__(self, doc: fitz.Page):
self._doc = doc
def get_image(self):
"""Return the imge info.
Returns:
dict: {
img: np.ndarray,
width: int,
height: int
}
"""
return fitz_doc_to_image(self._doc)
def get_doc(self) -> fitz.Page:
"""Get the pymudoc object.
Returns:
fitz.Page: the pymudoc object
"""
return self._doc
def get_page_info(self) -> PageInfo:
"""Get the page info of the page.
Returns:
PageInfo: the page info of this page
"""
page_w = self._doc.rect.width
page_h = self._doc.rect.height
return PageInfo(w=page_w, h=page_h)
def __getattr__(self, name):
if hasattr(self._doc, name):
return getattr(self._doc, name)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment