Commit 356cb1f2 authored by myhloli's avatar myhloli
Browse files

feat(language-detection): improve language detection accuracy for specific languages

- Add separate models for Chinese/Japanese and English/French/German detection
- Implement mode-based detection to use appropriate models for different languages
- Update language detection process to use higher DPI for better accuracy
- Modify model initialization and prediction logic to support new language-specific models
parent 3fcac5ef
...@@ -51,6 +51,7 @@ magic-pdf --help ...@@ -51,6 +51,7 @@ magic-pdf --help
## 已知问题 ## 已知问题
- paddleocr使用内嵌onnx模型,仅支持中英文ocr,不支持其他语言ocr - paddleocr使用内嵌onnx模型,仅在默认语言配置下能以较快速度对中英文进行识别
- 自定义lang参数时,paddleocr速度会存在明显下降情况
- layout模型使用layoutlmv3时会发生间歇性崩溃,建议使用默认配置的doclayout_yolo模型 - layout模型使用layoutlmv3时会发生间歇性崩溃,建议使用默认配置的doclayout_yolo模型
- 表格解析仅适配了rapid_table模型,其他模型可能会无法使用 - 表格解析仅适配了rapid_table模型,其他模型可能会无法使用
\ No newline at end of file
...@@ -12,7 +12,7 @@ from magic_pdf.data.utils import load_images_from_pdf ...@@ -12,7 +12,7 @@ from magic_pdf.data.utils import load_images_from_pdf
from magic_pdf.libs.config_reader import get_local_models_dir, get_device from magic_pdf.libs.config_reader import get_local_models_dir, get_device
from magic_pdf.libs.pdf_check import extract_pages from magic_pdf.libs.pdf_check import extract_pages
from magic_pdf.model.model_list import AtomicModel from magic_pdf.model.model_list import AtomicModel
from magic_pdf.model.sub_modules.language_detection.yolov11.YOLOv11 import YOLOv11LangDetModel from magic_pdf.model.sub_modules.language_detection.yolov11.YOLOv11 import YOLOv11LangDetModel, LangDetectMode
from magic_pdf.model.sub_modules.model_init import AtomModelSingleton from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
...@@ -59,15 +59,21 @@ def get_text_images(simple_images): ...@@ -59,15 +59,21 @@ def get_text_images(simple_images):
def auto_detect_lang(pdf_bytes: bytes): def auto_detect_lang(pdf_bytes: bytes):
sample_docs = extract_pages(pdf_bytes) sample_docs = extract_pages(pdf_bytes)
sample_pdf_bytes = sample_docs.tobytes() sample_pdf_bytes = sample_docs.tobytes()
simple_images = load_images_from_pdf(sample_pdf_bytes, dpi=96) simple_images = load_images_from_pdf(sample_pdf_bytes, dpi=200)
text_images = get_text_images(simple_images) text_images = get_text_images(simple_images)
local_models_dir, device, configs = get_model_config() local_models_dir, device, configs = get_model_config()
# 用yolo11做语言分类 # 用yolo11做语言分类
langdetect_model_weights = str( langdetect_model_weights_dir = str(
os.path.join( os.path.join(
local_models_dir, configs['weights'][MODEL_NAME.YOLO_V11_LangDetect] local_models_dir, configs['weights'][MODEL_NAME.YOLO_V11_LangDetect]
) )
) )
langdetect_model = YOLOv11LangDetModel(langdetect_model_weights, device) langdetect_model = YOLOv11LangDetModel(langdetect_model_weights_dir, device)
lang = langdetect_model.do_detect(text_images) lang = langdetect_model.do_detect(text_images)
if lang in ["ch", "japan"]:
lang = langdetect_model.do_detect(text_images, mode=LangDetectMode.CH_JP)
elif lang in ["en", "fr", "german"]:
lang = langdetect_model.do_detect(text_images, mode=LangDetectMode.EN_FR_GE)
return lang return lang
\ No newline at end of file
# Copyright (c) Opendatalab. All rights reserved. # Copyright (c) Opendatalab. All rights reserved.
import os
from collections import Counter from collections import Counter
from uuid import uuid4 from uuid import uuid4
import torch
from PIL import Image from PIL import Image
from loguru import logger from loguru import logger
from ultralytics import YOLO from ultralytics import YOLO
...@@ -17,6 +19,11 @@ language_dict = { ...@@ -17,6 +19,11 @@ language_dict = {
"ru": "俄语" "ru": "俄语"
} }
class LangDetectMode:
BASE = "base"
CH_JP = "ch_jp"
EN_FR_GE = "en_fr_ge"
def split_images(image, result_images=None): def split_images(image, result_images=None):
""" """
...@@ -83,11 +90,25 @@ def resize_images_to_224(image): ...@@ -83,11 +90,25 @@ def resize_images_to_224(image):
class YOLOv11LangDetModel(object): class YOLOv11LangDetModel(object):
def __init__(self, weight, device): def __init__(self, langdetect_model_weights_dir, device):
self.model = YOLO(weight) langdetect_model_base_weight = str(
self.device = device os.path.join(langdetect_model_weights_dir, 'yolo_v11_cls_ft.pt')
)
def do_detect(self, images: list): langdetect_model_ch_jp_weight = str(
os.path.join(langdetect_model_weights_dir, 'yolo_v11_cls_ch_jp.pt')
)
langdetect_model_en_fr_ge_weight = str(
os.path.join(langdetect_model_weights_dir, 'yolo_v11_cls_en_fr_ge.pt')
)
self.model = YOLO(langdetect_model_base_weight)
self.ch_jp_model = YOLO(langdetect_model_ch_jp_weight)
self.en_fr_ge_model = YOLO(langdetect_model_en_fr_ge_weight)
if str(device).startswith("npu"):
self.device = torch.device(device)
else:
self.device = device
def do_detect(self, images: list, mode=LangDetectMode.BASE):
all_images = [] all_images = []
for image in images: for image in images:
width, height = image.size width, height = image.size
...@@ -98,7 +119,7 @@ class YOLOv11LangDetModel(object): ...@@ -98,7 +119,7 @@ class YOLOv11LangDetModel(object):
for temp_image in temp_images: for temp_image in temp_images:
all_images.append(resize_images_to_224(temp_image)) all_images.append(resize_images_to_224(temp_image))
images_lang_res = self.batch_predict(all_images, batch_size=8) images_lang_res = self.batch_predict(all_images, batch_size=8, mode=mode)
logger.info(f"images_lang_res: {images_lang_res}") logger.info(f"images_lang_res: {images_lang_res}")
if len(images_lang_res) > 0: if len(images_lang_res) > 0:
count_dict = Counter(images_lang_res) count_dict = Counter(images_lang_res)
...@@ -107,20 +128,39 @@ class YOLOv11LangDetModel(object): ...@@ -107,20 +128,39 @@ class YOLOv11LangDetModel(object):
language = None language = None
return language return language
def predict(self, image, mode=LangDetectMode.BASE):
if mode == LangDetectMode.BASE:
model = self.model
elif mode == LangDetectMode.CH_JP:
model = self.ch_jp_model
elif mode == LangDetectMode.EN_FR_GE:
model = self.en_fr_ge_model
else:
model = self.model
def predict(self, image): results = model.predict(image, verbose=False, device=self.device)
results = self.model.predict(image, verbose=False, device=self.device)
predicted_class_id = int(results[0].probs.top1) predicted_class_id = int(results[0].probs.top1)
predicted_class_name = self.model.names[predicted_class_id] predicted_class_name = model.names[predicted_class_id]
return predicted_class_name return predicted_class_name
def batch_predict(self, images: list, batch_size: int) -> list: def batch_predict(self, images: list, batch_size: int, mode=LangDetectMode.BASE) -> list:
images_lang_res = [] images_lang_res = []
if mode == LangDetectMode.BASE:
model = self.model
elif mode == LangDetectMode.CH_JP:
model = self.ch_jp_model
elif mode == LangDetectMode.EN_FR_GE:
model = self.en_fr_ge_model
else:
model = self.model
for index in range(0, len(images), batch_size): for index in range(0, len(images), batch_size):
lang_res = [ lang_res = [
image_res.cpu() image_res.cpu()
for image_res in self.model.predict( for image_res in model.predict(
images[index: index + batch_size], images[index: index + batch_size],
verbose = False, verbose = False,
device=self.device, device=self.device,
...@@ -128,7 +168,7 @@ class YOLOv11LangDetModel(object): ...@@ -128,7 +168,7 @@ class YOLOv11LangDetModel(object):
] ]
for res in lang_res: for res in lang_res:
predicted_class_id = int(res.probs.top1) predicted_class_id = int(res.probs.top1)
predicted_class_name = self.model.names[predicted_class_id] predicted_class_name = model.names[predicted_class_id]
images_lang_res.append(predicted_class_name) images_lang_res.append(predicted_class_name)
return images_lang_res return images_lang_res
\ No newline at end of file
...@@ -21,7 +21,7 @@ class ModifiedPaddleOCR(PaddleOCR): ...@@ -21,7 +21,7 @@ class ModifiedPaddleOCR(PaddleOCR):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.lang = kwargs.get('lang', 'ch')
# 在cpu架构为arm且不支持cuda时调用onnx、 # 在cpu架构为arm且不支持cuda时调用onnx、
if not torch.cuda.is_available() and platform.machine() in ['arm64', 'aarch64']: if not torch.cuda.is_available() and platform.machine() in ['arm64', 'aarch64']:
self.use_onnx = True self.use_onnx = True
...@@ -94,7 +94,7 @@ class ModifiedPaddleOCR(PaddleOCR): ...@@ -94,7 +94,7 @@ class ModifiedPaddleOCR(PaddleOCR):
ocr_res = [] ocr_res = []
for img in imgs: for img in imgs:
img = preprocess_image(img) img = preprocess_image(img)
if self.use_onnx: if self.lang in ['ch'] and self.use_onnx:
dt_boxes, elapse = self.additional_ocr.text_detector(img) dt_boxes, elapse = self.additional_ocr.text_detector(img)
else: else:
dt_boxes, elapse = self.text_detector(img) dt_boxes, elapse = self.text_detector(img)
...@@ -124,7 +124,7 @@ class ModifiedPaddleOCR(PaddleOCR): ...@@ -124,7 +124,7 @@ class ModifiedPaddleOCR(PaddleOCR):
img, cls_res_tmp, elapse = self.text_classifier(img) img, cls_res_tmp, elapse = self.text_classifier(img)
if not rec: if not rec:
cls_res.append(cls_res_tmp) cls_res.append(cls_res_tmp)
if self.use_onnx: if self.lang in ['ch'] and self.use_onnx:
rec_res, elapse = self.additional_ocr.text_recognizer(img) rec_res, elapse = self.additional_ocr.text_recognizer(img)
else: else:
rec_res, elapse = self.text_recognizer(img) rec_res, elapse = self.text_recognizer(img)
...@@ -142,7 +142,7 @@ class ModifiedPaddleOCR(PaddleOCR): ...@@ -142,7 +142,7 @@ class ModifiedPaddleOCR(PaddleOCR):
start = time.time() start = time.time()
ori_im = img.copy() ori_im = img.copy()
if self.use_onnx: if self.lang in ['ch'] and self.use_onnx:
dt_boxes, elapse = self.additional_ocr.text_detector(img) dt_boxes, elapse = self.additional_ocr.text_detector(img)
else: else:
dt_boxes, elapse = self.text_detector(img) dt_boxes, elapse = self.text_detector(img)
...@@ -183,7 +183,7 @@ class ModifiedPaddleOCR(PaddleOCR): ...@@ -183,7 +183,7 @@ class ModifiedPaddleOCR(PaddleOCR):
time_dict['cls'] = elapse time_dict['cls'] = elapse
logger.debug("cls num : {}, elapsed : {}".format( logger.debug("cls num : {}, elapsed : {}".format(
len(img_crop_list), elapse)) len(img_crop_list), elapse))
if self.use_onnx: if self.lang in ['ch'] and self.use_onnx:
rec_res, elapse = self.additional_ocr.text_recognizer(img_crop_list) rec_res, elapse = self.additional_ocr.text_recognizer(img_crop_list)
else: else:
rec_res, elapse = self.text_recognizer(img_crop_list) rec_res, elapse = self.text_recognizer(img_crop_list)
......
...@@ -6,4 +6,4 @@ weights: ...@@ -6,4 +6,4 @@ weights:
struct_eqtable: TabRec/StructEqTable struct_eqtable: TabRec/StructEqTable
tablemaster: TabRec/TableMaster tablemaster: TabRec/TableMaster
rapid_table: TabRec/RapidTable rapid_table: TabRec/RapidTable
yolo_v11n_langdetect: LangDetect/YOLO/yolo_v11_cls_ft.pt yolo_v11n_langdetect: LangDetect/YOLO
\ No newline at end of file \ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment