test_gan_loss.py 2.7 KB
Newer Older
limm's avatar
limm committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
# Copyright (c) OpenMMLab. All rights reserved.
import numpy.testing as npt
import pytest
import torch

from mmgen.models.losses.gan_loss import GANLoss


def test_gan_losses():
    """Test gan losses."""
    with pytest.raises(NotImplementedError):
        GANLoss(
            'xixihaha',
            loss_weight=1.0,
            real_label_val=1.0,
            fake_label_val=0.0)

    input_1 = torch.ones(1, 1)
    input_2 = torch.ones(1, 3, 6, 6) * 2

    # vanilla
    gan_loss = GANLoss(
        'vanilla', loss_weight=2.0, real_label_val=1.0, fake_label_val=0.0)
    loss = gan_loss(input_1, True, is_disc=False)
    npt.assert_almost_equal(loss.item(), 0.6265233)
    loss = gan_loss(input_1, False, is_disc=False)
    npt.assert_almost_equal(loss.item(), 2.6265232)
    loss = gan_loss(input_1, True, is_disc=True)
    npt.assert_almost_equal(loss.item(), 0.3132616)
    loss = gan_loss(input_1, False, is_disc=True)
    npt.assert_almost_equal(loss.item(), 1.3132616)

    # lsgan
    gan_loss = GANLoss(
        'lsgan', loss_weight=2.0, real_label_val=1.0, fake_label_val=0.0)
    loss = gan_loss(input_2, True, is_disc=False)
    npt.assert_almost_equal(loss.item(), 2.0)
    loss = gan_loss(input_2, False, is_disc=False)
    npt.assert_almost_equal(loss.item(), 8.0)
    loss = gan_loss(input_2, True, is_disc=True)
    npt.assert_almost_equal(loss.item(), 1.0)
    loss = gan_loss(input_2, False, is_disc=True)
    npt.assert_almost_equal(loss.item(), 4.0)

    # wgan
    gan_loss = GANLoss(
        'wgan', loss_weight=2.0, real_label_val=1.0, fake_label_val=0.0)
    loss = gan_loss(input_2, True, is_disc=False)
    npt.assert_almost_equal(loss.item(), -4.0)
    loss = gan_loss(input_2, False, is_disc=False)
    npt.assert_almost_equal(loss.item(), 4)
    loss = gan_loss(input_2, True, is_disc=True)
    npt.assert_almost_equal(loss.item(), -2.0)
    loss = gan_loss(input_2, False, is_disc=True)
    npt.assert_almost_equal(loss.item(), 2.0)

    # wgan
    gan_loss = GANLoss(
        'wgan-logistic-ns',
        loss_weight=2.0,
        real_label_val=1.0,
        fake_label_val=0.0)
    loss = gan_loss(input_2, True, is_disc=False)
    assert loss.item() > 0
    loss = gan_loss(input_2, False, is_disc=False)
    assert loss.item() > 0

    # hinge
    gan_loss = GANLoss(
        'hinge', loss_weight=2.0, real_label_val=1.0, fake_label_val=0.0)
    loss = gan_loss(input_2, True, is_disc=False)
    npt.assert_almost_equal(loss.item(), -4.0)
    loss = gan_loss(input_2, False, is_disc=False)
    npt.assert_almost_equal(loss.item(), -4.0)
    loss = gan_loss(input_2, True, is_disc=True)
    npt.assert_almost_equal(loss.item(), 0.0)
    loss = gan_loss(input_2, False, is_disc=True)
    npt.assert_almost_equal(loss.item(), 3.0)