paconv_regularization_loss.py 4.29 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
2
3
from typing import List, Optional

4
import torch
5
from mmdet.models.losses.utils import weight_reduce_loss
6
from torch import Tensor
7
8
from torch import nn as nn

9
from mmdet3d.registry import MODELS
zhangshilong's avatar
zhangshilong committed
10
from ..layers import PAConv, PAConvCUDA
11
12


13
def weight_correlation(conv: nn.Module) -> Tensor:
14
15
16
17
18
19
20
21
    """Calculate correlations between kernel weights in Conv's weight bank as
    regularization loss. The cosine similarity is used as metrics.

    Args:
        conv (nn.Module): A Conv modules to be regularized.
            Currently we only support `PAConv` and `PAConvCUDA`.

    Returns:
22
        Tensor: Correlations between each kernel weights in weight bank.
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
    """
    assert isinstance(conv, (PAConv, PAConvCUDA)), \
        f'unsupported module type {type(conv)}'
    kernels = conv.weight_bank  # [C_in, num_kernels * C_out]
    in_channels = conv.in_channels
    out_channels = conv.out_channels
    num_kernels = conv.num_kernels

    # [num_kernels, Cin * Cout]
    flatten_kernels = kernels.view(in_channels, num_kernels, out_channels).\
        permute(1, 0, 2).reshape(num_kernels, -1)
    # [num_kernels, num_kernels]
    inner_product = torch.matmul(flatten_kernels, flatten_kernels.T)
    # [num_kernels, 1]
    kernel_norms = torch.sum(flatten_kernels**2, dim=-1, keepdim=True)**0.5
    # [num_kernels, num_kernels]
    kernel_norms = torch.matmul(kernel_norms, kernel_norms.T)
    cosine_sims = inner_product / kernel_norms
    # take upper triangular part excluding diagonal since we only compute
    # correlation between different kernels once
    # the square is to ensure positive loss, refer to:
    # https://github.com/CVMI-Lab/PAConv/blob/main/scene_seg/tool/train.py#L208
    corr = torch.sum(torch.triu(cosine_sims, diagonal=1)**2)

    return corr


50
51
def paconv_regularization_loss(modules: List[nn.Module],
                               reduction: str) -> Tensor:
52
53
54
55
56
57
    """Computes correlation loss of PAConv weight kernels as regularization.

    Args:
        modules (List[nn.Module] | :obj:`generator`):
            A list or a python generator of torch.nn.Modules.
        reduction (str): Method to reduce losses among PAConv modules.
58
            The valid reduction method are 'none', 'sum' or 'mean'.
59
60

    Returns:
61
        Tensor: Correlation loss of kernel weights.
62
63
64
65
66
67
68
69
70
71
72
73
74
    """
    corr_loss = []
    for module in modules:
        if isinstance(module, (PAConv, PAConvCUDA)):
            corr_loss.append(weight_correlation(module))
    corr_loss = torch.stack(corr_loss)

    # perform reduction
    corr_loss = weight_reduce_loss(corr_loss, reduction=reduction)

    return corr_loss


75
@MODELS.register_module()
76
77
78
79
80
81
82
83
class PAConvRegularizationLoss(nn.Module):
    """Calculate correlation loss of kernel weights in PAConv's weight bank.

    This is used as a regularization term in PAConv model training.

    Args:
        reduction (str): Method to reduce losses. The reduction is performed
            among all PAConv modules instead of prediction tensors.
84
85
86
            The valid reduction method are 'none', 'sum' or 'mean'.
            Defaults to 'mean'.
        loss_weight (float): Weight of loss. Defaults to 1.0.
87
88
    """

89
90
91
    def __init__(self,
                 reduction: str = 'mean',
                 loss_weight: float = 1.0) -> None:
92
93
94
95
96
        super(PAConvRegularizationLoss, self).__init__()
        assert reduction in ['none', 'sum', 'mean']
        self.reduction = reduction
        self.loss_weight = loss_weight

97
98
99
100
    def forward(self,
                modules: List[nn.Module],
                reduction_override: Optional[str] = None,
                **kwargs) -> Tensor:
101
102
103
104
105
106
107
108
109
110
        """Forward function of loss calculation.

        Args:
            modules (List[nn.Module] | :obj:`generator`):
                A list or a python generator of torch.nn.Modules.
            reduction_override (str, optional): Method to reduce losses.
                The valid reduction method are 'none', 'sum' or 'mean'.
                Defaults to None.

        Returns:
111
            Tensor: Correlation loss of kernel weights.
112
113
114
115
116
117
118
        """
        assert reduction_override in (None, 'none', 'mean', 'sum')
        reduction = (
            reduction_override if reduction_override else self.reduction)

        return self.loss_weight * paconv_regularization_loss(
            modules, reduction=reduction)