Commit c9a48a52 authored by limm's avatar limm
Browse files

add tests code

parent b7536f78
Pipeline #2778 canceled with stages
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from mmgen.models import WGANGPDiscriminator, WGANGPGenerator, build_module
class TestWGANGPGenerator(object):
@classmethod
def setup_class(cls):
cls.noise = torch.randn((2, 100))
cls.default_config = dict(
type='WGANGPGenerator', noise_size=128, out_scale=128)
def test_wgangp_generator(self):
# test default setting with builder
g = build_module(self.default_config)
assert isinstance(g, WGANGPGenerator)
x = g(None, num_batches=3)
assert x.shape == (3, 3, 128, 128)
# test different out_scale
config = dict(type='WGANGPGenerator', noise_size=128, out_scale=64)
g = build_module(config)
assert isinstance(g, WGANGPGenerator)
x = g(None, num_batches=3)
assert x.shape == (3, 3, 64, 64)
# test different conv config
config = dict(
type='WGANGPGenerator',
noise_size=128,
out_scale=128,
conv_module_cfg=dict(
conv_cfg=None,
kernel_size=3,
stride=1,
padding=1,
bias=True,
act_cfg=dict(type='LeakyReLU', negative_slope=0.2),
norm_cfg=dict(type='BN'),
order=('conv', 'norm', 'act')))
g = build_module(config)
assert isinstance(g, WGANGPGenerator)
x = g(None, num_batches=3)
assert x.shape == (3, 3, 128, 128)
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
def test_wgangp_generator_cuda(self):
# test default setting with builder
g = build_module(self.default_config).cuda()
assert isinstance(g, WGANGPGenerator)
x = g(None, num_batches=3)
assert x.shape == (3, 3, 128, 128)
# test different out_scale
config = dict(type='WGANGPGenerator', noise_size=128, out_scale=64)
g = build_module(config).cuda()
assert isinstance(g, WGANGPGenerator)
x = g(None, num_batches=3)
assert x.shape == (3, 3, 64, 64)
# test different conv config
config = dict(
type='WGANGPGenerator',
noise_size=128,
out_scale=128,
conv_module_cfg=dict(
conv_cfg=None,
kernel_size=3,
stride=1,
padding=1,
bias=True,
act_cfg=dict(type='LeakyReLU', negative_slope=0.2),
norm_cfg=dict(type='BN'),
order=('conv', 'norm', 'act')))
g = build_module(config).cuda()
assert isinstance(g, WGANGPGenerator)
x = g(None, num_batches=3)
assert x.shape == (3, 3, 128, 128)
class TestWGANGPDiscriminator(object):
@classmethod
def setup_class(cls):
cls.x = torch.randn((2, 3, 128, 128))
cls.default_config = dict(
type='WGANGPDiscriminator', in_channel=3, in_scale=128)
cls.conv_ln_module_config = dict(
conv_cfg=None,
kernel_size=3,
stride=1,
padding=1,
bias=True,
act_cfg=dict(type='LeakyReLU', negative_slope=0.2),
norm_cfg=dict(type='LN2d'),
order=('conv', 'norm', 'act'))
cls.conv_gn_module_config = dict(
conv_cfg=None,
kernel_size=3,
stride=1,
padding=1,
bias=True,
act_cfg=dict(type='LeakyReLU', negative_slope=0.2),
norm_cfg=dict(type='GN'),
order=('conv', 'norm', 'act'))
def test_wgangp_discriminator(self):
# test default setting with builder
d = build_module(self.default_config)
assert isinstance(d, WGANGPDiscriminator)
score = d(self.x)
assert score.shape == (2, 1)
# test different in_scale
config = dict(type='WGANGPDiscriminator', in_channel=3, in_scale=64)
d = build_module(config)
assert isinstance(d, WGANGPDiscriminator)
x = torch.randn((2, 3, 64, 64))
score = d(x)
assert score.shape == (2, 1)
# test different conv config
config = dict(
type='WGANGPDiscriminator',
in_channel=3,
in_scale=128,
conv_module_cfg=self.conv_ln_module_config)
d = build_module(config)
assert isinstance(d, WGANGPDiscriminator)
score = d(self.x)
assert score.shape == (2, 1)
config = dict(
type='WGANGPDiscriminator',
in_channel=3,
in_scale=128,
conv_module_cfg=self.conv_gn_module_config)
d = build_module(config)
assert isinstance(d, WGANGPDiscriminator)
score = d(self.x)
assert score.shape == (2, 1)
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
def test_wgangp_discriminator_cuda(self):
# test default setting with builder
d = build_module(self.default_config).cuda()
assert isinstance(d, WGANGPDiscriminator)
score = d(self.x.cuda())
assert score.shape == (2, 1)
# test different in_scale
config = dict(type='WGANGPDiscriminator', in_channel=3, in_scale=64)
d = build_module(config).cuda()
assert isinstance(d, WGANGPDiscriminator)
x = torch.randn((2, 3, 64, 64))
score = d(x.cuda())
assert score.shape == (2, 1)
# test different conv config
config = dict(
type='WGANGPDiscriminator',
in_channel=3,
in_scale=128,
conv_module_cfg=self.conv_ln_module_config)
d = build_module(config).cuda()
assert isinstance(d, WGANGPDiscriminator)
score = d(self.x.cuda())
assert score.shape == (2, 1)
config = dict(
type='WGANGPDiscriminator',
in_channel=3,
in_scale=128,
conv_module_cfg=self.conv_gn_module_config)
d = build_module(config).cuda()
assert isinstance(d, WGANGPDiscriminator)
score = d(self.x.cuda())
assert score.shape == (2, 1)
# Copyright (c) OpenMMLab. All rights reserved.
from functools import partial
import pytest
import torch
import torch.nn as nn
from torch.autograd import gradgradcheck
from mmgen.ops import conv2d, conv_transpose2d
class TestCond2d:
@classmethod
def setup_class(cls):
cls.input = torch.randn((1, 3, 32, 32))
cls.weight = nn.Parameter(torch.randn(1, 3, 3, 3))
@pytest.mark.skipif(
not torch.cuda.is_available()
or not hasattr(torch.backends.cudnn, 'allow_tf32'),
reason='requires cuda')
def test_conv2d_cuda(self):
x = self.input.cuda()
weight = self.weight.cuda()
res = conv2d(x, weight, None, 1, 1)
assert res.shape == (1, 1, 32, 32)
gradgradcheck(partial(conv2d, weight=weight, padding=1, stride=1), x)
class TestCond2dTansposed:
@classmethod
def setup_class(cls):
cls.input = torch.randn((1, 3, 32, 32))
cls.weight = nn.Parameter(torch.randn(3, 1, 3, 3))
@pytest.mark.skipif(
not torch.cuda.is_available()
or not hasattr(torch.backends.cudnn, 'allow_tf32'),
reason='requires cuda')
def test_conv2d_transposed_cuda(self):
x = self.input.cuda()
weight = self.weight.cuda()
res = conv_transpose2d(x, weight, None, 1, 1)
assert res.shape == (1, 1, 32, 32)
gradgradcheck(
partial(conv_transpose2d, weight=weight, padding=1, stride=1), x)
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
import torch.nn as nn
from mmgen.ops.stylegan3.ops import bias_act, upfirdn2d
class TestStyleGAN3Ops:
@classmethod
def setup_class(cls):
cls.input = torch.randn((1, 3, 16, 16))
cls.bias = torch.randn(3)
cls.kernel = nn.Parameter(torch.randn(3, 3), requires_grad=False)
def test_s3_ops_cpu(self):
out = upfirdn2d.upfirdn2d(self.input, self.kernel)
assert out.shape == (1, 3, 14, 14)
out = upfirdn2d.upfirdn2d(
self.input, self.kernel, up=2, down=1, padding=1)
assert out.shape == (1, 3, 32, 32)
out = upfirdn2d.upfirdn2d(
self.input, self.kernel, up=1, down=2, padding=1)
assert out.shape == (1, 3, 8, 8)
out = bias_act.bias_act(self.input)
assert out.shape == (1, 3, 16, 16)
# test bias
out = bias_act.bias_act(self.input, self.bias)
assert out.shape == (1, 3, 16, 16)
# test gain
out = bias_act.bias_act(self.input, gain=0.5)
assert out.shape == (1, 3, 16, 16)
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
def test_s3_ops_cuda(self):
out = upfirdn2d.upfirdn2d(self.input.cuda(), self.kernel.cuda())
assert out.shape == (1, 3, 14, 14)
out = upfirdn2d.upfirdn2d(
self.input.cuda(), self.kernel.cuda(), up=2, down=1, padding=1)
assert out.shape == (1, 3, 32, 32)
out = upfirdn2d.upfirdn2d(
self.input.cuda(), self.kernel.cuda(), up=1, down=2, padding=1)
assert out.shape == (1, 3, 8, 8)
out = bias_act.bias_act(self.input.cuda())
assert out.shape == (1, 3, 16, 16)
# test bias
out = bias_act.bias_act(self.input.cuda(), self.bias.cuda())
assert out.shape == (1, 3, 16, 16)
# test gain
out = bias_act.bias_act(self.input.cuda(), gain=0.5)
assert out.shape == (1, 3, 16, 16)
# Copyright (c) OpenMMLab. All rights reserved.
import os
from tempfile import TemporaryDirectory
from mmgen.utils import download_from_url
def test_download_from_file():
img_url = 'https://user-images.githubusercontent.com/12726765/114528756-de55af80-9c7b-11eb-94d7-d3224ada1585.png' # noqa
with TemporaryDirectory() as temp_dir:
local_file = download_from_url(url=img_url, dest_dir=temp_dir)
assert os.path.exists(local_file)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment