Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
mmdetection3d
Commits
f717eb62
Commit
f717eb62
authored
Jun 04, 2020
by
wuyuefeng
Committed by
zhangwenwei
Jun 04, 2020
Browse files
Votenet
parent
ac3590a1
Changes
34
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
499 additions
and
61 deletions
+499
-61
mmdet3d/models/losses/__init__.py
mmdet3d/models/losses/__init__.py
+5
-1
mmdet3d/models/losses/chamfer_distance.py
mmdet3d/models/losses/chamfer_distance.py
+119
-0
mmdet3d/models/model_utils/__init__.py
mmdet3d/models/model_utils/__init__.py
+3
-0
mmdet3d/models/model_utils/vote_module.py
mmdet3d/models/model_utils/vote_module.py
+11
-50
mmdet3d/ops/__init__.py
mmdet3d/ops/__init__.py
+3
-4
mmdet3d/ops/roiaware_pool3d/__init__.py
mmdet3d/ops/roiaware_pool3d/__init__.py
+6
-2
mmdet3d/ops/roiaware_pool3d/points_in_boxes.py
mmdet3d/ops/roiaware_pool3d/points_in_boxes.py
+26
-0
mmdet3d/ops/roiaware_pool3d/src/points_in_boxes_cuda.cu
mmdet3d/ops/roiaware_pool3d/src/points_in_boxes_cuda.cu
+76
-0
mmdet3d/ops/roiaware_pool3d/src/roiaware_pool3d.cpp
mmdet3d/ops/roiaware_pool3d/src/roiaware_pool3d.cpp
+5
-0
tests/test_heads.py
tests/test_heads.py
+84
-0
tests/test_losses.py
tests/test_losses.py
+68
-0
tests/test_nms.py
tests/test_nms.py
+57
-0
tests/test_roiaware_pool3d.py
tests/test_roiaware_pool3d.py
+28
-1
tests/test_vote_module.py
tests/test_vote_module.py
+8
-3
No files found.
mmdet3d/models/losses/__init__.py
View file @
f717eb62
from
mmdet.models.losses
import
FocalLoss
,
SmoothL1Loss
,
binary_cross_entropy
from
mmdet.models.losses
import
FocalLoss
,
SmoothL1Loss
,
binary_cross_entropy
from
.chamfer_distance
import
ChamferDistance
,
chamfer_distance
__all__
=
[
'FocalLoss'
,
'SmoothL1Loss'
,
'binary_cross_entropy'
]
__all__
=
[
'FocalLoss'
,
'SmoothL1Loss'
,
'binary_cross_entropy'
,
'ChamferDistance'
,
'chamfer_distance'
]
mmdet3d/models/losses/chamfer_distance.py
0 → 100644
View file @
f717eb62
import
torch
import
torch.nn
as
nn
from
torch.nn.functional
import
l1_loss
,
mse_loss
,
smooth_l1_loss
from
mmdet.models.builder
import
LOSSES
def
chamfer_distance
(
src
,
dst
,
src_weight
=
1.0
,
dst_weight
=
1.0
,
criterion_mode
=
'l2'
,
reduction
=
'mean'
):
"""Calculate Chamfer Distance of two sets.
Args:
src (tensor): Source set with shape [B, N, C] to
calculate Chamfer Distance.
dst (tensor): Destination set with shape [B, M, C] to
calculate Chamfer Distance.
src_weight (tensor or float): Weight of source loss.
dst_weight (tensor or float): Weight of destination loss.
criterion_mode (str): Criterion mode to calculate distance.
The valid modes are smooth_l1, l1 or l2.
reduction (str): Method to reduce losses.
The valid reduction method are none, sum or mean.
Returns:
tuple: Source and Destination loss with indices.
- loss_src (Tensor): The min distance from source to destination.
- loss_dst (Tensor): The min distance from destination to source.
- indices1 (Tensor): Index the min distance point for each point
in source to destination.
- indices2 (Tensor): Index the min distance point for each point
in destination to source.
"""
if
criterion_mode
==
'smooth_l1'
:
criterion
=
smooth_l1_loss
elif
criterion_mode
==
'l1'
:
criterion
=
l1_loss
elif
criterion_mode
==
'l2'
:
criterion
=
mse_loss
else
:
raise
NotImplementedError
src_expand
=
src
.
unsqueeze
(
2
).
repeat
(
1
,
1
,
dst
.
shape
[
1
],
1
)
dst_expand
=
dst
.
unsqueeze
(
1
).
repeat
(
1
,
src
.
shape
[
1
],
1
,
1
)
distance
=
criterion
(
src_expand
,
dst_expand
,
reduction
=
'none'
).
sum
(
-
1
)
src2dst_distance
,
indices1
=
torch
.
min
(
distance
,
dim
=
2
)
# (B,N)
dst2src_distance
,
indices2
=
torch
.
min
(
distance
,
dim
=
1
)
# (B,M)
loss_src
=
(
src2dst_distance
*
src_weight
)
loss_dst
=
(
dst2src_distance
*
dst_weight
)
if
reduction
==
'sum'
:
loss_src
=
torch
.
sum
(
loss_src
)
loss_dst
=
torch
.
sum
(
loss_dst
)
elif
reduction
==
'mean'
:
loss_src
=
torch
.
mean
(
loss_src
)
loss_dst
=
torch
.
mean
(
loss_dst
)
elif
reduction
==
'none'
:
pass
else
:
raise
NotImplementedError
return
loss_src
,
loss_dst
,
indices1
,
indices2
@
LOSSES
.
register_module
()
class
ChamferDistance
(
nn
.
Module
):
"""Calculate Chamfer Distance of two sets.
Args:
mode (str): Criterion mode to calculate distance.
The valid modes are smooth_l1, l1 or l2.
reduction (str): Method to reduce losses.
The valid reduction method are none, sum or mean.
loss_src_weight (float): Weight of loss_source.
loss_dst_weight (float): Weight of loss_target.
"""
def
__init__
(
self
,
mode
=
'l2'
,
reduction
=
'mean'
,
loss_src_weight
=
1.0
,
loss_dst_weight
=
1.0
):
super
(
ChamferDistance
,
self
).
__init__
()
assert
mode
in
[
'smooth_l1'
,
'l1'
,
'l2'
]
assert
reduction
in
[
'none'
,
'sum'
,
'mean'
]
self
.
mode
=
mode
self
.
reduction
=
reduction
self
.
loss_src_weight
=
loss_src_weight
self
.
loss_dst_weight
=
loss_dst_weight
def
forward
(
self
,
source
,
target
,
src_weight
=
1.0
,
dst_weight
=
1.0
,
reduction_override
=
None
,
return_indices
=
False
,
**
kwargs
):
assert
reduction_override
in
(
None
,
'none'
,
'mean'
,
'sum'
)
reduction
=
(
reduction_override
if
reduction_override
else
self
.
reduction
)
loss_source
,
loss_target
,
indices1
,
indices2
=
chamfer_distance
(
source
,
target
,
src_weight
,
dst_weight
,
self
.
mode
,
reduction
)
loss_source
*=
self
.
loss_src_weight
loss_target
*=
self
.
loss_dst_weight
if
return_indices
:
return
loss_source
,
loss_target
,
indices1
,
indices2
else
:
return
loss_source
,
loss_target
mmdet3d/models/model_utils/__init__.py
0 → 100644
View file @
f717eb62
from
.vote_module
import
VoteModule
__all__
=
[
'VoteModule'
]
mmdet3d/
op
s/vote_module.py
→
mmdet3d/
models/model_util
s/vote_module.py
View file @
f717eb62
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
mmcv.cnn
import
ConvModule
from
mmcv.cnn
import
ConvModule
from
torch.nn.functional
import
l1_loss
,
mse_loss
,
smooth_l1_loss
from
mmdet3d.models.builder
import
build_loss
class
VoteModule
(
nn
.
Module
):
class
VoteModule
(
nn
.
Module
):
...
@@ -22,7 +23,7 @@ class VoteModule(nn.Module):
...
@@ -22,7 +23,7 @@ class VoteModule(nn.Module):
Default: dict(type='BN1d').
Default: dict(type='BN1d').
norm_feats (bool): Whether to normalize features.
norm_feats (bool): Whether to normalize features.
Default: True.
Default: True.
loss_weight (float): Weight
of vot
ing
loss.
vote_loss (dict): config
of vot
e
loss.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -33,13 +34,13 @@ class VoteModule(nn.Module):
...
@@ -33,13 +34,13 @@ class VoteModule(nn.Module):
conv_cfg
=
dict
(
type
=
'Conv1d'
),
conv_cfg
=
dict
(
type
=
'Conv1d'
),
norm_cfg
=
dict
(
type
=
'BN1d'
),
norm_cfg
=
dict
(
type
=
'BN1d'
),
norm_feats
=
True
,
norm_feats
=
True
,
loss_weight
=
1.0
):
vote_loss
=
None
):
super
().
__init__
()
super
().
__init__
()
self
.
in_channels
=
in_channels
self
.
in_channels
=
in_channels
self
.
vote_per_seed
=
vote_per_seed
self
.
vote_per_seed
=
vote_per_seed
self
.
gt_per_seed
=
gt_per_seed
self
.
gt_per_seed
=
gt_per_seed
self
.
norm_feats
=
norm_feats
self
.
norm_feats
=
norm_feats
self
.
loss_weight
=
loss_weight
self
.
vote_loss
=
build_loss
(
vote_loss
)
prev_channels
=
in_channels
prev_channels
=
in_channels
vote_conv_list
=
list
()
vote_conv_list
=
list
()
...
@@ -118,57 +119,17 @@ class VoteModule(nn.Module):
...
@@ -118,57 +119,17 @@ class VoteModule(nn.Module):
seed_gt_votes_mask
=
torch
.
gather
(
vote_targets_mask
,
1
,
seed_gt_votes_mask
=
torch
.
gather
(
vote_targets_mask
,
1
,
seed_indices
).
float
()
seed_indices
).
float
()
pos_num
=
torch
.
sum
(
seed_gt_votes_mask
)
seed_indices_expand
=
seed_indices
.
unsqueeze
(
-
1
).
repeat
(
seed_indices_expand
=
seed_indices
.
unsqueeze
(
-
1
).
repeat
(
1
,
1
,
3
*
self
.
gt_per_seed
)
1
,
1
,
3
*
self
.
gt_per_seed
)
seed_gt_votes
=
torch
.
gather
(
vote_targets
,
1
,
seed_indices_expand
)
seed_gt_votes
=
torch
.
gather
(
vote_targets
,
1
,
seed_indices_expand
)
seed_gt_votes
+=
seed_points
.
repeat
(
1
,
1
,
3
)
seed_gt_votes
+=
seed_points
.
repeat
(
1
,
1
,
3
)
distance
=
self
.
nn_distance
(
weight
=
seed_gt_votes_mask
/
(
torch
.
sum
(
seed_gt_votes_mask
)
+
1e-6
)
distance
=
self
.
vote_loss
(
vote_points
.
view
(
batch_size
*
num_seed
,
-
1
,
3
),
vote_points
.
view
(
batch_size
*
num_seed
,
-
1
,
3
),
seed_gt_votes
.
view
(
batch_size
*
num_seed
,
-
1
,
3
),
seed_gt_votes
.
view
(
batch_size
*
num_seed
,
-
1
,
3
),
mode
=
'l1'
)[
2
]
dst_weight
=
weight
.
view
(
batch_size
*
num_seed
,
1
))[
1
]
votes_distance
=
torch
.
min
(
distance
,
dim
=
1
)[
0
]
vote_loss
=
torch
.
sum
(
torch
.
min
(
distance
,
dim
=
1
)[
0
])
votes_dist
=
votes_distance
.
view
(
batch_size
,
num_seed
)
vote_loss
=
torch
.
sum
(
votes_dist
*
seed_gt_votes_mask
)
/
(
pos_num
+
1e-6
)
return
self
.
loss_weight
*
vote_loss
def
nn_distance
(
self
,
points1
,
points2
,
mode
=
'smooth_l1'
):
return
vote_loss
"""Find the nearest neighbor from point1 to point2
Args:
points1 (Tensor): points to find the Nearest neighbor.
points2 (Tensor): points to find the Nearest neighbor.
mode (str): Specify the function (smooth_l1, l1 or l2)
to calculate distance.
Returns:
tuple[Tensor]:
- distance1: the nearest distance from points1 to points2.
- index1: the index of the nearest neighbor for points1.
- distance2: the nearest distance from points2 to points1.
- index2: the index of the nearest neighbor for points2.
"""
assert
mode
in
[
'smooth_l1'
,
'l1'
,
'l2'
]
N
=
points1
.
shape
[
1
]
M
=
points2
.
shape
[
1
]
pc1_expand_tile
=
points1
.
unsqueeze
(
2
).
repeat
(
1
,
1
,
M
,
1
)
pc2_expand_tile
=
points2
.
unsqueeze
(
1
).
repeat
(
1
,
N
,
1
,
1
)
if
mode
==
'smooth_l1'
:
pc_dist
=
torch
.
sum
(
smooth_l1_loss
(
pc1_expand_tile
,
pc2_expand_tile
),
dim
=-
1
)
elif
mode
==
'l1'
:
pc_dist
=
torch
.
sum
(
l1_loss
(
pc1_expand_tile
,
pc2_expand_tile
),
dim
=-
1
)
# (B,N,M)
elif
mode
==
'l2'
:
pc_dist
=
torch
.
sum
(
mse_loss
(
pc1_expand_tile
,
pc2_expand_tile
),
dim
=-
1
)
# (B,N,M)
else
:
raise
NotImplementedError
distance1
,
index1
=
torch
.
min
(
pc_dist
,
dim
=
2
)
# (B,N)
distance2
,
index2
=
torch
.
min
(
pc_dist
,
dim
=
1
)
# (B,M)
return
distance1
,
index1
,
distance2
,
index2
mmdet3d/ops/__init__.py
View file @
f717eb62
...
@@ -9,11 +9,10 @@ from .group_points import (GroupAll, QueryAndGroup, group_points,
...
@@ -9,11 +9,10 @@ from .group_points import (GroupAll, QueryAndGroup, group_points,
from
.interpolate
import
three_interpolate
,
three_nn
from
.interpolate
import
three_interpolate
,
three_nn
from
.norm
import
NaiveSyncBatchNorm1d
,
NaiveSyncBatchNorm2d
from
.norm
import
NaiveSyncBatchNorm1d
,
NaiveSyncBatchNorm2d
from
.pointnet_modules
import
PointFPModule
,
PointSAModule
,
PointSAModuleMSG
from
.pointnet_modules
import
PointFPModule
,
PointSAModule
,
PointSAModuleMSG
from
.roiaware_pool3d
import
(
RoIAwarePool3d
,
points_in_boxes_
cpu
,
from
.roiaware_pool3d
import
(
RoIAwarePool3d
,
points_in_boxes_
batch
,
points_in_boxes_gpu
)
points_in_boxes_cpu
,
points_in_boxes_gpu
)
from
.sparse_block
import
(
SparseBasicBlock
,
SparseBottleneck
,
from
.sparse_block
import
(
SparseBasicBlock
,
SparseBottleneck
,
make_sparse_convmodule
)
make_sparse_convmodule
)
from
.vote_module
import
VoteModule
from
.voxel
import
DynamicScatter
,
Voxelization
,
dynamic_scatter
,
voxelization
from
.voxel
import
DynamicScatter
,
Voxelization
,
dynamic_scatter
,
voxelization
__all__
=
[
__all__
=
[
...
@@ -26,5 +25,5 @@ __all__ = [
...
@@ -26,5 +25,5 @@ __all__ = [
'make_sparse_convmodule'
,
'ball_query'
,
'furthest_point_sample'
,
'make_sparse_convmodule'
,
'ball_query'
,
'furthest_point_sample'
,
'three_interpolate'
,
'three_nn'
,
'gather_points'
,
'grouping_operation'
,
'three_interpolate'
,
'three_nn'
,
'gather_points'
,
'grouping_operation'
,
'group_points'
,
'GroupAll'
,
'QueryAndGroup'
,
'PointSAModule'
,
'group_points'
,
'GroupAll'
,
'QueryAndGroup'
,
'PointSAModule'
,
'PointSAModuleMSG'
,
'PointFPModule'
,
'
VoteModule
'
'PointSAModuleMSG'
,
'PointFPModule'
,
'
points_in_boxes_batch
'
]
]
mmdet3d/ops/roiaware_pool3d/__init__.py
View file @
f717eb62
from
.points_in_boxes
import
points_in_boxes_cpu
,
points_in_boxes_gpu
from
.points_in_boxes
import
(
points_in_boxes_batch
,
points_in_boxes_cpu
,
points_in_boxes_gpu
)
from
.roiaware_pool3d
import
RoIAwarePool3d
from
.roiaware_pool3d
import
RoIAwarePool3d
__all__
=
[
'RoIAwarePool3d'
,
'points_in_boxes_gpu'
,
'points_in_boxes_cpu'
]
__all__
=
[
'RoIAwarePool3d'
,
'points_in_boxes_gpu'
,
'points_in_boxes_cpu'
,
'points_in_boxes_batch'
]
mmdet3d/ops/roiaware_pool3d/points_in_boxes.py
View file @
f717eb62
...
@@ -53,3 +53,29 @@ def points_in_boxes_cpu(points, boxes):
...
@@ -53,3 +53,29 @@ def points_in_boxes_cpu(points, boxes):
point_indices
)
point_indices
)
return
point_indices
return
point_indices
def
points_in_boxes_batch
(
points
,
boxes
):
"""Find points that are in boxes (CUDA)
Args:
points (torch.Tensor): [B, M, 3], [x, y, z] in LiDAR coordinate
boxes (torch.Tensor): [B, T, 7],
num_valid_boxes <= T, [x, y, z, w, l, h, ry] in LiDAR coordinate,
(x, y, z) is the bottom center
Returns:
box_idxs_of_pts (torch.Tensor): (B, M, T), default background = 0
"""
assert
boxes
.
shape
[
0
]
==
points
.
shape
[
0
]
assert
boxes
.
shape
[
2
]
==
7
batch_size
,
num_points
,
_
=
points
.
shape
num_boxes
=
boxes
.
shape
[
1
]
box_idxs_of_pts
=
points
.
new_zeros
((
batch_size
,
num_points
,
num_boxes
),
dtype
=
torch
.
int
).
fill_
(
0
)
roiaware_pool3d_ext
.
points_in_boxes_batch
(
boxes
.
contiguous
(),
points
.
contiguous
(),
box_idxs_of_pts
)
return
box_idxs_of_pts
mmdet3d/ops/roiaware_pool3d/src/points_in_boxes_cuda.cu
View file @
f717eb62
...
@@ -77,6 +77,34 @@ __global__ void points_in_boxes_kernel(int batch_size, int boxes_num,
...
@@ -77,6 +77,34 @@ __global__ void points_in_boxes_kernel(int batch_size, int boxes_num,
}
}
}
}
__global__
void
points_in_boxes_batch_kernel
(
int
batch_size
,
int
boxes_num
,
int
pts_num
,
const
float
*
boxes
,
const
float
*
pts
,
int
*
box_idx_of_points
)
{
// params boxes: (B, N, 7) [x, y, z, w, l, h, rz] in LiDAR coordinate, z is
// the bottom center, each box DO NOT overlaps params pts: (B, npoints, 3) [x,
// y, z] in LiDAR coordinate params boxes_idx_of_points: (B, npoints), default
// -1
int
bs_idx
=
blockIdx
.
y
;
int
pt_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
bs_idx
>=
batch_size
||
pt_idx
>=
pts_num
)
return
;
boxes
+=
bs_idx
*
boxes_num
*
7
;
pts
+=
bs_idx
*
pts_num
*
3
+
pt_idx
*
3
;
box_idx_of_points
+=
bs_idx
*
pts_num
*
boxes_num
+
pt_idx
*
boxes_num
;
float
local_x
=
0
,
local_y
=
0
;
int
cur_in_flag
=
0
;
for
(
int
k
=
0
;
k
<
boxes_num
;
k
++
)
{
cur_in_flag
=
check_pt_in_box3d
(
pts
,
boxes
+
k
*
7
,
local_x
,
local_y
);
if
(
cur_in_flag
)
{
box_idx_of_points
[
k
]
=
1
;
}
cur_in_flag
=
0
;
}
}
void
points_in_boxes_launcher
(
int
batch_size
,
int
boxes_num
,
int
pts_num
,
void
points_in_boxes_launcher
(
int
batch_size
,
int
boxes_num
,
int
pts_num
,
const
float
*
boxes
,
const
float
*
pts
,
const
float
*
boxes
,
const
float
*
pts
,
int
*
box_idx_of_points
)
{
int
*
box_idx_of_points
)
{
...
@@ -102,6 +130,30 @@ void points_in_boxes_launcher(int batch_size, int boxes_num, int pts_num,
...
@@ -102,6 +130,30 @@ void points_in_boxes_launcher(int batch_size, int boxes_num, int pts_num,
#endif
#endif
}
}
void
points_in_boxes_batch_launcher
(
int
batch_size
,
int
boxes_num
,
int
pts_num
,
const
float
*
boxes
,
const
float
*
pts
,
int
*
box_idx_of_points
)
{
// params boxes: (B, N, 7) [x, y, z, w, l, h, rz] in LiDAR coordinate, z is
// the bottom center, each box params pts: (B, npoints, 3) [x, y, z] in
//LiDAR coordinate params boxes_idx_of_points: (B, npoints), default -1
cudaError_t
err
;
dim3
blocks
(
DIVUP
(
pts_num
,
THREADS_PER_BLOCK
),
batch_size
);
dim3
threads
(
THREADS_PER_BLOCK
);
points_in_boxes_batch_kernel
<<<
blocks
,
threads
>>>
(
batch_size
,
boxes_num
,
pts_num
,
boxes
,
pts
,
box_idx_of_points
);
err
=
cudaGetLastError
();
if
(
cudaSuccess
!=
err
)
{
fprintf
(
stderr
,
"CUDA kernel failed : %s
\n
"
,
cudaGetErrorString
(
err
));
exit
(
-
1
);
}
#ifdef DEBUG
cudaDeviceSynchronize
();
// for using printf in kernel function
#endif
}
int
points_in_boxes_gpu
(
at
::
Tensor
boxes_tensor
,
at
::
Tensor
pts_tensor
,
int
points_in_boxes_gpu
(
at
::
Tensor
boxes_tensor
,
at
::
Tensor
pts_tensor
,
at
::
Tensor
box_idx_of_points_tensor
)
{
at
::
Tensor
box_idx_of_points_tensor
)
{
// params boxes: (B, N, 7) [x, y, z, w, l, h, rz] in LiDAR coordinate, z is
// params boxes: (B, N, 7) [x, y, z, w, l, h, rz] in LiDAR coordinate, z is
...
@@ -126,3 +178,27 @@ int points_in_boxes_gpu(at::Tensor boxes_tensor, at::Tensor pts_tensor,
...
@@ -126,3 +178,27 @@ int points_in_boxes_gpu(at::Tensor boxes_tensor, at::Tensor pts_tensor,
return
1
;
return
1
;
}
}
int
points_in_boxes_batch
(
at
::
Tensor
boxes_tensor
,
at
::
Tensor
pts_tensor
,
at
::
Tensor
box_idx_of_points_tensor
)
{
// params boxes: (B, N, 7) [x, y, z, w, l, h, rz] in LiDAR coordinate, z is
// the bottom center. params pts: (B, npoints, 3) [x, y, z] in LiDAR
// coordinate params boxes_idx_of_points: (B, npoints), default -1
CHECK_INPUT
(
boxes_tensor
);
CHECK_INPUT
(
pts_tensor
);
CHECK_INPUT
(
box_idx_of_points_tensor
);
int
batch_size
=
boxes_tensor
.
size
(
0
);
int
boxes_num
=
boxes_tensor
.
size
(
1
);
int
pts_num
=
pts_tensor
.
size
(
1
);
const
float
*
boxes
=
boxes_tensor
.
data_ptr
<
float
>
();
const
float
*
pts
=
pts_tensor
.
data_ptr
<
float
>
();
int
*
box_idx_of_points
=
box_idx_of_points_tensor
.
data_ptr
<
int
>
();
points_in_boxes_batch_launcher
(
batch_size
,
boxes_num
,
pts_num
,
boxes
,
pts
,
box_idx_of_points
);
return
1
;
}
mmdet3d/ops/roiaware_pool3d/src/roiaware_pool3d.cpp
View file @
f717eb62
...
@@ -44,6 +44,9 @@ int points_in_boxes_cpu(at::Tensor boxes_tensor, at::Tensor pts_tensor,
...
@@ -44,6 +44,9 @@ int points_in_boxes_cpu(at::Tensor boxes_tensor, at::Tensor pts_tensor,
int
points_in_boxes_gpu
(
at
::
Tensor
boxes_tensor
,
at
::
Tensor
pts_tensor
,
int
points_in_boxes_gpu
(
at
::
Tensor
boxes_tensor
,
at
::
Tensor
pts_tensor
,
at
::
Tensor
box_idx_of_points_tensor
);
at
::
Tensor
box_idx_of_points_tensor
);
int
points_in_boxes_batch
(
at
::
Tensor
boxes_tensor
,
at
::
Tensor
pts_tensor
,
at
::
Tensor
box_idx_of_points_tensor
);
int
roiaware_pool3d_gpu
(
at
::
Tensor
rois
,
at
::
Tensor
pts
,
at
::
Tensor
pts_feature
,
int
roiaware_pool3d_gpu
(
at
::
Tensor
rois
,
at
::
Tensor
pts
,
at
::
Tensor
pts_feature
,
at
::
Tensor
argmax
,
at
::
Tensor
pts_idx_of_voxels
,
at
::
Tensor
argmax
,
at
::
Tensor
pts_idx_of_voxels
,
at
::
Tensor
pooled_features
,
int
pool_method
)
{
at
::
Tensor
pooled_features
,
int
pool_method
)
{
...
@@ -127,6 +130,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -127,6 +130,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"roiaware pool3d backward (CUDA)"
);
"roiaware pool3d backward (CUDA)"
);
m
.
def
(
"points_in_boxes_gpu"
,
&
points_in_boxes_gpu
,
m
.
def
(
"points_in_boxes_gpu"
,
&
points_in_boxes_gpu
,
"points_in_boxes_gpu forward (CUDA)"
);
"points_in_boxes_gpu forward (CUDA)"
);
m
.
def
(
"points_in_boxes_batch"
,
&
points_in_boxes_batch
,
"points_in_boxes_batch forward (CUDA)"
);
m
.
def
(
"points_in_boxes_cpu"
,
&
points_in_boxes_cpu
,
m
.
def
(
"points_in_boxes_cpu"
,
&
points_in_boxes_cpu
,
"points_in_boxes_cpu forward (CPU)"
);
"points_in_boxes_cpu forward (CPU)"
);
}
}
tests/test_heads.py
View file @
f717eb62
...
@@ -170,3 +170,87 @@ def test_parta2_rpnhead_getboxes():
...
@@ -170,3 +170,87 @@ def test_parta2_rpnhead_getboxes():
assert
result_list
[
0
][
'labels_3d'
].
shape
==
torch
.
Size
([
512
])
assert
result_list
[
0
][
'labels_3d'
].
shape
==
torch
.
Size
([
512
])
assert
result_list
[
0
][
'cls_preds'
].
shape
==
torch
.
Size
([
512
,
3
])
assert
result_list
[
0
][
'cls_preds'
].
shape
==
torch
.
Size
([
512
,
3
])
assert
result_list
[
0
][
'boxes_3d'
].
shape
==
torch
.
Size
([
512
,
7
])
assert
result_list
[
0
][
'boxes_3d'
].
shape
==
torch
.
Size
([
512
,
7
])
def
test_vote_head
():
if
not
torch
.
cuda
.
is_available
():
pytest
.
skip
(
'test requires GPU and torch+cuda'
)
from
mmdet3d.models.dense_heads
import
VoteHead
bbox_head_cfg
=
dict
(
num_classes
=
10
,
bbox_coder
=
dict
(
type
=
'PartialBinBasedBBoxCoder'
,
num_sizes
=
10
,
num_dir_bins
=
5
,
with_rot
=
True
,
mean_sizes
=
[[
2.114256
,
1.620300
,
0.927272
],
[
0.791118
,
1.279516
,
0.718182
],
[
0.923508
,
1.867419
,
0.845495
],
[
0.591958
,
0.552978
,
0.827272
],
[
0.699104
,
0.454178
,
0.75625
],
[
0.69519
,
1.346299
,
0.736364
],
[
0.528526
,
1.002642
,
1.172878
],
[
0.500618
,
0.632163
,
0.683424
],
[
0.404671
,
1.071108
,
1.688889
],
[
0.76584
,
1.398258
,
0.472728
]]),
vote_moudule_cfg
=
dict
(
in_channels
=
64
,
vote_per_seed
=
1
,
gt_per_seed
=
3
,
conv_channels
=
(
64
,
64
),
conv_cfg
=
dict
(
type
=
'Conv1d'
),
norm_cfg
=
dict
(
type
=
'BN1d'
),
norm_feats
=
True
,
vote_loss
=
dict
(
type
=
'ChamferDistance'
,
mode
=
'l1'
,
reduction
=
'none'
,
loss_dst_weight
=
10.0
)),
vote_aggregation_cfg
=
dict
(
num_point
=
256
,
radius
=
0.3
,
num_sample
=
16
,
mlp_channels
=
[
64
,
32
,
32
,
32
],
use_xyz
=
True
,
normalize_xyz
=
True
),
feat_channels
=
(
64
,
64
),
conv_cfg
=
dict
(
type
=
'Conv1d'
),
norm_cfg
=
dict
(
type
=
'BN1d'
),
objectness_loss
=
dict
(
type
=
'CrossEntropyLoss'
,
class_weight
=
[
0.2
,
0.8
],
reduction
=
'sum'
,
loss_weight
=
5.0
),
center_loss
=
dict
(
type
=
'ChamferDistance'
,
mode
=
'l2'
,
reduction
=
'sum'
,
loss_src_weight
=
10.0
,
loss_dst_weight
=
10.0
),
dir_class_loss
=
dict
(
type
=
'CrossEntropyLoss'
,
reduction
=
'sum'
,
loss_weight
=
1.0
),
dir_res_loss
=
dict
(
type
=
'SmoothL1Loss'
,
reduction
=
'sum'
,
loss_weight
=
10.0
),
size_class_loss
=
dict
(
type
=
'CrossEntropyLoss'
,
reduction
=
'sum'
,
loss_weight
=
1.0
),
size_res_loss
=
dict
(
type
=
'SmoothL1Loss'
,
reduction
=
'sum'
,
loss_weight
=
10.0
/
3.0
),
semantic_loss
=
dict
(
type
=
'CrossEntropyLoss'
,
reduction
=
'sum'
,
loss_weight
=
1.0
))
train_cfg
=
dict
(
pos_distance_thr
=
0.3
,
neg_distance_thr
=
0.6
,
sample_mod
=
'vote'
)
self
=
VoteHead
(
train_cfg
=
train_cfg
,
**
bbox_head_cfg
).
cuda
()
fp_xyz
=
[
torch
.
rand
([
2
,
64
,
3
],
dtype
=
torch
.
float32
).
cuda
()]
fp_features
=
[
torch
.
rand
([
2
,
64
,
64
],
dtype
=
torch
.
float32
).
cuda
()]
fp_indices
=
[
torch
.
randint
(
0
,
128
,
[
2
,
64
]).
cuda
()]
input_dict
=
dict
(
fp_xyz
=
fp_xyz
,
fp_features
=
fp_features
,
fp_indices
=
fp_indices
)
# test forward
ret_dict
=
self
(
input_dict
,
'vote'
)
assert
ret_dict
[
'center'
].
shape
==
torch
.
Size
([
2
,
256
,
3
])
assert
ret_dict
[
'obj_scores'
].
shape
==
torch
.
Size
([
2
,
256
,
2
])
assert
ret_dict
[
'size_res'
].
shape
==
torch
.
Size
([
2
,
256
,
10
,
3
])
assert
ret_dict
[
'dir_res'
].
shape
==
torch
.
Size
([
2
,
256
,
5
])
tests/test_losses.py
0 → 100644
View file @
f717eb62
import
pytest
import
torch
def
test_chamfer_disrance
():
from
mmdet3d.models.losses
import
ChamferDistance
,
chamfer_distance
with
pytest
.
raises
(
AssertionError
):
# test invalid mode
ChamferDistance
(
mode
=
'smoothl1'
)
# test invalid type of reduction
ChamferDistance
(
mode
=
'l2'
,
reduction
=
None
)
self
=
ChamferDistance
(
mode
=
'l2'
,
reduction
=
'sum'
,
loss_src_weight
=
1.0
,
loss_dst_weight
=
1.0
)
source
=
torch
.
tensor
([[[
-
0.9888
,
0.9683
,
-
0.8494
],
[
-
6.4536
,
4.5146
,
1.6861
],
[
2.0482
,
5.6936
,
-
1.4701
],
[
-
0.5173
,
5.6472
,
2.1748
],
[
-
2.8010
,
5.4423
,
-
1.2158
],
[
2.4018
,
2.4389
,
-
0.2403
],
[
-
2.8811
,
3.8486
,
1.4750
],
[
-
0.2031
,
3.8969
,
-
1.5245
],
[
1.3827
,
4.9295
,
1.1537
],
[
-
2.6961
,
2.2621
,
-
1.0976
]],
[[
0.3692
,
1.8409
,
-
1.4983
],
[
1.9995
,
6.3602
,
0.1798
],
[
-
2.1317
,
4.6011
,
-
0.7028
],
[
2.4158
,
3.1482
,
0.3169
],
[
-
0.5836
,
3.6250
,
-
1.2650
],
[
-
1.9862
,
1.6182
,
-
1.4901
],
[
2.5992
,
1.2847
,
-
0.8471
],
[
-
0.3467
,
5.3681
,
-
1.4755
],
[
-
0.8576
,
3.3400
,
-
1.7399
],
[
2.7447
,
4.6349
,
0.1994
]]])
target
=
torch
.
tensor
([[[
-
0.4758
,
1.0094
,
-
0.8645
],
[
-
0.3130
,
0.8564
,
-
0.9061
],
[
-
0.1560
,
2.0394
,
-
0.8936
],
[
-
0.3685
,
1.6467
,
-
0.8271
],
[
-
0.2740
,
2.2212
,
-
0.7980
]],
[[
1.4856
,
2.5299
,
-
1.0047
],
[
2.3262
,
3.3065
,
-
0.9475
],
[
2.4593
,
2.5870
,
-
0.9423
],
[
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
]]])
loss_source
,
loss_target
,
indices1
,
indices2
=
self
(
source
,
target
,
return_indices
=
True
)
assert
torch
.
allclose
(
loss_source
,
torch
.
tensor
(
219.5936
))
assert
torch
.
allclose
(
loss_target
,
torch
.
tensor
(
22.3705
))
assert
(
indices1
==
indices1
.
new_tensor
([[
0
,
4
,
4
,
4
,
4
,
2
,
4
,
4
,
4
,
3
],
[
0
,
1
,
0
,
1
,
0
,
4
,
2
,
0
,
0
,
1
]])).
all
()
assert
(
indices2
==
indices2
.
new_tensor
([[
0
,
0
,
0
,
0
,
0
],
[
0
,
3
,
6
,
0
,
0
]])).
all
()
loss_source
,
loss_target
,
indices1
,
indices2
=
chamfer_distance
(
source
,
target
,
reduction
=
'sum'
)
assert
torch
.
allclose
(
loss_source
,
torch
.
tensor
(
219.5936
))
assert
torch
.
allclose
(
loss_target
,
torch
.
tensor
(
22.3705
))
assert
(
indices1
==
indices1
.
new_tensor
([[
0
,
4
,
4
,
4
,
4
,
2
,
4
,
4
,
4
,
3
],
[
0
,
1
,
0
,
1
,
0
,
4
,
2
,
0
,
0
,
1
]])).
all
()
assert
(
indices2
==
indices2
.
new_tensor
([[
0
,
0
,
0
,
0
,
0
],
[
0
,
3
,
6
,
0
,
0
]])).
all
()
tests/test_nms.py
0 → 100644
View file @
f717eb62
import
torch
def
test_aligned_3d_nms
():
from
mmdet3d.core.post_processing
import
aligned_3d_nms
boxes
=
torch
.
tensor
([[
1.2261
,
0.6679
,
-
1.2678
,
2.6547
,
1.0428
,
0.1000
],
[
5.0919
,
0.6512
,
0.7238
,
5.4821
,
1.2451
,
2.1095
],
[
6.8392
,
-
1.2205
,
0.8570
,
7.6920
,
0.3220
,
3.2223
],
[
3.6900
,
-
0.4235
,
-
1.0380
,
4.4415
,
0.2671
,
-
0.1442
],
[
4.8071
,
-
1.4311
,
0.7004
,
5.5788
,
-
0.6837
,
1.2487
],
[
2.1807
,
-
1.5811
,
-
1.1289
,
3.0151
,
-
0.1346
,
-
0.5351
],
[
4.4631
,
-
4.2588
,
-
1.1403
,
5.3012
,
-
3.4463
,
-
0.3212
],
[
4.7607
,
-
3.3311
,
0.5993
,
5.2976
,
-
2.7874
,
1.2273
],
[
3.1265
,
0.7113
,
-
0.0296
,
3.8944
,
1.3532
,
0.9785
],
[
5.5828
,
-
3.5350
,
1.0105
,
8.2841
,
-
0.0405
,
3.3614
],
[
3.0003
,
-
2.1099
,
-
1.0608
,
5.3423
,
0.0328
,
0.6252
],
[
2.7148
,
0.6082
,
-
1.1738
,
3.6995
,
1.2375
,
-
0.0209
],
[
4.9263
,
-
0.2152
,
0.2889
,
5.6963
,
0.3416
,
1.3471
],
[
5.0713
,
1.3459
,
-
0.2598
,
5.6278
,
1.9300
,
1.2835
],
[
4.5985
,
-
2.3996
,
-
0.3393
,
5.2705
,
-
1.7306
,
0.5698
],
[
4.1386
,
0.5658
,
0.0422
,
4.8937
,
1.1983
,
0.9911
],
[
2.7694
,
-
1.9822
,
-
1.0637
,
4.0691
,
0.3575
,
-
0.1393
],
[
4.6464
,
-
3.0123
,
-
1.0694
,
5.1421
,
-
2.4450
,
-
0.3758
],
[
3.4754
,
0.4443
,
-
1.1282
,
4.6727
,
1.3786
,
0.2550
],
[
2.5905
,
-
0.3504
,
-
1.1202
,
3.1599
,
0.1153
,
-
0.3036
],
[
4.1336
,
-
3.4813
,
1.1477
,
6.2091
,
-
0.8776
,
2.6757
],
[
3.9966
,
0.2069
,
-
1.1148
,
5.0841
,
1.0525
,
-
0.0648
],
[
4.3216
,
-
1.8647
,
0.4733
,
6.2069
,
0.6671
,
3.3363
],
[
4.7683
,
0.4286
,
-
0.0500
,
5.5642
,
1.2906
,
0.8902
],
[
1.7337
,
0.7625
,
-
1.0058
,
3.0675
,
1.3617
,
0.3849
],
[
4.7193
,
-
3.3687
,
-
0.9635
,
5.1633
,
-
2.7656
,
1.1001
],
[
4.4704
,
-
2.7744
,
-
1.1127
,
5.0971
,
-
2.0228
,
-
0.3150
],
[
2.7027
,
0.6122
,
-
0.9169
,
3.3083
,
1.2117
,
0.6129
],
[
4.8789
,
-
2.0025
,
0.8385
,
5.5214
,
-
1.3668
,
1.3552
],
[
3.7856
,
-
1.7582
,
-
0.1738
,
5.3373
,
-
0.6300
,
0.5558
]])
scores
=
torch
.
tensor
([
3.6414e-03
,
2.2901e-02
,
2.7576e-04
,
1.2238e-02
,
5.9310e-04
,
1.2659e-01
,
2.4104e-02
,
5.0742e-03
,
2.3581e-03
,
2.0946e-07
,
8.8039e-01
,
1.9127e-01
,
5.0469e-05
,
9.3638e-03
,
3.0663e-03
,
9.4350e-03
,
5.3380e-02
,
1.7895e-01
,
2.0048e-01
,
1.1294e-03
,
3.0304e-08
,
2.0237e-01
,
1.0894e-08
,
6.7972e-02
,
6.7156e-01
,
9.3986e-04
,
7.9470e-01
,
3.9736e-01
,
1.8000e-04
,
7.9151e-04
])
cls
=
torch
.
tensor
([
8
,
8
,
8
,
3
,
3
,
1
,
3
,
3
,
7
,
8
,
0
,
6
,
7
,
8
,
3
,
7
,
2
,
7
,
6
,
3
,
8
,
6
,
6
,
7
,
6
,
8
,
7
,
6
,
3
,
1
])
pick
=
aligned_3d_nms
(
boxes
,
scores
,
cls
,
0.25
)
expected_pick
=
torch
.
tensor
([
10
,
26
,
24
,
27
,
21
,
18
,
17
,
5
,
23
,
16
,
6
,
1
,
3
,
15
,
13
,
7
,
0
,
14
,
8
,
19
,
25
,
29
,
4
,
2
,
28
,
12
,
9
,
20
,
22
])
assert
torch
.
all
(
pick
==
expected_pick
)
tests/test_roiaware_pool3d.py
View file @
f717eb62
import
pytest
import
pytest
import
torch
import
torch
from
mmdet3d.ops.roiaware_pool3d
import
(
RoIAwarePool3d
,
points_in_boxes_cpu
,
from
mmdet3d.ops.roiaware_pool3d
import
(
RoIAwarePool3d
,
points_in_boxes_batch
,
points_in_boxes_cpu
,
points_in_boxes_gpu
)
points_in_boxes_gpu
)
...
@@ -83,3 +84,29 @@ def test_points_in_boxes_cpu():
...
@@ -83,3 +84,29 @@ def test_points_in_boxes_cpu():
dtype
=
torch
.
int32
)
dtype
=
torch
.
int32
)
assert
point_indices
.
shape
==
torch
.
Size
([
2
,
15
])
assert
point_indices
.
shape
==
torch
.
Size
([
2
,
15
])
assert
(
point_indices
==
expected_point_indices
).
all
()
assert
(
point_indices
==
expected_point_indices
).
all
()
def
test_points_in_boxes_batch
():
if
not
torch
.
cuda
.
is_available
():
pytest
.
skip
(
'test requires GPU and torch+cuda'
)
boxes
=
torch
.
tensor
(
[[[
1.0
,
2.0
,
3.0
,
4.0
,
5.0
,
6.0
,
0.3
],
[
-
10.0
,
23.0
,
16.0
,
10
,
20
,
20
,
0.5
]]],
dtype
=
torch
.
float32
).
cuda
(
)
# boxes (m, 7) with bottom center in lidar coordinate
pts
=
torch
.
tensor
(
[[[
1
,
2
,
3.3
],
[
1.2
,
2.5
,
3.0
],
[
0.8
,
2.1
,
3.5
],
[
1.6
,
2.6
,
3.6
],
[
0.8
,
1.2
,
3.9
],
[
-
9.2
,
21.0
,
18.2
],
[
3.8
,
7.9
,
6.3
],
[
4.7
,
3.5
,
-
12.2
],
[
3.8
,
7.6
,
-
2
],
[
-
10.6
,
-
12.9
,
-
20
],
[
-
16
,
-
18
,
9
],
[
-
21.3
,
-
52
,
-
5
],
[
0
,
0
,
0
],
[
6
,
7
,
8
],
[
-
2
,
-
3
,
-
4
]]],
dtype
=
torch
.
float32
).
cuda
()
# points (n, 3) in lidar coordinate
point_indices
=
points_in_boxes_batch
(
points
=
pts
,
boxes
=
boxes
)
expected_point_indices
=
torch
.
tensor
(
[[[
1
,
0
],
[
1
,
0
],
[
1
,
0
],
[
1
,
0
],
[
1
,
0
],
[
0
,
1
],
[
0
,
0
],
[
0
,
0
],
[
0
,
0
],
[
0
,
0
],
[
0
,
0
],
[
0
,
0
],
[
0
,
0
],
[
0
,
0
],
[
0
,
0
]]],
dtype
=
torch
.
int32
).
cuda
()
assert
point_indices
.
shape
==
torch
.
Size
([
1
,
15
,
2
])
assert
(
point_indices
==
expected_point_indices
).
all
()
tests/test_vot
ing
_module.py
→
tests/test_vot
e
_module.py
View file @
f717eb62
import
torch
import
torch
def
test_vot
ing
_module
():
def
test_vot
e
_module
():
from
mmdet3d.
op
s
import
VoteModule
from
mmdet3d.
models.model_util
s
import
VoteModule
self
=
VoteModule
(
vote_per_seed
=
3
,
in_channels
=
8
)
vote_loss
=
dict
(
type
=
'ChamferDistance'
,
mode
=
'l1'
,
reduction
=
'none'
,
loss_dst_weight
=
10.0
)
self
=
VoteModule
(
vote_per_seed
=
3
,
in_channels
=
8
,
vote_loss
=
vote_loss
)
seed_xyz
=
torch
.
rand
([
2
,
64
,
3
],
dtype
=
torch
.
float32
)
# (b, npoints, 3)
seed_xyz
=
torch
.
rand
([
2
,
64
,
3
],
dtype
=
torch
.
float32
)
# (b, npoints, 3)
seed_features
=
torch
.
rand
(
seed_features
=
torch
.
rand
(
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment