Commit d0719036 authored by wuyuefeng's avatar wuyuefeng
Browse files

using f-string and modify some attributes

parent 7ca9e90f
...@@ -16,14 +16,14 @@ class SparseUnet(nn.Module): ...@@ -16,14 +16,14 @@ class SparseUnet(nn.Module):
pre_act=False, pre_act=False,
norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.01), norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.01),
base_channels=16, base_channels=16,
out_conv_channels=128, output_channels=128,
encode_conv_channels=((16, ), (32, 32, 32), (64, 64, 64), encoder_channels=((16, ), (32, 32, 32), (64, 64, 64), (64, 64,
(64, 64, 64)), 64)),
encode_paddings=((1, ), (1, 1, 1), (1, 1, 1), ((0, 1, 1), 1, encoder_paddings=((1, ), (1, 1, 1), (1, 1, 1), ((0, 1, 1), 1,
1)), 1)),
decode_conv_channels=((64, 64, 64), (64, 64, 32), decoder_channels=((64, 64, 64), (64, 64, 32), (32, 32, 16),
(32, 32, 16), (16, 16, 16)), (16, 16, 16)),
decode_paddings=((1, 0), (1, 0), (0, 0), (0, 1))): decoder_paddings=((1, 0), (1, 0), (0, 0), (0, 1))):
"""SparseUnet for PartA^2 """SparseUnet for PartA^2
See https://arxiv.org/abs/1907.03670 for more detials. See https://arxiv.org/abs/1907.03670 for more detials.
...@@ -32,15 +32,15 @@ class SparseUnet(nn.Module): ...@@ -32,15 +32,15 @@ class SparseUnet(nn.Module):
in_channels (int): the number of input channels in_channels (int): the number of input channels
output_shape (list[int]): the shape of output tensor output_shape (list[int]): the shape of output tensor
pre_act (bool): use pre_act_block or post_act_block pre_act (bool): use pre_act_block or post_act_block
norm_cfg (dict): normalize layer config norm_cfg (dict): config of normalization layer
base_channels (int): out channels for conv_input layer base_channels (int): out channels for conv_input layer
out_conv_channels (int): out channels for conv_out layer output_channels (int): out channels for conv_out layer
encode_conv_channels (tuple[tuple[int]]): encoder_channels (tuple[tuple[int]]):
conv channels of each encond block conv channels of each encode block
encode_paddings (tuple[tuple[int]]): paddings of each encode block encoder_paddings (tuple[tuple[int]]): paddings of each encode block
decode_conv_channels (tuple[tuple[int]]): decoder_channels (tuple[tuple[int]]):
conv channels of each decode block conv channels of each decode block
decode_paddings (tuple[tuple[int]]): paddings of each decode block decoder_paddings (tuple[tuple[int]]): paddings of each decode block
""" """
super().__init__() super().__init__()
self.sparse_shape = output_shape self.sparse_shape = output_shape
...@@ -48,14 +48,13 @@ class SparseUnet(nn.Module): ...@@ -48,14 +48,13 @@ class SparseUnet(nn.Module):
self.in_channels = in_channels self.in_channels = in_channels
self.pre_act = pre_act self.pre_act = pre_act
self.base_channels = base_channels self.base_channels = base_channels
self.out_conv_channels = out_conv_channels self.output_channels = output_channels
self.encode_conv_channels = encode_conv_channels self.encoder_channels = encoder_channels
self.encode_paddings = encode_paddings self.encoder_paddings = encoder_paddings
self.decode_conv_channels = decode_conv_channels self.decoder_channels = decoder_channels
self.decode_paddings = decode_paddings self.decoder_paddings = decoder_paddings
self.stage_num = len(self.encode_conv_channels) self.stage_num = len(self.encoder_channels)
# Spconv init all weight on its own # Spconv init all weight on its own
# TODO: make the network could be modified
if pre_act: if pre_act:
# TODO: use ConvModule to encapsulate # TODO: use ConvModule to encapsulate
...@@ -69,8 +68,6 @@ class SparseUnet(nn.Module): ...@@ -69,8 +68,6 @@ class SparseUnet(nn.Module):
indice_key='subm1'), ) indice_key='subm1'), )
make_block = self.pre_act_block make_block = self.pre_act_block
else: else:
norm_name, norm_layer = build_norm_layer(norm_cfg,
self.base_channels)
self.conv_input = spconv.SparseSequential( self.conv_input = spconv.SparseSequential(
spconv.SubMConv3d( spconv.SubMConv3d(
in_channels, in_channels,
...@@ -79,27 +76,25 @@ class SparseUnet(nn.Module): ...@@ -79,27 +76,25 @@ class SparseUnet(nn.Module):
padding=1, padding=1,
bias=False, bias=False,
indice_key='subm1'), indice_key='subm1'),
norm_layer, build_norm_layer(norm_cfg, self.base_channels)[1],
nn.ReLU(), nn.ReLU(),
) )
make_block = self.post_act_block make_block = self.post_act_block
encoder_out_channels = self.make_encode_layers(make_block, norm_cfg, encoder_out_channels = self.make_encoder_layers(
self.base_channels) make_block, norm_cfg, self.base_channels)
self.make_decode_layers(make_block, norm_cfg, encoder_out_channels) self.make_decoder_layers(make_block, norm_cfg, encoder_out_channels)
norm_name, norm_layer = build_norm_layer(norm_cfg,
self.out_conv_channels)
self.conv_out = spconv.SparseSequential( self.conv_out = spconv.SparseSequential(
# [200, 176, 5] -> [200, 176, 2] # [200, 176, 5] -> [200, 176, 2]
spconv.SparseConv3d( spconv.SparseConv3d(
encoder_out_channels, encoder_out_channels,
self.out_conv_channels, (3, 1, 1), self.output_channels, (3, 1, 1),
stride=(2, 1, 1), stride=(2, 1, 1),
padding=0, padding=0,
bias=False, bias=False,
indice_key='spconv_down2'), indice_key='spconv_down2'),
norm_layer, build_norm_layer(norm_cfg, self.output_channels)[1],
nn.ReLU(), nn.ReLU(),
) )
...@@ -144,12 +139,12 @@ class SparseUnet(nn.Module): ...@@ -144,12 +139,12 @@ class SparseUnet(nn.Module):
decode_features = [] decode_features = []
x = encode_features[-1] x = encode_features[-1]
for i in range(self.stage_num, 0, -1): for i in range(self.stage_num, 0, -1):
x = self.UR_block_forward( x = self.decoder_layer_forward(
encode_features[i - 1], encode_features[i - 1],
x, x,
getattr(self, 'conv_up_t{}'.format(i)), getattr(self, f'lateral_layer{i}'),
getattr(self, 'conv_up_m{}'.format(i)), getattr(self, f'merge_layer{i}'),
getattr(self, 'inv_conv{}'.format(i)), getattr(self, f'upsample_layer{i}'),
) )
decode_features.append(x) decode_features.append(x)
...@@ -159,7 +154,8 @@ class SparseUnet(nn.Module): ...@@ -159,7 +154,8 @@ class SparseUnet(nn.Module):
return ret return ret
def UR_block_forward(self, x_lateral, x_bottom, conv_t, conv_m, conv_inv): def decoder_layer_forward(self, x_lateral, x_bottom, conv_t, conv_m,
conv_inv):
"""Forward of upsample and residual block. """Forward of upsample and residual block.
Args: Args:
...@@ -183,7 +179,7 @@ class SparseUnet(nn.Module): ...@@ -183,7 +179,7 @@ class SparseUnet(nn.Module):
@staticmethod @staticmethod
def channel_reduction(x, out_channels): def channel_reduction(x, out_channels):
"""Channel reduction for element-wise add. """Channel reduction for element-wise addition.
Args: Args:
x (SparseConvTensor): x.features (N, C1) x (SparseConvTensor): x.features (N, C1)
...@@ -219,7 +215,7 @@ class SparseUnet(nn.Module): ...@@ -219,7 +215,7 @@ class SparseUnet(nn.Module):
stride (int): the stride of convolution stride (int): the stride of convolution
padding (int or list[int]): the padding number of input padding (int or list[int]): the padding number of input
conv_type (str): conv type in 'subm', 'spconv' or 'inverseconv' conv_type (str): conv type in 'subm', 'spconv' or 'inverseconv'
norm_cfg (dict): normal layer configs norm_cfg (dict): config of normalization layer
Returns: Returns:
spconv.SparseSequential: pre activate sparse convolution block. spconv.SparseSequential: pre activate sparse convolution block.
...@@ -227,10 +223,9 @@ class SparseUnet(nn.Module): ...@@ -227,10 +223,9 @@ class SparseUnet(nn.Module):
# TODO: use ConvModule to encapsulate # TODO: use ConvModule to encapsulate
assert conv_type in ['subm', 'spconv', 'inverseconv'] assert conv_type in ['subm', 'spconv', 'inverseconv']
norm_name, norm_layer = build_norm_layer(norm_cfg, in_channels)
if conv_type == 'subm': if conv_type == 'subm':
m = spconv.SparseSequential( m = spconv.SparseSequential(
norm_layer, build_norm_layer(norm_cfg, in_channels)[1],
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
spconv.SubMConv3d( spconv.SubMConv3d(
in_channels, in_channels,
...@@ -242,7 +237,7 @@ class SparseUnet(nn.Module): ...@@ -242,7 +237,7 @@ class SparseUnet(nn.Module):
) )
elif conv_type == 'spconv': elif conv_type == 'spconv':
m = spconv.SparseSequential( m = spconv.SparseSequential(
norm_layer, build_norm_layer(norm_cfg, in_channels)[1],
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
spconv.SparseConv3d( spconv.SparseConv3d(
in_channels, in_channels,
...@@ -255,7 +250,7 @@ class SparseUnet(nn.Module): ...@@ -255,7 +250,7 @@ class SparseUnet(nn.Module):
) )
elif conv_type == 'inverseconv': elif conv_type == 'inverseconv':
m = spconv.SparseSequential( m = spconv.SparseSequential(
norm_layer, build_norm_layer(norm_cfg, in_channels)[1],
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
spconv.SparseInverseConv3d( spconv.SparseInverseConv3d(
in_channels, in_channels,
...@@ -287,7 +282,7 @@ class SparseUnet(nn.Module): ...@@ -287,7 +282,7 @@ class SparseUnet(nn.Module):
stride (int): the stride of convolution stride (int): the stride of convolution
padding (int or list[int]): the padding number of input padding (int or list[int]): the padding number of input
conv_type (str): conv type in 'subm', 'spconv' or 'inverseconv' conv_type (str): conv type in 'subm', 'spconv' or 'inverseconv'
norm_cfg (dict[str]): normal layer configs norm_cfg (dict[str]): config of normalization layer
Returns: Returns:
spconv.SparseSequential: post activate sparse convolution block. spconv.SparseSequential: post activate sparse convolution block.
...@@ -295,7 +290,6 @@ class SparseUnet(nn.Module): ...@@ -295,7 +290,6 @@ class SparseUnet(nn.Module):
# TODO: use ConvModule to encapsulate # TODO: use ConvModule to encapsulate
assert conv_type in ['subm', 'spconv', 'inverseconv'] assert conv_type in ['subm', 'spconv', 'inverseconv']
norm_name, norm_layer = build_norm_layer(norm_cfg, out_channels)
if conv_type == 'subm': if conv_type == 'subm':
m = spconv.SparseSequential( m = spconv.SparseSequential(
spconv.SubMConv3d( spconv.SubMConv3d(
...@@ -304,7 +298,7 @@ class SparseUnet(nn.Module): ...@@ -304,7 +298,7 @@ class SparseUnet(nn.Module):
kernel_size, kernel_size,
bias=False, bias=False,
indice_key=indice_key), indice_key=indice_key),
norm_layer, build_norm_layer(norm_cfg, out_channels)[1],
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
) )
elif conv_type == 'spconv': elif conv_type == 'spconv':
...@@ -317,7 +311,7 @@ class SparseUnet(nn.Module): ...@@ -317,7 +311,7 @@ class SparseUnet(nn.Module):
padding=padding, padding=padding,
bias=False, bias=False,
indice_key=indice_key), indice_key=indice_key),
norm_layer, build_norm_layer(norm_cfg, out_channels)[1],
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
) )
elif conv_type == 'inverseconv': elif conv_type == 'inverseconv':
...@@ -328,29 +322,29 @@ class SparseUnet(nn.Module): ...@@ -328,29 +322,29 @@ class SparseUnet(nn.Module):
kernel_size, kernel_size,
bias=False, bias=False,
indice_key=indice_key), indice_key=indice_key),
norm_layer, build_norm_layer(norm_cfg, out_channels)[1],
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
) )
else: else:
raise NotImplementedError raise NotImplementedError
return m return m
def make_encode_layers(self, make_block, norm_cfg, in_channels): def make_encoder_layers(self, make_block, norm_cfg, in_channels):
"""make encode layers using sparse convs """make encoder layers using sparse convs
Args: Args:
make_block (method): a bounded function to build blocks make_block (method): a bounded function to build blocks
norm_cfg (dict[str]): normal layer configs norm_cfg (dict[str]): config of normalization layer
in_channels (int): the number of encoder input channels in_channels (int): the number of encoder input channels
Returns: Returns:
int: the number of encoder output channels int: the number of encoder output channels
""" """
self.encoder = [] self.encoder = []
for i, blocks in enumerate(self.encode_conv_channels): for i, blocks in enumerate(self.encoder_channels):
blocks_list = [] blocks_list = []
for j, out_channels in enumerate(tuple(blocks)): for j, out_channels in enumerate(tuple(blocks)):
padding = tuple(self.encode_paddings[i])[j] padding = tuple(self.encoder_paddings[i])[j]
# each stage started with a spconv layer # each stage started with a spconv layer
# except the first stage # except the first stage
if i != 0 and j == 0: if i != 0 and j == 0:
...@@ -362,7 +356,7 @@ class SparseUnet(nn.Module): ...@@ -362,7 +356,7 @@ class SparseUnet(nn.Module):
norm_cfg=norm_cfg, norm_cfg=norm_cfg,
stride=2, stride=2,
padding=padding, padding=padding,
indice_key='spconv{}'.format(i + 1), indice_key=f'spconv{i + 1}',
conv_type='spconv')) conv_type='spconv'))
else: else:
blocks_list.append( blocks_list.append(
...@@ -372,56 +366,55 @@ class SparseUnet(nn.Module): ...@@ -372,56 +366,55 @@ class SparseUnet(nn.Module):
3, 3,
norm_cfg=norm_cfg, norm_cfg=norm_cfg,
padding=padding, padding=padding,
indice_key='subm{}'.format(i + 1))) indice_key=f'subm{i + 1}'))
in_channels = out_channels in_channels = out_channels
stage_name = 'conv{}'.format(i + 1) stage_name = f'encoder_layer{i + 1}'
stage_layers = spconv.SparseSequential(*blocks_list) stage_layers = spconv.SparseSequential(*blocks_list)
self.add_module(stage_name, stage_layers) self.add_module(stage_name, stage_layers)
self.encoder.append(stage_name) self.encoder.append(stage_name)
return out_channels return out_channels
def make_decode_layers(self, make_block, norm_cfg, in_channels): def make_decoder_layers(self, make_block, norm_cfg, in_channels):
"""make decode layers using sparse convs """make decoder layers using sparse convs
Args: Args:
make_block (method): a bounded function to build blocks make_block (method): a bounded function to build blocks
norm_cfg (dict[str]): normal layer configs norm_cfg (dict[str]): config of normalization layer
in_channels (int): the number of encoder input channels in_channels (int): the number of encoder input channels
Returns: Returns:
int: the number of encoder output channels int: the number of encoder output channels
""" """
block_num = len(self.decode_conv_channels) block_num = len(self.decoder_channels)
for i, block_channels in enumerate(self.decode_conv_channels): for i, block_channels in enumerate(self.decoder_channels):
paddings = self.decode_paddings[i] paddings = self.decoder_paddings[i]
setattr( setattr(
self, 'conv_up_t{}'.format(block_num - i), self, f'lateral_layer{block_num - i}',
SparseBasicBlock( SparseBasicBlock(
in_channels, in_channels,
block_channels[0], block_channels[0],
conv_cfg=dict( conv_cfg=dict(
type='SubMConv3d', type='SubMConv3d', indice_key=f'subm{block_num - i}'),
indice_key='subm{}'.format(block_num - i)),
norm_cfg=norm_cfg)) norm_cfg=norm_cfg))
setattr( setattr(
self, 'conv_up_m{}'.format(block_num - i), self, f'merge_layer{block_num - i}',
make_block( make_block(
in_channels * 2, in_channels * 2,
block_channels[1], block_channels[1],
3, 3,
norm_cfg=norm_cfg, norm_cfg=norm_cfg,
padding=paddings[0], padding=paddings[0],
indice_key='subm{}'.format(block_num - i))) indice_key=f'subm{block_num - i}'))
setattr( setattr(
self, self,
'inv_conv{}'.format(block_num - i), f'upsample_layer{block_num - i}',
make_block( make_block(
in_channels, in_channels,
block_channels[2], block_channels[2],
3, 3,
norm_cfg=norm_cfg, norm_cfg=norm_cfg,
padding=paddings[1], padding=paddings[1],
indice_key='spconv{}'.format(block_num - i) indice_key=f'spconv{block_num - i}'
if block_num - i != 1 else 'subm1', if block_num - i != 1 else 'subm1',
conv_type='inverseconv' if block_num - i != 1 else conv_type='inverseconv' if block_num - i != 1 else
'subm') # use submanifold conv instead of inverse conv 'subm') # use submanifold conv instead of inverse conv
......
...@@ -22,7 +22,7 @@ def conv3x3(in_planes, out_planes, stride=1, indice_key=None): ...@@ -22,7 +22,7 @@ def conv3x3(in_planes, out_planes, stride=1, indice_key=None):
Returns: Returns:
spconv.conv.SubMConv3d: 3x3 submanifold sparse convolution ops spconv.conv.SubMConv3d: 3x3 submanifold sparse convolution ops
""" """
# TODO: duplicate this class # TODO: deprecate this class
return spconv.SubMConv3d( return spconv.SubMConv3d(
in_planes, in_planes,
out_planes, out_planes,
...@@ -45,7 +45,7 @@ def conv1x1(in_planes, out_planes, stride=1, indice_key=None): ...@@ -45,7 +45,7 @@ def conv1x1(in_planes, out_planes, stride=1, indice_key=None):
Returns: Returns:
spconv.conv.SubMConv3d: 1x1 submanifold sparse convolution ops spconv.conv.SubMConv3d: 1x1 submanifold sparse convolution ops
""" """
# TODO: duplicate this class # TODO: deprecate this class
return spconv.SubMConv3d( return spconv.SubMConv3d(
in_planes, in_planes,
out_planes, out_planes,
...@@ -70,7 +70,7 @@ class SparseBasicBlockV0(spconv.SparseModule): ...@@ -70,7 +70,7 @@ class SparseBasicBlockV0(spconv.SparseModule):
Sparse basic block implemented with submanifold sparse convolution. Sparse basic block implemented with submanifold sparse convolution.
""" """
# TODO: duplicate this class # TODO: deprecate this class
super().__init__() super().__init__()
self.conv1 = conv3x3(inplanes, planes, stride, indice_key=indice_key) self.conv1 = conv3x3(inplanes, planes, stride, indice_key=indice_key)
norm_name1, norm_layer1 = build_norm_layer(norm_cfg, planes) norm_name1, norm_layer1 = build_norm_layer(norm_cfg, planes)
...@@ -85,7 +85,7 @@ class SparseBasicBlockV0(spconv.SparseModule): ...@@ -85,7 +85,7 @@ class SparseBasicBlockV0(spconv.SparseModule):
def forward(self, x): def forward(self, x):
identity = x.features identity = x.features
assert x.features.dim() == 2, 'x.features.dim()=%d' % x.features.dim() assert x.features.dim() == 2, f'x.features.dim()={x.features.dim()}'
out = self.conv1(x) out = self.conv1(x)
out.features = self.bn1(out.features) out.features = self.bn1(out.features)
...@@ -117,7 +117,7 @@ class SparseBottleneckV0(spconv.SparseModule): ...@@ -117,7 +117,7 @@ class SparseBottleneckV0(spconv.SparseModule):
Bottleneck block implemented with submanifold sparse convolution. Bottleneck block implemented with submanifold sparse convolution.
""" """
# TODO: duplicate this class # TODO: deprecate this class
super().__init__() super().__init__()
self.conv1 = conv1x1(inplanes, planes, indice_key=indice_key) self.conv1 = conv1x1(inplanes, planes, indice_key=indice_key)
self.bn1 = norm_fn(planes) self.bn1 = norm_fn(planes)
...@@ -227,7 +227,7 @@ class SparseBasicBlock(BasicBlock, spconv.SparseModule): ...@@ -227,7 +227,7 @@ class SparseBasicBlock(BasicBlock, spconv.SparseModule):
def forward(self, x): def forward(self, x):
identity = x.features identity = x.features
assert x.features.dim() == 2, 'x.features.dim()=%d' % x.features.dim() assert x.features.dim() == 2, f'x.features.dim()={x.features.dim()}'
out = self.conv1(x) out = self.conv1(x)
out.features = self.norm1(out.features) out.features = self.norm1(out.features)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment