Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
huastant
LPR
Commits
dfc4d118
You need to sign in or sign up before continuing.
Commit
dfc4d118
authored
Mar 01, 2023
by
liuhy
Browse files
修改代码
parent
9adcf60d
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
130 additions
and
121 deletions
+130
-121
LPRNet_ORT_infer.py
LPRNet_ORT_infer.py
+11
-3
test.py
test.py
+29
-11
train.py
train.py
+90
-107
No files found.
LPRNet_ORT_infer.py
View file @
dfc4d118
...
@@ -60,6 +60,14 @@ def LPRNetInference(model, imgs):
...
@@ -60,6 +60,14 @@ def LPRNetInference(model, imgs):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
model_name
=
'model/LPRNet.onnx'
model_name
=
'model/LPRNet.onnx'
# model_name = 'LPRNet.onnx'
# model_name = 'LPRNet.onnx'
image
=
'imgs/川JK0707.jpg'
# image = 'imgs/川JK0707.jpg'
InferRes
=
LPRNetInference
(
model_name
,
image
)
import
os
print
(
image
,
'Inference Result:'
,
InferRes
)
images
=
os
.
listdir
(
'/code/lpr_ori/data/test'
)
count
=
0
for
image
in
images
:
label
=
image
[:
-
4
]
InferRes
=
LPRNetInference
(
model_name
,
os
.
path
.
join
(
'/code/lpr_ori/data/test'
,
image
))
print
(
image
,
'Inference Result:'
,
InferRes
)
if
label
==
InferRes
:
count
+=
1
print
(
'acc rate:'
,
count
/
len
(
images
))
test.py
View file @
dfc4d118
import
argparse
import
argparse
import
cv2
import
cv2
import
os
import
torch
import
torch
import
numpy
as
np
import
numpy
as
np
from
lprnet
import
build_lprnet
from
lprnet
import
build_lprnet
from
load_data
import
CHARS
from
load_data
import
CHARS
def
infer
(
args
,
image
,
model
):
def
validation
(
args
):
img
=
cv2
.
imread
(
image
)
model
=
build_lprnet
(
len
(
CHARS
))
model
.
load_state_dict
(
torch
.
load
(
args
.
model
,
map_location
=
args
.
device
))
model
.
to
(
args
.
device
)
img
=
cv2
.
imread
(
args
.
img
)
height
,
width
,
_
=
img
.
shape
height
,
width
,
_
=
img
.
shape
if
height
!=
24
or
width
!=
94
:
if
height
!=
24
or
width
!=
94
:
img
=
cv2
.
resize
(
img
,
(
94
,
24
))
img
=
cv2
.
resize
(
img
,
(
94
,
24
))
...
@@ -38,7 +34,26 @@ def validation(args):
...
@@ -38,7 +34,26 @@ def validation(args):
continue
continue
no_repeat_blank_label
.
append
(
c
)
no_repeat_blank_label
.
append
(
c
)
pre_c
=
c
pre_c
=
c
return
''
.
join
(
list
(
map
(
lambda
x
:
CHARS
[
x
],
no_repeat_blank_label
)))
def
validation
(
args
):
model
=
build_lprnet
(
len
(
CHARS
))
model
.
load_state_dict
(
torch
.
load
(
args
.
model
,
map_location
=
args
.
device
))
model
.
to
(
args
.
device
)
if
os
.
path
.
isdir
(
args
.
imgpath
):
images
=
os
.
listdir
(
args
.
imgpath
)
count
=
0
for
image
in
images
:
res
=
infer
(
args
,
os
.
path
.
join
(
args
.
imgpath
,
image
),
model
)
if
res
==
image
[:
-
4
]:
count
+=
1
print
(
'Image: '
+
image
+
' recongise result: '
+
res
)
print
(
'acc rate:'
,
count
/
len
(
images
))
else
:
res
=
infer
(
args
,
args
.
imgpath
,
model
)
print
(
'Image: '
+
args
.
imgpath
+
' recongise result: '
+
res
)
if
args
.
export_onnx
:
if
args
.
export_onnx
:
print
(
'export pytroch model to onnx model...'
)
print
(
'export pytroch model to onnx model...'
)
onnx_input
=
torch
.
randn
(
1
,
3
,
24
,
94
,
device
=
args
.
device
)
onnx_input
=
torch
.
randn
(
1
,
3
,
24
,
94
,
device
=
args
.
device
)
...
@@ -51,16 +66,19 @@ def validation(args):
...
@@ -51,16 +66,19 @@ def validation(args):
dynamic_axes
=
{
'input'
:
{
0
:
'batch'
},
'output'
:
{
0
:
'batch'
}}
if
args
.
dynamic
else
None
,
dynamic_axes
=
{
'input'
:
{
0
:
'batch'
},
'output'
:
{
0
:
'batch'
}}
if
args
.
dynamic
else
None
,
opset_version
=
12
,
opset_version
=
12
,
)
)
return
''
.
join
(
list
(
map
(
lambda
x
:
CHARS
[
x
],
no_repeat_blank_label
)))
return
res
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
(
description
=
'parameters to vaildate net'
)
parser
=
argparse
.
ArgumentParser
(
description
=
'parameters to vaildate net'
)
parser
.
add_argument
(
'--model'
,
default
=
'model/lprnet.pth'
,
help
=
'model path to vaildate'
)
# parser.add_argument('--model', default='model/lprnet.pth', help='model path to vaildate')
parser
.
add_argument
(
'--img'
,
default
=
'imgs/川JK0707.jpg'
,
help
=
'the image path'
)
parser
.
add_argument
(
'--model'
,
default
=
'weights/Final_LPRNet_model.pth'
,
help
=
'model path to vaildate'
)
# parser.add_argument('--imgpath', default='imgs/川JK0707.jpg', help='the image path')
parser
.
add_argument
(
'--imgpath'
,
default
=
'/code/lpr_ori/data/test'
,
help
=
'the image path'
)
parser
.
add_argument
(
'--device'
,
default
=
'cuda'
,
help
=
'Use cuda to vaildate model'
)
parser
.
add_argument
(
'--device'
,
default
=
'cuda'
,
help
=
'Use cuda to vaildate model'
)
parser
.
add_argument
(
'--export_onnx'
,
default
=
False
,
help
=
'export model to onnx'
)
parser
.
add_argument
(
'--export_onnx'
,
default
=
False
,
help
=
'export model to onnx'
)
parser
.
add_argument
(
'--dynamic'
,
default
=
False
,
help
=
'use dynamic batch size'
)
parser
.
add_argument
(
'--dynamic'
,
default
=
False
,
help
=
'use dynamic batch size'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
result
=
validation
(
args
)
result
=
validation
(
args
)
print
(
'recongise result:'
,
result
)
train.py
View file @
dfc4d118
...
@@ -2,7 +2,6 @@
...
@@ -2,7 +2,6 @@
# /usr/bin/env/python3
# /usr/bin/env/python3
from
load_data
import
CHARS
,
CHARS_DICT
,
LPRDataLoader
from
load_data
import
CHARS
,
CHARS_DICT
,
LPRDataLoader
from
lprnet
import
build_lprnet
from
lprnet
import
build_lprnet
# import torch.backends.cudnn as cudnn
from
torch.autograd
import
Variable
from
torch.autograd
import
Variable
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
torch.utils.data
import
*
from
torch.utils.data
import
*
...
@@ -14,7 +13,7 @@ import torch
...
@@ -14,7 +13,7 @@ import torch
import
time
import
time
import
os
import
os
print
(
torch
.
cuda
.
is_available
())
print
(
'Cuda Availabel:'
,
torch
.
cuda
.
is_available
())
def
sparse_tuple_for_ctc
(
T_length
,
lengths
):
def
sparse_tuple_for_ctc
(
T_length
,
lengths
):
input_lengths
=
[]
input_lengths
=
[]
...
@@ -23,7 +22,6 @@ def sparse_tuple_for_ctc(T_length, lengths):
...
@@ -23,7 +22,6 @@ def sparse_tuple_for_ctc(T_length, lengths):
for
ch
in
lengths
:
for
ch
in
lengths
:
input_lengths
.
append
(
T_length
)
input_lengths
.
append
(
T_length
)
target_lengths
.
append
(
ch
)
target_lengths
.
append
(
ch
)
return
tuple
(
input_lengths
),
tuple
(
target_lengths
)
return
tuple
(
input_lengths
),
tuple
(
target_lengths
)
def
adjust_learning_rate
(
optimizer
,
cur_epoch
,
base_lr
,
lr_schedule
):
def
adjust_learning_rate
(
optimizer
,
cur_epoch
,
base_lr
,
lr_schedule
):
...
@@ -39,37 +37,8 @@ def adjust_learning_rate(optimizer, cur_epoch, base_lr, lr_schedule):
...
@@ -39,37 +37,8 @@ def adjust_learning_rate(optimizer, cur_epoch, base_lr, lr_schedule):
lr
=
base_lr
lr
=
base_lr
for
param_group
in
optimizer
.
param_groups
:
for
param_group
in
optimizer
.
param_groups
:
param_group
[
'lr'
]
=
lr
param_group
[
'lr'
]
=
lr
return
lr
return
lr
def
get_parser
():
parser
=
argparse
.
ArgumentParser
(
description
=
'parameters to train net'
)
parser
.
add_argument
(
'--max_epoch'
,
default
=
15
,
help
=
'epoch to train the network'
)
parser
.
add_argument
(
'--img_size'
,
default
=
[
94
,
24
],
help
=
'the image size'
)
parser
.
add_argument
(
'--train_img_dirs'
,
default
=
"data/train"
,
help
=
'the train images path'
)
parser
.
add_argument
(
'--test_img_dirs'
,
default
=
"data/test"
,
help
=
'the test images path'
)
parser
.
add_argument
(
'--dropout_rate'
,
default
=
0.5
,
help
=
'dropout rate.'
)
parser
.
add_argument
(
'--learning_rate'
,
default
=
0.1
,
help
=
'base value of learning rate.'
)
parser
.
add_argument
(
'--lpr_max_len'
,
default
=
8
,
help
=
'license plate number max length.'
)
parser
.
add_argument
(
'--train_batch_size'
,
default
=
64
,
help
=
'training batch size.'
)
parser
.
add_argument
(
'--test_batch_size'
,
default
=
10
,
help
=
'testing batch size.'
)
parser
.
add_argument
(
'--phase_train'
,
default
=
True
,
type
=
bool
,
help
=
'train or test phase flag.'
)
parser
.
add_argument
(
'--num_workers'
,
default
=
8
,
type
=
int
,
help
=
'Number of workers used in dataloading'
)
parser
.
add_argument
(
'--cuda'
,
default
=
True
,
type
=
bool
,
help
=
'Use cuda to train model'
)
parser
.
add_argument
(
'--resume_epoch'
,
default
=
10
,
type
=
int
,
help
=
'resume iter for retraining'
)
parser
.
add_argument
(
'--save_interval'
,
default
=
2000
,
type
=
int
,
help
=
'interval for save model state dict'
)
parser
.
add_argument
(
'--test_interval'
,
default
=
2000
,
type
=
int
,
help
=
'interval for evaluate'
)
parser
.
add_argument
(
'--momentum'
,
default
=
0.9
,
type
=
float
,
help
=
'momentum'
)
parser
.
add_argument
(
'--weight_decay'
,
default
=
2e-5
,
type
=
float
,
help
=
'Weight decay for SGD'
)
parser
.
add_argument
(
'--lr_schedule'
,
default
=
[
4
,
8
,
12
,
14
,
16
],
help
=
'schedule for learning rate.'
)
parser
.
add_argument
(
'--save_folder'
,
default
=
'./weights/'
,
help
=
'Location to save checkpoint models'
)
# parser.add_argument('--pretrained_model', default='./weights/Final_LPRNet_model.pth', help='pretrained base model')
parser
.
add_argument
(
'--pretrained_model'
,
default
=
''
,
help
=
'pretrained base model'
)
args
=
parser
.
parse_args
()
return
args
def
collate_fn
(
batch
):
def
collate_fn
(
batch
):
imgs
=
[]
imgs
=
[]
labels
=
[]
labels
=
[]
...
@@ -80,12 +49,71 @@ def collate_fn(batch):
...
@@ -80,12 +49,71 @@ def collate_fn(batch):
labels
.
extend
(
label
)
labels
.
extend
(
label
)
lengths
.
append
(
length
)
lengths
.
append
(
length
)
labels
=
np
.
asarray
(
labels
).
flatten
().
astype
(
np
.
int16
)
labels
=
np
.
asarray
(
labels
).
flatten
().
astype
(
np
.
int16
)
return
(
torch
.
stack
(
imgs
,
0
),
torch
.
from_numpy
(
labels
),
lengths
)
return
(
torch
.
stack
(
imgs
,
0
),
torch
.
from_numpy
(
labels
),
lengths
)
def
train
():
def
Greedy_Decode_Eval
(
Net
,
datasets
,
args
):
args
=
get_parser
()
epoch_size
=
len
(
datasets
)
//
args
.
test_batch_size
batch_iterator
=
iter
(
DataLoader
(
datasets
,
args
.
test_batch_size
,
shuffle
=
True
,
num_workers
=
args
.
num_workers
,
collate_fn
=
collate_fn
))
Tp
=
0
Tn_1
=
0
Tn_2
=
0
t1
=
time
.
time
()
for
i
in
range
(
epoch_size
):
# load train data
images
,
labels
,
lengths
=
next
(
batch_iterator
)
start
=
0
targets
=
[]
for
length
in
lengths
:
label
=
labels
[
start
:
start
+
length
]
targets
.
append
(
label
)
start
+=
length
targets
=
np
.
array
([
el
.
numpy
()
for
el
in
targets
])
if
args
.
cuda
:
images
=
Variable
(
images
.
cuda
())
else
:
images
=
Variable
(
images
)
# forward
Net
.
eval
()
prebs
=
Net
(
images
)
# greedy decode
prebs
=
prebs
.
cpu
().
detach
().
numpy
()
preb_labels
=
[]
for
i
in
range
(
prebs
.
shape
[
0
]):
preb
=
prebs
[
i
,
:,
:]
preb_label
=
[]
for
j
in
range
(
preb
.
shape
[
1
]):
preb_label
.
append
(
np
.
argmax
(
preb
[:,
j
],
axis
=
0
))
no_repeat_blank_label
=
[]
pre_c
=
preb_label
[
0
]
if
pre_c
!=
len
(
CHARS
)
-
1
:
no_repeat_blank_label
.
append
(
pre_c
)
for
c
in
preb_label
:
# dropout repeate label and blank label
if
(
pre_c
==
c
)
or
(
c
==
len
(
CHARS
)
-
1
):
if
c
==
len
(
CHARS
)
-
1
:
pre_c
=
c
continue
no_repeat_blank_label
.
append
(
c
)
pre_c
=
c
preb_labels
.
append
(
no_repeat_blank_label
)
for
i
,
label
in
enumerate
(
preb_labels
):
if
len
(
label
)
!=
len
(
targets
[
i
]):
Tn_1
+=
1
continue
if
(
np
.
asarray
(
targets
[
i
])
==
np
.
asarray
(
label
)).
all
():
Tp
+=
1
else
:
Tn_2
+=
1
Acc
=
Tp
*
1.0
/
(
Tp
+
Tn_1
+
Tn_2
)
print
(
"[Info] Test Accuracy: {} [{}:{}:{}:{}]"
.
format
(
Acc
,
Tp
,
Tn_1
,
Tn_2
,
(
Tp
+
Tn_1
+
Tn_2
)))
t2
=
time
.
time
()
print
(
"[Info] Test Speed: {}s 1/{}]"
.
format
((
t2
-
t1
)
/
len
(
datasets
),
len
(
datasets
)))
def
train
(
args
):
T_length
=
18
# args.lpr_max_len
T_length
=
18
# args.lpr_max_len
epoch
=
0
+
args
.
resume_epoch
epoch
=
0
+
args
.
resume_epoch
loss_val
=
0
loss_val
=
0
...
@@ -121,8 +149,6 @@ def train():
...
@@ -121,8 +149,6 @@ def train():
print
(
"initial net weights successful!"
)
print
(
"initial net weights successful!"
)
# define optimizer
# define optimizer
# optimizer = optim.SGD(lprnet.parameters(), lr=args.learning_rate,
# momentum=args.momentum, weight_decay=args.weight_decay)
optimizer
=
optim
.
RMSprop
(
lprnet
.
parameters
(),
lr
=
args
.
learning_rate
,
alpha
=
0.9
,
eps
=
1e-08
,
optimizer
=
optim
.
RMSprop
(
lprnet
.
parameters
(),
lr
=
args
.
learning_rate
,
alpha
=
0.9
,
eps
=
1e-08
,
momentum
=
args
.
momentum
,
weight_decay
=
args
.
weight_decay
)
momentum
=
args
.
momentum
,
weight_decay
=
args
.
weight_decay
)
train_img_dirs
=
os
.
path
.
expanduser
(
args
.
train_img_dirs
)
train_img_dirs
=
os
.
path
.
expanduser
(
args
.
train_img_dirs
)
...
@@ -148,17 +174,14 @@ def train():
...
@@ -148,17 +174,14 @@ def train():
epoch
+=
1
epoch
+=
1
if
iteration
!=
0
and
iteration
%
args
.
save_interval
==
0
:
if
iteration
!=
0
and
iteration
%
args
.
save_interval
==
0
:
torch
.
save
(
lprnet
.
state_dict
(),
args
.
save_folder
+
'LPRNet_'
+
'
_
iteration_'
+
repr
(
iteration
)
+
'.pth'
)
torch
.
save
(
lprnet
.
state_dict
(),
args
.
save_folder
+
'LPRNet_'
+
'iteration_'
+
repr
(
iteration
)
+
'.pth'
)
if
(
iteration
+
1
)
%
args
.
test_interval
==
0
:
if
(
iteration
+
1
)
%
args
.
test_interval
==
0
:
Greedy_Decode_Eval
(
lprnet
,
test_dataset
,
args
)
Greedy_Decode_Eval
(
lprnet
,
test_dataset
,
args
)
# lprnet.train() # should be switch to train mode
start_time
=
time
.
time
()
start_time
=
time
.
time
()
# load train data
# load train data
images
,
labels
,
lengths
=
next
(
batch_iterator
)
images
,
labels
,
lengths
=
next
(
batch_iterator
)
# labels = np.array([el.numpy() for el in labels]).T
# print(labels)
# get ctc parameters
# get ctc parameters
input_lengths
,
target_lengths
=
sparse_tuple_for_ctc
(
T_length
,
lengths
)
input_lengths
,
target_lengths
=
sparse_tuple_for_ctc
(
T_length
,
lengths
)
# update lr
# update lr
...
@@ -173,12 +196,8 @@ def train():
...
@@ -173,12 +196,8 @@ def train():
# forward
# forward
logits
=
lprnet
(
images
)
logits
=
lprnet
(
images
)
# print(logits.size())
log_probs
=
logits
.
permute
(
2
,
0
,
1
)
# for ctc loss: T x N x C
log_probs
=
logits
.
permute
(
2
,
0
,
1
)
# for ctc loss: T x N x C
# print(labels.shape)
log_probs
=
log_probs
.
log_softmax
(
2
).
requires_grad_
()
log_probs
=
log_probs
.
log_softmax
(
2
).
requires_grad_
()
# log_probs = log_probs.detach().requires_grad_()
# print(log_probs.shape)
# backprop
# backprop
optimizer
.
zero_grad
()
optimizer
.
zero_grad
()
loss
=
ctc_loss
(
log_probs
,
labels
,
input_lengths
=
input_lengths
,
target_lengths
=
target_lengths
)
loss
=
ctc_loss
(
log_probs
,
labels
,
input_lengths
=
input_lengths
,
target_lengths
=
target_lengths
)
...
@@ -199,67 +218,31 @@ def train():
...
@@ -199,67 +218,31 @@ def train():
# save final parameters
# save final parameters
torch
.
save
(
lprnet
.
state_dict
(),
args
.
save_folder
+
'Final_LPRNet_model.pth'
)
torch
.
save
(
lprnet
.
state_dict
(),
args
.
save_folder
+
'Final_LPRNet_model.pth'
)
def
Greedy_Decode_Eval
(
Net
,
datasets
,
args
):
def
get_parser
():
# TestNet = Net.eval()
parser
=
argparse
.
ArgumentParser
(
description
=
'parameters to train net'
)
epoch_size
=
len
(
datasets
)
//
args
.
test_batch_size
parser
.
add_argument
(
'--max_epoch'
,
default
=
15
,
help
=
'epoch to train the network'
)
batch_iterator
=
iter
(
DataLoader
(
datasets
,
args
.
test_batch_size
,
shuffle
=
True
,
num_workers
=
args
.
num_workers
,
collate_fn
=
collate_fn
))
parser
.
add_argument
(
'--img_size'
,
default
=
[
94
,
24
],
help
=
'the image size'
)
parser
.
add_argument
(
'--train_img_dirs'
,
default
=
"data/train"
,
help
=
'the train images path'
)
Tp
=
0
parser
.
add_argument
(
'--test_img_dirs'
,
default
=
"data/test"
,
help
=
'the test images path'
)
Tn_1
=
0
parser
.
add_argument
(
'--dropout_rate'
,
default
=
0.5
,
help
=
'dropout rate.'
)
Tn_2
=
0
parser
.
add_argument
(
'--learning_rate'
,
default
=
0.1
,
help
=
'base value of learning rate.'
)
t1
=
time
.
time
()
parser
.
add_argument
(
'--lpr_max_len'
,
default
=
8
,
help
=
'license plate number max length.'
)
for
i
in
range
(
epoch_size
):
parser
.
add_argument
(
'--train_batch_size'
,
default
=
64
,
help
=
'training batch size.'
)
# load train data
parser
.
add_argument
(
'--test_batch_size'
,
default
=
10
,
help
=
'testing batch size.'
)
images
,
labels
,
lengths
=
next
(
batch_iterator
)
parser
.
add_argument
(
'--phase_train'
,
default
=
True
,
type
=
bool
,
help
=
'train or test phase flag.'
)
start
=
0
parser
.
add_argument
(
'--num_workers'
,
default
=
8
,
type
=
int
,
help
=
'Number of workers used in dataloading'
)
targets
=
[]
parser
.
add_argument
(
'--cuda'
,
default
=
True
,
type
=
bool
,
help
=
'Use cuda to train model'
)
for
length
in
lengths
:
parser
.
add_argument
(
'--resume_epoch'
,
default
=
10
,
type
=
int
,
help
=
'resume iter for retraining'
)
label
=
labels
[
start
:
start
+
length
]
parser
.
add_argument
(
'--save_interval'
,
default
=
2000
,
type
=
int
,
help
=
'interval for save model state dict'
)
targets
.
append
(
label
)
parser
.
add_argument
(
'--test_interval'
,
default
=
2000
,
type
=
int
,
help
=
'interval for evaluate'
)
start
+=
length
parser
.
add_argument
(
'--momentum'
,
default
=
0.9
,
type
=
float
,
help
=
'momentum'
)
targets
=
np
.
array
([
el
.
numpy
()
for
el
in
targets
])
parser
.
add_argument
(
'--weight_decay'
,
default
=
2e-5
,
type
=
float
,
help
=
'Weight decay for SGD'
)
parser
.
add_argument
(
'--lr_schedule'
,
default
=
[
4
,
8
,
12
,
14
,
16
],
help
=
'schedule for learning rate.'
)
if
args
.
cuda
:
parser
.
add_argument
(
'--save_folder'
,
default
=
'./weights/'
,
help
=
'Location to save checkpoint models'
)
images
=
Variable
(
images
.
cuda
())
parser
.
add_argument
(
'--pretrained_model'
,
default
=
'./weights/Final_LPRNet_model.pth'
,
help
=
'pretrained base model'
)
else
:
args
=
parser
.
parse_args
()
images
=
Variable
(
images
)
return
args
# forward
prebs
=
Net
(
images
)
# greedy decode
prebs
=
prebs
.
cpu
().
detach
().
numpy
()
preb_labels
=
list
()
for
i
in
range
(
prebs
.
shape
[
0
]):
preb
=
prebs
[
i
,
:,
:]
preb_label
=
list
()
for
j
in
range
(
preb
.
shape
[
1
]):
preb_label
.
append
(
np
.
argmax
(
preb
[:,
j
],
axis
=
0
))
no_repeat_blank_label
=
list
()
pre_c
=
preb_label
[
0
]
if
pre_c
!=
len
(
CHARS
)
-
1
:
no_repeat_blank_label
.
append
(
pre_c
)
for
c
in
preb_label
:
# dropout repeate label and blank label
if
(
pre_c
==
c
)
or
(
c
==
len
(
CHARS
)
-
1
):
if
c
==
len
(
CHARS
)
-
1
:
pre_c
=
c
continue
no_repeat_blank_label
.
append
(
c
)
pre_c
=
c
preb_labels
.
append
(
no_repeat_blank_label
)
for
i
,
label
in
enumerate
(
preb_labels
):
if
len
(
label
)
!=
len
(
targets
[
i
]):
Tn_1
+=
1
continue
if
(
np
.
asarray
(
targets
[
i
])
==
np
.
asarray
(
label
)).
all
():
Tp
+=
1
else
:
Tn_2
+=
1
Acc
=
Tp
*
1.0
/
(
Tp
+
Tn_1
+
Tn_2
)
print
(
"[Info] Test Accuracy: {} [{}:{}:{}:{}]"
.
format
(
Acc
,
Tp
,
Tn_1
,
Tn_2
,
(
Tp
+
Tn_1
+
Tn_2
)))
t2
=
time
.
time
()
print
(
"[Info] Test Speed: {}s 1/{}]"
.
format
((
t2
-
t1
)
/
len
(
datasets
),
len
(
datasets
)))
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
train
()
args
=
get_parser
()
train
(
args
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment