"requirements/requirements-sparse_attn.txt" did not exist on "838f53b7613d65fd8260bb76331ce5b8b0911848"
imvoxelnet.py 8.91 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
Tai-Wang's avatar
Tai-Wang committed
2
3
from typing import List, Tuple, Union

4
5
import torch

zhangshilong's avatar
zhangshilong committed
6
7
8
9
10
from mmdet3d.models.layers.fusion_layers.point_fusion import point_sample
from mmdet3d.registry import MODELS, TASK_UTILS
from mmdet3d.structures import Det3DDataSample
from mmdet3d.structures.det3d_data_sample import SampleList
from mmdet3d.utils import ConfigType, InstanceList, OptConfigType
11
12
13
from mmdet.models.detectors import BaseDetector


14
@MODELS.register_module()
15
class ImVoxelNet(BaseDetector):
Tai-Wang's avatar
Tai-Wang committed
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
    r"""`ImVoxelNet <https://arxiv.org/abs/2106.01178>`_.

    Args:
        backbone (:obj:`ConfigDict` or dict): The backbone config.
        neck (:obj:`ConfigDict` or dict): The neck config.
        neck_3d (:obj:`ConfigDict` or dict): The 3D neck config.
        bbox_head (:obj:`ConfigDict` or dict): The bbox head config.
        n_voxels (list): Number of voxels along x, y, z axis.
        anchor_generator (:obj:`ConfigDict` or dict): The anchor generator
            config.
        train_cfg (:obj:`ConfigDict` or dict, optional): Config dict of
            training hyper-parameters. Defaults to None.
        test_cfg (:obj:`ConfigDict` or dict, optional): Config dict of test
            hyper-parameters. Defaults to None.
        data_preprocessor (dict or ConfigDict, optional): The pre-process
            config of :class:`BaseDataPreprocessor`.  it usually includes,
                ``pad_size_divisor``, ``pad_value``, ``mean`` and ``std``.
        init_cfg (:obj:`ConfigDict` or dict, optional): The initialization
            config. Defaults to None.
    """
36
37

    def __init__(self,
Tai-Wang's avatar
Tai-Wang committed
38
39
40
41
42
43
44
45
46
47
48
49
                 backbone: ConfigType,
                 neck: ConfigType,
                 neck_3d: ConfigType,
                 bbox_head: ConfigType,
                 n_voxels: List,
                 anchor_generator: ConfigType,
                 train_cfg: OptConfigType = None,
                 test_cfg: OptConfigType = None,
                 data_preprocessor: OptConfigType = None,
                 init_cfg: OptConfigType = None):
        super().__init__(
            data_preprocessor=data_preprocessor, init_cfg=init_cfg)
50
51
52
        self.backbone = MODELS.build(backbone)
        self.neck = MODELS.build(neck)
        self.neck_3d = MODELS.build(neck_3d)
53
54
        bbox_head.update(train_cfg=train_cfg)
        bbox_head.update(test_cfg=test_cfg)
55
        self.bbox_head = MODELS.build(bbox_head)
56
        self.n_voxels = n_voxels
zhangshilong's avatar
zhangshilong committed
57
        self.anchor_generator = TASK_UTILS.build(anchor_generator)
58
59
60
        self.train_cfg = train_cfg
        self.test_cfg = test_cfg

Tai-Wang's avatar
Tai-Wang committed
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
    def convert_to_datasample(self, results_list: InstanceList) -> SampleList:
        """Convert results list to `Det3DDataSample`.

        Args:
            results_list (list[:obj:`InstanceData`]): 3D Detection results of
                each image.

        Returns:
            list[:obj:`Det3DDataSample`]: 3D Detection results of the
            input images. Each Det3DDataSample usually contain
            'pred_instances_3d'. And the ``pred_instances_3d`` usually
            contains following keys.

                - scores_3d (Tensor): Classification scores, has a shape
                    (num_instance, )
                - labels_3d (Tensor): Labels of bboxes, has a shape
                    (num_instances, ).
                - bboxes_3d (Tensor): Contains a tensor with shape
                    (num_instances, C) where C >=7.
        """
        out_results_list = []
        for i in range(len(results_list)):
            result = Det3DDataSample()
            result.pred_instances_3d = results_list[i]
            out_results_list.append(result)
        return out_results_list

    def extract_feat(self, batch_inputs_dict: dict,
                     batch_data_samples: SampleList):
90
91
92
        """Extract 3d features from the backbone -> fpn -> 3d projection.

        Args:
Tai-Wang's avatar
Tai-Wang committed
93
94
95
96
97
98
99
            batch_inputs_dict (dict): The model input dict which include
                the 'imgs' key.

                    - imgs (torch.Tensor, optional): Image of each sample.
            batch_data_samples (list[:obj:`DetDataSample`]): The batch
                data samples. It usually includes information such
                as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
100
101
102
103

        Returns:
            torch.Tensor: of shape (N, C_out, N_x, N_y, N_z)
        """
Tai-Wang's avatar
Tai-Wang committed
104
105
106
107
        img = batch_inputs_dict['imgs']
        batch_img_metas = [
            data_samples.metainfo for data_samples in batch_data_samples
        ]
108
109
110
111
112
        x = self.backbone(img)
        x = self.neck(x)[0]
        points = self.anchor_generator.grid_anchors(
            [self.n_voxels[::-1]], device=img.device)[0][:, :3]
        volumes = []
Tai-Wang's avatar
Tai-Wang committed
113
        for feature, img_meta in zip(x, batch_img_metas):
114
115
116
117
118
119
120
            img_scale_factor = (
                points.new_tensor(img_meta['scale_factor'][:2])
                if 'scale_factor' in img_meta.keys() else 1)
            img_flip = img_meta['flip'] if 'flip' in img_meta.keys() else False
            img_crop_offset = (
                points.new_tensor(img_meta['img_crop_offset'])
                if 'img_crop_offset' in img_meta.keys() else 0)
Tai-Wang's avatar
Tai-Wang committed
121
            lidar2img = points.new_tensor(img_meta['lidar2img'])
122
123
124
125
            volume = point_sample(
                img_meta,
                img_features=feature[None, ...],
                points=points,
Tai-Wang's avatar
Tai-Wang committed
126
                proj_mat=lidar2img,
127
                coord_type='LIDAR',
128
129
130
131
132
133
134
135
136
137
138
139
                img_scale_factor=img_scale_factor,
                img_crop_offset=img_crop_offset,
                img_flip=img_flip,
                img_pad_shape=img.shape[-2:],
                img_shape=img_meta['img_shape'][:2],
                aligned=False)
            volumes.append(
                volume.reshape(self.n_voxels[::-1] + [-1]).permute(3, 2, 1, 0))
        x = torch.stack(volumes)
        x = self.neck_3d(x)
        return x

Tai-Wang's avatar
Tai-Wang committed
140
141
142
    def loss(self, batch_inputs_dict: dict, batch_data_samples: SampleList,
             **kwargs) -> Union[dict, list]:
        """Calculate losses from a batch of inputs and data samples.
143
144

        Args:
Tai-Wang's avatar
Tai-Wang committed
145
146
147
148
149
150
151
            batch_inputs_dict (dict): The model input dict which include
                the 'imgs' key.

                    - imgs (torch.Tensor, optional): Image of each sample.
            batch_data_samples (list[:obj:`DetDataSample`]): The batch
                data samples. It usually includes information such
                as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
152
153

        Returns:
Tai-Wang's avatar
Tai-Wang committed
154
            dict: A dictionary of loss components.
155
        """
Tai-Wang's avatar
Tai-Wang committed
156
157
158

        x = self.extract_feat(batch_inputs_dict, batch_data_samples)
        losses = self.bbox_head.loss(x, batch_data_samples, **kwargs)
159
160
        return losses

Tai-Wang's avatar
Tai-Wang committed
161
162
163
164
    def predict(self, batch_inputs_dict: dict, batch_data_samples: SampleList,
                **kwargs) -> SampleList:
        """Predict results from a batch of inputs and data samples with post-
        processing.
165
166

        Args:
Tai-Wang's avatar
Tai-Wang committed
167
168
            batch_inputs_dict (dict): The model input dict which include
                the 'imgs' key.
169

Tai-Wang's avatar
Tai-Wang committed
170
                    - imgs (torch.Tensor, optional): Image of each sample.
171

Tai-Wang's avatar
Tai-Wang committed
172
173
174
            batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
                Samples. It usually includes information such as
                `gt_instance_3d`, `gt_panoptic_seg_3d` and `gt_sem_seg_3d`.
175
176

        Returns:
Tai-Wang's avatar
Tai-Wang committed
177
178
179
180
181
182
183
184
185
186
187
            list[:obj:`Det3DDataSample`]: Detection results of the
            input images. Each Det3DDataSample usually contain
            'pred_instances_3d'. And the ``pred_instances_3d`` usually
            contains following keys.

                - scores_3d (Tensor): Classification scores, has a shape
                    (num_instance, )
                - labels_3d (Tensor): Labels of bboxes, has a shape
                    (num_instances, ).
                - bboxes_3d (Tensor): Contains a tensor with shape
                    (num_instances, C) where C >=7.
188
        """
Tai-Wang's avatar
Tai-Wang committed
189
190
191
192
        x = self.extract_feat(batch_inputs_dict, batch_data_samples)
        results_list = self.bbox_head.predict(x, batch_data_samples, **kwargs)
        predictions = self.convert_to_datasample(results_list)
        return predictions
193

Tai-Wang's avatar
Tai-Wang committed
194
195
196
197
    def _forward(self, batch_inputs_dict: dict, batch_data_samples: SampleList,
                 *args, **kwargs) -> Tuple[List[torch.Tensor]]:
        """Network forward process. Usually includes backbone, neck and head
        forward without any post-processing.
198
199

        Args:
Tai-Wang's avatar
Tai-Wang committed
200
201
202
203
204
205
206
            batch_inputs_dict (dict): The model input dict which include
                the 'imgs' key.

                    - imgs (torch.Tensor, optional): Image of each sample.
            batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
                Samples. It usually includes information such as
                `gt_instance_3d`, `gt_panoptic_seg_3d` and `gt_sem_seg_3d`.
207
208

        Returns:
Tai-Wang's avatar
Tai-Wang committed
209
            tuple[list]: A tuple of features from ``bbox_head`` forward.
210
        """
Tai-Wang's avatar
Tai-Wang committed
211
212
213
        x = self.extract_feat(batch_inputs_dict, batch_data_samples)
        results = self.bbox_head.forward(x)
        return results