Commit 012ad57c authored by yizhou-wang's avatar yizhou-wang
Browse files

Update prepare_data.py

parent 9957f4aa
...@@ -21,7 +21,9 @@ def parse_args(): ...@@ -21,7 +21,9 @@ def parse_args():
parser = argparse.ArgumentParser(description='Prepare RODNet data.') parser = argparse.ArgumentParser(description='Prepare RODNet data.')
parser.add_argument('--config', type=str, dest='config', help='configuration file path') parser.add_argument('--config', type=str, dest='config', help='configuration file path')
parser.add_argument('--data_root', type=str, help='directory to the prepared data') parser.add_argument('--data_root', type=str, help='directory to the prepared data')
parser.add_argument('--split', type=str, dest='split', help='choose from train, valid, test, supertest') parser.add_argument('--sensor_config', type=str, default='sensor_config')
parser.add_argument('--split', type=str, dest='split', default='',
help='choose from train, valid, test, supertest')
parser.add_argument('--out_data_dir', type=str, default='./data', parser.add_argument('--out_data_dir', type=str, default='./data',
help='data directory to save the prepared data') help='data directory to save the prepared data')
parser.add_argument('--overwrite', action="store_true", help="overwrite prepared data if exist") parser.add_argument('--overwrite', action="store_true", help="overwrite prepared data if exist")
...@@ -94,6 +96,13 @@ def prepare_data(dataset, config_dict, data_dir, split, save_dir, viz=False, ove ...@@ -94,6 +96,13 @@ def prepare_data(dataset, config_dict, data_dir, split, save_dir, viz=False, ove
data_root = config_dict['dataset_cfg']['data_root'] data_root = config_dict['dataset_cfg']['data_root']
anno_root = config_dict['dataset_cfg']['anno_root'] anno_root = config_dict['dataset_cfg']['anno_root']
if split == None:
set_cfg = {
'subdir': '',
'seqs': sorted(os.listdir(data_root))
}
sets_seqs = sorted(os.listdir(data_root))
else:
set_cfg = config_dict['dataset_cfg'][split] set_cfg = config_dict['dataset_cfg'][split]
if 'seqs' not in set_cfg: if 'seqs' not in set_cfg:
sets_seqs = sorted(os.listdir(os.path.join(data_root, set_cfg['subdir']))) sets_seqs = sorted(os.listdir(os.path.join(data_root, set_cfg['subdir'])))
...@@ -193,14 +202,21 @@ def prepare_data(dataset, config_dict, data_dir, split, save_dir, viz=False, ove ...@@ -193,14 +202,21 @@ def prepare_data(dataset, config_dict, data_dir, split, save_dir, viz=False, ove
if __name__ == "__main__": if __name__ == "__main__":
args = parse_args() args = parse_args()
data_root = args.data_root data_root = args.data_root
if args.splits == '':
splits = None
else:
splits = args.split.split(',') splits = args.split.split(',')
out_data_dir = args.out_data_dir out_data_dir = args.out_data_dir
overwrite = args.overwrite overwrite = args.overwrite
dataset = CRUW(data_root=data_root, sensor_config_name='sensor_config_rod2021') dataset = CRUW(data_root=data_root, sensor_config_name=args.sensor_config)
config_dict = load_configs_from_file(args.config) config_dict = load_configs_from_file(args.config)
radar_configs = dataset.sensor_cfg.radar_cfg radar_configs = dataset.sensor_cfg.radar_cfg
if splits == None:
prepare_data(dataset, config_dict, out_data_dir, split=None, save_dir=out_data_dir, viz=False,
overwrite=overwrite)
else:
for split in splits: for split in splits:
if split not in SPLITS_LIST: if split not in SPLITS_LIST:
raise TypeError("split %s cannot be recognized" % split) raise TypeError("split %s cannot be recognized" % split)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment