Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
mmdetection3d
Commits
db986fa4
Commit
db986fa4
authored
May 19, 2020
by
zhangwenwei
Browse files
Merge branch 'feat_pointnet2_modules' into 'master'
Feat pointnet2 modules See merge request open-mmlab/mmdet.3d!36
parents
fb2120b9
b64375c8
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
346 additions
and
2 deletions
+346
-2
mmdet3d/ops/__init__.py
mmdet3d/ops/__init__.py
+5
-2
mmdet3d/ops/pointnet_modules/__init__.py
mmdet3d/ops/pointnet_modules/__init__.py
+4
-0
mmdet3d/ops/pointnet_modules/point_fp_module.py
mmdet3d/ops/pointnet_modules/point_fp_module.py
+76
-0
mmdet3d/ops/pointnet_modules/point_sa_module.py
mmdet3d/ops/pointnet_modules/point_sa_module.py
+176
-0
tests/test_pointnet_modules.py
tests/test_pointnet_modules.py
+85
-0
No files found.
mmdet3d/ops/__init__.py
View file @
db986fa4
...
...
@@ -4,9 +4,11 @@ from mmdet.ops import (RoIAlign, SigmoidFocalLoss, get_compiler_version,
from
.ball_query
import
ball_query
from
.furthest_point_sample
import
furthest_point_sample
from
.gather_points
import
gather_points
from
.group_points
import
group_points
,
grouping_operation
from
.group_points
import
(
GroupAll
,
QueryAndGroup
,
group_points
,
grouping_operation
)
from
.interpolate
import
three_interpolate
,
three_nn
from
.norm
import
NaiveSyncBatchNorm1d
,
NaiveSyncBatchNorm2d
from
.pointnet_modules
import
PointFPModule
,
PointSAModule
,
PointSAModuleMSG
from
.roiaware_pool3d
import
(
RoIAwarePool3d
,
points_in_boxes_cpu
,
points_in_boxes_gpu
)
from
.sparse_block
import
(
SparseBasicBlock
,
SparseBottleneck
,
...
...
@@ -22,5 +24,6 @@ __all__ = [
'RoIAwarePool3d'
,
'points_in_boxes_gpu'
,
'points_in_boxes_cpu'
,
'make_sparse_convmodule'
,
'ball_query'
,
'furthest_point_sample'
,
'three_interpolate'
,
'three_nn'
,
'gather_points'
,
'grouping_operation'
,
'group_points'
'group_points'
,
'GroupAll'
,
'QueryAndGroup'
,
'PointSAModule'
,
'PointSAModuleMSG'
,
'PointFPModule'
]
mmdet3d/ops/pointnet_modules/__init__.py
0 → 100644
View file @
db986fa4
from
.point_fp_module
import
PointFPModule
from
.point_sa_module
import
PointSAModule
,
PointSAModuleMSG
__all__
=
[
'PointSAModuleMSG'
,
'PointSAModule'
,
'PointFPModule'
]
mmdet3d/ops/pointnet_modules/point_fp_module.py
0 → 100644
View file @
db986fa4
from
typing
import
List
import
torch
import
torch.nn
as
nn
from
mmcv.cnn
import
ConvModule
from
mmdet3d.ops
import
three_interpolate
,
three_nn
class
PointFPModule
(
nn
.
Module
):
"""Point feature propagation module used in PointNets.
Propagate the features from one set to another.
Args:
mlp_channels (list[int]): List of mlp channels.
norm_cfg (dict): Type of normalization method.
Default: dict(type='BN2d').
"""
def
__init__
(
self
,
mlp_channels
:
List
[
int
],
norm_cfg
:
dict
=
dict
(
type
=
'BN2d'
)):
super
().
__init__
()
self
.
mlps
=
nn
.
Sequential
()
for
i
in
range
(
len
(
mlp_channels
)
-
1
):
self
.
mlps
.
add_module
(
f
'layer
{
i
}
'
,
ConvModule
(
mlp_channels
[
i
],
mlp_channels
[
i
+
1
],
kernel_size
=
(
1
,
1
),
stride
=
(
1
,
1
),
conv_cfg
=
dict
(
type
=
'Conv2d'
),
norm_cfg
=
norm_cfg
))
def
forward
(
self
,
target
:
torch
.
Tensor
,
source
:
torch
.
Tensor
,
target_feats
:
torch
.
Tensor
,
source_feats
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""forward.
Args:
target (Tensor): (B, n, 3) tensor of the xyz positions of
the target features.
source (Tensor): (B, m, 3) tensor of the xyz positions of
the source features.
target_feats (Tensor): (B, C1, n) tensor of the features to be
propagated to.
source_feats (Tensor): (B, C2, m) tensor of features
to be propagated.
Return:
Tensor: (B, M, N) M = mlp[-1], tensor of the target features.
"""
if
source
is
not
None
:
dist
,
idx
=
three_nn
(
target
,
source
)
dist_reciprocal
=
1.0
/
(
dist
+
1e-8
)
norm
=
torch
.
sum
(
dist_reciprocal
,
dim
=
2
,
keepdim
=
True
)
weight
=
dist_reciprocal
/
norm
interpolated_feats
=
three_interpolate
(
source_feats
,
idx
,
weight
)
else
:
interpolated_feats
=
source_feats
.
expand
(
*
source_feats
.
size
()[
0
:
2
],
target
.
size
(
1
))
if
target_feats
is
not
None
:
new_features
=
torch
.
cat
([
interpolated_feats
,
target_feats
],
dim
=
1
)
# (B, C2 + C1, n)
else
:
new_features
=
interpolated_feats
new_features
=
new_features
.
unsqueeze
(
-
1
)
new_features
=
self
.
mlps
(
new_features
)
return
new_features
.
squeeze
(
-
1
)
mmdet3d/ops/pointnet_modules/point_sa_module.py
0 → 100644
View file @
db986fa4
from
typing
import
List
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
mmcv.cnn
import
ConvModule
from
mmdet3d.ops
import
(
GroupAll
,
QueryAndGroup
,
furthest_point_sample
,
gather_points
)
class
PointSAModuleMSG
(
nn
.
Module
):
"""Point set abstraction module with multi-scale grouping used in Pointnets.
Args:
num_point (int): Number of points.
radii (list[float]): List of radius in each ball query.
sample_nums (list[int]): Number of samples in each ball query.
mlp_channels (list[int]): Specify of the pointnet before
the global pooling for each scale.
norm_cfg (dict): Type of normalization method.
Default: dict(type='BN2d').
use_xyz (bool): Whether to use xyz.
Default: True.
pool_mod (str): Type of pooling method.
Default: 'max_pool'.
normalize_xyz (bool): Whether to normalize local XYZ with radius.
Default: False.
"""
def
__init__
(
self
,
num_point
:
int
,
radii
:
List
[
float
],
sample_nums
:
List
[
int
],
mlp_channels
:
List
[
List
[
int
]],
norm_cfg
:
dict
=
dict
(
type
=
'BN2d'
),
use_xyz
:
bool
=
True
,
pool_mod
=
'max'
,
normalize_xyz
:
bool
=
False
):
super
().
__init__
()
assert
len
(
radii
)
==
len
(
sample_nums
)
==
len
(
mlp_channels
)
assert
pool_mod
in
[
'max'
,
'avg'
]
self
.
num_point
=
num_point
self
.
pool_mod
=
pool_mod
self
.
groupers
=
nn
.
ModuleList
()
self
.
mlps
=
nn
.
ModuleList
()
for
i
in
range
(
len
(
radii
)):
radius
=
radii
[
i
]
sample_num
=
sample_nums
[
i
]
if
num_point
is
not
None
:
grouper
=
QueryAndGroup
(
radius
,
sample_num
,
use_xyz
=
use_xyz
,
normalize_xyz
=
normalize_xyz
)
else
:
grouper
=
GroupAll
(
use_xyz
)
self
.
groupers
.
append
(
grouper
)
mlp_spec
=
mlp_channels
[
i
]
if
use_xyz
:
mlp_spec
[
0
]
+=
3
mlp
=
nn
.
Sequential
()
for
i
in
range
(
len
(
mlp_spec
)
-
1
):
mlp
.
add_module
(
f
'layer
{
i
}
'
,
ConvModule
(
mlp_spec
[
i
],
mlp_spec
[
i
+
1
],
kernel_size
=
(
1
,
1
),
stride
=
(
1
,
1
),
conv_cfg
=
dict
(
type
=
'Conv2d'
),
norm_cfg
=
norm_cfg
))
self
.
mlps
.
append
(
mlp
)
def
forward
(
self
,
points_xyz
:
torch
.
Tensor
,
features
:
torch
.
Tensor
=
None
,
indices
:
torch
.
Tensor
=
None
)
->
(
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
):
"""forward.
Args:
points_xyz (Tensor): (B, N, 3) xyz coordinates of the features.
features (Tensor): (B, C, N) features of each point.
Default: None.
indices (Tensor): (B, num_point) Index of the features.
Default: None.
Returns:
Tensor: (B, M, 3) where M is the number of points.
New features xyz.
Tensor: (B, M, sum_k(mlps[k][-1])) where M is the number
of points. New feature descriptors.
Tensor: (B, M) where M is the number of points.
Index of the features.
"""
new_features_list
=
[]
xyz_flipped
=
points_xyz
.
transpose
(
1
,
2
).
contiguous
()
if
indices
is
None
:
indices
=
furthest_point_sample
(
points_xyz
,
self
.
num_point
)
else
:
assert
(
indices
.
shape
[
1
]
==
self
.
num_point
)
new_xyz
=
gather_points
(
xyz_flipped
,
indices
).
transpose
(
1
,
2
).
contiguous
()
if
self
.
num_point
is
not
None
else
None
for
i
in
range
(
len
(
self
.
groupers
)):
# (B, C, num_point, nsample)
new_features
=
self
.
groupers
[
i
](
points_xyz
,
new_xyz
,
features
)
# (B, mlp[-1], num_point, nsample)
new_features
=
self
.
mlps
[
i
](
new_features
)
if
self
.
pool_mod
==
'max'
:
# (B, mlp[-1], num_point, 1)
new_features
=
F
.
max_pool2d
(
new_features
,
kernel_size
=
[
1
,
new_features
.
size
(
3
)])
elif
self
.
pool_mod
==
'avg'
:
# (B, mlp[-1], num_point, 1)
new_features
=
F
.
avg_pool2d
(
new_features
,
kernel_size
=
[
1
,
new_features
.
size
(
3
)])
else
:
raise
NotImplementedError
new_features
=
new_features
.
squeeze
(
-
1
)
# (B, mlp[-1], num_point)
new_features_list
.
append
(
new_features
)
return
new_xyz
,
torch
.
cat
(
new_features_list
,
dim
=
1
),
indices
class
PointSAModule
(
PointSAModuleMSG
):
"""Point set abstraction module used in Pointnets.
Args:
mlp_channels (list[int]): Specify of the pointnet before
the global pooling for each scale.
num_point (int): Number of points.
Default: None.
radius (float): Radius to group with.
Default: None.
num_sample (int): Number of samples in each ball query.
Default: None.
norm_cfg (dict): Type of normalization method.
Default: dict(type='BN2d').
use_xyz (bool): Whether to use xyz.
Default: True.
pool_mod (str): Type of pooling method.
Default: 'max_pool'.
normalize_xyz (bool): Whether to normalize local XYZ with radius.
Default: False.
"""
def
__init__
(
self
,
mlp_channels
:
List
[
int
],
num_point
:
int
=
None
,
radius
:
float
=
None
,
num_sample
:
int
=
None
,
norm_cfg
:
dict
=
dict
(
type
=
'BN2d'
),
use_xyz
:
bool
=
True
,
pool_mod
:
str
=
'max'
,
normalize_xyz
:
bool
=
False
):
super
().
__init__
(
mlp_channels
=
[
mlp_channels
],
num_point
=
num_point
,
radii
=
[
radius
],
sample_nums
=
[
num_sample
],
norm_cfg
=
norm_cfg
,
use_xyz
=
use_xyz
,
pool_mod
=
pool_mod
,
normalize_xyz
=
normalize_xyz
)
tests/test_pointnet_modules.py
0 → 100644
View file @
db986fa4
import
numpy
as
np
import
torch
def
test_pointnet_sa_module_msg
():
from
mmdet3d.ops
import
PointSAModuleMSG
self
=
PointSAModuleMSG
(
num_point
=
16
,
radii
=
[
0.2
,
0.4
],
sample_nums
=
[
4
,
8
],
mlp_channels
=
[[
12
,
16
],
[
12
,
32
]],
norm_cfg
=
dict
(
type
=
'BN2d'
),
use_xyz
=
False
,
pool_mod
=
'max'
).
cuda
()
assert
self
.
mlps
[
0
].
layer0
.
conv
.
in_channels
==
12
assert
self
.
mlps
[
0
].
layer0
.
conv
.
out_channels
==
16
assert
self
.
mlps
[
1
].
layer0
.
conv
.
in_channels
==
12
assert
self
.
mlps
[
1
].
layer0
.
conv
.
out_channels
==
32
xyz
=
np
.
load
(
'tests/data/sunrgbd/sunrgbd_trainval/lidar/000001.npy'
)
# (B, N, 3)
xyz
=
torch
.
from_numpy
(
xyz
[...,
:
3
]).
view
(
1
,
-
1
,
3
).
cuda
()
# (B, C, N)
features
=
xyz
.
repeat
([
1
,
1
,
4
]).
transpose
(
1
,
2
).
contiguous
().
cuda
()
# test forward
new_xyz
,
new_features
,
inds
=
self
(
xyz
,
features
)
assert
new_xyz
.
shape
==
torch
.
Size
([
1
,
16
,
3
])
assert
new_features
.
shape
==
torch
.
Size
([
1
,
48
,
16
])
assert
inds
.
shape
==
torch
.
Size
([
1
,
16
])
def
test_pointnet_sa_module
():
from
mmdet3d.ops
import
PointSAModule
self
=
PointSAModule
(
num_point
=
16
,
radius
=
0.2
,
num_sample
=
8
,
mlp_channels
=
[
12
,
32
],
norm_cfg
=
dict
(
type
=
'BN2d'
),
use_xyz
=
True
,
pool_mod
=
'max'
).
cuda
()
assert
self
.
mlps
[
0
].
layer0
.
conv
.
in_channels
==
15
assert
self
.
mlps
[
0
].
layer0
.
conv
.
out_channels
==
32
xyz
=
np
.
load
(
'tests/data/sunrgbd/sunrgbd_trainval/lidar/000001.npy'
)
# (B, N, 3)
xyz
=
torch
.
from_numpy
(
xyz
[...,
:
3
]).
view
(
1
,
-
1
,
3
).
cuda
()
# (B, C, N)
features
=
xyz
.
repeat
([
1
,
1
,
4
]).
transpose
(
1
,
2
).
contiguous
().
cuda
()
# test forward
new_xyz
,
new_features
,
inds
=
self
(
xyz
,
features
)
assert
new_xyz
.
shape
==
torch
.
Size
([
1
,
16
,
3
])
assert
new_features
.
shape
==
torch
.
Size
([
1
,
32
,
16
])
assert
inds
.
shape
==
torch
.
Size
([
1
,
16
])
def
test_pointnet_fp_module
():
from
mmdet3d.ops
import
PointFPModule
self
=
PointFPModule
(
mlp_channels
=
[
24
,
16
]).
cuda
()
assert
self
.
mlps
.
layer0
.
conv
.
in_channels
==
24
assert
self
.
mlps
.
layer0
.
conv
.
out_channels
==
16
xyz
=
np
.
load
(
'tests/data/sunrgbd/sunrgbd_trainval/lidar/000001.npy'
)
# (B, N, 3)
xyz1
=
torch
.
from_numpy
(
xyz
[
0
::
2
,
:
3
]).
view
(
1
,
-
1
,
3
).
cuda
()
# (B, C1, N)
features1
=
xyz1
.
repeat
([
1
,
1
,
4
]).
transpose
(
1
,
2
).
contiguous
().
cuda
()
# (B, M, 3)
xyz2
=
torch
.
from_numpy
(
xyz
[
1
::
3
,
:
3
]).
view
(
1
,
-
1
,
3
).
cuda
()
# (B, C2, N)
features2
=
xyz2
.
repeat
([
1
,
1
,
4
]).
transpose
(
1
,
2
).
contiguous
().
cuda
()
fp_features
=
self
(
xyz1
,
xyz2
,
features1
,
features2
)
assert
fp_features
.
shape
==
torch
.
Size
([
1
,
16
,
50
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment