Commit cc8336e3 authored by Yizhou Wang's avatar Yizhou Wang
Browse files

add average loss to log and tensorboard

parent 1d635cff
...@@ -2,6 +2,7 @@ import os ...@@ -2,6 +2,7 @@ import os
import time import time
import json import json
import argparse import argparse
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -173,6 +174,8 @@ if __name__ == "__main__": ...@@ -173,6 +174,8 @@ if __name__ == "__main__":
scheduler = StepLR(optimizer, step_size=config_dict['train_cfg']['lr_step'], gamma=0.1) scheduler = StepLR(optimizer, step_size=config_dict['train_cfg']['lr_step'], gamma=0.1)
iter_count = 0 iter_count = 0
loss_ave = 0
if cp_path is not None: if cp_path is not None:
checkpoint = torch.load(cp_path) checkpoint = torch.load(cp_path)
if 'optimizer_state_dict' in checkpoint: if 'optimizer_state_dict' in checkpoint:
...@@ -229,13 +232,15 @@ if __name__ == "__main__": ...@@ -229,13 +232,15 @@ if __name__ == "__main__":
loss_confmap.backward() loss_confmap.backward()
optimizer.step() optimizer.step()
loss_ave = np.average([loss_ave, loss_confmap.item()], weights=[iter_count, 1])
if iter % config_dict['train_cfg']['log_step'] == 0: if iter % config_dict['train_cfg']['log_step'] == 0:
# print statistics # print statistics
print('epoch %2d, iter %4d: loss: %.8f | load time: %.4f | backward time: %.4f' % print('epoch %2d, iter %4d: loss: %.6f (%.4f) | load time: %.2f | backward time: %.2f' %
(epoch + 1, iter + 1, loss_confmap.item(), tic - tic_load, time.time() - tic)) (epoch + 1, iter + 1, loss_confmap.item(), loss_ave, tic - tic_load, time.time() - tic))
with open(train_log_name, 'a+') as f_log: with open(train_log_name, 'a+') as f_log:
f_log.write('epoch %2d, iter %4d: loss: %.8f | load time: %.4f | backward time: %.4f\n' % f_log.write('epoch %2d, iter %4d: loss: %.6f (%.4f) | load time: %.2f | backward time: %.2f\n' %
(epoch + 1, iter + 1, loss_confmap.item(), tic - tic_load, time.time() - tic)) (epoch + 1, iter + 1, loss_confmap.item(), loss_ave, tic - tic_load, time.time() - tic))
if stacked_num is not None: if stacked_num is not None:
writer.add_scalar('loss/loss_all', loss_confmap.item(), iter_count) writer.add_scalar('loss/loss_all', loss_confmap.item(), iter_count)
...@@ -243,6 +248,8 @@ if __name__ == "__main__": ...@@ -243,6 +248,8 @@ if __name__ == "__main__":
else: else:
writer.add_scalar('loss/loss_all', loss_confmap.item(), iter_count) writer.add_scalar('loss/loss_all', loss_confmap.item(), iter_count)
confmap_pred = confmap_preds.cpu().detach().numpy() confmap_pred = confmap_preds.cpu().detach().numpy()
writer.add_scalar('loss/loss_ave', loss_ave, iter_count)
if 'mnet_cfg' in model_cfg: if 'mnet_cfg' in model_cfg:
chirp_amp_curr = chirp_amp(data.numpy()[0, :, 0, 0, :, :], radar_configs['data_type']) chirp_amp_curr = chirp_amp(data.numpy()[0, :, 0, 0, :, :], radar_configs['data_type'])
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment