Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
mmdetection3d
Commits
360c27f9
Commit
360c27f9
authored
Jul 15, 2022
by
ZCMax
Committed by
ChaimZhu
Jul 20, 2022
Browse files
[Refactor] Refactor 3D Seg Dataset
parent
1039ad0e
Changes
22
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1202 additions
and
1261 deletions
+1202
-1261
configs/_base_/datasets/s3dis_seg-3d-13class.py
configs/_base_/datasets/s3dis_seg-3d-13class.py
+51
-62
configs/_base_/datasets/scannet_seg-3d-20class.py
configs/_base_/datasets/scannet_seg-3d-20class.py
+65
-58
mmdet3d/core/data_structures/__init__.py
mmdet3d/core/data_structures/__init__.py
+2
-1
mmdet3d/core/data_structures/det3d_data_sample.py
mmdet3d/core/data_structures/det3d_data_sample.py
+36
-68
mmdet3d/core/data_structures/point_data.py
mmdet3d/core/data_structures/point_data.py
+162
-0
mmdet3d/datasets/__init__.py
mmdet3d/datasets/__init__.py
+2
-2
mmdet3d/datasets/custom_3d_seg.py
mmdet3d/datasets/custom_3d_seg.py
+0
-463
mmdet3d/datasets/pipelines/formating.py
mmdet3d/datasets/pipelines/formating.py
+4
-4
mmdet3d/datasets/pipelines/loading.py
mmdet3d/datasets/pipelines/loading.py
+18
-24
mmdet3d/datasets/pipelines/transforms_3d.py
mmdet3d/datasets/pipelines/transforms_3d.py
+49
-35
mmdet3d/datasets/s3dis_dataset.py
mmdet3d/datasets/s3dis_dataset.py
+50
-119
mmdet3d/datasets/scannet_dataset.py
mmdet3d/datasets/scannet_dataset.py
+107
-305
mmdet3d/datasets/seg3d_dataset.py
mmdet3d/datasets/seg3d_dataset.py
+279
-0
mmdet3d/datasets/semantickitti_dataset.py
mmdet3d/datasets/semantickitti_dataset.py
+31
-68
tests/data/s3dis/s3dis_infos.pkl
tests/data/s3dis/s3dis_infos.pkl
+0
-0
tests/data/semantickitti/semantickitti_infos.pkl
tests/data/semantickitti/semantickitti_infos.pkl
+0
-0
tests/test_core/test_data_structure/test_det_data_sample.py
tests/test_core/test_data_structure/test_det_data_sample.py
+32
-51
tests/test_data/test_datasets/test_s3dis_dataset.py
tests/test_data/test_datasets/test_s3dis_dataset.py
+110
-0
tests/test_data/test_datasets/test_scannet_dataset.py
tests/test_data/test_datasets/test_scannet_dataset.py
+119
-1
tests/test_data/test_datasets/test_semantickitti_dataset.py
tests/test_data/test_datasets/test_semantickitti_dataset.py
+85
-0
No files found.
configs/_base_/datasets/s3dis_seg-3d-13class.py
View file @
360c27f9
# dataset settings
# For S3DIS seg we usually do 13-class segmentation
dataset_type
=
'S3DISSegDataset'
data_root
=
'./data/s3dis/'
class_names
=
(
'ceiling'
,
'floor'
,
'wall'
,
'beam'
,
'column'
,
'window'
,
'door'
,
class_names
=
(
'ceiling'
,
'floor'
,
'wall'
,
'beam'
,
'column'
,
'window'
,
'door'
,
'table'
,
'chair'
,
'sofa'
,
'bookcase'
,
'board'
,
'clutter'
)
'table'
,
'chair'
,
'sofa'
,
'bookcase'
,
'board'
,
'clutter'
)
metainfo
=
dict
(
CLASSES
=
class_names
)
dataset_type
=
'S3DISSegDataset'
data_root
=
'data/s3dis/'
input_modality
=
dict
(
use_lidar
=
True
,
use_camera
=
False
)
data_prefix
=
dict
(
pts
=
'points'
,
pts_instance_mask
=
'instance_mask'
,
pts_semantic_mask
=
'semantic_mask'
)
file_client_args
=
dict
(
backend
=
'disk'
)
file_client_args
=
dict
(
backend
=
'disk'
)
# Uncomment the following if use ceph or other file clients.
# Uncomment the following if use ceph or other file clients.
# See https://mmcv.readthedocs.io/en/latest/api.html#mmcv.fileio.FileClient
# See https://mmcv.readthedocs.io/en/latest/api.html#mmcv.fileio.FileClient
...
@@ -15,29 +22,27 @@ file_client_args = dict(backend='disk')
...
@@ -15,29 +22,27 @@ file_client_args = dict(backend='disk')
# 'data/s3dis/':
# 'data/s3dis/':
# 's3://openmmlab/datasets/detection3d/s3dis_processed/'
# 's3://openmmlab/datasets/detection3d/s3dis_processed/'
# }))
# }))
num_points
=
4096
num_points
=
4096
train_area
=
[
1
,
2
,
3
,
4
,
6
]
train_area
=
[
1
,
2
,
3
,
4
,
6
]
test_area
=
5
test_area
=
5
train_pipeline
=
[
train_pipeline
=
[
dict
(
dict
(
type
=
'LoadPointsFromFile'
,
type
=
'LoadPointsFromFile'
,
file_client_args
=
file_client_args
,
coord_type
=
'DEPTH'
,
coord_type
=
'DEPTH'
,
shift_height
=
False
,
shift_height
=
False
,
use_color
=
True
,
use_color
=
True
,
load_dim
=
6
,
load_dim
=
6
,
use_dim
=
[
0
,
1
,
2
,
3
,
4
,
5
]),
use_dim
=
[
0
,
1
,
2
,
3
,
4
,
5
],
file_client_args
=
file_client_args
),
dict
(
dict
(
type
=
'LoadAnnotations3D'
,
type
=
'LoadAnnotations3D'
,
file_client_args
=
file_client_args
,
with_bbox_3d
=
False
,
with_bbox_3d
=
False
,
with_label_3d
=
False
,
with_label_3d
=
False
,
with_mask_3d
=
False
,
with_mask_3d
=
False
,
with_seg_3d
=
True
),
with_seg_3d
=
True
,
dict
(
file_client_args
=
file_client_args
),
type
=
'PointSegClassMapping'
,
dict
(
type
=
'PointSegClassMapping'
),
valid_cat_ids
=
tuple
(
range
(
len
(
class_names
))),
max_cat_id
=
13
),
dict
(
dict
(
type
=
'IndoorPatchPointSample'
,
type
=
'IndoorPatchPointSample'
,
num_points
=
num_points
,
num_points
=
num_points
,
...
@@ -47,18 +52,17 @@ train_pipeline = [
...
@@ -47,18 +52,17 @@ train_pipeline = [
enlarge_size
=
0.2
,
enlarge_size
=
0.2
,
min_unique_num
=
None
),
min_unique_num
=
None
),
dict
(
type
=
'NormalizePointsColor'
,
color_mean
=
None
),
dict
(
type
=
'NormalizePointsColor'
,
color_mean
=
None
),
dict
(
type
=
'DefaultFormatBundle3D'
,
class_names
=
class_names
),
dict
(
type
=
'Pack3DDetInputs'
,
keys
=
[
'points'
,
'pts_semantic_mask'
])
dict
(
type
=
'Collect3D'
,
keys
=
[
'points'
,
'pts_semantic_mask'
])
]
]
test_pipeline
=
[
test_pipeline
=
[
dict
(
dict
(
type
=
'LoadPointsFromFile'
,
type
=
'LoadPointsFromFile'
,
file_client_args
=
file_client_args
,
coord_type
=
'DEPTH'
,
coord_type
=
'DEPTH'
,
shift_height
=
False
,
shift_height
=
False
,
use_color
=
True
,
use_color
=
True
,
load_dim
=
6
,
load_dim
=
6
,
use_dim
=
[
0
,
1
,
2
,
3
,
4
,
5
]),
use_dim
=
[
0
,
1
,
2
,
3
,
4
,
5
],
file_client_args
=
file_client_args
),
dict
(
type
=
'NormalizePointsColor'
,
color_mean
=
None
),
dict
(
type
=
'NormalizePointsColor'
,
color_mean
=
None
),
dict
(
dict
(
# a wrapper in order to successfully call test function
# a wrapper in order to successfully call test function
...
@@ -78,12 +82,8 @@ test_pipeline = [
...
@@ -78,12 +82,8 @@ test_pipeline = [
sync_2d
=
False
,
sync_2d
=
False
,
flip_ratio_bev_horizontal
=
0.0
,
flip_ratio_bev_horizontal
=
0.0
,
flip_ratio_bev_vertical
=
0.0
),
flip_ratio_bev_vertical
=
0.0
),
dict
(
]),
type
=
'DefaultFormatBundle3D'
,
dict
(
type
=
'Pack3DDetInputs'
,
keys
=
[
'points'
])
class_names
=
class_names
,
with_label
=
False
),
dict
(
type
=
'Collect3D'
,
keys
=
[
'points'
])
])
]
]
# construct a pipeline for data and gt loading in show function
# construct a pipeline for data and gt loading in show function
# please keep its loading function consistent with test_pipeline (e.g. client)
# please keep its loading function consistent with test_pipeline (e.g. client)
...
@@ -91,69 +91,58 @@ test_pipeline = [
...
@@ -91,69 +91,58 @@ test_pipeline = [
eval_pipeline
=
[
eval_pipeline
=
[
dict
(
dict
(
type
=
'LoadPointsFromFile'
,
type
=
'LoadPointsFromFile'
,
file_client_args
=
file_client_args
,
coord_type
=
'DEPTH'
,
coord_type
=
'DEPTH'
,
shift_height
=
False
,
shift_height
=
False
,
use_color
=
True
,
use_color
=
True
,
load_dim
=
6
,
load_dim
=
6
,
use_dim
=
[
0
,
1
,
2
,
3
,
4
,
5
]),
use_dim
=
[
0
,
1
,
2
,
3
,
4
,
5
],
dict
(
file_client_args
=
file_client_args
),
type
=
'LoadAnnotations3D'
,
dict
(
type
=
'NormalizePointsColor'
,
color_mean
=
None
),
file_client_args
=
file_client_args
,
dict
(
type
=
'Pack3DDetInputs'
,
keys
=
[
'points'
])
with_bbox_3d
=
False
,
with_label_3d
=
False
,
with_mask_3d
=
False
,
with_seg_3d
=
True
),
dict
(
type
=
'PointSegClassMapping'
,
valid_cat_ids
=
tuple
(
range
(
len
(
class_names
))),
max_cat_id
=
13
),
dict
(
type
=
'DefaultFormatBundle3D'
,
with_label
=
False
,
class_names
=
class_names
),
dict
(
type
=
'Collect3D'
,
keys
=
[
'points'
,
'pts_semantic_mask'
])
]
]
data
=
dict
(
# train on area 1, 2, 3, 4, 6
samples_per_gpu
=
8
,
# test on area 5
workers_per_gpu
=
4
,
train_dataloader
=
dict
(
# train on area 1, 2, 3, 4, 6
batch_size
=
8
,
# test on area 5
num_workers
=
4
,
train
=
dict
(
persistent_workers
=
True
,
sampler
=
dict
(
type
=
'DefaultSampler'
,
shuffle
=
True
),
dataset
=
dict
(
type
=
dataset_type
,
type
=
dataset_type
,
data_root
=
data_root
,
data_root
=
data_root
,
ann_files
=
[
ann_files
=
[
data_root
+
f
's3dis_infos_Area_
{
i
}
.pkl'
for
i
in
train_area
data_root
+
f
's3dis_infos_Area_
{
i
}
.pkl'
for
i
in
train_area
],
],
metainfo
=
metainfo
,
data_prefix
=
data_prefix
,
pipeline
=
train_pipeline
,
pipeline
=
train_pipeline
,
classes
=
class_names
,
modality
=
input_modality
,
test_mode
=
False
,
ignore_index
=
len
(
class_names
),
ignore_index
=
len
(
class_names
),
scene_idxs
=
[
scene_idxs
=
[
data_root
+
f
'seg_info/Area_
{
i
}
_resampled_scene_idxs.npy'
data_root
+
f
'seg_info/Area_
{
i
}
_resampled_scene_idxs.npy'
for
i
in
train_area
for
i
in
train_area
],
],
file_client_args
=
file_client_args
),
test_mode
=
False
))
val
=
dict
(
test_dataloader
=
dict
(
batch_size
=
1
,
num_workers
=
1
,
persistent_workers
=
True
,
drop_last
=
False
,
sampler
=
dict
(
type
=
'DefaultSampler'
,
shuffle
=
False
),
dataset
=
dict
(
type
=
dataset_type
,
type
=
dataset_type
,
data_root
=
data_root
,
data_root
=
data_root
,
ann_files
=
data_root
+
f
's3dis_infos_Area_
{
test_area
}
.pkl'
,
ann_files
=
data_root
+
f
's3dis_infos_Area_
{
test_area
}
.pkl'
,
metainfo
=
metainfo
,
data_prefix
=
data_prefix
,
pipeline
=
test_pipeline
,
pipeline
=
test_pipeline
,
classes
=
class_names
,
modality
=
input_modality
,
test_mode
=
True
,
ignore_index
=
len
(
class_names
),
ignore_index
=
len
(
class_names
),
scene_idxs
=
data_root
+
scene_idxs
=
data_root
+
f
'seg_info/Area_
{
test_area
}
_resampled_scene_idxs.npy'
,
f
'seg_info/Area_
{
test_area
}
_resampled_scene_idxs.npy'
,
file_client_args
=
file_client_args
),
test_mode
=
True
))
test
=
dict
(
val_dataloader
=
test_dataloader
type
=
dataset_type
,
data_root
=
data_root
,
ann_files
=
data_root
+
f
's3dis_infos_Area_
{
test_area
}
.pkl'
,
pipeline
=
test_pipeline
,
classes
=
class_names
,
test_mode
=
True
,
ignore_index
=
len
(
class_names
),
file_client_args
=
file_client_args
))
evaluation
=
dict
(
pipeline
=
eval_pipeline
)
val_evaluator
=
dict
(
type
=
'SegMetric'
)
test_evaluator
=
val_evaluator
configs/_base_/datasets/scannet_seg-3d-20class.py
View file @
360c27f9
# dataset settings
# For ScanNet seg we usually do 20-class segmentation
dataset_type
=
'ScanNetSegDataset'
data_root
=
'./data/scannet/'
class_names
=
(
'wall'
,
'floor'
,
'cabinet'
,
'bed'
,
'chair'
,
'sofa'
,
'table'
,
class_names
=
(
'wall'
,
'floor'
,
'cabinet'
,
'bed'
,
'chair'
,
'sofa'
,
'table'
,
'door'
,
'window'
,
'bookshelf'
,
'picture'
,
'counter'
,
'desk'
,
'door'
,
'window'
,
'bookshelf'
,
'picture'
,
'counter'
,
'desk'
,
'curtain'
,
'refrigerator'
,
'showercurtrain'
,
'toilet'
,
'sink'
,
'curtain'
,
'refrigerator'
,
'showercurtrain'
,
'toilet'
,
'sink'
,
'bathtub'
,
'otherfurniture'
)
'bathtub'
,
'otherfurniture'
)
metainfo
=
dict
(
CLASSES
=
class_names
)
dataset_type
=
'ScanNetSegDataset'
data_root
=
'data/scannet/'
input_modality
=
dict
(
use_lidar
=
True
,
use_camera
=
False
)
data_prefix
=
dict
(
pts
=
'points'
,
pts_instance_mask
=
'instance_mask'
,
pts_semantic_mask
=
'semantic_mask'
)
file_client_args
=
dict
(
backend
=
'disk'
)
# Uncomment the following if use ceph or other file clients.
# See https://mmcv.readthedocs.io/en/latest/api.html#mmcv.fileio.FileClient
# for more details.
# file_client_args = dict(
# backend='petrel',
# path_mapping=dict({
# './data/scannet/':
# 's3://openmmlab/datasets/detection3d/scannet_processed/',
# 'data/scannet/':
# 's3://openmmlab/datasets/detection3d/scannet_processed/'
# }))
num_points
=
8192
num_points
=
8192
train_pipeline
=
[
train_pipeline
=
[
dict
(
dict
(
...
@@ -13,18 +33,16 @@ train_pipeline = [
...
@@ -13,18 +33,16 @@ train_pipeline = [
shift_height
=
False
,
shift_height
=
False
,
use_color
=
True
,
use_color
=
True
,
load_dim
=
6
,
load_dim
=
6
,
use_dim
=
[
0
,
1
,
2
,
3
,
4
,
5
]),
use_dim
=
[
0
,
1
,
2
,
3
,
4
,
5
],
file_client_args
=
file_client_args
),
dict
(
dict
(
type
=
'LoadAnnotations3D'
,
type
=
'LoadAnnotations3D'
,
with_bbox_3d
=
False
,
with_bbox_3d
=
False
,
with_label_3d
=
False
,
with_label_3d
=
False
,
with_mask_3d
=
False
,
with_mask_3d
=
False
,
with_seg_3d
=
True
),
with_seg_3d
=
True
,
dict
(
file_client_args
=
file_client_args
),
type
=
'PointSegClassMapping'
,
dict
(
type
=
'PointSegClassMapping'
),
valid_cat_ids
=
(
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
14
,
16
,
24
,
28
,
33
,
34
,
36
,
39
),
max_cat_id
=
40
),
dict
(
dict
(
type
=
'IndoorPatchPointSample'
,
type
=
'IndoorPatchPointSample'
,
num_points
=
num_points
,
num_points
=
num_points
,
...
@@ -34,8 +52,7 @@ train_pipeline = [
...
@@ -34,8 +52,7 @@ train_pipeline = [
enlarge_size
=
0.2
,
enlarge_size
=
0.2
,
min_unique_num
=
None
),
min_unique_num
=
None
),
dict
(
type
=
'NormalizePointsColor'
,
color_mean
=
None
),
dict
(
type
=
'NormalizePointsColor'
,
color_mean
=
None
),
dict
(
type
=
'DefaultFormatBundle3D'
,
class_names
=
class_names
),
dict
(
type
=
'Pack3DDetInputs'
,
keys
=
[
'points'
,
'pts_semantic_mask'
])
dict
(
type
=
'Collect3D'
,
keys
=
[
'points'
,
'pts_semantic_mask'
])
]
]
test_pipeline
=
[
test_pipeline
=
[
dict
(
dict
(
...
@@ -44,7 +61,8 @@ test_pipeline = [
...
@@ -44,7 +61,8 @@ test_pipeline = [
shift_height
=
False
,
shift_height
=
False
,
use_color
=
True
,
use_color
=
True
,
load_dim
=
6
,
load_dim
=
6
,
use_dim
=
[
0
,
1
,
2
,
3
,
4
,
5
]),
use_dim
=
[
0
,
1
,
2
,
3
,
4
,
5
],
file_client_args
=
file_client_args
),
dict
(
type
=
'NormalizePointsColor'
,
color_mean
=
None
),
dict
(
type
=
'NormalizePointsColor'
,
color_mean
=
None
),
dict
(
dict
(
# a wrapper in order to successfully call test function
# a wrapper in order to successfully call test function
...
@@ -64,12 +82,8 @@ test_pipeline = [
...
@@ -64,12 +82,8 @@ test_pipeline = [
sync_2d
=
False
,
sync_2d
=
False
,
flip_ratio_bev_horizontal
=
0.0
,
flip_ratio_bev_horizontal
=
0.0
,
flip_ratio_bev_vertical
=
0.0
),
flip_ratio_bev_vertical
=
0.0
),
dict
(
]),
type
=
'DefaultFormatBundle3D'
,
dict
(
type
=
'Pack3DDetInputs'
,
keys
=
[
'points'
])
class_names
=
class_names
,
with_label
=
False
),
dict
(
type
=
'Collect3D'
,
keys
=
[
'points'
])
])
]
]
# construct a pipeline for data and gt loading in show function
# construct a pipeline for data and gt loading in show function
# please keep its loading function consistent with test_pipeline (e.g. client)
# please keep its loading function consistent with test_pipeline (e.g. client)
...
@@ -81,52 +95,45 @@ eval_pipeline = [
...
@@ -81,52 +95,45 @@ eval_pipeline = [
shift_height
=
False
,
shift_height
=
False
,
use_color
=
True
,
use_color
=
True
,
load_dim
=
6
,
load_dim
=
6
,
use_dim
=
[
0
,
1
,
2
,
3
,
4
,
5
]),
use_dim
=
[
0
,
1
,
2
,
3
,
4
,
5
],
dict
(
file_client_args
=
file_client_args
),
type
=
'LoadAnnotations3D'
,
dict
(
type
=
'NormalizePointsColor'
,
color_mean
=
None
),
with_bbox_3d
=
False
,
dict
(
type
=
'Pack3DDetInputs'
,
keys
=
[
'points'
])
with_label_3d
=
False
,
with_mask_3d
=
False
,
with_seg_3d
=
True
),
dict
(
type
=
'PointSegClassMapping'
,
valid_cat_ids
=
(
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
14
,
16
,
24
,
28
,
33
,
34
,
36
,
39
),
max_cat_id
=
40
),
dict
(
type
=
'DefaultFormatBundle3D'
,
with_label
=
False
,
class_names
=
class_names
),
dict
(
type
=
'Collect3D'
,
keys
=
[
'points'
,
'pts_semantic_mask'
])
]
]
data
=
dict
(
train_dataloader
=
dict
(
samples_per_gpu
=
8
,
batch_size
=
8
,
workers_per_gpu
=
4
,
num_workers
=
4
,
train
=
dict
(
persistent_workers
=
True
,
sampler
=
dict
(
type
=
'DefaultSampler'
,
shuffle
=
True
),
dataset
=
dict
(
type
=
dataset_type
,
type
=
dataset_type
,
data_root
=
data_root
,
data_root
=
data_root
,
ann_file
=
data_root
+
'scannet_infos_train.pkl'
,
ann_file
=
'scannet_infos_train.pkl'
,
metainfo
=
metainfo
,
data_prefix
=
data_prefix
,
pipeline
=
train_pipeline
,
pipeline
=
train_pipeline
,
classes
=
class_names
,
modality
=
input_modality
,
test_mode
=
False
,
ignore_index
=
len
(
class_names
),
ignore_index
=
len
(
class_names
),
scene_idxs
=
data_root
+
'seg_info/train_resampled_scene_idxs.npy'
),
scene_idxs
=
data_root
+
'seg_info/train_resampled_scene_idxs.npy'
,
val
=
dict
(
test_mode
=
False
))
test_dataloader
=
dict
(
batch_size
=
1
,
num_workers
=
1
,
persistent_workers
=
True
,
drop_last
=
False
,
sampler
=
dict
(
type
=
'DefaultSampler'
,
shuffle
=
False
),
dataset
=
dict
(
type
=
dataset_type
,
type
=
dataset_type
,
data_root
=
data_root
,
data_root
=
data_root
,
ann_file
=
data_root
+
'scannet_infos_val.pkl'
,
ann_file
=
'scannet_infos_val.pkl'
,
metainfo
=
metainfo
,
data_prefix
=
data_prefix
,
pipeline
=
test_pipeline
,
pipeline
=
test_pipeline
,
classes
=
class_names
,
modality
=
input_modality
,
test_mode
=
True
,
ignore_index
=
len
(
class_names
),
ignore_index
=
len
(
class_names
)),
test_mode
=
True
))
test
=
dict
(
val_dataloader
=
test_dataloader
type
=
dataset_type
,
data_root
=
data_root
,
ann_file
=
data_root
+
'scannet_infos_val.pkl'
,
pipeline
=
test_pipeline
,
classes
=
class_names
,
test_mode
=
True
,
ignore_index
=
len
(
class_names
)))
evaluation
=
dict
(
pipeline
=
eval_pipeline
)
val_evaluator
=
dict
(
type
=
'SegMetric'
)
test_evaluator
=
val_evaluator
mmdet3d/core/data_structures/__init__.py
View file @
360c27f9
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
from
.det3d_data_sample
import
Det3DDataSample
from
.det3d_data_sample
import
Det3DDataSample
from
.point_data
import
PointData
__all__
=
[
'Det3DDataSample'
]
__all__
=
[
'Det3DDataSample'
,
'PointData'
]
mmdet3d/core/data_structures/det3d_data_sample.py
View file @
360c27f9
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
from
mmengine.data
import
InstanceData
,
PixelData
from
mmengine.data
import
InstanceData
from
mmdet.core.data_structures
import
DetDataSample
from
mmdet.core.data_structures
import
DetDataSample
from
.point_data
import
PointData
class
Det3DDataSample
(
DetDataSample
):
class
Det3DDataSample
(
DetDataSample
):
...
@@ -43,19 +44,15 @@ class Det3DDataSample(DetDataSample):
...
@@ -43,19 +44,15 @@ class Det3DDataSample(DetDataSample):
`use_lidar=True, use_camera=True`, the 3D predictions based on
`use_lidar=True, use_camera=True`, the 3D predictions based on
image are saved in `img_pred_instances_3d` to distinguish with
image are saved in `img_pred_instances_3d` to distinguish with
`pts_pred_instances_3d` which based on point cloud.
`pts_pred_instances_3d` which based on point cloud.
- ``gt_pts_sem_seg``(PixelData): Ground truth of point cloud
- ``gt_pts_seg``(PointData): Ground truth of point cloud
semantic segmentation.
segmentation.
- ``pred_pts_sem_seg``(PixelData): Prediction of point cloud
- ``pred_pts_seg``(PointData): Prediction of point cloud
semantic segmentation.
segmentation.
- ``gt_pts_panoptic_seg``(PixelData): Ground truth of point cloud
panoptic segmentation.
- ``pred_pts_panoptic_seg``(PixelData): Predicted of point cloud
panoptic segmentation.
- ``eval_ann_info``(dict): Raw annotation, which will be passed to
- ``eval_ann_info``(dict): Raw annotation, which will be passed to
evaluator and do the online evaluation.
evaluator and do the online evaluation.
Examples:
Examples:
>>> from mmengine.data import InstanceData
, PixelData
>>> from mmengine.data import InstanceData
>>> from mmdet3d.core import Det3DDataSample
>>> from mmdet3d.core import Det3DDataSample
>>> from mmdet3d.core.bbox import BaseInstance3DBoxes
>>> from mmdet3d.core.bbox import BaseInstance3DBoxes
...
@@ -128,38 +125,33 @@ class Det3DDataSample(DetDataSample):
...
@@ -128,38 +125,33 @@ class Det3DDataSample(DetDataSample):
>>> assert 'bboxes' in data_sample.gt_instances_3d
>>> assert 'bboxes' in data_sample.gt_instances_3d
>>> data_sample = Det3DDataSample()
>>> data_sample = Det3DDataSample()
>>> gt_pts_panoptic_seg_data = dict(panoptic_seg=torch.rand(1, 2, 4))
... gt_pts_seg_data = dict(
>>> gt_pts_panoptic_seg = PixelData(**gt_pts_panoptic_seg_data)
... pts_instance_mask=torch.rand(2),
>>> data_sample.gt_pts_panoptic_seg = gt_pts_panoptic_seg
... pts_semantic_mask=torch.rand(2))
>>> data_sample.gt_pts_seg = PointData(**gt_pts_seg_data)
>>> print(data_sample)
>>> print(data_sample)
<Det3DDataSample(
<Det3DDataSample(
META INFORMATION
META INFORMATION
DATA FIELDS
DATA FIELDS
_
gt_pts_
panoptic_
seg: <P
ixel
Data(
gt_pts_seg: <P
oint
Data(
META INFORMATION
META INFORMATION
DATA FIELDS
DATA FIELDS
p
anoptic_seg: tensor([[[0.9875, 0.3012, 0.5534, 0.9593],
p
ts_instance_mask: tensor([0.0576, 0.3067])
[0.1251, 0.1911, 0.8058, 0.2566]]
])
pts_semantic_mask: tensor([0.9267, 0.7455
])
) at 0x7f
b0d93543d
0>
) at 0x7f
654a9c159
0>
gt_pts_
panoptic_
seg: <P
ixel
Data(
_
gt_pts_seg: <P
oint
Data(
META INFORMATION
META INFORMATION
DATA FIELDS
DATA FIELDS
panoptic_seg: tensor([[[0.9875, 0.3012, 0.5534, 0.9593],
pts_instance_mask: tensor([0.0576, 0.3067])
[0.1251, 0.1911, 0.8058, 0.2566]]])
pts_semantic_mask: tensor([0.9267, 0.7455])
) at 0x7fb0d93543d0>
) at 0x7f654a9c1590>
) at 0x7fb0d9354280>
) at 0x7f654a9c1550>
>>> data_sample = Det3DDataSample()
>>> gt_pts_sem_seg_data = dict(segm_seg=torch.rand(2, 2, 2))
>>> gt_pts_sem_seg = PixelData(**gt_pts_sem_seg_data)
>>> data_sample.gt_pts_sem_seg = gt_pts_sem_seg
>>> assert 'gt_pts_sem_seg' in data_sample
>>> assert 'segm_seg' in data_sample.gt_pts_sem_seg
"""
"""
@
property
@
property
...
@@ -211,49 +203,25 @@ class Det3DDataSample(DetDataSample):
...
@@ -211,49 +203,25 @@ class Det3DDataSample(DetDataSample):
del
self
.
_img_pred_instances_3d
del
self
.
_img_pred_instances_3d
@
property
@
property
def
gt_pts_sem_seg
(
self
)
->
PixelData
:
def
gt_pts_seg
(
self
)
->
PointData
:
return
self
.
_gt_pts_sem_seg
return
self
.
_gt_pts_seg
@
gt_pts_sem_seg
.
setter
def
gt_pts_sem_seg
(
self
,
value
:
PixelData
):
self
.
set_field
(
value
,
'_gt_pts_sem_seg'
,
dtype
=
PixelData
)
@
gt_pts_sem_seg
.
deleter
def
gt_pts_sem_seg
(
self
):
del
self
.
_gt_pts_sem_seg
@
property
def
pred_pts_sem_seg
(
self
)
->
PixelData
:
return
self
.
_pred_pts_sem_seg
@
pred_pts_sem_seg
.
setter
def
pred_pts_sem_seg
(
self
,
value
:
PixelData
):
self
.
set_field
(
value
,
'_pred_pts_sem_seg'
,
dtype
=
PixelData
)
@
pred_pts_sem_seg
.
deleter
def
pred_pts_sem_seg
(
self
):
del
self
.
_pred_pts_sem_seg
@
property
def
gt_pts_panoptic_seg
(
self
)
->
PixelData
:
return
self
.
_gt_pts_panoptic_seg
@
gt_pts_
panoptic_
seg
.
setter
@
gt_pts_seg
.
setter
def
gt_pts_
panoptic_
seg
(
self
,
value
:
P
ixel
Data
):
def
gt_pts_seg
(
self
,
value
:
P
oint
Data
):
self
.
set_field
(
value
,
'_gt_pts_
panoptic_
seg'
,
dtype
=
P
ixel
Data
)
self
.
set_field
(
value
,
'_gt_pts_seg'
,
dtype
=
P
oint
Data
)
@
gt_pts_
panoptic_
seg
.
deleter
@
gt_pts_seg
.
deleter
def
gt_pts_
panoptic_
seg
(
self
):
def
gt_pts_seg
(
self
):
del
self
.
_gt_pts_
panoptic_
seg
del
self
.
_gt_pts_seg
@
property
@
property
def
pred_pts_
panoptic_
seg
(
self
)
->
P
ixel
Data
:
def
pred_pts_seg
(
self
)
->
P
oint
Data
:
return
self
.
_pred_pts_
panoptic_
seg
return
self
.
_pred_pts_seg
@
pred_pts_
panoptic_
seg
.
setter
@
pred_pts_seg
.
setter
def
pred_pts_
panoptic_
seg
(
self
,
value
:
P
ixel
Data
):
def
pred_pts_seg
(
self
,
value
:
P
oint
Data
):
self
.
set_field
(
value
,
'_pred_pts_
panoptic_
seg'
,
dtype
=
P
ixel
Data
)
self
.
set_field
(
value
,
'_pred_pts_seg'
,
dtype
=
P
oint
Data
)
@
pred_pts_
panoptic_
seg
.
deleter
@
pred_pts_seg
.
deleter
def
pred_pts_
panoptic_
seg
(
self
):
def
pred_pts_seg
(
self
):
del
self
.
_pred_pts_
panoptic_
seg
del
self
.
_pred_pts_seg
mmdet3d/core/data_structures/point_data.py
0 → 100644
View file @
360c27f9
# Copyright (c) OpenMMLab. All rights reserved.
from
collections.abc
import
Sized
from
typing
import
Union
import
numpy
as
np
import
torch
from
mmengine.data
import
BaseDataElement
IndexType
=
Union
[
str
,
slice
,
int
,
list
,
torch
.
LongTensor
,
torch
.
cuda
.
LongTensor
,
torch
.
BoolTensor
,
torch
.
cuda
.
BoolTensor
,
np
.
ndarray
]
class
PointData
(
BaseDataElement
):
"""Data structure for point-level annnotations or predictions.
All data items in ``data_fields`` of ``PointData`` meet the following
requirements:
- They are all one dimension.
- They should have the same length.
Notice: ``PointData`` behaves like `InstanceData`.
Examples:
>>> metainfo = dict(
... sample_id=random.randint(0, 100))
>>> points = np.random.randint(0, 255, (100, 3))
>>> point_data = PointData(metainfo=metainfo,
... points=points)
>>> print(len(point_data))
>>> (100)
>>> # slice
>>> slice_data = pixel_data[10:60]
>>> assert slice_data.shape == (50,)
>>> # set
>>> point_data.pts_semantic_mask = torch.randint(0, 255, (100))
>>> point_data.pts_instance_mask = torch.randint(0, 255, (100))
>>> assert tuple(point_data.pts_semantic_mask.shape) == (100)
>>> assert tuple(point_data.pts_instance_mask.shape) == (100)
"""
def
__setattr__
(
self
,
name
:
str
,
value
:
Sized
):
"""setattr is only used to set data.
the value must have the attribute of `__len__` and have the same length
of PointData.
"""
if
name
in
(
'_metainfo_fields'
,
'_data_fields'
):
if
not
hasattr
(
self
,
name
):
super
().
__setattr__
(
name
,
value
)
else
:
raise
AttributeError
(
f
'
{
name
}
has been used as a '
f
'private attribute, which is immutable. '
)
else
:
assert
isinstance
(
value
,
Sized
),
'value must contain `_len__` attribute'
if
len
(
self
)
>
0
:
assert
len
(
value
)
==
len
(
self
),
f
'the length of '
\
f
'values
{
len
(
value
)
}
is '
\
f
'not consistent with'
\
f
' the length of this '
\
f
':obj:`PointData` '
\
f
'
{
len
(
self
)
}
'
super
().
__setattr__
(
name
,
value
)
__setitem__
=
__setattr__
def
__getitem__
(
self
,
item
:
IndexType
)
->
'PointData'
:
"""
Args:
item (str, obj:`slice`,
obj`torch.LongTensor`, obj:`torch.BoolTensor`):
get the corresponding values according to item.
Returns:
obj:`PointData`: Corresponding values.
"""
if
isinstance
(
item
,
list
):
item
=
np
.
array
(
item
)
if
isinstance
(
item
,
np
.
ndarray
):
item
=
torch
.
from_numpy
(
item
)
assert
isinstance
(
item
,
(
str
,
slice
,
int
,
torch
.
LongTensor
,
torch
.
cuda
.
LongTensor
,
torch
.
BoolTensor
,
torch
.
cuda
.
BoolTensor
))
if
isinstance
(
item
,
str
):
return
getattr
(
self
,
item
)
if
type
(
item
)
==
int
:
if
item
>=
len
(
self
)
or
item
<
-
len
(
self
):
# type:ignore
raise
IndexError
(
f
'Index
{
item
}
out of range!'
)
else
:
# keep the dimension
item
=
slice
(
item
,
None
,
len
(
self
))
new_data
=
self
.
__class__
(
metainfo
=
self
.
metainfo
)
if
isinstance
(
item
,
torch
.
Tensor
):
assert
item
.
dim
()
==
1
,
'Only support to get the'
\
' values along the first dimension.'
if
isinstance
(
item
,
(
torch
.
BoolTensor
,
torch
.
cuda
.
BoolTensor
)):
assert
len
(
item
)
==
len
(
self
),
f
'The shape of the'
\
f
' input(BoolTensor)) '
\
f
'
{
len
(
item
)
}
'
\
f
' does not match the shape '
\
f
'of the indexed tensor '
\
f
'in results_filed '
\
f
'
{
len
(
self
)
}
at '
\
f
'first dimension. '
for
k
,
v
in
self
.
items
():
if
isinstance
(
v
,
torch
.
Tensor
):
new_data
[
k
]
=
v
[
item
]
elif
isinstance
(
v
,
np
.
ndarray
):
new_data
[
k
]
=
v
[
item
.
cpu
().
numpy
()]
elif
isinstance
(
v
,
(
str
,
list
,
tuple
))
or
(
hasattr
(
v
,
'__getitem__'
)
and
hasattr
(
v
,
'cat'
)):
# convert to indexes from boolTensor
if
isinstance
(
item
,
(
torch
.
BoolTensor
,
torch
.
cuda
.
BoolTensor
)):
indexes
=
torch
.
nonzero
(
item
).
view
(
-
1
).
cpu
().
numpy
().
tolist
()
else
:
indexes
=
item
.
cpu
().
numpy
().
tolist
()
slice_list
=
[]
if
indexes
:
for
index
in
indexes
:
slice_list
.
append
(
slice
(
index
,
None
,
len
(
v
)))
else
:
slice_list
.
append
(
slice
(
None
,
0
,
None
))
r_list
=
[
v
[
s
]
for
s
in
slice_list
]
if
isinstance
(
v
,
(
str
,
list
,
tuple
)):
new_value
=
r_list
[
0
]
for
r
in
r_list
[
1
:]:
new_value
=
new_value
+
r
else
:
new_value
=
v
.
cat
(
r_list
)
new_data
[
k
]
=
new_value
else
:
raise
ValueError
(
f
'The type of `
{
k
}
` is `
{
type
(
v
)
}
`, which has no '
'attribute of `cat`, so it does not '
f
'support slice with `bool`'
)
else
:
# item is a slice
for
k
,
v
in
self
.
items
():
new_data
[
k
]
=
v
[
item
]
return
new_data
# type:ignore
def
__len__
(
self
)
->
int
:
"""int: the length of PointData"""
if
len
(
self
.
_data_fields
)
>
0
:
return
len
(
self
.
values
()[
0
])
else
:
return
0
mmdet3d/datasets/__init__.py
View file @
360c27f9
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
from
.builder
import
DATASETS
,
PIPELINES
,
build_dataset
from
.builder
import
DATASETS
,
PIPELINES
,
build_dataset
from
.custom_3d_seg
import
Custom3DSegDataset
from
.det3d_dataset
import
Det3DDataset
from
.det3d_dataset
import
Det3DDataset
from
.kitti_dataset
import
KittiDataset
from
.kitti_dataset
import
KittiDataset
from
.kitti_mono_dataset
import
KittiMonoDataset
from
.kitti_mono_dataset
import
KittiMonoDataset
...
@@ -22,6 +21,7 @@ from .pipelines import (AffineResize, BackgroundPointsFilter, GlobalAlignment,
...
@@ -22,6 +21,7 @@ from .pipelines import (AffineResize, BackgroundPointsFilter, GlobalAlignment,
from
.s3dis_dataset
import
S3DISDataset
,
S3DISSegDataset
from
.s3dis_dataset
import
S3DISDataset
,
S3DISSegDataset
from
.scannet_dataset
import
(
ScanNetDataset
,
ScanNetInstanceSegDataset
,
from
.scannet_dataset
import
(
ScanNetDataset
,
ScanNetInstanceSegDataset
,
ScanNetSegDataset
)
ScanNetSegDataset
)
from
.seg3d_dataset
import
Seg3DDataset
from
.semantickitti_dataset
import
SemanticKITTIDataset
from
.semantickitti_dataset
import
SemanticKITTIDataset
from
.sunrgbd_dataset
import
SUNRGBDDataset
from
.sunrgbd_dataset
import
SUNRGBDDataset
from
.utils
import
get_loading_pipeline
from
.utils
import
get_loading_pipeline
...
@@ -36,7 +36,7 @@ __all__ = [
...
@@ -36,7 +36,7 @@ __all__ = [
'IndoorPatchPointSample'
,
'IndoorPointSample'
,
'PointSample'
,
'IndoorPatchPointSample'
,
'IndoorPointSample'
,
'PointSample'
,
'LoadAnnotations3D'
,
'GlobalAlignment'
,
'SUNRGBDDataset'
,
'ScanNetDataset'
,
'LoadAnnotations3D'
,
'GlobalAlignment'
,
'SUNRGBDDataset'
,
'ScanNetDataset'
,
'ScanNetSegDataset'
,
'ScanNetInstanceSegDataset'
,
'SemanticKITTIDataset'
,
'ScanNetSegDataset'
,
'ScanNetInstanceSegDataset'
,
'SemanticKITTIDataset'
,
'Det3DDataset'
,
'
Custom3D
SegDataset'
,
'LoadPointsFromMultiSweeps'
,
'Det3DDataset'
,
'Seg
3D
Dataset'
,
'LoadPointsFromMultiSweeps'
,
'WaymoDataset'
,
'BackgroundPointsFilter'
,
'VoxelBasedPointSampler'
,
'WaymoDataset'
,
'BackgroundPointsFilter'
,
'VoxelBasedPointSampler'
,
'get_loading_pipeline'
,
'RandomDropPointsColor'
,
'RandomJitterPoints'
,
'get_loading_pipeline'
,
'RandomDropPointsColor'
,
'RandomJitterPoints'
,
'ObjectNameFilter'
,
'AffineResize'
,
'RandomShiftScale'
,
'ObjectNameFilter'
,
'AffineResize'
,
'RandomShiftScale'
,
...
...
mmdet3d/datasets/custom_3d_seg.py
deleted
100644 → 0
View file @
1039ad0e
# Copyright (c) OpenMMLab. All rights reserved.
import
tempfile
import
warnings
from
os
import
path
as
osp
import
mmcv
import
numpy
as
np
from
torch.utils.data
import
Dataset
from
mmdet3d.registry
import
DATASETS
from
.pipelines
import
Compose
from
.utils
import
extract_result_dict
,
get_loading_pipeline
@
DATASETS
.
register_module
()
class
Custom3DSegDataset
(
Dataset
):
"""Customized 3D dataset for semantic segmentation task.
This is the base dataset of ScanNet and S3DIS dataset.
Args:
data_root (str): Path of dataset root.
ann_file (str): Path of annotation file.
pipeline (list[dict], optional): Pipeline used for data processing.
Defaults to None.
classes (tuple[str], optional): Classes used in the dataset.
Defaults to None.
palette (list[list[int]], optional): The palette of segmentation map.
Defaults to None.
modality (dict, optional): Modality to specify the sensor data used
as input. Defaults to None.
test_mode (bool, optional): Whether the dataset is in test mode.
Defaults to False.
ignore_index (int, optional): The label index to be ignored, e.g.
unannotated points. If None is given, set to len(self.CLASSES) to
be consistent with PointSegClassMapping function in pipeline.
Defaults to None.
scene_idxs (np.ndarray | str, optional): Precomputed index to load
data. For scenes with many points, we may sample it several times.
Defaults to None.
"""
# names of all classes data used for the task
CLASSES
=
None
# class_ids used for training
VALID_CLASS_IDS
=
None
# all possible class_ids in loaded segmentation mask
ALL_CLASS_IDS
=
None
# official color for visualization
PALETTE
=
None
def
__init__
(
self
,
data_root
,
ann_file
,
pipeline
=
None
,
classes
=
None
,
palette
=
None
,
modality
=
None
,
test_mode
=
False
,
ignore_index
=
None
,
scene_idxs
=
None
,
file_client_args
=
dict
(
backend
=
'disk'
)):
super
().
__init__
()
self
.
data_root
=
data_root
self
.
ann_file
=
ann_file
self
.
test_mode
=
test_mode
self
.
modality
=
modality
self
.
file_client
=
mmcv
.
FileClient
(
**
file_client_args
)
# load annotations
if
hasattr
(
self
.
file_client
,
'get_local_path'
):
with
self
.
file_client
.
get_local_path
(
self
.
ann_file
)
as
local_path
:
self
.
data_infos
=
self
.
load_annotations
(
open
(
local_path
,
'rb'
))
else
:
warnings
.
warn
(
'The used MMCV version does not have get_local_path. '
f
'We treat the
{
self
.
ann_file
}
as local paths and it '
'might cause errors if the path is not a local path. '
'Please use MMCV>= 1.3.16 if you meet errors.'
)
self
.
data_infos
=
self
.
load_annotations
(
self
.
ann_file
)
if
pipeline
is
not
None
:
self
.
pipeline
=
Compose
(
pipeline
)
self
.
ignore_index
=
len
(
self
.
CLASSES
)
if
\
ignore_index
is
None
else
ignore_index
self
.
scene_idxs
=
self
.
get_scene_idxs
(
scene_idxs
)
self
.
CLASSES
,
self
.
PALETTE
=
\
self
.
get_classes_and_palette
(
classes
,
palette
)
# set group flag for the sampler
if
not
self
.
test_mode
:
self
.
_set_group_flag
()
def
load_annotations
(
self
,
ann_file
):
"""Load annotations from ann_file.
Args:
ann_file (str): Path of the annotation file.
Returns:
list[dict]: List of annotations.
"""
# loading data from a file-like object needs file format
return
mmcv
.
load
(
ann_file
,
file_format
=
'pkl'
)
def
get_data_info
(
self
,
index
):
"""Get data info according to the given index.
Args:
index (int): Index of the sample data to get.
Returns:
dict: Data information that will be passed to the data
preprocessing pipelines. It includes the following keys:
- sample_idx (str): Sample index.
- pts_filename (str): Filename of point clouds.
- file_name (str): Filename of point clouds.
- ann_info (dict): Annotation info.
"""
info
=
self
.
data_infos
[
index
]
sample_idx
=
info
[
'point_cloud'
][
'lidar_idx'
]
pts_filename
=
osp
.
join
(
self
.
data_root
,
info
[
'pts_path'
])
input_dict
=
dict
(
pts_filename
=
pts_filename
,
sample_idx
=
sample_idx
,
file_name
=
pts_filename
)
if
not
self
.
test_mode
:
annos
=
self
.
get_ann_info
(
index
)
input_dict
[
'ann_info'
]
=
annos
return
input_dict
def
pre_pipeline
(
self
,
results
):
"""Initialization before data preparation.
Args:
results (dict): Dict before data preprocessing.
- img_fields (list): Image fields.
- pts_mask_fields (list): Mask fields of points.
- pts_seg_fields (list): Mask fields of point segments.
- mask_fields (list): Fields of masks.
- seg_fields (list): Segment fields.
"""
results
[
'img_fields'
]
=
[]
results
[
'pts_mask_fields'
]
=
[]
results
[
'pts_seg_fields'
]
=
[]
results
[
'mask_fields'
]
=
[]
results
[
'seg_fields'
]
=
[]
results
[
'bbox3d_fields'
]
=
[]
def
prepare_train_data
(
self
,
index
):
"""Training data preparation.
Args:
index (int): Index for accessing the target data.
Returns:
dict: Training data dict of the corresponding index.
"""
input_dict
=
self
.
get_data_info
(
index
)
if
input_dict
is
None
:
return
None
self
.
pre_pipeline
(
input_dict
)
example
=
self
.
pipeline
(
input_dict
)
return
example
def
prepare_test_data
(
self
,
index
):
"""Prepare data for testing.
Args:
index (int): Index for accessing the target data.
Returns:
dict: Testing data dict of the corresponding index.
"""
input_dict
=
self
.
get_data_info
(
index
)
self
.
pre_pipeline
(
input_dict
)
example
=
self
.
pipeline
(
input_dict
)
return
example
def
get_classes_and_palette
(
self
,
classes
=
None
,
palette
=
None
):
"""Get class names of current dataset.
This function is taken from MMSegmentation.
Args:
classes (Sequence[str] | str): If classes is None, use
default CLASSES defined by builtin dataset. If classes is a
string, take it as a file name. The file contains the name of
classes where each line contains one class name. If classes is
a tuple or list, override the CLASSES defined by the dataset.
Defaults to None.
palette (Sequence[Sequence[int]]] | np.ndarray):
The palette of segmentation map. If None is given, random
palette will be generated. Defaults to None.
"""
if
classes
is
None
:
self
.
custom_classes
=
False
# map id in the loaded mask to label used for training
self
.
label_map
=
{
cls_id
:
self
.
ignore_index
for
cls_id
in
self
.
ALL_CLASS_IDS
}
self
.
label_map
.
update
(
{
cls_id
:
i
for
i
,
cls_id
in
enumerate
(
self
.
VALID_CLASS_IDS
)})
# map label to category name
self
.
label2cat
=
{
i
:
cat_name
for
i
,
cat_name
in
enumerate
(
self
.
CLASSES
)
}
return
self
.
CLASSES
,
self
.
PALETTE
self
.
custom_classes
=
True
if
isinstance
(
classes
,
str
):
# take it as a file path
class_names
=
mmcv
.
list_from_file
(
classes
)
elif
isinstance
(
classes
,
(
tuple
,
list
)):
class_names
=
classes
else
:
raise
ValueError
(
f
'Unsupported type
{
type
(
classes
)
}
of classes.'
)
if
self
.
CLASSES
:
if
not
set
(
class_names
).
issubset
(
self
.
CLASSES
):
raise
ValueError
(
'classes is not a subset of CLASSES.'
)
# update valid_class_ids
self
.
VALID_CLASS_IDS
=
[
self
.
VALID_CLASS_IDS
[
self
.
CLASSES
.
index
(
cls_name
)]
for
cls_name
in
class_names
]
# dictionary, its keys are the old label ids and its values
# are the new label ids.
# used for changing pixel labels in load_annotations.
self
.
label_map
=
{
cls_id
:
self
.
ignore_index
for
cls_id
in
self
.
ALL_CLASS_IDS
}
self
.
label_map
.
update
(
{
cls_id
:
i
for
i
,
cls_id
in
enumerate
(
self
.
VALID_CLASS_IDS
)})
self
.
label2cat
=
{
i
:
cat_name
for
i
,
cat_name
in
enumerate
(
class_names
)
}
# modify palette for visualization
palette
=
[
self
.
PALETTE
[
self
.
CLASSES
.
index
(
cls_name
)]
for
cls_name
in
class_names
]
return
class_names
,
palette
def
get_scene_idxs
(
self
,
scene_idxs
):
"""Compute scene_idxs for data sampling.
We sample more times for scenes with more points.
"""
if
self
.
test_mode
:
# when testing, we load one whole scene every time
return
np
.
arange
(
len
(
self
.
data_infos
)).
astype
(
np
.
int32
)
# we may need to re-sample different scenes according to scene_idxs
# this is necessary for indoor scene segmentation such as ScanNet
if
scene_idxs
is
None
:
scene_idxs
=
np
.
arange
(
len
(
self
.
data_infos
))
if
isinstance
(
scene_idxs
,
str
):
with
self
.
file_client
.
get_local_path
(
scene_idxs
)
as
local_path
:
scene_idxs
=
np
.
load
(
local_path
)
else
:
scene_idxs
=
np
.
array
(
scene_idxs
)
return
scene_idxs
.
astype
(
np
.
int32
)
def
format_results
(
self
,
outputs
,
pklfile_prefix
=
None
,
submission_prefix
=
None
):
"""Format the results to pkl file.
Args:
outputs (list[dict]): Testing results of the dataset.
pklfile_prefix (str): The prefix of pkl files. It includes
the file path and the prefix of filename, e.g., "a/b/prefix".
If not specified, a temp file will be created. Default: None.
Returns:
tuple: (outputs, tmp_dir), outputs is the detection results,
tmp_dir is the temporal directory created for saving json
files when ``jsonfile_prefix`` is not specified.
"""
if
pklfile_prefix
is
None
:
tmp_dir
=
tempfile
.
TemporaryDirectory
()
pklfile_prefix
=
osp
.
join
(
tmp_dir
.
name
,
'results'
)
out
=
f
'
{
pklfile_prefix
}
.pkl'
mmcv
.
dump
(
outputs
,
out
)
return
outputs
,
tmp_dir
def
evaluate
(
self
,
results
,
metric
=
None
,
logger
=
None
,
show
=
False
,
out_dir
=
None
,
pipeline
=
None
):
"""Evaluate.
Evaluation in semantic segmentation protocol.
Args:
results (list[dict]): List of results.
metric (str | list[str]): Metrics to be evaluated.
logger (logging.Logger | str, optional): Logger used for printing
related information during evaluation. Defaults to None.
show (bool, optional): Whether to visualize.
Defaults to False.
out_dir (str, optional): Path to save the visualization results.
Defaults to None.
pipeline (list[dict], optional): raw data loading for showing.
Default: None.
Returns:
dict: Evaluation results.
"""
from
mmdet3d.core.evaluation
import
seg_eval
assert
isinstance
(
results
,
list
),
f
'Expect results to be list, got
{
type
(
results
)
}
.'
assert
len
(
results
)
>
0
,
'Expect length of results > 0.'
assert
len
(
results
)
==
len
(
self
.
data_infos
)
assert
isinstance
(
results
[
0
],
dict
),
f
'Expect elements in results to be dict, got
{
type
(
results
[
0
])
}
.'
load_pipeline
=
self
.
_get_pipeline
(
pipeline
)
pred_sem_masks
=
[
result
[
'semantic_mask'
]
for
result
in
results
]
gt_sem_masks
=
[
self
.
_extract_data
(
i
,
load_pipeline
,
'pts_semantic_mask'
,
load_annos
=
True
)
for
i
in
range
(
len
(
self
.
data_infos
))
]
ret_dict
=
seg_eval
(
gt_sem_masks
,
pred_sem_masks
,
self
.
label2cat
,
self
.
ignore_index
,
logger
=
logger
)
if
show
:
self
.
show
(
pred_sem_masks
,
out_dir
,
pipeline
=
pipeline
)
return
ret_dict
def
_rand_another
(
self
,
idx
):
"""Randomly get another item with the same flag.
Returns:
int: Another index of item with the same flag.
"""
pool
=
np
.
where
(
self
.
flag
==
self
.
flag
[
idx
])[
0
]
return
np
.
random
.
choice
(
pool
)
def
_build_default_pipeline
(
self
):
"""Build the default pipeline for this dataset."""
raise
NotImplementedError
(
'_build_default_pipeline is not implemented '
f
'for dataset
{
self
.
__class__
.
__name__
}
'
)
def
_get_pipeline
(
self
,
pipeline
):
"""Get data loading pipeline in self.show/evaluate function.
Args:
pipeline (list[dict]): Input pipeline. If None is given,
get from self.pipeline.
"""
if
pipeline
is
None
:
if
not
hasattr
(
self
,
'pipeline'
)
or
self
.
pipeline
is
None
:
warnings
.
warn
(
'Use default pipeline for data loading, this may cause '
'errors when data is on ceph'
)
return
self
.
_build_default_pipeline
()
loading_pipeline
=
get_loading_pipeline
(
self
.
pipeline
.
transforms
)
return
Compose
(
loading_pipeline
)
return
Compose
(
pipeline
)
def
_extract_data
(
self
,
index
,
pipeline
,
key
,
load_annos
=
False
):
"""Load data using input pipeline and extract data according to key.
Args:
index (int): Index for accessing the target data.
pipeline (:obj:`Compose`): Composed data loading pipeline.
key (str | list[str]): One single or a list of data key.
load_annos (bool): Whether to load data annotations.
If True, need to set self.test_mode as False before loading.
Returns:
np.ndarray | torch.Tensor | list[np.ndarray | torch.Tensor]:
A single or a list of loaded data.
"""
assert
pipeline
is
not
None
,
'data loading pipeline is not provided'
# when we want to load ground-truth via pipeline (e.g. bbox, seg mask)
# we need to set self.test_mode as False so that we have 'annos'
if
load_annos
:
original_test_mode
=
self
.
test_mode
self
.
test_mode
=
False
input_dict
=
self
.
get_data_info
(
index
)
self
.
pre_pipeline
(
input_dict
)
example
=
pipeline
(
input_dict
)
# extract data items according to keys
if
isinstance
(
key
,
str
):
data
=
extract_result_dict
(
example
,
key
)
else
:
data
=
[
extract_result_dict
(
example
,
k
)
for
k
in
key
]
if
load_annos
:
self
.
test_mode
=
original_test_mode
return
data
def
__len__
(
self
):
"""Return the length of scene_idxs.
Returns:
int: Length of data infos.
"""
return
len
(
self
.
scene_idxs
)
def
__getitem__
(
self
,
idx
):
"""Get item from infos according to the given index.
In indoor scene segmentation task, each scene contains millions of
points. However, we only sample less than 10k points within a patch
each time. Therefore, we use `scene_idxs` to re-sample different rooms.
Returns:
dict: Data dictionary of the corresponding index.
"""
scene_idx
=
self
.
scene_idxs
[
idx
]
# map to scene idx
if
self
.
test_mode
:
return
self
.
prepare_test_data
(
scene_idx
)
while
True
:
data
=
self
.
prepare_train_data
(
scene_idx
)
if
data
is
None
:
idx
=
self
.
_rand_another
(
idx
)
scene_idx
=
self
.
scene_idxs
[
idx
]
# map to scene idx
continue
return
data
def
_set_group_flag
(
self
):
"""Set flag according to image aspect ratio.
Images with aspect ratio greater than 1 will be set as group 1,
otherwise group 0. In 3D datasets, they are all the same, thus are all
zeros.
"""
self
.
flag
=
np
.
zeros
(
len
(
self
),
dtype
=
np
.
uint8
)
mmdet3d/datasets/pipelines/formating.py
View file @
360c27f9
...
@@ -6,7 +6,7 @@ from mmcv import BaseTransform
...
@@ -6,7 +6,7 @@ from mmcv import BaseTransform
from
mmcv.transforms
import
to_tensor
from
mmcv.transforms
import
to_tensor
from
mmengine
import
InstanceData
from
mmengine
import
InstanceData
from
mmdet3d.core
import
Det3DDataSample
from
mmdet3d.core
import
Det3DDataSample
,
PointData
from
mmdet3d.core.bbox
import
BaseInstance3DBoxes
from
mmdet3d.core.bbox
import
BaseInstance3DBoxes
from
mmdet3d.core.points
import
BasePoints
from
mmdet3d.core.points
import
BasePoints
from
mmdet3d.registry
import
TRANSFORMS
from
mmdet3d.registry
import
TRANSFORMS
...
@@ -143,7 +143,7 @@ class Pack3DDetInputs(BaseTransform):
...
@@ -143,7 +143,7 @@ class Pack3DDetInputs(BaseTransform):
data_sample
=
Det3DDataSample
()
data_sample
=
Det3DDataSample
()
gt_instances_3d
=
InstanceData
()
gt_instances_3d
=
InstanceData
()
gt_instances
=
InstanceData
()
gt_instances
=
InstanceData
()
seg_data
=
dict
()
gt_pts_seg
=
PointData
()
img_metas
=
{}
img_metas
=
{}
for
key
in
self
.
meta_keys
:
for
key
in
self
.
meta_keys
:
...
@@ -161,7 +161,7 @@ class Pack3DDetInputs(BaseTransform):
...
@@ -161,7 +161,7 @@ class Pack3DDetInputs(BaseTransform):
elif
key
in
self
.
INSTANCEDATA_2D_KEYS
:
elif
key
in
self
.
INSTANCEDATA_2D_KEYS
:
gt_instances
[
self
.
_remove_prefix
(
key
)]
=
results
[
key
]
gt_instances
[
self
.
_remove_prefix
(
key
)]
=
results
[
key
]
elif
key
in
self
.
SEG_KEYS
:
elif
key
in
self
.
SEG_KEYS
:
seg_data
[
self
.
_remove_prefix
(
key
)]
=
results
[
key
]
gt_pts_seg
[
self
.
_remove_prefix
(
key
)]
=
results
[
key
]
else
:
else
:
raise
NotImplementedError
(
f
'Please modified '
raise
NotImplementedError
(
f
'Please modified '
f
'`Pack3DDetInputs` '
f
'`Pack3DDetInputs` '
...
@@ -170,7 +170,7 @@ class Pack3DDetInputs(BaseTransform):
...
@@ -170,7 +170,7 @@ class Pack3DDetInputs(BaseTransform):
data_sample
.
gt_instances_3d
=
gt_instances_3d
data_sample
.
gt_instances_3d
=
gt_instances_3d
data_sample
.
gt_instances
=
gt_instances
data_sample
.
gt_instances
=
gt_instances
data_sample
.
seg_data
=
seg_data
data_sample
.
gt_pts_seg
=
gt_pts_seg
if
'eval_ann_info'
in
results
:
if
'eval_ann_info'
in
results
:
data_sample
.
eval_ann_info
=
results
[
'eval_ann_info'
]
data_sample
.
eval_ann_info
=
results
[
'eval_ann_info'
]
else
:
else
:
...
...
mmdet3d/datasets/pipelines/loading.py
View file @
360c27f9
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
from
typing
import
Sequence
from
typing
import
List
import
mmcv
import
mmcv
import
numpy
as
np
import
numpy
as
np
...
@@ -270,22 +270,6 @@ class PointSegClassMapping(BaseTransform):
...
@@ -270,22 +270,6 @@ class PointSegClassMapping(BaseTransform):
segmentation mask. Defaults to 40.
segmentation mask. Defaults to 40.
"""
"""
def
__init__
(
self
,
valid_cat_ids
:
Sequence
[
int
],
max_cat_id
:
int
=
40
)
->
None
:
assert
max_cat_id
>=
np
.
max
(
valid_cat_ids
),
\
'max_cat_id should be greater than maximum id in valid_cat_ids'
self
.
valid_cat_ids
=
valid_cat_ids
self
.
max_cat_id
=
int
(
max_cat_id
)
# build cat_id to class index mapping
neg_cls
=
len
(
valid_cat_ids
)
self
.
cat_id2class
=
np
.
ones
(
self
.
max_cat_id
+
1
,
dtype
=
np
.
int
)
*
neg_cls
for
cls_idx
,
cat_id
in
enumerate
(
valid_cat_ids
):
self
.
cat_id2class
[
cat_id
]
=
cls_idx
def
transform
(
self
,
results
:
dict
)
->
None
:
def
transform
(
self
,
results
:
dict
)
->
None
:
"""Call function to map original semantic class to valid category ids.
"""Call function to map original semantic class to valid category ids.
...
@@ -301,9 +285,19 @@ class PointSegClassMapping(BaseTransform):
...
@@ -301,9 +285,19 @@ class PointSegClassMapping(BaseTransform):
assert
'pts_semantic_mask'
in
results
assert
'pts_semantic_mask'
in
results
pts_semantic_mask
=
results
[
'pts_semantic_mask'
]
pts_semantic_mask
=
results
[
'pts_semantic_mask'
]
converted_pts_sem_mask
=
self
.
cat_id2class
[
pts_semantic_mask
]
assert
'label_mapping'
in
results
label_mapping
=
results
[
'label_mapping'
]
converted_pts_sem_mask
=
\
np
.
array
([
label_mapping
[
mask
]
for
mask
in
pts_semantic_mask
])
results
[
'pts_semantic_mask'
]
=
converted_pts_sem_mask
results
[
'pts_semantic_mask'
]
=
converted_pts_sem_mask
# 'eval_ann_info' will be passed to evaluator
if
'eval_ann_info'
in
results
:
assert
'pts_semantic_mask'
in
results
[
'eval_ann_info'
]
results
[
'eval_ann_info'
][
'pts_semantic_mask'
]
=
\
converted_pts_sem_mask
return
results
return
results
def
__repr__
(
self
):
def
__repr__
(
self
):
...
@@ -315,17 +309,17 @@ class PointSegClassMapping(BaseTransform):
...
@@ -315,17 +309,17 @@ class PointSegClassMapping(BaseTransform):
@
TRANSFORMS
.
register_module
()
@
TRANSFORMS
.
register_module
()
class
NormalizePointsColor
(
object
):
class
NormalizePointsColor
(
BaseTransform
):
"""Normalize color of points.
"""Normalize color of points.
Args:
Args:
color_mean (list[float]): Mean color of the point cloud.
color_mean (list[float]): Mean color of the point cloud.
"""
"""
def
__init__
(
self
,
color_mean
)
:
def
__init__
(
self
,
color_mean
:
List
[
float
])
->
None
:
self
.
color_mean
=
color_mean
self
.
color_mean
=
color_mean
def
__call__
(
self
,
results
)
:
def
transform
(
self
,
input_dict
:
dict
)
->
dict
:
"""Call function to normalize color of points.
"""Call function to normalize color of points.
Args:
Args:
...
@@ -337,7 +331,7 @@ class NormalizePointsColor(object):
...
@@ -337,7 +331,7 @@ class NormalizePointsColor(object):
- points (:obj:`BasePoints`): Points after color normalization.
- points (:obj:`BasePoints`): Points after color normalization.
"""
"""
points
=
results
[
'points'
]
points
=
input_dict
[
'points'
]
assert
points
.
attribute_dims
is
not
None
and
\
assert
points
.
attribute_dims
is
not
None
and
\
'color'
in
points
.
attribute_dims
.
keys
(),
\
'color'
in
points
.
attribute_dims
.
keys
(),
\
'Expect points have color attribute'
'Expect points have color attribute'
...
@@ -345,8 +339,8 @@ class NormalizePointsColor(object):
...
@@ -345,8 +339,8 @@ class NormalizePointsColor(object):
points
.
color
=
points
.
color
-
\
points
.
color
=
points
.
color
-
\
points
.
color
.
new_tensor
(
self
.
color_mean
)
points
.
color
.
new_tensor
(
self
.
color_mean
)
points
.
color
=
points
.
color
/
255.0
points
.
color
=
points
.
color
/
255.0
results
[
'points'
]
=
points
input_dict
[
'points'
]
=
points
return
results
return
input_dict
def
__repr__
(
self
):
def
__repr__
(
self
):
"""str: Return a string that describes the module."""
"""str: Return a string that describes the module."""
...
...
mmdet3d/datasets/pipelines/transforms_3d.py
View file @
360c27f9
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
import
random
import
random
import
warnings
import
warnings
from
typing
import
Dict
,
List
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
import
cv2
import
cv2
import
numpy
as
np
import
numpy
as
np
...
@@ -18,7 +18,7 @@ from .data_augment_utils import noise_per_object_v3_
...
@@ -18,7 +18,7 @@ from .data_augment_utils import noise_per_object_v3_
@
TRANSFORMS
.
register_module
()
@
TRANSFORMS
.
register_module
()
class
RandomDropPointsColor
(
object
):
class
RandomDropPointsColor
(
BaseTransform
):
r
"""Randomly set the color of points to all zeros.
r
"""Randomly set the color of points to all zeros.
Once this transform is executed, all the points' color will be dropped.
Once this transform is executed, all the points' color will be dropped.
...
@@ -30,12 +30,12 @@ class RandomDropPointsColor(object):
...
@@ -30,12 +30,12 @@ class RandomDropPointsColor(object):
Defaults to 0.2.
Defaults to 0.2.
"""
"""
def
__init__
(
self
,
drop_ratio
=
0.2
)
:
def
__init__
(
self
,
drop_ratio
:
float
=
0.2
)
->
None
:
assert
isinstance
(
drop_ratio
,
(
int
,
float
))
and
0
<=
drop_ratio
<=
1
,
\
assert
isinstance
(
drop_ratio
,
(
int
,
float
))
and
0
<=
drop_ratio
<=
1
,
\
f
'invalid drop_ratio value
{
drop_ratio
}
'
f
'invalid drop_ratio value
{
drop_ratio
}
'
self
.
drop_ratio
=
drop_ratio
self
.
drop_ratio
=
drop_ratio
def
__call__
(
self
,
input_dict
)
:
def
transform
(
self
,
input_dict
:
dict
)
->
dict
:
"""Call function to drop point colors.
"""Call function to drop point colors.
Args:
Args:
...
@@ -224,7 +224,7 @@ class RandomFlip3D(RandomFlip):
...
@@ -224,7 +224,7 @@ class RandomFlip3D(RandomFlip):
@
TRANSFORMS
.
register_module
()
@
TRANSFORMS
.
register_module
()
class
RandomJitterPoints
(
object
):
class
RandomJitterPoints
(
BaseTransform
):
"""Randomly jitter point coordinates.
"""Randomly jitter point coordinates.
Different from the global translation in ``GlobalRotScaleTrans``, here we
Different from the global translation in ``GlobalRotScaleTrans``, here we
...
@@ -246,8 +246,8 @@ class RandomJitterPoints(object):
...
@@ -246,8 +246,8 @@ class RandomJitterPoints(object):
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
jitter_std
=
[
0.01
,
0.01
,
0.01
],
jitter_std
:
List
[
float
]
=
[
0.01
,
0.01
,
0.01
],
clip_range
=
[
-
0.05
,
0.05
]):
clip_range
:
List
[
float
]
=
[
-
0.05
,
0.05
])
->
None
:
seq_types
=
(
list
,
tuple
,
np
.
ndarray
)
seq_types
=
(
list
,
tuple
,
np
.
ndarray
)
if
not
isinstance
(
jitter_std
,
seq_types
):
if
not
isinstance
(
jitter_std
,
seq_types
):
assert
isinstance
(
jitter_std
,
(
int
,
float
)),
\
assert
isinstance
(
jitter_std
,
(
int
,
float
)),
\
...
@@ -262,7 +262,7 @@ class RandomJitterPoints(object):
...
@@ -262,7 +262,7 @@ class RandomJitterPoints(object):
clip_range
=
[
-
clip_range
,
clip_range
]
clip_range
=
[
-
clip_range
,
clip_range
]
self
.
clip_range
=
clip_range
self
.
clip_range
=
clip_range
def
__call__
(
self
,
input_dict
)
:
def
transform
(
self
,
input_dict
:
dict
)
->
dict
:
"""Call function to jitter all the points in the scene.
"""Call function to jitter all the points in the scene.
Args:
Args:
...
@@ -780,10 +780,10 @@ class GlobalRotScaleTrans(BaseTransform):
...
@@ -780,10 +780,10 @@ class GlobalRotScaleTrans(BaseTransform):
@
TRANSFORMS
.
register_module
()
@
TRANSFORMS
.
register_module
()
class
PointShuffle
(
object
):
class
PointShuffle
(
BaseTransform
):
"""Shuffle input points."""
"""Shuffle input points."""
def
__call__
(
self
,
input_dict
)
:
def
transform
(
self
,
input_dict
:
dict
)
->
dict
:
"""Call function to shuffle points.
"""Call function to shuffle points.
Args:
Args:
...
@@ -1113,7 +1113,7 @@ class IndoorPointSample(PointSample):
...
@@ -1113,7 +1113,7 @@ class IndoorPointSample(PointSample):
@
TRANSFORMS
.
register_module
()
@
TRANSFORMS
.
register_module
()
class
IndoorPatchPointSample
(
object
):
class
IndoorPatchPointSample
(
BaseTransform
):
r
"""Indoor point sample within a patch. Modified from `PointNet++ <https://
r
"""Indoor point sample within a patch. Modified from `PointNet++ <https://
github.com/charlesq34/pointnet2/blob/master/scannet/scannet_dataset.py>`_.
github.com/charlesq34/pointnet2/blob/master/scannet/scannet_dataset.py>`_.
...
@@ -1152,15 +1152,15 @@ class IndoorPatchPointSample(object):
...
@@ -1152,15 +1152,15 @@ class IndoorPatchPointSample(object):
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
num_points
,
num_points
:
int
,
block_size
=
1.5
,
block_size
:
float
=
1.5
,
sample_rate
=
None
,
sample_rate
:
Optional
[
float
]
=
None
,
ignore_index
=
None
,
ignore_index
:
Optional
[
int
]
=
None
,
use_normalized_coord
=
False
,
use_normalized_coord
:
bool
=
False
,
num_try
=
10
,
num_try
:
int
=
10
,
enlarge_size
=
0.2
,
enlarge_size
:
float
=
0.2
,
min_unique_num
=
None
,
min_unique_num
:
Optional
[
int
]
=
None
,
eps
=
1e-2
)
:
eps
:
float
=
1e-2
)
->
None
:
self
.
num_points
=
num_points
self
.
num_points
=
num_points
self
.
block_size
=
block_size
self
.
block_size
=
block_size
self
.
ignore_index
=
ignore_index
self
.
ignore_index
=
ignore_index
...
@@ -1175,8 +1175,10 @@ class IndoorPatchPointSample(object):
...
@@ -1175,8 +1175,10 @@ class IndoorPatchPointSample(object):
"'sample_rate' has been deprecated and will be removed in "
"'sample_rate' has been deprecated and will be removed in "
'the future. Please remove them from your code.'
)
'the future. Please remove them from your code.'
)
def
_input_generation
(
self
,
coords
,
patch_center
,
coord_max
,
attributes
,
def
_input_generation
(
self
,
coords
:
np
.
ndarray
,
patch_center
:
np
.
ndarray
,
attribute_dims
,
point_type
):
coord_max
:
np
.
ndarray
,
attributes
:
np
.
ndarray
,
attribute_dims
:
dict
,
point_type
:
type
)
->
BasePoints
:
"""Generating model input.
"""Generating model input.
Generate input by subtracting patch center and adding additional
Generate input by subtracting patch center and adding additional
...
@@ -1216,7 +1218,8 @@ class IndoorPatchPointSample(object):
...
@@ -1216,7 +1218,8 @@ class IndoorPatchPointSample(object):
return
points
return
points
def
_patch_points_sampling
(
self
,
points
,
sem_mask
):
def
_patch_points_sampling
(
self
,
points
:
BasePoints
,
sem_mask
:
np
.
ndarray
)
->
BasePoints
:
"""Patch points sampling.
"""Patch points sampling.
First sample a valid patch.
First sample a valid patch.
...
@@ -1316,7 +1319,7 @@ class IndoorPatchPointSample(object):
...
@@ -1316,7 +1319,7 @@ class IndoorPatchPointSample(object):
return
points
,
choices
return
points
,
choices
def
__call__
(
self
,
results
)
:
def
transform
(
self
,
input_dict
:
dict
)
->
dict
:
"""Call function to sample points to in indoor scenes.
"""Call function to sample points to in indoor scenes.
Args:
Args:
...
@@ -1326,22 +1329,33 @@ class IndoorPatchPointSample(object):
...
@@ -1326,22 +1329,33 @@ class IndoorPatchPointSample(object):
dict: Results after sampling, 'points', 'pts_instance_mask'
dict: Results after sampling, 'points', 'pts_instance_mask'
and 'pts_semantic_mask' keys are updated in the result dict.
and 'pts_semantic_mask' keys are updated in the result dict.
"""
"""
points
=
results
[
'points'
]
points
=
input_dict
[
'points'
]
assert
'pts_semantic_mask'
in
results
.
keys
(),
\
assert
'pts_semantic_mask'
in
input_dict
.
keys
(),
\
'semantic mask should be provided in training and evaluation'
'semantic mask should be provided in training and evaluation'
pts_semantic_mask
=
results
[
'pts_semantic_mask'
]
pts_semantic_mask
=
input_dict
[
'pts_semantic_mask'
]
points
,
choices
=
self
.
_patch_points_sampling
(
points
,
points
,
choices
=
self
.
_patch_points_sampling
(
points
,
pts_semantic_mask
)
pts_semantic_mask
)
results
[
'points'
]
=
points
input_dict
[
'points'
]
=
points
results
[
'pts_semantic_mask'
]
=
pts_semantic_mask
[
choices
]
input_dict
[
'pts_semantic_mask'
]
=
pts_semantic_mask
[
choices
]
pts_instance_mask
=
results
.
get
(
'pts_instance_mask'
,
None
)
# 'eval_ann_info' will be passed to evaluator
if
'eval_ann_info'
in
input_dict
:
input_dict
[
'eval_ann_info'
][
'pts_semantic_mask'
]
=
\
pts_semantic_mask
[
choices
]
pts_instance_mask
=
input_dict
.
get
(
'pts_instance_mask'
,
None
)
if
pts_instance_mask
is
not
None
:
if
pts_instance_mask
is
not
None
:
results
[
'pts_instance_mask'
]
=
pts_instance_mask
[
choices
]
input_dict
[
'pts_instance_mask'
]
=
pts_instance_mask
[
choices
]
# 'eval_ann_info' will be passed to evaluator
if
'eval_ann_info'
in
input_dict
:
input_dict
[
'eval_ann_info'
][
'pts_instance_mask'
]
=
\
pts_instance_mask
[
choices
]
return
results
return
input_dict
def
__repr__
(
self
):
def
__repr__
(
self
):
"""str: Return a string that describes the module."""
"""str: Return a string that describes the module."""
...
@@ -1358,14 +1372,14 @@ class IndoorPatchPointSample(object):
...
@@ -1358,14 +1372,14 @@ class IndoorPatchPointSample(object):
@
TRANSFORMS
.
register_module
()
@
TRANSFORMS
.
register_module
()
class
BackgroundPointsFilter
(
object
):
class
BackgroundPointsFilter
(
BaseTransform
):
"""Filter background points near the bounding box.
"""Filter background points near the bounding box.
Args:
Args:
bbox_enlarge_range (tuple[float], float): Bbox enlarge range.
bbox_enlarge_range (tuple[float], float): Bbox enlarge range.
"""
"""
def
__init__
(
self
,
bbox_enlarge_range
)
:
def
__init__
(
self
,
bbox_enlarge_range
:
Union
[
Tuple
[
float
],
float
])
->
None
:
assert
(
is_tuple_of
(
bbox_enlarge_range
,
float
)
assert
(
is_tuple_of
(
bbox_enlarge_range
,
float
)
and
len
(
bbox_enlarge_range
)
==
3
)
\
and
len
(
bbox_enlarge_range
)
==
3
)
\
or
isinstance
(
bbox_enlarge_range
,
float
),
\
or
isinstance
(
bbox_enlarge_range
,
float
),
\
...
@@ -1376,7 +1390,7 @@ class BackgroundPointsFilter(object):
...
@@ -1376,7 +1390,7 @@ class BackgroundPointsFilter(object):
self
.
bbox_enlarge_range
=
np
.
array
(
self
.
bbox_enlarge_range
=
np
.
array
(
bbox_enlarge_range
,
dtype
=
np
.
float32
)[
np
.
newaxis
,
:]
bbox_enlarge_range
,
dtype
=
np
.
float32
)[
np
.
newaxis
,
:]
def
__call__
(
self
,
input_dict
)
:
def
transform
(
self
,
input_dict
:
dict
)
->
dict
:
"""Call function to filter points by the range.
"""Call function to filter points by the range.
Args:
Args:
...
...
mmdet3d/datasets/s3dis_dataset.py
View file @
360c27f9
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
from
os
import
path
as
osp
from
os
import
path
as
osp
from
typing
import
Callable
,
List
,
Optional
,
Union
import
numpy
as
np
import
numpy
as
np
from
mmdet3d.core
import
show_seg_result
from
mmdet3d.core.bbox
import
DepthInstance3DBoxes
from
mmdet3d.core.bbox
import
DepthInstance3DBoxes
from
mmdet3d.registry
import
DATASETS
from
mmdet3d.registry
import
DATASETS
from
.custom_3d_seg
import
Custom3DSegDataset
from
.det3d_dataset
import
Det3DDataset
from
.det3d_dataset
import
Det3DDataset
from
.pipelines
import
Compose
from
.pipelines
import
Compose
from
.seg3d_dataset
import
Seg3DDataset
@
DATASETS
.
register_module
()
@
DATASETS
.
register_module
()
...
@@ -153,7 +153,7 @@ class S3DISDataset(Det3DDataset):
...
@@ -153,7 +153,7 @@ class S3DISDataset(Det3DDataset):
return
Compose
(
pipeline
)
return
Compose
(
pipeline
)
class
_S3DISSegDataset
(
Custom3D
SegDataset
):
class
_S3DISSegDataset
(
Seg
3D
Dataset
):
r
"""S3DIS Dataset for Semantic Segmentation Task.
r
"""S3DIS Dataset for Semantic Segmentation Task.
This class is the inner dataset for S3DIS. Since S3DIS has 6 areas, we
This class is the inner dataset for S3DIS. Since S3DIS has 6 areas, we
...
@@ -185,114 +185,44 @@ class _S3DISSegDataset(Custom3DSegDataset):
...
@@ -185,114 +185,44 @@ class _S3DISSegDataset(Custom3DSegDataset):
data. For scenes with many points, we may sample it several times.
data. For scenes with many points, we may sample it several times.
Defaults to None.
Defaults to None.
"""
"""
CLASSES
=
(
'ceiling'
,
'floor'
,
'wall'
,
'beam'
,
'column'
,
'window'
,
'door'
,
METAINFO
=
{
'table'
,
'chair'
,
'sofa'
,
'bookcase'
,
'board'
,
'clutter'
)
'CLASSES'
:
(
'ceiling'
,
'floor'
,
'wall'
,
'beam'
,
'column'
,
'window'
,
'door'
,
VALID_CLASS_IDS
=
tuple
(
range
(
13
))
'table'
,
'chair'
,
'sofa'
,
'bookcase'
,
'board'
,
'clutter'
),
'PALETTE'
:
[[
0
,
255
,
0
],
[
0
,
0
,
255
],
[
0
,
255
,
255
],
[
255
,
255
,
0
],
ALL_CLASS_IDS
=
tuple
(
range
(
14
))
# possibly with 'stair' class
PALETTE
=
[[
0
,
255
,
0
],
[
0
,
0
,
255
],
[
0
,
255
,
255
],
[
255
,
255
,
0
],
[
255
,
0
,
255
],
[
100
,
100
,
255
],
[
200
,
200
,
100
],
[
255
,
0
,
255
],
[
100
,
100
,
255
],
[
200
,
200
,
100
],
[
170
,
120
,
200
],
[
255
,
0
,
0
],
[
200
,
100
,
100
],
[
10
,
200
,
100
],
[
170
,
120
,
200
],
[
255
,
0
,
0
],
[
200
,
100
,
100
],
[
200
,
200
,
200
],
[
50
,
50
,
50
]]
[
10
,
200
,
100
],
[
200
,
200
,
200
],
[
50
,
50
,
50
]],
'valid_class_ids'
:
tuple
(
range
(
13
)),
'all_class_ids'
:
tuple
(
range
(
14
))
# possibly with 'stair' class
}
def
__init__
(
self
,
def
__init__
(
self
,
data_root
,
data_root
:
Optional
[
str
]
=
None
,
ann_file
,
ann_file
:
str
=
''
,
pipeline
=
None
,
metainfo
:
Optional
[
dict
]
=
None
,
classes
=
None
,
data_prefix
:
dict
=
dict
(
palette
=
None
,
pts
=
'points'
,
img
=
''
,
instance_mask
=
''
,
semantic_mask
=
''
)
,
modality
=
None
,
pipeline
:
List
[
Union
[
dict
,
Callable
]]
=
[]
,
test_mode
=
False
,
modality
:
dict
=
dict
(
use_lidar
=
True
,
use_camera
=
False
)
,
ignore_index
=
None
,
ignore_index
=
None
,
scene_idxs
=
None
,
scene_idxs
=
None
,
**
kwargs
):
test_mode
=
False
,
**
kwargs
)
->
None
:
super
().
__init__
(
super
().
__init__
(
data_root
=
data_root
,
data_root
=
data_root
,
ann_file
=
ann_file
,
ann_file
=
ann_file
,
metainfo
=
metainfo
,
data_prefix
=
data_prefix
,
pipeline
=
pipeline
,
pipeline
=
pipeline
,
classes
=
classes
,
palette
=
palette
,
modality
=
modality
,
modality
=
modality
,
test_mode
=
test_mode
,
ignore_index
=
ignore_index
,
ignore_index
=
ignore_index
,
scene_idxs
=
scene_idxs
,
scene_idxs
=
scene_idxs
,
test_mode
=
test_mode
,
**
kwargs
)
**
kwargs
)
def
get_ann_info
(
self
,
index
):
"""Get annotation info according to the given index.
Args:
index (int): Index of the annotation data to get.
Returns:
dict: annotation information consists of the following keys:
- pts_semantic_mask_path (str): Path of semantic masks.
"""
# Use index to get the annos, thus the evalhook could also use this api
info
=
self
.
data_infos
[
index
]
pts_semantic_mask_path
=
osp
.
join
(
self
.
data_root
,
info
[
'pts_semantic_mask_path'
])
anns_results
=
dict
(
pts_semantic_mask_path
=
pts_semantic_mask_path
)
return
anns_results
def
_build_default_pipeline
(
self
):
"""Build the default pipeline for this dataset."""
pipeline
=
[
dict
(
type
=
'LoadPointsFromFile'
,
coord_type
=
'DEPTH'
,
shift_height
=
False
,
use_color
=
True
,
load_dim
=
6
,
use_dim
=
[
0
,
1
,
2
,
3
,
4
,
5
]),
dict
(
type
=
'LoadAnnotations3D'
,
with_bbox_3d
=
False
,
with_label_3d
=
False
,
with_mask_3d
=
False
,
with_seg_3d
=
True
),
dict
(
type
=
'PointSegClassMapping'
,
valid_cat_ids
=
self
.
VALID_CLASS_IDS
,
max_cat_id
=
np
.
max
(
self
.
ALL_CLASS_IDS
)),
dict
(
type
=
'DefaultFormatBundle3D'
,
with_label
=
False
,
class_names
=
self
.
CLASSES
),
dict
(
type
=
'Collect3D'
,
keys
=
[
'points'
,
'pts_semantic_mask'
])
]
return
Compose
(
pipeline
)
def
show
(
self
,
results
,
out_dir
,
show
=
True
,
pipeline
=
None
):
"""Results visualization.
Args:
results (list[dict]): List of bounding boxes results.
out_dir (str): Output directory of visualization result.
show (bool): Visualize the results online.
pipeline (list[dict], optional): raw data loading for showing.
Default: None.
"""
assert
out_dir
is
not
None
,
'Expect out_dir, got none.'
pipeline
=
self
.
_get_pipeline
(
pipeline
)
for
i
,
result
in
enumerate
(
results
):
data_info
=
self
.
data_infos
[
i
]
pts_path
=
data_info
[
'pts_path'
]
file_name
=
osp
.
split
(
pts_path
)[
-
1
].
split
(
'.'
)[
0
]
points
,
gt_sem_mask
=
self
.
_extract_data
(
i
,
pipeline
,
[
'points'
,
'pts_semantic_mask'
],
load_annos
=
True
)
points
=
points
.
numpy
()
pred_sem_mask
=
result
[
'semantic_mask'
].
numpy
()
show_seg_result
(
points
,
gt_sem_mask
,
pred_sem_mask
,
out_dir
,
file_name
,
np
.
array
(
self
.
PALETTE
),
self
.
ignore_index
,
show
)
def
get_scene_idxs
(
self
,
scene_idxs
):
def
get_scene_idxs
(
self
,
scene_idxs
):
"""Compute scene_idxs for data sampling.
"""Compute scene_idxs for data sampling.
...
@@ -341,16 +271,17 @@ class S3DISSegDataset(_S3DISSegDataset):
...
@@ -341,16 +271,17 @@ class S3DISSegDataset(_S3DISSegDataset):
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
data_root
,
data_root
:
Optional
[
str
]
=
None
,
ann_files
,
ann_files
:
str
=
''
,
pipeline
=
None
,
metainfo
:
Optional
[
dict
]
=
None
,
classes
=
None
,
data_prefix
:
dict
=
dict
(
palette
=
None
,
pts
=
'points'
,
img
=
''
,
instance_mask
=
''
,
semantic_mask
=
''
)
,
modality
=
None
,
pipeline
:
List
[
Union
[
dict
,
Callable
]]
=
[]
,
test_mode
=
False
,
modality
:
dict
=
dict
(
use_lidar
=
True
,
use_camera
=
False
)
,
ignore_index
=
None
,
ignore_index
=
None
,
scene_idxs
=
None
,
scene_idxs
=
None
,
**
kwargs
):
test_mode
=
False
,
**
kwargs
)
->
None
:
# make sure that ann_files and scene_idxs have same length
# make sure that ann_files and scene_idxs have same length
ann_files
=
self
.
_check_ann_files
(
ann_files
)
ann_files
=
self
.
_check_ann_files
(
ann_files
)
...
@@ -360,45 +291,45 @@ class S3DISSegDataset(_S3DISSegDataset):
...
@@ -360,45 +291,45 @@ class S3DISSegDataset(_S3DISSegDataset):
super
().
__init__
(
super
().
__init__
(
data_root
=
data_root
,
data_root
=
data_root
,
ann_file
=
ann_files
[
0
],
ann_file
=
ann_files
[
0
],
metainfo
=
metainfo
,
data_prefix
=
data_prefix
,
pipeline
=
pipeline
,
pipeline
=
pipeline
,
classes
=
classes
,
palette
=
palette
,
modality
=
modality
,
modality
=
modality
,
test_mode
=
test_mode
,
ignore_index
=
ignore_index
,
ignore_index
=
ignore_index
,
scene_idxs
=
scene_idxs
[
0
],
scene_idxs
=
scene_idxs
[
0
],
test_mode
=
test_mode
,
**
kwargs
)
**
kwargs
)
datasets
=
[
datasets
=
[
_S3DISSegDataset
(
_S3DISSegDataset
(
data_root
=
data_root
,
data_root
=
data_root
,
ann_file
=
ann_files
[
i
],
ann_file
=
ann_files
[
i
],
metainfo
=
metainfo
,
data_prefix
=
data_prefix
,
pipeline
=
pipeline
,
pipeline
=
pipeline
,
classes
=
classes
,
palette
=
palette
,
modality
=
modality
,
modality
=
modality
,
test_mode
=
test_mode
,
ignore_index
=
ignore_index
,
ignore_index
=
ignore_index
,
scene_idxs
=
scene_idxs
[
i
],
scene_idxs
=
scene_idxs
[
i
],
test_mode
=
test_mode
,
**
kwargs
)
for
i
in
range
(
len
(
ann_files
))
**
kwargs
)
for
i
in
range
(
len
(
ann_files
))
]
]
# data_
infos
and scene_idxs need to be concat
# data_
list
and scene_idxs need to be concat
self
.
concat_data_
infos
([
dst
.
data_
infos
for
dst
in
datasets
])
self
.
concat_data_
list
([
dst
.
data_
list
for
dst
in
datasets
])
self
.
concat_scene_idxs
([
dst
.
scene_idxs
for
dst
in
datasets
])
self
.
concat_scene_idxs
([
dst
.
scene_idxs
for
dst
in
datasets
])
# set group flag for the sampler
# set group flag for the sampler
if
not
self
.
test_mode
:
if
not
self
.
test_mode
:
self
.
_set_group_flag
()
self
.
_set_group_flag
()
def
concat_data_
infos
(
self
,
data_
info
s
):
def
concat_data_
list
(
self
,
data_
list
s
):
"""Concat data_
infos
from several datasets to form self.data_
infos
.
"""Concat data_
list
from several datasets to form self.data_
list
.
Args:
Args:
data_
info
s (list[list[dict]])
data_
list
s (list[list[dict]])
"""
"""
self
.
data_
infos
=
[
self
.
data_
list
=
[
info
for
one_
data_
infos
in
data_
info
s
for
info
in
one_data_infos
data
for
data_
list
in
data_
list
s
for
data
in
data_list
]
]
def
concat_scene_idxs
(
self
,
scene_idxs
):
def
concat_scene_idxs
(
self
,
scene_idxs
):
...
...
mmdet3d/datasets/scannet_dataset.py
View file @
360c27f9
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
import
tempfile
import
warnings
import
warnings
from
os
import
path
as
osp
from
os
import
path
as
osp
from
typing
import
Callable
,
List
,
Union
from
typing
import
Callable
,
List
,
Optional
,
Union
import
numpy
as
np
import
numpy
as
np
from
mmdet3d.core
import
instance_seg_eval
,
show_result
,
show_seg
_result
from
mmdet3d.core
import
show
_result
from
mmdet3d.core.bbox
import
DepthInstance3DBoxes
from
mmdet3d.core.bbox
import
DepthInstance3DBoxes
from
mmdet3d.registry
import
DATASETS
from
mmdet3d.registry
import
DATASETS
from
.custom_3d_seg
import
Custom3DSegDataset
from
.det3d_dataset
import
Det3DDataset
from
.det3d_dataset
import
Det3DDataset
from
.pipelines
import
Compose
from
.pipelines
import
Compose
from
.seg3d_dataset
import
Seg3DDataset
@
DATASETS
.
register_module
()
@
DATASETS
.
register_module
()
...
@@ -193,7 +192,7 @@ class ScanNetDataset(Det3DDataset):
...
@@ -193,7 +192,7 @@ class ScanNetDataset(Det3DDataset):
@
DATASETS
.
register_module
()
@
DATASETS
.
register_module
()
class
ScanNetSegDataset
(
Custom3D
SegDataset
):
class
ScanNetSegDataset
(
Seg
3D
Dataset
):
r
"""ScanNet Dataset for Semantic Segmentation Task.
r
"""ScanNet Dataset for Semantic Segmentation Task.
This class serves as the API for experiments on the ScanNet Dataset.
This class serves as the API for experiments on the ScanNet Dataset.
...
@@ -221,17 +220,13 @@ class ScanNetSegDataset(Custom3DSegDataset):
...
@@ -221,17 +220,13 @@ class ScanNetSegDataset(Custom3DSegDataset):
data. For scenes with many points, we may sample it several times.
data. For scenes with many points, we may sample it several times.
Defaults to None.
Defaults to None.
"""
"""
CLASSES
=
(
'wall'
,
'floor'
,
'cabinet'
,
'bed'
,
'chair'
,
'sofa'
,
'table'
,
METAINFO
=
{
'door'
,
'window'
,
'bookshelf'
,
'picture'
,
'counter'
,
'desk'
,
'CLASSES'
:
'curtain'
,
'refrigerator'
,
'showercurtrain'
,
'toilet'
,
'sink'
,
(
'wall'
,
'floor'
,
'cabinet'
,
'bed'
,
'chair'
,
'sofa'
,
'table'
,
'door'
,
'bathtub'
,
'otherfurniture'
)
'window'
,
'bookshelf'
,
'picture'
,
'counter'
,
'desk'
,
'curtain'
,
'refrigerator'
,
'showercurtrain'
,
'toilet'
,
'sink'
,
'bathtub'
,
VALID_CLASS_IDS
=
(
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
14
,
16
,
24
,
28
,
'otherfurniture'
),
33
,
34
,
36
,
39
)
'PALETTE'
:
[
ALL_CLASS_IDS
=
tuple
(
range
(
41
))
PALETTE
=
[
[
174
,
199
,
232
],
[
174
,
199
,
232
],
[
152
,
223
,
138
],
[
152
,
223
,
138
],
[
31
,
119
,
180
],
[
31
,
119
,
180
],
...
@@ -252,104 +247,37 @@ class ScanNetSegDataset(Custom3DSegDataset):
...
@@ -252,104 +247,37 @@ class ScanNetSegDataset(Custom3DSegDataset):
[
112
,
128
,
144
],
[
112
,
128
,
144
],
[
227
,
119
,
194
],
[
227
,
119
,
194
],
[
82
,
84
,
163
],
[
82
,
84
,
163
],
]
],
'valid_class_ids'
:
(
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
14
,
16
,
24
,
28
,
33
,
34
,
36
,
39
),
'all_class_ids'
:
tuple
(
range
(
41
)),
}
def
__init__
(
self
,
def
__init__
(
self
,
data_root
,
data_root
:
Optional
[
str
]
=
None
,
ann_file
,
ann_file
:
str
=
''
,
pipeline
=
None
,
metainfo
:
Optional
[
dict
]
=
None
,
classes
=
None
,
data_prefix
:
dict
=
dict
(
palette
=
None
,
pts
=
'points'
,
img
=
''
,
instance_mask
=
''
,
semantic_mask
=
''
)
,
modality
=
None
,
pipeline
:
List
[
Union
[
dict
,
Callable
]]
=
[]
,
test_mode
=
False
,
modality
:
dict
=
dict
(
use_lidar
=
True
,
use_camera
=
False
)
,
ignore_index
=
None
,
ignore_index
=
None
,
scene_idxs
=
None
,
scene_idxs
=
None
,
**
kwargs
):
test_mode
=
False
,
**
kwargs
)
->
None
:
super
().
__init__
(
super
().
__init__
(
data_root
=
data_root
,
data_root
=
data_root
,
ann_file
=
ann_file
,
ann_file
=
ann_file
,
metainfo
=
metainfo
,
data_prefix
=
data_prefix
,
pipeline
=
pipeline
,
pipeline
=
pipeline
,
classes
=
classes
,
palette
=
palette
,
modality
=
modality
,
modality
=
modality
,
test_mode
=
test_mode
,
ignore_index
=
ignore_index
,
ignore_index
=
ignore_index
,
scene_idxs
=
scene_idxs
,
scene_idxs
=
scene_idxs
,
test_mode
=
test_mode
,
**
kwargs
)
**
kwargs
)
def
get_ann_info
(
self
,
index
):
"""Get annotation info according to the given index.
Args:
index (int): Index of the annotation data to get.
Returns:
dict: annotation information consists of the following keys:
- pts_semantic_mask_path (str): Path of semantic masks.
"""
# Use index to get the annos, thus the evalhook could also use this api
info
=
self
.
data_infos
[
index
]
pts_semantic_mask_path
=
osp
.
join
(
self
.
data_root
,
info
[
'pts_semantic_mask_path'
])
anns_results
=
dict
(
pts_semantic_mask_path
=
pts_semantic_mask_path
)
return
anns_results
def
_build_default_pipeline
(
self
):
"""Build the default pipeline for this dataset."""
pipeline
=
[
dict
(
type
=
'LoadPointsFromFile'
,
coord_type
=
'DEPTH'
,
shift_height
=
False
,
use_color
=
True
,
load_dim
=
6
,
use_dim
=
[
0
,
1
,
2
,
3
,
4
,
5
]),
dict
(
type
=
'LoadAnnotations3D'
,
with_bbox_3d
=
False
,
with_label_3d
=
False
,
with_mask_3d
=
False
,
with_seg_3d
=
True
),
dict
(
type
=
'PointSegClassMapping'
,
valid_cat_ids
=
self
.
VALID_CLASS_IDS
,
max_cat_id
=
np
.
max
(
self
.
ALL_CLASS_IDS
)),
dict
(
type
=
'DefaultFormatBundle3D'
,
with_label
=
False
,
class_names
=
self
.
CLASSES
),
dict
(
type
=
'Collect3D'
,
keys
=
[
'points'
,
'pts_semantic_mask'
])
]
return
Compose
(
pipeline
)
def
show
(
self
,
results
,
out_dir
,
show
=
True
,
pipeline
=
None
):
"""Results visualization.
Args:
results (list[dict]): List of bounding boxes results.
out_dir (str): Output directory of visualization result.
show (bool): Visualize the results online.
pipeline (list[dict], optional): raw data loading for showing.
Default: None.
"""
assert
out_dir
is
not
None
,
'Expect out_dir, got none.'
pipeline
=
self
.
_get_pipeline
(
pipeline
)
for
i
,
result
in
enumerate
(
results
):
data_info
=
self
.
data_infos
[
i
]
pts_path
=
data_info
[
'pts_path'
]
file_name
=
osp
.
split
(
pts_path
)[
-
1
].
split
(
'.'
)[
0
]
points
,
gt_sem_mask
=
self
.
_extract_data
(
i
,
pipeline
,
[
'points'
,
'pts_semantic_mask'
],
load_annos
=
True
)
points
=
points
.
numpy
()
pred_sem_mask
=
result
[
'semantic_mask'
].
numpy
()
show_seg_result
(
points
,
gt_sem_mask
,
pred_sem_mask
,
out_dir
,
file_name
,
np
.
array
(
self
.
PALETTE
),
self
.
ignore_index
,
show
)
def
get_scene_idxs
(
self
,
scene_idxs
):
def
get_scene_idxs
(
self
,
scene_idxs
):
"""Compute scene_idxs for data sampling.
"""Compute scene_idxs for data sampling.
...
@@ -362,191 +290,65 @@ class ScanNetSegDataset(Custom3DSegDataset):
...
@@ -362,191 +290,65 @@ class ScanNetSegDataset(Custom3DSegDataset):
return
super
().
get_scene_idxs
(
scene_idxs
)
return
super
().
get_scene_idxs
(
scene_idxs
)
def
format_results
(
self
,
results
,
txtfile_prefix
=
None
):
r
"""Format the results to txt file. Refer to `ScanNet documentation
<http://kaldir.vc.in.tum.de/scannet_benchmark/documentation>`_.
Args:
outputs (list[dict]): Testing results of the dataset.
txtfile_prefix (str): The prefix of saved files. It includes
the file path and the prefix of filename, e.g., "a/b/prefix".
If not specified, a temp file will be created. Default: None.
Returns:
tuple: (outputs, tmp_dir), outputs is the detection results,
tmp_dir is the temporal directory created for saving submission
files when ``submission_prefix`` is not specified.
"""
import
mmcv
if
txtfile_prefix
is
None
:
tmp_dir
=
tempfile
.
TemporaryDirectory
()
txtfile_prefix
=
osp
.
join
(
tmp_dir
.
name
,
'results'
)
else
:
tmp_dir
=
None
mmcv
.
mkdir_or_exist
(
txtfile_prefix
)
# need to map network output to original label idx
pred2label
=
np
.
zeros
(
len
(
self
.
VALID_CLASS_IDS
)).
astype
(
np
.
int
)
for
original_label
,
output_idx
in
self
.
label_map
.
items
():
if
output_idx
!=
self
.
ignore_index
:
pred2label
[
output_idx
]
=
original_label
outputs
=
[]
for
i
,
result
in
enumerate
(
results
):
info
=
self
.
data_infos
[
i
]
sample_idx
=
info
[
'point_cloud'
][
'lidar_idx'
]
pred_sem_mask
=
result
[
'semantic_mask'
].
numpy
().
astype
(
np
.
int
)
pred_label
=
pred2label
[
pred_sem_mask
]
curr_file
=
f
'
{
txtfile_prefix
}
/
{
sample_idx
}
.txt'
np
.
savetxt
(
curr_file
,
pred_label
,
fmt
=
'%d'
)
outputs
.
append
(
dict
(
seg_mask
=
pred_label
))
return
outputs
,
tmp_dir
@
DATASETS
.
register_module
()
@
DATASETS
.
register_module
()
class
ScanNetInstanceSegDataset
(
Custom3DSegDataset
):
class
ScanNetInstanceSegDataset
(
Seg3DDataset
):
CLASSES
=
(
'cabinet'
,
'bed'
,
'chair'
,
'sofa'
,
'table'
,
'door'
,
'window'
,
'bookshelf'
,
'picture'
,
'counter'
,
'desk'
,
'curtain'
,
'refrigerator'
,
'showercurtrain'
,
'toilet'
,
'sink'
,
'bathtub'
,
'garbagebin'
)
VALID_CLASS_IDS
=
(
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
14
,
16
,
24
,
28
,
33
,
34
,
36
,
39
)
ALL_CLASS_IDS
=
tuple
(
range
(
41
))
def
get_ann_info
(
self
,
index
):
"""Get annotation info according to the given index.
Args:
index (int): Index of the annotation data to get.
Returns:
dict: annotation information consists of the following keys:
- pts_semantic_mask_path (str): Path of semantic masks.
- pts_instance_mask_path (str): Path of instance masks.
"""
# Use index to get the annos, thus the evalhook could also use this api
info
=
self
.
data_infos
[
index
]
pts_instance_mask_path
=
osp
.
join
(
self
.
data_root
,
info
[
'pts_instance_mask_path'
])
pts_semantic_mask_path
=
osp
.
join
(
self
.
data_root
,
info
[
'pts_semantic_mask_path'
])
anns_results
=
dict
(
pts_instance_mask_path
=
pts_instance_mask_path
,
pts_semantic_mask_path
=
pts_semantic_mask_path
)
return
anns_results
def
get_classes_and_palette
(
self
,
classes
=
None
,
palette
=
None
):
"""Get class names of current dataset. Palette is simply ignored for
instance segmentation.
Args:
classes (Sequence[str] | str | None): If classes is None, use
default CLASSES defined by builtin dataset. If classes is a
string, take it as a file name. The file contains the name of
classes where each line contains one class name. If classes is
a tuple or list, override the CLASSES defined by the dataset.
Defaults to None.
palette (Sequence[Sequence[int]]] | np.ndarray | None):
The palette of segmentation map. If None is given, random
palette will be generated. Defaults to None.
"""
if
classes
is
not
None
:
return
classes
,
None
return
self
.
CLASSES
,
None
def
_build_default_pipeline
(
self
):
METAINFO
=
{
"""Build the default pipeline for this dataset."""
'CLASSES'
:
pipeline
=
[
(
'cabinet'
,
'bed'
,
'chair'
,
'sofa'
,
'table'
,
'door'
,
'window'
,
dict
(
'bookshelf'
,
'picture'
,
'counter'
,
'desk'
,
'curtain'
,
'refrigerator'
,
type
=
'LoadPointsFromFile'
,
'showercurtrain'
,
'toilet'
,
'sink'
,
'bathtub'
,
'garbagebin'
),
coord_type
=
'DEPTH'
,
'PLATTE'
:
[
shift_height
=
False
,
[
174
,
199
,
232
],
use_color
=
True
,
[
152
,
223
,
138
],
load_dim
=
6
,
[
31
,
119
,
180
],
use_dim
=
[
0
,
1
,
2
,
3
,
4
,
5
]),
[
255
,
187
,
120
],
dict
(
[
188
,
189
,
34
],
type
=
'LoadAnnotations3D'
,
[
140
,
86
,
75
],
with_bbox_3d
=
False
,
[
255
,
152
,
150
],
with_label_3d
=
False
,
[
214
,
39
,
40
],
with_mask_3d
=
True
,
[
197
,
176
,
213
],
with_seg_3d
=
True
),
[
148
,
103
,
189
],
dict
(
[
196
,
156
,
148
],
type
=
'PointSegClassMapping'
,
[
23
,
190
,
207
],
valid_cat_ids
=
self
.
VALID_CLASS_IDS
,
[
247
,
182
,
210
],
max_cat_id
=
40
),
[
219
,
219
,
141
],
dict
(
[
255
,
127
,
14
],
type
=
'DefaultFormatBundle3D'
,
[
158
,
218
,
229
],
with_label
=
False
,
[
44
,
160
,
44
],
class_names
=
self
.
CLASSES
),
[
112
,
128
,
144
],
dict
(
[
227
,
119
,
194
],
type
=
'Collect3D'
,
[
82
,
84
,
163
],
keys
=
[
'points'
,
'pts_semantic_mask'
,
'pts_instance_mask'
])
],
]
'valid_class_ids'
:
return
Compose
(
pipeline
)
(
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
14
,
16
,
24
,
28
,
33
,
34
,
36
,
39
),
'all_class_ids'
:
def
evaluate
(
self
,
tuple
(
range
(
41
))
results
,
}
metric
=
None
,
options
=
None
,
logger
=
None
,
show
=
False
,
out_dir
=
None
,
pipeline
=
None
):
"""Evaluation in instance segmentation protocol.
Args:
results (list[dict]): List of results.
metric (str | list[str]): Metrics to be evaluated.
options (dict, optional): options for instance_seg_eval.
logger (logging.Logger | None | str): Logger used for printing
related information during evaluation. Defaults to None.
show (bool, optional): Whether to visualize.
Defaults to False.
out_dir (str, optional): Path to save the visualization results.
Defaults to None.
pipeline (list[dict], optional): raw data loading for showing.
Default: None.
Returns:
def
__init__
(
self
,
dict: Evaluation results.
data_root
:
Optional
[
str
]
=
None
,
"""
ann_file
:
str
=
''
,
assert
isinstance
(
metainfo
:
Optional
[
dict
]
=
None
,
results
,
list
),
f
'Expect results to be list, got
{
type
(
results
)
}
.'
data_prefix
:
dict
=
dict
(
assert
len
(
results
)
>
0
,
'Expect length of results > 0.'
pts
=
'points'
,
img
=
''
,
instance_mask
=
''
,
semantic_mask
=
''
),
assert
len
(
results
)
==
len
(
self
.
data_infos
)
pipeline
:
List
[
Union
[
dict
,
Callable
]]
=
[],
assert
isinstance
(
modality
:
dict
=
dict
(
use_lidar
=
True
,
use_camera
=
False
),
results
[
0
],
dict
test_mode
=
False
,
),
f
'Expect elements in results to be dict, got
{
type
(
results
[
0
])
}
.'
ignore_index
=
None
,
scene_idxs
=
None
,
load_pipeline
=
self
.
_get_pipeline
(
pipeline
)
file_client_args
=
dict
(
backend
=
'disk'
),
pred_instance_masks
=
[
result
[
'instance_mask'
]
for
result
in
results
]
**
kwargs
)
->
None
:
pred_instance_labels
=
[
result
[
'instance_label'
]
for
result
in
results
]
super
().
__init__
(
pred_instance_scores
=
[
result
[
'instance_score'
]
for
result
in
results
]
data_root
=
data_root
,
gt_semantic_masks
,
gt_instance_masks
=
zip
(
*
[
ann_file
=
ann_file
,
self
.
_extract_data
(
metainfo
=
metainfo
,
index
=
i
,
pipeline
=
pipeline
,
pipeline
=
load_pipeline
,
data_prefix
=
data_prefix
,
key
=
[
'pts_semantic_mask'
,
'pts_instance_mask'
],
modality
=
modality
,
load_annos
=
True
)
for
i
in
range
(
len
(
self
.
data_infos
))
test_mode
=
test_mode
,
])
ignore_index
=
ignore_index
,
ret_dict
=
instance_seg_eval
(
scene_idxs
=
scene_idxs
,
gt_semantic_masks
,
file_client_args
=
file_client_args
,
gt_instance_masks
,
**
kwargs
)
pred_instance_masks
,
pred_instance_labels
,
pred_instance_scores
,
valid_class_ids
=
self
.
VALID_CLASS_IDS
,
class_labels
=
self
.
CLASSES
,
options
=
options
,
logger
=
logger
)
if
show
:
raise
NotImplementedError
(
'show is not implemented for now'
)
return
ret_dict
mmdet3d/datasets/seg3d_dataset.py
0 → 100644
View file @
360c27f9
# Copyright (c) OpenMMLab. All rights reserved.
from
os
import
path
as
osp
from
typing
import
Callable
,
Dict
,
List
,
Optional
,
Sequence
,
Union
import
mmcv
import
numpy
as
np
from
mmengine.dataset
import
BaseDataset
from
mmdet3d.registry
import
DATASETS
@
DATASETS
.
register_module
()
class
Seg3DDataset
(
BaseDataset
):
"""Base Class for 3D semantic segmentation dataset.
This is the base dataset of ScanNet, S3DIS and SemanticKITTI dataset.
Args:
data_root (str): Path of dataset root.
ann_file (str): Path of annotation file.
pipeline (list[dict], optional): Pipeline used for data processing.
Defaults to None.
metainfo (dict, optional): Meta information for dataset, such as class
information. Defaults to None.
data_prefix (dict, optional): Prefix for training data. Defaults to
dict(pts='velodyne', img='', instance_mask='', semantic_mask='').
pipeline (list[dict], optional): Pipeline used for data processing.
Defaults to None.
modality (dict, optional): Modality to specify the sensor data used
as input, it usually has following keys.
- use_camera: bool
- use_lidar: bool
Defaults to `dict(use_lidar=True, use_camera=False)`
test_mode (bool, optional): Whether the dataset is in test mode.
Defaults to False.
ignore_index (int, optional): The label index to be ignored, e.g.
unannotated points. If None is given, set to len(self.CLASSES) to
be consistent with PointSegClassMapping function in pipeline.
Defaults to None.
scene_idxs (np.ndarray | str, optional): Precomputed index to load
data. For scenes with many points, we may sample it several times.
Defaults to None.
load_eval_anns (bool): Whether to load annotations
in test_mode, the annotation will be save in
`eval_ann_infos`, which can be use in Evaluator.
file_client_args (dict): Configuration of file client.
Defaults to `dict(backend='disk')`.
"""
METAINFO
=
{
'CLASSES'
:
None
,
# names of all classes data used for the task
'PALETTE'
:
None
,
# official color for visualization
'valid_class_ids'
:
None
,
# class_ids used for training
'all_class_ids'
:
None
,
# all possible class_ids in loaded seg mask
}
def
__init__
(
self
,
data_root
:
Optional
[
str
]
=
None
,
ann_file
:
str
=
''
,
metainfo
:
Optional
[
dict
]
=
None
,
data_prefix
:
dict
=
dict
(
pts
=
'points'
,
img
=
''
,
pts_instance_mask
=
''
,
pts_emantic_mask
=
''
),
pipeline
:
List
[
Union
[
dict
,
Callable
]]
=
[],
modality
:
dict
=
dict
(
use_lidar
=
True
,
use_camera
=
False
),
ignore_index
:
Optional
[
int
]
=
None
,
scene_idxs
:
Optional
[
str
]
=
None
,
test_mode
:
bool
=
False
,
load_eval_anns
:
bool
=
True
,
file_client_args
:
dict
=
dict
(
backend
=
'disk'
),
**
kwargs
)
->
None
:
# init file client
self
.
file_client
=
mmcv
.
FileClient
(
**
file_client_args
)
self
.
modality
=
modality
self
.
load_eval_anns
=
load_eval_anns
# TODO: We maintain the ignore_index attributes,
# but we may consider to remove it in the future.
self
.
ignore_index
=
len
(
self
.
METAINFO
[
'CLASSES'
])
if
\
ignore_index
is
None
else
ignore_index
# Get label mapping for custom classes
new_classes
=
metainfo
.
get
(
'CLASSES'
,
None
)
self
.
label_mapping
,
self
.
label2cat
,
valid_class_ids
=
\
self
.
get_label_mapping
(
new_classes
)
metainfo
[
'label_mapping'
]
=
self
.
label_mapping
metainfo
[
'label2cat'
]
=
self
.
label2cat
metainfo
[
'valid_class_ids'
]
=
valid_class_ids
# generate palette if it is not defined based on
# label mapping, otherwise directly use palette
# defined in dataset config.
palette
=
metainfo
.
get
(
'PALETTE'
,
None
)
updated_palette
=
self
.
_update_palette
(
new_classes
,
palette
)
metainfo
[
'PALETTE'
]
=
updated_palette
super
().
__init__
(
ann_file
=
ann_file
,
metainfo
=
metainfo
,
data_root
=
data_root
,
data_prefix
=
data_prefix
,
pipeline
=
pipeline
,
test_mode
=
test_mode
,
**
kwargs
)
self
.
scene_idxs
=
self
.
get_scene_idxs
(
scene_idxs
)
# set group flag for the sampler
if
not
self
.
test_mode
:
self
.
_set_group_flag
()
def
get_label_mapping
(
self
,
new_classes
:
Optional
[
Sequence
]
=
None
)
->
Union
[
Dict
,
None
]:
"""Get label mapping.
The ``label_mapping`` is a dictionary, its keys are the old label ids
and its values are the new label ids, and is used for changing pixel
labels in load_annotations. If and only if old classes in cls.METAINFO
is not equal to new classes in self._metainfo and nether of them is not
None, `label_mapping` is not None.
Args:
new_classes (list, tuple, optional): The new classes name from
metainfo. Default to None.
Returns:
tuple: The mapping from old classes in cls.METAINFO to
new classes in metainfo
"""
old_classes
=
self
.
METAINFO
.
get
(
'CLASSSES'
,
None
)
if
(
new_classes
is
not
None
and
old_classes
is
not
None
and
list
(
new_classes
)
!=
list
(
old_classes
)):
label_mapping
=
{}
if
not
set
(
new_classes
).
issubset
(
old_classes
):
raise
ValueError
(
f
'new classes
{
new_classes
}
is not a '
f
'subset of CLASSES
{
old_classes
}
in METAINFO.'
)
# obtain true id from valid_class_ids
valid_class_ids
=
[
self
.
METAINFO
[
'valid_class_ids'
][
old_classes
.
index
(
cls_name
)]
for
cls_name
in
new_classes
]
label_mapping
=
{
cls_id
:
self
.
ignore_index
for
cls_id
in
self
.
METAINFO
[
'all_class_ids'
]
}
label_mapping
.
update
(
{
cls_id
:
i
for
i
,
cls_id
in
enumerate
(
valid_class_ids
)})
label2cat
=
{
i
:
cat_name
for
i
,
cat_name
in
enumerate
(
new_classes
)}
else
:
label_mapping
=
{
cls_id
:
self
.
ignore_index
for
cls_id
in
self
.
METAINFO
[
'all_class_ids'
]
}
label_mapping
.
update
({
cls_id
:
i
for
i
,
cls_id
in
enumerate
(
self
.
METAINFO
[
'valid_class_ids'
])
})
# map label to category name
label2cat
=
{
i
:
cat_name
for
i
,
cat_name
in
enumerate
(
self
.
METAINFO
[
'CLASSES'
])
}
valid_class_ids
=
self
.
METAINFO
[
'valid_class_ids'
]
return
label_mapping
,
label2cat
,
valid_class_ids
def
_update_palette
(
self
,
new_classes
,
palette
)
->
list
:
"""Update palette according to metainfo.
If length of palette is equal to classes, just return the palette.
If palette is not defined, it will randomly generate a palette.
If classes is updated by customer, it will return the subset of
palette.
Returns:
Sequence: Palette for current dataset.
"""
if
palette
is
None
:
# If palette is not defined, it generate a palette according
# to the original PALETTE and classes.
old_classes
=
self
.
METAINFO
.
get
(
'CLASSSES'
,
None
)
palette
=
[
self
.
METAINFO
[
'PALETTE'
][
old_classes
.
index
(
cls_name
)]
for
cls_name
in
new_classes
]
return
palette
# palette does match classes
if
len
(
palette
)
==
len
(
new_classes
):
return
palette
else
:
raise
ValueError
(
'Once PLATTE in set in metainfo, it should'
'match CLASSES in metainfo'
)
def
parse_data_info
(
self
,
info
:
dict
)
->
dict
:
"""Process the raw data info.
Convert all relative path of needed modality data file to
the absolute path. And process
the `instances` field to `ann_info` in training stage.
Args:
info (dict): Raw info dict.
Returns:
dict: Has `ann_info` in training stage. And
all path has been converted to absolute path.
"""
if
self
.
modality
[
'use_lidar'
]:
info
[
'lidar_points'
][
'lidar_path'
]
=
\
osp
.
join
(
self
.
data_prefix
.
get
(
'pts'
,
''
),
info
[
'lidar_points'
][
'lidar_path'
])
if
self
.
modality
[
'use_camera'
]:
for
cam_id
,
img_info
in
info
[
'images'
].
items
():
if
'img_path'
in
img_info
:
img_info
[
'img_path'
]
=
osp
.
join
(
self
.
data_prefix
.
get
(
'img'
,
''
),
img_info
[
'img_path'
])
if
'pts_instance_mask_path'
in
info
:
info
[
'pts_instance_mask_path'
]
=
\
osp
.
join
(
self
.
data_prefix
.
get
(
'pts_instance_mask'
,
''
),
info
[
'pts_instance_mask_path'
])
if
'pts_semantic_mask_path'
in
info
:
info
[
'pts_semantic_mask_path'
]
=
\
osp
.
join
(
self
.
data_prefix
.
get
(
'pts_semantic_mask'
,
''
),
info
[
'pts_semantic_mask_path'
])
# Add label_mapping to input dict for directly
# use it in PointSegClassMapping pipeline
info
[
'label_mapping'
]
=
self
.
label_mapping
# 'eval_ann_info' will be updated in loading pipelines
if
self
.
test_mode
and
self
.
load_eval_anns
:
info
[
'eval_ann_info'
]
=
dict
()
return
info
def
get_scene_idxs
(
self
,
scene_idxs
):
"""Compute scene_idxs for data sampling.
We sample more times for scenes with more points.
"""
if
self
.
test_mode
:
# when testing, we load one whole scene every time
return
np
.
arange
(
len
(
self
.
data_list
)).
astype
(
np
.
int32
)
# we may need to re-sample different scenes according to scene_idxs
# this is necessary for indoor scene segmentation such as ScanNet
if
scene_idxs
is
None
:
scene_idxs
=
np
.
arange
(
len
(
self
.
data_list
))
if
isinstance
(
scene_idxs
,
str
):
with
self
.
file_client
.
get_local_path
(
scene_idxs
)
as
local_path
:
scene_idxs
=
np
.
load
(
local_path
)
else
:
scene_idxs
=
np
.
array
(
scene_idxs
)
return
scene_idxs
.
astype
(
np
.
int32
)
def
_set_group_flag
(
self
):
"""Set flag according to image aspect ratio.
Images with aspect ratio greater than 1 will be set as group 1,
otherwise group 0. In 3D datasets, they are all the same, thus are all
zeros.
"""
self
.
flag
=
np
.
zeros
(
len
(
self
),
dtype
=
np
.
uint8
)
mmdet3d/datasets/semantickitti_dataset.py
View file @
360c27f9
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
from
os
import
path
as
osp
from
typing
import
Callable
,
List
,
Optional
,
Union
from
mmdet3d.registry
import
DATASETS
from
mmdet3d.registry
import
DATASETS
from
.
det
3d_dataset
import
Det
3DDataset
from
.
seg
3d_dataset
import
Seg
3DDataset
@
DATASETS
.
register_module
()
@
DATASETS
.
register_module
()
class
SemanticKITTIDataset
(
Det
3DDataset
):
class
SemanticKITTIDataset
(
Seg
3DDataset
):
r
"""SemanticKITTI Dataset.
r
"""SemanticKITTI Dataset.
This class serves as the API for experiments on the SemanticKITTI Dataset
This class serves as the API for experiments on the SemanticKITTI Dataset
...
@@ -36,75 +36,38 @@ class SemanticKITTIDataset(Det3DDataset):
...
@@ -36,75 +36,38 @@ class SemanticKITTIDataset(Det3DDataset):
test_mode (bool, optional): Whether the dataset is in test mode.
test_mode (bool, optional): Whether the dataset is in test mode.
Defaults to False.
Defaults to False.
"""
"""
CLASSES
=
(
'unlabeled'
,
'car'
,
'bicycle'
,
'motorcycle'
,
'truck'
,
'bus'
,
METAINFO
=
{
'person'
,
'bicyclist'
,
'motorcyclist'
,
'road'
,
'parking'
,
'CLASSES'
:
(
'unlabeled'
,
'car'
,
'bicycle'
,
'motorcycle'
,
'truck'
,
'sidewalk'
,
'other-ground'
,
'building'
,
'fence'
,
'vegetation'
,
'bus'
,
'person'
,
'bicyclist'
,
'motorcyclist'
,
'road'
,
'trunck'
,
'terrian'
,
'pole'
,
'traffic-sign'
)
'parking'
,
'sidewalk'
,
'other-ground'
,
'building'
,
'fence'
,
'vegetation'
,
'trunck'
,
'terrian'
,
'pole'
,
'traffic-sign'
),
'valid_class_ids'
:
tuple
(
range
(
20
)),
'all_class_ids'
:
tuple
(
range
(
20
))
}
def
__init__
(
self
,
def
__init__
(
self
,
data_root
,
data_root
:
Optional
[
str
]
=
None
,
ann_file
,
ann_file
:
str
=
''
,
pipeline
=
None
,
metainfo
:
Optional
[
dict
]
=
None
,
classes
=
None
,
data_prefix
:
dict
=
dict
(
modality
=
None
,
pts
=
'points'
,
img
=
''
,
instance_mask
=
''
,
semantic_mask
=
''
),
box_type_3d
=
'Lidar'
,
pipeline
:
List
[
Union
[
dict
,
Callable
]]
=
[],
filter_empty_gt
=
False
,
modality
:
dict
=
dict
(
use_lidar
=
True
,
use_camera
=
False
),
test_mode
=
False
):
ignore_index
=
None
,
scene_idxs
=
None
,
test_mode
=
False
,
**
kwargs
)
->
None
:
super
().
__init__
(
super
().
__init__
(
data_root
=
data_root
,
data_root
=
data_root
,
ann_file
=
ann_file
,
ann_file
=
ann_file
,
metainfo
=
metainfo
,
data_prefix
=
data_prefix
,
pipeline
=
pipeline
,
pipeline
=
pipeline
,
classes
=
classes
,
modality
=
modality
,
modality
=
modality
,
box_type_3d
=
box_type_3d
,
ignore_index
=
ignore_index
,
filter_empty_gt
=
filter_empty_gt
,
scene_idxs
=
scene_idxs
,
test_mode
=
test_mode
)
test_mode
=
test_mode
,
**
kwargs
)
def
get_data_info
(
self
,
index
):
"""Get data info according to the given index.
Args:
index (int): Index of the sample data to get.
Returns:
dict: Data information that will be passed to the data
preprocessing pipelines. It includes the following keys:
- sample_idx (str): Sample index.
- pts_filename (str): Filename of point clouds.
- file_name (str): Filename of point clouds.
- ann_info (dict): Annotation info.
"""
info
=
self
.
data_infos
[
index
]
sample_idx
=
info
[
'point_cloud'
][
'lidar_idx'
]
pts_filename
=
osp
.
join
(
self
.
data_root
,
info
[
'pts_path'
])
input_dict
=
dict
(
pts_filename
=
pts_filename
,
sample_idx
=
sample_idx
,
file_name
=
pts_filename
)
if
not
self
.
test_mode
:
annos
=
self
.
get_ann_info
(
index
)
input_dict
[
'ann_info'
]
=
annos
if
self
.
filter_empty_gt
and
~
(
annos
[
'gt_labels_3d'
]
!=
-
1
).
any
():
return
None
return
input_dict
def
get_ann_info
(
self
,
index
):
"""Get annotation info according to the given index.
Args:
index (int): Index of the annotation data to get.
Returns:
dict: annotation information consists of the following keys:
- pts_semantic_mask_path (str): Path of semantic masks.
"""
# Use index to get the annos, thus the evalhook could also use this api
info
=
self
.
data_infos
[
index
]
pts_semantic_mask_path
=
osp
.
join
(
self
.
data_root
,
info
[
'pts_semantic_mask_path'
])
anns_results
=
dict
(
pts_semantic_mask_path
=
pts_semantic_mask_path
)
return
anns_results
tests/data/s3dis/s3dis_infos.pkl
View file @
360c27f9
No preview for this file type
tests/data/semantickitti/semantickitti_infos.pkl
View file @
360c27f9
No preview for this file type
tests/test_core/test_data_structure/test_det_data_sample.py
View file @
360c27f9
# Copyright (c) OpenMMLab. All rights reserved.
from
unittest
import
TestCase
from
unittest
import
TestCase
import
numpy
as
np
import
numpy
as
np
import
pytest
import
pytest
import
torch
import
torch
from
mmengine.data
import
InstanceData
,
PixelData
from
mmengine.data
import
InstanceData
from
mmdet3d.core.data_structures
import
Det3DDataSample
from
mmdet3d.core.data_structures
import
Det3DDataSample
,
PointData
def
_equal
(
a
,
b
):
def
_equal
(
a
,
b
):
...
@@ -86,47 +87,34 @@ class TestDet3DataSample(TestCase):
...
@@ -86,47 +87,34 @@ class TestDet3DataSample(TestCase):
assert
_equal
(
det3d_data_sample
.
img_pred_instances_3d
.
scores_3d
,
assert
_equal
(
det3d_data_sample
.
img_pred_instances_3d
.
scores_3d
,
img_pred_instances_3d_data
[
'scores_3d'
])
img_pred_instances_3d_data
[
'scores_3d'
])
# test gt_panoptic_seg
# test gt_seg
gt_pts_panoptic_seg_data
=
dict
(
panoptic_seg
=
torch
.
rand
(
5
,
4
))
gt_pts_seg_data
=
dict
(
gt_pts_panoptic_seg
=
PixelData
(
**
gt_pts_panoptic_seg_data
)
pts_instance_mask
=
torch
.
rand
(
20
),
pts_semantic_mask
=
torch
.
rand
(
20
))
det3d_data_sample
.
gt_pts_panoptic_seg
=
gt_pts_panoptic_seg
gt_pts_seg
=
PointData
(
**
gt_pts_seg_data
)
assert
'gt_pts_panoptic_seg'
in
det3d_data_sample
det3d_data_sample
.
gt_pts_seg
=
gt_pts_seg
assert
_equal
(
det3d_data_sample
.
gt_pts_panoptic_seg
.
panoptic_seg
,
assert
'gt_pts_seg'
in
det3d_data_sample
gt_pts_panoptic_seg_data
[
'panoptic_seg'
])
assert
_equal
(
det3d_data_sample
.
gt_pts_seg
.
pts_instance_mask
,
gt_pts_seg_data
[
'pts_instance_mask'
])
# test pred_panoptic_seg
assert
_equal
(
det3d_data_sample
.
gt_pts_seg
.
pts_semantic_mask
,
pred_pts_panoptic_seg_data
=
dict
(
panoptic_seg
=
torch
.
rand
(
5
,
4
))
gt_pts_seg_data
[
'pts_semantic_mask'
])
pred_pts_panoptic_seg
=
PixelData
(
**
pred_pts_panoptic_seg_data
)
det3d_data_sample
.
pred_pts_panoptic_seg
=
pred_pts_panoptic_seg
# test pred_seg
assert
'pred_pts_panoptic_seg'
in
det3d_data_sample
pred_pts_seg_data
=
dict
(
assert
_equal
(
det3d_data_sample
.
pred_pts_panoptic_seg
.
panoptic_seg
,
pts_instance_mask
=
torch
.
rand
(
20
),
pts_semantic_mask
=
torch
.
rand
(
20
))
pred_pts_panoptic_seg_data
[
'panoptic_seg'
])
pred_pts_seg
=
PointData
(
**
pred_pts_seg_data
)
det3d_data_sample
.
pred_pts_seg
=
pred_pts_seg
# test gt_sem_seg
assert
'pred_pts_seg'
in
det3d_data_sample
gt_pts_sem_seg_data
=
dict
(
segm_seg
=
torch
.
rand
(
5
,
4
,
2
))
assert
_equal
(
det3d_data_sample
.
pred_pts_seg
.
pts_instance_mask
,
gt_pts_sem_seg
=
PixelData
(
**
gt_pts_sem_seg_data
)
pred_pts_seg_data
[
'pts_instance_mask'
])
det3d_data_sample
.
gt_pts_sem_seg
=
gt_pts_sem_seg
assert
_equal
(
det3d_data_sample
.
pred_pts_seg
.
pts_semantic_mask
,
assert
'gt_pts_sem_seg'
in
det3d_data_sample
pred_pts_seg_data
[
'pts_semantic_mask'
])
assert
_equal
(
det3d_data_sample
.
gt_pts_sem_seg
.
segm_seg
,
gt_pts_sem_seg_data
[
'segm_seg'
])
# test pred_segm_seg
pred_pts_sem_seg_data
=
dict
(
segm_seg
=
torch
.
rand
(
5
,
4
,
2
))
pred_pts_sem_seg
=
PixelData
(
**
pred_pts_sem_seg_data
)
det3d_data_sample
.
pred_pts_sem_seg
=
pred_pts_sem_seg
assert
'pred_pts_sem_seg'
in
det3d_data_sample
assert
_equal
(
det3d_data_sample
.
pred_pts_sem_seg
.
segm_seg
,
pred_pts_sem_seg_data
[
'segm_seg'
])
# test type error
# test type error
with
pytest
.
raises
(
AssertionError
):
with
pytest
.
raises
(
AssertionError
):
det3d_data_sample
.
pred_instances_3d
=
torch
.
rand
(
2
,
4
)
det3d_data_sample
.
pred_instances_3d
=
torch
.
rand
(
2
,
4
)
with
pytest
.
raises
(
AssertionError
):
with
pytest
.
raises
(
AssertionError
):
det3d_data_sample
.
pred_pts_panoptic_seg
=
torch
.
rand
(
2
,
4
)
det3d_data_sample
.
pred_pts_seg
=
torch
.
rand
(
20
)
with
pytest
.
raises
(
AssertionError
):
det3d_data_sample
.
pred_pts_sem_seg
=
torch
.
rand
(
2
,
4
)
def
test_deleter
(
self
):
def
test_deleter
(
self
):
tmp_instances_3d_data
=
dict
(
tmp_instances_3d_data
=
dict
(
...
@@ -157,17 +145,10 @@ class TestDet3DataSample(TestCase):
...
@@ -157,17 +145,10 @@ class TestDet3DataSample(TestCase):
del
det3d_data_sample
.
img_pred_instances_3d
del
det3d_data_sample
.
img_pred_instances_3d
assert
'img_pred_instances_3d'
not
in
det3d_data_sample
assert
'img_pred_instances_3d'
not
in
det3d_data_sample
pred_pts_panoptic_seg_data
=
torch
.
rand
(
5
,
4
)
pred_pts_seg_data
=
dict
(
pred_pts_panoptic_seg_data
=
PixelData
(
data
=
pred_pts_panoptic_seg_data
)
pts_instance_mask
=
torch
.
rand
(
20
),
pts_semantic_mask
=
torch
.
rand
(
20
))
det3d_data_sample
.
pred_pts_panoptic_seg_data
=
\
pred_pts_seg
=
PointData
(
**
pred_pts_seg_data
)
pred_pts_panoptic_seg_data
det3d_data_sample
.
pred_pts_seg
=
pred_pts_seg
assert
'pred_pts_panoptic_seg_data'
in
det3d_data_sample
assert
'pred_pts_seg'
in
det3d_data_sample
del
det3d_data_sample
.
pred_pts_panoptic_seg_data
del
det3d_data_sample
.
pred_pts_seg
assert
'pred_pts_panoptic_seg_data'
not
in
det3d_data_sample
assert
'pred_pts_seg'
not
in
det3d_data_sample
pred_pts_sem_seg_data
=
dict
(
segm_seg
=
torch
.
rand
(
5
,
4
,
2
))
pred_pts_sem_seg
=
PixelData
(
**
pred_pts_sem_seg_data
)
det3d_data_sample
.
pred_pts_sem_seg
=
pred_pts_sem_seg
assert
'pred_pts_sem_seg'
in
det3d_data_sample
del
det3d_data_sample
.
pred_pts_sem_seg
assert
'pred_pts_sem_seg'
not
in
det3d_data_sample
tests/test_data/test_datasets/test_s3dis_dataset.py
0 → 100644
View file @
360c27f9
# Copyright (c) OpenMMLab. All rights reserved.
import
unittest
import
numpy
as
np
import
torch
from
mmdet3d.datasets
import
S3DISSegDataset
from
mmdet3d.utils
import
register_all_modules
def
_generate_s3dis_seg_dataset_config
():
data_root
=
'./tests/data/s3dis/'
ann_file
=
's3dis_infos.pkl'
classes
=
(
'ceiling'
,
'floor'
,
'wall'
,
'beam'
,
'column'
,
'window'
,
'door'
,
'table'
,
'chair'
,
'sofa'
,
'bookcase'
,
'board'
,
'clutter'
)
palette
=
[[
0
,
255
,
0
],
[
0
,
0
,
255
],
[
0
,
255
,
255
],
[
255
,
255
,
0
],
[
255
,
0
,
255
],
[
100
,
100
,
255
],
[
200
,
200
,
100
],
[
170
,
120
,
200
],
[
255
,
0
,
0
],
[
200
,
100
,
100
],
[
10
,
200
,
100
],
[
200
,
200
,
200
],
[
50
,
50
,
50
]]
scene_idxs
=
[
0
for
_
in
range
(
20
)]
modality
=
dict
(
use_lidar
=
True
,
use_camera
=
False
)
pipeline
=
[
dict
(
type
=
'LoadPointsFromFile'
,
coord_type
=
'DEPTH'
,
shift_height
=
False
,
use_color
=
True
,
load_dim
=
6
,
use_dim
=
[
0
,
1
,
2
,
3
,
4
,
5
]),
dict
(
type
=
'LoadAnnotations3D'
,
with_bbox_3d
=
False
,
with_label_3d
=
False
,
with_mask_3d
=
False
,
with_seg_3d
=
True
),
dict
(
type
=
'PointSegClassMapping'
),
dict
(
type
=
'IndoorPatchPointSample'
,
num_points
=
5
,
block_size
=
1.0
,
ignore_index
=
len
(
classes
),
use_normalized_coord
=
True
,
enlarge_size
=
0.2
,
min_unique_num
=
None
),
dict
(
type
=
'NormalizePointsColor'
,
color_mean
=
None
),
dict
(
type
=
'Pack3DDetInputs'
,
keys
=
[
'points'
,
'pts_semantic_mask'
])
]
data_prefix
=
dict
(
pts
=
'points'
,
pts_instance_mask
=
'instance_mask'
,
pts_semantic_mask
=
'semantic_mask'
)
return
(
data_root
,
ann_file
,
classes
,
palette
,
scene_idxs
,
data_prefix
,
pipeline
,
modality
)
class
TestS3DISDataset
(
unittest
.
TestCase
):
def
test_s3dis_seg
(
self
):
np
.
random
.
seed
(
0
)
data_root
,
ann_file
,
classes
,
palette
,
scene_idxs
,
data_prefix
,
\
pipeline
,
modality
,
=
_generate_s3dis_seg_dataset_config
()
register_all_modules
()
s3dis_seg_dataset
=
S3DISSegDataset
(
data_root
,
ann_file
,
metainfo
=
dict
(
CLASSES
=
classes
,
PALETTE
=
palette
),
data_prefix
=
data_prefix
,
pipeline
=
pipeline
,
modality
=
modality
,
scene_idxs
=
scene_idxs
)
input_dict
=
s3dis_seg_dataset
.
prepare_data
(
0
)
points
=
input_dict
[
'inputs'
][
'points'
]
data_sample
=
input_dict
[
'data_sample'
]
pts_semantic_mask
=
data_sample
.
gt_pts_seg
.
pts_semantic_mask
expected_points
=
torch
.
tensor
([[
0.0000
,
0.0000
,
3.1720
,
0.4706
,
0.4431
,
0.3725
,
0.4624
,
0.7502
,
0.9543
],
[
0.2880
,
-
0.5900
,
0.0650
,
0.3451
,
0.3373
,
0.3490
,
0.5119
,
0.5518
,
0.0196
],
[
0.1570
,
0.6000
,
3.1700
,
0.4941
,
0.4667
,
0.3569
,
0.4893
,
0.9519
,
0.9537
],
[
-
0.1320
,
0.3950
,
0.2720
,
0.3216
,
0.2863
,
0.2275
,
0.4397
,
0.8830
,
0.0818
],
[
-
0.4860
,
-
0.0640
,
3.1710
,
0.3843
,
0.3725
,
0.3059
,
0.3789
,
0.7286
,
0.9540
]])
expected_pts_semantic_mask
=
np
.
array
([
0
,
1
,
0
,
8
,
0
])
assert
torch
.
allclose
(
points
,
expected_points
,
1e-2
)
self
.
assertTrue
(
(
pts_semantic_mask
.
numpy
()
==
expected_pts_semantic_mask
).
all
())
tests/test_data/test_datasets/test_scannet_dataset.py
View file @
360c27f9
...
@@ -6,7 +6,74 @@ import torch
...
@@ -6,7 +6,74 @@ import torch
from
mmengine.testing
import
assert_allclose
from
mmengine.testing
import
assert_allclose
from
mmdet3d.core
import
DepthInstance3DBoxes
from
mmdet3d.core
import
DepthInstance3DBoxes
from
mmdet3d.datasets
import
ScanNetDataset
from
mmdet3d.datasets
import
ScanNetDataset
,
ScanNetSegDataset
from
mmdet3d.utils
import
register_all_modules
def
_generate_scannet_seg_dataset_config
():
data_root
=
'./tests/data/scannet/'
ann_file
=
'scannet_infos.pkl'
classes
=
(
'wall'
,
'floor'
,
'cabinet'
,
'bed'
,
'chair'
,
'sofa'
,
'table'
,
'door'
,
'window'
,
'bookshelf'
,
'picture'
,
'counter'
,
'desk'
,
'curtain'
,
'refrigerator'
,
'showercurtrain'
,
'toilet'
,
'sink'
,
'bathtub'
,
'otherfurniture'
)
palette
=
[
[
174
,
199
,
232
],
[
152
,
223
,
138
],
[
31
,
119
,
180
],
[
255
,
187
,
120
],
[
188
,
189
,
34
],
[
140
,
86
,
75
],
[
255
,
152
,
150
],
[
214
,
39
,
40
],
[
197
,
176
,
213
],
[
148
,
103
,
189
],
[
196
,
156
,
148
],
[
23
,
190
,
207
],
[
247
,
182
,
210
],
[
219
,
219
,
141
],
[
255
,
127
,
14
],
[
158
,
218
,
229
],
[
44
,
160
,
44
],
[
112
,
128
,
144
],
[
227
,
119
,
194
],
[
82
,
84
,
163
],
]
scene_idxs
=
[
0
for
_
in
range
(
20
)]
modality
=
dict
(
use_lidar
=
True
,
use_camera
=
False
)
pipeline
=
[
dict
(
type
=
'LoadPointsFromFile'
,
coord_type
=
'DEPTH'
,
shift_height
=
False
,
use_color
=
True
,
load_dim
=
6
,
use_dim
=
[
0
,
1
,
2
,
3
,
4
,
5
]),
dict
(
type
=
'LoadAnnotations3D'
,
with_bbox_3d
=
False
,
with_label_3d
=
False
,
with_mask_3d
=
False
,
with_seg_3d
=
True
),
dict
(
type
=
'PointSegClassMapping'
),
dict
(
type
=
'IndoorPatchPointSample'
,
num_points
=
5
,
block_size
=
1.5
,
ignore_index
=
len
(
classes
),
use_normalized_coord
=
True
,
enlarge_size
=
0.2
,
min_unique_num
=
None
),
dict
(
type
=
'NormalizePointsColor'
,
color_mean
=
None
),
dict
(
type
=
'Pack3DDetInputs'
,
keys
=
[
'points'
,
'pts_semantic_mask'
])
]
data_prefix
=
dict
(
pts
=
'points'
,
pts_instance_mask
=
'instance_mask'
,
pts_semantic_mask
=
'semantic_mask'
)
return
(
data_root
,
ann_file
,
classes
,
palette
,
scene_idxs
,
data_prefix
,
pipeline
,
modality
)
def
_generate_scannet_dataset_config
():
def
_generate_scannet_dataset_config
():
...
@@ -92,3 +159,54 @@ class TestScanNetDataset(unittest.TestCase):
...
@@ -92,3 +159,54 @@ class TestScanNetDataset(unittest.TestCase):
# all instance have been filtered by classes
# all instance have been filtered by classes
self
.
assertEqual
(
len
(
ann_info
[
'gt_labels_3d'
]),
27
)
self
.
assertEqual
(
len
(
ann_info
[
'gt_labels_3d'
]),
27
)
self
.
assertEqual
(
len
(
no_class_scannet_dataset
.
metainfo
[
'CLASSES'
]),
1
)
self
.
assertEqual
(
len
(
no_class_scannet_dataset
.
metainfo
[
'CLASSES'
]),
1
)
def
test_scannet_seg
(
self
):
np
.
random
.
seed
(
0
)
data_root
,
ann_file
,
classes
,
palette
,
scene_idxs
,
data_prefix
,
\
pipeline
,
modality
,
=
_generate_scannet_seg_dataset_config
()
register_all_modules
()
scannet_seg_dataset
=
ScanNetSegDataset
(
data_root
,
ann_file
,
metainfo
=
dict
(
CLASSES
=
classes
,
PALETTE
=
palette
),
data_prefix
=
data_prefix
,
pipeline
=
pipeline
,
modality
=
modality
,
scene_idxs
=
scene_idxs
)
input_dict
=
scannet_seg_dataset
.
prepare_data
(
0
)
points
=
input_dict
[
'inputs'
][
'points'
]
data_sample
=
input_dict
[
'data_sample'
]
pts_semantic_mask
=
data_sample
.
gt_pts_seg
.
pts_semantic_mask
expected_points
=
torch
.
tensor
([[
0.0000
,
0.0000
,
1.2427
,
0.6118
,
0.5529
,
0.4471
,
-
0.6462
,
-
1.0046
,
0.4280
],
[
0.1553
,
-
0.0074
,
1.6077
,
0.5882
,
0.6157
,
0.5569
,
-
0.6001
,
-
1.0068
,
0.5537
],
[
0.1518
,
0.6016
,
0.6548
,
0.1490
,
0.1059
,
0.0431
,
-
0.6012
,
-
0.8309
,
0.2255
],
[
-
0.7494
,
0.1033
,
0.6756
,
0.5216
,
0.4353
,
0.3333
,
-
0.8687
,
-
0.9748
,
0.2327
],
[
-
0.6836
,
-
0.0203
,
0.5884
,
0.5765
,
0.5020
,
0.4510
,
-
0.8491
,
-
1.0105
,
0.2027
]])
expected_pts_semantic_mask
=
np
.
array
([
13
,
13
,
12
,
2
,
0
])
assert
torch
.
allclose
(
points
,
expected_points
,
1e-2
)
self
.
assertTrue
(
(
pts_semantic_mask
.
numpy
()
==
expected_pts_semantic_mask
).
all
())
tests/test_data/test_datasets/test_semantickitti_dataset.py
0 → 100644
View file @
360c27f9
# Copyright (c) OpenMMLab. All rights reserved.
import
unittest
import
numpy
as
np
from
mmdet3d.datasets
import
SemanticKITTIDataset
from
mmdet3d.utils
import
register_all_modules
def
_generate_semantickitti_dataset_config
():
data_root
=
'./tests/data/semantickitti/'
ann_file
=
'semantickitti_infos.pkl'
classes
=
(
'unlabeled'
,
'car'
,
'bicycle'
,
'motorcycle'
,
'truck'
,
'bus'
,
'person'
,
'bicyclist'
,
'motorcyclist'
,
'road'
,
'parking'
,
'sidewalk'
,
'other-ground'
,
'building'
,
'fence'
,
'vegetation'
,
'trunck'
,
'terrian'
,
'pole'
,
'traffic-sign'
)
palette
=
[
[
174
,
199
,
232
],
[
152
,
223
,
138
],
[
31
,
119
,
180
],
[
255
,
187
,
120
],
[
188
,
189
,
34
],
[
140
,
86
,
75
],
[
255
,
152
,
150
],
[
214
,
39
,
40
],
[
197
,
176
,
213
],
[
148
,
103
,
189
],
[
196
,
156
,
148
],
[
23
,
190
,
207
],
[
247
,
182
,
210
],
[
219
,
219
,
141
],
[
255
,
127
,
14
],
[
158
,
218
,
229
],
[
44
,
160
,
44
],
[
112
,
128
,
144
],
[
227
,
119
,
194
],
[
82
,
84
,
163
],
]
modality
=
dict
(
use_lidar
=
True
,
use_camera
=
False
)
pipeline
=
[
dict
(
type
=
'LoadPointsFromFile'
,
coord_type
=
'LIDAR'
,
shift_height
=
True
,
load_dim
=
4
,
use_dim
=
[
0
,
1
,
2
]),
dict
(
type
=
'LoadAnnotations3D'
,
with_bbox_3d
=
False
,
with_label_3d
=
False
,
with_mask_3d
=
False
,
with_seg_3d
=
True
,
seg_3d_dtype
=
np
.
int32
),
dict
(
type
=
'Pack3DDetInputs'
,
keys
=
[
'points'
,
'pts_semantic_mask'
])
]
data_prefix
=
dict
(
pts
=
'sequences/00/velodyne'
,
pts_semantic_mask
=
'sequences/00/labels'
)
return
(
data_root
,
ann_file
,
classes
,
palette
,
data_prefix
,
pipeline
,
modality
)
class
TestSemanticKITTIDataset
(
unittest
.
TestCase
):
def
test_semantickitti
(
self
):
np
.
random
.
seed
(
0
)
data_root
,
ann_file
,
classes
,
palette
,
data_prefix
,
\
pipeline
,
modality
,
=
_generate_semantickitti_dataset_config
()
register_all_modules
()
semantickitti_dataset
=
SemanticKITTIDataset
(
data_root
,
ann_file
,
metainfo
=
dict
(
CLASSES
=
classes
,
PALETTE
=
palette
),
data_prefix
=
data_prefix
,
pipeline
=
pipeline
,
modality
=
modality
)
input_dict
=
semantickitti_dataset
.
prepare_data
(
0
)
points
=
input_dict
[
'inputs'
][
'points'
]
data_sample
=
input_dict
[
'data_sample'
]
pts_semantic_mask
=
data_sample
.
gt_pts_seg
.
pts_semantic_mask
self
.
assertEqual
(
points
.
shape
[
0
],
pts_semantic_mask
.
shape
[
0
])
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment