Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
mmdetection3d
Commits
0e17beab
"vscode:/vscode.git/clone" did not exist on "5eb77bf9764c67f63c95f4f5c0abf7d4cce6c431"
Commit
0e17beab
authored
Jul 18, 2022
by
jshilong
Committed by
ChaimZhu
Jul 20, 2022
Browse files
[REfactor]Refactor H3D
parent
9ebb75da
Changes
12
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
808 additions
and
697 deletions
+808
-697
configs/_base_/models/h3dnet.py
configs/_base_/models/h3dnet.py
+42
-32
configs/h3dnet/debug.py
configs/h3dnet/debug.py
+0
-69
configs/h3dnet/h3dnet_3x8_scannet-3d-18class.py
configs/h3dnet/h3dnet_3x8_scannet-3d-18class.py
+7
-2
mmdet3d/datasets/scannet_dataset.py
mmdet3d/datasets/scannet_dataset.py
+0
-42
mmdet3d/models/dense_heads/vote_head.py
mmdet3d/models/dense_heads/vote_head.py
+93
-28
mmdet3d/models/detectors/h3dnet.py
mmdet3d/models/detectors/h3dnet.py
+97
-116
mmdet3d/models/detectors/point_rcnn.py
mmdet3d/models/detectors/point_rcnn.py
+11
-11
mmdet3d/models/roi_heads/bbox_heads/h3d_bbox_head.py
mmdet3d/models/roi_heads/bbox_heads/h3d_bbox_head.py
+267
-204
mmdet3d/models/roi_heads/h3d_roi_head.py
mmdet3d/models/roi_heads/h3d_roi_head.py
+58
-87
mmdet3d/models/roi_heads/mask_heads/primitive_head.py
mmdet3d/models/roi_heads/mask_heads/primitive_head.py
+160
-91
tests/test_models/test_detectors/test_h3dnet.py
tests/test_models/test_detectors/test_h3dnet.py
+51
-0
tests/utils/model_utils.py
tests/utils/model_utils.py
+22
-15
No files found.
configs/_base_/models/h3dnet.py
View file @
0e17beab
...
@@ -30,7 +30,7 @@ primitive_z_cfg = dict(
...
@@ -30,7 +30,7 @@ primitive_z_cfg = dict(
conv_cfg
=
dict
(
type
=
'Conv1d'
),
conv_cfg
=
dict
(
type
=
'Conv1d'
),
norm_cfg
=
dict
(
type
=
'BN1d'
),
norm_cfg
=
dict
(
type
=
'BN1d'
),
objectness_loss
=
dict
(
objectness_loss
=
dict
(
type
=
'CrossEntropyLoss'
,
type
=
'
mmdet.
CrossEntropyLoss'
,
class_weight
=
[
0.4
,
0.6
],
class_weight
=
[
0.4
,
0.6
],
reduction
=
'mean'
,
reduction
=
'mean'
,
loss_weight
=
30.0
),
loss_weight
=
30.0
),
...
@@ -47,14 +47,16 @@ primitive_z_cfg = dict(
...
@@ -47,14 +47,16 @@ primitive_z_cfg = dict(
loss_src_weight
=
0.5
,
loss_src_weight
=
0.5
,
loss_dst_weight
=
0.5
),
loss_dst_weight
=
0.5
),
semantic_cls_loss
=
dict
(
semantic_cls_loss
=
dict
(
type
=
'CrossEntropyLoss'
,
reduction
=
'sum'
,
loss_weight
=
1.0
),
type
=
'
mmdet.
CrossEntropyLoss'
,
reduction
=
'sum'
,
loss_weight
=
1.0
),
train_cfg
=
dict
(
train_cfg
=
dict
(
sample_mode
=
'vote'
,
dist_thresh
=
0.2
,
dist_thresh
=
0.2
,
var_thresh
=
1e-2
,
var_thresh
=
1e-2
,
lower_thresh
=
1e-6
,
lower_thresh
=
1e-6
,
num_point
=
100
,
num_point
=
100
,
num_point_line
=
10
,
num_point_line
=
10
,
line_thresh
=
0.2
))
line_thresh
=
0.2
),
test_cfg
=
dict
(
sample_mode
=
'seed'
))
primitive_xy_cfg
=
dict
(
primitive_xy_cfg
=
dict
(
type
=
'PrimitiveHead'
,
type
=
'PrimitiveHead'
,
...
@@ -88,7 +90,7 @@ primitive_xy_cfg = dict(
...
@@ -88,7 +90,7 @@ primitive_xy_cfg = dict(
conv_cfg
=
dict
(
type
=
'Conv1d'
),
conv_cfg
=
dict
(
type
=
'Conv1d'
),
norm_cfg
=
dict
(
type
=
'BN1d'
),
norm_cfg
=
dict
(
type
=
'BN1d'
),
objectness_loss
=
dict
(
objectness_loss
=
dict
(
type
=
'CrossEntropyLoss'
,
type
=
'
mmdet.
CrossEntropyLoss'
,
class_weight
=
[
0.4
,
0.6
],
class_weight
=
[
0.4
,
0.6
],
reduction
=
'mean'
,
reduction
=
'mean'
,
loss_weight
=
30.0
),
loss_weight
=
30.0
),
...
@@ -105,14 +107,16 @@ primitive_xy_cfg = dict(
...
@@ -105,14 +107,16 @@ primitive_xy_cfg = dict(
loss_src_weight
=
0.5
,
loss_src_weight
=
0.5
,
loss_dst_weight
=
0.5
),
loss_dst_weight
=
0.5
),
semantic_cls_loss
=
dict
(
semantic_cls_loss
=
dict
(
type
=
'CrossEntropyLoss'
,
reduction
=
'sum'
,
loss_weight
=
1.0
),
type
=
'
mmdet.
CrossEntropyLoss'
,
reduction
=
'sum'
,
loss_weight
=
1.0
),
train_cfg
=
dict
(
train_cfg
=
dict
(
sample_mode
=
'vote'
,
dist_thresh
=
0.2
,
dist_thresh
=
0.2
,
var_thresh
=
1e-2
,
var_thresh
=
1e-2
,
lower_thresh
=
1e-6
,
lower_thresh
=
1e-6
,
num_point
=
100
,
num_point
=
100
,
num_point_line
=
10
,
num_point_line
=
10
,
line_thresh
=
0.2
))
line_thresh
=
0.2
),
test_cfg
=
dict
(
sample_mode
=
'seed'
))
primitive_line_cfg
=
dict
(
primitive_line_cfg
=
dict
(
type
=
'PrimitiveHead'
,
type
=
'PrimitiveHead'
,
...
@@ -146,7 +150,7 @@ primitive_line_cfg = dict(
...
@@ -146,7 +150,7 @@ primitive_line_cfg = dict(
conv_cfg
=
dict
(
type
=
'Conv1d'
),
conv_cfg
=
dict
(
type
=
'Conv1d'
),
norm_cfg
=
dict
(
type
=
'BN1d'
),
norm_cfg
=
dict
(
type
=
'BN1d'
),
objectness_loss
=
dict
(
objectness_loss
=
dict
(
type
=
'CrossEntropyLoss'
,
type
=
'
mmdet.
CrossEntropyLoss'
,
class_weight
=
[
0.4
,
0.6
],
class_weight
=
[
0.4
,
0.6
],
reduction
=
'mean'
,
reduction
=
'mean'
,
loss_weight
=
30.0
),
loss_weight
=
30.0
),
...
@@ -163,17 +167,20 @@ primitive_line_cfg = dict(
...
@@ -163,17 +167,20 @@ primitive_line_cfg = dict(
loss_src_weight
=
1.0
,
loss_src_weight
=
1.0
,
loss_dst_weight
=
1.0
),
loss_dst_weight
=
1.0
),
semantic_cls_loss
=
dict
(
semantic_cls_loss
=
dict
(
type
=
'CrossEntropyLoss'
,
reduction
=
'sum'
,
loss_weight
=
2.0
),
type
=
'
mmdet.
CrossEntropyLoss'
,
reduction
=
'sum'
,
loss_weight
=
2.0
),
train_cfg
=
dict
(
train_cfg
=
dict
(
sample_mode
=
'vote'
,
dist_thresh
=
0.2
,
dist_thresh
=
0.2
,
var_thresh
=
1e-2
,
var_thresh
=
1e-2
,
lower_thresh
=
1e-6
,
lower_thresh
=
1e-6
,
num_point
=
100
,
num_point
=
100
,
num_point_line
=
10
,
num_point_line
=
10
,
line_thresh
=
0.2
))
line_thresh
=
0.2
),
test_cfg
=
dict
(
sample_mode
=
'seed'
))
model
=
dict
(
model
=
dict
(
type
=
'H3DNet'
,
type
=
'H3DNet'
,
data_preprocessor
=
dict
(
type
=
'Det3DDataPreprocessor'
),
backbone
=
dict
(
backbone
=
dict
(
type
=
'MultiBackbone'
,
type
=
'MultiBackbone'
,
num_streams
=
4
,
num_streams
=
4
,
...
@@ -221,10 +228,8 @@ model = dict(
...
@@ -221,10 +228,8 @@ model = dict(
normalize_xyz
=
True
),
normalize_xyz
=
True
),
pred_layer_cfg
=
dict
(
pred_layer_cfg
=
dict
(
in_channels
=
128
,
shared_conv_channels
=
(
128
,
128
),
bias
=
True
),
in_channels
=
128
,
shared_conv_channels
=
(
128
,
128
),
bias
=
True
),
conv_cfg
=
dict
(
type
=
'Conv1d'
),
norm_cfg
=
dict
(
type
=
'BN1d'
),
objectness_loss
=
dict
(
objectness_loss
=
dict
(
type
=
'CrossEntropyLoss'
,
type
=
'
mmdet.
CrossEntropyLoss'
,
class_weight
=
[
0.2
,
0.8
],
class_weight
=
[
0.2
,
0.8
],
reduction
=
'sum'
,
reduction
=
'sum'
,
loss_weight
=
5.0
),
loss_weight
=
5.0
),
...
@@ -235,15 +240,15 @@ model = dict(
...
@@ -235,15 +240,15 @@ model = dict(
loss_src_weight
=
10.0
,
loss_src_weight
=
10.0
,
loss_dst_weight
=
10.0
),
loss_dst_weight
=
10.0
),
dir_class_loss
=
dict
(
dir_class_loss
=
dict
(
type
=
'CrossEntropyLoss'
,
reduction
=
'sum'
,
loss_weight
=
1.0
),
type
=
'
mmdet.
CrossEntropyLoss'
,
reduction
=
'sum'
,
loss_weight
=
1.0
),
dir_res_loss
=
dict
(
dir_res_loss
=
dict
(
type
=
'SmoothL1Loss'
,
reduction
=
'sum'
,
loss_weight
=
10.0
),
type
=
'
mmdet.
SmoothL1Loss'
,
reduction
=
'sum'
,
loss_weight
=
10.0
),
size_class_loss
=
dict
(
size_class_loss
=
dict
(
type
=
'CrossEntropyLoss'
,
reduction
=
'sum'
,
loss_weight
=
1.0
),
type
=
'
mmdet.
CrossEntropyLoss'
,
reduction
=
'sum'
,
loss_weight
=
1.0
),
size_res_loss
=
dict
(
size_res_loss
=
dict
(
type
=
'SmoothL1Loss'
,
reduction
=
'sum'
,
loss_weight
=
10.0
),
type
=
'
mmdet.
SmoothL1Loss'
,
reduction
=
'sum'
,
loss_weight
=
10.0
),
semantic_loss
=
dict
(
semantic_loss
=
dict
(
type
=
'CrossEntropyLoss'
,
reduction
=
'sum'
,
loss_weight
=
1.0
)),
type
=
'
mmdet.
CrossEntropyLoss'
,
reduction
=
'sum'
,
loss_weight
=
1.0
)),
roi_head
=
dict
(
roi_head
=
dict
(
type
=
'H3DRoIHead'
,
type
=
'H3DRoIHead'
,
primitive_list
=
[
primitive_z_cfg
,
primitive_xy_cfg
,
primitive_line_cfg
],
primitive_list
=
[
primitive_z_cfg
,
primitive_xy_cfg
,
primitive_line_cfg
],
...
@@ -267,7 +272,6 @@ model = dict(
...
@@ -267,7 +272,6 @@ model = dict(
mlp_channels
=
[
128
+
12
,
128
,
64
,
32
],
mlp_channels
=
[
128
+
12
,
128
,
64
,
32
],
use_xyz
=
True
,
use_xyz
=
True
,
normalize_xyz
=
True
),
normalize_xyz
=
True
),
feat_channels
=
(
128
,
128
),
primitive_refine_channels
=
[
128
,
128
,
128
],
primitive_refine_channels
=
[
128
,
128
,
128
],
upper_thresh
=
100.0
,
upper_thresh
=
100.0
,
surface_thresh
=
0.5
,
surface_thresh
=
0.5
,
...
@@ -275,7 +279,7 @@ model = dict(
...
@@ -275,7 +279,7 @@ model = dict(
conv_cfg
=
dict
(
type
=
'Conv1d'
),
conv_cfg
=
dict
(
type
=
'Conv1d'
),
norm_cfg
=
dict
(
type
=
'BN1d'
),
norm_cfg
=
dict
(
type
=
'BN1d'
),
objectness_loss
=
dict
(
objectness_loss
=
dict
(
type
=
'CrossEntropyLoss'
,
type
=
'
mmdet.
CrossEntropyLoss'
,
class_weight
=
[
0.2
,
0.8
],
class_weight
=
[
0.2
,
0.8
],
reduction
=
'sum'
,
reduction
=
'sum'
,
loss_weight
=
5.0
),
loss_weight
=
5.0
),
...
@@ -286,41 +290,47 @@ model = dict(
...
@@ -286,41 +290,47 @@ model = dict(
loss_src_weight
=
10.0
,
loss_src_weight
=
10.0
,
loss_dst_weight
=
10.0
),
loss_dst_weight
=
10.0
),
dir_class_loss
=
dict
(
dir_class_loss
=
dict
(
type
=
'CrossEntropyLoss'
,
reduction
=
'sum'
,
loss_weight
=
0.1
),
type
=
'mmdet.CrossEntropyLoss'
,
reduction
=
'sum'
,
loss_weight
=
0.1
),
dir_res_loss
=
dict
(
dir_res_loss
=
dict
(
type
=
'SmoothL1Loss'
,
reduction
=
'sum'
,
loss_weight
=
10.0
),
type
=
'
mmdet.
SmoothL1Loss'
,
reduction
=
'sum'
,
loss_weight
=
10.0
),
size_class_loss
=
dict
(
size_class_loss
=
dict
(
type
=
'CrossEntropyLoss'
,
reduction
=
'sum'
,
loss_weight
=
0.1
),
type
=
'mmdet.CrossEntropyLoss'
,
reduction
=
'sum'
,
loss_weight
=
0.1
),
size_res_loss
=
dict
(
size_res_loss
=
dict
(
type
=
'SmoothL1Loss'
,
reduction
=
'sum'
,
loss_weight
=
10.0
),
type
=
'
mmdet.
SmoothL1Loss'
,
reduction
=
'sum'
,
loss_weight
=
10.0
),
semantic_loss
=
dict
(
semantic_loss
=
dict
(
type
=
'CrossEntropyLoss'
,
reduction
=
'sum'
,
loss_weight
=
0.1
),
type
=
'mmdet.CrossEntropyLoss'
,
reduction
=
'sum'
,
loss_weight
=
0.1
),
cues_objectness_loss
=
dict
(
cues_objectness_loss
=
dict
(
type
=
'CrossEntropyLoss'
,
type
=
'
mmdet.
CrossEntropyLoss'
,
class_weight
=
[
0.3
,
0.7
],
class_weight
=
[
0.3
,
0.7
],
reduction
=
'mean'
,
reduction
=
'mean'
,
loss_weight
=
5.0
),
loss_weight
=
5.0
),
cues_semantic_loss
=
dict
(
cues_semantic_loss
=
dict
(
type
=
'CrossEntropyLoss'
,
type
=
'
mmdet.
CrossEntropyLoss'
,
class_weight
=
[
0.3
,
0.7
],
class_weight
=
[
0.3
,
0.7
],
reduction
=
'mean'
,
reduction
=
'mean'
,
loss_weight
=
5.0
),
loss_weight
=
5.0
),
proposal_objectness_loss
=
dict
(
proposal_objectness_loss
=
dict
(
type
=
'CrossEntropyLoss'
,
type
=
'
mmdet.
CrossEntropyLoss'
,
class_weight
=
[
0.2
,
0.8
],
class_weight
=
[
0.2
,
0.8
],
reduction
=
'none'
,
reduction
=
'none'
,
loss_weight
=
5.0
),
loss_weight
=
5.0
),
primitive_center_loss
=
dict
(
primitive_center_loss
=
dict
(
type
=
'MSELoss'
,
reduction
=
'none'
,
loss_weight
=
1.0
))),
type
=
'
mmdet.
MSELoss'
,
reduction
=
'none'
,
loss_weight
=
1.0
))),
# model training and testing settings
# model training and testing settings
train_cfg
=
dict
(
train_cfg
=
dict
(
rpn
=
dict
(
rpn
=
dict
(
pos_distance_thr
=
0.3
,
neg_distance_thr
=
0.6
,
sample_mod
=
'vote'
),
pos_distance_thr
=
0.3
,
neg_distance_thr
=
0.6
,
sample_mod
e
=
'vote'
),
rpn_proposal
=
dict
(
use_nms
=
False
),
rpn_proposal
=
dict
(
use_nms
=
False
),
rcnn
=
dict
(
rcnn
=
dict
(
pos_distance_thr
=
0.3
,
pos_distance_thr
=
0.3
,
neg_distance_thr
=
0.6
,
neg_distance_thr
=
0.6
,
sample_mod
=
'vote'
,
sample_mod
e
=
'vote'
,
far_threshold
=
0.6
,
far_threshold
=
0.6
,
near_threshold
=
0.3
,
near_threshold
=
0.3
,
mask_surface_threshold
=
0.3
,
mask_surface_threshold
=
0.3
,
...
@@ -329,13 +339,13 @@ model = dict(
...
@@ -329,13 +339,13 @@ model = dict(
label_line_threshold
=
0.3
)),
label_line_threshold
=
0.3
)),
test_cfg
=
dict
(
test_cfg
=
dict
(
rpn
=
dict
(
rpn
=
dict
(
sample_mod
=
'seed'
,
sample_mod
e
=
'seed'
,
nms_thr
=
0.25
,
nms_thr
=
0.25
,
score_thr
=
0.05
,
score_thr
=
0.05
,
per_class_proposal
=
True
,
per_class_proposal
=
True
,
use_nms
=
False
),
use_nms
=
False
),
rcnn
=
dict
(
rcnn
=
dict
(
sample_mod
=
'seed'
,
sample_mod
e
=
'seed'
,
nms_thr
=
0.25
,
nms_thr
=
0.25
,
score_thr
=
0.05
,
score_thr
=
0.05
,
per_class_proposal
=
True
)))
per_class_proposal
=
True
)))
configs/h3dnet/debug.py
deleted
100644 → 0
View file @
9ebb75da
_base_
=
[
'../_base_/datasets/scannet-3d-18class.py'
,
'../_base_/models/h3dnet.py'
,
'../_base_/schedules/schedule_3x.py'
,
'../_base_/default_runtime.py'
]
# model settings
model
=
dict
(
rpn_head
=
dict
(
num_classes
=
18
,
bbox_coder
=
dict
(
type
=
'PartialBinBasedBBoxCoder'
,
num_sizes
=
18
,
num_dir_bins
=
24
,
with_rot
=
False
,
mean_sizes
=
[[
0.76966727
,
0.8116021
,
0.92573744
],
[
1.876858
,
1.8425595
,
1.1931566
],
[
0.61328
,
0.6148609
,
0.7182701
],
[
1.3955007
,
1.5121545
,
0.83443564
],
[
0.97949594
,
1.0675149
,
0.6329687
],
[
0.531663
,
0.5955577
,
1.7500148
],
[
0.9624706
,
0.72462326
,
1.1481868
],
[
0.83221924
,
1.0490936
,
1.6875663
],
[
0.21132214
,
0.4206159
,
0.5372846
],
[
1.4440073
,
1.8970833
,
0.26985747
],
[
1.0294262
,
1.4040797
,
0.87554324
],
[
1.3766412
,
0.65521795
,
1.6813129
],
[
0.6650819
,
0.71111923
,
1.298853
],
[
0.41999173
,
0.37906948
,
1.7513971
],
[
0.59359556
,
0.5912492
,
0.73919016
],
[
0.50867593
,
0.50656086
,
0.30136237
],
[
1.1511526
,
1.0546296
,
0.49706793
],
[
0.47535285
,
0.49249494
,
0.5802117
]])),
roi_head
=
dict
(
bbox_head
=
dict
(
num_classes
=
18
,
bbox_coder
=
dict
(
type
=
'PartialBinBasedBBoxCoder'
,
num_sizes
=
18
,
num_dir_bins
=
24
,
with_rot
=
False
,
mean_sizes
=
[[
0.76966727
,
0.8116021
,
0.92573744
],
[
1.876858
,
1.8425595
,
1.1931566
],
[
0.61328
,
0.6148609
,
0.7182701
],
[
1.3955007
,
1.5121545
,
0.83443564
],
[
0.97949594
,
1.0675149
,
0.6329687
],
[
0.531663
,
0.5955577
,
1.7500148
],
[
0.9624706
,
0.72462326
,
1.1481868
],
[
0.83221924
,
1.0490936
,
1.6875663
],
[
0.21132214
,
0.4206159
,
0.5372846
],
[
1.4440073
,
1.8970833
,
0.26985747
],
[
1.0294262
,
1.4040797
,
0.87554324
],
[
1.3766412
,
0.65521795
,
1.6813129
],
[
0.6650819
,
0.71111923
,
1.298853
],
[
0.41999173
,
0.37906948
,
1.7513971
],
[
0.59359556
,
0.5912492
,
0.73919016
],
[
0.50867593
,
0.50656086
,
0.30136237
],
[
1.1511526
,
1.0546296
,
0.49706793
],
[
0.47535285
,
0.49249494
,
0.5802117
]]))))
train_dataloader
=
dict
(
batch_size
=
3
,
num_workers
=
2
,
)
# yapf:disable
default_hooks
=
dict
(
logger
=
dict
(
type
=
'LoggerHook'
,
interval
=
30
)
)
# yapf:enable
configs/h3dnet/h3dnet_3x8_scannet-3d-18class.py
View file @
0e17beab
...
@@ -57,8 +57,13 @@ model = dict(
...
@@ -57,8 +57,13 @@ model = dict(
[
1.1511526
,
1.0546296
,
0.49706793
],
[
1.1511526
,
1.0546296
,
0.49706793
],
[
0.47535285
,
0.49249494
,
0.5802117
]]))))
[
0.47535285
,
0.49249494
,
0.5802117
]]))))
data
=
dict
(
samples_per_gpu
=
3
,
workers_per_gpu
=
2
)
train_dataloader
=
dict
(
batch_size
=
3
,
num_workers
=
2
,
)
# yapf:disable
# yapf:disable
log_config
=
dict
(
interval
=
30
)
default_hooks
=
dict
(
logger
=
dict
(
type
=
'LoggerHook'
,
interval
=
30
)
)
# yapf:enable
# yapf:enable
mmdet3d/datasets/scannet_dataset.py
View file @
0e17beab
...
@@ -5,11 +5,9 @@ from typing import Callable, List, Optional, Union
...
@@ -5,11 +5,9 @@ from typing import Callable, List, Optional, Union
import
numpy
as
np
import
numpy
as
np
from
mmdet3d.core
import
show_result
from
mmdet3d.core.bbox
import
DepthInstance3DBoxes
from
mmdet3d.core.bbox
import
DepthInstance3DBoxes
from
mmdet3d.registry
import
DATASETS
from
mmdet3d.registry
import
DATASETS
from
.det3d_dataset
import
Det3DDataset
from
.det3d_dataset
import
Det3DDataset
from
.pipelines
import
Compose
from
.seg3d_dataset
import
Seg3DDataset
from
.seg3d_dataset
import
Seg3DDataset
...
@@ -151,46 +149,6 @@ class ScanNetDataset(Det3DDataset):
...
@@ -151,46 +149,6 @@ class ScanNetDataset(Det3DDataset):
return
ann_info
return
ann_info
def
_build_default_pipeline
(
self
):
"""Build the default pipeline for this dataset."""
pipeline
=
[
dict
(
type
=
'LoadPointsFromFile'
,
coord_type
=
'DEPTH'
,
shift_height
=
False
,
load_dim
=
6
,
use_dim
=
[
0
,
1
,
2
]),
dict
(
type
=
'GlobalAlignment'
,
rotation_axis
=
2
),
dict
(
type
=
'DefaultFormatBundle3D'
,
class_names
=
self
.
CLASSES
,
with_label
=
False
),
dict
(
type
=
'Collect3D'
,
keys
=
[
'points'
])
]
return
Compose
(
pipeline
)
def
show
(
self
,
results
,
out_dir
,
show
=
True
,
pipeline
=
None
):
"""Results visualization.
Args:
results (list[dict]): List of bounding boxes results.
out_dir (str): Output directory of visualization result.
show (bool): Visualize the results online.
pipeline (list[dict], optional): raw data loading for showing.
Default: None.
"""
assert
out_dir
is
not
None
,
'Expect out_dir, got none.'
pipeline
=
self
.
_get_pipeline
(
pipeline
)
for
i
,
result
in
enumerate
(
results
):
data_info
=
self
.
get_data_info
[
i
]
pts_path
=
data_info
[
'lidar_points'
][
'lidar_path'
]
file_name
=
osp
.
split
(
pts_path
)[
-
1
].
split
(
'.'
)[
0
]
points
=
self
.
_extract_data
(
i
,
pipeline
,
'points'
).
numpy
()
gt_bboxes
=
self
.
get_ann_info
(
i
)[
'gt_bboxes_3d'
].
tensor
.
numpy
()
pred_bboxes
=
result
[
'boxes_3d'
].
tensor
.
numpy
()
show_result
(
points
,
gt_bboxes
,
pred_bboxes
,
out_dir
,
file_name
,
show
)
@
DATASETS
.
register_module
()
@
DATASETS
.
register_module
()
class
ScanNetSegDataset
(
Seg3DDataset
):
class
ScanNetSegDataset
(
Seg3DDataset
):
...
...
mmdet3d/models/dense_heads/vote_head.py
View file @
0e17beab
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
from
typing
import
Dict
,
List
,
Optional
,
Union
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
mmcv.ops
import
furthest_point_sample
from
mmcv.ops
import
furthest_point_sample
from
mmcv.runner
import
BaseModule
from
mmcv.runner
import
BaseModule
from
mmengine
import
ConfigDict
,
InstanceData
from
mmengine
import
ConfigDict
,
InstanceData
from
torch
import
Tensor
from
torch.nn
import
functional
as
F
from
torch.nn
import
functional
as
F
from
mmdet3d.core.post_processing
import
aligned_3d_nms
from
mmdet3d.core.post_processing
import
aligned_3d_nms
...
@@ -161,7 +162,7 @@ class VoteHead(BaseModule):
...
@@ -161,7 +162,7 @@ class VoteHead(BaseModule):
points
:
List
[
torch
.
Tensor
],
points
:
List
[
torch
.
Tensor
],
feats_dict
:
Dict
[
str
,
torch
.
Tensor
],
feats_dict
:
Dict
[
str
,
torch
.
Tensor
],
batch_data_samples
:
List
[
Det3DDataSample
],
batch_data_samples
:
List
[
Det3DDataSample
],
rescale
=
True
,
use_nms
:
bool
=
True
,
**
kwargs
)
->
List
[
InstanceData
]:
**
kwargs
)
->
List
[
InstanceData
]:
"""
"""
Args:
Args:
...
@@ -169,8 +170,8 @@ class VoteHead(BaseModule):
...
@@ -169,8 +170,8 @@ class VoteHead(BaseModule):
feats_dict (dict): Features from FPN or backbone..
feats_dict (dict): Features from FPN or backbone..
batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
Samples. It usually includes meta information of data.
Samples. It usually includes meta information of data.
rescale
(bool): Whether
rescale the resutls to
use_nms
(bool): Whether
do the nms for predictions.
the original scal
e.
Defaults to Tru
e.
Returns:
Returns:
list[:obj:`InstanceData`]: List of processed predictions. Each
list[:obj:`InstanceData`]: List of processed predictions. Each
...
@@ -178,6 +179,9 @@ class VoteHead(BaseModule):
...
@@ -178,6 +179,9 @@ class VoteHead(BaseModule):
scores and labels.
scores and labels.
"""
"""
preds_dict
=
self
(
feats_dict
)
preds_dict
=
self
(
feats_dict
)
# `preds_dict` can be used in H3DNET
feats_dict
.
update
(
preds_dict
)
batch_size
=
len
(
batch_data_samples
)
batch_size
=
len
(
batch_data_samples
)
batch_input_metas
=
[]
batch_input_metas
=
[]
for
batch_index
in
range
(
batch_size
):
for
batch_index
in
range
(
batch_size
):
...
@@ -185,12 +189,73 @@ class VoteHead(BaseModule):
...
@@ -185,12 +189,73 @@ class VoteHead(BaseModule):
batch_input_metas
.
append
(
metainfo
)
batch_input_metas
.
append
(
metainfo
)
results_list
=
self
.
predict_by_feat
(
results_list
=
self
.
predict_by_feat
(
points
,
preds_dict
,
batch_input_metas
,
rescale
=
rescale
,
**
kwargs
)
points
,
preds_dict
,
batch_input_metas
,
use_nms
=
use_nms
,
**
kwargs
)
return
results_list
return
results_list
def
loss
(
self
,
points
:
List
[
torch
.
Tensor
],
feats_dict
:
Dict
[
str
,
def
loss_and_predict
(
self
,
torch
.
Tensor
],
points
:
List
[
torch
.
Tensor
],
batch_data_samples
:
List
[
Det3DDataSample
],
**
kwargs
)
->
dict
:
feats_dict
:
Dict
[
str
,
torch
.
Tensor
],
batch_data_samples
:
List
[
Det3DDataSample
],
ret_target
:
bool
=
False
,
proposal_cfg
:
dict
=
None
,
**
kwargs
)
->
Tuple
:
"""
Args:
points (list[tensor]): Points cloud of multiple samples.
feats_dict (dict): Predictions from backbone or FPN.
batch_data_samples (list[:obj:`Det3DDataSample`]): Each item
contains the meta information of each sample and
corresponding annotations.
ret_target (bool): Whether return the assigned target.
Defaults to False.
proposal_cfg (dict): Configure for proposal process.
Defaults to True.
Returns:
tuple: Contains loss and predictions after post-process.
"""
preds_dict
=
self
.
forward
(
feats_dict
)
feats_dict
.
update
(
preds_dict
)
batch_gt_instance_3d
=
[]
batch_gt_instances_ignore
=
[]
batch_input_metas
=
[]
batch_pts_semantic_mask
=
[]
batch_pts_instance_mask
=
[]
for
data_sample
in
batch_data_samples
:
batch_input_metas
.
append
(
data_sample
.
metainfo
)
batch_gt_instance_3d
.
append
(
data_sample
.
gt_instances_3d
)
batch_gt_instances_ignore
.
append
(
data_sample
.
get
(
'ignored_instances'
,
None
))
batch_pts_semantic_mask
.
append
(
data_sample
.
gt_pts_seg
.
get
(
'pts_semantic_mask'
,
None
))
batch_pts_instance_mask
.
append
(
data_sample
.
gt_pts_seg
.
get
(
'pts_instance_mask'
,
None
))
loss_inputs
=
(
points
,
preds_dict
,
batch_gt_instance_3d
)
losses
=
self
.
loss_by_feat
(
*
loss_inputs
,
batch_pts_semantic_mask
=
batch_pts_semantic_mask
,
batch_pts_instance_mask
=
batch_pts_instance_mask
,
batch_input_metas
=
batch_input_metas
,
batch_gt_instances_ignore
=
batch_gt_instances_ignore
,
ret_target
=
ret_target
,
**
kwargs
)
results_list
=
self
.
predict_by_feat
(
points
,
preds_dict
,
batch_input_metas
,
use_nms
=
proposal_cfg
.
use_nms
,
**
kwargs
)
return
losses
,
results_list
def
loss
(
self
,
points
:
List
[
torch
.
Tensor
],
feats_dict
:
Dict
[
str
,
torch
.
Tensor
],
batch_data_samples
:
List
[
Det3DDataSample
],
ret_target
:
bool
=
False
,
**
kwargs
)
->
dict
:
"""
"""
Args:
Args:
points (list[tensor]): Points cloud of multiple samples.
points (list[tensor]): Points cloud of multiple samples.
...
@@ -198,6 +263,8 @@ class VoteHead(BaseModule):
...
@@ -198,6 +263,8 @@ class VoteHead(BaseModule):
batch_data_samples (list[:obj:`Det3DDataSample`]): Each item
batch_data_samples (list[:obj:`Det3DDataSample`]): Each item
contains the meta information of each sample and
contains the meta information of each sample and
corresponding annotations.
corresponding annotations.
ret_target (bool): Whether return the assigned target.
Defaults to False.
Returns:
Returns:
dict: A dictionary of loss components.
dict: A dictionary of loss components.
...
@@ -224,7 +291,9 @@ class VoteHead(BaseModule):
...
@@ -224,7 +291,9 @@ class VoteHead(BaseModule):
batch_pts_semantic_mask
=
batch_pts_semantic_mask
,
batch_pts_semantic_mask
=
batch_pts_semantic_mask
,
batch_pts_instance_mask
=
batch_pts_instance_mask
,
batch_pts_instance_mask
=
batch_pts_instance_mask
,
batch_input_metas
=
batch_input_metas
,
batch_input_metas
=
batch_input_metas
,
batch_gt_instances_ignore
=
batch_gt_instances_ignore
)
batch_gt_instances_ignore
=
batch_gt_instances_ignore
,
ret_target
=
ret_target
,
**
kwargs
)
return
losses
return
losses
def
forward
(
self
,
feat_dict
:
dict
)
->
dict
:
def
forward
(
self
,
feat_dict
:
dict
)
->
dict
:
...
@@ -330,7 +399,7 @@ class VoteHead(BaseModule):
...
@@ -330,7 +399,7 @@ class VoteHead(BaseModule):
batch_pts_semantic_mask (list[tensor]): Instance mask
batch_pts_semantic_mask (list[tensor]): Instance mask
of points cloud. Defaults to None.
of points cloud. Defaults to None.
batch_input_metas (list[dict]): Contain pcd and img's meta info.
batch_input_metas (list[dict]): Contain pcd and img's meta info.
ret_target (bool): Return targets or not.
ret_target (bool): Return targets or not.
Defaults to False.
Returns:
Returns:
dict: Losses of Votenet.
dict: Losses of Votenet.
...
@@ -671,9 +740,10 @@ class VoteHead(BaseModule):
...
@@ -671,9 +740,10 @@ class VoteHead(BaseModule):
while using vote head in rpn stage.
while using vote head in rpn stage.
Returns:
Returns:
list[:obj:`InstanceData`]: List of processed predictions. Each
list[:obj:`InstanceData`] or Tensor: Return list of processed
InstanceData cantains 3d Bounding boxes and corresponding
predictions when `use_nms` is True. Each InstanceData cantains
scores and labels.
3d Bounding boxes and corresponding scores and labels.
Return raw bboxes when `use_nms` is False.
"""
"""
# decode boxes
# decode boxes
stack_points
=
torch
.
stack
(
points
)
stack_points
=
torch
.
stack
(
points
)
...
@@ -683,9 +753,9 @@ class VoteHead(BaseModule):
...
@@ -683,9 +753,9 @@ class VoteHead(BaseModule):
batch_size
=
bbox3d
.
shape
[
0
]
batch_size
=
bbox3d
.
shape
[
0
]
results_list
=
list
()
results_list
=
list
()
if
use_nms
:
for
b
in
range
(
batch_size
):
for
b
in
range
(
batch_size
):
temp_results
=
InstanceData
()
temp_results
=
InstanceData
()
if
use_nms
:
bbox_selected
,
score_selected
,
labels
=
\
bbox_selected
,
score_selected
,
labels
=
\
self
.
multiclass_nms_single
(
obj_scores
[
b
],
self
.
multiclass_nms_single
(
obj_scores
[
b
],
sem_scores
[
b
],
sem_scores
[
b
],
...
@@ -700,20 +770,15 @@ class VoteHead(BaseModule):
...
@@ -700,20 +770,15 @@ class VoteHead(BaseModule):
temp_results
.
scores_3d
=
score_selected
temp_results
.
scores_3d
=
score_selected
temp_results
.
labels_3d
=
labels
temp_results
.
labels_3d
=
labels
results_list
.
append
(
temp_results
)
results_list
.
append
(
temp_results
)
else
:
bbox
=
batch_input_metas
[
b
][
'box_type_3d'
](
bbox_selected
,
box_dim
=
bbox_selected
.
shape
[
-
1
],
with_yaw
=
self
.
bbox_coder
.
with_rot
)
temp_results
.
bboxes_3d
=
bbox
temp_results
.
obj_scores_3d
=
obj_scores
[
b
]
temp_results
.
sem_scores_3d
=
obj_scores
[
b
]
results_list
.
append
(
temp_results
)
return
results_list
return
results_list
else
:
# TODO unify it when refactor the Augtest
return
bbox3d
def
multiclass_nms_single
(
self
,
obj_scores
,
sem_scores
,
bbox
,
points
,
def
multiclass_nms_single
(
self
,
obj_scores
:
Tensor
,
sem_scores
:
Tensor
,
input_meta
):
bbox
:
Tensor
,
points
:
Tensor
,
input_meta
:
dict
)
->
Tuple
:
"""Multi-class nms in single batch.
"""Multi-class nms in single batch.
Args:
Args:
...
...
mmdet3d/models/detectors/h3dnet.py
View file @
0e17beab
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
from
typing
import
Dict
,
List
,
Optional
,
Union
import
torch
import
torch
from
torch
import
Tensor
from
mmdet3d.core
import
merge_aug_bboxes_3d
from
mmdet3d.registry
import
MODELS
from
mmdet3d.registry
import
MODELS
from
...core
import
Det3DDataSample
from
.two_stage
import
TwoStage3DDetector
from
.two_stage
import
TwoStage3DDetector
...
@@ -11,17 +14,33 @@ class H3DNet(TwoStage3DDetector):
...
@@ -11,17 +14,33 @@ class H3DNet(TwoStage3DDetector):
r
"""H3DNet model.
r
"""H3DNet model.
Please refer to the `paper <https://arxiv.org/abs/2006.05682>`_
Please refer to the `paper <https://arxiv.org/abs/2006.05682>`_
Args:
backbone (dict): Config dict of detector's backbone.
neck (dict, optional): Config dict of neck. Defaults to None.
rpn_head (dict, optional): Config dict of rpn head. Defaults to None.
roi_head (dict, optional): Config dict of roi head. Defaults to None.
train_cfg (dict, optional): Config dict of training hyper-parameters.
Defaults to None.
test_cfg (dict, optional): Config dict of test hyper-parameters.
Defaults to None.
init_cfg (dict, optional): the config to control the
initialization. Default to None.
data_preprocessor (dict or ConfigDict, optional): The pre-process
config of :class:`BaseDataPreprocessor`. it usually includes,
``pad_size_divisor``, ``pad_value``, ``mean`` and ``std``.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
backbone
,
backbone
:
dict
,
neck
=
None
,
neck
:
Optional
[
dict
]
=
None
,
rpn_head
=
None
,
rpn_head
:
Optional
[
dict
]
=
None
,
roi_head
=
None
,
roi_head
:
Optional
[
dict
]
=
None
,
train_cfg
=
None
,
train_cfg
:
Optional
[
dict
]
=
None
,
test_cfg
=
None
,
test_cfg
:
Optional
[
dict
]
=
None
,
pretrained
=
None
,
init_cfg
:
Optional
[
dict
]
=
None
,
init_cfg
=
None
):
data_preprocessor
:
Optional
[
dict
]
=
None
,
**
kwargs
)
->
None
:
super
(
H3DNet
,
self
).
__init__
(
super
(
H3DNet
,
self
).
__init__
(
backbone
=
backbone
,
backbone
=
backbone
,
neck
=
neck
,
neck
=
neck
,
...
@@ -29,148 +48,110 @@ class H3DNet(TwoStage3DDetector):
...
@@ -29,148 +48,110 @@ class H3DNet(TwoStage3DDetector):
roi_head
=
roi_head
,
roi_head
=
roi_head
,
train_cfg
=
train_cfg
,
train_cfg
=
train_cfg
,
test_cfg
=
test_cfg
,
test_cfg
=
test_cfg
,
pretrained
=
pretrained
,
init_cfg
=
init_cfg
,
init_cfg
=
init_cfg
)
data_preprocessor
=
data_preprocessor
,
**
kwargs
)
def
forward_train
(
self
,
points
,
def
extract_feat
(
self
,
batch_inputs_dict
:
dict
)
->
None
:
img_metas
,
"""Directly extract features from the backbone+neck.
gt_bboxes_3d
,
gt_labels_3d
,
pts_semantic_mask
=
None
,
pts_instance_mask
=
None
,
gt_bboxes_ignore
=
None
):
"""Forward of training.
Args:
Args:
points (list[torch.Tensor]): Points of each batch.
img_metas (list): Image metas.
batch_inputs_dict (dict): The model input dict which include
gt_bboxes_3d (:obj:`BaseInstance3DBoxes`): gt bboxes of each batch.
'points'.
gt_labels_3d (list[torch.Tensor]): gt class labels of each batch.
pts_semantic_mask (list[torch.Tensor]): point-wise semantic
- points (list[torch.Tensor]): Point cloud of each sample.
label of each batch.
pts_instance_mask (list[torch.Tensor]): point-wise instance
label of each batch.
gt_bboxes_ignore (list[torch.Tensor]): Specify
which bounding.
Returns:
Returns:
dict:
Losses
.
dict:
Dict of feature
.
"""
"""
points_cat
=
torch
.
stack
(
points
)
stack_points
=
torch
.
stack
(
batch_inputs_dict
[
'points'
])
x
=
self
.
backbone
(
stack_points
)
if
self
.
with_neck
:
x
=
self
.
neck
(
x
)
return
x
def
loss
(
self
,
batch_inputs_dict
:
Dict
[
str
,
Union
[
List
,
Tensor
]],
batch_data_samples
:
List
[
Det3DDataSample
],
**
kwargs
)
->
dict
:
"""
Args:
batch_inputs_dict (dict): The model input dict which include
'points' keys.
- points (list[torch.Tensor]): Point cloud of each sample.
batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
Samples. It usually includes information such as
`gt_instance_3d`.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
feats_dict
=
self
.
extract_feat
(
batch_inputs_dict
)
feats_dict
=
self
.
extract_feat
(
points_cat
)
feats_dict
[
'fp_xyz'
]
=
[
feats_dict
[
'fp_xyz_net0'
][
-
1
]]
feats_dict
[
'fp_xyz'
]
=
[
feats_dict
[
'fp_xyz_net0'
][
-
1
]]
feats_dict
[
'fp_features'
]
=
[
feats_dict
[
'hd_feature'
]]
feats_dict
[
'fp_features'
]
=
[
feats_dict
[
'hd_feature'
]]
feats_dict
[
'fp_indices'
]
=
[
feats_dict
[
'fp_indices_net0'
][
-
1
]]
feats_dict
[
'fp_indices'
]
=
[
feats_dict
[
'fp_indices_net0'
][
-
1
]]
losses
=
dict
()
losses
=
dict
()
if
self
.
with_rpn
:
if
self
.
with_rpn
:
rpn_outs
=
self
.
rpn_head
(
feats_dict
,
self
.
train_cfg
.
rpn
.
sample_mod
)
feats_dict
.
update
(
rpn_outs
)
rpn_loss_inputs
=
(
points
,
gt_bboxes_3d
,
gt_labels_3d
,
pts_semantic_mask
,
pts_instance_mask
,
img_metas
)
rpn_losses
=
self
.
rpn_head
.
loss
(
rpn_outs
,
*
rpn_loss_inputs
,
gt_bboxes_ignore
=
gt_bboxes_ignore
,
ret_target
=
True
)
feats_dict
[
'targets'
]
=
rpn_losses
.
pop
(
'targets'
)
losses
.
update
(
rpn_losses
)
# Generate rpn proposals
proposal_cfg
=
self
.
train_cfg
.
get
(
'rpn_proposal'
,
proposal_cfg
=
self
.
train_cfg
.
get
(
'rpn_proposal'
,
self
.
test_cfg
.
rpn
)
self
.
test_cfg
.
rpn
)
proposal_inputs
=
(
points
,
rpn_outs
,
img_metas
)
# note, the feats_dict would be added new key & value in rpn_head
proposal_list
=
self
.
rpn_head
.
get_bboxes
(
rpn_losses
,
rpn_proposals
=
self
.
rpn_head
.
loss_and_predict
(
*
proposal_inputs
,
use_nms
=
proposal_cfg
.
use_nms
)
batch_inputs_dict
[
'points'
],
feats_dict
[
'proposal_list'
]
=
proposal_list
feats_dict
,
batch_data_samples
,
ret_target
=
True
,
proposal_cfg
=
proposal_cfg
)
feats_dict
[
'targets'
]
=
rpn_losses
.
pop
(
'targets'
)
losses
.
update
(
rpn_losses
)
feats_dict
[
'rpn_proposals'
]
=
rpn_proposals
else
:
else
:
raise
NotImplementedError
raise
NotImplementedError
roi_losses
=
self
.
roi_head
.
forward_train
(
feats_dict
,
img_metas
,
points
,
roi_losses
=
self
.
roi_head
.
loss
(
batch_inputs_dict
[
'points'
],
gt_bboxes_3d
,
gt_labels_3d
,
feats_dict
,
batch_data_samples
,
pts_semantic_mask
,
**
kwargs
)
pts_instance_mask
,
gt_bboxes_ignore
)
losses
.
update
(
roi_losses
)
losses
.
update
(
roi_losses
)
return
losses
return
losses
def
simple_test
(
self
,
points
,
img_metas
,
imgs
=
None
,
rescale
=
False
):
def
predict
(
"""Forward of testing.
self
,
batch_input_dict
:
Dict
,
batch_data_samples
:
List
[
Det3DDataSample
]
)
->
List
[
Det3DDataSample
]:
"""Get model predictions.
Args:
Args:
points (list[torch.Tensor]): Points of each sample.
points (list[torch.Tensor]): Points of each sample.
img_metas (list): Image metas.
batch_data_samples (list[:obj:`Det3DDataSample`]): Each item
rescale (bool): Whether to rescale results.
contains the meta information of each sample and
corresponding annotations.
Returns:
Returns:
list: Predicted 3d boxes.
list: Predicted 3d boxes.
"""
"""
points_cat
=
torch
.
stack
(
points
)
feats_dict
=
self
.
extract_feat
(
points_ca
t
)
feats_dict
=
self
.
extract_feat
(
batch_input_dic
t
)
feats_dict
[
'fp_xyz'
]
=
[
feats_dict
[
'fp_xyz_net0'
][
-
1
]]
feats_dict
[
'fp_xyz'
]
=
[
feats_dict
[
'fp_xyz_net0'
][
-
1
]]
feats_dict
[
'fp_features'
]
=
[
feats_dict
[
'hd_feature'
]]
feats_dict
[
'fp_features'
]
=
[
feats_dict
[
'hd_feature'
]]
feats_dict
[
'fp_indices'
]
=
[
feats_dict
[
'fp_indices_net0'
][
-
1
]]
feats_dict
[
'fp_indices'
]
=
[
feats_dict
[
'fp_indices_net0'
][
-
1
]]
if
self
.
with_rpn
:
if
self
.
with_rpn
:
proposal_cfg
=
self
.
test_cfg
.
rpn
proposal_cfg
=
self
.
test_cfg
.
rpn
rpn_outs
=
self
.
rpn_head
(
feats_dict
,
proposal_cfg
.
sample_mod
)
rpn_proposals
=
self
.
rpn_head
.
predict
(
feats_dict
.
update
(
rpn_outs
)
batch_input_dict
[
'points'
],
# Generate rpn proposals
feats_dict
,
proposal_list
=
self
.
rpn_head
.
get_bboxes
(
batch_data_samples
,
points
,
rpn_outs
,
img_metas
,
use_nms
=
proposal_cfg
.
use_nms
)
use_nms
=
proposal_cfg
.
use_nms
)
feats_dict
[
'proposal_list'
]
=
proposal_list
feats_dict
[
'rpn_proposals'
]
=
rpn_proposals
else
:
raise
NotImplementedError
return
self
.
roi_head
.
simple_test
(
feats_dict
,
img_metas
,
points_cat
,
rescale
=
rescale
)
def
aug_test
(
self
,
points
,
img_metas
,
imgs
=
None
,
rescale
=
False
):
"""Test with augmentation."""
points_cat
=
[
torch
.
stack
(
pts
)
for
pts
in
points
]
feats_dict
=
self
.
extract_feats
(
points_cat
,
img_metas
)
for
feat_dict
in
feats_dict
:
feat_dict
[
'fp_xyz'
]
=
[
feat_dict
[
'fp_xyz_net0'
][
-
1
]]
feat_dict
[
'fp_features'
]
=
[
feat_dict
[
'hd_feature'
]]
feat_dict
[
'fp_indices'
]
=
[
feat_dict
[
'fp_indices_net0'
][
-
1
]]
# only support aug_test for one sample
aug_bboxes
=
[]
for
feat_dict
,
pts_cat
,
img_meta
in
zip
(
feats_dict
,
points_cat
,
img_metas
):
if
self
.
with_rpn
:
proposal_cfg
=
self
.
test_cfg
.
rpn
rpn_outs
=
self
.
rpn_head
(
feat_dict
,
proposal_cfg
.
sample_mod
)
feat_dict
.
update
(
rpn_outs
)
# Generate rpn proposals
proposal_list
=
self
.
rpn_head
.
get_bboxes
(
points
,
rpn_outs
,
img_metas
,
use_nms
=
proposal_cfg
.
use_nms
)
feat_dict
[
'proposal_list'
]
=
proposal_list
else
:
else
:
raise
NotImplementedError
raise
NotImplementedError
bbox_results
=
self
.
roi_head
.
simple_test
(
results_list
=
self
.
roi_head
.
predict
(
feat_dict
,
batch_input_dict
[
'points'
],
self
.
test_cfg
.
rcnn
.
sample_mod
,
feats_dict
,
img_meta
,
batch_data_samples
,
pts_cat
,
suffix
=
'_optimized'
)
rescale
=
rescale
)
return
self
.
convert_to_datasample
(
results_list
)
aug_bboxes
.
append
(
bbox_results
)
# after merging, bboxes will be rescaled to the original image size
merged_bboxes
=
merge_aug_bboxes_3d
(
aug_bboxes
,
img_metas
,
self
.
bbox_head
.
test_cfg
)
return
[
merged_bboxes
]
def
extract_feats
(
self
,
points
,
img_metas
):
"""Extract features of multiple samples."""
return
[
self
.
extract_feat
(
pts
,
img_meta
)
for
pts
,
img_meta
in
zip
(
points
,
img_metas
)
]
mmdet3d/models/detectors/point_rcnn.py
View file @
0e17beab
...
@@ -56,12 +56,12 @@ class PointRCNN(TwoStage3DDetector):
...
@@ -56,12 +56,12 @@ class PointRCNN(TwoStage3DDetector):
x
=
self
.
neck
(
x
)
x
=
self
.
neck
(
x
)
return
x
return
x
def
forward_train
(
self
,
points
,
i
mg
_metas
,
gt_bboxes_3d
,
gt_labels_3d
):
def
forward_train
(
self
,
points
,
i
nput
_metas
,
gt_bboxes_3d
,
gt_labels_3d
):
"""Forward of training.
"""Forward of training.
Args:
Args:
points (list[torch.Tensor]): Points of each batch.
points (list[torch.Tensor]): Points of each batch.
i
mg
_metas (list[dict]): Meta information of each sample.
i
nput
_metas (list[dict]): Meta information of each sample.
gt_bboxes_3d (:obj:`BaseInstance3DBoxes`): gt bboxes of each batch.
gt_bboxes_3d (:obj:`BaseInstance3DBoxes`): gt bboxes of each batch.
gt_labels_3d (list[torch.Tensor]): gt class labels of each batch.
gt_labels_3d (list[torch.Tensor]): gt class labels of each batch.
...
@@ -69,8 +69,8 @@ class PointRCNN(TwoStage3DDetector):
...
@@ -69,8 +69,8 @@ class PointRCNN(TwoStage3DDetector):
dict: Losses.
dict: Losses.
"""
"""
losses
=
dict
()
losses
=
dict
()
points
_cat
=
torch
.
stack
(
points
)
stack_
points
=
torch
.
stack
(
points
)
x
=
self
.
extract_feat
(
points
_cat
)
x
=
self
.
extract_feat
(
stack_
points
)
# features for rcnn
# features for rcnn
backbone_feats
=
x
[
'fp_features'
].
clone
()
backbone_feats
=
x
[
'fp_features'
].
clone
()
...
@@ -85,11 +85,11 @@ class PointRCNN(TwoStage3DDetector):
...
@@ -85,11 +85,11 @@ class PointRCNN(TwoStage3DDetector):
points
=
points
,
points
=
points
,
gt_bboxes_3d
=
gt_bboxes_3d
,
gt_bboxes_3d
=
gt_bboxes_3d
,
gt_labels_3d
=
gt_labels_3d
,
gt_labels_3d
=
gt_labels_3d
,
i
mg
_metas
=
i
mg
_metas
)
i
nput
_metas
=
i
nput
_metas
)
losses
.
update
(
rpn_loss
)
losses
.
update
(
rpn_loss
)
bbox_list
=
self
.
rpn_head
.
get_bboxes
(
points
_cat
,
bbox_preds
,
cls_preds
,
bbox_list
=
self
.
rpn_head
.
get_bboxes
(
stack_
points
,
bbox_preds
,
img
_metas
)
cls_preds
,
input
_metas
)
proposal_list
=
[
proposal_list
=
[
dict
(
dict
(
boxes_3d
=
bboxes
,
boxes_3d
=
bboxes
,
...
@@ -100,7 +100,7 @@ class PointRCNN(TwoStage3DDetector):
...
@@ -100,7 +100,7 @@ class PointRCNN(TwoStage3DDetector):
]
]
rcnn_feats
.
update
({
'points_cls_preds'
:
cls_preds
})
rcnn_feats
.
update
({
'points_cls_preds'
:
cls_preds
})
roi_losses
=
self
.
roi_head
.
forward_train
(
rcnn_feats
,
i
mg
_metas
,
roi_losses
=
self
.
roi_head
.
forward_train
(
rcnn_feats
,
i
nput
_metas
,
proposal_list
,
gt_bboxes_3d
,
proposal_list
,
gt_bboxes_3d
,
gt_labels_3d
)
gt_labels_3d
)
losses
.
update
(
roi_losses
)
losses
.
update
(
roi_losses
)
...
@@ -121,9 +121,9 @@ class PointRCNN(TwoStage3DDetector):
...
@@ -121,9 +121,9 @@ class PointRCNN(TwoStage3DDetector):
Returns:
Returns:
list: Predicted 3d boxes.
list: Predicted 3d boxes.
"""
"""
points
_cat
=
torch
.
stack
(
points
)
stack_
points
=
torch
.
stack
(
points
)
x
=
self
.
extract_feat
(
points
_cat
)
x
=
self
.
extract_feat
(
stack_
points
)
# features for rcnn
# features for rcnn
backbone_feats
=
x
[
'fp_features'
].
clone
()
backbone_feats
=
x
[
'fp_features'
].
clone
()
backbone_xyz
=
x
[
'fp_xyz'
].
clone
()
backbone_xyz
=
x
[
'fp_xyz'
].
clone
()
...
@@ -132,7 +132,7 @@ class PointRCNN(TwoStage3DDetector):
...
@@ -132,7 +132,7 @@ class PointRCNN(TwoStage3DDetector):
rcnn_feats
.
update
({
'points_cls_preds'
:
cls_preds
})
rcnn_feats
.
update
({
'points_cls_preds'
:
cls_preds
})
bbox_list
=
self
.
rpn_head
.
get_bboxes
(
bbox_list
=
self
.
rpn_head
.
get_bboxes
(
points
_cat
,
bbox_preds
,
cls_preds
,
img_metas
,
rescale
=
rescale
)
stack_
points
,
bbox_preds
,
cls_preds
,
img_metas
,
rescale
=
rescale
)
proposal_list
=
[
proposal_list
=
[
dict
(
dict
(
...
...
mmdet3d/models/roi_heads/bbox_heads/h3d_bbox_head.py
View file @
0e17beab
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
torch
import
torch
from
mmcv.cnn
import
ConvModule
from
mmcv.cnn
import
ConvModule
from
mmcv.runner
import
BaseModule
from
mmcv.runner
import
BaseModule
from
mmengine
import
InstanceData
from
torch
import
Tensor
from
torch
import
nn
as
nn
from
torch
import
nn
as
nn
from
torch.nn
import
functional
as
F
from
torch.nn
import
functional
as
F
from
mmdet3d.core
import
build_bbox_coder
from
mmdet3d.core
import
BaseInstance3DBoxes
,
Det3DDataSample
from
mmdet3d.core.bbox
import
DepthInstance3DBoxes
from
mmdet3d.core.bbox
import
DepthInstance3DBoxes
from
mmdet3d.core.post_processing
import
aligned_3d_nms
from
mmdet3d.core.post_processing
import
aligned_3d_nms
from
mmdet3d.models.builder
import
build_loss
from
mmdet3d.models.losses
import
chamfer_distance
from
mmdet3d.models.losses
import
chamfer_distance
from
mmdet3d.ops
import
build_sa_module
from
mmdet3d.ops
import
build_sa_module
from
mmdet3d.registry
import
MODELS
from
mmdet3d.registry
import
MODELS
,
TASK_UTILS
from
mmdet.core
import
multi_apply
from
mmdet.core
import
multi_apply
...
@@ -25,66 +28,73 @@ class H3DBboxHead(BaseModule):
...
@@ -25,66 +28,73 @@ class H3DBboxHead(BaseModule):
line_matching_cfg (dict): Config for line primitive matching.
line_matching_cfg (dict): Config for line primitive matching.
bbox_coder (:obj:`BaseBBoxCoder`): Bbox coder for encoding and
bbox_coder (:obj:`BaseBBoxCoder`): Bbox coder for encoding and
decoding boxes.
decoding boxes.
train_cfg (dict): Config for training.
train_cfg (dict): Config for training.
Defaults to None.
test_cfg (dict): Config for testing.
test_cfg (dict): Config for testing.
Defaults to None.
gt_per_seed (int): Number of ground truth votes generated
gt_per_seed (int): Number of ground truth votes generated
from each seed point.
from each seed point.
Defaults to 1.
num_proposal (int): Number of proposal votes generated.
num_proposal (int): Number of proposal votes generated.
feat_channels (tuple[int]): Convolution channels of
Defaults to 256.
prediction layer.
primitive_feat_refine_streams (int): The number of mlps to
primitive_feat_refine_streams (int): The number of mlps to
refine primitive feature.
refine primitive feature.
Defaults to 2.
primitive_refine_channels (tuple[int]): Convolution channels of
primitive_refine_channels (tuple[int]): Convolution channels of
prediction layer.
prediction layer.
Defaults to [128, 128, 128].
upper_thresh (float): Threshold for line matching.
upper_thresh (float): Threshold for line matching.
Defaults to 100.
surface_thresh (float): Threshold for surface matching.
surface_thresh (float): Threshold for surface matching.
line_thresh (float): Threshold for line matching.
Defaults to 0.5.
line_thresh (float): Threshold for line matching. Defaults to 0.5.
conv_cfg (dict): Config of convolution in prediction layer.
conv_cfg (dict): Config of convolution in prediction layer.
norm_cfg (dict): Config of BN in prediction layer.
Defaults to None.
objectness_loss (dict): Config of objectness loss.
norm_cfg (dict): Config of BN in prediction layer. Defaults to None.
center_loss (dict): Config of center loss.
objectness_loss (dict): Config of objectness loss. Defaults to None.
center_loss (dict): Config of center loss. Defaults to None.
dir_class_loss (dict): Config of direction classification loss.
dir_class_loss (dict): Config of direction classification loss.
Defaults to None.
dir_res_loss (dict): Config of direction residual regression loss.
dir_res_loss (dict): Config of direction residual regression loss.
Defaults to None.
size_class_loss (dict): Config of size classification loss.
size_class_loss (dict): Config of size classification loss.
Defaults to None.
size_res_loss (dict): Config of size residual regression loss.
size_res_loss (dict): Config of size residual regression loss.
Defaults to None.
semantic_loss (dict): Config of point-wise semantic segmentation loss.
semantic_loss (dict): Config of point-wise semantic segmentation loss.
Defaults to None.
cues_objectness_loss (dict): Config of cues objectness loss.
cues_objectness_loss (dict): Config of cues objectness loss.
Defaults to None.
cues_semantic_loss (dict): Config of cues semantic loss.
cues_semantic_loss (dict): Config of cues semantic loss.
Defaults to None.
proposal_objectness_loss (dict): Config of proposal objectness
proposal_objectness_loss (dict): Config of proposal objectness
loss.
loss.
Defaults to None.
primitive_center_loss (dict): Config of primitive center regression
primitive_center_loss (dict): Config of primitive center regression
loss.
loss.
Defaults to None.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
num_classes
,
num_classes
:
int
,
suface_matching_cfg
,
suface_matching_cfg
:
dict
,
line_matching_cfg
,
line_matching_cfg
:
dict
,
bbox_coder
,
bbox_coder
:
dict
,
train_cfg
=
None
,
train_cfg
:
Optional
[
dict
]
=
None
,
test_cfg
=
None
,
test_cfg
:
Optional
[
dict
]
=
None
,
gt_per_seed
=
1
,
gt_per_seed
:
int
=
1
,
num_proposal
=
256
,
num_proposal
:
int
=
256
,
feat_channels
=
(
128
,
128
),
primitive_feat_refine_streams
:
int
=
2
,
primitive_feat_refine_streams
=
2
,
primitive_refine_channels
:
List
[
int
]
=
[
128
,
128
,
128
],
primitive_refine_channels
=
[
128
,
128
,
128
],
upper_thresh
:
float
=
100.0
,
upper_thresh
=
100.0
,
surface_thresh
:
float
=
0.5
,
surface_thresh
=
0.5
,
line_thresh
:
float
=
0.5
,
line_thresh
=
0.5
,
conv_cfg
:
dict
=
dict
(
type
=
'Conv1d'
),
conv_cfg
=
dict
(
type
=
'Conv1d'
),
norm_cfg
:
dict
=
dict
(
type
=
'BN1d'
),
norm_cfg
=
dict
(
type
=
'BN1d'
),
objectness_loss
:
Optional
[
dict
]
=
None
,
objectness_loss
=
None
,
center_loss
:
Optional
[
dict
]
=
None
,
center_loss
=
None
,
dir_class_loss
:
Optional
[
dict
]
=
None
,
dir_class_loss
=
None
,
dir_res_loss
:
Optional
[
dict
]
=
None
,
dir_res_loss
=
None
,
size_class_loss
:
Optional
[
dict
]
=
None
,
size_class_loss
=
None
,
size_res_loss
:
Optional
[
dict
]
=
None
,
size_res_loss
=
None
,
semantic_loss
:
Optional
[
dict
]
=
None
,
semantic_loss
=
None
,
cues_objectness_loss
:
Optional
[
dict
]
=
None
,
cues_objectness_loss
=
None
,
cues_semantic_loss
:
Optional
[
dict
]
=
None
,
cues_semantic_loss
=
None
,
proposal_objectness_loss
:
Optional
[
dict
]
=
None
,
proposal_objectness_loss
=
None
,
primitive_center_loss
:
Optional
[
dict
]
=
None
,
primitive_center_loss
=
None
,
init_cfg
:
dict
=
None
):
init_cfg
=
None
):
super
(
H3DBboxHead
,
self
).
__init__
(
init_cfg
=
init_cfg
)
super
(
H3DBboxHead
,
self
).
__init__
(
init_cfg
=
init_cfg
)
self
.
num_classes
=
num_classes
self
.
num_classes
=
num_classes
self
.
train_cfg
=
train_cfg
self
.
train_cfg
=
train_cfg
...
@@ -96,22 +106,22 @@ class H3DBboxHead(BaseModule):
...
@@ -96,22 +106,22 @@ class H3DBboxHead(BaseModule):
self
.
surface_thresh
=
surface_thresh
self
.
surface_thresh
=
surface_thresh
self
.
line_thresh
=
line_thresh
self
.
line_thresh
=
line_thresh
self
.
objectness
_loss
=
build_loss
(
objectness_loss
)
self
.
loss_
objectness
=
MODELS
.
build
(
objectness_loss
)
self
.
center
_loss
=
build_loss
(
center_loss
)
self
.
loss_
center
=
MODELS
.
build
(
center_loss
)
self
.
dir_class
_loss
=
build_loss
(
dir_class_loss
)
self
.
loss_
dir_class
=
MODELS
.
build
(
dir_class_loss
)
self
.
dir_res
_loss
=
build_loss
(
dir_res_loss
)
self
.
loss_
dir_res
=
MODELS
.
build
(
dir_res_loss
)
self
.
size_class
_loss
=
build_loss
(
size_class_loss
)
self
.
loss_
size_class
=
MODELS
.
build
(
size_class_loss
)
self
.
size_res
_loss
=
build_loss
(
size_res_loss
)
self
.
loss_
size_res
=
MODELS
.
build
(
size_res_loss
)
self
.
semantic
_loss
=
build_loss
(
semantic_loss
)
self
.
loss_
semantic
=
MODELS
.
build
(
semantic_loss
)
self
.
bbox_coder
=
build_bbox_coder
(
bbox_coder
)
self
.
bbox_coder
=
TASK_UTILS
.
build
(
bbox_coder
)
self
.
num_sizes
=
self
.
bbox_coder
.
num_sizes
self
.
num_sizes
=
self
.
bbox_coder
.
num_sizes
self
.
num_dir_bins
=
self
.
bbox_coder
.
num_dir_bins
self
.
num_dir_bins
=
self
.
bbox_coder
.
num_dir_bins
self
.
cues_objectness
_loss
=
build_loss
(
cues_objectness_loss
)
self
.
loss_
cues_objectness
=
MODELS
.
build
(
cues_objectness_loss
)
self
.
cues_semantic
_loss
=
build_loss
(
cues_semantic_loss
)
self
.
loss_
cues_semantic
=
MODELS
.
build
(
cues_semantic_loss
)
self
.
proposal_objectness
_loss
=
build_loss
(
proposal_objectness_loss
)
self
.
loss_
proposal_objectness
=
MODELS
.
build
(
proposal_objectness_loss
)
self
.
primitive_center
_loss
=
build_loss
(
primitive_center_loss
)
self
.
loss_
primitive_center
=
MODELS
.
build
(
primitive_center_loss
)
assert
suface_matching_cfg
[
'mlp_channels'
][
-
1
]
==
\
assert
suface_matching_cfg
[
'mlp_channels'
][
-
1
]
==
\
line_matching_cfg
[
'mlp_channels'
][
-
1
]
line_matching_cfg
[
'mlp_channels'
][
-
1
]
...
@@ -202,16 +212,14 @@ class H3DBboxHead(BaseModule):
...
@@ -202,16 +212,14 @@ class H3DBboxHead(BaseModule):
bbox_coder
[
'num_sizes'
]
*
4
+
self
.
num_classes
)
bbox_coder
[
'num_sizes'
]
*
4
+
self
.
num_classes
)
self
.
bbox_pred
.
append
(
nn
.
Conv1d
(
prev_channel
,
conv_out_channel
,
1
))
self
.
bbox_pred
.
append
(
nn
.
Conv1d
(
prev_channel
,
conv_out_channel
,
1
))
def
forward
(
self
,
feats_dict
,
sample_mod
):
def
forward
(
self
,
feats_dict
:
dict
):
"""Forward pass.
"""Forward pass.
Args:
Args:
feats_dict (dict): Feature dict from backbone.
feats_dict (dict): Feature dict from backbone.
sample_mod (str): Sample mode for vote aggregation layer.
valid modes are "vote", "seed" and "random".
Returns:
Returns:
dict: Predictions of
vote
head.
dict: Predictions of head.
"""
"""
ret_dict
=
{}
ret_dict
=
{}
aggregated_points
=
feats_dict
[
'aggregated_points'
]
aggregated_points
=
feats_dict
[
'aggregated_points'
]
...
@@ -236,7 +244,7 @@ class H3DBboxHead(BaseModule):
...
@@ -236,7 +244,7 @@ class H3DBboxHead(BaseModule):
dim
=
1
)
dim
=
1
)
# Extract the surface and line centers of rpn proposals
# Extract the surface and line centers of rpn proposals
rpn_proposals
=
feats_dict
[
'proposal
_list
'
]
rpn_proposals
=
feats_dict
[
'
rpn_
proposal
s
'
]
rpn_proposals_bbox
=
DepthInstance3DBoxes
(
rpn_proposals_bbox
=
DepthInstance3DBoxes
(
rpn_proposals
.
reshape
(
-
1
,
7
).
clone
(),
rpn_proposals
.
reshape
(
-
1
,
7
).
clone
(),
box_dim
=
rpn_proposals
.
shape
[
-
1
],
box_dim
=
rpn_proposals
.
shape
[
-
1
],
...
@@ -310,36 +318,29 @@ class H3DBboxHead(BaseModule):
...
@@ -310,36 +318,29 @@ class H3DBboxHead(BaseModule):
ret_dict
[
key
+
'_optimized'
]
=
refine_decode_res
[
key
]
ret_dict
[
key
+
'_optimized'
]
=
refine_decode_res
[
key
]
return
ret_dict
return
ret_dict
def
loss
(
self
,
def
loss
(
bbox_preds
,
self
,
points
,
points
:
List
[
Tensor
],
gt_bboxes_3d
,
feats_dict
:
dict
,
gt_labels_3d
,
rpn_targets
:
Tuple
=
None
,
pts_semantic_mask
=
None
,
batch_data_samples
:
List
[
Det3DDataSample
]
=
None
,
pts_instance_mask
=
None
,
):
img_metas
=
None
,
"""
rpn_targets
=
None
,
gt_bboxes_ignore
=
None
):
"""Compute loss.
Args:
Args:
bbox_preds (dict): Predictions from forward of h3d bbox head.
points (list[tensor]): Points cloud of multiple samples.
points (list[torch.Tensor]): Input points.
feats_dict (dict): Predictions from backbone or FPN.
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth
rpn_targets (Tuple, Optional): The target of sample from RPN.
bboxes of each sample.
Defaults to None.
gt_labels_3d (list[torch.Tensor]): Labels of each sample.
batch_data_samples (list[:obj:`Det3DDataSample`], Optional):
pts_semantic_mask (list[torch.Tensor]): Point-wise
Each item contains the meta information of each sample
semantic mask.
and corresponding annotations. Defaults to None.
pts_instance_mask (list[torch.Tensor]): Point-wise
instance mask.
img_metas (list[dict]): Contain pcd and img's meta info.
rpn_targets (Tuple) : Targets generated by rpn head.
gt_bboxes_ignore (list[torch.Tensor]): Specify
which bounding.
Returns:
Returns:
dict:
Losses of H3dnet
.
dict:
A dictionary of loss components
.
"""
"""
preds
=
self
(
feats_dict
)
feats_dict
.
update
(
preds
)
(
vote_targets
,
vote_target_masks
,
size_class_targets
,
size_res_targets
,
(
vote_targets
,
vote_target_masks
,
size_class_targets
,
size_res_targets
,
dir_class_targets
,
dir_res_targets
,
center_targets
,
_
,
mask_targets
,
dir_class_targets
,
dir_res_targets
,
center_targets
,
_
,
mask_targets
,
valid_gt_masks
,
objectness_targets
,
objectness_weights
,
valid_gt_masks
,
objectness_targets
,
objectness_weights
,
...
@@ -349,7 +350,7 @@ class H3DBboxHead(BaseModule):
...
@@ -349,7 +350,7 @@ class H3DBboxHead(BaseModule):
# calculate refined proposal loss
# calculate refined proposal loss
refined_proposal_loss
=
self
.
get_proposal_stage_loss
(
refined_proposal_loss
=
self
.
get_proposal_stage_loss
(
bbox_preds
,
feats_dict
,
size_class_targets
,
size_class_targets
,
size_res_targets
,
size_res_targets
,
dir_class_targets
,
dir_class_targets
,
...
@@ -364,36 +365,60 @@ class H3DBboxHead(BaseModule):
...
@@ -364,36 +365,60 @@ class H3DBboxHead(BaseModule):
for
key
in
refined_proposal_loss
.
keys
():
for
key
in
refined_proposal_loss
.
keys
():
losses
[
key
+
'_optimized'
]
=
refined_proposal_loss
[
key
]
losses
[
key
+
'_optimized'
]
=
refined_proposal_loss
[
key
]
batch_gt_instance_3d
=
[]
batch_input_metas
=
[]
for
data_sample
in
batch_data_samples
:
batch_input_metas
.
append
(
data_sample
.
metainfo
)
batch_gt_instance_3d
.
append
(
data_sample
.
gt_instances_3d
)
temp_loss
=
self
.
loss_by_feat
(
points
,
feats_dict
,
batch_gt_instance_3d
)
losses
.
update
(
temp_loss
)
return
losses
def
loss_by_feat
(
self
,
points
:
List
[
torch
.
Tensor
],
feats_dict
:
dict
,
batch_gt_instances_3d
:
List
[
InstanceData
],
**
kwargs
)
->
dict
:
"""Compute loss.
Args:
points (list[torch.Tensor]): Input points.
feats_dict (dict): Predictions from forward of vote head.
batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
gt_instances. It usually includes ``bboxes`` and ``labels``
attributes.
Returns:
dict: Losses of H3DNet.
"""
bbox3d_optimized
=
self
.
bbox_coder
.
decode
(
bbox3d_optimized
=
self
.
bbox_coder
.
decode
(
bbox_preds
,
suffix
=
'_optimized'
)
feats_dict
,
suffix
=
'_optimized'
)
targets
=
self
.
get_targets
(
points
,
gt_bboxes_3d
,
gt_labels_3d
,
targets
=
self
.
get_targets
(
points
,
feats_dict
,
batch_gt_instances_3d
)
pts_semantic_mask
,
pts_instance_mask
,
bbox_preds
)
(
cues_objectness_label
,
cues_sem_label
,
proposal_objectness_label
,
(
cues_objectness_label
,
cues_sem_label
,
proposal_objectness_label
,
cues_mask
,
cues_match_mask
,
proposal_objectness_mask
,
cues_mask
,
cues_match_mask
,
proposal_objectness_mask
,
cues_matching_label
,
obj_surface_line_center
)
=
targets
cues_matching_label
,
obj_surface_line_center
)
=
targets
# match scores for each geometric primitive
# match scores for each geometric primitive
objectness_scores
=
bbox_preds
[
'matching_score'
]
objectness_scores
=
feats_dict
[
'matching_score'
]
# match scores for the semantics of primitives
# match scores for the semantics of primitives
objectness_scores_sem
=
bbox_preds
[
'semantic_matching_score'
]
objectness_scores_sem
=
feats_dict
[
'semantic_matching_score'
]
primitive_objectness_loss
=
self
.
cues_objectness
_loss
(
primitive_objectness_loss
=
self
.
loss_
cues_objectness
(
objectness_scores
.
transpose
(
2
,
1
),
objectness_scores
.
transpose
(
2
,
1
),
cues_objectness_label
,
cues_objectness_label
,
weight
=
cues_mask
,
weight
=
cues_mask
,
avg_factor
=
cues_mask
.
sum
()
+
1e-6
)
avg_factor
=
cues_mask
.
sum
()
+
1e-6
)
primitive_sem_loss
=
self
.
cues_semantic
_loss
(
primitive_sem_loss
=
self
.
loss_
cues_semantic
(
objectness_scores_sem
.
transpose
(
2
,
1
),
objectness_scores_sem
.
transpose
(
2
,
1
),
cues_sem_label
,
cues_sem_label
,
weight
=
cues_mask
,
weight
=
cues_mask
,
avg_factor
=
cues_mask
.
sum
()
+
1e-6
)
avg_factor
=
cues_mask
.
sum
()
+
1e-6
)
objectness_scores
=
bbox_preds
[
'obj_scores_optimized'
]
objectness_scores
=
feats_dict
[
'obj_scores_optimized'
]
objectness_loss_refine
=
self
.
proposal_objectness
_loss
(
objectness_loss_refine
=
self
.
loss_
proposal_objectness
(
objectness_scores
.
transpose
(
2
,
1
),
proposal_objectness_label
)
objectness_scores
.
transpose
(
2
,
1
),
proposal_objectness_label
)
primitive_matching_loss
=
(
objectness_loss_refine
*
primitive_matching_loss
=
(
objectness_loss_refine
*
cues_match_mask
).
sum
()
/
(
cues_match_mask
).
sum
()
/
(
...
@@ -419,7 +444,7 @@ class H3DBboxHead(BaseModule):
...
@@ -419,7 +444,7 @@ class H3DBboxHead(BaseModule):
pred_surface_line_center
=
torch
.
cat
(
pred_surface_line_center
=
torch
.
cat
(
(
pred_obj_surface_center
,
pred_obj_line_center
),
1
)
(
pred_obj_surface_center
,
pred_obj_line_center
),
1
)
square_dist
=
self
.
primitive_center
_loss
(
pred_surface_line_center
,
square_dist
=
self
.
loss_
primitive_center
(
pred_surface_line_center
,
obj_surface_line_center
)
obj_surface_line_center
)
match_dist
=
torch
.
sqrt
(
square_dist
.
sum
(
dim
=-
1
)
+
1e-6
)
match_dist
=
torch
.
sqrt
(
square_dist
.
sum
(
dim
=-
1
)
+
1e-6
)
...
@@ -434,58 +459,102 @@ class H3DBboxHead(BaseModule):
...
@@ -434,58 +459,102 @@ class H3DBboxHead(BaseModule):
primitive_sem_matching_loss
=
primitive_sem_matching_loss
,
primitive_sem_matching_loss
=
primitive_sem_matching_loss
,
primitive_centroid_reg_loss
=
primitive_centroid_reg_loss
)
primitive_centroid_reg_loss
=
primitive_centroid_reg_loss
)
losses
.
update
(
refined_loss
)
return
refined_loss
return
losses
def
predict
(
self
,
points
:
List
[
torch
.
Tensor
],
feats_dict
:
Dict
[
str
,
torch
.
Tensor
],
batch_data_samples
:
List
[
Det3DDataSample
],
suffix
=
'_optimized'
,
**
kwargs
)
->
List
[
InstanceData
]:
"""
Args:
points (list[tensor]): Point clouds of multiple samples.
feats_dict (dict): Features from FPN or backbone..
batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
Samples. It usually includes meta information of data.
suffix (str): suffix for tensor in feats_dict.
Defaults to '_optimized'.
def
get_bboxes
(
self
,
Returns:
points
,
list[:obj:`InstanceData`]: List of processed predictions. Each
bbox_preds
,
InstanceData contains 3d Bounding boxes and corresponding
input_metas
,
scores and labels.
rescale
=
False
,
"""
suffix
=
''
):
preds_dict
=
self
(
feats_dict
)
# `preds_dict` can be used in H3DNET
feats_dict
.
update
(
preds_dict
)
batch_size
=
len
(
batch_data_samples
)
batch_input_metas
=
[]
for
batch_index
in
range
(
batch_size
):
metainfo
=
batch_data_samples
[
batch_index
].
metainfo
batch_input_metas
.
append
(
metainfo
)
results_list
=
self
.
predict_by_feat
(
points
,
feats_dict
,
batch_input_metas
,
suffix
=
suffix
,
**
kwargs
)
return
results_list
def
predict_by_feat
(
self
,
points
:
List
[
torch
.
Tensor
],
feats_dict
:
dict
,
batch_input_metas
:
List
[
dict
],
suffix
=
'_optimized'
,
**
kwargs
)
->
List
[
InstanceData
]:
"""Generate bboxes from vote head predictions.
"""Generate bboxes from vote head predictions.
Args:
Args:
points (torch.Tensor): Input points.
points (List[torch.Tensor]): Input points of multiple samples.
bbox_preds (dict): Predictions from vote head.
feats_dict (dict): Predictions from previous components.
input_metas (list[dict]): Point cloud and image's meta info.
batch_input_metas (list[dict]): Each item
rescale (bool): Whether to rescale bboxes.
contains the meta information of each sample.
suffix (str): suffix for tensor in feats_dict.
Defaults to '_optimized'.
Returns:
Returns:
list[tuple[torch.Tensor]]: Bounding boxes, scores and labels.
list[:obj:`InstanceData`]: Return list of processed
predictions. Each InstanceData cantains
3d Bounding boxes and corresponding scores and labels.
"""
"""
# decode boxes
# decode boxes
obj_scores
=
F
.
softmax
(
obj_scores
=
F
.
softmax
(
bbox_preds
[
'obj_scores'
+
suffix
],
dim
=-
1
)[...,
-
1
]
feats_dict
[
'obj_scores'
+
suffix
],
dim
=-
1
)[...,
-
1
]
sem_scores
=
F
.
softmax
(
bbox_preds
[
'sem_scores'
],
dim
=-
1
)
sem_scores
=
F
.
softmax
(
feats_dict
[
'sem_scores'
],
dim
=-
1
)
prediction_collection
=
{}
prediction_collection
=
{}
prediction_collection
[
'center'
]
=
bbox_preds
[
'center'
+
suffix
]
prediction_collection
[
'center'
]
=
feats_dict
[
'center'
+
suffix
]
prediction_collection
[
'dir_class'
]
=
bbox_preds
[
'dir_class'
]
prediction_collection
[
'dir_class'
]
=
feats_dict
[
'dir_class'
]
prediction_collection
[
'dir_res'
]
=
bbox_preds
[
'dir_res'
+
suffix
]
prediction_collection
[
'dir_res'
]
=
feats_dict
[
'dir_res'
+
suffix
]
prediction_collection
[
'size_class'
]
=
bbox_preds
[
'size_class'
]
prediction_collection
[
'size_class'
]
=
feats_dict
[
'size_class'
]
prediction_collection
[
'size_res'
]
=
bbox_preds
[
'size_res'
+
suffix
]
prediction_collection
[
'size_res'
]
=
feats_dict
[
'size_res'
+
suffix
]
bbox3d
=
self
.
bbox_coder
.
decode
(
prediction_collection
)
bbox3d
=
self
.
bbox_coder
.
decode
(
prediction_collection
)
batch_size
=
bbox3d
.
shape
[
0
]
batch_size
=
bbox3d
.
shape
[
0
]
results
=
list
()
results_list
=
list
()
points
=
torch
.
stack
(
points
)
for
b
in
range
(
batch_size
):
for
b
in
range
(
batch_size
):
temp_results
=
InstanceData
()
bbox_selected
,
score_selected
,
labels
=
self
.
multiclass_nms_single
(
bbox_selected
,
score_selected
,
labels
=
self
.
multiclass_nms_single
(
obj_scores
[
b
],
sem_scores
[
b
],
bbox3d
[
b
],
points
[
b
,
...,
:
3
],
obj_scores
[
b
],
sem_scores
[
b
],
bbox3d
[
b
],
points
[
b
,
...,
:
3
],
input_metas
[
b
])
batch_
input_metas
[
b
])
bbox
=
input_metas
[
b
][
'box_type_3d'
](
bbox
=
batch_
input_metas
[
b
][
'box_type_3d'
](
bbox_selected
,
bbox_selected
,
box_dim
=
bbox_selected
.
shape
[
-
1
],
box_dim
=
bbox_selected
.
shape
[
-
1
],
with_yaw
=
self
.
bbox_coder
.
with_rot
)
with_yaw
=
self
.
bbox_coder
.
with_rot
)
results
.
append
((
bbox
,
score_selected
,
labels
))
return
results
temp_results
.
bboxes_3d
=
bbox
temp_results
.
scores_3d
=
score_selected
temp_results
.
labels_3d
=
labels
results_list
.
append
(
temp_results
)
return
results_list
def
multiclass_nms_single
(
self
,
obj_scores
,
sem_scores
,
bbox
,
points
,
def
multiclass_nms_single
(
self
,
obj_scores
:
Tensor
,
sem_scores
:
Tensor
,
input_meta
):
bbox
:
Tensor
,
points
:
Tensor
,
input_meta
:
dict
)
->
Tuple
:
"""Multi-class nms in single batch.
"""Multi-class nms in single batch.
Args:
Args:
...
@@ -586,13 +655,13 @@ class H3DBboxHead(BaseModule):
...
@@ -586,13 +655,13 @@ class H3DBboxHead(BaseModule):
dict: Losses of aggregation module.
dict: Losses of aggregation module.
"""
"""
# calculate objectness loss
# calculate objectness loss
objectness_loss
=
self
.
objectness
_loss
(
objectness_loss
=
self
.
loss_
objectness
(
bbox_preds
[
'obj_scores'
+
suffix
].
transpose
(
2
,
1
),
bbox_preds
[
'obj_scores'
+
suffix
].
transpose
(
2
,
1
),
objectness_targets
,
objectness_targets
,
weight
=
objectness_weights
)
weight
=
objectness_weights
)
# calculate center loss
# calculate center loss
source2target_loss
,
target2source_loss
=
self
.
center
_loss
(
source2target_loss
,
target2source_loss
=
self
.
loss_
center
(
bbox_preds
[
'center'
+
suffix
],
bbox_preds
[
'center'
+
suffix
],
center_targets
,
center_targets
,
src_weight
=
box_loss_weights
,
src_weight
=
box_loss_weights
,
...
@@ -600,7 +669,7 @@ class H3DBboxHead(BaseModule):
...
@@ -600,7 +669,7 @@ class H3DBboxHead(BaseModule):
center_loss
=
source2target_loss
+
target2source_loss
center_loss
=
source2target_loss
+
target2source_loss
# calculate direction class loss
# calculate direction class loss
dir_class_loss
=
self
.
dir_class
_loss
(
dir_class_loss
=
self
.
loss_
dir_class
(
bbox_preds
[
'dir_class'
+
suffix
].
transpose
(
2
,
1
),
bbox_preds
[
'dir_class'
+
suffix
].
transpose
(
2
,
1
),
dir_class_targets
,
dir_class_targets
,
weight
=
box_loss_weights
)
weight
=
box_loss_weights
)
...
@@ -612,11 +681,11 @@ class H3DBboxHead(BaseModule):
...
@@ -612,11 +681,11 @@ class H3DBboxHead(BaseModule):
heading_label_one_hot
.
scatter_
(
2
,
dir_class_targets
.
unsqueeze
(
-
1
),
1
)
heading_label_one_hot
.
scatter_
(
2
,
dir_class_targets
.
unsqueeze
(
-
1
),
1
)
dir_res_norm
=
(
bbox_preds
[
'dir_res_norm'
+
suffix
]
*
dir_res_norm
=
(
bbox_preds
[
'dir_res_norm'
+
suffix
]
*
heading_label_one_hot
).
sum
(
dim
=-
1
)
heading_label_one_hot
).
sum
(
dim
=-
1
)
dir_res_loss
=
self
.
dir_res
_loss
(
dir_res_loss
=
self
.
loss_
dir_res
(
dir_res_norm
,
dir_res_targets
,
weight
=
box_loss_weights
)
dir_res_norm
,
dir_res_targets
,
weight
=
box_loss_weights
)
# calculate size class loss
# calculate size class loss
size_class_loss
=
self
.
size_class
_loss
(
size_class_loss
=
self
.
loss_
size_class
(
bbox_preds
[
'size_class'
+
suffix
].
transpose
(
2
,
1
),
bbox_preds
[
'size_class'
+
suffix
].
transpose
(
2
,
1
),
size_class_targets
,
size_class_targets
,
weight
=
box_loss_weights
)
weight
=
box_loss_weights
)
...
@@ -631,13 +700,13 @@ class H3DBboxHead(BaseModule):
...
@@ -631,13 +700,13 @@ class H3DBboxHead(BaseModule):
one_hot_size_targets_expand
).
sum
(
dim
=
2
)
one_hot_size_targets_expand
).
sum
(
dim
=
2
)
box_loss_weights_expand
=
box_loss_weights
.
unsqueeze
(
-
1
).
repeat
(
box_loss_weights_expand
=
box_loss_weights
.
unsqueeze
(
-
1
).
repeat
(
1
,
1
,
3
)
1
,
1
,
3
)
size_res_loss
=
self
.
size_res
_loss
(
size_res_loss
=
self
.
loss_
size_res
(
size_residual_norm
,
size_residual_norm
,
size_res_targets
,
size_res_targets
,
weight
=
box_loss_weights_expand
)
weight
=
box_loss_weights_expand
)
# calculate semantic loss
# calculate semantic loss
semantic_loss
=
self
.
semantic
_loss
(
semantic_loss
=
self
.
loss_
semantic
(
bbox_preds
[
'sem_scores'
+
suffix
].
transpose
(
2
,
1
),
bbox_preds
[
'sem_scores'
+
suffix
].
transpose
(
2
,
1
),
mask_targets
,
mask_targets
,
weight
=
box_loss_weights
)
weight
=
box_loss_weights
)
...
@@ -653,91 +722,93 @@ class H3DBboxHead(BaseModule):
...
@@ -653,91 +722,93 @@ class H3DBboxHead(BaseModule):
return
losses
return
losses
def
get_targets
(
self
,
def
get_targets
(
self
,
points
,
points
,
gt_bboxes_3d
,
feats_dict
:
Optional
[
dict
]
=
None
,
gt_labels_3d
,
batch_gt_instances_3d
:
Optional
[
List
[
InstanceData
]]
=
None
,
pts_semantic_mask
=
None
,
):
pts_instance_mask
=
None
,
"""Generate targets of vote head.
bbox_preds
=
None
):
"""Generate targets of proposal module.
Args:
Args:
points (list[torch.Tensor]): Points of each batch.
points (list[torch.Tensor]): Points of each batch.
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth
feats_dict (dict, optional): Predictions of previous
bboxes of each batch.
components. Defaults to None.
gt_labels_3d (list[torch.Tensor]): Labels of each batch.
batch_gt_instances_3d (list[:obj:`InstanceData`], optional):
pts_semantic_mask (list[torch.Tensor]): Point-wise semantic
Batch of gt_instances. It usually includes
label of each batch.
``bboxes_3d`` and ``labels_3d`` attributes.
pts_instance_mask (list[torch.Tensor]): Point-wise instance
label of each batch.
bbox_preds (torch.Tensor): Bounding box predictions of vote head.
Returns:
Returns:
tuple[torch.Tensor]: Targets of
proposal module
.
tuple[torch.Tensor]: Targets of
vote head
.
"""
"""
# find empty example
# find empty example
valid_gt_masks
=
list
()
valid_gt_masks
=
list
()
gt_num
=
list
()
gt_num
=
list
()
for
index
in
range
(
len
(
gt_labels_3d
)):
batch_gt_labels_3d
=
[
if
len
(
gt_labels_3d
[
index
])
==
0
:
gt_instances_3d
.
labels_3d
fake_box
=
gt_bboxes_3d
[
index
].
tensor
.
new_zeros
(
for
gt_instances_3d
in
batch_gt_instances_3d
1
,
gt_bboxes_3d
[
index
].
tensor
.
shape
[
-
1
])
]
gt_bboxes_3d
[
index
]
=
gt_bboxes_3d
[
index
].
new_box
(
fake_box
)
batch_gt_bboxes_3d
=
[
gt_labels_3d
[
index
]
=
gt_labels_3d
[
index
].
new_zeros
(
1
)
gt_instances_3d
.
bboxes_3d
valid_gt_masks
.
append
(
gt_labels_3d
[
index
].
new_zeros
(
1
))
for
gt_instances_3d
in
batch_gt_instances_3d
]
for
index
in
range
(
len
(
batch_gt_labels_3d
)):
if
len
(
batch_gt_labels_3d
[
index
])
==
0
:
fake_box
=
batch_gt_bboxes_3d
[
index
].
tensor
.
new_zeros
(
1
,
batch_gt_bboxes_3d
[
index
].
tensor
.
shape
[
-
1
])
batch_gt_bboxes_3d
[
index
]
=
batch_gt_bboxes_3d
[
index
].
new_box
(
fake_box
)
batch_gt_labels_3d
[
index
]
=
batch_gt_labels_3d
[
index
].
new_zeros
(
1
)
valid_gt_masks
.
append
(
batch_gt_labels_3d
[
index
].
new_zeros
(
1
))
gt_num
.
append
(
1
)
gt_num
.
append
(
1
)
else
:
else
:
valid_gt_masks
.
append
(
gt_labels_3d
[
index
].
new_ones
(
valid_gt_masks
.
append
(
batch_gt_labels_3d
[
index
].
new_ones
(
gt_labels_3d
[
index
].
shape
))
batch_gt_labels_3d
[
index
].
shape
))
gt_num
.
append
(
gt_labels_3d
[
index
].
shape
[
0
])
gt_num
.
append
(
batch_gt_labels_3d
[
index
].
shape
[
0
])
if
pts_semantic_mask
is
None
:
pts_semantic_mask
=
[
None
for
i
in
range
(
len
(
gt_labels_3d
))]
pts_instance_mask
=
[
None
for
i
in
range
(
len
(
gt_labels_3d
))]
aggregated_points
=
[
aggregated_points
=
[
bbox_preds
[
'aggregated_points'
][
i
]
feats_dict
[
'aggregated_points'
][
i
]
for
i
in
range
(
len
(
gt_labels_3d
))
for
i
in
range
(
len
(
batch_
gt_labels_3d
))
]
]
surface_center_pred
=
[
surface_center_pred
=
[
bbox_preds
[
'surface_center_pred'
][
i
]
feats_dict
[
'surface_center_pred'
][
i
]
for
i
in
range
(
len
(
gt_labels_3d
))
for
i
in
range
(
len
(
batch_
gt_labels_3d
))
]
]
line_center_pred
=
[
line_center_pred
=
[
bbox_preds
[
'pred_line_center'
][
i
]
feats_dict
[
'pred_line_center'
][
i
]
for
i
in
range
(
len
(
gt_labels_3d
))
for
i
in
range
(
len
(
batch_
gt_labels_3d
))
]
]
surface_center_object
=
[
surface_center_object
=
[
bbox_preds
[
'surface_center_object'
][
i
]
feats_dict
[
'surface_center_object'
][
i
]
for
i
in
range
(
len
(
gt_labels_3d
))
for
i
in
range
(
len
(
batch_
gt_labels_3d
))
]
]
line_center_object
=
[
line_center_object
=
[
bbox_preds
[
'line_center_object'
][
i
]
feats_dict
[
'line_center_object'
][
i
]
for
i
in
range
(
len
(
gt_labels_3d
))
for
i
in
range
(
len
(
batch_
gt_labels_3d
))
]
]
surface_sem_pred
=
[
surface_sem_pred
=
[
bbox_preds
[
'surface_sem_pred'
][
i
]
feats_dict
[
'surface_sem_pred'
][
i
]
for
i
in
range
(
len
(
gt_labels_3d
))
for
i
in
range
(
len
(
batch_
gt_labels_3d
))
]
]
line_sem_pred
=
[
line_sem_pred
=
[
bbox_preds
[
'sem_cls_scores_line'
][
i
]
feats_dict
[
'sem_cls_scores_line'
][
i
]
for
i
in
range
(
len
(
gt_labels_3d
))
for
i
in
range
(
len
(
batch_
gt_labels_3d
))
]
]
(
cues_objectness_label
,
cues_sem_label
,
proposal_objectness_label
,
(
cues_objectness_label
,
cues_sem_label
,
proposal_objectness_label
,
cues_mask
,
cues_match_mask
,
proposal_objectness_mask
,
cues_mask
,
cues_match_mask
,
proposal_objectness_mask
,
cues_matching_label
,
obj_surface_line_center
)
=
multi_apply
(
cues_matching_label
,
obj_surface_line_center
)
=
multi_apply
(
self
.
get_targets_single
,
points
,
gt_bboxes_3d
,
gt_labels_3d
,
self
.
_
get_targets_single
,
points
,
batch_
gt_bboxes_3d
,
pts_semantic_mask
,
pts_instance_mask
,
aggregated_points
,
batch_gt_labels_3d
,
aggregated_points
,
surface_center_pred
,
surface_center_pred
,
line_center_pred
,
surface_center_object
,
line_center_pred
,
surface_center_object
,
line_center_object
,
line_center_object
,
surface_sem_pred
,
line_sem_pred
)
surface_sem_pred
,
line_sem_pred
)
cues_objectness_label
=
torch
.
stack
(
cues_objectness_label
)
cues_objectness_label
=
torch
.
stack
(
cues_objectness_label
)
cues_sem_label
=
torch
.
stack
(
cues_sem_label
)
cues_sem_label
=
torch
.
stack
(
cues_sem_label
)
...
@@ -753,19 +824,17 @@ class H3DBboxHead(BaseModule):
...
@@ -753,19 +824,17 @@ class H3DBboxHead(BaseModule):
proposal_objectness_mask
,
cues_matching_label
,
proposal_objectness_mask
,
cues_matching_label
,
obj_surface_line_center
)
obj_surface_line_center
)
def
get_targets_single
(
self
,
def
_get_targets_single
(
self
,
points
,
points
:
Tensor
,
gt_bboxes_3d
,
gt_bboxes_3d
:
BaseInstance3DBoxes
,
gt_labels_3d
,
gt_labels_3d
:
Tensor
,
pts_semantic_mask
=
None
,
aggregated_points
:
Optional
[
Tensor
]
=
None
,
pts_instance_mask
=
None
,
pred_surface_center
:
Optional
[
Tensor
]
=
None
,
aggregated_points
=
None
,
pred_line_center
:
Optional
[
Tensor
]
=
None
,
pred_surface_center
=
None
,
pred_obj_surface_center
:
Optional
[
Tensor
]
=
None
,
pred_line_center
=
None
,
pred_obj_line_center
:
Optional
[
Tensor
]
=
None
,
pred_obj_surface_center
=
None
,
pred_surface_sem
:
Optional
[
Tensor
]
=
None
,
pred_obj_line_center
=
None
,
pred_line_sem
:
Optional
[
Tensor
]
=
None
):
pred_surface_sem
=
None
,
pred_line_sem
=
None
):
"""Generate targets for primitive cues for single batch.
"""Generate targets for primitive cues for single batch.
Args:
Args:
...
@@ -773,10 +842,6 @@ class H3DBboxHead(BaseModule):
...
@@ -773,10 +842,6 @@ class H3DBboxHead(BaseModule):
gt_bboxes_3d (:obj:`BaseInstance3DBoxes`): Ground truth
gt_bboxes_3d (:obj:`BaseInstance3DBoxes`): Ground truth
boxes of each batch.
boxes of each batch.
gt_labels_3d (torch.Tensor): Labels of each batch.
gt_labels_3d (torch.Tensor): Labels of each batch.
pts_semantic_mask (torch.Tensor): Point-wise semantic
label of each batch.
pts_instance_mask (torch.Tensor): Point-wise instance
label of each batch.
aggregated_points (torch.Tensor): Aggregated points from
aggregated_points (torch.Tensor): Aggregated points from
vote aggregation layer.
vote aggregation layer.
pred_surface_center (torch.Tensor): Prediction of surface center.
pred_surface_center (torch.Tensor): Prediction of surface center.
...
@@ -847,12 +912,10 @@ class H3DBboxHead(BaseModule):
...
@@ -847,12 +912,10 @@ class H3DBboxHead(BaseModule):
euclidean_dist_line
=
torch
.
sqrt
(
dist_line
.
squeeze
(
0
)
+
1e-6
)
euclidean_dist_line
=
torch
.
sqrt
(
dist_line
.
squeeze
(
0
)
+
1e-6
)
objectness_label_surface
=
euclidean_dist_line
.
new_zeros
(
objectness_label_surface
=
euclidean_dist_line
.
new_zeros
(
num_proposals
*
6
,
dtype
=
torch
.
long
)
num_proposals
*
6
,
dtype
=
torch
.
long
)
objectness_mask_surface
=
euclidean_dist_line
.
new_zeros
(
num_proposals
*
6
)
objectness_label_line
=
euclidean_dist_line
.
new_zeros
(
objectness_label_line
=
euclidean_dist_line
.
new_zeros
(
num_proposals
*
12
,
dtype
=
torch
.
long
)
num_proposals
*
12
,
dtype
=
torch
.
long
)
objectness_mask_line
=
euclidean_dist_line
.
new_zeros
(
num_proposals
*
12
)
objectness_label_surface_sem
=
euclidean_dist_line
.
new_zeros
(
objectness_label_surface_sem
=
euclidean_dist_line
.
new_zeros
(
num_proposals
*
6
,
dtype
=
torch
.
long
)
num_proposals
*
6
,
dtype
=
torch
.
long
)
objectness_label_line_sem
=
euclidean_dist_line
.
new_zeros
(
objectness_label_line_sem
=
euclidean_dist_line
.
new_zeros
(
...
...
mmdet3d/models/roi_heads/h3d_roi_head.py
View file @
0e17beab
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
from
mmdet3d.core.bbox
import
bbox3d2result
from
typing
import
Dict
,
List
from
mmengine
import
InstanceData
from
torch
import
Tensor
from
mmdet3d.registry
import
MODELS
from
mmdet3d.registry
import
MODELS
from
...core
import
Det3DDataSample
from
.base_3droi_head
import
Base3DRoIHead
from
.base_3droi_head
import
Base3DRoIHead
...
@@ -16,17 +21,15 @@ class H3DRoIHead(Base3DRoIHead):
...
@@ -16,17 +21,15 @@ class H3DRoIHead(Base3DRoIHead):
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
primitive_list
,
primitive_list
:
List
[
dict
],
bbox_head
=
None
,
bbox_head
:
dict
=
None
,
train_cfg
=
None
,
train_cfg
:
dict
=
None
,
test_cfg
=
None
,
test_cfg
:
dict
=
None
,
pretrained
=
None
,
init_cfg
:
dict
=
None
):
init_cfg
=
None
):
super
(
H3DRoIHead
,
self
).
__init__
(
super
(
H3DRoIHead
,
self
).
__init__
(
bbox_head
=
bbox_head
,
bbox_head
=
bbox_head
,
train_cfg
=
train_cfg
,
train_cfg
=
train_cfg
,
test_cfg
=
test_cfg
,
test_cfg
=
test_cfg
,
pretrained
=
pretrained
,
init_cfg
=
init_cfg
)
init_cfg
=
init_cfg
)
# Primitive module
# Primitive module
assert
len
(
primitive_list
)
==
3
assert
len
(
primitive_list
)
==
3
...
@@ -39,8 +42,14 @@ class H3DRoIHead(Base3DRoIHead):
...
@@ -39,8 +42,14 @@ class H3DRoIHead(Base3DRoIHead):
one."""
one."""
pass
pass
def
init_bbox_head
(
self
,
bbox_head
):
def
init_bbox_head
(
self
,
dummy_args
,
bbox_head
):
"""Initialize box head."""
"""Initialize box head.
Args:
dummy_args (optional): Just to compatible with
the interface in base class
bbox_head (dict): Config for bbox head.
"""
bbox_head
[
'train_cfg'
]
=
self
.
train_cfg
bbox_head
[
'train_cfg'
]
=
self
.
train_cfg
bbox_head
[
'test_cfg'
]
=
self
.
test_cfg
bbox_head
[
'test_cfg'
]
=
self
.
test_cfg
self
.
bbox_head
=
MODELS
.
build
(
bbox_head
)
self
.
bbox_head
=
MODELS
.
build
(
bbox_head
)
...
@@ -49,111 +58,73 @@ class H3DRoIHead(Base3DRoIHead):
...
@@ -49,111 +58,73 @@ class H3DRoIHead(Base3DRoIHead):
"""Initialize assigner and sampler."""
"""Initialize assigner and sampler."""
pass
pass
def
forward_train
(
self
,
def
loss
(
self
,
points
:
List
[
Tensor
],
feats_dict
:
dict
,
feats_dict
,
batch_data_samples
:
List
[
Det3DDataSample
],
**
kwargs
):
img_metas
,
points
,
gt_bboxes_3d
,
gt_labels_3d
,
pts_semantic_mask
,
pts_instance_mask
,
gt_bboxes_ignore
=
None
):
"""Training forward function of PartAggregationROIHead.
"""Training forward function of PartAggregationROIHead.
Args:
Args:
feats_dict (dict): Contains features from the first stage.
points (list[torch.Tensor]): Point cloud of each sample.
img_metas (list[dict]): Contain pcd and img's meta info.
feats_dict (dict): Dict of feature.
points (list[torch.Tensor]): Input points.
batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth
Samples. It usually includes information such as
bboxes of each sample.
`gt_instance_3d`.
gt_labels_3d (list[torch.Tensor]): Labels of each sample.
pts_semantic_mask (list[torch.Tensor]): Point-wise
semantic mask.
pts_instance_mask (list[torch.Tensor]): Point-wise
instance mask.
gt_bboxes_ignore (list[torch.Tensor]): Specify
which bounding boxes to ignore.
Returns:
Returns:
dict: losses from each head.
dict: losses from each head.
"""
"""
losses
=
dict
()
losses
=
dict
()
sample_mod
=
self
.
train_cfg
.
sample_mod
primitive_loss_inputs
=
(
points
,
feats_dict
,
batch_data_samples
)
assert
sample_mod
in
[
'vote'
,
'seed'
,
'random'
]
# note the feats_dict would be added new key and value in each head.
result_z
=
self
.
primitive_z
(
feats_dict
,
sample_mod
)
feats_dict
.
update
(
result_z
)
result_xy
=
self
.
primitive_xy
(
feats_dict
,
sample_mod
)
feats_dict
.
update
(
result_xy
)
result_line
=
self
.
primitive_line
(
feats_dict
,
sample_mod
)
feats_dict
.
update
(
result_line
)
primitive_loss_inputs
=
(
feats_dict
,
points
,
gt_bboxes_3d
,
gt_labels_3d
,
pts_semantic_mask
,
pts_instance_mask
,
img_metas
,
gt_bboxes_ignore
)
loss_z
=
self
.
primitive_z
.
loss
(
*
primitive_loss_inputs
)
loss_z
=
self
.
primitive_z
.
loss
(
*
primitive_loss_inputs
)
losses
.
update
(
loss_z
)
loss_xy
=
self
.
primitive_xy
.
loss
(
*
primitive_loss_inputs
)
loss_xy
=
self
.
primitive_xy
.
loss
(
*
primitive_loss_inputs
)
losses
.
update
(
loss_xy
)
loss_line
=
self
.
primitive_line
.
loss
(
*
primitive_loss_inputs
)
loss_line
=
self
.
primitive_line
.
loss
(
*
primitive_loss_inputs
)
losses
.
update
(
loss_z
)
losses
.
update
(
loss_xy
)
losses
.
update
(
loss_line
)
losses
.
update
(
loss_line
)
targets
=
feats_dict
.
pop
(
'targets'
)
targets
=
feats_dict
.
pop
(
'targets'
)
bbox_results
=
self
.
bbox_head
(
feats_dict
,
sample_mod
)
bbox_loss
=
self
.
bbox_head
.
loss
(
points
,
feats_dict
.
update
(
bbox_results
)
feats_dict
,
bbox_loss
=
self
.
bbox_head
.
loss
(
feats_dict
,
points
,
gt_bboxes_3d
,
rpn_targets
=
targets
,
gt_labels_3d
,
pts_semantic_mask
,
batch_data_samples
=
batch_data_samples
)
pts_instance_mask
,
img_metas
,
targets
,
gt_bboxes_ignore
)
losses
.
update
(
bbox_loss
)
losses
.
update
(
bbox_loss
)
return
losses
return
losses
def
simple_test
(
self
,
feats_dict
,
img_metas
,
points
,
rescale
=
False
):
def
predict
(
self
,
"""Simple testing forward function of PartAggregationROIHead.
points
:
List
[
Tensor
],
feats_dict
:
Dict
[
str
,
Tensor
],
Note:
batch_data_samples
:
List
[
Det3DDataSample
],
This function assumes that the batch size is 1
suffix
=
'_optimized'
,
**
kwargs
)
->
List
[
InstanceData
]:
"""
Args:
Args:
feats_dict (dict): Contains features from the first stage
.
points (list[tensor]): Point clouds of multiple samples
.
img_metas (list[
dict
]
):
Contain pcd and img's meta info
.
feats_dict (
dict):
Features from FPN or backbone.
.
points (torch.Tensor): Input points.
batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
rescale (bool): Whether to rescale results
.
Samples. It usually includes meta information of data
.
Returns:
Returns:
dict: Bbox results of one frame.
list[:obj:`InstanceData`]: List of processed predictions. Each
InstanceData contains 3d Bounding boxes and corresponding
scores and labels.
"""
"""
sample_mod
=
self
.
test_cfg
.
sample_mod
assert
sample_mod
in
[
'vote'
,
'seed'
,
'random'
]
result_z
=
self
.
primitive_z
(
feats_dict
,
sample_mod
)
result_z
=
self
.
primitive_z
(
feats_dict
)
feats_dict
.
update
(
result_z
)
feats_dict
.
update
(
result_z
)
result_xy
=
self
.
primitive_xy
(
feats_dict
,
sample_mod
)
result_xy
=
self
.
primitive_xy
(
feats_dict
)
feats_dict
.
update
(
result_xy
)
feats_dict
.
update
(
result_xy
)
result_line
=
self
.
primitive_line
(
feats_dict
,
sample_mod
)
result_line
=
self
.
primitive_line
(
feats_dict
)
feats_dict
.
update
(
result_line
)
feats_dict
.
update
(
result_line
)
bbox_preds
=
self
.
bbox_head
(
feats_dict
,
sample_mod
)
bbox_preds
=
self
.
bbox_head
(
feats_dict
)
feats_dict
.
update
(
bbox_preds
)
feats_dict
.
update
(
bbox_preds
)
bbox_list
=
self
.
bbox_head
.
get_bboxes
(
results_list
=
self
.
bbox_head
.
predict
(
points
,
points
,
feats_dict
,
batch_data_samples
,
suffix
=
suffix
)
feats_dict
,
img_metas
,
return
results_list
rescale
=
rescale
,
suffix
=
'_optimized'
)
bbox_results
=
[
bbox3d2result
(
bboxes
,
scores
,
labels
)
for
bboxes
,
scores
,
labels
in
bbox_list
]
return
bbox_results
mmdet3d/models/roi_heads/mask_heads/primitive_head.py
View file @
0e17beab
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
from
typing
import
Dict
,
List
,
Optional
import
torch
import
torch
from
mmcv.cnn
import
ConvModule
from
mmcv.cnn
import
ConvModule
from
mmcv.ops
import
furthest_point_sample
from
mmcv.ops
import
furthest_point_sample
from
mmcv.runner
import
BaseModule
from
mmcv.runner
import
BaseModule
from
mmengine
import
InstanceData
from
torch
import
nn
as
nn
from
torch
import
nn
as
nn
from
torch.nn
import
functional
as
F
from
torch.nn
import
functional
as
F
from
mmdet3d.
models.builder
import
build_loss
from
mmdet3d.
core
import
Det3DDataSample
from
mmdet3d.models.model_utils
import
VoteModule
from
mmdet3d.models.model_utils
import
VoteModule
from
mmdet3d.ops
import
build_sa_module
from
mmdet3d.ops
import
build_sa_module
from
mmdet3d.registry
import
MODELS
from
mmdet3d.registry
import
MODELS
...
@@ -40,24 +43,25 @@ class PrimitiveHead(BaseModule):
...
@@ -40,24 +43,25 @@ class PrimitiveHead(BaseModule):
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
num_dims
,
num_dims
:
int
,
num_classes
,
num_classes
:
int
,
primitive_mode
,
primitive_mode
:
str
,
train_cfg
=
None
,
train_cfg
:
dict
=
None
,
test_cfg
=
None
,
test_cfg
:
dict
=
None
,
vote_module_cfg
=
None
,
vote_module_cfg
:
dict
=
None
,
vote_aggregation_cfg
=
None
,
vote_aggregation_cfg
:
dict
=
None
,
feat_channels
=
(
128
,
128
),
feat_channels
:
tuple
=
(
128
,
128
),
upper_thresh
=
100.0
,
upper_thresh
:
float
=
100.0
,
surface_thresh
=
0.5
,
surface_thresh
:
float
=
0.5
,
conv_cfg
=
dict
(
type
=
'Conv1d'
),
conv_cfg
:
dict
=
dict
(
type
=
'Conv1d'
),
norm_cfg
=
dict
(
type
=
'BN1d'
),
norm_cfg
:
dict
=
dict
(
type
=
'BN1d'
),
objectness_loss
=
None
,
objectness_loss
:
dict
=
None
,
center_loss
=
None
,
center_loss
:
dict
=
None
,
semantic_reg_loss
=
None
,
semantic_reg_loss
:
dict
=
None
,
semantic_cls_loss
=
None
,
semantic_cls_loss
:
dict
=
None
,
init_cfg
=
None
):
init_cfg
:
dict
=
None
):
super
(
PrimitiveHead
,
self
).
__init__
(
init_cfg
=
init_cfg
)
super
(
PrimitiveHead
,
self
).
__init__
(
init_cfg
=
init_cfg
)
# bounding boxes centers, face centers and edge centers
assert
primitive_mode
in
[
'z'
,
'xy'
,
'line'
]
assert
primitive_mode
in
[
'z'
,
'xy'
,
'line'
]
# The dimension of primitive semantic information.
# The dimension of primitive semantic information.
self
.
num_dims
=
num_dims
self
.
num_dims
=
num_dims
...
@@ -70,10 +74,10 @@ class PrimitiveHead(BaseModule):
...
@@ -70,10 +74,10 @@ class PrimitiveHead(BaseModule):
self
.
upper_thresh
=
upper_thresh
self
.
upper_thresh
=
upper_thresh
self
.
surface_thresh
=
surface_thresh
self
.
surface_thresh
=
surface_thresh
self
.
objectness
_loss
=
build_loss
(
objectness_loss
)
self
.
loss_
objectness
=
MODELS
.
build
(
objectness_loss
)
self
.
center
_loss
=
build_loss
(
center_loss
)
self
.
loss_
center
=
MODELS
.
build
(
center_loss
)
self
.
semantic_reg
_loss
=
build_loss
(
semantic_reg_loss
)
self
.
loss_
semantic_reg
=
MODELS
.
build
(
semantic_reg_loss
)
self
.
semantic_cls
_loss
=
build_loss
(
semantic_cls_loss
)
self
.
loss_
semantic_cls
=
MODELS
.
build
(
semantic_cls_loss
)
assert
vote_aggregation_cfg
[
'mlp_channels'
][
0
]
==
vote_module_cfg
[
assert
vote_aggregation_cfg
[
'mlp_channels'
][
0
]
==
vote_module_cfg
[
'in_channels'
]
'in_channels'
]
...
@@ -114,18 +118,26 @@ class PrimitiveHead(BaseModule):
...
@@ -114,18 +118,26 @@ class PrimitiveHead(BaseModule):
self
.
conv_pred
.
add_module
(
'conv_out'
,
self
.
conv_pred
.
add_module
(
'conv_out'
,
nn
.
Conv1d
(
prev_channel
,
conv_out_channel
,
1
))
nn
.
Conv1d
(
prev_channel
,
conv_out_channel
,
1
))
def
forward
(
self
,
feats_dict
,
sample_mod
):
@
property
def
sample_mode
(
self
):
if
self
.
training
:
sample_mode
=
self
.
train_cfg
.
sample_mode
else
:
sample_mode
=
self
.
test_cfg
.
sample_mode
assert
sample_mode
in
[
'vote'
,
'seed'
,
'random'
]
return
sample_mode
def
forward
(
self
,
feats_dict
):
"""Forward pass.
"""Forward pass.
Args:
Args:
feats_dict (dict): Feature dict from backbone.
feats_dict (dict): Feature dict from backbone.
sample_mod (str): Sample mode for vote aggregation layer.
valid modes are "vote", "seed" and "random".
Returns:
Returns:
dict: Predictions of primitive head.
dict: Predictions of primitive head.
"""
"""
assert
sample_mod
in
[
'vote'
,
'seed'
,
'random'
]
sample_mod
e
=
self
.
sample_mode
seed_points
=
feats_dict
[
'fp_xyz_net0'
][
-
1
]
seed_points
=
feats_dict
[
'fp_xyz_net0'
][
-
1
]
seed_features
=
feats_dict
[
'hd_feature'
]
seed_features
=
feats_dict
[
'hd_feature'
]
...
@@ -143,14 +155,14 @@ class PrimitiveHead(BaseModule):
...
@@ -143,14 +155,14 @@ class PrimitiveHead(BaseModule):
results
[
'vote_features_'
+
self
.
primitive_mode
]
=
vote_features
results
[
'vote_features_'
+
self
.
primitive_mode
]
=
vote_features
# 2. aggregate vote_points
# 2. aggregate vote_points
if
sample_mod
==
'vote'
:
if
sample_mod
e
==
'vote'
:
# use fps in vote_aggregation
# use fps in vote_aggregation
sample_indices
=
None
sample_indices
=
None
elif
sample_mod
==
'seed'
:
elif
sample_mod
e
==
'seed'
:
# FPS on seed and choose the votes corresponding to the seeds
# FPS on seed and choose the votes corresponding to the seeds
sample_indices
=
furthest_point_sample
(
seed_points
,
sample_indices
=
furthest_point_sample
(
seed_points
,
self
.
num_proposal
)
self
.
num_proposal
)
elif
sample_mod
==
'random'
:
elif
sample_mod
e
==
'random'
:
# Random sampling from the votes
# Random sampling from the votes
batch_size
,
num_seed
=
seed_points
.
shape
[:
2
]
batch_size
,
num_seed
=
seed_points
.
shape
[:
2
]
sample_indices
=
torch
.
randint
(
sample_indices
=
torch
.
randint
(
...
@@ -185,63 +197,103 @@ class PrimitiveHead(BaseModule):
...
@@ -185,63 +197,103 @@ class PrimitiveHead(BaseModule):
results
[
'pred_'
+
self
.
primitive_mode
+
'_center'
]
=
center
results
[
'pred_'
+
self
.
primitive_mode
+
'_center'
]
=
center
return
results
return
results
def
loss
(
self
,
def
loss
(
self
,
points
:
List
[
torch
.
Tensor
],
feats_dict
:
Dict
[
str
,
bbox_preds
,
torch
.
Tensor
],
points
,
batch_data_samples
:
List
[
Det3DDataSample
],
**
kwargs
)
->
dict
:
gt_bboxes_3d
,
"""
gt_labels_3d
,
Args:
pts_semantic_mask
=
None
,
points (list[tensor]): Points cloud of multiple samples.
pts_instance_mask
=
None
,
feats_dict (dict): Predictions from backbone or FPN.
img_metas
=
None
,
batch_data_samples (list[:obj:`Det3DDataSample`]): Each item
gt_bboxes_ignore
=
None
):
contains the meta information of each sample and
corresponding annotations.
Returns:
dict: A dictionary of loss components.
"""
preds
=
self
(
feats_dict
)
feats_dict
.
update
(
preds
)
batch_gt_instance_3d
=
[]
batch_gt_instances_ignore
=
[]
batch_input_metas
=
[]
batch_pts_semantic_mask
=
[]
batch_pts_instance_mask
=
[]
for
data_sample
in
batch_data_samples
:
batch_input_metas
.
append
(
data_sample
.
metainfo
)
batch_gt_instance_3d
.
append
(
data_sample
.
gt_instances_3d
)
batch_gt_instances_ignore
.
append
(
data_sample
.
get
(
'ignored_instances'
,
None
))
batch_pts_semantic_mask
.
append
(
data_sample
.
gt_pts_seg
.
get
(
'pts_semantic_mask'
,
None
))
batch_pts_instance_mask
.
append
(
data_sample
.
gt_pts_seg
.
get
(
'pts_instance_mask'
,
None
))
loss_inputs
=
(
points
,
feats_dict
,
batch_gt_instance_3d
)
losses
=
self
.
loss_by_feat
(
*
loss_inputs
,
batch_pts_semantic_mask
=
batch_pts_semantic_mask
,
batch_pts_instance_mask
=
batch_pts_instance_mask
,
batch_gt_instances_ignore
=
batch_gt_instances_ignore
,
)
return
losses
def
loss_by_feat
(
self
,
points
:
List
[
torch
.
Tensor
],
feats_dict
:
dict
,
batch_gt_instances_3d
:
List
[
InstanceData
],
batch_pts_semantic_mask
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
batch_pts_instance_mask
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
**
kwargs
):
"""Compute loss.
"""Compute loss.
Args:
Args:
bbox_preds (dict): Predictions from forward of primitive head.
points (list[torch.Tensor]): Input points.
points (list[torch.Tensor]): Input points.
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth
feats_dict (dict): Predictions of previous modules.
bboxes of each sample.
batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
gt_labels_3d (list[torch.Tensor]): Labels of each sample.
gt_instances. It usually includes ``bboxes`` and ``labels``
pts_semantic_mask (list[torch.Tensor]): Point-wise
attributes.
s
emantic mask
.
batch_pts_semantic_mask (list[tensor]): S
emantic mask
pts_instance_mask (list[torch.Tensor]): Point-wise
of points cloud. Defaults to None.
i
nstance mask
.
batch_pts_semantic_mask (list[tensor]): I
nstance mask
img_metas (list[dict]): Contain pcd and img's meta info
.
of points cloud. Defaults to None
.
gt_bboxes_ignore (list[torch.Tensor]): Specify
batch_input_metas (list[dict]): Contain pcd and img's meta info.
which bounding
.
ret_target (bool): Return targets or not. Defaults to False
.
Returns:
Returns:
dict: Losses of Primitive Head.
dict: Losses of Primitive Head.
"""
"""
targets
=
self
.
get_targets
(
points
,
gt_bboxes_3d
,
gt_labels_3d
,
pts_semantic_mask
,
pts_instance_mask
,
targets
=
self
.
get_targets
(
points
,
feats_dict
,
batch_gt_instances_3d
,
bbox_preds
)
batch_pts_semantic_mask
,
batch_pts_instance_mask
)
(
point_mask
,
point_offset
,
gt_primitive_center
,
gt_primitive_semantic
,
(
point_mask
,
point_offset
,
gt_primitive_center
,
gt_primitive_semantic
,
gt_sem_cls_label
,
gt_primitive_mask
)
=
targets
gt_sem_cls_label
,
gt_primitive_mask
)
=
targets
losses
=
{}
losses
=
{}
# Compute the loss of primitive existence flag
# Compute the loss of primitive existence flag
pred_flag
=
bbox_preds
[
'pred_flag_'
+
self
.
primitive_mode
]
pred_flag
=
feats_dict
[
'pred_flag_'
+
self
.
primitive_mode
]
flag_loss
=
self
.
objectness
_loss
(
pred_flag
,
gt_primitive_mask
.
long
())
flag_loss
=
self
.
loss_
objectness
(
pred_flag
,
gt_primitive_mask
.
long
())
losses
[
'flag_loss_'
+
self
.
primitive_mode
]
=
flag_loss
losses
[
'flag_loss_'
+
self
.
primitive_mode
]
=
flag_loss
# calculate vote loss
# calculate vote loss
vote_loss
=
self
.
vote_module
.
get_loss
(
vote_loss
=
self
.
vote_module
.
get_loss
(
bbox_preds
[
'seed_points'
],
feats_dict
[
'seed_points'
],
bbox_preds
[
'vote_'
+
self
.
primitive_mode
],
feats_dict
[
'vote_'
+
self
.
primitive_mode
],
bbox_preds
[
'seed_indices'
],
point_mask
,
point_offset
)
feats_dict
[
'seed_indices'
],
point_mask
,
point_offset
)
losses
[
'vote_loss_'
+
self
.
primitive_mode
]
=
vote_loss
losses
[
'vote_loss_'
+
self
.
primitive_mode
]
=
vote_loss
num_proposal
=
bbox_preds
[
'aggregated_points_'
+
num_proposal
=
feats_dict
[
'aggregated_points_'
+
self
.
primitive_mode
].
shape
[
1
]
self
.
primitive_mode
].
shape
[
1
]
primitive_center
=
bbox_preds
[
'center_'
+
self
.
primitive_mode
]
primitive_center
=
feats_dict
[
'center_'
+
self
.
primitive_mode
]
if
self
.
primitive_mode
!=
'line'
:
if
self
.
primitive_mode
!=
'line'
:
primitive_semantic
=
bbox_preds
[
'size_residuals_'
+
primitive_semantic
=
feats_dict
[
'size_residuals_'
+
self
.
primitive_mode
].
contiguous
()
self
.
primitive_mode
].
contiguous
()
else
:
else
:
primitive_semantic
=
None
primitive_semantic
=
None
semancitc_scores
=
bbox_preds
[
'sem_cls_scores_'
+
semancitc_scores
=
feats_dict
[
'sem_cls_scores_'
+
self
.
primitive_mode
].
transpose
(
2
,
1
)
self
.
primitive_mode
].
transpose
(
2
,
1
)
gt_primitive_mask
=
gt_primitive_mask
/
\
gt_primitive_mask
=
gt_primitive_mask
/
\
...
@@ -256,44 +308,61 @@ class PrimitiveHead(BaseModule):
...
@@ -256,44 +308,61 @@ class PrimitiveHead(BaseModule):
return
losses
return
losses
def
get_targets
(
self
,
def
get_targets
(
self
,
points
,
points
,
gt_bboxes_3d
,
bbox_preds
:
Optional
[
dict
]
=
None
,
gt_labels_3d
,
batch_gt_instances_3d
:
List
[
InstanceData
]
=
None
,
pts_semantic_mask
=
None
,
batch_
pts_semantic_mask
:
List
[
torch
.
Tensor
]
=
None
,
pts_instance_mask
=
None
,
batch_
pts_instance_mask
:
List
[
torch
.
Tensor
]
=
None
,
bbox_preds
=
None
):
):
"""Generate targets of primitive head.
"""Generate targets of primitive head.
Args:
Args:
points (list[torch.Tensor]): Points of each batch.
points (list[torch.Tensor]): Points of each batch.
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth
bbox_preds (torch.Tensor): Bounding box predictions of
bboxes of each batch.
primitive head.
gt_labels_3d (list[torch.Tensor]): Labels of each batch.
batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
pts_semantic_mask (list[torch.Tensor]): Point-wise semantic
gt_instances. It usually includes ``bboxes_3d`` and
label of each batch.
``labels_3d`` attributes.
pts_instance_mask (list[torch.Tensor]): Point-wise instance
batch_pts_semantic_mask (list[tensor]): Semantic gt mask for
label of each batch.
multiple images.
bbox_preds (dict): Predictions from forward of primitive head.
batch_pts_instance_mask (list[tensor]): Instance gt mask for
multiple images.
Returns:
Returns:
tuple[torch.Tensor]: Targets of primitive head.
tuple[torch.Tensor]: Targets of primitive head.
"""
"""
for
index
in
range
(
len
(
gt_labels_3d
)):
batch_gt_labels_3d
=
[
if
len
(
gt_labels_3d
[
index
])
==
0
:
gt_instances_3d
.
labels_3d
fake_box
=
gt_bboxes_3d
[
index
].
tensor
.
new_zeros
(
for
gt_instances_3d
in
batch_gt_instances_3d
1
,
gt_bboxes_3d
[
index
].
tensor
.
shape
[
-
1
])
]
gt_bboxes_3d
[
index
]
=
gt_bboxes_3d
[
index
].
new_box
(
fake_box
)
batch_gt_bboxes_3d
=
[
gt_labels_3d
[
index
]
=
gt_labels_3d
[
index
].
new_zeros
(
1
)
gt_instances_3d
.
bboxes_3d
for
gt_instances_3d
in
batch_gt_instances_3d
if
pts_semantic_mask
is
None
:
]
pts_semantic_mask
=
[
None
for
i
in
range
(
len
(
gt_labels_3d
))]
for
index
in
range
(
len
(
batch_gt_labels_3d
)):
pts_instance_mask
=
[
None
for
i
in
range
(
len
(
gt_labels_3d
))]
if
len
(
batch_gt_labels_3d
[
index
])
==
0
:
fake_box
=
batch_gt_bboxes_3d
[
index
].
tensor
.
new_zeros
(
1
,
batch_gt_bboxes_3d
[
index
].
tensor
.
shape
[
-
1
])
batch_gt_bboxes_3d
[
index
]
=
batch_gt_bboxes_3d
[
index
].
new_box
(
fake_box
)
batch_gt_labels_3d
[
index
]
=
batch_gt_labels_3d
[
index
].
new_zeros
(
1
)
if
batch_pts_semantic_mask
is
None
:
batch_pts_semantic_mask
=
[
None
for
_
in
range
(
len
(
batch_gt_labels_3d
))
]
batch_pts_instance_mask
=
[
None
for
_
in
range
(
len
(
batch_gt_labels_3d
))
]
(
point_mask
,
point_sem
,
(
point_mask
,
point_sem
,
point_offset
)
=
multi_apply
(
self
.
get_targets_single
,
points
,
point_offset
)
=
multi_apply
(
self
.
get_targets_single
,
points
,
gt_bboxes_3d
,
gt_labels_3d
,
batch_gt_bboxes_3d
,
batch_gt_labels_3d
,
pts_semantic_mask
,
pts_instance_mask
)
batch_pts_semantic_mask
,
batch_pts_instance_mask
)
point_mask
=
torch
.
stack
(
point_mask
)
point_mask
=
torch
.
stack
(
point_mask
)
point_sem
=
torch
.
stack
(
point_sem
)
point_sem
=
torch
.
stack
(
point_sem
)
...
@@ -759,7 +828,7 @@ class PrimitiveHead(BaseModule):
...
@@ -759,7 +828,7 @@ class PrimitiveHead(BaseModule):
vote_xyz_reshape
=
primitive_center
.
view
(
batch_size
*
num_proposal
,
-
1
,
vote_xyz_reshape
=
primitive_center
.
view
(
batch_size
*
num_proposal
,
-
1
,
3
)
3
)
center_loss
=
self
.
center
_loss
(
center_loss
=
self
.
loss_
center
(
vote_xyz_reshape
,
vote_xyz_reshape
,
gt_primitive_center
,
gt_primitive_center
,
dst_weight
=
gt_primitive_mask
.
view
(
batch_size
*
num_proposal
,
1
))[
1
]
dst_weight
=
gt_primitive_mask
.
view
(
batch_size
*
num_proposal
,
1
))[
1
]
...
@@ -767,7 +836,7 @@ class PrimitiveHead(BaseModule):
...
@@ -767,7 +836,7 @@ class PrimitiveHead(BaseModule):
if
self
.
primitive_mode
!=
'line'
:
if
self
.
primitive_mode
!=
'line'
:
size_xyz_reshape
=
primitive_semantic
.
view
(
size_xyz_reshape
=
primitive_semantic
.
view
(
batch_size
*
num_proposal
,
-
1
,
self
.
num_dims
).
contiguous
()
batch_size
*
num_proposal
,
-
1
,
self
.
num_dims
).
contiguous
()
size_loss
=
self
.
semantic_reg
_loss
(
size_loss
=
self
.
loss_
semantic_reg
(
size_xyz_reshape
,
size_xyz_reshape
,
gt_primitive_semantic
,
gt_primitive_semantic
,
dst_weight
=
gt_primitive_mask
.
view
(
batch_size
*
num_proposal
,
dst_weight
=
gt_primitive_mask
.
view
(
batch_size
*
num_proposal
,
...
@@ -776,7 +845,7 @@ class PrimitiveHead(BaseModule):
...
@@ -776,7 +845,7 @@ class PrimitiveHead(BaseModule):
size_loss
=
center_loss
.
new_tensor
(
0.0
)
size_loss
=
center_loss
.
new_tensor
(
0.0
)
# Semantic cls loss
# Semantic cls loss
sem_cls_loss
=
self
.
semantic_cls
_loss
(
sem_cls_loss
=
self
.
loss_
semantic_cls
(
semantic_scores
,
gt_sem_cls_label
,
weight
=
gt_primitive_mask
)
semantic_scores
,
gt_sem_cls_label
,
weight
=
gt_primitive_mask
)
return
center_loss
,
size_loss
,
sem_cls_loss
return
center_loss
,
size_loss
,
sem_cls_loss
...
...
tests/test_models/test_detectors/test_h3dnet.py
0 → 100644
View file @
0e17beab
import
unittest
import
torch
from
mmengine
import
DefaultScope
from
mmdet3d.registry
import
MODELS
from
tests.utils.model_utils
import
(
_create_detector_inputs
,
_get_detector_cfg
,
_setup_seed
)
class
TestH3D
(
unittest
.
TestCase
):
def
test_h3dnet
(
self
):
import
mmdet3d.models
assert
hasattr
(
mmdet3d
.
models
,
'H3DNet'
)
DefaultScope
.
get_instance
(
'test_H3DNet'
,
scope_name
=
'mmdet3d'
)
_setup_seed
(
0
)
voxel_net_cfg
=
_get_detector_cfg
(
'h3dnet/h3dnet_3x8_scannet-3d-18class.py'
)
model
=
MODELS
.
build
(
voxel_net_cfg
)
num_gt_instance
=
5
data
=
[
_create_detector_inputs
(
num_gt_instance
=
num_gt_instance
,
points_feat_dim
=
4
,
bboxes_3d_type
=
'depth'
,
with_pts_semantic_mask
=
True
,
with_pts_instance_mask
=
True
)
]
if
torch
.
cuda
.
is_available
():
model
=
model
.
cuda
()
# test simple_test
with
torch
.
no_grad
():
batch_inputs
,
data_samples
=
model
.
data_preprocessor
(
data
,
True
)
results
=
model
.
forward
(
batch_inputs
,
data_samples
,
mode
=
'predict'
)
self
.
assertEqual
(
len
(
results
),
len
(
data
))
self
.
assertIn
(
'bboxes_3d'
,
results
[
0
].
pred_instances_3d
)
self
.
assertIn
(
'scores_3d'
,
results
[
0
].
pred_instances_3d
)
self
.
assertIn
(
'labels_3d'
,
results
[
0
].
pred_instances_3d
)
# save the memory
with
torch
.
no_grad
():
losses
=
model
.
forward
(
batch_inputs
,
data_samples
,
mode
=
'loss'
)
self
.
assertGreater
(
losses
[
'vote_loss'
],
0
)
self
.
assertGreater
(
losses
[
'objectness_loss'
],
0
)
self
.
assertGreater
(
losses
[
'center_loss'
],
0
)
tests/utils/model_utils.py
View file @
0e17beab
...
@@ -7,7 +7,8 @@ import numpy as np
...
@@ -7,7 +7,8 @@ import numpy as np
import
torch
import
torch
from
mmengine
import
InstanceData
from
mmengine
import
InstanceData
from
mmdet3d.core
import
Det3DDataSample
,
LiDARInstance3DBoxes
,
PointData
from
mmdet3d.core
import
(
CameraInstance3DBoxes
,
DepthInstance3DBoxes
,
Det3DDataSample
,
LiDARInstance3DBoxes
,
PointData
)
def
_setup_seed
(
seed
):
def
_setup_seed
(
seed
):
...
@@ -71,8 +72,7 @@ def _get_detector_cfg(fname):
...
@@ -71,8 +72,7 @@ def _get_detector_cfg(fname):
return
model
return
model
def
_create_detector_inputs
(
def
_create_detector_inputs
(
seed
=
0
,
seed
=
0
,
with_points
=
True
,
with_points
=
True
,
with_img
=
False
,
with_img
=
False
,
num_gt_instance
=
20
,
num_gt_instance
=
20
,
...
@@ -82,8 +82,14 @@ def _create_detector_inputs(
...
@@ -82,8 +82,14 @@ def _create_detector_inputs(
gt_bboxes_dim
=
7
,
gt_bboxes_dim
=
7
,
with_pts_semantic_mask
=
False
,
with_pts_semantic_mask
=
False
,
with_pts_instance_mask
=
False
,
with_pts_instance_mask
=
False
,
):
bboxes_3d_type
=
'lidar'
):
_setup_seed
(
seed
)
_setup_seed
(
seed
)
assert
bboxes_3d_type
in
(
'lidar'
,
'depth'
,
'cam'
)
bbox_3d_class
=
{
'lidar'
:
LiDARInstance3DBoxes
,
'depth'
:
DepthInstance3DBoxes
,
'cam'
:
CameraInstance3DBoxes
}
if
with_points
:
if
with_points
:
points
=
torch
.
rand
([
num_points
,
points_feat_dim
])
points
=
torch
.
rand
([
num_points
,
points_feat_dim
])
else
:
else
:
...
@@ -93,12 +99,13 @@ def _create_detector_inputs(
...
@@ -93,12 +99,13 @@ def _create_detector_inputs(
else
:
else
:
img
=
None
img
=
None
inputs_dict
=
dict
(
img
=
img
,
points
=
points
)
inputs_dict
=
dict
(
img
=
img
,
points
=
points
)
gt_instance_3d
=
InstanceData
()
gt_instance_3d
=
InstanceData
()
gt_instance_3d
.
bboxes_3d
=
LiDARInstance3DBoxes
(
gt_instance_3d
.
bboxes_3d
=
bbox_3d_class
[
bboxes_3d_type
]
(
torch
.
rand
([
num_gt_instance
,
gt_bboxes_dim
]),
box_dim
=
gt_bboxes_dim
)
torch
.
rand
([
num_gt_instance
,
gt_bboxes_dim
]),
box_dim
=
gt_bboxes_dim
)
gt_instance_3d
.
labels_3d
=
torch
.
randint
(
0
,
num_classes
,
[
num_gt_instance
])
gt_instance_3d
.
labels_3d
=
torch
.
randint
(
0
,
num_classes
,
[
num_gt_instance
])
data_sample
=
Det3DDataSample
(
data_sample
=
Det3DDataSample
(
metainfo
=
dict
(
box_type_3d
=
LiDARInstance3DBoxes
))
metainfo
=
dict
(
box_type_3d
=
bbox_3d_class
[
bboxes_3d_type
]
))
data_sample
.
gt_instances_3d
=
gt_instance_3d
data_sample
.
gt_instances_3d
=
gt_instance_3d
data_sample
.
gt_pts_seg
=
PointData
()
data_sample
.
gt_pts_seg
=
PointData
()
if
with_pts_instance_mask
:
if
with_pts_instance_mask
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment