Unverified Commit 9287164e authored by xizaoqu's avatar xizaoqu Committed by GitHub
Browse files

[Feature] Add panoptic segmentation loading. (#2223)

* add panoptic loading

* update

* update

* update

* add semantic config

* fix tets

* fix test

* update

* get seg label mapping
parent 4cfae3f0
# dataset settings
dataset_type = 'SemanticKITTIDataset'
data_root = 'data/semantickitti/'
class_names = [
'unlabeled', 'car', 'bicycle', 'motorcycle', 'truck', 'bus', 'person',
'bicyclist', 'motorcyclist', 'road', 'parking', 'sidewalk', 'other-ground',
'building', 'fence', 'vegetation', 'trunck', 'terrian', 'pole',
'traffic-sign'
]
palette = [
[174, 199, 232],
[152, 223, 138],
[31, 119, 180],
[255, 187, 120],
[188, 189, 34],
[140, 86, 75],
[255, 152, 150],
[214, 39, 40],
[197, 176, 213],
[148, 103, 189],
[196, 156, 148],
[23, 190, 207],
[247, 182, 210],
[219, 219, 141],
[255, 127, 14],
[158, 218, 229],
[44, 160, 44],
[112, 128, 144],
[227, 119, 194],
[82, 84, 163],
]
labels_map = {
0: 0, # "unlabeled"
1: 0, # "outlier" mapped to "unlabeled" --------------mapped
10: 1, # "car"
11: 2, # "bicycle"
13: 5, # "bus" mapped to "other-vehicle" --------------mapped
15: 3, # "motorcycle"
16: 5, # "on-rails" mapped to "other-vehicle" ---------mapped
18: 4, # "truck"
20: 5, # "other-vehicle"
30: 6, # "person"
31: 7, # "bicyclist"
32: 8, # "motorcyclist"
40: 9, # "road"
44: 10, # "parking"
48: 11, # "sidewalk"
49: 12, # "other-ground"
50: 13, # "building"
51: 14, # "fence"
52: 0, # "other-structure" mapped to "unlabeled" ------mapped
60: 9, # "lane-marking" to "road" ---------------------mapped
70: 15, # "vegetation"
71: 16, # "trunk"
72: 17, # "terrain"
80: 18, # "pole"
81: 19, # "traffic-sign"
99: 0, # "other-object" to "unlabeled" ----------------mapped
252: 1, # "moving-car" to "car" ------------------------mapped
253: 7, # "moving-bicyclist" to "bicyclist" ------------mapped
254: 6, # "moving-person" to "person" ------------------mapped
255: 8, # "moving-motorcyclist" to "motorcyclist" ------mapped
256: 5, # "moving-on-rails" mapped to "other-vehic------mapped
257: 5, # "moving-bus" mapped to "other-vehicle" -------mapped
258: 4, # "moving-truck" to "truck" --------------------mapped
259: 5 # "moving-other"-vehicle to "other-vehicle"-----mapped
}
metainfo = dict(
classes=class_names,
palette=palette,
seg_label_mapping=labels_map,
max_label=259)
input_modality = dict(use_lidar=True, use_camera=False)
file_client_args = dict(backend='disk')
# Uncomment the following if use ceph or other file clients.
# See https://mmcv.readthedocs.io/en/latest/api.html#mmcv.fileio.FileClient
# for more details.
# file_client_args = dict(
# backend='petrel',
# path_mapping=dict({
# './data/semantickitti/':
# 's3://semantickitti/',
# }))
train_pipeline = [
dict(
type='LoadPointsFromFile',
coord_type='LIDAR',
load_dim=4,
use_dim=4,
file_client_args=file_client_args),
dict(
type='LoadAnnotations3D',
with_seg_3d=True,
seg_offset=2**16,
dataset_type='semantickitti'),
dict(type='PointSegClassMapping', ),
dict(
type='RandomFlip3D',
sync_2d=False,
flip_ratio_bev_horizontal=0.5,
flip_ratio_bev_vertical=0.5),
dict(
type='GlobalRotScaleTrans',
rot_range=[-0.78539816, 0.78539816],
scale_ratio_range=[0.95, 1.05],
translation_std=[0.1, 0.1, 0.1],
),
dict(type='Pack3DDetInputs', keys=['points', 'pts_semantic_mask'])
]
test_pipeline = [
dict(
type='LoadPointsFromFile',
coord_type='LIDAR',
load_dim=4,
use_dim=4,
file_client_args=file_client_args),
dict(
type='LoadAnnotations3D',
with_seg_3d=True,
seg_offset=2**16,
dataset_type='semantickitti'),
dict(type='PointSegClassMapping', ),
dict(type='Pack3DDetInputs', keys=['points', 'pts_semantic_mask'])
]
# construct a pipeline for data and gt loading in show function
# please keep its loading function consistent with test_pipeline (e.g. client)
eval_pipeline = [
dict(
type='LoadPointsFromFile',
coord_type='LIDAR',
load_dim=4,
use_dim=4,
file_client_args=file_client_args),
dict(
type='LoadAnnotations3D',
with_seg_3d=True,
seg_offset=2**16,
dataset_type='semantickitti'),
dict(type='PointSegClassMapping', ),
dict(type='Pack3DDetInputs', keys=['points', 'pts_semantic_mask'])
]
train_dataloader = dict(
batch_size=4,
num_workers=4,
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(
type='RepeatDataset',
times=1,
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='train_infos.pkl',
pipeline=train_pipeline,
metainfo=metainfo,
modality=input_modality)),
)
test_dataloader = dict(
batch_size=1,
num_workers=1,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type='RepeatDataset',
times=1,
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='valid_infos.pkl',
pipeline=test_pipeline,
metainfo=metainfo,
modality=input_modality,
test_mode=True,
)),
)
val_dataloader = test_dataloader
val_evaluator = dict(type='SegMetric')
test_evaluator = val_evaluator
...@@ -107,14 +107,7 @@ class Seg3DDataset(BaseDataset): ...@@ -107,14 +107,7 @@ class Seg3DDataset(BaseDataset):
metainfo['palette'] = updated_palette metainfo['palette'] = updated_palette
# construct seg_label_mapping for semantic mask # construct seg_label_mapping for semantic mask
seg_max_cat_id = len(self.METAINFO['seg_all_class_ids']) self.seg_label_mapping = self.get_seg_label_mapping(metainfo)
seg_valid_cat_ids = self.METAINFO['seg_valid_class_ids']
neg_label = len(seg_valid_cat_ids)
seg_label_mapping = np.ones(
seg_max_cat_id + 1, dtype=np.int) * neg_label
for cls_idx, cat_id in enumerate(seg_valid_cat_ids):
seg_label_mapping[cat_id] = cls_idx
self.seg_label_mapping = seg_label_mapping
super().__init__( super().__init__(
ann_file=ann_file, ann_file=ann_file,
...@@ -192,6 +185,29 @@ class Seg3DDataset(BaseDataset): ...@@ -192,6 +185,29 @@ class Seg3DDataset(BaseDataset):
return label_mapping, label2cat, valid_class_ids return label_mapping, label2cat, valid_class_ids
def get_seg_label_mapping(self, metainfo=None):
"""Get segmentation label mapping.
The ``seg_label_mapping`` is an array, its indices are the old label
ids and its values are the new label ids, and is specifically used
for changing point labels in PointSegClassMapping.
Args:
metainfo (dict, optional): Meta information to set
seg_label_mapping. Defaults to None.
Returns:
tuple: The mapping from old classes to new classes.
"""
seg_max_cat_id = len(self.METAINFO['seg_all_class_ids'])
seg_valid_cat_ids = self.METAINFO['seg_valid_class_ids']
neg_label = len(seg_valid_cat_ids)
seg_label_mapping = np.ones(
seg_max_cat_id + 1, dtype=np.int) * neg_label
for cls_idx, cat_id in enumerate(seg_valid_cat_ids):
seg_label_mapping[cat_id] = cls_idx
return seg_label_mapping
def _update_palette(self, new_classes: list, palette: Union[None, def _update_palette(self, new_classes: list, palette: Union[None,
list]) -> list: list]) -> list:
"""Update palette according to metainfo. """Update palette according to metainfo.
......
...@@ -51,7 +51,7 @@ class SemanticKITTIDataset(Seg3DDataset): ...@@ -51,7 +51,7 @@ class SemanticKITTIDataset(Seg3DDataset):
'seg_valid_class_ids': 'seg_valid_class_ids':
tuple(range(20)), tuple(range(20)),
'seg_all_class_ids': 'seg_all_class_ids':
tuple(range(20)) tuple(range(20)),
} }
def __init__(self, def __init__(self,
...@@ -81,3 +81,9 @@ class SemanticKITTIDataset(Seg3DDataset): ...@@ -81,3 +81,9 @@ class SemanticKITTIDataset(Seg3DDataset):
scene_idxs=scene_idxs, scene_idxs=scene_idxs,
test_mode=test_mode, test_mode=test_mode,
**kwargs) **kwargs)
def get_seg_label_mapping(self, metainfo):
seg_label_mapping = np.zeros(metainfo['max_label'] + 1)
for idx in metainfo['seg_label_mapping']:
seg_label_mapping[idx] = metainfo['seg_label_mapping'][idx]
return seg_label_mapping
...@@ -758,6 +758,8 @@ class LoadAnnotations3D(LoadAnnotations): ...@@ -758,6 +758,8 @@ class LoadAnnotations3D(LoadAnnotations):
Only when `with_mask_3d` is True. Only when `with_mask_3d` is True.
- pts_semantic_mask_path (str): Path of semantic mask file. - pts_semantic_mask_path (str): Path of semantic mask file.
Only when `with_seg_3d` is True. Only when `with_seg_3d` is True.
- pts_panoptic_mask_path (str): Path of panoptic mask file.
Only when both `with_panoptic_3d` is True.
Added Keys: Added Keys:
...@@ -795,9 +797,15 @@ class LoadAnnotations3D(LoadAnnotations): ...@@ -795,9 +797,15 @@ class LoadAnnotations3D(LoadAnnotations):
with_mask (bool): Whether to load 2D instance masks. Defaults to False. with_mask (bool): Whether to load 2D instance masks. Defaults to False.
with_seg (bool): Whether to load 2D semantic masks. Defaults to False. with_seg (bool): Whether to load 2D semantic masks. Defaults to False.
with_bbox_depth (bool): Whether to load 2.5D boxes. Defaults to False. with_bbox_depth (bool): Whether to load 2.5D boxes. Defaults to False.
with_panoptic_3d (bool): Whether to load 3D panoptic masks for points.
Defaults to False.
poly2mask (bool): Whether to convert polygon annotations to bitmasks. poly2mask (bool): Whether to convert polygon annotations to bitmasks.
Defaults to True. Defaults to True.
seg_3d_dtype (dtype): Dtype of 3D semantic masks. Defaults to int64. seg_3d_dtype (dtype): Dtype of 3D semantic masks. Defaults to int64.
seg_offset (int): The offset to split semantic and instance labels from
panoptic labels. Defaults to None.
dataset_type (str): Type of dataset used for splitting semantic and
instance labels. Defaults to None.
file_client_args (dict): Arguments to instantiate a FileClient. file_client_args (dict): Arguments to instantiate a FileClient.
See :class:`mmengine.fileio.FileClient` for details. See :class:`mmengine.fileio.FileClient` for details.
Defaults to dict(backend='disk'). Defaults to dict(backend='disk').
...@@ -815,8 +823,11 @@ class LoadAnnotations3D(LoadAnnotations): ...@@ -815,8 +823,11 @@ class LoadAnnotations3D(LoadAnnotations):
with_mask: bool = False, with_mask: bool = False,
with_seg: bool = False, with_seg: bool = False,
with_bbox_depth: bool = False, with_bbox_depth: bool = False,
with_panoptic_3d: bool = False,
poly2mask: bool = True, poly2mask: bool = True,
seg_3d_dtype: np.dtype = np.int64, seg_3d_dtype: np.dtype = np.int64,
seg_offset: int = None,
dataset_type: str = None,
file_client_args: dict = dict(backend='disk') file_client_args: dict = dict(backend='disk')
) -> None: ) -> None:
super().__init__( super().__init__(
...@@ -832,7 +843,10 @@ class LoadAnnotations3D(LoadAnnotations): ...@@ -832,7 +843,10 @@ class LoadAnnotations3D(LoadAnnotations):
self.with_attr_label = with_attr_label self.with_attr_label = with_attr_label
self.with_mask_3d = with_mask_3d self.with_mask_3d = with_mask_3d
self.with_seg_3d = with_seg_3d self.with_seg_3d = with_seg_3d
self.with_panoptic_3d = with_panoptic_3d
self.seg_3d_dtype = seg_3d_dtype self.seg_3d_dtype = seg_3d_dtype
self.seg_offset = seg_offset
self.dataset_type = dataset_type
self.file_client = None self.file_client = None
def _load_bboxes_3d(self, results: dict) -> dict: def _load_bboxes_3d(self, results: dict) -> dict:
...@@ -938,10 +952,57 @@ class LoadAnnotations3D(LoadAnnotations): ...@@ -938,10 +952,57 @@ class LoadAnnotations3D(LoadAnnotations):
pts_semantic_mask = np.fromfile( pts_semantic_mask = np.fromfile(
pts_semantic_mask_path, dtype=np.int64) pts_semantic_mask_path, dtype=np.int64)
if self.dataset_type == 'semantickitti':
pts_semantic_mask = pts_semantic_mask.astype(np.int64)
pts_semantic_mask = pts_semantic_mask % self.seg_offset
# nuScenes loads semantic and panoptic labels from different files.
results['pts_semantic_mask'] = pts_semantic_mask
# 'eval_ann_info' will be passed to evaluator
if 'eval_ann_info' in results:
results['eval_ann_info']['pts_semantic_mask'] = pts_semantic_mask
return results
def _load_panoptic_3d(self, results: dict) -> dict:
"""Private function to load 3D panoptic segmentation annotations.
Args:
results (dict): Result dict from :obj:`mmdet3d.CustomDataset`.
Returns:
dict: The dict containing the panoptic segmentation annotations.
"""
pts_panoptic_mask_path = results['pts_panoptic_mask_path']
if self.file_client is None:
self.file_client = mmengine.FileClient(**self.file_client_args)
try:
mask_bytes = self.file_client.get(pts_panoptic_mask_path)
# add .copy() to fix read-only bug
pts_panoptic_mask = np.frombuffer(
mask_bytes, dtype=self.seg_3d_dtype).copy()
except ConnectionError:
mmengine.check_file_exist(pts_panoptic_mask_path)
pts_panoptic_mask = np.fromfile(
pts_panoptic_mask_path, dtype=np.int64)
if self.dataset_type == 'semantickitti':
pts_semantic_mask = pts_panoptic_mask.astype(np.int64)
pts_semantic_mask = pts_semantic_mask % self.seg_offset
elif self.dataset_type == 'nuscenes':
pts_semantic_mask = pts_semantic_mask // self.seg_offset
results['pts_semantic_mask'] = pts_semantic_mask results['pts_semantic_mask'] = pts_semantic_mask
# We can directly take panoptic labels as instance ids.
pts_instance_mask = pts_panoptic_mask.astype(np.int64)
results['pts_instance_mask'] = pts_instance_mask
# 'eval_ann_info' will be passed to evaluator # 'eval_ann_info' will be passed to evaluator
if 'eval_ann_info' in results: if 'eval_ann_info' in results:
results['eval_ann_info']['pts_semantic_mask'] = pts_semantic_mask results['eval_ann_info']['pts_semantic_mask'] = pts_semantic_mask
results['eval_ann_info']['pts_instance_mask'] = pts_instance_mask
return results return results
def _load_bboxes(self, results: dict) -> None: def _load_bboxes(self, results: dict) -> None:
...@@ -989,11 +1050,12 @@ class LoadAnnotations3D(LoadAnnotations): ...@@ -989,11 +1050,12 @@ class LoadAnnotations3D(LoadAnnotations):
results = self._load_labels_3d(results) results = self._load_labels_3d(results)
if self.with_attr_label: if self.with_attr_label:
results = self._load_attr_labels(results) results = self._load_attr_labels(results)
if self.with_panoptic_3d:
results = self._load_panoptic_3d(results)
if self.with_mask_3d: if self.with_mask_3d:
results = self._load_masks_3d(results) results = self._load_masks_3d(results)
if self.with_seg_3d: if self.with_seg_3d:
results = self._load_semantic_seg_3d(results) results = self._load_semantic_seg_3d(results)
return results return results
def __repr__(self) -> str: def __repr__(self) -> str:
...@@ -1005,12 +1067,15 @@ class LoadAnnotations3D(LoadAnnotations): ...@@ -1005,12 +1067,15 @@ class LoadAnnotations3D(LoadAnnotations):
repr_str += f'{indent_str}with_attr_label={self.with_attr_label}, ' repr_str += f'{indent_str}with_attr_label={self.with_attr_label}, '
repr_str += f'{indent_str}with_mask_3d={self.with_mask_3d}, ' repr_str += f'{indent_str}with_mask_3d={self.with_mask_3d}, '
repr_str += f'{indent_str}with_seg_3d={self.with_seg_3d}, ' repr_str += f'{indent_str}with_seg_3d={self.with_seg_3d}, '
repr_str += f'{indent_str}with_panoptic_3d={self.with_panoptic_3d}, '
repr_str += f'{indent_str}with_bbox={self.with_bbox}, ' repr_str += f'{indent_str}with_bbox={self.with_bbox}, '
repr_str += f'{indent_str}with_label={self.with_label}, ' repr_str += f'{indent_str}with_label={self.with_label}, '
repr_str += f'{indent_str}with_mask={self.with_mask}, ' repr_str += f'{indent_str}with_mask={self.with_mask}, '
repr_str += f'{indent_str}with_seg={self.with_seg}, ' repr_str += f'{indent_str}with_seg={self.with_seg}, '
repr_str += f'{indent_str}with_bbox_depth={self.with_bbox_depth}, ' repr_str += f'{indent_str}with_bbox_depth={self.with_bbox_depth}, '
repr_str += f'{indent_str}poly2mask={self.poly2mask})' repr_str += f'{indent_str}poly2mask={self.poly2mask})'
repr_str += f'{indent_str}seg_offset={self.seg_offset})'
return repr_str return repr_str
......
...@@ -160,7 +160,11 @@ def create_dummy_data_info(with_ann=True): ...@@ -160,7 +160,11 @@ def create_dummy_data_info(with_ann=True):
0 0
}], }],
'plane': 'plane':
None None,
'pts_semantic_mask_path':
'tests/data/semantickitti/sequences/00/labels/000000.label',
'pts_panoptic_mask_path':
'tests/data/semantickitti/sequences/00/labels/000000.label',
} }
if with_ann: if with_ann:
data_info['ann_info'] = ann_info data_info['ann_info'] = ann_info
......
...@@ -36,6 +36,44 @@ def _generate_semantickitti_dataset_config(): ...@@ -36,6 +36,44 @@ def _generate_semantickitti_dataset_config():
[227, 119, 194], [227, 119, 194],
[82, 84, 163], [82, 84, 163],
] ]
seg_label_mapping = {
0: 0, # "unlabeled"
1: 0, # "outlier" mapped to "unlabeled" --------------mapped
10: 1, # "car"
11: 2, # "bicycle"
13: 5, # "bus" mapped to "other-vehicle" --------------mapped
15: 3, # "motorcycle"
16: 5, # "on-rails" mapped to "other-vehicle" ---------mapped
18: 4, # "truck"
20: 5, # "other-vehicle"
30: 6, # "person"
31: 7, # "bicyclist"
32: 8, # "motorcyclist"
40: 9, # "road"
44: 10, # "parking"
48: 11, # "sidewalk"
49: 12, # "other-ground"
50: 13, # "building"
51: 14, # "fence"
52: 0, # "other-structure" mapped to "unlabeled" ------mapped
60: 9, # "lane-marking" to "road" ---------------------mapped
70: 15, # "vegetation"
71: 16, # "trunk"
72: 17, # "terrain"
80: 18, # "pole"
81: 19, # "traffic-sign"
99: 0, # "other-object" to "unlabeled" ----------------mapped
252: 1, # "moving-car" to "car" ------------------------mapped
253: 7, # "moving-bicyclist" to "bicyclist" ------------mapped
254: 6, # "moving-person" to "person" ------------------mapped
255: 8, # "moving-motorcyclist" to "motorcyclist" ------mapped
256: 5, # "moving-on-rails" mapped to "other-vehic------mapped
257: 5, # "moving-bus" mapped to "other-vehicle" -------mapped
258: 4, # "moving-truck" to "truck" --------------------mapped
259: 5 # "moving-other"-vehicle to "other-vehicle"-----mapped
}
max_label = 259
modality = dict(use_lidar=True, use_camera=False) modality = dict(use_lidar=True, use_camera=False)
pipeline = [ pipeline = [
dict( dict(
...@@ -51,6 +89,7 @@ def _generate_semantickitti_dataset_config(): ...@@ -51,6 +89,7 @@ def _generate_semantickitti_dataset_config():
with_mask_3d=False, with_mask_3d=False,
with_seg_3d=True, with_seg_3d=True,
seg_3d_dtype=np.int32), seg_3d_dtype=np.int32),
dict(type='PointSegClassMapping'),
dict(type='Pack3DDetInputs', keys=['points', 'pts_semantic_mask']) dict(type='Pack3DDetInputs', keys=['points', 'pts_semantic_mask'])
] ]
...@@ -58,21 +97,26 @@ def _generate_semantickitti_dataset_config(): ...@@ -58,21 +97,26 @@ def _generate_semantickitti_dataset_config():
pts='sequences/00/velodyne', pts_semantic_mask='sequences/00/labels') pts='sequences/00/velodyne', pts_semantic_mask='sequences/00/labels')
return (data_root, ann_file, classes, palette, data_prefix, pipeline, return (data_root, ann_file, classes, palette, data_prefix, pipeline,
modality) modality, seg_label_mapping, max_label)
class TestSemanticKITTIDataset(unittest.TestCase): class TestSemanticKITTIDataset(unittest.TestCase):
def test_semantickitti(self): def test_semantickitti(self):
data_root, ann_file, classes, palette, data_prefix, \ (data_root, ann_file, classes, palette, data_prefix, pipeline,
pipeline, modality, = _generate_semantickitti_dataset_config() modality, seg_label_mapping,
max_label) = _generate_semantickitti_dataset_config()
register_all_modules() register_all_modules()
np.random.seed(0) np.random.seed(0)
semantickitti_dataset = SemanticKITTIDataset( semantickitti_dataset = SemanticKITTIDataset(
data_root, data_root,
ann_file, ann_file,
metainfo=dict(classes=classes, palette=palette), metainfo=dict(
classes=classes,
palette=palette,
seg_label_mapping=seg_label_mapping,
max_label=max_label),
data_prefix=data_prefix, data_prefix=data_prefix,
pipeline=pipeline, pipeline=pipeline,
modality=modality) modality=modality)
...@@ -83,3 +127,13 @@ class TestSemanticKITTIDataset(unittest.TestCase): ...@@ -83,3 +127,13 @@ class TestSemanticKITTIDataset(unittest.TestCase):
data_sample = input_dict['data_samples'] data_sample = input_dict['data_samples']
pts_semantic_mask = data_sample.gt_pts_seg.pts_semantic_mask pts_semantic_mask = data_sample.gt_pts_seg.pts_semantic_mask
self.assertEqual(points.shape[0], pts_semantic_mask.shape[0]) self.assertEqual(points.shape[0], pts_semantic_mask.shape[0])
expected_pts_semantic_mask = np.array([
13., 13., 13., 15., 15., 13., 0., 13., 15., 13., 13., 15., 16., 0.,
15., 13., 13., 13., 13., 0., 13., 13., 13., 13., 13., 15., 13.,
16., 13., 15., 15., 18., 13., 15., 15., 15., 16., 15., 13., 13.,
15., 13., 18., 15., 13., 15., 13., 15., 15., 13.
])
self.assertTrue(
(pts_semantic_mask.numpy() == expected_pts_semantic_mask).all())
...@@ -58,6 +58,10 @@ class TestLoadAnnotations3D(unittest.TestCase): ...@@ -58,6 +58,10 @@ class TestLoadAnnotations3D(unittest.TestCase):
load_anns_transform = LoadAnnotations3D( load_anns_transform = LoadAnnotations3D(
with_bbox_3d=True, with_bbox_3d=True,
with_label_3d=True, with_label_3d=True,
with_panoptic_3d=True,
seg_offset=2**16,
dataset_type='semantickitti',
seg_3d_dtype=np.uint32,
file_client_args=file_client_args) file_client_args=file_client_args)
self.assertIs(load_anns_transform.with_seg, False) self.assertIs(load_anns_transform.with_seg, False)
self.assertIs(load_anns_transform.with_bbox_3d, True) self.assertIs(load_anns_transform.with_bbox_3d, True)
...@@ -69,10 +73,29 @@ class TestLoadAnnotations3D(unittest.TestCase): ...@@ -69,10 +73,29 @@ class TestLoadAnnotations3D(unittest.TestCase):
torch.tensor(7.2650)) torch.tensor(7.2650))
self.assertIn('gt_labels_3d', info) self.assertIn('gt_labels_3d', info)
assert_allclose(info['gt_labels_3d'], torch.tensor([1])) assert_allclose(info['gt_labels_3d'], torch.tensor([1]))
self.assertIn('pts_semantic_mask', info)
self.assertIn('pts_instance_mask', info)
assert_allclose(
info['pts_semantic_mask'],
np.array([
50, 50, 50, 70, 70, 50, 0, 50, 70, 50, 50, 70, 71, 52, 70, 50,
50, 50, 50, 0, 50, 50, 50, 50, 50, 70, 50, 71, 50, 70, 70, 80,
50, 70, 70, 70, 71, 70, 50, 50, 70, 50, 80, 70, 50, 70, 50, 70,
70, 50
]))
assert_allclose(
info['pts_instance_mask'],
np.array([
50, 50, 50, 70, 70, 50, 0, 50, 70, 50, 50, 70, 71, 52, 70, 50,
50, 50, 50, 0, 50, 50, 50, 50, 50, 70, 50, 71, 50, 70, 70, 80,
50, 70, 70, 70, 71, 70, 50, 50, 70, 50, 80, 70, 50, 70, 50, 70,
70, 50
]))
repr_str = repr(load_anns_transform) repr_str = repr(load_anns_transform)
self.assertIn('with_bbox_3d=True', repr_str) self.assertIn('with_bbox_3d=True', repr_str)
self.assertIn('with_label_3d=True', repr_str) self.assertIn('with_label_3d=True', repr_str)
self.assertIn('with_bbox_depth=False', repr_str) self.assertIn('with_bbox_depth=False', repr_str)
self.assertIn('with_panoptic_3d=True', repr_str)
class TestPointSegClassMapping(unittest.TestCase): class TestPointSegClassMapping(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment