Commit d7d85a28 authored by myhloli's avatar myhloli
Browse files

feat(ocr): implement language-specific OCR processing

- Add support for multiple languages in OCR processing
- Create separate lists for each language to improve processing efficiency
- Update OCR model initialization to use PytorchPaddleOCR instead of ModifiedPaddleOCR
- Modify get_ocr_result_list function to include language information- Improve logging for OCR detection and recognition
parent a330651d
......@@ -124,7 +124,7 @@ class BatchAnalyze:
# Integration results
if ocr_res:
ocr_result_list = get_ocr_result_list(ocr_res, useful_list, ocr_enable, new_image)
ocr_result_list = get_ocr_result_list(ocr_res, useful_list, ocr_enable, new_image, _lang)
layout_res.extend(ocr_result_list)
det_time += time.time() - det_start
det_count += len(ocr_res_list)
......@@ -177,27 +177,58 @@ class BatchAnalyze:
if self.model.apply_table:
logger.info(f'table time: {round(table_time, 2)}, image num: {table_count}')
need_ocr_list = []
img_crop_list = []
# Create dictionaries to store items by language
need_ocr_lists_by_lang = {} # Dict of lists for each language
img_crop_lists_by_lang = {} # Dict of lists for each language
for layout_res in images_layout_res:
for layout_res_item in layout_res:
if layout_res_item['category_id'] in [15]:
if 'np_img' in layout_res_item:
need_ocr_list.append(layout_res_item)
img_crop_list.append(layout_res_item['np_img'])
if 'np_img' in layout_res_item and 'lang' in layout_res_item:
lang = layout_res_item['lang']
# Initialize lists for this language if not exist
if lang not in need_ocr_lists_by_lang:
need_ocr_lists_by_lang[lang] = []
img_crop_lists_by_lang[lang] = []
# Add to the appropriate language-specific lists
need_ocr_lists_by_lang[lang].append(layout_res_item)
img_crop_lists_by_lang[lang].append(layout_res_item['np_img'])
# Remove the fields after adding to lists
layout_res_item.pop('np_img')
layout_res_item.pop('lang')
if len(img_crop_lists_by_lang) > 0:
# Process OCR by language
rec_time = 0
rec_start = time.time()
total_processed = 0
# Process each language separately
for lang, img_crop_list in img_crop_lists_by_lang.items():
if len(img_crop_list) > 0:
# Get OCR results for this language's images
ocr_res_list = self.model.ocr_model.ocr(img_crop_list, det=False)[0]
need_ocr_list = need_ocr_lists_by_lang[lang]
# Verify we have matching counts
assert len(ocr_res_list) == len(
need_ocr_list), f'ocr_res_list: {len(ocr_res_list)}, need_ocr_list: {len(need_ocr_list)} for lang: {lang}'
# Process OCR results for this language
for index, layout_res_item in enumerate(need_ocr_list):
ocr_text, ocr_score = ocr_res_list[index]
layout_res_item['text'] = ocr_text
layout_res_item['score'] = float(round(ocr_score, 2))
total_processed += len(img_crop_list)
rec_time = 0
rec_start = time.time()
if len(img_crop_list) > 0:
ocr_res_list = self.model.ocr_model.ocr(img_crop_list, det=False)[0]
assert len(ocr_res_list)==len(need_ocr_list), f'ocr_res_list: {len(ocr_res_list)}, need_ocr_list: {len(need_ocr_list)}'
for index, layout_res_item in enumerate(need_ocr_list):
ocr_text, ocr_score = ocr_res_list[index]
layout_res_item['text'] = ocr_text
layout_res_item['score'] = float(round(ocr_score, 2))
rec_time += time.time() - rec_start
logger.info(f'ocr-rec time: {round(rec_time, 2)}, image num: {len(img_crop_list)}')
rec_time += time.time() - rec_start
logger.info(f'ocr-rec time: {round(rec_time, 2)}, total images processed: {total_processed}')
......
......@@ -14,7 +14,7 @@ from magic_pdf.model.model_list import AtomicModel
from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
from magic_pdf.model.sub_modules.model_utils import (
clean_vram, crop_img, get_res_list_from_layout_res)
from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import (
from magic_pdf.model.sub_modules.ocr.paddleocr2pytorch.ocr_utils import (
get_adjusted_mfdetrec_res, get_ocr_result_list)
......
......@@ -7,32 +7,33 @@ from magic_pdf.model.sub_modules.language_detection.yolov11.YOLOv11 import YOLOv
from magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO import DocLayoutYOLOModel
from magic_pdf.model.sub_modules.mfd.yolov8.YOLOv8 import YOLOv8MFDModel
from magic_pdf.model.sub_modules.mfr.unimernet.Unimernet import UnimernetModel
try:
from magic_pdf_ascend_plugin.libs.license_verifier import (
LicenseExpiredError, LicenseFormatError, LicenseSignatureError,
load_license)
from magic_pdf_ascend_plugin.model_plugin.ocr.paddleocr.ppocr_273_npu import ModifiedPaddleOCR
from magic_pdf_ascend_plugin.model_plugin.table.rapidtable.rapid_table_npu import RapidTableModel
license_key = load_license()
logger.info(f'Using Ascend Plugin Success, License id is {license_key["payload"]["id"]},'
f' License expired at {license_key["payload"]["date"]["end_date"]}')
except Exception as e:
if isinstance(e, ImportError):
pass
elif isinstance(e, LicenseFormatError):
logger.error('Ascend Plugin: Invalid license format. Please check the license file.')
elif isinstance(e, LicenseSignatureError):
logger.error('Ascend Plugin: Invalid signature. The license may be tampered with.')
elif isinstance(e, LicenseExpiredError):
logger.error('Ascend Plugin: License has expired. Please renew your license.')
elif isinstance(e, FileNotFoundError):
logger.error('Ascend Plugin: Not found License file.')
else:
logger.error(f'Ascend Plugin: {e}')
from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_273_mod import ModifiedPaddleOCR
# from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_291_mod import ModifiedPaddleOCR
from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import RapidTableModel
from magic_pdf.model.sub_modules.ocr.paddleocr2pytorch.pytorch_paddle import PytorchPaddleOCR
from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import RapidTableModel
# try:
# from magic_pdf_ascend_plugin.libs.license_verifier import (
# LicenseExpiredError, LicenseFormatError, LicenseSignatureError,
# load_license)
# from magic_pdf_ascend_plugin.model_plugin.ocr.paddleocr.ppocr_273_npu import ModifiedPaddleOCR
# from magic_pdf_ascend_plugin.model_plugin.table.rapidtable.rapid_table_npu import RapidTableModel
# license_key = load_license()
# logger.info(f'Using Ascend Plugin Success, License id is {license_key["payload"]["id"]},'
# f' License expired at {license_key["payload"]["date"]["end_date"]}')
# except Exception as e:
# if isinstance(e, ImportError):
# pass
# elif isinstance(e, LicenseFormatError):
# logger.error('Ascend Plugin: Invalid license format. Please check the license file.')
# elif isinstance(e, LicenseSignatureError):
# logger.error('Ascend Plugin: Invalid signature. The license may be tampered with.')
# elif isinstance(e, LicenseExpiredError):
# logger.error('Ascend Plugin: License has expired. Please renew your license.')
# elif isinstance(e, FileNotFoundError):
# logger.error('Ascend Plugin: Not found License file.')
# else:
# logger.error(f'Ascend Plugin: {e}')
# from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_273_mod import ModifiedPaddleOCR
# # from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_291_mod import ModifiedPaddleOCR
# from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import RapidTableModel
def table_model_init(table_model_type, model_path, max_time, _device_='cpu', ocr_engine=None, table_sub_model_name=None):
......@@ -94,7 +95,8 @@ def ocr_model_init(show_log: bool = False,
det_db_unclip_ratio=1.8,
):
if lang is not None and lang != '':
model = ModifiedPaddleOCR(
# model = ModifiedPaddleOCR(
model = PytorchPaddleOCR(
show_log=show_log,
det_db_box_thresh=det_db_box_thresh,
lang=lang,
......@@ -102,7 +104,8 @@ def ocr_model_init(show_log: bool = False,
det_db_unclip_ratio=det_db_unclip_ratio,
)
else:
model = ModifiedPaddleOCR(
# model = ModifiedPaddleOCR(
model = PytorchPaddleOCR(
show_log=show_log,
det_db_box_thresh=det_db_box_thresh,
use_dilation=use_dilation,
......
......@@ -261,7 +261,7 @@ def get_adjusted_mfdetrec_res(single_page_mfdetrec_res, useful_list):
return adjusted_mfdetrec_res
def get_ocr_result_list(ocr_res, useful_list, ocr_enable, new_image):
def get_ocr_result_list(ocr_res, useful_list, ocr_enable, new_image, lang):
paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
ocr_result_list = []
ori_im = new_image.copy()
......@@ -307,9 +307,10 @@ def get_ocr_result_list(ocr_res, useful_list, ocr_enable, new_image):
ocr_result_list.append({
'category_id': 15,
'poly': p1 + p2 + p3 + p4,
'score': float(round(score, 2)),
'score': 1,
'text': text,
'np_img': img_crop,
'lang': lang,
})
else:
ocr_result_list.append({
......
......@@ -66,6 +66,7 @@ class PytorchPaddleOCR(TextSystem):
for img in imgs:
img = preprocess_image(img)
dt_boxes, elapse = self.text_detector(img)
logger.debug("dt_boxes num : {}, elapsed : {}".format(len(dt_boxes), elapse))
if dt_boxes is None:
ocr_res.append(None)
continue
......@@ -84,6 +85,7 @@ class PytorchPaddleOCR(TextSystem):
img = preprocess_image(img)
img = [img]
rec_res, elapse = self.text_recognizer(img)
logger.debug("rec_res num : {}, elapsed : {}".format(len(rec_res), elapse))
ocr_res.append(rec_res)
return ocr_res
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment