"sgl-router/src/routers/grpc/harmony/processor.rs" did not exist on "9ff9fa7f95bef8de135e9eb567a1ee9a17f0db47"
imvoxelnet.py 10.3 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
Tai-Wang's avatar
Tai-Wang committed
2
3
from typing import List, Tuple, Union

4
import torch
5
from mmengine.structures import InstanceData
6

7
from mmdet3d.models.detectors import Base3DDetector
zhangshilong's avatar
zhangshilong committed
8
9
10
from mmdet3d.models.layers.fusion_layers.point_fusion import point_sample
from mmdet3d.registry import MODELS, TASK_UTILS
from mmdet3d.structures.det3d_data_sample import SampleList
11
from mmdet3d.utils import ConfigType, OptConfigType, OptInstanceList
12
13


14
@MODELS.register_module()
15
class ImVoxelNet(Base3DDetector):
Tai-Wang's avatar
Tai-Wang committed
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
    r"""`ImVoxelNet <https://arxiv.org/abs/2106.01178>`_.

    Args:
        backbone (:obj:`ConfigDict` or dict): The backbone config.
        neck (:obj:`ConfigDict` or dict): The neck config.
        neck_3d (:obj:`ConfigDict` or dict): The 3D neck config.
        bbox_head (:obj:`ConfigDict` or dict): The bbox head config.
        n_voxels (list): Number of voxels along x, y, z axis.
        anchor_generator (:obj:`ConfigDict` or dict): The anchor generator
            config.
        train_cfg (:obj:`ConfigDict` or dict, optional): Config dict of
            training hyper-parameters. Defaults to None.
        test_cfg (:obj:`ConfigDict` or dict, optional): Config dict of test
            hyper-parameters. Defaults to None.
        data_preprocessor (dict or ConfigDict, optional): The pre-process
            config of :class:`BaseDataPreprocessor`.  it usually includes,
                ``pad_size_divisor``, ``pad_value``, ``mean`` and ``std``.
        init_cfg (:obj:`ConfigDict` or dict, optional): The initialization
            config. Defaults to None.
    """
36
37

    def __init__(self,
Tai-Wang's avatar
Tai-Wang committed
38
39
40
41
42
43
44
45
46
47
48
49
                 backbone: ConfigType,
                 neck: ConfigType,
                 neck_3d: ConfigType,
                 bbox_head: ConfigType,
                 n_voxels: List,
                 anchor_generator: ConfigType,
                 train_cfg: OptConfigType = None,
                 test_cfg: OptConfigType = None,
                 data_preprocessor: OptConfigType = None,
                 init_cfg: OptConfigType = None):
        super().__init__(
            data_preprocessor=data_preprocessor, init_cfg=init_cfg)
50
51
52
        self.backbone = MODELS.build(backbone)
        self.neck = MODELS.build(neck)
        self.neck_3d = MODELS.build(neck_3d)
53
54
        bbox_head.update(train_cfg=train_cfg)
        bbox_head.update(test_cfg=test_cfg)
55
        self.bbox_head = MODELS.build(bbox_head)
56
        self.n_voxels = n_voxels
zhangshilong's avatar
zhangshilong committed
57
        self.anchor_generator = TASK_UTILS.build(anchor_generator)
58
59
60
        self.train_cfg = train_cfg
        self.test_cfg = test_cfg

Tai-Wang's avatar
Tai-Wang committed
61
62
    def extract_feat(self, batch_inputs_dict: dict,
                     batch_data_samples: SampleList):
63
64
65
        """Extract 3d features from the backbone -> fpn -> 3d projection.

        Args:
Tai-Wang's avatar
Tai-Wang committed
66
67
68
69
70
71
72
            batch_inputs_dict (dict): The model input dict which include
                the 'imgs' key.

                    - imgs (torch.Tensor, optional): Image of each sample.
            batch_data_samples (list[:obj:`DetDataSample`]): The batch
                data samples. It usually includes information such
                as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
73
74
75
76

        Returns:
            torch.Tensor: of shape (N, C_out, N_x, N_y, N_z)
        """
Tai-Wang's avatar
Tai-Wang committed
77
78
79
80
        img = batch_inputs_dict['imgs']
        batch_img_metas = [
            data_samples.metainfo for data_samples in batch_data_samples
        ]
81
82
83
84
85
        x = self.backbone(img)
        x = self.neck(x)[0]
        points = self.anchor_generator.grid_anchors(
            [self.n_voxels[::-1]], device=img.device)[0][:, :3]
        volumes = []
Tai-Wang's avatar
Tai-Wang committed
86
        for feature, img_meta in zip(x, batch_img_metas):
87
88
89
90
91
92
93
            img_scale_factor = (
                points.new_tensor(img_meta['scale_factor'][:2])
                if 'scale_factor' in img_meta.keys() else 1)
            img_flip = img_meta['flip'] if 'flip' in img_meta.keys() else False
            img_crop_offset = (
                points.new_tensor(img_meta['img_crop_offset'])
                if 'img_crop_offset' in img_meta.keys() else 0)
Tai-Wang's avatar
Tai-Wang committed
94
            lidar2img = points.new_tensor(img_meta['lidar2img'])
95
96
97
98
            volume = point_sample(
                img_meta,
                img_features=feature[None, ...],
                points=points,
Tai-Wang's avatar
Tai-Wang committed
99
                proj_mat=lidar2img,
100
                coord_type='LIDAR',
101
102
103
104
105
106
107
108
109
110
111
112
                img_scale_factor=img_scale_factor,
                img_crop_offset=img_crop_offset,
                img_flip=img_flip,
                img_pad_shape=img.shape[-2:],
                img_shape=img_meta['img_shape'][:2],
                aligned=False)
            volumes.append(
                volume.reshape(self.n_voxels[::-1] + [-1]).permute(3, 2, 1, 0))
        x = torch.stack(volumes)
        x = self.neck_3d(x)
        return x

Tai-Wang's avatar
Tai-Wang committed
113
114
115
    def loss(self, batch_inputs_dict: dict, batch_data_samples: SampleList,
             **kwargs) -> Union[dict, list]:
        """Calculate losses from a batch of inputs and data samples.
116
117

        Args:
Tai-Wang's avatar
Tai-Wang committed
118
119
120
121
122
123
124
            batch_inputs_dict (dict): The model input dict which include
                the 'imgs' key.

                    - imgs (torch.Tensor, optional): Image of each sample.
            batch_data_samples (list[:obj:`DetDataSample`]): The batch
                data samples. It usually includes information such
                as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
125
126

        Returns:
Tai-Wang's avatar
Tai-Wang committed
127
            dict: A dictionary of loss components.
128
        """
Tai-Wang's avatar
Tai-Wang committed
129
130
131

        x = self.extract_feat(batch_inputs_dict, batch_data_samples)
        losses = self.bbox_head.loss(x, batch_data_samples, **kwargs)
132
133
        return losses

Tai-Wang's avatar
Tai-Wang committed
134
135
136
137
    def predict(self, batch_inputs_dict: dict, batch_data_samples: SampleList,
                **kwargs) -> SampleList:
        """Predict results from a batch of inputs and data samples with post-
        processing.
138
139

        Args:
Tai-Wang's avatar
Tai-Wang committed
140
141
            batch_inputs_dict (dict): The model input dict which include
                the 'imgs' key.
142

Tai-Wang's avatar
Tai-Wang committed
143
                    - imgs (torch.Tensor, optional): Image of each sample.
144

Tai-Wang's avatar
Tai-Wang committed
145
146
147
            batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
                Samples. It usually includes information such as
                `gt_instance_3d`, `gt_panoptic_seg_3d` and `gt_sem_seg_3d`.
148
149

        Returns:
Tai-Wang's avatar
Tai-Wang committed
150
151
152
153
154
155
156
157
158
159
160
            list[:obj:`Det3DDataSample`]: Detection results of the
            input images. Each Det3DDataSample usually contain
            'pred_instances_3d'. And the ``pred_instances_3d`` usually
            contains following keys.

                - scores_3d (Tensor): Classification scores, has a shape
                    (num_instance, )
                - labels_3d (Tensor): Labels of bboxes, has a shape
                    (num_instances, ).
                - bboxes_3d (Tensor): Contains a tensor with shape
                    (num_instances, C) where C >=7.
161
        """
Tai-Wang's avatar
Tai-Wang committed
162
163
        x = self.extract_feat(batch_inputs_dict, batch_data_samples)
        results_list = self.bbox_head.predict(x, batch_data_samples, **kwargs)
164
165
        predictions = self.add_pred_to_datasample(batch_data_samples,
                                                  results_list)
Tai-Wang's avatar
Tai-Wang committed
166
        return predictions
167

Tai-Wang's avatar
Tai-Wang committed
168
169
170
171
    def _forward(self, batch_inputs_dict: dict, batch_data_samples: SampleList,
                 *args, **kwargs) -> Tuple[List[torch.Tensor]]:
        """Network forward process. Usually includes backbone, neck and head
        forward without any post-processing.
172
173

        Args:
Tai-Wang's avatar
Tai-Wang committed
174
175
176
177
178
179
180
            batch_inputs_dict (dict): The model input dict which include
                the 'imgs' key.

                    - imgs (torch.Tensor, optional): Image of each sample.
            batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
                Samples. It usually includes information such as
                `gt_instance_3d`, `gt_panoptic_seg_3d` and `gt_sem_seg_3d`.
181
182

        Returns:
Tai-Wang's avatar
Tai-Wang committed
183
            tuple[list]: A tuple of features from ``bbox_head`` forward.
184
        """
Tai-Wang's avatar
Tai-Wang committed
185
186
187
        x = self.extract_feat(batch_inputs_dict, batch_data_samples)
        results = self.bbox_head.forward(x)
        return results
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248

    def convert_to_datasample(
        self,
        data_samples: SampleList,
        data_instances_3d: OptInstanceList = None,
        data_instances_2d: OptInstanceList = None,
    ) -> SampleList:
        """Convert results list to `Det3DDataSample`.

        Subclasses could override it to be compatible for some multi-modality
        3D detectors.

        Args:
            data_samples (list[:obj:`Det3DDataSample`]): The input data.
            data_instances_3d (list[:obj:`InstanceData`], optional): 3D
                Detection results of each sample.
            data_instances_2d (list[:obj:`InstanceData`], optional): 2D
                Detection results of each sample.

        Returns:
            list[:obj:`Det3DDataSample`]: Detection results of the
            input. Each Det3DDataSample usually contains
            'pred_instances_3d'. And the ``pred_instances_3d`` normally
            contains following keys.

            - scores_3d (Tensor): Classification scores, has a shape
              (num_instance, )
            - labels_3d (Tensor): Labels of 3D bboxes, has a shape
              (num_instances, ).
            - bboxes_3d (Tensor): Contains a tensor with shape
              (num_instances, C) where C >=7.

            When there are image prediction in some models, it should
            contains  `pred_instances`, And the ``pred_instances`` normally
            contains following keys.

            - scores (Tensor): Classification scores of image, has a shape
              (num_instance, )
            - labels (Tensor): Predict Labels of 2D bboxes, has a shape
              (num_instances, ).
            - bboxes (Tensor): Contains a tensor with shape
              (num_instances, 4).
        """

        assert (data_instances_2d is not None) or \
               (data_instances_3d is not None),\
               'please pass at least one type of data_samples'

        if data_instances_2d is None:
            data_instances_2d = [
                InstanceData() for _ in range(len(data_instances_3d))
            ]
        if data_instances_3d is None:
            data_instances_3d = [
                InstanceData() for _ in range(len(data_instances_2d))
            ]

        for i, data_sample in enumerate(data_samples):
            data_sample.pred_instances_3d = data_instances_3d[i]
            data_sample.pred_instances = data_instances_2d[i]
        return data_samples