Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
MinerU
Commits
4bb54393
Unverified
Commit
4bb54393
authored
Jan 10, 2025
by
Xiaomeng Zhao
Committed by
GitHub
Jan 10, 2025
Browse files
Merge pull request #1427 from opendatalab/release-1.0.0
Release 1.0.0
parents
04f084ac
1c9f9942
Changes
121
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
702 additions
and
682 deletions
+702
-682
magic_pdf/data/data_reader_writer/filebase.py
magic_pdf/data/data_reader_writer/filebase.py
+1
-1
magic_pdf/data/data_reader_writer/multi_bucket_s3.py
magic_pdf/data/data_reader_writer/multi_bucket_s3.py
+8
-6
magic_pdf/data/dataset.py
magic_pdf/data/dataset.py
+13
-1
magic_pdf/data/read_api.py
magic_pdf/data/read_api.py
+59
-12
magic_pdf/data/utils.py
magic_pdf/data/utils.py
+35
-0
magic_pdf/dict2md/ocr_mkcontent.py
magic_pdf/dict2md/ocr_mkcontent.py
+14
-13
magic_pdf/libs/clean_memory.py
magic_pdf/libs/clean_memory.py
+11
-4
magic_pdf/libs/config_reader.py
magic_pdf/libs/config_reader.py
+9
-0
magic_pdf/libs/draw_bbox.py
magic_pdf/libs/draw_bbox.py
+8
-12
magic_pdf/libs/language.py
magic_pdf/libs/language.py
+3
-0
magic_pdf/model/__init__.py
magic_pdf/model/__init__.py
+1
-125
magic_pdf/model/batch_analyze.py
magic_pdf/model/batch_analyze.py
+275
-0
magic_pdf/model/doc_analyze_by_custom_model.py
magic_pdf/model/doc_analyze_by_custom_model.py
+4
-51
magic_pdf/model/magic_model.py
magic_pdf/model/magic_model.py
+4
-435
magic_pdf/model/model_list.py
magic_pdf/model/model_list.py
+1
-0
magic_pdf/model/pdf_extract_kit.py
magic_pdf/model/pdf_extract_kit.py
+33
-22
magic_pdf/model/sub_modules/language_detection/__init__.py
magic_pdf/model/sub_modules/language_detection/__init__.py
+1
-0
magic_pdf/model/sub_modules/language_detection/utils.py
magic_pdf/model/sub_modules/language_detection/utils.py
+82
-0
magic_pdf/model/sub_modules/language_detection/yolov11/YOLOv11.py
...f/model/sub_modules/language_detection/yolov11/YOLOv11.py
+139
-0
magic_pdf/model/sub_modules/language_detection/yolov11/__init__.py
.../model/sub_modules/language_detection/yolov11/__init__.py
+1
-0
No files found.
magic_pdf/data/data_reader_writer/filebase.py
View file @
4bb54393
...
...
@@ -55,7 +55,7 @@ class FileBasedDataWriter(DataWriter):
if
not
os
.
path
.
isabs
(
fn_path
)
and
len
(
self
.
_parent_dir
)
>
0
:
fn_path
=
os
.
path
.
join
(
self
.
_parent_dir
,
path
)
if
not
os
.
path
.
exists
(
os
.
path
.
dirname
(
fn_path
)):
if
not
os
.
path
.
exists
(
os
.
path
.
dirname
(
fn_path
))
and
os
.
path
.
dirname
(
fn_path
)
!=
""
:
os
.
makedirs
(
os
.
path
.
dirname
(
fn_path
),
exist_ok
=
True
)
with
open
(
fn_path
,
'wb'
)
as
f
:
...
...
magic_pdf/data/data_reader_writer/multi_bucket_s3.py
View file @
4bb54393
import
os
from
magic_pdf.config.exceptions
import
InvalidConfig
,
InvalidParams
from
magic_pdf.data.data_reader_writer.base
import
DataReader
,
DataWriter
from
magic_pdf.data.io.s3
import
S3Reader
,
S3Writer
...
...
@@ -22,10 +22,10 @@ class MultiS3Mixin:
"""
if
len
(
default_prefix
)
==
0
:
raise
InvalidConfig
(
'default_prefix must be provided'
)
arr
=
default_prefix
.
strip
(
"/"
).
split
(
"/"
)
arr
=
default_prefix
.
strip
(
'/'
).
split
(
'/'
)
self
.
default_bucket
=
arr
[
0
]
self
.
default_prefix
=
"/"
.
join
(
arr
[
1
:])
self
.
default_prefix
=
'/'
.
join
(
arr
[
1
:])
found_default_bucket_config
=
False
for
conf
in
s3_configs
:
...
...
@@ -103,7 +103,8 @@ class MultiBucketS3DataReader(DataReader, MultiS3Mixin):
s3_reader
=
self
.
__get_s3_client
(
bucket_name
)
else
:
s3_reader
=
self
.
__get_s3_client
(
self
.
default_bucket
)
path
=
os
.
path
.
join
(
self
.
default_prefix
,
path
)
if
self
.
default_prefix
:
path
=
self
.
default_prefix
+
'/'
+
path
return
s3_reader
.
read_at
(
path
,
offset
,
limit
)
...
...
@@ -139,5 +140,6 @@ class MultiBucketS3DataWriter(DataWriter, MultiS3Mixin):
s3_writer
=
self
.
__get_s3_client
(
bucket_name
)
else
:
s3_writer
=
self
.
__get_s3_client
(
self
.
default_bucket
)
path
=
os
.
path
.
join
(
self
.
default_prefix
,
path
)
if
self
.
default_prefix
:
path
=
self
.
default_prefix
+
'/'
+
path
return
s3_writer
.
write
(
path
,
data
)
magic_pdf/data/dataset.py
View file @
4bb54393
...
...
@@ -3,6 +3,7 @@ from abc import ABC, abstractmethod
from
typing
import
Callable
,
Iterator
import
fitz
from
loguru
import
logger
from
magic_pdf.config.enums
import
SupportedPdfParseMethod
from
magic_pdf.data.schemas
import
PageInfo
...
...
@@ -133,7 +134,7 @@ class Dataset(ABC):
class
PymuDocDataset
(
Dataset
):
def
__init__
(
self
,
bits
:
bytes
):
def
__init__
(
self
,
bits
:
bytes
,
lang
=
None
):
"""Initialize the dataset, which wraps the pymudoc documents.
Args:
...
...
@@ -144,6 +145,15 @@ class PymuDocDataset(Dataset):
self
.
_data_bits
=
bits
self
.
_raw_data
=
bits
if
lang
==
''
:
self
.
_lang
=
None
elif
lang
==
'auto'
:
from
magic_pdf.model.sub_modules.language_detection.utils
import
auto_detect_lang
self
.
_lang
=
auto_detect_lang
(
bits
)
logger
.
info
(
f
"lang:
{
lang
}
, detect_lang:
{
self
.
_lang
}
"
)
else
:
self
.
_lang
=
lang
logger
.
info
(
f
"lang:
{
lang
}
"
)
def
__len__
(
self
)
->
int
:
"""The page number of the pdf."""
return
len
(
self
.
_records
)
...
...
@@ -197,6 +207,8 @@ class PymuDocDataset(Dataset):
Returns:
Any: return the result generated by proc
"""
if
'lang'
in
kwargs
and
self
.
_lang
is
not
None
:
kwargs
[
'lang'
]
=
self
.
_lang
return
proc
(
self
,
*
args
,
**
kwargs
)
def
classify
(
self
)
->
SupportedPdfParseMethod
:
...
...
magic_pdf/data/read_api.py
View file @
4bb54393
import
json
import
os
import
tempfile
import
shutil
from
pathlib
import
Path
from
magic_pdf.config.exceptions
import
EmptyData
,
InvalidParams
from
magic_pdf.data.data_reader_writer
import
(
FileBasedDataReader
,
MultiBucketS3DataReader
)
from
magic_pdf.data.dataset
import
ImageDataset
,
PymuDocDataset
from
magic_pdf.utils.office_to_pdf
import
convert_file_to_pdf
,
ConvertToPdfError
def
read_jsonl
(
s3_path_or_local
:
str
,
s3_client
:
MultiBucketS3DataReader
|
None
=
None
...
...
@@ -58,23 +60,68 @@ def read_local_pdfs(path: str) -> list[PymuDocDataset]:
list[PymuDocDataset]: each pdf file will converted to a PymuDocDataset
"""
if
os
.
path
.
isdir
(
path
):
reader
=
FileBasedDataReader
(
path
)
return
[
PymuDocDataset
(
reader
.
read
(
doc_path
.
name
))
for
doc_path
in
Path
(
path
).
glob
(
'*.pdf'
)
]
reader
=
FileBasedDataReader
()
ret
=
[]
for
root
,
_
,
files
in
os
.
walk
(
path
):
for
file
in
files
:
suffix
=
file
.
split
(
'.'
)
if
suffix
[
-
1
]
==
'pdf'
:
ret
.
append
(
PymuDocDataset
(
reader
.
read
(
os
.
path
.
join
(
root
,
file
))))
return
ret
else
:
reader
=
FileBasedDataReader
()
bits
=
reader
.
read
(
path
)
return
[
PymuDocDataset
(
bits
)]
def
read_local_office
(
path
:
str
)
->
list
[
PymuDocDataset
]:
"""Read ms-office file (ppt, pptx, doc, docx) from path or directory.
def
read_local_images
(
path
:
str
,
suffixes
:
list
[
str
])
->
list
[
ImageDataset
]:
Args:
path (str): ms-office file or directory that contains ms-office files
Returns:
list[PymuDocDataset]: each ms-office file will converted to a PymuDocDataset
Raises:
ConvertToPdfError: Failed to convert ms-office file to pdf via libreoffice
FileNotFoundError: File not Found
Exception: Unknown Exception raised
"""
suffixes
=
[
'.ppt'
,
'.pptx'
,
'.doc'
,
'.docx'
]
fns
=
[]
ret
=
[]
if
os
.
path
.
isdir
(
path
):
for
root
,
_
,
files
in
os
.
walk
(
path
):
for
file
in
files
:
suffix
=
Path
(
file
).
suffix
if
suffix
in
suffixes
:
fns
.
append
((
os
.
path
.
join
(
root
,
file
)))
else
:
fns
.
append
(
path
)
reader
=
FileBasedDataReader
()
temp_dir
=
tempfile
.
mkdtemp
()
for
fn
in
fns
:
try
:
convert_file_to_pdf
(
fn
,
temp_dir
)
except
ConvertToPdfError
as
e
:
raise
e
except
FileNotFoundError
as
e
:
raise
e
except
Exception
as
e
:
raise
e
fn_path
=
Path
(
fn
)
pdf_fn
=
f
"
{
temp_dir
}
/
{
fn_path
.
stem
}
.pdf"
ret
.
append
(
PymuDocDataset
(
reader
.
read
(
pdf_fn
)))
shutil
.
rmtree
(
temp_dir
)
return
ret
def
read_local_images
(
path
:
str
,
suffixes
:
list
[
str
]
=
[
'.png'
,
'.jpg'
])
->
list
[
ImageDataset
]:
"""Read images from path or directory.
Args:
path (str): image file path or directory that contains image files
suffixes (list[str]): the suffixes of the image files used to filter the files. Example: ['jpg', 'png']
suffixes (list[str]): the suffixes of the image files used to filter the files. Example: ['
.
jpg', '
.
png']
Returns:
list[ImageDataset]: each image file will converted to a ImageDataset
...
...
@@ -82,12 +129,12 @@ def read_local_images(path: str, suffixes: list[str]) -> list[ImageDataset]:
if
os
.
path
.
isdir
(
path
):
imgs_bits
=
[]
s_suffixes
=
set
(
suffixes
)
reader
=
FileBasedDataReader
(
path
)
reader
=
FileBasedDataReader
()
for
root
,
_
,
files
in
os
.
walk
(
path
):
for
file
in
files
:
suffix
=
file
.
s
plit
(
'.'
)
if
suffix
[
-
1
]
in
s_suffixes
:
imgs_bits
.
append
(
reader
.
read
(
file
))
suffix
=
Path
(
file
)
.
s
uffix
if
suffix
in
s_suffixes
:
imgs_bits
.
append
(
reader
.
read
(
os
.
path
.
join
(
root
,
file
))
)
return
[
ImageDataset
(
bits
)
for
bits
in
imgs_bits
]
else
:
reader
=
FileBasedDataReader
()
...
...
magic_pdf/data/utils.py
View file @
4bb54393
import
fitz
import
numpy
as
np
from
loguru
import
logger
from
magic_pdf.utils.annotations
import
ImportPIL
...
...
@@ -30,3 +31,37 @@ def fitz_doc_to_image(doc, dpi=200) -> dict:
img_dict
=
{
'img'
:
img
,
'width'
:
pm
.
width
,
'height'
:
pm
.
height
}
return
img_dict
@
ImportPIL
def
load_images_from_pdf
(
pdf_bytes
:
bytes
,
dpi
=
200
,
start_page_id
=
0
,
end_page_id
=
None
)
->
list
:
from
PIL
import
Image
images
=
[]
with
fitz
.
open
(
'pdf'
,
pdf_bytes
)
as
doc
:
pdf_page_num
=
doc
.
page_count
end_page_id
=
(
end_page_id
if
end_page_id
is
not
None
and
end_page_id
>=
0
else
pdf_page_num
-
1
)
if
end_page_id
>
pdf_page_num
-
1
:
logger
.
warning
(
'end_page_id is out of range, use images length'
)
end_page_id
=
pdf_page_num
-
1
for
index
in
range
(
0
,
doc
.
page_count
):
if
start_page_id
<=
index
<=
end_page_id
:
page
=
doc
[
index
]
mat
=
fitz
.
Matrix
(
dpi
/
72
,
dpi
/
72
)
pm
=
page
.
get_pixmap
(
matrix
=
mat
,
alpha
=
False
)
# If the width or height exceeds 4500 after scaling, do not scale further.
if
pm
.
width
>
4500
or
pm
.
height
>
4500
:
pm
=
page
.
get_pixmap
(
matrix
=
fitz
.
Matrix
(
1
,
1
),
alpha
=
False
)
img
=
Image
.
frombytes
(
'RGB'
,
(
pm
.
width
,
pm
.
height
),
pm
.
samples
)
img
=
np
.
array
(
img
)
img_dict
=
{
'img'
:
img
,
'width'
:
pm
.
width
,
'height'
:
pm
.
height
}
else
:
img_dict
=
{
'img'
:
[],
'width'
:
0
,
'height'
:
0
}
images
.
append
(
img_dict
)
return
images
magic_pdf/dict2md/ocr_mkcontent.py
View file @
4bb54393
...
...
@@ -7,7 +7,7 @@ from magic_pdf.config.ocr_content_type import BlockType, ContentType
from
magic_pdf.libs.commons
import
join_path
from
magic_pdf.libs.language
import
detect_lang
from
magic_pdf.libs.markdown_utils
import
ocr_escape_special_markdown_char
from
magic_pdf.p
ara
.para_split_v3
import
ListLineTag
from
magic_pdf.p
ost_proc
.para_split_v3
import
ListLineTag
def
__is_hyphen_at_line_end
(
line
):
...
...
@@ -61,7 +61,8 @@ def ocr_mk_markdown_with_para_core_v2(paras_of_layout,
if
para_type
in
[
BlockType
.
Text
,
BlockType
.
List
,
BlockType
.
Index
]:
para_text
=
merge_para_with_text
(
para_block
)
elif
para_type
==
BlockType
.
Title
:
para_text
=
f
'#
{
merge_para_with_text
(
para_block
)
}
'
title_level
=
get_title_level
(
para_block
)
para_text
=
f
'
{
"#"
*
title_level
}
{
merge_para_with_text
(
para_block
)
}
'
elif
para_type
==
BlockType
.
InterlineEquation
:
para_text
=
merge_para_with_text
(
para_block
)
elif
para_type
==
BlockType
.
Image
:
...
...
@@ -125,16 +126,6 @@ def detect_language(text):
return
'empty'
# 连写字符拆分
def
__replace_ligatures
(
text
:
str
):
text
=
re
.
sub
(
r
'fi'
,
'fi'
,
text
)
# 替换 fi 连写符
text
=
re
.
sub
(
r
'fl'
,
'fl'
,
text
)
# 替换 fl 连写符
text
=
re
.
sub
(
r
'ff'
,
'ff'
,
text
)
# 替换 ff 连写符
text
=
re
.
sub
(
r
'ffi'
,
'ffi'
,
text
)
# 替换 ffi 连写符
text
=
re
.
sub
(
r
'ffl'
,
'ffl'
,
text
)
# 替换 ffl 连写符
return
text
def
merge_para_with_text
(
para_block
):
block_text
=
''
for
line
in
para_block
[
'lines'
]:
...
...
@@ -196,10 +187,11 @@ def para_to_standard_format_v2(para_block, img_buket_path, page_idx, drop_reason
'text'
:
merge_para_with_text
(
para_block
),
}
elif
para_type
==
BlockType
.
Title
:
title_level
=
get_title_level
(
para_block
)
para_content
=
{
'type'
:
'text'
,
'text'
:
merge_para_with_text
(
para_block
),
'text_level'
:
1
,
'text_level'
:
title_level
,
}
elif
para_type
==
BlockType
.
InterlineEquation
:
para_content
=
{
...
...
@@ -299,3 +291,12 @@ def union_make(pdf_info_dict: list,
return
'
\n\n
'
.
join
(
output_content
)
elif
make_mode
==
MakeMode
.
STANDARD_FORMAT
:
return
output_content
def
get_title_level
(
block
):
title_level
=
block
.
get
(
'level'
,
1
)
if
title_level
>
4
:
title_level
=
4
elif
title_level
<
1
:
title_level
=
1
return
title_level
\ No newline at end of file
magic_pdf/libs/clean_memory.py
View file @
4bb54393
...
...
@@ -3,8 +3,15 @@ import torch
import
gc
def
clean_memory
():
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
ipc_collect
()
def
clean_memory
(
device
=
'cuda'
):
if
device
==
'cuda'
:
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
ipc_collect
()
elif
str
(
device
).
startswith
(
"npu"
):
import
torch_npu
if
torch_npu
.
npu
.
is_available
():
torch_npu
.
npu
.
empty_cache
()
elif
str
(
device
).
startswith
(
"mps"
):
torch
.
mps
.
empty_cache
()
gc
.
collect
()
\ No newline at end of file
magic_pdf/libs/config_reader.py
View file @
4bb54393
...
...
@@ -116,6 +116,15 @@ def get_formula_config():
else
:
return
formula_config
def
get_llm_aided_config
():
config
=
read_config
()
llm_aided_config
=
config
.
get
(
'llm-aided-config'
)
if
llm_aided_config
is
None
:
logger
.
warning
(
f
"'llm-aided-config' not found in
{
CONFIG_FILE_NAME
}
, use 'None' as default"
)
return
None
else
:
return
llm_aided_config
if
__name__
==
'__main__'
:
ak
,
sk
,
endpoint
=
get_s3_config
(
'llm-raw'
)
magic_pdf/libs/draw_bbox.py
View file @
4bb54393
...
...
@@ -394,17 +394,13 @@ def draw_line_sort_bbox(pdf_info, pdf_bytes, out_path, filename):
pdf_docs
.
save
(
f
'
{
out_path
}
/
{
filename
}
'
)
def
draw_layout_sort_bbox
(
pdf_info
,
pdf_bytes
,
out_path
,
filename
):
layout_bbox_list
=
[]
for
page
in
pdf_info
:
page_block_list
=
[]
for
block
in
page
[
'para_blocks'
]:
bbox
=
block
[
'bbox'
]
page_block_list
.
append
(
bbox
)
layout_bbox_list
.
append
(
page_block_list
)
def
draw_char_bbox
(
pdf_bytes
,
out_path
,
filename
):
pdf_docs
=
fitz
.
open
(
'pdf'
,
pdf_bytes
)
for
i
,
page
in
enumerate
(
pdf_docs
):
draw_bbox_with_number
(
i
,
layout_bbox_list
,
page
,
[
255
,
0
,
0
],
False
)
pdf_docs
.
save
(
f
'
{
out_path
}
/
{
filename
}
_layout_sort.pdf'
)
for
block
in
page
.
get_text
(
'rawdict'
,
flags
=
fitz
.
TEXT_PRESERVE_LIGATURES
|
fitz
.
TEXT_PRESERVE_WHITESPACE
|
fitz
.
TEXT_MEDIABOX_CLIP
)[
'blocks'
]:
for
line
in
block
[
'lines'
]:
for
span
in
line
[
'spans'
]:
for
char
in
span
[
'chars'
]:
char_bbox
=
char
[
'bbox'
]
page
.
draw_rect
(
char_bbox
,
color
=
[
1
,
0
,
0
],
fill
=
None
,
fill_opacity
=
1
,
width
=
0.3
,
overlay
=
True
,)
pdf_docs
.
save
(
f
'
{
out_path
}
/
{
filename
}
'
)
magic_pdf/libs/language.py
View file @
4bb54393
...
...
@@ -16,11 +16,14 @@ def detect_lang(text: str) -> str:
if
len
(
text
)
==
0
:
return
""
text
=
text
.
replace
(
"
\n
"
,
""
)
try
:
lang_upper
=
detect_language
(
text
)
except
:
html_no_ctrl_chars
=
''
.
join
([
l
for
l
in
text
if
unicodedata
.
category
(
l
)[
0
]
not
in
[
'C'
,
]])
lang_upper
=
detect_language
(
html_no_ctrl_chars
)
try
:
lang
=
lang_upper
.
lower
()
except
:
...
...
magic_pdf/model/__init__.py
View file @
4bb54393
from
typing
import
Callable
from
abc
import
ABC
,
abstractmethod
from
magic_pdf.data.data_reader_writer
import
DataWriter
from
magic_pdf.data.dataset
import
Dataset
from
magic_pdf.pipe.operators
import
PipeResult
__use_inside_model__
=
True
__model_mode__
=
"full"
class
InferenceResultBase
(
ABC
):
@
abstractmethod
def
__init__
(
self
,
inference_results
:
list
,
dataset
:
Dataset
):
"""Initialized method.
Args:
inference_results (list): the inference result generated by model
dataset (Dataset): the dataset related with model inference result
"""
self
.
_infer_res
=
inference_results
self
.
_dataset
=
dataset
@
abstractmethod
def
draw_model
(
self
,
file_path
:
str
)
->
None
:
"""Draw model inference result.
Args:
file_path (str): the output file path
"""
pass
@
abstractmethod
def
dump_model
(
self
,
writer
:
DataWriter
,
file_path
:
str
):
"""Dump model inference result to file.
Args:
writer (DataWriter): writer handle
file_path (str): the location of target file
"""
pass
@
abstractmethod
def
get_infer_res
(
self
):
"""Get the inference result.
Returns:
list: the inference result generated by model
"""
pass
@
abstractmethod
def
apply
(
self
,
proc
:
Callable
,
*
args
,
**
kwargs
):
"""Apply callable method which.
Args:
proc (Callable): invoke proc as follows:
proc(inference_result, *args, **kwargs)
Returns:
Any: return the result generated by proc
"""
pass
@
abstractmethod
def
pipe_auto_mode
(
self
,
imageWriter
:
DataWriter
,
start_page_id
=
0
,
end_page_id
=
None
,
debug_mode
=
False
,
lang
=
None
,
)
->
PipeResult
:
"""Post-proc the model inference result.
step1: classify the dataset type
step2: based the result of step1, using `pipe_txt_mode` or `pipe_ocr_mode`
Args:
imageWriter (DataWriter): the image writer handle
start_page_id (int, optional): Defaults to 0. Let user select some pages He/She want to process
end_page_id (int, optional): Defaults to the last page index of dataset. Let user select some pages He/She want to process
debug_mode (bool, optional): Defaults to False. will dump more log if enabled
lang (str, optional): Defaults to None.
Returns:
PipeResult: the result
"""
pass
@
abstractmethod
def
pipe_txt_mode
(
self
,
imageWriter
:
DataWriter
,
start_page_id
=
0
,
end_page_id
=
None
,
debug_mode
=
False
,
lang
=
None
,
)
->
PipeResult
:
"""Post-proc the model inference result, Extract the text using the
third library, such as `pymupdf`
Args:
imageWriter (DataWriter): the image writer handle
start_page_id (int, optional): Defaults to 0. Let user select some pages He/She want to process
end_page_id (int, optional): Defaults to the last page index of dataset. Let user select some pages He/She want to process
debug_mode (bool, optional): Defaults to False. will dump more log if enabled
lang (str, optional): Defaults to None.
Returns:
PipeResult: the result
"""
pass
@
abstractmethod
def
pipe_ocr_mode
(
self
,
imageWriter
:
DataWriter
,
start_page_id
=
0
,
end_page_id
=
None
,
debug_mode
=
False
,
lang
=
None
,
)
->
PipeResult
:
pass
__model_mode__
=
'full'
\ No newline at end of file
magic_pdf/model/batch_analyze.py
0 → 100644
View file @
4bb54393
import
time
import
cv2
import
numpy
as
np
import
torch
from
loguru
import
logger
from
PIL
import
Image
from
magic_pdf.config.constants
import
MODEL_NAME
from
magic_pdf.config.exceptions
import
CUDA_NOT_AVAILABLE
from
magic_pdf.data.dataset
import
Dataset
from
magic_pdf.libs.clean_memory
import
clean_memory
from
magic_pdf.libs.config_reader
import
get_device
from
magic_pdf.model.doc_analyze_by_custom_model
import
ModelSingleton
from
magic_pdf.model.pdf_extract_kit
import
CustomPEKModel
from
magic_pdf.model.sub_modules.model_utils
import
(
clean_vram
,
crop_img
,
get_res_list_from_layout_res
)
from
magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils
import
(
get_adjusted_mfdetrec_res
,
get_ocr_result_list
)
from
magic_pdf.operators.models
import
InferenceResult
YOLO_LAYOUT_BASE_BATCH_SIZE
=
4
MFD_BASE_BATCH_SIZE
=
1
MFR_BASE_BATCH_SIZE
=
16
class
BatchAnalyze
:
def
__init__
(
self
,
model
:
CustomPEKModel
,
batch_ratio
:
int
):
self
.
model
=
model
self
.
batch_ratio
=
batch_ratio
def
__call__
(
self
,
images
:
list
)
->
list
:
images_layout_res
=
[]
layout_start_time
=
time
.
time
()
if
self
.
model
.
layout_model_name
==
MODEL_NAME
.
LAYOUTLMv3
:
# layoutlmv3
for
image
in
images
:
layout_res
=
self
.
model
.
layout_model
(
image
,
ignore_catids
=
[])
images_layout_res
.
append
(
layout_res
)
elif
self
.
model
.
layout_model_name
==
MODEL_NAME
.
DocLayout_YOLO
:
# doclayout_yolo
layout_images
=
[]
modified_images
=
[]
for
image_index
,
image
in
enumerate
(
images
):
pil_img
=
Image
.
fromarray
(
image
)
width
,
height
=
pil_img
.
size
if
height
>
width
:
input_res
=
{
'poly'
:
[
0
,
0
,
width
,
0
,
width
,
height
,
0
,
height
]}
new_image
,
useful_list
=
crop_img
(
input_res
,
pil_img
,
crop_paste_x
=
width
//
2
,
crop_paste_y
=
0
)
layout_images
.
append
(
new_image
)
modified_images
.
append
([
image_index
,
useful_list
])
else
:
layout_images
.
append
(
pil_img
)
images_layout_res
+=
self
.
model
.
layout_model
.
batch_predict
(
layout_images
,
self
.
batch_ratio
*
YOLO_LAYOUT_BASE_BATCH_SIZE
)
for
image_index
,
useful_list
in
modified_images
:
for
res
in
images_layout_res
[
image_index
]:
for
i
in
range
(
len
(
res
[
'poly'
])):
if
i
%
2
==
0
:
res
[
'poly'
][
i
]
=
(
res
[
'poly'
][
i
]
-
useful_list
[
0
]
+
useful_list
[
2
]
)
else
:
res
[
'poly'
][
i
]
=
(
res
[
'poly'
][
i
]
-
useful_list
[
1
]
+
useful_list
[
3
]
)
logger
.
info
(
f
'layout time:
{
round
(
time
.
time
()
-
layout_start_time
,
2
)
}
, image num:
{
len
(
images
)
}
'
)
if
self
.
model
.
apply_formula
:
# 公式检测
mfd_start_time
=
time
.
time
()
images_mfd_res
=
self
.
model
.
mfd_model
.
batch_predict
(
images
,
self
.
batch_ratio
*
MFD_BASE_BATCH_SIZE
)
logger
.
info
(
f
'mfd time:
{
round
(
time
.
time
()
-
mfd_start_time
,
2
)
}
, image num:
{
len
(
images
)
}
'
)
# 公式识别
mfr_start_time
=
time
.
time
()
images_formula_list
=
self
.
model
.
mfr_model
.
batch_predict
(
images_mfd_res
,
images
,
batch_size
=
self
.
batch_ratio
*
MFR_BASE_BATCH_SIZE
,
)
for
image_index
in
range
(
len
(
images
)):
images_layout_res
[
image_index
]
+=
images_formula_list
[
image_index
]
logger
.
info
(
f
'mfr time:
{
round
(
time
.
time
()
-
mfr_start_time
,
2
)
}
, image num:
{
len
(
images
)
}
'
)
# 清理显存
clean_vram
(
self
.
model
.
device
,
vram_threshold
=
8
)
ocr_time
=
0
ocr_count
=
0
table_time
=
0
table_count
=
0
# reference: magic_pdf/model/doc_analyze_by_custom_model.py:doc_analyze
for
index
in
range
(
len
(
images
)):
layout_res
=
images_layout_res
[
index
]
pil_img
=
Image
.
fromarray
(
images
[
index
])
ocr_res_list
,
table_res_list
,
single_page_mfdetrec_res
=
(
get_res_list_from_layout_res
(
layout_res
)
)
# ocr识别
ocr_start
=
time
.
time
()
# Process each area that requires OCR processing
for
res
in
ocr_res_list
:
new_image
,
useful_list
=
crop_img
(
res
,
pil_img
,
crop_paste_x
=
50
,
crop_paste_y
=
50
)
adjusted_mfdetrec_res
=
get_adjusted_mfdetrec_res
(
single_page_mfdetrec_res
,
useful_list
)
# OCR recognition
new_image
=
cv2
.
cvtColor
(
np
.
asarray
(
new_image
),
cv2
.
COLOR_RGB2BGR
)
if
self
.
model
.
apply_ocr
:
ocr_res
=
self
.
model
.
ocr_model
.
ocr
(
new_image
,
mfd_res
=
adjusted_mfdetrec_res
)[
0
]
else
:
ocr_res
=
self
.
model
.
ocr_model
.
ocr
(
new_image
,
mfd_res
=
adjusted_mfdetrec_res
,
rec
=
False
)[
0
]
# Integration results
if
ocr_res
:
ocr_result_list
=
get_ocr_result_list
(
ocr_res
,
useful_list
)
layout_res
.
extend
(
ocr_result_list
)
ocr_time
+=
time
.
time
()
-
ocr_start
ocr_count
+=
len
(
ocr_res_list
)
# 表格识别 table recognition
if
self
.
model
.
apply_table
:
table_start
=
time
.
time
()
for
res
in
table_res_list
:
new_image
,
_
=
crop_img
(
res
,
pil_img
)
single_table_start_time
=
time
.
time
()
html_code
=
None
if
self
.
model
.
table_model_name
==
MODEL_NAME
.
STRUCT_EQTABLE
:
with
torch
.
no_grad
():
table_result
=
self
.
model
.
table_model
.
predict
(
new_image
,
'html'
)
if
len
(
table_result
)
>
0
:
html_code
=
table_result
[
0
]
elif
self
.
model
.
table_model_name
==
MODEL_NAME
.
TABLE_MASTER
:
html_code
=
self
.
model
.
table_model
.
img2html
(
new_image
)
elif
self
.
model
.
table_model_name
==
MODEL_NAME
.
RAPID_TABLE
:
html_code
,
table_cell_bboxes
,
elapse
=
(
self
.
model
.
table_model
.
predict
(
new_image
)
)
run_time
=
time
.
time
()
-
single_table_start_time
if
run_time
>
self
.
model
.
table_max_time
:
logger
.
warning
(
f
'table recognition processing exceeds max time
{
self
.
model
.
table_max_time
}
s'
)
# 判断是否返回正常
if
html_code
:
expected_ending
=
html_code
.
strip
().
endswith
(
'</html>'
)
or
html_code
.
strip
().
endswith
(
'</table>'
)
if
expected_ending
:
res
[
'html'
]
=
html_code
else
:
logger
.
warning
(
'table recognition processing fails, not found expected HTML table end'
)
else
:
logger
.
warning
(
'table recognition processing fails, not get html return'
)
table_time
+=
time
.
time
()
-
table_start
table_count
+=
len
(
table_res_list
)
if
self
.
model
.
apply_ocr
:
logger
.
info
(
f
'ocr time:
{
round
(
ocr_time
,
2
)
}
, image num:
{
ocr_count
}
'
)
else
:
logger
.
info
(
f
'det time:
{
round
(
ocr_time
,
2
)
}
, image num:
{
ocr_count
}
'
)
if
self
.
model
.
apply_table
:
logger
.
info
(
f
'table time:
{
round
(
table_time
,
2
)
}
, image num:
{
table_count
}
'
)
return
images_layout_res
def
doc_batch_analyze
(
dataset
:
Dataset
,
ocr
:
bool
=
False
,
show_log
:
bool
=
False
,
start_page_id
=
0
,
end_page_id
=
None
,
lang
=
None
,
layout_model
=
None
,
formula_enable
=
None
,
table_enable
=
None
,
batch_ratio
:
int
|
None
=
None
,
)
->
InferenceResult
:
"""Perform batch analysis on a document dataset.
Args:
dataset (Dataset): The dataset containing document pages to be analyzed.
ocr (bool, optional): Flag to enable OCR (Optical Character Recognition). Defaults to False.
show_log (bool, optional): Flag to enable logging. Defaults to False.
start_page_id (int, optional): The starting page ID for analysis. Defaults to 0.
end_page_id (int, optional): The ending page ID for analysis. Defaults to None, which means analyze till the last page.
lang (str, optional): Language for OCR. Defaults to None.
layout_model (optional): Layout model to be used for analysis. Defaults to None.
formula_enable (optional): Flag to enable formula detection. Defaults to None.
table_enable (optional): Flag to enable table detection. Defaults to None.
batch_ratio (int | None, optional): Ratio for batch processing. Defaults to None, which sets it to 1.
Raises:
CUDA_NOT_AVAILABLE: If CUDA is not available, raises an exception as batch analysis is not supported in CPU mode.
Returns:
InferenceResult: The result of the batch analysis containing the analyzed data and the dataset.
"""
if
not
torch
.
cuda
.
is_available
():
raise
CUDA_NOT_AVAILABLE
(
'batch analyze not support in CPU mode'
)
lang
=
None
if
lang
==
''
else
lang
# TODO: auto detect batch size
batch_ratio
=
1
if
batch_ratio
is
None
else
batch_ratio
end_page_id
=
end_page_id
if
end_page_id
else
len
(
dataset
)
model_manager
=
ModelSingleton
()
custom_model
:
CustomPEKModel
=
model_manager
.
get_model
(
ocr
,
show_log
,
lang
,
layout_model
,
formula_enable
,
table_enable
)
batch_model
=
BatchAnalyze
(
model
=
custom_model
,
batch_ratio
=
batch_ratio
)
model_json
=
[]
# batch analyze
images
=
[]
for
index
in
range
(
len
(
dataset
)):
if
start_page_id
<=
index
<=
end_page_id
:
page_data
=
dataset
.
get_page
(
index
)
img_dict
=
page_data
.
get_image
()
images
.
append
(
img_dict
[
'img'
])
analyze_result
=
batch_model
(
images
)
for
index
in
range
(
len
(
dataset
)):
page_data
=
dataset
.
get_page
(
index
)
img_dict
=
page_data
.
get_image
()
page_width
=
img_dict
[
'width'
]
page_height
=
img_dict
[
'height'
]
if
start_page_id
<=
index
<=
end_page_id
:
result
=
analyze_result
.
pop
(
0
)
else
:
result
=
[]
page_info
=
{
'page_no'
:
index
,
'height'
:
page_height
,
'width'
:
page_width
}
page_dict
=
{
'layout_dets'
:
result
,
'page_info'
:
page_info
}
model_json
.
append
(
page_dict
)
# TODO: clean memory when gpu memory is not enough
clean_memory_start_time
=
time
.
time
()
clean_memory
(
get_device
())
logger
.
info
(
f
'clean memory time:
{
round
(
time
.
time
()
-
clean_memory_start_time
,
2
)
}
'
)
return
InferenceResult
(
model_json
,
dataset
)
magic_pdf/model/doc_analyze_by_custom_model.py
View file @
4bb54393
import
os
import
time
import
fitz
import
numpy
as
np
from
loguru
import
logger
# 关闭paddle的信号处理
import
paddle
from
loguru
import
logger
paddle
.
disable_signal_handler
()
os
.
environ
[
'NO_ALBUMENTATIONS_UPDATE'
]
=
'1'
# 禁止albumentations检查更新
os
.
environ
[
'YOLO_VERBOSE'
]
=
'False'
# disable yolo logger
try
:
import
torchtext
...
...
@@ -28,7 +25,7 @@ from magic_pdf.libs.config_reader import (get_device, get_formula_config,
get_local_models_dir
,
get_table_recog_config
)
from
magic_pdf.model.model_list
import
MODEL
from
magic_pdf.
model.
operators
import
InferenceResult
from
magic_pdf.operators
.models
import
InferenceResult
def
dict_compare
(
d1
,
d2
):
...
...
@@ -45,47 +42,6 @@ def remove_duplicates_dicts(lst):
return
unique_dicts
def
load_images_from_pdf
(
pdf_bytes
:
bytes
,
dpi
=
200
,
start_page_id
=
0
,
end_page_id
=
None
)
->
list
:
try
:
from
PIL
import
Image
except
ImportError
:
logger
.
error
(
'Pillow not installed, please install by pip.'
)
exit
(
1
)
images
=
[]
with
fitz
.
open
(
'pdf'
,
pdf_bytes
)
as
doc
:
pdf_page_num
=
doc
.
page_count
end_page_id
=
(
end_page_id
if
end_page_id
is
not
None
and
end_page_id
>=
0
else
pdf_page_num
-
1
)
if
end_page_id
>
pdf_page_num
-
1
:
logger
.
warning
(
'end_page_id is out of range, use images length'
)
end_page_id
=
pdf_page_num
-
1
for
index
in
range
(
0
,
doc
.
page_count
):
if
start_page_id
<=
index
<=
end_page_id
:
page
=
doc
[
index
]
mat
=
fitz
.
Matrix
(
dpi
/
72
,
dpi
/
72
)
pm
=
page
.
get_pixmap
(
matrix
=
mat
,
alpha
=
False
)
# If the width or height exceeds 4500 after scaling, do not scale further.
if
pm
.
width
>
4500
or
pm
.
height
>
4500
:
pm
=
page
.
get_pixmap
(
matrix
=
fitz
.
Matrix
(
1
,
1
),
alpha
=
False
)
img
=
Image
.
frombytes
(
'RGB'
,
(
pm
.
width
,
pm
.
height
),
pm
.
samples
)
img
=
np
.
array
(
img
)
img_dict
=
{
'img'
:
img
,
'width'
:
pm
.
width
,
'height'
:
pm
.
height
}
else
:
img_dict
=
{
'img'
:
[],
'width'
:
0
,
'height'
:
0
}
images
.
append
(
img_dict
)
return
images
class
ModelSingleton
:
_instance
=
None
_models
=
{}
...
...
@@ -198,9 +154,6 @@ def doc_analyze(
table_enable
=
None
,
)
->
InferenceResult
:
if
lang
==
''
:
lang
=
None
model_manager
=
ModelSingleton
()
custom_model
=
model_manager
.
get_model
(
ocr
,
show_log
,
lang
,
layout_model
,
formula_enable
,
table_enable
...
...
@@ -230,7 +183,7 @@ def doc_analyze(
model_json
.
append
(
page_dict
)
gc_start
=
time
.
time
()
clean_memory
()
clean_memory
(
get_device
()
)
gc_time
=
round
(
time
.
time
()
-
gc_start
,
2
)
logger
.
info
(
f
'gc time:
{
gc_time
}
'
)
...
...
magic_pdf/model/magic_model.py
View file @
4bb54393
...
...
@@ -3,12 +3,9 @@ import enum
from
magic_pdf.config.model_block_type
import
ModelBlockTypeEnum
from
magic_pdf.config.ocr_content_type
import
CategoryId
,
ContentType
from
magic_pdf.data.dataset
import
Dataset
from
magic_pdf.libs.boxbase
import
(
_is_in
,
_is_part_overlap
,
bbox_distance
,
bbox_relative_pos
,
box_area
,
calculate_iou
,
calculate_overlap_area_in_bbox1_area_ratio
,
get_overlap_area
)
from
magic_pdf.libs.boxbase
import
(
_is_in
,
bbox_distance
,
bbox_relative_pos
,
calculate_iou
)
from
magic_pdf.libs.coordinate_transform
import
get_scale_ratio
from
magic_pdf.libs.local_math
import
float_gt
from
magic_pdf.pre_proc.remove_bbox_overlap
import
_remove_overlap_between_bbox
CAPATION_OVERLAP_AREA_RATIO
=
0.6
...
...
@@ -208,393 +205,6 @@ class MagicModel:
keep
[
i
]
=
False
return
[
bboxes
[
i
]
for
i
in
range
(
N
)
if
keep
[
i
]]
def
__tie_up_category_by_distance
(
self
,
page_no
,
subject_category_id
,
object_category_id
):
"""假定每个 subject 最多有一个 object (可以有多个相邻的 object 合并为单个 object),每个 object
只能属于一个 subject."""
ret
=
[]
MAX_DIS_OF_POINT
=
10
**
9
+
7
"""
subject 和 object 的 bbox 会合并成一个大的 bbox (named: merged bbox)。
筛选出所有和 merged bbox 有 overlap 且 overlap 面积大于 object 的面积的 subjects。
再求出筛选出的 subjects 和 object 的最短距离
"""
def
search_overlap_between_boxes
(
subject_idx
,
object_idx
):
idxes
=
[
subject_idx
,
object_idx
]
x0s
=
[
all_bboxes
[
idx
][
'bbox'
][
0
]
for
idx
in
idxes
]
y0s
=
[
all_bboxes
[
idx
][
'bbox'
][
1
]
for
idx
in
idxes
]
x1s
=
[
all_bboxes
[
idx
][
'bbox'
][
2
]
for
idx
in
idxes
]
y1s
=
[
all_bboxes
[
idx
][
'bbox'
][
3
]
for
idx
in
idxes
]
merged_bbox
=
[
min
(
x0s
),
min
(
y0s
),
max
(
x1s
),
max
(
y1s
),
]
ratio
=
0
other_objects
=
list
(
map
(
lambda
x
:
{
'bbox'
:
x
[
'bbox'
],
'score'
:
x
[
'score'
]},
filter
(
lambda
x
:
x
[
'category_id'
]
not
in
(
object_category_id
,
subject_category_id
),
self
.
__model_list
[
page_no
][
'layout_dets'
],
),
)
)
for
other_object
in
other_objects
:
ratio
=
max
(
ratio
,
get_overlap_area
(
merged_bbox
,
other_object
[
'bbox'
])
*
1.0
/
box_area
(
all_bboxes
[
object_idx
][
'bbox'
]),
)
if
ratio
>=
MERGE_BOX_OVERLAP_AREA_RATIO
:
break
return
ratio
def
may_find_other_nearest_bbox
(
subject_idx
,
object_idx
):
ret
=
float
(
'inf'
)
x0
=
min
(
all_bboxes
[
subject_idx
][
'bbox'
][
0
],
all_bboxes
[
object_idx
][
'bbox'
][
0
]
)
y0
=
min
(
all_bboxes
[
subject_idx
][
'bbox'
][
1
],
all_bboxes
[
object_idx
][
'bbox'
][
1
]
)
x1
=
max
(
all_bboxes
[
subject_idx
][
'bbox'
][
2
],
all_bboxes
[
object_idx
][
'bbox'
][
2
]
)
y1
=
max
(
all_bboxes
[
subject_idx
][
'bbox'
][
3
],
all_bboxes
[
object_idx
][
'bbox'
][
3
]
)
object_area
=
abs
(
all_bboxes
[
object_idx
][
'bbox'
][
2
]
-
all_bboxes
[
object_idx
][
'bbox'
][
0
]
)
*
abs
(
all_bboxes
[
object_idx
][
'bbox'
][
3
]
-
all_bboxes
[
object_idx
][
'bbox'
][
1
]
)
for
i
in
range
(
len
(
all_bboxes
)):
if
(
i
==
subject_idx
or
all_bboxes
[
i
][
'category_id'
]
!=
subject_category_id
):
continue
if
_is_part_overlap
([
x0
,
y0
,
x1
,
y1
],
all_bboxes
[
i
][
'bbox'
])
or
_is_in
(
all_bboxes
[
i
][
'bbox'
],
[
x0
,
y0
,
x1
,
y1
]
):
i_area
=
abs
(
all_bboxes
[
i
][
'bbox'
][
2
]
-
all_bboxes
[
i
][
'bbox'
][
0
]
)
*
abs
(
all_bboxes
[
i
][
'bbox'
][
3
]
-
all_bboxes
[
i
][
'bbox'
][
1
])
if
i_area
>=
object_area
:
ret
=
min
(
float
(
'inf'
),
dis
[
i
][
object_idx
])
return
ret
def
expand_bbbox
(
idxes
):
x0s
=
[
all_bboxes
[
idx
][
'bbox'
][
0
]
for
idx
in
idxes
]
y0s
=
[
all_bboxes
[
idx
][
'bbox'
][
1
]
for
idx
in
idxes
]
x1s
=
[
all_bboxes
[
idx
][
'bbox'
][
2
]
for
idx
in
idxes
]
y1s
=
[
all_bboxes
[
idx
][
'bbox'
][
3
]
for
idx
in
idxes
]
return
min
(
x0s
),
min
(
y0s
),
max
(
x1s
),
max
(
y1s
)
subjects
=
self
.
__reduct_overlap
(
list
(
map
(
lambda
x
:
{
'bbox'
:
x
[
'bbox'
],
'score'
:
x
[
'score'
]},
filter
(
lambda
x
:
x
[
'category_id'
]
==
subject_category_id
,
self
.
__model_list
[
page_no
][
'layout_dets'
],
),
)
)
)
objects
=
self
.
__reduct_overlap
(
list
(
map
(
lambda
x
:
{
'bbox'
:
x
[
'bbox'
],
'score'
:
x
[
'score'
]},
filter
(
lambda
x
:
x
[
'category_id'
]
==
object_category_id
,
self
.
__model_list
[
page_no
][
'layout_dets'
],
),
)
)
)
subject_object_relation_map
=
{}
subjects
.
sort
(
key
=
lambda
x
:
x
[
'bbox'
][
0
]
**
2
+
x
[
'bbox'
][
1
]
**
2
)
# get the distance !
all_bboxes
=
[]
for
v
in
subjects
:
all_bboxes
.
append
(
{
'category_id'
:
subject_category_id
,
'bbox'
:
v
[
'bbox'
],
'score'
:
v
[
'score'
],
}
)
for
v
in
objects
:
all_bboxes
.
append
(
{
'category_id'
:
object_category_id
,
'bbox'
:
v
[
'bbox'
],
'score'
:
v
[
'score'
],
}
)
N
=
len
(
all_bboxes
)
dis
=
[[
MAX_DIS_OF_POINT
]
*
N
for
_
in
range
(
N
)]
for
i
in
range
(
N
):
for
j
in
range
(
i
):
if
(
all_bboxes
[
i
][
'category_id'
]
==
subject_category_id
and
all_bboxes
[
j
][
'category_id'
]
==
subject_category_id
):
continue
subject_idx
,
object_idx
=
i
,
j
if
all_bboxes
[
j
][
'category_id'
]
==
subject_category_id
:
subject_idx
,
object_idx
=
j
,
i
if
(
search_overlap_between_boxes
(
subject_idx
,
object_idx
)
>=
MERGE_BOX_OVERLAP_AREA_RATIO
):
dis
[
i
][
j
]
=
float
(
'inf'
)
dis
[
j
][
i
]
=
dis
[
i
][
j
]
continue
dis
[
i
][
j
]
=
self
.
_bbox_distance
(
all_bboxes
[
subject_idx
][
'bbox'
],
all_bboxes
[
object_idx
][
'bbox'
]
)
dis
[
j
][
i
]
=
dis
[
i
][
j
]
used
=
set
()
for
i
in
range
(
N
):
# 求第 i 个 subject 所关联的 object
if
all_bboxes
[
i
][
'category_id'
]
!=
subject_category_id
:
continue
seen
=
set
()
candidates
=
[]
arr
=
[]
for
j
in
range
(
N
):
pos_flag_count
=
sum
(
list
(
map
(
lambda
x
:
1
if
x
else
0
,
bbox_relative_pos
(
all_bboxes
[
i
][
'bbox'
],
all_bboxes
[
j
][
'bbox'
]
),
)
)
)
if
pos_flag_count
>
1
:
continue
if
(
all_bboxes
[
j
][
'category_id'
]
!=
object_category_id
or
j
in
used
or
dis
[
i
][
j
]
==
MAX_DIS_OF_POINT
):
continue
left
,
right
,
_
,
_
=
bbox_relative_pos
(
all_bboxes
[
i
][
'bbox'
],
all_bboxes
[
j
][
'bbox'
]
)
# 由 pos_flag_count 相关逻辑保证本段逻辑准确性
if
left
or
right
:
one_way_dis
=
all_bboxes
[
i
][
'bbox'
][
2
]
-
all_bboxes
[
i
][
'bbox'
][
0
]
else
:
one_way_dis
=
all_bboxes
[
i
][
'bbox'
][
3
]
-
all_bboxes
[
i
][
'bbox'
][
1
]
if
dis
[
i
][
j
]
>
one_way_dis
:
continue
arr
.
append
((
dis
[
i
][
j
],
j
))
arr
.
sort
(
key
=
lambda
x
:
x
[
0
])
if
len
(
arr
)
>
0
:
"""
bug: 离该subject 最近的 object 可能跨越了其它的 subject。
比如 [this subect] [some sbuject] [the nearest object of subject]
"""
if
may_find_other_nearest_bbox
(
i
,
arr
[
0
][
1
])
>=
arr
[
0
][
0
]:
candidates
.
append
(
arr
[
0
][
1
])
seen
.
add
(
arr
[
0
][
1
])
# 已经获取初始种子
for
j
in
set
(
candidates
):
tmp
=
[]
for
k
in
range
(
i
+
1
,
N
):
pos_flag_count
=
sum
(
list
(
map
(
lambda
x
:
1
if
x
else
0
,
bbox_relative_pos
(
all_bboxes
[
j
][
'bbox'
],
all_bboxes
[
k
][
'bbox'
]
),
)
)
)
if
pos_flag_count
>
1
:
continue
if
(
all_bboxes
[
k
][
'category_id'
]
!=
object_category_id
or
k
in
used
or
k
in
seen
or
dis
[
j
][
k
]
==
MAX_DIS_OF_POINT
or
dis
[
j
][
k
]
>
dis
[
i
][
j
]
):
continue
is_nearest
=
True
for
ni
in
range
(
i
+
1
,
N
):
if
ni
in
(
j
,
k
)
or
ni
in
used
or
ni
in
seen
:
continue
if
not
float_gt
(
dis
[
ni
][
k
],
dis
[
j
][
k
]):
is_nearest
=
False
break
if
is_nearest
:
nx0
,
ny0
,
nx1
,
ny1
=
expand_bbbox
(
list
(
seen
)
+
[
k
])
n_dis
=
bbox_distance
(
all_bboxes
[
i
][
'bbox'
],
[
nx0
,
ny0
,
nx1
,
ny1
]
)
if
float_gt
(
dis
[
i
][
j
],
n_dis
):
continue
tmp
.
append
(
k
)
seen
.
add
(
k
)
candidates
=
tmp
if
len
(
candidates
)
==
0
:
break
# 已经获取到某个 figure 下所有的最靠近的 captions,以及最靠近这些 captions 的 captions 。
# 先扩一下 bbox,
ox0
,
oy0
,
ox1
,
oy1
=
expand_bbbox
(
list
(
seen
)
+
[
i
])
ix0
,
iy0
,
ix1
,
iy1
=
all_bboxes
[
i
][
'bbox'
]
# 分成了 4 个截取空间,需要计算落在每个截取空间下 objects 合并后占据的矩形面积
caption_poses
=
[
[
ox0
,
oy0
,
ix0
,
oy1
],
[
ox0
,
oy0
,
ox1
,
iy0
],
[
ox0
,
iy1
,
ox1
,
oy1
],
[
ix1
,
oy0
,
ox1
,
oy1
],
]
caption_areas
=
[]
for
bbox
in
caption_poses
:
embed_arr
=
[]
for
idx
in
seen
:
if
(
calculate_overlap_area_in_bbox1_area_ratio
(
all_bboxes
[
idx
][
'bbox'
],
bbox
)
>
CAPATION_OVERLAP_AREA_RATIO
):
embed_arr
.
append
(
idx
)
if
len
(
embed_arr
)
>
0
:
embed_x0
=
min
([
all_bboxes
[
idx
][
'bbox'
][
0
]
for
idx
in
embed_arr
])
embed_y0
=
min
([
all_bboxes
[
idx
][
'bbox'
][
1
]
for
idx
in
embed_arr
])
embed_x1
=
max
([
all_bboxes
[
idx
][
'bbox'
][
2
]
for
idx
in
embed_arr
])
embed_y1
=
max
([
all_bboxes
[
idx
][
'bbox'
][
3
]
for
idx
in
embed_arr
])
caption_areas
.
append
(
int
(
abs
(
embed_x1
-
embed_x0
)
*
abs
(
embed_y1
-
embed_y0
))
)
else
:
caption_areas
.
append
(
0
)
subject_object_relation_map
[
i
]
=
[]
if
max
(
caption_areas
)
>
0
:
max_area_idx
=
caption_areas
.
index
(
max
(
caption_areas
))
caption_bbox
=
caption_poses
[
max_area_idx
]
for
j
in
seen
:
if
(
calculate_overlap_area_in_bbox1_area_ratio
(
all_bboxes
[
j
][
'bbox'
],
caption_bbox
)
>
CAPATION_OVERLAP_AREA_RATIO
):
used
.
add
(
j
)
subject_object_relation_map
[
i
].
append
(
j
)
for
i
in
sorted
(
subject_object_relation_map
.
keys
()):
result
=
{
'subject_body'
:
all_bboxes
[
i
][
'bbox'
],
'all'
:
all_bboxes
[
i
][
'bbox'
],
'score'
:
all_bboxes
[
i
][
'score'
],
}
if
len
(
subject_object_relation_map
[
i
])
>
0
:
x0
=
min
(
[
all_bboxes
[
j
][
'bbox'
][
0
]
for
j
in
subject_object_relation_map
[
i
]]
)
y0
=
min
(
[
all_bboxes
[
j
][
'bbox'
][
1
]
for
j
in
subject_object_relation_map
[
i
]]
)
x1
=
max
(
[
all_bboxes
[
j
][
'bbox'
][
2
]
for
j
in
subject_object_relation_map
[
i
]]
)
y1
=
max
(
[
all_bboxes
[
j
][
'bbox'
][
3
]
for
j
in
subject_object_relation_map
[
i
]]
)
result
[
'object_body'
]
=
[
x0
,
y0
,
x1
,
y1
]
result
[
'all'
]
=
[
min
(
x0
,
all_bboxes
[
i
][
'bbox'
][
0
]),
min
(
y0
,
all_bboxes
[
i
][
'bbox'
][
1
]),
max
(
x1
,
all_bboxes
[
i
][
'bbox'
][
2
]),
max
(
y1
,
all_bboxes
[
i
][
'bbox'
][
3
]),
]
ret
.
append
(
result
)
total_subject_object_dis
=
0
# 计算已经配对的 distance 距离
for
i
in
subject_object_relation_map
.
keys
():
for
j
in
subject_object_relation_map
[
i
]:
total_subject_object_dis
+=
bbox_distance
(
all_bboxes
[
i
][
'bbox'
],
all_bboxes
[
j
][
'bbox'
]
)
# 计算未匹配的 subject 和 object 的距离(非精确版)
with_caption_subject
=
set
(
[
key
for
key
in
subject_object_relation_map
.
keys
()
if
len
(
subject_object_relation_map
[
i
])
>
0
]
)
for
i
in
range
(
N
):
if
all_bboxes
[
i
][
'category_id'
]
!=
object_category_id
or
i
in
used
:
continue
candidates
=
[]
for
j
in
range
(
N
):
if
(
all_bboxes
[
j
][
'category_id'
]
!=
subject_category_id
or
j
in
with_caption_subject
):
continue
candidates
.
append
((
dis
[
i
][
j
],
j
))
if
len
(
candidates
)
>
0
:
candidates
.
sort
(
key
=
lambda
x
:
x
[
0
])
total_subject_object_dis
+=
candidates
[
0
][
1
]
with_caption_subject
.
add
(
j
)
return
ret
,
total_subject_object_dis
def
__tie_up_category_by_distance_v2
(
self
,
page_no
:
int
,
...
...
@@ -879,52 +489,12 @@ class MagicModel:
return
ret
def
get_imgs
(
self
,
page_no
:
int
):
with_captions
,
_
=
self
.
__tie_up_category_by_distance
(
page_no
,
3
,
4
)
with_footnotes
,
_
=
self
.
__tie_up_category_by_distance
(
page_no
,
3
,
CategoryId
.
ImageFootnote
)
ret
=
[]
N
,
M
=
len
(
with_captions
),
len
(
with_footnotes
)
assert
N
==
M
for
i
in
range
(
N
):
record
=
{
'score'
:
with_captions
[
i
][
'score'
],
'img_caption_bbox'
:
with_captions
[
i
].
get
(
'object_body'
,
None
),
'img_body_bbox'
:
with_captions
[
i
][
'subject_body'
],
'img_footnote_bbox'
:
with_footnotes
[
i
].
get
(
'object_body'
,
None
),
}
x0
=
min
(
with_captions
[
i
][
'all'
][
0
],
with_footnotes
[
i
][
'all'
][
0
])
y0
=
min
(
with_captions
[
i
][
'all'
][
1
],
with_footnotes
[
i
][
'all'
][
1
])
x1
=
max
(
with_captions
[
i
][
'all'
][
2
],
with_footnotes
[
i
][
'all'
][
2
])
y1
=
max
(
with_captions
[
i
][
'all'
][
3
],
with_footnotes
[
i
][
'all'
][
3
])
record
[
'bbox'
]
=
[
x0
,
y0
,
x1
,
y1
]
ret
.
append
(
record
)
return
ret
return
self
.
get_imgs_v2
(
page_no
)
def
get_tables
(
self
,
page_no
:
int
)
->
list
:
# 3个坐标, caption, table主体,table-note
with_captions
,
_
=
self
.
__tie_up_category_by_distance
(
page_no
,
5
,
6
)
with_footnotes
,
_
=
self
.
__tie_up_category_by_distance
(
page_no
,
5
,
7
)
ret
=
[]
N
,
M
=
len
(
with_captions
),
len
(
with_footnotes
)
assert
N
==
M
for
i
in
range
(
N
):
record
=
{
'score'
:
with_captions
[
i
][
'score'
],
'table_caption_bbox'
:
with_captions
[
i
].
get
(
'object_body'
,
None
),
'table_body_bbox'
:
with_captions
[
i
][
'subject_body'
],
'table_footnote_bbox'
:
with_footnotes
[
i
].
get
(
'object_body'
,
None
),
}
x0
=
min
(
with_captions
[
i
][
'all'
][
0
],
with_footnotes
[
i
][
'all'
][
0
])
y0
=
min
(
with_captions
[
i
][
'all'
][
1
],
with_footnotes
[
i
][
'all'
][
1
])
x1
=
max
(
with_captions
[
i
][
'all'
][
2
],
with_footnotes
[
i
][
'all'
][
2
])
y1
=
max
(
with_captions
[
i
][
'all'
][
3
],
with_footnotes
[
i
][
'all'
][
3
])
record
[
'bbox'
]
=
[
x0
,
y0
,
x1
,
y1
]
ret
.
append
(
record
)
return
ret
return
self
.
get_tables_v2
(
page_no
)
def
get_equations
(
self
,
page_no
:
int
)
->
list
:
# 有坐标,也有字
inline_equations
=
self
.
__get_blocks_by_type
(
...
...
@@ -1043,4 +613,3 @@ class MagicModel:
def
get_model_list
(
self
,
page_no
):
return
self
.
__model_list
[
page_no
]
magic_pdf/model/model_list.py
View file @
4bb54393
...
...
@@ -9,3 +9,4 @@ class AtomicModel:
MFR
=
"mfr"
OCR
=
"ocr"
Table
=
"table"
LangDetect
=
"langdetect"
magic_pdf/model/pdf_extract_kit.py
View file @
4bb54393
...
...
@@ -10,7 +10,6 @@ from loguru import logger
from
PIL
import
Image
os
.
environ
[
'NO_ALBUMENTATIONS_UPDATE'
]
=
'1'
# 禁止albumentations检查更新
os
.
environ
[
'YOLO_VERBOSE'
]
=
'False'
# disable yolo logger
try
:
import
torchtext
...
...
@@ -88,6 +87,14 @@ class CustomPEKModel:
)
# 初始化解析方案
self
.
device
=
kwargs
.
get
(
'device'
,
'cpu'
)
if
str
(
self
.
device
).
startswith
(
"npu"
):
import
torch_npu
os
.
environ
[
'FLAGS_npu_jit_compile'
]
=
'0'
os
.
environ
[
'FLAGS_use_stride_kernel'
]
=
'0'
elif
str
(
self
.
device
).
startswith
(
"mps"
):
os
.
environ
[
'PYTORCH_ENABLE_MPS_FALLBACK'
]
=
'1'
logger
.
info
(
'using device: {}'
.
format
(
self
.
device
))
models_dir
=
kwargs
.
get
(
'models_dir'
,
os
.
path
.
join
(
root_dir
,
'resources'
,
'models'
)
...
...
@@ -114,11 +121,12 @@ class CustomPEKModel:
os
.
path
.
join
(
models_dir
,
self
.
configs
[
'weights'
][
self
.
mfr_model_name
])
)
mfr_cfg_path
=
str
(
os
.
path
.
join
(
model_config_dir
,
'UniMERNet'
,
'demo.yaml'
))
self
.
mfr_model
=
atom_model_manager
.
get_atom_model
(
atom_model_name
=
AtomicModel
.
MFR
,
mfr_weight_dir
=
mfr_weight_dir
,
mfr_cfg_path
=
mfr_cfg_path
,
device
=
self
.
device
,
device
=
'cpu'
if
str
(
self
.
device
).
startswith
(
"mps"
)
else
self
.
device
,
)
# 初始化layout模型
...
...
@@ -165,12 +173,17 @@ class CustomPEKModel:
table_model_path
=
str
(
os
.
path
.
join
(
models_dir
,
table_model_dir
)),
table_max_time
=
self
.
table_max_time
,
device
=
self
.
device
,
ocr_engine
=
self
.
ocr_model
,
)
logger
.
info
(
'DocAnalysis init done!'
)
def
__call__
(
self
,
image
):
pil_img
=
Image
.
fromarray
(
image
)
width
,
height
=
pil_img
.
size
# logger.info(f'width: {width}, height: {height}')
# layout检测
layout_start
=
time
.
time
()
layout_res
=
[]
...
...
@@ -179,30 +192,28 @@ class CustomPEKModel:
layout_res
=
self
.
layout_model
(
image
,
ignore_catids
=
[])
elif
self
.
layout_model_name
==
MODEL_NAME
.
DocLayout_YOLO
:
# doclayout_yolo
i
mg_pil
=
Image
.
fromarray
(
image
)
width
,
height
=
img_pil
.
size
# logger.info(f'width: {width}, height: {height}'
)
input_res
=
{
"poly"
:[
0
,
0
,
width
,
0
,
width
,
height
,
0
,
height
]}
new_image
,
useful_list
=
crop_img
(
input_res
,
img_pil
,
crop_paste_x
=
width
//
2
,
crop_paste_y
=
0
)
paste_x
,
paste_y
,
xmin
,
ymin
,
xmax
,
ymax
,
new_width
,
new_height
=
useful_list
layout_res
=
self
.
layout_model
.
predict
(
new_image
)
for
res
in
layout_res
:
p1
,
p2
,
p3
,
p4
,
p5
,
p6
,
p7
,
p8
=
res
[
'poly'
]
p1
=
p
1
-
paste_x
+
xmin
p2
=
p
2
-
paste_y
+
ymin
p3
=
p
3
-
paste_x
+
xmin
p4
=
p
4
-
paste_y
+
ymin
p5
=
p
5
-
paste_x
+
xmin
p6
=
p
6
-
paste_y
+
ymin
p7
=
p7
-
paste_x
+
xmin
p8
=
p8
-
paste_y
+
ymin
res
[
'poly'
]
=
[
p1
,
p2
,
p3
,
p4
,
p5
,
p6
,
p7
,
p8
]
i
f
height
>
width
:
input_res
=
{
"poly"
:[
0
,
0
,
width
,
0
,
width
,
height
,
0
,
height
]}
new_image
,
useful_list
=
crop_img
(
input_res
,
pil_img
,
crop_paste_x
=
width
//
2
,
crop_paste_y
=
0
)
paste_x
,
paste_y
,
xmin
,
ymin
,
xmax
,
ymax
,
new_
width
,
new_
height
=
useful_list
layout_res
=
self
.
layout_model
.
predict
(
new_image
)
for
res
in
layout_res
:
p1
,
p2
,
p3
,
p4
,
p5
,
p6
,
p7
,
p8
=
res
[
'poly'
]
p1
=
p1
-
paste_x
+
xmin
p2
=
p2
-
paste_y
+
ymin
p3
=
p
3
-
paste_x
+
xmin
p4
=
p
4
-
paste_y
+
ymin
p5
=
p
5
-
paste_x
+
xmin
p6
=
p
6
-
paste_y
+
ymin
p7
=
p
7
-
paste_x
+
xmin
p8
=
p
8
-
paste_y
+
ymin
res
[
'poly'
]
=
[
p1
,
p2
,
p3
,
p4
,
p5
,
p6
,
p7
,
p8
]
else
:
layout_res
=
self
.
layout_model
.
predict
(
image
)
layout_cost
=
round
(
time
.
time
()
-
layout_start
,
2
)
logger
.
info
(
f
'layout detection time:
{
layout_cost
}
'
)
pil_img
=
Image
.
fromarray
(
image
)
if
self
.
apply_formula
:
# 公式检测
mfd_start
=
time
.
time
()
...
...
magic_pdf/model/sub_modules/language_detection/__init__.py
0 → 100644
View file @
4bb54393
# Copyright (c) Opendatalab. All rights reserved.
magic_pdf/model/sub_modules/language_detection/utils.py
0 → 100644
View file @
4bb54393
# Copyright (c) Opendatalab. All rights reserved.
import
os
from
pathlib
import
Path
import
yaml
from
PIL
import
Image
os
.
environ
[
'NO_ALBUMENTATIONS_UPDATE'
]
=
'1'
# 禁止albumentations检查更新
from
magic_pdf.config.constants
import
MODEL_NAME
from
magic_pdf.data.utils
import
load_images_from_pdf
from
magic_pdf.libs.config_reader
import
get_local_models_dir
,
get_device
from
magic_pdf.libs.pdf_check
import
extract_pages
from
magic_pdf.model.model_list
import
AtomicModel
from
magic_pdf.model.sub_modules.model_init
import
AtomModelSingleton
def
get_model_config
():
local_models_dir
=
get_local_models_dir
()
device
=
get_device
()
current_file_path
=
os
.
path
.
abspath
(
__file__
)
root_dir
=
Path
(
current_file_path
).
parents
[
3
]
model_config_dir
=
os
.
path
.
join
(
root_dir
,
'resources'
,
'model_config'
)
config_path
=
os
.
path
.
join
(
model_config_dir
,
'model_configs.yaml'
)
with
open
(
config_path
,
'r'
,
encoding
=
'utf-8'
)
as
f
:
configs
=
yaml
.
load
(
f
,
Loader
=
yaml
.
FullLoader
)
return
root_dir
,
local_models_dir
,
device
,
configs
def
get_text_images
(
simple_images
):
_
,
local_models_dir
,
device
,
configs
=
get_model_config
()
atom_model_manager
=
AtomModelSingleton
()
temp_layout_model
=
atom_model_manager
.
get_atom_model
(
atom_model_name
=
AtomicModel
.
Layout
,
layout_model_name
=
MODEL_NAME
.
DocLayout_YOLO
,
doclayout_yolo_weights
=
str
(
os
.
path
.
join
(
local_models_dir
,
configs
[
'weights'
][
MODEL_NAME
.
DocLayout_YOLO
]
)
),
device
=
device
,
)
text_images
=
[]
for
simple_image
in
simple_images
:
image
=
Image
.
fromarray
(
simple_image
[
'img'
])
layout_res
=
temp_layout_model
.
predict
(
image
)
# 给textblock截图
for
res
in
layout_res
:
if
res
[
'category_id'
]
in
[
1
]:
x1
,
y1
,
_
,
_
,
x2
,
y2
,
_
,
_
=
res
[
'poly'
]
# 初步清洗(宽和高都小于100)
if
x2
-
x1
<
100
and
y2
-
y1
<
100
:
continue
text_images
.
append
(
image
.
crop
((
x1
,
y1
,
x2
,
y2
)))
return
text_images
def
auto_detect_lang
(
pdf_bytes
:
bytes
):
sample_docs
=
extract_pages
(
pdf_bytes
)
sample_pdf_bytes
=
sample_docs
.
tobytes
()
simple_images
=
load_images_from_pdf
(
sample_pdf_bytes
,
dpi
=
200
)
text_images
=
get_text_images
(
simple_images
)
langdetect_model
=
model_init
(
MODEL_NAME
.
YOLO_V11_LangDetect
)
lang
=
langdetect_model
.
do_detect
(
text_images
)
return
lang
def
model_init
(
model_name
:
str
):
atom_model_manager
=
AtomModelSingleton
()
if
model_name
==
MODEL_NAME
.
YOLO_V11_LangDetect
:
root_dir
,
_
,
device
,
_
=
get_model_config
()
model
=
atom_model_manager
.
get_atom_model
(
atom_model_name
=
AtomicModel
.
LangDetect
,
langdetect_model_name
=
MODEL_NAME
.
YOLO_V11_LangDetect
,
langdetect_model_weight
=
str
(
os
.
path
.
join
(
root_dir
,
'resources'
,
'yolov11-langdetect'
,
'yolo_v11_ft.pt'
)),
device
=
device
,
)
else
:
raise
ValueError
(
f
"model_name
{
model_name
}
not found"
)
return
model
magic_pdf/model/sub_modules/language_detection/yolov11/YOLOv11.py
0 → 100644
View file @
4bb54393
# Copyright (c) Opendatalab. All rights reserved.
from
collections
import
Counter
from
uuid
import
uuid4
import
torch
from
PIL
import
Image
from
loguru
import
logger
from
ultralytics
import
YOLO
language_dict
=
{
"ch"
:
"中文简体"
,
"en"
:
"英语"
,
"japan"
:
"日语"
,
"korean"
:
"韩语"
,
"fr"
:
"法语"
,
"german"
:
"德语"
,
"ar"
:
"阿拉伯语"
,
"ru"
:
"俄语"
}
def
split_images
(
image
,
result_images
=
None
):
"""
对输入文件夹内的图片进行处理,若图片竖向(y方向)分辨率超过400,则进行拆分,
每次平分图片,直至拆分出的图片竖向分辨率都满足400以下,将处理后的图片(拆分后的子图片)保存到输出文件夹。
避免保存因裁剪区域超出图片范围导致出现的无效黑色图片部分。
"""
if
result_images
is
None
:
result_images
=
[]
width
,
height
=
image
.
size
long_side
=
max
(
width
,
height
)
# 获取较长边长度
if
long_side
<=
400
:
result_images
.
append
(
image
)
return
result_images
new_long_side
=
long_side
//
2
sub_images
=
[]
if
width
>=
height
:
# 如果宽度是较长边
for
x
in
range
(
0
,
width
,
new_long_side
):
# 判断裁剪区域是否超出图片范围,如果超出则不进行裁剪保存操作
if
x
+
new_long_side
>
width
:
continue
box
=
(
x
,
0
,
x
+
new_long_side
,
height
)
sub_image
=
image
.
crop
(
box
)
sub_images
.
append
(
sub_image
)
else
:
# 如果高度是较长边
for
y
in
range
(
0
,
height
,
new_long_side
):
# 判断裁剪区域是否超出图片范围,如果超出则不进行裁剪保存操作
if
y
+
new_long_side
>
height
:
continue
box
=
(
0
,
y
,
width
,
y
+
new_long_side
)
sub_image
=
image
.
crop
(
box
)
sub_images
.
append
(
sub_image
)
for
sub_image
in
sub_images
:
split_images
(
sub_image
,
result_images
)
return
result_images
def
resize_images_to_224
(
image
):
"""
若分辨率小于224则用黑色背景补齐到224*224大小,若大于等于224则调整为224*224大小,并保存到输出文件夹中。
"""
try
:
width
,
height
=
image
.
size
if
width
<
224
or
height
<
224
:
new_image
=
Image
.
new
(
'RGB'
,
(
224
,
224
),
(
0
,
0
,
0
))
paste_x
=
(
224
-
width
)
//
2
paste_y
=
(
224
-
height
)
//
2
new_image
.
paste
(
image
,
(
paste_x
,
paste_y
))
image
=
new_image
else
:
image
=
image
.
resize
((
224
,
224
),
Image
.
Resampling
.
LANCZOS
)
# uuid = str(uuid4())
# image.save(f"/tmp/{uuid}.jpg")
return
image
except
Exception
as
e
:
logger
.
exception
(
e
)
class
YOLOv11LangDetModel
(
object
):
def
__init__
(
self
,
langdetect_model_weight
,
device
):
self
.
model
=
YOLO
(
langdetect_model_weight
)
if
str
(
device
).
startswith
(
"npu"
):
self
.
device
=
torch
.
device
(
device
)
else
:
self
.
device
=
device
def
do_detect
(
self
,
images
:
list
):
all_images
=
[]
for
image
in
images
:
width
,
height
=
image
.
size
# logger.info(f"image size: {width} x {height}")
if
width
<
100
and
height
<
100
:
continue
temp_images
=
split_images
(
image
)
for
temp_image
in
temp_images
:
all_images
.
append
(
resize_images_to_224
(
temp_image
))
images_lang_res
=
self
.
batch_predict
(
all_images
,
batch_size
=
8
)
# logger.info(f"images_lang_res: {images_lang_res}")
if
len
(
images_lang_res
)
>
0
:
count_dict
=
Counter
(
images_lang_res
)
language
=
max
(
count_dict
,
key
=
count_dict
.
get
)
else
:
language
=
None
return
language
def
predict
(
self
,
image
):
results
=
self
.
model
.
predict
(
image
,
verbose
=
False
,
device
=
self
.
device
)
predicted_class_id
=
int
(
results
[
0
].
probs
.
top1
)
predicted_class_name
=
self
.
model
.
names
[
predicted_class_id
]
return
predicted_class_name
def
batch_predict
(
self
,
images
:
list
,
batch_size
:
int
)
->
list
:
images_lang_res
=
[]
for
index
in
range
(
0
,
len
(
images
),
batch_size
):
lang_res
=
[
image_res
.
cpu
()
for
image_res
in
self
.
model
.
predict
(
images
[
index
:
index
+
batch_size
],
verbose
=
False
,
device
=
self
.
device
,
)
]
for
res
in
lang_res
:
predicted_class_id
=
int
(
res
.
probs
.
top1
)
predicted_class_name
=
self
.
model
.
names
[
predicted_class_id
]
images_lang_res
.
append
(
predicted_class_name
)
return
images_lang_res
\ No newline at end of file
magic_pdf/model/sub_modules/language_detection/yolov11/__init__.py
0 → 100644
View file @
4bb54393
# Copyright (c) Opendatalab. All rights reserved.
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment