Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
LPRNet_pytorch
Commits
1fe2937e
Commit
1fe2937e
authored
Mar 06, 2023
by
liuhy
Browse files
修改test代码
parent
40622bae
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
85 additions
and
52 deletions
+85
-52
test.py
test.py
+85
-52
No files found.
test.py
View file @
1fe2937e
...
...
@@ -4,60 +4,89 @@ import os
import
torch
import
numpy
as
np
from
lprnet
import
build_lprnet
from
load_data
import
CHARS
from
load_data
import
CHARS
,
LPRDataLoader
import
time
from
torch.utils.data
import
*
from
torch.autograd
import
Variable
def
infer
(
args
,
image
,
model
):
img
=
cv2
.
imread
(
image
)
height
,
width
,
_
=
img
.
shape
if
height
!=
24
or
width
!=
94
:
img
=
cv2
.
resize
(
img
,
(
94
,
24
))
img
=
img
.
astype
(
'float32'
)
img
-=
127.5
img
*=
0.0078125
img
=
np
.
transpose
(
img
,
(
2
,
0
,
1
))
def
collate_fn
(
batch
):
imgs
=
[]
labels
=
[]
lengths
=
[]
for
_
,
sample
in
enumerate
(
batch
):
img
,
label
,
length
=
sample
imgs
.
append
(
torch
.
from_numpy
(
img
))
labels
.
extend
(
label
)
lengths
.
append
(
length
)
labels
=
np
.
asarray
(
labels
).
flatten
().
astype
(
np
.
float32
)
with
torch
.
no_grad
():
img
=
torch
.
from_numpy
(
img
).
unsqueeze
(
0
).
to
(
args
.
device
)
preb
=
model
(
img
)
preb
=
preb
.
detach
().
cpu
().
numpy
().
squeeze
()
preb_label
=
[]
for
j
in
range
(
preb
.
shape
[
1
]):
preb_label
.
append
(
np
.
argmax
(
preb
[:,
j
],
axis
=
0
))
no_repeat_blank_label
=
[]
pre_c
=
preb_label
[
0
]
if
pre_c
!=
len
(
CHARS
)
-
1
:
no_repeat_blank_label
.
append
(
pre_c
)
for
c
in
preb_label
:
if
(
pre_c
==
c
)
or
(
c
==
len
(
CHARS
)
-
1
):
if
c
==
len
(
CHARS
)
-
1
:
return
(
torch
.
stack
(
imgs
,
0
),
torch
.
from_numpy
(
labels
),
lengths
)
def
Greedy_Decode_Eval
(
Net
,
datasets
,
args
):
epoch_size
=
len
(
datasets
)
//
args
.
batch_size
batch_iterator
=
iter
(
DataLoader
(
datasets
,
args
.
batch_size
,
shuffle
=
True
,
num_workers
=
args
.
num_workers
,
collate_fn
=
collate_fn
))
Tp
=
0
Tn_1
=
0
Tn_2
=
0
t1
=
time
.
time
()
for
i
in
range
(
epoch_size
):
# load train data
images
,
labels
,
lengths
=
next
(
batch_iterator
)
start
=
0
targets
=
[]
for
length
in
lengths
:
label
=
labels
[
start
:
start
+
length
]
targets
.
append
(
label
)
start
+=
length
targets
=
np
.
array
([
el
.
numpy
()
for
el
in
targets
])
if
args
.
cuda
:
images
=
Variable
(
images
.
cuda
())
else
:
images
=
Variable
(
images
)
# forward
prebs
=
Net
(
images
)
# greedy decode
prebs
=
prebs
.
cpu
().
detach
().
numpy
()
preb_max
=
np
.
argmax
(
prebs
,
axis
=
1
)
preb_labels
=
list
()
for
preb
in
preb_max
:
no_repeat_blank_label
=
list
()
pre_c
=
preb
[
0
]
if
pre_c
!=
len
(
CHARS
)
-
1
:
no_repeat_blank_label
.
append
(
pre_c
)
for
c
in
preb
:
# dropout repeate label and blank label
if
(
pre_c
==
c
)
or
(
c
==
len
(
CHARS
)
-
1
):
if
c
==
len
(
CHARS
)
-
1
:
pre_c
=
c
continue
no_repeat_blank_label
.
append
(
c
)
pre_c
=
c
continue
no_repeat_blank_label
.
append
(
c
)
pre_c
=
c
return
''
.
join
(
list
(
map
(
lambda
x
:
CHARS
[
x
],
no_repeat_blank_label
)))
preb_labels
.
append
(
no_repeat_blank_label
)
for
i
,
label
in
enumerate
(
preb_labels
):
if
len
(
label
)
!=
len
(
targets
[
i
]):
Tn_1
+=
1
continue
if
(
np
.
asarray
(
targets
[
i
])
==
np
.
asarray
(
label
)).
all
():
Tp
+=
1
else
:
Tn_2
+=
1
Acc
=
Tp
*
1.0
/
(
Tp
+
Tn_1
+
Tn_2
)
print
(
"[Info] Test Accuracy: {} [{}:{}:{}:{}]"
.
format
(
Acc
,
Tp
,
Tn_1
,
Tn_2
,
(
Tp
+
Tn_1
+
Tn_2
)))
t2
=
time
.
time
()
print
(
"[Info] Test Speed: {}s 1/{}]"
.
format
((
t2
-
t1
)
/
len
(
datasets
),
len
(
datasets
)))
def
validation
(
args
):
model
=
build_lprnet
(
len
(
CHARS
))
model
.
load_state_dict
(
torch
.
load
(
args
.
model
,
map_location
=
args
.
device
))
model
.
to
(
args
.
device
)
lprnet
=
build_lprnet
(
class_num
=
len
(
CHARS
),
phase
=
args
.
phase_train
)
lprnet
.
load_state_dict
(
torch
.
load
(
args
.
model
))
lprnet
.
to
(
args
.
device
)
print
(
"Successful to build network!"
)
test_img_dirs
=
os
.
path
.
expanduser
(
args
.
imgpath
)
test_dataset
=
LPRDataLoader
(
test_img_dirs
.
split
(
','
),
args
.
img_size
)
Greedy_Decode_Eval
(
lprnet
,
test_dataset
,
args
)
if
os
.
path
.
isdir
(
args
.
imgpath
):
images
=
os
.
listdir
(
args
.
imgpath
)
count
=
0
time1
=
time
.
perf_counter
()
for
image
in
images
:
result
=
infer
(
args
,
os
.
path
.
join
(
args
.
imgpath
,
image
),
model
)
if
result
==
image
[:
-
4
]:
count
+=
1
print
(
'Image: '
+
image
+
' recongise result: '
+
result
)
time2
=
time
.
perf_counter
()
print
(
'accuracy rate:'
,
count
/
len
(
images
))
print
(
'average time'
,
(
time2
-
time1
)
/
count
*
1000
)
else
:
result
=
infer
(
args
,
args
.
imgpath
,
model
)
print
(
'Image: '
+
args
.
imgpath
+
' recongise result: '
+
result
)
if
args
.
export_onnx
:
print
(
'export pytorch model to onnx model...'
)
onnx_input
=
torch
.
randn
(
1
,
3
,
24
,
94
,
device
=
args
.
device
)
...
...
@@ -70,16 +99,20 @@ def validation(args):
dynamic_axes
=
{
'input'
:
{
0
:
'batch'
},
'output'
:
{
0
:
'batch'
}}
if
args
.
dynamic
else
None
,
opset_version
=
12
,
)
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
(
description
=
'parameters to vaildate net'
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
description
=
'parameters to train net'
)
parser
.
add_argument
(
'--img_size'
,
default
=
[
94
,
24
],
help
=
'the image size'
)
parser
.
add_argument
(
'--imgpath'
,
default
=
"imgs"
,
help
=
'the image path'
)
parser
.
add_argument
(
'--model'
,
default
=
'model/lprnet.pth'
,
help
=
'model path to vaildate'
)
parser
.
add_argument
(
'--imgpath'
,
default
=
'imgs'
,
help
=
'the image path'
)
parser
.
add_argument
(
'--batch_size'
,
default
=
100
,
type
=
int
,
help
=
'testing batch size.'
)
parser
.
add_argument
(
'--cuda'
,
default
=
True
,
type
=
bool
,
help
=
'Use cuda to train model'
)
parser
.
add_argument
(
'--device'
,
default
=
'cuda'
,
help
=
'Use cuda to vaildate model'
)
parser
.
add_argument
(
'--export_onnx'
,
default
=
False
,
help
=
'export model to onnx'
)
parser
.
add_argument
(
'--dynamic'
,
default
=
False
,
help
=
'use dynamic batch size'
)
parser
.
add_argument
(
'--phase_train'
,
default
=
False
,
type
=
bool
,
help
=
'train or test phase flag.'
)
parser
.
add_argument
(
'--num_workers'
,
default
=
8
,
type
=
int
,
help
=
'Number of workers used in dataloading'
)
args
=
parser
.
parse_args
()
validation
(
args
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment