Commit 20438bd2 authored by myhloli's avatar myhloli
Browse files

feat(language-detection): add YOLOv11 language detection model

- Add YOLOv11 language detection model for PDF documents
- Implement language detection in PymuDocDataset
- Update app.py to include 'auto' language option
- Create language detection utilities and constants
parent 9e4ebea9
......@@ -52,6 +52,8 @@ class MODEL_NAME:
RAPID_TABLE = 'rapid_table'
YOLO_V11_LangDetect = 'yolo_v11n_langdetect'
PARSE_TYPE_TXT = 'txt'
PARSE_TYPE_OCR = 'ocr'
......
......@@ -3,11 +3,13 @@ from abc import ABC, abstractmethod
from typing import Callable, Iterator
import fitz
from loguru import logger
from magic_pdf.config.enums import SupportedPdfParseMethod
from magic_pdf.data.schemas import PageInfo
from magic_pdf.data.utils import fitz_doc_to_image
from magic_pdf.filter import classify
from magic_pdf.model.sub_modules.language_detection.utils import auto_detect_lang
class PageableData(ABC):
......@@ -133,7 +135,7 @@ class Dataset(ABC):
class PymuDocDataset(Dataset):
def __init__(self, bits: bytes):
def __init__(self, bits: bytes, lang=None):
"""Initialize the dataset, which wraps the pymudoc documents.
Args:
......@@ -144,6 +146,13 @@ class PymuDocDataset(Dataset):
self._data_bits = bits
self._raw_data = bits
if lang == '':
self._lang = None
elif lang == 'auto':
self._lang = auto_detect_lang(bits)
logger.info(f"lang: {lang}, detect_lang: {self._lang}")
else:
self._lang = lang
def __len__(self) -> int:
"""The page number of the pdf."""
return len(self._records)
......@@ -197,6 +206,8 @@ class PymuDocDataset(Dataset):
Returns:
Any: return the result generated by proc
"""
if 'lang' in kwargs and self._lang is not None:
kwargs['lang'] = self._lang
return proc(self, *args, **kwargs)
def classify(self) -> SupportedPdfParseMethod:
......
import fitz
import numpy as np
from loguru import logger
from magic_pdf.utils.annotations import ImportPIL
......@@ -30,3 +31,37 @@ def fitz_doc_to_image(doc, dpi=200) -> dict:
img_dict = {'img': img, 'width': pm.width, 'height': pm.height}
return img_dict
@ImportPIL
def load_images_from_pdf(pdf_bytes: bytes, dpi=200, start_page_id=0, end_page_id=None) -> list:
from PIL import Image
images = []
with fitz.open('pdf', pdf_bytes) as doc:
pdf_page_num = doc.page_count
end_page_id = (
end_page_id
if end_page_id is not None and end_page_id >= 0
else pdf_page_num - 1
)
if end_page_id > pdf_page_num - 1:
logger.warning('end_page_id is out of range, use images length')
end_page_id = pdf_page_num - 1
for index in range(0, doc.page_count):
if start_page_id <= index <= end_page_id:
page = doc[index]
mat = fitz.Matrix(dpi / 72, dpi / 72)
pm = page.get_pixmap(matrix=mat, alpha=False)
# If the width or height exceeds 4500 after scaling, do not scale further.
if pm.width > 4500 or pm.height > 4500:
pm = page.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False)
img = Image.frombytes('RGB', (pm.width, pm.height), pm.samples)
img = np.array(img)
img_dict = {'img': img, 'width': pm.width, 'height': pm.height}
else:
img_dict = {'img': [], 'width': 0, 'height': 0}
images.append(img_dict)
return images
import os
import time
import fitz
import numpy as np
from loguru import logger
# 关闭paddle的信号处理
......@@ -44,47 +42,6 @@ def remove_duplicates_dicts(lst):
return unique_dicts
def load_images_from_pdf(
pdf_bytes: bytes, dpi=200, start_page_id=0, end_page_id=None
) -> list:
try:
from PIL import Image
except ImportError:
logger.error('Pillow not installed, please install by pip.')
exit(1)
images = []
with fitz.open('pdf', pdf_bytes) as doc:
pdf_page_num = doc.page_count
end_page_id = (
end_page_id
if end_page_id is not None and end_page_id >= 0
else pdf_page_num - 1
)
if end_page_id > pdf_page_num - 1:
logger.warning('end_page_id is out of range, use images length')
end_page_id = pdf_page_num - 1
for index in range(0, doc.page_count):
if start_page_id <= index <= end_page_id:
page = doc[index]
mat = fitz.Matrix(dpi / 72, dpi / 72)
pm = page.get_pixmap(matrix=mat, alpha=False)
# If the width or height exceeds 4500 after scaling, do not scale further.
if pm.width > 4500 or pm.height > 4500:
pm = page.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False)
img = Image.frombytes('RGB', (pm.width, pm.height), pm.samples)
img = np.array(img)
img_dict = {'img': img, 'width': pm.width, 'height': pm.height}
else:
img_dict = {'img': [], 'width': 0, 'height': 0}
images.append(img_dict)
return images
class ModelSingleton:
_instance = None
_models = {}
......@@ -197,9 +154,6 @@ def doc_analyze(
table_enable=None,
) -> InferenceResult:
if lang == '':
lang = None
model_manager = ModelSingleton()
custom_model = model_manager.get_model(
ocr, show_log, lang, layout_model, formula_enable, table_enable
......
# Copyright (c) Opendatalab. All rights reserved.
# Copyright (c) Opendatalab. All rights reserved.
import os
from pathlib import Path
import yaml
from PIL import Image
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
from magic_pdf.config.constants import MODEL_NAME
from magic_pdf.data.utils import load_images_from_pdf
from magic_pdf.libs.config_reader import get_local_models_dir, get_device
from magic_pdf.libs.pdf_check import extract_pages
from magic_pdf.model.model_list import AtomicModel
from magic_pdf.model.sub_modules.language_detection.yolov11.YOLOv11 import YOLOv11LangDetModel
from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
def get_model_config():
local_models_dir = get_local_models_dir()
device = get_device()
current_file_path = os.path.abspath(__file__)
root_dir = Path(current_file_path).parents[3]
model_config_dir = os.path.join(root_dir, 'resources', 'model_config')
config_path = os.path.join(model_config_dir, 'model_configs.yaml')
with open(config_path, 'r', encoding='utf-8') as f:
configs = yaml.load(f, Loader=yaml.FullLoader)
return local_models_dir, device, configs
def get_text_images(simple_images):
local_models_dir, device, configs = get_model_config()
atom_model_manager = AtomModelSingleton()
temp_layout_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.Layout,
layout_model_name=MODEL_NAME.DocLayout_YOLO,
doclayout_yolo_weights=str(
os.path.join(
local_models_dir, configs['weights'][MODEL_NAME.DocLayout_YOLO]
)
),
device=device,
)
text_images = []
for simple_image in simple_images:
image = Image.fromarray(simple_image['img'])
layout_res = temp_layout_model.predict(image)
# 给textblock截图
for res in layout_res:
if res['category_id'] in [1]:
x1, y1, _, _, x2, y2, _, _ = res['poly']
# 初步清洗(宽和高都小于100)
if x2 - x1 < 100 and y2 - y1 < 100:
continue
text_images.append(image.crop((x1, y1, x2, y2)))
return text_images
def auto_detect_lang(pdf_bytes: bytes):
sample_docs = extract_pages(pdf_bytes)
sample_pdf_bytes = sample_docs.tobytes()
simple_images = load_images_from_pdf(sample_pdf_bytes, dpi=96)
text_images = get_text_images(simple_images)
local_models_dir, device, configs = get_model_config()
# 用yolo11做语言分类
langdetect_model_weights = str(
os.path.join(
local_models_dir, configs['weights'][MODEL_NAME.YOLO_V11_LangDetect]
)
)
langdetect_model = YOLOv11LangDetModel(langdetect_model_weights, device)
lang = langdetect_model.do_detect(text_images)
return lang
\ No newline at end of file
# Copyright (c) Opendatalab. All rights reserved.
from collections import Counter
from uuid import uuid4
from PIL import Image
from loguru import logger
from ultralytics import YOLO
language_dict = {
"ch": "中文简体",
"en": "英语",
"japan": "日语",
"korean": "韩语",
"fr": "法语",
"german": "德语",
"ar": "阿拉伯语",
"ru": "俄语"
}
def split_images(image, result_images=None):
"""
对输入文件夹内的图片进行处理,若图片竖向(y方向)分辨率超过400,则进行拆分,
每次平分图片,直至拆分出的图片竖向分辨率都满足400以下,将处理后的图片(拆分后的子图片)保存到输出文件夹。
避免保存因裁剪区域超出图片范围导致出现的无效黑色图片部分。
"""
if result_images is None:
result_images = []
width, height = image.size
long_side = max(width, height) # 获取较长边长度
if long_side <= 400:
result_images.append(image)
return result_images
new_long_side = long_side // 2
sub_images = []
if width >= height: # 如果宽度是较长边
for x in range(0, width, new_long_side):
# 判断裁剪区域是否超出图片范围,如果超出则不进行裁剪保存操作
if x + new_long_side > width:
continue
box = (x, 0, x + new_long_side, height)
sub_image = image.crop(box)
sub_images.append(sub_image)
else: # 如果高度是较长边
for y in range(0, height, new_long_side):
# 判断裁剪区域是否超出图片范围,如果超出则不进行裁剪保存操作
if y + new_long_side > height:
continue
box = (0, y, width, y + new_long_side)
sub_image = image.crop(box)
sub_images.append(sub_image)
for sub_image in sub_images:
split_images(sub_image, result_images)
return result_images
def resize_images_to_224(image):
"""
若分辨率小于224则用黑色背景补齐到224*224大小,若大于等于224则调整为224*224大小,并保存到输出文件夹中。
"""
try:
width, height = image.size
if width < 224 or height < 224:
new_image = Image.new('RGB', (224, 224), (0, 0, 0))
paste_x = (224 - width) // 2
paste_y = (224 - height) // 2
new_image.paste(image, (paste_x, paste_y))
image = new_image
else:
image = image.resize((224, 224), Image.Resampling.LANCZOS)
# uuid = str(uuid4())
# image.save(f"/tmp/{uuid}.jpg")
return image
except Exception as e:
logger.exception(e)
class YOLOv11LangDetModel(object):
def __init__(self, weight, device):
self.model = YOLO(weight)
self.device = device
def do_detect(self, images: list):
all_images = []
for image in images:
width, height = image.size
# logger.info(f"image size: {width} x {height}")
if width < 100 and height < 100:
continue
temp_images = split_images(image)
for temp_image in temp_images:
all_images.append(resize_images_to_224(temp_image))
images_lang_res = self.batch_predict(all_images, batch_size=8)
logger.info(f"images_lang_res: {images_lang_res}")
if len(images_lang_res) > 0:
count_dict = Counter(images_lang_res)
language = max(count_dict, key=count_dict.get)
else:
language = None
return language
def predict(self, image):
results = self.model.predict(image, verbose=False, device=self.device)
predicted_class_id = int(results[0].probs.top1)
predicted_class_name = self.model.names[predicted_class_id]
return predicted_class_name
def batch_predict(self, images: list, batch_size: int) -> list:
images_lang_res = []
for index in range(0, len(images), batch_size):
lang_res = [
image_res.cpu()
for image_res in self.model.predict(
images[index: index + batch_size],
verbose = False,
device=self.device,
)
]
for res in lang_res:
predicted_class_id = int(res.probs.top1)
predicted_class_name = self.model.names[predicted_class_id]
images_lang_res.append(predicted_class_name)
return images_lang_res
\ No newline at end of file
# Copyright (c) Opendatalab. All rights reserved.
......@@ -9,7 +9,7 @@ class DocLayoutYOLOModel(object):
def predict(self, image):
layout_res = []
doclayout_yolo_res = self.model.predict(
image, imgsz=1024, conf=0.25, iou=0.45, verbose=True, device=self.device
image, imgsz=1024, conf=0.25, iou=0.45, verbose=False, device=self.device
)[0]
for xyxy, conf, cla in zip(
doclayout_yolo_res.boxes.xyxy.cpu(),
......@@ -35,7 +35,7 @@ class DocLayoutYOLOModel(object):
imgsz=1024,
conf=0.25,
iou=0.45,
verbose=True,
verbose=False,
device=self.device,
)
]
......
......@@ -5,4 +5,5 @@ weights:
unimernet_small: MFR/unimernet_small
struct_eqtable: TabRec/StructEqTable
tablemaster: TabRec/TableMaster
rapid_table: TabRec/RapidTable
\ No newline at end of file
rapid_table: TabRec/RapidTable
yolo_v11n_langdetect: LangDetect/YOLO/yolo_v11_cls_ft.pt
\ No newline at end of file
......@@ -95,9 +95,6 @@ def do_parse(
f_draw_model_bbox = True
f_draw_line_sort_bbox = True
if lang == '':
lang = None
pdf_bytes = convert_pdf_bytes_to_bytes_by_pymupdf(
pdf_bytes, start_page_id, end_page_id
)
......@@ -109,7 +106,7 @@ def do_parse(
)
image_dir = str(os.path.basename(local_image_dir))
ds = PymuDocDataset(pdf_bytes)
ds = PymuDocDataset(pdf_bytes, lang=lang)
if len(model_list) == 0:
if model_config.__use_inside_model__:
......@@ -118,50 +115,50 @@ def do_parse(
infer_result = ds.apply(
doc_analyze,
ocr=False,
lang=lang,
lang=ds._lang,
layout_model=layout_model,
formula_enable=formula_enable,
table_enable=table_enable,
)
pipe_result = infer_result.pipe_txt_mode(
image_writer, debug_mode=True, lang=lang
image_writer, debug_mode=True, lang=ds._lang
)
else:
infer_result = ds.apply(
doc_analyze,
ocr=True,
lang=lang,
lang=ds._lang,
layout_model=layout_model,
formula_enable=formula_enable,
table_enable=table_enable,
)
pipe_result = infer_result.pipe_ocr_mode(
image_writer, debug_mode=True, lang=lang
image_writer, debug_mode=True, lang=ds._lang
)
elif parse_method == 'txt':
infer_result = ds.apply(
doc_analyze,
ocr=False,
lang=lang,
lang=ds._lang,
layout_model=layout_model,
formula_enable=formula_enable,
table_enable=table_enable,
)
pipe_result = infer_result.pipe_txt_mode(
image_writer, debug_mode=True, lang=lang
image_writer, debug_mode=True, lang=ds._lang
)
elif parse_method == 'ocr':
infer_result = ds.apply(
doc_analyze,
ocr=True,
lang=lang,
lang=ds._lang,
layout_model=layout_model,
formula_enable=formula_enable,
table_enable=table_enable,
)
pipe_result = infer_result.pipe_ocr_mode(
image_writer, debug_mode=True, lang=lang
image_writer, debug_mode=True, lang=ds._lang
)
else:
logger.error('unknown parse method')
......@@ -174,20 +171,20 @@ def do_parse(
infer_result = InferenceResult(model_list, ds)
if parse_method == 'ocr':
pipe_result = infer_result.pipe_ocr_mode(
image_writer, debug_mode=True, lang=lang
image_writer, debug_mode=True, lang=ds._lang
)
elif parse_method == 'txt':
pipe_result = infer_result.pipe_txt_mode(
image_writer, debug_mode=True, lang=lang
image_writer, debug_mode=True, lang=ds._lang
)
else:
if ds.classify() == SupportedPdfParseMethod.TXT:
pipe_result = infer_result.pipe_txt_mode(
image_writer, debug_mode=True, lang=lang
image_writer, debug_mode=True, lang=ds._lang
)
else:
pipe_result = infer_result.pipe_ocr_mode(
image_writer, debug_mode=True, lang=lang
image_writer, debug_mode=True, lang=ds._lang
)
......
......@@ -159,7 +159,7 @@ devanagari_lang = [
]
other_lang = ['ch', 'en', 'korean', 'japan', 'chinese_cht', 'ta', 'te', 'ka']
all_lang = ['']
all_lang = ['', 'auto']
all_lang.extend([*other_lang, *latin_lang, *arabic_lang, *cyrillic_lang, *devanagari_lang])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment