"vscode:/vscode.git/clone" did not exist on "fb6bc29b236f5eb1cbb1dd3d250b1a84d0327704"
Unverified Commit 3a42ebbf authored by Xiaomeng Zhao's avatar Xiaomeng Zhao Committed by GitHub
Browse files

Merge pull request #838 from opendatalab/release-0.9.0

Release 0.9.0
parents 765c6d77 14024793
...@@ -3,82 +3,107 @@ ...@@ -3,82 +3,107 @@
## 1. 安装cuda和cuDNN ## 1. 安装cuda和cuDNN
需要安装的版本 CUDA 11.8 + cuDNN 8.7.0 需要安装的版本 CUDA 11.8 + cuDNN 8.7.0
- CUDA 11.8 https://developer.nvidia.com/cuda-11-8-0-download-archive - CUDA 11.8 https://developer.nvidia.com/cuda-11-8-0-download-archive
- cuDNN v8.7.0 (November 28th, 2022), for CUDA 11.x https://developer.nvidia.com/rdp/cudnn-archive - cuDNN v8.7.0 (November 28th, 2022), for CUDA 11.x https://developer.nvidia.com/rdp/cudnn-archive
## 2. 安装anaconda ## 2. 安装anaconda
如果已安装conda,可以跳过本步骤 如果已安装conda,可以跳过本步骤
下载链接: 下载链接:
https://mirrors.tuna.tsinghua.edu.cn/anaconda/archive/Anaconda3-2024.06-1-Windows-x86_64.exe https://mirrors.tuna.tsinghua.edu.cn/anaconda/archive/Anaconda3-2024.06-1-Windows-x86_64.exe
## 3. 使用conda 创建环境 ## 3. 使用conda 创建环境
需指定python版本为3.10 需指定python版本为3.10
```bash ```bash
conda create -n MinerU python=3.10 conda create -n MinerU python=3.10
conda activate MinerU conda activate MinerU
``` ```
## 4. 安装应用 ## 4. 安装应用
```bash ```bash
pip install -U magic-pdf[full] --extra-index-url https://wheels.myhloli.com -i https://pypi.tuna.tsinghua.edu.cn/simple pip install -U magic-pdf[full] --extra-index-url https://wheels.myhloli.com -i https://mirrors.aliyun.com/pypi/simple
``` ```
> ❗️下载完成后,务必通过以下命令确认magic-pdf的版本是否正确 > ❗️下载完成后,务必通过以下命令确认magic-pdf的版本是否正确
> >
> ```bash > ```bash
> magic-pdf --version > magic-pdf --version
>``` > ```
>
> 如果版本号小于0.7.0,请到issue中向我们反馈 > 如果版本号小于0.7.0,请到issue中向我们反馈
## 5. 下载模型 ## 5. 下载模型
详细参考 [如何下载模型文件](how_to_download_models_zh_cn.md) 详细参考 [如何下载模型文件](how_to_download_models_zh_cn.md)
## 6. 了解配置文件存放的位置 ## 6. 了解配置文件存放的位置
完成[5.下载模型](#5-下载模型)步骤后,脚本会自动生成用户目录下的magic-pdf.json文件,并自动配置默认模型路径。 完成[5.下载模型](#5-下载模型)步骤后,脚本会自动生成用户目录下的magic-pdf.json文件,并自动配置默认模型路径。
您可在【用户目录】下找到magic-pdf.json文件。 您可在【用户目录】下找到magic-pdf.json文件。
> windows用户目录为 "C:/Users/用户名" > windows用户目录为 "C:/Users/用户名"
## 7. 第一次运行 ## 7. 第一次运行
从仓库中下载样本文件,并测试 从仓库中下载样本文件,并测试
```powershell ```powershell
(New-Object System.Net.WebClient).DownloadFile('https://gitee.com/myhloli/MinerU/raw/master/demo/small_ocr.pdf', 'small_ocr.pdf') wget https://github.com/opendatalab/MinerU/raw/master/demo/small_ocr.pdf -O small_ocr.pdf
magic-pdf -p small_ocr.pdf magic-pdf -p small_ocr.pdf
``` ```
## 8. 测试CUDA加速 ## 8. 测试CUDA加速
如果您的显卡显存大于等于8G,可以进行以下流程,测试CUDA解析加速效果
如果您的显卡显存大于等于 **8GB** ,可以进行以下流程,测试CUDA解析加速效果
**1.覆盖安装支持cuda的torch和torchvision** **1.覆盖安装支持cuda的torch和torchvision**
```bash ```bash
pip install --force-reinstall torch==2.3.1 torchvision==0.18.1 --index-url https://download.pytorch.org/whl/cu118 pip install --force-reinstall torch==2.3.1 torchvision==0.18.1 --index-url https://download.pytorch.org/whl/cu118
``` ```
> ❗️务必在命令中指定以下版本 > ❗️务必在命令中指定以下版本
>
> ```bash > ```bash
> torch==2.3.1 torchvision==0.18.1 > torch==2.3.1 torchvision==0.18.1
> ``` > ```
>
> 这是我们支持的最高版本,如果不指定版本会自动安装更高版本导致程序无法运行 > 这是我们支持的最高版本,如果不指定版本会自动安装更高版本导致程序无法运行
**2.修改【用户目录】中配置文件magic-pdf.json中"device-mode"的值** **2.修改【用户目录】中配置文件magic-pdf.json中"device-mode"的值**
```json ```json
{ {
"device-mode":"cuda" "device-mode":"cuda"
} }
``` ```
**3.运行以下命令测试cuda加速效果** **3.运行以下命令测试cuda加速效果**
```bash ```bash
magic-pdf -p small_ocr.pdf magic-pdf -p small_ocr.pdf
``` ```
> 提示:CUDA加速是否生效可以根据log中输出的各个阶段cost耗时来简单判断,通常情况下,`layout detection cost` 和 `mfr time` 应提速10倍以上。
> 提示:CUDA加速是否生效可以根据log中输出的各个阶段的耗时来简单判断,通常情况下,`layout detection time` 和 `mfr time` 应提速10倍以上。
## 9. 为ocr开启cuda加速 ## 9. 为ocr开启cuda加速
> ❗️以下操作需显卡显存大于等于16G才可进行,否则会因为显存不足导致程序崩溃或运行速度下降
**1.下载paddlepaddle-gpu, 安装完成后会自动开启ocr加速** **1.下载paddlepaddle-gpu, 安装完成后会自动开启ocr加速**
```bash ```bash
pip install paddlepaddle-gpu==2.6.1 pip install paddlepaddle-gpu==2.6.1
``` ```
**2.运行以下命令测试ocr加速效果** **2.运行以下命令测试ocr加速效果**
```bash ```bash
magic-pdf -p small_ocr.pdf magic-pdf -p small_ocr.pdf
``` ```
> 提示:CUDA加速是否生效可以根据log中输出的各个阶段cost耗时来简单判断,通常情况下,`ocr cost`应提速10倍以上。
> 提示:CUDA加速是否生效可以根据log中输出的各个阶段cost耗时来简单判断,通常情况下,`ocr time`应提速10倍以上。
import json
import os import os
import requests import requests
import json
from modelscope import snapshot_download from modelscope import snapshot_download
def download_json(url):
# 下载JSON文件
response = requests.get(url)
response.raise_for_status() # 检查请求是否成功
return response.json()
def download_and_modify_json(url, local_filename, modifications): def download_and_modify_json(url, local_filename, modifications):
if os.path.exists(local_filename): if os.path.exists(local_filename):
data = json.load(open(local_filename)) data = json.load(open(local_filename))
config_version = data.get('config_version', '0.0.0')
if config_version < '1.0.0':
data = download_json(url)
else: else:
# 下载JSON文件 data = download_json(url)
response = requests.get(url)
response.raise_for_status() # 检查请求是否成功
# 解析JSON内容
data = response.json()
# 修改内容 # 修改内容
for key, value in modifications.items(): for key, value in modifications.items():
...@@ -25,15 +33,25 @@ def download_and_modify_json(url, local_filename, modifications): ...@@ -25,15 +33,25 @@ def download_and_modify_json(url, local_filename, modifications):
if __name__ == '__main__': if __name__ == '__main__':
model_dir = snapshot_download('opendatalab/PDF-Extract-Kit')
mineru_patterns = [
"models/Layout/LayoutLMv3/*",
"models/Layout/YOLO/*",
"models/MFD/YOLO/*",
"models/MFR/unimernet_small/*",
"models/TabRec/TableMaster/*",
"models/TabRec/StructEqTable/*",
]
model_dir = snapshot_download('opendatalab/PDF-Extract-Kit-1.0', allow_patterns=mineru_patterns)
layoutreader_model_dir = snapshot_download('ppaanngggg/layoutreader') layoutreader_model_dir = snapshot_download('ppaanngggg/layoutreader')
model_dir = model_dir + "/models" model_dir = model_dir + '/models'
print(f"model_dir is: {model_dir}") print(f'model_dir is: {model_dir}')
print(f"layoutreader_model_dir is: {layoutreader_model_dir}") print(f'layoutreader_model_dir is: {layoutreader_model_dir}')
json_url = 'https://gitee.com/myhloli/MinerU/raw/dev/magic-pdf.template.json'
config_file_name = 'magic-pdf.json'
home_dir = os.path.expanduser('~')
json_url = 'https://gitee.com/myhloli/MinerU/raw/master/magic-pdf.template.json'
config_file_name = "magic-pdf.json"
home_dir = os.path.expanduser("~")
config_file = os.path.join(home_dir, config_file_name) config_file = os.path.join(home_dir, config_file_name)
json_mods = { json_mods = {
...@@ -42,4 +60,6 @@ if __name__ == '__main__': ...@@ -42,4 +60,6 @@ if __name__ == '__main__':
} }
download_and_modify_json(json_url, config_file, json_mods) download_and_modify_json(json_url, config_file, json_mods)
print(f"The configuration file has been configured successfully, the path is: {config_file}")
\ No newline at end of file print(f'The configuration file has been configured successfully, the path is: {config_file}')
import json
import os import os
import requests import requests
import json
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
def download_json(url):
# 下载JSON文件
response = requests.get(url)
response.raise_for_status() # 检查请求是否成功
return response.json()
def download_and_modify_json(url, local_filename, modifications): def download_and_modify_json(url, local_filename, modifications):
if os.path.exists(local_filename): if os.path.exists(local_filename):
data = json.load(open(local_filename)) data = json.load(open(local_filename))
config_version = data.get('config_version', '0.0.0')
if config_version < '1.0.0':
data = download_json(url)
else: else:
# 下载JSON文件 data = download_json(url)
response = requests.get(url)
response.raise_for_status() # 检查请求是否成功
# 解析JSON内容
data = response.json()
# 修改内容 # 修改内容
for key, value in modifications.items(): for key, value in modifications.items():
...@@ -25,15 +32,31 @@ def download_and_modify_json(url, local_filename, modifications): ...@@ -25,15 +32,31 @@ def download_and_modify_json(url, local_filename, modifications):
if __name__ == '__main__': if __name__ == '__main__':
model_dir = snapshot_download('opendatalab/PDF-Extract-Kit')
layoutreader_model_dir = snapshot_download('hantian/layoutreader') mineru_patterns = [
model_dir = model_dir + "/models" "models/Layout/LayoutLMv3/*",
print(f"model_dir is: {model_dir}") "models/Layout/YOLO/*",
print(f"layoutreader_model_dir is: {layoutreader_model_dir}") "models/MFD/YOLO/*",
"models/MFR/unimernet_small/*",
json_url = 'https://github.com/opendatalab/MinerU/raw/master/magic-pdf.template.json' "models/TabRec/TableMaster/*",
config_file_name = "magic-pdf.json" "models/TabRec/StructEqTable/*",
home_dir = os.path.expanduser("~") ]
model_dir = snapshot_download('opendatalab/PDF-Extract-Kit-1.0', allow_patterns=mineru_patterns)
layoutreader_pattern = [
"*.json",
"*.safetensors",
]
layoutreader_model_dir = snapshot_download('hantian/layoutreader', allow_patterns=layoutreader_pattern)
model_dir = model_dir + '/models'
print(f'model_dir is: {model_dir}')
print(f'layoutreader_model_dir is: {layoutreader_model_dir}')
json_url = 'https://github.com/opendatalab/MinerU/raw/dev/magic-pdf.template.json'
config_file_name = 'magic-pdf.json'
home_dir = os.path.expanduser('~')
config_file = os.path.join(home_dir, config_file_name) config_file = os.path.join(home_dir, config_file_name)
json_mods = { json_mods = {
...@@ -42,4 +65,6 @@ if __name__ == '__main__': ...@@ -42,4 +65,6 @@ if __name__ == '__main__':
} }
download_and_modify_json(json_url, config_file, json_mods) download_and_modify_json(json_url, config_file, json_mods)
print(f"The configuration file has been configured successfully, the path is: {config_file}")
\ No newline at end of file print(f'The configuration file has been configured successfully, the path is: {config_file}')
...@@ -3,7 +3,8 @@ Model downloads are divided into initial downloads and updates to the model dire ...@@ -3,7 +3,8 @@ Model downloads are divided into initial downloads and updates to the model dire
# Initial download of model files # Initial download of model files
### 1. Download the Model from Hugging Face ### Download the Model from Hugging Face
Use a Python Script to Download Model Files from Hugging Face Use a Python Script to Download Model Files from Hugging Face
```bash ```bash
pip install huggingface_hub pip install huggingface_hub
...@@ -14,13 +15,16 @@ The Python script will automatically download the model files and configure the ...@@ -14,13 +15,16 @@ The Python script will automatically download the model files and configure the
The configuration file can be found in the user directory, with the filename `magic-pdf.json`. The configuration file can be found in the user directory, with the filename `magic-pdf.json`.
# How to update models previously downloaded # How to update models previously downloaded
## 1. Models downloaded via Git LFS ## 1. Models downloaded via Git LFS
>Due to feedback from some users that downloading model files using git lfs was incomplete or resulted in corrupted model files, this method is no longer recommended. > Due to feedback from some users that downloading model files using git lfs was incomplete or resulted in corrupted model files, this method is no longer recommended.
When magic-pdf <= 0.8.1, if you have previously downloaded the model files via git lfs, you can navigate to the previous download directory and update the models using the `git pull` command.
If you previously downloaded model files via git lfs, you can navigate to the previous download directory and use the `git pull` command to update the model. > For versions 0.9.x and later, due to the repository change and the addition of the layout sorting model in PDF-Extract-Kit 1.0, the models cannot be updated using the `git pull` command. Instead, a Python script must be used for one-click updates.
## 2. Models downloaded via Hugging Face or Model Scope ## 2. Models downloaded via Hugging Face or Model Scope
......
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
<pre><code>pip install huggingface_hub <pre><code>pip install huggingface_hub
wget https://gitee.com/myhloli/MinerU/raw/master/docs/download_models_hf.py -O download_models_hf.py wget https://gitee.com/myhloli/MinerU/raw/master/docs/download_models_hf.py -O download_models_hf.py
python download_models_hf.py</code></pre> python download_models_hf.py</code></pre>
<p>python脚本执行完毕后,会输出模型下载目录</p> <p>python脚本会自动下载模型文件并配置好配置文件中的模型目录</p>
</details> </details>
## 方法二:从 ModelScope 下载模型 ## 方法二:从 ModelScope 下载模型
...@@ -25,6 +25,7 @@ python download_models.py ...@@ -25,6 +25,7 @@ python download_models.py
python脚本会自动下载模型文件并配置好配置文件中的模型目录 python脚本会自动下载模型文件并配置好配置文件中的模型目录
配置文件可以在用户目录中找到,文件名为`magic-pdf.json` 配置文件可以在用户目录中找到,文件名为`magic-pdf.json`
> windows的用户目录为 "C:\\Users\\用户名", linux用户目录为 "/home/用户名", macOS用户目录为 "/Users/用户名" > windows的用户目录为 "C:\\Users\\用户名", linux用户目录为 "/home/用户名", macOS用户目录为 "/Users/用户名"
...@@ -32,16 +33,12 @@ python脚本会自动下载模型文件并配置好配置文件中的模型目 ...@@ -32,16 +33,12 @@ python脚本会自动下载模型文件并配置好配置文件中的模型目
## 1. 通过git lfs下载过模型 ## 1. 通过git lfs下载过模型
>由于部分用户反馈通过git lfs下载模型文件遇到下载不全和模型文件损坏情况,现已不推荐使用该方式下载。 > 由于部分用户反馈通过git lfs下载模型文件遇到下载不全和模型文件损坏情况,现已不推荐使用该方式下载。
当magic-pdf <= 0.8.1时,如此前通过 git lfs 下载过模型文件,可以进入到之前的下载目录中,通过`git pull`命令更新模型。
如此前通过 git lfs 下载过模型文件,可以进入到之前的下载目录中,通过`git pull`命令更新模型 > 0.9.x及以后版本由于PDF-Extract-Kit 1.0更换仓库和新增layout排序模型,不能通过`git pull`命令更新,需要使用python脚本一键更新
> 0.9.x及以后版本由于新增layout排序模型,且该模型和此前的模型不在同一仓库,不能通过`git pull`命令更新,需要单独下载。
>
>```
>from modelscope import snapshot_download
>snapshot_download('ppaanngggg/layoutreader')
>```
## 2. 通过 Hugging Face 或 Model Scope 下载过模型 ## 2. 通过 Hugging Face 或 Model Scope 下载过模型
......
...@@ -175,11 +175,14 @@ Detailed explanation of second-level block types ...@@ -175,11 +175,14 @@ Detailed explanation of second-level block types
| :----------------- | :--------------------- | | :----------------- | :--------------------- |
| image_body | Main body of the image | | image_body | Main body of the image |
| image_caption | Image description text | | image_caption | Image description text |
| image_footnote | Image footnote |
| table_body | Main body of the table | | table_body | Main body of the table |
| table_caption | Table description text | | table_caption | Table description text |
| table_footnote | Table footnote | | table_footnote | Table footnote |
| text | Text block | | text | Text block |
| title | Title block | | title | Title block |
| index | Index block |
| list | List block |
| interline_equation | Block formula | | interline_equation | Block formula |
<br> <br>
......
...@@ -174,11 +174,14 @@ poly 坐标的格式 \[x0, y0, x1, y1, x2, y2, x3, y3\], 分别表示左上、 ...@@ -174,11 +174,14 @@ poly 坐标的格式 \[x0, y0, x1, y1, x2, y2, x3, y3\], 分别表示左上、
| :----------------- | :------------- | | :----------------- | :------------- |
| image_body | 图像的本体 | | image_body | 图像的本体 |
| image_caption | 图像的描述文本 | | image_caption | 图像的描述文本 |
| image_footnote | 图像的脚注 |
| table_body | 表格本体 | | table_body | 表格本体 |
| table_caption | 表格的描述文本 | | table_caption | 表格的描述文本 |
| table_footnote | 表格的脚注 | | table_footnote | 表格的脚注 |
| text | 文本块 | | text | 文本块 |
| title | 标题块 | | title | 标题块 |
| index | 目录块 |
| list | 列表块 |
| interline_equation | 行间公式块 | | interline_equation | 行间公式块 |
<br> <br>
......
...@@ -6,9 +6,18 @@ ...@@ -6,9 +6,18 @@
"models-dir":"/tmp/models", "models-dir":"/tmp/models",
"layoutreader-model-dir":"/tmp/layoutreader", "layoutreader-model-dir":"/tmp/layoutreader",
"device-mode":"cpu", "device-mode":"cpu",
"layout-config": {
"model": "layoutlmv3"
},
"formula-config": {
"mfd_model": "yolo_v8_mfd",
"mfr_model": "unimernet_small",
"enable": true
},
"table-config": { "table-config": {
"model": "TableMaster", "model": "tablemaster",
"is_table_recog_enable": false, "enable": false,
"max_time": 400 "max_time": 400
} },
"config_version": "1.0.0"
} }
\ No newline at end of file
import enum
class SupportedPdfParseMethod(enum.Enum):
OCR = 'ocr'
TXT = 'txt'
class FileNotExisted(Exception):
def __init__(self, path):
self.path = path
def __str__(self):
return f'File {self.path} does not exist.'
class InvalidConfig(Exception):
def __init__(self, msg):
self.msg = msg
def __str__(self):
return f'Invalid config: {self.msg}'
class InvalidParams(Exception):
def __init__(self, msg):
self.msg = msg
def __str__(self):
return f'Invalid params: {self.msg}'
class EmptyData(Exception):
def __init__(self, msg):
self.msg = msg
def __str__(self):
return f'Empty data: {self.msg}'
from magic_pdf.data.data_reader_writer.filebase import \
FileBasedDataReader # noqa: F401
from magic_pdf.data.data_reader_writer.filebase import \
FileBasedDataWriter # noqa: F401
from magic_pdf.data.data_reader_writer.multi_bucket_s3 import \
MultiBucketS3DataReader # noqa: F401
from magic_pdf.data.data_reader_writer.multi_bucket_s3 import \
MultiBucketS3DataWriter # noqa: F401
from magic_pdf.data.data_reader_writer.s3 import S3DataReader # noqa: F401
from magic_pdf.data.data_reader_writer.s3 import S3DataWriter # noqa: F401
from magic_pdf.data.data_reader_writer.base import DataReader # noqa: F401
from magic_pdf.data.data_reader_writer.base import DataWriter # noqa: F401
\ No newline at end of file
from abc import ABC, abstractmethod
class DataReader(ABC):
def read(self, path: str) -> bytes:
"""Read the file.
Args:
path (str): file path to read
Returns:
bytes: the content of the file
"""
return self.read_at(path)
@abstractmethod
def read_at(self, path: str, offset: int = 0, limit: int = -1) -> bytes:
"""Read the file at offset and limit.
Args:
path (str): the file path
offset (int, optional): the number of bytes skipped. Defaults to 0.
limit (int, optional): the length of bytes want to read. Defaults to -1.
Returns:
bytes: the content of the file
"""
pass
class DataWriter(ABC):
@abstractmethod
def write(self, path: str, data: bytes) -> None:
"""Write the data to the file.
Args:
path (str): the target file where to write
data (bytes): the data want to write
"""
pass
def write_string(self, path: str, data: str) -> None:
"""Write the data to file, the data will be encoded to bytes.
Args:
path (str): the target file where to write
data (str): the data want to write
"""
self.write(path, data.encode())
import os
from magic_pdf.data.data_reader_writer.base import DataReader, DataWriter
class FileBasedDataReader(DataReader):
def __init__(self, parent_dir: str = ''):
"""Initialized with parent_dir.
Args:
parent_dir (str, optional): the parent directory that may be used within methods. Defaults to ''.
"""
self._parent_dir = parent_dir
def read_at(self, path: str, offset: int = 0, limit: int = -1) -> bytes:
"""Read at offset and limit.
Args:
path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
offset (int, optional): the number of bytes skipped. Defaults to 0.
limit (int, optional): the length of bytes want to read. Defaults to -1.
Returns:
bytes: the content of file
"""
fn_path = path
if not os.path.isabs(fn_path) and len(self._parent_dir) > 0:
fn_path = os.path.join(self._parent_dir, path)
with open(fn_path, 'rb') as f:
f.seek(offset)
if limit == -1:
return f.read()
else:
return f.read(limit)
class FileBasedDataWriter(DataWriter):
def __init__(self, parent_dir: str = '') -> None:
"""Initialized with parent_dir.
Args:
parent_dir (str, optional): the parent directory that may be used within methods. Defaults to ''.
"""
self._parent_dir = parent_dir
def write(self, path: str, data: bytes) -> None:
"""Write file with data.
Args:
path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
data (bytes): the data want to write
"""
fn_path = path
if not os.path.isabs(fn_path) and len(self._parent_dir) > 0:
fn_path = os.path.join(self._parent_dir, path)
with open(fn_path, 'wb') as f:
f.write(data)
from magic_pdf.config.exceptions import InvalidConfig, InvalidParams
from magic_pdf.data.data_reader_writer.base import DataReader, DataWriter
from magic_pdf.data.io.s3 import S3Reader, S3Writer
from magic_pdf.data.schemas import S3Config
from magic_pdf.libs.path_utils import (parse_s3_range_params, parse_s3path,
remove_non_official_s3_args)
class MultiS3Mixin:
def __init__(self, default_bucket: str, s3_configs: list[S3Config]):
"""Initialized with multiple s3 configs.
Args:
default_bucket (str): the default bucket name of the relative path
s3_configs (list[S3Config]): list of s3 configs, the bucket_name must be unique in the list.
Raises:
InvalidConfig: default bucket config not in s3_configs
InvalidConfig: bucket name not unique in s3_configs
InvalidConfig: default bucket must be provided
"""
if len(default_bucket) == 0:
raise InvalidConfig('default_bucket must be provided')
found_default_bucket_config = False
for conf in s3_configs:
if conf.bucket_name == default_bucket:
found_default_bucket_config = True
break
if not found_default_bucket_config:
raise InvalidConfig(
f'default_bucket: {default_bucket} config must be provided in s3_configs: {s3_configs}'
)
uniq_bucket = set([conf.bucket_name for conf in s3_configs])
if len(uniq_bucket) != len(s3_configs):
raise InvalidConfig(
f'the bucket_name in s3_configs: {s3_configs} must be unique'
)
self.default_bucket = default_bucket
self.s3_configs = s3_configs
self._s3_clients_h: dict = {}
class MultiBucketS3DataReader(DataReader, MultiS3Mixin):
def read(self, path: str) -> bytes:
"""Read the path from s3, select diffect bucket client for each request
based on the path, also support range read.
Args:
path (str): the s3 path of file, the path must be in the format of s3://bucket_name/path?offset,limit
for example: s3://bucket_name/path?0,100
Returns:
bytes: the content of s3 file
"""
may_range_params = parse_s3_range_params(path)
if may_range_params is None or 2 != len(may_range_params):
byte_start, byte_len = 0, -1
else:
byte_start, byte_len = int(may_range_params[0]), int(may_range_params[1])
path = remove_non_official_s3_args(path)
return self.read_at(path, byte_start, byte_len)
def __get_s3_client(self, bucket_name: str):
if bucket_name not in set([conf.bucket_name for conf in self.s3_configs]):
raise InvalidParams(
f'bucket name: {bucket_name} not found in s3_configs: {self.s3_configs}'
)
if bucket_name not in self._s3_clients_h:
conf = next(
filter(lambda conf: conf.bucket_name == bucket_name, self.s3_configs)
)
self._s3_clients_h[bucket_name] = S3Reader(
bucket_name,
conf.access_key,
conf.secret_key,
conf.endpoint_url,
conf.addressing_style,
)
return self._s3_clients_h[bucket_name]
def read_at(self, path: str, offset: int = 0, limit: int = -1) -> bytes:
"""Read the file with offset and limit, select diffect bucket client
for each request based on the path.
Args:
path (str): the file path
offset (int, optional): the number of bytes skipped. Defaults to 0.
limit (int, optional): the number of bytes want to read. Defaults to -1 which means infinite.
Returns:
bytes: the file content
"""
if path.startswith('s3://'):
bucket_name, path = parse_s3path(path)
s3_reader = self.__get_s3_client(bucket_name)
else:
s3_reader = self.__get_s3_client(self.default_bucket)
return s3_reader.read_at(path, offset, limit)
class MultiBucketS3DataWriter(DataWriter, MultiS3Mixin):
def __get_s3_client(self, bucket_name: str):
if bucket_name not in set([conf.bucket_name for conf in self.s3_configs]):
raise InvalidParams(
f'bucket name: {bucket_name} not found in s3_configs: {self.s3_configs}'
)
if bucket_name not in self._s3_clients_h:
conf = next(
filter(lambda conf: conf.bucket_name == bucket_name, self.s3_configs)
)
self._s3_clients_h[bucket_name] = S3Writer(
bucket_name,
conf.access_key,
conf.secret_key,
conf.endpoint_url,
conf.addressing_style,
)
return self._s3_clients_h[bucket_name]
def write(self, path: str, data: bytes) -> None:
"""Write file with data, also select diffect bucket client for each
request based on the path.
Args:
path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
data (bytes): the data want to write
"""
if path.startswith('s3://'):
bucket_name, path = parse_s3path(path)
s3_writer = self.__get_s3_client(bucket_name)
else:
s3_writer = self.__get_s3_client(self.default_bucket)
return s3_writer.write(path, data)
from magic_pdf.data.data_reader_writer.multi_bucket_s3 import (
MultiBucketS3DataReader, MultiBucketS3DataWriter)
from magic_pdf.data.schemas import S3Config
class S3DataReader(MultiBucketS3DataReader):
def __init__(
self,
bucket: str,
ak: str,
sk: str,
endpoint_url: str,
addressing_style: str = 'auto',
):
"""s3 reader client.
Args:
bucket (str): bucket name
ak (str): access key
sk (str): secret key
endpoint_url (str): endpoint url of s3
addressing_style (str, optional): Defaults to 'auto'. Other valid options here are 'path' and 'virtual'
refer to https://boto3.amazonaws.com/v1/documentation/api/1.9.42/guide/s3.html
"""
super().__init__(
bucket,
[
S3Config(
bucket_name=bucket,
access_key=ak,
secret_key=sk,
endpoint_url=endpoint_url,
addressing_style=addressing_style,
)
],
)
class S3DataWriter(MultiBucketS3DataWriter):
def __init__(
self,
bucket: str,
ak: str,
sk: str,
endpoint_url: str,
addressing_style: str = 'auto',
):
"""s3 writer client.
Args:
bucket (str): bucket name
ak (str): access key
sk (str): secret key
endpoint_url (str): endpoint url of s3
addressing_style (str, optional): Defaults to 'auto'. Other valid options here are 'path' and 'virtual'
refer to https://boto3.amazonaws.com/v1/documentation/api/1.9.42/guide/s3.html
"""
super().__init__(
bucket,
[
S3Config(
bucket_name=bucket,
access_key=ak,
secret_key=sk,
endpoint_url=endpoint_url,
addressing_style=addressing_style,
)
],
)
from abc import ABC, abstractmethod
from typing import Iterator
import fitz
from magic_pdf.config.enums import SupportedPdfParseMethod
from magic_pdf.data.schemas import PageInfo
from magic_pdf.data.utils import fitz_doc_to_image
class PageableData(ABC):
@abstractmethod
def get_image(self) -> dict:
"""Transform data to image."""
pass
@abstractmethod
def get_doc(self) -> fitz.Page:
"""Get the pymudoc page."""
pass
@abstractmethod
def get_page_info(self) -> PageInfo:
"""Get the page info of the page.
Returns:
PageInfo: the page info of this page
"""
pass
class Dataset(ABC):
@abstractmethod
def __len__(self) -> int:
"""The length of the dataset."""
pass
@abstractmethod
def __iter__(self) -> Iterator[PageableData]:
"""Yield the page data."""
pass
@abstractmethod
def supported_methods(self) -> list[SupportedPdfParseMethod]:
"""The methods that this dataset support.
Returns:
list[SupportedPdfParseMethod]: The supported methods, Valid methods are: OCR, TXT
"""
pass
@abstractmethod
def data_bits(self) -> bytes:
"""The bits used to create this dataset."""
pass
@abstractmethod
def get_page(self, page_id: int) -> PageableData:
"""Get the page indexed by page_id.
Args:
page_id (int): the index of the page
Returns:
PageableData: the page doc object
"""
pass
class PymuDocDataset(Dataset):
def __init__(self, bits: bytes):
"""Initialize the dataset, which wraps the pymudoc documents.
Args:
bits (bytes): the bytes of the pdf
"""
self._records = [Doc(v) for v in fitz.open('pdf', bits)]
self._data_bits = bits
self._raw_data = bits
def __len__(self) -> int:
"""The page number of the pdf."""
return len(self._records)
def __iter__(self) -> Iterator[PageableData]:
"""Yield the page doc object."""
return iter(self._records)
def supported_methods(self) -> list[SupportedPdfParseMethod]:
"""The method supported by this dataset.
Returns:
list[SupportedPdfParseMethod]: the supported methods
"""
return [SupportedPdfParseMethod.OCR, SupportedPdfParseMethod.TXT]
def data_bits(self) -> bytes:
"""The pdf bits used to create this dataset."""
return self._data_bits
def get_page(self, page_id: int) -> PageableData:
"""The page doc object.
Args:
page_id (int): the page doc index
Returns:
PageableData: the page doc object
"""
return self._records[page_id]
class ImageDataset(Dataset):
def __init__(self, bits: bytes):
"""Initialize the dataset, which wraps the pymudoc documents.
Args:
bits (bytes): the bytes of the photo which will be converted to pdf first. then converted to pymudoc.
"""
pdf_bytes = fitz.open(stream=bits).convert_to_pdf()
self._records = [Doc(v) for v in fitz.open('pdf', pdf_bytes)]
self._raw_data = bits
self._data_bits = pdf_bytes
def __len__(self) -> int:
"""The length of the dataset."""
return len(self._records)
def __iter__(self) -> Iterator[PageableData]:
"""Yield the page object."""
return iter(self._records)
def supported_methods(self):
"""The method supported by this dataset.
Returns:
list[SupportedPdfParseMethod]: the supported methods
"""
return [SupportedPdfParseMethod.OCR]
def data_bits(self) -> bytes:
"""The pdf bits used to create this dataset."""
return self._data_bits
def get_page(self, page_id: int) -> PageableData:
"""The page doc object.
Args:
page_id (int): the page doc index
Returns:
PageableData: the page doc object
"""
return self._records[page_id]
class Doc(PageableData):
"""Initialized with pymudoc object."""
def __init__(self, doc: fitz.Page):
self._doc = doc
def get_image(self):
"""Return the imge info.
Returns:
dict: {
img: np.ndarray,
width: int,
height: int
}
"""
return fitz_doc_to_image(self._doc)
def get_doc(self) -> fitz.Page:
"""Get the pymudoc object.
Returns:
fitz.Page: the pymudoc object
"""
return self._doc
def get_page_info(self) -> PageInfo:
"""Get the page info of the page.
Returns:
PageInfo: the page info of this page
"""
page_w = self._doc.rect.width
page_h = self._doc.rect.height
return PageInfo(w=page_w, h=page_h)
def __getattr__(self, name):
if hasattr(self._doc, name):
return getattr(self._doc, name)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment