Unverified Commit e94c91d3 authored by jihan.yang's avatar jihan.yang Committed by GitHub
Browse files

refix disp_dict when distributed (#700)

parent 65554a52
...@@ -61,19 +61,16 @@ def train_one_epoch(model, optimizer, train_loader, model_func, lr_scheduler, ac ...@@ -61,19 +61,16 @@ def train_one_epoch(model, optimizer, train_loader, model_func, lr_scheduler, ac
avg_forward_time = commu_utils.average_reduce_value(cur_forward_time) avg_forward_time = commu_utils.average_reduce_value(cur_forward_time)
avg_batch_time = commu_utils.average_reduce_value(cur_batch_time) avg_batch_time = commu_utils.average_reduce_value(cur_batch_time)
# log to console and tensorboard
if rank == 0: if rank == 0:
data_time.update(avg_data_time) data_time.update(avg_data_time)
forward_time.update(avg_forward_time) forward_time.update(avg_forward_time)
batch_time.update(avg_batch_time) batch_time.update(avg_batch_time)
disp_dict.update({
'loss': loss.item(), 'lr': cur_lr, 'd_time': f'{data_time.val:.2f}({data_time.avg:.2f})',
'f_time': f'{forward_time.val:.2f}({forward_time.avg:.2f})', 'b_time': f'{batch_time.val:.2f}({batch_time.avg:.2f})'
})
disp_dict.update({
'loss': loss.item(), 'lr': cur_lr, 'd_time': f'{data_time.val:.2f}({data_time.avg:.2f})',
'f_time': f'{forward_time.val:.2f}({forward_time.avg:.2f})', 'b_time': f'{batch_time.val:.2f}({batch_time.avg:.2f})'
})
# log to console and tensorboard
if rank == 0:
pbar.update() pbar.update()
pbar.set_postfix(dict(total_it=accumulated_iter)) pbar.set_postfix(dict(total_it=accumulated_iter))
tbar.set_postfix(disp_dict) tbar.set_postfix(disp_dict)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment