Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
mmdetection3d
Commits
b64375c8
Commit
b64375c8
authored
May 19, 2020
by
wuyuefeng
Committed by
zhangwenwei
May 19, 2020
Browse files
Feat pointnet2 modules
parent
fb2120b9
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
346 additions
and
2 deletions
+346
-2
mmdet3d/ops/__init__.py
mmdet3d/ops/__init__.py
+5
-2
mmdet3d/ops/pointnet_modules/__init__.py
mmdet3d/ops/pointnet_modules/__init__.py
+4
-0
mmdet3d/ops/pointnet_modules/point_fp_module.py
mmdet3d/ops/pointnet_modules/point_fp_module.py
+76
-0
mmdet3d/ops/pointnet_modules/point_sa_module.py
mmdet3d/ops/pointnet_modules/point_sa_module.py
+176
-0
tests/test_pointnet_modules.py
tests/test_pointnet_modules.py
+85
-0
No files found.
mmdet3d/ops/__init__.py
View file @
b64375c8
...
...
@@ -4,9 +4,11 @@ from mmdet.ops import (RoIAlign, SigmoidFocalLoss, get_compiler_version,
from
.ball_query
import
ball_query
from
.furthest_point_sample
import
furthest_point_sample
from
.gather_points
import
gather_points
from
.group_points
import
group_points
,
grouping_operation
from
.group_points
import
(
GroupAll
,
QueryAndGroup
,
group_points
,
grouping_operation
)
from
.interpolate
import
three_interpolate
,
three_nn
from
.norm
import
NaiveSyncBatchNorm1d
,
NaiveSyncBatchNorm2d
from
.pointnet_modules
import
PointFPModule
,
PointSAModule
,
PointSAModuleMSG
from
.roiaware_pool3d
import
(
RoIAwarePool3d
,
points_in_boxes_cpu
,
points_in_boxes_gpu
)
from
.sparse_block
import
(
SparseBasicBlock
,
SparseBottleneck
,
...
...
@@ -22,5 +24,6 @@ __all__ = [
'RoIAwarePool3d'
,
'points_in_boxes_gpu'
,
'points_in_boxes_cpu'
,
'make_sparse_convmodule'
,
'ball_query'
,
'furthest_point_sample'
,
'three_interpolate'
,
'three_nn'
,
'gather_points'
,
'grouping_operation'
,
'group_points'
'group_points'
,
'GroupAll'
,
'QueryAndGroup'
,
'PointSAModule'
,
'PointSAModuleMSG'
,
'PointFPModule'
]
mmdet3d/ops/pointnet_modules/__init__.py
0 → 100644
View file @
b64375c8
from
.point_fp_module
import
PointFPModule
from
.point_sa_module
import
PointSAModule
,
PointSAModuleMSG
__all__
=
[
'PointSAModuleMSG'
,
'PointSAModule'
,
'PointFPModule'
]
mmdet3d/ops/pointnet_modules/point_fp_module.py
0 → 100644
View file @
b64375c8
from
typing
import
List
import
torch
import
torch.nn
as
nn
from
mmcv.cnn
import
ConvModule
from
mmdet3d.ops
import
three_interpolate
,
three_nn
class
PointFPModule
(
nn
.
Module
):
"""Point feature propagation module used in PointNets.
Propagate the features from one set to another.
Args:
mlp_channels (list[int]): List of mlp channels.
norm_cfg (dict): Type of normalization method.
Default: dict(type='BN2d').
"""
def
__init__
(
self
,
mlp_channels
:
List
[
int
],
norm_cfg
:
dict
=
dict
(
type
=
'BN2d'
)):
super
().
__init__
()
self
.
mlps
=
nn
.
Sequential
()
for
i
in
range
(
len
(
mlp_channels
)
-
1
):
self
.
mlps
.
add_module
(
f
'layer
{
i
}
'
,
ConvModule
(
mlp_channels
[
i
],
mlp_channels
[
i
+
1
],
kernel_size
=
(
1
,
1
),
stride
=
(
1
,
1
),
conv_cfg
=
dict
(
type
=
'Conv2d'
),
norm_cfg
=
norm_cfg
))
def
forward
(
self
,
target
:
torch
.
Tensor
,
source
:
torch
.
Tensor
,
target_feats
:
torch
.
Tensor
,
source_feats
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""forward.
Args:
target (Tensor): (B, n, 3) tensor of the xyz positions of
the target features.
source (Tensor): (B, m, 3) tensor of the xyz positions of
the source features.
target_feats (Tensor): (B, C1, n) tensor of the features to be
propagated to.
source_feats (Tensor): (B, C2, m) tensor of features
to be propagated.
Return:
Tensor: (B, M, N) M = mlp[-1], tensor of the target features.
"""
if
source
is
not
None
:
dist
,
idx
=
three_nn
(
target
,
source
)
dist_reciprocal
=
1.0
/
(
dist
+
1e-8
)
norm
=
torch
.
sum
(
dist_reciprocal
,
dim
=
2
,
keepdim
=
True
)
weight
=
dist_reciprocal
/
norm
interpolated_feats
=
three_interpolate
(
source_feats
,
idx
,
weight
)
else
:
interpolated_feats
=
source_feats
.
expand
(
*
source_feats
.
size
()[
0
:
2
],
target
.
size
(
1
))
if
target_feats
is
not
None
:
new_features
=
torch
.
cat
([
interpolated_feats
,
target_feats
],
dim
=
1
)
# (B, C2 + C1, n)
else
:
new_features
=
interpolated_feats
new_features
=
new_features
.
unsqueeze
(
-
1
)
new_features
=
self
.
mlps
(
new_features
)
return
new_features
.
squeeze
(
-
1
)
mmdet3d/ops/pointnet_modules/point_sa_module.py
0 → 100644
View file @
b64375c8
from
typing
import
List
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
mmcv.cnn
import
ConvModule
from
mmdet3d.ops
import
(
GroupAll
,
QueryAndGroup
,
furthest_point_sample
,
gather_points
)
class
PointSAModuleMSG
(
nn
.
Module
):
"""Point set abstraction module with multi-scale grouping used in Pointnets.
Args:
num_point (int): Number of points.
radii (list[float]): List of radius in each ball query.
sample_nums (list[int]): Number of samples in each ball query.
mlp_channels (list[int]): Specify of the pointnet before
the global pooling for each scale.
norm_cfg (dict): Type of normalization method.
Default: dict(type='BN2d').
use_xyz (bool): Whether to use xyz.
Default: True.
pool_mod (str): Type of pooling method.
Default: 'max_pool'.
normalize_xyz (bool): Whether to normalize local XYZ with radius.
Default: False.
"""
def
__init__
(
self
,
num_point
:
int
,
radii
:
List
[
float
],
sample_nums
:
List
[
int
],
mlp_channels
:
List
[
List
[
int
]],
norm_cfg
:
dict
=
dict
(
type
=
'BN2d'
),
use_xyz
:
bool
=
True
,
pool_mod
=
'max'
,
normalize_xyz
:
bool
=
False
):
super
().
__init__
()
assert
len
(
radii
)
==
len
(
sample_nums
)
==
len
(
mlp_channels
)
assert
pool_mod
in
[
'max'
,
'avg'
]
self
.
num_point
=
num_point
self
.
pool_mod
=
pool_mod
self
.
groupers
=
nn
.
ModuleList
()
self
.
mlps
=
nn
.
ModuleList
()
for
i
in
range
(
len
(
radii
)):
radius
=
radii
[
i
]
sample_num
=
sample_nums
[
i
]
if
num_point
is
not
None
:
grouper
=
QueryAndGroup
(
radius
,
sample_num
,
use_xyz
=
use_xyz
,
normalize_xyz
=
normalize_xyz
)
else
:
grouper
=
GroupAll
(
use_xyz
)
self
.
groupers
.
append
(
grouper
)
mlp_spec
=
mlp_channels
[
i
]
if
use_xyz
:
mlp_spec
[
0
]
+=
3
mlp
=
nn
.
Sequential
()
for
i
in
range
(
len
(
mlp_spec
)
-
1
):
mlp
.
add_module
(
f
'layer
{
i
}
'
,
ConvModule
(
mlp_spec
[
i
],
mlp_spec
[
i
+
1
],
kernel_size
=
(
1
,
1
),
stride
=
(
1
,
1
),
conv_cfg
=
dict
(
type
=
'Conv2d'
),
norm_cfg
=
norm_cfg
))
self
.
mlps
.
append
(
mlp
)
def
forward
(
self
,
points_xyz
:
torch
.
Tensor
,
features
:
torch
.
Tensor
=
None
,
indices
:
torch
.
Tensor
=
None
)
->
(
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
):
"""forward.
Args:
points_xyz (Tensor): (B, N, 3) xyz coordinates of the features.
features (Tensor): (B, C, N) features of each point.
Default: None.
indices (Tensor): (B, num_point) Index of the features.
Default: None.
Returns:
Tensor: (B, M, 3) where M is the number of points.
New features xyz.
Tensor: (B, M, sum_k(mlps[k][-1])) where M is the number
of points. New feature descriptors.
Tensor: (B, M) where M is the number of points.
Index of the features.
"""
new_features_list
=
[]
xyz_flipped
=
points_xyz
.
transpose
(
1
,
2
).
contiguous
()
if
indices
is
None
:
indices
=
furthest_point_sample
(
points_xyz
,
self
.
num_point
)
else
:
assert
(
indices
.
shape
[
1
]
==
self
.
num_point
)
new_xyz
=
gather_points
(
xyz_flipped
,
indices
).
transpose
(
1
,
2
).
contiguous
()
if
self
.
num_point
is
not
None
else
None
for
i
in
range
(
len
(
self
.
groupers
)):
# (B, C, num_point, nsample)
new_features
=
self
.
groupers
[
i
](
points_xyz
,
new_xyz
,
features
)
# (B, mlp[-1], num_point, nsample)
new_features
=
self
.
mlps
[
i
](
new_features
)
if
self
.
pool_mod
==
'max'
:
# (B, mlp[-1], num_point, 1)
new_features
=
F
.
max_pool2d
(
new_features
,
kernel_size
=
[
1
,
new_features
.
size
(
3
)])
elif
self
.
pool_mod
==
'avg'
:
# (B, mlp[-1], num_point, 1)
new_features
=
F
.
avg_pool2d
(
new_features
,
kernel_size
=
[
1
,
new_features
.
size
(
3
)])
else
:
raise
NotImplementedError
new_features
=
new_features
.
squeeze
(
-
1
)
# (B, mlp[-1], num_point)
new_features_list
.
append
(
new_features
)
return
new_xyz
,
torch
.
cat
(
new_features_list
,
dim
=
1
),
indices
class
PointSAModule
(
PointSAModuleMSG
):
"""Point set abstraction module used in Pointnets.
Args:
mlp_channels (list[int]): Specify of the pointnet before
the global pooling for each scale.
num_point (int): Number of points.
Default: None.
radius (float): Radius to group with.
Default: None.
num_sample (int): Number of samples in each ball query.
Default: None.
norm_cfg (dict): Type of normalization method.
Default: dict(type='BN2d').
use_xyz (bool): Whether to use xyz.
Default: True.
pool_mod (str): Type of pooling method.
Default: 'max_pool'.
normalize_xyz (bool): Whether to normalize local XYZ with radius.
Default: False.
"""
def
__init__
(
self
,
mlp_channels
:
List
[
int
],
num_point
:
int
=
None
,
radius
:
float
=
None
,
num_sample
:
int
=
None
,
norm_cfg
:
dict
=
dict
(
type
=
'BN2d'
),
use_xyz
:
bool
=
True
,
pool_mod
:
str
=
'max'
,
normalize_xyz
:
bool
=
False
):
super
().
__init__
(
mlp_channels
=
[
mlp_channels
],
num_point
=
num_point
,
radii
=
[
radius
],
sample_nums
=
[
num_sample
],
norm_cfg
=
norm_cfg
,
use_xyz
=
use_xyz
,
pool_mod
=
pool_mod
,
normalize_xyz
=
normalize_xyz
)
tests/test_pointnet_modules.py
0 → 100644
View file @
b64375c8
import
numpy
as
np
import
torch
def
test_pointnet_sa_module_msg
():
from
mmdet3d.ops
import
PointSAModuleMSG
self
=
PointSAModuleMSG
(
num_point
=
16
,
radii
=
[
0.2
,
0.4
],
sample_nums
=
[
4
,
8
],
mlp_channels
=
[[
12
,
16
],
[
12
,
32
]],
norm_cfg
=
dict
(
type
=
'BN2d'
),
use_xyz
=
False
,
pool_mod
=
'max'
).
cuda
()
assert
self
.
mlps
[
0
].
layer0
.
conv
.
in_channels
==
12
assert
self
.
mlps
[
0
].
layer0
.
conv
.
out_channels
==
16
assert
self
.
mlps
[
1
].
layer0
.
conv
.
in_channels
==
12
assert
self
.
mlps
[
1
].
layer0
.
conv
.
out_channels
==
32
xyz
=
np
.
load
(
'tests/data/sunrgbd/sunrgbd_trainval/lidar/000001.npy'
)
# (B, N, 3)
xyz
=
torch
.
from_numpy
(
xyz
[...,
:
3
]).
view
(
1
,
-
1
,
3
).
cuda
()
# (B, C, N)
features
=
xyz
.
repeat
([
1
,
1
,
4
]).
transpose
(
1
,
2
).
contiguous
().
cuda
()
# test forward
new_xyz
,
new_features
,
inds
=
self
(
xyz
,
features
)
assert
new_xyz
.
shape
==
torch
.
Size
([
1
,
16
,
3
])
assert
new_features
.
shape
==
torch
.
Size
([
1
,
48
,
16
])
assert
inds
.
shape
==
torch
.
Size
([
1
,
16
])
def
test_pointnet_sa_module
():
from
mmdet3d.ops
import
PointSAModule
self
=
PointSAModule
(
num_point
=
16
,
radius
=
0.2
,
num_sample
=
8
,
mlp_channels
=
[
12
,
32
],
norm_cfg
=
dict
(
type
=
'BN2d'
),
use_xyz
=
True
,
pool_mod
=
'max'
).
cuda
()
assert
self
.
mlps
[
0
].
layer0
.
conv
.
in_channels
==
15
assert
self
.
mlps
[
0
].
layer0
.
conv
.
out_channels
==
32
xyz
=
np
.
load
(
'tests/data/sunrgbd/sunrgbd_trainval/lidar/000001.npy'
)
# (B, N, 3)
xyz
=
torch
.
from_numpy
(
xyz
[...,
:
3
]).
view
(
1
,
-
1
,
3
).
cuda
()
# (B, C, N)
features
=
xyz
.
repeat
([
1
,
1
,
4
]).
transpose
(
1
,
2
).
contiguous
().
cuda
()
# test forward
new_xyz
,
new_features
,
inds
=
self
(
xyz
,
features
)
assert
new_xyz
.
shape
==
torch
.
Size
([
1
,
16
,
3
])
assert
new_features
.
shape
==
torch
.
Size
([
1
,
32
,
16
])
assert
inds
.
shape
==
torch
.
Size
([
1
,
16
])
def
test_pointnet_fp_module
():
from
mmdet3d.ops
import
PointFPModule
self
=
PointFPModule
(
mlp_channels
=
[
24
,
16
]).
cuda
()
assert
self
.
mlps
.
layer0
.
conv
.
in_channels
==
24
assert
self
.
mlps
.
layer0
.
conv
.
out_channels
==
16
xyz
=
np
.
load
(
'tests/data/sunrgbd/sunrgbd_trainval/lidar/000001.npy'
)
# (B, N, 3)
xyz1
=
torch
.
from_numpy
(
xyz
[
0
::
2
,
:
3
]).
view
(
1
,
-
1
,
3
).
cuda
()
# (B, C1, N)
features1
=
xyz1
.
repeat
([
1
,
1
,
4
]).
transpose
(
1
,
2
).
contiguous
().
cuda
()
# (B, M, 3)
xyz2
=
torch
.
from_numpy
(
xyz
[
1
::
3
,
:
3
]).
view
(
1
,
-
1
,
3
).
cuda
()
# (B, C2, N)
features2
=
xyz2
.
repeat
([
1
,
1
,
4
]).
transpose
(
1
,
2
).
contiguous
().
cuda
()
fp_features
=
self
(
xyz1
,
xyz2
,
features1
,
features2
)
assert
fp_features
.
shape
==
torch
.
Size
([
1
,
16
,
50
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment