dsvt_input_layer.py 17.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
# modified from https://github.com/Haiyang-W/DSVT
from math import ceil

import torch
from torch import nn

from .utils import (PositionEmbeddingLearned, get_continous_inds,
                    get_inner_win_inds_cuda, get_pooling_index,
                    get_window_coors)


class DSVTInputLayer(nn.Module):
    '''
    This class converts the output of vfe to dsvt input.
    We do in this class:
    1. Window partition: partition voxels to non-overlapping windows.
    2. Set partition: generate non-overlapped and size-equivalent local sets
        within each window.
    3. Pre-compute the downsample information between two consecutive stages.
    4. Pre-compute the position embedding vectors.

    Args:
        sparse_shape (tuple[int, int, int]): Shape of input space
            (xdim, ydim, zdim).
        window_shape (list[list[int, int, int]]): Window shapes
            (winx, winy, winz) in different stages. Length: stage_num.
        downsample_stride (list[list[int, int, int]]): Downsample
            strides between two consecutive stages.
            Element i is [ds_x, ds_y, ds_z], which is used between stage_i and
            stage_{i+1}. Length: stage_num - 1.
        dim_model (list[int]): Number of input channels for each stage. Length:
            stage_num.
        set_info (list[list[int, int]]): A list of set config for each stage.
            Eelement i contains
            [set_size, block_num], where set_size is the number of voxel in a
            set and block_num is the
            number of blocks for stage i. Length: stage_num.
        hybrid_factor (list[int, int, int]): Control the window shape in
            different blocks.
            e.g. for block_{0} and block_{1} in stage_0, window shapes are
            [win_x, win_y, win_z] and
            [win_x * h[0], win_y * h[1], win_z * h[2]] respectively.
        shift_list (list): Shift window. Length: stage_num.
        normalize_pos (bool): Whether to normalize coordinates in position
            embedding.
    '''

    def __init__(self, sparse_shape, window_shape, downsample_stride,
                 dim_model, set_info, hybrid_factor, shift_list,
                 normalize_pos):
        super().__init__()

        self.sparse_shape = sparse_shape
        self.window_shape = window_shape
        self.downsample_stride = downsample_stride
        self.dim_model = dim_model
        self.set_info = set_info
        self.stage_num = len(self.dim_model)

        self.hybrid_factor = hybrid_factor
        self.window_shape = [[
            self.window_shape[s_id],
            [
                self.window_shape[s_id][coord_id] *
                self.hybrid_factor[coord_id] for coord_id in range(3)
            ]
        ] for s_id in range(self.stage_num)]
        self.shift_list = shift_list
        self.normalize_pos = normalize_pos

        self.num_shifts = [
            2,
        ] * len(self.window_shape)

        self.sparse_shape_list = [self.sparse_shape]
        # compute sparse shapes for each stage
        for ds_stride in self.downsample_stride:
            last_sparse_shape = self.sparse_shape_list[-1]
            self.sparse_shape_list.append(
                (ceil(last_sparse_shape[0] / ds_stride[0]),
                 ceil(last_sparse_shape[1] / ds_stride[1]),
                 ceil(last_sparse_shape[2] / ds_stride[2])))

        # position embedding layers
        self.posembed_layers = nn.ModuleList()
        for i in range(len(self.set_info)):
            input_dim = 3 if self.sparse_shape_list[i][-1] > 1 else 2
            stage_posembed_layers = nn.ModuleList()
            for j in range(self.set_info[i][1]):
                block_posembed_layers = nn.ModuleList()
                for s in range(self.num_shifts[i]):
                    block_posembed_layers.append(
                        PositionEmbeddingLearned(input_dim, self.dim_model[i]))
                stage_posembed_layers.append(block_posembed_layers)
            self.posembed_layers.append(stage_posembed_layers)

    def forward(self, batch_dict):
        '''
        Args:
            bacth_dict (dict):
                The dict contains the following keys
                - voxel_features (Tensor[float]): Voxel features after VFE
                    with shape (N, dim_model[0]),
                    where N is the number of input voxels.
                - voxel_coords (Tensor[int]): Shape of (N, 4), corresponding
                    voxel coordinates of each voxels.
                    Each row is (batch_id, z, y, x).
                - ...

        Returns:
            voxel_info (dict):
                The dict contains the following keys
                - voxel_coors_stage{i} (Tensor[int]): Shape of (N_i, 4). N is
                    the number of voxels in stage_i.
                    Each row is (batch_id, z, y, x).
                - set_voxel_inds_stage{i}_shift{j} (Tensor[int]): Set partition
                    index with shape (2, set_num, set_info[i][0]).
                    2 indicates x-axis partition and y-axis partition.
                - set_voxel_mask_stage{i}_shift{i} (Tensor[bool]): Key mask
                    used in set attention with shape
                    (2, set_num, set_info[i][0]).
                - pos_embed_stage{i}_block{i}_shift{i} (Tensor[float]):
                    Position embedding vectors with shape (N_i, dim_model[i]).
                    N_i is the number of remain voxels in stage_i;
                - pooling_mapping_index_stage{i} (Tensor[int]): Pooling region
                    index used in pooling operation between stage_{i-1}
                    and stage_{i} with shape (N_{i-1}).
                - pooling_index_in_pool_stage{i} (Tensor[int]): Index inner
                    region with shape (N_{i-1}). Combined with
                    pooling_mapping_index_stage{i}, we can map each voxel in
                    satge_{i-1} to pooling_preholder_feats_stage{i}, which
                    are input of downsample operation.
                - pooling_preholder_feats_stage{i} (Tensor[int]): Preholder
                    features initial with value 0.
                    Shape of (N_{i}, downsample_stride[i-1].prob(),
                    d_moel[i-1]), where prob() returns the product of
                    all elements.
                - ...
        '''
        voxel_feats = batch_dict['voxel_features']
        voxel_coors = batch_dict['voxel_coords'].long()

        voxel_info = {}
        voxel_info['voxel_feats_stage0'] = voxel_feats.clone()
        voxel_info['voxel_coors_stage0'] = voxel_coors.clone()

        for stage_id in range(self.stage_num):
            # window partition of corresponding stage-map
            voxel_info = self.window_partition(voxel_info, stage_id)
            # generate set id of corresponding stage-map
            voxel_info = self.get_set(voxel_info, stage_id)
            for block_id in range(self.set_info[stage_id][1]):
                for shift_id in range(self.num_shifts[stage_id]):
                    layer_name = f'pos_embed_stage{stage_id}_block{block_id}_shift{shift_id}'  # noqa: E501
                    pos_name = f'coors_in_win_stage{stage_id}_shift{shift_id}'
                    voxel_info[layer_name] = self.get_pos_embed(
                        voxel_info[pos_name], stage_id, block_id, shift_id)

            # compute pooling information
            if stage_id < self.stage_num - 1:
                voxel_info = self.subm_pooling(voxel_info, stage_id)

        return voxel_info

    @torch.no_grad()
    def subm_pooling(self, voxel_info, stage_id):
        # x,y,z stride
        cur_stage_downsample = self.downsample_stride[stage_id]
        # batch_win_coords is from 1 of x, y
        batch_win_inds, _, index_in_win, batch_win_coors = get_pooling_index(
            voxel_info[f'voxel_coors_stage{stage_id}'],
            self.sparse_shape_list[stage_id], cur_stage_downsample)
        # compute pooling mapping index
        unique_batch_win_inds, contiguous_batch_win_inds = torch.unique(
            batch_win_inds, return_inverse=True)
        voxel_info[
            f'pooling_mapping_index_stage{stage_id+1}'] = \
            contiguous_batch_win_inds

        # generate empty placeholder features
        placeholder_prepool_feats = voxel_info['voxel_feats_stage0'].new_zeros(
            (len(unique_batch_win_inds),
             torch.prod(torch.IntTensor(cur_stage_downsample)).item(),
             self.dim_model[stage_id]))
        voxel_info[f'pooling_index_in_pool_stage{stage_id+1}'] = index_in_win
        voxel_info[
            f'pooling_preholder_feats_stage{stage_id+1}'] = \
            placeholder_prepool_feats

        # compute pooling coordinates
        unique, inverse = unique_batch_win_inds.clone(
        ), contiguous_batch_win_inds.clone()
        perm = torch.arange(
            inverse.size(0), dtype=inverse.dtype, device=inverse.device)
        inverse, perm = inverse.flip([0]), perm.flip([0])
        perm = inverse.new_empty(unique.size(0)).scatter_(0, inverse, perm)
        pool_coors = batch_win_coors[perm]

        voxel_info[f'voxel_coors_stage{stage_id+1}'] = pool_coors

        return voxel_info

    def get_set(self, voxel_info, stage_id):
        '''
        This is one of the core operation of DSVT.
        Given voxels' window ids and relative-coords inner window, we partition
        them into window-bounded and size-equivalent local sets. To make it
        clear and easy to follow, we do not use loop to process two shifts.
        Args:
            voxel_info (dict):
                The dict contains the following keys
                - batch_win_inds_s{i} (Tensor[float]): Windows indices of each
                    voxel with shape (N), computed by 'window_partition'.
                - coors_in_win_shift{i} (Tensor[int]): Relative-coords inner
                    window of each voxel with shape (N, 3), computed by
                    'window_partition'. Each row is (z, y, x).
                - ...

        Returns:
            See from 'forward' function.
        '''
        batch_win_inds_shift0 = voxel_info[
            f'batch_win_inds_stage{stage_id}_shift0']
        coors_in_win_shift0 = voxel_info[
            f'coors_in_win_stage{stage_id}_shift0']
        set_voxel_inds_shift0 = self.get_set_single_shift(
            batch_win_inds_shift0,
            stage_id,
            shift_id=0,
            coors_in_win=coors_in_win_shift0)
        voxel_info[
            f'set_voxel_inds_stage{stage_id}_shift0'] = set_voxel_inds_shift0
        # compute key masks, voxel duplication must happen continuously
        prefix_set_voxel_inds_s0 = torch.roll(
            set_voxel_inds_shift0.clone(), shifts=1, dims=-1)
        prefix_set_voxel_inds_s0[:, :, 0] = -1
        set_voxel_mask_s0 = (set_voxel_inds_shift0 == prefix_set_voxel_inds_s0)
        voxel_info[
            f'set_voxel_mask_stage{stage_id}_shift0'] = set_voxel_mask_s0

        batch_win_inds_shift1 = voxel_info[
            f'batch_win_inds_stage{stage_id}_shift1']
        coors_in_win_shift1 = voxel_info[
            f'coors_in_win_stage{stage_id}_shift1']
        set_voxel_inds_shift1 = self.get_set_single_shift(
            batch_win_inds_shift1,
            stage_id,
            shift_id=1,
            coors_in_win=coors_in_win_shift1)
        voxel_info[
            f'set_voxel_inds_stage{stage_id}_shift1'] = set_voxel_inds_shift1
        # compute key masks, voxel duplication must happen continuously
        prefix_set_voxel_inds_s1 = torch.roll(
            set_voxel_inds_shift1.clone(), shifts=1, dims=-1)
        prefix_set_voxel_inds_s1[:, :, 0] = -1
        set_voxel_mask_s1 = (set_voxel_inds_shift1 == prefix_set_voxel_inds_s1)
        voxel_info[
            f'set_voxel_mask_stage{stage_id}_shift1'] = set_voxel_mask_s1

        return voxel_info

    def get_set_single_shift(self,
                             batch_win_inds,
                             stage_id,
                             shift_id=None,
                             coors_in_win=None):
        device = batch_win_inds.device
        # the number of voxels assigned to a set
        voxel_num_set = self.set_info[stage_id][0]
        # max number of voxels in a window
        max_voxel = self.window_shape[stage_id][shift_id][
            0] * self.window_shape[stage_id][shift_id][1] * self.window_shape[
                stage_id][shift_id][2]
        # get unique set indices
        contiguous_win_inds = torch.unique(
            batch_win_inds, return_inverse=True)[1]
        voxelnum_per_win = torch.bincount(contiguous_win_inds)
        win_num = voxelnum_per_win.shape[0]
        setnum_per_win_float = voxelnum_per_win / voxel_num_set
        setnum_per_win = torch.ceil(setnum_per_win_float).long()
        set_win_inds, set_inds_in_win = get_continous_inds(setnum_per_win)

        # compution of Eq.3 in 'DSVT: Dynamic Sparse Voxel Transformer with
        # Rotated Sets' - https://arxiv.org/abs/2301.06051,
        # for each window, we can get voxel indices belong to different sets.
        offset_idx = set_inds_in_win[:, None].repeat(
            1, voxel_num_set) * voxel_num_set
        base_idx = torch.arange(0, voxel_num_set, 1, device=device)
        base_select_idx = offset_idx + base_idx
        base_select_idx = base_select_idx * voxelnum_per_win[
            set_win_inds][:, None]
        base_select_idx = base_select_idx.double() / (
            setnum_per_win[set_win_inds] * voxel_num_set)[:, None].double()
        base_select_idx = torch.floor(base_select_idx)
        # obtain unique indices in whole space
        select_idx = base_select_idx
        select_idx = select_idx + set_win_inds.view(-1, 1) * max_voxel

        # this function will return unordered inner window indices of
        # each voxel
        inner_voxel_inds = get_inner_win_inds_cuda(contiguous_win_inds)
        global_voxel_inds = contiguous_win_inds * max_voxel + inner_voxel_inds
        _, order1 = torch.sort(global_voxel_inds)

        # get y-axis partition results
        global_voxel_inds_sorty = contiguous_win_inds * max_voxel + \
            coors_in_win[:, 1] * self.window_shape[stage_id][shift_id][0] * \
            self.window_shape[stage_id][shift_id][2] + coors_in_win[:, 2] * \
            self.window_shape[stage_id][shift_id][2] + \
            coors_in_win[:, 0]
        _, order2 = torch.sort(global_voxel_inds_sorty)
        inner_voxel_inds_sorty = -torch.ones_like(inner_voxel_inds)
        inner_voxel_inds_sorty.scatter_(
            dim=0, index=order2, src=inner_voxel_inds[order1]
        )  # get y-axis ordered inner window indices of each voxel
        voxel_inds_in_batch_sorty = inner_voxel_inds_sorty + max_voxel * \
            contiguous_win_inds
        voxel_inds_padding_sorty = -1 * torch.ones(
            (win_num * max_voxel), dtype=torch.long, device=device)
        voxel_inds_padding_sorty[voxel_inds_in_batch_sorty] = torch.arange(
            0,
            voxel_inds_in_batch_sorty.shape[0],
            dtype=torch.long,
            device=device)
        set_voxel_inds_sorty = voxel_inds_padding_sorty[select_idx.long()]

        # get x-axis partition results
        global_voxel_inds_sortx = contiguous_win_inds * max_voxel + \
            coors_in_win[:, 2] * self.window_shape[stage_id][shift_id][1] * \
            self.window_shape[stage_id][shift_id][2] + \
            coors_in_win[:, 1] * self.window_shape[stage_id][shift_id][2] + \
            coors_in_win[:, 0]
        _, order2 = torch.sort(global_voxel_inds_sortx)
        inner_voxel_inds_sortx = -torch.ones_like(inner_voxel_inds)
        inner_voxel_inds_sortx.scatter_(
            dim=0, index=order2, src=inner_voxel_inds[order1]
        )  # get x-axis ordered inner window indices of each voxel
        voxel_inds_in_batch_sortx = inner_voxel_inds_sortx + max_voxel * \
            contiguous_win_inds
        voxel_inds_padding_sortx = -1 * torch.ones(
            (win_num * max_voxel), dtype=torch.long, device=device)
        voxel_inds_padding_sortx[voxel_inds_in_batch_sortx] = torch.arange(
            0,
            voxel_inds_in_batch_sortx.shape[0],
            dtype=torch.long,
            device=device)
        set_voxel_inds_sortx = voxel_inds_padding_sortx[select_idx.long()]

        all_set_voxel_inds = torch.stack(
            (set_voxel_inds_sorty, set_voxel_inds_sortx), dim=0)
        return all_set_voxel_inds

    @torch.no_grad()
    def window_partition(self, voxel_info, stage_id):
        for i in range(2):
            batch_win_inds, coors_in_win = get_window_coors(
                voxel_info[f'voxel_coors_stage{stage_id}'],
                self.sparse_shape_list[stage_id],
                self.window_shape[stage_id][i], i == 1,
                self.shift_list[stage_id][i])

            voxel_info[
                f'batch_win_inds_stage{stage_id}_shift{i}'] = batch_win_inds
            voxel_info[f'coors_in_win_stage{stage_id}_shift{i}'] = coors_in_win

        return voxel_info

    def get_pos_embed(self, coors_in_win, stage_id, block_id, shift_id):
        '''
        Args:
            coors_in_win: shape=[N, 3], order: z, y, x
        '''
        # [N,]
        window_shape = self.window_shape[stage_id][shift_id]

        embed_layer = self.posembed_layers[stage_id][block_id][shift_id]
        if len(window_shape) == 2:
            ndim = 2
            win_x, win_y = window_shape
            win_z = 0
        elif window_shape[-1] == 1:
            ndim = 2
            win_x, win_y = window_shape[:2]
            win_z = 0
        else:
            win_x, win_y, win_z = window_shape
            ndim = 3

        assert coors_in_win.size(1) == 3
        z, y, x = coors_in_win[:, 0] - win_z / 2,\
            coors_in_win[:, 1] - win_y / 2,\
            coors_in_win[:, 2] - win_x / 2

        if self.normalize_pos:
            x = x / win_x * 2 * 3.1415  # [-pi, pi]
            y = y / win_y * 2 * 3.1415  # [-pi, pi]
            z = z / win_z * 2 * 3.1415  # [-pi, pi]

        if ndim == 2:
            location = torch.stack((x, y), dim=-1)
        else:
            location = torch.stack((x, y, z), dim=-1)
        pos_embed = embed_layer(location)

        return pos_embed