minkunet_head.py 2.37 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List

import torch
from torch import Tensor
from torch import nn as nn

from mmdet3d.registry import MODELS
from mmdet3d.structures.det3d_data_sample import SampleList
from .decode_head import Base3DDecodeHead


@MODELS.register_module()
class MinkUNetHead(Base3DDecodeHead):
    r"""MinkUNet decoder head with TorchSparse backend.

    Refer to `implementation code <https://github.com/mit-han-lab/spvnas>`_.

    Args:
        channels (int): The input channel of conv_seg.
        num_classes (int): Number of classes.
    """

    def __init__(self, channels: int, num_classes: int, **kwargs) -> None:
        super().__init__(channels, num_classes, **kwargs)

    def build_conv_seg(self, channels: int, num_classes: int,
                       kernel_size: int) -> nn.Module:
        """Build Convolutional Segmentation Layers."""
        return nn.Linear(channels, num_classes)

    def _stack_batch_gt(self, batch_data_samples: SampleList) -> Tensor:
        """Concat voxel-wise Groud Truth."""
        gt_semantic_segs = [
            data_sample.gt_pts_seg.voxel_semantic_mask
            for data_sample in batch_data_samples
        ]
        return torch.cat(gt_semantic_segs)

40
    def predict(self, inputs: Tensor,
41
42
43
44
                batch_data_samples: SampleList) -> List[Tensor]:
        """Forward function for testing.

        Args:
45
            inputs (Tensor): Features from backone.
46
47
48
49
50
51
52
53
            batch_data_samples (List[:obj:`Det3DDataSample`]): The seg
                data samples.

        Returns:
            List[Tensor]: The segmentation prediction mask of each batch.
        """
        seg_logits = self.forward(inputs)

54
55
        batch_idx = torch.cat(
            [data_samples.batch_idx for data_samples in batch_data_samples])
56
57
58
        seg_logit_list = []
        for i, data_sample in enumerate(batch_data_samples):
            seg_logit = seg_logits[batch_idx == i]
59
            seg_logit = seg_logit[data_sample.point2voxel_map]
60
61
62
63
            seg_logit_list.append(seg_logit)

        return seg_logit_list

64
    def forward(self, x: Tensor) -> Tensor:
65
66
67
        """Forward function.

        Args:
68
            x (Tensor): Features from backbone.
69
70
71
72
73

        Returns:
            Tensor: Segmentation map of shape [N, C].
                Note that output contains all points from each batch.
        """
74
        return self.cls_seg(x)