Unverified Commit 1483517a authored by Christian Reisinger's avatar Christian Reisinger Committed by GitHub
Browse files

Gt sampler sweeps (#721)



* Fixes #718

* FIX: replaced wrong part

* #719 Fixes index error max_dt while inference: No gt_db_sampling -> MAX_SWEEPS-1 > # unique dt (e.g. first frame)

* move max_sweeps filter to PointFeatureEncoder with new config FILTER_SWEEPS

* update to lowercase
Co-authored-by: default avatarChristian Fruhwirth-Reisinger <christian.reisinger@student.tugraz.at>
Co-authored-by: default avatarShaoshuai Shi <shaoshuaics@gmail.com>
parent 1f5b7872
...@@ -142,15 +142,6 @@ class DatasetTemplate(torch_data.Dataset): ...@@ -142,15 +142,6 @@ class DatasetTemplate(torch_data.Dataset):
if data_dict.get('gt_boxes2d', None) is not None: if data_dict.get('gt_boxes2d', None) is not None:
data_dict['gt_boxes2d'] = data_dict['gt_boxes2d'][selected] data_dict['gt_boxes2d'] = data_dict['gt_boxes2d'][selected]
if 'timestamp' in self.dataset_cfg.POINT_FEATURE_ENCODING.get('src_feature_list'):
if data_dict.get('points', None) is not None:
max_sweeps = self.dataset_cfg.get('MAX_SWEEPS', 1)
idx = self.dataset_cfg.POINT_FEATURE_ENCODING.get('src_feature_list').index('timestamp')
dt = np.round(data_dict['points'][:, idx], 2)
if np.unique(dt).shape[0] == max_sweeps:
max_dt = sorted(np.unique(dt))[max_sweeps-1]
data_dict['points'] = data_dict['points'][dt <= max_dt]
if data_dict.get('points', None) is not None: if data_dict.get('points', None) is not None:
data_dict = self.point_feature_encoder.forward(data_dict) data_dict = self.point_feature_encoder.forward(data_dict)
......
...@@ -30,6 +30,14 @@ class PointFeatureEncoder(object): ...@@ -30,6 +30,14 @@ class PointFeatureEncoder(object):
data_dict['points'] data_dict['points']
) )
data_dict['use_lead_xyz'] = use_lead_xyz data_dict['use_lead_xyz'] = use_lead_xyz
if self.point_encoding_config.get('filter_sweeps', False) and 'timestamp' in self.src_feature_list:
max_sweeps = self.point_encoding_config.max_sweeps
idx = self.src_feature_list.index('timestamp')
dt = np.round(data_dict['points'][:, idx], 2)
max_dt = sorted(np.unique(dt))[min(len(np.unique(dt))-1, max_sweeps-1)]
data_dict['points'] = data_dict['points'][dt <= max_dt]
return data_dict return data_dict
def absolute_coordinates_encoding(self, points=None): def absolute_coordinates_encoding(self, points=None):
...@@ -44,4 +52,5 @@ class PointFeatureEncoder(object): ...@@ -44,4 +52,5 @@ class PointFeatureEncoder(object):
idx = self.src_feature_list.index(x) idx = self.src_feature_list.index(x)
point_feature_list.append(points[:, idx:idx+1]) point_feature_list.append(points[:, idx:idx+1])
point_features = np.concatenate(point_feature_list, axis=1) point_features = np.concatenate(point_feature_list, axis=1)
return point_features, True return point_features, True
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment