Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
mmgeneration
Commits
c9a48a52
Commit
c9a48a52
authored
Jun 16, 2025
by
limm
Browse files
add tests code
parent
b7536f78
Pipeline
#2778
canceled with stages
Changes
64
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2334 additions
and
0 deletions
+2334
-0
tests/test_apis/test_inference.py
tests/test_apis/test_inference.py
+130
-0
tests/test_cores/test_ema_hooks.py
tests/test_cores/test_ema_hooks.py
+344
-0
tests/test_cores/test_fp16_utils.py
tests/test_cores/test_fp16_utils.py
+217
-0
tests/test_cores/test_metrics.py
tests/test_cores/test_metrics.py
+334
-0
tests/test_cores/test_optimizers.py
tests/test_cores/test_optimizers.py
+101
-0
tests/test_cores/test_scheduler.py
tests/test_cores/test_scheduler.py
+90
-0
tests/test_cores/test_tensor2img.py
tests/test_cores/test_tensor2img.py
+74
-0
tests/test_cores/test_visualization_hook.py
tests/test_cores/test_visualization_hook.py
+62
-0
tests/test_datasets/test_dataset_wrappers.py
tests/test_datasets/test_dataset_wrappers.py
+25
-0
tests/test_datasets/test_grow_scale_img_dataset.py
tests/test_datasets/test_grow_scale_img_dataset.py
+99
-0
tests/test_datasets/test_paired_image_dataset.py
tests/test_datasets/test_paired_image_dataset.py
+51
-0
tests/test_datasets/test_persistent_worker.py
tests/test_datasets/test_persistent_worker.py
+30
-0
tests/test_datasets/test_pipelines/test_augmentation.py
tests/test_datasets/test_pipelines/test_augmentation.py
+285
-0
tests/test_datasets/test_pipelines/test_compose.py
tests/test_datasets/test_pipelines/test_compose.py
+38
-0
tests/test_datasets/test_pipelines/test_crop.py
tests/test_datasets/test_pipelines/test_crop.py
+172
-0
tests/test_datasets/test_pipelines/test_formatting.py
tests/test_datasets/test_pipelines/test_formatting.py
+101
-0
tests/test_datasets/test_pipelines/test_loading.py
tests/test_datasets/test_pipelines/test_loading.py
+34
-0
tests/test_datasets/test_pipelines/test_normalize.py
tests/test_datasets/test_pipelines/test_normalize.py
+107
-0
tests/test_datasets/test_quicktest_dataset.py
tests/test_datasets/test_quicktest_dataset.py
+14
-0
tests/test_datasets/test_singan_dataset.py
tests/test_datasets/test_singan_dataset.py
+26
-0
No files found.
tests/test_apis/test_inference.py
0 → 100644
View file @
c9a48a52
# Copyright (c) OpenMMLab. All rights reserved.
import
os
import
mmcv
import
pytest
import
torch
from
mmgen.apis
import
(
init_model
,
sample_ddpm_model
,
sample_img2img_model
,
sample_unconditional_model
)
class
TestSampleUnconditionalModel
:
@
classmethod
def
setup_class
(
cls
):
project_dir
=
os
.
path
.
abspath
(
os
.
path
.
join
(
__file__
,
'../../..'
))
config
=
mmcv
.
Config
.
fromfile
(
os
.
path
.
join
(
project_dir
,
'configs/dcgan/dcgan_celeba-cropped_64_b128x1_300k.py'
))
cls
.
model
=
init_model
(
config
,
checkpoint
=
None
,
device
=
'cpu'
)
def
test_sample_unconditional_model_cpu
(
self
):
res
=
sample_unconditional_model
(
self
.
model
,
5
,
num_batches
=
2
,
sample_model
=
'orig'
)
assert
res
.
shape
==
(
5
,
3
,
64
,
64
)
res
=
sample_unconditional_model
(
self
.
model
,
4
,
num_batches
=
2
,
sample_model
=
'orig'
)
assert
res
.
shape
==
(
4
,
3
,
64
,
64
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
'requires cuda'
)
def
test_sample_unconditional_model_cuda
(
self
):
model
=
self
.
model
.
cuda
()
res
=
sample_unconditional_model
(
model
,
5
,
num_batches
=
2
,
sample_model
=
'orig'
)
assert
res
.
shape
==
(
5
,
3
,
64
,
64
)
res
=
sample_unconditional_model
(
model
,
4
,
num_batches
=
2
,
sample_model
=
'orig'
)
assert
res
.
shape
==
(
4
,
3
,
64
,
64
)
class
TestSampleTranslationModel
:
@
classmethod
def
setup_class
(
cls
):
project_dir
=
os
.
path
.
abspath
(
os
.
path
.
join
(
__file__
,
'../../..'
))
pix2pix_config
=
mmcv
.
Config
.
fromfile
(
os
.
path
.
join
(
project_dir
,
'configs/pix2pix/pix2pix_vanilla_unet_bn_facades_b1x1_80k.py'
))
cls
.
pix2pix
=
init_model
(
pix2pix_config
,
checkpoint
=
None
,
device
=
'cpu'
)
cyclegan_config
=
mmcv
.
Config
.
fromfile
(
os
.
path
.
join
(
project_dir
,
'configs/cyclegan/cyclegan_lsgan_resnet_in_facades_b1x1_80k.py'
))
cls
.
cyclegan
=
init_model
(
cyclegan_config
,
checkpoint
=
None
,
device
=
'cpu'
)
cls
.
img_path
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'..'
,
'data/unpaired/testA/5.jpg'
)
def
test_translation_model_cpu
(
self
):
res
=
sample_img2img_model
(
self
.
pix2pix
,
self
.
img_path
,
target_domain
=
'photo'
)
assert
res
.
shape
==
(
1
,
3
,
256
,
256
)
res
=
sample_img2img_model
(
self
.
cyclegan
,
self
.
img_path
,
target_domain
=
'photo'
)
assert
res
.
shape
==
(
1
,
3
,
256
,
256
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
'requires cuda'
)
def
test_translation_model_cuda
(
self
):
res
=
sample_img2img_model
(
self
.
pix2pix
.
cuda
(),
self
.
img_path
,
target_domain
=
'photo'
)
assert
res
.
shape
==
(
1
,
3
,
256
,
256
)
res
=
sample_img2img_model
(
self
.
cyclegan
.
cuda
(),
self
.
img_path
,
target_domain
=
'photo'
)
assert
res
.
shape
==
(
1
,
3
,
256
,
256
)
class
TestDiffusionModel
:
@
classmethod
def
setup_class
(
cls
):
project_dir
=
os
.
path
.
abspath
(
os
.
path
.
join
(
__file__
,
'../../..'
))
ddpm_config
=
mmcv
.
Config
.
fromfile
(
os
.
path
.
join
(
project_dir
,
'configs/improved_ddpm/'
'ddpm_cosine_hybird_timestep-4k_drop0.3_'
'cifar10_32x32_b8x16_500k.py'
))
# change timesteps to speed up test process
ddpm_config
.
model
[
'num_timesteps'
]
=
10
cls
.
model
=
init_model
(
ddpm_config
,
checkpoint
=
None
,
device
=
'cpu'
)
def
test_diffusion_model_cpu
(
self
):
# save_intermedia is False
res
=
sample_ddpm_model
(
self
.
model
,
num_samples
=
3
,
num_batches
=
2
,
same_noise
=
True
)
assert
res
.
shape
==
(
3
,
3
,
32
,
32
)
# save_intermedia is True
res
=
sample_ddpm_model
(
self
.
model
,
num_samples
=
2
,
num_batches
=
2
,
same_noise
=
True
,
save_intermedia
=
True
)
assert
isinstance
(
res
,
dict
)
assert
all
([
i
in
res
for
i
in
range
(
10
)])
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
'requires cuda'
)
def
test_diffusion_model_cuda
(
self
):
model
=
self
.
model
.
cuda
()
# save_intermedia is False
res
=
sample_ddpm_model
(
model
,
num_samples
=
3
,
num_batches
=
2
,
same_noise
=
True
)
assert
res
.
shape
==
(
3
,
3
,
32
,
32
)
# save_intermedia is True
res
=
sample_ddpm_model
(
model
,
num_samples
=
2
,
num_batches
=
2
,
same_noise
=
True
,
save_intermedia
=
True
)
assert
isinstance
(
res
,
dict
)
assert
all
([
i
in
res
for
i
in
range
(
10
)])
tests/test_cores/test_ema_hooks.py
0 → 100644
View file @
c9a48a52
# Copyright (c) OpenMMLab. All rights reserved.
from
copy
import
deepcopy
import
pytest
import
torch
import
torch.nn
as
nn
from
torch.nn.parallel
import
DataParallel
from
mmgen.core.hooks
import
ExponentialMovingAverageHook
class
SimpleModule
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
a
=
nn
.
Parameter
(
torch
.
tensor
([
1.
,
2.
]))
if
torch
.
__version__
>=
'1.7.0'
:
self
.
register_buffer
(
'b'
,
torch
.
tensor
([
2.
,
3.
]),
persistent
=
True
)
self
.
register_buffer
(
'c'
,
torch
.
tensor
([
0.
,
1.
]),
persistent
=
False
)
else
:
self
.
register_buffer
(
'b'
,
torch
.
tensor
([
2.
,
3.
]))
self
.
c
=
torch
.
tensor
([
0.
,
1.
])
class
SimpleModel
(
nn
.
Module
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
self
.
module_a
=
SimpleModule
()
self
.
module_b
=
SimpleModule
()
self
.
module_a_ema
=
SimpleModule
()
self
.
module_b_ema
=
SimpleModule
()
class
SimpleModelNoEMA
(
nn
.
Module
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
self
.
module_a
=
SimpleModule
()
self
.
module_b
=
SimpleModule
()
class
SimpleRunner
:
def
__init__
(
self
):
self
.
model
=
SimpleModel
()
self
.
iter
=
0
class
TestEMA
:
@
classmethod
def
setup_class
(
cls
):
cls
.
default_config
=
dict
(
module_keys
=
(
'module_a_ema'
,
'module_b_ema'
),
interval
=
1
,
interp_cfg
=
dict
(
momentum
=
0.5
))
cls
.
runner
=
SimpleRunner
()
@
torch
.
no_grad
()
def
test_ema_hook
(
self
):
cfg_
=
deepcopy
(
self
.
default_config
)
cfg_
[
'interval'
]
=
-
1
ema
=
ExponentialMovingAverageHook
(
**
cfg_
)
ema
.
before_run
(
self
.
runner
)
ema
.
after_train_iter
(
self
.
runner
)
module_a
=
self
.
runner
.
model
.
module_a
module_a_ema
=
self
.
runner
.
model
.
module_a_ema
ema_states
=
module_a_ema
.
state_dict
()
assert
torch
.
equal
(
ema_states
[
'a'
],
torch
.
tensor
([
1.
,
2.
]))
ema
=
ExponentialMovingAverageHook
(
**
self
.
default_config
)
ema
.
after_train_iter
(
self
.
runner
)
ema_states
=
module_a_ema
.
state_dict
()
assert
torch
.
equal
(
ema_states
[
'a'
],
torch
.
tensor
([
1.
,
2.
]))
module_a
.
b
/=
2.
module_a
.
a
.
data
/=
2.
module_a
.
c
/=
2.
self
.
runner
.
iter
+=
1
ema
.
after_train_iter
(
self
.
runner
)
ema_states
=
module_a_ema
.
state_dict
()
assert
torch
.
equal
(
self
.
runner
.
model
.
module_a
.
a
,
torch
.
tensor
([
0.5
,
1.
]))
assert
torch
.
equal
(
ema_states
[
'a'
],
torch
.
tensor
([
0.75
,
1.5
]))
assert
torch
.
equal
(
ema_states
[
'b'
],
torch
.
tensor
([
1.
,
1.5
]))
assert
'c'
not
in
ema_states
# check for the validity of args
with
pytest
.
raises
(
AssertionError
):
_
=
ExponentialMovingAverageHook
(
module_keys
=
[
'a'
])
with
pytest
.
raises
(
AssertionError
):
_
=
ExponentialMovingAverageHook
(
module_keys
=
(
'a'
))
with
pytest
.
raises
(
AssertionError
):
_
=
ExponentialMovingAverageHook
(
module_keys
=
(
'module_a_ema'
),
interp_mode
=
'xxx'
)
# test before run
ema
=
ExponentialMovingAverageHook
(
**
self
.
default_config
)
self
.
runner
.
model
=
SimpleModelNoEMA
()
self
.
runner
.
iter
=
0
ema
.
before_run
(
self
.
runner
)
assert
hasattr
(
self
.
runner
.
model
,
'module_a_ema'
)
module_a
=
self
.
runner
.
model
.
module_a
module_a_ema
=
self
.
runner
.
model
.
module_a_ema
ema
.
after_train_iter
(
self
.
runner
)
ema_states
=
module_a_ema
.
state_dict
()
assert
torch
.
equal
(
ema_states
[
'a'
],
torch
.
tensor
([
1.
,
2.
]))
module_a
.
b
/=
2.
module_a
.
a
.
data
/=
2.
module_a
.
c
/=
2.
self
.
runner
.
iter
+=
1
ema
.
after_train_iter
(
self
.
runner
)
ema_states
=
module_a_ema
.
state_dict
()
assert
torch
.
equal
(
self
.
runner
.
model
.
module_a
.
a
,
torch
.
tensor
([
0.5
,
1.
]))
assert
torch
.
equal
(
ema_states
[
'a'
],
torch
.
tensor
([
0.75
,
1.5
]))
assert
torch
.
equal
(
ema_states
[
'b'
],
torch
.
tensor
([
1.
,
1.5
]))
assert
'c'
not
in
ema_states
# test ema with simple warm up
runner
=
SimpleRunner
()
cfg_
=
deepcopy
(
self
.
default_config
)
cfg_
.
update
(
dict
(
start_iter
=
3
,
interval
=
1
))
ema
=
ExponentialMovingAverageHook
(
**
cfg_
)
ema
.
before_run
(
runner
)
module_a
=
runner
.
model
.
module_a
module_a_ema
=
runner
.
model
.
module_a_ema
module_a
.
a
.
data
/=
2.
runner
.
iter
+=
1
ema
.
after_train_iter
(
runner
)
ema_states
=
module_a_ema
.
state_dict
()
assert
torch
.
equal
(
runner
.
model
.
module_a
.
a
,
torch
.
tensor
([
0.5
,
1.
]))
assert
torch
.
equal
(
ema_states
[
'a'
],
torch
.
tensor
([
0.5
,
1.
]))
module_a
.
a
.
data
/=
2
runner
.
iter
+=
2
ema
.
after_train_iter
(
runner
)
ema_states
=
module_a_ema
.
state_dict
()
assert
torch
.
equal
(
runner
.
model
.
module_a
.
a
,
torch
.
tensor
([
0.25
,
0.5
]))
assert
torch
.
equal
(
ema_states
[
'a'
],
torch
.
tensor
([
0.375
,
0.75
]))
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
'requires cuda'
)
def
test_ema_hook_cuda
(
self
):
ema
=
ExponentialMovingAverageHook
(
**
self
.
default_config
)
cuda_runner
=
SimpleRunner
()
cuda_runner
.
model
=
cuda_runner
.
model
.
cuda
()
ema
.
after_train_iter
(
cuda_runner
)
module_a
=
cuda_runner
.
model
.
module_a
module_a_ema
=
cuda_runner
.
model
.
module_a_ema
ema_states
=
module_a_ema
.
state_dict
()
assert
torch
.
equal
(
ema_states
[
'a'
],
torch
.
tensor
([
1.
,
2.
]).
cuda
())
module_a
.
b
/=
2.
module_a
.
a
.
data
/=
2.
module_a
.
c
/=
2.
cuda_runner
.
iter
+=
1
ema
.
after_train_iter
(
cuda_runner
)
ema_states
=
module_a_ema
.
state_dict
()
assert
torch
.
equal
(
cuda_runner
.
model
.
module_a
.
a
,
torch
.
tensor
([
0.5
,
1.
]).
cuda
())
assert
torch
.
equal
(
ema_states
[
'a'
],
torch
.
tensor
([
0.75
,
1.5
]).
cuda
())
assert
torch
.
equal
(
ema_states
[
'b'
],
torch
.
tensor
([
1.
,
1.5
]).
cuda
())
assert
'c'
not
in
ema_states
# test before run
ema
=
ExponentialMovingAverageHook
(
**
self
.
default_config
)
self
.
runner
.
model
=
SimpleModelNoEMA
().
cuda
()
self
.
runner
.
model
=
DataParallel
(
self
.
runner
.
model
)
self
.
runner
.
iter
=
0
ema
.
before_run
(
self
.
runner
)
assert
hasattr
(
self
.
runner
.
model
.
module
,
'module_a_ema'
)
module_a
=
self
.
runner
.
model
.
module
.
module_a
module_a_ema
=
self
.
runner
.
model
.
module
.
module_a_ema
ema
.
after_train_iter
(
self
.
runner
)
ema_states
=
module_a_ema
.
state_dict
()
assert
torch
.
equal
(
ema_states
[
'a'
],
torch
.
tensor
([
1.
,
2.
]).
cuda
())
module_a
.
b
/=
2.
module_a
.
a
.
data
/=
2.
module_a
.
c
/=
2.
self
.
runner
.
iter
+=
1
ema
.
after_train_iter
(
self
.
runner
)
ema_states
=
module_a_ema
.
state_dict
()
assert
torch
.
equal
(
self
.
runner
.
model
.
module
.
module_a
.
a
,
torch
.
tensor
([
0.5
,
1.
]).
cuda
())
assert
torch
.
equal
(
ema_states
[
'a'
],
torch
.
tensor
([
0.75
,
1.5
]).
cuda
())
assert
torch
.
equal
(
ema_states
[
'b'
],
torch
.
tensor
([
1.
,
1.5
]).
cuda
())
assert
'c'
not
in
ema_states
# test ema with simple warm up
runner
=
SimpleRunner
()
runner
.
model
=
runner
.
model
.
cuda
()
cfg_
=
deepcopy
(
self
.
default_config
)
cfg_
.
update
(
dict
(
start_iter
=
3
,
interval
=
1
))
ema
=
ExponentialMovingAverageHook
(
**
cfg_
)
ema
.
before_run
(
runner
)
module_a
=
runner
.
model
.
module_a
module_a_ema
=
runner
.
model
.
module_a_ema
module_a
.
a
.
data
/=
2.
runner
.
iter
+=
1
ema
.
after_train_iter
(
runner
)
ema_states
=
module_a_ema
.
state_dict
()
assert
torch
.
equal
(
runner
.
model
.
module_a
.
a
,
torch
.
tensor
([
0.5
,
1.
]).
cuda
())
assert
torch
.
equal
(
ema_states
[
'a'
],
torch
.
tensor
([
0.5
,
1.
]).
cuda
())
module_a
.
a
.
data
/=
2
runner
.
iter
+=
2
ema
.
after_train_iter
(
runner
)
ema_states
=
module_a_ema
.
state_dict
()
assert
torch
.
equal
(
runner
.
model
.
module_a
.
a
,
torch
.
tensor
([
0.25
,
0.5
]).
cuda
())
assert
torch
.
equal
(
ema_states
[
'a'
],
torch
.
tensor
([
0.375
,
0.75
]).
cuda
())
def
test_dynamic_ema
(
self
):
# test within rampup phase
cfg_
=
dict
(
module_keys
=
(
'module_a_ema'
,
'module_b_ema'
),
interp_cfg
=
dict
(
momentum
=
0.9
),
interval
=
1
,
start_iter
=
0
,
momentum_policy
=
'rampup'
,
momentum_cfg
=
dict
(
ema_kimg
=
10
,
ema_rampup
=
0.05
,
batch_size
=
4
,
eps
=
1e-8
))
runner
=
SimpleRunner
()
ema
=
ExponentialMovingAverageHook
(
**
cfg_
)
ema
.
before_run
(
runner
)
module_a
=
runner
.
model
.
module_a
module_a_ema
=
runner
.
model
.
module_a_ema
ema_states
=
module_a_ema
.
state_dict
()
assert
torch
.
equal
(
ema_states
[
'a'
],
torch
.
tensor
([
1.
,
2.
]))
module_a
.
b
/=
2.
module_a
.
a
.
data
/=
2.
module_a
.
c
/=
2.
runner
.
iter
+=
19
ema
.
after_train_iter
(
runner
)
ema_states
=
module_a_ema
.
state_dict
()
assert
torch
.
equal
(
runner
.
model
.
module_a
.
a
,
torch
.
tensor
([
0.5
,
1.
]))
assert
torch
.
equal
(
ema_states
[
'a'
],
torch
.
tensor
([
0.75
,
1.5
]))
assert
torch
.
equal
(
ema_states
[
'b'
],
torch
.
tensor
([
1.
,
1.5
]))
assert
'c'
not
in
ema_states
# test exceeds rampup phase
cfg_
=
dict
(
module_keys
=
(
'module_a_ema'
,
'module_b_ema'
),
interp_cfg
=
dict
(
momentum
=
0.9
),
interval
=
1
,
start_iter
=
0
,
momentum_policy
=
'rampup'
,
momentum_cfg
=
dict
(
ema_kimg
=
10
,
ema_rampup
=
0.05
,
batch_size
=
4
,
eps
=
1e-8
))
runner
=
SimpleRunner
()
ema
=
ExponentialMovingAverageHook
(
**
cfg_
)
ema
.
before_run
(
runner
)
# modify module data
module_a
=
runner
.
model
.
module_a
module_a_ema
=
runner
.
model
.
module_a_ema
ema_states
=
module_a_ema
.
state_dict
()
assert
torch
.
equal
(
ema_states
[
'a'
],
torch
.
tensor
([
1.
,
2.
]))
module_a
.
b
/=
2.
module_a
.
a
.
data
/=
2.
module_a
.
c
/=
2.
runner
.
iter
+=
49999
ema
.
after_train_iter
(
runner
)
ema_states
=
module_a_ema
.
state_dict
()
assert
torch
.
equal
(
runner
.
model
.
module_a
.
a
,
torch
.
tensor
([
0.5
,
1.
]))
expected_m
=
0.5
**
(
4
/
10000
)
assert
torch
.
equal
(
ema_states
[
'a'
],
expected_m
*
torch
.
tensor
([
1.0
,
2.0
])
+
(
1.
-
expected_m
)
*
torch
.
tensor
([
0.5
,
1.0
]))
assert
torch
.
equal
(
ema_states
[
'b'
],
torch
.
tensor
([
1.
,
1.5
]))
assert
'c'
not
in
ema_states
# test exceeds rampup phase
cfg_
=
dict
(
module_keys
=
(
'module_a_ema'
,
'module_b_ema'
),
interp_cfg
=
dict
(
momentum
=
0.9
),
interval
=
1
,
start_iter
=
0
,
momentum_policy
=
'rampup'
,
momentum_cfg
=
dict
(
ema_kimg
=
10
,
ema_rampup
=
0.05
,
batch_size
=
4
,
eps
=
1e-8
))
runner
=
SimpleRunner
()
ema
=
ExponentialMovingAverageHook
(
**
cfg_
)
ema
.
before_run
(
runner
)
# modify module data
module_a
=
runner
.
model
.
module_a
module_a_ema
=
runner
.
model
.
module_a_ema
ema_states
=
module_a_ema
.
state_dict
()
assert
torch
.
equal
(
ema_states
[
'a'
],
torch
.
tensor
([
1.
,
2.
]))
module_a
.
b
/=
2.
module_a
.
a
.
data
/=
2.
module_a
.
c
/=
2.
runner
.
iter
+=
79999
ema
.
after_train_iter
(
runner
)
ema_states
=
module_a_ema
.
state_dict
()
assert
torch
.
equal
(
runner
.
model
.
module_a
.
a
,
torch
.
tensor
([
0.5
,
1.
]))
expected_m
=
0.5
**
(
4
/
10000
)
assert
torch
.
equal
(
ema_states
[
'a'
],
expected_m
*
torch
.
tensor
([
1.0
,
2.0
])
+
(
1.
-
expected_m
)
*
torch
.
tensor
([
0.5
,
1.0
]))
assert
torch
.
equal
(
ema_states
[
'b'
],
torch
.
tensor
([
1.
,
1.5
]))
assert
'c'
not
in
ema_states
tests/test_cores/test_fp16_utils.py
0 → 100644
View file @
c9a48a52
# Copyright (c) OpenMMLab. All rights reserved.
import
numpy
as
np
import
pytest
import
torch
import
torch.nn
as
nn
from
mmcv.utils
import
TORCH_VERSION
from
mmgen.core.runners.fp16_utils
import
(
auto_fp16
,
cast_tensor_type
,
nan_to_num
)
def
test_nan_to_num
():
a
=
torch
.
tensor
([
float
(
'inf'
),
float
(
'nan'
),
2.
])
res
=
nan_to_num
(
a
,
posinf
=
255.
,
neginf
=-
255.
)
assert
(
res
==
torch
.
tensor
([
255.
,
0.
,
2.
])).
all
()
res
=
nan_to_num
(
a
)
assert
res
.
shape
==
(
3
,
)
with
pytest
.
raises
(
TypeError
):
nan_to_num
(
1
)
def
test_cast_tensor_type
():
inputs
=
torch
.
FloatTensor
([
5.
])
src_type
=
torch
.
float32
dst_type
=
torch
.
int32
outputs
=
cast_tensor_type
(
inputs
,
src_type
,
dst_type
)
assert
isinstance
(
outputs
,
torch
.
Tensor
)
assert
outputs
.
dtype
==
dst_type
inputs
=
'tensor'
src_type
=
str
dst_type
=
str
outputs
=
cast_tensor_type
(
inputs
,
src_type
,
dst_type
)
assert
isinstance
(
outputs
,
str
)
inputs
=
np
.
array
([
5.
])
src_type
=
np
.
ndarray
dst_type
=
np
.
ndarray
outputs
=
cast_tensor_type
(
inputs
,
src_type
,
dst_type
)
assert
isinstance
(
outputs
,
np
.
ndarray
)
inputs
=
dict
(
tensor_a
=
torch
.
FloatTensor
([
1.
]),
tensor_b
=
torch
.
FloatTensor
([
2.
]))
src_type
=
torch
.
float32
dst_type
=
torch
.
int32
outputs
=
cast_tensor_type
(
inputs
,
src_type
,
dst_type
)
assert
isinstance
(
outputs
,
dict
)
assert
outputs
[
'tensor_a'
].
dtype
==
dst_type
assert
outputs
[
'tensor_b'
].
dtype
==
dst_type
inputs
=
[
torch
.
FloatTensor
([
1.
]),
torch
.
FloatTensor
([
2.
])]
src_type
=
torch
.
float32
dst_type
=
torch
.
int32
outputs
=
cast_tensor_type
(
inputs
,
src_type
,
dst_type
)
assert
isinstance
(
outputs
,
list
)
assert
outputs
[
0
].
dtype
==
dst_type
assert
outputs
[
1
].
dtype
==
dst_type
inputs
=
5
outputs
=
cast_tensor_type
(
inputs
,
None
,
None
)
assert
isinstance
(
outputs
,
int
)
inputs
=
nn
.
Sequential
(
nn
.
Conv2d
(
2
,
2
,
3
),
nn
.
ReLU
())
outputs
=
cast_tensor_type
(
inputs
,
None
,
None
)
assert
isinstance
(
outputs
,
nn
.
Module
)
@
pytest
.
mark
.
skipif
(
not
TORCH_VERSION
>=
'1.6.0'
,
reason
=
'Lower PyTorch version'
)
def
test_auto_fp16_func
():
with
pytest
.
raises
(
TypeError
):
# ExampleObject is not a subclass of nn.Module
class
ExampleObject
(
object
):
@
auto_fp16
()
def
__call__
(
self
,
x
):
return
x
model
=
ExampleObject
()
input_x
=
torch
.
ones
(
1
,
dtype
=
torch
.
float32
)
model
(
input_x
)
# apply to all input args
class
ExampleModule
(
nn
.
Module
):
@
auto_fp16
()
def
forward
(
self
,
x
,
y
):
return
x
,
y
model
=
ExampleModule
()
input_x
=
torch
.
ones
(
1
,
dtype
=
torch
.
float32
)
input_y
=
torch
.
ones
(
1
,
dtype
=
torch
.
float32
)
output_x
,
output_y
=
model
(
input_x
,
input_y
)
assert
output_x
.
dtype
==
torch
.
float32
assert
output_y
.
dtype
==
torch
.
float32
model
.
fp16_enabled
=
True
output_x
,
output_y
=
model
(
input_x
,
input_y
)
assert
output_x
.
dtype
==
torch
.
float32
assert
output_y
.
dtype
==
torch
.
float32
if
torch
.
cuda
.
is_available
():
model
.
cuda
()
output_x
,
output_y
=
model
(
input_x
.
cuda
(),
input_y
.
cuda
())
assert
output_x
.
dtype
==
torch
.
float32
assert
output_y
.
dtype
==
torch
.
float32
# apply to specified input args
class
ExampleModule
(
nn
.
Module
):
@
auto_fp16
(
apply_to
=
(
'x'
,
))
def
forward
(
self
,
x
,
y
):
return
x
,
y
model
=
ExampleModule
()
input_x
=
torch
.
ones
(
1
,
dtype
=
torch
.
float32
)
input_y
=
torch
.
ones
(
1
,
dtype
=
torch
.
float32
)
output_x
,
output_y
=
model
(
input_x
,
input_y
)
assert
output_x
.
dtype
==
torch
.
float32
assert
output_y
.
dtype
==
torch
.
float32
model
.
fp16_enabled
=
True
output_x
,
output_y
=
model
(
input_x
,
input_y
)
assert
output_x
.
dtype
==
torch
.
half
assert
output_y
.
dtype
==
torch
.
float32
if
torch
.
cuda
.
is_available
():
model
.
cuda
()
output_x
,
output_y
=
model
(
input_x
.
cuda
(),
input_y
.
cuda
())
assert
output_x
.
dtype
==
torch
.
half
assert
output_y
.
dtype
==
torch
.
float32
# apply to optional input args
class
ExampleModule
(
nn
.
Module
):
@
auto_fp16
(
apply_to
=
(
'x'
,
'y'
))
def
forward
(
self
,
x
,
y
=
None
,
z
=
None
):
return
x
,
y
,
z
model
=
ExampleModule
()
input_x
=
torch
.
ones
(
1
,
dtype
=
torch
.
float32
)
input_y
=
torch
.
ones
(
1
,
dtype
=
torch
.
float32
)
input_z
=
torch
.
ones
(
1
,
dtype
=
torch
.
float32
)
output_x
,
output_y
,
output_z
=
model
(
input_x
,
y
=
input_y
,
z
=
input_z
)
assert
output_x
.
dtype
==
torch
.
float32
assert
output_y
.
dtype
==
torch
.
float32
assert
output_z
.
dtype
==
torch
.
float32
model
.
fp16_enabled
=
True
output_x
,
output_y
,
output_z
=
model
(
input_x
,
y
=
input_y
,
z
=
input_z
)
assert
output_x
.
dtype
==
torch
.
half
assert
output_y
.
dtype
==
torch
.
half
assert
output_z
.
dtype
==
torch
.
float32
if
torch
.
cuda
.
is_available
():
model
.
cuda
()
output_x
,
output_y
,
output_z
=
model
(
input_x
.
cuda
(),
y
=
input_y
.
cuda
(),
z
=
input_z
.
cuda
())
assert
output_x
.
dtype
==
torch
.
half
assert
output_y
.
dtype
==
torch
.
half
assert
output_z
.
dtype
==
torch
.
float32
# out_fp32=True
class
ExampleModule
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
out_fp32
=
True
@
auto_fp16
(
apply_to
=
(
'x'
,
'y'
))
def
forward
(
self
,
x
,
y
=
None
,
z
=
None
):
return
x
,
y
,
z
model
=
ExampleModule
()
model
.
fp16_enabled
=
True
input_x
=
torch
.
ones
(
1
,
dtype
=
torch
.
half
)
input_y
=
torch
.
ones
(
1
,
dtype
=
torch
.
float32
)
input_z
=
torch
.
ones
(
1
,
dtype
=
torch
.
float32
)
output_x
,
output_y
,
output_z
=
model
(
input_x
,
y
=
input_y
,
z
=
input_z
)
assert
output_x
.
dtype
==
torch
.
float32
assert
output_y
.
dtype
==
torch
.
float32
assert
output_z
.
dtype
==
torch
.
float32
# out_fp32=True
class
ExampleModule
(
nn
.
Module
):
@
auto_fp16
(
apply_to
=
(
'x'
,
'y'
),
out_fp32
=
True
)
def
forward
(
self
,
x
,
y
=
None
,
z
=
None
):
return
x
,
y
,
z
model
=
ExampleModule
()
input_x
=
torch
.
ones
(
1
,
dtype
=
torch
.
half
)
input_y
=
torch
.
ones
(
1
,
dtype
=
torch
.
float32
)
input_z
=
torch
.
ones
(
1
,
dtype
=
torch
.
float32
)
output_x
,
output_y
,
output_z
=
model
(
input_x
,
y
=
input_y
,
z
=
input_z
)
assert
output_x
.
dtype
==
torch
.
half
assert
output_y
.
dtype
==
torch
.
float32
assert
output_z
.
dtype
==
torch
.
float32
model
.
fp16_enabled
=
True
output_x
,
output_y
,
output_z
=
model
(
input_x
,
y
=
input_y
,
z
=
input_z
)
assert
output_x
.
dtype
==
torch
.
float32
assert
output_y
.
dtype
==
torch
.
float32
assert
output_z
.
dtype
==
torch
.
float32
if
torch
.
cuda
.
is_available
():
model
.
cuda
()
output_x
,
output_y
,
output_z
=
model
(
input_x
.
cuda
(),
y
=
input_y
.
cuda
(),
z
=
input_z
.
cuda
())
assert
output_x
.
dtype
==
torch
.
float32
assert
output_y
.
dtype
==
torch
.
float32
assert
output_z
.
dtype
==
torch
.
float32
tests/test_cores/test_metrics.py
0 → 100644
View file @
c9a48a52
# Copyright (c) OpenMMLab. All rights reserved.
import
os.path
as
osp
import
numpy
as
np
import
pytest
import
torch
from
mmgen.core.evaluation.metric_utils
import
extract_inception_features
from
mmgen.core.evaluation.metrics
import
(
FID
,
IS
,
MS_SSIM
,
PPL
,
PR
,
SWD
,
GaussianKLD
)
from
mmgen.datasets
import
UnconditionalImageDataset
,
build_dataloader
from
mmgen.models
import
build_model
from
mmgen.models.architectures
import
InceptionV3
from
mmgen.utils
import
download_from_url
# def test_inception_download():
# from mmgen.core.evaluation.metrics import load_inception
# from mmgen.utils import MMGEN_CACHE_DIR
# args_FID_pytorch = dict(type='pytorch', normalize_input=False)
# args_FID_tero = dict(type='StyleGAN', inception_path='')
# args_IS_pytorch = dict(type='pytorch')
# args_IS_tero = dict(
# type='StyleGAN',
# inception_path=osp.join(MMGEN_CACHE_DIR, 'inception-2015-12-05.pt'))
# tar_style_list = ['pytorch', 'StyleGAN', 'pytorch', 'StyleGAN']
# for inception_args, metric, tar_style in zip(
# [args_FID_pytorch, args_FID_tero, args_IS_pytorch, args_IS_tero],
# ['FID', 'FID', 'IS', 'IS'], tar_style_list):
# model, style = load_inception(inception_args, metric)
# assert style == tar_style
# args_empty = ''
# with pytest.raises(TypeError) as exc_info:
# load_inception(args_empty, 'FID')
# args_error_path = dict(type='StyleGAN', inception_path='error-path')
# with pytest.raises(RuntimeError) as exc_info:
# load_inception(args_error_path, 'FID')
def
test_swd_metric
():
img_nchw_1
=
torch
.
rand
((
100
,
3
,
32
,
32
))
img_nchw_2
=
torch
.
rand
((
100
,
3
,
32
,
32
))
metric
=
SWD
(
100
,
(
3
,
32
,
32
))
metric
.
prepare
()
metric
.
feed
(
img_nchw_1
,
'reals'
)
metric
.
feed
(
img_nchw_2
,
'fakes'
)
result
=
[
16.495922580361366
,
24.15413036942482
,
20.325026474893093
]
output
=
metric
.
summary
()
result
=
[
item
/
100
for
item
in
result
]
output
=
[
item
/
100
for
item
in
output
]
np
.
testing
.
assert_almost_equal
(
output
,
result
,
decimal
=
1
)
def
test_ms_ssim
():
img_nhwc_1
=
torch
.
rand
((
100
,
3
,
32
,
32
))
img_nhwc_2
=
torch
.
rand
((
100
,
3
,
32
,
32
))
metric
=
MS_SSIM
(
100
)
metric
.
prepare
()
metric
.
feed
(
img_nhwc_1
,
'reals'
)
metric
.
feed
(
img_nhwc_2
,
'fakes'
)
ssim_result
=
metric
.
summary
()
assert
ssim_result
<
1
class
TestExtractInceptionFeat
:
@
classmethod
def
setup_class
(
cls
):
cls
.
inception
=
InceptionV3
(
load_fid_inception
=
False
,
resize_input
=
True
)
pipeline
=
[
dict
(
type
=
'LoadImageFromFile'
,
key
=
'real_img'
),
dict
(
type
=
'Resize'
,
keys
=
[
'real_img'
],
scale
=
(
299
,
299
),
keep_ratio
=
False
,
),
dict
(
type
=
'Normalize'
,
keys
=
[
'real_img'
],
mean
=
[
127.5
]
*
3
,
std
=
[
127.5
]
*
3
,
to_rgb
=
False
),
dict
(
type
=
'Collect'
,
keys
=
[
'real_img'
],
meta_keys
=
[]),
dict
(
type
=
'ImageToTensor'
,
keys
=
[
'real_img'
])
]
dataset
=
UnconditionalImageDataset
(
osp
.
join
(
osp
.
dirname
(
__file__
),
'..'
,
'data'
),
pipeline
)
cls
.
data_loader
=
build_dataloader
(
dataset
,
3
,
0
,
dist
=
False
)
def
test_extr_inception_feat
(
self
):
feat
=
extract_inception_features
(
self
.
data_loader
,
self
.
inception
,
5
)
assert
feat
.
shape
[
0
]
==
5
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
'requires cuda'
)
def
test_extr_inception_feat_cuda
(
self
):
inception
=
torch
.
nn
.
DataParallel
(
self
.
inception
)
feat
=
extract_inception_features
(
self
.
data_loader
,
inception
,
5
)
assert
feat
.
shape
[
0
]
==
5
@
torch
.
no_grad
()
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
'requires cuda'
)
def
test_with_tero_implement
(
self
):
self
.
inception
=
InceptionV3
(
load_fid_inception
=
True
,
resize_input
=
False
)
img
=
torch
.
randn
((
2
,
3
,
1024
,
1024
))
feature_ours
=
self
.
inception
(
img
)[
0
].
view
(
img
.
shape
[
0
],
-
1
)
# Tero implementation
download_from_url
(
'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt'
,
# noqa
dest_dir
=
'./work_dirs/cache'
)
net
=
torch
.
jit
.
load
(
'./work_dirs/cache/inception-2015-12-05.pt'
).
eval
().
cuda
()
net
=
torch
.
nn
.
DataParallel
(
net
)
feature_tero
=
net
(
img
,
return_features
=
True
)
print
(
feature_ours
.
shape
)
print
((
feature_tero
.
cpu
()
-
feature_ours
).
abs
().
mean
())
class
TestFID
:
@
classmethod
def
setup_class
(
cls
):
cls
.
reals
=
[
torch
.
randn
(
2
,
3
,
128
,
128
)
for
_
in
range
(
5
)]
cls
.
fakes
=
[
torch
.
randn
(
2
,
3
,
128
,
128
)
for
_
in
range
(
5
)]
def
test_fid
(
self
):
fid
=
FID
(
3
,
inception_args
=
dict
(
normalize_input
=
False
,
load_fid_inception
=
False
))
for
b
in
self
.
reals
:
fid
.
feed
(
b
,
'reals'
)
for
b
in
self
.
fakes
:
fid
.
feed
(
b
,
'fakes'
)
fid_score
,
mean
,
cov
=
fid
.
summary
()
assert
fid_score
>
0
and
mean
>
0
and
cov
>
0
# To reduce the size of git repo, we remove the following test
# fid = FID(
# 3,
# inception_args=dict(
# normalize_input=False, load_fid_inception=False),
# inception_pkl=osp.join(
# osp.dirname(__file__), '..', 'data', 'test_dirty.pkl'))
# assert fid.num_real_feeded == 3
# for b in self.reals:
# fid.feed(b, 'reals')
# for b in self.fakes:
# fid.feed(b, 'fakes')
# fid_score, mean, cov = fid.summary()
# assert fid_score > 0 and mean > 0 and cov > 0
class
TestPR
:
@
classmethod
def
setup_class
(
cls
):
cls
.
reals
=
[
torch
.
rand
(
2
,
3
,
128
,
128
)
for
_
in
range
(
5
)]
cls
.
fakes
=
[
torch
.
rand
(
2
,
3
,
128
,
128
)
for
_
in
range
(
5
)]
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
'requires cuda'
)
def
test_pr_cuda
(
self
):
pr
=
PR
(
10
)
pr
.
prepare
()
for
b
in
self
.
fakes
:
pr
.
feed
(
b
.
cuda
(),
'fakes'
)
for
b
in
self
.
reals
:
pr
.
feed
(
b
.
cuda
(),
'reals'
)
pr_score
=
pr
.
summary
()
print
(
pr_score
)
assert
pr_score
[
'precision'
]
>=
0
and
pr_score
[
'recall'
]
>=
0
def
test_pr_cpu
(
self
):
pr
=
PR
(
10
)
pr
.
prepare
()
for
b
in
self
.
fakes
:
pr
.
feed
(
b
,
'fakes'
)
for
b
in
self
.
reals
:
pr
.
feed
(
b
,
'reals'
)
pr_score
=
pr
.
summary
()
assert
pr_score
[
'precision'
]
>=
0
and
pr_score
[
'recall'
]
>=
0
class
TestIS
:
@
classmethod
def
setup_class
(
cls
):
cls
.
reals
=
[
torch
.
randn
(
2
,
3
,
128
,
128
)
for
_
in
range
(
5
)]
cls
.
fakes
=
[
torch
.
randn
(
2
,
3
,
128
,
128
)
for
_
in
range
(
5
)]
def
test_is_cpu
(
self
):
inception_score
=
IS
(
10
,
resize
=
True
,
splits
=
10
)
inception_score
.
prepare
()
for
b
in
self
.
reals
:
inception_score
.
feed
(
b
,
'reals'
)
for
b
in
self
.
fakes
:
inception_score
.
feed
(
b
,
'fakes'
)
score
,
std
=
inception_score
.
summary
()
assert
score
>
0
and
std
>=
0
@
torch
.
no_grad
()
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
'requires cuda'
)
def
test_is_cuda
(
self
):
inception_score
=
IS
(
10
,
resize
=
True
,
splits
=
10
)
inception_score
.
prepare
()
for
b
in
self
.
reals
:
inception_score
.
feed
(
b
.
cuda
(),
'reals'
)
for
b
in
self
.
fakes
:
inception_score
.
feed
(
b
.
cuda
(),
'fakes'
)
score
,
std
=
inception_score
.
summary
()
assert
score
>
0
and
std
>=
0
class
TestPPL
:
@
classmethod
def
setup_class
(
cls
):
cls
.
model_cfg
=
dict
(
type
=
'StaticUnconditionalGAN'
,
generator
=
dict
(
type
=
'StyleGANv2Generator'
,
out_size
=
256
,
style_channels
=
512
,
),
discriminator
=
dict
(
type
=
'StyleGAN2Discriminator'
,
in_size
=
256
,
),
gan_loss
=
dict
(
type
=
'GANLoss'
,
gan_type
=
'wgan-logistic-ns'
),
train_cfg
=
dict
(
use_ema
=
True
))
def
test_ppl_cpu
(
self
):
self
.
model
=
build_model
(
self
.
model_cfg
)
ppl
=
PPL
(
10
)
ppl_iterator
=
iter
(
ppl
.
get_sampler
(
self
.
model
,
2
,
'ema'
))
ppl
.
prepare
()
for
b
in
ppl_iterator
:
ppl
.
feed
(
b
,
'fakes'
)
score
=
ppl
.
summary
()
assert
score
>=
0
@
torch
.
no_grad
()
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
'requires cuda'
)
def
test_ppl_cuda
(
self
):
self
.
model
=
build_model
(
self
.
model_cfg
).
cuda
()
ppl
=
PPL
(
10
)
ppl_iterator
=
iter
(
ppl
.
get_sampler
(
self
.
model
,
2
,
'ema'
))
ppl
.
prepare
()
for
b
in
ppl_iterator
:
ppl
.
feed
(
b
,
'fakes'
)
score
=
ppl
.
summary
()
assert
score
>=
0
def
test_kld_gaussian
():
# we only test at bz = 1 to test the numerical accuracy
# due to the time and memory cost
tar_shape
=
[
2
,
3
,
4
,
4
]
mean1
,
mean2
=
torch
.
rand
(
*
tar_shape
,
1
),
torch
.
rand
(
*
tar_shape
,
1
)
# var1, var2 = torch.rand(2, 3, 4, 4, 1), torch.rand(2, 3, 4, 4, 1)
var1
=
torch
.
randint
(
1
,
3
,
(
*
tar_shape
,
1
)).
float
()
var2
=
torch
.
randint
(
1
,
3
,
(
*
tar_shape
,
1
)).
float
()
def
pdf
(
x
,
mean
,
var
):
return
(
1
/
np
.
sqrt
(
2
*
np
.
pi
*
var
)
*
torch
.
exp
(
-
(
x
-
mean
)
**
2
/
(
2
*
var
)))
delta
=
0.0001
indice
=
torch
.
arange
(
-
5
,
5
,
delta
).
repeat
(
*
mean1
.
shape
)
p
=
pdf
(
indice
,
mean1
,
var1
)
# pdf of target distribution
q
=
pdf
(
indice
,
mean2
,
var2
)
# pdf of predicted distribution
kld_manually
=
(
p
*
torch
.
log
(
p
/
q
)
*
delta
).
sum
(
dim
=
(
1
,
2
,
3
,
4
)).
mean
()
data
=
dict
(
mean_pred
=
mean2
,
mean_target
=
mean1
,
logvar_pred
=
torch
.
log
(
var2
),
logvar_target
=
torch
.
log
(
var1
))
metric
=
GaussianKLD
(
2
)
metric
.
prepare
()
metric
.
feed
(
data
,
'reals'
)
kld
=
metric
.
summary
()
# this is a quite loose limitation for we cannot choose delta which is
# small enough for precise kld calculation
np
.
testing
.
assert_almost_equal
(
kld
,
kld_manually
,
decimal
=
1
)
# assert (kld - kld_manually < 1e-1).all()
metric_base_2
=
GaussianKLD
(
2
,
base
=
'2'
)
metric_base_2
.
prepare
()
metric_base_2
.
feed
(
data
,
'reals'
)
kld_base_2
=
metric_base_2
.
summary
()
np
.
testing
.
assert_almost_equal
(
kld_base_2
,
kld
/
np
.
log
(
2
),
decimal
=
4
)
# assert kld_base_2 == kld / np.log(2)
# test wrong log_base
with
pytest
.
raises
(
AssertionError
):
GaussianKLD
(
2
,
base
=
'10'
)
# test other reduction --> mean
metric
=
GaussianKLD
(
2
,
reduction
=
'mean'
)
metric
.
prepare
()
metric
.
feed
(
data
,
'reals'
)
kld
=
metric
.
summary
()
# test other reduction --> sum
metric
=
GaussianKLD
(
2
,
reduction
=
'sum'
)
metric
.
prepare
()
metric
.
feed
(
data
,
'reals'
)
kld
=
metric
.
summary
()
# test other reduction --> error
with
pytest
.
raises
(
AssertionError
):
metric
=
GaussianKLD
(
2
,
reduction
=
'none'
)
tests/test_cores/test_optimizers.py
0 → 100644
View file @
c9a48a52
# Copyright (c) OpenMMLab. All rights reserved.
import
torch
import
torch.nn
as
nn
from
mmgen.core
import
build_optimizers
class
ExampleModel
(
nn
.
Module
):
def
__init__
(
self
):
super
(
ExampleModel
,
self
).
__init__
()
self
.
model1
=
nn
.
Conv2d
(
3
,
8
,
kernel_size
=
3
)
self
.
model2
=
nn
.
Conv2d
(
3
,
4
,
kernel_size
=
3
)
def
forward
(
self
,
x
):
return
x
def
test_build_optimizers
():
base_lr
=
0.0001
base_wd
=
0.0002
momentum
=
0.9
# basic config with ExampleModel
optimizer_cfg
=
dict
(
model1
=
dict
(
type
=
'SGD'
,
lr
=
base_lr
,
weight_decay
=
base_wd
,
momentum
=
momentum
),
model2
=
dict
(
type
=
'SGD'
,
lr
=
base_lr
,
weight_decay
=
base_wd
,
momentum
=
momentum
))
model
=
ExampleModel
()
optimizers
=
build_optimizers
(
model
,
optimizer_cfg
)
param_dict
=
dict
(
model
.
named_parameters
())
assert
isinstance
(
optimizers
,
dict
)
for
i
in
range
(
2
):
optimizer
=
optimizers
[
f
'model
{
i
+
1
}
'
]
param_groups
=
optimizer
.
param_groups
[
0
]
assert
isinstance
(
optimizer
,
torch
.
optim
.
SGD
)
assert
optimizer
.
defaults
[
'lr'
]
==
base_lr
assert
optimizer
.
defaults
[
'momentum'
]
==
momentum
assert
optimizer
.
defaults
[
'weight_decay'
]
==
base_wd
assert
len
(
param_groups
[
'params'
])
==
2
assert
torch
.
equal
(
param_groups
[
'params'
][
0
],
param_dict
[
f
'model
{
i
+
1
}
.weight'
])
assert
torch
.
equal
(
param_groups
[
'params'
][
1
],
param_dict
[
f
'model
{
i
+
1
}
.bias'
])
# basic config with Parallel model
model
=
torch
.
nn
.
DataParallel
(
ExampleModel
())
optimizers
=
build_optimizers
(
model
,
optimizer_cfg
)
param_dict
=
dict
(
model
.
named_parameters
())
assert
isinstance
(
optimizers
,
dict
)
for
i
in
range
(
2
):
optimizer
=
optimizers
[
f
'model
{
i
+
1
}
'
]
param_groups
=
optimizer
.
param_groups
[
0
]
assert
isinstance
(
optimizer
,
torch
.
optim
.
SGD
)
assert
optimizer
.
defaults
[
'lr'
]
==
base_lr
assert
optimizer
.
defaults
[
'momentum'
]
==
momentum
assert
optimizer
.
defaults
[
'weight_decay'
]
==
base_wd
assert
len
(
param_groups
[
'params'
])
==
2
assert
torch
.
equal
(
param_groups
[
'params'
][
0
],
param_dict
[
f
'module.model
{
i
+
1
}
.weight'
])
assert
torch
.
equal
(
param_groups
[
'params'
][
1
],
param_dict
[
f
'module.model
{
i
+
1
}
.bias'
])
# basic config with ExampleModel (one optimizer)
optimizer_cfg
=
dict
(
type
=
'SGD'
,
lr
=
base_lr
,
weight_decay
=
base_wd
,
momentum
=
momentum
)
model
=
ExampleModel
()
optimizer
=
build_optimizers
(
model
,
optimizer_cfg
)
param_dict
=
dict
(
model
.
named_parameters
())
assert
isinstance
(
optimizers
,
dict
)
param_groups
=
optimizer
.
param_groups
[
0
]
assert
isinstance
(
optimizer
,
torch
.
optim
.
SGD
)
assert
optimizer
.
defaults
[
'lr'
]
==
base_lr
assert
optimizer
.
defaults
[
'momentum'
]
==
momentum
assert
optimizer
.
defaults
[
'weight_decay'
]
==
base_wd
assert
len
(
param_groups
[
'params'
])
==
4
assert
torch
.
equal
(
param_groups
[
'params'
][
0
],
param_dict
[
'model1.weight'
])
assert
torch
.
equal
(
param_groups
[
'params'
][
1
],
param_dict
[
'model1.bias'
])
assert
torch
.
equal
(
param_groups
[
'params'
][
2
],
param_dict
[
'model2.weight'
])
assert
torch
.
equal
(
param_groups
[
'params'
][
3
],
param_dict
[
'model2.bias'
])
# basic config with Parallel model (one optimizer)
model
=
torch
.
nn
.
DataParallel
(
ExampleModel
())
optimizer
=
build_optimizers
(
model
,
optimizer_cfg
)
param_dict
=
dict
(
model
.
named_parameters
())
assert
isinstance
(
optimizers
,
dict
)
param_groups
=
optimizer
.
param_groups
[
0
]
assert
isinstance
(
optimizer
,
torch
.
optim
.
SGD
)
assert
optimizer
.
defaults
[
'lr'
]
==
base_lr
assert
optimizer
.
defaults
[
'momentum'
]
==
momentum
assert
optimizer
.
defaults
[
'weight_decay'
]
==
base_wd
assert
len
(
param_groups
[
'params'
])
==
4
assert
torch
.
equal
(
param_groups
[
'params'
][
0
],
param_dict
[
'module.model1.weight'
])
assert
torch
.
equal
(
param_groups
[
'params'
][
1
],
param_dict
[
'module.model1.bias'
])
assert
torch
.
equal
(
param_groups
[
'params'
][
2
],
param_dict
[
'module.model2.weight'
])
assert
torch
.
equal
(
param_groups
[
'params'
][
3
],
param_dict
[
'module.model2.bias'
])
tests/test_cores/test_scheduler.py
0 → 100644
View file @
c9a48a52
# Copyright (c) OpenMMLab. All rights reserved.
import
logging
import
shutil
import
sys
import
tempfile
from
unittest.mock
import
MagicMock
,
call
import
torch
import
torch.nn
as
nn
from
mmcv.runner
import
PaviLoggerHook
,
build_runner
from
torch.utils.data
import
DataLoader
def
_build_demo_runner
(
runner_type
=
'EpochBasedRunner'
,
max_epochs
=
1
,
max_iters
=
None
):
class
Model
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
linear
=
nn
.
Linear
(
2
,
1
)
def
forward
(
self
,
x
):
return
self
.
linear
(
x
)
def
train_step
(
self
,
x
,
optimizer
,
**
kwargs
):
return
dict
(
loss
=
self
(
x
))
def
val_step
(
self
,
x
,
optimizer
,
**
kwargs
):
return
dict
(
loss
=
self
(
x
))
model
=
Model
()
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.02
,
momentum
=
0.95
)
log_config
=
dict
(
interval
=
1
,
hooks
=
[
dict
(
type
=
'TextLoggerHook'
),
])
tmp_dir
=
tempfile
.
mkdtemp
()
runner
=
build_runner
(
dict
(
type
=
runner_type
),
default_args
=
dict
(
model
=
model
,
work_dir
=
tmp_dir
,
optimizer
=
optimizer
,
logger
=
logging
.
getLogger
(),
max_epochs
=
max_epochs
,
max_iters
=
max_iters
))
runner
.
register_checkpoint_hook
(
dict
(
interval
=
1
))
runner
.
register_logger_hooks
(
log_config
)
return
runner
def
test_linear_lr_updater_scheduler
():
sys
.
modules
[
'pavi'
]
=
MagicMock
()
loader
=
DataLoader
(
torch
.
ones
((
10
,
2
)))
runner
=
_build_demo_runner
()
# add momentum LR scheduler
lr_config
=
dict
(
policy
=
'Linear'
,
by_epoch
=
False
,
target_lr
=
0
,
start
=
0
,
interval
=
1
)
runner
.
register_lr_hook
(
lr_config
)
runner
.
register_hook_from_cfg
(
dict
(
type
=
'IterTimerHook'
))
# add pavi hook
hook
=
PaviLoggerHook
(
interval
=
1
,
add_graph
=
False
,
add_last_ckpt
=
True
)
runner
.
register_hook
(
hook
)
runner
.
run
([
loader
],
[(
'train'
,
1
)])
shutil
.
rmtree
(
runner
.
work_dir
)
# TODO: use a more elegant way to check values
assert
hasattr
(
hook
,
'writer'
)
calls
=
[
call
(
'train'
,
{
'learning_rate'
:
0.018000000000000002
,
'momentum'
:
0.95
},
2
),
call
(
'train'
,
{
'learning_rate'
:
0.014
,
'momentum'
:
0.95
},
4
),
call
(
'train'
,
{
'learning_rate'
:
0.01
,
'momentum'
:
0.95
},
6
),
]
hook
.
writer
.
add_scalars
.
assert_has_calls
(
calls
,
any_order
=
True
)
tests/test_cores/test_tensor2img.py
0 → 100644
View file @
c9a48a52
# Copyright (c) OpenMMLab. All rights reserved.
import
numpy
as
np
import
pytest
import
torch
from
torchvision.utils
import
make_grid
from
mmgen.models.misc
import
tensor2img
def
test_tensor2img
():
tensor_4d_1
=
torch
.
FloatTensor
(
2
,
3
,
4
,
4
).
uniform_
(
0
,
1
)
tensor_4d_2
=
torch
.
FloatTensor
(
1
,
3
,
4
,
4
).
uniform_
(
0
,
1
)
tensor_4d_3
=
torch
.
FloatTensor
(
3
,
1
,
4
,
4
).
uniform_
(
0
,
1
)
tensor_4d_4
=
torch
.
FloatTensor
(
1
,
1
,
4
,
4
).
uniform_
(
0
,
1
)
tensor_3d_1
=
torch
.
FloatTensor
(
3
,
4
,
4
).
uniform_
(
0
,
1
)
tensor_3d_2
=
torch
.
FloatTensor
(
3
,
6
,
6
).
uniform_
(
0
,
1
)
tensor_3d_3
=
torch
.
FloatTensor
(
1
,
6
,
6
).
uniform_
(
0
,
1
)
tensor_2d
=
torch
.
FloatTensor
(
4
,
4
).
uniform_
(
0
,
1
)
with
pytest
.
raises
(
TypeError
):
# input is not a tensor
tensor2img
(
4
)
with
pytest
.
raises
(
TypeError
):
# input is not a list of tensors
tensor2img
([
tensor_3d_1
,
4
])
with
pytest
.
raises
(
ValueError
):
# unsupported 5D tensor
tensor2img
(
torch
.
FloatTensor
(
2
,
2
,
3
,
4
,
4
).
uniform_
(
0
,
1
))
# 4d
rlt
=
tensor2img
(
tensor_4d_1
,
out_type
=
np
.
uint8
,
min_max
=
(
0
,
1
))
tensor_4d_1_np
=
make_grid
(
tensor_4d_1
,
nrow
=
1
,
normalize
=
False
).
numpy
()
tensor_4d_1_np
=
np
.
transpose
(
tensor_4d_1_np
[[
2
,
1
,
0
],
:,
:],
(
1
,
2
,
0
))
np
.
testing
.
assert_almost_equal
(
rlt
,
(
tensor_4d_1_np
*
255
).
round
())
rlt
=
tensor2img
(
tensor_4d_2
,
out_type
=
np
.
uint8
,
min_max
=
(
0
,
1
))
tensor_4d_2_np
=
tensor_4d_2
.
squeeze
().
numpy
()
tensor_4d_2_np
=
np
.
transpose
(
tensor_4d_2_np
[[
2
,
1
,
0
],
:,
:],
(
1
,
2
,
0
))
np
.
testing
.
assert_almost_equal
(
rlt
,
(
tensor_4d_2_np
*
255
).
round
())
rlt
=
tensor2img
(
tensor_4d_3
,
out_type
=
np
.
uint8
,
min_max
=
(
0
,
1
))
tensor_4d_3_np
=
make_grid
(
tensor_4d_3
,
nrow
=
1
,
normalize
=
False
).
numpy
()
tensor_4d_3_np
=
np
.
transpose
(
tensor_4d_3_np
[[
2
,
1
,
0
],
:,
:],
(
1
,
2
,
0
))
np
.
testing
.
assert_almost_equal
(
rlt
,
(
tensor_4d_3_np
*
255
).
round
())
rlt
=
tensor2img
(
tensor_4d_4
,
out_type
=
np
.
uint8
,
min_max
=
(
0
,
1
))
tensor_4d_4_np
=
tensor_4d_4
.
squeeze
().
numpy
()
np
.
testing
.
assert_almost_equal
(
rlt
,
(
tensor_4d_4_np
*
255
).
round
())
# 3d
rlt
=
tensor2img
([
tensor_3d_1
,
tensor_3d_2
],
out_type
=
np
.
uint8
,
min_max
=
(
0
,
1
))
tensor_3d_1_np
=
tensor_3d_1
.
numpy
()
tensor_3d_1_np
=
np
.
transpose
(
tensor_3d_1_np
[[
2
,
1
,
0
],
:,
:],
(
1
,
2
,
0
))
tensor_3d_2_np
=
tensor_3d_2
.
numpy
()
tensor_3d_2_np
=
np
.
transpose
(
tensor_3d_2_np
[[
2
,
1
,
0
],
:,
:],
(
1
,
2
,
0
))
np
.
testing
.
assert_almost_equal
(
rlt
[
0
],
(
tensor_3d_1_np
*
255
).
round
())
np
.
testing
.
assert_almost_equal
(
rlt
[
1
],
(
tensor_3d_2_np
*
255
).
round
())
rlt
=
tensor2img
(
tensor_3d_3
,
out_type
=
np
.
uint8
,
min_max
=
(
0
,
1
))
tensor_3d_3_np
=
tensor_3d_3
.
squeeze
().
numpy
()
np
.
testing
.
assert_almost_equal
(
rlt
,
(
tensor_3d_3_np
*
255
).
round
())
# 2d
rlt
=
tensor2img
(
tensor_2d
,
out_type
=
np
.
uint8
,
min_max
=
(
0
,
1
))
tensor_2d_np
=
tensor_2d
.
numpy
()
np
.
testing
.
assert_almost_equal
(
rlt
,
(
tensor_2d_np
*
255
).
round
())
rlt
=
tensor2img
(
tensor_2d
,
out_type
=
np
.
float32
,
min_max
=
(
0
,
1
))
np
.
testing
.
assert_almost_equal
(
rlt
,
tensor_2d_np
)
rlt
=
tensor2img
(
tensor_2d
,
out_type
=
np
.
float32
,
min_max
=
(
0.1
,
0.5
))
tensor_2d_np
=
(
np
.
clip
(
tensor_2d_np
,
0.1
,
0.5
)
-
0.1
)
/
0.4
np
.
testing
.
assert_almost_equal
(
rlt
,
tensor_2d_np
)
tests/test_cores/test_visualization_hook.py
0 → 100644
View file @
c9a48a52
# Copyright (c) OpenMMLab. All rights reserved.
import
os.path
as
osp
import
tempfile
from
unittest.mock
import
MagicMock
import
mmcv
import
numpy
as
np
import
pytest
import
torch
import
torch.nn
as
nn
from
torch.utils.data
import
DataLoader
,
Dataset
from
mmgen.core
import
VisualizationHook
from
mmgen.utils
import
get_root_logger
class
ExampleDataset
(
Dataset
):
def
__getitem__
(
self
,
idx
):
img
=
torch
.
zeros
((
3
,
10
,
10
))
img
[:,
2
:
9
,
:]
=
1.
results
=
dict
(
imgs
=
img
)
return
results
def
__len__
(
self
):
return
1
class
ExampleModel
(
nn
.
Module
):
def
__init__
(
self
):
super
(
ExampleModel
,
self
).
__init__
()
self
.
test_cfg
=
None
def
train_step
(
self
,
data_batch
,
optimizer
):
output
=
dict
(
results
=
dict
(
img
=
data_batch
[
'imgs'
]))
return
output
def
test_visual_hook
():
with
pytest
.
raises
(
AssertionError
):
VisualizationHook
(
'temp'
,
[
1
,
2
,
3
])
test_dataset
=
ExampleDataset
()
test_dataset
.
evaluate
=
MagicMock
(
return_value
=
dict
(
test
=
'success'
))
img
=
torch
.
zeros
((
1
,
3
,
10
,
10
))
img
[:,
:,
2
:
9
,
:]
=
1.
model
=
ExampleModel
()
data_loader
=
DataLoader
(
test_dataset
,
batch_size
=
1
,
sampler
=
None
,
num_workers
=
0
,
shuffle
=
False
)
with
tempfile
.
TemporaryDirectory
()
as
tmpdir
:
visual_hook
=
VisualizationHook
(
'visual'
,
[
'img'
],
interval
=
8
)
runner
=
mmcv
.
runner
.
IterBasedRunner
(
model
=
model
,
work_dir
=
tmpdir
,
logger
=
get_root_logger
())
runner
.
register_hook
(
visual_hook
)
runner
.
run
([
data_loader
],
[(
'train'
,
10
)],
10
)
img_saved
=
mmcv
.
imread
(
osp
.
join
(
tmpdir
,
'visual'
,
'iter_8.png'
),
flag
=
'unchanged'
)
np
.
testing
.
assert_almost_equal
(
img_saved
,
img
[
0
].
permute
(
1
,
2
,
0
)
*
127
+
128
)
tests/test_datasets/test_dataset_wrappers.py
0 → 100644
View file @
c9a48a52
# Copyright (c) OpenMMLab. All rights reserved.
from
torch.utils.data
import
Dataset
from
mmgen.datasets
import
RepeatDataset
def
test_repeat_dataset
():
class
ToyDataset
(
Dataset
):
def
__init__
(
self
):
super
(
ToyDataset
,
self
).
__init__
()
self
.
members
=
[
1
,
2
,
3
,
4
,
5
]
def
__len__
(
self
):
return
len
(
self
.
members
)
def
__getitem__
(
self
,
idx
):
return
self
.
members
[
idx
%
5
]
toy_dataset
=
ToyDataset
()
repeat_dataset
=
RepeatDataset
(
toy_dataset
,
2
)
assert
len
(
repeat_dataset
)
==
10
assert
repeat_dataset
[
2
]
==
3
assert
repeat_dataset
[
8
]
==
4
tests/test_datasets/test_grow_scale_img_dataset.py
0 → 100644
View file @
c9a48a52
# Copyright (c) OpenMMLab. All rights reserved.
import
os.path
as
osp
import
pytest
from
mmgen.datasets
import
GrowScaleImgDataset
class
TestGrowScaleImgDataset
:
@
classmethod
def
setup_class
(
cls
):
cls
.
imgs_root
=
osp
.
join
(
osp
.
dirname
(
__file__
),
'..'
,
'data/image'
)
cls
.
imgs_roots
=
{
'4'
:
cls
.
imgs_root
,
'8'
:
osp
.
join
(
cls
.
imgs_root
,
'img_root'
),
'32'
:
osp
.
join
(
cls
.
imgs_root
,
'img_root'
,
'grass'
)
}
cls
.
default_pipeline
=
[
dict
(
type
=
'LoadImageFromFile'
,
io_backend
=
'disk'
,
key
=
'real_img'
)
]
cls
.
len_per_stage
=
10
cls
.
gpu_samples_base
=
2
def
test_dynamic_unconditional_img_dataset
(
self
):
dataset
=
GrowScaleImgDataset
(
self
.
imgs_roots
,
self
.
default_pipeline
,
self
.
len_per_stage
,
gpu_samples_base
=
self
.
gpu_samples_base
)
assert
len
(
dataset
)
==
10
img
=
dataset
[
2
][
'real_img'
]
assert
img
.
ndim
==
3
assert
repr
(
dataset
)
==
(
f
'dataset_name:
{
dataset
.
__class__
}
, '
f
'total
{
10
}
images in imgs_root:
{
self
.
imgs_root
}
'
)
assert
dataset
.
samples_per_gpu
==
2
dataset
.
update_annotations
(
8
)
assert
len
(
dataset
)
==
10
img
=
dataset
[
2
][
'real_img'
]
assert
img
.
ndim
==
3
assert
repr
(
dataset
)
==
(
f
'dataset_name:
{
dataset
.
__class__
}
, '
f
'total
{
10
}
images in imgs_root:'
f
'
{
osp
.
join
(
self
.
imgs_root
,
"img_root"
)
}
'
)
assert
dataset
.
samples_per_gpu
==
2
dataset
=
GrowScaleImgDataset
(
self
.
imgs_roots
,
self
.
default_pipeline
,
20
,
gpu_samples_base
=
self
.
gpu_samples_base
,
gpu_samples_per_scale
=
{
'4'
:
10
,
'16'
:
13
})
assert
len
(
dataset
)
==
20
img
=
dataset
[
2
][
'real_img'
]
assert
img
.
ndim
==
3
assert
repr
(
dataset
)
==
(
f
'dataset_name:
{
dataset
.
__class__
}
, '
f
'total
{
20
}
images in imgs_root:
{
self
.
imgs_root
}
'
)
assert
dataset
.
samples_per_gpu
==
10
dataset
.
update_annotations
(
8
)
assert
len
(
dataset
)
==
20
img
=
dataset
[
2
][
'real_img'
]
assert
img
.
ndim
==
3
assert
repr
(
dataset
)
==
(
f
'dataset_name:
{
dataset
.
__class__
}
, '
f
'total
{
20
}
images in imgs_root:'
f
'
{
osp
.
join
(
self
.
imgs_root
,
"img_root"
)
}
'
)
assert
dataset
.
samples_per_gpu
==
2
dataset
=
GrowScaleImgDataset
(
self
.
imgs_roots
,
self
.
default_pipeline
,
5
,
test_mode
=
True
)
assert
len
(
dataset
)
==
5
img
=
dataset
[
2
][
'real_img'
]
assert
img
.
ndim
==
3
assert
repr
(
dataset
)
==
(
f
'dataset_name:
{
dataset
.
__class__
}
, '
f
'total
{
5
}
images in imgs_root:
{
self
.
imgs_root
}
'
)
dataset
.
update_annotations
(
24
)
assert
len
(
dataset
)
==
5
img
=
dataset
[
2
][
'real_img'
]
assert
img
.
ndim
==
3
_path_str
=
osp
.
join
(
self
.
imgs_root
,
'img_root'
,
'grass'
)
assert
repr
(
dataset
)
==
(
f
'dataset_name:
{
dataset
.
__class__
}
, '
f
'total
{
5
}
images in imgs_root:
{
_path_str
}
'
)
with
pytest
.
raises
(
AssertionError
):
_
=
GrowScaleImgDataset
(
self
.
imgs_root
,
self
.
default_pipeline
,
10
,
gpu_samples_per_scale
=
10
)
with
pytest
.
raises
(
AssertionError
):
_
=
GrowScaleImgDataset
(
10
,
self
.
default_pipeline
,
10.
)
tests/test_datasets/test_paired_image_dataset.py
0 → 100644
View file @
c9a48a52
# Copyright (c) OpenMMLab. All rights reserved.
import
os.path
as
osp
from
mmgen.datasets
import
PairedImageDataset
class
TestPairedImageDataset
(
object
):
@
classmethod
def
setup_class
(
cls
):
cls
.
imgs_root
=
osp
.
join
(
osp
.
dirname
(
osp
.
dirname
(
__file__
)),
'data/paired'
)
img_norm_cfg
=
dict
(
mean
=
[
0.5
,
0.5
,
0.5
],
std
=
[
0.5
,
0.5
,
0.5
])
cls
.
default_pipeline
=
[
dict
(
type
=
'LoadPairedImageFromFile'
,
io_backend
=
'disk'
,
key
=
'pair'
,
domain_a
=
'a'
,
domain_b
=
'b'
),
dict
(
type
=
'Resize'
,
keys
=
[
'img_a'
,
'img_b'
],
scale
=
(
286
,
286
),
interpolation
=
'bicubic'
),
dict
(
type
=
'FixedCrop'
,
keys
=
[
'img_a'
,
'img_b'
],
crop_size
=
(
256
,
256
)),
dict
(
type
=
'Flip'
,
keys
=
[
'img_a'
,
'img_b'
],
direction
=
'horizontal'
),
dict
(
type
=
'RescaleToZeroOne'
,
keys
=
[
'img_a'
,
'img_b'
]),
dict
(
type
=
'Normalize'
,
keys
=
[
'img_a'
,
'img_b'
],
to_rgb
=
True
,
**
img_norm_cfg
),
dict
(
type
=
'ImageToTensor'
,
keys
=
[
'img_a'
,
'img_b'
]),
dict
(
type
=
'Collect'
,
keys
=
[
'img_a'
,
'img_b'
],
meta_keys
=
[
'img_a_path'
,
'img_b_path'
])
]
def
test_paired_image_dataset
(
self
):
dataset
=
PairedImageDataset
(
self
.
imgs_root
,
pipeline
=
self
.
default_pipeline
)
assert
len
(
dataset
)
==
2
img
=
dataset
[
0
][
'img_a'
]
assert
img
.
ndim
==
3
img
=
dataset
[
0
][
'img_b'
]
assert
img
.
ndim
==
3
tests/test_datasets/test_persistent_worker.py
0 → 100644
View file @
c9a48a52
# Copyright (c) OpenMMLab. All rights reserved.
import
os.path
as
osp
from
mmgen.datasets.builder
import
build_dataloader
,
build_dataset
class
TestPersistentWorker
(
object
):
@
classmethod
def
setup_class
(
cls
):
imgs_root
=
osp
.
join
(
osp
.
dirname
(
__file__
),
'..'
,
'data/image'
)
train_pipeline
=
[
dict
(
type
=
'LoadImageFromFile'
,
io_backend
=
'disk'
,
key
=
'real_img'
)
]
cls
.
config
=
dict
(
samples_per_gpu
=
1
,
workers_per_gpu
=
4
,
drop_last
=
True
,
persistent_workers
=
True
)
cls
.
data_cfg
=
dict
(
type
=
'UnconditionalImageDataset'
,
imgs_root
=
imgs_root
,
pipeline
=
train_pipeline
,
test_mode
=
False
)
def
test_persistent_worker
(
self
):
# test non-persistent-worker
dataset
=
build_dataset
(
self
.
data_cfg
)
build_dataloader
(
dataset
,
**
self
.
config
)
tests/test_datasets/test_pipelines/test_augmentation.py
0 → 100644
View file @
c9a48a52
# Copyright (c) OpenMMLab. All rights reserved.
import
copy
import
numpy
as
np
import
pytest
import
torch
from
mmgen.datasets.pipelines
import
(
CenterCropLongEdge
,
Flip
,
NumpyPad
,
RandomCropLongEdge
,
RandomImgNoise
,
Resize
)
class
TestAugmentations
(
object
):
@
classmethod
def
setup_class
(
cls
):
cls
.
results
=
dict
()
cls
.
img_gt
=
np
.
random
.
rand
(
256
,
128
,
3
).
astype
(
np
.
float32
)
cls
.
img_lq
=
np
.
random
.
rand
(
64
,
32
,
3
).
astype
(
np
.
float32
)
cls
.
results
=
dict
(
lq
=
cls
.
img_lq
,
gt
=
cls
.
img_gt
,
scale
=
4
,
lq_path
=
'fake_lq_path'
,
gt_path
=
'fake_gt_path'
)
cls
.
results
[
'img'
]
=
np
.
random
.
rand
(
256
,
256
,
3
).
astype
(
np
.
float32
)
cls
.
results
[
'mask'
]
=
np
.
random
.
rand
(
256
,
256
,
1
).
astype
(
np
.
float32
)
cls
.
results
[
'img_tensor'
]
=
torch
.
rand
((
3
,
256
,
256
))
cls
.
results
[
'mask_tensor'
]
=
torch
.
zeros
((
1
,
256
,
256
))
cls
.
results
[
'mask_tensor'
][:,
50
:
150
,
40
:
140
]
=
1.
@
staticmethod
def
assert_img_equal
(
img
,
ref_img
,
ratio_thr
=
0.999
):
"""Check if img and ref_img are matched approximately."""
assert
img
.
shape
==
ref_img
.
shape
assert
img
.
dtype
==
ref_img
.
dtype
area
=
ref_img
.
shape
[
-
1
]
*
ref_img
.
shape
[
-
2
]
diff
=
np
.
abs
(
img
.
astype
(
'int32'
)
-
ref_img
.
astype
(
'int32'
))
assert
np
.
sum
(
diff
<=
1
)
/
float
(
area
)
>
ratio_thr
@
staticmethod
def
check_keys_contain
(
result_keys
,
target_keys
):
"""Check if all elements in target_keys is in result_keys."""
return
set
(
target_keys
).
issubset
(
set
(
result_keys
))
@
staticmethod
def
check_flip
(
origin_img
,
result_img
,
flip_type
):
"""Check if the origin_img are flipped correctly into result_img in
different flip_types."""
h
,
w
,
c
=
origin_img
.
shape
if
flip_type
==
'horizontal'
:
# yapf: disable
for
i
in
range
(
h
):
for
j
in
range
(
w
):
for
k
in
range
(
c
):
if
result_img
[
i
,
j
,
k
]
!=
origin_img
[
i
,
w
-
1
-
j
,
k
]:
return
False
# yapf: enable
else
:
# yapf: disable
for
i
in
range
(
h
):
for
j
in
range
(
w
):
for
k
in
range
(
c
):
if
result_img
[
i
,
j
,
k
]
!=
origin_img
[
h
-
1
-
i
,
j
,
k
]:
return
False
# yapf: enable
return
True
def
test_flip
(
self
):
results
=
copy
.
deepcopy
(
self
.
results
)
with
pytest
.
raises
(
ValueError
):
Flip
(
keys
=
[
'lq'
,
'gt'
],
direction
=
'vertically'
)
# horizontal
np
.
random
.
seed
(
1
)
target_keys
=
[
'lq'
,
'gt'
,
'flip'
,
'flip_direction'
]
flip
=
Flip
(
keys
=
[
'lq'
,
'gt'
],
flip_ratio
=
1
,
direction
=
'horizontal'
)
results
=
flip
(
results
)
assert
self
.
check_keys_contain
(
results
.
keys
(),
target_keys
)
assert
self
.
check_flip
(
self
.
img_lq
,
results
[
'lq'
],
results
[
'flip_direction'
])
assert
self
.
check_flip
(
self
.
img_gt
,
results
[
'gt'
],
results
[
'flip_direction'
])
assert
results
[
'lq'
].
shape
==
self
.
img_lq
.
shape
assert
results
[
'gt'
].
shape
==
self
.
img_gt
.
shape
# vertical
results
=
copy
.
deepcopy
(
self
.
results
)
flip
=
Flip
(
keys
=
[
'lq'
,
'gt'
],
flip_ratio
=
1
,
direction
=
'vertical'
)
results
=
flip
(
results
)
assert
self
.
check_keys_contain
(
results
.
keys
(),
target_keys
)
assert
self
.
check_flip
(
self
.
img_lq
,
results
[
'lq'
],
results
[
'flip_direction'
])
assert
self
.
check_flip
(
self
.
img_gt
,
results
[
'gt'
],
results
[
'flip_direction'
])
assert
results
[
'lq'
].
shape
==
self
.
img_lq
.
shape
assert
results
[
'gt'
].
shape
==
self
.
img_gt
.
shape
assert
repr
(
flip
)
==
flip
.
__class__
.
__name__
+
(
f
"(keys=
{
[
'lq'
,
'gt'
]
}
, flip_ratio=1, "
f
"direction=
{
results
[
'flip_direction'
]
}
)"
)
# flip a list
# horizontal
flip
=
Flip
(
keys
=
[
'lq'
,
'gt'
],
flip_ratio
=
1
,
direction
=
'horizontal'
)
results
=
dict
(
lq
=
[
self
.
img_lq
,
np
.
copy
(
self
.
img_lq
)],
gt
=
[
self
.
img_gt
,
np
.
copy
(
self
.
img_gt
)],
scale
=
4
,
lq_path
=
'fake_lq_path'
,
gt_path
=
'fake_gt_path'
)
flip_rlt
=
flip
(
copy
.
deepcopy
(
results
))
assert
self
.
check_keys_contain
(
flip_rlt
.
keys
(),
target_keys
)
assert
self
.
check_flip
(
self
.
img_lq
,
flip_rlt
[
'lq'
][
0
],
flip_rlt
[
'flip_direction'
])
assert
self
.
check_flip
(
self
.
img_gt
,
flip_rlt
[
'gt'
][
0
],
flip_rlt
[
'flip_direction'
])
np
.
testing
.
assert_almost_equal
(
flip_rlt
[
'gt'
][
0
],
flip_rlt
[
'gt'
][
1
])
np
.
testing
.
assert_almost_equal
(
flip_rlt
[
'lq'
][
0
],
flip_rlt
[
'lq'
][
1
])
# vertical
flip
=
Flip
(
keys
=
[
'lq'
,
'gt'
],
flip_ratio
=
1
,
direction
=
'vertical'
)
flip_rlt
=
flip
(
copy
.
deepcopy
(
results
))
assert
self
.
check_keys_contain
(
flip_rlt
.
keys
(),
target_keys
)
assert
self
.
check_flip
(
self
.
img_lq
,
flip_rlt
[
'lq'
][
0
],
flip_rlt
[
'flip_direction'
])
assert
self
.
check_flip
(
self
.
img_gt
,
flip_rlt
[
'gt'
][
0
],
flip_rlt
[
'flip_direction'
])
np
.
testing
.
assert_almost_equal
(
flip_rlt
[
'gt'
][
0
],
flip_rlt
[
'gt'
][
1
])
np
.
testing
.
assert_almost_equal
(
flip_rlt
[
'lq'
][
0
],
flip_rlt
[
'lq'
][
1
])
# no flip
flip
=
Flip
(
keys
=
[
'lq'
,
'gt'
],
flip_ratio
=
0
,
direction
=
'vertical'
)
results
=
flip
(
copy
.
deepcopy
(
results
))
assert
self
.
check_keys_contain
(
results
.
keys
(),
target_keys
)
np
.
testing
.
assert_almost_equal
(
results
[
'gt'
][
0
],
self
.
img_gt
)
np
.
testing
.
assert_almost_equal
(
results
[
'lq'
][
0
],
self
.
img_lq
)
np
.
testing
.
assert_almost_equal
(
results
[
'gt'
][
0
],
results
[
'gt'
][
1
])
np
.
testing
.
assert_almost_equal
(
results
[
'lq'
][
0
],
results
[
'lq'
][
1
])
def
test_resize
(
self
):
with
pytest
.
raises
(
AssertionError
):
Resize
([],
scale
=
0.5
)
with
pytest
.
raises
(
AssertionError
):
Resize
([
'gt_img'
],
size_factor
=
32
,
scale
=
0.5
)
with
pytest
.
raises
(
AssertionError
):
Resize
([
'gt_img'
],
size_factor
=
32
,
keep_ratio
=
True
)
with
pytest
.
raises
(
AssertionError
):
Resize
([
'gt_img'
],
max_size
=
32
,
size_factor
=
None
)
with
pytest
.
raises
(
ValueError
):
Resize
([
'gt_img'
],
scale
=-
0.5
)
with
pytest
.
raises
(
TypeError
):
Resize
([
'gt_img'
],
(
0.4
,
0.2
))
with
pytest
.
raises
(
TypeError
):
Resize
([
'gt_img'
],
dict
(
test
=
None
))
target_keys
=
[
'alpha'
]
alpha
=
np
.
random
.
rand
(
240
,
320
).
astype
(
np
.
float32
)
results
=
dict
(
alpha
=
alpha
)
resize
=
Resize
(
keys
=
[
'alpha'
],
size_factor
=
32
,
max_size
=
None
)
resize_results
=
resize
(
results
)
assert
self
.
check_keys_contain
(
resize_results
.
keys
(),
target_keys
)
assert
resize_results
[
'alpha'
].
shape
==
(
224
,
320
,
1
)
resize
=
Resize
(
keys
=
[
'alpha'
],
size_factor
=
32
,
max_size
=
320
)
resize_results
=
resize
(
results
)
assert
self
.
check_keys_contain
(
resize_results
.
keys
(),
target_keys
)
assert
resize_results
[
'alpha'
].
shape
==
(
224
,
320
,
1
)
resize
=
Resize
(
keys
=
[
'alpha'
],
size_factor
=
32
,
max_size
=
200
)
resize_results
=
resize
(
results
)
assert
self
.
check_keys_contain
(
resize_results
.
keys
(),
target_keys
)
assert
resize_results
[
'alpha'
].
shape
==
(
192
,
192
,
1
)
resize
=
Resize
([
'gt_img'
],
(
-
1
,
200
))
results
=
dict
(
gt_img
=
self
.
results
[
'gt'
].
copy
())
resize_results
=
resize
(
results
)
assert
resize
.
scale
==
(
np
.
inf
,
200
)
assert
resize_results
[
'gt_img'
].
shape
==
(
400
,
200
,
3
)
resize
=
Resize
([
'gt_img'
],
(
-
1
,
200
))
results
=
dict
(
gt_img
=
self
.
results
[
'gt'
].
copy
().
transpose
(
1
,
0
,
2
))
resize_results
=
resize
(
results
)
assert
resize
.
scale
==
(
np
.
inf
,
200
)
assert
resize_results
[
'gt_img'
].
shape
==
(
200
,
400
,
3
)
results
=
dict
(
gt_img
=
self
.
results
[
'img'
].
copy
())
resize_keep_ratio
=
Resize
([
'gt_img'
],
scale
=
0.5
,
keep_ratio
=
True
)
results
=
resize_keep_ratio
(
results
)
assert
results
[
'gt_img'
].
shape
[:
2
]
==
(
128
,
128
)
assert
results
[
'scale_factor'
]
==
0.5
results
=
dict
(
gt_img
=
self
.
results
[
'img'
].
copy
())
resize_keep_ratio
=
Resize
([
'gt_img'
],
scale
=
(
128
,
128
),
keep_ratio
=
False
)
results
=
resize_keep_ratio
(
results
)
assert
results
[
'gt_img'
].
shape
[:
2
]
==
(
128
,
128
)
# test input with shape (256, 256)
results
=
dict
(
gt_img
=
self
.
results
[
'img'
][...,
0
].
copy
())
resize
=
Resize
([
'gt_img'
],
scale
=
(
128
,
128
),
keep_ratio
=
False
)
results
=
resize
(
results
)
assert
results
[
'gt_img'
].
shape
==
(
128
,
128
,
1
)
name_
=
str
(
resize_keep_ratio
)
assert
name_
==
resize_keep_ratio
.
__class__
.
__name__
+
(
f
"(keys=
{
[
'gt_img'
]
}
, scale=(128, 128), "
f
'keep_ratio=
{
False
}
, size_factor=None, '
'max_size=None,interpolation=bilinear)'
)
def
test_random_img_noise
():
img
=
np
.
random
.
randn
(
256
,
128
,
3
).
astype
(
np
.
float32
)
results
=
dict
(
img
=
copy
.
deepcopy
(
img
))
noise_uniform
=
RandomImgNoise
([
'img'
],
1
,
2
,
distribution
=
'uniform'
)
results
=
noise_uniform
(
results
)
assert
(
results
[
'img'
]
-
img
<=
2
).
all
()
assert
(
results
[
'img'
]
-
img
>=
1
).
all
()
repr_str
=
noise_uniform
.
__class__
.
__name__
repr_str
+=
(
f
'(keys=
{
noise_uniform
.
keys
}
, '
f
'lower_bound=
{
noise_uniform
.
lower_bound
}
, '
f
'upper_bound=
{
noise_uniform
.
upper_bound
}
)'
)
assert
str
(
noise_uniform
)
==
repr_str
img
=
np
.
random
.
randn
(
256
,
128
,
3
).
astype
(
np
.
float32
)
results
=
dict
(
img
=
copy
.
deepcopy
(
img
))
noise_normal
=
RandomImgNoise
([
'img'
],
distribution
=
'normal'
)
results
=
noise_normal
(
results
)
assert
(
results
[
'img'
]
-
img
<=
1
/
128.
).
all
()
assert
(
results
[
'img'
]
-
img
>=
0
).
all
()
repr_str
=
noise_normal
.
__class__
.
__name__
repr_str
+=
(
f
'(keys=
{
noise_normal
.
keys
}
, '
f
'lower_bound=
{
noise_normal
.
lower_bound
}
, '
f
'upper_bound=
{
noise_normal
.
upper_bound
}
)'
)
assert
str
(
noise_normal
)
==
repr_str
with
pytest
.
raises
(
AssertionError
):
RandomImgNoise
([])
with
pytest
.
raises
(
KeyError
):
RandomImgNoise
([
'img'
],
distribution
=
'test'
)
def
test_random_long_edge_crop
():
results
=
dict
(
img
=
np
.
random
.
rand
(
256
,
128
,
3
).
astype
(
np
.
float32
))
crop
=
RandomCropLongEdge
([
'img'
])
results
=
crop
(
results
)
assert
results
[
'img'
].
shape
==
(
128
,
128
,
3
)
repr_str
=
crop
.
__class__
.
__name__
repr_str
+=
(
f
'(keys=
{
crop
.
keys
}
)'
)
assert
str
(
crop
)
==
repr_str
def
test_center_long_edge_crop
():
results
=
dict
(
img
=
np
.
random
.
rand
(
256
,
128
,
3
).
astype
(
np
.
float32
))
crop
=
CenterCropLongEdge
([
'img'
])
results
=
crop
(
results
)
assert
results
[
'img'
].
shape
==
(
128
,
128
,
3
)
repr_str
=
crop
.
__class__
.
__name__
repr_str
+=
(
f
'(keys=
{
crop
.
keys
}
)'
)
assert
str
(
crop
)
==
repr_str
def
test_numpy_pad
():
results
=
dict
(
img
=
np
.
zeros
((
5
,
5
,
1
)))
pad
=
NumpyPad
([
'img'
],
((
2
,
2
),
(
0
,
0
),
(
0
,
0
)))
results
=
pad
(
results
)
assert
results
[
'img'
].
shape
==
(
9
,
5
,
1
)
repr_str
=
pad
.
__class__
.
__name__
repr_str
+=
(
f
'(keys=
{
pad
.
keys
}
, padding=
{
pad
.
padding
}
, kwargs=
{
pad
.
kwargs
}
)'
)
assert
str
(
pad
)
==
repr_str
tests/test_datasets/test_pipelines/test_compose.py
0 → 100644
View file @
c9a48a52
# Copyright (c) OpenMMLab. All rights reserved.
import
numpy
as
np
import
pytest
from
mmgen.datasets.pipelines
import
Compose
,
ImageToTensor
def
check_keys_equal
(
result_keys
,
target_keys
):
"""Check if all elements in target_keys is in result_keys."""
return
set
(
target_keys
)
==
set
(
result_keys
)
def
test_compose
():
with
pytest
.
raises
(
TypeError
):
Compose
(
'LoadAlpha'
)
target_keys
=
[
'img'
,
'meta'
]
img
=
np
.
random
.
randn
(
256
,
256
,
3
)
results
=
dict
(
img
=
img
,
abandoned_key
=
None
,
img_name
=
'test_image.png'
)
test_pipeline
=
[
dict
(
type
=
'Collect'
,
keys
=
[
'img'
],
meta_keys
=
[
'img_name'
]),
dict
(
type
=
'ImageToTensor'
,
keys
=
[
'img'
])
]
compose
=
Compose
(
test_pipeline
)
compose_results
=
compose
(
results
)
assert
check_keys_equal
(
compose_results
.
keys
(),
target_keys
)
assert
check_keys_equal
(
compose_results
[
'meta'
].
data
.
keys
(),
[
'img_name'
])
results
=
None
image_to_tensor
=
ImageToTensor
(
keys
=
[])
test_pipeline
=
[
image_to_tensor
]
compose
=
Compose
(
test_pipeline
)
compose_results
=
compose
(
results
)
assert
compose_results
is
None
assert
repr
(
compose
)
==
(
compose
.
__class__
.
__name__
+
f
'(
\n
{
image_to_tensor
}
\n
)'
)
tests/test_datasets/test_pipelines/test_crop.py
0 → 100644
View file @
c9a48a52
# Copyright (c) OpenMMLab. All rights reserved.
import
copy
import
numpy
as
np
import
pytest
from
mmgen.datasets.pipelines
import
Crop
,
FixedCrop
class
TestAugmentations
(
object
):
@
classmethod
def
setup_class
(
cls
):
cls
.
results
=
dict
()
cls
.
img_gt
=
np
.
random
.
rand
(
256
,
128
,
3
).
astype
(
np
.
float32
)
cls
.
img_lq
=
np
.
random
.
rand
(
64
,
32
,
3
).
astype
(
np
.
float32
)
cls
.
results
=
dict
(
lq
=
cls
.
img_lq
,
gt
=
cls
.
img_gt
,
scale
=
4
,
lq_path
=
'fake_lq_path'
,
gt_path
=
'fake_gt_path'
)
cls
.
results
[
'img'
]
=
np
.
random
.
rand
(
256
,
256
,
3
).
astype
(
np
.
float32
)
cls
.
results
[
'img_a'
]
=
np
.
random
.
rand
(
286
,
286
,
3
).
astype
(
np
.
float32
)
cls
.
results
[
'img_b'
]
=
np
.
random
.
rand
(
286
,
286
,
3
).
astype
(
np
.
float32
)
@
staticmethod
def
check_crop
(
result_img_shape
,
result_bbox
):
crop_w
=
result_bbox
[
2
]
-
result_bbox
[
0
]
"""Check if the result_bbox is in correspond to result_img_shape."""
crop_h
=
result_bbox
[
3
]
-
result_bbox
[
1
]
crop_shape
=
(
crop_h
,
crop_w
)
return
result_img_shape
==
crop_shape
@
staticmethod
def
check_crop_around_semi
(
alpha
):
return
((
alpha
>
0
)
&
(
alpha
<
255
)).
any
()
@
staticmethod
def
check_keys_contain
(
result_keys
,
target_keys
):
"""Check if all elements in target_keys is in result_keys."""
return
set
(
target_keys
).
issubset
(
set
(
result_keys
))
def
test_crop
(
self
):
with
pytest
.
raises
(
TypeError
):
Crop
([
'img'
],
(
0.23
,
0.1
))
# test center crop
results
=
copy
.
deepcopy
(
self
.
results
)
center_crop
=
Crop
([
'img'
],
crop_size
=
(
128
,
128
),
random_crop
=
False
)
results
=
center_crop
(
results
)
assert
results
[
'img_crop_bbox'
]
==
[
64
,
64
,
128
,
128
]
assert
np
.
array_equal
(
self
.
results
[
'img'
][
64
:
192
,
64
:
192
,
:],
results
[
'img'
])
# test random crop
results
=
copy
.
deepcopy
(
self
.
results
)
random_crop
=
Crop
([
'img'
],
crop_size
=
(
128
,
128
),
random_crop
=
True
)
results
=
random_crop
(
results
)
assert
0
<=
results
[
'img_crop_bbox'
][
0
]
<=
128
assert
0
<=
results
[
'img_crop_bbox'
][
1
]
<=
128
assert
results
[
'img_crop_bbox'
][
2
]
==
128
assert
results
[
'img_crop_bbox'
][
3
]
==
128
# test random crop for lager size than the original shape
results
=
copy
.
deepcopy
(
self
.
results
)
random_crop
=
Crop
([
'img'
],
crop_size
=
(
512
,
512
),
random_crop
=
True
)
results
=
random_crop
(
results
)
assert
np
.
array_equal
(
self
.
results
[
'img'
],
results
[
'img'
])
assert
str
(
random_crop
)
==
(
random_crop
.
__class__
.
__name__
+
"keys=['img'], crop_size=(512, 512), random_crop=True"
)
def
test_fixed_crop
(
self
):
with
pytest
.
raises
(
TypeError
):
FixedCrop
([
'img_a'
,
'img_b'
],
(
0.23
,
0.1
))
with
pytest
.
raises
(
TypeError
):
FixedCrop
([
'img_a'
,
'img_b'
],
(
256
,
256
),
(
0
,
0.1
))
# test shape consistency
results
=
copy
.
deepcopy
(
self
.
results
)
fixed_crop
=
FixedCrop
([
'img_a'
,
'img'
],
crop_size
=
(
128
,
128
))
with
pytest
.
raises
(
ValueError
):
results
=
fixed_crop
(
results
)
# test given pos crop
results
=
copy
.
deepcopy
(
self
.
results
)
given_pos_crop
=
FixedCrop
([
'img_a'
,
'img_b'
],
crop_size
=
(
256
,
256
),
crop_pos
=
(
1
,
1
))
results
=
given_pos_crop
(
results
)
assert
results
[
'img_a_crop_bbox'
]
==
[
1
,
1
,
256
,
256
]
assert
results
[
'img_b_crop_bbox'
]
==
[
1
,
1
,
256
,
256
]
assert
np
.
array_equal
(
self
.
results
[
'img_a'
][
1
:
257
,
1
:
257
,
:],
results
[
'img_a'
])
assert
np
.
array_equal
(
self
.
results
[
'img_b'
][
1
:
257
,
1
:
257
,
:],
results
[
'img_b'
])
# test given pos crop if pos > suitable pos
results
=
copy
.
deepcopy
(
self
.
results
)
given_pos_crop
=
FixedCrop
([
'img_a'
,
'img_b'
],
crop_size
=
(
256
,
256
),
crop_pos
=
(
280
,
280
))
results
=
given_pos_crop
(
results
)
assert
results
[
'img_a_crop_bbox'
]
==
[
280
,
280
,
6
,
6
]
assert
results
[
'img_b_crop_bbox'
]
==
[
280
,
280
,
6
,
6
]
assert
np
.
array_equal
(
self
.
results
[
'img_a'
][
280
:,
280
:,
:],
results
[
'img_a'
])
assert
np
.
array_equal
(
self
.
results
[
'img_b'
][
280
:,
280
:,
:],
results
[
'img_b'
])
assert
str
(
given_pos_crop
)
==
(
given_pos_crop
.
__class__
.
__name__
+
"keys=['img_a', 'img_b'], crop_size=(256, 256), "
+
'crop_pos=(280, 280)'
)
# test random initialized fixed crop
results
=
copy
.
deepcopy
(
self
.
results
)
random_fixed_crop
=
FixedCrop
([
'img_a'
,
'img_b'
],
crop_size
=
(
256
,
256
),
crop_pos
=
None
)
results
=
random_fixed_crop
(
results
)
assert
0
<=
results
[
'img_a_crop_bbox'
][
0
]
<=
30
assert
0
<=
results
[
'img_a_crop_bbox'
][
1
]
<=
30
assert
results
[
'img_a_crop_bbox'
][
2
]
==
256
assert
results
[
'img_a_crop_bbox'
][
3
]
==
256
x_offset
,
y_offset
,
crop_w
,
crop_h
=
results
[
'img_a_crop_bbox'
]
assert
x_offset
==
results
[
'img_b_crop_bbox'
][
0
]
assert
y_offset
==
results
[
'img_b_crop_bbox'
][
1
]
assert
crop_w
==
results
[
'img_b_crop_bbox'
][
2
]
assert
crop_h
==
results
[
'img_b_crop_bbox'
][
3
]
assert
np
.
array_equal
(
self
.
results
[
'img_a'
][
y_offset
:
y_offset
+
crop_h
,
x_offset
:
x_offset
+
crop_w
,
:],
results
[
'img_a'
])
assert
np
.
array_equal
(
self
.
results
[
'img_b'
][
y_offset
:
y_offset
+
crop_h
,
x_offset
:
x_offset
+
crop_w
,
:],
results
[
'img_b'
])
# test given pos crop for lager size than the original shape
results
=
copy
.
deepcopy
(
self
.
results
)
given_pos_crop
=
FixedCrop
([
'img_a'
,
'img_b'
],
crop_size
=
(
512
,
512
),
crop_pos
=
(
1
,
1
))
results
=
given_pos_crop
(
results
)
assert
results
[
'img_a_crop_bbox'
]
==
[
1
,
1
,
285
,
285
]
assert
results
[
'img_b_crop_bbox'
]
==
[
1
,
1
,
285
,
285
]
assert
np
.
array_equal
(
self
.
results
[
'img_a'
][
1
:,
1
:,
:],
results
[
'img_a'
])
assert
np
.
array_equal
(
self
.
results
[
'img_b'
][
1
:,
1
:,
:],
results
[
'img_b'
])
assert
str
(
given_pos_crop
)
==
(
given_pos_crop
.
__class__
.
__name__
+
"keys=['img_a', 'img_b'], crop_size=(512, 512), crop_pos=(1, 1)"
)
# test random initialized fixed crop for lager size
# than the original shape
results
=
copy
.
deepcopy
(
self
.
results
)
random_fixed_crop
=
FixedCrop
([
'img_a'
,
'img_b'
],
crop_size
=
(
512
,
512
),
crop_pos
=
None
)
results
=
random_fixed_crop
(
results
)
assert
results
[
'img_a_crop_bbox'
]
==
[
0
,
0
,
286
,
286
]
assert
results
[
'img_b_crop_bbox'
]
==
[
0
,
0
,
286
,
286
]
assert
np
.
array_equal
(
self
.
results
[
'img_a'
],
results
[
'img_a'
])
assert
np
.
array_equal
(
self
.
results
[
'img_b'
],
results
[
'img_b'
])
assert
str
(
random_fixed_crop
)
==
(
random_fixed_crop
.
__class__
.
__name__
+
"keys=['img_a', 'img_b'], crop_size=(512, 512), crop_pos=None"
)
tests/test_datasets/test_pipelines/test_formatting.py
0 → 100644
View file @
c9a48a52
# Copyright (c) OpenMMLab. All rights reserved.
import
numpy
as
np
import
pytest
import
torch
from
mmgen.datasets.pipelines
import
Collect
,
ImageToTensor
,
ToTensor
def
check_keys_contain
(
result_keys
,
target_keys
):
"""Check if all elements in target_keys is in result_keys."""
return
set
(
target_keys
).
issubset
(
set
(
result_keys
))
def
test_to_tensor
():
to_tensor
=
ToTensor
([
'str'
])
with
pytest
.
raises
(
TypeError
):
results
=
dict
(
str
=
'0'
)
to_tensor
(
results
)
target_keys
=
[
'tensor'
,
'numpy'
,
'sequence'
,
'int'
,
'float'
]
to_tensor
=
ToTensor
(
target_keys
)
ori_results
=
dict
(
tensor
=
torch
.
randn
(
2
,
3
),
numpy
=
np
.
random
.
randn
(
2
,
3
),
sequence
=
list
(
range
(
10
)),
int
=
1
,
float
=
0.1
)
results
=
to_tensor
(
ori_results
)
assert
check_keys_contain
(
results
.
keys
(),
target_keys
)
for
key
in
target_keys
:
assert
isinstance
(
results
[
key
],
torch
.
Tensor
)
assert
torch
.
equal
(
results
[
key
].
data
,
ori_results
[
key
])
# Add an additional key which is not in keys.
ori_results
=
dict
(
tensor
=
torch
.
randn
(
2
,
3
),
numpy
=
np
.
random
.
randn
(
2
,
3
),
sequence
=
list
(
range
(
10
)),
int
=
1
,
float
=
0.1
,
str
=
'test'
)
results
=
to_tensor
(
ori_results
)
assert
check_keys_contain
(
results
.
keys
(),
target_keys
)
for
key
in
target_keys
:
assert
isinstance
(
results
[
key
],
torch
.
Tensor
)
assert
torch
.
equal
(
results
[
key
].
data
,
ori_results
[
key
])
assert
repr
(
to_tensor
)
==
to_tensor
.
__class__
.
__name__
+
f
'(keys=
{
target_keys
}
)'
def
test_image_to_tensor
():
ori_results
=
dict
(
img
=
np
.
random
.
randn
(
256
,
256
,
3
))
keys
=
[
'img'
]
to_float32
=
False
image_to_tensor
=
ImageToTensor
(
keys
)
results
=
image_to_tensor
(
ori_results
)
assert
results
[
'img'
].
shape
==
torch
.
Size
([
3
,
256
,
256
])
assert
isinstance
(
results
[
'img'
],
torch
.
Tensor
)
assert
torch
.
equal
(
results
[
'img'
].
data
,
ori_results
[
'img'
])
assert
results
[
'img'
].
dtype
==
torch
.
float32
ori_results
=
dict
(
img
=
np
.
random
.
randint
(
256
,
size
=
(
256
,
256
)))
keys
=
[
'img'
]
to_float32
=
True
image_to_tensor
=
ImageToTensor
(
keys
)
results
=
image_to_tensor
(
ori_results
)
assert
results
[
'img'
].
shape
==
torch
.
Size
([
1
,
256
,
256
])
assert
isinstance
(
results
[
'img'
],
torch
.
Tensor
)
assert
torch
.
equal
(
results
[
'img'
].
data
,
ori_results
[
'img'
])
assert
results
[
'img'
].
dtype
==
torch
.
float32
assert
repr
(
image_to_tensor
)
==
(
image_to_tensor
.
__class__
.
__name__
+
f
'(keys=
{
keys
}
, to_float32=
{
to_float32
}
)'
)
def
test_collect
():
inputs
=
dict
(
img
=
np
.
random
.
randn
(
256
,
256
,
3
),
label
=
[
1
],
img_name
=
'test_image.png'
,
ori_shape
=
(
256
,
256
,
3
),
img_shape
=
(
256
,
256
,
3
),
pad_shape
=
(
256
,
256
,
3
),
flip_direction
=
'vertical'
,
img_norm_cfg
=
dict
(
to_bgr
=
False
))
keys
=
[
'img'
,
'label'
]
meta_keys
=
[
'img_shape'
,
'img_name'
,
'ori_shape'
]
collect
=
Collect
(
keys
,
meta_keys
=
meta_keys
)
results
=
collect
(
inputs
)
assert
set
(
list
(
results
.
keys
()))
==
set
([
'img'
,
'label'
,
'meta'
])
inputs
.
pop
(
'img'
)
assert
set
(
results
[
'meta'
].
data
.
keys
())
==
set
(
meta_keys
)
for
key
in
results
[
'meta'
].
data
:
assert
results
[
'meta'
].
data
[
key
]
==
inputs
[
key
]
assert
repr
(
collect
)
==
(
collect
.
__class__
.
__name__
+
f
'(keys=
{
keys
}
, meta_keys=
{
collect
.
meta_keys
}
)'
)
tests/test_datasets/test_pipelines/test_loading.py
0 → 100644
View file @
c9a48a52
# Copyright (c) OpenMMLab. All rights reserved.
from
pathlib
import
Path
import
mmcv
import
numpy
as
np
from
mmgen.datasets
import
LoadImageFromFile
def
test_load_image_from_file
():
path_baboon
=
Path
(
__file__
).
parent
/
'..'
/
'..'
/
'data'
/
'image'
/
'baboon.png'
img_baboon
=
mmcv
.
imread
(
str
(
path_baboon
),
flag
=
'color'
)
# read gt image
# input path is Path object
results
=
dict
(
gt_path
=
path_baboon
)
config
=
dict
(
io_backend
=
'disk'
,
key
=
'gt'
)
image_loader
=
LoadImageFromFile
(
**
config
)
results
=
image_loader
(
results
)
assert
results
[
'gt'
].
shape
==
(
480
,
500
,
3
)
np
.
testing
.
assert_almost_equal
(
results
[
'gt'
],
img_baboon
)
assert
results
[
'gt_path'
]
==
str
(
path_baboon
)
# input path is str
results
=
dict
(
gt_path
=
str
(
path_baboon
))
results
=
image_loader
(
results
)
assert
results
[
'gt'
].
shape
==
(
480
,
500
,
3
)
np
.
testing
.
assert_almost_equal
(
results
[
'gt'
],
img_baboon
)
assert
results
[
'gt_path'
]
==
str
(
path_baboon
)
assert
repr
(
image_loader
)
==
(
image_loader
.
__class__
.
__name__
+
(
'(io_backend=disk, key=gt, '
'flag=color, save_original_img=False)'
))
tests/test_datasets/test_pipelines/test_normalize.py
0 → 100644
View file @
c9a48a52
# Copyright (c) OpenMMLab. All rights reserved.
import
numpy
as
np
import
pytest
from
mmgen.datasets.pipelines
import
Normalize
class
TestAugmentations
(
object
):
@
staticmethod
def
assert_img_equal
(
img
,
ref_img
,
ratio_thr
=
0.999
):
"""Check if img and ref_img are matched approximately."""
assert
img
.
shape
==
ref_img
.
shape
assert
img
.
dtype
==
ref_img
.
dtype
area
=
ref_img
.
shape
[
-
1
]
*
ref_img
.
shape
[
-
2
]
diff
=
np
.
abs
(
img
.
astype
(
'int32'
)
-
ref_img
.
astype
(
'int32'
))
assert
np
.
sum
(
diff
<=
1
)
/
float
(
area
)
>
ratio_thr
@
staticmethod
def
check_keys_contain
(
result_keys
,
target_keys
):
"""Check if all elements in target_keys is in result_keys."""
return
set
(
target_keys
).
issubset
(
set
(
result_keys
))
def
check_normalize
(
self
,
origin_img
,
result_img
,
norm_cfg
):
"""Check if the origin_img are normalized correctly into result_img in
a given norm_cfg."""
target_img
=
result_img
.
copy
()
target_img
*=
norm_cfg
[
'std'
][
None
,
None
,
:]
target_img
+=
norm_cfg
[
'mean'
][
None
,
None
,
:]
if
norm_cfg
[
'to_rgb'
]:
target_img
=
target_img
[:,
::
-
1
,
...].
copy
()
self
.
assert_img_equal
(
origin_img
,
target_img
)
def
test_normalize
(
self
):
with
pytest
.
raises
(
TypeError
):
Normalize
([
'alpha'
],
dict
(
mean
=
[
123.675
,
116.28
,
103.53
]),
[
58.395
,
57.12
,
57.375
])
with
pytest
.
raises
(
TypeError
):
Normalize
([
'alpha'
],
[
123.675
,
116.28
,
103.53
],
dict
(
std
=
[
58.395
,
57.12
,
57.375
]))
target_keys
=
[
'merged'
,
'img_norm_cfg'
]
merged
=
np
.
random
.
rand
(
240
,
320
,
3
).
astype
(
np
.
float32
)
results
=
dict
(
merged
=
merged
)
config
=
dict
(
keys
=
[
'merged'
],
mean
=
[
123.675
,
116.28
,
103.53
],
std
=
[
58.395
,
57.12
,
57.375
],
to_rgb
=
False
)
normalize
=
Normalize
(
**
config
)
normalize_results
=
normalize
(
results
)
assert
self
.
check_keys_contain
(
normalize_results
.
keys
(),
target_keys
)
self
.
check_normalize
(
merged
,
normalize_results
[
'merged'
],
normalize_results
[
'img_norm_cfg'
])
merged
=
np
.
random
.
rand
(
240
,
320
,
3
).
astype
(
np
.
float32
)
results
=
dict
(
merged
=
merged
)
config
=
dict
(
keys
=
[
'merged'
],
mean
=
[
123.675
,
116.28
,
103.53
],
std
=
[
58.395
,
57.12
,
57.375
],
to_rgb
=
True
)
normalize
=
Normalize
(
**
config
)
normalize_results
=
normalize
(
results
)
assert
self
.
check_keys_contain
(
normalize_results
.
keys
(),
target_keys
)
self
.
check_normalize
(
merged
,
normalize_results
[
'merged'
],
normalize_results
[
'img_norm_cfg'
])
assert
normalize
.
__repr__
()
==
(
normalize
.
__class__
.
__name__
+
f
"(keys=
{
[
'merged'
]
}
, mean=
{
np
.
array
([
123.675
,
116.28
,
103.53
])
}
,"
f
' std=
{
np
.
array
([
58.395
,
57.12
,
57.375
])
}
, to_rgb=True)'
)
# input is an image list
merged
=
np
.
random
.
rand
(
240
,
320
,
3
).
astype
(
np
.
float32
)
merged_2
=
np
.
random
.
rand
(
240
,
320
,
3
).
astype
(
np
.
float32
)
results
=
dict
(
merged
=
[
merged
,
merged_2
])
config
=
dict
(
keys
=
[
'merged'
],
mean
=
[
123.675
,
116.28
,
103.53
],
std
=
[
58.395
,
57.12
,
57.375
],
to_rgb
=
False
)
normalize
=
Normalize
(
**
config
)
normalize_results
=
normalize
(
results
)
assert
self
.
check_keys_contain
(
normalize_results
.
keys
(),
target_keys
)
self
.
check_normalize
(
merged
,
normalize_results
[
'merged'
][
0
],
normalize_results
[
'img_norm_cfg'
])
self
.
check_normalize
(
merged_2
,
normalize_results
[
'merged'
][
1
],
normalize_results
[
'img_norm_cfg'
])
merged
=
np
.
random
.
rand
(
240
,
320
,
3
).
astype
(
np
.
float32
)
merged_2
=
np
.
random
.
rand
(
240
,
320
,
3
).
astype
(
np
.
float32
)
results
=
dict
(
merged
=
[
merged
,
merged_2
])
config
=
dict
(
keys
=
[
'merged'
],
mean
=
[
123.675
,
116.28
,
103.53
],
std
=
[
58.395
,
57.12
,
57.375
],
to_rgb
=
True
)
normalize
=
Normalize
(
**
config
)
normalize_results
=
normalize
(
results
)
assert
self
.
check_keys_contain
(
normalize_results
.
keys
(),
target_keys
)
self
.
check_normalize
(
merged
,
normalize_results
[
'merged'
][
0
],
normalize_results
[
'img_norm_cfg'
])
self
.
check_normalize
(
merged_2
,
normalize_results
[
'merged'
][
1
],
normalize_results
[
'img_norm_cfg'
])
tests/test_datasets/test_quicktest_dataset.py
0 → 100644
View file @
c9a48a52
# Copyright (c) OpenMMLab. All rights reserved.
from
mmgen.datasets.quick_test_dataset
import
QuickTestImageDataset
class
TestQuickTest
:
@
classmethod
def
setup_class
(
cls
):
cls
.
dataset
=
QuickTestImageDataset
(
size
=
(
256
,
256
))
def
test_quicktest_dataset
(
self
):
assert
len
(
self
.
dataset
)
==
10000
img
=
self
.
dataset
[
2
]
assert
img
[
'real_img'
].
shape
==
(
3
,
256
,
256
)
tests/test_datasets/test_singan_dataset.py
0 → 100644
View file @
c9a48a52
# Copyright (c) OpenMMLab. All rights reserved.
import
os.path
as
osp
from
mmgen.datasets
import
SinGANDataset
class
TestSinGANDataset
(
object
):
@
classmethod
def
setup_class
(
cls
):
cls
.
imgs_root
=
osp
.
join
(
osp
.
dirname
(
osp
.
dirname
(
__file__
)),
'data/image/baboon.png'
)
cls
.
min_size
=
25
cls
.
max_size
=
250
cls
.
scale_factor_init
=
0.75
def
test_singan_dataset
(
self
):
dataset
=
SinGANDataset
(
self
.
imgs_root
,
min_size
=
self
.
min_size
,
max_size
=
self
.
max_size
,
scale_factor_init
=
self
.
scale_factor_init
)
assert
len
(
dataset
)
==
1000000
data_dict
=
dataset
[
0
]
assert
all
([
f
'real_scale
{
i
}
'
in
data_dict
for
i
in
range
(
10
)])
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment