pillar_scatter.py 3.5 KB
Newer Older
zhangwenwei's avatar
zhangwenwei committed
1
2
3
4
5
6
import torch
from torch import nn

from ..registry import MIDDLE_ENCODERS


7
@MIDDLE_ENCODERS.register_module()
zhangwenwei's avatar
zhangwenwei committed
8
class PointPillarsScatter(nn.Module):
zhangwenwei's avatar
zhangwenwei committed
9
    """Point Pillar's Scatter.
zhangwenwei's avatar
zhangwenwei committed
10

zhangwenwei's avatar
zhangwenwei committed
11
    Converts learned features from dense tensor to sparse pseudo image.
zhangwenwei's avatar
zhangwenwei committed
12

zhangwenwei's avatar
zhangwenwei committed
13
14
15
16
    Args:
        in_channels (int): Channels of input features.
        output_shape (list[int]): Required output shape of features.
    """
zhangwenwei's avatar
zhangwenwei committed
17

zhangwenwei's avatar
zhangwenwei committed
18
    def __init__(self, in_channels, output_shape):
zhangwenwei's avatar
zhangwenwei committed
19
20
21
22
        super().__init__()
        self.output_shape = output_shape
        self.ny = output_shape[0]
        self.nx = output_shape[1]
zhangwenwei's avatar
zhangwenwei committed
23
        self.in_channels = in_channels
zhangwenwei's avatar
zhangwenwei committed
24
25

    def forward(self, voxel_features, coors, batch_size=None):
zhangwenwei's avatar
zhangwenwei committed
26
        """Foraward function to scatter features."""
zhangwenwei's avatar
zhangwenwei committed
27
28
29
30
31
32
33
34
        # TODO: rewrite the function in a batch manner
        # no need to deal with different batch cases
        if batch_size is not None:
            return self.forward_batch(voxel_features, coors, batch_size)
        else:
            return self.forward_single(voxel_features, coors)

    def forward_single(self, voxel_features, coors):
zhangwenwei's avatar
zhangwenwei committed
35
        """Scatter features of single sample.
zhangwenwei's avatar
zhangwenwei committed
36
37
38
39
40
41

        Args:
            voxel_features (torch.Tensor): Voxel features in shape (N, M, C).
            coors (torch.Tensor): Coordinates of each voxel.
                The first column indicates the sample ID.
        """
zhangwenwei's avatar
zhangwenwei committed
42
43
        # Create the canvas for this sample
        canvas = torch.zeros(
zhangwenwei's avatar
zhangwenwei committed
44
            self.in_channels,
zhangwenwei's avatar
zhangwenwei committed
45
46
47
48
49
50
51
52
53
54
            self.nx * self.ny,
            dtype=voxel_features.dtype,
            device=voxel_features.device)

        indices = coors[:, 1] * self.nx + coors[:, 2]
        indices = indices.long()
        voxels = voxel_features.t()
        # Now scatter the blob back to the canvas.
        canvas[:, indices] = voxels
        # Undo the column stacking to final 4-dim tensor
zhangwenwei's avatar
zhangwenwei committed
55
        canvas = canvas.view(1, self.in_channels, self.ny, self.nx)
zhangwenwei's avatar
zhangwenwei committed
56
57
58
        return [canvas]

    def forward_batch(self, voxel_features, coors, batch_size):
zhangwenwei's avatar
zhangwenwei committed
59
        """Scatter features of single sample.
zhangwenwei's avatar
zhangwenwei committed
60

zhangwenwei's avatar
zhangwenwei committed
61
62
63
64
65
66
        Args:
            voxel_features (torch.Tensor): Voxel features in shape (N, M, C).
            coors (torch.Tensor): Coordinates of each voxel in shape (N, 4).
                The first column indicates the sample ID.
            batch_size (int): Number of samples in the current batch.
        """
zhangwenwei's avatar
zhangwenwei committed
67
68
69
70
71
        # batch_canvas will be the final output.
        batch_canvas = []
        for batch_itt in range(batch_size):
            # Create the canvas for this sample
            canvas = torch.zeros(
zhangwenwei's avatar
zhangwenwei committed
72
                self.in_channels,
zhangwenwei's avatar
zhangwenwei committed
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
                self.nx * self.ny,
                dtype=voxel_features.dtype,
                device=voxel_features.device)

            # Only include non-empty pillars
            batch_mask = coors[:, 0] == batch_itt
            this_coors = coors[batch_mask, :]
            indices = this_coors[:, 2] * self.nx + this_coors[:, 3]
            indices = indices.type(torch.long)
            voxels = voxel_features[batch_mask, :]
            voxels = voxels.t()

            # Now scatter the blob back to the canvas.
            canvas[:, indices] = voxels

            # Append to a list for later stacking.
            batch_canvas.append(canvas)

zhangwenwei's avatar
zhangwenwei committed
91
        # Stack to 3-dim tensor (batch-size, in_channels, nrows*ncols)
zhangwenwei's avatar
zhangwenwei committed
92
93
94
        batch_canvas = torch.stack(batch_canvas, 0)

        # Undo the column stacking to final 4-dim tensor
zhangwenwei's avatar
zhangwenwei committed
95
        batch_canvas = batch_canvas.view(batch_size, self.in_channels, self.ny,
zhangwenwei's avatar
zhangwenwei committed
96
97
98
                                         self.nx)

        return batch_canvas