imvoxelnet.py 9.04 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
Tai-Wang's avatar
Tai-Wang committed
2
3
from typing import List, Tuple, Union

4
5
import torch

zhangshilong's avatar
zhangshilong committed
6
7
8
9
from mmdet3d.models.layers.fusion_layers.point_fusion import point_sample
from mmdet3d.registry import MODELS, TASK_UTILS
from mmdet3d.structures.det3d_data_sample import SampleList
from mmdet3d.utils import ConfigType, InstanceList, OptConfigType
10
11
12
from mmdet.models.detectors import BaseDetector


13
@MODELS.register_module()
14
class ImVoxelNet(BaseDetector):
Tai-Wang's avatar
Tai-Wang committed
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
    r"""`ImVoxelNet <https://arxiv.org/abs/2106.01178>`_.

    Args:
        backbone (:obj:`ConfigDict` or dict): The backbone config.
        neck (:obj:`ConfigDict` or dict): The neck config.
        neck_3d (:obj:`ConfigDict` or dict): The 3D neck config.
        bbox_head (:obj:`ConfigDict` or dict): The bbox head config.
        n_voxels (list): Number of voxels along x, y, z axis.
        anchor_generator (:obj:`ConfigDict` or dict): The anchor generator
            config.
        train_cfg (:obj:`ConfigDict` or dict, optional): Config dict of
            training hyper-parameters. Defaults to None.
        test_cfg (:obj:`ConfigDict` or dict, optional): Config dict of test
            hyper-parameters. Defaults to None.
        data_preprocessor (dict or ConfigDict, optional): The pre-process
            config of :class:`BaseDataPreprocessor`.  it usually includes,
                ``pad_size_divisor``, ``pad_value``, ``mean`` and ``std``.
        init_cfg (:obj:`ConfigDict` or dict, optional): The initialization
            config. Defaults to None.
    """
35
36

    def __init__(self,
Tai-Wang's avatar
Tai-Wang committed
37
38
39
40
41
42
43
44
45
46
47
48
                 backbone: ConfigType,
                 neck: ConfigType,
                 neck_3d: ConfigType,
                 bbox_head: ConfigType,
                 n_voxels: List,
                 anchor_generator: ConfigType,
                 train_cfg: OptConfigType = None,
                 test_cfg: OptConfigType = None,
                 data_preprocessor: OptConfigType = None,
                 init_cfg: OptConfigType = None):
        super().__init__(
            data_preprocessor=data_preprocessor, init_cfg=init_cfg)
49
50
51
        self.backbone = MODELS.build(backbone)
        self.neck = MODELS.build(neck)
        self.neck_3d = MODELS.build(neck_3d)
52
53
        bbox_head.update(train_cfg=train_cfg)
        bbox_head.update(test_cfg=test_cfg)
54
        self.bbox_head = MODELS.build(bbox_head)
55
        self.n_voxels = n_voxels
zhangshilong's avatar
zhangshilong committed
56
        self.anchor_generator = TASK_UTILS.build(anchor_generator)
57
58
59
        self.train_cfg = train_cfg
        self.test_cfg = test_cfg

60
61
62
    def convert_to_datasample(self, data_samples: SampleList,
                              data_instances: InstanceList) -> SampleList:
        """ Convert results list to `Det3DDataSample`.
Tai-Wang's avatar
Tai-Wang committed
63
        Args:
64
65
66
            inputs (list[:obj:`Det3DDataSample`]): The input data.
            data_instances (list[:obj:`InstanceData`]): 3D Detection
                results of each image.
Tai-Wang's avatar
Tai-Wang committed
67
68
69
70
71
72
73
74
75
76
77
78
        Returns:
            list[:obj:`Det3DDataSample`]: 3D Detection results of the
            input images. Each Det3DDataSample usually contain
            'pred_instances_3d'. And the ``pred_instances_3d`` usually
            contains following keys.

                - scores_3d (Tensor): Classification scores, has a shape
                    (num_instance, )
                - labels_3d (Tensor): Labels of bboxes, has a shape
                    (num_instances, ).
                - bboxes_3d (Tensor): Contains a tensor with shape
                    (num_instances, C) where C >=7.
79
80
81
82
83
            """
        for data_sample, pred_instances_3d in zip(data_samples,
                                                  data_instances):
            data_sample.pred_instances_3d = pred_instances_3d
        return data_samples
Tai-Wang's avatar
Tai-Wang committed
84
85
86

    def extract_feat(self, batch_inputs_dict: dict,
                     batch_data_samples: SampleList):
87
88
89
        """Extract 3d features from the backbone -> fpn -> 3d projection.

        Args:
Tai-Wang's avatar
Tai-Wang committed
90
91
92
93
94
95
96
            batch_inputs_dict (dict): The model input dict which include
                the 'imgs' key.

                    - imgs (torch.Tensor, optional): Image of each sample.
            batch_data_samples (list[:obj:`DetDataSample`]): The batch
                data samples. It usually includes information such
                as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
97
98
99
100

        Returns:
            torch.Tensor: of shape (N, C_out, N_x, N_y, N_z)
        """
Tai-Wang's avatar
Tai-Wang committed
101
102
103
104
        img = batch_inputs_dict['imgs']
        batch_img_metas = [
            data_samples.metainfo for data_samples in batch_data_samples
        ]
105
106
107
108
109
        x = self.backbone(img)
        x = self.neck(x)[0]
        points = self.anchor_generator.grid_anchors(
            [self.n_voxels[::-1]], device=img.device)[0][:, :3]
        volumes = []
Tai-Wang's avatar
Tai-Wang committed
110
        for feature, img_meta in zip(x, batch_img_metas):
111
112
113
114
115
116
117
            img_scale_factor = (
                points.new_tensor(img_meta['scale_factor'][:2])
                if 'scale_factor' in img_meta.keys() else 1)
            img_flip = img_meta['flip'] if 'flip' in img_meta.keys() else False
            img_crop_offset = (
                points.new_tensor(img_meta['img_crop_offset'])
                if 'img_crop_offset' in img_meta.keys() else 0)
Tai-Wang's avatar
Tai-Wang committed
118
            lidar2img = points.new_tensor(img_meta['lidar2img'])
119
120
121
122
            volume = point_sample(
                img_meta,
                img_features=feature[None, ...],
                points=points,
Tai-Wang's avatar
Tai-Wang committed
123
                proj_mat=lidar2img,
124
                coord_type='LIDAR',
125
126
127
128
129
130
131
132
133
134
135
136
                img_scale_factor=img_scale_factor,
                img_crop_offset=img_crop_offset,
                img_flip=img_flip,
                img_pad_shape=img.shape[-2:],
                img_shape=img_meta['img_shape'][:2],
                aligned=False)
            volumes.append(
                volume.reshape(self.n_voxels[::-1] + [-1]).permute(3, 2, 1, 0))
        x = torch.stack(volumes)
        x = self.neck_3d(x)
        return x

Tai-Wang's avatar
Tai-Wang committed
137
138
139
    def loss(self, batch_inputs_dict: dict, batch_data_samples: SampleList,
             **kwargs) -> Union[dict, list]:
        """Calculate losses from a batch of inputs and data samples.
140
141

        Args:
Tai-Wang's avatar
Tai-Wang committed
142
143
144
145
146
147
148
            batch_inputs_dict (dict): The model input dict which include
                the 'imgs' key.

                    - imgs (torch.Tensor, optional): Image of each sample.
            batch_data_samples (list[:obj:`DetDataSample`]): The batch
                data samples. It usually includes information such
                as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
149
150

        Returns:
Tai-Wang's avatar
Tai-Wang committed
151
            dict: A dictionary of loss components.
152
        """
Tai-Wang's avatar
Tai-Wang committed
153
154
155

        x = self.extract_feat(batch_inputs_dict, batch_data_samples)
        losses = self.bbox_head.loss(x, batch_data_samples, **kwargs)
156
157
        return losses

Tai-Wang's avatar
Tai-Wang committed
158
159
160
161
    def predict(self, batch_inputs_dict: dict, batch_data_samples: SampleList,
                **kwargs) -> SampleList:
        """Predict results from a batch of inputs and data samples with post-
        processing.
162
163

        Args:
Tai-Wang's avatar
Tai-Wang committed
164
165
            batch_inputs_dict (dict): The model input dict which include
                the 'imgs' key.
166

Tai-Wang's avatar
Tai-Wang committed
167
                    - imgs (torch.Tensor, optional): Image of each sample.
168

Tai-Wang's avatar
Tai-Wang committed
169
170
171
            batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
                Samples. It usually includes information such as
                `gt_instance_3d`, `gt_panoptic_seg_3d` and `gt_sem_seg_3d`.
172
173

        Returns:
Tai-Wang's avatar
Tai-Wang committed
174
175
176
177
178
179
180
181
182
183
184
            list[:obj:`Det3DDataSample`]: Detection results of the
            input images. Each Det3DDataSample usually contain
            'pred_instances_3d'. And the ``pred_instances_3d`` usually
            contains following keys.

                - scores_3d (Tensor): Classification scores, has a shape
                    (num_instance, )
                - labels_3d (Tensor): Labels of bboxes, has a shape
                    (num_instances, ).
                - bboxes_3d (Tensor): Contains a tensor with shape
                    (num_instances, C) where C >=7.
185
        """
Tai-Wang's avatar
Tai-Wang committed
186
187
        x = self.extract_feat(batch_inputs_dict, batch_data_samples)
        results_list = self.bbox_head.predict(x, batch_data_samples, **kwargs)
188
189
        predictions = self.convert_to_datasample(batch_data_samples,
                                                 results_list)
Tai-Wang's avatar
Tai-Wang committed
190
        return predictions
191

Tai-Wang's avatar
Tai-Wang committed
192
193
194
195
    def _forward(self, batch_inputs_dict: dict, batch_data_samples: SampleList,
                 *args, **kwargs) -> Tuple[List[torch.Tensor]]:
        """Network forward process. Usually includes backbone, neck and head
        forward without any post-processing.
196
197

        Args:
Tai-Wang's avatar
Tai-Wang committed
198
199
200
201
202
203
204
            batch_inputs_dict (dict): The model input dict which include
                the 'imgs' key.

                    - imgs (torch.Tensor, optional): Image of each sample.
            batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
                Samples. It usually includes information such as
                `gt_instance_3d`, `gt_panoptic_seg_3d` and `gt_sem_seg_3d`.
205
206

        Returns:
Tai-Wang's avatar
Tai-Wang committed
207
            tuple[list]: A tuple of features from ``bbox_head`` forward.
208
        """
Tai-Wang's avatar
Tai-Wang committed
209
210
211
        x = self.extract_feat(batch_inputs_dict, batch_data_samples)
        results = self.bbox_head.forward(x)
        return results