".github/vscode:/vscode.git/clone" did not exist on "070850cf8a5b4591e6d699363f88bd970973d63d"
Commit 66315452 authored by wuyuefeng's avatar wuyuefeng
Browse files

unittest

parent d0719036
...@@ -116,9 +116,8 @@ class SparseUnet(nn.Module): ...@@ -116,9 +116,8 @@ class SparseUnet(nn.Module):
x = self.conv_input(input_sp_tensor) x = self.conv_input(input_sp_tensor)
encode_features = [] encode_features = []
for i, stage_name in enumerate(self.encoder): for encoder_layer in self.encoder_layers:
stage = getattr(self, stage_name) x = encoder_layer(x)
x = stage(x)
encode_features.append(x) encode_features.append(x)
# for detection head # for detection head
...@@ -154,32 +153,32 @@ class SparseUnet(nn.Module): ...@@ -154,32 +153,32 @@ class SparseUnet(nn.Module):
return ret return ret
def decoder_layer_forward(self, x_lateral, x_bottom, conv_t, conv_m, def decoder_layer_forward(self, x_lateral, x_bottom, lateral_layer,
conv_inv): merge_layer, upsample_layer):
"""Forward of upsample and residual block. """Forward of upsample and residual block.
Args: Args:
x_lateral (SparseConvTensor): lateral tensor x_lateral (SparseConvTensor): lateral tensor
x_bottom (SparseConvTensor): tensor from bottom layer x_bottom (SparseConvTensor): tensor from bottom layer
conv_t (SparseBasicBlock): convolution for lateral tensor lateral_layer (SparseBasicBlock): convolution for lateral tensor
conv_m (SparseSequential): convolution for merging features merge_layer (SparseSequential): convolution for merging features
conv_inv (SparseSequential): convolution for upsampling upsample_layer (SparseSequential): convolution for upsampling
Returns: Returns:
SparseConvTensor: upsampled feature SparseConvTensor: upsampled feature
""" """
x_trans = conv_t(x_lateral) x_trans = lateral_layer(x_lateral)
x = x_trans x = x_trans
x.features = torch.cat((x_bottom.features, x_trans.features), dim=1) x.features = torch.cat((x_bottom.features, x_trans.features), dim=1)
x_m = conv_m(x) x_m = merge_layer(x)
x = self.channel_reduction(x, x_m.features.shape[1]) x = self.reduce_channel(x, x_m.features.shape[1])
x.features = x_m.features + x.features x.features = x_m.features + x.features
x = conv_inv(x) x = upsample_layer(x)
return x return x
@staticmethod @staticmethod
def channel_reduction(x, out_channels): def reduce_channel(x, out_channels):
"""Channel reduction for element-wise addition. """reduce channel for element-wise addition.
Args: Args:
x (SparseConvTensor): x.features (N, C1) x (SparseConvTensor): x.features (N, C1)
...@@ -340,7 +339,7 @@ class SparseUnet(nn.Module): ...@@ -340,7 +339,7 @@ class SparseUnet(nn.Module):
Returns: Returns:
int: the number of encoder output channels int: the number of encoder output channels
""" """
self.encoder = [] self.encoder_layers = spconv.SparseSequential()
for i, blocks in enumerate(self.encoder_channels): for i, blocks in enumerate(self.encoder_channels):
blocks_list = [] blocks_list = []
for j, out_channels in enumerate(tuple(blocks)): for j, out_channels in enumerate(tuple(blocks)):
...@@ -370,8 +369,7 @@ class SparseUnet(nn.Module): ...@@ -370,8 +369,7 @@ class SparseUnet(nn.Module):
in_channels = out_channels in_channels = out_channels
stage_name = f'encoder_layer{i + 1}' stage_name = f'encoder_layer{i + 1}'
stage_layers = spconv.SparseSequential(*blocks_list) stage_layers = spconv.SparseSequential(*blocks_list)
self.add_module(stage_name, stage_layers) self.encoder_layers.add_module(stage_name, stage_layers)
self.encoder.append(stage_name)
return out_channels return out_channels
def make_decoder_layers(self, make_block, norm_cfg, in_channels): def make_decoder_layers(self, make_block, norm_cfg, in_channels):
...@@ -405,19 +403,28 @@ class SparseUnet(nn.Module): ...@@ -405,19 +403,28 @@ class SparseUnet(nn.Module):
norm_cfg=norm_cfg, norm_cfg=norm_cfg,
padding=paddings[0], padding=paddings[0],
indice_key=f'subm{block_num - i}')) indice_key=f'subm{block_num - i}'))
if block_num - i != 1:
setattr( setattr(
self, self, f'upsample_layer{block_num - i}',
f'upsample_layer{block_num - i}',
make_block( make_block(
in_channels, in_channels,
block_channels[2], block_channels[2],
3, 3,
norm_cfg=norm_cfg, norm_cfg=norm_cfg,
padding=paddings[1], padding=paddings[1],
indice_key=f'spconv{block_num - i}' indice_key=f'spconv{block_num - i}',
if block_num - i != 1 else 'subm1', conv_type='inverseconv'))
conv_type='inverseconv' if block_num - i != 1 else else:
'subm') # use submanifold conv instead of inverse conv # use submanifold conv instead of inverse conv
# in the last block # in the last block
) setattr(
self, f'upsample_layer{block_num - i}',
make_block(
in_channels,
block_channels[2],
3,
norm_cfg=norm_cfg,
padding=paddings[1],
indice_key='subm1',
conv_type='subm'))
in_channels = block_channels[2] in_channels = block_channels[2]
import torch import torch
import mmdet3d.ops.spconv as spconv
from mmdet3d.ops import SparseBasicBlock, SparseBasicBlockV0
def test_SparseUnet(): def test_SparseUnet():
from mmdet3d.models.middle_encoders.sparse_unet import SparseUnet from mmdet3d.models.middle_encoders.sparse_unet import SparseUnet
self = SparseUnet( self = SparseUnet(
in_channels=4, output_shape=[41, 1600, 1408], pre_act=False) in_channels=4, output_shape=[41, 1600, 1408], pre_act=False)
# test encoder layers
assert len(self.encoder_layers) == 4
assert self.encoder_layers.encoder_layer1[0][0].in_channels == 16
assert self.encoder_layers.encoder_layer1[0][0].out_channels == 16
assert isinstance(self.encoder_layers.encoder_layer1[0][0],
spconv.conv.SubMConv3d)
assert isinstance(self.encoder_layers.encoder_layer1[0][1],
torch.nn.modules.batchnorm.BatchNorm1d)
assert isinstance(self.encoder_layers.encoder_layer1[0][2],
torch.nn.modules.activation.ReLU)
assert self.encoder_layers.encoder_layer4[0][0].in_channels == 64
assert self.encoder_layers.encoder_layer4[0][0].out_channels == 64
assert isinstance(self.encoder_layers.encoder_layer4[0][0],
spconv.conv.SparseConv3d)
assert isinstance(self.encoder_layers.encoder_layer4[2][0],
spconv.conv.SubMConv3d)
# test decoder layers
assert isinstance(self.lateral_layer1, SparseBasicBlock)
assert isinstance(self.merge_layer1[0], spconv.conv.SubMConv3d)
assert isinstance(self.upsample_layer1[0], spconv.conv.SubMConv3d)
assert isinstance(self.upsample_layer2[0], spconv.conv.SparseInverseConv3d)
voxel_features = torch.tensor([[6.56126, 0.9648336, -1.7339306, 0.315], voxel_features = torch.tensor([[6.56126, 0.9648336, -1.7339306, 0.315],
[6.8162713, -2.480431, -1.3616394, 0.36], [6.8162713, -2.480431, -1.3616394, 0.36],
[11.643568, -4.744306, -1.3580885, 0.16], [11.643568, -4.744306, -1.3580885, 0.16],
...@@ -24,8 +51,6 @@ def test_SparseUnet(): ...@@ -24,8 +51,6 @@ def test_SparseUnet():
def test_SparseBasicBlock(): def test_SparseBasicBlock():
from mmdet3d.ops import SparseBasicBlockV0, SparseBasicBlock
import mmdet3d.ops.spconv as spconv
voxel_features = torch.tensor([[6.56126, 0.9648336, -1.7339306, 0.315], voxel_features = torch.tensor([[6.56126, 0.9648336, -1.7339306, 0.315],
[6.8162713, -2.480431, -1.3616394, 0.36], [6.8162713, -2.480431, -1.3616394, 0.36],
[11.643568, -4.744306, -1.3580885, 0.16], [11.643568, -4.744306, -1.3580885, 0.16],
...@@ -55,5 +80,15 @@ def test_SparseBasicBlock(): ...@@ -55,5 +80,15 @@ def test_SparseBasicBlock():
4, 4,
conv_cfg=dict(type='SubMConv3d', indice_key='subm1'), conv_cfg=dict(type='SubMConv3d', indice_key='subm1'),
norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.01)) norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.01))
# test conv and bn layer
assert isinstance(self.conv1, spconv.conv.SubMConv3d)
assert self.conv1.in_channels == 4
assert self.conv1.out_channels == 4
assert isinstance(self.conv2, spconv.conv.SubMConv3d)
assert self.conv2.out_channels == 4
assert self.conv2.out_channels == 4
assert self.bn1.eps == 1e-3
assert self.bn1.momentum == 0.01
out_features = self(input_sp_tensor) out_features = self(input_sp_tensor)
assert out_features.features.shape == torch.Size([4, 4]) assert out_features.features.shape == torch.Size([4, 4])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment