Unverified Commit 919280aa authored by Xiaomeng Zhao's avatar Xiaomeng Zhao Committed by GitHub
Browse files

Merge branch 'dev' into multi_gpu_v2

parents ea9336c0 c6881d83
import click
def arg_parse(ctx: 'click.Context') -> dict:
# 解析额外参数
extra_kwargs = {}
i = 0
while i < len(ctx.args):
arg = ctx.args[i]
if arg.startswith('--'):
param_name = arg[2:].replace('-', '_') # 转换参数名格式
i += 1
if i < len(ctx.args) and not ctx.args[i].startswith('--'):
# 参数有值
try:
# 尝试转换为适当的类型
if ctx.args[i].lower() == 'true':
extra_kwargs[param_name] = True
elif ctx.args[i].lower() == 'false':
extra_kwargs[param_name] = False
elif '.' in ctx.args[i]:
try:
extra_kwargs[param_name] = float(ctx.args[i])
except ValueError:
extra_kwargs[param_name] = ctx.args[i]
else:
try:
extra_kwargs[param_name] = int(ctx.args[i])
except ValueError:
extra_kwargs[param_name] = ctx.args[i]
except:
extra_kwargs[param_name] = ctx.args[i]
else:
# 布尔型标志参数
extra_kwargs[param_name] = True
i -= 1
i += 1
return extra_kwargs
\ No newline at end of file
......@@ -21,6 +21,7 @@ class ContentType:
TEXT = 'text'
INTERLINE_EQUATION = 'interline_equation'
INLINE_EQUATION = 'inline_equation'
EQUATION = 'equation'
class CategoryId:
......
......@@ -132,6 +132,35 @@ def otsl_parse_texts(texts, tokens):
r_idx = 0
c_idx = 0
# Check and complete the matrix
if split_row_tokens:
max_cols = max(len(row) for row in split_row_tokens)
# Insert additional <ecel> to tags
for row_idx, row in enumerate(split_row_tokens):
while len(row) < max_cols:
row.append(OTSL_ECEL)
# Insert additional <ecel> to texts
new_texts = []
text_idx = 0
for row_idx, row in enumerate(split_row_tokens):
for col_idx, token in enumerate(row):
new_texts.append(token)
if text_idx < len(texts) and texts[text_idx] == token:
text_idx += 1
if (text_idx < len(texts) and
texts[text_idx] not in [OTSL_NL, OTSL_FCEL, OTSL_ECEL, OTSL_LCEL, OTSL_UCEL, OTSL_XCEL]):
new_texts.append(texts[text_idx])
text_idx += 1
new_texts.append(OTSL_NL)
if text_idx < len(texts) and texts[text_idx] == OTSL_NL:
text_idx += 1
texts = new_texts
def count_right(tokens, c_idx, r_idx, which_tokens):
span = 0
c_idx_iter = c_idx
......@@ -235,10 +264,11 @@ def export_to_html(table_data: TableData):
body = ""
grid = table_data.grid
for i in range(nrows):
body += "<tr>"
for j in range(ncols):
cell: TableCell = table_data.grid[i][j]
cell: TableCell = grid[i][j]
rowspan, rowstart = (
cell.row_span,
......
# Copyright (c) Opendatalab. All rights reserved.
from loguru import logger
from openai import OpenAI
import ast
import json_repair
from mineru.backend.pipeline.pipeline_middle_json_mkcontent import merge_para_with_text
......@@ -20,14 +20,19 @@ def llm_aided_title(page_info_list, title_aided_config):
if block["type"] == "title":
origin_title_list.append(block)
title_text = merge_para_with_text(block)
page_line_height_list = []
for line in block['lines']:
bbox = line['bbox']
page_line_height_list.append(int(bbox[3] - bbox[1]))
if len(page_line_height_list) > 0:
line_avg_height = sum(page_line_height_list) / len(page_line_height_list)
if 'line_avg_height' in block:
line_avg_height = block['line_avg_height']
else:
line_avg_height = int(block['bbox'][3] - block['bbox'][1])
title_block_line_height_list = []
for line in block['lines']:
bbox = line['bbox']
title_block_line_height_list.append(int(bbox[3] - bbox[1]))
if len(title_block_line_height_list) > 0:
line_avg_height = sum(title_block_line_height_list) / len(title_block_line_height_list)
else:
line_avg_height = int(block['bbox'][3] - block['bbox'][1])
title_dict[f"{i}"] = [title_text, line_avg_height, int(page_info['page_idx']) + 1]
i += 1
# logger.info(f"Title list: {title_dict}")
......@@ -91,7 +96,6 @@ Corrected title list:
if "</think>" in content:
idx = content.index("</think>") + len("</think>")
content = content[idx:].strip()
import json_repair
dict_completion = json_repair.loads(content)
dict_completion = {int(k): int(v) for k, v in dict_completion.items()}
......
......@@ -57,8 +57,12 @@ def auto_download_and_get_model_root_path(relative_path: str, repo_mode='pipelin
relative_path = relative_path.strip('/')
cache_dir = snapshot_download(repo, allow_patterns=[relative_path, relative_path+"/*"])
elif repo_mode == 'vlm':
# VLM 模式下,直接下载整个模型目录
cache_dir = snapshot_download(repo)
# VLM 模式下,根据 relative_path 的不同处理方式
if relative_path == "/":
cache_dir = snapshot_download(repo)
else:
relative_path = relative_path.strip('/')
cache_dir = snapshot_download(repo, allow_patterns=[relative_path, relative_path+"/*"])
if not cache_dir:
raise FileNotFoundError(f"Failed to download model: {relative_path} from {repo}")
......
......@@ -5,9 +5,11 @@ import numpy as np
class OcrConfidence:
min_confidence = 0.68
min_confidence = 0.5
min_width = 3
LINE_WIDTH_TO_HEIGHT_RATIO_THRESHOLD = 4 # 一般情况下,行宽度超过高度4倍时才是一个正常的横向文本块
def merge_spans_to_line(spans, threshold=0.6):
if len(spans) == 0:
......@@ -20,7 +22,7 @@ def merge_spans_to_line(spans, threshold=0.6):
current_line = [spans[0]]
for span in spans[1:]:
# 如果当前的span与当前行的最后一个span在y轴上重叠,则添加到当前行
if __is_overlaps_y_exceeds_threshold(span['bbox'], current_line[-1]['bbox'], threshold):
if _is_overlaps_y_exceeds_threshold(span['bbox'], current_line[-1]['bbox'], threshold):
current_line.append(span)
else:
# 否则,开始新行
......@@ -33,9 +35,9 @@ def merge_spans_to_line(spans, threshold=0.6):
return lines
def __is_overlaps_y_exceeds_threshold(bbox1,
bbox2,
overlap_ratio_threshold=0.8):
def _is_overlaps_y_exceeds_threshold(bbox1,
bbox2,
overlap_ratio_threshold=0.8):
"""检查两个bbox在y轴上是否有重叠,并且该重叠区域的高度占两个bbox高度更低的那个超过80%"""
_, y0_1, _, y1_1 = bbox1
_, y0_2, _, y1_2 = bbox2
......@@ -45,7 +47,21 @@ def __is_overlaps_y_exceeds_threshold(bbox1,
# max_height = max(height1, height2)
min_height = min(height1, height2)
return (overlap / min_height) > overlap_ratio_threshold
return (overlap / min_height) > overlap_ratio_threshold if min_height > 0 else False
def _is_overlaps_x_exceeds_threshold(bbox1,
bbox2,
overlap_ratio_threshold=0.8):
"""检查两个bbox在x轴上是否有重叠,并且该重叠区域的宽度占两个bbox宽度更低的那个超过指定阈值"""
x0_1, _, x1_1, _ = bbox1
x0_2, _, x1_2, _ = bbox2
overlap = max(0, min(x1_1, x1_2) - max(x0_1, x0_2))
width1, width2 = x1_1 - x0_1, x1_2 - x0_2
min_width = min(width1, width2)
return (overlap / min_width) > overlap_ratio_threshold if min_width > 0 else False
def img_decode(content: bytes):
......@@ -178,7 +194,7 @@ def update_det_boxes(dt_boxes, mfd_res):
masks_list = []
for mf_box in mfd_res:
mf_bbox = mf_box['bbox']
if __is_overlaps_y_exceeds_threshold(text_bbox, mf_bbox):
if _is_overlaps_y_exceeds_threshold(text_bbox, mf_bbox):
masks_list.append([mf_bbox[0], mf_bbox[2]])
text_x_range = [text_bbox[0], text_bbox[2]]
text_remove_mask_range = remove_intervals(text_x_range, masks_list)
......@@ -266,12 +282,27 @@ def merge_det_boxes(dt_boxes):
for span in line:
line_bbox_list.append(span['bbox'])
# Merge overlapping text regions within the same line
merged_spans = merge_overlapping_spans(line_bbox_list)
# 计算整行的宽度和高度
min_x = min(bbox[0] for bbox in line_bbox_list)
max_x = max(bbox[2] for bbox in line_bbox_list)
min_y = min(bbox[1] for bbox in line_bbox_list)
max_y = max(bbox[3] for bbox in line_bbox_list)
line_width = max_x - min_x
line_height = max_y - min_y
# 只有当行宽度超过高度4倍时才进行合并
if line_width > line_height * LINE_WIDTH_TO_HEIGHT_RATIO_THRESHOLD:
# Convert the merged text regions back to point format and add them to the new detection box list
for span in merged_spans:
new_dt_boxes.append(bbox_to_points(span))
# Merge overlapping text regions within the same line
merged_spans = merge_overlapping_spans(line_bbox_list)
# Convert the merged text regions back to point format and add them to the new detection box list
for span in merged_spans:
new_dt_boxes.append(bbox_to_points(span))
else:
# 不进行合并,直接添加原始区域
for bbox in line_bbox_list:
new_dt_boxes.append(bbox_to_points(bbox))
new_dt_boxes.extend(angle_boxes_list)
......
......@@ -15,7 +15,7 @@ def page_to_image(
scale = dpi / 72
long_side_length = max(*page.get_size())
if long_side_length > max_width_or_height:
if (long_side_length*scale) > max_width_or_height:
scale = max_width_or_height / long_side_length
bitmap: PdfBitmap = page.render(scale=scale) # type: ignore
......
This diff is collapsed.
# Copyright (c) Opendatalab. All rights reserved.
import collections
import re
import statistics
......@@ -187,7 +188,7 @@ def txt_spans_extract(pdf_page, spans, pil_img, scale, all_bboxes, all_discarded
span['chars'] = []
new_spans.append(span)
need_ocr_spans = fill_char_in_spans(new_spans, page_all_chars)
need_ocr_spans = fill_char_in_spans(new_spans, page_all_chars, median_span_height)
"""对未填充的span进行ocr"""
if len(need_ocr_spans) > 0:
......@@ -208,14 +209,26 @@ def txt_spans_extract(pdf_page, spans, pil_img, scale, all_bboxes, all_discarded
return spans
def fill_char_in_spans(spans, all_chars):
def fill_char_in_spans(spans, all_chars, median_span_height):
# 简单从上到下排一下序
spans = sorted(spans, key=lambda x: x['bbox'][1])
grid_size = median_span_height
grid = collections.defaultdict(list)
for i, span in enumerate(spans):
start_cell = int(span['bbox'][1] / grid_size)
end_cell = int(span['bbox'][3] / grid_size)
for cell_idx in range(start_cell, end_cell + 1):
grid[cell_idx].append(i)
for char in all_chars:
char_center_y = (char['bbox'][1] + char['bbox'][3]) / 2
cell_idx = int(char_center_y / grid_size)
candidate_span_indices = grid.get(cell_idx, [])
for span in spans:
for span_idx in candidate_span_indices:
span = spans[span_idx]
if calculate_char_in_span(char['bbox'], span['bbox'], char['char']):
span['chars'].append(char)
break
......
__version__ = "2.0.5"
__version__ = "2.0.6"
......@@ -3,10 +3,7 @@
## Project List
- Projects compatible with version 2.0:
- [gradio_app](./gradio_app/README.md): Web application based on Gradio
- [multi_gpu_v2](./multi_gpu_v2/README.md): Multi-GPU parallel processing based on LitServe
- Projects not yet compatible with version 2.0:
- [web_api](./web_api/README.md): Web API based on FastAPI
- [multi_gpu](./multi_gpu/README.md): Multi-GPU parallel processing based on LitServe
- [mcp](./mcp/README.md): MCP server based on the official API
......@@ -3,10 +3,7 @@
## 项目列表
- 已兼容2.0版本的项目列表
- [gradio_app](./gradio_app/README_zh-CN.md): 基于 Gradio 的 Web 应用
- [multi_gpu_v2](./multi_gpu_v2/README_zh.md): 基于 LitServe 的多 GPU 并行处理
- 未兼容2.0版本的项目列表
- [web_api](./web_api/README.md): 基于 FastAPI 的 Web API
- [multi_gpu](./multi_gpu/README.md): 基于 LitServe 的多 GPU 并行处理
- [mcp](./mcp/README.md): 基于官方api的mcp server
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment