Unverified Commit 9d689790 authored by linfeng's avatar linfeng Committed by GitHub
Browse files

Merge branch 'opendatalab:dev' into dev

parents bcef0868 fb383ba6
...@@ -80,6 +80,7 @@ body: ...@@ -80,6 +80,7 @@ body:
- -
- "0.6.x" - "0.6.x"
- "0.7.x" - "0.7.x"
- "0.8.x"
validations: validations:
required: true required: true
......
...@@ -37,12 +37,13 @@ jobs: ...@@ -37,12 +37,13 @@ jobs:
run: | run: |
echo $GITHUB_WORKSPACE && sh tests/retry_env.sh echo $GITHUB_WORKSPACE && sh tests/retry_env.sh
- name: unit test - name: unit test
run: | run: |
cd $GITHUB_WORKSPACE && export PYTHONPATH=. && coverage run -m pytest tests/test_unit.py --cov=magic_pdf/ --cov-report term-missing --cov-report html cd $GITHUB_WORKSPACE && python tests/clean_coverage.py
cd $GITHUB_WORKSPACE && export PYTHONPATH=. && coverage run -m pytest tests/unittest --cov=magic_pdf/ --cov-report term-missing --cov-report html
cd $GITHUB_WORKSPACE && python tests/get_coverage.py cd $GITHUB_WORKSPACE && python tests/get_coverage.py
- name: cli test - name: cli test
run: | run: |
cd $GITHUB_WORKSPACE && pytest -s -v tests/test_cli/test_cli_sdk.py source ~/.bashrc && cd $GITHUB_WORKSPACE && pytest -s -v tests/test_cli/test_cli.py
notify_to_feishu: notify_to_feishu:
if: ${{ always() && !cancelled() && contains(needs.*.result, 'failure') && (github.ref_name == 'master') }} if: ${{ always() && !cancelled() && contains(needs.*.result, 'failure') && (github.ref_name == 'master') }}
......
...@@ -659,3 +659,4 @@ specific requirements. ...@@ -659,3 +659,4 @@ specific requirements.
if any, to sign a "copyright disclaimer" for the program, if necessary. if any, to sign a "copyright disclaimer" for the program, if necessary.
For more information on this, and how to apply and follow the GNU AGPL, see For more information on this, and how to apply and follow the GNU AGPL, see
<https://www.gnu.org/licenses/>. <https://www.gnu.org/licenses/>.
This diff is collapsed.
This diff is collapsed.
...@@ -44,3 +44,11 @@ pip uninstall fairscale ...@@ -44,3 +44,11 @@ pip uninstall fairscale
pip install fairscale pip install fairscale
``` ```
Reference: https://github.com/opendatalab/MinerU/issues/411 Reference: https://github.com/opendatalab/MinerU/issues/411
### 6. On some newer devices like the H100, the text parsed during OCR using CUDA acceleration is garbled.
The compatibility of cuda11 with new graphics cards is poor, and the CUDA version used by Paddle needs to be upgraded.
```bash
pip install paddlepaddle-gpu==3.0.0b1 -i https://www.paddlepaddle.org.cn/packages/stable/cu123/
```
Reference: https://github.com/opendatalab/MinerU/issues/558
...@@ -41,3 +41,11 @@ pip uninstall fairscale ...@@ -41,3 +41,11 @@ pip uninstall fairscale
pip install fairscale pip install fairscale
``` ```
参考:https://github.com/opendatalab/MinerU/issues/411 参考:https://github.com/opendatalab/MinerU/issues/411
### 6.在部分较新的设备如H100上,使用CUDA加速OCR时解析出的文字乱码。
cuda11对新显卡的兼容性不好,需要升级paddle使用的cuda版本
```bash
pip install paddlepaddle-gpu==3.0.0b1 -i https://www.paddlepaddle.org.cn/packages/stable/cu123/
```
参考:https://github.com/opendatalab/MinerU/issues/558
from huggingface_hub import snapshot_download
model_dir = snapshot_download('opendatalab/PDF-Extract-Kit')
print(f"model dir is: {model_dir}/models")
...@@ -6,58 +6,8 @@ wget https://github.com/opendatalab/MinerU/raw/master/docs/download_models_hf.py ...@@ -6,58 +6,8 @@ wget https://github.com/opendatalab/MinerU/raw/master/docs/download_models_hf.py
python download_models_hf.py python download_models_hf.py
``` ```
After the Python script finishes executing, it will output the directory where the models are downloaded. After the Python script finishes executing, it will output the directory where the models are downloaded.
### 2. Additional steps
#### 1. Check whether the model directory is downloaded completely. ### 2. To modify the model path address in the configuration file
The structure of the model folder is as follows, including configuration files and weight files of different components:
```
../
├── Layout
│ ├── config.json
│ └── model_final.pth
├── MFD
│ └── weights.pt
├── MFR
│ └── UniMERNet
│ ├── config.json
│ ├── preprocessor_config.json
│ ├── pytorch_model.bin
│ ├── README.md
│ ├── tokenizer_config.json
│ └── tokenizer.json
│── TabRec
│ └─StructEqTable
│ ├── config.json
│ ├── generation_config.json
│ ├── model.safetensors
│ ├── preprocessor_config.json
│ ├── special_tokens_map.json
│ ├── spiece.model
│ ├── tokenizer.json
│ └── tokenizer_config.json
│ └─ TableMaster
│ └─ ch_PP-OCRv3_det_infer
│ ├── inference.pdiparams
│ ├── inference.pdiparams.info
│ └── inference.pdmodel
│ └─ ch_PP-OCRv3_rec_infer
│ ├── inference.pdiparams
│ ├── inference.pdiparams.info
│ └── inference.pdmodel
│ └─ table_structure_tablemaster_infer
│ ├── inference.pdiparams
│ ├── inference.pdiparams.info
│ └── inference.pdmodel
│ ├── ppocr_keys_v1.txt
│ └── table_master_structure_dict.txt
└── README.md
```
#### 2. Check whether the model file is fully downloaded.
Please check whether the size of the model file in the directory is consistent with the description on the web page. If possible, it is best to check whether the model is downloaded completely through sha256.
#### 3.
Additionally, in `~/magic-pdf.json`, update the model directory path to the absolute path of the `models` directory output by the previous Python script. Otherwise, you will encounter an error indicating that the model cannot be loaded. Additionally, in `~/magic-pdf.json`, update the model directory path to the absolute path of the `models` directory output by the previous Python script. Otherwise, you will encounter an error indicating that the model cannot be loaded.
...@@ -21,55 +21,7 @@ wget https://gitee.com/myhloli/MinerU/raw/master/docs/download_models.py ...@@ -21,55 +21,7 @@ wget https://gitee.com/myhloli/MinerU/raw/master/docs/download_models.py
python download_models.py python download_models.py
``` ```
python脚本执行完毕后,会输出模型下载目录 python脚本执行完毕后,会输出模型下载目录
## 【❗️必须要做❗️】的额外步骤(模型下载完成后请务必完成以下操作)
### 1.检查模型目录是否下载完整
模型文件夹的结构如下,包含了不同组件的配置文件和权重文件:
```
./
├── Layout # 布局检测模型
│ ├── config.json
│ └── model_final.pth
├── MFD # 公式检测
│ └── weights.pt
├── MFR # 公式识别模型
│ └── UniMERNet
│ ├── config.json
│ ├── preprocessor_config.json
│ ├── pytorch_model.bin
│ ├── README.md
│ ├── tokenizer_config.json
│ └── tokenizer.json
│── TabRec # 表格识别模型
│ └─StructEqTable
│ ├── config.json
│ ├── generation_config.json
│ ├── model.safetensors
│ ├── preprocessor_config.json
│ ├── special_tokens_map.json
│ ├── spiece.model
│ ├── tokenizer.json
│ └── tokenizer_config.json
│ └─ TableMaster
│ └─ ch_PP-OCRv3_det_infer
│ ├── inference.pdiparams
│ ├── inference.pdiparams.info
│ └── inference.pdmodel
│ └─ ch_PP-OCRv3_rec_infer
│ ├── inference.pdiparams
│ ├── inference.pdiparams.info
│ └── inference.pdmodel
│ └─ table_structure_tablemaster_infer
│ ├── inference.pdiparams
│ ├── inference.pdiparams.info
│ └── inference.pdmodel
│ ├── ppocr_keys_v1.txt
│ └── table_master_structure_dict.txt
└── README.md
```
### 2.检查模型文件是否下载完整
请检查目录下的模型文件大小与网页上描述是否一致,如果可以的话,最好通过sha256校验模型是否下载完整
### 3.修改magic-pdf.json中的模型路径 ## 下载完成后的操作:修改magic-pdf.json中的模型路径
此外在 `~/magic-pdf.json`里修改模型的目录指向之前python脚本输出的models目录的绝对路径,否则会报模型无法加载的错误。 `~/magic-pdf.json`里修改模型的目录指向上一步脚本输出的models目录的绝对路径,否则会报模型无法加载的错误。
...@@ -116,17 +116,20 @@ def ocr_mk_markdown_with_para_core(paras_of_layout, mode, img_buket_path=''): ...@@ -116,17 +116,20 @@ def ocr_mk_markdown_with_para_core(paras_of_layout, mode, img_buket_path=''):
def ocr_mk_markdown_with_para_core_v2(paras_of_layout, def ocr_mk_markdown_with_para_core_v2(paras_of_layout,
mode, mode,
img_buket_path=''): img_buket_path='',
parse_type="auto",
lang=None
):
page_markdown = [] page_markdown = []
for para_block in paras_of_layout: for para_block in paras_of_layout:
para_text = '' para_text = ''
para_type = para_block['type'] para_type = para_block['type']
if para_type == BlockType.Text: if para_type == BlockType.Text:
para_text = merge_para_with_text(para_block) para_text = merge_para_with_text(para_block, parse_type=parse_type, lang=lang)
elif para_type == BlockType.Title: elif para_type == BlockType.Title:
para_text = f'# {merge_para_with_text(para_block)}' para_text = f'# {merge_para_with_text(para_block, parse_type=parse_type, lang=lang)}'
elif para_type == BlockType.InterlineEquation: elif para_type == BlockType.InterlineEquation:
para_text = merge_para_with_text(para_block) para_text = merge_para_with_text(para_block, parse_type=parse_type, lang=lang)
elif para_type == BlockType.Image: elif para_type == BlockType.Image:
if mode == 'nlp': if mode == 'nlp':
continue continue
...@@ -139,17 +142,17 @@ def ocr_mk_markdown_with_para_core_v2(paras_of_layout, ...@@ -139,17 +142,17 @@ def ocr_mk_markdown_with_para_core_v2(paras_of_layout,
para_text += f"\n![]({join_path(img_buket_path, span['image_path'])}) \n" para_text += f"\n![]({join_path(img_buket_path, span['image_path'])}) \n"
for block in para_block['blocks']: # 2nd.拼image_caption for block in para_block['blocks']: # 2nd.拼image_caption
if block['type'] == BlockType.ImageCaption: if block['type'] == BlockType.ImageCaption:
para_text += merge_para_with_text(block) para_text += merge_para_with_text(block, parse_type=parse_type, lang=lang)
for block in para_block['blocks']: # 2nd.拼image_caption for block in para_block['blocks']: # 2nd.拼image_caption
if block['type'] == BlockType.ImageFootnote: if block['type'] == BlockType.ImageFootnote:
para_text += merge_para_with_text(block) para_text += merge_para_with_text(block, parse_type=parse_type, lang=lang)
elif para_type == BlockType.Table: elif para_type == BlockType.Table:
if mode == 'nlp': if mode == 'nlp':
continue continue
elif mode == 'mm': elif mode == 'mm':
for block in para_block['blocks']: # 1st.拼table_caption for block in para_block['blocks']: # 1st.拼table_caption
if block['type'] == BlockType.TableCaption: if block['type'] == BlockType.TableCaption:
para_text += merge_para_with_text(block) para_text += merge_para_with_text(block, parse_type=parse_type, lang=lang)
for block in para_block['blocks']: # 2nd.拼table_body for block in para_block['blocks']: # 2nd.拼table_body
if block['type'] == BlockType.TableBody: if block['type'] == BlockType.TableBody:
for line in block['lines']: for line in block['lines']:
...@@ -164,7 +167,7 @@ def ocr_mk_markdown_with_para_core_v2(paras_of_layout, ...@@ -164,7 +167,7 @@ def ocr_mk_markdown_with_para_core_v2(paras_of_layout,
para_text += f"\n![]({join_path(img_buket_path, span['image_path'])}) \n" para_text += f"\n![]({join_path(img_buket_path, span['image_path'])}) \n"
for block in para_block['blocks']: # 3rd.拼table_footnote for block in para_block['blocks']: # 3rd.拼table_footnote
if block['type'] == BlockType.TableFootnote: if block['type'] == BlockType.TableFootnote:
para_text += merge_para_with_text(block) para_text += merge_para_with_text(block, parse_type=parse_type, lang=lang)
if para_text.strip() == '': if para_text.strip() == '':
continue continue
...@@ -174,7 +177,7 @@ def ocr_mk_markdown_with_para_core_v2(paras_of_layout, ...@@ -174,7 +177,7 @@ def ocr_mk_markdown_with_para_core_v2(paras_of_layout,
return page_markdown return page_markdown
def merge_para_with_text(para_block): def merge_para_with_text(para_block, parse_type="auto", lang=None):
def detect_language(text): def detect_language(text):
en_pattern = r'[a-zA-Z]+' en_pattern = r'[a-zA-Z]+'
...@@ -205,11 +208,15 @@ def merge_para_with_text(para_block): ...@@ -205,11 +208,15 @@ def merge_para_with_text(para_block):
content = span['content'] content = span['content']
# language = detect_lang(content) # language = detect_lang(content)
language = detect_language(content) language = detect_language(content)
if language == 'en': # 只对英文长词进行分词处理,中文分词会丢失文本 # 判断是否小语种
content = ocr_escape_special_markdown_char( if lang is not None and lang != 'en':
split_long_words(content))
else:
content = ocr_escape_special_markdown_char(content) content = ocr_escape_special_markdown_char(content)
else: # 非小语种逻辑
if language == 'en' and parse_type == 'ocr': # 只对英文长词进行分词处理,中文分词会丢失文本
content = ocr_escape_special_markdown_char(
split_long_words(content))
else:
content = ocr_escape_special_markdown_char(content)
elif span_type == ContentType.InlineEquation: elif span_type == ContentType.InlineEquation:
content = f" ${span['content']}$ " content = f" ${span['content']}$ "
elif span_type == ContentType.InterlineEquation: elif span_type == ContentType.InterlineEquation:
...@@ -265,41 +272,39 @@ def para_to_standard_format(para, img_buket_path): ...@@ -265,41 +272,39 @@ def para_to_standard_format(para, img_buket_path):
return para_content return para_content
def para_to_standard_format_v2(para_block, img_buket_path, page_idx): def para_to_standard_format_v2(para_block, img_buket_path, page_idx, parse_type="auto", lang=None, drop_reason=None):
para_type = para_block['type'] para_type = para_block['type']
para_content = {}
if para_type == BlockType.Text: if para_type == BlockType.Text:
para_content = { para_content = {
'type': 'text', 'type': 'text',
'text': merge_para_with_text(para_block), 'text': merge_para_with_text(para_block, parse_type=parse_type, lang=lang),
'page_idx': page_idx,
} }
elif para_type == BlockType.Title: elif para_type == BlockType.Title:
para_content = { para_content = {
'type': 'text', 'type': 'text',
'text': merge_para_with_text(para_block), 'text': merge_para_with_text(para_block, parse_type=parse_type, lang=lang),
'text_level': 1, 'text_level': 1,
'page_idx': page_idx,
} }
elif para_type == BlockType.InterlineEquation: elif para_type == BlockType.InterlineEquation:
para_content = { para_content = {
'type': 'equation', 'type': 'equation',
'text': merge_para_with_text(para_block), 'text': merge_para_with_text(para_block, parse_type=parse_type, lang=lang),
'text_format': 'latex', 'text_format': 'latex',
'page_idx': page_idx,
} }
elif para_type == BlockType.Image: elif para_type == BlockType.Image:
para_content = {'type': 'image', 'page_idx': page_idx} para_content = {'type': 'image'}
for block in para_block['blocks']: for block in para_block['blocks']:
if block['type'] == BlockType.ImageBody: if block['type'] == BlockType.ImageBody:
para_content['img_path'] = join_path( para_content['img_path'] = join_path(
img_buket_path, img_buket_path,
block['lines'][0]['spans'][0]['image_path']) block['lines'][0]['spans'][0]['image_path'])
if block['type'] == BlockType.ImageCaption: if block['type'] == BlockType.ImageCaption:
para_content['img_caption'] = merge_para_with_text(block) para_content['img_caption'] = merge_para_with_text(block, parse_type=parse_type, lang=lang)
if block['type'] == BlockType.ImageFootnote: if block['type'] == BlockType.ImageFootnote:
para_content['img_footnote'] = merge_para_with_text(block) para_content['img_footnote'] = merge_para_with_text(block, parse_type=parse_type, lang=lang)
elif para_type == BlockType.Table: elif para_type == BlockType.Table:
para_content = {'type': 'table', 'page_idx': page_idx} para_content = {'type': 'table'}
for block in para_block['blocks']: for block in para_block['blocks']:
if block['type'] == BlockType.TableBody: if block['type'] == BlockType.TableBody:
if block["lines"][0]["spans"][0].get('latex', ''): if block["lines"][0]["spans"][0].get('latex', ''):
...@@ -308,9 +313,14 @@ def para_to_standard_format_v2(para_block, img_buket_path, page_idx): ...@@ -308,9 +313,14 @@ def para_to_standard_format_v2(para_block, img_buket_path, page_idx):
para_content['table_body'] = f"\n\n{block['lines'][0]['spans'][0]['html']}\n\n" para_content['table_body'] = f"\n\n{block['lines'][0]['spans'][0]['html']}\n\n"
para_content['img_path'] = join_path(img_buket_path, block["lines"][0]["spans"][0]['image_path']) para_content['img_path'] = join_path(img_buket_path, block["lines"][0]["spans"][0]['image_path'])
if block['type'] == BlockType.TableCaption: if block['type'] == BlockType.TableCaption:
para_content['table_caption'] = merge_para_with_text(block) para_content['table_caption'] = merge_para_with_text(block, parse_type=parse_type, lang=lang)
if block['type'] == BlockType.TableFootnote: if block['type'] == BlockType.TableFootnote:
para_content['table_footnote'] = merge_para_with_text(block) para_content['table_footnote'] = merge_para_with_text(block, parse_type=parse_type, lang=lang)
para_content['page_idx'] = page_idx
if drop_reason is not None:
para_content['drop_reason'] = drop_reason
return para_content return para_content
...@@ -394,13 +404,19 @@ def ocr_mk_mm_standard_format(pdf_info_dict: list): ...@@ -394,13 +404,19 @@ def ocr_mk_mm_standard_format(pdf_info_dict: list):
def union_make(pdf_info_dict: list, def union_make(pdf_info_dict: list,
make_mode: str, make_mode: str,
drop_mode: str, drop_mode: str,
img_buket_path: str = ''): img_buket_path: str = '',
parse_type: str = "auto",
lang=None):
output_content = [] output_content = []
for page_info in pdf_info_dict: for page_info in pdf_info_dict:
drop_reason_flag = False
drop_reason = None
if page_info.get('need_drop', False): if page_info.get('need_drop', False):
drop_reason = page_info.get('drop_reason') drop_reason = page_info.get('drop_reason')
if drop_mode == DropMode.NONE: if drop_mode == DropMode.NONE:
pass pass
elif drop_mode == DropMode.NONE_WITH_REASON:
drop_reason_flag = True
elif drop_mode == DropMode.WHOLE_PDF: elif drop_mode == DropMode.WHOLE_PDF:
raise Exception((f'drop_mode is {DropMode.WHOLE_PDF} ,' raise Exception((f'drop_mode is {DropMode.WHOLE_PDF} ,'
f'drop_reason is {drop_reason}')) f'drop_reason is {drop_reason}'))
...@@ -417,16 +433,20 @@ def union_make(pdf_info_dict: list, ...@@ -417,16 +433,20 @@ def union_make(pdf_info_dict: list,
continue continue
if make_mode == MakeMode.MM_MD: if make_mode == MakeMode.MM_MD:
page_markdown = ocr_mk_markdown_with_para_core_v2( page_markdown = ocr_mk_markdown_with_para_core_v2(
paras_of_layout, 'mm', img_buket_path) paras_of_layout, 'mm', img_buket_path, parse_type=parse_type, lang=lang)
output_content.extend(page_markdown) output_content.extend(page_markdown)
elif make_mode == MakeMode.NLP_MD: elif make_mode == MakeMode.NLP_MD:
page_markdown = ocr_mk_markdown_with_para_core_v2( page_markdown = ocr_mk_markdown_with_para_core_v2(
paras_of_layout, 'nlp') paras_of_layout, 'nlp', parse_type=parse_type, lang=lang)
output_content.extend(page_markdown) output_content.extend(page_markdown)
elif make_mode == MakeMode.STANDARD_FORMAT: elif make_mode == MakeMode.STANDARD_FORMAT:
for para_block in paras_of_layout: for para_block in paras_of_layout:
para_content = para_to_standard_format_v2( if drop_reason_flag:
para_block, img_buket_path, page_idx) para_content = para_to_standard_format_v2(
para_block, img_buket_path, page_idx, parse_type=parse_type, lang=lang, drop_reason=drop_reason)
else:
para_content = para_to_standard_format_v2(
para_block, img_buket_path, page_idx, parse_type=parse_type, lang=lang)
output_content.append(para_content) output_content.append(para_content)
if make_mode in [MakeMode.MM_MD, MakeMode.NLP_MD]: if make_mode in [MakeMode.MM_MD, MakeMode.NLP_MD]:
return '\n\n'.join(output_content) return '\n\n'.join(output_content)
......
...@@ -8,3 +8,4 @@ class DropMode: ...@@ -8,3 +8,4 @@ class DropMode:
WHOLE_PDF = "whole_pdf" WHOLE_PDF = "whole_pdf"
SINGLE_PAGE = "single_page" SINGLE_PAGE = "single_page"
NONE = "none" NONE = "none"
NONE_WITH_REASON = "none_with_reason"
...@@ -426,3 +426,22 @@ def bbox_distance(bbox1, bbox2): ...@@ -426,3 +426,22 @@ def bbox_distance(bbox1, bbox2):
elif top: elif top:
return y2 - y1b return y2 - y1b
return 0.0 return 0.0
def box_area(bbox):
return (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
def get_overlap_area(bbox1, bbox2):
"""计算box1和box2的重叠面积占bbox1的比例."""
# Determine the coordinates of the intersection rectangle
x_left = max(bbox1[0], bbox2[0])
y_top = max(bbox1[1], bbox2[1])
x_right = min(bbox1[2], bbox2[2])
y_bottom = min(bbox1[3], bbox2[3])
if x_right < x_left or y_bottom < y_top:
return 0.0
# The area of overlap area
return (x_right - x_left) * (y_bottom - y_top)
__version__ = "0.7.1" __version__ = "0.8.0"
...@@ -57,14 +57,14 @@ class ModelSingleton: ...@@ -57,14 +57,14 @@ class ModelSingleton:
cls._instance = super().__new__(cls) cls._instance = super().__new__(cls)
return cls._instance return cls._instance
def get_model(self, ocr: bool, show_log: bool): def get_model(self, ocr: bool, show_log: bool, lang=None):
key = (ocr, show_log) key = (ocr, show_log, lang)
if key not in self._models: if key not in self._models:
self._models[key] = custom_model_init(ocr=ocr, show_log=show_log) self._models[key] = custom_model_init(ocr=ocr, show_log=show_log, lang=lang)
return self._models[key] return self._models[key]
def custom_model_init(ocr: bool = False, show_log: bool = False): def custom_model_init(ocr: bool = False, show_log: bool = False, lang=None):
model = None model = None
if model_config.__model_mode__ == "lite": if model_config.__model_mode__ == "lite":
...@@ -78,7 +78,7 @@ def custom_model_init(ocr: bool = False, show_log: bool = False): ...@@ -78,7 +78,7 @@ def custom_model_init(ocr: bool = False, show_log: bool = False):
model_init_start = time.time() model_init_start = time.time()
if model == MODEL.Paddle: if model == MODEL.Paddle:
from magic_pdf.model.pp_structure_v2 import CustomPaddleModel from magic_pdf.model.pp_structure_v2 import CustomPaddleModel
custom_model = CustomPaddleModel(ocr=ocr, show_log=show_log) custom_model = CustomPaddleModel(ocr=ocr, show_log=show_log, lang=lang)
elif model == MODEL.PEK: elif model == MODEL.PEK:
from magic_pdf.model.pdf_extract_kit import CustomPEKModel from magic_pdf.model.pdf_extract_kit import CustomPEKModel
# 从配置文件读取model-dir和device # 从配置文件读取model-dir和device
...@@ -89,7 +89,9 @@ def custom_model_init(ocr: bool = False, show_log: bool = False): ...@@ -89,7 +89,9 @@ def custom_model_init(ocr: bool = False, show_log: bool = False):
"show_log": show_log, "show_log": show_log,
"models_dir": local_models_dir, "models_dir": local_models_dir,
"device": device, "device": device,
"table_config": table_config} "table_config": table_config,
"lang": lang,
}
custom_model = CustomPEKModel(**model_input) custom_model = CustomPEKModel(**model_input)
else: else:
logger.error("Not allow model_name!") logger.error("Not allow model_name!")
...@@ -104,10 +106,10 @@ def custom_model_init(ocr: bool = False, show_log: bool = False): ...@@ -104,10 +106,10 @@ def custom_model_init(ocr: bool = False, show_log: bool = False):
def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False, def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False,
start_page_id=0, end_page_id=None): start_page_id=0, end_page_id=None, lang=None):
model_manager = ModelSingleton() model_manager = ModelSingleton()
custom_model = model_manager.get_model(ocr, show_log) custom_model = model_manager.get_model(ocr, show_log, lang)
images = load_images_from_pdf(pdf_bytes) images = load_images_from_pdf(pdf_bytes)
......
import json import json
from magic_pdf.libs.boxbase import (_is_in, _is_part_overlap, bbox_distance, from magic_pdf.libs.boxbase import (_is_in, _is_part_overlap, bbox_distance,
bbox_relative_pos, calculate_iou, bbox_relative_pos, box_area, calculate_iou,
calculate_overlap_area_in_bbox1_area_ratio) calculate_overlap_area_in_bbox1_area_ratio,
get_overlap_area)
from magic_pdf.libs.commons import fitz, join_path from magic_pdf.libs.commons import fitz, join_path
from magic_pdf.libs.coordinate_transform import get_scale_ratio from magic_pdf.libs.coordinate_transform import get_scale_ratio
from magic_pdf.libs.local_math import float_gt from magic_pdf.libs.local_math import float_gt
...@@ -12,6 +13,7 @@ from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter ...@@ -12,6 +13,7 @@ from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter
from magic_pdf.rw.DiskReaderWriter import DiskReaderWriter from magic_pdf.rw.DiskReaderWriter import DiskReaderWriter
CAPATION_OVERLAP_AREA_RATIO = 0.6 CAPATION_OVERLAP_AREA_RATIO = 0.6
MERGE_BOX_OVERLAP_AREA_RATIO = 1.1
class MagicModel: class MagicModel:
...@@ -124,49 +126,51 @@ class MagicModel: ...@@ -124,49 +126,51 @@ class MagicModel:
tables.append(obj) tables.append(obj)
if len(footnotes) * len(figures) == 0: if len(footnotes) * len(figures) == 0:
continue continue
dis_figure_footnote = {} dis_figure_footnote = {}
dis_table_footnote = {} dis_table_footnote = {}
for i in range(len(footnotes)): for i in range(len(footnotes)):
for j in range(len(figures)): for j in range(len(figures)):
pos_flag_count = sum( pos_flag_count = sum(
list( list(
map( map(
lambda x: 1 if x else 0, lambda x: 1 if x else 0,
bbox_relative_pos( bbox_relative_pos(
footnotes[i]['bbox'], figures[j]['bbox'] footnotes[i]['bbox'], figures[j]['bbox']
), ),
)
) )
) )
if pos_flag_count > 1: )
continue if pos_flag_count > 1:
dis_figure_footnote[i] = min( continue
bbox_distance(figures[j]['bbox'], footnotes[i]['bbox']), dis_figure_footnote[i] = min(
dis_figure_footnote.get(i, float('inf')), bbox_distance(figures[j]['bbox'], footnotes[i]['bbox']),
) dis_figure_footnote.get(i, float('inf')),
for i in range(len(footnotes)): )
for j in range(len(tables)): for i in range(len(footnotes)):
pos_flag_count = sum( for j in range(len(tables)):
list( pos_flag_count = sum(
map( list(
lambda x: 1 if x else 0, map(
bbox_relative_pos( lambda x: 1 if x else 0,
footnotes[i]['bbox'], tables[j]['bbox'] bbox_relative_pos(
), footnotes[i]['bbox'], tables[j]['bbox']
) ),
) )
) )
if pos_flag_count > 1: )
continue if pos_flag_count > 1:
continue
dis_table_footnote[i] = min( dis_table_footnote[i] = min(
bbox_distance(tables[j]['bbox'], footnotes[i]['bbox']), bbox_distance(tables[j]['bbox'], footnotes[i]['bbox']),
dis_table_footnote.get(i, float('inf')), dis_table_footnote.get(i, float('inf')),
) )
for i in range(len(footnotes)): for i in range(len(footnotes)):
if dis_table_footnote.get(i, float('inf')) > dis_figure_footnote[i]: if i not in dis_figure_footnote:
footnotes[i]['category_id'] = CategoryId.ImageFootnote continue
if dis_table_footnote.get(i, float('inf')) > dis_figure_footnote[i]:
footnotes[i]['category_id'] = CategoryId.ImageFootnote
def __reduct_overlap(self, bboxes): def __reduct_overlap(self, bboxes):
N = len(bboxes) N = len(bboxes)
...@@ -191,6 +195,44 @@ class MagicModel: ...@@ -191,6 +195,44 @@ class MagicModel:
筛选出所有和 merged bbox 有 overlap 且 overlap 面积大于 object 的面积的 subjects。 筛选出所有和 merged bbox 有 overlap 且 overlap 面积大于 object 的面积的 subjects。
再求出筛选出的 subjects 和 object 的最短距离 再求出筛选出的 subjects 和 object 的最短距离
""" """
def search_overlap_between_boxes(
subject_idx, object_idx
):
idxes = [subject_idx, object_idx]
x0s = [all_bboxes[idx]['bbox'][0] for idx in idxes]
y0s = [all_bboxes[idx]['bbox'][1] for idx in idxes]
x1s = [all_bboxes[idx]['bbox'][2] for idx in idxes]
y1s = [all_bboxes[idx]['bbox'][3] for idx in idxes]
merged_bbox = [
min(x0s),
min(y0s),
max(x1s),
max(y1s),
]
ratio = 0
other_objects = list(
map(
lambda x: {'bbox': x['bbox'], 'score': x['score']},
filter(
lambda x: x['category_id']
not in (object_category_id, subject_category_id),
self.__model_list[page_no]['layout_dets'],
),
)
)
for other_object in other_objects:
ratio = max(
ratio,
get_overlap_area(
merged_bbox, other_object['bbox']
) * 1.0 / box_area(all_bboxes[object_idx]['bbox'])
)
if ratio >= MERGE_BOX_OVERLAP_AREA_RATIO:
break
return ratio
def may_find_other_nearest_bbox(subject_idx, object_idx): def may_find_other_nearest_bbox(subject_idx, object_idx):
ret = float('inf') ret = float('inf')
...@@ -299,6 +341,15 @@ class MagicModel: ...@@ -299,6 +341,15 @@ class MagicModel:
): ):
continue continue
subject_idx, object_idx = i, j
if all_bboxes[j]['category_id'] == subject_category_id:
subject_idx, object_idx = j, i
if search_overlap_between_boxes(subject_idx, object_idx) >= MERGE_BOX_OVERLAP_AREA_RATIO:
dis[i][j] = float('inf')
dis[j][i] = dis[i][j]
continue
dis[i][j] = bbox_distance(all_bboxes[i]['bbox'], all_bboxes[j]['bbox']) dis[i][j] = bbox_distance(all_bboxes[i]['bbox'], all_bboxes[j]['bbox'])
dis[j][i] = dis[i][j] dis[j][i] = dis[i][j]
...@@ -627,13 +678,13 @@ class MagicModel: ...@@ -627,13 +678,13 @@ class MagicModel:
span['type'] = ContentType.Image span['type'] = ContentType.Image
elif category_id == 5: elif category_id == 5:
# 获取table模型结果 # 获取table模型结果
latex = layout_det.get("latex", None) latex = layout_det.get('latex', None)
html = layout_det.get("html", None) html = layout_det.get('html', None)
if latex: if latex:
span["latex"] = latex span['latex'] = latex
elif html: elif html:
span["html"] = html span['html'] = html
span["type"] = ContentType.Table span['type'] = ContentType.Table
elif category_id == 13: elif category_id == 13:
span['content'] = layout_det['latex'] span['content'] = layout_det['latex']
span['type'] = ContentType.InlineEquation span['type'] = ContentType.InlineEquation
......
...@@ -58,7 +58,7 @@ def mfd_model_init(weight): ...@@ -58,7 +58,7 @@ def mfd_model_init(weight):
def mfr_model_init(weight_dir, cfg_path, _device_='cpu'): def mfr_model_init(weight_dir, cfg_path, _device_='cpu'):
args = argparse.Namespace(cfg_path=cfg_path, options=None) args = argparse.Namespace(cfg_path=cfg_path, options=None)
cfg = Config(args) cfg = Config(args)
cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.bin") cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.pth")
cfg.config.model.model_config.model_name = weight_dir cfg.config.model.model_config.model_name = weight_dir
cfg.config.model.tokenizer_config.path = weight_dir cfg.config.model.tokenizer_config.path = weight_dir
task = tasks.setup_task(cfg) task = tasks.setup_task(cfg)
...@@ -74,8 +74,11 @@ def layout_model_init(weight, config_file, device): ...@@ -74,8 +74,11 @@ def layout_model_init(weight, config_file, device):
return model return model
def ocr_model_init(show_log: bool = False, det_db_box_thresh=0.3): def ocr_model_init(show_log: bool = False, det_db_box_thresh=0.3, lang=None):
model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=det_db_box_thresh) if lang is not None:
model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=det_db_box_thresh, lang=lang)
else:
model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=det_db_box_thresh)
return model return model
...@@ -134,7 +137,8 @@ def atom_model_init(model_name: str, **kwargs): ...@@ -134,7 +137,8 @@ def atom_model_init(model_name: str, **kwargs):
elif model_name == AtomicModel.OCR: elif model_name == AtomicModel.OCR:
atom_model = ocr_model_init( atom_model = ocr_model_init(
kwargs.get("ocr_show_log"), kwargs.get("ocr_show_log"),
kwargs.get("det_db_box_thresh") kwargs.get("det_db_box_thresh"),
kwargs.get("lang")
) )
elif model_name == AtomicModel.Table: elif model_name == AtomicModel.Table:
atom_model = table_model_init( atom_model = table_model_init(
...@@ -177,9 +181,10 @@ class CustomPEKModel: ...@@ -177,9 +181,10 @@ class CustomPEKModel:
self.table_max_time = self.table_config.get("max_time", TABLE_MAX_TIME_VALUE) self.table_max_time = self.table_config.get("max_time", TABLE_MAX_TIME_VALUE)
self.table_model_type = self.table_config.get("model", TABLE_MASTER) self.table_model_type = self.table_config.get("model", TABLE_MASTER)
self.apply_ocr = ocr self.apply_ocr = ocr
self.lang = kwargs.get("lang", None)
logger.info( logger.info(
"DocAnalysis init, this may take some times. apply_layout: {}, apply_formula: {}, apply_ocr: {}, apply_table: {}".format( "DocAnalysis init, this may take some times. apply_layout: {}, apply_formula: {}, apply_ocr: {}, apply_table: {}, lang: {}".format(
self.apply_layout, self.apply_formula, self.apply_ocr, self.apply_table self.apply_layout, self.apply_formula, self.apply_ocr, self.apply_table, self.lang
) )
) )
assert self.apply_layout, "DocAnalysis must contain layout model." assert self.apply_layout, "DocAnalysis must contain layout model."
...@@ -225,11 +230,13 @@ class CustomPEKModel: ...@@ -225,11 +230,13 @@ class CustomPEKModel:
) )
# 初始化ocr # 初始化ocr
if self.apply_ocr: if self.apply_ocr:
# self.ocr_model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=0.3) # self.ocr_model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=0.3)
self.ocr_model = atom_model_manager.get_atom_model( self.ocr_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.OCR, atom_model_name=AtomicModel.OCR,
ocr_show_log=show_log, ocr_show_log=show_log,
det_db_box_thresh=0.3 det_db_box_thresh=0.3,
lang=self.lang
) )
# init table model # init table model
if self.apply_table: if self.apply_table:
...@@ -243,6 +250,7 @@ class CustomPEKModel: ...@@ -243,6 +250,7 @@ class CustomPEKModel:
table_max_time=self.table_max_time, table_max_time=self.table_max_time,
device=self.device device=self.device
) )
logger.info('DocAnalysis init done!') logger.info('DocAnalysis init done!')
def __call__(self, image): def __call__(self, image):
...@@ -383,6 +391,7 @@ class CustomPEKModel: ...@@ -383,6 +391,7 @@ class CustomPEKModel:
latex_code = self.table_model.image2latex(new_image)[0] latex_code = self.table_model.image2latex(new_image)[0]
else: else:
html_code = self.table_model.img2html(new_image) html_code = self.table_model.img2html(new_image)
run_time = time.time() - single_table_start_time run_time = time.time() - single_table_start_time
logger.info(f"------------table recognition processing ends within {run_time}s-----") logger.info(f"------------table recognition processing ends within {run_time}s-----")
if run_time > self.table_max_time: if run_time > self.table_max_time:
......
...@@ -18,8 +18,11 @@ def region_to_bbox(region): ...@@ -18,8 +18,11 @@ def region_to_bbox(region):
class CustomPaddleModel: class CustomPaddleModel:
def __init__(self, ocr: bool = False, show_log: bool = False): def __init__(self, ocr: bool = False, show_log: bool = False, lang=None):
self.model = PPStructure(table=False, ocr=ocr, show_log=show_log) if lang is not None:
self.model = PPStructure(table=False, ocr=ocr, show_log=show_log, lang=lang)
else:
self.model = PPStructure(table=False, ocr=ocr, show_log=show_log)
def __call__(self, img): def __call__(self, img):
try: try:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment