test_losses.py 2.59 KB
Newer Older
mashun1's avatar
anytext  
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import pytest
import torch

from basicsr.losses.basic_loss import CharbonnierLoss, L1Loss, MSELoss, WeightedTVLoss


@pytest.mark.parametrize('loss_class', [L1Loss, MSELoss, CharbonnierLoss])
def test_pixellosses(loss_class):
    """Test loss: pixel losses"""

    pred = torch.rand((1, 3, 4, 4), dtype=torch.float32)
    target = torch.rand((1, 3, 4, 4), dtype=torch.float32)
    loss = loss_class(loss_weight=1.0, reduction='mean')
    out = loss(pred, target, weight=None)
    assert isinstance(out, torch.Tensor)
    assert out.shape == torch.Size([])

    # -------------------- test with other reduction -------------------- #
    # reduction = none
    loss = loss_class(loss_weight=1.0, reduction='none')
    out = loss(pred, target, weight=None)
    assert isinstance(out, torch.Tensor)
    assert out.shape == (1, 3, 4, 4)
    # test with spatial weights
    weight = torch.rand((1, 3, 4, 4), dtype=torch.float32)
    out = loss(pred, target, weight=weight)
    assert isinstance(out, torch.Tensor)
    assert out.shape == (1, 3, 4, 4)

    # reduction = sum
    loss = loss_class(loss_weight=1.0, reduction='sum')
    out = loss(pred, target, weight=None)
    assert isinstance(out, torch.Tensor)
    assert out.shape == torch.Size([])

    # -------------------- test unsupported loss reduction -------------------- #
    with pytest.raises(ValueError):
        loss_class(loss_weight=1.0, reduction='unknown')


def test_weightedtvloss():
    """Test loss: WeightedTVLoss"""

    pred = torch.rand((1, 3, 4, 4), dtype=torch.float32)
    loss = WeightedTVLoss(loss_weight=1.0, reduction='mean')
    out = loss(pred, weight=None)
    assert isinstance(out, torch.Tensor)
    assert out.shape == torch.Size([])

    # test with spatial weights
    weight = torch.rand((1, 3, 4, 4), dtype=torch.float32)
    out = loss(pred, weight=weight)
    assert isinstance(out, torch.Tensor)
    assert out.shape == torch.Size([])

    # -------------------- test reduction = sum-------------------- #
    loss = WeightedTVLoss(loss_weight=1.0, reduction='sum')
    out = loss(pred, weight=None)
    assert isinstance(out, torch.Tensor)
    assert out.shape == torch.Size([])

    # test with spatial weights
    weight = torch.rand((1, 3, 4, 4), dtype=torch.float32)
    out = loss(pred, weight=weight)
    assert isinstance(out, torch.Tensor)
    assert out.shape == torch.Size([])

    # -------------------- test unsupported loss reduction -------------------- #
    with pytest.raises(ValueError):
        WeightedTVLoss(loss_weight=1.0, reduction='unknown')
    with pytest.raises(ValueError):
        WeightedTVLoss(loss_weight=1.0, reduction='none')