map2bev.py 1.82 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
# modified from https://github.com/Haiyang-W/DSVT
import torch
import torch.nn as nn

from mmdet3d.registry import MODELS


@MODELS.register_module()
class PointPillarsScatter3D(nn.Module):
    """The difference between `PointPillarsScatter3D` and `PointPillarsScatter`
    is that the voxel in this module is along 3 dims: (x, y, z)."""

    def __init__(self, output_shape, num_bev_feats, **kwargs):
        super().__init__()
        self.nx, self.ny, self.nz = output_shape
        self.num_bev_feats = num_bev_feats
        self.num_bev_feats_ori = num_bev_feats // self.nz

    def forward(self, batch_dict, **kwargs):
        pillar_features, coords = batch_dict['pillar_features'], batch_dict[
            'voxel_coords']

        batch_spatial_features = []
        batch_size = coords[:, 0].max().int().item() + 1
        for batch_idx in range(batch_size):
            spatial_feature = torch.zeros(
                self.num_bev_feats_ori,
                self.nz * self.nx * self.ny,
                dtype=pillar_features.dtype,
                device=pillar_features.device)

            batch_mask = coords[:, 0] == batch_idx
            this_coords = coords[batch_mask, :]
            indices = this_coords[:, 1] * self.ny * self.nx + \
                this_coords[:, 2] * self.nx + this_coords[:,  3]
            indices = indices.type(torch.long)
            pillars = pillar_features[batch_mask, :]
            pillars = pillars.t()
            spatial_feature[:, indices] = pillars
            batch_spatial_features.append(spatial_feature)

        batch_spatial_features = torch.stack(batch_spatial_features, 0)
        batch_spatial_features = batch_spatial_features.view(
            batch_size, self.num_bev_feats_ori * self.nz, self.ny, self.nx)
        batch_dict['spatial_features'] = batch_spatial_features
        return batch_dict