Unverified Commit 3a42ebbf authored by Xiaomeng Zhao's avatar Xiaomeng Zhao Committed by GitHub
Browse files

Merge pull request #838 from opendatalab/release-0.9.0

Release 0.9.0
parents 765c6d77 14024793
from struct_eqtable.model import StructTable from loguru import logger
try:
from struct_eqtable.model import StructTable
except ImportError:
logger.error("StructEqTable is under upgrade, the current version does not support it.")
from pypandoc import convert_text from pypandoc import convert_text
class StructTableModel: class StructTableModel:
def __init__(self, model_path, max_new_tokens=2048, max_time=400, device = 'cpu'): def __init__(self, model_path, max_new_tokens=2048, max_time=400, device = 'cpu'):
# init # init
......
...@@ -52,11 +52,11 @@ class ppTableModel(object): ...@@ -52,11 +52,11 @@ class ppTableModel(object):
rec_model_dir = os.path.join(model_dir, REC_MODEL_DIR) rec_model_dir = os.path.join(model_dir, REC_MODEL_DIR)
rec_char_dict_path = os.path.join(model_dir, REC_CHAR_DICT) rec_char_dict_path = os.path.join(model_dir, REC_CHAR_DICT)
device = kwargs.get("device", "cpu") device = kwargs.get("device", "cpu")
use_gpu = True if device == "cuda" else False use_gpu = True if device.startswith("cuda") else False
config = { config = {
"use_gpu": use_gpu, "use_gpu": use_gpu,
"table_max_len": kwargs.get("table_max_len", TABLE_MAX_LEN), "table_max_len": kwargs.get("table_max_len", TABLE_MAX_LEN),
"table_algorithm": TABLE_MASTER, "table_algorithm": "TableMaster",
"table_model_dir": table_model_dir, "table_model_dir": table_model_dir,
"table_char_dict_path": table_char_dict_path, "table_char_dict_path": table_char_dict_path,
"det_model_dir": det_model_dir, "det_model_dir": det_model_dir,
......
...@@ -18,7 +18,10 @@ def region_to_bbox(region): ...@@ -18,7 +18,10 @@ def region_to_bbox(region):
class CustomPaddleModel: class CustomPaddleModel:
def __init__(self, ocr: bool = False, show_log: bool = False): def __init__(self, ocr: bool = False, show_log: bool = False, lang=None):
if lang is not None:
self.model = PPStructure(table=False, ocr=ocr, show_log=show_log, lang=lang)
else:
self.model = PPStructure(table=False, ocr=ocr, show_log=show_log) self.model = PPStructure(table=False, ocr=ocr, show_log=show_log)
def __call__(self, img): def __call__(self, img):
......
from collections import defaultdict
from typing import List, Dict
import torch
from transformers import LayoutLMv3ForTokenClassification
MAX_LEN = 510
CLS_TOKEN_ID = 0
UNK_TOKEN_ID = 3
EOS_TOKEN_ID = 2
class DataCollator:
def __call__(self, features: List[dict]) -> Dict[str, torch.Tensor]:
bbox = []
labels = []
input_ids = []
attention_mask = []
# clip bbox and labels to max length, build input_ids and attention_mask
for feature in features:
_bbox = feature["source_boxes"]
if len(_bbox) > MAX_LEN:
_bbox = _bbox[:MAX_LEN]
_labels = feature["target_index"]
if len(_labels) > MAX_LEN:
_labels = _labels[:MAX_LEN]
_input_ids = [UNK_TOKEN_ID] * len(_bbox)
_attention_mask = [1] * len(_bbox)
assert len(_bbox) == len(_labels) == len(_input_ids) == len(_attention_mask)
bbox.append(_bbox)
labels.append(_labels)
input_ids.append(_input_ids)
attention_mask.append(_attention_mask)
# add CLS and EOS tokens
for i in range(len(bbox)):
bbox[i] = [[0, 0, 0, 0]] + bbox[i] + [[0, 0, 0, 0]]
labels[i] = [-100] + labels[i] + [-100]
input_ids[i] = [CLS_TOKEN_ID] + input_ids[i] + [EOS_TOKEN_ID]
attention_mask[i] = [1] + attention_mask[i] + [1]
# padding to max length
max_len = max(len(x) for x in bbox)
for i in range(len(bbox)):
bbox[i] = bbox[i] + [[0, 0, 0, 0]] * (max_len - len(bbox[i]))
labels[i] = labels[i] + [-100] * (max_len - len(labels[i]))
input_ids[i] = input_ids[i] + [EOS_TOKEN_ID] * (max_len - len(input_ids[i]))
attention_mask[i] = attention_mask[i] + [0] * (
max_len - len(attention_mask[i])
)
ret = {
"bbox": torch.tensor(bbox),
"attention_mask": torch.tensor(attention_mask),
"labels": torch.tensor(labels),
"input_ids": torch.tensor(input_ids),
}
# set label > MAX_LEN to -100, because original labels may be > MAX_LEN
ret["labels"][ret["labels"] > MAX_LEN] = -100
# set label > 0 to label-1, because original labels are 1-indexed
ret["labels"][ret["labels"] > 0] -= 1
return ret
def boxes2inputs(boxes: List[List[int]]) -> Dict[str, torch.Tensor]:
bbox = [[0, 0, 0, 0]] + boxes + [[0, 0, 0, 0]]
input_ids = [CLS_TOKEN_ID] + [UNK_TOKEN_ID] * len(boxes) + [EOS_TOKEN_ID]
attention_mask = [1] + [1] * len(boxes) + [1]
return {
"bbox": torch.tensor([bbox]),
"attention_mask": torch.tensor([attention_mask]),
"input_ids": torch.tensor([input_ids]),
}
def prepare_inputs(
inputs: Dict[str, torch.Tensor], model: LayoutLMv3ForTokenClassification
) -> Dict[str, torch.Tensor]:
ret = {}
for k, v in inputs.items():
v = v.to(model.device)
if torch.is_floating_point(v):
v = v.to(model.dtype)
ret[k] = v
return ret
def parse_logits(logits: torch.Tensor, length: int) -> List[int]:
"""
parse logits to orders
:param logits: logits from model
:param length: input length
:return: orders
"""
logits = logits[1 : length + 1, :length]
orders = logits.argsort(descending=False).tolist()
ret = [o.pop() for o in orders]
while True:
order_to_idxes = defaultdict(list)
for idx, order in enumerate(ret):
order_to_idxes[order].append(idx)
# filter idxes len > 1
order_to_idxes = {k: v for k, v in order_to_idxes.items() if len(v) > 1}
if not order_to_idxes:
break
# filter
for order, idxes in order_to_idxes.items():
# find original logits of idxes
idxes_to_logit = {}
for idx in idxes:
idxes_to_logit[idx] = logits[idx, order]
idxes_to_logit = sorted(
idxes_to_logit.items(), key=lambda x: x[1], reverse=True
)
# keep the highest logit as order, set others to next candidate
for idx, _ in idxes_to_logit[1:]:
ret[idx] = orders[idx].pop()
return ret
def check_duplicate(a: List[int]) -> bool:
return len(a) != len(set(a))
This diff is collapsed.
from magic_pdf.pdf_parse_union_core import pdf_parse_union from magic_pdf.config.enums import SupportedPdfParseMethod
from magic_pdf.data.dataset import PymuDocDataset
from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union
def parse_pdf_by_ocr(pdf_bytes, def parse_pdf_by_ocr(pdf_bytes,
...@@ -8,10 +10,11 @@ def parse_pdf_by_ocr(pdf_bytes, ...@@ -8,10 +10,11 @@ def parse_pdf_by_ocr(pdf_bytes,
end_page_id=None, end_page_id=None,
debug_mode=False, debug_mode=False,
): ):
return pdf_parse_union(pdf_bytes, dataset = PymuDocDataset(pdf_bytes)
return pdf_parse_union(dataset,
model_list, model_list,
imageWriter, imageWriter,
"ocr", SupportedPdfParseMethod.OCR,
start_page_id=start_page_id, start_page_id=start_page_id,
end_page_id=end_page_id, end_page_id=end_page_id,
debug_mode=debug_mode, debug_mode=debug_mode,
......
from magic_pdf.pdf_parse_union_core import pdf_parse_union from magic_pdf.config.enums import SupportedPdfParseMethod
from magic_pdf.data.dataset import PymuDocDataset
from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union
def parse_pdf_by_txt( def parse_pdf_by_txt(
...@@ -9,10 +11,11 @@ def parse_pdf_by_txt( ...@@ -9,10 +11,11 @@ def parse_pdf_by_txt(
end_page_id=None, end_page_id=None,
debug_mode=False, debug_mode=False,
): ):
return pdf_parse_union(pdf_bytes, dataset = PymuDocDataset(pdf_bytes)
return pdf_parse_union(dataset,
model_list, model_list,
imageWriter, imageWriter,
"txt", SupportedPdfParseMethod.TXT,
start_page_id=start_page_id, start_page_id=start_page_id,
end_page_id=end_page_id, end_page_id=end_page_id,
debug_mode=debug_mode, debug_mode=debug_mode,
......
This diff is collapsed.
...@@ -17,7 +17,7 @@ class AbsPipe(ABC): ...@@ -17,7 +17,7 @@ class AbsPipe(ABC):
PIP_TXT = "txt" PIP_TXT = "txt"
def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False, def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False,
start_page_id=0, end_page_id=None): start_page_id=0, end_page_id=None, lang=None, layout_model=None, formula_enable=None, table_enable=None):
self.pdf_bytes = pdf_bytes self.pdf_bytes = pdf_bytes
self.model_list = model_list self.model_list = model_list
self.image_writer = image_writer self.image_writer = image_writer
...@@ -25,6 +25,10 @@ class AbsPipe(ABC): ...@@ -25,6 +25,10 @@ class AbsPipe(ABC):
self.is_debug = is_debug self.is_debug = is_debug
self.start_page_id = start_page_id self.start_page_id = start_page_id
self.end_page_id = end_page_id self.end_page_id = end_page_id
self.lang = lang
self.layout_model = layout_model
self.formula_enable = formula_enable
self.table_enable = table_enable
def get_compress_pdf_mid_data(self): def get_compress_pdf_mid_data(self):
return JsonCompressor.compress_json(self.pdf_mid_data) return JsonCompressor.compress_json(self.pdf_mid_data)
......
...@@ -10,19 +10,25 @@ from magic_pdf.user_api import parse_ocr_pdf ...@@ -10,19 +10,25 @@ from magic_pdf.user_api import parse_ocr_pdf
class OCRPipe(AbsPipe): class OCRPipe(AbsPipe):
def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False, def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False,
start_page_id=0, end_page_id=None): start_page_id=0, end_page_id=None, lang=None,
super().__init__(pdf_bytes, model_list, image_writer, is_debug, start_page_id, end_page_id) layout_model=None, formula_enable=None, table_enable=None):
super().__init__(pdf_bytes, model_list, image_writer, is_debug, start_page_id, end_page_id, lang,
layout_model, formula_enable, table_enable)
def pipe_classify(self): def pipe_classify(self):
pass pass
def pipe_analyze(self): def pipe_analyze(self):
self.model_list = doc_analyze(self.pdf_bytes, ocr=True, self.model_list = doc_analyze(self.pdf_bytes, ocr=True,
start_page_id=self.start_page_id, end_page_id=self.end_page_id) start_page_id=self.start_page_id, end_page_id=self.end_page_id,
lang=self.lang, layout_model=self.layout_model,
formula_enable=self.formula_enable, table_enable=self.table_enable)
def pipe_parse(self): def pipe_parse(self):
self.pdf_mid_data = parse_ocr_pdf(self.pdf_bytes, self.model_list, self.image_writer, is_debug=self.is_debug, self.pdf_mid_data = parse_ocr_pdf(self.pdf_bytes, self.model_list, self.image_writer, is_debug=self.is_debug,
start_page_id=self.start_page_id, end_page_id=self.end_page_id) start_page_id=self.start_page_id, end_page_id=self.end_page_id,
lang=self.lang, layout_model=self.layout_model,
formula_enable=self.formula_enable, table_enable=self.table_enable)
def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF): def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF):
result = super().pipe_mk_uni_format(img_parent_path, drop_mode) result = super().pipe_mk_uni_format(img_parent_path, drop_mode)
......
...@@ -11,19 +11,25 @@ from magic_pdf.user_api import parse_txt_pdf ...@@ -11,19 +11,25 @@ from magic_pdf.user_api import parse_txt_pdf
class TXTPipe(AbsPipe): class TXTPipe(AbsPipe):
def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False, def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False,
start_page_id=0, end_page_id=None): start_page_id=0, end_page_id=None, lang=None,
super().__init__(pdf_bytes, model_list, image_writer, is_debug, start_page_id, end_page_id) layout_model=None, formula_enable=None, table_enable=None):
super().__init__(pdf_bytes, model_list, image_writer, is_debug, start_page_id, end_page_id, lang,
layout_model, formula_enable, table_enable)
def pipe_classify(self): def pipe_classify(self):
pass pass
def pipe_analyze(self): def pipe_analyze(self):
self.model_list = doc_analyze(self.pdf_bytes, ocr=False, self.model_list = doc_analyze(self.pdf_bytes, ocr=False,
start_page_id=self.start_page_id, end_page_id=self.end_page_id) start_page_id=self.start_page_id, end_page_id=self.end_page_id,
lang=self.lang, layout_model=self.layout_model,
formula_enable=self.formula_enable, table_enable=self.table_enable)
def pipe_parse(self): def pipe_parse(self):
self.pdf_mid_data = parse_txt_pdf(self.pdf_bytes, self.model_list, self.image_writer, is_debug=self.is_debug, self.pdf_mid_data = parse_txt_pdf(self.pdf_bytes, self.model_list, self.image_writer, is_debug=self.is_debug,
start_page_id=self.start_page_id, end_page_id=self.end_page_id) start_page_id=self.start_page_id, end_page_id=self.end_page_id,
lang=self.lang, layout_model=self.layout_model,
formula_enable=self.formula_enable, table_enable=self.table_enable)
def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF): def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF):
result = super().pipe_mk_uni_format(img_parent_path, drop_mode) result = super().pipe_mk_uni_format(img_parent_path, drop_mode)
......
...@@ -14,9 +14,11 @@ from magic_pdf.user_api import parse_union_pdf, parse_ocr_pdf ...@@ -14,9 +14,11 @@ from magic_pdf.user_api import parse_union_pdf, parse_ocr_pdf
class UNIPipe(AbsPipe): class UNIPipe(AbsPipe):
def __init__(self, pdf_bytes: bytes, jso_useful_key: dict, image_writer: AbsReaderWriter, is_debug: bool = False, def __init__(self, pdf_bytes: bytes, jso_useful_key: dict, image_writer: AbsReaderWriter, is_debug: bool = False,
start_page_id=0, end_page_id=None): start_page_id=0, end_page_id=None, lang=None,
layout_model=None, formula_enable=None, table_enable=None):
self.pdf_type = jso_useful_key["_pdf_type"] self.pdf_type = jso_useful_key["_pdf_type"]
super().__init__(pdf_bytes, jso_useful_key["model_list"], image_writer, is_debug, start_page_id, end_page_id) super().__init__(pdf_bytes, jso_useful_key["model_list"], image_writer, is_debug, start_page_id, end_page_id,
lang, layout_model, formula_enable, table_enable)
if len(self.model_list) == 0: if len(self.model_list) == 0:
self.input_model_is_empty = True self.input_model_is_empty = True
else: else:
...@@ -28,22 +30,29 @@ class UNIPipe(AbsPipe): ...@@ -28,22 +30,29 @@ class UNIPipe(AbsPipe):
def pipe_analyze(self): def pipe_analyze(self):
if self.pdf_type == self.PIP_TXT: if self.pdf_type == self.PIP_TXT:
self.model_list = doc_analyze(self.pdf_bytes, ocr=False, self.model_list = doc_analyze(self.pdf_bytes, ocr=False,
start_page_id=self.start_page_id, end_page_id=self.end_page_id) start_page_id=self.start_page_id, end_page_id=self.end_page_id,
lang=self.lang, layout_model=self.layout_model,
formula_enable=self.formula_enable, table_enable=self.table_enable)
elif self.pdf_type == self.PIP_OCR: elif self.pdf_type == self.PIP_OCR:
self.model_list = doc_analyze(self.pdf_bytes, ocr=True, self.model_list = doc_analyze(self.pdf_bytes, ocr=True,
start_page_id=self.start_page_id, end_page_id=self.end_page_id) start_page_id=self.start_page_id, end_page_id=self.end_page_id,
lang=self.lang, layout_model=self.layout_model,
formula_enable=self.formula_enable, table_enable=self.table_enable)
def pipe_parse(self): def pipe_parse(self):
if self.pdf_type == self.PIP_TXT: if self.pdf_type == self.PIP_TXT:
self.pdf_mid_data = parse_union_pdf(self.pdf_bytes, self.model_list, self.image_writer, self.pdf_mid_data = parse_union_pdf(self.pdf_bytes, self.model_list, self.image_writer,
is_debug=self.is_debug, input_model_is_empty=self.input_model_is_empty, is_debug=self.is_debug, input_model_is_empty=self.input_model_is_empty,
start_page_id=self.start_page_id, end_page_id=self.end_page_id) start_page_id=self.start_page_id, end_page_id=self.end_page_id,
lang=self.lang, layout_model=self.layout_model,
formula_enable=self.formula_enable, table_enable=self.table_enable)
elif self.pdf_type == self.PIP_OCR: elif self.pdf_type == self.PIP_OCR:
self.pdf_mid_data = parse_ocr_pdf(self.pdf_bytes, self.model_list, self.image_writer, self.pdf_mid_data = parse_ocr_pdf(self.pdf_bytes, self.model_list, self.image_writer,
is_debug=self.is_debug, is_debug=self.is_debug,
start_page_id=self.start_page_id, end_page_id=self.end_page_id) start_page_id=self.start_page_id, end_page_id=self.end_page_id,
lang=self.lang)
def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF): def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.NONE_WITH_REASON):
result = super().pipe_mk_uni_format(img_parent_path, drop_mode) result = super().pipe_mk_uni_format(img_parent_path, drop_mode)
logger.info("uni_pipe mk content list finished") logger.info("uni_pipe mk content list finished")
return result return result
......
This diff is collapsed.
...@@ -49,8 +49,7 @@ def merge_spans_to_line(spans): ...@@ -49,8 +49,7 @@ def merge_spans_to_line(spans):
continue continue
# 如果当前的span与当前行的最后一个span在y轴上重叠,则添加到当前行 # 如果当前的span与当前行的最后一个span在y轴上重叠,则添加到当前行
if __is_overlaps_y_exceeds_threshold(span['bbox'], if __is_overlaps_y_exceeds_threshold(span['bbox'], current_line[-1]['bbox'], 0.5):
current_line[-1]['bbox']):
current_line.append(span) current_line.append(span)
else: else:
# 否则,开始新行 # 否则,开始新行
...@@ -154,6 +153,11 @@ def fill_spans_in_blocks(blocks, spans, radio): ...@@ -154,6 +153,11 @@ def fill_spans_in_blocks(blocks, spans, radio):
'type': block_type, 'type': block_type,
'bbox': block_bbox, 'bbox': block_bbox,
} }
if block_type in [
BlockType.ImageBody, BlockType.ImageCaption, BlockType.ImageFootnote,
BlockType.TableBody, BlockType.TableCaption, BlockType.TableFootnote
]:
block_dict["group_id"] = block[-1]
block_spans = [] block_spans = []
for span in spans: for span in spans:
span_bbox = span['bbox'] span_bbox = span['bbox']
...@@ -202,6 +206,27 @@ def fix_block_spans(block_with_spans, img_blocks, table_blocks): ...@@ -202,6 +206,27 @@ def fix_block_spans(block_with_spans, img_blocks, table_blocks):
return fix_blocks return fix_blocks
def fix_block_spans_v2(block_with_spans):
"""1、img_block和table_block因为包含caption和footnote的关系,存在block的嵌套关系
需要将caption和footnote的text_span放入相应img_block和table_block内的
caption_block和footnote_block中 2、同时需要删除block中的spans字段."""
fix_blocks = []
for block in block_with_spans:
block_type = block['type']
if block_type in [BlockType.Text, BlockType.Title,
BlockType.ImageCaption, BlockType.ImageFootnote,
BlockType.TableCaption, BlockType.TableFootnote
]:
block = fix_text_block(block)
elif block_type in [BlockType.InterlineEquation, BlockType.ImageBody, BlockType.TableBody]:
block = fix_interline_block(block)
else:
continue
fix_blocks.append(block)
return fix_blocks
def fix_discarded_block(discarded_block_with_spans): def fix_discarded_block(discarded_block_with_spans):
fix_discarded_blocks = [] fix_discarded_blocks = []
for block in discarded_block_with_spans: for block in discarded_block_with_spans:
......
...@@ -2,13 +2,13 @@ model: ...@@ -2,13 +2,13 @@ model:
arch: unimernet arch: unimernet
model_type: unimernet model_type: unimernet
model_config: model_config:
model_name: ./models model_name: ./models/unimernet_base
max_seq_len: 1024 max_seq_len: 1536
length_aware: False
load_pretrained: True load_pretrained: True
pretrained: ./models/pytorch_model.bin pretrained: './models/unimernet_base/pytorch_model.pth'
tokenizer_config: tokenizer_config:
path: ./models path: ./models/unimernet_base
datasets: datasets:
formula_rec_eval: formula_rec_eval:
......
config:
device: cpu
layout: True
formula: True
table_config:
model: TableMaster
is_table_recog_enable: False
max_time: 400
weights: weights:
layout: Layout/model_final.pth layoutlmv3: Layout/LayoutLMv3/model_final.pth
mfd: MFD/weights.pt doclayout_yolo: Layout/YOLO/doclayout_yolo_ft.pt
mfr: MFR/UniMERNet yolo_v8_mfd: MFD/YOLO/yolo_v8_ft.pt
unimernet_small: MFR/unimernet_small
struct_eqtable: TabRec/StructEqTable struct_eqtable: TabRec/StructEqTable
TableMaster: TabRec/TableMaster tablemaster: TabRec/TableMaster
\ No newline at end of file \ No newline at end of file
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment