Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
LPRNet_pytorch
Commits
dfc4d118
"src/vscode:/vscode.git/clone" did not exist on "ea8ae8c6397d8333760471e573e4d8ca4646efd0"
Commit
dfc4d118
authored
Mar 01, 2023
by
liuhy
Browse files
修改代码
parent
9adcf60d
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
130 additions
and
121 deletions
+130
-121
LPRNet_ORT_infer.py
LPRNet_ORT_infer.py
+11
-3
test.py
test.py
+29
-11
train.py
train.py
+90
-107
No files found.
LPRNet_ORT_infer.py
View file @
dfc4d118
...
@@ -60,6 +60,14 @@ def LPRNetInference(model, imgs):
...
@@ -60,6 +60,14 @@ def LPRNetInference(model, imgs):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
model_name
=
'model/LPRNet.onnx'
model_name
=
'model/LPRNet.onnx'
# model_name = 'LPRNet.onnx'
# model_name = 'LPRNet.onnx'
image
=
'imgs/川JK0707.jpg'
# image = 'imgs/川JK0707.jpg'
InferRes
=
LPRNetInference
(
model_name
,
image
)
import
os
print
(
image
,
'Inference Result:'
,
InferRes
)
images
=
os
.
listdir
(
'/code/lpr_ori/data/test'
)
count
=
0
for
image
in
images
:
label
=
image
[:
-
4
]
InferRes
=
LPRNetInference
(
model_name
,
os
.
path
.
join
(
'/code/lpr_ori/data/test'
,
image
))
print
(
image
,
'Inference Result:'
,
InferRes
)
if
label
==
InferRes
:
count
+=
1
print
(
'acc rate:'
,
count
/
len
(
images
))
test.py
View file @
dfc4d118
import
argparse
import
argparse
import
cv2
import
cv2
import
os
import
torch
import
torch
import
numpy
as
np
import
numpy
as
np
from
lprnet
import
build_lprnet
from
lprnet
import
build_lprnet
from
load_data
import
CHARS
from
load_data
import
CHARS
def
infer
(
args
,
image
,
model
):
def
validation
(
args
):
img
=
cv2
.
imread
(
image
)
model
=
build_lprnet
(
len
(
CHARS
))
model
.
load_state_dict
(
torch
.
load
(
args
.
model
,
map_location
=
args
.
device
))
model
.
to
(
args
.
device
)
img
=
cv2
.
imread
(
args
.
img
)
height
,
width
,
_
=
img
.
shape
height
,
width
,
_
=
img
.
shape
if
height
!=
24
or
width
!=
94
:
if
height
!=
24
or
width
!=
94
:
img
=
cv2
.
resize
(
img
,
(
94
,
24
))
img
=
cv2
.
resize
(
img
,
(
94
,
24
))
...
@@ -38,7 +34,26 @@ def validation(args):
...
@@ -38,7 +34,26 @@ def validation(args):
continue
continue
no_repeat_blank_label
.
append
(
c
)
no_repeat_blank_label
.
append
(
c
)
pre_c
=
c
pre_c
=
c
return
''
.
join
(
list
(
map
(
lambda
x
:
CHARS
[
x
],
no_repeat_blank_label
)))
def
validation
(
args
):
model
=
build_lprnet
(
len
(
CHARS
))
model
.
load_state_dict
(
torch
.
load
(
args
.
model
,
map_location
=
args
.
device
))
model
.
to
(
args
.
device
)
if
os
.
path
.
isdir
(
args
.
imgpath
):
images
=
os
.
listdir
(
args
.
imgpath
)
count
=
0
for
image
in
images
:
res
=
infer
(
args
,
os
.
path
.
join
(
args
.
imgpath
,
image
),
model
)
if
res
==
image
[:
-
4
]:
count
+=
1
print
(
'Image: '
+
image
+
' recongise result: '
+
res
)
print
(
'acc rate:'
,
count
/
len
(
images
))
else
:
res
=
infer
(
args
,
args
.
imgpath
,
model
)
print
(
'Image: '
+
args
.
imgpath
+
' recongise result: '
+
res
)
if
args
.
export_onnx
:
if
args
.
export_onnx
:
print
(
'export pytroch model to onnx model...'
)
print
(
'export pytroch model to onnx model...'
)
onnx_input
=
torch
.
randn
(
1
,
3
,
24
,
94
,
device
=
args
.
device
)
onnx_input
=
torch
.
randn
(
1
,
3
,
24
,
94
,
device
=
args
.
device
)
...
@@ -51,16 +66,19 @@ def validation(args):
...
@@ -51,16 +66,19 @@ def validation(args):
dynamic_axes
=
{
'input'
:
{
0
:
'batch'
},
'output'
:
{
0
:
'batch'
}}
if
args
.
dynamic
else
None
,
dynamic_axes
=
{
'input'
:
{
0
:
'batch'
},
'output'
:
{
0
:
'batch'
}}
if
args
.
dynamic
else
None
,
opset_version
=
12
,
opset_version
=
12
,
)
)
return
''
.
join
(
list
(
map
(
lambda
x
:
CHARS
[
x
],
no_repeat_blank_label
)))
return
res
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
(
description
=
'parameters to vaildate net'
)
parser
=
argparse
.
ArgumentParser
(
description
=
'parameters to vaildate net'
)
parser
.
add_argument
(
'--model'
,
default
=
'model/lprnet.pth'
,
help
=
'model path to vaildate'
)
# parser.add_argument('--model', default='model/lprnet.pth', help='model path to vaildate')
parser
.
add_argument
(
'--img'
,
default
=
'imgs/川JK0707.jpg'
,
help
=
'the image path'
)
parser
.
add_argument
(
'--model'
,
default
=
'weights/Final_LPRNet_model.pth'
,
help
=
'model path to vaildate'
)
# parser.add_argument('--imgpath', default='imgs/川JK0707.jpg', help='the image path')
parser
.
add_argument
(
'--imgpath'
,
default
=
'/code/lpr_ori/data/test'
,
help
=
'the image path'
)
parser
.
add_argument
(
'--device'
,
default
=
'cuda'
,
help
=
'Use cuda to vaildate model'
)
parser
.
add_argument
(
'--device'
,
default
=
'cuda'
,
help
=
'Use cuda to vaildate model'
)
parser
.
add_argument
(
'--export_onnx'
,
default
=
False
,
help
=
'export model to onnx'
)
parser
.
add_argument
(
'--export_onnx'
,
default
=
False
,
help
=
'export model to onnx'
)
parser
.
add_argument
(
'--dynamic'
,
default
=
False
,
help
=
'use dynamic batch size'
)
parser
.
add_argument
(
'--dynamic'
,
default
=
False
,
help
=
'use dynamic batch size'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
result
=
validation
(
args
)
result
=
validation
(
args
)
print
(
'recongise result:'
,
result
)
train.py
View file @
dfc4d118
...
@@ -2,7 +2,6 @@
...
@@ -2,7 +2,6 @@
# /usr/bin/env/python3
# /usr/bin/env/python3
from
load_data
import
CHARS
,
CHARS_DICT
,
LPRDataLoader
from
load_data
import
CHARS
,
CHARS_DICT
,
LPRDataLoader
from
lprnet
import
build_lprnet
from
lprnet
import
build_lprnet
# import torch.backends.cudnn as cudnn
from
torch.autograd
import
Variable
from
torch.autograd
import
Variable
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
torch.utils.data
import
*
from
torch.utils.data
import
*
...
@@ -14,7 +13,7 @@ import torch
...
@@ -14,7 +13,7 @@ import torch
import
time
import
time
import
os
import
os
print
(
torch
.
cuda
.
is_available
())
print
(
'Cuda Availabel:'
,
torch
.
cuda
.
is_available
())
def
sparse_tuple_for_ctc
(
T_length
,
lengths
):
def
sparse_tuple_for_ctc
(
T_length
,
lengths
):
input_lengths
=
[]
input_lengths
=
[]
...
@@ -23,7 +22,6 @@ def sparse_tuple_for_ctc(T_length, lengths):
...
@@ -23,7 +22,6 @@ def sparse_tuple_for_ctc(T_length, lengths):
for
ch
in
lengths
:
for
ch
in
lengths
:
input_lengths
.
append
(
T_length
)
input_lengths
.
append
(
T_length
)
target_lengths
.
append
(
ch
)
target_lengths
.
append
(
ch
)
return
tuple
(
input_lengths
),
tuple
(
target_lengths
)
return
tuple
(
input_lengths
),
tuple
(
target_lengths
)
def
adjust_learning_rate
(
optimizer
,
cur_epoch
,
base_lr
,
lr_schedule
):
def
adjust_learning_rate
(
optimizer
,
cur_epoch
,
base_lr
,
lr_schedule
):
...
@@ -39,37 +37,8 @@ def adjust_learning_rate(optimizer, cur_epoch, base_lr, lr_schedule):
...
@@ -39,37 +37,8 @@ def adjust_learning_rate(optimizer, cur_epoch, base_lr, lr_schedule):
lr
=
base_lr
lr
=
base_lr
for
param_group
in
optimizer
.
param_groups
:
for
param_group
in
optimizer
.
param_groups
:
param_group
[
'lr'
]
=
lr
param_group
[
'lr'
]
=
lr
return
lr
return
lr
def
get_parser
():
parser
=
argparse
.
ArgumentParser
(
description
=
'parameters to train net'
)
parser
.
add_argument
(
'--max_epoch'
,
default
=
15
,
help
=
'epoch to train the network'
)
parser
.
add_argument
(
'--img_size'
,
default
=
[
94
,
24
],
help
=
'the image size'
)
parser
.
add_argument
(
'--train_img_dirs'
,
default
=
"data/train"
,
help
=
'the train images path'
)
parser
.
add_argument
(
'--test_img_dirs'
,
default
=
"data/test"
,
help
=
'the test images path'
)
parser
.
add_argument
(
'--dropout_rate'
,
default
=
0.5
,
help
=
'dropout rate.'
)
parser
.
add_argument
(
'--learning_rate'
,
default
=
0.1
,
help
=
'base value of learning rate.'
)
parser
.
add_argument
(
'--lpr_max_len'
,
default
=
8
,
help
=
'license plate number max length.'
)
parser
.
add_argument
(
'--train_batch_size'
,
default
=
64
,
help
=
'training batch size.'
)
parser
.
add_argument
(
'--test_batch_size'
,
default
=
10
,
help
=
'testing batch size.'
)
parser
.
add_argument
(
'--phase_train'
,
default
=
True
,
type
=
bool
,
help
=
'train or test phase flag.'
)
parser
.
add_argument
(
'--num_workers'
,
default
=
8
,
type
=
int
,
help
=
'Number of workers used in dataloading'
)
parser
.
add_argument
(
'--cuda'
,
default
=
True
,
type
=
bool
,
help
=
'Use cuda to train model'
)
parser
.
add_argument
(
'--resume_epoch'
,
default
=
10
,
type
=
int
,
help
=
'resume iter for retraining'
)
parser
.
add_argument
(
'--save_interval'
,
default
=
2000
,
type
=
int
,
help
=
'interval for save model state dict'
)
parser
.
add_argument
(
'--test_interval'
,
default
=
2000
,
type
=
int
,
help
=
'interval for evaluate'
)
parser
.
add_argument
(
'--momentum'
,
default
=
0.9
,
type
=
float
,
help
=
'momentum'
)
parser
.
add_argument
(
'--weight_decay'
,
default
=
2e-5
,
type
=
float
,
help
=
'Weight decay for SGD'
)
parser
.
add_argument
(
'--lr_schedule'
,
default
=
[
4
,
8
,
12
,
14
,
16
],
help
=
'schedule for learning rate.'
)
parser
.
add_argument
(
'--save_folder'
,
default
=
'./weights/'
,
help
=
'Location to save checkpoint models'
)
# parser.add_argument('--pretrained_model', default='./weights/Final_LPRNet_model.pth', help='pretrained base model')
parser
.
add_argument
(
'--pretrained_model'
,
default
=
''
,
help
=
'pretrained base model'
)
args
=
parser
.
parse_args
()
return
args
def
collate_fn
(
batch
):
def
collate_fn
(
batch
):
imgs
=
[]
imgs
=
[]
labels
=
[]
labels
=
[]
...
@@ -80,12 +49,71 @@ def collate_fn(batch):
...
@@ -80,12 +49,71 @@ def collate_fn(batch):
labels
.
extend
(
label
)
labels
.
extend
(
label
)
lengths
.
append
(
length
)
lengths
.
append
(
length
)
labels
=
np
.
asarray
(
labels
).
flatten
().
astype
(
np
.
int16
)
labels
=
np
.
asarray
(
labels
).
flatten
().
astype
(
np
.
int16
)
return
(
torch
.
stack
(
imgs
,
0
),
torch
.
from_numpy
(
labels
),
lengths
)
return
(
torch
.
stack
(
imgs
,
0
),
torch
.
from_numpy
(
labels
),
lengths
)
def
train
():
def
Greedy_Decode_Eval
(
Net
,
datasets
,
args
):
args
=
get_parser
()
epoch_size
=
len
(
datasets
)
//
args
.
test_batch_size
batch_iterator
=
iter
(
DataLoader
(
datasets
,
args
.
test_batch_size
,
shuffle
=
True
,
num_workers
=
args
.
num_workers
,
collate_fn
=
collate_fn
))
Tp
=
0
Tn_1
=
0
Tn_2
=
0
t1
=
time
.
time
()
for
i
in
range
(
epoch_size
):
# load train data
images
,
labels
,
lengths
=
next
(
batch_iterator
)
start
=
0
targets
=
[]
for
length
in
lengths
:
label
=
labels
[
start
:
start
+
length
]
targets
.
append
(
label
)
start
+=
length
targets
=
np
.
array
([
el
.
numpy
()
for
el
in
targets
])
if
args
.
cuda
:
images
=
Variable
(
images
.
cuda
())
else
:
images
=
Variable
(
images
)
# forward
Net
.
eval
()
prebs
=
Net
(
images
)
# greedy decode
prebs
=
prebs
.
cpu
().
detach
().
numpy
()
preb_labels
=
[]
for
i
in
range
(
prebs
.
shape
[
0
]):
preb
=
prebs
[
i
,
:,
:]
preb_label
=
[]
for
j
in
range
(
preb
.
shape
[
1
]):
preb_label
.
append
(
np
.
argmax
(
preb
[:,
j
],
axis
=
0
))
no_repeat_blank_label
=
[]
pre_c
=
preb_label
[
0
]
if
pre_c
!=
len
(
CHARS
)
-
1
:
no_repeat_blank_label
.
append
(
pre_c
)
for
c
in
preb_label
:
# dropout repeate label and blank label
if
(
pre_c
==
c
)
or
(
c
==
len
(
CHARS
)
-
1
):
if
c
==
len
(
CHARS
)
-
1
:
pre_c
=
c
continue
no_repeat_blank_label
.
append
(
c
)
pre_c
=
c
preb_labels
.
append
(
no_repeat_blank_label
)
for
i
,
label
in
enumerate
(
preb_labels
):
if
len
(
label
)
!=
len
(
targets
[
i
]):
Tn_1
+=
1
continue
if
(
np
.
asarray
(
targets
[
i
])
==
np
.
asarray
(
label
)).
all
():
Tp
+=
1
else
:
Tn_2
+=
1
Acc
=
Tp
*
1.0
/
(
Tp
+
Tn_1
+
Tn_2
)
print
(
"[Info] Test Accuracy: {} [{}:{}:{}:{}]"
.
format
(
Acc
,
Tp
,
Tn_1
,
Tn_2
,
(
Tp
+
Tn_1
+
Tn_2
)))
t2
=
time
.
time
()
print
(
"[Info] Test Speed: {}s 1/{}]"
.
format
((
t2
-
t1
)
/
len
(
datasets
),
len
(
datasets
)))
def
train
(
args
):
T_length
=
18
# args.lpr_max_len
T_length
=
18
# args.lpr_max_len
epoch
=
0
+
args
.
resume_epoch
epoch
=
0
+
args
.
resume_epoch
loss_val
=
0
loss_val
=
0
...
@@ -121,8 +149,6 @@ def train():
...
@@ -121,8 +149,6 @@ def train():
print
(
"initial net weights successful!"
)
print
(
"initial net weights successful!"
)
# define optimizer
# define optimizer
# optimizer = optim.SGD(lprnet.parameters(), lr=args.learning_rate,
# momentum=args.momentum, weight_decay=args.weight_decay)
optimizer
=
optim
.
RMSprop
(
lprnet
.
parameters
(),
lr
=
args
.
learning_rate
,
alpha
=
0.9
,
eps
=
1e-08
,
optimizer
=
optim
.
RMSprop
(
lprnet
.
parameters
(),
lr
=
args
.
learning_rate
,
alpha
=
0.9
,
eps
=
1e-08
,
momentum
=
args
.
momentum
,
weight_decay
=
args
.
weight_decay
)
momentum
=
args
.
momentum
,
weight_decay
=
args
.
weight_decay
)
train_img_dirs
=
os
.
path
.
expanduser
(
args
.
train_img_dirs
)
train_img_dirs
=
os
.
path
.
expanduser
(
args
.
train_img_dirs
)
...
@@ -148,17 +174,14 @@ def train():
...
@@ -148,17 +174,14 @@ def train():
epoch
+=
1
epoch
+=
1
if
iteration
!=
0
and
iteration
%
args
.
save_interval
==
0
:
if
iteration
!=
0
and
iteration
%
args
.
save_interval
==
0
:
torch
.
save
(
lprnet
.
state_dict
(),
args
.
save_folder
+
'LPRNet_'
+
'
_
iteration_'
+
repr
(
iteration
)
+
'.pth'
)
torch
.
save
(
lprnet
.
state_dict
(),
args
.
save_folder
+
'LPRNet_'
+
'iteration_'
+
repr
(
iteration
)
+
'.pth'
)
if
(
iteration
+
1
)
%
args
.
test_interval
==
0
:
if
(
iteration
+
1
)
%
args
.
test_interval
==
0
:
Greedy_Decode_Eval
(
lprnet
,
test_dataset
,
args
)
Greedy_Decode_Eval
(
lprnet
,
test_dataset
,
args
)
# lprnet.train() # should be switch to train mode
start_time
=
time
.
time
()
start_time
=
time
.
time
()
# load train data
# load train data
images
,
labels
,
lengths
=
next
(
batch_iterator
)
images
,
labels
,
lengths
=
next
(
batch_iterator
)
# labels = np.array([el.numpy() for el in labels]).T
# print(labels)
# get ctc parameters
# get ctc parameters
input_lengths
,
target_lengths
=
sparse_tuple_for_ctc
(
T_length
,
lengths
)
input_lengths
,
target_lengths
=
sparse_tuple_for_ctc
(
T_length
,
lengths
)
# update lr
# update lr
...
@@ -173,12 +196,8 @@ def train():
...
@@ -173,12 +196,8 @@ def train():
# forward
# forward
logits
=
lprnet
(
images
)
logits
=
lprnet
(
images
)
# print(logits.size())
log_probs
=
logits
.
permute
(
2
,
0
,
1
)
# for ctc loss: T x N x C
log_probs
=
logits
.
permute
(
2
,
0
,
1
)
# for ctc loss: T x N x C
# print(labels.shape)
log_probs
=
log_probs
.
log_softmax
(
2
).
requires_grad_
()
log_probs
=
log_probs
.
log_softmax
(
2
).
requires_grad_
()
# log_probs = log_probs.detach().requires_grad_()
# print(log_probs.shape)
# backprop
# backprop
optimizer
.
zero_grad
()
optimizer
.
zero_grad
()
loss
=
ctc_loss
(
log_probs
,
labels
,
input_lengths
=
input_lengths
,
target_lengths
=
target_lengths
)
loss
=
ctc_loss
(
log_probs
,
labels
,
input_lengths
=
input_lengths
,
target_lengths
=
target_lengths
)
...
@@ -199,67 +218,31 @@ def train():
...
@@ -199,67 +218,31 @@ def train():
# save final parameters
# save final parameters
torch
.
save
(
lprnet
.
state_dict
(),
args
.
save_folder
+
'Final_LPRNet_model.pth'
)
torch
.
save
(
lprnet
.
state_dict
(),
args
.
save_folder
+
'Final_LPRNet_model.pth'
)
def
Greedy_Decode_Eval
(
Net
,
datasets
,
args
):
def
get_parser
():
# TestNet = Net.eval()
parser
=
argparse
.
ArgumentParser
(
description
=
'parameters to train net'
)
epoch_size
=
len
(
datasets
)
//
args
.
test_batch_size
parser
.
add_argument
(
'--max_epoch'
,
default
=
15
,
help
=
'epoch to train the network'
)
batch_iterator
=
iter
(
DataLoader
(
datasets
,
args
.
test_batch_size
,
shuffle
=
True
,
num_workers
=
args
.
num_workers
,
collate_fn
=
collate_fn
))
parser
.
add_argument
(
'--img_size'
,
default
=
[
94
,
24
],
help
=
'the image size'
)
parser
.
add_argument
(
'--train_img_dirs'
,
default
=
"data/train"
,
help
=
'the train images path'
)
Tp
=
0
parser
.
add_argument
(
'--test_img_dirs'
,
default
=
"data/test"
,
help
=
'the test images path'
)
Tn_1
=
0
parser
.
add_argument
(
'--dropout_rate'
,
default
=
0.5
,
help
=
'dropout rate.'
)
Tn_2
=
0
parser
.
add_argument
(
'--learning_rate'
,
default
=
0.1
,
help
=
'base value of learning rate.'
)
t1
=
time
.
time
()
parser
.
add_argument
(
'--lpr_max_len'
,
default
=
8
,
help
=
'license plate number max length.'
)
for
i
in
range
(
epoch_size
):
parser
.
add_argument
(
'--train_batch_size'
,
default
=
64
,
help
=
'training batch size.'
)
# load train data
parser
.
add_argument
(
'--test_batch_size'
,
default
=
10
,
help
=
'testing batch size.'
)
images
,
labels
,
lengths
=
next
(
batch_iterator
)
parser
.
add_argument
(
'--phase_train'
,
default
=
True
,
type
=
bool
,
help
=
'train or test phase flag.'
)
start
=
0
parser
.
add_argument
(
'--num_workers'
,
default
=
8
,
type
=
int
,
help
=
'Number of workers used in dataloading'
)
targets
=
[]
parser
.
add_argument
(
'--cuda'
,
default
=
True
,
type
=
bool
,
help
=
'Use cuda to train model'
)
for
length
in
lengths
:
parser
.
add_argument
(
'--resume_epoch'
,
default
=
10
,
type
=
int
,
help
=
'resume iter for retraining'
)
label
=
labels
[
start
:
start
+
length
]
parser
.
add_argument
(
'--save_interval'
,
default
=
2000
,
type
=
int
,
help
=
'interval for save model state dict'
)
targets
.
append
(
label
)
parser
.
add_argument
(
'--test_interval'
,
default
=
2000
,
type
=
int
,
help
=
'interval for evaluate'
)
start
+=
length
parser
.
add_argument
(
'--momentum'
,
default
=
0.9
,
type
=
float
,
help
=
'momentum'
)
targets
=
np
.
array
([
el
.
numpy
()
for
el
in
targets
])
parser
.
add_argument
(
'--weight_decay'
,
default
=
2e-5
,
type
=
float
,
help
=
'Weight decay for SGD'
)
parser
.
add_argument
(
'--lr_schedule'
,
default
=
[
4
,
8
,
12
,
14
,
16
],
help
=
'schedule for learning rate.'
)
if
args
.
cuda
:
parser
.
add_argument
(
'--save_folder'
,
default
=
'./weights/'
,
help
=
'Location to save checkpoint models'
)
images
=
Variable
(
images
.
cuda
())
parser
.
add_argument
(
'--pretrained_model'
,
default
=
'./weights/Final_LPRNet_model.pth'
,
help
=
'pretrained base model'
)
else
:
args
=
parser
.
parse_args
()
images
=
Variable
(
images
)
return
args
# forward
prebs
=
Net
(
images
)
# greedy decode
prebs
=
prebs
.
cpu
().
detach
().
numpy
()
preb_labels
=
list
()
for
i
in
range
(
prebs
.
shape
[
0
]):
preb
=
prebs
[
i
,
:,
:]
preb_label
=
list
()
for
j
in
range
(
preb
.
shape
[
1
]):
preb_label
.
append
(
np
.
argmax
(
preb
[:,
j
],
axis
=
0
))
no_repeat_blank_label
=
list
()
pre_c
=
preb_label
[
0
]
if
pre_c
!=
len
(
CHARS
)
-
1
:
no_repeat_blank_label
.
append
(
pre_c
)
for
c
in
preb_label
:
# dropout repeate label and blank label
if
(
pre_c
==
c
)
or
(
c
==
len
(
CHARS
)
-
1
):
if
c
==
len
(
CHARS
)
-
1
:
pre_c
=
c
continue
no_repeat_blank_label
.
append
(
c
)
pre_c
=
c
preb_labels
.
append
(
no_repeat_blank_label
)
for
i
,
label
in
enumerate
(
preb_labels
):
if
len
(
label
)
!=
len
(
targets
[
i
]):
Tn_1
+=
1
continue
if
(
np
.
asarray
(
targets
[
i
])
==
np
.
asarray
(
label
)).
all
():
Tp
+=
1
else
:
Tn_2
+=
1
Acc
=
Tp
*
1.0
/
(
Tp
+
Tn_1
+
Tn_2
)
print
(
"[Info] Test Accuracy: {} [{}:{}:{}:{}]"
.
format
(
Acc
,
Tp
,
Tn_1
,
Tn_2
,
(
Tp
+
Tn_1
+
Tn_2
)))
t2
=
time
.
time
()
print
(
"[Info] Test Speed: {}s 1/{}]"
.
format
((
t2
-
t1
)
/
len
(
datasets
),
len
(
datasets
)))
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
train
()
args
=
get_parser
()
train
(
args
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment