Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
mmdetection3d
Commits
3c57cc41
Commit
3c57cc41
authored
Jul 15, 2022
by
jshilong
Committed by
ChaimZhu
Jul 20, 2022
Browse files
[Refactor] Refactort the interface of 3DSSD
parent
bd73d3b9
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
246 additions
and
178 deletions
+246
-178
configs/3dssd/3dssd_4x4_kitti-3d-car.py
configs/3dssd/3dssd_4x4_kitti-3d-car.py
+29
-29
configs/_base_/datasets/kitti-3d-car.py
configs/_base_/datasets/kitti-3d-car.py
+4
-4
configs/_base_/models/3dssd.py
configs/_base_/models/3dssd.py
+11
-12
configs/parta2/hv_PartA2_secfpn_2x8_cyclic_80e_kitti-3d-3class.py
...parta2/hv_PartA2_secfpn_2x8_cyclic_80e_kitti-3d-3class.py
+3
-3
configs/votenet/votenet_16x8_sunrgbd-3d-10class.py
configs/votenet/votenet_16x8_sunrgbd-3d-10class.py
+1
-0
mmdet3d/models/dense_heads/ssd_3d_head.py
mmdet3d/models/dense_heads/ssd_3d_head.py
+154
-128
mmdet3d/models/detectors/ssd3dnet.py
mmdet3d/models/detectors/ssd3dnet.py
+2
-2
tests/test_models/test_detectors/test_3dssd.py
tests/test_models/test_detectors/test_3dssd.py
+42
-0
No files found.
configs/3dssd/3dssd_4x4_kitti-3d-car.py
View file @
3c57cc41
...
@@ -53,8 +53,9 @@ train_pipeline = [
...
@@ -53,8 +53,9 @@ train_pipeline = [
# 3DSSD can get a higher performance without this transform
# 3DSSD can get a higher performance without this transform
# dict(type='BackgroundPointsFilter', bbox_enlarge_range=(0.5, 2.0, 0.5)),
# dict(type='BackgroundPointsFilter', bbox_enlarge_range=(0.5, 2.0, 0.5)),
dict
(
type
=
'PointSample'
,
num_points
=
16384
),
dict
(
type
=
'PointSample'
,
num_points
=
16384
),
dict
(
type
=
'DefaultFormatBundle3D'
,
class_names
=
class_names
),
dict
(
dict
(
type
=
'Collect3D'
,
keys
=
[
'points'
,
'gt_bboxes_3d'
,
'gt_labels_3d'
])
type
=
'Pack3DDetInputs'
,
keys
=
[
'points'
,
'gt_bboxes_3d'
,
'gt_labels_3d'
])
]
]
test_pipeline
=
[
test_pipeline
=
[
...
@@ -79,22 +80,14 @@ test_pipeline = [
...
@@ -79,22 +80,14 @@ test_pipeline = [
dict
(
dict
(
type
=
'PointsRangeFilter'
,
point_cloud_range
=
point_cloud_range
),
type
=
'PointsRangeFilter'
,
point_cloud_range
=
point_cloud_range
),
dict
(
type
=
'PointSample'
,
num_points
=
16384
),
dict
(
type
=
'PointSample'
,
num_points
=
16384
),
dict
(
]),
type
=
'DefaultFormatBundle3D'
,
dict
(
type
=
'Pack3DDetInputs'
,
keys
=
[
'points'
])
class_names
=
class_names
,
with_label
=
False
),
dict
(
type
=
'Collect3D'
,
keys
=
[
'points'
])
])
]
]
data
=
dict
(
train_dataloader
=
dict
(
samples_per_gpu
=
4
,
batch_size
=
4
,
dataset
=
dict
(
dataset
=
dict
(
pipeline
=
train_pipeline
,
)))
workers_per_gpu
=
4
,
test_dataloader
=
dict
(
dataset
=
dict
(
pipeline
=
test_pipeline
))
train
=
dict
(
dataset
=
dict
(
pipeline
=
train_pipeline
)),
val_dataloader
=
dict
(
dataset
=
dict
(
pipeline
=
test_pipeline
))
val
=
dict
(
pipeline
=
test_pipeline
),
test
=
dict
(
pipeline
=
test_pipeline
))
evaluation
=
dict
(
interval
=
2
)
# model settings
# model settings
model
=
dict
(
model
=
dict
(
...
@@ -105,17 +98,24 @@ model = dict(
...
@@ -105,17 +98,24 @@ model = dict(
# optimizer
# optimizer
lr
=
0.002
# max learning rate
lr
=
0.002
# max learning rate
optim
iz
er
=
dict
(
type
=
'AdamW'
,
lr
=
lr
,
weight_decay
=
0
)
optim
_wrapp
er
=
dict
(
optimizer_config
=
dict
(
grad_clip
=
dict
(
max_norm
=
35
,
norm_type
=
2
))
type
=
'OptimWrapper'
,
lr_config
=
dict
(
policy
=
'step'
,
warmup
=
None
,
step
=
[
45
,
60
])
optimizer
=
dict
(
type
=
'AdamW'
,
lr
=
lr
,
weight_decay
=
0.
),
# runtime settings
clip_grad
=
dict
(
max_norm
=
35
,
norm_type
=
2
),
runner
=
dict
(
type
=
'EpochBasedRunner'
,
max_epochs
=
80
)
)
# yapf:disable
# training schedule for 1x
log_config
=
dict
(
train_cfg
=
dict
(
type
=
'EpochBasedTrainLoop'
,
max_epochs
=
80
,
val_interval
=
2
)
interval
=
30
,
val_cfg
=
dict
(
type
=
'ValLoop'
)
hooks
=
[
test_cfg
=
dict
(
type
=
'TestLoop'
)
dict
(
type
=
'TextLoggerHook'
),
dict
(
type
=
'TensorboardLoggerHook'
)
# learning rate
])
param_scheduler
=
[
# yapf:enable
dict
(
type
=
'MultiStepLR'
,
begin
=
0
,
end
=
80
,
by_epoch
=
True
,
milestones
=
[
45
,
60
],
gamma
=
0.1
)
]
configs/_base_/datasets/kitti-3d-car.py
View file @
3c57cc41
...
@@ -69,9 +69,9 @@ test_pipeline = [
...
@@ -69,9 +69,9 @@ test_pipeline = [
translation_std
=
[
0
,
0
,
0
]),
translation_std
=
[
0
,
0
,
0
]),
dict
(
type
=
'RandomFlip3D'
),
dict
(
type
=
'RandomFlip3D'
),
dict
(
dict
(
type
=
'PointsRangeFilter'
,
point_cloud_range
=
point_cloud_range
)
,
type
=
'PointsRangeFilter'
,
point_cloud_range
=
point_cloud_range
)
dict
(
type
=
'Pack3DDetInputs'
,
keys
=
[
'points'
]),
]),
])
dict
(
type
=
'Pack3DDetInputs'
,
keys
=
[
'points'
])
]
]
# construct a pipeline for data and gt loading in show function
# construct a pipeline for data and gt loading in show function
# please keep its loading function consistent with test_pipeline (e.g. client)
# please keep its loading function consistent with test_pipeline (e.g. client)
...
@@ -82,7 +82,7 @@ eval_pipeline = [
...
@@ -82,7 +82,7 @@ eval_pipeline = [
load_dim
=
4
,
load_dim
=
4
,
use_dim
=
4
,
use_dim
=
4
,
file_client_args
=
file_client_args
),
file_client_args
=
file_client_args
),
dict
(
type
=
'Pack3DDetInputs'
,
keys
=
[
'points'
])
,
dict
(
type
=
'Pack3DDetInputs'
,
keys
=
[
'points'
])
]
]
train_dataloader
=
dict
(
train_dataloader
=
dict
(
batch_size
=
6
,
batch_size
=
6
,
...
...
configs/_base_/models/3dssd.py
View file @
3c57cc41
model
=
dict
(
model
=
dict
(
type
=
'SSD3DNet'
,
type
=
'SSD3DNet'
,
data_preprocessor
=
dict
(
type
=
'Det3DDataPreprocessor'
),
backbone
=
dict
(
backbone
=
dict
(
type
=
'PointNet2SAMSG'
,
type
=
'PointNet2SAMSG'
,
in_channels
=
4
,
in_channels
=
4
,
...
@@ -20,7 +21,6 @@ model = dict(
...
@@ -20,7 +21,6 @@ model = dict(
normalize_xyz
=
False
)),
normalize_xyz
=
False
)),
bbox_head
=
dict
(
bbox_head
=
dict
(
type
=
'SSD3DHead'
,
type
=
'SSD3DHead'
,
in_channels
=
256
,
vote_module_cfg
=
dict
(
vote_module_cfg
=
dict
(
in_channels
=
256
,
in_channels
=
256
,
num_points
=
256
,
num_points
=
256
,
...
@@ -48,30 +48,29 @@ model = dict(
...
@@ -48,30 +48,29 @@ model = dict(
conv_cfg
=
dict
(
type
=
'Conv1d'
),
conv_cfg
=
dict
(
type
=
'Conv1d'
),
norm_cfg
=
dict
(
type
=
'BN1d'
,
eps
=
1e-3
,
momentum
=
0.1
),
norm_cfg
=
dict
(
type
=
'BN1d'
,
eps
=
1e-3
,
momentum
=
0.1
),
bias
=
True
),
bias
=
True
),
conv_cfg
=
dict
(
type
=
'Conv1d'
),
norm_cfg
=
dict
(
type
=
'BN1d'
,
eps
=
1e-3
,
momentum
=
0.1
),
objectness_loss
=
dict
(
objectness_loss
=
dict
(
type
=
'CrossEntropyLoss'
,
type
=
'
mmdet.
CrossEntropyLoss'
,
use_sigmoid
=
True
,
use_sigmoid
=
True
,
reduction
=
'sum'
,
reduction
=
'sum'
,
loss_weight
=
1.0
),
loss_weight
=
1.0
),
center_loss
=
dict
(
center_loss
=
dict
(
type
=
'SmoothL1Loss'
,
reduction
=
'sum'
,
loss_weight
=
1.0
),
type
=
'
mmdet.
SmoothL1Loss'
,
reduction
=
'sum'
,
loss_weight
=
1.0
),
dir_class_loss
=
dict
(
dir_class_loss
=
dict
(
type
=
'CrossEntropyLoss'
,
reduction
=
'sum'
,
loss_weight
=
1.0
),
type
=
'
mmdet.
CrossEntropyLoss'
,
reduction
=
'sum'
,
loss_weight
=
1.0
),
dir_res_loss
=
dict
(
dir_res_loss
=
dict
(
type
=
'SmoothL1Loss'
,
reduction
=
'sum'
,
loss_weight
=
1.0
),
type
=
'
mmdet.
SmoothL1Loss'
,
reduction
=
'sum'
,
loss_weight
=
1.0
),
size_res_loss
=
dict
(
size_res_loss
=
dict
(
type
=
'SmoothL1Loss'
,
reduction
=
'sum'
,
loss_weight
=
1.0
),
type
=
'
mmdet.
SmoothL1Loss'
,
reduction
=
'sum'
,
loss_weight
=
1.0
),
corner_loss
=
dict
(
corner_loss
=
dict
(
type
=
'SmoothL1Loss'
,
reduction
=
'sum'
,
loss_weight
=
1.0
),
type
=
'mmdet.SmoothL1Loss'
,
reduction
=
'sum'
,
loss_weight
=
1.0
),
vote_loss
=
dict
(
type
=
'SmoothL1Loss'
,
reduction
=
'sum'
,
loss_weight
=
1.0
)),
vote_loss
=
dict
(
type
=
'mmdet.SmoothL1Loss'
,
reduction
=
'sum'
,
loss_weight
=
1.0
)),
# model training and testing settings
# model training and testing settings
train_cfg
=
dict
(
train_cfg
=
dict
(
sample_mod
=
'spec'
,
pos_distance_thr
=
10.0
,
expand_dims_length
=
0.05
),
sample_mod
e
=
'spec'
,
pos_distance_thr
=
10.0
,
expand_dims_length
=
0.05
),
test_cfg
=
dict
(
test_cfg
=
dict
(
nms_cfg
=
dict
(
type
=
'nms'
,
iou_thr
=
0.1
),
nms_cfg
=
dict
(
type
=
'nms'
,
iou_thr
=
0.1
),
sample_mod
=
'spec'
,
sample_mod
e
=
'spec'
,
score_thr
=
0.0
,
score_thr
=
0.0
,
per_class_proposal
=
True
,
per_class_proposal
=
True
,
max_output_num
=
100
))
max_output_num
=
100
))
configs/parta2/hv_PartA2_secfpn_2x8_cyclic_80e_kitti-3d-3class.py
View file @
3c57cc41
...
@@ -57,9 +57,9 @@ test_pipeline = [
...
@@ -57,9 +57,9 @@ test_pipeline = [
translation_std
=
[
0
,
0
,
0
]),
translation_std
=
[
0
,
0
,
0
]),
dict
(
type
=
'RandomFlip3D'
),
dict
(
type
=
'RandomFlip3D'
),
dict
(
dict
(
type
=
'PointsRangeFilter'
,
point_cloud_range
=
point_cloud_range
)
,
type
=
'PointsRangeFilter'
,
point_cloud_range
=
point_cloud_range
)
dict
(
type
=
'Pack3DDetInputs'
,
keys
=
[
'points'
])
])
,
])
dict
(
type
=
'Pack3DDetInputs'
,
keys
=
[
'points'
])
]
]
# construct a pipeline for data and gt loading in show function
# construct a pipeline for data and gt loading in show function
# please keep its loading function consistent with test_pipeline (e.g. client)
# please keep its loading function consistent with test_pipeline (e.g. client)
...
...
configs/votenet/votenet_16x8_sunrgbd-3d-10class.py
View file @
3c57cc41
# TODO refactor the config of sunrgbd
_base_
=
[
_base_
=
[
'../_base_/datasets/sunrgbd-3d-10class.py'
,
'../_base_/models/votenet.py'
,
'../_base_/datasets/sunrgbd-3d-10class.py'
,
'../_base_/models/votenet.py'
,
'../_base_/schedules/schedule_3x.py'
,
'../_base_/default_runtime.py'
'../_base_/schedules/schedule_3x.py'
,
'../_base_/default_runtime.py'
...
...
mmdet3d/models/dense_heads/ssd_3d_head.py
View file @
3c57cc41
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch
from
mmcv
import
ConfigDict
from
mmcv.ops.nms
import
batched_nms
from
mmcv.ops.nms
import
batched_nms
from
mmcv.runner
import
force_fp32
from
mmengine
import
InstanceData
from
torch
import
Tensor
from
torch.nn
import
functional
as
F
from
torch.nn
import
functional
as
F
from
mmdet3d.core.bbox.structures
import
(
DepthInstance3DBoxes
,
from
mmdet3d.core.bbox.structures
import
(
DepthInstance3DBoxes
,
...
@@ -9,6 +13,7 @@ from mmdet3d.core.bbox.structures import (DepthInstance3DBoxes,
...
@@ -9,6 +13,7 @@ from mmdet3d.core.bbox.structures import (DepthInstance3DBoxes,
rotation_3d_in_axis
)
rotation_3d_in_axis
)
from
mmdet3d.registry
import
MODELS
from
mmdet3d.registry
import
MODELS
from
mmdet.core
import
multi_apply
from
mmdet.core
import
multi_apply
from
...core
import
BaseInstance3DBoxes
from
..builder
import
build_loss
from
..builder
import
build_loss
from
.vote_head
import
VoteHead
from
.vote_head
import
VoteHead
...
@@ -21,7 +26,6 @@ class SSD3DHead(VoteHead):
...
@@ -21,7 +26,6 @@ class SSD3DHead(VoteHead):
num_classes (int): The number of class.
num_classes (int): The number of class.
bbox_coder (:obj:`BaseBBoxCoder`): Bbox coder for encoding and
bbox_coder (:obj:`BaseBBoxCoder`): Bbox coder for encoding and
decoding boxes.
decoding boxes.
in_channels (int): The number of input feature channel.
train_cfg (dict): Config for training.
train_cfg (dict): Config for training.
test_cfg (dict): Config for testing.
test_cfg (dict): Config for testing.
vote_module_cfg (dict): Config of VoteModule for point-wise votes.
vote_module_cfg (dict): Config of VoteModule for point-wise votes.
...
@@ -41,25 +45,21 @@ class SSD3DHead(VoteHead):
...
@@ -41,25 +45,21 @@ class SSD3DHead(VoteHead):
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
num_classes
,
num_classes
:
int
,
bbox_coder
,
bbox_coder
:
Union
[
ConfigDict
,
dict
],
in_channels
=
256
,
train_cfg
:
Optional
[
dict
]
=
None
,
train_cfg
=
None
,
test_cfg
:
Optional
[
dict
]
=
None
,
test_cfg
=
None
,
vote_module_cfg
:
Optional
[
dict
]
=
None
,
vote_module_cfg
=
None
,
vote_aggregation_cfg
:
Optional
[
dict
]
=
None
,
vote_aggregation_cfg
=
None
,
pred_layer_cfg
:
Optional
[
dict
]
=
None
,
pred_layer_cfg
=
None
,
objectness_loss
:
Optional
[
dict
]
=
None
,
conv_cfg
=
dict
(
type
=
'Conv1d'
),
center_loss
:
Optional
[
dict
]
=
None
,
norm_cfg
=
dict
(
type
=
'BN1d'
),
dir_class_loss
:
Optional
[
dict
]
=
None
,
act_cfg
=
dict
(
type
=
'ReLU'
),
dir_res_loss
:
Optional
[
dict
]
=
None
,
objectness_loss
=
None
,
size_res_loss
:
Optional
[
dict
]
=
None
,
center_loss
=
None
,
corner_loss
:
Optional
[
dict
]
=
None
,
dir_class_loss
=
None
,
vote_loss
:
Optional
[
dict
]
=
None
,
dir_res_loss
=
None
,
init_cfg
:
Optional
[
dict
]
=
None
)
->
None
:
size_res_loss
=
None
,
corner_loss
=
None
,
vote_loss
=
None
,
init_cfg
=
None
):
super
(
SSD3DHead
,
self
).
__init__
(
super
(
SSD3DHead
,
self
).
__init__
(
num_classes
,
num_classes
,
bbox_coder
,
bbox_coder
,
...
@@ -68,8 +68,6 @@ class SSD3DHead(VoteHead):
...
@@ -68,8 +68,6 @@ class SSD3DHead(VoteHead):
vote_module_cfg
=
vote_module_cfg
,
vote_module_cfg
=
vote_module_cfg
,
vote_aggregation_cfg
=
vote_aggregation_cfg
,
vote_aggregation_cfg
=
vote_aggregation_cfg
,
pred_layer_cfg
=
pred_layer_cfg
,
pred_layer_cfg
=
pred_layer_cfg
,
conv_cfg
=
conv_cfg
,
norm_cfg
=
norm_cfg
,
objectness_loss
=
objectness_loss
,
objectness_loss
=
objectness_loss
,
center_loss
=
center_loss
,
center_loss
=
center_loss
,
dir_class_loss
=
dir_class_loss
,
dir_class_loss
=
dir_class_loss
,
...
@@ -78,24 +76,23 @@ class SSD3DHead(VoteHead):
...
@@ -78,24 +76,23 @@ class SSD3DHead(VoteHead):
size_res_loss
=
size_res_loss
,
size_res_loss
=
size_res_loss
,
semantic_loss
=
None
,
semantic_loss
=
None
,
init_cfg
=
init_cfg
)
init_cfg
=
init_cfg
)
self
.
corner_loss
=
build_loss
(
corner_loss
)
self
.
corner_loss
=
build_loss
(
corner_loss
)
self
.
vote_loss
=
build_loss
(
vote_loss
)
self
.
vote_loss
=
build_loss
(
vote_loss
)
self
.
num_candidates
=
vote_module_cfg
[
'num_points'
]
self
.
num_candidates
=
vote_module_cfg
[
'num_points'
]
def
_get_cls_out_channels
(
self
):
def
_get_cls_out_channels
(
self
)
->
int
:
"""Return the channel number of classification outputs."""
"""Return the channel number of classification outputs."""
# Class numbers (k) + objectness (1)
# Class numbers (k) + objectness (1)
return
self
.
num_classes
return
self
.
num_classes
def
_get_reg_out_channels
(
self
):
def
_get_reg_out_channels
(
self
)
->
int
:
"""Return the channel number of regression outputs."""
"""Return the channel number of regression outputs."""
# Bbox classification and regression
# Bbox classification and regression
# (center residual (3), size regression (3)
# (center residual (3), size regression (3)
# heading class+residual (num_dir_bins*2)),
# heading class+residual (num_dir_bins*2)),
return
3
+
3
+
self
.
num_dir_bins
*
2
return
3
+
3
+
self
.
num_dir_bins
*
2
def
_extract_input
(
self
,
feat_dict
)
:
def
_extract_input
(
self
,
feat_dict
:
dict
)
->
Tuple
:
"""Extract inputs from features dictionary.
"""Extract inputs from features dictionary.
Args:
Args:
...
@@ -112,86 +109,87 @@ class SSD3DHead(VoteHead):
...
@@ -112,86 +109,87 @@ class SSD3DHead(VoteHead):
return
seed_points
,
seed_features
,
seed_indices
return
seed_points
,
seed_features
,
seed_indices
@
force_fp32
(
apply_to
=
(
'bbox_preds'
,
))
def
loss_by_feat
(
def
loss
(
self
,
self
,
bbox_preds
,
points
:
List
[
torch
.
Tensor
]
,
points
,
bbox_preds_dict
:
dict
,
gt_bboxes_3d
,
batch_gt_instances_3d
:
List
[
InstanceData
]
,
gt_labels_3d
,
batch_pts_semantic_mask
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
pts_semantic_mask
=
None
,
batch_pts_instance_mask
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
pts_instance_mask
=
None
,
batch_input_metas
:
List
[
dict
]
=
None
,
img_metas
=
Non
e
,
ret_target
:
bool
=
Fals
e
,
gt_bboxes_ignore
=
None
)
:
**
kwargs
)
->
dict
:
"""Compute loss.
"""Compute loss.
Args:
Args:
bbox_preds (dict): Predictions from forward of SSD3DHead.
points (list[torch.Tensor]): Input points.
points (list[torch.Tensor]): Input points.
gt_
bbox
es_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth
bbox
_preds_dict (dict): Predictions from forward of vote head.
bboxes of each sample.
batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
gt_labels_3d (list[torch.Tensor]): Labels of each sample.
gt_instances. It usually includes ``bboxes_3d`` and
pts_semantic_mask (list[torch.Tensor]): Point-wise
``labels_3d`` attributes.
s
emantic mask
.
batch_pts_semantic_mask (list[tensor]): S
emantic mask
pts_instance_mask (list[torch.Tensor]): Point-wise
of points cloud. Defaults to None. Defaults to None.
i
nstance mask
.
batch_pts_semantic_mask (list[tensor]): I
nstance mask
img_metas (list[dict]): Contain pcd and img's meta info
.
of points cloud. Defaults to None. Defaults to None
.
gt_bboxes_ignore (list[torch.Tensor]): Specify
batch_input_metas (list[dict]): Contain pcd and img's meta info.
which bounding
.
ret_target (bool): Return targets or not. Defaults to False
.
Returns:
Returns:
dict: Losses of 3DSSD.
dict: Losses of 3DSSD.
"""
"""
targets
=
self
.
get_targets
(
points
,
gt_bboxes_3d
,
gt_labels_3d
,
pts_semantic_mask
,
pts_instance_mask
,
targets
=
self
.
get_targets
(
points
,
bbox_preds_dict
,
bbox_preds
)
batch_gt_instances_3d
,
batch_pts_semantic_mask
,
batch_pts_instance_mask
)
(
vote_targets
,
center_targets
,
size_res_targets
,
dir_class_targets
,
(
vote_targets
,
center_targets
,
size_res_targets
,
dir_class_targets
,
dir_res_targets
,
mask_targets
,
centerness_targets
,
corner3d_targets
,
dir_res_targets
,
mask_targets
,
centerness_targets
,
corner3d_targets
,
vote_mask
,
positive_mask
,
negative_mask
,
centerness_weights
,
vote_mask
,
positive_mask
,
negative_mask
,
centerness_weights
,
box_loss_weights
,
heading_res_loss_weight
)
=
targets
box_loss_weights
,
heading_res_loss_weight
)
=
targets
# calculate centerness loss
# calculate centerness loss
centerness_loss
=
self
.
objectness
_loss
(
centerness_loss
=
self
.
loss_
objectness
(
bbox_preds
[
'obj_scores'
].
transpose
(
2
,
1
),
bbox_preds
_dict
[
'obj_scores'
].
transpose
(
2
,
1
),
centerness_targets
,
centerness_targets
,
weight
=
centerness_weights
)
weight
=
centerness_weights
)
# calculate center loss
# calculate center loss
center_loss
=
self
.
center
_loss
(
center_loss
=
self
.
loss_
center
(
bbox_preds
[
'center_offset'
],
bbox_preds
_dict
[
'center_offset'
],
center_targets
,
center_targets
,
weight
=
box_loss_weights
.
unsqueeze
(
-
1
))
weight
=
box_loss_weights
.
unsqueeze
(
-
1
))
# calculate direction class loss
# calculate direction class loss
dir_class_loss
=
self
.
dir_class
_loss
(
dir_class_loss
=
self
.
loss_
dir_class
(
bbox_preds
[
'dir_class'
].
transpose
(
1
,
2
),
bbox_preds
_dict
[
'dir_class'
].
transpose
(
1
,
2
),
dir_class_targets
,
dir_class_targets
,
weight
=
box_loss_weights
)
weight
=
box_loss_weights
)
# calculate direction residual loss
# calculate direction residual loss
dir_res_loss
=
self
.
dir_res
_loss
(
dir_res_loss
=
self
.
loss_
dir_res
(
bbox_preds
[
'dir_res_norm'
],
bbox_preds
_dict
[
'dir_res_norm'
],
dir_res_targets
.
unsqueeze
(
-
1
).
repeat
(
1
,
1
,
self
.
num_dir_bins
),
dir_res_targets
.
unsqueeze
(
-
1
).
repeat
(
1
,
1
,
self
.
num_dir_bins
),
weight
=
heading_res_loss_weight
)
weight
=
heading_res_loss_weight
)
# calculate size residual loss
# calculate size residual loss
size_loss
=
self
.
size_res
_loss
(
size_loss
=
self
.
loss_
size_res
(
bbox_preds
[
'size'
],
bbox_preds
_dict
[
'size'
],
size_res_targets
,
size_res_targets
,
weight
=
box_loss_weights
.
unsqueeze
(
-
1
))
weight
=
box_loss_weights
.
unsqueeze
(
-
1
))
# calculate corner loss
# calculate corner loss
one_hot_dir_class_targets
=
dir_class_targets
.
new_zeros
(
one_hot_dir_class_targets
=
dir_class_targets
.
new_zeros
(
bbox_preds
[
'dir_class'
].
shape
)
bbox_preds
_dict
[
'dir_class'
].
shape
)
one_hot_dir_class_targets
.
scatter_
(
2
,
dir_class_targets
.
unsqueeze
(
-
1
),
one_hot_dir_class_targets
.
scatter_
(
2
,
dir_class_targets
.
unsqueeze
(
-
1
),
1
)
1
)
pred_bbox3d
=
self
.
bbox_coder
.
decode
(
pred_bbox3d
=
self
.
bbox_coder
.
decode
(
dict
(
dict
(
center
=
bbox_preds
[
'center'
],
center
=
bbox_preds
_dict
[
'center'
],
dir_res
=
bbox_preds
[
'dir_res'
],
dir_res
=
bbox_preds
_dict
[
'dir_res'
],
dir_class
=
one_hot_dir_class_targets
,
dir_class
=
one_hot_dir_class_targets
,
size
=
bbox_preds
[
'size'
]))
size
=
bbox_preds
_dict
[
'size'
]))
pred_bbox3d
=
pred_bbox3d
.
reshape
(
-
1
,
pred_bbox3d
.
shape
[
-
1
])
pred_bbox3d
=
pred_bbox3d
.
reshape
(
-
1
,
pred_bbox3d
.
shape
[
-
1
])
pred_bbox3d
=
img
_metas
[
0
][
'box_type_3d'
](
pred_bbox3d
=
batch_input
_metas
[
0
][
'box_type_3d'
](
pred_bbox3d
.
clone
(),
pred_bbox3d
.
clone
(),
box_dim
=
pred_bbox3d
.
shape
[
-
1
],
box_dim
=
pred_bbox3d
.
shape
[
-
1
],
with_yaw
=
self
.
bbox_coder
.
with_rot
,
with_yaw
=
self
.
bbox_coder
.
with_rot
,
...
@@ -204,7 +202,7 @@ class SSD3DHead(VoteHead):
...
@@ -204,7 +202,7 @@ class SSD3DHead(VoteHead):
# calculate vote loss
# calculate vote loss
vote_loss
=
self
.
vote_loss
(
vote_loss
=
self
.
vote_loss
(
bbox_preds
[
'vote_offset'
].
transpose
(
1
,
2
),
bbox_preds
_dict
[
'vote_offset'
].
transpose
(
1
,
2
),
vote_targets
,
vote_targets
,
weight
=
vote_mask
.
unsqueeze
(
-
1
))
weight
=
vote_mask
.
unsqueeze
(
-
1
))
...
@@ -219,57 +217,74 @@ class SSD3DHead(VoteHead):
...
@@ -219,57 +217,74 @@ class SSD3DHead(VoteHead):
return
losses
return
losses
def
get_targets
(
self
,
def
get_targets
(
points
,
self
,
gt_bboxes_3d
,
points
:
List
[
Tensor
],
gt_labels_3d
,
bbox_preds_dict
:
dict
=
None
,
pts_semantic_mask
=
None
,
batch_gt_instances_3d
:
List
[
InstanceData
]
=
None
,
pts_instance_mask
=
None
,
batch_pts_semantic_mask
:
List
[
torch
.
Tensor
]
=
None
,
bbox_preds
=
None
):
batch_pts_instance_mask
:
List
[
torch
.
Tensor
]
=
None
,
"""Generate targets of ssd3d head.
)
->
Tuple
[
Tensor
]:
"""Generate targets of 3DSSD head.
Args:
Args:
points (list[torch.Tensor]): Points of each batch.
points (list[torch.Tensor]): Points of each batch.
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth
bbox_preds_dict (dict): Bounding box predictions of
bboxes of each batch.
vote head. Defaults to None.
gt_labels_3d (list[torch.Tensor]): Labels of each batch.
batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
pts_semantic_mask (list[torch.Tensor]): Point-wise semantic
gt_instances. It usually includes ``bboxes`` and ``labels``
label of each batch.
attributes. Defaults to None.
pts_instance_mask (list[torch.Tensor]): Point-wise instance
batch_pts_semantic_mask (list[tensor]): Semantic gt mask for
label of each batch.
point clouds. Defaults to None.
bbox_preds (torch.Tensor): Bounding box predictions of ssd3d head.
batch_pts_instance_mask (list[tensor]): Instance gt mask for
point clouds. Defaults to None.
Returns:
Returns:
tuple[torch.Tensor]: Targets of
ssd3d
head.
tuple[torch.Tensor]: Targets of
3DSSD
head.
"""
"""
# find empty example
batch_gt_labels_3d
=
[
for
index
in
range
(
len
(
gt_labels_3d
)):
gt_instances_3d
.
labels_3d
if
len
(
gt_labels_3d
[
index
])
==
0
:
for
gt_instances_3d
in
batch_gt_instances_3d
fake_box
=
gt_bboxes_3d
[
index
].
tensor
.
new_zeros
(
]
1
,
gt_bboxes_3d
[
index
].
tensor
.
shape
[
-
1
])
batch_gt_bboxes_3d
=
[
gt_bboxes_3d
[
index
]
=
gt_bboxes_3d
[
index
].
new_box
(
fake_box
)
gt_instances_3d
.
bboxes_3d
gt_labels_3d
[
index
]
=
gt_labels_3d
[
index
].
new_zeros
(
1
)
for
gt_instances_3d
in
batch_gt_instances_3d
]
if
pts_semantic_mask
is
None
:
# find empty example
pts_semantic_mask
=
[
None
for
i
in
range
(
len
(
gt_labels_3d
))]
for
index
in
range
(
len
(
batch_gt_labels_3d
)):
pts_instance_mask
=
[
None
for
i
in
range
(
len
(
gt_labels_3d
))]
if
len
(
batch_gt_labels_3d
[
index
])
==
0
:
fake_box
=
batch_gt_bboxes_3d
[
index
].
tensor
.
new_zeros
(
1
,
batch_gt_bboxes_3d
[
index
].
tensor
.
shape
[
-
1
])
batch_gt_bboxes_3d
[
index
]
=
batch_gt_bboxes_3d
[
index
].
new_box
(
fake_box
)
batch_gt_labels_3d
[
index
]
=
batch_gt_labels_3d
[
index
].
new_zeros
(
1
)
if
batch_pts_semantic_mask
is
None
:
batch_pts_semantic_mask
=
[
None
for
_
in
range
(
len
(
batch_gt_labels_3d
))
]
batch_pts_instance_mask
=
[
None
for
_
in
range
(
len
(
batch_gt_labels_3d
))
]
aggregated_points
=
[
aggregated_points
=
[
bbox_preds
[
'aggregated_points'
][
i
]
bbox_preds
_dict
[
'aggregated_points'
][
i
]
for
i
in
range
(
len
(
gt_labels_3d
))
for
i
in
range
(
len
(
batch_
gt_labels_3d
))
]
]
seed_points
=
[
seed_points
=
[
bbox_preds
[
'seed_points'
][
i
,
:
self
.
num_candidates
].
detach
()
bbox_preds
_dict
[
'seed_points'
][
i
,
:
self
.
num_candidates
].
detach
()
for
i
in
range
(
len
(
gt_labels_3d
))
for
i
in
range
(
len
(
batch_
gt_labels_3d
))
]
]
(
vote_targets
,
center_targets
,
size_res_targets
,
dir_class_targets
,
(
vote_targets
,
center_targets
,
size_res_targets
,
dir_class_targets
,
dir_res_targets
,
mask_targets
,
centerness_targets
,
corner3d_targets
,
dir_res_targets
,
mask_targets
,
centerness_targets
,
corner3d_targets
,
vote_mask
,
positive_mask
,
negative_mask
)
=
multi_apply
(
vote_mask
,
positive_mask
,
negative_mask
)
=
multi_apply
(
self
.
get_targets_single
,
points
,
gt_bboxes_3d
,
gt_labels_3d
,
self
.
get_targets_single
,
points
,
batch_
gt_bboxes_3d
,
pts_semantic_mask
,
pts_instance_mask
,
aggregated_points
,
batch_gt_labels_3d
,
batch_pts_semantic_mask
,
seed_points
)
batch_pts_instance_mask
,
aggregated_points
,
seed_points
)
center_targets
=
torch
.
stack
(
center_targets
)
center_targets
=
torch
.
stack
(
center_targets
)
positive_mask
=
torch
.
stack
(
positive_mask
)
positive_mask
=
torch
.
stack
(
positive_mask
)
...
@@ -283,7 +298,7 @@ class SSD3DHead(VoteHead):
...
@@ -283,7 +298,7 @@ class SSD3DHead(VoteHead):
vote_targets
=
torch
.
stack
(
vote_targets
)
vote_targets
=
torch
.
stack
(
vote_targets
)
vote_mask
=
torch
.
stack
(
vote_mask
)
vote_mask
=
torch
.
stack
(
vote_mask
)
center_targets
-=
bbox_preds
[
'aggregated_points'
]
center_targets
-=
bbox_preds
_dict
[
'aggregated_points'
]
centerness_weights
=
(
positive_mask
+
centerness_weights
=
(
positive_mask
+
negative_mask
).
unsqueeze
(
-
1
).
repeat
(
negative_mask
).
unsqueeze
(
-
1
).
repeat
(
...
@@ -308,13 +323,14 @@ class SSD3DHead(VoteHead):
...
@@ -308,13 +323,14 @@ class SSD3DHead(VoteHead):
heading_res_loss_weight
)
heading_res_loss_weight
)
def
get_targets_single
(
self
,
def
get_targets_single
(
self
,
points
,
points
:
Tensor
,
gt_bboxes_3d
,
gt_bboxes_3d
:
BaseInstance3DBoxes
,
gt_labels_3d
,
gt_labels_3d
:
Tensor
,
pts_semantic_mask
=
None
,
pts_semantic_mask
:
Optional
[
Tensor
]
=
None
,
pts_instance_mask
=
None
,
pts_instance_mask
:
Optional
[
Tensor
]
=
None
,
aggregated_points
=
None
,
aggregated_points
:
Optional
[
Tensor
]
=
None
,
seed_points
=
None
):
seed_points
:
Optional
[
Tensor
]
=
None
,
**
kwargs
):
"""Generate targets of ssd3d head for single batch.
"""Generate targets of ssd3d head for single batch.
Args:
Args:
...
@@ -440,41 +456,50 @@ class SSD3DHead(VoteHead):
...
@@ -440,41 +456,50 @@ class SSD3DHead(VoteHead):
centerness_targets
,
corner3d_targets
,
vote_mask
,
positive_mask
,
centerness_targets
,
corner3d_targets
,
vote_mask
,
positive_mask
,
negative_mask
)
negative_mask
)
def
get_bboxes
(
self
,
points
,
bbox_preds
,
input_metas
,
rescale
=
False
):
def
predict_by_feat
(
self
,
points
:
List
[
torch
.
Tensor
],
"""Generate bboxes from 3DSSD head predictions.
bbox_preds_dict
:
dict
,
batch_input_metas
:
List
[
dict
],
**
kwargs
)
->
List
[
InstanceData
]:
"""Generate bboxes from vote head predictions.
Args:
Args:
points (torch.Tensor): Input points.
points (
List[
torch.Tensor
]
): Input points
of multiple samples
.
bbox_preds (dict): Predictions from
sdd3d
head.
bbox_preds
_dict
(dict): Predictions from
vote
head.
input_metas (list[dict]):
Point cloud and image's meta info.
batch_
input_metas (list[dict]):
Each item
rescale (bool): Whether to rescale bboxes
.
contains the meta information of each sample
.
Returns:
Returns:
list[tuple[torch.Tensor]]: Bounding boxes, scores and labels.
list[:obj:`InstanceData`]: List of processed predictions. Each
InstanceData cantains 3d Bounding boxes and corresponding
scores and labels.
"""
"""
# decode boxes
# decode boxes
sem_scores
=
F
.
sigmoid
(
bbox_preds
[
'obj_scores'
]).
transpose
(
1
,
2
)
sem_scores
=
F
.
sigmoid
(
bbox_preds
_dict
[
'obj_scores'
]).
transpose
(
1
,
2
)
obj_scores
=
sem_scores
.
max
(
-
1
)[
0
]
obj_scores
=
sem_scores
.
max
(
-
1
)[
0
]
bbox3d
=
self
.
bbox_coder
.
decode
(
bbox_preds
)
bbox3d
=
self
.
bbox_coder
.
decode
(
bbox_preds_dict
)
batch_size
=
bbox3d
.
shape
[
0
]
batch_size
=
bbox3d
.
shape
[
0
]
results
=
list
(
)
points
=
torch
.
stack
(
points
)
results_list
=
[]
for
b
in
range
(
batch_size
):
for
b
in
range
(
batch_size
):
temp_results
=
InstanceData
()
bbox_selected
,
score_selected
,
labels
=
self
.
multiclass_nms_single
(
bbox_selected
,
score_selected
,
labels
=
self
.
multiclass_nms_single
(
obj_scores
[
b
],
sem_scores
[
b
],
bbox3d
[
b
],
points
[
b
,
...,
:
3
],
obj_scores
[
b
],
sem_scores
[
b
],
bbox3d
[
b
],
points
[
b
,
...,
:
3
],
input_metas
[
b
])
batch_
input_metas
[
b
])
bbox
=
input_metas
[
b
][
'box_type_3d'
](
bbox
=
batch_
input_metas
[
b
][
'box_type_3d'
](
bbox_selected
.
clone
(),
bbox_selected
.
clone
(),
box_dim
=
bbox_selected
.
shape
[
-
1
],
box_dim
=
bbox_selected
.
shape
[
-
1
],
with_yaw
=
self
.
bbox_coder
.
with_rot
)
with_yaw
=
self
.
bbox_coder
.
with_rot
)
results
.
append
((
bbox
,
score_selected
,
labels
))
return
results
temp_results
.
bboxes_3d
=
bbox
temp_results
.
scores_3d
=
score_selected
temp_results
.
labels_3d
=
labels
results_list
.
append
(
temp_results
)
return
results_list
def
multiclass_nms_single
(
self
,
obj_scores
,
sem_scores
,
bbox
,
points
,
def
multiclass_nms_single
(
self
,
obj_scores
:
Tensor
,
sem_scores
:
Tensor
,
input_meta
):
bbox
:
Tensor
,
points
:
Tensor
,
input_meta
:
dict
)
->
Tuple
[
Tensor
]:
"""Multi-class nms in single batch.
"""Multi-class nms in single batch.
Args:
Args:
...
@@ -538,7 +563,8 @@ class SSD3DHead(VoteHead):
...
@@ -538,7 +563,8 @@ class SSD3DHead(VoteHead):
return
bbox_selected
,
score_selected
,
labels
return
bbox_selected
,
score_selected
,
labels
def
_assign_targets_by_points_inside
(
self
,
bboxes_3d
,
points
):
def
_assign_targets_by_points_inside
(
self
,
bboxes_3d
:
BaseInstance3DBoxes
,
points
:
Tensor
)
->
Tuple
:
"""Compute assignment by checking whether point is inside bbox.
"""Compute assignment by checking whether point is inside bbox.
Args:
Args:
...
...
mmdet3d/models/detectors/ssd3dnet.py
View file @
3c57cc41
...
@@ -16,11 +16,11 @@ class SSD3DNet(VoteNet):
...
@@ -16,11 +16,11 @@ class SSD3DNet(VoteNet):
train_cfg
=
None
,
train_cfg
=
None
,
test_cfg
=
None
,
test_cfg
=
None
,
init_cfg
=
None
,
init_cfg
=
None
,
pretrained
=
None
):
**
kwargs
):
super
(
SSD3DNet
,
self
).
__init__
(
super
(
SSD3DNet
,
self
).
__init__
(
backbone
=
backbone
,
backbone
=
backbone
,
bbox_head
=
bbox_head
,
bbox_head
=
bbox_head
,
train_cfg
=
train_cfg
,
train_cfg
=
train_cfg
,
test_cfg
=
test_cfg
,
test_cfg
=
test_cfg
,
init_cfg
=
init_cfg
,
init_cfg
=
init_cfg
,
pretrained
=
pretrained
)
**
kwargs
)
tests/test_models/test_detectors/test_3dssd.py
0 → 100644
View file @
3c57cc41
import
unittest
import
torch
from
mmengine
import
DefaultScope
from
mmdet3d.registry
import
MODELS
from
tests.utils.model_utils
import
(
_create_detector_inputs
,
_get_detector_cfg
,
_setup_seed
)
class
Test3DSSD
(
unittest
.
TestCase
):
def
test_3dssd
(
self
):
import
mmdet3d.models
assert
hasattr
(
mmdet3d
.
models
,
'SSD3DNet'
)
DefaultScope
.
get_instance
(
'test_ssd3d'
,
scope_name
=
'mmdet3d'
)
_setup_seed
(
0
)
voxel_net_cfg
=
_get_detector_cfg
(
'3dssd/3dssd_4x4_kitti-3d-car.py'
)
model
=
MODELS
.
build
(
voxel_net_cfg
)
num_gt_instance
=
3
data
=
[
_create_detector_inputs
(
num_gt_instance
=
num_gt_instance
,
num_classes
=
1
)
]
if
torch
.
cuda
.
is_available
():
model
=
model
.
cuda
()
# test simple_test
with
torch
.
no_grad
():
batch_inputs
,
data_samples
=
model
.
data_preprocessor
(
data
,
True
)
results
=
model
.
forward
(
batch_inputs
,
data_samples
,
mode
=
'predict'
)
self
.
assertEqual
(
len
(
results
),
len
(
data
))
self
.
assertIn
(
'bboxes_3d'
,
results
[
0
].
pred_instances_3d
)
self
.
assertIn
(
'scores_3d'
,
results
[
0
].
pred_instances_3d
)
self
.
assertIn
(
'labels_3d'
,
results
[
0
].
pred_instances_3d
)
losses
=
model
.
forward
(
batch_inputs
,
data_samples
,
mode
=
'loss'
)
self
.
assertGreater
(
losses
[
'centerness_loss'
],
0
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment