Commit 1d3dead7 authored by Yizhou Wang's avatar Yizhou Wang
Browse files

update base code for ROD2021

parent 81f1e0ac
...@@ -9,7 +9,7 @@ def chirp_amp(chirp, radar_data_type): ...@@ -9,7 +9,7 @@ def chirp_amp(chirp, radar_data_type):
:return: amplitude map for the input chirp (w x h) :return: amplitude map for the input chirp (w x h)
""" """
c0, c1, c2 = chirp.shape c0, c1, c2 = chirp.shape
if radar_data_type == 'RI' or radar_data_type == 'RISEP': if radar_data_type == 'RI' or radar_data_type == 'RISEP' or radar_data_type == 'ROD2021':
if c0 == 2: if c0 == 2:
chirp_abs = np.sqrt(chirp[0, :, :] ** 2 + chirp[1, :, :] ** 2) chirp_abs = np.sqrt(chirp[0, :, :] ** 2 + chirp[1, :, :] ** 2)
elif c2 == 2: elif c2 == 2:
......
...@@ -7,7 +7,7 @@ from tqdm import tqdm ...@@ -7,7 +7,7 @@ from tqdm import tqdm
from torch.utils import data from torch.utils import data
from .loaders import list_pkl_filenames from .loaders import list_pkl_filenames, list_pkl_filenames_from_prepared
class CRDataset(data.Dataset): class CRDataset(data.Dataset):
...@@ -61,7 +61,8 @@ class CRDataset(data.Dataset): ...@@ -61,7 +61,8 @@ class CRDataset(data.Dataset):
if subset is not None: if subset is not None:
self.data_files = [subset + '.pkl'] self.data_files = [subset + '.pkl']
else: else:
self.data_files = list_pkl_filenames(config_dict['dataset_cfg'], split) # self.data_files = list_pkl_filenames(config_dict['dataset_cfg'], split)
self.data_files = list_pkl_filenames_from_prepared(data_dir, split)
self.seq_names = [name.split('.')[0] for name in self.data_files] self.seq_names = [name.split('.')[0] for name in self.data_files]
self.n_seq = len(self.seq_names) self.n_seq = len(self.seq_names)
...@@ -142,8 +143,15 @@ class CRDataset(data.Dataset): ...@@ -142,8 +143,15 @@ class CRDataset(data.Dataset):
data_dict['image_paths'].append(image_paths[frameid]) data_dict['image_paths'].append(image_paths[frameid])
else: else:
raise TypeError raise TypeError
elif radar_configs['data_type'] == 'ROD2021':
radar_npy_win = np.zeros((self.win_size, ramap_rsize, ramap_asize, 2), dtype=np.float32)
chirp_id = 0 # only use chirp 0 for training
for idx, frameid in enumerate(
range(data_id, data_id + self.win_size * self.step, self.step)):
radar_npy_win[idx, :, :, :] = np.load(radar_paths[frameid][chirp_id])
data_dict['image_paths'].append(image_paths[frameid])
else: else:
raise ValueError raise NotImplementedError
except: except:
# in case load npy fail # in case load npy fail
data_dict['status'] = False data_dict['status'] = False
...@@ -202,10 +210,3 @@ class CRDataset(data.Dataset): ...@@ -202,10 +210,3 @@ class CRDataset(data.Dataset):
data_dict['anno'] = None data_dict['anno'] = None
return data_dict return data_dict
if __name__ == "__main__":
dataset = CRDataset('./data/data_details', stride=16)
print(len(dataset))
for i in range(len(dataset)):
continue
from .parse_pkl import list_pkl_filenames from .parse_pkl import list_pkl_filenames, list_pkl_filenames_from_prepared
from .read_rod_results import load_rodnet_res, load_vgg_res from .read_rod_results import load_rodnet_res, load_vgg_res
...@@ -6,3 +6,8 @@ def list_pkl_filenames(dataset_configs, split): ...@@ -6,3 +6,8 @@ def list_pkl_filenames(dataset_configs, split):
seqs = dataset_configs[split]['seqs'] seqs = dataset_configs[split]['seqs']
seqs_pkl_names = [name + '.pkl' for name in seqs] seqs_pkl_names = [name + '.pkl' for name in seqs]
return seqs_pkl_names return seqs_pkl_names
def list_pkl_filenames_from_prepared(data_dir, split):
seqs_pkl_names = sorted(os.listdir(os.path.join(data_dir, split)))
return seqs_pkl_names
...@@ -6,7 +6,9 @@ import json ...@@ -6,7 +6,9 @@ import json
import pickle import pickle
import argparse import argparse
from cruw.cruw import CRUW from cruw import CRUW
from cruw.annotation.init_json import init_meta_json
from cruw.mapping import ra2idx
from rodnet.core.confidence_map import generate_confmap, normalize_confmap, add_noise_channel from rodnet.core.confidence_map import generate_confmap, normalize_confmap, add_noise_channel
from rodnet.utils.load_configs import load_configs_from_file from rodnet.utils.load_configs import load_configs_from_file
...@@ -27,6 +29,52 @@ def parse_args(): ...@@ -27,6 +29,52 @@ def parse_args():
return args return args
def load_anno_txt(txt_path, n_frame, dataset):
folder_name_dict = dict(
cam_0='IMAGES_0',
rad_h='RADAR_RA_H'
)
anno_dict = init_meta_json(n_frame, folder_name_dict)
with open(txt_path, 'r') as f:
data = f.readlines()
for line in data:
frame_id, r, a, class_name = line.rstrip().split()
frame_id = int(frame_id)
r = float(r)
a = float(a)
rid, aid = ra2idx(r, a, dataset.range_grid, dataset.angle_grid)
anno_dict[frame_id]['rad_h']['n_objects'] += 1
anno_dict[frame_id]['rad_h']['obj_info']['categories'].append(class_name)
anno_dict[frame_id]['rad_h']['obj_info']['centers'].append([r, a])
anno_dict[frame_id]['rad_h']['obj_info']['center_ids'].append([rid, aid])
anno_dict[frame_id]['rad_h']['obj_info']['scores'].append(1.0)
return anno_dict
def generate_confmaps(metadata_dict, n_class, viz):
confmaps = []
for metadata_frame in metadata_dict:
n_obj = metadata_frame['rad_h']['n_objects']
obj_info = metadata_frame['rad_h']['obj_info']
if n_obj == 0:
confmap_gt = np.zeros(
(n_class + 1, radar_configs['ramap_rsize'], radar_configs['ramap_asize']),
dtype=float)
confmap_gt[-1, :, :] = 1.0 # initialize noise channal
else:
confmap_gt = generate_confmap(n_obj, obj_info, dataset, config_dict)
confmap_gt = normalize_confmap(confmap_gt)
confmap_gt = add_noise_channel(confmap_gt, dataset, config_dict)
assert confmap_gt.shape == (
n_class + 1, radar_configs['ramap_rsize'], radar_configs['ramap_asize'])
if viz:
visualize_confmap(confmap_gt)
confmaps.append(confmap_gt)
confmaps = np.array(confmaps)
return confmaps
def prepare_data(dataset, config_dict, data_dir, split, save_dir, viz=False, overwrite=False): def prepare_data(dataset, config_dict, data_dir, split, save_dir, viz=False, overwrite=False):
""" """
Prepare pickle data for RODNet training and testing Prepare pickle data for RODNet training and testing
...@@ -34,6 +82,7 @@ def prepare_data(dataset, config_dict, data_dir, split, save_dir, viz=False, ove ...@@ -34,6 +82,7 @@ def prepare_data(dataset, config_dict, data_dir, split, save_dir, viz=False, ove
:param config_dict: rodnet configurations :param config_dict: rodnet configurations
:param data_dir: output directory of the processed data :param data_dir: output directory of the processed data
:param split: train, valid, test, demo, etc. :param split: train, valid, test, demo, etc.
:param save_dir: output directory of the prepared data
:param viz: whether visualize the prepared data :param viz: whether visualize the prepared data
:param overwrite: whether overwrite the existing prepared data :param overwrite: whether overwrite the existing prepared data
:return: :return:
...@@ -46,7 +95,10 @@ def prepare_data(dataset, config_dict, data_dir, split, save_dir, viz=False, ove ...@@ -46,7 +95,10 @@ def prepare_data(dataset, config_dict, data_dir, split, save_dir, viz=False, ove
data_root = config_dict['dataset_cfg']['data_root'] data_root = config_dict['dataset_cfg']['data_root']
anno_root = config_dict['dataset_cfg']['anno_root'] anno_root = config_dict['dataset_cfg']['anno_root']
set_cfg = config_dict['dataset_cfg'][split] set_cfg = config_dict['dataset_cfg'][split]
sets_seqs = set_cfg['seqs'] if 'seqs' not in set_cfg:
sets_seqs = sorted(os.listdir(os.path.join(data_root, set_cfg['subdir'])))
else:
sets_seqs = set_cfg['seqs']
if overwrite: if overwrite:
if os.path.exists(os.path.join(data_dir, split)): if os.path.exists(os.path.join(data_dir, split)):
...@@ -54,8 +106,8 @@ def prepare_data(dataset, config_dict, data_dir, split, save_dir, viz=False, ove ...@@ -54,8 +106,8 @@ def prepare_data(dataset, config_dict, data_dir, split, save_dir, viz=False, ove
os.makedirs(os.path.join(data_dir, split)) os.makedirs(os.path.join(data_dir, split))
for seq in sets_seqs: for seq in sets_seqs:
seq_path = os.path.join(data_root, seq) seq_path = os.path.join(data_root, set_cfg['subdir'], seq)
seq_anno_path = os.path.join(anno_root, seq + '.json') seq_anno_path = os.path.join(anno_root, set_cfg['subdir'], seq + config_dict['dataset_cfg']['anno_ext'])
save_path = os.path.join(save_dir, seq + '.pkl') save_path = os.path.join(save_dir, seq + '.pkl')
print("Sequence %s saving to %s" % (seq_path, save_path)) print("Sequence %s saving to %s" % (seq_path, save_path))
...@@ -89,6 +141,16 @@ def prepare_data(dataset, config_dict, data_dir, split, save_dir, viz=False, ove ...@@ -89,6 +141,16 @@ def prepare_data(dataset, config_dict, data_dir, split, save_dir, viz=False, ove
for chirp_id in range(n_chirp): for chirp_id in range(n_chirp):
frame_paths.append(radar_paths_chirp[chirp_id][frame_id]) frame_paths.append(radar_paths_chirp[chirp_id][frame_id])
radar_paths.append(frame_paths) radar_paths.append(frame_paths)
elif radar_configs['data_type'] == 'ROD2021':
assert len(os.listdir(radar_dir)) == n_frame * len(radar_configs['chirp_ids'])
radar_paths = []
for frame_id in range(n_frame):
chirp_paths = []
for chirp_id in radar_configs['chirp_ids']:
path = os.path.join(radar_dir, '%06d_%04d.' % (frame_id, chirp_id) +
dataset.sensor_cfg.radar_cfg['ext'])
chirp_paths.append(path)
radar_paths.append(chirp_paths)
else: else:
raise ValueError raise ValueError
...@@ -107,35 +169,19 @@ def prepare_data(dataset, config_dict, data_dir, split, save_dir, viz=False, ove ...@@ -107,35 +169,19 @@ def prepare_data(dataset, config_dict, data_dir, split, save_dir, viz=False, ove
pickle.dump(data_dict, open(save_path, 'wb')) pickle.dump(data_dict, open(save_path, 'wb'))
continue continue
else: else:
with open(os.path.join(seq_anno_path), 'r') as f:
anno = json.load(f)
anno_obj = {} anno_obj = {}
anno_obj['metadata'] = anno['metadata'] if config_dict['dataset_cfg']['anno_ext'] == '.txt':
anno_obj['confmaps'] = [] anno_obj['metadata'] = load_anno_txt(seq_anno_path, n_frame, dataset)
for metadata_frame in anno['metadata']:
n_obj = metadata_frame['rad_h']['n_objects']
obj_info = metadata_frame['rad_h']['obj_info']
if n_obj == 0:
confmap_gt = np.zeros(
(n_class + 1, radar_configs['ramap_rsize'], radar_configs['ramap_asize']),
dtype=float)
confmap_gt[-1, :, :] = 1.0 # initialize noise channal
else:
confmap_gt = generate_confmap(n_obj, obj_info, dataset, config_dict)
confmap_gt = normalize_confmap(confmap_gt)
confmap_gt = add_noise_channel(confmap_gt, dataset, config_dict)
assert confmap_gt.shape == (
n_class + 1, radar_configs['ramap_rsize'], radar_configs['ramap_asize'])
if viz:
visualize_confmap(confmap_gt)
anno_obj['confmaps'].append(confmap_gt)
# end objects loop
anno_obj['confmaps'] = np.array(anno_obj['confmaps'])
data_dict['anno'] = anno_obj
elif config_dict['dataset_cfg']['anno_ext'] == '.json':
with open(os.path.join(seq_anno_path), 'r') as f:
anno = json.load(f)
anno_obj['metadata'] = anno['metadata']
else:
raise
anno_obj['confmaps'] = generate_confmaps(anno_obj['metadata'], n_class, viz)
data_dict['anno'] = anno_obj
# save pkl files # save pkl files
pickle.dump(data_dict, open(save_path, 'wb')) pickle.dump(data_dict, open(save_path, 'wb'))
# end frames loop # end frames loop
...@@ -151,7 +197,7 @@ if __name__ == "__main__": ...@@ -151,7 +197,7 @@ if __name__ == "__main__":
out_data_dir = args.out_data_dir out_data_dir = args.out_data_dir
overwrite = args.overwrite overwrite = args.overwrite
dataset = CRUW(data_root=data_root) dataset = CRUW(data_root=data_root, sensor_config_name='sensor_config_rod2021')
config_dict = load_configs_from_file(args.config) config_dict = load_configs_from_file(args.config)
radar_configs = dataset.sensor_cfg.radar_cfg radar_configs = dataset.sensor_cfg.radar_cfg
......
...@@ -10,7 +10,7 @@ from torch.optim.lr_scheduler import StepLR ...@@ -10,7 +10,7 @@ from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from cruw.cruw import CRUW from cruw import CRUW
from rodnet.datasets.CRDataset import CRDataset from rodnet.datasets.CRDataset import CRDataset
from rodnet.datasets.CRDatasetSM import CRDatasetSM from rodnet.datasets.CRDatasetSM import CRDatasetSM
...@@ -37,16 +37,13 @@ def parse_args(): ...@@ -37,16 +37,13 @@ def parse_args():
if __name__ == "__main__": if __name__ == "__main__":
args = parse_args() args = parse_args()
config_dict = load_configs_from_file(args.config) config_dict = load_configs_from_file(args.config)
dataset = CRUW(data_root=config_dict['dataset_cfg']['base_root']) # dataset = CRUW(data_root=config_dict['dataset_cfg']['base_root'])
dataset = CRUW(data_root=config_dict['dataset_cfg']['base_root'], sensor_config_name='sensor_config_rod2021')
radar_configs = dataset.sensor_cfg.radar_cfg radar_configs = dataset.sensor_cfg.radar_cfg
range_grid = dataset.range_grid range_grid = dataset.range_grid
angle_grid = dataset.angle_grid angle_grid = dataset.angle_grid
# config_dict['mappings'] = {}
# config_dict['mappings']['range_grid'] = range_grid.tolist()
# config_dict['mappings']['angle_grid'] = angle_grid.tolist()
model_cfg = config_dict['model_cfg'] model_cfg = config_dict['model_cfg']
if model_cfg['type'] == 'CDC': if model_cfg['type'] == 'CDC':
from rodnet.models import RODNetCDC as RODNet from rodnet.models import RODNetCDC as RODNet
elif model_cfg['type'] == 'HG': elif model_cfg['type'] == 'HG':
...@@ -132,27 +129,16 @@ if __name__ == "__main__": ...@@ -132,27 +129,16 @@ if __name__ == "__main__":
print("Building model ... (%s)" % model_cfg) print("Building model ... (%s)" % model_cfg)
if model_cfg['type'] == 'CDC': if model_cfg['type'] == 'CDC':
if 'mnet_cfg' in model_cfg: rodnet = RODNet(n_class_train).cuda()
rodnet = RODNet(n_class_train, mnet_cfg=model_cfg['mnet_cfg']).cuda()
else:
rodnet = RODNet(n_class_train).cuda()
criterion = nn.MSELoss() criterion = nn.MSELoss()
elif model_cfg['type'] == 'HG': elif model_cfg['type'] == 'HG':
if 'mnet_cfg' in model_cfg: rodnet = RODNet(n_class_train, stacked_num=stacked_num).cuda()
rodnet = RODNet(n_class_train, stacked_num=stacked_num, mnet_cfg=model_cfg['mnet_cfg']).cuda()
else:
rodnet = RODNet(n_class_train, stacked_num=stacked_num).cuda()
criterion = nn.BCELoss() criterion = nn.BCELoss()
elif model_cfg['type'] == 'HGwI': elif model_cfg['type'] == 'HGwI':
if 'mnet_cfg' in model_cfg: rodnet = RODNet(n_class_train, stacked_num=stacked_num).cuda()
rodnet = RODNet(n_class_train, stacked_num=stacked_num, mnet_cfg=model_cfg['mnet_cfg']).cuda()
else:
rodnet = RODNet(n_class_train, stacked_num=stacked_num).cuda()
criterion = nn.BCELoss() criterion = nn.BCELoss()
else: else:
raise TypeError raise TypeError
# criterion = FocalLoss(focusing_param=8, balance_param=0.25)
optimizer = optim.Adam(rodnet.parameters(), lr=lr) optimizer = optim.Adam(rodnet.parameters(), lr=lr)
scheduler = StepLR(optimizer, step_size=config_dict['train_cfg']['lr_step'], gamma=0.1) scheduler = StepLR(optimizer, step_size=config_dict['train_cfg']['lr_step'], gamma=0.1)
...@@ -232,21 +218,13 @@ if __name__ == "__main__": ...@@ -232,21 +218,13 @@ if __name__ == "__main__":
else: else:
chirp_amp_curr = chirp_amp(data.numpy()[0, :, 0, :, :], radar_configs['data_type']) chirp_amp_curr = chirp_amp(data.numpy()[0, :, 0, :, :], radar_configs['data_type'])
if True: # draw train images
# draw train images fig_name = os.path.join(train_viz_path,
fig_name = os.path.join(train_viz_path, '%03d_%010d_%06d.png' % (epoch + 1, iter_count, iter + 1))
'%03d_%010d_%06d.png' % (epoch + 1, iter_count, iter + 1)) img_path = image_paths[0][0]
img_path = image_paths[0][0] visualize_train_img(fig_name, img_path, chirp_amp_curr,
visualize_train_img(fig_name, img_path, chirp_amp_curr, confmap_pred[0, :n_class, 0, :, :],
confmap_pred[0, :n_class, 0, :, :], confmap_gt[0, :n_class, 0, :, :])
confmap_gt[0, :n_class, 0, :, :])
else:
writer.add_image('images/ramap', heatmap2rgb(chirp_amp_curr), iter_count)
writer.add_image('images/confmap_pred', prob2image(confmap_pred[0, :, 0, :, :]), iter_count)
writer.add_image('images/confmap_gt', prob2image(confmap_gt[0, :, 0, :, :]), iter_count)
# TODO: combine three images together
# writer.add_images('')
if (iter + 1) % config_dict['train_cfg']['save_step'] == 0: if (iter + 1) % config_dict['train_cfg']['save_step'] == 0:
# validate current model # validate current model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment