"vscode:/vscode.git/clone" did not exist on "ac155f7415c74730a32e3cad7850894c614b1564"
Commit 59d6b195 authored by myhloli's avatar myhloli
Browse files

refactor(model): integrate AtomModelSingleton for OCR and improve OCR result handling

- Replace direct OCR model access with AtomModelSingleton for better model management
- Round OCR scores to 2 decimal places for consistency
- Improve error handling and logging in batch analysis
- Simplify OCR result processing in pdf_parse_union_core_v2.py
parent d7d85a28
...@@ -5,7 +5,7 @@ import torch ...@@ -5,7 +5,7 @@ import torch
from loguru import logger from loguru import logger
from magic_pdf.config.constants import MODEL_NAME from magic_pdf.config.constants import MODEL_NAME
from magic_pdf.model.pdf_extract_kit import CustomPEKModel from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
from magic_pdf.model.sub_modules.model_utils import ( from magic_pdf.model.sub_modules.model_utils import (
clean_vram, crop_img, get_res_list_from_layout_res) clean_vram, crop_img, get_res_list_from_layout_res)
from magic_pdf.model.sub_modules.ocr.paddleocr2pytorch.ocr_utils import ( from magic_pdf.model.sub_modules.ocr.paddleocr2pytorch.ocr_utils import (
...@@ -212,15 +212,21 @@ class BatchAnalyze: ...@@ -212,15 +212,21 @@ class BatchAnalyze:
for lang, img_crop_list in img_crop_lists_by_lang.items(): for lang, img_crop_list in img_crop_lists_by_lang.items():
if len(img_crop_list) > 0: if len(img_crop_list) > 0:
# Get OCR results for this language's images # Get OCR results for this language's images
ocr_res_list = self.model.ocr_model.ocr(img_crop_list, det=False)[0] atom_model_manager = AtomModelSingleton()
need_ocr_list = need_ocr_lists_by_lang[lang] ocr_model = atom_model_manager.get_atom_model(
atom_model_name='ocr',
ocr_show_log=False,
det_db_box_thresh=0.3,
lang=lang
)
ocr_res_list = ocr_model.ocr(img_crop_list, det=False)[0]
# Verify we have matching counts # Verify we have matching counts
assert len(ocr_res_list) == len( assert len(ocr_res_list) == len(
need_ocr_list), f'ocr_res_list: {len(ocr_res_list)}, need_ocr_list: {len(need_ocr_list)} for lang: {lang}' need_ocr_lists_by_lang[lang]), f'ocr_res_list: {len(ocr_res_list)}, need_ocr_list: {len(need_ocr_lists_by_lang[lang])} for lang: {lang}'
# Process OCR results for this language # Process OCR results for this language
for index, layout_res_item in enumerate(need_ocr_list): for index, layout_res_item in enumerate(need_ocr_lists_by_lang[lang]):
ocr_text, ocr_score = ocr_res_list[index] ocr_text, ocr_score = ocr_res_list[index]
layout_res_item['text'] = ocr_text layout_res_item['text'] = ocr_text
layout_res_item['score'] = float(round(ocr_score, 2)) layout_res_item['score'] = float(round(ocr_score, 2))
......
...@@ -309,7 +309,7 @@ def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, lang ...@@ -309,7 +309,7 @@ def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, lang
# logger.info(f"ocr_text: {ocr_text}, ocr_score: {ocr_score}") # logger.info(f"ocr_text: {ocr_text}, ocr_score: {ocr_score}")
if ocr_score > 0.5 and len(ocr_text) > 0: if ocr_score > 0.5 and len(ocr_text) > 0:
span['content'] = ocr_text span['content'] = ocr_text
span['score'] = ocr_score span['score'] = float(round(ocr_score, 2))
else: else:
spans.remove(span) spans.remove(span)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment