Unverified Commit 4bb54393 authored by Xiaomeng Zhao's avatar Xiaomeng Zhao Committed by GitHub
Browse files

Merge pull request #1427 from opendatalab/release-1.0.0

Release 1.0.0
parents 04f084ac 1c9f9942
......@@ -55,7 +55,7 @@ class FileBasedDataWriter(DataWriter):
if not os.path.isabs(fn_path) and len(self._parent_dir) > 0:
fn_path = os.path.join(self._parent_dir, path)
if not os.path.exists(os.path.dirname(fn_path)):
if not os.path.exists(os.path.dirname(fn_path)) and os.path.dirname(fn_path) != "":
os.makedirs(os.path.dirname(fn_path), exist_ok=True)
with open(fn_path, 'wb') as f:
......
import os
from magic_pdf.config.exceptions import InvalidConfig, InvalidParams
from magic_pdf.data.data_reader_writer.base import DataReader, DataWriter
from magic_pdf.data.io.s3 import S3Reader, S3Writer
......@@ -22,10 +22,10 @@ class MultiS3Mixin:
"""
if len(default_prefix) == 0:
raise InvalidConfig('default_prefix must be provided')
arr = default_prefix.strip("/").split("/")
arr = default_prefix.strip('/').split('/')
self.default_bucket = arr[0]
self.default_prefix = "/".join(arr[1:])
self.default_prefix = '/'.join(arr[1:])
found_default_bucket_config = False
for conf in s3_configs:
......@@ -103,7 +103,8 @@ class MultiBucketS3DataReader(DataReader, MultiS3Mixin):
s3_reader = self.__get_s3_client(bucket_name)
else:
s3_reader = self.__get_s3_client(self.default_bucket)
path = os.path.join(self.default_prefix, path)
if self.default_prefix:
path = self.default_prefix + '/' + path
return s3_reader.read_at(path, offset, limit)
......@@ -139,5 +140,6 @@ class MultiBucketS3DataWriter(DataWriter, MultiS3Mixin):
s3_writer = self.__get_s3_client(bucket_name)
else:
s3_writer = self.__get_s3_client(self.default_bucket)
path = os.path.join(self.default_prefix, path)
if self.default_prefix:
path = self.default_prefix + '/' + path
return s3_writer.write(path, data)
......@@ -3,6 +3,7 @@ from abc import ABC, abstractmethod
from typing import Callable, Iterator
import fitz
from loguru import logger
from magic_pdf.config.enums import SupportedPdfParseMethod
from magic_pdf.data.schemas import PageInfo
......@@ -133,7 +134,7 @@ class Dataset(ABC):
class PymuDocDataset(Dataset):
def __init__(self, bits: bytes):
def __init__(self, bits: bytes, lang=None):
"""Initialize the dataset, which wraps the pymudoc documents.
Args:
......@@ -144,6 +145,15 @@ class PymuDocDataset(Dataset):
self._data_bits = bits
self._raw_data = bits
if lang == '':
self._lang = None
elif lang == 'auto':
from magic_pdf.model.sub_modules.language_detection.utils import auto_detect_lang
self._lang = auto_detect_lang(bits)
logger.info(f"lang: {lang}, detect_lang: {self._lang}")
else:
self._lang = lang
logger.info(f"lang: {lang}")
def __len__(self) -> int:
"""The page number of the pdf."""
return len(self._records)
......@@ -197,6 +207,8 @@ class PymuDocDataset(Dataset):
Returns:
Any: return the result generated by proc
"""
if 'lang' in kwargs and self._lang is not None:
kwargs['lang'] = self._lang
return proc(self, *args, **kwargs)
def classify(self) -> SupportedPdfParseMethod:
......
import json
import os
import tempfile
import shutil
from pathlib import Path
from magic_pdf.config.exceptions import EmptyData, InvalidParams
from magic_pdf.data.data_reader_writer import (FileBasedDataReader,
MultiBucketS3DataReader)
from magic_pdf.data.dataset import ImageDataset, PymuDocDataset
from magic_pdf.utils.office_to_pdf import convert_file_to_pdf, ConvertToPdfError
def read_jsonl(
s3_path_or_local: str, s3_client: MultiBucketS3DataReader | None = None
......@@ -58,23 +60,68 @@ def read_local_pdfs(path: str) -> list[PymuDocDataset]:
list[PymuDocDataset]: each pdf file will converted to a PymuDocDataset
"""
if os.path.isdir(path):
reader = FileBasedDataReader(path)
return [
PymuDocDataset(reader.read(doc_path.name))
for doc_path in Path(path).glob('*.pdf')
]
reader = FileBasedDataReader()
ret = []
for root, _, files in os.walk(path):
for file in files:
suffix = file.split('.')
if suffix[-1] == 'pdf':
ret.append( PymuDocDataset(reader.read(os.path.join(root, file))))
return ret
else:
reader = FileBasedDataReader()
bits = reader.read(path)
return [PymuDocDataset(bits)]
def read_local_office(path: str) -> list[PymuDocDataset]:
"""Read ms-office file (ppt, pptx, doc, docx) from path or directory.
def read_local_images(path: str, suffixes: list[str]) -> list[ImageDataset]:
Args:
path (str): ms-office file or directory that contains ms-office files
Returns:
list[PymuDocDataset]: each ms-office file will converted to a PymuDocDataset
Raises:
ConvertToPdfError: Failed to convert ms-office file to pdf via libreoffice
FileNotFoundError: File not Found
Exception: Unknown Exception raised
"""
suffixes = ['.ppt', '.pptx', '.doc', '.docx']
fns = []
ret = []
if os.path.isdir(path):
for root, _, files in os.walk(path):
for file in files:
suffix = Path(file).suffix
if suffix in suffixes:
fns.append((os.path.join(root, file)))
else:
fns.append(path)
reader = FileBasedDataReader()
temp_dir = tempfile.mkdtemp()
for fn in fns:
try:
convert_file_to_pdf(fn, temp_dir)
except ConvertToPdfError as e:
raise e
except FileNotFoundError as e:
raise e
except Exception as e:
raise e
fn_path = Path(fn)
pdf_fn = f"{temp_dir}/{fn_path.stem}.pdf"
ret.append(PymuDocDataset(reader.read(pdf_fn)))
shutil.rmtree(temp_dir)
return ret
def read_local_images(path: str, suffixes: list[str]=['.png', '.jpg']) -> list[ImageDataset]:
"""Read images from path or directory.
Args:
path (str): image file path or directory that contains image files
suffixes (list[str]): the suffixes of the image files used to filter the files. Example: ['jpg', 'png']
suffixes (list[str]): the suffixes of the image files used to filter the files. Example: ['.jpg', '.png']
Returns:
list[ImageDataset]: each image file will converted to a ImageDataset
......@@ -82,12 +129,12 @@ def read_local_images(path: str, suffixes: list[str]) -> list[ImageDataset]:
if os.path.isdir(path):
imgs_bits = []
s_suffixes = set(suffixes)
reader = FileBasedDataReader(path)
reader = FileBasedDataReader()
for root, _, files in os.walk(path):
for file in files:
suffix = file.split('.')
if suffix[-1] in s_suffixes:
imgs_bits.append(reader.read(file))
suffix = Path(file).suffix
if suffix in s_suffixes:
imgs_bits.append(reader.read(os.path.join(root, file)))
return [ImageDataset(bits) for bits in imgs_bits]
else:
reader = FileBasedDataReader()
......
import fitz
import numpy as np
from loguru import logger
from magic_pdf.utils.annotations import ImportPIL
......@@ -30,3 +31,37 @@ def fitz_doc_to_image(doc, dpi=200) -> dict:
img_dict = {'img': img, 'width': pm.width, 'height': pm.height}
return img_dict
@ImportPIL
def load_images_from_pdf(pdf_bytes: bytes, dpi=200, start_page_id=0, end_page_id=None) -> list:
from PIL import Image
images = []
with fitz.open('pdf', pdf_bytes) as doc:
pdf_page_num = doc.page_count
end_page_id = (
end_page_id
if end_page_id is not None and end_page_id >= 0
else pdf_page_num - 1
)
if end_page_id > pdf_page_num - 1:
logger.warning('end_page_id is out of range, use images length')
end_page_id = pdf_page_num - 1
for index in range(0, doc.page_count):
if start_page_id <= index <= end_page_id:
page = doc[index]
mat = fitz.Matrix(dpi / 72, dpi / 72)
pm = page.get_pixmap(matrix=mat, alpha=False)
# If the width or height exceeds 4500 after scaling, do not scale further.
if pm.width > 4500 or pm.height > 4500:
pm = page.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False)
img = Image.frombytes('RGB', (pm.width, pm.height), pm.samples)
img = np.array(img)
img_dict = {'img': img, 'width': pm.width, 'height': pm.height}
else:
img_dict = {'img': [], 'width': 0, 'height': 0}
images.append(img_dict)
return images
......@@ -7,7 +7,7 @@ from magic_pdf.config.ocr_content_type import BlockType, ContentType
from magic_pdf.libs.commons import join_path
from magic_pdf.libs.language import detect_lang
from magic_pdf.libs.markdown_utils import ocr_escape_special_markdown_char
from magic_pdf.para.para_split_v3 import ListLineTag
from magic_pdf.post_proc.para_split_v3 import ListLineTag
def __is_hyphen_at_line_end(line):
......@@ -61,7 +61,8 @@ def ocr_mk_markdown_with_para_core_v2(paras_of_layout,
if para_type in [BlockType.Text, BlockType.List, BlockType.Index]:
para_text = merge_para_with_text(para_block)
elif para_type == BlockType.Title:
para_text = f'# {merge_para_with_text(para_block)}'
title_level = get_title_level(para_block)
para_text = f'{"#" * title_level} {merge_para_with_text(para_block)}'
elif para_type == BlockType.InterlineEquation:
para_text = merge_para_with_text(para_block)
elif para_type == BlockType.Image:
......@@ -125,16 +126,6 @@ def detect_language(text):
return 'empty'
# 连写字符拆分
def __replace_ligatures(text: str):
text = re.sub(r'fi', 'fi', text) # 替换 fi 连写符
text = re.sub(r'fl', 'fl', text) # 替换 fl 连写符
text = re.sub(r'ff', 'ff', text) # 替换 ff 连写符
text = re.sub(r'ffi', 'ffi', text) # 替换 ffi 连写符
text = re.sub(r'ffl', 'ffl', text) # 替换 ffl 连写符
return text
def merge_para_with_text(para_block):
block_text = ''
for line in para_block['lines']:
......@@ -196,10 +187,11 @@ def para_to_standard_format_v2(para_block, img_buket_path, page_idx, drop_reason
'text': merge_para_with_text(para_block),
}
elif para_type == BlockType.Title:
title_level = get_title_level(para_block)
para_content = {
'type': 'text',
'text': merge_para_with_text(para_block),
'text_level': 1,
'text_level': title_level,
}
elif para_type == BlockType.InterlineEquation:
para_content = {
......@@ -299,3 +291,12 @@ def union_make(pdf_info_dict: list,
return '\n\n'.join(output_content)
elif make_mode == MakeMode.STANDARD_FORMAT:
return output_content
def get_title_level(block):
title_level = block.get('level', 1)
if title_level > 4:
title_level = 4
elif title_level < 1:
title_level = 1
return title_level
\ No newline at end of file
......@@ -3,8 +3,15 @@ import torch
import gc
def clean_memory():
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
def clean_memory(device='cuda'):
if device == 'cuda':
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
elif str(device).startswith("npu"):
import torch_npu
if torch_npu.npu.is_available():
torch_npu.npu.empty_cache()
elif str(device).startswith("mps"):
torch.mps.empty_cache()
gc.collect()
\ No newline at end of file
......@@ -116,6 +116,15 @@ def get_formula_config():
else:
return formula_config
def get_llm_aided_config():
config = read_config()
llm_aided_config = config.get('llm-aided-config')
if llm_aided_config is None:
logger.warning(f"'llm-aided-config' not found in {CONFIG_FILE_NAME}, use 'None' as default")
return None
else:
return llm_aided_config
if __name__ == '__main__':
ak, sk, endpoint = get_s3_config('llm-raw')
......@@ -394,17 +394,13 @@ def draw_line_sort_bbox(pdf_info, pdf_bytes, out_path, filename):
pdf_docs.save(f'{out_path}/{filename}')
def draw_layout_sort_bbox(pdf_info, pdf_bytes, out_path, filename):
layout_bbox_list = []
for page in pdf_info:
page_block_list = []
for block in page['para_blocks']:
bbox = block['bbox']
page_block_list.append(bbox)
layout_bbox_list.append(page_block_list)
def draw_char_bbox(pdf_bytes, out_path, filename):
pdf_docs = fitz.open('pdf', pdf_bytes)
for i, page in enumerate(pdf_docs):
draw_bbox_with_number(i, layout_bbox_list, page, [255, 0, 0], False)
pdf_docs.save(f'{out_path}/{filename}_layout_sort.pdf')
for block in page.get_text('rawdict', flags=fitz.TEXT_PRESERVE_LIGATURES | fitz.TEXT_PRESERVE_WHITESPACE | fitz.TEXT_MEDIABOX_CLIP)['blocks']:
for line in block['lines']:
for span in line['spans']:
for char in span['chars']:
char_bbox = char['bbox']
page.draw_rect(char_bbox, color=[1, 0, 0], fill=None, fill_opacity=1, width=0.3, overlay=True,)
pdf_docs.save(f'{out_path}/{filename}')
......@@ -16,11 +16,14 @@ def detect_lang(text: str) -> str:
if len(text) == 0:
return ""
text = text.replace("\n", "")
try:
lang_upper = detect_language(text)
except:
html_no_ctrl_chars = ''.join([l for l in text if unicodedata.category(l)[0] not in ['C', ]])
lang_upper = detect_language(html_no_ctrl_chars)
try:
lang = lang_upper.lower()
except:
......
from typing import Callable
from abc import ABC, abstractmethod
from magic_pdf.data.data_reader_writer import DataWriter
from magic_pdf.data.dataset import Dataset
from magic_pdf.pipe.operators import PipeResult
__use_inside_model__ = True
__model_mode__ = "full"
class InferenceResultBase(ABC):
@abstractmethod
def __init__(self, inference_results: list, dataset: Dataset):
"""Initialized method.
Args:
inference_results (list): the inference result generated by model
dataset (Dataset): the dataset related with model inference result
"""
self._infer_res = inference_results
self._dataset = dataset
@abstractmethod
def draw_model(self, file_path: str) -> None:
"""Draw model inference result.
Args:
file_path (str): the output file path
"""
pass
@abstractmethod
def dump_model(self, writer: DataWriter, file_path: str):
"""Dump model inference result to file.
Args:
writer (DataWriter): writer handle
file_path (str): the location of target file
"""
pass
@abstractmethod
def get_infer_res(self):
"""Get the inference result.
Returns:
list: the inference result generated by model
"""
pass
@abstractmethod
def apply(self, proc: Callable, *args, **kwargs):
"""Apply callable method which.
Args:
proc (Callable): invoke proc as follows:
proc(inference_result, *args, **kwargs)
Returns:
Any: return the result generated by proc
"""
pass
@abstractmethod
def pipe_auto_mode(
self,
imageWriter: DataWriter,
start_page_id=0,
end_page_id=None,
debug_mode=False,
lang=None,
) -> PipeResult:
"""Post-proc the model inference result.
step1: classify the dataset type
step2: based the result of step1, using `pipe_txt_mode` or `pipe_ocr_mode`
Args:
imageWriter (DataWriter): the image writer handle
start_page_id (int, optional): Defaults to 0. Let user select some pages He/She want to process
end_page_id (int, optional): Defaults to the last page index of dataset. Let user select some pages He/She want to process
debug_mode (bool, optional): Defaults to False. will dump more log if enabled
lang (str, optional): Defaults to None.
Returns:
PipeResult: the result
"""
pass
@abstractmethod
def pipe_txt_mode(
self,
imageWriter: DataWriter,
start_page_id=0,
end_page_id=None,
debug_mode=False,
lang=None,
) -> PipeResult:
"""Post-proc the model inference result, Extract the text using the
third library, such as `pymupdf`
Args:
imageWriter (DataWriter): the image writer handle
start_page_id (int, optional): Defaults to 0. Let user select some pages He/She want to process
end_page_id (int, optional): Defaults to the last page index of dataset. Let user select some pages He/She want to process
debug_mode (bool, optional): Defaults to False. will dump more log if enabled
lang (str, optional): Defaults to None.
Returns:
PipeResult: the result
"""
pass
@abstractmethod
def pipe_ocr_mode(
self,
imageWriter: DataWriter,
start_page_id=0,
end_page_id=None,
debug_mode=False,
lang=None,
) -> PipeResult:
pass
__model_mode__ = 'full'
\ No newline at end of file
import time
import cv2
import numpy as np
import torch
from loguru import logger
from PIL import Image
from magic_pdf.config.constants import MODEL_NAME
from magic_pdf.config.exceptions import CUDA_NOT_AVAILABLE
from magic_pdf.data.dataset import Dataset
from magic_pdf.libs.clean_memory import clean_memory
from magic_pdf.libs.config_reader import get_device
from magic_pdf.model.doc_analyze_by_custom_model import ModelSingleton
from magic_pdf.model.pdf_extract_kit import CustomPEKModel
from magic_pdf.model.sub_modules.model_utils import (
clean_vram, crop_img, get_res_list_from_layout_res)
from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import (
get_adjusted_mfdetrec_res, get_ocr_result_list)
from magic_pdf.operators.models import InferenceResult
YOLO_LAYOUT_BASE_BATCH_SIZE = 4
MFD_BASE_BATCH_SIZE = 1
MFR_BASE_BATCH_SIZE = 16
class BatchAnalyze:
def __init__(self, model: CustomPEKModel, batch_ratio: int):
self.model = model
self.batch_ratio = batch_ratio
def __call__(self, images: list) -> list:
images_layout_res = []
layout_start_time = time.time()
if self.model.layout_model_name == MODEL_NAME.LAYOUTLMv3:
# layoutlmv3
for image in images:
layout_res = self.model.layout_model(image, ignore_catids=[])
images_layout_res.append(layout_res)
elif self.model.layout_model_name == MODEL_NAME.DocLayout_YOLO:
# doclayout_yolo
layout_images = []
modified_images = []
for image_index, image in enumerate(images):
pil_img = Image.fromarray(image)
width, height = pil_img.size
if height > width:
input_res = {'poly': [0, 0, width, 0, width, height, 0, height]}
new_image, useful_list = crop_img(
input_res, pil_img, crop_paste_x=width // 2, crop_paste_y=0
)
layout_images.append(new_image)
modified_images.append([image_index, useful_list])
else:
layout_images.append(pil_img)
images_layout_res += self.model.layout_model.batch_predict(
layout_images, self.batch_ratio * YOLO_LAYOUT_BASE_BATCH_SIZE
)
for image_index, useful_list in modified_images:
for res in images_layout_res[image_index]:
for i in range(len(res['poly'])):
if i % 2 == 0:
res['poly'][i] = (
res['poly'][i] - useful_list[0] + useful_list[2]
)
else:
res['poly'][i] = (
res['poly'][i] - useful_list[1] + useful_list[3]
)
logger.info(
f'layout time: {round(time.time() - layout_start_time, 2)}, image num: {len(images)}'
)
if self.model.apply_formula:
# 公式检测
mfd_start_time = time.time()
images_mfd_res = self.model.mfd_model.batch_predict(
images, self.batch_ratio * MFD_BASE_BATCH_SIZE
)
logger.info(
f'mfd time: {round(time.time() - mfd_start_time, 2)}, image num: {len(images)}'
)
# 公式识别
mfr_start_time = time.time()
images_formula_list = self.model.mfr_model.batch_predict(
images_mfd_res,
images,
batch_size=self.batch_ratio * MFR_BASE_BATCH_SIZE,
)
for image_index in range(len(images)):
images_layout_res[image_index] += images_formula_list[image_index]
logger.info(
f'mfr time: {round(time.time() - mfr_start_time, 2)}, image num: {len(images)}'
)
# 清理显存
clean_vram(self.model.device, vram_threshold=8)
ocr_time = 0
ocr_count = 0
table_time = 0
table_count = 0
# reference: magic_pdf/model/doc_analyze_by_custom_model.py:doc_analyze
for index in range(len(images)):
layout_res = images_layout_res[index]
pil_img = Image.fromarray(images[index])
ocr_res_list, table_res_list, single_page_mfdetrec_res = (
get_res_list_from_layout_res(layout_res)
)
# ocr识别
ocr_start = time.time()
# Process each area that requires OCR processing
for res in ocr_res_list:
new_image, useful_list = crop_img(
res, pil_img, crop_paste_x=50, crop_paste_y=50
)
adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(
single_page_mfdetrec_res, useful_list
)
# OCR recognition
new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
if self.model.apply_ocr:
ocr_res = self.model.ocr_model.ocr(
new_image, mfd_res=adjusted_mfdetrec_res
)[0]
else:
ocr_res = self.model.ocr_model.ocr(
new_image, mfd_res=adjusted_mfdetrec_res, rec=False
)[0]
# Integration results
if ocr_res:
ocr_result_list = get_ocr_result_list(ocr_res, useful_list)
layout_res.extend(ocr_result_list)
ocr_time += time.time() - ocr_start
ocr_count += len(ocr_res_list)
# 表格识别 table recognition
if self.model.apply_table:
table_start = time.time()
for res in table_res_list:
new_image, _ = crop_img(res, pil_img)
single_table_start_time = time.time()
html_code = None
if self.model.table_model_name == MODEL_NAME.STRUCT_EQTABLE:
with torch.no_grad():
table_result = self.model.table_model.predict(
new_image, 'html'
)
if len(table_result) > 0:
html_code = table_result[0]
elif self.model.table_model_name == MODEL_NAME.TABLE_MASTER:
html_code = self.model.table_model.img2html(new_image)
elif self.model.table_model_name == MODEL_NAME.RAPID_TABLE:
html_code, table_cell_bboxes, elapse = (
self.model.table_model.predict(new_image)
)
run_time = time.time() - single_table_start_time
if run_time > self.model.table_max_time:
logger.warning(
f'table recognition processing exceeds max time {self.model.table_max_time}s'
)
# 判断是否返回正常
if html_code:
expected_ending = html_code.strip().endswith(
'</html>'
) or html_code.strip().endswith('</table>')
if expected_ending:
res['html'] = html_code
else:
logger.warning(
'table recognition processing fails, not found expected HTML table end'
)
else:
logger.warning(
'table recognition processing fails, not get html return'
)
table_time += time.time() - table_start
table_count += len(table_res_list)
if self.model.apply_ocr:
logger.info(f'ocr time: {round(ocr_time, 2)}, image num: {ocr_count}')
else:
logger.info(f'det time: {round(ocr_time, 2)}, image num: {ocr_count}')
if self.model.apply_table:
logger.info(f'table time: {round(table_time, 2)}, image num: {table_count}')
return images_layout_res
def doc_batch_analyze(
dataset: Dataset,
ocr: bool = False,
show_log: bool = False,
start_page_id=0,
end_page_id=None,
lang=None,
layout_model=None,
formula_enable=None,
table_enable=None,
batch_ratio: int | None = None,
) -> InferenceResult:
"""Perform batch analysis on a document dataset.
Args:
dataset (Dataset): The dataset containing document pages to be analyzed.
ocr (bool, optional): Flag to enable OCR (Optical Character Recognition). Defaults to False.
show_log (bool, optional): Flag to enable logging. Defaults to False.
start_page_id (int, optional): The starting page ID for analysis. Defaults to 0.
end_page_id (int, optional): The ending page ID for analysis. Defaults to None, which means analyze till the last page.
lang (str, optional): Language for OCR. Defaults to None.
layout_model (optional): Layout model to be used for analysis. Defaults to None.
formula_enable (optional): Flag to enable formula detection. Defaults to None.
table_enable (optional): Flag to enable table detection. Defaults to None.
batch_ratio (int | None, optional): Ratio for batch processing. Defaults to None, which sets it to 1.
Raises:
CUDA_NOT_AVAILABLE: If CUDA is not available, raises an exception as batch analysis is not supported in CPU mode.
Returns:
InferenceResult: The result of the batch analysis containing the analyzed data and the dataset.
"""
if not torch.cuda.is_available():
raise CUDA_NOT_AVAILABLE('batch analyze not support in CPU mode')
lang = None if lang == '' else lang
# TODO: auto detect batch size
batch_ratio = 1 if batch_ratio is None else batch_ratio
end_page_id = end_page_id if end_page_id else len(dataset)
model_manager = ModelSingleton()
custom_model: CustomPEKModel = model_manager.get_model(
ocr, show_log, lang, layout_model, formula_enable, table_enable
)
batch_model = BatchAnalyze(model=custom_model, batch_ratio=batch_ratio)
model_json = []
# batch analyze
images = []
for index in range(len(dataset)):
if start_page_id <= index <= end_page_id:
page_data = dataset.get_page(index)
img_dict = page_data.get_image()
images.append(img_dict['img'])
analyze_result = batch_model(images)
for index in range(len(dataset)):
page_data = dataset.get_page(index)
img_dict = page_data.get_image()
page_width = img_dict['width']
page_height = img_dict['height']
if start_page_id <= index <= end_page_id:
result = analyze_result.pop(0)
else:
result = []
page_info = {'page_no': index, 'height': page_height, 'width': page_width}
page_dict = {'layout_dets': result, 'page_info': page_info}
model_json.append(page_dict)
# TODO: clean memory when gpu memory is not enough
clean_memory_start_time = time.time()
clean_memory(get_device())
logger.info(f'clean memory time: {round(time.time() - clean_memory_start_time, 2)}')
return InferenceResult(model_json, dataset)
import os
import time
import fitz
import numpy as np
from loguru import logger
# 关闭paddle的信号处理
import paddle
from loguru import logger
paddle.disable_signal_handler()
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
os.environ['YOLO_VERBOSE'] = 'False' # disable yolo logger
try:
import torchtext
......@@ -28,7 +25,7 @@ from magic_pdf.libs.config_reader import (get_device, get_formula_config,
get_local_models_dir,
get_table_recog_config)
from magic_pdf.model.model_list import MODEL
from magic_pdf.model.operators import InferenceResult
from magic_pdf.operators.models import InferenceResult
def dict_compare(d1, d2):
......@@ -45,47 +42,6 @@ def remove_duplicates_dicts(lst):
return unique_dicts
def load_images_from_pdf(
pdf_bytes: bytes, dpi=200, start_page_id=0, end_page_id=None
) -> list:
try:
from PIL import Image
except ImportError:
logger.error('Pillow not installed, please install by pip.')
exit(1)
images = []
with fitz.open('pdf', pdf_bytes) as doc:
pdf_page_num = doc.page_count
end_page_id = (
end_page_id
if end_page_id is not None and end_page_id >= 0
else pdf_page_num - 1
)
if end_page_id > pdf_page_num - 1:
logger.warning('end_page_id is out of range, use images length')
end_page_id = pdf_page_num - 1
for index in range(0, doc.page_count):
if start_page_id <= index <= end_page_id:
page = doc[index]
mat = fitz.Matrix(dpi / 72, dpi / 72)
pm = page.get_pixmap(matrix=mat, alpha=False)
# If the width or height exceeds 4500 after scaling, do not scale further.
if pm.width > 4500 or pm.height > 4500:
pm = page.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False)
img = Image.frombytes('RGB', (pm.width, pm.height), pm.samples)
img = np.array(img)
img_dict = {'img': img, 'width': pm.width, 'height': pm.height}
else:
img_dict = {'img': [], 'width': 0, 'height': 0}
images.append(img_dict)
return images
class ModelSingleton:
_instance = None
_models = {}
......@@ -198,9 +154,6 @@ def doc_analyze(
table_enable=None,
) -> InferenceResult:
if lang == '':
lang = None
model_manager = ModelSingleton()
custom_model = model_manager.get_model(
ocr, show_log, lang, layout_model, formula_enable, table_enable
......@@ -230,7 +183,7 @@ def doc_analyze(
model_json.append(page_dict)
gc_start = time.time()
clean_memory()
clean_memory(get_device())
gc_time = round(time.time() - gc_start, 2)
logger.info(f'gc time: {gc_time}')
......
......@@ -3,12 +3,9 @@ import enum
from magic_pdf.config.model_block_type import ModelBlockTypeEnum
from magic_pdf.config.ocr_content_type import CategoryId, ContentType
from magic_pdf.data.dataset import Dataset
from magic_pdf.libs.boxbase import (_is_in, _is_part_overlap, bbox_distance,
bbox_relative_pos, box_area, calculate_iou,
calculate_overlap_area_in_bbox1_area_ratio,
get_overlap_area)
from magic_pdf.libs.boxbase import (_is_in, bbox_distance, bbox_relative_pos,
calculate_iou)
from magic_pdf.libs.coordinate_transform import get_scale_ratio
from magic_pdf.libs.local_math import float_gt
from magic_pdf.pre_proc.remove_bbox_overlap import _remove_overlap_between_bbox
CAPATION_OVERLAP_AREA_RATIO = 0.6
......@@ -208,393 +205,6 @@ class MagicModel:
keep[i] = False
return [bboxes[i] for i in range(N) if keep[i]]
def __tie_up_category_by_distance(
self, page_no, subject_category_id, object_category_id
):
"""假定每个 subject 最多有一个 object (可以有多个相邻的 object 合并为单个 object),每个 object
只能属于一个 subject."""
ret = []
MAX_DIS_OF_POINT = 10**9 + 7
"""
subject 和 object 的 bbox 会合并成一个大的 bbox (named: merged bbox)。
筛选出所有和 merged bbox 有 overlap 且 overlap 面积大于 object 的面积的 subjects。
再求出筛选出的 subjects 和 object 的最短距离
"""
def search_overlap_between_boxes(subject_idx, object_idx):
idxes = [subject_idx, object_idx]
x0s = [all_bboxes[idx]['bbox'][0] for idx in idxes]
y0s = [all_bboxes[idx]['bbox'][1] for idx in idxes]
x1s = [all_bboxes[idx]['bbox'][2] for idx in idxes]
y1s = [all_bboxes[idx]['bbox'][3] for idx in idxes]
merged_bbox = [
min(x0s),
min(y0s),
max(x1s),
max(y1s),
]
ratio = 0
other_objects = list(
map(
lambda x: {'bbox': x['bbox'], 'score': x['score']},
filter(
lambda x: x['category_id']
not in (object_category_id, subject_category_id),
self.__model_list[page_no]['layout_dets'],
),
)
)
for other_object in other_objects:
ratio = max(
ratio,
get_overlap_area(merged_bbox, other_object['bbox'])
* 1.0
/ box_area(all_bboxes[object_idx]['bbox']),
)
if ratio >= MERGE_BOX_OVERLAP_AREA_RATIO:
break
return ratio
def may_find_other_nearest_bbox(subject_idx, object_idx):
ret = float('inf')
x0 = min(
all_bboxes[subject_idx]['bbox'][0], all_bboxes[object_idx]['bbox'][0]
)
y0 = min(
all_bboxes[subject_idx]['bbox'][1], all_bboxes[object_idx]['bbox'][1]
)
x1 = max(
all_bboxes[subject_idx]['bbox'][2], all_bboxes[object_idx]['bbox'][2]
)
y1 = max(
all_bboxes[subject_idx]['bbox'][3], all_bboxes[object_idx]['bbox'][3]
)
object_area = abs(
all_bboxes[object_idx]['bbox'][2] - all_bboxes[object_idx]['bbox'][0]
) * abs(
all_bboxes[object_idx]['bbox'][3] - all_bboxes[object_idx]['bbox'][1]
)
for i in range(len(all_bboxes)):
if (
i == subject_idx
or all_bboxes[i]['category_id'] != subject_category_id
):
continue
if _is_part_overlap([x0, y0, x1, y1], all_bboxes[i]['bbox']) or _is_in(
all_bboxes[i]['bbox'], [x0, y0, x1, y1]
):
i_area = abs(
all_bboxes[i]['bbox'][2] - all_bboxes[i]['bbox'][0]
) * abs(all_bboxes[i]['bbox'][3] - all_bboxes[i]['bbox'][1])
if i_area >= object_area:
ret = min(float('inf'), dis[i][object_idx])
return ret
def expand_bbbox(idxes):
x0s = [all_bboxes[idx]['bbox'][0] for idx in idxes]
y0s = [all_bboxes[idx]['bbox'][1] for idx in idxes]
x1s = [all_bboxes[idx]['bbox'][2] for idx in idxes]
y1s = [all_bboxes[idx]['bbox'][3] for idx in idxes]
return min(x0s), min(y0s), max(x1s), max(y1s)
subjects = self.__reduct_overlap(
list(
map(
lambda x: {'bbox': x['bbox'], 'score': x['score']},
filter(
lambda x: x['category_id'] == subject_category_id,
self.__model_list[page_no]['layout_dets'],
),
)
)
)
objects = self.__reduct_overlap(
list(
map(
lambda x: {'bbox': x['bbox'], 'score': x['score']},
filter(
lambda x: x['category_id'] == object_category_id,
self.__model_list[page_no]['layout_dets'],
),
)
)
)
subject_object_relation_map = {}
subjects.sort(
key=lambda x: x['bbox'][0] ** 2 + x['bbox'][1] ** 2
) # get the distance !
all_bboxes = []
for v in subjects:
all_bboxes.append(
{
'category_id': subject_category_id,
'bbox': v['bbox'],
'score': v['score'],
}
)
for v in objects:
all_bboxes.append(
{
'category_id': object_category_id,
'bbox': v['bbox'],
'score': v['score'],
}
)
N = len(all_bboxes)
dis = [[MAX_DIS_OF_POINT] * N for _ in range(N)]
for i in range(N):
for j in range(i):
if (
all_bboxes[i]['category_id'] == subject_category_id
and all_bboxes[j]['category_id'] == subject_category_id
):
continue
subject_idx, object_idx = i, j
if all_bboxes[j]['category_id'] == subject_category_id:
subject_idx, object_idx = j, i
if (
search_overlap_between_boxes(subject_idx, object_idx)
>= MERGE_BOX_OVERLAP_AREA_RATIO
):
dis[i][j] = float('inf')
dis[j][i] = dis[i][j]
continue
dis[i][j] = self._bbox_distance(
all_bboxes[subject_idx]['bbox'], all_bboxes[object_idx]['bbox']
)
dis[j][i] = dis[i][j]
used = set()
for i in range(N):
# 求第 i 个 subject 所关联的 object
if all_bboxes[i]['category_id'] != subject_category_id:
continue
seen = set()
candidates = []
arr = []
for j in range(N):
pos_flag_count = sum(
list(
map(
lambda x: 1 if x else 0,
bbox_relative_pos(
all_bboxes[i]['bbox'], all_bboxes[j]['bbox']
),
)
)
)
if pos_flag_count > 1:
continue
if (
all_bboxes[j]['category_id'] != object_category_id
or j in used
or dis[i][j] == MAX_DIS_OF_POINT
):
continue
left, right, _, _ = bbox_relative_pos(
all_bboxes[i]['bbox'], all_bboxes[j]['bbox']
) # 由 pos_flag_count 相关逻辑保证本段逻辑准确性
if left or right:
one_way_dis = all_bboxes[i]['bbox'][2] - all_bboxes[i]['bbox'][0]
else:
one_way_dis = all_bboxes[i]['bbox'][3] - all_bboxes[i]['bbox'][1]
if dis[i][j] > one_way_dis:
continue
arr.append((dis[i][j], j))
arr.sort(key=lambda x: x[0])
if len(arr) > 0:
"""
bug: 离该subject 最近的 object 可能跨越了其它的 subject。
比如 [this subect] [some sbuject] [the nearest object of subject]
"""
if may_find_other_nearest_bbox(i, arr[0][1]) >= arr[0][0]:
candidates.append(arr[0][1])
seen.add(arr[0][1])
# 已经获取初始种子
for j in set(candidates):
tmp = []
for k in range(i + 1, N):
pos_flag_count = sum(
list(
map(
lambda x: 1 if x else 0,
bbox_relative_pos(
all_bboxes[j]['bbox'], all_bboxes[k]['bbox']
),
)
)
)
if pos_flag_count > 1:
continue
if (
all_bboxes[k]['category_id'] != object_category_id
or k in used
or k in seen
or dis[j][k] == MAX_DIS_OF_POINT
or dis[j][k] > dis[i][j]
):
continue
is_nearest = True
for ni in range(i + 1, N):
if ni in (j, k) or ni in used or ni in seen:
continue
if not float_gt(dis[ni][k], dis[j][k]):
is_nearest = False
break
if is_nearest:
nx0, ny0, nx1, ny1 = expand_bbbox(list(seen) + [k])
n_dis = bbox_distance(
all_bboxes[i]['bbox'], [nx0, ny0, nx1, ny1]
)
if float_gt(dis[i][j], n_dis):
continue
tmp.append(k)
seen.add(k)
candidates = tmp
if len(candidates) == 0:
break
# 已经获取到某个 figure 下所有的最靠近的 captions,以及最靠近这些 captions 的 captions 。
# 先扩一下 bbox,
ox0, oy0, ox1, oy1 = expand_bbbox(list(seen) + [i])
ix0, iy0, ix1, iy1 = all_bboxes[i]['bbox']
# 分成了 4 个截取空间,需要计算落在每个截取空间下 objects 合并后占据的矩形面积
caption_poses = [
[ox0, oy0, ix0, oy1],
[ox0, oy0, ox1, iy0],
[ox0, iy1, ox1, oy1],
[ix1, oy0, ox1, oy1],
]
caption_areas = []
for bbox in caption_poses:
embed_arr = []
for idx in seen:
if (
calculate_overlap_area_in_bbox1_area_ratio(
all_bboxes[idx]['bbox'], bbox
)
> CAPATION_OVERLAP_AREA_RATIO
):
embed_arr.append(idx)
if len(embed_arr) > 0:
embed_x0 = min([all_bboxes[idx]['bbox'][0] for idx in embed_arr])
embed_y0 = min([all_bboxes[idx]['bbox'][1] for idx in embed_arr])
embed_x1 = max([all_bboxes[idx]['bbox'][2] for idx in embed_arr])
embed_y1 = max([all_bboxes[idx]['bbox'][3] for idx in embed_arr])
caption_areas.append(
int(abs(embed_x1 - embed_x0) * abs(embed_y1 - embed_y0))
)
else:
caption_areas.append(0)
subject_object_relation_map[i] = []
if max(caption_areas) > 0:
max_area_idx = caption_areas.index(max(caption_areas))
caption_bbox = caption_poses[max_area_idx]
for j in seen:
if (
calculate_overlap_area_in_bbox1_area_ratio(
all_bboxes[j]['bbox'], caption_bbox
)
> CAPATION_OVERLAP_AREA_RATIO
):
used.add(j)
subject_object_relation_map[i].append(j)
for i in sorted(subject_object_relation_map.keys()):
result = {
'subject_body': all_bboxes[i]['bbox'],
'all': all_bboxes[i]['bbox'],
'score': all_bboxes[i]['score'],
}
if len(subject_object_relation_map[i]) > 0:
x0 = min(
[all_bboxes[j]['bbox'][0] for j in subject_object_relation_map[i]]
)
y0 = min(
[all_bboxes[j]['bbox'][1] for j in subject_object_relation_map[i]]
)
x1 = max(
[all_bboxes[j]['bbox'][2] for j in subject_object_relation_map[i]]
)
y1 = max(
[all_bboxes[j]['bbox'][3] for j in subject_object_relation_map[i]]
)
result['object_body'] = [x0, y0, x1, y1]
result['all'] = [
min(x0, all_bboxes[i]['bbox'][0]),
min(y0, all_bboxes[i]['bbox'][1]),
max(x1, all_bboxes[i]['bbox'][2]),
max(y1, all_bboxes[i]['bbox'][3]),
]
ret.append(result)
total_subject_object_dis = 0
# 计算已经配对的 distance 距离
for i in subject_object_relation_map.keys():
for j in subject_object_relation_map[i]:
total_subject_object_dis += bbox_distance(
all_bboxes[i]['bbox'], all_bboxes[j]['bbox']
)
# 计算未匹配的 subject 和 object 的距离(非精确版)
with_caption_subject = set(
[
key
for key in subject_object_relation_map.keys()
if len(subject_object_relation_map[i]) > 0
]
)
for i in range(N):
if all_bboxes[i]['category_id'] != object_category_id or i in used:
continue
candidates = []
for j in range(N):
if (
all_bboxes[j]['category_id'] != subject_category_id
or j in with_caption_subject
):
continue
candidates.append((dis[i][j], j))
if len(candidates) > 0:
candidates.sort(key=lambda x: x[0])
total_subject_object_dis += candidates[0][1]
with_caption_subject.add(j)
return ret, total_subject_object_dis
def __tie_up_category_by_distance_v2(
self,
page_no: int,
......@@ -879,52 +489,12 @@ class MagicModel:
return ret
def get_imgs(self, page_no: int):
with_captions, _ = self.__tie_up_category_by_distance(page_no, 3, 4)
with_footnotes, _ = self.__tie_up_category_by_distance(
page_no, 3, CategoryId.ImageFootnote
)
ret = []
N, M = len(with_captions), len(with_footnotes)
assert N == M
for i in range(N):
record = {
'score': with_captions[i]['score'],
'img_caption_bbox': with_captions[i].get('object_body', None),
'img_body_bbox': with_captions[i]['subject_body'],
'img_footnote_bbox': with_footnotes[i].get('object_body', None),
}
x0 = min(with_captions[i]['all'][0], with_footnotes[i]['all'][0])
y0 = min(with_captions[i]['all'][1], with_footnotes[i]['all'][1])
x1 = max(with_captions[i]['all'][2], with_footnotes[i]['all'][2])
y1 = max(with_captions[i]['all'][3], with_footnotes[i]['all'][3])
record['bbox'] = [x0, y0, x1, y1]
ret.append(record)
return ret
return self.get_imgs_v2(page_no)
def get_tables(
self, page_no: int
) -> list: # 3个坐标, caption, table主体,table-note
with_captions, _ = self.__tie_up_category_by_distance(page_no, 5, 6)
with_footnotes, _ = self.__tie_up_category_by_distance(page_no, 5, 7)
ret = []
N, M = len(with_captions), len(with_footnotes)
assert N == M
for i in range(N):
record = {
'score': with_captions[i]['score'],
'table_caption_bbox': with_captions[i].get('object_body', None),
'table_body_bbox': with_captions[i]['subject_body'],
'table_footnote_bbox': with_footnotes[i].get('object_body', None),
}
x0 = min(with_captions[i]['all'][0], with_footnotes[i]['all'][0])
y0 = min(with_captions[i]['all'][1], with_footnotes[i]['all'][1])
x1 = max(with_captions[i]['all'][2], with_footnotes[i]['all'][2])
y1 = max(with_captions[i]['all'][3], with_footnotes[i]['all'][3])
record['bbox'] = [x0, y0, x1, y1]
ret.append(record)
return ret
return self.get_tables_v2(page_no)
def get_equations(self, page_no: int) -> list: # 有坐标,也有字
inline_equations = self.__get_blocks_by_type(
......@@ -1043,4 +613,3 @@ class MagicModel:
def get_model_list(self, page_no):
return self.__model_list[page_no]
......@@ -9,3 +9,4 @@ class AtomicModel:
MFR = "mfr"
OCR = "ocr"
Table = "table"
LangDetect = "langdetect"
......@@ -10,7 +10,6 @@ from loguru import logger
from PIL import Image
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
os.environ['YOLO_VERBOSE'] = 'False' # disable yolo logger
try:
import torchtext
......@@ -88,6 +87,14 @@ class CustomPEKModel:
)
# 初始化解析方案
self.device = kwargs.get('device', 'cpu')
if str(self.device).startswith("npu"):
import torch_npu
os.environ['FLAGS_npu_jit_compile'] = '0'
os.environ['FLAGS_use_stride_kernel'] = '0'
elif str(self.device).startswith("mps"):
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
logger.info('using device: {}'.format(self.device))
models_dir = kwargs.get(
'models_dir', os.path.join(root_dir, 'resources', 'models')
......@@ -114,11 +121,12 @@ class CustomPEKModel:
os.path.join(models_dir, self.configs['weights'][self.mfr_model_name])
)
mfr_cfg_path = str(os.path.join(model_config_dir, 'UniMERNet', 'demo.yaml'))
self.mfr_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.MFR,
mfr_weight_dir=mfr_weight_dir,
mfr_cfg_path=mfr_cfg_path,
device=self.device,
device='cpu' if str(self.device).startswith("mps") else self.device,
)
# 初始化layout模型
......@@ -165,12 +173,17 @@ class CustomPEKModel:
table_model_path=str(os.path.join(models_dir, table_model_dir)),
table_max_time=self.table_max_time,
device=self.device,
ocr_engine=self.ocr_model,
)
logger.info('DocAnalysis init done!')
def __call__(self, image):
pil_img = Image.fromarray(image)
width, height = pil_img.size
# logger.info(f'width: {width}, height: {height}')
# layout检测
layout_start = time.time()
layout_res = []
......@@ -179,30 +192,28 @@ class CustomPEKModel:
layout_res = self.layout_model(image, ignore_catids=[])
elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
# doclayout_yolo
img_pil = Image.fromarray(image)
width, height = img_pil.size
# logger.info(f'width: {width}, height: {height}')
input_res = {"poly":[0,0,width,0,width,height,0,height]}
new_image, useful_list = crop_img(input_res, img_pil, crop_paste_x=width//2, crop_paste_y=0)
paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
layout_res = self.layout_model.predict(new_image)
for res in layout_res:
p1, p2, p3, p4, p5, p6, p7, p8 = res['poly']
p1 = p1 - paste_x + xmin
p2 = p2 - paste_y + ymin
p3 = p3 - paste_x + xmin
p4 = p4 - paste_y + ymin
p5 = p5 - paste_x + xmin
p6 = p6 - paste_y + ymin
p7 = p7 - paste_x + xmin
p8 = p8 - paste_y + ymin
res['poly'] = [p1, p2, p3, p4, p5, p6, p7, p8]
if height > width:
input_res = {"poly":[0,0,width,0,width,height,0,height]}
new_image, useful_list = crop_img(input_res, pil_img, crop_paste_x=width//2, crop_paste_y=0)
paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
layout_res = self.layout_model.predict(new_image)
for res in layout_res:
p1, p2, p3, p4, p5, p6, p7, p8 = res['poly']
p1 = p1 - paste_x + xmin
p2 = p2 - paste_y + ymin
p3 = p3 - paste_x + xmin
p4 = p4 - paste_y + ymin
p5 = p5 - paste_x + xmin
p6 = p6 - paste_y + ymin
p7 = p7 - paste_x + xmin
p8 = p8 - paste_y + ymin
res['poly'] = [p1, p2, p3, p4, p5, p6, p7, p8]
else:
layout_res = self.layout_model.predict(image)
layout_cost = round(time.time() - layout_start, 2)
logger.info(f'layout detection time: {layout_cost}')
pil_img = Image.fromarray(image)
if self.apply_formula:
# 公式检测
mfd_start = time.time()
......
# Copyright (c) Opendatalab. All rights reserved.
# Copyright (c) Opendatalab. All rights reserved.
import os
from pathlib import Path
import yaml
from PIL import Image
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
from magic_pdf.config.constants import MODEL_NAME
from magic_pdf.data.utils import load_images_from_pdf
from magic_pdf.libs.config_reader import get_local_models_dir, get_device
from magic_pdf.libs.pdf_check import extract_pages
from magic_pdf.model.model_list import AtomicModel
from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
def get_model_config():
local_models_dir = get_local_models_dir()
device = get_device()
current_file_path = os.path.abspath(__file__)
root_dir = Path(current_file_path).parents[3]
model_config_dir = os.path.join(root_dir, 'resources', 'model_config')
config_path = os.path.join(model_config_dir, 'model_configs.yaml')
with open(config_path, 'r', encoding='utf-8') as f:
configs = yaml.load(f, Loader=yaml.FullLoader)
return root_dir, local_models_dir, device, configs
def get_text_images(simple_images):
_, local_models_dir, device, configs = get_model_config()
atom_model_manager = AtomModelSingleton()
temp_layout_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.Layout,
layout_model_name=MODEL_NAME.DocLayout_YOLO,
doclayout_yolo_weights=str(
os.path.join(
local_models_dir, configs['weights'][MODEL_NAME.DocLayout_YOLO]
)
),
device=device,
)
text_images = []
for simple_image in simple_images:
image = Image.fromarray(simple_image['img'])
layout_res = temp_layout_model.predict(image)
# 给textblock截图
for res in layout_res:
if res['category_id'] in [1]:
x1, y1, _, _, x2, y2, _, _ = res['poly']
# 初步清洗(宽和高都小于100)
if x2 - x1 < 100 and y2 - y1 < 100:
continue
text_images.append(image.crop((x1, y1, x2, y2)))
return text_images
def auto_detect_lang(pdf_bytes: bytes):
sample_docs = extract_pages(pdf_bytes)
sample_pdf_bytes = sample_docs.tobytes()
simple_images = load_images_from_pdf(sample_pdf_bytes, dpi=200)
text_images = get_text_images(simple_images)
langdetect_model = model_init(MODEL_NAME.YOLO_V11_LangDetect)
lang = langdetect_model.do_detect(text_images)
return lang
def model_init(model_name: str):
atom_model_manager = AtomModelSingleton()
if model_name == MODEL_NAME.YOLO_V11_LangDetect:
root_dir, _, device, _ = get_model_config()
model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.LangDetect,
langdetect_model_name=MODEL_NAME.YOLO_V11_LangDetect,
langdetect_model_weight=str(os.path.join(root_dir, 'resources', 'yolov11-langdetect', 'yolo_v11_ft.pt')),
device=device,
)
else:
raise ValueError(f"model_name {model_name} not found")
return model
# Copyright (c) Opendatalab. All rights reserved.
from collections import Counter
from uuid import uuid4
import torch
from PIL import Image
from loguru import logger
from ultralytics import YOLO
language_dict = {
"ch": "中文简体",
"en": "英语",
"japan": "日语",
"korean": "韩语",
"fr": "法语",
"german": "德语",
"ar": "阿拉伯语",
"ru": "俄语"
}
def split_images(image, result_images=None):
"""
对输入文件夹内的图片进行处理,若图片竖向(y方向)分辨率超过400,则进行拆分,
每次平分图片,直至拆分出的图片竖向分辨率都满足400以下,将处理后的图片(拆分后的子图片)保存到输出文件夹。
避免保存因裁剪区域超出图片范围导致出现的无效黑色图片部分。
"""
if result_images is None:
result_images = []
width, height = image.size
long_side = max(width, height) # 获取较长边长度
if long_side <= 400:
result_images.append(image)
return result_images
new_long_side = long_side // 2
sub_images = []
if width >= height: # 如果宽度是较长边
for x in range(0, width, new_long_side):
# 判断裁剪区域是否超出图片范围,如果超出则不进行裁剪保存操作
if x + new_long_side > width:
continue
box = (x, 0, x + new_long_side, height)
sub_image = image.crop(box)
sub_images.append(sub_image)
else: # 如果高度是较长边
for y in range(0, height, new_long_side):
# 判断裁剪区域是否超出图片范围,如果超出则不进行裁剪保存操作
if y + new_long_side > height:
continue
box = (0, y, width, y + new_long_side)
sub_image = image.crop(box)
sub_images.append(sub_image)
for sub_image in sub_images:
split_images(sub_image, result_images)
return result_images
def resize_images_to_224(image):
"""
若分辨率小于224则用黑色背景补齐到224*224大小,若大于等于224则调整为224*224大小,并保存到输出文件夹中。
"""
try:
width, height = image.size
if width < 224 or height < 224:
new_image = Image.new('RGB', (224, 224), (0, 0, 0))
paste_x = (224 - width) // 2
paste_y = (224 - height) // 2
new_image.paste(image, (paste_x, paste_y))
image = new_image
else:
image = image.resize((224, 224), Image.Resampling.LANCZOS)
# uuid = str(uuid4())
# image.save(f"/tmp/{uuid}.jpg")
return image
except Exception as e:
logger.exception(e)
class YOLOv11LangDetModel(object):
def __init__(self, langdetect_model_weight, device):
self.model = YOLO(langdetect_model_weight)
if str(device).startswith("npu"):
self.device = torch.device(device)
else:
self.device = device
def do_detect(self, images: list):
all_images = []
for image in images:
width, height = image.size
# logger.info(f"image size: {width} x {height}")
if width < 100 and height < 100:
continue
temp_images = split_images(image)
for temp_image in temp_images:
all_images.append(resize_images_to_224(temp_image))
images_lang_res = self.batch_predict(all_images, batch_size=8)
# logger.info(f"images_lang_res: {images_lang_res}")
if len(images_lang_res) > 0:
count_dict = Counter(images_lang_res)
language = max(count_dict, key=count_dict.get)
else:
language = None
return language
def predict(self, image):
results = self.model.predict(image, verbose=False, device=self.device)
predicted_class_id = int(results[0].probs.top1)
predicted_class_name = self.model.names[predicted_class_id]
return predicted_class_name
def batch_predict(self, images: list, batch_size: int) -> list:
images_lang_res = []
for index in range(0, len(images), batch_size):
lang_res = [
image_res.cpu()
for image_res in self.model.predict(
images[index: index + batch_size],
verbose = False,
device=self.device,
)
]
for res in lang_res:
predicted_class_id = int(res.probs.top1)
predicted_class_name = self.model.names[predicted_class_id]
images_lang_res.append(predicted_class_name)
return images_lang_res
\ No newline at end of file
# Copyright (c) Opendatalab. All rights reserved.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment