"vscode:/vscode.git/clone" did not exist on "30ba39c9ade866feea982ff1992bf808195b7bc3"
Commit 5db1ead3 authored by ZCMax's avatar ZCMax Committed by ChaimZhu
Browse files

[Refactor] Base + AnchorFreeMono3DHead + FCOSMono3DHead model interface

parent a79b105b
......@@ -400,15 +400,9 @@ class AnchorFreeMono3DHead(BaseMono3DDenseHead):
bbox_preds,
dir_cls_preds,
attr_preds,
gt_bboxes,
gt_labels,
gt_bboxes_3d,
gt_labels_3d,
centers2d,
depths,
attr_labels,
img_metas,
gt_bboxes_ignore=None):
batch_gt_instances_3d,
batch_img_metas,
batch_gt_instances_ignore=None):
"""Compute loss of the head.
Args:
......@@ -424,20 +418,16 @@ class AnchorFreeMono3DHead(BaseMono3DDenseHead):
attr_preds (list[Tensor]): Box scores for each scale level,
each is a 4D-tensor, the channel number is
num_points * num_attrs.
gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
gt_labels (list[Tensor]): class indices corresponding to each box
gt_bboxes_3d (list[Tensor]): 3D Ground truth bboxes for each
image with shape (num_gts, bbox_code_size).
gt_labels_3d (list[Tensor]): 3D class indices of each box.
centers2d (list[Tensor]): Projected 3D centers onto 2D images.
depths (list[Tensor]): Depth of projected centers on 2D images.
attr_labels (list[Tensor], optional): Attribute indices
corresponding to each box
img_metas (list[dict]): Meta information of each image, e.g.,
batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
gt_instance_3d. It usually includes ``bboxes``、``labels``
、``bboxes_3d``、``labels3d``、``depths``、``centers2d`` and
attributes.
batch_img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
gt_bboxes_ignore (list[Tensor]): specify which bounding
boxes can be ignored when computing the loss.
batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional):
Batch of gt_instances_ignore. It includes ``bboxes`` attribute
data that is ignored during training and testing.
Defaults to None.
"""
raise NotImplementedError
......@@ -474,29 +464,17 @@ class AnchorFreeMono3DHead(BaseMono3DDenseHead):
raise NotImplementedError
@abstractmethod
def get_targets(self, points, gt_bboxes_list, gt_labels_list,
gt_bboxes_3d_list, gt_labels_3d_list, centers2d_list,
depths_list, attr_labels_list):
def get_targets(self, points, batch_gt_instances_3d):
"""Compute regression, classification and centerss targets for points
in multiple images.
Args:
points (list[Tensor]): Points of each fpn level, each has shape
(num_points, 2).
gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image,
each has shape (num_gt, 4).
gt_labels_list (list[Tensor]): Ground truth labels of each box,
each has shape (num_gt,).
gt_bboxes_3d_list (list[Tensor]): 3D Ground truth bboxes of each
image, each has shape (num_gt, bbox_code_size).
gt_labels_3d_list (list[Tensor]): 3D Ground truth labels of each
box, each has shape (num_gt,).
centers2d_list (list[Tensor]): Projected 3D centers onto 2D image,
each has shape (num_gt, 2).
depths_list (list[Tensor]): Depth of projected 3D centers onto 2D
image, each has shape (num_gt, 1).
attr_labels_list (list[Tensor]): Attribute labels of each box,
each has shape (num_gt,).
batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
gt_instance_3d. It usually includes ``bboxes``、``labels``
、``bboxes_3d``、``labels3d``、``depths``、``centers2d`` and
attributes.
"""
raise NotImplementedError
......
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from abc import ABCMeta, abstractmethod
from typing import List, Optional
from mmcv.runner import BaseModule
from mmengine.config import ConfigDict
from torch import Tensor
from mmdet3d.core import Det3DDataSample
class BaseMono3DDenseHead(BaseModule, metaclass=ABCMeta):
"""Base class for Monocular 3D DenseHeads."""
def __init__(self, init_cfg=None):
def __init__(self, init_cfg: Optional[dict] = None) -> None:
super(BaseMono3DDenseHead, self).__init__(init_cfg=init_cfg)
@abstractmethod
......@@ -15,64 +21,78 @@ class BaseMono3DDenseHead(BaseModule, metaclass=ABCMeta):
"""Compute losses of the head."""
pass
def get_bboxes(self, *args, **kwargs):
warnings.warn('`get_bboxes` is deprecated and will be removed in '
'the future. Please use `get_results` instead.')
return self.get_results(*args, **kwargs)
@abstractmethod
def get_bboxes(self, **kwargs):
"""Transform network output for a batch into bbox predictions."""
def get_results(self, *args, **kwargs):
"""Transform network outputs of a batch into 3D bbox results."""
pass
def forward_train(self,
x,
img_metas,
gt_bboxes,
gt_labels=None,
gt_bboxes_3d=None,
gt_labels_3d=None,
centers2d=None,
depths=None,
attr_labels=None,
gt_bboxes_ignore=None,
proposal_cfg=None,
x: List[Tensor],
batch_data_samples: List[Det3DDataSample],
proposal_cfg: Optional[ConfigDict] = None,
**kwargs):
"""
Args:
x (list[Tensor]): Features from FPN.
img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
gt_bboxes (list[Tensor]): Ground truth bboxes of the image,
shape (num_gts, 4).
gt_labels (list[Tensor]): Ground truth labels of each box,
shape (num_gts,).
gt_bboxes_3d (list[Tensor]): 3D ground truth bboxes of the image,
shape (num_gts, self.bbox_code_size).
gt_labels_3d (list[Tensor]): 3D ground truth labels of each box,
shape (num_gts,).
centers2d (list[Tensor]): Projected 3D center of each box,
shape (num_gts, 2).
depths (list[Tensor]): Depth of projected 3D center of each box,
shape (num_gts,).
attr_labels (list[Tensor]): Attribute labels of each box,
shape (num_gts,).
gt_bboxes_ignore (list[Tensor]): Ground truth bboxes to be
ignored, shape (num_ignored_gts, 4).
proposal_cfg (mmcv.Config): Test / postprocessing configuration,
if None, test_cfg would be used
batch_data_samples (list[:obj:`Det3DDataSample`]): Each item
contains the meta information of each image and corresponding
annotations.
proposal_cfg (mmengine.Config, optional): Test / postprocessing
configuration, if None, test_cfg would be used.
Defaults to None.
Returns:
tuple:
losses: (dict[str, Tensor]): A dictionary of loss components.
proposal_list (list[Tensor]): Proposals of each image.
tuple or Tensor: When `proposal_cfg` is None, the detector is a \
normal one-stage detector, The return value is the losses.
- losses: (dict[str, Tensor]): A dictionary of loss components.
When the `proposal_cfg` is not None, the head is used as a
`rpn_head`, the return value is a tuple contains:
- losses: (dict[str, Tensor]): A dictionary of loss components.
- results_list (list[:obj:`InstanceData`]): Detection
results of each image after the post process.
Each item usually contains following keys.
- scores (Tensor): Classification scores, has a shape
(num_instance, )
- labels (Tensor): Labels of bboxes, has a shape
(num_instances, ).
- bboxes (:obj:`BaseInstance3DBoxes`): Contains a tensor
with shape (num_instances, C), the last dimension C of a
3D box is (x, y, z, x_size, y_size, z_size, yaw, ...), where
C >= 7. C = 7 for kitti and C = 9 for nuscenes with extra 2
dims of velocity.
"""
outs = self(x)
if gt_labels is None:
loss_inputs = outs + (gt_bboxes, gt_bboxes_3d, centers2d, depths,
attr_labels, img_metas)
else:
loss_inputs = outs + (gt_bboxes, gt_labels, gt_bboxes_3d,
gt_labels_3d, centers2d, depths, attr_labels,
img_metas)
losses = self.loss(*loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
batch_gt_instances_3d = []
batch_gt_instances_ignore = []
batch_img_metas = []
for data_sample in batch_data_samples:
batch_img_metas.append(data_sample.metainfo)
batch_gt_instances_3d.append(data_sample.gt_instances_3d)
if 'ignored_instances' in data_sample:
batch_gt_instances_ignore.append(data_sample.ignored_instances)
else:
batch_gt_instances_ignore.append(None)
loss_inputs = outs + (batch_gt_instances_3d, batch_img_metas,
batch_gt_instances_ignore)
losses = self.loss(*loss_inputs)
if proposal_cfg is None:
return losses
else:
proposal_list = self.get_bboxes(*outs, img_metas, cfg=proposal_cfg)
return losses, proposal_list
batch_img_metas = [
data_sample.metainfo for data_sample in batch_data_samples
]
results_list = self.get_results(
*outs, batch_img_metas=batch_img_metas, cfg=proposal_cfg)
return losses, results_list
......@@ -259,15 +259,9 @@ class FCOSMono3DHead(AnchorFreeMono3DHead):
dir_cls_preds,
attr_preds,
centernesses,
gt_bboxes,
gt_labels,
gt_bboxes_3d,
gt_labels_3d,
centers2d,
depths,
attr_labels,
img_metas,
gt_bboxes_ignore=None):
batch_gt_instances_3d,
batch_img_metas,
batch_gt_instances_ignore=None):
"""Compute loss of the head.
Args:
......@@ -285,21 +279,16 @@ class FCOSMono3DHead(AnchorFreeMono3DHead):
num_points * num_attrs.
centernesses (list[Tensor]): Centerness for each scale level, each
is a 4D-tensor, the channel number is num_points * 1.
gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
gt_labels (list[Tensor]): class indices corresponding to each box
gt_bboxes_3d (list[Tensor]): 3D boxes ground truth with shape of
(num_gts, code_size).
gt_labels_3d (list[Tensor]): same as gt_labels
centers2d (list[Tensor]): 2D centers on the image with shape of
(num_gts, 2).
depths (list[Tensor]): Depth ground truth with shape of
(num_gts, ).
attr_labels (list[Tensor]): Attributes indices of each box.
img_metas (list[dict]): Meta information of each image, e.g.,
batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
gt_instance_3d. It usually includes ``bboxes``、``labels``
、``bboxes_3d``、``labels3d``、``depths``、``centers2d`` and
attributes.
batch_img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
gt_bboxes_ignore (list[Tensor]): specify which bounding
boxes can be ignored when computing the loss.
batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional):
Batch of gt_instances_ignore. It includes ``bboxes`` attribute
data that is ignored during training and testing.
Defaults to None.
Returns:
dict[str, Tensor]: A dictionary of loss components.
......@@ -310,9 +299,7 @@ class FCOSMono3DHead(AnchorFreeMono3DHead):
all_level_points = self.get_points(featmap_sizes, bbox_preds[0].dtype,
bbox_preds[0].device)
labels_3d, bbox_targets_3d, centerness_targets, attr_targets = \
self.get_targets(
all_level_points, gt_bboxes, gt_labels, gt_bboxes_3d,
gt_labels_3d, centers2d, depths, attr_labels)
self.get_targets(all_level_points, batch_gt_instances_3d)
num_imgs = cls_scores[0].size(0)
# flatten cls_scores, bbox_preds, dir_cls_preds and centerness
......@@ -742,29 +729,17 @@ class FCOSMono3DHead(AnchorFreeMono3DHead):
dim=-1) + stride // 2
return points
def get_targets(self, points, gt_bboxes_list, gt_labels_list,
gt_bboxes_3d_list, gt_labels_3d_list, centers2d_list,
depths_list, attr_labels_list):
def get_targets(self, points, batch_gt_instances_3d):
"""Compute regression, classification and centerss targets for points
in multiple images.
Args:
points (list[Tensor]): Points of each fpn level, each has shape
(num_points, 2).
gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image,
each has shape (num_gt, 4).
gt_labels_list (list[Tensor]): Ground truth labels of each box,
each has shape (num_gt,).
gt_bboxes_3d_list (list[Tensor]): 3D Ground truth bboxes of each
image, each has shape (num_gt, bbox_code_size).
gt_labels_3d_list (list[Tensor]): 3D Ground truth labels of each
box, each has shape (num_gt,).
centers2d_list (list[Tensor]): Projected 3D centers onto 2D image,
each has shape (num_gt, 2).
depths_list (list[Tensor]): Depth of projected 3D centers onto 2D
image, each has shape (num_gt, 1).
attr_labels_list (list[Tensor]): Attribute labels of each box,
each has shape (num_gt,).
batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
gt_instance_3d. It usually includes ``bboxes``、``labels``
、``bboxes_3d``、``labels3d``、``depths``、``centers2d`` and
attributes.
Returns:
tuple:
......@@ -786,23 +761,11 @@ class FCOSMono3DHead(AnchorFreeMono3DHead):
# the number of points per img, per lvl
num_points = [center.size(0) for center in points]
if attr_labels_list is None:
attr_labels_list = [
gt_labels.new_full(gt_labels.shape, self.attr_background_label)
for gt_labels in gt_labels_list
]
# get labels and bbox_targets of each image
_, _, labels_3d_list, bbox_targets_3d_list, centerness_targets_list, \
attr_targets_list = multi_apply(
self._get_target_single,
gt_bboxes_list,
gt_labels_list,
gt_bboxes_3d_list,
gt_labels_3d_list,
centers2d_list,
depths_list,
attr_labels_list,
batch_gt_instances_3d,
points=concat_points,
regress_ranges=concat_regress_ranges,
num_points_per_lvl=num_points)
......@@ -850,12 +813,19 @@ class FCOSMono3DHead(AnchorFreeMono3DHead):
return concat_lvl_labels_3d, concat_lvl_bbox_targets_3d, \
concat_lvl_centerness_targets, concat_lvl_attr_targets
def _get_target_single(self, gt_bboxes, gt_labels, gt_bboxes_3d,
gt_labels_3d, centers2d, depths, attr_labels,
points, regress_ranges, num_points_per_lvl):
def _get_target_single(self, gt_instances_3d, points, regress_ranges,
num_points_per_lvl):
"""Compute regression and classification targets for a single image."""
num_points = points.size(0)
num_gts = gt_labels.size(0)
num_gts = len(gt_instances_3d)
gt_bboxes = gt_instances_3d.bboxes
gt_labels = gt_instances_3d.labels
gt_bboxes_3d = gt_instances_3d.bboxes_3d
gt_labels_3d = gt_instances_3d.labels_3d
centers2d = gt_instances_3d.centers2d
depths = gt_instances_3d.depths
attr_labels = gt_instances_3d.attr_labels
if not isinstance(gt_bboxes_3d, torch.Tensor):
gt_bboxes_3d = gt_bboxes_3d.tensor.to(gt_bboxes.device)
if num_gts == 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment