Commit 8f1f9abe authored by myhloli's avatar myhloli
Browse files

refactor: enhance bounding box utilities and add configuration reader for S3 integration

parent 7285ea92
# Copyright (c) Opendatalab. All rights reserved.
import json
import os
from loguru import logger
# 定义配置文件名常量
CONFIG_FILE_NAME = os.getenv('MINERU_TOOLS_CONFIG_JSON', 'magic-pdf.json')
def read_config():
if os.path.isabs(CONFIG_FILE_NAME):
config_file = CONFIG_FILE_NAME
else:
home_dir = os.path.expanduser('~')
config_file = os.path.join(home_dir, CONFIG_FILE_NAME)
if not os.path.exists(config_file):
raise FileNotFoundError(f'{config_file} not found')
with open(config_file, 'r', encoding='utf-8') as f:
config = json.load(f)
return config
def get_s3_config(bucket_name: str):
"""~/magic-pdf.json 读出来."""
config = read_config()
bucket_info = config.get('bucket_info')
if bucket_name not in bucket_info:
access_key, secret_key, storage_endpoint = bucket_info['[default]']
else:
access_key, secret_key, storage_endpoint = bucket_info[bucket_name]
if access_key is None or secret_key is None or storage_endpoint is None:
raise Exception(f'ak, sk or endpoint not found in {CONFIG_FILE_NAME}')
# logger.info(f"get_s3_config: ak={access_key}, sk={secret_key}, endpoint={storage_endpoint}")
return access_key, secret_key, storage_endpoint
def get_s3_config_dict(path: str):
access_key, secret_key, storage_endpoint = get_s3_config(get_bucket_name(path))
return {'ak': access_key, 'sk': secret_key, 'endpoint': storage_endpoint}
def get_bucket_name(path):
bucket, key = parse_bucket_key(path)
return bucket
def parse_bucket_key(s3_full_path: str):
"""
输入 s3://bucket/path/to/my/file.txt
输出 bucket, path/to/my/file.txt
"""
s3_full_path = s3_full_path.strip()
if s3_full_path.startswith("s3://"):
s3_full_path = s3_full_path[5:]
if s3_full_path.startswith("/"):
s3_full_path = s3_full_path[1:]
bucket, key = s3_full_path.split("/", 1)
return bucket, key
def get_local_models_dir():
config = read_config()
models_dir = config.get('models-dir')
if models_dir is None:
logger.warning(f"'models-dir' not found in {CONFIG_FILE_NAME}, use '/tmp/models' as default")
return '/tmp/models'
else:
return models_dir
def get_local_layoutreader_model_dir():
config = read_config()
layoutreader_model_dir = config.get('layoutreader-model-dir')
if layoutreader_model_dir is None or not os.path.exists(layoutreader_model_dir):
home_dir = os.path.expanduser('~')
layoutreader_at_modelscope_dir_path = os.path.join(home_dir, '.cache/modelscope/hub/ppaanngggg/layoutreader')
logger.warning(f"'layoutreader-model-dir' not exists, use {layoutreader_at_modelscope_dir_path} as default")
return layoutreader_at_modelscope_dir_path
else:
return layoutreader_model_dir
def get_device():
config = read_config()
device = config.get('device-mode')
if device is None:
logger.warning(f"'device-mode' not found in {CONFIG_FILE_NAME}, use 'cpu' as default")
return 'cpu'
else:
return device
def get_table_recog_config():
config = read_config()
table_config = config.get('table-config')
if table_config is None:
logger.warning(f"'table-config' not found in {CONFIG_FILE_NAME}, use 'False' as default")
return json.loads(f'{{"enable": true}}')
else:
return table_config
def get_formula_config():
config = read_config()
formula_config = config.get('formula-config')
if formula_config is None:
logger.warning(f"'formula-config' not found in {CONFIG_FILE_NAME}, use 'True' as default")
return json.loads(f'{{"enable": true}}')
else:
return formula_config
\ No newline at end of file
# Copyright (c) Opendatalab. All rights reserved. # Copyright (c) Opendatalab. All rights reserved.
def result_to_middle_json(model_json, images_list, pdf_doc, image_writer): from mineru.utils.pipeline_magic_model import MagicModel
pass from mineru.version import __version__
\ No newline at end of file from mineru.utils.hash_utils import str_md5
def page_model_info_to_page_info(page_model_info, image_dict, page, image_writer, page_index, lang=None, ocr=False):
scale = image_dict["scale"]
page_pil_img = image_dict["img_pil"]
page_img_md5 = str_md5(image_dict["img_base64"])
width, height = map(int, page.get_size())
magic_model = MagicModel(page_model_info, scale)
def result_to_middle_json(model_list, images_list, pdf_doc, image_writer, lang=None, ocr=False):
middle_json = {"pdf_info": [], "_backend":"vlm", "_version_name": __version__}
for page_index, page_model_info in enumerate(model_list):
page = pdf_doc[page_index]
image_dict = images_list[page_index]
page_info = page_model_info_to_page_info(
page_model_info, image_dict, page, image_writer, page_index, lang=lang, ocr=ocr
)
middle_json["pdf_info"].append(page_info)
return middle_json
\ No newline at end of file
...@@ -2,9 +2,9 @@ import os ...@@ -2,9 +2,9 @@ import os
import time import time
import numpy as np import numpy as np
import torch import torch
from pypdfium2 import PdfDocument
from mineru.backend.pipeline.model_init import MineruPipelineModel from .model_init import MineruPipelineModel
from .config_reader import get_local_models_dir, get_device, get_formula_config, get_table_recog_config
from .model_json_to_middle_json import result_to_middle_json from .model_json_to_middle_json import result_to_middle_json
from ...data.data_reader_writer import DataWriter from ...data.data_reader_writer import DataWriter
from ...utils.pdf_classify import classify from ...utils.pdf_classify import classify
...@@ -13,11 +13,6 @@ from ...utils.pdf_image_tools import load_images_from_pdf ...@@ -13,11 +13,6 @@ from ...utils.pdf_image_tools import load_images_from_pdf
from loguru import logger from loguru import logger
from ...utils.model_utils import get_vram, clean_memory from ...utils.model_utils import get_vram, clean_memory
from magic_pdf.libs.config_reader import (get_device, get_formula_config,
get_layout_config,
get_local_models_dir,
get_table_recog_config)
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 让mps可以fallback os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 让mps可以fallback
...@@ -109,6 +104,7 @@ def doc_analyze( ...@@ -109,6 +104,7 @@ def doc_analyze(
all_image_lists = [] all_image_lists = []
all_pdf_docs = [] all_pdf_docs = []
ocr_enabled_list = []
for pdf_idx, pdf_bytes in enumerate(pdf_bytes_list): for pdf_idx, pdf_bytes in enumerate(pdf_bytes_list):
# 确定OCR设置 # 确定OCR设置
_ocr = False _ocr = False
...@@ -118,6 +114,7 @@ def doc_analyze( ...@@ -118,6 +114,7 @@ def doc_analyze(
elif parse_method == 'ocr': elif parse_method == 'ocr':
_ocr = True _ocr = True
ocr_enabled_list[pdf_idx] = _ocr
_lang = lang_list[pdf_idx] _lang = lang_list[pdf_idx]
# 收集每个数据集中的页面 # 收集每个数据集中的页面
...@@ -152,23 +149,23 @@ def doc_analyze( ...@@ -152,23 +149,23 @@ def doc_analyze(
results.extend(batch_results) results.extend(batch_results)
# 构建返回结果 # 构建返回结果
infer_results = []
# 多数据集模式:按数据集分组结果
infer_results = [[] for _ in datasets]
for i, page_info in enumerate(all_pages_info): for i, page_info in enumerate(all_pages_info):
pdf_idx, page_idx, pil_img, _, _ = page_info pdf_idx, page_idx, pil_img, _, _ = page_info
result = results[i] result = results[i]
page_info_dict = {'page_no': page_idx, 'width': pil_img.get_width(), 'height': pil_img.get_height()} page_info_dict = {'page_no': page_idx, 'width': pil_img.width, 'height': pil_img.height}
page_dict = {'layout_dets': result, 'page_info': page_info_dict} page_dict = {'layout_dets': result, 'page_info': page_info_dict}
infer_results[pdf_idx].append(page_dict) infer_results[pdf_idx][page_idx] = page_dict
middle_json_list = [] middle_json_list = []
for pdf_idx, model_json in enumerate(infer_results): for pdf_idx, model_list in enumerate(infer_results):
images_list = all_image_lists[pdf_idx] images_list = all_image_lists[pdf_idx]
pdf_doc = all_pdf_docs[pdf_idx] pdf_doc = all_pdf_docs[pdf_idx]
middle_json = result_to_middle_json(model_json, images_list, pdf_doc, image_writer) _lang = lang_list[pdf_idx]
_ocr = ocr_enabled_list[pdf_idx]
middle_json = result_to_middle_json(model_list, images_list, pdf_doc, image_writer, _lang, _ocr)
middle_json_list.append(middle_json) middle_json_list.append(middle_json)
return middle_json_list, infer_results return middle_json_list, infer_results
......
...@@ -118,7 +118,7 @@ def token_to_page_info(token, image_dict, page, image_writer, page_index) -> dic ...@@ -118,7 +118,7 @@ def token_to_page_info(token, image_dict, page, image_writer, page_index) -> dic
def result_to_middle_json(token_list, images_list, pdf_doc, image_writer): def result_to_middle_json(token_list, images_list, pdf_doc, image_writer):
middle_json = {"pdf_info": [], "_version_name": __version__} middle_json = {"pdf_info": [], "_backend":"vlm", "_version_name": __version__}
for index, token in enumerate(token_list): for index, token in enumerate(token_list):
page = pdf_doc[index] page = pdf_doc[index]
image_dict = images_list[index] image_dict = images_list[index]
......
# Copyright (c) Opendatalab. All rights reserved.
# Copyright (c) Opendatalab. All rights reserved.
...@@ -72,3 +72,88 @@ def bbox_distance(bbox1, bbox2): ...@@ -72,3 +72,88 @@ def bbox_distance(bbox1, bbox2):
elif top: elif top:
return y2 - y1b return y2 - y1b
return 0.0 return 0.0
def get_minbox_if_overlap_by_ratio(bbox1, bbox2, ratio):
"""通过calculate_overlap_area_2_minbox_area_ratio计算两个bbox重叠的面积占最小面积的box的比例
如果比例大于ratio,则返回小的那个bbox, 否则返回None."""
x1_min, y1_min, x1_max, y1_max = bbox1
x2_min, y2_min, x2_max, y2_max = bbox2
area1 = (x1_max - x1_min) * (y1_max - y1_min)
area2 = (x2_max - x2_min) * (y2_max - y2_min)
overlap_ratio = calculate_overlap_area_2_minbox_area_ratio(bbox1, bbox2)
if overlap_ratio > ratio:
if area1 <= area2:
return bbox1
else:
return bbox2
else:
return None
def calculate_overlap_area_2_minbox_area_ratio(bbox1, bbox2):
"""计算box1和box2的重叠面积占最小面积的box的比例."""
# Determine the coordinates of the intersection rectangle
x_left = max(bbox1[0], bbox2[0])
y_top = max(bbox1[1], bbox2[1])
x_right = min(bbox1[2], bbox2[2])
y_bottom = min(bbox1[3], bbox2[3])
if x_right < x_left or y_bottom < y_top:
return 0.0
# The area of overlap area
intersection_area = (x_right - x_left) * (y_bottom - y_top)
min_box_area = min([(bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1]),
(bbox2[3] - bbox2[1]) * (bbox2[2] - bbox2[0])])
if min_box_area == 0:
return 0
else:
return intersection_area / min_box_area
def calculate_iou(bbox1, bbox2):
"""计算两个边界框的交并比(IOU)。
Args:
bbox1 (list[float]): 第一个边界框的坐标,格式为 [x1, y1, x2, y2],其中 (x1, y1) 为左上角坐标,(x2, y2) 为右下角坐标。
bbox2 (list[float]): 第二个边界框的坐标,格式与 `bbox1` 相同。
Returns:
float: 两个边界框的交并比(IOU),取值范围为 [0, 1]。
"""
# Determine the coordinates of the intersection rectangle
x_left = max(bbox1[0], bbox2[0])
y_top = max(bbox1[1], bbox2[1])
x_right = min(bbox1[2], bbox2[2])
y_bottom = min(bbox1[3], bbox2[3])
if x_right < x_left or y_bottom < y_top:
return 0.0
# The area of overlap area
intersection_area = (x_right - x_left) * (y_bottom - y_top)
# The area of both rectangles
bbox1_area = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1])
bbox2_area = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1])
if any([bbox1_area == 0, bbox2_area == 0]):
return 0
# Compute the intersection over union by taking the intersection area
# and dividing it by the sum of both areas minus the intersection area
iou = intersection_area / float(bbox1_area + bbox2_area - intersection_area)
return iou
def _is_in(box1, box2) -> bool:
"""box1是否完全在box2里面."""
x0_1, y0_1, x1_1, y1_1 = box1
x0_2, y0_2, x1_2, y1_2 = box2
return (x0_1 >= x0_2 and # box1的左边界不在box2的左边外
y0_1 >= y0_2 and # box1的上边界不在box2的上边外
x1_1 <= x1_2 and # box1的右边界不在box2的右边外
y1_1 <= y1_2) # box1的下边界不在box2的下边外
\ No newline at end of file
...@@ -23,6 +23,22 @@ class ContentType: ...@@ -23,6 +23,22 @@ class ContentType:
INLINE_EQUATION = 'inline_equation' INLINE_EQUATION = 'inline_equation'
class CategoryId:
Title = 0
Text = 1
Abandon = 2
ImageBody = 3
ImageCaption = 4
TableBody = 5
TableCaption = 6
TableFootnote = 7
InterlineEquation_Layout = 8
InlineEquation = 13
InterlineEquation_YOLO = 14
OcrText = 15
ImageFootnote = 101
class MakeMode: class MakeMode:
MM_MD = 'mm_markdown' MM_MD = 'mm_markdown'
NLP_MD = 'nlp_markdown' NLP_MD = 'nlp_markdown'
......
import os
import unicodedata
if not os.getenv("FTLANG_CACHE"):
current_file_path = os.path.abspath(__file__)
current_dir = os.path.dirname(current_file_path)
root_dir = os.path.dirname(current_dir)
ftlang_cache_dir = os.path.join(root_dir, 'resources', 'fasttext-langdetect')
os.environ["FTLANG_CACHE"] = str(ftlang_cache_dir)
# print(os.getenv("FTLANG_CACHE"))
from fast_langdetect import detect_language
def remove_invalid_surrogates(text):
# 移除无效的 UTF-16 代理对
return ''.join(c for c in text if not (0xD800 <= ord(c) <= 0xDFFF))
def detect_lang(text: str) -> str:
if len(text) == 0:
return ""
text = text.replace("\n", "")
text = remove_invalid_surrogates(text)
# print(text)
try:
lang_upper = detect_language(text)
except:
html_no_ctrl_chars = ''.join([l for l in text if unicodedata.category(l)[0] not in ['C', ]])
lang_upper = detect_language(html_no_ctrl_chars)
try:
lang = lang_upper.lower()
except:
lang = ""
return lang
if __name__ == '__main__':
print(os.getenv("FTLANG_CACHE"))
print(detect_lang("This is a test."))
print(detect_lang("<html>This is a test</html>"))
print(detect_lang("这个是中文测试。"))
print(detect_lang("<html>这个是中文测试。</html>"))
print(detect_lang("〖\ud835\udc46\ud835〗这是个包含utf-16的中文测试"))
\ No newline at end of file
...@@ -4,7 +4,7 @@ import gc ...@@ -4,7 +4,7 @@ import gc
from loguru import logger from loguru import logger
import numpy as np import numpy as np
from magic_pdf.libs.boxbase import get_minbox_if_overlap_by_ratio from mineru.utils.boxbase import get_minbox_if_overlap_by_ratio
def crop_img(input_res, input_np_img, crop_paste_x=0, crop_paste_y=0): def crop_img(input_res, input_np_img, crop_paste_x=0, crop_paste_y=0):
......
import enum from mineru.utils.boxbase import bbox_relative_pos, calculate_iou, bbox_distance, _is_in
from mineru.utils.enum_class import CategoryId, ContentType
from magic_pdf.config.model_block_type import ModelBlockTypeEnum
from magic_pdf.config.ocr_content_type import CategoryId, ContentType
from magic_pdf.data.dataset import Dataset
from magic_pdf.libs.boxbase import (_is_in, bbox_distance, bbox_relative_pos,
calculate_iou)
from magic_pdf.libs.coordinate_transform import get_scale_ratio
from magic_pdf.pre_proc.remove_bbox_overlap import _remove_overlap_between_bbox
CAPATION_OVERLAP_AREA_RATIO = 0.6
MERGE_BOX_OVERLAP_AREA_RATIO = 1.1
class PosRelationEnum(enum.Enum):
LEFT = 'left'
RIGHT = 'right'
UP = 'up'
BOTTOM = 'bottom'
ALL = 'all'
class MagicModel: class MagicModel:
"""每个函数没有得到元素的时候返回空list.""" """每个函数没有得到元素的时候返回空list."""
def __init__(self, page_model_info: dict, scale: float):
self.__page_model_info = page_model_info
self.__scale = scale
"""为所有模型数据添加bbox信息(缩放,poly->bbox)"""
self.__fix_axis()
"""删除置信度特别低的模型数据(<0.05),提高质量"""
self.__fix_by_remove_low_confidence()
"""删除高iou(>0.9)数据中置信度较低的那个"""
self.__fix_by_remove_high_iou_and_low_confidence()
self.__fix_footnote()
def __fix_axis(self): def __fix_axis(self):
for model_page_info in self.__model_list: need_remove_list = []
need_remove_list = [] layout_dets = self.__page_model_info['layout_dets']
page_no = model_page_info['page_info']['page_no'] for layout_det in layout_dets:
horizontal_scale_ratio, vertical_scale_ratio = get_scale_ratio( x0, y0, _, _, x1, y1, _, _ = layout_det['poly']
model_page_info, self.__docs.get_page(page_no) bbox = [
) int(x0 / self.__scale),
layout_dets = model_page_info['layout_dets'] int(y0 / self.__scale),
for layout_det in layout_dets: int(x1 / self.__scale),
int(y1 / self.__scale),
if layout_det.get('bbox') is not None: ]
# 兼容直接输出bbox的模型数据,如paddle layout_det['bbox'] = bbox
x0, y0, x1, y1 = layout_det['bbox'] # 删除高度或者宽度小于等于0的spans
else: if bbox[2] - bbox[0] <= 0 or bbox[3] - bbox[1] <= 0:
# 兼容直接输出poly的模型数据,如xxx need_remove_list.append(layout_det)
x0, y0, _, _, x1, y1, _, _ = layout_det['poly'] for need_remove in need_remove_list:
layout_dets.remove(need_remove)
bbox = [
int(x0 / horizontal_scale_ratio),
int(y0 / vertical_scale_ratio),
int(x1 / horizontal_scale_ratio),
int(y1 / vertical_scale_ratio),
]
layout_det['bbox'] = bbox
# 删除高度或者宽度小于等于0的spans
if bbox[2] - bbox[0] <= 0 or bbox[3] - bbox[1] <= 0:
need_remove_list.append(layout_det)
for need_remove in need_remove_list:
layout_dets.remove(need_remove)
def __fix_by_remove_low_confidence(self): def __fix_by_remove_low_confidence(self):
for model_page_info in self.__model_list: need_remove_list = []
need_remove_list = [] layout_dets = self.__page_model_info['layout_dets']
layout_dets = model_page_info['layout_dets'] for layout_det in layout_dets:
for layout_det in layout_dets: if layout_det['score'] <= 0.05:
if layout_det['score'] <= 0.05: need_remove_list.append(layout_det)
need_remove_list.append(layout_det) else:
else: continue
continue for need_remove in need_remove_list:
for need_remove in need_remove_list: layout_dets.remove(need_remove)
layout_dets.remove(need_remove)
def __fix_by_remove_high_iou_and_low_confidence(self): def __fix_by_remove_high_iou_and_low_confidence(self):
for model_page_info in self.__model_list: need_remove_list = []
need_remove_list = [] layout_dets = self.__page_model_info['layout_dets']
layout_dets = model_page_info['layout_dets'] for layout_det1 in layout_dets:
for layout_det1 in layout_dets: for layout_det2 in layout_dets:
for layout_det2 in layout_dets: if layout_det1 == layout_det2:
if layout_det1 == layout_det2: continue
continue if layout_det1['category_id'] in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] and layout_det2['category_id'] in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]:
if layout_det1['category_id'] in [ if (
0, calculate_iou(layout_det1['bbox'], layout_det2['bbox'])
1, > 0.9
2, ):
3, if layout_det1['score'] < layout_det2['score']:
4, layout_det_need_remove = layout_det1
5,
6,
7,
8,
9,
] and layout_det2['category_id'] in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]:
if (
calculate_iou(layout_det1['bbox'], layout_det2['bbox'])
> 0.9
):
if layout_det1['score'] < layout_det2['score']:
layout_det_need_remove = layout_det1
else:
layout_det_need_remove = layout_det2
if layout_det_need_remove not in need_remove_list:
need_remove_list.append(layout_det_need_remove)
else: else:
continue layout_det_need_remove = layout_det2
if layout_det_need_remove not in need_remove_list:
need_remove_list.append(layout_det_need_remove)
else: else:
continue continue
for need_remove in need_remove_list: else:
layout_dets.remove(need_remove) continue
for need_remove in need_remove_list:
layout_dets.remove(need_remove)
def __init__(self, model_list: list, docs: Dataset): def __fix_footnote(self):
self.__model_list = model_list # 3: figure, 5: table, 7: footnote
self.__docs = docs footnotes = []
"""为所有模型数据添加bbox信息(缩放,poly->bbox)""" figures = []
self.__fix_axis() tables = []
"""删除置信度特别低的模型数据(<0.05),提高质量"""
self.__fix_by_remove_low_confidence() for obj in self.__page_model_info['layout_dets']:
"""删除高iou(>0.9)数据中置信度较低的那个""" if obj['category_id'] == 7:
self.__fix_by_remove_high_iou_and_low_confidence() footnotes.append(obj)
self.__fix_footnote() elif obj['category_id'] == 3:
figures.append(obj)
elif obj['category_id'] == 5:
tables.append(obj)
if len(footnotes) * len(figures) == 0:
continue
dis_figure_footnote = {}
dis_table_footnote = {}
for i in range(len(footnotes)):
for j in range(len(figures)):
pos_flag_count = sum(
list(
map(
lambda x: 1 if x else 0,
bbox_relative_pos(
footnotes[i]['bbox'], figures[j]['bbox']
),
)
)
)
if pos_flag_count > 1:
continue
dis_figure_footnote[i] = min(
self._bbox_distance(figures[j]['bbox'], footnotes[i]['bbox']),
dis_figure_footnote.get(i, float('inf')),
)
for i in range(len(footnotes)):
for j in range(len(tables)):
pos_flag_count = sum(
list(
map(
lambda x: 1 if x else 0,
bbox_relative_pos(
footnotes[i]['bbox'], tables[j]['bbox']
),
)
)
)
if pos_flag_count > 1:
continue
dis_table_footnote[i] = min(
self._bbox_distance(tables[j]['bbox'], footnotes[i]['bbox']),
dis_table_footnote.get(i, float('inf')),
)
for i in range(len(footnotes)):
if i not in dis_figure_footnote:
continue
if dis_table_footnote.get(i, float('inf')) > dis_figure_footnote[i]:
footnotes[i]['category_id'] = CategoryId.ImageFootnote
def _bbox_distance(self, bbox1, bbox2): def _bbox_distance(self, bbox1, bbox2):
left, right, bottom, top = bbox_relative_pos(bbox1, bbox2) left, right, bottom, top = bbox_relative_pos(bbox1, bbox2)
...@@ -132,68 +149,6 @@ class MagicModel: ...@@ -132,68 +149,6 @@ class MagicModel:
return bbox_distance(bbox1, bbox2) return bbox_distance(bbox1, bbox2)
def __fix_footnote(self):
# 3: figure, 5: table, 7: footnote
for model_page_info in self.__model_list:
footnotes = []
figures = []
tables = []
for obj in model_page_info['layout_dets']:
if obj['category_id'] == 7:
footnotes.append(obj)
elif obj['category_id'] == 3:
figures.append(obj)
elif obj['category_id'] == 5:
tables.append(obj)
if len(footnotes) * len(figures) == 0:
continue
dis_figure_footnote = {}
dis_table_footnote = {}
for i in range(len(footnotes)):
for j in range(len(figures)):
pos_flag_count = sum(
list(
map(
lambda x: 1 if x else 0,
bbox_relative_pos(
footnotes[i]['bbox'], figures[j]['bbox']
),
)
)
)
if pos_flag_count > 1:
continue
dis_figure_footnote[i] = min(
self._bbox_distance(figures[j]['bbox'], footnotes[i]['bbox']),
dis_figure_footnote.get(i, float('inf')),
)
for i in range(len(footnotes)):
for j in range(len(tables)):
pos_flag_count = sum(
list(
map(
lambda x: 1 if x else 0,
bbox_relative_pos(
footnotes[i]['bbox'], tables[j]['bbox']
),
)
)
)
if pos_flag_count > 1:
continue
dis_table_footnote[i] = min(
self._bbox_distance(tables[j]['bbox'], footnotes[i]['bbox']),
dis_table_footnote.get(i, float('inf')),
)
for i in range(len(footnotes)):
if i not in dis_figure_footnote:
continue
if dis_table_footnote.get(i, float('inf')) > dis_figure_footnote[i]:
footnotes[i]['category_id'] = CategoryId.ImageFootnote
def __reduct_overlap(self, bboxes): def __reduct_overlap(self, bboxes):
N = len(bboxes) N = len(bboxes)
keep = [True] * N keep = [True] * N
...@@ -205,258 +160,10 @@ class MagicModel: ...@@ -205,258 +160,10 @@ class MagicModel:
keep[i] = False keep[i] = False
return [bboxes[i] for i in range(N) if keep[i]] return [bboxes[i] for i in range(N) if keep[i]]
def __tie_up_category_by_distance_v2(
self,
page_no: int,
subject_category_id: int,
object_category_id: int,
priority_pos: PosRelationEnum,
):
"""_summary_
Args:
page_no (int): _description_
subject_category_id (int): _description_
object_category_id (int): _description_
priority_pos (PosRelationEnum): _description_
Returns:
_type_: _description_
"""
AXIS_MULPLICITY = 0.5
subjects = self.__reduct_overlap(
list(
map(
lambda x: {'bbox': x['bbox'], 'score': x['score']},
filter(
lambda x: x['category_id'] == subject_category_id,
self.__model_list[page_no]['layout_dets'],
),
)
)
)
objects = self.__reduct_overlap(
list(
map(
lambda x: {'bbox': x['bbox'], 'score': x['score']},
filter(
lambda x: x['category_id'] == object_category_id,
self.__model_list[page_no]['layout_dets'],
),
)
)
)
M = len(objects)
subjects.sort(key=lambda x: x['bbox'][0] ** 2 + x['bbox'][1] ** 2)
objects.sort(key=lambda x: x['bbox'][0] ** 2 + x['bbox'][1] ** 2)
sub_obj_map_h = {i: [] for i in range(len(subjects))}
dis_by_directions = {
'top': [[-1, float('inf')]] * M,
'bottom': [[-1, float('inf')]] * M,
'left': [[-1, float('inf')]] * M,
'right': [[-1, float('inf')]] * M,
}
for i, obj in enumerate(objects):
l_x_axis, l_y_axis = (
obj['bbox'][2] - obj['bbox'][0],
obj['bbox'][3] - obj['bbox'][1],
)
axis_unit = min(l_x_axis, l_y_axis)
for j, sub in enumerate(subjects):
bbox1, bbox2, _ = _remove_overlap_between_bbox(
objects[i]['bbox'], subjects[j]['bbox']
)
left, right, bottom, top = bbox_relative_pos(bbox1, bbox2)
flags = [left, right, bottom, top]
if sum([1 if v else 0 for v in flags]) > 1:
continue
if left:
if dis_by_directions['left'][i][1] > bbox_distance(
obj['bbox'], sub['bbox']
):
dis_by_directions['left'][i] = [
j,
bbox_distance(obj['bbox'], sub['bbox']),
]
if right:
if dis_by_directions['right'][i][1] > bbox_distance(
obj['bbox'], sub['bbox']
):
dis_by_directions['right'][i] = [
j,
bbox_distance(obj['bbox'], sub['bbox']),
]
if bottom:
if dis_by_directions['bottom'][i][1] > bbox_distance(
obj['bbox'], sub['bbox']
):
dis_by_directions['bottom'][i] = [
j,
bbox_distance(obj['bbox'], sub['bbox']),
]
if top:
if dis_by_directions['top'][i][1] > bbox_distance(
obj['bbox'], sub['bbox']
):
dis_by_directions['top'][i] = [
j,
bbox_distance(obj['bbox'], sub['bbox']),
]
if (
dis_by_directions['top'][i][1] != float('inf')
and dis_by_directions['bottom'][i][1] != float('inf')
and priority_pos in (PosRelationEnum.BOTTOM, PosRelationEnum.UP)
):
RATIO = 3
if (
abs(
dis_by_directions['top'][i][1]
- dis_by_directions['bottom'][i][1]
)
< RATIO * axis_unit
):
if priority_pos == PosRelationEnum.BOTTOM:
sub_obj_map_h[dis_by_directions['bottom'][i][0]].append(i)
else:
sub_obj_map_h[dis_by_directions['top'][i][0]].append(i)
continue
if dis_by_directions['left'][i][1] != float('inf') or dis_by_directions[
'right'
][i][1] != float('inf'):
if dis_by_directions['left'][i][1] != float(
'inf'
) and dis_by_directions['right'][i][1] != float('inf'):
if AXIS_MULPLICITY * axis_unit >= abs(
dis_by_directions['left'][i][1]
- dis_by_directions['right'][i][1]
):
left_sub_bbox = subjects[dis_by_directions['left'][i][0]][
'bbox'
]
right_sub_bbox = subjects[dis_by_directions['right'][i][0]][
'bbox'
]
left_sub_bbox_y_axis = left_sub_bbox[3] - left_sub_bbox[1]
right_sub_bbox_y_axis = right_sub_bbox[3] - right_sub_bbox[1]
if (
abs(left_sub_bbox_y_axis - l_y_axis)
+ dis_by_directions['left'][i][0]
> abs(right_sub_bbox_y_axis - l_y_axis)
+ dis_by_directions['right'][i][0]
):
left_or_right = dis_by_directions['right'][i]
else:
left_or_right = dis_by_directions['left'][i]
else:
left_or_right = dis_by_directions['left'][i]
if left_or_right[1] > dis_by_directions['right'][i][1]:
left_or_right = dis_by_directions['right'][i]
else:
left_or_right = dis_by_directions['left'][i]
if left_or_right[1] == float('inf'):
left_or_right = dis_by_directions['right'][i]
else:
left_or_right = [-1, float('inf')]
if dis_by_directions['top'][i][1] != float('inf') or dis_by_directions[
'bottom'
][i][1] != float('inf'):
if dis_by_directions['top'][i][1] != float('inf') and dis_by_directions[
'bottom'
][i][1] != float('inf'):
if AXIS_MULPLICITY * axis_unit >= abs(
dis_by_directions['top'][i][1]
- dis_by_directions['bottom'][i][1]
):
top_bottom = subjects[dis_by_directions['bottom'][i][0]]['bbox']
bottom_top = subjects[dis_by_directions['top'][i][0]]['bbox']
top_bottom_x_axis = top_bottom[2] - top_bottom[0]
bottom_top_x_axis = bottom_top[2] - bottom_top[0]
if (
abs(top_bottom_x_axis - l_x_axis)
+ dis_by_directions['bottom'][i][1]
> abs(bottom_top_x_axis - l_x_axis)
+ dis_by_directions['top'][i][1]
):
top_or_bottom = dis_by_directions['top'][i]
else:
top_or_bottom = dis_by_directions['bottom'][i]
else:
top_or_bottom = dis_by_directions['top'][i]
if top_or_bottom[1] > dis_by_directions['bottom'][i][1]:
top_or_bottom = dis_by_directions['bottom'][i]
else:
top_or_bottom = dis_by_directions['top'][i]
if top_or_bottom[1] == float('inf'):
top_or_bottom = dis_by_directions['bottom'][i]
else:
top_or_bottom = [-1, float('inf')]
if left_or_right[1] != float('inf') or top_or_bottom[1] != float('inf'):
if left_or_right[1] != float('inf') and top_or_bottom[1] != float(
'inf'
):
if AXIS_MULPLICITY * axis_unit >= abs(
left_or_right[1] - top_or_bottom[1]
):
y_axis_bbox = subjects[left_or_right[0]]['bbox']
x_axis_bbox = subjects[top_or_bottom[0]]['bbox']
if (
abs((x_axis_bbox[2] - x_axis_bbox[0]) - l_x_axis) / l_x_axis
> abs((y_axis_bbox[3] - y_axis_bbox[1]) - l_y_axis)
/ l_y_axis
):
sub_obj_map_h[left_or_right[0]].append(i)
else:
sub_obj_map_h[top_or_bottom[0]].append(i)
else:
if left_or_right[1] > top_or_bottom[1]:
sub_obj_map_h[top_or_bottom[0]].append(i)
else:
sub_obj_map_h[left_or_right[0]].append(i)
else:
if left_or_right[1] != float('inf'):
sub_obj_map_h[left_or_right[0]].append(i)
else:
sub_obj_map_h[top_or_bottom[0]].append(i)
ret = []
for i in sub_obj_map_h.keys():
ret.append(
{
'sub_bbox': {
'bbox': subjects[i]['bbox'],
'score': subjects[i]['score'],
},
'obj_bboxes': [
{'score': objects[j]['score'], 'bbox': objects[j]['bbox']}
for j in sub_obj_map_h[i]
],
'sub_idx': i,
}
)
return ret
def __tie_up_category_by_distance_v3( def __tie_up_category_by_distance_v3(
self, self,
page_no: int,
subject_category_id: int, subject_category_id: int,
object_category_id: int, object_category_id: int,
priority_pos: PosRelationEnum,
): ):
subjects = self.__reduct_overlap( subjects = self.__reduct_overlap(
list( list(
...@@ -464,7 +171,7 @@ class MagicModel: ...@@ -464,7 +171,7 @@ class MagicModel:
lambda x: {'bbox': x['bbox'], 'score': x['score']}, lambda x: {'bbox': x['bbox'], 'score': x['score']},
filter( filter(
lambda x: x['category_id'] == subject_category_id, lambda x: x['category_id'] == subject_category_id,
self.__model_list[page_no]['layout_dets'], self.__page_model_info['layout_dets'],
), ),
) )
) )
...@@ -475,7 +182,7 @@ class MagicModel: ...@@ -475,7 +182,7 @@ class MagicModel:
lambda x: {'bbox': x['bbox'], 'score': x['score']}, lambda x: {'bbox': x['bbox'], 'score': x['score']},
filter( filter(
lambda x: x['category_id'] == object_category_id, lambda x: x['category_id'] == object_category_id,
self.__model_list[page_no]['layout_dets'], self.__page_model_info['layout_dets'],
), ),
) )
) )
...@@ -605,13 +312,12 @@ class MagicModel: ...@@ -605,13 +312,12 @@ class MagicModel:
return ret return ret
def get_imgs(self):
def get_imgs_v2(self, page_no: int):
with_captions = self.__tie_up_category_by_distance_v3( with_captions = self.__tie_up_category_by_distance_v3(
page_no, 3, 4, PosRelationEnum.BOTTOM 3, 4
) )
with_footnotes = self.__tie_up_category_by_distance_v3( with_footnotes = self.__tie_up_category_by_distance_v3(
page_no, 3, CategoryId.ImageFootnote, PosRelationEnum.ALL 3, CategoryId.ImageFootnote
) )
ret = [] ret = []
for v in with_captions: for v in with_captions:
...@@ -625,12 +331,12 @@ class MagicModel: ...@@ -625,12 +331,12 @@ class MagicModel:
ret.append(record) ret.append(record)
return ret return ret
def get_tables_v2(self, page_no: int) -> list: def get_tables(self) -> list:
with_captions = self.__tie_up_category_by_distance_v3( with_captions = self.__tie_up_category_by_distance_v3(
page_no, 5, 6, PosRelationEnum.UP 5, 6
) )
with_footnotes = self.__tie_up_category_by_distance_v3( with_footnotes = self.__tie_up_category_by_distance_v3(
page_no, 5, 7, PosRelationEnum.ALL 5, 7
) )
ret = [] ret = []
for v in with_captions: for v in with_captions:
...@@ -644,52 +350,31 @@ class MagicModel: ...@@ -644,52 +350,31 @@ class MagicModel:
ret.append(record) ret.append(record)
return ret return ret
def get_imgs(self, page_no: int): def get_equations(self) -> tuple[list, list, list]: # 有坐标,也有字
return self.get_imgs_v2(page_no)
def get_tables(
self, page_no: int
) -> list: # 3个坐标, caption, table主体,table-note
return self.get_tables_v2(page_no)
def get_equations(self, page_no: int) -> list: # 有坐标,也有字
inline_equations = self.__get_blocks_by_type( inline_equations = self.__get_blocks_by_type(
ModelBlockTypeEnum.EMBEDDING.value, page_no, ['latex'] CategoryId.InlineEquation, ['latex']
) )
interline_equations = self.__get_blocks_by_type( interline_equations = self.__get_blocks_by_type(
ModelBlockTypeEnum.ISOLATED.value, page_no, ['latex'] CategoryId.InterlineEquation_YOLO, ['latex']
) )
interline_equations_blocks = self.__get_blocks_by_type( interline_equations_blocks = self.__get_blocks_by_type(
ModelBlockTypeEnum.ISOLATE_FORMULA.value, page_no CategoryId.InterlineEquation_Layout
) )
return inline_equations, interline_equations, interline_equations_blocks return inline_equations, interline_equations, interline_equations_blocks
def get_discarded(self, page_no: int) -> list: # 自研模型,只有坐标 def get_discarded(self) -> list: # 自研模型,只有坐标
blocks = self.__get_blocks_by_type(ModelBlockTypeEnum.ABANDON.value, page_no) blocks = self.__get_blocks_by_type(CategoryId.Abandon)
return blocks return blocks
def get_text_blocks(self, page_no: int) -> list: # 自研模型搞的,只有坐标,没有字 def get_text_blocks(self) -> list: # 自研模型搞的,只有坐标,没有字
blocks = self.__get_blocks_by_type(ModelBlockTypeEnum.PLAIN_TEXT.value, page_no) blocks = self.__get_blocks_by_type(CategoryId.Text)
return blocks return blocks
def get_title_blocks(self, page_no: int) -> list: # 自研模型,只有坐标,没字 def get_title_blocks(self) -> list: # 自研模型,只有坐标,没字
blocks = self.__get_blocks_by_type(ModelBlockTypeEnum.TITLE.value, page_no) blocks = self.__get_blocks_by_type(CategoryId.Title)
return blocks return blocks
def get_ocr_text(self, page_no: int) -> list: # paddle 搞的,有字也有坐标 def get_all_spans(self) -> list:
text_spans = []
model_page_info = self.__model_list[page_no]
layout_dets = model_page_info['layout_dets']
for layout_det in layout_dets:
if layout_det['category_id'] == '15':
span = {
'bbox': layout_det['bbox'],
'content': layout_det['text'],
}
text_spans.append(span)
return text_spans
def get_all_spans(self, page_no: int) -> list:
def remove_duplicate_spans(spans): def remove_duplicate_spans(spans):
new_spans = [] new_spans = []
...@@ -699,8 +384,7 @@ class MagicModel: ...@@ -699,8 +384,7 @@ class MagicModel:
return new_spans return new_spans
all_spans = [] all_spans = []
model_page_info = self.__model_list[page_no] layout_dets = self.__page_model_info['layout_dets']
layout_dets = model_page_info['layout_dets']
allow_category_id_list = [3, 5, 13, 14, 15] allow_category_id_list = [3, 5, 13, 14, 15]
"""当成span拼接的""" """当成span拼接的"""
# 3: 'image', # 图片 # 3: 'image', # 图片
...@@ -713,7 +397,7 @@ class MagicModel: ...@@ -713,7 +397,7 @@ class MagicModel:
if category_id in allow_category_id_list: if category_id in allow_category_id_list:
span = {'bbox': layout_det['bbox'], 'score': layout_det['score']} span = {'bbox': layout_det['bbox'], 'score': layout_det['score']}
if category_id == 3: if category_id == 3:
span['type'] = ContentType.Image span['type'] = ContentType.IMAGE
elif category_id == 5: elif category_id == 5:
# 获取table模型结果 # 获取table模型结果
latex = layout_det.get('latex', None) latex = layout_det.get('latex', None)
...@@ -722,50 +406,36 @@ class MagicModel: ...@@ -722,50 +406,36 @@ class MagicModel:
span['latex'] = latex span['latex'] = latex
elif html: elif html:
span['html'] = html span['html'] = html
span['type'] = ContentType.Table span['type'] = ContentType.TABLE
elif category_id == 13: elif category_id == 13:
span['content'] = layout_det['latex'] span['content'] = layout_det['latex']
span['type'] = ContentType.InlineEquation span['type'] = ContentType.INLINE_EQUATION
elif category_id == 14: elif category_id == 14:
span['content'] = layout_det['latex'] span['content'] = layout_det['latex']
span['type'] = ContentType.InterlineEquation span['type'] = ContentType.INTERLINE_EQUATION
elif category_id == 15: elif category_id == 15:
span['content'] = layout_det['text'] span['content'] = layout_det['text']
span['type'] = ContentType.Text span['type'] = ContentType.TEXT
all_spans.append(span) all_spans.append(span)
return remove_duplicate_spans(all_spans) return remove_duplicate_spans(all_spans)
def get_page_size(self, page_no: int): # 获取页面宽高
# 获取当前页的page对象
page = self.__docs.get_page(page_no).get_page_info()
# 获取当前页的宽高
page_w = page.w
page_h = page.h
return page_w, page_h
def __get_blocks_by_type( def __get_blocks_by_type(
self, type: int, page_no: int, extra_col: list[str] = [] self, category_type: int, extra_col=None
) -> list: ) -> list:
if extra_col is None:
extra_col = []
blocks = [] blocks = []
for page_dict in self.__model_list: layout_dets = self.__page_model_info.get('layout_dets', [])
layout_dets = page_dict.get('layout_dets', []) for item in layout_dets:
page_info = page_dict.get('page_info', {}) category_id = item.get('category_id', -1)
page_number = page_info.get('page_no', -1) bbox = item.get('bbox', None)
if page_no != page_number:
continue if category_id == category_type:
for item in layout_dets: block = {
category_id = item.get('category_id', -1) 'bbox': bbox,
bbox = item.get('bbox', None) 'score': item.get('score'),
}
if category_id == type: for col in extra_col:
block = { block[col] = item.get(col, None)
'bbox': bbox, blocks.append(block)
'score': item.get('score'),
}
for col in extra_col:
block[col] = item.get(col, None)
blocks.append(block)
return blocks return blocks
def get_model_list(self, page_no):
return self.__model_list[page_no]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment