Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
mmdetection3d
Commits
28e511cd
Commit
28e511cd
authored
May 17, 2020
by
zhangwenwei
Browse files
Merge branch 'feat_clean_sparse_block' into 'master'
Feat clean sparse block See merge request open-mmlab/mmdet.3d!33
parents
97e4ed42
df7e4e30
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
324 additions
and
550 deletions
+324
-550
configs/kitti/dv_mvx-v2_second_secfpn_fpn-fusion_adamw_2x8_80e_kitti-3d-3class.py
...second_secfpn_fpn-fusion_adamw_2x8_80e_kitti-3d-3class.py
+2
-3
configs/kitti/dv_second_secfpn_2x8_cosine_80e_kitti-3d-3class.py
.../kitti/dv_second_secfpn_2x8_cosine_80e_kitti-3d-3class.py
+2
-3
configs/kitti/dv_second_secfpn_6x8_80e_kitti-3d-car.py
configs/kitti/dv_second_secfpn_6x8_80e_kitti-3d-car.py
+2
-3
configs/kitti/hv_PartA2_secfpn_4x8_cyclic_80e_kitti-3d-3class.py
.../kitti/hv_PartA2_secfpn_4x8_cyclic_80e_kitti-3d-3class.py
+2
-2
configs/kitti/hv_PartA2_secfpn_4x8_cyclic_80e_kitti-3d-car.py
...igs/kitti/hv_PartA2_secfpn_4x8_cyclic_80e_kitti-3d-car.py
+2
-2
configs/kitti/hv_second_secfpn_6x8_80e_kitti-3d-car.py
configs/kitti/hv_second_secfpn_6x8_80e_kitti-3d-car.py
+2
-3
mmdet3d/models/middle_encoders/sparse_encoder.py
mmdet3d/models/middle_encoders/sparse_encoder.py
+120
-176
mmdet3d/models/middle_encoders/sparse_unet.py
mmdet3d/models/middle_encoders/sparse_unet.py
+65
-194
mmdet3d/ops/__init__.py
mmdet3d/ops/__init__.py
+5
-5
mmdet3d/ops/sparse_block.py
mmdet3d/ops/sparse_block.py
+66
-145
tests/test_sparse_unet.py
tests/test_sparse_unet.py
+56
-14
No files found.
configs/kitti/dv_mvx-v2_second_secfpn_fpn-fusion_adamw_2x8_80e_kitti-3d-3class.py
View file @
28e511cd
...
@@ -49,9 +49,8 @@ model = dict(
...
@@ -49,9 +49,8 @@ model = dict(
pts_middle_encoder
=
dict
(
pts_middle_encoder
=
dict
(
type
=
'SparseEncoder'
,
type
=
'SparseEncoder'
,
in_channels
=
128
,
in_channels
=
128
,
output_shape
=
[
41
,
1600
,
1408
],
# checked from PointCloud3D
sparse_shape
=
[
41
,
1600
,
1408
],
pre_act
=
False
,
order
=
(
'conv'
,
'norm'
,
'act'
)),
),
pts_backbone
=
dict
(
pts_backbone
=
dict
(
type
=
'SECOND'
,
type
=
'SECOND'
,
in_channels
=
256
,
in_channels
=
256
,
...
...
configs/kitti/dv_second_secfpn_2x8_cosine_80e_kitti-3d-3class.py
View file @
28e511cd
...
@@ -18,9 +18,8 @@ model = dict(
...
@@ -18,9 +18,8 @@ model = dict(
middle_encoder
=
dict
(
middle_encoder
=
dict
(
type
=
'SparseEncoder'
,
type
=
'SparseEncoder'
,
in_channels
=
4
,
in_channels
=
4
,
output_shape
=
[
41
,
1600
,
1408
],
sparse_shape
=
[
41
,
1600
,
1408
],
pre_act
=
False
,
order
=
(
'conv'
,
'norm'
,
'act'
)),
),
backbone
=
dict
(
backbone
=
dict
(
type
=
'SECOND'
,
type
=
'SECOND'
,
in_channels
=
256
,
in_channels
=
256
,
...
...
configs/kitti/dv_second_secfpn_6x8_80e_kitti-3d-car.py
View file @
28e511cd
...
@@ -18,9 +18,8 @@ model = dict(
...
@@ -18,9 +18,8 @@ model = dict(
middle_encoder
=
dict
(
middle_encoder
=
dict
(
type
=
'SparseEncoder'
,
type
=
'SparseEncoder'
,
in_channels
=
4
,
in_channels
=
4
,
output_shape
=
[
41
,
1600
,
1408
],
# checked from PointCloud3D
sparse_shape
=
[
41
,
1600
,
1408
],
pre_act
=
False
,
order
=
(
'conv'
,
'norm'
,
'act'
)),
),
backbone
=
dict
(
backbone
=
dict
(
type
=
'SECOND'
,
type
=
'SECOND'
,
in_channels
=
256
,
in_channels
=
256
,
...
...
configs/kitti/hv_PartA2_secfpn_4x8_cyclic_80e_kitti-3d-3class.py
View file @
28e511cd
...
@@ -18,8 +18,8 @@ model = dict(
...
@@ -18,8 +18,8 @@ model = dict(
middle_encoder
=
dict
(
middle_encoder
=
dict
(
type
=
'SparseUNet'
,
type
=
'SparseUNet'
,
in_channels
=
4
,
in_channels
=
4
,
output
_shape
=
[
41
,
1600
,
1408
],
sparse
_shape
=
[
41
,
1600
,
1408
],
pre_act
=
False
),
order
=
(
'conv'
,
'norm'
,
'act'
)
),
backbone
=
dict
(
backbone
=
dict
(
type
=
'SECOND'
,
type
=
'SECOND'
,
in_channels
=
256
,
in_channels
=
256
,
...
...
configs/kitti/hv_PartA2_secfpn_4x8_cyclic_80e_kitti-3d-car.py
View file @
28e511cd
...
@@ -18,8 +18,8 @@ model = dict(
...
@@ -18,8 +18,8 @@ model = dict(
middle_encoder
=
dict
(
middle_encoder
=
dict
(
type
=
'SparseUNet'
,
type
=
'SparseUNet'
,
in_channels
=
4
,
in_channels
=
4
,
output
_shape
=
[
41
,
1600
,
1408
],
sparse
_shape
=
[
41
,
1600
,
1408
],
pre_act
=
False
),
order
=
(
'conv'
,
'norm'
,
'act'
)
),
backbone
=
dict
(
backbone
=
dict
(
type
=
'SECOND'
,
type
=
'SECOND'
,
in_channels
=
256
,
in_channels
=
256
,
...
...
configs/kitti/hv_second_secfpn_6x8_80e_kitti-3d-car.py
View file @
28e511cd
...
@@ -18,9 +18,8 @@ model = dict(
...
@@ -18,9 +18,8 @@ model = dict(
middle_encoder
=
dict
(
middle_encoder
=
dict
(
type
=
'SparseEncoder'
,
type
=
'SparseEncoder'
,
in_channels
=
4
,
in_channels
=
4
,
output_shape
=
[
41
,
1600
,
1408
],
# checked from PointCloud3D
sparse_shape
=
[
41
,
1600
,
1408
],
pre_act
=
False
,
order
=
(
'conv'
,
'norm'
,
'act'
)),
),
backbone
=
dict
(
backbone
=
dict
(
type
=
'SECOND'
,
type
=
'SECOND'
,
in_channels
=
256
,
in_channels
=
256
,
...
...
mmdet3d/models/middle_encoders/sparse_encoder.py
View file @
28e511cd
import
torch.nn
as
nn
import
torch.nn
as
nn
from
mmcv.cnn
import
build_norm_layer
import
mmdet3d.ops.spconv
as
spconv
import
mmdet3d.ops.spconv
as
spconv
from
mmdet3d.ops
import
make_sparse_convmodule
from
..registry
import
MIDDLE_ENCODERS
from
..registry
import
MIDDLE_ENCODERS
@
MIDDLE_ENCODERS
.
register_module
()
@
MIDDLE_ENCODERS
.
register_module
()
class
SparseEncoder
(
nn
.
Module
):
class
SparseEncoder
(
nn
.
Module
):
"""Sparse encoder for Second
See https://arxiv.org/abs/1907.03670 for more detials.
Args:
in_channels (int): the number of input channels
sparse_shape (list[int]): the sparse shape of input tensor
norm_cfg (dict): config of normalization layer
base_channels (int): out channels for conv_input layer
output_channels (int): out channels for conv_out layer
encoder_channels (tuple[tuple[int]]):
conv channels of each encode block
encoder_paddings (tuple[tuple[int]]): paddings of each encode block
"""
def
__init__
(
self
,
def
__init__
(
self
,
in_channels
,
in_channels
,
output_shape
,
sparse_shape
,
pre_act
,
order
=
(
'conv'
,
'norm'
,
'act'
),
norm_cfg
=
dict
(
type
=
'BN1d'
,
eps
=
1e-3
,
momentum
=
0.01
)):
norm_cfg
=
dict
(
type
=
'BN1d'
,
eps
=
1e-3
,
momentum
=
0.01
),
base_channels
=
16
,
output_channels
=
128
,
encoder_channels
=
((
16
,
),
(
32
,
32
,
32
),
(
64
,
64
,
64
),
(
64
,
64
,
64
)),
encoder_paddings
=
((
1
,
),
(
1
,
1
,
1
),
(
1
,
1
,
1
),
((
0
,
1
,
1
),
1
,
1
))):
super
().
__init__
()
super
().
__init__
()
self
.
sparse_shape
=
output_shape
self
.
sparse_shape
=
sparse_shape
self
.
output_shape
=
output_shape
self
.
in_channels
=
in_channels
self
.
in_channels
=
in_channels
self
.
pre_act
=
pre_act
self
.
order
=
order
self
.
base_channels
=
base_channels
self
.
output_channels
=
output_channels
self
.
encoder_channels
=
encoder_channels
self
.
encoder_paddings
=
encoder_paddings
self
.
stage_num
=
len
(
self
.
encoder_channels
)
# Spconv init all weight on its own
# Spconv init all weight on its own
# TODO: make the network could be modified
if
pre_act
:
assert
isinstance
(
order
,
tuple
)
and
len
(
order
)
==
3
self
.
conv_input
=
spconv
.
SparseSequential
(
assert
set
(
order
)
==
{
'conv'
,
'norm'
,
'act'
}
spconv
.
SubMConv3d
(
in_channels
,
if
self
.
order
[
0
]
!=
'conv'
:
# pre activate
16
,
self
.
conv_input
=
make_sparse_convmodule
(
3
,
padding
=
1
,
bias
=
False
,
indice_key
=
'subm1'
),
)
block
=
self
.
pre_act_block
else
:
norm_name
,
norm_layer
=
build_norm_layer
(
norm_cfg
,
16
)
self
.
conv_input
=
spconv
.
SparseSequential
(
spconv
.
SubMConv3d
(
in_channels
,
in_channels
,
16
,
self
.
base_channels
,
3
,
padding
=
1
,
bias
=
False
,
indice_key
=
'subm1'
),
norm_layer
,
nn
.
ReLU
(),
)
block
=
self
.
post_act_block
self
.
conv1
=
spconv
.
SparseSequential
(
block
(
16
,
16
,
3
,
norm_cfg
=
norm_cfg
,
padding
=
1
,
indice_key
=
'subm1'
),
)
self
.
conv2
=
spconv
.
SparseSequential
(
# [1600, 1408, 41] -> [800, 704, 21]
block
(
16
,
32
,
3
,
3
,
norm_cfg
=
norm_cfg
,
norm_cfg
=
norm_cfg
,
stride
=
2
,
padding
=
1
,
padding
=
1
,
indice_key
=
'spconv2'
,
indice_key
=
'subm1'
,
conv_type
=
'spconv'
),
conv_type
=
'SubMConv3d'
,
block
(
32
,
32
,
3
,
norm_cfg
=
norm_cfg
,
padding
=
1
,
indice_key
=
'subm2'
),
order
=
(
'conv'
,
))
block
(
32
,
32
,
3
,
norm_cfg
=
norm_cfg
,
padding
=
1
,
indice_key
=
'subm2'
),
else
:
# post activate
)
self
.
conv_input
=
make_sparse_convmodule
(
in_channels
,
self
.
conv3
=
spconv
.
SparseSequential
(
self
.
base_channels
,
# [800, 704, 21] -> [400, 352, 11]
block
(
32
,
64
,
3
,
3
,
norm_cfg
=
norm_cfg
,
norm_cfg
=
norm_cfg
,
stride
=
2
,
padding
=
1
,
padding
=
1
,
indice_key
=
'spconv3'
,
indice_key
=
'subm1'
,
conv_type
=
'spconv'
),
conv_type
=
'SubMConv3d'
)
block
(
64
,
64
,
3
,
norm_cfg
=
norm_cfg
,
padding
=
1
,
indice_key
=
'subm3'
),
block
(
64
,
64
,
3
,
norm_cfg
=
norm_cfg
,
padding
=
1
,
indice_key
=
'subm3'
),
encoder_out_channels
=
self
.
make_encoder_layers
(
)
make_sparse_convmodule
,
norm_cfg
,
self
.
base_channels
)
self
.
conv4
=
spconv
.
SparseSequential
(
self
.
conv_out
=
make_sparse_convmodule
(
# [400, 352, 11] -> [200, 176, 5]
encoder_out_channels
,
block
(
self
.
output_channels
,
64
,
kernel_size
=
(
3
,
1
,
1
),
64
,
3
,
norm_cfg
=
norm_cfg
,
stride
=
2
,
padding
=
(
0
,
1
,
1
),
indice_key
=
'spconv4'
,
conv_type
=
'spconv'
),
block
(
64
,
64
,
3
,
norm_cfg
=
norm_cfg
,
padding
=
1
,
indice_key
=
'subm4'
),
block
(
64
,
64
,
3
,
norm_cfg
=
norm_cfg
,
padding
=
1
,
indice_key
=
'subm4'
),
)
norm_name
,
norm_layer
=
build_norm_layer
(
norm_cfg
,
128
)
self
.
conv_out
=
spconv
.
SparseSequential
(
# [200, 176, 5] -> [200, 176, 2]
spconv
.
SparseConv3d
(
128
,
128
,
(
3
,
1
,
1
),
stride
=
(
2
,
1
,
1
),
stride
=
(
2
,
1
,
1
),
norm_cfg
=
norm_cfg
,
padding
=
0
,
padding
=
0
,
bias
=
False
,
indice_key
=
'spconv_down2'
,
indice_key
=
'spconv_down2'
),
conv_type
=
'SparseConv3d'
)
norm_layer
,
nn
.
ReLU
(),
)
def
forward
(
self
,
voxel_features
,
coors
,
batch_size
):
def
forward
(
self
,
voxel_features
,
coors
,
batch_size
):
"""
"""Forward of SparseEncoder
:param voxel_features: (N, C)
:param coors: (N, 4) [batch_idx, z_idx, y_idx, x_idx]
Args:
:param batch_size:
voxel_features (torch.float32): shape [N, C]
:return:
coors (torch.int32): shape [N, 4](batch_idx, z_idx, y_idx, x_idx)
batch_size (int): batch size
Returns:
dict: backbone features
"""
"""
coors
=
coors
.
int
()
coors
=
coors
.
int
()
input_sp_tensor
=
spconv
.
SparseConvTensor
(
voxel_features
,
coors
,
input_sp_tensor
=
spconv
.
SparseConvTensor
(
voxel_features
,
coors
,
...
@@ -122,14 +97,14 @@ class SparseEncoder(nn.Module):
...
@@ -122,14 +97,14 @@ class SparseEncoder(nn.Module):
batch_size
)
batch_size
)
x
=
self
.
conv_input
(
input_sp_tensor
)
x
=
self
.
conv_input
(
input_sp_tensor
)
x_conv1
=
self
.
conv1
(
x
)
encode_features
=
[]
x_conv2
=
self
.
conv2
(
x_conv1
)
for
encoder_layer
in
self
.
encoder_layers
:
x_conv3
=
self
.
conv3
(
x_conv2
)
x
=
encoder_layer
(
x
)
x_conv4
=
self
.
conv4
(
x_conv3
)
encode_features
.
append
(
x
)
# for detection head
# for detection head
# [200, 176, 5] -> [200, 176, 2]
# [200, 176, 5] -> [200, 176, 2]
out
=
self
.
conv_out
(
x_conv4
)
out
=
self
.
conv_out
(
encode_features
[
-
1
]
)
spatial_features
=
out
.
dense
()
spatial_features
=
out
.
dense
()
N
,
C
,
D
,
H
,
W
=
spatial_features
.
shape
N
,
C
,
D
,
H
,
W
=
spatial_features
.
shape
...
@@ -137,79 +112,48 @@ class SparseEncoder(nn.Module):
...
@@ -137,79 +112,48 @@ class SparseEncoder(nn.Module):
return
spatial_features
return
spatial_features
def
pre_act_block
(
self
,
def
make_encoder_layers
(
self
,
make_block
,
norm_cfg
,
in_channels
):
in_channels
,
"""make encoder layers using sparse convs
out_channels
,
kernel_size
,
Args:
indice_key
=
None
,
make_block (method): a bounded function to build blocks
stride
=
1
,
norm_cfg (dict[str]): config of normalization layer
padding
=
0
,
in_channels (int): the number of encoder input channels
conv_type
=
'subm'
,
norm_cfg
=
None
):
Returns:
norm_name
,
norm_layer
=
build_norm_layer
(
norm_cfg
,
in_channels
)
int: the number of encoder output channels
if
conv_type
==
'subm'
:
"""
m
=
spconv
.
SparseSequential
(
self
.
encoder_layers
=
spconv
.
SparseSequential
()
norm_layer
,
nn
.
ReLU
(
inplace
=
True
),
for
i
,
blocks
in
enumerate
(
self
.
encoder_channels
):
spconv
.
SubMConv3d
(
blocks_list
=
[]
in_channels
,
for
j
,
out_channels
in
enumerate
(
tuple
(
blocks
)):
out_channels
,
padding
=
tuple
(
self
.
encoder_paddings
[
i
])[
j
]
kernel_size
,
# each stage started with a spconv layer
padding
=
padding
,
# except the first stage
bias
=
False
,
if
i
!=
0
and
j
==
0
:
indice_key
=
indice_key
),
blocks_list
.
append
(
)
make_block
(
elif
conv_type
==
'spconv'
:
m
=
spconv
.
SparseSequential
(
norm_layer
,
nn
.
ReLU
(
inplace
=
True
),
spconv
.
SparseConv3d
(
in_channels
,
in_channels
,
out_channels
,
out_channels
,
kernel_size
,
3
,
stride
=
stride
,
norm_cfg
=
norm_cfg
,
stride
=
2
,
padding
=
padding
,
padding
=
padding
,
bias
=
False
,
indice_key
=
f
'spconv
{
i
+
1
}
'
,
indice_key
=
indice_key
),
conv_type
=
'SparseConv3d'
))
)
else
:
else
:
raise
NotImplementedError
blocks_list
.
append
(
return
m
make_block
(
def
post_act_block
(
self
,
in_channels
,
out_channels
,
kernel_size
,
indice_key
,
stride
=
1
,
padding
=
0
,
conv_type
=
'subm'
,
norm_cfg
=
None
):
norm_name
,
norm_layer
=
build_norm_layer
(
norm_cfg
,
out_channels
)
if
conv_type
==
'subm'
:
m
=
spconv
.
SparseSequential
(
spconv
.
SubMConv3d
(
in_channels
,
in_channels
,
out_channels
,
out_channels
,
kernel_size
,
3
,
bias
=
False
,
norm_cfg
=
norm_cfg
,
indice_key
=
indice_key
),
norm_layer
,
nn
.
ReLU
(
inplace
=
True
),
)
elif
conv_type
==
'spconv'
:
m
=
spconv
.
SparseSequential
(
spconv
.
SparseConv3d
(
in_channels
,
out_channels
,
kernel_size
,
stride
=
stride
,
padding
=
padding
,
padding
=
padding
,
bias
=
False
,
indice_key
=
f
'subm
{
i
+
1
}
'
,
indice_key
=
indice_key
),
conv_type
=
'SubMConv3d'
))
norm_layer
,
in_channels
=
out_channels
nn
.
ReLU
(
inplace
=
True
),
stage_name
=
f
'encoder_layer
{
i
+
1
}
'
)
stage_layers
=
spconv
.
SparseSequential
(
*
blocks_list
)
else
:
self
.
encoder_layers
.
add_module
(
stage_name
,
stage_layers
)
raise
NotImplementedError
return
out_channels
return
m
mmdet3d/models/middle_encoders/sparse_unet.py
View file @
28e511cd
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
mmcv.cnn
import
build_norm_layer
import
mmdet3d.ops.spconv
as
spconv
import
mmdet3d.ops.spconv
as
spconv
from
mmdet3d.ops
import
SparseBasicBlock
from
mmdet3d.ops
import
SparseBasicBlock
,
make_sparse_convmodule
from
..registry
import
MIDDLE_ENCODERS
from
..registry
import
MIDDLE_ENCODERS
@
MIDDLE_ENCODERS
.
register_module
()
@
MIDDLE_ENCODERS
.
register_module
()
class
SparseUNet
(
nn
.
Module
):
class
SparseUNet
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
output_shape
,
pre_act
=
False
,
norm_cfg
=
dict
(
type
=
'BN1d'
,
eps
=
1e-3
,
momentum
=
0.01
),
base_channels
=
16
,
output_channels
=
128
,
encoder_channels
=
((
16
,
),
(
32
,
32
,
32
),
(
64
,
64
,
64
),
(
64
,
64
,
64
)),
encoder_paddings
=
((
1
,
),
(
1
,
1
,
1
),
(
1
,
1
,
1
),
((
0
,
1
,
1
),
1
,
1
)),
decoder_channels
=
((
64
,
64
,
64
),
(
64
,
64
,
32
),
(
32
,
32
,
16
),
(
16
,
16
,
16
)),
decoder_paddings
=
((
1
,
0
),
(
1
,
0
),
(
0
,
0
),
(
0
,
1
))):
"""SparseUNet for PartA^2
"""SparseUNet for PartA^2
See https://arxiv.org/abs/1907.03670 for more detials.
See https://arxiv.org/abs/1907.03670 for more detials.
Args:
Args:
in_channels (int): the number of input channels
in_channels (int): the number of input channels
output_shape (list[int]): the shape of output tensor
sparse_shape (list[int]): the sparse shape of input tensor
pre_act (bool): use pre_act_block or post_act_block
norm_cfg (dict): config of normalization layer
norm_cfg (dict): config of normalization layer
base_channels (int): out channels for conv_input layer
base_channels (int): out channels for conv_input layer
output_channels (int): out channels for conv_out layer
output_channels (int): out channels for conv_out layer
...
@@ -42,11 +25,25 @@ class SparseUNet(nn.Module):
...
@@ -42,11 +25,25 @@ class SparseUNet(nn.Module):
conv channels of each decode block
conv channels of each decode block
decoder_paddings (tuple[tuple[int]]): paddings of each decode block
decoder_paddings (tuple[tuple[int]]): paddings of each decode block
"""
"""
def
__init__
(
self
,
in_channels
,
sparse_shape
,
order
=
(
'conv'
,
'norm'
,
'act'
),
norm_cfg
=
dict
(
type
=
'BN1d'
,
eps
=
1e-3
,
momentum
=
0.01
),
base_channels
=
16
,
output_channels
=
128
,
encoder_channels
=
((
16
,
),
(
32
,
32
,
32
),
(
64
,
64
,
64
),
(
64
,
64
,
64
)),
encoder_paddings
=
((
1
,
),
(
1
,
1
,
1
),
(
1
,
1
,
1
),
((
0
,
1
,
1
),
1
,
1
)),
decoder_channels
=
((
64
,
64
,
64
),
(
64
,
64
,
32
),
(
32
,
32
,
16
),
(
16
,
16
,
16
)),
decoder_paddings
=
((
1
,
0
),
(
1
,
0
),
(
0
,
0
),
(
0
,
1
))):
super
().
__init__
()
super
().
__init__
()
self
.
sparse_shape
=
output_shape
self
.
sparse_shape
=
sparse_shape
self
.
output_shape
=
output_shape
self
.
in_channels
=
in_channels
self
.
in_channels
=
in_channels
self
.
pre_act
=
pre_act
self
.
order
=
order
self
.
base_channels
=
base_channels
self
.
base_channels
=
base_channels
self
.
output_channels
=
output_channels
self
.
output_channels
=
output_channels
self
.
encoder_channels
=
encoder_channels
self
.
encoder_channels
=
encoder_channels
...
@@ -56,44 +53,43 @@ class SparseUNet(nn.Module):
...
@@ -56,44 +53,43 @@ class SparseUNet(nn.Module):
self
.
stage_num
=
len
(
self
.
encoder_channels
)
self
.
stage_num
=
len
(
self
.
encoder_channels
)
# Spconv init all weight on its own
# Spconv init all weight on its own
if
pre_act
:
assert
isinstance
(
order
,
tuple
)
and
len
(
order
)
==
3
# TODO: use ConvModule to encapsulate
assert
set
(
order
)
==
{
'conv'
,
'norm'
,
'act'
}
self
.
conv_input
=
spconv
.
SparseSequential
(
spconv
.
SubMConv3d
(
if
self
.
order
[
0
]
!=
'conv'
:
# pre activate
self
.
conv_input
=
make_sparse_convmodule
(
in_channels
,
in_channels
,
self
.
base_channels
,
self
.
base_channels
,
3
,
3
,
norm_cfg
=
norm_cfg
,
padding
=
1
,
padding
=
1
,
bias
=
False
,
indice_key
=
'subm1'
,
indice_key
=
'subm1'
))
conv_type
=
'SubMConv3d'
,
make_block
=
self
.
pre_act_block
order
=
(
'conv'
,
))
else
:
else
:
# post activate
self
.
conv_input
=
spconv
.
SparseSequential
(
self
.
conv_input
=
make_sparse_convmodule
(
spconv
.
SubMConv3d
(
in_channels
,
in_channels
,
self
.
base_channels
,
self
.
base_channels
,
3
,
3
,
norm_cfg
=
norm_cfg
,
padding
=
1
,
padding
=
1
,
bias
=
False
,
indice_key
=
'subm1'
,
indice_key
=
'subm1'
),
conv_type
=
'SubMConv3d'
)
build_norm_layer
(
norm_cfg
,
self
.
base_channels
)[
1
],
nn
.
ReLU
())
make_block
=
self
.
post_act_block
encoder_out_channels
=
self
.
make_encoder_layers
(
encoder_out_channels
=
self
.
make_encoder_layers
(
make_block
,
norm_cfg
,
self
.
base_channels
)
make_sparse_convmodule
,
norm_cfg
,
self
.
base_channels
)
self
.
make_decoder_layers
(
make_block
,
norm_cfg
,
encoder_out_channels
)
self
.
make_decoder_layers
(
make_sparse_convmodule
,
norm_cfg
,
encoder_out_channels
)
self
.
conv_out
=
spconv
.
SparseSequential
(
self
.
conv_out
=
make_sparse_convmodule
(
# [200, 176, 5] -> [200, 176, 2]
spconv
.
SparseConv3d
(
encoder_out_channels
,
encoder_out_channels
,
self
.
output_channels
,
(
3
,
1
,
1
),
self
.
output_channels
,
kernel_size
=
(
3
,
1
,
1
),
stride
=
(
2
,
1
,
1
),
stride
=
(
2
,
1
,
1
),
norm_cfg
=
norm_cfg
,
padding
=
0
,
padding
=
0
,
bias
=
False
,
indice_key
=
'spconv_down2'
,
indice_key
=
'spconv_down2'
),
conv_type
=
'SparseConv3d'
)
build_norm_layer
(
norm_cfg
,
self
.
output_channels
)[
1
],
nn
.
ReLU
())
def
forward
(
self
,
voxel_features
,
coors
,
batch_size
):
def
forward
(
self
,
voxel_features
,
coors
,
batch_size
):
"""Forward of SparseUNet
"""Forward of SparseUNet
...
@@ -187,133 +183,6 @@ class SparseUNet(nn.Module):
...
@@ -187,133 +183,6 @@ class SparseUNet(nn.Module):
x
.
features
=
features
.
view
(
n
,
out_channels
,
-
1
).
sum
(
dim
=
2
)
x
.
features
=
features
.
view
(
n
,
out_channels
,
-
1
).
sum
(
dim
=
2
)
return
x
return
x
def
pre_act_block
(
self
,
in_channels
,
out_channels
,
kernel_size
,
indice_key
=
None
,
stride
=
1
,
padding
=
0
,
conv_type
=
'subm'
,
norm_cfg
=
None
):
"""Make pre activate sparse convolution block.
Args:
in_channels (int): the number of input channels
out_channels (int): the number of out channels
kernel_size (int): kernel size of convolution
indice_key (str): the indice key used for sparse tensor
stride (int): the stride of convolution
padding (int or list[int]): the padding number of input
conv_type (str): conv type in 'subm', 'spconv' or 'inverseconv'
norm_cfg (dict): config of normalization layer
Returns:
spconv.SparseSequential: pre activate sparse convolution block.
"""
# TODO: use ConvModule to encapsulate
assert
conv_type
in
[
'subm'
,
'spconv'
,
'inverseconv'
]
if
conv_type
==
'subm'
:
m
=
spconv
.
SparseSequential
(
build_norm_layer
(
norm_cfg
,
in_channels
)[
1
],
nn
.
ReLU
(
inplace
=
True
),
spconv
.
SubMConv3d
(
in_channels
,
out_channels
,
kernel_size
,
padding
=
padding
,
bias
=
False
,
indice_key
=
indice_key
))
elif
conv_type
==
'spconv'
:
m
=
spconv
.
SparseSequential
(
build_norm_layer
(
norm_cfg
,
in_channels
)[
1
],
nn
.
ReLU
(
inplace
=
True
),
spconv
.
SparseConv3d
(
in_channels
,
out_channels
,
kernel_size
,
stride
=
stride
,
padding
=
padding
,
bias
=
False
,
indice_key
=
indice_key
))
elif
conv_type
==
'inverseconv'
:
m
=
spconv
.
SparseSequential
(
build_norm_layer
(
norm_cfg
,
in_channels
)[
1
],
nn
.
ReLU
(
inplace
=
True
),
spconv
.
SparseInverseConv3d
(
in_channels
,
out_channels
,
kernel_size
,
bias
=
False
,
indice_key
=
indice_key
))
else
:
raise
NotImplementedError
return
m
def
post_act_block
(
self
,
in_channels
,
out_channels
,
kernel_size
,
indice_key
,
stride
=
1
,
padding
=
0
,
conv_type
=
'subm'
,
norm_cfg
=
None
):
"""Make post activate sparse convolution block.
Args:
in_channels (int): the number of input channels
out_channels (int): the number of out channels
kernel_size (int): kernel size of convolution
indice_key (str): the indice key used for sparse tensor
stride (int): the stride of convolution
padding (int or list[int]): the padding number of input
conv_type (str): conv type in 'subm', 'spconv' or 'inverseconv'
norm_cfg (dict[str]): config of normalization layer
Returns:
spconv.SparseSequential: post activate sparse convolution block.
"""
# TODO: use ConvModule to encapsulate
assert
conv_type
in
[
'subm'
,
'spconv'
,
'inverseconv'
]
if
conv_type
==
'subm'
:
m
=
spconv
.
SparseSequential
(
spconv
.
SubMConv3d
(
in_channels
,
out_channels
,
kernel_size
,
bias
=
False
,
indice_key
=
indice_key
),
build_norm_layer
(
norm_cfg
,
out_channels
)[
1
],
nn
.
ReLU
(
inplace
=
True
))
elif
conv_type
==
'spconv'
:
m
=
spconv
.
SparseSequential
(
spconv
.
SparseConv3d
(
in_channels
,
out_channels
,
kernel_size
,
stride
=
stride
,
padding
=
padding
,
bias
=
False
,
indice_key
=
indice_key
),
build_norm_layer
(
norm_cfg
,
out_channels
)[
1
],
nn
.
ReLU
(
inplace
=
True
))
elif
conv_type
==
'inverseconv'
:
m
=
spconv
.
SparseSequential
(
spconv
.
SparseInverseConv3d
(
in_channels
,
out_channels
,
kernel_size
,
bias
=
False
,
indice_key
=
indice_key
),
build_norm_layer
(
norm_cfg
,
out_channels
)[
1
],
nn
.
ReLU
(
inplace
=
True
))
else
:
raise
NotImplementedError
return
m
def
make_encoder_layers
(
self
,
make_block
,
norm_cfg
,
in_channels
):
def
make_encoder_layers
(
self
,
make_block
,
norm_cfg
,
in_channels
):
"""make encoder layers using sparse convs
"""make encoder layers using sparse convs
...
@@ -326,6 +195,7 @@ class SparseUNet(nn.Module):
...
@@ -326,6 +195,7 @@ class SparseUNet(nn.Module):
int: the number of encoder output channels
int: the number of encoder output channels
"""
"""
self
.
encoder_layers
=
spconv
.
SparseSequential
()
self
.
encoder_layers
=
spconv
.
SparseSequential
()
for
i
,
blocks
in
enumerate
(
self
.
encoder_channels
):
for
i
,
blocks
in
enumerate
(
self
.
encoder_channels
):
blocks_list
=
[]
blocks_list
=
[]
for
j
,
out_channels
in
enumerate
(
tuple
(
blocks
)):
for
j
,
out_channels
in
enumerate
(
tuple
(
blocks
)):
...
@@ -342,7 +212,7 @@ class SparseUNet(nn.Module):
...
@@ -342,7 +212,7 @@ class SparseUNet(nn.Module):
stride
=
2
,
stride
=
2
,
padding
=
padding
,
padding
=
padding
,
indice_key
=
f
'spconv
{
i
+
1
}
'
,
indice_key
=
f
'spconv
{
i
+
1
}
'
,
conv_type
=
'
spc
onv'
))
conv_type
=
'
SparseC
onv
3d
'
))
else
:
else
:
blocks_list
.
append
(
blocks_list
.
append
(
make_block
(
make_block
(
...
@@ -351,7 +221,8 @@ class SparseUNet(nn.Module):
...
@@ -351,7 +221,8 @@ class SparseUNet(nn.Module):
3
,
3
,
norm_cfg
=
norm_cfg
,
norm_cfg
=
norm_cfg
,
padding
=
padding
,
padding
=
padding
,
indice_key
=
f
'subm
{
i
+
1
}
'
))
indice_key
=
f
'subm
{
i
+
1
}
'
,
conv_type
=
'SubMConv3d'
))
in_channels
=
out_channels
in_channels
=
out_channels
stage_name
=
f
'encoder_layer
{
i
+
1
}
'
stage_name
=
f
'encoder_layer
{
i
+
1
}
'
stage_layers
=
spconv
.
SparseSequential
(
*
blocks_list
)
stage_layers
=
spconv
.
SparseSequential
(
*
blocks_list
)
...
@@ -388,7 +259,8 @@ class SparseUNet(nn.Module):
...
@@ -388,7 +259,8 @@ class SparseUNet(nn.Module):
3
,
3
,
norm_cfg
=
norm_cfg
,
norm_cfg
=
norm_cfg
,
padding
=
paddings
[
0
],
padding
=
paddings
[
0
],
indice_key
=
f
'subm
{
block_num
-
i
}
'
))
indice_key
=
f
'subm
{
block_num
-
i
}
'
,
conv_type
=
'SubMConv3d'
))
if
block_num
-
i
!=
1
:
if
block_num
-
i
!=
1
:
setattr
(
setattr
(
self
,
f
'upsample_layer
{
block_num
-
i
}
'
,
self
,
f
'upsample_layer
{
block_num
-
i
}
'
,
...
@@ -397,9 +269,8 @@ class SparseUNet(nn.Module):
...
@@ -397,9 +269,8 @@ class SparseUNet(nn.Module):
block_channels
[
2
],
block_channels
[
2
],
3
,
3
,
norm_cfg
=
norm_cfg
,
norm_cfg
=
norm_cfg
,
padding
=
paddings
[
1
],
indice_key
=
f
'spconv
{
block_num
-
i
}
'
,
indice_key
=
f
'spconv
{
block_num
-
i
}
'
,
conv_type
=
'
i
nverse
c
onv'
))
conv_type
=
'
SparseI
nverse
C
onv
3d
'
))
else
:
else
:
# use submanifold conv instead of inverse conv
# use submanifold conv instead of inverse conv
# in the last block
# in the last block
...
@@ -412,5 +283,5 @@ class SparseUNet(nn.Module):
...
@@ -412,5 +283,5 @@ class SparseUNet(nn.Module):
norm_cfg
=
norm_cfg
,
norm_cfg
=
norm_cfg
,
padding
=
paddings
[
1
],
padding
=
paddings
[
1
],
indice_key
=
'subm1'
,
indice_key
=
'subm1'
,
conv_type
=
'
s
ub
m
'
))
conv_type
=
'
S
ub
MConv3d
'
))
in_channels
=
block_channels
[
2
]
in_channels
=
block_channels
[
2
]
mmdet3d/ops/__init__.py
View file @
28e511cd
...
@@ -4,8 +4,8 @@ from mmdet.ops import (RoIAlign, SigmoidFocalLoss, get_compiler_version,
...
@@ -4,8 +4,8 @@ from mmdet.ops import (RoIAlign, SigmoidFocalLoss, get_compiler_version,
from
.norm
import
NaiveSyncBatchNorm1d
,
NaiveSyncBatchNorm2d
from
.norm
import
NaiveSyncBatchNorm1d
,
NaiveSyncBatchNorm2d
from
.roiaware_pool3d
import
(
RoIAwarePool3d
,
points_in_boxes_cpu
,
from
.roiaware_pool3d
import
(
RoIAwarePool3d
,
points_in_boxes_cpu
,
points_in_boxes_gpu
)
points_in_boxes_gpu
)
from
.sparse_block
import
(
SparseBasicBlock
,
SparseB
asicBlockV0
,
from
.sparse_block
import
(
SparseBasicBlock
,
SparseB
ottleneck
,
SparseBottleneck
,
SparseBottleneckV0
)
make_sparse_convmodule
)
from
.voxel
import
DynamicScatter
,
Voxelization
,
dynamic_scatter
,
voxelization
from
.voxel
import
DynamicScatter
,
Voxelization
,
dynamic_scatter
,
voxelization
__all__
=
[
__all__
=
[
...
@@ -13,7 +13,7 @@ __all__ = [
...
@@ -13,7 +13,7 @@ __all__ = [
'get_compiling_cuda_version'
,
'NaiveSyncBatchNorm1d'
,
'get_compiling_cuda_version'
,
'NaiveSyncBatchNorm1d'
,
'NaiveSyncBatchNorm2d'
,
'batched_nms'
,
'Voxelization'
,
'voxelization'
,
'NaiveSyncBatchNorm2d'
,
'batched_nms'
,
'Voxelization'
,
'voxelization'
,
'dynamic_scatter'
,
'DynamicScatter'
,
'sigmoid_focal_loss'
,
'dynamic_scatter'
,
'DynamicScatter'
,
'sigmoid_focal_loss'
,
'SigmoidFocalLoss'
,
'SparseBasicBlock
V0
'
,
'SparseBottleneck
V0
'
,
'SigmoidFocalLoss'
,
'SparseBasicBlock'
,
'SparseBottleneck'
,
'
SparseBasicBlock'
,
'SparseBottleneck'
,
'RoIAwarePool3d
'
,
'
RoIAwarePool3d'
,
'points_in_boxes_gpu'
,
'points_in_boxes_cpu
'
,
'
points_in_boxes_gpu'
,
'points_in_boxes_cpu
'
'
make_sparse_convmodule
'
]
]
mmdet3d/ops/sparse_block.py
View file @
28e511cd
from
mmcv.cnn
import
build_norm_layer
from
mmcv.cnn
import
build_conv_layer
,
build_norm_layer
from
torch
import
nn
from
torch
import
nn
from
mmdet3d.ops
import
spconv
from
mmdet.models.backbones.resnet
import
BasicBlock
,
Bottleneck
from
mmdet.models.backbones.resnet
import
BasicBlock
,
Bottleneck
from
.
import
spconv
def
conv3x3
(
in_planes
,
out_planes
,
stride
=
1
,
indice_key
=
None
):
"""3x3 submanifold sparse convolution with padding.
Args:
in_planes (int): the number of input channels
out_planes (int): the number of output channels
stride (int): the stride of convolution
indice_key (str): the indice key used for sparse tensor
Returns:
spconv.conv.SubMConv3d: 3x3 submanifold sparse convolution ops
"""
# TODO: deprecate this class
return
spconv
.
SubMConv3d
(
in_planes
,
out_planes
,
kernel_size
=
3
,
stride
=
stride
,
padding
=
1
,
bias
=
False
,
indice_key
=
indice_key
)
def
conv1x1
(
in_planes
,
out_planes
,
stride
=
1
,
indice_key
=
None
):
"""1x1 submanifold sparse convolution with padding.
Args:
in_planes (int): the number of input channels
out_planes (int): the number of output channels
stride (int): the stride of convolution
indice_key (str): the indice key used for sparse tensor
Returns:
spconv.conv.SubMConv3d: 1x1 submanifold sparse convolution ops
"""
# TODO: deprecate this class
return
spconv
.
SubMConv3d
(
in_planes
,
out_planes
,
kernel_size
=
1
,
stride
=
stride
,
padding
=
1
,
bias
=
False
,
indice_key
=
indice_key
)
class
SparseBasicBlockV0
(
spconv
.
SparseModule
):
expansion
=
1
def
__init__
(
self
,
inplanes
,
planes
,
stride
=
1
,
downsample
=
None
,
indice_key
=
None
,
norm_cfg
=
None
):
"""Sparse basic block for PartA^2.
Sparse basic block implemented with submanifold sparse convolution.
"""
# TODO: deprecate this class
super
().
__init__
()
self
.
conv1
=
conv3x3
(
inplanes
,
planes
,
stride
,
indice_key
=
indice_key
)
norm_name1
,
norm_layer1
=
build_norm_layer
(
norm_cfg
,
planes
)
self
.
bn1
=
norm_layer1
self
.
relu
=
nn
.
ReLU
()
self
.
conv2
=
conv3x3
(
planes
,
planes
,
indice_key
=
indice_key
)
norm_name2
,
norm_layer2
=
build_norm_layer
(
norm_cfg
,
planes
)
self
.
bn2
=
norm_layer2
self
.
downsample
=
downsample
self
.
stride
=
stride
def
forward
(
self
,
x
):
identity
=
x
.
features
assert
x
.
features
.
dim
()
==
2
,
f
'x.features.dim()=
{
x
.
features
.
dim
()
}
'
out
=
self
.
conv1
(
x
)
out
.
features
=
self
.
bn1
(
out
.
features
)
out
.
features
=
self
.
relu
(
out
.
features
)
out
=
self
.
conv2
(
out
)
out
.
features
=
self
.
bn2
(
out
.
features
)
if
self
.
downsample
is
not
None
:
identity
=
self
.
downsample
(
x
)
out
.
features
+=
identity
out
.
features
=
self
.
relu
(
out
.
features
)
return
out
class
SparseBottleneckV0
(
spconv
.
SparseModule
):
expansion
=
4
def
__init__
(
self
,
inplanes
,
planes
,
stride
=
1
,
downsample
=
None
,
indice_key
=
None
,
norm_fn
=
None
):
"""Sparse bottleneck block for PartA^2.
Bottleneck block implemented with submanifold sparse convolution.
"""
# TODO: deprecate this class
super
().
__init__
()
self
.
conv1
=
conv1x1
(
inplanes
,
planes
,
indice_key
=
indice_key
)
self
.
bn1
=
norm_fn
(
planes
)
self
.
conv2
=
conv3x3
(
planes
,
planes
,
stride
,
indice_key
=
indice_key
)
self
.
bn2
=
norm_fn
(
planes
)
self
.
conv3
=
conv1x1
(
planes
,
planes
*
self
.
expansion
,
indice_key
=
indice_key
)
self
.
bn3
=
norm_fn
(
planes
*
self
.
expansion
)
self
.
relu
=
nn
.
ReLU
()
self
.
downsample
=
downsample
self
.
stride
=
stride
def
forward
(
self
,
x
):
identity
=
x
.
features
out
=
self
.
conv1
(
x
)
out
.
features
=
self
.
bn1
(
out
.
features
)
out
.
features
=
self
.
relu
(
out
.
features
)
out
=
self
.
conv2
(
out
)
out
.
features
=
self
.
bn2
(
out
.
features
)
out
.
features
=
self
.
relu
(
out
.
features
)
out
=
self
.
conv3
(
out
)
out
.
features
=
self
.
bn3
(
out
.
features
)
if
self
.
downsample
is
not
None
:
identity
=
self
.
downsample
(
x
)
out
.
features
+=
identity
out
.
features
=
self
.
relu
(
out
.
features
)
return
out
class
SparseBottleneck
(
Bottleneck
,
spconv
.
SparseModule
):
class
SparseBottleneck
(
Bottleneck
,
spconv
.
SparseModule
):
...
@@ -238,3 +95,67 @@ class SparseBasicBlock(BasicBlock, spconv.SparseModule):
...
@@ -238,3 +95,67 @@ class SparseBasicBlock(BasicBlock, spconv.SparseModule):
out
.
features
=
self
.
relu
(
out
.
features
)
out
.
features
=
self
.
relu
(
out
.
features
)
return
out
return
out
def
make_sparse_convmodule
(
in_channels
,
out_channels
,
kernel_size
,
indice_key
,
stride
=
1
,
padding
=
0
,
conv_type
=
'SubMConv3d'
,
norm_cfg
=
None
,
order
=
(
'conv'
,
'norm'
,
'act'
)):
"""Make sparse convolution module.
Args:
in_channels (int): the number of input channels
out_channels (int): the number of out channels
kernel_size (int|tuple(int)): kernel size of convolution
indice_key (str): the indice key used for sparse tensor
stride (int|tuple(int)): the stride of convolution
padding (int or list[int]): the padding number of input
conv_type (str): sparse conv type in spconv
norm_cfg (dict[str]): config of normalization layer
order (tuple[str]): The order of conv/norm/activation layers. It is a
sequence of "conv", "norm" and "act". Common examples are
("conv", "norm", "act") and ("act", "conv", "norm").
Returns:
spconv.SparseSequential: sparse convolution module.
"""
assert
isinstance
(
order
,
tuple
)
and
len
(
order
)
<=
3
conv_cfg
=
dict
(
type
=
conv_type
,
indice_key
=
indice_key
)
layers
=
list
()
for
layer
in
order
:
if
layer
==
'conv'
:
if
conv_type
not
in
[
'SparseInverseConv3d'
,
'SparseInverseConv2d'
,
'SparseInverseConv1d'
]:
layers
.
append
(
build_conv_layer
(
conv_cfg
,
in_channels
,
out_channels
,
kernel_size
,
stride
=
stride
,
padding
=
padding
,
bias
=
False
))
else
:
layers
.
append
(
build_conv_layer
(
conv_cfg
,
in_channels
,
out_channels
,
kernel_size
,
bias
=
False
))
elif
layer
==
'norm'
:
layers
.
append
(
build_norm_layer
(
norm_cfg
,
out_channels
)[
1
])
elif
layer
==
'act'
:
layers
.
append
(
nn
.
ReLU
(
inplace
=
True
))
layers
=
spconv
.
SparseSequential
(
*
layers
)
return
layers
tests/test_sparse_unet.py
View file @
28e511cd
import
torch
import
torch
import
mmdet3d.ops.spconv
as
spconv
import
mmdet3d.ops.spconv
as
spconv
from
mmdet3d.ops
import
SparseBasicBlock
,
SparseBasicBlockV0
from
mmdet3d.ops
import
SparseBasicBlock
def
test_SparseUNet
():
def
test_SparseUNet
():
from
mmdet3d.models.middle_encoders.sparse_unet
import
SparseUNet
from
mmdet3d.models.middle_encoders.sparse_unet
import
SparseUNet
self
=
SparseUNet
(
self
=
SparseUNet
(
in_channels
=
4
,
sparse_shape
=
[
41
,
1600
,
1408
])
in_channels
=
4
,
output_shape
=
[
41
,
1600
,
1408
],
pre_act
=
False
)
# test encoder layers
# test encoder layers
assert
len
(
self
.
encoder_layers
)
==
4
assert
len
(
self
.
encoder_layers
)
==
4
...
@@ -61,17 +60,6 @@ def test_SparseBasicBlock():
...
@@ -61,17 +60,6 @@ def test_SparseBasicBlock():
[
1
,
35
,
930
,
469
]],
[
1
,
35
,
930
,
469
]],
dtype
=
torch
.
int32
)
# n, 4(batch, ind_x, ind_y, ind_z)
dtype
=
torch
.
int32
)
# n, 4(batch, ind_x, ind_y, ind_z)
# test v0
self
=
SparseBasicBlockV0
(
4
,
4
,
indice_key
=
'subm0'
,
norm_cfg
=
dict
(
type
=
'BN1d'
,
eps
=
1e-3
,
momentum
=
0.01
))
input_sp_tensor
=
spconv
.
SparseConvTensor
(
voxel_features
,
coordinates
,
[
41
,
1600
,
1408
],
2
)
out_features
=
self
(
input_sp_tensor
)
assert
out_features
.
features
.
shape
==
torch
.
Size
([
4
,
4
])
# test
# test
input_sp_tensor
=
spconv
.
SparseConvTensor
(
voxel_features
,
coordinates
,
input_sp_tensor
=
spconv
.
SparseConvTensor
(
voxel_features
,
coordinates
,
[
41
,
1600
,
1408
],
2
)
[
41
,
1600
,
1408
],
2
)
...
@@ -92,3 +80,57 @@ def test_SparseBasicBlock():
...
@@ -92,3 +80,57 @@ def test_SparseBasicBlock():
out_features
=
self
(
input_sp_tensor
)
out_features
=
self
(
input_sp_tensor
)
assert
out_features
.
features
.
shape
==
torch
.
Size
([
4
,
4
])
assert
out_features
.
features
.
shape
==
torch
.
Size
([
4
,
4
])
def
test_make_sparse_convmodule
():
from
mmdet3d.ops
import
make_sparse_convmodule
voxel_features
=
torch
.
tensor
([[
6.56126
,
0.9648336
,
-
1.7339306
,
0.315
],
[
6.8162713
,
-
2.480431
,
-
1.3616394
,
0.36
],
[
11.643568
,
-
4.744306
,
-
1.3580885
,
0.16
],
[
23.482342
,
6.5036807
,
0.5806964
,
0.35
]],
dtype
=
torch
.
float32
)
# n, point_features
coordinates
=
torch
.
tensor
(
[[
0
,
12
,
819
,
131
],
[
0
,
16
,
750
,
136
],
[
1
,
16
,
705
,
232
],
[
1
,
35
,
930
,
469
]],
dtype
=
torch
.
int32
)
# n, 4(batch, ind_x, ind_y, ind_z)
# test
input_sp_tensor
=
spconv
.
SparseConvTensor
(
voxel_features
,
coordinates
,
[
41
,
1600
,
1408
],
2
)
sparse_block0
=
make_sparse_convmodule
(
4
,
16
,
3
,
'test0'
,
stride
=
1
,
padding
=
0
,
conv_type
=
'SubMConv3d'
,
norm_cfg
=
dict
(
type
=
'BN1d'
,
eps
=
1e-3
,
momentum
=
0.01
),
order
=
(
'conv'
,
'norm'
,
'act'
))
assert
isinstance
(
sparse_block0
[
0
],
spconv
.
SubMConv3d
)
assert
sparse_block0
[
0
].
in_channels
==
4
assert
sparse_block0
[
0
].
out_channels
==
16
assert
isinstance
(
sparse_block0
[
1
],
torch
.
nn
.
BatchNorm1d
)
assert
sparse_block0
[
1
].
eps
==
0.001
assert
sparse_block0
[
1
].
momentum
==
0.01
assert
isinstance
(
sparse_block0
[
2
],
torch
.
nn
.
ReLU
)
# test forward
out_features
=
sparse_block0
(
input_sp_tensor
)
assert
out_features
.
features
.
shape
==
torch
.
Size
([
4
,
16
])
sparse_block1
=
make_sparse_convmodule
(
4
,
16
,
3
,
'test1'
,
stride
=
1
,
padding
=
0
,
conv_type
=
'SparseInverseConv3d'
,
norm_cfg
=
dict
(
type
=
'BN1d'
,
eps
=
1e-3
,
momentum
=
0.01
),
order
=
(
'norm'
,
'act'
,
'conv'
))
assert
isinstance
(
sparse_block1
[
0
],
torch
.
nn
.
BatchNorm1d
)
assert
isinstance
(
sparse_block1
[
1
],
torch
.
nn
.
ReLU
)
assert
isinstance
(
sparse_block1
[
2
],
spconv
.
SparseInverseConv3d
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment