Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenPCDet
Commits
97a4e42e
Commit
97a4e42e
authored
Aug 17, 2022
by
Shaoshuai Shi
Browse files
support to use logger to store intermediate losses/states during training
parent
57b19553
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
69 additions
and
15 deletions
+69
-15
tools/train.py
tools/train.py
+20
-6
tools/train_utils/train_utils.py
tools/train_utils/train_utils.py
+49
-9
No files found.
tools/train.py
View file @
97a4e42e
...
@@ -44,6 +44,10 @@ def parse_config():
...
@@ -44,6 +44,10 @@ def parse_config():
parser
.
add_argument
(
'--num_epochs_to_eval'
,
type
=
int
,
default
=
0
,
help
=
'number of checkpoints to be evaluated'
)
parser
.
add_argument
(
'--num_epochs_to_eval'
,
type
=
int
,
default
=
0
,
help
=
'number of checkpoints to be evaluated'
)
parser
.
add_argument
(
'--save_to_file'
,
action
=
'store_true'
,
default
=
False
,
help
=
''
)
parser
.
add_argument
(
'--save_to_file'
,
action
=
'store_true'
,
default
=
False
,
help
=
''
)
parser
.
add_argument
(
'--use_tqdm_to_record'
,
action
=
'store_true'
,
default
=
False
,
help
=
'if True, the intermediate losses will not be logged to file, only tqdm will be used'
)
parser
.
add_argument
(
'--logger_iter_interval'
,
type
=
int
,
default
=
50
,
help
=
''
)
parser
.
add_argument
(
'--ckpt_save_time_interval'
,
type
=
int
,
default
=
300
,
help
=
'in terms of seconds'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
cfg_from_yaml_file
(
args
.
cfg_file
,
cfg
)
cfg_from_yaml_file
(
args
.
cfg_file
,
cfg
)
...
@@ -131,13 +135,19 @@ def main():
...
@@ -131,13 +135,19 @@ def main():
it
,
start_epoch
=
model
.
load_params_with_optimizer
(
args
.
ckpt
,
to_cpu
=
dist_train
,
optimizer
=
optimizer
,
logger
=
logger
)
it
,
start_epoch
=
model
.
load_params_with_optimizer
(
args
.
ckpt
,
to_cpu
=
dist_train
,
optimizer
=
optimizer
,
logger
=
logger
)
last_epoch
=
start_epoch
+
1
last_epoch
=
start_epoch
+
1
else
:
else
:
ckpt_list
=
glob
.
glob
(
str
(
ckpt_dir
/
'*checkpoint_epoch_*.pth'
))
ckpt_list
=
glob
.
glob
(
str
(
ckpt_dir
/
'*.pth'
))
if
len
(
ckpt_list
)
>
0
:
if
len
(
ckpt_list
)
>
0
:
ckpt_list
.
sort
(
key
=
os
.
path
.
getmtime
)
ckpt_list
.
sort
(
key
=
os
.
path
.
getmtime
)
while
len
(
ckpt_list
)
>
0
:
try
:
it
,
start_epoch
=
model
.
load_params_with_optimizer
(
it
,
start_epoch
=
model
.
load_params_with_optimizer
(
ckpt_list
[
-
1
],
to_cpu
=
dist_train
,
optimizer
=
optimizer
,
logger
=
logger
ckpt_list
[
-
1
],
to_cpu
=
dist_train
,
optimizer
=
optimizer
,
logger
=
logger
)
)
last_epoch
=
start_epoch
+
1
last_epoch
=
start_epoch
+
1
break
except
:
ckpt_list
=
ckpt_list
[:
-
1
]
model
.
train
()
# before wrap to DistributedDataParallel to support fixed some parameters
model
.
train
()
# before wrap to DistributedDataParallel to support fixed some parameters
if
dist_train
:
if
dist_train
:
...
@@ -169,7 +179,11 @@ def main():
...
@@ -169,7 +179,11 @@ def main():
lr_warmup_scheduler
=
lr_warmup_scheduler
,
lr_warmup_scheduler
=
lr_warmup_scheduler
,
ckpt_save_interval
=
args
.
ckpt_save_interval
,
ckpt_save_interval
=
args
.
ckpt_save_interval
,
max_ckpt_save_num
=
args
.
max_ckpt_save_num
,
max_ckpt_save_num
=
args
.
max_ckpt_save_num
,
merge_all_iters_to_one_epoch
=
args
.
merge_all_iters_to_one_epoch
merge_all_iters_to_one_epoch
=
args
.
merge_all_iters_to_one_epoch
,
logger
=
logger
,
logger_iter_interval
=
args
.
logger_iter_interval
,
ckpt_save_time_interval
=
args
.
ckpt_save_time_interval
,
use_logger_to_record
=
not
args
.
use_tqdm_to_record
)
)
if
hasattr
(
train_set
,
'use_shared_memory'
)
and
train_set
.
use_shared_memory
:
if
hasattr
(
train_set
,
'use_shared_memory'
)
and
train_set
.
use_shared_memory
:
...
...
tools/train_utils/train_utils.py
View file @
97a4e42e
...
@@ -9,17 +9,22 @@ from pcdet.utils import common_utils, commu_utils
...
@@ -9,17 +9,22 @@ from pcdet.utils import common_utils, commu_utils
def
train_one_epoch
(
model
,
optimizer
,
train_loader
,
model_func
,
lr_scheduler
,
accumulated_iter
,
optim_cfg
,
def
train_one_epoch
(
model
,
optimizer
,
train_loader
,
model_func
,
lr_scheduler
,
accumulated_iter
,
optim_cfg
,
rank
,
tbar
,
total_it_each_epoch
,
dataloader_iter
,
tb_log
=
None
,
leave_pbar
=
False
):
rank
,
tbar
,
total_it_each_epoch
,
dataloader_iter
,
tb_log
=
None
,
leave_pbar
=
False
,
use_logger_to_record
=
False
,
logger
=
None
,
logger_iter_interval
=
50
,
cur_epoch
=
None
,
total_epochs
=
None
,
ckpt_save_dir
=
None
,
ckpt_save_time_interval
=
300
):
if
total_it_each_epoch
==
len
(
train_loader
):
if
total_it_each_epoch
==
len
(
train_loader
):
dataloader_iter
=
iter
(
train_loader
)
dataloader_iter
=
iter
(
train_loader
)
ckpt_save_cnt
=
1
start_it
=
accumulated_iter
%
total_it_each_epoch
if
rank
==
0
:
if
rank
==
0
:
pbar
=
tqdm
.
tqdm
(
total
=
total_it_each_epoch
,
leave
=
leave_pbar
,
desc
=
'train'
,
dynamic_ncols
=
True
)
pbar
=
tqdm
.
tqdm
(
total
=
total_it_each_epoch
,
leave
=
leave_pbar
,
desc
=
'train'
,
dynamic_ncols
=
True
)
data_time
=
common_utils
.
AverageMeter
()
data_time
=
common_utils
.
AverageMeter
()
batch_time
=
common_utils
.
AverageMeter
()
batch_time
=
common_utils
.
AverageMeter
()
forward_time
=
common_utils
.
AverageMeter
()
forward_time
=
common_utils
.
AverageMeter
()
for
cur_it
in
range
(
total_it_each_epoch
):
for
cur_it
in
range
(
start_it
,
total_it_each_epoch
):
end
=
time
.
time
()
end
=
time
.
time
()
try
:
try
:
batch
=
next
(
dataloader_iter
)
batch
=
next
(
dataloader_iter
)
...
@@ -66,11 +71,29 @@ def train_one_epoch(model, optimizer, train_loader, model_func, lr_scheduler, ac
...
@@ -66,11 +71,29 @@ def train_one_epoch(model, optimizer, train_loader, model_func, lr_scheduler, ac
data_time
.
update
(
avg_data_time
)
data_time
.
update
(
avg_data_time
)
forward_time
.
update
(
avg_forward_time
)
forward_time
.
update
(
avg_forward_time
)
batch_time
.
update
(
avg_batch_time
)
batch_time
.
update
(
avg_batch_time
)
disp_dict
.
update
({
disp_dict
.
update
({
'loss'
:
loss
.
item
(),
'lr'
:
cur_lr
,
'd_time'
:
f
'
{
data_time
.
val
:.
2
f
}
(
{
data_time
.
avg
:.
2
f
}
)'
,
'loss'
:
loss
.
item
(),
'lr'
:
cur_lr
,
'd_time'
:
f
'
{
data_time
.
val
:.
2
f
}
(
{
data_time
.
avg
:.
2
f
}
)'
,
'f_time'
:
f
'
{
forward_time
.
val
:.
2
f
}
(
{
forward_time
.
avg
:.
2
f
}
)'
,
'b_time'
:
f
'
{
batch_time
.
val
:.
2
f
}
(
{
batch_time
.
avg
:.
2
f
}
)'
'f_time'
:
f
'
{
forward_time
.
val
:.
2
f
}
(
{
forward_time
.
avg
:.
2
f
}
)'
,
'b_time'
:
f
'
{
batch_time
.
val
:.
2
f
}
(
{
batch_time
.
avg
:.
2
f
}
)'
})
})
if
use_logger_to_record
:
if
accumulated_iter
%
logger_iter_interval
==
0
or
cur_it
==
start_it
or
cur_it
+
1
==
total_it_each_epoch
:
trained_time_past_all
=
tbar
.
format_dict
[
'elapsed'
]
second_each_iter
=
pbar
.
format_dict
[
'elapsed'
]
/
max
(
cur_it
-
start_it
+
1
,
1.0
)
trained_time_each_epoch
=
pbar
.
format_dict
[
'elapsed'
]
remaining_second_each_epoch
=
second_each_iter
*
(
total_it_each_epoch
-
cur_it
)
remaining_second_all
=
second_each_iter
*
((
total_epochs
-
cur_epoch
)
*
total_it_each_epoch
-
cur_it
)
disp_str
=
', '
.
join
([
f
'
{
key
}
=
{
val
}
'
for
key
,
val
in
disp_dict
.
items
()
if
key
!=
'lr'
])
disp_str
+=
f
', lr=
{
disp_dict
[
"lr"
]
}
'
batch_size
=
batch
.
get
(
'batch_size'
,
None
)
logger
.
info
(
f
'epoch:
{
cur_epoch
}
/
{
total_epochs
}
, acc_iter=
{
accumulated_iter
}
, cur_iter=
{
cur_it
}
/
{
total_it_each_epoch
}
, batch_size=
{
batch_size
}
, '
f
'time_cost(epoch):
{
tbar
.
format_interval
(
trained_time_each_epoch
)
}
/
{
tbar
.
format_interval
(
remaining_second_each_epoch
)
}
, '
f
'time_cost(all):
{
tbar
.
format_interval
(
trained_time_past_all
)
}
/
{
tbar
.
format_interval
(
remaining_second_all
)
}
, '
f
'
{
disp_str
}
'
)
else
:
pbar
.
update
()
pbar
.
update
()
pbar
.
set_postfix
(
dict
(
total_it
=
accumulated_iter
))
pbar
.
set_postfix
(
dict
(
total_it
=
accumulated_iter
))
tbar
.
set_postfix
(
disp_dict
)
tbar
.
set_postfix
(
disp_dict
)
...
@@ -81,6 +104,17 @@ def train_one_epoch(model, optimizer, train_loader, model_func, lr_scheduler, ac
...
@@ -81,6 +104,17 @@ def train_one_epoch(model, optimizer, train_loader, model_func, lr_scheduler, ac
tb_log
.
add_scalar
(
'meta_data/learning_rate'
,
cur_lr
,
accumulated_iter
)
tb_log
.
add_scalar
(
'meta_data/learning_rate'
,
cur_lr
,
accumulated_iter
)
for
key
,
val
in
tb_dict
.
items
():
for
key
,
val
in
tb_dict
.
items
():
tb_log
.
add_scalar
(
'train/'
+
key
,
val
,
accumulated_iter
)
tb_log
.
add_scalar
(
'train/'
+
key
,
val
,
accumulated_iter
)
# save intermediate ckpt every {ckpt_save_time_interval} seconds
time_past_this_epoch
=
pbar
.
format_dict
[
'elapsed'
]
if
time_past_this_epoch
//
ckpt_save_time_interval
>=
ckpt_save_cnt
:
ckpt_name
=
ckpt_save_dir
/
'latest_model'
save_checkpoint
(
checkpoint_state
(
model
,
optimizer
,
cur_epoch
,
accumulated_iter
),
filename
=
ckpt_name
,
)
logger
.
info
(
f
'Save latest model to
{
ckpt_name
}
'
)
ckpt_save_cnt
+=
1
if
rank
==
0
:
if
rank
==
0
:
pbar
.
close
()
pbar
.
close
()
return
accumulated_iter
return
accumulated_iter
...
@@ -89,7 +123,8 @@ def train_one_epoch(model, optimizer, train_loader, model_func, lr_scheduler, ac
...
@@ -89,7 +123,8 @@ def train_one_epoch(model, optimizer, train_loader, model_func, lr_scheduler, ac
def
train_model
(
model
,
optimizer
,
train_loader
,
model_func
,
lr_scheduler
,
optim_cfg
,
def
train_model
(
model
,
optimizer
,
train_loader
,
model_func
,
lr_scheduler
,
optim_cfg
,
start_epoch
,
total_epochs
,
start_iter
,
rank
,
tb_log
,
ckpt_save_dir
,
train_sampler
=
None
,
start_epoch
,
total_epochs
,
start_iter
,
rank
,
tb_log
,
ckpt_save_dir
,
train_sampler
=
None
,
lr_warmup_scheduler
=
None
,
ckpt_save_interval
=
1
,
max_ckpt_save_num
=
50
,
lr_warmup_scheduler
=
None
,
ckpt_save_interval
=
1
,
max_ckpt_save_num
=
50
,
merge_all_iters_to_one_epoch
=
False
):
merge_all_iters_to_one_epoch
=
False
,
use_logger_to_record
=
False
,
logger
=
None
,
logger_iter_interval
=
None
,
ckpt_save_time_interval
=
None
):
accumulated_iter
=
start_iter
accumulated_iter
=
start_iter
with
tqdm
.
trange
(
start_epoch
,
total_epochs
,
desc
=
'epochs'
,
dynamic_ncols
=
True
,
leave
=
(
rank
==
0
))
as
tbar
:
with
tqdm
.
trange
(
start_epoch
,
total_epochs
,
desc
=
'epochs'
,
dynamic_ncols
=
True
,
leave
=
(
rank
==
0
))
as
tbar
:
total_it_each_epoch
=
len
(
train_loader
)
total_it_each_epoch
=
len
(
train_loader
)
...
@@ -115,7 +150,12 @@ def train_model(model, optimizer, train_loader, model_func, lr_scheduler, optim_
...
@@ -115,7 +150,12 @@ def train_model(model, optimizer, train_loader, model_func, lr_scheduler, optim_
rank
=
rank
,
tbar
=
tbar
,
tb_log
=
tb_log
,
rank
=
rank
,
tbar
=
tbar
,
tb_log
=
tb_log
,
leave_pbar
=
(
cur_epoch
+
1
==
total_epochs
),
leave_pbar
=
(
cur_epoch
+
1
==
total_epochs
),
total_it_each_epoch
=
total_it_each_epoch
,
total_it_each_epoch
=
total_it_each_epoch
,
dataloader_iter
=
dataloader_iter
dataloader_iter
=
dataloader_iter
,
cur_epoch
=
cur_epoch
,
total_epochs
=
total_epochs
,
use_logger_to_record
=
use_logger_to_record
,
logger
=
logger
,
logger_iter_interval
=
logger_iter_interval
,
ckpt_save_dir
=
ckpt_save_dir
,
ckpt_save_time_interval
=
ckpt_save_time_interval
)
)
# save trained model
# save trained model
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment