loss.py 4.07 KB
Newer Older
LDOUBLEV's avatar
LDOUBLEV committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import paddle
import paddle.nn.functional as F


class Loss(object):
    """
    Loss
    """

    def __init__(self, class_dim=1000, epsilon=None):
        assert class_dim > 1, "class_dim=%d is not larger than 1" % (class_dim)
        self._class_dim = class_dim
        if epsilon is not None and epsilon >= 0.0 and epsilon <= 1.0:
            self._epsilon = epsilon
            self._label_smoothing = True
        else:
            self._epsilon = None
            self._label_smoothing = False

    def _labelsmoothing(self, target):
        if target.shape[-1] != self._class_dim:
            one_hot_target = F.one_hot(target, self._class_dim)
        else:
            one_hot_target = target
        soft_target = F.label_smooth(one_hot_target, epsilon=self._epsilon)
        soft_target = paddle.reshape(soft_target, shape=[-1, self._class_dim])
        return soft_target

    def _crossentropy(self, input, target, use_pure_fp16=False):
        if self._label_smoothing:
            target = self._labelsmoothing(target)
            input = -F.log_softmax(input, axis=-1)
            cost = paddle.sum(target * input, axis=-1)
        else:
            cost = F.cross_entropy(input=input, label=target)
        if use_pure_fp16:
            avg_cost = paddle.sum(cost)
        else:
            avg_cost = paddle.mean(cost)
        return avg_cost

    def __call__(self, input, target):
        return self._crossentropy(input, target)


def build_loss(config, epsilon=None):
    class_dim = config['class_dim']
    loss_func = Loss(class_dim=class_dim, epsilon=epsilon)
    return loss_func


class LossDistill(Loss):
    def __init__(self, model_name_list, class_dim=1000, epsilon=None):
        assert class_dim > 1, "class_dim=%d is not larger than 1" % (class_dim)
        self._class_dim = class_dim
        if epsilon is not None and epsilon >= 0.0 and epsilon <= 1.0:
            self._epsilon = epsilon
            self._label_smoothing = True
        else:
            self._epsilon = None
            self._label_smoothing = False

        self.model_name_list = model_name_list
        assert len(self.model_name_list) > 1, "error"

    def __call__(self, input, target):
        losses = {}
        for k in self.model_name_list:
            inp = input[k]
            losses[k] = self._crossentropy(inp, target)
        return losses


class KLJSLoss(object):
    def __init__(self, mode='kl'):
        assert mode in ['kl', 'js', 'KL', 'JS'
                        ], "mode can only be one of ['kl', 'js', 'KL', 'JS']"
        self.mode = mode

    def __call__(self, p1, p2, reduction="mean"):
        p1 = F.softmax(p1, axis=-1)
        p2 = F.softmax(p2, axis=-1)

        loss = paddle.multiply(p2, paddle.log((p2 + 1e-5) / (p1 + 1e-5) + 1e-5))

        if self.mode.lower() == "js":
            loss += paddle.multiply(
                p1, paddle.log((p1 + 1e-5) / (p2 + 1e-5) + 1e-5))
            loss *= 0.5
        if reduction == "mean":
            loss = paddle.mean(loss)
        elif reduction == "none" or reduction is None:
            return loss
        else:
            loss = paddle.sum(loss)
        return loss


class DMLLoss(object):
    def __init__(self, model_name_pairs, mode='js'):

        self.model_name_pairs = self._check_model_name_pairs(model_name_pairs)
        self.kljs_loss = KLJSLoss(mode=mode)

    def _check_model_name_pairs(self, model_name_pairs):
        if not isinstance(model_name_pairs, list):
            return []
        elif isinstance(model_name_pairs[0], list) and isinstance(
                model_name_pairs[0][0], str):
            return model_name_pairs
        else:
            return [model_name_pairs]

    def __call__(self, predicts, target=None):
        loss_dict = dict()
        for pairs in self.model_name_pairs:
            p1 = predicts[pairs[0]]
            p2 = predicts[pairs[1]]

            loss_dict[pairs[0] + "_" + pairs[1]] = self.kljs_loss(p1, p2)

        return loss_dict


# def build_distill_loss(config, epsilon=None):
#     class_dim = config['class_dim']
#     loss = LossDistill(model_name_list=['student', 'student1'], )
#     return loss_func