Commit d0140132 authored by Yizhou Wang's avatar Yizhou Wang
Browse files

v1.1 code for RODNet J-STSP version

parent 9266cc35
......@@ -11,7 +11,7 @@ from cruw.annotation.init_json import init_meta_json
from cruw.mapping import ra2idx
from rodnet.core.confidence_map import generate_confmap, normalize_confmap, add_noise_channel
from rodnet.utils.load_configs import load_configs_from_file
from rodnet.utils.load_configs import load_configs_from_file, update_config_dict
from rodnet.utils.visualization import visualize_confmap
SPLITS_LIST = ['train', 'valid', 'test', 'demo']
......@@ -20,7 +20,8 @@ SPLITS_LIST = ['train', 'valid', 'test', 'demo']
def parse_args():
parser = argparse.ArgumentParser(description='Prepare RODNet data.')
parser.add_argument('--config', type=str, dest='config', help='configuration file path')
parser.add_argument('--data_root', type=str, help='directory to the prepared data')
parser.add_argument('--data_root', type=str,
help='directory to the dataset (will overwrite data_root in config file)')
parser.add_argument('--sensor_config', type=str, default='sensor_config_rod2021')
parser.add_argument('--split', type=str, dest='split', default='',
help='choose from train, valid, test, supertest')
......@@ -220,6 +221,7 @@ if __name__ == "__main__":
dataset = CRUW(data_root=data_root, sensor_config_name=args.sensor_config)
config_dict = load_configs_from_file(args.config)
config_dict = update_config_dict(config_dict, args) # update configs by args
radar_configs = dataset.sensor_cfg.radar_cfg
if splits == None:
......
......@@ -18,7 +18,7 @@ from rodnet.datasets.CRDataLoader import CRDataLoader
from rodnet.datasets.collate_functions import cr_collate
from rodnet.core.radar_processing import chirp_amp
from rodnet.utils.solve_dir import create_dir_for_new_model
from rodnet.utils.load_configs import load_configs_from_file
from rodnet.utils.load_configs import load_configs_from_file, update_config_dict
from rodnet.utils.visualization import visualize_train_img
......@@ -51,6 +51,12 @@ if __name__ == "__main__":
from rodnet.models import RODNetHG as RODNet
elif model_cfg['type'] == 'HGwI':
from rodnet.models import RODNetHGwI as RODNet
elif model_cfg['type'] == 'CDCv2':
from rodnet.models import RODNetCDCDCN as RODNet
elif model_cfg['type'] == 'HGv2':
from rodnet.models import RODNetHGDCN as RODNet
elif model_cfg['type'] == 'HGwIv2':
from rodnet.models import RODNetHGwIDCN as RODNet
else:
raise NotImplementedError
......@@ -110,7 +116,7 @@ if __name__ == "__main__":
# dataloader_valid = DataLoader(crdata_valid, batch_size=batch_size, shuffle=True, num_workers=0)
else:
crdata_train = CRDatasetSM(data_root=args.data_dir, config_dict=config_dict, split='train',
crdata_train = CRDatasetSM(data_dir=args.data_dir, dataset=dataset, config_dict=config_dict, split='train',
noise_channel=args.use_noise_channel)
seq_names = crdata_train.seq_names
index_mapping = crdata_train.index_mapping
......@@ -130,13 +136,31 @@ if __name__ == "__main__":
print("Building model ... (%s)" % model_cfg)
if model_cfg['type'] == 'CDC':
rodnet = RODNet(n_class_train).cuda()
criterion = nn.MSELoss()
rodnet = RODNet(in_channels=2, n_class=n_class_train).cuda()
criterion = nn.BCELoss()
elif model_cfg['type'] == 'HG':
rodnet = RODNet(n_class_train, stacked_num=stacked_num).cuda()
rodnet = RODNet(in_channels=2, n_class=n_class_train, stacked_num=stacked_num).cuda()
criterion = nn.BCELoss()
elif model_cfg['type'] == 'HGwI':
rodnet = RODNet(n_class_train, stacked_num=stacked_num).cuda()
rodnet = RODNet(in_channels=2, n_class=n_class_train, stacked_num=stacked_num).cuda()
criterion = nn.BCELoss()
elif model_cfg['type'] == 'CDCv2':
in_chirps = len(radar_configs['chirp_ids'])
rodnet = RODNet(in_channels=in_chirps, n_class=n_class_train,
mnet_cfg=config_dict['model_cfg']['mnet_cfg'],
dcn=config_dict['model_cfg']['dcn']).cuda()
criterion = nn.BCELoss()
elif model_cfg['type'] == 'HGv2':
in_chirps = len(radar_configs['chirp_ids'])
rodnet = RODNet(in_channels=in_chirps, n_class=n_class_train, stacked_num=stacked_num,
mnet_cfg=config_dict['model_cfg']['mnet_cfg'],
dcn=config_dict['model_cfg']['dcn']).cuda()
criterion = nn.BCELoss()
elif model_cfg['type'] == 'HGwIv2':
in_chirps = len(radar_configs['chirp_ids'])
rodnet = RODNet(in_channels=in_chirps, n_class=n_class_train, stacked_num=stacked_num,
mnet_cfg=config_dict['model_cfg']['mnet_cfg'],
dcn=config_dict['model_cfg']['dcn']).cuda()
criterion = nn.BCELoss()
else:
raise TypeError
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment