Commit c9a48a52 authored by limm's avatar limm
Browse files

add tests code

parent b7536f78
Pipeline #2778 canceled with stages
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy
import pytest
import torch
from mmgen.models.architectures.cyclegan import ResnetGenerator
class TestResnetGenerator:
@classmethod
def setup_class(cls):
cls.default_cfg = dict(
in_channels=3,
out_channels=3,
base_channels=64,
norm_cfg=dict(type='IN'),
use_dropout=False,
num_blocks=9,
padding_mode='reflect',
init_cfg=dict(type='normal', gain=0.02))
def test_cyclegan_generator_cpu(self):
# test with default cfg
real_a = torch.randn((2, 3, 256, 256))
gen = ResnetGenerator(**self.default_cfg)
fake_b = gen(real_a)
assert fake_b.shape == (2, 3, 256, 256)
# test args system
cfg = deepcopy(self.default_cfg)
cfg['num_blocks'] = 8
gen = ResnetGenerator(**cfg)
fake_b = gen(real_a)
assert fake_b.shape == (2, 3, 256, 256)
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
def test_cyclegan_generator_cuda(self):
# test with default cfg
real_a = torch.randn((2, 3, 256, 256)).cuda()
gen = ResnetGenerator(**self.default_cfg).cuda()
fake_b = gen(real_a)
assert fake_b.shape == (2, 3, 256, 256)
# test args system
cfg = deepcopy(self.default_cfg)
cfg['num_blocks'] = 8
gen = ResnetGenerator(**cfg).cuda()
fake_b = gen(real_a)
assert fake_b.shape == (2, 3, 256, 256)
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment