pillar_scatter.py 3.53 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
zhangwenwei's avatar
zhangwenwei committed
2
3
4
import torch
from torch import nn

5
from mmdet3d.registry import MODELS
zhangwenwei's avatar
zhangwenwei committed
6
7


8
@MODELS.register_module()
zhangwenwei's avatar
zhangwenwei committed
9
class PointPillarsScatter(nn.Module):
zhangwenwei's avatar
zhangwenwei committed
10
    """Point Pillar's Scatter.
zhangwenwei's avatar
zhangwenwei committed
11

zhangwenwei's avatar
zhangwenwei committed
12
    Converts learned features from dense tensor to sparse pseudo image.
zhangwenwei's avatar
zhangwenwei committed
13

zhangwenwei's avatar
zhangwenwei committed
14
15
16
17
    Args:
        in_channels (int): Channels of input features.
        output_shape (list[int]): Required output shape of features.
    """
zhangwenwei's avatar
zhangwenwei committed
18

zhangwenwei's avatar
zhangwenwei committed
19
    def __init__(self, in_channels, output_shape):
zhangwenwei's avatar
zhangwenwei committed
20
21
22
23
        super().__init__()
        self.output_shape = output_shape
        self.ny = output_shape[0]
        self.nx = output_shape[1]
zhangwenwei's avatar
zhangwenwei committed
24
        self.in_channels = in_channels
zhangwenwei's avatar
zhangwenwei committed
25
26

    def forward(self, voxel_features, coors, batch_size=None):
zhangwenwei's avatar
zhangwenwei committed
27
        """Foraward function to scatter features."""
zhangwenwei's avatar
zhangwenwei committed
28
29
30
31
32
33
34
35
        # TODO: rewrite the function in a batch manner
        # no need to deal with different batch cases
        if batch_size is not None:
            return self.forward_batch(voxel_features, coors, batch_size)
        else:
            return self.forward_single(voxel_features, coors)

    def forward_single(self, voxel_features, coors):
zhangwenwei's avatar
zhangwenwei committed
36
        """Scatter features of single sample.
zhangwenwei's avatar
zhangwenwei committed
37
38
39
40
41
42

        Args:
            voxel_features (torch.Tensor): Voxel features in shape (N, M, C).
            coors (torch.Tensor): Coordinates of each voxel.
                The first column indicates the sample ID.
        """
zhangwenwei's avatar
zhangwenwei committed
43
44
        # Create the canvas for this sample
        canvas = torch.zeros(
zhangwenwei's avatar
zhangwenwei committed
45
            self.in_channels,
zhangwenwei's avatar
zhangwenwei committed
46
47
48
49
            self.nx * self.ny,
            dtype=voxel_features.dtype,
            device=voxel_features.device)

50
        indices = coors[:, 2] * self.nx + coors[:, 3]
zhangwenwei's avatar
zhangwenwei committed
51
52
53
54
55
        indices = indices.long()
        voxels = voxel_features.t()
        # Now scatter the blob back to the canvas.
        canvas[:, indices] = voxels
        # Undo the column stacking to final 4-dim tensor
zhangwenwei's avatar
zhangwenwei committed
56
        canvas = canvas.view(1, self.in_channels, self.ny, self.nx)
57
        return canvas
zhangwenwei's avatar
zhangwenwei committed
58
59

    def forward_batch(self, voxel_features, coors, batch_size):
zhangwenwei's avatar
zhangwenwei committed
60
        """Scatter features of single sample.
zhangwenwei's avatar
zhangwenwei committed
61

zhangwenwei's avatar
zhangwenwei committed
62
63
64
65
66
67
        Args:
            voxel_features (torch.Tensor): Voxel features in shape (N, M, C).
            coors (torch.Tensor): Coordinates of each voxel in shape (N, 4).
                The first column indicates the sample ID.
            batch_size (int): Number of samples in the current batch.
        """
zhangwenwei's avatar
zhangwenwei committed
68
69
70
71
72
        # batch_canvas will be the final output.
        batch_canvas = []
        for batch_itt in range(batch_size):
            # Create the canvas for this sample
            canvas = torch.zeros(
zhangwenwei's avatar
zhangwenwei committed
73
                self.in_channels,
zhangwenwei's avatar
zhangwenwei committed
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
                self.nx * self.ny,
                dtype=voxel_features.dtype,
                device=voxel_features.device)

            # Only include non-empty pillars
            batch_mask = coors[:, 0] == batch_itt
            this_coors = coors[batch_mask, :]
            indices = this_coors[:, 2] * self.nx + this_coors[:, 3]
            indices = indices.type(torch.long)
            voxels = voxel_features[batch_mask, :]
            voxels = voxels.t()

            # Now scatter the blob back to the canvas.
            canvas[:, indices] = voxels

            # Append to a list for later stacking.
            batch_canvas.append(canvas)

zhangwenwei's avatar
zhangwenwei committed
92
        # Stack to 3-dim tensor (batch-size, in_channels, nrows*ncols)
zhangwenwei's avatar
zhangwenwei committed
93
94
95
        batch_canvas = torch.stack(batch_canvas, 0)

        # Undo the column stacking to final 4-dim tensor
zhangwenwei's avatar
zhangwenwei committed
96
        batch_canvas = batch_canvas.view(batch_size, self.in_channels, self.ny,
zhangwenwei's avatar
zhangwenwei committed
97
98
99
                                         self.nx)

        return batch_canvas