You need to sign in or sign up before continuing.
Commit 1d635cff authored by Yizhou Wang's avatar Yizhou Wang
Browse files

add utils for changing configs from args

parent 3a94bd8d
...@@ -59,7 +59,7 @@ Please refer to [`cruw-devit`](https://github.com/yizhou-wang/cruw-devkit) repos ...@@ -59,7 +59,7 @@ Please refer to [`cruw-devit`](https://github.com/yizhou-wang/cruw-devkit) repos
```commandline ```commandline
git clone https://github.com/yizhou-wang/cruw-devkit.git git clone https://github.com/yizhou-wang/cruw-devkit.git
cd cruw-devkit cd cruw-devkit
pip install -e . pip install .
cd .. cd ..
``` ```
......
...@@ -20,11 +20,103 @@ def load_configs_from_file(config_path): ...@@ -20,11 +20,103 @@ def load_configs_from_file(config_path):
return cfg_dict return cfg_dict
def parse_cfgs(parser):
# dataset_cfg
parser.add_argument('--data_root', type=str,
help='directory to the dataset (will overwrite data_root in config file)')
# model_cfg
parser.add_argument('--model_type', type=str, help='model type')
parser.add_argument('--model_name', type=str, help='model name or exp name')
parser.add_argument('--max_dets', type=int, help='max detection per frome')
parser.add_argument('--peak_thres', type=float, help='peak threshold')
parser.add_argument('--ols_thres', type=float, help='OLS thres')
parser.add_argument('--stacked_num', type=int, help='number of stack for HG')
parser.add_argument('--mnet_cfg', type=tuple, help='MNet configuration')
parser.add_argument('--dcn', type=bool, help='whether use TDC')
# train_cfg
parser.add_argument('--n_epoch', type=int, help='number of training epochs')
parser.add_argument('--batch_size', type=int, help='batch size')
parser.add_argument('--lr', type=float, help='learning rate')
parser.add_argument('--lr_step', type=int, help='step for learning rate decreasing')
parser.add_argument('--win_size', type=int, help='window size for RF images')
parser.add_argument('--train_step', type=int, help='training step within RF snippets')
parser.add_argument('--train_stride', type=int, help='training stride between RF snippets')
parser.add_argument('--log_step', type=int, help='step for printing out log info')
parser.add_argument('--save_step', type=int, help='step for saving checkpoints')
# test_cfg
parser.add_argument('--test_step', type=int, help='testing step within RF snippets')
parser.add_argument('--test_stride', type=int, help='testing stride between RF snippets')
parser.add_argument('--rr_min', type=float, help='range of range min value')
parser.add_argument('--rr_max', type=float, help='range of range max value')
parser.add_argument('--ra_min', type=float, help='range of angle min value')
parser.add_argument('--ra_max', type=float, help='range of angle max value')
return parser
def update_config_dict(config_dict, args): def update_config_dict(config_dict, args):
data_root_old = config_dict['dataset_cfg']['base_root'] # dataset_cfg
config_dict['dataset_cfg']['base_root'] = args.data_root if hasattr(args, 'data_root') and args.data_root is not None:
config_dict['dataset_cfg']['data_root'] = config_dict['dataset_cfg']['data_root'].replace(data_root_old, data_root_old = config_dict['dataset_cfg']['base_root']
args.data_root) config_dict['dataset_cfg']['base_root'] = args.data_root
config_dict['dataset_cfg']['anno_root'] = config_dict['dataset_cfg']['anno_root'].replace(data_root_old, config_dict['dataset_cfg']['data_root'] = config_dict['dataset_cfg']['data_root'].replace(data_root_old,
args.data_root) args.data_root)
config_dict['dataset_cfg']['anno_root'] = config_dict['dataset_cfg']['anno_root'].replace(data_root_old,
args.data_root)
# model_cfg
if hasattr(args, 'model_type') and args.model_type is not None:
config_dict['model_cfg']['type'] = args.model_type
if hasattr(args, 'model_name') and args.model_name is not None:
config_dict['model_cfg']['name'] = args.model_name
if hasattr(args, 'max_dets') and args.max_dets is not None:
config_dict['model_cfg']['max_dets'] = args.max_dets
if hasattr(args, 'peak_thres') and args.peak_thres is not None:
config_dict['model_cfg']['peak_thres'] = args.peak_thres
if hasattr(args, 'ols_thres') and args.ols_thres is not None:
config_dict['model_cfg']['ols_thres'] = args.ols_thres
if hasattr(args, 'stacked_num') and args.stacked_num is not None:
config_dict['model_cfg']['stacked_num'] = args.stacked_num
if hasattr(args, 'mnet_cfg') and args.mnet_cfg is not None:
config_dict['model_cfg']['mnet_cfg'] = args.mnet_cfg
if hasattr(args, 'dcn') and args.dcn is not None:
config_dict['model_cfg']['dcn'] = args.dcn
# train_cfg
if hasattr(args, 'n_epoch') and args.n_epoch is not None:
config_dict['train_cfg']['n_epoch'] = args.n_epoch
if hasattr(args, 'batch_size') and args.batch_size is not None:
config_dict['train_cfg']['batch_size'] = args.batch_size
if hasattr(args, 'lr') and args.lr is not None:
config_dict['train_cfg']['lr'] = args.lr
if hasattr(args, 'lr_step') and args.lr_step is not None:
config_dict['train_cfg']['lr_step'] = args.lr_step
if hasattr(args, 'win_size') and args.win_size is not None:
config_dict['train_cfg']['win_size'] = args.win_size
if hasattr(args, 'train_step') and args.train_step is not None:
config_dict['train_cfg']['train_step'] = args.train_step
if hasattr(args, 'train_stride') and args.train_stride is not None:
config_dict['train_cfg']['train_stride'] = args.train_stride
if hasattr(args, 'log_step') and args.log_step is not None:
config_dict['train_cfg']['log_step'] = args.log_step
if hasattr(args, 'save_step') and args.save_step is not None:
config_dict['train_cfg']['save_step'] = args.save_step
# test_cfg
if hasattr(args, 'test_step') and args.test_step is not None:
config_dict['test_cfg']['test_step'] = args.test_step
if hasattr(args, 'test_stride') and args.test_stride is not None:
config_dict['test_cfg']['test_stride'] = args.test_stride
if hasattr(args, 'rr_min') and args.rr_min is not None:
config_dict['test_cfg']['rr_min'] = args.rr_min
if hasattr(args, 'rr_max') and args.rr_max is not None:
config_dict['test_cfg']['rr_max'] = args.rr_max
if hasattr(args, 'ra_min') and args.ra_min is not None:
config_dict['test_cfg']['ra_min'] = args.ra_min
if hasattr(args, 'ra_max') and args.ra_max is not None:
config_dict['test_cfg']['ra_max'] = args.ra_max
return config_dict return config_dict
import os import os
import time import time
import argparse import argparse
import numpy as np
import torch import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
...@@ -15,7 +14,7 @@ from rodnet.core.post_processing import write_dets_results, write_dets_results_s ...@@ -15,7 +14,7 @@ from rodnet.core.post_processing import write_dets_results, write_dets_results_s
from rodnet.core.post_processing import ConfmapStack from rodnet.core.post_processing import ConfmapStack
from rodnet.core.radar_processing import chirp_amp from rodnet.core.radar_processing import chirp_amp
from rodnet.utils.visualization import visualize_test_img, visualize_test_img_wo_gt from rodnet.utils.visualization import visualize_test_img, visualize_test_img_wo_gt
from rodnet.utils.load_configs import load_configs_from_file from rodnet.utils.load_configs import load_configs_from_file, parse_cfgs, update_config_dict
from rodnet.utils.solve_dir import create_random_model_name from rodnet.utils.solve_dir import create_random_model_name
""" """
...@@ -27,6 +26,7 @@ Example: ...@@ -27,6 +26,7 @@ Example:
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description='Test RODNet.') parser = argparse.ArgumentParser(description='Test RODNet.')
parser.add_argument('--config', type=str, help='choose rodnet model configurations') parser.add_argument('--config', type=str, help='choose rodnet model configurations')
parser.add_argument('--sensor_config', type=str, default='sensor_config_rod2021') parser.add_argument('--sensor_config', type=str, default='sensor_config_rod2021')
parser.add_argument('--data_dir', type=str, default='./data/', help='directory to the prepared data') parser.add_argument('--data_dir', type=str, default='./data/', help='directory to the prepared data')
...@@ -35,6 +35,8 @@ def parse_args(): ...@@ -35,6 +35,8 @@ def parse_args():
parser.add_argument('--use_noise_channel', action="store_true", help="use noise channel or not") parser.add_argument('--use_noise_channel', action="store_true", help="use noise channel or not")
parser.add_argument('--demo', action="store_true", help='False: test with GT, True: demo without GT') parser.add_argument('--demo', action="store_true", help='False: test with GT, True: demo without GT')
parser.add_argument('--symbol', action="store_true", help='use symbol or text+score') parser.add_argument('--symbol', action="store_true", help='use symbol or text+score')
parser = parse_cfgs(parser)
args = parser.parse_args() args = parser.parse_args()
return args return args
...@@ -44,6 +46,8 @@ if __name__ == "__main__": ...@@ -44,6 +46,8 @@ if __name__ == "__main__":
sybl = args.symbol sybl = args.symbol
config_dict = load_configs_from_file(args.config) config_dict = load_configs_from_file(args.config)
config_dict = update_config_dict(config_dict, args) # update configs by args
dataset = CRUW(data_root=config_dict['dataset_cfg']['base_root'], sensor_config_name=args.sensor_config) dataset = CRUW(data_root=config_dict['dataset_cfg']['base_root'], sensor_config_name=args.sensor_config)
radar_configs = dataset.sensor_cfg.radar_cfg radar_configs = dataset.sensor_cfg.radar_cfg
range_grid = dataset.range_grid range_grid = dataset.range_grid
......
...@@ -18,12 +18,13 @@ from rodnet.datasets.CRDataLoader import CRDataLoader ...@@ -18,12 +18,13 @@ from rodnet.datasets.CRDataLoader import CRDataLoader
from rodnet.datasets.collate_functions import cr_collate from rodnet.datasets.collate_functions import cr_collate
from rodnet.core.radar_processing import chirp_amp from rodnet.core.radar_processing import chirp_amp
from rodnet.utils.solve_dir import create_dir_for_new_model from rodnet.utils.solve_dir import create_dir_for_new_model
from rodnet.utils.load_configs import load_configs_from_file, update_config_dict from rodnet.utils.load_configs import load_configs_from_file, parse_cfgs, update_config_dict
from rodnet.utils.visualization import visualize_train_img from rodnet.utils.visualization import visualize_train_img
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description='Train RODNet.') parser = argparse.ArgumentParser(description='Train RODNet.')
parser.add_argument('--config', type=str, help='configuration file path') parser.add_argument('--config', type=str, help='configuration file path')
parser.add_argument('--sensor_config', type=str, default='sensor_config_rod2021') parser.add_argument('--sensor_config', type=str, default='sensor_config_rod2021')
parser.add_argument('--data_dir', type=str, default='./data/', help='directory to the prepared data') parser.add_argument('--data_dir', type=str, default='./data/', help='directory to the prepared data')
...@@ -31,6 +32,8 @@ def parse_args(): ...@@ -31,6 +32,8 @@ def parse_args():
parser.add_argument('--resume_from', type=str, default=None, help='path to the trained model') parser.add_argument('--resume_from', type=str, default=None, help='path to the trained model')
parser.add_argument('--save_memory', action="store_true", help="use customized dataloader to save memory") parser.add_argument('--save_memory', action="store_true", help="use customized dataloader to save memory")
parser.add_argument('--use_noise_channel', action="store_true", help="use noise channel or not") parser.add_argument('--use_noise_channel', action="store_true", help="use noise channel or not")
parser = parse_cfgs(parser)
args = parser.parse_args() args = parser.parse_args()
return args return args
...@@ -38,6 +41,8 @@ def parse_args(): ...@@ -38,6 +41,8 @@ def parse_args():
if __name__ == "__main__": if __name__ == "__main__":
args = parse_args() args = parse_args()
config_dict = load_configs_from_file(args.config) config_dict = load_configs_from_file(args.config)
config_dict = update_config_dict(config_dict, args) # update configs by args
# dataset = CRUW(data_root=config_dict['dataset_cfg']['base_root']) # dataset = CRUW(data_root=config_dict['dataset_cfg']['base_root'])
dataset = CRUW(data_root=config_dict['dataset_cfg']['base_root'], sensor_config_name=args.sensor_config) dataset = CRUW(data_root=config_dict['dataset_cfg']['base_root'], sensor_config_name=args.sensor_config)
radar_configs = dataset.sensor_cfg.radar_cfg radar_configs = dataset.sensor_cfg.radar_cfg
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment