Commit 4f2843ce authored by Yizhou Wang's avatar Yizhou Wang
Browse files

update test.py script

parent d2cb288d
......@@ -28,4 +28,4 @@ def write_dets_results_single_frame(res, data_id, save_path, dataset):
row_id = res[d, 1]
col_id = res[d, 2]
conf = res[d, 3]
f.write("%d %s %d %d %s\n" % (data_id, get_class_name(cla_id, classes), row_id, col_id, conf))
f.write("%d %s %d %d %.4f\n" % (data_id, get_class_name(cla_id, classes), row_id, col_id, conf))
......@@ -160,6 +160,10 @@ class CRDataset(data.Dataset):
raise TypeError
else:
raise NotImplementedError
data_dict['start_frame'] = data_id
data_dict['end_frame'] = data_id + self.win_size * self.step - 1
except:
# in case load npy fail
data_dict['status'] = False
......
......@@ -159,6 +159,9 @@ class CRDatasetSM(data.Dataset):
else:
raise ValueError
data_dict['start_frame'] = data_id
data_dict['end_frame'] = data_id + self.win_size * self.step - 1
except:
# in case load npy fail
data_dict['status'] = False
......
......@@ -17,12 +17,6 @@ from rodnet.utils.visualization import visualize_test_img, visualize_test_img_wo
from rodnet.utils.load_configs import load_configs_from_file, parse_cfgs, update_config_dict
from rodnet.utils.solve_dir import create_random_model_name
"""
Example:
python test.py -m HG -dd /mnt/ssd2/rodnet/data/ -ld /mnt/ssd2/rodnet/checkpoints/ \
-md HG-20200122-104604 -rd /mnt/ssd2/rodnet/results/
"""
def parse_args():
parser = argparse.ArgumentParser(description='Test RODNet.')
......@@ -53,14 +47,20 @@ if __name__ == "__main__":
range_grid = dataset.range_grid
angle_grid = dataset.angle_grid
model_configs = config_dict['model_cfg']
model_cfg = config_dict['model_cfg']
if model_configs['type'] == 'CDC':
if model_cfg['type'] == 'CDC':
from rodnet.models import RODNetCDC as RODNet
elif model_configs['type'] == 'HG':
elif model_cfg['type'] == 'HG':
from rodnet.models import RODNetHG as RODNet
elif model_configs['type'] == 'HGwI':
elif model_cfg['type'] == 'HGwI':
from rodnet.models import RODNetHGwI as RODNet
elif model_cfg['type'] == 'CDCv2':
from rodnet.models import RODNetCDCDCN as RODNet
elif model_cfg['type'] == 'HGv2':
from rodnet.models import RODNetHGDCN as RODNet
elif model_cfg['type'] == 'HGwIv2':
from rodnet.models import RODNetHGwIDCN as RODNet
else:
raise NotImplementedError
......@@ -73,8 +73,8 @@ if __name__ == "__main__":
n_class = dataset.object_cfg.n_class
confmap_shape = (n_class, radar_configs['ramap_rsize'], radar_configs['ramap_asize'])
if 'stacked_num' in model_configs:
stacked_num = model_configs['stacked_num']
if 'stacked_num' in model_cfg:
stacked_num = model_cfg['stacked_num']
else:
stacked_num = None
......@@ -88,13 +88,28 @@ if __name__ == "__main__":
else:
n_class_test = n_class
print("Building model ... (%s)" % model_configs)
if model_configs['type'] == 'CDC':
rodnet = RODNet(n_class_test).cuda()
elif model_configs['type'] == 'HG':
rodnet = RODNet(n_class_test, stacked_num=stacked_num).cuda()
elif model_configs['type'] == 'HGwI':
rodnet = RODNet(n_class_test, stacked_num=stacked_num).cuda()
print("Building model ... (%s)" % model_cfg)
if model_cfg['type'] == 'CDC':
rodnet = RODNet(in_channels=2, n_class=n_class_test).cuda()
elif model_cfg['type'] == 'HG':
rodnet = RODNet(in_channels=2, n_class=n_class_test, stacked_num=stacked_num).cuda()
elif model_cfg['type'] == 'HGwI':
rodnet = RODNet(in_channels=2, n_class=n_class_test, stacked_num=stacked_num).cuda()
elif model_cfg['type'] == 'CDCv2':
in_chirps = len(radar_configs['chirp_ids'])
rodnet = RODNet(in_channels=in_chirps, n_class=n_class_test,
mnet_cfg=config_dict['model_cfg']['mnet_cfg'],
dcn=config_dict['model_cfg']['dcn']).cuda()
elif model_cfg['type'] == 'HGv2':
in_chirps = len(radar_configs['chirp_ids'])
rodnet = RODNet(in_channels=in_chirps, n_class=n_class_test, stacked_num=stacked_num,
mnet_cfg=config_dict['model_cfg']['mnet_cfg'],
dcn=config_dict['model_cfg']['dcn']).cuda()
elif model_cfg['type'] == 'HGwIv2':
in_chirps = len(radar_configs['chirp_ids'])
rodnet = RODNet(in_channels=in_chirps, n_class=n_class_test, stacked_num=stacked_num,
mnet_cfg=config_dict['model_cfg']['mnet_cfg'],
dcn=config_dict['model_cfg']['dcn']).cuda()
else:
raise TypeError
......@@ -106,7 +121,7 @@ if __name__ == "__main__":
if 'model_name' in checkpoint:
model_name = checkpoint['model_name']
else:
model_name = create_random_model_name(model_configs['name'], checkpoint_path)
model_name = create_random_model_name(model_cfg['name'], checkpoint_path)
rodnet.eval()
test_res_dir = os.path.join(os.path.join(args.res_dir, model_name))
......@@ -167,7 +182,11 @@ if __name__ == "__main__":
for iter, data_dict in enumerate(dataloader):
load_time = time.time() - load_tic
data = data_dict['radar_data']
try:
image_paths = data_dict['image_paths'][0]
except:
print('warning: fail to load RGB images, will not visualize results')
image_paths = None
seq_name = data_dict['seq_names'][0]
if not args.demo:
confmap_gt = data_dict['anno']['confmaps']
......@@ -177,12 +196,10 @@ if __name__ == "__main__":
obj_info = None
save_path = os.path.join(test_res_dir, seq_name, 'rod_res.txt')
start_frame_name = image_paths[0].split('/')[-1].split('.')[0]
end_frame_name = image_paths[-1].split('/')[-1].split('.')[0]
start_frame_id = int(start_frame_name)
end_frame_id = int(end_frame_name)
print("Testing %s: %s-%s" % (seq_name, start_frame_name, end_frame_name))
start_frame_id = data_dict['start_frame'].item()
end_frame_id = data_dict['end_frame'].item()
tic = time.time()
confmap_pred = rodnet(data.float().cuda())
if stacked_num is not None:
......@@ -211,6 +228,7 @@ if __name__ == "__main__":
write_dets_results_single_frame(res_final, cur_frame_id, save_path, dataset)
confmap_pred_0 = init_genConfmap.confmap
res_final_0 = res_final
if image_paths is not None:
img_path = image_paths[i]
radar_input = chirp_amp(data.numpy()[0, :, i, :, :], radar_configs['data_type'])
fig_name = os.path.join(test_res_dir, seq_name, 'rod_viz', '%010d.jpg' % (cur_frame_id))
......@@ -232,12 +250,14 @@ if __name__ == "__main__":
write_dets_results_single_frame(res_final, cur_frame_id, save_path, dataset)
confmap_pred_0 = init_genConfmap.confmap
res_final_0 = res_final
if image_paths is not None:
img_path = image_paths[offset]
radar_input = chirp_amp(data.numpy()[0, :, offset, :, :], radar_configs['data_type'])
fig_name = os.path.join(test_res_dir, seq_name, 'rod_viz', '%010d.jpg' % (cur_frame_id))
if confmap_gt is not None:
confmap_gt_0 = confmap_gt[0, :, offset, :, :]
visualize_test_img(fig_name, img_path, radar_input, confmap_pred_0, confmap_gt_0, res_final_0,
visualize_test_img(fig_name, img_path, radar_input, confmap_pred_0, confmap_gt_0,
res_final_0,
dataset, sybl=sybl)
else:
visualize_test_img_wo_gt(fig_name, img_path, radar_input, confmap_pred_0, res_final_0,
......@@ -250,7 +270,8 @@ if __name__ == "__main__":
init_genConfmap = ConfmapStack(confmap_shape)
proc_time = time.time() - process_tic
print("Load time: %.4f | Inference time: %.4f | Process time: %.4f" % (load_time, infer_time, proc_time))
print("Testing %s: frame %4d to %4d | Load time: %.4f | Inference time: %.4f | Process time: %.4f" %
(seq_name, start_frame_id, end_frame_id, load_time, infer_time, proc_time))
load_tic = time.time()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment