Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
mmdetection3d
Commits
48306d57
Commit
48306d57
authored
Jun 10, 2022
by
ZCMax
Committed by
ChaimZhu
Jul 20, 2022
Browse files
[Fix] Add fcos_mono_3d_head unittest and fix bugs
parent
effec8c3
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
248 additions
and
53 deletions
+248
-53
mmdet3d/models/dense_heads/fcos_mono3d_head.py
mmdet3d/models/dense_heads/fcos_mono3d_head.py
+65
-53
tests/test_models/test_dense_heads/test_fcos_mono3d_head.py
tests/test_models/test_dense_heads/test_fcos_mono3d_head.py
+183
-0
No files found.
mmdet3d/models/dense_heads/fcos_mono3d_head.py
View file @
48306d57
...
...
@@ -5,14 +5,13 @@ import numpy as np
import
torch
from
mmcv.cnn
import
Scale
,
normal_init
from
mmcv.runner
import
force_fp32
from
mmengine.data
import
InstanceData
from
torch
import
nn
as
nn
from
mmdet3d.core
import
(
box3d_multiclass_nms
,
limit_period
,
points_img2cam
,
xywhr2xyxyr
)
from
mmdet3d.models.builder
import
build_loss
from
mmdet3d.registry
import
MODELS
from
mmdet3d.registry
import
MODELS
,
TASK_UTILS
from
mmdet.core
import
multi_apply
from
mmdet.core.bbox.builder
import
build_bbox_coder
from
.anchor_free_mono3d_head
import
AnchorFreeMono3DHead
INF
=
1e8
...
...
@@ -96,9 +95,9 @@ class FCOSMono3DHead(AnchorFreeMono3DHead):
norm_cfg
=
norm_cfg
,
init_cfg
=
init_cfg
,
**
kwargs
)
self
.
loss_centerness
=
build
_loss
(
loss_centerness
)
self
.
loss_centerness
=
MODELS
.
build
(
loss_centerness
)
bbox_coder
[
'code_size'
]
=
self
.
bbox_code_size
self
.
bbox_coder
=
build_bbox_coder
(
bbox_coder
)
self
.
bbox_coder
=
TASK_UTILS
.
build
(
bbox_coder
)
def
_init_layers
(
self
):
"""Initialize layers of the head."""
...
...
@@ -281,7 +280,7 @@ class FCOSMono3DHead(AnchorFreeMono3DHead):
is a 4D-tensor, the channel number is num_points * 1.
batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
gt_instance_3d. It usually includes ``bboxes``、``labels``
、``bboxes_3d``、``labels3d``、``depths``、``centers2d`` and
、``bboxes_3d``、``labels3d``、``depths``、``centers
_
2d`` and
attributes.
batch_img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
...
...
@@ -467,15 +466,15 @@ class FCOSMono3DHead(AnchorFreeMono3DHead):
@
force_fp32
(
apply_to
=
(
'cls_scores'
,
'bbox_preds'
,
'dir_cls_preds'
,
'attr_preds'
,
'centernesses'
))
def
get_
bboxe
s
(
self
,
cls_scores
,
bbox_preds
,
dir_cls_preds
,
attr_preds
,
centernesses
,
img_metas
,
cfg
=
None
,
rescale
=
None
):
def
get_
result
s
(
self
,
cls_scores
,
bbox_preds
,
dir_cls_preds
,
attr_preds
,
centernesses
,
batch_
img_metas
,
cfg
=
None
,
rescale
=
None
):
"""Transform network output for a batch into bbox predictions.
Args:
...
...
@@ -490,7 +489,7 @@ class FCOSMono3DHead(AnchorFreeMono3DHead):
Has shape (N, num_points * num_attrs, H, W)
centernesses (list[Tensor]): Centerness for each scale level with
shape (N, num_points * 1, H, W)
img_metas (list[dict]): Meta information of each image, e.g.,
batch_
img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
cfg (mmcv.Config): Test / postprocessing configuration,
if None, test_cfg would be used
...
...
@@ -512,7 +511,7 @@ class FCOSMono3DHead(AnchorFreeMono3DHead):
mlvl_points
=
self
.
get_points
(
featmap_sizes
,
bbox_preds
[
0
].
dtype
,
bbox_preds
[
0
].
device
)
result_list
=
[]
for
img_id
in
range
(
len
(
img_metas
)):
for
img_id
in
range
(
len
(
batch_
img_metas
)):
cls_score_list
=
[
cls_scores
[
i
][
img_id
].
detach
()
for
i
in
range
(
num_levels
)
]
...
...
@@ -544,24 +543,26 @@ class FCOSMono3DHead(AnchorFreeMono3DHead):
centerness_pred_list
=
[
centernesses
[
i
][
img_id
].
detach
()
for
i
in
range
(
num_levels
)
]
input_meta
=
img_metas
[
img_id
]
det_bboxes
=
self
.
_get_bboxes_single
(
cls_score_list
,
bbox_pred_list
,
dir_cls_pred_list
,
attr_pred_list
,
centerness_pred_list
,
mlvl_points
,
input_meta
,
cfg
,
rescale
)
result_list
.
append
(
det_bboxes
)
img_meta
=
batch_img_metas
[
img_id
]
results
=
self
.
_get_results_single
(
cls_score_list
,
bbox_pred_list
,
dir_cls_pred_list
,
attr_pred_list
,
centerness_pred_list
,
mlvl_points
,
img_meta
,
cfg
,
rescale
)
result_list
.
append
(
results
)
return
result_list
def
_get_
bboxe
s_single
(
self
,
cls_scores
,
bbox_preds
,
dir_cls_preds
,
attr_preds
,
centernesses
,
mlvl_points
,
input
_meta
,
cfg
,
rescale
=
False
):
def
_get_
result
s_single
(
self
,
cls_scores
,
bbox_preds
,
dir_cls_preds
,
attr_preds
,
centernesses
,
mlvl_points
,
img
_meta
,
cfg
,
rescale
=
False
):
"""Transform outputs for a single batch item into bbox predictions.
Args:
...
...
@@ -578,7 +579,7 @@ class FCOSMono3DHead(AnchorFreeMono3DHead):
with shape (num_points, H, W).
mlvl_points (list[Tensor]): Box reference for a single scale level
with shape (num_total_points, 2).
i
nput
_meta (dict): Metadata of input image.
i
mg
_meta (dict): Metadata of input image.
cfg (mmcv.Config): Test / postprocessing configuration,
if None, test_cfg would be used.
rescale (bool): If True, return boxes in original image space.
...
...
@@ -586,11 +587,11 @@ class FCOSMono3DHead(AnchorFreeMono3DHead):
Returns:
tuples[Tensor]: Predicted 3D boxes, scores, labels and attributes.
"""
view
=
np
.
array
(
i
nput
_meta
[
'cam2img'
])
scale_factor
=
i
nput
_meta
[
'scale_factor'
]
view
=
np
.
array
(
i
mg
_meta
[
'cam2img'
])
scale_factor
=
i
mg
_meta
[
'scale_factor'
]
cfg
=
self
.
test_cfg
if
cfg
is
None
else
cfg
assert
len
(
cls_scores
)
==
len
(
bbox_preds
)
==
len
(
mlvl_points
)
mlvl_centers2d
=
[]
mlvl_centers
_
2d
=
[]
mlvl_bboxes
=
[]
mlvl_scores
=
[]
mlvl_dir_scores
=
[]
...
...
@@ -630,26 +631,26 @@ class FCOSMono3DHead(AnchorFreeMono3DHead):
bbox_pred
[:,
:
2
]
/=
bbox_pred
[:,
:
2
].
new_tensor
(
scale_factor
)
pred_center2d
=
bbox_pred
[:,
:
3
].
clone
()
bbox_pred
[:,
:
3
]
=
points_img2cam
(
bbox_pred
[:,
:
3
],
view
)
mlvl_centers2d
.
append
(
pred_center2d
)
mlvl_centers
_
2d
.
append
(
pred_center2d
)
mlvl_bboxes
.
append
(
bbox_pred
)
mlvl_scores
.
append
(
scores
)
mlvl_dir_scores
.
append
(
dir_cls_score
)
mlvl_attr_scores
.
append
(
attr_score
)
mlvl_centerness
.
append
(
centerness
)
mlvl_centers2d
=
torch
.
cat
(
mlvl_centers2d
)
mlvl_centers
_
2d
=
torch
.
cat
(
mlvl_centers
_
2d
)
mlvl_bboxes
=
torch
.
cat
(
mlvl_bboxes
)
mlvl_dir_scores
=
torch
.
cat
(
mlvl_dir_scores
)
# change local yaw to global yaw for 3D nms
cam2img
=
mlvl_centers2d
.
new_zeros
((
4
,
4
))
cam2img
=
mlvl_centers
_
2d
.
new_zeros
((
4
,
4
))
cam2img
[:
view
.
shape
[
0
],
:
view
.
shape
[
1
]]
=
\
mlvl_centers2d
.
new_tensor
(
view
)
mlvl_bboxes
=
self
.
bbox_coder
.
decode_yaw
(
mlvl_bboxes
,
mlvl_centers2d
,
mlvl_centers
_
2d
.
new_tensor
(
view
)
mlvl_bboxes
=
self
.
bbox_coder
.
decode_yaw
(
mlvl_bboxes
,
mlvl_centers
_
2d
,
mlvl_dir_scores
,
self
.
dir_offset
,
cam2img
)
mlvl_bboxes_for_nms
=
xywhr2xyxyr
(
i
nput
_meta
[
'box_type_3d'
](
mlvl_bboxes_for_nms
=
xywhr2xyxyr
(
i
mg
_meta
[
'box_type_3d'
](
mlvl_bboxes
,
box_dim
=
self
.
bbox_code_size
,
origin
=
(
0.5
,
0.5
,
0.5
)).
bev
)
...
...
@@ -669,16 +670,22 @@ class FCOSMono3DHead(AnchorFreeMono3DHead):
mlvl_attr_scores
)
bboxes
,
scores
,
labels
,
dir_scores
,
attrs
=
results
attrs
=
attrs
.
to
(
labels
.
dtype
)
# change data type to int
bboxes
=
i
nput
_meta
[
'box_type_3d'
](
bboxes
=
i
mg
_meta
[
'box_type_3d'
](
bboxes
,
box_dim
=
self
.
bbox_code_size
,
origin
=
(
0.5
,
0.5
,
0.5
))
# Note that the predictions use origin (0.5, 0.5, 0.5)
# Due to the ground truth centers2d are the gravity center of objects
# Due to the ground truth centers
_
2d are the gravity center of objects
# v0.10.0 fix inplace operation to the input tensor of cam_box3d
# So here we also need to add origin=(0.5, 0.5, 0.5)
if
not
self
.
pred_attrs
:
attrs
=
None
return
bboxes
,
scores
,
labels
,
attrs
results
=
InstanceData
()
results
.
bboxes_3d
=
bboxes
results
.
scores_3d
=
scores
results
.
labels_3d
=
labels
results
.
attr_labels
=
attrs
return
results
@
staticmethod
def
pts2Dto3D
(
points
,
view
):
...
...
@@ -738,7 +745,7 @@ class FCOSMono3DHead(AnchorFreeMono3DHead):
(num_points, 2).
batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
gt_instance_3d. It usually includes ``bboxes``、``labels``
、``bboxes_3d``、``labels3d``、``depths``、``centers2d`` and
、``bboxes_3d``、``labels3d``、``depths``、``centers
_
2d`` and
attributes.
Returns:
...
...
@@ -761,6 +768,11 @@ class FCOSMono3DHead(AnchorFreeMono3DHead):
# the number of points per img, per lvl
num_points
=
[
center
.
size
(
0
)
for
center
in
points
]
if
'attr_labels'
not
in
batch_gt_instances_3d
[
0
]:
for
gt_instances_3d
in
batch_gt_instances_3d
:
gt_instances_3d
.
attr_labels
=
gt_instances_3d
.
labels
.
new_full
(
gt_instances_3d
.
labels
.
shape
,
self
.
attr_background_label
)
# get labels and bbox_targets of each image
_
,
_
,
labels_3d_list
,
bbox_targets_3d_list
,
centerness_targets_list
,
\
attr_targets_list
=
multi_apply
(
...
...
@@ -822,7 +834,7 @@ class FCOSMono3DHead(AnchorFreeMono3DHead):
gt_labels
=
gt_instances_3d
.
labels
gt_bboxes_3d
=
gt_instances_3d
.
bboxes_3d
gt_labels_3d
=
gt_instances_3d
.
labels_3d
centers2d
=
gt_instances_3d
.
centers2d
centers
_
2d
=
gt_instances_3d
.
centers
_
2d
depths
=
gt_instances_3d
.
depths
attr_labels
=
gt_instances_3d
.
attr_labels
...
...
@@ -848,7 +860,7 @@ class FCOSMono3DHead(AnchorFreeMono3DHead):
regress_ranges
=
regress_ranges
[:,
None
,
:].
expand
(
num_points
,
num_gts
,
2
)
gt_bboxes
=
gt_bboxes
[
None
].
expand
(
num_points
,
num_gts
,
4
)
centers2d
=
centers2d
[
None
].
expand
(
num_points
,
num_gts
,
2
)
centers
_
2d
=
centers
_
2d
[
None
].
expand
(
num_points
,
num_gts
,
2
)
gt_bboxes_3d
=
gt_bboxes_3d
[
None
].
expand
(
num_points
,
num_gts
,
self
.
bbox_code_size
)
depths
=
depths
[
None
,
:,
None
].
expand
(
num_points
,
num_gts
,
1
)
...
...
@@ -856,8 +868,8 @@ class FCOSMono3DHead(AnchorFreeMono3DHead):
xs
=
xs
[:,
None
].
expand
(
num_points
,
num_gts
)
ys
=
ys
[:,
None
].
expand
(
num_points
,
num_gts
)
delta_xs
=
(
xs
-
centers2d
[...,
0
])[...,
None
]
delta_ys
=
(
ys
-
centers2d
[...,
1
])[...,
None
]
delta_xs
=
(
xs
-
centers
_
2d
[...,
0
])[...,
None
]
delta_ys
=
(
ys
-
centers
_
2d
[...,
1
])[...,
None
]
bbox_targets_3d
=
torch
.
cat
(
(
delta_xs
,
delta_ys
,
depths
,
gt_bboxes_3d
[...,
3
:]),
dim
=-
1
)
...
...
@@ -871,8 +883,8 @@ class FCOSMono3DHead(AnchorFreeMono3DHead):
'False has not been implemented for FCOS3D.'
# condition1: inside a `center bbox`
radius
=
self
.
center_sample_radius
center_xs
=
centers2d
[...,
0
]
center_ys
=
centers2d
[...,
1
]
center_xs
=
centers
_
2d
[...,
0
]
center_ys
=
centers
_
2d
[...,
1
]
center_gts
=
torch
.
zeros_like
(
gt_bboxes
)
stride
=
center_xs
.
new_zeros
(
center_xs
.
shape
)
...
...
tests/test_models/test_dense_heads/test_fcos_mono3d_head.py
0 → 100644
View file @
48306d57
# Copyright (c) OpenMMLab. All rights reserved.
from
unittest
import
TestCase
import
mmcv
import
numpy
as
np
import
torch
from
mmengine.data
import
InstanceData
from
mmdet3d.core.bbox
import
CameraInstance3DBoxes
from
mmdet3d.models.dense_heads
import
FCOSMono3DHead
class
TestFCOSMono3DHead
(
TestCase
):
def
test_fcos_mono3d_head_loss
(
self
):
"""Tests FCOS3D head loss and inference."""
img_metas
=
[
dict
(
cam2img
=
[[
1260.8474446004698
,
0.0
,
807.968244525554
],
[
0.0
,
1260.8474446004698
,
495.3344268742088
],
[
0.0
,
0.0
,
1.0
]],
scale_factor
=
np
.
array
([
1.
,
1.
,
1.
,
1.
],
dtype
=
np
.
float32
),
box_type_3d
=
CameraInstance3DBoxes
)
]
train_cfg
=
dict
(
allowed_border
=
0
,
code_weight
=
[
1.0
,
1.0
,
0.2
,
1.0
,
1.0
,
1.0
,
1.0
,
0.05
,
0.05
],
pos_weight
=-
1
,
debug
=
False
)
test_cfg
=
dict
(
use_rotate_nms
=
True
,
nms_across_levels
=
False
,
nms_pre
=
1000
,
nms_thr
=
0.8
,
score_thr
=
0.05
,
min_bbox_size
=
0
,
max_per_img
=
200
)
train_cfg
=
mmcv
.
Config
(
train_cfg
)
test_cfg
=
mmcv
.
Config
(
test_cfg
)
fcos_mono3d_head
=
FCOSMono3DHead
(
num_classes
=
10
,
in_channels
=
256
,
stacked_convs
=
2
,
feat_channels
=
256
,
use_direction_classifier
=
True
,
diff_rad_by_sin
=
True
,
pred_attrs
=
True
,
pred_velo
=
True
,
dir_offset
=
0.7854
,
# pi/4
dir_limit_offset
=
0
,
strides
=
[
8
,
16
,
32
,
64
,
128
],
group_reg_dims
=
(
2
,
1
,
3
,
1
,
2
),
# offset, depth, size, rot, velo
cls_branch
=
(
256
,
),
reg_branch
=
(
(
256
,
),
# offset
(
256
,
),
# depth
(
256
,
),
# size
(
256
,
),
# rot
()
# velo
),
dir_branch
=
(
256
,
),
attr_branch
=
(
256
,
),
loss_cls
=
dict
(
type
=
'mmdet.FocalLoss'
,
use_sigmoid
=
True
,
gamma
=
2.0
,
alpha
=
0.25
,
loss_weight
=
1.0
),
loss_bbox
=
dict
(
type
=
'mmdet.SmoothL1Loss'
,
beta
=
1.0
/
9.0
,
loss_weight
=
1.0
),
loss_dir
=
dict
(
type
=
'mmdet.CrossEntropyLoss'
,
use_sigmoid
=
False
,
loss_weight
=
1.0
),
loss_attr
=
dict
(
type
=
'mmdet.CrossEntropyLoss'
,
use_sigmoid
=
False
,
loss_weight
=
1.0
),
loss_centerness
=
dict
(
type
=
'mmdet.CrossEntropyLoss'
,
use_sigmoid
=
True
,
loss_weight
=
1.0
),
bbox_coder
=
dict
(
type
=
'FCOS3DBBoxCoder'
,
code_size
=
9
),
norm_on_bbox
=
True
,
centerness_on_reg
=
True
,
center_sampling
=
True
,
conv_bias
=
True
,
dcn_on_last_conv
=
False
,
train_cfg
=
train_cfg
,
test_cfg
=
test_cfg
)
# FCOS3D head expects a multiple levels of features per image
feats
=
[
torch
.
rand
([
1
,
256
,
116
,
200
],
dtype
=
torch
.
float32
),
torch
.
rand
([
1
,
256
,
58
,
100
],
dtype
=
torch
.
float32
),
torch
.
rand
([
1
,
256
,
29
,
50
],
dtype
=
torch
.
float32
),
torch
.
rand
([
1
,
256
,
15
,
25
],
dtype
=
torch
.
float32
),
torch
.
rand
([
1
,
256
,
8
,
13
],
dtype
=
torch
.
float32
)
]
# Test forward
ret_dict
=
fcos_mono3d_head
.
forward
(
feats
)
self
.
assertEqual
(
len
(
ret_dict
),
5
,
'the length of forward feature should be 5'
)
self
.
assertEqual
(
len
(
ret_dict
[
0
]),
5
,
'each feature should have 5 levels'
)
self
.
assertEqual
(
ret_dict
[
0
][
0
].
shape
,
torch
.
Size
([
1
,
10
,
116
,
200
]),
'the fist level feature shape should be [1, 10, 116, 200]'
)
# When truth is non-empty then all losses
# should be nonzero for random inputs
gt_instances_3d
=
InstanceData
()
gt_bboxes
=
torch
.
rand
([
3
,
4
],
dtype
=
torch
.
float32
)
gt_bboxes_3d
=
CameraInstance3DBoxes
(
torch
.
rand
([
3
,
9
]),
box_dim
=
9
)
gt_labels
=
torch
.
randint
(
0
,
10
,
[
3
])
gt_labels_3d
=
gt_labels
centers_2d
=
torch
.
rand
([
3
,
2
],
dtype
=
torch
.
float32
)
depths
=
torch
.
rand
([
3
],
dtype
=
torch
.
float32
)
attr_labels
=
torch
.
randint
(
0
,
9
,
[
3
])
gt_instances_3d
.
bboxes_3d
=
gt_bboxes_3d
gt_instances_3d
.
labels_3d
=
gt_labels_3d
gt_instances_3d
.
bboxes
=
gt_bboxes
gt_instances_3d
.
labels
=
gt_labels
gt_instances_3d
.
centers_2d
=
centers_2d
gt_instances_3d
.
depths
=
depths
gt_instances_3d
.
attr_labels
=
attr_labels
gt_losses
=
fcos_mono3d_head
.
loss
(
*
ret_dict
,
[
gt_instances_3d
],
img_metas
)
gt_cls_loss
=
gt_losses
[
'loss_cls'
].
item
()
gt_siz_loss
=
gt_losses
[
'loss_size'
].
item
()
gt_ctr_loss
=
gt_losses
[
'loss_centerness'
].
item
()
gt_off_loss
=
gt_losses
[
'loss_offset'
].
item
()
gt_dep_loss
=
gt_losses
[
'loss_depth'
].
item
()
gt_rot_loss
=
gt_losses
[
'loss_rotsin'
].
item
()
gt_vel_loss
=
gt_losses
[
'loss_velo'
].
item
()
gt_dir_loss
=
gt_losses
[
'loss_dir'
].
item
()
gt_atr_loss
=
gt_losses
[
'loss_attr'
].
item
()
self
.
assertGreater
(
gt_cls_loss
,
0
,
'cls loss should be positive'
)
self
.
assertGreater
(
gt_siz_loss
,
0
,
'size loss should be positive'
)
self
.
assertGreater
(
gt_ctr_loss
,
0
,
'centerness loss should be positive'
)
self
.
assertGreater
(
gt_off_loss
,
0
,
'offset loss should be positive'
)
self
.
assertGreater
(
gt_dep_loss
,
0
,
'depth loss should be positive'
)
self
.
assertGreater
(
gt_rot_loss
,
0
,
'rotsin loss should be positive'
)
self
.
assertGreater
(
gt_vel_loss
,
0
,
'velocity loss should be positive'
)
self
.
assertGreater
(
gt_dir_loss
,
0
,
'direction loss should be positive'
)
self
.
assertGreater
(
gt_atr_loss
,
0
,
'attribue loss should be positive'
)
# test get_results
results_list
=
fcos_mono3d_head
.
get_results
(
*
ret_dict
,
img_metas
)
self
.
assertEqual
(
len
(
results_list
),
1
,
'there should be no centerness loss when there are no true boxes'
)
results
=
results_list
[
0
]
pred_bboxes_3d
=
results
.
bboxes_3d
pred_scores_3d
=
results
.
scores_3d
pred_labels_3d
=
results
.
labels_3d
pred_attr_labels
=
results
.
attr_labels
self
.
assertEqual
(
pred_bboxes_3d
.
tensor
.
shape
,
torch
.
Size
([
200
,
9
]),
'the shape of predicted 3d bboxes should be [200, 9]'
)
self
.
assertEqual
(
pred_scores_3d
.
shape
,
torch
.
Size
([
200
]),
'the shape of predicted 3d bbox scores should be [200]'
)
self
.
assertEqual
(
pred_labels_3d
.
shape
,
torch
.
Size
([
200
]),
'the shape of predicted 3d bbox labels should be [200]'
)
self
.
assertEqual
(
pred_attr_labels
.
shape
,
torch
.
Size
([
200
]),
'the shape of predicted 3d bbox attribute labels should be [200]'
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment