Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
mmdetection3d
Commits
522cc20d
Commit
522cc20d
authored
Jul 15, 2022
by
VVsssssk
Committed by
ChaimZhu
Jul 20, 2022
Browse files
[Refactor]Refactor ShapeAwareHead and FreeAnchor3DHead
parent
3c57cc41
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
377 additions
and
182 deletions
+377
-182
configs/_base_/models/hv_pointpillars_fpn_nus.py
configs/_base_/models/hv_pointpillars_fpn_nus.py
+1
-1
configs/_base_/schedules/cyclic_20e.py
configs/_base_/schedules/cyclic_20e.py
+1
-1
configs/_base_/schedules/schedule_2x.py
configs/_base_/schedules/schedule_2x.py
+1
-1
configs/free_anchor/hv_pointpillars_fpn_sbn-all_free-anchor_4x8_2x_nus-3d.py
.../hv_pointpillars_fpn_sbn-all_free-anchor_4x8_2x_nus-3d.py
+5
-3
configs/free_anchor/hv_pointpillars_regnet-1.6gf_fpn_sbn-all_free-anchor_strong-aug_4x8_3x_nus-3d.py
...1.6gf_fpn_sbn-all_free-anchor_strong-aug_4x8_3x_nus-3d.py
+20
-6
configs/free_anchor/hv_pointpillars_regnet-3.2gf_fpn_sbn-all_free-anchor_strong-aug_4x8_3x_nus-3d.py
...3.2gf_fpn_sbn-all_free-anchor_strong-aug_4x8_3x_nus-3d.py
+19
-5
configs/pointpillars/hv_pointpillars_secfpn_sbn-all_2x8_2x_lyft-3d.py
...tpillars/hv_pointpillars_secfpn_sbn-all_2x8_2x_lyft-3d.py
+0
-6
configs/ssn/hv_ssn_secfpn_sbn-all_2x16_2x_nus-3d.py
configs/ssn/hv_ssn_secfpn_sbn-all_2x16_2x_nus-3d.py
+29
-28
mmdet3d/models/dense_heads/base_3d_dense_head.py
mmdet3d/models/dense_heads/base_3d_dense_head.py
+2
-2
mmdet3d/models/dense_heads/free_anchor3d_head.py
mmdet3d/models/dense_heads/free_anchor3d_head.py
+58
-52
mmdet3d/models/dense_heads/shape_aware_head.py
mmdet3d/models/dense_heads/shape_aware_head.py
+97
-76
tests/test_models/test_dense_heads/test_freeanchors.py
tests/test_models/test_dense_heads/test_freeanchors.py
+71
-0
tests/test_models/test_dense_heads/test_ssn.py
tests/test_models/test_dense_heads/test_ssn.py
+71
-0
tests/utils/model_utils.py
tests/utils/model_utils.py
+2
-1
No files found.
configs/_base_/models/hv_pointpillars_fpn_nus.py
View file @
522cc20d
...
...
@@ -32,7 +32,7 @@ model = dict(
layer_strides
=
[
2
,
2
,
2
],
out_channels
=
[
64
,
128
,
256
]),
pts_neck
=
dict
(
type
=
'FPN'
,
type
=
'
mmdet.
FPN'
,
norm_cfg
=
dict
(
type
=
'naiveSyncBN2d'
,
eps
=
1e-3
,
momentum
=
0.01
),
act_cfg
=
dict
(
type
=
'ReLU'
),
in_channels
=
[
64
,
128
,
256
],
...
...
configs/_base_/schedules/cyclic_20e.py
View file @
522cc20d
...
...
@@ -47,6 +47,6 @@ param_scheduler = [
]
# runtime settings
train_cfg
=
dict
(
by_epoch
=
True
,
max_epochs
=
20
,
val_interval
=
1
)
train_cfg
=
dict
(
by_epoch
=
True
,
max_epochs
=
20
,
val_interval
=
20
)
val_cfg
=
dict
()
test_cfg
=
dict
()
configs/_base_/schedules/schedule_2x.py
View file @
522cc20d
...
...
@@ -8,7 +8,7 @@ optim_wrapper = dict(
clip_grad
=
dict
(
max_norm
=
35
,
norm_type
=
2
))
# training schedule for 2x
train_cfg
=
dict
(
type
=
'EpochBasedTrainLoop'
,
max_epochs
=
24
,
val_interval
=
1
)
train_cfg
=
dict
(
type
=
'EpochBasedTrainLoop'
,
max_epochs
=
24
,
val_interval
=
24
)
val_cfg
=
dict
(
type
=
'ValLoop'
)
test_cfg
=
dict
(
type
=
'TestLoop'
)
...
...
configs/free_anchor/hv_pointpillars_fpn_sbn-all_free-anchor_4x8_2x_nus-3d.py
View file @
522cc20d
...
...
@@ -34,14 +34,16 @@ model = dict(
dir_offset
=-
0.7854
,
# -pi / 4
bbox_coder
=
dict
(
type
=
'DeltaXYZWLHRBBoxCoder'
,
code_size
=
9
),
loss_cls
=
dict
(
type
=
'FocalLoss'
,
type
=
'
mmdet.
FocalLoss'
,
use_sigmoid
=
True
,
gamma
=
2.0
,
alpha
=
0.25
,
loss_weight
=
1.0
),
loss_bbox
=
dict
(
type
=
'SmoothL1Loss'
,
beta
=
1.0
/
9.0
,
loss_weight
=
0.8
),
loss_bbox
=
dict
(
type
=
'mmdet.SmoothL1Loss'
,
beta
=
1.0
/
9.0
,
loss_weight
=
0.8
),
loss_dir
=
dict
(
type
=
'CrossEntropyLoss'
,
use_sigmoid
=
False
,
loss_weight
=
0.2
)),
type
=
'mmdet.CrossEntropyLoss'
,
use_sigmoid
=
False
,
loss_weight
=
0.2
)),
# model training and testing settings
train_cfg
=
dict
(
pts
=
dict
(
code_weight
=
[
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
0.25
,
0.25
])))
configs/free_anchor/hv_pointpillars_regnet-1.6gf_fpn_sbn-all_free-anchor_strong-aug_4x8_3x_nus-3d.py
View file @
522cc20d
...
...
@@ -60,11 +60,25 @@ train_pipeline = [
dict
(
type
=
'ObjectRangeFilter'
,
point_cloud_range
=
point_cloud_range
),
dict
(
type
=
'ObjectNameFilter'
,
classes
=
class_names
),
dict
(
type
=
'PointShuffle'
),
dict
(
type
=
'DefaultFormatBundle3D'
,
class_names
=
class_names
),
dict
(
type
=
'Collect3D'
,
keys
=
[
'points'
,
'gt_bboxes_3d'
,
'gt_labels_3d'
])
dict
(
type
=
'Pack3DDetInputs'
,
keys
=
[
'points'
,
'gt_bboxes_3d'
,
'gt_labels_3d'
])
]
data
=
dict
(
train
=
dict
(
pipeline
=
train_pipeline
))
train_dataloader
=
dict
(
dataset
=
dict
(
pipeline
=
train_pipeline
))
lr_config
=
dict
(
step
=
[
28
,
34
])
runner
=
dict
(
max_epochs
=
36
)
evaluation
=
dict
(
interval
=
36
)
train_cfg
=
dict
(
max_epochs
=
36
,
val_interval
=
36
)
param_scheduler
=
[
dict
(
type
=
'LinearLR'
,
start_factor
=
1.0
/
1000
,
by_epoch
=
False
,
begin
=
0
,
end
=
1000
),
dict
(
type
=
'MultiStepLR'
,
begin
=
0
,
end
=
24
,
by_epoch
=
True
,
milestones
=
[
28
,
34
],
gamma
=
0.1
)
]
configs/free_anchor/hv_pointpillars_regnet-3.2gf_fpn_sbn-all_free-anchor_strong-aug_4x8_3x_nus-3d.py
View file @
522cc20d
...
...
@@ -60,11 +60,25 @@ train_pipeline = [
dict
(
type
=
'ObjectRangeFilter'
,
point_cloud_range
=
point_cloud_range
),
dict
(
type
=
'ObjectNameFilter'
,
classes
=
class_names
),
dict
(
type
=
'PointShuffle'
),
dict
(
type
=
'DefaultFormatBundle3D'
,
class_names
=
class_names
),
dict
(
type
=
'Collect3D'
,
keys
=
[
'points'
,
'gt_bboxes_3d'
,
'gt_labels_3d'
])
]
train_dataloader
=
dict
(
dataset
=
dict
(
pipeline
=
train_pipeline
))
data
=
dict
(
train
=
dict
(
pipeline
=
train_pipeline
))
lr_config
=
dict
(
step
=
[
28
,
34
])
runner
=
dict
(
max_epochs
=
36
)
evaluation
=
dict
(
interval
=
36
)
train_cfg
=
dict
(
max_epochs
=
36
,
val_interval
=
36
)
# learning rate
param_scheduler
=
[
dict
(
type
=
'LinearLR'
,
start_factor
=
1.0
/
1000
,
by_epoch
=
False
,
begin
=
0
,
end
=
1000
),
dict
(
type
=
'MultiStepLR'
,
begin
=
0
,
end
=
36
,
by_epoch
=
True
,
milestones
=
[
28
,
34
],
gamma
=
0.1
)
]
configs/pointpillars/hv_pointpillars_secfpn_sbn-all_2x8_2x_lyft-3d.py
View file @
522cc20d
...
...
@@ -41,9 +41,3 @@ model = dict(
],
rotations
=
[
0
,
1.57
],
reshape_out
=
True
)))
# For Lyft dataset, we usually evaluate the model at the end of training.
# Since the models are trained by 24 epochs by default, we set evaluation
# interval to be 24. Please change the interval accordingly if you do not
# use a default schedule.
train_cfg
=
dict
(
val_interval
=
24
)
configs/ssn/hv_ssn_secfpn_sbn-all_2x16_2x_nus-3d.py
View file @
522cc20d
...
...
@@ -29,8 +29,9 @@ train_pipeline = [
dict
(
type
=
'PointsRangeFilter'
,
point_cloud_range
=
point_cloud_range
),
dict
(
type
=
'ObjectRangeFilter'
,
point_cloud_range
=
point_cloud_range
),
dict
(
type
=
'PointShuffle'
),
dict
(
type
=
'DefaultFormatBundle3D'
,
class_names
=
class_names
),
dict
(
type
=
'Collect3D'
,
keys
=
[
'points'
,
'gt_bboxes_3d'
,
'gt_labels_3d'
])
dict
(
type
=
'Pack3DDetInputs'
,
keys
=
[
'points'
,
'gt_bboxes_3d'
,
'gt_labels_3d'
])
]
test_pipeline
=
[
dict
(
type
=
'LoadPointsFromFile'
,
coord_type
=
'LIDAR'
,
load_dim
=
5
,
use_dim
=
5
),
...
...
@@ -48,20 +49,18 @@ test_pipeline = [
translation_std
=
[
0
,
0
,
0
]),
dict
(
type
=
'RandomFlip3D'
),
dict
(
type
=
'PointsRangeFilter'
,
point_cloud_range
=
point_cloud_range
),
dict
(
type
=
'DefaultFormatBundle3D'
,
class_names
=
class_names
,
with_label
=
False
),
dict
(
type
=
'Collect3D'
,
keys
=
[
'points'
])
])
type
=
'PointsRangeFilter'
,
point_cloud_range
=
point_cloud_range
)
]),
dict
(
type
=
'Pack3DDetInputs'
,
keys
=
[
'points'
])
]
data
=
dict
(
samples_per_gpu
=
2
,
workers_per_gpu
=
4
,
train
=
dict
(
pipeline
=
train_pipeline
,
classes
=
class_names
),
val
=
dict
(
pipeline
=
test_pipeline
,
classes
=
class_names
),
test
=
dict
(
pipeline
=
test_pipeline
,
classes
=
class_names
))
train_dataloader
=
dict
(
batch_size
=
2
,
num_workers
=
4
,
dataset
=
dict
(
pipeline
=
train_pipeline
,
metainfo
=
dict
(
CLASSES
=
class_names
)))
test_dataloader
=
dict
(
dataset
=
dict
(
pipeline
=
test_pipeline
,
metainfo
=
dict
(
CLASSES
=
class_names
)))
val_dataloader
=
dict
(
dataset
=
dict
(
pipeline
=
test_pipeline
,
metainfo
=
dict
(
CLASSES
=
class_names
)))
# model settings
model
=
dict
(
...
...
@@ -148,84 +147,86 @@ model = dict(
dir_limit_offset
=
0
,
bbox_coder
=
dict
(
type
=
'DeltaXYZWLHRBBoxCoder'
,
code_size
=
9
),
loss_cls
=
dict
(
type
=
'FocalLoss'
,
type
=
'
mmdet.
FocalLoss'
,
use_sigmoid
=
True
,
gamma
=
2.0
,
alpha
=
0.25
,
loss_weight
=
1.0
),
loss_bbox
=
dict
(
type
=
'SmoothL1Loss'
,
beta
=
1.0
/
9.0
,
loss_weight
=
1.0
),
loss_bbox
=
dict
(
type
=
'mmdet.SmoothL1Loss'
,
beta
=
1.0
/
9.0
,
loss_weight
=
1.0
),
loss_dir
=
dict
(
type
=
'CrossEntropyLoss'
,
use_sigmoid
=
False
,
loss_weight
=
0.2
)),
type
=
'mmdet.CrossEntropyLoss'
,
use_sigmoid
=
False
,
loss_weight
=
0.2
)),
# model training and testing settings
train_cfg
=
dict
(
_delete_
=
True
,
pts
=
dict
(
assigner
=
[
dict
(
# bicycle
type
=
'MaxIoUAssigner'
,
type
=
'Max
3D
IoUAssigner'
,
iou_calculator
=
dict
(
type
=
'BboxOverlapsNearest3D'
),
pos_iou_thr
=
0.5
,
neg_iou_thr
=
0.35
,
min_pos_iou
=
0.35
,
ignore_iof_thr
=-
1
),
dict
(
# motorcycle
type
=
'MaxIoUAssigner'
,
type
=
'Max
3D
IoUAssigner'
,
iou_calculator
=
dict
(
type
=
'BboxOverlapsNearest3D'
),
pos_iou_thr
=
0.5
,
neg_iou_thr
=
0.3
,
min_pos_iou
=
0.3
,
ignore_iof_thr
=-
1
),
dict
(
# pedestrian
type
=
'MaxIoUAssigner'
,
type
=
'Max
3D
IoUAssigner'
,
iou_calculator
=
dict
(
type
=
'BboxOverlapsNearest3D'
),
pos_iou_thr
=
0.6
,
neg_iou_thr
=
0.4
,
min_pos_iou
=
0.4
,
ignore_iof_thr
=-
1
),
dict
(
# traffic cone
type
=
'MaxIoUAssigner'
,
type
=
'Max
3D
IoUAssigner'
,
iou_calculator
=
dict
(
type
=
'BboxOverlapsNearest3D'
),
pos_iou_thr
=
0.6
,
neg_iou_thr
=
0.4
,
min_pos_iou
=
0.4
,
ignore_iof_thr
=-
1
),
dict
(
# barrier
type
=
'MaxIoUAssigner'
,
type
=
'Max
3D
IoUAssigner'
,
iou_calculator
=
dict
(
type
=
'BboxOverlapsNearest3D'
),
pos_iou_thr
=
0.55
,
neg_iou_thr
=
0.4
,
min_pos_iou
=
0.4
,
ignore_iof_thr
=-
1
),
dict
(
# car
type
=
'MaxIoUAssigner'
,
type
=
'Max
3D
IoUAssigner'
,
iou_calculator
=
dict
(
type
=
'BboxOverlapsNearest3D'
),
pos_iou_thr
=
0.6
,
neg_iou_thr
=
0.45
,
min_pos_iou
=
0.45
,
ignore_iof_thr
=-
1
),
dict
(
# truck
type
=
'MaxIoUAssigner'
,
type
=
'Max
3D
IoUAssigner'
,
iou_calculator
=
dict
(
type
=
'BboxOverlapsNearest3D'
),
pos_iou_thr
=
0.55
,
neg_iou_thr
=
0.4
,
min_pos_iou
=
0.4
,
ignore_iof_thr
=-
1
),
dict
(
# trailer
type
=
'MaxIoUAssigner'
,
type
=
'Max
3D
IoUAssigner'
,
iou_calculator
=
dict
(
type
=
'BboxOverlapsNearest3D'
),
pos_iou_thr
=
0.5
,
neg_iou_thr
=
0.35
,
min_pos_iou
=
0.35
,
ignore_iof_thr
=-
1
),
dict
(
# bus
type
=
'MaxIoUAssigner'
,
type
=
'Max
3D
IoUAssigner'
,
iou_calculator
=
dict
(
type
=
'BboxOverlapsNearest3D'
),
pos_iou_thr
=
0.55
,
neg_iou_thr
=
0.4
,
min_pos_iou
=
0.4
,
ignore_iof_thr
=-
1
),
dict
(
# construction vehicle
type
=
'MaxIoUAssigner'
,
type
=
'Max
3D
IoUAssigner'
,
iou_calculator
=
dict
(
type
=
'BboxOverlapsNearest3D'
),
pos_iou_thr
=
0.5
,
neg_iou_thr
=
0.35
,
...
...
mmdet3d/models/dense_heads/base_3d_dense_head.py
View file @
522cc20d
...
...
@@ -264,8 +264,8 @@ class Base3DDenseHead(BaseModule, metaclass=ABCMeta):
cfg
:
ConfigDict
,
rescale
:
bool
=
False
,
**
kwargs
)
->
InstanceData
:
"""Transform a single
imag
e's features extracted from the head
into
bbox results.
"""Transform a single
points sampl
e's features extracted from the head
into
bbox results.
Args:
cls_score_list (list[Tensor]): Box scores from all scale
...
...
mmdet3d/models/dense_heads/free_anchor3d_head.py
View file @
522cc20d
# Copyright (c) OpenMMLab. All rights reserved.
from
typing
import
Dict
,
List
import
torch
from
mmcv.runner
import
force_fp32
from
torch
import
Tensor
from
torch.nn
import
functional
as
F
from
mmdet3d.core.bbox
import
bbox_overlaps_nearest_3d
from
mmdet3d.core.utils
import
InstanceList
,
OptInstanceList
from
mmdet3d.registry
import
MODELS
from
.anchor3d_head
import
Anchor3DHead
from
.train_mixins
import
get_direction_target
...
...
@@ -29,27 +32,26 @@ class FreeAnchor3DHead(Anchor3DHead):
"""
# noqa: E501
def
__init__
(
self
,
pre_anchor_topk
=
50
,
bbox_thr
=
0.6
,
gamma
=
2.0
,
alpha
=
0.5
,
init_cfg
=
None
,
**
kwargs
):
pre_anchor_topk
:
int
=
50
,
bbox_thr
:
float
=
0.6
,
gamma
:
float
=
2.0
,
alpha
:
float
=
0.5
,
init_cfg
:
dict
=
None
,
**
kwargs
)
->
None
:
super
().
__init__
(
init_cfg
=
init_cfg
,
**
kwargs
)
self
.
pre_anchor_topk
=
pre_anchor_topk
self
.
bbox_thr
=
bbox_thr
self
.
gamma
=
gamma
self
.
alpha
=
alpha
@
force_fp32
(
apply_to
=
(
'cls_scores'
,
'bbox_preds'
,
'dir_cls_preds'
))
def
loss
(
self
,
cls_scores
,
bbox_preds
,
dir_cls_preds
,
gt_bboxes
,
gt_labels
,
input_metas
,
gt_bboxes_ignore
=
None
):
def
loss_by_feat
(
self
,
cls_scores
:
List
[
Tensor
],
bbox_preds
:
List
[
Tensor
],
dir_cls_preds
:
List
[
Tensor
],
batch_gt_instances_3d
:
InstanceList
,
batch_input_metas
:
List
[
dict
],
batch_gt_instances_ignore
:
OptInstanceList
=
None
)
->
Dict
:
"""Calculate loss of FreeAnchor head.
Args:
...
...
@@ -59,11 +61,14 @@ class FreeAnchor3DHead(Anchor3DHead):
different samples
dir_cls_preds (list[torch.Tensor]): Direction predictions of
different samples
gt_bboxes (list[:obj:`BaseInstance3DBoxes`]): Ground truth boxes.
gt_labels (list[torch.Tensor]): Ground truth labels.
input_metas (list[dict]): List of input meta information.
gt_bboxes_ignore (list[:obj:`BaseInstance3DBoxes`], optional):
Ground truth boxes that should be ignored. Defaults to None.
batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
gt_instances. It usually includes ``bboxes_3d`` and
``labels_3d`` attributes.
batch_input_metas (list[dict]): Contain pcd and img's meta info.
batch_gt_instances_ignore (list[:obj:`InstanceData`], optional):
Batch of gt_instances_ignore. It includes ``bboxes`` attribute
data that is ignored during training and testing.
Defaults to None.
Returns:
dict[str, torch.Tensor]: Loss items.
...
...
@@ -72,10 +77,10 @@ class FreeAnchor3DHead(Anchor3DHead):
- negative_bag_loss (torch.Tensor): Loss of negative samples.
"""
featmap_sizes
=
[
featmap
.
size
()[
-
2
:]
for
featmap
in
cls_scores
]
assert
len
(
featmap_sizes
)
==
self
.
anch
or_generator
.
num_levels
assert
len
(
featmap_sizes
)
==
self
.
pri
or_generator
.
num_levels
anchor_list
=
self
.
get_anchors
(
featmap_sizes
,
input_metas
)
anchors
=
[
torch
.
cat
(
anchor
)
for
anchor
in
anchor_list
]
anchor_list
=
self
.
get_anchors
(
featmap_sizes
,
batch_
input_metas
)
mlvl_
anchors
=
[
torch
.
cat
(
anchor
)
for
anchor
in
anchor_list
]
# concatenate each level
cls_scores
=
[
...
...
@@ -98,24 +103,24 @@ class FreeAnchor3DHead(Anchor3DHead):
bbox_preds
=
torch
.
cat
(
bbox_preds
,
dim
=
1
)
dir_cls_preds
=
torch
.
cat
(
dir_cls_preds
,
dim
=
1
)
cls_prob
=
torch
.
sigmoid
(
cls_scores
)
cls_prob
s
=
torch
.
sigmoid
(
cls_scores
)
box_prob
=
[]
num_pos
=
0
positive_losses
=
[]
for
_
,
(
anchors_
,
gt_labels_
,
gt_bboxes_
,
cls_prob_
,
bbox_preds_
,
dir_cls_preds_
)
in
enumerate
(
zip
(
anchors
,
gt_labels
,
gt_bboxes
,
cls_prob
,
bbox_preds
,
dir_cls_preds
)):
gt_bboxes_
=
gt_bboxes_
.
tensor
.
to
(
anchors_
.
device
)
for
_
,
(
anchors
,
gt_instance_3d
,
cls_prob
,
bbox_pred
,
dir_cls_pred
)
in
enumerate
(
zip
(
mlvl_anchors
,
batch_gt_instances_3d
,
cls_probs
,
bbox_preds
,
dir_cls_preds
)):
gt_bboxes
=
gt_instance_3d
.
bboxes_3d
.
tensor
.
to
(
anchors
.
device
)
gt_labels
=
gt_instance_3d
.
labels_3d
.
to
(
anchors
.
device
)
with
torch
.
no_grad
():
# box_localization: a_{j}^{loc}, shape: [j, 4]
pred_boxes
=
self
.
bbox_coder
.
decode
(
anchors
_
,
bbox_pred
s_
)
pred_boxes
=
self
.
bbox_coder
.
decode
(
anchors
,
bbox_pred
)
# object_box_iou: IoU_{ij}^{loc}, shape: [i, j]
object_box_iou
=
bbox_overlaps_nearest_3d
(
gt_bboxes
_
,
pred_boxes
)
gt_bboxes
,
pred_boxes
)
# object_box_prob: P{a_{j} -> b_{i}}, shape: [i, j]
t1
=
self
.
bbox_thr
...
...
@@ -125,9 +130,9 @@ class FreeAnchor3DHead(Anchor3DHead):
min
=
0
,
max
=
1
)
# object_cls_box_prob: P{a_{j} -> b_{i}}, shape: [i, c, j]
num_obj
=
gt_labels
_
.
size
(
0
)
num_obj
=
gt_labels
.
size
(
0
)
indices
=
torch
.
stack
(
[
torch
.
arange
(
num_obj
).
type_as
(
gt_labels
_
),
gt_labels
_
],
[
torch
.
arange
(
num_obj
).
type_as
(
gt_labels
),
gt_labels
],
dim
=
0
)
object_cls_box_prob
=
torch
.
sparse_coo_tensor
(
...
...
@@ -147,11 +152,11 @@ class FreeAnchor3DHead(Anchor3DHead):
indices
=
torch
.
nonzero
(
box_cls_prob
,
as_tuple
=
False
).
t_
()
if
indices
.
numel
()
==
0
:
image_box_prob
=
torch
.
zeros
(
anchors
_
.
size
(
0
),
anchors
.
size
(
0
),
self
.
num_classes
).
type_as
(
object_box_prob
)
else
:
nonzero_box_prob
=
torch
.
where
(
(
gt_labels
_
.
unsqueeze
(
dim
=-
1
)
==
indices
[
0
]),
(
gt_labels
.
unsqueeze
(
dim
=-
1
)
==
indices
[
0
]),
object_box_prob
[:,
indices
[
1
]],
torch
.
tensor
(
[
0
]).
type_as
(
object_box_prob
)).
max
(
dim
=
0
).
values
...
...
@@ -160,14 +165,13 @@ class FreeAnchor3DHead(Anchor3DHead):
image_box_prob
=
torch
.
sparse_coo_tensor
(
indices
.
flip
([
0
]),
nonzero_box_prob
,
size
=
(
anchors
_
.
size
(
0
),
self
.
num_classes
)).
to_dense
()
size
=
(
anchors
.
size
(
0
),
self
.
num_classes
)).
to_dense
()
# end
box_prob
.
append
(
image_box_prob
)
# construct bags for objects
match_quality_matrix
=
bbox_overlaps_nearest_3d
(
gt_bboxes_
,
anchors_
)
match_quality_matrix
=
bbox_overlaps_nearest_3d
(
gt_bboxes
,
anchors
)
_
,
matched
=
torch
.
topk
(
match_quality_matrix
,
self
.
pre_anchor_topk
,
...
...
@@ -177,15 +181,15 @@ class FreeAnchor3DHead(Anchor3DHead):
# matched_cls_prob: P_{ij}^{cls}
matched_cls_prob
=
torch
.
gather
(
cls_prob
_
[
matched
],
2
,
gt_labels
_
.
view
(
-
1
,
1
,
1
).
repeat
(
1
,
self
.
pre_anchor_topk
,
1
)).
squeeze
(
2
)
cls_prob
[
matched
],
2
,
gt_labels
.
view
(
-
1
,
1
,
1
).
repeat
(
1
,
self
.
pre_anchor_topk
,
1
)).
squeeze
(
2
)
# matched_box_prob: P_{ij}^{loc}
matched_anchors
=
anchors
_
[
matched
]
matched_anchors
=
anchors
[
matched
]
matched_object_targets
=
self
.
bbox_coder
.
encode
(
matched_anchors
,
gt_bboxes
_
.
unsqueeze
(
dim
=
1
).
expand_as
(
matched_anchors
))
gt_bboxes
.
unsqueeze
(
dim
=
1
).
expand_as
(
matched_anchors
))
# direction classification loss
loss_dir
=
None
...
...
@@ -198,15 +202,16 @@ class FreeAnchor3DHead(Anchor3DHead):
self
.
dir_limit_offset
,
one_hot
=
False
)
loss_dir
=
self
.
loss_dir
(
dir_cls_pred
s_
[
matched
].
transpose
(
-
2
,
-
1
),
dir_cls_pred
[
matched
].
transpose
(
-
2
,
-
1
),
matched_dir_targets
,
reduction_override
=
'none'
)
# generate bbox weights
if
self
.
diff_rad_by_sin
:
bbox_preds_
[
matched
],
matched_object_targets
=
\
bbox_preds_clone
=
bbox_pred
.
clone
()
bbox_preds_clone
[
matched
],
matched_object_targets
=
\
self
.
add_sin_difference
(
bbox_preds_
[
matched
],
matched_object_targets
)
bbox_preds_
clone
[
matched
],
matched_object_targets
)
bbox_weights
=
matched_anchors
.
new_ones
(
matched_anchors
.
size
())
# Use pop is not right, check performance
code_weight
=
self
.
train_cfg
.
get
(
'code_weight'
,
None
)
...
...
@@ -214,7 +219,7 @@ class FreeAnchor3DHead(Anchor3DHead):
bbox_weights
=
bbox_weights
*
bbox_weights
.
new_tensor
(
code_weight
)
loss_bbox
=
self
.
loss_bbox
(
bbox_preds_
[
matched
],
bbox_preds_
clone
[
matched
],
matched_object_targets
,
bbox_weights
,
reduction_override
=
'none'
).
sum
(
-
1
)
...
...
@@ -224,7 +229,7 @@ class FreeAnchor3DHead(Anchor3DHead):
matched_box_prob
=
torch
.
exp
(
-
loss_bbox
)
# positive_losses: {-log( Mean-max(P_{ij}^{cls} * P_{ij}^{loc}) )}
num_pos
+=
len
(
gt_bboxes
_
)
num_pos
+=
len
(
gt_bboxes
)
positive_losses
.
append
(
self
.
positive_bag_loss
(
matched_cls_prob
,
matched_box_prob
))
...
...
@@ -244,7 +249,8 @@ class FreeAnchor3DHead(Anchor3DHead):
}
return
losses
def
positive_bag_loss
(
self
,
matched_cls_prob
,
matched_box_prob
):
def
positive_bag_loss
(
self
,
matched_cls_prob
:
Tensor
,
matched_box_prob
:
Tensor
)
->
Tensor
:
"""Generate positive bag loss.
Args:
...
...
@@ -266,7 +272,7 @@ class FreeAnchor3DHead(Anchor3DHead):
return
self
.
alpha
*
F
.
binary_cross_entropy
(
bag_prob
,
torch
.
ones_like
(
bag_prob
),
reduction
=
'none'
)
def
negative_bag_loss
(
self
,
cls_prob
,
box_prob
)
:
def
negative_bag_loss
(
self
,
cls_prob
:
Tensor
,
box_prob
:
Tensor
)
->
Tensor
:
"""Generate negative bag loss.
Args:
...
...
mmdet3d/models/dense_heads/shape_aware_head.py
View file @
522cc20d
# Copyright (c) OpenMMLab. All rights reserved.
import
warnings
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
numpy
as
np
import
torch
from
mmcv.cnn
import
ConvModule
from
mmcv.runner
import
BaseModule
from
mmengine.data
import
InstanceData
from
mmengine.model
import
BaseModule
from
torch
import
Tensor
from
torch
import
nn
as
nn
from
mmdet3d.core
import
box3d_multiclass_nms
,
limit_period
,
xywhr2xyxyr
from
mmdet3d.core.utils
import
InstanceList
,
OptInstanceList
from
mmdet3d.registry
import
MODELS
from
mmdet.core
import
multi_apply
from
..builder
import
build_head
...
...
@@ -33,29 +37,30 @@ class BaseShapeHead(BaseModule):
in_channels (int): Input channels for convolutional layers.
shared_conv_channels (tuple, optional): Channels for shared
convolutional layers. Default: (64, 64).
shared_conv_strides (tuple
, optional
): Strides for shared
shared_conv_strides (tuple): Strides for shared
convolutional layers. Default: (1, 1).
use_direction_classifier (bool
, optional
): Whether to use direction
use_direction_classifier (bool): Whether to use direction
classifier. Default: True.
conv_cfg (dict
, optional
): Config of conv layer.
conv_cfg (dict): Config of conv layer.
Default: dict(type='Conv2d')
norm_cfg (dict
, optional
): Config of norm layer.
norm_cfg (dict): Config of norm layer.
Default: dict(type='BN2d').
bias (bool | str, optional): Type of bias. Default: False.
bias (bool | str): Type of bias. Default: False.
init_cfg (dict or list[dict], optional): Initialization config dict.
"""
def
__init__
(
self
,
num_cls
,
num_base_anchors
,
box_code_size
,
in_channels
,
shared_conv_channels
=
(
64
,
64
),
shared_conv_strides
=
(
1
,
1
),
use_direction_classifier
=
True
,
conv_cfg
=
dict
(
type
=
'Conv2d'
),
norm_cfg
=
dict
(
type
=
'BN2d'
),
bias
=
False
,
init_cfg
=
None
)
:
num_cls
:
int
,
num_base_anchors
:
int
,
box_code_size
:
int
,
in_channels
:
int
,
shared_conv_channels
:
Tuple
=
(
64
,
64
),
shared_conv_strides
:
Tuple
=
(
1
,
1
),
use_direction_classifier
:
bool
=
True
,
conv_cfg
:
Dict
=
dict
(
type
=
'Conv2d'
),
norm_cfg
:
Dict
=
dict
(
type
=
'BN2d'
),
bias
:
bool
=
False
,
init_cfg
:
Optional
[
dict
]
=
None
)
->
None
:
super
().
__init__
(
init_cfg
=
init_cfg
)
self
.
num_cls
=
num_cls
self
.
num_base_anchors
=
num_base_anchors
...
...
@@ -122,7 +127,7 @@ class BaseShapeHead(BaseModule):
bias_prob
=
0.01
)
])
def
forward
(
self
,
x
)
:
def
forward
(
self
,
x
:
Tensor
)
->
Dict
:
"""Forward function for SmallHead.
Args:
...
...
@@ -171,13 +176,16 @@ class ShapeAwareHead(Anchor3DHead):
Args:
tasks (dict): Shape-aware groups of multi-class objects.
assign_per_class (bool
, optional
): Whether to do assignment for each
assign_per_class (bool): Whether to do assignment for each
class. Default: True.
kwargs (dict): Other arguments are the same as those in
:class:`Anchor3DHead`.
init_cfg (dict or list[dict], optional): Initialization config dict.
"""
def
__init__
(
self
,
tasks
,
assign_per_class
=
True
,
init_cfg
=
None
,
**
kwargs
):
def
__init__
(
self
,
tasks
:
Dict
,
assign_per_class
:
bool
=
True
,
init_cfg
:
Optional
[
dict
]
=
None
,
**
kwargs
)
->
Dict
:
self
.
tasks
=
tasks
self
.
featmap_sizes
=
[]
super
().
__init__
(
...
...
@@ -198,10 +206,10 @@ class ShapeAwareHead(Anchor3DHead):
self
.
heads
=
nn
.
ModuleList
()
cls_ptr
=
0
for
task
in
self
.
tasks
:
sizes
=
self
.
anch
or_generator
.
sizes
[
cls_ptr
:
cls_ptr
+
task
[
'num_class'
]]
sizes
=
self
.
pri
or_generator
.
sizes
[
cls_ptr
:
cls_ptr
+
task
[
'num_class'
]]
num_size
=
torch
.
tensor
(
sizes
).
reshape
(
-
1
,
3
).
size
(
0
)
num_rot
=
len
(
self
.
anch
or_generator
.
rotations
)
num_rot
=
len
(
self
.
pri
or_generator
.
rotations
)
num_base_anchors
=
num_rot
*
num_size
branch
=
dict
(
type
=
'BaseShapeHead'
,
...
...
@@ -214,7 +222,7 @@ class ShapeAwareHead(Anchor3DHead):
self
.
heads
.
append
(
build_head
(
branch
))
cls_ptr
+=
task
[
'num_class'
]
def
forward_single
(
self
,
x
)
:
def
forward_single
(
self
,
x
:
Tensor
)
->
Tuple
[
Tensor
]
:
"""Forward function on a single-scale feature map.
Args:
...
...
@@ -241,15 +249,18 @@ class ShapeAwareHead(Anchor3DHead):
for
i
,
task
in
enumerate
(
self
.
tasks
):
for
_
in
range
(
task
[
'num_class'
]):
self
.
featmap_sizes
.
append
(
results
[
i
][
'featmap_size'
])
assert
len
(
self
.
featmap_sizes
)
==
len
(
self
.
anch
or_generator
.
ranges
),
\
assert
len
(
self
.
featmap_sizes
)
==
len
(
self
.
pri
or_generator
.
ranges
),
\
'Length of feature map sizes must be equal to length of '
+
\
'different ranges of anchor generator.'
return
cls_score
,
bbox_pred
,
dir_cls_preds
def
loss_single
(
self
,
cls_score
,
bbox_pred
,
dir_cls_preds
,
labels
,
label_weights
,
bbox_targets
,
bbox_weights
,
dir_targets
,
dir_weights
,
num_total_samples
):
def
loss_single
(
self
,
cls_score
:
Tensor
,
bbox_pred
:
Tensor
,
dir_cls_preds
:
Tensor
,
labels
:
Tensor
,
label_weights
:
Tensor
,
bbox_targets
:
Tensor
,
bbox_weights
:
Tensor
,
dir_targets
:
Tensor
,
dir_weights
:
Tensor
,
num_total_samples
:
int
)
->
Tuple
[
Tensor
]:
"""Calculate loss of Single-level results.
Args:
...
...
@@ -309,27 +320,30 @@ class ShapeAwareHead(Anchor3DHead):
return
loss_cls
,
loss_bbox
,
loss_dir
def
loss
(
self
,
cls_scores
,
bbox_preds
,
dir_cls_preds
,
gt_bboxes
,
gt_labels
,
input_metas
,
gt_bboxes_ignore
=
None
):
"""Calculate losses.
def
loss_by_feat
(
self
,
cls_scores
:
List
[
Tensor
],
bbox_preds
:
List
[
Tensor
],
dir_cls_preds
:
List
[
Tensor
],
batch_gt_instances_3d
:
InstanceList
,
batch_input_metas
:
List
[
dict
],
batch_gt_instances_ignore
:
OptInstanceList
=
None
)
->
Dict
:
"""Calculate the loss based on the features extracted by the detection
head.
Args:
cls_scores (list[torch.Tensor]): Multi-level class scores.
bbox_preds (list[torch.Tensor]): Multi-level bbox predictions.
dir_cls_preds (list[torch.Tensor]): Multi-level direction
class predictions.
gt_bboxes (list[:obj:`BaseInstance3DBoxes`]): Gt bboxes
of each sample.
gt_labels (list[torch.Tensor]): Gt labels of each sample.
input_metas (list[dict]): Contain pcd and img's meta info.
gt_bboxes_ignore (list[torch.Tensor]): Specify
which bounding.
batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
gt_instances. It usually includes ``bboxes_3d`` and
``labels_3d`` attributes.
batch_input_metas (list[dict]): Contain pcd and sample's meta info.
batch_gt_instances_ignore (list[:obj:`InstanceData`], optional):
Batch of gt_instances_ignore. It includes ``bboxes`` attribute
data that is ignored during training and testing.
Defaults to None.
Returns:
dict[str, list[torch.Tensor]]: Classification, bbox, and
...
...
@@ -342,13 +356,12 @@ class ShapeAwareHead(Anchor3DHead):
"""
device
=
cls_scores
[
0
].
device
anchor_list
=
self
.
get_anchors
(
self
.
featmap_sizes
,
input_metas
,
device
=
device
)
self
.
featmap_sizes
,
batch_
input_metas
,
device
=
device
)
cls_reg_targets
=
self
.
anchor_target_3d
(
anchor_list
,
gt_bboxes
,
input_metas
,
gt_bboxes_ignore_list
=
gt_bboxes_ignore
,
gt_labels_list
=
gt_labels
,
batch_gt_instances_3d
,
batch_input_metas
,
batch_gt_instances_ignore
=
batch_gt_instances_ignore
,
num_classes
=
self
.
num_classes
,
sampling
=
self
.
sampling
)
...
...
@@ -376,21 +389,22 @@ class ShapeAwareHead(Anchor3DHead):
return
dict
(
loss_cls
=
losses_cls
,
loss_bbox
=
losses_bbox
,
loss_dir
=
losses_dir
)
def
get_bboxes
(
self
,
cls_scores
,
bbox_preds
,
dir_cls_preds
,
input_metas
,
cfg
=
None
,
rescale
=
False
):
"""Get bboxes of anchor head.
def
predict_by_feat
(
self
,
cls_scores
:
List
[
Tensor
],
bbox_preds
:
List
[
Tensor
],
dir_cls_preds
:
List
[
Tensor
],
batch_input_metas
:
List
[
dict
],
cfg
:
Optional
[
dict
]
=
None
,
rescale
:
List
[
Tensor
]
=
False
)
->
List
[
tuple
]:
"""Transform a batch of output features extracted from the head into
bbox results.
Args:
cls_scores (list[torch.Tensor]): Multi-level class scores.
bbox_preds (list[torch.Tensor]): Multi-level bbox predictions.
dir_cls_preds (list[torch.Tensor]): Multi-level direction
class predictions.
input_metas (list[dict]): Contain pcd and img's meta info.
batch_
input_metas (list[dict]): Contain pcd and img's meta info.
cfg (:obj:`ConfigDict`, optional): Training or testing config.
Default: None.
rescale (list[torch.Tensor], optional): Whether to rescale bbox.
...
...
@@ -404,13 +418,13 @@ class ShapeAwareHead(Anchor3DHead):
num_levels
=
len
(
cls_scores
)
assert
num_levels
==
1
,
'Only support single level inference.'
device
=
cls_scores
[
0
].
device
mlvl_anchors
=
self
.
anch
or_generator
.
grid_anchors
(
mlvl_anchors
=
self
.
pri
or_generator
.
grid_anchors
(
self
.
featmap_sizes
,
device
=
device
)
# `anchor` is a list of anchors for different classes
mlvl_anchors
=
[
torch
.
cat
(
anchor
,
dim
=
0
)
for
anchor
in
mlvl_anchors
]
result_list
=
[]
for
img_id
in
range
(
len
(
input_metas
)):
for
img_id
in
range
(
len
(
batch_
input_metas
)):
cls_score_list
=
[
cls_scores
[
i
][
img_id
].
detach
()
for
i
in
range
(
num_levels
)
]
...
...
@@ -421,22 +435,25 @@ class ShapeAwareHead(Anchor3DHead):
dir_cls_preds
[
i
][
img_id
].
detach
()
for
i
in
range
(
num_levels
)
]
input_meta
=
input_metas
[
img_id
]
proposals
=
self
.
get_bboxes_single
(
cls_score_list
,
bbox_pred_list
,
dir_cls_pred_list
,
mlvl_anchors
,
input_meta
,
cfg
,
rescale
)
input_meta
=
batch_input_metas
[
img_id
]
proposals
=
self
.
_predict_by_feat_single
(
cls_score_list
,
bbox_pred_list
,
dir_cls_pred_list
,
mlvl_anchors
,
input_meta
,
cfg
,
rescale
)
result_list
.
append
(
proposals
)
return
result_list
def
get_bboxes_single
(
self
,
cls_scores
,
bbox_preds
,
dir_cls_preds
,
mlvl_anchors
,
input_meta
,
cfg
=
None
,
rescale
=
False
):
"""Get bboxes of single branch.
def
_predict_by_feat_single
(
self
,
cls_scores
:
Tensor
,
bbox_preds
:
Tensor
,
dir_cls_preds
:
Tensor
,
mlvl_anchors
:
List
[
Tensor
],
input_meta
:
List
[
dict
],
cfg
:
Dict
=
None
,
rescale
:
List
[
Tensor
]
=
False
):
"""Transform a single point's features extracted from the head into
bbox results.
Args:
cls_scores (torch.Tensor): Class score in single batch.
...
...
@@ -447,7 +464,7 @@ class ShapeAwareHead(Anchor3DHead):
in single batch.
input_meta (list[dict]): Contain pcd and img's meta info.
cfg (:obj:`ConfigDict`): Training or testing config.
rescale (list[torch.Tensor]
, optional
): whether to rescale bbox.
rescale (list[torch.Tensor]): whether to rescale bbox.
Default: False.
Returns:
...
...
@@ -513,4 +530,8 @@ class ShapeAwareHead(Anchor3DHead):
dir_rot
+
self
.
dir_offset
+
np
.
pi
*
dir_scores
.
to
(
bboxes
.
dtype
))
bboxes
=
input_meta
[
'box_type_3d'
](
bboxes
,
box_dim
=
self
.
box_code_size
)
return
bboxes
,
scores
,
labels
results
=
InstanceData
()
results
.
bboxes_3d
=
bboxes
results
.
scores_3d
=
scores
results
.
labels_3d
=
labels
return
results
tests/test_models/test_dense_heads/test_freeanchors.py
0 → 100644
View file @
522cc20d
import
unittest
import
torch
from
mmengine
import
DefaultScope
from
mmdet3d.core
import
LiDARInstance3DBoxes
from
mmdet3d.registry
import
MODELS
from
tests.utils.model_utils
import
(
_create_detector_inputs
,
_get_detector_cfg
,
_setup_seed
)
class
TestFreeAnchor
(
unittest
.
TestCase
):
def
test_freeanchor
(
self
):
import
mmdet3d.models
assert
hasattr
(
mmdet3d
.
models
.
dense_heads
,
'FreeAnchor3DHead'
)
DefaultScope
.
get_instance
(
'test_freeanchor'
,
scope_name
=
'mmdet3d'
)
_setup_seed
(
0
)
freeanchor_cfg
=
_get_detector_cfg
(
'free_anchor/hv_pointpillars_fpn_sbn-all_free-'
'anchor_4x8_2x_nus-3d.py'
)
model
=
MODELS
.
build
(
freeanchor_cfg
)
num_gt_instance
=
50
data
=
[
_create_detector_inputs
(
num_gt_instance
=
num_gt_instance
,
gt_bboxes_dim
=
9
)
]
aug_data
=
[
_create_detector_inputs
(
num_gt_instance
=
num_gt_instance
,
gt_bboxes_dim
=
9
),
_create_detector_inputs
(
num_gt_instance
=
num_gt_instance
+
1
,
gt_bboxes_dim
=
9
)
]
# test_aug_test
metainfo
=
{
'pcd_scale_factor'
:
1
,
'pcd_horizontal_flip'
:
1
,
'pcd_vertical_flip'
:
1
,
'box_type_3d'
:
LiDARInstance3DBoxes
}
for
item
in
aug_data
:
item
[
'data_sample'
].
set_metainfo
(
metainfo
)
if
torch
.
cuda
.
is_available
():
model
=
model
.
cuda
()
# test simple_test
with
torch
.
no_grad
():
batch_inputs
,
data_samples
=
model
.
data_preprocessor
(
data
,
True
)
results
=
model
.
forward
(
batch_inputs
,
data_samples
,
mode
=
'predict'
)
self
.
assertEqual
(
len
(
results
),
len
(
data
))
self
.
assertIn
(
'bboxes_3d'
,
results
[
0
].
pred_instances_3d
)
self
.
assertIn
(
'scores_3d'
,
results
[
0
].
pred_instances_3d
)
self
.
assertIn
(
'labels_3d'
,
results
[
0
].
pred_instances_3d
)
batch_inputs
,
data_samples
=
model
.
data_preprocessor
(
aug_data
,
True
)
aug_results
=
model
.
forward
(
batch_inputs
,
data_samples
,
mode
=
'predict'
)
self
.
assertEqual
(
len
(
results
),
len
(
data
))
self
.
assertIn
(
'bboxes_3d'
,
aug_results
[
0
].
pred_instances_3d
)
self
.
assertIn
(
'scores_3d'
,
aug_results
[
0
].
pred_instances_3d
)
self
.
assertIn
(
'labels_3d'
,
aug_results
[
0
].
pred_instances_3d
)
self
.
assertIn
(
'bboxes_3d'
,
aug_results
[
1
].
pred_instances_3d
)
self
.
assertIn
(
'scores_3d'
,
aug_results
[
1
].
pred_instances_3d
)
self
.
assertIn
(
'labels_3d'
,
aug_results
[
1
].
pred_instances_3d
)
losses
=
model
.
forward
(
batch_inputs
,
data_samples
,
mode
=
'loss'
)
self
.
assertGreater
(
losses
[
'positive_bag_loss'
],
0
)
self
.
assertGreater
(
losses
[
'negative_bag_loss'
],
0
)
tests/test_models/test_dense_heads/test_ssn.py
0 → 100644
View file @
522cc20d
import
unittest
import
torch
from
mmengine
import
DefaultScope
from
mmdet3d.core
import
LiDARInstance3DBoxes
from
mmdet3d.registry
import
MODELS
from
tests.utils.model_utils
import
(
_create_detector_inputs
,
_get_detector_cfg
,
_setup_seed
)
class
TestSSN
(
unittest
.
TestCase
):
def
test_ssn
(
self
):
import
mmdet3d.models
assert
hasattr
(
mmdet3d
.
models
.
dense_heads
,
'ShapeAwareHead'
)
DefaultScope
.
get_instance
(
'test_ssn'
,
scope_name
=
'mmdet3d'
)
_setup_seed
(
0
)
ssn_cfg
=
_get_detector_cfg
(
'ssn/hv_ssn_secfpn_sbn-all_2x16_2x_nus-3d.py'
)
model
=
MODELS
.
build
(
ssn_cfg
)
num_gt_instance
=
50
data
=
[
_create_detector_inputs
(
num_gt_instance
=
num_gt_instance
,
gt_bboxes_dim
=
9
)
]
aug_data
=
[
_create_detector_inputs
(
num_gt_instance
=
num_gt_instance
,
gt_bboxes_dim
=
9
),
_create_detector_inputs
(
num_gt_instance
=
num_gt_instance
+
1
,
gt_bboxes_dim
=
9
)
]
# test_aug_test
metainfo
=
{
'pcd_scale_factor'
:
1
,
'pcd_horizontal_flip'
:
1
,
'pcd_vertical_flip'
:
1
,
'box_type_3d'
:
LiDARInstance3DBoxes
}
for
item
in
aug_data
:
item
[
'data_sample'
].
set_metainfo
(
metainfo
)
if
torch
.
cuda
.
is_available
():
model
=
model
.
cuda
()
# test simple_test
with
torch
.
no_grad
():
batch_inputs
,
data_samples
=
model
.
data_preprocessor
(
data
,
True
)
results
=
model
.
forward
(
batch_inputs
,
data_samples
,
mode
=
'predict'
)
self
.
assertEqual
(
len
(
results
),
len
(
data
))
self
.
assertIn
(
'bboxes_3d'
,
results
[
0
].
pred_instances_3d
)
self
.
assertIn
(
'scores_3d'
,
results
[
0
].
pred_instances_3d
)
self
.
assertIn
(
'labels_3d'
,
results
[
0
].
pred_instances_3d
)
batch_inputs
,
data_samples
=
model
.
data_preprocessor
(
aug_data
,
True
)
aug_results
=
model
.
forward
(
batch_inputs
,
data_samples
,
mode
=
'predict'
)
self
.
assertEqual
(
len
(
results
),
len
(
data
))
self
.
assertIn
(
'bboxes_3d'
,
aug_results
[
0
].
pred_instances_3d
)
self
.
assertIn
(
'scores_3d'
,
aug_results
[
0
].
pred_instances_3d
)
self
.
assertIn
(
'labels_3d'
,
aug_results
[
0
].
pred_instances_3d
)
self
.
assertIn
(
'bboxes_3d'
,
aug_results
[
1
].
pred_instances_3d
)
self
.
assertIn
(
'scores_3d'
,
aug_results
[
1
].
pred_instances_3d
)
self
.
assertIn
(
'labels_3d'
,
aug_results
[
1
].
pred_instances_3d
)
losses
=
model
.
forward
(
batch_inputs
,
data_samples
,
mode
=
'loss'
)
self
.
assertGreater
(
losses
[
'loss_cls'
][
0
],
0
)
self
.
assertGreater
(
losses
[
'loss_bbox'
][
0
],
0
)
self
.
assertGreater
(
losses
[
'loss_dir'
][
0
],
0
)
tests/utils/model_utils.py
View file @
522cc20d
...
...
@@ -76,6 +76,7 @@ def _create_detector_inputs(seed=0,
with_img
=
False
,
num_gt_instance
=
20
,
points_feat_dim
=
4
,
gt_bboxes_dim
=
7
,
num_classes
=
3
):
_setup_seed
(
seed
)
inputs_dict
=
dict
()
...
...
@@ -88,7 +89,7 @@ def _create_detector_inputs(seed=0,
gt_instance_3d
=
InstanceData
()
gt_instance_3d
.
bboxes_3d
=
LiDARInstance3DBoxes
(
torch
.
rand
([
num_gt_instance
,
7
])
)
torch
.
rand
([
num_gt_instance
,
gt_bboxes_dim
]),
box_dim
=
gt_bboxes_dim
)
gt_instance_3d
.
labels_3d
=
torch
.
randint
(
0
,
num_classes
,
[
num_gt_instance
])
data_sample
=
Det3DDataSample
(
metainfo
=
dict
(
box_type_3d
=
LiDARInstance3DBoxes
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment