Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
MinerU
Commits
4101c357
Commit
4101c357
authored
Jul 12, 2024
by
zhaoxiaomeng
Browse files
refactor(model): update init methods and improve model loading logic
parent
b6df9b18
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
36 additions
and
29 deletions
+36
-29
magic_pdf/model/__init__.py
magic_pdf/model/__init__.py
+1
-1
magic_pdf/model/pdf_extract_kit.py
magic_pdf/model/pdf_extract_kit.py
+35
-28
No files found.
magic_pdf/model/__init__.py
View file @
4101c357
__use_inside_model__
=
Tru
e
__use_inside_model__
=
Fals
e
__model_mode__
=
"full"
__model_mode__
=
"full"
magic_pdf/model/pdf_extract_kit.py
View file @
4101c357
import
os
import
os
import
time
import
cv2
import
cv2
import
numpy
as
np
import
yaml
import
yaml
from
PIL
import
Image
import
time
from
ultralytics
import
YOLO
import
argparse
import
numpy
as
np
import
torch
from
loguru
import
logger
from
loguru
import
logger
from
magic_pdf.model.pek_sub_modules.layoutlmv3.model_init
import
Layoutlmv3_Predictor
from
paddleocr
import
draw_ocr
from
PIL
import
Image
from
torchvision
import
transforms
from
torch.utils.data
import
Dataset
,
DataLoader
from
ultralytics
import
YOLO
from
unimernet.common.config
import
Config
from
unimernet.common.config
import
Config
import
unimernet.tasks
as
tasks
import
unimernet.tasks
as
tasks
from
unimernet.processors
import
load_processor
from
unimernet.processors
import
load_processor
import
argparse
from
torchvision
import
transforms
from
torch.utils.data
import
Dataset
,
DataLoader
from
magic_pdf.model.pek_sub_modules.layoutlmv3.model_init
import
Layoutlmv3_Predictor
from
magic_pdf.model.pek_sub_modules.post_process
import
get_croped_image
,
latex_rm_whitespace
from
magic_pdf.model.pek_sub_modules.post_process
import
get_croped_image
,
latex_rm_whitespace
from
magic_pdf.model.pek_sub_modules.self_modify
import
ModifiedPaddleOCR
from
magic_pdf.model.pek_sub_modules.self_modify
import
ModifiedPaddleOCR
def
layout
_model_init
(
weight
,
config_file
,
device
):
def
mfd
_model_init
(
weight
):
model
=
Layoutlmv3_Predictor
(
weight
,
config_file
,
device
)
mfd_
model
=
YOLO
(
weight
)
return
model
return
mfd_
model
def
mfr_model_init
(
weight_dir
,
cfg_path
,
device
=
'cpu'
):
def
mfr_model_init
(
weight_dir
,
cfg_path
,
_
device
_
=
'cpu'
):
args
=
argparse
.
Namespace
(
cfg_path
=
cfg_path
,
options
=
None
)
args
=
argparse
.
Namespace
(
cfg_path
=
cfg_path
,
options
=
None
)
cfg
=
Config
(
args
)
cfg
=
Config
(
args
)
cfg
.
config
.
model
.
pretrained
=
os
.
path
.
join
(
weight_dir
,
"pytorch_model.bin"
)
cfg
.
config
.
model
.
pretrained
=
os
.
path
.
join
(
weight_dir
,
"pytorch_model.bin"
)
...
@@ -33,11 +34,16 @@ def mfr_model_init(weight_dir, cfg_path, device='cpu'):
...
@@ -33,11 +34,16 @@ def mfr_model_init(weight_dir, cfg_path, device='cpu'):
cfg
.
config
.
model
.
tokenizer_config
.
path
=
weight_dir
cfg
.
config
.
model
.
tokenizer_config
.
path
=
weight_dir
task
=
tasks
.
setup_task
(
cfg
)
task
=
tasks
.
setup_task
(
cfg
)
model
=
task
.
build_model
(
cfg
)
model
=
task
.
build_model
(
cfg
)
model
=
model
.
to
(
device
)
model
=
model
.
to
(
_
device
_
)
vis_processor
=
load_processor
(
'formula_image_eval'
,
cfg
.
config
.
datasets
.
formula_rec_eval
.
vis_processor
.
eval
)
vis_processor
=
load_processor
(
'formula_image_eval'
,
cfg
.
config
.
datasets
.
formula_rec_eval
.
vis_processor
.
eval
)
return
model
,
vis_processor
return
model
,
vis_processor
def
layout_model_init
(
weight
,
config_file
,
device
):
model
=
Layoutlmv3_Predictor
(
weight
,
config_file
,
device
)
return
model
class
MathDataset
(
Dataset
):
class
MathDataset
(
Dataset
):
def
__init__
(
self
,
image_paths
,
transform
=
None
):
def
__init__
(
self
,
image_paths
,
transform
=
None
):
self
.
image_paths
=
image_paths
self
.
image_paths
=
image_paths
...
@@ -54,10 +60,11 @@ class MathDataset(Dataset):
...
@@ -54,10 +60,11 @@ class MathDataset(Dataset):
raw_image
=
self
.
image_paths
[
idx
]
raw_image
=
self
.
image_paths
[
idx
]
if
self
.
transform
:
if
self
.
transform
:
image
=
self
.
transform
(
raw_image
)
image
=
self
.
transform
(
raw_image
)
return
image
return
image
class
CustomPEKModel
:
class
CustomPEKModel
:
def
__init__
(
self
,
ocr
:
bool
=
False
,
show_log
:
bool
=
False
,
**
kwargs
):
def
__init__
(
self
,
ocr
:
bool
=
False
,
show_log
:
bool
=
False
,
**
kwargs
):
"""
"""
======== model init ========
======== model init ========
...
@@ -88,24 +95,24 @@ class CustomPEKModel:
...
@@ -88,24 +95,24 @@ class CustomPEKModel:
self
.
device
=
kwargs
.
get
(
"device"
,
self
.
configs
[
"config"
][
"device"
])
self
.
device
=
kwargs
.
get
(
"device"
,
self
.
configs
[
"config"
][
"device"
])
logger
.
info
(
"using device: {}"
.
format
(
self
.
device
))
logger
.
info
(
"using device: {}"
.
format
(
self
.
device
))
models_dir
=
kwargs
.
get
(
"models_dir"
,
os
.
path
.
join
(
root_dir
,
"resources"
,
"models"
))
models_dir
=
kwargs
.
get
(
"models_dir"
,
os
.
path
.
join
(
root_dir
,
"resources"
,
"models"
))
# 初始化layout模型
self
.
layout_model
=
layout_model_init
(
os
.
path
.
join
(
models_dir
,
self
.
configs
[
'weights'
][
'layout'
]),
os
.
path
.
join
(
model_config_dir
,
"layoutlmv3"
,
"layoutlmv3_base_inference.yaml"
),
device
=
self
.
device
)
# 初始化公式识别
# 初始化公式识别
if
self
.
apply_formula
:
if
self
.
apply_formula
:
# 初始化公式检测模型
# 初始化公式检测模型
self
.
mfd_model
=
YOLO
(
model
=
str
(
os
.
path
.
join
(
models_dir
,
self
.
configs
[
"weights"
][
"mfd"
])))
self
.
mfd_model
=
mfd_model_init
(
str
(
os
.
path
.
join
(
models_dir
,
self
.
configs
[
"weights"
][
"mfd"
])))
# 初始化公式解析模型
# 初始化公式解析模型
mfr_config_path
=
os
.
path
.
join
(
model_config_dir
,
'UniMERNet'
,
'demo.yaml'
)
mfr_weight_dir
=
str
(
os
.
path
.
join
(
models_dir
,
self
.
configs
[
"weights"
][
"mfr"
]))
self
.
mfr_model
,
mfr_vis_processors
=
mfr_model_init
(
mfr_cfg_path
=
str
(
os
.
path
.
join
(
model_config_dir
,
"UniMERNet"
,
"demo.yaml"
))
os
.
path
.
join
(
models_dir
,
self
.
configs
[
"weights"
][
"mfr"
]),
self
.
mfr_model
,
mfr_vis_processors
=
mfr_model_init
(
mfr_weight_dir
,
mfr_cfg_path
,
_device_
=
self
.
device
)
mfr_config_path
,
device
=
self
.
device
)
self
.
mfr_transform
=
transforms
.
Compose
([
mfr_vis_processors
,
])
self
.
mfr_transform
=
transforms
.
Compose
([
mfr_vis_processors
,
])
# 初始化layout模型
self
.
layout_model
=
Layoutlmv3_Predictor
(
str
(
os
.
path
.
join
(
models_dir
,
self
.
configs
[
'weights'
][
'layout'
])),
str
(
os
.
path
.
join
(
model_config_dir
,
"layoutlmv3"
,
"layoutlmv3_base_inference.yaml"
)),
device
=
self
.
device
)
# 初始化ocr
# 初始化ocr
if
self
.
apply_ocr
:
if
self
.
apply_ocr
:
self
.
ocr_model
=
ModifiedPaddleOCR
(
show_log
=
show_log
)
self
.
ocr_model
=
ModifiedPaddleOCR
(
show_log
=
show_log
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment