Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
mmdetection3d
Commits
f28fbecb
Commit
f28fbecb
authored
May 07, 2020
by
zhangwenwei
Browse files
Merge branch 'master' into box3d_structure
parents
8c7d0586
f584b970
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
395 additions
and
7 deletions
+395
-7
.gitlab-ci.yml
.gitlab-ci.yml
+1
-1
mmdet3d/apis/train.py
mmdet3d/apis/train.py
+3
-2
mmdet3d/core/bbox/box_torch_ops.py
mmdet3d/core/bbox/box_torch_ops.py
+20
-0
mmdet3d/datasets/pipelines/indoor_sample.py
mmdet3d/datasets/pipelines/indoor_sample.py
+67
-0
mmdet3d/models/losses/__init__.py
mmdet3d/models/losses/__init__.py
+2
-2
mmdet3d/models/roi_heads/__init__.py
mmdet3d/models/roi_heads/__init__.py
+3
-0
mmdet3d/models/roi_heads/mask_heads/__init__.py
mmdet3d/models/roi_heads/mask_heads/__init__.py
+3
-0
mmdet3d/models/roi_heads/mask_heads/pointwise_semantic_head.py
...3d/models/roi_heads/mask_heads/pointwise_semantic_head.py
+161
-0
tests/test_indoor_sample.py
tests/test_indoor_sample.py
+58
-0
tests/test_roiaware_pool3d.py
tests/test_roiaware_pool3d.py
+2
-2
tests/test_semantic_heads.py
tests/test_semantic_heads.py
+75
-0
No files found.
.gitlab-ci.yml
View file @
f28fbecb
...
@@ -26,7 +26,7 @@ before_script:
...
@@ -26,7 +26,7 @@ before_script:
script
:
script
:
-
echo "Start building..."
-
echo "Start building..."
-
pip install "git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI"
-
pip install "git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI"
-
pip install git+https://github.com/open-mmlab/mmdetection.git
@v2.0
-
pip install git+https://github.com/open-mmlab/mmdetection.git
-
python -c "import mmdet; print(mmdet.__version__)"
-
python -c "import mmdet; print(mmdet.__version__)"
-
pip install -v -e .[all]
-
pip install -v -e .[all]
-
python -c "import mmdet3d; print(mmdet3d.__version__)"
-
python -c "import mmdet3d; print(mmdet3d.__version__)"
...
...
mmdet3d/apis/train.py
View file @
f28fbecb
...
@@ -90,14 +90,15 @@ def train_detector(model,
...
@@ -90,14 +90,15 @@ def train_detector(model,
if
fp16_cfg
is
not
None
:
if
fp16_cfg
is
not
None
:
optimizer_config
=
Fp16OptimizerHook
(
optimizer_config
=
Fp16OptimizerHook
(
**
cfg
.
optimizer_config
,
**
fp16_cfg
,
distributed
=
distributed
)
**
cfg
.
optimizer_config
,
**
fp16_cfg
,
distributed
=
distributed
)
elif
distributed
:
elif
distributed
and
'type'
not
in
cfg
.
optimizer_config
:
optimizer_config
=
DistOptimizerHook
(
**
cfg
.
optimizer_config
)
optimizer_config
=
DistOptimizerHook
(
**
cfg
.
optimizer_config
)
else
:
else
:
optimizer_config
=
cfg
.
optimizer_config
optimizer_config
=
cfg
.
optimizer_config
# register hooks
# register hooks
runner
.
register_training_hooks
(
cfg
.
lr_config
,
optimizer_config
,
runner
.
register_training_hooks
(
cfg
.
lr_config
,
optimizer_config
,
cfg
.
checkpoint_config
,
cfg
.
log_config
)
cfg
.
checkpoint_config
,
cfg
.
log_config
,
cfg
.
get
(
'momentum_config'
,
None
))
if
distributed
:
if
distributed
:
runner
.
register_hook
(
DistSamplerSeedHook
())
runner
.
register_hook
(
DistSamplerSeedHook
())
...
...
mmdet3d/core/bbox/box_torch_ops.py
View file @
f28fbecb
...
@@ -190,3 +190,23 @@ def rotation_2d(points, angles):
...
@@ -190,3 +190,23 @@ def rotation_2d(points, angles):
rot_cos
=
torch
.
cos
(
angles
)
rot_cos
=
torch
.
cos
(
angles
)
rot_mat_T
=
torch
.
stack
([[
rot_cos
,
-
rot_sin
],
[
rot_sin
,
rot_cos
]])
rot_mat_T
=
torch
.
stack
([[
rot_cos
,
-
rot_sin
],
[
rot_sin
,
rot_cos
]])
return
torch
.
einsum
(
'aij,jka->aik'
,
points
,
rot_mat_T
)
return
torch
.
einsum
(
'aij,jka->aik'
,
points
,
rot_mat_T
)
def
enlarge_box3d_lidar
(
boxes3d
,
extra_width
):
"""Enlarge the length, width and height of input boxes
Args:
boxes3d (torch.float32 or numpy.float32): bottom_center with
shape [N, 7], (x, y, z, w, l, h, ry) in LiDAR coords
extra_width (float): a fix number to add
Returns:
torch.float32 or numpy.float32: enlarged boxes
"""
if
isinstance
(
boxes3d
,
np
.
ndarray
):
large_boxes3d
=
boxes3d
.
copy
()
else
:
large_boxes3d
=
boxes3d
.
clone
()
large_boxes3d
[:,
3
:
6
]
+=
extra_width
*
2
large_boxes3d
[:,
2
]
-=
extra_width
# bottom center z minus extra_width
return
large_boxes3d
mmdet3d/datasets/pipelines/indoor_sample.py
0 → 100644
View file @
f28fbecb
import
numpy
as
np
from
mmdet.datasets.builder
import
PIPELINES
@
PIPELINES
.
register_module
()
class
PointSample
(
object
):
"""Point Sample.
Sampling data to a certain number.
Args:
name (str): Name of the dataset.
num_points (int): Number of points to be sampled.
"""
def
__init__
(
self
,
num_points
):
self
.
num_points
=
num_points
def
points_random_sampling
(
self
,
points
,
num_samples
,
replace
=
None
,
return_choices
=
False
):
"""Points Random Sampling.
Sample points to a certain number.
Args:
points (ndarray): 3D Points.
num_samples (int): Number of samples to be sampled.
replace (bool): Whether the sample is with or without replacement.
return_choices (bool): Whether return choice.
Returns:
points (ndarray): 3D Points.
choices (ndarray): The generated random samples
"""
if
replace
is
None
:
replace
=
(
points
.
shape
[
0
]
<
num_samples
)
choices
=
np
.
random
.
choice
(
points
.
shape
[
0
],
num_samples
,
replace
=
replace
)
if
return_choices
:
return
points
[
choices
],
choices
else
:
return
points
[
choices
]
def
__call__
(
self
,
results
):
points
=
results
.
get
(
'points'
,
None
)
points
,
choices
=
self
.
points_random_sampling
(
points
,
self
.
num_points
,
return_choices
=
True
)
pts_instance_mask
=
results
.
get
(
'pts_instance_mask'
,
None
)
pts_semantic_mask
=
results
.
get
(
'pts_semantic_mask'
,
None
)
results
[
'points'
]
=
points
if
pts_instance_mask
is
not
None
and
pts_semantic_mask
is
not
None
:
pts_instance_mask
=
pts_instance_mask
[
choices
]
pts_semantic_mask
=
pts_semantic_mask
[
choices
]
results
[
'pts_instance_mask'
]
=
pts_instance_mask
results
[
'pts_semantic_mask'
]
=
pts_semantic_mask
return
results
def
__repr__
(
self
):
repr_str
=
self
.
__class__
.
__name__
repr_str
+=
'(num_points={})'
.
format
(
self
.
num_points
)
return
repr_str
mmdet3d/models/losses/__init__.py
View file @
f28fbecb
from
mmdet.models.losses
import
FocalLoss
,
SmoothL1Loss
from
mmdet.models.losses
import
FocalLoss
,
SmoothL1Loss
,
binary_cross_entropy
__all__
=
[
'FocalLoss'
,
'SmoothL1Loss'
]
__all__
=
[
'FocalLoss'
,
'SmoothL1Loss'
,
'binary_cross_entropy'
]
mmdet3d/models/roi_heads/__init__.py
View file @
f28fbecb
from
.mask_heads
import
PointwiseSemanticHead
__all__
=
[
'PointwiseSemanticHead'
]
mmdet3d/models/roi_heads/mask_heads/__init__.py
0 → 100644
View file @
f28fbecb
from
.pointwise_semantic_head
import
PointwiseSemanticHead
__all__
=
[
'PointwiseSemanticHead'
]
mmdet3d/models/roi_heads/mask_heads/pointwise_semantic_head.py
0 → 100644
View file @
f28fbecb
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
mmdet3d.core
import
multi_apply
from
mmdet3d.core.bbox
import
box_torch_ops
from
mmdet3d.models.builder
import
build_loss
from
mmdet3d.ops.roiaware_pool3d
import
points_in_boxes_gpu
from
mmdet.models
import
HEADS
@
HEADS
.
register_module
()
class
PointwiseSemanticHead
(
nn
.
Module
):
"""Semantic segmentation head for point-wise segmentation.
Predict point-wise segmentation and part regression results for PartA2.
See https://arxiv.org/abs/1907.03670 for more detials.
Args:
in_channels (int): the number of input channel.
num_classes (int): the number of class.
extra_width (float): boxes enlarge width.
loss_seg (dict): Config of segmentation loss.
loss_part (dict): Config of part prediction loss.
"""
def
__init__
(
self
,
in_channels
,
num_classes
=
3
,
extra_width
=
0.2
,
seg_score_thr
=
0.3
,
loss_seg
=
dict
(
type
=
'FocalLoss'
,
use_sigmoid
=
True
,
reduction
=
'sum'
,
gamma
=
2.0
,
alpha
=
0.25
,
loss_weight
=
1.0
),
loss_part
=
dict
(
type
=
'CrossEntropyLoss'
,
use_sigmoid
=
True
,
loss_weight
=
1.0
)):
super
(
PointwiseSemanticHead
,
self
).
__init__
()
self
.
extra_width
=
extra_width
self
.
num_classes
=
num_classes
self
.
seg_score_thr
=
seg_score_thr
self
.
seg_cls_layer
=
nn
.
Linear
(
in_channels
,
1
,
bias
=
True
)
self
.
seg_reg_layer
=
nn
.
Linear
(
in_channels
,
3
,
bias
=
True
)
self
.
loss_seg
=
build_loss
(
loss_seg
)
self
.
loss_part
=
build_loss
(
loss_part
)
def
forward
(
self
,
x
):
seg_preds
=
self
.
seg_cls_layer
(
x
)
# (N, 1)
part_preds
=
self
.
seg_reg_layer
(
x
)
# (N, 3)
seg_scores
=
torch
.
sigmoid
(
seg_preds
).
detach
()
seg_mask
=
(
seg_scores
>
self
.
seg_score_thr
)
part_offsets
=
torch
.
sigmoid
(
part_preds
).
clone
().
detach
()
part_offsets
[
seg_mask
.
view
(
-
1
)
==
0
]
=
0
part_feats
=
torch
.
cat
((
part_offsets
,
seg_scores
),
dim
=-
1
)
# shape (npoints, 4)
return
dict
(
seg_preds
=
seg_preds
,
part_preds
=
part_preds
,
part_feats
=
part_feats
)
def
get_targets_single
(
self
,
voxel_centers
,
gt_bboxes_3d
,
gt_labels_3d
):
"""generate segmentation and part prediction targets
Args:
voxel_centers (torch.Tensor): shape [voxel_num, 3],
the center of voxels
gt_bboxes_3d (torch.Tensor): shape [box_num, 7], gt boxes
gt_labels_3d (torch.Tensor): shape [box_num], class label of gt
Returns:
tuple : segmentation targets with shape [voxel_num]
part prediction targets with shape [voxel_num, 3]
"""
enlarged_gt_boxes
=
box_torch_ops
.
enlarge_box3d_lidar
(
gt_bboxes_3d
,
extra_width
=
self
.
extra_width
)
part_targets
=
voxel_centers
.
new_zeros
((
voxel_centers
.
shape
[
0
],
3
),
dtype
=
torch
.
float32
)
box_idx
=
points_in_boxes_gpu
(
voxel_centers
.
unsqueeze
(
0
),
gt_bboxes_3d
.
unsqueeze
(
0
)).
squeeze
(
0
)
# -1 ~ box_num
enlarge_box_idx
=
points_in_boxes_gpu
(
voxel_centers
.
unsqueeze
(
0
),
enlarged_gt_boxes
.
unsqueeze
(
0
)).
squeeze
(
0
).
long
()
# -1 ~ box_num
gt_labels_pad
=
F
.
pad
(
gt_labels_3d
,
(
1
,
0
),
mode
=
'constant'
,
value
=
self
.
num_classes
)
seg_targets
=
gt_labels_pad
[(
box_idx
.
long
()
+
1
)]
fg_pt_flag
=
box_idx
>
-
1
ignore_flag
=
fg_pt_flag
^
(
enlarge_box_idx
>
-
1
)
seg_targets
[
ignore_flag
]
=
-
1
for
k
in
range
(
gt_bboxes_3d
.
shape
[
0
]):
k_box_flag
=
box_idx
==
k
# no point in current box (caused by velodyne reduce)
if
not
k_box_flag
.
any
():
continue
fg_voxels
=
voxel_centers
[
k_box_flag
]
transformed_voxels
=
fg_voxels
-
gt_bboxes_3d
[
k
,
0
:
3
]
transformed_voxels
=
box_torch_ops
.
rotation_3d_in_axis
(
transformed_voxels
.
unsqueeze
(
0
),
-
gt_bboxes_3d
[
k
,
6
].
view
(
1
),
axis
=
2
)
part_targets
[
k_box_flag
]
=
transformed_voxels
/
gt_bboxes_3d
[
k
,
3
:
6
]
+
voxel_centers
.
new_tensor
([
0.5
,
0.5
,
0
])
part_targets
=
torch
.
clamp
(
part_targets
,
min
=
0
)
return
seg_targets
,
part_targets
def
get_targets
(
self
,
voxels_dict
,
gt_bboxes_3d
,
gt_labels_3d
):
batch_size
=
len
(
gt_labels_3d
)
voxel_center_list
=
[]
for
idx
in
range
(
batch_size
):
coords_idx
=
voxels_dict
[
'coors'
][:,
0
]
==
idx
voxel_center_list
.
append
(
voxels_dict
[
'voxel_centers'
][
coords_idx
])
seg_targets
,
part_targets
=
multi_apply
(
self
.
get_targets_single
,
voxel_center_list
,
gt_bboxes_3d
,
gt_labels_3d
)
seg_targets
=
torch
.
cat
(
seg_targets
,
dim
=
0
)
part_targets
=
torch
.
cat
(
part_targets
,
dim
=
0
)
return
dict
(
seg_targets
=
seg_targets
,
part_targets
=
part_targets
)
def
loss
(
self
,
seg_preds
,
part_preds
,
seg_targets
,
part_targets
):
"""Calculate point-wise segmentation and part prediction losses.
Args:
seg_preds (torch.Tensor): prediction of binary
segmentation with shape [voxel_num, 1].
part_preds (torch.Tensor): prediction of part
with shape [voxel_num, 3].
seg_targets (torch.Tensor): target of segmentation
with shape [voxel_num, 1].
part_targets (torch.Tensor): target of part with
shape [voxel_num, 3].
Returns:
dict: loss of segmentation and part prediction.
"""
pos_mask
=
(
seg_targets
>
-
1
)
&
(
seg_targets
<
self
.
num_classes
)
binary_seg_target
=
pos_mask
.
long
()
pos
=
pos_mask
.
float
()
neg
=
(
seg_targets
==
self
.
num_classes
).
float
()
seg_weights
=
pos
+
neg
pos_normalizer
=
pos
.
sum
()
seg_weights
=
seg_weights
/
torch
.
clamp
(
pos_normalizer
,
min
=
1.0
)
loss_seg
=
self
.
loss_seg
(
seg_preds
,
binary_seg_target
,
seg_weights
)
if
pos_normalizer
>
0
:
loss_part
=
self
.
loss_part
(
part_preds
[
pos_mask
],
part_targets
[
pos_mask
])
else
:
# fake a part loss
loss_part
=
loss_seg
.
new_tensor
(
0
)
return
dict
(
loss_seg
=
loss_seg
,
loss_part
=
loss_part
)
tests/test_indoor_sample.py
0 → 100644
View file @
f28fbecb
import
numpy
as
np
from
mmdet3d.datasets.pipelines.indoor_sample
import
PointSample
def
test_indoor_sample
():
np
.
random
.
seed
(
0
)
scannet_sample_points
=
PointSample
(
5
)
scannet_results
=
dict
()
scannet_points
=
np
.
array
([[
1.0719866
,
-
0.7870435
,
0.8408122
,
0.9196809
],
[
1.103661
,
0.81065744
,
2.6616862
,
2.7405548
],
[
1.0276475
,
1.5061463
,
2.6174362
,
2.6963048
],
[
-
0.9709588
,
0.6750515
,
0.93901765
,
1.0178864
],
[
1.0578915
,
1.1693821
,
0.87503505
,
0.95390373
],
[
0.05560996
,
-
1.5688863
,
1.2440368
,
1.3229055
],
[
-
0.15731563
,
-
1.7735453
,
2.7535574
,
2.832426
],
[
1.1188195
,
-
0.99211365
,
2.5551798
,
2.6340485
],
[
-
0.9186557
,
-
1.7041215
,
2.0562649
,
2.1351335
],
[
-
1.0128691
,
-
1.3394243
,
0.040936
,
0.1198047
]])
scannet_results
[
'points'
]
=
scannet_points
scannet_pts_instance_mask
=
np
.
array
(
[
15
,
12
,
11
,
38
,
0
,
18
,
17
,
12
,
17
,
0
])
scannet_results
[
'pts_instance_mask'
]
=
scannet_pts_instance_mask
scannet_pts_semantic_mask
=
np
.
array
([
38
,
1
,
1
,
40
,
0
,
40
,
1
,
1
,
1
,
0
])
scannet_results
[
'pts_semantic_mask'
]
=
scannet_pts_semantic_mask
scannet_results
=
scannet_sample_points
(
scannet_results
)
scannet_points_result
=
scannet_results
.
get
(
'points'
,
None
)
scannet_instance_labels_result
=
scannet_results
.
get
(
'pts_instance_mask'
,
None
)
scannet_semantic_labels_result
=
scannet_results
.
get
(
'pts_semantic_mask'
,
None
)
scannet_choices
=
np
.
array
([
2
,
8
,
4
,
9
,
1
])
assert
np
.
allclose
(
scannet_points
[
scannet_choices
],
scannet_points_result
)
assert
np
.
all
(
scannet_pts_instance_mask
[
scannet_choices
]
==
scannet_instance_labels_result
)
assert
np
.
all
(
scannet_pts_semantic_mask
[
scannet_choices
]
==
scannet_semantic_labels_result
)
np
.
random
.
seed
(
0
)
sunrgbd_sample_points
=
PointSample
(
5
)
sunrgbd_results
=
dict
()
sunrgbd_point_cloud
=
np
.
array
(
[[
-
1.8135729e-01
,
1.4695230e+00
,
-
1.2780589e+00
,
7.8938007e-03
],
[
1.2581362e-03
,
2.0561588e+00
,
-
1.0341064e+00
,
2.5184631e-01
],
[
6.8236995e-01
,
3.3611867e+00
,
-
9.2599887e-01
,
3.5995382e-01
],
[
-
2.9432583e-01
,
1.8714852e+00
,
-
9.0929651e-01
,
3.7665617e-01
],
[
-
0.5024875
,
1.8032674
,
-
1.1403012
,
0.14565146
],
[
-
0.520559
,
1.6324949
,
-
0.9896099
,
0.2963428
],
[
0.95929825
,
2.9402404
,
-
0.8746674
,
0.41128528
],
[
-
0.74624217
,
1.5244724
,
-
0.8678476
,
0.41810507
],
[
0.56485355
,
1.5747732
,
-
0.804522
,
0.4814307
],
[
-
0.0913099
,
1.3673826
,
-
1.2800645
,
0.00588822
]])
sunrgbd_results
[
'points'
]
=
sunrgbd_point_cloud
sunrgbd_results
=
sunrgbd_sample_points
(
sunrgbd_results
)
sunrgbd_choices
=
np
.
array
([
2
,
8
,
4
,
9
,
1
])
sunrgbd_points_result
=
sunrgbd_results
.
get
(
'points'
,
None
)
assert
np
.
allclose
(
sunrgbd_point_cloud
[
sunrgbd_choices
],
sunrgbd_points_result
)
tests/test_roiaware_pool3d.py
View file @
f28fbecb
...
@@ -6,8 +6,8 @@ from mmdet3d.ops.roiaware_pool3d import (RoIAwarePool3d, points_in_boxes_cpu,
...
@@ -6,8 +6,8 @@ from mmdet3d.ops.roiaware_pool3d import (RoIAwarePool3d, points_in_boxes_cpu,
def
test_RoIAwarePool3d
():
def
test_RoIAwarePool3d
():
if
not
torch
.
cuda
.
is_available
(
# RoIAwarePool3d only support gpu version currently.
):
# RoIAwarePool3d only support gpu version currently.
if
not
torch
.
cuda
.
is_available
():
pytest
.
skip
(
'test requires GPU and torch+cuda'
)
pytest
.
skip
(
'test requires GPU and torch+cuda'
)
roiaware_pool3d_max
=
RoIAwarePool3d
(
roiaware_pool3d_max
=
RoIAwarePool3d
(
out_size
=
4
,
max_pts_per_voxel
=
128
,
mode
=
'max'
)
out_size
=
4
,
max_pts_per_voxel
=
128
,
mode
=
'max'
)
...
...
tests/test_semantic_heads.py
0 → 100644
View file @
f28fbecb
import
pytest
import
torch
def
test_PointwiseSemanticHead
():
# PointwiseSemanticHead only support gpu version currently.
if
not
torch
.
cuda
.
is_available
():
pytest
.
skip
(
'test requires GPU and torch+cuda'
)
from
mmdet3d.models.builder
import
build_head
head_cfg
=
dict
(
type
=
'PointwiseSemanticHead'
,
in_channels
=
8
,
extra_width
=
0.2
,
seg_score_thr
=
0.3
,
num_classes
=
3
,
loss_seg
=
dict
(
type
=
'FocalLoss'
,
use_sigmoid
=
True
,
reduction
=
'sum'
,
gamma
=
2.0
,
alpha
=
0.25
,
loss_weight
=
1.0
),
loss_part
=
dict
(
type
=
'CrossEntropyLoss'
,
use_sigmoid
=
True
,
loss_weight
=
1.0
))
self
=
build_head
(
head_cfg
)
self
.
cuda
()
# test forward
voxel_features
=
torch
.
rand
([
4
,
8
],
dtype
=
torch
.
float32
).
cuda
()
feats_dict
=
self
.
forward
(
voxel_features
)
assert
feats_dict
[
'seg_preds'
].
shape
==
torch
.
Size
(
[
voxel_features
.
shape
[
0
],
1
])
assert
feats_dict
[
'part_preds'
].
shape
==
torch
.
Size
(
[
voxel_features
.
shape
[
0
],
3
])
assert
feats_dict
[
'part_feats'
].
shape
==
torch
.
Size
(
[
voxel_features
.
shape
[
0
],
4
])
voxel_centers
=
torch
.
tensor
(
[[
6.56126
,
0.9648336
,
-
1.7339306
],
[
6.8162713
,
-
2.480431
,
-
1.3616394
],
[
11.643568
,
-
4.744306
,
-
1.3580885
],
[
23.482342
,
6.5036807
,
0.5806964
]
],
dtype
=
torch
.
float32
).
cuda
()
# n, point_features
coordinates
=
torch
.
tensor
(
[[
0
,
12
,
819
,
131
],
[
0
,
16
,
750
,
136
],
[
1
,
16
,
705
,
232
],
[
1
,
35
,
930
,
469
]],
dtype
=
torch
.
int32
).
cuda
()
# n, 4(batch, ind_x, ind_y, ind_z)
voxel_dict
=
dict
(
voxel_centers
=
voxel_centers
,
coors
=
coordinates
)
gt_bboxes
=
list
(
torch
.
tensor
(
[[[
6.4118
,
-
3.4305
,
-
1.7291
,
1.7033
,
3.4693
,
1.6197
,
-
0.9091
]],
[[
16.9107
,
9.7925
,
-
1.9201
,
1.6097
,
3.2786
,
1.5307
,
-
2.4056
]]],
dtype
=
torch
.
float32
).
cuda
())
gt_labels
=
list
(
torch
.
tensor
([[
0
],
[
1
]],
dtype
=
torch
.
int64
).
cuda
())
# test get_targets
target_dict
=
self
.
get_targets
(
voxel_dict
,
gt_bboxes
,
gt_labels
)
assert
target_dict
[
'seg_targets'
].
shape
==
torch
.
Size
(
[
voxel_features
.
shape
[
0
]])
assert
target_dict
[
'part_targets'
].
shape
==
torch
.
Size
(
[
voxel_features
.
shape
[
0
],
3
])
# test loss
loss_dict
=
self
.
loss
(
feats_dict
[
'seg_preds'
],
feats_dict
[
'part_preds'
],
target_dict
[
'seg_targets'
],
target_dict
[
'part_targets'
])
assert
loss_dict
[
'loss_seg'
]
>
0
assert
loss_dict
[
'loss_part'
]
==
0
# no points in gt_boxes
total_loss
=
loss_dict
[
'loss_seg'
]
+
loss_dict
[
'loss_part'
]
total_loss
.
backward
()
if
__name__
==
'__main__'
:
test_PointwiseSemanticHead
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment