Commit 97a4e42e authored by Shaoshuai Shi's avatar Shaoshuai Shi
Browse files

support to use logger to store intermediate losses/states during training

parent 57b19553
...@@ -43,6 +43,10 @@ def parse_config(): ...@@ -43,6 +43,10 @@ def parse_config():
parser.add_argument('--start_epoch', type=int, default=0, help='') parser.add_argument('--start_epoch', type=int, default=0, help='')
parser.add_argument('--num_epochs_to_eval', type=int, default=0, help='number of checkpoints to be evaluated') parser.add_argument('--num_epochs_to_eval', type=int, default=0, help='number of checkpoints to be evaluated')
parser.add_argument('--save_to_file', action='store_true', default=False, help='') parser.add_argument('--save_to_file', action='store_true', default=False, help='')
parser.add_argument('--use_tqdm_to_record', action='store_true', default=False, help='if True, the intermediate losses will not be logged to file, only tqdm will be used')
parser.add_argument('--logger_iter_interval', type=int, default=50, help='')
parser.add_argument('--ckpt_save_time_interval', type=int, default=300, help='in terms of seconds')
args = parser.parse_args() args = parser.parse_args()
...@@ -131,13 +135,19 @@ def main(): ...@@ -131,13 +135,19 @@ def main():
it, start_epoch = model.load_params_with_optimizer(args.ckpt, to_cpu=dist_train, optimizer=optimizer, logger=logger) it, start_epoch = model.load_params_with_optimizer(args.ckpt, to_cpu=dist_train, optimizer=optimizer, logger=logger)
last_epoch = start_epoch + 1 last_epoch = start_epoch + 1
else: else:
ckpt_list = glob.glob(str(ckpt_dir / '*checkpoint_epoch_*.pth')) ckpt_list = glob.glob(str(ckpt_dir / '*.pth'))
if len(ckpt_list) > 0: if len(ckpt_list) > 0:
ckpt_list.sort(key=os.path.getmtime) ckpt_list.sort(key=os.path.getmtime)
it, start_epoch = model.load_params_with_optimizer( while len(ckpt_list) > 0:
ckpt_list[-1], to_cpu=dist_train, optimizer=optimizer, logger=logger try:
) it, start_epoch = model.load_params_with_optimizer(
last_epoch = start_epoch + 1 ckpt_list[-1], to_cpu=dist_train, optimizer=optimizer, logger=logger
)
last_epoch = start_epoch + 1
break
except:
ckpt_list = ckpt_list[:-1]
model.train() # before wrap to DistributedDataParallel to support fixed some parameters model.train() # before wrap to DistributedDataParallel to support fixed some parameters
if dist_train: if dist_train:
...@@ -169,7 +179,11 @@ def main(): ...@@ -169,7 +179,11 @@ def main():
lr_warmup_scheduler=lr_warmup_scheduler, lr_warmup_scheduler=lr_warmup_scheduler,
ckpt_save_interval=args.ckpt_save_interval, ckpt_save_interval=args.ckpt_save_interval,
max_ckpt_save_num=args.max_ckpt_save_num, max_ckpt_save_num=args.max_ckpt_save_num,
merge_all_iters_to_one_epoch=args.merge_all_iters_to_one_epoch merge_all_iters_to_one_epoch=args.merge_all_iters_to_one_epoch,
logger=logger,
logger_iter_interval=args.logger_iter_interval,
ckpt_save_time_interval=args.ckpt_save_time_interval,
use_logger_to_record=not args.use_tqdm_to_record
) )
if hasattr(train_set, 'use_shared_memory') and train_set.use_shared_memory: if hasattr(train_set, 'use_shared_memory') and train_set.use_shared_memory:
......
...@@ -9,17 +9,22 @@ from pcdet.utils import common_utils, commu_utils ...@@ -9,17 +9,22 @@ from pcdet.utils import common_utils, commu_utils
def train_one_epoch(model, optimizer, train_loader, model_func, lr_scheduler, accumulated_iter, optim_cfg, def train_one_epoch(model, optimizer, train_loader, model_func, lr_scheduler, accumulated_iter, optim_cfg,
rank, tbar, total_it_each_epoch, dataloader_iter, tb_log=None, leave_pbar=False): rank, tbar, total_it_each_epoch, dataloader_iter, tb_log=None, leave_pbar=False,
use_logger_to_record=False, logger=None, logger_iter_interval=50, cur_epoch=None,
total_epochs=None, ckpt_save_dir=None, ckpt_save_time_interval=300):
if total_it_each_epoch == len(train_loader): if total_it_each_epoch == len(train_loader):
dataloader_iter = iter(train_loader) dataloader_iter = iter(train_loader)
ckpt_save_cnt = 1
start_it = accumulated_iter % total_it_each_epoch
if rank == 0: if rank == 0:
pbar = tqdm.tqdm(total=total_it_each_epoch, leave=leave_pbar, desc='train', dynamic_ncols=True) pbar = tqdm.tqdm(total=total_it_each_epoch, leave=leave_pbar, desc='train', dynamic_ncols=True)
data_time = common_utils.AverageMeter() data_time = common_utils.AverageMeter()
batch_time = common_utils.AverageMeter() batch_time = common_utils.AverageMeter()
forward_time = common_utils.AverageMeter() forward_time = common_utils.AverageMeter()
for cur_it in range(total_it_each_epoch): for cur_it in range(start_it, total_it_each_epoch):
end = time.time() end = time.time()
try: try:
batch = next(dataloader_iter) batch = next(dataloader_iter)
...@@ -66,21 +71,50 @@ def train_one_epoch(model, optimizer, train_loader, model_func, lr_scheduler, ac ...@@ -66,21 +71,50 @@ def train_one_epoch(model, optimizer, train_loader, model_func, lr_scheduler, ac
data_time.update(avg_data_time) data_time.update(avg_data_time)
forward_time.update(avg_forward_time) forward_time.update(avg_forward_time)
batch_time.update(avg_batch_time) batch_time.update(avg_batch_time)
disp_dict.update({ disp_dict.update({
'loss': loss.item(), 'lr': cur_lr, 'd_time': f'{data_time.val:.2f}({data_time.avg:.2f})', 'loss': loss.item(), 'lr': cur_lr, 'd_time': f'{data_time.val:.2f}({data_time.avg:.2f})',
'f_time': f'{forward_time.val:.2f}({forward_time.avg:.2f})', 'b_time': f'{batch_time.val:.2f}({batch_time.avg:.2f})' 'f_time': f'{forward_time.val:.2f}({forward_time.avg:.2f})', 'b_time': f'{batch_time.val:.2f}({batch_time.avg:.2f})'
}) })
pbar.update() if use_logger_to_record:
pbar.set_postfix(dict(total_it=accumulated_iter)) if accumulated_iter % logger_iter_interval == 0 or cur_it == start_it or cur_it + 1 == total_it_each_epoch:
tbar.set_postfix(disp_dict) trained_time_past_all = tbar.format_dict['elapsed']
tbar.refresh() second_each_iter = pbar.format_dict['elapsed'] / max(cur_it - start_it + 1, 1.0)
trained_time_each_epoch = pbar.format_dict['elapsed']
remaining_second_each_epoch = second_each_iter * (total_it_each_epoch - cur_it)
remaining_second_all = second_each_iter * ((total_epochs - cur_epoch) * total_it_each_epoch - cur_it)
disp_str = ', '.join([f'{key}={val}' for key, val in disp_dict.items() if key != 'lr'])
disp_str += f', lr={disp_dict["lr"]}'
batch_size = batch.get('batch_size', None)
logger.info(f'epoch: {cur_epoch}/{total_epochs}, acc_iter={accumulated_iter}, cur_iter={cur_it}/{total_it_each_epoch}, batch_size={batch_size}, '
f'time_cost(epoch): {tbar.format_interval(trained_time_each_epoch)}/{tbar.format_interval(remaining_second_each_epoch)}, '
f'time_cost(all): {tbar.format_interval(trained_time_past_all)}/{tbar.format_interval(remaining_second_all)}, '
f'{disp_str}')
else:
pbar.update()
pbar.set_postfix(dict(total_it=accumulated_iter))
tbar.set_postfix(disp_dict)
tbar.refresh()
if tb_log is not None: if tb_log is not None:
tb_log.add_scalar('train/loss', loss, accumulated_iter) tb_log.add_scalar('train/loss', loss, accumulated_iter)
tb_log.add_scalar('meta_data/learning_rate', cur_lr, accumulated_iter) tb_log.add_scalar('meta_data/learning_rate', cur_lr, accumulated_iter)
for key, val in tb_dict.items(): for key, val in tb_dict.items():
tb_log.add_scalar('train/' + key, val, accumulated_iter) tb_log.add_scalar('train/' + key, val, accumulated_iter)
# save intermediate ckpt every {ckpt_save_time_interval} seconds
time_past_this_epoch = pbar.format_dict['elapsed']
if time_past_this_epoch // ckpt_save_time_interval >= ckpt_save_cnt:
ckpt_name = ckpt_save_dir / 'latest_model'
save_checkpoint(
checkpoint_state(model, optimizer, cur_epoch, accumulated_iter), filename=ckpt_name,
)
logger.info(f'Save latest model to {ckpt_name}')
ckpt_save_cnt += 1
if rank == 0: if rank == 0:
pbar.close() pbar.close()
return accumulated_iter return accumulated_iter
...@@ -89,7 +123,8 @@ def train_one_epoch(model, optimizer, train_loader, model_func, lr_scheduler, ac ...@@ -89,7 +123,8 @@ def train_one_epoch(model, optimizer, train_loader, model_func, lr_scheduler, ac
def train_model(model, optimizer, train_loader, model_func, lr_scheduler, optim_cfg, def train_model(model, optimizer, train_loader, model_func, lr_scheduler, optim_cfg,
start_epoch, total_epochs, start_iter, rank, tb_log, ckpt_save_dir, train_sampler=None, start_epoch, total_epochs, start_iter, rank, tb_log, ckpt_save_dir, train_sampler=None,
lr_warmup_scheduler=None, ckpt_save_interval=1, max_ckpt_save_num=50, lr_warmup_scheduler=None, ckpt_save_interval=1, max_ckpt_save_num=50,
merge_all_iters_to_one_epoch=False): merge_all_iters_to_one_epoch=False,
use_logger_to_record=False, logger=None, logger_iter_interval=None, ckpt_save_time_interval=None):
accumulated_iter = start_iter accumulated_iter = start_iter
with tqdm.trange(start_epoch, total_epochs, desc='epochs', dynamic_ncols=True, leave=(rank == 0)) as tbar: with tqdm.trange(start_epoch, total_epochs, desc='epochs', dynamic_ncols=True, leave=(rank == 0)) as tbar:
total_it_each_epoch = len(train_loader) total_it_each_epoch = len(train_loader)
...@@ -115,7 +150,12 @@ def train_model(model, optimizer, train_loader, model_func, lr_scheduler, optim_ ...@@ -115,7 +150,12 @@ def train_model(model, optimizer, train_loader, model_func, lr_scheduler, optim_
rank=rank, tbar=tbar, tb_log=tb_log, rank=rank, tbar=tbar, tb_log=tb_log,
leave_pbar=(cur_epoch + 1 == total_epochs), leave_pbar=(cur_epoch + 1 == total_epochs),
total_it_each_epoch=total_it_each_epoch, total_it_each_epoch=total_it_each_epoch,
dataloader_iter=dataloader_iter dataloader_iter=dataloader_iter,
cur_epoch=cur_epoch, total_epochs=total_epochs,
use_logger_to_record=use_logger_to_record,
logger=logger, logger_iter_interval=logger_iter_interval,
ckpt_save_dir=ckpt_save_dir, ckpt_save_time_interval=ckpt_save_time_interval
) )
# save trained model # save trained model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment