Commit 735f3a70 authored by myhloli's avatar myhloli
Browse files

feat(model): add language detection model and update related modules

- Add language detection model initialization and integration
- Update model list to include language detection
- Refactor language detection utils for better model management
parent 356cb1f2
...@@ -9,3 +9,4 @@ class AtomicModel: ...@@ -9,3 +9,4 @@ class AtomicModel:
MFR = "mfr" MFR = "mfr"
OCR = "ocr" OCR = "ocr"
Table = "table" Table = "table"
LangDetect = "langdetect"
...@@ -12,7 +12,7 @@ from magic_pdf.data.utils import load_images_from_pdf ...@@ -12,7 +12,7 @@ from magic_pdf.data.utils import load_images_from_pdf
from magic_pdf.libs.config_reader import get_local_models_dir, get_device from magic_pdf.libs.config_reader import get_local_models_dir, get_device
from magic_pdf.libs.pdf_check import extract_pages from magic_pdf.libs.pdf_check import extract_pages
from magic_pdf.model.model_list import AtomicModel from magic_pdf.model.model_list import AtomicModel
from magic_pdf.model.sub_modules.language_detection.yolov11.YOLOv11 import YOLOv11LangDetModel, LangDetectMode from magic_pdf.model.sub_modules.language_detection.yolov11.YOLOv11 import LangDetectMode
from magic_pdf.model.sub_modules.model_init import AtomModelSingleton from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
...@@ -61,19 +61,32 @@ def auto_detect_lang(pdf_bytes: bytes): ...@@ -61,19 +61,32 @@ def auto_detect_lang(pdf_bytes: bytes):
sample_pdf_bytes = sample_docs.tobytes() sample_pdf_bytes = sample_docs.tobytes()
simple_images = load_images_from_pdf(sample_pdf_bytes, dpi=200) simple_images = load_images_from_pdf(sample_pdf_bytes, dpi=200)
text_images = get_text_images(simple_images) text_images = get_text_images(simple_images)
local_models_dir, device, configs = get_model_config() langdetect_model = model_init(MODEL_NAME.YOLO_V11_LangDetect)
# 用yolo11做语言分类
langdetect_model_weights_dir = str(
os.path.join(
local_models_dir, configs['weights'][MODEL_NAME.YOLO_V11_LangDetect]
)
)
langdetect_model = YOLOv11LangDetModel(langdetect_model_weights_dir, device)
lang = langdetect_model.do_detect(text_images) lang = langdetect_model.do_detect(text_images)
if lang in ["ch", "japan"]: if lang in ["ch", "japan"]:
lang = langdetect_model.do_detect(text_images, mode=LangDetectMode.CH_JP) lang = langdetect_model.do_detect(text_images, mode=LangDetectMode.CH_JP)
elif lang in ["en", "fr", "german"]: elif lang in ["en", "fr", "german"]:
lang = langdetect_model.do_detect(text_images, mode=LangDetectMode.EN_FR_GE) lang = langdetect_model.do_detect(text_images, mode=LangDetectMode.EN_FR_GE)
return lang return lang
\ No newline at end of file
def model_init(model_name: str):
atom_model_manager = AtomModelSingleton()
if model_name == MODEL_NAME.YOLO_V11_LangDetect:
local_models_dir, device, configs = get_model_config()
model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.LangDetect,
langdetect_model_name=MODEL_NAME.YOLO_V11_LangDetect,
langdetect_model_weights_dir=str(
os.path.join(
local_models_dir, configs['weights'][MODEL_NAME.YOLO_V11_LangDetect]
)
),
device=device,
)
else:
raise ValueError(f"model_name {model_name} not found")
return model
...@@ -2,8 +2,8 @@ import torch ...@@ -2,8 +2,8 @@ import torch
from loguru import logger from loguru import logger
from magic_pdf.config.constants import MODEL_NAME from magic_pdf.config.constants import MODEL_NAME
from magic_pdf.libs.config_reader import get_device
from magic_pdf.model.model_list import AtomicModel from magic_pdf.model.model_list import AtomicModel
from magic_pdf.model.sub_modules.language_detection.yolov11.YOLOv11 import YOLOv11LangDetModel
from magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO import \ from magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO import \
DocLayoutYOLOModel DocLayoutYOLOModel
from magic_pdf.model.sub_modules.layout.layoutlmv3.model_init import \ from magic_pdf.model.sub_modules.layout.layoutlmv3.model_init import \
...@@ -63,6 +63,13 @@ def doclayout_yolo_model_init(weight, device='cpu'): ...@@ -63,6 +63,13 @@ def doclayout_yolo_model_init(weight, device='cpu'):
return model return model
def langdetect_model_init(langdetect_model_weights_dir, device='cpu'):
if str(device).startswith("npu"):
device = torch.device(device)
model = YOLOv11LangDetModel(langdetect_model_weights_dir, device)
return model
def ocr_model_init(show_log: bool = False, def ocr_model_init(show_log: bool = False,
det_db_box_thresh=0.3, det_db_box_thresh=0.3,
lang=None, lang=None,
...@@ -130,6 +137,9 @@ def atom_model_init(model_name: str, **kwargs): ...@@ -130,6 +137,9 @@ def atom_model_init(model_name: str, **kwargs):
kwargs.get('doclayout_yolo_weights'), kwargs.get('doclayout_yolo_weights'),
kwargs.get('device') kwargs.get('device')
) )
else:
logger.error('layout model name not allow')
exit(1)
elif model_name == AtomicModel.MFD: elif model_name == AtomicModel.MFD:
atom_model = mfd_model_init( atom_model = mfd_model_init(
kwargs.get('mfd_weights'), kwargs.get('mfd_weights'),
...@@ -155,6 +165,15 @@ def atom_model_init(model_name: str, **kwargs): ...@@ -155,6 +165,15 @@ def atom_model_init(model_name: str, **kwargs):
kwargs.get('device'), kwargs.get('device'),
kwargs.get('ocr_engine') kwargs.get('ocr_engine')
) )
elif model_name == AtomicModel.LangDetect:
if kwargs.get('langdetect_model_name') == MODEL_NAME.YOLO_V11_LangDetect:
atom_model = langdetect_model_init(
kwargs.get('langdetect_model_weights_dir'),
kwargs.get('device')
)
else:
logger.error('langdetect model name not allow')
exit(1)
else: else:
logger.error('model name not allow') logger.error('model name not allow')
exit(1) exit(1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment