Unverified Commit 62e43671 authored by ChaimZhu's avatar ChaimZhu Committed by GitHub
Browse files

[Fix] fix nuscenes datasets (#1837)

* fix nuscenes datasets

* fix fcos3d bugs
parent 77d16764
...@@ -53,9 +53,7 @@ test_pipeline = [ ...@@ -53,9 +53,7 @@ test_pipeline = [
] ]
train_dataloader = dict( train_dataloader = dict(
batch_size=2, batch_size=2, num_workers=2, dataset=dict(pipeline=train_pipeline))
num_workers=2,
dataset=dict(dataset=dict(pipeline=train_pipeline)))
test_dataloader = dict(dataset=dict(pipeline=test_pipeline)) test_dataloader = dict(dataset=dict(pipeline=test_pipeline))
val_dataloader = dict(dataset=dict(pipeline=test_pipeline)) val_dataloader = dict(dataset=dict(pipeline=test_pipeline))
......
...@@ -192,7 +192,8 @@ class Det3DDataset(BaseDataset): ...@@ -192,7 +192,8 @@ class Det3DDataset(BaseDataset):
'bbox_3d': 'gt_bboxes_3d', 'bbox_3d': 'gt_bboxes_3d',
'depth': 'depths', 'depth': 'depths',
'center_2d': 'centers_2d', 'center_2d': 'centers_2d',
'attr_label': 'attr_labels' 'attr_label': 'attr_labels',
'velocity': 'velocities',
} }
instances = info['instances'] instances = info['instances']
# empty gt # empty gt
...@@ -209,14 +210,18 @@ class Det3DDataset(BaseDataset): ...@@ -209,14 +210,18 @@ class Det3DDataset(BaseDataset):
self.label_mapping[item] for item in temp_anns self.label_mapping[item] for item in temp_anns
] ]
if ann_name in name_mapping: if ann_name in name_mapping:
ann_name = name_mapping[ann_name] mapped_ann_name = name_mapping[ann_name]
else:
mapped_ann_name = ann_name
if 'label' in ann_name: if 'label' in ann_name:
temp_anns = np.array(temp_anns).astype(np.int64) temp_anns = np.array(temp_anns).astype(np.int64)
else: elif ann_name in name_mapping:
temp_anns = np.array(temp_anns).astype(np.float32) temp_anns = np.array(temp_anns).astype(np.float32)
else:
temp_anns = np.array(temp_anns)
ann_info[ann_name] = temp_anns ann_info[mapped_ann_name] = temp_anns
ann_info['instances'] = info['instances'] ann_info['instances'] = info['instances']
return ann_info return ann_info
......
...@@ -85,6 +85,27 @@ class NuScenesDataset(Det3DDataset): ...@@ -85,6 +85,27 @@ class NuScenesDataset(Det3DDataset):
test_mode=test_mode, test_mode=test_mode,
**kwargs) **kwargs)
def _filter_with_mask(self, ann_info):
"""Remove annotations that do not need to be cared.
Args:
ann_info (dict): Dict of annotation infos.
Returns:
dict: Annotations after filtering.
"""
filtered_annotations = {}
if self.use_valid_flag:
filter_mask = ann_info['bbox_3d_isvalid']
else:
filter_mask = ann_info['num_lidar_pts'] > 0
for key in ann_info.keys():
if key != 'instances':
filtered_annotations[key] = (ann_info[key][filter_mask])
else:
filtered_annotations[key] = ann_info[key]
return filtered_annotations
def parse_ann_info(self, info: dict) -> dict: def parse_ann_info(self, info: dict) -> dict:
"""Get annotation info according to the given index. """Get annotation info according to the given index.
...@@ -99,66 +120,51 @@ class NuScenesDataset(Det3DDataset): ...@@ -99,66 +120,51 @@ class NuScenesDataset(Det3DDataset):
- gt_labels_3d (np.ndarray): Labels of ground truths. - gt_labels_3d (np.ndarray): Labels of ground truths.
""" """
ann_info = super().parse_ann_info(info) ann_info = super().parse_ann_info(info)
if ann_info is None: if ann_info is not None:
# empty instance
anns_results = dict() ann_info = self._filter_with_mask(ann_info)
anns_results['gt_bboxes_3d'] = np.zeros((0, 7), dtype=np.float32)
anns_results['gt_labels_3d'] = np.zeros(0, dtype=np.int64) if self.with_velocity:
return anns_results gt_bboxes_3d = ann_info['gt_bboxes_3d']
gt_velocities = ann_info['velocities']
if self.use_valid_flag: nan_mask = np.isnan(gt_velocities[:, 0])
mask = ann_info['bbox_3d_isvalid'] gt_velocities[nan_mask] = [0.0, 0.0]
else: gt_bboxes_3d = np.concatenate([gt_bboxes_3d, gt_velocities],
mask = ann_info['num_lidar_pts'] > 0 axis=-1)
gt_bboxes_3d = ann_info['gt_bboxes_3d'][mask] ann_info['gt_bboxes_3d'] = gt_bboxes_3d
gt_labels_3d = ann_info['gt_labels_3d'][mask]
if 'gt_bboxes' in ann_info:
gt_bboxes = ann_info['gt_bboxes'][mask]
gt_labels = ann_info['gt_labels'][mask]
attr_labels = ann_info['attr_labels'][mask]
else: else:
gt_bboxes = np.zeros((0, 4), dtype=np.float32) # empty instance
gt_labels = np.array([], dtype=np.int64) ann_info = dict()
attr_labels = np.array([], dtype=np.int64) if self.with_velocity:
ann_info['gt_bboxes_3d'] = np.zeros((0, 9), dtype=np.float32)
if 'centers_2d' in ann_info: else:
centers_2d = ann_info['centers_2d'][mask] ann_info['gt_bboxes_3d'] = np.zeros((0, 7), dtype=np.float32)
depths = ann_info['depths'][mask] ann_info['gt_labels_3d'] = np.zeros(0, dtype=np.int64)
else:
centers_2d = np.zeros((0, 2), dtype=np.float32) if self.task == 'mono3d':
depths = np.zeros((0), dtype=np.float32) ann_info['gt_bboxes'] = np.zeros((0, 4), dtype=np.float32)
ann_info['gt_bboxes_labels'] = np.array(0, dtype=np.int64)
if self.with_velocity: ann_info['attr_labels'] = np.array(0, dtype=np.int64)
gt_velocity = ann_info['velocity'][mask] ann_info['centers_2d'] = np.zeros((0, 2), dtype=np.float32)
nan_mask = np.isnan(gt_velocity[:, 0]) ann_info['depths'] = np.zeros((0), dtype=np.float32)
gt_velocity[nan_mask] = [0.0, 0.0]
gt_bboxes_3d = np.concatenate([gt_bboxes_3d, gt_velocity], axis=-1)
# the nuscenes box center is [0.5, 0.5, 0.5], we change it to be # the nuscenes box center is [0.5, 0.5, 0.5], we change it to be
# the same as KITTI (0.5, 0.5, 0) # the same as KITTI (0.5, 0.5, 0)
# TODO: Unify the coordinates # TODO: Unify the coordinates
if self.task == 'mono3d': if self.task == 'mono3d':
gt_bboxes_3d = CameraInstance3DBoxes( gt_bboxes_3d = CameraInstance3DBoxes(
gt_bboxes_3d, ann_info['gt_bboxes_3d'],
box_dim=gt_bboxes_3d.shape[-1], box_dim=ann_info['gt_bboxes_3d'].shape[-1],
origin=(0.5, 0.5, 0.5)) origin=(0.5, 0.5, 0.5))
else: else:
gt_bboxes_3d = LiDARInstance3DBoxes( gt_bboxes_3d = LiDARInstance3DBoxes(
gt_bboxes_3d, ann_info['gt_bboxes_3d'],
box_dim=gt_bboxes_3d.shape[-1], box_dim=ann_info['gt_bboxes_3d'].shape[-1],
origin=(0.5, 0.5, 0.5)).convert_to(self.box_mode_3d) origin=(0.5, 0.5, 0.5)).convert_to(self.box_mode_3d)
anns_results = dict( ann_info['gt_bboxes_3d'] = gt_bboxes_3d
gt_bboxes_3d=gt_bboxes_3d,
gt_labels_3d=gt_labels_3d,
gt_bboxes=gt_bboxes,
gt_labels=gt_labels,
attr_labels=attr_labels,
centers_2d=centers_2d,
depths=depths)
return anns_results return ann_info
def parse_data_info(self, info: dict) -> dict: def parse_data_info(self, info: dict) -> dict:
"""Process the raw data info. """Process the raw data info.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment