Unverified Commit 53e06229 authored by Ziyi Wu's avatar Ziyi Wu Committed by GitHub
Browse files

[Feature] Support PointNet++ decode head (#479)

* support PN2 decode head

* add mmseg dependency in github workflow

* complete PN2 decode head

* modify backbone pn2 to support seg task & its unit test

* add unit test for PN2 decode_head
parent 36400705
......@@ -92,6 +92,7 @@ jobs:
run: |
pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu101/${{matrix.torch_version}}/index.html
pip install mmdet==2.11.0
pip install -q git+https://github.com/open-mmlab/mmsegmentation.git
pip install -r requirements.txt
- name: Build and install
run: |
......
......@@ -4,6 +4,7 @@ from .builder import (FUSION_LAYERS, MIDDLE_ENCODERS, VOXEL_ENCODERS,
build_head, build_loss, build_middle_encoder, build_neck,
build_roi_extractor, build_shared_head,
build_voxel_encoder)
from .decode_heads import * # noqa: F401,F403
from .dense_heads import * # noqa: F401,F403
from .detectors import * # noqa: F401,F403
from .fusion_layers import * # noqa: F401,F403
......
......@@ -102,15 +102,21 @@ class PointNet2SAMSG(BasePointNet):
cfg=sa_cfg,
bias=True))
skip_channel_list.append(sa_out_channel)
self.aggregation_mlps.append(
ConvModule(
sa_out_channel,
aggregation_channels[sa_index],
conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'),
kernel_size=1,
bias=True))
sa_in_channel = aggregation_channels[sa_index]
cur_aggregation_channel = aggregation_channels[sa_index]
if cur_aggregation_channel is None:
self.aggregation_mlps.append(None)
sa_in_channel = sa_out_channel
else:
self.aggregation_mlps.append(
ConvModule(
sa_out_channel,
cur_aggregation_channel,
conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'),
kernel_size=1,
bias=True))
sa_in_channel = cur_aggregation_channel
@auto_fp16(apply_to=('points', ))
def forward(self, points):
......@@ -139,14 +145,15 @@ class PointNet2SAMSG(BasePointNet):
sa_features = [features]
sa_indices = [indices]
out_sa_xyz = []
out_sa_features = []
out_sa_indices = []
out_sa_xyz = [xyz]
out_sa_features = [features]
out_sa_indices = [indices]
for i in range(self.num_sa):
cur_xyz, cur_features, cur_indices = self.SA_modules[i](
sa_xyz[i], sa_features[i])
cur_features = self.aggregation_mlps[i](cur_features)
if self.aggregation_mlps[i] is not None:
cur_features = self.aggregation_mlps[i](cur_features)
sa_xyz.append(cur_xyz)
sa_features.append(cur_features)
sa_indices.append(
......
......@@ -132,5 +132,10 @@ class PointNet2SASSG(BasePointNet):
fp_indices.append(sa_indices[self.num_sa - i - 1])
ret = dict(
fp_xyz=fp_xyz, fp_features=fp_features, fp_indices=fp_indices)
fp_xyz=fp_xyz,
fp_features=fp_features,
fp_indices=fp_indices,
sa_xyz=sa_xyz,
sa_features=sa_features,
sa_indices=sa_indices)
return ret
from .pointnet2_head import PointNet2Head
__all__ = ['PointNet2Head']
from abc import ABCMeta, abstractmethod
from mmcv.cnn import normal_init
from mmcv.runner import auto_fp16, force_fp32
from torch import nn as nn
from mmseg.models.builder import build_loss
class Base3DDecodeHead(nn.Module, metaclass=ABCMeta):
"""Base class for BaseDecodeHead.
Args:
channels (int): Channels after modules, before conv_seg.
num_classes (int): Number of classes.
dropout_ratio (float): Ratio of dropout layer. Default: 0.5.
conv_cfg (dict|None): Config of conv layers.
Default: dict(type='Conv1d').
norm_cfg (dict|None): Config of norm layers.
Default: dict(type='BN1d').
act_cfg (dict): Config of activation layers.
Default: dict(type='ReLU').
loss_decode (dict): Config of decode loss.
Default: dict(type='CrossEntropyLoss').
ignore_index (int | None): The label index to be ignored. When using
masked BCE loss, ignore_index should be set to None. Default: 255.
"""
def __init__(self,
channels,
num_classes,
dropout_ratio=0.5,
conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'),
act_cfg=dict(type='ReLU'),
loss_decode=dict(
type='CrossEntropyLoss',
use_sigmoid=False,
class_weight=None,
loss_weight=1.0),
ignore_index=255):
super(Base3DDecodeHead, self).__init__()
self.channels = channels
self.num_classes = num_classes
self.dropout_ratio = dropout_ratio
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.loss_decode = build_loss(loss_decode)
self.ignore_index = ignore_index
self.conv_seg = nn.Conv1d(channels, num_classes, kernel_size=1)
if dropout_ratio > 0:
self.dropout = nn.Dropout(dropout_ratio)
else:
self.dropout = None
self.fp16_enabled = False
def init_weights(self):
"""Initialize weights of classification layer."""
normal_init(self.conv_seg, mean=0, std=0.01)
@auto_fp16()
@abstractmethod
def forward(self, inputs):
"""Placeholder of forward function."""
pass
def forward_train(self, inputs, img_metas, gt_semantic_seg, train_cfg):
"""Forward function for training.
Args:
inputs (list[Tensor]): List of multi-level point features.
img_metas (list[dict]): Meta information of each sample.
gt_semantic_seg (torch.Tensor): Semantic segmentation masks
used if the architecture supports semantic segmentation task.
train_cfg (dict): The training config.
Returns:
dict[str, Tensor]: a dictionary of loss components
"""
seg_logits = self.forward(inputs)
losses = self.losses(seg_logits, gt_semantic_seg)
return losses
def forward_test(self, inputs, img_metas, test_cfg):
"""Forward function for testing.
Args:
inputs (list[Tensor]): List of multi-level point features.
img_metas (list[dict]): Meta information of each sample.
test_cfg (dict): The testing config.
Returns:
Tensor: Output segmentation map.
"""
return self.forward(inputs)
def cls_seg(self, feat):
"""Classify each points."""
if self.dropout is not None:
feat = self.dropout(feat)
output = self.conv_seg(feat)
return output
@force_fp32(apply_to=('seg_logit', ))
def losses(self, seg_logit, seg_label):
"""Compute semantic segmentation loss.
Args:
seg_logit (torch.Tensor): Predicted per-point segmentation logits \
of shape [B, num_classes, N].
seg_label (torch.Tensor): Ground-truth segmentation label of \
shape [B, N].
"""
loss = dict()
loss['loss_sem_seg'] = self.loss_decode(
seg_logit, seg_label, ignore_index=self.ignore_index)
return loss
from mmcv.cnn.bricks import ConvModule
from torch import nn as nn
from mmdet3d.ops import PointFPModule
from mmdet.models import HEADS
from .decode_head import Base3DDecodeHead
@HEADS.register_module()
class PointNet2Head(Base3DDecodeHead):
r"""PointNet2 decoder head.
Decoder head used in `PointNet++ <https://arxiv.org/abs/1706.02413>`_.
Refer to the `official code <https://github.com/charlesq34/pointnet2>`_.
Args:
fp_channels (tuple[tuple[int]]): Tuple of mlp channels in FP modules.
"""
def __init__(self,
fp_channels=((768, 256, 256), (384, 256, 256),
(320, 256, 128), (128, 128, 128, 128)),
**kwargs):
super(PointNet2Head, self).__init__(**kwargs)
self.num_fp = len(fp_channels)
self.FP_modules = nn.ModuleList()
for cur_fp_mlps in fp_channels:
self.FP_modules.append(PointFPModule(mlp_channels=cur_fp_mlps))
# https://github.com/charlesq34/pointnet2/blob/master/models/pointnet2_sem_seg.py#L40
self.pre_seg_conv = ConvModule(
fp_channels[-1][-1],
self.channels,
kernel_size=1,
bias=True,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
def _extract_input(self, feat_dict):
"""Extract inputs from features dictionary.
Args:
feat_dict (dict): Feature dict from backbone.
Returns:
list[torch.Tensor]: Coordinates of multiple levels of points.
list[torch.Tensor]: Features of multiple levels of points.
"""
sa_xyz = feat_dict['sa_xyz']
sa_features = feat_dict['sa_features']
assert len(sa_xyz) == len(sa_features)
return sa_xyz, sa_features
def forward(self, feat_dict):
"""Forward pass.
Args:
feat_dict (dict): Feature dict from backbone.
Returns:
torch.Tensor: Segmentation map of shape [B, num_classes, N].
"""
sa_xyz, sa_features = self._extract_input(feat_dict)
# https://github.com/charlesq34/pointnet2/blob/master/models/pointnet2_sem_seg.py#L24
sa_features[0] = None
fp_feature = sa_features[-1]
for i in range(self.num_fp):
# consume the points in a bottom-up manner
fp_feature = self.FP_modules[i](sa_xyz[-(i + 2)], sa_xyz[-(i + 1)],
sa_features[-(i + 2)], fp_feature)
output = self.pre_seg_conv(fp_feature)
output = self.cls_seg(output)
return output
......@@ -7,7 +7,7 @@ SPLIT_BEFORE_EXPRESSION_AFTER_OPENING_PAREN = true
line_length = 79
multi_line_output = 0
known_standard_library = setuptools
known_first_party = mmdet,mmdet3d
known_first_party = mmdet,mmseg,mmdet3d
known_third_party = cv2,indoor3d_util,load_scannet_data,lyft_dataset_sdk,m2r,matplotlib,mmcv,nuimages,numba,numpy,nuscenes,pandas,plyfile,pycocotools,pyquaternion,pytest,recommonmark,scannet_utils,scipy,seaborn,shapely,skimage,tensorflow,terminaltables,torch,trimesh,waymo_open_dataset
no_lines_before = STDLIB,LOCALFOLDER
default_section = THIRDPARTY
......@@ -34,12 +34,29 @@ def test_pointnet2_sa_ssg():
fp_xyz = ret_dict['fp_xyz']
fp_features = ret_dict['fp_features']
fp_indices = ret_dict['fp_indices']
sa_xyz = ret_dict['sa_xyz']
sa_features = ret_dict['sa_features']
sa_indices = ret_dict['sa_indices']
assert len(fp_xyz) == len(fp_features) == len(fp_indices) == 3
assert len(sa_xyz) == len(sa_features) == len(sa_indices) == 3
assert fp_xyz[0].shape == torch.Size([1, 16, 3])
assert fp_xyz[1].shape == torch.Size([1, 32, 3])
assert fp_xyz[2].shape == torch.Size([1, 100, 3])
assert fp_features[0].shape == torch.Size([1, 16, 16])
assert fp_features[1].shape == torch.Size([1, 16, 32])
assert fp_features[2].shape == torch.Size([1, 16, 100])
assert fp_indices[0].shape == torch.Size([1, 16])
assert fp_indices[1].shape == torch.Size([1, 32])
assert fp_indices[2].shape == torch.Size([1, 100])
assert sa_xyz[0].shape == torch.Size([1, 100, 3])
assert sa_xyz[1].shape == torch.Size([1, 32, 3])
assert sa_xyz[2].shape == torch.Size([1, 16, 3])
assert sa_features[0].shape == torch.Size([1, 3, 100])
assert sa_features[1].shape == torch.Size([1, 16, 32])
assert sa_features[2].shape == torch.Size([1, 16, 16])
assert sa_indices[0].shape == torch.Size([1, 100])
assert sa_indices[1].shape == torch.Size([1, 32])
assert sa_indices[2].shape == torch.Size([1, 16])
def test_multi_backbone():
......@@ -156,6 +173,8 @@ def test_multi_backbone():
def test_pointnet2_sa_msg():
if not torch.cuda.is_available():
pytest.skip()
# PN2MSG used in 3DSSD
cfg = dict(
type='PointNet2SAMSG',
in_channels=4,
......@@ -216,3 +235,50 @@ def test_pointnet2_sa_msg():
pool_mod='max',
use_xyz=True,
normalize_xyz=False)))
# PN2MSG used in segmentation
cfg = dict(
type='PointNet2SAMSG',
in_channels=6, # [xyz, rgb]
num_points=(1024, 256, 64, 16),
radii=((0.05, 0.1), (0.1, 0.2), (0.2, 0.4), (0.4, 0.8)),
num_samples=((16, 32), (16, 32), (16, 32), (16, 32)),
sa_channels=(((16, 16, 32), (32, 32, 64)), ((64, 64, 128), (64, 96,
128)),
((128, 196, 256), (128, 196, 256)), ((256, 256, 512),
(256, 384, 512))),
aggregation_channels=(None, None, None, None),
fps_mods=(('D-FPS'), ('D-FPS'), ('D-FPS'), ('D-FPS')),
fps_sample_range_lists=((-1), (-1), (-1), (-1)),
dilated_group=(False, False, False, False),
out_indices=(0, 1, 2, 3),
norm_cfg=dict(type='BN2d'),
sa_cfg=dict(
type='PointSAModuleMSG',
pool_mod='max',
use_xyz=True,
normalize_xyz=False))
self = build_backbone(cfg)
self.cuda()
ret_dict = self(xyz)
sa_xyz = ret_dict['sa_xyz']
sa_features = ret_dict['sa_features']
sa_indices = ret_dict['sa_indices']
assert len(sa_xyz) == len(sa_features) == len(sa_indices) == 5
assert sa_xyz[0].shape == torch.Size([1, 100, 3])
assert sa_xyz[1].shape == torch.Size([1, 1024, 3])
assert sa_xyz[2].shape == torch.Size([1, 256, 3])
assert sa_xyz[3].shape == torch.Size([1, 64, 3])
assert sa_xyz[4].shape == torch.Size([1, 16, 3])
assert sa_features[0].shape == torch.Size([1, 3, 100])
assert sa_features[1].shape == torch.Size([1, 96, 1024])
assert sa_features[2].shape == torch.Size([1, 256, 256])
assert sa_features[3].shape == torch.Size([1, 512, 64])
assert sa_features[4].shape == torch.Size([1, 1024, 16])
assert sa_indices[0].shape == torch.Size([1, 100])
assert sa_indices[1].shape == torch.Size([1, 1024])
assert sa_indices[2].shape == torch.Size([1, 256])
assert sa_indices[3].shape == torch.Size([1, 64])
assert sa_indices[4].shape == torch.Size([1, 16])
import numpy as np
import pytest
import torch
from mmcv.cnn.bricks import ConvModule
from mmdet3d.models.builder import build_head
def test_pn2_decode_head_loss():
if not torch.cuda.is_available():
pytest.skip('test requires GPU and torch+cuda')
pn2_decode_head_cfg = dict(
type='PointNet2Head',
fp_channels=((768, 256, 256), (384, 256, 256), (320, 256, 128),
(128, 128, 128, 128)),
channels=128,
num_classes=20,
dropout_ratio=0.5,
conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'),
act_cfg=dict(type='ReLU'),
loss_decode=dict(
type='CrossEntropyLoss',
use_sigmoid=False,
class_weight=None,
loss_weight=1.0),
ignore_index=20)
self = build_head(pn2_decode_head_cfg)
self.cuda()
assert isinstance(self.conv_seg, torch.nn.Conv1d)
assert self.conv_seg.in_channels == 128
assert self.conv_seg.out_channels == 20
assert self.conv_seg.kernel_size == (1, )
assert isinstance(self.pre_seg_conv, ConvModule)
assert isinstance(self.pre_seg_conv.conv, torch.nn.Conv1d)
assert self.pre_seg_conv.conv.in_channels == 128
assert self.pre_seg_conv.conv.out_channels == 128
assert self.pre_seg_conv.conv.kernel_size == (1, )
assert isinstance(self.pre_seg_conv.bn, torch.nn.BatchNorm1d)
assert self.pre_seg_conv.bn.num_features == 128
assert isinstance(self.pre_seg_conv.activate, torch.nn.ReLU)
# test forward
sa_xyz = [
torch.rand(2, 4096, 3).float().cuda(),
torch.rand(2, 1024, 3).float().cuda(),
torch.rand(2, 256, 3).float().cuda(),
torch.rand(2, 64, 3).float().cuda(),
torch.rand(2, 16, 3).float().cuda(),
]
sa_features = [
torch.rand(2, 6, 4096).float().cuda(),
torch.rand(2, 64, 1024).float().cuda(),
torch.rand(2, 128, 256).float().cuda(),
torch.rand(2, 256, 64).float().cuda(),
torch.rand(2, 512, 16).float().cuda(),
]
input_dict = dict(sa_xyz=sa_xyz, sa_features=sa_features)
seg_logits = self(input_dict)
assert seg_logits.shape == torch.Size([2, 20, 4096])
# test loss
gt_semantic_seg = torch.randint(0, 20, (2, 4096)).long().cuda()
losses = self.losses(seg_logits, gt_semantic_seg)
assert losses['loss_sem_seg'].item() > 0
# test loss with ignore_index
ignore_index_mask = torch.ones_like(gt_semantic_seg) * 20
losses = self.losses(seg_logits, ignore_index_mask)
assert losses['loss_sem_seg'].item() == 0
# test loss with class_weight
pn2_decode_head_cfg['loss_decode'] = dict(
type='CrossEntropyLoss',
use_sigmoid=False,
class_weight=np.random.rand(20),
loss_weight=1.0)
self = build_head(pn2_decode_head_cfg)
self.cuda()
losses = self.losses(seg_logits, gt_semantic_seg)
assert losses['loss_sem_seg'].item() > 0
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment