Commit 5db915b5 authored by wuyuefeng's avatar wuyuefeng
Browse files

modified sparse block ops

parent 7695d4a4
...@@ -2,9 +2,9 @@ import torch ...@@ -2,9 +2,9 @@ import torch
import torch.nn as nn import torch.nn as nn
import mmdet3d.ops.spconv as spconv import mmdet3d.ops.spconv as spconv
from mmdet3d.ops import SparseBasicBlock
from mmdet.ops import build_norm_layer from mmdet.ops import build_norm_layer
from ..registry import MIDDLE_ENCODERS from ..registry import MIDDLE_ENCODERS
from .sparse_block_utils import SparseBasicBlock
@MIDDLE_ENCODERS.register_module @MIDDLE_ENCODERS.register_module
...@@ -122,7 +122,10 @@ class SparseUnetV2(nn.Module): ...@@ -122,7 +122,10 @@ class SparseUnetV2(nn.Module):
# decoder # decoder
# [400, 352, 11] <- [200, 176, 5] # [400, 352, 11] <- [200, 176, 5]
self.conv_up_t4 = SparseBasicBlock( self.conv_up_t4 = SparseBasicBlock(
64, 64, indice_key='subm4', norm_cfg=norm_cfg) 64,
64,
conv_cfg=dict(type='SubMConv3d', indice_key='subm4'),
norm_cfg=norm_cfg)
self.conv_up_m4 = block( self.conv_up_m4 = block(
128, 64, 3, norm_cfg=norm_cfg, padding=1, indice_key='subm4') 128, 64, 3, norm_cfg=norm_cfg, padding=1, indice_key='subm4')
self.inv_conv4 = block( self.inv_conv4 = block(
...@@ -135,7 +138,10 @@ class SparseUnetV2(nn.Module): ...@@ -135,7 +138,10 @@ class SparseUnetV2(nn.Module):
# [800, 704, 21] <- [400, 352, 11] # [800, 704, 21] <- [400, 352, 11]
self.conv_up_t3 = SparseBasicBlock( self.conv_up_t3 = SparseBasicBlock(
64, 64, indice_key='subm3', norm_cfg=norm_cfg) 64,
64,
conv_cfg=dict(type='SubMConv3d', indice_key='subm3'),
norm_cfg=norm_cfg)
self.conv_up_m3 = block( self.conv_up_m3 = block(
128, 64, 3, norm_cfg=norm_cfg, padding=1, indice_key='subm3') 128, 64, 3, norm_cfg=norm_cfg, padding=1, indice_key='subm3')
self.inv_conv3 = block( self.inv_conv3 = block(
...@@ -148,7 +154,10 @@ class SparseUnetV2(nn.Module): ...@@ -148,7 +154,10 @@ class SparseUnetV2(nn.Module):
# [1600, 1408, 41] <- [800, 704, 21] # [1600, 1408, 41] <- [800, 704, 21]
self.conv_up_t2 = SparseBasicBlock( self.conv_up_t2 = SparseBasicBlock(
32, 32, indice_key='subm2', norm_cfg=norm_cfg) 32,
32,
conv_cfg=dict(type='SubMConv3d', indice_key='subm2'),
norm_cfg=norm_cfg)
self.conv_up_m2 = block( self.conv_up_m2 = block(
64, 32, 3, norm_cfg=norm_cfg, indice_key='subm2') 64, 32, 3, norm_cfg=norm_cfg, indice_key='subm2')
self.inv_conv2 = block( self.inv_conv2 = block(
...@@ -161,7 +170,10 @@ class SparseUnetV2(nn.Module): ...@@ -161,7 +170,10 @@ class SparseUnetV2(nn.Module):
# [1600, 1408, 41] <- [1600, 1408, 41] # [1600, 1408, 41] <- [1600, 1408, 41]
self.conv_up_t1 = SparseBasicBlock( self.conv_up_t1 = SparseBasicBlock(
16, 16, indice_key='subm1', norm_cfg=norm_cfg) 16,
16,
conv_cfg=dict(type='SubMConv3d', indice_key='subm1'),
norm_cfg=norm_cfg)
self.conv_up_m1 = block( self.conv_up_m1 = block(
32, 16, 3, norm_cfg=norm_cfg, indice_key='subm1') 32, 16, 3, norm_cfg=norm_cfg, indice_key='subm1')
......
...@@ -2,12 +2,29 @@ from mmdet.ops import (RoIAlign, SigmoidFocalLoss, get_compiler_version, ...@@ -2,12 +2,29 @@ from mmdet.ops import (RoIAlign, SigmoidFocalLoss, get_compiler_version,
get_compiling_cuda_version, nms, roi_align, get_compiling_cuda_version, nms, roi_align,
sigmoid_focal_loss) sigmoid_focal_loss)
from .norm import NaiveSyncBatchNorm1d, NaiveSyncBatchNorm2d from .norm import NaiveSyncBatchNorm1d, NaiveSyncBatchNorm2d
from .sparse_block import (SparseBasicBlock, SparseBasicBlockV0,
SparseBottleneck, SparseBottleneckV0)
from .voxel import DynamicScatter, Voxelization, dynamic_scatter, voxelization from .voxel import DynamicScatter, Voxelization, dynamic_scatter, voxelization
__all__ = [ __all__ = [
'nms', 'soft_nms', 'RoIAlign', 'roi_align', 'get_compiler_version', 'nms',
'get_compiling_cuda_version', 'build_conv_layer', 'NaiveSyncBatchNorm1d', 'soft_nms',
'NaiveSyncBatchNorm2d', 'batched_nms', 'Voxelization', 'voxelization', 'RoIAlign',
'dynamic_scatter', 'DynamicScatter', 'sigmoid_focal_loss', 'roi_align',
'SigmoidFocalLoss' 'get_compiler_version',
'get_compiling_cuda_version',
'build_conv_layer',
'NaiveSyncBatchNorm1d',
'NaiveSyncBatchNorm2d',
'batched_nms',
'Voxelization',
'voxelization',
'dynamic_scatter',
'DynamicScatter',
'sigmoid_focal_loss',
'SigmoidFocalLoss',
'SparseBasicBlockV0',
'SparseBottleneckV0',
'SparseBasicBlock',
'SparseBottleneck',
] ]
from torch import nn from torch import nn
import mmdet3d.ops.spconv as spconv import mmdet3d.ops.spconv as spconv
from mmdet.models.backbones.resnet import BasicBlock, Bottleneck
from mmdet.ops import build_norm_layer from mmdet.ops import build_norm_layer
from mmdet.ops.conv import conv_cfg
conv_cfg.update({
'SubMConv3d': spconv.SubMConv3d,
})
def conv3x3(in_planes, out_planes, stride=1, indice_key=None): def conv3x3(in_planes, out_planes, stride=1, indice_key=None):
...@@ -16,6 +22,7 @@ def conv3x3(in_planes, out_planes, stride=1, indice_key=None): ...@@ -16,6 +22,7 @@ def conv3x3(in_planes, out_planes, stride=1, indice_key=None):
Returns: Returns:
spconv.conv.SubMConv3d: 3x3 submanifold sparse convolution ops spconv.conv.SubMConv3d: 3x3 submanifold sparse convolution ops
""" """
# TODO: duplicate this class
return spconv.SubMConv3d( return spconv.SubMConv3d(
in_planes, in_planes,
out_planes, out_planes,
...@@ -38,6 +45,7 @@ def conv1x1(in_planes, out_planes, stride=1, indice_key=None): ...@@ -38,6 +45,7 @@ def conv1x1(in_planes, out_planes, stride=1, indice_key=None):
Returns: Returns:
spconv.conv.SubMConv3d: 1x1 submanifold sparse convolution ops spconv.conv.SubMConv3d: 1x1 submanifold sparse convolution ops
""" """
# TODO: duplicate this class
return spconv.SubMConv3d( return spconv.SubMConv3d(
in_planes, in_planes,
out_planes, out_planes,
...@@ -48,7 +56,7 @@ def conv1x1(in_planes, out_planes, stride=1, indice_key=None): ...@@ -48,7 +56,7 @@ def conv1x1(in_planes, out_planes, stride=1, indice_key=None):
indice_key=indice_key) indice_key=indice_key)
class SparseBasicBlock(spconv.SparseModule): class SparseBasicBlockV0(spconv.SparseModule):
expansion = 1 expansion = 1
def __init__(self, def __init__(self,
...@@ -62,7 +70,8 @@ class SparseBasicBlock(spconv.SparseModule): ...@@ -62,7 +70,8 @@ class SparseBasicBlock(spconv.SparseModule):
Sparse basic block implemented with submanifold sparse convolution. Sparse basic block implemented with submanifold sparse convolution.
""" """
super(SparseBasicBlock, self).__init__() # TODO: duplicate this class
super().__init__()
self.conv1 = conv3x3(inplanes, planes, stride, indice_key=indice_key) self.conv1 = conv3x3(inplanes, planes, stride, indice_key=indice_key)
norm_name1, norm_layer1 = build_norm_layer(norm_cfg, planes) norm_name1, norm_layer1 = build_norm_layer(norm_cfg, planes)
self.bn1 = norm_layer1 self.bn1 = norm_layer1
...@@ -94,7 +103,7 @@ class SparseBasicBlock(spconv.SparseModule): ...@@ -94,7 +103,7 @@ class SparseBasicBlock(spconv.SparseModule):
return out return out
class SparseBottleneck(spconv.SparseModule): class SparseBottleneckV0(spconv.SparseModule):
expansion = 4 expansion = 4
def __init__(self, def __init__(self,
...@@ -108,7 +117,8 @@ class SparseBottleneck(spconv.SparseModule): ...@@ -108,7 +117,8 @@ class SparseBottleneck(spconv.SparseModule):
Bottleneck block implemented with submanifold sparse convolution. Bottleneck block implemented with submanifold sparse convolution.
""" """
super(SparseBottleneck, self).__init__() # TODO: duplicate this class
super().__init__()
self.conv1 = conv1x1(inplanes, planes, indice_key=indice_key) self.conv1 = conv1x1(inplanes, planes, indice_key=indice_key)
self.bn1 = norm_fn(planes) self.bn1 = norm_fn(planes)
self.conv2 = conv3x3(planes, planes, stride, indice_key=indice_key) self.conv2 = conv3x3(planes, planes, stride, indice_key=indice_key)
...@@ -141,3 +151,95 @@ class SparseBottleneck(spconv.SparseModule): ...@@ -141,3 +151,95 @@ class SparseBottleneck(spconv.SparseModule):
out.features = self.relu(out.features) out.features = self.relu(out.features)
return out return out
class SparseBottleneck(Bottleneck, spconv.SparseModule):
expansion = 4
def __init__(self,
inplanes,
planes,
stride=1,
downsample=None,
conv_cfg=None,
norm_cfg=None):
"""Sparse bottleneck block for PartA^2.
Bottleneck block implemented with submanifold sparse convolution.
"""
spconv.SparseModule.__init__(self)
Bottleneck.__init__(
self,
inplanes,
planes,
stride=stride,
downsample=downsample,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg)
def forward(self, x):
identity = x.features
out = self.conv1(x)
out.features = self.bn1(out.features)
out.features = self.relu(out.features)
out = self.conv2(out)
out.features = self.bn2(out.features)
out.features = self.relu(out.features)
out = self.conv3(out)
out.features = self.bn3(out.features)
if self.downsample is not None:
identity = self.downsample(x)
out.features += identity
out.features = self.relu(out.features)
return out
class SparseBasicBlock(BasicBlock, spconv.SparseModule):
expansion = 1
def __init__(self,
inplanes,
planes,
stride=1,
downsample=None,
conv_cfg=None,
norm_cfg=None):
"""Sparse basic block for PartA^2.
Sparse basic block implemented with submanifold sparse convolution.
"""
spconv.SparseModule.__init__(self)
BasicBlock.__init__(
self,
inplanes,
planes,
stride=stride,
downsample=downsample,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg)
def forward(self, x):
identity = x.features
assert x.features.dim() == 2, 'x.features.dim()=%d' % x.features.dim()
out = self.conv1(x)
out.features = self.norm1(out.features)
out.features = self.relu(out.features)
out = self.conv2(out)
out.features = self.norm2(out.features)
if self.downsample is not None:
identity = self.downsample(x)
out.features += identity
out.features = self.relu(out.features)
return out
...@@ -27,5 +27,37 @@ def test_SparseUnetV2(): ...@@ -27,5 +27,37 @@ def test_SparseUnetV2():
assert spatial_features.shape == torch.Size([2, 256, 200, 176]) assert spatial_features.shape == torch.Size([2, 256, 200, 176])
if __name__ == '__main__': def test_SparseBasicBlock():
test_SparseUnetV2() from mmdet3d.ops import SparseBasicBlockV0, SparseBasicBlock
import mmdet3d.ops.spconv as spconv
voxel_features = torch.tensor([[6.56126, 0.9648336, -1.7339306, 0.315],
[6.8162713, -2.480431, -1.3616394, 0.36],
[11.643568, -4.744306, -1.3580885, 0.16],
[23.482342, 6.5036807, 0.5806964, 0.35]],
dtype=torch.float32) # n, point_features
coordinates = torch.tensor(
[[0, 12, 819, 131], [0, 16, 750, 136], [1, 16, 705, 232],
[1, 35, 930, 469]],
dtype=torch.int32) # n, 4(batch, ind_x, ind_y, ind_z)
# test v0
self = SparseBasicBlockV0(
4,
4,
indice_key='subm0',
norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.01))
input_sp_tensor = spconv.SparseConvTensor(voxel_features, coordinates,
[41, 1600, 1408], 2)
out_features = self(input_sp_tensor)
assert out_features.features.shape == torch.Size([4, 4])
# test
input_sp_tensor = spconv.SparseConvTensor(voxel_features, coordinates,
[41, 1600, 1408], 2)
self = SparseBasicBlock(
4,
4,
conv_cfg=dict(type='SubMConv3d', indice_key='subm1'),
norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.01))
out_features = self(input_sp_tensor)
assert out_features.features.shape == torch.Size([4, 4])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment