"docs/vscode:/vscode.git/clone" did not exist on "abc6f88b22e021a4f4739022f912c63effe6a6f3"
Commit 3271cf75 authored by myhloli's avatar myhloli
Browse files

refactor(langdetect): simplify language detection model and improve logging

- Remove LangDetectMode and related conditional logic
- Use a single model weight for language detection
- Add logging for language detection results
- Update model initialization and prediction methods
parent 735f3a70
...@@ -153,6 +153,7 @@ class PymuDocDataset(Dataset): ...@@ -153,6 +153,7 @@ class PymuDocDataset(Dataset):
logger.info(f"lang: {lang}, detect_lang: {self._lang}") logger.info(f"lang: {lang}, detect_lang: {self._lang}")
else: else:
self._lang = lang self._lang = lang
logger.info(f"lang: {lang}")
def __len__(self) -> int: def __len__(self) -> int:
"""The page number of the pdf.""" """The page number of the pdf."""
return len(self._records) return len(self._records)
......
...@@ -12,7 +12,6 @@ from magic_pdf.data.utils import load_images_from_pdf ...@@ -12,7 +12,6 @@ from magic_pdf.data.utils import load_images_from_pdf
from magic_pdf.libs.config_reader import get_local_models_dir, get_device from magic_pdf.libs.config_reader import get_local_models_dir, get_device
from magic_pdf.libs.pdf_check import extract_pages from magic_pdf.libs.pdf_check import extract_pages
from magic_pdf.model.model_list import AtomicModel from magic_pdf.model.model_list import AtomicModel
from magic_pdf.model.sub_modules.language_detection.yolov11.YOLOv11 import LangDetectMode
from magic_pdf.model.sub_modules.model_init import AtomModelSingleton from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
...@@ -63,11 +62,6 @@ def auto_detect_lang(pdf_bytes: bytes): ...@@ -63,11 +62,6 @@ def auto_detect_lang(pdf_bytes: bytes):
text_images = get_text_images(simple_images) text_images = get_text_images(simple_images)
langdetect_model = model_init(MODEL_NAME.YOLO_V11_LangDetect) langdetect_model = model_init(MODEL_NAME.YOLO_V11_LangDetect)
lang = langdetect_model.do_detect(text_images) lang = langdetect_model.do_detect(text_images)
if lang in ["ch", "japan"]:
lang = langdetect_model.do_detect(text_images, mode=LangDetectMode.CH_JP)
elif lang in ["en", "fr", "german"]:
lang = langdetect_model.do_detect(text_images, mode=LangDetectMode.EN_FR_GE)
return lang return lang
...@@ -79,7 +73,7 @@ def model_init(model_name: str): ...@@ -79,7 +73,7 @@ def model_init(model_name: str):
model = atom_model_manager.get_atom_model( model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.LangDetect, atom_model_name=AtomicModel.LangDetect,
langdetect_model_name=MODEL_NAME.YOLO_V11_LangDetect, langdetect_model_name=MODEL_NAME.YOLO_V11_LangDetect,
langdetect_model_weights_dir=str( langdetect_model_weight=str(
os.path.join( os.path.join(
local_models_dir, configs['weights'][MODEL_NAME.YOLO_V11_LangDetect] local_models_dir, configs['weights'][MODEL_NAME.YOLO_V11_LangDetect]
) )
......
# Copyright (c) Opendatalab. All rights reserved. # Copyright (c) Opendatalab. All rights reserved.
import os
from collections import Counter from collections import Counter
from uuid import uuid4 from uuid import uuid4
...@@ -19,11 +18,6 @@ language_dict = { ...@@ -19,11 +18,6 @@ language_dict = {
"ru": "俄语" "ru": "俄语"
} }
class LangDetectMode:
BASE = "base"
CH_JP = "ch_jp"
EN_FR_GE = "en_fr_ge"
def split_images(image, result_images=None): def split_images(image, result_images=None):
""" """
...@@ -90,25 +84,15 @@ def resize_images_to_224(image): ...@@ -90,25 +84,15 @@ def resize_images_to_224(image):
class YOLOv11LangDetModel(object): class YOLOv11LangDetModel(object):
def __init__(self, langdetect_model_weights_dir, device): def __init__(self, langdetect_model_weight, device):
langdetect_model_base_weight = str(
os.path.join(langdetect_model_weights_dir, 'yolo_v11_cls_ft.pt') self.model = YOLO(langdetect_model_weight)
)
langdetect_model_ch_jp_weight = str(
os.path.join(langdetect_model_weights_dir, 'yolo_v11_cls_ch_jp.pt')
)
langdetect_model_en_fr_ge_weight = str(
os.path.join(langdetect_model_weights_dir, 'yolo_v11_cls_en_fr_ge.pt')
)
self.model = YOLO(langdetect_model_base_weight)
self.ch_jp_model = YOLO(langdetect_model_ch_jp_weight)
self.en_fr_ge_model = YOLO(langdetect_model_en_fr_ge_weight)
if str(device).startswith("npu"): if str(device).startswith("npu"):
self.device = torch.device(device) self.device = torch.device(device)
else: else:
self.device = device self.device = device
def do_detect(self, images: list, mode=LangDetectMode.BASE): def do_detect(self, images: list):
all_images = [] all_images = []
for image in images: for image in images:
width, height = image.size width, height = image.size
...@@ -119,8 +103,8 @@ class YOLOv11LangDetModel(object): ...@@ -119,8 +103,8 @@ class YOLOv11LangDetModel(object):
for temp_image in temp_images: for temp_image in temp_images:
all_images.append(resize_images_to_224(temp_image)) all_images.append(resize_images_to_224(temp_image))
images_lang_res = self.batch_predict(all_images, batch_size=8, mode=mode) images_lang_res = self.batch_predict(all_images, batch_size=8)
logger.info(f"images_lang_res: {images_lang_res}") # logger.info(f"images_lang_res: {images_lang_res}")
if len(images_lang_res) > 0: if len(images_lang_res) > 0:
count_dict = Counter(images_lang_res) count_dict = Counter(images_lang_res)
language = max(count_dict, key=count_dict.get) language = max(count_dict, key=count_dict.get)
...@@ -128,39 +112,20 @@ class YOLOv11LangDetModel(object): ...@@ -128,39 +112,20 @@ class YOLOv11LangDetModel(object):
language = None language = None
return language return language
def predict(self, image, mode=LangDetectMode.BASE): def predict(self, image):
results = self.model.predict(image, verbose=False, device=self.device)
if mode == LangDetectMode.BASE:
model = self.model
elif mode == LangDetectMode.CH_JP:
model = self.ch_jp_model
elif mode == LangDetectMode.EN_FR_GE:
model = self.en_fr_ge_model
else:
model = self.model
results = model.predict(image, verbose=False, device=self.device)
predicted_class_id = int(results[0].probs.top1) predicted_class_id = int(results[0].probs.top1)
predicted_class_name = model.names[predicted_class_id] predicted_class_name = self.model.names[predicted_class_id]
return predicted_class_name return predicted_class_name
def batch_predict(self, images: list, batch_size: int, mode=LangDetectMode.BASE) -> list: def batch_predict(self, images: list, batch_size: int) -> list:
images_lang_res = [] images_lang_res = []
if mode == LangDetectMode.BASE:
model = self.model
elif mode == LangDetectMode.CH_JP:
model = self.ch_jp_model
elif mode == LangDetectMode.EN_FR_GE:
model = self.en_fr_ge_model
else:
model = self.model
for index in range(0, len(images), batch_size): for index in range(0, len(images), batch_size):
lang_res = [ lang_res = [
image_res.cpu() image_res.cpu()
for image_res in model.predict( for image_res in self.model.predict(
images[index: index + batch_size], images[index: index + batch_size],
verbose = False, verbose = False,
device=self.device, device=self.device,
...@@ -168,7 +133,7 @@ class YOLOv11LangDetModel(object): ...@@ -168,7 +133,7 @@ class YOLOv11LangDetModel(object):
] ]
for res in lang_res: for res in lang_res:
predicted_class_id = int(res.probs.top1) predicted_class_id = int(res.probs.top1)
predicted_class_name = model.names[predicted_class_id] predicted_class_name = self.model.names[predicted_class_id]
images_lang_res.append(predicted_class_name) images_lang_res.append(predicted_class_name)
return images_lang_res return images_lang_res
\ No newline at end of file
...@@ -63,10 +63,10 @@ def doclayout_yolo_model_init(weight, device='cpu'): ...@@ -63,10 +63,10 @@ def doclayout_yolo_model_init(weight, device='cpu'):
return model return model
def langdetect_model_init(langdetect_model_weights_dir, device='cpu'): def langdetect_model_init(langdetect_model_weight, device='cpu'):
if str(device).startswith("npu"): if str(device).startswith("npu"):
device = torch.device(device) device = torch.device(device)
model = YOLOv11LangDetModel(langdetect_model_weights_dir, device) model = YOLOv11LangDetModel(langdetect_model_weight, device)
return model return model
...@@ -168,7 +168,7 @@ def atom_model_init(model_name: str, **kwargs): ...@@ -168,7 +168,7 @@ def atom_model_init(model_name: str, **kwargs):
elif model_name == AtomicModel.LangDetect: elif model_name == AtomicModel.LangDetect:
if kwargs.get('langdetect_model_name') == MODEL_NAME.YOLO_V11_LangDetect: if kwargs.get('langdetect_model_name') == MODEL_NAME.YOLO_V11_LangDetect:
atom_model = langdetect_model_init( atom_model = langdetect_model_init(
kwargs.get('langdetect_model_weights_dir'), kwargs.get('langdetect_model_weight'),
kwargs.get('device') kwargs.get('device')
) )
else: else:
......
...@@ -6,4 +6,4 @@ weights: ...@@ -6,4 +6,4 @@ weights:
struct_eqtable: TabRec/StructEqTable struct_eqtable: TabRec/StructEqTable
tablemaster: TabRec/TableMaster tablemaster: TabRec/TableMaster
rapid_table: TabRec/RapidTable rapid_table: TabRec/RapidTable
yolo_v11n_langdetect: LangDetect/YOLO yolo_v11n_langdetect: LangDetect/YOLO/yolo_v11_ft.pt
\ No newline at end of file \ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment