Commit 5db1ead3 authored by ZCMax's avatar ZCMax Committed by ChaimZhu
Browse files

[Refactor] Base + AnchorFreeMono3DHead + FCOSMono3DHead model interface

parent a79b105b
...@@ -400,15 +400,9 @@ class AnchorFreeMono3DHead(BaseMono3DDenseHead): ...@@ -400,15 +400,9 @@ class AnchorFreeMono3DHead(BaseMono3DDenseHead):
bbox_preds, bbox_preds,
dir_cls_preds, dir_cls_preds,
attr_preds, attr_preds,
gt_bboxes, batch_gt_instances_3d,
gt_labels, batch_img_metas,
gt_bboxes_3d, batch_gt_instances_ignore=None):
gt_labels_3d,
centers2d,
depths,
attr_labels,
img_metas,
gt_bboxes_ignore=None):
"""Compute loss of the head. """Compute loss of the head.
Args: Args:
...@@ -424,20 +418,16 @@ class AnchorFreeMono3DHead(BaseMono3DDenseHead): ...@@ -424,20 +418,16 @@ class AnchorFreeMono3DHead(BaseMono3DDenseHead):
attr_preds (list[Tensor]): Box scores for each scale level, attr_preds (list[Tensor]): Box scores for each scale level,
each is a 4D-tensor, the channel number is each is a 4D-tensor, the channel number is
num_points * num_attrs. num_points * num_attrs.
gt_bboxes (list[Tensor]): Ground truth bboxes for each image with batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. gt_instance_3d. It usually includes ``bboxes``、``labels``
gt_labels (list[Tensor]): class indices corresponding to each box 、``bboxes_3d``、``labels3d``、``depths``、``centers2d`` and
gt_bboxes_3d (list[Tensor]): 3D Ground truth bboxes for each attributes.
image with shape (num_gts, bbox_code_size). batch_img_metas (list[dict]): Meta information of each image, e.g.,
gt_labels_3d (list[Tensor]): 3D class indices of each box.
centers2d (list[Tensor]): Projected 3D centers onto 2D images.
depths (list[Tensor]): Depth of projected centers on 2D images.
attr_labels (list[Tensor], optional): Attribute indices
corresponding to each box
img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc. image size, scaling factor, etc.
gt_bboxes_ignore (list[Tensor]): specify which bounding batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional):
boxes can be ignored when computing the loss. Batch of gt_instances_ignore. It includes ``bboxes`` attribute
data that is ignored during training and testing.
Defaults to None.
""" """
raise NotImplementedError raise NotImplementedError
...@@ -474,29 +464,17 @@ class AnchorFreeMono3DHead(BaseMono3DDenseHead): ...@@ -474,29 +464,17 @@ class AnchorFreeMono3DHead(BaseMono3DDenseHead):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def get_targets(self, points, gt_bboxes_list, gt_labels_list, def get_targets(self, points, batch_gt_instances_3d):
gt_bboxes_3d_list, gt_labels_3d_list, centers2d_list,
depths_list, attr_labels_list):
"""Compute regression, classification and centerss targets for points """Compute regression, classification and centerss targets for points
in multiple images. in multiple images.
Args: Args:
points (list[Tensor]): Points of each fpn level, each has shape points (list[Tensor]): Points of each fpn level, each has shape
(num_points, 2). (num_points, 2).
gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image, batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
each has shape (num_gt, 4). gt_instance_3d. It usually includes ``bboxes``、``labels``
gt_labels_list (list[Tensor]): Ground truth labels of each box, 、``bboxes_3d``、``labels3d``、``depths``、``centers2d`` and
each has shape (num_gt,). attributes.
gt_bboxes_3d_list (list[Tensor]): 3D Ground truth bboxes of each
image, each has shape (num_gt, bbox_code_size).
gt_labels_3d_list (list[Tensor]): 3D Ground truth labels of each
box, each has shape (num_gt,).
centers2d_list (list[Tensor]): Projected 3D centers onto 2D image,
each has shape (num_gt, 2).
depths_list (list[Tensor]): Depth of projected 3D centers onto 2D
image, each has shape (num_gt, 1).
attr_labels_list (list[Tensor]): Attribute labels of each box,
each has shape (num_gt,).
""" """
raise NotImplementedError raise NotImplementedError
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import warnings
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from typing import List, Optional
from mmcv.runner import BaseModule from mmcv.runner import BaseModule
from mmengine.config import ConfigDict
from torch import Tensor
from mmdet3d.core import Det3DDataSample
class BaseMono3DDenseHead(BaseModule, metaclass=ABCMeta): class BaseMono3DDenseHead(BaseModule, metaclass=ABCMeta):
"""Base class for Monocular 3D DenseHeads.""" """Base class for Monocular 3D DenseHeads."""
def __init__(self, init_cfg=None): def __init__(self, init_cfg: Optional[dict] = None) -> None:
super(BaseMono3DDenseHead, self).__init__(init_cfg=init_cfg) super(BaseMono3DDenseHead, self).__init__(init_cfg=init_cfg)
@abstractmethod @abstractmethod
...@@ -15,64 +21,78 @@ class BaseMono3DDenseHead(BaseModule, metaclass=ABCMeta): ...@@ -15,64 +21,78 @@ class BaseMono3DDenseHead(BaseModule, metaclass=ABCMeta):
"""Compute losses of the head.""" """Compute losses of the head."""
pass pass
def get_bboxes(self, *args, **kwargs):
warnings.warn('`get_bboxes` is deprecated and will be removed in '
'the future. Please use `get_results` instead.')
return self.get_results(*args, **kwargs)
@abstractmethod @abstractmethod
def get_bboxes(self, **kwargs): def get_results(self, *args, **kwargs):
"""Transform network output for a batch into bbox predictions.""" """Transform network outputs of a batch into 3D bbox results."""
pass pass
def forward_train(self, def forward_train(self,
x, x: List[Tensor],
img_metas, batch_data_samples: List[Det3DDataSample],
gt_bboxes, proposal_cfg: Optional[ConfigDict] = None,
gt_labels=None,
gt_bboxes_3d=None,
gt_labels_3d=None,
centers2d=None,
depths=None,
attr_labels=None,
gt_bboxes_ignore=None,
proposal_cfg=None,
**kwargs): **kwargs):
""" """
Args: Args:
x (list[Tensor]): Features from FPN. x (list[Tensor]): Features from FPN.
img_metas (list[dict]): Meta information of each image, e.g., batch_data_samples (list[:obj:`Det3DDataSample`]): Each item
image size, scaling factor, etc. contains the meta information of each image and corresponding
gt_bboxes (list[Tensor]): Ground truth bboxes of the image, annotations.
shape (num_gts, 4). proposal_cfg (mmengine.Config, optional): Test / postprocessing
gt_labels (list[Tensor]): Ground truth labels of each box, configuration, if None, test_cfg would be used.
shape (num_gts,). Defaults to None.
gt_bboxes_3d (list[Tensor]): 3D ground truth bboxes of the image,
shape (num_gts, self.bbox_code_size).
gt_labels_3d (list[Tensor]): 3D ground truth labels of each box,
shape (num_gts,).
centers2d (list[Tensor]): Projected 3D center of each box,
shape (num_gts, 2).
depths (list[Tensor]): Depth of projected 3D center of each box,
shape (num_gts,).
attr_labels (list[Tensor]): Attribute labels of each box,
shape (num_gts,).
gt_bboxes_ignore (list[Tensor]): Ground truth bboxes to be
ignored, shape (num_ignored_gts, 4).
proposal_cfg (mmcv.Config): Test / postprocessing configuration,
if None, test_cfg would be used
Returns: Returns:
tuple: tuple or Tensor: When `proposal_cfg` is None, the detector is a \
losses: (dict[str, Tensor]): A dictionary of loss components. normal one-stage detector, The return value is the losses.
proposal_list (list[Tensor]): Proposals of each image.
- losses: (dict[str, Tensor]): A dictionary of loss components.
When the `proposal_cfg` is not None, the head is used as a
`rpn_head`, the return value is a tuple contains:
- losses: (dict[str, Tensor]): A dictionary of loss components.
- results_list (list[:obj:`InstanceData`]): Detection
results of each image after the post process.
Each item usually contains following keys.
- scores (Tensor): Classification scores, has a shape
(num_instance, )
- labels (Tensor): Labels of bboxes, has a shape
(num_instances, ).
- bboxes (:obj:`BaseInstance3DBoxes`): Contains a tensor
with shape (num_instances, C), the last dimension C of a
3D box is (x, y, z, x_size, y_size, z_size, yaw, ...), where
C >= 7. C = 7 for kitti and C = 9 for nuscenes with extra 2
dims of velocity.
""" """
outs = self(x) outs = self(x)
if gt_labels is None: batch_gt_instances_3d = []
loss_inputs = outs + (gt_bboxes, gt_bboxes_3d, centers2d, depths, batch_gt_instances_ignore = []
attr_labels, img_metas) batch_img_metas = []
else: for data_sample in batch_data_samples:
loss_inputs = outs + (gt_bboxes, gt_labels, gt_bboxes_3d, batch_img_metas.append(data_sample.metainfo)
gt_labels_3d, centers2d, depths, attr_labels, batch_gt_instances_3d.append(data_sample.gt_instances_3d)
img_metas) if 'ignored_instances' in data_sample:
losses = self.loss(*loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore) batch_gt_instances_ignore.append(data_sample.ignored_instances)
else:
batch_gt_instances_ignore.append(None)
loss_inputs = outs + (batch_gt_instances_3d, batch_img_metas,
batch_gt_instances_ignore)
losses = self.loss(*loss_inputs)
if proposal_cfg is None: if proposal_cfg is None:
return losses return losses
else: else:
proposal_list = self.get_bboxes(*outs, img_metas, cfg=proposal_cfg) batch_img_metas = [
return losses, proposal_list data_sample.metainfo for data_sample in batch_data_samples
]
results_list = self.get_results(
*outs, batch_img_metas=batch_img_metas, cfg=proposal_cfg)
return losses, results_list
...@@ -259,15 +259,9 @@ class FCOSMono3DHead(AnchorFreeMono3DHead): ...@@ -259,15 +259,9 @@ class FCOSMono3DHead(AnchorFreeMono3DHead):
dir_cls_preds, dir_cls_preds,
attr_preds, attr_preds,
centernesses, centernesses,
gt_bboxes, batch_gt_instances_3d,
gt_labels, batch_img_metas,
gt_bboxes_3d, batch_gt_instances_ignore=None):
gt_labels_3d,
centers2d,
depths,
attr_labels,
img_metas,
gt_bboxes_ignore=None):
"""Compute loss of the head. """Compute loss of the head.
Args: Args:
...@@ -285,21 +279,16 @@ class FCOSMono3DHead(AnchorFreeMono3DHead): ...@@ -285,21 +279,16 @@ class FCOSMono3DHead(AnchorFreeMono3DHead):
num_points * num_attrs. num_points * num_attrs.
centernesses (list[Tensor]): Centerness for each scale level, each centernesses (list[Tensor]): Centerness for each scale level, each
is a 4D-tensor, the channel number is num_points * 1. is a 4D-tensor, the channel number is num_points * 1.
gt_bboxes (list[Tensor]): Ground truth bboxes for each image with batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. gt_instance_3d. It usually includes ``bboxes``、``labels``
gt_labels (list[Tensor]): class indices corresponding to each box 、``bboxes_3d``、``labels3d``、``depths``、``centers2d`` and
gt_bboxes_3d (list[Tensor]): 3D boxes ground truth with shape of attributes.
(num_gts, code_size). batch_img_metas (list[dict]): Meta information of each image, e.g.,
gt_labels_3d (list[Tensor]): same as gt_labels
centers2d (list[Tensor]): 2D centers on the image with shape of
(num_gts, 2).
depths (list[Tensor]): Depth ground truth with shape of
(num_gts, ).
attr_labels (list[Tensor]): Attributes indices of each box.
img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc. image size, scaling factor, etc.
gt_bboxes_ignore (list[Tensor]): specify which bounding batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional):
boxes can be ignored when computing the loss. Batch of gt_instances_ignore. It includes ``bboxes`` attribute
data that is ignored during training and testing.
Defaults to None.
Returns: Returns:
dict[str, Tensor]: A dictionary of loss components. dict[str, Tensor]: A dictionary of loss components.
...@@ -310,9 +299,7 @@ class FCOSMono3DHead(AnchorFreeMono3DHead): ...@@ -310,9 +299,7 @@ class FCOSMono3DHead(AnchorFreeMono3DHead):
all_level_points = self.get_points(featmap_sizes, bbox_preds[0].dtype, all_level_points = self.get_points(featmap_sizes, bbox_preds[0].dtype,
bbox_preds[0].device) bbox_preds[0].device)
labels_3d, bbox_targets_3d, centerness_targets, attr_targets = \ labels_3d, bbox_targets_3d, centerness_targets, attr_targets = \
self.get_targets( self.get_targets(all_level_points, batch_gt_instances_3d)
all_level_points, gt_bboxes, gt_labels, gt_bboxes_3d,
gt_labels_3d, centers2d, depths, attr_labels)
num_imgs = cls_scores[0].size(0) num_imgs = cls_scores[0].size(0)
# flatten cls_scores, bbox_preds, dir_cls_preds and centerness # flatten cls_scores, bbox_preds, dir_cls_preds and centerness
...@@ -742,29 +729,17 @@ class FCOSMono3DHead(AnchorFreeMono3DHead): ...@@ -742,29 +729,17 @@ class FCOSMono3DHead(AnchorFreeMono3DHead):
dim=-1) + stride // 2 dim=-1) + stride // 2
return points return points
def get_targets(self, points, gt_bboxes_list, gt_labels_list, def get_targets(self, points, batch_gt_instances_3d):
gt_bboxes_3d_list, gt_labels_3d_list, centers2d_list,
depths_list, attr_labels_list):
"""Compute regression, classification and centerss targets for points """Compute regression, classification and centerss targets for points
in multiple images. in multiple images.
Args: Args:
points (list[Tensor]): Points of each fpn level, each has shape points (list[Tensor]): Points of each fpn level, each has shape
(num_points, 2). (num_points, 2).
gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image, batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
each has shape (num_gt, 4). gt_instance_3d. It usually includes ``bboxes``、``labels``
gt_labels_list (list[Tensor]): Ground truth labels of each box, 、``bboxes_3d``、``labels3d``、``depths``、``centers2d`` and
each has shape (num_gt,). attributes.
gt_bboxes_3d_list (list[Tensor]): 3D Ground truth bboxes of each
image, each has shape (num_gt, bbox_code_size).
gt_labels_3d_list (list[Tensor]): 3D Ground truth labels of each
box, each has shape (num_gt,).
centers2d_list (list[Tensor]): Projected 3D centers onto 2D image,
each has shape (num_gt, 2).
depths_list (list[Tensor]): Depth of projected 3D centers onto 2D
image, each has shape (num_gt, 1).
attr_labels_list (list[Tensor]): Attribute labels of each box,
each has shape (num_gt,).
Returns: Returns:
tuple: tuple:
...@@ -786,23 +761,11 @@ class FCOSMono3DHead(AnchorFreeMono3DHead): ...@@ -786,23 +761,11 @@ class FCOSMono3DHead(AnchorFreeMono3DHead):
# the number of points per img, per lvl # the number of points per img, per lvl
num_points = [center.size(0) for center in points] num_points = [center.size(0) for center in points]
if attr_labels_list is None:
attr_labels_list = [
gt_labels.new_full(gt_labels.shape, self.attr_background_label)
for gt_labels in gt_labels_list
]
# get labels and bbox_targets of each image # get labels and bbox_targets of each image
_, _, labels_3d_list, bbox_targets_3d_list, centerness_targets_list, \ _, _, labels_3d_list, bbox_targets_3d_list, centerness_targets_list, \
attr_targets_list = multi_apply( attr_targets_list = multi_apply(
self._get_target_single, self._get_target_single,
gt_bboxes_list, batch_gt_instances_3d,
gt_labels_list,
gt_bboxes_3d_list,
gt_labels_3d_list,
centers2d_list,
depths_list,
attr_labels_list,
points=concat_points, points=concat_points,
regress_ranges=concat_regress_ranges, regress_ranges=concat_regress_ranges,
num_points_per_lvl=num_points) num_points_per_lvl=num_points)
...@@ -850,12 +813,19 @@ class FCOSMono3DHead(AnchorFreeMono3DHead): ...@@ -850,12 +813,19 @@ class FCOSMono3DHead(AnchorFreeMono3DHead):
return concat_lvl_labels_3d, concat_lvl_bbox_targets_3d, \ return concat_lvl_labels_3d, concat_lvl_bbox_targets_3d, \
concat_lvl_centerness_targets, concat_lvl_attr_targets concat_lvl_centerness_targets, concat_lvl_attr_targets
def _get_target_single(self, gt_bboxes, gt_labels, gt_bboxes_3d, def _get_target_single(self, gt_instances_3d, points, regress_ranges,
gt_labels_3d, centers2d, depths, attr_labels, num_points_per_lvl):
points, regress_ranges, num_points_per_lvl):
"""Compute regression and classification targets for a single image.""" """Compute regression and classification targets for a single image."""
num_points = points.size(0) num_points = points.size(0)
num_gts = gt_labels.size(0) num_gts = len(gt_instances_3d)
gt_bboxes = gt_instances_3d.bboxes
gt_labels = gt_instances_3d.labels
gt_bboxes_3d = gt_instances_3d.bboxes_3d
gt_labels_3d = gt_instances_3d.labels_3d
centers2d = gt_instances_3d.centers2d
depths = gt_instances_3d.depths
attr_labels = gt_instances_3d.attr_labels
if not isinstance(gt_bboxes_3d, torch.Tensor): if not isinstance(gt_bboxes_3d, torch.Tensor):
gt_bboxes_3d = gt_bboxes_3d.tensor.to(gt_bboxes.device) gt_bboxes_3d = gt_bboxes_3d.tensor.to(gt_bboxes.device)
if num_gts == 0: if num_gts == 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment