Commit 356cb1f2 authored by myhloli's avatar myhloli
Browse files

feat(language-detection): improve language detection accuracy for specific languages

- Add separate models for Chinese/Japanese and English/French/German detection
- Implement mode-based detection to use appropriate models for different languages
- Update language detection process to use higher DPI for better accuracy
- Modify model initialization and prediction logic to support new language-specific models
parent 3fcac5ef
......@@ -51,6 +51,7 @@ magic-pdf --help
## 已知问题
- paddleocr使用内嵌onnx模型,仅支持中英文ocr,不支持其他语言ocr
- paddleocr使用内嵌onnx模型,仅在默认语言配置下能以较快速度对中英文进行识别
- 自定义lang参数时,paddleocr速度会存在明显下降情况
- layout模型使用layoutlmv3时会发生间歇性崩溃,建议使用默认配置的doclayout_yolo模型
- 表格解析仅适配了rapid_table模型,其他模型可能会无法使用
\ No newline at end of file
......@@ -12,7 +12,7 @@ from magic_pdf.data.utils import load_images_from_pdf
from magic_pdf.libs.config_reader import get_local_models_dir, get_device
from magic_pdf.libs.pdf_check import extract_pages
from magic_pdf.model.model_list import AtomicModel
from magic_pdf.model.sub_modules.language_detection.yolov11.YOLOv11 import YOLOv11LangDetModel
from magic_pdf.model.sub_modules.language_detection.yolov11.YOLOv11 import YOLOv11LangDetModel, LangDetectMode
from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
......@@ -59,15 +59,21 @@ def get_text_images(simple_images):
def auto_detect_lang(pdf_bytes: bytes):
sample_docs = extract_pages(pdf_bytes)
sample_pdf_bytes = sample_docs.tobytes()
simple_images = load_images_from_pdf(sample_pdf_bytes, dpi=96)
simple_images = load_images_from_pdf(sample_pdf_bytes, dpi=200)
text_images = get_text_images(simple_images)
local_models_dir, device, configs = get_model_config()
# 用yolo11做语言分类
langdetect_model_weights = str(
langdetect_model_weights_dir = str(
os.path.join(
local_models_dir, configs['weights'][MODEL_NAME.YOLO_V11_LangDetect]
)
)
langdetect_model = YOLOv11LangDetModel(langdetect_model_weights, device)
langdetect_model = YOLOv11LangDetModel(langdetect_model_weights_dir, device)
lang = langdetect_model.do_detect(text_images)
if lang in ["ch", "japan"]:
lang = langdetect_model.do_detect(text_images, mode=LangDetectMode.CH_JP)
elif lang in ["en", "fr", "german"]:
lang = langdetect_model.do_detect(text_images, mode=LangDetectMode.EN_FR_GE)
return lang
\ No newline at end of file
# Copyright (c) Opendatalab. All rights reserved.
import os
from collections import Counter
from uuid import uuid4
import torch
from PIL import Image
from loguru import logger
from ultralytics import YOLO
......@@ -17,6 +19,11 @@ language_dict = {
"ru": "俄语"
}
class LangDetectMode:
BASE = "base"
CH_JP = "ch_jp"
EN_FR_GE = "en_fr_ge"
def split_images(image, result_images=None):
"""
......@@ -83,11 +90,25 @@ def resize_images_to_224(image):
class YOLOv11LangDetModel(object):
def __init__(self, weight, device):
self.model = YOLO(weight)
self.device = device
def do_detect(self, images: list):
def __init__(self, langdetect_model_weights_dir, device):
langdetect_model_base_weight = str(
os.path.join(langdetect_model_weights_dir, 'yolo_v11_cls_ft.pt')
)
langdetect_model_ch_jp_weight = str(
os.path.join(langdetect_model_weights_dir, 'yolo_v11_cls_ch_jp.pt')
)
langdetect_model_en_fr_ge_weight = str(
os.path.join(langdetect_model_weights_dir, 'yolo_v11_cls_en_fr_ge.pt')
)
self.model = YOLO(langdetect_model_base_weight)
self.ch_jp_model = YOLO(langdetect_model_ch_jp_weight)
self.en_fr_ge_model = YOLO(langdetect_model_en_fr_ge_weight)
if str(device).startswith("npu"):
self.device = torch.device(device)
else:
self.device = device
def do_detect(self, images: list, mode=LangDetectMode.BASE):
all_images = []
for image in images:
width, height = image.size
......@@ -98,7 +119,7 @@ class YOLOv11LangDetModel(object):
for temp_image in temp_images:
all_images.append(resize_images_to_224(temp_image))
images_lang_res = self.batch_predict(all_images, batch_size=8)
images_lang_res = self.batch_predict(all_images, batch_size=8, mode=mode)
logger.info(f"images_lang_res: {images_lang_res}")
if len(images_lang_res) > 0:
count_dict = Counter(images_lang_res)
......@@ -107,20 +128,39 @@ class YOLOv11LangDetModel(object):
language = None
return language
def predict(self, image, mode=LangDetectMode.BASE):
if mode == LangDetectMode.BASE:
model = self.model
elif mode == LangDetectMode.CH_JP:
model = self.ch_jp_model
elif mode == LangDetectMode.EN_FR_GE:
model = self.en_fr_ge_model
else:
model = self.model
def predict(self, image):
results = self.model.predict(image, verbose=False, device=self.device)
results = model.predict(image, verbose=False, device=self.device)
predicted_class_id = int(results[0].probs.top1)
predicted_class_name = self.model.names[predicted_class_id]
predicted_class_name = model.names[predicted_class_id]
return predicted_class_name
def batch_predict(self, images: list, batch_size: int) -> list:
def batch_predict(self, images: list, batch_size: int, mode=LangDetectMode.BASE) -> list:
images_lang_res = []
if mode == LangDetectMode.BASE:
model = self.model
elif mode == LangDetectMode.CH_JP:
model = self.ch_jp_model
elif mode == LangDetectMode.EN_FR_GE:
model = self.en_fr_ge_model
else:
model = self.model
for index in range(0, len(images), batch_size):
lang_res = [
image_res.cpu()
for image_res in self.model.predict(
for image_res in model.predict(
images[index: index + batch_size],
verbose = False,
device=self.device,
......@@ -128,7 +168,7 @@ class YOLOv11LangDetModel(object):
]
for res in lang_res:
predicted_class_id = int(res.probs.top1)
predicted_class_name = self.model.names[predicted_class_id]
predicted_class_name = model.names[predicted_class_id]
images_lang_res.append(predicted_class_name)
return images_lang_res
\ No newline at end of file
......@@ -21,7 +21,7 @@ class ModifiedPaddleOCR(PaddleOCR):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.lang = kwargs.get('lang', 'ch')
# 在cpu架构为arm且不支持cuda时调用onnx、
if not torch.cuda.is_available() and platform.machine() in ['arm64', 'aarch64']:
self.use_onnx = True
......@@ -94,7 +94,7 @@ class ModifiedPaddleOCR(PaddleOCR):
ocr_res = []
for img in imgs:
img = preprocess_image(img)
if self.use_onnx:
if self.lang in ['ch'] and self.use_onnx:
dt_boxes, elapse = self.additional_ocr.text_detector(img)
else:
dt_boxes, elapse = self.text_detector(img)
......@@ -124,7 +124,7 @@ class ModifiedPaddleOCR(PaddleOCR):
img, cls_res_tmp, elapse = self.text_classifier(img)
if not rec:
cls_res.append(cls_res_tmp)
if self.use_onnx:
if self.lang in ['ch'] and self.use_onnx:
rec_res, elapse = self.additional_ocr.text_recognizer(img)
else:
rec_res, elapse = self.text_recognizer(img)
......@@ -142,7 +142,7 @@ class ModifiedPaddleOCR(PaddleOCR):
start = time.time()
ori_im = img.copy()
if self.use_onnx:
if self.lang in ['ch'] and self.use_onnx:
dt_boxes, elapse = self.additional_ocr.text_detector(img)
else:
dt_boxes, elapse = self.text_detector(img)
......@@ -183,7 +183,7 @@ class ModifiedPaddleOCR(PaddleOCR):
time_dict['cls'] = elapse
logger.debug("cls num : {}, elapsed : {}".format(
len(img_crop_list), elapse))
if self.use_onnx:
if self.lang in ['ch'] and self.use_onnx:
rec_res, elapse = self.additional_ocr.text_recognizer(img_crop_list)
else:
rec_res, elapse = self.text_recognizer(img_crop_list)
......
......@@ -6,4 +6,4 @@ weights:
struct_eqtable: TabRec/StructEqTable
tablemaster: TabRec/TableMaster
rapid_table: TabRec/RapidTable
yolo_v11n_langdetect: LangDetect/YOLO/yolo_v11_cls_ft.pt
\ No newline at end of file
yolo_v11n_langdetect: LangDetect/YOLO
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment