Commit 8f1f9abe authored by myhloli's avatar myhloli
Browse files

refactor: enhance bounding box utilities and add configuration reader for S3 integration

parent 7285ea92
# Copyright (c) Opendatalab. All rights reserved.
import json
import os
from loguru import logger
# 定义配置文件名常量
CONFIG_FILE_NAME = os.getenv('MINERU_TOOLS_CONFIG_JSON', 'magic-pdf.json')
def read_config():
if os.path.isabs(CONFIG_FILE_NAME):
config_file = CONFIG_FILE_NAME
else:
home_dir = os.path.expanduser('~')
config_file = os.path.join(home_dir, CONFIG_FILE_NAME)
if not os.path.exists(config_file):
raise FileNotFoundError(f'{config_file} not found')
with open(config_file, 'r', encoding='utf-8') as f:
config = json.load(f)
return config
def get_s3_config(bucket_name: str):
"""~/magic-pdf.json 读出来."""
config = read_config()
bucket_info = config.get('bucket_info')
if bucket_name not in bucket_info:
access_key, secret_key, storage_endpoint = bucket_info['[default]']
else:
access_key, secret_key, storage_endpoint = bucket_info[bucket_name]
if access_key is None or secret_key is None or storage_endpoint is None:
raise Exception(f'ak, sk or endpoint not found in {CONFIG_FILE_NAME}')
# logger.info(f"get_s3_config: ak={access_key}, sk={secret_key}, endpoint={storage_endpoint}")
return access_key, secret_key, storage_endpoint
def get_s3_config_dict(path: str):
access_key, secret_key, storage_endpoint = get_s3_config(get_bucket_name(path))
return {'ak': access_key, 'sk': secret_key, 'endpoint': storage_endpoint}
def get_bucket_name(path):
bucket, key = parse_bucket_key(path)
return bucket
def parse_bucket_key(s3_full_path: str):
"""
输入 s3://bucket/path/to/my/file.txt
输出 bucket, path/to/my/file.txt
"""
s3_full_path = s3_full_path.strip()
if s3_full_path.startswith("s3://"):
s3_full_path = s3_full_path[5:]
if s3_full_path.startswith("/"):
s3_full_path = s3_full_path[1:]
bucket, key = s3_full_path.split("/", 1)
return bucket, key
def get_local_models_dir():
config = read_config()
models_dir = config.get('models-dir')
if models_dir is None:
logger.warning(f"'models-dir' not found in {CONFIG_FILE_NAME}, use '/tmp/models' as default")
return '/tmp/models'
else:
return models_dir
def get_local_layoutreader_model_dir():
config = read_config()
layoutreader_model_dir = config.get('layoutreader-model-dir')
if layoutreader_model_dir is None or not os.path.exists(layoutreader_model_dir):
home_dir = os.path.expanduser('~')
layoutreader_at_modelscope_dir_path = os.path.join(home_dir, '.cache/modelscope/hub/ppaanngggg/layoutreader')
logger.warning(f"'layoutreader-model-dir' not exists, use {layoutreader_at_modelscope_dir_path} as default")
return layoutreader_at_modelscope_dir_path
else:
return layoutreader_model_dir
def get_device():
config = read_config()
device = config.get('device-mode')
if device is None:
logger.warning(f"'device-mode' not found in {CONFIG_FILE_NAME}, use 'cpu' as default")
return 'cpu'
else:
return device
def get_table_recog_config():
config = read_config()
table_config = config.get('table-config')
if table_config is None:
logger.warning(f"'table-config' not found in {CONFIG_FILE_NAME}, use 'False' as default")
return json.loads(f'{{"enable": true}}')
else:
return table_config
def get_formula_config():
config = read_config()
formula_config = config.get('formula-config')
if formula_config is None:
logger.warning(f"'formula-config' not found in {CONFIG_FILE_NAME}, use 'True' as default")
return json.loads(f'{{"enable": true}}')
else:
return formula_config
\ No newline at end of file
# Copyright (c) Opendatalab. All rights reserved.
def result_to_middle_json(model_json, images_list, pdf_doc, image_writer):
pass
\ No newline at end of file
from mineru.utils.pipeline_magic_model import MagicModel
from mineru.version import __version__
from mineru.utils.hash_utils import str_md5
def page_model_info_to_page_info(page_model_info, image_dict, page, image_writer, page_index, lang=None, ocr=False):
scale = image_dict["scale"]
page_pil_img = image_dict["img_pil"]
page_img_md5 = str_md5(image_dict["img_base64"])
width, height = map(int, page.get_size())
magic_model = MagicModel(page_model_info, scale)
def result_to_middle_json(model_list, images_list, pdf_doc, image_writer, lang=None, ocr=False):
middle_json = {"pdf_info": [], "_backend":"vlm", "_version_name": __version__}
for page_index, page_model_info in enumerate(model_list):
page = pdf_doc[page_index]
image_dict = images_list[page_index]
page_info = page_model_info_to_page_info(
page_model_info, image_dict, page, image_writer, page_index, lang=lang, ocr=ocr
)
middle_json["pdf_info"].append(page_info)
return middle_json
\ No newline at end of file
......@@ -2,9 +2,9 @@ import os
import time
import numpy as np
import torch
from pypdfium2 import PdfDocument
from mineru.backend.pipeline.model_init import MineruPipelineModel
from .model_init import MineruPipelineModel
from .config_reader import get_local_models_dir, get_device, get_formula_config, get_table_recog_config
from .model_json_to_middle_json import result_to_middle_json
from ...data.data_reader_writer import DataWriter
from ...utils.pdf_classify import classify
......@@ -13,11 +13,6 @@ from ...utils.pdf_image_tools import load_images_from_pdf
from loguru import logger
from ...utils.model_utils import get_vram, clean_memory
from magic_pdf.libs.config_reader import (get_device, get_formula_config,
get_layout_config,
get_local_models_dir,
get_table_recog_config)
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 让mps可以fallback
......@@ -109,6 +104,7 @@ def doc_analyze(
all_image_lists = []
all_pdf_docs = []
ocr_enabled_list = []
for pdf_idx, pdf_bytes in enumerate(pdf_bytes_list):
# 确定OCR设置
_ocr = False
......@@ -118,6 +114,7 @@ def doc_analyze(
elif parse_method == 'ocr':
_ocr = True
ocr_enabled_list[pdf_idx] = _ocr
_lang = lang_list[pdf_idx]
# 收集每个数据集中的页面
......@@ -152,23 +149,23 @@ def doc_analyze(
results.extend(batch_results)
# 构建返回结果
# 多数据集模式:按数据集分组结果
infer_results = [[] for _ in datasets]
infer_results = []
for i, page_info in enumerate(all_pages_info):
pdf_idx, page_idx, pil_img, _, _ = page_info
result = results[i]
page_info_dict = {'page_no': page_idx, 'width': pil_img.get_width(), 'height': pil_img.get_height()}
page_info_dict = {'page_no': page_idx, 'width': pil_img.width, 'height': pil_img.height}
page_dict = {'layout_dets': result, 'page_info': page_info_dict}
infer_results[pdf_idx].append(page_dict)
infer_results[pdf_idx][page_idx] = page_dict
middle_json_list = []
for pdf_idx, model_json in enumerate(infer_results):
for pdf_idx, model_list in enumerate(infer_results):
images_list = all_image_lists[pdf_idx]
pdf_doc = all_pdf_docs[pdf_idx]
middle_json = result_to_middle_json(model_json, images_list, pdf_doc, image_writer)
_lang = lang_list[pdf_idx]
_ocr = ocr_enabled_list[pdf_idx]
middle_json = result_to_middle_json(model_list, images_list, pdf_doc, image_writer, _lang, _ocr)
middle_json_list.append(middle_json)
return middle_json_list, infer_results
......
......@@ -118,7 +118,7 @@ def token_to_page_info(token, image_dict, page, image_writer, page_index) -> dic
def result_to_middle_json(token_list, images_list, pdf_doc, image_writer):
middle_json = {"pdf_info": [], "_version_name": __version__}
middle_json = {"pdf_info": [], "_backend":"vlm", "_version_name": __version__}
for index, token in enumerate(token_list):
page = pdf_doc[index]
image_dict = images_list[index]
......
# Copyright (c) Opendatalab. All rights reserved.
# Copyright (c) Opendatalab. All rights reserved.
......@@ -72,3 +72,88 @@ def bbox_distance(bbox1, bbox2):
elif top:
return y2 - y1b
return 0.0
def get_minbox_if_overlap_by_ratio(bbox1, bbox2, ratio):
"""通过calculate_overlap_area_2_minbox_area_ratio计算两个bbox重叠的面积占最小面积的box的比例
如果比例大于ratio,则返回小的那个bbox, 否则返回None."""
x1_min, y1_min, x1_max, y1_max = bbox1
x2_min, y2_min, x2_max, y2_max = bbox2
area1 = (x1_max - x1_min) * (y1_max - y1_min)
area2 = (x2_max - x2_min) * (y2_max - y2_min)
overlap_ratio = calculate_overlap_area_2_minbox_area_ratio(bbox1, bbox2)
if overlap_ratio > ratio:
if area1 <= area2:
return bbox1
else:
return bbox2
else:
return None
def calculate_overlap_area_2_minbox_area_ratio(bbox1, bbox2):
"""计算box1和box2的重叠面积占最小面积的box的比例."""
# Determine the coordinates of the intersection rectangle
x_left = max(bbox1[0], bbox2[0])
y_top = max(bbox1[1], bbox2[1])
x_right = min(bbox1[2], bbox2[2])
y_bottom = min(bbox1[3], bbox2[3])
if x_right < x_left or y_bottom < y_top:
return 0.0
# The area of overlap area
intersection_area = (x_right - x_left) * (y_bottom - y_top)
min_box_area = min([(bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1]),
(bbox2[3] - bbox2[1]) * (bbox2[2] - bbox2[0])])
if min_box_area == 0:
return 0
else:
return intersection_area / min_box_area
def calculate_iou(bbox1, bbox2):
"""计算两个边界框的交并比(IOU)。
Args:
bbox1 (list[float]): 第一个边界框的坐标,格式为 [x1, y1, x2, y2],其中 (x1, y1) 为左上角坐标,(x2, y2) 为右下角坐标。
bbox2 (list[float]): 第二个边界框的坐标,格式与 `bbox1` 相同。
Returns:
float: 两个边界框的交并比(IOU),取值范围为 [0, 1]。
"""
# Determine the coordinates of the intersection rectangle
x_left = max(bbox1[0], bbox2[0])
y_top = max(bbox1[1], bbox2[1])
x_right = min(bbox1[2], bbox2[2])
y_bottom = min(bbox1[3], bbox2[3])
if x_right < x_left or y_bottom < y_top:
return 0.0
# The area of overlap area
intersection_area = (x_right - x_left) * (y_bottom - y_top)
# The area of both rectangles
bbox1_area = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1])
bbox2_area = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1])
if any([bbox1_area == 0, bbox2_area == 0]):
return 0
# Compute the intersection over union by taking the intersection area
# and dividing it by the sum of both areas minus the intersection area
iou = intersection_area / float(bbox1_area + bbox2_area - intersection_area)
return iou
def _is_in(box1, box2) -> bool:
"""box1是否完全在box2里面."""
x0_1, y0_1, x1_1, y1_1 = box1
x0_2, y0_2, x1_2, y1_2 = box2
return (x0_1 >= x0_2 and # box1的左边界不在box2的左边外
y0_1 >= y0_2 and # box1的上边界不在box2的上边外
x1_1 <= x1_2 and # box1的右边界不在box2的右边外
y1_1 <= y1_2) # box1的下边界不在box2的下边外
\ No newline at end of file
......@@ -23,6 +23,22 @@ class ContentType:
INLINE_EQUATION = 'inline_equation'
class CategoryId:
Title = 0
Text = 1
Abandon = 2
ImageBody = 3
ImageCaption = 4
TableBody = 5
TableCaption = 6
TableFootnote = 7
InterlineEquation_Layout = 8
InlineEquation = 13
InterlineEquation_YOLO = 14
OcrText = 15
ImageFootnote = 101
class MakeMode:
MM_MD = 'mm_markdown'
NLP_MD = 'nlp_markdown'
......
import os
import unicodedata
if not os.getenv("FTLANG_CACHE"):
current_file_path = os.path.abspath(__file__)
current_dir = os.path.dirname(current_file_path)
root_dir = os.path.dirname(current_dir)
ftlang_cache_dir = os.path.join(root_dir, 'resources', 'fasttext-langdetect')
os.environ["FTLANG_CACHE"] = str(ftlang_cache_dir)
# print(os.getenv("FTLANG_CACHE"))
from fast_langdetect import detect_language
def remove_invalid_surrogates(text):
# 移除无效的 UTF-16 代理对
return ''.join(c for c in text if not (0xD800 <= ord(c) <= 0xDFFF))
def detect_lang(text: str) -> str:
if len(text) == 0:
return ""
text = text.replace("\n", "")
text = remove_invalid_surrogates(text)
# print(text)
try:
lang_upper = detect_language(text)
except:
html_no_ctrl_chars = ''.join([l for l in text if unicodedata.category(l)[0] not in ['C', ]])
lang_upper = detect_language(html_no_ctrl_chars)
try:
lang = lang_upper.lower()
except:
lang = ""
return lang
if __name__ == '__main__':
print(os.getenv("FTLANG_CACHE"))
print(detect_lang("This is a test."))
print(detect_lang("<html>This is a test</html>"))
print(detect_lang("这个是中文测试。"))
print(detect_lang("<html>这个是中文测试。</html>"))
print(detect_lang("〖\ud835\udc46\ud835〗这是个包含utf-16的中文测试"))
\ No newline at end of file
......@@ -4,7 +4,7 @@ import gc
from loguru import logger
import numpy as np
from magic_pdf.libs.boxbase import get_minbox_if_overlap_by_ratio
from mineru.utils.boxbase import get_minbox_if_overlap_by_ratio
def crop_img(input_res, input_np_img, crop_paste_x=0, crop_paste_y=0):
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment