vote_module.py 5.27 KB
Newer Older
wuyuefeng's avatar
wuyuefeng committed
1
2
import torch
from mmcv.cnn import ConvModule
zhangwenwei's avatar
zhangwenwei committed
3
from torch import nn as nn
wuyuefeng's avatar
Votenet  
wuyuefeng committed
4
5

from mmdet3d.models.builder import build_loss
wuyuefeng's avatar
wuyuefeng committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25


class VoteModule(nn.Module):
    """Vote module.

    Generate votes from seed point features.

    Args:
        in_channels (int): Number of channels of seed point features.
        vote_per_seed (int): Number of votes generated from each seed point.
        gt_per_seed (int): Number of ground truth votes generated
            from each seed point.
        conv_channels (tuple[int]): Out channels of vote
            generating convolution.
        conv_cfg (dict): Config of convolution.
            Default: dict(type='Conv1d').
        norm_cfg (dict): Config of normalization.
            Default: dict(type='BN1d').
        norm_feats (bool): Whether to normalize features.
            Default: True.
zhangwenwei's avatar
zhangwenwei committed
26
        vote_loss (dict): Config of vote loss.
wuyuefeng's avatar
wuyuefeng committed
27
28
29
30
31
32
33
34
35
36
    """

    def __init__(self,
                 in_channels,
                 vote_per_seed=1,
                 gt_per_seed=3,
                 conv_channels=(16, 16),
                 conv_cfg=dict(type='Conv1d'),
                 norm_cfg=dict(type='BN1d'),
                 norm_feats=True,
wuyuefeng's avatar
Votenet  
wuyuefeng committed
37
                 vote_loss=None):
wuyuefeng's avatar
wuyuefeng committed
38
39
40
41
42
        super().__init__()
        self.in_channels = in_channels
        self.vote_per_seed = vote_per_seed
        self.gt_per_seed = gt_per_seed
        self.norm_feats = norm_feats
wuyuefeng's avatar
Votenet  
wuyuefeng committed
43
        self.vote_loss = build_loss(vote_loss)
wuyuefeng's avatar
wuyuefeng committed
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68

        prev_channels = in_channels
        vote_conv_list = list()
        for k in range(len(conv_channels)):
            vote_conv_list.append(
                ConvModule(
                    prev_channels,
                    conv_channels[k],
                    1,
                    padding=0,
                    conv_cfg=conv_cfg,
                    norm_cfg=norm_cfg,
                    bias=True,
                    inplace=True))
            prev_channels = conv_channels[k]
        self.vote_conv = nn.Sequential(*vote_conv_list)

        # conv_out predicts coordinate and residual features
        out_channel = (3 + in_channels) * self.vote_per_seed
        self.conv_out = nn.Conv1d(prev_channels, out_channel, 1)

    def forward(self, seed_points, seed_feats):
        """forward.

        Args:
zhangwenwei's avatar
zhangwenwei committed
69
70
71
72
            seed_points (torch.Tensor): Coordinate of the seed
                points in shape (B, N, 3).
            seed_feats (torch.Tensor): Features of the seed points in shape
                (B, C, N).
wuyuefeng's avatar
wuyuefeng committed
73
74

        Returns:
75
76
            tuple[torch.Tensor]:

zhangwenwei's avatar
zhangwenwei committed
77
78
79
80
81
                - vote_points: Voted xyz based on the seed points \
                    with shape (B, M, 3), ``M=num_seed*vote_per_seed``.
                - vote_features: Voted features based on the seed points with \
                    shape (B, C, M) where ``M=num_seed*vote_per_seed``, \
                    ``C=vote_feature_dim``.
wuyuefeng's avatar
wuyuefeng committed
82
83
84
85
86
87
88
89
90
        """
        batch_size, feat_channels, num_seed = seed_feats.shape
        num_vote = num_seed * self.vote_per_seed
        x = self.vote_conv(seed_feats)
        # (batch_size, (3+out_dim)*vote_per_seed, num_seed)
        votes = self.conv_out(x)

        votes = votes.transpose(2, 1).view(batch_size, num_seed,
                                           self.vote_per_seed, -1)
Wenwei Zhang's avatar
Wenwei Zhang committed
91
92
        offset = votes[:, :, :, 0:3].contiguous()
        res_feats = votes[:, :, :, 3:].contiguous()
wuyuefeng's avatar
wuyuefeng committed
93

Wenwei Zhang's avatar
Wenwei Zhang committed
94
        vote_points = seed_points.unsqueeze(2) + offset
wuyuefeng's avatar
wuyuefeng committed
95
        vote_points = vote_points.view(batch_size, num_vote, 3)
Wenwei Zhang's avatar
Wenwei Zhang committed
96
97
        vote_feats = seed_feats.permute(
            0, 2, 1).unsqueeze(2).contiguous() + res_feats
wuyuefeng's avatar
wuyuefeng committed
98
99
100
101
102
103
104
105
106
107
108
109
110
111
        vote_feats = vote_feats.view(batch_size, num_vote,
                                     feat_channels).transpose(2,
                                                              1).contiguous()

        if self.norm_feats:
            features_norm = torch.norm(vote_feats, p=2, dim=1)
            vote_feats = vote_feats.div(features_norm.unsqueeze(1))
        return vote_points, vote_feats

    def get_loss(self, seed_points, vote_points, seed_indices,
                 vote_targets_mask, vote_targets):
        """Calculate loss of voting module.

        Args:
zhangwenwei's avatar
zhangwenwei committed
112
113
114
115
116
            seed_points (torch.Tensor): Coordinate of the seed points.
            vote_points (torch.Tensor): Coordinate of the vote points.
            seed_indices (torch.Tensor): Indices of seed points in raw points.
            vote_targets_mask (torch.Tensor): Mask of valid vote targets.
            vote_targets (torch.Tensor): Targets of votes.
wuyuefeng's avatar
wuyuefeng committed
117
118

        Returns:
zhangwenwei's avatar
zhangwenwei committed
119
            torch.Tensor: Weighted vote loss.
wuyuefeng's avatar
wuyuefeng committed
120
121
122
123
124
        """
        batch_size, num_seed = seed_points.shape[:2]

        seed_gt_votes_mask = torch.gather(vote_targets_mask, 1,
                                          seed_indices).float()
wuyuefeng's avatar
Votenet  
wuyuefeng committed
125

wuyuefeng's avatar
wuyuefeng committed
126
127
128
        seed_indices_expand = seed_indices.unsqueeze(-1).repeat(
            1, 1, 3 * self.gt_per_seed)
        seed_gt_votes = torch.gather(vote_targets, 1, seed_indices_expand)
encore-zhou's avatar
encore-zhou committed
129
        seed_gt_votes += seed_points.repeat(1, 1, self.gt_per_seed)
wuyuefeng's avatar
wuyuefeng committed
130

wuyuefeng's avatar
Votenet  
wuyuefeng committed
131
132
        weight = seed_gt_votes_mask / (torch.sum(seed_gt_votes_mask) + 1e-6)
        distance = self.vote_loss(
wuyuefeng's avatar
wuyuefeng committed
133
134
            vote_points.view(batch_size * num_seed, -1, 3),
            seed_gt_votes.view(batch_size * num_seed, -1, 3),
wuyuefeng's avatar
Votenet  
wuyuefeng committed
135
136
            dst_weight=weight.view(batch_size * num_seed, 1))[1]
        vote_loss = torch.sum(torch.min(distance, dim=1)[0])
wuyuefeng's avatar
wuyuefeng committed
137

wuyuefeng's avatar
Votenet  
wuyuefeng committed
138
        return vote_loss