Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
MinerU
Commits
ebfab424
Unverified
Commit
ebfab424
authored
Nov 15, 2024
by
linfeng
Committed by
GitHub
Nov 15, 2024
Browse files
Merge branch 'opendatalab:dev' into dev
parents
aed0941f
94f6bd83
Changes
58
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
732 additions
and
0 deletions
+732
-0
magic_pdf/model/sub_modules/layout/layoutlmv3/layoutlmft/models/layoutlmv3/__init__.py
...ayout/layoutlmv3/layoutlmft/models/layoutlmv3/__init__.py
+0
-0
magic_pdf/model/sub_modules/layout/layoutlmv3/layoutlmft/models/layoutlmv3/configuration_layoutlmv3.py
.../layoutlmft/models/layoutlmv3/configuration_layoutlmv3.py
+0
-0
magic_pdf/model/sub_modules/layout/layoutlmv3/layoutlmft/models/layoutlmv3/modeling_layoutlmv3.py
...tlmv3/layoutlmft/models/layoutlmv3/modeling_layoutlmv3.py
+0
-0
magic_pdf/model/sub_modules/layout/layoutlmv3/layoutlmft/models/layoutlmv3/tokenization_layoutlmv3.py
...3/layoutlmft/models/layoutlmv3/tokenization_layoutlmv3.py
+0
-0
magic_pdf/model/sub_modules/layout/layoutlmv3/layoutlmft/models/layoutlmv3/tokenization_layoutlmv3_fast.py
...outlmft/models/layoutlmv3/tokenization_layoutlmv3_fast.py
+0
-0
magic_pdf/model/sub_modules/layout/layoutlmv3/model_init.py
magic_pdf/model/sub_modules/layout/layoutlmv3/model_init.py
+0
-0
magic_pdf/model/sub_modules/layout/layoutlmv3/rcnn_vl.py
magic_pdf/model/sub_modules/layout/layoutlmv3/rcnn_vl.py
+0
-0
magic_pdf/model/sub_modules/layout/layoutlmv3/visualizer.py
magic_pdf/model/sub_modules/layout/layoutlmv3/visualizer.py
+0
-0
magic_pdf/model/sub_modules/mfd/__init__.py
magic_pdf/model/sub_modules/mfd/__init__.py
+0
-0
magic_pdf/model/sub_modules/mfd/yolov8/YOLOv8.py
magic_pdf/model/sub_modules/mfd/yolov8/YOLOv8.py
+12
-0
magic_pdf/model/sub_modules/mfd/yolov8/__init__.py
magic_pdf/model/sub_modules/mfd/yolov8/__init__.py
+0
-0
magic_pdf/model/sub_modules/mfr/__init__.py
magic_pdf/model/sub_modules/mfr/__init__.py
+0
-0
magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py
magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py
+98
-0
magic_pdf/model/sub_modules/mfr/unimernet/__init__.py
magic_pdf/model/sub_modules/mfr/unimernet/__init__.py
+0
-0
magic_pdf/model/sub_modules/model_init.py
magic_pdf/model/sub_modules/model_init.py
+144
-0
magic_pdf/model/sub_modules/model_utils.py
magic_pdf/model/sub_modules/model_utils.py
+51
-0
magic_pdf/model/sub_modules/ocr/__init__.py
magic_pdf/model/sub_modules/ocr/__init__.py
+0
-0
magic_pdf/model/sub_modules/ocr/paddleocr/__init__.py
magic_pdf/model/sub_modules/ocr/paddleocr/__init__.py
+0
-0
magic_pdf/model/sub_modules/ocr/paddleocr/ocr_utils.py
magic_pdf/model/sub_modules/ocr/paddleocr/ocr_utils.py
+259
-0
magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_273_mod.py
magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_273_mod.py
+168
-0
No files found.
magic_pdf/model/
pek_
sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/__init__.py
→
magic_pdf/model/sub_modules/
layout/
layoutlmv3/layoutlmft/models/layoutlmv3/__init__.py
View file @
ebfab424
File moved
magic_pdf/model/
pek_
sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/configuration_layoutlmv3.py
→
magic_pdf/model/sub_modules/
layout/
layoutlmv3/layoutlmft/models/layoutlmv3/configuration_layoutlmv3.py
View file @
ebfab424
File moved
magic_pdf/model/
pek_
sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/modeling_layoutlmv3.py
→
magic_pdf/model/sub_modules/
layout/
layoutlmv3/layoutlmft/models/layoutlmv3/modeling_layoutlmv3.py
View file @
ebfab424
File moved
magic_pdf/model/
pek_
sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/tokenization_layoutlmv3.py
→
magic_pdf/model/sub_modules/
layout/
layoutlmv3/layoutlmft/models/layoutlmv3/tokenization_layoutlmv3.py
View file @
ebfab424
File moved
magic_pdf/model/
pek_
sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/tokenization_layoutlmv3_fast.py
→
magic_pdf/model/sub_modules/
layout/
layoutlmv3/layoutlmft/models/layoutlmv3/tokenization_layoutlmv3_fast.py
View file @
ebfab424
File moved
magic_pdf/model/
pek_
sub_modules/layoutlmv3/model_init.py
→
magic_pdf/model/sub_modules/
layout/
layoutlmv3/model_init.py
View file @
ebfab424
File moved
magic_pdf/model/
pek_
sub_modules/layoutlmv3/rcnn_vl.py
→
magic_pdf/model/sub_modules/
layout/
layoutlmv3/rcnn_vl.py
View file @
ebfab424
File moved
magic_pdf/model/
pek_
sub_modules/layoutlmv3/visualizer.py
→
magic_pdf/model/sub_modules/
layout/
layoutlmv3/visualizer.py
View file @
ebfab424
File moved
magic_pdf/model/sub_modules/mfd/__init__.py
0 → 100644
View file @
ebfab424
magic_pdf/model/sub_modules/mfd/yolov8/YOLOv8.py
0 → 100644
View file @
ebfab424
from
ultralytics
import
YOLO
class
YOLOv8MFDModel
(
object
):
def
__init__
(
self
,
weight
,
device
=
'cpu'
):
self
.
mfd_model
=
YOLO
(
weight
)
self
.
device
=
device
def
predict
(
self
,
image
):
mfd_res
=
self
.
mfd_model
.
predict
(
image
,
imgsz
=
1888
,
conf
=
0.25
,
iou
=
0.45
,
verbose
=
True
,
device
=
self
.
device
)[
0
]
return
mfd_res
magic_pdf/model/sub_modules/mfd/yolov8/__init__.py
0 → 100644
View file @
ebfab424
magic_pdf/model/sub_modules/mfr/__init__.py
0 → 100644
View file @
ebfab424
magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py
0 → 100644
View file @
ebfab424
import
os
import
argparse
import
re
from
PIL
import
Image
import
torch
from
torch.utils.data
import
Dataset
,
DataLoader
from
torchvision
import
transforms
from
unimernet.common.config
import
Config
import
unimernet.tasks
as
tasks
from
unimernet.processors
import
load_processor
class
MathDataset
(
Dataset
):
def
__init__
(
self
,
image_paths
,
transform
=
None
):
self
.
image_paths
=
image_paths
self
.
transform
=
transform
def
__len__
(
self
):
return
len
(
self
.
image_paths
)
def
__getitem__
(
self
,
idx
):
# if not pil image, then convert to pil image
if
isinstance
(
self
.
image_paths
[
idx
],
str
):
raw_image
=
Image
.
open
(
self
.
image_paths
[
idx
])
else
:
raw_image
=
self
.
image_paths
[
idx
]
if
self
.
transform
:
image
=
self
.
transform
(
raw_image
)
return
image
def
latex_rm_whitespace
(
s
:
str
):
"""Remove unnecessary whitespace from LaTeX code.
"""
text_reg
=
r
'(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})'
letter
=
'[a-zA-Z]'
noletter
=
'[\W_^\d]'
names
=
[
x
[
0
].
replace
(
' '
,
''
)
for
x
in
re
.
findall
(
text_reg
,
s
)]
s
=
re
.
sub
(
text_reg
,
lambda
match
:
str
(
names
.
pop
(
0
)),
s
)
news
=
s
while
True
:
s
=
news
news
=
re
.
sub
(
r
'(?!\\ )(%s)\s+?(%s)'
%
(
noletter
,
noletter
),
r
'\1\2'
,
s
)
news
=
re
.
sub
(
r
'(?!\\ )(%s)\s+?(%s)'
%
(
noletter
,
letter
),
r
'\1\2'
,
news
)
news
=
re
.
sub
(
r
'(%s)\s+?(%s)'
%
(
letter
,
noletter
),
r
'\1\2'
,
news
)
if
news
==
s
:
break
return
s
class
UnimernetModel
(
object
):
def
__init__
(
self
,
weight_dir
,
cfg_path
,
_device_
=
'cpu'
):
args
=
argparse
.
Namespace
(
cfg_path
=
cfg_path
,
options
=
None
)
cfg
=
Config
(
args
)
cfg
.
config
.
model
.
pretrained
=
os
.
path
.
join
(
weight_dir
,
"pytorch_model.pth"
)
cfg
.
config
.
model
.
model_config
.
model_name
=
weight_dir
cfg
.
config
.
model
.
tokenizer_config
.
path
=
weight_dir
task
=
tasks
.
setup_task
(
cfg
)
self
.
model
=
task
.
build_model
(
cfg
)
self
.
device
=
_device_
self
.
model
.
to
(
_device_
)
self
.
model
.
eval
()
vis_processor
=
load_processor
(
'formula_image_eval'
,
cfg
.
config
.
datasets
.
formula_rec_eval
.
vis_processor
.
eval
)
self
.
mfr_transform
=
transforms
.
Compose
([
vis_processor
,
])
def
predict
(
self
,
mfd_res
,
image
):
formula_list
=
[]
mf_image_list
=
[]
for
xyxy
,
conf
,
cla
in
zip
(
mfd_res
.
boxes
.
xyxy
.
cpu
(),
mfd_res
.
boxes
.
conf
.
cpu
(),
mfd_res
.
boxes
.
cls
.
cpu
()):
xmin
,
ymin
,
xmax
,
ymax
=
[
int
(
p
.
item
())
for
p
in
xyxy
]
new_item
=
{
'category_id'
:
13
+
int
(
cla
.
item
()),
'poly'
:
[
xmin
,
ymin
,
xmax
,
ymin
,
xmax
,
ymax
,
xmin
,
ymax
],
'score'
:
round
(
float
(
conf
.
item
()),
2
),
'latex'
:
''
,
}
formula_list
.
append
(
new_item
)
pil_img
=
Image
.
fromarray
(
image
)
bbox_img
=
pil_img
.
crop
((
xmin
,
ymin
,
xmax
,
ymax
))
mf_image_list
.
append
(
bbox_img
)
dataset
=
MathDataset
(
mf_image_list
,
transform
=
self
.
mfr_transform
)
dataloader
=
DataLoader
(
dataset
,
batch_size
=
64
,
num_workers
=
0
)
mfr_res
=
[]
for
mf_img
in
dataloader
:
mf_img
=
mf_img
.
to
(
self
.
device
)
with
torch
.
no_grad
():
output
=
self
.
model
.
generate
({
'image'
:
mf_img
})
mfr_res
.
extend
(
output
[
'pred_str'
])
for
res
,
latex
in
zip
(
formula_list
,
mfr_res
):
res
[
'latex'
]
=
latex_rm_whitespace
(
latex
)
return
formula_list
magic_pdf/model/sub_modules/mfr/unimernet/__init__.py
0 → 100644
View file @
ebfab424
magic_pdf/model/sub_modules/model_init.py
0 → 100644
View file @
ebfab424
from
loguru
import
logger
from
magic_pdf.libs.Constants
import
MODEL_NAME
from
magic_pdf.model.model_list
import
AtomicModel
from
magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO
import
DocLayoutYOLOModel
from
magic_pdf.model.sub_modules.layout.layoutlmv3.model_init
import
Layoutlmv3_Predictor
from
magic_pdf.model.sub_modules.mfd.yolov8.YOLOv8
import
YOLOv8MFDModel
from
magic_pdf.model.sub_modules.mfr.unimernet.Unimernet
import
UnimernetModel
from
magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_273_mod
import
ModifiedPaddleOCR
# from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_291_mod import ModifiedPaddleOCR
from
magic_pdf.model.sub_modules.table.structeqtable.struct_eqtable
import
StructTableModel
from
magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle
import
TableMasterPaddleModel
from
magic_pdf.model.sub_modules.table.rapidtable.rapid_table
import
RapidTableModel
def
table_model_init
(
table_model_type
,
model_path
,
max_time
,
_device_
=
'cpu'
):
if
table_model_type
==
MODEL_NAME
.
STRUCT_EQTABLE
:
table_model
=
StructTableModel
(
model_path
,
max_new_tokens
=
2048
,
max_time
=
max_time
)
elif
table_model_type
==
MODEL_NAME
.
TABLE_MASTER
:
config
=
{
"model_dir"
:
model_path
,
"device"
:
_device_
}
table_model
=
TableMasterPaddleModel
(
config
)
elif
table_model_type
==
MODEL_NAME
.
RAPID_TABLE
:
table_model
=
RapidTableModel
()
else
:
logger
.
error
(
"table model type not allow"
)
exit
(
1
)
return
table_model
def
mfd_model_init
(
weight
,
device
=
'cpu'
):
mfd_model
=
YOLOv8MFDModel
(
weight
,
device
)
return
mfd_model
def
mfr_model_init
(
weight_dir
,
cfg_path
,
device
=
'cpu'
):
mfr_model
=
UnimernetModel
(
weight_dir
,
cfg_path
,
device
)
return
mfr_model
def
layout_model_init
(
weight
,
config_file
,
device
):
model
=
Layoutlmv3_Predictor
(
weight
,
config_file
,
device
)
return
model
def
doclayout_yolo_model_init
(
weight
,
device
=
'cpu'
):
model
=
DocLayoutYOLOModel
(
weight
,
device
)
return
model
def
ocr_model_init
(
show_log
:
bool
=
False
,
det_db_box_thresh
=
0.3
,
lang
=
None
,
use_dilation
=
True
,
det_db_unclip_ratio
=
1.8
,
):
if
lang
is
not
None
:
model
=
ModifiedPaddleOCR
(
show_log
=
show_log
,
det_db_box_thresh
=
det_db_box_thresh
,
lang
=
lang
,
use_dilation
=
use_dilation
,
det_db_unclip_ratio
=
det_db_unclip_ratio
,
)
else
:
model
=
ModifiedPaddleOCR
(
show_log
=
show_log
,
det_db_box_thresh
=
det_db_box_thresh
,
use_dilation
=
use_dilation
,
det_db_unclip_ratio
=
det_db_unclip_ratio
,
# use_angle_cls=True,
)
return
model
class
AtomModelSingleton
:
_instance
=
None
_models
=
{}
def
__new__
(
cls
,
*
args
,
**
kwargs
):
if
cls
.
_instance
is
None
:
cls
.
_instance
=
super
().
__new__
(
cls
)
return
cls
.
_instance
def
get_atom_model
(
self
,
atom_model_name
:
str
,
**
kwargs
):
lang
=
kwargs
.
get
(
"lang"
,
None
)
layout_model_name
=
kwargs
.
get
(
"layout_model_name"
,
None
)
key
=
(
atom_model_name
,
layout_model_name
,
lang
)
if
key
not
in
self
.
_models
:
self
.
_models
[
key
]
=
atom_model_init
(
model_name
=
atom_model_name
,
**
kwargs
)
return
self
.
_models
[
key
]
def
atom_model_init
(
model_name
:
str
,
**
kwargs
):
atom_model
=
None
if
model_name
==
AtomicModel
.
Layout
:
if
kwargs
.
get
(
"layout_model_name"
)
==
MODEL_NAME
.
LAYOUTLMv3
:
atom_model
=
layout_model_init
(
kwargs
.
get
(
"layout_weights"
),
kwargs
.
get
(
"layout_config_file"
),
kwargs
.
get
(
"device"
)
)
elif
kwargs
.
get
(
"layout_model_name"
)
==
MODEL_NAME
.
DocLayout_YOLO
:
atom_model
=
doclayout_yolo_model_init
(
kwargs
.
get
(
"doclayout_yolo_weights"
),
kwargs
.
get
(
"device"
)
)
elif
model_name
==
AtomicModel
.
MFD
:
atom_model
=
mfd_model_init
(
kwargs
.
get
(
"mfd_weights"
),
kwargs
.
get
(
"device"
)
)
elif
model_name
==
AtomicModel
.
MFR
:
atom_model
=
mfr_model_init
(
kwargs
.
get
(
"mfr_weight_dir"
),
kwargs
.
get
(
"mfr_cfg_path"
),
kwargs
.
get
(
"device"
)
)
elif
model_name
==
AtomicModel
.
OCR
:
atom_model
=
ocr_model_init
(
kwargs
.
get
(
"ocr_show_log"
),
kwargs
.
get
(
"det_db_box_thresh"
),
kwargs
.
get
(
"lang"
)
)
elif
model_name
==
AtomicModel
.
Table
:
atom_model
=
table_model_init
(
kwargs
.
get
(
"table_model_name"
),
kwargs
.
get
(
"table_model_path"
),
kwargs
.
get
(
"table_max_time"
),
kwargs
.
get
(
"device"
)
)
else
:
logger
.
error
(
"model name not allow"
)
exit
(
1
)
if
atom_model
is
None
:
logger
.
error
(
"model init failed"
)
exit
(
1
)
else
:
return
atom_model
magic_pdf/model/sub_modules/model_utils.py
0 → 100644
View file @
ebfab424
import
time
import
torch
from
PIL
import
Image
from
loguru
import
logger
from
magic_pdf.libs.clean_memory
import
clean_memory
def
crop_img
(
input_res
,
input_pil_img
,
crop_paste_x
=
0
,
crop_paste_y
=
0
):
crop_xmin
,
crop_ymin
=
int
(
input_res
[
'poly'
][
0
]),
int
(
input_res
[
'poly'
][
1
])
crop_xmax
,
crop_ymax
=
int
(
input_res
[
'poly'
][
4
]),
int
(
input_res
[
'poly'
][
5
])
# Create a white background with an additional width and height of 50
crop_new_width
=
crop_xmax
-
crop_xmin
+
crop_paste_x
*
2
crop_new_height
=
crop_ymax
-
crop_ymin
+
crop_paste_y
*
2
return_image
=
Image
.
new
(
'RGB'
,
(
crop_new_width
,
crop_new_height
),
'white'
)
# Crop image
crop_box
=
(
crop_xmin
,
crop_ymin
,
crop_xmax
,
crop_ymax
)
cropped_img
=
input_pil_img
.
crop
(
crop_box
)
return_image
.
paste
(
cropped_img
,
(
crop_paste_x
,
crop_paste_y
))
return_list
=
[
crop_paste_x
,
crop_paste_y
,
crop_xmin
,
crop_ymin
,
crop_xmax
,
crop_ymax
,
crop_new_width
,
crop_new_height
]
return
return_image
,
return_list
# Select regions for OCR / formula regions / table regions
def
get_res_list_from_layout_res
(
layout_res
):
ocr_res_list
=
[]
table_res_list
=
[]
single_page_mfdetrec_res
=
[]
for
res
in
layout_res
:
if
int
(
res
[
'category_id'
])
in
[
13
,
14
]:
single_page_mfdetrec_res
.
append
({
"bbox"
:
[
int
(
res
[
'poly'
][
0
]),
int
(
res
[
'poly'
][
1
]),
int
(
res
[
'poly'
][
4
]),
int
(
res
[
'poly'
][
5
])],
})
elif
int
(
res
[
'category_id'
])
in
[
0
,
1
,
2
,
4
,
6
,
7
]:
ocr_res_list
.
append
(
res
)
elif
int
(
res
[
'category_id'
])
in
[
5
]:
table_res_list
.
append
(
res
)
return
ocr_res_list
,
table_res_list
,
single_page_mfdetrec_res
def
clean_vram
(
device
,
vram_threshold
=
8
):
if
torch
.
cuda
.
is_available
()
and
device
!=
'cpu'
:
total_memory
=
torch
.
cuda
.
get_device_properties
(
device
).
total_memory
/
(
1024
**
3
)
# 将字节转换为 GB
if
total_memory
<=
vram_threshold
:
gc_start
=
time
.
time
()
clean_memory
()
gc_time
=
round
(
time
.
time
()
-
gc_start
,
2
)
logger
.
info
(
f
"gc time:
{
gc_time
}
"
)
\ No newline at end of file
magic_pdf/model/sub_modules/ocr/__init__.py
0 → 100644
View file @
ebfab424
magic_pdf/model/sub_modules/ocr/paddleocr/__init__.py
0 → 100644
View file @
ebfab424
magic_pdf/model/sub_modules/ocr/paddleocr/ocr_utils.py
0 → 100644
View file @
ebfab424
import
math
import
numpy
as
np
from
loguru
import
logger
from
magic_pdf.libs.boxbase
import
__is_overlaps_y_exceeds_threshold
from
magic_pdf.pre_proc.ocr_dict_merge
import
merge_spans_to_line
def
bbox_to_points
(
bbox
):
""" 将bbox格式转换为四个顶点的数组 """
x0
,
y0
,
x1
,
y1
=
bbox
return
np
.
array
([[
x0
,
y0
],
[
x1
,
y0
],
[
x1
,
y1
],
[
x0
,
y1
]]).
astype
(
'float32'
)
def
points_to_bbox
(
points
):
""" 将四个顶点的数组转换为bbox格式 """
x0
,
y0
=
points
[
0
]
x1
,
_
=
points
[
1
]
_
,
y1
=
points
[
2
]
return
[
x0
,
y0
,
x1
,
y1
]
def
merge_intervals
(
intervals
):
# Sort the intervals based on the start value
intervals
.
sort
(
key
=
lambda
x
:
x
[
0
])
merged
=
[]
for
interval
in
intervals
:
# If the list of merged intervals is empty or if the current
# interval does not overlap with the previous, simply append it.
if
not
merged
or
merged
[
-
1
][
1
]
<
interval
[
0
]:
merged
.
append
(
interval
)
else
:
# Otherwise, there is overlap, so we merge the current and previous intervals.
merged
[
-
1
][
1
]
=
max
(
merged
[
-
1
][
1
],
interval
[
1
])
return
merged
def
remove_intervals
(
original
,
masks
):
# Merge all mask intervals
merged_masks
=
merge_intervals
(
masks
)
result
=
[]
original_start
,
original_end
=
original
for
mask
in
merged_masks
:
mask_start
,
mask_end
=
mask
# If the mask starts after the original range, ignore it
if
mask_start
>
original_end
:
continue
# If the mask ends before the original range starts, ignore it
if
mask_end
<
original_start
:
continue
# Remove the masked part from the original range
if
original_start
<
mask_start
:
result
.
append
([
original_start
,
mask_start
-
1
])
original_start
=
max
(
mask_end
+
1
,
original_start
)
# Add the remaining part of the original range, if any
if
original_start
<=
original_end
:
result
.
append
([
original_start
,
original_end
])
return
result
def
update_det_boxes
(
dt_boxes
,
mfd_res
):
new_dt_boxes
=
[]
for
text_box
in
dt_boxes
:
text_bbox
=
points_to_bbox
(
text_box
)
masks_list
=
[]
for
mf_box
in
mfd_res
:
mf_bbox
=
mf_box
[
'bbox'
]
if
__is_overlaps_y_exceeds_threshold
(
text_bbox
,
mf_bbox
):
masks_list
.
append
([
mf_bbox
[
0
],
mf_bbox
[
2
]])
text_x_range
=
[
text_bbox
[
0
],
text_bbox
[
2
]]
text_remove_mask_range
=
remove_intervals
(
text_x_range
,
masks_list
)
temp_dt_box
=
[]
for
text_remove_mask
in
text_remove_mask_range
:
temp_dt_box
.
append
(
bbox_to_points
([
text_remove_mask
[
0
],
text_bbox
[
1
],
text_remove_mask
[
1
],
text_bbox
[
3
]]))
if
len
(
temp_dt_box
)
>
0
:
new_dt_boxes
.
extend
(
temp_dt_box
)
return
new_dt_boxes
def
merge_overlapping_spans
(
spans
):
"""
Merges overlapping spans on the same line.
:param spans: A list of span coordinates [(x1, y1, x2, y2), ...]
:return: A list of merged spans
"""
# Return an empty list if the input spans list is empty
if
not
spans
:
return
[]
# Sort spans by their starting x-coordinate
spans
.
sort
(
key
=
lambda
x
:
x
[
0
])
# Initialize the list of merged spans
merged
=
[]
for
span
in
spans
:
# Unpack span coordinates
x1
,
y1
,
x2
,
y2
=
span
# If the merged list is empty or there's no horizontal overlap, add the span directly
if
not
merged
or
merged
[
-
1
][
2
]
<
x1
:
merged
.
append
(
span
)
else
:
# If there is horizontal overlap, merge the current span with the previous one
last_span
=
merged
.
pop
()
# Update the merged span's top-left corner to the smaller (x1, y1) and bottom-right to the larger (x2, y2)
x1
=
min
(
last_span
[
0
],
x1
)
y1
=
min
(
last_span
[
1
],
y1
)
x2
=
max
(
last_span
[
2
],
x2
)
y2
=
max
(
last_span
[
3
],
y2
)
# Add the merged span back to the list
merged
.
append
((
x1
,
y1
,
x2
,
y2
))
# Return the list of merged spans
return
merged
def
merge_det_boxes
(
dt_boxes
):
"""
Merge detection boxes.
This function takes a list of detected bounding boxes, each represented by four corner points.
The goal is to merge these bounding boxes into larger text regions.
Parameters:
dt_boxes (list): A list containing multiple text detection boxes, where each box is defined by four corner points.
Returns:
list: A list containing the merged text regions, where each region is represented by four corner points.
"""
# Convert the detection boxes into a dictionary format with bounding boxes and type
dt_boxes_dict_list
=
[]
angle_boxes_list
=
[]
for
text_box
in
dt_boxes
:
text_bbox
=
points_to_bbox
(
text_box
)
if
text_bbox
[
2
]
<=
text_bbox
[
0
]
or
text_bbox
[
3
]
<=
text_bbox
[
1
]:
angle_boxes_list
.
append
(
text_box
)
continue
text_box_dict
=
{
'bbox'
:
text_bbox
,
'type'
:
'text'
,
}
dt_boxes_dict_list
.
append
(
text_box_dict
)
# Merge adjacent text regions into lines
lines
=
merge_spans_to_line
(
dt_boxes_dict_list
)
# Initialize a new list for storing the merged text regions
new_dt_boxes
=
[]
for
line
in
lines
:
line_bbox_list
=
[]
for
span
in
line
:
line_bbox_list
.
append
(
span
[
'bbox'
])
# Merge overlapping text regions within the same line
merged_spans
=
merge_overlapping_spans
(
line_bbox_list
)
# Convert the merged text regions back to point format and add them to the new detection box list
for
span
in
merged_spans
:
new_dt_boxes
.
append
(
bbox_to_points
(
span
))
new_dt_boxes
.
extend
(
angle_boxes_list
)
return
new_dt_boxes
def
get_adjusted_mfdetrec_res
(
single_page_mfdetrec_res
,
useful_list
):
paste_x
,
paste_y
,
xmin
,
ymin
,
xmax
,
ymax
,
new_width
,
new_height
=
useful_list
# Adjust the coordinates of the formula area
adjusted_mfdetrec_res
=
[]
for
mf_res
in
single_page_mfdetrec_res
:
mf_xmin
,
mf_ymin
,
mf_xmax
,
mf_ymax
=
mf_res
[
"bbox"
]
# Adjust the coordinates of the formula area to the coordinates relative to the cropping area
x0
=
mf_xmin
-
xmin
+
paste_x
y0
=
mf_ymin
-
ymin
+
paste_y
x1
=
mf_xmax
-
xmin
+
paste_x
y1
=
mf_ymax
-
ymin
+
paste_y
# Filter formula blocks outside the graph
if
any
([
x1
<
0
,
y1
<
0
])
or
any
([
x0
>
new_width
,
y0
>
new_height
]):
continue
else
:
adjusted_mfdetrec_res
.
append
({
"bbox"
:
[
x0
,
y0
,
x1
,
y1
],
})
return
adjusted_mfdetrec_res
def
get_ocr_result_list
(
ocr_res
,
useful_list
):
paste_x
,
paste_y
,
xmin
,
ymin
,
xmax
,
ymax
,
new_width
,
new_height
=
useful_list
ocr_result_list
=
[]
for
box_ocr_res
in
ocr_res
:
p1
,
p2
,
p3
,
p4
=
box_ocr_res
[
0
]
text
,
score
=
box_ocr_res
[
1
]
average_angle_degrees
=
calculate_angle_degrees
(
box_ocr_res
[
0
])
if
average_angle_degrees
>
0.5
:
# logger.info(f"average_angle_degrees: {average_angle_degrees}, text: {text}")
# 与x轴的夹角超过0.5度,对边界做一下矫正
# 计算几何中心
x_center
=
sum
(
point
[
0
]
for
point
in
box_ocr_res
[
0
])
/
4
y_center
=
sum
(
point
[
1
]
for
point
in
box_ocr_res
[
0
])
/
4
new_height
=
((
p4
[
1
]
-
p1
[
1
])
+
(
p3
[
1
]
-
p2
[
1
]))
/
2
new_width
=
p3
[
0
]
-
p1
[
0
]
p1
=
[
x_center
-
new_width
/
2
,
y_center
-
new_height
/
2
]
p2
=
[
x_center
+
new_width
/
2
,
y_center
-
new_height
/
2
]
p3
=
[
x_center
+
new_width
/
2
,
y_center
+
new_height
/
2
]
p4
=
[
x_center
-
new_width
/
2
,
y_center
+
new_height
/
2
]
# Convert the coordinates back to the original coordinate system
p1
=
[
p1
[
0
]
-
paste_x
+
xmin
,
p1
[
1
]
-
paste_y
+
ymin
]
p2
=
[
p2
[
0
]
-
paste_x
+
xmin
,
p2
[
1
]
-
paste_y
+
ymin
]
p3
=
[
p3
[
0
]
-
paste_x
+
xmin
,
p3
[
1
]
-
paste_y
+
ymin
]
p4
=
[
p4
[
0
]
-
paste_x
+
xmin
,
p4
[
1
]
-
paste_y
+
ymin
]
ocr_result_list
.
append
({
'category_id'
:
15
,
'poly'
:
p1
+
p2
+
p3
+
p4
,
'score'
:
float
(
round
(
score
,
2
)),
'text'
:
text
,
})
return
ocr_result_list
def
calculate_angle_degrees
(
poly
):
# 定义对角线的顶点
diagonal1
=
(
poly
[
0
],
poly
[
2
])
diagonal2
=
(
poly
[
1
],
poly
[
3
])
# 计算对角线的斜率
def
slope
(
p1
,
p2
):
return
(
p2
[
1
]
-
p1
[
1
])
/
(
p2
[
0
]
-
p1
[
0
])
if
p2
[
0
]
!=
p1
[
0
]
else
float
(
'inf'
)
slope1
=
slope
(
diagonal1
[
0
],
diagonal1
[
1
])
slope2
=
slope
(
diagonal2
[
0
],
diagonal2
[
1
])
# 计算对角线与x轴的夹角(以弧度为单位)
angle1_radians
=
math
.
atan
(
slope1
)
angle2_radians
=
math
.
atan
(
slope2
)
# 将弧度转换为角度
angle1_degrees
=
math
.
degrees
(
angle1_radians
)
angle2_degrees
=
math
.
degrees
(
angle2_radians
)
# 取两条对角线与x轴夹角的平均值
average_angle_degrees
=
abs
((
angle1_degrees
+
angle2_degrees
)
/
2
)
# logger.info(f"average_angle_degrees: {average_angle_degrees}")
return
average_angle_degrees
magic_pdf/model/
pek_
sub_modules/
self
_mod
ify
.py
→
magic_pdf/model/sub_modules/
ocr/paddleocr/ppocr_273
_mod.py
View file @
ebfab424
import
time
import
copy
import
copy
import
base64
import
time
import
cv2
import
cv2
import
numpy
as
np
import
numpy
as
np
from
io
import
BytesIO
from
PIL
import
Image
from
paddleocr
import
PaddleOCR
from
paddleocr
import
PaddleOCR
from
paddleocr.ppocr.utils.logging
import
get_logger
from
paddleocr.paddleocr
import
check_img
,
logger
from
paddleocr.ppocr.utils.utility
import
check_and_read
,
alpha_to_color
,
binarize_img
from
paddleocr.ppocr.utils.utility
import
alpha_to_color
,
binarize_img
from
paddleocr.tools.infer.utility
import
draw_ocr_box_txt
,
get_rotate_crop_image
,
get_minarea_rect_crop
from
paddleocr.tools.infer.predict_system
import
sorted_boxes
from
paddleocr.tools.infer.utility
import
get_rotate_crop_image
,
get_minarea_rect_crop
from
magic_pdf.libs.boxbase
import
__is_overlaps_y_exceeds_threshold
from
magic_pdf.pre_proc.ocr_dict_merge
import
merge_spans_to_line
logger
=
get_logger
()
def
img_decode
(
content
:
bytes
):
np_arr
=
np
.
frombuffer
(
content
,
dtype
=
np
.
uint8
)
return
cv2
.
imdecode
(
np_arr
,
cv2
.
IMREAD_UNCHANGED
)
def
check_img
(
img
):
if
isinstance
(
img
,
bytes
):
img
=
img_decode
(
img
)
if
isinstance
(
img
,
str
):
image_file
=
img
img
,
flag_gif
,
flag_pdf
=
check_and_read
(
image_file
)
if
not
flag_gif
and
not
flag_pdf
:
with
open
(
image_file
,
'rb'
)
as
f
:
img_str
=
f
.
read
()
img
=
img_decode
(
img_str
)
if
img
is
None
:
try
:
buf
=
BytesIO
()
image
=
BytesIO
(
img_str
)
im
=
Image
.
open
(
image
)
rgb
=
im
.
convert
(
'RGB'
)
rgb
.
save
(
buf
,
'jpeg'
)
buf
.
seek
(
0
)
image_bytes
=
buf
.
read
()
data_base64
=
str
(
base64
.
b64encode
(
image_bytes
),
encoding
=
"utf-8"
)
image_decode
=
base64
.
b64decode
(
data_base64
)
img_array
=
np
.
frombuffer
(
image_decode
,
np
.
uint8
)
img
=
cv2
.
imdecode
(
img_array
,
cv2
.
IMREAD_COLOR
)
except
:
logger
.
error
(
"error in loading image:{}"
.
format
(
image_file
))
return
None
if
img
is
None
:
logger
.
error
(
"error in loading image:{}"
.
format
(
image_file
))
return
None
if
isinstance
(
img
,
np
.
ndarray
)
and
len
(
img
.
shape
)
==
2
:
img
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_GRAY2BGR
)
return
img
def
sorted_boxes
(
dt_boxes
):
"""
Sort text boxes in order from top to bottom, left to right
args:
dt_boxes(array):detected text boxes with shape [4, 2]
return:
sorted boxes(array) with shape [4, 2]
"""
num_boxes
=
dt_boxes
.
shape
[
0
]
sorted_boxes
=
sorted
(
dt_boxes
,
key
=
lambda
x
:
(
x
[
0
][
1
],
x
[
0
][
0
]))
_boxes
=
list
(
sorted_boxes
)
for
i
in
range
(
num_boxes
-
1
):
for
j
in
range
(
i
,
-
1
,
-
1
):
if
abs
(
_boxes
[
j
+
1
][
0
][
1
]
-
_boxes
[
j
][
0
][
1
])
<
10
and
\
(
_boxes
[
j
+
1
][
0
][
0
]
<
_boxes
[
j
][
0
][
0
]):
tmp
=
_boxes
[
j
]
_boxes
[
j
]
=
_boxes
[
j
+
1
]
_boxes
[
j
+
1
]
=
tmp
else
:
break
return
_boxes
def
bbox_to_points
(
bbox
):
""" 将bbox格式转换为四个顶点的数组 """
x0
,
y0
,
x1
,
y1
=
bbox
return
np
.
array
([[
x0
,
y0
],
[
x1
,
y0
],
[
x1
,
y1
],
[
x0
,
y1
]]).
astype
(
'float32'
)
def
points_to_bbox
(
points
):
""" 将四个顶点的数组转换为bbox格式 """
x0
,
y0
=
points
[
0
]
x1
,
_
=
points
[
1
]
_
,
y1
=
points
[
2
]
return
[
x0
,
y0
,
x1
,
y1
]
def
merge_intervals
(
intervals
):
# Sort the intervals based on the start value
intervals
.
sort
(
key
=
lambda
x
:
x
[
0
])
merged
=
[]
for
interval
in
intervals
:
# If the list of merged intervals is empty or if the current
# interval does not overlap with the previous, simply append it.
if
not
merged
or
merged
[
-
1
][
1
]
<
interval
[
0
]:
merged
.
append
(
interval
)
else
:
# Otherwise, there is overlap, so we merge the current and previous intervals.
merged
[
-
1
][
1
]
=
max
(
merged
[
-
1
][
1
],
interval
[
1
])
return
merged
def
remove_intervals
(
original
,
masks
):
# Merge all mask intervals
merged_masks
=
merge_intervals
(
masks
)
result
=
[]
from
magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils
import
update_det_boxes
,
merge_det_boxes
original_start
,
original_end
=
original
for
mask
in
merged_masks
:
mask_start
,
mask_end
=
mask
# If the mask starts after the original range, ignore it
if
mask_start
>
original_end
:
continue
# If the mask ends before the original range starts, ignore it
if
mask_end
<
original_start
:
continue
# Remove the masked part from the original range
if
original_start
<
mask_start
:
result
.
append
([
original_start
,
mask_start
-
1
])
original_start
=
max
(
mask_end
+
1
,
original_start
)
# Add the remaining part of the original range, if any
if
original_start
<=
original_end
:
result
.
append
([
original_start
,
original_end
])
return
result
def
update_det_boxes
(
dt_boxes
,
mfd_res
):
new_dt_boxes
=
[]
for
text_box
in
dt_boxes
:
text_bbox
=
points_to_bbox
(
text_box
)
masks_list
=
[]
for
mf_box
in
mfd_res
:
mf_bbox
=
mf_box
[
'bbox'
]
if
__is_overlaps_y_exceeds_threshold
(
text_bbox
,
mf_bbox
):
masks_list
.
append
([
mf_bbox
[
0
],
mf_bbox
[
2
]])
text_x_range
=
[
text_bbox
[
0
],
text_bbox
[
2
]]
text_remove_mask_range
=
remove_intervals
(
text_x_range
,
masks_list
)
temp_dt_box
=
[]
for
text_remove_mask
in
text_remove_mask_range
:
temp_dt_box
.
append
(
bbox_to_points
([
text_remove_mask
[
0
],
text_bbox
[
1
],
text_remove_mask
[
1
],
text_bbox
[
3
]]))
if
len
(
temp_dt_box
)
>
0
:
new_dt_boxes
.
extend
(
temp_dt_box
)
return
new_dt_boxes
def
merge_overlapping_spans
(
spans
):
"""
Merges overlapping spans on the same line.
:param spans: A list of span coordinates [(x1, y1, x2, y2), ...]
:return: A list of merged spans
"""
# Return an empty list if the input spans list is empty
if
not
spans
:
return
[]
# Sort spans by their starting x-coordinate
spans
.
sort
(
key
=
lambda
x
:
x
[
0
])
# Initialize the list of merged spans
merged
=
[]
for
span
in
spans
:
# Unpack span coordinates
x1
,
y1
,
x2
,
y2
=
span
# If the merged list is empty or there's no horizontal overlap, add the span directly
if
not
merged
or
merged
[
-
1
][
2
]
<
x1
:
merged
.
append
(
span
)
else
:
# If there is horizontal overlap, merge the current span with the previous one
last_span
=
merged
.
pop
()
# Update the merged span's top-left corner to the smaller (x1, y1) and bottom-right to the larger (x2, y2)
x1
=
min
(
last_span
[
0
],
x1
)
y1
=
min
(
last_span
[
1
],
y1
)
x2
=
max
(
last_span
[
2
],
x2
)
y2
=
max
(
last_span
[
3
],
y2
)
# Add the merged span back to the list
merged
.
append
((
x1
,
y1
,
x2
,
y2
))
# Return the list of merged spans
return
merged
def
merge_det_boxes
(
dt_boxes
):
"""
Merge detection boxes.
This function takes a list of detected bounding boxes, each represented by four corner points.
The goal is to merge these bounding boxes into larger text regions.
Parameters:
dt_boxes (list): A list containing multiple text detection boxes, where each box is defined by four corner points.
Returns:
list: A list containing the merged text regions, where each region is represented by four corner points.
"""
# Convert the detection boxes into a dictionary format with bounding boxes and type
dt_boxes_dict_list
=
[]
for
text_box
in
dt_boxes
:
text_bbox
=
points_to_bbox
(
text_box
)
text_box_dict
=
{
'bbox'
:
text_bbox
,
'type'
:
'text'
,
}
dt_boxes_dict_list
.
append
(
text_box_dict
)
# Merge adjacent text regions into lines
lines
=
merge_spans_to_line
(
dt_boxes_dict_list
)
# Initialize a new list for storing the merged text regions
new_dt_boxes
=
[]
for
line
in
lines
:
line_bbox_list
=
[]
for
span
in
line
:
line_bbox_list
.
append
(
span
[
'bbox'
])
# Merge overlapping text regions within the same line
merged_spans
=
merge_overlapping_spans
(
line_bbox_list
)
# Convert the merged text regions back to point format and add them to the new detection box list
for
span
in
merged_spans
:
new_dt_boxes
.
append
(
bbox_to_points
(
span
))
return
new_dt_boxes
class
ModifiedPaddleOCR
(
PaddleOCR
):
class
ModifiedPaddleOCR
(
PaddleOCR
):
def
ocr
(
self
,
img
,
det
=
True
,
rec
=
True
,
cls
=
True
,
bin
=
False
,
inv
=
False
,
mfd_res
=
None
,
alpha_color
=
(
255
,
255
,
255
)):
def
ocr
(
self
,
img
,
det
=
True
,
rec
=
True
,
cls
=
True
,
bin
=
False
,
inv
=
False
,
alpha_color
=
(
255
,
255
,
255
),
mfd_res
=
None
,
):
"""
"""
OCR with PaddleOCR
OCR with PaddleOCR
args:
args:
...
@@ -347,7 +125,9 @@ class ModifiedPaddleOCR(PaddleOCR):
...
@@ -347,7 +125,9 @@ class ModifiedPaddleOCR(PaddleOCR):
dt_boxes
=
sorted_boxes
(
dt_boxes
)
dt_boxes
=
sorted_boxes
(
dt_boxes
)
dt_boxes
=
merge_det_boxes
(
dt_boxes
)
# @todo 目前是在bbox层merge,对倾斜文本行的兼容性不佳,需要修改成支持poly的merge
# dt_boxes = merge_det_boxes(dt_boxes)
if
mfd_res
:
if
mfd_res
:
bef
=
time
.
time
()
bef
=
time
.
time
()
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment