pillar_scatter.py 3.66 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
zhangwenwei's avatar
zhangwenwei committed
2
import torch
3
from mmcv.runner import auto_fp16
zhangwenwei's avatar
zhangwenwei committed
4
5
from torch import nn

6
from ..builder import MIDDLE_ENCODERS
zhangwenwei's avatar
zhangwenwei committed
7
8


9
@MIDDLE_ENCODERS.register_module()
zhangwenwei's avatar
zhangwenwei committed
10
class PointPillarsScatter(nn.Module):
zhangwenwei's avatar
zhangwenwei committed
11
    """Point Pillar's Scatter.
zhangwenwei's avatar
zhangwenwei committed
12

zhangwenwei's avatar
zhangwenwei committed
13
    Converts learned features from dense tensor to sparse pseudo image.
zhangwenwei's avatar
zhangwenwei committed
14

zhangwenwei's avatar
zhangwenwei committed
15
16
17
18
    Args:
        in_channels (int): Channels of input features.
        output_shape (list[int]): Required output shape of features.
    """
zhangwenwei's avatar
zhangwenwei committed
19

zhangwenwei's avatar
zhangwenwei committed
20
    def __init__(self, in_channels, output_shape):
zhangwenwei's avatar
zhangwenwei committed
21
22
23
24
        super().__init__()
        self.output_shape = output_shape
        self.ny = output_shape[0]
        self.nx = output_shape[1]
zhangwenwei's avatar
zhangwenwei committed
25
        self.in_channels = in_channels
26
        self.fp16_enabled = False
zhangwenwei's avatar
zhangwenwei committed
27

28
    @auto_fp16(apply_to=('voxel_features', ))
zhangwenwei's avatar
zhangwenwei committed
29
    def forward(self, voxel_features, coors, batch_size=None):
zhangwenwei's avatar
zhangwenwei committed
30
        """Foraward function to scatter features."""
zhangwenwei's avatar
zhangwenwei committed
31
32
33
34
35
36
37
38
        # TODO: rewrite the function in a batch manner
        # no need to deal with different batch cases
        if batch_size is not None:
            return self.forward_batch(voxel_features, coors, batch_size)
        else:
            return self.forward_single(voxel_features, coors)

    def forward_single(self, voxel_features, coors):
zhangwenwei's avatar
zhangwenwei committed
39
        """Scatter features of single sample.
zhangwenwei's avatar
zhangwenwei committed
40
41
42
43
44
45

        Args:
            voxel_features (torch.Tensor): Voxel features in shape (N, M, C).
            coors (torch.Tensor): Coordinates of each voxel.
                The first column indicates the sample ID.
        """
zhangwenwei's avatar
zhangwenwei committed
46
47
        # Create the canvas for this sample
        canvas = torch.zeros(
zhangwenwei's avatar
zhangwenwei committed
48
            self.in_channels,
zhangwenwei's avatar
zhangwenwei committed
49
50
51
52
53
54
55
56
57
58
            self.nx * self.ny,
            dtype=voxel_features.dtype,
            device=voxel_features.device)

        indices = coors[:, 1] * self.nx + coors[:, 2]
        indices = indices.long()
        voxels = voxel_features.t()
        # Now scatter the blob back to the canvas.
        canvas[:, indices] = voxels
        # Undo the column stacking to final 4-dim tensor
zhangwenwei's avatar
zhangwenwei committed
59
        canvas = canvas.view(1, self.in_channels, self.ny, self.nx)
zhangwenwei's avatar
zhangwenwei committed
60
61
62
        return [canvas]

    def forward_batch(self, voxel_features, coors, batch_size):
zhangwenwei's avatar
zhangwenwei committed
63
        """Scatter features of single sample.
zhangwenwei's avatar
zhangwenwei committed
64

zhangwenwei's avatar
zhangwenwei committed
65
66
67
68
69
70
        Args:
            voxel_features (torch.Tensor): Voxel features in shape (N, M, C).
            coors (torch.Tensor): Coordinates of each voxel in shape (N, 4).
                The first column indicates the sample ID.
            batch_size (int): Number of samples in the current batch.
        """
zhangwenwei's avatar
zhangwenwei committed
71
72
73
74
75
        # batch_canvas will be the final output.
        batch_canvas = []
        for batch_itt in range(batch_size):
            # Create the canvas for this sample
            canvas = torch.zeros(
zhangwenwei's avatar
zhangwenwei committed
76
                self.in_channels,
zhangwenwei's avatar
zhangwenwei committed
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
                self.nx * self.ny,
                dtype=voxel_features.dtype,
                device=voxel_features.device)

            # Only include non-empty pillars
            batch_mask = coors[:, 0] == batch_itt
            this_coors = coors[batch_mask, :]
            indices = this_coors[:, 2] * self.nx + this_coors[:, 3]
            indices = indices.type(torch.long)
            voxels = voxel_features[batch_mask, :]
            voxels = voxels.t()

            # Now scatter the blob back to the canvas.
            canvas[:, indices] = voxels

            # Append to a list for later stacking.
            batch_canvas.append(canvas)

zhangwenwei's avatar
zhangwenwei committed
95
        # Stack to 3-dim tensor (batch-size, in_channels, nrows*ncols)
zhangwenwei's avatar
zhangwenwei committed
96
97
98
        batch_canvas = torch.stack(batch_canvas, 0)

        # Undo the column stacking to final 4-dim tensor
zhangwenwei's avatar
zhangwenwei committed
99
        batch_canvas = batch_canvas.view(batch_size, self.in_channels, self.ny,
zhangwenwei's avatar
zhangwenwei committed
100
101
102
                                         self.nx)

        return batch_canvas