Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
MinerU
Commits
62b7582f
Commit
62b7582f
authored
Apr 01, 2025
by
myhloli
Browse files
Merge remote-tracking branch 'origin/dev' into dev
parents
978ef41c
41f1fb8a
Changes
60
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2414 additions
and
150 deletions
+2414
-150
magic_pdf/model/batch_analyze.py
magic_pdf/model/batch_analyze.py
+97
-24
magic_pdf/model/doc_analyze_by_custom_model.py
magic_pdf/model/doc_analyze_by_custom_model.py
+33
-97
magic_pdf/model/pdf_extract_kit.py
magic_pdf/model/pdf_extract_kit.py
+1
-1
magic_pdf/model/sub_modules/model_init.py
magic_pdf/model/sub_modules/model_init.py
+31
-28
magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/__init__.py
...c_pdf/model/sub_modules/ocr/paddleocr2pytorch/__init__.py
+1
-0
magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/models_config.yml
...model/sub_modules/ocr/paddleocr2pytorch/models_config.yml
+49
-0
magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/ocr_utils.py
..._pdf/model/sub_modules/ocr/paddleocr2pytorch/ocr_utils.py
+368
-0
magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorch_paddle.py
...model/sub_modules/ocr/paddleocr2pytorch/pytorch_paddle.py
+189
-0
magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/__init__.py
.../sub_modules/ocr/paddleocr2pytorch/pytorchocr/__init__.py
+0
-0
magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/base_ocr_v20.py
..._modules/ocr/paddleocr2pytorch/pytorchocr/base_ocr_v20.py
+39
-0
magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/data/__init__.py
...modules/ocr/paddleocr2pytorch/pytorchocr/data/__init__.py
+23
-0
magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/data/imaug/__init__.py
...s/ocr/paddleocr2pytorch/pytorchocr/data/imaug/__init__.py
+48
-0
magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/data/imaug/operators.py
.../ocr/paddleocr2pytorch/pytorchocr/data/imaug/operators.py
+418
-0
magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/__init__.py
...les/ocr/paddleocr2pytorch/pytorchocr/modeling/__init__.py
+0
-0
magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/arch_config.yaml
...cr/paddleocr2pytorch/pytorchocr/modeling/arch_config.yaml
+366
-0
magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/architectures/__init__.py
...ocr2pytorch/pytorchocr/modeling/architectures/__init__.py
+25
-0
magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/architectures/base_model.py
...r2pytorch/pytorchocr/modeling/architectures/base_model.py
+105
-0
magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/__init__.py
...ddleocr2pytorch/pytorchocr/modeling/backbones/__init__.py
+62
-0
magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/det_mobilenet_v3.py
...pytorch/pytorchocr/modeling/backbones/det_mobilenet_v3.py
+269
-0
magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_hgnet.py
...dleocr2pytorch/pytorchocr/modeling/backbones/rec_hgnet.py
+290
-0
No files found.
magic_pdf/model/batch_analyze.py
View file @
62b7582f
...
...
@@ -5,10 +5,10 @@ import torch
from
loguru
import
logger
from
magic_pdf.config.constants
import
MODEL_NAME
from
magic_pdf.model.
pdf_extract_kit
import
CustomPEKModel
from
magic_pdf.model.
sub_modules.model_init
import
AtomModelSingleton
from
magic_pdf.model.sub_modules.model_utils
import
(
clean_vram
,
crop_img
,
get_res_list_from_layout_res
)
from
magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils
import
(
from
magic_pdf.model.sub_modules.ocr.paddleocr
2pytorch
.ocr_utils
import
(
get_adjusted_mfdetrec_res
,
get_ocr_result_list
)
YOLO_LAYOUT_BASE_BATCH_SIZE
=
1
...
...
@@ -17,13 +17,25 @@ MFR_BASE_BATCH_SIZE = 16
class
BatchAnalyze
:
def
__init__
(
self
,
model
:
CustomPEKModel
,
batch_ratio
:
int
):
self
.
model
=
model
def
__init__
(
self
,
model
_manager
,
batch_ratio
:
int
,
show_log
,
layout_model
,
formula_enable
,
table_enable
):
self
.
model
_manager
=
model_manager
self
.
batch_ratio
=
batch_ratio
def
__call__
(
self
,
images
:
list
)
->
list
:
self
.
show_log
=
show_log
self
.
layout_model
=
layout_model
self
.
formula_enable
=
formula_enable
self
.
table_enable
=
table_enable
def
__call__
(
self
,
images_with_extra_info
:
list
)
->
list
:
if
len
(
images_with_extra_info
)
==
0
:
return
[]
images_layout_res
=
[]
layout_start_time
=
time
.
time
()
_
,
fst_ocr
,
fst_lang
=
images_with_extra_info
[
0
]
self
.
model
=
self
.
model_manager
.
get_model
(
fst_ocr
,
self
.
show_log
,
fst_lang
,
self
.
layout_model
,
self
.
formula_enable
,
self
.
table_enable
)
images
=
[
image
for
image
,
_
,
_
in
images_with_extra_info
]
if
self
.
model
.
layout_model_name
==
MODEL_NAME
.
LAYOUTLMv3
:
# layoutlmv3
for
image
in
images
:
...
...
@@ -73,12 +85,14 @@ class BatchAnalyze:
# 清理显存
clean_vram
(
self
.
model
.
device
,
vram_threshold
=
8
)
ocr
_time
=
0
ocr
_count
=
0
det
_time
=
0
det
_count
=
0
table_time
=
0
table_count
=
0
# reference: magic_pdf/model/doc_analyze_by_custom_model.py:doc_analyze
for
index
in
range
(
len
(
images
)):
_
,
ocr_enable
,
_lang
=
images_with_extra_info
[
index
]
self
.
model
=
self
.
model_manager
.
get_model
(
ocr_enable
,
self
.
show_log
,
_lang
,
self
.
layout_model
,
self
.
formula_enable
,
self
.
table_enable
)
layout_res
=
images_layout_res
[
index
]
np_array_img
=
images
[
index
]
...
...
@@ -86,7 +100,7 @@ class BatchAnalyze:
get_res_list_from_layout_res
(
layout_res
)
)
# ocr识别
ocr
_start
=
time
.
time
()
det
_start
=
time
.
time
()
# Process each area that requires OCR processing
for
res
in
ocr_res_list
:
new_image
,
useful_list
=
crop_img
(
...
...
@@ -99,21 +113,21 @@ class BatchAnalyze:
# OCR recognition
new_image
=
cv2
.
cvtColor
(
new_image
,
cv2
.
COLOR_RGB2BGR
)
if
self
.
model
.
apply_ocr
:
ocr_res
=
self
.
model
.
ocr_model
.
ocr
(
new_image
,
mfd_res
=
adjusted_mfdetrec_res
)[
0
]
else
:
ocr_res
=
self
.
model
.
ocr_model
.
ocr
(
new_image
,
mfd_res
=
adjusted_mfdetrec_res
,
rec
=
False
)[
0
]
#
if
ocr_enable
:
#
ocr_res = self.model.ocr_model.ocr(
#
new_image, mfd_res=adjusted_mfdetrec_res
#
)[0]
#
else:
ocr_res
=
self
.
model
.
ocr_model
.
ocr
(
new_image
,
mfd_res
=
adjusted_mfdetrec_res
,
rec
=
False
)[
0
]
# Integration results
if
ocr_res
:
ocr_result_list
=
get_ocr_result_list
(
ocr_res
,
useful_list
)
ocr_result_list
=
get_ocr_result_list
(
ocr_res
,
useful_list
,
ocr_enable
,
new_image
,
_lang
)
layout_res
.
extend
(
ocr_result_list
)
ocr
_time
+=
time
.
time
()
-
ocr
_start
ocr
_count
+=
len
(
ocr_res_list
)
det
_time
+=
time
.
time
()
-
det
_start
det
_count
+=
len
(
ocr_res_list
)
# 表格识别 table recognition
if
self
.
model
.
apply_table
:
...
...
@@ -158,11 +172,70 @@ class BatchAnalyze:
table_time
+=
time
.
time
()
-
table_start
table_count
+=
len
(
table_res_list
)
if
self
.
model
.
apply_ocr
:
logger
.
info
(
f
'ocr time:
{
round
(
ocr_time
,
2
)
}
, image num:
{
ocr_count
}
'
)
else
:
logger
.
info
(
f
'det time:
{
round
(
ocr_time
,
2
)
}
, image num:
{
ocr_count
}
'
)
logger
.
info
(
f
'ocr-det time:
{
round
(
det_time
,
2
)
}
, image num:
{
det_count
}
'
)
if
self
.
model
.
apply_table
:
logger
.
info
(
f
'table time:
{
round
(
table_time
,
2
)
}
, image num:
{
table_count
}
'
)
# Create dictionaries to store items by language
need_ocr_lists_by_lang
=
{}
# Dict of lists for each language
img_crop_lists_by_lang
=
{}
# Dict of lists for each language
for
layout_res
in
images_layout_res
:
for
layout_res_item
in
layout_res
:
if
layout_res_item
[
'category_id'
]
in
[
15
]:
if
'np_img'
in
layout_res_item
and
'lang'
in
layout_res_item
:
lang
=
layout_res_item
[
'lang'
]
# Initialize lists for this language if not exist
if
lang
not
in
need_ocr_lists_by_lang
:
need_ocr_lists_by_lang
[
lang
]
=
[]
img_crop_lists_by_lang
[
lang
]
=
[]
# Add to the appropriate language-specific lists
need_ocr_lists_by_lang
[
lang
].
append
(
layout_res_item
)
img_crop_lists_by_lang
[
lang
].
append
(
layout_res_item
[
'np_img'
])
# Remove the fields after adding to lists
layout_res_item
.
pop
(
'np_img'
)
layout_res_item
.
pop
(
'lang'
)
if
len
(
img_crop_lists_by_lang
)
>
0
:
# Process OCR by language
rec_time
=
0
rec_start
=
time
.
time
()
total_processed
=
0
# Process each language separately
for
lang
,
img_crop_list
in
img_crop_lists_by_lang
.
items
():
if
len
(
img_crop_list
)
>
0
:
# Get OCR results for this language's images
atom_model_manager
=
AtomModelSingleton
()
ocr_model
=
atom_model_manager
.
get_atom_model
(
atom_model_name
=
'ocr'
,
ocr_show_log
=
False
,
det_db_box_thresh
=
0.3
,
lang
=
lang
)
ocr_res_list
=
ocr_model
.
ocr
(
img_crop_list
,
det
=
False
)[
0
]
# Verify we have matching counts
assert
len
(
ocr_res_list
)
==
len
(
need_ocr_lists_by_lang
[
lang
]),
f
'ocr_res_list:
{
len
(
ocr_res_list
)
}
, need_ocr_list:
{
len
(
need_ocr_lists_by_lang
[
lang
])
}
for lang:
{
lang
}
'
# Process OCR results for this language
for
index
,
layout_res_item
in
enumerate
(
need_ocr_lists_by_lang
[
lang
]):
ocr_text
,
ocr_score
=
ocr_res_list
[
index
]
layout_res_item
[
'text'
]
=
ocr_text
layout_res_item
[
'score'
]
=
float
(
round
(
ocr_score
,
2
))
total_processed
+=
len
(
img_crop_list
)
rec_time
+=
time
.
time
()
-
rec_start
logger
.
info
(
f
'ocr-rec time:
{
round
(
rec_time
,
2
)
}
, total images processed:
{
total_processed
}
'
)
return
images_layout_res
magic_pdf/model/doc_analyze_by_custom_model.py
View file @
62b7582f
...
...
@@ -15,7 +15,7 @@ os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
from
loguru
import
logger
from
magic_pdf.model.sub_modules.model_utils
import
get_vram
from
magic_pdf.config.enums
import
SupportedPdfParseMethod
import
magic_pdf.model
as
model_config
from
magic_pdf.data.dataset
import
Dataset
from
magic_pdf.libs.clean_memory
import
clean_memory
...
...
@@ -141,7 +141,7 @@ def doc_analyze(
else
len
(
dataset
)
-
1
)
MIN_BATCH_INFERENCE_SIZE
=
int
(
os
.
environ
.
get
(
'MINERU_MIN_BATCH_INFERENCE_SIZE'
,
1
00
))
MIN_BATCH_INFERENCE_SIZE
=
int
(
os
.
environ
.
get
(
'MINERU_MIN_BATCH_INFERENCE_SIZE'
,
2
00
))
images
=
[]
page_wh_list
=
[]
for
index
in
range
(
len
(
dataset
)):
...
...
@@ -150,16 +150,20 @@ def doc_analyze(
img_dict
=
page_data
.
get_image
()
images
.
append
(
img_dict
[
'img'
])
page_wh_list
.
append
((
img_dict
[
'width'
],
img_dict
[
'height'
]))
if
lang
is
None
or
lang
==
'auto'
:
images_with_extra_info
=
[(
images
[
index
],
ocr
,
dataset
.
_lang
)
for
index
in
range
(
len
(
dataset
))]
else
:
images_with_extra_info
=
[(
images
[
index
],
ocr
,
lang
)
for
index
in
range
(
len
(
dataset
))]
if
len
(
images
)
>=
MIN_BATCH_INFERENCE_SIZE
:
batch_size
=
MIN_BATCH_INFERENCE_SIZE
batch_images
=
[
images
[
i
:
i
+
batch_size
]
for
i
in
range
(
0
,
len
(
images
),
batch_size
)]
batch_images
=
[
images
_with_extra_info
[
i
:
i
+
batch_size
]
for
i
in
range
(
0
,
len
(
images
_with_extra_info
),
batch_size
)]
else
:
batch_images
=
[
images
]
batch_images
=
[
images
_with_extra_info
]
results
=
[]
for
sn
,
batch_image
in
enumerate
(
batch_images
):
_
,
result
=
may_batch_image_analyze
(
batch_image
,
sn
,
ocr
,
show_log
,
lang
,
layout_model
,
formula_enable
,
table_enable
)
_
,
result
=
may_batch_image_analyze
(
batch_image
,
sn
,
ocr
,
show_log
,
layout_model
,
formula_enable
,
table_enable
)
results
.
extend
(
result
)
model_json
=
[]
...
...
@@ -181,7 +185,7 @@ def doc_analyze(
def
batch_doc_analyze
(
datasets
:
list
[
Dataset
],
ocr
:
bool
=
False
,
parse_method
:
str
,
show_log
:
bool
=
False
,
lang
=
None
,
layout_model
=
None
,
...
...
@@ -192,47 +196,31 @@ def batch_doc_analyze(
batch_size
=
MIN_BATCH_INFERENCE_SIZE
images
=
[]
page_wh_list
=
[]
lang_list
=
[]
lang_s
=
set
()
images_with_extra_info
=
[]
for
dataset
in
datasets
:
for
index
in
range
(
len
(
dataset
)):
if
lang
is
None
or
lang
==
'auto'
:
lang_list
.
append
(
dataset
.
_lang
)
_lang
=
dataset
.
_lang
else
:
lang_list
.
append
(
lang
)
lang_s
.
add
(
lang_list
[
-
1
])
_lang
=
lang
page_data
=
dataset
.
get_page
(
index
)
img_dict
=
page_data
.
get_image
()
images
.
append
(
img_dict
[
'img'
])
page_wh_list
.
append
((
img_dict
[
'width'
],
img_dict
[
'height'
]))
if
parse_method
==
'auto'
:
images_with_extra_info
.
append
((
images
[
-
1
],
dataset
.
classify
()
==
SupportedPdfParseMethod
.
OCR
,
_lang
))
else
:
images_with_extra_info
.
append
((
images
[
-
1
],
parse_method
==
'ocr'
,
_lang
))
batch_images
=
[]
img_idx_list
=
[]
for
t_lang
in
lang_s
:
tmp_img_idx_list
=
[]
for
i
,
_lang
in
enumerate
(
lang_list
):
if
_lang
==
t_lang
:
tmp_img_idx_list
.
append
(
i
)
img_idx_list
.
extend
(
tmp_img_idx_list
)
if
batch_size
>=
len
(
tmp_img_idx_list
):
batch_images
.
append
((
t_lang
,
[
images
[
j
]
for
j
in
tmp_img_idx_list
]))
else
:
slices
=
[
tmp_img_idx_list
[
k
:
k
+
batch_size
]
for
k
in
range
(
0
,
len
(
tmp_img_idx_list
),
batch_size
)]
for
arr
in
slices
:
batch_images
.
append
((
t_lang
,
[
images
[
j
]
for
j
in
arr
]))
unorder_results
=
[]
for
sn
,
(
_lang
,
batch_image
)
in
enumerate
(
batch_images
):
_
,
result
=
may_batch_image_analyze
(
batch_image
,
sn
,
ocr
,
show_log
,
_lang
,
layout_model
,
formula_enable
,
table_enable
)
unorder_results
.
extend
(
result
)
results
=
[
None
]
*
len
(
img_idx_list
)
for
i
,
idx
in
enumerate
(
img_idx_list
):
results
[
idx
]
=
unorder_results
[
i
]
batch_images
=
[
images_with_extra_info
[
i
:
i
+
batch_size
]
for
i
in
range
(
0
,
len
(
images_with_extra_info
),
batch_size
)]
results
=
[]
for
sn
,
batch_image
in
enumerate
(
batch_images
):
_
,
result
=
may_batch_image_analyze
(
batch_image
,
sn
,
True
,
show_log
,
layout_model
,
formula_enable
,
table_enable
)
results
.
extend
(
result
)
infer_results
=
[]
from
magic_pdf.operators.models
import
InferenceResult
for
index
in
range
(
len
(
datasets
)):
dataset
=
datasets
[
index
]
...
...
@@ -248,11 +236,10 @@ def batch_doc_analyze(
def
may_batch_image_analyze
(
images
:
list
[
np
.
ndarray
],
images
_with_extra_info
:
list
[
(
np
.
ndarray
,
bool
,
str
)
],
idx
:
int
,
ocr
:
bool
=
False
,
ocr
:
bool
,
show_log
:
bool
=
False
,
lang
=
None
,
layout_model
=
None
,
formula_enable
=
None
,
table_enable
=
None
):
...
...
@@ -263,10 +250,8 @@ def may_batch_image_analyze(
from
magic_pdf.model.batch_analyze
import
BatchAnalyze
model_manager
=
ModelSingleton
()
custom_model
=
model_manager
.
get_model
(
ocr
,
show_log
,
lang
,
layout_model
,
formula_enable
,
table_enable
)
images
=
[
image
for
image
,
_
,
_
in
images_with_extra_info
]
batch_analyze
=
False
batch_ratio
=
1
device
=
get_device
()
...
...
@@ -290,64 +275,15 @@ def may_batch_image_analyze(
else
:
batch_ratio
=
1
logger
.
info
(
f
'gpu_memory:
{
gpu_memory
}
GB, batch_ratio:
{
batch_ratio
}
'
)
batch_analyze
=
True
#
batch_analyze = True
elif
str
(
device
).
startswith
(
'mps'
):
batch_analyze
=
True
doc_analyze_start
=
time
.
time
()
if
batch_analyze
:
"""# batch analyze
images = []
page_wh_list = []
for index in range(len(dataset)):
if start_page_id <= index <= end_page_id:
page_data = dataset.get_page(index)
img_dict = page_data.get_image()
images.append(img_dict['img'])
page_wh_list.append((img_dict['width'], img_dict['height']))
"""
batch_model
=
BatchAnalyze
(
model
=
custom_model
,
batch_ratio
=
batch_ratio
)
results
=
batch_model
(
images
)
"""
for index in range(len(dataset)):
if start_page_id <= index <= end_page_id:
result = analyze_result.pop(0)
page_width, page_height = page_wh_list.pop(0)
else:
result = []
page_height = 0
page_width = 0
# batch_analyze = True
pass
page_info = {'page_no': index, 'width': page_width, 'height': page_height}
page_dict = {'layout_dets': result, 'page_info': page_info}
model_json.append(page_dict)
"""
else
:
# single analyze
"""
for index in range(len(dataset)):
page_data = dataset.get_page(index)
img_dict = page_data.get_image()
img = img_dict['img']
page_width = img_dict['width']
page_height = img_dict['height']
if start_page_id <= index <= end_page_id:
page_start = time.time()
result = custom_model(img)
logger.info(f'-----page_id : {index}, page total time: {round(time.time() - page_start, 2)}-----')
else:
result = []
doc_analyze_start
=
time
.
time
()
page_info = {'page_no': index, 'width': page_width, 'height': page_height}
page_dict = {'layout_dets': result, 'page_info': page_info}
model_json.append(page_dict)
"""
results
=
[]
for
img_idx
,
img
in
enumerate
(
images
):
inference_start
=
time
.
time
()
result
=
custom_model
(
img
)
logger
.
info
(
f
'-----image index :
{
img_idx
}
, image inference total time:
{
round
(
time
.
time
()
-
inference_start
,
2
)
}
-----'
)
results
.
append
(
result
)
batch_model
=
BatchAnalyze
(
model_manager
,
batch_ratio
,
show_log
,
layout_model
,
formula_enable
,
table_enable
)
results
=
batch_model
(
images_with_extra_info
)
gc_start
=
time
.
time
()
clean_memory
(
get_device
())
...
...
magic_pdf/model/pdf_extract_kit.py
View file @
62b7582f
...
...
@@ -14,7 +14,7 @@ from magic_pdf.model.model_list import AtomicModel
from
magic_pdf.model.sub_modules.model_init
import
AtomModelSingleton
from
magic_pdf.model.sub_modules.model_utils
import
(
clean_vram
,
crop_img
,
get_res_list_from_layout_res
)
from
magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils
import
(
from
magic_pdf.model.sub_modules.ocr.paddleocr
2pytorch
.ocr_utils
import
(
get_adjusted_mfdetrec_res
,
get_ocr_result_list
)
...
...
magic_pdf/model/sub_modules/model_init.py
View file @
62b7582f
...
...
@@ -7,32 +7,33 @@ from magic_pdf.model.sub_modules.language_detection.yolov11.YOLOv11 import YOLOv
from
magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO
import
DocLayoutYOLOModel
from
magic_pdf.model.sub_modules.mfd.yolov8.YOLOv8
import
YOLOv8MFDModel
from
magic_pdf.model.sub_modules.mfr.unimernet.Unimernet
import
UnimernetModel
try
:
from
magic_pdf_ascend_plugin.libs.license_verifier
import
(
LicenseExpiredError
,
LicenseFormatError
,
LicenseSignatureError
,
load_license
)
from
magic_pdf_ascend_plugin.model_plugin.ocr.paddleocr.ppocr_273_npu
import
ModifiedPaddleOCR
from
magic_pdf_ascend_plugin.model_plugin.table.rapidtable.rapid_table_npu
import
RapidTableModel
license_key
=
load_license
()
logger
.
info
(
f
'Using Ascend Plugin Success, License id is
{
license_key
[
"payload"
][
"id"
]
}
,'
f
' License expired at
{
license_key
[
"payload"
][
"date"
][
"end_date"
]
}
'
)
except
Exception
as
e
:
if
isinstance
(
e
,
ImportError
):
pass
elif
isinstance
(
e
,
LicenseFormatError
):
logger
.
error
(
'Ascend Plugin: Invalid license format. Please check the license file.'
)
elif
isinstance
(
e
,
LicenseSignatureError
):
logger
.
error
(
'Ascend Plugin: Invalid signature. The license may be tampered with.'
)
elif
isinstance
(
e
,
LicenseExpiredError
):
logger
.
error
(
'Ascend Plugin: License has expired. Please renew your license.'
)
elif
isinstance
(
e
,
FileNotFoundError
):
logger
.
error
(
'Ascend Plugin: Not found License file.'
)
else
:
logger
.
error
(
f
'Ascend Plugin:
{
e
}
'
)
from
magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_273_mod
import
ModifiedPaddleOCR
# from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_291_mod import ModifiedPaddleOCR
from
magic_pdf.model.sub_modules.table.rapidtable.rapid_table
import
RapidTableModel
from
magic_pdf.model.sub_modules.ocr.paddleocr2pytorch.pytorch_paddle
import
PytorchPaddleOCR
from
magic_pdf.model.sub_modules.table.rapidtable.rapid_table
import
RapidTableModel
# try:
# from magic_pdf_ascend_plugin.libs.license_verifier import (
# LicenseExpiredError, LicenseFormatError, LicenseSignatureError,
# load_license)
# from magic_pdf_ascend_plugin.model_plugin.ocr.paddleocr.ppocr_273_npu import ModifiedPaddleOCR
# from magic_pdf_ascend_plugin.model_plugin.table.rapidtable.rapid_table_npu import RapidTableModel
# license_key = load_license()
# logger.info(f'Using Ascend Plugin Success, License id is {license_key["payload"]["id"]},'
# f' License expired at {license_key["payload"]["date"]["end_date"]}')
# except Exception as e:
# if isinstance(e, ImportError):
# pass
# elif isinstance(e, LicenseFormatError):
# logger.error('Ascend Plugin: Invalid license format. Please check the license file.')
# elif isinstance(e, LicenseSignatureError):
# logger.error('Ascend Plugin: Invalid signature. The license may be tampered with.')
# elif isinstance(e, LicenseExpiredError):
# logger.error('Ascend Plugin: License has expired. Please renew your license.')
# elif isinstance(e, FileNotFoundError):
# logger.error('Ascend Plugin: Not found License file.')
# else:
# logger.error(f'Ascend Plugin: {e}')
# from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_273_mod import ModifiedPaddleOCR
# # from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_291_mod import ModifiedPaddleOCR
# from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import RapidTableModel
def
table_model_init
(
table_model_type
,
model_path
,
max_time
,
_device_
=
'cpu'
,
ocr_engine
=
None
,
table_sub_model_name
=
None
):
...
...
@@ -94,7 +95,8 @@ def ocr_model_init(show_log: bool = False,
det_db_unclip_ratio
=
1.8
,
):
if
lang
is
not
None
and
lang
!=
''
:
model
=
ModifiedPaddleOCR
(
# model = ModifiedPaddleOCR(
model
=
PytorchPaddleOCR
(
show_log
=
show_log
,
det_db_box_thresh
=
det_db_box_thresh
,
lang
=
lang
,
...
...
@@ -102,7 +104,8 @@ def ocr_model_init(show_log: bool = False,
det_db_unclip_ratio
=
det_db_unclip_ratio
,
)
else
:
model
=
ModifiedPaddleOCR
(
# model = ModifiedPaddleOCR(
model
=
PytorchPaddleOCR
(
show_log
=
show_log
,
det_db_box_thresh
=
det_db_box_thresh
,
use_dilation
=
use_dilation
,
...
...
magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/__init__.py
0 → 100644
View file @
62b7582f
# Copyright (c) Opendatalab. All rights reserved.
magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/models_config.yml
0 → 100644
View file @
62b7582f
lang
:
ch
:
det
:
ch_PP-OCRv4_det_infer.pth
rec
:
ch_PP-OCRv4_rec_infer.pth
dict
:
ppocr_keys_v1.txt
en
:
det
:
en_PP-OCRv3_det_infer.pth
rec
:
en_PP-OCRv4_rec_infer.pth
dict
:
en_dict.txt
korean
:
det
:
Multilingual_PP-OCRv3_det_infer.pth
rec
:
korean_PP-OCRv3_rec_infer.pth
dict
:
korean_dict.txt
japan
:
det
:
Multilingual_PP-OCRv3_det_infer.pth
rec
:
japan_PP-OCRv3_rec_infer.pth
dict
:
japan_dict.txt
chinese_cht
:
det
:
Multilingual_PP-OCRv3_det_infer.pth
rec
:
chinese_cht_PP-OCRv3_rec_infer.pth
dict
:
chinese_cht_dict.txt
ta
:
det
:
Multilingual_PP-OCRv3_det_infer.pth
rec
:
ta_PP-OCRv3_rec_infer.pth
dict
:
ta_dict.txt
te
:
det
:
Multilingual_PP-OCRv3_det_infer.pth
rec
:
te_PP-OCRv3_rec_infer.pth
dict
:
te_dict.txt
ka
:
det
:
Multilingual_PP-OCRv3_det_infer.pth
rec
:
ka_PP-OCRv3_rec_infer.pth
dict
:
ka_dict.txt
latin
:
det
:
en_PP-OCRv3_det_infer.pth
rec
:
latin_PP-OCRv3_rec_infer.pth
dict
:
latin_dict.txt
arabic
:
det
:
Multilingual_PP-OCRv3_det_infer.pth
rec
:
arabic_PP-OCRv3_rec_infer.pth
dict
:
arabic_dict.txt
cyrillic
:
det
:
Multilingual_PP-OCRv3_det_infer.pth
rec
:
cyrillic_PP-OCRv3_rec_infer.pth
dict
:
cyrillic_dict.txt
devanagari
:
det
:
Multilingual_PP-OCRv3_det_infer.pth
rec
:
devanagari_PP-OCRv3_rec_infer.pth
dict
:
devanagari_dict.txt
\ No newline at end of file
magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/ocr_utils.py
0 → 100644
View file @
62b7582f
# Copyright (c) Opendatalab. All rights reserved.
import
copy
import
cv2
import
numpy
as
np
from
magic_pdf.pre_proc.ocr_dict_merge
import
merge_spans_to_line
from
magic_pdf.libs.boxbase
import
__is_overlaps_y_exceeds_threshold
def
img_decode
(
content
:
bytes
):
np_arr
=
np
.
frombuffer
(
content
,
dtype
=
np
.
uint8
)
return
cv2
.
imdecode
(
np_arr
,
cv2
.
IMREAD_UNCHANGED
)
def
check_img
(
img
):
if
isinstance
(
img
,
bytes
):
img
=
img_decode
(
img
)
if
isinstance
(
img
,
np
.
ndarray
)
and
len
(
img
.
shape
)
==
2
:
img
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_GRAY2BGR
)
return
img
def
alpha_to_color
(
img
,
alpha_color
=
(
255
,
255
,
255
)):
if
len
(
img
.
shape
)
==
3
and
img
.
shape
[
2
]
==
4
:
B
,
G
,
R
,
A
=
cv2
.
split
(
img
)
alpha
=
A
/
255
R
=
(
alpha_color
[
0
]
*
(
1
-
alpha
)
+
R
*
alpha
).
astype
(
np
.
uint8
)
G
=
(
alpha_color
[
1
]
*
(
1
-
alpha
)
+
G
*
alpha
).
astype
(
np
.
uint8
)
B
=
(
alpha_color
[
2
]
*
(
1
-
alpha
)
+
B
*
alpha
).
astype
(
np
.
uint8
)
img
=
cv2
.
merge
((
B
,
G
,
R
))
return
img
def
preprocess_image
(
_image
):
alpha_color
=
(
255
,
255
,
255
)
_image
=
alpha_to_color
(
_image
,
alpha_color
)
return
_image
def
sorted_boxes
(
dt_boxes
):
"""
Sort text boxes in order from top to bottom, left to right
args:
dt_boxes(array):detected text boxes with shape [4, 2]
return:
sorted boxes(array) with shape [4, 2]
"""
num_boxes
=
dt_boxes
.
shape
[
0
]
sorted_boxes
=
sorted
(
dt_boxes
,
key
=
lambda
x
:
(
x
[
0
][
1
],
x
[
0
][
0
]))
_boxes
=
list
(
sorted_boxes
)
for
i
in
range
(
num_boxes
-
1
):
for
j
in
range
(
i
,
-
1
,
-
1
):
if
abs
(
_boxes
[
j
+
1
][
0
][
1
]
-
_boxes
[
j
][
0
][
1
])
<
10
and
\
(
_boxes
[
j
+
1
][
0
][
0
]
<
_boxes
[
j
][
0
][
0
]):
tmp
=
_boxes
[
j
]
_boxes
[
j
]
=
_boxes
[
j
+
1
]
_boxes
[
j
+
1
]
=
tmp
else
:
break
return
_boxes
def
bbox_to_points
(
bbox
):
""" 将bbox格式转换为四个顶点的数组 """
x0
,
y0
,
x1
,
y1
=
bbox
return
np
.
array
([[
x0
,
y0
],
[
x1
,
y0
],
[
x1
,
y1
],
[
x0
,
y1
]]).
astype
(
'float32'
)
def
points_to_bbox
(
points
):
""" 将四个顶点的数组转换为bbox格式 """
x0
,
y0
=
points
[
0
]
x1
,
_
=
points
[
1
]
_
,
y1
=
points
[
2
]
return
[
x0
,
y0
,
x1
,
y1
]
def
merge_intervals
(
intervals
):
# Sort the intervals based on the start value
intervals
.
sort
(
key
=
lambda
x
:
x
[
0
])
merged
=
[]
for
interval
in
intervals
:
# If the list of merged intervals is empty or if the current
# interval does not overlap with the previous, simply append it.
if
not
merged
or
merged
[
-
1
][
1
]
<
interval
[
0
]:
merged
.
append
(
interval
)
else
:
# Otherwise, there is overlap, so we merge the current and previous intervals.
merged
[
-
1
][
1
]
=
max
(
merged
[
-
1
][
1
],
interval
[
1
])
return
merged
def
remove_intervals
(
original
,
masks
):
# Merge all mask intervals
merged_masks
=
merge_intervals
(
masks
)
result
=
[]
original_start
,
original_end
=
original
for
mask
in
merged_masks
:
mask_start
,
mask_end
=
mask
# If the mask starts after the original range, ignore it
if
mask_start
>
original_end
:
continue
# If the mask ends before the original range starts, ignore it
if
mask_end
<
original_start
:
continue
# Remove the masked part from the original range
if
original_start
<
mask_start
:
result
.
append
([
original_start
,
mask_start
-
1
])
original_start
=
max
(
mask_end
+
1
,
original_start
)
# Add the remaining part of the original range, if any
if
original_start
<=
original_end
:
result
.
append
([
original_start
,
original_end
])
return
result
def
update_det_boxes
(
dt_boxes
,
mfd_res
):
new_dt_boxes
=
[]
angle_boxes_list
=
[]
for
text_box
in
dt_boxes
:
if
calculate_is_angle
(
text_box
):
angle_boxes_list
.
append
(
text_box
)
continue
text_bbox
=
points_to_bbox
(
text_box
)
masks_list
=
[]
for
mf_box
in
mfd_res
:
mf_bbox
=
mf_box
[
'bbox'
]
if
__is_overlaps_y_exceeds_threshold
(
text_bbox
,
mf_bbox
):
masks_list
.
append
([
mf_bbox
[
0
],
mf_bbox
[
2
]])
text_x_range
=
[
text_bbox
[
0
],
text_bbox
[
2
]]
text_remove_mask_range
=
remove_intervals
(
text_x_range
,
masks_list
)
temp_dt_box
=
[]
for
text_remove_mask
in
text_remove_mask_range
:
temp_dt_box
.
append
(
bbox_to_points
([
text_remove_mask
[
0
],
text_bbox
[
1
],
text_remove_mask
[
1
],
text_bbox
[
3
]]))
if
len
(
temp_dt_box
)
>
0
:
new_dt_boxes
.
extend
(
temp_dt_box
)
new_dt_boxes
.
extend
(
angle_boxes_list
)
return
new_dt_boxes
def
merge_overlapping_spans
(
spans
):
"""
Merges overlapping spans on the same line.
:param spans: A list of span coordinates [(x1, y1, x2, y2), ...]
:return: A list of merged spans
"""
# Return an empty list if the input spans list is empty
if
not
spans
:
return
[]
# Sort spans by their starting x-coordinate
spans
.
sort
(
key
=
lambda
x
:
x
[
0
])
# Initialize the list of merged spans
merged
=
[]
for
span
in
spans
:
# Unpack span coordinates
x1
,
y1
,
x2
,
y2
=
span
# If the merged list is empty or there's no horizontal overlap, add the span directly
if
not
merged
or
merged
[
-
1
][
2
]
<
x1
:
merged
.
append
(
span
)
else
:
# If there is horizontal overlap, merge the current span with the previous one
last_span
=
merged
.
pop
()
# Update the merged span's top-left corner to the smaller (x1, y1) and bottom-right to the larger (x2, y2)
x1
=
min
(
last_span
[
0
],
x1
)
y1
=
min
(
last_span
[
1
],
y1
)
x2
=
max
(
last_span
[
2
],
x2
)
y2
=
max
(
last_span
[
3
],
y2
)
# Add the merged span back to the list
merged
.
append
((
x1
,
y1
,
x2
,
y2
))
# Return the list of merged spans
return
merged
def
merge_det_boxes
(
dt_boxes
):
"""
Merge detection boxes.
This function takes a list of detected bounding boxes, each represented by four corner points.
The goal is to merge these bounding boxes into larger text regions.
Parameters:
dt_boxes (list): A list containing multiple text detection boxes, where each box is defined by four corner points.
Returns:
list: A list containing the merged text regions, where each region is represented by four corner points.
"""
# Convert the detection boxes into a dictionary format with bounding boxes and type
dt_boxes_dict_list
=
[]
angle_boxes_list
=
[]
for
text_box
in
dt_boxes
:
text_bbox
=
points_to_bbox
(
text_box
)
if
calculate_is_angle
(
text_box
):
angle_boxes_list
.
append
(
text_box
)
continue
text_box_dict
=
{
'bbox'
:
text_bbox
,
'type'
:
'text'
,
}
dt_boxes_dict_list
.
append
(
text_box_dict
)
# Merge adjacent text regions into lines
lines
=
merge_spans_to_line
(
dt_boxes_dict_list
)
# Initialize a new list for storing the merged text regions
new_dt_boxes
=
[]
for
line
in
lines
:
line_bbox_list
=
[]
for
span
in
line
:
line_bbox_list
.
append
(
span
[
'bbox'
])
# Merge overlapping text regions within the same line
merged_spans
=
merge_overlapping_spans
(
line_bbox_list
)
# Convert the merged text regions back to point format and add them to the new detection box list
for
span
in
merged_spans
:
new_dt_boxes
.
append
(
bbox_to_points
(
span
))
new_dt_boxes
.
extend
(
angle_boxes_list
)
return
new_dt_boxes
def
get_adjusted_mfdetrec_res
(
single_page_mfdetrec_res
,
useful_list
):
paste_x
,
paste_y
,
xmin
,
ymin
,
xmax
,
ymax
,
new_width
,
new_height
=
useful_list
# Adjust the coordinates of the formula area
adjusted_mfdetrec_res
=
[]
for
mf_res
in
single_page_mfdetrec_res
:
mf_xmin
,
mf_ymin
,
mf_xmax
,
mf_ymax
=
mf_res
[
"bbox"
]
# Adjust the coordinates of the formula area to the coordinates relative to the cropping area
x0
=
mf_xmin
-
xmin
+
paste_x
y0
=
mf_ymin
-
ymin
+
paste_y
x1
=
mf_xmax
-
xmin
+
paste_x
y1
=
mf_ymax
-
ymin
+
paste_y
# Filter formula blocks outside the graph
if
any
([
x1
<
0
,
y1
<
0
])
or
any
([
x0
>
new_width
,
y0
>
new_height
]):
continue
else
:
adjusted_mfdetrec_res
.
append
({
"bbox"
:
[
x0
,
y0
,
x1
,
y1
],
})
return
adjusted_mfdetrec_res
def
get_ocr_result_list
(
ocr_res
,
useful_list
,
ocr_enable
,
new_image
,
lang
):
paste_x
,
paste_y
,
xmin
,
ymin
,
xmax
,
ymax
,
new_width
,
new_height
=
useful_list
ocr_result_list
=
[]
ori_im
=
new_image
.
copy
()
for
box_ocr_res
in
ocr_res
:
if
len
(
box_ocr_res
)
==
2
:
p1
,
p2
,
p3
,
p4
=
box_ocr_res
[
0
]
text
,
score
=
box_ocr_res
[
1
]
# logger.info(f"text: {text}, score: {score}")
if
score
<
0.6
:
# 过滤低置信度的结果
continue
else
:
p1
,
p2
,
p3
,
p4
=
box_ocr_res
text
,
score
=
""
,
1
if
ocr_enable
:
tmp_box
=
copy
.
deepcopy
(
np
.
array
([
p1
,
p2
,
p3
,
p4
]).
astype
(
'float32'
))
img_crop
=
get_rotate_crop_image
(
ori_im
,
tmp_box
)
# average_angle_degrees = calculate_angle_degrees(box_ocr_res[0])
# if average_angle_degrees > 0.5:
poly
=
[
p1
,
p2
,
p3
,
p4
]
if
calculate_is_angle
(
poly
):
# logger.info(f"average_angle_degrees: {average_angle_degrees}, text: {text}")
# 与x轴的夹角超过0.5度,对边界做一下矫正
# 计算几何中心
x_center
=
sum
(
point
[
0
]
for
point
in
poly
)
/
4
y_center
=
sum
(
point
[
1
]
for
point
in
poly
)
/
4
new_height
=
((
p4
[
1
]
-
p1
[
1
])
+
(
p3
[
1
]
-
p2
[
1
]))
/
2
new_width
=
p3
[
0
]
-
p1
[
0
]
p1
=
[
x_center
-
new_width
/
2
,
y_center
-
new_height
/
2
]
p2
=
[
x_center
+
new_width
/
2
,
y_center
-
new_height
/
2
]
p3
=
[
x_center
+
new_width
/
2
,
y_center
+
new_height
/
2
]
p4
=
[
x_center
-
new_width
/
2
,
y_center
+
new_height
/
2
]
# Convert the coordinates back to the original coordinate system
p1
=
[
p1
[
0
]
-
paste_x
+
xmin
,
p1
[
1
]
-
paste_y
+
ymin
]
p2
=
[
p2
[
0
]
-
paste_x
+
xmin
,
p2
[
1
]
-
paste_y
+
ymin
]
p3
=
[
p3
[
0
]
-
paste_x
+
xmin
,
p3
[
1
]
-
paste_y
+
ymin
]
p4
=
[
p4
[
0
]
-
paste_x
+
xmin
,
p4
[
1
]
-
paste_y
+
ymin
]
if
ocr_enable
:
ocr_result_list
.
append
({
'category_id'
:
15
,
'poly'
:
p1
+
p2
+
p3
+
p4
,
'score'
:
1
,
'text'
:
text
,
'np_img'
:
img_crop
,
'lang'
:
lang
,
})
else
:
ocr_result_list
.
append
({
'category_id'
:
15
,
'poly'
:
p1
+
p2
+
p3
+
p4
,
'score'
:
float
(
round
(
score
,
2
)),
'text'
:
text
,
})
return
ocr_result_list
def
calculate_is_angle
(
poly
):
p1
,
p2
,
p3
,
p4
=
poly
height
=
((
p4
[
1
]
-
p1
[
1
])
+
(
p3
[
1
]
-
p2
[
1
]))
/
2
if
0.8
*
height
<=
(
p3
[
1
]
-
p1
[
1
])
<=
1.2
*
height
:
return
False
else
:
# logger.info((p3[1] - p1[1])/height)
return
True
def
get_rotate_crop_image
(
img
,
points
):
'''
img_height, img_width = img.shape[0:2]
left = int(np.min(points[:, 0]))
right = int(np.max(points[:, 0]))
top = int(np.min(points[:, 1]))
bottom = int(np.max(points[:, 1]))
img_crop = img[top:bottom, left:right, :].copy()
points[:, 0] = points[:, 0] - left
points[:, 1] = points[:, 1] - top
'''
assert
len
(
points
)
==
4
,
"shape of points must be 4*2"
img_crop_width
=
int
(
max
(
np
.
linalg
.
norm
(
points
[
0
]
-
points
[
1
]),
np
.
linalg
.
norm
(
points
[
2
]
-
points
[
3
])))
img_crop_height
=
int
(
max
(
np
.
linalg
.
norm
(
points
[
0
]
-
points
[
3
]),
np
.
linalg
.
norm
(
points
[
1
]
-
points
[
2
])))
pts_std
=
np
.
float32
([[
0
,
0
],
[
img_crop_width
,
0
],
[
img_crop_width
,
img_crop_height
],
[
0
,
img_crop_height
]])
M
=
cv2
.
getPerspectiveTransform
(
points
,
pts_std
)
dst_img
=
cv2
.
warpPerspective
(
img
,
M
,
(
img_crop_width
,
img_crop_height
),
borderMode
=
cv2
.
BORDER_REPLICATE
,
flags
=
cv2
.
INTER_CUBIC
)
dst_img_height
,
dst_img_width
=
dst_img
.
shape
[
0
:
2
]
if
dst_img_height
*
1.0
/
dst_img_width
>=
1.5
:
dst_img
=
np
.
rot90
(
dst_img
)
return
dst_img
\ No newline at end of file
magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorch_paddle.py
0 → 100644
View file @
62b7582f
# Copyright (c) Opendatalab. All rights reserved.
import
copy
import
os.path
from
pathlib
import
Path
import
cv2
import
numpy
as
np
import
yaml
from
loguru
import
logger
from
magic_pdf.libs.config_reader
import
get_device
,
get_local_models_dir
from
.ocr_utils
import
check_img
,
preprocess_image
,
sorted_boxes
,
merge_det_boxes
,
update_det_boxes
,
get_rotate_crop_image
from
.tools.infer.predict_system
import
TextSystem
from
.tools.infer
import
pytorchocr_utility
as
utility
import
argparse
latin_lang
=
[
'af'
,
'az'
,
'bs'
,
'cs'
,
'cy'
,
'da'
,
'de'
,
'es'
,
'et'
,
'fr'
,
'ga'
,
'hr'
,
# noqa: E126
'hu'
,
'id'
,
'is'
,
'it'
,
'ku'
,
'la'
,
'lt'
,
'lv'
,
'mi'
,
'ms'
,
'mt'
,
'nl'
,
'no'
,
'oc'
,
'pi'
,
'pl'
,
'pt'
,
'ro'
,
'rs_latin'
,
'sk'
,
'sl'
,
'sq'
,
'sv'
,
'sw'
,
'tl'
,
'tr'
,
'uz'
,
'vi'
,
'french'
,
'german'
]
arabic_lang
=
[
'ar'
,
'fa'
,
'ug'
,
'ur'
]
cyrillic_lang
=
[
'ru'
,
'rs_cyrillic'
,
'be'
,
'bg'
,
'uk'
,
'mn'
,
'abq'
,
'ady'
,
'kbd'
,
'ava'
,
# noqa: E126
'dar'
,
'inh'
,
'che'
,
'lbe'
,
'lez'
,
'tab'
]
devanagari_lang
=
[
'hi'
,
'mr'
,
'ne'
,
'bh'
,
'mai'
,
'ang'
,
'bho'
,
'mah'
,
'sck'
,
'new'
,
'gom'
,
# noqa: E126
'sa'
,
'bgc'
]
def
get_model_params
(
lang
,
config
):
if
lang
in
config
[
'lang'
]:
params
=
config
[
'lang'
][
lang
]
det
=
params
.
get
(
'det'
)
rec
=
params
.
get
(
'rec'
)
dict_file
=
params
.
get
(
'dict'
)
return
det
,
rec
,
dict_file
else
:
raise
Exception
(
f
'Language
{
lang
}
not supported'
)
root_dir
=
Path
(
__file__
).
resolve
().
parent
class
PytorchPaddleOCR
(
TextSystem
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
parser
=
utility
.
init_args
()
args
=
parser
.
parse_args
(
args
)
self
.
lang
=
kwargs
.
get
(
'lang'
,
'ch'
)
if
self
.
lang
in
latin_lang
:
self
.
lang
=
'latin'
elif
self
.
lang
in
arabic_lang
:
self
.
lang
=
'arabic'
elif
self
.
lang
in
cyrillic_lang
:
self
.
lang
=
'cyrillic'
elif
self
.
lang
in
devanagari_lang
:
self
.
lang
=
'devanagari'
else
:
pass
models_config_path
=
os
.
path
.
join
(
root_dir
,
'models_config.yml'
)
with
open
(
models_config_path
)
as
file
:
config
=
yaml
.
safe_load
(
file
)
det
,
rec
,
dict_file
=
get_model_params
(
self
.
lang
,
config
)
ocr_models_dir
=
os
.
path
.
join
(
get_local_models_dir
(),
'OCR'
,
'paddleocr_torch'
)
kwargs
[
'det_model_path'
]
=
os
.
path
.
join
(
ocr_models_dir
,
det
)
kwargs
[
'rec_model_path'
]
=
os
.
path
.
join
(
ocr_models_dir
,
rec
)
kwargs
[
'rec_char_dict_path'
]
=
os
.
path
.
join
(
root_dir
,
'pytorchocr'
,
'utils'
,
'dict'
,
dict_file
)
kwargs
[
'device'
]
=
get_device
()
default_args
=
vars
(
args
)
default_args
.
update
(
kwargs
)
args
=
argparse
.
Namespace
(
**
default_args
)
super
().
__init__
(
args
)
def
ocr
(
self
,
img
,
det
=
True
,
rec
=
True
,
mfd_res
=
None
,
):
assert
isinstance
(
img
,
(
np
.
ndarray
,
list
,
str
,
bytes
))
if
isinstance
(
img
,
list
)
and
det
==
True
:
logger
.
error
(
'When input a list of images, det must be false'
)
exit
(
0
)
img
=
check_img
(
img
)
imgs
=
[
img
]
if
det
and
rec
:
ocr_res
=
[]
for
img
in
imgs
:
img
=
preprocess_image
(
img
)
dt_boxes
,
rec_res
=
self
.
__call__
(
img
,
mfd_res
=
mfd_res
)
if
not
dt_boxes
and
not
rec_res
:
ocr_res
.
append
(
None
)
continue
tmp_res
=
[[
box
.
tolist
(),
res
]
for
box
,
res
in
zip
(
dt_boxes
,
rec_res
)]
ocr_res
.
append
(
tmp_res
)
return
ocr_res
elif
det
and
not
rec
:
ocr_res
=
[]
for
img
in
imgs
:
img
=
preprocess_image
(
img
)
dt_boxes
,
elapse
=
self
.
text_detector
(
img
)
logger
.
debug
(
"dt_boxes num : {}, elapsed : {}"
.
format
(
len
(
dt_boxes
),
elapse
))
if
dt_boxes
is
None
:
ocr_res
.
append
(
None
)
continue
dt_boxes
=
sorted_boxes
(
dt_boxes
)
# merge_det_boxes 和 update_det_boxes 都会把poly转成bbox再转回poly,因此需要过滤所有倾斜程度较大的文本框
dt_boxes
=
merge_det_boxes
(
dt_boxes
)
if
mfd_res
:
dt_boxes
=
update_det_boxes
(
dt_boxes
,
mfd_res
)
tmp_res
=
[
box
.
tolist
()
for
box
in
dt_boxes
]
ocr_res
.
append
(
tmp_res
)
return
ocr_res
elif
not
det
and
rec
:
ocr_res
=
[]
for
img
in
imgs
:
if
not
isinstance
(
img
,
list
):
img
=
preprocess_image
(
img
)
img
=
[
img
]
rec_res
,
elapse
=
self
.
text_recognizer
(
img
)
logger
.
debug
(
"rec_res num : {}, elapsed : {}"
.
format
(
len
(
rec_res
),
elapse
))
ocr_res
.
append
(
rec_res
)
return
ocr_res
def
__call__
(
self
,
img
,
mfd_res
=
None
):
if
img
is
None
:
logger
.
debug
(
"no valid image provided"
)
return
None
,
None
ori_im
=
img
.
copy
()
dt_boxes
,
elapse
=
self
.
text_detector
(
img
)
if
dt_boxes
is
None
:
logger
.
debug
(
"no dt_boxes found, elapsed : {}"
.
format
(
elapse
))
return
None
,
None
else
:
pass
logger
.
debug
(
"dt_boxes num : {}, elapsed : {}"
.
format
(
len
(
dt_boxes
),
elapse
))
img_crop_list
=
[]
dt_boxes
=
sorted_boxes
(
dt_boxes
)
# merge_det_boxes 和 update_det_boxes 都会把poly转成bbox再转回poly,因此需要过滤所有倾斜程度较大的文本框
dt_boxes
=
merge_det_boxes
(
dt_boxes
)
if
mfd_res
:
dt_boxes
=
update_det_boxes
(
dt_boxes
,
mfd_res
)
for
bno
in
range
(
len
(
dt_boxes
)):
tmp_box
=
copy
.
deepcopy
(
dt_boxes
[
bno
])
img_crop
=
get_rotate_crop_image
(
ori_im
,
tmp_box
)
img_crop_list
.
append
(
img_crop
)
rec_res
,
elapse
=
self
.
text_recognizer
(
img_crop_list
)
logger
.
debug
(
"rec_res num : {}, elapsed : {}"
.
format
(
len
(
rec_res
),
elapse
))
filter_boxes
,
filter_rec_res
=
[],
[]
for
box
,
rec_result
in
zip
(
dt_boxes
,
rec_res
):
text
,
score
=
rec_result
if
score
>=
self
.
drop_score
:
filter_boxes
.
append
(
box
)
filter_rec_res
.
append
(
rec_result
)
return
filter_boxes
,
filter_rec_res
if
__name__
==
'__main__'
:
pytorch_paddle_ocr
=
PytorchPaddleOCR
()
img
=
cv2
.
imread
(
"/Users/myhloli/Downloads/screenshot-20250326-194348.png"
)
dt_boxes
,
rec_res
=
pytorch_paddle_ocr
(
img
)
ocr_res
=
[]
if
not
dt_boxes
and
not
rec_res
:
ocr_res
.
append
(
None
)
else
:
tmp_res
=
[[
box
.
tolist
(),
res
]
for
box
,
res
in
zip
(
dt_boxes
,
rec_res
)]
ocr_res
.
append
(
tmp_res
)
print
(
ocr_res
)
magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/__init__.py
0 → 100755
View file @
62b7582f
magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/base_ocr_v20.py
0 → 100755
View file @
62b7582f
import
os
import
torch
from
.modeling.architectures.base_model
import
BaseModel
class
BaseOCRV20
:
def
__init__
(
self
,
config
,
**
kwargs
):
self
.
config
=
config
self
.
build_net
(
**
kwargs
)
self
.
net
.
eval
()
def
build_net
(
self
,
**
kwargs
):
self
.
net
=
BaseModel
(
self
.
config
,
**
kwargs
)
def
read_pytorch_weights
(
self
,
weights_path
):
if
not
os
.
path
.
exists
(
weights_path
):
raise
FileNotFoundError
(
'{} is not existed.'
.
format
(
weights_path
))
weights
=
torch
.
load
(
weights_path
)
return
weights
def
get_out_channels
(
self
,
weights
):
if
list
(
weights
.
keys
())[
-
1
].
endswith
(
'.weight'
)
and
len
(
list
(
weights
.
values
())[
-
1
].
shape
)
==
2
:
out_channels
=
list
(
weights
.
values
())[
-
1
].
numpy
().
shape
[
1
]
else
:
out_channels
=
list
(
weights
.
values
())[
-
1
].
numpy
().
shape
[
0
]
return
out_channels
def
load_state_dict
(
self
,
weights
):
self
.
net
.
load_state_dict
(
weights
)
print
(
'weights is loaded.'
)
def
load_pytorch_weights
(
self
,
weights_path
):
self
.
net
.
load_state_dict
(
torch
.
load
(
weights_path
,
weights_only
=
True
))
print
(
'model is loaded: {}'
.
format
(
weights_path
))
def
inference
(
self
,
inputs
):
with
torch
.
no_grad
():
infer
=
self
.
net
(
inputs
)
return
infer
magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/data/__init__.py
0 → 100755
View file @
62b7582f
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
unicode_literals
import
os
import
sys
import
numpy
as
np
# import paddle
import
signal
import
random
__dir__
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
__dir__
,
'../..'
)))
import
copy
# from paddle.io import Dataset, DataLoader, BatchSampler, DistributedBatchSampler
# import paddle.distributed as dist
from
.imaug
import
transform
,
create_operators
magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/data/imaug/__init__.py
0 → 100755
View file @
62b7582f
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
unicode_literals
# from .iaa_augment import IaaAugment
# from .make_border_map import MakeBorderMap
# from .make_shrink_map import MakeShrinkMap
# from .random_crop_data import EastRandomCropData, PSERandomCrop
# from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg
# from .randaugment import RandAugment
from
.operators
import
*
# from .label_ops import *
# from .east_process import *
# from .sast_process import *
# from .gen_table_mask import *
def
transform
(
data
,
ops
=
None
):
""" transform """
if
ops
is
None
:
ops
=
[]
for
op
in
ops
:
data
=
op
(
data
)
if
data
is
None
:
return
None
return
data
def
create_operators
(
op_param_list
,
global_config
=
None
):
"""
create operators based on the config
Args:
params(list): a dict list, used to create some operators
"""
assert
isinstance
(
op_param_list
,
list
),
(
'operator config should be a list'
)
ops
=
[]
for
operator
in
op_param_list
:
assert
isinstance
(
operator
,
dict
)
and
len
(
operator
)
==
1
,
"yaml format error"
op_name
=
list
(
operator
)[
0
]
param
=
{}
if
operator
[
op_name
]
is
None
else
operator
[
op_name
]
if
global_config
is
not
None
:
param
.
update
(
global_config
)
op
=
eval
(
op_name
)(
**
param
)
ops
.
append
(
op
)
return
ops
\ No newline at end of file
magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/data/imaug/operators.py
0 → 100755
View file @
62b7582f
"""
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
unicode_literals
import
sys
import
six
import
cv2
import
numpy
as
np
class
DecodeImage
(
object
):
""" decode image """
def
__init__
(
self
,
img_mode
=
'RGB'
,
channel_first
=
False
,
**
kwargs
):
self
.
img_mode
=
img_mode
self
.
channel_first
=
channel_first
def
__call__
(
self
,
data
):
img
=
data
[
'image'
]
if
six
.
PY2
:
assert
type
(
img
)
is
str
and
len
(
img
)
>
0
,
"invalid input 'img' in DecodeImage"
else
:
assert
type
(
img
)
is
bytes
and
len
(
img
)
>
0
,
"invalid input 'img' in DecodeImage"
img
=
np
.
frombuffer
(
img
,
dtype
=
'uint8'
)
img
=
cv2
.
imdecode
(
img
,
1
)
if
img
is
None
:
return
None
if
self
.
img_mode
==
'GRAY'
:
img
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_GRAY2BGR
)
elif
self
.
img_mode
==
'RGB'
:
assert
img
.
shape
[
2
]
==
3
,
'invalid shape of image[%s]'
%
(
img
.
shape
)
img
=
img
[:,
:,
::
-
1
]
if
self
.
channel_first
:
img
=
img
.
transpose
((
2
,
0
,
1
))
data
[
'image'
]
=
img
return
data
class
NRTRDecodeImage
(
object
):
""" decode image """
def
__init__
(
self
,
img_mode
=
'RGB'
,
channel_first
=
False
,
**
kwargs
):
self
.
img_mode
=
img_mode
self
.
channel_first
=
channel_first
def
__call__
(
self
,
data
):
img
=
data
[
'image'
]
if
six
.
PY2
:
assert
type
(
img
)
is
str
and
len
(
img
)
>
0
,
"invalid input 'img' in DecodeImage"
else
:
assert
type
(
img
)
is
bytes
and
len
(
img
)
>
0
,
"invalid input 'img' in DecodeImage"
img
=
np
.
frombuffer
(
img
,
dtype
=
'uint8'
)
img
=
cv2
.
imdecode
(
img
,
1
)
if
img
is
None
:
return
None
if
self
.
img_mode
==
'GRAY'
:
img
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_GRAY2BGR
)
elif
self
.
img_mode
==
'RGB'
:
assert
img
.
shape
[
2
]
==
3
,
'invalid shape of image[%s]'
%
(
img
.
shape
)
img
=
img
[:,
:,
::
-
1
]
img
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_BGR2GRAY
)
if
self
.
channel_first
:
img
=
img
.
transpose
((
2
,
0
,
1
))
data
[
'image'
]
=
img
return
data
class
NormalizeImage
(
object
):
""" normalize image such as substract mean, divide std
"""
def
__init__
(
self
,
scale
=
None
,
mean
=
None
,
std
=
None
,
order
=
'chw'
,
**
kwargs
):
if
isinstance
(
scale
,
str
):
scale
=
eval
(
scale
)
self
.
scale
=
np
.
float32
(
scale
if
scale
is
not
None
else
1.0
/
255.0
)
mean
=
mean
if
mean
is
not
None
else
[
0.485
,
0.456
,
0.406
]
std
=
std
if
std
is
not
None
else
[
0.229
,
0.224
,
0.225
]
shape
=
(
3
,
1
,
1
)
if
order
==
'chw'
else
(
1
,
1
,
3
)
self
.
mean
=
np
.
array
(
mean
).
reshape
(
shape
).
astype
(
'float32'
)
self
.
std
=
np
.
array
(
std
).
reshape
(
shape
).
astype
(
'float32'
)
def
__call__
(
self
,
data
):
img
=
data
[
'image'
]
from
PIL
import
Image
if
isinstance
(
img
,
Image
.
Image
):
img
=
np
.
array
(
img
)
assert
isinstance
(
img
,
np
.
ndarray
),
"invalid input 'img' in NormalizeImage"
data
[
'image'
]
=
(
img
.
astype
(
'float32'
)
*
self
.
scale
-
self
.
mean
)
/
self
.
std
return
data
class
ToCHWImage
(
object
):
""" convert hwc image to chw image
"""
def
__init__
(
self
,
**
kwargs
):
pass
def
__call__
(
self
,
data
):
img
=
data
[
'image'
]
from
PIL
import
Image
if
isinstance
(
img
,
Image
.
Image
):
img
=
np
.
array
(
img
)
data
[
'image'
]
=
img
.
transpose
((
2
,
0
,
1
))
return
data
class
Fasttext
(
object
):
def
__init__
(
self
,
path
=
"None"
,
**
kwargs
):
import
fasttext
self
.
fast_model
=
fasttext
.
load_model
(
path
)
def
__call__
(
self
,
data
):
label
=
data
[
'label'
]
fast_label
=
self
.
fast_model
[
label
]
data
[
'fast_label'
]
=
fast_label
return
data
class
KeepKeys
(
object
):
def
__init__
(
self
,
keep_keys
,
**
kwargs
):
self
.
keep_keys
=
keep_keys
def
__call__
(
self
,
data
):
data_list
=
[]
for
key
in
self
.
keep_keys
:
data_list
.
append
(
data
[
key
])
return
data_list
class
Resize
(
object
):
def
__init__
(
self
,
size
=
(
640
,
640
),
**
kwargs
):
self
.
size
=
size
def
resize_image
(
self
,
img
):
resize_h
,
resize_w
=
self
.
size
ori_h
,
ori_w
=
img
.
shape
[:
2
]
# (h, w, c)
ratio_h
=
float
(
resize_h
)
/
ori_h
ratio_w
=
float
(
resize_w
)
/
ori_w
img
=
cv2
.
resize
(
img
,
(
int
(
resize_w
),
int
(
resize_h
)))
return
img
,
[
ratio_h
,
ratio_w
]
def
__call__
(
self
,
data
):
img
=
data
[
'image'
]
text_polys
=
data
[
'polys'
]
img_resize
,
[
ratio_h
,
ratio_w
]
=
self
.
resize_image
(
img
)
new_boxes
=
[]
for
box
in
text_polys
:
new_box
=
[]
for
cord
in
box
:
new_box
.
append
([
cord
[
0
]
*
ratio_w
,
cord
[
1
]
*
ratio_h
])
new_boxes
.
append
(
new_box
)
data
[
'image'
]
=
img_resize
data
[
'polys'
]
=
np
.
array
(
new_boxes
,
dtype
=
np
.
float32
)
return
data
class
DetResizeForTest
(
object
):
def
__init__
(
self
,
**
kwargs
):
super
(
DetResizeForTest
,
self
).
__init__
()
self
.
resize_type
=
0
if
'image_shape'
in
kwargs
:
self
.
image_shape
=
kwargs
[
'image_shape'
]
self
.
resize_type
=
1
elif
'limit_side_len'
in
kwargs
:
self
.
limit_side_len
=
kwargs
[
'limit_side_len'
]
self
.
limit_type
=
kwargs
.
get
(
'limit_type'
,
'min'
)
elif
'resize_long'
in
kwargs
:
self
.
resize_type
=
2
self
.
resize_long
=
kwargs
.
get
(
'resize_long'
,
960
)
else
:
self
.
limit_side_len
=
736
self
.
limit_type
=
'min'
def
__call__
(
self
,
data
):
img
=
data
[
'image'
]
src_h
,
src_w
,
_
=
img
.
shape
if
self
.
resize_type
==
0
:
# img, shape = self.resize_image_type0(img)
img
,
[
ratio_h
,
ratio_w
]
=
self
.
resize_image_type0
(
img
)
elif
self
.
resize_type
==
2
:
img
,
[
ratio_h
,
ratio_w
]
=
self
.
resize_image_type2
(
img
)
else
:
# img, shape = self.resize_image_type1(img)
img
,
[
ratio_h
,
ratio_w
]
=
self
.
resize_image_type1
(
img
)
data
[
'image'
]
=
img
data
[
'shape'
]
=
np
.
array
([
src_h
,
src_w
,
ratio_h
,
ratio_w
])
return
data
def
resize_image_type1
(
self
,
img
):
resize_h
,
resize_w
=
self
.
image_shape
ori_h
,
ori_w
=
img
.
shape
[:
2
]
# (h, w, c)
ratio_h
=
float
(
resize_h
)
/
ori_h
ratio_w
=
float
(
resize_w
)
/
ori_w
img
=
cv2
.
resize
(
img
,
(
int
(
resize_w
),
int
(
resize_h
)))
# return img, np.array([ori_h, ori_w])
return
img
,
[
ratio_h
,
ratio_w
]
def
resize_image_type0
(
self
,
img
):
"""
resize image to a size multiple of 32 which is required by the network
args:
img(array): array with shape [h, w, c]
return(tuple):
img, (ratio_h, ratio_w)
"""
limit_side_len
=
self
.
limit_side_len
h
,
w
,
c
=
img
.
shape
# limit the max side
if
self
.
limit_type
==
'max'
:
if
max
(
h
,
w
)
>
limit_side_len
:
if
h
>
w
:
ratio
=
float
(
limit_side_len
)
/
h
else
:
ratio
=
float
(
limit_side_len
)
/
w
else
:
ratio
=
1.
elif
self
.
limit_type
==
'min'
:
if
min
(
h
,
w
)
<
limit_side_len
:
if
h
<
w
:
ratio
=
float
(
limit_side_len
)
/
h
else
:
ratio
=
float
(
limit_side_len
)
/
w
else
:
ratio
=
1.
elif
self
.
limit_type
==
'resize_long'
:
ratio
=
float
(
limit_side_len
)
/
max
(
h
,
w
)
else
:
raise
Exception
(
'not support limit type, image '
)
resize_h
=
int
(
h
*
ratio
)
resize_w
=
int
(
w
*
ratio
)
resize_h
=
max
(
int
(
round
(
resize_h
/
32
)
*
32
),
32
)
resize_w
=
max
(
int
(
round
(
resize_w
/
32
)
*
32
),
32
)
try
:
if
int
(
resize_w
)
<=
0
or
int
(
resize_h
)
<=
0
:
return
None
,
(
None
,
None
)
img
=
cv2
.
resize
(
img
,
(
int
(
resize_w
),
int
(
resize_h
)))
except
:
print
(
img
.
shape
,
resize_w
,
resize_h
)
sys
.
exit
(
0
)
ratio_h
=
resize_h
/
float
(
h
)
ratio_w
=
resize_w
/
float
(
w
)
return
img
,
[
ratio_h
,
ratio_w
]
def
resize_image_type2
(
self
,
img
):
h
,
w
,
_
=
img
.
shape
resize_w
=
w
resize_h
=
h
if
resize_h
>
resize_w
:
ratio
=
float
(
self
.
resize_long
)
/
resize_h
else
:
ratio
=
float
(
self
.
resize_long
)
/
resize_w
resize_h
=
int
(
resize_h
*
ratio
)
resize_w
=
int
(
resize_w
*
ratio
)
max_stride
=
128
resize_h
=
(
resize_h
+
max_stride
-
1
)
//
max_stride
*
max_stride
resize_w
=
(
resize_w
+
max_stride
-
1
)
//
max_stride
*
max_stride
img
=
cv2
.
resize
(
img
,
(
int
(
resize_w
),
int
(
resize_h
)))
ratio_h
=
resize_h
/
float
(
h
)
ratio_w
=
resize_w
/
float
(
w
)
return
img
,
[
ratio_h
,
ratio_w
]
class
E2EResizeForTest
(
object
):
def
__init__
(
self
,
**
kwargs
):
super
(
E2EResizeForTest
,
self
).
__init__
()
self
.
max_side_len
=
kwargs
[
'max_side_len'
]
self
.
valid_set
=
kwargs
[
'valid_set'
]
def
__call__
(
self
,
data
):
img
=
data
[
'image'
]
src_h
,
src_w
,
_
=
img
.
shape
if
self
.
valid_set
==
'totaltext'
:
im_resized
,
[
ratio_h
,
ratio_w
]
=
self
.
resize_image_for_totaltext
(
img
,
max_side_len
=
self
.
max_side_len
)
else
:
im_resized
,
(
ratio_h
,
ratio_w
)
=
self
.
resize_image
(
img
,
max_side_len
=
self
.
max_side_len
)
data
[
'image'
]
=
im_resized
data
[
'shape'
]
=
np
.
array
([
src_h
,
src_w
,
ratio_h
,
ratio_w
])
return
data
def
resize_image_for_totaltext
(
self
,
im
,
max_side_len
=
512
):
h
,
w
,
_
=
im
.
shape
resize_w
=
w
resize_h
=
h
ratio
=
1.25
if
h
*
ratio
>
max_side_len
:
ratio
=
float
(
max_side_len
)
/
resize_h
resize_h
=
int
(
resize_h
*
ratio
)
resize_w
=
int
(
resize_w
*
ratio
)
max_stride
=
128
resize_h
=
(
resize_h
+
max_stride
-
1
)
//
max_stride
*
max_stride
resize_w
=
(
resize_w
+
max_stride
-
1
)
//
max_stride
*
max_stride
im
=
cv2
.
resize
(
im
,
(
int
(
resize_w
),
int
(
resize_h
)))
ratio_h
=
resize_h
/
float
(
h
)
ratio_w
=
resize_w
/
float
(
w
)
return
im
,
(
ratio_h
,
ratio_w
)
def
resize_image
(
self
,
im
,
max_side_len
=
512
):
"""
resize image to a size multiple of max_stride which is required by the network
:param im: the resized image
:param max_side_len: limit of max image size to avoid out of memory in gpu
:return: the resized image and the resize ratio
"""
h
,
w
,
_
=
im
.
shape
resize_w
=
w
resize_h
=
h
# Fix the longer side
if
resize_h
>
resize_w
:
ratio
=
float
(
max_side_len
)
/
resize_h
else
:
ratio
=
float
(
max_side_len
)
/
resize_w
resize_h
=
int
(
resize_h
*
ratio
)
resize_w
=
int
(
resize_w
*
ratio
)
max_stride
=
128
resize_h
=
(
resize_h
+
max_stride
-
1
)
//
max_stride
*
max_stride
resize_w
=
(
resize_w
+
max_stride
-
1
)
//
max_stride
*
max_stride
im
=
cv2
.
resize
(
im
,
(
int
(
resize_w
),
int
(
resize_h
)))
ratio_h
=
resize_h
/
float
(
h
)
ratio_w
=
resize_w
/
float
(
w
)
return
im
,
(
ratio_h
,
ratio_w
)
class
KieResize
(
object
):
def
__init__
(
self
,
**
kwargs
):
super
(
KieResize
,
self
).
__init__
()
self
.
max_side
,
self
.
min_side
=
kwargs
[
'img_scale'
][
0
],
kwargs
[
'img_scale'
][
1
]
def
__call__
(
self
,
data
):
img
=
data
[
'image'
]
points
=
data
[
'points'
]
src_h
,
src_w
,
_
=
img
.
shape
im_resized
,
scale_factor
,
[
ratio_h
,
ratio_w
],
[
new_h
,
new_w
]
=
self
.
resize_image
(
img
)
resize_points
=
self
.
resize_boxes
(
img
,
points
,
scale_factor
)
data
[
'ori_image'
]
=
img
data
[
'ori_boxes'
]
=
points
data
[
'points'
]
=
resize_points
data
[
'image'
]
=
im_resized
data
[
'shape'
]
=
np
.
array
([
new_h
,
new_w
])
return
data
def
resize_image
(
self
,
img
):
norm_img
=
np
.
zeros
([
1024
,
1024
,
3
],
dtype
=
'float32'
)
scale
=
[
512
,
1024
]
h
,
w
=
img
.
shape
[:
2
]
max_long_edge
=
max
(
scale
)
max_short_edge
=
min
(
scale
)
scale_factor
=
min
(
max_long_edge
/
max
(
h
,
w
),
max_short_edge
/
min
(
h
,
w
))
resize_w
,
resize_h
=
int
(
w
*
float
(
scale_factor
)
+
0.5
),
int
(
h
*
float
(
scale_factor
)
+
0.5
)
max_stride
=
32
resize_h
=
(
resize_h
+
max_stride
-
1
)
//
max_stride
*
max_stride
resize_w
=
(
resize_w
+
max_stride
-
1
)
//
max_stride
*
max_stride
im
=
cv2
.
resize
(
img
,
(
resize_w
,
resize_h
))
new_h
,
new_w
=
im
.
shape
[:
2
]
w_scale
=
new_w
/
w
h_scale
=
new_h
/
h
scale_factor
=
np
.
array
(
[
w_scale
,
h_scale
,
w_scale
,
h_scale
],
dtype
=
np
.
float32
)
norm_img
[:
new_h
,
:
new_w
,
:]
=
im
return
norm_img
,
scale_factor
,
[
h_scale
,
w_scale
],
[
new_h
,
new_w
]
def
resize_boxes
(
self
,
im
,
points
,
scale_factor
):
points
=
points
*
scale_factor
img_shape
=
im
.
shape
[:
2
]
points
[:,
0
::
2
]
=
np
.
clip
(
points
[:,
0
::
2
],
0
,
img_shape
[
1
])
points
[:,
1
::
2
]
=
np
.
clip
(
points
[:,
1
::
2
],
0
,
img_shape
[
0
])
return
points
magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/__init__.py
0 → 100644
View file @
62b7582f
magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/arch_config.yaml
0 → 100644
View file @
62b7582f
ch_ptocr_mobile_v2.0_cls_infer
:
model_type
:
cls
algorithm
:
CLS
Transform
:
Backbone
:
name
:
MobileNetV3
scale
:
0.35
model_name
:
small
Neck
:
Head
:
name
:
ClsHead
class_dim
:
2
Multilingual_PP-OCRv3_det_infer
:
model_type
:
det
algorithm
:
DB
Transform
:
Backbone
:
name
:
MobileNetV3
scale
:
0.5
model_name
:
large
disable_se
:
True
Neck
:
name
:
RSEFPN
out_channels
:
96
shortcut
:
True
Head
:
name
:
DBHead
k
:
50
en_PP-OCRv3_det_infer
:
model_type
:
det
algorithm
:
DB
Transform
:
Backbone
:
name
:
MobileNetV3
scale
:
0.5
model_name
:
large
disable_se
:
True
Neck
:
name
:
RSEFPN
out_channels
:
96
shortcut
:
True
Head
:
name
:
DBHead
k
:
50
en_PP-OCRv4_rec_infer
:
model_type
:
rec
algorithm
:
SVTR_LCNet
Transform
:
Backbone
:
name
:
PPLCNetV3
scale
:
0.95
Head
:
name
:
MultiHead
out_channels_list
:
CTCLabelDecode
:
97
#'blank' + ...(62) + ' '
head_list
:
-
CTCHead
:
Neck
:
name
:
svtr
dims
:
120
depth
:
2
hidden_dims
:
120
kernel_size
:
[
1
,
3
]
use_guide
:
True
Head
:
fc_decay
:
0.00001
-
NRTRHead
:
nrtr_dim
:
384
max_text_length
:
25
ch_PP-OCRv4_det_infer
:
model_type
:
det
algorithm
:
DB
Transform
:
null
Backbone
:
name
:
PPLCNetV3
scale
:
0.75
det
:
True
Neck
:
name
:
RSEFPN
out_channels
:
96
shortcut
:
True
Head
:
name
:
DBHead
k
:
50
ch_PP-OCRv4_det_server_infer
:
model_type
:
det
algorithm
:
DB
Transform
:
null
Backbone
:
name
:
PPHGNet_small
det
:
True
Neck
:
name
:
LKPAN
out_channels
:
256
intracl
:
true
Head
:
name
:
PFHeadLocal
k
:
50
mode
:
"
large"
ch_PP-OCRv4_rec_infer
:
model_type
:
rec
algorithm
:
SVTR_LCNet
Transform
:
Backbone
:
name
:
PPLCNetV3
scale
:
0.95
Head
:
name
:
MultiHead
out_channels_list
:
CTCLabelDecode
:
6625
#'blank' + ...(6623) + ' '
head_list
:
-
CTCHead
:
Neck
:
name
:
svtr
dims
:
120
depth
:
2
hidden_dims
:
120
kernel_size
:
[
1
,
3
]
use_guide
:
True
Head
:
fc_decay
:
0.00001
-
NRTRHead
:
nrtr_dim
:
384
max_text_length
:
25
ch_PP-OCRv4_rec_server_infer
:
model_type
:
rec
algorithm
:
SVTR_HGNet
Transform
:
Backbone
:
name
:
PPHGNet_small
Head
:
name
:
MultiHead
out_channels_list
:
CTCLabelDecode
:
6625
#'blank' + ...(6623) + ' '
head_list
:
-
CTCHead
:
Neck
:
name
:
svtr
dims
:
120
depth
:
2
hidden_dims
:
120
kernel_size
:
[
1
,
3
]
use_guide
:
True
Head
:
fc_decay
:
0.00001
-
NRTRHead
:
nrtr_dim
:
384
max_text_length
:
25
chinese_cht_PP-OCRv3_rec_infer
:
model_type
:
rec
algorithm
:
SVTR
Transform
:
Backbone
:
name
:
MobileNetV1Enhance
scale
:
0.5
last_conv_stride
:
[
1
,
2
]
last_pool_type
:
avg
Neck
:
name
:
SequenceEncoder
encoder_type
:
svtr
dims
:
64
depth
:
2
hidden_dims
:
120
use_guide
:
True
Head
:
name
:
CTCHead
# out_channels: 8423
fc_decay
:
0.00001
latin_PP-OCRv3_rec_infer
:
model_type
:
rec
algorithm
:
SVTR
Transform
:
Backbone
:
name
:
MobileNetV1Enhance
scale
:
0.5
last_conv_stride
:
[
1
,
2
]
last_pool_type
:
avg
Neck
:
name
:
SequenceEncoder
encoder_type
:
svtr
dims
:
64
depth
:
2
hidden_dims
:
120
use_guide
:
True
Head
:
name
:
CTCHead
# out_channels: 187
fc_decay
:
0.00001
cyrillic_PP-OCRv3_rec_infer
:
model_type
:
rec
algorithm
:
SVTR
Transform
:
Backbone
:
name
:
MobileNetV1Enhance
scale
:
0.5
last_conv_stride
:
[
1
,
2
]
last_pool_type
:
avg
Neck
:
name
:
SequenceEncoder
encoder_type
:
svtr
dims
:
64
depth
:
2
hidden_dims
:
120
use_guide
:
True
Head
:
name
:
CTCHead
# out_channels: 165
fc_decay
:
0.00001
arabic_PP-OCRv3_rec_infer
:
model_type
:
rec
algorithm
:
SVTR
Transform
:
Backbone
:
name
:
MobileNetV1Enhance
scale
:
0.5
last_conv_stride
:
[
1
,
2
]
last_pool_type
:
avg
Neck
:
name
:
SequenceEncoder
encoder_type
:
svtr
dims
:
64
depth
:
2
hidden_dims
:
120
use_guide
:
True
Head
:
name
:
CTCHead
# out_channels: 164
fc_decay
:
0.00001
korean_PP-OCRv3_rec_infer
:
model_type
:
rec
algorithm
:
SVTR
Transform
:
Backbone
:
name
:
MobileNetV1Enhance
scale
:
0.5
last_conv_stride
:
[
1
,
2
]
last_pool_type
:
avg
Neck
:
name
:
SequenceEncoder
encoder_type
:
svtr
dims
:
64
depth
:
2
hidden_dims
:
120
use_guide
:
True
Head
:
name
:
CTCHead
# out_channels: 3690
fc_decay
:
0.00001
japan_PP-OCRv3_rec_infer
:
model_type
:
rec
algorithm
:
SVTR
Transform
:
Backbone
:
name
:
MobileNetV1Enhance
scale
:
0.5
last_conv_stride
:
[
1
,
2
]
last_pool_type
:
avg
Neck
:
name
:
SequenceEncoder
encoder_type
:
svtr
dims
:
64
depth
:
2
hidden_dims
:
120
use_guide
:
True
Head
:
name
:
CTCHead
# out_channels: 4401
fc_decay
:
0.00001
ta_PP-OCRv3_rec_infer
:
model_type
:
rec
algorithm
:
SVTR
Transform
:
Backbone
:
name
:
MobileNetV1Enhance
scale
:
0.5
last_conv_stride
:
[
1
,
2
]
last_pool_type
:
avg
Neck
:
name
:
SequenceEncoder
encoder_type
:
svtr
dims
:
64
depth
:
2
hidden_dims
:
120
use_guide
:
True
Head
:
name
:
CTCHead
# out_channels: 130
fc_decay
:
0.00001
te_PP-OCRv3_rec_infer
:
model_type
:
rec
algorithm
:
SVTR
Transform
:
Backbone
:
name
:
MobileNetV1Enhance
scale
:
0.5
last_conv_stride
:
[
1
,
2
]
last_pool_type
:
avg
Neck
:
name
:
SequenceEncoder
encoder_type
:
svtr
dims
:
64
depth
:
2
hidden_dims
:
120
use_guide
:
True
Head
:
name
:
CTCHead
# out_channels: 153
fc_decay
:
0.00001
ka_PP-OCRv3_rec_infer
:
model_type
:
rec
algorithm
:
SVTR
Transform
:
Backbone
:
name
:
MobileNetV1Enhance
scale
:
0.5
last_conv_stride
:
[
1
,
2
]
last_pool_type
:
avg
Neck
:
name
:
SequenceEncoder
encoder_type
:
svtr
dims
:
64
depth
:
2
hidden_dims
:
120
use_guide
:
True
Head
:
name
:
CTCHead
# out_channels: 155
fc_decay
:
0.00001
devanagari_PP-OCRv3_rec_infer
:
model_type
:
rec
algorithm
:
SVTR
Transform
:
Backbone
:
name
:
MobileNetV1Enhance
scale
:
0.5
last_conv_stride
:
[
1
,
2
]
last_pool_type
:
avg
Neck
:
name
:
SequenceEncoder
encoder_type
:
svtr
dims
:
64
depth
:
2
hidden_dims
:
120
use_guide
:
True
Head
:
name
:
CTCHead
# out_channels: 169
fc_decay
:
0.00001
magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/architectures/__init__.py
0 → 100644
View file @
62b7582f
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
copy
__all__
=
[
"build_model"
]
def
build_model
(
config
,
**
kwargs
):
from
.base_model
import
BaseModel
config
=
copy
.
deepcopy
(
config
)
module_class
=
BaseModel
(
config
,
**
kwargs
)
return
module_class
magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/architectures/base_model.py
0 → 100644
View file @
62b7582f
from
torch
import
nn
from
..backbones
import
build_backbone
from
..heads
import
build_head
from
..necks
import
build_neck
class
BaseModel
(
nn
.
Module
):
def
__init__
(
self
,
config
,
**
kwargs
):
"""
the module for OCR.
args:
config (dict): the super parameters for module.
"""
super
(
BaseModel
,
self
).
__init__
()
in_channels
=
config
.
get
(
"in_channels"
,
3
)
model_type
=
config
[
"model_type"
]
# build backbone, backbone is need for del, rec and cls
if
"Backbone"
not
in
config
or
config
[
"Backbone"
]
is
None
:
self
.
use_backbone
=
False
else
:
self
.
use_backbone
=
True
config
[
"Backbone"
][
"in_channels"
]
=
in_channels
self
.
backbone
=
build_backbone
(
config
[
"Backbone"
],
model_type
)
in_channels
=
self
.
backbone
.
out_channels
# build neck
# for rec, neck can be cnn,rnn or reshape(None)
# for det, neck can be FPN, BIFPN and so on.
# for cls, neck should be none
if
"Neck"
not
in
config
or
config
[
"Neck"
]
is
None
:
self
.
use_neck
=
False
else
:
self
.
use_neck
=
True
config
[
"Neck"
][
"in_channels"
]
=
in_channels
self
.
neck
=
build_neck
(
config
[
"Neck"
])
in_channels
=
self
.
neck
.
out_channels
# # build head, head is need for det, rec and cls
if
"Head"
not
in
config
or
config
[
"Head"
]
is
None
:
self
.
use_head
=
False
else
:
self
.
use_head
=
True
config
[
"Head"
][
"in_channels"
]
=
in_channels
self
.
head
=
build_head
(
config
[
"Head"
],
**
kwargs
)
self
.
return_all_feats
=
config
.
get
(
"return_all_feats"
,
False
)
self
.
_initialize_weights
()
def
_initialize_weights
(
self
):
# weight initialization
for
m
in
self
.
modules
():
if
isinstance
(
m
,
nn
.
Conv2d
):
nn
.
init
.
kaiming_normal_
(
m
.
weight
,
mode
=
"fan_out"
)
if
m
.
bias
is
not
None
:
nn
.
init
.
zeros_
(
m
.
bias
)
elif
isinstance
(
m
,
nn
.
BatchNorm2d
):
nn
.
init
.
ones_
(
m
.
weight
)
nn
.
init
.
zeros_
(
m
.
bias
)
elif
isinstance
(
m
,
nn
.
Linear
):
nn
.
init
.
normal_
(
m
.
weight
,
0
,
0.01
)
if
m
.
bias
is
not
None
:
nn
.
init
.
zeros_
(
m
.
bias
)
elif
isinstance
(
m
,
nn
.
ConvTranspose2d
):
nn
.
init
.
kaiming_normal_
(
m
.
weight
,
mode
=
"fan_out"
)
if
m
.
bias
is
not
None
:
nn
.
init
.
zeros_
(
m
.
bias
)
def
forward
(
self
,
x
):
y
=
dict
()
if
self
.
use_backbone
:
x
=
self
.
backbone
(
x
)
if
isinstance
(
x
,
dict
):
y
.
update
(
x
)
else
:
y
[
"backbone_out"
]
=
x
final_name
=
"backbone_out"
if
self
.
use_neck
:
x
=
self
.
neck
(
x
)
if
isinstance
(
x
,
dict
):
y
.
update
(
x
)
else
:
y
[
"neck_out"
]
=
x
final_name
=
"neck_out"
if
self
.
use_head
:
x
=
self
.
head
(
x
)
# for multi head, save ctc neck out for udml
if
isinstance
(
x
,
dict
)
and
"ctc_nect"
in
x
.
keys
():
y
[
"neck_out"
]
=
x
[
"ctc_neck"
]
y
[
"head_out"
]
=
x
elif
isinstance
(
x
,
dict
):
y
.
update
(
x
)
else
:
y
[
"head_out"
]
=
x
if
self
.
return_all_feats
:
if
self
.
training
:
return
y
elif
isinstance
(
x
,
dict
):
return
x
else
:
return
{
final_name
:
x
}
else
:
return
x
magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/__init__.py
0 → 100644
View file @
62b7582f
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__all__
=
[
"build_backbone"
]
def
build_backbone
(
config
,
model_type
):
if
model_type
==
"det"
:
from
.det_mobilenet_v3
import
MobileNetV3
from
.rec_hgnet
import
PPHGNet_small
from
.rec_lcnetv3
import
PPLCNetV3
support_dict
=
[
"MobileNetV3"
,
"ResNet"
,
"ResNet_vd"
,
"ResNet_SAST"
,
"PPLCNetV3"
,
"PPHGNet_small"
,
]
elif
model_type
==
"rec"
or
model_type
==
"cls"
:
from
.rec_hgnet
import
PPHGNet_small
from
.rec_lcnetv3
import
PPLCNetV3
from
.rec_mobilenet_v3
import
MobileNetV3
from
.rec_svtrnet
import
SVTRNet
from
.rec_mv1_enhance
import
MobileNetV1Enhance
support_dict
=
[
"MobileNetV1Enhance"
,
"MobileNetV3"
,
"ResNet"
,
"ResNetFPN"
,
"MTB"
,
"ResNet31"
,
"SVTRNet"
,
"ViTSTR"
,
"DenseNet"
,
"PPLCNetV3"
,
"PPHGNet_small"
,
]
else
:
raise
NotImplementedError
module_name
=
config
.
pop
(
"name"
)
assert
module_name
in
support_dict
,
Exception
(
"when model typs is {}, backbone only support {}"
.
format
(
model_type
,
support_dict
)
)
module_class
=
eval
(
module_name
)(
**
config
)
return
module_class
magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/det_mobilenet_v3.py
0 → 100644
View file @
62b7582f
from
torch
import
nn
from
..common
import
Activation
def
make_divisible
(
v
,
divisor
=
8
,
min_value
=
None
):
if
min_value
is
None
:
min_value
=
divisor
new_v
=
max
(
min_value
,
int
(
v
+
divisor
/
2
)
//
divisor
*
divisor
)
if
new_v
<
0.9
*
v
:
new_v
+=
divisor
return
new_v
class
ConvBNLayer
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
stride
,
padding
,
groups
=
1
,
if_act
=
True
,
act
=
None
,
name
=
None
,
):
super
(
ConvBNLayer
,
self
).
__init__
()
self
.
if_act
=
if_act
self
.
conv
=
nn
.
Conv2d
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
padding
,
groups
=
groups
,
bias
=
False
,
)
self
.
bn
=
nn
.
BatchNorm2d
(
out_channels
,
)
if
self
.
if_act
:
self
.
act
=
Activation
(
act_type
=
act
,
inplace
=
True
)
def
forward
(
self
,
x
):
x
=
self
.
conv
(
x
)
x
=
self
.
bn
(
x
)
if
self
.
if_act
:
x
=
self
.
act
(
x
)
return
x
class
SEModule
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
reduction
=
4
,
name
=
""
):
super
(
SEModule
,
self
).
__init__
()
self
.
avg_pool
=
nn
.
AdaptiveAvgPool2d
(
1
)
self
.
conv1
=
nn
.
Conv2d
(
in_channels
=
in_channels
,
out_channels
=
in_channels
//
reduction
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
bias
=
True
,
)
self
.
relu1
=
Activation
(
act_type
=
"relu"
,
inplace
=
True
)
self
.
conv2
=
nn
.
Conv2d
(
in_channels
=
in_channels
//
reduction
,
out_channels
=
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
bias
=
True
,
)
self
.
hard_sigmoid
=
Activation
(
act_type
=
"hard_sigmoid"
,
inplace
=
True
)
def
forward
(
self
,
inputs
):
outputs
=
self
.
avg_pool
(
inputs
)
outputs
=
self
.
conv1
(
outputs
)
outputs
=
self
.
relu1
(
outputs
)
outputs
=
self
.
conv2
(
outputs
)
outputs
=
self
.
hard_sigmoid
(
outputs
)
outputs
=
inputs
*
outputs
return
outputs
class
ResidualUnit
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
mid_channels
,
out_channels
,
kernel_size
,
stride
,
use_se
,
act
=
None
,
name
=
""
,
):
super
(
ResidualUnit
,
self
).
__init__
()
self
.
if_shortcut
=
stride
==
1
and
in_channels
==
out_channels
self
.
if_se
=
use_se
self
.
expand_conv
=
ConvBNLayer
(
in_channels
=
in_channels
,
out_channels
=
mid_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
if_act
=
True
,
act
=
act
,
name
=
name
+
"_expand"
,
)
self
.
bottleneck_conv
=
ConvBNLayer
(
in_channels
=
mid_channels
,
out_channels
=
mid_channels
,
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
int
((
kernel_size
-
1
)
//
2
),
groups
=
mid_channels
,
if_act
=
True
,
act
=
act
,
name
=
name
+
"_depthwise"
,
)
if
self
.
if_se
:
self
.
mid_se
=
SEModule
(
mid_channels
,
name
=
name
+
"_se"
)
self
.
linear_conv
=
ConvBNLayer
(
in_channels
=
mid_channels
,
out_channels
=
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
if_act
=
False
,
act
=
None
,
name
=
name
+
"_linear"
,
)
def
forward
(
self
,
inputs
):
x
=
self
.
expand_conv
(
inputs
)
x
=
self
.
bottleneck_conv
(
x
)
if
self
.
if_se
:
x
=
self
.
mid_se
(
x
)
x
=
self
.
linear_conv
(
x
)
if
self
.
if_shortcut
:
x
=
inputs
+
x
return
x
class
MobileNetV3
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
=
3
,
model_name
=
"large"
,
scale
=
0.5
,
disable_se
=
False
,
**
kwargs
):
"""
the MobilenetV3 backbone network for detection module.
Args:
params(dict): the super parameters for build network
"""
super
(
MobileNetV3
,
self
).
__init__
()
self
.
disable_se
=
disable_se
if
model_name
==
"large"
:
cfg
=
[
# k, exp, c, se, nl, s,
[
3
,
16
,
16
,
False
,
"relu"
,
1
],
[
3
,
64
,
24
,
False
,
"relu"
,
2
],
[
3
,
72
,
24
,
False
,
"relu"
,
1
],
[
5
,
72
,
40
,
True
,
"relu"
,
2
],
[
5
,
120
,
40
,
True
,
"relu"
,
1
],
[
5
,
120
,
40
,
True
,
"relu"
,
1
],
[
3
,
240
,
80
,
False
,
"hard_swish"
,
2
],
[
3
,
200
,
80
,
False
,
"hard_swish"
,
1
],
[
3
,
184
,
80
,
False
,
"hard_swish"
,
1
],
[
3
,
184
,
80
,
False
,
"hard_swish"
,
1
],
[
3
,
480
,
112
,
True
,
"hard_swish"
,
1
],
[
3
,
672
,
112
,
True
,
"hard_swish"
,
1
],
[
5
,
672
,
160
,
True
,
"hard_swish"
,
2
],
[
5
,
960
,
160
,
True
,
"hard_swish"
,
1
],
[
5
,
960
,
160
,
True
,
"hard_swish"
,
1
],
]
cls_ch_squeeze
=
960
elif
model_name
==
"small"
:
cfg
=
[
# k, exp, c, se, nl, s,
[
3
,
16
,
16
,
True
,
"relu"
,
2
],
[
3
,
72
,
24
,
False
,
"relu"
,
2
],
[
3
,
88
,
24
,
False
,
"relu"
,
1
],
[
5
,
96
,
40
,
True
,
"hard_swish"
,
2
],
[
5
,
240
,
40
,
True
,
"hard_swish"
,
1
],
[
5
,
240
,
40
,
True
,
"hard_swish"
,
1
],
[
5
,
120
,
48
,
True
,
"hard_swish"
,
1
],
[
5
,
144
,
48
,
True
,
"hard_swish"
,
1
],
[
5
,
288
,
96
,
True
,
"hard_swish"
,
2
],
[
5
,
576
,
96
,
True
,
"hard_swish"
,
1
],
[
5
,
576
,
96
,
True
,
"hard_swish"
,
1
],
]
cls_ch_squeeze
=
576
else
:
raise
NotImplementedError
(
"mode["
+
model_name
+
"_model] is not implemented!"
)
supported_scale
=
[
0.35
,
0.5
,
0.75
,
1.0
,
1.25
]
assert
(
scale
in
supported_scale
),
"supported scale are {} but input scale is {}"
.
format
(
supported_scale
,
scale
)
inplanes
=
16
# conv1
self
.
conv
=
ConvBNLayer
(
in_channels
=
in_channels
,
out_channels
=
make_divisible
(
inplanes
*
scale
),
kernel_size
=
3
,
stride
=
2
,
padding
=
1
,
groups
=
1
,
if_act
=
True
,
act
=
"hard_swish"
,
name
=
"conv1"
,
)
self
.
stages
=
nn
.
ModuleList
()
self
.
out_channels
=
[]
block_list
=
[]
i
=
0
inplanes
=
make_divisible
(
inplanes
*
scale
)
for
k
,
exp
,
c
,
se
,
nl
,
s
in
cfg
:
se
=
se
and
not
self
.
disable_se
if
s
==
2
and
i
>
2
:
self
.
out_channels
.
append
(
inplanes
)
self
.
stages
.
append
(
nn
.
Sequential
(
*
block_list
))
block_list
=
[]
block_list
.
append
(
ResidualUnit
(
in_channels
=
inplanes
,
mid_channels
=
make_divisible
(
scale
*
exp
),
out_channels
=
make_divisible
(
scale
*
c
),
kernel_size
=
k
,
stride
=
s
,
use_se
=
se
,
act
=
nl
,
name
=
"conv"
+
str
(
i
+
2
),
)
)
inplanes
=
make_divisible
(
scale
*
c
)
i
+=
1
block_list
.
append
(
ConvBNLayer
(
in_channels
=
inplanes
,
out_channels
=
make_divisible
(
scale
*
cls_ch_squeeze
),
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
groups
=
1
,
if_act
=
True
,
act
=
"hard_swish"
,
name
=
"conv_last"
,
)
)
self
.
stages
.
append
(
nn
.
Sequential
(
*
block_list
))
self
.
out_channels
.
append
(
make_divisible
(
scale
*
cls_ch_squeeze
))
# for i, stage in enumerate(self.stages):
# self.add_sublayer(sublayer=stage, name="stage{}".format(i))
def
forward
(
self
,
x
):
x
=
self
.
conv
(
x
)
out_list
=
[]
for
stage
in
self
.
stages
:
x
=
stage
(
x
)
out_list
.
append
(
x
)
return
out_list
magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_hgnet.py
0 → 100644
View file @
62b7582f
import
torch
import
torch.nn.functional
as
F
from
torch
import
nn
class
ConvBNAct
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
stride
,
groups
=
1
,
use_act
=
True
):
super
().
__init__
()
self
.
use_act
=
use_act
self
.
conv
=
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
,
stride
,
padding
=
(
kernel_size
-
1
)
//
2
,
groups
=
groups
,
bias
=
False
,
)
self
.
bn
=
nn
.
BatchNorm2d
(
out_channels
)
if
self
.
use_act
:
self
.
act
=
nn
.
ReLU
()
def
forward
(
self
,
x
):
x
=
self
.
conv
(
x
)
x
=
self
.
bn
(
x
)
if
self
.
use_act
:
x
=
self
.
act
(
x
)
return
x
class
ESEModule
(
nn
.
Module
):
def
__init__
(
self
,
channels
):
super
().
__init__
()
self
.
avg_pool
=
nn
.
AdaptiveAvgPool2d
(
1
)
self
.
conv
=
nn
.
Conv2d
(
in_channels
=
channels
,
out_channels
=
channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
)
self
.
sigmoid
=
nn
.
Sigmoid
()
def
forward
(
self
,
x
):
identity
=
x
x
=
self
.
avg_pool
(
x
)
x
=
self
.
conv
(
x
)
x
=
self
.
sigmoid
(
x
)
return
x
*
identity
class
HG_Block
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
mid_channels
,
out_channels
,
layer_num
,
identity
=
False
,
):
super
().
__init__
()
self
.
identity
=
identity
self
.
layers
=
nn
.
ModuleList
()
self
.
layers
.
append
(
ConvBNAct
(
in_channels
=
in_channels
,
out_channels
=
mid_channels
,
kernel_size
=
3
,
stride
=
1
,
)
)
for
_
in
range
(
layer_num
-
1
):
self
.
layers
.
append
(
ConvBNAct
(
in_channels
=
mid_channels
,
out_channels
=
mid_channels
,
kernel_size
=
3
,
stride
=
1
,
)
)
# feature aggregation
total_channels
=
in_channels
+
layer_num
*
mid_channels
self
.
aggregation_conv
=
ConvBNAct
(
in_channels
=
total_channels
,
out_channels
=
out_channels
,
kernel_size
=
1
,
stride
=
1
,
)
self
.
att
=
ESEModule
(
out_channels
)
def
forward
(
self
,
x
):
identity
=
x
output
=
[]
output
.
append
(
x
)
for
layer
in
self
.
layers
:
x
=
layer
(
x
)
output
.
append
(
x
)
x
=
torch
.
cat
(
output
,
dim
=
1
)
x
=
self
.
aggregation_conv
(
x
)
x
=
self
.
att
(
x
)
if
self
.
identity
:
x
+=
identity
return
x
class
HG_Stage
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
mid_channels
,
out_channels
,
block_num
,
layer_num
,
downsample
=
True
,
stride
=
[
2
,
1
],
):
super
().
__init__
()
self
.
downsample
=
downsample
if
downsample
:
self
.
downsample
=
ConvBNAct
(
in_channels
=
in_channels
,
out_channels
=
in_channels
,
kernel_size
=
3
,
stride
=
stride
,
groups
=
in_channels
,
use_act
=
False
,
)
blocks_list
=
[]
blocks_list
.
append
(
HG_Block
(
in_channels
,
mid_channels
,
out_channels
,
layer_num
,
identity
=
False
)
)
for
_
in
range
(
block_num
-
1
):
blocks_list
.
append
(
HG_Block
(
out_channels
,
mid_channels
,
out_channels
,
layer_num
,
identity
=
True
)
)
self
.
blocks
=
nn
.
Sequential
(
*
blocks_list
)
def
forward
(
self
,
x
):
if
self
.
downsample
:
x
=
self
.
downsample
(
x
)
x
=
self
.
blocks
(
x
)
return
x
class
PPHGNet
(
nn
.
Module
):
"""
PPHGNet
Args:
stem_channels: list. Stem channel list of PPHGNet.
stage_config: dict. The configuration of each stage of PPHGNet. such as the number of channels, stride, etc.
layer_num: int. Number of layers of HG_Block.
use_last_conv: boolean. Whether to use a 1x1 convolutional layer before the classification layer.
class_expand: int=2048. Number of channels for the last 1x1 convolutional layer.
dropout_prob: float. Parameters of dropout, 0.0 means dropout is not used.
class_num: int=1000. The number of classes.
Returns:
model: nn.Layer. Specific PPHGNet model depends on args.
"""
def
__init__
(
self
,
stem_channels
,
stage_config
,
layer_num
,
in_channels
=
3
,
det
=
False
,
out_indices
=
None
,
):
super
().
__init__
()
self
.
det
=
det
self
.
out_indices
=
out_indices
if
out_indices
is
not
None
else
[
0
,
1
,
2
,
3
]
# stem
stem_channels
.
insert
(
0
,
in_channels
)
self
.
stem
=
nn
.
Sequential
(
*
[
ConvBNAct
(
in_channels
=
stem_channels
[
i
],
out_channels
=
stem_channels
[
i
+
1
],
kernel_size
=
3
,
stride
=
2
if
i
==
0
else
1
,
)
for
i
in
range
(
len
(
stem_channels
)
-
1
)
]
)
if
self
.
det
:
self
.
pool
=
nn
.
MaxPool2d
(
kernel_size
=
3
,
stride
=
2
,
padding
=
1
)
# stages
self
.
stages
=
nn
.
ModuleList
()
self
.
out_channels
=
[]
for
block_id
,
k
in
enumerate
(
stage_config
):
(
in_channels
,
mid_channels
,
out_channels
,
block_num
,
downsample
,
stride
,
)
=
stage_config
[
k
]
self
.
stages
.
append
(
HG_Stage
(
in_channels
,
mid_channels
,
out_channels
,
block_num
,
layer_num
,
downsample
,
stride
,
)
)
if
block_id
in
self
.
out_indices
:
self
.
out_channels
.
append
(
out_channels
)
if
not
self
.
det
:
self
.
out_channels
=
stage_config
[
"stage4"
][
2
]
self
.
_init_weights
()
def
_init_weights
(
self
):
for
m
in
self
.
modules
():
if
isinstance
(
m
,
nn
.
Conv2d
):
nn
.
init
.
kaiming_normal_
(
m
.
weight
)
elif
isinstance
(
m
,
nn
.
BatchNorm2d
):
nn
.
init
.
ones_
(
m
.
weight
)
nn
.
init
.
zeros_
(
m
.
bias
)
elif
isinstance
(
m
,
nn
.
Linear
):
nn
.
init
.
zeros_
(
m
.
bias
)
def
forward
(
self
,
x
):
x
=
self
.
stem
(
x
)
if
self
.
det
:
x
=
self
.
pool
(
x
)
out
=
[]
for
i
,
stage
in
enumerate
(
self
.
stages
):
x
=
stage
(
x
)
if
self
.
det
and
i
in
self
.
out_indices
:
out
.
append
(
x
)
if
self
.
det
:
return
out
if
self
.
training
:
x
=
F
.
adaptive_avg_pool2d
(
x
,
[
1
,
40
])
else
:
x
=
F
.
avg_pool2d
(
x
,
[
3
,
2
])
return
x
def
PPHGNet_small
(
pretrained
=
False
,
use_ssld
=
False
,
det
=
False
,
**
kwargs
):
"""
PPHGNet_small
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `PPHGNet_small` model depends on args.
"""
stage_config_det
=
{
# in_channels, mid_channels, out_channels, blocks, downsample
"stage1"
:
[
128
,
128
,
256
,
1
,
False
,
2
],
"stage2"
:
[
256
,
160
,
512
,
1
,
True
,
2
],
"stage3"
:
[
512
,
192
,
768
,
2
,
True
,
2
],
"stage4"
:
[
768
,
224
,
1024
,
1
,
True
,
2
],
}
stage_config_rec
=
{
# in_channels, mid_channels, out_channels, blocks, downsample
"stage1"
:
[
128
,
128
,
256
,
1
,
True
,
[
2
,
1
]],
"stage2"
:
[
256
,
160
,
512
,
1
,
True
,
[
1
,
2
]],
"stage3"
:
[
512
,
192
,
768
,
2
,
True
,
[
2
,
1
]],
"stage4"
:
[
768
,
224
,
1024
,
1
,
True
,
[
2
,
1
]],
}
model
=
PPHGNet
(
stem_channels
=
[
64
,
64
,
128
],
stage_config
=
stage_config_det
if
det
else
stage_config_rec
,
layer_num
=
6
,
det
=
det
,
**
kwargs
)
return
model
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment