test_losses.py 12.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch

from mmcls.models import build_loss


def test_asymmetric_loss():
    # test asymmetric_loss
    cls_score = torch.Tensor([[5, -5, 0], [5, -5, 0]])
    label = torch.Tensor([[1, 0, 1], [0, 1, 0]])
    weight = torch.tensor([0.5, 0.5])

    loss_cfg = dict(
        type='AsymmetricLoss',
        gamma_pos=1.0,
        gamma_neg=4.0,
        clip=0.05,
        reduction='mean',
        loss_weight=1.0)
    loss = build_loss(loss_cfg)
    assert torch.allclose(loss(cls_score, label), torch.tensor(3.80845 / 3))

    # test asymmetric_loss with weight
    assert torch.allclose(
        loss(cls_score, label, weight=weight), torch.tensor(3.80845 / 6))

    # test asymmetric_loss without clip
    loss_cfg = dict(
        type='AsymmetricLoss',
        gamma_pos=1.0,
        gamma_neg=4.0,
        clip=None,
        reduction='mean',
        loss_weight=1.0)
    loss = build_loss(loss_cfg)
    assert torch.allclose(loss(cls_score, label), torch.tensor(5.1186 / 3))

    # test asymmetric_loss with softmax for single label task
    cls_score = torch.Tensor([[5, -5, 0], [5, -5, 0]])
    label = torch.Tensor([0, 1])
    weight = torch.tensor([0.5, 0.5])
    loss_cfg = dict(
        type='AsymmetricLoss',
        gamma_pos=0.0,
        gamma_neg=0.0,
        clip=None,
        reduction='mean',
        loss_weight=1.0,
        use_sigmoid=False,
        eps=1e-8)
    loss = build_loss(loss_cfg)
    # test asymmetric_loss for single label task without weight
    assert torch.allclose(loss(cls_score, label), torch.tensor(2.5045))
    # test asymmetric_loss for single label task with weight
    assert torch.allclose(
        loss(cls_score, label, weight=weight), torch.tensor(2.5045 * 0.5))

    # test soft asymmetric_loss with softmax
    cls_score = torch.Tensor([[5, -5, 0], [5, -5, 0]])
    label = torch.Tensor([[1, 0, 0], [0, 1, 0]])
    weight = torch.tensor([0.5, 0.5])
    loss_cfg = dict(
        type='AsymmetricLoss',
        gamma_pos=0.0,
        gamma_neg=0.0,
        clip=None,
        reduction='mean',
        loss_weight=1.0,
        use_sigmoid=False,
        eps=1e-8)
    loss = build_loss(loss_cfg)
    # test soft asymmetric_loss with softmax without weight
    assert torch.allclose(loss(cls_score, label), torch.tensor(2.5045))
    # test soft asymmetric_loss with softmax with weight
    assert torch.allclose(
        loss(cls_score, label, weight=weight), torch.tensor(2.5045 * 0.5))


def test_cross_entropy_loss():
    with pytest.raises(AssertionError):
        # use_sigmoid and use_soft could not be set simultaneously
        loss_cfg = dict(
            type='CrossEntropyLoss', use_sigmoid=True, use_soft=True)
        loss = build_loss(loss_cfg)

    # test ce_loss
    cls_score = torch.Tensor([[-1000, 1000], [100, -100]])
    label = torch.Tensor([0, 1]).long()
    class_weight = [0.3, 0.7]  # class 0 : 0.3, class 1 : 0.7
    weight = torch.tensor([0.6, 0.4])

    # test ce_loss without class weight
    loss_cfg = dict(type='CrossEntropyLoss', reduction='mean', loss_weight=1.0)
    loss = build_loss(loss_cfg)
    assert torch.allclose(loss(cls_score, label), torch.tensor(1100.))
    # test ce_loss with weight
    assert torch.allclose(
        loss(cls_score, label, weight=weight), torch.tensor(640.))

    # test ce_loss with class weight
    loss_cfg = dict(
        type='CrossEntropyLoss',
        reduction='mean',
        loss_weight=1.0,
        class_weight=class_weight)
    loss = build_loss(loss_cfg)
    assert torch.allclose(loss(cls_score, label), torch.tensor(370.))
    # test ce_loss with weight
    assert torch.allclose(
        loss(cls_score, label, weight=weight), torch.tensor(208.))

    # test bce_loss
    cls_score = torch.Tensor([[-200, 100], [500, -1000], [300, -300]])
    label = torch.Tensor([[1, 0], [0, 1], [1, 0]])
    weight = torch.Tensor([0.6, 0.4, 0.5])
    class_weight = [0.1, 0.9]  # class 0: 0.1, class 1: 0.9
    pos_weight = [0.1, 0.2]

    # test bce_loss without class weight
    loss_cfg = dict(
        type='CrossEntropyLoss',
        use_sigmoid=True,
        reduction='mean',
        loss_weight=1.0)
    loss = build_loss(loss_cfg)
    assert torch.allclose(loss(cls_score, label), torch.tensor(300.))
    # test ce_loss with weight
    assert torch.allclose(
        loss(cls_score, label, weight=weight), torch.tensor(130.))

    # test bce_loss with class weight
    loss_cfg = dict(
        type='CrossEntropyLoss',
        use_sigmoid=True,
        reduction='mean',
        loss_weight=1.0,
        class_weight=class_weight)
    loss = build_loss(loss_cfg)
    assert torch.allclose(loss(cls_score, label), torch.tensor(176.667))
    # test bce_loss with weight
    assert torch.allclose(
        loss(cls_score, label, weight=weight), torch.tensor(74.333))

    # test bce loss with pos_weight
    loss_cfg = dict(
        type='CrossEntropyLoss',
        use_sigmoid=True,
        reduction='mean',
        loss_weight=1.0,
        pos_weight=pos_weight)
    loss = build_loss(loss_cfg)
    assert torch.allclose(loss(cls_score, label), torch.tensor(136.6667))

    # test soft_ce_loss
    cls_score = torch.Tensor([[-1000, 1000], [100, -100]])
    label = torch.Tensor([[1.0, 0.0], [0.0, 1.0]])
    class_weight = [0.3, 0.7]  # class 0 : 0.3, class 1 : 0.7
    weight = torch.tensor([0.6, 0.4])

    # test soft_ce_loss without class weight
    loss_cfg = dict(
        type='CrossEntropyLoss',
        use_soft=True,
        reduction='mean',
        loss_weight=1.0)
    loss = build_loss(loss_cfg)
    assert torch.allclose(loss(cls_score, label), torch.tensor(1100.))
    # test soft_ce_loss with weight
    assert torch.allclose(
        loss(cls_score, label, weight=weight), torch.tensor(640.))

    # test soft_ce_loss with class weight
    loss_cfg = dict(
        type='CrossEntropyLoss',
        use_soft=True,
        reduction='mean',
        loss_weight=1.0,
        class_weight=class_weight)
    loss = build_loss(loss_cfg)
    assert torch.allclose(loss(cls_score, label), torch.tensor(370.))
    # test soft_ce_loss with weight
    assert torch.allclose(
        loss(cls_score, label, weight=weight), torch.tensor(208.))


def test_focal_loss():
    # test focal_loss
    cls_score = torch.Tensor([[5, -5, 0], [5, -5, 0]])
    label = torch.Tensor([[1, 0, 1], [0, 1, 0]])
    weight = torch.tensor([0.5, 0.5])

    loss_cfg = dict(
        type='FocalLoss',
        gamma=2.0,
        alpha=0.25,
        reduction='mean',
        loss_weight=1.0)
    loss = build_loss(loss_cfg)
    assert torch.allclose(loss(cls_score, label), torch.tensor(0.8522))
    # test focal_loss with weight
    assert torch.allclose(
        loss(cls_score, label, weight=weight), torch.tensor(0.8522 / 2))
    # test focal loss for single label task
    cls_score = torch.Tensor([[5, -5, 0], [5, -5, 0]])
    label = torch.Tensor([0, 1])
    weight = torch.tensor([0.5, 0.5])
    assert torch.allclose(loss(cls_score, label), torch.tensor(0.86664125))
    # test focal_loss single label with weight
    assert torch.allclose(
        loss(cls_score, label, weight=weight), torch.tensor(0.86664125 / 2))


def test_label_smooth_loss():
    # test label_smooth_val assertion
    with pytest.raises(AssertionError):
        loss_cfg = dict(type='LabelSmoothLoss', label_smooth_val=1.0)
        build_loss(loss_cfg)

    with pytest.raises(AssertionError):
        loss_cfg = dict(type='LabelSmoothLoss', label_smooth_val='str')
        build_loss(loss_cfg)

    # test reduction assertion
    with pytest.raises(AssertionError):
        loss_cfg = dict(
            type='LabelSmoothLoss', label_smooth_val=0.1, reduction='unknown')
        build_loss(loss_cfg)

    # test mode assertion
    with pytest.raises(AssertionError):
        loss_cfg = dict(
            type='LabelSmoothLoss', label_smooth_val=0.1, mode='unknown')
        build_loss(loss_cfg)

    # test original mode label smooth loss
    cls_score = torch.tensor([[1., -1.]])
    label = torch.tensor([0])

    loss_cfg = dict(
        type='LabelSmoothLoss',
        label_smooth_val=0.1,
        mode='original',
        reduction='mean',
        loss_weight=1.0)
    loss = build_loss(loss_cfg)
    correct = 0.2269  # from timm
    assert loss(cls_score, label) - correct <= 0.0001

    # test classy_vision mode label smooth loss
    loss_cfg = dict(
        type='LabelSmoothLoss',
        label_smooth_val=0.1,
        mode='classy_vision',
        reduction='mean',
        loss_weight=1.0)
    loss = build_loss(loss_cfg)
    correct = 0.2178  # from ClassyVision
    assert loss(cls_score, label) - correct <= 0.0001

    # test multi_label mode label smooth loss
    cls_score = torch.tensor([[1., -1., 1]])
    label = torch.tensor([[1, 0, 1]])

    loss_cfg = dict(
        type='LabelSmoothLoss',
        label_smooth_val=0.1,
        mode='multi_label',
        reduction='mean',
        loss_weight=1.0)
    loss = build_loss(loss_cfg)
    smooth_label = torch.tensor([[0.9, 0.1, 0.9]])
    correct = torch.binary_cross_entropy_with_logits(cls_score,
                                                     smooth_label).mean()
    assert torch.allclose(loss(cls_score, label), correct)

    # test label linear combination smooth loss
    cls_score = torch.tensor([[1., -1., 0.]])
    label1 = torch.tensor([[1., 0., 0.]])
    label2 = torch.tensor([[0., 0., 1.]])
    label_mix = label1 * 0.6 + label2 * 0.4

    loss_cfg = dict(
        type='LabelSmoothLoss',
        label_smooth_val=0.1,
        mode='original',
        reduction='mean',
        num_classes=3,
        loss_weight=1.0)
    loss = build_loss(loss_cfg)
    smooth_label1 = loss.original_smooth_label(label1)
    smooth_label2 = loss.original_smooth_label(label2)
    label_smooth_mix = smooth_label1 * 0.6 + smooth_label2 * 0.4
    correct = (-torch.log_softmax(cls_score, -1) * label_smooth_mix).sum()

    assert loss(cls_score, label_mix) - correct <= 0.0001

    # test label smooth loss with weight
    cls_score = torch.tensor([[1., -1.], [1., -1.]])
    label = torch.tensor([0, 1])
    weight = torch.tensor([0.5, 0.5])

    loss_cfg = dict(
        type='LabelSmoothLoss',
        reduction='mean',
        label_smooth_val=0.1,
        loss_weight=1.0)
    loss = build_loss(loss_cfg)
    assert torch.allclose(
        loss(cls_score, label, weight=weight),
        loss(cls_score, label) / 2)


# migrate from mmdetection with modifications
def test_seesaw_loss():
    # only softmax version of Seesaw Loss is implemented
    with pytest.raises(AssertionError):
        loss_cfg = dict(type='SeesawLoss', use_sigmoid=True, loss_weight=1.0)
        build_loss(loss_cfg)

    # test that cls_score.size(-1) == num_classes
    loss_cls_cfg = dict(
        type='SeesawLoss', p=0.0, q=0.0, loss_weight=1.0, num_classes=2)
    loss_cls = build_loss(loss_cls_cfg)
    # the length of fake_pred should be num_classe = 4
    with pytest.raises(AssertionError):
        fake_pred = torch.Tensor([[-100, 100, -100]])
        fake_label = torch.Tensor([1]).long()
        loss_cls(fake_pred, fake_label)
    # the length of fake_pred should be num_classes + 2 = 4
    with pytest.raises(AssertionError):
        fake_pred = torch.Tensor([[-100, 100, -100, 100]])
        fake_label = torch.Tensor([1]).long()
        loss_cls(fake_pred, fake_label)

    # test the calculation without p and q
    loss_cls_cfg = dict(
        type='SeesawLoss', p=0.0, q=0.0, loss_weight=1.0, num_classes=2)
    loss_cls = build_loss(loss_cls_cfg)
    fake_pred = torch.Tensor([[-100, 100]])
    fake_label = torch.Tensor([1]).long()
    loss = loss_cls(fake_pred, fake_label)
    assert torch.allclose(loss, torch.tensor(0.))

    # test the calculation with p and without q
    loss_cls_cfg = dict(
        type='SeesawLoss', p=1.0, q=0.0, loss_weight=1.0, num_classes=2)
    loss_cls = build_loss(loss_cls_cfg)
    fake_pred = torch.Tensor([[-100, 100]])
    fake_label = torch.Tensor([0]).long()
    loss_cls.cum_samples[0] = torch.exp(torch.Tensor([20]))
    loss = loss_cls(fake_pred, fake_label)
    assert torch.allclose(loss, torch.tensor(180.))

    # test the calculation with q and without p
    loss_cls_cfg = dict(
        type='SeesawLoss', p=0.0, q=1.0, loss_weight=1.0, num_classes=2)
    loss_cls = build_loss(loss_cls_cfg)
    fake_pred = torch.Tensor([[-100, 100]])
    fake_label = torch.Tensor([0]).long()
    loss = loss_cls(fake_pred, fake_label)
    assert torch.allclose(loss, torch.tensor(200.) + torch.tensor(100.).log())