Unverified Commit 9d689790 authored by linfeng's avatar linfeng Committed by GitHub
Browse files

Merge branch 'opendatalab:dev' into dev

parents bcef0868 fb383ba6
......@@ -80,6 +80,7 @@ body:
-
- "0.6.x"
- "0.7.x"
- "0.8.x"
validations:
required: true
......
......@@ -37,12 +37,13 @@ jobs:
run: |
echo $GITHUB_WORKSPACE && sh tests/retry_env.sh
- name: unit test
run: |
cd $GITHUB_WORKSPACE && export PYTHONPATH=. && coverage run -m pytest tests/test_unit.py --cov=magic_pdf/ --cov-report term-missing --cov-report html
run: |
cd $GITHUB_WORKSPACE && python tests/clean_coverage.py
cd $GITHUB_WORKSPACE && export PYTHONPATH=. && coverage run -m pytest tests/unittest --cov=magic_pdf/ --cov-report term-missing --cov-report html
cd $GITHUB_WORKSPACE && python tests/get_coverage.py
- name: cli test
run: |
cd $GITHUB_WORKSPACE && pytest -s -v tests/test_cli/test_cli_sdk.py
source ~/.bashrc && cd $GITHUB_WORKSPACE && pytest -s -v tests/test_cli/test_cli.py
notify_to_feishu:
if: ${{ always() && !cancelled() && contains(needs.*.result, 'failure') && (github.ref_name == 'master') }}
......
......@@ -659,3 +659,4 @@ specific requirements.
if any, to sign a "copyright disclaimer" for the program, if necessary.
For more information on this, and how to apply and follow the GNU AGPL, see
<https://www.gnu.org/licenses/>.
This diff is collapsed.
This diff is collapsed.
......@@ -44,3 +44,11 @@ pip uninstall fairscale
pip install fairscale
```
Reference: https://github.com/opendatalab/MinerU/issues/411
### 6. On some newer devices like the H100, the text parsed during OCR using CUDA acceleration is garbled.
The compatibility of cuda11 with new graphics cards is poor, and the CUDA version used by Paddle needs to be upgraded.
```bash
pip install paddlepaddle-gpu==3.0.0b1 -i https://www.paddlepaddle.org.cn/packages/stable/cu123/
```
Reference: https://github.com/opendatalab/MinerU/issues/558
......@@ -41,3 +41,11 @@ pip uninstall fairscale
pip install fairscale
```
参考:https://github.com/opendatalab/MinerU/issues/411
### 6.在部分较新的设备如H100上,使用CUDA加速OCR时解析出的文字乱码。
cuda11对新显卡的兼容性不好,需要升级paddle使用的cuda版本
```bash
pip install paddlepaddle-gpu==3.0.0b1 -i https://www.paddlepaddle.org.cn/packages/stable/cu123/
```
参考:https://github.com/opendatalab/MinerU/issues/558
from huggingface_hub import snapshot_download
model_dir = snapshot_download('opendatalab/PDF-Extract-Kit')
print(f"model dir is: {model_dir}/models")
......@@ -6,58 +6,8 @@ wget https://github.com/opendatalab/MinerU/raw/master/docs/download_models_hf.py
python download_models_hf.py
```
After the Python script finishes executing, it will output the directory where the models are downloaded.
### 2. Additional steps
#### 1. Check whether the model directory is downloaded completely.
The structure of the model folder is as follows, including configuration files and weight files of different components:
```
../
├── Layout
│ ├── config.json
│ └── model_final.pth
├── MFD
│ └── weights.pt
├── MFR
│ └── UniMERNet
│ ├── config.json
│ ├── preprocessor_config.json
│ ├── pytorch_model.bin
│ ├── README.md
│ ├── tokenizer_config.json
│ └── tokenizer.json
│── TabRec
│ └─StructEqTable
│ ├── config.json
│ ├── generation_config.json
│ ├── model.safetensors
│ ├── preprocessor_config.json
│ ├── special_tokens_map.json
│ ├── spiece.model
│ ├── tokenizer.json
│ └── tokenizer_config.json
│ └─ TableMaster
│ └─ ch_PP-OCRv3_det_infer
│ ├── inference.pdiparams
│ ├── inference.pdiparams.info
│ └── inference.pdmodel
│ └─ ch_PP-OCRv3_rec_infer
│ ├── inference.pdiparams
│ ├── inference.pdiparams.info
│ └── inference.pdmodel
│ └─ table_structure_tablemaster_infer
│ ├── inference.pdiparams
│ ├── inference.pdiparams.info
│ └── inference.pdmodel
│ ├── ppocr_keys_v1.txt
│ └── table_master_structure_dict.txt
└── README.md
```
#### 2. Check whether the model file is fully downloaded.
Please check whether the size of the model file in the directory is consistent with the description on the web page. If possible, it is best to check whether the model is downloaded completely through sha256.
#### 3.
### 2. To modify the model path address in the configuration file
Additionally, in `~/magic-pdf.json`, update the model directory path to the absolute path of the `models` directory output by the previous Python script. Otherwise, you will encounter an error indicating that the model cannot be loaded.
......@@ -21,55 +21,7 @@ wget https://gitee.com/myhloli/MinerU/raw/master/docs/download_models.py
python download_models.py
```
python脚本执行完毕后,会输出模型下载目录
## 【❗️必须要做❗️】的额外步骤(模型下载完成后请务必完成以下操作)
### 1.检查模型目录是否下载完整
模型文件夹的结构如下,包含了不同组件的配置文件和权重文件:
```
./
├── Layout # 布局检测模型
│ ├── config.json
│ └── model_final.pth
├── MFD # 公式检测
│ └── weights.pt
├── MFR # 公式识别模型
│ └── UniMERNet
│ ├── config.json
│ ├── preprocessor_config.json
│ ├── pytorch_model.bin
│ ├── README.md
│ ├── tokenizer_config.json
│ └── tokenizer.json
│── TabRec # 表格识别模型
│ └─StructEqTable
│ ├── config.json
│ ├── generation_config.json
│ ├── model.safetensors
│ ├── preprocessor_config.json
│ ├── special_tokens_map.json
│ ├── spiece.model
│ ├── tokenizer.json
│ └── tokenizer_config.json
│ └─ TableMaster
│ └─ ch_PP-OCRv3_det_infer
│ ├── inference.pdiparams
│ ├── inference.pdiparams.info
│ └── inference.pdmodel
│ └─ ch_PP-OCRv3_rec_infer
│ ├── inference.pdiparams
│ ├── inference.pdiparams.info
│ └── inference.pdmodel
│ └─ table_structure_tablemaster_infer
│ ├── inference.pdiparams
│ ├── inference.pdiparams.info
│ └── inference.pdmodel
│ ├── ppocr_keys_v1.txt
│ └── table_master_structure_dict.txt
└── README.md
```
### 2.检查模型文件是否下载完整
请检查目录下的模型文件大小与网页上描述是否一致,如果可以的话,最好通过sha256校验模型是否下载完整
### 3.修改magic-pdf.json中的模型路径
此外在 `~/magic-pdf.json`里修改模型的目录指向之前python脚本输出的models目录的绝对路径,否则会报模型无法加载的错误。
## 下载完成后的操作:修改magic-pdf.json中的模型路径
`~/magic-pdf.json`里修改模型的目录指向上一步脚本输出的models目录的绝对路径,否则会报模型无法加载的错误。
......@@ -116,17 +116,20 @@ def ocr_mk_markdown_with_para_core(paras_of_layout, mode, img_buket_path=''):
def ocr_mk_markdown_with_para_core_v2(paras_of_layout,
mode,
img_buket_path=''):
img_buket_path='',
parse_type="auto",
lang=None
):
page_markdown = []
for para_block in paras_of_layout:
para_text = ''
para_type = para_block['type']
if para_type == BlockType.Text:
para_text = merge_para_with_text(para_block)
para_text = merge_para_with_text(para_block, parse_type=parse_type, lang=lang)
elif para_type == BlockType.Title:
para_text = f'# {merge_para_with_text(para_block)}'
para_text = f'# {merge_para_with_text(para_block, parse_type=parse_type, lang=lang)}'
elif para_type == BlockType.InterlineEquation:
para_text = merge_para_with_text(para_block)
para_text = merge_para_with_text(para_block, parse_type=parse_type, lang=lang)
elif para_type == BlockType.Image:
if mode == 'nlp':
continue
......@@ -139,17 +142,17 @@ def ocr_mk_markdown_with_para_core_v2(paras_of_layout,
para_text += f"\n![]({join_path(img_buket_path, span['image_path'])}) \n"
for block in para_block['blocks']: # 2nd.拼image_caption
if block['type'] == BlockType.ImageCaption:
para_text += merge_para_with_text(block)
para_text += merge_para_with_text(block, parse_type=parse_type, lang=lang)
for block in para_block['blocks']: # 2nd.拼image_caption
if block['type'] == BlockType.ImageFootnote:
para_text += merge_para_with_text(block)
para_text += merge_para_with_text(block, parse_type=parse_type, lang=lang)
elif para_type == BlockType.Table:
if mode == 'nlp':
continue
elif mode == 'mm':
for block in para_block['blocks']: # 1st.拼table_caption
if block['type'] == BlockType.TableCaption:
para_text += merge_para_with_text(block)
para_text += merge_para_with_text(block, parse_type=parse_type, lang=lang)
for block in para_block['blocks']: # 2nd.拼table_body
if block['type'] == BlockType.TableBody:
for line in block['lines']:
......@@ -164,7 +167,7 @@ def ocr_mk_markdown_with_para_core_v2(paras_of_layout,
para_text += f"\n![]({join_path(img_buket_path, span['image_path'])}) \n"
for block in para_block['blocks']: # 3rd.拼table_footnote
if block['type'] == BlockType.TableFootnote:
para_text += merge_para_with_text(block)
para_text += merge_para_with_text(block, parse_type=parse_type, lang=lang)
if para_text.strip() == '':
continue
......@@ -174,7 +177,7 @@ def ocr_mk_markdown_with_para_core_v2(paras_of_layout,
return page_markdown
def merge_para_with_text(para_block):
def merge_para_with_text(para_block, parse_type="auto", lang=None):
def detect_language(text):
en_pattern = r'[a-zA-Z]+'
......@@ -205,11 +208,15 @@ def merge_para_with_text(para_block):
content = span['content']
# language = detect_lang(content)
language = detect_language(content)
if language == 'en': # 只对英文长词进行分词处理,中文分词会丢失文本
content = ocr_escape_special_markdown_char(
split_long_words(content))
else:
# 判断是否小语种
if lang is not None and lang != 'en':
content = ocr_escape_special_markdown_char(content)
else: # 非小语种逻辑
if language == 'en' and parse_type == 'ocr': # 只对英文长词进行分词处理,中文分词会丢失文本
content = ocr_escape_special_markdown_char(
split_long_words(content))
else:
content = ocr_escape_special_markdown_char(content)
elif span_type == ContentType.InlineEquation:
content = f" ${span['content']}$ "
elif span_type == ContentType.InterlineEquation:
......@@ -265,41 +272,39 @@ def para_to_standard_format(para, img_buket_path):
return para_content
def para_to_standard_format_v2(para_block, img_buket_path, page_idx):
def para_to_standard_format_v2(para_block, img_buket_path, page_idx, parse_type="auto", lang=None, drop_reason=None):
para_type = para_block['type']
para_content = {}
if para_type == BlockType.Text:
para_content = {
'type': 'text',
'text': merge_para_with_text(para_block),
'page_idx': page_idx,
'text': merge_para_with_text(para_block, parse_type=parse_type, lang=lang),
}
elif para_type == BlockType.Title:
para_content = {
'type': 'text',
'text': merge_para_with_text(para_block),
'text': merge_para_with_text(para_block, parse_type=parse_type, lang=lang),
'text_level': 1,
'page_idx': page_idx,
}
elif para_type == BlockType.InterlineEquation:
para_content = {
'type': 'equation',
'text': merge_para_with_text(para_block),
'text': merge_para_with_text(para_block, parse_type=parse_type, lang=lang),
'text_format': 'latex',
'page_idx': page_idx,
}
elif para_type == BlockType.Image:
para_content = {'type': 'image', 'page_idx': page_idx}
para_content = {'type': 'image'}
for block in para_block['blocks']:
if block['type'] == BlockType.ImageBody:
para_content['img_path'] = join_path(
img_buket_path,
block['lines'][0]['spans'][0]['image_path'])
if block['type'] == BlockType.ImageCaption:
para_content['img_caption'] = merge_para_with_text(block)
para_content['img_caption'] = merge_para_with_text(block, parse_type=parse_type, lang=lang)
if block['type'] == BlockType.ImageFootnote:
para_content['img_footnote'] = merge_para_with_text(block)
para_content['img_footnote'] = merge_para_with_text(block, parse_type=parse_type, lang=lang)
elif para_type == BlockType.Table:
para_content = {'type': 'table', 'page_idx': page_idx}
para_content = {'type': 'table'}
for block in para_block['blocks']:
if block['type'] == BlockType.TableBody:
if block["lines"][0]["spans"][0].get('latex', ''):
......@@ -308,9 +313,14 @@ def para_to_standard_format_v2(para_block, img_buket_path, page_idx):
para_content['table_body'] = f"\n\n{block['lines'][0]['spans'][0]['html']}\n\n"
para_content['img_path'] = join_path(img_buket_path, block["lines"][0]["spans"][0]['image_path'])
if block['type'] == BlockType.TableCaption:
para_content['table_caption'] = merge_para_with_text(block)
para_content['table_caption'] = merge_para_with_text(block, parse_type=parse_type, lang=lang)
if block['type'] == BlockType.TableFootnote:
para_content['table_footnote'] = merge_para_with_text(block)
para_content['table_footnote'] = merge_para_with_text(block, parse_type=parse_type, lang=lang)
para_content['page_idx'] = page_idx
if drop_reason is not None:
para_content['drop_reason'] = drop_reason
return para_content
......@@ -394,13 +404,19 @@ def ocr_mk_mm_standard_format(pdf_info_dict: list):
def union_make(pdf_info_dict: list,
make_mode: str,
drop_mode: str,
img_buket_path: str = ''):
img_buket_path: str = '',
parse_type: str = "auto",
lang=None):
output_content = []
for page_info in pdf_info_dict:
drop_reason_flag = False
drop_reason = None
if page_info.get('need_drop', False):
drop_reason = page_info.get('drop_reason')
if drop_mode == DropMode.NONE:
pass
elif drop_mode == DropMode.NONE_WITH_REASON:
drop_reason_flag = True
elif drop_mode == DropMode.WHOLE_PDF:
raise Exception((f'drop_mode is {DropMode.WHOLE_PDF} ,'
f'drop_reason is {drop_reason}'))
......@@ -417,16 +433,20 @@ def union_make(pdf_info_dict: list,
continue
if make_mode == MakeMode.MM_MD:
page_markdown = ocr_mk_markdown_with_para_core_v2(
paras_of_layout, 'mm', img_buket_path)
paras_of_layout, 'mm', img_buket_path, parse_type=parse_type, lang=lang)
output_content.extend(page_markdown)
elif make_mode == MakeMode.NLP_MD:
page_markdown = ocr_mk_markdown_with_para_core_v2(
paras_of_layout, 'nlp')
paras_of_layout, 'nlp', parse_type=parse_type, lang=lang)
output_content.extend(page_markdown)
elif make_mode == MakeMode.STANDARD_FORMAT:
for para_block in paras_of_layout:
para_content = para_to_standard_format_v2(
para_block, img_buket_path, page_idx)
if drop_reason_flag:
para_content = para_to_standard_format_v2(
para_block, img_buket_path, page_idx, parse_type=parse_type, lang=lang, drop_reason=drop_reason)
else:
para_content = para_to_standard_format_v2(
para_block, img_buket_path, page_idx, parse_type=parse_type, lang=lang)
output_content.append(para_content)
if make_mode in [MakeMode.MM_MD, MakeMode.NLP_MD]:
return '\n\n'.join(output_content)
......
......@@ -8,3 +8,4 @@ class DropMode:
WHOLE_PDF = "whole_pdf"
SINGLE_PAGE = "single_page"
NONE = "none"
NONE_WITH_REASON = "none_with_reason"
......@@ -426,3 +426,22 @@ def bbox_distance(bbox1, bbox2):
elif top:
return y2 - y1b
return 0.0
def box_area(bbox):
return (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
def get_overlap_area(bbox1, bbox2):
"""计算box1和box2的重叠面积占bbox1的比例."""
# Determine the coordinates of the intersection rectangle
x_left = max(bbox1[0], bbox2[0])
y_top = max(bbox1[1], bbox2[1])
x_right = min(bbox1[2], bbox2[2])
y_bottom = min(bbox1[3], bbox2[3])
if x_right < x_left or y_bottom < y_top:
return 0.0
# The area of overlap area
return (x_right - x_left) * (y_bottom - y_top)
__version__ = "0.7.1"
__version__ = "0.8.0"
......@@ -57,14 +57,14 @@ class ModelSingleton:
cls._instance = super().__new__(cls)
return cls._instance
def get_model(self, ocr: bool, show_log: bool):
key = (ocr, show_log)
def get_model(self, ocr: bool, show_log: bool, lang=None):
key = (ocr, show_log, lang)
if key not in self._models:
self._models[key] = custom_model_init(ocr=ocr, show_log=show_log)
self._models[key] = custom_model_init(ocr=ocr, show_log=show_log, lang=lang)
return self._models[key]
def custom_model_init(ocr: bool = False, show_log: bool = False):
def custom_model_init(ocr: bool = False, show_log: bool = False, lang=None):
model = None
if model_config.__model_mode__ == "lite":
......@@ -78,7 +78,7 @@ def custom_model_init(ocr: bool = False, show_log: bool = False):
model_init_start = time.time()
if model == MODEL.Paddle:
from magic_pdf.model.pp_structure_v2 import CustomPaddleModel
custom_model = CustomPaddleModel(ocr=ocr, show_log=show_log)
custom_model = CustomPaddleModel(ocr=ocr, show_log=show_log, lang=lang)
elif model == MODEL.PEK:
from magic_pdf.model.pdf_extract_kit import CustomPEKModel
# 从配置文件读取model-dir和device
......@@ -89,7 +89,9 @@ def custom_model_init(ocr: bool = False, show_log: bool = False):
"show_log": show_log,
"models_dir": local_models_dir,
"device": device,
"table_config": table_config}
"table_config": table_config,
"lang": lang,
}
custom_model = CustomPEKModel(**model_input)
else:
logger.error("Not allow model_name!")
......@@ -104,10 +106,10 @@ def custom_model_init(ocr: bool = False, show_log: bool = False):
def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False,
start_page_id=0, end_page_id=None):
start_page_id=0, end_page_id=None, lang=None):
model_manager = ModelSingleton()
custom_model = model_manager.get_model(ocr, show_log)
custom_model = model_manager.get_model(ocr, show_log, lang)
images = load_images_from_pdf(pdf_bytes)
......
import json
from magic_pdf.libs.boxbase import (_is_in, _is_part_overlap, bbox_distance,
bbox_relative_pos, calculate_iou,
calculate_overlap_area_in_bbox1_area_ratio)
bbox_relative_pos, box_area, calculate_iou,
calculate_overlap_area_in_bbox1_area_ratio,
get_overlap_area)
from magic_pdf.libs.commons import fitz, join_path
from magic_pdf.libs.coordinate_transform import get_scale_ratio
from magic_pdf.libs.local_math import float_gt
......@@ -12,6 +13,7 @@ from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter
from magic_pdf.rw.DiskReaderWriter import DiskReaderWriter
CAPATION_OVERLAP_AREA_RATIO = 0.6
MERGE_BOX_OVERLAP_AREA_RATIO = 1.1
class MagicModel:
......@@ -124,49 +126,51 @@ class MagicModel:
tables.append(obj)
if len(footnotes) * len(figures) == 0:
continue
dis_figure_footnote = {}
dis_table_footnote = {}
for i in range(len(footnotes)):
for j in range(len(figures)):
pos_flag_count = sum(
list(
map(
lambda x: 1 if x else 0,
bbox_relative_pos(
footnotes[i]['bbox'], figures[j]['bbox']
),
)
dis_figure_footnote = {}
dis_table_footnote = {}
for i in range(len(footnotes)):
for j in range(len(figures)):
pos_flag_count = sum(
list(
map(
lambda x: 1 if x else 0,
bbox_relative_pos(
footnotes[i]['bbox'], figures[j]['bbox']
),
)
)
if pos_flag_count > 1:
continue
dis_figure_footnote[i] = min(
bbox_distance(figures[j]['bbox'], footnotes[i]['bbox']),
dis_figure_footnote.get(i, float('inf')),
)
for i in range(len(footnotes)):
for j in range(len(tables)):
pos_flag_count = sum(
list(
map(
lambda x: 1 if x else 0,
bbox_relative_pos(
footnotes[i]['bbox'], tables[j]['bbox']
),
)
)
if pos_flag_count > 1:
continue
dis_figure_footnote[i] = min(
bbox_distance(figures[j]['bbox'], footnotes[i]['bbox']),
dis_figure_footnote.get(i, float('inf')),
)
for i in range(len(footnotes)):
for j in range(len(tables)):
pos_flag_count = sum(
list(
map(
lambda x: 1 if x else 0,
bbox_relative_pos(
footnotes[i]['bbox'], tables[j]['bbox']
),
)
)
if pos_flag_count > 1:
continue
)
if pos_flag_count > 1:
continue
dis_table_footnote[i] = min(
bbox_distance(tables[j]['bbox'], footnotes[i]['bbox']),
dis_table_footnote.get(i, float('inf')),
)
for i in range(len(footnotes)):
if dis_table_footnote.get(i, float('inf')) > dis_figure_footnote[i]:
footnotes[i]['category_id'] = CategoryId.ImageFootnote
dis_table_footnote[i] = min(
bbox_distance(tables[j]['bbox'], footnotes[i]['bbox']),
dis_table_footnote.get(i, float('inf')),
)
for i in range(len(footnotes)):
if i not in dis_figure_footnote:
continue
if dis_table_footnote.get(i, float('inf')) > dis_figure_footnote[i]:
footnotes[i]['category_id'] = CategoryId.ImageFootnote
def __reduct_overlap(self, bboxes):
N = len(bboxes)
......@@ -191,6 +195,44 @@ class MagicModel:
筛选出所有和 merged bbox 有 overlap 且 overlap 面积大于 object 的面积的 subjects。
再求出筛选出的 subjects 和 object 的最短距离
"""
def search_overlap_between_boxes(
subject_idx, object_idx
):
idxes = [subject_idx, object_idx]
x0s = [all_bboxes[idx]['bbox'][0] for idx in idxes]
y0s = [all_bboxes[idx]['bbox'][1] for idx in idxes]
x1s = [all_bboxes[idx]['bbox'][2] for idx in idxes]
y1s = [all_bboxes[idx]['bbox'][3] for idx in idxes]
merged_bbox = [
min(x0s),
min(y0s),
max(x1s),
max(y1s),
]
ratio = 0
other_objects = list(
map(
lambda x: {'bbox': x['bbox'], 'score': x['score']},
filter(
lambda x: x['category_id']
not in (object_category_id, subject_category_id),
self.__model_list[page_no]['layout_dets'],
),
)
)
for other_object in other_objects:
ratio = max(
ratio,
get_overlap_area(
merged_bbox, other_object['bbox']
) * 1.0 / box_area(all_bboxes[object_idx]['bbox'])
)
if ratio >= MERGE_BOX_OVERLAP_AREA_RATIO:
break
return ratio
def may_find_other_nearest_bbox(subject_idx, object_idx):
ret = float('inf')
......@@ -299,6 +341,15 @@ class MagicModel:
):
continue
subject_idx, object_idx = i, j
if all_bboxes[j]['category_id'] == subject_category_id:
subject_idx, object_idx = j, i
if search_overlap_between_boxes(subject_idx, object_idx) >= MERGE_BOX_OVERLAP_AREA_RATIO:
dis[i][j] = float('inf')
dis[j][i] = dis[i][j]
continue
dis[i][j] = bbox_distance(all_bboxes[i]['bbox'], all_bboxes[j]['bbox'])
dis[j][i] = dis[i][j]
......@@ -627,13 +678,13 @@ class MagicModel:
span['type'] = ContentType.Image
elif category_id == 5:
# 获取table模型结果
latex = layout_det.get("latex", None)
html = layout_det.get("html", None)
latex = layout_det.get('latex', None)
html = layout_det.get('html', None)
if latex:
span["latex"] = latex
span['latex'] = latex
elif html:
span["html"] = html
span["type"] = ContentType.Table
span['html'] = html
span['type'] = ContentType.Table
elif category_id == 13:
span['content'] = layout_det['latex']
span['type'] = ContentType.InlineEquation
......
......@@ -58,7 +58,7 @@ def mfd_model_init(weight):
def mfr_model_init(weight_dir, cfg_path, _device_='cpu'):
args = argparse.Namespace(cfg_path=cfg_path, options=None)
cfg = Config(args)
cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.bin")
cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.pth")
cfg.config.model.model_config.model_name = weight_dir
cfg.config.model.tokenizer_config.path = weight_dir
task = tasks.setup_task(cfg)
......@@ -74,8 +74,11 @@ def layout_model_init(weight, config_file, device):
return model
def ocr_model_init(show_log: bool = False, det_db_box_thresh=0.3):
model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=det_db_box_thresh)
def ocr_model_init(show_log: bool = False, det_db_box_thresh=0.3, lang=None):
if lang is not None:
model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=det_db_box_thresh, lang=lang)
else:
model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=det_db_box_thresh)
return model
......@@ -134,7 +137,8 @@ def atom_model_init(model_name: str, **kwargs):
elif model_name == AtomicModel.OCR:
atom_model = ocr_model_init(
kwargs.get("ocr_show_log"),
kwargs.get("det_db_box_thresh")
kwargs.get("det_db_box_thresh"),
kwargs.get("lang")
)
elif model_name == AtomicModel.Table:
atom_model = table_model_init(
......@@ -177,9 +181,10 @@ class CustomPEKModel:
self.table_max_time = self.table_config.get("max_time", TABLE_MAX_TIME_VALUE)
self.table_model_type = self.table_config.get("model", TABLE_MASTER)
self.apply_ocr = ocr
self.lang = kwargs.get("lang", None)
logger.info(
"DocAnalysis init, this may take some times. apply_layout: {}, apply_formula: {}, apply_ocr: {}, apply_table: {}".format(
self.apply_layout, self.apply_formula, self.apply_ocr, self.apply_table
"DocAnalysis init, this may take some times. apply_layout: {}, apply_formula: {}, apply_ocr: {}, apply_table: {}, lang: {}".format(
self.apply_layout, self.apply_formula, self.apply_ocr, self.apply_table, self.lang
)
)
assert self.apply_layout, "DocAnalysis must contain layout model."
......@@ -225,11 +230,13 @@ class CustomPEKModel:
)
# 初始化ocr
if self.apply_ocr:
# self.ocr_model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=0.3)
self.ocr_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.OCR,
ocr_show_log=show_log,
det_db_box_thresh=0.3
det_db_box_thresh=0.3,
lang=self.lang
)
# init table model
if self.apply_table:
......@@ -243,6 +250,7 @@ class CustomPEKModel:
table_max_time=self.table_max_time,
device=self.device
)
logger.info('DocAnalysis init done!')
def __call__(self, image):
......@@ -383,6 +391,7 @@ class CustomPEKModel:
latex_code = self.table_model.image2latex(new_image)[0]
else:
html_code = self.table_model.img2html(new_image)
run_time = time.time() - single_table_start_time
logger.info(f"------------table recognition processing ends within {run_time}s-----")
if run_time > self.table_max_time:
......
......@@ -18,8 +18,11 @@ def region_to_bbox(region):
class CustomPaddleModel:
def __init__(self, ocr: bool = False, show_log: bool = False):
self.model = PPStructure(table=False, ocr=ocr, show_log=show_log)
def __init__(self, ocr: bool = False, show_log: bool = False, lang=None):
if lang is not None:
self.model = PPStructure(table=False, ocr=ocr, show_log=show_log, lang=lang)
else:
self.model = PPStructure(table=False, ocr=ocr, show_log=show_log)
def __call__(self, img):
try:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment