Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
crnn_pytorch
Commits
4c93f0ed
Commit
4c93f0ed
authored
Dec 26, 2023
by
dengjf
Browse files
update code
parents
Pipeline
#685
canceled with stages
Changes
22
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
391 additions
and
0 deletions
+391
-0
train_ddp.py
train_ddp.py
+242
-0
utils.py
utils.py
+149
-0
No files found.
train_ddp.py
0 → 100644
View file @
4c93f0ed
from
__future__
import
print_function
from
__future__
import
division
import
argparse
import
random
import
torch
import
torch.backends.cudnn
as
cudnn
import
torch.optim
as
optim
import
torch.utils.data
from
torch.autograd
import
Variable
import
numpy
as
np
# from warpctc_pytorch import CTCLoss
from
torch.nn
import
CTCLoss
import
os
import
utils
import
dataset
from
datetime
import
datetime
import
models.crnn
as
crnn
import
time
import
torch.distributed
as
dist
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--trainRoot'
,
required
=
True
,
help
=
'path to dataset'
)
parser
.
add_argument
(
'--valRoot'
,
required
=
True
,
help
=
'path to dataset'
)
parser
.
add_argument
(
'--workers'
,
type
=
int
,
help
=
'number of data loading workers'
,
default
=
2
)
parser
.
add_argument
(
'--batchSize'
,
type
=
int
,
default
=
64
,
help
=
'input batch size'
)
parser
.
add_argument
(
'--imgH'
,
type
=
int
,
default
=
32
,
help
=
'the height of the input image to network'
)
parser
.
add_argument
(
'--imgW'
,
type
=
int
,
default
=
100
,
help
=
'the width of the input image to network'
)
parser
.
add_argument
(
'--nh'
,
type
=
int
,
default
=
256
,
help
=
'size of the lstm hidden state'
)
parser
.
add_argument
(
'--nepoch'
,
type
=
int
,
default
=
25
,
help
=
'number of epochs to train for'
)
# TODO(meijieru): epoch -> iter
parser
.
add_argument
(
'--cuda'
,
action
=
'store_true'
,
help
=
'enables cuda'
)
parser
.
add_argument
(
'--ngpu'
,
type
=
int
,
default
=
1
,
help
=
'number of GPUs to use'
)
parser
.
add_argument
(
'--pretrained'
,
default
=
''
,
help
=
"path to pretrained model (to continue training)"
)
parser
.
add_argument
(
'--alphabet'
,
type
=
str
,
default
=
'0123456789abcdefghijklmnopqrstuvwxyz'
)
parser
.
add_argument
(
'--expr_dir'
,
default
=
'expr'
,
help
=
'Where to store samples and models'
)
parser
.
add_argument
(
'--displayInterval'
,
type
=
int
,
default
=
500
,
help
=
'Interval to be displayed'
)
parser
.
add_argument
(
'--n_test_disp'
,
type
=
int
,
default
=
10
,
help
=
'Number of samples to display when test'
)
parser
.
add_argument
(
'--valInterval'
,
type
=
int
,
default
=
500
,
help
=
'Interval to be displayed'
)
parser
.
add_argument
(
'--saveInterval'
,
type
=
int
,
default
=
500
,
help
=
'Interval to be displayed'
)
parser
.
add_argument
(
'--lr'
,
type
=
float
,
default
=
0.01
,
help
=
'learning rate for Critic, not used by adadealta'
)
parser
.
add_argument
(
'--beta1'
,
type
=
float
,
default
=
0.5
,
help
=
'beta1 for adam. default=0.5'
)
parser
.
add_argument
(
'--adam'
,
action
=
'store_true'
,
help
=
'Whether to use adam (default is rmsprop)'
)
parser
.
add_argument
(
'--adadelta'
,
action
=
'store_true'
,
help
=
'Whether to use adadelta (default is rmsprop)'
)
parser
.
add_argument
(
'--keep_ratio'
,
action
=
'store_true'
,
help
=
'whether to keep ratio for image resize'
)
parser
.
add_argument
(
'--manualSeed'
,
type
=
int
,
default
=
1234
,
help
=
'reproduce experiemnt'
)
parser
.
add_argument
(
'--random_sample'
,
action
=
'store_true'
,
help
=
'whether to sample the dataset with random sampler'
)
parser
.
add_argument
(
'--local_rank'
,
type
=
int
,
default
=
1
,
help
=
'local rank environ'
)
parser
.
add_argument
(
'--world-size'
,
default
=
4
,
type
=
int
,
help
=
'number of distributed processes'
)
opt
=
parser
.
parse_args
()
print
(
opt
)
rank
=
int
(
os
.
environ
[
"RANK"
])
local_rank
=
opt
.
local_rank
world_size
=
int
(
os
.
environ
[
'WORLD_SIZE'
])
print
(
f
"rank:
{
rank
}
, local_rank:
{
local_rank
}
, world_size:
{
world_size
}
"
)
dist
.
init_process_group
(
backend
=
"nccl"
)
device
=
torch
.
device
(
f
"cuda:
{
local_rank
}
"
)
if
not
os
.
path
.
exists
(
opt
.
expr_dir
):
os
.
makedirs
(
opt
.
expr_dir
)
random
.
seed
(
opt
.
manualSeed
)
np
.
random
.
seed
(
opt
.
manualSeed
)
torch
.
manual_seed
(
opt
.
manualSeed
)
cudnn
.
benchmark
=
True
if
torch
.
cuda
.
is_available
()
and
not
opt
.
cuda
:
print
(
"WARNING: You have a CUDA device, so you should probably run with --cuda"
)
# train_dataset = dataset.lmdbDataset(root=opt.trainroot)
train_dataset
=
dataset
.
lmdbDataset
(
root
=
opt
.
trainRoot
)
assert
train_dataset
sampler
=
torch
.
utils
.
data
.
distributed
.
DistributedSampler
(
train_dataset
,
shuffle
=
True
)
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
train_dataset
,
batch_size
=
opt
.
batchSize
,
shuffle
=
False
,
sampler
=
sampler
,
num_workers
=
int
(
opt
.
workers
),
collate_fn
=
dataset
.
alignCollate
(
imgH
=
opt
.
imgH
,
imgW
=
opt
.
imgW
,
keep_ratio
=
opt
.
keep_ratio
))
test_dataset
=
dataset
.
lmdbDataset
(
root
=
opt
.
valRoot
,
transform
=
dataset
.
resizeNormalize
((
100
,
32
)))
nclass
=
len
(
opt
.
alphabet
)
+
1
nc
=
1
converter
=
utils
.
strLabelConverter
(
opt
.
alphabet
)
criterion
=
CTCLoss
()
# custom weights initialization called on crnn
def
weights_init
(
m
):
classname
=
m
.
__class__
.
__name__
if
classname
.
find
(
'Conv'
)
!=
-
1
:
m
.
weight
.
data
.
normal_
(
0.0
,
0.02
)
elif
classname
.
find
(
'BatchNorm'
)
!=
-
1
:
m
.
weight
.
data
.
normal_
(
1.0
,
0.02
)
m
.
bias
.
data
.
fill_
(
0
)
crnn
=
crnn
.
CRNN
(
opt
.
imgH
,
nc
,
nclass
,
opt
.
nh
)
crnn
.
apply
(
weights_init
)
if
opt
.
pretrained
!=
''
:
print
(
'loading pretrained model from %s'
%
opt
.
pretrained
)
crnn
.
load_state_dict
(
torch
.
load
(
opt
.
pretrained
))
print
(
crnn
)
# ddp model
image
=
torch
.
FloatTensor
(
opt
.
batchSize
,
3
,
opt
.
imgH
,
opt
.
imgH
)
text
=
torch
.
IntTensor
(
opt
.
batchSize
*
5
)
length
=
torch
.
IntTensor
(
opt
.
batchSize
)
if
opt
.
cuda
:
crnn
.
cuda
(
local_rank
)
# crnn = torch.nn.DataParallel(crnn, device_ids=range(opt.ngpu))
image
=
image
.
cuda
(
local_rank
)
criterion
=
criterion
.
cuda
(
local_rank
)
crnn
=
DDP
(
crnn
,
device_ids
=
[
local_rank
],
find_unused_parameters
=
True
)
image
=
Variable
(
image
)
text
=
Variable
(
text
)
length
=
Variable
(
length
)
# loss averager
loss_avg
=
utils
.
averager
()
# setup optimizer
if
opt
.
adam
:
optimizer
=
optim
.
Adam
(
crnn
.
parameters
(),
lr
=
opt
.
lr
,
betas
=
(
opt
.
beta1
,
0.999
))
elif
opt
.
adadelta
:
optimizer
=
optim
.
Adadelta
(
crnn
.
parameters
())
else
:
optimizer
=
optim
.
RMSprop
(
crnn
.
parameters
(),
lr
=
opt
.
lr
)
def
val
(
net
,
dataset
,
criterion
,
max_iter
=
100
):
print
(
'Start val'
)
for
p
in
crnn
.
parameters
():
p
.
requires_grad
=
False
net
.
eval
()
data_loader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
,
shuffle
=
True
,
batch_size
=
opt
.
batchSize
,
num_workers
=
int
(
opt
.
workers
))
val_iter
=
iter
(
data_loader
)
i
=
0
n_correct
=
0
loss_avg
=
utils
.
averager
()
max_iter
=
min
(
max_iter
,
len
(
data_loader
))
for
i
in
range
(
max_iter
):
data
=
next
(
val_iter
)
i
+=
1
cpu_images
,
cpu_texts
=
data
batch_size
=
cpu_images
.
size
(
0
)
utils
.
loadData
(
image
,
cpu_images
)
t
,
l
=
converter
.
encode
(
cpu_texts
)
utils
.
loadData
(
text
,
t
)
utils
.
loadData
(
length
,
l
)
preds
=
crnn
(
image
).
permute
(
1
,
0
,
2
)
preds_size
=
Variable
(
torch
.
IntTensor
([
preds
.
size
(
0
)]
*
batch_size
))
cost
=
criterion
(
preds
,
text
,
preds_size
,
length
)
/
batch_size
loss_avg
.
add
(
cost
)
_
,
preds
=
preds
.
max
(
2
)
# preds = preds.squeeze(2)
preds
=
preds
.
transpose
(
1
,
0
).
contiguous
().
view
(
-
1
)
sim_preds
=
converter
.
decode
(
preds
.
data
,
preds_size
.
data
,
raw
=
False
)
for
pred
,
target
in
zip
(
sim_preds
,
cpu_texts
):
if
pred
==
target
.
lower
():
n_correct
+=
1
raw_preds
=
converter
.
decode
(
preds
.
data
,
preds_size
.
data
,
raw
=
True
)[:
opt
.
n_test_disp
]
for
raw_pred
,
pred
,
gt
in
zip
(
raw_preds
,
sim_preds
,
cpu_texts
):
print
(
'%-20s => %-20s, gt: %-20s'
%
(
raw_pred
,
pred
,
gt
))
accuracy
=
n_correct
/
float
(
max_iter
*
opt
.
batchSize
)
print
(
'Test loss: %f, accuray: %f'
%
(
loss_avg
.
val
(),
accuracy
))
def
trainBatch
(
net
,
criterion
,
optimizer
):
batch_time
=
utils
.
AverageMeter
()
data_time
=
utils
.
AverageMeter
()
end
=
time
.
time
()
data
=
next
(
train_iter
)
data_time
.
update
((
time
.
time
()
-
end
)
*
1000
)
cpu_images
,
cpu_texts
=
data
batch_size
=
cpu_images
.
size
(
0
)
utils
.
loadData
(
image
,
cpu_images
)
t
,
l
=
converter
.
encode
(
cpu_texts
)
utils
.
loadData
(
text
,
t
)
utils
.
loadData
(
length
,
l
)
preds
=
crnn
(
image
).
permute
(
1
,
0
,
2
)
preds_size
=
Variable
(
torch
.
IntTensor
([
preds
.
size
(
0
)]
*
batch_size
))
cost
=
criterion
(
preds
,
text
,
preds_size
,
length
)
/
batch_size
crnn
.
zero_grad
()
cost
.
backward
()
optimizer
.
step
()
batch_time
.
update
((
time
.
time
()
-
end
)
*
1000
)
fps
=
(
batch_size
/
batch_time
.
val
)
*
1000
msg
=
'Time {batch_time.val:.3f}ms (avg_time:{batch_time.avg:.3f}ms)
\t
'
\
'Data {data_time.val:.3f}ms ({data_time.avg:.3f}ms)
\t
'
\
'Fps {fps:.3f}
\t
'
.
format
(
batch_time
=
batch_time
,
data_time
=
data_time
,
fps
=
fps
)
return
cost
for
epoch
in
range
(
opt
.
nepoch
):
sampler
.
set_epoch
(
epoch
)
train_iter
=
iter
(
train_loader
)
i
=
0
time_all
=
0
while
i
<
len
(
train_loader
):
for
p
in
crnn
.
parameters
():
p
.
requires_grad
=
True
crnn
.
train
()
cost
=
trainBatch
(
crnn
,
criterion
,
optimizer
)
loss_avg
.
add
(
cost
)
i
+=
1
if
dist
.
get_rank
()
==
0
:
print
(
'
\r
[%d/%d][%d/%d] Loss: %f'
%
(
epoch
,
opt
.
nepoch
,
i
,
len
(
train_loader
),
loss_avg
.
val
()),
end
=
''
)
loss_avg
.
reset
()
# if local_rank ==0:
# val(crnn, test_dataset, criterion)
if
i
%
opt
.
saveInterval
==
0
and
local_rank
==
0
:
torch
.
save
(
crnn
.
state_dict
(),
'{0}/netCRNN_{1}_{2}.pth'
.
format
(
opt
.
expr_dir
,
epoch
,
i
))
utils.py
0 → 100644
View file @
4c93f0ed
#!/usr/bin/python
# encoding: utf-8
import
torch
import
torch.nn
as
nn
from
torch.autograd
import
Variable
import
collections
class
strLabelConverter
(
object
):
"""Convert between str and label.
NOTE:
Insert `blank` to the alphabet for CTC.
Args:
alphabet (str): set of the possible characters.
ignore_case (bool, default=True): whether or not to ignore all of the case.
"""
def
__init__
(
self
,
alphabet
,
ignore_case
=
True
):
self
.
_ignore_case
=
ignore_case
if
self
.
_ignore_case
:
alphabet
=
alphabet
.
lower
()
self
.
alphabet
=
alphabet
+
'-'
# for `-1` index
self
.
dict
=
{}
for
i
,
char
in
enumerate
(
alphabet
):
# NOTE: 0 is reserved for 'blank' required by wrap_ctc
self
.
dict
[
char
]
=
i
+
1
def
encode
(
self
,
text
):
"""Support batch or single str.
Args:
text (str or list of str): texts to convert.
Returns:
torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
torch.IntTensor [n]: length of each text.
"""
if
isinstance
(
text
,
str
):
text
=
[
self
.
dict
[
char
.
lower
()
if
self
.
_ignore_case
else
char
]
for
char
in
text
]
length
=
[
len
(
text
)]
elif
isinstance
(
text
,
collections
.
Iterable
):
length
=
[
len
(
s
)
for
s
in
text
]
text
=
''
.
join
(
text
)
text
,
_
=
self
.
encode
(
text
)
return
(
torch
.
IntTensor
(
text
),
torch
.
IntTensor
(
length
))
def
decode
(
self
,
t
,
length
,
raw
=
False
):
"""Decode encoded texts back into strs.
Args:
torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
torch.IntTensor [n]: length of each text.
Raises:
AssertionError: when the texts and its length does not match.
Returns:
text (str or list of str): texts to convert.
"""
if
length
.
numel
()
==
1
:
length
=
length
[
0
]
assert
t
.
numel
()
==
length
,
"text with length: {} does not match declared length: {}"
.
format
(
t
.
numel
(),
length
)
if
raw
:
return
''
.
join
([
self
.
alphabet
[
i
-
1
]
for
i
in
t
])
else
:
char_list
=
[]
for
i
in
range
(
length
):
if
t
[
i
]
!=
0
and
(
not
(
i
>
0
and
t
[
i
-
1
]
==
t
[
i
])):
char_list
.
append
(
self
.
alphabet
[
t
[
i
]
-
1
])
return
''
.
join
(
char_list
)
else
:
# batch mode
assert
t
.
numel
()
==
length
.
sum
(),
"texts with length: {} does not match declared length: {}"
.
format
(
t
.
numel
(),
length
.
sum
())
texts
=
[]
index
=
0
for
i
in
range
(
length
.
numel
()):
l
=
length
[
i
]
texts
.
append
(
self
.
decode
(
t
[
index
:
index
+
l
],
torch
.
IntTensor
([
l
]),
raw
=
raw
))
index
+=
l
return
texts
class
averager
(
object
):
"""Compute average for `torch.Variable` and `torch.Tensor`. """
def
__init__
(
self
):
self
.
reset
()
def
add
(
self
,
v
):
if
isinstance
(
v
,
Variable
):
count
=
v
.
data
.
numel
()
v
=
v
.
data
.
sum
()
elif
isinstance
(
v
,
torch
.
Tensor
):
count
=
v
.
numel
()
v
=
v
.
sum
()
self
.
n_count
+=
count
self
.
sum
+=
v
def
reset
(
self
):
self
.
n_count
=
0
self
.
sum
=
0
def
val
(
self
):
res
=
0
if
self
.
n_count
!=
0
:
res
=
self
.
sum
/
float
(
self
.
n_count
)
return
res
def
oneHot
(
v
,
v_length
,
nc
):
batchSize
=
v_length
.
size
(
0
)
maxLength
=
v_length
.
max
()
v_onehot
=
torch
.
FloatTensor
(
batchSize
,
maxLength
,
nc
).
fill_
(
0
)
acc
=
0
for
i
in
range
(
batchSize
):
length
=
v_length
[
i
]
label
=
v
[
acc
:
acc
+
length
].
view
(
-
1
,
1
).
long
()
v_onehot
[
i
,
:
length
].
scatter_
(
1
,
label
,
1.0
)
acc
+=
length
return
v_onehot
def
loadData
(
v
,
data
):
v
.
data
.
resize_
(
data
.
size
()).
copy_
(
data
)
def
prettyPrint
(
v
):
print
(
'Size {0}, Type: {1}'
.
format
(
str
(
v
.
size
()),
v
.
data
.
type
()))
print
(
'| Max: %f | Min: %f | Mean: %f'
%
(
v
.
max
().
data
[
0
],
v
.
min
().
data
[
0
],
v
.
mean
().
data
[
0
]))
def
assureRatio
(
img
):
"""Ensure imgH <= imgW."""
b
,
c
,
h
,
w
=
img
.
size
()
if
h
>
w
:
main
=
nn
.
UpsamplingBilinear2d
(
size
=
(
h
,
h
),
scale_factor
=
None
)
img
=
main
(
img
)
return
img
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment