Commit c9a48a52 authored by limm's avatar limm
Browse files

add tests code

parent b7536f78
Pipeline #2778 canceled with stages
# Copyright (c) OpenMMLab. All rights reserved.
import os
import mmcv
import pytest
import torch
from mmgen.apis import (init_model, sample_ddpm_model, sample_img2img_model,
sample_unconditional_model)
class TestSampleUnconditionalModel:
@classmethod
def setup_class(cls):
project_dir = os.path.abspath(os.path.join(__file__, '../../..'))
config = mmcv.Config.fromfile(
os.path.join(
project_dir,
'configs/dcgan/dcgan_celeba-cropped_64_b128x1_300k.py'))
cls.model = init_model(config, checkpoint=None, device='cpu')
def test_sample_unconditional_model_cpu(self):
res = sample_unconditional_model(
self.model, 5, num_batches=2, sample_model='orig')
assert res.shape == (5, 3, 64, 64)
res = sample_unconditional_model(
self.model, 4, num_batches=2, sample_model='orig')
assert res.shape == (4, 3, 64, 64)
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
def test_sample_unconditional_model_cuda(self):
model = self.model.cuda()
res = sample_unconditional_model(
model, 5, num_batches=2, sample_model='orig')
assert res.shape == (5, 3, 64, 64)
res = sample_unconditional_model(
model, 4, num_batches=2, sample_model='orig')
assert res.shape == (4, 3, 64, 64)
class TestSampleTranslationModel:
@classmethod
def setup_class(cls):
project_dir = os.path.abspath(os.path.join(__file__, '../../..'))
pix2pix_config = mmcv.Config.fromfile(
os.path.join(
project_dir,
'configs/pix2pix/pix2pix_vanilla_unet_bn_facades_b1x1_80k.py'))
cls.pix2pix = init_model(pix2pix_config, checkpoint=None, device='cpu')
cyclegan_config = mmcv.Config.fromfile(
os.path.join(
project_dir,
'configs/cyclegan/cyclegan_lsgan_resnet_in_facades_b1x1_80k.py'
))
cls.cyclegan = init_model(
cyclegan_config, checkpoint=None, device='cpu')
cls.img_path = os.path.join(
os.path.dirname(__file__), '..', 'data/unpaired/testA/5.jpg')
def test_translation_model_cpu(self):
res = sample_img2img_model(
self.pix2pix, self.img_path, target_domain='photo')
assert res.shape == (1, 3, 256, 256)
res = sample_img2img_model(
self.cyclegan, self.img_path, target_domain='photo')
assert res.shape == (1, 3, 256, 256)
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
def test_translation_model_cuda(self):
res = sample_img2img_model(
self.pix2pix.cuda(), self.img_path, target_domain='photo')
assert res.shape == (1, 3, 256, 256)
res = sample_img2img_model(
self.cyclegan.cuda(), self.img_path, target_domain='photo')
assert res.shape == (1, 3, 256, 256)
class TestDiffusionModel:
@classmethod
def setup_class(cls):
project_dir = os.path.abspath(os.path.join(__file__, '../../..'))
ddpm_config = mmcv.Config.fromfile(
os.path.join(
project_dir, 'configs/improved_ddpm/'
'ddpm_cosine_hybird_timestep-4k_drop0.3_'
'cifar10_32x32_b8x16_500k.py'))
# change timesteps to speed up test process
ddpm_config.model['num_timesteps'] = 10
cls.model = init_model(ddpm_config, checkpoint=None, device='cpu')
def test_diffusion_model_cpu(self):
# save_intermedia is False
res = sample_ddpm_model(
self.model, num_samples=3, num_batches=2, same_noise=True)
assert res.shape == (3, 3, 32, 32)
# save_intermedia is True
res = sample_ddpm_model(
self.model,
num_samples=2,
num_batches=2,
same_noise=True,
save_intermedia=True)
assert isinstance(res, dict)
assert all([i in res for i in range(10)])
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
def test_diffusion_model_cuda(self):
model = self.model.cuda()
# save_intermedia is False
res = sample_ddpm_model(
model, num_samples=3, num_batches=2, same_noise=True)
assert res.shape == (3, 3, 32, 32)
# save_intermedia is True
res = sample_ddpm_model(
model,
num_samples=2,
num_batches=2,
same_noise=True,
save_intermedia=True)
assert isinstance(res, dict)
assert all([i in res for i in range(10)])
# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy
import pytest
import torch
import torch.nn as nn
from torch.nn.parallel import DataParallel
from mmgen.core.hooks import ExponentialMovingAverageHook
class SimpleModule(nn.Module):
def __init__(self):
super().__init__()
self.a = nn.Parameter(torch.tensor([1., 2.]))
if torch.__version__ >= '1.7.0':
self.register_buffer('b', torch.tensor([2., 3.]), persistent=True)
self.register_buffer('c', torch.tensor([0., 1.]), persistent=False)
else:
self.register_buffer('b', torch.tensor([2., 3.]))
self.c = torch.tensor([0., 1.])
class SimpleModel(nn.Module):
def __init__(self) -> None:
super().__init__()
self.module_a = SimpleModule()
self.module_b = SimpleModule()
self.module_a_ema = SimpleModule()
self.module_b_ema = SimpleModule()
class SimpleModelNoEMA(nn.Module):
def __init__(self) -> None:
super().__init__()
self.module_a = SimpleModule()
self.module_b = SimpleModule()
class SimpleRunner:
def __init__(self):
self.model = SimpleModel()
self.iter = 0
class TestEMA:
@classmethod
def setup_class(cls):
cls.default_config = dict(
module_keys=('module_a_ema', 'module_b_ema'),
interval=1,
interp_cfg=dict(momentum=0.5))
cls.runner = SimpleRunner()
@torch.no_grad()
def test_ema_hook(self):
cfg_ = deepcopy(self.default_config)
cfg_['interval'] = -1
ema = ExponentialMovingAverageHook(**cfg_)
ema.before_run(self.runner)
ema.after_train_iter(self.runner)
module_a = self.runner.model.module_a
module_a_ema = self.runner.model.module_a_ema
ema_states = module_a_ema.state_dict()
assert torch.equal(ema_states['a'], torch.tensor([1., 2.]))
ema = ExponentialMovingAverageHook(**self.default_config)
ema.after_train_iter(self.runner)
ema_states = module_a_ema.state_dict()
assert torch.equal(ema_states['a'], torch.tensor([1., 2.]))
module_a.b /= 2.
module_a.a.data /= 2.
module_a.c /= 2.
self.runner.iter += 1
ema.after_train_iter(self.runner)
ema_states = module_a_ema.state_dict()
assert torch.equal(self.runner.model.module_a.a,
torch.tensor([0.5, 1.]))
assert torch.equal(ema_states['a'], torch.tensor([0.75, 1.5]))
assert torch.equal(ema_states['b'], torch.tensor([1., 1.5]))
assert 'c' not in ema_states
# check for the validity of args
with pytest.raises(AssertionError):
_ = ExponentialMovingAverageHook(module_keys=['a'])
with pytest.raises(AssertionError):
_ = ExponentialMovingAverageHook(module_keys=('a'))
with pytest.raises(AssertionError):
_ = ExponentialMovingAverageHook(
module_keys=('module_a_ema'), interp_mode='xxx')
# test before run
ema = ExponentialMovingAverageHook(**self.default_config)
self.runner.model = SimpleModelNoEMA()
self.runner.iter = 0
ema.before_run(self.runner)
assert hasattr(self.runner.model, 'module_a_ema')
module_a = self.runner.model.module_a
module_a_ema = self.runner.model.module_a_ema
ema.after_train_iter(self.runner)
ema_states = module_a_ema.state_dict()
assert torch.equal(ema_states['a'], torch.tensor([1., 2.]))
module_a.b /= 2.
module_a.a.data /= 2.
module_a.c /= 2.
self.runner.iter += 1
ema.after_train_iter(self.runner)
ema_states = module_a_ema.state_dict()
assert torch.equal(self.runner.model.module_a.a,
torch.tensor([0.5, 1.]))
assert torch.equal(ema_states['a'], torch.tensor([0.75, 1.5]))
assert torch.equal(ema_states['b'], torch.tensor([1., 1.5]))
assert 'c' not in ema_states
# test ema with simple warm up
runner = SimpleRunner()
cfg_ = deepcopy(self.default_config)
cfg_.update(dict(start_iter=3, interval=1))
ema = ExponentialMovingAverageHook(**cfg_)
ema.before_run(runner)
module_a = runner.model.module_a
module_a_ema = runner.model.module_a_ema
module_a.a.data /= 2.
runner.iter += 1
ema.after_train_iter(runner)
ema_states = module_a_ema.state_dict()
assert torch.equal(runner.model.module_a.a, torch.tensor([0.5, 1.]))
assert torch.equal(ema_states['a'], torch.tensor([0.5, 1.]))
module_a.a.data /= 2
runner.iter += 2
ema.after_train_iter(runner)
ema_states = module_a_ema.state_dict()
assert torch.equal(runner.model.module_a.a, torch.tensor([0.25, 0.5]))
assert torch.equal(ema_states['a'], torch.tensor([0.375, 0.75]))
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
def test_ema_hook_cuda(self):
ema = ExponentialMovingAverageHook(**self.default_config)
cuda_runner = SimpleRunner()
cuda_runner.model = cuda_runner.model.cuda()
ema.after_train_iter(cuda_runner)
module_a = cuda_runner.model.module_a
module_a_ema = cuda_runner.model.module_a_ema
ema_states = module_a_ema.state_dict()
assert torch.equal(ema_states['a'], torch.tensor([1., 2.]).cuda())
module_a.b /= 2.
module_a.a.data /= 2.
module_a.c /= 2.
cuda_runner.iter += 1
ema.after_train_iter(cuda_runner)
ema_states = module_a_ema.state_dict()
assert torch.equal(cuda_runner.model.module_a.a,
torch.tensor([0.5, 1.]).cuda())
assert torch.equal(ema_states['a'], torch.tensor([0.75, 1.5]).cuda())
assert torch.equal(ema_states['b'], torch.tensor([1., 1.5]).cuda())
assert 'c' not in ema_states
# test before run
ema = ExponentialMovingAverageHook(**self.default_config)
self.runner.model = SimpleModelNoEMA().cuda()
self.runner.model = DataParallel(self.runner.model)
self.runner.iter = 0
ema.before_run(self.runner)
assert hasattr(self.runner.model.module, 'module_a_ema')
module_a = self.runner.model.module.module_a
module_a_ema = self.runner.model.module.module_a_ema
ema.after_train_iter(self.runner)
ema_states = module_a_ema.state_dict()
assert torch.equal(ema_states['a'], torch.tensor([1., 2.]).cuda())
module_a.b /= 2.
module_a.a.data /= 2.
module_a.c /= 2.
self.runner.iter += 1
ema.after_train_iter(self.runner)
ema_states = module_a_ema.state_dict()
assert torch.equal(self.runner.model.module.module_a.a,
torch.tensor([0.5, 1.]).cuda())
assert torch.equal(ema_states['a'], torch.tensor([0.75, 1.5]).cuda())
assert torch.equal(ema_states['b'], torch.tensor([1., 1.5]).cuda())
assert 'c' not in ema_states
# test ema with simple warm up
runner = SimpleRunner()
runner.model = runner.model.cuda()
cfg_ = deepcopy(self.default_config)
cfg_.update(dict(start_iter=3, interval=1))
ema = ExponentialMovingAverageHook(**cfg_)
ema.before_run(runner)
module_a = runner.model.module_a
module_a_ema = runner.model.module_a_ema
module_a.a.data /= 2.
runner.iter += 1
ema.after_train_iter(runner)
ema_states = module_a_ema.state_dict()
assert torch.equal(runner.model.module_a.a,
torch.tensor([0.5, 1.]).cuda())
assert torch.equal(ema_states['a'], torch.tensor([0.5, 1.]).cuda())
module_a.a.data /= 2
runner.iter += 2
ema.after_train_iter(runner)
ema_states = module_a_ema.state_dict()
assert torch.equal(runner.model.module_a.a,
torch.tensor([0.25, 0.5]).cuda())
assert torch.equal(ema_states['a'], torch.tensor([0.375, 0.75]).cuda())
def test_dynamic_ema(self):
# test within rampup phase
cfg_ = dict(
module_keys=('module_a_ema', 'module_b_ema'),
interp_cfg=dict(momentum=0.9),
interval=1,
start_iter=0,
momentum_policy='rampup',
momentum_cfg=dict(
ema_kimg=10, ema_rampup=0.05, batch_size=4, eps=1e-8))
runner = SimpleRunner()
ema = ExponentialMovingAverageHook(**cfg_)
ema.before_run(runner)
module_a = runner.model.module_a
module_a_ema = runner.model.module_a_ema
ema_states = module_a_ema.state_dict()
assert torch.equal(ema_states['a'], torch.tensor([1., 2.]))
module_a.b /= 2.
module_a.a.data /= 2.
module_a.c /= 2.
runner.iter += 19
ema.after_train_iter(runner)
ema_states = module_a_ema.state_dict()
assert torch.equal(runner.model.module_a.a, torch.tensor([0.5, 1.]))
assert torch.equal(ema_states['a'], torch.tensor([0.75, 1.5]))
assert torch.equal(ema_states['b'], torch.tensor([1., 1.5]))
assert 'c' not in ema_states
# test exceeds rampup phase
cfg_ = dict(
module_keys=('module_a_ema', 'module_b_ema'),
interp_cfg=dict(momentum=0.9),
interval=1,
start_iter=0,
momentum_policy='rampup',
momentum_cfg=dict(
ema_kimg=10, ema_rampup=0.05, batch_size=4, eps=1e-8))
runner = SimpleRunner()
ema = ExponentialMovingAverageHook(**cfg_)
ema.before_run(runner)
# modify module data
module_a = runner.model.module_a
module_a_ema = runner.model.module_a_ema
ema_states = module_a_ema.state_dict()
assert torch.equal(ema_states['a'], torch.tensor([1., 2.]))
module_a.b /= 2.
module_a.a.data /= 2.
module_a.c /= 2.
runner.iter += 49999
ema.after_train_iter(runner)
ema_states = module_a_ema.state_dict()
assert torch.equal(runner.model.module_a.a, torch.tensor([0.5, 1.]))
expected_m = 0.5**(4 / 10000)
assert torch.equal(
ema_states['a'],
expected_m * torch.tensor([1.0, 2.0]) +
(1. - expected_m) * torch.tensor([0.5, 1.0]))
assert torch.equal(ema_states['b'], torch.tensor([1., 1.5]))
assert 'c' not in ema_states
# test exceeds rampup phase
cfg_ = dict(
module_keys=('module_a_ema', 'module_b_ema'),
interp_cfg=dict(momentum=0.9),
interval=1,
start_iter=0,
momentum_policy='rampup',
momentum_cfg=dict(
ema_kimg=10, ema_rampup=0.05, batch_size=4, eps=1e-8))
runner = SimpleRunner()
ema = ExponentialMovingAverageHook(**cfg_)
ema.before_run(runner)
# modify module data
module_a = runner.model.module_a
module_a_ema = runner.model.module_a_ema
ema_states = module_a_ema.state_dict()
assert torch.equal(ema_states['a'], torch.tensor([1., 2.]))
module_a.b /= 2.
module_a.a.data /= 2.
module_a.c /= 2.
runner.iter += 79999
ema.after_train_iter(runner)
ema_states = module_a_ema.state_dict()
assert torch.equal(runner.model.module_a.a, torch.tensor([0.5, 1.]))
expected_m = 0.5**(4 / 10000)
assert torch.equal(
ema_states['a'],
expected_m * torch.tensor([1.0, 2.0]) +
(1. - expected_m) * torch.tensor([0.5, 1.0]))
assert torch.equal(ema_states['b'], torch.tensor([1., 1.5]))
assert 'c' not in ema_states
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import pytest
import torch
import torch.nn as nn
from mmcv.utils import TORCH_VERSION
from mmgen.core.runners.fp16_utils import (auto_fp16, cast_tensor_type,
nan_to_num)
def test_nan_to_num():
a = torch.tensor([float('inf'), float('nan'), 2.])
res = nan_to_num(a, posinf=255., neginf=-255.)
assert (res == torch.tensor([255., 0., 2.])).all()
res = nan_to_num(a)
assert res.shape == (3, )
with pytest.raises(TypeError):
nan_to_num(1)
def test_cast_tensor_type():
inputs = torch.FloatTensor([5.])
src_type = torch.float32
dst_type = torch.int32
outputs = cast_tensor_type(inputs, src_type, dst_type)
assert isinstance(outputs, torch.Tensor)
assert outputs.dtype == dst_type
inputs = 'tensor'
src_type = str
dst_type = str
outputs = cast_tensor_type(inputs, src_type, dst_type)
assert isinstance(outputs, str)
inputs = np.array([5.])
src_type = np.ndarray
dst_type = np.ndarray
outputs = cast_tensor_type(inputs, src_type, dst_type)
assert isinstance(outputs, np.ndarray)
inputs = dict(
tensor_a=torch.FloatTensor([1.]), tensor_b=torch.FloatTensor([2.]))
src_type = torch.float32
dst_type = torch.int32
outputs = cast_tensor_type(inputs, src_type, dst_type)
assert isinstance(outputs, dict)
assert outputs['tensor_a'].dtype == dst_type
assert outputs['tensor_b'].dtype == dst_type
inputs = [torch.FloatTensor([1.]), torch.FloatTensor([2.])]
src_type = torch.float32
dst_type = torch.int32
outputs = cast_tensor_type(inputs, src_type, dst_type)
assert isinstance(outputs, list)
assert outputs[0].dtype == dst_type
assert outputs[1].dtype == dst_type
inputs = 5
outputs = cast_tensor_type(inputs, None, None)
assert isinstance(outputs, int)
inputs = nn.Sequential(nn.Conv2d(2, 2, 3), nn.ReLU())
outputs = cast_tensor_type(inputs, None, None)
assert isinstance(outputs, nn.Module)
@pytest.mark.skipif(
not TORCH_VERSION >= '1.6.0', reason='Lower PyTorch version')
def test_auto_fp16_func():
with pytest.raises(TypeError):
# ExampleObject is not a subclass of nn.Module
class ExampleObject(object):
@auto_fp16()
def __call__(self, x):
return x
model = ExampleObject()
input_x = torch.ones(1, dtype=torch.float32)
model(input_x)
# apply to all input args
class ExampleModule(nn.Module):
@auto_fp16()
def forward(self, x, y):
return x, y
model = ExampleModule()
input_x = torch.ones(1, dtype=torch.float32)
input_y = torch.ones(1, dtype=torch.float32)
output_x, output_y = model(input_x, input_y)
assert output_x.dtype == torch.float32
assert output_y.dtype == torch.float32
model.fp16_enabled = True
output_x, output_y = model(input_x, input_y)
assert output_x.dtype == torch.float32
assert output_y.dtype == torch.float32
if torch.cuda.is_available():
model.cuda()
output_x, output_y = model(input_x.cuda(), input_y.cuda())
assert output_x.dtype == torch.float32
assert output_y.dtype == torch.float32
# apply to specified input args
class ExampleModule(nn.Module):
@auto_fp16(apply_to=('x', ))
def forward(self, x, y):
return x, y
model = ExampleModule()
input_x = torch.ones(1, dtype=torch.float32)
input_y = torch.ones(1, dtype=torch.float32)
output_x, output_y = model(input_x, input_y)
assert output_x.dtype == torch.float32
assert output_y.dtype == torch.float32
model.fp16_enabled = True
output_x, output_y = model(input_x, input_y)
assert output_x.dtype == torch.half
assert output_y.dtype == torch.float32
if torch.cuda.is_available():
model.cuda()
output_x, output_y = model(input_x.cuda(), input_y.cuda())
assert output_x.dtype == torch.half
assert output_y.dtype == torch.float32
# apply to optional input args
class ExampleModule(nn.Module):
@auto_fp16(apply_to=('x', 'y'))
def forward(self, x, y=None, z=None):
return x, y, z
model = ExampleModule()
input_x = torch.ones(1, dtype=torch.float32)
input_y = torch.ones(1, dtype=torch.float32)
input_z = torch.ones(1, dtype=torch.float32)
output_x, output_y, output_z = model(input_x, y=input_y, z=input_z)
assert output_x.dtype == torch.float32
assert output_y.dtype == torch.float32
assert output_z.dtype == torch.float32
model.fp16_enabled = True
output_x, output_y, output_z = model(input_x, y=input_y, z=input_z)
assert output_x.dtype == torch.half
assert output_y.dtype == torch.half
assert output_z.dtype == torch.float32
if torch.cuda.is_available():
model.cuda()
output_x, output_y, output_z = model(
input_x.cuda(), y=input_y.cuda(), z=input_z.cuda())
assert output_x.dtype == torch.half
assert output_y.dtype == torch.half
assert output_z.dtype == torch.float32
# out_fp32=True
class ExampleModule(nn.Module):
def __init__(self):
super().__init__()
self.out_fp32 = True
@auto_fp16(apply_to=('x', 'y'))
def forward(self, x, y=None, z=None):
return x, y, z
model = ExampleModule()
model.fp16_enabled = True
input_x = torch.ones(1, dtype=torch.half)
input_y = torch.ones(1, dtype=torch.float32)
input_z = torch.ones(1, dtype=torch.float32)
output_x, output_y, output_z = model(input_x, y=input_y, z=input_z)
assert output_x.dtype == torch.float32
assert output_y.dtype == torch.float32
assert output_z.dtype == torch.float32
# out_fp32=True
class ExampleModule(nn.Module):
@auto_fp16(apply_to=('x', 'y'), out_fp32=True)
def forward(self, x, y=None, z=None):
return x, y, z
model = ExampleModule()
input_x = torch.ones(1, dtype=torch.half)
input_y = torch.ones(1, dtype=torch.float32)
input_z = torch.ones(1, dtype=torch.float32)
output_x, output_y, output_z = model(input_x, y=input_y, z=input_z)
assert output_x.dtype == torch.half
assert output_y.dtype == torch.float32
assert output_z.dtype == torch.float32
model.fp16_enabled = True
output_x, output_y, output_z = model(input_x, y=input_y, z=input_z)
assert output_x.dtype == torch.float32
assert output_y.dtype == torch.float32
assert output_z.dtype == torch.float32
if torch.cuda.is_available():
model.cuda()
output_x, output_y, output_z = model(
input_x.cuda(), y=input_y.cuda(), z=input_z.cuda())
assert output_x.dtype == torch.float32
assert output_y.dtype == torch.float32
assert output_z.dtype == torch.float32
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import numpy as np
import pytest
import torch
from mmgen.core.evaluation.metric_utils import extract_inception_features
from mmgen.core.evaluation.metrics import (FID, IS, MS_SSIM, PPL, PR, SWD,
GaussianKLD)
from mmgen.datasets import UnconditionalImageDataset, build_dataloader
from mmgen.models import build_model
from mmgen.models.architectures import InceptionV3
from mmgen.utils import download_from_url
# def test_inception_download():
# from mmgen.core.evaluation.metrics import load_inception
# from mmgen.utils import MMGEN_CACHE_DIR
# args_FID_pytorch = dict(type='pytorch', normalize_input=False)
# args_FID_tero = dict(type='StyleGAN', inception_path='')
# args_IS_pytorch = dict(type='pytorch')
# args_IS_tero = dict(
# type='StyleGAN',
# inception_path=osp.join(MMGEN_CACHE_DIR, 'inception-2015-12-05.pt'))
# tar_style_list = ['pytorch', 'StyleGAN', 'pytorch', 'StyleGAN']
# for inception_args, metric, tar_style in zip(
# [args_FID_pytorch, args_FID_tero, args_IS_pytorch, args_IS_tero],
# ['FID', 'FID', 'IS', 'IS'], tar_style_list):
# model, style = load_inception(inception_args, metric)
# assert style == tar_style
# args_empty = ''
# with pytest.raises(TypeError) as exc_info:
# load_inception(args_empty, 'FID')
# args_error_path = dict(type='StyleGAN', inception_path='error-path')
# with pytest.raises(RuntimeError) as exc_info:
# load_inception(args_error_path, 'FID')
def test_swd_metric():
img_nchw_1 = torch.rand((100, 3, 32, 32))
img_nchw_2 = torch.rand((100, 3, 32, 32))
metric = SWD(100, (3, 32, 32))
metric.prepare()
metric.feed(img_nchw_1, 'reals')
metric.feed(img_nchw_2, 'fakes')
result = [16.495922580361366, 24.15413036942482, 20.325026474893093]
output = metric.summary()
result = [item / 100 for item in result]
output = [item / 100 for item in output]
np.testing.assert_almost_equal(output, result, decimal=1)
def test_ms_ssim():
img_nhwc_1 = torch.rand((100, 3, 32, 32))
img_nhwc_2 = torch.rand((100, 3, 32, 32))
metric = MS_SSIM(100)
metric.prepare()
metric.feed(img_nhwc_1, 'reals')
metric.feed(img_nhwc_2, 'fakes')
ssim_result = metric.summary()
assert ssim_result < 1
class TestExtractInceptionFeat:
@classmethod
def setup_class(cls):
cls.inception = InceptionV3(
load_fid_inception=False, resize_input=True)
pipeline = [
dict(type='LoadImageFromFile', key='real_img'),
dict(
type='Resize',
keys=['real_img'],
scale=(299, 299),
keep_ratio=False,
),
dict(
type='Normalize',
keys=['real_img'],
mean=[127.5] * 3,
std=[127.5] * 3,
to_rgb=False),
dict(type='Collect', keys=['real_img'], meta_keys=[]),
dict(type='ImageToTensor', keys=['real_img'])
]
dataset = UnconditionalImageDataset(
osp.join(osp.dirname(__file__), '..', 'data'), pipeline)
cls.data_loader = build_dataloader(dataset, 3, 0, dist=False)
def test_extr_inception_feat(self):
feat = extract_inception_features(self.data_loader, self.inception, 5)
assert feat.shape[0] == 5
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
def test_extr_inception_feat_cuda(self):
inception = torch.nn.DataParallel(self.inception)
feat = extract_inception_features(self.data_loader, inception, 5)
assert feat.shape[0] == 5
@torch.no_grad()
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
def test_with_tero_implement(self):
self.inception = InceptionV3(
load_fid_inception=True, resize_input=False)
img = torch.randn((2, 3, 1024, 1024))
feature_ours = self.inception(img)[0].view(img.shape[0], -1)
# Tero implementation
download_from_url(
'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt', # noqa
dest_dir='./work_dirs/cache')
net = torch.jit.load(
'./work_dirs/cache/inception-2015-12-05.pt').eval().cuda()
net = torch.nn.DataParallel(net)
feature_tero = net(img, return_features=True)
print(feature_ours.shape)
print((feature_tero.cpu() - feature_ours).abs().mean())
class TestFID:
@classmethod
def setup_class(cls):
cls.reals = [torch.randn(2, 3, 128, 128) for _ in range(5)]
cls.fakes = [torch.randn(2, 3, 128, 128) for _ in range(5)]
def test_fid(self):
fid = FID(
3,
inception_args=dict(
normalize_input=False, load_fid_inception=False))
for b in self.reals:
fid.feed(b, 'reals')
for b in self.fakes:
fid.feed(b, 'fakes')
fid_score, mean, cov = fid.summary()
assert fid_score > 0 and mean > 0 and cov > 0
# To reduce the size of git repo, we remove the following test
# fid = FID(
# 3,
# inception_args=dict(
# normalize_input=False, load_fid_inception=False),
# inception_pkl=osp.join(
# osp.dirname(__file__), '..', 'data', 'test_dirty.pkl'))
# assert fid.num_real_feeded == 3
# for b in self.reals:
# fid.feed(b, 'reals')
# for b in self.fakes:
# fid.feed(b, 'fakes')
# fid_score, mean, cov = fid.summary()
# assert fid_score > 0 and mean > 0 and cov > 0
class TestPR:
@classmethod
def setup_class(cls):
cls.reals = [torch.rand(2, 3, 128, 128) for _ in range(5)]
cls.fakes = [torch.rand(2, 3, 128, 128) for _ in range(5)]
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
def test_pr_cuda(self):
pr = PR(10)
pr.prepare()
for b in self.fakes:
pr.feed(b.cuda(), 'fakes')
for b in self.reals:
pr.feed(b.cuda(), 'reals')
pr_score = pr.summary()
print(pr_score)
assert pr_score['precision'] >= 0 and pr_score['recall'] >= 0
def test_pr_cpu(self):
pr = PR(10)
pr.prepare()
for b in self.fakes:
pr.feed(b, 'fakes')
for b in self.reals:
pr.feed(b, 'reals')
pr_score = pr.summary()
assert pr_score['precision'] >= 0 and pr_score['recall'] >= 0
class TestIS:
@classmethod
def setup_class(cls):
cls.reals = [torch.randn(2, 3, 128, 128) for _ in range(5)]
cls.fakes = [torch.randn(2, 3, 128, 128) for _ in range(5)]
def test_is_cpu(self):
inception_score = IS(10, resize=True, splits=10)
inception_score.prepare()
for b in self.reals:
inception_score.feed(b, 'reals')
for b in self.fakes:
inception_score.feed(b, 'fakes')
score, std = inception_score.summary()
assert score > 0 and std >= 0
@torch.no_grad()
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
def test_is_cuda(self):
inception_score = IS(10, resize=True, splits=10)
inception_score.prepare()
for b in self.reals:
inception_score.feed(b.cuda(), 'reals')
for b in self.fakes:
inception_score.feed(b.cuda(), 'fakes')
score, std = inception_score.summary()
assert score > 0 and std >= 0
class TestPPL:
@classmethod
def setup_class(cls):
cls.model_cfg = dict(
type='StaticUnconditionalGAN',
generator=dict(
type='StyleGANv2Generator',
out_size=256,
style_channels=512,
),
discriminator=dict(
type='StyleGAN2Discriminator',
in_size=256,
),
gan_loss=dict(type='GANLoss', gan_type='wgan-logistic-ns'),
train_cfg=dict(use_ema=True))
def test_ppl_cpu(self):
self.model = build_model(self.model_cfg)
ppl = PPL(10)
ppl_iterator = iter(ppl.get_sampler(self.model, 2, 'ema'))
ppl.prepare()
for b in ppl_iterator:
ppl.feed(b, 'fakes')
score = ppl.summary()
assert score >= 0
@torch.no_grad()
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
def test_ppl_cuda(self):
self.model = build_model(self.model_cfg).cuda()
ppl = PPL(10)
ppl_iterator = iter(ppl.get_sampler(self.model, 2, 'ema'))
ppl.prepare()
for b in ppl_iterator:
ppl.feed(b, 'fakes')
score = ppl.summary()
assert score >= 0
def test_kld_gaussian():
# we only test at bz = 1 to test the numerical accuracy
# due to the time and memory cost
tar_shape = [2, 3, 4, 4]
mean1, mean2 = torch.rand(*tar_shape, 1), torch.rand(*tar_shape, 1)
# var1, var2 = torch.rand(2, 3, 4, 4, 1), torch.rand(2, 3, 4, 4, 1)
var1 = torch.randint(1, 3, (*tar_shape, 1)).float()
var2 = torch.randint(1, 3, (*tar_shape, 1)).float()
def pdf(x, mean, var):
return (1 / np.sqrt(2 * np.pi * var) * torch.exp(-(x - mean)**2 /
(2 * var)))
delta = 0.0001
indice = torch.arange(-5, 5, delta).repeat(*mean1.shape)
p = pdf(indice, mean1, var1) # pdf of target distribution
q = pdf(indice, mean2, var2) # pdf of predicted distribution
kld_manually = (p * torch.log(p / q) * delta).sum(dim=(1, 2, 3, 4)).mean()
data = dict(
mean_pred=mean2,
mean_target=mean1,
logvar_pred=torch.log(var2),
logvar_target=torch.log(var1))
metric = GaussianKLD(2)
metric.prepare()
metric.feed(data, 'reals')
kld = metric.summary()
# this is a quite loose limitation for we cannot choose delta which is
# small enough for precise kld calculation
np.testing.assert_almost_equal(kld, kld_manually, decimal=1)
# assert (kld - kld_manually < 1e-1).all()
metric_base_2 = GaussianKLD(2, base='2')
metric_base_2.prepare()
metric_base_2.feed(data, 'reals')
kld_base_2 = metric_base_2.summary()
np.testing.assert_almost_equal(kld_base_2, kld / np.log(2), decimal=4)
# assert kld_base_2 == kld / np.log(2)
# test wrong log_base
with pytest.raises(AssertionError):
GaussianKLD(2, base='10')
# test other reduction --> mean
metric = GaussianKLD(2, reduction='mean')
metric.prepare()
metric.feed(data, 'reals')
kld = metric.summary()
# test other reduction --> sum
metric = GaussianKLD(2, reduction='sum')
metric.prepare()
metric.feed(data, 'reals')
kld = metric.summary()
# test other reduction --> error
with pytest.raises(AssertionError):
metric = GaussianKLD(2, reduction='none')
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmgen.core import build_optimizers
class ExampleModel(nn.Module):
def __init__(self):
super(ExampleModel, self).__init__()
self.model1 = nn.Conv2d(3, 8, kernel_size=3)
self.model2 = nn.Conv2d(3, 4, kernel_size=3)
def forward(self, x):
return x
def test_build_optimizers():
base_lr = 0.0001
base_wd = 0.0002
momentum = 0.9
# basic config with ExampleModel
optimizer_cfg = dict(
model1=dict(
type='SGD', lr=base_lr, weight_decay=base_wd, momentum=momentum),
model2=dict(
type='SGD', lr=base_lr, weight_decay=base_wd, momentum=momentum))
model = ExampleModel()
optimizers = build_optimizers(model, optimizer_cfg)
param_dict = dict(model.named_parameters())
assert isinstance(optimizers, dict)
for i in range(2):
optimizer = optimizers[f'model{i+1}']
param_groups = optimizer.param_groups[0]
assert isinstance(optimizer, torch.optim.SGD)
assert optimizer.defaults['lr'] == base_lr
assert optimizer.defaults['momentum'] == momentum
assert optimizer.defaults['weight_decay'] == base_wd
assert len(param_groups['params']) == 2
assert torch.equal(param_groups['params'][0],
param_dict[f'model{i+1}.weight'])
assert torch.equal(param_groups['params'][1],
param_dict[f'model{i+1}.bias'])
# basic config with Parallel model
model = torch.nn.DataParallel(ExampleModel())
optimizers = build_optimizers(model, optimizer_cfg)
param_dict = dict(model.named_parameters())
assert isinstance(optimizers, dict)
for i in range(2):
optimizer = optimizers[f'model{i+1}']
param_groups = optimizer.param_groups[0]
assert isinstance(optimizer, torch.optim.SGD)
assert optimizer.defaults['lr'] == base_lr
assert optimizer.defaults['momentum'] == momentum
assert optimizer.defaults['weight_decay'] == base_wd
assert len(param_groups['params']) == 2
assert torch.equal(param_groups['params'][0],
param_dict[f'module.model{i+1}.weight'])
assert torch.equal(param_groups['params'][1],
param_dict[f'module.model{i+1}.bias'])
# basic config with ExampleModel (one optimizer)
optimizer_cfg = dict(
type='SGD', lr=base_lr, weight_decay=base_wd, momentum=momentum)
model = ExampleModel()
optimizer = build_optimizers(model, optimizer_cfg)
param_dict = dict(model.named_parameters())
assert isinstance(optimizers, dict)
param_groups = optimizer.param_groups[0]
assert isinstance(optimizer, torch.optim.SGD)
assert optimizer.defaults['lr'] == base_lr
assert optimizer.defaults['momentum'] == momentum
assert optimizer.defaults['weight_decay'] == base_wd
assert len(param_groups['params']) == 4
assert torch.equal(param_groups['params'][0], param_dict['model1.weight'])
assert torch.equal(param_groups['params'][1], param_dict['model1.bias'])
assert torch.equal(param_groups['params'][2], param_dict['model2.weight'])
assert torch.equal(param_groups['params'][3], param_dict['model2.bias'])
# basic config with Parallel model (one optimizer)
model = torch.nn.DataParallel(ExampleModel())
optimizer = build_optimizers(model, optimizer_cfg)
param_dict = dict(model.named_parameters())
assert isinstance(optimizers, dict)
param_groups = optimizer.param_groups[0]
assert isinstance(optimizer, torch.optim.SGD)
assert optimizer.defaults['lr'] == base_lr
assert optimizer.defaults['momentum'] == momentum
assert optimizer.defaults['weight_decay'] == base_wd
assert len(param_groups['params']) == 4
assert torch.equal(param_groups['params'][0],
param_dict['module.model1.weight'])
assert torch.equal(param_groups['params'][1],
param_dict['module.model1.bias'])
assert torch.equal(param_groups['params'][2],
param_dict['module.model2.weight'])
assert torch.equal(param_groups['params'][3],
param_dict['module.model2.bias'])
# Copyright (c) OpenMMLab. All rights reserved.
import logging
import shutil
import sys
import tempfile
from unittest.mock import MagicMock, call
import torch
import torch.nn as nn
from mmcv.runner import PaviLoggerHook, build_runner
from torch.utils.data import DataLoader
def _build_demo_runner(runner_type='EpochBasedRunner',
max_epochs=1,
max_iters=None):
class Model(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(2, 1)
def forward(self, x):
return self.linear(x)
def train_step(self, x, optimizer, **kwargs):
return dict(loss=self(x))
def val_step(self, x, optimizer, **kwargs):
return dict(loss=self(x))
model = Model()
optimizer = torch.optim.SGD(model.parameters(), lr=0.02, momentum=0.95)
log_config = dict(
interval=1, hooks=[
dict(type='TextLoggerHook'),
])
tmp_dir = tempfile.mkdtemp()
runner = build_runner(
dict(type=runner_type),
default_args=dict(
model=model,
work_dir=tmp_dir,
optimizer=optimizer,
logger=logging.getLogger(),
max_epochs=max_epochs,
max_iters=max_iters))
runner.register_checkpoint_hook(dict(interval=1))
runner.register_logger_hooks(log_config)
return runner
def test_linear_lr_updater_scheduler():
sys.modules['pavi'] = MagicMock()
loader = DataLoader(torch.ones((10, 2)))
runner = _build_demo_runner()
# add momentum LR scheduler
lr_config = dict(
policy='Linear', by_epoch=False, target_lr=0, start=0, interval=1)
runner.register_lr_hook(lr_config)
runner.register_hook_from_cfg(dict(type='IterTimerHook'))
# add pavi hook
hook = PaviLoggerHook(interval=1, add_graph=False, add_last_ckpt=True)
runner.register_hook(hook)
runner.run([loader], [('train', 1)])
shutil.rmtree(runner.work_dir)
# TODO: use a more elegant way to check values
assert hasattr(hook, 'writer')
calls = [
call('train', {
'learning_rate': 0.018000000000000002,
'momentum': 0.95
}, 2),
call('train', {
'learning_rate': 0.014,
'momentum': 0.95
}, 4),
call('train', {
'learning_rate': 0.01,
'momentum': 0.95
}, 6),
]
hook.writer.add_scalars.assert_has_calls(calls, any_order=True)
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import pytest
import torch
from torchvision.utils import make_grid
from mmgen.models.misc import tensor2img
def test_tensor2img():
tensor_4d_1 = torch.FloatTensor(2, 3, 4, 4).uniform_(0, 1)
tensor_4d_2 = torch.FloatTensor(1, 3, 4, 4).uniform_(0, 1)
tensor_4d_3 = torch.FloatTensor(3, 1, 4, 4).uniform_(0, 1)
tensor_4d_4 = torch.FloatTensor(1, 1, 4, 4).uniform_(0, 1)
tensor_3d_1 = torch.FloatTensor(3, 4, 4).uniform_(0, 1)
tensor_3d_2 = torch.FloatTensor(3, 6, 6).uniform_(0, 1)
tensor_3d_3 = torch.FloatTensor(1, 6, 6).uniform_(0, 1)
tensor_2d = torch.FloatTensor(4, 4).uniform_(0, 1)
with pytest.raises(TypeError):
# input is not a tensor
tensor2img(4)
with pytest.raises(TypeError):
# input is not a list of tensors
tensor2img([tensor_3d_1, 4])
with pytest.raises(ValueError):
# unsupported 5D tensor
tensor2img(torch.FloatTensor(2, 2, 3, 4, 4).uniform_(0, 1))
# 4d
rlt = tensor2img(tensor_4d_1, out_type=np.uint8, min_max=(0, 1))
tensor_4d_1_np = make_grid(tensor_4d_1, nrow=1, normalize=False).numpy()
tensor_4d_1_np = np.transpose(tensor_4d_1_np[[2, 1, 0], :, :], (1, 2, 0))
np.testing.assert_almost_equal(rlt, (tensor_4d_1_np * 255).round())
rlt = tensor2img(tensor_4d_2, out_type=np.uint8, min_max=(0, 1))
tensor_4d_2_np = tensor_4d_2.squeeze().numpy()
tensor_4d_2_np = np.transpose(tensor_4d_2_np[[2, 1, 0], :, :], (1, 2, 0))
np.testing.assert_almost_equal(rlt, (tensor_4d_2_np * 255).round())
rlt = tensor2img(tensor_4d_3, out_type=np.uint8, min_max=(0, 1))
tensor_4d_3_np = make_grid(tensor_4d_3, nrow=1, normalize=False).numpy()
tensor_4d_3_np = np.transpose(tensor_4d_3_np[[2, 1, 0], :, :], (1, 2, 0))
np.testing.assert_almost_equal(rlt, (tensor_4d_3_np * 255).round())
rlt = tensor2img(tensor_4d_4, out_type=np.uint8, min_max=(0, 1))
tensor_4d_4_np = tensor_4d_4.squeeze().numpy()
np.testing.assert_almost_equal(rlt, (tensor_4d_4_np * 255).round())
# 3d
rlt = tensor2img([tensor_3d_1, tensor_3d_2],
out_type=np.uint8,
min_max=(0, 1))
tensor_3d_1_np = tensor_3d_1.numpy()
tensor_3d_1_np = np.transpose(tensor_3d_1_np[[2, 1, 0], :, :], (1, 2, 0))
tensor_3d_2_np = tensor_3d_2.numpy()
tensor_3d_2_np = np.transpose(tensor_3d_2_np[[2, 1, 0], :, :], (1, 2, 0))
np.testing.assert_almost_equal(rlt[0], (tensor_3d_1_np * 255).round())
np.testing.assert_almost_equal(rlt[1], (tensor_3d_2_np * 255).round())
rlt = tensor2img(tensor_3d_3, out_type=np.uint8, min_max=(0, 1))
tensor_3d_3_np = tensor_3d_3.squeeze().numpy()
np.testing.assert_almost_equal(rlt, (tensor_3d_3_np * 255).round())
# 2d
rlt = tensor2img(tensor_2d, out_type=np.uint8, min_max=(0, 1))
tensor_2d_np = tensor_2d.numpy()
np.testing.assert_almost_equal(rlt, (tensor_2d_np * 255).round())
rlt = tensor2img(tensor_2d, out_type=np.float32, min_max=(0, 1))
np.testing.assert_almost_equal(rlt, tensor_2d_np)
rlt = tensor2img(tensor_2d, out_type=np.float32, min_max=(0.1, 0.5))
tensor_2d_np = (np.clip(tensor_2d_np, 0.1, 0.5) - 0.1) / 0.4
np.testing.assert_almost_equal(rlt, tensor_2d_np)
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import tempfile
from unittest.mock import MagicMock
import mmcv
import numpy as np
import pytest
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from mmgen.core import VisualizationHook
from mmgen.utils import get_root_logger
class ExampleDataset(Dataset):
def __getitem__(self, idx):
img = torch.zeros((3, 10, 10))
img[:, 2:9, :] = 1.
results = dict(imgs=img)
return results
def __len__(self):
return 1
class ExampleModel(nn.Module):
def __init__(self):
super(ExampleModel, self).__init__()
self.test_cfg = None
def train_step(self, data_batch, optimizer):
output = dict(results=dict(img=data_batch['imgs']))
return output
def test_visual_hook():
with pytest.raises(AssertionError):
VisualizationHook('temp', [1, 2, 3])
test_dataset = ExampleDataset()
test_dataset.evaluate = MagicMock(return_value=dict(test='success'))
img = torch.zeros((1, 3, 10, 10))
img[:, :, 2:9, :] = 1.
model = ExampleModel()
data_loader = DataLoader(
test_dataset, batch_size=1, sampler=None, num_workers=0, shuffle=False)
with tempfile.TemporaryDirectory() as tmpdir:
visual_hook = VisualizationHook('visual', ['img'], interval=8)
runner = mmcv.runner.IterBasedRunner(
model=model, work_dir=tmpdir, logger=get_root_logger())
runner.register_hook(visual_hook)
runner.run([data_loader], [('train', 10)], 10)
img_saved = mmcv.imread(
osp.join(tmpdir, 'visual', 'iter_8.png'), flag='unchanged')
np.testing.assert_almost_equal(img_saved,
img[0].permute(1, 2, 0) * 127 + 128)
# Copyright (c) OpenMMLab. All rights reserved.
from torch.utils.data import Dataset
from mmgen.datasets import RepeatDataset
def test_repeat_dataset():
class ToyDataset(Dataset):
def __init__(self):
super(ToyDataset, self).__init__()
self.members = [1, 2, 3, 4, 5]
def __len__(self):
return len(self.members)
def __getitem__(self, idx):
return self.members[idx % 5]
toy_dataset = ToyDataset()
repeat_dataset = RepeatDataset(toy_dataset, 2)
assert len(repeat_dataset) == 10
assert repeat_dataset[2] == 3
assert repeat_dataset[8] == 4
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import pytest
from mmgen.datasets import GrowScaleImgDataset
class TestGrowScaleImgDataset:
@classmethod
def setup_class(cls):
cls.imgs_root = osp.join(osp.dirname(__file__), '..', 'data/image')
cls.imgs_roots = {
'4': cls.imgs_root,
'8': osp.join(cls.imgs_root, 'img_root'),
'32': osp.join(cls.imgs_root, 'img_root', 'grass')
}
cls.default_pipeline = [
dict(type='LoadImageFromFile', io_backend='disk', key='real_img')
]
cls.len_per_stage = 10
cls.gpu_samples_base = 2
def test_dynamic_unconditional_img_dataset(self):
dataset = GrowScaleImgDataset(
self.imgs_roots,
self.default_pipeline,
self.len_per_stage,
gpu_samples_base=self.gpu_samples_base)
assert len(dataset) == 10
img = dataset[2]['real_img']
assert img.ndim == 3
assert repr(dataset) == (
f'dataset_name: {dataset.__class__}, '
f'total {10} images in imgs_root: {self.imgs_root}')
assert dataset.samples_per_gpu == 2
dataset.update_annotations(8)
assert len(dataset) == 10
img = dataset[2]['real_img']
assert img.ndim == 3
assert repr(dataset) == (f'dataset_name: {dataset.__class__}, '
f'total {10} images in imgs_root:'
f' {osp.join(self.imgs_root, "img_root")}')
assert dataset.samples_per_gpu == 2
dataset = GrowScaleImgDataset(
self.imgs_roots,
self.default_pipeline,
20,
gpu_samples_base=self.gpu_samples_base,
gpu_samples_per_scale={
'4': 10,
'16': 13
})
assert len(dataset) == 20
img = dataset[2]['real_img']
assert img.ndim == 3
assert repr(dataset) == (
f'dataset_name: {dataset.__class__}, '
f'total {20} images in imgs_root: {self.imgs_root}')
assert dataset.samples_per_gpu == 10
dataset.update_annotations(8)
assert len(dataset) == 20
img = dataset[2]['real_img']
assert img.ndim == 3
assert repr(dataset) == (f'dataset_name: {dataset.__class__}, '
f'total {20} images in imgs_root:'
f' {osp.join(self.imgs_root, "img_root")}')
assert dataset.samples_per_gpu == 2
dataset = GrowScaleImgDataset(
self.imgs_roots, self.default_pipeline, 5, test_mode=True)
assert len(dataset) == 5
img = dataset[2]['real_img']
assert img.ndim == 3
assert repr(dataset) == (
f'dataset_name: {dataset.__class__}, '
f'total {5} images in imgs_root: {self.imgs_root}')
dataset.update_annotations(24)
assert len(dataset) == 5
img = dataset[2]['real_img']
assert img.ndim == 3
_path_str = osp.join(self.imgs_root, 'img_root', 'grass')
assert repr(dataset) == (f'dataset_name: {dataset.__class__}, '
f'total {5} images in imgs_root: {_path_str}')
with pytest.raises(AssertionError):
_ = GrowScaleImgDataset(
self.imgs_root,
self.default_pipeline,
10,
gpu_samples_per_scale=10)
with pytest.raises(AssertionError):
_ = GrowScaleImgDataset(10, self.default_pipeline, 10.)
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from mmgen.datasets import PairedImageDataset
class TestPairedImageDataset(object):
@classmethod
def setup_class(cls):
cls.imgs_root = osp.join(
osp.dirname(osp.dirname(__file__)), 'data/paired')
img_norm_cfg = dict(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
cls.default_pipeline = [
dict(
type='LoadPairedImageFromFile',
io_backend='disk',
key='pair',
domain_a='a',
domain_b='b'),
dict(
type='Resize',
keys=['img_a', 'img_b'],
scale=(286, 286),
interpolation='bicubic'),
dict(
type='FixedCrop',
keys=['img_a', 'img_b'],
crop_size=(256, 256)),
dict(type='Flip', keys=['img_a', 'img_b'], direction='horizontal'),
dict(type='RescaleToZeroOne', keys=['img_a', 'img_b']),
dict(
type='Normalize',
keys=['img_a', 'img_b'],
to_rgb=True,
**img_norm_cfg),
dict(type='ImageToTensor', keys=['img_a', 'img_b']),
dict(
type='Collect',
keys=['img_a', 'img_b'],
meta_keys=['img_a_path', 'img_b_path'])
]
def test_paired_image_dataset(self):
dataset = PairedImageDataset(
self.imgs_root, pipeline=self.default_pipeline)
assert len(dataset) == 2
img = dataset[0]['img_a']
assert img.ndim == 3
img = dataset[0]['img_b']
assert img.ndim == 3
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from mmgen.datasets.builder import build_dataloader, build_dataset
class TestPersistentWorker(object):
@classmethod
def setup_class(cls):
imgs_root = osp.join(osp.dirname(__file__), '..', 'data/image')
train_pipeline = [
dict(type='LoadImageFromFile', io_backend='disk', key='real_img')
]
cls.config = dict(
samples_per_gpu=1,
workers_per_gpu=4,
drop_last=True,
persistent_workers=True)
cls.data_cfg = dict(
type='UnconditionalImageDataset',
imgs_root=imgs_root,
pipeline=train_pipeline,
test_mode=False)
def test_persistent_worker(self):
# test non-persistent-worker
dataset = build_dataset(self.data_cfg)
build_dataloader(dataset, **self.config)
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import numpy as np
import pytest
import torch
from mmgen.datasets.pipelines import (CenterCropLongEdge, Flip, NumpyPad,
RandomCropLongEdge, RandomImgNoise,
Resize)
class TestAugmentations(object):
@classmethod
def setup_class(cls):
cls.results = dict()
cls.img_gt = np.random.rand(256, 128, 3).astype(np.float32)
cls.img_lq = np.random.rand(64, 32, 3).astype(np.float32)
cls.results = dict(
lq=cls.img_lq,
gt=cls.img_gt,
scale=4,
lq_path='fake_lq_path',
gt_path='fake_gt_path')
cls.results['img'] = np.random.rand(256, 256, 3).astype(np.float32)
cls.results['mask'] = np.random.rand(256, 256, 1).astype(np.float32)
cls.results['img_tensor'] = torch.rand((3, 256, 256))
cls.results['mask_tensor'] = torch.zeros((1, 256, 256))
cls.results['mask_tensor'][:, 50:150, 40:140] = 1.
@staticmethod
def assert_img_equal(img, ref_img, ratio_thr=0.999):
"""Check if img and ref_img are matched approximately."""
assert img.shape == ref_img.shape
assert img.dtype == ref_img.dtype
area = ref_img.shape[-1] * ref_img.shape[-2]
diff = np.abs(img.astype('int32') - ref_img.astype('int32'))
assert np.sum(diff <= 1) / float(area) > ratio_thr
@staticmethod
def check_keys_contain(result_keys, target_keys):
"""Check if all elements in target_keys is in result_keys."""
return set(target_keys).issubset(set(result_keys))
@staticmethod
def check_flip(origin_img, result_img, flip_type):
"""Check if the origin_img are flipped correctly into result_img in
different flip_types."""
h, w, c = origin_img.shape
if flip_type == 'horizontal':
# yapf: disable
for i in range(h):
for j in range(w):
for k in range(c):
if result_img[i, j, k] != origin_img[i, w - 1 - j, k]:
return False
# yapf: enable
else:
# yapf: disable
for i in range(h):
for j in range(w):
for k in range(c):
if result_img[i, j, k] != origin_img[h - 1 - i, j, k]:
return False
# yapf: enable
return True
def test_flip(self):
results = copy.deepcopy(self.results)
with pytest.raises(ValueError):
Flip(keys=['lq', 'gt'], direction='vertically')
# horizontal
np.random.seed(1)
target_keys = ['lq', 'gt', 'flip', 'flip_direction']
flip = Flip(keys=['lq', 'gt'], flip_ratio=1, direction='horizontal')
results = flip(results)
assert self.check_keys_contain(results.keys(), target_keys)
assert self.check_flip(self.img_lq, results['lq'],
results['flip_direction'])
assert self.check_flip(self.img_gt, results['gt'],
results['flip_direction'])
assert results['lq'].shape == self.img_lq.shape
assert results['gt'].shape == self.img_gt.shape
# vertical
results = copy.deepcopy(self.results)
flip = Flip(keys=['lq', 'gt'], flip_ratio=1, direction='vertical')
results = flip(results)
assert self.check_keys_contain(results.keys(), target_keys)
assert self.check_flip(self.img_lq, results['lq'],
results['flip_direction'])
assert self.check_flip(self.img_gt, results['gt'],
results['flip_direction'])
assert results['lq'].shape == self.img_lq.shape
assert results['gt'].shape == self.img_gt.shape
assert repr(flip) == flip.__class__.__name__ + (
f"(keys={['lq', 'gt']}, flip_ratio=1, "
f"direction={results['flip_direction']})")
# flip a list
# horizontal
flip = Flip(keys=['lq', 'gt'], flip_ratio=1, direction='horizontal')
results = dict(
lq=[self.img_lq, np.copy(self.img_lq)],
gt=[self.img_gt, np.copy(self.img_gt)],
scale=4,
lq_path='fake_lq_path',
gt_path='fake_gt_path')
flip_rlt = flip(copy.deepcopy(results))
assert self.check_keys_contain(flip_rlt.keys(), target_keys)
assert self.check_flip(self.img_lq, flip_rlt['lq'][0],
flip_rlt['flip_direction'])
assert self.check_flip(self.img_gt, flip_rlt['gt'][0],
flip_rlt['flip_direction'])
np.testing.assert_almost_equal(flip_rlt['gt'][0], flip_rlt['gt'][1])
np.testing.assert_almost_equal(flip_rlt['lq'][0], flip_rlt['lq'][1])
# vertical
flip = Flip(keys=['lq', 'gt'], flip_ratio=1, direction='vertical')
flip_rlt = flip(copy.deepcopy(results))
assert self.check_keys_contain(flip_rlt.keys(), target_keys)
assert self.check_flip(self.img_lq, flip_rlt['lq'][0],
flip_rlt['flip_direction'])
assert self.check_flip(self.img_gt, flip_rlt['gt'][0],
flip_rlt['flip_direction'])
np.testing.assert_almost_equal(flip_rlt['gt'][0], flip_rlt['gt'][1])
np.testing.assert_almost_equal(flip_rlt['lq'][0], flip_rlt['lq'][1])
# no flip
flip = Flip(keys=['lq', 'gt'], flip_ratio=0, direction='vertical')
results = flip(copy.deepcopy(results))
assert self.check_keys_contain(results.keys(), target_keys)
np.testing.assert_almost_equal(results['gt'][0], self.img_gt)
np.testing.assert_almost_equal(results['lq'][0], self.img_lq)
np.testing.assert_almost_equal(results['gt'][0], results['gt'][1])
np.testing.assert_almost_equal(results['lq'][0], results['lq'][1])
def test_resize(self):
with pytest.raises(AssertionError):
Resize([], scale=0.5)
with pytest.raises(AssertionError):
Resize(['gt_img'], size_factor=32, scale=0.5)
with pytest.raises(AssertionError):
Resize(['gt_img'], size_factor=32, keep_ratio=True)
with pytest.raises(AssertionError):
Resize(['gt_img'], max_size=32, size_factor=None)
with pytest.raises(ValueError):
Resize(['gt_img'], scale=-0.5)
with pytest.raises(TypeError):
Resize(['gt_img'], (0.4, 0.2))
with pytest.raises(TypeError):
Resize(['gt_img'], dict(test=None))
target_keys = ['alpha']
alpha = np.random.rand(240, 320).astype(np.float32)
results = dict(alpha=alpha)
resize = Resize(keys=['alpha'], size_factor=32, max_size=None)
resize_results = resize(results)
assert self.check_keys_contain(resize_results.keys(), target_keys)
assert resize_results['alpha'].shape == (224, 320, 1)
resize = Resize(keys=['alpha'], size_factor=32, max_size=320)
resize_results = resize(results)
assert self.check_keys_contain(resize_results.keys(), target_keys)
assert resize_results['alpha'].shape == (224, 320, 1)
resize = Resize(keys=['alpha'], size_factor=32, max_size=200)
resize_results = resize(results)
assert self.check_keys_contain(resize_results.keys(), target_keys)
assert resize_results['alpha'].shape == (192, 192, 1)
resize = Resize(['gt_img'], (-1, 200))
results = dict(gt_img=self.results['gt'].copy())
resize_results = resize(results)
assert resize.scale == (np.inf, 200)
assert resize_results['gt_img'].shape == (400, 200, 3)
resize = Resize(['gt_img'], (-1, 200))
results = dict(gt_img=self.results['gt'].copy().transpose(1, 0, 2))
resize_results = resize(results)
assert resize.scale == (np.inf, 200)
assert resize_results['gt_img'].shape == (200, 400, 3)
results = dict(gt_img=self.results['img'].copy())
resize_keep_ratio = Resize(['gt_img'], scale=0.5, keep_ratio=True)
results = resize_keep_ratio(results)
assert results['gt_img'].shape[:2] == (128, 128)
assert results['scale_factor'] == 0.5
results = dict(gt_img=self.results['img'].copy())
resize_keep_ratio = Resize(['gt_img'],
scale=(128, 128),
keep_ratio=False)
results = resize_keep_ratio(results)
assert results['gt_img'].shape[:2] == (128, 128)
# test input with shape (256, 256)
results = dict(gt_img=self.results['img'][..., 0].copy())
resize = Resize(['gt_img'], scale=(128, 128), keep_ratio=False)
results = resize(results)
assert results['gt_img'].shape == (128, 128, 1)
name_ = str(resize_keep_ratio)
assert name_ == resize_keep_ratio.__class__.__name__ + (
f"(keys={['gt_img']}, scale=(128, 128), "
f'keep_ratio={False}, size_factor=None, '
'max_size=None,interpolation=bilinear)')
def test_random_img_noise():
img = np.random.randn(256, 128, 3).astype(np.float32)
results = dict(img=copy.deepcopy(img))
noise_uniform = RandomImgNoise(['img'], 1, 2, distribution='uniform')
results = noise_uniform(results)
assert (results['img'] - img <= 2).all()
assert (results['img'] - img >= 1).all()
repr_str = noise_uniform.__class__.__name__
repr_str += (f'(keys={noise_uniform.keys}, '
f'lower_bound={noise_uniform.lower_bound}, '
f'upper_bound={noise_uniform.upper_bound})')
assert str(noise_uniform) == repr_str
img = np.random.randn(256, 128, 3).astype(np.float32)
results = dict(img=copy.deepcopy(img))
noise_normal = RandomImgNoise(['img'], distribution='normal')
results = noise_normal(results)
assert (results['img'] - img <= 1 / 128.).all()
assert (results['img'] - img >= 0).all()
repr_str = noise_normal.__class__.__name__
repr_str += (f'(keys={noise_normal.keys}, '
f'lower_bound={noise_normal.lower_bound}, '
f'upper_bound={noise_normal.upper_bound})')
assert str(noise_normal) == repr_str
with pytest.raises(AssertionError):
RandomImgNoise([])
with pytest.raises(KeyError):
RandomImgNoise(['img'], distribution='test')
def test_random_long_edge_crop():
results = dict(img=np.random.rand(256, 128, 3).astype(np.float32))
crop = RandomCropLongEdge(['img'])
results = crop(results)
assert results['img'].shape == (128, 128, 3)
repr_str = crop.__class__.__name__
repr_str += (f'(keys={crop.keys})')
assert str(crop) == repr_str
def test_center_long_edge_crop():
results = dict(img=np.random.rand(256, 128, 3).astype(np.float32))
crop = CenterCropLongEdge(['img'])
results = crop(results)
assert results['img'].shape == (128, 128, 3)
repr_str = crop.__class__.__name__
repr_str += (f'(keys={crop.keys})')
assert str(crop) == repr_str
def test_numpy_pad():
results = dict(img=np.zeros((5, 5, 1)))
pad = NumpyPad(['img'], ((2, 2), (0, 0), (0, 0)))
results = pad(results)
assert results['img'].shape == (9, 5, 1)
repr_str = pad.__class__.__name__
repr_str += (
f'(keys={pad.keys}, padding={pad.padding}, kwargs={pad.kwargs})')
assert str(pad) == repr_str
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import pytest
from mmgen.datasets.pipelines import Compose, ImageToTensor
def check_keys_equal(result_keys, target_keys):
"""Check if all elements in target_keys is in result_keys."""
return set(target_keys) == set(result_keys)
def test_compose():
with pytest.raises(TypeError):
Compose('LoadAlpha')
target_keys = ['img', 'meta']
img = np.random.randn(256, 256, 3)
results = dict(img=img, abandoned_key=None, img_name='test_image.png')
test_pipeline = [
dict(type='Collect', keys=['img'], meta_keys=['img_name']),
dict(type='ImageToTensor', keys=['img'])
]
compose = Compose(test_pipeline)
compose_results = compose(results)
assert check_keys_equal(compose_results.keys(), target_keys)
assert check_keys_equal(compose_results['meta'].data.keys(), ['img_name'])
results = None
image_to_tensor = ImageToTensor(keys=[])
test_pipeline = [image_to_tensor]
compose = Compose(test_pipeline)
compose_results = compose(results)
assert compose_results is None
assert repr(compose) == (
compose.__class__.__name__ + f'(\n {image_to_tensor}\n)')
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import numpy as np
import pytest
from mmgen.datasets.pipelines import Crop, FixedCrop
class TestAugmentations(object):
@classmethod
def setup_class(cls):
cls.results = dict()
cls.img_gt = np.random.rand(256, 128, 3).astype(np.float32)
cls.img_lq = np.random.rand(64, 32, 3).astype(np.float32)
cls.results = dict(
lq=cls.img_lq,
gt=cls.img_gt,
scale=4,
lq_path='fake_lq_path',
gt_path='fake_gt_path')
cls.results['img'] = np.random.rand(256, 256, 3).astype(np.float32)
cls.results['img_a'] = np.random.rand(286, 286, 3).astype(np.float32)
cls.results['img_b'] = np.random.rand(286, 286, 3).astype(np.float32)
@staticmethod
def check_crop(result_img_shape, result_bbox):
crop_w = result_bbox[2] - result_bbox[0]
"""Check if the result_bbox is in correspond to result_img_shape."""
crop_h = result_bbox[3] - result_bbox[1]
crop_shape = (crop_h, crop_w)
return result_img_shape == crop_shape
@staticmethod
def check_crop_around_semi(alpha):
return ((alpha > 0) & (alpha < 255)).any()
@staticmethod
def check_keys_contain(result_keys, target_keys):
"""Check if all elements in target_keys is in result_keys."""
return set(target_keys).issubset(set(result_keys))
def test_crop(self):
with pytest.raises(TypeError):
Crop(['img'], (0.23, 0.1))
# test center crop
results = copy.deepcopy(self.results)
center_crop = Crop(['img'], crop_size=(128, 128), random_crop=False)
results = center_crop(results)
assert results['img_crop_bbox'] == [64, 64, 128, 128]
assert np.array_equal(self.results['img'][64:192, 64:192, :],
results['img'])
# test random crop
results = copy.deepcopy(self.results)
random_crop = Crop(['img'], crop_size=(128, 128), random_crop=True)
results = random_crop(results)
assert 0 <= results['img_crop_bbox'][0] <= 128
assert 0 <= results['img_crop_bbox'][1] <= 128
assert results['img_crop_bbox'][2] == 128
assert results['img_crop_bbox'][3] == 128
# test random crop for lager size than the original shape
results = copy.deepcopy(self.results)
random_crop = Crop(['img'], crop_size=(512, 512), random_crop=True)
results = random_crop(results)
assert np.array_equal(self.results['img'], results['img'])
assert str(random_crop) == (
random_crop.__class__.__name__ +
"keys=['img'], crop_size=(512, 512), random_crop=True")
def test_fixed_crop(self):
with pytest.raises(TypeError):
FixedCrop(['img_a', 'img_b'], (0.23, 0.1))
with pytest.raises(TypeError):
FixedCrop(['img_a', 'img_b'], (256, 256), (0, 0.1))
# test shape consistency
results = copy.deepcopy(self.results)
fixed_crop = FixedCrop(['img_a', 'img'], crop_size=(128, 128))
with pytest.raises(ValueError):
results = fixed_crop(results)
# test given pos crop
results = copy.deepcopy(self.results)
given_pos_crop = FixedCrop(['img_a', 'img_b'],
crop_size=(256, 256),
crop_pos=(1, 1))
results = given_pos_crop(results)
assert results['img_a_crop_bbox'] == [1, 1, 256, 256]
assert results['img_b_crop_bbox'] == [1, 1, 256, 256]
assert np.array_equal(self.results['img_a'][1:257, 1:257, :],
results['img_a'])
assert np.array_equal(self.results['img_b'][1:257, 1:257, :],
results['img_b'])
# test given pos crop if pos > suitable pos
results = copy.deepcopy(self.results)
given_pos_crop = FixedCrop(['img_a', 'img_b'],
crop_size=(256, 256),
crop_pos=(280, 280))
results = given_pos_crop(results)
assert results['img_a_crop_bbox'] == [280, 280, 6, 6]
assert results['img_b_crop_bbox'] == [280, 280, 6, 6]
assert np.array_equal(self.results['img_a'][280:, 280:, :],
results['img_a'])
assert np.array_equal(self.results['img_b'][280:, 280:, :],
results['img_b'])
assert str(given_pos_crop) == (
given_pos_crop.__class__.__name__ +
"keys=['img_a', 'img_b'], crop_size=(256, 256), " +
'crop_pos=(280, 280)')
# test random initialized fixed crop
results = copy.deepcopy(self.results)
random_fixed_crop = FixedCrop(['img_a', 'img_b'],
crop_size=(256, 256),
crop_pos=None)
results = random_fixed_crop(results)
assert 0 <= results['img_a_crop_bbox'][0] <= 30
assert 0 <= results['img_a_crop_bbox'][1] <= 30
assert results['img_a_crop_bbox'][2] == 256
assert results['img_a_crop_bbox'][3] == 256
x_offset, y_offset, crop_w, crop_h = results['img_a_crop_bbox']
assert x_offset == results['img_b_crop_bbox'][0]
assert y_offset == results['img_b_crop_bbox'][1]
assert crop_w == results['img_b_crop_bbox'][2]
assert crop_h == results['img_b_crop_bbox'][3]
assert np.array_equal(
self.results['img_a'][y_offset:y_offset + crop_h,
x_offset:x_offset + crop_w, :],
results['img_a'])
assert np.array_equal(
self.results['img_b'][y_offset:y_offset + crop_h,
x_offset:x_offset + crop_w, :],
results['img_b'])
# test given pos crop for lager size than the original shape
results = copy.deepcopy(self.results)
given_pos_crop = FixedCrop(['img_a', 'img_b'],
crop_size=(512, 512),
crop_pos=(1, 1))
results = given_pos_crop(results)
assert results['img_a_crop_bbox'] == [1, 1, 285, 285]
assert results['img_b_crop_bbox'] == [1, 1, 285, 285]
assert np.array_equal(self.results['img_a'][1:, 1:, :],
results['img_a'])
assert np.array_equal(self.results['img_b'][1:, 1:, :],
results['img_b'])
assert str(given_pos_crop) == (
given_pos_crop.__class__.__name__ +
"keys=['img_a', 'img_b'], crop_size=(512, 512), crop_pos=(1, 1)")
# test random initialized fixed crop for lager size
# than the original shape
results = copy.deepcopy(self.results)
random_fixed_crop = FixedCrop(['img_a', 'img_b'],
crop_size=(512, 512),
crop_pos=None)
results = random_fixed_crop(results)
assert results['img_a_crop_bbox'] == [0, 0, 286, 286]
assert results['img_b_crop_bbox'] == [0, 0, 286, 286]
assert np.array_equal(self.results['img_a'], results['img_a'])
assert np.array_equal(self.results['img_b'], results['img_b'])
assert str(random_fixed_crop) == (
random_fixed_crop.__class__.__name__ +
"keys=['img_a', 'img_b'], crop_size=(512, 512), crop_pos=None")
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import pytest
import torch
from mmgen.datasets.pipelines import Collect, ImageToTensor, ToTensor
def check_keys_contain(result_keys, target_keys):
"""Check if all elements in target_keys is in result_keys."""
return set(target_keys).issubset(set(result_keys))
def test_to_tensor():
to_tensor = ToTensor(['str'])
with pytest.raises(TypeError):
results = dict(str='0')
to_tensor(results)
target_keys = ['tensor', 'numpy', 'sequence', 'int', 'float']
to_tensor = ToTensor(target_keys)
ori_results = dict(
tensor=torch.randn(2, 3),
numpy=np.random.randn(2, 3),
sequence=list(range(10)),
int=1,
float=0.1)
results = to_tensor(ori_results)
assert check_keys_contain(results.keys(), target_keys)
for key in target_keys:
assert isinstance(results[key], torch.Tensor)
assert torch.equal(results[key].data, ori_results[key])
# Add an additional key which is not in keys.
ori_results = dict(
tensor=torch.randn(2, 3),
numpy=np.random.randn(2, 3),
sequence=list(range(10)),
int=1,
float=0.1,
str='test')
results = to_tensor(ori_results)
assert check_keys_contain(results.keys(), target_keys)
for key in target_keys:
assert isinstance(results[key], torch.Tensor)
assert torch.equal(results[key].data, ori_results[key])
assert repr(
to_tensor) == to_tensor.__class__.__name__ + f'(keys={target_keys})'
def test_image_to_tensor():
ori_results = dict(img=np.random.randn(256, 256, 3))
keys = ['img']
to_float32 = False
image_to_tensor = ImageToTensor(keys)
results = image_to_tensor(ori_results)
assert results['img'].shape == torch.Size([3, 256, 256])
assert isinstance(results['img'], torch.Tensor)
assert torch.equal(results['img'].data, ori_results['img'])
assert results['img'].dtype == torch.float32
ori_results = dict(img=np.random.randint(256, size=(256, 256)))
keys = ['img']
to_float32 = True
image_to_tensor = ImageToTensor(keys)
results = image_to_tensor(ori_results)
assert results['img'].shape == torch.Size([1, 256, 256])
assert isinstance(results['img'], torch.Tensor)
assert torch.equal(results['img'].data, ori_results['img'])
assert results['img'].dtype == torch.float32
assert repr(image_to_tensor) == (
image_to_tensor.__class__.__name__ +
f'(keys={keys}, to_float32={to_float32})')
def test_collect():
inputs = dict(
img=np.random.randn(256, 256, 3),
label=[1],
img_name='test_image.png',
ori_shape=(256, 256, 3),
img_shape=(256, 256, 3),
pad_shape=(256, 256, 3),
flip_direction='vertical',
img_norm_cfg=dict(to_bgr=False))
keys = ['img', 'label']
meta_keys = ['img_shape', 'img_name', 'ori_shape']
collect = Collect(keys, meta_keys=meta_keys)
results = collect(inputs)
assert set(list(results.keys())) == set(['img', 'label', 'meta'])
inputs.pop('img')
assert set(results['meta'].data.keys()) == set(meta_keys)
for key in results['meta'].data:
assert results['meta'].data[key] == inputs[key]
assert repr(collect) == (
collect.__class__.__name__ +
f'(keys={keys}, meta_keys={collect.meta_keys})')
# Copyright (c) OpenMMLab. All rights reserved.
from pathlib import Path
import mmcv
import numpy as np
from mmgen.datasets import LoadImageFromFile
def test_load_image_from_file():
path_baboon = Path(
__file__).parent / '..' / '..' / 'data' / 'image' / 'baboon.png'
img_baboon = mmcv.imread(str(path_baboon), flag='color')
# read gt image
# input path is Path object
results = dict(gt_path=path_baboon)
config = dict(io_backend='disk', key='gt')
image_loader = LoadImageFromFile(**config)
results = image_loader(results)
assert results['gt'].shape == (480, 500, 3)
np.testing.assert_almost_equal(results['gt'], img_baboon)
assert results['gt_path'] == str(path_baboon)
# input path is str
results = dict(gt_path=str(path_baboon))
results = image_loader(results)
assert results['gt'].shape == (480, 500, 3)
np.testing.assert_almost_equal(results['gt'], img_baboon)
assert results['gt_path'] == str(path_baboon)
assert repr(image_loader) == (
image_loader.__class__.__name__ +
('(io_backend=disk, key=gt, '
'flag=color, save_original_img=False)'))
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import pytest
from mmgen.datasets.pipelines import Normalize
class TestAugmentations(object):
@staticmethod
def assert_img_equal(img, ref_img, ratio_thr=0.999):
"""Check if img and ref_img are matched approximately."""
assert img.shape == ref_img.shape
assert img.dtype == ref_img.dtype
area = ref_img.shape[-1] * ref_img.shape[-2]
diff = np.abs(img.astype('int32') - ref_img.astype('int32'))
assert np.sum(diff <= 1) / float(area) > ratio_thr
@staticmethod
def check_keys_contain(result_keys, target_keys):
"""Check if all elements in target_keys is in result_keys."""
return set(target_keys).issubset(set(result_keys))
def check_normalize(self, origin_img, result_img, norm_cfg):
"""Check if the origin_img are normalized correctly into result_img in
a given norm_cfg."""
target_img = result_img.copy()
target_img *= norm_cfg['std'][None, None, :]
target_img += norm_cfg['mean'][None, None, :]
if norm_cfg['to_rgb']:
target_img = target_img[:, ::-1, ...].copy()
self.assert_img_equal(origin_img, target_img)
def test_normalize(self):
with pytest.raises(TypeError):
Normalize(['alpha'], dict(mean=[123.675, 116.28, 103.53]),
[58.395, 57.12, 57.375])
with pytest.raises(TypeError):
Normalize(['alpha'], [123.675, 116.28, 103.53],
dict(std=[58.395, 57.12, 57.375]))
target_keys = ['merged', 'img_norm_cfg']
merged = np.random.rand(240, 320, 3).astype(np.float32)
results = dict(merged=merged)
config = dict(
keys=['merged'],
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
to_rgb=False)
normalize = Normalize(**config)
normalize_results = normalize(results)
assert self.check_keys_contain(normalize_results.keys(), target_keys)
self.check_normalize(merged, normalize_results['merged'],
normalize_results['img_norm_cfg'])
merged = np.random.rand(240, 320, 3).astype(np.float32)
results = dict(merged=merged)
config = dict(
keys=['merged'],
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
to_rgb=True)
normalize = Normalize(**config)
normalize_results = normalize(results)
assert self.check_keys_contain(normalize_results.keys(), target_keys)
self.check_normalize(merged, normalize_results['merged'],
normalize_results['img_norm_cfg'])
assert normalize.__repr__() == (
normalize.__class__.__name__ +
f"(keys={ ['merged']}, mean={np.array([123.675, 116.28, 103.53])},"
f' std={np.array([58.395, 57.12, 57.375])}, to_rgb=True)')
# input is an image list
merged = np.random.rand(240, 320, 3).astype(np.float32)
merged_2 = np.random.rand(240, 320, 3).astype(np.float32)
results = dict(merged=[merged, merged_2])
config = dict(
keys=['merged'],
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
to_rgb=False)
normalize = Normalize(**config)
normalize_results = normalize(results)
assert self.check_keys_contain(normalize_results.keys(), target_keys)
self.check_normalize(merged, normalize_results['merged'][0],
normalize_results['img_norm_cfg'])
self.check_normalize(merged_2, normalize_results['merged'][1],
normalize_results['img_norm_cfg'])
merged = np.random.rand(240, 320, 3).astype(np.float32)
merged_2 = np.random.rand(240, 320, 3).astype(np.float32)
results = dict(merged=[merged, merged_2])
config = dict(
keys=['merged'],
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
to_rgb=True)
normalize = Normalize(**config)
normalize_results = normalize(results)
assert self.check_keys_contain(normalize_results.keys(), target_keys)
self.check_normalize(merged, normalize_results['merged'][0],
normalize_results['img_norm_cfg'])
self.check_normalize(merged_2, normalize_results['merged'][1],
normalize_results['img_norm_cfg'])
# Copyright (c) OpenMMLab. All rights reserved.
from mmgen.datasets.quick_test_dataset import QuickTestImageDataset
class TestQuickTest:
@classmethod
def setup_class(cls):
cls.dataset = QuickTestImageDataset(size=(256, 256))
def test_quicktest_dataset(self):
assert len(self.dataset) == 10000
img = self.dataset[2]
assert img['real_img'].shape == (3, 256, 256)
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from mmgen.datasets import SinGANDataset
class TestSinGANDataset(object):
@classmethod
def setup_class(cls):
cls.imgs_root = osp.join(
osp.dirname(osp.dirname(__file__)), 'data/image/baboon.png')
cls.min_size = 25
cls.max_size = 250
cls.scale_factor_init = 0.75
def test_singan_dataset(self):
dataset = SinGANDataset(
self.imgs_root,
min_size=self.min_size,
max_size=self.max_size,
scale_factor_init=self.scale_factor_init)
assert len(dataset) == 1000000
data_dict = dataset[0]
assert all([f'real_scale{i}' in data_dict for i in range(10)])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment