Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
mmdetection3d
Commits
5db1ead3
Commit
5db1ead3
authored
May 31, 2022
by
ZCMax
Committed by
ChaimZhu
Jul 20, 2022
Browse files
[Refactor] Base + AnchorFreeMono3DHead + FCOSMono3DHead model interface
parent
a79b105b
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
114 additions
and
146 deletions
+114
-146
mmdet3d/models/dense_heads/anchor_free_mono3d_head.py
mmdet3d/models/dense_heads/anchor_free_mono3d_head.py
+17
-39
mmdet3d/models/dense_heads/base_mono3d_dense_head.py
mmdet3d/models/dense_heads/base_mono3d_dense_head.py
+67
-47
mmdet3d/models/dense_heads/fcos_mono3d_head.py
mmdet3d/models/dense_heads/fcos_mono3d_head.py
+30
-60
No files found.
mmdet3d/models/dense_heads/anchor_free_mono3d_head.py
View file @
5db1ead3
...
@@ -400,15 +400,9 @@ class AnchorFreeMono3DHead(BaseMono3DDenseHead):
...
@@ -400,15 +400,9 @@ class AnchorFreeMono3DHead(BaseMono3DDenseHead):
bbox_preds
,
bbox_preds
,
dir_cls_preds
,
dir_cls_preds
,
attr_preds
,
attr_preds
,
gt_bboxes
,
batch_gt_instances_3d
,
gt_labels
,
batch_img_metas
,
gt_bboxes_3d
,
batch_gt_instances_ignore
=
None
):
gt_labels_3d
,
centers2d
,
depths
,
attr_labels
,
img_metas
,
gt_bboxes_ignore
=
None
):
"""Compute loss of the head.
"""Compute loss of the head.
Args:
Args:
...
@@ -424,20 +418,16 @@ class AnchorFreeMono3DHead(BaseMono3DDenseHead):
...
@@ -424,20 +418,16 @@ class AnchorFreeMono3DHead(BaseMono3DDenseHead):
attr_preds (list[Tensor]): Box scores for each scale level,
attr_preds (list[Tensor]): Box scores for each scale level,
each is a 4D-tensor, the channel number is
each is a 4D-tensor, the channel number is
num_points * num_attrs.
num_points * num_attrs.
gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
gt_instance_3d. It usually includes ``bboxes``、``labels``
gt_labels (list[Tensor]): class indices corresponding to each box
、``bboxes_3d``、``labels3d``、``depths``、``centers2d`` and
gt_bboxes_3d (list[Tensor]): 3D Ground truth bboxes for each
attributes.
image with shape (num_gts, bbox_code_size).
batch_img_metas (list[dict]): Meta information of each image, e.g.,
gt_labels_3d (list[Tensor]): 3D class indices of each box.
centers2d (list[Tensor]): Projected 3D centers onto 2D images.
depths (list[Tensor]): Depth of projected centers on 2D images.
attr_labels (list[Tensor], optional): Attribute indices
corresponding to each box
img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
image size, scaling factor, etc.
gt_bboxes_ignore (list[Tensor]): specify which bounding
batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional):
boxes can be ignored when computing the loss.
Batch of gt_instances_ignore. It includes ``bboxes`` attribute
data that is ignored during training and testing.
Defaults to None.
"""
"""
raise
NotImplementedError
raise
NotImplementedError
...
@@ -474,29 +464,17 @@ class AnchorFreeMono3DHead(BaseMono3DDenseHead):
...
@@ -474,29 +464,17 @@ class AnchorFreeMono3DHead(BaseMono3DDenseHead):
raise
NotImplementedError
raise
NotImplementedError
@
abstractmethod
@
abstractmethod
def
get_targets
(
self
,
points
,
gt_bboxes_list
,
gt_labels_list
,
def
get_targets
(
self
,
points
,
batch_gt_instances_3d
):
gt_bboxes_3d_list
,
gt_labels_3d_list
,
centers2d_list
,
depths_list
,
attr_labels_list
):
"""Compute regression, classification and centerss targets for points
"""Compute regression, classification and centerss targets for points
in multiple images.
in multiple images.
Args:
Args:
points (list[Tensor]): Points of each fpn level, each has shape
points (list[Tensor]): Points of each fpn level, each has shape
(num_points, 2).
(num_points, 2).
gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image,
batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
each has shape (num_gt, 4).
gt_instance_3d. It usually includes ``bboxes``、``labels``
gt_labels_list (list[Tensor]): Ground truth labels of each box,
、``bboxes_3d``、``labels3d``、``depths``、``centers2d`` and
each has shape (num_gt,).
attributes.
gt_bboxes_3d_list (list[Tensor]): 3D Ground truth bboxes of each
image, each has shape (num_gt, bbox_code_size).
gt_labels_3d_list (list[Tensor]): 3D Ground truth labels of each
box, each has shape (num_gt,).
centers2d_list (list[Tensor]): Projected 3D centers onto 2D image,
each has shape (num_gt, 2).
depths_list (list[Tensor]): Depth of projected 3D centers onto 2D
image, each has shape (num_gt, 1).
attr_labels_list (list[Tensor]): Attribute labels of each box,
each has shape (num_gt,).
"""
"""
raise
NotImplementedError
raise
NotImplementedError
...
...
mmdet3d/models/dense_heads/base_mono3d_dense_head.py
View file @
5db1ead3
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
import
warnings
from
abc
import
ABCMeta
,
abstractmethod
from
abc
import
ABCMeta
,
abstractmethod
from
typing
import
List
,
Optional
from
mmcv.runner
import
BaseModule
from
mmcv.runner
import
BaseModule
from
mmengine.config
import
ConfigDict
from
torch
import
Tensor
from
mmdet3d.core
import
Det3DDataSample
class
BaseMono3DDenseHead
(
BaseModule
,
metaclass
=
ABCMeta
):
class
BaseMono3DDenseHead
(
BaseModule
,
metaclass
=
ABCMeta
):
"""Base class for Monocular 3D DenseHeads."""
"""Base class for Monocular 3D DenseHeads."""
def
__init__
(
self
,
init_cfg
=
None
)
:
def
__init__
(
self
,
init_cfg
:
Optional
[
dict
]
=
None
)
->
None
:
super
(
BaseMono3DDenseHead
,
self
).
__init__
(
init_cfg
=
init_cfg
)
super
(
BaseMono3DDenseHead
,
self
).
__init__
(
init_cfg
=
init_cfg
)
@
abstractmethod
@
abstractmethod
...
@@ -15,64 +21,78 @@ class BaseMono3DDenseHead(BaseModule, metaclass=ABCMeta):
...
@@ -15,64 +21,78 @@ class BaseMono3DDenseHead(BaseModule, metaclass=ABCMeta):
"""Compute losses of the head."""
"""Compute losses of the head."""
pass
pass
def
get_bboxes
(
self
,
*
args
,
**
kwargs
):
warnings
.
warn
(
'`get_bboxes` is deprecated and will be removed in '
'the future. Please use `get_results` instead.'
)
return
self
.
get_results
(
*
args
,
**
kwargs
)
@
abstractmethod
@
abstractmethod
def
get_
bboxes
(
self
,
**
kwargs
):
def
get_
results
(
self
,
*
args
,
**
kwargs
):
"""Transform network output
for
a batch into bbox
p
re
diction
s."""
"""Transform network output
s of
a batch into
3D
bbox re
sult
s."""
pass
pass
def
forward_train
(
self
,
def
forward_train
(
self
,
x
,
x
:
List
[
Tensor
],
img_metas
,
batch_data_samples
:
List
[
Det3DDataSample
],
gt_bboxes
,
proposal_cfg
:
Optional
[
ConfigDict
]
=
None
,
gt_labels
=
None
,
gt_bboxes_3d
=
None
,
gt_labels_3d
=
None
,
centers2d
=
None
,
depths
=
None
,
attr_labels
=
None
,
gt_bboxes_ignore
=
None
,
proposal_cfg
=
None
,
**
kwargs
):
**
kwargs
):
"""
"""
Args:
Args:
x (list[Tensor]): Features from FPN.
x (list[Tensor]): Features from FPN.
img_metas (list[dict]): Meta information of each image, e.g.,
batch_data_samples (list[:obj:`Det3DDataSample`]): Each item
image size, scaling factor, etc.
contains the meta information of each image and corresponding
gt_bboxes (list[Tensor]): Ground truth bboxes of the image,
annotations.
shape (num_gts, 4).
proposal_cfg (mmengine.Config, optional): Test / postprocessing
gt_labels (list[Tensor]): Ground truth labels of each box,
configuration, if None, test_cfg would be used.
shape (num_gts,).
Defaults to None.
gt_bboxes_3d (list[Tensor]): 3D ground truth bboxes of the image,
shape (num_gts, self.bbox_code_size).
gt_labels_3d (list[Tensor]): 3D ground truth labels of each box,
shape (num_gts,).
centers2d (list[Tensor]): Projected 3D center of each box,
shape (num_gts, 2).
depths (list[Tensor]): Depth of projected 3D center of each box,
shape (num_gts,).
attr_labels (list[Tensor]): Attribute labels of each box,
shape (num_gts,).
gt_bboxes_ignore (list[Tensor]): Ground truth bboxes to be
ignored, shape (num_ignored_gts, 4).
proposal_cfg (mmcv.Config): Test / postprocessing configuration,
if None, test_cfg would be used
Returns:
Returns:
tuple:
tuple or Tensor: When `proposal_cfg` is None, the detector is a
\
losses: (dict[str, Tensor]): A dictionary of loss components.
normal one-stage detector, The return value is the losses.
proposal_list (list[Tensor]): Proposals of each image.
- losses: (dict[str, Tensor]): A dictionary of loss components.
When the `proposal_cfg` is not None, the head is used as a
`rpn_head`, the return value is a tuple contains:
- losses: (dict[str, Tensor]): A dictionary of loss components.
- results_list (list[:obj:`InstanceData`]): Detection
results of each image after the post process.
Each item usually contains following keys.
- scores (Tensor): Classification scores, has a shape
(num_instance, )
- labels (Tensor): Labels of bboxes, has a shape
(num_instances, ).
- bboxes (:obj:`BaseInstance3DBoxes`): Contains a tensor
with shape (num_instances, C), the last dimension C of a
3D box is (x, y, z, x_size, y_size, z_size, yaw, ...), where
C >= 7. C = 7 for kitti and C = 9 for nuscenes with extra 2
dims of velocity.
"""
"""
outs
=
self
(
x
)
outs
=
self
(
x
)
if
gt_labels
is
None
:
batch_gt_instances_3d
=
[]
loss_inputs
=
outs
+
(
gt_bboxes
,
gt_bboxes_3d
,
centers2d
,
depths
,
batch_gt_instances_ignore
=
[]
attr_labels
,
img_metas
)
batch_img_metas
=
[]
else
:
for
data_sample
in
batch_data_samples
:
loss_inputs
=
outs
+
(
gt_bboxes
,
gt_labels
,
gt_bboxes_3d
,
batch_img_metas
.
append
(
data_sample
.
metainfo
)
gt_labels_3d
,
centers2d
,
depths
,
attr_labels
,
batch_gt_instances_3d
.
append
(
data_sample
.
gt_instances_3d
)
img_metas
)
if
'ignored_instances'
in
data_sample
:
losses
=
self
.
loss
(
*
loss_inputs
,
gt_bboxes_ignore
=
gt_bboxes_ignore
)
batch_gt_instances_ignore
.
append
(
data_sample
.
ignored_instances
)
else
:
batch_gt_instances_ignore
.
append
(
None
)
loss_inputs
=
outs
+
(
batch_gt_instances_3d
,
batch_img_metas
,
batch_gt_instances_ignore
)
losses
=
self
.
loss
(
*
loss_inputs
)
if
proposal_cfg
is
None
:
if
proposal_cfg
is
None
:
return
losses
return
losses
else
:
else
:
proposal_list
=
self
.
get_bboxes
(
*
outs
,
img_metas
,
cfg
=
proposal_cfg
)
batch_img_metas
=
[
return
losses
,
proposal_list
data_sample
.
metainfo
for
data_sample
in
batch_data_samples
]
results_list
=
self
.
get_results
(
*
outs
,
batch_img_metas
=
batch_img_metas
,
cfg
=
proposal_cfg
)
return
losses
,
results_list
mmdet3d/models/dense_heads/fcos_mono3d_head.py
View file @
5db1ead3
...
@@ -259,15 +259,9 @@ class FCOSMono3DHead(AnchorFreeMono3DHead):
...
@@ -259,15 +259,9 @@ class FCOSMono3DHead(AnchorFreeMono3DHead):
dir_cls_preds
,
dir_cls_preds
,
attr_preds
,
attr_preds
,
centernesses
,
centernesses
,
gt_bboxes
,
batch_gt_instances_3d
,
gt_labels
,
batch_img_metas
,
gt_bboxes_3d
,
batch_gt_instances_ignore
=
None
):
gt_labels_3d
,
centers2d
,
depths
,
attr_labels
,
img_metas
,
gt_bboxes_ignore
=
None
):
"""Compute loss of the head.
"""Compute loss of the head.
Args:
Args:
...
@@ -285,21 +279,16 @@ class FCOSMono3DHead(AnchorFreeMono3DHead):
...
@@ -285,21 +279,16 @@ class FCOSMono3DHead(AnchorFreeMono3DHead):
num_points * num_attrs.
num_points * num_attrs.
centernesses (list[Tensor]): Centerness for each scale level, each
centernesses (list[Tensor]): Centerness for each scale level, each
is a 4D-tensor, the channel number is num_points * 1.
is a 4D-tensor, the channel number is num_points * 1.
gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
gt_instance_3d. It usually includes ``bboxes``、``labels``
gt_labels (list[Tensor]): class indices corresponding to each box
、``bboxes_3d``、``labels3d``、``depths``、``centers2d`` and
gt_bboxes_3d (list[Tensor]): 3D boxes ground truth with shape of
attributes.
(num_gts, code_size).
batch_img_metas (list[dict]): Meta information of each image, e.g.,
gt_labels_3d (list[Tensor]): same as gt_labels
centers2d (list[Tensor]): 2D centers on the image with shape of
(num_gts, 2).
depths (list[Tensor]): Depth ground truth with shape of
(num_gts, ).
attr_labels (list[Tensor]): Attributes indices of each box.
img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
image size, scaling factor, etc.
gt_bboxes_ignore (list[Tensor]): specify which bounding
batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional):
boxes can be ignored when computing the loss.
Batch of gt_instances_ignore. It includes ``bboxes`` attribute
data that is ignored during training and testing.
Defaults to None.
Returns:
Returns:
dict[str, Tensor]: A dictionary of loss components.
dict[str, Tensor]: A dictionary of loss components.
...
@@ -310,9 +299,7 @@ class FCOSMono3DHead(AnchorFreeMono3DHead):
...
@@ -310,9 +299,7 @@ class FCOSMono3DHead(AnchorFreeMono3DHead):
all_level_points
=
self
.
get_points
(
featmap_sizes
,
bbox_preds
[
0
].
dtype
,
all_level_points
=
self
.
get_points
(
featmap_sizes
,
bbox_preds
[
0
].
dtype
,
bbox_preds
[
0
].
device
)
bbox_preds
[
0
].
device
)
labels_3d
,
bbox_targets_3d
,
centerness_targets
,
attr_targets
=
\
labels_3d
,
bbox_targets_3d
,
centerness_targets
,
attr_targets
=
\
self
.
get_targets
(
self
.
get_targets
(
all_level_points
,
batch_gt_instances_3d
)
all_level_points
,
gt_bboxes
,
gt_labels
,
gt_bboxes_3d
,
gt_labels_3d
,
centers2d
,
depths
,
attr_labels
)
num_imgs
=
cls_scores
[
0
].
size
(
0
)
num_imgs
=
cls_scores
[
0
].
size
(
0
)
# flatten cls_scores, bbox_preds, dir_cls_preds and centerness
# flatten cls_scores, bbox_preds, dir_cls_preds and centerness
...
@@ -742,29 +729,17 @@ class FCOSMono3DHead(AnchorFreeMono3DHead):
...
@@ -742,29 +729,17 @@ class FCOSMono3DHead(AnchorFreeMono3DHead):
dim
=-
1
)
+
stride
//
2
dim
=-
1
)
+
stride
//
2
return
points
return
points
def
get_targets
(
self
,
points
,
gt_bboxes_list
,
gt_labels_list
,
def
get_targets
(
self
,
points
,
batch_gt_instances_3d
):
gt_bboxes_3d_list
,
gt_labels_3d_list
,
centers2d_list
,
depths_list
,
attr_labels_list
):
"""Compute regression, classification and centerss targets for points
"""Compute regression, classification and centerss targets for points
in multiple images.
in multiple images.
Args:
Args:
points (list[Tensor]): Points of each fpn level, each has shape
points (list[Tensor]): Points of each fpn level, each has shape
(num_points, 2).
(num_points, 2).
gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image,
batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
each has shape (num_gt, 4).
gt_instance_3d. It usually includes ``bboxes``、``labels``
gt_labels_list (list[Tensor]): Ground truth labels of each box,
、``bboxes_3d``、``labels3d``、``depths``、``centers2d`` and
each has shape (num_gt,).
attributes.
gt_bboxes_3d_list (list[Tensor]): 3D Ground truth bboxes of each
image, each has shape (num_gt, bbox_code_size).
gt_labels_3d_list (list[Tensor]): 3D Ground truth labels of each
box, each has shape (num_gt,).
centers2d_list (list[Tensor]): Projected 3D centers onto 2D image,
each has shape (num_gt, 2).
depths_list (list[Tensor]): Depth of projected 3D centers onto 2D
image, each has shape (num_gt, 1).
attr_labels_list (list[Tensor]): Attribute labels of each box,
each has shape (num_gt,).
Returns:
Returns:
tuple:
tuple:
...
@@ -786,23 +761,11 @@ class FCOSMono3DHead(AnchorFreeMono3DHead):
...
@@ -786,23 +761,11 @@ class FCOSMono3DHead(AnchorFreeMono3DHead):
# the number of points per img, per lvl
# the number of points per img, per lvl
num_points
=
[
center
.
size
(
0
)
for
center
in
points
]
num_points
=
[
center
.
size
(
0
)
for
center
in
points
]
if
attr_labels_list
is
None
:
attr_labels_list
=
[
gt_labels
.
new_full
(
gt_labels
.
shape
,
self
.
attr_background_label
)
for
gt_labels
in
gt_labels_list
]
# get labels and bbox_targets of each image
# get labels and bbox_targets of each image
_
,
_
,
labels_3d_list
,
bbox_targets_3d_list
,
centerness_targets_list
,
\
_
,
_
,
labels_3d_list
,
bbox_targets_3d_list
,
centerness_targets_list
,
\
attr_targets_list
=
multi_apply
(
attr_targets_list
=
multi_apply
(
self
.
_get_target_single
,
self
.
_get_target_single
,
gt_bboxes_list
,
batch_gt_instances_3d
,
gt_labels_list
,
gt_bboxes_3d_list
,
gt_labels_3d_list
,
centers2d_list
,
depths_list
,
attr_labels_list
,
points
=
concat_points
,
points
=
concat_points
,
regress_ranges
=
concat_regress_ranges
,
regress_ranges
=
concat_regress_ranges
,
num_points_per_lvl
=
num_points
)
num_points_per_lvl
=
num_points
)
...
@@ -850,12 +813,19 @@ class FCOSMono3DHead(AnchorFreeMono3DHead):
...
@@ -850,12 +813,19 @@ class FCOSMono3DHead(AnchorFreeMono3DHead):
return
concat_lvl_labels_3d
,
concat_lvl_bbox_targets_3d
,
\
return
concat_lvl_labels_3d
,
concat_lvl_bbox_targets_3d
,
\
concat_lvl_centerness_targets
,
concat_lvl_attr_targets
concat_lvl_centerness_targets
,
concat_lvl_attr_targets
def
_get_target_single
(
self
,
gt_bboxes
,
gt_labels
,
gt_bboxes_3d
,
def
_get_target_single
(
self
,
gt_instances_3d
,
points
,
regress_ranges
,
gt_labels_3d
,
centers2d
,
depths
,
attr_labels
,
num_points_per_lvl
):
points
,
regress_ranges
,
num_points_per_lvl
):
"""Compute regression and classification targets for a single image."""
"""Compute regression and classification targets for a single image."""
num_points
=
points
.
size
(
0
)
num_points
=
points
.
size
(
0
)
num_gts
=
gt_labels
.
size
(
0
)
num_gts
=
len
(
gt_instances_3d
)
gt_bboxes
=
gt_instances_3d
.
bboxes
gt_labels
=
gt_instances_3d
.
labels
gt_bboxes_3d
=
gt_instances_3d
.
bboxes_3d
gt_labels_3d
=
gt_instances_3d
.
labels_3d
centers2d
=
gt_instances_3d
.
centers2d
depths
=
gt_instances_3d
.
depths
attr_labels
=
gt_instances_3d
.
attr_labels
if
not
isinstance
(
gt_bboxes_3d
,
torch
.
Tensor
):
if
not
isinstance
(
gt_bboxes_3d
,
torch
.
Tensor
):
gt_bboxes_3d
=
gt_bboxes_3d
.
tensor
.
to
(
gt_bboxes
.
device
)
gt_bboxes_3d
=
gt_bboxes_3d
.
tensor
.
to
(
gt_bboxes
.
device
)
if
num_gts
==
0
:
if
num_gts
==
0
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment