base_mono3d_dense_head.py 3.82 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
2
import warnings
twang's avatar
twang committed
3
from abc import ABCMeta, abstractmethod
4
from typing import List, Optional
5

6
from mmcv.runner import BaseModule
7
8
9
10
from mmengine.config import ConfigDict
from torch import Tensor

from mmdet3d.core import Det3DDataSample
twang's avatar
twang committed
11
12


13
class BaseMono3DDenseHead(BaseModule, metaclass=ABCMeta):
twang's avatar
twang committed
14
15
    """Base class for Monocular 3D DenseHeads."""

16
    def __init__(self, init_cfg: Optional[dict] = None) -> None:
17
        super(BaseMono3DDenseHead, self).__init__(init_cfg=init_cfg)
twang's avatar
twang committed
18
19
20
21
22
23

    @abstractmethod
    def loss(self, **kwargs):
        """Compute losses of the head."""
        pass

24
25
26
27
28
    def get_bboxes(self, *args, **kwargs):
        warnings.warn('`get_bboxes` is deprecated and will be removed in '
                      'the future. Please use `get_results` instead.')
        return self.get_results(*args, **kwargs)

twang's avatar
twang committed
29
    @abstractmethod
30
31
    def get_results(self, *args, **kwargs):
        """Transform network outputs of a batch into 3D bbox results."""
twang's avatar
twang committed
32
33
34
        pass

    def forward_train(self,
35
36
37
                      x: List[Tensor],
                      batch_data_samples: List[Det3DDataSample],
                      proposal_cfg: Optional[ConfigDict] = None,
twang's avatar
twang committed
38
39
40
41
                      **kwargs):
        """
        Args:
            x (list[Tensor]): Features from FPN.
42
43
44
45
46
47
            batch_data_samples (list[:obj:`Det3DDataSample`]): Each item
                contains the meta information of each image and corresponding
                annotations.
            proposal_cfg (mmengine.Config, optional): Test / postprocessing
                configuration, if None, test_cfg would be used.
                Defaults to None.
twang's avatar
twang committed
48
49

        Returns:
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
            tuple or Tensor: When `proposal_cfg` is None, the detector is a \
            normal one-stage detector, The return value is the losses.

            - losses: (dict[str, Tensor]): A dictionary of loss components.

            When the `proposal_cfg` is not None, the head is used as a
            `rpn_head`, the return value is a tuple contains:

            - losses: (dict[str, Tensor]): A dictionary of loss components.
            - results_list (list[:obj:`InstanceData`]): Detection
              results of each image after the post process.
              Each item usually contains following keys.

                - scores (Tensor): Classification scores, has a shape
                  (num_instance, )
                - labels (Tensor): Labels of bboxes, has a shape
                  (num_instances, ).
                - bboxes (:obj:`BaseInstance3DBoxes`): Contains a tensor
                  with shape (num_instances, C), the last dimension C of a
                  3D box is (x, y, z, x_size, y_size, z_size, yaw, ...), where
                  C >= 7. C = 7 for kitti and C = 9 for nuscenes with extra 2
                  dims of velocity.
twang's avatar
twang committed
72
        """
73

twang's avatar
twang committed
74
        outs = self(x)
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
        batch_gt_instances_3d = []
        batch_gt_instances_ignore = []
        batch_img_metas = []
        for data_sample in batch_data_samples:
            batch_img_metas.append(data_sample.metainfo)
            batch_gt_instances_3d.append(data_sample.gt_instances_3d)
            if 'ignored_instances' in data_sample:
                batch_gt_instances_ignore.append(data_sample.ignored_instances)
            else:
                batch_gt_instances_ignore.append(None)

        loss_inputs = outs + (batch_gt_instances_3d, batch_img_metas,
                              batch_gt_instances_ignore)
        losses = self.loss(*loss_inputs)

twang's avatar
twang committed
90
91
92
        if proposal_cfg is None:
            return losses
        else:
93
94
95
96
97
98
            batch_img_metas = [
                data_sample.metainfo for data_sample in batch_data_samples
            ]
            results_list = self.get_results(
                *outs, batch_img_metas=batch_img_metas, cfg=proposal_cfg)
            return losses, results_list