Commit 7c85adb5 authored by Shaoshuai Shi's avatar Shaoshuai Shi
Browse files

support to train single-frame model with multi-frame GT database on WOD

parent edb82b9f
...@@ -9,9 +9,8 @@ ...@@ -9,9 +9,8 @@
``` ```
python -m pcdet.datasets.waymo.waymo_dataset --func create_waymo_infos --cfg_file tools/cfgs/dataset_configs/waymo_dataset.yaml --update_info_only python -m pcdet.datasets.waymo.waymo_dataset --func create_waymo_infos --cfg_file tools/cfgs/dataset_configs/waymo_dataset.yaml --update_info_only
``` ```
* Generate multi-frame GT database for copy-paste augmentation of multi-frame training * Generate multi-frame GT database for copy-paste augmentation of multi-frame training. There is also a faster version with parallel data generation by adding `--use_parallel`, but you need to read the codes and rename the file after getting the results.
``` ```
# There is also a faster version with parallel data generation by adding `--use_parallel`, but you need to read the codes and rename the file after getting the results
python -m pcdet.datasets.waymo.waymo_dataset --func create_waymo_gt_database --cfg_file tools/cfgs/dataset_configs/waymo_dataset_multiframe.yaml python -m pcdet.datasets.waymo.waymo_dataset --func create_waymo_gt_database --cfg_file tools/cfgs/dataset_configs/waymo_dataset_multiframe.yaml
``` ```
This will generate the new files like the following (the last three lines under `data/waymo`): This will generate the new files like the following (the last three lines under `data/waymo`):
......
...@@ -30,6 +30,13 @@ class DataBaseSampler(object): ...@@ -30,6 +30,13 @@ class DataBaseSampler(object):
for db_info_path in sampler_cfg.DB_INFO_PATH: for db_info_path in sampler_cfg.DB_INFO_PATH:
db_info_path = self.root_path.resolve() / db_info_path db_info_path = self.root_path.resolve() / db_info_path
if not db_info_path.exists():
assert len(sampler_cfg.DB_INFO_PATH) == 1
sampler_cfg.DB_INFO_PATH[0] = sampler_cfg.BACKUP_DB_INFO['DB_INFO_PATH']
sampler_cfg.DB_DATA_PATH[0] = sampler_cfg.BACKUP_DB_INFO['DB_DATA_PATH']
db_info_path = self.root_path.resolve() / sampler_cfg.DB_INFO_PATH[0]
sampler_cfg.NUM_POINT_FEATURES = sampler_cfg.BACKUP_DB_INFO['NUM_POINT_FEATURES']
with open(str(db_info_path), 'rb') as f: with open(str(db_info_path), 'rb') as f:
infos = pickle.load(f) infos = pickle.load(f)
[self.db_infos[cur_class].extend(infos[cur_class]) for cur_class in class_names] [self.db_infos[cur_class].extend(infos[cur_class]) for cur_class in class_names]
...@@ -403,9 +410,14 @@ class DataBaseSampler(object): ...@@ -403,9 +410,14 @@ class DataBaseSampler(object):
obj_points = np.concatenate(obj_points_list, axis=0) obj_points = np.concatenate(obj_points_list, axis=0)
sampled_gt_names = np.array([x['name'] for x in total_valid_sampled_dict]) sampled_gt_names = np.array([x['name'] for x in total_valid_sampled_dict])
if self.sampler_cfg.get('FILTER_OBJ_POINTS_BY_TIMESTAMP', False) or obj_points.shape[-1] != points.shape[-1]:
if self.sampler_cfg.get('FILTER_OBJ_POINTS_BY_TIMESTAMP', False): if self.sampler_cfg.get('FILTER_OBJ_POINTS_BY_TIMESTAMP', False):
min_time = min(self.sampler_cfg.TIME_RANGE[0], self.sampler_cfg.TIME_RANGE[1]) min_time = min(self.sampler_cfg.TIME_RANGE[0], self.sampler_cfg.TIME_RANGE[1])
max_time = max(self.sampler_cfg.TIME_RANGE[0], self.sampler_cfg.TIME_RANGE[1]) max_time = max(self.sampler_cfg.TIME_RANGE[0], self.sampler_cfg.TIME_RANGE[1])
else:
assert obj_points.shape[-1] == points.shape[-1] + 1
# transform multi-frame GT points to single-frame GT points
min_time = max_time = 0.0
time_mask = np.logical_and(obj_points[:, -1] < max_time + 1e-6, obj_points[:, -1] > min_time - 1e-6) time_mask = np.logical_and(obj_points[:, -1] < max_time + 1e-6, obj_points[:, -1] > min_time - 1e-6)
obj_points = obj_points[time_mask] obj_points = obj_points[time_mask]
...@@ -414,7 +426,7 @@ class DataBaseSampler(object): ...@@ -414,7 +426,7 @@ class DataBaseSampler(object):
sampled_gt_boxes[:, 0:7], extra_width=self.sampler_cfg.REMOVE_EXTRA_WIDTH sampled_gt_boxes[:, 0:7], extra_width=self.sampler_cfg.REMOVE_EXTRA_WIDTH
) )
points = box_utils.remove_points_in_boxes3d(points, large_sampled_gt_boxes) points = box_utils.remove_points_in_boxes3d(points, large_sampled_gt_boxes)
points = np.concatenate([obj_points, points], axis=0) points = np.concatenate([obj_points[:, :points.shape[-1]], points], axis=0)
gt_names = np.concatenate([gt_names, sampled_gt_names], axis=0) gt_names = np.concatenate([gt_names, sampled_gt_names], axis=0)
gt_boxes = np.concatenate([gt_boxes, sampled_gt_boxes], axis=0) gt_boxes = np.concatenate([gt_boxes, sampled_gt_boxes], axis=0)
data_dict['gt_boxes'] = gt_boxes data_dict['gt_boxes'] = gt_boxes
......
...@@ -155,6 +155,7 @@ class WaymoDataset(DatasetTemplate): ...@@ -155,6 +155,7 @@ class WaymoDataset(DatasetTemplate):
for sequence_file in self.sample_sequence_list for sequence_file in self.sample_sequence_list
] ]
# process_single_sequence(sample_sequence_file_list[0])
with multiprocessing.Pool(num_workers) as p: with multiprocessing.Pool(num_workers) as p:
sequence_infos = list(tqdm(p.imap(process_single_sequence, sample_sequence_file_list), sequence_infos = list(tqdm(p.imap(process_single_sequence, sample_sequence_file_list),
total=len(sample_sequence_file_list))) total=len(sample_sequence_file_list)))
...@@ -369,7 +370,7 @@ class WaymoDataset(DatasetTemplate): ...@@ -369,7 +370,7 @@ class WaymoDataset(DatasetTemplate):
point_offset_cnt = 0 point_offset_cnt = 0
stacked_gt_points = [] stacked_gt_points = []
for k in tqdm(range(0, len(infos), sampled_interval)): for k in tqdm(range(0, len(infos), sampled_interval)):
print('gt_database sample: %d/%d' % (k + 1, len(infos))) # print('gt_database sample: %d/%d' % (k + 1, len(infos)))
info = infos[k] info = infos[k]
pc_info = info['point_cloud'] pc_info = info['point_cloud']
......
...@@ -210,7 +210,7 @@ def process_single_sequence(sequence_file, save_path, sampled_interval, has_labe ...@@ -210,7 +210,7 @@ def process_single_sequence(sequence_file, save_path, sampled_interval, has_labe
sequence_infos = [] sequence_infos = []
if pkl_file.exists(): if pkl_file.exists():
sequence_infos = pickle.load(open(pkl_file, 'rb')) sequence_infos = pickle.load(open(pkl_file, 'rb'))
sequence_infos_old = None
if not update_info_only: if not update_info_only:
print('Skip sequence since it has been processed before: %s' % pkl_file) print('Skip sequence since it has been processed before: %s' % pkl_file)
return sequence_infos return sequence_infos
...@@ -248,7 +248,7 @@ def process_single_sequence(sequence_file, save_path, sampled_interval, has_labe ...@@ -248,7 +248,7 @@ def process_single_sequence(sequence_file, save_path, sampled_interval, has_labe
annotations = generate_labels(frame, pose=pose) annotations = generate_labels(frame, pose=pose)
info['annos'] = annotations info['annos'] = annotations
if update_info_only: if update_info_only and sequence_infos_old is not None:
assert info['frame_id'] == sequence_infos_old[cnt]['frame_id'] assert info['frame_id'] == sequence_infos_old[cnt]['frame_id']
num_points_of_each_lidar = sequence_infos_old[cnt]['num_points_of_each_lidar'] num_points_of_each_lidar = sequence_infos_old[cnt]['num_points_of_each_lidar']
else: else:
......
...@@ -32,6 +32,12 @@ DATA_AUGMENTOR: ...@@ -32,6 +32,12 @@ DATA_AUGMENTOR:
DB_DATA_PATH: DB_DATA_PATH:
- waymo_processed_data_v0_5_0_gt_database_train_sampled_1_global.npy - waymo_processed_data_v0_5_0_gt_database_train_sampled_1_global.npy
BACKUP_DB_INFO:
# if the above DB_INFO cannot be found, will use this backup one
DB_INFO_PATH: waymo_processed_data_v0_5_0_waymo_dbinfos_train_sampled_1_multiframe_-4_to_0.pkl
DB_DATA_PATH: waymo_processed_data_v0_5_0_gt_database_train_sampled_1_multiframe_-4_to_0_global.npy
NUM_POINT_FEATURES: 6
PREPARE: { PREPARE: {
filter_by_min_points: ['Vehicle:5', 'Pedestrian:5', 'Cyclist:5'], filter_by_min_points: ['Vehicle:5', 'Pedestrian:5', 'Cyclist:5'],
filter_by_difficulty: [-1], filter_by_difficulty: [-1],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment