Commit 15dd9a0f authored by myhloli's avatar myhloli
Browse files

refactor: reorganize config_reader imports and enhance format utilities

parent 3eef1218
......@@ -138,24 +138,20 @@ class MineruPipelineModel:
'DocAnalysis init, this may take some times......'
)
atom_model_manager = AtomModelSingleton()
models_dir = kwargs.get('models_dir', "")
if not models_dir:
logger.error("can't found models_dir, please set models_dir")
exit(1)
if self.apply_formula:
# 初始化公式检测模型
self.mfd_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.MFD,
mfd_weights=str(
os.path.join(models_dir, get_file_from_repos(ModelPath.yolo_v8_mfd))
get_file_from_repos(ModelPath.yolo_v8_mfd)
),
device=self.device,
)
# 初始化公式解析模型
mfr_weight_dir = str(
os.path.join(models_dir, get_file_from_repos(ModelPath.unimernet_small))
get_file_from_repos(ModelPath.unimernet_small)
)
self.mfr_model = atom_model_manager.get_atom_model(
......@@ -168,7 +164,7 @@ class MineruPipelineModel:
self.layout_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.Layout,
doclayout_yolo_weights=str(
os.path.join(models_dir, get_file_from_repos(ModelPath.doclayout_yolo))
get_file_from_repos(ModelPath.doclayout_yolo)
),
device=self.device,
)
......
......@@ -3,7 +3,7 @@ import time
from loguru import logger
from mineru.backend.pipeline.config_reader import get_device, get_llm_aided_config
from mineru.utils.config_reader import get_device, get_llm_aided_config
from mineru.backend.pipeline.model_init import AtomModelSingleton
from mineru.backend.pipeline.para_split import para_split
from mineru.utils.block_pre_proc import prepare_block_bboxes, process_groups
......
......@@ -4,7 +4,7 @@ import numpy as np
import torch
from .model_init import MineruPipelineModel
from .config_reader import get_local_models_dir, get_device, get_formula_config, get_table_recog_config
from mineru.utils.config_reader import get_device, get_formula_config, get_table_recog_config
from ...utils.pdf_classify import classify
from ...utils.pdf_image_tools import load_images_from_pdf
......@@ -48,7 +48,6 @@ def custom_model_init(
):
model_init_start = time.time()
# 从配置文件读取model-dir和device
local_models_dir = get_local_models_dir()
device = get_device()
formula_config = get_formula_config()
......@@ -60,7 +59,6 @@ def custom_model_init(
table_config['enable'] = table_enable
model_input = {
'models_dir': local_models_dir,
'device': device,
'table_config': table_config,
'formula_config': formula_config,
......
import re
from loguru import logger
from mineru.backend.pipeline.config_reader import get_latex_delimiter_config
from mineru.utils.config_reader import get_latex_delimiter_config
from mineru.backend.pipeline.para_split import ListLineTag
from mineru.utils.enum_class import BlockType, ContentType, MakeMode
from mineru.utils.language import detect_lang
......
......@@ -3,7 +3,7 @@ import re
from mineru.utils.cut_image import cut_image_and_table
from mineru.utils.enum_class import BlockType, ContentType
from mineru.utils.hash_utils import str_md5
from mineru.backend.vlm.vlm_magic_model import fix_two_layer_blocks, fix_title_blocks
from mineru.backend.vlm.vlm_magic_model import MagicModel
from mineru.version import __version__
......@@ -17,100 +17,23 @@ def token_to_page_info(token, image_dict, page, image_writer, page_index) -> dic
scale = image_dict["scale"]
page_pil_img = image_dict["img_pil"]
page_img_md5 = str_md5(image_dict["img_base64"])
width, height = map(int, page.get_size())
# 使用正则表达式查找所有块
pattern = (
r"<\|box_start\|>(.*?)<\|box_end\|><\|ref_start\|>(.*?)<\|ref_end\|><\|md_start\|>(.*?)(?:<\|md_end\|>|<\|im_end\|>)"
)
block_infos = re.findall(pattern, token, re.DOTALL)
blocks = []
# 解析每个块
for index, block_info in enumerate(block_infos):
block_bbox = block_info[0].strip()
x1, y1, x2, y2 = map(int, block_bbox.split())
x_1, y_1, x_2, y_2 = (
int(x1 * width / 1000),
int(y1 * height / 1000),
int(x2 * width / 1000),
int(y2 * height / 1000),
)
if x_2 < x_1:
x_1, x_2 = x_2, x_1
if y_2 < y_1:
y_1, y_2 = y_2, y_1
block_bbox = (x_1, y_1, x_2, y_2)
block_type = block_info[1].strip()
block_content = block_info[2].strip()
# print(f"坐标: {block_bbox}")
# print(f"类型: {block_type}")
# print(f"内容: {block_content}")
# print("-" * 50)
span_type = "unknown"
if block_type in [
"text",
"title",
"image_caption",
"image_footnote",
"table_caption",
"table_footnote",
"list",
"index",
]:
span_type = ContentType.TEXT
elif block_type in ["image"]:
block_type = BlockType.IMAGE_BODY
span_type = ContentType.IMAGE
elif block_type in ["table"]:
block_type = BlockType.TABLE_BODY
span_type = ContentType.TABLE
elif block_type in ["equation"]:
block_type = BlockType.INTERLINE_EQUATION
span_type = ContentType.INTERLINE_EQUATION
if span_type in ["image", "table"]:
span = {
"bbox": block_bbox,
"type": span_type,
}
if span_type == ContentType.TABLE:
span["html"] = block_content
magic_model = MagicModel(token, width, height)
image_blocks = magic_model.get_image_blocks()
table_blocks = magic_model.get_table_blocks()
title_blocks = magic_model.get_title_blocks()
text_blocks = magic_model.get_text_blocks()
interline_equation_blocks = magic_model.get_interline_equation_blocks()
all_spans = magic_model.get_all_spans()
# 对image/table/interline_equation的span截图
for span in all_spans:
if span["type"] in [ContentType.IMAGE, ContentType.TABLE, ContentType.INTERLINE_EQUATION]:
span = cut_image_and_table(span, page_pil_img, page_img_md5, page_index, image_writer, scale=scale)
else:
span = {
"bbox": block_bbox,
"type": span_type,
"content": block_content,
}
line = {
"bbox": block_bbox,
"spans": [span],
}
blocks.append(
{
"bbox": block_bbox,
"type": block_type,
"lines": [line],
"index": index,
}
)
image_blocks = fix_two_layer_blocks(blocks, BlockType.IMAGE)
table_blocks = fix_two_layer_blocks(blocks, BlockType.TABLE)
title_blocks = fix_title_blocks(blocks)
page_blocks = [
block
for block in blocks
if block["type"] in [BlockType.TEXT, BlockType.LIST, BlockType.INDEX, BlockType.INTERLINE_EQUATION]
]
page_blocks.extend([*image_blocks, *table_blocks, *title_blocks])
page_blocks = []
page_blocks.extend([*image_blocks, *table_blocks, *title_blocks, *text_blocks, *interline_equation_blocks])
# 对page_blocks根据index的值进行排序
page_blocks.sort(key=lambda x: x["index"])
......
......@@ -2,8 +2,205 @@ import re
from typing import Literal
from mineru.utils.boxbase import bbox_distance, is_in
from mineru.utils.enum_class import BlockType
from mineru.utils.enum_class import BlockType, ContentType
from mineru.backend.vlm.vlm_middle_json_mkcontent import merge_para_with_text
from mineru.utils.format_utils import convert_otsl_to_html
class MagicModel:
def __init__(self, token: str, width, height):
self.token = token
# 使用正则表达式查找所有块
pattern = (
r"<\|box_start\|>(.*?)<\|box_end\|><\|ref_start\|>(.*?)<\|ref_end\|><\|md_start\|>(.*?)(?:<\|md_end\|>|<\|im_end\|>)"
)
block_infos = re.findall(pattern, token, re.DOTALL)
blocks = []
self.all_spans = []
# 解析每个块
for index, block_info in enumerate(block_infos):
block_bbox = block_info[0].strip()
x1, y1, x2, y2 = map(int, block_bbox.split())
x_1, y_1, x_2, y_2 = (
int(x1 * width / 1000),
int(y1 * height / 1000),
int(x2 * width / 1000),
int(y2 * height / 1000),
)
if x_2 < x_1:
x_1, x_2 = x_2, x_1
if y_2 < y_1:
y_1, y_2 = y_2, y_1
block_bbox = (x_1, y_1, x_2, y_2)
block_type = block_info[1].strip()
block_content = block_info[2].strip()
# print(f"坐标: {block_bbox}")
# print(f"类型: {block_type}")
# print(f"内容: {block_content}")
# print("-" * 50)
span_type = "unknown"
if block_type in [
"text",
"title",
"image_caption",
"image_footnote",
"table_caption",
"table_footnote",
"list",
"index",
]:
span_type = ContentType.TEXT
elif block_type in ["image"]:
block_type = BlockType.IMAGE_BODY
span_type = ContentType.IMAGE
elif block_type in ["table"]:
block_type = BlockType.TABLE_BODY
span_type = ContentType.TABLE
elif block_type in ["equation"]:
block_type = BlockType.INTERLINE_EQUATION
span_type = ContentType.INTERLINE_EQUATION
if span_type in ["image", "table"]:
span = {
"bbox": block_bbox,
"type": span_type,
}
if span_type == ContentType.TABLE:
if "<fcel>" in block_content or "<ecel>" in block_content:
lines = block_content.split("\n\n")
new_lines = []
for line in lines:
if "<fcel>" in line or "<ecel>" in line:
line = convert_otsl_to_html(line)
new_lines.append(line)
span["html"] = "\n\n".join(new_lines)
else:
span["html"] = block_content
elif span_type in [ContentType.INTERLINE_EQUATION]:
span = {
"bbox": block_bbox,
"type": span_type,
"content": isolated_formula_clean(block_content),
}
else:
if block_content.count("\\(") == block_content.count("\\)") and block_content.count("\\(") > 0:
# 生成包含文本和公式的span列表
spans = []
last_end = 0
# 查找所有公式
for match in re.finditer(r'\\\((.+?)\\\)', block_content):
start, end = match.span()
# 添加公式前的文本
if start > last_end:
text_before = block_content[last_end:start]
if text_before.strip():
spans.append({
"bbox": block_bbox,
"type": ContentType.TEXT,
"content": text_before
})
# 添加公式(去除\(和\))
formula = match.group(1)
spans.append({
"bbox": block_bbox,
"type": ContentType.INLINE_EQUATION,
"content": formula.strip()
})
last_end = end
# 添加最后一个公式后的文本
if last_end < len(block_content):
text_after = block_content[last_end:]
if text_after.strip():
spans.append({
"bbox": block_bbox,
"type": ContentType.TEXT,
"content": text_after
})
span = spans
else:
span = {
"bbox": block_bbox,
"type": span_type,
"content": block_content,
}
if isinstance(span, dict) and "bbox" in span:
self.all_spans.append(span)
line = {
"bbox": block_bbox,
"spans": [span],
}
elif isinstance(span, list):
self.all_spans.extend(span)
line = {
"bbox": block_bbox,
"spans": span,
}
else:
raise ValueError(f"Invalid span type: {span_type}, expected dict or list, got {type(span)}")
blocks.append(
{
"bbox": block_bbox,
"type": block_type,
"lines": [line],
"index": index,
}
)
self.image_blocks = []
self.table_blocks = []
self.interline_equation_blocks = []
self.text_blocks = []
self.title_blocks = []
for block in blocks:
if block["type"] in [BlockType.IMAGE_BODY, BlockType.IMAGE_CAPTION, BlockType.IMAGE_FOOTNOTE]:
self.image_blocks.append(block)
elif block["type"] in [BlockType.TABLE_BODY, BlockType.TABLE_CAPTION, BlockType.TABLE_FOOTNOTE]:
self.table_blocks.append(block)
elif block["type"] == BlockType.INTERLINE_EQUATION:
self.interline_equation_blocks.append(block)
elif block["type"] == BlockType.TEXT:
self.text_blocks.append(block)
elif block["type"] == BlockType.TITLE:
self.title_blocks.append(block)
else:
continue
def get_image_blocks(self):
return fix_two_layer_blocks(self.image_blocks, BlockType.IMAGE)
def get_table_blocks(self):
return fix_two_layer_blocks(self.table_blocks, BlockType.TABLE)
def get_title_blocks(self):
return fix_title_blocks(self.title_blocks)
def get_text_blocks(self):
return self.text_blocks
def get_interline_equation_blocks(self):
return self.interline_equation_blocks
def get_all_spans(self):
return self.all_spans
def isolated_formula_clean(txt):
latex = txt[:]
if latex.startswith("\\["): latex = latex[2:]
if latex.endswith("\\]"): latex = latex[:-2]
return latex.strip()
def __reduct_overlap(bboxes):
......
import re
from mineru.utils.config_reader import get_latex_delimiter_config
from mineru.utils.enum_class import MakeMode, BlockType, ContentType
def merge_para_with_text(para_block):
latex_delimiters_config = get_latex_delimiter_config()
default_delimiters = {
'display': {'left': '$$', 'right': '$$'},
'inline': {'left': '$', 'right': '$'}
}
delimiters = latex_delimiters_config if latex_delimiters_config else default_delimiters
display_left_delimiter = delimiters['display']['left']
display_right_delimiter = delimiters['display']['right']
inline_left_delimiter = delimiters['inline']['left']
inline_right_delimiter = delimiters['inline']['right']
def merge_para_with_text(para_block):
para_text = ''
for line in para_block['lines']:
for span in line['spans']:
content = span['content']
for j, span in enumerate(line['spans']):
span_type = span['type']
content = ''
if span_type == ContentType.TEXT:
content = span['content']
elif span_type == ContentType.INLINE_EQUATION:
content = f"{inline_left_delimiter}{span['content']}{inline_right_delimiter}"
elif span_type == ContentType.INTERLINE_EQUATION:
content = f"\n{display_left_delimiter}\n{span['content']}\n{display_right_delimiter}\n"
content = content.strip()
if content:
para_text += content
else:
continue
if span_type in [ContentType.TEXT, ContentType.INLINE_EQUATION]:
if j == len(line['spans']) - 1:
para_text += content
else:
para_text += f'{content} '
elif span_type == ContentType.INTERLINE_EQUATION:
para_text += content
return para_text
def mk_blocks_to_markdown(para_blocks, make_mode, img_buket_path=''):
......
......@@ -9,7 +9,7 @@ import numpy as np
import yaml
from loguru import logger
from mineru.backend.pipeline.config_reader import get_device
from mineru.utils.config_reader import get_device
from mineru.utils.enum_class import ModelPath
from mineru.utils.models_download_utils import get_file_from_repos
from ....utils.ocr_utils import check_img, preprocess_image, sorted_boxes, merge_det_boxes, update_det_boxes, get_rotate_crop_image
......
......@@ -7,7 +7,7 @@ from typing import List
import torch
from loguru import logger
from mineru.backend.pipeline.config_reader import get_device
from mineru.utils.config_reader import get_device
from mineru.utils.enum_class import BlockType, ModelPath
from mineru.utils.models_download_utils import get_file_from_repos
......
import re
import itertools
import html
from typing import Any, Dict, List
from pydantic import (
BaseModel,
computed_field,
model_validator,
)
class TableCell(BaseModel):
"""TableCell."""
row_span: int = 1
col_span: int = 1
start_row_offset_idx: int
end_row_offset_idx: int
start_col_offset_idx: int
end_col_offset_idx: int
text: str
column_header: bool = False
row_header: bool = False
row_section: bool = False
@model_validator(mode="before")
@classmethod
def from_dict_format(cls, data: Any) -> Any:
"""from_dict_format."""
if isinstance(data, Dict):
# Check if this is a native BoundingBox or a bbox from docling-ibm-models
if (
# "bbox" not in data
# or data["bbox"] is None
# or isinstance(data["bbox"], BoundingBox)
"text"
in data
):
return data
text = data["bbox"].get("token", "")
if not len(text):
text_cells = data.pop("text_cell_bboxes", None)
if text_cells:
for el in text_cells:
text += el["token"] + " "
text = text.strip()
data["text"] = text
return data
class TableData(BaseModel): # TBD
"""BaseTableData."""
table_cells: List[TableCell] = []
num_rows: int = 0
num_cols: int = 0
@computed_field # type: ignore
@property
def grid(
self,
) -> List[List[TableCell]]:
"""grid."""
# Initialise empty table data grid (only empty cells)
table_data = [
[
TableCell(
text="",
start_row_offset_idx=i,
end_row_offset_idx=i + 1,
start_col_offset_idx=j,
end_col_offset_idx=j + 1,
)
for j in range(self.num_cols)
]
for i in range(self.num_rows)
]
# Overwrite cells in table data for which there is actual cell content.
for cell in self.table_cells:
for i in range(
min(cell.start_row_offset_idx, self.num_rows),
min(cell.end_row_offset_idx, self.num_rows),
):
for j in range(
min(cell.start_col_offset_idx, self.num_cols),
min(cell.end_col_offset_idx, self.num_cols),
):
table_data[i][j] = cell
return table_data
"""
OTSL
"""
OTSL_NL = "<nl>"
OTSL_FCEL = "<fcel>"
OTSL_ECEL = "<ecel>"
OTSL_LCEL = "<lcel>"
OTSL_UCEL = "<ucel>"
OTSL_XCEL = "<xcel>"
def otsl_extract_tokens_and_text(s: str):
# Pattern to match anything enclosed by < >
# (including the angle brackets themselves)
# pattern = r"(<[^>]+>)"
pattern = r"(" + r"|".join([OTSL_NL, OTSL_FCEL, OTSL_ECEL, OTSL_LCEL, OTSL_UCEL, OTSL_XCEL]) + r")"
# Find all tokens (e.g. "<otsl>", "<loc_140>", etc.)
tokens = re.findall(pattern, s)
# Remove any tokens that start with "<loc_"
tokens = [token for token in tokens]
# Split the string by those tokens to get the in-between text
text_parts = re.split(pattern, s)
text_parts = [token for token in text_parts]
# Remove any empty or purely whitespace strings from text_parts
text_parts = [part for part in text_parts if part.strip()]
return tokens, text_parts
def otsl_parse_texts(texts, tokens):
split_word = OTSL_NL
split_row_tokens = [
list(y)
for x, y in itertools.groupby(tokens, lambda z: z == split_word)
if not x
]
table_cells = []
r_idx = 0
c_idx = 0
def count_right(tokens, c_idx, r_idx, which_tokens):
span = 0
c_idx_iter = c_idx
while tokens[r_idx][c_idx_iter] in which_tokens:
c_idx_iter += 1
span += 1
if c_idx_iter >= len(tokens[r_idx]):
return span
return span
def count_down(tokens, c_idx, r_idx, which_tokens):
span = 0
r_idx_iter = r_idx
while tokens[r_idx_iter][c_idx] in which_tokens:
r_idx_iter += 1
span += 1
if r_idx_iter >= len(tokens):
return span
return span
for i, text in enumerate(texts):
cell_text = ""
if text in [
OTSL_FCEL,
OTSL_ECEL,
]:
row_span = 1
col_span = 1
right_offset = 1
if text != OTSL_ECEL:
cell_text = texts[i + 1]
right_offset = 2
# Check next element(s) for lcel / ucel / xcel,
# set properly row_span, col_span
next_right_cell = ""
if i + right_offset < len(texts):
next_right_cell = texts[i + right_offset]
next_bottom_cell = ""
if r_idx + 1 < len(split_row_tokens):
if c_idx < len(split_row_tokens[r_idx + 1]):
next_bottom_cell = split_row_tokens[r_idx + 1][c_idx]
if next_right_cell in [
OTSL_LCEL,
OTSL_XCEL,
]:
# we have horisontal spanning cell or 2d spanning cell
col_span += count_right(
split_row_tokens,
c_idx + 1,
r_idx,
[OTSL_LCEL, OTSL_XCEL],
)
if next_bottom_cell in [
OTSL_UCEL,
OTSL_XCEL,
]:
# we have a vertical spanning cell or 2d spanning cell
row_span += count_down(
split_row_tokens,
c_idx,
r_idx + 1,
[OTSL_UCEL, OTSL_XCEL],
)
table_cells.append(
TableCell(
text=cell_text.strip(),
row_span=row_span,
col_span=col_span,
start_row_offset_idx=r_idx,
end_row_offset_idx=r_idx + row_span,
start_col_offset_idx=c_idx,
end_col_offset_idx=c_idx + col_span,
)
)
if text in [
OTSL_FCEL,
OTSL_ECEL,
OTSL_LCEL,
OTSL_UCEL,
OTSL_XCEL,
]:
c_idx += 1
if text == OTSL_NL:
r_idx += 1
c_idx = 0
return table_cells, split_row_tokens
def export_to_html(table_data: TableData):
nrows = table_data.num_rows
ncols = table_data.num_cols
text = ""
if len(table_data.table_cells) == 0:
return ""
body = ""
for i in range(nrows):
body += "<tr>"
for j in range(ncols):
cell: TableCell = table_data.grid[i][j]
rowspan, rowstart = (
cell.row_span,
cell.start_row_offset_idx,
)
colspan, colstart = (
cell.col_span,
cell.start_col_offset_idx,
)
if rowstart != i:
continue
if colstart != j:
continue
content = html.escape(cell.text.strip())
celltag = "td"
if cell.column_header:
celltag = "th"
opening_tag = f"{celltag}"
if rowspan > 1:
opening_tag += f' rowspan="{rowspan}"'
if colspan > 1:
opening_tag += f' colspan="{colspan}"'
body += f"<{opening_tag}>{content}</{celltag}>"
body += "</tr>"
# dir = get_text_direction(text)
body = f"<table>{body}</table>"
return body
def convert_otsl_to_html(otsl_content: str):
tokens, mixed_texts = otsl_extract_tokens_and_text(otsl_content)
table_cells, split_row_tokens = otsl_parse_texts(mixed_texts, tokens)
table_data = TableData(
num_rows=len(split_row_tokens),
num_cols=(
max(len(row) for row in split_row_tokens) if split_row_tokens else 0
),
table_cells=table_cells,
)
return export_to_html(table_data)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment