Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
mmgeneration
Commits
c9a48a52
Commit
c9a48a52
authored
Jun 16, 2025
by
limm
Browse files
add tests code
parent
b7536f78
Pipeline
#2778
canceled with stages
Changes
64
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3291 additions
and
0 deletions
+3291
-0
tests/test_datasets/test_unconditional_image_dataset.py
tests/test_datasets/test_unconditional_image_dataset.py
+24
-0
tests/test_datasets/test_unpaired_image_dataset.py
tests/test_datasets/test_unpaired_image_dataset.py
+60
-0
tests/test_losses/test_ddpm_loss.py
tests/test_losses/test_ddpm_loss.py
+186
-0
tests/test_losses/test_disc_auxilary_loss.py
tests/test_losses/test_disc_auxilary_loss.py
+185
-0
tests/test_losses/test_gan_loss.py
tests/test_losses/test_gan_loss.py
+78
-0
tests/test_losses/test_gen_auxiliary_loss.py
tests/test_losses/test_gen_auxiliary_loss.py
+200
-0
tests/test_losses/test_pixelwise_loss.py
tests/test_losses/test_pixelwise_loss.py
+196
-0
tests/test_models/test_base_ddpm.py
tests/test_models/test_base_ddpm.py
+19
-0
tests/test_models/test_base_gan.py
tests/test_models/test_base_gan.py
+1
-0
tests/test_models/test_basic_conditional_gan.py
tests/test_models/test_basic_conditional_gan.py
+181
-0
tests/test_models/test_cyclegan.py
tests/test_models/test_cyclegan.py
+327
-0
tests/test_models/test_ddpm.py
tests/test_models/test_ddpm.py
+613
-0
tests/test_models/test_mspie_styelgan2.py
tests/test_models/test_mspie_styelgan2.py
+81
-0
tests/test_models/test_pggan.py
tests/test_models/test_pggan.py
+195
-0
tests/test_models/test_pix2pix.py
tests/test_models/test_pix2pix.py
+252
-0
tests/test_models/test_sagan.py
tests/test_models/test_sagan.py
+93
-0
tests/test_models/test_singan.py
tests/test_models/test_singan.py
+182
-0
tests/test_models/test_sngan_proj.py
tests/test_models/test_sngan_proj.py
+91
-0
tests/test_models/test_static_unconditional_gan.py
tests/test_models/test_static_unconditional_gan.py
+236
-0
tests/test_models/test_stylegan1.py
tests/test_models/test_stylegan1.py
+91
-0
No files found.
tests/test_datasets/test_unconditional_image_dataset.py
0 → 100644
View file @
c9a48a52
# Copyright (c) OpenMMLab. All rights reserved.
import
os.path
as
osp
from
mmgen.datasets
import
UnconditionalImageDataset
class
TestUnconditionalImageDataset
(
object
):
@
classmethod
def
setup_class
(
cls
):
cls
.
imgs_root
=
osp
.
join
(
osp
.
dirname
(
__file__
),
'..'
,
'data/image'
)
cls
.
default_pipeline
=
[
dict
(
type
=
'LoadImageFromFile'
,
io_backend
=
'disk'
,
key
=
'real_img'
)
]
def
test_unconditional_imgs_dataset
(
self
):
dataset
=
UnconditionalImageDataset
(
self
.
imgs_root
,
pipeline
=
self
.
default_pipeline
)
assert
len
(
dataset
)
==
6
img
=
dataset
[
2
][
'real_img'
]
assert
img
.
ndim
==
3
assert
repr
(
dataset
)
==
(
f
'dataset_name:
{
dataset
.
__class__
}
, '
f
'total
{
6
}
images in imgs_root:
{
self
.
imgs_root
}
'
)
tests/test_datasets/test_unpaired_image_dataset.py
0 → 100644
View file @
c9a48a52
# Copyright (c) OpenMMLab. All rights reserved.
import
os.path
as
osp
from
mmgen.datasets
import
UnpairedImageDataset
class
TestUnpairedImageDataset
(
object
):
@
classmethod
def
setup_class
(
cls
):
cls
.
imgs_root
=
osp
.
join
(
osp
.
dirname
(
osp
.
dirname
(
__file__
)),
'data/unpaired'
)
img_norm_cfg
=
dict
(
mean
=
[
0.5
,
0.5
,
0.5
],
std
=
[
0.5
,
0.5
,
0.5
])
cls
.
default_pipeline
=
[
dict
(
type
=
'LoadImageFromFile'
,
io_backend
=
'disk'
,
key
=
'img_a'
,
flag
=
'color'
),
dict
(
type
=
'LoadImageFromFile'
,
io_backend
=
'disk'
,
key
=
'img_b'
,
flag
=
'color'
),
dict
(
type
=
'Resize'
,
keys
=
[
'img_a'
,
'img_b'
],
scale
=
(
286
,
286
),
interpolation
=
'bicubic'
),
dict
(
type
=
'Crop'
,
keys
=
[
'img_a'
,
'img_b'
],
crop_size
=
(
256
,
256
),
random_crop
=
True
),
dict
(
type
=
'Flip'
,
keys
=
[
'img_a'
],
direction
=
'horizontal'
),
dict
(
type
=
'Flip'
,
keys
=
[
'img_b'
],
direction
=
'horizontal'
),
dict
(
type
=
'RescaleToZeroOne'
,
keys
=
[
'img_a'
,
'img_b'
]),
dict
(
type
=
'Normalize'
,
keys
=
[
'img_a'
,
'img_b'
],
to_rgb
=
True
,
**
img_norm_cfg
),
dict
(
type
=
'ImageToTensor'
,
keys
=
[
'img_a'
,
'img_b'
]),
dict
(
type
=
'Collect'
,
keys
=
[
'img_a'
,
'img_b'
],
meta_keys
=
[
'img_a_path'
,
'img_b_path'
])
]
def
test_unpaired_image_dataset
(
self
):
dataset
=
UnpairedImageDataset
(
self
.
imgs_root
,
pipeline
=
self
.
default_pipeline
,
domain_a
=
'a'
,
domain_b
=
'b'
)
assert
len
(
dataset
)
==
2
img
=
dataset
[
0
][
'img_a'
]
assert
img
.
ndim
==
3
img
=
dataset
[
0
][
'img_b'
]
assert
img
.
ndim
==
3
tests/test_losses/test_ddpm_loss.py
0 → 100644
View file @
c9a48a52
# Copyright (c) OpenMMLab. All rights reserved.
from
copy
import
deepcopy
import
numpy
as
np
import
pytest
import
torch
from
mmgen.models.builder
import
build_module
from
mmgen.models.losses.pixelwise_loss
import
(
DiscretizedGaussianLogLikelihoodLoss
,
GaussianKLDLoss
,
MSELoss
)
class
TestDDPMVLBLoss
:
@
classmethod
def
setup_class
(
cls
):
cls
.
gaussian_kld_data_info
=
dict
(
mean_pred
=
'mean_pred'
,
mean_target
=
'mean_posterior'
,
logvar_pred
=
'logvar_pred'
,
logvar_target
=
'logvar_posterior'
)
cls
.
disc_log_likelihood_data_info
=
dict
(
x
=
'real_imgs'
,
mean
=
'mean_pred'
,
logvar
=
'logvar_pred'
)
cls
.
config
=
dict
(
type
=
'DDPMVLBLoss'
,
rescale_mode
=
'constant'
,
rescale_cfg
=
dict
(
scale
=
4
),
data_info
=
cls
.
gaussian_kld_data_info
,
data_info_t_0
=
cls
.
disc_log_likelihood_data_info
,
log_cfgs
=
[
dict
(
type
=
'quartile'
,
prefix_name
=
'loss_vlb'
,
total_timesteps
=
4
),
dict
(
type
=
'name'
)
])
cls
.
t
=
torch
.
LongTensor
([
0
,
1
,
2
,
3
])
cls
.
tar_shape
=
[
4
,
2
,
4
,
4
]
cls
.
mean_pred
=
torch
.
randn
(
cls
.
tar_shape
)
cls
.
logvar_pred
=
torch
.
randn
(
cls
.
tar_shape
)
cls
.
mean_posterior
=
torch
.
randn
(
cls
.
tar_shape
)
cls
.
logvar_posterior
=
torch
.
randn
(
cls
.
tar_shape
)
cls
.
real_imgs
=
torch
.
randn
(
cls
.
tar_shape
)
cls
.
label
=
[
0
,
18
,
1
,
5
]
cls
.
output_dict
=
dict
(
mean_pred
=
cls
.
mean_pred
,
logvar_pred
=
cls
.
logvar_pred
,
mean_posterior
=
cls
.
mean_posterior
,
logvar_posterior
=
cls
.
logvar_posterior
,
real_imgs
=
cls
.
real_imgs
,
label
=
cls
.
label
,
meta_info
=
None
,
timesteps
=
cls
.
t
)
# calculate loss manually
cls
.
loss_gaussian_kld
=
GaussianKLDLoss
(
data_info
=
cls
.
gaussian_kld_data_info
,
reduction
=
'flatmean'
,
base
=
'2'
)(
cls
.
output_dict
)
cls
.
loss_disc_likelihood
=
DiscretizedGaussianLogLikelihoodLoss
(
data_info
=
cls
.
disc_log_likelihood_data_info
,
reduction
=
'flatmean'
,
base
=
'2'
)(
cls
.
output_dict
)
cls
.
loss_manually
=
(
-
cls
.
loss_disc_likelihood
[
0
]
+
cls
.
loss_gaussian_kld
[
1
:].
sum
())
/
4
# TODO: unit test for sampler would be add later
cls
.
weight
=
torch
.
rand
(
4
,
)
def
test_vlb_loss
(
self
):
# test forward
config
=
deepcopy
(
self
.
config
)
loss_fn
=
build_module
(
config
)
loss
=
loss_fn
(
self
.
output_dict
)
np
.
allclose
(
loss
,
self
.
loss_manually
*
4
)
# test log_cfgs --> dict input
config
=
deepcopy
(
self
.
config
)
config
[
'log_cfgs'
]
=
dict
(
type
=
'name'
)
loss_fn
=
build_module
(
config
)
assert
isinstance
(
loss_fn
.
log_fn_list
,
list
)
# test log_cfgs --> no log_cfgs
config
=
deepcopy
(
self
.
config
)
config
[
'log_cfgs'
]
=
None
loss_fn
=
build_module
(
config
)
loss
=
loss_fn
(
self
.
output_dict
)
assert
not
loss_fn
.
log_vars
# test rescale_cfg --> rescale is None
config
=
deepcopy
(
self
.
config
)
config
[
'rescale_mode'
]
=
None
loss_fn
=
build_module
(
config
)
loss
=
loss_fn
(
self
.
output_dict
)
np
.
allclose
(
loss
,
self
.
loss_manually
)
# TODO: test rescale_cfg --> test sampler
# test rescale_cfg --> test weight
config
=
deepcopy
(
self
.
config
)
config
[
'rescale_mode'
]
=
'timestep_weight'
weight
=
self
.
weight
.
clone
()
loss_fn
=
build_module
(
config
,
default_args
=
dict
(
weight
=
weight
))
loss
=
loss_fn
(
self
.
output_dict
)
loss_weighted_manually
=
(
-
(
self
.
loss_disc_likelihood
*
weight
)[
0
]
+
(
self
.
loss_gaussian_kld
*
weight
)[
1
:].
sum
())
/
4
np
.
allclose
(
loss
,
loss_weighted_manually
)
# test rescale_cfg --> change weight
weight
[
0
]
+=
1
loss
=
loss_fn
(
self
.
output_dict
)
loss_weighted_manually
=
(
-
(
self
.
loss_disc_likelihood
*
weight
)[
0
]
+
(
self
.
loss_gaussian_kld
*
weight
)[
1
:].
sum
())
/
4
np
.
allclose
(
loss
,
loss_weighted_manually
)
# test t = 0
config
=
deepcopy
(
self
.
config
)
output_dict
=
deepcopy
(
self
.
output_dict
)
output_dict
[
'timesteps'
][
0
]
=
1
loss_fn
=
build_module
(
config
)
loss
=
loss_fn
(
output_dict
)
assert
loss_fn
.
log_vars
[
'loss_vlb_quartile_0'
]
==
0
assert
loss_fn
.
log_vars
[
'loss_DiscGaussianLogLikelihood'
]
==
0
class
TestDDPMMSELoss
:
@
classmethod
def
setup_class
(
cls
):
cls
.
data_info
=
dict
(
pred
=
'eps_t_pred'
,
target
=
'noise'
)
cls
.
config
=
dict
(
type
=
'DDPMMSELoss'
,
data_info
=
cls
.
data_info
,
log_cfgs
=
dict
(
type
=
'quartile'
,
prefix_name
=
'loss_mse'
,
total_timesteps
=
4
))
cls
.
t
=
torch
.
LongTensor
([
0
,
1
,
2
,
3
])
cls
.
tar_shape
=
[
4
,
2
,
4
,
4
]
cls
.
eps_t_pred
=
torch
.
randn
(
cls
.
tar_shape
)
cls
.
noise
=
torch
.
randn
(
cls
.
tar_shape
)
cls
.
output_dict
=
dict
(
eps_t_pred
=
cls
.
eps_t_pred
,
noise
=
cls
.
noise
,
meta_info
=
None
,
timesteps
=
cls
.
t
)
cls
.
weight
=
torch
.
rand
(
4
,
)
# calculate loss manually
cls
.
loss_manually
=
0
for
idx
in
range
(
cls
.
tar_shape
[
0
]):
t
=
cls
.
t
[
idx
]
weight
=
cls
.
weight
[
t
]
output_dict_
=
dict
(
eps_t_pred
=
cls
.
eps_t_pred
[
t
],
noise
=
cls
.
noise
[
t
])
cls
.
loss_manually
+=
MSELoss
(
data_info
=
cls
.
data_info
)(
output_dict_
)
*
weight
cls
.
loss_manually
/=
4
def
test_mse_loss
(
self
):
# test forward
config
=
deepcopy
(
self
.
config
)
config
[
'rescale_mode'
]
=
'timestep_weight'
loss_fn
=
build_module
(
config
,
default_args
=
dict
(
weight
=
self
.
weight
))
loss
=
loss_fn
(
self
.
output_dict
)
np
.
allclose
(
loss
,
self
.
loss_manually
)
# test reduction raise error
config
=
deepcopy
(
self
.
config
)
config
[
'reduction'
]
=
'reduction'
with
pytest
.
raises
(
ValueError
):
loss_fn
=
build_module
(
config
)
# test return loss name
config
=
deepcopy
(
self
.
config
)
config
[
'loss_name'
]
=
'loss_name'
loss_fn
=
build_module
(
config
)
assert
loss_fn
.
loss_name
()
==
'loss_name'
tests/test_losses/test_disc_auxilary_loss.py
0 → 100644
View file @
c9a48a52
# Copyright (c) OpenMMLab. All rights reserved.
from
functools
import
partial
import
pytest
import
torch
from
mmgen.models.architectures
import
DCGANDiscriminator
from
mmgen.models.architectures.pggan.generator_discriminator
import
\
PGGANDiscriminator
from
mmgen.models.losses
import
(
DiscShiftLoss
,
GradientPenaltyLoss
,
disc_shift_loss
,
gradient_penalty_loss
)
from
mmgen.models.losses.disc_auxiliary_loss
import
(
R1GradientPenalty
,
r1_gradient_penalty_loss
)
class
TestDiscShiftLoss
(
object
):
@
classmethod
def
setup_class
(
cls
):
cls
.
input_tensor
=
torch
.
randn
((
2
,
10
))
cls
.
default_cfg
=
dict
(
loss_weight
=
0.1
,
data_info
=
dict
(
pred
=
'disc_pred'
))
cls
.
default_input_dict
=
dict
(
disc_pred
=
cls
.
input_tensor
)
def
test_disc_shift_loss
(
self
):
loss
=
disc_shift_loss
(
self
.
input_tensor
)
assert
loss
.
ndim
==
0
assert
loss
>=
0
loss
=
disc_shift_loss
(
self
.
input_tensor
,
weight
=-
0.1
)
assert
loss
.
ndim
==
0
assert
loss
<=
0
loss
=
disc_shift_loss
(
self
.
input_tensor
,
reduction
=
'none'
)
assert
loss
.
ndim
==
2
assert
(
loss
>=
0
).
all
()
loss_sum
=
disc_shift_loss
(
self
.
input_tensor
,
reduction
=
'sum'
)
loss_avg
=
disc_shift_loss
(
self
.
input_tensor
,
avg_factor
=
1000
)
assert
loss_avg
.
ndim
==
0
and
loss_sum
.
ndim
==
0
assert
loss_sum
>
loss_avg
with
pytest
.
raises
(
ValueError
):
_
=
disc_shift_loss
(
self
.
input_tensor
,
reduction
=
'sum'
,
avg_factor
=
100
)
def
test_module_wrapper
(
self
):
# test with default config
loss_module
=
DiscShiftLoss
(
**
self
.
default_cfg
)
loss
=
loss_module
(
self
.
default_input_dict
)
assert
loss
.
ndim
==
0
with
pytest
.
raises
(
NotImplementedError
):
_
=
loss_module
(
self
.
default_input_dict
,
1
)
with
pytest
.
raises
(
AssertionError
):
_
=
loss_module
(
1
,
outputs_dict
=
self
.
default_input_dict
)
input_
=
dict
(
outputs_dict
=
self
.
default_input_dict
)
loss
=
loss_module
(
**
input_
)
assert
loss
.
ndim
==
0
with
pytest
.
raises
(
AssertionError
):
_
=
loss_module
(
self
.
input_tensor
)
# test without data_info
loss_module
=
DiscShiftLoss
(
data_info
=
None
)
loss
=
loss_module
(
self
.
input_tensor
)
assert
loss
.
ndim
==
0
class
TestGradientPenalty
:
@
classmethod
def
setup_class
(
cls
):
cls
.
input_img
=
torch
.
randn
((
2
,
3
,
8
,
8
))
cls
.
disc
=
DCGANDiscriminator
(
input_scale
=
8
,
output_scale
=
4
,
out_channels
=
5
)
cls
.
pggan_disc
=
PGGANDiscriminator
(
in_scale
=
8
,
base_channels
=
32
,
max_channels
=
32
)
cls
.
data_info
=
dict
(
discriminator
=
'disc'
,
real_data
=
'real_imgs'
,
fake_data
=
'fake_imgs'
)
def
test_gp_loss
(
self
):
loss
=
gradient_penalty_loss
(
self
.
disc
,
self
.
input_img
,
self
.
input_img
)
assert
loss
>
0
loss
=
gradient_penalty_loss
(
self
.
disc
,
self
.
input_img
,
self
.
input_img
,
norm_mode
=
'HWC'
)
assert
loss
>
0
with
pytest
.
raises
(
NotImplementedError
):
_
=
gradient_penalty_loss
(
self
.
disc
,
self
.
input_img
,
self
.
input_img
,
norm_mode
=
'xxx'
)
loss
=
gradient_penalty_loss
(
self
.
disc
,
self
.
input_img
,
self
.
input_img
,
norm_mode
=
'HWC'
,
weight
=
10
)
assert
loss
>
0
loss
=
gradient_penalty_loss
(
self
.
disc
,
self
.
input_img
,
self
.
input_img
,
norm_mode
=
'HWC'
,
mask
=
torch
.
ones_like
(
self
.
input_img
),
weight
=
10
)
assert
loss
>
0
data_dict
=
dict
(
real_imgs
=
self
.
input_img
,
fake_imgs
=
self
.
input_img
,
disc
=
partial
(
self
.
pggan_disc
,
transition_weight
=
0.5
,
curr_scale
=
8
))
gp_loss
=
GradientPenaltyLoss
(
loss_weight
=
10
,
norm_mode
=
'pixel'
,
data_info
=
self
.
data_info
)
loss
=
gp_loss
(
data_dict
)
assert
loss
>
0
loss
=
gp_loss
(
outputs_dict
=
data_dict
)
assert
loss
>
0
with
pytest
.
raises
(
NotImplementedError
):
_
=
gp_loss
(
asdf
=
1.
)
with
pytest
.
raises
(
AssertionError
):
_
=
gp_loss
(
1.
)
with
pytest
.
raises
(
AssertionError
):
_
=
gp_loss
(
1.
,
2
,
outputs_dict
=
data_dict
)
class
TestR1GradientPenalty
:
@
classmethod
def
setup_class
(
cls
):
cls
.
data_info
=
dict
(
discriminator
=
'disc'
,
real_data
=
'real_imgs'
)
cls
.
disc
=
DCGANDiscriminator
(
input_scale
=
8
,
output_scale
=
4
,
out_channels
=
5
)
cls
.
pggan_disc
=
PGGANDiscriminator
(
in_scale
=
8
,
base_channels
=
32
,
max_channels
=
32
)
cls
.
input_img
=
torch
.
randn
((
2
,
3
,
8
,
8
))
def
test_r1_regularizer
(
self
):
loss
=
r1_gradient_penalty_loss
(
self
.
disc
,
self
.
input_img
)
assert
loss
>
0
loss
=
r1_gradient_penalty_loss
(
self
.
disc
,
self
.
input_img
,
norm_mode
=
'HWC'
)
assert
loss
>
0
with
pytest
.
raises
(
NotImplementedError
):
_
=
r1_gradient_penalty_loss
(
self
.
disc
,
self
.
input_img
,
norm_mode
=
'xxx'
)
loss
=
r1_gradient_penalty_loss
(
self
.
disc
,
self
.
input_img
,
norm_mode
=
'HWC'
,
weight
=
10
)
assert
loss
>
0
loss
=
r1_gradient_penalty_loss
(
self
.
disc
,
self
.
input_img
,
norm_mode
=
'HWC'
,
mask
=
torch
.
ones_like
(
self
.
input_img
),
weight
=
10
)
assert
loss
>
0
data_dict
=
dict
(
real_imgs
=
self
.
input_img
,
disc
=
partial
(
self
.
pggan_disc
,
transition_weight
=
0.5
,
curr_scale
=
8
))
gp_loss
=
R1GradientPenalty
(
loss_weight
=
10
,
norm_mode
=
'pixel'
,
data_info
=
self
.
data_info
)
loss
=
gp_loss
(
data_dict
)
assert
loss
>
0
loss
=
gp_loss
(
outputs_dict
=
data_dict
)
assert
loss
>
0
with
pytest
.
raises
(
NotImplementedError
):
_
=
gp_loss
(
asdf
=
1.
)
with
pytest
.
raises
(
AssertionError
):
_
=
gp_loss
(
1.
)
with
pytest
.
raises
(
AssertionError
):
_
=
gp_loss
(
1.
,
2
,
outputs_dict
=
data_dict
)
tests/test_losses/test_gan_loss.py
0 → 100644
View file @
c9a48a52
# Copyright (c) OpenMMLab. All rights reserved.
import
numpy.testing
as
npt
import
pytest
import
torch
from
mmgen.models.losses.gan_loss
import
GANLoss
def
test_gan_losses
():
"""Test gan losses."""
with
pytest
.
raises
(
NotImplementedError
):
GANLoss
(
'xixihaha'
,
loss_weight
=
1.0
,
real_label_val
=
1.0
,
fake_label_val
=
0.0
)
input_1
=
torch
.
ones
(
1
,
1
)
input_2
=
torch
.
ones
(
1
,
3
,
6
,
6
)
*
2
# vanilla
gan_loss
=
GANLoss
(
'vanilla'
,
loss_weight
=
2.0
,
real_label_val
=
1.0
,
fake_label_val
=
0.0
)
loss
=
gan_loss
(
input_1
,
True
,
is_disc
=
False
)
npt
.
assert_almost_equal
(
loss
.
item
(),
0.6265233
)
loss
=
gan_loss
(
input_1
,
False
,
is_disc
=
False
)
npt
.
assert_almost_equal
(
loss
.
item
(),
2.6265232
)
loss
=
gan_loss
(
input_1
,
True
,
is_disc
=
True
)
npt
.
assert_almost_equal
(
loss
.
item
(),
0.3132616
)
loss
=
gan_loss
(
input_1
,
False
,
is_disc
=
True
)
npt
.
assert_almost_equal
(
loss
.
item
(),
1.3132616
)
# lsgan
gan_loss
=
GANLoss
(
'lsgan'
,
loss_weight
=
2.0
,
real_label_val
=
1.0
,
fake_label_val
=
0.0
)
loss
=
gan_loss
(
input_2
,
True
,
is_disc
=
False
)
npt
.
assert_almost_equal
(
loss
.
item
(),
2.0
)
loss
=
gan_loss
(
input_2
,
False
,
is_disc
=
False
)
npt
.
assert_almost_equal
(
loss
.
item
(),
8.0
)
loss
=
gan_loss
(
input_2
,
True
,
is_disc
=
True
)
npt
.
assert_almost_equal
(
loss
.
item
(),
1.0
)
loss
=
gan_loss
(
input_2
,
False
,
is_disc
=
True
)
npt
.
assert_almost_equal
(
loss
.
item
(),
4.0
)
# wgan
gan_loss
=
GANLoss
(
'wgan'
,
loss_weight
=
2.0
,
real_label_val
=
1.0
,
fake_label_val
=
0.0
)
loss
=
gan_loss
(
input_2
,
True
,
is_disc
=
False
)
npt
.
assert_almost_equal
(
loss
.
item
(),
-
4.0
)
loss
=
gan_loss
(
input_2
,
False
,
is_disc
=
False
)
npt
.
assert_almost_equal
(
loss
.
item
(),
4
)
loss
=
gan_loss
(
input_2
,
True
,
is_disc
=
True
)
npt
.
assert_almost_equal
(
loss
.
item
(),
-
2.0
)
loss
=
gan_loss
(
input_2
,
False
,
is_disc
=
True
)
npt
.
assert_almost_equal
(
loss
.
item
(),
2.0
)
# wgan
gan_loss
=
GANLoss
(
'wgan-logistic-ns'
,
loss_weight
=
2.0
,
real_label_val
=
1.0
,
fake_label_val
=
0.0
)
loss
=
gan_loss
(
input_2
,
True
,
is_disc
=
False
)
assert
loss
.
item
()
>
0
loss
=
gan_loss
(
input_2
,
False
,
is_disc
=
False
)
assert
loss
.
item
()
>
0
# hinge
gan_loss
=
GANLoss
(
'hinge'
,
loss_weight
=
2.0
,
real_label_val
=
1.0
,
fake_label_val
=
0.0
)
loss
=
gan_loss
(
input_2
,
True
,
is_disc
=
False
)
npt
.
assert_almost_equal
(
loss
.
item
(),
-
4.0
)
loss
=
gan_loss
(
input_2
,
False
,
is_disc
=
False
)
npt
.
assert_almost_equal
(
loss
.
item
(),
-
4.0
)
loss
=
gan_loss
(
input_2
,
True
,
is_disc
=
True
)
npt
.
assert_almost_equal
(
loss
.
item
(),
0.0
)
loss
=
gan_loss
(
input_2
,
False
,
is_disc
=
True
)
npt
.
assert_almost_equal
(
loss
.
item
(),
3.0
)
tests/test_losses/test_gen_auxiliary_loss.py
0 → 100644
View file @
c9a48a52
# Copyright (c) OpenMMLab. All rights reserved.
import
pytest
import
torch
from
mmgen.models.architectures.pix2pix
import
UnetGenerator
from
mmgen.models.architectures.stylegan
import
StyleGANv2Generator
from
mmgen.models.losses
import
GeneratorPathRegularizer
,
PerceptualLoss
from
mmgen.models.losses.pixelwise_loss
import
l1_loss
,
mse_loss
class
TestPathRegularizer
:
@
classmethod
def
setup_class
(
cls
):
cls
.
data_info
=
dict
(
generator
=
'generator'
,
num_batches
=
'num_batches'
)
cls
.
gen
=
StyleGANv2Generator
(
32
,
10
,
num_mlps
=
2
)
def
test_path_regularizer_cpu
(
self
):
gen
=
self
.
gen
output_dict
=
dict
(
generator
=
gen
,
num_batches
=
2
)
pl
=
GeneratorPathRegularizer
(
data_info
=
self
.
data_info
)
pl_loss
=
pl
(
output_dict
)
assert
pl_loss
>
0
output_dict
=
dict
(
generator
=
gen
,
num_batches
=
2
,
iteration
=
3
)
pl
=
GeneratorPathRegularizer
(
data_info
=
self
.
data_info
,
interval
=
2
)
pl_loss
=
pl
(
outputs_dict
=
output_dict
)
assert
pl_loss
is
None
with
pytest
.
raises
(
NotImplementedError
):
_
=
pl
(
asdf
=
1.
)
with
pytest
.
raises
(
AssertionError
):
_
=
pl
(
1.
)
with
pytest
.
raises
(
AssertionError
):
_
=
pl
(
1.
,
2
,
outputs_dict
=
output_dict
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
()
or
not
hasattr
(
torch
.
backends
.
cudnn
,
'allow_tf32'
),
reason
=
'requires cuda'
)
def
test_path_regularizer_cuda
(
self
):
gen
=
self
.
gen
.
cuda
()
output_dict
=
dict
(
generator
=
gen
,
num_batches
=
2
)
pl
=
GeneratorPathRegularizer
(
data_info
=
self
.
data_info
).
cuda
()
pl_loss
=
pl
(
output_dict
)
assert
pl_loss
>
0
output_dict
=
dict
(
generator
=
gen
,
num_batches
=
2
,
iteration
=
3
)
pl
=
GeneratorPathRegularizer
(
data_info
=
self
.
data_info
,
interval
=
2
).
cuda
()
pl_loss
=
pl
(
outputs_dict
=
output_dict
)
assert
pl_loss
is
None
with
pytest
.
raises
(
NotImplementedError
):
_
=
pl
(
asdf
=
1.
)
with
pytest
.
raises
(
AssertionError
):
_
=
pl
(
1.
)
with
pytest
.
raises
(
AssertionError
):
_
=
pl
(
1.
,
2
,
outputs_dict
=
output_dict
)
class
TestPerceptualLoss
:
@
classmethod
def
setup_class
(
cls
):
cls
.
data_info
=
dict
(
pred
=
'fake_imgs'
,
target
=
'real_imgs'
)
cls
.
gen
=
UnetGenerator
(
3
,
3
)
def
test_perceptual_loss_cpu
(
self
):
unknown_h
,
unknown_w
=
(
32
,
32
)
weight
=
torch
.
zeros
(
1
,
1
,
64
,
64
)
weight
[
0
,
0
,
:
unknown_h
,
:
unknown_w
]
=
1
pred
=
weight
.
clone
()
target
=
weight
.
clone
()
*
2
perceptual_loss
=
PerceptualLoss
(
data_info
=
self
.
data_info
)
loss_perceptual
=
perceptual_loss
(
outputs_dict
=
dict
(
fake_imgs
=
pred
,
real_imgs
=
target
))
assert
loss_perceptual
.
shape
==
()
assert
id
(
perceptual_loss
.
criterion
)
==
id
(
l1_loss
)
def
test_only_perceptual_loss
(
self
):
unknown_h
,
unknown_w
=
(
32
,
32
)
weight
=
torch
.
zeros
(
1
,
1
,
64
,
64
)
weight
[
0
,
0
,
:
unknown_h
,
:
unknown_w
]
=
1
pred
=
weight
.
clone
()
target
=
weight
.
clone
()
*
2
perceptual_loss
=
PerceptualLoss
(
data_info
=
self
.
data_info
,
style_weight
=
0
)
loss_percep
=
perceptual_loss
(
dict
(
fake_imgs
=
pred
,
real_imgs
=
target
))
assert
loss_percep
.
shape
==
()
assert
perceptual_loss
.
style_weight
==
0
def
test_only_style_loss
(
self
):
unknown_h
,
unknown_w
=
(
32
,
32
)
weight
=
torch
.
zeros
(
1
,
1
,
64
,
64
)
weight
[
0
,
0
,
:
unknown_h
,
:
unknown_w
]
=
1
pred
=
weight
.
clone
()
target
=
weight
.
clone
()
*
2
perceptual_loss
=
PerceptualLoss
(
data_info
=
self
.
data_info
,
perceptual_weight
=
0
)
loss_style
=
perceptual_loss
(
dict
(
fake_imgs
=
pred
,
real_imgs
=
target
))
assert
loss_style
.
shape
==
()
assert
perceptual_loss
.
perceptual_weight
==
0
def
test_with_different_layer_weights
(
self
):
unknown_h
,
unknown_w
=
(
32
,
32
)
weight
=
torch
.
zeros
(
1
,
1
,
64
,
64
)
weight
[
0
,
0
,
:
unknown_h
,
:
unknown_w
]
=
1
pred
=
weight
.
clone
()
target
=
weight
.
clone
()
*
2
layer_weights
=
{
'1'
:
1.
,
'2'
:
2.
,
'3'
:
3.
}
perceptual_loss
=
PerceptualLoss
(
data_info
=
self
.
data_info
,
layer_weights
=
layer_weights
)
loss_perceptual
=
perceptual_loss
(
dict
(
fake_imgs
=
pred
,
real_imgs
=
target
))
assert
loss_perceptual
.
shape
==
()
assert
perceptual_loss
.
layer_weights
==
layer_weights
and
\
perceptual_loss
.
layer_weights_style
==
layer_weights
def
test_with_different_perceptual_and_style_layers
(
self
):
unknown_h
,
unknown_w
=
(
32
,
32
)
weight
=
torch
.
zeros
(
1
,
1
,
64
,
64
)
weight
[
0
,
0
,
:
unknown_h
,
:
unknown_w
]
=
1
pred
=
weight
.
clone
()
target
=
weight
.
clone
()
*
2
layer_weights
=
{
'1'
:
1.
,
'2'
:
2.
,
'3'
:
3.
}
layer_weights_style
=
{
'4'
:
4.
,
'5'
:
5.
,
'6'
:
6.
}
perceptual_loss
=
PerceptualLoss
(
data_info
=
self
.
data_info
,
layer_weights
=
layer_weights
,
layer_weights_style
=
layer_weights_style
)
loss_perceptual
=
perceptual_loss
(
dict
(
fake_imgs
=
pred
,
real_imgs
=
target
))
assert
loss_perceptual
.
shape
==
()
assert
perceptual_loss
.
layer_weights
==
layer_weights
and
\
perceptual_loss
.
layer_weights_style
==
layer_weights_style
def
test_MSE_critierion
(
self
):
unknown_h
,
unknown_w
=
(
32
,
32
)
weight
=
torch
.
zeros
(
1
,
1
,
64
,
64
)
weight
[
0
,
0
,
:
unknown_h
,
:
unknown_w
]
=
1
pred
=
weight
.
clone
()
target
=
weight
.
clone
()
*
2
perceptual_loss
=
PerceptualLoss
(
data_info
=
self
.
data_info
,
criterion
=
'mse'
)
loss_perceptual
=
perceptual_loss
(
outputs_dict
=
dict
(
fake_imgs
=
pred
,
real_imgs
=
target
))
assert
loss_perceptual
.
shape
==
()
assert
id
(
perceptual_loss
.
criterion
)
==
id
(
mse_loss
)
def
test_VGG_16
(
self
):
unknown_h
,
unknown_w
=
(
32
,
32
)
weight
=
torch
.
zeros
(
1
,
1
,
64
,
64
)
weight
[
0
,
0
,
:
unknown_h
,
:
unknown_w
]
=
1
pred
=
weight
.
clone
()
target
=
weight
.
clone
()
*
2
perceptual_loss
=
PerceptualLoss
(
data_info
=
self
.
data_info
,
vgg_type
=
'vgg16'
,
pretrained
=
'torchvision://vgg16'
)
loss_perceptual
=
perceptual_loss
(
outputs_dict
=
dict
(
fake_imgs
=
pred
,
real_imgs
=
target
))
assert
loss_perceptual
.
shape
==
()
# TODO need to check whether vgg16 is loaded
# assert perceptual_loss.vgg
def
test_split_style_loss
(
self
):
unknown_h
,
unknown_w
=
(
32
,
32
)
weight
=
torch
.
zeros
(
1
,
1
,
64
,
64
)
weight
[
0
,
0
,
:
unknown_h
,
:
unknown_w
]
=
1
pred
=
weight
.
clone
()
target
=
weight
.
clone
()
*
2
perceptual_loss
=
PerceptualLoss
(
data_info
=
self
.
data_info
,
split_style_loss
=
True
)
loss_percep
,
loss_style
=
perceptual_loss
(
outputs_dict
=
dict
(
fake_imgs
=
pred
,
real_imgs
=
target
))
assert
loss_percep
.
shape
==
()
and
loss_style
.
shape
==
()
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
'requires cuda'
)
def
test_perceptual_loss_cuda
(
self
):
pred
=
torch
.
rand
([
2
,
3
,
256
,
256
]).
cuda
()
target
=
torch
.
rand_like
(
pred
).
cuda
()
perceptual_loss
=
PerceptualLoss
(
data_info
=
self
.
data_info
).
cuda
()
loss_perceptual
=
perceptual_loss
(
outputs_dict
=
dict
(
fake_imgs
=
pred
,
real_imgs
=
target
))
assert
loss_perceptual
.
shape
==
()
tests/test_losses/test_pixelwise_loss.py
0 → 100644
View file @
c9a48a52
# Copyright (c) OpenMMLab. All rights reserved.
import
pytest
import
torch
from
torch.distributions.normal
import
Normal
from
mmgen.models.architectures.pix2pix
import
UnetGenerator
from
mmgen.models.losses
import
L1Loss
,
MSELoss
from
mmgen.models.losses.pixelwise_loss
import
(
DiscretizedGaussianLogLikelihoodLoss
,
GaussianKLDLoss
,
approx_gaussian_cdf
)
class
TestPixelwiseLosses
:
@
classmethod
def
setup_class
(
cls
):
cls
.
data_info
=
dict
(
pred
=
'fake_imgs'
,
target
=
'real_imgs'
)
cls
.
gen
=
UnetGenerator
(
3
,
3
)
def
test_pixelwise_losses
(
self
):
with
pytest
.
raises
(
ValueError
):
# only 'none', 'mean' and 'sum' are supported
L1Loss
(
reduction
=
'InvalidValue'
)
unknown_h
,
unknown_w
=
(
32
,
32
)
weight
=
torch
.
zeros
(
1
,
1
,
64
,
64
)
weight
[
0
,
0
,
:
unknown_h
,
:
unknown_w
]
=
1
pred
=
weight
.
clone
()
target
=
weight
.
clone
()
*
2
# test l1 loss
l1_loss
=
L1Loss
(
loss_weight
=
1.0
,
reduction
=
'mean'
,
data_info
=
self
.
data_info
)
loss
=
l1_loss
(
outputs_dict
=
dict
(
fake_imgs
=
pred
,
real_imgs
=
target
))
assert
loss
.
shape
==
()
assert
loss
.
item
()
==
0.25
l1_loss
=
L1Loss
(
loss_weight
=
1.0
,
reduction
=
'mean'
,
data_info
=
self
.
data_info
)
loss
=
l1_loss
(
dict
(
fake_imgs
=
pred
,
real_imgs
=
target
))
assert
loss
.
shape
==
()
assert
loss
.
item
()
==
0.25
l1_loss
=
L1Loss
(
loss_weight
=
0.5
,
reduction
=
'none'
)
loss
=
l1_loss
(
pred
,
target
)
assert
loss
.
shape
==
(
1
,
1
,
64
,
64
)
assert
(
loss
==
torch
.
ones
(
1
,
1
,
64
,
64
)
*
weight
*
0.5
).
all
()
l1_loss
=
L1Loss
(
loss_weight
=
0.5
,
reduction
=
'sum'
)
loss
=
l1_loss
(
pred
,
target
)
assert
loss
.
shape
==
()
assert
loss
.
item
()
==
512
# test MSE loss
mse_loss
=
MSELoss
(
loss_weight
=
1.0
,
data_info
=
self
.
data_info
)
loss
=
mse_loss
(
outputs_dict
=
dict
(
fake_imgs
=
pred
,
real_imgs
=
target
))
assert
loss
.
shape
==
()
assert
loss
.
item
()
==
0.25
mse_loss
=
MSELoss
(
loss_weight
=
1.0
,
data_info
=
self
.
data_info
)
loss
=
mse_loss
(
dict
(
fake_imgs
=
pred
,
real_imgs
=
target
))
assert
loss
.
shape
==
()
assert
loss
.
item
()
==
0.25
mse_loss
=
MSELoss
(
loss_weight
=
0.5
)
loss
=
mse_loss
(
pred
,
target
)
assert
loss
.
shape
==
()
assert
loss
.
item
()
==
0.1250
class
TestGaussianKLDLoss
:
@
classmethod
def
setup_class
(
cls
):
cls
.
data_info
=
dict
(
mean_pred
=
'mean_pred'
,
mean_target
=
'mean_target'
,
logvar_pred
=
'logvar_pred'
,
logvar_target
=
'logvar_target'
)
cls
.
tar_shape
=
[
2
,
2
,
4
,
4
]
cls
.
mean_pred
=
torch
.
zeros
(
cls
.
tar_shape
)
cls
.
mean_target
=
torch
.
ones
(
cls
.
tar_shape
)
cls
.
logvar_pred
=
torch
.
zeros
(
cls
.
tar_shape
)
cls
.
logvar_target
=
torch
.
ones
(
cls
.
tar_shape
)
cls
.
output_dict
=
dict
(
mean_pred
=
cls
.
mean_pred
,
mean_target
=
cls
.
mean_target
,
logvar_pred
=
cls
.
logvar_pred
,
logvar_target
=
cls
.
logvar_target
)
cls
.
gt_loss
=
((
torch
.
exp
(
torch
.
ones
(
1
))
-
1
)
/
2
).
item
()
def
test_gaussian_kld_loss
(
self
):
# test reduction --> mean
gaussian_kld_loss
=
GaussianKLDLoss
(
data_info
=
self
.
data_info
,
reduction
=
'mean'
)
loss
=
gaussian_kld_loss
(
self
.
output_dict
)
assert
(
loss
==
self
.
gt_loss
).
all
()
# test reduction --> batchmean
gaussian_kld_loss
=
GaussianKLDLoss
(
data_info
=
self
.
data_info
,
reduction
=
'batchmean'
)
loss
=
gaussian_kld_loss
(
self
.
output_dict
)
num_elements
=
self
.
tar_shape
[
1
]
*
self
.
tar_shape
[
2
]
*
\
self
.
tar_shape
[
3
]
assert
(
loss
==
(
self
.
gt_loss
*
num_elements
)).
all
()
# test weight --> int
gaussian_kld_loss
=
GaussianKLDLoss
(
loss_weight
=
2
,
data_info
=
self
.
data_info
,
reduction
=
'mean'
)
loss
=
gaussian_kld_loss
(
self
.
output_dict
)
assert
(
loss
==
self
.
gt_loss
*
2
).
all
()
# test weight --> tensor
weight
=
torch
.
randn
(
*
self
.
tar_shape
)
gaussian_kld_loss
=
GaussianKLDLoss
(
loss_weight
=
weight
,
data_info
=
self
.
data_info
,
reduction
=
'mean'
)
loss
=
gaussian_kld_loss
(
self
.
output_dict
)
assert
torch
.
allclose
(
loss
,
weight
.
mean
()
*
self
.
gt_loss
,
atol
=
1e-6
)
# test weight --> tensor & batchmean
weight
=
torch
.
randn
(
*
self
.
tar_shape
)
gaussian_kld_loss
=
GaussianKLDLoss
(
loss_weight
=
weight
,
data_info
=
self
.
data_info
,
reduction
=
'batchmean'
)
loss
=
gaussian_kld_loss
(
self
.
output_dict
)
assert
torch
.
allclose
(
loss
,
weight
.
sum
([
1
,
2
,
3
]).
mean
()
*
self
.
gt_loss
,
atol
=
1e-6
)
def
test_approx_gaussian_cdf
():
pos
=
torch
.
rand
(
2
,
2
)
gaussian_dist
=
Normal
(
0
,
1
)
assert
torch
.
allclose
(
approx_gaussian_cdf
(
pos
),
gaussian_dist
.
cdf
(
pos
),
atol
=
1e-3
)
class
TestDistLoss
():
@
classmethod
def
setup_class
(
cls
):
cls
.
data_info
=
dict
(
mean
=
'mean_pred'
,
logvar
=
'logvar_pred'
,
x
=
'real_imgs'
)
cls
.
tar_shape
=
[
2
,
2
,
4
,
4
]
cls
.
mean_pred
=
torch
.
zeros
(
cls
.
tar_shape
)
cls
.
logvar_pred
=
torch
.
zeros
(
cls
.
tar_shape
)
cls
.
real_imgs
=
torch
.
zeros
(
cls
.
tar_shape
)
cls
.
output_dict
=
dict
(
mean_pred
=
cls
.
mean_pred
,
logvar_pred
=
cls
.
logvar_pred
,
real_imgs
=
cls
.
real_imgs
)
norm_dist
=
Normal
(
0
,
1
)
cls
.
gt_loss
=
torch
.
log
(
norm_dist
.
cdf
(
torch
.
FloatTensor
([
1
/
255
]))
-
norm_dist
.
cdf
(
torch
.
FloatTensor
([
-
1
/
255
])))
def
test_disc_gaussian_log_likelihood_loss
(
self
):
# test reduction --> mean
disc_gaussian_loss
=
DiscretizedGaussianLogLikelihoodLoss
(
data_info
=
self
.
data_info
,
reduction
=
'mean'
)
loss
=
disc_gaussian_loss
(
self
.
output_dict
)
assert
(
loss
==
self
.
gt_loss
).
all
()
# test reduction --> batchmean
disc_gaussian_loss
=
DiscretizedGaussianLogLikelihoodLoss
(
data_info
=
self
.
data_info
,
reduction
=
'batchmean'
)
loss
=
disc_gaussian_loss
(
self
.
output_dict
)
num_elements
=
self
.
tar_shape
[
1
]
*
self
.
tar_shape
[
2
]
*
\
self
.
tar_shape
[
3
]
assert
(
loss
==
(
self
.
gt_loss
*
num_elements
)).
all
()
# test weight --> int
disc_gaussian_loss
=
DiscretizedGaussianLogLikelihoodLoss
(
loss_weight
=
2
,
data_info
=
self
.
data_info
,
reduction
=
'mean'
)
loss
=
disc_gaussian_loss
(
self
.
output_dict
)
assert
(
loss
==
self
.
gt_loss
*
2
).
all
()
# # test weight --> tensor
weight
=
torch
.
randn
(
*
self
.
tar_shape
)
disc_gaussian_loss
=
DiscretizedGaussianLogLikelihoodLoss
(
loss_weight
=
weight
,
data_info
=
self
.
data_info
,
reduction
=
'mean'
)
loss
=
disc_gaussian_loss
(
self
.
output_dict
)
assert
torch
.
allclose
(
loss
,
weight
.
mean
()
*
self
.
gt_loss
,
atol
=
1e-6
)
# test weight --> tensor & batchmean
weight
=
torch
.
randn
(
*
self
.
tar_shape
)
disc_gaussian_loss
=
DiscretizedGaussianLogLikelihoodLoss
(
loss_weight
=
weight
,
data_info
=
self
.
data_info
,
reduction
=
'batchmean'
)
loss
=
disc_gaussian_loss
(
self
.
output_dict
)
assert
torch
.
allclose
(
loss
,
weight
.
sum
([
1
,
2
,
3
]).
mean
()
*
self
.
gt_loss
,
atol
=
1e-6
)
tests/test_models/test_base_ddpm.py
0 → 100644
View file @
c9a48a52
# Copyright (c) OpenMMLab. All rights reserved.
import
torch
from
mmgen.models.diffusions
import
UniformTimeStepSampler
def
test_uniform_sampler
():
sampler
=
UniformTimeStepSampler
(
10
)
timesteps
=
sampler
(
2
)
assert
timesteps
.
shape
==
torch
.
Size
([
2
,
])
assert
timesteps
.
max
()
<
10
and
timesteps
.
min
()
>=
0
timesteps
=
sampler
.
__call__
(
2
)
assert
timesteps
.
shape
==
torch
.
Size
([
2
,
])
assert
timesteps
.
max
()
<
10
and
timesteps
.
min
()
>=
0
tests/test_models/test_base_gan.py
0 → 100644
View file @
c9a48a52
# Copyright (c) OpenMMLab. All rights reserved.
tests/test_models/test_basic_conditional_gan.py
0 → 100644
View file @
c9a48a52
# Copyright (c) OpenMMLab. All rights reserved.
from
copy
import
deepcopy
import
pytest
import
torch
import
torch.nn
as
nn
from
mmgen.models
import
BasicConditionalGAN
,
build_model
class
TestBasicConditionalGAN
(
object
):
@
classmethod
def
setup_class
(
cls
):
cls
.
default_config
=
dict
(
type
=
'BasicConditionalGAN'
,
generator
=
dict
(
type
=
'SNGANGenerator'
,
output_scale
=
32
,
base_channels
=
256
,
num_classes
=
10
),
discriminator
=
dict
(
type
=
'ProjDiscriminator'
,
input_scale
=
32
,
base_channels
=
128
,
num_classes
=
10
),
gan_loss
=
dict
(
type
=
'GANLoss'
,
gan_type
=
'hinge'
),
disc_auxiliary_loss
=
None
,
gen_auxiliary_loss
=
None
,
train_cfg
=
None
,
test_cfg
=
None
)
cls
.
generator_cfg
=
dict
(
type
=
'SAGANGenerator'
,
output_scale
=
32
,
num_classes
=
10
,
base_channels
=
256
,
attention_after_nth_block
=
2
,
with_spectral_norm
=
True
)
cls
.
disc_cfg
=
dict
(
type
=
'SAGANDiscriminator'
,
input_scale
=
32
,
num_classes
=
10
,
base_channels
=
128
,
attention_after_nth_block
=
1
)
cls
.
gan_loss
=
dict
(
type
=
'GANLoss'
,
gan_type
=
'hinge'
)
cls
.
disc_auxiliary_loss
=
[
dict
(
type
=
'DiscShiftLoss'
,
loss_weight
=
0.5
,
data_info
=
dict
(
pred
=
'disc_pred_fake'
)),
dict
(
type
=
'DiscShiftLoss'
,
loss_weight
=
0.5
,
data_info
=
dict
(
pred
=
'disc_pred_real'
))
]
def
test_default_dcgan_model_cpu
(
self
):
sngan
=
build_model
(
self
.
default_config
)
assert
isinstance
(
sngan
,
BasicConditionalGAN
)
assert
not
sngan
.
with_disc_auxiliary_loss
assert
sngan
.
with_disc
# test forward train
with
pytest
.
raises
(
NotImplementedError
):
_
=
sngan
(
None
,
return_loss
=
True
)
# test forward test
imgs
=
sngan
(
None
,
return_loss
=
False
,
mode
=
'sampling'
,
num_batches
=
2
)
assert
imgs
.
shape
==
(
2
,
3
,
32
,
32
)
# test train step
data
=
torch
.
randn
((
2
,
3
,
32
,
32
))
label
=
torch
.
randint
(
0
,
10
,
(
2
,
))
data_input
=
dict
(
img
=
data
,
gt_label
=
label
)
optimizer_g
=
torch
.
optim
.
SGD
(
sngan
.
generator
.
parameters
(),
lr
=
0.01
)
optimizer_d
=
torch
.
optim
.
SGD
(
sngan
.
discriminator
.
parameters
(),
lr
=
0.01
)
optim_dict
=
dict
(
generator
=
optimizer_g
,
discriminator
=
optimizer_d
)
model_outputs
=
sngan
.
train_step
(
data_input
,
optim_dict
)
assert
'results'
in
model_outputs
assert
'log_vars'
in
model_outputs
assert
model_outputs
[
'num_samples'
]
==
2
# more tests for different configs with heavy computation
# test disc_steps
config_
=
deepcopy
(
self
.
default_config
)
config_
[
'train_cfg'
]
=
dict
(
disc_steps
=
2
)
sngan
=
build_model
(
config_
)
model_outputs
=
sngan
.
train_step
(
data_input
,
optim_dict
)
assert
'loss_disc_fake'
in
model_outputs
[
'log_vars'
]
assert
'loss_disc_fake_g'
not
in
model_outputs
[
'log_vars'
]
assert
sngan
.
disc_steps
==
2
model_outputs
=
sngan
.
train_step
(
data_input
,
optim_dict
,
running_status
=
dict
(
iteration
=
1
))
assert
'loss_disc_fake'
in
model_outputs
[
'log_vars'
]
assert
'loss_disc_fake_g'
in
model_outputs
[
'log_vars'
]
# test customized config
sagan
=
BasicConditionalGAN
(
self
.
generator_cfg
,
self
.
disc_cfg
,
self
.
gan_loss
,
self
.
disc_auxiliary_loss
,
)
# test train step
data
=
torch
.
randn
((
2
,
3
,
32
,
32
))
data_input
=
dict
(
img
=
data
,
gt_label
=
label
)
optimizer_g
=
torch
.
optim
.
SGD
(
sngan
.
generator
.
parameters
(),
lr
=
0.01
)
optimizer_d
=
torch
.
optim
.
SGD
(
sngan
.
discriminator
.
parameters
(),
lr
=
0.01
)
optim_dict
=
dict
(
generator
=
optimizer_g
,
discriminator
=
optimizer_d
)
model_outputs
=
sagan
.
train_step
(
data_input
,
optim_dict
)
assert
'results'
in
model_outputs
assert
'log_vars'
in
model_outputs
assert
model_outputs
[
'num_samples'
]
==
2
sagan
=
BasicConditionalGAN
(
self
.
generator_cfg
,
self
.
disc_cfg
,
self
.
gan_loss
,
dict
(
type
=
'DiscShiftLoss'
,
loss_weight
=
0.5
,
data_info
=
dict
(
pred
=
'disc_pred_fake'
)),
dict
(
type
=
'GeneratorPathRegularizer'
))
assert
isinstance
(
sagan
.
disc_auxiliary_losses
,
nn
.
ModuleList
)
assert
isinstance
(
sagan
.
gen_auxiliary_losses
,
nn
.
ModuleList
)
sagan
=
BasicConditionalGAN
(
self
.
generator_cfg
,
self
.
disc_cfg
,
self
.
gan_loss
,
dict
(
type
=
'DiscShiftLoss'
,
loss_weight
=
0.5
,
data_info
=
dict
(
pred
=
'disc_pred_fake'
)),
[
dict
(
type
=
'GeneratorPathRegularizer'
)])
assert
isinstance
(
sagan
.
disc_auxiliary_losses
,
nn
.
ModuleList
)
assert
isinstance
(
sagan
.
gen_auxiliary_losses
,
nn
.
ModuleList
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
'requires cuda'
)
def
test_default_dcgan_model_cuda
(
self
):
sngan
=
build_model
(
self
.
default_config
).
cuda
()
assert
isinstance
(
sngan
,
BasicConditionalGAN
)
assert
not
sngan
.
with_disc_auxiliary_loss
assert
sngan
.
with_disc
# test forward train
with
pytest
.
raises
(
NotImplementedError
):
_
=
sngan
(
None
,
return_loss
=
True
)
# test forward test
imgs
=
sngan
(
None
,
return_loss
=
False
,
mode
=
'sampling'
,
num_batches
=
2
)
assert
imgs
.
shape
==
(
2
,
3
,
32
,
32
)
# test train step
data
=
torch
.
randn
((
2
,
3
,
32
,
32
)).
cuda
()
label
=
torch
.
randint
(
0
,
10
,
(
2
,
)).
cuda
()
data_input
=
dict
(
img
=
data
,
gt_label
=
label
)
optimizer_g
=
torch
.
optim
.
SGD
(
sngan
.
generator
.
parameters
(),
lr
=
0.01
)
optimizer_d
=
torch
.
optim
.
SGD
(
sngan
.
discriminator
.
parameters
(),
lr
=
0.01
)
optim_dict
=
dict
(
generator
=
optimizer_g
,
discriminator
=
optimizer_d
)
model_outputs
=
sngan
.
train_step
(
data_input
,
optim_dict
)
assert
'results'
in
model_outputs
assert
'log_vars'
in
model_outputs
assert
model_outputs
[
'num_samples'
]
==
2
# more tests for different configs with heavy computation
# test disc_steps
config_
=
deepcopy
(
self
.
default_config
)
config_
[
'train_cfg'
]
=
dict
(
disc_steps
=
2
)
sngan
=
build_model
(
config_
).
cuda
()
model_outputs
=
sngan
.
train_step
(
data_input
,
optim_dict
)
assert
'loss_disc_fake'
in
model_outputs
[
'log_vars'
]
assert
'loss_disc_fake_g'
not
in
model_outputs
[
'log_vars'
]
assert
sngan
.
disc_steps
==
2
model_outputs
=
sngan
.
train_step
(
data_input
,
optim_dict
,
running_status
=
dict
(
iteration
=
1
))
assert
'loss_disc_fake'
in
model_outputs
[
'log_vars'
]
assert
'loss_disc_fake_g'
in
model_outputs
[
'log_vars'
]
tests/test_models/test_cyclegan.py
0 → 100644
View file @
c9a48a52
# Copyright (c) OpenMMLab. All rights reserved.
import
copy
import
torch
from
mmcv.runner
import
obj_from_dict
from
mmgen.models
import
(
GANLoss
,
L1Loss
,
PatchDiscriminator
,
ResnetGenerator
,
build_model
)
def
test_cyclegan
():
model_cfg
=
dict
(
type
=
'CycleGAN'
,
default_domain
=
'photo'
,
reachable_domains
=
[
'photo'
,
'mask'
],
related_domains
=
[
'photo'
,
'mask'
],
generator
=
dict
(
type
=
'ResnetGenerator'
,
in_channels
=
3
,
out_channels
=
3
,
base_channels
=
64
,
norm_cfg
=
dict
(
type
=
'IN'
),
use_dropout
=
False
,
num_blocks
=
9
,
padding_mode
=
'reflect'
,
init_cfg
=
dict
(
type
=
'normal'
,
gain
=
0.02
)),
discriminator
=
dict
(
type
=
'PatchDiscriminator'
,
in_channels
=
3
,
base_channels
=
64
,
num_conv
=
3
,
norm_cfg
=
dict
(
type
=
'IN'
),
init_cfg
=
dict
(
type
=
'normal'
,
gain
=
0.02
)),
gan_loss
=
dict
(
type
=
'GANLoss'
,
gan_type
=
'lsgan'
,
real_label_val
=
1.0
,
fake_label_val
=
0.0
,
loss_weight
=
1.0
),
gen_auxiliary_loss
=
[
dict
(
type
=
'L1Loss'
,
loss_weight
=
10.0
,
data_info
=
dict
(
pred
=
'cycle_photo'
,
target
=
'real_photo'
),
reduction
=
'mean'
),
dict
(
type
=
'L1Loss'
,
loss_weight
=
10.0
,
data_info
=
dict
(
pred
=
'cycle_mask'
,
target
=
'real_mask'
,
),
reduction
=
'mean'
),
dict
(
type
=
'L1Loss'
,
loss_weight
=
0.5
,
data_info
=
dict
(
pred
=
'identity_photo'
,
target
=
'real_photo'
),
reduction
=
'mean'
),
dict
(
type
=
'L1Loss'
,
loss_weight
=
0.5
,
data_info
=
dict
(
pred
=
'identity_mask'
,
target
=
'real_mask'
),
reduction
=
'mean'
)
])
train_cfg
=
None
test_cfg
=
None
# build synthesizer
synthesizer
=
build_model
(
model_cfg
,
train_cfg
=
train_cfg
,
test_cfg
=
test_cfg
)
# test attributes
assert
synthesizer
.
__class__
.
__name__
==
'CycleGAN'
assert
isinstance
(
synthesizer
.
generators
[
'photo'
],
ResnetGenerator
)
assert
isinstance
(
synthesizer
.
generators
[
'mask'
],
ResnetGenerator
)
assert
isinstance
(
synthesizer
.
discriminators
[
'photo'
],
PatchDiscriminator
)
assert
isinstance
(
synthesizer
.
discriminators
[
'mask'
],
PatchDiscriminator
)
assert
isinstance
(
synthesizer
.
gan_loss
,
GANLoss
)
for
loss_module
in
synthesizer
.
gen_auxiliary_losses
:
assert
isinstance
(
loss_module
,
L1Loss
)
# prepare data
inputs
=
torch
.
rand
(
1
,
3
,
64
,
64
)
targets
=
torch
.
rand
(
1
,
3
,
64
,
64
)
data_batch
=
{
'img_mask'
:
inputs
,
'img_photo'
:
targets
}
# prepare optimizer
optim_cfg
=
dict
(
type
=
'Adam'
,
lr
=
2e-4
,
betas
=
(
0.5
,
0.999
))
optimizer
=
{
'generators'
:
obj_from_dict
(
optim_cfg
,
torch
.
optim
,
dict
(
params
=
getattr
(
synthesizer
,
'generators'
).
parameters
())),
'discriminators'
:
obj_from_dict
(
optim_cfg
,
torch
.
optim
,
dict
(
params
=
getattr
(
synthesizer
,
'discriminators'
).
parameters
()))
}
# test forward_test
with
torch
.
no_grad
():
outputs
=
synthesizer
(
inputs
,
target_domain
=
'photo'
,
test_mode
=
True
)
assert
torch
.
equal
(
outputs
[
'source'
],
data_batch
[
'img_mask'
])
assert
torch
.
is_tensor
(
outputs
[
'target'
])
assert
outputs
[
'target'
].
size
()
==
(
1
,
3
,
64
,
64
)
with
torch
.
no_grad
():
outputs
=
synthesizer
(
targets
,
target_domain
=
'mask'
,
test_mode
=
True
)
assert
torch
.
equal
(
outputs
[
'source'
],
data_batch
[
'img_photo'
])
assert
torch
.
is_tensor
(
outputs
[
'target'
])
assert
outputs
[
'target'
].
size
()
==
(
1
,
3
,
64
,
64
)
# test forward_train
with
torch
.
no_grad
():
outputs
=
synthesizer
(
inputs
,
target_domain
=
'photo'
,
test_mode
=
True
)
assert
torch
.
equal
(
outputs
[
'source'
],
data_batch
[
'img_mask'
])
assert
torch
.
is_tensor
(
outputs
[
'target'
])
assert
outputs
[
'target'
].
size
()
==
(
1
,
3
,
64
,
64
)
with
torch
.
no_grad
():
outputs
=
synthesizer
(
targets
,
target_domain
=
'mask'
,
test_mode
=
True
)
assert
torch
.
equal
(
outputs
[
'source'
],
data_batch
[
'img_photo'
])
assert
torch
.
is_tensor
(
outputs
[
'target'
])
assert
outputs
[
'target'
].
size
()
==
(
1
,
3
,
64
,
64
)
# test train_step
outputs
=
synthesizer
.
train_step
(
data_batch
,
optimizer
)
assert
isinstance
(
outputs
,
dict
)
assert
isinstance
(
outputs
[
'log_vars'
],
dict
)
assert
isinstance
(
outputs
[
'results'
],
dict
)
for
v
in
[
'loss_gan_d_mask'
,
'loss_gan_d_photo'
,
'loss_gan_g_mask'
,
'loss_gan_g_photo'
,
'loss_l1'
]:
assert
isinstance
(
outputs
[
'log_vars'
][
v
],
float
)
assert
outputs
[
'num_samples'
]
==
1
assert
torch
.
equal
(
outputs
[
'results'
][
'real_photo'
],
data_batch
[
'img_photo'
])
assert
torch
.
equal
(
outputs
[
'results'
][
'real_mask'
],
data_batch
[
'img_mask'
])
assert
torch
.
is_tensor
(
outputs
[
'results'
][
'fake_mask'
])
assert
torch
.
is_tensor
(
outputs
[
'results'
][
'fake_photo'
])
assert
outputs
[
'results'
][
'fake_mask'
].
size
()
==
(
1
,
3
,
64
,
64
)
assert
outputs
[
'results'
][
'fake_photo'
].
size
()
==
(
1
,
3
,
64
,
64
)
# test train_step and forward_test (gpu)
if
torch
.
cuda
.
is_available
():
synthesizer
=
synthesizer
.
cuda
()
optimizer
=
{
'generators'
:
obj_from_dict
(
optim_cfg
,
torch
.
optim
,
dict
(
params
=
getattr
(
synthesizer
,
'generators'
).
parameters
())),
'discriminators'
:
obj_from_dict
(
optim_cfg
,
torch
.
optim
,
dict
(
params
=
getattr
(
synthesizer
,
'discriminators'
).
parameters
()))
}
data_batch_cuda
=
copy
.
deepcopy
(
data_batch
)
data_batch_cuda
[
'img_mask'
]
=
inputs
.
cuda
()
data_batch_cuda
[
'img_photo'
]
=
targets
.
cuda
()
# forward_test
with
torch
.
no_grad
():
outputs
=
synthesizer
(
data_batch_cuda
[
'img_mask'
],
target_domain
=
'photo'
,
test_mode
=
True
)
assert
torch
.
equal
(
outputs
[
'source'
],
data_batch_cuda
[
'img_mask'
].
cpu
())
assert
torch
.
is_tensor
(
outputs
[
'target'
])
assert
outputs
[
'target'
].
size
()
==
(
1
,
3
,
64
,
64
)
with
torch
.
no_grad
():
outputs
=
synthesizer
(
data_batch_cuda
[
'img_photo'
],
target_domain
=
'mask'
,
test_mode
=
True
)
assert
torch
.
equal
(
outputs
[
'source'
],
data_batch_cuda
[
'img_photo'
].
cpu
())
assert
torch
.
is_tensor
(
outputs
[
'target'
])
assert
outputs
[
'target'
].
size
()
==
(
1
,
3
,
64
,
64
)
# test forward_train
with
torch
.
no_grad
():
outputs
=
synthesizer
(
data_batch_cuda
[
'img_mask'
],
target_domain
=
'photo'
,
test_mode
=
False
)
assert
torch
.
equal
(
outputs
[
'source'
],
data_batch_cuda
[
'img_mask'
])
assert
torch
.
is_tensor
(
outputs
[
'target'
])
assert
outputs
[
'target'
].
size
()
==
(
1
,
3
,
64
,
64
)
with
torch
.
no_grad
():
outputs
=
synthesizer
(
data_batch_cuda
[
'img_photo'
],
target_domain
=
'mask'
,
test_mode
=
False
)
assert
torch
.
equal
(
outputs
[
'source'
],
data_batch_cuda
[
'img_photo'
])
assert
torch
.
is_tensor
(
outputs
[
'target'
])
assert
outputs
[
'target'
].
size
()
==
(
1
,
3
,
64
,
64
)
# train_step
outputs
=
synthesizer
.
train_step
(
data_batch_cuda
,
optimizer
)
assert
isinstance
(
outputs
,
dict
)
assert
isinstance
(
outputs
[
'log_vars'
],
dict
)
print
(
outputs
[
'log_vars'
].
keys
())
assert
isinstance
(
outputs
[
'results'
],
dict
)
for
v
in
[
'loss_gan_d_mask'
,
'loss_gan_d_photo'
,
'loss_gan_g_mask'
,
'loss_gan_g_photo'
,
'loss_l1'
]:
assert
isinstance
(
outputs
[
'log_vars'
][
v
],
float
)
assert
outputs
[
'num_samples'
]
==
1
assert
torch
.
equal
(
outputs
[
'results'
][
'real_photo'
],
data_batch_cuda
[
'img_photo'
].
cpu
())
assert
torch
.
equal
(
outputs
[
'results'
][
'real_mask'
],
data_batch_cuda
[
'img_mask'
].
cpu
())
assert
torch
.
is_tensor
(
outputs
[
'results'
][
'fake_mask'
])
assert
torch
.
is_tensor
(
outputs
[
'results'
][
'fake_photo'
])
assert
outputs
[
'results'
][
'fake_mask'
].
size
()
==
(
1
,
3
,
64
,
64
)
assert
outputs
[
'results'
][
'fake_photo'
].
size
()
==
(
1
,
3
,
64
,
64
)
# test disc_steps and disc_init_steps
data_batch
[
'img_mask'
]
=
inputs
.
cpu
()
data_batch
[
'img_photo'
]
=
targets
.
cpu
()
train_cfg
=
dict
(
disc_steps
=
2
,
disc_init_steps
=
2
)
synthesizer
=
build_model
(
model_cfg
,
train_cfg
=
train_cfg
,
test_cfg
=
test_cfg
)
optimizer
=
{
'generators'
:
obj_from_dict
(
optim_cfg
,
torch
.
optim
,
dict
(
params
=
getattr
(
synthesizer
,
'generators'
).
parameters
())),
'discriminators'
:
obj_from_dict
(
optim_cfg
,
torch
.
optim
,
dict
(
params
=
getattr
(
synthesizer
,
'discriminators'
).
parameters
()))
}
# iter 0, 1
for
i
in
range
(
2
):
outputs
=
synthesizer
.
train_step
(
data_batch
,
optimizer
)
assert
isinstance
(
outputs
,
dict
)
assert
isinstance
(
outputs
[
'log_vars'
],
dict
)
assert
isinstance
(
outputs
[
'results'
],
dict
)
for
v
in
[
'loss_gan_g_mask'
,
'loss_gan_g_photo'
,
'loss_l1'
]:
assert
outputs
[
'log_vars'
].
get
(
v
)
is
None
assert
isinstance
(
outputs
[
'log_vars'
][
'loss_gan_d_mask'
],
float
)
assert
isinstance
(
outputs
[
'log_vars'
][
'loss_gan_d_photo'
],
float
)
assert
outputs
[
'num_samples'
]
==
1
assert
torch
.
equal
(
outputs
[
'results'
][
'real_photo'
],
data_batch
[
'img_photo'
])
assert
torch
.
equal
(
outputs
[
'results'
][
'real_mask'
],
data_batch
[
'img_mask'
])
assert
torch
.
is_tensor
(
outputs
[
'results'
][
'fake_mask'
])
assert
torch
.
is_tensor
(
outputs
[
'results'
][
'fake_photo'
])
assert
outputs
[
'results'
][
'fake_mask'
].
size
()
==
(
1
,
3
,
64
,
64
)
assert
outputs
[
'results'
][
'fake_photo'
].
size
()
==
(
1
,
3
,
64
,
64
)
assert
synthesizer
.
iteration
==
i
+
1
# iter 2, 3, 4, 5
for
i
in
range
(
2
,
6
):
assert
synthesizer
.
iteration
==
i
outputs
=
synthesizer
.
train_step
(
data_batch
,
optimizer
)
assert
isinstance
(
outputs
,
dict
)
assert
isinstance
(
outputs
[
'log_vars'
],
dict
)
assert
isinstance
(
outputs
[
'results'
],
dict
)
log_check_list
=
[
'loss_gan_d_mask'
,
'loss_gan_d_photo'
,
'loss_gan_g_mask'
,
'loss_gan_g_photo'
,
'loss_l1'
]
if
i
%
2
==
1
:
log_None_list
=
[
'loss_gan_g_mask'
,
'loss_gan_g_photo'
,
'loss_l1'
]
for
v
in
log_None_list
:
assert
outputs
[
'log_vars'
].
get
(
v
)
is
None
log_check_list
.
remove
(
v
)
for
v
in
log_check_list
:
assert
isinstance
(
outputs
[
'log_vars'
][
v
],
float
)
assert
outputs
[
'num_samples'
]
==
1
assert
torch
.
equal
(
outputs
[
'results'
][
'real_mask'
],
data_batch
[
'img_mask'
])
assert
torch
.
equal
(
outputs
[
'results'
][
'real_photo'
],
data_batch
[
'img_photo'
])
assert
torch
.
is_tensor
(
outputs
[
'results'
][
'fake_mask'
])
assert
torch
.
is_tensor
(
outputs
[
'results'
][
'fake_photo'
])
assert
outputs
[
'results'
][
'fake_mask'
].
size
()
==
(
1
,
3
,
64
,
64
)
assert
outputs
[
'results'
][
'fake_photo'
].
size
()
==
(
1
,
3
,
64
,
64
)
assert
synthesizer
.
iteration
==
i
+
1
# test GAN image buffer size = 0
data_batch
[
'img_mask'
]
=
inputs
.
cpu
()
data_batch
[
'img_photo'
]
=
targets
.
cpu
()
train_cfg
=
dict
(
buffer_size
=
0
)
synthesizer
=
build_model
(
model_cfg
,
train_cfg
=
train_cfg
,
test_cfg
=
test_cfg
)
optimizer
=
{
'generators'
:
obj_from_dict
(
optim_cfg
,
torch
.
optim
,
dict
(
params
=
getattr
(
synthesizer
,
'generators'
).
parameters
())),
'discriminators'
:
obj_from_dict
(
optim_cfg
,
torch
.
optim
,
dict
(
params
=
getattr
(
synthesizer
,
'discriminators'
).
parameters
()))
}
outputs
=
synthesizer
.
train_step
(
data_batch
,
optimizer
)
assert
isinstance
(
outputs
,
dict
)
assert
isinstance
(
outputs
[
'log_vars'
],
dict
)
assert
isinstance
(
outputs
[
'results'
],
dict
)
for
v
in
[
'loss_gan_d_mask'
,
'loss_gan_d_photo'
,
'loss_gan_g_mask'
,
'loss_gan_g_photo'
,
'loss_l1'
]:
assert
isinstance
(
outputs
[
'log_vars'
][
v
],
float
)
assert
outputs
[
'num_samples'
]
==
1
assert
torch
.
equal
(
outputs
[
'results'
][
'real_mask'
],
data_batch
[
'img_mask'
])
assert
torch
.
equal
(
outputs
[
'results'
][
'real_photo'
],
data_batch
[
'img_photo'
])
assert
torch
.
is_tensor
(
outputs
[
'results'
][
'fake_mask'
])
assert
torch
.
is_tensor
(
outputs
[
'results'
][
'fake_photo'
])
assert
outputs
[
'results'
][
'fake_mask'
].
size
()
==
(
1
,
3
,
64
,
64
)
assert
outputs
[
'results'
][
'fake_photo'
].
size
()
==
(
1
,
3
,
64
,
64
)
assert
synthesizer
.
iteration
==
1
tests/test_models/test_ddpm.py
0 → 100644
View file @
c9a48a52
# Copyright (c) OpenMMLab. All rights reserved.
from
copy
import
deepcopy
import
numpy
as
np
import
pytest
import
torch
from
mmgen.models.builder
import
build_model
from
mmgen.models.diffusions
import
(
BasicGaussianDiffusion
,
UniformTimeStepSampler
)
from
mmgen.models.diffusions.utils
import
_get_label_batch
,
_get_noise_batch
class
TestBasicGaussianDiffusion
:
@
classmethod
def
setup_class
(
cls
):
cls
.
config
=
dict
(
type
=
'BasicGaussianDiffusion'
,
num_timesteps
=
10
,
betas_cfg
=
dict
(
type
=
'cosine'
),
train_cfg
=
None
,
test_cfg
=
None
)
cls
.
denoising
=
dict
(
type
=
'DenoisingUnet'
,
image_size
=
32
,
in_channels
=
3
,
base_channels
=
128
,
resblocks_per_downsample
=
1
,
attention_res
=
[
16
,
8
],
use_scale_shift_norm
=
True
,
dropout
=
0.3
,
num_heads
=
1
,
use_rescale_timesteps
=
True
,
output_cfg
=
dict
(
mean
=
'eps'
,
var
=
'learned_range'
),
)
cls
.
sampler
=
dict
(
type
=
'UniformTimeStepSampler'
)
cls
.
ddpm_loss
=
[
dict
(
type
=
'DDPMVLBLoss'
,
rescale_mode
=
'constant'
,
rescale_cfg
=
dict
(
scale
=
1
),
data_info
=
dict
(
mean_pred
=
'mean_pred'
,
mean_target
=
'mean_posterior'
,
logvar_pred
=
'logvar_pred'
,
logvar_target
=
'logvar_posterior'
),
log_cfgs
=
[
dict
(
type
=
'quartile'
,
prefix_name
=
'loss_vlb'
,
total_timesteps
=
1000
),
dict
(
type
=
'name'
)
]),
dict
(
type
=
'DDPMMSELoss'
,
log_cfgs
=
dict
(
type
=
'quartile'
,
prefix_name
=
'loss_mse'
,
total_timesteps
=
1000
),
)
]
def
test_diffusion
(
self
):
# test build model
cfg
=
deepcopy
(
self
.
config
)
cfg
[
'denoising'
]
=
self
.
denoising
cfg
[
'timestep_sampler'
]
=
self
.
sampler
cfg
[
'ddpm_loss'
]
=
None
diffusion
=
build_model
(
cfg
)
assert
isinstance
(
diffusion
,
BasicGaussianDiffusion
)
assert
isinstance
(
diffusion
.
sampler
,
UniformTimeStepSampler
)
assert
not
diffusion
.
use_ema
# test build model --> parse train_cfg with ema
cfg
=
deepcopy
(
self
.
config
)
cfg
[
'denoising'
]
=
self
.
denoising
cfg
[
'timestep_sampler'
]
=
self
.
sampler
cfg
[
'ddpm_loss'
]
=
None
cfg
[
'train_cfg'
]
=
dict
(
use_ema
=
True
)
diffusion
=
build_model
(
cfg
)
assert
isinstance
(
diffusion
,
BasicGaussianDiffusion
)
assert
isinstance
(
diffusion
.
sampler
,
UniformTimeStepSampler
)
assert
diffusion
.
use_ema
# test build_model --> parse train_cfg, without ema
cfg
=
deepcopy
(
self
.
config
)
cfg
[
'denoising'
]
=
self
.
denoising
cfg
[
'timestep_sampler'
]
=
self
.
sampler
cfg
[
'ddpm_loss'
]
=
None
cfg
[
'train_cfg'
]
=
dict
(
use_ema
=
False
)
diffusion
=
build_model
(
cfg
)
assert
isinstance
(
diffusion
,
BasicGaussianDiffusion
)
assert
isinstance
(
diffusion
.
sampler
,
UniformTimeStepSampler
)
assert
not
diffusion
.
use_ema
# test build_model --> parse test_cfg, with ema
cfg
=
deepcopy
(
self
.
config
)
cfg
[
'denoising'
]
=
self
.
denoising
cfg
[
'timestep_sampler'
]
=
self
.
sampler
cfg
[
'ddpm_loss'
]
=
None
cfg
[
'test_cfg'
]
=
dict
(
use_ema
=
True
)
diffusion
=
build_model
(
cfg
)
assert
isinstance
(
diffusion
,
BasicGaussianDiffusion
)
assert
isinstance
(
diffusion
.
sampler
,
UniformTimeStepSampler
)
assert
diffusion
.
use_ema
# test build_model --> parse test_cfg, without ema
cfg
=
deepcopy
(
self
.
config
)
cfg
[
'denoising'
]
=
self
.
denoising
cfg
[
'timestep_sampler'
]
=
self
.
sampler
cfg
[
'ddpm_loss'
]
=
None
cfg
[
'test_cfg'
]
=
dict
(
use_ema
=
False
)
diffusion
=
build_model
(
cfg
)
assert
isinstance
(
diffusion
,
BasicGaussianDiffusion
)
assert
isinstance
(
diffusion
.
sampler
,
UniformTimeStepSampler
)
assert
not
diffusion
.
use_ema
# test sampler is None
cfg
=
deepcopy
(
self
.
config
)
cfg
[
'denoising'
]
=
self
.
denoising
cfg
[
'timestep_sampler'
]
=
None
cfg
[
'ddpm_loss'
]
=
None
diffusion
=
build_model
(
cfg
)
assert
diffusion
.
sampler
is
None
# test build model --> betas type = linear
cfg
=
deepcopy
(
self
.
config
)
cfg
[
'betas_cfg'
]
=
dict
(
type
=
'linear'
)
cfg
[
'denoising'
]
=
self
.
denoising
cfg
[
'timestep_sampler'
]
=
self
.
sampler
cfg
[
'ddpm_loss'
]
=
None
diffusion
=
build_model
(
cfg
)
assert
isinstance
(
diffusion
,
BasicGaussianDiffusion
)
assert
isinstance
(
diffusion
.
sampler
,
UniformTimeStepSampler
)
assert
not
diffusion
.
use_ema
# test build model --> wrong beta cfgs
cfg
=
deepcopy
(
self
.
config
)
cfg
[
'betas_cfg'
]
=
dict
(
type
=
'sine'
)
cfg
[
'denoising'
]
=
self
.
denoising
cfg
[
'timestep_sampler'
]
=
self
.
sampler
cfg
[
'ddpm_loss'
]
=
None
with
pytest
.
raises
(
AttributeError
):
diffusion
=
build_model
(
cfg
)
# test forward train --> raise error
with
pytest
.
raises
(
NotImplementedError
):
diffusion
(
None
,
return_loss
=
True
)
# test forward test
imgs
=
diffusion
(
None
,
return_loss
=
False
,
mode
=
'sampling'
,
num_batches
=
2
)
assert
imgs
.
shape
==
(
2
,
3
,
32
,
32
)
# test forward test --> wrong mode
with
pytest
.
raises
(
NotImplementedError
):
diffusion
(
None
,
return_loss
=
False
,
mode
=
'generation'
,
num_batches
=
2
)
# test reconstruction step --> given timestep
cfg
=
deepcopy
(
self
.
config
)
cfg
[
'denoising'
]
=
self
.
denoising
cfg
[
'timestep_sampler'
]
=
self
.
sampler
cfg
[
'ddpm_loss'
]
=
None
diffusion
=
build_model
(
cfg
)
data_batch
=
dict
(
real_img
=
torch
.
randn
(
2
,
3
,
32
,
32
))
fake_imgs
=
diffusion
(
data_batch
,
timesteps
=
[
0
,
5
],
mode
=
'reconstruction'
)
assert
fake_imgs
.
shape
==
(
2
,
3
,
32
,
32
)
# test reconstruction step --> timestep = None
output_dict
=
diffusion
(
data_batch
,
mode
=
'reconstruction'
,
return_noise
=
True
)
assert
output_dict
[
'fake_img'
].
shape
==
(
20
,
3
,
32
,
32
)
timestep
=
torch
.
cat
([
torch
.
LongTensor
([
i
,
i
])
for
i
in
range
(
10
)])
assert
(
output_dict
[
'timesteps'
]
==
timestep
).
all
()
# test reconstruction step --> noise in input
noise
=
torch
.
randn
(
2
,
3
,
32
,
32
)
output_dict
=
diffusion
(
data_batch
,
noise
=
noise
,
mode
=
'reconstruction'
,
return_noise
=
True
)
assert
output_dict
[
'noise'
].
shape
==
(
20
,
3
,
32
,
32
)
assert
(
output_dict
[
'noise'
]
==
torch
.
cat
([
noise
for
_
in
range
(
10
)],
dim
=
0
)).
all
()
# test reconstruction step --> noise in data_batch
data_batch
=
dict
(
real_img
=
torch
.
randn
(
2
,
3
,
32
,
32
),
noise
=
noise
)
output_dict
=
diffusion
(
data_batch
,
mode
=
'reconstruction'
,
return_noise
=
True
)
assert
output_dict
[
'noise'
].
shape
==
(
20
,
3
,
32
,
32
)
assert
(
output_dict
[
'noise'
]
==
torch
.
cat
([
noise
for
_
in
range
(
10
)],
dim
=
0
)).
all
()
# test reconstruction step --> noise in data_batch and input (error)
data_batch
=
dict
(
real_img
=
torch
.
randn
(
2
,
3
,
32
,
32
),
noise
=
torch
.
randn
(
2
,
3
,
32
,
32
))
with
pytest
.
raises
(
AssertionError
):
output_dict
=
diffusion
(
data_batch
,
noise
=
torch
.
randn
(
2
,
3
,
32
,
32
),
mode
=
'reconstruction'
)
# test sample from noise
fake_imgs
=
diffusion
.
sample_from_noise
(
None
,
num_batches
=
2
)
assert
fake_imgs
.
shape
==
(
2
,
3
,
32
,
32
)
# test sample from noise --> save_intermedia
output_dict
=
diffusion
.
sample_from_noise
(
None
,
num_batches
=
2
,
save_intermedia
=
True
)
assert
list
(
output_dict
.
keys
())
==
[
i
for
i
in
range
(
10
,
-
1
,
-
1
)]
# test sample from noise -->
# sample model == ema/orig and use_ema == False
output_dict
=
diffusion
.
sample_from_noise
(
None
,
num_batches
=
2
,
save_intermedia
=
True
)
# test sample from noise --> sample model == ema but use_ema = False
with
pytest
.
raises
(
AssertionError
):
diffusion
.
sample_from_noise
(
None
,
num_batches
=
2
,
sample_model
=
'ema'
)
# test sample from noise --> wrong sample method
diffusion
.
sample_method
=
'dk_method'
with
pytest
.
raises
(
AttributeError
):
diffusion
.
sample_from_noise
(
None
,
num_batches
=
2
,
save_intermedia
=
True
)
# test sample from noise --> sample model == orig/ema
cfg
=
deepcopy
(
self
.
config
)
denoising_cfg
=
deepcopy
(
self
.
denoising
)
denoising_cfg
[
'output_cfg'
]
=
dict
(
mean
=
'start_x'
,
var
=
'learned'
)
cfg
[
'train_cfg'
]
=
dict
(
use_ema
=
True
)
cfg
[
'denoising'
]
=
denoising_cfg
cfg
[
'timestep_sampler'
]
=
self
.
sampler
cfg
[
'ddpm_loss'
]
=
None
diffusion
=
build_model
(
cfg
)
output
=
diffusion
.
sample_from_noise
(
None
,
num_batches
=
2
)
assert
output
.
shape
==
(
4
,
3
,
32
,
32
)
# test sample from noise --> ema
fake_img
=
diffusion
.
sample_from_noise
(
None
,
num_batches
=
2
,
sample_model
=
'ema'
)
assert
fake_img
.
shape
==
(
2
,
3
,
32
,
32
)
# test sample from noise --> orig/ema + save_intermedia
output_dict
=
diffusion
.
sample_from_noise
(
None
,
num_batches
=
2
,
save_intermedia
=
True
)
assert
list
(
output_dict
.
keys
())
==
[
i
for
i
in
range
(
10
,
-
1
,
-
1
)]
assert
all
([
v
.
shape
==
(
4
,
3
,
32
,
32
)
for
v
in
output_dict
.
values
()])
# test sample from noise -->
# sample model == ema/orig return noise = True
output_dict
=
diffusion
.
sample_from_noise
(
None
,
num_batches
=
2
,
save_intermedia
=
True
,
return_noise
=
True
,
sample_model
=
'ema/orig'
)
assert
list
(
output_dict
.
keys
())
==
[
i
for
i
in
range
(
10
,
-
1
,
-
1
)]
assert
all
([
v
.
shape
==
(
4
,
3
,
32
,
32
)
for
v
in
output_dict
.
values
()])
# test sample from noise -->
# sample model == ema/orig return noise = True, timesteps_noise
output_dict
=
diffusion
.
sample_from_noise
(
None
,
num_batches
=
2
,
save_intermedia
=
True
,
return_noise
=
True
,
timesteps_noise
=
torch
.
randn
(
10
,
3
,
32
,
32
),
sample_model
=
'ema/orig'
)
assert
list
(
output_dict
.
keys
())
==
[
i
for
i
in
range
(
10
,
-
1
,
-
1
)]
assert
all
([
v
.
shape
==
(
4
,
3
,
32
,
32
)
for
v
in
output_dict
.
values
()])
# test denoising_var_mode = 'LEARNED' & denoising_mean_mode = 'start_x'
cfg
=
deepcopy
(
self
.
config
)
denoising_cfg
=
deepcopy
(
self
.
denoising
)
denoising_cfg
[
'output_cfg'
]
=
dict
(
mean
=
'start_x'
,
var
=
'learned'
)
cfg
[
'denoising'
]
=
denoising_cfg
cfg
[
'timestep_sampler'
]
=
self
.
sampler
cfg
[
'ddpm_loss'
]
=
None
diffusion
=
build_model
(
cfg
)
output_dict
=
diffusion
.
sample_from_noise
(
None
,
num_batches
=
2
)
# test denoising_var_mode = 'fixed_large' &
# denoising_mean_mode = 'previous_x'
cfg
=
deepcopy
(
self
.
config
)
denoising_cfg
=
deepcopy
(
self
.
denoising
)
denoising_cfg
[
'output_cfg'
]
=
dict
(
mean
=
'previous_x'
,
var
=
'fixed_large'
)
cfg
[
'denoising'
]
=
denoising_cfg
cfg
[
'timestep_sampler'
]
=
self
.
sampler
cfg
[
'ddpm_loss'
]
=
None
diffusion
=
build_model
(
cfg
)
output_dict
=
diffusion
.
sample_from_noise
(
None
,
num_batches
=
2
)
# test denoising_var_mode = 'fixed_small' &
# denoising_mean_mode = 'previous_x'
cfg
=
deepcopy
(
self
.
config
)
denoising_cfg
=
deepcopy
(
self
.
denoising
)
denoising_cfg
[
'output_cfg'
]
=
dict
(
mean
=
'previous_x'
,
var
=
'fixed_small'
)
cfg
[
'denoising'
]
=
denoising_cfg
cfg
[
'timestep_sampler'
]
=
self
.
sampler
cfg
[
'ddpm_loss'
]
=
None
diffusion
=
build_model
(
cfg
)
output_dict
=
diffusion
.
sample_from_noise
(
None
,
num_batches
=
2
)
# test output_cfg --> error denoising_mean_mode
cfg
=
deepcopy
(
self
.
config
)
denoising_cfg
=
deepcopy
(
self
.
denoising
)
denoising_cfg
[
'output_cfg'
]
=
dict
(
mean
=
'x_0'
,
var
=
'fixed_small'
)
cfg
[
'denoising'
]
=
denoising_cfg
cfg
[
'timestep_sampler'
]
=
self
.
sampler
cfg
[
'ddpm_loss'
]
=
None
diffusion
=
build_model
(
cfg
)
with
pytest
.
raises
(
AttributeError
):
output_dict
=
diffusion
.
sample_from_noise
(
None
,
num_batches
=
2
)
# test output_cfg --> error denoising_var_mode
cfg
=
deepcopy
(
self
.
config
)
denoising_cfg
=
deepcopy
(
self
.
denoising
)
denoising_cfg
[
'output_cfg'
]
=
dict
(
mean
=
'previous_x'
,
var
=
'fixex'
)
cfg
[
'denoising'
]
=
denoising_cfg
cfg
[
'timestep_sampler'
]
=
self
.
sampler
cfg
[
'ddpm_loss'
]
=
None
diffusion
=
build_model
(
cfg
)
with
pytest
.
raises
(
AttributeError
):
output_dict
=
diffusion
.
sample_from_noise
(
None
,
num_batches
=
2
)
# test train step --> no running status but have diffusion.iteration
cfg
=
deepcopy
(
self
.
config
)
cfg
[
'denoising'
]
=
deepcopy
(
self
.
denoising
)
cfg
[
'timestep_sampler'
]
=
self
.
sampler
cfg
[
'ddpm_loss'
]
=
deepcopy
(
self
.
ddpm_loss
)
diffusion
=
build_model
(
cfg
)
setattr
(
diffusion
,
'iteration'
,
1
)
data
=
dict
(
real_img
=
torch
.
randn
(
2
,
3
,
32
,
32
))
optimizer
=
dict
(
denoising
=
torch
.
optim
.
SGD
(
diffusion
.
denoising
.
parameters
(),
lr
=
0.01
))
model_outputs
=
diffusion
.
train_step
(
data
,
optimizer
)
assert
'log_vars'
in
model_outputs
assert
'results'
in
model_outputs
# test train step --> running status
cfg
=
deepcopy
(
self
.
config
)
cfg
[
'denoising'
]
=
deepcopy
(
self
.
denoising
)
cfg
[
'timestep_sampler'
]
=
self
.
sampler
cfg
[
'ddpm_loss'
]
=
deepcopy
(
self
.
ddpm_loss
)
diffusion
=
build_model
(
cfg
)
data
=
dict
(
real_img
=
torch
.
randn
(
2
,
3
,
32
,
32
))
optimizer
=
dict
(
denoising
=
torch
.
optim
.
SGD
(
diffusion
.
denoising
.
parameters
(),
lr
=
0.01
))
model_outputs
=
diffusion
.
train_step
(
data
,
optimizer
,
running_status
=
dict
(
iteration
=
1
))
assert
'log_vars'
in
model_outputs
assert
'results'
in
model_outputs
def
test_ddpm_noise_batch_utils
():
image_shape
=
(
3
,
32
,
32
)
num_batches
=
2
# noise is None, timestep is False
noise_out
=
_get_noise_batch
(
None
,
image_shape
,
num_batches
=
num_batches
)
assert
noise_out
.
shape
==
(
2
,
3
,
32
,
32
)
print
(
noise_out
.
shape
)
# noise is None, timestep is True
noise_out
=
_get_noise_batch
(
None
,
image_shape
,
4
,
num_batches
,
True
)
assert
noise_out
.
shape
==
(
4
,
2
,
3
,
32
,
32
)
print
(
noise_out
.
shape
)
# noise is callable, timestep is False
noise_out
=
_get_noise_batch
(
lambda
shape
:
torch
.
randn
(
*
shape
),
image_shape
,
num_batches
=
num_batches
)
print
(
noise_out
.
shape
)
assert
noise_out
.
shape
==
(
2
,
3
,
32
,
32
)
# noise is callable, timestep is True
noise_out
=
_get_noise_batch
(
lambda
shape
:
torch
.
randn
(
*
shape
),
image_shape
,
4
,
num_batches
,
True
)
print
(
noise_out
.
shape
)
assert
noise_out
.
shape
==
(
4
,
2
,
3
,
32
,
32
)
# noise is Tensor, timestep is False, noise dim = 3
noise_inp
=
torch
.
randn
(
3
,
32
,
32
)
noise_out
=
_get_noise_batch
(
noise_inp
,
image_shape
)
print
(
noise_out
.
shape
)
assert
noise_out
.
shape
==
(
1
,
3
,
32
,
32
)
# noise is Tensor, timestep is False, noise dim = 4
noise_inp
=
torch
.
randn
(
2
,
3
,
32
,
32
)
noise_out
=
_get_noise_batch
(
noise_inp
,
image_shape
)
print
(
noise_out
.
shape
)
assert
noise_out
.
shape
==
(
2
,
3
,
32
,
32
)
# noise is Tensor, timestep is False, noise dim = 5
noise_inp
=
torch
.
randn
(
1
,
2
,
3
,
32
,
32
)
with
pytest
.
raises
(
ValueError
):
_get_noise_batch
(
noise_inp
,
image_shape
)
# noise is Tensor, timestep is True, noise dim = 4
# noise.size(0) == num_batches
noise_inp
=
torch
.
randn
(
4
,
3
,
32
,
32
)
noise_out
=
_get_noise_batch
(
noise_inp
,
image_shape
,
num_timesteps
=
6
,
num_batches
=
4
,
timesteps_noise
=
True
)
print
(
noise_out
.
shape
)
assert
noise_out
.
shape
==
(
6
,
4
,
3
,
32
,
32
)
assert
all
([(
noise_inp
==
noise
).
all
()
for
noise
in
noise_out
])
# noise is Tensor, timestep is True, noise dim = 4
# noise.size(0) == num_timesteps
noise_inp
=
torch
.
randn
(
6
,
3
,
32
,
32
)
noise_out
=
_get_noise_batch
(
noise_inp
,
image_shape
,
num_timesteps
=
6
,
num_batches
=
4
,
timesteps_noise
=
True
)
print
(
noise_out
.
shape
)
assert
noise_out
.
shape
==
(
6
,
4
,
3
,
32
,
32
)
assert
all
([(
noise_inp
==
noise_out
[:,
idx
,
...]).
all
()
for
idx
in
range
(
4
)])
# noise is Tensor, timestep is True, noise dim = 4
# noise.size(0) == num_timesteps * num_batches
noise_inp
=
torch
.
randn
(
24
,
3
,
32
,
32
)
noise_out
=
_get_noise_batch
(
noise_inp
,
image_shape
,
num_timesteps
=
6
,
num_batches
=
4
,
timesteps_noise
=
True
)
print
(
noise_out
.
shape
)
assert
noise_out
.
shape
==
(
6
,
4
,
3
,
32
,
32
)
assert
all
([(
noise_inp
[
idx
]
==
noise_out
[
idx
//
4
][
idx
%
4
]).
all
()
for
idx
in
range
(
24
)])
# noise is Tensor, timestep is True, noise dim = 4
# noise_out.size(0) != num_batches * num_timesteps
noise_inp
=
torch
.
randn
(
25
,
3
,
32
,
32
)
with
pytest
.
raises
(
ValueError
):
_get_noise_batch
(
noise_inp
,
image_shape
,
num_timesteps
=
6
,
num_batches
=
4
,
timesteps_noise
=
True
)
# noise is Tensor, timestep is True, noise dim = 5
noise_inp
=
torch
.
randn
(
6
,
4
,
3
,
32
,
32
)
noise_out
=
_get_noise_batch
(
noise_inp
,
image_shape
,
num_timesteps
=
6
,
num_batches
=
4
,
timesteps_noise
=
True
)
print
(
noise_out
.
shape
)
assert
noise_out
.
shape
==
(
6
,
4
,
3
,
32
,
32
)
assert
(
noise_out
==
noise_inp
).
all
()
# noise is Tensor, timestep is True, noise dim = 6
noise_inp
=
torch
.
randn
(
1
,
6
,
4
,
3
,
32
,
32
)
with
pytest
.
raises
(
ValueError
):
noise_out
=
_get_noise_batch
(
noise_inp
,
image_shape
,
num_timesteps
=
6
,
num_batches
=
4
,
timesteps_noise
=
True
)
def
test_ddpm_label_batch_utils
():
# num_classes = 0
label_out
=
_get_label_batch
(
label
=
None
,
num_timesteps
=
2
,
num_classes
=
0
,
num_batches
=
2
)
assert
label_out
is
None
# num_classes = 0, label is not None
with
pytest
.
raises
(
AssertionError
):
label_out
=
_get_label_batch
(
label
=
torch
.
randint
(
0
,
10
,
(
2
,
)),
num_timesteps
=
2
,
num_classes
=
0
,
num_batches
=
2
)
# label is None, timestep is False
label_out
=
_get_label_batch
(
None
,
num_classes
=
10
,
num_batches
=
2
)
assert
label_out
.
shape
==
(
2
,
)
assert
torch
.
logical_and
(
label_out
>=
0
,
label_out
<
10
).
all
()
# label is None, timestep is True
label_out
=
_get_label_batch
(
None
,
num_classes
=
10
,
num_batches
=
2
,
num_timesteps
=
4
,
timesteps_noise
=
True
)
assert
label_out
.
shape
==
(
4
,
2
)
assert
torch
.
logical_and
(
label_out
>=
0
,
label_out
<
10
).
all
()
# label is callable, timestep is False
label_out
=
_get_label_batch
(
lambda
shape
:
torch
.
randint
(
0
,
10
,
shape
),
num_classes
=
10
,
num_batches
=
2
)
assert
label_out
.
shape
==
(
2
,
)
assert
torch
.
logical_and
(
label_out
>=
0
,
label_out
<
10
).
all
()
# label is callable, timestep is True
label_out
=
_get_label_batch
(
lambda
shape
:
torch
.
randint
(
0
,
10
,
shape
),
num_classes
=
10
,
num_timesteps
=
4
,
num_batches
=
2
,
timesteps_noise
=
True
)
assert
label_out
.
shape
==
(
4
,
2
)
assert
torch
.
logical_and
(
label_out
>=
0
,
label_out
<
10
).
all
()
# label is tensor, timestep is False, label dim = 1
label_inp
=
torch
.
LongTensor
([
4
,
3
])
label_out
=
_get_label_batch
(
label_inp
,
num_classes
=
10
)
assert
label_out
.
shape
==
(
2
,
)
assert
(
label_out
==
label_inp
).
all
()
# label is tensor, timestep is False, label dim = 0
label_inp
=
torch
.
from_numpy
(
np
.
array
(
10
))
label_out
=
_get_label_batch
(
label_inp
,
num_classes
=
10
)
assert
label_out
.
shape
==
(
1
,
)
# label is tensor, timestep is False, label dim = 2
label_inp
=
torch
.
randint
(
0
,
10
,
(
4
,
2
))
with
pytest
.
raises
(
ValueError
):
_get_label_batch
(
label_inp
,
num_classes
=
10
,
num_batches
=
2
)
# label is tensor, timestep is True, label dim = 1
# label.size(0) == num_batches
label_inp
=
torch
.
randint
(
0
,
10
,
(
2
,
))
label_out
=
_get_label_batch
(
label_inp
,
num_timesteps
=
4
,
num_batches
=
2
,
num_classes
=
10
,
timesteps_noise
=
True
)
assert
label_out
.
shape
==
(
4
,
2
)
assert
all
([(
label_out
[
idx
]
==
label_inp
).
all
()
for
idx
in
range
(
4
)])
# label is tensor, timestep is True, label dim = 1
# label.size(0) == num_timesteps
label_inp
=
torch
.
randint
(
0
,
10
,
(
4
,
))
label_out
=
_get_label_batch
(
label_inp
,
num_timesteps
=
4
,
num_batches
=
2
,
num_classes
=
10
,
timesteps_noise
=
True
)
assert
label_out
.
shape
==
(
4
,
2
)
assert
all
([(
label_inp
==
label_out
[:,
idx
]).
all
()
for
idx
in
range
(
2
)])
# label is tensor, timestep is True, label dim = 1
# label.size(0) == num_timesteps * num_batches
label_inp
=
torch
.
randint
(
0
,
10
,
(
8
,
))
label_out
=
_get_label_batch
(
label_inp
,
num_timesteps
=
4
,
num_batches
=
2
,
num_classes
=
10
,
timesteps_noise
=
True
)
assert
label_out
.
shape
==
(
4
,
2
)
assert
all
([(
label_inp
[
idx
]
==
label_out
[
idx
//
2
][
idx
%
2
]).
all
()
for
idx
in
range
(
8
)])
# label is tensor, timestep is True, label dim = 1
# label.size(0) != num_timesteps * num_batches
label_inp
=
torch
.
randint
(
0
,
10
,
(
9
,
))
with
pytest
.
raises
(
ValueError
):
_get_label_batch
(
label_inp
,
num_timesteps
=
4
,
num_batches
=
2
,
num_classes
=
10
,
timesteps_noise
=
True
)
# label is tensor, timestep is True, label dim = 2
label_inp
=
torch
.
randint
(
0
,
10
,
(
4
,
2
))
label_out
=
_get_label_batch
(
label_inp
,
num_timesteps
=
4
,
num_batches
=
2
,
num_classes
=
10
,
timesteps_noise
=
True
)
assert
label_out
.
shape
==
(
4
,
2
)
assert
(
label_out
==
label_inp
).
all
()
# label is tensor, timestep is True, label dim = 3
label_inp
=
torch
.
randint
(
0
,
10
,
(
4
,
2
,
1
))
with
pytest
.
raises
(
ValueError
):
_get_label_batch
(
label_inp
,
num_timesteps
=
4
,
num_batches
=
2
,
num_classes
=
10
,
timesteps_noise
=
True
)
tests/test_models/test_mspie_styelgan2.py
0 → 100644
View file @
c9a48a52
# Copyright (c) OpenMMLab. All rights reserved.
from
copy
import
deepcopy
import
torch
from
mmgen.models.gans.mspie_stylegan2
import
MSPIEStyleGAN2
class
TestMSStyleGAN2
:
@
classmethod
def
setup_class
(
cls
):
cls
.
generator_cfg
=
dict
(
type
=
'MSStyleGANv2Generator'
,
out_size
=
32
,
style_channels
=
16
)
cls
.
disc_cfg
=
dict
(
type
=
'MSStyleGAN2Discriminator'
,
in_size
=
32
,
with_adaptive_pool
=
True
)
cls
.
gan_loss
=
dict
(
type
=
'GANLoss'
,
gan_type
=
'vanilla'
)
cls
.
disc_auxiliary_loss
=
dict
(
type
=
'R1GradientPenalty'
,
loss_weight
=
10.
/
2.
,
interval
=
1
,
norm_mode
=
'HWC'
,
data_info
=
dict
(
real_data
=
'real_imgs'
,
discriminator
=
'disc'
))
cls
.
train_cfg
=
dict
(
use_ema
=
True
,
num_upblocks
=
3
,
multi_input_scales
=
[
0
,
2
,
4
],
multi_scale_probability
=
[
0.5
,
0.25
,
0.25
])
def
test_msstylegan2_cpu
(
self
):
stylegan2
=
MSPIEStyleGAN2
(
self
.
generator_cfg
,
self
.
disc_cfg
,
self
.
gan_loss
,
self
.
disc_auxiliary_loss
,
None
,
train_cfg
=
self
.
train_cfg
,
test_cfg
=
None
)
optimizer_g
=
torch
.
optim
.
SGD
(
stylegan2
.
generator
.
parameters
(),
lr
=
0.01
)
optimizer_d
=
torch
.
optim
.
SGD
(
stylegan2
.
discriminator
.
parameters
(),
lr
=
0.01
)
optim_dict
=
dict
(
generator
=
optimizer_g
,
discriminator
=
optimizer_d
)
data
=
torch
.
randn
((
2
,
3
,
16
,
16
))
data_input
=
dict
(
real_img
=
data
)
model_outputs
=
stylegan2
.
train_step
(
data_input
,
optim_dict
)
assert
'results'
in
model_outputs
assert
'log_vars'
in
model_outputs
assert
model_outputs
[
'num_samples'
]
==
2
cfg_
=
deepcopy
(
self
.
train_cfg
)
cfg_
[
'disc_steps'
]
=
2
stylegan2
=
MSPIEStyleGAN2
(
self
.
generator_cfg
,
self
.
disc_cfg
,
self
.
gan_loss
,
self
.
disc_auxiliary_loss
,
None
,
train_cfg
=
cfg_
,
test_cfg
=
None
)
optimizer_g
=
torch
.
optim
.
SGD
(
stylegan2
.
generator
.
parameters
(),
lr
=
0.01
)
optimizer_d
=
torch
.
optim
.
SGD
(
stylegan2
.
discriminator
.
parameters
(),
lr
=
0.01
)
optim_dict
=
dict
(
generator
=
optimizer_g
,
discriminator
=
optimizer_d
)
data
=
torch
.
randn
((
2
,
3
,
16
,
16
))
data_input
=
dict
(
real_img
=
data
)
model_outputs
=
stylegan2
.
train_step
(
data_input
,
optim_dict
)
assert
'results'
in
model_outputs
assert
'log_vars'
in
model_outputs
assert
model_outputs
[
'num_samples'
]
==
2
tests/test_models/test_pggan.py
0 → 100644
View file @
c9a48a52
# Copyright (c) OpenMMLab. All rights reserved.
from
copy
import
deepcopy
import
numpy
as
np
import
pytest
import
torch
import
torch.nn
as
nn
from
mmgen.models.gans
import
ProgressiveGrowingGAN
class
TestPGGAN
:
@
classmethod
def
setup_class
(
cls
):
cls
.
generator_cfg
=
dict
(
type
=
'PGGANGenerator'
,
noise_size
=
8
,
out_scale
=
16
,
base_channels
=
32
,
max_channels
=
32
)
cls
.
discriminator_cfg
=
dict
(
type
=
'PGGANDiscriminator'
,
in_scale
=
16
,
label_size
=
0
)
cls
.
gan_loss
=
dict
(
type
=
'GANLoss'
,
gan_type
=
'vanilla'
)
cls
.
disc_auxiliary_loss
=
[
dict
(
type
=
'DiscShiftLoss'
,
loss_weight
=
0.5
,
data_info
=
dict
(
pred
=
'disc_pred_fake'
)),
dict
(
type
=
'DiscShiftLoss'
,
loss_weight
=
0.5
,
data_info
=
dict
(
pred
=
'disc_pred_real'
))
]
cls
.
train_cfg
=
dict
(
use_ema
=
True
,
nkimgs_per_scale
=
{
'4'
:
0.004
,
'8'
:
0.008
,
'16'
:
0.016
},
optimizer_cfg
=
dict
(
generator
=
dict
(
type
=
'Adam'
,
lr
=
0.0002
,
betas
=
(
0.5
,
0.999
)),
discriminator
=
dict
(
type
=
'Adam'
,
lr
=
0.0002
,
betas
=
(
0.5
,
0.999
))),
g_lr_base
=
0.0001
,
d_lr_base
=
0.0001
,
g_lr_schedule
=
{
'16'
:
0.00005
})
def
test_pggan_cpu
(
self
):
# test default config
pggan
=
ProgressiveGrowingGAN
(
self
.
generator_cfg
,
self
.
discriminator_cfg
,
self
.
gan_loss
,
disc_auxiliary_loss
=
self
.
disc_auxiliary_loss
,
train_cfg
=
self
.
train_cfg
)
data_batch
=
dict
(
real_img
=
torch
.
randn
(
3
,
3
,
16
,
16
))
for
iter_num
in
range
(
6
):
outputs
=
pggan
.
train_step
(
data_batch
,
None
,
running_status
=
dict
(
iteration
=
iter_num
,
batch_size
=
3
))
results
=
outputs
[
'results'
]
if
iter_num
==
1
:
assert
results
[
'fake_imgs'
].
shape
==
(
3
,
3
,
4
,
4
)
elif
iter_num
==
2
:
assert
results
[
'fake_imgs'
].
shape
==
(
3
,
3
,
8
,
8
)
assert
np
.
isclose
(
pggan
.
_actual_nkimgs
[
0
],
0.006
,
atol
=
1e-8
)
elif
iter_num
==
3
:
assert
results
[
'fake_imgs'
].
shape
==
(
3
,
3
,
8
,
8
)
assert
np
.
isclose
(
pggan
.
_actual_nkimgs
[
0
],
0.006
,
atol
=
1e-8
)
assert
np
.
isclose
(
pggan
.
optimizer
[
'generator'
].
defaults
[
'lr'
],
0.0001
,
atol
=
1e-8
)
elif
iter_num
==
5
:
assert
results
[
'fake_imgs'
].
shape
==
(
3
,
3
,
16
,
16
)
assert
np
.
isclose
(
pggan
.
_actual_nkimgs
[
-
1
],
0.012
,
atol
=
1e-8
)
assert
np
.
isclose
(
pggan
.
optimizer
[
'generator'
].
defaults
[
'lr'
],
0.00005
,
atol
=
1e-8
)
# test sample from noise
outputs
=
pggan
.
sample_from_noise
(
None
,
num_batches
=
2
)
assert
outputs
.
shape
==
(
4
,
3
,
16
,
16
)
outputs
=
pggan
.
sample_from_noise
(
None
,
num_batches
=
2
,
return_noise
=
True
,
transition_weight
=
0.2
,
sample_model
=
'ema'
)
assert
outputs
[
'fake_img'
].
shape
==
(
2
,
3
,
16
,
16
)
outputs
=
pggan
.
sample_from_noise
(
None
,
num_batches
=
2
,
return_noise
=
True
,
sample_model
=
'orig'
)
assert
outputs
[
'fake_img'
].
shape
==
(
2
,
3
,
16
,
16
)
with
pytest
.
raises
(
RuntimeError
):
data_batch
=
dict
(
real_img
=
torch
.
randn
(
3
,
3
,
4
,
32
))
_
=
pggan
.
train_step
(
data_batch
,
None
,
running_status
=
dict
(
iteration
=
5
))
# test customized config
train_cfg_
=
deepcopy
(
self
.
train_cfg
)
train_cfg_
[
'use_ema'
]
=
False
train_cfg_
[
'interp_real_cfg'
]
=
dict
(
mode
=
'bilinear'
,
align_corners
=
False
)
train_cfg_
[
'interp_real_cfg'
]
=
dict
(
mode
=
'bilinear'
,
align_corners
=
False
)
train_cfg_
[
'reset_optim_for_new_scale'
]
=
False
pggan
=
ProgressiveGrowingGAN
(
self
.
generator_cfg
,
self
.
discriminator_cfg
,
self
.
gan_loss
,
disc_auxiliary_loss
=
None
,
train_cfg
=
train_cfg_
)
data_batch
=
dict
(
real_img
=
torch
.
randn
(
3
,
3
,
16
,
16
))
outputs
=
pggan
.
train_step
(
data_batch
,
None
,
running_status
=
dict
(
iteration
=
0
,
batch_size
=
3
))
results
=
outputs
[
'results'
]
assert
results
[
'fake_imgs'
].
shape
==
(
3
,
3
,
4
,
4
)
assert
not
pggan
.
with_gen_auxiliary_loss
assert
not
pggan
.
with_disc_auxiliary_loss
assert
not
pggan
.
use_ema
data_batch
=
dict
(
real_img
=
torch
.
randn
(
3
,
3
,
16
,
16
))
for
iter_num
in
range
(
1
,
3
):
outputs
=
pggan
.
train_step
(
data_batch
,
None
,
running_status
=
dict
(
iteration
=
iter_num
,
batch_size
=
3
))
results
=
outputs
[
'results'
]
if
iter_num
==
1
:
assert
results
[
'fake_imgs'
].
shape
==
(
3
,
3
,
4
,
4
)
elif
iter_num
==
2
:
assert
results
[
'fake_imgs'
].
shape
==
(
3
,
3
,
8
,
8
)
assert
np
.
isclose
(
pggan
.
_actual_nkimgs
[
0
],
0.006
,
atol
=
1e-8
)
train_cfg_
=
deepcopy
(
self
.
train_cfg
)
train_cfg_
[
'optimizer_cfg'
]
=
dict
(
generator
=
dict
(
type
=
'Adam'
,
lr
=
0.0002
,
betas
=
(
0.5
,
0.999
)))
pggan
=
ProgressiveGrowingGAN
(
self
.
generator_cfg
,
None
,
None
,
disc_auxiliary_loss
=
dict
(
type
=
'DiscShiftLoss'
,
loss_weight
=
0.5
,
data_info
=
dict
(
pred
=
'disc_pred_fake'
)),
gen_auxiliary_loss
=
dict
(
type
=
'GeneratorPathRegularizer'
),
train_cfg
=
train_cfg_
)
assert
pggan
.
with_gen_auxiliary_loss
assert
isinstance
(
pggan
.
disc_auxiliary_losses
,
nn
.
ModuleList
)
assert
pggan
.
gan_loss
is
None
assert
pggan
.
discriminator
is
None
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
'requires cuda'
)
def
test_pggan_cuda
(
self
):
# test default config
pggan
=
ProgressiveGrowingGAN
(
self
.
generator_cfg
,
self
.
discriminator_cfg
,
self
.
gan_loss
,
disc_auxiliary_loss
=
self
.
disc_auxiliary_loss
,
train_cfg
=
self
.
train_cfg
).
cuda
()
data_batch
=
dict
(
real_img
=
torch
.
randn
(
3
,
3
,
32
,
32
).
cuda
())
for
iter_num
in
range
(
6
):
outputs
=
pggan
.
train_step
(
data_batch
,
None
,
running_status
=
dict
(
iteration
=
iter_num
,
batch_size
=
3
))
results
=
outputs
[
'results'
]
if
iter_num
==
1
:
assert
results
[
'fake_imgs'
].
shape
==
(
3
,
3
,
4
,
4
)
elif
iter_num
==
2
:
assert
results
[
'fake_imgs'
].
shape
==
(
3
,
3
,
8
,
8
)
assert
np
.
isclose
(
pggan
.
_actual_nkimgs
[
0
],
0.006
,
atol
=
1e-8
)
elif
iter_num
==
3
:
assert
results
[
'fake_imgs'
].
shape
==
(
3
,
3
,
8
,
8
)
assert
np
.
isclose
(
pggan
.
_actual_nkimgs
[
0
],
0.006
,
atol
=
1e-8
)
elif
iter_num
==
5
:
assert
results
[
'fake_imgs'
].
shape
==
(
3
,
3
,
16
,
16
)
assert
np
.
isclose
(
pggan
.
_actual_nkimgs
[
-
1
],
0.012
,
atol
=
1e-8
)
tests/test_models/test_pix2pix.py
0 → 100644
View file @
c9a48a52
# Copyright (c) OpenMMLab. All rights reserved.
import
copy
import
torch
from
mmcv.runner
import
obj_from_dict
from
mmgen.models
import
GANLoss
,
L1Loss
,
build_model
from
mmgen.models.architectures.pix2pix
import
(
PatchDiscriminator
,
UnetGenerator
)
def
test_pix2pix
():
# model settings
model_cfg
=
dict
(
type
=
'Pix2Pix'
,
generator
=
dict
(
type
=
'UnetGenerator'
,
in_channels
=
3
,
out_channels
=
3
,
num_down
=
8
,
base_channels
=
64
,
norm_cfg
=
dict
(
type
=
'BN'
),
use_dropout
=
True
,
init_cfg
=
dict
(
type
=
'normal'
,
gain
=
0.02
)),
discriminator
=
dict
(
type
=
'PatchDiscriminator'
,
in_channels
=
6
,
base_channels
=
64
,
num_conv
=
3
,
norm_cfg
=
dict
(
type
=
'BN'
),
init_cfg
=
dict
(
type
=
'normal'
,
gain
=
0.02
)),
gan_loss
=
dict
(
type
=
'GANLoss'
,
gan_type
=
'vanilla'
,
real_label_val
=
1.0
,
fake_label_val
=
0.0
,
loss_weight
=
1.0
),
default_domain
=
'photo'
,
reachable_domains
=
[
'photo'
],
related_domains
=
[
'photo'
,
'mask'
],
gen_auxiliary_loss
=
dict
(
type
=
'L1Loss'
,
loss_weight
=
100.0
,
data_info
=
dict
(
pred
=
'fake_photo'
,
target
=
'real_photo'
),
reduction
=
'mean'
))
train_cfg
=
None
test_cfg
=
None
# build synthesizer
synthesizer
=
build_model
(
model_cfg
,
train_cfg
=
train_cfg
,
test_cfg
=
test_cfg
)
# test attributes
assert
synthesizer
.
__class__
.
__name__
==
'Pix2Pix'
assert
isinstance
(
synthesizer
.
generators
[
'photo'
],
UnetGenerator
)
assert
isinstance
(
synthesizer
.
discriminators
[
'photo'
],
PatchDiscriminator
)
assert
isinstance
(
synthesizer
.
gan_loss
,
GANLoss
)
assert
isinstance
(
synthesizer
.
gen_auxiliary_losses
[
0
],
L1Loss
)
assert
synthesizer
.
test_cfg
is
None
# prepare data
img_mask
=
torch
.
rand
(
1
,
3
,
256
,
256
)
img_photo
=
torch
.
rand
(
1
,
3
,
256
,
256
)
data_batch
=
{
'img_mask'
:
img_mask
,
'img_photo'
:
img_photo
}
# prepare optimizer
optim_cfg
=
dict
(
type
=
'Adam'
,
lr
=
2e-4
,
betas
=
(
0.5
,
0.999
))
optimizer
=
{
'generators'
:
obj_from_dict
(
optim_cfg
,
torch
.
optim
,
dict
(
params
=
getattr
(
synthesizer
,
'generators'
).
parameters
())),
'discriminators'
:
obj_from_dict
(
optim_cfg
,
torch
.
optim
,
dict
(
params
=
getattr
(
synthesizer
,
'discriminators'
).
parameters
()))
}
# test forward_test
domain
=
'photo'
with
torch
.
no_grad
():
outputs
=
synthesizer
(
img_mask
,
target_domain
=
domain
,
test_mode
=
True
)
assert
torch
.
equal
(
outputs
[
'source'
],
data_batch
[
'img_mask'
])
assert
torch
.
is_tensor
(
outputs
[
'target'
])
assert
outputs
[
'target'
].
size
()
==
(
1
,
3
,
256
,
256
)
# test forward_train
outputs
=
synthesizer
(
img_mask
,
target_domain
=
domain
,
test_mode
=
False
)
assert
torch
.
equal
(
outputs
[
'source'
],
data_batch
[
'img_mask'
])
assert
torch
.
is_tensor
(
outputs
[
'target'
])
assert
outputs
[
'target'
].
size
()
==
(
1
,
3
,
256
,
256
)
# test train_step
outputs
=
synthesizer
.
train_step
(
data_batch
,
optimizer
)
assert
isinstance
(
outputs
,
dict
)
assert
isinstance
(
outputs
[
'log_vars'
],
dict
)
assert
isinstance
(
outputs
[
'results'
],
dict
)
for
v
in
[
'loss_gan_d_fake'
,
'loss_gan_d_real'
,
'loss_gan_g'
,
'loss_l1'
]:
assert
isinstance
(
outputs
[
'log_vars'
][
v
],
float
)
assert
outputs
[
'num_samples'
]
==
1
assert
torch
.
equal
(
outputs
[
'results'
][
'real_mask'
],
data_batch
[
'img_mask'
])
assert
torch
.
equal
(
outputs
[
'results'
][
'real_photo'
],
data_batch
[
'img_photo'
])
assert
torch
.
is_tensor
(
outputs
[
'results'
][
'fake_photo'
])
assert
outputs
[
'results'
][
'fake_photo'
].
size
()
==
(
1
,
3
,
256
,
256
)
# test cuda
if
torch
.
cuda
.
is_available
():
synthesizer
=
synthesizer
.
cuda
()
optimizer
=
{
'generators'
:
obj_from_dict
(
optim_cfg
,
torch
.
optim
,
dict
(
params
=
getattr
(
synthesizer
,
'generators'
).
parameters
())),
'discriminators'
:
obj_from_dict
(
optim_cfg
,
torch
.
optim
,
dict
(
params
=
getattr
(
synthesizer
,
'discriminators'
).
parameters
()))
}
data_batch_cuda
=
copy
.
deepcopy
(
data_batch
)
data_batch_cuda
[
'img_mask'
]
=
img_mask
.
cuda
()
data_batch_cuda
[
'img_photo'
]
=
img_photo
.
cuda
()
# forward_test
with
torch
.
no_grad
():
outputs
=
synthesizer
(
data_batch_cuda
[
'img_mask'
],
target_domain
=
domain
,
test_mode
=
True
)
assert
torch
.
equal
(
outputs
[
'source'
],
data_batch_cuda
[
'img_mask'
].
cpu
())
assert
torch
.
is_tensor
(
outputs
[
'target'
])
assert
outputs
[
'target'
].
size
()
==
(
1
,
3
,
256
,
256
)
# test forward_train
outputs
=
synthesizer
(
data_batch_cuda
[
'img_mask'
],
target_domain
=
domain
,
test_mode
=
False
)
assert
torch
.
equal
(
outputs
[
'source'
],
data_batch_cuda
[
'img_mask'
])
assert
torch
.
is_tensor
(
outputs
[
'target'
])
assert
outputs
[
'target'
].
size
()
==
(
1
,
3
,
256
,
256
)
# train_step
outputs
=
synthesizer
.
train_step
(
data_batch_cuda
,
optimizer
)
assert
isinstance
(
outputs
,
dict
)
assert
isinstance
(
outputs
[
'log_vars'
],
dict
)
assert
isinstance
(
outputs
[
'results'
],
dict
)
for
v
in
[
'loss_gan_d_fake'
,
'loss_gan_d_real'
,
'loss_gan_g'
,
'loss_l1'
]:
assert
isinstance
(
outputs
[
'log_vars'
][
v
],
float
)
assert
outputs
[
'num_samples'
]
==
1
assert
torch
.
equal
(
outputs
[
'results'
][
'real_mask'
],
data_batch_cuda
[
'img_mask'
].
cpu
())
assert
torch
.
equal
(
outputs
[
'results'
][
'real_photo'
],
data_batch_cuda
[
'img_photo'
].
cpu
())
assert
torch
.
is_tensor
(
outputs
[
'results'
][
'fake_photo'
])
assert
outputs
[
'results'
][
'fake_photo'
].
size
()
==
(
1
,
3
,
256
,
256
)
# test disc_steps and disc_init_steps
data_batch
[
'img_mask'
]
=
img_mask
.
cpu
()
data_batch
[
'img_photo'
]
=
img_photo
.
cpu
()
train_cfg
=
dict
(
disc_steps
=
2
,
disc_init_steps
=
2
)
synthesizer
=
build_model
(
model_cfg
,
train_cfg
=
train_cfg
,
test_cfg
=
test_cfg
)
optimizer
=
{
'generators'
:
obj_from_dict
(
optim_cfg
,
torch
.
optim
,
dict
(
params
=
getattr
(
synthesizer
,
'generators'
).
parameters
())),
'discriminators'
:
obj_from_dict
(
optim_cfg
,
torch
.
optim
,
dict
(
params
=
getattr
(
synthesizer
,
'discriminators'
).
parameters
()))
}
# iter 0, 1
for
i
in
range
(
2
):
outputs
=
synthesizer
.
train_step
(
data_batch
,
optimizer
)
assert
isinstance
(
outputs
,
dict
)
assert
isinstance
(
outputs
[
'log_vars'
],
dict
)
assert
isinstance
(
outputs
[
'results'
],
dict
)
assert
outputs
[
'log_vars'
].
get
(
'loss_gan_g'
)
is
None
assert
outputs
[
'log_vars'
].
get
(
'loss_l1'
)
is
None
for
v
in
[
'loss_gan_d_fake'
,
'loss_gan_d_real'
]:
assert
isinstance
(
outputs
[
'log_vars'
][
v
],
float
)
assert
outputs
[
'num_samples'
]
==
1
assert
torch
.
equal
(
outputs
[
'results'
][
'real_mask'
],
data_batch
[
'img_mask'
])
assert
torch
.
equal
(
outputs
[
'results'
][
'real_photo'
],
data_batch
[
'img_photo'
])
assert
torch
.
is_tensor
(
outputs
[
'results'
][
'fake_photo'
])
assert
outputs
[
'results'
][
'fake_photo'
].
size
()
==
(
1
,
3
,
256
,
256
)
assert
synthesizer
.
iteration
==
i
+
1
# iter 2, 3, 4, 5
for
i
in
range
(
2
,
6
):
assert
synthesizer
.
iteration
==
i
outputs
=
synthesizer
.
train_step
(
data_batch
,
optimizer
)
assert
isinstance
(
outputs
,
dict
)
assert
isinstance
(
outputs
[
'log_vars'
],
dict
)
assert
isinstance
(
outputs
[
'results'
],
dict
)
log_check_list
=
[
'loss_gan_d_fake'
,
'loss_gan_d_real'
,
'loss_gan_g'
,
'loss_l1'
]
if
i
%
2
==
1
:
assert
outputs
[
'log_vars'
].
get
(
'loss_gan_g'
)
is
None
assert
outputs
[
'log_vars'
].
get
(
'loss_pixel'
)
is
None
log_check_list
.
remove
(
'loss_gan_g'
)
log_check_list
.
remove
(
'loss_l1'
)
for
v
in
log_check_list
:
assert
isinstance
(
outputs
[
'log_vars'
][
v
],
float
)
assert
outputs
[
'num_samples'
]
==
1
assert
torch
.
equal
(
outputs
[
'results'
][
'real_mask'
],
data_batch
[
'img_mask'
])
assert
torch
.
equal
(
outputs
[
'results'
][
'real_photo'
],
data_batch
[
'img_photo'
])
assert
torch
.
is_tensor
(
outputs
[
'results'
][
'fake_photo'
])
assert
outputs
[
'results'
][
'fake_photo'
].
size
()
==
(
1
,
3
,
256
,
256
)
assert
synthesizer
.
iteration
==
i
+
1
# test without pixel loss
model_cfg_
=
copy
.
deepcopy
(
model_cfg
)
model_cfg_
.
pop
(
'gen_auxiliary_loss'
)
synthesizer
=
build_model
(
model_cfg_
,
train_cfg
=
None
,
test_cfg
=
None
)
optimizer
=
{
'generators'
:
obj_from_dict
(
optim_cfg
,
torch
.
optim
,
dict
(
params
=
getattr
(
synthesizer
,
'generators'
).
parameters
())),
'discriminators'
:
obj_from_dict
(
optim_cfg
,
torch
.
optim
,
dict
(
params
=
getattr
(
synthesizer
,
'discriminators'
).
parameters
()))
}
data_batch
[
'img_mask'
]
=
img_mask
.
cpu
()
data_batch
[
'img_photo'
]
=
img_photo
.
cpu
()
outputs
=
synthesizer
.
train_step
(
data_batch
,
optimizer
)
assert
isinstance
(
outputs
,
dict
)
assert
isinstance
(
outputs
[
'log_vars'
],
dict
)
assert
isinstance
(
outputs
[
'results'
],
dict
)
assert
outputs
[
'log_vars'
].
get
(
'loss_pixel'
)
is
None
for
v
in
[
'loss_gan_d_fake'
,
'loss_gan_d_real'
,
'loss_gan_g'
]:
assert
isinstance
(
outputs
[
'log_vars'
][
v
],
float
)
assert
outputs
[
'num_samples'
]
==
1
assert
torch
.
equal
(
outputs
[
'results'
][
'real_mask'
],
data_batch
[
'img_mask'
])
assert
torch
.
equal
(
outputs
[
'results'
][
'real_photo'
],
data_batch
[
'img_photo'
])
assert
torch
.
is_tensor
(
outputs
[
'results'
][
'fake_photo'
])
assert
outputs
[
'results'
][
'fake_photo'
].
size
()
==
(
1
,
3
,
256
,
256
)
tests/test_models/test_sagan.py
0 → 100644
View file @
c9a48a52
# Copyright (c) OpenMMLab. All rights reserved.
import
pytest
import
torch
from
mmgen.models.gans
import
BasicConditionalGAN
class
TestSAGAN
:
@
classmethod
def
setup_class
(
cls
):
cls
.
generator_cfg
=
dict
(
type
=
'SAGANGenerator'
,
output_scale
=
32
,
base_channels
=
256
,
attention_cfg
=
dict
(
type
=
'SelfAttentionBlock'
),
attention_after_nth_block
=
2
,
num_classes
=
10
)
cls
.
discriminator_cfg
=
dict
(
type
=
'SAGANDiscriminator'
,
input_scale
=
32
,
base_channels
=
128
,
attention_cfg
=
dict
(
type
=
'SelfAttentionBlock'
),
attention_after_nth_block
=
1
,
num_classes
=
10
)
cls
.
disc_auxiliary_loss
=
None
cls
.
gan_loss
=
dict
(
type
=
'GANLoss'
,
gan_type
=
'hinge'
)
cls
.
train_cfg
=
None
def
test_sagan_cpu
(
self
):
# test default config
sagan
=
BasicConditionalGAN
(
self
.
generator_cfg
,
self
.
discriminator_cfg
,
self
.
gan_loss
,
disc_auxiliary_loss
=
None
,
train_cfg
=
self
.
train_cfg
)
# test sample from noise
outputs
=
sagan
.
sample_from_noise
(
None
,
num_batches
=
2
)
assert
outputs
.
shape
==
(
2
,
3
,
32
,
32
)
outputs
=
sagan
.
sample_from_noise
(
None
,
num_batches
=
2
,
return_noise
=
True
,
sample_model
=
'orig'
)
assert
outputs
[
'fake_img'
].
shape
==
(
2
,
3
,
32
,
32
)
# test train step
img
=
torch
.
randn
((
2
,
3
,
32
,
32
))
lab
=
torch
.
randint
(
0
,
10
,
(
2
,
))
data_input
=
dict
(
img
=
img
,
gt_label
=
lab
)
optimizer_g
=
torch
.
optim
.
SGD
(
sagan
.
generator
.
parameters
(),
lr
=
0.01
)
optimizer_d
=
torch
.
optim
.
SGD
(
sagan
.
discriminator
.
parameters
(),
lr
=
0.01
)
optim_dict
=
dict
(
generator
=
optimizer_g
,
discriminator
=
optimizer_d
)
model_outputs
=
sagan
.
train_step
(
data_input
,
optim_dict
)
assert
'results'
in
model_outputs
assert
'log_vars'
in
model_outputs
assert
model_outputs
[
'num_samples'
]
==
2
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
'requires cuda'
)
def
test_sagan_cuda
(
self
):
# test default config
sagan
=
BasicConditionalGAN
(
self
.
generator_cfg
,
self
.
discriminator_cfg
,
self
.
gan_loss
,
disc_auxiliary_loss
=
self
.
disc_auxiliary_loss
,
train_cfg
=
self
.
train_cfg
).
cuda
()
# test sample from noise
outputs
=
sagan
.
sample_from_noise
(
None
,
num_batches
=
2
)
assert
outputs
.
shape
==
(
2
,
3
,
32
,
32
)
outputs
=
sagan
.
sample_from_noise
(
None
,
num_batches
=
2
,
return_noise
=
True
,
sample_model
=
'orig'
)
assert
outputs
[
'fake_img'
].
shape
==
(
2
,
3
,
32
,
32
)
# test train step
img
=
torch
.
randn
((
2
,
3
,
32
,
32
)).
cuda
()
lab
=
torch
.
randint
(
0
,
10
,
(
2
,
)).
cuda
()
data_input
=
dict
(
img
=
img
,
gt_label
=
lab
)
optimizer_g
=
torch
.
optim
.
SGD
(
sagan
.
generator
.
parameters
(),
lr
=
0.01
)
optimizer_d
=
torch
.
optim
.
SGD
(
sagan
.
discriminator
.
parameters
(),
lr
=
0.01
)
optim_dict
=
dict
(
generator
=
optimizer_g
,
discriminator
=
optimizer_d
)
model_outputs
=
sagan
.
train_step
(
data_input
,
optim_dict
)
assert
'results'
in
model_outputs
assert
'log_vars'
in
model_outputs
assert
model_outputs
[
'num_samples'
]
==
2
tests/test_models/test_singan.py
0 → 100644
View file @
c9a48a52
# Copyright (c) OpenMMLab. All rights reserved.
import
torch
from
mmgen.models.gans.singan
import
PESinGAN
,
SinGAN
class
TestSinGAN
:
@
classmethod
def
setup_class
(
cls
):
cls
.
generator
=
dict
(
type
=
'SinGANMultiScaleGenerator'
,
in_channels
=
3
,
out_channels
=
3
,
num_scales
=
3
)
cls
.
disc
=
dict
(
type
=
'SinGANMultiScaleDiscriminator'
,
in_channels
=
3
,
num_scales
=
3
)
cls
.
gan_loss
=
dict
(
type
=
'GANLoss'
,
gan_type
=
'wgan'
,
loss_weight
=
1
)
cls
.
disc_auxiliary_loss
=
[
dict
(
type
=
'GradientPenaltyLoss'
,
loss_weight
=
0.1
,
norm_mode
=
'pixel'
,
data_info
=
dict
(
discriminator
=
'disc_partial'
,
real_data
=
'real_imgs'
,
fake_data
=
'fake_imgs'
))
]
cls
.
gen_auxiliary_loss
=
dict
(
type
=
'MSELoss'
,
loss_weight
=
10
,
data_info
=
dict
(
pred
=
'recon_imgs'
,
target
=
'real_imgs'
),
)
cls
.
train_cfg
=
dict
(
noise_weight_init
=
0.1
,
iters_per_scale
=
2
,
curr_scale
=-
1
,
disc_steps
=
3
,
generator_steps
=
3
,
lr_d
=
0.0005
,
lr_g
=
0.0005
,
lr_scheduler_args
=
dict
(
milestones
=
[
1600
],
gamma
=
0.1
))
cls
.
data_batch
=
dict
(
real_scale0
=
torch
.
randn
(
1
,
3
,
25
,
25
),
real_scale1
=
torch
.
randn
(
1
,
3
,
30
,
30
),
real_scale2
=
torch
.
randn
(
1
,
3
,
32
,
32
),
)
cls
.
data_batch
[
'input_sample'
]
=
torch
.
zeros_like
(
cls
.
data_batch
[
'real_scale0'
])
def
test_singan_cpu
(
self
):
singan
=
SinGAN
(
self
.
generator
,
self
.
disc
,
self
.
gan_loss
,
self
.
disc_auxiliary_loss
,
self
.
gen_auxiliary_loss
,
self
.
train_cfg
,
None
)
for
i
in
range
(
6
):
output
=
singan
.
train_step
(
self
.
data_batch
,
None
)
if
i
==
0
:
assert
output
[
'results'
][
'fake_imgs'
].
shape
[
-
2
:]
==
(
25
,
25
)
elif
i
==
2
:
assert
output
[
'results'
][
'fake_imgs'
].
shape
[
-
2
:]
==
(
30
,
30
)
elif
i
==
5
:
assert
output
[
'results'
][
'fake_imgs'
].
shape
[
-
2
:]
==
(
32
,
32
)
singan
=
SinGAN
(
self
.
generator
,
self
.
disc
,
self
.
gan_loss
,
None
,
None
,
self
.
train_cfg
,
None
)
for
i
in
range
(
6
):
output
=
singan
.
train_step
(
self
.
data_batch
,
None
)
if
i
==
0
:
assert
output
[
'results'
][
'fake_imgs'
].
shape
[
-
2
:]
==
(
25
,
25
)
elif
i
==
2
:
assert
output
[
'results'
][
'fake_imgs'
].
shape
[
-
2
:]
==
(
30
,
30
)
elif
i
==
5
:
assert
output
[
'results'
][
'fake_imgs'
].
shape
[
-
2
:]
==
(
32
,
32
)
# test sample from noise
img
=
singan
.
sample_from_noise
(
None
,
num_batches
=
1
)
assert
img
.
shape
==
(
1
,
3
,
32
,
32
)
class
TestPESinGAN
:
@
classmethod
def
setup_class
(
cls
):
cls
.
generator
=
dict
(
type
=
'SinGANMSGeneratorPE'
,
in_channels
=
3
,
out_channels
=
3
,
num_scales
=
3
,
interp_pad
=
True
,
noise_with_pad
=
True
)
cls
.
disc
=
dict
(
type
=
'SinGANMultiScaleDiscriminator'
,
in_channels
=
3
,
num_scales
=
3
)
cls
.
gan_loss
=
dict
(
type
=
'GANLoss'
,
gan_type
=
'wgan'
,
loss_weight
=
1
)
cls
.
disc_auxiliary_loss
=
[
dict
(
type
=
'GradientPenaltyLoss'
,
loss_weight
=
0.1
,
norm_mode
=
'pixel'
,
data_info
=
dict
(
discriminator
=
'disc_partial'
,
real_data
=
'real_imgs'
,
fake_data
=
'fake_imgs'
))
]
cls
.
gen_auxiliary_loss
=
dict
(
type
=
'MSELoss'
,
loss_weight
=
10
,
data_info
=
dict
(
pred
=
'recon_imgs'
,
target
=
'real_imgs'
),
)
cls
.
train_cfg
=
dict
(
noise_weight_init
=
0.1
,
iters_per_scale
=
2
,
curr_scale
=-
1
,
disc_steps
=
3
,
generator_steps
=
3
,
lr_d
=
0.0005
,
lr_g
=
0.0005
,
lr_scheduler_args
=
dict
(
milestones
=
[
1600
],
gamma
=
0.1
),
fixed_noise_with_pad
=
True
)
cls
.
data_batch
=
dict
(
real_scale0
=
torch
.
randn
(
1
,
3
,
25
,
25
),
real_scale1
=
torch
.
randn
(
1
,
3
,
30
,
30
),
real_scale2
=
torch
.
randn
(
1
,
3
,
32
,
32
),
)
cls
.
data_batch
[
'input_sample'
]
=
torch
.
zeros_like
(
cls
.
data_batch
[
'real_scale0'
])
def
test_pesingan_cpu
(
self
):
singan
=
PESinGAN
(
self
.
generator
,
self
.
disc
,
self
.
gan_loss
,
self
.
disc_auxiliary_loss
,
self
.
gen_auxiliary_loss
,
self
.
train_cfg
,
None
)
for
i
in
range
(
6
):
output
=
singan
.
train_step
(
self
.
data_batch
,
None
)
if
i
==
0
:
assert
singan
.
fixed_noises
[
0
].
shape
[
-
2
]
==
35
assert
output
[
'results'
][
'fake_imgs'
].
shape
[
-
2
:]
==
(
25
,
25
)
elif
i
==
2
:
assert
output
[
'results'
][
'fake_imgs'
].
shape
[
-
2
:]
==
(
30
,
30
)
elif
i
==
5
:
assert
output
[
'results'
][
'fake_imgs'
].
shape
[
-
2
:]
==
(
32
,
32
)
singan
=
PESinGAN
(
dict
(
type
=
'SinGANMSGeneratorPE'
,
in_channels
=
3
,
out_channels
=
3
,
num_scales
=
3
,
interp_pad
=
True
,
noise_with_pad
=
False
),
self
.
disc
,
self
.
gan_loss
,
None
,
None
,
dict
(
noise_weight_init
=
0.1
,
iters_per_scale
=
2
,
curr_scale
=-
1
,
disc_steps
=
3
,
generator_steps
=
3
,
lr_d
=
0.0005
,
lr_g
=
0.0005
,
lr_scheduler_args
=
dict
(
milestones
=
[
1600
],
gamma
=
0.1
),
fixed_noise_with_pad
=
False
),
None
)
for
i
in
range
(
6
):
output
=
singan
.
train_step
(
self
.
data_batch
,
None
)
if
i
==
0
:
assert
output
[
'results'
][
'fake_imgs'
].
shape
[
-
2
:]
==
(
25
,
25
)
elif
i
==
2
:
assert
output
[
'results'
][
'fake_imgs'
].
shape
[
-
2
:]
==
(
30
,
30
)
elif
i
==
5
:
assert
output
[
'results'
][
'fake_imgs'
].
shape
[
-
2
:]
==
(
32
,
32
)
# test sample from noise
img
=
singan
.
sample_from_noise
(
None
,
num_batches
=
1
)
assert
img
.
shape
==
(
1
,
3
,
32
,
32
)
tests/test_models/test_sngan_proj.py
0 → 100644
View file @
c9a48a52
# Copyright (c) OpenMMLab. All rights reserved.
import
pytest
import
torch
from
mmgen.models.gans
import
BasicConditionalGAN
class
TestSNGAN_PROJ
:
@
classmethod
def
setup_class
(
cls
):
cls
.
generator_cfg
=
dict
(
type
=
'SNGANGenerator'
,
output_scale
=
32
,
base_channels
=
256
,
num_classes
=
10
)
cls
.
discriminator_cfg
=
dict
(
type
=
'ProjDiscriminator'
,
input_scale
=
32
,
base_channels
=
128
,
num_classes
=
10
)
cls
.
disc_auxiliary_loss
=
None
cls
.
gan_loss
=
dict
(
type
=
'GANLoss'
,
gan_type
=
'hinge'
)
cls
.
train_cfg
=
None
def
test_sngan_proj_cpu
(
self
):
# test default config
snganproj
=
BasicConditionalGAN
(
self
.
generator_cfg
,
self
.
discriminator_cfg
,
self
.
gan_loss
,
disc_auxiliary_loss
=
None
,
train_cfg
=
self
.
train_cfg
)
# test sample from noise
outputs
=
snganproj
.
sample_from_noise
(
None
,
num_batches
=
2
)
assert
outputs
.
shape
==
(
2
,
3
,
32
,
32
)
outputs
=
snganproj
.
sample_from_noise
(
None
,
num_batches
=
2
,
return_noise
=
True
,
sample_model
=
'orig'
)
assert
outputs
[
'fake_img'
].
shape
==
(
2
,
3
,
32
,
32
)
# test train step
img
=
torch
.
randn
((
2
,
3
,
32
,
32
))
lab
=
torch
.
randint
(
0
,
10
,
(
2
,
))
data_input
=
dict
(
img
=
img
,
gt_label
=
lab
)
optimizer_g
=
torch
.
optim
.
SGD
(
snganproj
.
generator
.
parameters
(),
lr
=
0.01
)
optimizer_d
=
torch
.
optim
.
SGD
(
snganproj
.
discriminator
.
parameters
(),
lr
=
0.01
)
optim_dict
=
dict
(
generator
=
optimizer_g
,
discriminator
=
optimizer_d
)
model_outputs
=
snganproj
.
train_step
(
data_input
,
optim_dict
)
assert
'results'
in
model_outputs
assert
'log_vars'
in
model_outputs
assert
model_outputs
[
'num_samples'
]
==
2
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
'requires cuda'
)
def
test_sngan_proj_cuda
(
self
):
# test default config
snganproj
=
BasicConditionalGAN
(
self
.
generator_cfg
,
self
.
discriminator_cfg
,
self
.
gan_loss
,
disc_auxiliary_loss
=
self
.
disc_auxiliary_loss
,
train_cfg
=
self
.
train_cfg
).
cuda
()
# test sample from noise
outputs
=
snganproj
.
sample_from_noise
(
None
,
num_batches
=
2
)
assert
outputs
.
shape
==
(
2
,
3
,
32
,
32
)
outputs
=
snganproj
.
sample_from_noise
(
None
,
num_batches
=
2
,
return_noise
=
True
,
sample_model
=
'orig'
)
assert
outputs
[
'fake_img'
].
shape
==
(
2
,
3
,
32
,
32
)
# test train step
img
=
torch
.
randn
((
2
,
3
,
32
,
32
)).
cuda
()
lab
=
torch
.
randint
(
0
,
10
,
(
2
,
)).
cuda
()
data_input
=
dict
(
img
=
img
,
gt_label
=
lab
)
optimizer_g
=
torch
.
optim
.
SGD
(
snganproj
.
generator
.
parameters
(),
lr
=
0.01
)
optimizer_d
=
torch
.
optim
.
SGD
(
snganproj
.
discriminator
.
parameters
(),
lr
=
0.01
)
optim_dict
=
dict
(
generator
=
optimizer_g
,
discriminator
=
optimizer_d
)
model_outputs
=
snganproj
.
train_step
(
data_input
,
optim_dict
)
assert
'results'
in
model_outputs
assert
'log_vars'
in
model_outputs
assert
model_outputs
[
'num_samples'
]
==
2
tests/test_models/test_static_unconditional_gan.py
0 → 100644
View file @
c9a48a52
# Copyright (c) OpenMMLab. All rights reserved.
from
copy
import
deepcopy
import
pytest
import
torch
import
torch.nn
as
nn
from
mmgen.models
import
StaticUnconditionalGAN
,
build_model
class
TestStaticUnconditionalGAN
(
object
):
@
classmethod
def
setup_class
(
cls
):
cls
.
default_config
=
dict
(
type
=
'StaticUnconditionalGAN'
,
generator
=
dict
(
type
=
'DCGANGenerator'
,
output_scale
=
16
,
base_channels
=
32
),
discriminator
=
dict
(
type
=
'DCGANDiscriminator'
,
input_scale
=
16
,
output_scale
=
4
,
out_channels
=
5
),
gan_loss
=
dict
(
type
=
'GANLoss'
,
gan_type
=
'vanilla'
),
disc_auxiliary_loss
=
None
,
gen_auxiliary_loss
=
None
,
train_cfg
=
None
,
test_cfg
=
None
)
cls
.
generator_cfg
=
dict
(
type
=
'DCGANGenerator'
,
output_scale
=
16
,
base_channels
=
32
)
cls
.
disc_cfg
=
dict
(
type
=
'DCGANDiscriminator'
,
input_scale
=
16
,
output_scale
=
4
,
out_channels
=
5
)
cls
.
gan_loss
=
dict
(
type
=
'GANLoss'
,
gan_type
=
'vanilla'
)
cls
.
disc_auxiliary_loss
=
[
dict
(
type
=
'DiscShiftLoss'
,
loss_weight
=
0.5
,
data_info
=
dict
(
pred
=
'disc_pred_fake'
)),
dict
(
type
=
'DiscShiftLoss'
,
loss_weight
=
0.5
,
data_info
=
dict
(
pred
=
'disc_pred_real'
))
]
def
test_default_dcgan_model_cpu
(
self
):
dcgan
=
build_model
(
self
.
default_config
)
assert
isinstance
(
dcgan
,
StaticUnconditionalGAN
)
assert
not
dcgan
.
with_disc_auxiliary_loss
assert
dcgan
.
with_disc
# test forward train
with
pytest
.
raises
(
NotImplementedError
):
_
=
dcgan
(
None
,
return_loss
=
True
)
# test forward test
imgs
=
dcgan
(
None
,
return_loss
=
False
,
mode
=
'sampling'
,
num_batches
=
2
)
assert
imgs
.
shape
==
(
2
,
3
,
16
,
16
)
# test train step
data
=
torch
.
randn
((
2
,
3
,
16
,
16
))
data_input
=
dict
(
real_img
=
data
)
optimizer_g
=
torch
.
optim
.
SGD
(
dcgan
.
generator
.
parameters
(),
lr
=
0.01
)
optimizer_d
=
torch
.
optim
.
SGD
(
dcgan
.
discriminator
.
parameters
(),
lr
=
0.01
)
optim_dict
=
dict
(
generator
=
optimizer_g
,
discriminator
=
optimizer_d
)
model_outputs
=
dcgan
.
train_step
(
data_input
,
optim_dict
)
assert
'results'
in
model_outputs
assert
'log_vars'
in
model_outputs
assert
model_outputs
[
'num_samples'
]
==
2
# more tests for different configs with heavy computation
# test disc_steps
config_
=
deepcopy
(
self
.
default_config
)
config_
[
'train_cfg'
]
=
dict
(
disc_steps
=
2
)
dcgan
=
build_model
(
config_
)
model_outputs
=
dcgan
.
train_step
(
data_input
,
optim_dict
)
assert
'loss_disc_fake'
in
model_outputs
[
'log_vars'
]
assert
'loss_disc_fake_g'
not
in
model_outputs
[
'log_vars'
]
assert
dcgan
.
disc_steps
==
2
model_outputs
=
dcgan
.
train_step
(
data_input
,
optim_dict
,
running_status
=
dict
(
iteration
=
1
))
assert
'loss_disc_fake'
in
model_outputs
[
'log_vars'
]
assert
'loss_disc_fake_g'
in
model_outputs
[
'log_vars'
]
# test customized config
dcgan
=
StaticUnconditionalGAN
(
self
.
generator_cfg
,
self
.
disc_cfg
,
self
.
gan_loss
,
self
.
disc_auxiliary_loss
,
)
# test train step
data
=
torch
.
randn
((
2
,
3
,
16
,
16
))
data_input
=
dict
(
real_img
=
data
)
optimizer_g
=
torch
.
optim
.
SGD
(
dcgan
.
generator
.
parameters
(),
lr
=
0.01
)
optimizer_d
=
torch
.
optim
.
SGD
(
dcgan
.
discriminator
.
parameters
(),
lr
=
0.01
)
optim_dict
=
dict
(
generator
=
optimizer_g
,
discriminator
=
optimizer_d
)
model_outputs
=
dcgan
.
train_step
(
data_input
,
optim_dict
)
assert
'results'
in
model_outputs
assert
'log_vars'
in
model_outputs
assert
model_outputs
[
'num_samples'
]
==
2
dcgan
=
StaticUnconditionalGAN
(
self
.
generator_cfg
,
self
.
disc_cfg
,
self
.
gan_loss
,
dict
(
type
=
'DiscShiftLoss'
,
loss_weight
=
0.5
,
data_info
=
dict
(
pred
=
'disc_pred_fake'
)),
dict
(
type
=
'GeneratorPathRegularizer'
))
assert
isinstance
(
dcgan
.
disc_auxiliary_losses
,
nn
.
ModuleList
)
assert
isinstance
(
dcgan
.
gen_auxiliary_losses
,
nn
.
ModuleList
)
dcgan
=
StaticUnconditionalGAN
(
self
.
generator_cfg
,
self
.
disc_cfg
,
self
.
gan_loss
,
dict
(
type
=
'DiscShiftLoss'
,
loss_weight
=
0.5
,
data_info
=
dict
(
pred
=
'disc_pred_fake'
)),
[
dict
(
type
=
'GeneratorPathRegularizer'
)])
assert
isinstance
(
dcgan
.
disc_auxiliary_losses
,
nn
.
ModuleList
)
assert
isinstance
(
dcgan
.
gen_auxiliary_losses
,
nn
.
ModuleList
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
'requires cuda'
)
def
test_default_dcgan_model_cuda
(
self
):
dcgan
=
build_model
(
self
.
default_config
).
cuda
()
assert
isinstance
(
dcgan
,
StaticUnconditionalGAN
)
assert
not
dcgan
.
with_disc_auxiliary_loss
assert
dcgan
.
with_disc
# test forward train
with
pytest
.
raises
(
NotImplementedError
):
_
=
dcgan
(
None
,
return_loss
=
True
)
# test forward test
imgs
=
dcgan
(
None
,
return_loss
=
False
,
mode
=
'sampling'
,
num_batches
=
2
)
assert
imgs
.
shape
==
(
2
,
3
,
16
,
16
)
# test train step
data
=
torch
.
randn
((
2
,
3
,
16
,
16
)).
cuda
()
data_input
=
dict
(
real_img
=
data
)
optimizer_g
=
torch
.
optim
.
SGD
(
dcgan
.
generator
.
parameters
(),
lr
=
0.01
)
optimizer_d
=
torch
.
optim
.
SGD
(
dcgan
.
discriminator
.
parameters
(),
lr
=
0.01
)
optim_dict
=
dict
(
generator
=
optimizer_g
,
discriminator
=
optimizer_d
)
model_outputs
=
dcgan
.
train_step
(
data_input
,
optim_dict
)
assert
'results'
in
model_outputs
assert
'log_vars'
in
model_outputs
assert
model_outputs
[
'num_samples'
]
==
2
# more tests for different configs with heavy computation in GPU
# test disc_steps
config_
=
deepcopy
(
self
.
default_config
)
config_
[
'train_cfg'
]
=
dict
(
disc_steps
=
2
)
dcgan
=
build_model
(
config_
).
cuda
()
model_outputs
=
dcgan
.
train_step
(
data_input
,
optim_dict
)
assert
'loss_disc_fake'
in
model_outputs
[
'log_vars'
]
assert
'loss_disc_fake_g'
not
in
model_outputs
[
'log_vars'
]
assert
dcgan
.
disc_steps
==
2
model_outputs
=
dcgan
.
train_step
(
data_input
,
optim_dict
,
running_status
=
dict
(
iteration
=
1
))
assert
'loss_disc_fake'
in
model_outputs
[
'log_vars'
]
assert
'loss_disc_fake_g'
in
model_outputs
[
'log_vars'
]
@
pytest
.
mark
.
skipif
(
torch
.
__version__
in
[
'1.5.1'
],
reason
=
'avoid killing'
)
def
test_ada_stylegan2_model_cpu
(
self
):
synthesis_cfg
=
{
'type'
:
'SynthesisNetwork'
,
'channel_base'
:
1024
,
'channel_max'
:
16
,
'magnitude_ema_beta'
:
0.999
}
aug_kwargs
=
{
'xflip'
:
1
,
'rotate90'
:
1
,
'xint'
:
1
,
'scale'
:
1
,
'rotate'
:
1
,
'aniso'
:
1
,
'xfrac'
:
1
,
'brightness'
:
1
,
'contrast'
:
1
,
'lumaflip'
:
1
,
'hue'
:
1
,
'saturation'
:
1
}
default_config
=
dict
(
type
=
'StaticUnconditionalGAN'
,
generator
=
dict
(
type
=
'StyleGANv3Generator'
,
out_size
=
8
,
style_channels
=
8
,
img_channels
=
3
,
rgb2bgr
=
True
,
synthesis_cfg
=
synthesis_cfg
),
discriminator
=
dict
(
type
=
'ADAStyleGAN2Discriminator'
,
in_size
=
8
,
input_bgr2rgb
=
True
,
data_aug
=
dict
(
type
=
'ADAAug'
,
update_interval
=
2
,
aug_pipeline
=
aug_kwargs
,
ada_kimg
=
100
)),
gan_loss
=
dict
(
type
=
'GANLoss'
,
gan_type
=
'wgan-logistic-ns'
))
s3gan
=
build_model
(
default_config
)
assert
isinstance
(
s3gan
,
StaticUnconditionalGAN
)
assert
not
s3gan
.
with_disc_auxiliary_loss
assert
s3gan
.
with_disc
# test forward train
with
pytest
.
raises
(
NotImplementedError
):
_
=
s3gan
(
None
,
return_loss
=
True
)
# test forward test
imgs
=
s3gan
(
None
,
return_loss
=
False
,
mode
=
'sampling'
,
num_batches
=
2
)
assert
imgs
.
shape
==
(
2
,
3
,
8
,
8
)
# test train step
data
=
torch
.
randn
((
2
,
3
,
8
,
8
))
data_input
=
dict
(
real_img
=
data
)
optimizer_g
=
torch
.
optim
.
SGD
(
s3gan
.
generator
.
parameters
(),
lr
=
0.01
)
optimizer_d
=
torch
.
optim
.
SGD
(
s3gan
.
discriminator
.
parameters
(),
lr
=
0.01
)
optim_dict
=
dict
(
generator
=
optimizer_g
,
discriminator
=
optimizer_d
)
_
=
s3gan
.
train_step
(
data_input
,
optim_dict
,
running_status
=
dict
(
iteration
=
1
))
s3gan
.
discriminator
.
ada_aug
.
aug_pipeline
.
p
.
dtype
==
torch
.
float32
tests/test_models/test_stylegan1.py
0 → 100644
View file @
c9a48a52
# Copyright (c) OpenMMLab. All rights reserved.
import
numpy
as
np
import
pytest
import
torch
from
mmgen.models
import
build_model
# from mmgen.models.gans import StyleGANV1
class
TestStyleGANV1
:
@
classmethod
def
setup_class
(
cls
):
cls
.
generator_cfg
=
dict
(
type
=
'StyleGANv1Generator'
,
out_size
=
32
,
style_channels
=
512
)
cls
.
discriminator_cfg
=
dict
(
type
=
'StyleGAN1Discriminator'
,
in_size
=
32
)
cls
.
gan_loss
=
dict
(
type
=
'GANLoss'
,
gan_type
=
'wgan'
)
cls
.
disc_auxiliary_loss
=
[
dict
(
type
=
'R1GradientPenalty'
,
loss_weight
=
10
,
norm_mode
=
'HWC'
,
data_info
=
dict
(
discriminator
=
'disc_partial'
,
real_data
=
'real_imgs'
))
]
cls
.
train_cfg
=
dict
(
use_ema
=
True
,
nkimgs_per_scale
=
{
'8'
:
0.006
,
'16'
:
0.006
,
'32'
:
0.012
},
optimizer_cfg
=
dict
(
generator
=
dict
(
type
=
'Adam'
,
lr
=
0.003
,
betas
=
(
0.0
,
0.99
)),
discriminator
=
dict
(
type
=
'Adam'
,
lr
=
0.003
,
betas
=
(
0.0
,
0.99
))),
g_lr_base
=
0.003
,
d_lr_base
=
0.003
)
cls
.
stylegan_cfg
=
dict
(
type
=
'ProgressiveGrowingGAN'
,
generator
=
cls
.
generator_cfg
,
discriminator
=
cls
.
discriminator_cfg
,
gan_loss
=
cls
.
gan_loss
,
disc_auxiliary_loss
=
cls
.
disc_auxiliary_loss
,
train_cfg
=
cls
.
train_cfg
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
'requires cuda'
)
def
test_stylegan1_cuda
(
self
):
# test default config
stylegan
=
build_model
(
self
.
stylegan_cfg
).
cuda
()
data_batch
=
dict
(
real_img
=
torch
.
randn
(
3
,
3
,
32
,
32
).
cuda
())
for
iter_num
in
range
(
5
):
outputs
=
stylegan
.
train_step
(
data_batch
,
None
,
running_status
=
dict
(
iteration
=
iter_num
,
batch_size
=
3
))
results
=
outputs
[
'results'
]
if
iter_num
==
1
:
assert
results
[
'fake_imgs'
].
shape
==
(
3
,
3
,
8
,
8
)
elif
iter_num
==
2
:
assert
results
[
'fake_imgs'
].
shape
==
(
3
,
3
,
16
,
16
)
assert
np
.
isclose
(
stylegan
.
_actual_nkimgs
[
0
],
0.006
,
atol
=
1e-8
)
elif
iter_num
==
3
:
assert
results
[
'fake_imgs'
].
shape
==
(
3
,
3
,
16
,
16
)
elif
iter_num
==
4
:
assert
results
[
'fake_imgs'
].
shape
==
(
3
,
3
,
32
,
32
)
assert
np
.
isclose
(
stylegan
.
_actual_nkimgs
[
1
],
0.012
,
atol
=
1e-8
)
def
test_stylegan1_cpu
(
self
):
# test default config
stylegan
=
build_model
(
self
.
stylegan_cfg
)
data_batch
=
dict
(
real_img
=
torch
.
randn
(
3
,
3
,
32
,
32
))
for
iter_num
in
range
(
5
):
outputs
=
stylegan
.
train_step
(
data_batch
,
None
,
running_status
=
dict
(
iteration
=
iter_num
,
batch_size
=
3
))
results
=
outputs
[
'results'
]
if
iter_num
==
1
:
assert
results
[
'fake_imgs'
].
shape
==
(
3
,
3
,
8
,
8
)
elif
iter_num
==
2
:
assert
results
[
'fake_imgs'
].
shape
==
(
3
,
3
,
16
,
16
)
assert
np
.
isclose
(
stylegan
.
_actual_nkimgs
[
0
],
0.006
,
atol
=
1e-8
)
elif
iter_num
==
3
:
assert
results
[
'fake_imgs'
].
shape
==
(
3
,
3
,
16
,
16
)
elif
iter_num
==
4
:
assert
results
[
'fake_imgs'
].
shape
==
(
3
,
3
,
32
,
32
)
assert
np
.
isclose
(
stylegan
.
_actual_nkimgs
[
1
],
0.012
,
atol
=
1e-8
)
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment