Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
MinerU
Commits
8ccfff6f
Unverified
Commit
8ccfff6f
authored
Dec 12, 2024
by
Xiaomeng Zhao
Committed by
GitHub
Dec 12, 2024
Browse files
Merge pull request #1281 from IMSUVEN/dev
feat: add batch prediction methods for YOLOv8 and Unimernet models
parents
d0a3058b
7ce9edc6
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
364 additions
and
36 deletions
+364
-36
magic_pdf/config/exceptions.py
magic_pdf/config/exceptions.py
+7
-0
magic_pdf/model/batch_analyze.py
magic_pdf/model/batch_analyze.py
+228
-0
magic_pdf/model/sub_modules/layout/doclayout_yolo/DocLayoutYOLO.py
.../model/sub_modules/layout/doclayout_yolo/DocLayoutYOLO.py
+41
-7
magic_pdf/model/sub_modules/mfd/yolov8/YOLOv8.py
magic_pdf/model/sub_modules/mfd/yolov8/YOLOv8.py
+18
-2
magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py
magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py
+70
-27
No files found.
magic_pdf/config/exceptions.py
View file @
8ccfff6f
...
@@ -30,3 +30,10 @@ class EmptyData(Exception):
...
@@ -30,3 +30,10 @@ class EmptyData(Exception):
def
__str__
(
self
):
def
__str__
(
self
):
return
f
'Empty data:
{
self
.
msg
}
'
return
f
'Empty data:
{
self
.
msg
}
'
class
CUDA_NOT_AVAILABLE
(
Exception
):
def
__init__
(
self
,
msg
):
self
.
msg
=
msg
def
__str__
(
self
):
return
f
'CUDA not available:
{
self
.
msg
}
'
\ No newline at end of file
magic_pdf/model/batch_analyze.py
0 → 100644
View file @
8ccfff6f
import
time
import
cv2
import
numpy
as
np
import
torch
from
loguru
import
logger
from
PIL
import
Image
from
magic_pdf.config.constants
import
MODEL_NAME
from
magic_pdf.config.exceptions
import
CUDA_NOT_AVAILABLE
from
magic_pdf.data.dataset
import
Dataset
from
magic_pdf.libs.clean_memory
import
clean_memory
from
magic_pdf.model.doc_analyze_by_custom_model
import
ModelSingleton
from
magic_pdf.model.operators
import
InferenceResult
from
magic_pdf.model.pdf_extract_kit
import
CustomPEKModel
from
magic_pdf.model.sub_modules.model_utils
import
(
clean_vram
,
crop_img
,
get_res_list_from_layout_res
,
)
from
magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils
import
(
get_adjusted_mfdetrec_res
,
get_ocr_result_list
,
)
YOLO_LAYOUT_BASE_BATCH_SIZE
=
4
MFD_BASE_BATCH_SIZE
=
1
MFR_BASE_BATCH_SIZE
=
16
class
BatchAnalyze
:
def
__init__
(
self
,
model
:
CustomPEKModel
,
batch_ratio
:
int
):
self
.
model
=
model
self
.
batch_ratio
=
batch_ratio
def
__call__
(
self
,
images
:
list
)
->
list
:
if
self
.
model
.
layout_model_name
==
MODEL_NAME
.
LAYOUTLMv3
:
# layoutlmv3
images_layout_res
=
[]
for
image
in
images
:
layout_res
=
self
.
model
.
layout_model
(
image
,
ignore_catids
=
[])
images_layout_res
.
append
(
layout_res
)
elif
self
.
model
.
layout_model_name
==
MODEL_NAME
.
DocLayout_YOLO
:
# doclayout_yolo
images_layout_res
=
self
.
model
.
layout_model
.
batch_predict
(
images
,
self
.
batch_ratio
*
YOLO_LAYOUT_BASE_BATCH_SIZE
)
if
self
.
model
.
apply_formula
:
# 公式检测
images_mfd_res
=
self
.
model
.
mfd_model
.
batch_predict
(
images
,
self
.
batch_ratio
*
MFD_BASE_BATCH_SIZE
)
# 公式识别
images_formula_list
=
self
.
model
.
mfr_model
.
batch_predict
(
images_mfd_res
,
images
,
batch_size
=
self
.
batch_ratio
*
MFR_BASE_BATCH_SIZE
,
)
for
image_index
in
range
(
len
(
images
)):
images_layout_res
[
image_index
]
+=
images_formula_list
[
image_index
]
# 清理显存
clean_vram
(
self
.
model
.
device
,
vram_threshold
=
8
)
# reference: magic_pdf/model/doc_analyze_by_custom_model.py:doc_analyze
for
index
in
range
(
len
(
images
)):
layout_res
=
images_layout_res
[
index
]
pil_img
=
Image
.
fromarray
(
images
[
index
])
ocr_res_list
,
table_res_list
,
single_page_mfdetrec_res
=
(
get_res_list_from_layout_res
(
layout_res
)
)
# ocr识别
ocr_start
=
time
.
time
()
# Process each area that requires OCR processing
for
res
in
ocr_res_list
:
new_image
,
useful_list
=
crop_img
(
res
,
pil_img
,
crop_paste_x
=
50
,
crop_paste_y
=
50
)
adjusted_mfdetrec_res
=
get_adjusted_mfdetrec_res
(
single_page_mfdetrec_res
,
useful_list
)
# OCR recognition
new_image
=
cv2
.
cvtColor
(
np
.
asarray
(
new_image
),
cv2
.
COLOR_RGB2BGR
)
if
self
.
model
.
apply_ocr
:
ocr_res
=
self
.
model
.
ocr_model
.
ocr
(
new_image
,
mfd_res
=
adjusted_mfdetrec_res
)[
0
]
else
:
ocr_res
=
self
.
model
.
ocr_model
.
ocr
(
new_image
,
mfd_res
=
adjusted_mfdetrec_res
,
rec
=
False
)[
0
]
# Integration results
if
ocr_res
:
ocr_result_list
=
get_ocr_result_list
(
ocr_res
,
useful_list
)
layout_res
.
extend
(
ocr_result_list
)
ocr_cost
=
round
(
time
.
time
()
-
ocr_start
,
2
)
if
self
.
model
.
apply_ocr
:
logger
.
info
(
f
"ocr time:
{
ocr_cost
}
"
)
else
:
logger
.
info
(
f
"det time:
{
ocr_cost
}
"
)
# 表格识别 table recognition
if
self
.
model
.
apply_table
:
table_start
=
time
.
time
()
for
res
in
table_res_list
:
new_image
,
_
=
crop_img
(
res
,
pil_img
)
single_table_start_time
=
time
.
time
()
html_code
=
None
if
self
.
model
.
table_model_name
==
MODEL_NAME
.
STRUCT_EQTABLE
:
with
torch
.
no_grad
():
table_result
=
self
.
model
.
table_model
.
predict
(
new_image
,
"html"
)
if
len
(
table_result
)
>
0
:
html_code
=
table_result
[
0
]
elif
self
.
model
.
table_model_name
==
MODEL_NAME
.
TABLE_MASTER
:
html_code
=
self
.
model
.
table_model
.
img2html
(
new_image
)
elif
self
.
model
.
table_model_name
==
MODEL_NAME
.
RAPID_TABLE
:
html_code
,
table_cell_bboxes
,
elapse
=
(
self
.
model
.
table_model
.
predict
(
new_image
)
)
run_time
=
time
.
time
()
-
single_table_start_time
if
run_time
>
self
.
model
.
table_max_time
:
logger
.
warning
(
f
"table recognition processing exceeds max time
{
self
.
model
.
table_max_time
}
s"
)
# 判断是否返回正常
if
html_code
:
expected_ending
=
html_code
.
strip
().
endswith
(
"</html>"
)
or
html_code
.
strip
().
endswith
(
"</table>"
)
if
expected_ending
:
res
[
"html"
]
=
html_code
else
:
logger
.
warning
(
"table recognition processing fails, not found expected HTML table end"
)
else
:
logger
.
warning
(
"table recognition processing fails, not get html return"
)
logger
.
info
(
f
"table time:
{
round
(
time
.
time
()
-
table_start
,
2
)
}
"
)
def
doc_batch_analyze
(
dataset
:
Dataset
,
ocr
:
bool
=
False
,
show_log
:
bool
=
False
,
start_page_id
=
0
,
end_page_id
=
None
,
lang
=
None
,
layout_model
=
None
,
formula_enable
=
None
,
table_enable
=
None
,
batch_ratio
:
int
|
None
=
None
,
)
->
InferenceResult
:
"""
Perform batch analysis on a document dataset.
Args:
dataset (Dataset): The dataset containing document pages to be analyzed.
ocr (bool, optional): Flag to enable OCR (Optical Character Recognition). Defaults to False.
show_log (bool, optional): Flag to enable logging. Defaults to False.
start_page_id (int, optional): The starting page ID for analysis. Defaults to 0.
end_page_id (int, optional): The ending page ID for analysis. Defaults to None, which means analyze till the last page.
lang (str, optional): Language for OCR. Defaults to None.
layout_model (optional): Layout model to be used for analysis. Defaults to None.
formula_enable (optional): Flag to enable formula detection. Defaults to None.
table_enable (optional): Flag to enable table detection. Defaults to None.
batch_ratio (int | None, optional): Ratio for batch processing. Defaults to None, which sets it to 1.
Raises:
CUDA_NOT_AVAILABLE: If CUDA is not available, raises an exception as batch analysis is not supported in CPU mode.
Returns:
InferenceResult: The result of the batch analysis containing the analyzed data and the dataset.
"""
if
not
torch
.
cuda
.
is_available
():
raise
CUDA_NOT_AVAILABLE
(
"batch analyze not support in CPU mode"
)
lang
=
None
if
lang
==
""
else
lang
# TODO: auto detect batch size
batch_ratio
=
1
if
batch_ratio
is
None
else
batch_ratio
end_page_id
=
end_page_id
if
end_page_id
else
len
(
dataset
)
model_manager
=
ModelSingleton
()
custom_model
:
CustomPEKModel
=
model_manager
.
get_model
(
ocr
,
show_log
,
lang
,
layout_model
,
formula_enable
,
table_enable
)
batch_model
=
BatchAnalyze
(
model
=
custom_model
,
batch_ratio
=
batch_ratio
)
model_json
=
[]
# batch analyze
images
=
[]
for
index
in
range
(
len
(
dataset
)):
if
start_page_id
<=
index
<=
end_page_id
:
page_data
=
dataset
.
get_page
(
index
)
img_dict
=
page_data
.
get_image
()
images
.
append
(
img_dict
[
"img"
])
analyze_result
=
batch_model
(
images
)
for
index
in
range
(
len
(
dataset
)):
page_data
=
dataset
.
get_page
(
index
)
img_dict
=
page_data
.
get_image
()
page_width
=
img_dict
[
"width"
]
page_height
=
img_dict
[
"height"
]
if
start_page_id
<=
index
<=
end_page_id
:
result
=
analyze_result
.
pop
(
0
)
else
:
result
=
[]
page_info
=
{
"page_no"
:
index
,
"height"
:
page_height
,
"width"
:
page_width
}
page_dict
=
{
"layout_dets"
:
result
,
"page_info"
:
page_info
}
model_json
.
append
(
page_dict
)
# TODO: clean memory when gpu memory is not enough
clean_memory
()
return
InferenceResult
(
model_json
,
dataset
)
magic_pdf/model/sub_modules/layout/doclayout_yolo/DocLayoutYOLO.py
View file @
8ccfff6f
...
@@ -8,14 +8,48 @@ class DocLayoutYOLOModel(object):
...
@@ -8,14 +8,48 @@ class DocLayoutYOLOModel(object):
def
predict
(
self
,
image
):
def
predict
(
self
,
image
):
layout_res
=
[]
layout_res
=
[]
doclayout_yolo_res
=
self
.
model
.
predict
(
image
,
imgsz
=
1024
,
conf
=
0.25
,
iou
=
0.45
,
verbose
=
True
,
device
=
self
.
device
)[
0
]
doclayout_yolo_res
=
self
.
model
.
predict
(
for
xyxy
,
conf
,
cla
in
zip
(
doclayout_yolo_res
.
boxes
.
xyxy
.
cpu
(),
doclayout_yolo_res
.
boxes
.
conf
.
cpu
(),
image
,
imgsz
=
1024
,
conf
=
0.25
,
iou
=
0.45
,
verbose
=
True
,
device
=
self
.
device
doclayout_yolo_res
.
boxes
.
cls
.
cpu
()):
)[
0
]
for
xyxy
,
conf
,
cla
in
zip
(
doclayout_yolo_res
.
boxes
.
xyxy
.
cpu
(),
doclayout_yolo_res
.
boxes
.
conf
.
cpu
(),
doclayout_yolo_res
.
boxes
.
cls
.
cpu
(),
):
xmin
,
ymin
,
xmax
,
ymax
=
[
int
(
p
.
item
())
for
p
in
xyxy
]
xmin
,
ymin
,
xmax
,
ymax
=
[
int
(
p
.
item
())
for
p
in
xyxy
]
new_item
=
{
new_item
=
{
'
category_id
'
:
int
(
cla
.
item
()),
"
category_id
"
:
int
(
cla
.
item
()),
'
poly
'
:
[
xmin
,
ymin
,
xmax
,
ymin
,
xmax
,
ymax
,
xmin
,
ymax
],
"
poly
"
:
[
xmin
,
ymin
,
xmax
,
ymin
,
xmax
,
ymax
,
xmin
,
ymax
],
'
score
'
:
round
(
float
(
conf
.
item
()),
3
),
"
score
"
:
round
(
float
(
conf
.
item
()),
3
),
}
}
layout_res
.
append
(
new_item
)
layout_res
.
append
(
new_item
)
return
layout_res
return
layout_res
\ No newline at end of file
def
batch_predict
(
self
,
images
:
list
,
batch_size
:
int
)
->
list
:
images_layout_res
=
[]
for
index
in
range
(
0
,
len
(
images
),
batch_size
):
doclayout_yolo_res
=
self
.
model
.
predict
(
images
[
index
:
index
+
batch_size
],
imgsz
=
1024
,
conf
=
0.25
,
iou
=
0.45
,
verbose
=
True
,
device
=
self
.
device
,
).
cpu
()
for
image_res
in
doclayout_yolo_res
:
layout_res
=
[]
for
xyxy
,
conf
,
cla
in
zip
(
image_res
.
boxes
.
xyxy
,
image_res
.
boxes
.
conf
,
image_res
.
boxes
.
cls
,
):
xmin
,
ymin
,
xmax
,
ymax
=
[
int
(
p
.
item
())
for
p
in
xyxy
]
new_item
=
{
"category_id"
:
int
(
cla
.
item
()),
"poly"
:
[
xmin
,
ymin
,
xmax
,
ymin
,
xmax
,
ymax
,
xmin
,
ymax
],
"score"
:
round
(
float
(
conf
.
item
()),
3
),
}
layout_res
.
append
(
new_item
)
images_layout_res
.
append
(
layout_res
)
return
images_layout_res
magic_pdf/model/sub_modules/mfd/yolov8/YOLOv8.py
View file @
8ccfff6f
...
@@ -2,11 +2,27 @@ from ultralytics import YOLO
...
@@ -2,11 +2,27 @@ from ultralytics import YOLO
class
YOLOv8MFDModel
(
object
):
class
YOLOv8MFDModel
(
object
):
def
__init__
(
self
,
weight
,
device
=
'
cpu
'
):
def
__init__
(
self
,
weight
,
device
=
"
cpu
"
):
self
.
mfd_model
=
YOLO
(
weight
)
self
.
mfd_model
=
YOLO
(
weight
)
self
.
device
=
device
self
.
device
=
device
def
predict
(
self
,
image
):
def
predict
(
self
,
image
):
mfd_res
=
self
.
mfd_model
.
predict
(
image
,
imgsz
=
1888
,
conf
=
0.25
,
iou
=
0.45
,
verbose
=
True
,
device
=
self
.
device
)[
0
]
mfd_res
=
self
.
mfd_model
.
predict
(
image
,
imgsz
=
1888
,
conf
=
0.25
,
iou
=
0.45
,
verbose
=
True
,
device
=
self
.
device
)[
0
]
return
mfd_res
return
mfd_res
def
batch_predict
(
self
,
images
:
list
,
batch_size
:
int
)
->
list
:
images_mfd_res
=
[]
for
index
in
range
(
0
,
len
(
images
),
batch_size
):
mfd_res
=
self
.
mfd_model
.
predict
(
images
[
index
:
index
+
batch_size
],
imgsz
=
1888
,
conf
=
0.25
,
iou
=
0.45
,
verbose
=
True
,
device
=
self
.
device
,
).
cpu
()
for
image_res
in
mfd_res
:
images_mfd_res
.
append
(
image_res
)
return
images_mfd_res
magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py
View file @
8ccfff6f
import
os
import
argparse
import
argparse
import
os
import
re
import
re
from
PIL
import
Image
import
torch
import
torch
from
torch.utils.data
import
Dataset
,
DataLoader
import
unimernet.tasks
as
tasks
from
PIL
import
Image
from
torch.utils.data
import
DataLoader
,
Dataset
from
torchvision
import
transforms
from
torchvision
import
transforms
from
unimernet.common.config
import
Config
from
unimernet.common.config
import
Config
import
unimernet.tasks
as
tasks
from
unimernet.processors
import
load_processor
from
unimernet.processors
import
load_processor
...
@@ -31,27 +31,25 @@ class MathDataset(Dataset):
...
@@ -31,27 +31,25 @@ class MathDataset(Dataset):
def
latex_rm_whitespace
(
s
:
str
):
def
latex_rm_whitespace
(
s
:
str
):
"""Remove unnecessary whitespace from LaTeX code.
"""Remove unnecessary whitespace from LaTeX code."""
"""
text_reg
=
r
"(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})"
text_reg
=
r
'(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})'
letter
=
"[a-zA-Z]"
letter
=
'[a-zA-Z]'
noletter
=
"[\W_^\d]"
noletter
=
'[\W_^\d]'
names
=
[
x
[
0
].
replace
(
" "
,
""
)
for
x
in
re
.
findall
(
text_reg
,
s
)]
names
=
[
x
[
0
].
replace
(
' '
,
''
)
for
x
in
re
.
findall
(
text_reg
,
s
)]
s
=
re
.
sub
(
text_reg
,
lambda
match
:
str
(
names
.
pop
(
0
)),
s
)
s
=
re
.
sub
(
text_reg
,
lambda
match
:
str
(
names
.
pop
(
0
)),
s
)
news
=
s
news
=
s
while
True
:
while
True
:
s
=
news
s
=
news
news
=
re
.
sub
(
r
'
(?!\\ )(%s)\s+?(%s)
'
%
(
noletter
,
noletter
),
r
'
\1\2
'
,
s
)
news
=
re
.
sub
(
r
"
(?!\\ )(%s)\s+?(%s)
"
%
(
noletter
,
noletter
),
r
"
\1\2
"
,
s
)
news
=
re
.
sub
(
r
'
(?!\\ )(%s)\s+?(%s)
'
%
(
noletter
,
letter
),
r
'
\1\2
'
,
news
)
news
=
re
.
sub
(
r
"
(?!\\ )(%s)\s+?(%s)
"
%
(
noletter
,
letter
),
r
"
\1\2
"
,
news
)
news
=
re
.
sub
(
r
'
(%s)\s+?(%s)
'
%
(
letter
,
noletter
),
r
'
\1\2
'
,
news
)
news
=
re
.
sub
(
r
"
(%s)\s+?(%s)
"
%
(
letter
,
noletter
),
r
"
\1\2
"
,
news
)
if
news
==
s
:
if
news
==
s
:
break
break
return
s
return
s
class
UnimernetModel
(
object
):
class
UnimernetModel
(
object
):
def
__init__
(
self
,
weight_dir
,
cfg_path
,
_device_
=
'cpu'
):
def
__init__
(
self
,
weight_dir
,
cfg_path
,
_device_
=
"cpu"
):
args
=
argparse
.
Namespace
(
cfg_path
=
cfg_path
,
options
=
None
)
args
=
argparse
.
Namespace
(
cfg_path
=
cfg_path
,
options
=
None
)
cfg
=
Config
(
args
)
cfg
=
Config
(
args
)
cfg
.
config
.
model
.
pretrained
=
os
.
path
.
join
(
weight_dir
,
"pytorch_model.pth"
)
cfg
.
config
.
model
.
pretrained
=
os
.
path
.
join
(
weight_dir
,
"pytorch_model.pth"
)
...
@@ -62,20 +60,28 @@ class UnimernetModel(object):
...
@@ -62,20 +60,28 @@ class UnimernetModel(object):
self
.
device
=
_device_
self
.
device
=
_device_
self
.
model
.
to
(
_device_
)
self
.
model
.
to
(
_device_
)
self
.
model
.
eval
()
self
.
model
.
eval
()
vis_processor
=
load_processor
(
'formula_image_eval'
,
cfg
.
config
.
datasets
.
formula_rec_eval
.
vis_processor
.
eval
)
vis_processor
=
load_processor
(
self
.
mfr_transform
=
transforms
.
Compose
([
vis_processor
,
])
"formula_image_eval"
,
cfg
.
config
.
datasets
.
formula_rec_eval
.
vis_processor
.
eval
,
)
self
.
mfr_transform
=
transforms
.
Compose
(
[
vis_processor
,
]
)
def
predict
(
self
,
mfd_res
,
image
):
def
predict
(
self
,
mfd_res
,
image
):
formula_list
=
[]
formula_list
=
[]
mf_image_list
=
[]
mf_image_list
=
[]
for
xyxy
,
conf
,
cla
in
zip
(
mfd_res
.
boxes
.
xyxy
.
cpu
(),
mfd_res
.
boxes
.
conf
.
cpu
(),
mfd_res
.
boxes
.
cls
.
cpu
()):
for
xyxy
,
conf
,
cla
in
zip
(
mfd_res
.
boxes
.
xyxy
.
cpu
(),
mfd_res
.
boxes
.
conf
.
cpu
(),
mfd_res
.
boxes
.
cls
.
cpu
()
):
xmin
,
ymin
,
xmax
,
ymax
=
[
int
(
p
.
item
())
for
p
in
xyxy
]
xmin
,
ymin
,
xmax
,
ymax
=
[
int
(
p
.
item
())
for
p
in
xyxy
]
new_item
=
{
new_item
=
{
'
category_id
'
:
13
+
int
(
cla
.
item
()),
"
category_id
"
:
13
+
int
(
cla
.
item
()),
'
poly
'
:
[
xmin
,
ymin
,
xmax
,
ymin
,
xmax
,
ymax
,
xmin
,
ymax
],
"
poly
"
:
[
xmin
,
ymin
,
xmax
,
ymin
,
xmax
,
ymax
,
xmin
,
ymax
],
'
score
'
:
round
(
float
(
conf
.
item
()),
2
),
"
score
"
:
round
(
float
(
conf
.
item
()),
2
),
'
latex
'
:
''
,
"
latex
"
:
""
,
}
}
formula_list
.
append
(
new_item
)
formula_list
.
append
(
new_item
)
pil_img
=
Image
.
fromarray
(
image
)
pil_img
=
Image
.
fromarray
(
image
)
...
@@ -88,11 +94,48 @@ class UnimernetModel(object):
...
@@ -88,11 +94,48 @@ class UnimernetModel(object):
for
mf_img
in
dataloader
:
for
mf_img
in
dataloader
:
mf_img
=
mf_img
.
to
(
self
.
device
)
mf_img
=
mf_img
.
to
(
self
.
device
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
output
=
self
.
model
.
generate
({
'
image
'
:
mf_img
})
output
=
self
.
model
.
generate
({
"
image
"
:
mf_img
})
mfr_res
.
extend
(
output
[
'
pred_str
'
])
mfr_res
.
extend
(
output
[
"
pred_str
"
])
for
res
,
latex
in
zip
(
formula_list
,
mfr_res
):
for
res
,
latex
in
zip
(
formula_list
,
mfr_res
):
res
[
'
latex
'
]
=
latex_rm_whitespace
(
latex
)
res
[
"
latex
"
]
=
latex_rm_whitespace
(
latex
)
return
formula_list
return
formula_list
def
batch_predict
(
self
,
images_mfd_res
:
list
,
images
:
list
,
batch_size
:
int
=
64
)
->
list
:
images_formula_list
=
[]
mf_image_list
=
[]
backfill_list
=
[]
for
image_index
in
range
(
len
(
images_mfd_res
)):
mfd_res
=
images_mfd_res
[
image_index
]
pil_img
=
Image
.
fromarray
(
images
[
image_index
])
formula_list
=
[]
for
xyxy
,
conf
,
cla
in
zip
(
mfd_res
.
boxes
.
xyxy
,
mfd_res
.
boxes
.
conf
,
mfd_res
.
boxes
.
cls
):
xmin
,
ymin
,
xmax
,
ymax
=
[
int
(
p
.
item
())
for
p
in
xyxy
]
new_item
=
{
"category_id"
:
13
+
int
(
cla
.
item
()),
"poly"
:
[
xmin
,
ymin
,
xmax
,
ymin
,
xmax
,
ymax
,
xmin
,
ymax
],
"score"
:
round
(
float
(
conf
.
item
()),
2
),
"latex"
:
""
,
}
formula_list
.
append
(
new_item
)
bbox_img
=
pil_img
.
crop
((
xmin
,
ymin
,
xmax
,
ymax
))
mf_image_list
.
append
(
bbox_img
)
images_formula_list
.
append
(
formula_list
)
backfill_list
+=
formula_list
dataset
=
MathDataset
(
mf_image_list
,
transform
=
self
.
mfr_transform
)
dataloader
=
DataLoader
(
dataset
,
batch_size
=
batch_size
,
num_workers
=
0
)
mfr_res
=
[]
for
mf_img
in
dataloader
:
mf_img
=
mf_img
.
to
(
self
.
device
)
with
torch
.
no_grad
():
output
=
self
.
model
.
generate
({
"image"
:
mf_img
})
mfr_res
.
extend
(
output
[
"pred_str"
])
for
res
,
latex
in
zip
(
backfill_list
,
mfr_res
):
res
[
"latex"
]
=
latex_rm_whitespace
(
latex
)
return
images_formula_list
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment