Commit 7695d4a4 authored by wuyuefeng's avatar wuyuefeng
Browse files

sparse_unetv2 backbone

parent 2f5b24ce
from .pillar_scatter import PointPillarsScatter
from .sparse_encoder import SparseEncoder
from .sparse_unetv2 import SparseUnetV2
__all__ = ['PointPillarsScatter', 'SparseEncoder']
__all__ = ['PointPillarsScatter', 'SparseEncoder', 'SparseUnetV2']
from torch import nn
import mmdet3d.ops.spconv as spconv
from mmdet.ops import build_norm_layer
def conv3x3(in_planes, out_planes, stride=1, indice_key=None):
"""3x3 submanifold sparse convolution with padding.
Args:
in_planes (int): the number of input channels
out_planes (int): the number of output channels
stride (int): the stride of convolution
indice_key (str): the indice key used for sparse tensor
Returns:
spconv.conv.SubMConv3d: 3x3 submanifold sparse convolution ops
"""
return spconv.SubMConv3d(
in_planes,
out_planes,
kernel_size=3,
stride=stride,
padding=1,
bias=False,
indice_key=indice_key)
def conv1x1(in_planes, out_planes, stride=1, indice_key=None):
"""1x1 submanifold sparse convolution with padding.
Args:
in_planes (int): the number of input channels
out_planes (int): the number of output channels
stride (int): the stride of convolution
indice_key (str): the indice key used for sparse tensor
Returns:
spconv.conv.SubMConv3d: 1x1 submanifold sparse convolution ops
"""
return spconv.SubMConv3d(
in_planes,
out_planes,
kernel_size=1,
stride=stride,
padding=1,
bias=False,
indice_key=indice_key)
class SparseBasicBlock(spconv.SparseModule):
expansion = 1
def __init__(self,
inplanes,
planes,
stride=1,
downsample=None,
indice_key=None,
norm_cfg=None):
"""Sparse basic block for PartA^2.
Sparse basic block implemented with submanifold sparse convolution.
"""
super(SparseBasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride, indice_key=indice_key)
norm_name1, norm_layer1 = build_norm_layer(norm_cfg, planes)
self.bn1 = norm_layer1
self.relu = nn.ReLU()
self.conv2 = conv3x3(planes, planes, indice_key=indice_key)
norm_name2, norm_layer2 = build_norm_layer(norm_cfg, planes)
self.bn2 = norm_layer2
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x.features
assert x.features.dim() == 2, 'x.features.dim()=%d' % x.features.dim()
out = self.conv1(x)
out.features = self.bn1(out.features)
out.features = self.relu(out.features)
out = self.conv2(out)
out.features = self.bn2(out.features)
if self.downsample is not None:
identity = self.downsample(x)
out.features += identity
out.features = self.relu(out.features)
return out
class SparseBottleneck(spconv.SparseModule):
expansion = 4
def __init__(self,
inplanes,
planes,
stride=1,
downsample=None,
indice_key=None,
norm_fn=None):
"""Sparse bottleneck block for PartA^2.
Bottleneck block implemented with submanifold sparse convolution.
"""
super(SparseBottleneck, self).__init__()
self.conv1 = conv1x1(inplanes, planes, indice_key=indice_key)
self.bn1 = norm_fn(planes)
self.conv2 = conv3x3(planes, planes, stride, indice_key=indice_key)
self.bn2 = norm_fn(planes)
self.conv3 = conv1x1(
planes, planes * self.expansion, indice_key=indice_key)
self.bn3 = norm_fn(planes * self.expansion)
self.relu = nn.ReLU()
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x.features
out = self.conv1(x)
out.features = self.bn1(out.features)
out.features = self.relu(out.features)
out = self.conv2(out)
out.features = self.bn2(out.features)
out.features = self.relu(out.features)
out = self.conv3(out)
out.features = self.bn3(out.features)
if self.downsample is not None:
identity = self.downsample(x)
out.features += identity
out.features = self.relu(out.features)
return out
import torch
import torch.nn as nn
import mmdet3d.ops.spconv as spconv
from mmdet.ops import build_norm_layer
from ..registry import MIDDLE_ENCODERS
from .sparse_block_utils import SparseBasicBlock
@MIDDLE_ENCODERS.register_module
class SparseUnetV2(nn.Module):
def __init__(self,
in_channels,
output_shape,
pre_act,
norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.01)):
"""SparseUnet for PartA^2
Args:
in_channels (int): the number of input channels
output_shape (list[int]): the shape of output tensor
pre_act (bool): use pre_act_block or post_act_block
norm_cfg (dict): normalize layer config
"""
super().__init__()
self.sparse_shape = output_shape
self.output_shape = output_shape
self.in_channels = in_channels
self.pre_act = pre_act
# Spconv init all weight on its own
# TODO: make the network could be modified
if pre_act:
self.conv_input = spconv.SparseSequential(
spconv.SubMConv3d(
in_channels,
16,
3,
padding=1,
bias=False,
indice_key='subm1'), )
block = self.pre_act_block
else:
norm_name, norm_layer = build_norm_layer(norm_cfg, 16)
self.conv_input = spconv.SparseSequential(
spconv.SubMConv3d(
in_channels,
16,
3,
padding=1,
bias=False,
indice_key='subm1'),
norm_layer,
nn.ReLU(),
)
block = self.post_act_block
self.conv1 = spconv.SparseSequential(
block(16, 16, 3, norm_cfg=norm_cfg, padding=1,
indice_key='subm1'), )
self.conv2 = spconv.SparseSequential(
# [1600, 1408, 41] -> [800, 704, 21]
block(
16,
32,
3,
norm_cfg=norm_cfg,
stride=2,
padding=1,
indice_key='spconv2',
conv_type='spconv'),
block(32, 32, 3, norm_cfg=norm_cfg, padding=1, indice_key='subm2'),
block(32, 32, 3, norm_cfg=norm_cfg, padding=1, indice_key='subm2'),
)
self.conv3 = spconv.SparseSequential(
# [800, 704, 21] -> [400, 352, 11]
block(
32,
64,
3,
norm_cfg=norm_cfg,
stride=2,
padding=1,
indice_key='spconv3',
conv_type='spconv'),
block(64, 64, 3, norm_cfg=norm_cfg, padding=1, indice_key='subm3'),
block(64, 64, 3, norm_cfg=norm_cfg, padding=1, indice_key='subm3'),
)
self.conv4 = spconv.SparseSequential(
# [400, 352, 11] -> [200, 176, 5]
block(
64,
64,
3,
norm_cfg=norm_cfg,
stride=2,
padding=(0, 1, 1),
indice_key='spconv4',
conv_type='spconv'),
block(64, 64, 3, norm_cfg=norm_cfg, padding=1, indice_key='subm4'),
block(64, 64, 3, norm_cfg=norm_cfg, padding=1, indice_key='subm4'),
)
norm_name, norm_layer = build_norm_layer(norm_cfg, 128)
self.conv_out = spconv.SparseSequential(
# [200, 176, 5] -> [200, 176, 2]
spconv.SparseConv3d(
64,
128, (3, 1, 1),
stride=(2, 1, 1),
padding=0,
bias=False,
indice_key='spconv_down2'),
norm_layer,
nn.ReLU(),
)
# decoder
# [400, 352, 11] <- [200, 176, 5]
self.conv_up_t4 = SparseBasicBlock(
64, 64, indice_key='subm4', norm_cfg=norm_cfg)
self.conv_up_m4 = block(
128, 64, 3, norm_cfg=norm_cfg, padding=1, indice_key='subm4')
self.inv_conv4 = block(
64,
64,
3,
norm_cfg=norm_cfg,
indice_key='spconv4',
conv_type='inverseconv')
# [800, 704, 21] <- [400, 352, 11]
self.conv_up_t3 = SparseBasicBlock(
64, 64, indice_key='subm3', norm_cfg=norm_cfg)
self.conv_up_m3 = block(
128, 64, 3, norm_cfg=norm_cfg, padding=1, indice_key='subm3')
self.inv_conv3 = block(
64,
32,
3,
norm_cfg=norm_cfg,
indice_key='spconv3',
conv_type='inverseconv')
# [1600, 1408, 41] <- [800, 704, 21]
self.conv_up_t2 = SparseBasicBlock(
32, 32, indice_key='subm2', norm_cfg=norm_cfg)
self.conv_up_m2 = block(
64, 32, 3, norm_cfg=norm_cfg, indice_key='subm2')
self.inv_conv2 = block(
32,
16,
3,
norm_cfg=norm_cfg,
indice_key='spconv2',
conv_type='inverseconv')
# [1600, 1408, 41] <- [1600, 1408, 41]
self.conv_up_t1 = SparseBasicBlock(
16, 16, indice_key='subm1', norm_cfg=norm_cfg)
self.conv_up_m1 = block(
32, 16, 3, norm_cfg=norm_cfg, indice_key='subm1')
self.conv5 = spconv.SparseSequential(
block(16, 16, 3, norm_cfg=norm_cfg, padding=1, indice_key='subm1'))
self.seg_cls_layer = nn.Linear(16, 1, bias=True)
self.seg_reg_layer = nn.Linear(16, 3, bias=True)
def forward(self, voxel_features, coors, batch_size):
"""Forward of SparseUnetV2
Args:
voxel_features (torch.float32): shape [N, C]
coors (torch.int32): shape [N, 4](batch_idx, z_idx, y_idx, x_idx)
batch_size (int): batch size
Returns:
dict: backbone features
"""
coors = coors.int()
input_sp_tensor = spconv.SparseConvTensor(voxel_features, coors,
self.sparse_shape,
batch_size)
x = self.conv_input(input_sp_tensor)
x_conv1 = self.conv1(x)
x_conv2 = self.conv2(x_conv1)
x_conv3 = self.conv3(x_conv2)
x_conv4 = self.conv4(x_conv3)
# for detection head
# [200, 176, 5] -> [200, 176, 2]
out = self.conv_out(x_conv4)
spatial_features = out.dense()
N, C, D, H, W = spatial_features.shape
spatial_features = spatial_features.view(N, C * D, H, W)
ret = {'spatial_features': spatial_features}
# for segmentation head
# [400, 352, 11] <- [200, 176, 5]
x_up4 = self.UR_block_forward(x_conv4, x_conv4, self.conv_up_t4,
self.conv_up_m4, self.inv_conv4)
# [800, 704, 21] <- [400, 352, 11]
x_up3 = self.UR_block_forward(x_conv3, x_up4, self.conv_up_t3,
self.conv_up_m3, self.inv_conv3)
# [1600, 1408, 41] <- [800, 704, 21]
x_up2 = self.UR_block_forward(x_conv2, x_up3, self.conv_up_t2,
self.conv_up_m2, self.inv_conv2)
# [1600, 1408, 41] <- [1600, 1408, 41]
x_up1 = self.UR_block_forward(x_conv1, x_up2, self.conv_up_t1,
self.conv_up_m1, self.conv5)
seg_features = x_up1.features
seg_cls_preds = self.seg_cls_layer(seg_features) # (N, 1)
seg_reg_preds = self.seg_reg_layer(seg_features) # (N, 3)
ret.update({
'u_seg_preds': seg_cls_preds,
'u_reg_preds': seg_reg_preds,
'seg_features': seg_features
})
return ret
def UR_block_forward(self, x_lateral, x_bottom, conv_t, conv_m, conv_inv):
"""Forward of upsample and residual block.
Args:
x_lateral (SparseConvTensor): lateral tensor
x_bottom (SparseConvTensor): tensor from bottom layer
conv_t (SparseBasicBlock): convolution for lateral tensor
conv_m (SparseSequential): convolution for merging features
conv_inv (SparseSequential): convolution for upsampling
Returns:
SparseConvTensor: upsampled feature
"""
x_trans = conv_t(x_lateral)
x = x_trans
x.features = torch.cat((x_bottom.features, x_trans.features), dim=1)
x_m = conv_m(x)
x = self.channel_reduction(x, x_m.features.shape[1])
x.features = x_m.features + x.features
x = conv_inv(x)
return x
@staticmethod
def channel_reduction(x, out_channels):
"""Channel reduction for element-wise add.
Args:
x (SparseConvTensor): x.features (N, C1)
out_channels (int): the number of channel after reduction
Returns:
SparseConvTensor: channel reduced feature
"""
features = x.features
n, in_channels = features.shape
assert (in_channels %
out_channels == 0) and (in_channels >= out_channels)
x.features = features.view(n, out_channels, -1).sum(dim=2)
return x
def pre_act_block(self,
in_channels,
out_channels,
kernel_size,
indice_key=None,
stride=1,
padding=0,
conv_type='subm',
norm_cfg=None):
"""Make pre activate sparse convolution block.
Args:
in_channels (int): the number of input channels
out_channels (int): the number of out channels
kernel_size (int): kernel size of convolution
indice_key (str): the indice key used for sparse tensor
stride (int): the stride of convolution
padding (int or list[int]): the padding number of input
conv_type (str): conv type in 'subm', 'spconv' or 'inverseconv'
norm_cfg (dict): normal layer configs
Returns:
spconv.SparseSequential: pre activate sparse convolution block.
"""
assert conv_type in ['subm', 'spconv', 'inverseconv']
norm_name, norm_layer = build_norm_layer(norm_cfg, in_channels)
if conv_type == 'subm':
m = spconv.SparseSequential(
norm_layer,
nn.ReLU(inplace=True),
spconv.SubMConv3d(
in_channels,
out_channels,
kernel_size,
padding=padding,
bias=False,
indice_key=indice_key),
)
elif conv_type == 'spconv':
m = spconv.SparseSequential(
norm_layer,
nn.ReLU(inplace=True),
spconv.SparseConv3d(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
bias=False,
indice_key=indice_key),
)
elif conv_type == 'inverseconv':
m = spconv.SparseSequential(
norm_layer,
nn.ReLU(inplace=True),
spconv.SparseInverseConv3d(
in_channels,
out_channels,
kernel_size,
bias=False,
indice_key=indice_key),
)
else:
raise NotImplementedError
return m
def post_act_block(self,
in_channels,
out_channels,
kernel_size,
indice_key,
stride=1,
padding=0,
conv_type='subm',
norm_cfg=None):
"""Make post activate sparse convolution block.
Args:
in_channels (int): the number of input channels
out_channels (int): the number of out channels
kernel_size (int): kernel size of convolution
indice_key (str): the indice key used for sparse tensor
stride (int): the stride of convolution
padding (int or list[int]): the padding number of input
conv_type (str): conv type in 'subm', 'spconv' or 'inverseconv'
norm_cfg (dict[str]): normal layer configs
Returns:
spconv.SparseSequential: post activate sparse convolution block.
"""
assert conv_type in ['subm', 'spconv', 'inverseconv']
norm_name, norm_layer = build_norm_layer(norm_cfg, out_channels)
if conv_type == 'subm':
m = spconv.SparseSequential(
spconv.SubMConv3d(
in_channels,
out_channels,
kernel_size,
bias=False,
indice_key=indice_key),
norm_layer,
nn.ReLU(inplace=True),
)
elif conv_type == 'spconv':
m = spconv.SparseSequential(
spconv.SparseConv3d(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
bias=False,
indice_key=indice_key),
norm_layer,
nn.ReLU(inplace=True),
)
elif conv_type == 'inverseconv':
m = spconv.SparseSequential(
spconv.SparseInverseConv3d(
in_channels,
out_channels,
kernel_size,
bias=False,
indice_key=indice_key),
norm_layer,
nn.ReLU(inplace=True),
)
else:
raise NotImplementedError
return m
import torch
def test_SparseUnetV2():
from mmdet3d.models.middle_encoders.sparse_unetv2 import SparseUnetV2
self = SparseUnetV2(
in_channels=4, output_shape=[41, 1600, 1408], pre_act=False)
voxel_features = torch.tensor([[6.56126, 0.9648336, -1.7339306, 0.315],
[6.8162713, -2.480431, -1.3616394, 0.36],
[11.643568, -4.744306, -1.3580885, 0.16],
[23.482342, 6.5036807, 0.5806964, 0.35]],
dtype=torch.float32) # n, point_features
coordinates = torch.tensor(
[[0, 12, 819, 131], [0, 16, 750, 136], [1, 16, 705, 232],
[1, 35, 930, 469]],
dtype=torch.int32) # n, 4(batch, ind_x, ind_y, ind_z)
unet_ret_dict = self.forward(voxel_features, coordinates, 2)
seg_cls_preds = unet_ret_dict['u_seg_preds']
seg_reg_preds = unet_ret_dict['u_reg_preds']
seg_features = unet_ret_dict['seg_features']
spatial_features = unet_ret_dict['spatial_features']
assert seg_cls_preds.shape == torch.Size([4, 1])
assert seg_reg_preds.shape == torch.Size([4, 3])
assert seg_features.shape == torch.Size([4, 16])
assert spatial_features.shape == torch.Size([2, 256, 200, 176])
if __name__ == '__main__':
test_SparseUnetV2()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment