Unverified Commit 763fbc60 authored by Xiaomeng Zhao's avatar Xiaomeng Zhao Committed by GitHub
Browse files

Merge pull request #2536 from seedclaimer/dev

支持batch-ocr-det,速度约提升3倍(200页pdf在3090上)
parents f5016508 54950551
...@@ -2,6 +2,8 @@ import time ...@@ -2,6 +2,8 @@ import time
import cv2 import cv2
from loguru import logger from loguru import logger
from tqdm import tqdm from tqdm import tqdm
from collections import defaultdict
import numpy as np
from magic_pdf.config.constants import MODEL_NAME from magic_pdf.config.constants import MODEL_NAME
from magic_pdf.model.sub_modules.model_init import AtomModelSingleton from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
...@@ -16,13 +18,14 @@ MFR_BASE_BATCH_SIZE = 16 ...@@ -16,13 +18,14 @@ MFR_BASE_BATCH_SIZE = 16
class BatchAnalyze: class BatchAnalyze:
def __init__(self, model_manager, batch_ratio: int, show_log, layout_model, formula_enable, table_enable): def __init__(self, model_manager, batch_ratio: int, show_log, layout_model, formula_enable, table_enable, enable_ocr_det_batch=True):
self.model_manager = model_manager self.model_manager = model_manager
self.batch_ratio = batch_ratio self.batch_ratio = batch_ratio
self.show_log = show_log self.show_log = show_log
self.layout_model = layout_model self.layout_model = layout_model
self.formula_enable = formula_enable self.formula_enable = formula_enable
self.table_enable = table_enable self.table_enable = table_enable
self.enable_ocr_det_batch = enable_ocr_det_batch
def __call__(self, images_with_extra_info: list) -> list: def __call__(self, images_with_extra_info: list) -> list:
if len(images_with_extra_info) == 0: if len(images_with_extra_info) == 0:
...@@ -33,10 +36,10 @@ class BatchAnalyze: ...@@ -33,10 +36,10 @@ class BatchAnalyze:
self.model = self.model_manager.get_model( self.model = self.model_manager.get_model(
ocr=True, ocr=True,
show_log=self.show_log, show_log=self.show_log,
lang = None, lang=None,
layout_model = self.layout_model, layout_model=self.layout_model,
formula_enable = self.formula_enable, formula_enable=self.formula_enable,
table_enable = self.table_enable, table_enable=self.table_enable,
) )
images = [image for image, _, _ in images_with_extra_info] images = [image for image, _, _ in images_with_extra_info]
...@@ -101,25 +104,134 @@ class BatchAnalyze: ...@@ -101,25 +104,134 @@ class BatchAnalyze:
get_res_list_from_layout_res(layout_res) get_res_list_from_layout_res(layout_res)
) )
ocr_res_list_all_page.append({'ocr_res_list':ocr_res_list, ocr_res_list_all_page.append({
'lang':_lang, 'ocr_res_list': ocr_res_list,
'ocr_enable':ocr_enable, 'lang': _lang,
'np_array_img':np_array_img, 'ocr_enable': ocr_enable,
'single_page_mfdetrec_res':single_page_mfdetrec_res, 'np_array_img': np_array_img,
'layout_res':layout_res, 'single_page_mfdetrec_res': single_page_mfdetrec_res,
'layout_res': layout_res,
}) })
for table_res in table_res_list: for table_res in table_res_list:
table_img, _ = crop_img(table_res, np_array_img) table_img, _ = crop_img(table_res, np_array_img)
table_res_list_all_page.append({'table_res':table_res, table_res_list_all_page.append({
'lang':_lang, 'table_res': table_res,
'table_img':table_img, 'lang': _lang,
'table_img': table_img,
}) })
# 文本框检测 # OCR检测处理
det_start = time.time() if self.enable_ocr_det_batch:
det_count = 0 # 批处理模式 - 按语言和分辨率分组
# for ocr_res_list_dict in ocr_res_list_all_page: # 收集所有需要OCR检测的裁剪图像
all_cropped_images_info = []
for ocr_res_list_dict in ocr_res_list_all_page:
_lang = ocr_res_list_dict['lang']
for res in ocr_res_list_dict['ocr_res_list']:
new_image, useful_list = crop_img(
res, ocr_res_list_dict['np_array_img'], crop_paste_x=50, crop_paste_y=50
)
adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(
ocr_res_list_dict['single_page_mfdetrec_res'], useful_list
)
# BGR转换
new_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR)
all_cropped_images_info.append((
new_image, useful_list, ocr_res_list_dict, res, adjusted_mfdetrec_res, _lang
))
# 按语言分组
lang_groups = defaultdict(list)
for crop_info in all_cropped_images_info:
lang = crop_info[5]
lang_groups[lang].append(crop_info)
# 对每种语言按分辨率分组并批处理
for lang, lang_crop_list in lang_groups.items():
if not lang_crop_list:
continue
# logger.info(f"Processing OCR detection for language {lang} with {len(lang_crop_list)} images")
# 获取OCR模型
atom_model_manager = AtomModelSingleton()
ocr_model = atom_model_manager.get_atom_model(
atom_model_name='ocr',
ocr_show_log=False,
det_db_box_thresh=0.3,
lang=lang
)
# 按分辨率分组并同时完成padding
resolution_groups = defaultdict(list)
for crop_info in lang_crop_list:
cropped_img = crop_info[0]
h, w = cropped_img.shape[:2]
# 使用更大的分组容差,减少分组数量
# 将尺寸标准化到32的倍数
normalized_h = ((h + 32) // 32) * 32 # 向上取整到32的倍数
normalized_w = ((w + 32) // 32) * 32
group_key = (normalized_h, normalized_w)
resolution_groups[group_key].append(crop_info)
# 对每个分辨率组进行批处理
for group_key, group_crops in tqdm(resolution_groups.items(), desc=f"OCR-det {lang}"):
raw_images = [crop_info[0] for crop_info in group_crops]
# 计算目标尺寸(组内最大尺寸,向上取整到32的倍数)
max_h = max(img.shape[0] for img in raw_images)
max_w = max(img.shape[1] for img in raw_images)
target_h = ((max_h + 32 - 1) // 32) * 32
target_w = ((max_w + 32 - 1) // 32) * 32
# 对所有图像进行padding到统一尺寸
batch_images = []
for img in raw_images:
h, w = img.shape[:2]
# 创建目标尺寸的白色背景
padded_img = np.ones((target_h, target_w, 3), dtype=np.uint8) * 255
# 将原图像粘贴到左上角
padded_img[:h, :w] = img
batch_images.append(padded_img)
# 批处理检测
batch_size = min(len(batch_images), self.batch_ratio * 16) # 增加批处理大小
# logger.debug(f"OCR-det batch: {batch_size} images, target size: {target_h}x{target_w}")
batch_results = ocr_model.text_detector.batch_predict(batch_images, batch_size)
# 处理批处理结果
for i, (crop_info, (dt_boxes, elapse)) in enumerate(zip(group_crops, batch_results)):
new_image, useful_list, ocr_res_list_dict, res, adjusted_mfdetrec_res, _lang = crop_info
if dt_boxes is not None:
# 构造OCR结果格式 - 每个box应该是4个点的列表
ocr_res = [box.tolist() for box in dt_boxes]
if ocr_res:
ocr_result_list = get_ocr_result_list(
ocr_res, useful_list, ocr_res_list_dict['ocr_enable'], new_image, _lang
)
if res["category_id"] == 3:
# ocr_result_list中所有bbox的面积之和
ocr_res_area = sum(get_coords_and_area(ocr_res_item)[4] for ocr_res_item in ocr_result_list if 'poly' in ocr_res_item)
# 求ocr_res_area和res的面积的比值
res_area = get_coords_and_area(res)[4]
if res_area > 0:
ratio = ocr_res_area / res_area
if ratio > 0.25:
res["category_id"] = 1
else:
continue
ocr_res_list_dict['layout_res'].extend(ocr_result_list)
else:
# 原始单张处理模式
for ocr_res_list_dict in tqdm(ocr_res_list_all_page, desc="OCR-det Predict"): for ocr_res_list_dict in tqdm(ocr_res_list_all_page, desc="OCR-det Predict"):
# Process each area that requires OCR processing # Process each area that requires OCR processing
_lang = ocr_res_list_dict['lang'] _lang = ocr_res_list_dict['lang']
......
...@@ -117,6 +117,128 @@ class TextDetector(BaseOCRV20): ...@@ -117,6 +117,128 @@ class TextDetector(BaseOCRV20):
self.net.eval() self.net.eval()
self.net.to(self.device) self.net.to(self.device)
def _batch_process_same_size(self, img_list):
"""
对相同尺寸的图像进行批处理
Args:
img_list: 相同尺寸的图像列表
Returns:
batch_results: 批处理结果列表
total_elapse: 总耗时
"""
starttime = time.time()
# 预处理所有图像
batch_data = []
batch_shapes = []
ori_imgs = []
for img in img_list:
ori_im = img.copy()
ori_imgs.append(ori_im)
data = {'image': img}
data = transform(data, self.preprocess_op)
if data is None:
# 如果预处理失败,返回空结果
return [(None, 0) for _ in img_list], 0
img_processed, shape_list = data
batch_data.append(img_processed)
batch_shapes.append(shape_list)
# 堆叠成批处理张量
try:
batch_tensor = np.stack(batch_data, axis=0)
batch_shapes = np.stack(batch_shapes, axis=0)
except Exception as e:
# 如果堆叠失败,回退到逐个处理
batch_results = []
for img in img_list:
dt_boxes, elapse = self.__call__(img)
batch_results.append((dt_boxes, elapse))
return batch_results, time.time() - starttime
# 批处理推理
with torch.no_grad():
inp = torch.from_numpy(batch_tensor)
inp = inp.to(self.device)
outputs = self.net(inp)
# 处理输出
preds = {}
if self.det_algorithm == "EAST":
preds['f_geo'] = outputs['f_geo'].cpu().numpy()
preds['f_score'] = outputs['f_score'].cpu().numpy()
elif self.det_algorithm == 'SAST':
preds['f_border'] = outputs['f_border'].cpu().numpy()
preds['f_score'] = outputs['f_score'].cpu().numpy()
preds['f_tco'] = outputs['f_tco'].cpu().numpy()
preds['f_tvo'] = outputs['f_tvo'].cpu().numpy()
elif self.det_algorithm in ['DB', 'PSE', 'DB++']:
preds['maps'] = outputs['maps'].cpu().numpy()
elif self.det_algorithm == 'FCE':
for i, (k, output) in enumerate(outputs.items()):
preds['level_{}'.format(i)] = output.cpu().numpy()
else:
raise NotImplementedError
# 后处理每个图像的结果
batch_results = []
total_elapse = time.time() - starttime
for i in range(len(img_list)):
# 提取单个图像的预测结果
single_preds = {}
for key, value in preds.items():
if isinstance(value, np.ndarray):
single_preds[key] = value[i:i + 1] # 保持批次维度
else:
single_preds[key] = value
# 后处理
post_result = self.postprocess_op(single_preds, batch_shapes[i:i + 1])
dt_boxes = post_result[0]['points']
# 过滤和裁剪检测框
if (self.det_algorithm == "SAST" and
self.det_sast_polygon) or (self.det_algorithm in ["PSE", "FCE"] and
self.postprocess_op.box_type == 'poly'):
dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_imgs[i].shape)
else:
dt_boxes = self.filter_tag_det_res(dt_boxes, ori_imgs[i].shape)
batch_results.append((dt_boxes, total_elapse / len(img_list)))
return batch_results, total_elapse
def batch_predict(self, img_list, max_batch_size=8):
"""
批处理预测方法,支持多张图像同时检测
Args:
img_list: 图像列表
max_batch_size: 最大批处理大小
Returns:
batch_results: 批处理结果列表,每个元素为(dt_boxes, elapse)
"""
if not img_list:
return []
batch_results = []
# 分批处理
for i in range(0, len(img_list), max_batch_size):
batch_imgs = img_list[i:i + max_batch_size]
# assert尺寸一致
batch_dt_boxes, batch_elapse = self._batch_process_same_size(batch_imgs)
batch_results.extend(batch_dt_boxes)
return batch_results
def order_points_clockwise(self, pts): def order_points_clockwise(self, pts):
""" """
reference from: https://github.com/jrosebr1/imutils/blob/master/imutils/perspective.py reference from: https://github.com/jrosebr1/imutils/blob/master/imutils/perspective.py
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment