detr_loss.py 4.84 KB
Newer Older
zhe chen's avatar
zhe chen committed
1
import mmcv
yeshenglong1's avatar
yeshenglong1 committed
2
import torch
zhe chen's avatar
zhe chen committed
3
4
from mmdet.models.builder import LOSSES
from mmdet.models.losses.utils import weighted_loss
yeshenglong1's avatar
yeshenglong1 committed
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from torch import nn as nn
from torch.nn import functional as F


@mmcv.jit(derivate=True, coderize=True)
@weighted_loss
def smooth_l1_loss(pred, target, beta=1.0):
    """Smooth L1 loss.
    Args:
        pred (torch.Tensor): The prediction.
        target (torch.Tensor): The learning target of the prediction.
        beta (float, optional): The threshold in the piecewise function.
            Defaults to 1.0.
    Returns:
        torch.Tensor: Calculated loss
    """
    assert beta > 0
    if target.numel() == 0:
        return pred.sum() * 0

    assert pred.size() == target.size()
    diff = torch.abs(pred - target)
    loss = torch.where(diff < beta, 0.5 * diff * diff / beta,
                       diff - 0.5 * beta)

    return loss


@LOSSES.register_module()
class LinesLoss(nn.Module):

    def __init__(self, reduction='mean', loss_weight=1.0, beta=0.5):
        """
            L1 loss. The same as the smooth L1 loss
            Args:
                reduction (str, optional): The method to reduce the loss.
                    Options are "none", "mean" and "sum".
                loss_weight (float, optional): The weight of loss.
        """

        super(LinesLoss, self).__init__()
        self.reduction = reduction
        self.loss_weight = loss_weight
        self.beta = beta

    def forward(self,
                pred,
                target,
                weight=None,
                avg_factor=None,
                reduction_override=None):
        """Forward function.
        Args:
            pred (torch.Tensor): The prediction.
                shape: [bs, ...]
            target (torch.Tensor): The learning target of the prediction.
                shape: [bs, ...]
            weight (torch.Tensor, optional): The weight of loss for each
zhe chen's avatar
zhe chen committed
63
                prediction. Defaults to None.
yeshenglong1's avatar
yeshenglong1 committed
64
65
66
67
68
69
70
71
72
73
74
75
76
77
                it's useful when the predictions are not all valid.
            avg_factor (int, optional): Average factor that is used to average
                the loss. Defaults to None.
            reduction_override (str, optional): The reduction method used to
                override the original reduction method of the loss.
                Defaults to None.
        """
        assert reduction_override in (None, 'none', 'mean', 'sum')
        reduction = (
            reduction_override if reduction_override else self.reduction)

        loss = smooth_l1_loss(
            pred, target, weight, reduction=reduction, avg_factor=avg_factor, beta=self.beta)

zhe chen's avatar
zhe chen committed
78
        return loss * self.loss_weight
yeshenglong1's avatar
yeshenglong1 committed
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123


@mmcv.jit(derivate=True, coderize=True)
@weighted_loss
def bce(pred, label, class_weight=None):
    """
        pred: B,nquery,npts
        label: B,nquery,npts
    """

    if label.numel() == 0:
        return pred.sum() * 0
    assert pred.size() == label.size()

    loss = F.binary_cross_entropy_with_logits(
        pred, label.float(), pos_weight=class_weight, reduction='none')

    return loss


@LOSSES.register_module()
class MasksLoss(nn.Module):

    def __init__(self, reduction='mean', loss_weight=1.0):
        super(MasksLoss, self).__init__()
        self.reduction = reduction
        self.loss_weight = loss_weight

    def forward(self,
                pred,
                target,
                weight=None,
                avg_factor=None,
                reduction_override=None):
        """Forward function.
        Args:
            xxx
        """
        assert reduction_override in (None, 'none', 'mean', 'sum')
        reduction = (
            reduction_override if reduction_override else self.reduction)

        loss = bce(pred, target, weight, reduction=reduction,
                   avg_factor=avg_factor)

zhe chen's avatar
zhe chen committed
124
125
        return loss * self.loss_weight

yeshenglong1's avatar
yeshenglong1 committed
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166

@mmcv.jit(derivate=True, coderize=True)
@weighted_loss
def ce(pred, label, class_weight=None):
    """
        pred: B*nquery,npts
        label: B*nquery,
    """

    if label.numel() == 0:
        return pred.sum() * 0

    loss = F.cross_entropy(
        pred, label, weight=class_weight, reduction='none')

    return loss


@LOSSES.register_module()
class LenLoss(nn.Module):

    def __init__(self, reduction='mean', loss_weight=1.0):
        super(LenLoss, self).__init__()
        self.reduction = reduction
        self.loss_weight = loss_weight

    def forward(self,
                pred,
                target,
                weight=None,
                avg_factor=None,
                reduction_override=None):
        """Forward function.
        Args:
            xxx
        """
        assert reduction_override in (None, 'none', 'mean', 'sum')
        reduction = (
            reduction_override if reduction_override else self.reduction)

        loss = ce(pred, target, weight, reduction=reduction,
zhe chen's avatar
zhe chen committed
167
                  avg_factor=avg_factor)
yeshenglong1's avatar
yeshenglong1 committed
168

zhe chen's avatar
zhe chen committed
169
        return loss * self.loss_weight