# Copyright (c) OpenMMLab. All rights reserved. from abc import ABCMeta, abstractmethod from typing import Optional, Tuple from mmcv.runner import BaseModule from mmengine.config import ConfigDict from torch import Tensor from mmdet3d.core.utils import InstanceList, OptMultiConfig, SampleList class BaseMono3DDenseHead(BaseModule, metaclass=ABCMeta): """Base class for Monocular 3D DenseHeads. 1. The ``loss`` method is used to calculate the loss of densehead, which includes two steps: (1) the densehead model performs forward propagation to obtain the feature maps (2) The ``loss_by_feat`` method is called based on the feature maps to calculate the loss. .. code:: text loss(): forward() -> loss_by_feat() 2. The ``predict`` method is used to predict detection results, which includes two steps: (1) the densehead model performs forward propagation to obtain the feature maps (2) The ``predict_by_feat`` method is called based on the feature maps to predict detection results including post-processing. .. code:: text predict(): forward() -> predict_by_feat() 3. The ``loss_and_predict`` method is used to return loss and detection results at the same time. It will call densehead's ``forward``, ``loss_by_feat`` and ``predict_by_feat`` methods in order. If one-stage is used as RPN, the densehead needs to return both losses and predictions. This predictions is used as the proposal of roihead. .. code:: text loss_and_predict(): forward() -> loss_by_feat() -> predict_by_feat() """ def __init__(self, init_cfg: OptMultiConfig = None) -> None: super(BaseMono3DDenseHead, self).__init__(init_cfg=init_cfg) def loss(self, x: Tuple[Tensor], batch_data_samples: SampleList, **kwargs) -> dict: """ Args: x (list[Tensor]): Features from FPN. batch_data_samples (list[:obj:`Det3DDataSample`]): Each item contains the meta information of each image and corresponding annotations. Returns: tuple or Tensor: When `proposal_cfg` is None, the detector is a \ normal one-stage detector, The return value is the losses. - losses: (dict[str, Tensor]): A dictionary of loss components. When the `proposal_cfg` is not None, the head is used as a `rpn_head`, the return value is a tuple contains: - losses: (dict[str, Tensor]): A dictionary of loss components. - results_list (list[:obj:`InstanceData`]): Detection results of each image after the post process. Each item usually contains following keys. - scores (Tensor): Classification scores, has a shape (num_instance, ) - labels (Tensor): Labels of bboxes, has a shape (num_instances, ). - bboxes (:obj:`BaseInstance3DBoxes`): Contains a tensor with shape (num_instances, C), the last dimension C of a 3D box is (x, y, z, x_size, y_size, z_size, yaw, ...), where C >= 7. C = 7 for kitti and C = 9 for nuscenes with extra 2 dims of velocity. """ outs = self(x) batch_gt_instances_3d = [] batch_gt_instances = [] batch_gt_instances_ignore = [] batch_img_metas = [] for data_sample in batch_data_samples: batch_img_metas.append(data_sample.metainfo) batch_gt_instances_3d.append(data_sample.gt_instances_3d) batch_gt_instances.append(data_sample.gt_instances) batch_gt_instances_ignore.append( data_sample.get('ignored_instances', None)) loss_inputs = outs + (batch_gt_instances_3d, batch_gt_instances, batch_img_metas, batch_gt_instances_ignore) losses = self.loss_by_feat(*loss_inputs) return losses @abstractmethod def loss_by_feat(self, **kwargs) -> dict: """Calculate the loss based on the features extracted by the detection head.""" pass def loss_and_predict(self, x: Tuple[Tensor], batch_data_samples: SampleList, proposal_cfg: Optional[ConfigDict] = None, **kwargs) -> Tuple[dict, InstanceList]: """Perform forward propagation of the head, then calculate loss and predictions from the features and data samples. Args: x (tuple[Tensor]): Features from FPN. batch_data_samples (list[:obj:`Det3DDataSample`]): Each item contains the meta information of each image and corresponding annotations. proposal_cfg (ConfigDict, optional): Test / postprocessing configuration, if None, test_cfg would be used. Defaults to None. Returns: tuple: the return value is a tuple contains: - losses: (dict[str, Tensor]): A dictionary of loss components. - predictions (list[:obj:`InstanceData`]): Detection results of each image after the post process. """ batch_gt_instances_3d = [] batch_gt_instances = [] batch_gt_instances_ignore = [] batch_img_metas = [] for data_sample in batch_data_samples: batch_img_metas.append(data_sample.metainfo) batch_gt_instances_3d.append(data_sample.gt_instances_3d) batch_gt_instances.append(data_sample.gt_instances) batch_gt_instances_ignore.append( data_sample.get('ignored_instances', None)) outs = self(x) loss_inputs = outs + (batch_gt_instances_3d, batch_gt_instances, batch_img_metas, batch_gt_instances_ignore) losses = self.loss_by_feat(*loss_inputs) predictions = self.predict_by_feat( *outs, batch_img_metas=batch_img_metas, cfg=proposal_cfg) return losses, predictions def predict(self, x: Tuple[Tensor], batch_data_samples: SampleList, rescale: bool = False) -> InstanceList: """Perform forward propagation of the detection head and predict detection results on the features of the upstream network. Args: x (tuple[Tensor]): Multi-level features from the upstream network, each is a 4D-tensor. batch_data_samples (List[:obj:`Det3DDataSample`]): The Data Samples. It usually includes information such as `gt_instance_3d`, `gt_pts_panoptic_seg` and `gt_pts_sem_seg`. rescale (bool, optional): Whether to rescale the results. Defaults to False. Returns: list[obj:`InstanceData`]: Detection results of each image after the post process. """ batch_img_metas = [ data_samples.metainfo for data_samples in batch_data_samples ] outs = self(x) predictions = self.predict_by_feat( *outs, batch_img_metas=batch_img_metas, rescale=rescale) return predictions @abstractmethod def predict_by_feat(self, **kwargs) -> InstanceList: """Transform a batch of output features extracted from the head into bbox results.""" pass