"examples/vscode:/vscode.git/clone" did not exist on "72a98a86c6bc17573cbe61a26061689b3830e5af"
Commit 73f85035 authored by myhloli's avatar myhloli
Browse files

refactor: optimize OCR batch processing and enhance image cropping logic

parent 101b12a1
......@@ -92,144 +92,100 @@ class BatchAnalyze:
'table_img':table_img,
})
# OCR检测处理
if self.enable_ocr_det_batch:
# 批处理模式 - 按语言和分辨率分组
# 收集所有需要OCR检测的裁剪图像
all_cropped_images_info = []
for ocr_res_list_dict in ocr_res_list_all_page:
_lang = ocr_res_list_dict['lang']
for res in ocr_res_list_dict['ocr_res_list']:
new_image, useful_list = crop_img(
res, ocr_res_list_dict['pil_img'], crop_paste_x=50, crop_paste_y=50
)
adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(
ocr_res_list_dict['single_page_mfdetrec_res'], useful_list
)
# BGR转换
new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
all_cropped_images_info.append((
new_image, useful_list, ocr_res_list_dict, res, adjusted_mfdetrec_res, _lang
))
# 按语言分组
lang_groups = defaultdict(list)
for crop_info in all_cropped_images_info:
lang = crop_info[5]
lang_groups[lang].append(crop_info)
# 对每种语言按分辨率分组并批处理
for lang, lang_crop_list in lang_groups.items():
if not lang_crop_list:
continue
# logger.info(f"Processing OCR detection for language {lang} with {len(lang_crop_list)} images")
# 获取OCR模型
ocr_model = atom_model_manager.get_atom_model(
atom_model_name='ocr',
ocr_show_log=False,
det_db_box_thresh=0.3,
lang=lang
)
# OCR检测处理
if self.enable_ocr_det_batch:
# 批处理模式 - 按语言和分辨率分组
# 收集所有需要OCR检测的裁剪图像
all_cropped_images_info = []
for ocr_res_list_dict in ocr_res_list_all_page:
_lang = ocr_res_list_dict['lang']
for res in ocr_res_list_dict['ocr_res_list']:
new_image, useful_list = crop_img(
res, ocr_res_list_dict['pil_img'], crop_paste_x=50, crop_paste_y=50
)
adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(
ocr_res_list_dict['single_page_mfdetrec_res'], useful_list
)
# BGR转换
new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
all_cropped_images_info.append((
new_image, useful_list, ocr_res_list_dict, res, adjusted_mfdetrec_res, _lang
))
# 按语言分组
lang_groups = defaultdict(list)
for crop_info in all_cropped_images_info:
lang = crop_info[5]
lang_groups[lang].append(crop_info)
# 对每种语言按分辨率分组并批处理
for lang, lang_crop_list in lang_groups.items():
if not lang_crop_list:
continue
logger.info(f"Processing OCR detection for language {lang} with {len(lang_crop_list)} images")
# 获取OCR模型
ocr_model = atom_model_manager.get_atom_model(
atom_model_name='ocr',
ocr_show_log=False,
det_db_box_thresh=0.3,
lang=lang
)
# 按分辨率分组并同时完成padding
resolution_groups = defaultdict(list)
for crop_info in lang_crop_list:
cropped_img = crop_info[0]
h, w = cropped_img.shape[:2]
# 使用更大的分组容差,减少分组数量
# 将尺寸标准化到32的倍数
normalized_h = ((h + 32) // 32) * 32 # 向上取整到32的倍数
normalized_w = ((w + 32) // 32) * 32
group_key = (normalized_h, normalized_w)
resolution_groups[group_key].append(crop_info)
# 对每个分辨率组进行批处理
for group_key, group_crops in tqdm(resolution_groups.items(), desc=f"OCR-det {lang}"):
raw_images = [crop_info[0] for crop_info in group_crops]
# 计算目标尺寸(组内最大尺寸,向上取整到32的倍数)
max_h = max(img.shape[0] for img in raw_images)
max_w = max(img.shape[1] for img in raw_images)
target_h = ((max_h + 32 - 1) // 32) * 32
target_w = ((max_w + 32 - 1) // 32) * 32
# 对所有图像进行padding到统一尺寸
batch_images = []
for img in raw_images:
h, w = img.shape[:2]
# 创建目标尺寸的白色背景
padded_img = np.ones((target_h, target_w, 3), dtype=np.uint8) * 255
# 将原图像粘贴到左上角
padded_img[:h, :w] = img
batch_images.append(padded_img)
# 批处理检测
batch_size = min(len(batch_images), self.batch_ratio * 16) # 增加批处理大小
logger.debug(f"OCR-det batch: {batch_size} images, target size: {target_h}x{target_w}")
batch_results = ocr_model.text_detector.batch_predict(batch_images, batch_size)
# 处理批处理结果
for i, (crop_info, (dt_boxes, elapse)) in enumerate(zip(group_crops, batch_results)):
new_image, useful_list, ocr_res_list_dict, res, adjusted_mfdetrec_res, _lang = crop_info
if dt_boxes is not None:
# 构造OCR结果格式 - 每个box应该是4个点的列表
ocr_res = [box.tolist() for box in dt_boxes]
# 按分辨率分组并同时完成padding
resolution_groups = defaultdict(list)
for crop_info in lang_crop_list:
cropped_img = crop_info[0]
h, w = cropped_img.shape[:2]
# 使用更大的分组容差,减少分组数量
# 将尺寸标准化到32的倍数
normalized_h = ((h + 32) // 32) * 32 # 向上取整到32的倍数
normalized_w = ((w + 32) // 32) * 32
group_key = (normalized_h, normalized_w)
resolution_groups[group_key].append(crop_info)
# 对每个分辨率组进行批处理
for group_key, group_crops in tqdm(resolution_groups.items(), desc=f"OCR-det {lang}"):
raw_images = [crop_info[0] for crop_info in group_crops]
# 计算目标尺寸(组内最大尺寸,向上取整到32的倍数)
max_h = max(img.shape[0] for img in raw_images)
max_w = max(img.shape[1] for img in raw_images)
target_h = ((max_h + 32 - 1) // 32) * 32
target_w = ((max_w + 32 - 1) // 32) * 32
# 对所有图像进行padding到统一尺寸
batch_images = []
for img in raw_images:
h, w = img.shape[:2]
# 创建目标尺寸的白色背景
padded_img = np.ones((target_h, target_w, 3), dtype=np.uint8) * 255
# 将原图像粘贴到左上角
padded_img[:h, :w] = img
batch_images.append(padded_img)
# 批处理检测
batch_size = min(len(batch_images), self.batch_ratio * 16) # 增加批处理大小
# logger.debug(f"OCR-det batch: {batch_size} images, target size: {target_h}x{target_w}")
batch_results = ocr_model.text_detector.batch_predict(batch_images, batch_size)
# 处理批处理结果
for i, (crop_info, (dt_boxes, elapse)) in enumerate(zip(group_crops, batch_results)):
new_image, useful_list, ocr_res_list_dict, res, adjusted_mfdetrec_res, _lang = crop_info
if dt_boxes is not None:
# 构造OCR结果格式 - 每个box应该是4个点的列表
ocr_res = [box.tolist() for box in dt_boxes]
if ocr_res:
ocr_result_list = get_ocr_result_list(
ocr_res, useful_list, ocr_res_list_dict['ocr_enable'], new_image, _lang
)
if res["category_id"] == 3:
# ocr_result_list中所有bbox的面积之和
ocr_res_area = sum(
get_coords_and_area(ocr_res_item)[4] for ocr_res_item in ocr_result_list if 'poly' in ocr_res_item)
# 求ocr_res_area和res的面积的比值
res_area = get_coords_and_area(res)[4]
if res_area > 0:
ratio = ocr_res_area / res_area
if ratio > 0.25:
res["category_id"] = 1
else:
continue
ocr_res_list_dict['layout_res'].extend(ocr_result_list)
else:
# 原始单张处理模式
for ocr_res_list_dict in tqdm(ocr_res_list_all_page, desc="OCR-det Predict"):
# Process each area that requires OCR processing
_lang = ocr_res_list_dict['lang']
# Get OCR results for this language's images
ocr_model = atom_model_manager.get_atom_model(
atom_model_name='ocr',
ocr_show_log=False,
det_db_box_thresh=0.3,
lang=_lang
)
for res in ocr_res_list_dict['ocr_res_list']:
new_image, useful_list = crop_img(
res, ocr_res_list_dict['pil_img'], crop_paste_x=50, crop_paste_y=50
)
adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(
ocr_res_list_dict['single_page_mfdetrec_res'], useful_list
)
# OCR-det
new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
ocr_res = ocr_model.ocr(
new_image, mfd_res=adjusted_mfdetrec_res, rec=False
)[0]
# Integration results
if ocr_res:
ocr_result_list = get_ocr_result_list(ocr_res, useful_list, ocr_res_list_dict['ocr_enable'],
new_image, _lang)
ocr_result_list = get_ocr_result_list(
ocr_res, useful_list, ocr_res_list_dict['ocr_enable'], new_image, _lang
)
if res["category_id"] == 3:
# ocr_result_list中所有bbox的面积之和
......@@ -245,6 +201,50 @@ class BatchAnalyze:
continue
ocr_res_list_dict['layout_res'].extend(ocr_result_list)
else:
# 原始单张处理模式
for ocr_res_list_dict in tqdm(ocr_res_list_all_page, desc="OCR-det Predict"):
# Process each area that requires OCR processing
_lang = ocr_res_list_dict['lang']
# Get OCR results for this language's images
ocr_model = atom_model_manager.get_atom_model(
atom_model_name='ocr',
ocr_show_log=False,
det_db_box_thresh=0.3,
lang=_lang
)
for res in ocr_res_list_dict['ocr_res_list']:
new_image, useful_list = crop_img(
res, ocr_res_list_dict['pil_img'], crop_paste_x=50, crop_paste_y=50
)
adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(
ocr_res_list_dict['single_page_mfdetrec_res'], useful_list
)
# OCR-det
new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
ocr_res = ocr_model.ocr(
new_image, mfd_res=adjusted_mfdetrec_res, rec=False
)[0]
# Integration results
if ocr_res:
ocr_result_list = get_ocr_result_list(ocr_res, useful_list, ocr_res_list_dict['ocr_enable'],
new_image, _lang)
if res["category_id"] == 3:
# ocr_result_list中所有bbox的面积之和
ocr_res_area = sum(
get_coords_and_area(ocr_res_item)[4] for ocr_res_item in ocr_result_list if 'poly' in ocr_res_item)
# 求ocr_res_area和res的面积的比值
res_area = get_coords_and_area(res)[4]
if res_area > 0:
ratio = ocr_res_area / res_area
if ratio > 0.25:
res["category_id"] = 1
else:
continue
ocr_res_list_dict['layout_res'].extend(ocr_result_list)
# 表格识别 table recognition
if self.table_enable:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment