Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
mmdetection3d
Commits
c66197c7
Commit
c66197c7
authored
Jul 17, 2022
by
ZCMax
Committed by
ChaimZhu
Jul 20, 2022
Browse files
[Refactor] 3D Segmentor and EncoderDecoder3D
parent
522cc20d
Changes
22
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
597 additions
and
309 deletions
+597
-309
configs/_base_/models/dgcnn.py
configs/_base_/models/dgcnn.py
+2
-1
configs/_base_/models/paconv_ssg.py
configs/_base_/models/paconv_ssg.py
+2
-1
configs/_base_/models/pointnet2_ssg.py
configs/_base_/models/pointnet2_ssg.py
+2
-1
configs/_base_/schedules/seg_cosine_100e.py
configs/_base_/schedules/seg_cosine_100e.py
+17
-4
configs/_base_/schedules/seg_cosine_150e.py
configs/_base_/schedules/seg_cosine_150e.py
+17
-5
configs/_base_/schedules/seg_cosine_200e.py
configs/_base_/schedules/seg_cosine_200e.py
+18
-6
configs/_base_/schedules/seg_cosine_50e.py
configs/_base_/schedules/seg_cosine_50e.py
+17
-5
configs/dgcnn/dgcnn_32x4_cosine_100e_s3dis_seg-3d-13class.py
configs/dgcnn/dgcnn_32x4_cosine_100e_s3dis_seg-3d-13class.py
+3
-6
configs/paconv/paconv_ssg_8x8_cosine_150e_s3dis_seg-3d-13class.py
...paconv/paconv_ssg_8x8_cosine_150e_s3dis_seg-3d-13class.py
+22
-13
configs/pointnet2/pointnet2_ssg_16x2_cosine_200e_scannet_seg-3d-20class.py
.../pointnet2_ssg_16x2_cosine_200e_scannet_seg-3d-20class.py
+5
-5
configs/pointnet2/pointnet2_ssg_16x2_cosine_50e_s3dis_seg-3d-13class.py
...et2/pointnet2_ssg_16x2_cosine_50e_s3dis_seg-3d-13class.py
+5
-5
mmdet3d/core/utils/__init__.py
mmdet3d/core/utils/__init__.py
+2
-1
mmdet3d/core/utils/misc.py
mmdet3d/core/utils/misc.py
+20
-0
mmdet3d/models/decode_heads/decode_head.py
mmdet3d/models/decode_heads/decode_head.py
+63
-21
mmdet3d/models/decode_heads/dgcnn_head.py
mmdet3d/models/decode_heads/dgcnn_head.py
+6
-3
mmdet3d/models/decode_heads/paconv_head.py
mmdet3d/models/decode_heads/paconv_head.py
+13
-6
mmdet3d/models/decode_heads/pointnet2_head.py
mmdet3d/models/decode_heads/pointnet2_head.py
+12
-6
mmdet3d/models/segmentors/base.py
mmdet3d/models/segmentors/base.py
+144
-113
mmdet3d/models/segmentors/encoder_decoder.py
mmdet3d/models/segmentors/encoder_decoder.py
+175
-107
tests/test_models/test_decode_heads/test_dgcnn_head.py
tests/test_models/test_decode_heads/test_dgcnn_head.py
+52
-0
No files found.
configs/_base_/models/dgcnn.py
View file @
c66197c7
# model settings
model
=
dict
(
type
=
'EncoderDecoder3D'
,
data_preprocessor
=
dict
(
type
=
'Det3DDataPreprocessor'
),
backbone
=
dict
(
type
=
'DGCNNBackbone'
,
in_channels
=
9
,
# [xyz, rgb, normal_xyz], modified with dataset
...
...
@@ -19,7 +20,7 @@ model = dict(
norm_cfg
=
dict
(
type
=
'BN1d'
),
act_cfg
=
dict
(
type
=
'LeakyReLU'
,
negative_slope
=
0.2
),
loss_decode
=
dict
(
type
=
'CrossEntropyLoss'
,
type
=
'
mmdet.
CrossEntropyLoss'
,
use_sigmoid
=
False
,
class_weight
=
None
,
# modified with dataset
loss_weight
=
1.0
)),
...
...
configs/_base_/models/paconv_ssg.py
View file @
c66197c7
# model settings
model
=
dict
(
type
=
'EncoderDecoder3D'
,
data_preprocessor
=
dict
(
type
=
'Det3DDataPreprocessor'
),
backbone
=
dict
(
type
=
'PointNet2SASSG'
,
in_channels
=
9
,
# [xyz, rgb, normalized_xyz]
...
...
@@ -37,7 +38,7 @@ model = dict(
norm_cfg
=
dict
(
type
=
'BN1d'
),
act_cfg
=
dict
(
type
=
'ReLU'
),
loss_decode
=
dict
(
type
=
'CrossEntropyLoss'
,
type
=
'
mmdet.
CrossEntropyLoss'
,
use_sigmoid
=
False
,
class_weight
=
None
,
# should be modified with dataset
loss_weight
=
1.0
)),
...
...
configs/_base_/models/pointnet2_ssg.py
View file @
c66197c7
# model settings
model
=
dict
(
type
=
'EncoderDecoder3D'
,
data_preprocessor
=
dict
(
type
=
'Det3DDataPreprocessor'
),
backbone
=
dict
(
type
=
'PointNet2SASSG'
,
in_channels
=
6
,
# [xyz, rgb], should be modified with dataset
...
...
@@ -26,7 +27,7 @@ model = dict(
norm_cfg
=
dict
(
type
=
'BN1d'
),
act_cfg
=
dict
(
type
=
'ReLU'
),
loss_decode
=
dict
(
type
=
'CrossEntropyLoss'
,
type
=
'
mmdet.
CrossEntropyLoss'
,
use_sigmoid
=
False
,
class_weight
=
None
,
# should be modified with dataset
loss_weight
=
1.0
)),
...
...
configs/_base_/schedules/seg_cosine_100e.py
View file @
c66197c7
# optimizer
# This schedule is mainly used on S3DIS dataset in segmentation task
optimizer
=
dict
(
type
=
'SGD'
,
lr
=
0.1
,
momentum
=
0.9
,
weight_decay
=
0.0001
)
optimizer_config
=
dict
(
grad_clip
=
None
)
lr_config
=
dict
(
policy
=
'CosineAnnealing'
,
warmup
=
None
,
min_lr
=
1e-5
)
optim_wrapper
=
dict
(
type
=
'OptimWrapper'
,
optimizer
=
dict
(
type
=
'SGD'
,
lr
=
0.1
,
momentum
=
0.9
,
weight_decay
=
0.001
),
clip_grad
=
None
)
param_scheduler
=
[
dict
(
type
=
'CosineAnnealingLR'
,
T_max
=
100
,
eta_min
=
1e-5
,
by_epoch
=
True
,
begin
=
0
,
end
=
100
)
]
# runtime settings
runner
=
dict
(
type
=
'EpochBasedRunner'
,
max_epochs
=
100
)
train_cfg
=
dict
(
by_epoch
=
True
,
max_epochs
=
100
)
val_cfg
=
dict
(
interval
=
1
)
test_cfg
=
dict
()
configs/_base_/schedules/seg_cosine_150e.py
View file @
c66197c7
# optimizer
# This schedule is mainly used on S3DIS dataset in segmentation task
optimizer
=
dict
(
type
=
'SGD'
,
lr
=
0.2
,
weight_decay
=
0.0001
,
momentum
=
0.9
)
optimizer_config
=
dict
(
grad_clip
=
None
)
lr_config
=
dict
(
policy
=
'CosineAnnealing'
,
warmup
=
None
,
min_lr
=
0.002
)
momentum_config
=
None
optim_wrapper
=
dict
(
type
=
'OptimWrapper'
,
optimizer
=
dict
(
type
=
'SGD'
,
lr
=
0.2
,
momentum
=
0.9
,
weight_decay
=
0.0001
),
clip_grad
=
None
)
param_scheduler
=
[
dict
(
type
=
'CosineAnnealingLR'
,
T_max
=
150
,
eta_min
=
0.002
,
by_epoch
=
True
,
begin
=
0
,
end
=
150
)
]
# runtime settings
runner
=
dict
(
type
=
'EpochBasedRunner'
,
max_epochs
=
150
)
train_cfg
=
dict
(
by_epoch
=
True
,
max_epochs
=
150
)
val_cfg
=
dict
(
interval
=
1
)
test_cfg
=
dict
()
configs/_base_/schedules/seg_cosine_200e.py
View file @
c66197c7
# optimizer
# This schedule is mainly used on ScanNet dataset in segmentation task
optimizer
=
dict
(
type
=
'Adam'
,
lr
=
0.001
,
weight_decay
=
0.01
)
optimizer_config
=
dict
(
grad_clip
=
None
)
lr_config
=
dict
(
policy
=
'CosineAnnealing'
,
warmup
=
None
,
min_lr
=
1e-5
)
momentum_config
=
None
# This schedule is mainly used on S3DIS dataset in segmentation task
optim_wrapper
=
dict
(
type
=
'OptimWrapper'
,
optimizer
=
dict
(
type
=
'Adam'
,
lr
=
0.001
,
weight_decay
=
0.01
),
clip_grad
=
None
)
param_scheduler
=
[
dict
(
type
=
'CosineAnnealingLR'
,
T_max
=
200
,
eta_min
=
1e-5
,
by_epoch
=
True
,
begin
=
0
,
end
=
200
)
]
# runtime settings
runner
=
dict
(
type
=
'EpochBasedRunner'
,
max_epochs
=
200
)
train_cfg
=
dict
(
by_epoch
=
True
,
max_epochs
=
200
)
val_cfg
=
dict
(
interval
=
1
)
test_cfg
=
dict
()
configs/_base_/schedules/seg_cosine_50e.py
View file @
c66197c7
# optimizer
# This schedule is mainly used on S3DIS dataset in segmentation task
optimizer
=
dict
(
type
=
'Adam'
,
lr
=
0.001
,
weight_decay
=
0.001
)
optimizer_config
=
dict
(
grad_clip
=
None
)
lr_config
=
dict
(
policy
=
'CosineAnnealing'
,
warmup
=
None
,
min_lr
=
1e-5
)
momentum_config
=
None
optim_wrapper
=
dict
(
type
=
'OptimWrapper'
,
optimizer
=
dict
(
type
=
'Adam'
,
lr
=
0.001
,
weight_decay
=
0.001
),
clip_grad
=
None
)
param_scheduler
=
[
dict
(
type
=
'CosineAnnealingLR'
,
T_max
=
50
,
eta_min
=
1e-5
,
by_epoch
=
True
,
begin
=
0
,
end
=
50
)
]
# runtime settings
runner
=
dict
(
type
=
'EpochBasedRunner'
,
max_epochs
=
50
)
train_cfg
=
dict
(
by_epoch
=
True
,
max_epochs
=
50
)
val_cfg
=
dict
(
interval
=
1
)
test_cfg
=
dict
()
configs/dgcnn/dgcnn_32x4_cosine_100e_s3dis_seg-3d-13class.py
View file @
c66197c7
...
...
@@ -3,10 +3,6 @@ _base_ = [
'../_base_/schedules/seg_cosine_100e.py'
,
'../_base_/default_runtime.py'
]
# data settings
data
=
dict
(
samples_per_gpu
=
32
)
evaluation
=
dict
(
interval
=
2
)
# model settings
model
=
dict
(
backbone
=
dict
(
in_channels
=
9
),
# [xyz, rgb, normalized_xyz]
...
...
@@ -20,5 +16,6 @@ model = dict(
use_normalized_coord
=
True
,
batch_size
=
24
))
# runtime settings
checkpoint_config
=
dict
(
interval
=
2
)
default_hooks
=
dict
(
checkpoint
=
dict
(
type
=
'CheckpointHook'
,
interval
=
2
),
)
train_dataloader
=
dict
(
batch_size
=
32
)
val_cfg
=
dict
(
interval
=
2
)
configs/paconv/paconv_ssg_8x8_cosine_150e_s3dis_seg-3d-13class.py
View file @
c66197c7
...
...
@@ -4,9 +4,20 @@ _base_ = [
'../_base_/default_runtime.py'
]
# file_client_args = dict(backend='disk')
# Uncomment the following if use ceph or other file clients.
# See https://mmcv.readthedocs.io/en/latest/api.html#mmcv.fileio.FileClient
# for more details.
file_client_args
=
dict
(
backend
=
'petrel'
,
path_mapping
=
dict
({
'./data/s3dis/'
:
's3://openmmlab/datasets/detection3d/s3dis_processed/'
,
'data/s3dis/'
:
's3://openmmlab/datasets/detection3d/s3dis_processed/'
}))
# data settings
class_names
=
(
'ceiling'
,
'floor'
,
'wall'
,
'beam'
,
'column'
,
'window'
,
'door'
,
'table'
,
'chair'
,
'sofa'
,
'bookcase'
,
'board'
,
'clutter'
)
num_points
=
4096
train_pipeline
=
[
dict
(
...
...
@@ -15,17 +26,16 @@ train_pipeline = [
shift_height
=
False
,
use_color
=
True
,
load_dim
=
6
,
use_dim
=
[
0
,
1
,
2
,
3
,
4
,
5
]),
use_dim
=
[
0
,
1
,
2
,
3
,
4
,
5
],
file_client_args
=
file_client_args
),
dict
(
type
=
'LoadAnnotations3D'
,
with_bbox_3d
=
False
,
with_label_3d
=
False
,
with_mask_3d
=
False
,
with_seg_3d
=
True
),
dict
(
type
=
'PointSegClassMapping'
,
valid_cat_ids
=
tuple
(
range
(
len
(
class_names
))),
max_cat_id
=
13
),
with_seg_3d
=
True
,
file_client_args
=
file_client_args
),
dict
(
type
=
'PointSegClassMapping'
),
dict
(
type
=
'IndoorPatchPointSample'
,
num_points
=
num_points
,
...
...
@@ -46,13 +56,9 @@ train_pipeline = [
jitter_std
=
[
0.01
,
0.01
,
0.01
],
clip_range
=
[
-
0.05
,
0.05
]),
dict
(
type
=
'RandomDropPointsColor'
,
drop_ratio
=
0.2
),
dict
(
type
=
'DefaultFormatBundle3D'
,
class_names
=
class_names
),
dict
(
type
=
'Collect3D'
,
keys
=
[
'points'
,
'pts_semantic_mask'
])
dict
(
type
=
'Pack3DDetInputs'
,
keys
=
[
'points'
,
'pts_semantic_mask'
])
]
data
=
dict
(
samples_per_gpu
=
8
,
train
=
dict
(
pipeline
=
train_pipeline
))
evaluation
=
dict
(
interval
=
1
)
# model settings
model
=
dict
(
decode_head
=
dict
(
...
...
@@ -64,3 +70,6 @@ model = dict(
sample_rate
=
0.5
,
use_normalized_coord
=
True
,
batch_size
=
12
))
train_dataloader
=
dict
(
batch_size
=
8
,
dataset
=
dict
(
pipeline
=
train_pipeline
))
val_cfg
=
dict
(
interval
=
1
)
configs/pointnet2/pointnet2_ssg_16x2_cosine_200e_scannet_seg-3d-20class.py
View file @
c66197c7
...
...
@@ -4,10 +4,6 @@ _base_ = [
'../_base_/schedules/seg_cosine_200e.py'
,
'../_base_/default_runtime.py'
]
# data settings
data
=
dict
(
samples_per_gpu
=
16
)
evaluation
=
dict
(
interval
=
5
)
# model settings
model
=
dict
(
decode_head
=
dict
(
...
...
@@ -30,5 +26,9 @@ model = dict(
use_normalized_coord
=
False
,
batch_size
=
24
))
# data settings
train_dataloader
=
dict
(
batch_size
=
16
)
# runtime settings
checkpoint_config
=
dict
(
interval
=
5
)
default_hooks
=
dict
(
checkpoint
=
dict
(
type
=
'CheckpointHook'
,
interval
=
5
),
)
val_cfg
=
dict
(
interval
=
5
)
configs/pointnet2/pointnet2_ssg_16x2_cosine_50e_s3dis_seg-3d-13class.py
View file @
c66197c7
...
...
@@ -4,10 +4,6 @@ _base_ = [
'../_base_/schedules/seg_cosine_50e.py'
,
'../_base_/default_runtime.py'
]
# data settings
data
=
dict
(
samples_per_gpu
=
16
)
evaluation
=
dict
(
interval
=
2
)
# model settings
model
=
dict
(
backbone
=
dict
(
in_channels
=
9
),
# [xyz, rgb, normalized_xyz]
...
...
@@ -21,5 +17,9 @@ model = dict(
use_normalized_coord
=
True
,
batch_size
=
24
))
# data settings
train_dataloader
=
dict
(
batch_size
=
6
)
# runtime settings
checkpoint_config
=
dict
(
interval
=
2
)
default_hooks
=
dict
(
checkpoint
=
dict
(
type
=
'CheckpointHook'
,
interval
=
2
),
)
val_cfg
=
dict
(
interval
=
2
)
mmdet3d/core/utils/__init__.py
View file @
c66197c7
...
...
@@ -2,6 +2,7 @@
from
.array_converter
import
ArrayConverter
,
array_converter
from
.gaussian
import
(
draw_heatmap_gaussian
,
ellip_gaussian2D
,
gaussian_2d
,
gaussian_radius
,
get_ellip_gaussian_2D
)
from
.misc
import
add_prefix
from
.typing
import
(
ConfigType
,
ForwardResults
,
InstanceList
,
MultiConfig
,
OptConfigType
,
OptInstanceList
,
OptMultiConfig
,
OptSampleList
,
OptSamplingResultList
,
SampleList
,
...
...
@@ -13,5 +14,5 @@ __all__ = [
'get_ellip_gaussian_2D'
,
'ConfigType'
,
'OptConfigType'
,
'MultiConfig'
,
'OptMultiConfig'
,
'InstanceList'
,
'OptInstanceList'
,
'SampleList'
,
'OptSampleList'
,
'SamplingResultList'
,
'ForwardResults'
,
'OptSamplingResultList'
'OptSamplingResultList'
,
'add_prefix'
]
mmdet3d/core/utils/misc.py
0 → 100644
View file @
c66197c7
# Copyright (c) OpenMMLab. All rights reserved.
def
add_prefix
(
inputs
,
prefix
):
"""Add prefix for dict.
Args:
inputs (dict): The input dict with str keys.
prefix (str): The prefix to add.
Returns:
dict: The dict with keys updated with ``prefix``.
"""
outputs
=
dict
()
for
name
,
value
in
inputs
.
items
():
outputs
[
f
'
{
prefix
}
.
{
name
}
'
]
=
value
return
outputs
mmdet3d/models/decode_heads/decode_head.py
View file @
c66197c7
# Copyright (c) OpenMMLab. All rights reserved.
from
abc
import
ABCMeta
,
abstractmethod
from
typing
import
List
import
torch
from
mmcv.cnn
import
normal_init
from
mmcv.runner
import
BaseModule
,
auto_fp16
,
force_fp32
from
mmcv.runner
import
BaseModule
,
auto_fp16
from
torch
import
Tensor
from
torch
import
nn
as
nn
from
..builder
import
build_loss
from
mmdet3d.core.utils.typing
import
ConfigType
,
SampleList
from
mmdet3d.registry
import
MODELS
class
Base3DDecodeHead
(
BaseModule
,
metaclass
=
ABCMeta
):
"""Base class for BaseDecodeHead.
1. The ``init_weights`` method is used to initialize decode_head's
model parameters. After segmentor initialization, ``init_weights``
is triggered when ``segmentor.init_weights()`` is called externally.
2. The ``loss`` method is used to calculate the loss of decode_head,
which includes two steps: (1) the decode_head model performs forward
propagation to obtain the feature maps (2) The ``loss_by_feat`` method
is called based on the feature maps to calculate the loss.
.. code:: text
loss(): forward() -> loss_by_feat()
3. The ``predict`` method is used to predict segmentation results,
which includes two steps: (1) the decode_head model performs forward
propagation to obtain the feature maps (2) The ``predict_by_feat`` method
is called based on the feature maps to predict segmentation results
including post-processing.
.. code:: text
predict(): forward() -> predict_by_feat()
Args:
channels (int): Channels after modules, before conv_seg.
num_classes (int): Number of classes.
...
...
@@ -26,6 +53,7 @@ class Base3DDecodeHead(BaseModule, metaclass=ABCMeta):
ignore_index (int, optional): The label index to be ignored.
When using masked BCE loss, ignore_index should be set to None.
Default: 255.
init_cfg (dict or list[dict], optional): Initialization config dict.
"""
def
__init__
(
self
,
...
...
@@ -36,12 +64,12 @@ class Base3DDecodeHead(BaseModule, metaclass=ABCMeta):
norm_cfg
=
dict
(
type
=
'BN1d'
),
act_cfg
=
dict
(
type
=
'ReLU'
),
loss_decode
=
dict
(
type
=
'CrossEntropyLoss'
,
type
=
'
mmdet.
CrossEntropyLoss'
,
use_sigmoid
=
False
,
class_weight
=
None
,
loss_weight
=
1.0
),
ignore_index
=
255
,
init_cfg
=
None
):
init_cfg
=
None
)
->
None
:
super
(
Base3DDecodeHead
,
self
).
__init__
(
init_cfg
=
init_cfg
)
self
.
channels
=
channels
self
.
num_classes
=
num_classes
...
...
@@ -49,7 +77,7 @@ class Base3DDecodeHead(BaseModule, metaclass=ABCMeta):
self
.
conv_cfg
=
conv_cfg
self
.
norm_cfg
=
norm_cfg
self
.
act_cfg
=
act_cfg
self
.
loss_decode
=
build
_loss
(
loss_decode
)
self
.
loss_decode
=
MODELS
.
build
(
loss_decode
)
self
.
ignore_index
=
ignore_index
self
.
conv_seg
=
nn
.
Conv1d
(
channels
,
num_classes
,
kernel_size
=
1
)
...
...
@@ -57,6 +85,7 @@ class Base3DDecodeHead(BaseModule, metaclass=ABCMeta):
self
.
dropout
=
nn
.
Dropout
(
dropout_ratio
)
else
:
self
.
dropout
=
None
self
.
fp16_enabled
=
False
def
init_weights
(
self
):
...
...
@@ -66,11 +95,19 @@ class Base3DDecodeHead(BaseModule, metaclass=ABCMeta):
@
auto_fp16
()
@
abstractmethod
def
forward
(
self
,
inputs
):
def
forward
(
self
,
feats_dict
:
dict
):
"""Placeholder of forward function."""
pass
def
forward_train
(
self
,
inputs
,
img_metas
,
pts_semantic_mask
,
train_cfg
):
def
cls_seg
(
self
,
feat
:
Tensor
)
->
Tensor
:
"""Classify each points."""
if
self
.
dropout
is
not
None
:
feat
=
self
.
dropout
(
feat
)
output
=
self
.
conv_seg
(
feat
)
return
output
def
loss
(
self
,
inputs
:
List
[
Tensor
],
batch_data_samples
:
SampleList
,
train_cfg
:
ConfigType
)
->
dict
:
"""Forward function for training.
Args:
...
...
@@ -84,39 +121,44 @@ class Base3DDecodeHead(BaseModule, metaclass=ABCMeta):
dict[str, Tensor]: a dictionary of loss components
"""
seg_logits
=
self
.
forward
(
inputs
)
losses
=
self
.
loss
es
(
seg_logits
,
pts_semantic_mask
)
losses
=
self
.
loss
_by_feat
(
seg_logits
,
batch_data_samples
)
return
losses
def
forward_test
(
self
,
inputs
,
img_metas
,
test_cfg
):
def
predict
(
self
,
inputs
:
List
[
Tensor
],
batch_input_metas
:
List
[
dict
],
test_cfg
:
ConfigType
)
->
List
[
Tensor
]:
"""Forward function for testing.
Args:
inputs (list[Tensor]): List of multi-level point features.
img_metas (list[dict]): Meta information of each sample.
batch_
img_metas (list[dict]): Meta information of each sample.
test_cfg (dict): The testing config.
Returns:
Tensor: Output segmentation map.
"""
return
self
.
forward
(
inputs
)
seg_logits
=
self
.
forward
(
inputs
)
def
cls_seg
(
self
,
feat
):
"""Classify each points."""
if
self
.
dropout
is
not
None
:
feat
=
self
.
dropout
(
feat
)
output
=
self
.
conv_seg
(
feat
)
return
output
return
seg_logits
def
_stack_batch_gt
(
self
,
batch_data_samples
:
SampleList
)
->
Tensor
:
gt_semantic_segs
=
[
data_sample
.
gt_pts_seg
.
pts_semantic_mask
for
data_sample
in
batch_data_samples
]
return
torch
.
stack
(
gt_semantic_segs
,
dim
=
0
)
@
force_fp32
(
apply_to
=
(
'seg_logit'
,
))
def
losses
(
self
,
seg_logit
,
seg_label
)
:
def
loss_by_feat
(
self
,
seg_logit
:
Tensor
,
batch_data_samples
:
SampleList
)
->
dict
:
"""Compute semantic segmentation loss.
Args:
seg_logit (torch.Tensor): Predicted per-point segmentation logits
of shape [B, num_classes, N].
seg_label (torch.Tensor): Ground-truth segmentation label of
shape [B, N].
batch_data_samples (List[:obj:`Det3DDataSample`]): The seg
data samples. It usually includes information such
as `metainfo` and `gt_pts_seg`.
"""
seg_label
=
self
.
_stack_batch_gt
(
batch_data_samples
)
loss
=
dict
()
loss
[
'loss_sem_seg'
]
=
self
.
loss_decode
(
seg_logit
,
seg_label
,
ignore_index
=
self
.
ignore_index
)
...
...
mmdet3d/models/decode_heads/dgcnn_head.py
View file @
c66197c7
# Copyright (c) OpenMMLab. All rights reserved.
from
typing
import
Tuple
from
mmcv.cnn.bricks
import
ConvModule
from
torch
import
Tensor
from
mmdet3d.ops
import
DGCNNFPModule
from
mmdet3d.registry
import
MODELS
...
...
@@ -19,7 +22,7 @@ class DGCNNHead(Base3DDecodeHead):
propagation (FP) modules. Defaults to (1216, 512).
"""
def
__init__
(
self
,
fp_channels
=
(
1216
,
512
),
**
kwargs
):
def
__init__
(
self
,
fp_channels
:
Tuple
=
(
1216
,
512
),
**
kwargs
)
->
None
:
super
(
DGCNNHead
,
self
).
__init__
(
**
kwargs
)
self
.
FP_module
=
DGCNNFPModule
(
...
...
@@ -35,7 +38,7 @@ class DGCNNHead(Base3DDecodeHead):
norm_cfg
=
self
.
norm_cfg
,
act_cfg
=
self
.
act_cfg
)
def
_extract_input
(
self
,
feat_dict
)
:
def
_extract_input
(
self
,
feat_dict
:
dict
)
->
Tensor
:
"""Extract inputs from features dictionary.
Args:
...
...
@@ -48,7 +51,7 @@ class DGCNNHead(Base3DDecodeHead):
return
fa_points
def
forward
(
self
,
feat_dict
)
:
def
forward
(
self
,
feat_dict
:
dict
)
->
Tensor
:
"""Forward pass.
Args:
...
...
mmdet3d/models/decode_heads/paconv_head.py
View file @
c66197c7
# Copyright (c) OpenMMLab. All rights reserved.
from
typing
import
Tuple
from
mmcv.cnn.bricks
import
ConvModule
from
torch
import
Tensor
from
mmdet3d.core.utils
import
ConfigType
from
mmdet3d.registry
import
MODELS
from
.pointnet2_head
import
PointNet2Head
...
...
@@ -19,11 +23,14 @@ class PAConvHead(PointNet2Head):
"""
def
__init__
(
self
,
fp_channels
=
((
768
,
256
,
256
),
(
384
,
256
,
256
),
(
320
,
256
,
128
),
(
128
+
6
,
128
,
128
,
128
)),
fp_norm_cfg
=
dict
(
type
=
'BN2d'
),
**
kwargs
):
super
(
PAConvHead
,
self
).
__init__
(
fp_channels
,
fp_norm_cfg
,
**
kwargs
)
fp_channels
:
Tuple
[
Tuple
[
int
]]
=
((
768
,
256
,
256
),
(
384
,
256
,
256
),
(
320
,
256
,
128
),
(
128
+
6
,
128
,
128
,
128
)),
fp_norm_cfg
:
ConfigType
=
dict
(
type
=
'BN2d'
),
**
kwargs
)
->
None
:
super
(
PAConvHead
,
self
).
__init__
(
fp_channels
=
fp_channels
,
fp_norm_cfg
=
fp_norm_cfg
,
**
kwargs
)
# https://github.com/CVMI-Lab/PAConv/blob/main/scene_seg/model/pointnet2/pointnet2_paconv_seg.py#L53
# PointNet++'s decoder conv has bias while PAConv's doesn't have
...
...
@@ -37,7 +44,7 @@ class PAConvHead(PointNet2Head):
norm_cfg
=
self
.
norm_cfg
,
act_cfg
=
self
.
act_cfg
)
def
forward
(
self
,
feat_dict
)
:
def
forward
(
self
,
feat_dict
:
dict
)
->
Tensor
:
"""Forward pass.
Args:
...
...
mmdet3d/models/decode_heads/pointnet2_head.py
View file @
c66197c7
# Copyright (c) OpenMMLab. All rights reserved.
from
typing
import
Tuple
from
mmcv.cnn.bricks
import
ConvModule
from
torch
import
Tensor
from
torch
import
nn
as
nn
from
mmdet3d.core.utils.typing
import
ConfigType
from
mmdet3d.ops
import
PointFPModule
from
mmdet3d.registry
import
MODELS
from
.decode_head
import
Base3DDecodeHead
...
...
@@ -21,10 +25,12 @@ class PointNet2Head(Base3DDecodeHead):
"""
def
__init__
(
self
,
fp_channels
=
((
768
,
256
,
256
),
(
384
,
256
,
256
),
(
320
,
256
,
128
),
(
128
,
128
,
128
,
128
)),
fp_norm_cfg
=
dict
(
type
=
'BN2d'
),
**
kwargs
):
fp_channels
:
Tuple
[
Tuple
[
int
]]
=
((
768
,
256
,
256
),
(
384
,
256
,
256
),
(
320
,
256
,
128
),
(
128
,
128
,
128
,
128
)),
fp_norm_cfg
:
ConfigType
=
dict
(
type
=
'BN2d'
),
**
kwargs
)
->
None
:
super
(
PointNet2Head
,
self
).
__init__
(
**
kwargs
)
self
.
num_fp
=
len
(
fp_channels
)
...
...
@@ -43,7 +49,7 @@ class PointNet2Head(Base3DDecodeHead):
norm_cfg
=
self
.
norm_cfg
,
act_cfg
=
self
.
act_cfg
)
def
_extract_input
(
self
,
feat_dict
)
:
def
_extract_input
(
self
,
feat_dict
:
dict
)
->
Tensor
:
"""Extract inputs from features dictionary.
Args:
...
...
@@ -59,7 +65,7 @@ class PointNet2Head(Base3DDecodeHead):
return
sa_xyz
,
sa_features
def
forward
(
self
,
feat_dict
)
:
def
forward
(
self
,
feat_dict
:
dict
)
->
Tensor
:
"""Forward pass.
Args:
...
...
mmdet3d/models/segmentors/base.py
View file @
c66197c7
# Copyright (c) OpenMMLab. All rights reserved.
from
os
import
path
as
osp
from
abc
import
ABCMeta
,
abstractmethod
from
typing
import
List
,
Tuple
import
mmcv
import
numpy
as
np
import
torch
from
mmcv.parallel
import
DataContainer
as
DC
from
mmcv.runner
import
auto_fp16
from
mmengine.data
import
PixelData
from
mmengine.model
import
BaseModel
from
torch
import
Tensor
from
mmdet3d.core
import
show_seg_result
from
mmseg.models.segmentors
import
BaseSegmentor
from
mmdet3d.core
import
Det3DDataSample
from
mmdet3d.core.utils
import
(
ForwardResults
,
OptConfigType
,
OptMultiConfig
,
OptSampleList
,
SampleList
)
class
Base3DSegmentor
(
Base
Segmentor
):
class
Base3DSegmentor
(
Base
Model
,
metaclass
=
ABCMeta
):
"""Base class for 3D segmentors.
The main difference with `BaseSegmentor` is that we modify the keys in
data_dict and use a 3D seg specific visualization function.
Args:
data_preprocessor (dict, optional): Model preprocessing config
for processing the input data. it usually includes
``to_rgb``, ``pad_size_divisor``, ``pad_val``,
``mean`` and ``std``. Default to None.
init_cfg (dict, optional): the config to control the
initialization. Default to None.
"""
def
__init__
(
self
,
data_preprocessor
:
OptConfigType
=
None
,
init_cfg
:
OptMultiConfig
=
None
):
super
(
Base3DSegmentor
,
self
).
__init__
(
data_preprocessor
=
data_preprocessor
,
init_cfg
=
init_cfg
)
@
property
def
with_neck
(
self
)
->
bool
:
"""bool: whether the segmentor has neck"""
return
hasattr
(
self
,
'neck'
)
and
self
.
neck
is
not
None
@
property
def
with_auxiliary_head
(
self
)
->
bool
:
"""bool: whether the segmentor has auxiliary head"""
return
hasattr
(
self
,
'auxiliary_head'
)
and
self
.
auxiliary_head
is
not
None
@
property
def
with_regularization_loss
(
self
):
def
with_decode_head
(
self
)
->
bool
:
"""bool: whether the segmentor has decode head"""
return
hasattr
(
self
,
'decode_head'
)
and
self
.
decode_head
is
not
None
@
property
def
with_regularization_loss
(
self
)
->
bool
:
"""bool: whether the segmentor has regularization loss for weight"""
return
hasattr
(
self
,
'loss_regularization'
)
and
\
self
.
loss_regularization
is
not
None
def
forward_test
(
self
,
points
,
img_metas
,
**
kwargs
):
"""Calls either simple_test or aug_test depending on the length of
outer list of points. If len(points) == 1, call simple_test. Otherwise
call aug_test to aggregate the test results by e.g. voting.
@
abstractmethod
def
extract_feat
(
self
,
batch_inputs
:
Tensor
)
->
bool
:
"""Placeholder for extract features from images."""
pass
@
abstractmethod
def
encode_decode
(
self
,
batch_inputs
:
Tensor
,
batch_data_samples
:
SampleList
):
"""Placeholder for encode images with backbone and decode into a
semantic segmentation map of the same size as input."""
pass
def
forward
(
self
,
batch_inputs_dict
:
Tensor
,
batch_data_samples
:
OptSampleList
=
None
,
mode
:
str
=
'tensor'
)
->
ForwardResults
:
"""The unified entry for a forward process in both training and test.
The method should accept three modes: "tensor", "predict" and "loss":
- "tensor": Forward the whole network and return tensor or tuple of
tensor without any post-processing, same as a common nn.Module.
- "predict": Forward and return the predictions, which are fully
processed to a list of :obj:`SegDataSample`.
- "loss": Forward and return a dict of losses according to the given
inputs and data samples.
Note that this method doesn't handle neither back propagation nor
optimizer updating, which are done in the :meth:`train_step`.
Args:
points (list[list[torch.Tensor]]): the outer list indicates
test-time augmentations and inner torch.Tensor should have a
shape BXNxC, which contains all points in the batch.
img_metas (list[list[dict]]): the outer list indicates test-time
augs (multiscale, flip, etc.) and the inner list indicates
images in a batch.
"""
for
var
,
name
in
[(
points
,
'points'
),
(
img_metas
,
'img_metas'
)]:
if
not
isinstance
(
var
,
list
):
raise
TypeError
(
f
'
{
name
}
must be a list, but got
{
type
(
var
)
}
'
)
batch_inputs_dict (dict): Input sample dict which
includes 'points' and 'imgs' keys.
num_augs
=
len
(
points
)
if
num_augs
!=
len
(
img_metas
):
raise
ValueError
(
f
'num of augmentations (
{
len
(
points
)
}
) != '
f
'num of image meta (
{
len
(
img_metas
)
}
)'
)
- points (list[torch.Tensor]): Point cloud of each sample.
- imgs (torch.Tensor): Image tensor has shape (B, C, H, W).
batch_data_samples (list[:obj:`Det3DDataSample`], optional):
The annotation data of every samples. Defaults to None.
mode (str): Return what kind of value. Defaults to 'tensor'.
if
num_augs
==
1
:
return
self
.
simple_test
(
points
[
0
],
img_metas
[
0
],
**
kwargs
)
else
:
return
self
.
aug_test
(
points
,
img_metas
,
**
kwargs
)
@
auto_fp16
(
apply_to
=
(
'points'
))
def
forward
(
self
,
return_loss
=
True
,
**
kwargs
):
"""Calls either forward_train or forward_test depending on whether
return_loss=True.
Note this setting will change the expected inputs. When
`return_loss=True`, point and img_metas are single-nested (i.e.
torch.Tensor and list[dict]), and when `resturn_loss=False`, point and
img_metas should be double nested (i.e. list[torch.Tensor],
list[list[dict]]), with the outer list indicating test time
augmentations.
Returns:
The return type depends on ``mode``.
- If ``mode="tensor"``, return a tensor or a tuple of tensor.
- If ``mode="predict"``, return a list of :obj:`Det3DDataSample`.
- If ``mode="loss"``, return a dict of tensor.
"""
if
return_loss
:
return
self
.
forward_train
(
**
kwargs
)
if
mode
==
'loss'
:
return
self
.
loss
(
batch_inputs_dict
,
batch_data_samples
)
elif
mode
==
'predict'
:
return
self
.
predict
(
batch_inputs_dict
,
batch_data_samples
)
elif
mode
==
'tensor'
:
return
self
.
_forward
(
batch_inputs_dict
,
batch_data_samples
)
else
:
return
self
.
forward_test
(
**
kwargs
)
def
show_results
(
self
,
data
,
result
,
palette
=
None
,
out_dir
=
None
,
ignore_index
=
None
,
show
=
False
,
score_thr
=
None
):
"""Results visualization.
raise
RuntimeError
(
f
'Invalid mode "
{
mode
}
". '
'Only supports loss, predict and tensor mode'
)
@
abstractmethod
def
loss
(
self
,
batch_inputs
:
Tensor
,
batch_data_samples
:
SampleList
)
->
dict
:
"""Calculate losses from a batch of inputs and data samples."""
pass
@
abstractmethod
def
predict
(
self
,
batch_inputs
:
Tensor
,
batch_data_samples
:
SampleList
)
->
SampleList
:
"""Predict results from a batch of inputs and data samples with post-
processing."""
pass
@
abstractmethod
def
_forward
(
self
,
batch_inputs
:
Tensor
,
batch_data_samples
:
OptSampleList
=
None
)
->
Tuple
[
List
[
Tensor
]]:
"""Network forward process.
Usually includes backbone, neck and head forward without any post-
processing.
"""
pass
@
abstractmethod
def
aug_test
(
self
,
batch_inputs
,
batch_img_metas
):
"""Placeholder for augmentation test."""
pass
def
postprocess_result
(
self
,
seg_logits_list
:
List
[
dict
],
batch_img_metas
:
List
[
dict
])
->
list
:
""" Convert results list to `Det3DDataSample`.
Args:
data (list[dict]): Input points and the information of the sample.
result (list[dict]): Prediction results.
palette (list[list[int]]] | np.ndarray): The palette of
segmentation map. If None is given, random palette will be
generated. Default: None
out_dir (str): Output directory of visualization result.
ignore_index (int, optional): The label index to be ignored, e.g.
unannotated points. If None is given, set to len(self.CLASSES).
Defaults to None.
show (bool, optional): Determines whether you are
going to show result by open3d.
Defaults to False.
TODO: implement score_thr of Base3DSegmentor.
score_thr (float, optional): Score threshold of bounding boxes.
Default to None.
Not implemented yet, but it is here for unification.
seg_logits_list (List[dict]): List of segmentation results,
seg_logits from model of each input point clouds sample.
Returns:
list[:obj:`Det3DDataSample`]: Segmentation results of the
input images. Each Det3DDataSample usually contain:
- ``pred_pts_sem_seg``(PixelData): Prediction of 3D
semantic segmentation.
- ``seg_logits``(PixelData): Predicted logits of semantic
segmentation before normalization.
"""
assert
out_dir
is
not
None
,
'Expect out_dir, got none.'
if
palette
is
None
:
if
self
.
PALETTE
is
None
:
palette
=
np
.
random
.
randint
(
0
,
255
,
size
=
(
len
(
self
.
CLASSES
),
3
))
else
:
palette
=
self
.
PALETTE
palette
=
np
.
array
(
palette
)
for
batch_id
in
range
(
len
(
result
)):
if
isinstance
(
data
[
'points'
][
0
],
DC
):
points
=
data
[
'points'
][
0
].
_data
[
0
][
batch_id
].
numpy
()
elif
mmcv
.
is_list_of
(
data
[
'points'
][
0
],
torch
.
Tensor
):
points
=
data
[
'points'
][
0
][
batch_id
]
else
:
ValueError
(
f
"Unsupported data type
{
type
(
data
[
'points'
][
0
])
}
"
f
'for visualization!'
)
if
isinstance
(
data
[
'img_metas'
][
0
],
DC
):
pts_filename
=
data
[
'img_metas'
][
0
].
_data
[
0
][
batch_id
][
'pts_filename'
]
elif
mmcv
.
is_list_of
(
data
[
'img_metas'
][
0
],
dict
):
pts_filename
=
data
[
'img_metas'
][
0
][
batch_id
][
'pts_filename'
]
else
:
ValueError
(
f
"Unsupported data type
{
type
(
data
[
'img_metas'
][
0
])
}
"
f
'for visualization!'
)
file_name
=
osp
.
split
(
pts_filename
)[
-
1
].
split
(
'.'
)[
0
]
pred_sem_mask
=
result
[
batch_id
][
'semantic_mask'
].
cpu
().
numpy
()
show_seg_result
(
points
,
None
,
pred_sem_mask
,
out_dir
,
file_name
,
palette
,
ignore_index
,
show
=
show
)
predictions
=
[]
for
i
in
range
(
len
(
seg_logits_list
)):
img_meta
=
batch_img_metas
[
i
]
seg_logits
=
seg_logits_list
[
i
][
None
],
seg_pred
=
seg_logits
.
argmax
(
dim
=
0
,
keepdim
=
True
)
prediction
=
Det3DDataSample
(
**
{
'metainfo'
:
img_meta
})
prediction
.
set_data
(
{
'pred_pts_sem_seg'
:
PixelData
(
**
{
'data'
:
seg_pred
})})
predictions
.
append
(
prediction
)
return
predictions
mmdet3d/models/segmentors/encoder_decoder.py
View file @
c66197c7
# Copyright (c) OpenMMLab. All rights reserved.
from
typing
import
List
import
numpy
as
np
import
torch
from
torch
import
Tensor
from
torch
import
nn
as
nn
from
torch.nn
import
functional
as
F
from
mmdet3d.core
import
add_prefix
from
mmdet3d.core.utils
import
(
ConfigType
,
OptConfigType
,
OptMultiConfig
,
OptSampleList
,
SampleList
)
from
mmdet3d.registry
import
MODELS
from
mmseg.core
import
add_prefix
from
.base
import
Base3DSegmentor
...
...
@@ -15,20 +20,69 @@ class EncoderDecoder3D(Base3DSegmentor):
EncoderDecoder typically consists of backbone, decode_head, auxiliary_head.
Note that auxiliary_head is only used for deep supervision during training,
which could be thrown during inference.
"""
which could be dumped during inference.
1. The ``loss`` method is used to calculate the loss of model,
which includes two steps: (1) Extracts features to obtain the feature maps
(2) Call the decode head loss function to forward decode head model and
calculate losses.
.. code:: text
loss(): extract_feat() -> _decode_head_forward_train() -> _auxiliary_head_forward_train (optional)
_decode_head_forward_train(): decode_head.loss()
_auxiliary_head_forward_train(): auxiliary_head.loss (optional)
2. The ``predict`` method is used to predict segmentation results,
which includes two steps: (1) Run inference function to obtain the list of
seg_logits (2) Call post-processing function to obtain list of
``SegDataSampel`` including ``pred_sem_seg`` and ``seg_logits``.
.. code:: text
predict(): inference() -> postprocess_result()
infercen(): whole_inference()/slide_inference()
whole_inference()/slide_inference(): encoder_decoder()
encoder_decoder(): extract_feat() -> decode_head.predict()
4 The ``_forward`` method is used to output the tensor by running the model,
which includes two steps: (1) Extracts features to obtain the feature maps
(2)Call the decode head forward function to forward decode head model.
.. code:: text
_forward(): extract_feat() -> _decode_head.forward()
Args:
backbone (ConfigType): The config for the backnone of segmentor.
decode_head (ConfigType): The config for the decode head of segmentor.
neck (OptConfigType): The config for the neck of segmentor.
Defaults to None.
auxiliary_head (OptConfigType): The config for the auxiliary head of
segmentor. Defaults to None.
loss_regularization (OptiConfigType): The config for the regularization
loass. Defaults to None.
train_cfg (OptConfigType): The config for training. Defaults to None.
test_cfg (OptConfigType): The config for testing. Defaults to None.
data_preprocessor (dict, optional): The pre-process config of
:class:`BaseDataPreprocessor`.
init_cfg (dict, optional): The weight initialized config for
:class:`BaseModule`.
"""
# noqa: E501
def
__init__
(
self
,
backbone
,
decode_head
,
neck
=
None
,
auxiliary_head
=
None
,
loss_regularization
=
None
,
train_cfg
=
None
,
test_cfg
=
None
,
pretrained
=
None
,
init_cfg
=
None
):
super
(
EncoderDecoder3D
,
self
).
__init__
(
init_cfg
=
init_cfg
)
backbone
:
ConfigType
,
decode_head
:
ConfigType
,
neck
:
OptConfigType
=
None
,
auxiliary_head
:
OptConfigType
=
None
,
loss_regularization
:
OptConfigType
=
None
,
train_cfg
:
OptConfigType
=
None
,
test_cfg
:
OptConfigType
=
None
,
data_preprocessor
:
OptConfigType
=
None
,
init_cfg
:
OptMultiConfig
=
None
):
super
(
EncoderDecoder3D
,
self
).
__init__
(
data_preprocessor
=
data_preprocessor
,
init_cfg
=
init_cfg
)
self
.
backbone
=
MODELS
.
build
(
backbone
)
if
neck
is
not
None
:
self
.
neck
=
MODELS
.
build
(
neck
)
...
...
@@ -38,15 +92,16 @@ class EncoderDecoder3D(Base3DSegmentor):
self
.
train_cfg
=
train_cfg
self
.
test_cfg
=
test_cfg
assert
self
.
with_decode_head
,
\
'3D EncoderDecoder Segmentor should have a decode_head'
def
_init_decode_head
(
self
,
decode_head
)
:
def
_init_decode_head
(
self
,
decode_head
:
ConfigType
)
->
None
:
"""Initialize ``decode_head``"""
self
.
decode_head
=
MODELS
.
build
(
decode_head
)
self
.
num_classes
=
self
.
decode_head
.
num_classes
def
_init_auxiliary_head
(
self
,
auxiliary_head
)
:
def
_init_auxiliary_head
(
self
,
auxiliary_head
:
ConfigType
)
->
None
:
"""Initialize ``auxiliary_head``"""
if
auxiliary_head
is
not
None
:
if
isinstance
(
auxiliary_head
,
list
):
...
...
@@ -56,7 +111,8 @@ class EncoderDecoder3D(Base3DSegmentor):
else
:
self
.
auxiliary_head
=
MODELS
.
build
(
auxiliary_head
)
def
_init_loss_regularization
(
self
,
loss_regularization
):
def
_init_loss_regularization
(
self
,
loss_regularization
:
ConfigType
)
->
None
:
"""Initialize ``loss_regularization``"""
if
loss_regularization
is
not
None
:
if
isinstance
(
loss_regularization
,
list
):
...
...
@@ -66,58 +122,64 @@ class EncoderDecoder3D(Base3DSegmentor):
else
:
self
.
loss_regularization
=
MODELS
.
build
(
loss_regularization
)
def
extract_feat
(
self
,
points
)
:
def
extract_feat
(
self
,
batch_inputs_dict
:
dict
)
->
List
[
Tensor
]
:
"""Extract features from points."""
x
=
self
.
backbone
(
points
)
points
=
batch_inputs_dict
[
'points'
]
stack_points
=
torch
.
stack
(
points
)
x
=
self
.
backbone
(
stack_points
)
if
self
.
with_neck
:
x
=
self
.
neck
(
x
)
return
x
def
encode_decode
(
self
,
points
,
img_metas
):
def
encode_decode
(
self
,
batch_inputs_dict
:
dict
,
batch_input_metas
:
List
[
dict
])
->
List
[
Tensor
]:
"""Encode points with backbone and decode into a semantic segmentation
map of the same size as input.
Args:
points (torch.Tensor): Input points of shape [B, N, 3+C].
img_metas (list[dict]): Meta information of each sample.
batch_inputs_dict (dict): Input sample dict which
includes 'points' and 'imgs' keys.
- points (list[torch.Tensor]): Point cloud of each sample.
- imgs (torch.Tensor): Image tensor has shape (B, C, H, W).
batch_input_metas (list[dict]): Meta information of each sample.
Returns:
torch.Tensor: Segmentation logits of shape [B, num_classes, N].
"""
x
=
self
.
extract_feat
(
points
)
out
=
self
.
_decode_head_forward_test
(
x
,
img_metas
)
return
out
x
=
self
.
extract_feat
(
batch_inputs_dict
)
seg_logits
=
self
.
decode_head
.
predict
(
x
,
batch_input_metas
,
self
.
test_cfg
)
return
seg_logits
def
_decode_head_forward_train
(
self
,
x
,
img_metas
,
pts_semantic_mask
):
def
_decode_head_forward_train
(
self
,
batch_inputs_dict
:
dict
,
batch_data_samples
:
SampleList
)
->
dict
:
"""Run forward function and calculate loss for decode head in
training."""
losses
=
dict
()
loss_decode
=
self
.
decode_head
.
forward_train
(
x
,
img_metas
,
pts_semantic_mask
,
self
.
train_cfg
)
loss_decode
=
self
.
decode_head
.
loss
(
batch_inputs_dict
,
batch_data_samples
,
self
.
train_cfg
)
losses
.
update
(
add_prefix
(
loss_decode
,
'decode'
))
return
losses
def
_decode_head_forward_test
(
self
,
x
,
img_metas
):
"""Run forward function and calculate loss for decode head in
inference."""
seg_logits
=
self
.
decode_head
.
forward_test
(
x
,
img_metas
,
self
.
test_cfg
)
return
seg_logits
def
_auxiliary_head_forward_train
(
self
,
x
,
img_metas
,
pts_semantic_mask
):
def
_auxiliary_head_forward_train
(
self
,
batch_inputs_dict
:
dict
,
batch_data_samples
:
SampleList
,
)
->
dict
:
"""Run forward function and calculate loss for auxiliary head in
training."""
losses
=
dict
()
if
isinstance
(
self
.
auxiliary_head
,
nn
.
ModuleList
):
for
idx
,
aux_head
in
enumerate
(
self
.
auxiliary_head
):
loss_aux
=
aux_head
.
forward_train
(
x
,
img_metas
,
pts_semantic_mask
,
self
.
train_cfg
)
loss_aux
=
aux_head
.
loss
(
batch_inputs_dict
,
batch_data_samples
,
self
.
train_cfg
)
losses
.
update
(
add_prefix
(
loss_aux
,
f
'aux_
{
idx
}
'
))
else
:
loss_aux
=
self
.
auxiliary_head
.
forward_train
(
x
,
img_metas
,
pts_semantic_mask
,
self
.
train_cfg
)
loss_aux
=
self
.
auxiliary_head
.
loss
(
batch_inputs_dict
,
batch_data_samples
,
self
.
train_cfg
)
losses
.
update
(
add_prefix
(
loss_aux
,
'aux'
))
return
losses
...
...
@@ -137,39 +199,36 @@ class EncoderDecoder3D(Base3DSegmentor):
return
losses
def
forward_dummy
(
self
,
points
):
"""Dummy forward function."""
seg_logit
=
self
.
encode_decode
(
points
,
None
)
return
seg_logit
def
forward_train
(
self
,
points
,
img_metas
,
pts_semantic_mask
):
"""Forward function for training.
def
loss
(
self
,
batch_inputs_dict
:
dict
,
batch_data_samples
:
SampleList
)
->
dict
:
"""Calculate losses from a batch of inputs and data samples.
Args:
points (list[torch.Tensor]): List of points of shape [N, C].
img_metas (list): Image metas.
pts_semantic_mask (list[torch.Tensor]): List of point-wise semantic
labels of shape [N].
batch_inputs_dict (dict): Input sample dict which
includes 'points' and 'imgs' keys.
- points (list[torch.Tensor]): Point cloud of each sample.
- imgs (torch.Tensor, optional): Image tensor has shape
(B, C, H, W).
batch_data_samples (list[:obj:`Det3DDataSample`]): The det3d
data samples. It usually includes information such
as `metainfo` and `gt_pts_sem_seg`.
Returns:
dict[str, Tensor]:
Losse
s.
dict[str, Tensor]:
a dictionary of loss component
s.
"""
points_cat
=
torch
.
stack
(
points
)
pts_semantic_mask_cat
=
torch
.
stack
(
pts_semantic_mask
)
# extract features using backbone
x
=
self
.
extract_feat
(
points_ca
t
)
x
=
self
.
extract_feat
(
batch_inputs_dic
t
)
losses
=
dict
()
loss_decode
=
self
.
_decode_head_forward_train
(
x
,
img_metas
,
pts_semantic_mask_cat
)
loss_decode
=
self
.
_decode_head_forward_train
(
x
,
batch_data_samples
)
losses
.
update
(
loss_decode
)
if
self
.
with_auxiliary_head
:
loss_aux
=
self
.
_auxiliary_head_forward_train
(
x
,
img_metas
,
pts_semantic_mask_cat
)
x
,
batch_data_samples
)
losses
.
update
(
loss_aux
)
if
self
.
with_regularization_loss
:
...
...
@@ -180,10 +239,10 @@ class EncoderDecoder3D(Base3DSegmentor):
@
staticmethod
def
_input_generation
(
coords
,
patch_center
,
coord_max
,
feats
,
use_normalized_coord
=
False
):
patch_center
:
Tensor
,
coord_max
:
Tensor
,
feats
:
Tensor
,
use_normalized_coord
:
bool
=
False
):
"""Generating model input.
Generate input by subtracting patch center and adding additional
...
...
@@ -215,12 +274,12 @@ class EncoderDecoder3D(Base3DSegmentor):
return
points
def
_sliding_patch_generation
(
self
,
points
,
num_points
,
block_size
,
sample_rate
=
0.5
,
use_normalized_coord
=
False
,
eps
=
1e-3
):
points
:
Tensor
,
num_points
:
int
,
block_size
:
float
,
sample_rate
:
float
=
0.5
,
use_normalized_coord
:
bool
=
False
,
eps
:
float
=
1e-3
):
"""Sampling points in a sliding window fashion.
First sample patches to cover all the input points.
...
...
@@ -318,7 +377,8 @@ class EncoderDecoder3D(Base3DSegmentor):
return
patch_points
,
patch_idxs
def
slide_inference
(
self
,
point
,
img_meta
,
rescale
):
def
slide_inference
(
self
,
point
:
Tensor
,
img_meta
:
List
[
dict
],
rescale
:
bool
):
"""Inference by sliding-window with overlap.
Args:
...
...
@@ -362,18 +422,20 @@ class EncoderDecoder3D(Base3DSegmentor):
return
preds
.
transpose
(
0
,
1
)
# to [num_classes, K*N]
def
whole_inference
(
self
,
points
,
img_metas
,
rescale
):
def
whole_inference
(
self
,
points
:
Tensor
,
input_metas
:
List
[
dict
],
rescale
:
bool
):
"""Inference with full scene (one forward pass without sliding)."""
seg_logit
=
self
.
encode_decode
(
points
,
i
mg
_metas
)
seg_logit
=
self
.
encode_decode
(
points
,
i
nput
_metas
)
# TODO: if rescale and voxelization segmentor
return
seg_logit
def
inference
(
self
,
points
,
img_metas
,
rescale
):
def
inference
(
self
,
points
:
Tensor
,
input_metas
:
List
[
dict
],
rescale
:
bool
):
"""Inference with slide/whole style.
Args:
points (torch.Tensor): Input points of shape [B, N, 3+C].
i
mg
_metas (list[dict]): Meta information of each sample.
i
nput
_metas (list[dict]): Meta information of each sample.
rescale (bool): Whether transform to original number of points.
Will be used for voxelization based segmentors.
...
...
@@ -384,19 +446,29 @@ class EncoderDecoder3D(Base3DSegmentor):
if
self
.
test_cfg
.
mode
==
'slide'
:
seg_logit
=
torch
.
stack
([
self
.
slide_inference
(
point
,
img_meta
,
rescale
)
for
point
,
img_meta
in
zip
(
points
,
i
mg
_metas
)
for
point
,
img_meta
in
zip
(
points
,
i
nput
_metas
)
],
0
)
else
:
seg_logit
=
self
.
whole_inference
(
points
,
i
mg
_metas
,
rescale
)
seg_logit
=
self
.
whole_inference
(
points
,
i
nput
_metas
,
rescale
)
output
=
F
.
softmax
(
seg_logit
,
dim
=
1
)
return
output
def
simple_test
(
self
,
points
,
img_metas
,
rescale
=
True
):
def
predict
(
self
,
batch_inputs_dict
:
dict
,
batch_data_samples
:
SampleList
,
rescale
:
bool
=
True
)
->
SampleList
:
"""Simple test with single scene.
Args:
points (list[torch.Tensor]): List of points of shape [N, 3+C].
img_metas (list[dict]): Meta information of each sample.
batch_inputs_dict (dict): Input sample dict which
includes 'points' and 'imgs' keys.
- points (list[torch.Tensor]): Point cloud of each sample.
- imgs (torch.Tensor, optional): Image tensor has shape
(B, C, H, W).
batch_data_samples (list[:obj:`Det3DDataSample`]): The det3d
data samples. It usually includes information such
as `metainfo` and `gt_pts_sem_seg`.
rescale (bool): Whether transform to original number of points.
Will be used for voxelization based segmentors.
Defaults to True.
...
...
@@ -410,9 +482,14 @@ class EncoderDecoder3D(Base3DSegmentor):
# to use down-sampling to get a batch of scenes with same num_points
# therefore, we only support testing one scene every time
seg_pred
=
[]
for
point
,
img_meta
in
zip
(
points
,
img_metas
):
seg_prob
=
self
.
inference
(
point
.
unsqueeze
(
0
),
[
img_meta
],
rescale
)[
0
]
batch_input_metas
=
[]
for
data_sample
in
batch_data_samples
:
batch_input_metas
.
append
(
data_sample
.
metainfo
)
points
=
batch_inputs_dict
[
'points'
]
for
point
,
input_meta
in
zip
(
points
,
batch_input_metas
):
seg_prob
=
self
.
inference
(
point
.
unsqueeze
(
0
),
[
input_meta
],
rescale
)[
0
]
seg_map
=
seg_prob
.
argmax
(
0
)
# [N]
# to cpu tensor for consistency with det3d
seg_map
=
seg_map
.
cpu
()
...
...
@@ -421,33 +498,24 @@ class EncoderDecoder3D(Base3DSegmentor):
seg_pred
=
[
dict
(
semantic_mask
=
seg_map
)
for
seg_map
in
seg_pred
]
return
seg_pred
def
aug_test
(
self
,
points
,
img_metas
,
rescale
=
True
):
"""Test with augmentations.
def
_forward
(
self
,
batch_inputs_dict
:
dict
,
batch_data_samples
:
OptSampleList
=
None
)
->
Tensor
:
"""Network forward process.
Args:
points (list[torch.Tensor]): List of points of shape [B, N, 3+C].
img_metas (list[list[dict]]): Meta information of each sample.
Outer list are different samples while inner is different augs.
rescale (bool): Whether transform to original number of points.
Will be used for voxelization based segmentors.
Defaults to True.
batch_inputs_dict (dict): Input sample dict which
includes 'points' and 'imgs' keys.
Returns:
list[dict]: The output prediction result with following keys:
- points (list[torch.Tensor]): Point cloud of each sample.
- imgs (torch.Tensor, optional): Image tensor has shape
(B, C, H, W).
batch_data_samples (List[:obj:`Det3DDataSample`]): The seg
data samples. It usually includes information such
as `metainfo` and `gt_pts_sem_seg`.
- semantic_mask (Tensor): Segmentation mask of shape [N].
Returns:
Tensor: Forward output of model without any post-processes.
"""
# in aug_test, one scene going through different augmentations could
# have the same number of points and are stacked as a batch
# to save memory, we get augmented seg logit inplace
seg_pred
=
[]
for
point
,
img_meta
in
zip
(
points
,
img_metas
):
seg_prob
=
self
.
inference
(
point
,
img_meta
,
rescale
)
seg_prob
=
seg_prob
.
mean
(
0
)
# [num_classes, N]
seg_map
=
seg_prob
.
argmax
(
0
)
# [N]
# to cpu tensor for consistency with det3d
seg_map
=
seg_map
.
cpu
()
seg_pred
.
append
(
seg_map
)
# warp in dict
seg_pred
=
[
dict
(
semantic_mask
=
seg_map
)
for
seg_map
in
seg_pred
]
return
seg_pred
x
=
self
.
extract_feat
(
batch_inputs_dict
)
return
self
.
decode_head
.
forward
(
x
)
tests/test_models/test_decode_heads/test_dgcnn_head.py
0 → 100644
View file @
c66197c7
# Copyright (c) OpenMMLab. All rights reserved.
from
unittest
import
TestCase
import
torch
from
mmdet3d.core
import
Det3DDataSample
,
PointData
from
mmdet3d.models.decode_heads
import
DGCNNHead
class
TestDGCNNHead
(
TestCase
):
def
test_dgcnn_head_loss
(
self
):
"""Tests DGCNN head loss."""
dgcnn_head
=
DGCNNHead
(
fp_channels
=
(
1024
,
512
),
channels
=
256
,
num_classes
=
13
,
dropout_ratio
=
0.5
,
conv_cfg
=
dict
(
type
=
'Conv1d'
),
norm_cfg
=
dict
(
type
=
'BN1d'
),
act_cfg
=
dict
(
type
=
'LeakyReLU'
,
negative_slope
=
0.2
),
loss_decode
=
dict
(
type
=
'mmdet.CrossEntropyLoss'
,
use_sigmoid
=
False
,
class_weight
=
None
,
loss_weight
=
1.0
),
ignore_index
=
13
)
# DGCNN head expects dict format features
fa_points
=
torch
.
rand
(
1
,
4096
,
1024
).
float
()
feat_dict
=
dict
(
fa_points
=
fa_points
)
# Test forward
seg_logits
=
dgcnn_head
.
forward
(
feat_dict
)
self
.
assertEqual
(
seg_logits
,
torch
.
Size
([
1
,
13
,
4096
]))
# When truth is non-empty then losses
# should be nonzero for random inputs
pts_semantic_mask
=
torch
.
randint
(
0
,
13
,
(
2
,
4096
)).
long
()
gt_pts_seg
=
PointData
(
pts_semantic_mask
=
pts_semantic_mask
)
datasample
=
Det3DDataSample
()
datasample
.
gt_pts_seg
=
gt_pts_seg
gt_losses
=
dgcnn_head
.
loss
(
seg_logits
,
[
datasample
])
gt_sem_seg_loss
=
gt_losses
[
'loss_sem_seg'
].
item
()
self
.
assertGreater
(
gt_sem_seg_loss
,
0
,
'semantic seg loss should be positive'
)
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment