Unverified Commit aa535316 authored by Xiaomeng Zhao's avatar Xiaomeng Zhao Committed by GitHub
Browse files

Merge pull request #1453 from myhloli/dev

refactor(langdetect): simplify language detection model
parents c634e2df 3271cf75
...@@ -51,6 +51,7 @@ magic-pdf --help ...@@ -51,6 +51,7 @@ magic-pdf --help
## 已知问题 ## 已知问题
- paddleocr使用内嵌onnx模型,仅支持中英文ocr,不支持其他语言ocr - paddleocr使用内嵌onnx模型,仅在默认语言配置下能以较快速度对中英文进行识别
- 自定义lang参数时,paddleocr速度会存在明显下降情况
- layout模型使用layoutlmv3时会发生间歇性崩溃,建议使用默认配置的doclayout_yolo模型 - layout模型使用layoutlmv3时会发生间歇性崩溃,建议使用默认配置的doclayout_yolo模型
- 表格解析仅适配了rapid_table模型,其他模型可能会无法使用 - 表格解析仅适配了rapid_table模型,其他模型可能会无法使用
\ No newline at end of file
...@@ -153,6 +153,7 @@ class PymuDocDataset(Dataset): ...@@ -153,6 +153,7 @@ class PymuDocDataset(Dataset):
logger.info(f"lang: {lang}, detect_lang: {self._lang}") logger.info(f"lang: {lang}, detect_lang: {self._lang}")
else: else:
self._lang = lang self._lang = lang
logger.info(f"lang: {lang}")
def __len__(self) -> int: def __len__(self) -> int:
"""The page number of the pdf.""" """The page number of the pdf."""
return len(self._records) return len(self._records)
......
...@@ -9,3 +9,4 @@ class AtomicModel: ...@@ -9,3 +9,4 @@ class AtomicModel:
MFR = "mfr" MFR = "mfr"
OCR = "ocr" OCR = "ocr"
Table = "table" Table = "table"
LangDetect = "langdetect"
...@@ -12,7 +12,6 @@ from magic_pdf.data.utils import load_images_from_pdf ...@@ -12,7 +12,6 @@ from magic_pdf.data.utils import load_images_from_pdf
from magic_pdf.libs.config_reader import get_local_models_dir, get_device from magic_pdf.libs.config_reader import get_local_models_dir, get_device
from magic_pdf.libs.pdf_check import extract_pages from magic_pdf.libs.pdf_check import extract_pages
from magic_pdf.model.model_list import AtomicModel from magic_pdf.model.model_list import AtomicModel
from magic_pdf.model.sub_modules.language_detection.yolov11.YOLOv11 import YOLOv11LangDetModel
from magic_pdf.model.sub_modules.model_init import AtomModelSingleton from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
...@@ -59,15 +58,29 @@ def get_text_images(simple_images): ...@@ -59,15 +58,29 @@ def get_text_images(simple_images):
def auto_detect_lang(pdf_bytes: bytes): def auto_detect_lang(pdf_bytes: bytes):
sample_docs = extract_pages(pdf_bytes) sample_docs = extract_pages(pdf_bytes)
sample_pdf_bytes = sample_docs.tobytes() sample_pdf_bytes = sample_docs.tobytes()
simple_images = load_images_from_pdf(sample_pdf_bytes, dpi=96) simple_images = load_images_from_pdf(sample_pdf_bytes, dpi=200)
text_images = get_text_images(simple_images) text_images = get_text_images(simple_images)
local_models_dir, device, configs = get_model_config() langdetect_model = model_init(MODEL_NAME.YOLO_V11_LangDetect)
# 用yolo11做语言分类
langdetect_model_weights = str(
os.path.join(
local_models_dir, configs['weights'][MODEL_NAME.YOLO_V11_LangDetect]
)
)
langdetect_model = YOLOv11LangDetModel(langdetect_model_weights, device)
lang = langdetect_model.do_detect(text_images) lang = langdetect_model.do_detect(text_images)
return lang return lang
\ No newline at end of file
def model_init(model_name: str):
atom_model_manager = AtomModelSingleton()
if model_name == MODEL_NAME.YOLO_V11_LangDetect:
local_models_dir, device, configs = get_model_config()
model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.LangDetect,
langdetect_model_name=MODEL_NAME.YOLO_V11_LangDetect,
langdetect_model_weight=str(
os.path.join(
local_models_dir, configs['weights'][MODEL_NAME.YOLO_V11_LangDetect]
)
),
device=device,
)
else:
raise ValueError(f"model_name {model_name} not found")
return model
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
from collections import Counter from collections import Counter
from uuid import uuid4 from uuid import uuid4
import torch
from PIL import Image from PIL import Image
from loguru import logger from loguru import logger
from ultralytics import YOLO from ultralytics import YOLO
...@@ -83,10 +84,14 @@ def resize_images_to_224(image): ...@@ -83,10 +84,14 @@ def resize_images_to_224(image):
class YOLOv11LangDetModel(object): class YOLOv11LangDetModel(object):
def __init__(self, weight, device): def __init__(self, langdetect_model_weight, device):
self.model = YOLO(weight)
self.device = device
self.model = YOLO(langdetect_model_weight)
if str(device).startswith("npu"):
self.device = torch.device(device)
else:
self.device = device
def do_detect(self, images: list): def do_detect(self, images: list):
all_images = [] all_images = []
for image in images: for image in images:
...@@ -99,7 +104,7 @@ class YOLOv11LangDetModel(object): ...@@ -99,7 +104,7 @@ class YOLOv11LangDetModel(object):
all_images.append(resize_images_to_224(temp_image)) all_images.append(resize_images_to_224(temp_image))
images_lang_res = self.batch_predict(all_images, batch_size=8) images_lang_res = self.batch_predict(all_images, batch_size=8)
logger.info(f"images_lang_res: {images_lang_res}") # logger.info(f"images_lang_res: {images_lang_res}")
if len(images_lang_res) > 0: if len(images_lang_res) > 0:
count_dict = Counter(images_lang_res) count_dict = Counter(images_lang_res)
language = max(count_dict, key=count_dict.get) language = max(count_dict, key=count_dict.get)
...@@ -107,7 +112,6 @@ class YOLOv11LangDetModel(object): ...@@ -107,7 +112,6 @@ class YOLOv11LangDetModel(object):
language = None language = None
return language return language
def predict(self, image): def predict(self, image):
results = self.model.predict(image, verbose=False, device=self.device) results = self.model.predict(image, verbose=False, device=self.device)
predicted_class_id = int(results[0].probs.top1) predicted_class_id = int(results[0].probs.top1)
...@@ -117,6 +121,7 @@ class YOLOv11LangDetModel(object): ...@@ -117,6 +121,7 @@ class YOLOv11LangDetModel(object):
def batch_predict(self, images: list, batch_size: int) -> list: def batch_predict(self, images: list, batch_size: int) -> list:
images_lang_res = [] images_lang_res = []
for index in range(0, len(images), batch_size): for index in range(0, len(images), batch_size):
lang_res = [ lang_res = [
image_res.cpu() image_res.cpu()
......
...@@ -2,8 +2,8 @@ import torch ...@@ -2,8 +2,8 @@ import torch
from loguru import logger from loguru import logger
from magic_pdf.config.constants import MODEL_NAME from magic_pdf.config.constants import MODEL_NAME
from magic_pdf.libs.config_reader import get_device
from magic_pdf.model.model_list import AtomicModel from magic_pdf.model.model_list import AtomicModel
from magic_pdf.model.sub_modules.language_detection.yolov11.YOLOv11 import YOLOv11LangDetModel
from magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO import \ from magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO import \
DocLayoutYOLOModel DocLayoutYOLOModel
from magic_pdf.model.sub_modules.layout.layoutlmv3.model_init import \ from magic_pdf.model.sub_modules.layout.layoutlmv3.model_init import \
...@@ -63,6 +63,13 @@ def doclayout_yolo_model_init(weight, device='cpu'): ...@@ -63,6 +63,13 @@ def doclayout_yolo_model_init(weight, device='cpu'):
return model return model
def langdetect_model_init(langdetect_model_weight, device='cpu'):
if str(device).startswith("npu"):
device = torch.device(device)
model = YOLOv11LangDetModel(langdetect_model_weight, device)
return model
def ocr_model_init(show_log: bool = False, def ocr_model_init(show_log: bool = False,
det_db_box_thresh=0.3, det_db_box_thresh=0.3,
lang=None, lang=None,
...@@ -130,6 +137,9 @@ def atom_model_init(model_name: str, **kwargs): ...@@ -130,6 +137,9 @@ def atom_model_init(model_name: str, **kwargs):
kwargs.get('doclayout_yolo_weights'), kwargs.get('doclayout_yolo_weights'),
kwargs.get('device') kwargs.get('device')
) )
else:
logger.error('layout model name not allow')
exit(1)
elif model_name == AtomicModel.MFD: elif model_name == AtomicModel.MFD:
atom_model = mfd_model_init( atom_model = mfd_model_init(
kwargs.get('mfd_weights'), kwargs.get('mfd_weights'),
...@@ -155,6 +165,15 @@ def atom_model_init(model_name: str, **kwargs): ...@@ -155,6 +165,15 @@ def atom_model_init(model_name: str, **kwargs):
kwargs.get('device'), kwargs.get('device'),
kwargs.get('ocr_engine') kwargs.get('ocr_engine')
) )
elif model_name == AtomicModel.LangDetect:
if kwargs.get('langdetect_model_name') == MODEL_NAME.YOLO_V11_LangDetect:
atom_model = langdetect_model_init(
kwargs.get('langdetect_model_weight'),
kwargs.get('device')
)
else:
logger.error('langdetect model name not allow')
exit(1)
else: else:
logger.error('model name not allow') logger.error('model name not allow')
exit(1) exit(1)
......
...@@ -21,7 +21,7 @@ class ModifiedPaddleOCR(PaddleOCR): ...@@ -21,7 +21,7 @@ class ModifiedPaddleOCR(PaddleOCR):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.lang = kwargs.get('lang', 'ch')
# 在cpu架构为arm且不支持cuda时调用onnx、 # 在cpu架构为arm且不支持cuda时调用onnx、
if not torch.cuda.is_available() and platform.machine() in ['arm64', 'aarch64']: if not torch.cuda.is_available() and platform.machine() in ['arm64', 'aarch64']:
self.use_onnx = True self.use_onnx = True
...@@ -94,7 +94,7 @@ class ModifiedPaddleOCR(PaddleOCR): ...@@ -94,7 +94,7 @@ class ModifiedPaddleOCR(PaddleOCR):
ocr_res = [] ocr_res = []
for img in imgs: for img in imgs:
img = preprocess_image(img) img = preprocess_image(img)
if self.use_onnx: if self.lang in ['ch'] and self.use_onnx:
dt_boxes, elapse = self.additional_ocr.text_detector(img) dt_boxes, elapse = self.additional_ocr.text_detector(img)
else: else:
dt_boxes, elapse = self.text_detector(img) dt_boxes, elapse = self.text_detector(img)
...@@ -124,7 +124,7 @@ class ModifiedPaddleOCR(PaddleOCR): ...@@ -124,7 +124,7 @@ class ModifiedPaddleOCR(PaddleOCR):
img, cls_res_tmp, elapse = self.text_classifier(img) img, cls_res_tmp, elapse = self.text_classifier(img)
if not rec: if not rec:
cls_res.append(cls_res_tmp) cls_res.append(cls_res_tmp)
if self.use_onnx: if self.lang in ['ch'] and self.use_onnx:
rec_res, elapse = self.additional_ocr.text_recognizer(img) rec_res, elapse = self.additional_ocr.text_recognizer(img)
else: else:
rec_res, elapse = self.text_recognizer(img) rec_res, elapse = self.text_recognizer(img)
...@@ -142,7 +142,7 @@ class ModifiedPaddleOCR(PaddleOCR): ...@@ -142,7 +142,7 @@ class ModifiedPaddleOCR(PaddleOCR):
start = time.time() start = time.time()
ori_im = img.copy() ori_im = img.copy()
if self.use_onnx: if self.lang in ['ch'] and self.use_onnx:
dt_boxes, elapse = self.additional_ocr.text_detector(img) dt_boxes, elapse = self.additional_ocr.text_detector(img)
else: else:
dt_boxes, elapse = self.text_detector(img) dt_boxes, elapse = self.text_detector(img)
...@@ -183,7 +183,7 @@ class ModifiedPaddleOCR(PaddleOCR): ...@@ -183,7 +183,7 @@ class ModifiedPaddleOCR(PaddleOCR):
time_dict['cls'] = elapse time_dict['cls'] = elapse
logger.debug("cls num : {}, elapsed : {}".format( logger.debug("cls num : {}, elapsed : {}".format(
len(img_crop_list), elapse)) len(img_crop_list), elapse))
if self.use_onnx: if self.lang in ['ch'] and self.use_onnx:
rec_res, elapse = self.additional_ocr.text_recognizer(img_crop_list) rec_res, elapse = self.additional_ocr.text_recognizer(img_crop_list)
else: else:
rec_res, elapse = self.text_recognizer(img_crop_list) rec_res, elapse = self.text_recognizer(img_crop_list)
......
...@@ -6,4 +6,4 @@ weights: ...@@ -6,4 +6,4 @@ weights:
struct_eqtable: TabRec/StructEqTable struct_eqtable: TabRec/StructEqTable
tablemaster: TabRec/TableMaster tablemaster: TabRec/TableMaster
rapid_table: TabRec/RapidTable rapid_table: TabRec/RapidTable
yolo_v11n_langdetect: LangDetect/YOLO/yolo_v11_cls_ft.pt yolo_v11n_langdetect: LangDetect/YOLO/yolo_v11_ft.pt
\ No newline at end of file \ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment