pillar_scatter.py 3.74 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
2
3
from typing import List

zhangwenwei's avatar
zhangwenwei committed
4
import torch
5
from torch import Tensor, nn
zhangwenwei's avatar
zhangwenwei committed
6

7
from mmdet3d.registry import MODELS
zhangwenwei's avatar
zhangwenwei committed
8
9


10
@MODELS.register_module()
zhangwenwei's avatar
zhangwenwei committed
11
class PointPillarsScatter(nn.Module):
zhangwenwei's avatar
zhangwenwei committed
12
    """Point Pillar's Scatter.
zhangwenwei's avatar
zhangwenwei committed
13

zhangwenwei's avatar
zhangwenwei committed
14
    Converts learned features from dense tensor to sparse pseudo image.
zhangwenwei's avatar
zhangwenwei committed
15

zhangwenwei's avatar
zhangwenwei committed
16
17
18
19
    Args:
        in_channels (int): Channels of input features.
        output_shape (list[int]): Required output shape of features.
    """
zhangwenwei's avatar
zhangwenwei committed
20

21
    def __init__(self, in_channels: int, output_shape: List[int]):
zhangwenwei's avatar
zhangwenwei committed
22
23
24
25
        super().__init__()
        self.output_shape = output_shape
        self.ny = output_shape[0]
        self.nx = output_shape[1]
zhangwenwei's avatar
zhangwenwei committed
26
        self.in_channels = in_channels
zhangwenwei's avatar
zhangwenwei committed
27

28
29
30
31
    def forward(self,
                voxel_features: Tensor,
                coors: Tensor,
                batch_size: int = None) -> Tensor:
zhangwenwei's avatar
zhangwenwei committed
32
        """Foraward function to scatter features."""
zhangwenwei's avatar
zhangwenwei committed
33
34
35
36
37
38
39
        # TODO: rewrite the function in a batch manner
        # no need to deal with different batch cases
        if batch_size is not None:
            return self.forward_batch(voxel_features, coors, batch_size)
        else:
            return self.forward_single(voxel_features, coors)

40
    def forward_single(self, voxel_features: Tensor, coors: Tensor) -> Tensor:
zhangwenwei's avatar
zhangwenwei committed
41
        """Scatter features of single sample.
zhangwenwei's avatar
zhangwenwei committed
42
43
44
45
46
47

        Args:
            voxel_features (torch.Tensor): Voxel features in shape (N, M, C).
            coors (torch.Tensor): Coordinates of each voxel.
                The first column indicates the sample ID.
        """
zhangwenwei's avatar
zhangwenwei committed
48
49
        # Create the canvas for this sample
        canvas = torch.zeros(
zhangwenwei's avatar
zhangwenwei committed
50
            self.in_channels,
zhangwenwei's avatar
zhangwenwei committed
51
52
53
54
            self.nx * self.ny,
            dtype=voxel_features.dtype,
            device=voxel_features.device)

55
        indices = coors[:, 2] * self.nx + coors[:, 3]
zhangwenwei's avatar
zhangwenwei committed
56
57
58
59
60
        indices = indices.long()
        voxels = voxel_features.t()
        # Now scatter the blob back to the canvas.
        canvas[:, indices] = voxels
        # Undo the column stacking to final 4-dim tensor
zhangwenwei's avatar
zhangwenwei committed
61
        canvas = canvas.view(1, self.in_channels, self.ny, self.nx)
62
        return canvas
zhangwenwei's avatar
zhangwenwei committed
63

64
65
    def forward_batch(self, voxel_features: Tensor, coors: Tensor,
                      batch_size: int) -> Tensor:
zhangwenwei's avatar
zhangwenwei committed
66
        """Scatter features of single sample.
zhangwenwei's avatar
zhangwenwei committed
67

zhangwenwei's avatar
zhangwenwei committed
68
69
70
71
72
73
        Args:
            voxel_features (torch.Tensor): Voxel features in shape (N, M, C).
            coors (torch.Tensor): Coordinates of each voxel in shape (N, 4).
                The first column indicates the sample ID.
            batch_size (int): Number of samples in the current batch.
        """
zhangwenwei's avatar
zhangwenwei committed
74
75
76
77
78
        # batch_canvas will be the final output.
        batch_canvas = []
        for batch_itt in range(batch_size):
            # Create the canvas for this sample
            canvas = torch.zeros(
zhangwenwei's avatar
zhangwenwei committed
79
                self.in_channels,
zhangwenwei's avatar
zhangwenwei committed
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
                self.nx * self.ny,
                dtype=voxel_features.dtype,
                device=voxel_features.device)

            # Only include non-empty pillars
            batch_mask = coors[:, 0] == batch_itt
            this_coors = coors[batch_mask, :]
            indices = this_coors[:, 2] * self.nx + this_coors[:, 3]
            indices = indices.type(torch.long)
            voxels = voxel_features[batch_mask, :]
            voxels = voxels.t()

            # Now scatter the blob back to the canvas.
            canvas[:, indices] = voxels

            # Append to a list for later stacking.
            batch_canvas.append(canvas)

zhangwenwei's avatar
zhangwenwei committed
98
        # Stack to 3-dim tensor (batch-size, in_channels, nrows*ncols)
zhangwenwei's avatar
zhangwenwei committed
99
100
101
        batch_canvas = torch.stack(batch_canvas, 0)

        # Undo the column stacking to final 4-dim tensor
zhangwenwei's avatar
zhangwenwei committed
102
        batch_canvas = batch_canvas.view(batch_size, self.in_channels, self.ny,
zhangwenwei's avatar
zhangwenwei committed
103
104
105
                                         self.nx)

        return batch_canvas