Commit d13f3c6d authored by icecraft's avatar icecraft
Browse files

refactor: remove unused method in MagicModel class

parent ad9abc32
...@@ -3,12 +3,9 @@ import enum ...@@ -3,12 +3,9 @@ import enum
from magic_pdf.config.model_block_type import ModelBlockTypeEnum from magic_pdf.config.model_block_type import ModelBlockTypeEnum
from magic_pdf.config.ocr_content_type import CategoryId, ContentType from magic_pdf.config.ocr_content_type import CategoryId, ContentType
from magic_pdf.data.dataset import Dataset from magic_pdf.data.dataset import Dataset
from magic_pdf.libs.boxbase import (_is_in, _is_part_overlap, bbox_distance, from magic_pdf.libs.boxbase import (_is_in, bbox_distance, bbox_relative_pos,
bbox_relative_pos, box_area, calculate_iou, calculate_iou)
calculate_overlap_area_in_bbox1_area_ratio,
get_overlap_area)
from magic_pdf.libs.coordinate_transform import get_scale_ratio from magic_pdf.libs.coordinate_transform import get_scale_ratio
from magic_pdf.libs.local_math import float_gt
from magic_pdf.pre_proc.remove_bbox_overlap import _remove_overlap_between_bbox from magic_pdf.pre_proc.remove_bbox_overlap import _remove_overlap_between_bbox
CAPATION_OVERLAP_AREA_RATIO = 0.6 CAPATION_OVERLAP_AREA_RATIO = 0.6
...@@ -208,393 +205,6 @@ class MagicModel: ...@@ -208,393 +205,6 @@ class MagicModel:
keep[i] = False keep[i] = False
return [bboxes[i] for i in range(N) if keep[i]] return [bboxes[i] for i in range(N) if keep[i]]
def __tie_up_category_by_distance(
self, page_no, subject_category_id, object_category_id
):
"""假定每个 subject 最多有一个 object (可以有多个相邻的 object 合并为单个 object),每个 object
只能属于一个 subject."""
ret = []
MAX_DIS_OF_POINT = 10**9 + 7
"""
subject 和 object 的 bbox 会合并成一个大的 bbox (named: merged bbox)。
筛选出所有和 merged bbox 有 overlap 且 overlap 面积大于 object 的面积的 subjects。
再求出筛选出的 subjects 和 object 的最短距离
"""
def search_overlap_between_boxes(subject_idx, object_idx):
idxes = [subject_idx, object_idx]
x0s = [all_bboxes[idx]['bbox'][0] for idx in idxes]
y0s = [all_bboxes[idx]['bbox'][1] for idx in idxes]
x1s = [all_bboxes[idx]['bbox'][2] for idx in idxes]
y1s = [all_bboxes[idx]['bbox'][3] for idx in idxes]
merged_bbox = [
min(x0s),
min(y0s),
max(x1s),
max(y1s),
]
ratio = 0
other_objects = list(
map(
lambda x: {'bbox': x['bbox'], 'score': x['score']},
filter(
lambda x: x['category_id']
not in (object_category_id, subject_category_id),
self.__model_list[page_no]['layout_dets'],
),
)
)
for other_object in other_objects:
ratio = max(
ratio,
get_overlap_area(merged_bbox, other_object['bbox'])
* 1.0
/ box_area(all_bboxes[object_idx]['bbox']),
)
if ratio >= MERGE_BOX_OVERLAP_AREA_RATIO:
break
return ratio
def may_find_other_nearest_bbox(subject_idx, object_idx):
ret = float('inf')
x0 = min(
all_bboxes[subject_idx]['bbox'][0], all_bboxes[object_idx]['bbox'][0]
)
y0 = min(
all_bboxes[subject_idx]['bbox'][1], all_bboxes[object_idx]['bbox'][1]
)
x1 = max(
all_bboxes[subject_idx]['bbox'][2], all_bboxes[object_idx]['bbox'][2]
)
y1 = max(
all_bboxes[subject_idx]['bbox'][3], all_bboxes[object_idx]['bbox'][3]
)
object_area = abs(
all_bboxes[object_idx]['bbox'][2] - all_bboxes[object_idx]['bbox'][0]
) * abs(
all_bboxes[object_idx]['bbox'][3] - all_bboxes[object_idx]['bbox'][1]
)
for i in range(len(all_bboxes)):
if (
i == subject_idx
or all_bboxes[i]['category_id'] != subject_category_id
):
continue
if _is_part_overlap([x0, y0, x1, y1], all_bboxes[i]['bbox']) or _is_in(
all_bboxes[i]['bbox'], [x0, y0, x1, y1]
):
i_area = abs(
all_bboxes[i]['bbox'][2] - all_bboxes[i]['bbox'][0]
) * abs(all_bboxes[i]['bbox'][3] - all_bboxes[i]['bbox'][1])
if i_area >= object_area:
ret = min(float('inf'), dis[i][object_idx])
return ret
def expand_bbbox(idxes):
x0s = [all_bboxes[idx]['bbox'][0] for idx in idxes]
y0s = [all_bboxes[idx]['bbox'][1] for idx in idxes]
x1s = [all_bboxes[idx]['bbox'][2] for idx in idxes]
y1s = [all_bboxes[idx]['bbox'][3] for idx in idxes]
return min(x0s), min(y0s), max(x1s), max(y1s)
subjects = self.__reduct_overlap(
list(
map(
lambda x: {'bbox': x['bbox'], 'score': x['score']},
filter(
lambda x: x['category_id'] == subject_category_id,
self.__model_list[page_no]['layout_dets'],
),
)
)
)
objects = self.__reduct_overlap(
list(
map(
lambda x: {'bbox': x['bbox'], 'score': x['score']},
filter(
lambda x: x['category_id'] == object_category_id,
self.__model_list[page_no]['layout_dets'],
),
)
)
)
subject_object_relation_map = {}
subjects.sort(
key=lambda x: x['bbox'][0] ** 2 + x['bbox'][1] ** 2
) # get the distance !
all_bboxes = []
for v in subjects:
all_bboxes.append(
{
'category_id': subject_category_id,
'bbox': v['bbox'],
'score': v['score'],
}
)
for v in objects:
all_bboxes.append(
{
'category_id': object_category_id,
'bbox': v['bbox'],
'score': v['score'],
}
)
N = len(all_bboxes)
dis = [[MAX_DIS_OF_POINT] * N for _ in range(N)]
for i in range(N):
for j in range(i):
if (
all_bboxes[i]['category_id'] == subject_category_id
and all_bboxes[j]['category_id'] == subject_category_id
):
continue
subject_idx, object_idx = i, j
if all_bboxes[j]['category_id'] == subject_category_id:
subject_idx, object_idx = j, i
if (
search_overlap_between_boxes(subject_idx, object_idx)
>= MERGE_BOX_OVERLAP_AREA_RATIO
):
dis[i][j] = float('inf')
dis[j][i] = dis[i][j]
continue
dis[i][j] = self._bbox_distance(
all_bboxes[subject_idx]['bbox'], all_bboxes[object_idx]['bbox']
)
dis[j][i] = dis[i][j]
used = set()
for i in range(N):
# 求第 i 个 subject 所关联的 object
if all_bboxes[i]['category_id'] != subject_category_id:
continue
seen = set()
candidates = []
arr = []
for j in range(N):
pos_flag_count = sum(
list(
map(
lambda x: 1 if x else 0,
bbox_relative_pos(
all_bboxes[i]['bbox'], all_bboxes[j]['bbox']
),
)
)
)
if pos_flag_count > 1:
continue
if (
all_bboxes[j]['category_id'] != object_category_id
or j in used
or dis[i][j] == MAX_DIS_OF_POINT
):
continue
left, right, _, _ = bbox_relative_pos(
all_bboxes[i]['bbox'], all_bboxes[j]['bbox']
) # 由 pos_flag_count 相关逻辑保证本段逻辑准确性
if left or right:
one_way_dis = all_bboxes[i]['bbox'][2] - all_bboxes[i]['bbox'][0]
else:
one_way_dis = all_bboxes[i]['bbox'][3] - all_bboxes[i]['bbox'][1]
if dis[i][j] > one_way_dis:
continue
arr.append((dis[i][j], j))
arr.sort(key=lambda x: x[0])
if len(arr) > 0:
"""
bug: 离该subject 最近的 object 可能跨越了其它的 subject。
比如 [this subect] [some sbuject] [the nearest object of subject]
"""
if may_find_other_nearest_bbox(i, arr[0][1]) >= arr[0][0]:
candidates.append(arr[0][1])
seen.add(arr[0][1])
# 已经获取初始种子
for j in set(candidates):
tmp = []
for k in range(i + 1, N):
pos_flag_count = sum(
list(
map(
lambda x: 1 if x else 0,
bbox_relative_pos(
all_bboxes[j]['bbox'], all_bboxes[k]['bbox']
),
)
)
)
if pos_flag_count > 1:
continue
if (
all_bboxes[k]['category_id'] != object_category_id
or k in used
or k in seen
or dis[j][k] == MAX_DIS_OF_POINT
or dis[j][k] > dis[i][j]
):
continue
is_nearest = True
for ni in range(i + 1, N):
if ni in (j, k) or ni in used or ni in seen:
continue
if not float_gt(dis[ni][k], dis[j][k]):
is_nearest = False
break
if is_nearest:
nx0, ny0, nx1, ny1 = expand_bbbox(list(seen) + [k])
n_dis = bbox_distance(
all_bboxes[i]['bbox'], [nx0, ny0, nx1, ny1]
)
if float_gt(dis[i][j], n_dis):
continue
tmp.append(k)
seen.add(k)
candidates = tmp
if len(candidates) == 0:
break
# 已经获取到某个 figure 下所有的最靠近的 captions,以及最靠近这些 captions 的 captions 。
# 先扩一下 bbox,
ox0, oy0, ox1, oy1 = expand_bbbox(list(seen) + [i])
ix0, iy0, ix1, iy1 = all_bboxes[i]['bbox']
# 分成了 4 个截取空间,需要计算落在每个截取空间下 objects 合并后占据的矩形面积
caption_poses = [
[ox0, oy0, ix0, oy1],
[ox0, oy0, ox1, iy0],
[ox0, iy1, ox1, oy1],
[ix1, oy0, ox1, oy1],
]
caption_areas = []
for bbox in caption_poses:
embed_arr = []
for idx in seen:
if (
calculate_overlap_area_in_bbox1_area_ratio(
all_bboxes[idx]['bbox'], bbox
)
> CAPATION_OVERLAP_AREA_RATIO
):
embed_arr.append(idx)
if len(embed_arr) > 0:
embed_x0 = min([all_bboxes[idx]['bbox'][0] for idx in embed_arr])
embed_y0 = min([all_bboxes[idx]['bbox'][1] for idx in embed_arr])
embed_x1 = max([all_bboxes[idx]['bbox'][2] for idx in embed_arr])
embed_y1 = max([all_bboxes[idx]['bbox'][3] for idx in embed_arr])
caption_areas.append(
int(abs(embed_x1 - embed_x0) * abs(embed_y1 - embed_y0))
)
else:
caption_areas.append(0)
subject_object_relation_map[i] = []
if max(caption_areas) > 0:
max_area_idx = caption_areas.index(max(caption_areas))
caption_bbox = caption_poses[max_area_idx]
for j in seen:
if (
calculate_overlap_area_in_bbox1_area_ratio(
all_bboxes[j]['bbox'], caption_bbox
)
> CAPATION_OVERLAP_AREA_RATIO
):
used.add(j)
subject_object_relation_map[i].append(j)
for i in sorted(subject_object_relation_map.keys()):
result = {
'subject_body': all_bboxes[i]['bbox'],
'all': all_bboxes[i]['bbox'],
'score': all_bboxes[i]['score'],
}
if len(subject_object_relation_map[i]) > 0:
x0 = min(
[all_bboxes[j]['bbox'][0] for j in subject_object_relation_map[i]]
)
y0 = min(
[all_bboxes[j]['bbox'][1] for j in subject_object_relation_map[i]]
)
x1 = max(
[all_bboxes[j]['bbox'][2] for j in subject_object_relation_map[i]]
)
y1 = max(
[all_bboxes[j]['bbox'][3] for j in subject_object_relation_map[i]]
)
result['object_body'] = [x0, y0, x1, y1]
result['all'] = [
min(x0, all_bboxes[i]['bbox'][0]),
min(y0, all_bboxes[i]['bbox'][1]),
max(x1, all_bboxes[i]['bbox'][2]),
max(y1, all_bboxes[i]['bbox'][3]),
]
ret.append(result)
total_subject_object_dis = 0
# 计算已经配对的 distance 距离
for i in subject_object_relation_map.keys():
for j in subject_object_relation_map[i]:
total_subject_object_dis += bbox_distance(
all_bboxes[i]['bbox'], all_bboxes[j]['bbox']
)
# 计算未匹配的 subject 和 object 的距离(非精确版)
with_caption_subject = set(
[
key
for key in subject_object_relation_map.keys()
if len(subject_object_relation_map[i]) > 0
]
)
for i in range(N):
if all_bboxes[i]['category_id'] != object_category_id or i in used:
continue
candidates = []
for j in range(N):
if (
all_bboxes[j]['category_id'] != subject_category_id
or j in with_caption_subject
):
continue
candidates.append((dis[i][j], j))
if len(candidates) > 0:
candidates.sort(key=lambda x: x[0])
total_subject_object_dis += candidates[0][1]
with_caption_subject.add(j)
return ret, total_subject_object_dis
def __tie_up_category_by_distance_v2( def __tie_up_category_by_distance_v2(
self, self,
page_no: int, page_no: int,
...@@ -879,52 +489,12 @@ class MagicModel: ...@@ -879,52 +489,12 @@ class MagicModel:
return ret return ret
def get_imgs(self, page_no: int): def get_imgs(self, page_no: int):
with_captions, _ = self.__tie_up_category_by_distance(page_no, 3, 4) return self.get_imgs_v2(page_no)
with_footnotes, _ = self.__tie_up_category_by_distance(
page_no, 3, CategoryId.ImageFootnote
)
ret = []
N, M = len(with_captions), len(with_footnotes)
assert N == M
for i in range(N):
record = {
'score': with_captions[i]['score'],
'img_caption_bbox': with_captions[i].get('object_body', None),
'img_body_bbox': with_captions[i]['subject_body'],
'img_footnote_bbox': with_footnotes[i].get('object_body', None),
}
x0 = min(with_captions[i]['all'][0], with_footnotes[i]['all'][0])
y0 = min(with_captions[i]['all'][1], with_footnotes[i]['all'][1])
x1 = max(with_captions[i]['all'][2], with_footnotes[i]['all'][2])
y1 = max(with_captions[i]['all'][3], with_footnotes[i]['all'][3])
record['bbox'] = [x0, y0, x1, y1]
ret.append(record)
return ret
def get_tables( def get_tables(
self, page_no: int self, page_no: int
) -> list: # 3个坐标, caption, table主体,table-note ) -> list: # 3个坐标, caption, table主体,table-note
with_captions, _ = self.__tie_up_category_by_distance(page_no, 5, 6) return self.get_tables_v2(page_no)
with_footnotes, _ = self.__tie_up_category_by_distance(page_no, 5, 7)
ret = []
N, M = len(with_captions), len(with_footnotes)
assert N == M
for i in range(N):
record = {
'score': with_captions[i]['score'],
'table_caption_bbox': with_captions[i].get('object_body', None),
'table_body_bbox': with_captions[i]['subject_body'],
'table_footnote_bbox': with_footnotes[i].get('object_body', None),
}
x0 = min(with_captions[i]['all'][0], with_footnotes[i]['all'][0])
y0 = min(with_captions[i]['all'][1], with_footnotes[i]['all'][1])
x1 = max(with_captions[i]['all'][2], with_footnotes[i]['all'][2])
y1 = max(with_captions[i]['all'][3], with_footnotes[i]['all'][3])
record['bbox'] = [x0, y0, x1, y1]
ret.append(record)
return ret
def get_equations(self, page_no: int) -> list: # 有坐标,也有字 def get_equations(self, page_no: int) -> list: # 有坐标,也有字
inline_equations = self.__get_blocks_by_type( inline_equations = self.__get_blocks_by_type(
...@@ -1043,4 +613,3 @@ class MagicModel: ...@@ -1043,4 +613,3 @@ class MagicModel:
def get_model_list(self, page_no): def get_model_list(self, page_no):
return self.__model_list[page_no] return self.__model_list[page_no]
from magic_pdf.config.enums import SupportedPdfParseMethod
from magic_pdf.data.dataset import Dataset
from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union
def parse_pdf_by_ocr(dataset: Dataset,
model_list,
imageWriter,
start_page_id=0,
end_page_id=None,
debug_mode=False,
lang=None,
):
return pdf_parse_union(model_list,
dataset,
imageWriter,
SupportedPdfParseMethod.OCR,
start_page_id=start_page_id,
end_page_id=end_page_id,
debug_mode=debug_mode,
lang=lang,
)
from magic_pdf.config.enums import SupportedPdfParseMethod
from magic_pdf.data.dataset import Dataset
from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union
def parse_pdf_by_txt(
dataset: Dataset,
model_list,
imageWriter,
start_page_id=0,
end_page_id=None,
debug_mode=False,
lang=None,
):
return pdf_parse_union(model_list,
dataset,
imageWriter,
SupportedPdfParseMethod.TXT,
start_page_id=start_page_id,
end_page_id=end_page_id,
debug_mode=debug_mode,
lang=lang,
)
from abc import ABC, abstractmethod
from magic_pdf.config.drop_reason import DropReason
from magic_pdf.config.make_content_config import DropMode, MakeMode
from magic_pdf.data.data_reader_writer import DataWriter
from magic_pdf.data.dataset import Dataset
from magic_pdf.dict2md.ocr_mkcontent import union_make
from magic_pdf.filter.pdf_classify_by_type import classify
from magic_pdf.filter.pdf_meta_scan import pdf_meta_scan
from magic_pdf.libs.json_compressor import JsonCompressor
class AbsPipe(ABC):
"""txt和ocr处理的抽象类."""
PIP_OCR = 'ocr'
PIP_TXT = 'txt'
def __init__(self, dataset: Dataset, model_list: list, image_writer: DataWriter, is_debug: bool = False,
start_page_id=0, end_page_id=None, lang=None, layout_model=None, formula_enable=None, table_enable=None):
self.dataset = dataset
self.model_list = model_list
self.image_writer = image_writer
self.pdf_mid_data = None # 未压缩
self.is_debug = is_debug
self.start_page_id = start_page_id
self.end_page_id = end_page_id
self.lang = lang
self.layout_model = layout_model
self.formula_enable = formula_enable
self.table_enable = table_enable
def get_compress_pdf_mid_data(self):
return JsonCompressor.compress_json(self.pdf_mid_data)
@abstractmethod
def pipe_classify(self):
"""有状态的分类."""
raise NotImplementedError
@abstractmethod
def pipe_analyze(self):
"""有状态的跑模型分析."""
raise NotImplementedError
@abstractmethod
def pipe_parse(self):
"""有状态的解析."""
raise NotImplementedError
def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF):
content_list = AbsPipe.mk_uni_format(self.get_compress_pdf_mid_data(), img_parent_path, drop_mode)
return content_list
def pipe_mk_markdown(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF, md_make_mode=MakeMode.MM_MD):
md_content = AbsPipe.mk_markdown(self.get_compress_pdf_mid_data(), img_parent_path, drop_mode, md_make_mode)
return md_content
@staticmethod
def classify(pdf_bytes: bytes) -> str:
"""根据pdf的元数据,判断是文本pdf,还是ocr pdf."""
pdf_meta = pdf_meta_scan(pdf_bytes)
if pdf_meta.get('_need_drop', False): # 如果返回了需要丢弃的标志,则抛出异常
raise Exception(f"pdf meta_scan need_drop,reason is {pdf_meta['_drop_reason']}")
else:
is_encrypted = pdf_meta['is_encrypted']
is_needs_password = pdf_meta['is_needs_password']
if is_encrypted or is_needs_password: # 加密的,需要密码的,没有页面的,都不处理
raise Exception(f'pdf meta_scan need_drop,reason is {DropReason.ENCRYPTED}')
else:
is_text_pdf, results = classify(
pdf_meta['total_page'],
pdf_meta['page_width_pts'],
pdf_meta['page_height_pts'],
pdf_meta['image_info_per_page'],
pdf_meta['text_len_per_page'],
pdf_meta['imgs_per_page'],
pdf_meta['text_layout_per_page'],
pdf_meta['invalid_chars'],
)
if is_text_pdf:
return AbsPipe.PIP_TXT
else:
return AbsPipe.PIP_OCR
@staticmethod
def mk_uni_format(compressed_pdf_mid_data: str, img_buket_path: str, drop_mode=DropMode.WHOLE_PDF) -> list:
"""根据pdf类型,生成统一格式content_list."""
pdf_mid_data = JsonCompressor.decompress_json(compressed_pdf_mid_data)
pdf_info_list = pdf_mid_data['pdf_info']
content_list = union_make(pdf_info_list, MakeMode.STANDARD_FORMAT, drop_mode, img_buket_path)
return content_list
@staticmethod
def mk_markdown(compressed_pdf_mid_data: str, img_buket_path: str, drop_mode=DropMode.WHOLE_PDF, md_make_mode=MakeMode.MM_MD) -> list:
"""根据pdf类型,markdown."""
pdf_mid_data = JsonCompressor.decompress_json(compressed_pdf_mid_data)
pdf_info_list = pdf_mid_data['pdf_info']
md_content = union_make(pdf_info_list, md_make_mode, drop_mode, img_buket_path)
return md_content
from loguru import logger
from magic_pdf.config.make_content_config import DropMode, MakeMode
from magic_pdf.data.data_reader_writer import DataWriter
from magic_pdf.data.dataset import Dataset
from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
from magic_pdf.pipe.AbsPipe import AbsPipe
from magic_pdf.user_api import parse_ocr_pdf
class OCRPipe(AbsPipe):
def __init__(
self,
dataset: Dataset,
model_list: list,
image_writer: DataWriter,
is_debug: bool = False,
start_page_id=0,
end_page_id=None,
lang=None,
layout_model=None,
formula_enable=None,
table_enable=None,
):
super().__init__(
dataset,
model_list,
image_writer,
is_debug,
start_page_id,
end_page_id,
lang,
layout_model,
formula_enable,
table_enable,
)
def pipe_classify(self):
pass
def pipe_analyze(self):
self.infer_res = doc_analyze(
self.dataset,
ocr=True,
start_page_id=self.start_page_id,
end_page_id=self.end_page_id,
lang=self.lang,
layout_model=self.layout_model,
formula_enable=self.formula_enable,
table_enable=self.table_enable,
)
def pipe_parse(self):
self.pdf_mid_data = parse_ocr_pdf(
self.dataset,
self.infer_res,
self.image_writer,
is_debug=self.is_debug,
start_page_id=self.start_page_id,
end_page_id=self.end_page_id,
lang=self.lang,
layout_model=self.layout_model,
formula_enable=self.formula_enable,
table_enable=self.table_enable,
)
def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF):
result = super().pipe_mk_uni_format(img_parent_path, drop_mode)
logger.info('ocr_pipe mk content list finished')
return result
def pipe_mk_markdown(
self,
img_parent_path: str,
drop_mode=DropMode.WHOLE_PDF,
md_make_mode=MakeMode.MM_MD,
):
result = super().pipe_mk_markdown(img_parent_path, drop_mode, md_make_mode)
logger.info(f'ocr_pipe mk {md_make_mode} finished')
return result
from loguru import logger
from magic_pdf.config.make_content_config import DropMode, MakeMode
from magic_pdf.data.data_reader_writer import DataWriter
from magic_pdf.data.dataset import Dataset
from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
from magic_pdf.pipe.AbsPipe import AbsPipe
from magic_pdf.user_api import parse_txt_pdf
class TXTPipe(AbsPipe):
def __init__(self, dataset: Dataset, model_list: list, image_writer: DataWriter, is_debug: bool = False,
start_page_id=0, end_page_id=None, lang=None,
layout_model=None, formula_enable=None, table_enable=None):
super().__init__(dataset, model_list, image_writer, is_debug, start_page_id, end_page_id, lang,
layout_model, formula_enable, table_enable)
def pipe_classify(self):
pass
def pipe_analyze(self):
self.model_list = doc_analyze(self.dataset, ocr=False,
start_page_id=self.start_page_id, end_page_id=self.end_page_id,
lang=self.lang, layout_model=self.layout_model,
formula_enable=self.formula_enable, table_enable=self.table_enable)
def pipe_parse(self):
self.pdf_mid_data = parse_txt_pdf(self.dataset, self.model_list, self.image_writer, is_debug=self.is_debug,
start_page_id=self.start_page_id, end_page_id=self.end_page_id,
lang=self.lang, layout_model=self.layout_model,
formula_enable=self.formula_enable, table_enable=self.table_enable)
def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF):
result = super().pipe_mk_uni_format(img_parent_path, drop_mode)
logger.info('txt_pipe mk content list finished')
return result
def pipe_mk_markdown(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF, md_make_mode=MakeMode.MM_MD):
result = super().pipe_mk_markdown(img_parent_path, drop_mode, md_make_mode)
logger.info(f'txt_pipe mk {md_make_mode} finished')
return result
import json
from loguru import logger
from magic_pdf.config.make_content_config import DropMode, MakeMode
from magic_pdf.data.data_reader_writer import DataWriter
from magic_pdf.data.dataset import Dataset
from magic_pdf.libs.commons import join_path
from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
from magic_pdf.pipe.AbsPipe import AbsPipe
from magic_pdf.user_api import parse_ocr_pdf, parse_union_pdf
class UNIPipe(AbsPipe):
def __init__(
self,
dataset: Dataset,
jso_useful_key: dict,
image_writer: DataWriter,
is_debug: bool = False,
start_page_id=0,
end_page_id=None,
lang=None,
layout_model=None,
formula_enable=None,
table_enable=None,
):
self.pdf_type = jso_useful_key['_pdf_type']
super().__init__(
dataset,
jso_useful_key['model_list'],
image_writer,
is_debug,
start_page_id,
end_page_id,
lang,
layout_model,
formula_enable,
table_enable,
)
if len(self.model_list) == 0:
self.input_model_is_empty = True
else:
self.input_model_is_empty = False
def pipe_classify(self):
self.pdf_type = AbsPipe.classify(self.pdf_bytes)
def pipe_analyze(self):
if self.pdf_type == self.PIP_TXT:
self.model_list = doc_analyze(
self.dataset,
ocr=False,
start_page_id=self.start_page_id,
end_page_id=self.end_page_id,
lang=self.lang,
layout_model=self.layout_model,
formula_enable=self.formula_enable,
table_enable=self.table_enable,
)
elif self.pdf_type == self.PIP_OCR:
self.model_list = doc_analyze(
self.dataset,
ocr=True,
start_page_id=self.start_page_id,
end_page_id=self.end_page_id,
lang=self.lang,
layout_model=self.layout_model,
formula_enable=self.formula_enable,
table_enable=self.table_enable,
)
def pipe_parse(self):
if self.pdf_type == self.PIP_TXT:
self.pdf_mid_data = parse_union_pdf(
self.dataset,
self.model_list,
self.image_writer,
is_debug=self.is_debug,
start_page_id=self.start_page_id,
end_page_id=self.end_page_id,
lang=self.lang,
layout_model=self.layout_model,
formula_enable=self.formula_enable,
table_enable=self.table_enable,
)
elif self.pdf_type == self.PIP_OCR:
self.pdf_mid_data = parse_ocr_pdf(
self.dataset,
self.model_list,
self.image_writer,
is_debug=self.is_debug,
start_page_id=self.start_page_id,
end_page_id=self.end_page_id,
lang=self.lang,
)
def pipe_mk_uni_format(
self, img_parent_path: str, drop_mode=DropMode.NONE_WITH_REASON
):
result = super().pipe_mk_uni_format(img_parent_path, drop_mode)
logger.info('uni_pipe mk content list finished')
return result
def pipe_mk_markdown(
self,
img_parent_path: str,
drop_mode=DropMode.WHOLE_PDF,
md_make_mode=MakeMode.MM_MD,
):
result = super().pipe_mk_markdown(img_parent_path, drop_mode, md_make_mode)
logger.info(f'uni_pipe mk {md_make_mode} finished')
return result
if __name__ == '__main__':
# 测试
from magic_pdf.data.data_reader_writer import DataReader
drw = DataReader(r'D:/project/20231108code-clean')
pdf_file_path = r'linshixuqiu\19983-00.pdf'
model_file_path = r'linshixuqiu\19983-00.json'
pdf_bytes = drw.read(pdf_file_path)
model_json_txt = drw.read(model_file_path).decode()
model_list = json.loads(model_json_txt)
write_path = r'D:\project\20231108code-clean\linshixuqiu\19983-00'
img_bucket_path = 'imgs'
img_writer = DataWriter(join_path(write_path, img_bucket_path))
# pdf_type = UNIPipe.classify(pdf_bytes)
# jso_useful_key = {
# "_pdf_type": pdf_type,
# "model_list": model_list
# }
jso_useful_key = {'_pdf_type': '', 'model_list': model_list}
pipe = UNIPipe(pdf_bytes, jso_useful_key, img_writer)
pipe.pipe_classify()
pipe.pipe_parse()
md_content = pipe.pipe_mk_markdown(img_bucket_path)
content_list = pipe.pipe_mk_uni_format(img_bucket_path)
md_writer = DataWriter(write_path)
md_writer.write_string('19983-00.md', md_content)
md_writer.write_string(
'19983-00.json', json.dumps(pipe.pdf_mid_data, ensure_ascii=False, indent=4)
)
md_writer.write_string('19983-00.txt', str(content_list))
from abc import ABC, abstractmethod
class AbsReaderWriter(ABC):
MODE_TXT = "text"
MODE_BIN = "binary"
@abstractmethod
def read(self, path: str, mode=MODE_TXT):
raise NotImplementedError
@abstractmethod
def write(self, content: str, path: str, mode=MODE_TXT):
raise NotImplementedError
@abstractmethod
def read_offset(self, path: str, offset=0, limit=None) -> bytes:
raise NotImplementedError
import os
from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter
from loguru import logger
class DiskReaderWriter(AbsReaderWriter):
def __init__(self, parent_path, encoding="utf-8"):
self.path = parent_path
self.encoding = encoding
def read(self, path, mode=AbsReaderWriter.MODE_TXT):
if os.path.isabs(path):
abspath = path
else:
abspath = os.path.join(self.path, path)
if not os.path.exists(abspath):
logger.error(f"file {abspath} not exists")
raise Exception(f"file {abspath} no exists")
if mode == AbsReaderWriter.MODE_TXT:
with open(abspath, "r", encoding=self.encoding) as f:
return f.read()
elif mode == AbsReaderWriter.MODE_BIN:
with open(abspath, "rb") as f:
return f.read()
else:
raise ValueError("Invalid mode. Use 'text' or 'binary'.")
def write(self, content, path, mode=AbsReaderWriter.MODE_TXT):
if os.path.isabs(path):
abspath = path
else:
abspath = os.path.join(self.path, path)
directory_path = os.path.dirname(abspath)
if not os.path.exists(directory_path):
os.makedirs(directory_path)
if mode == AbsReaderWriter.MODE_TXT:
with open(abspath, "w", encoding=self.encoding, errors="replace") as f:
f.write(content)
elif mode == AbsReaderWriter.MODE_BIN:
with open(abspath, "wb") as f:
f.write(content)
else:
raise ValueError("Invalid mode. Use 'text' or 'binary'.")
def read_offset(self, path: str, offset=0, limit=None):
abspath = path
if not os.path.isabs(path):
abspath = os.path.join(self.path, path)
with open(abspath, "rb") as f:
f.seek(offset)
return f.read(limit)
if __name__ == "__main__":
if 0:
file_path = "io/test/example.txt"
drw = DiskReaderWriter("D:\projects\papayfork\Magic-PDF\magic_pdf")
# 写入内容到文件
drw.write(b"Hello, World!", path="io/test/example.txt", mode="binary")
# 从文件读取内容
content = drw.read(path=file_path)
if content:
logger.info(f"从 {file_path} 读取的内容: {content}")
if 1:
drw = DiskReaderWriter("/opt/data/pdf/resources/test/io/")
content_bin = drw.read_offset("1.txt")
assert content_bin == b"ABCD!"
content_bin = drw.read_offset("1.txt", offset=1, limit=2)
assert content_bin == b"BC"
from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter
from magic_pdf.libs.commons import parse_bucket_key, join_path
import boto3
from loguru import logger
from botocore.config import Config
class S3ReaderWriter(AbsReaderWriter):
def __init__(
self,
ak: str,
sk: str,
endpoint_url: str,
addressing_style: str = "auto",
parent_path: str = "",
):
self.client = self._get_client(ak, sk, endpoint_url, addressing_style)
self.path = parent_path
def _get_client(self, ak: str, sk: str, endpoint_url: str, addressing_style: str):
s3_client = boto3.client(
service_name="s3",
aws_access_key_id=ak,
aws_secret_access_key=sk,
endpoint_url=endpoint_url,
config=Config(
s3={"addressing_style": addressing_style},
retries={"max_attempts": 5, "mode": "standard"},
),
)
return s3_client
def read(self, s3_relative_path, mode=AbsReaderWriter.MODE_TXT, encoding="utf-8"):
if s3_relative_path.startswith("s3://"):
s3_path = s3_relative_path
else:
s3_path = join_path(self.path, s3_relative_path)
bucket_name, key = parse_bucket_key(s3_path)
res = self.client.get_object(Bucket=bucket_name, Key=key)
body = res["Body"].read()
if mode == AbsReaderWriter.MODE_TXT:
data = body.decode(encoding) # Decode bytes to text
elif mode == AbsReaderWriter.MODE_BIN:
data = body
else:
raise ValueError("Invalid mode. Use 'text' or 'binary'.")
return data
def write(self, content, s3_relative_path, mode=AbsReaderWriter.MODE_TXT, encoding="utf-8"):
if s3_relative_path.startswith("s3://"):
s3_path = s3_relative_path
else:
s3_path = join_path(self.path, s3_relative_path)
if mode == AbsReaderWriter.MODE_TXT:
body = content.encode(encoding) # Encode text data as bytes
elif mode == AbsReaderWriter.MODE_BIN:
body = content
else:
raise ValueError("Invalid mode. Use 'text' or 'binary'.")
bucket_name, key = parse_bucket_key(s3_path)
self.client.put_object(Body=body, Bucket=bucket_name, Key=key)
logger.info(f"内容已写入 {s3_path} ")
def read_offset(self, path: str, offset=0, limit=None) -> bytes:
if path.startswith("s3://"):
s3_path = path
else:
s3_path = join_path(self.path, path)
bucket_name, key = parse_bucket_key(s3_path)
range_header = (
f"bytes={offset}-{offset+limit-1}" if limit else f"bytes={offset}-"
)
res = self.client.get_object(Bucket=bucket_name, Key=key, Range=range_header)
return res["Body"].read()
if __name__ == "__main__":
if 0:
# Config the connection info
ak = ""
sk = ""
endpoint_url = ""
addressing_style = "auto"
bucket_name = ""
# Create an S3ReaderWriter object
s3_reader_writer = S3ReaderWriter(
ak, sk, endpoint_url, addressing_style, "s3://bucket_name/"
)
# Write text data to S3
text_data = "This is some text data"
s3_reader_writer.write(
text_data,
s3_relative_path=f"s3://{bucket_name}/ebook/test/test.json",
mode=AbsReaderWriter.MODE_TXT,
)
# Read text data from S3
text_data_read = s3_reader_writer.read(
s3_relative_path=f"s3://{bucket_name}/ebook/test/test.json", mode=AbsReaderWriter.MODE_TXT
)
logger.info(f"Read text data from S3: {text_data_read}")
# Write binary data to S3
binary_data = b"This is some binary data"
s3_reader_writer.write(
text_data,
s3_relative_path=f"s3://{bucket_name}/ebook/test/test.json",
mode=AbsReaderWriter.MODE_BIN,
)
# Read binary data from S3
binary_data_read = s3_reader_writer.read(
s3_relative_path=f"s3://{bucket_name}/ebook/test/test.json", mode=AbsReaderWriter.MODE_BIN
)
logger.info(f"Read binary data from S3: {binary_data_read}")
# Range Read text data from S3
binary_data_read = s3_reader_writer.read_offset(
path=f"s3://{bucket_name}/ebook/test/test.json", offset=0, limit=10
)
logger.info(f"Read binary data from S3: {binary_data_read}")
if 1:
import os
import json
ak = os.getenv("AK", "")
sk = os.getenv("SK", "")
endpoint_url = os.getenv("ENDPOINT", "")
bucket = os.getenv("S3_BUCKET", "")
prefix = os.getenv("S3_PREFIX", "")
key_basename = os.getenv("S3_KEY_BASENAME", "")
s3_reader_writer = S3ReaderWriter(
ak, sk, endpoint_url, "auto", f"s3://{bucket}/{prefix}"
)
content_bin = s3_reader_writer.read_offset(key_basename)
assert content_bin[:10] == b'{"track_id'
assert content_bin[-10:] == b'r":null}}\n'
content_bin = s3_reader_writer.read_offset(key_basename, offset=424, limit=426)
jso = json.dumps(content_bin.decode("utf-8"))
print(jso)
"""用户输入: model数组,每个元素代表一个页面 pdf在s3的路径 截图保存的s3位置.
然后:
1)根据s3路径,调用spark集群的api,拿到ak,sk,endpoint,构造出s3PDFReader
2)根据用户输入的s3地址,调用spark集群的api,拿到ak,sk,endpoint,构造出s3ImageWriter
其余部分至于构造s3cli, 获取ak,sk都在code-clean里写代码完成。不要反向依赖!!!
"""
from loguru import logger
from magic_pdf.data.data_reader_writer import DataWriter
from magic_pdf.data.dataset import Dataset
from magic_pdf.libs.version import __version__
from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
from magic_pdf.pdf_parse_by_ocr import parse_pdf_by_ocr
from magic_pdf.pdf_parse_by_txt import parse_pdf_by_txt
from magic_pdf.config.constants import PARSE_TYPE_TXT, PARSE_TYPE_OCR
def parse_txt_pdf(
dataset: Dataset,
model_list: list,
imageWriter: DataWriter,
is_debug=False,
start_page_id=0,
end_page_id=None,
lang=None,
*args,
**kwargs
):
"""解析文本类pdf."""
pdf_info_dict = parse_pdf_by_txt(
dataset,
model_list,
imageWriter,
start_page_id=start_page_id,
end_page_id=end_page_id,
debug_mode=is_debug,
lang=lang,
)
pdf_info_dict['_parse_type'] = PARSE_TYPE_TXT
pdf_info_dict['_version_name'] = __version__
if lang is not None:
pdf_info_dict['_lang'] = lang
return pdf_info_dict
def parse_ocr_pdf(
dataset: Dataset,
model_list: list,
imageWriter: DataWriter,
is_debug=False,
start_page_id=0,
end_page_id=None,
lang=None,
*args,
**kwargs
):
"""解析ocr类pdf."""
pdf_info_dict = parse_pdf_by_ocr(
dataset,
model_list,
imageWriter,
start_page_id=start_page_id,
end_page_id=end_page_id,
debug_mode=is_debug,
lang=lang,
)
pdf_info_dict['_parse_type'] = PARSE_TYPE_OCR
pdf_info_dict['_version_name'] = __version__
if lang is not None:
pdf_info_dict['_lang'] = lang
return pdf_info_dict
def parse_union_pdf(
dataset: Dataset,
model_list: list,
imageWriter: DataWriter,
is_debug=False,
start_page_id=0,
end_page_id=None,
lang=None,
*args,
**kwargs
):
"""ocr和文本混合的pdf,全部解析出来."""
def parse_pdf(method):
try:
return method(
dataset,
model_list,
imageWriter,
start_page_id=start_page_id,
end_page_id=end_page_id,
debug_mode=is_debug,
lang=lang,
)
except Exception as e:
logger.exception(e)
return None
pdf_info_dict = parse_pdf(parse_pdf_by_txt)
if pdf_info_dict is None or pdf_info_dict.get('_need_drop', False):
logger.warning('parse_pdf_by_txt drop or error, switch to parse_pdf_by_ocr')
if len(model_list) == 0:
layout_model = kwargs.get('layout_model', None)
formula_enable = kwargs.get('formula_enable', None)
table_enable = kwargs.get('table_enable', None)
infer_res = doc_analyze(
dataset,
ocr=True,
start_page_id=start_page_id,
end_page_id=end_page_id,
lang=lang,
layout_model=layout_model,
formula_enable=formula_enable,
table_enable=table_enable,
)
model_list = infer_res.get_infer_res()
pdf_info_dict = parse_pdf(parse_pdf_by_ocr)
if pdf_info_dict is None:
raise Exception('Both parse_pdf_by_txt and parse_pdf_by_ocr failed.')
else:
pdf_info_dict['_parse_type'] = PARSE_TYPE_OCR
else:
pdf_info_dict['_parse_type'] = PARSE_TYPE_TXT
pdf_info_dict['_version_name'] = __version__
if lang is not None:
pdf_info_dict['_lang'] = lang
return pdf_info_dict
import pytest
import os import os
from conf import conf from conf import conf
import os import os
import json import json
from magic_pdf.pipe.UNIPipe import UNIPipe
from magic_pdf.rw.DiskReaderWriter import DiskReaderWriter
from lib import calculate_score from lib import calculate_score
import shutil import shutil
pdf_res_path = conf.conf["pdf_res_path"] pdf_res_path = conf.conf["pdf_res_path"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment