imvoxelnet.py 7.85 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
Tai-Wang's avatar
Tai-Wang committed
2
3
from typing import List, Tuple, Union

4
5
import torch

6
from mmdet3d.models.detectors import Base3DDetector
zhangshilong's avatar
zhangshilong committed
7
8
9
from mmdet3d.models.layers.fusion_layers.point_fusion import point_sample
from mmdet3d.registry import MODELS, TASK_UTILS
from mmdet3d.structures.det3d_data_sample import SampleList
10
from mmdet3d.utils import ConfigType, OptConfigType
11
12


13
@MODELS.register_module()
14
class ImVoxelNet(Base3DDetector):
Tai-Wang's avatar
Tai-Wang committed
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
    r"""`ImVoxelNet <https://arxiv.org/abs/2106.01178>`_.

    Args:
        backbone (:obj:`ConfigDict` or dict): The backbone config.
        neck (:obj:`ConfigDict` or dict): The neck config.
        neck_3d (:obj:`ConfigDict` or dict): The 3D neck config.
        bbox_head (:obj:`ConfigDict` or dict): The bbox head config.
        n_voxels (list): Number of voxels along x, y, z axis.
        anchor_generator (:obj:`ConfigDict` or dict): The anchor generator
            config.
        train_cfg (:obj:`ConfigDict` or dict, optional): Config dict of
            training hyper-parameters. Defaults to None.
        test_cfg (:obj:`ConfigDict` or dict, optional): Config dict of test
            hyper-parameters. Defaults to None.
        data_preprocessor (dict or ConfigDict, optional): The pre-process
            config of :class:`BaseDataPreprocessor`.  it usually includes,
                ``pad_size_divisor``, ``pad_value``, ``mean`` and ``std``.
        init_cfg (:obj:`ConfigDict` or dict, optional): The initialization
            config. Defaults to None.
    """
35
36

    def __init__(self,
Tai-Wang's avatar
Tai-Wang committed
37
38
39
40
41
42
43
44
45
46
47
48
                 backbone: ConfigType,
                 neck: ConfigType,
                 neck_3d: ConfigType,
                 bbox_head: ConfigType,
                 n_voxels: List,
                 anchor_generator: ConfigType,
                 train_cfg: OptConfigType = None,
                 test_cfg: OptConfigType = None,
                 data_preprocessor: OptConfigType = None,
                 init_cfg: OptConfigType = None):
        super().__init__(
            data_preprocessor=data_preprocessor, init_cfg=init_cfg)
49
50
51
        self.backbone = MODELS.build(backbone)
        self.neck = MODELS.build(neck)
        self.neck_3d = MODELS.build(neck_3d)
52
53
        bbox_head.update(train_cfg=train_cfg)
        bbox_head.update(test_cfg=test_cfg)
54
        self.bbox_head = MODELS.build(bbox_head)
55
        self.n_voxels = n_voxels
zhangshilong's avatar
zhangshilong committed
56
        self.anchor_generator = TASK_UTILS.build(anchor_generator)
57
58
59
        self.train_cfg = train_cfg
        self.test_cfg = test_cfg

Tai-Wang's avatar
Tai-Wang committed
60
61
    def extract_feat(self, batch_inputs_dict: dict,
                     batch_data_samples: SampleList):
62
63
64
        """Extract 3d features from the backbone -> fpn -> 3d projection.

        Args:
Tai-Wang's avatar
Tai-Wang committed
65
66
67
68
69
70
71
            batch_inputs_dict (dict): The model input dict which include
                the 'imgs' key.

                    - imgs (torch.Tensor, optional): Image of each sample.
            batch_data_samples (list[:obj:`DetDataSample`]): The batch
                data samples. It usually includes information such
                as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
72
73
74
75

        Returns:
            torch.Tensor: of shape (N, C_out, N_x, N_y, N_z)
        """
Tai-Wang's avatar
Tai-Wang committed
76
77
78
79
        img = batch_inputs_dict['imgs']
        batch_img_metas = [
            data_samples.metainfo for data_samples in batch_data_samples
        ]
80
81
82
83
84
        x = self.backbone(img)
        x = self.neck(x)[0]
        points = self.anchor_generator.grid_anchors(
            [self.n_voxels[::-1]], device=img.device)[0][:, :3]
        volumes = []
Tai-Wang's avatar
Tai-Wang committed
85
        for feature, img_meta in zip(x, batch_img_metas):
86
87
88
89
90
91
92
            img_scale_factor = (
                points.new_tensor(img_meta['scale_factor'][:2])
                if 'scale_factor' in img_meta.keys() else 1)
            img_flip = img_meta['flip'] if 'flip' in img_meta.keys() else False
            img_crop_offset = (
                points.new_tensor(img_meta['img_crop_offset'])
                if 'img_crop_offset' in img_meta.keys() else 0)
Tai-Wang's avatar
Tai-Wang committed
93
            lidar2img = points.new_tensor(img_meta['lidar2img'])
94
95
96
97
            volume = point_sample(
                img_meta,
                img_features=feature[None, ...],
                points=points,
Tai-Wang's avatar
Tai-Wang committed
98
                proj_mat=lidar2img,
99
                coord_type='LIDAR',
100
101
102
103
104
105
106
107
108
109
110
111
                img_scale_factor=img_scale_factor,
                img_crop_offset=img_crop_offset,
                img_flip=img_flip,
                img_pad_shape=img.shape[-2:],
                img_shape=img_meta['img_shape'][:2],
                aligned=False)
            volumes.append(
                volume.reshape(self.n_voxels[::-1] + [-1]).permute(3, 2, 1, 0))
        x = torch.stack(volumes)
        x = self.neck_3d(x)
        return x

Tai-Wang's avatar
Tai-Wang committed
112
113
114
    def loss(self, batch_inputs_dict: dict, batch_data_samples: SampleList,
             **kwargs) -> Union[dict, list]:
        """Calculate losses from a batch of inputs and data samples.
115
116

        Args:
Tai-Wang's avatar
Tai-Wang committed
117
118
119
120
121
122
123
            batch_inputs_dict (dict): The model input dict which include
                the 'imgs' key.

                    - imgs (torch.Tensor, optional): Image of each sample.
            batch_data_samples (list[:obj:`DetDataSample`]): The batch
                data samples. It usually includes information such
                as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
124
125

        Returns:
Tai-Wang's avatar
Tai-Wang committed
126
            dict: A dictionary of loss components.
127
        """
Tai-Wang's avatar
Tai-Wang committed
128
129
130

        x = self.extract_feat(batch_inputs_dict, batch_data_samples)
        losses = self.bbox_head.loss(x, batch_data_samples, **kwargs)
131
132
        return losses

Tai-Wang's avatar
Tai-Wang committed
133
134
135
136
    def predict(self, batch_inputs_dict: dict, batch_data_samples: SampleList,
                **kwargs) -> SampleList:
        """Predict results from a batch of inputs and data samples with post-
        processing.
137
138

        Args:
Tai-Wang's avatar
Tai-Wang committed
139
140
            batch_inputs_dict (dict): The model input dict which include
                the 'imgs' key.
141

Tai-Wang's avatar
Tai-Wang committed
142
                    - imgs (torch.Tensor, optional): Image of each sample.
143

Tai-Wang's avatar
Tai-Wang committed
144
145
146
            batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
                Samples. It usually includes information such as
                `gt_instance_3d`, `gt_panoptic_seg_3d` and `gt_sem_seg_3d`.
147
148

        Returns:
Tai-Wang's avatar
Tai-Wang committed
149
150
151
152
153
154
155
156
157
158
159
            list[:obj:`Det3DDataSample`]: Detection results of the
            input images. Each Det3DDataSample usually contain
            'pred_instances_3d'. And the ``pred_instances_3d`` usually
            contains following keys.

                - scores_3d (Tensor): Classification scores, has a shape
                    (num_instance, )
                - labels_3d (Tensor): Labels of bboxes, has a shape
                    (num_instances, ).
                - bboxes_3d (Tensor): Contains a tensor with shape
                    (num_instances, C) where C >=7.
160
        """
Tai-Wang's avatar
Tai-Wang committed
161
162
        x = self.extract_feat(batch_inputs_dict, batch_data_samples)
        results_list = self.bbox_head.predict(x, batch_data_samples, **kwargs)
163
164
        predictions = self.add_pred_to_datasample(batch_data_samples,
                                                  results_list)
Tai-Wang's avatar
Tai-Wang committed
165
        return predictions
166

Tai-Wang's avatar
Tai-Wang committed
167
168
169
170
    def _forward(self, batch_inputs_dict: dict, batch_data_samples: SampleList,
                 *args, **kwargs) -> Tuple[List[torch.Tensor]]:
        """Network forward process. Usually includes backbone, neck and head
        forward without any post-processing.
171
172

        Args:
Tai-Wang's avatar
Tai-Wang committed
173
174
175
176
177
178
179
            batch_inputs_dict (dict): The model input dict which include
                the 'imgs' key.

                    - imgs (torch.Tensor, optional): Image of each sample.
            batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
                Samples. It usually includes information such as
                `gt_instance_3d`, `gt_panoptic_seg_3d` and `gt_sem_seg_3d`.
180
181

        Returns:
Tai-Wang's avatar
Tai-Wang committed
182
            tuple[list]: A tuple of features from ``bbox_head`` forward.
183
        """
Tai-Wang's avatar
Tai-Wang committed
184
185
186
        x = self.extract_feat(batch_inputs_dict, batch_data_samples)
        results = self.bbox_head.forward(x)
        return results