Commit 1d635cff authored by Yizhou Wang's avatar Yizhou Wang
Browse files

add utils for changing configs from args

parent 3a94bd8d
...@@ -59,7 +59,7 @@ Please refer to [`cruw-devit`](https://github.com/yizhou-wang/cruw-devkit) repos ...@@ -59,7 +59,7 @@ Please refer to [`cruw-devit`](https://github.com/yizhou-wang/cruw-devkit) repos
```commandline ```commandline
git clone https://github.com/yizhou-wang/cruw-devkit.git git clone https://github.com/yizhou-wang/cruw-devkit.git
cd cruw-devkit cd cruw-devkit
pip install -e . pip install .
cd .. cd ..
``` ```
......
...@@ -20,11 +20,103 @@ def load_configs_from_file(config_path): ...@@ -20,11 +20,103 @@ def load_configs_from_file(config_path):
return cfg_dict return cfg_dict
def parse_cfgs(parser):
# dataset_cfg
parser.add_argument('--data_root', type=str,
help='directory to the dataset (will overwrite data_root in config file)')
# model_cfg
parser.add_argument('--model_type', type=str, help='model type')
parser.add_argument('--model_name', type=str, help='model name or exp name')
parser.add_argument('--max_dets', type=int, help='max detection per frome')
parser.add_argument('--peak_thres', type=float, help='peak threshold')
parser.add_argument('--ols_thres', type=float, help='OLS thres')
parser.add_argument('--stacked_num', type=int, help='number of stack for HG')
parser.add_argument('--mnet_cfg', type=tuple, help='MNet configuration')
parser.add_argument('--dcn', type=bool, help='whether use TDC')
# train_cfg
parser.add_argument('--n_epoch', type=int, help='number of training epochs')
parser.add_argument('--batch_size', type=int, help='batch size')
parser.add_argument('--lr', type=float, help='learning rate')
parser.add_argument('--lr_step', type=int, help='step for learning rate decreasing')
parser.add_argument('--win_size', type=int, help='window size for RF images')
parser.add_argument('--train_step', type=int, help='training step within RF snippets')
parser.add_argument('--train_stride', type=int, help='training stride between RF snippets')
parser.add_argument('--log_step', type=int, help='step for printing out log info')
parser.add_argument('--save_step', type=int, help='step for saving checkpoints')
# test_cfg
parser.add_argument('--test_step', type=int, help='testing step within RF snippets')
parser.add_argument('--test_stride', type=int, help='testing stride between RF snippets')
parser.add_argument('--rr_min', type=float, help='range of range min value')
parser.add_argument('--rr_max', type=float, help='range of range max value')
parser.add_argument('--ra_min', type=float, help='range of angle min value')
parser.add_argument('--ra_max', type=float, help='range of angle max value')
return parser
def update_config_dict(config_dict, args): def update_config_dict(config_dict, args):
# dataset_cfg
if hasattr(args, 'data_root') and args.data_root is not None:
data_root_old = config_dict['dataset_cfg']['base_root'] data_root_old = config_dict['dataset_cfg']['base_root']
config_dict['dataset_cfg']['base_root'] = args.data_root config_dict['dataset_cfg']['base_root'] = args.data_root
config_dict['dataset_cfg']['data_root'] = config_dict['dataset_cfg']['data_root'].replace(data_root_old, config_dict['dataset_cfg']['data_root'] = config_dict['dataset_cfg']['data_root'].replace(data_root_old,
args.data_root) args.data_root)
config_dict['dataset_cfg']['anno_root'] = config_dict['dataset_cfg']['anno_root'].replace(data_root_old, config_dict['dataset_cfg']['anno_root'] = config_dict['dataset_cfg']['anno_root'].replace(data_root_old,
args.data_root) args.data_root)
# model_cfg
if hasattr(args, 'model_type') and args.model_type is not None:
config_dict['model_cfg']['type'] = args.model_type
if hasattr(args, 'model_name') and args.model_name is not None:
config_dict['model_cfg']['name'] = args.model_name
if hasattr(args, 'max_dets') and args.max_dets is not None:
config_dict['model_cfg']['max_dets'] = args.max_dets
if hasattr(args, 'peak_thres') and args.peak_thres is not None:
config_dict['model_cfg']['peak_thres'] = args.peak_thres
if hasattr(args, 'ols_thres') and args.ols_thres is not None:
config_dict['model_cfg']['ols_thres'] = args.ols_thres
if hasattr(args, 'stacked_num') and args.stacked_num is not None:
config_dict['model_cfg']['stacked_num'] = args.stacked_num
if hasattr(args, 'mnet_cfg') and args.mnet_cfg is not None:
config_dict['model_cfg']['mnet_cfg'] = args.mnet_cfg
if hasattr(args, 'dcn') and args.dcn is not None:
config_dict['model_cfg']['dcn'] = args.dcn
# train_cfg
if hasattr(args, 'n_epoch') and args.n_epoch is not None:
config_dict['train_cfg']['n_epoch'] = args.n_epoch
if hasattr(args, 'batch_size') and args.batch_size is not None:
config_dict['train_cfg']['batch_size'] = args.batch_size
if hasattr(args, 'lr') and args.lr is not None:
config_dict['train_cfg']['lr'] = args.lr
if hasattr(args, 'lr_step') and args.lr_step is not None:
config_dict['train_cfg']['lr_step'] = args.lr_step
if hasattr(args, 'win_size') and args.win_size is not None:
config_dict['train_cfg']['win_size'] = args.win_size
if hasattr(args, 'train_step') and args.train_step is not None:
config_dict['train_cfg']['train_step'] = args.train_step
if hasattr(args, 'train_stride') and args.train_stride is not None:
config_dict['train_cfg']['train_stride'] = args.train_stride
if hasattr(args, 'log_step') and args.log_step is not None:
config_dict['train_cfg']['log_step'] = args.log_step
if hasattr(args, 'save_step') and args.save_step is not None:
config_dict['train_cfg']['save_step'] = args.save_step
# test_cfg
if hasattr(args, 'test_step') and args.test_step is not None:
config_dict['test_cfg']['test_step'] = args.test_step
if hasattr(args, 'test_stride') and args.test_stride is not None:
config_dict['test_cfg']['test_stride'] = args.test_stride
if hasattr(args, 'rr_min') and args.rr_min is not None:
config_dict['test_cfg']['rr_min'] = args.rr_min
if hasattr(args, 'rr_max') and args.rr_max is not None:
config_dict['test_cfg']['rr_max'] = args.rr_max
if hasattr(args, 'ra_min') and args.ra_min is not None:
config_dict['test_cfg']['ra_min'] = args.ra_min
if hasattr(args, 'ra_max') and args.ra_max is not None:
config_dict['test_cfg']['ra_max'] = args.ra_max
return config_dict return config_dict
import os import os
import time import time
import argparse import argparse
import numpy as np
import torch import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
...@@ -15,7 +14,7 @@ from rodnet.core.post_processing import write_dets_results, write_dets_results_s ...@@ -15,7 +14,7 @@ from rodnet.core.post_processing import write_dets_results, write_dets_results_s
from rodnet.core.post_processing import ConfmapStack from rodnet.core.post_processing import ConfmapStack
from rodnet.core.radar_processing import chirp_amp from rodnet.core.radar_processing import chirp_amp
from rodnet.utils.visualization import visualize_test_img, visualize_test_img_wo_gt from rodnet.utils.visualization import visualize_test_img, visualize_test_img_wo_gt
from rodnet.utils.load_configs import load_configs_from_file from rodnet.utils.load_configs import load_configs_from_file, parse_cfgs, update_config_dict
from rodnet.utils.solve_dir import create_random_model_name from rodnet.utils.solve_dir import create_random_model_name
""" """
...@@ -27,6 +26,7 @@ Example: ...@@ -27,6 +26,7 @@ Example:
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description='Test RODNet.') parser = argparse.ArgumentParser(description='Test RODNet.')
parser.add_argument('--config', type=str, help='choose rodnet model configurations') parser.add_argument('--config', type=str, help='choose rodnet model configurations')
parser.add_argument('--sensor_config', type=str, default='sensor_config_rod2021') parser.add_argument('--sensor_config', type=str, default='sensor_config_rod2021')
parser.add_argument('--data_dir', type=str, default='./data/', help='directory to the prepared data') parser.add_argument('--data_dir', type=str, default='./data/', help='directory to the prepared data')
...@@ -35,6 +35,8 @@ def parse_args(): ...@@ -35,6 +35,8 @@ def parse_args():
parser.add_argument('--use_noise_channel', action="store_true", help="use noise channel or not") parser.add_argument('--use_noise_channel', action="store_true", help="use noise channel or not")
parser.add_argument('--demo', action="store_true", help='False: test with GT, True: demo without GT') parser.add_argument('--demo', action="store_true", help='False: test with GT, True: demo without GT')
parser.add_argument('--symbol', action="store_true", help='use symbol or text+score') parser.add_argument('--symbol', action="store_true", help='use symbol or text+score')
parser = parse_cfgs(parser)
args = parser.parse_args() args = parser.parse_args()
return args return args
...@@ -44,6 +46,8 @@ if __name__ == "__main__": ...@@ -44,6 +46,8 @@ if __name__ == "__main__":
sybl = args.symbol sybl = args.symbol
config_dict = load_configs_from_file(args.config) config_dict = load_configs_from_file(args.config)
config_dict = update_config_dict(config_dict, args) # update configs by args
dataset = CRUW(data_root=config_dict['dataset_cfg']['base_root'], sensor_config_name=args.sensor_config) dataset = CRUW(data_root=config_dict['dataset_cfg']['base_root'], sensor_config_name=args.sensor_config)
radar_configs = dataset.sensor_cfg.radar_cfg radar_configs = dataset.sensor_cfg.radar_cfg
range_grid = dataset.range_grid range_grid = dataset.range_grid
......
...@@ -18,12 +18,13 @@ from rodnet.datasets.CRDataLoader import CRDataLoader ...@@ -18,12 +18,13 @@ from rodnet.datasets.CRDataLoader import CRDataLoader
from rodnet.datasets.collate_functions import cr_collate from rodnet.datasets.collate_functions import cr_collate
from rodnet.core.radar_processing import chirp_amp from rodnet.core.radar_processing import chirp_amp
from rodnet.utils.solve_dir import create_dir_for_new_model from rodnet.utils.solve_dir import create_dir_for_new_model
from rodnet.utils.load_configs import load_configs_from_file, update_config_dict from rodnet.utils.load_configs import load_configs_from_file, parse_cfgs, update_config_dict
from rodnet.utils.visualization import visualize_train_img from rodnet.utils.visualization import visualize_train_img
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description='Train RODNet.') parser = argparse.ArgumentParser(description='Train RODNet.')
parser.add_argument('--config', type=str, help='configuration file path') parser.add_argument('--config', type=str, help='configuration file path')
parser.add_argument('--sensor_config', type=str, default='sensor_config_rod2021') parser.add_argument('--sensor_config', type=str, default='sensor_config_rod2021')
parser.add_argument('--data_dir', type=str, default='./data/', help='directory to the prepared data') parser.add_argument('--data_dir', type=str, default='./data/', help='directory to the prepared data')
...@@ -31,6 +32,8 @@ def parse_args(): ...@@ -31,6 +32,8 @@ def parse_args():
parser.add_argument('--resume_from', type=str, default=None, help='path to the trained model') parser.add_argument('--resume_from', type=str, default=None, help='path to the trained model')
parser.add_argument('--save_memory', action="store_true", help="use customized dataloader to save memory") parser.add_argument('--save_memory', action="store_true", help="use customized dataloader to save memory")
parser.add_argument('--use_noise_channel', action="store_true", help="use noise channel or not") parser.add_argument('--use_noise_channel', action="store_true", help="use noise channel or not")
parser = parse_cfgs(parser)
args = parser.parse_args() args = parser.parse_args()
return args return args
...@@ -38,6 +41,8 @@ def parse_args(): ...@@ -38,6 +41,8 @@ def parse_args():
if __name__ == "__main__": if __name__ == "__main__":
args = parse_args() args = parse_args()
config_dict = load_configs_from_file(args.config) config_dict = load_configs_from_file(args.config)
config_dict = update_config_dict(config_dict, args) # update configs by args
# dataset = CRUW(data_root=config_dict['dataset_cfg']['base_root']) # dataset = CRUW(data_root=config_dict['dataset_cfg']['base_root'])
dataset = CRUW(data_root=config_dict['dataset_cfg']['base_root'], sensor_config_name=args.sensor_config) dataset = CRUW(data_root=config_dict['dataset_cfg']['base_root'], sensor_config_name=args.sensor_config)
radar_configs = dataset.sensor_cfg.radar_cfg radar_configs = dataset.sensor_cfg.radar_cfg
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment