Commit a8863510 authored by Yizhou Wang's avatar Yizhou Wang
Browse files

v1.0: first commit

parent 16d8dda7
# -*- coding: utf-8 -*-
import matplotlib.pyplot as plt
from matplotlib.font_manager import FontProperties
fig = plt.figure(figsize=(8, 8))
fp = FontProperties(fname=r"assets/fontawesome-free-5.12.0-desktop/otfs/solid-900.otf")
symbols = {
'pedestrian': "\uf554",
'cyclist': "\uf84a",
'car': "\uf1b9",
}
import matplotlib.pyplot as plt
def visualize_ols_hist(olss_flatten):
_ = plt.hist(olss_flatten, bins='auto') # arguments are passed to np.histogram
plt.title("OLS Distribution")
plt.show()
import numpy as np
import matplotlib.pyplot as plt
from rodnet.core.radar_processing.chirp_ops import chirp_amp
def visualize_radar_chirp(chirp, radar_data_type):
"""
Visualize radar data of one chirp
:param chirp: (w x h x 2) or (2 x w x h)
:param radar_data_type: current available types include 'RI', 'RISEP', 'AP', 'APSEP'
:return:
"""
chirp_abs = chirp_amp(chirp, radar_data_type)
plt.imshow(chirp_abs)
plt.show()
def visualize_radar_chirps(chirps, radar_data_type):
"""
Visualize radar data of multiple chirps
:param chirps: (N x w x h x 2) or (N x 2 x w x h)
:param radar_data_type: current available types include 'RI', 'RISEP', 'AP', 'APSEP'
:return:
"""
num_chirps, c0, c1, c2 = chirps.shape
if c2 == 2:
chirps_abs = np.zeros((num_chirps, c0, c1))
elif c0 == 2:
chirps_abs = np.zeros((num_chirps, c1, c2))
else:
raise ValueError
for chirp_id in range(num_chirps):
chirps_abs[chirp_id, :, :] = chirp_amp(chirps[chirp_id, :, :, :], radar_data_type)
chirp_abs_avg = np.mean(chirps_abs, axis=0)
plt.imshow(chirp_abs_avg)
plt.show()
def visualize_fuse_crdets(chirp, obj_dicts, figname=None, viz=False):
chirp_abs = chirp_amp(chirp)
chirp_shape = chirp_abs.shape
plt.figure()
plt.imshow(chirp_abs, vmin=0, vmax=1, origin='lower')
for obj_id, obj_dict in enumerate(obj_dicts):
plt.scatter(obj_dict['angle_id'], obj_dict['range_id'], s=10, c='white')
try:
text = str(obj_dict['object_id']) + ' ' + obj_dict['class']
except:
text = str(obj_dict['object_id'])
plt.text(obj_dict['angle_id'] + 5, obj_dict['range_id'], text, color='white', fontsize=10)
plt.xlim(0, chirp_shape[1])
plt.ylim(0, chirp_shape[0])
if viz:
plt.show()
else:
plt.savefig(figname)
plt.close()
def visualize_fuse_crdets_compare(img_path, chirp, c_dicts, r_dicts, cr_dicts, figname=None, viz=False):
chirp_abs = chirp_amp(chirp)
chirp_shape = chirp_abs.shape
fig_local = plt.figure()
fig_local.set_size_inches(16, 4)
fig_local.add_subplot(1, 4, 1)
im = plt.imread(img_path)
plt.imshow(im)
fig_local.add_subplot(1, 4, 2)
plt.imshow(chirp_abs, vmin=0, vmax=1, origin='lower')
for obj_id, obj_dict in enumerate(c_dicts):
plt.scatter(obj_dict['angle_id'], obj_dict['range_id'], s=10, c='white')
try:
obj_dict['object_id']
except:
obj_dict['object_id'] = ''
try:
text = str(obj_dict['object_id']) + ' ' + obj_dict['class']
except:
text = str(obj_dict['object_id'])
plt.text(obj_dict['angle_id'] + 5, obj_dict['range_id'], text, color='white', fontsize=10)
plt.xlim(0, chirp_shape[1])
plt.ylim(0, chirp_shape[0])
fig_local.add_subplot(1, 4, 3)
plt.imshow(chirp_abs, vmin=0, vmax=1, origin='lower')
for obj_id, obj_dict in enumerate(r_dicts):
plt.scatter(obj_dict['angle_id'], obj_dict['range_id'], s=10, c='white')
try:
obj_dict['object_id']
except:
obj_dict['object_id'] = ''
try:
text = str(obj_dict['object_id']) + ' ' + obj_dict['class']
except:
text = str(obj_dict['object_id'])
plt.text(obj_dict['angle_id'] + 5, obj_dict['range_id'], text, color='white', fontsize=10)
plt.xlim(0, chirp_shape[1])
plt.ylim(0, chirp_shape[0])
fig_local.add_subplot(1, 4, 4)
plt.imshow(chirp_abs, vmin=0, vmax=1, origin='lower')
for obj_id, obj_dict in enumerate(cr_dicts):
plt.scatter(obj_dict['angle_id'], obj_dict['range_id'], s=10, c='white')
try:
obj_dict['object_id']
except:
obj_dict['object_id'] = '%.2f' % obj_dict['confidence']
try:
text = str(obj_dict['object_id']) + ' ' + obj_dict['class']
except:
text = str(obj_dict['object_id'])
plt.text(obj_dict['angle_id'] + 5, obj_dict['range_id'], text, color='white', fontsize=10)
plt.xlim(0, chirp_shape[1])
plt.ylim(0, chirp_shape[0])
if viz:
plt.show()
else:
plt.savefig(figname)
plt.close()
def visualize_anno_ramap(chirp, obj_info, figname, viz=False):
chirp_abs = chirp_amp(chirp)
plt.figure()
plt.imshow(chirp_abs, vmin=0, vmax=1, origin='lower')
for obj in obj_info:
rng_idx, agl_idx, class_id = obj
if class_id >= 0:
try:
cla_str = class_table[class_id]
except:
continue
else:
continue
plt.scatter(agl_idx, rng_idx, s=10, c='white')
plt.text(agl_idx + 5, rng_idx, cla_str, color='white', fontsize=10)
if viz:
plt.show()
else:
plt.savefig(figname)
plt.close()
import os
import sys
import shutil
import numpy as np
import json
import pickle
import argparse
from cruw.cruw import CRUW
from rodnet.core.confidence_map import generate_confmap, normalize_confmap, add_noise_channel
from rodnet.utils.load_configs import load_configs_from_file
from rodnet.utils.visualization import visualize_confmap
SPLITS_LIST = ['train', 'valid', 'test', 'demo']
def parse_args():
parser = argparse.ArgumentParser(description='Prepare RODNet data.')
parser.add_argument('--config', type=str, dest='config', help='configuration file path')
parser.add_argument('--data_root', type=str, help='directory to the prepared data')
parser.add_argument('--split', type=str, dest='split', help='choose from train, valid, test, supertest')
parser.add_argument('--out_data_dir', type=str, default='./data',
help='data directory to save the prepared data')
parser.add_argument('--overwrite', action="store_true", help="overwrite prepared data if exist")
args = parser.parse_args()
return args
def prepare_data(dataset, config_dict, data_dir, split, save_dir, viz=False, overwrite=False):
"""
Prepare pickle data for RODNet training and testing
:param dataset: dataset object
:param config_dict: rodnet configurations
:param data_dir: output directory of the processed data
:param split: train, valid, test, demo, etc.
:param viz: whether visualize the prepared data
:param overwrite: whether overwrite the existing prepared data
:return:
"""
camera_configs = dataset.sensor_cfg.camera_cfg
radar_configs = dataset.sensor_cfg.radar_cfg
n_chirp = radar_configs['n_chirps']
n_class = dataset.object_cfg.n_class
data_root = config_dict['dataset_cfg']['data_root']
anno_root = config_dict['dataset_cfg']['anno_root']
set_cfg = config_dict['dataset_cfg'][split]
sets_seqs = set_cfg['seqs']
if overwrite:
if os.path.exists(os.path.join(data_dir, split)):
shutil.rmtree(os.path.join(data_dir, split))
os.makedirs(os.path.join(data_dir, split))
for seq in sets_seqs:
seq_path = os.path.join(data_root, seq)
seq_anno_path = os.path.join(anno_root, seq + '.json')
save_path = os.path.join(save_dir, seq + '.pkl')
print("Sequence %s saving to %s" % (seq_path, save_path))
try:
if not overwrite and os.path.exists(save_path):
print("%s already exists, skip" % save_path)
continue
image_dir = os.path.join(seq_path, camera_configs['image_folder'])
image_paths = sorted([os.path.join(image_dir, name) for name in os.listdir(image_dir) if
name.endswith(camera_configs['ext'])])
n_frame = len(image_paths)
radar_dir = os.path.join(seq_path, dataset.sensor_cfg.radar_cfg['chirp_folder'])
if radar_configs['data_type'] == 'RI' or radar_configs['data_type'] == 'AP':
radar_paths = sorted([os.path.join(radar_dir, name) for name in os.listdir(radar_dir) if
name.endswith(dataset.sensor_cfg.radar_cfg['ext'])])
n_radar_frame = len(radar_paths)
assert n_frame == n_radar_frame
elif radar_configs['data_type'] == 'RISEP' or radar_configs['data_type'] == 'APSEP':
radar_paths_chirp = []
for chirp_id in range(n_chirp):
chirp_dir = os.path.join(radar_dir, '%04d' % chirp_id)
paths = sorted([os.path.join(chirp_dir, name) for name in os.listdir(chirp_dir) if
name.endswith(config_dict['dataset_cfg']['radar_cfg']['ext'])])
n_radar_frame = len(paths)
assert n_frame == n_radar_frame
radar_paths_chirp.append(paths)
radar_paths = []
for frame_id in range(n_frame):
frame_paths = []
for chirp_id in range(n_chirp):
frame_paths.append(radar_paths_chirp[chirp_id][frame_id])
radar_paths.append(frame_paths)
else:
raise ValueError
data_dict = dict(
data_root=data_root,
data_path=seq_path,
seq_name=seq,
n_frame=n_frame,
image_paths=image_paths,
radar_paths=radar_paths,
anno=None,
)
if split == 'demo':
# no labels need to be saved
pickle.dump(data_dict, open(save_path, 'wb'))
continue
else:
with open(os.path.join(seq_anno_path), 'r') as f:
anno = json.load(f)
anno_obj = {}
anno_obj['metadata'] = anno['metadata']
anno_obj['confmaps'] = []
for metadata_frame in anno['metadata']:
n_obj = metadata_frame['rad_h']['n_objects']
obj_info = metadata_frame['rad_h']['obj_info']
if n_obj == 0:
confmap_gt = np.zeros(
(n_class + 1, radar_configs['ramap_rsize'], radar_configs['ramap_asize']),
dtype=float)
confmap_gt[-1, :, :] = 1.0 # initialize noise channal
else:
confmap_gt = generate_confmap(n_obj, obj_info, dataset, config_dict)
confmap_gt = normalize_confmap(confmap_gt)
confmap_gt = add_noise_channel(confmap_gt, dataset, config_dict)
assert confmap_gt.shape == (
n_class + 1, radar_configs['ramap_rsize'], radar_configs['ramap_asize'])
if viz:
visualize_confmap(confmap_gt)
anno_obj['confmaps'].append(confmap_gt)
# end objects loop
anno_obj['confmaps'] = np.array(anno_obj['confmaps'])
data_dict['anno'] = anno_obj
# save pkl files
pickle.dump(data_dict, open(save_path, 'wb'))
# end frames loop
except Exception as e:
print("Error while preparing %s: %s" % (seq_path, e))
if __name__ == "__main__":
args = parse_args()
data_root = args.data_root
splits = args.split.split(',')
out_data_dir = args.out_data_dir
overwrite = args.overwrite
dataset = CRUW(data_root=data_root)
config_dict = load_configs_from_file(args.config)
radar_configs = dataset.sensor_cfg.radar_cfg
for split in splits:
if split not in SPLITS_LIST:
raise TypeError("split %s cannot be recognized" % split)
for split in splits:
save_dir = os.path.join(out_data_dir, split)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
print('Preparing %s sets ...' % split)
prepare_data(dataset, config_dict, out_data_dir, split, save_dir, viz=False, overwrite=overwrite)
import os
import time
import argparse
import numpy as np
import torch
from torch.utils.data import DataLoader
from cruw.cruw import CRUW
from rodnet.datasets.CRDataset import CRDataset
from rodnet.datasets.collate_functions import cr_collate
from rodnet.core.post_processing import post_process, post_process_single_frame
from rodnet.core.post_processing import write_dets_results, write_dets_results_single_frame
from rodnet.core.post_processing import ConfmapStack
from rodnet.core.radar_processing import chirp_amp
from rodnet.utils.visualization import visualize_test_img, visualize_test_img_wo_gt
from rodnet.utils.load_configs import load_configs_from_file
from rodnet.utils.solve_dir import create_random_model_name
"""
Example:
python test.py -m HG -dd /mnt/ssd2/rodnet/data/ -ld /mnt/ssd2/rodnet/checkpoints/ \
-md HG-20200122-104604 -rd /mnt/ssd2/rodnet/results/
"""
def parse_args():
parser = argparse.ArgumentParser(description='Test RODNet.')
parser.add_argument('--config', type=str, help='choose rodnet model configurations')
parser.add_argument('--data_dir', type=str, default='./data/', help='directory to the prepared data')
parser.add_argument('--checkpoint', type=str, help='path to the saved trained model')
parser.add_argument('--res_dir', type=str, default='./results/', help='directory to save testing results')
parser.add_argument('--use_noise_channel', action="store_true", help="use noise channel or not")
parser.add_argument('--demo', action="store_true", help='False: test with GT, True: demo without GT')
parser.add_argument('--symbol', action="store_true", help='use symbol or text+score')
args = parser.parse_args()
return args
if __name__ == "__main__":
args = parse_args()
sybl = args.symbol
config_dict = load_configs_from_file(args.config)
dataset = CRUW(data_root=config_dict['dataset_cfg']['base_root'])
radar_configs = dataset.sensor_cfg.radar_cfg
range_grid = dataset.range_grid
angle_grid = dataset.angle_grid
# config_dict['mappings'] = {}
# config_dict['mappings']['range_grid'] = range_grid.tolist()
# config_dict['mappings']['angle_grid'] = angle_grid.tolist()
model_configs = config_dict['model_cfg']
if model_configs['type'] == 'CDC':
from rodnet.models import RODNetCDC as RODNet
elif model_configs['type'] == 'HG':
from rodnet.models import RODNetHG as RODNet
elif model_configs['type'] == 'HGwI':
from rodnet.models import RODNetHGwI as RODNet
else:
raise NotImplementedError
# parameter settings
dataset_configs = config_dict['dataset_cfg']
train_configs = config_dict['train_cfg']
test_configs = config_dict['test_cfg']
win_size = train_configs['win_size']
n_class = dataset.object_cfg.n_class
confmap_shape = (n_class, radar_configs['ramap_rsize'], radar_configs['ramap_asize'])
if 'stacked_num' in model_configs:
stacked_num = model_configs['stacked_num']
else:
stacked_num = None
if args.checkpoint is not None and os.path.exists(args.checkpoint):
checkpoint_path = args.checkpoint
else:
raise ValueError("No trained model found.")
if args.use_noise_channel:
n_class_test = n_class + 1
else:
n_class_test = n_class
print("Building model ... (%s)" % model_configs)
if model_configs['type'] == 'CDC':
rodnet = RODNet(n_class_test).cuda()
elif model_configs['type'] == 'HG':
rodnet = RODNet(n_class_test, stacked_num=stacked_num).cuda()
elif model_configs['type'] == 'HGwI':
rodnet = RODNet(n_class_test, stacked_num=stacked_num).cuda()
else:
raise TypeError
checkpoint = torch.load(checkpoint_path)
if 'optimizer_state_dict' in checkpoint:
rodnet.load_state_dict(checkpoint['model_state_dict'])
else:
rodnet.load_state_dict(checkpoint)
if 'model_name' in checkpoint:
model_name = checkpoint['model_name']
else:
model_name = create_random_model_name(model_configs['name'], checkpoint_path)
rodnet.eval()
test_res_dir = os.path.join(os.path.join(args.res_dir, model_name))
if not os.path.exists(test_res_dir):
os.makedirs(test_res_dir)
# save current checkpoint path
weight_log_path = os.path.join(test_res_dir, 'weight_name.txt')
if os.path.exists(weight_log_path):
with open(weight_log_path, 'a+') as f:
f.write(checkpoint_path + '\n')
else:
with open(weight_log_path, 'w') as f:
f.write(checkpoint_path + '\n')
total_time = 0
total_count = 0
data_root = dataset_configs['data_root']
if not args.demo:
seq_names = dataset_configs['test']['seqs']
else:
seq_names = dataset_configs['demo']['seqs']
print(seq_names)
for seq_name in seq_names:
seq_res_dir = os.path.join(test_res_dir, seq_name)
if not os.path.exists(seq_res_dir):
os.makedirs(seq_res_dir)
seq_res_viz_dir = os.path.join(seq_res_dir, 'rod_viz')
if not os.path.exists(seq_res_viz_dir):
os.makedirs(seq_res_viz_dir)
f = open(os.path.join(seq_res_dir, 'rod_res.txt'), 'w')
f.close()
for subset in seq_names:
print(subset)
if not args.demo:
crdata_test = CRDataset(data_dir=args.data_dir, dataset=dataset, config_dict=config_dict, split='test',
noise_channel=args.use_noise_channel, subset=subset, is_random_chirp=False)
else:
crdata_test = CRDataset(data_dir=args.data_dir, dataset=dataset, config_dict=config_dict, split='demo',
noise_channel=args.use_noise_channel, subset=subset, is_random_chirp=False)
print("Length of testing data: %d" % len(crdata_test))
dataloader = DataLoader(crdata_test, batch_size=1, shuffle=False, num_workers=0, collate_fn=cr_collate)
seq_names = crdata_test.seq_names
index_mapping = crdata_test.index_mapping
init_genConfmap = ConfmapStack(confmap_shape)
iter_ = init_genConfmap
for i in range(train_configs['win_size'] - 1):
while iter_.next is not None:
iter_ = iter_.next
iter_.next = ConfmapStack(confmap_shape)
load_tic = time.time()
for iter, data_dict in enumerate(dataloader):
load_time = time.time() - load_tic
data = data_dict['radar_data']
image_paths = data_dict['image_paths'][0]
seq_name = data_dict['seq_names'][0]
if not args.demo:
confmap_gt = data_dict['anno']['confmaps']
obj_info = data_dict['anno']['obj_infos']
else:
confmap_gt = None
obj_info = None
save_path = os.path.join(test_res_dir, seq_name, 'rod_res.txt')
start_frame_name = image_paths[0].split('/')[-1].split('.')[0]
end_frame_name = image_paths[-1].split('/')[-1].split('.')[0]
start_frame_id = int(start_frame_name)
end_frame_id = int(end_frame_name)
print("Testing %s: %s-%s" % (seq_name, start_frame_name, end_frame_name))
tic = time.time()
confmap_pred = rodnet(data.float().cuda())
if stacked_num is not None:
confmap_pred = confmap_pred[-1].cpu().detach().numpy() # (1, 4, 32, 128, 128)
else:
confmap_pred = confmap_pred.cpu().detach().numpy()
if args.use_noise_channel:
confmap_pred = confmap_pred[:, :n_class, :, :, :]
infer_time = time.time() - tic
total_time += infer_time
iter_ = init_genConfmap
for i in range(confmap_pred.shape[2]):
if iter_.next is None and i != confmap_pred.shape[2] - 1:
iter_.next = ConfmapStack(confmap_shape)
iter_.append(confmap_pred[0, :, i, :, :])
iter_ = iter_.next
process_tic = time.time()
for i in range(test_configs['test_stride']):
total_count += 1
res_final = post_process_single_frame(init_genConfmap.confmap, dataset, config_dict)
cur_frame_id = start_frame_id + i
write_dets_results_single_frame(res_final, cur_frame_id, save_path, dataset)
confmap_pred_0 = init_genConfmap.confmap
res_final_0 = res_final
img_path = image_paths[i]
radar_input = chirp_amp(data.numpy()[0, :, i, :, :], radar_configs['data_type'])
fig_name = os.path.join(test_res_dir, seq_name, 'rod_viz', '%010d.jpg' % (cur_frame_id))
if confmap_gt is not None:
confmap_gt_0 = confmap_gt[0, :, i, :, :]
visualize_test_img(fig_name, img_path, radar_input, confmap_pred_0, confmap_gt_0, res_final_0,
dataset, sybl=sybl)
else:
visualize_test_img_wo_gt(fig_name, img_path, radar_input, confmap_pred_0, res_final_0,
dataset, sybl=sybl)
init_genConfmap = init_genConfmap.next
if iter == len(dataloader) - 1:
offset = test_configs['test_stride']
cur_frame_id = start_frame_id + offset
while init_genConfmap is not None:
total_count += 1
res_final = post_process_single_frame(init_genConfmap.confmap, dataset, config_dict)
write_dets_results_single_frame(res_final, cur_frame_id, save_path, dataset)
confmap_pred_0 = init_genConfmap.confmap
res_final_0 = res_final
img_path = image_paths[offset]
radar_input = chirp_amp(data.numpy()[0, :, offset, :, :], radar_configs['data_type'])
fig_name = os.path.join(test_res_dir, seq_name, 'rod_viz', '%010d.jpg' % (cur_frame_id))
if confmap_gt is not None:
confmap_gt_0 = confmap_gt[0, :, offset, :, :]
visualize_test_img(fig_name, img_path, radar_input, confmap_pred_0, confmap_gt_0, res_final_0,
dataset, sybl=sybl)
else:
visualize_test_img_wo_gt(fig_name, img_path, radar_input, confmap_pred_0, res_final_0,
dataset, sybl=sybl)
init_genConfmap = init_genConfmap.next
offset += 1
cur_frame_id += 1
if init_genConfmap is None:
init_genConfmap = ConfmapStack(confmap_shape)
proc_time = time.time() - process_tic
print("Load time: %.4f | Inference time: %.4f | Process time: %.4f" % (load_time, infer_time, proc_time))
load_tic = time.time()
print("ave time: %f" % (total_time / total_count))
import os
import time
import json
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from cruw.cruw import CRUW
from rodnet.datasets.CRDataset import CRDataset
from rodnet.datasets.CRDatasetSM import CRDatasetSM
from rodnet.datasets.CRDataLoader import CRDataLoader
from rodnet.datasets.collate_functions import cr_collate
from rodnet.core.radar_processing import chirp_amp
from rodnet.utils.solve_dir import create_dir_for_new_model
from rodnet.utils.load_configs import load_configs_from_file
from rodnet.utils.visualization import visualize_train_img
def parse_args():
parser = argparse.ArgumentParser(description='Train RODNet.')
parser.add_argument('--config', type=str, help='configuration file path')
parser.add_argument('--data_dir', type=str, default='./data/', help='directory to the prepared data')
parser.add_argument('--log_dir', type=str, default='./checkpoints/', help='directory to save trained model')
parser.add_argument('--resume_from', type=str, default=None, help='path to the trained model')
parser.add_argument('--save_memory', action="store_true", help="use customized dataloader to save memory")
parser.add_argument('--use_noise_channel', action="store_true", help="use noise channel or not")
args = parser.parse_args()
return args
if __name__ == "__main__":
args = parse_args()
config_dict = load_configs_from_file(args.config)
dataset = CRUW(data_root=config_dict['dataset_cfg']['base_root'])
radar_configs = dataset.sensor_cfg.radar_cfg
range_grid = dataset.range_grid
angle_grid = dataset.angle_grid
# config_dict['mappings'] = {}
# config_dict['mappings']['range_grid'] = range_grid.tolist()
# config_dict['mappings']['angle_grid'] = angle_grid.tolist()
model_cfg = config_dict['model_cfg']
if model_cfg['type'] == 'CDC':
from rodnet.models import RODNetCDC as RODNet
elif model_cfg['type'] == 'HG':
from rodnet.models import RODNetHG as RODNet
elif model_cfg['type'] == 'HGwI':
from rodnet.models import RODNetHGwI as RODNet
else:
raise NotImplementedError
if not os.path.exists(args.log_dir):
os.makedirs(args.log_dir)
train_model_path = args.log_dir
# create / load models
cp_path = None
epoch_start = 0
iter_start = 0
if args.resume_from is not None and os.path.exists(args.resume_from):
cp_path = args.resume_from
model_dir, model_name = create_dir_for_new_model(model_cfg['name'], train_model_path)
else:
model_dir, model_name = create_dir_for_new_model(model_cfg['name'], train_model_path)
train_viz_path = os.path.join(model_dir, 'train_viz')
if not os.path.exists(train_viz_path):
os.makedirs(train_viz_path)
writer = SummaryWriter(model_dir)
save_config_dict = {
'args': vars(args),
'config_dict': config_dict,
}
config_json_name = os.path.join(model_dir, 'config-' + time.strftime("%Y%m%d-%H%M%S") + '.json')
with open(config_json_name, 'w') as fp:
json.dump(save_config_dict, fp)
train_log_name = os.path.join(model_dir, "train.log")
with open(train_log_name, 'w'):
pass
n_class = dataset.object_cfg.n_class
n_epoch = config_dict['train_cfg']['n_epoch']
batch_size = config_dict['train_cfg']['batch_size']
lr = config_dict['train_cfg']['lr']
if 'stacked_num' in model_cfg:
stacked_num = model_cfg['stacked_num']
else:
stacked_num = None
print("Building dataloader ... (Mode: %s)" % ("save_memory" if args.save_memory else "normal"))
if not args.save_memory:
crdata_train = CRDataset(data_dir=args.data_dir, dataset=dataset, config_dict=config_dict, split='train',
noise_channel=args.use_noise_channel)
seq_names = crdata_train.seq_names
index_mapping = crdata_train.index_mapping
dataloader = DataLoader(crdata_train, batch_size, shuffle=True, num_workers=0, collate_fn=cr_collate)
# crdata_valid = CRDataset(os.path.join(args.data_dir, 'data_details'),
# os.path.join(args.data_dir, 'confmaps_gt'),
# win_size=win_size, set_type='valid', stride=8)
# seq_names_valid = crdata_valid.seq_names
# index_mapping_valid = crdata_valid.index_mapping
# dataloader_valid = DataLoader(crdata_valid, batch_size=batch_size, shuffle=True, num_workers=0)
else:
crdata_train = CRDatasetSM(data_root=args.data_dir, config_dict=config_dict, split='train',
noise_channel=args.use_noise_channel)
seq_names = crdata_train.seq_names
index_mapping = crdata_train.index_mapping
dataloader = CRDataLoader(crdata_train, shuffle=True, noise_channel=args.use_noise_channel)
# crdata_valid = CRDatasetSM(os.path.join(args.data_dir, 'data_details'),
# os.path.join(args.data_dir, 'confmaps_gt'),
# win_size=win_size, set_type='train', stride=8, is_Memory_Limit=True)
# seq_names_valid = crdata_valid.seq_names
# index_mapping_valid = crdata_valid.index_mapping
# dataloader_valid = CRDataLoader(crdata_valid, batch_size=batch_size, shuffle=True)
if args.use_noise_channel:
n_class_train = n_class + 1
else:
n_class_train = n_class
print("Building model ... (%s)" % model_cfg)
if model_cfg['type'] == 'CDC':
if 'mnet_cfg' in model_cfg:
rodnet = RODNet(n_class_train, mnet_cfg=model_cfg['mnet_cfg']).cuda()
else:
rodnet = RODNet(n_class_train).cuda()
criterion = nn.MSELoss()
elif model_cfg['type'] == 'HG':
if 'mnet_cfg' in model_cfg:
rodnet = RODNet(n_class_train, stacked_num=stacked_num, mnet_cfg=model_cfg['mnet_cfg']).cuda()
else:
rodnet = RODNet(n_class_train, stacked_num=stacked_num).cuda()
criterion = nn.BCELoss()
elif model_cfg['type'] == 'HGwI':
if 'mnet_cfg' in model_cfg:
rodnet = RODNet(n_class_train, stacked_num=stacked_num, mnet_cfg=model_cfg['mnet_cfg']).cuda()
else:
rodnet = RODNet(n_class_train, stacked_num=stacked_num).cuda()
criterion = nn.BCELoss()
else:
raise TypeError
# criterion = FocalLoss(focusing_param=8, balance_param=0.25)
optimizer = optim.Adam(rodnet.parameters(), lr=lr)
scheduler = StepLR(optimizer, step_size=config_dict['train_cfg']['lr_step'], gamma=0.1)
iter_count = 0
if cp_path is not None:
checkpoint = torch.load(cp_path)
if 'optimizer_state_dict' in checkpoint:
rodnet.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch_start = checkpoint['epoch'] + 1
iter_start = checkpoint['iter'] + 1
loss_cp = checkpoint['loss']
if 'iter_count' in checkpoint:
iter_count = checkpoint['iter_count']
else:
rodnet.load_state_dict(checkpoint)
# print training configurations
print("Model name: %s" % model_name)
print("Number of sequences to train: %d" % crdata_train.n_seq)
print("Training dataset length: %d" % len(crdata_train))
print("Batch size: %d" % batch_size)
print("Number of iterations in each epoch: %d" % int(len(crdata_train) / batch_size))
for epoch in range(epoch_start, n_epoch):
tic_load = time.time()
# if epoch == epoch_start:
# dataloader_start = iter_start
# else:
# dataloader_start = 0
for iter, data_dict in enumerate(dataloader):
data = data_dict['radar_data']
image_paths = data_dict['image_paths']
confmap_gt = data_dict['anno']['confmaps']
if not data_dict['status']:
# in case load npy fail
print("Warning: Loading NPY data failed! Skip this iteration")
tic_load = time.time()
continue
tic = time.time()
optimizer.zero_grad() # zero the parameter gradients
confmap_preds = rodnet(data.float().cuda())
loss_confmap = 0
if stacked_num is not None:
for i in range(stacked_num):
loss_cur = criterion(confmap_preds[i], confmap_gt.float().cuda())
loss_confmap += loss_cur
loss_confmap.backward()
optimizer.step()
else:
loss_confmap = criterion(confmap_preds, confmap_gt.float().cuda())
loss_confmap.backward()
optimizer.step()
if iter % config_dict['train_cfg']['log_step'] == 0:
# print statistics
print('epoch %2d, iter %4d: loss: %.8f | load time: %.4f | backward time: %.4f' %
(epoch + 1, iter + 1, loss_confmap.item(), tic - tic_load, time.time() - tic))
with open(train_log_name, 'a+') as f_log:
f_log.write('epoch %2d, iter %4d: loss: %.8f | load time: %.4f | backward time: %.4f\n' %
(epoch + 1, iter + 1, loss_confmap.item(), tic - tic_load, time.time() - tic))
if stacked_num is not None:
writer.add_scalar('loss/loss_all', loss_confmap.item(), iter_count)
confmap_pred = confmap_preds[stacked_num - 1].cpu().detach().numpy()
else:
writer.add_scalar('loss/loss_all', loss_confmap.item(), iter_count)
confmap_pred = confmap_preds.cpu().detach().numpy()
if 'mnet_cfg' in model_cfg:
chirp_amp_curr = chirp_amp(data.numpy()[0, :, 0, 0, :, :], radar_configs['data_type'])
else:
chirp_amp_curr = chirp_amp(data.numpy()[0, :, 0, :, :], radar_configs['data_type'])
if True:
# draw train images
fig_name = os.path.join(train_viz_path,
'%03d_%010d_%06d.png' % (epoch + 1, iter_count, iter + 1))
img_path = image_paths[0][0]
visualize_train_img(fig_name, img_path, chirp_amp_curr,
confmap_pred[0, :n_class, 0, :, :],
confmap_gt[0, :n_class, 0, :, :])
else:
writer.add_image('images/ramap', heatmap2rgb(chirp_amp_curr), iter_count)
writer.add_image('images/confmap_pred', prob2image(confmap_pred[0, :, 0, :, :]), iter_count)
writer.add_image('images/confmap_gt', prob2image(confmap_gt[0, :, 0, :, :]), iter_count)
# TODO: combine three images together
# writer.add_images('')
if (iter + 1) % config_dict['train_cfg']['save_step'] == 0:
# validate current model
# print("validing current model ...")
# validate()
# save current model
print("saving current model ...")
status_dict = {
'model_name': model_name,
'epoch': epoch,
'iter': iter,
'model_state_dict': rodnet.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss_confmap,
'iter_count': iter_count,
}
save_model_path = '%s/epoch_%02d_iter_%010d.pkl' % (model_dir, epoch + 1, iter_count + 1)
torch.save(status_dict, save_model_path)
iter_count += 1
tic_load = time.time()
# save current model
print("saving current epoch model ...")
status_dict = {
'model_name': model_name,
'epoch': epoch,
'iter': iter,
'model_state_dict': rodnet.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss_confmap,
'iter_count': iter_count,
}
save_model_path = '%s/epoch_%02d_final.pkl' % (model_dir, epoch + 1)
torch.save(status_dict, save_model_path)
scheduler.step()
print('Training Finished.')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment