Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
mmdetection3d
Commits
d490f024
Commit
d490f024
authored
Jun 09, 2022
by
ZCMax
Committed by
ChaimZhu
Jul 20, 2022
Browse files
[Refactor] Refactor monoflex head and unittest
parent
98cc28e2
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
299 additions
and
164 deletions
+299
-164
mmdet3d/models/dense_heads/monoflex_head.py
mmdet3d/models/dense_heads/monoflex_head.py
+231
-164
tests/test_models/test_dense_heads/test_monoflex_head.py
tests/test_models/test_dense_heads/test_monoflex_head.py
+68
-0
No files found.
mmdet3d/models/dense_heads/monoflex_head.py
View file @
d490f024
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch
from
mmcv.cnn
import
xavier_init
from
mmcv.cnn
import
xavier_init
from
mmcv.runner
import
force_fp32
from
mmengine.config
import
ConfigDict
from
mmengine.data
import
InstanceData
from
torch
import
Tensor
from
torch
import
nn
as
nn
from
torch
import
nn
as
nn
from
mmdet3d.core
import
Det3DDataSample
from
mmdet3d.core.bbox.builder
import
build_bbox_coder
from
mmdet3d.core.bbox.builder
import
build_bbox_coder
from
mmdet3d.core.utils
import
get_ellip_gaussian_2D
from
mmdet3d.core.utils
import
get_ellip_gaussian_2D
from
mmdet3d.models.builder
import
build_loss
from
mmdet3d.models.model_utils
import
EdgeFusionModule
from
mmdet3d.models.model_utils
import
EdgeFusionModule
from
mmdet3d.models.utils
import
(
filter_outside_objs
,
get_edge_indices
,
from
mmdet3d.models.utils
import
(
filter_outside_objs
,
get_edge_indices
,
get_keypoints
,
handle_proj_objs
)
get_keypoints
,
handle_proj_objs
)
...
@@ -63,7 +69,7 @@ class MonoFlexHead(AnchorFreeMono3DHead):
...
@@ -63,7 +69,7 @@ class MonoFlexHead(AnchorFreeMono3DHead):
Default: dict(type='L1Loss', loss_weight=0.1).
Default: dict(type='L1Loss', loss_weight=0.1).
loss_dims: (dict, optional): Config of dimensions loss.
loss_dims: (dict, optional): Config of dimensions loss.
Default: dict(type='L1Loss', loss_weight=0.1).
Default: dict(type='L1Loss', loss_weight=0.1).
loss_offsets2d: (dict, optional): Config of offsets2d loss.
loss_offsets
_
2d: (dict, optional): Config of offsets
_
2d loss.
Default: dict(type='L1Loss', loss_weight=0.1).
Default: dict(type='L1Loss', loss_weight=0.1).
loss_direct_depth: (dict, optional): Config of directly regression depth loss.
loss_direct_depth: (dict, optional): Config of directly regression depth loss.
Default: dict(type='L1Loss', loss_weight=0.1).
Default: dict(type='L1Loss', loss_weight=0.1).
...
@@ -81,27 +87,33 @@ class MonoFlexHead(AnchorFreeMono3DHead):
...
@@ -81,27 +87,33 @@ class MonoFlexHead(AnchorFreeMono3DHead):
"""
# noqa: E501
"""
# noqa: E501
def
__init__
(
self
,
def
__init__
(
self
,
num_classes
,
num_classes
:
int
,
in_channels
,
in_channels
:
int
,
use_edge_fusion
,
use_edge_fusion
:
bool
,
edge_fusion_inds
,
edge_fusion_inds
:
List
[
Tuple
],
edge_heatmap_ratio
,
edge_heatmap_ratio
:
float
,
filter_outside_objs
=
True
,
filter_outside_objs
:
bool
=
True
,
loss_cls
=
dict
(
type
=
'GaussianFocalLoss'
,
loss_weight
=
1.0
),
loss_cls
:
dict
=
dict
(
loss_bbox
=
dict
(
type
=
'IoULoss'
,
loss_weight
=
0.1
),
type
=
'mmdet.GaussianFocalLoss'
,
loss_weight
=
1.0
),
loss_dir
=
dict
(
type
=
'MultiBinLoss'
,
loss_weight
=
0.1
),
loss_bbox
:
dict
=
dict
(
type
=
'mmdet.IoULoss'
,
loss_weight
=
0.1
),
loss_keypoints
=
dict
(
type
=
'L1Loss'
,
loss_weight
=
0.1
),
loss_dir
:
dict
=
dict
(
type
=
'MultiBinLoss'
,
loss_weight
=
0.1
),
loss_dims
=
dict
(
type
=
'L1Loss'
,
loss_weight
=
0.1
),
loss_keypoints
:
dict
=
dict
(
loss_offsets2d
=
dict
(
type
=
'L1Loss'
,
loss_weight
=
0.1
),
type
=
'mmdet.L1Loss'
,
loss_weight
=
0.1
),
loss_direct_depth
=
dict
(
type
=
'L1Loss'
,
loss_weight
=
0.1
),
loss_dims
:
dict
=
dict
(
type
=
'mmdet.L1Loss'
,
loss_weight
=
0.1
),
loss_keypoints_depth
=
dict
(
type
=
'L1Loss'
,
loss_weight
=
0.1
),
loss_offsets_2d
:
dict
=
dict
(
loss_combined_depth
=
dict
(
type
=
'L1Loss'
,
loss_weight
=
0.1
),
type
=
'mmdet.L1Loss'
,
loss_weight
=
0.1
),
loss_attr
=
None
,
loss_direct_depth
:
dict
=
dict
(
bbox_coder
=
dict
(
type
=
'MonoFlexCoder'
,
code_size
=
7
),
type
=
'mmdet.L1Loss'
,
loss_weight
=
0.1
),
norm_cfg
=
dict
(
type
=
'BN'
),
loss_keypoints_depth
:
dict
=
dict
(
init_cfg
=
None
,
type
=
'mmdet.L1Loss'
,
loss_weight
=
0.1
),
init_bias
=-
2.19
,
loss_combined_depth
:
dict
=
dict
(
**
kwargs
):
type
=
'mmdet.L1Loss'
,
loss_weight
=
0.1
),
loss_attr
:
Optional
[
dict
]
=
None
,
bbox_coder
:
dict
=
dict
(
type
=
'MonoFlexCoder'
,
code_size
=
7
),
norm_cfg
:
Union
[
ConfigDict
,
dict
]
=
dict
(
type
=
'BN'
),
init_cfg
:
Optional
[
Union
[
ConfigDict
,
dict
]]
=
None
,
init_bias
:
float
=
-
2.19
,
**
kwargs
)
->
None
:
self
.
use_edge_fusion
=
use_edge_fusion
self
.
use_edge_fusion
=
use_edge_fusion
self
.
edge_fusion_inds
=
edge_fusion_inds
self
.
edge_fusion_inds
=
edge_fusion_inds
super
().
__init__
(
super
().
__init__
(
...
@@ -117,13 +129,13 @@ class MonoFlexHead(AnchorFreeMono3DHead):
...
@@ -117,13 +129,13 @@ class MonoFlexHead(AnchorFreeMono3DHead):
self
.
filter_outside_objs
=
filter_outside_objs
self
.
filter_outside_objs
=
filter_outside_objs
self
.
edge_heatmap_ratio
=
edge_heatmap_ratio
self
.
edge_heatmap_ratio
=
edge_heatmap_ratio
self
.
init_bias
=
init_bias
self
.
init_bias
=
init_bias
self
.
loss_dir
=
build
_loss
(
loss_dir
)
self
.
loss_dir
=
MODELS
.
build
(
loss_dir
)
self
.
loss_keypoints
=
build
_loss
(
loss_keypoints
)
self
.
loss_keypoints
=
MODELS
.
build
(
loss_keypoints
)
self
.
loss_dims
=
build
_loss
(
loss_dims
)
self
.
loss_dims
=
MODELS
.
build
(
loss_dims
)
self
.
loss_offsets2d
=
build
_loss
(
loss_offsets2d
)
self
.
loss_offsets
_
2d
=
MODELS
.
build
(
loss_offsets
_
2d
)
self
.
loss_direct_depth
=
build
_loss
(
loss_direct_depth
)
self
.
loss_direct_depth
=
MODELS
.
build
(
loss_direct_depth
)
self
.
loss_keypoints_depth
=
build
_loss
(
loss_keypoints_depth
)
self
.
loss_keypoints_depth
=
MODELS
.
build
(
loss_keypoints_depth
)
self
.
loss_combined_depth
=
build
_loss
(
loss_combined_depth
)
self
.
loss_combined_depth
=
MODELS
.
build
(
loss_combined_depth
)
self
.
bbox_coder
=
build_bbox_coder
(
bbox_coder
)
self
.
bbox_coder
=
build_bbox_coder
(
bbox_coder
)
def
_init_edge_module
(
self
):
def
_init_edge_module
(
self
):
...
@@ -185,13 +197,15 @@ class MonoFlexHead(AnchorFreeMono3DHead):
...
@@ -185,13 +197,15 @@ class MonoFlexHead(AnchorFreeMono3DHead):
if
self
.
use_edge_fusion
:
if
self
.
use_edge_fusion
:
self
.
_init_edge_module
()
self
.
_init_edge_module
()
def
forward_train
(
self
,
x
,
input_metas
,
gt_bboxes
,
gt_labels
,
gt_bboxes_3d
,
def
forward_train
(
self
,
gt_labels_3d
,
centers2d
,
depths
,
attr_labels
,
x
:
List
[
Tensor
],
gt_bboxes_ignore
,
proposal_cfg
,
**
kwargs
):
batch_data_samples
:
List
[
Det3DDataSample
],
proposal_cfg
:
Optional
[
ConfigDict
]
=
None
,
**
kwargs
):
"""
"""
Args:
Args:
x (list[Tensor]): Features from FPN.
x (list[Tensor]): Features from FPN.
input
_metas (list[dict]): Meta information of each image, e.g.,
batch_img
_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
image size, scaling factor, etc.
gt_bboxes (list[Tensor]): Ground truth bboxes of the image,
gt_bboxes (list[Tensor]): Ground truth bboxes of the image,
shape (num_gts, 4).
shape (num_gts, 4).
...
@@ -201,7 +215,7 @@ class MonoFlexHead(AnchorFreeMono3DHead):
...
@@ -201,7 +215,7 @@ class MonoFlexHead(AnchorFreeMono3DHead):
shape (num_gts, self.bbox_code_size).
shape (num_gts, self.bbox_code_size).
gt_labels_3d (list[Tensor]): 3D ground truth labels of each box,
gt_labels_3d (list[Tensor]): 3D ground truth labels of each box,
shape (num_gts,).
shape (num_gts,).
centers2d (list[Tensor]): Projected 3D center of each box,
centers
_
2d (list[Tensor]): Projected 3D center of each box,
shape (num_gts, 2).
shape (num_gts, 2).
depths (list[Tensor]): Depth of projected 3D center of each box,
depths (list[Tensor]): Depth of projected 3D center of each box,
shape (num_gts,).
shape (num_gts,).
...
@@ -216,29 +230,75 @@ class MonoFlexHead(AnchorFreeMono3DHead):
...
@@ -216,29 +230,75 @@ class MonoFlexHead(AnchorFreeMono3DHead):
losses: (dict[str, Tensor]): A dictionary of loss components.
losses: (dict[str, Tensor]): A dictionary of loss components.
proposal_list (list[Tensor]): Proposals of each image.
proposal_list (list[Tensor]): Proposals of each image.
"""
"""
outs
=
self
(
x
,
input_metas
)
"""
if
gt_labels
is
None
:
Args:
loss_inputs
=
outs
+
(
gt_bboxes
,
gt_bboxes_3d
,
centers2d
,
depths
,
x (list[Tensor]): Features from FPN.
attr_labels
,
input_metas
)
batch_data_samples (list[:obj:`Det3DDataSample`]): Each item
contains the meta information of each image and corresponding
annotations.
proposal_cfg (mmengine.Config, optional): Test / postprocessing
configuration, if None, test_cfg would be used.
Defaults to None.
Returns:
tuple or Tensor: When `proposal_cfg` is None, the detector is a
\
normal one-stage detector, The return value is the losses.
- losses: (dict[str, Tensor]): A dictionary of loss components.
When the `proposal_cfg` is not None, the head is used as a
`rpn_head`, the return value is a tuple contains:
- losses: (dict[str, Tensor]): A dictionary of loss components.
- results_list (list[:obj:`InstanceData`]): Detection
results of each image after the post process.
Each item usually contains following keys.
- scores (Tensor): Classification scores, has a shape
(num_instance, )
- labels (Tensor): Labels of bboxes, has a shape
(num_instances, ).
- bboxes (:obj:`BaseInstance3DBoxes`): Contains a tensor
with shape (num_instances, C), the last dimension C of a
3D box is (x, y, z, x_size, y_size, z_size, yaw, ...), where
C >= 7. C = 7 for kitti and C = 9 for nuscenes with extra 2
dims of velocity.
"""
batch_gt_instances_3d
=
[]
batch_gt_instances_ignore
=
[]
batch_img_metas
=
[]
for
data_sample
in
batch_data_samples
:
batch_img_metas
.
append
(
data_sample
.
metainfo
)
batch_gt_instances_3d
.
append
(
data_sample
.
gt_instances_3d
)
if
'ignored_instances'
in
data_sample
:
batch_gt_instances_ignore
.
append
(
data_sample
.
ignored_instances
)
else
:
else
:
loss_inputs
=
outs
+
(
gt_bboxes
,
gt_labels
,
gt_bboxes_3d
,
batch_gt_instances_ignore
.
append
(
None
)
gt_labels_3d
,
centers2d
,
depths
,
attr_labels
,
input_metas
)
# monoflex head needs img_metas for feature extraction
losses
=
self
.
loss
(
*
loss_inputs
,
gt_bboxes_ignore
=
gt_bboxes_ignore
)
outs
=
self
(
x
,
batch_img_metas
)
loss_inputs
=
outs
+
(
batch_gt_instances_3d
,
batch_img_metas
,
batch_gt_instances_ignore
)
losses
=
self
.
loss
(
*
loss_inputs
)
if
proposal_cfg
is
None
:
if
proposal_cfg
is
None
:
return
losses
return
losses
else
:
else
:
proposal_list
=
self
.
get_bboxes
(
batch_img_metas
=
[
*
outs
,
input_metas
,
cfg
=
proposal_cfg
)
data_sample
.
metainfo
for
data_sample
in
batch_data_samples
return
losses
,
proposal_list
]
results_list
=
self
.
get_results
(
*
outs
,
batch_img_metas
=
batch_img_metas
,
cfg
=
proposal_cfg
)
return
losses
,
results_list
def
forward
(
self
,
feats
,
input_metas
):
def
forward
(
self
,
feats
:
List
[
Tensor
],
batch_img_metas
:
List
[
dict
]
):
"""Forward features from the upstream network.
"""Forward features from the upstream network.
Args:
Args:
feats (list[Tensor]): Features from the upstream network, each is
feats (list[Tensor]): Features from the upstream network, each is
a 4D-tensor.
a 4D-tensor.
input
_metas (list[dict]): Meta information of each image, e.g.,
batch_img
_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
image size, scaling factor, etc.
Returns:
Returns:
...
@@ -250,21 +310,21 @@ class MonoFlexHead(AnchorFreeMono3DHead):
...
@@ -250,21 +310,21 @@ class MonoFlexHead(AnchorFreeMono3DHead):
level, each is a 4D-tensor, the channel number is
level, each is a 4D-tensor, the channel number is
num_points * bbox_code_size.
num_points * bbox_code_size.
"""
"""
mlvl_
input
_metas
=
[
input
_metas
for
i
in
range
(
len
(
feats
))]
mlvl_
batch_img
_metas
=
[
batch_img
_metas
for
i
in
range
(
len
(
feats
))]
return
multi_apply
(
self
.
forward_single
,
feats
,
mlvl_
input
_metas
)
return
multi_apply
(
self
.
forward_single
,
feats
,
mlvl_
batch_img
_metas
)
def
forward_single
(
self
,
x
,
input_metas
):
def
forward_single
(
self
,
x
:
Tensor
,
batch_img_metas
:
List
[
dict
]
):
"""Forward features of a single scale level.
"""Forward features of a single scale level.
Args:
Args:
x (Tensor): Feature maps from a specific FPN feature level.
x (Tensor): Feature maps from a specific FPN feature level.
input
_metas (list[dict]): Meta information of each image, e.g.,
batch_img
_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
image size, scaling factor, etc.
Returns:
Returns:
tuple: Scores for each class, bbox predictions.
tuple: Scores for each class, bbox predictions.
"""
"""
img_h
,
img_w
=
input
_metas
[
0
][
'pad_shape'
][:
2
]
img_h
,
img_w
=
batch_img
_metas
[
0
][
'pad_shape'
][:
2
]
batch_size
,
_
,
feat_h
,
feat_w
=
x
.
shape
batch_size
,
_
,
feat_h
,
feat_w
=
x
.
shape
downsample_ratio
=
img_h
/
feat_h
downsample_ratio
=
img_h
/
feat_h
...
@@ -275,7 +335,7 @@ class MonoFlexHead(AnchorFreeMono3DHead):
...
@@ -275,7 +335,7 @@ class MonoFlexHead(AnchorFreeMono3DHead):
if
self
.
use_edge_fusion
:
if
self
.
use_edge_fusion
:
# calculate the edge indices for the batch data
# calculate the edge indices for the batch data
edge_indices_list
=
get_edge_indices
(
edge_indices_list
=
get_edge_indices
(
input
_metas
,
downsample_ratio
,
device
=
x
.
device
)
batch_img
_metas
,
downsample_ratio
,
device
=
x
.
device
)
edge_lens
=
[
edge_lens
=
[
edge_indices
.
shape
[
0
]
for
edge_indices
in
edge_indices_list
edge_indices
.
shape
[
0
]
for
edge_indices
in
edge_indices_list
]
]
...
@@ -313,13 +373,15 @@ class MonoFlexHead(AnchorFreeMono3DHead):
...
@@ -313,13 +373,15 @@ class MonoFlexHead(AnchorFreeMono3DHead):
return
cls_score
,
bbox_pred
return
cls_score
,
bbox_pred
def
get_bboxes
(
self
,
cls_scores
,
bbox_preds
,
input_metas
):
@
force_fp32
(
apply_to
=
(
'cls_scores'
,
'bbox_preds'
))
def
get_results
(
self
,
cls_scores
:
List
[
Tensor
],
bbox_preds
:
List
[
Tensor
],
batch_img_metas
:
List
[
dict
]):
"""Generate bboxes from bbox head predictions.
"""Generate bboxes from bbox head predictions.
Args:
Args:
cls_scores (list[Tensor]): Box scores for each scale level.
cls_scores (list[Tensor]): Box scores for each scale level.
bbox_preds (list[Tensor]): Box regression for each scale.
bbox_preds (list[Tensor]): Box regression for each scale.
input
_metas (list[dict]): Meta information of each image, e.g.,
batch_img
_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
image size, scaling factor, etc.
rescale (bool): If True, return boxes in original image space.
rescale (bool): If True, return boxes in original image space.
Returns:
Returns:
...
@@ -329,18 +391,18 @@ class MonoFlexHead(AnchorFreeMono3DHead):
...
@@ -329,18 +391,18 @@ class MonoFlexHead(AnchorFreeMono3DHead):
assert
len
(
cls_scores
)
==
len
(
bbox_preds
)
==
1
assert
len
(
cls_scores
)
==
len
(
bbox_preds
)
==
1
cam2imgs
=
torch
.
stack
([
cam2imgs
=
torch
.
stack
([
cls_scores
[
0
].
new_tensor
(
input_meta
[
'cam2img'
])
cls_scores
[
0
].
new_tensor
(
input_meta
[
'cam2img'
])
for
input_meta
in
input
_metas
for
input_meta
in
batch_img
_metas
])
])
batch_bboxes
,
batch_scores
,
batch_topk_labels
=
self
.
decode_heatmap
(
batch_bboxes
,
batch_scores
,
batch_topk_labels
=
self
.
decode_heatmap
(
cls_scores
[
0
],
cls_scores
[
0
],
bbox_preds
[
0
],
bbox_preds
[
0
],
input
_metas
,
batch_img
_metas
,
cam2imgs
=
cam2imgs
,
cam2imgs
=
cam2imgs
,
topk
=
100
,
topk
=
100
,
kernel
=
3
)
kernel
=
3
)
result_list
=
[]
result_list
=
[]
for
img_id
in
range
(
len
(
input
_metas
)):
for
img_id
in
range
(
len
(
batch_img
_metas
)):
bboxes
=
batch_bboxes
[
img_id
]
bboxes
=
batch_bboxes
[
img_id
]
scores
=
batch_scores
[
img_id
]
scores
=
batch_scores
[
img_id
]
...
@@ -351,20 +413,29 @@ class MonoFlexHead(AnchorFreeMono3DHead):
...
@@ -351,20 +413,29 @@ class MonoFlexHead(AnchorFreeMono3DHead):
scores
=
scores
[
keep_idx
]
scores
=
scores
[
keep_idx
]
labels
=
labels
[
keep_idx
]
labels
=
labels
[
keep_idx
]
bboxes
=
input
_metas
[
img_id
][
'box_type_3d'
](
bboxes
=
batch_img
_metas
[
img_id
][
'box_type_3d'
](
bboxes
,
box_dim
=
self
.
bbox_code_size
,
origin
=
(
0.5
,
0.5
,
0.5
))
bboxes
,
box_dim
=
self
.
bbox_code_size
,
origin
=
(
0.5
,
0.5
,
0.5
))
attrs
=
None
attrs
=
None
result_list
.
append
((
bboxes
,
scores
,
labels
,
attrs
))
results
=
InstanceData
()
results
.
bboxes_3d
=
bboxes
results
.
scores_3d
=
scores
results
.
labels_3d
=
labels
if
attrs
is
not
None
:
results
.
attr_labels
=
attrs
result_list
.
append
(
results
)
return
result_list
return
result_list
def
decode_heatmap
(
self
,
def
decode_heatmap
(
self
,
cls_score
,
cls_score
:
Tensor
,
reg_pred
,
reg_pred
:
Tensor
,
input_metas
,
batch_img_metas
:
List
[
dict
]
,
cam2imgs
,
cam2imgs
:
Tensor
,
topk
=
100
,
topk
:
int
=
100
,
kernel
=
3
):
kernel
:
int
=
3
):
"""Transform outputs into detections raw bbox predictions.
"""Transform outputs into detections raw bbox predictions.
Args:
Args:
...
@@ -372,7 +443,7 @@ class MonoFlexHead(AnchorFreeMono3DHead):
...
@@ -372,7 +443,7 @@ class MonoFlexHead(AnchorFreeMono3DHead):
shape (B, num_classes, H, W).
shape (B, num_classes, H, W).
reg_pred (Tensor): Box regression map.
reg_pred (Tensor): Box regression map.
shape (B, channel, H , W).
shape (B, channel, H , W).
input
_metas (List[dict]): Meta information of each image, e.g.,
batch_img
_metas (List[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
image size, scaling factor, etc.
cam2imgs (Tensor): Camera intrinsic matrix.
cam2imgs (Tensor): Camera intrinsic matrix.
shape (N, 4, 4)
shape (N, 4, 4)
...
@@ -391,7 +462,7 @@ class MonoFlexHead(AnchorFreeMono3DHead):
...
@@ -391,7 +462,7 @@ class MonoFlexHead(AnchorFreeMono3DHead):
- batch_topk_labels (Tensor): Categories of each 3D box.
- batch_topk_labels (Tensor): Categories of each 3D box.
shape (B, k)
shape (B, k)
"""
"""
img_h
,
img_w
=
input
_metas
[
0
][
'pad_shape'
][:
2
]
img_h
,
img_w
=
batch_img
_metas
[
0
][
'pad_shape'
][:
2
]
batch_size
,
_
,
feat_h
,
feat_w
=
cls_score
.
shape
batch_size
,
_
,
feat_h
,
feat_w
=
cls_score
.
shape
downsample_ratio
=
img_h
/
feat_h
downsample_ratio
=
img_h
/
feat_h
...
@@ -404,13 +475,13 @@ class MonoFlexHead(AnchorFreeMono3DHead):
...
@@ -404,13 +475,13 @@ class MonoFlexHead(AnchorFreeMono3DHead):
regression
=
transpose_and_gather_feat
(
reg_pred
,
batch_index
)
regression
=
transpose_and_gather_feat
(
reg_pred
,
batch_index
)
regression
=
regression
.
view
(
-
1
,
8
)
regression
=
regression
.
view
(
-
1
,
8
)
pred_base_centers2d
=
torch
.
cat
(
pred_base_centers
_
2d
=
torch
.
cat
(
[
topk_xs
.
view
(
-
1
,
1
),
[
topk_xs
.
view
(
-
1
,
1
),
topk_ys
.
view
(
-
1
,
1
).
float
()],
dim
=
1
)
topk_ys
.
view
(
-
1
,
1
).
float
()],
dim
=
1
)
preds
=
self
.
bbox_coder
.
decode
(
regression
,
batch_topk_labels
,
preds
=
self
.
bbox_coder
.
decode
(
regression
,
batch_topk_labels
,
downsample_ratio
,
cam2imgs
)
downsample_ratio
,
cam2imgs
)
pred_locations
=
self
.
bbox_coder
.
decode_location
(
pred_locations
=
self
.
bbox_coder
.
decode_location
(
pred_base_centers2d
,
preds
[
'offsets2d'
],
preds
[
'combined_depth'
],
pred_base_centers
_
2d
,
preds
[
'offsets
_
2d'
],
preds
[
'combined_depth'
],
cam2imgs
,
downsample_ratio
)
cam2imgs
,
downsample_ratio
)
pred_yaws
=
self
.
bbox_coder
.
decode_orientation
(
pred_yaws
=
self
.
bbox_coder
.
decode_orientation
(
preds
[
'orientations'
]).
unsqueeze
(
-
1
)
preds
[
'orientations'
]).
unsqueeze
(
-
1
)
...
@@ -419,8 +490,8 @@ class MonoFlexHead(AnchorFreeMono3DHead):
...
@@ -419,8 +490,8 @@ class MonoFlexHead(AnchorFreeMono3DHead):
batch_bboxes
=
batch_bboxes
.
view
(
batch_size
,
-
1
,
self
.
bbox_code_size
)
batch_bboxes
=
batch_bboxes
.
view
(
batch_size
,
-
1
,
self
.
bbox_code_size
)
return
batch_bboxes
,
batch_scores
,
batch_topk_labels
return
batch_bboxes
,
batch_scores
,
batch_topk_labels
def
get_predictions
(
self
,
pred_reg
,
labels3d
,
centers2d
,
reg_mask
,
def
get_predictions
(
self
,
pred_reg
,
labels3d
,
centers
_
2d
,
reg_mask
,
batch_indices
,
input
_metas
,
downsample_ratio
):
batch_indices
,
batch_img
_metas
,
downsample_ratio
):
"""Prepare predictions for computing loss.
"""Prepare predictions for computing loss.
Args:
Args:
...
@@ -428,14 +499,14 @@ class MonoFlexHead(AnchorFreeMono3DHead):
...
@@ -428,14 +499,14 @@ class MonoFlexHead(AnchorFreeMono3DHead):
shape (B, channel, H , W).
shape (B, channel, H , W).
labels3d (Tensor): Labels of each 3D box.
labels3d (Tensor): Labels of each 3D box.
shape (B * max_objs, )
shape (B * max_objs, )
centers2d (Tensor): Coords of each projected 3D box
centers
_
2d (Tensor): Coords of each projected 3D box
center on image. shape (N, 2)
center on image. shape (N, 2)
reg_mask (Tensor): Indexes of the existence of the 3D box.
reg_mask (Tensor): Indexes of the existence of the 3D box.
shape (B * max_objs, )
shape (B * max_objs, )
batch_indices (Tenosr): Batch indices of the 3D box.
batch_indices (Tenosr): Batch indices of the 3D box.
shape (N, 3)
shape (N, 3)
input
_metas (list[dict]): Meta information of each image,
batch_img
_metas (list[dict]): Meta information of each image,
e.g.,
e.g.,
image size, scaling factor, etc.
image size, scaling factor, etc.
downsample_ratio (int): The stride of feature map.
downsample_ratio (int): The stride of feature map.
Returns:
Returns:
...
@@ -444,50 +515,41 @@ class MonoFlexHead(AnchorFreeMono3DHead):
...
@@ -444,50 +515,41 @@ class MonoFlexHead(AnchorFreeMono3DHead):
batch
,
channel
=
pred_reg
.
shape
[
0
],
pred_reg
.
shape
[
1
]
batch
,
channel
=
pred_reg
.
shape
[
0
],
pred_reg
.
shape
[
1
]
w
=
pred_reg
.
shape
[
3
]
w
=
pred_reg
.
shape
[
3
]
cam2imgs
=
torch
.
stack
([
cam2imgs
=
torch
.
stack
([
centers2d
.
new_tensor
(
i
nput
_meta
[
'cam2img'
])
centers
_
2d
.
new_tensor
(
i
mg
_meta
[
'cam2img'
])
for
i
nput
_meta
in
input
_metas
for
i
mg
_meta
in
batch_img
_metas
])
])
# (batch_size, 4, 4) -> (N, 4, 4)
# (batch_size, 4, 4) -> (N, 4, 4)
cam2imgs
=
cam2imgs
[
batch_indices
,
:,
:]
cam2imgs
=
cam2imgs
[
batch_indices
,
:,
:]
centers2d_inds
=
centers2d
[:,
1
]
*
w
+
centers2d
[:,
0
]
centers
_
2d_inds
=
centers
_
2d
[:,
1
]
*
w
+
centers
_
2d
[:,
0
]
centers2d_inds
=
centers2d_inds
.
view
(
batch
,
-
1
)
centers
_
2d_inds
=
centers
_
2d_inds
.
view
(
batch
,
-
1
)
pred_regression
=
transpose_and_gather_feat
(
pred_reg
,
centers2d_inds
)
pred_regression
=
transpose_and_gather_feat
(
pred_reg
,
centers
_
2d_inds
)
pred_regression_pois
=
pred_regression
.
view
(
-
1
,
channel
)[
reg_mask
]
pred_regression_pois
=
pred_regression
.
view
(
-
1
,
channel
)[
reg_mask
]
preds
=
self
.
bbox_coder
.
decode
(
pred_regression_pois
,
labels3d
,
preds
=
self
.
bbox_coder
.
decode
(
pred_regression_pois
,
labels3d
,
downsample_ratio
,
cam2imgs
)
downsample_ratio
,
cam2imgs
)
return
preds
return
preds
def
get_targets
(
self
,
gt_bboxes_list
,
gt_labels_list
,
gt_bboxes_3d_list
,
def
get_targets
(
self
,
batch_gt_instances_3d
:
List
[
InstanceData
],
gt_labels_3d_list
,
centers2d_list
,
depths_list
,
feat_shape
,
feat_shape
:
Tuple
[
int
],
batch_img_metas
:
List
[
dict
]):
img_shape
,
input_metas
):
"""Get training targets for batch images.
"""Get training targets for batch images.
``
``
Args:
Args:
gt_bboxes_list (list[Tensor]): Ground truth bboxes of each
batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
image, shape (num_gt, 4).
gt_instance_3d. It usually includes ``bboxes``、``labels``
gt_labels_list (list[Tensor]): Ground truth labels of each
、``bboxes_3d``、``labels_3d``、``depths``、``centers_2d`` and
box, shape (num_gt,).
attributes.
gt_bboxes_3d_list (list[:obj:`CameraInstance3DBoxes`]): 3D
Ground truth bboxes of each image,
shape (num_gt, bbox_code_size).
gt_labels_3d_list (list[Tensor]): 3D Ground truth labels of
each box, shape (num_gt,).
centers2d_list (list[Tensor]): Projected 3D centers onto 2D
image, shape (num_gt, 2).
depths_list (list[Tensor]): Depth of projected 3D centers onto 2D
image, each has shape (num_gt, 1).
feat_shape (tuple[int]): Feature map shape with value,
feat_shape (tuple[int]): Feature map shape with value,
shape (B, _, H, W).
shape (B, _, H, W).
img_shape (tuple[int]): Image shape in [h, w] format.
batch_img_metas (list[dict]): Meta information of each image, e.g.,
input_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
image size, scaling factor, etc.
Returns:
Returns:
tuple[Tensor, dict]: The Tensor value is the targets of
tuple[Tensor, dict]: The Tensor value is the targets of
center heatmap, the dict has components below:
center heatmap, the dict has components below:
- base_centers2d_target (Tensor): Coords of each projected 3D box
- base_centers_2d_target (Tensor): Coords of each projected
center on image. shape (B * max_objs, 2), [dtype: int]
3D box center on image. shape (B * max_objs, 2),
[dtype: int]
- labels3d (Tensor): Labels of each 3D box.
- labels3d (Tensor): Labels of each 3D box.
shape (N, )
shape (N, )
- reg_mask (Tensor): Mask of the existence of the 3D box.
- reg_mask (Tensor): Mask of the existence of the 3D box.
...
@@ -504,14 +566,36 @@ class MonoFlexHead(AnchorFreeMono3DHead):
...
@@ -504,14 +566,36 @@ class MonoFlexHead(AnchorFreeMono3DHead):
of each 3D box. shape (N, 3)
of each 3D box. shape (N, 3)
- orientations_target (Tensor): Orientation (encoded local yaw)
- orientations_target (Tensor): Orientation (encoded local yaw)
target of each 3D box. shape (N, )
target of each 3D box. shape (N, )
- offsets2d_target (Tensor): Offsets target of each projected
- offsets
_
2d_target (Tensor): Offsets target of each projected
3D box. shape (N, 2)
3D box. shape (N, 2)
- dimensions_target (Tensor): Dimensions target of each 3D box.
- dimensions_target (Tensor): Dimensions target of each 3D box.
shape (N, 3)
shape (N, 3)
- downsample_ratio (int): The stride of feature map.
- downsample_ratio (int): The stride of feature map.
"""
"""
img_h
,
img_w
=
img_shape
[:
2
]
gt_bboxes_list
=
[
gt_instances_3d
.
bboxes
for
gt_instances_3d
in
batch_gt_instances_3d
]
gt_labels_list
=
[
gt_instances_3d
.
labels
for
gt_instances_3d
in
batch_gt_instances_3d
]
gt_bboxes_3d_list
=
[
gt_instances_3d
.
bboxes_3d
for
gt_instances_3d
in
batch_gt_instances_3d
]
gt_labels_3d_list
=
[
gt_instances_3d
.
labels_3d
for
gt_instances_3d
in
batch_gt_instances_3d
]
centers_2d_list
=
[
gt_instances_3d
.
centers_2d
for
gt_instances_3d
in
batch_gt_instances_3d
]
depths_list
=
[
gt_instances_3d
.
depths
for
gt_instances_3d
in
batch_gt_instances_3d
]
img_h
,
img_w
=
batch_img_metas
[
0
][
'pad_shape'
][:
2
]
batch_size
,
_
,
feat_h
,
feat_w
=
feat_shape
batch_size
,
_
,
feat_h
,
feat_w
=
feat_shape
width_ratio
=
float
(
feat_w
/
img_w
)
# 1/4
width_ratio
=
float
(
feat_w
/
img_w
)
# 1/4
...
@@ -523,16 +607,16 @@ class MonoFlexHead(AnchorFreeMono3DHead):
...
@@ -523,16 +607,16 @@ class MonoFlexHead(AnchorFreeMono3DHead):
if
self
.
filter_outside_objs
:
if
self
.
filter_outside_objs
:
filter_outside_objs
(
gt_bboxes_list
,
gt_labels_list
,
filter_outside_objs
(
gt_bboxes_list
,
gt_labels_list
,
gt_bboxes_3d_list
,
gt_labels_3d_list
,
gt_bboxes_3d_list
,
gt_labels_3d_list
,
centers2d_list
,
input
_metas
)
centers
_
2d_list
,
batch_img
_metas
)
# transform centers2d to base centers2d for regression and
# transform centers
_
2d to base centers
_
2d for regression and
# heatmap generation.
# heatmap generation.
# centers2d = int(base_centers2d) + offsets2d
# centers
_
2d = int(base_centers
_
2d) + offsets
_
2d
base_centers2d_list
,
offsets2d_list
,
trunc_mask_list
=
\
base_centers
_
2d_list
,
offsets
_
2d_list
,
trunc_mask_list
=
\
handle_proj_objs
(
centers2d_list
,
gt_bboxes_list
,
input
_metas
)
handle_proj_objs
(
centers
_
2d_list
,
gt_bboxes_list
,
batch_img
_metas
)
keypoints2d_list
,
keypoints_mask_list
,
keypoints_depth_mask_list
=
\
keypoints2d_list
,
keypoints_mask_list
,
keypoints_depth_mask_list
=
\
get_keypoints
(
gt_bboxes_3d_list
,
centers2d_list
,
input
_metas
)
get_keypoints
(
gt_bboxes_3d_list
,
centers
_
2d_list
,
batch_img
_metas
)
center_heatmap_target
=
gt_bboxes_list
[
-
1
].
new_zeros
(
center_heatmap_target
=
gt_bboxes_list
[
-
1
].
new_zeros
(
[
batch_size
,
self
.
num_classes
,
feat_h
,
feat_w
])
[
batch_size
,
self
.
num_classes
,
feat_h
,
feat_w
])
...
@@ -542,11 +626,11 @@ class MonoFlexHead(AnchorFreeMono3DHead):
...
@@ -542,11 +626,11 @@ class MonoFlexHead(AnchorFreeMono3DHead):
gt_bboxes
=
gt_bboxes_list
[
batch_id
]
*
width_ratio
gt_bboxes
=
gt_bboxes_list
[
batch_id
]
*
width_ratio
gt_labels
=
gt_labels_list
[
batch_id
]
gt_labels
=
gt_labels_list
[
batch_id
]
# project base centers2d from input image to feat map
# project base centers
_
2d from input image to feat map
gt_base_centers2d
=
base_centers2d_list
[
batch_id
]
*
width_ratio
gt_base_centers
_
2d
=
base_centers
_
2d_list
[
batch_id
]
*
width_ratio
trunc_masks
=
trunc_mask_list
[
batch_id
]
trunc_masks
=
trunc_mask_list
[
batch_id
]
for
j
,
base_center2d
in
enumerate
(
gt_base_centers2d
):
for
j
,
base_center2d
in
enumerate
(
gt_base_centers
_
2d
):
if
trunc_masks
[
j
]:
if
trunc_masks
[
j
]:
# for outside objects, generate ellipse heatmap
# for outside objects, generate ellipse heatmap
base_center2d_x_int
,
base_center2d_y_int
=
\
base_center2d_x_int
,
base_center2d_y_int
=
\
...
@@ -579,40 +663,40 @@ class MonoFlexHead(AnchorFreeMono3DHead):
...
@@ -579,40 +663,40 @@ class MonoFlexHead(AnchorFreeMono3DHead):
[
base_center2d_x_int
,
base_center2d_y_int
],
radius
)
[
base_center2d_x_int
,
base_center2d_y_int
],
radius
)
avg_factor
=
max
(
1
,
center_heatmap_target
.
eq
(
1
).
sum
())
avg_factor
=
max
(
1
,
center_heatmap_target
.
eq
(
1
).
sum
())
num_ctrs
=
[
centers2d
.
shape
[
0
]
for
centers2d
in
centers2d_list
]
num_ctrs
=
[
centers
_
2d
.
shape
[
0
]
for
centers
_
2d
in
centers
_
2d_list
]
max_objs
=
max
(
num_ctrs
)
max_objs
=
max
(
num_ctrs
)
batch_indices
=
[
batch_indices
=
[
centers2d_list
[
0
].
new_full
((
num_ctrs
[
i
],
),
i
)
centers
_
2d_list
[
0
].
new_full
((
num_ctrs
[
i
],
),
i
)
for
i
in
range
(
batch_size
)
for
i
in
range
(
batch_size
)
]
]
batch_indices
=
torch
.
cat
(
batch_indices
,
dim
=
0
)
batch_indices
=
torch
.
cat
(
batch_indices
,
dim
=
0
)
reg_mask
=
torch
.
zeros
(
reg_mask
=
torch
.
zeros
(
(
batch_size
,
max_objs
),
(
batch_size
,
max_objs
),
dtype
=
torch
.
bool
).
to
(
base_centers2d_list
[
0
].
device
)
dtype
=
torch
.
bool
).
to
(
base_centers
_
2d_list
[
0
].
device
)
gt_bboxes_3d
=
input
_metas
[
'box_type_3d'
].
cat
(
gt_bboxes_3d_list
)
gt_bboxes_3d
=
batch_img
_metas
[
0
][
'box_type_3d'
].
cat
(
gt_bboxes_3d_list
)
gt_bboxes_3d
=
gt_bboxes_3d
.
to
(
base_centers2d_list
[
0
].
device
)
gt_bboxes_3d
=
gt_bboxes_3d
.
to
(
base_centers
_
2d_list
[
0
].
device
)
# encode original local yaw to multibin format
# encode original local yaw to multibin format
orienations_target
=
self
.
bbox_coder
.
encode
(
gt_bboxes_3d
)
orienations_target
=
self
.
bbox_coder
.
encode
(
gt_bboxes_3d
)
batch_base_centers2d
=
base_centers2d_list
[
0
].
new_zeros
(
batch_base_centers
_
2d
=
base_centers
_
2d_list
[
0
].
new_zeros
(
(
batch_size
,
max_objs
,
2
))
(
batch_size
,
max_objs
,
2
))
for
i
in
range
(
batch_size
):
for
i
in
range
(
batch_size
):
reg_mask
[
i
,
:
num_ctrs
[
i
]]
=
1
reg_mask
[
i
,
:
num_ctrs
[
i
]]
=
1
batch_base_centers2d
[
i
,
:
num_ctrs
[
i
]]
=
base_centers2d_list
[
i
]
batch_base_centers
_
2d
[
i
,
:
num_ctrs
[
i
]]
=
base_centers
_
2d_list
[
i
]
flatten_reg_mask
=
reg_mask
.
flatten
()
flatten_reg_mask
=
reg_mask
.
flatten
()
# transform base centers2d from input scale to output scale
# transform base centers
_
2d from input scale to output scale
batch_base_centers2d
=
batch_base_centers2d
.
view
(
-
1
,
2
)
*
width_ratio
batch_base_centers
_
2d
=
batch_base_centers
_
2d
.
view
(
-
1
,
2
)
*
width_ratio
dimensions_target
=
gt_bboxes_3d
.
tensor
[:,
3
:
6
]
dimensions_target
=
gt_bboxes_3d
.
tensor
[:,
3
:
6
]
labels_3d
=
torch
.
cat
(
gt_labels_3d_list
)
labels_3d
=
torch
.
cat
(
gt_labels_3d_list
)
keypoints2d_target
=
torch
.
cat
(
keypoints2d_list
)
keypoints2d_target
=
torch
.
cat
(
keypoints2d_list
)
keypoints_mask
=
torch
.
cat
(
keypoints_mask_list
)
keypoints_mask
=
torch
.
cat
(
keypoints_mask_list
)
keypoints_depth_mask
=
torch
.
cat
(
keypoints_depth_mask_list
)
keypoints_depth_mask
=
torch
.
cat
(
keypoints_depth_mask_list
)
offsets2d_target
=
torch
.
cat
(
offsets2d_list
)
offsets
_
2d_target
=
torch
.
cat
(
offsets
_
2d_list
)
bboxes2d
=
torch
.
cat
(
gt_bboxes_list
)
bboxes2d
=
torch
.
cat
(
gt_bboxes_list
)
# transform FCOS style bbox into [x1, y1, x2, y2] format.
# transform FCOS style bbox into [x1, y1, x2, y2] format.
...
@@ -621,7 +705,7 @@ class MonoFlexHead(AnchorFreeMono3DHead):
...
@@ -621,7 +705,7 @@ class MonoFlexHead(AnchorFreeMono3DHead):
depths
=
torch
.
cat
(
depths_list
)
depths
=
torch
.
cat
(
depths_list
)
target_labels
=
dict
(
target_labels
=
dict
(
base_centers2d_target
=
batch_base_centers2d
.
int
(),
base_centers
_
2d_target
=
batch_base_centers
_
2d
.
int
(),
labels3d
=
labels_3d
,
labels3d
=
labels_3d
,
reg_mask
=
flatten_reg_mask
,
reg_mask
=
flatten_reg_mask
,
batch_indices
=
batch_indices
,
batch_indices
=
batch_indices
,
...
@@ -631,24 +715,18 @@ class MonoFlexHead(AnchorFreeMono3DHead):
...
@@ -631,24 +715,18 @@ class MonoFlexHead(AnchorFreeMono3DHead):
keypoints_mask
=
keypoints_mask
,
keypoints_mask
=
keypoints_mask
,
keypoints_depth_mask
=
keypoints_depth_mask
,
keypoints_depth_mask
=
keypoints_depth_mask
,
orienations_target
=
orienations_target
,
orienations_target
=
orienations_target
,
offsets2d_target
=
offsets2d_target
,
offsets
_
2d_target
=
offsets
_
2d_target
,
dimensions_target
=
dimensions_target
,
dimensions_target
=
dimensions_target
,
downsample_ratio
=
1
/
width_ratio
)
downsample_ratio
=
1
/
width_ratio
)
return
center_heatmap_target
,
avg_factor
,
target_labels
return
center_heatmap_target
,
avg_factor
,
target_labels
def
loss
(
self
,
def
loss
(
self
,
cls_scores
,
cls_scores
:
List
[
Tensor
],
bbox_preds
,
bbox_preds
:
List
[
Tensor
],
gt_bboxes
,
batch_gt_instances_3d
:
List
[
InstanceData
],
gt_labels
,
batch_img_metas
:
List
[
dict
],
gt_bboxes_3d
,
batch_gt_instances_ignore
:
Optional
[
List
[
InstanceData
]]
=
None
):
gt_labels_3d
,
centers2d
,
depths
,
attr_labels
,
input_metas
,
gt_bboxes_ignore
=
None
):
"""Compute loss of the head.
"""Compute loss of the head.
Args:
Args:
...
@@ -657,48 +735,37 @@ class MonoFlexHead(AnchorFreeMono3DHead):
...
@@ -657,48 +735,37 @@ class MonoFlexHead(AnchorFreeMono3DHead):
bbox_preds (list[Tensor]): Box dims is a 4D-tensor, the channel
bbox_preds (list[Tensor]): Box dims is a 4D-tensor, the channel
number is bbox_code_size.
number is bbox_code_size.
shape (B, 7, H, W).
shape (B, 7, H, W).
gt_bboxes (list[Tensor]): Ground truth bboxes for each image.
batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
gt_instance_3d. It usually includes ``bboxes``、``labels``
gt_labels (list[Tensor]): Class indices corresponding to each box.
、``bboxes_3d``、``labels_3d``、``depths``、``centers_2d`` and
shape (num_gts, ).
attributes.
gt_bboxes_3d (list[:obj:`CameraInstance3DBoxes`]): 3D boxes ground
batch_img_metas (list[dict]): Meta information of each image, e.g.,
truth. it is the flipped gt_bboxes
gt_labels_3d (list[Tensor]): Same as gt_labels.
centers2d (list[Tensor]): 2D centers on the image.
shape (num_gts, 2).
depths (list[Tensor]): Depth ground truth.
shape (num_gts, ).
attr_labels (list[Tensor]): Attributes indices of each box.
In kitti it's None.
input_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
image size, scaling factor, etc.
gt_bboxes_ignore (None | list[Tensor]): Specify which bounding
batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional):
boxes can be ignored when computing the loss.
Batch of gt_instances_ignore. It includes ``bboxes`` attribute
Default: None.
data that is ignored during training and testing.
Defaults to None.
Returns:
Returns:
dict[str, Tensor]: A dictionary of loss components.
dict[str, Tensor]: A dictionary of loss components.
"""
"""
assert
len
(
cls_scores
)
==
len
(
bbox_preds
)
==
1
assert
len
(
cls_scores
)
==
len
(
bbox_preds
)
==
1
assert
attr_labels
is
None
assert
batch_gt_instances_ignore
is
None
assert
gt_bboxes_ignore
is
None
center2d_heatmap
=
cls_scores
[
0
]
center2d_heatmap
=
cls_scores
[
0
]
pred_reg
=
bbox_preds
[
0
]
pred_reg
=
bbox_preds
[
0
]
center2d_heatmap_target
,
avg_factor
,
target_labels
=
\
center2d_heatmap_target
,
avg_factor
,
target_labels
=
\
self
.
get_targets
(
gt_bboxes
,
gt_labels
,
gt_bboxes_3d
,
self
.
get_targets
(
batch_gt_instances_3d
,
gt_labels_3d
,
centers2d
,
depths
,
center2d_heatmap
.
shape
,
center2d_heatmap
.
shape
,
input_metas
[
0
][
'pad_shape'
],
batch_img_metas
)
input_metas
)
preds
=
self
.
get_predictions
(
preds
=
self
.
get_predictions
(
pred_reg
=
pred_reg
,
pred_reg
=
pred_reg
,
labels3d
=
target_labels
[
'labels3d'
],
labels3d
=
target_labels
[
'labels3d'
],
centers2d
=
target_labels
[
'base_centers2d_target'
],
centers
_
2d
=
target_labels
[
'base_centers
_
2d_target'
],
reg_mask
=
target_labels
[
'reg_mask'
],
reg_mask
=
target_labels
[
'reg_mask'
],
batch_indices
=
target_labels
[
'batch_indices'
],
batch_indices
=
target_labels
[
'batch_indices'
],
input_metas
=
input
_metas
,
batch_img_metas
=
batch_img
_metas
,
downsample_ratio
=
target_labels
[
'downsample_ratio'
])
downsample_ratio
=
target_labels
[
'downsample_ratio'
])
# heatmap loss
# heatmap loss
...
@@ -726,8 +793,8 @@ class MonoFlexHead(AnchorFreeMono3DHead):
...
@@ -726,8 +793,8 @@ class MonoFlexHead(AnchorFreeMono3DHead):
target_labels
[
'dimensions_target'
])
target_labels
[
'dimensions_target'
])
# offsets for center heatmap
# offsets for center heatmap
loss_offsets2d
=
self
.
loss_offsets2d
(
preds
[
'offsets2d'
],
loss_offsets
_
2d
=
self
.
loss_offsets
_
2d
(
target_labels
[
'offsets2d_target'
])
preds
[
'offsets_2d'
],
target_labels
[
'offsets
_
2d_target'
])
# directly regressed depth loss with direct depth uncertainty loss
# directly regressed depth loss with direct depth uncertainty loss
direct_depth_weights
=
torch
.
exp
(
-
preds
[
'direct_depth_uncertainty'
])
direct_depth_weights
=
torch
.
exp
(
-
preds
[
'direct_depth_uncertainty'
])
...
@@ -764,7 +831,7 @@ class MonoFlexHead(AnchorFreeMono3DHead):
...
@@ -764,7 +831,7 @@ class MonoFlexHead(AnchorFreeMono3DHead):
loss_keypoints
=
loss_keypoints
,
loss_keypoints
=
loss_keypoints
,
loss_dir
=
loss_dir
,
loss_dir
=
loss_dir
,
loss_dims
=
loss_dims
,
loss_dims
=
loss_dims
,
loss_offsets2d
=
loss_offsets2d
,
loss_offsets
_
2d
=
loss_offsets
_
2d
,
loss_direct_depth
=
loss_direct_depth
,
loss_direct_depth
=
loss_direct_depth
,
loss_keypoints_depth
=
loss_keypoints_depth
,
loss_keypoints_depth
=
loss_keypoints_depth
,
loss_combined_depth
=
loss_combined_depth
)
loss_combined_depth
=
loss_combined_depth
)
...
...
tests/test_models/test_dense_heads/test_monoflex_head.py
0 → 100644
View file @
d490f024
# Copyright (c) OpenMMLab. All rights reserved.
from
unittest
import
TestCase
import
numpy
as
np
import
torch
from
mmdet3d.models.dense_heads
import
MonoFlexHead
class
TestMonoFlexHead
(
TestCase
):
def
test_monoflex_head_loss
(
self
):
"""Tests MonoFlex head loss and inference."""
input_metas
=
[
dict
(
img_shape
=
(
110
,
110
),
pad_shape
=
(
128
,
128
))]
monoflex_head
=
MonoFlexHead
(
num_classes
=
3
,
in_channels
=
64
,
use_edge_fusion
=
True
,
edge_fusion_inds
=
[(
1
,
0
)],
edge_heatmap_ratio
=
1
/
8
,
stacked_convs
=
0
,
feat_channels
=
64
,
use_direction_classifier
=
False
,
diff_rad_by_sin
=
False
,
pred_attrs
=
False
,
pred_velo
=
False
,
dir_offset
=
0
,
strides
=
None
,
group_reg_dims
=
((
4
,
),
(
2
,
),
(
20
,
),
(
3
,
),
(
3
,
),
(
8
,
8
),
(
1
,
),
(
1
,
)),
cls_branch
=
(
256
,
),
reg_branch
=
((
256
,
),
(
256
,
),
(
256
,
),
(
256
,
),
(
256
,
),
(
256
,
),
(
256
,
),
(
256
,
)),
num_attrs
=
0
,
bbox_code_size
=
7
,
dir_branch
=
(),
attr_branch
=
(),
bbox_coder
=
dict
(
type
=
'MonoFlexCoder'
,
depth_mode
=
'exp'
,
base_depth
=
(
26.494627
,
16.05988
),
depth_range
=
[
0.1
,
100
],
combine_depth
=
True
,
uncertainty_range
=
[
-
10
,
10
],
base_dims
=
((
3.8840
,
1.5261
,
1.6286
,
0.4259
,
0.1367
,
0.1022
),
(
0.8423
,
1.7607
,
0.6602
,
0.2349
,
0.1133
,
0.1427
),
(
1.7635
,
1.7372
,
0.5968
,
0.1766
,
0.0948
,
0.1242
)),
dims_mode
=
'linear'
,
multibin
=
True
,
num_dir_bins
=
4
,
bin_centers
=
[
0
,
np
.
pi
/
2
,
np
.
pi
,
-
np
.
pi
/
2
],
bin_margin
=
np
.
pi
/
6
,
code_size
=
7
),
conv_bias
=
True
,
dcn_on_last_conv
=
False
)
# Monoflex head expects a single level of features per image
feats
=
[
torch
.
rand
([
1
,
64
,
32
,
32
],
dtype
=
torch
.
float32
)]
# Test forward
cls_score
,
out_reg
=
monoflex_head
.
forward
(
feats
,
input_metas
)
self
.
assertEqual
(
cls_score
[
0
].
shape
,
torch
.
Size
([
1
,
3
,
32
,
32
]),
'the shape of cls_score should be [1, 3, 32, 32]'
)
self
.
assertEqual
(
out_reg
[
0
].
shape
,
torch
.
Size
([
1
,
50
,
32
,
32
]),
'the shape of out_reg should be [1, 50, 32, 32]'
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment