Unverified Commit 763fbc60 authored by Xiaomeng Zhao's avatar Xiaomeng Zhao Committed by GitHub
Browse files

Merge pull request #2536 from seedclaimer/dev

支持batch-ocr-det,速度约提升3倍(200页pdf在3090上)
parents f5016508 54950551
......@@ -2,6 +2,8 @@ import time
import cv2
from loguru import logger
from tqdm import tqdm
from collections import defaultdict
import numpy as np
from magic_pdf.config.constants import MODEL_NAME
from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
......@@ -16,27 +18,28 @@ MFR_BASE_BATCH_SIZE = 16
class BatchAnalyze:
def __init__(self, model_manager, batch_ratio: int, show_log, layout_model, formula_enable, table_enable):
def __init__(self, model_manager, batch_ratio: int, show_log, layout_model, formula_enable, table_enable, enable_ocr_det_batch=True):
self.model_manager = model_manager
self.batch_ratio = batch_ratio
self.show_log = show_log
self.layout_model = layout_model
self.formula_enable = formula_enable
self.table_enable = table_enable
self.enable_ocr_det_batch = enable_ocr_det_batch
def __call__(self, images_with_extra_info: list) -> list:
if len(images_with_extra_info) == 0:
return []
images_layout_res = []
layout_start_time = time.time()
self.model = self.model_manager.get_model(
ocr=True,
show_log=self.show_log,
lang = None,
layout_model = self.layout_model,
formula_enable = self.formula_enable,
table_enable = self.table_enable,
lang=None,
layout_model=self.layout_model,
formula_enable=self.formula_enable,
table_enable=self.table_enable,
)
images = [image for image, _, _ in images_with_extra_info]
......@@ -101,43 +104,152 @@ class BatchAnalyze:
get_res_list_from_layout_res(layout_res)
)
ocr_res_list_all_page.append({'ocr_res_list':ocr_res_list,
'lang':_lang,
'ocr_enable':ocr_enable,
'np_array_img':np_array_img,
'single_page_mfdetrec_res':single_page_mfdetrec_res,
'layout_res':layout_res,
})
ocr_res_list_all_page.append({
'ocr_res_list': ocr_res_list,
'lang': _lang,
'ocr_enable': ocr_enable,
'np_array_img': np_array_img,
'single_page_mfdetrec_res': single_page_mfdetrec_res,
'layout_res': layout_res,
})
for table_res in table_res_list:
table_img, _ = crop_img(table_res, np_array_img)
table_res_list_all_page.append({'table_res':table_res,
'lang':_lang,
'table_img':table_img,
})
# 文本框检测
det_start = time.time()
det_count = 0
# for ocr_res_list_dict in ocr_res_list_all_page:
for ocr_res_list_dict in tqdm(ocr_res_list_all_page, desc="OCR-det Predict"):
# Process each area that requires OCR processing
_lang = ocr_res_list_dict['lang']
# Get OCR results for this language's images
atom_model_manager = AtomModelSingleton()
ocr_model = atom_model_manager.get_atom_model(
atom_model_name='ocr',
ocr_show_log=False,
det_db_box_thresh=0.3,
lang=_lang
)
for res in ocr_res_list_dict['ocr_res_list']:
new_image, useful_list = crop_img(
res, ocr_res_list_dict['np_array_img'], crop_paste_x=50, crop_paste_y=50
table_res_list_all_page.append({
'table_res': table_res,
'lang': _lang,
'table_img': table_img,
})
# OCR检测处理
if self.enable_ocr_det_batch:
# 批处理模式 - 按语言和分辨率分组
# 收集所有需要OCR检测的裁剪图像
all_cropped_images_info = []
for ocr_res_list_dict in ocr_res_list_all_page:
_lang = ocr_res_list_dict['lang']
for res in ocr_res_list_dict['ocr_res_list']:
new_image, useful_list = crop_img(
res, ocr_res_list_dict['np_array_img'], crop_paste_x=50, crop_paste_y=50
)
adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(
ocr_res_list_dict['single_page_mfdetrec_res'], useful_list
)
# BGR转换
new_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR)
all_cropped_images_info.append((
new_image, useful_list, ocr_res_list_dict, res, adjusted_mfdetrec_res, _lang
))
# 按语言分组
lang_groups = defaultdict(list)
for crop_info in all_cropped_images_info:
lang = crop_info[5]
lang_groups[lang].append(crop_info)
# 对每种语言按分辨率分组并批处理
for lang, lang_crop_list in lang_groups.items():
if not lang_crop_list:
continue
# logger.info(f"Processing OCR detection for language {lang} with {len(lang_crop_list)} images")
# 获取OCR模型
atom_model_manager = AtomModelSingleton()
ocr_model = atom_model_manager.get_atom_model(
atom_model_name='ocr',
ocr_show_log=False,
det_db_box_thresh=0.3,
lang=lang
)
adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(
ocr_res_list_dict['single_page_mfdetrec_res'], useful_list
# 按分辨率分组并同时完成padding
resolution_groups = defaultdict(list)
for crop_info in lang_crop_list:
cropped_img = crop_info[0]
h, w = cropped_img.shape[:2]
# 使用更大的分组容差,减少分组数量
# 将尺寸标准化到32的倍数
normalized_h = ((h + 32) // 32) * 32 # 向上取整到32的倍数
normalized_w = ((w + 32) // 32) * 32
group_key = (normalized_h, normalized_w)
resolution_groups[group_key].append(crop_info)
# 对每个分辨率组进行批处理
for group_key, group_crops in tqdm(resolution_groups.items(), desc=f"OCR-det {lang}"):
raw_images = [crop_info[0] for crop_info in group_crops]
# 计算目标尺寸(组内最大尺寸,向上取整到32的倍数)
max_h = max(img.shape[0] for img in raw_images)
max_w = max(img.shape[1] for img in raw_images)
target_h = ((max_h + 32 - 1) // 32) * 32
target_w = ((max_w + 32 - 1) // 32) * 32
# 对所有图像进行padding到统一尺寸
batch_images = []
for img in raw_images:
h, w = img.shape[:2]
# 创建目标尺寸的白色背景
padded_img = np.ones((target_h, target_w, 3), dtype=np.uint8) * 255
# 将原图像粘贴到左上角
padded_img[:h, :w] = img
batch_images.append(padded_img)
# 批处理检测
batch_size = min(len(batch_images), self.batch_ratio * 16) # 增加批处理大小
# logger.debug(f"OCR-det batch: {batch_size} images, target size: {target_h}x{target_w}")
batch_results = ocr_model.text_detector.batch_predict(batch_images, batch_size)
# 处理批处理结果
for i, (crop_info, (dt_boxes, elapse)) in enumerate(zip(group_crops, batch_results)):
new_image, useful_list, ocr_res_list_dict, res, adjusted_mfdetrec_res, _lang = crop_info
if dt_boxes is not None:
# 构造OCR结果格式 - 每个box应该是4个点的列表
ocr_res = [box.tolist() for box in dt_boxes]
if ocr_res:
ocr_result_list = get_ocr_result_list(
ocr_res, useful_list, ocr_res_list_dict['ocr_enable'], new_image, _lang
)
if res["category_id"] == 3:
# ocr_result_list中所有bbox的面积之和
ocr_res_area = sum(get_coords_and_area(ocr_res_item)[4] for ocr_res_item in ocr_result_list if 'poly' in ocr_res_item)
# 求ocr_res_area和res的面积的比值
res_area = get_coords_and_area(res)[4]
if res_area > 0:
ratio = ocr_res_area / res_area
if ratio > 0.25:
res["category_id"] = 1
else:
continue
ocr_res_list_dict['layout_res'].extend(ocr_result_list)
else:
# 原始单张处理模式
for ocr_res_list_dict in tqdm(ocr_res_list_all_page, desc="OCR-det Predict"):
# Process each area that requires OCR processing
_lang = ocr_res_list_dict['lang']
# Get OCR results for this language's images
atom_model_manager = AtomModelSingleton()
ocr_model = atom_model_manager.get_atom_model(
atom_model_name='ocr',
ocr_show_log=False,
det_db_box_thresh=0.3,
lang=_lang
)
for res in ocr_res_list_dict['ocr_res_list']:
new_image, useful_list = crop_img(
res, ocr_res_list_dict['np_array_img'], crop_paste_x=50, crop_paste_y=50
)
adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(
ocr_res_list_dict['single_page_mfdetrec_res'], useful_list
)
# OCR-det
new_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR)
......
......@@ -117,6 +117,128 @@ class TextDetector(BaseOCRV20):
self.net.eval()
self.net.to(self.device)
def _batch_process_same_size(self, img_list):
"""
对相同尺寸的图像进行批处理
Args:
img_list: 相同尺寸的图像列表
Returns:
batch_results: 批处理结果列表
total_elapse: 总耗时
"""
starttime = time.time()
# 预处理所有图像
batch_data = []
batch_shapes = []
ori_imgs = []
for img in img_list:
ori_im = img.copy()
ori_imgs.append(ori_im)
data = {'image': img}
data = transform(data, self.preprocess_op)
if data is None:
# 如果预处理失败,返回空结果
return [(None, 0) for _ in img_list], 0
img_processed, shape_list = data
batch_data.append(img_processed)
batch_shapes.append(shape_list)
# 堆叠成批处理张量
try:
batch_tensor = np.stack(batch_data, axis=0)
batch_shapes = np.stack(batch_shapes, axis=0)
except Exception as e:
# 如果堆叠失败,回退到逐个处理
batch_results = []
for img in img_list:
dt_boxes, elapse = self.__call__(img)
batch_results.append((dt_boxes, elapse))
return batch_results, time.time() - starttime
# 批处理推理
with torch.no_grad():
inp = torch.from_numpy(batch_tensor)
inp = inp.to(self.device)
outputs = self.net(inp)
# 处理输出
preds = {}
if self.det_algorithm == "EAST":
preds['f_geo'] = outputs['f_geo'].cpu().numpy()
preds['f_score'] = outputs['f_score'].cpu().numpy()
elif self.det_algorithm == 'SAST':
preds['f_border'] = outputs['f_border'].cpu().numpy()
preds['f_score'] = outputs['f_score'].cpu().numpy()
preds['f_tco'] = outputs['f_tco'].cpu().numpy()
preds['f_tvo'] = outputs['f_tvo'].cpu().numpy()
elif self.det_algorithm in ['DB', 'PSE', 'DB++']:
preds['maps'] = outputs['maps'].cpu().numpy()
elif self.det_algorithm == 'FCE':
for i, (k, output) in enumerate(outputs.items()):
preds['level_{}'.format(i)] = output.cpu().numpy()
else:
raise NotImplementedError
# 后处理每个图像的结果
batch_results = []
total_elapse = time.time() - starttime
for i in range(len(img_list)):
# 提取单个图像的预测结果
single_preds = {}
for key, value in preds.items():
if isinstance(value, np.ndarray):
single_preds[key] = value[i:i + 1] # 保持批次维度
else:
single_preds[key] = value
# 后处理
post_result = self.postprocess_op(single_preds, batch_shapes[i:i + 1])
dt_boxes = post_result[0]['points']
# 过滤和裁剪检测框
if (self.det_algorithm == "SAST" and
self.det_sast_polygon) or (self.det_algorithm in ["PSE", "FCE"] and
self.postprocess_op.box_type == 'poly'):
dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_imgs[i].shape)
else:
dt_boxes = self.filter_tag_det_res(dt_boxes, ori_imgs[i].shape)
batch_results.append((dt_boxes, total_elapse / len(img_list)))
return batch_results, total_elapse
def batch_predict(self, img_list, max_batch_size=8):
"""
批处理预测方法,支持多张图像同时检测
Args:
img_list: 图像列表
max_batch_size: 最大批处理大小
Returns:
batch_results: 批处理结果列表,每个元素为(dt_boxes, elapse)
"""
if not img_list:
return []
batch_results = []
# 分批处理
for i in range(0, len(img_list), max_batch_size):
batch_imgs = img_list[i:i + max_batch_size]
# assert尺寸一致
batch_dt_boxes, batch_elapse = self._batch_process_same_size(batch_imgs)
batch_results.extend(batch_dt_boxes)
return batch_results
def order_points_clockwise(self, pts):
"""
reference from: https://github.com/jrosebr1/imutils/blob/master/imutils/perspective.py
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment