Commit a330651d authored by myhloli's avatar myhloli
Browse files

feat(ocr): implement separate detection and recognition processes

- Split OCR process into detection and recognition stages
- Update batch analysis and document analysis pipelines
- Modify OCR result formatting and handling
- Remove unused imports and optimize code structure
parent a9b37b71
......@@ -8,7 +8,7 @@ from magic_pdf.config.constants import MODEL_NAME
from magic_pdf.model.pdf_extract_kit import CustomPEKModel
from magic_pdf.model.sub_modules.model_utils import (
clean_vram, crop_img, get_res_list_from_layout_res)
from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import (
from magic_pdf.model.sub_modules.ocr.paddleocr2pytorch.ocr_utils import (
get_adjusted_mfdetrec_res, get_ocr_result_list)
YOLO_LAYOUT_BASE_BATCH_SIZE = 1
......@@ -85,8 +85,8 @@ class BatchAnalyze:
# 清理显存
clean_vram(self.model.device, vram_threshold=8)
ocr_time = 0
ocr_count = 0
det_time = 0
det_count = 0
table_time = 0
table_count = 0
# reference: magic_pdf/model/doc_analyze_by_custom_model.py:doc_analyze
......@@ -100,7 +100,7 @@ class BatchAnalyze:
get_res_list_from_layout_res(layout_res)
)
# ocr识别
ocr_start = time.time()
det_start = time.time()
# Process each area that requires OCR processing
for res in ocr_res_list:
new_image, useful_list = crop_img(
......@@ -113,21 +113,21 @@ class BatchAnalyze:
# OCR recognition
new_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR)
if ocr_enable:
ocr_res = self.model.ocr_model.ocr(
new_image, mfd_res=adjusted_mfdetrec_res
)[0]
else:
ocr_res = self.model.ocr_model.ocr(
new_image, mfd_res=adjusted_mfdetrec_res, rec=False
)[0]
# if ocr_enable:
# ocr_res = self.model.ocr_model.ocr(
# new_image, mfd_res=adjusted_mfdetrec_res
# )[0]
# else:
ocr_res = self.model.ocr_model.ocr(
new_image, mfd_res=adjusted_mfdetrec_res, rec=False
)[0]
# Integration results
if ocr_res:
ocr_result_list = get_ocr_result_list(ocr_res, useful_list)
ocr_result_list = get_ocr_result_list(ocr_res, useful_list, ocr_enable, new_image)
layout_res.extend(ocr_result_list)
ocr_time += time.time() - ocr_start
ocr_count += len(ocr_res_list)
det_time += time.time() - det_start
det_count += len(ocr_res_list)
# 表格识别 table recognition
if self.model.apply_table:
......@@ -172,9 +172,33 @@ class BatchAnalyze:
table_time += time.time() - table_start
table_count += len(table_res_list)
if self.model.apply_ocr:
logger.info(f'det or det time costs: {round(ocr_time, 2)}, image num: {ocr_count}')
logger.info(f'ocr-det time: {round(det_time, 2)}, image num: {det_count}')
if self.model.apply_table:
logger.info(f'table time: {round(table_time, 2)}, image num: {table_count}')
need_ocr_list = []
img_crop_list = []
for layout_res in images_layout_res:
for layout_res_item in layout_res:
if layout_res_item['category_id'] in [15]:
if 'np_img' in layout_res_item:
need_ocr_list.append(layout_res_item)
img_crop_list.append(layout_res_item['np_img'])
layout_res_item.pop('np_img')
rec_time = 0
rec_start = time.time()
if len(img_crop_list) > 0:
ocr_res_list = self.model.ocr_model.ocr(img_crop_list, det=False)[0]
assert len(ocr_res_list)==len(need_ocr_list), f'ocr_res_list: {len(ocr_res_list)}, need_ocr_list: {len(need_ocr_list)}'
for index, layout_res_item in enumerate(need_ocr_list):
ocr_text, ocr_score = ocr_res_list[index]
layout_res_item['text'] = ocr_text
layout_res_item['score'] = float(round(ocr_score, 2))
rec_time += time.time() - rec_start
logger.info(f'ocr-rec time: {round(rec_time, 2)}, image num: {len(img_crop_list)}')
return images_layout_res
......@@ -141,7 +141,7 @@ def doc_analyze(
else len(dataset) - 1
)
MIN_BATCH_INFERENCE_SIZE = int(os.environ.get('MINERU_MIN_BATCH_INFERENCE_SIZE', 100))
MIN_BATCH_INFERENCE_SIZE = int(os.environ.get('MINERU_MIN_BATCH_INFERENCE_SIZE', 200))
images = []
page_wh_list = []
for index in range(len(dataset)):
......
# Copyright (c) Opendatalab. All rights reserved.
import copy
import cv2
import numpy as np
from magic_pdf.pre_proc.ocr_dict_merge import merge_spans_to_line
......@@ -259,9 +261,10 @@ def get_adjusted_mfdetrec_res(single_page_mfdetrec_res, useful_list):
return adjusted_mfdetrec_res
def get_ocr_result_list(ocr_res, useful_list):
def get_ocr_result_list(ocr_res, useful_list, ocr_enable, new_image):
paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
ocr_result_list = []
ori_im = new_image.copy()
for box_ocr_res in ocr_res:
if len(box_ocr_res) == 2:
......@@ -273,6 +276,11 @@ def get_ocr_result_list(ocr_res, useful_list):
else:
p1, p2, p3, p4 = box_ocr_res
text, score = "", 1
if ocr_enable:
tmp_box = copy.deepcopy(np.array([p1, p2, p3, p4]).astype('float32'))
img_crop = get_rotate_crop_image(ori_im, tmp_box)
# average_angle_degrees = calculate_angle_degrees(box_ocr_res[0])
# if average_angle_degrees > 0.5:
poly = [p1, p2, p3, p4]
......@@ -295,12 +303,21 @@ def get_ocr_result_list(ocr_res, useful_list):
p3 = [p3[0] - paste_x + xmin, p3[1] - paste_y + ymin]
p4 = [p4[0] - paste_x + xmin, p4[1] - paste_y + ymin]
ocr_result_list.append({
'category_id': 15,
'poly': p1 + p2 + p3 + p4,
'score': float(round(score, 2)),
'text': text,
})
if ocr_enable:
ocr_result_list.append({
'category_id': 15,
'poly': p1 + p2 + p3 + p4,
'score': float(round(score, 2)),
'text': text,
'np_img': img_crop,
})
else:
ocr_result_list.append({
'category_id': 15,
'poly': p1 + p2 + p3 + p4,
'score': float(round(score, 2)),
'text': text,
})
return ocr_result_list
......
......@@ -21,12 +21,9 @@ from magic_pdf.libs.config_reader import get_local_layoutreader_model_dir, get_l
from magic_pdf.libs.convert_utils import dict_to_list
from magic_pdf.libs.hash_utils import compute_md5
from magic_pdf.libs.pdf_image_tools import cut_image_to_pil_image
from magic_pdf.libs.performance_stats import measure_time, PerformanceStats
from magic_pdf.model.magic_model import MagicModel
from magic_pdf.post_proc.llm_aided import llm_aided_formula, llm_aided_text, llm_aided_title
from concurrent.futures import ThreadPoolExecutor
from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
from magic_pdf.post_proc.para_split_v3 import para_split
from magic_pdf.pre_proc.construct_page_dict import ocr_construct_page_component_v2
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment