Commit d7d85a28 authored by myhloli's avatar myhloli
Browse files

feat(ocr): implement language-specific OCR processing

- Add support for multiple languages in OCR processing
- Create separate lists for each language to improve processing efficiency
- Update OCR model initialization to use PytorchPaddleOCR instead of ModifiedPaddleOCR
- Modify get_ocr_result_list function to include language information- Improve logging for OCR detection and recognition
parent a330651d
...@@ -124,7 +124,7 @@ class BatchAnalyze: ...@@ -124,7 +124,7 @@ class BatchAnalyze:
# Integration results # Integration results
if ocr_res: if ocr_res:
ocr_result_list = get_ocr_result_list(ocr_res, useful_list, ocr_enable, new_image) ocr_result_list = get_ocr_result_list(ocr_res, useful_list, ocr_enable, new_image, _lang)
layout_res.extend(ocr_result_list) layout_res.extend(ocr_result_list)
det_time += time.time() - det_start det_time += time.time() - det_start
det_count += len(ocr_res_list) det_count += len(ocr_res_list)
...@@ -177,27 +177,58 @@ class BatchAnalyze: ...@@ -177,27 +177,58 @@ class BatchAnalyze:
if self.model.apply_table: if self.model.apply_table:
logger.info(f'table time: {round(table_time, 2)}, image num: {table_count}') logger.info(f'table time: {round(table_time, 2)}, image num: {table_count}')
need_ocr_list = [] # Create dictionaries to store items by language
img_crop_list = [] need_ocr_lists_by_lang = {} # Dict of lists for each language
img_crop_lists_by_lang = {} # Dict of lists for each language
for layout_res in images_layout_res: for layout_res in images_layout_res:
for layout_res_item in layout_res: for layout_res_item in layout_res:
if layout_res_item['category_id'] in [15]: if layout_res_item['category_id'] in [15]:
if 'np_img' in layout_res_item: if 'np_img' in layout_res_item and 'lang' in layout_res_item:
need_ocr_list.append(layout_res_item) lang = layout_res_item['lang']
img_crop_list.append(layout_res_item['np_img'])
# Initialize lists for this language if not exist
if lang not in need_ocr_lists_by_lang:
need_ocr_lists_by_lang[lang] = []
img_crop_lists_by_lang[lang] = []
# Add to the appropriate language-specific lists
need_ocr_lists_by_lang[lang].append(layout_res_item)
img_crop_lists_by_lang[lang].append(layout_res_item['np_img'])
# Remove the fields after adding to lists
layout_res_item.pop('np_img') layout_res_item.pop('np_img')
layout_res_item.pop('lang')
if len(img_crop_lists_by_lang) > 0:
# Process OCR by language
rec_time = 0
rec_start = time.time()
total_processed = 0
# Process each language separately
for lang, img_crop_list in img_crop_lists_by_lang.items():
if len(img_crop_list) > 0:
# Get OCR results for this language's images
ocr_res_list = self.model.ocr_model.ocr(img_crop_list, det=False)[0]
need_ocr_list = need_ocr_lists_by_lang[lang]
# Verify we have matching counts
assert len(ocr_res_list) == len(
need_ocr_list), f'ocr_res_list: {len(ocr_res_list)}, need_ocr_list: {len(need_ocr_list)} for lang: {lang}'
# Process OCR results for this language
for index, layout_res_item in enumerate(need_ocr_list):
ocr_text, ocr_score = ocr_res_list[index]
layout_res_item['text'] = ocr_text
layout_res_item['score'] = float(round(ocr_score, 2))
total_processed += len(img_crop_list)
rec_time = 0 rec_time += time.time() - rec_start
rec_start = time.time() logger.info(f'ocr-rec time: {round(rec_time, 2)}, total images processed: {total_processed}')
if len(img_crop_list) > 0:
ocr_res_list = self.model.ocr_model.ocr(img_crop_list, det=False)[0]
assert len(ocr_res_list)==len(need_ocr_list), f'ocr_res_list: {len(ocr_res_list)}, need_ocr_list: {len(need_ocr_list)}'
for index, layout_res_item in enumerate(need_ocr_list):
ocr_text, ocr_score = ocr_res_list[index]
layout_res_item['text'] = ocr_text
layout_res_item['score'] = float(round(ocr_score, 2))
rec_time += time.time() - rec_start
logger.info(f'ocr-rec time: {round(rec_time, 2)}, image num: {len(img_crop_list)}')
......
...@@ -14,7 +14,7 @@ from magic_pdf.model.model_list import AtomicModel ...@@ -14,7 +14,7 @@ from magic_pdf.model.model_list import AtomicModel
from magic_pdf.model.sub_modules.model_init import AtomModelSingleton from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
from magic_pdf.model.sub_modules.model_utils import ( from magic_pdf.model.sub_modules.model_utils import (
clean_vram, crop_img, get_res_list_from_layout_res) clean_vram, crop_img, get_res_list_from_layout_res)
from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import ( from magic_pdf.model.sub_modules.ocr.paddleocr2pytorch.ocr_utils import (
get_adjusted_mfdetrec_res, get_ocr_result_list) get_adjusted_mfdetrec_res, get_ocr_result_list)
......
...@@ -7,32 +7,33 @@ from magic_pdf.model.sub_modules.language_detection.yolov11.YOLOv11 import YOLOv ...@@ -7,32 +7,33 @@ from magic_pdf.model.sub_modules.language_detection.yolov11.YOLOv11 import YOLOv
from magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO import DocLayoutYOLOModel from magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO import DocLayoutYOLOModel
from magic_pdf.model.sub_modules.mfd.yolov8.YOLOv8 import YOLOv8MFDModel from magic_pdf.model.sub_modules.mfd.yolov8.YOLOv8 import YOLOv8MFDModel
from magic_pdf.model.sub_modules.mfr.unimernet.Unimernet import UnimernetModel from magic_pdf.model.sub_modules.mfr.unimernet.Unimernet import UnimernetModel
from magic_pdf.model.sub_modules.ocr.paddleocr2pytorch.pytorch_paddle import PytorchPaddleOCR
try: from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import RapidTableModel
from magic_pdf_ascend_plugin.libs.license_verifier import ( # try:
LicenseExpiredError, LicenseFormatError, LicenseSignatureError, # from magic_pdf_ascend_plugin.libs.license_verifier import (
load_license) # LicenseExpiredError, LicenseFormatError, LicenseSignatureError,
from magic_pdf_ascend_plugin.model_plugin.ocr.paddleocr.ppocr_273_npu import ModifiedPaddleOCR # load_license)
from magic_pdf_ascend_plugin.model_plugin.table.rapidtable.rapid_table_npu import RapidTableModel # from magic_pdf_ascend_plugin.model_plugin.ocr.paddleocr.ppocr_273_npu import ModifiedPaddleOCR
license_key = load_license() # from magic_pdf_ascend_plugin.model_plugin.table.rapidtable.rapid_table_npu import RapidTableModel
logger.info(f'Using Ascend Plugin Success, License id is {license_key["payload"]["id"]},' # license_key = load_license()
f' License expired at {license_key["payload"]["date"]["end_date"]}') # logger.info(f'Using Ascend Plugin Success, License id is {license_key["payload"]["id"]},'
except Exception as e: # f' License expired at {license_key["payload"]["date"]["end_date"]}')
if isinstance(e, ImportError): # except Exception as e:
pass # if isinstance(e, ImportError):
elif isinstance(e, LicenseFormatError): # pass
logger.error('Ascend Plugin: Invalid license format. Please check the license file.') # elif isinstance(e, LicenseFormatError):
elif isinstance(e, LicenseSignatureError): # logger.error('Ascend Plugin: Invalid license format. Please check the license file.')
logger.error('Ascend Plugin: Invalid signature. The license may be tampered with.') # elif isinstance(e, LicenseSignatureError):
elif isinstance(e, LicenseExpiredError): # logger.error('Ascend Plugin: Invalid signature. The license may be tampered with.')
logger.error('Ascend Plugin: License has expired. Please renew your license.') # elif isinstance(e, LicenseExpiredError):
elif isinstance(e, FileNotFoundError): # logger.error('Ascend Plugin: License has expired. Please renew your license.')
logger.error('Ascend Plugin: Not found License file.') # elif isinstance(e, FileNotFoundError):
else: # logger.error('Ascend Plugin: Not found License file.')
logger.error(f'Ascend Plugin: {e}') # else:
from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_273_mod import ModifiedPaddleOCR # logger.error(f'Ascend Plugin: {e}')
# from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_291_mod import ModifiedPaddleOCR # from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_273_mod import ModifiedPaddleOCR
from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import RapidTableModel # # from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_291_mod import ModifiedPaddleOCR
# from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import RapidTableModel
def table_model_init(table_model_type, model_path, max_time, _device_='cpu', ocr_engine=None, table_sub_model_name=None): def table_model_init(table_model_type, model_path, max_time, _device_='cpu', ocr_engine=None, table_sub_model_name=None):
...@@ -94,7 +95,8 @@ def ocr_model_init(show_log: bool = False, ...@@ -94,7 +95,8 @@ def ocr_model_init(show_log: bool = False,
det_db_unclip_ratio=1.8, det_db_unclip_ratio=1.8,
): ):
if lang is not None and lang != '': if lang is not None and lang != '':
model = ModifiedPaddleOCR( # model = ModifiedPaddleOCR(
model = PytorchPaddleOCR(
show_log=show_log, show_log=show_log,
det_db_box_thresh=det_db_box_thresh, det_db_box_thresh=det_db_box_thresh,
lang=lang, lang=lang,
...@@ -102,7 +104,8 @@ def ocr_model_init(show_log: bool = False, ...@@ -102,7 +104,8 @@ def ocr_model_init(show_log: bool = False,
det_db_unclip_ratio=det_db_unclip_ratio, det_db_unclip_ratio=det_db_unclip_ratio,
) )
else: else:
model = ModifiedPaddleOCR( # model = ModifiedPaddleOCR(
model = PytorchPaddleOCR(
show_log=show_log, show_log=show_log,
det_db_box_thresh=det_db_box_thresh, det_db_box_thresh=det_db_box_thresh,
use_dilation=use_dilation, use_dilation=use_dilation,
......
...@@ -261,7 +261,7 @@ def get_adjusted_mfdetrec_res(single_page_mfdetrec_res, useful_list): ...@@ -261,7 +261,7 @@ def get_adjusted_mfdetrec_res(single_page_mfdetrec_res, useful_list):
return adjusted_mfdetrec_res return adjusted_mfdetrec_res
def get_ocr_result_list(ocr_res, useful_list, ocr_enable, new_image): def get_ocr_result_list(ocr_res, useful_list, ocr_enable, new_image, lang):
paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
ocr_result_list = [] ocr_result_list = []
ori_im = new_image.copy() ori_im = new_image.copy()
...@@ -307,9 +307,10 @@ def get_ocr_result_list(ocr_res, useful_list, ocr_enable, new_image): ...@@ -307,9 +307,10 @@ def get_ocr_result_list(ocr_res, useful_list, ocr_enable, new_image):
ocr_result_list.append({ ocr_result_list.append({
'category_id': 15, 'category_id': 15,
'poly': p1 + p2 + p3 + p4, 'poly': p1 + p2 + p3 + p4,
'score': float(round(score, 2)), 'score': 1,
'text': text, 'text': text,
'np_img': img_crop, 'np_img': img_crop,
'lang': lang,
}) })
else: else:
ocr_result_list.append({ ocr_result_list.append({
......
...@@ -66,6 +66,7 @@ class PytorchPaddleOCR(TextSystem): ...@@ -66,6 +66,7 @@ class PytorchPaddleOCR(TextSystem):
for img in imgs: for img in imgs:
img = preprocess_image(img) img = preprocess_image(img)
dt_boxes, elapse = self.text_detector(img) dt_boxes, elapse = self.text_detector(img)
logger.debug("dt_boxes num : {}, elapsed : {}".format(len(dt_boxes), elapse))
if dt_boxes is None: if dt_boxes is None:
ocr_res.append(None) ocr_res.append(None)
continue continue
...@@ -84,6 +85,7 @@ class PytorchPaddleOCR(TextSystem): ...@@ -84,6 +85,7 @@ class PytorchPaddleOCR(TextSystem):
img = preprocess_image(img) img = preprocess_image(img)
img = [img] img = [img]
rec_res, elapse = self.text_recognizer(img) rec_res, elapse = self.text_recognizer(img)
logger.debug("rec_res num : {}, elapsed : {}".format(len(rec_res), elapse))
ocr_res.append(rec_res) ocr_res.append(rec_res)
return ocr_res return ocr_res
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment