Commit 0039d113 authored by myhloli's avatar myhloli
Browse files

refactor: improve batch processing logic and enhance OCR result handling

parent 7c8fb44b
......@@ -14,7 +14,7 @@ MFR_BASE_BATCH_SIZE = 16
class BatchAnalyze:
def __init__(self, model_manager, batch_ratio: int, formula_enable, table_enable, enable_ocr_det_batch: bool = False):
def __init__(self, model_manager, batch_ratio: int, formula_enable, table_enable, enable_ocr_det_batch: bool = True):
self.batch_ratio = batch_ratio
self.formula_enable = formula_enable
self.table_enable = table_enable
......@@ -150,17 +150,17 @@ class BatchAnalyze:
# 对每个分辨率组进行批处理
for group_key, group_crops in tqdm(resolution_groups.items(), desc=f"OCR-det {lang}"):
raw_images = [crop_info[0] for crop_info in group_crops]
# 计算目标尺寸(组内最大尺寸,向上取整到32的倍数)
max_h = max(img.shape[0] for img in raw_images)
max_w = max(img.shape[1] for img in raw_images)
max_h = max(crop_info[0].shape[0] for crop_info in group_crops)
max_w = max(crop_info[0].shape[1] for crop_info in group_crops)
target_h = ((max_h + 32 - 1) // 32) * 32
target_w = ((max_w + 32 - 1) // 32) * 32
# 对所有图像进行padding到统一尺寸
batch_images = []
for img in raw_images:
for crop_info in group_crops:
img = crop_info[0]
h, w = img.shape[:2]
# 创建目标尺寸的白色背景
padded_img = np.ones((target_h, target_w, 3), dtype=np.uint8) * 255
......@@ -177,28 +177,38 @@ class BatchAnalyze:
for i, (crop_info, (dt_boxes, elapse)) in enumerate(zip(group_crops, batch_results)):
new_image, useful_list, ocr_res_list_dict, res, adjusted_mfdetrec_res, _lang = crop_info
if dt_boxes is not None:
# 构造OCR结果格式 - 每个box应该是4个点的列表
ocr_res = [box.tolist() for box in dt_boxes]
if dt_boxes is not None and len(dt_boxes) > 0:
# 直接应用原始OCR流程中的关键处理步骤
from mineru.utils.ocr_utils import (
merge_det_boxes, update_det_boxes, sorted_boxes
)
# 1. 排序检测框
if len(dt_boxes) > 0:
dt_boxes_sorted = sorted_boxes(dt_boxes)
else:
dt_boxes_sorted = []
# 2. 合并相邻检测框
if dt_boxes_sorted:
dt_boxes_merged = merge_det_boxes(dt_boxes_sorted)
else:
dt_boxes_merged = []
# 3. 根据公式位置更新检测框(关键步骤!)
if dt_boxes_merged and adjusted_mfdetrec_res:
dt_boxes_final = update_det_boxes(dt_boxes_merged, adjusted_mfdetrec_res)
else:
dt_boxes_final = dt_boxes_merged
# 构造OCR结果格式
ocr_res = [box.tolist() if hasattr(box, 'tolist') else box for box in dt_boxes_final]
if ocr_res:
ocr_result_list = get_ocr_result_list(
ocr_res, useful_list, ocr_res_list_dict['ocr_enable'], new_image, _lang
)
if res["category_id"] == 3:
# ocr_result_list中所有bbox的面积之和
ocr_res_area = sum(
get_coords_and_area(ocr_res_item)[4] for ocr_res_item in ocr_result_list if 'poly' in ocr_res_item)
# 求ocr_res_area和res的面积的比值
res_area = get_coords_and_area(res)[4]
if res_area > 0:
ratio = ocr_res_area / res_area
if ratio > 0.25:
res["category_id"] = 1
else:
continue
ocr_res_list_dict['layout_res'].extend(ocr_result_list)
else:
# 原始单张处理模式
......@@ -227,8 +237,9 @@ class BatchAnalyze:
# Integration results
if ocr_res:
ocr_result_list = get_ocr_result_list(ocr_res, useful_list, ocr_res_list_dict['ocr_enable'],
new_image, _lang)
ocr_result_list = get_ocr_result_list(
ocr_res, useful_list, ocr_res_list_dict['ocr_enable'],new_image, _lang
)
ocr_res_list_dict['layout_res'].extend(ocr_result_list)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment