Commit d0719036 authored by wuyuefeng's avatar wuyuefeng
Browse files

using f-string and modify some attributes

parent 7ca9e90f
......@@ -16,14 +16,14 @@ class SparseUnet(nn.Module):
pre_act=False,
norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.01),
base_channels=16,
out_conv_channels=128,
encode_conv_channels=((16, ), (32, 32, 32), (64, 64, 64),
(64, 64, 64)),
encode_paddings=((1, ), (1, 1, 1), (1, 1, 1), ((0, 1, 1), 1,
1)),
decode_conv_channels=((64, 64, 64), (64, 64, 32),
(32, 32, 16), (16, 16, 16)),
decode_paddings=((1, 0), (1, 0), (0, 0), (0, 1))):
output_channels=128,
encoder_channels=((16, ), (32, 32, 32), (64, 64, 64), (64, 64,
64)),
encoder_paddings=((1, ), (1, 1, 1), (1, 1, 1), ((0, 1, 1), 1,
1)),
decoder_channels=((64, 64, 64), (64, 64, 32), (32, 32, 16),
(16, 16, 16)),
decoder_paddings=((1, 0), (1, 0), (0, 0), (0, 1))):
"""SparseUnet for PartA^2
See https://arxiv.org/abs/1907.03670 for more detials.
......@@ -32,15 +32,15 @@ class SparseUnet(nn.Module):
in_channels (int): the number of input channels
output_shape (list[int]): the shape of output tensor
pre_act (bool): use pre_act_block or post_act_block
norm_cfg (dict): normalize layer config
norm_cfg (dict): config of normalization layer
base_channels (int): out channels for conv_input layer
out_conv_channels (int): out channels for conv_out layer
encode_conv_channels (tuple[tuple[int]]):
conv channels of each encond block
encode_paddings (tuple[tuple[int]]): paddings of each encode block
decode_conv_channels (tuple[tuple[int]]):
output_channels (int): out channels for conv_out layer
encoder_channels (tuple[tuple[int]]):
conv channels of each encode block
encoder_paddings (tuple[tuple[int]]): paddings of each encode block
decoder_channels (tuple[tuple[int]]):
conv channels of each decode block
decode_paddings (tuple[tuple[int]]): paddings of each decode block
decoder_paddings (tuple[tuple[int]]): paddings of each decode block
"""
super().__init__()
self.sparse_shape = output_shape
......@@ -48,14 +48,13 @@ class SparseUnet(nn.Module):
self.in_channels = in_channels
self.pre_act = pre_act
self.base_channels = base_channels
self.out_conv_channels = out_conv_channels
self.encode_conv_channels = encode_conv_channels
self.encode_paddings = encode_paddings
self.decode_conv_channels = decode_conv_channels
self.decode_paddings = decode_paddings
self.stage_num = len(self.encode_conv_channels)
self.output_channels = output_channels
self.encoder_channels = encoder_channels
self.encoder_paddings = encoder_paddings
self.decoder_channels = decoder_channels
self.decoder_paddings = decoder_paddings
self.stage_num = len(self.encoder_channels)
# Spconv init all weight on its own
# TODO: make the network could be modified
if pre_act:
# TODO: use ConvModule to encapsulate
......@@ -69,8 +68,6 @@ class SparseUnet(nn.Module):
indice_key='subm1'), )
make_block = self.pre_act_block
else:
norm_name, norm_layer = build_norm_layer(norm_cfg,
self.base_channels)
self.conv_input = spconv.SparseSequential(
spconv.SubMConv3d(
in_channels,
......@@ -79,27 +76,25 @@ class SparseUnet(nn.Module):
padding=1,
bias=False,
indice_key='subm1'),
norm_layer,
build_norm_layer(norm_cfg, self.base_channels)[1],
nn.ReLU(),
)
make_block = self.post_act_block
encoder_out_channels = self.make_encode_layers(make_block, norm_cfg,
self.base_channels)
self.make_decode_layers(make_block, norm_cfg, encoder_out_channels)
encoder_out_channels = self.make_encoder_layers(
make_block, norm_cfg, self.base_channels)
self.make_decoder_layers(make_block, norm_cfg, encoder_out_channels)
norm_name, norm_layer = build_norm_layer(norm_cfg,
self.out_conv_channels)
self.conv_out = spconv.SparseSequential(
# [200, 176, 5] -> [200, 176, 2]
spconv.SparseConv3d(
encoder_out_channels,
self.out_conv_channels, (3, 1, 1),
self.output_channels, (3, 1, 1),
stride=(2, 1, 1),
padding=0,
bias=False,
indice_key='spconv_down2'),
norm_layer,
build_norm_layer(norm_cfg, self.output_channels)[1],
nn.ReLU(),
)
......@@ -144,12 +139,12 @@ class SparseUnet(nn.Module):
decode_features = []
x = encode_features[-1]
for i in range(self.stage_num, 0, -1):
x = self.UR_block_forward(
x = self.decoder_layer_forward(
encode_features[i - 1],
x,
getattr(self, 'conv_up_t{}'.format(i)),
getattr(self, 'conv_up_m{}'.format(i)),
getattr(self, 'inv_conv{}'.format(i)),
getattr(self, f'lateral_layer{i}'),
getattr(self, f'merge_layer{i}'),
getattr(self, f'upsample_layer{i}'),
)
decode_features.append(x)
......@@ -159,7 +154,8 @@ class SparseUnet(nn.Module):
return ret
def UR_block_forward(self, x_lateral, x_bottom, conv_t, conv_m, conv_inv):
def decoder_layer_forward(self, x_lateral, x_bottom, conv_t, conv_m,
conv_inv):
"""Forward of upsample and residual block.
Args:
......@@ -183,7 +179,7 @@ class SparseUnet(nn.Module):
@staticmethod
def channel_reduction(x, out_channels):
"""Channel reduction for element-wise add.
"""Channel reduction for element-wise addition.
Args:
x (SparseConvTensor): x.features (N, C1)
......@@ -219,7 +215,7 @@ class SparseUnet(nn.Module):
stride (int): the stride of convolution
padding (int or list[int]): the padding number of input
conv_type (str): conv type in 'subm', 'spconv' or 'inverseconv'
norm_cfg (dict): normal layer configs
norm_cfg (dict): config of normalization layer
Returns:
spconv.SparseSequential: pre activate sparse convolution block.
......@@ -227,10 +223,9 @@ class SparseUnet(nn.Module):
# TODO: use ConvModule to encapsulate
assert conv_type in ['subm', 'spconv', 'inverseconv']
norm_name, norm_layer = build_norm_layer(norm_cfg, in_channels)
if conv_type == 'subm':
m = spconv.SparseSequential(
norm_layer,
build_norm_layer(norm_cfg, in_channels)[1],
nn.ReLU(inplace=True),
spconv.SubMConv3d(
in_channels,
......@@ -242,7 +237,7 @@ class SparseUnet(nn.Module):
)
elif conv_type == 'spconv':
m = spconv.SparseSequential(
norm_layer,
build_norm_layer(norm_cfg, in_channels)[1],
nn.ReLU(inplace=True),
spconv.SparseConv3d(
in_channels,
......@@ -255,7 +250,7 @@ class SparseUnet(nn.Module):
)
elif conv_type == 'inverseconv':
m = spconv.SparseSequential(
norm_layer,
build_norm_layer(norm_cfg, in_channels)[1],
nn.ReLU(inplace=True),
spconv.SparseInverseConv3d(
in_channels,
......@@ -287,7 +282,7 @@ class SparseUnet(nn.Module):
stride (int): the stride of convolution
padding (int or list[int]): the padding number of input
conv_type (str): conv type in 'subm', 'spconv' or 'inverseconv'
norm_cfg (dict[str]): normal layer configs
norm_cfg (dict[str]): config of normalization layer
Returns:
spconv.SparseSequential: post activate sparse convolution block.
......@@ -295,7 +290,6 @@ class SparseUnet(nn.Module):
# TODO: use ConvModule to encapsulate
assert conv_type in ['subm', 'spconv', 'inverseconv']
norm_name, norm_layer = build_norm_layer(norm_cfg, out_channels)
if conv_type == 'subm':
m = spconv.SparseSequential(
spconv.SubMConv3d(
......@@ -304,7 +298,7 @@ class SparseUnet(nn.Module):
kernel_size,
bias=False,
indice_key=indice_key),
norm_layer,
build_norm_layer(norm_cfg, out_channels)[1],
nn.ReLU(inplace=True),
)
elif conv_type == 'spconv':
......@@ -317,7 +311,7 @@ class SparseUnet(nn.Module):
padding=padding,
bias=False,
indice_key=indice_key),
norm_layer,
build_norm_layer(norm_cfg, out_channels)[1],
nn.ReLU(inplace=True),
)
elif conv_type == 'inverseconv':
......@@ -328,29 +322,29 @@ class SparseUnet(nn.Module):
kernel_size,
bias=False,
indice_key=indice_key),
norm_layer,
build_norm_layer(norm_cfg, out_channels)[1],
nn.ReLU(inplace=True),
)
else:
raise NotImplementedError
return m
def make_encode_layers(self, make_block, norm_cfg, in_channels):
"""make encode layers using sparse convs
def make_encoder_layers(self, make_block, norm_cfg, in_channels):
"""make encoder layers using sparse convs
Args:
make_block (method): a bounded function to build blocks
norm_cfg (dict[str]): normal layer configs
norm_cfg (dict[str]): config of normalization layer
in_channels (int): the number of encoder input channels
Returns:
int: the number of encoder output channels
"""
self.encoder = []
for i, blocks in enumerate(self.encode_conv_channels):
for i, blocks in enumerate(self.encoder_channels):
blocks_list = []
for j, out_channels in enumerate(tuple(blocks)):
padding = tuple(self.encode_paddings[i])[j]
padding = tuple(self.encoder_paddings[i])[j]
# each stage started with a spconv layer
# except the first stage
if i != 0 and j == 0:
......@@ -362,7 +356,7 @@ class SparseUnet(nn.Module):
norm_cfg=norm_cfg,
stride=2,
padding=padding,
indice_key='spconv{}'.format(i + 1),
indice_key=f'spconv{i + 1}',
conv_type='spconv'))
else:
blocks_list.append(
......@@ -372,56 +366,55 @@ class SparseUnet(nn.Module):
3,
norm_cfg=norm_cfg,
padding=padding,
indice_key='subm{}'.format(i + 1)))
indice_key=f'subm{i + 1}'))
in_channels = out_channels
stage_name = 'conv{}'.format(i + 1)
stage_name = f'encoder_layer{i + 1}'
stage_layers = spconv.SparseSequential(*blocks_list)
self.add_module(stage_name, stage_layers)
self.encoder.append(stage_name)
return out_channels
def make_decode_layers(self, make_block, norm_cfg, in_channels):
"""make decode layers using sparse convs
def make_decoder_layers(self, make_block, norm_cfg, in_channels):
"""make decoder layers using sparse convs
Args:
make_block (method): a bounded function to build blocks
norm_cfg (dict[str]): normal layer configs
norm_cfg (dict[str]): config of normalization layer
in_channels (int): the number of encoder input channels
Returns:
int: the number of encoder output channels
"""
block_num = len(self.decode_conv_channels)
for i, block_channels in enumerate(self.decode_conv_channels):
paddings = self.decode_paddings[i]
block_num = len(self.decoder_channels)
for i, block_channels in enumerate(self.decoder_channels):
paddings = self.decoder_paddings[i]
setattr(
self, 'conv_up_t{}'.format(block_num - i),
self, f'lateral_layer{block_num - i}',
SparseBasicBlock(
in_channels,
block_channels[0],
conv_cfg=dict(
type='SubMConv3d',
indice_key='subm{}'.format(block_num - i)),
type='SubMConv3d', indice_key=f'subm{block_num - i}'),
norm_cfg=norm_cfg))
setattr(
self, 'conv_up_m{}'.format(block_num - i),
self, f'merge_layer{block_num - i}',
make_block(
in_channels * 2,
block_channels[1],
3,
norm_cfg=norm_cfg,
padding=paddings[0],
indice_key='subm{}'.format(block_num - i)))
indice_key=f'subm{block_num - i}'))
setattr(
self,
'inv_conv{}'.format(block_num - i),
f'upsample_layer{block_num - i}',
make_block(
in_channels,
block_channels[2],
3,
norm_cfg=norm_cfg,
padding=paddings[1],
indice_key='spconv{}'.format(block_num - i)
indice_key=f'spconv{block_num - i}'
if block_num - i != 1 else 'subm1',
conv_type='inverseconv' if block_num - i != 1 else
'subm') # use submanifold conv instead of inverse conv
......
......@@ -22,7 +22,7 @@ def conv3x3(in_planes, out_planes, stride=1, indice_key=None):
Returns:
spconv.conv.SubMConv3d: 3x3 submanifold sparse convolution ops
"""
# TODO: duplicate this class
# TODO: deprecate this class
return spconv.SubMConv3d(
in_planes,
out_planes,
......@@ -45,7 +45,7 @@ def conv1x1(in_planes, out_planes, stride=1, indice_key=None):
Returns:
spconv.conv.SubMConv3d: 1x1 submanifold sparse convolution ops
"""
# TODO: duplicate this class
# TODO: deprecate this class
return spconv.SubMConv3d(
in_planes,
out_planes,
......@@ -70,7 +70,7 @@ class SparseBasicBlockV0(spconv.SparseModule):
Sparse basic block implemented with submanifold sparse convolution.
"""
# TODO: duplicate this class
# TODO: deprecate this class
super().__init__()
self.conv1 = conv3x3(inplanes, planes, stride, indice_key=indice_key)
norm_name1, norm_layer1 = build_norm_layer(norm_cfg, planes)
......@@ -85,7 +85,7 @@ class SparseBasicBlockV0(spconv.SparseModule):
def forward(self, x):
identity = x.features
assert x.features.dim() == 2, 'x.features.dim()=%d' % x.features.dim()
assert x.features.dim() == 2, f'x.features.dim()={x.features.dim()}'
out = self.conv1(x)
out.features = self.bn1(out.features)
......@@ -117,7 +117,7 @@ class SparseBottleneckV0(spconv.SparseModule):
Bottleneck block implemented with submanifold sparse convolution.
"""
# TODO: duplicate this class
# TODO: deprecate this class
super().__init__()
self.conv1 = conv1x1(inplanes, planes, indice_key=indice_key)
self.bn1 = norm_fn(planes)
......@@ -227,7 +227,7 @@ class SparseBasicBlock(BasicBlock, spconv.SparseModule):
def forward(self, x):
identity = x.features
assert x.features.dim() == 2, 'x.features.dim()=%d' % x.features.dim()
assert x.features.dim() == 2, f'x.features.dim()={x.features.dim()}'
out = self.conv1(x)
out.features = self.norm1(out.features)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment