Unverified Commit 62e43671 authored by ChaimZhu's avatar ChaimZhu Committed by GitHub
Browse files

[Fix] fix nuscenes datasets (#1837)

* fix nuscenes datasets

* fix fcos3d bugs
parent 77d16764
......@@ -53,9 +53,7 @@ test_pipeline = [
]
train_dataloader = dict(
batch_size=2,
num_workers=2,
dataset=dict(dataset=dict(pipeline=train_pipeline)))
batch_size=2, num_workers=2, dataset=dict(pipeline=train_pipeline))
test_dataloader = dict(dataset=dict(pipeline=test_pipeline))
val_dataloader = dict(dataset=dict(pipeline=test_pipeline))
......
......@@ -192,7 +192,8 @@ class Det3DDataset(BaseDataset):
'bbox_3d': 'gt_bboxes_3d',
'depth': 'depths',
'center_2d': 'centers_2d',
'attr_label': 'attr_labels'
'attr_label': 'attr_labels',
'velocity': 'velocities',
}
instances = info['instances']
# empty gt
......@@ -209,14 +210,18 @@ class Det3DDataset(BaseDataset):
self.label_mapping[item] for item in temp_anns
]
if ann_name in name_mapping:
ann_name = name_mapping[ann_name]
mapped_ann_name = name_mapping[ann_name]
else:
mapped_ann_name = ann_name
if 'label' in ann_name:
temp_anns = np.array(temp_anns).astype(np.int64)
else:
elif ann_name in name_mapping:
temp_anns = np.array(temp_anns).astype(np.float32)
else:
temp_anns = np.array(temp_anns)
ann_info[ann_name] = temp_anns
ann_info[mapped_ann_name] = temp_anns
ann_info['instances'] = info['instances']
return ann_info
......
......@@ -85,6 +85,27 @@ class NuScenesDataset(Det3DDataset):
test_mode=test_mode,
**kwargs)
def _filter_with_mask(self, ann_info):
"""Remove annotations that do not need to be cared.
Args:
ann_info (dict): Dict of annotation infos.
Returns:
dict: Annotations after filtering.
"""
filtered_annotations = {}
if self.use_valid_flag:
filter_mask = ann_info['bbox_3d_isvalid']
else:
filter_mask = ann_info['num_lidar_pts'] > 0
for key in ann_info.keys():
if key != 'instances':
filtered_annotations[key] = (ann_info[key][filter_mask])
else:
filtered_annotations[key] = ann_info[key]
return filtered_annotations
def parse_ann_info(self, info: dict) -> dict:
"""Get annotation info according to the given index.
......@@ -99,66 +120,51 @@ class NuScenesDataset(Det3DDataset):
- gt_labels_3d (np.ndarray): Labels of ground truths.
"""
ann_info = super().parse_ann_info(info)
if ann_info is None:
# empty instance
anns_results = dict()
anns_results['gt_bboxes_3d'] = np.zeros((0, 7), dtype=np.float32)
anns_results['gt_labels_3d'] = np.zeros(0, dtype=np.int64)
return anns_results
if ann_info is not None:
if self.use_valid_flag:
mask = ann_info['bbox_3d_isvalid']
else:
mask = ann_info['num_lidar_pts'] > 0
gt_bboxes_3d = ann_info['gt_bboxes_3d'][mask]
gt_labels_3d = ann_info['gt_labels_3d'][mask]
if 'gt_bboxes' in ann_info:
gt_bboxes = ann_info['gt_bboxes'][mask]
gt_labels = ann_info['gt_labels'][mask]
attr_labels = ann_info['attr_labels'][mask]
else:
gt_bboxes = np.zeros((0, 4), dtype=np.float32)
gt_labels = np.array([], dtype=np.int64)
attr_labels = np.array([], dtype=np.int64)
ann_info = self._filter_with_mask(ann_info)
if 'centers_2d' in ann_info:
centers_2d = ann_info['centers_2d'][mask]
depths = ann_info['depths'][mask]
if self.with_velocity:
gt_bboxes_3d = ann_info['gt_bboxes_3d']
gt_velocities = ann_info['velocities']
nan_mask = np.isnan(gt_velocities[:, 0])
gt_velocities[nan_mask] = [0.0, 0.0]
gt_bboxes_3d = np.concatenate([gt_bboxes_3d, gt_velocities],
axis=-1)
ann_info['gt_bboxes_3d'] = gt_bboxes_3d
else:
centers_2d = np.zeros((0, 2), dtype=np.float32)
depths = np.zeros((0), dtype=np.float32)
# empty instance
ann_info = dict()
if self.with_velocity:
gt_velocity = ann_info['velocity'][mask]
nan_mask = np.isnan(gt_velocity[:, 0])
gt_velocity[nan_mask] = [0.0, 0.0]
gt_bboxes_3d = np.concatenate([gt_bboxes_3d, gt_velocity], axis=-1)
ann_info['gt_bboxes_3d'] = np.zeros((0, 9), dtype=np.float32)
else:
ann_info['gt_bboxes_3d'] = np.zeros((0, 7), dtype=np.float32)
ann_info['gt_labels_3d'] = np.zeros(0, dtype=np.int64)
if self.task == 'mono3d':
ann_info['gt_bboxes'] = np.zeros((0, 4), dtype=np.float32)
ann_info['gt_bboxes_labels'] = np.array(0, dtype=np.int64)
ann_info['attr_labels'] = np.array(0, dtype=np.int64)
ann_info['centers_2d'] = np.zeros((0, 2), dtype=np.float32)
ann_info['depths'] = np.zeros((0), dtype=np.float32)
# the nuscenes box center is [0.5, 0.5, 0.5], we change it to be
# the same as KITTI (0.5, 0.5, 0)
# TODO: Unify the coordinates
if self.task == 'mono3d':
gt_bboxes_3d = CameraInstance3DBoxes(
gt_bboxes_3d,
box_dim=gt_bboxes_3d.shape[-1],
ann_info['gt_bboxes_3d'],
box_dim=ann_info['gt_bboxes_3d'].shape[-1],
origin=(0.5, 0.5, 0.5))
else:
gt_bboxes_3d = LiDARInstance3DBoxes(
gt_bboxes_3d,
box_dim=gt_bboxes_3d.shape[-1],
ann_info['gt_bboxes_3d'],
box_dim=ann_info['gt_bboxes_3d'].shape[-1],
origin=(0.5, 0.5, 0.5)).convert_to(self.box_mode_3d)
anns_results = dict(
gt_bboxes_3d=gt_bboxes_3d,
gt_labels_3d=gt_labels_3d,
gt_bboxes=gt_bboxes,
gt_labels=gt_labels,
attr_labels=attr_labels,
centers_2d=centers_2d,
depths=depths)
ann_info['gt_bboxes_3d'] = gt_bboxes_3d
return anns_results
return ann_info
def parse_data_info(self, info: dict) -> dict:
"""Process the raw data info.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment