loss.py 3.81 KB
Newer Older
Hang Zhang's avatar
Hang Zhang committed
1
2
3
4
5
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.autograd import Variable

Hang Zhang's avatar
Hang Zhang committed
6
__all__ = ['LabelSmoothing', 'NLLMultiLabelSmooth', 'SegmentationLosses']
Hang Zhang's avatar
Hang Zhang committed
7

Hang Zhang's avatar
Hang Zhang committed
8
9
10
11
12
class LabelSmoothing(nn.Module):
    """
    NLL loss with label smoothing.
    """
    def __init__(self, smoothing=0.1):
Hang Zhang's avatar
Hang Zhang committed
13
        """
Hang Zhang's avatar
Hang Zhang committed
14
15
        Constructor for the LabelSmoothing module.
        :param smoothing: label smoothing factor
Hang Zhang's avatar
Hang Zhang committed
16
        """
Hang Zhang's avatar
Hang Zhang committed
17
18
19
        super(LabelSmoothing, self).__init__()
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
Hang Zhang's avatar
Hang Zhang committed
20

Hang Zhang's avatar
Hang Zhang committed
21
22
    def forward(self, x, target):
        logprobs = torch.nn.functional.log_softmax(x, dim=-1)
Hang Zhang's avatar
Hang Zhang committed
23

Hang Zhang's avatar
Hang Zhang committed
24
25
26
27
28
        nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
        nll_loss = nll_loss.squeeze(1)
        smooth_loss = -logprobs.mean(dim=-1)
        loss = self.confidence * nll_loss + self.smoothing * smooth_loss
        return loss.mean()
Hang Zhang's avatar
Hang Zhang committed
29

Hang Zhang's avatar
Hang Zhang committed
30
31
32
33
34
class NLLMultiLabelSmooth(nn.Module):
    def __init__(self, smoothing = 0.1):
        super(NLLMultiLabelSmooth, self).__init__()
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
Hang Zhang's avatar
Hang Zhang committed
35

Hang Zhang's avatar
Hang Zhang committed
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
    def forward(self, x, target):
        if self.training:
            x = x.float()
            target = target.float()
            logprobs = torch.nn.functional.log_softmax(x, dim = -1)
    
            nll_loss = -logprobs * target
            nll_loss = nll_loss.sum(-1)
    
            smooth_loss = -logprobs.mean(dim=-1)
    
            loss = self.confidence * nll_loss + self.smoothing * smooth_loss
    
            return loss.mean()
        else:
            return torch.nn.functional.cross_entropy(x, target)
Hang Zhang's avatar
Hang Zhang committed
52

Hang Zhang's avatar
Hang Zhang committed
53
class SegmentationLosses(nn.CrossEntropyLoss):
Hang Zhang's avatar
Hang Zhang committed
54
55
56
57
    """2D Cross Entropy Loss with Auxilary Loss"""
    def __init__(self, se_loss=False, se_weight=0.2, nclass=-1,
                 aux=False, aux_weight=0.4, weight=None,
                 ignore_index=-1):
Hang Zhang's avatar
Hang Zhang committed
58
        super(SegmentationLosses, self).__init__(weight, None, ignore_index)
Hang Zhang's avatar
Hang Zhang committed
59
60
61
62
63
64
65
66
67
        self.se_loss = se_loss
        self.aux = aux
        self.nclass = nclass
        self.se_weight = se_weight
        self.aux_weight = aux_weight
        self.bceloss = nn.BCELoss(weight) 

    def forward(self, *inputs):
        if not self.se_loss and not self.aux:
Hang Zhang's avatar
Hang Zhang committed
68
            return super(SegmentationLosses, self).forward(*inputs)
Hang Zhang's avatar
Hang Zhang committed
69
70
        elif not self.se_loss:
            pred1, pred2, target = tuple(inputs)
Hang Zhang's avatar
Hang Zhang committed
71
72
            loss1 = super(SegmentationLosses, self).forward(pred1, target)
            loss2 = super(SegmentationLosses, self).forward(pred2, target)
Hang Zhang's avatar
Hang Zhang committed
73
74
75
76
            return loss1 + self.aux_weight * loss2
        elif not self.aux:
            pred, se_pred, target = tuple(inputs)
            se_target = self._get_batch_label_vector(target, nclass=self.nclass).type_as(pred)
Hang Zhang's avatar
Hang Zhang committed
77
            loss1 = super(SegmentationLosses, self).forward(pred, target)
Hang Zhang's avatar
Hang Zhang committed
78
79
80
81
82
            loss2 = self.bceloss(torch.sigmoid(se_pred), se_target)
            return loss1 + self.se_weight * loss2
        else:
            pred1, se_pred, pred2, target = tuple(inputs)
            se_target = self._get_batch_label_vector(target, nclass=self.nclass).type_as(pred1)
Hang Zhang's avatar
Hang Zhang committed
83
84
            loss1 = super(SegmentationLosses, self).forward(pred1, target)
            loss2 = super(SegmentationLosses, self).forward(pred2, target)
Hang Zhang's avatar
Hang Zhang committed
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
            loss3 = self.bceloss(torch.sigmoid(se_pred), se_target)
            return loss1 + self.aux_weight * loss2 + self.se_weight * loss3

    @staticmethod
    def _get_batch_label_vector(target, nclass):
        # target is a 3D Variable BxHxW, output is 2D BxnClass
        batch = target.size(0)
        tvect = Variable(torch.zeros(batch, nclass))
        for i in range(batch):
            hist = torch.histc(target[i].cpu().data.float(), 
                               bins=nclass, min=0,
                               max=nclass-1)
            vect = hist>0
            tvect[i] = vect
        return tvect