Commit 5db915b5 authored by wuyuefeng's avatar wuyuefeng
Browse files

modified sparse block ops

parent 7695d4a4
......@@ -2,9 +2,9 @@ import torch
import torch.nn as nn
import mmdet3d.ops.spconv as spconv
from mmdet3d.ops import SparseBasicBlock
from mmdet.ops import build_norm_layer
from ..registry import MIDDLE_ENCODERS
from .sparse_block_utils import SparseBasicBlock
@MIDDLE_ENCODERS.register_module
......@@ -122,7 +122,10 @@ class SparseUnetV2(nn.Module):
# decoder
# [400, 352, 11] <- [200, 176, 5]
self.conv_up_t4 = SparseBasicBlock(
64, 64, indice_key='subm4', norm_cfg=norm_cfg)
64,
64,
conv_cfg=dict(type='SubMConv3d', indice_key='subm4'),
norm_cfg=norm_cfg)
self.conv_up_m4 = block(
128, 64, 3, norm_cfg=norm_cfg, padding=1, indice_key='subm4')
self.inv_conv4 = block(
......@@ -135,7 +138,10 @@ class SparseUnetV2(nn.Module):
# [800, 704, 21] <- [400, 352, 11]
self.conv_up_t3 = SparseBasicBlock(
64, 64, indice_key='subm3', norm_cfg=norm_cfg)
64,
64,
conv_cfg=dict(type='SubMConv3d', indice_key='subm3'),
norm_cfg=norm_cfg)
self.conv_up_m3 = block(
128, 64, 3, norm_cfg=norm_cfg, padding=1, indice_key='subm3')
self.inv_conv3 = block(
......@@ -148,7 +154,10 @@ class SparseUnetV2(nn.Module):
# [1600, 1408, 41] <- [800, 704, 21]
self.conv_up_t2 = SparseBasicBlock(
32, 32, indice_key='subm2', norm_cfg=norm_cfg)
32,
32,
conv_cfg=dict(type='SubMConv3d', indice_key='subm2'),
norm_cfg=norm_cfg)
self.conv_up_m2 = block(
64, 32, 3, norm_cfg=norm_cfg, indice_key='subm2')
self.inv_conv2 = block(
......@@ -161,7 +170,10 @@ class SparseUnetV2(nn.Module):
# [1600, 1408, 41] <- [1600, 1408, 41]
self.conv_up_t1 = SparseBasicBlock(
16, 16, indice_key='subm1', norm_cfg=norm_cfg)
16,
16,
conv_cfg=dict(type='SubMConv3d', indice_key='subm1'),
norm_cfg=norm_cfg)
self.conv_up_m1 = block(
32, 16, 3, norm_cfg=norm_cfg, indice_key='subm1')
......
......@@ -2,12 +2,29 @@ from mmdet.ops import (RoIAlign, SigmoidFocalLoss, get_compiler_version,
get_compiling_cuda_version, nms, roi_align,
sigmoid_focal_loss)
from .norm import NaiveSyncBatchNorm1d, NaiveSyncBatchNorm2d
from .sparse_block import (SparseBasicBlock, SparseBasicBlockV0,
SparseBottleneck, SparseBottleneckV0)
from .voxel import DynamicScatter, Voxelization, dynamic_scatter, voxelization
__all__ = [
'nms', 'soft_nms', 'RoIAlign', 'roi_align', 'get_compiler_version',
'get_compiling_cuda_version', 'build_conv_layer', 'NaiveSyncBatchNorm1d',
'NaiveSyncBatchNorm2d', 'batched_nms', 'Voxelization', 'voxelization',
'dynamic_scatter', 'DynamicScatter', 'sigmoid_focal_loss',
'SigmoidFocalLoss'
'nms',
'soft_nms',
'RoIAlign',
'roi_align',
'get_compiler_version',
'get_compiling_cuda_version',
'build_conv_layer',
'NaiveSyncBatchNorm1d',
'NaiveSyncBatchNorm2d',
'batched_nms',
'Voxelization',
'voxelization',
'dynamic_scatter',
'DynamicScatter',
'sigmoid_focal_loss',
'SigmoidFocalLoss',
'SparseBasicBlockV0',
'SparseBottleneckV0',
'SparseBasicBlock',
'SparseBottleneck',
]
from torch import nn
import mmdet3d.ops.spconv as spconv
from mmdet.models.backbones.resnet import BasicBlock, Bottleneck
from mmdet.ops import build_norm_layer
from mmdet.ops.conv import conv_cfg
conv_cfg.update({
'SubMConv3d': spconv.SubMConv3d,
})
def conv3x3(in_planes, out_planes, stride=1, indice_key=None):
......@@ -16,6 +22,7 @@ def conv3x3(in_planes, out_planes, stride=1, indice_key=None):
Returns:
spconv.conv.SubMConv3d: 3x3 submanifold sparse convolution ops
"""
# TODO: duplicate this class
return spconv.SubMConv3d(
in_planes,
out_planes,
......@@ -38,6 +45,7 @@ def conv1x1(in_planes, out_planes, stride=1, indice_key=None):
Returns:
spconv.conv.SubMConv3d: 1x1 submanifold sparse convolution ops
"""
# TODO: duplicate this class
return spconv.SubMConv3d(
in_planes,
out_planes,
......@@ -48,7 +56,7 @@ def conv1x1(in_planes, out_planes, stride=1, indice_key=None):
indice_key=indice_key)
class SparseBasicBlock(spconv.SparseModule):
class SparseBasicBlockV0(spconv.SparseModule):
expansion = 1
def __init__(self,
......@@ -62,7 +70,8 @@ class SparseBasicBlock(spconv.SparseModule):
Sparse basic block implemented with submanifold sparse convolution.
"""
super(SparseBasicBlock, self).__init__()
# TODO: duplicate this class
super().__init__()
self.conv1 = conv3x3(inplanes, planes, stride, indice_key=indice_key)
norm_name1, norm_layer1 = build_norm_layer(norm_cfg, planes)
self.bn1 = norm_layer1
......@@ -94,7 +103,7 @@ class SparseBasicBlock(spconv.SparseModule):
return out
class SparseBottleneck(spconv.SparseModule):
class SparseBottleneckV0(spconv.SparseModule):
expansion = 4
def __init__(self,
......@@ -108,7 +117,8 @@ class SparseBottleneck(spconv.SparseModule):
Bottleneck block implemented with submanifold sparse convolution.
"""
super(SparseBottleneck, self).__init__()
# TODO: duplicate this class
super().__init__()
self.conv1 = conv1x1(inplanes, planes, indice_key=indice_key)
self.bn1 = norm_fn(planes)
self.conv2 = conv3x3(planes, planes, stride, indice_key=indice_key)
......@@ -141,3 +151,95 @@ class SparseBottleneck(spconv.SparseModule):
out.features = self.relu(out.features)
return out
class SparseBottleneck(Bottleneck, spconv.SparseModule):
expansion = 4
def __init__(self,
inplanes,
planes,
stride=1,
downsample=None,
conv_cfg=None,
norm_cfg=None):
"""Sparse bottleneck block for PartA^2.
Bottleneck block implemented with submanifold sparse convolution.
"""
spconv.SparseModule.__init__(self)
Bottleneck.__init__(
self,
inplanes,
planes,
stride=stride,
downsample=downsample,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg)
def forward(self, x):
identity = x.features
out = self.conv1(x)
out.features = self.bn1(out.features)
out.features = self.relu(out.features)
out = self.conv2(out)
out.features = self.bn2(out.features)
out.features = self.relu(out.features)
out = self.conv3(out)
out.features = self.bn3(out.features)
if self.downsample is not None:
identity = self.downsample(x)
out.features += identity
out.features = self.relu(out.features)
return out
class SparseBasicBlock(BasicBlock, spconv.SparseModule):
expansion = 1
def __init__(self,
inplanes,
planes,
stride=1,
downsample=None,
conv_cfg=None,
norm_cfg=None):
"""Sparse basic block for PartA^2.
Sparse basic block implemented with submanifold sparse convolution.
"""
spconv.SparseModule.__init__(self)
BasicBlock.__init__(
self,
inplanes,
planes,
stride=stride,
downsample=downsample,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg)
def forward(self, x):
identity = x.features
assert x.features.dim() == 2, 'x.features.dim()=%d' % x.features.dim()
out = self.conv1(x)
out.features = self.norm1(out.features)
out.features = self.relu(out.features)
out = self.conv2(out)
out.features = self.norm2(out.features)
if self.downsample is not None:
identity = self.downsample(x)
out.features += identity
out.features = self.relu(out.features)
return out
......@@ -27,5 +27,37 @@ def test_SparseUnetV2():
assert spatial_features.shape == torch.Size([2, 256, 200, 176])
if __name__ == '__main__':
test_SparseUnetV2()
def test_SparseBasicBlock():
from mmdet3d.ops import SparseBasicBlockV0, SparseBasicBlock
import mmdet3d.ops.spconv as spconv
voxel_features = torch.tensor([[6.56126, 0.9648336, -1.7339306, 0.315],
[6.8162713, -2.480431, -1.3616394, 0.36],
[11.643568, -4.744306, -1.3580885, 0.16],
[23.482342, 6.5036807, 0.5806964, 0.35]],
dtype=torch.float32) # n, point_features
coordinates = torch.tensor(
[[0, 12, 819, 131], [0, 16, 750, 136], [1, 16, 705, 232],
[1, 35, 930, 469]],
dtype=torch.int32) # n, 4(batch, ind_x, ind_y, ind_z)
# test v0
self = SparseBasicBlockV0(
4,
4,
indice_key='subm0',
norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.01))
input_sp_tensor = spconv.SparseConvTensor(voxel_features, coordinates,
[41, 1600, 1408], 2)
out_features = self(input_sp_tensor)
assert out_features.features.shape == torch.Size([4, 4])
# test
input_sp_tensor = spconv.SparseConvTensor(voxel_features, coordinates,
[41, 1600, 1408], 2)
self = SparseBasicBlock(
4,
4,
conv_cfg=dict(type='SubMConv3d', indice_key='subm1'),
norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.01))
out_features = self(input_sp_tensor)
assert out_features.features.shape == torch.Size([4, 4])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment