"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "3c312e215c6aae05b4df97f1543dcb428053776b"
Commit 360c27f9 authored by ZCMax's avatar ZCMax Committed by ChaimZhu
Browse files

[Refactor] Refactor 3D Seg Dataset

parent 1039ad0e
# dataset settings # For S3DIS seg we usually do 13-class segmentation
dataset_type = 'S3DISSegDataset'
data_root = './data/s3dis/'
class_names = ('ceiling', 'floor', 'wall', 'beam', 'column', 'window', 'door', class_names = ('ceiling', 'floor', 'wall', 'beam', 'column', 'window', 'door',
'table', 'chair', 'sofa', 'bookcase', 'board', 'clutter') 'table', 'chair', 'sofa', 'bookcase', 'board', 'clutter')
metainfo = dict(CLASSES=class_names)
dataset_type = 'S3DISSegDataset'
data_root = 'data/s3dis/'
input_modality = dict(use_lidar=True, use_camera=False)
data_prefix = dict(
pts='points',
pts_instance_mask='instance_mask',
pts_semantic_mask='semantic_mask')
file_client_args = dict(backend='disk') file_client_args = dict(backend='disk')
# Uncomment the following if use ceph or other file clients. # Uncomment the following if use ceph or other file clients.
# See https://mmcv.readthedocs.io/en/latest/api.html#mmcv.fileio.FileClient # See https://mmcv.readthedocs.io/en/latest/api.html#mmcv.fileio.FileClient
...@@ -15,29 +22,27 @@ file_client_args = dict(backend='disk') ...@@ -15,29 +22,27 @@ file_client_args = dict(backend='disk')
# 'data/s3dis/': # 'data/s3dis/':
# 's3://openmmlab/datasets/detection3d/s3dis_processed/' # 's3://openmmlab/datasets/detection3d/s3dis_processed/'
# })) # }))
num_points = 4096 num_points = 4096
train_area = [1, 2, 3, 4, 6] train_area = [1, 2, 3, 4, 6]
test_area = 5 test_area = 5
train_pipeline = [ train_pipeline = [
dict( dict(
type='LoadPointsFromFile', type='LoadPointsFromFile',
file_client_args=file_client_args,
coord_type='DEPTH', coord_type='DEPTH',
shift_height=False, shift_height=False,
use_color=True, use_color=True,
load_dim=6, load_dim=6,
use_dim=[0, 1, 2, 3, 4, 5]), use_dim=[0, 1, 2, 3, 4, 5],
file_client_args=file_client_args),
dict( dict(
type='LoadAnnotations3D', type='LoadAnnotations3D',
file_client_args=file_client_args,
with_bbox_3d=False, with_bbox_3d=False,
with_label_3d=False, with_label_3d=False,
with_mask_3d=False, with_mask_3d=False,
with_seg_3d=True), with_seg_3d=True,
dict( file_client_args=file_client_args),
type='PointSegClassMapping', dict(type='PointSegClassMapping'),
valid_cat_ids=tuple(range(len(class_names))),
max_cat_id=13),
dict( dict(
type='IndoorPatchPointSample', type='IndoorPatchPointSample',
num_points=num_points, num_points=num_points,
...@@ -47,18 +52,17 @@ train_pipeline = [ ...@@ -47,18 +52,17 @@ train_pipeline = [
enlarge_size=0.2, enlarge_size=0.2,
min_unique_num=None), min_unique_num=None),
dict(type='NormalizePointsColor', color_mean=None), dict(type='NormalizePointsColor', color_mean=None),
dict(type='DefaultFormatBundle3D', class_names=class_names), dict(type='Pack3DDetInputs', keys=['points', 'pts_semantic_mask'])
dict(type='Collect3D', keys=['points', 'pts_semantic_mask'])
] ]
test_pipeline = [ test_pipeline = [
dict( dict(
type='LoadPointsFromFile', type='LoadPointsFromFile',
file_client_args=file_client_args,
coord_type='DEPTH', coord_type='DEPTH',
shift_height=False, shift_height=False,
use_color=True, use_color=True,
load_dim=6, load_dim=6,
use_dim=[0, 1, 2, 3, 4, 5]), use_dim=[0, 1, 2, 3, 4, 5],
file_client_args=file_client_args),
dict(type='NormalizePointsColor', color_mean=None), dict(type='NormalizePointsColor', color_mean=None),
dict( dict(
# a wrapper in order to successfully call test function # a wrapper in order to successfully call test function
...@@ -78,12 +82,8 @@ test_pipeline = [ ...@@ -78,12 +82,8 @@ test_pipeline = [
sync_2d=False, sync_2d=False,
flip_ratio_bev_horizontal=0.0, flip_ratio_bev_horizontal=0.0,
flip_ratio_bev_vertical=0.0), flip_ratio_bev_vertical=0.0),
dict( ]),
type='DefaultFormatBundle3D', dict(type='Pack3DDetInputs', keys=['points'])
class_names=class_names,
with_label=False),
dict(type='Collect3D', keys=['points'])
])
] ]
# construct a pipeline for data and gt loading in show function # construct a pipeline for data and gt loading in show function
# please keep its loading function consistent with test_pipeline (e.g. client) # please keep its loading function consistent with test_pipeline (e.g. client)
...@@ -91,69 +91,58 @@ test_pipeline = [ ...@@ -91,69 +91,58 @@ test_pipeline = [
eval_pipeline = [ eval_pipeline = [
dict( dict(
type='LoadPointsFromFile', type='LoadPointsFromFile',
file_client_args=file_client_args,
coord_type='DEPTH', coord_type='DEPTH',
shift_height=False, shift_height=False,
use_color=True, use_color=True,
load_dim=6, load_dim=6,
use_dim=[0, 1, 2, 3, 4, 5]), use_dim=[0, 1, 2, 3, 4, 5],
dict( file_client_args=file_client_args),
type='LoadAnnotations3D', dict(type='NormalizePointsColor', color_mean=None),
file_client_args=file_client_args, dict(type='Pack3DDetInputs', keys=['points'])
with_bbox_3d=False,
with_label_3d=False,
with_mask_3d=False,
with_seg_3d=True),
dict(
type='PointSegClassMapping',
valid_cat_ids=tuple(range(len(class_names))),
max_cat_id=13),
dict(
type='DefaultFormatBundle3D',
with_label=False,
class_names=class_names),
dict(type='Collect3D', keys=['points', 'pts_semantic_mask'])
] ]
data = dict( # train on area 1, 2, 3, 4, 6
samples_per_gpu=8, # test on area 5
workers_per_gpu=4, train_dataloader = dict(
# train on area 1, 2, 3, 4, 6 batch_size=8,
# test on area 5 num_workers=4,
train=dict( persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(
type=dataset_type, type=dataset_type,
data_root=data_root, data_root=data_root,
ann_files=[ ann_files=[
data_root + f's3dis_infos_Area_{i}.pkl' for i in train_area data_root + f's3dis_infos_Area_{i}.pkl' for i in train_area
], ],
metainfo=metainfo,
data_prefix=data_prefix,
pipeline=train_pipeline, pipeline=train_pipeline,
classes=class_names, modality=input_modality,
test_mode=False,
ignore_index=len(class_names), ignore_index=len(class_names),
scene_idxs=[ scene_idxs=[
data_root + f'seg_info/Area_{i}_resampled_scene_idxs.npy' data_root + f'seg_info/Area_{i}_resampled_scene_idxs.npy'
for i in train_area for i in train_area
], ],
file_client_args=file_client_args), test_mode=False))
val=dict( test_dataloader = dict(
batch_size=1,
num_workers=1,
persistent_workers=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type, type=dataset_type,
data_root=data_root, data_root=data_root,
ann_files=data_root + f's3dis_infos_Area_{test_area}.pkl', ann_files=data_root + f's3dis_infos_Area_{test_area}.pkl',
metainfo=metainfo,
data_prefix=data_prefix,
pipeline=test_pipeline, pipeline=test_pipeline,
classes=class_names, modality=input_modality,
test_mode=True,
ignore_index=len(class_names), ignore_index=len(class_names),
scene_idxs=data_root + scene_idxs=data_root +
f'seg_info/Area_{test_area}_resampled_scene_idxs.npy', f'seg_info/Area_{test_area}_resampled_scene_idxs.npy',
file_client_args=file_client_args), test_mode=True))
test=dict( val_dataloader = test_dataloader
type=dataset_type,
data_root=data_root,
ann_files=data_root + f's3dis_infos_Area_{test_area}.pkl',
pipeline=test_pipeline,
classes=class_names,
test_mode=True,
ignore_index=len(class_names),
file_client_args=file_client_args))
evaluation = dict(pipeline=eval_pipeline) val_evaluator = dict(type='SegMetric')
test_evaluator = val_evaluator
# dataset settings # For ScanNet seg we usually do 20-class segmentation
dataset_type = 'ScanNetSegDataset'
data_root = './data/scannet/'
class_names = ('wall', 'floor', 'cabinet', 'bed', 'chair', 'sofa', 'table', class_names = ('wall', 'floor', 'cabinet', 'bed', 'chair', 'sofa', 'table',
'door', 'window', 'bookshelf', 'picture', 'counter', 'desk', 'door', 'window', 'bookshelf', 'picture', 'counter', 'desk',
'curtain', 'refrigerator', 'showercurtrain', 'toilet', 'sink', 'curtain', 'refrigerator', 'showercurtrain', 'toilet', 'sink',
'bathtub', 'otherfurniture') 'bathtub', 'otherfurniture')
metainfo = dict(CLASSES=class_names)
dataset_type = 'ScanNetSegDataset'
data_root = 'data/scannet/'
input_modality = dict(use_lidar=True, use_camera=False)
data_prefix = dict(
pts='points',
pts_instance_mask='instance_mask',
pts_semantic_mask='semantic_mask')
file_client_args = dict(backend='disk')
# Uncomment the following if use ceph or other file clients.
# See https://mmcv.readthedocs.io/en/latest/api.html#mmcv.fileio.FileClient
# for more details.
# file_client_args = dict(
# backend='petrel',
# path_mapping=dict({
# './data/scannet/':
# 's3://openmmlab/datasets/detection3d/scannet_processed/',
# 'data/scannet/':
# 's3://openmmlab/datasets/detection3d/scannet_processed/'
# }))
num_points = 8192 num_points = 8192
train_pipeline = [ train_pipeline = [
dict( dict(
...@@ -13,18 +33,16 @@ train_pipeline = [ ...@@ -13,18 +33,16 @@ train_pipeline = [
shift_height=False, shift_height=False,
use_color=True, use_color=True,
load_dim=6, load_dim=6,
use_dim=[0, 1, 2, 3, 4, 5]), use_dim=[0, 1, 2, 3, 4, 5],
file_client_args=file_client_args),
dict( dict(
type='LoadAnnotations3D', type='LoadAnnotations3D',
with_bbox_3d=False, with_bbox_3d=False,
with_label_3d=False, with_label_3d=False,
with_mask_3d=False, with_mask_3d=False,
with_seg_3d=True), with_seg_3d=True,
dict( file_client_args=file_client_args),
type='PointSegClassMapping', dict(type='PointSegClassMapping'),
valid_cat_ids=(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28,
33, 34, 36, 39),
max_cat_id=40),
dict( dict(
type='IndoorPatchPointSample', type='IndoorPatchPointSample',
num_points=num_points, num_points=num_points,
...@@ -34,8 +52,7 @@ train_pipeline = [ ...@@ -34,8 +52,7 @@ train_pipeline = [
enlarge_size=0.2, enlarge_size=0.2,
min_unique_num=None), min_unique_num=None),
dict(type='NormalizePointsColor', color_mean=None), dict(type='NormalizePointsColor', color_mean=None),
dict(type='DefaultFormatBundle3D', class_names=class_names), dict(type='Pack3DDetInputs', keys=['points', 'pts_semantic_mask'])
dict(type='Collect3D', keys=['points', 'pts_semantic_mask'])
] ]
test_pipeline = [ test_pipeline = [
dict( dict(
...@@ -44,7 +61,8 @@ test_pipeline = [ ...@@ -44,7 +61,8 @@ test_pipeline = [
shift_height=False, shift_height=False,
use_color=True, use_color=True,
load_dim=6, load_dim=6,
use_dim=[0, 1, 2, 3, 4, 5]), use_dim=[0, 1, 2, 3, 4, 5],
file_client_args=file_client_args),
dict(type='NormalizePointsColor', color_mean=None), dict(type='NormalizePointsColor', color_mean=None),
dict( dict(
# a wrapper in order to successfully call test function # a wrapper in order to successfully call test function
...@@ -64,12 +82,8 @@ test_pipeline = [ ...@@ -64,12 +82,8 @@ test_pipeline = [
sync_2d=False, sync_2d=False,
flip_ratio_bev_horizontal=0.0, flip_ratio_bev_horizontal=0.0,
flip_ratio_bev_vertical=0.0), flip_ratio_bev_vertical=0.0),
dict( ]),
type='DefaultFormatBundle3D', dict(type='Pack3DDetInputs', keys=['points'])
class_names=class_names,
with_label=False),
dict(type='Collect3D', keys=['points'])
])
] ]
# construct a pipeline for data and gt loading in show function # construct a pipeline for data and gt loading in show function
# please keep its loading function consistent with test_pipeline (e.g. client) # please keep its loading function consistent with test_pipeline (e.g. client)
...@@ -81,52 +95,45 @@ eval_pipeline = [ ...@@ -81,52 +95,45 @@ eval_pipeline = [
shift_height=False, shift_height=False,
use_color=True, use_color=True,
load_dim=6, load_dim=6,
use_dim=[0, 1, 2, 3, 4, 5]), use_dim=[0, 1, 2, 3, 4, 5],
dict( file_client_args=file_client_args),
type='LoadAnnotations3D', dict(type='NormalizePointsColor', color_mean=None),
with_bbox_3d=False, dict(type='Pack3DDetInputs', keys=['points'])
with_label_3d=False,
with_mask_3d=False,
with_seg_3d=True),
dict(
type='PointSegClassMapping',
valid_cat_ids=(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28,
33, 34, 36, 39),
max_cat_id=40),
dict(
type='DefaultFormatBundle3D',
with_label=False,
class_names=class_names),
dict(type='Collect3D', keys=['points', 'pts_semantic_mask'])
] ]
data = dict( train_dataloader = dict(
samples_per_gpu=8, batch_size=8,
workers_per_gpu=4, num_workers=4,
train=dict( persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(
type=dataset_type, type=dataset_type,
data_root=data_root, data_root=data_root,
ann_file=data_root + 'scannet_infos_train.pkl', ann_file='scannet_infos_train.pkl',
metainfo=metainfo,
data_prefix=data_prefix,
pipeline=train_pipeline, pipeline=train_pipeline,
classes=class_names, modality=input_modality,
test_mode=False,
ignore_index=len(class_names), ignore_index=len(class_names),
scene_idxs=data_root + 'seg_info/train_resampled_scene_idxs.npy'), scene_idxs=data_root + 'seg_info/train_resampled_scene_idxs.npy',
val=dict( test_mode=False))
test_dataloader = dict(
batch_size=1,
num_workers=1,
persistent_workers=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type, type=dataset_type,
data_root=data_root, data_root=data_root,
ann_file=data_root + 'scannet_infos_val.pkl', ann_file='scannet_infos_val.pkl',
metainfo=metainfo,
data_prefix=data_prefix,
pipeline=test_pipeline, pipeline=test_pipeline,
classes=class_names, modality=input_modality,
test_mode=True, ignore_index=len(class_names),
ignore_index=len(class_names)), test_mode=True))
test=dict( val_dataloader = test_dataloader
type=dataset_type,
data_root=data_root,
ann_file=data_root + 'scannet_infos_val.pkl',
pipeline=test_pipeline,
classes=class_names,
test_mode=True,
ignore_index=len(class_names)))
evaluation = dict(pipeline=eval_pipeline) val_evaluator = dict(type='SegMetric')
test_evaluator = val_evaluator
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from .det3d_data_sample import Det3DDataSample from .det3d_data_sample import Det3DDataSample
from .point_data import PointData
__all__ = ['Det3DDataSample'] __all__ = ['Det3DDataSample', 'PointData']
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from mmengine.data import InstanceData, PixelData from mmengine.data import InstanceData
from mmdet.core.data_structures import DetDataSample from mmdet.core.data_structures import DetDataSample
from .point_data import PointData
class Det3DDataSample(DetDataSample): class Det3DDataSample(DetDataSample):
...@@ -43,19 +44,15 @@ class Det3DDataSample(DetDataSample): ...@@ -43,19 +44,15 @@ class Det3DDataSample(DetDataSample):
`use_lidar=True, use_camera=True`, the 3D predictions based on `use_lidar=True, use_camera=True`, the 3D predictions based on
image are saved in `img_pred_instances_3d` to distinguish with image are saved in `img_pred_instances_3d` to distinguish with
`pts_pred_instances_3d` which based on point cloud. `pts_pred_instances_3d` which based on point cloud.
- ``gt_pts_sem_seg``(PixelData): Ground truth of point cloud - ``gt_pts_seg``(PointData): Ground truth of point cloud
semantic segmentation. segmentation.
- ``pred_pts_sem_seg``(PixelData): Prediction of point cloud - ``pred_pts_seg``(PointData): Prediction of point cloud
semantic segmentation. segmentation.
- ``gt_pts_panoptic_seg``(PixelData): Ground truth of point cloud
panoptic segmentation.
- ``pred_pts_panoptic_seg``(PixelData): Predicted of point cloud
panoptic segmentation.
- ``eval_ann_info``(dict): Raw annotation, which will be passed to - ``eval_ann_info``(dict): Raw annotation, which will be passed to
evaluator and do the online evaluation. evaluator and do the online evaluation.
Examples: Examples:
>>> from mmengine.data import InstanceData, PixelData >>> from mmengine.data import InstanceData
>>> from mmdet3d.core import Det3DDataSample >>> from mmdet3d.core import Det3DDataSample
>>> from mmdet3d.core.bbox import BaseInstance3DBoxes >>> from mmdet3d.core.bbox import BaseInstance3DBoxes
...@@ -128,38 +125,33 @@ class Det3DDataSample(DetDataSample): ...@@ -128,38 +125,33 @@ class Det3DDataSample(DetDataSample):
>>> assert 'bboxes' in data_sample.gt_instances_3d >>> assert 'bboxes' in data_sample.gt_instances_3d
>>> data_sample = Det3DDataSample() >>> data_sample = Det3DDataSample()
>>> gt_pts_panoptic_seg_data = dict(panoptic_seg=torch.rand(1, 2, 4)) ... gt_pts_seg_data = dict(
>>> gt_pts_panoptic_seg = PixelData(**gt_pts_panoptic_seg_data) ... pts_instance_mask=torch.rand(2),
>>> data_sample.gt_pts_panoptic_seg = gt_pts_panoptic_seg ... pts_semantic_mask=torch.rand(2))
>>> data_sample.gt_pts_seg = PointData(**gt_pts_seg_data)
>>> print(data_sample) >>> print(data_sample)
<Det3DDataSample( <Det3DDataSample(
META INFORMATION META INFORMATION
DATA FIELDS DATA FIELDS
_gt_pts_panoptic_seg: <PixelData( gt_pts_seg: <PointData(
META INFORMATION META INFORMATION
DATA FIELDS DATA FIELDS
panoptic_seg: tensor([[[0.9875, 0.3012, 0.5534, 0.9593], pts_instance_mask: tensor([0.0576, 0.3067])
[0.1251, 0.1911, 0.8058, 0.2566]]]) pts_semantic_mask: tensor([0.9267, 0.7455])
) at 0x7fb0d93543d0> ) at 0x7f654a9c1590>
gt_pts_panoptic_seg: <PixelData( _gt_pts_seg: <PointData(
META INFORMATION META INFORMATION
DATA FIELDS DATA FIELDS
panoptic_seg: tensor([[[0.9875, 0.3012, 0.5534, 0.9593], pts_instance_mask: tensor([0.0576, 0.3067])
[0.1251, 0.1911, 0.8058, 0.2566]]]) pts_semantic_mask: tensor([0.9267, 0.7455])
) at 0x7fb0d93543d0> ) at 0x7f654a9c1590>
) at 0x7fb0d9354280> ) at 0x7f654a9c1550>
>>> data_sample = Det3DDataSample()
>>> gt_pts_sem_seg_data = dict(segm_seg=torch.rand(2, 2, 2))
>>> gt_pts_sem_seg = PixelData(**gt_pts_sem_seg_data)
>>> data_sample.gt_pts_sem_seg = gt_pts_sem_seg
>>> assert 'gt_pts_sem_seg' in data_sample
>>> assert 'segm_seg' in data_sample.gt_pts_sem_seg
""" """
@property @property
...@@ -211,49 +203,25 @@ class Det3DDataSample(DetDataSample): ...@@ -211,49 +203,25 @@ class Det3DDataSample(DetDataSample):
del self._img_pred_instances_3d del self._img_pred_instances_3d
@property @property
def gt_pts_sem_seg(self) -> PixelData: def gt_pts_seg(self) -> PointData:
return self._gt_pts_sem_seg return self._gt_pts_seg
@gt_pts_sem_seg.setter
def gt_pts_sem_seg(self, value: PixelData):
self.set_field(value, '_gt_pts_sem_seg', dtype=PixelData)
@gt_pts_sem_seg.deleter
def gt_pts_sem_seg(self):
del self._gt_pts_sem_seg
@property
def pred_pts_sem_seg(self) -> PixelData:
return self._pred_pts_sem_seg
@pred_pts_sem_seg.setter
def pred_pts_sem_seg(self, value: PixelData):
self.set_field(value, '_pred_pts_sem_seg', dtype=PixelData)
@pred_pts_sem_seg.deleter
def pred_pts_sem_seg(self):
del self._pred_pts_sem_seg
@property
def gt_pts_panoptic_seg(self) -> PixelData:
return self._gt_pts_panoptic_seg
@gt_pts_panoptic_seg.setter @gt_pts_seg.setter
def gt_pts_panoptic_seg(self, value: PixelData): def gt_pts_seg(self, value: PointData):
self.set_field(value, '_gt_pts_panoptic_seg', dtype=PixelData) self.set_field(value, '_gt_pts_seg', dtype=PointData)
@gt_pts_panoptic_seg.deleter @gt_pts_seg.deleter
def gt_pts_panoptic_seg(self): def gt_pts_seg(self):
del self._gt_pts_panoptic_seg del self._gt_pts_seg
@property @property
def pred_pts_panoptic_seg(self) -> PixelData: def pred_pts_seg(self) -> PointData:
return self._pred_pts_panoptic_seg return self._pred_pts_seg
@pred_pts_panoptic_seg.setter @pred_pts_seg.setter
def pred_pts_panoptic_seg(self, value: PixelData): def pred_pts_seg(self, value: PointData):
self.set_field(value, '_pred_pts_panoptic_seg', dtype=PixelData) self.set_field(value, '_pred_pts_seg', dtype=PointData)
@pred_pts_panoptic_seg.deleter @pred_pts_seg.deleter
def pred_pts_panoptic_seg(self): def pred_pts_seg(self):
del self._pred_pts_panoptic_seg del self._pred_pts_seg
# Copyright (c) OpenMMLab. All rights reserved.
from collections.abc import Sized
from typing import Union
import numpy as np
import torch
from mmengine.data import BaseDataElement
IndexType = Union[str, slice, int, list, torch.LongTensor,
torch.cuda.LongTensor, torch.BoolTensor,
torch.cuda.BoolTensor, np.ndarray]
class PointData(BaseDataElement):
"""Data structure for point-level annnotations or predictions.
All data items in ``data_fields`` of ``PointData`` meet the following
requirements:
- They are all one dimension.
- They should have the same length.
Notice: ``PointData`` behaves like `InstanceData`.
Examples:
>>> metainfo = dict(
... sample_id=random.randint(0, 100))
>>> points = np.random.randint(0, 255, (100, 3))
>>> point_data = PointData(metainfo=metainfo,
... points=points)
>>> print(len(point_data))
>>> (100)
>>> # slice
>>> slice_data = pixel_data[10:60]
>>> assert slice_data.shape == (50,)
>>> # set
>>> point_data.pts_semantic_mask = torch.randint(0, 255, (100))
>>> point_data.pts_instance_mask = torch.randint(0, 255, (100))
>>> assert tuple(point_data.pts_semantic_mask.shape) == (100)
>>> assert tuple(point_data.pts_instance_mask.shape) == (100)
"""
def __setattr__(self, name: str, value: Sized):
"""setattr is only used to set data.
the value must have the attribute of `__len__` and have the same length
of PointData.
"""
if name in ('_metainfo_fields', '_data_fields'):
if not hasattr(self, name):
super().__setattr__(name, value)
else:
raise AttributeError(
f'{name} has been used as a '
f'private attribute, which is immutable. ')
else:
assert isinstance(value,
Sized), 'value must contain `_len__` attribute'
if len(self) > 0:
assert len(value) == len(self), f'the length of ' \
f'values {len(value)} is ' \
f'not consistent with' \
f' the length of this ' \
f':obj:`PointData` ' \
f'{len(self)} '
super().__setattr__(name, value)
__setitem__ = __setattr__
def __getitem__(self, item: IndexType) -> 'PointData':
"""
Args:
item (str, obj:`slice`,
obj`torch.LongTensor`, obj:`torch.BoolTensor`):
get the corresponding values according to item.
Returns:
obj:`PointData`: Corresponding values.
"""
if isinstance(item, list):
item = np.array(item)
if isinstance(item, np.ndarray):
item = torch.from_numpy(item)
assert isinstance(
item, (str, slice, int, torch.LongTensor, torch.cuda.LongTensor,
torch.BoolTensor, torch.cuda.BoolTensor))
if isinstance(item, str):
return getattr(self, item)
if type(item) == int:
if item >= len(self) or item < -len(self): # type:ignore
raise IndexError(f'Index {item} out of range!')
else:
# keep the dimension
item = slice(item, None, len(self))
new_data = self.__class__(metainfo=self.metainfo)
if isinstance(item, torch.Tensor):
assert item.dim() == 1, 'Only support to get the' \
' values along the first dimension.'
if isinstance(item, (torch.BoolTensor, torch.cuda.BoolTensor)):
assert len(item) == len(self), f'The shape of the' \
f' input(BoolTensor)) ' \
f'{len(item)} ' \
f' does not match the shape ' \
f'of the indexed tensor ' \
f'in results_filed ' \
f'{len(self)} at ' \
f'first dimension. '
for k, v in self.items():
if isinstance(v, torch.Tensor):
new_data[k] = v[item]
elif isinstance(v, np.ndarray):
new_data[k] = v[item.cpu().numpy()]
elif isinstance(
v, (str, list, tuple)) or (hasattr(v, '__getitem__')
and hasattr(v, 'cat')):
# convert to indexes from boolTensor
if isinstance(item,
(torch.BoolTensor, torch.cuda.BoolTensor)):
indexes = torch.nonzero(item).view(
-1).cpu().numpy().tolist()
else:
indexes = item.cpu().numpy().tolist()
slice_list = []
if indexes:
for index in indexes:
slice_list.append(slice(index, None, len(v)))
else:
slice_list.append(slice(None, 0, None))
r_list = [v[s] for s in slice_list]
if isinstance(v, (str, list, tuple)):
new_value = r_list[0]
for r in r_list[1:]:
new_value = new_value + r
else:
new_value = v.cat(r_list)
new_data[k] = new_value
else:
raise ValueError(
f'The type of `{k}` is `{type(v)}`, which has no '
'attribute of `cat`, so it does not '
f'support slice with `bool`')
else:
# item is a slice
for k, v in self.items():
new_data[k] = v[item]
return new_data # type:ignore
def __len__(self) -> int:
"""int: the length of PointData"""
if len(self._data_fields) > 0:
return len(self.values()[0])
else:
return 0
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from .builder import DATASETS, PIPELINES, build_dataset from .builder import DATASETS, PIPELINES, build_dataset
from .custom_3d_seg import Custom3DSegDataset
from .det3d_dataset import Det3DDataset from .det3d_dataset import Det3DDataset
from .kitti_dataset import KittiDataset from .kitti_dataset import KittiDataset
from .kitti_mono_dataset import KittiMonoDataset from .kitti_mono_dataset import KittiMonoDataset
...@@ -22,6 +21,7 @@ from .pipelines import (AffineResize, BackgroundPointsFilter, GlobalAlignment, ...@@ -22,6 +21,7 @@ from .pipelines import (AffineResize, BackgroundPointsFilter, GlobalAlignment,
from .s3dis_dataset import S3DISDataset, S3DISSegDataset from .s3dis_dataset import S3DISDataset, S3DISSegDataset
from .scannet_dataset import (ScanNetDataset, ScanNetInstanceSegDataset, from .scannet_dataset import (ScanNetDataset, ScanNetInstanceSegDataset,
ScanNetSegDataset) ScanNetSegDataset)
from .seg3d_dataset import Seg3DDataset
from .semantickitti_dataset import SemanticKITTIDataset from .semantickitti_dataset import SemanticKITTIDataset
from .sunrgbd_dataset import SUNRGBDDataset from .sunrgbd_dataset import SUNRGBDDataset
from .utils import get_loading_pipeline from .utils import get_loading_pipeline
...@@ -36,7 +36,7 @@ __all__ = [ ...@@ -36,7 +36,7 @@ __all__ = [
'IndoorPatchPointSample', 'IndoorPointSample', 'PointSample', 'IndoorPatchPointSample', 'IndoorPointSample', 'PointSample',
'LoadAnnotations3D', 'GlobalAlignment', 'SUNRGBDDataset', 'ScanNetDataset', 'LoadAnnotations3D', 'GlobalAlignment', 'SUNRGBDDataset', 'ScanNetDataset',
'ScanNetSegDataset', 'ScanNetInstanceSegDataset', 'SemanticKITTIDataset', 'ScanNetSegDataset', 'ScanNetInstanceSegDataset', 'SemanticKITTIDataset',
'Det3DDataset', 'Custom3DSegDataset', 'LoadPointsFromMultiSweeps', 'Det3DDataset', 'Seg3DDataset', 'LoadPointsFromMultiSweeps',
'WaymoDataset', 'BackgroundPointsFilter', 'VoxelBasedPointSampler', 'WaymoDataset', 'BackgroundPointsFilter', 'VoxelBasedPointSampler',
'get_loading_pipeline', 'RandomDropPointsColor', 'RandomJitterPoints', 'get_loading_pipeline', 'RandomDropPointsColor', 'RandomJitterPoints',
'ObjectNameFilter', 'AffineResize', 'RandomShiftScale', 'ObjectNameFilter', 'AffineResize', 'RandomShiftScale',
......
# Copyright (c) OpenMMLab. All rights reserved.
import tempfile
import warnings
from os import path as osp
import mmcv
import numpy as np
from torch.utils.data import Dataset
from mmdet3d.registry import DATASETS
from .pipelines import Compose
from .utils import extract_result_dict, get_loading_pipeline
@DATASETS.register_module()
class Custom3DSegDataset(Dataset):
"""Customized 3D dataset for semantic segmentation task.
This is the base dataset of ScanNet and S3DIS dataset.
Args:
data_root (str): Path of dataset root.
ann_file (str): Path of annotation file.
pipeline (list[dict], optional): Pipeline used for data processing.
Defaults to None.
classes (tuple[str], optional): Classes used in the dataset.
Defaults to None.
palette (list[list[int]], optional): The palette of segmentation map.
Defaults to None.
modality (dict, optional): Modality to specify the sensor data used
as input. Defaults to None.
test_mode (bool, optional): Whether the dataset is in test mode.
Defaults to False.
ignore_index (int, optional): The label index to be ignored, e.g.
unannotated points. If None is given, set to len(self.CLASSES) to
be consistent with PointSegClassMapping function in pipeline.
Defaults to None.
scene_idxs (np.ndarray | str, optional): Precomputed index to load
data. For scenes with many points, we may sample it several times.
Defaults to None.
"""
# names of all classes data used for the task
CLASSES = None
# class_ids used for training
VALID_CLASS_IDS = None
# all possible class_ids in loaded segmentation mask
ALL_CLASS_IDS = None
# official color for visualization
PALETTE = None
def __init__(self,
data_root,
ann_file,
pipeline=None,
classes=None,
palette=None,
modality=None,
test_mode=False,
ignore_index=None,
scene_idxs=None,
file_client_args=dict(backend='disk')):
super().__init__()
self.data_root = data_root
self.ann_file = ann_file
self.test_mode = test_mode
self.modality = modality
self.file_client = mmcv.FileClient(**file_client_args)
# load annotations
if hasattr(self.file_client, 'get_local_path'):
with self.file_client.get_local_path(self.ann_file) as local_path:
self.data_infos = self.load_annotations(open(local_path, 'rb'))
else:
warnings.warn(
'The used MMCV version does not have get_local_path. '
f'We treat the {self.ann_file} as local paths and it '
'might cause errors if the path is not a local path. '
'Please use MMCV>= 1.3.16 if you meet errors.')
self.data_infos = self.load_annotations(self.ann_file)
if pipeline is not None:
self.pipeline = Compose(pipeline)
self.ignore_index = len(self.CLASSES) if \
ignore_index is None else ignore_index
self.scene_idxs = self.get_scene_idxs(scene_idxs)
self.CLASSES, self.PALETTE = \
self.get_classes_and_palette(classes, palette)
# set group flag for the sampler
if not self.test_mode:
self._set_group_flag()
def load_annotations(self, ann_file):
"""Load annotations from ann_file.
Args:
ann_file (str): Path of the annotation file.
Returns:
list[dict]: List of annotations.
"""
# loading data from a file-like object needs file format
return mmcv.load(ann_file, file_format='pkl')
def get_data_info(self, index):
"""Get data info according to the given index.
Args:
index (int): Index of the sample data to get.
Returns:
dict: Data information that will be passed to the data
preprocessing pipelines. It includes the following keys:
- sample_idx (str): Sample index.
- pts_filename (str): Filename of point clouds.
- file_name (str): Filename of point clouds.
- ann_info (dict): Annotation info.
"""
info = self.data_infos[index]
sample_idx = info['point_cloud']['lidar_idx']
pts_filename = osp.join(self.data_root, info['pts_path'])
input_dict = dict(
pts_filename=pts_filename,
sample_idx=sample_idx,
file_name=pts_filename)
if not self.test_mode:
annos = self.get_ann_info(index)
input_dict['ann_info'] = annos
return input_dict
def pre_pipeline(self, results):
"""Initialization before data preparation.
Args:
results (dict): Dict before data preprocessing.
- img_fields (list): Image fields.
- pts_mask_fields (list): Mask fields of points.
- pts_seg_fields (list): Mask fields of point segments.
- mask_fields (list): Fields of masks.
- seg_fields (list): Segment fields.
"""
results['img_fields'] = []
results['pts_mask_fields'] = []
results['pts_seg_fields'] = []
results['mask_fields'] = []
results['seg_fields'] = []
results['bbox3d_fields'] = []
def prepare_train_data(self, index):
"""Training data preparation.
Args:
index (int): Index for accessing the target data.
Returns:
dict: Training data dict of the corresponding index.
"""
input_dict = self.get_data_info(index)
if input_dict is None:
return None
self.pre_pipeline(input_dict)
example = self.pipeline(input_dict)
return example
def prepare_test_data(self, index):
"""Prepare data for testing.
Args:
index (int): Index for accessing the target data.
Returns:
dict: Testing data dict of the corresponding index.
"""
input_dict = self.get_data_info(index)
self.pre_pipeline(input_dict)
example = self.pipeline(input_dict)
return example
def get_classes_and_palette(self, classes=None, palette=None):
"""Get class names of current dataset.
This function is taken from MMSegmentation.
Args:
classes (Sequence[str] | str): If classes is None, use
default CLASSES defined by builtin dataset. If classes is a
string, take it as a file name. The file contains the name of
classes where each line contains one class name. If classes is
a tuple or list, override the CLASSES defined by the dataset.
Defaults to None.
palette (Sequence[Sequence[int]]] | np.ndarray):
The palette of segmentation map. If None is given, random
palette will be generated. Defaults to None.
"""
if classes is None:
self.custom_classes = False
# map id in the loaded mask to label used for training
self.label_map = {
cls_id: self.ignore_index
for cls_id in self.ALL_CLASS_IDS
}
self.label_map.update(
{cls_id: i
for i, cls_id in enumerate(self.VALID_CLASS_IDS)})
# map label to category name
self.label2cat = {
i: cat_name
for i, cat_name in enumerate(self.CLASSES)
}
return self.CLASSES, self.PALETTE
self.custom_classes = True
if isinstance(classes, str):
# take it as a file path
class_names = mmcv.list_from_file(classes)
elif isinstance(classes, (tuple, list)):
class_names = classes
else:
raise ValueError(f'Unsupported type {type(classes)} of classes.')
if self.CLASSES:
if not set(class_names).issubset(self.CLASSES):
raise ValueError('classes is not a subset of CLASSES.')
# update valid_class_ids
self.VALID_CLASS_IDS = [
self.VALID_CLASS_IDS[self.CLASSES.index(cls_name)]
for cls_name in class_names
]
# dictionary, its keys are the old label ids and its values
# are the new label ids.
# used for changing pixel labels in load_annotations.
self.label_map = {
cls_id: self.ignore_index
for cls_id in self.ALL_CLASS_IDS
}
self.label_map.update(
{cls_id: i
for i, cls_id in enumerate(self.VALID_CLASS_IDS)})
self.label2cat = {
i: cat_name
for i, cat_name in enumerate(class_names)
}
# modify palette for visualization
palette = [
self.PALETTE[self.CLASSES.index(cls_name)]
for cls_name in class_names
]
return class_names, palette
def get_scene_idxs(self, scene_idxs):
"""Compute scene_idxs for data sampling.
We sample more times for scenes with more points.
"""
if self.test_mode:
# when testing, we load one whole scene every time
return np.arange(len(self.data_infos)).astype(np.int32)
# we may need to re-sample different scenes according to scene_idxs
# this is necessary for indoor scene segmentation such as ScanNet
if scene_idxs is None:
scene_idxs = np.arange(len(self.data_infos))
if isinstance(scene_idxs, str):
with self.file_client.get_local_path(scene_idxs) as local_path:
scene_idxs = np.load(local_path)
else:
scene_idxs = np.array(scene_idxs)
return scene_idxs.astype(np.int32)
def format_results(self,
outputs,
pklfile_prefix=None,
submission_prefix=None):
"""Format the results to pkl file.
Args:
outputs (list[dict]): Testing results of the dataset.
pklfile_prefix (str): The prefix of pkl files. It includes
the file path and the prefix of filename, e.g., "a/b/prefix".
If not specified, a temp file will be created. Default: None.
Returns:
tuple: (outputs, tmp_dir), outputs is the detection results,
tmp_dir is the temporal directory created for saving json
files when ``jsonfile_prefix`` is not specified.
"""
if pklfile_prefix is None:
tmp_dir = tempfile.TemporaryDirectory()
pklfile_prefix = osp.join(tmp_dir.name, 'results')
out = f'{pklfile_prefix}.pkl'
mmcv.dump(outputs, out)
return outputs, tmp_dir
def evaluate(self,
results,
metric=None,
logger=None,
show=False,
out_dir=None,
pipeline=None):
"""Evaluate.
Evaluation in semantic segmentation protocol.
Args:
results (list[dict]): List of results.
metric (str | list[str]): Metrics to be evaluated.
logger (logging.Logger | str, optional): Logger used for printing
related information during evaluation. Defaults to None.
show (bool, optional): Whether to visualize.
Defaults to False.
out_dir (str, optional): Path to save the visualization results.
Defaults to None.
pipeline (list[dict], optional): raw data loading for showing.
Default: None.
Returns:
dict: Evaluation results.
"""
from mmdet3d.core.evaluation import seg_eval
assert isinstance(
results, list), f'Expect results to be list, got {type(results)}.'
assert len(results) > 0, 'Expect length of results > 0.'
assert len(results) == len(self.data_infos)
assert isinstance(
results[0], dict
), f'Expect elements in results to be dict, got {type(results[0])}.'
load_pipeline = self._get_pipeline(pipeline)
pred_sem_masks = [result['semantic_mask'] for result in results]
gt_sem_masks = [
self._extract_data(
i, load_pipeline, 'pts_semantic_mask', load_annos=True)
for i in range(len(self.data_infos))
]
ret_dict = seg_eval(
gt_sem_masks,
pred_sem_masks,
self.label2cat,
self.ignore_index,
logger=logger)
if show:
self.show(pred_sem_masks, out_dir, pipeline=pipeline)
return ret_dict
def _rand_another(self, idx):
"""Randomly get another item with the same flag.
Returns:
int: Another index of item with the same flag.
"""
pool = np.where(self.flag == self.flag[idx])[0]
return np.random.choice(pool)
def _build_default_pipeline(self):
"""Build the default pipeline for this dataset."""
raise NotImplementedError('_build_default_pipeline is not implemented '
f'for dataset {self.__class__.__name__}')
def _get_pipeline(self, pipeline):
"""Get data loading pipeline in self.show/evaluate function.
Args:
pipeline (list[dict]): Input pipeline. If None is given,
get from self.pipeline.
"""
if pipeline is None:
if not hasattr(self, 'pipeline') or self.pipeline is None:
warnings.warn(
'Use default pipeline for data loading, this may cause '
'errors when data is on ceph')
return self._build_default_pipeline()
loading_pipeline = get_loading_pipeline(self.pipeline.transforms)
return Compose(loading_pipeline)
return Compose(pipeline)
def _extract_data(self, index, pipeline, key, load_annos=False):
"""Load data using input pipeline and extract data according to key.
Args:
index (int): Index for accessing the target data.
pipeline (:obj:`Compose`): Composed data loading pipeline.
key (str | list[str]): One single or a list of data key.
load_annos (bool): Whether to load data annotations.
If True, need to set self.test_mode as False before loading.
Returns:
np.ndarray | torch.Tensor | list[np.ndarray | torch.Tensor]:
A single or a list of loaded data.
"""
assert pipeline is not None, 'data loading pipeline is not provided'
# when we want to load ground-truth via pipeline (e.g. bbox, seg mask)
# we need to set self.test_mode as False so that we have 'annos'
if load_annos:
original_test_mode = self.test_mode
self.test_mode = False
input_dict = self.get_data_info(index)
self.pre_pipeline(input_dict)
example = pipeline(input_dict)
# extract data items according to keys
if isinstance(key, str):
data = extract_result_dict(example, key)
else:
data = [extract_result_dict(example, k) for k in key]
if load_annos:
self.test_mode = original_test_mode
return data
def __len__(self):
"""Return the length of scene_idxs.
Returns:
int: Length of data infos.
"""
return len(self.scene_idxs)
def __getitem__(self, idx):
"""Get item from infos according to the given index.
In indoor scene segmentation task, each scene contains millions of
points. However, we only sample less than 10k points within a patch
each time. Therefore, we use `scene_idxs` to re-sample different rooms.
Returns:
dict: Data dictionary of the corresponding index.
"""
scene_idx = self.scene_idxs[idx] # map to scene idx
if self.test_mode:
return self.prepare_test_data(scene_idx)
while True:
data = self.prepare_train_data(scene_idx)
if data is None:
idx = self._rand_another(idx)
scene_idx = self.scene_idxs[idx] # map to scene idx
continue
return data
def _set_group_flag(self):
"""Set flag according to image aspect ratio.
Images with aspect ratio greater than 1 will be set as group 1,
otherwise group 0. In 3D datasets, they are all the same, thus are all
zeros.
"""
self.flag = np.zeros(len(self), dtype=np.uint8)
...@@ -6,7 +6,7 @@ from mmcv import BaseTransform ...@@ -6,7 +6,7 @@ from mmcv import BaseTransform
from mmcv.transforms import to_tensor from mmcv.transforms import to_tensor
from mmengine import InstanceData from mmengine import InstanceData
from mmdet3d.core import Det3DDataSample from mmdet3d.core import Det3DDataSample, PointData
from mmdet3d.core.bbox import BaseInstance3DBoxes from mmdet3d.core.bbox import BaseInstance3DBoxes
from mmdet3d.core.points import BasePoints from mmdet3d.core.points import BasePoints
from mmdet3d.registry import TRANSFORMS from mmdet3d.registry import TRANSFORMS
...@@ -143,7 +143,7 @@ class Pack3DDetInputs(BaseTransform): ...@@ -143,7 +143,7 @@ class Pack3DDetInputs(BaseTransform):
data_sample = Det3DDataSample() data_sample = Det3DDataSample()
gt_instances_3d = InstanceData() gt_instances_3d = InstanceData()
gt_instances = InstanceData() gt_instances = InstanceData()
seg_data = dict() gt_pts_seg = PointData()
img_metas = {} img_metas = {}
for key in self.meta_keys: for key in self.meta_keys:
...@@ -161,7 +161,7 @@ class Pack3DDetInputs(BaseTransform): ...@@ -161,7 +161,7 @@ class Pack3DDetInputs(BaseTransform):
elif key in self.INSTANCEDATA_2D_KEYS: elif key in self.INSTANCEDATA_2D_KEYS:
gt_instances[self._remove_prefix(key)] = results[key] gt_instances[self._remove_prefix(key)] = results[key]
elif key in self.SEG_KEYS: elif key in self.SEG_KEYS:
seg_data[self._remove_prefix(key)] = results[key] gt_pts_seg[self._remove_prefix(key)] = results[key]
else: else:
raise NotImplementedError(f'Please modified ' raise NotImplementedError(f'Please modified '
f'`Pack3DDetInputs` ' f'`Pack3DDetInputs` '
...@@ -170,7 +170,7 @@ class Pack3DDetInputs(BaseTransform): ...@@ -170,7 +170,7 @@ class Pack3DDetInputs(BaseTransform):
data_sample.gt_instances_3d = gt_instances_3d data_sample.gt_instances_3d = gt_instances_3d
data_sample.gt_instances = gt_instances data_sample.gt_instances = gt_instances
data_sample.seg_data = seg_data data_sample.gt_pts_seg = gt_pts_seg
if 'eval_ann_info' in results: if 'eval_ann_info' in results:
data_sample.eval_ann_info = results['eval_ann_info'] data_sample.eval_ann_info = results['eval_ann_info']
else: else:
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Sequence from typing import List
import mmcv import mmcv
import numpy as np import numpy as np
...@@ -270,22 +270,6 @@ class PointSegClassMapping(BaseTransform): ...@@ -270,22 +270,6 @@ class PointSegClassMapping(BaseTransform):
segmentation mask. Defaults to 40. segmentation mask. Defaults to 40.
""" """
def __init__(self,
valid_cat_ids: Sequence[int],
max_cat_id: int = 40) -> None:
assert max_cat_id >= np.max(valid_cat_ids), \
'max_cat_id should be greater than maximum id in valid_cat_ids'
self.valid_cat_ids = valid_cat_ids
self.max_cat_id = int(max_cat_id)
# build cat_id to class index mapping
neg_cls = len(valid_cat_ids)
self.cat_id2class = np.ones(
self.max_cat_id + 1, dtype=np.int) * neg_cls
for cls_idx, cat_id in enumerate(valid_cat_ids):
self.cat_id2class[cat_id] = cls_idx
def transform(self, results: dict) -> None: def transform(self, results: dict) -> None:
"""Call function to map original semantic class to valid category ids. """Call function to map original semantic class to valid category ids.
...@@ -301,9 +285,19 @@ class PointSegClassMapping(BaseTransform): ...@@ -301,9 +285,19 @@ class PointSegClassMapping(BaseTransform):
assert 'pts_semantic_mask' in results assert 'pts_semantic_mask' in results
pts_semantic_mask = results['pts_semantic_mask'] pts_semantic_mask = results['pts_semantic_mask']
converted_pts_sem_mask = self.cat_id2class[pts_semantic_mask] assert 'label_mapping' in results
label_mapping = results['label_mapping']
converted_pts_sem_mask = \
np.array([label_mapping[mask] for mask in pts_semantic_mask])
results['pts_semantic_mask'] = converted_pts_sem_mask results['pts_semantic_mask'] = converted_pts_sem_mask
# 'eval_ann_info' will be passed to evaluator
if 'eval_ann_info' in results:
assert 'pts_semantic_mask' in results['eval_ann_info']
results['eval_ann_info']['pts_semantic_mask'] = \
converted_pts_sem_mask
return results return results
def __repr__(self): def __repr__(self):
...@@ -315,17 +309,17 @@ class PointSegClassMapping(BaseTransform): ...@@ -315,17 +309,17 @@ class PointSegClassMapping(BaseTransform):
@TRANSFORMS.register_module() @TRANSFORMS.register_module()
class NormalizePointsColor(object): class NormalizePointsColor(BaseTransform):
"""Normalize color of points. """Normalize color of points.
Args: Args:
color_mean (list[float]): Mean color of the point cloud. color_mean (list[float]): Mean color of the point cloud.
""" """
def __init__(self, color_mean): def __init__(self, color_mean: List[float]) -> None:
self.color_mean = color_mean self.color_mean = color_mean
def __call__(self, results): def transform(self, input_dict: dict) -> dict:
"""Call function to normalize color of points. """Call function to normalize color of points.
Args: Args:
...@@ -337,7 +331,7 @@ class NormalizePointsColor(object): ...@@ -337,7 +331,7 @@ class NormalizePointsColor(object):
- points (:obj:`BasePoints`): Points after color normalization. - points (:obj:`BasePoints`): Points after color normalization.
""" """
points = results['points'] points = input_dict['points']
assert points.attribute_dims is not None and \ assert points.attribute_dims is not None and \
'color' in points.attribute_dims.keys(), \ 'color' in points.attribute_dims.keys(), \
'Expect points have color attribute' 'Expect points have color attribute'
...@@ -345,8 +339,8 @@ class NormalizePointsColor(object): ...@@ -345,8 +339,8 @@ class NormalizePointsColor(object):
points.color = points.color - \ points.color = points.color - \
points.color.new_tensor(self.color_mean) points.color.new_tensor(self.color_mean)
points.color = points.color / 255.0 points.color = points.color / 255.0
results['points'] = points input_dict['points'] = points
return results return input_dict
def __repr__(self): def __repr__(self):
"""str: Return a string that describes the module.""" """str: Return a string that describes the module."""
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import random import random
import warnings import warnings
from typing import Dict, List from typing import Dict, List, Optional, Tuple, Union
import cv2 import cv2
import numpy as np import numpy as np
...@@ -18,7 +18,7 @@ from .data_augment_utils import noise_per_object_v3_ ...@@ -18,7 +18,7 @@ from .data_augment_utils import noise_per_object_v3_
@TRANSFORMS.register_module() @TRANSFORMS.register_module()
class RandomDropPointsColor(object): class RandomDropPointsColor(BaseTransform):
r"""Randomly set the color of points to all zeros. r"""Randomly set the color of points to all zeros.
Once this transform is executed, all the points' color will be dropped. Once this transform is executed, all the points' color will be dropped.
...@@ -30,12 +30,12 @@ class RandomDropPointsColor(object): ...@@ -30,12 +30,12 @@ class RandomDropPointsColor(object):
Defaults to 0.2. Defaults to 0.2.
""" """
def __init__(self, drop_ratio=0.2): def __init__(self, drop_ratio: float = 0.2) -> None:
assert isinstance(drop_ratio, (int, float)) and 0 <= drop_ratio <= 1, \ assert isinstance(drop_ratio, (int, float)) and 0 <= drop_ratio <= 1, \
f'invalid drop_ratio value {drop_ratio}' f'invalid drop_ratio value {drop_ratio}'
self.drop_ratio = drop_ratio self.drop_ratio = drop_ratio
def __call__(self, input_dict): def transform(self, input_dict: dict) -> dict:
"""Call function to drop point colors. """Call function to drop point colors.
Args: Args:
...@@ -224,7 +224,7 @@ class RandomFlip3D(RandomFlip): ...@@ -224,7 +224,7 @@ class RandomFlip3D(RandomFlip):
@TRANSFORMS.register_module() @TRANSFORMS.register_module()
class RandomJitterPoints(object): class RandomJitterPoints(BaseTransform):
"""Randomly jitter point coordinates. """Randomly jitter point coordinates.
Different from the global translation in ``GlobalRotScaleTrans``, here we Different from the global translation in ``GlobalRotScaleTrans``, here we
...@@ -246,8 +246,8 @@ class RandomJitterPoints(object): ...@@ -246,8 +246,8 @@ class RandomJitterPoints(object):
""" """
def __init__(self, def __init__(self,
jitter_std=[0.01, 0.01, 0.01], jitter_std: List[float] = [0.01, 0.01, 0.01],
clip_range=[-0.05, 0.05]): clip_range: List[float] = [-0.05, 0.05]) -> None:
seq_types = (list, tuple, np.ndarray) seq_types = (list, tuple, np.ndarray)
if not isinstance(jitter_std, seq_types): if not isinstance(jitter_std, seq_types):
assert isinstance(jitter_std, (int, float)), \ assert isinstance(jitter_std, (int, float)), \
...@@ -262,7 +262,7 @@ class RandomJitterPoints(object): ...@@ -262,7 +262,7 @@ class RandomJitterPoints(object):
clip_range = [-clip_range, clip_range] clip_range = [-clip_range, clip_range]
self.clip_range = clip_range self.clip_range = clip_range
def __call__(self, input_dict): def transform(self, input_dict: dict) -> dict:
"""Call function to jitter all the points in the scene. """Call function to jitter all the points in the scene.
Args: Args:
...@@ -780,10 +780,10 @@ class GlobalRotScaleTrans(BaseTransform): ...@@ -780,10 +780,10 @@ class GlobalRotScaleTrans(BaseTransform):
@TRANSFORMS.register_module() @TRANSFORMS.register_module()
class PointShuffle(object): class PointShuffle(BaseTransform):
"""Shuffle input points.""" """Shuffle input points."""
def __call__(self, input_dict): def transform(self, input_dict: dict) -> dict:
"""Call function to shuffle points. """Call function to shuffle points.
Args: Args:
...@@ -1113,7 +1113,7 @@ class IndoorPointSample(PointSample): ...@@ -1113,7 +1113,7 @@ class IndoorPointSample(PointSample):
@TRANSFORMS.register_module() @TRANSFORMS.register_module()
class IndoorPatchPointSample(object): class IndoorPatchPointSample(BaseTransform):
r"""Indoor point sample within a patch. Modified from `PointNet++ <https:// r"""Indoor point sample within a patch. Modified from `PointNet++ <https://
github.com/charlesq34/pointnet2/blob/master/scannet/scannet_dataset.py>`_. github.com/charlesq34/pointnet2/blob/master/scannet/scannet_dataset.py>`_.
...@@ -1152,15 +1152,15 @@ class IndoorPatchPointSample(object): ...@@ -1152,15 +1152,15 @@ class IndoorPatchPointSample(object):
""" """
def __init__(self, def __init__(self,
num_points, num_points: int,
block_size=1.5, block_size: float = 1.5,
sample_rate=None, sample_rate: Optional[float] = None,
ignore_index=None, ignore_index: Optional[int] = None,
use_normalized_coord=False, use_normalized_coord: bool = False,
num_try=10, num_try: int = 10,
enlarge_size=0.2, enlarge_size: float = 0.2,
min_unique_num=None, min_unique_num: Optional[int] = None,
eps=1e-2): eps: float = 1e-2) -> None:
self.num_points = num_points self.num_points = num_points
self.block_size = block_size self.block_size = block_size
self.ignore_index = ignore_index self.ignore_index = ignore_index
...@@ -1175,8 +1175,10 @@ class IndoorPatchPointSample(object): ...@@ -1175,8 +1175,10 @@ class IndoorPatchPointSample(object):
"'sample_rate' has been deprecated and will be removed in " "'sample_rate' has been deprecated and will be removed in "
'the future. Please remove them from your code.') 'the future. Please remove them from your code.')
def _input_generation(self, coords, patch_center, coord_max, attributes, def _input_generation(self, coords: np.ndarray, patch_center: np.ndarray,
attribute_dims, point_type): coord_max: np.ndarray, attributes: np.ndarray,
attribute_dims: dict,
point_type: type) -> BasePoints:
"""Generating model input. """Generating model input.
Generate input by subtracting patch center and adding additional Generate input by subtracting patch center and adding additional
...@@ -1216,7 +1218,8 @@ class IndoorPatchPointSample(object): ...@@ -1216,7 +1218,8 @@ class IndoorPatchPointSample(object):
return points return points
def _patch_points_sampling(self, points, sem_mask): def _patch_points_sampling(self, points: BasePoints,
sem_mask: np.ndarray) -> BasePoints:
"""Patch points sampling. """Patch points sampling.
First sample a valid patch. First sample a valid patch.
...@@ -1316,7 +1319,7 @@ class IndoorPatchPointSample(object): ...@@ -1316,7 +1319,7 @@ class IndoorPatchPointSample(object):
return points, choices return points, choices
def __call__(self, results): def transform(self, input_dict: dict) -> dict:
"""Call function to sample points to in indoor scenes. """Call function to sample points to in indoor scenes.
Args: Args:
...@@ -1326,22 +1329,33 @@ class IndoorPatchPointSample(object): ...@@ -1326,22 +1329,33 @@ class IndoorPatchPointSample(object):
dict: Results after sampling, 'points', 'pts_instance_mask' dict: Results after sampling, 'points', 'pts_instance_mask'
and 'pts_semantic_mask' keys are updated in the result dict. and 'pts_semantic_mask' keys are updated in the result dict.
""" """
points = results['points'] points = input_dict['points']
assert 'pts_semantic_mask' in results.keys(), \ assert 'pts_semantic_mask' in input_dict.keys(), \
'semantic mask should be provided in training and evaluation' 'semantic mask should be provided in training and evaluation'
pts_semantic_mask = results['pts_semantic_mask'] pts_semantic_mask = input_dict['pts_semantic_mask']
points, choices = self._patch_points_sampling(points, points, choices = self._patch_points_sampling(points,
pts_semantic_mask) pts_semantic_mask)
results['points'] = points input_dict['points'] = points
results['pts_semantic_mask'] = pts_semantic_mask[choices] input_dict['pts_semantic_mask'] = pts_semantic_mask[choices]
pts_instance_mask = results.get('pts_instance_mask', None)
# 'eval_ann_info' will be passed to evaluator
if 'eval_ann_info' in input_dict:
input_dict['eval_ann_info']['pts_semantic_mask'] = \
pts_semantic_mask[choices]
pts_instance_mask = input_dict.get('pts_instance_mask', None)
if pts_instance_mask is not None: if pts_instance_mask is not None:
results['pts_instance_mask'] = pts_instance_mask[choices] input_dict['pts_instance_mask'] = pts_instance_mask[choices]
# 'eval_ann_info' will be passed to evaluator
if 'eval_ann_info' in input_dict:
input_dict['eval_ann_info']['pts_instance_mask'] = \
pts_instance_mask[choices]
return results return input_dict
def __repr__(self): def __repr__(self):
"""str: Return a string that describes the module.""" """str: Return a string that describes the module."""
...@@ -1358,14 +1372,14 @@ class IndoorPatchPointSample(object): ...@@ -1358,14 +1372,14 @@ class IndoorPatchPointSample(object):
@TRANSFORMS.register_module() @TRANSFORMS.register_module()
class BackgroundPointsFilter(object): class BackgroundPointsFilter(BaseTransform):
"""Filter background points near the bounding box. """Filter background points near the bounding box.
Args: Args:
bbox_enlarge_range (tuple[float], float): Bbox enlarge range. bbox_enlarge_range (tuple[float], float): Bbox enlarge range.
""" """
def __init__(self, bbox_enlarge_range): def __init__(self, bbox_enlarge_range: Union[Tuple[float], float]) -> None:
assert (is_tuple_of(bbox_enlarge_range, float) assert (is_tuple_of(bbox_enlarge_range, float)
and len(bbox_enlarge_range) == 3) \ and len(bbox_enlarge_range) == 3) \
or isinstance(bbox_enlarge_range, float), \ or isinstance(bbox_enlarge_range, float), \
...@@ -1376,7 +1390,7 @@ class BackgroundPointsFilter(object): ...@@ -1376,7 +1390,7 @@ class BackgroundPointsFilter(object):
self.bbox_enlarge_range = np.array( self.bbox_enlarge_range = np.array(
bbox_enlarge_range, dtype=np.float32)[np.newaxis, :] bbox_enlarge_range, dtype=np.float32)[np.newaxis, :]
def __call__(self, input_dict): def transform(self, input_dict: dict) -> dict:
"""Call function to filter points by the range. """Call function to filter points by the range.
Args: Args:
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from os import path as osp from os import path as osp
from typing import Callable, List, Optional, Union
import numpy as np import numpy as np
from mmdet3d.core import show_seg_result
from mmdet3d.core.bbox import DepthInstance3DBoxes from mmdet3d.core.bbox import DepthInstance3DBoxes
from mmdet3d.registry import DATASETS from mmdet3d.registry import DATASETS
from .custom_3d_seg import Custom3DSegDataset
from .det3d_dataset import Det3DDataset from .det3d_dataset import Det3DDataset
from .pipelines import Compose from .pipelines import Compose
from .seg3d_dataset import Seg3DDataset
@DATASETS.register_module() @DATASETS.register_module()
...@@ -153,7 +153,7 @@ class S3DISDataset(Det3DDataset): ...@@ -153,7 +153,7 @@ class S3DISDataset(Det3DDataset):
return Compose(pipeline) return Compose(pipeline)
class _S3DISSegDataset(Custom3DSegDataset): class _S3DISSegDataset(Seg3DDataset):
r"""S3DIS Dataset for Semantic Segmentation Task. r"""S3DIS Dataset for Semantic Segmentation Task.
This class is the inner dataset for S3DIS. Since S3DIS has 6 areas, we This class is the inner dataset for S3DIS. Since S3DIS has 6 areas, we
...@@ -185,114 +185,44 @@ class _S3DISSegDataset(Custom3DSegDataset): ...@@ -185,114 +185,44 @@ class _S3DISSegDataset(Custom3DSegDataset):
data. For scenes with many points, we may sample it several times. data. For scenes with many points, we may sample it several times.
Defaults to None. Defaults to None.
""" """
CLASSES = ('ceiling', 'floor', 'wall', 'beam', 'column', 'window', 'door', METAINFO = {
'table', 'chair', 'sofa', 'bookcase', 'board', 'clutter') 'CLASSES':
('ceiling', 'floor', 'wall', 'beam', 'column', 'window', 'door',
VALID_CLASS_IDS = tuple(range(13)) 'table', 'chair', 'sofa', 'bookcase', 'board', 'clutter'),
'PALETTE': [[0, 255, 0], [0, 0, 255], [0, 255, 255], [255, 255, 0],
ALL_CLASS_IDS = tuple(range(14)) # possibly with 'stair' class [255, 0, 255], [100, 100, 255], [200, 200, 100],
[170, 120, 200], [255, 0, 0], [200, 100, 100],
PALETTE = [[0, 255, 0], [0, 0, 255], [0, 255, 255], [255, 255, 0], [10, 200, 100], [200, 200, 200], [50, 50, 50]],
[255, 0, 255], [100, 100, 255], [200, 200, 100], 'valid_class_ids':
[170, 120, 200], [255, 0, 0], [200, 100, 100], [10, 200, 100], tuple(range(13)),
[200, 200, 200], [50, 50, 50]] 'all_class_ids':
tuple(range(14)) # possibly with 'stair' class
}
def __init__(self, def __init__(self,
data_root, data_root: Optional[str] = None,
ann_file, ann_file: str = '',
pipeline=None, metainfo: Optional[dict] = None,
classes=None, data_prefix: dict = dict(
palette=None, pts='points', img='', instance_mask='', semantic_mask=''),
modality=None, pipeline: List[Union[dict, Callable]] = [],
test_mode=False, modality: dict = dict(use_lidar=True, use_camera=False),
ignore_index=None, ignore_index=None,
scene_idxs=None, scene_idxs=None,
**kwargs): test_mode=False,
**kwargs) -> None:
super().__init__( super().__init__(
data_root=data_root, data_root=data_root,
ann_file=ann_file, ann_file=ann_file,
metainfo=metainfo,
data_prefix=data_prefix,
pipeline=pipeline, pipeline=pipeline,
classes=classes,
palette=palette,
modality=modality, modality=modality,
test_mode=test_mode,
ignore_index=ignore_index, ignore_index=ignore_index,
scene_idxs=scene_idxs, scene_idxs=scene_idxs,
test_mode=test_mode,
**kwargs) **kwargs)
def get_ann_info(self, index):
"""Get annotation info according to the given index.
Args:
index (int): Index of the annotation data to get.
Returns:
dict: annotation information consists of the following keys:
- pts_semantic_mask_path (str): Path of semantic masks.
"""
# Use index to get the annos, thus the evalhook could also use this api
info = self.data_infos[index]
pts_semantic_mask_path = osp.join(self.data_root,
info['pts_semantic_mask_path'])
anns_results = dict(pts_semantic_mask_path=pts_semantic_mask_path)
return anns_results
def _build_default_pipeline(self):
"""Build the default pipeline for this dataset."""
pipeline = [
dict(
type='LoadPointsFromFile',
coord_type='DEPTH',
shift_height=False,
use_color=True,
load_dim=6,
use_dim=[0, 1, 2, 3, 4, 5]),
dict(
type='LoadAnnotations3D',
with_bbox_3d=False,
with_label_3d=False,
with_mask_3d=False,
with_seg_3d=True),
dict(
type='PointSegClassMapping',
valid_cat_ids=self.VALID_CLASS_IDS,
max_cat_id=np.max(self.ALL_CLASS_IDS)),
dict(
type='DefaultFormatBundle3D',
with_label=False,
class_names=self.CLASSES),
dict(type='Collect3D', keys=['points', 'pts_semantic_mask'])
]
return Compose(pipeline)
def show(self, results, out_dir, show=True, pipeline=None):
"""Results visualization.
Args:
results (list[dict]): List of bounding boxes results.
out_dir (str): Output directory of visualization result.
show (bool): Visualize the results online.
pipeline (list[dict], optional): raw data loading for showing.
Default: None.
"""
assert out_dir is not None, 'Expect out_dir, got none.'
pipeline = self._get_pipeline(pipeline)
for i, result in enumerate(results):
data_info = self.data_infos[i]
pts_path = data_info['pts_path']
file_name = osp.split(pts_path)[-1].split('.')[0]
points, gt_sem_mask = self._extract_data(
i, pipeline, ['points', 'pts_semantic_mask'], load_annos=True)
points = points.numpy()
pred_sem_mask = result['semantic_mask'].numpy()
show_seg_result(points, gt_sem_mask,
pred_sem_mask, out_dir, file_name,
np.array(self.PALETTE), self.ignore_index, show)
def get_scene_idxs(self, scene_idxs): def get_scene_idxs(self, scene_idxs):
"""Compute scene_idxs for data sampling. """Compute scene_idxs for data sampling.
...@@ -341,16 +271,17 @@ class S3DISSegDataset(_S3DISSegDataset): ...@@ -341,16 +271,17 @@ class S3DISSegDataset(_S3DISSegDataset):
""" """
def __init__(self, def __init__(self,
data_root, data_root: Optional[str] = None,
ann_files, ann_files: str = '',
pipeline=None, metainfo: Optional[dict] = None,
classes=None, data_prefix: dict = dict(
palette=None, pts='points', img='', instance_mask='', semantic_mask=''),
modality=None, pipeline: List[Union[dict, Callable]] = [],
test_mode=False, modality: dict = dict(use_lidar=True, use_camera=False),
ignore_index=None, ignore_index=None,
scene_idxs=None, scene_idxs=None,
**kwargs): test_mode=False,
**kwargs) -> None:
# make sure that ann_files and scene_idxs have same length # make sure that ann_files and scene_idxs have same length
ann_files = self._check_ann_files(ann_files) ann_files = self._check_ann_files(ann_files)
...@@ -360,45 +291,45 @@ class S3DISSegDataset(_S3DISSegDataset): ...@@ -360,45 +291,45 @@ class S3DISSegDataset(_S3DISSegDataset):
super().__init__( super().__init__(
data_root=data_root, data_root=data_root,
ann_file=ann_files[0], ann_file=ann_files[0],
metainfo=metainfo,
data_prefix=data_prefix,
pipeline=pipeline, pipeline=pipeline,
classes=classes,
palette=palette,
modality=modality, modality=modality,
test_mode=test_mode,
ignore_index=ignore_index, ignore_index=ignore_index,
scene_idxs=scene_idxs[0], scene_idxs=scene_idxs[0],
test_mode=test_mode,
**kwargs) **kwargs)
datasets = [ datasets = [
_S3DISSegDataset( _S3DISSegDataset(
data_root=data_root, data_root=data_root,
ann_file=ann_files[i], ann_file=ann_files[i],
metainfo=metainfo,
data_prefix=data_prefix,
pipeline=pipeline, pipeline=pipeline,
classes=classes,
palette=palette,
modality=modality, modality=modality,
test_mode=test_mode,
ignore_index=ignore_index, ignore_index=ignore_index,
scene_idxs=scene_idxs[i], scene_idxs=scene_idxs[i],
test_mode=test_mode,
**kwargs) for i in range(len(ann_files)) **kwargs) for i in range(len(ann_files))
] ]
# data_infos and scene_idxs need to be concat # data_list and scene_idxs need to be concat
self.concat_data_infos([dst.data_infos for dst in datasets]) self.concat_data_list([dst.data_list for dst in datasets])
self.concat_scene_idxs([dst.scene_idxs for dst in datasets]) self.concat_scene_idxs([dst.scene_idxs for dst in datasets])
# set group flag for the sampler # set group flag for the sampler
if not self.test_mode: if not self.test_mode:
self._set_group_flag() self._set_group_flag()
def concat_data_infos(self, data_infos): def concat_data_list(self, data_lists):
"""Concat data_infos from several datasets to form self.data_infos. """Concat data_list from several datasets to form self.data_list.
Args: Args:
data_infos (list[list[dict]]) data_lists (list[list[dict]])
""" """
self.data_infos = [ self.data_list = [
info for one_data_infos in data_infos for info in one_data_infos data for data_list in data_lists for data in data_list
] ]
def concat_scene_idxs(self, scene_idxs): def concat_scene_idxs(self, scene_idxs):
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import tempfile
import warnings import warnings
from os import path as osp from os import path as osp
from typing import Callable, List, Union from typing import Callable, List, Optional, Union
import numpy as np import numpy as np
from mmdet3d.core import instance_seg_eval, show_result, show_seg_result from mmdet3d.core import show_result
from mmdet3d.core.bbox import DepthInstance3DBoxes from mmdet3d.core.bbox import DepthInstance3DBoxes
from mmdet3d.registry import DATASETS from mmdet3d.registry import DATASETS
from .custom_3d_seg import Custom3DSegDataset
from .det3d_dataset import Det3DDataset from .det3d_dataset import Det3DDataset
from .pipelines import Compose from .pipelines import Compose
from .seg3d_dataset import Seg3DDataset
@DATASETS.register_module() @DATASETS.register_module()
...@@ -193,7 +192,7 @@ class ScanNetDataset(Det3DDataset): ...@@ -193,7 +192,7 @@ class ScanNetDataset(Det3DDataset):
@DATASETS.register_module() @DATASETS.register_module()
class ScanNetSegDataset(Custom3DSegDataset): class ScanNetSegDataset(Seg3DDataset):
r"""ScanNet Dataset for Semantic Segmentation Task. r"""ScanNet Dataset for Semantic Segmentation Task.
This class serves as the API for experiments on the ScanNet Dataset. This class serves as the API for experiments on the ScanNet Dataset.
...@@ -221,135 +220,64 @@ class ScanNetSegDataset(Custom3DSegDataset): ...@@ -221,135 +220,64 @@ class ScanNetSegDataset(Custom3DSegDataset):
data. For scenes with many points, we may sample it several times. data. For scenes with many points, we may sample it several times.
Defaults to None. Defaults to None.
""" """
CLASSES = ('wall', 'floor', 'cabinet', 'bed', 'chair', 'sofa', 'table', METAINFO = {
'door', 'window', 'bookshelf', 'picture', 'counter', 'desk', 'CLASSES':
'curtain', 'refrigerator', 'showercurtrain', 'toilet', 'sink', ('wall', 'floor', 'cabinet', 'bed', 'chair', 'sofa', 'table', 'door',
'bathtub', 'otherfurniture') 'window', 'bookshelf', 'picture', 'counter', 'desk', 'curtain',
'refrigerator', 'showercurtrain', 'toilet', 'sink', 'bathtub',
VALID_CLASS_IDS = (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 'otherfurniture'),
33, 34, 36, 39) 'PALETTE': [
[174, 199, 232],
ALL_CLASS_IDS = tuple(range(41)) [152, 223, 138],
[31, 119, 180],
PALETTE = [ [255, 187, 120],
[174, 199, 232], [188, 189, 34],
[152, 223, 138], [140, 86, 75],
[31, 119, 180], [255, 152, 150],
[255, 187, 120], [214, 39, 40],
[188, 189, 34], [197, 176, 213],
[140, 86, 75], [148, 103, 189],
[255, 152, 150], [196, 156, 148],
[214, 39, 40], [23, 190, 207],
[197, 176, 213], [247, 182, 210],
[148, 103, 189], [219, 219, 141],
[196, 156, 148], [255, 127, 14],
[23, 190, 207], [158, 218, 229],
[247, 182, 210], [44, 160, 44],
[219, 219, 141], [112, 128, 144],
[255, 127, 14], [227, 119, 194],
[158, 218, 229], [82, 84, 163],
[44, 160, 44], ],
[112, 128, 144], 'valid_class_ids': (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24,
[227, 119, 194], 28, 33, 34, 36, 39),
[82, 84, 163], 'all_class_ids':
] tuple(range(41)),
}
def __init__(self, def __init__(self,
data_root, data_root: Optional[str] = None,
ann_file, ann_file: str = '',
pipeline=None, metainfo: Optional[dict] = None,
classes=None, data_prefix: dict = dict(
palette=None, pts='points', img='', instance_mask='', semantic_mask=''),
modality=None, pipeline: List[Union[dict, Callable]] = [],
test_mode=False, modality: dict = dict(use_lidar=True, use_camera=False),
ignore_index=None, ignore_index=None,
scene_idxs=None, scene_idxs=None,
**kwargs): test_mode=False,
**kwargs) -> None:
super().__init__( super().__init__(
data_root=data_root, data_root=data_root,
ann_file=ann_file, ann_file=ann_file,
metainfo=metainfo,
data_prefix=data_prefix,
pipeline=pipeline, pipeline=pipeline,
classes=classes,
palette=palette,
modality=modality, modality=modality,
test_mode=test_mode,
ignore_index=ignore_index, ignore_index=ignore_index,
scene_idxs=scene_idxs, scene_idxs=scene_idxs,
test_mode=test_mode,
**kwargs) **kwargs)
def get_ann_info(self, index):
"""Get annotation info according to the given index.
Args:
index (int): Index of the annotation data to get.
Returns:
dict: annotation information consists of the following keys:
- pts_semantic_mask_path (str): Path of semantic masks.
"""
# Use index to get the annos, thus the evalhook could also use this api
info = self.data_infos[index]
pts_semantic_mask_path = osp.join(self.data_root,
info['pts_semantic_mask_path'])
anns_results = dict(pts_semantic_mask_path=pts_semantic_mask_path)
return anns_results
def _build_default_pipeline(self):
"""Build the default pipeline for this dataset."""
pipeline = [
dict(
type='LoadPointsFromFile',
coord_type='DEPTH',
shift_height=False,
use_color=True,
load_dim=6,
use_dim=[0, 1, 2, 3, 4, 5]),
dict(
type='LoadAnnotations3D',
with_bbox_3d=False,
with_label_3d=False,
with_mask_3d=False,
with_seg_3d=True),
dict(
type='PointSegClassMapping',
valid_cat_ids=self.VALID_CLASS_IDS,
max_cat_id=np.max(self.ALL_CLASS_IDS)),
dict(
type='DefaultFormatBundle3D',
with_label=False,
class_names=self.CLASSES),
dict(type='Collect3D', keys=['points', 'pts_semantic_mask'])
]
return Compose(pipeline)
def show(self, results, out_dir, show=True, pipeline=None):
"""Results visualization.
Args:
results (list[dict]): List of bounding boxes results.
out_dir (str): Output directory of visualization result.
show (bool): Visualize the results online.
pipeline (list[dict], optional): raw data loading for showing.
Default: None.
"""
assert out_dir is not None, 'Expect out_dir, got none.'
pipeline = self._get_pipeline(pipeline)
for i, result in enumerate(results):
data_info = self.data_infos[i]
pts_path = data_info['pts_path']
file_name = osp.split(pts_path)[-1].split('.')[0]
points, gt_sem_mask = self._extract_data(
i, pipeline, ['points', 'pts_semantic_mask'], load_annos=True)
points = points.numpy()
pred_sem_mask = result['semantic_mask'].numpy()
show_seg_result(points, gt_sem_mask,
pred_sem_mask, out_dir, file_name,
np.array(self.PALETTE), self.ignore_index, show)
def get_scene_idxs(self, scene_idxs): def get_scene_idxs(self, scene_idxs):
"""Compute scene_idxs for data sampling. """Compute scene_idxs for data sampling.
...@@ -362,191 +290,65 @@ class ScanNetSegDataset(Custom3DSegDataset): ...@@ -362,191 +290,65 @@ class ScanNetSegDataset(Custom3DSegDataset):
return super().get_scene_idxs(scene_idxs) return super().get_scene_idxs(scene_idxs)
def format_results(self, results, txtfile_prefix=None):
r"""Format the results to txt file. Refer to `ScanNet documentation
<http://kaldir.vc.in.tum.de/scannet_benchmark/documentation>`_.
Args:
outputs (list[dict]): Testing results of the dataset.
txtfile_prefix (str): The prefix of saved files. It includes
the file path and the prefix of filename, e.g., "a/b/prefix".
If not specified, a temp file will be created. Default: None.
Returns:
tuple: (outputs, tmp_dir), outputs is the detection results,
tmp_dir is the temporal directory created for saving submission
files when ``submission_prefix`` is not specified.
"""
import mmcv
if txtfile_prefix is None:
tmp_dir = tempfile.TemporaryDirectory()
txtfile_prefix = osp.join(tmp_dir.name, 'results')
else:
tmp_dir = None
mmcv.mkdir_or_exist(txtfile_prefix)
# need to map network output to original label idx
pred2label = np.zeros(len(self.VALID_CLASS_IDS)).astype(np.int)
for original_label, output_idx in self.label_map.items():
if output_idx != self.ignore_index:
pred2label[output_idx] = original_label
outputs = []
for i, result in enumerate(results):
info = self.data_infos[i]
sample_idx = info['point_cloud']['lidar_idx']
pred_sem_mask = result['semantic_mask'].numpy().astype(np.int)
pred_label = pred2label[pred_sem_mask]
curr_file = f'{txtfile_prefix}/{sample_idx}.txt'
np.savetxt(curr_file, pred_label, fmt='%d')
outputs.append(dict(seg_mask=pred_label))
return outputs, tmp_dir
@DATASETS.register_module() @DATASETS.register_module()
class ScanNetInstanceSegDataset(Custom3DSegDataset): class ScanNetInstanceSegDataset(Seg3DDataset):
CLASSES = ('cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window',
'bookshelf', 'picture', 'counter', 'desk', 'curtain',
'refrigerator', 'showercurtrain', 'toilet', 'sink', 'bathtub',
'garbagebin')
VALID_CLASS_IDS = (3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34,
36, 39)
ALL_CLASS_IDS = tuple(range(41))
def get_ann_info(self, index):
"""Get annotation info according to the given index.
Args:
index (int): Index of the annotation data to get.
Returns: METAINFO = {
dict: annotation information consists of the following keys: 'CLASSES':
- pts_semantic_mask_path (str): Path of semantic masks. ('cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window',
- pts_instance_mask_path (str): Path of instance masks. 'bookshelf', 'picture', 'counter', 'desk', 'curtain', 'refrigerator',
""" 'showercurtrain', 'toilet', 'sink', 'bathtub', 'garbagebin'),
# Use index to get the annos, thus the evalhook could also use this api 'PLATTE': [
info = self.data_infos[index] [174, 199, 232],
[152, 223, 138],
pts_instance_mask_path = osp.join(self.data_root, [31, 119, 180],
info['pts_instance_mask_path']) [255, 187, 120],
pts_semantic_mask_path = osp.join(self.data_root, [188, 189, 34],
info['pts_semantic_mask_path']) [140, 86, 75],
[255, 152, 150],
anns_results = dict( [214, 39, 40],
pts_instance_mask_path=pts_instance_mask_path, [197, 176, 213],
pts_semantic_mask_path=pts_semantic_mask_path) [148, 103, 189],
return anns_results [196, 156, 148],
[23, 190, 207],
def get_classes_and_palette(self, classes=None, palette=None): [247, 182, 210],
"""Get class names of current dataset. Palette is simply ignored for [219, 219, 141],
instance segmentation. [255, 127, 14],
[158, 218, 229],
Args: [44, 160, 44],
classes (Sequence[str] | str | None): If classes is None, use [112, 128, 144],
default CLASSES defined by builtin dataset. If classes is a [227, 119, 194],
string, take it as a file name. The file contains the name of [82, 84, 163],
classes where each line contains one class name. If classes is ],
a tuple or list, override the CLASSES defined by the dataset. 'valid_class_ids':
Defaults to None. (3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34, 36, 39),
palette (Sequence[Sequence[int]]] | np.ndarray | None): 'all_class_ids':
The palette of segmentation map. If None is given, random tuple(range(41))
palette will be generated. Defaults to None. }
"""
if classes is not None:
return classes, None
return self.CLASSES, None
def _build_default_pipeline(self):
"""Build the default pipeline for this dataset."""
pipeline = [
dict(
type='LoadPointsFromFile',
coord_type='DEPTH',
shift_height=False,
use_color=True,
load_dim=6,
use_dim=[0, 1, 2, 3, 4, 5]),
dict(
type='LoadAnnotations3D',
with_bbox_3d=False,
with_label_3d=False,
with_mask_3d=True,
with_seg_3d=True),
dict(
type='PointSegClassMapping',
valid_cat_ids=self.VALID_CLASS_IDS,
max_cat_id=40),
dict(
type='DefaultFormatBundle3D',
with_label=False,
class_names=self.CLASSES),
dict(
type='Collect3D',
keys=['points', 'pts_semantic_mask', 'pts_instance_mask'])
]
return Compose(pipeline)
def evaluate(self,
results,
metric=None,
options=None,
logger=None,
show=False,
out_dir=None,
pipeline=None):
"""Evaluation in instance segmentation protocol.
Args:
results (list[dict]): List of results.
metric (str | list[str]): Metrics to be evaluated.
options (dict, optional): options for instance_seg_eval.
logger (logging.Logger | None | str): Logger used for printing
related information during evaluation. Defaults to None.
show (bool, optional): Whether to visualize.
Defaults to False.
out_dir (str, optional): Path to save the visualization results.
Defaults to None.
pipeline (list[dict], optional): raw data loading for showing.
Default: None.
Returns: def __init__(self,
dict: Evaluation results. data_root: Optional[str] = None,
""" ann_file: str = '',
assert isinstance( metainfo: Optional[dict] = None,
results, list), f'Expect results to be list, got {type(results)}.' data_prefix: dict = dict(
assert len(results) > 0, 'Expect length of results > 0.' pts='points', img='', instance_mask='', semantic_mask=''),
assert len(results) == len(self.data_infos) pipeline: List[Union[dict, Callable]] = [],
assert isinstance( modality: dict = dict(use_lidar=True, use_camera=False),
results[0], dict test_mode=False,
), f'Expect elements in results to be dict, got {type(results[0])}.' ignore_index=None,
scene_idxs=None,
load_pipeline = self._get_pipeline(pipeline) file_client_args=dict(backend='disk'),
pred_instance_masks = [result['instance_mask'] for result in results] **kwargs) -> None:
pred_instance_labels = [result['instance_label'] for result in results] super().__init__(
pred_instance_scores = [result['instance_score'] for result in results] data_root=data_root,
gt_semantic_masks, gt_instance_masks = zip(*[ ann_file=ann_file,
self._extract_data( metainfo=metainfo,
index=i, pipeline=pipeline,
pipeline=load_pipeline, data_prefix=data_prefix,
key=['pts_semantic_mask', 'pts_instance_mask'], modality=modality,
load_annos=True) for i in range(len(self.data_infos)) test_mode=test_mode,
]) ignore_index=ignore_index,
ret_dict = instance_seg_eval( scene_idxs=scene_idxs,
gt_semantic_masks, file_client_args=file_client_args,
gt_instance_masks, **kwargs)
pred_instance_masks,
pred_instance_labels,
pred_instance_scores,
valid_class_ids=self.VALID_CLASS_IDS,
class_labels=self.CLASSES,
options=options,
logger=logger)
if show:
raise NotImplementedError('show is not implemented for now')
return ret_dict
# Copyright (c) OpenMMLab. All rights reserved.
from os import path as osp
from typing import Callable, Dict, List, Optional, Sequence, Union
import mmcv
import numpy as np
from mmengine.dataset import BaseDataset
from mmdet3d.registry import DATASETS
@DATASETS.register_module()
class Seg3DDataset(BaseDataset):
"""Base Class for 3D semantic segmentation dataset.
This is the base dataset of ScanNet, S3DIS and SemanticKITTI dataset.
Args:
data_root (str): Path of dataset root.
ann_file (str): Path of annotation file.
pipeline (list[dict], optional): Pipeline used for data processing.
Defaults to None.
metainfo (dict, optional): Meta information for dataset, such as class
information. Defaults to None.
data_prefix (dict, optional): Prefix for training data. Defaults to
dict(pts='velodyne', img='', instance_mask='', semantic_mask='').
pipeline (list[dict], optional): Pipeline used for data processing.
Defaults to None.
modality (dict, optional): Modality to specify the sensor data used
as input, it usually has following keys.
- use_camera: bool
- use_lidar: bool
Defaults to `dict(use_lidar=True, use_camera=False)`
test_mode (bool, optional): Whether the dataset is in test mode.
Defaults to False.
ignore_index (int, optional): The label index to be ignored, e.g.
unannotated points. If None is given, set to len(self.CLASSES) to
be consistent with PointSegClassMapping function in pipeline.
Defaults to None.
scene_idxs (np.ndarray | str, optional): Precomputed index to load
data. For scenes with many points, we may sample it several times.
Defaults to None.
load_eval_anns (bool): Whether to load annotations
in test_mode, the annotation will be save in
`eval_ann_infos`, which can be use in Evaluator.
file_client_args (dict): Configuration of file client.
Defaults to `dict(backend='disk')`.
"""
METAINFO = {
'CLASSES': None, # names of all classes data used for the task
'PALETTE': None, # official color for visualization
'valid_class_ids': None, # class_ids used for training
'all_class_ids': None, # all possible class_ids in loaded seg mask
}
def __init__(self,
data_root: Optional[str] = None,
ann_file: str = '',
metainfo: Optional[dict] = None,
data_prefix: dict = dict(
pts='points',
img='',
pts_instance_mask='',
pts_emantic_mask=''),
pipeline: List[Union[dict, Callable]] = [],
modality: dict = dict(use_lidar=True, use_camera=False),
ignore_index: Optional[int] = None,
scene_idxs: Optional[str] = None,
test_mode: bool = False,
load_eval_anns: bool = True,
file_client_args: dict = dict(backend='disk'),
**kwargs) -> None:
# init file client
self.file_client = mmcv.FileClient(**file_client_args)
self.modality = modality
self.load_eval_anns = load_eval_anns
# TODO: We maintain the ignore_index attributes,
# but we may consider to remove it in the future.
self.ignore_index = len(self.METAINFO['CLASSES']) if \
ignore_index is None else ignore_index
# Get label mapping for custom classes
new_classes = metainfo.get('CLASSES', None)
self.label_mapping, self.label2cat, valid_class_ids = \
self.get_label_mapping(new_classes)
metainfo['label_mapping'] = self.label_mapping
metainfo['label2cat'] = self.label2cat
metainfo['valid_class_ids'] = valid_class_ids
# generate palette if it is not defined based on
# label mapping, otherwise directly use palette
# defined in dataset config.
palette = metainfo.get('PALETTE', None)
updated_palette = self._update_palette(new_classes, palette)
metainfo['PALETTE'] = updated_palette
super().__init__(
ann_file=ann_file,
metainfo=metainfo,
data_root=data_root,
data_prefix=data_prefix,
pipeline=pipeline,
test_mode=test_mode,
**kwargs)
self.scene_idxs = self.get_scene_idxs(scene_idxs)
# set group flag for the sampler
if not self.test_mode:
self._set_group_flag()
def get_label_mapping(self,
new_classes: Optional[Sequence] = None
) -> Union[Dict, None]:
"""Get label mapping.
The ``label_mapping`` is a dictionary, its keys are the old label ids
and its values are the new label ids, and is used for changing pixel
labels in load_annotations. If and only if old classes in cls.METAINFO
is not equal to new classes in self._metainfo and nether of them is not
None, `label_mapping` is not None.
Args:
new_classes (list, tuple, optional): The new classes name from
metainfo. Default to None.
Returns:
tuple: The mapping from old classes in cls.METAINFO to
new classes in metainfo
"""
old_classes = self.METAINFO.get('CLASSSES', None)
if (new_classes is not None and old_classes is not None
and list(new_classes) != list(old_classes)):
label_mapping = {}
if not set(new_classes).issubset(old_classes):
raise ValueError(
f'new classes {new_classes} is not a '
f'subset of CLASSES {old_classes} in METAINFO.')
# obtain true id from valid_class_ids
valid_class_ids = [
self.METAINFO['valid_class_ids'][old_classes.index(cls_name)]
for cls_name in new_classes
]
label_mapping = {
cls_id: self.ignore_index
for cls_id in self.METAINFO['all_class_ids']
}
label_mapping.update(
{cls_id: i
for i, cls_id in enumerate(valid_class_ids)})
label2cat = {i: cat_name for i, cat_name in enumerate(new_classes)}
else:
label_mapping = {
cls_id: self.ignore_index
for cls_id in self.METAINFO['all_class_ids']
}
label_mapping.update({
cls_id: i
for i, cls_id in enumerate(self.METAINFO['valid_class_ids'])
})
# map label to category name
label2cat = {
i: cat_name
for i, cat_name in enumerate(self.METAINFO['CLASSES'])
}
valid_class_ids = self.METAINFO['valid_class_ids']
return label_mapping, label2cat, valid_class_ids
def _update_palette(self, new_classes, palette) -> list:
"""Update palette according to metainfo.
If length of palette is equal to classes, just return the palette.
If palette is not defined, it will randomly generate a palette.
If classes is updated by customer, it will return the subset of
palette.
Returns:
Sequence: Palette for current dataset.
"""
if palette is None:
# If palette is not defined, it generate a palette according
# to the original PALETTE and classes.
old_classes = self.METAINFO.get('CLASSSES', None)
palette = [
self.METAINFO['PALETTE'][old_classes.index(cls_name)]
for cls_name in new_classes
]
return palette
# palette does match classes
if len(palette) == len(new_classes):
return palette
else:
raise ValueError('Once PLATTE in set in metainfo, it should'
'match CLASSES in metainfo')
def parse_data_info(self, info: dict) -> dict:
"""Process the raw data info.
Convert all relative path of needed modality data file to
the absolute path. And process
the `instances` field to `ann_info` in training stage.
Args:
info (dict): Raw info dict.
Returns:
dict: Has `ann_info` in training stage. And
all path has been converted to absolute path.
"""
if self.modality['use_lidar']:
info['lidar_points']['lidar_path'] = \
osp.join(
self.data_prefix.get('pts', ''),
info['lidar_points']['lidar_path'])
if self.modality['use_camera']:
for cam_id, img_info in info['images'].items():
if 'img_path' in img_info:
img_info['img_path'] = osp.join(
self.data_prefix.get('img', ''), img_info['img_path'])
if 'pts_instance_mask_path' in info:
info['pts_instance_mask_path'] = \
osp.join(self.data_prefix.get('pts_instance_mask', ''),
info['pts_instance_mask_path'])
if 'pts_semantic_mask_path' in info:
info['pts_semantic_mask_path'] = \
osp.join(self.data_prefix.get('pts_semantic_mask', ''),
info['pts_semantic_mask_path'])
# Add label_mapping to input dict for directly
# use it in PointSegClassMapping pipeline
info['label_mapping'] = self.label_mapping
# 'eval_ann_info' will be updated in loading pipelines
if self.test_mode and self.load_eval_anns:
info['eval_ann_info'] = dict()
return info
def get_scene_idxs(self, scene_idxs):
"""Compute scene_idxs for data sampling.
We sample more times for scenes with more points.
"""
if self.test_mode:
# when testing, we load one whole scene every time
return np.arange(len(self.data_list)).astype(np.int32)
# we may need to re-sample different scenes according to scene_idxs
# this is necessary for indoor scene segmentation such as ScanNet
if scene_idxs is None:
scene_idxs = np.arange(len(self.data_list))
if isinstance(scene_idxs, str):
with self.file_client.get_local_path(scene_idxs) as local_path:
scene_idxs = np.load(local_path)
else:
scene_idxs = np.array(scene_idxs)
return scene_idxs.astype(np.int32)
def _set_group_flag(self):
"""Set flag according to image aspect ratio.
Images with aspect ratio greater than 1 will be set as group 1,
otherwise group 0. In 3D datasets, they are all the same, thus are all
zeros.
"""
self.flag = np.zeros(len(self), dtype=np.uint8)
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from os import path as osp from typing import Callable, List, Optional, Union
from mmdet3d.registry import DATASETS from mmdet3d.registry import DATASETS
from .det3d_dataset import Det3DDataset from .seg3d_dataset import Seg3DDataset
@DATASETS.register_module() @DATASETS.register_module()
class SemanticKITTIDataset(Det3DDataset): class SemanticKITTIDataset(Seg3DDataset):
r"""SemanticKITTI Dataset. r"""SemanticKITTI Dataset.
This class serves as the API for experiments on the SemanticKITTI Dataset This class serves as the API for experiments on the SemanticKITTI Dataset
...@@ -36,75 +36,38 @@ class SemanticKITTIDataset(Det3DDataset): ...@@ -36,75 +36,38 @@ class SemanticKITTIDataset(Det3DDataset):
test_mode (bool, optional): Whether the dataset is in test mode. test_mode (bool, optional): Whether the dataset is in test mode.
Defaults to False. Defaults to False.
""" """
CLASSES = ('unlabeled', 'car', 'bicycle', 'motorcycle', 'truck', 'bus', METAINFO = {
'person', 'bicyclist', 'motorcyclist', 'road', 'parking', 'CLASSES': ('unlabeled', 'car', 'bicycle', 'motorcycle', 'truck',
'sidewalk', 'other-ground', 'building', 'fence', 'vegetation', 'bus', 'person', 'bicyclist', 'motorcyclist', 'road',
'trunck', 'terrian', 'pole', 'traffic-sign') 'parking', 'sidewalk', 'other-ground', 'building', 'fence',
'vegetation', 'trunck', 'terrian', 'pole', 'traffic-sign'),
'valid_class_ids':
tuple(range(20)),
'all_class_ids':
tuple(range(20))
}
def __init__(self, def __init__(self,
data_root, data_root: Optional[str] = None,
ann_file, ann_file: str = '',
pipeline=None, metainfo: Optional[dict] = None,
classes=None, data_prefix: dict = dict(
modality=None, pts='points', img='', instance_mask='', semantic_mask=''),
box_type_3d='Lidar', pipeline: List[Union[dict, Callable]] = [],
filter_empty_gt=False, modality: dict = dict(use_lidar=True, use_camera=False),
test_mode=False): ignore_index=None,
scene_idxs=None,
test_mode=False,
**kwargs) -> None:
super().__init__( super().__init__(
data_root=data_root, data_root=data_root,
ann_file=ann_file, ann_file=ann_file,
metainfo=metainfo,
data_prefix=data_prefix,
pipeline=pipeline, pipeline=pipeline,
classes=classes,
modality=modality, modality=modality,
box_type_3d=box_type_3d, ignore_index=ignore_index,
filter_empty_gt=filter_empty_gt, scene_idxs=scene_idxs,
test_mode=test_mode) test_mode=test_mode,
**kwargs)
def get_data_info(self, index):
"""Get data info according to the given index.
Args:
index (int): Index of the sample data to get.
Returns:
dict: Data information that will be passed to the data
preprocessing pipelines. It includes the following keys:
- sample_idx (str): Sample index.
- pts_filename (str): Filename of point clouds.
- file_name (str): Filename of point clouds.
- ann_info (dict): Annotation info.
"""
info = self.data_infos[index]
sample_idx = info['point_cloud']['lidar_idx']
pts_filename = osp.join(self.data_root, info['pts_path'])
input_dict = dict(
pts_filename=pts_filename,
sample_idx=sample_idx,
file_name=pts_filename)
if not self.test_mode:
annos = self.get_ann_info(index)
input_dict['ann_info'] = annos
if self.filter_empty_gt and ~(annos['gt_labels_3d'] != -1).any():
return None
return input_dict
def get_ann_info(self, index):
"""Get annotation info according to the given index.
Args:
index (int): Index of the annotation data to get.
Returns:
dict: annotation information consists of the following keys:
- pts_semantic_mask_path (str): Path of semantic masks.
"""
# Use index to get the annos, thus the evalhook could also use this api
info = self.data_infos[index]
pts_semantic_mask_path = osp.join(self.data_root,
info['pts_semantic_mask_path'])
anns_results = dict(pts_semantic_mask_path=pts_semantic_mask_path)
return anns_results
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase from unittest import TestCase
import numpy as np import numpy as np
import pytest import pytest
import torch import torch
from mmengine.data import InstanceData, PixelData from mmengine.data import InstanceData
from mmdet3d.core.data_structures import Det3DDataSample from mmdet3d.core.data_structures import Det3DDataSample, PointData
def _equal(a, b): def _equal(a, b):
...@@ -86,47 +87,34 @@ class TestDet3DataSample(TestCase): ...@@ -86,47 +87,34 @@ class TestDet3DataSample(TestCase):
assert _equal(det3d_data_sample.img_pred_instances_3d.scores_3d, assert _equal(det3d_data_sample.img_pred_instances_3d.scores_3d,
img_pred_instances_3d_data['scores_3d']) img_pred_instances_3d_data['scores_3d'])
# test gt_panoptic_seg # test gt_seg
gt_pts_panoptic_seg_data = dict(panoptic_seg=torch.rand(5, 4)) gt_pts_seg_data = dict(
gt_pts_panoptic_seg = PixelData(**gt_pts_panoptic_seg_data) pts_instance_mask=torch.rand(20), pts_semantic_mask=torch.rand(20))
det3d_data_sample.gt_pts_panoptic_seg = gt_pts_panoptic_seg gt_pts_seg = PointData(**gt_pts_seg_data)
assert 'gt_pts_panoptic_seg' in det3d_data_sample det3d_data_sample.gt_pts_seg = gt_pts_seg
assert _equal(det3d_data_sample.gt_pts_panoptic_seg.panoptic_seg, assert 'gt_pts_seg' in det3d_data_sample
gt_pts_panoptic_seg_data['panoptic_seg']) assert _equal(det3d_data_sample.gt_pts_seg.pts_instance_mask,
gt_pts_seg_data['pts_instance_mask'])
# test pred_panoptic_seg assert _equal(det3d_data_sample.gt_pts_seg.pts_semantic_mask,
pred_pts_panoptic_seg_data = dict(panoptic_seg=torch.rand(5, 4)) gt_pts_seg_data['pts_semantic_mask'])
pred_pts_panoptic_seg = PixelData(**pred_pts_panoptic_seg_data)
det3d_data_sample.pred_pts_panoptic_seg = pred_pts_panoptic_seg # test pred_seg
assert 'pred_pts_panoptic_seg' in det3d_data_sample pred_pts_seg_data = dict(
assert _equal(det3d_data_sample.pred_pts_panoptic_seg.panoptic_seg, pts_instance_mask=torch.rand(20), pts_semantic_mask=torch.rand(20))
pred_pts_panoptic_seg_data['panoptic_seg']) pred_pts_seg = PointData(**pred_pts_seg_data)
det3d_data_sample.pred_pts_seg = pred_pts_seg
# test gt_sem_seg assert 'pred_pts_seg' in det3d_data_sample
gt_pts_sem_seg_data = dict(segm_seg=torch.rand(5, 4, 2)) assert _equal(det3d_data_sample.pred_pts_seg.pts_instance_mask,
gt_pts_sem_seg = PixelData(**gt_pts_sem_seg_data) pred_pts_seg_data['pts_instance_mask'])
det3d_data_sample.gt_pts_sem_seg = gt_pts_sem_seg assert _equal(det3d_data_sample.pred_pts_seg.pts_semantic_mask,
assert 'gt_pts_sem_seg' in det3d_data_sample pred_pts_seg_data['pts_semantic_mask'])
assert _equal(det3d_data_sample.gt_pts_sem_seg.segm_seg,
gt_pts_sem_seg_data['segm_seg'])
# test pred_segm_seg
pred_pts_sem_seg_data = dict(segm_seg=torch.rand(5, 4, 2))
pred_pts_sem_seg = PixelData(**pred_pts_sem_seg_data)
det3d_data_sample.pred_pts_sem_seg = pred_pts_sem_seg
assert 'pred_pts_sem_seg' in det3d_data_sample
assert _equal(det3d_data_sample.pred_pts_sem_seg.segm_seg,
pred_pts_sem_seg_data['segm_seg'])
# test type error # test type error
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
det3d_data_sample.pred_instances_3d = torch.rand(2, 4) det3d_data_sample.pred_instances_3d = torch.rand(2, 4)
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
det3d_data_sample.pred_pts_panoptic_seg = torch.rand(2, 4) det3d_data_sample.pred_pts_seg = torch.rand(20)
with pytest.raises(AssertionError):
det3d_data_sample.pred_pts_sem_seg = torch.rand(2, 4)
def test_deleter(self): def test_deleter(self):
tmp_instances_3d_data = dict( tmp_instances_3d_data = dict(
...@@ -157,17 +145,10 @@ class TestDet3DataSample(TestCase): ...@@ -157,17 +145,10 @@ class TestDet3DataSample(TestCase):
del det3d_data_sample.img_pred_instances_3d del det3d_data_sample.img_pred_instances_3d
assert 'img_pred_instances_3d' not in det3d_data_sample assert 'img_pred_instances_3d' not in det3d_data_sample
pred_pts_panoptic_seg_data = torch.rand(5, 4) pred_pts_seg_data = dict(
pred_pts_panoptic_seg_data = PixelData(data=pred_pts_panoptic_seg_data) pts_instance_mask=torch.rand(20), pts_semantic_mask=torch.rand(20))
det3d_data_sample.pred_pts_panoptic_seg_data = \ pred_pts_seg = PointData(**pred_pts_seg_data)
pred_pts_panoptic_seg_data det3d_data_sample.pred_pts_seg = pred_pts_seg
assert 'pred_pts_panoptic_seg_data' in det3d_data_sample assert 'pred_pts_seg' in det3d_data_sample
del det3d_data_sample.pred_pts_panoptic_seg_data del det3d_data_sample.pred_pts_seg
assert 'pred_pts_panoptic_seg_data' not in det3d_data_sample assert 'pred_pts_seg' not in det3d_data_sample
pred_pts_sem_seg_data = dict(segm_seg=torch.rand(5, 4, 2))
pred_pts_sem_seg = PixelData(**pred_pts_sem_seg_data)
det3d_data_sample.pred_pts_sem_seg = pred_pts_sem_seg
assert 'pred_pts_sem_seg' in det3d_data_sample
del det3d_data_sample.pred_pts_sem_seg
assert 'pred_pts_sem_seg' not in det3d_data_sample
# Copyright (c) OpenMMLab. All rights reserved.
import unittest
import numpy as np
import torch
from mmdet3d.datasets import S3DISSegDataset
from mmdet3d.utils import register_all_modules
def _generate_s3dis_seg_dataset_config():
data_root = './tests/data/s3dis/'
ann_file = 's3dis_infos.pkl'
classes = ('ceiling', 'floor', 'wall', 'beam', 'column', 'window', 'door',
'table', 'chair', 'sofa', 'bookcase', 'board', 'clutter')
palette = [[0, 255, 0], [0, 0, 255], [0, 255, 255], [255, 255, 0],
[255, 0, 255], [100, 100, 255], [200, 200, 100],
[170, 120, 200], [255, 0, 0], [200, 100, 100], [10, 200, 100],
[200, 200, 200], [50, 50, 50]]
scene_idxs = [0 for _ in range(20)]
modality = dict(use_lidar=True, use_camera=False)
pipeline = [
dict(
type='LoadPointsFromFile',
coord_type='DEPTH',
shift_height=False,
use_color=True,
load_dim=6,
use_dim=[0, 1, 2, 3, 4, 5]),
dict(
type='LoadAnnotations3D',
with_bbox_3d=False,
with_label_3d=False,
with_mask_3d=False,
with_seg_3d=True),
dict(type='PointSegClassMapping'),
dict(
type='IndoorPatchPointSample',
num_points=5,
block_size=1.0,
ignore_index=len(classes),
use_normalized_coord=True,
enlarge_size=0.2,
min_unique_num=None),
dict(type='NormalizePointsColor', color_mean=None),
dict(type='Pack3DDetInputs', keys=['points', 'pts_semantic_mask'])
]
data_prefix = dict(
pts='points',
pts_instance_mask='instance_mask',
pts_semantic_mask='semantic_mask')
return (data_root, ann_file, classes, palette, scene_idxs, data_prefix,
pipeline, modality)
class TestS3DISDataset(unittest.TestCase):
def test_s3dis_seg(self):
np.random.seed(0)
data_root, ann_file, classes, palette, scene_idxs, data_prefix, \
pipeline, modality, = _generate_s3dis_seg_dataset_config()
register_all_modules()
s3dis_seg_dataset = S3DISSegDataset(
data_root,
ann_file,
metainfo=dict(CLASSES=classes, PALETTE=palette),
data_prefix=data_prefix,
pipeline=pipeline,
modality=modality,
scene_idxs=scene_idxs)
input_dict = s3dis_seg_dataset.prepare_data(0)
points = input_dict['inputs']['points']
data_sample = input_dict['data_sample']
pts_semantic_mask = data_sample.gt_pts_seg.pts_semantic_mask
expected_points = torch.tensor([[
0.0000, 0.0000, 3.1720, 0.4706, 0.4431, 0.3725, 0.4624, 0.7502,
0.9543
],
[
0.2880, -0.5900, 0.0650, 0.3451,
0.3373, 0.3490, 0.5119, 0.5518,
0.0196
],
[
0.1570, 0.6000, 3.1700, 0.4941,
0.4667, 0.3569, 0.4893, 0.9519,
0.9537
],
[
-0.1320, 0.3950, 0.2720, 0.3216,
0.2863, 0.2275, 0.4397, 0.8830,
0.0818
],
[
-0.4860, -0.0640, 3.1710, 0.3843,
0.3725, 0.3059, 0.3789, 0.7286,
0.9540
]])
expected_pts_semantic_mask = np.array([0, 1, 0, 8, 0])
assert torch.allclose(points, expected_points, 1e-2)
self.assertTrue(
(pts_semantic_mask.numpy() == expected_pts_semantic_mask).all())
...@@ -6,7 +6,74 @@ import torch ...@@ -6,7 +6,74 @@ import torch
from mmengine.testing import assert_allclose from mmengine.testing import assert_allclose
from mmdet3d.core import DepthInstance3DBoxes from mmdet3d.core import DepthInstance3DBoxes
from mmdet3d.datasets import ScanNetDataset from mmdet3d.datasets import ScanNetDataset, ScanNetSegDataset
from mmdet3d.utils import register_all_modules
def _generate_scannet_seg_dataset_config():
data_root = './tests/data/scannet/'
ann_file = 'scannet_infos.pkl'
classes = ('wall', 'floor', 'cabinet', 'bed', 'chair', 'sofa', 'table',
'door', 'window', 'bookshelf', 'picture', 'counter', 'desk',
'curtain', 'refrigerator', 'showercurtrain', 'toilet', 'sink',
'bathtub', 'otherfurniture')
palette = [
[174, 199, 232],
[152, 223, 138],
[31, 119, 180],
[255, 187, 120],
[188, 189, 34],
[140, 86, 75],
[255, 152, 150],
[214, 39, 40],
[197, 176, 213],
[148, 103, 189],
[196, 156, 148],
[23, 190, 207],
[247, 182, 210],
[219, 219, 141],
[255, 127, 14],
[158, 218, 229],
[44, 160, 44],
[112, 128, 144],
[227, 119, 194],
[82, 84, 163],
]
scene_idxs = [0 for _ in range(20)]
modality = dict(use_lidar=True, use_camera=False)
pipeline = [
dict(
type='LoadPointsFromFile',
coord_type='DEPTH',
shift_height=False,
use_color=True,
load_dim=6,
use_dim=[0, 1, 2, 3, 4, 5]),
dict(
type='LoadAnnotations3D',
with_bbox_3d=False,
with_label_3d=False,
with_mask_3d=False,
with_seg_3d=True),
dict(type='PointSegClassMapping'),
dict(
type='IndoorPatchPointSample',
num_points=5,
block_size=1.5,
ignore_index=len(classes),
use_normalized_coord=True,
enlarge_size=0.2,
min_unique_num=None),
dict(type='NormalizePointsColor', color_mean=None),
dict(type='Pack3DDetInputs', keys=['points', 'pts_semantic_mask'])
]
data_prefix = dict(
pts='points',
pts_instance_mask='instance_mask',
pts_semantic_mask='semantic_mask')
return (data_root, ann_file, classes, palette, scene_idxs, data_prefix,
pipeline, modality)
def _generate_scannet_dataset_config(): def _generate_scannet_dataset_config():
...@@ -92,3 +159,54 @@ class TestScanNetDataset(unittest.TestCase): ...@@ -92,3 +159,54 @@ class TestScanNetDataset(unittest.TestCase):
# all instance have been filtered by classes # all instance have been filtered by classes
self.assertEqual(len(ann_info['gt_labels_3d']), 27) self.assertEqual(len(ann_info['gt_labels_3d']), 27)
self.assertEqual(len(no_class_scannet_dataset.metainfo['CLASSES']), 1) self.assertEqual(len(no_class_scannet_dataset.metainfo['CLASSES']), 1)
def test_scannet_seg(self):
np.random.seed(0)
data_root, ann_file, classes, palette, scene_idxs, data_prefix, \
pipeline, modality, = _generate_scannet_seg_dataset_config()
register_all_modules()
scannet_seg_dataset = ScanNetSegDataset(
data_root,
ann_file,
metainfo=dict(CLASSES=classes, PALETTE=palette),
data_prefix=data_prefix,
pipeline=pipeline,
modality=modality,
scene_idxs=scene_idxs)
input_dict = scannet_seg_dataset.prepare_data(0)
points = input_dict['inputs']['points']
data_sample = input_dict['data_sample']
pts_semantic_mask = data_sample.gt_pts_seg.pts_semantic_mask
expected_points = torch.tensor([[
0.0000, 0.0000, 1.2427, 0.6118, 0.5529, 0.4471, -0.6462, -1.0046,
0.4280
],
[
0.1553, -0.0074, 1.6077, 0.5882,
0.6157, 0.5569, -0.6001, -1.0068,
0.5537
],
[
0.1518, 0.6016, 0.6548, 0.1490,
0.1059, 0.0431, -0.6012, -0.8309,
0.2255
],
[
-0.7494, 0.1033, 0.6756, 0.5216,
0.4353, 0.3333, -0.8687, -0.9748,
0.2327
],
[
-0.6836, -0.0203, 0.5884, 0.5765,
0.5020, 0.4510, -0.8491, -1.0105,
0.2027
]])
expected_pts_semantic_mask = np.array([13, 13, 12, 2, 0])
assert torch.allclose(points, expected_points, 1e-2)
self.assertTrue(
(pts_semantic_mask.numpy() == expected_pts_semantic_mask).all())
# Copyright (c) OpenMMLab. All rights reserved.
import unittest
import numpy as np
from mmdet3d.datasets import SemanticKITTIDataset
from mmdet3d.utils import register_all_modules
def _generate_semantickitti_dataset_config():
data_root = './tests/data/semantickitti/'
ann_file = 'semantickitti_infos.pkl'
classes = ('unlabeled', 'car', 'bicycle', 'motorcycle', 'truck', 'bus',
'person', 'bicyclist', 'motorcyclist', 'road', 'parking',
'sidewalk', 'other-ground', 'building', 'fence', 'vegetation',
'trunck', 'terrian', 'pole', 'traffic-sign')
palette = [
[174, 199, 232],
[152, 223, 138],
[31, 119, 180],
[255, 187, 120],
[188, 189, 34],
[140, 86, 75],
[255, 152, 150],
[214, 39, 40],
[197, 176, 213],
[148, 103, 189],
[196, 156, 148],
[23, 190, 207],
[247, 182, 210],
[219, 219, 141],
[255, 127, 14],
[158, 218, 229],
[44, 160, 44],
[112, 128, 144],
[227, 119, 194],
[82, 84, 163],
]
modality = dict(use_lidar=True, use_camera=False)
pipeline = [
dict(
type='LoadPointsFromFile',
coord_type='LIDAR',
shift_height=True,
load_dim=4,
use_dim=[0, 1, 2]),
dict(
type='LoadAnnotations3D',
with_bbox_3d=False,
with_label_3d=False,
with_mask_3d=False,
with_seg_3d=True,
seg_3d_dtype=np.int32),
dict(type='Pack3DDetInputs', keys=['points', 'pts_semantic_mask'])
]
data_prefix = dict(
pts='sequences/00/velodyne', pts_semantic_mask='sequences/00/labels')
return (data_root, ann_file, classes, palette, data_prefix, pipeline,
modality)
class TestSemanticKITTIDataset(unittest.TestCase):
def test_semantickitti(self):
np.random.seed(0)
data_root, ann_file, classes, palette, data_prefix, \
pipeline, modality, = _generate_semantickitti_dataset_config()
register_all_modules()
semantickitti_dataset = SemanticKITTIDataset(
data_root,
ann_file,
metainfo=dict(CLASSES=classes, PALETTE=palette),
data_prefix=data_prefix,
pipeline=pipeline,
modality=modality)
input_dict = semantickitti_dataset.prepare_data(0)
points = input_dict['inputs']['points']
data_sample = input_dict['data_sample']
pts_semantic_mask = data_sample.gt_pts_seg.pts_semantic_mask
self.assertEqual(points.shape[0], pts_semantic_mask.shape[0])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment