uncertain_smooth_l1_loss.py 6.81 KB
Newer Older
1
# Copyright (c) OpenMMLab. All rights reserved.
2
3
from typing import Optional

4
import torch
5
from mmdet.models.losses.utils import weighted_loss
6
from torch import Tensor
7
8
from torch import nn as nn

9
from mmdet3d.registry import MODELS
10
11
12


@weighted_loss
13
14
15
16
17
def uncertain_smooth_l1_loss(pred: Tensor,
                             target: Tensor,
                             sigma: Tensor,
                             alpha: float = 1.0,
                             beta: float = 1.0) -> Tensor:
18
19
20
    """Smooth L1 loss with uncertainty.

    Args:
21
22
23
24
        pred (Tensor): The prediction.
        target (Tensor): The learning target of the prediction.
        sigma (Tensor): The sigma for uncertainty.
        alpha (float): The coefficient of log(sigma).
25
            Defaults to 1.0.
26
        beta (float): The threshold in the piecewise function.
27
28
29
            Defaults to 1.0.

    Returns:
30
        Tensor: Calculated loss
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
    """
    assert beta > 0
    assert target.numel() > 0
    assert pred.size() == target.size() == sigma.size(), 'The size of pred ' \
        f'{pred.size()}, target {target.size()}, and sigma {sigma.size()} ' \
        'are inconsistent.'
    diff = torch.abs(pred - target)
    loss = torch.where(diff < beta, 0.5 * diff * diff / beta,
                       diff - 0.5 * beta)
    loss = torch.exp(-sigma) * loss + alpha * sigma

    return loss


@weighted_loss
46
47
48
49
def uncertain_l1_loss(pred: Tensor,
                      target: Tensor,
                      sigma: Tensor,
                      alpha: float = 1.0) -> Tensor:
50
51
52
    """L1 loss with uncertainty.

    Args:
53
54
55
56
        pred (Tensor): The prediction.
        target (Tensor): The learning target of the prediction.
        sigma (Tensor): The sigma for uncertainty.
        alpha (float): The coefficient of log(sigma).
57
58
59
            Defaults to 1.0.

    Returns:
60
        Tensor: Calculated loss
61
62
63
64
65
66
67
68
69
70
    """
    assert target.numel() > 0
    assert pred.size() == target.size() == sigma.size(), 'The size of pred ' \
        f'{pred.size()}, target {target.size()}, and sigma {sigma.size()} ' \
        'are inconsistent.'
    loss = torch.abs(pred - target)
    loss = torch.exp(-sigma) * loss + alpha * sigma
    return loss


71
@MODELS.register_module()
72
73
74
75
76
77
78
79
class UncertainSmoothL1Loss(nn.Module):
    r"""Smooth L1 loss with uncertainty.

    Please refer to `PGD <https://arxiv.org/abs/2107.14160>`_ and
    `Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry
    and Semantics <https://arxiv.org/abs/1705.07115>`_ for more details.

    Args:
80
        alpha (float): The coefficient of log(sigma).
81
            Defaults to 1.0.
82
        beta (float): The threshold in the piecewise function.
83
            Defaults to 1.0.
84
        reduction (str): The method to reduce the loss.
85
            Options are 'none', 'mean' and 'sum'. Defaults to 'mean'.
86
        loss_weight (float): The weight of loss. Defaults to 1.0
87
88
    """

89
90
91
92
93
    def __init__(self,
                 alpha: float = 1.0,
                 beta: float = 1.0,
                 reduction: str = 'mean',
                 loss_weight: float = 1.0) -> None:
94
95
96
97
98
99
100
101
        super(UncertainSmoothL1Loss, self).__init__()
        assert reduction in ['none', 'sum', 'mean']
        self.alpha = alpha
        self.beta = beta
        self.reduction = reduction
        self.loss_weight = loss_weight

    def forward(self,
102
103
104
105
106
107
108
                pred: Tensor,
                target: Tensor,
                sigma: Tensor,
                weight: Optional[Tensor] = None,
                avg_factor: Optional[float] = None,
                reduction_override: Optional[str] = None,
                **kwargs) -> Tensor:
109
110
111
        """Forward function.

        Args:
112
113
114
115
            pred (Tensor): The prediction.
            target (Tensor): The learning target of the prediction.
            sigma (Tensor): The sigma for uncertainty.
            weight (Tensor, optional): The weight of loss for each
116
                prediction. Defaults to None.
117
118
            avg_factor (float, optional): Average factor that is used to
                average the loss. Defaults to None.
119
120
121
            reduction_override (str, optional): The reduction method used to
                override the original reduction method of the loss.
                Defaults to None.
122
123
124

        Returns:
            Tensor: Calculated loss
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
        """
        assert reduction_override in (None, 'none', 'mean', 'sum')
        reduction = (
            reduction_override if reduction_override else self.reduction)
        loss_bbox = self.loss_weight * uncertain_smooth_l1_loss(
            pred,
            target,
            weight,
            sigma=sigma,
            alpha=self.alpha,
            beta=self.beta,
            reduction=reduction,
            avg_factor=avg_factor,
            **kwargs)
        return loss_bbox


142
@MODELS.register_module()
143
144
145
146
class UncertainL1Loss(nn.Module):
    """L1 loss with uncertainty.

    Args:
147
        alpha (float): The coefficient of log(sigma).
148
            Defaults to 1.0.
149
        reduction (str): The method to reduce the loss.
150
            Options are 'none', 'mean' and 'sum'. Defaults to 'mean'.
151
        loss_weight (float): The weight of loss. Defaults to 1.0.
152
153
    """

154
155
156
157
    def __init__(self,
                 alpha: float = 1.0,
                 reduction: str = 'mean',
                 loss_weight: float = 1.0) -> None:
158
159
160
161
162
163
164
        super(UncertainL1Loss, self).__init__()
        assert reduction in ['none', 'sum', 'mean']
        self.alpha = alpha
        self.reduction = reduction
        self.loss_weight = loss_weight

    def forward(self,
165
166
167
168
169
170
                pred: Tensor,
                target: Tensor,
                sigma: Tensor,
                weight: Optional[Tensor] = None,
                avg_factor: Optional[float] = None,
                reduction_override: Optional[str] = None) -> Tensor:
171
172
173
        """Forward function.

        Args:
174
175
176
177
            pred (Tensor): The prediction.
            target (Tensor): The learning target of the prediction.
            sigma (Tensor): The sigma for uncertainty.
            weight (Tensor, optional): The weight of loss for each
178
                prediction. Defaults to None.
179
180
            avg_factor (float, optional): Average factor that is used to
                average the loss. Defaults to None.
181
182
183
            reduction_override (str, optional): The reduction method used to
                override the original reduction method of the loss.
                Defaults to None.
184
185
186

        Returns:
            Tensor: Calculated loss
187
188
189
190
191
192
193
194
195
196
197
198
199
        """
        assert reduction_override in (None, 'none', 'mean', 'sum')
        reduction = (
            reduction_override if reduction_override else self.reduction)
        loss_bbox = self.loss_weight * uncertain_l1_loss(
            pred,
            target,
            weight,
            sigma=sigma,
            alpha=self.alpha,
            reduction=reduction,
            avg_factor=avg_factor)
        return loss_bbox