"torchvision/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "73281b4f98233bc11e068ce50e7f627e6ad6f84e"
Unverified Commit 25a1bcae authored by Ziyi Wu's avatar Ziyi Wu Committed by GitHub
Browse files

[Enhance] Efficient implementation of PointSegClassMapping (#489)

* efficient mapping by matrix indexing

* add pointsegmapping unit test

* add max_cat_id args everywhere PointSegClassMapping is called in codebase

* add default value for max_cat_id for safety

* add assertion of max_cat_id
parent e108340c
...@@ -22,7 +22,8 @@ train_pipeline = [ ...@@ -22,7 +22,8 @@ train_pipeline = [
with_seg_3d=True), with_seg_3d=True),
dict( dict(
type='PointSegClassMapping', type='PointSegClassMapping',
valid_cat_ids=tuple(range(len(class_names)))), valid_cat_ids=tuple(range(len(class_names))),
max_cat_id=13),
dict( dict(
type='IndoorPatchPointSample', type='IndoorPatchPointSample',
num_points=num_points, num_points=num_points,
...@@ -65,7 +66,8 @@ eval_pipeline = [ ...@@ -65,7 +66,8 @@ eval_pipeline = [
with_seg_3d=True), with_seg_3d=True),
dict( dict(
type='PointSegClassMapping', type='PointSegClassMapping',
valid_cat_ids=tuple(range(len(class_names)))), valid_cat_ids=tuple(range(len(class_names))),
max_cat_id=13),
dict( dict(
type='DefaultFormatBundle3D', type='DefaultFormatBundle3D',
with_label=False, with_label=False,
......
...@@ -21,7 +21,8 @@ train_pipeline = [ ...@@ -21,7 +21,8 @@ train_pipeline = [
dict( dict(
type='PointSegClassMapping', type='PointSegClassMapping',
valid_cat_ids=(3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34, valid_cat_ids=(3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34,
36, 39)), 36, 39),
max_cat_id=40),
dict(type='IndoorPointSample', num_points=40000), dict(type='IndoorPointSample', num_points=40000),
dict( dict(
type='RandomFlip3D', type='RandomFlip3D',
......
...@@ -23,7 +23,8 @@ train_pipeline = [ ...@@ -23,7 +23,8 @@ train_pipeline = [
dict( dict(
type='PointSegClassMapping', type='PointSegClassMapping',
valid_cat_ids=(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, valid_cat_ids=(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28,
33, 34, 36, 39)), 33, 34, 36, 39),
max_cat_id=40),
dict( dict(
type='IndoorPatchPointSample', type='IndoorPatchPointSample',
num_points=num_points, num_points=num_points,
...@@ -67,7 +68,8 @@ eval_pipeline = [ ...@@ -67,7 +68,8 @@ eval_pipeline = [
dict( dict(
type='PointSegClassMapping', type='PointSegClassMapping',
valid_cat_ids=(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, valid_cat_ids=(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28,
33, 34, 36, 39)), 33, 34, 36, 39),
max_cat_id=40),
dict( dict(
type='DefaultFormatBundle3D', type='DefaultFormatBundle3D',
with_label=False, with_label=False,
......
...@@ -201,7 +201,8 @@ train_pipeline = [ # Training pipeline, refer to mmdet3d.datasets.pipelines for ...@@ -201,7 +201,8 @@ train_pipeline = [ # Training pipeline, refer to mmdet3d.datasets.pipelines for
dict( dict(
type='PointSegClassMapping', # Declare valid categories, refer to mmdet3d.datasets.pipelines.point_seg_class_mapping for more details type='PointSegClassMapping', # Declare valid categories, refer to mmdet3d.datasets.pipelines.point_seg_class_mapping for more details
valid_cat_ids=(3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34, valid_cat_ids=(3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34,
36, 39)), 36, 39), # all valid categories ids
max_cat_id=40), # max possible category id in input segmentation mask
dict(type='IndoorPointSample', # Sample indoor points, refer to mmdet3d.datasets.pipelines.indoor_sample for more details dict(type='IndoorPointSample', # Sample indoor points, refer to mmdet3d.datasets.pipelines.indoor_sample for more details
num_points=40000), # Number of points to be sampled num_points=40000), # Number of points to be sampled
dict(type='IndoorFlipData', # Augmentation pipeline that flip points and 3d boxes dict(type='IndoorFlipData', # Augmentation pipeline that flip points and 3d boxes
...@@ -283,7 +284,8 @@ data = dict( ...@@ -283,7 +284,8 @@ data = dict(
dict( dict(
type='PointSegClassMapping', type='PointSegClassMapping',
valid_cat_ids=(3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, valid_cat_ids=(3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24,
28, 33, 34, 36, 39)), 28, 33, 34, 36, 39),
max_cat_id=40),
dict(type='IndoorPointSample', num_points=40000), dict(type='IndoorPointSample', num_points=40000),
dict( dict(
type='IndoorFlipData', type='IndoorFlipData',
......
...@@ -237,10 +237,23 @@ class PointSegClassMapping(object): ...@@ -237,10 +237,23 @@ class PointSegClassMapping(object):
Args: Args:
valid_cat_ids (tuple[int]): A tuple of valid category. valid_cat_ids (tuple[int]): A tuple of valid category.
max_cat_id (int): The max possible cat_id in input segmentation mask.
Defaults to 40.
""" """
def __init__(self, valid_cat_ids): def __init__(self, valid_cat_ids, max_cat_id=40):
assert max_cat_id >= np.max(valid_cat_ids), \
'max_cat_id should be greater than maximum id in valid_cat_ids'
self.valid_cat_ids = valid_cat_ids self.valid_cat_ids = valid_cat_ids
self.max_cat_id = int(max_cat_id)
# build cat_id to class index mapping
neg_cls = len(valid_cat_ids)
self.cat_id2class = np.ones(
self.max_cat_id + 1, dtype=np.int) * neg_cls
for cls_idx, cat_id in enumerate(valid_cat_ids):
self.cat_id2class[cat_id] = cls_idx
def __call__(self, results): def __call__(self, results):
"""Call function to map original semantic class to valid category ids. """Call function to map original semantic class to valid category ids.
...@@ -256,22 +269,17 @@ class PointSegClassMapping(object): ...@@ -256,22 +269,17 @@ class PointSegClassMapping(object):
""" """
assert 'pts_semantic_mask' in results assert 'pts_semantic_mask' in results
pts_semantic_mask = results['pts_semantic_mask'] pts_semantic_mask = results['pts_semantic_mask']
neg_cls = len(self.valid_cat_ids)
for i in range(pts_semantic_mask.shape[0]): converted_pts_sem_mask = self.cat_id2class[pts_semantic_mask]
if pts_semantic_mask[i] in self.valid_cat_ids:
converted_id = self.valid_cat_ids.index(pts_semantic_mask[i])
pts_semantic_mask[i] = converted_id
else:
pts_semantic_mask[i] = neg_cls
results['pts_semantic_mask'] = pts_semantic_mask results['pts_semantic_mask'] = converted_pts_sem_mask
return results return results
def __repr__(self): def __repr__(self):
"""str: Return a string that describes the module.""" """str: Return a string that describes the module."""
repr_str = self.__class__.__name__ repr_str = self.__class__.__name__
repr_str += f'(valid_cat_ids={self.valid_cat_ids})' repr_str += f'(valid_cat_ids={self.valid_cat_ids}, '
repr_str += f'max_cat_id={self.max_cat_id})'
return repr_str return repr_str
......
...@@ -117,7 +117,8 @@ class _S3DISSegDataset(Custom3DSegDataset): ...@@ -117,7 +117,8 @@ class _S3DISSegDataset(Custom3DSegDataset):
with_seg_3d=True), with_seg_3d=True),
dict( dict(
type='PointSegClassMapping', type='PointSegClassMapping',
valid_cat_ids=self.VALID_CLASS_IDS), valid_cat_ids=self.VALID_CLASS_IDS,
max_cat_id=np.max(self.ALL_CLASS_IDS)),
dict( dict(
type='DefaultFormatBundle3D', type='DefaultFormatBundle3D',
with_label=False, with_label=False,
......
...@@ -276,7 +276,8 @@ class ScanNetSegDataset(Custom3DSegDataset): ...@@ -276,7 +276,8 @@ class ScanNetSegDataset(Custom3DSegDataset):
with_seg_3d=True), with_seg_3d=True),
dict( dict(
type='PointSegClassMapping', type='PointSegClassMapping',
valid_cat_ids=self.VALID_CLASS_IDS), valid_cat_ids=self.VALID_CLASS_IDS,
max_cat_id=np.max(self.ALL_CLASS_IDS)),
dict( dict(
type='DefaultFormatBundle3D', type='DefaultFormatBundle3D',
with_label=False, with_label=False,
......
...@@ -39,7 +39,8 @@ def test_seg_getitem(): ...@@ -39,7 +39,8 @@ def test_seg_getitem():
with_seg_3d=True), with_seg_3d=True),
dict( dict(
type='PointSegClassMapping', type='PointSegClassMapping',
valid_cat_ids=tuple(range(len(class_names)))), valid_cat_ids=tuple(range(len(class_names))),
max_cat_id=13),
dict( dict(
type='IndoorPatchPointSample', type='IndoorPatchPointSample',
num_points=5, num_points=5,
......
...@@ -302,7 +302,8 @@ def test_seg_getitem(): ...@@ -302,7 +302,8 @@ def test_seg_getitem():
dict( dict(
type='PointSegClassMapping', type='PointSegClassMapping',
valid_cat_ids=(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, valid_cat_ids=(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24,
28, 33, 34, 36, 39)), 28, 33, 34, 36, 39),
max_cat_id=40),
dict( dict(
type='IndoorPatchPointSample', type='IndoorPatchPointSample',
num_points=5, num_points=5,
...@@ -542,7 +543,8 @@ def test_seg_evaluate(): ...@@ -542,7 +543,8 @@ def test_seg_evaluate():
dict( dict(
type='PointSegClassMapping', type='PointSegClassMapping',
valid_cat_ids=(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, valid_cat_ids=(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24,
28, 33, 34, 36, 39)), 28, 33, 34, 36, 39),
max_cat_id=40),
dict(type='DefaultFormatBundle3D', class_names=class_names), dict(type='DefaultFormatBundle3D', class_names=class_names),
dict(type='Collect3D', keys=['points', 'pts_semantic_mask']) dict(type='Collect3D', keys=['points', 'pts_semantic_mask'])
] ]
...@@ -606,7 +608,8 @@ def test_seg_show(): ...@@ -606,7 +608,8 @@ def test_seg_show():
dict( dict(
type='PointSegClassMapping', type='PointSegClassMapping',
valid_cat_ids=(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, valid_cat_ids=(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24,
28, 33, 34, 36, 39)), 28, 33, 34, 36, 39),
max_cat_id=40),
dict(type='DefaultFormatBundle3D', class_names=class_names), dict(type='DefaultFormatBundle3D', class_names=class_names),
dict(type='Collect3D', keys=['points', 'pts_semantic_mask']) dict(type='Collect3D', keys=['points', 'pts_semantic_mask'])
] ]
......
...@@ -128,7 +128,8 @@ def test_scannet_seg_pipeline(): ...@@ -128,7 +128,8 @@ def test_scannet_seg_pipeline():
dict( dict(
type='PointSegClassMapping', type='PointSegClassMapping',
valid_cat_ids=(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, valid_cat_ids=(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24,
28, 33, 34, 36, 39)), 28, 33, 34, 36, 39),
max_cat_id=40),
dict( dict(
type='IndoorPatchPointSample', type='IndoorPatchPointSample',
num_points=5, num_points=5,
...@@ -197,7 +198,8 @@ def test_s3dis_seg_pipeline(): ...@@ -197,7 +198,8 @@ def test_s3dis_seg_pipeline():
with_seg_3d=True), with_seg_3d=True),
dict( dict(
type='PointSegClassMapping', type='PointSegClassMapping',
valid_cat_ids=tuple(range(len(class_names)))), valid_cat_ids=tuple(range(len(class_names))),
max_cat_id=13),
dict( dict(
type='IndoorPatchPointSample', type='IndoorPatchPointSample',
num_points=5, num_points=5,
......
...@@ -70,7 +70,7 @@ def test_indoor_seg_sample(): ...@@ -70,7 +70,7 @@ def test_indoor_seg_sample():
scannet_patch_sample_points = IndoorPatchPointSample(5, 1.5, 1.0, 20, True) scannet_patch_sample_points = IndoorPatchPointSample(5, 1.5, 1.0, 20, True)
scannet_seg_class_mapping = \ scannet_seg_class_mapping = \
PointSegClassMapping((1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, PointSegClassMapping((1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16,
24, 28, 33, 34, 36, 39)) 24, 28, 33, 34, 36, 39), 40)
scannet_results = dict() scannet_results = dict()
scannet_points = np.fromfile( scannet_points = np.fromfile(
'./tests/data/scannet/points/scene0000_00.bin', './tests/data/scannet/points/scene0000_00.bin',
......
...@@ -217,7 +217,7 @@ def test_load_segmentation_mask(): ...@@ -217,7 +217,7 @@ def test_load_segmentation_mask():
# Convert class_id to label and assign ignore_index # Convert class_id to label and assign ignore_index
scannet_seg_class_mapping = \ scannet_seg_class_mapping = \
PointSegClassMapping((1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, PointSegClassMapping((1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16,
24, 28, 33, 34, 36, 39)) 24, 28, 33, 34, 36, 39), 40)
scannet_results = scannet_seg_class_mapping(scannet_results) scannet_results = scannet_seg_class_mapping(scannet_results)
scannet_pts_semantic_mask = scannet_results['pts_semantic_mask'] scannet_pts_semantic_mask = scannet_results['pts_semantic_mask']
...@@ -250,7 +250,7 @@ def test_load_segmentation_mask(): ...@@ -250,7 +250,7 @@ def test_load_segmentation_mask():
assert s3dis_pts_semantic_mask.shape == (100, ) assert s3dis_pts_semantic_mask.shape == (100, )
# Convert class_id to label and assign ignore_index # Convert class_id to label and assign ignore_index
s3dis_seg_class_mapping = PointSegClassMapping(tuple(range(13))) s3dis_seg_class_mapping = PointSegClassMapping(tuple(range(13)), 13)
s3dis_results = s3dis_seg_class_mapping(s3dis_results) s3dis_results = s3dis_seg_class_mapping(s3dis_results)
s3dis_pts_semantic_mask = s3dis_results['pts_semantic_mask'] s3dis_pts_semantic_mask = s3dis_results['pts_semantic_mask']
...@@ -288,6 +288,39 @@ def test_load_points_from_multi_sweeps(): ...@@ -288,6 +288,39 @@ def test_load_points_from_multi_sweeps():
assert points.shape == (403, 4) assert points.shape == (403, 4)
def test_point_seg_class_mapping():
# max_cat_id should larger tham max id in valid_cat_ids
with pytest.raises(AssertionError):
point_seg_class_mapping = PointSegClassMapping([1, 2, 5], 4)
sem_mask = np.array([
16, 22, 2, 3, 7, 3, 16, 2, 16, 3, 1, 0, 6, 22, 3, 1, 2, 16, 1, 1, 1,
38, 7, 25, 16, 25, 3, 40, 38, 3, 33, 6, 16, 6, 16, 1, 38, 1, 1, 2, 8,
0, 18, 15, 0, 0, 40, 40, 1, 2, 3, 16, 33, 2, 2, 2, 7, 3, 14, 22, 4, 22,
15, 24, 2, 40, 3, 2, 8, 3, 1, 6, 40, 6, 0, 15, 4, 7, 6, 0, 1, 16, 14,
3, 0, 1, 1, 16, 38, 2, 15, 6, 4, 1, 16, 2, 3, 3, 3, 2
])
valid_cat_ids = (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33,
34, 36, 39)
point_seg_class_mapping = PointSegClassMapping(valid_cat_ids, 40)
input_dict = dict(pts_semantic_mask=sem_mask)
results = point_seg_class_mapping(input_dict)
mapped_sem_mask = results['pts_semantic_mask']
expected_sem_mask = np.array([
13, 20, 1, 2, 6, 2, 13, 1, 13, 2, 0, 20, 5, 20, 2, 0, 1, 13, 0, 0, 0,
20, 6, 20, 13, 20, 2, 20, 20, 2, 16, 5, 13, 5, 13, 0, 20, 0, 0, 1, 7,
20, 20, 20, 20, 20, 20, 20, 0, 1, 2, 13, 16, 1, 1, 1, 6, 2, 12, 20, 3,
20, 20, 14, 1, 20, 2, 1, 7, 2, 0, 5, 20, 5, 20, 20, 3, 6, 5, 20, 0, 13,
12, 2, 20, 0, 0, 13, 20, 1, 20, 5, 3, 0, 13, 1, 2, 2, 2, 1
])
repr_str = repr(point_seg_class_mapping)
expected_repr_str = f'PointSegClassMapping(valid_cat_ids={valid_cat_ids}'\
', max_cat_id=40)'
assert np.all(mapped_sem_mask == expected_sem_mask)
assert repr_str == expected_repr_str
def test_normalize_points_color(): def test_normalize_points_color():
coord = np.array([[68.137, 3.358, 2.516], [67.697, 3.55, 2.501], coord = np.array([[68.137, 3.358, 2.516], [67.697, 3.55, 2.501],
[67.649, 3.76, 2.5], [66.414, 3.901, 2.459], [67.649, 3.76, 2.5], [66.414, 3.901, 2.459],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment