pillar_scatter.py 3.61 KB
Newer Older
zhangwenwei's avatar
zhangwenwei committed
1
import torch
2
from mmcv.runner import auto_fp16
zhangwenwei's avatar
zhangwenwei committed
3
4
from torch import nn

5
from ..builder import MIDDLE_ENCODERS
zhangwenwei's avatar
zhangwenwei committed
6
7


8
@MIDDLE_ENCODERS.register_module()
zhangwenwei's avatar
zhangwenwei committed
9
class PointPillarsScatter(nn.Module):
zhangwenwei's avatar
zhangwenwei committed
10
    """Point Pillar's Scatter.
zhangwenwei's avatar
zhangwenwei committed
11

zhangwenwei's avatar
zhangwenwei committed
12
    Converts learned features from dense tensor to sparse pseudo image.
zhangwenwei's avatar
zhangwenwei committed
13

zhangwenwei's avatar
zhangwenwei committed
14
15
16
17
    Args:
        in_channels (int): Channels of input features.
        output_shape (list[int]): Required output shape of features.
    """
zhangwenwei's avatar
zhangwenwei committed
18

zhangwenwei's avatar
zhangwenwei committed
19
    def __init__(self, in_channels, output_shape):
zhangwenwei's avatar
zhangwenwei committed
20
21
22
23
        super().__init__()
        self.output_shape = output_shape
        self.ny = output_shape[0]
        self.nx = output_shape[1]
zhangwenwei's avatar
zhangwenwei committed
24
        self.in_channels = in_channels
25
        self.fp16_enabled = False
zhangwenwei's avatar
zhangwenwei committed
26

27
    @auto_fp16(apply_to=('voxel_features', ))
zhangwenwei's avatar
zhangwenwei committed
28
    def forward(self, voxel_features, coors, batch_size=None):
zhangwenwei's avatar
zhangwenwei committed
29
        """Foraward function to scatter features."""
zhangwenwei's avatar
zhangwenwei committed
30
31
32
33
34
35
36
37
        # TODO: rewrite the function in a batch manner
        # no need to deal with different batch cases
        if batch_size is not None:
            return self.forward_batch(voxel_features, coors, batch_size)
        else:
            return self.forward_single(voxel_features, coors)

    def forward_single(self, voxel_features, coors):
zhangwenwei's avatar
zhangwenwei committed
38
        """Scatter features of single sample.
zhangwenwei's avatar
zhangwenwei committed
39
40
41
42
43
44

        Args:
            voxel_features (torch.Tensor): Voxel features in shape (N, M, C).
            coors (torch.Tensor): Coordinates of each voxel.
                The first column indicates the sample ID.
        """
zhangwenwei's avatar
zhangwenwei committed
45
46
        # Create the canvas for this sample
        canvas = torch.zeros(
zhangwenwei's avatar
zhangwenwei committed
47
            self.in_channels,
zhangwenwei's avatar
zhangwenwei committed
48
49
50
51
52
53
54
55
56
57
            self.nx * self.ny,
            dtype=voxel_features.dtype,
            device=voxel_features.device)

        indices = coors[:, 1] * self.nx + coors[:, 2]
        indices = indices.long()
        voxels = voxel_features.t()
        # Now scatter the blob back to the canvas.
        canvas[:, indices] = voxels
        # Undo the column stacking to final 4-dim tensor
zhangwenwei's avatar
zhangwenwei committed
58
        canvas = canvas.view(1, self.in_channels, self.ny, self.nx)
zhangwenwei's avatar
zhangwenwei committed
59
60
61
        return [canvas]

    def forward_batch(self, voxel_features, coors, batch_size):
zhangwenwei's avatar
zhangwenwei committed
62
        """Scatter features of single sample.
zhangwenwei's avatar
zhangwenwei committed
63

zhangwenwei's avatar
zhangwenwei committed
64
65
66
67
68
69
        Args:
            voxel_features (torch.Tensor): Voxel features in shape (N, M, C).
            coors (torch.Tensor): Coordinates of each voxel in shape (N, 4).
                The first column indicates the sample ID.
            batch_size (int): Number of samples in the current batch.
        """
zhangwenwei's avatar
zhangwenwei committed
70
71
72
73
74
        # batch_canvas will be the final output.
        batch_canvas = []
        for batch_itt in range(batch_size):
            # Create the canvas for this sample
            canvas = torch.zeros(
zhangwenwei's avatar
zhangwenwei committed
75
                self.in_channels,
zhangwenwei's avatar
zhangwenwei committed
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
                self.nx * self.ny,
                dtype=voxel_features.dtype,
                device=voxel_features.device)

            # Only include non-empty pillars
            batch_mask = coors[:, 0] == batch_itt
            this_coors = coors[batch_mask, :]
            indices = this_coors[:, 2] * self.nx + this_coors[:, 3]
            indices = indices.type(torch.long)
            voxels = voxel_features[batch_mask, :]
            voxels = voxels.t()

            # Now scatter the blob back to the canvas.
            canvas[:, indices] = voxels

            # Append to a list for later stacking.
            batch_canvas.append(canvas)

zhangwenwei's avatar
zhangwenwei committed
94
        # Stack to 3-dim tensor (batch-size, in_channels, nrows*ncols)
zhangwenwei's avatar
zhangwenwei committed
95
96
97
        batch_canvas = torch.stack(batch_canvas, 0)

        # Undo the column stacking to final 4-dim tensor
zhangwenwei's avatar
zhangwenwei committed
98
        batch_canvas = batch_canvas.view(batch_size, self.in_channels, self.ny,
zhangwenwei's avatar
zhangwenwei committed
99
100
101
                                         self.nx)

        return batch_canvas