"git@developer.sourcefind.cn:change/sglang.git" did not exist on "96d0e37fa7621c37a130ec12f867c8f99c9ef878"
Commit 7ca9e90f authored by wuyuefeng's avatar wuyuefeng
Browse files

refactor sparse_unet

parent 1a74819d
......@@ -14,7 +14,16 @@ class SparseUnet(nn.Module):
in_channels,
output_shape,
pre_act=False,
norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.01)):
norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.01),
base_channels=16,
out_conv_channels=128,
encode_conv_channels=((16, ), (32, 32, 32), (64, 64, 64),
(64, 64, 64)),
encode_paddings=((1, ), (1, 1, 1), (1, 1, 1), ((0, 1, 1), 1,
1)),
decode_conv_channels=((64, 64, 64), (64, 64, 32),
(32, 32, 16), (16, 16, 16)),
decode_paddings=((1, 0), (1, 0), (0, 0), (0, 1))):
"""SparseUnet for PartA^2
See https://arxiv.org/abs/1907.03670 for more detials.
......@@ -24,12 +33,27 @@ class SparseUnet(nn.Module):
output_shape (list[int]): the shape of output tensor
pre_act (bool): use pre_act_block or post_act_block
norm_cfg (dict): normalize layer config
base_channels (int): out channels for conv_input layer
out_conv_channels (int): out channels for conv_out layer
encode_conv_channels (tuple[tuple[int]]):
conv channels of each encond block
encode_paddings (tuple[tuple[int]]): paddings of each encode block
decode_conv_channels (tuple[tuple[int]]):
conv channels of each decode block
decode_paddings (tuple[tuple[int]]): paddings of each decode block
"""
super().__init__()
self.sparse_shape = output_shape
self.output_shape = output_shape
self.in_channels = in_channels
self.pre_act = pre_act
self.base_channels = base_channels
self.out_conv_channels = out_conv_channels
self.encode_conv_channels = encode_conv_channels
self.encode_paddings = encode_paddings
self.decode_conv_channels = decode_conv_channels
self.decode_paddings = decode_paddings
self.stage_num = len(self.encode_conv_channels)
# Spconv init all weight on its own
# TODO: make the network could be modified
......@@ -38,18 +62,19 @@ class SparseUnet(nn.Module):
self.conv_input = spconv.SparseSequential(
spconv.SubMConv3d(
in_channels,
16,
self.base_channels,
3,
padding=1,
bias=False,
indice_key='subm1'), )
block = self.pre_act_block
make_block = self.pre_act_block
else:
norm_name, norm_layer = build_norm_layer(norm_cfg, 16)
norm_name, norm_layer = build_norm_layer(norm_cfg,
self.base_channels)
self.conv_input = spconv.SparseSequential(
spconv.SubMConv3d(
in_channels,
16,
self.base_channels,
3,
padding=1,
bias=False,
......@@ -57,63 +82,19 @@ class SparseUnet(nn.Module):
norm_layer,
nn.ReLU(),
)
block = self.post_act_block
self.conv1 = spconv.SparseSequential(
block(16, 16, 3, norm_cfg=norm_cfg, padding=1,
indice_key='subm1'), )
self.conv2 = spconv.SparseSequential(
# [1600, 1408, 41] -> [800, 704, 21]
block(
16,
32,
3,
norm_cfg=norm_cfg,
stride=2,
padding=1,
indice_key='spconv2',
conv_type='spconv'),
block(32, 32, 3, norm_cfg=norm_cfg, padding=1, indice_key='subm2'),
block(32, 32, 3, norm_cfg=norm_cfg, padding=1, indice_key='subm2'),
)
self.conv3 = spconv.SparseSequential(
# [800, 704, 21] -> [400, 352, 11]
block(
32,
64,
3,
norm_cfg=norm_cfg,
stride=2,
padding=1,
indice_key='spconv3',
conv_type='spconv'),
block(64, 64, 3, norm_cfg=norm_cfg, padding=1, indice_key='subm3'),
block(64, 64, 3, norm_cfg=norm_cfg, padding=1, indice_key='subm3'),
)
make_block = self.post_act_block
self.conv4 = spconv.SparseSequential(
# [400, 352, 11] -> [200, 176, 5]
block(
64,
64,
3,
norm_cfg=norm_cfg,
stride=2,
padding=(0, 1, 1),
indice_key='spconv4',
conv_type='spconv'),
block(64, 64, 3, norm_cfg=norm_cfg, padding=1, indice_key='subm4'),
block(64, 64, 3, norm_cfg=norm_cfg, padding=1, indice_key='subm4'),
)
encoder_out_channels = self.make_encode_layers(make_block, norm_cfg,
self.base_channels)
self.make_decode_layers(make_block, norm_cfg, encoder_out_channels)
norm_name, norm_layer = build_norm_layer(norm_cfg, 128)
norm_name, norm_layer = build_norm_layer(norm_cfg,
self.out_conv_channels)
self.conv_out = spconv.SparseSequential(
# [200, 176, 5] -> [200, 176, 2]
spconv.SparseConv3d(
64,
128, (3, 1, 1),
encoder_out_channels,
self.out_conv_channels, (3, 1, 1),
stride=(2, 1, 1),
padding=0,
bias=False,
......@@ -122,67 +103,6 @@ class SparseUnet(nn.Module):
nn.ReLU(),
)
# decoder
# [400, 352, 11] <- [200, 176, 5]
self.conv_up_t4 = SparseBasicBlock(
64,
64,
conv_cfg=dict(type='SubMConv3d', indice_key='subm4'),
norm_cfg=norm_cfg)
self.conv_up_m4 = block(
128, 64, 3, norm_cfg=norm_cfg, padding=1, indice_key='subm4')
self.inv_conv4 = block(
64,
64,
3,
norm_cfg=norm_cfg,
indice_key='spconv4',
conv_type='inverseconv')
# [800, 704, 21] <- [400, 352, 11]
self.conv_up_t3 = SparseBasicBlock(
64,
64,
conv_cfg=dict(type='SubMConv3d', indice_key='subm3'),
norm_cfg=norm_cfg)
self.conv_up_m3 = block(
128, 64, 3, norm_cfg=norm_cfg, padding=1, indice_key='subm3')
self.inv_conv3 = block(
64,
32,
3,
norm_cfg=norm_cfg,
indice_key='spconv3',
conv_type='inverseconv')
# [1600, 1408, 41] <- [800, 704, 21]
self.conv_up_t2 = SparseBasicBlock(
32,
32,
conv_cfg=dict(type='SubMConv3d', indice_key='subm2'),
norm_cfg=norm_cfg)
self.conv_up_m2 = block(
64, 32, 3, norm_cfg=norm_cfg, indice_key='subm2')
self.inv_conv2 = block(
32,
16,
3,
norm_cfg=norm_cfg,
indice_key='spconv2',
conv_type='inverseconv')
# [1600, 1408, 41] <- [1600, 1408, 41]
self.conv_up_t1 = SparseBasicBlock(
16,
16,
conv_cfg=dict(type='SubMConv3d', indice_key='subm1'),
norm_cfg=norm_cfg)
self.conv_up_m1 = block(
32, 16, 3, norm_cfg=norm_cfg, indice_key='subm1')
self.conv5 = spconv.SparseSequential(
block(16, 16, 3, norm_cfg=norm_cfg, padding=1, indice_key='subm1'))
def forward(self, voxel_features, coors, batch_size):
"""Forward of SparseUnet
......@@ -200,14 +120,15 @@ class SparseUnet(nn.Module):
batch_size)
x = self.conv_input(input_sp_tensor)
x_conv1 = self.conv1(x)
x_conv2 = self.conv2(x_conv1)
x_conv3 = self.conv3(x_conv2)
x_conv4 = self.conv4(x_conv3)
encode_features = []
for i, stage_name in enumerate(self.encoder):
stage = getattr(self, stage_name)
x = stage(x)
encode_features.append(x)
# for detection head
# [200, 176, 5] -> [200, 176, 2]
out = self.conv_out(x_conv4)
out = self.conv_out(encode_features[-1])
spatial_features = out.dense()
N, C, D, H, W = spatial_features.shape
......@@ -215,21 +136,24 @@ class SparseUnet(nn.Module):
ret = {'spatial_features': spatial_features}
# for segmentation head
# for segmentation head, with output shape:
# [400, 352, 11] <- [200, 176, 5]
x_up4 = self.UR_block_forward(x_conv4, x_conv4, self.conv_up_t4,
self.conv_up_m4, self.inv_conv4)
# [800, 704, 21] <- [400, 352, 11]
x_up3 = self.UR_block_forward(x_conv3, x_up4, self.conv_up_t3,
self.conv_up_m3, self.inv_conv3)
# [1600, 1408, 41] <- [800, 704, 21]
x_up2 = self.UR_block_forward(x_conv2, x_up3, self.conv_up_t2,
self.conv_up_m2, self.inv_conv2)
# [1600, 1408, 41] <- [1600, 1408, 41]
x_up1 = self.UR_block_forward(x_conv1, x_up2, self.conv_up_t1,
self.conv_up_m1, self.conv5)
decode_features = []
x = encode_features[-1]
for i in range(self.stage_num, 0, -1):
x = self.UR_block_forward(
encode_features[i - 1],
x,
getattr(self, 'conv_up_t{}'.format(i)),
getattr(self, 'conv_up_m{}'.format(i)),
getattr(self, 'inv_conv{}'.format(i)),
)
decode_features.append(x)
seg_features = x_up1.features
seg_features = decode_features[-1].features
ret.update({'seg_features': seg_features})
......@@ -410,3 +334,97 @@ class SparseUnet(nn.Module):
else:
raise NotImplementedError
return m
def make_encode_layers(self, make_block, norm_cfg, in_channels):
"""make encode layers using sparse convs
Args:
make_block (method): a bounded function to build blocks
norm_cfg (dict[str]): normal layer configs
in_channels (int): the number of encoder input channels
Returns:
int: the number of encoder output channels
"""
self.encoder = []
for i, blocks in enumerate(self.encode_conv_channels):
blocks_list = []
for j, out_channels in enumerate(tuple(blocks)):
padding = tuple(self.encode_paddings[i])[j]
# each stage started with a spconv layer
# except the first stage
if i != 0 and j == 0:
blocks_list.append(
make_block(
in_channels,
out_channels,
3,
norm_cfg=norm_cfg,
stride=2,
padding=padding,
indice_key='spconv{}'.format(i + 1),
conv_type='spconv'))
else:
blocks_list.append(
make_block(
in_channels,
out_channels,
3,
norm_cfg=norm_cfg,
padding=padding,
indice_key='subm{}'.format(i + 1)))
in_channels = out_channels
stage_name = 'conv{}'.format(i + 1)
stage_layers = spconv.SparseSequential(*blocks_list)
self.add_module(stage_name, stage_layers)
self.encoder.append(stage_name)
return out_channels
def make_decode_layers(self, make_block, norm_cfg, in_channels):
"""make decode layers using sparse convs
Args:
make_block (method): a bounded function to build blocks
norm_cfg (dict[str]): normal layer configs
in_channels (int): the number of encoder input channels
Returns:
int: the number of encoder output channels
"""
block_num = len(self.decode_conv_channels)
for i, block_channels in enumerate(self.decode_conv_channels):
paddings = self.decode_paddings[i]
setattr(
self, 'conv_up_t{}'.format(block_num - i),
SparseBasicBlock(
in_channels,
block_channels[0],
conv_cfg=dict(
type='SubMConv3d',
indice_key='subm{}'.format(block_num - i)),
norm_cfg=norm_cfg))
setattr(
self, 'conv_up_m{}'.format(block_num - i),
make_block(
in_channels * 2,
block_channels[1],
3,
norm_cfg=norm_cfg,
padding=paddings[0],
indice_key='subm{}'.format(block_num - i)))
setattr(
self,
'inv_conv{}'.format(block_num - i),
make_block(
in_channels,
block_channels[2],
3,
norm_cfg=norm_cfg,
padding=paddings[1],
indice_key='spconv{}'.format(block_num - i)
if block_num - i != 1 else 'subm1',
conv_type='inverseconv' if block_num - i != 1 else
'subm') # use submanifold conv instead of inverse conv
# in the last block
)
in_channels = block_channels[2]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment