Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
mmdetection3d
Commits
7ca9e90f
"git@developer.sourcefind.cn:change/sglang.git" did not exist on "cf721fdece1f76b38765353626118736575ebeab"
Commit
7ca9e90f
authored
Apr 30, 2020
by
wuyuefeng
Browse files
refactor sparse_unet
parent
1a74819d
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
151 additions
and
133 deletions
+151
-133
mmdet3d/models/middle_encoders/sparse_unet.py
mmdet3d/models/middle_encoders/sparse_unet.py
+151
-133
No files found.
mmdet3d/models/middle_encoders/sparse_unet.py
View file @
7ca9e90f
...
@@ -14,7 +14,16 @@ class SparseUnet(nn.Module):
...
@@ -14,7 +14,16 @@ class SparseUnet(nn.Module):
in_channels
,
in_channels
,
output_shape
,
output_shape
,
pre_act
=
False
,
pre_act
=
False
,
norm_cfg
=
dict
(
type
=
'BN1d'
,
eps
=
1e-3
,
momentum
=
0.01
)):
norm_cfg
=
dict
(
type
=
'BN1d'
,
eps
=
1e-3
,
momentum
=
0.01
),
base_channels
=
16
,
out_conv_channels
=
128
,
encode_conv_channels
=
((
16
,
),
(
32
,
32
,
32
),
(
64
,
64
,
64
),
(
64
,
64
,
64
)),
encode_paddings
=
((
1
,
),
(
1
,
1
,
1
),
(
1
,
1
,
1
),
((
0
,
1
,
1
),
1
,
1
)),
decode_conv_channels
=
((
64
,
64
,
64
),
(
64
,
64
,
32
),
(
32
,
32
,
16
),
(
16
,
16
,
16
)),
decode_paddings
=
((
1
,
0
),
(
1
,
0
),
(
0
,
0
),
(
0
,
1
))):
"""SparseUnet for PartA^2
"""SparseUnet for PartA^2
See https://arxiv.org/abs/1907.03670 for more detials.
See https://arxiv.org/abs/1907.03670 for more detials.
...
@@ -24,12 +33,27 @@ class SparseUnet(nn.Module):
...
@@ -24,12 +33,27 @@ class SparseUnet(nn.Module):
output_shape (list[int]): the shape of output tensor
output_shape (list[int]): the shape of output tensor
pre_act (bool): use pre_act_block or post_act_block
pre_act (bool): use pre_act_block or post_act_block
norm_cfg (dict): normalize layer config
norm_cfg (dict): normalize layer config
base_channels (int): out channels for conv_input layer
out_conv_channels (int): out channels for conv_out layer
encode_conv_channels (tuple[tuple[int]]):
conv channels of each encond block
encode_paddings (tuple[tuple[int]]): paddings of each encode block
decode_conv_channels (tuple[tuple[int]]):
conv channels of each decode block
decode_paddings (tuple[tuple[int]]): paddings of each decode block
"""
"""
super
().
__init__
()
super
().
__init__
()
self
.
sparse_shape
=
output_shape
self
.
sparse_shape
=
output_shape
self
.
output_shape
=
output_shape
self
.
output_shape
=
output_shape
self
.
in_channels
=
in_channels
self
.
in_channels
=
in_channels
self
.
pre_act
=
pre_act
self
.
pre_act
=
pre_act
self
.
base_channels
=
base_channels
self
.
out_conv_channels
=
out_conv_channels
self
.
encode_conv_channels
=
encode_conv_channels
self
.
encode_paddings
=
encode_paddings
self
.
decode_conv_channels
=
decode_conv_channels
self
.
decode_paddings
=
decode_paddings
self
.
stage_num
=
len
(
self
.
encode_conv_channels
)
# Spconv init all weight on its own
# Spconv init all weight on its own
# TODO: make the network could be modified
# TODO: make the network could be modified
...
@@ -38,18 +62,19 @@ class SparseUnet(nn.Module):
...
@@ -38,18 +62,19 @@ class SparseUnet(nn.Module):
self
.
conv_input
=
spconv
.
SparseSequential
(
self
.
conv_input
=
spconv
.
SparseSequential
(
spconv
.
SubMConv3d
(
spconv
.
SubMConv3d
(
in_channels
,
in_channels
,
16
,
self
.
base_channels
,
3
,
3
,
padding
=
1
,
padding
=
1
,
bias
=
False
,
bias
=
False
,
indice_key
=
'subm1'
),
)
indice_key
=
'subm1'
),
)
block
=
self
.
pre_act_block
make_
block
=
self
.
pre_act_block
else
:
else
:
norm_name
,
norm_layer
=
build_norm_layer
(
norm_cfg
,
16
)
norm_name
,
norm_layer
=
build_norm_layer
(
norm_cfg
,
self
.
base_channels
)
self
.
conv_input
=
spconv
.
SparseSequential
(
self
.
conv_input
=
spconv
.
SparseSequential
(
spconv
.
SubMConv3d
(
spconv
.
SubMConv3d
(
in_channels
,
in_channels
,
16
,
self
.
base_channels
,
3
,
3
,
padding
=
1
,
padding
=
1
,
bias
=
False
,
bias
=
False
,
...
@@ -57,63 +82,19 @@ class SparseUnet(nn.Module):
...
@@ -57,63 +82,19 @@ class SparseUnet(nn.Module):
norm_layer
,
norm_layer
,
nn
.
ReLU
(),
nn
.
ReLU
(),
)
)
block
=
self
.
post_act_block
make_
block
=
self
.
post_act_block
self
.
conv1
=
spconv
.
SparseSequential
(
encoder_out_channels
=
self
.
make_encode_layers
(
make_block
,
norm_cfg
,
block
(
16
,
16
,
3
,
norm_cfg
=
norm_cfg
,
padding
=
1
,
self
.
base_channels
)
indice_key
=
'subm1'
),
)
self
.
make_decode_layers
(
make_block
,
norm_cfg
,
encoder_out_channels
)
self
.
conv2
=
spconv
.
SparseSequential
(
# [1600, 1408, 41] -> [800, 704, 21]
block
(
16
,
32
,
3
,
norm_cfg
=
norm_cfg
,
stride
=
2
,
padding
=
1
,
indice_key
=
'spconv2'
,
conv_type
=
'spconv'
),
block
(
32
,
32
,
3
,
norm_cfg
=
norm_cfg
,
padding
=
1
,
indice_key
=
'subm2'
),
block
(
32
,
32
,
3
,
norm_cfg
=
norm_cfg
,
padding
=
1
,
indice_key
=
'subm2'
),
)
self
.
conv3
=
spconv
.
SparseSequential
(
# [800, 704, 21] -> [400, 352, 11]
block
(
32
,
64
,
3
,
norm_cfg
=
norm_cfg
,
stride
=
2
,
padding
=
1
,
indice_key
=
'spconv3'
,
conv_type
=
'spconv'
),
block
(
64
,
64
,
3
,
norm_cfg
=
norm_cfg
,
padding
=
1
,
indice_key
=
'subm3'
),
block
(
64
,
64
,
3
,
norm_cfg
=
norm_cfg
,
padding
=
1
,
indice_key
=
'subm3'
),
)
self
.
conv4
=
spconv
.
SparseSequential
(
# [400, 352, 11] -> [200, 176, 5]
block
(
64
,
64
,
3
,
norm_cfg
=
norm_cfg
,
stride
=
2
,
padding
=
(
0
,
1
,
1
),
indice_key
=
'spconv4'
,
conv_type
=
'spconv'
),
block
(
64
,
64
,
3
,
norm_cfg
=
norm_cfg
,
padding
=
1
,
indice_key
=
'subm4'
),
block
(
64
,
64
,
3
,
norm_cfg
=
norm_cfg
,
padding
=
1
,
indice_key
=
'subm4'
),
)
norm_name
,
norm_layer
=
build_norm_layer
(
norm_cfg
,
128
)
norm_name
,
norm_layer
=
build_norm_layer
(
norm_cfg
,
self
.
out_conv_channels
)
self
.
conv_out
=
spconv
.
SparseSequential
(
self
.
conv_out
=
spconv
.
SparseSequential
(
# [200, 176, 5] -> [200, 176, 2]
# [200, 176, 5] -> [200, 176, 2]
spconv
.
SparseConv3d
(
spconv
.
SparseConv3d
(
64
,
encoder_out_channels
,
128
,
(
3
,
1
,
1
),
self
.
out_conv_channels
,
(
3
,
1
,
1
),
stride
=
(
2
,
1
,
1
),
stride
=
(
2
,
1
,
1
),
padding
=
0
,
padding
=
0
,
bias
=
False
,
bias
=
False
,
...
@@ -122,67 +103,6 @@ class SparseUnet(nn.Module):
...
@@ -122,67 +103,6 @@ class SparseUnet(nn.Module):
nn
.
ReLU
(),
nn
.
ReLU
(),
)
)
# decoder
# [400, 352, 11] <- [200, 176, 5]
self
.
conv_up_t4
=
SparseBasicBlock
(
64
,
64
,
conv_cfg
=
dict
(
type
=
'SubMConv3d'
,
indice_key
=
'subm4'
),
norm_cfg
=
norm_cfg
)
self
.
conv_up_m4
=
block
(
128
,
64
,
3
,
norm_cfg
=
norm_cfg
,
padding
=
1
,
indice_key
=
'subm4'
)
self
.
inv_conv4
=
block
(
64
,
64
,
3
,
norm_cfg
=
norm_cfg
,
indice_key
=
'spconv4'
,
conv_type
=
'inverseconv'
)
# [800, 704, 21] <- [400, 352, 11]
self
.
conv_up_t3
=
SparseBasicBlock
(
64
,
64
,
conv_cfg
=
dict
(
type
=
'SubMConv3d'
,
indice_key
=
'subm3'
),
norm_cfg
=
norm_cfg
)
self
.
conv_up_m3
=
block
(
128
,
64
,
3
,
norm_cfg
=
norm_cfg
,
padding
=
1
,
indice_key
=
'subm3'
)
self
.
inv_conv3
=
block
(
64
,
32
,
3
,
norm_cfg
=
norm_cfg
,
indice_key
=
'spconv3'
,
conv_type
=
'inverseconv'
)
# [1600, 1408, 41] <- [800, 704, 21]
self
.
conv_up_t2
=
SparseBasicBlock
(
32
,
32
,
conv_cfg
=
dict
(
type
=
'SubMConv3d'
,
indice_key
=
'subm2'
),
norm_cfg
=
norm_cfg
)
self
.
conv_up_m2
=
block
(
64
,
32
,
3
,
norm_cfg
=
norm_cfg
,
indice_key
=
'subm2'
)
self
.
inv_conv2
=
block
(
32
,
16
,
3
,
norm_cfg
=
norm_cfg
,
indice_key
=
'spconv2'
,
conv_type
=
'inverseconv'
)
# [1600, 1408, 41] <- [1600, 1408, 41]
self
.
conv_up_t1
=
SparseBasicBlock
(
16
,
16
,
conv_cfg
=
dict
(
type
=
'SubMConv3d'
,
indice_key
=
'subm1'
),
norm_cfg
=
norm_cfg
)
self
.
conv_up_m1
=
block
(
32
,
16
,
3
,
norm_cfg
=
norm_cfg
,
indice_key
=
'subm1'
)
self
.
conv5
=
spconv
.
SparseSequential
(
block
(
16
,
16
,
3
,
norm_cfg
=
norm_cfg
,
padding
=
1
,
indice_key
=
'subm1'
))
def
forward
(
self
,
voxel_features
,
coors
,
batch_size
):
def
forward
(
self
,
voxel_features
,
coors
,
batch_size
):
"""Forward of SparseUnet
"""Forward of SparseUnet
...
@@ -200,14 +120,15 @@ class SparseUnet(nn.Module):
...
@@ -200,14 +120,15 @@ class SparseUnet(nn.Module):
batch_size
)
batch_size
)
x
=
self
.
conv_input
(
input_sp_tensor
)
x
=
self
.
conv_input
(
input_sp_tensor
)
x_conv1
=
self
.
conv1
(
x
)
encode_features
=
[]
x_conv2
=
self
.
conv2
(
x_conv1
)
for
i
,
stage_name
in
enumerate
(
self
.
encoder
):
x_conv3
=
self
.
conv3
(
x_conv2
)
stage
=
getattr
(
self
,
stage_name
)
x_conv4
=
self
.
conv4
(
x_conv3
)
x
=
stage
(
x
)
encode_features
.
append
(
x
)
# for detection head
# for detection head
# [200, 176, 5] -> [200, 176, 2]
# [200, 176, 5] -> [200, 176, 2]
out
=
self
.
conv_out
(
x_conv4
)
out
=
self
.
conv_out
(
encode_features
[
-
1
]
)
spatial_features
=
out
.
dense
()
spatial_features
=
out
.
dense
()
N
,
C
,
D
,
H
,
W
=
spatial_features
.
shape
N
,
C
,
D
,
H
,
W
=
spatial_features
.
shape
...
@@ -215,21 +136,24 @@ class SparseUnet(nn.Module):
...
@@ -215,21 +136,24 @@ class SparseUnet(nn.Module):
ret
=
{
'spatial_features'
:
spatial_features
}
ret
=
{
'spatial_features'
:
spatial_features
}
# for segmentation head
# for segmentation head
, with output shape:
# [400, 352, 11] <- [200, 176, 5]
# [400, 352, 11] <- [200, 176, 5]
x_up4
=
self
.
UR_block_forward
(
x_conv4
,
x_conv4
,
self
.
conv_up_t4
,
self
.
conv_up_m4
,
self
.
inv_conv4
)
# [800, 704, 21] <- [400, 352, 11]
# [800, 704, 21] <- [400, 352, 11]
x_up3
=
self
.
UR_block_forward
(
x_conv3
,
x_up4
,
self
.
conv_up_t3
,
self
.
conv_up_m3
,
self
.
inv_conv3
)
# [1600, 1408, 41] <- [800, 704, 21]
# [1600, 1408, 41] <- [800, 704, 21]
x_up2
=
self
.
UR_block_forward
(
x_conv2
,
x_up3
,
self
.
conv_up_t2
,
self
.
conv_up_m2
,
self
.
inv_conv2
)
# [1600, 1408, 41] <- [1600, 1408, 41]
# [1600, 1408, 41] <- [1600, 1408, 41]
x_up1
=
self
.
UR_block_forward
(
x_conv1
,
x_up2
,
self
.
conv_up_t1
,
decode_features
=
[]
self
.
conv_up_m1
,
self
.
conv5
)
x
=
encode_features
[
-
1
]
for
i
in
range
(
self
.
stage_num
,
0
,
-
1
):
x
=
self
.
UR_block_forward
(
encode_features
[
i
-
1
],
x
,
getattr
(
self
,
'conv_up_t{}'
.
format
(
i
)),
getattr
(
self
,
'conv_up_m{}'
.
format
(
i
)),
getattr
(
self
,
'inv_conv{}'
.
format
(
i
)),
)
decode_features
.
append
(
x
)
seg_features
=
x_up1
.
features
seg_features
=
decode_features
[
-
1
]
.
features
ret
.
update
({
'seg_features'
:
seg_features
})
ret
.
update
({
'seg_features'
:
seg_features
})
...
@@ -410,3 +334,97 @@ class SparseUnet(nn.Module):
...
@@ -410,3 +334,97 @@ class SparseUnet(nn.Module):
else
:
else
:
raise
NotImplementedError
raise
NotImplementedError
return
m
return
m
def
make_encode_layers
(
self
,
make_block
,
norm_cfg
,
in_channels
):
"""make encode layers using sparse convs
Args:
make_block (method): a bounded function to build blocks
norm_cfg (dict[str]): normal layer configs
in_channels (int): the number of encoder input channels
Returns:
int: the number of encoder output channels
"""
self
.
encoder
=
[]
for
i
,
blocks
in
enumerate
(
self
.
encode_conv_channels
):
blocks_list
=
[]
for
j
,
out_channels
in
enumerate
(
tuple
(
blocks
)):
padding
=
tuple
(
self
.
encode_paddings
[
i
])[
j
]
# each stage started with a spconv layer
# except the first stage
if
i
!=
0
and
j
==
0
:
blocks_list
.
append
(
make_block
(
in_channels
,
out_channels
,
3
,
norm_cfg
=
norm_cfg
,
stride
=
2
,
padding
=
padding
,
indice_key
=
'spconv{}'
.
format
(
i
+
1
),
conv_type
=
'spconv'
))
else
:
blocks_list
.
append
(
make_block
(
in_channels
,
out_channels
,
3
,
norm_cfg
=
norm_cfg
,
padding
=
padding
,
indice_key
=
'subm{}'
.
format
(
i
+
1
)))
in_channels
=
out_channels
stage_name
=
'conv{}'
.
format
(
i
+
1
)
stage_layers
=
spconv
.
SparseSequential
(
*
blocks_list
)
self
.
add_module
(
stage_name
,
stage_layers
)
self
.
encoder
.
append
(
stage_name
)
return
out_channels
def
make_decode_layers
(
self
,
make_block
,
norm_cfg
,
in_channels
):
"""make decode layers using sparse convs
Args:
make_block (method): a bounded function to build blocks
norm_cfg (dict[str]): normal layer configs
in_channels (int): the number of encoder input channels
Returns:
int: the number of encoder output channels
"""
block_num
=
len
(
self
.
decode_conv_channels
)
for
i
,
block_channels
in
enumerate
(
self
.
decode_conv_channels
):
paddings
=
self
.
decode_paddings
[
i
]
setattr
(
self
,
'conv_up_t{}'
.
format
(
block_num
-
i
),
SparseBasicBlock
(
in_channels
,
block_channels
[
0
],
conv_cfg
=
dict
(
type
=
'SubMConv3d'
,
indice_key
=
'subm{}'
.
format
(
block_num
-
i
)),
norm_cfg
=
norm_cfg
))
setattr
(
self
,
'conv_up_m{}'
.
format
(
block_num
-
i
),
make_block
(
in_channels
*
2
,
block_channels
[
1
],
3
,
norm_cfg
=
norm_cfg
,
padding
=
paddings
[
0
],
indice_key
=
'subm{}'
.
format
(
block_num
-
i
)))
setattr
(
self
,
'inv_conv{}'
.
format
(
block_num
-
i
),
make_block
(
in_channels
,
block_channels
[
2
],
3
,
norm_cfg
=
norm_cfg
,
padding
=
paddings
[
1
],
indice_key
=
'spconv{}'
.
format
(
block_num
-
i
)
if
block_num
-
i
!=
1
else
'subm1'
,
conv_type
=
'inverseconv'
if
block_num
-
i
!=
1
else
'subm'
)
# use submanifold conv instead of inverse conv
# in the last block
)
in_channels
=
block_channels
[
2
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment