"...text-generation-inference.git" did not exist on "c5eae25be787192c0e9601352b42bf647f21888e"
Commit c66197c7 authored by ZCMax's avatar ZCMax Committed by ChaimZhu
Browse files

[Refactor] 3D Segmentor and EncoderDecoder3D

parent 522cc20d
# model settings # model settings
model = dict( model = dict(
type='EncoderDecoder3D', type='EncoderDecoder3D',
data_preprocessor=dict(type='Det3DDataPreprocessor'),
backbone=dict( backbone=dict(
type='DGCNNBackbone', type='DGCNNBackbone',
in_channels=9, # [xyz, rgb, normal_xyz], modified with dataset in_channels=9, # [xyz, rgb, normal_xyz], modified with dataset
...@@ -19,7 +20,7 @@ model = dict( ...@@ -19,7 +20,7 @@ model = dict(
norm_cfg=dict(type='BN1d'), norm_cfg=dict(type='BN1d'),
act_cfg=dict(type='LeakyReLU', negative_slope=0.2), act_cfg=dict(type='LeakyReLU', negative_slope=0.2),
loss_decode=dict( loss_decode=dict(
type='CrossEntropyLoss', type='mmdet.CrossEntropyLoss',
use_sigmoid=False, use_sigmoid=False,
class_weight=None, # modified with dataset class_weight=None, # modified with dataset
loss_weight=1.0)), loss_weight=1.0)),
......
# model settings # model settings
model = dict( model = dict(
type='EncoderDecoder3D', type='EncoderDecoder3D',
data_preprocessor=dict(type='Det3DDataPreprocessor'),
backbone=dict( backbone=dict(
type='PointNet2SASSG', type='PointNet2SASSG',
in_channels=9, # [xyz, rgb, normalized_xyz] in_channels=9, # [xyz, rgb, normalized_xyz]
...@@ -37,7 +38,7 @@ model = dict( ...@@ -37,7 +38,7 @@ model = dict(
norm_cfg=dict(type='BN1d'), norm_cfg=dict(type='BN1d'),
act_cfg=dict(type='ReLU'), act_cfg=dict(type='ReLU'),
loss_decode=dict( loss_decode=dict(
type='CrossEntropyLoss', type='mmdet.CrossEntropyLoss',
use_sigmoid=False, use_sigmoid=False,
class_weight=None, # should be modified with dataset class_weight=None, # should be modified with dataset
loss_weight=1.0)), loss_weight=1.0)),
......
# model settings # model settings
model = dict( model = dict(
type='EncoderDecoder3D', type='EncoderDecoder3D',
data_preprocessor=dict(type='Det3DDataPreprocessor'),
backbone=dict( backbone=dict(
type='PointNet2SASSG', type='PointNet2SASSG',
in_channels=6, # [xyz, rgb], should be modified with dataset in_channels=6, # [xyz, rgb], should be modified with dataset
...@@ -26,7 +27,7 @@ model = dict( ...@@ -26,7 +27,7 @@ model = dict(
norm_cfg=dict(type='BN1d'), norm_cfg=dict(type='BN1d'),
act_cfg=dict(type='ReLU'), act_cfg=dict(type='ReLU'),
loss_decode=dict( loss_decode=dict(
type='CrossEntropyLoss', type='mmdet.CrossEntropyLoss',
use_sigmoid=False, use_sigmoid=False,
class_weight=None, # should be modified with dataset class_weight=None, # should be modified with dataset
loss_weight=1.0)), loss_weight=1.0)),
......
# optimizer # optimizer
# This schedule is mainly used on S3DIS dataset in segmentation task # This schedule is mainly used on S3DIS dataset in segmentation task
optimizer = dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001) optim_wrapper = dict(
optimizer_config = dict(grad_clip=None) type='OptimWrapper',
lr_config = dict(policy='CosineAnnealing', warmup=None, min_lr=1e-5) optimizer=dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.001),
clip_grad=None)
param_scheduler = [
dict(
type='CosineAnnealingLR',
T_max=100,
eta_min=1e-5,
by_epoch=True,
begin=0,
end=100)
]
# runtime settings # runtime settings
runner = dict(type='EpochBasedRunner', max_epochs=100) train_cfg = dict(by_epoch=True, max_epochs=100)
val_cfg = dict(interval=1)
test_cfg = dict()
# optimizer # optimizer
# This schedule is mainly used on S3DIS dataset in segmentation task # This schedule is mainly used on S3DIS dataset in segmentation task
optimizer = dict(type='SGD', lr=0.2, weight_decay=0.0001, momentum=0.9) optim_wrapper = dict(
optimizer_config = dict(grad_clip=None) type='OptimWrapper',
lr_config = dict(policy='CosineAnnealing', warmup=None, min_lr=0.002) optimizer=dict(type='SGD', lr=0.2, momentum=0.9, weight_decay=0.0001),
momentum_config = None clip_grad=None)
param_scheduler = [
dict(
type='CosineAnnealingLR',
T_max=150,
eta_min=0.002,
by_epoch=True,
begin=0,
end=150)
]
# runtime settings # runtime settings
runner = dict(type='EpochBasedRunner', max_epochs=150) train_cfg = dict(by_epoch=True, max_epochs=150)
val_cfg = dict(interval=1)
test_cfg = dict()
# optimizer # optimizer
# This schedule is mainly used on ScanNet dataset in segmentation task # This schedule is mainly used on S3DIS dataset in segmentation task
optimizer = dict(type='Adam', lr=0.001, weight_decay=0.01) optim_wrapper = dict(
optimizer_config = dict(grad_clip=None) type='OptimWrapper',
lr_config = dict(policy='CosineAnnealing', warmup=None, min_lr=1e-5) optimizer=dict(type='Adam', lr=0.001, weight_decay=0.01),
momentum_config = None clip_grad=None)
param_scheduler = [
dict(
type='CosineAnnealingLR',
T_max=200,
eta_min=1e-5,
by_epoch=True,
begin=0,
end=200)
]
# runtime settings # runtime settings
runner = dict(type='EpochBasedRunner', max_epochs=200) train_cfg = dict(by_epoch=True, max_epochs=200)
val_cfg = dict(interval=1)
test_cfg = dict()
# optimizer # optimizer
# This schedule is mainly used on S3DIS dataset in segmentation task # This schedule is mainly used on S3DIS dataset in segmentation task
optimizer = dict(type='Adam', lr=0.001, weight_decay=0.001) optim_wrapper = dict(
optimizer_config = dict(grad_clip=None) type='OptimWrapper',
lr_config = dict(policy='CosineAnnealing', warmup=None, min_lr=1e-5) optimizer=dict(type='Adam', lr=0.001, weight_decay=0.001),
momentum_config = None clip_grad=None)
param_scheduler = [
dict(
type='CosineAnnealingLR',
T_max=50,
eta_min=1e-5,
by_epoch=True,
begin=0,
end=50)
]
# runtime settings # runtime settings
runner = dict(type='EpochBasedRunner', max_epochs=50) train_cfg = dict(by_epoch=True, max_epochs=50)
val_cfg = dict(interval=1)
test_cfg = dict()
...@@ -3,10 +3,6 @@ _base_ = [ ...@@ -3,10 +3,6 @@ _base_ = [
'../_base_/schedules/seg_cosine_100e.py', '../_base_/default_runtime.py' '../_base_/schedules/seg_cosine_100e.py', '../_base_/default_runtime.py'
] ]
# data settings
data = dict(samples_per_gpu=32)
evaluation = dict(interval=2)
# model settings # model settings
model = dict( model = dict(
backbone=dict(in_channels=9), # [xyz, rgb, normalized_xyz] backbone=dict(in_channels=9), # [xyz, rgb, normalized_xyz]
...@@ -20,5 +16,6 @@ model = dict( ...@@ -20,5 +16,6 @@ model = dict(
use_normalized_coord=True, use_normalized_coord=True,
batch_size=24)) batch_size=24))
# runtime settings default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=2), )
checkpoint_config = dict(interval=2) train_dataloader = dict(batch_size=32)
val_cfg = dict(interval=2)
...@@ -4,9 +4,20 @@ _base_ = [ ...@@ -4,9 +4,20 @@ _base_ = [
'../_base_/default_runtime.py' '../_base_/default_runtime.py'
] ]
# file_client_args = dict(backend='disk')
# Uncomment the following if use ceph or other file clients.
# See https://mmcv.readthedocs.io/en/latest/api.html#mmcv.fileio.FileClient
# for more details.
file_client_args = dict(
backend='petrel',
path_mapping=dict({
'./data/s3dis/':
's3://openmmlab/datasets/detection3d/s3dis_processed/',
'data/s3dis/':
's3://openmmlab/datasets/detection3d/s3dis_processed/'
}))
# data settings # data settings
class_names = ('ceiling', 'floor', 'wall', 'beam', 'column', 'window', 'door',
'table', 'chair', 'sofa', 'bookcase', 'board', 'clutter')
num_points = 4096 num_points = 4096
train_pipeline = [ train_pipeline = [
dict( dict(
...@@ -15,17 +26,16 @@ train_pipeline = [ ...@@ -15,17 +26,16 @@ train_pipeline = [
shift_height=False, shift_height=False,
use_color=True, use_color=True,
load_dim=6, load_dim=6,
use_dim=[0, 1, 2, 3, 4, 5]), use_dim=[0, 1, 2, 3, 4, 5],
file_client_args=file_client_args),
dict( dict(
type='LoadAnnotations3D', type='LoadAnnotations3D',
with_bbox_3d=False, with_bbox_3d=False,
with_label_3d=False, with_label_3d=False,
with_mask_3d=False, with_mask_3d=False,
with_seg_3d=True), with_seg_3d=True,
dict( file_client_args=file_client_args),
type='PointSegClassMapping', dict(type='PointSegClassMapping'),
valid_cat_ids=tuple(range(len(class_names))),
max_cat_id=13),
dict( dict(
type='IndoorPatchPointSample', type='IndoorPatchPointSample',
num_points=num_points, num_points=num_points,
...@@ -46,13 +56,9 @@ train_pipeline = [ ...@@ -46,13 +56,9 @@ train_pipeline = [
jitter_std=[0.01, 0.01, 0.01], jitter_std=[0.01, 0.01, 0.01],
clip_range=[-0.05, 0.05]), clip_range=[-0.05, 0.05]),
dict(type='RandomDropPointsColor', drop_ratio=0.2), dict(type='RandomDropPointsColor', drop_ratio=0.2),
dict(type='DefaultFormatBundle3D', class_names=class_names), dict(type='Pack3DDetInputs', keys=['points', 'pts_semantic_mask'])
dict(type='Collect3D', keys=['points', 'pts_semantic_mask'])
] ]
data = dict(samples_per_gpu=8, train=dict(pipeline=train_pipeline))
evaluation = dict(interval=1)
# model settings # model settings
model = dict( model = dict(
decode_head=dict( decode_head=dict(
...@@ -64,3 +70,6 @@ model = dict( ...@@ -64,3 +70,6 @@ model = dict(
sample_rate=0.5, sample_rate=0.5,
use_normalized_coord=True, use_normalized_coord=True,
batch_size=12)) batch_size=12))
train_dataloader = dict(batch_size=8, dataset=dict(pipeline=train_pipeline))
val_cfg = dict(interval=1)
...@@ -4,10 +4,6 @@ _base_ = [ ...@@ -4,10 +4,6 @@ _base_ = [
'../_base_/schedules/seg_cosine_200e.py', '../_base_/default_runtime.py' '../_base_/schedules/seg_cosine_200e.py', '../_base_/default_runtime.py'
] ]
# data settings
data = dict(samples_per_gpu=16)
evaluation = dict(interval=5)
# model settings # model settings
model = dict( model = dict(
decode_head=dict( decode_head=dict(
...@@ -30,5 +26,9 @@ model = dict( ...@@ -30,5 +26,9 @@ model = dict(
use_normalized_coord=False, use_normalized_coord=False,
batch_size=24)) batch_size=24))
# data settings
train_dataloader = dict(batch_size=16)
# runtime settings # runtime settings
checkpoint_config = dict(interval=5) default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=5), )
val_cfg = dict(interval=5)
...@@ -4,10 +4,6 @@ _base_ = [ ...@@ -4,10 +4,6 @@ _base_ = [
'../_base_/schedules/seg_cosine_50e.py', '../_base_/default_runtime.py' '../_base_/schedules/seg_cosine_50e.py', '../_base_/default_runtime.py'
] ]
# data settings
data = dict(samples_per_gpu=16)
evaluation = dict(interval=2)
# model settings # model settings
model = dict( model = dict(
backbone=dict(in_channels=9), # [xyz, rgb, normalized_xyz] backbone=dict(in_channels=9), # [xyz, rgb, normalized_xyz]
...@@ -21,5 +17,9 @@ model = dict( ...@@ -21,5 +17,9 @@ model = dict(
use_normalized_coord=True, use_normalized_coord=True,
batch_size=24)) batch_size=24))
# data settings
train_dataloader = dict(batch_size=6)
# runtime settings # runtime settings
checkpoint_config = dict(interval=2) default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=2), )
val_cfg = dict(interval=2)
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
from .array_converter import ArrayConverter, array_converter from .array_converter import ArrayConverter, array_converter
from .gaussian import (draw_heatmap_gaussian, ellip_gaussian2D, gaussian_2d, from .gaussian import (draw_heatmap_gaussian, ellip_gaussian2D, gaussian_2d,
gaussian_radius, get_ellip_gaussian_2D) gaussian_radius, get_ellip_gaussian_2D)
from .misc import add_prefix
from .typing import (ConfigType, ForwardResults, InstanceList, MultiConfig, from .typing import (ConfigType, ForwardResults, InstanceList, MultiConfig,
OptConfigType, OptInstanceList, OptMultiConfig, OptConfigType, OptInstanceList, OptMultiConfig,
OptSampleList, OptSamplingResultList, SampleList, OptSampleList, OptSamplingResultList, SampleList,
...@@ -13,5 +14,5 @@ __all__ = [ ...@@ -13,5 +14,5 @@ __all__ = [
'get_ellip_gaussian_2D', 'ConfigType', 'OptConfigType', 'MultiConfig', 'get_ellip_gaussian_2D', 'ConfigType', 'OptConfigType', 'MultiConfig',
'OptMultiConfig', 'InstanceList', 'OptInstanceList', 'SampleList', 'OptMultiConfig', 'InstanceList', 'OptInstanceList', 'SampleList',
'OptSampleList', 'SamplingResultList', 'ForwardResults', 'OptSampleList', 'SamplingResultList', 'ForwardResults',
'OptSamplingResultList' 'OptSamplingResultList', 'add_prefix'
] ]
# Copyright (c) OpenMMLab. All rights reserved.
def add_prefix(inputs, prefix):
"""Add prefix for dict.
Args:
inputs (dict): The input dict with str keys.
prefix (str): The prefix to add.
Returns:
dict: The dict with keys updated with ``prefix``.
"""
outputs = dict()
for name, value in inputs.items():
outputs[f'{prefix}.{name}'] = value
return outputs
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from typing import List
import torch
from mmcv.cnn import normal_init from mmcv.cnn import normal_init
from mmcv.runner import BaseModule, auto_fp16, force_fp32 from mmcv.runner import BaseModule, auto_fp16
from torch import Tensor
from torch import nn as nn from torch import nn as nn
from ..builder import build_loss from mmdet3d.core.utils.typing import ConfigType, SampleList
from mmdet3d.registry import MODELS
class Base3DDecodeHead(BaseModule, metaclass=ABCMeta): class Base3DDecodeHead(BaseModule, metaclass=ABCMeta):
"""Base class for BaseDecodeHead. """Base class for BaseDecodeHead.
1. The ``init_weights`` method is used to initialize decode_head's
model parameters. After segmentor initialization, ``init_weights``
is triggered when ``segmentor.init_weights()`` is called externally.
2. The ``loss`` method is used to calculate the loss of decode_head,
which includes two steps: (1) the decode_head model performs forward
propagation to obtain the feature maps (2) The ``loss_by_feat`` method
is called based on the feature maps to calculate the loss.
.. code:: text
loss(): forward() -> loss_by_feat()
3. The ``predict`` method is used to predict segmentation results,
which includes two steps: (1) the decode_head model performs forward
propagation to obtain the feature maps (2) The ``predict_by_feat`` method
is called based on the feature maps to predict segmentation results
including post-processing.
.. code:: text
predict(): forward() -> predict_by_feat()
Args: Args:
channels (int): Channels after modules, before conv_seg. channels (int): Channels after modules, before conv_seg.
num_classes (int): Number of classes. num_classes (int): Number of classes.
...@@ -26,6 +53,7 @@ class Base3DDecodeHead(BaseModule, metaclass=ABCMeta): ...@@ -26,6 +53,7 @@ class Base3DDecodeHead(BaseModule, metaclass=ABCMeta):
ignore_index (int, optional): The label index to be ignored. ignore_index (int, optional): The label index to be ignored.
When using masked BCE loss, ignore_index should be set to None. When using masked BCE loss, ignore_index should be set to None.
Default: 255. Default: 255.
init_cfg (dict or list[dict], optional): Initialization config dict.
""" """
def __init__(self, def __init__(self,
...@@ -36,12 +64,12 @@ class Base3DDecodeHead(BaseModule, metaclass=ABCMeta): ...@@ -36,12 +64,12 @@ class Base3DDecodeHead(BaseModule, metaclass=ABCMeta):
norm_cfg=dict(type='BN1d'), norm_cfg=dict(type='BN1d'),
act_cfg=dict(type='ReLU'), act_cfg=dict(type='ReLU'),
loss_decode=dict( loss_decode=dict(
type='CrossEntropyLoss', type='mmdet.CrossEntropyLoss',
use_sigmoid=False, use_sigmoid=False,
class_weight=None, class_weight=None,
loss_weight=1.0), loss_weight=1.0),
ignore_index=255, ignore_index=255,
init_cfg=None): init_cfg=None) -> None:
super(Base3DDecodeHead, self).__init__(init_cfg=init_cfg) super(Base3DDecodeHead, self).__init__(init_cfg=init_cfg)
self.channels = channels self.channels = channels
self.num_classes = num_classes self.num_classes = num_classes
...@@ -49,7 +77,7 @@ class Base3DDecodeHead(BaseModule, metaclass=ABCMeta): ...@@ -49,7 +77,7 @@ class Base3DDecodeHead(BaseModule, metaclass=ABCMeta):
self.conv_cfg = conv_cfg self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg self.norm_cfg = norm_cfg
self.act_cfg = act_cfg self.act_cfg = act_cfg
self.loss_decode = build_loss(loss_decode) self.loss_decode = MODELS.build(loss_decode)
self.ignore_index = ignore_index self.ignore_index = ignore_index
self.conv_seg = nn.Conv1d(channels, num_classes, kernel_size=1) self.conv_seg = nn.Conv1d(channels, num_classes, kernel_size=1)
...@@ -57,6 +85,7 @@ class Base3DDecodeHead(BaseModule, metaclass=ABCMeta): ...@@ -57,6 +85,7 @@ class Base3DDecodeHead(BaseModule, metaclass=ABCMeta):
self.dropout = nn.Dropout(dropout_ratio) self.dropout = nn.Dropout(dropout_ratio)
else: else:
self.dropout = None self.dropout = None
self.fp16_enabled = False self.fp16_enabled = False
def init_weights(self): def init_weights(self):
...@@ -66,11 +95,19 @@ class Base3DDecodeHead(BaseModule, metaclass=ABCMeta): ...@@ -66,11 +95,19 @@ class Base3DDecodeHead(BaseModule, metaclass=ABCMeta):
@auto_fp16() @auto_fp16()
@abstractmethod @abstractmethod
def forward(self, inputs): def forward(self, feats_dict: dict):
"""Placeholder of forward function.""" """Placeholder of forward function."""
pass pass
def forward_train(self, inputs, img_metas, pts_semantic_mask, train_cfg): def cls_seg(self, feat: Tensor) -> Tensor:
"""Classify each points."""
if self.dropout is not None:
feat = self.dropout(feat)
output = self.conv_seg(feat)
return output
def loss(self, inputs: List[Tensor], batch_data_samples: SampleList,
train_cfg: ConfigType) -> dict:
"""Forward function for training. """Forward function for training.
Args: Args:
...@@ -84,39 +121,44 @@ class Base3DDecodeHead(BaseModule, metaclass=ABCMeta): ...@@ -84,39 +121,44 @@ class Base3DDecodeHead(BaseModule, metaclass=ABCMeta):
dict[str, Tensor]: a dictionary of loss components dict[str, Tensor]: a dictionary of loss components
""" """
seg_logits = self.forward(inputs) seg_logits = self.forward(inputs)
losses = self.losses(seg_logits, pts_semantic_mask) losses = self.loss_by_feat(seg_logits, batch_data_samples)
return losses return losses
def forward_test(self, inputs, img_metas, test_cfg): def predict(self, inputs: List[Tensor], batch_input_metas: List[dict],
test_cfg: ConfigType) -> List[Tensor]:
"""Forward function for testing. """Forward function for testing.
Args: Args:
inputs (list[Tensor]): List of multi-level point features. inputs (list[Tensor]): List of multi-level point features.
img_metas (list[dict]): Meta information of each sample. batch_img_metas (list[dict]): Meta information of each sample.
test_cfg (dict): The testing config. test_cfg (dict): The testing config.
Returns: Returns:
Tensor: Output segmentation map. Tensor: Output segmentation map.
""" """
return self.forward(inputs) seg_logits = self.forward(inputs)
def cls_seg(self, feat): return seg_logits
"""Classify each points."""
if self.dropout is not None: def _stack_batch_gt(self, batch_data_samples: SampleList) -> Tensor:
feat = self.dropout(feat) gt_semantic_segs = [
output = self.conv_seg(feat) data_sample.gt_pts_seg.pts_semantic_mask
return output for data_sample in batch_data_samples
]
return torch.stack(gt_semantic_segs, dim=0)
@force_fp32(apply_to=('seg_logit', )) def loss_by_feat(self, seg_logit: Tensor,
def losses(self, seg_logit, seg_label): batch_data_samples: SampleList) -> dict:
"""Compute semantic segmentation loss. """Compute semantic segmentation loss.
Args: Args:
seg_logit (torch.Tensor): Predicted per-point segmentation logits seg_logit (torch.Tensor): Predicted per-point segmentation logits
of shape [B, num_classes, N]. of shape [B, num_classes, N].
seg_label (torch.Tensor): Ground-truth segmentation label of batch_data_samples (List[:obj:`Det3DDataSample`]): The seg
shape [B, N]. data samples. It usually includes information such
as `metainfo` and `gt_pts_seg`.
""" """
seg_label = self._stack_batch_gt(batch_data_samples)
loss = dict() loss = dict()
loss['loss_sem_seg'] = self.loss_decode( loss['loss_sem_seg'] = self.loss_decode(
seg_logit, seg_label, ignore_index=self.ignore_index) seg_logit, seg_label, ignore_index=self.ignore_index)
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple
from mmcv.cnn.bricks import ConvModule from mmcv.cnn.bricks import ConvModule
from torch import Tensor
from mmdet3d.ops import DGCNNFPModule from mmdet3d.ops import DGCNNFPModule
from mmdet3d.registry import MODELS from mmdet3d.registry import MODELS
...@@ -19,7 +22,7 @@ class DGCNNHead(Base3DDecodeHead): ...@@ -19,7 +22,7 @@ class DGCNNHead(Base3DDecodeHead):
propagation (FP) modules. Defaults to (1216, 512). propagation (FP) modules. Defaults to (1216, 512).
""" """
def __init__(self, fp_channels=(1216, 512), **kwargs): def __init__(self, fp_channels: Tuple = (1216, 512), **kwargs) -> None:
super(DGCNNHead, self).__init__(**kwargs) super(DGCNNHead, self).__init__(**kwargs)
self.FP_module = DGCNNFPModule( self.FP_module = DGCNNFPModule(
...@@ -35,7 +38,7 @@ class DGCNNHead(Base3DDecodeHead): ...@@ -35,7 +38,7 @@ class DGCNNHead(Base3DDecodeHead):
norm_cfg=self.norm_cfg, norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg) act_cfg=self.act_cfg)
def _extract_input(self, feat_dict): def _extract_input(self, feat_dict: dict) -> Tensor:
"""Extract inputs from features dictionary. """Extract inputs from features dictionary.
Args: Args:
...@@ -48,7 +51,7 @@ class DGCNNHead(Base3DDecodeHead): ...@@ -48,7 +51,7 @@ class DGCNNHead(Base3DDecodeHead):
return fa_points return fa_points
def forward(self, feat_dict): def forward(self, feat_dict: dict) -> Tensor:
"""Forward pass. """Forward pass.
Args: Args:
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple
from mmcv.cnn.bricks import ConvModule from mmcv.cnn.bricks import ConvModule
from torch import Tensor
from mmdet3d.core.utils import ConfigType
from mmdet3d.registry import MODELS from mmdet3d.registry import MODELS
from .pointnet2_head import PointNet2Head from .pointnet2_head import PointNet2Head
...@@ -19,11 +23,14 @@ class PAConvHead(PointNet2Head): ...@@ -19,11 +23,14 @@ class PAConvHead(PointNet2Head):
""" """
def __init__(self, def __init__(self,
fp_channels=((768, 256, 256), (384, 256, 256), fp_channels: Tuple[Tuple[int]] = ((768, 256, 256),
(320, 256, 128), (128 + 6, 128, 128, 128)), (384, 256, 256), (320, 256,
fp_norm_cfg=dict(type='BN2d'), 128),
**kwargs): (128 + 6, 128, 128, 128)),
super(PAConvHead, self).__init__(fp_channels, fp_norm_cfg, **kwargs) fp_norm_cfg: ConfigType = dict(type='BN2d'),
**kwargs) -> None:
super(PAConvHead, self).__init__(
fp_channels=fp_channels, fp_norm_cfg=fp_norm_cfg, **kwargs)
# https://github.com/CVMI-Lab/PAConv/blob/main/scene_seg/model/pointnet2/pointnet2_paconv_seg.py#L53 # https://github.com/CVMI-Lab/PAConv/blob/main/scene_seg/model/pointnet2/pointnet2_paconv_seg.py#L53
# PointNet++'s decoder conv has bias while PAConv's doesn't have # PointNet++'s decoder conv has bias while PAConv's doesn't have
...@@ -37,7 +44,7 @@ class PAConvHead(PointNet2Head): ...@@ -37,7 +44,7 @@ class PAConvHead(PointNet2Head):
norm_cfg=self.norm_cfg, norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg) act_cfg=self.act_cfg)
def forward(self, feat_dict): def forward(self, feat_dict: dict) -> Tensor:
"""Forward pass. """Forward pass.
Args: Args:
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple
from mmcv.cnn.bricks import ConvModule from mmcv.cnn.bricks import ConvModule
from torch import Tensor
from torch import nn as nn from torch import nn as nn
from mmdet3d.core.utils.typing import ConfigType
from mmdet3d.ops import PointFPModule from mmdet3d.ops import PointFPModule
from mmdet3d.registry import MODELS from mmdet3d.registry import MODELS
from .decode_head import Base3DDecodeHead from .decode_head import Base3DDecodeHead
...@@ -21,10 +25,12 @@ class PointNet2Head(Base3DDecodeHead): ...@@ -21,10 +25,12 @@ class PointNet2Head(Base3DDecodeHead):
""" """
def __init__(self, def __init__(self,
fp_channels=((768, 256, 256), (384, 256, 256), fp_channels: Tuple[Tuple[int]] = ((768, 256, 256),
(320, 256, 128), (128, 128, 128, 128)), (384, 256, 256), (320, 256,
fp_norm_cfg=dict(type='BN2d'), 128),
**kwargs): (128, 128, 128, 128)),
fp_norm_cfg: ConfigType = dict(type='BN2d'),
**kwargs) -> None:
super(PointNet2Head, self).__init__(**kwargs) super(PointNet2Head, self).__init__(**kwargs)
self.num_fp = len(fp_channels) self.num_fp = len(fp_channels)
...@@ -43,7 +49,7 @@ class PointNet2Head(Base3DDecodeHead): ...@@ -43,7 +49,7 @@ class PointNet2Head(Base3DDecodeHead):
norm_cfg=self.norm_cfg, norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg) act_cfg=self.act_cfg)
def _extract_input(self, feat_dict): def _extract_input(self, feat_dict: dict) -> Tensor:
"""Extract inputs from features dictionary. """Extract inputs from features dictionary.
Args: Args:
...@@ -59,7 +65,7 @@ class PointNet2Head(Base3DDecodeHead): ...@@ -59,7 +65,7 @@ class PointNet2Head(Base3DDecodeHead):
return sa_xyz, sa_features return sa_xyz, sa_features
def forward(self, feat_dict): def forward(self, feat_dict: dict) -> Tensor:
"""Forward pass. """Forward pass.
Args: Args:
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from os import path as osp from abc import ABCMeta, abstractmethod
from typing import List, Tuple
import mmcv from mmengine.data import PixelData
import numpy as np from mmengine.model import BaseModel
import torch from torch import Tensor
from mmcv.parallel import DataContainer as DC
from mmcv.runner import auto_fp16
from mmdet3d.core import show_seg_result from mmdet3d.core import Det3DDataSample
from mmseg.models.segmentors import BaseSegmentor from mmdet3d.core.utils import (ForwardResults, OptConfigType, OptMultiConfig,
OptSampleList, SampleList)
class Base3DSegmentor(BaseSegmentor): class Base3DSegmentor(BaseModel, metaclass=ABCMeta):
"""Base class for 3D segmentors. """Base class for 3D segmentors.
The main difference with `BaseSegmentor` is that we modify the keys in Args:
data_dict and use a 3D seg specific visualization function. data_preprocessor (dict, optional): Model preprocessing config
for processing the input data. it usually includes
``to_rgb``, ``pad_size_divisor``, ``pad_val``,
``mean`` and ``std``. Default to None.
init_cfg (dict, optional): the config to control the
initialization. Default to None.
""" """
def __init__(self,
data_preprocessor: OptConfigType = None,
init_cfg: OptMultiConfig = None):
super(Base3DSegmentor, self).__init__(
data_preprocessor=data_preprocessor, init_cfg=init_cfg)
@property
def with_neck(self) -> bool:
"""bool: whether the segmentor has neck"""
return hasattr(self, 'neck') and self.neck is not None
@property
def with_auxiliary_head(self) -> bool:
"""bool: whether the segmentor has auxiliary head"""
return hasattr(self,
'auxiliary_head') and self.auxiliary_head is not None
@property @property
def with_regularization_loss(self): def with_decode_head(self) -> bool:
"""bool: whether the segmentor has decode head"""
return hasattr(self, 'decode_head') and self.decode_head is not None
@property
def with_regularization_loss(self) -> bool:
"""bool: whether the segmentor has regularization loss for weight""" """bool: whether the segmentor has regularization loss for weight"""
return hasattr(self, 'loss_regularization') and \ return hasattr(self, 'loss_regularization') and \
self.loss_regularization is not None self.loss_regularization is not None
def forward_test(self, points, img_metas, **kwargs): @abstractmethod
"""Calls either simple_test or aug_test depending on the length of def extract_feat(self, batch_inputs: Tensor) -> bool:
outer list of points. If len(points) == 1, call simple_test. Otherwise """Placeholder for extract features from images."""
call aug_test to aggregate the test results by e.g. voting. pass
@abstractmethod
def encode_decode(self, batch_inputs: Tensor,
batch_data_samples: SampleList):
"""Placeholder for encode images with backbone and decode into a
semantic segmentation map of the same size as input."""
pass
def forward(self,
batch_inputs_dict: Tensor,
batch_data_samples: OptSampleList = None,
mode: str = 'tensor') -> ForwardResults:
"""The unified entry for a forward process in both training and test.
The method should accept three modes: "tensor", "predict" and "loss":
- "tensor": Forward the whole network and return tensor or tuple of
tensor without any post-processing, same as a common nn.Module.
- "predict": Forward and return the predictions, which are fully
processed to a list of :obj:`SegDataSample`.
- "loss": Forward and return a dict of losses according to the given
inputs and data samples.
Note that this method doesn't handle neither back propagation nor
optimizer updating, which are done in the :meth:`train_step`.
Args: Args:
points (list[list[torch.Tensor]]): the outer list indicates batch_inputs_dict (dict): Input sample dict which
test-time augmentations and inner torch.Tensor should have a includes 'points' and 'imgs' keys.
shape BXNxC, which contains all points in the batch.
img_metas (list[list[dict]]): the outer list indicates test-time
augs (multiscale, flip, etc.) and the inner list indicates
images in a batch.
"""
for var, name in [(points, 'points'), (img_metas, 'img_metas')]:
if not isinstance(var, list):
raise TypeError(f'{name} must be a list, but got {type(var)}')
num_augs = len(points) - points (list[torch.Tensor]): Point cloud of each sample.
if num_augs != len(img_metas): - imgs (torch.Tensor): Image tensor has shape (B, C, H, W).
raise ValueError(f'num of augmentations ({len(points)}) != ' batch_data_samples (list[:obj:`Det3DDataSample`], optional):
f'num of image meta ({len(img_metas)})') The annotation data of every samples. Defaults to None.
mode (str): Return what kind of value. Defaults to 'tensor'.
if num_augs == 1: Returns:
return self.simple_test(points[0], img_metas[0], **kwargs) The return type depends on ``mode``.
else:
return self.aug_test(points, img_metas, **kwargs) - If ``mode="tensor"``, return a tensor or a tuple of tensor.
- If ``mode="predict"``, return a list of :obj:`Det3DDataSample`.
@auto_fp16(apply_to=('points')) - If ``mode="loss"``, return a dict of tensor.
def forward(self, return_loss=True, **kwargs):
"""Calls either forward_train or forward_test depending on whether
return_loss=True.
Note this setting will change the expected inputs. When
`return_loss=True`, point and img_metas are single-nested (i.e.
torch.Tensor and list[dict]), and when `resturn_loss=False`, point and
img_metas should be double nested (i.e. list[torch.Tensor],
list[list[dict]]), with the outer list indicating test time
augmentations.
""" """
if return_loss: if mode == 'loss':
return self.forward_train(**kwargs) return self.loss(batch_inputs_dict, batch_data_samples)
elif mode == 'predict':
return self.predict(batch_inputs_dict, batch_data_samples)
elif mode == 'tensor':
return self._forward(batch_inputs_dict, batch_data_samples)
else: else:
return self.forward_test(**kwargs) raise RuntimeError(f'Invalid mode "{mode}". '
'Only supports loss, predict and tensor mode')
def show_results(self,
data, @abstractmethod
result, def loss(self, batch_inputs: Tensor,
palette=None, batch_data_samples: SampleList) -> dict:
out_dir=None, """Calculate losses from a batch of inputs and data samples."""
ignore_index=None, pass
show=False,
score_thr=None): @abstractmethod
"""Results visualization. def predict(self, batch_inputs: Tensor,
batch_data_samples: SampleList) -> SampleList:
"""Predict results from a batch of inputs and data samples with post-
processing."""
pass
@abstractmethod
def _forward(
self,
batch_inputs: Tensor,
batch_data_samples: OptSampleList = None) -> Tuple[List[Tensor]]:
"""Network forward process.
Usually includes backbone, neck and head forward without any post-
processing.
"""
pass
@abstractmethod
def aug_test(self, batch_inputs, batch_img_metas):
"""Placeholder for augmentation test."""
pass
def postprocess_result(self, seg_logits_list: List[dict],
batch_img_metas: List[dict]) -> list:
""" Convert results list to `Det3DDataSample`.
Args: Args:
data (list[dict]): Input points and the information of the sample. seg_logits_list (List[dict]): List of segmentation results,
result (list[dict]): Prediction results. seg_logits from model of each input point clouds sample.
palette (list[list[int]]] | np.ndarray): The palette of
segmentation map. If None is given, random palette will be Returns:
generated. Default: None list[:obj:`Det3DDataSample`]: Segmentation results of the
out_dir (str): Output directory of visualization result. input images. Each Det3DDataSample usually contain:
ignore_index (int, optional): The label index to be ignored, e.g.
unannotated points. If None is given, set to len(self.CLASSES). - ``pred_pts_sem_seg``(PixelData): Prediction of 3D
Defaults to None. semantic segmentation.
show (bool, optional): Determines whether you are - ``seg_logits``(PixelData): Predicted logits of semantic
going to show result by open3d. segmentation before normalization.
Defaults to False.
TODO: implement score_thr of Base3DSegmentor.
score_thr (float, optional): Score threshold of bounding boxes.
Default to None.
Not implemented yet, but it is here for unification.
""" """
assert out_dir is not None, 'Expect out_dir, got none.' predictions = []
if palette is None:
if self.PALETTE is None: for i in range(len(seg_logits_list)):
palette = np.random.randint( img_meta = batch_img_metas[i]
0, 255, size=(len(self.CLASSES), 3)) seg_logits = seg_logits_list[i][None],
else: seg_pred = seg_logits.argmax(dim=0, keepdim=True)
palette = self.PALETTE prediction = Det3DDataSample(**{'metainfo': img_meta})
palette = np.array(palette) prediction.set_data(
for batch_id in range(len(result)): {'pred_pts_sem_seg': PixelData(**{'data': seg_pred})})
if isinstance(data['points'][0], DC): predictions.append(prediction)
points = data['points'][0]._data[0][batch_id].numpy() return predictions
elif mmcv.is_list_of(data['points'][0], torch.Tensor):
points = data['points'][0][batch_id]
else:
ValueError(f"Unsupported data type {type(data['points'][0])} "
f'for visualization!')
if isinstance(data['img_metas'][0], DC):
pts_filename = data['img_metas'][0]._data[0][batch_id][
'pts_filename']
elif mmcv.is_list_of(data['img_metas'][0], dict):
pts_filename = data['img_metas'][0][batch_id]['pts_filename']
else:
ValueError(
f"Unsupported data type {type(data['img_metas'][0])} "
f'for visualization!')
file_name = osp.split(pts_filename)[-1].split('.')[0]
pred_sem_mask = result[batch_id]['semantic_mask'].cpu().numpy()
show_seg_result(
points,
None,
pred_sem_mask,
out_dir,
file_name,
palette,
ignore_index,
show=show)
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import List
import numpy as np import numpy as np
import torch import torch
from torch import Tensor
from torch import nn as nn from torch import nn as nn
from torch.nn import functional as F from torch.nn import functional as F
from mmdet3d.core import add_prefix
from mmdet3d.core.utils import (ConfigType, OptConfigType, OptMultiConfig,
OptSampleList, SampleList)
from mmdet3d.registry import MODELS from mmdet3d.registry import MODELS
from mmseg.core import add_prefix
from .base import Base3DSegmentor from .base import Base3DSegmentor
...@@ -15,20 +20,69 @@ class EncoderDecoder3D(Base3DSegmentor): ...@@ -15,20 +20,69 @@ class EncoderDecoder3D(Base3DSegmentor):
EncoderDecoder typically consists of backbone, decode_head, auxiliary_head. EncoderDecoder typically consists of backbone, decode_head, auxiliary_head.
Note that auxiliary_head is only used for deep supervision during training, Note that auxiliary_head is only used for deep supervision during training,
which could be thrown during inference. which could be dumped during inference.
"""
1. The ``loss`` method is used to calculate the loss of model,
which includes two steps: (1) Extracts features to obtain the feature maps
(2) Call the decode head loss function to forward decode head model and
calculate losses.
.. code:: text
loss(): extract_feat() -> _decode_head_forward_train() -> _auxiliary_head_forward_train (optional)
_decode_head_forward_train(): decode_head.loss()
_auxiliary_head_forward_train(): auxiliary_head.loss (optional)
2. The ``predict`` method is used to predict segmentation results,
which includes two steps: (1) Run inference function to obtain the list of
seg_logits (2) Call post-processing function to obtain list of
``SegDataSampel`` including ``pred_sem_seg`` and ``seg_logits``.
.. code:: text
predict(): inference() -> postprocess_result()
infercen(): whole_inference()/slide_inference()
whole_inference()/slide_inference(): encoder_decoder()
encoder_decoder(): extract_feat() -> decode_head.predict()
4 The ``_forward`` method is used to output the tensor by running the model,
which includes two steps: (1) Extracts features to obtain the feature maps
(2)Call the decode head forward function to forward decode head model.
.. code:: text
_forward(): extract_feat() -> _decode_head.forward()
Args:
backbone (ConfigType): The config for the backnone of segmentor.
decode_head (ConfigType): The config for the decode head of segmentor.
neck (OptConfigType): The config for the neck of segmentor.
Defaults to None.
auxiliary_head (OptConfigType): The config for the auxiliary head of
segmentor. Defaults to None.
loss_regularization (OptiConfigType): The config for the regularization
loass. Defaults to None.
train_cfg (OptConfigType): The config for training. Defaults to None.
test_cfg (OptConfigType): The config for testing. Defaults to None.
data_preprocessor (dict, optional): The pre-process config of
:class:`BaseDataPreprocessor`.
init_cfg (dict, optional): The weight initialized config for
:class:`BaseModule`.
""" # noqa: E501
def __init__(self, def __init__(self,
backbone, backbone: ConfigType,
decode_head, decode_head: ConfigType,
neck=None, neck: OptConfigType = None,
auxiliary_head=None, auxiliary_head: OptConfigType = None,
loss_regularization=None, loss_regularization: OptConfigType = None,
train_cfg=None, train_cfg: OptConfigType = None,
test_cfg=None, test_cfg: OptConfigType = None,
pretrained=None, data_preprocessor: OptConfigType = None,
init_cfg=None): init_cfg: OptMultiConfig = None):
super(EncoderDecoder3D, self).__init__(init_cfg=init_cfg) super(EncoderDecoder3D, self).__init__(
data_preprocessor=data_preprocessor, init_cfg=init_cfg)
self.backbone = MODELS.build(backbone) self.backbone = MODELS.build(backbone)
if neck is not None: if neck is not None:
self.neck = MODELS.build(neck) self.neck = MODELS.build(neck)
...@@ -38,15 +92,16 @@ class EncoderDecoder3D(Base3DSegmentor): ...@@ -38,15 +92,16 @@ class EncoderDecoder3D(Base3DSegmentor):
self.train_cfg = train_cfg self.train_cfg = train_cfg
self.test_cfg = test_cfg self.test_cfg = test_cfg
assert self.with_decode_head, \ assert self.with_decode_head, \
'3D EncoderDecoder Segmentor should have a decode_head' '3D EncoderDecoder Segmentor should have a decode_head'
def _init_decode_head(self, decode_head): def _init_decode_head(self, decode_head: ConfigType) -> None:
"""Initialize ``decode_head``""" """Initialize ``decode_head``"""
self.decode_head = MODELS.build(decode_head) self.decode_head = MODELS.build(decode_head)
self.num_classes = self.decode_head.num_classes self.num_classes = self.decode_head.num_classes
def _init_auxiliary_head(self, auxiliary_head): def _init_auxiliary_head(self, auxiliary_head: ConfigType) -> None:
"""Initialize ``auxiliary_head``""" """Initialize ``auxiliary_head``"""
if auxiliary_head is not None: if auxiliary_head is not None:
if isinstance(auxiliary_head, list): if isinstance(auxiliary_head, list):
...@@ -56,7 +111,8 @@ class EncoderDecoder3D(Base3DSegmentor): ...@@ -56,7 +111,8 @@ class EncoderDecoder3D(Base3DSegmentor):
else: else:
self.auxiliary_head = MODELS.build(auxiliary_head) self.auxiliary_head = MODELS.build(auxiliary_head)
def _init_loss_regularization(self, loss_regularization): def _init_loss_regularization(self,
loss_regularization: ConfigType) -> None:
"""Initialize ``loss_regularization``""" """Initialize ``loss_regularization``"""
if loss_regularization is not None: if loss_regularization is not None:
if isinstance(loss_regularization, list): if isinstance(loss_regularization, list):
...@@ -66,58 +122,64 @@ class EncoderDecoder3D(Base3DSegmentor): ...@@ -66,58 +122,64 @@ class EncoderDecoder3D(Base3DSegmentor):
else: else:
self.loss_regularization = MODELS.build(loss_regularization) self.loss_regularization = MODELS.build(loss_regularization)
def extract_feat(self, points): def extract_feat(self, batch_inputs_dict: dict) -> List[Tensor]:
"""Extract features from points.""" """Extract features from points."""
x = self.backbone(points) points = batch_inputs_dict['points']
stack_points = torch.stack(points)
x = self.backbone(stack_points)
if self.with_neck: if self.with_neck:
x = self.neck(x) x = self.neck(x)
return x return x
def encode_decode(self, points, img_metas): def encode_decode(self, batch_inputs_dict: dict,
batch_input_metas: List[dict]) -> List[Tensor]:
"""Encode points with backbone and decode into a semantic segmentation """Encode points with backbone and decode into a semantic segmentation
map of the same size as input. map of the same size as input.
Args: Args:
points (torch.Tensor): Input points of shape [B, N, 3+C]. batch_inputs_dict (dict): Input sample dict which
img_metas (list[dict]): Meta information of each sample. includes 'points' and 'imgs' keys.
- points (list[torch.Tensor]): Point cloud of each sample.
- imgs (torch.Tensor): Image tensor has shape (B, C, H, W).
batch_input_metas (list[dict]): Meta information of each sample.
Returns: Returns:
torch.Tensor: Segmentation logits of shape [B, num_classes, N]. torch.Tensor: Segmentation logits of shape [B, num_classes, N].
""" """
x = self.extract_feat(points) x = self.extract_feat(batch_inputs_dict)
out = self._decode_head_forward_test(x, img_metas) seg_logits = self.decode_head.predict(x, batch_input_metas,
return out self.test_cfg)
return seg_logits
def _decode_head_forward_train(self, x, img_metas, pts_semantic_mask): def _decode_head_forward_train(self, batch_inputs_dict: dict,
batch_data_samples: SampleList) -> dict:
"""Run forward function and calculate loss for decode head in """Run forward function and calculate loss for decode head in
training.""" training."""
losses = dict() losses = dict()
loss_decode = self.decode_head.forward_train(x, img_metas, loss_decode = self.decode_head.loss(batch_inputs_dict,
pts_semantic_mask, batch_data_samples, self.train_cfg)
self.train_cfg)
losses.update(add_prefix(loss_decode, 'decode')) losses.update(add_prefix(loss_decode, 'decode'))
return losses return losses
def _decode_head_forward_test(self, x, img_metas): def _auxiliary_head_forward_train(
"""Run forward function and calculate loss for decode head in self,
inference.""" batch_inputs_dict: dict,
seg_logits = self.decode_head.forward_test(x, img_metas, self.test_cfg) batch_data_samples: SampleList,
return seg_logits ) -> dict:
def _auxiliary_head_forward_train(self, x, img_metas, pts_semantic_mask):
"""Run forward function and calculate loss for auxiliary head in """Run forward function and calculate loss for auxiliary head in
training.""" training."""
losses = dict() losses = dict()
if isinstance(self.auxiliary_head, nn.ModuleList): if isinstance(self.auxiliary_head, nn.ModuleList):
for idx, aux_head in enumerate(self.auxiliary_head): for idx, aux_head in enumerate(self.auxiliary_head):
loss_aux = aux_head.forward_train(x, img_metas, loss_aux = aux_head.loss(batch_inputs_dict, batch_data_samples,
pts_semantic_mask, self.train_cfg)
self.train_cfg)
losses.update(add_prefix(loss_aux, f'aux_{idx}')) losses.update(add_prefix(loss_aux, f'aux_{idx}'))
else: else:
loss_aux = self.auxiliary_head.forward_train( loss_aux = self.auxiliary_head.loss(batch_inputs_dict,
x, img_metas, pts_semantic_mask, self.train_cfg) batch_data_samples,
self.train_cfg)
losses.update(add_prefix(loss_aux, 'aux')) losses.update(add_prefix(loss_aux, 'aux'))
return losses return losses
...@@ -137,39 +199,36 @@ class EncoderDecoder3D(Base3DSegmentor): ...@@ -137,39 +199,36 @@ class EncoderDecoder3D(Base3DSegmentor):
return losses return losses
def forward_dummy(self, points): def loss(self, batch_inputs_dict: dict,
"""Dummy forward function.""" batch_data_samples: SampleList) -> dict:
seg_logit = self.encode_decode(points, None) """Calculate losses from a batch of inputs and data samples.
return seg_logit
def forward_train(self, points, img_metas, pts_semantic_mask):
"""Forward function for training.
Args: Args:
points (list[torch.Tensor]): List of points of shape [N, C]. batch_inputs_dict (dict): Input sample dict which
img_metas (list): Image metas. includes 'points' and 'imgs' keys.
pts_semantic_mask (list[torch.Tensor]): List of point-wise semantic
labels of shape [N]. - points (list[torch.Tensor]): Point cloud of each sample.
- imgs (torch.Tensor, optional): Image tensor has shape
(B, C, H, W).
batch_data_samples (list[:obj:`Det3DDataSample`]): The det3d
data samples. It usually includes information such
as `metainfo` and `gt_pts_sem_seg`.
Returns: Returns:
dict[str, Tensor]: Losses. dict[str, Tensor]: a dictionary of loss components.
""" """
points_cat = torch.stack(points)
pts_semantic_mask_cat = torch.stack(pts_semantic_mask)
# extract features using backbone # extract features using backbone
x = self.extract_feat(points_cat) x = self.extract_feat(batch_inputs_dict)
losses = dict() losses = dict()
loss_decode = self._decode_head_forward_train(x, img_metas, loss_decode = self._decode_head_forward_train(x, batch_data_samples)
pts_semantic_mask_cat)
losses.update(loss_decode) losses.update(loss_decode)
if self.with_auxiliary_head: if self.with_auxiliary_head:
loss_aux = self._auxiliary_head_forward_train( loss_aux = self._auxiliary_head_forward_train(
x, img_metas, pts_semantic_mask_cat) x, batch_data_samples)
losses.update(loss_aux) losses.update(loss_aux)
if self.with_regularization_loss: if self.with_regularization_loss:
...@@ -180,10 +239,10 @@ class EncoderDecoder3D(Base3DSegmentor): ...@@ -180,10 +239,10 @@ class EncoderDecoder3D(Base3DSegmentor):
@staticmethod @staticmethod
def _input_generation(coords, def _input_generation(coords,
patch_center, patch_center: Tensor,
coord_max, coord_max: Tensor,
feats, feats: Tensor,
use_normalized_coord=False): use_normalized_coord: bool = False):
"""Generating model input. """Generating model input.
Generate input by subtracting patch center and adding additional Generate input by subtracting patch center and adding additional
...@@ -215,12 +274,12 @@ class EncoderDecoder3D(Base3DSegmentor): ...@@ -215,12 +274,12 @@ class EncoderDecoder3D(Base3DSegmentor):
return points return points
def _sliding_patch_generation(self, def _sliding_patch_generation(self,
points, points: Tensor,
num_points, num_points: int,
block_size, block_size: float,
sample_rate=0.5, sample_rate: float = 0.5,
use_normalized_coord=False, use_normalized_coord: bool = False,
eps=1e-3): eps: float = 1e-3):
"""Sampling points in a sliding window fashion. """Sampling points in a sliding window fashion.
First sample patches to cover all the input points. First sample patches to cover all the input points.
...@@ -318,7 +377,8 @@ class EncoderDecoder3D(Base3DSegmentor): ...@@ -318,7 +377,8 @@ class EncoderDecoder3D(Base3DSegmentor):
return patch_points, patch_idxs return patch_points, patch_idxs
def slide_inference(self, point, img_meta, rescale): def slide_inference(self, point: Tensor, img_meta: List[dict],
rescale: bool):
"""Inference by sliding-window with overlap. """Inference by sliding-window with overlap.
Args: Args:
...@@ -362,18 +422,20 @@ class EncoderDecoder3D(Base3DSegmentor): ...@@ -362,18 +422,20 @@ class EncoderDecoder3D(Base3DSegmentor):
return preds.transpose(0, 1) # to [num_classes, K*N] return preds.transpose(0, 1) # to [num_classes, K*N]
def whole_inference(self, points, img_metas, rescale): def whole_inference(self, points: Tensor, input_metas: List[dict],
rescale: bool):
"""Inference with full scene (one forward pass without sliding).""" """Inference with full scene (one forward pass without sliding)."""
seg_logit = self.encode_decode(points, img_metas) seg_logit = self.encode_decode(points, input_metas)
# TODO: if rescale and voxelization segmentor # TODO: if rescale and voxelization segmentor
return seg_logit return seg_logit
def inference(self, points, img_metas, rescale): def inference(self, points: Tensor, input_metas: List[dict],
rescale: bool):
"""Inference with slide/whole style. """Inference with slide/whole style.
Args: Args:
points (torch.Tensor): Input points of shape [B, N, 3+C]. points (torch.Tensor): Input points of shape [B, N, 3+C].
img_metas (list[dict]): Meta information of each sample. input_metas (list[dict]): Meta information of each sample.
rescale (bool): Whether transform to original number of points. rescale (bool): Whether transform to original number of points.
Will be used for voxelization based segmentors. Will be used for voxelization based segmentors.
...@@ -384,19 +446,29 @@ class EncoderDecoder3D(Base3DSegmentor): ...@@ -384,19 +446,29 @@ class EncoderDecoder3D(Base3DSegmentor):
if self.test_cfg.mode == 'slide': if self.test_cfg.mode == 'slide':
seg_logit = torch.stack([ seg_logit = torch.stack([
self.slide_inference(point, img_meta, rescale) self.slide_inference(point, img_meta, rescale)
for point, img_meta in zip(points, img_metas) for point, img_meta in zip(points, input_metas)
], 0) ], 0)
else: else:
seg_logit = self.whole_inference(points, img_metas, rescale) seg_logit = self.whole_inference(points, input_metas, rescale)
output = F.softmax(seg_logit, dim=1) output = F.softmax(seg_logit, dim=1)
return output return output
def simple_test(self, points, img_metas, rescale=True): def predict(self,
batch_inputs_dict: dict,
batch_data_samples: SampleList,
rescale: bool = True) -> SampleList:
"""Simple test with single scene. """Simple test with single scene.
Args: Args:
points (list[torch.Tensor]): List of points of shape [N, 3+C]. batch_inputs_dict (dict): Input sample dict which
img_metas (list[dict]): Meta information of each sample. includes 'points' and 'imgs' keys.
- points (list[torch.Tensor]): Point cloud of each sample.
- imgs (torch.Tensor, optional): Image tensor has shape
(B, C, H, W).
batch_data_samples (list[:obj:`Det3DDataSample`]): The det3d
data samples. It usually includes information such
as `metainfo` and `gt_pts_sem_seg`.
rescale (bool): Whether transform to original number of points. rescale (bool): Whether transform to original number of points.
Will be used for voxelization based segmentors. Will be used for voxelization based segmentors.
Defaults to True. Defaults to True.
...@@ -410,9 +482,14 @@ class EncoderDecoder3D(Base3DSegmentor): ...@@ -410,9 +482,14 @@ class EncoderDecoder3D(Base3DSegmentor):
# to use down-sampling to get a batch of scenes with same num_points # to use down-sampling to get a batch of scenes with same num_points
# therefore, we only support testing one scene every time # therefore, we only support testing one scene every time
seg_pred = [] seg_pred = []
for point, img_meta in zip(points, img_metas): batch_input_metas = []
seg_prob = self.inference(point.unsqueeze(0), [img_meta], for data_sample in batch_data_samples:
rescale)[0] batch_input_metas.append(data_sample.metainfo)
points = batch_inputs_dict['points']
for point, input_meta in zip(points, batch_input_metas):
seg_prob = self.inference(
point.unsqueeze(0), [input_meta], rescale)[0]
seg_map = seg_prob.argmax(0) # [N] seg_map = seg_prob.argmax(0) # [N]
# to cpu tensor for consistency with det3d # to cpu tensor for consistency with det3d
seg_map = seg_map.cpu() seg_map = seg_map.cpu()
...@@ -421,33 +498,24 @@ class EncoderDecoder3D(Base3DSegmentor): ...@@ -421,33 +498,24 @@ class EncoderDecoder3D(Base3DSegmentor):
seg_pred = [dict(semantic_mask=seg_map) for seg_map in seg_pred] seg_pred = [dict(semantic_mask=seg_map) for seg_map in seg_pred]
return seg_pred return seg_pred
def aug_test(self, points, img_metas, rescale=True): def _forward(self,
"""Test with augmentations. batch_inputs_dict: dict,
batch_data_samples: OptSampleList = None) -> Tensor:
"""Network forward process.
Args: Args:
points (list[torch.Tensor]): List of points of shape [B, N, 3+C]. batch_inputs_dict (dict): Input sample dict which
img_metas (list[list[dict]]): Meta information of each sample. includes 'points' and 'imgs' keys.
Outer list are different samples while inner is different augs.
rescale (bool): Whether transform to original number of points.
Will be used for voxelization based segmentors.
Defaults to True.
Returns: - points (list[torch.Tensor]): Point cloud of each sample.
list[dict]: The output prediction result with following keys: - imgs (torch.Tensor, optional): Image tensor has shape
(B, C, H, W).
batch_data_samples (List[:obj:`Det3DDataSample`]): The seg
data samples. It usually includes information such
as `metainfo` and `gt_pts_sem_seg`.
- semantic_mask (Tensor): Segmentation mask of shape [N]. Returns:
Tensor: Forward output of model without any post-processes.
""" """
# in aug_test, one scene going through different augmentations could x = self.extract_feat(batch_inputs_dict)
# have the same number of points and are stacked as a batch return self.decode_head.forward(x)
# to save memory, we get augmented seg logit inplace
seg_pred = []
for point, img_meta in zip(points, img_metas):
seg_prob = self.inference(point, img_meta, rescale)
seg_prob = seg_prob.mean(0) # [num_classes, N]
seg_map = seg_prob.argmax(0) # [N]
# to cpu tensor for consistency with det3d
seg_map = seg_map.cpu()
seg_pred.append(seg_map)
# warp in dict
seg_pred = [dict(semantic_mask=seg_map) for seg_map in seg_pred]
return seg_pred
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import torch
from mmdet3d.core import Det3DDataSample, PointData
from mmdet3d.models.decode_heads import DGCNNHead
class TestDGCNNHead(TestCase):
def test_dgcnn_head_loss(self):
"""Tests DGCNN head loss."""
dgcnn_head = DGCNNHead(
fp_channels=(1024, 512),
channels=256,
num_classes=13,
dropout_ratio=0.5,
conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'),
act_cfg=dict(type='LeakyReLU', negative_slope=0.2),
loss_decode=dict(
type='mmdet.CrossEntropyLoss',
use_sigmoid=False,
class_weight=None,
loss_weight=1.0),
ignore_index=13)
# DGCNN head expects dict format features
fa_points = torch.rand(1, 4096, 1024).float()
feat_dict = dict(fa_points=fa_points)
# Test forward
seg_logits = dgcnn_head.forward(feat_dict)
self.assertEqual(seg_logits, torch.Size([1, 13, 4096]))
# When truth is non-empty then losses
# should be nonzero for random inputs
pts_semantic_mask = torch.randint(0, 13, (2, 4096)).long()
gt_pts_seg = PointData(pts_semantic_mask=pts_semantic_mask)
datasample = Det3DDataSample()
datasample.gt_pts_seg = gt_pts_seg
gt_losses = dgcnn_head.loss(seg_logits, [datasample])
gt_sem_seg_loss = gt_losses['loss_sem_seg'].item()
self.assertGreater(gt_sem_seg_loss, 0,
'semantic seg loss should be positive')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment