Commit a330651d authored by myhloli's avatar myhloli
Browse files

feat(ocr): implement separate detection and recognition processes

- Split OCR process into detection and recognition stages
- Update batch analysis and document analysis pipelines
- Modify OCR result formatting and handling
- Remove unused imports and optimize code structure
parent a9b37b71
...@@ -8,7 +8,7 @@ from magic_pdf.config.constants import MODEL_NAME ...@@ -8,7 +8,7 @@ from magic_pdf.config.constants import MODEL_NAME
from magic_pdf.model.pdf_extract_kit import CustomPEKModel from magic_pdf.model.pdf_extract_kit import CustomPEKModel
from magic_pdf.model.sub_modules.model_utils import ( from magic_pdf.model.sub_modules.model_utils import (
clean_vram, crop_img, get_res_list_from_layout_res) clean_vram, crop_img, get_res_list_from_layout_res)
from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import ( from magic_pdf.model.sub_modules.ocr.paddleocr2pytorch.ocr_utils import (
get_adjusted_mfdetrec_res, get_ocr_result_list) get_adjusted_mfdetrec_res, get_ocr_result_list)
YOLO_LAYOUT_BASE_BATCH_SIZE = 1 YOLO_LAYOUT_BASE_BATCH_SIZE = 1
...@@ -85,8 +85,8 @@ class BatchAnalyze: ...@@ -85,8 +85,8 @@ class BatchAnalyze:
# 清理显存 # 清理显存
clean_vram(self.model.device, vram_threshold=8) clean_vram(self.model.device, vram_threshold=8)
ocr_time = 0 det_time = 0
ocr_count = 0 det_count = 0
table_time = 0 table_time = 0
table_count = 0 table_count = 0
# reference: magic_pdf/model/doc_analyze_by_custom_model.py:doc_analyze # reference: magic_pdf/model/doc_analyze_by_custom_model.py:doc_analyze
...@@ -100,7 +100,7 @@ class BatchAnalyze: ...@@ -100,7 +100,7 @@ class BatchAnalyze:
get_res_list_from_layout_res(layout_res) get_res_list_from_layout_res(layout_res)
) )
# ocr识别 # ocr识别
ocr_start = time.time() det_start = time.time()
# Process each area that requires OCR processing # Process each area that requires OCR processing
for res in ocr_res_list: for res in ocr_res_list:
new_image, useful_list = crop_img( new_image, useful_list = crop_img(
...@@ -113,21 +113,21 @@ class BatchAnalyze: ...@@ -113,21 +113,21 @@ class BatchAnalyze:
# OCR recognition # OCR recognition
new_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR) new_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR)
if ocr_enable: # if ocr_enable:
ocr_res = self.model.ocr_model.ocr( # ocr_res = self.model.ocr_model.ocr(
new_image, mfd_res=adjusted_mfdetrec_res # new_image, mfd_res=adjusted_mfdetrec_res
)[0] # )[0]
else: # else:
ocr_res = self.model.ocr_model.ocr( ocr_res = self.model.ocr_model.ocr(
new_image, mfd_res=adjusted_mfdetrec_res, rec=False new_image, mfd_res=adjusted_mfdetrec_res, rec=False
)[0] )[0]
# Integration results # Integration results
if ocr_res: if ocr_res:
ocr_result_list = get_ocr_result_list(ocr_res, useful_list) ocr_result_list = get_ocr_result_list(ocr_res, useful_list, ocr_enable, new_image)
layout_res.extend(ocr_result_list) layout_res.extend(ocr_result_list)
ocr_time += time.time() - ocr_start det_time += time.time() - det_start
ocr_count += len(ocr_res_list) det_count += len(ocr_res_list)
# 表格识别 table recognition # 表格识别 table recognition
if self.model.apply_table: if self.model.apply_table:
...@@ -172,9 +172,33 @@ class BatchAnalyze: ...@@ -172,9 +172,33 @@ class BatchAnalyze:
table_time += time.time() - table_start table_time += time.time() - table_start
table_count += len(table_res_list) table_count += len(table_res_list)
if self.model.apply_ocr:
logger.info(f'det or det time costs: {round(ocr_time, 2)}, image num: {ocr_count}') logger.info(f'ocr-det time: {round(det_time, 2)}, image num: {det_count}')
if self.model.apply_table: if self.model.apply_table:
logger.info(f'table time: {round(table_time, 2)}, image num: {table_count}') logger.info(f'table time: {round(table_time, 2)}, image num: {table_count}')
need_ocr_list = []
img_crop_list = []
for layout_res in images_layout_res:
for layout_res_item in layout_res:
if layout_res_item['category_id'] in [15]:
if 'np_img' in layout_res_item:
need_ocr_list.append(layout_res_item)
img_crop_list.append(layout_res_item['np_img'])
layout_res_item.pop('np_img')
rec_time = 0
rec_start = time.time()
if len(img_crop_list) > 0:
ocr_res_list = self.model.ocr_model.ocr(img_crop_list, det=False)[0]
assert len(ocr_res_list)==len(need_ocr_list), f'ocr_res_list: {len(ocr_res_list)}, need_ocr_list: {len(need_ocr_list)}'
for index, layout_res_item in enumerate(need_ocr_list):
ocr_text, ocr_score = ocr_res_list[index]
layout_res_item['text'] = ocr_text
layout_res_item['score'] = float(round(ocr_score, 2))
rec_time += time.time() - rec_start
logger.info(f'ocr-rec time: {round(rec_time, 2)}, image num: {len(img_crop_list)}')
return images_layout_res return images_layout_res
...@@ -141,7 +141,7 @@ def doc_analyze( ...@@ -141,7 +141,7 @@ def doc_analyze(
else len(dataset) - 1 else len(dataset) - 1
) )
MIN_BATCH_INFERENCE_SIZE = int(os.environ.get('MINERU_MIN_BATCH_INFERENCE_SIZE', 100)) MIN_BATCH_INFERENCE_SIZE = int(os.environ.get('MINERU_MIN_BATCH_INFERENCE_SIZE', 200))
images = [] images = []
page_wh_list = [] page_wh_list = []
for index in range(len(dataset)): for index in range(len(dataset)):
......
# Copyright (c) Opendatalab. All rights reserved. # Copyright (c) Opendatalab. All rights reserved.
import copy
import cv2 import cv2
import numpy as np import numpy as np
from magic_pdf.pre_proc.ocr_dict_merge import merge_spans_to_line from magic_pdf.pre_proc.ocr_dict_merge import merge_spans_to_line
...@@ -259,9 +261,10 @@ def get_adjusted_mfdetrec_res(single_page_mfdetrec_res, useful_list): ...@@ -259,9 +261,10 @@ def get_adjusted_mfdetrec_res(single_page_mfdetrec_res, useful_list):
return adjusted_mfdetrec_res return adjusted_mfdetrec_res
def get_ocr_result_list(ocr_res, useful_list): def get_ocr_result_list(ocr_res, useful_list, ocr_enable, new_image):
paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
ocr_result_list = [] ocr_result_list = []
ori_im = new_image.copy()
for box_ocr_res in ocr_res: for box_ocr_res in ocr_res:
if len(box_ocr_res) == 2: if len(box_ocr_res) == 2:
...@@ -273,6 +276,11 @@ def get_ocr_result_list(ocr_res, useful_list): ...@@ -273,6 +276,11 @@ def get_ocr_result_list(ocr_res, useful_list):
else: else:
p1, p2, p3, p4 = box_ocr_res p1, p2, p3, p4 = box_ocr_res
text, score = "", 1 text, score = "", 1
if ocr_enable:
tmp_box = copy.deepcopy(np.array([p1, p2, p3, p4]).astype('float32'))
img_crop = get_rotate_crop_image(ori_im, tmp_box)
# average_angle_degrees = calculate_angle_degrees(box_ocr_res[0]) # average_angle_degrees = calculate_angle_degrees(box_ocr_res[0])
# if average_angle_degrees > 0.5: # if average_angle_degrees > 0.5:
poly = [p1, p2, p3, p4] poly = [p1, p2, p3, p4]
...@@ -295,12 +303,21 @@ def get_ocr_result_list(ocr_res, useful_list): ...@@ -295,12 +303,21 @@ def get_ocr_result_list(ocr_res, useful_list):
p3 = [p3[0] - paste_x + xmin, p3[1] - paste_y + ymin] p3 = [p3[0] - paste_x + xmin, p3[1] - paste_y + ymin]
p4 = [p4[0] - paste_x + xmin, p4[1] - paste_y + ymin] p4 = [p4[0] - paste_x + xmin, p4[1] - paste_y + ymin]
ocr_result_list.append({ if ocr_enable:
'category_id': 15, ocr_result_list.append({
'poly': p1 + p2 + p3 + p4, 'category_id': 15,
'score': float(round(score, 2)), 'poly': p1 + p2 + p3 + p4,
'text': text, 'score': float(round(score, 2)),
}) 'text': text,
'np_img': img_crop,
})
else:
ocr_result_list.append({
'category_id': 15,
'poly': p1 + p2 + p3 + p4,
'score': float(round(score, 2)),
'text': text,
})
return ocr_result_list return ocr_result_list
......
...@@ -21,12 +21,9 @@ from magic_pdf.libs.config_reader import get_local_layoutreader_model_dir, get_l ...@@ -21,12 +21,9 @@ from magic_pdf.libs.config_reader import get_local_layoutreader_model_dir, get_l
from magic_pdf.libs.convert_utils import dict_to_list from magic_pdf.libs.convert_utils import dict_to_list
from magic_pdf.libs.hash_utils import compute_md5 from magic_pdf.libs.hash_utils import compute_md5
from magic_pdf.libs.pdf_image_tools import cut_image_to_pil_image from magic_pdf.libs.pdf_image_tools import cut_image_to_pil_image
from magic_pdf.libs.performance_stats import measure_time, PerformanceStats
from magic_pdf.model.magic_model import MagicModel from magic_pdf.model.magic_model import MagicModel
from magic_pdf.post_proc.llm_aided import llm_aided_formula, llm_aided_text, llm_aided_title from magic_pdf.post_proc.llm_aided import llm_aided_formula, llm_aided_text, llm_aided_title
from concurrent.futures import ThreadPoolExecutor
from magic_pdf.model.sub_modules.model_init import AtomModelSingleton from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
from magic_pdf.post_proc.para_split_v3 import para_split from magic_pdf.post_proc.para_split_v3 import para_split
from magic_pdf.pre_proc.construct_page_dict import ocr_construct_page_component_v2 from magic_pdf.pre_proc.construct_page_dict import ocr_construct_page_component_v2
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment