# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch

from mmgen.models import LSGANDiscriminator, LSGANGenerator, build_module


class TestLSGANGenerator(object):

    @classmethod
    def setup_class(cls):
        cls.noise = torch.randn((3, 128))
        cls.default_config = dict(
            type='LSGANGenerator', noise_size=128, output_scale=128)

    def test_lsgan_generator(self):

        # test default setting with builder
        g = build_module(self.default_config)
        assert isinstance(g, LSGANGenerator)
        x = g(None, num_batches=3)
        assert x.shape == (3, 3, 128, 128)
        x = g(None, num_batches=3, return_noise=True)
        assert x['noise_batch'].shape == (3, 128)
        x = g(self.noise, return_noise=True)
        assert x['noise_batch'].shape == (3, 128)
        x = g(torch.randn, num_batches=3, return_noise=True)
        assert x['noise_batch'].shape == (3, 128)

        # test different output_scale
        config = dict(type='LSGANGenerator', noise_size=128, output_scale=64)
        g = build_module(config)
        assert isinstance(g, LSGANGenerator)
        x = g(None, num_batches=3)
        assert x.shape == (3, 3, 64, 64)

    @pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
    def test_lsgan_generator_cuda(self):

        # test default setting with builder
        g = build_module(self.default_config).cuda()
        assert isinstance(g, LSGANGenerator)
        x = g(None, num_batches=3)
        assert x.shape == (3, 3, 128, 128)
        x = g(None, num_batches=3, return_noise=True)
        assert x['noise_batch'].shape == (3, 128)
        x = g(self.noise.cuda(), return_noise=True)
        assert x['noise_batch'].shape == (3, 128)
        x = g(torch.randn, num_batches=3, return_noise=True)
        assert x['noise_batch'].shape == (3, 128)

        # test different output_scale
        config = dict(type='LSGANGenerator', noise_size=128, output_scale=64)
        g = build_module(config).cuda()
        assert isinstance(g, LSGANGenerator)
        x = g(None, num_batches=3)
        assert x.shape == (3, 3, 64, 64)


class TestLSGANDiscriminator(object):

    @classmethod
    def setup_class(cls):
        cls.x = torch.randn((2, 3, 128, 128))
        cls.default_config = dict(
            type='LSGANDiscriminator', in_channels=3, input_scale=128)

    def test_lsgan_discriminator(self):

        # test default setting with builder
        d = build_module(self.default_config)
        assert isinstance(d, LSGANDiscriminator)
        score = d(self.x)
        assert score.shape == (2, 1)

        # test different input_scale
        config = dict(type='LSGANDiscriminator', in_channels=3, input_scale=64)
        d = build_module(config)
        assert isinstance(d, LSGANDiscriminator)
        x = torch.randn((2, 3, 64, 64))
        score = d(x)
        assert score.shape == (2, 1)

        # test different config
        config = dict(
            type='LSGANDiscriminator',
            in_channels=3,
            input_scale=64,
            out_act_cfg=dict(type='Sigmoid'))
        d = build_module(config)
        assert isinstance(d, LSGANDiscriminator)
        x = torch.randn((2, 3, 64, 64))
        score = d(x)
        assert score.shape == (2, 1)

    @pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
    def test_lsgan_discriminator_cuda(self):

        # test default setting with builder
        d = build_module(self.default_config).cuda()
        assert isinstance(d, LSGANDiscriminator)
        score = d(self.x.cuda())
        assert score.shape == (2, 1)

        # test different input_scale
        config = dict(type='LSGANDiscriminator', in_channels=3, input_scale=64)
        d = build_module(config).cuda()
        assert isinstance(d, LSGANDiscriminator)
        x = torch.randn((2, 3, 64, 64))
        score = d(x.cuda())
        assert score.shape == (2, 1)

        # test different config
        config = dict(
            type='LSGANDiscriminator',
            in_channels=3,
            input_scale=64,
            out_act_cfg=dict(type='Sigmoid'))
        d = build_module(config).cuda()
        assert isinstance(d, LSGANDiscriminator)
        x = torch.randn((2, 3, 64, 64))
        score = d(x.cuda())
        assert score.shape == (2, 1)
