Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
stylegan2_mmcv
Commits
1401de15
Commit
1401de15
authored
Jun 28, 2024
by
dongchy920
Browse files
stylegan2_mmcv
parents
Pipeline
#1274
canceled with stages
Changes
463
Pipelines
1
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3741 additions
and
0 deletions
+3741
-0
build/lib/mmgen/core/evaluation/evaluation.py
build/lib/mmgen/core/evaluation/evaluation.py
+403
-0
build/lib/mmgen/core/evaluation/metric_utils.py
build/lib/mmgen/core/evaluation/metric_utils.py
+283
-0
build/lib/mmgen/core/evaluation/metrics.py
build/lib/mmgen/core/evaluation/metrics.py
+1498
-0
build/lib/mmgen/core/hooks/__init__.py
build/lib/mmgen/core/hooks/__init__.py
+13
-0
build/lib/mmgen/core/hooks/ceph_hooks.py
build/lib/mmgen/core/hooks/ceph_hooks.py
+128
-0
build/lib/mmgen/core/hooks/ema_hook.py
build/lib/mmgen/core/hooks/ema_hook.py
+193
-0
build/lib/mmgen/core/hooks/pggan_fetch_data_hook.py
build/lib/mmgen/core/hooks/pggan_fetch_data_hook.py
+34
-0
build/lib/mmgen/core/hooks/pickle_data_hook.py
build/lib/mmgen/core/hooks/pickle_data_hook.py
+112
-0
build/lib/mmgen/core/hooks/visualization.py
build/lib/mmgen/core/hooks/visualization.py
+87
-0
build/lib/mmgen/core/hooks/visualize_training_samples.py
build/lib/mmgen/core/hooks/visualize_training_samples.py
+110
-0
build/lib/mmgen/core/optimizer/__init__.py
build/lib/mmgen/core/optimizer/__init__.py
+4
-0
build/lib/mmgen/core/optimizer/builder.py
build/lib/mmgen/core/optimizer/builder.py
+57
-0
build/lib/mmgen/core/registry.py
build/lib/mmgen/core/registry.py
+30
-0
build/lib/mmgen/core/runners/__init__.py
build/lib/mmgen/core/runners/__init__.py
+4
-0
build/lib/mmgen/core/runners/apex_amp_utils.py
build/lib/mmgen/core/runners/apex_amp_utils.py
+36
-0
build/lib/mmgen/core/runners/checkpoint.py
build/lib/mmgen/core/runners/checkpoint.py
+95
-0
build/lib/mmgen/core/runners/dynamic_iterbased_runner.py
build/lib/mmgen/core/runners/dynamic_iterbased_runner.py
+409
-0
build/lib/mmgen/core/runners/fp16_utils.py
build/lib/mmgen/core/runners/fp16_utils.py
+190
-0
build/lib/mmgen/core/scheduler/__init__.py
build/lib/mmgen/core/scheduler/__init__.py
+4
-0
build/lib/mmgen/core/scheduler/lr_updater.py
build/lib/mmgen/core/scheduler/lr_updater.py
+51
-0
No files found.
Too many changes to show.
To preserve performance only
463 of 463+
files are displayed.
Plain diff
Email patch
build/lib/mmgen/core/evaluation/evaluation.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
import
os
import
shutil
import
sys
from
copy
import
deepcopy
import
mmcv
import
torch
import
torch.distributed
as
dist
from
mmcv.runner
import
get_dist_info
from
prettytable
import
PrettyTable
from
torchvision.utils
import
save_image
from
mmgen.datasets
import
build_dataloader
,
build_dataset
def
make_metrics_table
(
train_cfg
,
ckpt
,
eval_info
,
metrics
):
"""Arrange evaluation results into a table.
Args:
train_cfg (str): Name of the training configuration.
ckpt (str): Path of the evaluated model's weights.
metrics (Metric): Metric objects.
Returns:
str: String of the eval table.
"""
table
=
PrettyTable
()
table
.
set_style
(
14
)
table
.
add_column
(
'Training configuration'
,
[
train_cfg
])
table
.
add_column
(
'Checkpoint'
,
[
ckpt
])
table
.
add_column
(
'Eval'
,
[
eval_info
])
for
metric
in
metrics
:
table
.
add_column
(
metric
.
name
,
[
metric
.
result_str
])
return
table
.
get_string
()
def
make_vanilla_dataloader
(
img_path
,
batch_size
,
dist
=
False
):
pipeline
=
[
dict
(
type
=
'LoadImageFromFile'
,
key
=
'real_img'
,
io_backend
=
'disk'
),
dict
(
type
=
'Normalize'
,
keys
=
[
'real_img'
],
mean
=
[
127.5
]
*
3
,
std
=
[
127.5
]
*
3
,
to_rgb
=
False
),
dict
(
type
=
'ImageToTensor'
,
keys
=
[
'real_img'
]),
dict
(
type
=
'Collect'
,
keys
=
[
'real_img'
],
meta_keys
=
[
'real_img_path'
])
]
dataset
=
build_dataset
(
dict
(
type
=
'UnconditionalImageDataset'
,
imgs_root
=
img_path
,
pipeline
=
pipeline
,
))
dataloader
=
build_dataloader
(
dataset
,
samples_per_gpu
=
batch_size
,
workers_per_gpu
=
4
,
dist
=
dist
,
shuffle
=
True
)
return
dataloader
@
torch
.
no_grad
()
def
offline_evaluation
(
model
,
data_loader
,
metrics
,
logger
,
basic_table_info
,
batch_size
,
samples_path
=
None
,
**
kwargs
):
"""Evaluate model in offline mode.
This method first save generated images at local and then load them by
dataloader.
Args:
model (nn.Module): Model to be tested.
data_loader (nn.Dataloader): PyTorch data loader.
metrics (list): List of metric objects.
logger (Logger): logger used to record results of evaluation.
batch_size (int): Batch size of images fed into metrics.
basic_table_info (dict): Dictionary containing the basic information
\
of the metric table include training configuration and ckpt.
samples_path (str): Used to save generated images. If it's none, we'll
give it a default directory and delete it after finishing the
evaluation. Default to None.
kwargs (dict): Other arguments.
"""
# eval special and recon metric online only
online_metric_name
=
[
'PPL'
,
'GaussianKLD'
]
for
metric
in
metrics
:
assert
metric
.
name
not
in
online_metric_name
,
'Please eval '
\
f
'
{
metric
.
name
}
online'
rank
,
ws
=
get_dist_info
()
delete_samples_path
=
False
if
samples_path
:
mmcv
.
mkdir_or_exist
(
samples_path
)
else
:
temp_path
=
'./work_dirs/temp_samples'
# if temp_path exists, add suffix
suffix
=
1
samples_path
=
temp_path
while
os
.
path
.
exists
(
samples_path
):
samples_path
=
temp_path
+
'_'
+
str
(
suffix
)
suffix
+=
1
os
.
makedirs
(
samples_path
)
delete_samples_path
=
True
# sample images
num_exist
=
len
(
list
(
mmcv
.
scandir
(
samples_path
,
suffix
=
(
'.jpg'
,
'.png'
,
'.jpeg'
,
'.JPEG'
))))
if
basic_table_info
[
'num_samples'
]
>
0
:
max_num_images
=
basic_table_info
[
'num_samples'
]
else
:
max_num_images
=
max
(
metric
.
num_images
for
metric
in
metrics
)
num_needed
=
max
(
max_num_images
-
num_exist
,
0
)
if
num_needed
>
0
and
rank
==
0
:
mmcv
.
print_log
(
f
'Sample
{
num_needed
}
fake images for evaluation'
,
'mmgen'
)
# define mmcv progress bar
pbar
=
mmcv
.
ProgressBar
(
num_needed
)
# if no images, `num_needed` should be zero
total_batch_size
=
batch_size
*
ws
for
begin
in
range
(
0
,
num_needed
,
total_batch_size
):
end
=
min
(
begin
+
batch_size
,
max_num_images
)
fakes
=
model
(
None
,
num_batches
=
end
-
begin
,
return_loss
=
False
,
sample_model
=
basic_table_info
[
'sample_model'
],
**
kwargs
)
global_end
=
min
(
begin
+
total_batch_size
,
max_num_images
)
if
rank
==
0
:
pbar
.
update
(
global_end
-
begin
)
# gather generated images
if
ws
>
1
:
placeholder
=
[
torch
.
zeros_like
(
fakes
)
for
_
in
range
(
ws
)]
dist
.
all_gather
(
placeholder
,
fakes
)
fakes
=
torch
.
cat
(
placeholder
,
dim
=
0
)
# save as three-channel
if
fakes
.
size
(
1
)
==
3
:
fakes
=
fakes
[:,
[
2
,
1
,
0
],
...]
elif
fakes
.
size
(
1
)
==
1
:
fakes
=
torch
.
cat
([
fakes
]
*
3
,
dim
=
1
)
else
:
raise
RuntimeError
(
'Generated images must have one or three '
'channels in the first dimension, '
'not %d'
%
fakes
.
size
(
1
))
if
rank
==
0
:
for
i
in
range
(
global_end
-
begin
):
images
=
fakes
[
i
:
i
+
1
]
images
=
((
images
+
1
)
/
2
)
images
=
images
.
clamp_
(
0
,
1
)
image_name
=
str
(
num_exist
+
begin
+
i
)
+
'.png'
save_image
(
images
,
os
.
path
.
join
(
samples_path
,
image_name
))
if
num_needed
>
0
and
rank
==
0
:
sys
.
stdout
.
write
(
'
\n
'
)
# return if only save sampled images
if
len
(
metrics
)
==
0
:
return
# empty cache to release GPU memory
torch
.
cuda
.
empty_cache
()
fake_dataloader
=
make_vanilla_dataloader
(
samples_path
,
batch_size
,
dist
=
ws
>
1
)
for
metric
in
metrics
:
mmcv
.
print_log
(
f
'Evaluate with
{
metric
.
name
}
metric.'
,
'mmgen'
)
metric
.
prepare
()
if
rank
==
0
:
# prepare for pbar
total_need
=
(
metric
.
num_real_need
+
metric
.
num_fake_need
-
metric
.
num_real_feeded
-
metric
.
num_fake_feeded
)
pbar
=
mmcv
.
ProgressBar
(
total_need
)
# feed in real images
for
data
in
data_loader
:
# key for unconditional GAN
if
'real_img'
in
data
:
reals
=
data
[
'real_img'
]
# key for conditional GAN
elif
'img'
in
data
:
reals
=
data
[
'img'
]
else
:
raise
KeyError
(
'Cannot found key for images in data_dict. '
'Only support `real_img` for unconditional '
'datasets and `img` for conditional '
'datasets.'
)
if
reals
.
shape
[
1
]
==
1
:
reals
=
torch
.
cat
([
reals
]
*
3
,
dim
=
1
)
num_left
=
metric
.
feed
(
reals
,
'reals'
)
if
num_left
<=
0
:
break
if
rank
==
0
:
pbar
.
update
(
reals
.
shape
[
0
]
*
ws
)
# feed in fake images
for
data
in
fake_dataloader
:
fakes
=
data
[
'real_img'
]
if
fakes
.
shape
[
1
]
==
1
:
fakes
=
torch
.
cat
([
fakes
]
*
3
,
dim
=
1
)
num_left
=
metric
.
feed
(
fakes
,
'fakes'
)
if
num_left
<=
0
:
break
if
rank
==
0
:
pbar
.
update
(
fakes
.
shape
[
0
]
*
ws
)
if
rank
==
0
:
# only call summary at main device
metric
.
summary
()
sys
.
stdout
.
write
(
'
\n
'
)
if
rank
==
0
:
table_str
=
make_metrics_table
(
basic_table_info
[
'train_cfg'
],
basic_table_info
[
'ckpt'
],
basic_table_info
[
'sample_model'
],
metrics
)
logger
.
info
(
'
\n
'
+
table_str
)
if
delete_samples_path
:
shutil
.
rmtree
(
samples_path
)
@
torch
.
no_grad
()
def
online_evaluation
(
model
,
data_loader
,
metrics
,
logger
,
basic_table_info
,
batch_size
,
**
kwargs
):
"""Evaluate model in online mode.
This method evaluate model and displays eval progress bar.
Different form `offline_evaluation`, this function will not save
the images or read images from disks. Namely, there do not exist any IO
operations in this function. Thus, in general, `online` mode will achieve a
faster evaluation. However, this mode will take much more memory cost.
To be noted that, we only support distributed evaluation for FID and IS
currently.
Args:
model (nn.Module): Model to be tested.
data_loader (nn.Dataloader): PyTorch data loader.
metrics (list): List of metric objects.
logger (Logger): logger used to record results of evaluation.
batch_size (int): Batch size of images fed into metrics.
basic_table_info (dict): Dictionary containing the basic information
\
of the metric table include training configuration and ckpt.
kwargs (dict): Other arguments.
"""
# separate metrics into special metrics, probabilistic metrics and vanilla
# metrics.
# For vanilla metrics, images are generated in a random way, and are
# shared by these metrics. For special metrics like 'PPL', images are
# generated in a metric-special way and not shared between different
# metrics.
# For reconstruction metrics like 'GaussianKLD', they do not
# receive images but receive a dict with corresponding probabilistic
# parameter.
rank
,
ws
=
get_dist_info
()
special_metrics
=
[]
recon_metrics
=
[]
vanilla_metrics
=
[]
special_metric_name
=
[
'PPL'
]
recon_metric_name
=
[
'GaussianKLD'
]
for
metric
in
metrics
:
if
ws
>
1
:
assert
metric
.
name
in
[
'FID'
,
'IS'
],
(
'We only support FID and IS for distributed evaluation '
f
'currently, but receive
{
metric
.
name
}
'
)
if
metric
.
name
in
special_metric_name
:
special_metrics
.
append
(
metric
)
elif
metric
.
name
in
recon_metric_name
:
recon_metrics
.
append
(
metric
)
else
:
vanilla_metrics
.
append
(
metric
)
# define mmcv progress bar
max_num_images
=
0
for
metric
in
vanilla_metrics
+
recon_metrics
:
metric
.
prepare
()
max_num_images
=
max
(
max_num_images
,
metric
.
num_real_need
-
metric
.
num_real_feeded
)
if
rank
==
0
:
mmcv
.
print_log
(
f
'Sample
{
max_num_images
}
real images for evaluation'
,
'mmgen'
)
pbar
=
mmcv
.
ProgressBar
(
max_num_images
)
# avoid `data_loader` is None
data_loader
=
[]
if
data_loader
is
None
else
data_loader
for
data
in
data_loader
:
if
'real_img'
in
data
:
reals
=
data
[
'real_img'
]
# key for conditional GAN
elif
'img'
in
data
:
reals
=
data
[
'img'
]
else
:
raise
KeyError
(
'Cannot found key for images in data_dict. '
'Only support `real_img` for unconditional '
'datasets and `img` for conditional '
'datasets.'
)
if
reals
.
shape
[
1
]
not
in
[
1
,
3
]:
raise
RuntimeError
(
'real images should have one or three '
'channels in the first, '
'not % d'
%
reals
.
shape
[
1
])
if
reals
.
shape
[
1
]
==
1
:
reals
=
reals
.
repeat
(
1
,
3
,
1
,
1
)
num_feed
=
0
for
metric
in
vanilla_metrics
:
num_feed_
=
metric
.
feed
(
reals
,
'reals'
)
num_feed
=
max
(
num_feed_
,
num_feed
)
for
metric
in
recon_metrics
:
kwargs_
=
deepcopy
(
kwargs
)
kwargs_
[
'mode'
]
=
'reconstruction'
prob_dict
=
model
(
reals
,
return_loss
=
False
,
**
kwargs_
)
num_feed_
=
metric
.
feed
(
prob_dict
,
'reals'
)
num_feed
=
max
(
num_feed_
,
num_feed
)
if
num_feed
<=
0
:
break
if
rank
==
0
:
pbar
.
update
(
num_feed
)
if
rank
==
0
:
# finish the pbar stdout
sys
.
stdout
.
write
(
'
\n
'
)
# define mmcv progress bar
max_num_images
=
0
if
len
(
vanilla_metrics
)
==
0
else
max
(
metric
.
num_fake_need
for
metric
in
vanilla_metrics
)
if
rank
==
0
:
mmcv
.
print_log
(
f
'Sample
{
max_num_images
}
fake images for evaluation'
,
'mmgen'
)
pbar
=
mmcv
.
ProgressBar
(
max_num_images
)
# sampling fake images and directly send them to metrics
total_batch_size
=
batch_size
*
ws
for
_
in
range
(
0
,
max_num_images
,
total_batch_size
):
fakes
=
model
(
None
,
num_batches
=
batch_size
,
return_loss
=
False
,
sample_model
=
basic_table_info
[
'sample_model'
],
**
kwargs
)
if
fakes
.
shape
[
1
]
not
in
[
1
,
3
]:
raise
RuntimeError
(
'fakes images should have one or three '
'channels in the first, '
'not % d'
%
fakes
.
shape
[
1
])
if
fakes
.
shape
[
1
]
==
1
:
fakes
=
torch
.
cat
([
fakes
]
*
3
,
dim
=
1
)
for
metric
in
vanilla_metrics
:
# feed in fake images
metric
.
feed
(
fakes
,
'fakes'
)
if
rank
==
0
:
pbar
.
update
(
total_batch_size
)
if
rank
==
0
:
# finish the pbar stdout
sys
.
stdout
.
write
(
'
\n
'
)
# feed special metric, we do not consider distributed eval here
for
metric
in
special_metrics
:
metric
.
prepare
()
fakedata_iterator
=
iter
(
metric
.
get_sampler
(
model
.
module
,
batch_size
,
basic_table_info
[
'sample_model'
]))
mmcv
.
print_log
(
f
'Sample
{
metric
.
num_images
}
samples for evaluating
{
metric
.
name
}
'
,
'mmgen'
)
pbar
=
mmcv
.
ProgressBar
(
metric
.
num_images
)
for
fakes
in
fakedata_iterator
:
num_left
=
metric
.
feed
(
fakes
,
'fakes'
)
pbar
.
update
(
fakes
.
shape
[
0
])
if
num_left
<=
0
:
break
# finish the pbar stdout
sys
.
stdout
.
write
(
'
\n
'
)
if
rank
==
0
:
for
metric
in
metrics
:
metric
.
summary
()
table_str
=
make_metrics_table
(
basic_table_info
[
'train_cfg'
],
basic_table_info
[
'ckpt'
],
basic_table_info
[
'sample_model'
],
metrics
)
logger
.
info
(
'
\n
'
+
table_str
)
build/lib/mmgen/core/evaluation/metric_utils.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
import
sys
import
mmcv
import
numpy
as
np
import
torch
import
torch.nn.functional
as
F
from
mmcv.parallel
import
is_module_wrapper
from
mmgen.models.architectures.common
import
get_module_device
@
torch
.
no_grad
()
def
extract_inception_features
(
dataloader
,
inception
,
num_samples
,
inception_style
=
'pytorch'
):
"""Extract inception features for FID metric.
Args:
dataloader (:obj:`DataLoader`): Dataloader for images.
inception (nn.Module): Inception network.
num_samples (int): The number of samples to be extracted.
inception_style (str): The style of Inception network, "pytorch" or
"stylegan". Defaults to "pytorch".
Returns:
torch.Tensor: Inception features.
"""
batch_size
=
dataloader
.
batch_size
num_iters
=
num_samples
//
batch_size
if
num_iters
*
batch_size
<
num_samples
:
num_iters
+=
1
# define mmcv progress bar
pbar
=
mmcv
.
ProgressBar
(
num_iters
)
feature_list
=
[]
curr_iter
=
1
for
data
in
dataloader
:
# a dirty walkround to support multiple datasets (mainly for the
# unconditional dataset and conditional dataset). In our
# implementation, unconditioanl dataset will return real images with
# the key "real_img". However, the conditional dataset contains a key
# "img" denoting the real images.
if
'real_img'
in
data
:
# Mainly for the unconditional dataset in our MMGeneration
img
=
data
[
'real_img'
]
else
:
# Mainly for conditional dataset in MMClassification
img
=
data
[
'img'
]
pbar
.
update
()
# the inception network is not wrapped with module wrapper.
if
not
is_module_wrapper
(
inception
):
# put the img to the module device
img
=
img
.
to
(
get_module_device
(
inception
))
if
inception_style
==
'stylegan'
:
img
=
(
img
*
127.5
+
128
).
clamp
(
0
,
255
).
to
(
torch
.
uint8
)
feature
=
inception
(
img
,
return_features
=
True
)
else
:
feature
=
inception
(
img
)[
0
].
view
(
img
.
shape
[
0
],
-
1
)
feature_list
.
append
(
feature
.
to
(
'cpu'
))
if
curr_iter
>=
num_iters
:
break
curr_iter
+=
1
# Attention: the number of features may be different as you want.
features
=
torch
.
cat
(
feature_list
,
0
)
assert
features
.
shape
[
0
]
>=
num_samples
features
=
features
[:
num_samples
]
# to change the line after pbar
sys
.
stdout
.
write
(
'
\n
'
)
return
features
def
_hox_downsample
(
img
):
r
"""Downsample images with factor equal to 0.5.
Ref: https://github.com/tkarras/progressive_growing_of_gans/blob/master/metrics/ms_ssim.py # noqa
Args:
img (ndarray): Images with order "NHWC".
Returns:
ndarray: Downsampled images with order "NHWC".
"""
return
(
img
[:,
0
::
2
,
0
::
2
,
:]
+
img
[:,
1
::
2
,
0
::
2
,
:]
+
img
[:,
0
::
2
,
1
::
2
,
:]
+
img
[:,
1
::
2
,
1
::
2
,
:])
*
0.25
def
_f_special_gauss
(
size
,
sigma
):
r
"""Return a circular symmetric gaussian kernel.
Ref: https://github.com/tkarras/progressive_growing_of_gans/blob/master/metrics/ms_ssim.py # noqa
Args:
size (int): Size of Gaussian kernel.
sigma (float): Standard deviation for Gaussian blur kernel.
Returns:
ndarray: Gaussian kernel.
"""
radius
=
size
//
2
offset
=
0.0
start
,
stop
=
-
radius
,
radius
+
1
if
size
%
2
==
0
:
offset
=
0.5
stop
-=
1
x
,
y
=
np
.
mgrid
[
offset
+
start
:
stop
,
offset
+
start
:
stop
]
assert
len
(
x
)
==
size
g
=
np
.
exp
(
-
((
x
**
2
+
y
**
2
)
/
(
2.0
*
sigma
**
2
)))
return
g
/
g
.
sum
()
# Gaussian blur kernel
def
get_gaussian_kernel
():
kernel
=
np
.
array
([[
1
,
4
,
6
,
4
,
1
],
[
4
,
16
,
24
,
16
,
4
],
[
6
,
24
,
36
,
24
,
6
],
[
4
,
16
,
24
,
16
,
4
],
[
1
,
4
,
6
,
4
,
1
]],
np
.
float32
)
/
256.0
gaussian_k
=
torch
.
as_tensor
(
kernel
.
reshape
(
1
,
1
,
5
,
5
))
return
gaussian_k
def
get_pyramid_layer
(
image
,
gaussian_k
,
direction
=
'down'
):
gaussian_k
=
gaussian_k
.
to
(
image
.
device
)
if
direction
==
'up'
:
image
=
F
.
interpolate
(
image
,
scale_factor
=
2
)
multiband
=
[
F
.
conv2d
(
image
[:,
i
:
i
+
1
,
:,
:],
gaussian_k
,
padding
=
2
,
stride
=
1
if
direction
==
'up'
else
2
)
for
i
in
range
(
3
)
]
image
=
torch
.
cat
(
multiband
,
dim
=
1
)
return
image
def
gaussian_pyramid
(
original
,
n_pyramids
,
gaussian_k
):
x
=
original
# pyramid down
pyramids
=
[
original
]
for
_
in
range
(
n_pyramids
):
x
=
get_pyramid_layer
(
x
,
gaussian_k
)
pyramids
.
append
(
x
)
return
pyramids
def
laplacian_pyramid
(
original
,
n_pyramids
,
gaussian_k
):
"""Calculate Laplacian pyramid.
Ref: https://github.com/koshian2/swd-pytorch/blob/master/swd.py
Args:
original (Tensor): Batch of Images with range [0, 1] and order "NCHW"
n_pyramids (int): Levels of pyramids minus one.
gaussian_k (Tensor): Gaussian kernel with shape (1, 1, 5, 5).
Return:
list[Tensor]. Laplacian pyramids of original.
"""
# create gaussian pyramid
pyramids
=
gaussian_pyramid
(
original
,
n_pyramids
,
gaussian_k
)
# pyramid up - diff
laplacian
=
[]
for
i
in
range
(
len
(
pyramids
)
-
1
):
diff
=
pyramids
[
i
]
-
get_pyramid_layer
(
pyramids
[
i
+
1
],
gaussian_k
,
'up'
)
laplacian
.
append
(
diff
)
# Add last gaussian pyramid
laplacian
.
append
(
pyramids
[
len
(
pyramids
)
-
1
])
return
laplacian
def
get_descriptors_for_minibatch
(
minibatch
,
nhood_size
,
nhoods_per_image
):
r
"""Get descriptors of one level of pyramids.
Ref: https://github.com/tkarras/progressive_growing_of_gans/blob/master/metrics/sliced_wasserstein.py # noqa
Args:
minibatch (Tensor): Pyramids of one level with order "NCHW".
nhood_size (int): Pixel neighborhood size.
nhoods_per_image (int): The number of descriptors per image.
Return:
Tensor: Descriptors of images from one level batch.
"""
S
=
minibatch
.
shape
# (minibatch, channel, height, width)
assert
len
(
S
)
==
4
and
S
[
1
]
==
3
N
=
nhoods_per_image
*
S
[
0
]
H
=
nhood_size
//
2
nhood
,
chan
,
x
,
y
=
np
.
ogrid
[
0
:
N
,
0
:
3
,
-
H
:
H
+
1
,
-
H
:
H
+
1
]
img
=
nhood
//
nhoods_per_image
x
=
x
+
np
.
random
.
randint
(
H
,
S
[
3
]
-
H
,
size
=
(
N
,
1
,
1
,
1
))
y
=
y
+
np
.
random
.
randint
(
H
,
S
[
2
]
-
H
,
size
=
(
N
,
1
,
1
,
1
))
idx
=
((
img
*
S
[
1
]
+
chan
)
*
S
[
2
]
+
y
)
*
S
[
3
]
+
x
return
minibatch
.
view
(
-
1
)[
idx
]
def
finalize_descriptors
(
desc
):
r
"""Normalize and reshape descriptors.
Ref: https://github.com/tkarras/progressive_growing_of_gans/blob/master/metrics/sliced_wasserstein.py # noqa
Args:
desc (list or Tensor): List of descriptors of one level.
Return:
Tensor: Descriptors after normalized along channel and flattened.
"""
if
isinstance
(
desc
,
list
):
desc
=
torch
.
cat
(
desc
,
dim
=
0
)
assert
desc
.
ndim
==
4
# (neighborhood, channel, height, width)
desc
-=
torch
.
mean
(
desc
,
dim
=
(
0
,
2
,
3
),
keepdim
=
True
)
desc
/=
torch
.
std
(
desc
,
dim
=
(
0
,
2
,
3
),
keepdim
=
True
)
desc
=
desc
.
reshape
(
desc
.
shape
[
0
],
-
1
)
return
desc
def
compute_pr_distances
(
row_features
,
col_features
,
num_gpus
,
rank
,
col_batch_size
=
10000
):
r
"""Compute distances between real images and fake images.
This function is used for calculate Precision and Recall metric.
Refer to:https://github.com/NVlabs/stylegan2-ada-pytorch/blob/main/metrics/precision_recall.py # noqa
"""
assert
0
<=
rank
<
num_gpus
num_cols
=
col_features
.
shape
[
0
]
num_batches
=
((
num_cols
-
1
)
//
col_batch_size
//
num_gpus
+
1
)
*
num_gpus
col_batches
=
torch
.
nn
.
functional
.
pad
(
col_features
,
[
0
,
0
,
0
,
-
num_cols
%
num_batches
]).
chunk
(
num_batches
)
dist_batches
=
[]
for
col_batch
in
col_batches
[
rank
::
num_gpus
]:
dist_batch
=
torch
.
cdist
(
row_features
.
unsqueeze
(
0
),
col_batch
.
unsqueeze
(
0
))[
0
]
for
src
in
range
(
num_gpus
):
dist_broadcast
=
dist_batch
.
clone
()
if
num_gpus
>
1
:
torch
.
distributed
.
broadcast
(
dist_broadcast
,
src
=
src
)
dist_batches
.
append
(
dist_broadcast
.
cpu
()
if
rank
==
0
else
None
)
return
torch
.
cat
(
dist_batches
,
dim
=
1
)[:,
:
num_cols
]
if
rank
==
0
else
None
def
normalize
(
a
):
"""L2 normalization.
Args:
a (Tensor): Tensor with shape [N, C].
Returns:
Tensor: Tensor after L2 normalization per-instance.
"""
return
a
/
torch
.
norm
(
a
,
dim
=
1
,
keepdim
=
True
)
def
slerp
(
a
,
b
,
percent
):
"""Spherical linear interpolation between two unnormalized vectors.
Args:
a (Tensor): Tensor with shape [N, C].
b (Tensor): Tensor with shape [N, C].
percent (float|Tensor): A float or tensor with shape broadcastable to
the shape of input Tensors.
Returns:
Tensor: Spherical linear interpolation result with shape [N, C].
"""
a
=
normalize
(
a
)
b
=
normalize
(
b
)
d
=
(
a
*
b
).
sum
(
-
1
,
keepdim
=
True
)
p
=
percent
*
torch
.
acos
(
d
)
c
=
normalize
(
b
-
d
*
a
)
d
=
a
*
torch
.
cos
(
p
)
+
c
*
torch
.
sin
(
p
)
return
normalize
(
d
)
build/lib/mmgen/core/evaluation/metrics.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
import
os
import
pickle
from
abc
import
ABC
,
abstractmethod
from
copy
import
deepcopy
from
functools
import
partial
import
mmcv
import
numpy
as
np
import
torch
import
torch.distributed
as
dist
import
torch.nn.functional
as
F
from
mmcv.runner
import
get_dist_info
from
scipy
import
linalg
,
signal
from
scipy.stats
import
entropy
from
torchvision
import
models
from
torchvision.models.inception
import
inception_v3
from
mmgen.models.architectures
import
InceptionV3
from
mmgen.models.architectures.common
import
get_module_device
from
mmgen.models.architectures.lpips
import
PerceptualLoss
from
mmgen.models.losses
import
gaussian_kld
from
mmgen.utils
import
MMGEN_CACHE_DIR
from
mmgen.utils.io_utils
import
download_from_url
from
..registry
import
METRICS
from
.metric_utils
import
(
_f_special_gauss
,
_hox_downsample
,
compute_pr_distances
,
finalize_descriptors
,
get_descriptors_for_minibatch
,
get_gaussian_kernel
,
laplacian_pyramid
,
slerp
)
TERO_INCEPTION_URL
=
'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt'
# noqa
def
load_inception
(
inception_args
,
metric
):
"""Load Inception Model from given ``inception_args`` and ``metric``. This
function would try to load Inception under the guidance of 'type' given in
`inception_args`, if not given, we would try best to load Tero's ones. In
detailly, we would first try to load the model from disk with the given
'inception_path', and then try to download the checkpoint from
'inception_url'. If both method are failed, pytorch version of Inception
would be loaded.
Args:
inception_args (dict): Keyword args for inception net.
metric (string): Metric to use the Inception. This argument would
influence the pytorch's Inception loading.
Returns:
model (torch.nn.Module): Loaded Inception model.
style (string): The version of the loaded Inception.
"""
if
not
isinstance
(
inception_args
,
dict
):
raise
TypeError
(
'Receive invalid
\'
inception_args
\'
: '
f
'
\'
{
inception_args
}
\'
'
)
_inception_args
=
deepcopy
(
inception_args
)
inceptoin_type
=
_inception_args
.
pop
(
'type'
,
None
)
if
torch
.
__version__
<
'1.6.0'
:
mmcv
.
print_log
(
'Current Pytorch Version not support script module, load '
'Inception Model from torch model zoo. If you want to use '
'Tero
\'
script model, please update your Pytorch higher '
f
'than
\'
1.6
\'
(now is
{
torch
.
__version__
}
)'
,
'mmgen'
)
return
_load_inception_torch
(
_inception_args
,
metric
),
'pytorch'
# load pytorch version is specific
if
inceptoin_type
!=
'StyleGAN'
:
return
_load_inception_torch
(
_inception_args
,
metric
),
'pytorch'
# try to load Tero's version
path
=
_inception_args
.
get
(
'inception_path'
,
TERO_INCEPTION_URL
)
# try to parse `path` as web url and download
if
'http'
not
in
path
:
model
=
_load_inception_from_path
(
path
)
if
isinstance
(
model
,
torch
.
nn
.
Module
):
return
model
,
'StyleGAN'
# try to parse `path` as path on disk
model
=
_load_inception_from_url
(
path
)
if
isinstance
(
model
,
torch
.
nn
.
Module
):
return
model
,
'StyleGAN'
raise
RuntimeError
(
'Cannot Load Inception Model, please check the input '
f
'`inception_args`:
{
inception_args
}
'
)
def
_load_inception_from_path
(
inception_path
):
mmcv
.
print_log
(
'Try to load Tero
\'
s Inception Model from '
f
'
\'
{
inception_path
}
\'
.'
,
'mmgen'
)
try
:
model
=
torch
.
jit
.
load
(
inception_path
)
mmcv
.
print_log
(
'Load Tero
\'
s Inception Model successfully.'
,
'mmgen'
)
except
Exception
as
e
:
model
=
None
mmcv
.
print_log
(
'Load Tero
\'
s Inception Model failed. '
f
'
\'
{
e
}
\'
occurs.'
,
'mmgen'
)
return
model
def
_load_inception_from_url
(
inception_url
):
inception_url
=
inception_url
if
inception_url
else
TERO_INCEPTION_URL
mmcv
.
print_log
(
f
'Try to download Inception Model from
{
inception_url
}
...'
,
'mmgen'
)
try
:
path
=
download_from_url
(
inception_url
,
dest_dir
=
MMGEN_CACHE_DIR
)
mmcv
.
print_log
(
'Download Finished.'
)
return
_load_inception_from_path
(
path
)
except
Exception
as
e
:
mmcv
.
print_log
(
f
'Download Failed.
{
e
}
occurs.'
)
return
None
def
_load_inception_torch
(
inception_args
,
metric
):
assert
metric
in
[
'FID'
,
'IS'
]
if
metric
==
'FID'
:
inception_model
=
InceptionV3
([
3
],
**
inception_args
)
elif
metric
==
'IS'
:
inception_model
=
inception_v3
(
pretrained
=
True
,
transform_input
=
False
)
mmcv
.
print_log
(
'Load Inception V3 Network from Pytorch Model Zoo '
'for IS calculation. The results can only used '
'for monitoring purposes. To get more accuracy IS, '
'please use Tero
\'
s Inception V3 checkpoints '
'and use Bicubic Interpolation with Pillow backend '
'for image resizing. More details may refer to '
'https://github.com/open-mmlab/mmgeneration/blob/master/docs/en/quick_run.md#is.'
,
# noqa
'mmgen'
)
return
inception_model
def
_ssim_for_multi_scale
(
img1
,
img2
,
max_val
=
255
,
filter_size
=
11
,
filter_sigma
=
1.5
,
k1
=
0.01
,
k2
=
0.03
):
"""Calculate SSIM (structural similarity) and contrast sensitivity.
Ref:
Image quality assessment: From error visibility to structural similarity.
The results are the same as that of the official released MATLAB code in
https://ece.uwaterloo.ca/~z70wang/research/ssim/.
For three-channel images, SSIM is calculated for each channel and then
averaged.
This function attempts to match the functionality of ssim_index_new.m by
Zhou Wang: http://www.cns.nyu.edu/~lcv/ssim/msssim.zip
Args:
img1 (ndarray): Images with range [0, 255] and order "NHWC".
img2 (ndarray): Images with range [0, 255] and order "NHWC".
max_val (int): the dynamic range of the images (i.e., the difference
between the maximum the and minimum allowed values).
Default to 255.
filter_size (int): Size of blur kernel to use (will be reduced for
small images). Default to 11.
filter_sigma (float): Standard deviation for Gaussian blur kernel (will
be reduced for small images). Default to 1.5.
k1 (float): Constant used to maintain stability in the SSIM calculation
(0.01 in the original paper). Default to 0.01.
k2 (float): Constant used to maintain stability in the SSIM calculation
(0.03 in the original paper). Default to 0.03.
Returns:
tuple: Pair containing the mean SSIM and contrast sensitivity between
`img1` and `img2`.
"""
if
img1
.
shape
!=
img2
.
shape
:
raise
RuntimeError
(
'Input images must have the same shape (%s vs. %s).'
%
(
img1
.
shape
,
img2
.
shape
))
if
img1
.
ndim
!=
4
:
raise
RuntimeError
(
'Input images must have four dimensions, not %d'
%
img1
.
ndim
)
img1
=
img1
.
astype
(
np
.
float32
)
img2
=
img2
.
astype
(
np
.
float32
)
_
,
height
,
width
,
_
=
img1
.
shape
# Filter size can't be larger than height or width of images.
size
=
min
(
filter_size
,
height
,
width
)
# Scale down sigma if a smaller filter size is used.
sigma
=
size
*
filter_sigma
/
filter_size
if
filter_size
else
0
if
filter_size
:
window
=
np
.
reshape
(
_f_special_gauss
(
size
,
sigma
),
(
1
,
size
,
size
,
1
))
mu1
=
signal
.
fftconvolve
(
img1
,
window
,
mode
=
'valid'
)
mu2
=
signal
.
fftconvolve
(
img2
,
window
,
mode
=
'valid'
)
sigma11
=
signal
.
fftconvolve
(
img1
*
img1
,
window
,
mode
=
'valid'
)
sigma22
=
signal
.
fftconvolve
(
img2
*
img2
,
window
,
mode
=
'valid'
)
sigma12
=
signal
.
fftconvolve
(
img1
*
img2
,
window
,
mode
=
'valid'
)
else
:
# Empty blur kernel so no need to convolve.
mu1
,
mu2
=
img1
,
img2
sigma11
=
img1
*
img1
sigma22
=
img2
*
img2
sigma12
=
img1
*
img2
mu11
=
mu1
*
mu1
mu22
=
mu2
*
mu2
mu12
=
mu1
*
mu2
sigma11
-=
mu11
sigma22
-=
mu22
sigma12
-=
mu12
# Calculate intermediate values used by both ssim and cs_map.
c1
=
(
k1
*
max_val
)
**
2
c2
=
(
k2
*
max_val
)
**
2
v1
=
2.0
*
sigma12
+
c2
v2
=
sigma11
+
sigma22
+
c2
ssim
=
np
.
mean
((((
2.0
*
mu12
+
c1
)
*
v1
)
/
((
mu11
+
mu22
+
c1
)
*
v2
)),
axis
=
(
1
,
2
,
3
))
# Return for each image individually.
cs
=
np
.
mean
(
v1
/
v2
,
axis
=
(
1
,
2
,
3
))
return
ssim
,
cs
def
ms_ssim
(
img1
,
img2
,
max_val
=
255
,
filter_size
=
11
,
filter_sigma
=
1.5
,
k1
=
0.01
,
k2
=
0.03
,
weights
=
None
):
"""Calculate MS-SSIM (multi-scale structural similarity).
Ref:
This function implements Multi-Scale Structural Similarity (MS-SSIM) Image
Quality Assessment according to Zhou Wang's paper, "Multi-scale structural
similarity for image quality assessment" (2003).
Link: https://ece.uwaterloo.ca/~z70wang/publications/msssim.pdf
Author's MATLAB implementation:
http://www.cns.nyu.edu/~lcv/ssim/msssim.zip
PGGAN's implementation:
https://github.com/tkarras/progressive_growing_of_gans/blob/master/metrics/ms_ssim.py
Args:
img1 (ndarray): Images with range [0, 255] and order "NHWC".
img2 (ndarray): Images with range [0, 255] and order "NHWC".
max_val (int): the dynamic range of the images (i.e., the difference
between the maximum the and minimum allowed values).
Default to 255.
filter_size (int): Size of blur kernel to use (will be reduced for
small images). Default to 11.
filter_sigma (float): Standard deviation for Gaussian blur kernel (will
be reduced for small images). Default to 1.5.
k1 (float): Constant used to maintain stability in the SSIM calculation
(0.01 in the original paper). Default to 0.01.
k2 (float): Constant used to maintain stability in the SSIM calculation
(0.03 in the original paper). Default to 0.03.
weights (list): List of weights for each level; if none, use five
levels and the weights from the original paper. Default to None.
Returns:
float: MS-SSIM score between `img1` and `img2`.
"""
if
img1
.
shape
!=
img2
.
shape
:
raise
RuntimeError
(
'Input images must have the same shape (%s vs. %s).'
%
(
img1
.
shape
,
img2
.
shape
))
if
img1
.
ndim
!=
4
:
raise
RuntimeError
(
'Input images must have four dimensions, not %d'
%
img1
.
ndim
)
# Note: default weights don't sum to 1.0 but do match the paper / matlab
# code.
weights
=
np
.
array
(
weights
if
weights
else
[
0.0448
,
0.2856
,
0.3001
,
0.2363
,
0.1333
])
levels
=
weights
.
size
im1
,
im2
=
[
x
.
astype
(
np
.
float32
)
for
x
in
[
img1
,
img2
]]
mssim
=
[]
mcs
=
[]
for
_
in
range
(
levels
):
ssim
,
cs
=
_ssim_for_multi_scale
(
im1
,
im2
,
max_val
=
max_val
,
filter_size
=
filter_size
,
filter_sigma
=
filter_sigma
,
k1
=
k1
,
k2
=
k2
)
mssim
.
append
(
ssim
)
mcs
.
append
(
cs
)
im1
,
im2
=
[
_hox_downsample
(
x
)
for
x
in
[
im1
,
im2
]]
# Clip to zero. Otherwise we get NaNs.
mssim
=
np
.
clip
(
np
.
asarray
(
mssim
),
0.0
,
np
.
inf
)
mcs
=
np
.
clip
(
np
.
asarray
(
mcs
),
0.0
,
np
.
inf
)
# Average over images only at the end.
return
np
.
mean
(
np
.
prod
(
mcs
[:
-
1
,
:]
**
weights
[:
-
1
,
np
.
newaxis
],
axis
=
0
)
*
(
mssim
[
-
1
,
:]
**
weights
[
-
1
]))
def
sliced_wasserstein
(
distribution_a
,
distribution_b
,
dir_repeats
=
4
,
dirs_per_repeat
=
128
):
r
"""sliced Wasserstein distance of two sets of patches.
Ref: https://github.com/tkarras/progressive_growing_of_gans/blob/master/metrics/ms_ssim.py # noqa
Args:
distribution_a (Tensor): Descriptors of first distribution.
distribution_b (Tensor): Descriptors of second distribution.
dir_repeats (int): The number of projection times. Default to 4.
dirs_per_repeat (int): The number of directions per projection.
Default to 128.
Returns:
float: sliced Wasserstein distance.
"""
if
torch
.
cuda
.
is_available
():
distribution_b
=
distribution_b
.
cuda
()
assert
distribution_a
.
ndim
==
2
assert
distribution_a
.
shape
==
distribution_b
.
shape
assert
dir_repeats
>
0
and
dirs_per_repeat
>
0
distribution_a
=
distribution_a
.
to
(
distribution_b
.
device
)
results
=
[]
for
_
in
range
(
dir_repeats
):
dirs
=
torch
.
randn
(
distribution_a
.
shape
[
1
],
dirs_per_repeat
)
dirs
/=
torch
.
sqrt
(
torch
.
sum
((
dirs
**
2
),
dim
=
0
,
keepdim
=
True
))
dirs
=
dirs
.
to
(
distribution_b
.
device
)
proj_a
=
torch
.
matmul
(
distribution_a
,
dirs
)
proj_b
=
torch
.
matmul
(
distribution_b
,
dirs
)
# To save cuda memory, we perform sort in cpu
proj_a
,
_
=
torch
.
sort
(
proj_a
.
cpu
(),
dim
=
0
)
proj_b
,
_
=
torch
.
sort
(
proj_b
.
cpu
(),
dim
=
0
)
dists
=
torch
.
abs
(
proj_a
-
proj_b
)
results
.
append
(
torch
.
mean
(
dists
).
item
())
torch
.
cuda
.
empty_cache
()
return
sum
(
results
)
/
dir_repeats
class
Metric
(
ABC
):
"""The abstract base class of metrics. Basically, we split calculation into
three steps. First, we initialize the metric object and do some
preparation. Second, we will feed the real and fake images into metric
object batch by batch, and we calculate intermediate results of these
batches. Finally, We use these intermediate results to summarize the final
result. And the result as a string can be obtained by property
'result_str'.
Args:
num_images (int): The number of real/fake images needed to calculate
metric.
image_shape (tuple): Shape of the real/fake images with order "CHW".
"""
def
__init__
(
self
,
num_images
,
image_shape
=
None
):
self
.
num_images
=
num_images
self
.
image_shape
=
image_shape
self
.
num_real_need
=
num_images
self
.
num_fake_need
=
num_images
self
.
num_real_feeded
=
0
# record of the fed real images
self
.
num_fake_feeded
=
0
# record of the fed fake images
self
.
_result_str
=
None
# string of metric result
@
property
def
result_str
(
self
):
"""Get results in string format.
Returns:
str: results in string format
"""
if
not
self
.
_result_str
:
self
.
summary
()
return
self
.
_result_str
return
self
.
_result_str
def
feed
(
self
,
batch
,
mode
):
"""Feed a image batch into metric calculator and perform intermediate
operation in 'feed_op' function.
Args:
batch (Tensor | dict): Images or dict to be fed into
metric object. If ``Tensor`` is passed, the order of ``Tensor``
should be "NCHW". If ``dict`` is passed, each term in the
``dict`` are ``Tensor`` with order "NCHW".
mode (str): Mark the batch as real or fake images. Value can be
'reals' or 'fakes',
"""
_
,
ws
=
get_dist_info
()
if
mode
==
'reals'
:
if
self
.
num_real_feeded
==
self
.
num_real_need
:
return
0
if
isinstance
(
batch
,
dict
):
batch_size
=
[
v
for
v
in
batch
.
values
()][
0
].
shape
[
0
]
end
=
min
(
batch_size
,
self
.
num_real_need
-
self
.
num_real_feeded
)
batch_to_feed
=
{
k
:
v
[:
end
,
...]
for
k
,
v
in
batch
.
items
()}
else
:
batch_size
=
batch
.
shape
[
0
]
end
=
min
(
batch_size
,
self
.
num_real_need
-
self
.
num_real_feeded
)
batch_to_feed
=
batch
[:
end
,
...]
global_end
=
min
(
batch_size
*
ws
,
self
.
num_real_need
-
self
.
num_real_feeded
)
self
.
feed_op
(
batch_to_feed
,
mode
)
self
.
num_real_feeded
+=
global_end
return
end
elif
mode
==
'fakes'
:
if
self
.
num_fake_feeded
==
self
.
num_fake_need
:
return
0
batch_size
=
batch
.
shape
[
0
]
end
=
min
(
batch_size
,
self
.
num_fake_need
-
self
.
num_fake_feeded
)
if
isinstance
(
batch
,
dict
):
batch_to_feed
=
{
k
:
v
[:
end
,
...]
for
k
,
v
in
batch
.
items
()}
else
:
batch_to_feed
=
batch
[:
end
,
...]
global_end
=
min
(
batch_size
*
ws
,
self
.
num_fake_need
-
self
.
num_fake_feeded
)
self
.
feed_op
(
batch_to_feed
,
mode
)
self
.
num_fake_feeded
+=
global_end
return
end
else
:
raise
ValueError
(
'The expected mode should be set to
\'
reals
\'
or
\'
fakes
\'
,'
f
'but got
\'
{
mode
}
\'
'
)
def
check
(
self
):
"""Check the numbers of image."""
assert
self
.
num_real_feeded
==
self
.
num_fake_feeded
==
self
.
num_images
@
abstractmethod
def
prepare
(
self
,
*
args
,
**
kwargs
):
"""please implement in subclass."""
@
abstractmethod
def
feed_op
(
self
,
batch
,
mode
):
"""please implement in subclass."""
@
abstractmethod
def
summary
(
self
):
"""please implement in subclass."""
@
METRICS
.
register_module
()
class
FID
(
Metric
):
"""FID metric.
In this metric, we calculate the distance between real distributions and
fake distributions. The distributions are modeled by the real samples and
fake samples, respectively.
`Inception_v3` is adopted as the feature extractor, which is widely used in
StyleGAN and BigGAN.
Args:
num_images (int): The number of images to be tested.
image_shape (tuple[int], optional): Image shape. Defaults to None.
inception_pkl (str, optional): Path to reference inception pickle file.
If `None`, the statistical value of real distribution will be
calculated at running time. Defaults to None.
bgr2rgb (bool, optional): If True, reformat the BGR image to RGB
format. Defaults to True.
inception_args (dict, optional): Keyword args for inception net.
Defaults to `dict(normalize_input=False)`.
"""
name
=
'FID'
def
__init__
(
self
,
num_images
,
image_shape
=
None
,
inception_pkl
=
None
,
bgr2rgb
=
True
,
inception_args
=
dict
(
normalize_input
=
False
)):
super
().
__init__
(
num_images
,
image_shape
=
image_shape
)
self
.
inception_pkl
=
inception_pkl
self
.
real_feats
=
[]
self
.
fake_feats
=
[]
self
.
real_mean
=
None
self
.
real_cov
=
None
self
.
bgr2rgb
=
bgr2rgb
self
.
device
=
'cpu'
self
.
inception_net
,
self
.
inception_style
=
load_inception
(
inception_args
,
'FID'
)
if
torch
.
cuda
.
is_available
():
self
.
inception_net
=
self
.
inception_net
.
cuda
()
self
.
device
=
'cuda'
self
.
inception_net
.
eval
()
mmcv
.
print_log
(
f
'FID: Adopt Inception in
{
self
.
inception_style
}
style'
,
'mmgen'
)
def
prepare
(
self
):
"""Prepare for evaluating models with this metric."""
# if `inception_pkl` is provided, read mean and cov stat
if
self
.
inception_pkl
is
not
None
and
mmcv
.
is_filepath
(
self
.
inception_pkl
):
with
open
(
self
.
inception_pkl
,
'rb'
)
as
f
:
reference
=
pickle
.
load
(
f
)
self
.
real_mean
=
reference
[
'mean'
]
self
.
real_cov
=
reference
[
'cov'
]
mmcv
.
print_log
(
f
'Load reference inception pkl from
{
self
.
inception_pkl
}
'
,
'mmgen'
)
self
.
num_real_feeded
=
self
.
num_images
@
torch
.
no_grad
()
def
feed_op
(
self
,
batch
,
mode
):
"""Feed data to the metric.
Args:
batch (Tensor): Input tensor.
mode (str): The mode of current data batch. 'reals' or 'fakes'.
"""
if
self
.
bgr2rgb
:
batch
=
batch
[:,
[
2
,
1
,
0
]]
batch
=
batch
.
to
(
self
.
device
)
if
self
.
inception_style
==
'StyleGAN'
:
batch
=
(
batch
*
127.5
+
128
).
clamp
(
0
,
255
).
to
(
torch
.
uint8
)
feat
=
self
.
inception_net
(
batch
,
return_features
=
True
)
else
:
feat
=
self
.
inception_net
(
batch
)[
0
].
view
(
batch
.
shape
[
0
],
-
1
)
# gather all of images if using distributed training
if
dist
.
is_initialized
():
ws
=
dist
.
get_world_size
()
placeholder
=
[
torch
.
zeros_like
(
feat
)
for
_
in
range
(
ws
)]
dist
.
all_gather
(
placeholder
,
feat
)
feat
=
torch
.
cat
(
placeholder
,
dim
=
0
)
# in distributed training, we only collect features at rank-0.
if
(
dist
.
is_initialized
()
and
dist
.
get_rank
()
==
0
)
or
not
dist
.
is_initialized
():
if
mode
==
'reals'
:
self
.
real_feats
.
append
(
feat
.
cpu
())
elif
mode
==
'fakes'
:
self
.
fake_feats
.
append
(
feat
.
cpu
())
else
:
raise
ValueError
(
f
"The expected mode should be set to 'reals' or 'fakes,
\
but got '
{
mode
}
'"
)
@
staticmethod
def
_calc_fid
(
sample_mean
,
sample_cov
,
real_mean
,
real_cov
,
eps
=
1e-6
):
"""Refer to the implementation from:
https://github.com/rosinality/stylegan2-pytorch/blob/master/fid.py#L34
"""
cov_sqrt
,
_
=
linalg
.
sqrtm
(
sample_cov
@
real_cov
,
disp
=
False
)
if
not
np
.
isfinite
(
cov_sqrt
).
all
():
print
(
'product of cov matrices is singular'
)
offset
=
np
.
eye
(
sample_cov
.
shape
[
0
])
*
eps
cov_sqrt
=
linalg
.
sqrtm
(
(
sample_cov
+
offset
)
@
(
real_cov
+
offset
))
if
np
.
iscomplexobj
(
cov_sqrt
):
if
not
np
.
allclose
(
np
.
diagonal
(
cov_sqrt
).
imag
,
0
,
atol
=
1e-3
):
m
=
np
.
max
(
np
.
abs
(
cov_sqrt
.
imag
))
raise
ValueError
(
f
'Imaginary component
{
m
}
'
)
cov_sqrt
=
cov_sqrt
.
real
mean_diff
=
sample_mean
-
real_mean
mean_norm
=
mean_diff
@
mean_diff
trace
=
np
.
trace
(
sample_cov
)
+
np
.
trace
(
real_cov
)
-
2
*
np
.
trace
(
cov_sqrt
)
fid
=
mean_norm
+
trace
return
fid
,
mean_norm
,
trace
@
torch
.
no_grad
()
def
summary
(
self
):
"""Summarize the results.
Returns:
dict | list: Summarized results.
"""
# calculate reference inception stat
if
self
.
real_mean
is
None
:
feats
=
torch
.
cat
(
self
.
real_feats
,
dim
=
0
)
assert
feats
.
shape
[
0
]
>=
self
.
num_images
feats
=
feats
[:
self
.
num_images
]
feats_np
=
feats
.
numpy
()
self
.
real_mean
=
np
.
mean
(
feats_np
,
0
)
self
.
real_cov
=
np
.
cov
(
feats_np
,
rowvar
=
False
)
# calculate fake inception stat
fake_feats
=
torch
.
cat
(
self
.
fake_feats
,
dim
=
0
)
assert
fake_feats
.
shape
[
0
]
>=
self
.
num_images
fake_feats
=
fake_feats
[:
self
.
num_images
]
fake_feats_np
=
fake_feats
.
numpy
()
fake_mean
=
np
.
mean
(
fake_feats_np
,
0
)
fake_cov
=
np
.
cov
(
fake_feats_np
,
rowvar
=
False
)
# calculate distance between real and fake statistics
fid
,
mean
,
cov
=
self
.
_calc_fid
(
fake_mean
,
fake_cov
,
self
.
real_mean
,
self
.
real_cov
)
# results for print/table
self
.
_result_str
=
(
f
'
{
fid
:.
4
f
}
(
{
mean
:.
5
f
}
/
{
cov
:.
5
f
}
)'
)
# results for log_buffer
self
.
_result_dict
=
dict
(
fid
=
fid
,
fid_mean
=
mean
,
fid_cov
=
cov
)
return
fid
,
mean
,
cov
def
clear_fake_data
(
self
):
"""Clear fake data."""
self
.
fake_feats
=
[]
self
.
num_fake_feeded
=
0
def
clear
(
self
,
clear_reals
=
False
):
"""Clear data buffers.
Args:
clear_reals (bool, optional): Whether to clear real data.
Defaults to False.
"""
self
.
clear_fake_data
()
if
clear_reals
:
self
.
real_feats
=
[]
self
.
num_real_feeded
=
0
@
METRICS
.
register_module
()
class
MS_SSIM
(
Metric
):
"""MS-SSIM (Multi-Scale Structure Similarity) metric.
Ref: https://github.com/tkarras/progressive_growing_of_gans/blob/master/metrics/ms_ssim.py # noqa
Args:
num_images (int): The number of evaluated generated samples.
image_shape (tuple, optional): Image shape in order "CHW". Defaults to
None.
"""
name
=
'MS-SSIM'
def
__init__
(
self
,
num_images
,
image_shape
=
None
):
super
().
__init__
(
num_images
,
image_shape
)
assert
num_images
%
2
==
0
self
.
num_pairs
=
num_images
//
2
def
prepare
(
self
):
"""Prepare for evaluating models with this metric."""
self
.
sum
=
0.0
@
torch
.
no_grad
()
def
feed_op
(
self
,
minibatch
,
mode
):
"""Feed data to the metric.
Args:
batch (Tensor): Input tensor.
mode (str): The mode of current data batch. 'reals' or 'fakes'.
"""
if
mode
==
'reals'
:
return
minibatch
=
((
minibatch
+
1
)
/
2
)
minibatch
=
minibatch
.
clamp_
(
0
,
1
)
half1
=
minibatch
[
0
::
2
].
cpu
().
data
.
numpy
().
transpose
((
0
,
2
,
3
,
1
))
half1
=
(
half1
*
255
).
astype
(
'uint8'
)
half2
=
minibatch
[
1
::
2
].
cpu
().
data
.
numpy
().
transpose
((
0
,
2
,
3
,
1
))
half2
=
(
half2
*
255
).
astype
(
'uint8'
)
score
=
ms_ssim
(
half1
,
half2
)
self
.
sum
+=
score
*
(
minibatch
.
shape
[
0
]
//
2
)
@
torch
.
no_grad
()
def
summary
(
self
):
"""Summarize the results.
Returns:
dict | list: Summarized results.
"""
self
.
check
()
avg
=
self
.
sum
/
self
.
num_pairs
self
.
_result_str
=
str
(
round
(
avg
.
item
(),
4
))
return
avg
@
METRICS
.
register_module
()
class
SWD
(
Metric
):
"""SWD (Sliced Wasserstein distance) metric. We calculate the SWD of two
sets of images in the following way. In every 'feed', we obtain the
Laplacian pyramids of every images and extract patches from the Laplacian
pyramids as descriptors. In 'summary', we normalize these descriptors along
channel, and reshape them so that we can use these descriptors to represent
the distribution of real/fake images. And we can calculate the sliced
Wasserstein distance of the real and fake descriptors as the SWD of the
real and fake images.
Ref: https://github.com/tkarras/progressive_growing_of_gans/blob/master/metrics/sliced_wasserstein.py # noqa
Args:
num_images (int): The number of evaluated generated samples.
image_shape (tuple): Image shape in order "CHW".
"""
name
=
'SWD'
def
__init__
(
self
,
num_images
,
image_shape
):
super
().
__init__
(
num_images
,
image_shape
)
self
.
nhood_size
=
7
# height and width of the extracted patches
self
.
nhoods_per_image
=
128
# number of extracted patches per image
self
.
dir_repeats
=
4
# times of sampling directions
self
.
dirs_per_repeat
=
128
# number of directions per sampling
self
.
resolutions
=
[]
res
=
image_shape
[
1
]
while
res
>=
16
and
len
(
self
.
resolutions
)
<
4
:
self
.
resolutions
.
append
(
res
)
res
//=
2
self
.
n_pyramids
=
len
(
self
.
resolutions
)
def
prepare
(
self
):
"""Prepare for evaluating models with this metric."""
self
.
real_descs
=
[[]
for
res
in
self
.
resolutions
]
self
.
fake_descs
=
[[]
for
res
in
self
.
resolutions
]
self
.
gaussian_k
=
get_gaussian_kernel
()
@
torch
.
no_grad
()
def
feed_op
(
self
,
minibatch
,
mode
):
"""Feed data to the metric.
Args:
batch (Tensor): Input tensor.
mode (str): The mode of current data batch. 'reals' or 'fakes'.
"""
assert
minibatch
.
shape
[
1
:]
==
self
.
image_shape
if
mode
==
'reals'
:
real_pyramid
=
laplacian_pyramid
(
minibatch
,
self
.
n_pyramids
-
1
,
self
.
gaussian_k
)
# lod: layer_of_descriptors
for
lod
,
level
in
enumerate
(
real_pyramid
):
desc
=
get_descriptors_for_minibatch
(
level
,
self
.
nhood_size
,
self
.
nhoods_per_image
)
self
.
real_descs
[
lod
].
append
(
desc
)
elif
mode
==
'fakes'
:
fake_pyramid
=
laplacian_pyramid
(
minibatch
,
self
.
n_pyramids
-
1
,
self
.
gaussian_k
)
for
lod
,
level
in
enumerate
(
fake_pyramid
):
desc
=
get_descriptors_for_minibatch
(
level
,
self
.
nhood_size
,
self
.
nhoods_per_image
)
self
.
fake_descs
[
lod
].
append
(
desc
)
else
:
raise
ValueError
(
f
'
{
mode
}
is not a implemented feed mode.'
)
@
torch
.
no_grad
()
def
summary
(
self
):
"""Summarize the results.
Returns:
dict | list: Summarized results.
"""
self
.
check
()
real_descs
=
[
finalize_descriptors
(
d
)
for
d
in
self
.
real_descs
]
fake_descs
=
[
finalize_descriptors
(
d
)
for
d
in
self
.
fake_descs
]
del
self
.
real_descs
del
self
.
fake_descs
distance
=
[
sliced_wasserstein
(
dreal
,
dfake
,
self
.
dir_repeats
,
self
.
dirs_per_repeat
)
for
dreal
,
dfake
in
zip
(
real_descs
,
fake_descs
)
]
del
real_descs
del
fake_descs
distance
=
[
d
*
1e3
for
d
in
distance
]
# multiply by 10^3
result
=
distance
+
[
np
.
mean
(
distance
)]
self
.
_result_str
=
', '
.
join
([
str
(
round
(
d
,
2
))
for
d
in
result
])
return
result
@
METRICS
.
register_module
()
class
PR
(
Metric
):
r
"""Improved Precision and recall metric.
In this metric, we draw real and generated samples respectively, and
embed them into a high-dimensional feature space using a pre-trained
classifier network. We use these features to estimate the corresponding
manifold. We obtain the estimation by calculating pairwise Euclidean
distances between all feature vectors in the set and, for each feature
vector, construct a hypersphere with radius equal to the distance to its
kth nearest neighbor. Together, these hyperspheres define a volume in
the feature space that serves as an estimate of the true manifold.
Precision is quantified by querying for each generated image whether
the image is within the estimated manifold of real images.
Symmetrically, recall is calculated by querying for each real image
whether the image is within estimated manifold of generated image.
Ref: https://github.com/NVlabs/stylegan2-ada-pytorch/blob/main/metrics/precision_recall.py # noqa
Note that we highly recommend that users should download the vgg16
script module from the following address. Then, the `vgg16_script` can
be set with user's local path. If not given, we will use the vgg16 from
pytorch model zoo. However, this may bring significant different in the
final results.
Tero's vgg16: https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt
Args:
num_images (int): The number of evaluated generated samples.
image_shape (tuple): Image shape in order "CHW". Defaults to None.
num_real_need (int | None, optional): The number of real images.
Defaults to None.
full_dataset (bool, optional): Whether to use full dataset for
evaluation. Defaults to False.
k (int, optional): Kth nearest parameter. Defaults to 3.
bgr2rgb (bool, optional): Whether to change the order of image
channel. Defaults to True.
vgg16_script (str, optional): Path for the Tero's vgg16 module.
Defaults to 'work_dirs/cache/vgg16.pt'.
row_batch_size (int, optional): The batch size of row data.
Defaults to 10000.
col_batch_size (int, optional): The batch size of col data.
Defaults to 10000.
"""
name
=
'PR'
def
__init__
(
self
,
num_images
,
image_shape
=
None
,
num_real_need
=
None
,
full_dataset
=
False
,
k
=
3
,
bgr2rgb
=
True
,
vgg16_script
=
'work_dirs/cache/vgg16.pt'
,
row_batch_size
=
10000
,
col_batch_size
=
10000
):
super
().
__init__
(
num_images
,
image_shape
)
mmcv
.
print_log
(
'loading vgg16 for improved precision and recall...'
,
'mmgen'
)
if
os
.
path
.
isfile
(
vgg16_script
):
self
.
vgg16
=
torch
.
jit
.
load
(
'work_dirs/cache/vgg16.pt'
).
eval
()
self
.
use_tero_scirpt
=
True
else
:
mmcv
.
print_log
(
'Cannot load Tero
\'
s script module. Use official '
'vgg16 instead'
,
'mmgen'
)
self
.
vgg16
=
models
.
vgg16
(
pretrained
=
True
).
eval
()
self
.
use_tero_scirpt
=
False
self
.
device
=
'cpu'
if
torch
.
cuda
.
is_available
():
self
.
vgg16
=
self
.
vgg16
.
cuda
()
self
.
device
=
'cuda'
self
.
k
=
k
self
.
bgr2rgb
=
bgr2rgb
self
.
full_dataset
=
full_dataset
self
.
row_batch_size
=
row_batch_size
self
.
col_batch_size
=
col_batch_size
if
num_real_need
:
self
.
num_real_need
=
num_real_need
if
self
.
full_dataset
:
self
.
num_real_need
=
10000000
def
prepare
(
self
,
*
args
,
**
kwargs
):
"""Prepare for evaluating models with this metric."""
self
.
features_of_reals
=
[]
self
.
features_of_fakes
=
[]
@
torch
.
no_grad
()
def
feed_op
(
self
,
batch
,
mode
):
"""Feed data to the metric.
Args:
batch (Tensor): Input tensor.
mode (str): The mode of current data batch. 'reals' or 'fakes'.
"""
batch
=
batch
.
to
(
self
.
device
)
if
self
.
bgr2rgb
:
batch
=
batch
[:,
[
2
,
1
,
0
],
...]
if
self
.
use_tero_scirpt
:
batch
=
(
batch
*
127.5
+
128
).
clamp
(
0
,
255
).
to
(
torch
.
uint8
)
if
mode
==
'reals'
:
self
.
features_of_reals
.
append
(
self
.
extract_features
(
batch
))
elif
mode
==
'fakes'
:
self
.
features_of_fakes
.
append
(
self
.
extract_features
(
batch
))
else
:
raise
ValueError
(
f
'
{
mode
}
is not a implemented feed mode.'
)
def
check
(
self
):
if
not
self
.
full_dataset
:
assert
(
self
.
num_real_feeded
==
self
.
num_real_need
and
self
.
num_fake_feeded
==
self
.
num_fake_need
)
else
:
assert
self
.
num_fake_feeded
==
self
.
num_fake_need
mmcv
.
print_log
(
f
'Test for the full dataset with
{
self
.
num_real_feeded
}
'
' real images'
,
'mmgen'
)
@
torch
.
no_grad
()
def
summary
(
self
):
"""Summarize the results.
Returns:
dict | list: Summarized results.
"""
self
.
check
()
real_features
=
torch
.
cat
(
self
.
features_of_reals
)
gen_features
=
torch
.
cat
(
self
.
features_of_fakes
)
self
.
_result_dict
=
{}
rank
,
ws
=
get_dist_info
()
for
name
,
manifold
,
probes
in
[
(
'precision'
,
real_features
,
gen_features
),
(
'recall'
,
gen_features
,
real_features
)
]:
kth
=
[]
for
manifold_batch
in
manifold
.
split
(
self
.
row_batch_size
):
distance
=
compute_pr_distances
(
row_features
=
manifold_batch
,
col_features
=
manifold
,
num_gpus
=
ws
,
rank
=
rank
,
col_batch_size
=
self
.
col_batch_size
)
kth
.
append
(
distance
.
to
(
torch
.
float32
).
kthvalue
(
self
.
k
+
1
).
values
.
to
(
torch
.
float16
)
if
rank
==
0
else
None
)
kth
=
torch
.
cat
(
kth
)
if
rank
==
0
else
None
pred
=
[]
for
probes_batch
in
probes
.
split
(
self
.
row_batch_size
):
distance
=
compute_pr_distances
(
row_features
=
probes_batch
,
col_features
=
manifold
,
num_gpus
=
ws
,
rank
=
rank
,
col_batch_size
=
self
.
col_batch_size
)
pred
.
append
((
distance
<=
kth
).
any
(
dim
=
1
)
if
rank
==
0
else
None
)
self
.
_result_dict
[
name
]
=
float
(
torch
.
cat
(
pred
).
to
(
torch
.
float32
).
mean
()
if
rank
==
0
else
'nan'
)
precision
=
self
.
_result_dict
[
'precision'
]
recall
=
self
.
_result_dict
[
'recall'
]
self
.
_result_str
=
f
'precision:
{
precision
}
, recall:
{
recall
}
'
return
self
.
_result_dict
def
extract_features
(
self
,
images
):
"""Extracting image features.
Args:
images (torch.Tensor): Images tensor.
Returns:
torch.Tensor: Vgg16 features of input images.
"""
if
self
.
use_tero_scirpt
:
feature
=
self
.
vgg16
(
images
,
return_features
=
True
)
else
:
batch
=
F
.
interpolate
(
images
,
size
=
(
224
,
224
))
before_fc
=
self
.
vgg16
.
features
(
batch
)
before_fc
=
before_fc
.
view
(
-
1
,
7
*
7
*
512
)
feature
=
self
.
vgg16
.
classifier
[:
4
](
before_fc
)
return
feature
@
METRICS
.
register_module
()
class
IS
(
Metric
):
"""IS (Inception Score) metric.
The images are split into groups, and the inception score is calculated
on each group of images, then the mean and standard deviation of the score
is reported. The calculation of the inception score on a group of images
involves first using the inception v3 model to calculate the conditional
probability for each image (p(y|x)). The marginal probability is then
calculated as the average of the conditional probabilities for the images
in the group (p(y)). The KL divergence is then calculated for each image as
the conditional probability multiplied by the log of the conditional
probability minus the log of the marginal probability. The KL divergence is
then summed over all images and averaged over all classes and the exponent
of the result is calculated to give the final score.
Ref: https://github.com/sbarratt/inception-score-pytorch/blob/master/inception_score.py # noqa
Note that we highly recommend that users should download the Inception V3
script module from the following address. Then, the `inception_pkl` can
be set with user's local path. If not given, we will use the Inception V3
from pytorch model zoo. However, this may bring significant different in
the final results.
Tero's Inception V3: https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt # noqa
Args:
num_images (int): The number of evaluated generated samples.
image_shape (tuple, optional): Image shape in order "CHW". Defaults to
None.
bgr2rgb (bool, optional): If True, reformat the BGR image to RGB
format. In default, our model generate images in the BGR order.
Thus, we use `True` as the default behavior. Please switch to
`False`, if the input is in the `RGB` order. Defaults to True.
resize (bool, optional): Whether resize image to 299x299. Defaults to
True.
splits (int, optional): The number of groups. Defaults to 10.
use_pil_resize (bool, optional): Whether use Bicubic interpolation with
Pillow's backend. If set as True, the evaluation process may be a
little bit slow, but achieve a more accurate IS result. Defaults
to False.
inception_args (dict, optional): Keyword args for inception net.
Defaults to ``dict(type='StyleGAN', inception_path=INCEPTION_URL)``.
"""
name
=
'IS'
def
__init__
(
self
,
num_images
,
image_shape
=
None
,
bgr2rgb
=
True
,
resize
=
True
,
splits
=
10
,
use_pil_resize
=
True
,
inception_args
=
dict
(
type
=
'StyleGAN'
,
inception_path
=
TERO_INCEPTION_URL
)):
super
().
__init__
(
num_images
,
image_shape
)
mmcv
.
print_log
(
'Loading Inception V3 for IS...'
,
'mmgen'
)
model
,
style
=
load_inception
(
inception_args
,
'IS'
)
self
.
inception_model
=
model
self
.
use_tero_script
=
style
==
'StyleGAN'
self
.
num_real_feeded
=
self
.
num_images
self
.
resize
=
resize
self
.
splits
=
splits
self
.
bgr2rgb
=
bgr2rgb
self
.
use_pil_resize
=
use_pil_resize
self
.
_pil_resize_warned
=
False
self
.
device
=
'cpu'
if
torch
.
cuda
.
is_available
():
self
.
inception_model
=
self
.
inception_model
.
cuda
()
self
.
device
=
'cuda'
self
.
inception_model
.
eval
()
def
pil_resize
(
self
,
x
):
"""Apply Bicubic interpolation with Pillow backend. Before and after
interpolation operation, we have to perform a type conversion between
torch.tensor and PIL.Image, and these operations make resize process a
bit slow.
Args:
x (Tensor): Input tensor, should have four dimension and
range in [-1, 1].
Returns:
torch.FloatTensor: Resized tensor.
"""
if
not
self
.
_pil_resize_warned
:
mmcv
.
print_log
(
'`use_pil_resize` is set as True, apply Bicubic '
'interpolation with Pillow backend. We perform '
'type conversion between torch.tensor and '
'PIL.Image in this function and make this process '
'a little bit slow.'
,
'mmgen'
)
self
.
_pil_resize_warned
=
True
from
PIL
import
Image
if
x
.
ndim
!=
4
:
raise
ValueError
(
'Input images should have 4 dimensions, '
'here receive input with {} '
'dimensions.'
.
format
(
x
.
ndim
))
x
=
(
x
.
clone
()
*
127.5
+
128
).
clamp
(
0
,
255
).
to
(
torch
.
uint8
)
x_np
=
[
x_
.
permute
(
1
,
2
,
0
).
detach
().
cpu
().
numpy
()
for
x_
in
x
]
x_pil
=
[
Image
.
fromarray
(
x_
).
resize
((
299
,
299
))
for
x_
in
x_np
]
x_ten
=
torch
.
cat
(
[
torch
.
FloatTensor
(
np
.
array
(
x_
)[
None
,
...])
for
x_
in
x_pil
])
x_ten
=
(
x_ten
/
127.5
-
1
).
to
(
torch
.
float
)
return
x_ten
.
permute
(
0
,
3
,
1
,
2
)
def
get_pred
(
self
,
x
):
"""Get prediction from inception model.
Args:
x (Tensor): Input tensor.
Returns:
np.array: Inception score.
"""
if
self
.
use_tero_script
:
x
=
self
.
inception_model
(
x
,
no_output_bias
=
True
)
else
:
# specify the dimension to avoid warning
x
=
F
.
softmax
(
self
.
inception_model
(
x
),
dim
=
1
)
return
x
def
prepare
(
self
):
"""Prepare for evaluating models with this metric."""
self
.
preds
=
[]
@
torch
.
no_grad
()
def
feed_op
(
self
,
batch
,
mode
):
"""Feed data to the metric.
Args:
batch (Tensor): Input tensor.
mode (str): The mode of current data batch. 'reals' or 'fakes'.
"""
if
mode
==
'reals'
:
pass
elif
mode
==
'fakes'
:
if
self
.
bgr2rgb
:
batch
=
batch
[:,
[
2
,
1
,
0
],
...]
if
self
.
resize
:
if
self
.
use_pil_resize
:
batch
=
self
.
pil_resize
(
batch
)
else
:
batch
=
F
.
interpolate
(
batch
,
size
=
(
299
,
299
),
mode
=
'bilinear'
)
if
self
.
use_tero_script
:
batch
=
(
batch
*
127.5
+
128
).
clamp
(
0
,
255
).
to
(
torch
.
uint8
)
batch
=
batch
.
to
(
self
.
device
)
# get prediction
pred
=
self
.
get_pred
(
batch
)
if
dist
.
is_initialized
():
ws
=
dist
.
get_world_size
()
placeholder
=
[
torch
.
zeros_like
(
pred
)
for
_
in
range
(
ws
)]
dist
.
all_gather
(
placeholder
,
pred
)
pred
=
torch
.
cat
(
placeholder
,
dim
=
0
)
# in distributed training, we only collect features at rank-0.
if
(
dist
.
is_initialized
()
and
dist
.
get_rank
()
==
0
)
or
not
dist
.
is_initialized
():
self
.
preds
.
append
(
pred
.
cpu
().
numpy
())
else
:
raise
ValueError
(
f
'
{
mode
}
is not a implemented feed mode.'
)
@
torch
.
no_grad
()
def
summary
(
self
):
"""Summarize the results.
TODO: support `master_only`
Returns:
dict | list: Summarized results.
"""
split_scores
=
[]
self
.
preds
=
np
.
concatenate
(
self
.
preds
,
axis
=
0
)
# check for the size
assert
self
.
preds
.
shape
[
0
]
>=
self
.
num_images
self
.
preds
=
self
.
preds
[:
self
.
num_images
]
for
k
in
range
(
self
.
splits
):
part
=
self
.
preds
[
k
*
(
self
.
num_images
//
self
.
splits
):(
k
+
1
)
*
(
self
.
num_images
//
self
.
splits
),
:]
py
=
np
.
mean
(
part
,
axis
=
0
)
scores
=
[]
for
i
in
range
(
part
.
shape
[
0
]):
pyx
=
part
[
i
,
:]
scores
.
append
(
entropy
(
pyx
,
py
))
split_scores
.
append
(
np
.
exp
(
np
.
mean
(
scores
)))
mean
,
std
=
np
.
mean
(
split_scores
),
np
.
std
(
split_scores
)
# results for print/table
self
.
_result_str
=
f
'mean:
{
mean
:.
3
f
}
, std:
{
std
:.
3
f
}
'
# results for log_buffer
self
.
_result_dict
=
{
'is'
:
mean
,
'is_std'
:
std
}
return
mean
,
std
def
clear_fake_data
(
self
):
"""Clear fake data."""
self
.
preds
=
[]
self
.
num_fake_feeded
=
0
def
clear
(
self
,
clear_reals
=
False
):
"""Clear data buffers.
Args:
clear_reals (bool, optional): Whether to clear real data.
Defaults to False.
"""
self
.
clear_fake_data
()
@
METRICS
.
register_module
()
class
PPL
(
Metric
):
r
"""Perceptual path length.
Measure the difference between consecutive images (their VGG16
embeddings) when interpolating between two random inputs. Drastic
changes mean that multiple features have changed together and that
they might be entangled.
Ref: https://github.com/rosinality/stylegan2-pytorch/blob/master/ppl.py # noqa
Args:
num_images (int): The number of evaluated generated samples.
image_shape (tuple, optional): Image shape in order "CHW". Defaults
to None.
crop (bool, optional): Whether crop images. Defaults to True.
epsilon (float, optional): Epsilon parameter for path sampling.
Defaults to 1e-4.
space (str, optional): Latent space. Defaults to 'W'.
sampling (str, optional): Sampling mode, whether sampling in full
path or endpoints. Defaults to 'end'.
latent_dim (int, optional): Latent dimension of input noise.
Defaults to 512.
"""
name
=
'PPL'
def
__init__
(
self
,
num_images
,
image_shape
=
None
,
crop
=
True
,
epsilon
=
1e-4
,
space
=
'W'
,
sampling
=
'end'
,
latent_dim
=
512
):
super
().
__init__
(
num_images
,
image_shape
=
image_shape
)
self
.
crop
=
crop
self
.
epsilon
=
epsilon
self
.
space
=
space
self
.
sampling
=
sampling
self
.
latent_dim
=
latent_dim
self
.
num_images
=
num_images
*
2
self
.
num_real_feeded
=
self
.
num_images
def
prepare
(
self
):
"""Prepare for evaluating models with this metric."""
self
.
dist_list
=
[]
@
torch
.
no_grad
()
def
feed_op
(
self
,
minibatch
,
mode
):
"""Feed data to the metric.
Args:
batch (Tensor): Input tensor.
mode (str): The mode of current data batch. 'reals' or 'fakes'.
"""
if
mode
==
'reals'
:
return
# use minibatch's device type to initialize a lpips calculator
if
not
hasattr
(
self
,
'percept'
):
self
.
percept
=
PerceptualLoss
(
use_gpu
=
minibatch
.
device
.
type
.
startswith
(
'cuda'
))
# crop and resize images
if
self
.
crop
:
c
=
minibatch
.
shape
[
2
]
//
8
minibatch
=
minibatch
[:,
:,
c
*
3
:
c
*
7
,
c
*
2
:
c
*
6
]
factor
=
minibatch
.
shape
[
2
]
//
256
if
factor
>
1
:
minibatch
=
F
.
interpolate
(
minibatch
,
size
=
(
256
,
256
),
mode
=
'bilinear'
,
align_corners
=
False
)
# calculator and store lpips score
distance
=
self
.
percept
(
minibatch
[::
2
],
minibatch
[
1
::
2
]).
view
(
minibatch
.
shape
[
0
]
//
2
)
/
(
self
.
epsilon
**
2
)
self
.
dist_list
.
append
(
distance
.
to
(
'cpu'
).
numpy
())
@
torch
.
no_grad
()
def
summary
(
self
):
"""Summarize the results.
Returns:
dict | list: Summarized results.
"""
distances
=
np
.
concatenate
(
self
.
dist_list
,
0
)
lo
=
np
.
percentile
(
distances
,
1
,
interpolation
=
'lower'
)
hi
=
np
.
percentile
(
distances
,
99
,
interpolation
=
'higher'
)
filtered_dist
=
np
.
extract
(
np
.
logical_and
(
lo
<=
distances
,
distances
<=
hi
),
distances
)
ppl_score
=
filtered_dist
.
mean
()
self
.
_result_str
=
f
'
{
ppl_score
:.
1
f
}
'
return
ppl_score
def
get_sampler
(
self
,
model
,
batch_size
,
sample_model
):
"""Get sampler for sampling along the path.
Args:
model (nn.Module): Generative model.
batch_size (int): Sampling batch size.
sample_model (str): Which model you want to use. ['ema',
'orig']. Defaults to 'ema'.
Returns:
Object: A sampler for calculating path length regularization.
"""
if
sample_model
==
'ema'
:
generator
=
model
.
generator_ema
else
:
generator
=
model
.
generator
ppl_sampler
=
PPLSampler
(
generator
,
self
.
num_images
,
batch_size
,
self
.
space
,
self
.
sampling
,
self
.
epsilon
,
self
.
latent_dim
)
return
ppl_sampler
class
PPLSampler
:
"""StyleGAN series generator's sampling iterator for PPL metric.
Args:
generator (nn.Module): StyleGAN series' generator.
num_images (int): The number of evaluated generated samples.
batch_size (int): Batch size of generated images.
space (str, optional): Latent space. Defaults to 'W'.
sampling (str, optional): Sampling mode, whether sampling in full
path or endpoints. Defaults to 'end'.
epsilon (float, optional): Epsilon parameter for path sampling.
Defaults to 1e-4.
latent_dim (int, optional): Latent dimension of input noise.
Defaults to 512.
"""
def
__init__
(
self
,
generator
,
num_images
,
batch_size
,
space
=
'W'
,
sampling
=
'end'
,
epsilon
=
1e-4
,
latent_dim
=
512
):
assert
space
in
[
'Z'
,
'W'
]
assert
sampling
in
[
'full'
,
'end'
]
n_batch
=
num_images
//
batch_size
resid
=
num_images
-
(
n_batch
*
batch_size
)
self
.
batch_sizes
=
[
batch_size
]
*
n_batch
+
([
resid
]
if
resid
>
0
else
[])
self
.
device
=
get_module_device
(
generator
)
self
.
generator
=
generator
self
.
latent_dim
=
latent_dim
self
.
space
=
space
self
.
sampling
=
sampling
self
.
epsilon
=
epsilon
def
__iter__
(
self
):
self
.
idx
=
0
return
self
@
torch
.
no_grad
()
def
__next__
(
self
):
if
self
.
idx
>=
len
(
self
.
batch_sizes
):
raise
StopIteration
batch
=
self
.
batch_sizes
[
self
.
idx
]
injected_noise
=
self
.
generator
.
make_injected_noise
()
inputs
=
torch
.
randn
([
batch
*
2
,
self
.
latent_dim
],
device
=
self
.
device
)
if
self
.
sampling
==
'full'
:
lerp_t
=
torch
.
rand
(
batch
,
device
=
self
.
device
)
else
:
lerp_t
=
torch
.
zeros
(
batch
,
device
=
self
.
device
)
if
self
.
space
==
'W'
:
assert
hasattr
(
self
.
generator
,
'style_mapping'
)
latent
=
self
.
generator
.
style_mapping
(
inputs
)
latent_t0
,
latent_t1
=
latent
[::
2
],
latent
[
1
::
2
]
latent_e0
=
torch
.
lerp
(
latent_t0
,
latent_t1
,
lerp_t
[:,
None
])
latent_e1
=
torch
.
lerp
(
latent_t0
,
latent_t1
,
lerp_t
[:,
None
]
+
self
.
epsilon
)
latent_e
=
torch
.
stack
([
latent_e0
,
latent_e1
],
1
).
view
(
*
latent
.
shape
)
image
=
self
.
generator
([
latent_e
],
input_is_latent
=
True
,
injected_noise
=
injected_noise
)
else
:
latent_t0
,
latent_t1
=
inputs
[::
2
],
inputs
[
1
::
2
]
latent_e0
=
slerp
(
latent_t0
,
latent_t1
,
lerp_t
[:,
None
])
latent_e1
=
slerp
(
latent_t0
,
latent_t1
,
lerp_t
[:,
None
]
+
self
.
epsilon
)
latent_e
=
torch
.
stack
([
latent_e0
,
latent_e1
],
1
).
view
(
*
inputs
.
shape
)
image
=
self
.
generator
([
latent_e
],
input_is_latent
=
False
,
injected_noise
=
injected_noise
)
self
.
idx
+=
1
return
image
@
METRICS
.
register_module
()
class
GaussianKLD
(
Metric
):
r
"""Gaussian KLD (Kullback-Leibler divergence) metric. We calculate the
KLD between two gaussian distribution via `mean` and `log_variance`.
The passed batch should be a dict instance and contain ``mean_pred``,
``mean_target``, ``logvar_pred``, ``logvar_target``.
When call ``feed`` operation, only ``reals`` mode is needed,
The calculation of KLD can be formulated as:
.. math::
:nowrap:
\begin{align}
KLD(p||q) &= -\int{p(x)\log{q(x)} dx} + \int{p(x)\log{p(x)} dx} \\
&= \frac{1}{2}\log{(2\pi \sigma_2^2)} +
\frac{\sigma_1^2 + (\mu_1 - \mu_2)^2}{2\sigma_2^2} -
\frac{1}{2}(1 + \log{2\pi \sigma_1^2}) \\
&= \log{\frac{\sigma_2}{\sigma_1}} +
\frac{\sigma_1^2 + (\mu_1 - \mu_2)^2}{2\sigma_2^2} - \frac{1}{2}
\end{align}
where `p` and `q` denote target and predicted distribution respectively.
Args:
num_images (int): The number of samples to be tested.
base (str, optional): The log base of calculated KLD. Support
``'e'`` and ``'2'``. Defaults to ``'e'``.
reduction (string, optional): Specifies the reduction to apply to the
output. Support ``'batchmean'``, ``'sum'`` and ``'mean'``. If
``reduction == 'batchmean'``, the sum of the output will be divided
by batchsize. If ``reduction == 'sum'``, the output will be summed.
If ``reduction == 'mean'``, the output will be divided by the
number of elements in the output. Defaults to ``'batchmean'``.
"""
name
=
'GaussianKLD'
def
__init__
(
self
,
num_images
,
base
=
'e'
,
reduction
=
'batchmean'
):
super
().
__init__
(
num_images
,
image_shape
=
None
)
assert
reduction
in
[
'sum'
,
'batchmean'
,
'mean'
],
(
'We only support reduction for
\'
batchmean
\'
,
\'
sum
\'
'
'and
\'
mean
\'
'
)
assert
base
in
[
'e'
,
'2'
],
(
'We only support log_base for
\'
e
\'
and
\'
2
\'
'
)
self
.
reduction
=
reduction
self
.
num_fake_feeded
=
self
.
num_images
self
.
cal_kld
=
partial
(
gaussian_kld
,
weight
=
1
,
reduction
=
'none'
,
base
=
base
)
def
prepare
(
self
):
"""Prepare for evaluating models with this metric."""
self
.
kld
=
[]
self
.
num_real_feeded
=
0
@
torch
.
no_grad
()
def
feed_op
(
self
,
batch
,
mode
):
"""Feed data to the metric.
Args:
batch (Tensor): Input tensor.
mode (str): The mode of current data batch. 'reals' or 'fakes'.
"""
if
mode
==
'fakes'
:
return
assert
isinstance
(
batch
,
dict
),
(
'To calculate GaussianKLD loss, a '
'dict contains probabilistic '
'parameters is required.'
)
# check required keys
require_keys
=
[
'mean_pred'
,
'mean_target'
,
'logvar_pred'
,
'logvar_target'
]
if
any
([
k
not
in
batch
for
k
in
require_keys
]):
raise
KeyError
(
f
'The input dict must require
{
require_keys
}
at '
'the same time. But keys in the given dict are '
f
'
{
batch
.
keys
()
}
. Some of the requirements are '
'missing.'
)
kld
=
self
.
cal_kld
(
batch
[
'mean_target'
],
batch
[
'mean_pred'
],
batch
[
'logvar_target'
],
batch
[
'logvar_pred'
])
if
dist
.
is_initialized
():
ws
=
dist
.
get_world_size
()
placeholder
=
[
torch
.
zeros_like
(
kld
)
for
_
in
range
(
ws
)]
dist
.
all_gather
(
placeholder
,
kld
)
kld
=
torch
.
cat
(
placeholder
,
dim
=
0
)
# in distributed training, we only collect features at rank-0.
if
(
dist
.
is_initialized
()
and
dist
.
get_rank
()
==
0
)
or
not
dist
.
is_initialized
():
self
.
kld
.
append
(
kld
.
cpu
())
@
torch
.
no_grad
()
def
summary
(
self
):
"""Summarize the results.
Returns:
dict | list: Summarized results.
"""
kld
=
torch
.
cat
(
self
.
kld
,
dim
=
0
)
assert
kld
.
shape
[
0
]
>=
self
.
num_images
kld_np
=
kld
.
numpy
()
if
self
.
reduction
==
'sum'
:
kld_result
=
np
.
sum
(
kld_np
)
elif
self
.
reduction
==
'mean'
:
kld_result
=
np
.
mean
(
kld_np
)
else
:
kld_result
=
np
.
sum
(
kld_np
)
/
kld_np
.
shape
[
0
]
self
.
_result_str
=
(
f
'
{
kld_result
:.
4
f
}
'
)
return
kld_result
build/lib/mmgen/core/hooks/__init__.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
from
.ceph_hooks
import
PetrelUploadHook
from
.ema_hook
import
ExponentialMovingAverageHook
from
.pggan_fetch_data_hook
import
PGGANFetchDataHook
from
.pickle_data_hook
import
PickleDataHook
from
.visualization
import
VisualizationHook
from
.visualize_training_samples
import
VisualizeUnconditionalSamples
__all__
=
[
'VisualizeUnconditionalSamples'
,
'PGGANFetchDataHook'
,
'ExponentialMovingAverageHook'
,
'VisualizationHook'
,
'PickleDataHook'
,
'PetrelUploadHook'
]
build/lib/mmgen/core/hooks/ceph_hooks.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
import
os
import
mmcv
from
mmcv.runner
import
HOOKS
,
Hook
,
master_only
@
HOOKS
.
register_module
()
class
PetrelUploadHook
(
Hook
):
"""Upload Data with Petrel.
With this hook, users can easily upload data to the cloud server for
saving local spaces. Please read the notes below for using this hook,
especially for the declaration of ``petrel``.
One of the major functions is to transfer the checkpoint files from the
local directory to the cloud server.
.. note::
``petrel`` is a private package containing several commonly used
``AWS`` python API. Currently, this package is only for internal usage
and will not be released to the public. We will support ``boto3`` in
the future. We think this hook is an easy template for you to transfer
to ``boto3``.
Args:
data_path (str, optional): Relative path of the data according to
current working directory. Defaults to 'ckpt'.
suffix (str, optional): Suffix for the data files. Defaults to '.pth'.
ceph_path (str | None, optional): Path in the cloud server.
Defaults to None.
interval (int, optional): Uploading interval (by iterations).
Default: -1.
upload_after_run (bool, optional): Whether to upload after running.
Defaults to True.
rm_orig (bool, optional): Whether to removing the local files after
uploading. Defaults to True.
"""
cfg_path
=
'~/petreloss.conf'
def
__init__
(
self
,
data_path
=
'ckpt'
,
suffix
=
'.pth'
,
ceph_path
=
None
,
interval
=-
1
,
upload_after_run
=
True
,
rm_orig
=
True
):
super
().
__init__
()
self
.
interval
=
interval
self
.
upload_after_run
=
upload_after_run
self
.
data_path
=
data_path
self
.
suffix
=
suffix
self
.
ceph_path
=
ceph_path
self
.
rm_orig
=
rm_orig
# setup petrel client
try
:
from
petrel_client.client
import
Client
except
ImportError
:
raise
ImportError
(
'Please install petrel in advance.'
)
self
.
client
=
Client
(
self
.
cfg_path
)
@
staticmethod
def
upload_dir
(
client
,
local_dir
,
remote_dir
,
exp_name
=
None
,
suffix
=
None
,
remove_local_file
=
True
):
"""Upload a directory to the cloud server.
Args:
client (obj): AWS client.
local_dir (str): Path for the local data.
remote_dir (str): Path for the remote server.
exp_name (str, optional): The experiment name. Defaults to None.
suffix (str, optional): Suffix for the data files.
Defaults to None.
remove_local_file (bool, optional): Whether to removing the local
files after uploading. Defaults to True.
"""
files
=
mmcv
.
scandir
(
local_dir
,
suffix
=
suffix
,
recursive
=
False
)
files
=
[
os
.
path
.
join
(
local_dir
,
x
)
for
x
in
files
]
# remove the rebundant symlinks in the data directory
files
=
[
x
for
x
in
files
if
not
os
.
path
.
islink
(
x
)]
# get the actual exp_name in work_dir
if
exp_name
is
None
:
exp_name
=
local_dir
.
split
(
'/'
)[
-
1
]
mmcv
.
print_log
(
f
'Uploading
{
len
(
files
)
}
files to ceph.'
,
'mmgen'
)
for
file
in
files
:
with
open
(
file
,
'rb'
)
as
f
:
data
=
f
.
read
()
_path_splits
=
file
.
split
(
'/'
)
idx
=
_path_splits
.
index
(
exp_name
)
_rel_path
=
'/'
.
join
(
_path_splits
[
idx
:])
_ceph_path
=
os
.
path
.
join
(
remote_dir
,
_rel_path
)
client
.
put
(
_ceph_path
,
data
)
# remove the local file to save space
if
remove_local_file
:
os
.
remove
(
file
)
@
master_only
def
after_run
(
self
,
runner
):
"""The behavior after the whole running.
Args:
runner (object): The runner.
"""
if
not
self
.
upload_after_run
:
return
_data_path
=
os
.
path
.
join
(
runner
.
work_dir
,
self
.
data_path
)
# get the actual exp_name in work_dir
exp_name
=
runner
.
work_dir
.
split
(
'/'
)[
-
1
]
self
.
upload_dir
(
self
.
client
,
_data_path
,
self
.
ceph_path
,
exp_name
=
exp_name
,
suffix
=
self
.
suffix
,
remove_local_file
=
self
.
rm_orig
)
build/lib/mmgen/core/hooks/ema_hook.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
import
warnings
from
copy
import
deepcopy
import
mmcv
import
torch
from
mmcv.parallel
import
is_module_wrapper
from
mmcv.runner
import
HOOKS
,
Hook
@
HOOKS
.
register_module
()
class
ExponentialMovingAverageHook
(
Hook
):
"""Exponential Moving Average Hook.
Exponential moving average is a trick that widely used in current GAN
literature, e.g., PGGAN, StyleGAN, and BigGAN. This general idea of it is
maintaining a model with the same architecture, but its parameters are
updated as a moving average of the trained weights in the original model.
In general, the model with moving averaged weights achieves better
performance.
Args:
module_keys (str | tuple[str]): The name of the ema model. Note that we
require these keys are followed by '_ema' so that we can easily
find the original model by discarding the last four characters.
interp_mode (str, optional): Mode of the interpolation method.
Defaults to 'lerp'.
interp_cfg (dict | None, optional): Set arguments of the interpolation
function. Defaults to None.
interval (int, optional): Evaluation interval (by iterations).
Default: -1.
start_iter (int, optional): Start iteration for ema. If the start
iteration is not reached, the weights of ema model will maintain
the same as the original one. Otherwise, its parameters are updated
as a moving average of the trained weights in the original model.
Default: 0.
momentum_policy (str, optional): Policy of the momentum updating
method. Defaults to 'fixed'.
momentum_cfg (dict | None, optional): Set arguments of the momentum
updater function. Defaults to None.
"""
_registered_interp_funcs
=
[
'lerp'
]
_registered_momentum_updaters
=
[
'rampup'
,
'fixed'
]
def
__init__
(
self
,
module_keys
,
interp_mode
=
'lerp'
,
interp_cfg
=
None
,
interval
=-
1
,
start_iter
=
0
,
momentum_policy
=
'fixed'
,
momentum_cfg
=
None
):
super
().
__init__
()
# check args
assert
interp_mode
in
self
.
_registered_interp_funcs
,
(
'Supported '
f
'interpolation functions are
{
self
.
_registered_interp_funcs
}
, '
f
'but got
{
interp_mode
}
'
)
assert
momentum_policy
in
self
.
_registered_momentum_updaters
,
(
'Supported momentum policy are'
f
'
{
self
.
_registered_momentum_updaters
}
,'
f
' but got
{
momentum_policy
}
'
)
assert
isinstance
(
module_keys
,
str
)
or
mmcv
.
is_tuple_of
(
module_keys
,
str
)
self
.
module_keys
=
(
module_keys
,
)
if
isinstance
(
module_keys
,
str
)
else
module_keys
# sanity check for the format of module keys
for
k
in
self
.
module_keys
:
assert
k
.
endswith
(
'_ema'
),
'You should give keys that end with "_ema".'
self
.
interp_mode
=
interp_mode
self
.
interp_cfg
=
dict
()
if
interp_cfg
is
None
else
deepcopy
(
interp_cfg
)
self
.
interval
=
interval
self
.
start_iter
=
start_iter
assert
hasattr
(
self
,
interp_mode
),
f
'Currently, we do not support
{
self
.
interp_mode
}
for EMA.'
self
.
interp_func
=
getattr
(
self
,
interp_mode
)
self
.
momentum_cfg
=
dict
()
if
momentum_cfg
is
None
else
deepcopy
(
momentum_cfg
)
self
.
momentum_policy
=
momentum_policy
if
momentum_policy
!=
'fixed'
:
assert
hasattr
(
self
,
momentum_policy
),
f
'Currently, we do not support
{
self
.
momentum_policy
}
for EMA.'
self
.
momentum_updater
=
getattr
(
self
,
momentum_policy
)
@
staticmethod
def
lerp
(
a
,
b
,
momentum
=
0.999
,
momentum_nontrainable
=
0.
,
trainable
=
True
):
"""Does a linear interpolation of two parameters/ buffers.
Args:
a (torch.Tensor): Interpolation start point, refer to orig state.
b (torch.Tensor): Interpolation end point, refer to ema state.
momentum (float, optional): The weight for the interpolation
formula. Defaults to 0.999.
momentum_nontrainable (float, optional): The weight for the
interpolation formula used for nontrainable parameters.
Defaults to 0..
trainable (bool, optional): Whether input parameters is trainable.
If set to False, momentum_nontrainable will be used.
Defaults to True.
Returns:
torch.Tensor: Interpolation result.
"""
m
=
momentum
if
trainable
else
momentum_nontrainable
return
a
+
(
b
-
a
)
*
m
@
staticmethod
def
rampup
(
runner
,
ema_kimg
=
10
,
ema_rampup
=
0.05
,
batch_size
=
4
,
eps
=
1e-8
):
"""Ramp up ema momentum.
Ref: https://github.com/NVlabs/stylegan3/blob/a5a69f58294509598714d1e88c9646c3d7c6ec94/training/training_loop.py#L300-L308 # noqa
Args:
runner (_type_): _description_
ema_kimg (int, optional): Half-life of the exponential moving
average of generator weights. Defaults to 10.
ema_rampup (float, optional): EMA ramp-up coefficient.If set to
None, then rampup will be disabled. Defaults to 0.05.
batch_size (int, optional): Total batch size for one training
iteration. Defaults to 4.
eps (float, optional): Epsiolon to avoid ``batch_size`` divided by
zero. Defaults to 1e-8.
Returns:
dict: Updated momentum.
"""
cur_nimg
=
(
runner
.
iter
+
1
)
*
batch_size
ema_nimg
=
ema_kimg
*
1000
if
ema_rampup
is
not
None
:
ema_nimg
=
min
(
ema_nimg
,
cur_nimg
*
ema_rampup
)
ema_beta
=
0.5
**
(
batch_size
/
max
(
ema_nimg
,
eps
))
return
dict
(
momentum
=
ema_beta
)
def
every_n_iters
(
self
,
runner
,
n
):
if
runner
.
iter
<
self
.
start_iter
:
return
True
return
(
runner
.
iter
+
1
-
self
.
start_iter
)
%
n
==
0
if
n
>
0
else
False
@
torch
.
no_grad
()
def
after_train_iter
(
self
,
runner
):
if
not
self
.
every_n_iters
(
runner
,
self
.
interval
):
return
model
=
runner
.
model
.
module
if
is_module_wrapper
(
runner
.
model
)
else
runner
.
model
# update momentum
_interp_cfg
=
deepcopy
(
self
.
interp_cfg
)
if
self
.
momentum_policy
!=
'fixed'
:
_updated_args
=
self
.
momentum_updater
(
runner
,
**
self
.
momentum_cfg
)
_interp_cfg
.
update
(
_updated_args
)
for
key
in
self
.
module_keys
:
# get current ema states
ema_net
=
getattr
(
model
,
key
)
states_ema
=
ema_net
.
state_dict
(
keep_vars
=
False
)
# get currently original states
net
=
getattr
(
model
,
key
[:
-
4
])
states_orig
=
net
.
state_dict
(
keep_vars
=
True
)
for
k
,
v
in
states_orig
.
items
():
if
runner
.
iter
<
self
.
start_iter
:
states_ema
[
k
].
data
.
copy_
(
v
.
data
)
else
:
states_ema
[
k
]
=
self
.
interp_func
(
v
,
states_ema
[
k
],
trainable
=
v
.
requires_grad
,
**
_interp_cfg
).
detach
()
ema_net
.
load_state_dict
(
states_ema
,
strict
=
True
)
def
before_run
(
self
,
runner
):
model
=
runner
.
model
.
module
if
is_module_wrapper
(
runner
.
model
)
else
runner
.
model
# sanity check for ema model
for
k
in
self
.
module_keys
:
if
not
hasattr
(
model
,
k
)
and
not
hasattr
(
model
,
k
[:
-
4
]):
raise
RuntimeError
(
f
'Cannot find both
{
k
[:
-
4
]
}
and
{
k
}
network for EMA hook.'
)
if
not
hasattr
(
model
,
k
)
and
hasattr
(
model
,
k
[:
-
4
]):
setattr
(
model
,
k
,
deepcopy
(
getattr
(
model
,
k
[:
-
4
])))
warnings
.
warn
(
f
'We do not suggest construct and initialize EMA model
{
k
}
'
' in hook. You may explicitly define it by yourself.'
)
build/lib/mmgen/core/hooks/pggan_fetch_data_hook.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
import
torch
from
mmcv.parallel
import
is_module_wrapper
from
mmcv.runner
import
HOOKS
,
Hook
@
HOOKS
.
register_module
()
class
PGGANFetchDataHook
(
Hook
):
"""PGGAN Fetch Data Hook.
Args:
interval (int, optional): The interval of calling this hook. If set
to -1, the visualization hook will not be called. Defaults to 1.
"""
def
__init__
(
self
,
interval
=
1
):
super
().
__init__
()
self
.
interval
=
interval
def
before_fetch_train_data
(
self
,
runner
):
"""The behavior before fetch train data.
Args:
runner (object): The runner.
"""
if
not
self
.
every_n_iters
(
runner
,
self
.
interval
):
return
_module
=
runner
.
model
.
module
if
is_module_wrapper
(
runner
.
model
)
else
runner
.
model
_next_scale_int
=
_module
.
_next_scale_int
if
isinstance
(
_next_scale_int
,
torch
.
Tensor
):
_next_scale_int
=
_next_scale_int
.
item
()
runner
.
data_loader
.
update_dataloader
(
_next_scale_int
)
build/lib/mmgen/core/hooks/pickle_data_hook.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
import
logging
import
os
import
pickle
import
mmcv
import
torch
from
mmcv.runner
import
HOOKS
,
Hook
from
mmcv.runner.dist_utils
import
master_only
@
HOOKS
.
register_module
()
class
PickleDataHook
(
Hook
):
"""Pickle Useful Data Hook.
This hook will be used in SinGAN training for saving some important data
that will be used in testing or inference.
Args:
output_dir (str): The output path for saving pickled data.
data_name_list (list[str]): The list contains the name of results in
outputs dict.
interval (int): The interval of calling this hook. If set to -1,
the visualization hook will not be called. Default: -1.
before_run (bool, optional): Whether to save before running.
Defaults to False.
after_run (bool, optional): Whether to save after running.
Defaults to False.
filename_tmpl (str, optional): Format string used to save images. The
output file name will be formatted as this args.
Defaults to 'iter_{}.pkl'.
"""
def
__init__
(
self
,
output_dir
,
data_name_list
,
interval
=-
1
,
before_run
=
False
,
after_run
=
False
,
filename_tmpl
=
'iter_{}.pkl'
):
assert
mmcv
.
is_list_of
(
data_name_list
,
str
)
self
.
output_dir
=
output_dir
self
.
data_name_list
=
data_name_list
self
.
interval
=
interval
self
.
filename_tmpl
=
filename_tmpl
self
.
_before_run
=
before_run
self
.
_after_run
=
after_run
@
master_only
def
after_run
(
self
,
runner
):
"""The behavior after each train iteration.
Args:
runner (object): The runner.
"""
if
self
.
_after_run
:
self
.
_pickle_data
(
runner
)
@
master_only
def
before_run
(
self
,
runner
):
"""The behavior after each train iteration.
Args:
runner (object): The runner.
"""
if
self
.
_before_run
:
self
.
_pickle_data
(
runner
)
@
master_only
def
after_train_iter
(
self
,
runner
):
"""The behavior after each train iteration.
Args:
runner (object): The runner.
"""
if
not
self
.
every_n_iters
(
runner
,
self
.
interval
):
return
self
.
_pickle_data
(
runner
)
def
_pickle_data
(
self
,
runner
):
filename
=
self
.
filename_tmpl
.
format
(
runner
.
iter
+
1
)
if
not
hasattr
(
self
,
'_out_dir'
):
self
.
_out_dir
=
os
.
path
.
join
(
runner
.
work_dir
,
self
.
output_dir
)
mmcv
.
mkdir_or_exist
(
self
.
_out_dir
)
file_path
=
os
.
path
.
join
(
self
.
_out_dir
,
filename
)
with
open
(
file_path
,
'wb'
)
as
f
:
data
=
runner
.
outputs
[
'results'
]
not_find_keys
=
[]
data_dict
=
{}
for
k
in
self
.
data_name_list
:
if
k
in
data
.
keys
():
data_dict
[
k
]
=
self
.
_get_numpy_data
(
data
[
k
])
else
:
not_find_keys
.
append
(
k
)
pickle
.
dump
(
data_dict
,
f
)
mmcv
.
print_log
(
f
'Pickle data in
{
filename
}
'
,
'mmgen'
)
if
len
(
not_find_keys
)
>
0
:
mmcv
.
print_log
(
f
'Cannot find keys for pickling:
{
not_find_keys
}
'
,
'mmgen'
,
level
=
logging
.
WARN
)
f
.
flush
()
def
_get_numpy_data
(
self
,
data
):
if
isinstance
(
data
,
list
):
return
[
self
.
_get_numpy_data
(
x
)
for
x
in
data
]
if
isinstance
(
data
,
torch
.
Tensor
):
return
data
.
cpu
().
numpy
()
return
data
build/lib/mmgen/core/hooks/visualization.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
import
os.path
as
osp
import
mmcv
import
torch
from
mmcv.runner
import
HOOKS
,
Hook
from
mmcv.runner.dist_utils
import
master_only
from
torchvision.utils
import
save_image
@
HOOKS
.
register_module
(
'MMGenVisualizationHook'
)
class
VisualizationHook
(
Hook
):
"""Visualization hook.
In this hook, we use the official api `save_image` in torchvision to save
the visualization results.
Args:
output_dir (str): The file path to store visualizations.
res_name_list (str): The list contains the name of results in outputs
dict. The results in outputs dict must be a torch.Tensor with shape
(n, c, h, w).
interval (int): The interval of calling this hook. If set to -1,
the visualization hook will not be called. Default: -1.
filename_tmpl (str): Format string used to save images. The output file
name will be formatted as this args. Default: 'iter_{}.png'.
rerange (bool): Whether to rerange the output value from [-1, 1] to
[0, 1]. We highly recommend users should preprocess the
visualization results on their own. Here, we just provide a simple
interface. Default: True.
bgr2rgb (bool): Whether to reformat the channel dimension from BGR to
RGB. The final image we will save is following RGB style.
Default: True.
nrow (int): The number of samples in a row. Default: 1.
padding (int): The number of padding pixels between each samples.
Default: 4.
"""
def
__init__
(
self
,
output_dir
,
res_name_list
,
interval
=-
1
,
filename_tmpl
=
'iter_{}.png'
,
rerange
=
True
,
bgr2rgb
=
True
,
nrow
=
1
,
padding
=
4
):
assert
mmcv
.
is_list_of
(
res_name_list
,
str
)
self
.
output_dir
=
output_dir
self
.
res_name_list
=
res_name_list
self
.
interval
=
interval
self
.
filename_tmpl
=
filename_tmpl
self
.
bgr2rgb
=
bgr2rgb
self
.
rerange
=
rerange
self
.
nrow
=
nrow
self
.
padding
=
padding
@
master_only
def
after_train_iter
(
self
,
runner
):
"""The behavior after each train iteration.
Args:
runner (object): The runner.
"""
if
not
self
.
every_n_iters
(
runner
,
self
.
interval
):
return
results
=
runner
.
outputs
[
'results'
]
filename
=
self
.
filename_tmpl
.
format
(
runner
.
iter
+
1
)
# img_list = [x for k, x in results.items() if k in self.res_name_list]
img_list
=
[
results
[
k
]
for
k
in
self
.
res_name_list
if
k
in
results
]
img_cat
=
torch
.
cat
(
img_list
,
dim
=
3
).
detach
()
if
self
.
rerange
:
img_cat
=
((
img_cat
+
1
)
/
2
)
if
self
.
bgr2rgb
:
img_cat
=
img_cat
[:,
[
2
,
1
,
0
],
...]
img_cat
=
img_cat
.
clamp_
(
0
,
1
)
if
not
hasattr
(
self
,
'_out_dir'
):
self
.
_out_dir
=
osp
.
join
(
runner
.
work_dir
,
self
.
output_dir
)
mmcv
.
mkdir_or_exist
(
self
.
_out_dir
)
save_image
(
img_cat
,
osp
.
join
(
self
.
_out_dir
,
filename
),
nrow
=
self
.
nrow
,
padding
=
self
.
padding
)
build/lib/mmgen/core/hooks/visualize_training_samples.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
import
os.path
as
osp
import
mmcv
import
torch
from
mmcv.runner
import
HOOKS
,
Hook
from
mmcv.runner.dist_utils
import
master_only
from
torchvision.utils
import
save_image
@
HOOKS
.
register_module
()
class
VisualizeUnconditionalSamples
(
Hook
):
"""Visualization hook for unconditional GANs.
In this hook, we use the official api `save_image` in torchvision to save
the visualization results.
Args:
output_dir (str): The file path to store visualizations.
fixed_noise (bool, optional): Whether to use fixed noises in sampling.
Defaults to True.
num_samples (int, optional): The number of samples to show in
visualization. Defaults to 16.
interval (int): The interval of calling this hook. If set to -1,
the visualization hook will not be called. Default: -1.
filename_tmpl (str): Format string used to save images. The output file
name will be formatted as this args. Default: 'iter_{}.png'.
rerange (bool): Whether to rerange the output value from [-1, 1] to
[0, 1]. We highly recommend users should preprocess the
visualization results on their own. Here, we just provide a simple
interface. Default: True.
bgr2rgb (bool): Whether to reformat the channel dimension from BGR to
RGB. The final image we will save is following RGB style.
Default: True.
nrow (int): The number of samples in a row. Default: 1.
padding (int): The number of padding pixels between each samples.
Default: 4.
kwargs (dict | None, optional): Key-word arguments for sampling
function. Defaults to None.
"""
def
__init__
(
self
,
output_dir
,
fixed_noise
=
True
,
num_samples
=
16
,
interval
=-
1
,
filename_tmpl
=
'iter_{}.png'
,
rerange
=
True
,
bgr2rgb
=
True
,
nrow
=
4
,
padding
=
0
,
kwargs
=
None
):
self
.
output_dir
=
output_dir
self
.
fixed_noise
=
fixed_noise
self
.
num_samples
=
num_samples
self
.
interval
=
interval
self
.
filename_tmpl
=
filename_tmpl
self
.
bgr2rgb
=
bgr2rgb
self
.
rerange
=
rerange
self
.
nrow
=
nrow
self
.
padding
=
padding
# the sampling noise will be initialized by the first sampling.
self
.
sampling_noise
=
None
self
.
kwargs
=
kwargs
if
kwargs
is
not
None
else
dict
()
@
master_only
def
after_train_iter
(
self
,
runner
):
"""The behavior after each train iteration.
Args:
runner (object): The runner.
"""
if
not
self
.
every_n_iters
(
runner
,
self
.
interval
):
return
# eval mode
runner
.
model
.
eval
()
# no grad in sampling
with
torch
.
no_grad
():
outputs_dict
=
runner
.
model
(
self
.
sampling_noise
,
return_loss
=
False
,
num_batches
=
self
.
num_samples
,
return_noise
=
True
,
**
self
.
kwargs
)
imgs
=
outputs_dict
[
'fake_img'
]
noise_
=
outputs_dict
[
'noise_batch'
]
# initialize samling noise with the first returned noise
if
self
.
sampling_noise
is
None
and
self
.
fixed_noise
:
self
.
sampling_noise
=
noise_
# train mode
runner
.
model
.
train
()
filename
=
self
.
filename_tmpl
.
format
(
runner
.
iter
+
1
)
if
self
.
rerange
:
imgs
=
((
imgs
+
1
)
/
2
)
if
self
.
bgr2rgb
and
imgs
.
size
(
1
)
==
3
:
imgs
=
imgs
[:,
[
2
,
1
,
0
],
...]
if
imgs
.
size
(
1
)
==
1
:
imgs
=
torch
.
cat
([
imgs
,
imgs
,
imgs
],
dim
=
1
)
imgs
=
imgs
.
clamp_
(
0
,
1
)
mmcv
.
mkdir_or_exist
(
osp
.
join
(
runner
.
work_dir
,
self
.
output_dir
))
save_image
(
imgs
,
osp
.
join
(
runner
.
work_dir
,
self
.
output_dir
,
filename
),
nrow
=
self
.
nrow
,
padding
=
self
.
padding
)
build/lib/mmgen/core/optimizer/__init__.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
from
.builder
import
build_optimizers
__all__
=
[
'build_optimizers'
]
build/lib/mmgen/core/optimizer/builder.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
from
mmcv.runner
import
build_optimizer
def
build_optimizers
(
model
,
cfgs
):
"""Build multiple optimizers from configs.
If `cfgs` contains several dicts for optimizers, then a dict for each
constructed optimizers will be returned.
If `cfgs` only contains one optimizer config, the constructed optimizer
itself will be returned.
For example,
1) Multiple optimizer configs:
.. code-block:: python
optimizer_cfg = dict(
model1=dict(type='SGD', lr=lr),
model2=dict(type='SGD', lr=lr))
The return dict is
``dict('model1': torch.optim.Optimizer, 'model2': torch.optim.Optimizer)``
2) Single optimizer config:
.. code-block:: python
optimizer_cfg = dict(type='SGD', lr=lr)
The return is ``torch.optim.Optimizer``.
Args:
model (:obj:`nn.Module`): The model with parameters to be optimized.
cfgs (dict): The config dict of the optimizer.
Returns:
dict[:obj:`torch.optim.Optimizer`] | :obj:`torch.optim.Optimizer`:
The initialized optimizers.
"""
optimizers
=
{}
if
hasattr
(
model
,
'module'
):
model
=
model
.
module
# determine whether 'cfgs' has several dicts for optimizers
is_dict_of_dict
=
True
for
key
,
cfg
in
cfgs
.
items
():
if
not
isinstance
(
cfg
,
dict
):
is_dict_of_dict
=
False
if
is_dict_of_dict
:
for
key
,
cfg
in
cfgs
.
items
():
cfg_
=
cfg
.
copy
()
module
=
getattr
(
model
,
key
)
optimizers
[
key
]
=
build_optimizer
(
module
,
cfg_
)
return
optimizers
return
build_optimizer
(
model
,
cfgs
)
build/lib/mmgen/core/registry.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
from
mmcv.utils
import
Registry
,
build_from_cfg
METRICS
=
Registry
(
'metric'
)
def
build
(
cfg
,
registry
,
default_args
=
None
):
"""Build a module.
Args:
cfg (dict, list[dict]): The config of modules, is is either a dict
or a list of configs.
registry (:obj:`Registry`): A registry the module belongs to.
default_args (dict, optional): Default arguments to build the module.
Defaults to None.
Returns:
nn.Module: A built nn module.
"""
if
isinstance
(
cfg
,
list
):
modules
=
[
build_from_cfg
(
cfg_
,
registry
,
default_args
)
for
cfg_
in
cfg
]
return
modules
return
build_from_cfg
(
cfg
,
registry
,
default_args
)
def
build_metric
(
cfg
):
"""Build a metric calculator."""
return
build
(
cfg
,
METRICS
)
build/lib/mmgen/core/runners/__init__.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
from
.dynamic_iterbased_runner
import
DynamicIterBasedRunner
__all__
=
[
'DynamicIterBasedRunner'
]
build/lib/mmgen/core/runners/apex_amp_utils.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
try
:
from
apex
import
amp
except
ImportError
:
amp
=
None
def
apex_amp_initialize
(
models
,
optimizers
,
init_args
=
None
,
mode
=
'gan'
):
"""Initialize apex.amp for mixed-precision training.
Args:
models (nn.Module | list[Module]): Modules to be wrapped with apex.amp.
optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
init_args (dict | None, optional): Config for amp initialization.
Defaults to None.
mode (str, optional): The moded used to initialize the apex.map.
Different modes lead to different wrapping mode for models and
optimizers. Defaults to 'gan'.
Returns:
Module, :obj:`Optimizer`: Wrapped module and optimizer.
"""
init_args
=
init_args
or
dict
()
if
mode
==
'gan'
:
_optmizers
=
[
optimizers
[
'generator'
],
optimizers
[
'discriminator'
]]
models
,
_optmizers
=
amp
.
initialize
(
models
,
_optmizers
,
**
init_args
)
optimizers
[
'generator'
]
=
_optmizers
[
0
]
optimizers
[
'discriminator'
]
=
_optmizers
[
1
]
return
models
,
optimizers
else
:
raise
NotImplementedError
(
f
'Cannot initialize apex.amp with mode
{
mode
}
'
)
build/lib/mmgen/core/runners/checkpoint.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
import
os.path
as
osp
import
time
from
tempfile
import
TemporaryDirectory
import
mmcv
import
torch
from
mmcv.parallel
import
is_module_wrapper
from
mmcv.runner.checkpoint
import
get_state_dict
,
weights_to_cpu
from
torch.optim
import
Optimizer
def
save_checkpoint
(
model
,
filename
,
optimizer
=
None
,
loss_scaler
=
None
,
save_apex_amp
=
False
,
meta
=
None
):
"""Save checkpoint to file.
The checkpoint will have 3 or more fields: ``meta``, ``state_dict`` and
``optimizer``. By default ``meta`` will contain version and time info.
In mixed-precision training, ``loss_scaler`` or ``amp.state_dict`` will be
saved in checkpoint.
Args:
model (Module): Module whose params are to be saved.
filename (str): Checkpoint filename.
optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
loss_scaler (Object, optional): Loss scaler used for FP16 training.
save_apex_amp (bool, optional): Whether to save apex.amp
``state_dict``.
meta (dict, optional): Metadata to be saved in checkpoint.
"""
if
meta
is
None
:
meta
=
{}
elif
not
isinstance
(
meta
,
dict
):
raise
TypeError
(
f
'meta must be a dict or None, but got
{
type
(
meta
)
}
'
)
meta
.
update
(
mmcv_version
=
mmcv
.
__version__
,
time
=
time
.
asctime
())
if
is_module_wrapper
(
model
):
model
=
model
.
module
if
hasattr
(
model
,
'CLASSES'
)
and
model
.
CLASSES
is
not
None
:
# save class name to the meta
meta
.
update
(
CLASSES
=
model
.
CLASSES
)
checkpoint
=
{
'meta'
:
meta
,
'state_dict'
:
weights_to_cpu
(
get_state_dict
(
model
))
}
# save optimizer state dict in the checkpoint
if
isinstance
(
optimizer
,
Optimizer
):
checkpoint
[
'optimizer'
]
=
optimizer
.
state_dict
()
elif
isinstance
(
optimizer
,
dict
):
checkpoint
[
'optimizer'
]
=
{}
for
name
,
optim
in
optimizer
.
items
():
checkpoint
[
'optimizer'
][
name
]
=
optim
.
state_dict
()
# save loss scaler for mixed-precision (FP16) training
if
loss_scaler
is
not
None
:
checkpoint
[
'loss_scaler'
]
=
loss_scaler
.
state_dict
()
# save state_dict from apex.amp
if
save_apex_amp
:
from
apex
import
amp
checkpoint
[
'amp'
]
=
amp
.
state_dict
()
if
filename
.
startswith
(
'pavi://'
):
try
:
from
pavi
import
modelcloud
from
pavi.exception
import
NodeNotFoundError
except
ImportError
:
raise
ImportError
(
'Please install pavi to load checkpoint from modelcloud.'
)
model_path
=
filename
[
7
:]
root
=
modelcloud
.
Folder
()
model_dir
,
model_name
=
osp
.
split
(
model_path
)
try
:
model
=
modelcloud
.
get
(
model_dir
)
except
NodeNotFoundError
:
model
=
root
.
create_training_model
(
model_dir
)
with
TemporaryDirectory
()
as
tmp_dir
:
checkpoint_file
=
osp
.
join
(
tmp_dir
,
model_name
)
with
open
(
checkpoint_file
,
'wb'
)
as
f
:
torch
.
save
(
checkpoint
,
f
)
f
.
flush
()
model
.
create_file
(
checkpoint_file
,
name
=
model_name
)
else
:
mmcv
.
mkdir_or_exist
(
osp
.
dirname
(
filename
))
# immediately flush buffer
with
open
(
filename
,
'wb'
)
as
f
:
torch
.
save
(
checkpoint
,
f
)
f
.
flush
()
build/lib/mmgen/core/runners/dynamic_iterbased_runner.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
import
os.path
as
osp
import
platform
import
shutil
import
time
import
warnings
from
functools
import
partial
import
mmcv
import
torch
import
torch.distributed
as
dist
from
mmcv.parallel
import
collate
,
is_module_wrapper
from
mmcv.runner
import
HOOKS
,
RUNNERS
,
IterBasedRunner
,
get_host_info
from
torch.optim
import
Optimizer
from
torch.utils.data
import
DataLoader
from
.checkpoint
import
save_checkpoint
try
:
# If PyTorch version >= 1.6.0, torch.cuda.amp.GradScaler would be imported
# and used; otherwise, auto fp16 will adopt mmcv's implementation.
from
torch.cuda.amp
import
GradScaler
except
ImportError
:
pass
class
IterLoader
:
"""Iteration based dataloader.
This wrapper for dataloader is to matching the iter-based training
proceduer.
Args:
dataloader (object): Dataloader in PyTorch.
runner (object): ``mmcv.Runner``
"""
def
__init__
(
self
,
dataloader
,
runner
):
self
.
_dataloader
=
dataloader
self
.
runner
=
runner
self
.
iter_loader
=
iter
(
self
.
_dataloader
)
self
.
_epoch
=
0
@
property
def
epoch
(
self
):
"""The number of current epoch.
Returns:
int: Epoch number.
"""
return
self
.
_epoch
def
update_dataloader
(
self
,
curr_scale
):
"""Update dataloader.
Update the dataloader according to the `curr_scale`. This functionality
is very helpful in training progressive growing GANs in which the
dataloader should be updated according to the scale of the models in
training.
Args:
curr_scale (int): The scale in current stage.
"""
# update dataset, sampler, and samples per gpu in dataloader
if
hasattr
(
self
.
_dataloader
.
dataset
,
'update_annotations'
):
update_flag
=
self
.
_dataloader
.
dataset
.
update_annotations
(
curr_scale
)
else
:
update_flag
=
False
if
update_flag
:
# the sampler should be updated with the modified dataset
assert
hasattr
(
self
.
_dataloader
.
sampler
,
'update_sampler'
)
samples_per_gpu
=
None
if
not
hasattr
(
self
.
_dataloader
.
dataset
,
'samples_per_gpu'
)
else
self
.
_dataloader
.
dataset
.
samples_per_gpu
self
.
_dataloader
.
sampler
.
update_sampler
(
self
.
_dataloader
.
dataset
,
samples_per_gpu
)
# update samples per gpu
if
samples_per_gpu
is
not
None
:
if
dist
.
is_initialized
():
# samples = samples_per_gpu
# self._dataloader.collate_fn = partial(
# collate, samples_per_gpu=samples)
self
.
_dataloader
=
DataLoader
(
self
.
_dataloader
.
dataset
,
batch_size
=
samples_per_gpu
,
sampler
=
self
.
_dataloader
.
sampler
,
num_workers
=
self
.
_dataloader
.
num_workers
,
collate_fn
=
partial
(
collate
,
samples_per_gpu
=
samples_per_gpu
),
shuffle
=
False
,
worker_init_fn
=
self
.
_dataloader
.
worker_init_fn
)
self
.
iter_loader
=
iter
(
self
.
_dataloader
)
else
:
raise
NotImplementedError
(
'Currently, we only support dynamic batch size in'
' ddp, because the number of gpus in DataParallel '
'cannot be obtained easily.'
)
def
__next__
(
self
):
try
:
data
=
next
(
self
.
iter_loader
)
except
StopIteration
:
self
.
_epoch
+=
1
if
hasattr
(
self
.
_dataloader
.
sampler
,
'set_epoch'
):
self
.
_dataloader
.
sampler
.
set_epoch
(
self
.
_epoch
)
self
.
iter_loader
=
iter
(
self
.
_dataloader
)
data
=
next
(
self
.
iter_loader
)
return
data
def
__len__
(
self
):
return
len
(
self
.
_dataloader
)
@
RUNNERS
.
register_module
()
class
DynamicIterBasedRunner
(
IterBasedRunner
):
"""Dynamic Iterbased Runner.
In this Dynamic Iterbased Runner, we will pass the ``reducer`` to the
``train_step`` so that the models can be trained with dynamic architecture.
More details and clarification can be found in this [tutorial](docs/en/tutorials/ddp_train_gans.md). # noqa
Args:
is_dynamic_ddp (bool, optional): Whether to adopt the dynamic ddp.
Defaults to False.
pass_training_status (bool, optional): Whether to pass the training
status. Defaults to False.
fp16_loss_scaler (dict | None, optional): Config for fp16 GradScaler
from ``torch.cuda.amp``. Defaults to None.
use_apex_amp (bool, optional): Whether to use apex.amp to start mixed
precision training. Defaults to False.
"""
def
__init__
(
self
,
*
args
,
is_dynamic_ddp
=
False
,
pass_training_status
=
False
,
fp16_loss_scaler
=
None
,
use_apex_amp
=
False
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
if
is_module_wrapper
(
self
.
model
):
_model
=
self
.
model
.
module
else
:
_model
=
self
.
model
self
.
is_dynamic_ddp
=
is_dynamic_ddp
self
.
pass_training_status
=
pass_training_status
# add a flag for checking if `self.optimizer` comes from `_model`
self
.
optimizer_from_model
=
False
# add support for optimizer is None.
# sanity check for whether `_model` contains self-defined optimizer
if
hasattr
(
_model
,
'optimizer'
):
assert
self
.
optimizer
is
None
,
(
'Runner and model cannot contain optimizer at the same time.'
)
self
.
optimizer_from_model
=
True
self
.
optimizer
=
_model
.
optimizer
# add fp16 grad scaler, using pytorch official GradScaler
self
.
with_fp16_grad_scaler
=
False
if
fp16_loss_scaler
is
not
None
:
self
.
loss_scaler
=
GradScaler
(
**
fp16_loss_scaler
)
self
.
with_fp16_grad_scaler
=
True
mmcv
.
print_log
(
'Use FP16 grad scaler in Training'
,
'mmgen'
)
# flag to use amp in apex (NVIDIA)
self
.
use_apex_amp
=
use_apex_amp
def
call_hook
(
self
,
fn_name
):
"""Call all hooks.
Args:
fn_name (str): The function name in each hook to be called, such as
"before_train_epoch".
"""
for
hook
in
self
.
_hooks
:
if
hasattr
(
hook
,
fn_name
):
getattr
(
hook
,
fn_name
)(
self
)
def
train
(
self
,
data_loader
,
**
kwargs
):
if
is_module_wrapper
(
self
.
model
):
_model
=
self
.
model
.
module
else
:
_model
=
self
.
model
self
.
model
.
train
()
self
.
mode
=
'train'
# check if self.optimizer from model and track it
if
self
.
optimizer_from_model
:
self
.
optimizer
=
_model
.
optimizer
self
.
data_loader
=
data_loader
self
.
_epoch
=
data_loader
.
epoch
self
.
call_hook
(
'before_fetch_train_data'
)
data_batch
=
next
(
self
.
data_loader
)
self
.
call_hook
(
'before_train_iter'
)
# prepare input args for train_step
# running status
if
self
.
pass_training_status
:
running_status
=
dict
(
iteration
=
self
.
iter
,
epoch
=
self
.
epoch
)
kwargs
[
'running_status'
]
=
running_status
# ddp reducer for tracking dynamic computational graph
if
self
.
is_dynamic_ddp
:
kwargs
.
update
(
dict
(
ddp_reducer
=
self
.
model
.
reducer
))
if
self
.
with_fp16_grad_scaler
:
kwargs
.
update
(
dict
(
loss_scaler
=
self
.
loss_scaler
))
if
self
.
use_apex_amp
:
kwargs
.
update
(
dict
(
use_apex_amp
=
True
))
outputs
=
self
.
model
.
train_step
(
data_batch
,
self
.
optimizer
,
**
kwargs
)
# the loss scaler should be updated after ``train_step``
if
self
.
with_fp16_grad_scaler
:
self
.
loss_scaler
.
update
()
# further check for the cases where the optimizer is built in
# `train_step`.
if
self
.
optimizer
is
None
:
if
hasattr
(
_model
,
'optimizer'
):
self
.
optimizer_from_model
=
True
self
.
optimizer
=
_model
.
optimizer
# check if self.optimizer from model and track it
if
self
.
optimizer_from_model
:
self
.
optimizer
=
_model
.
optimizer
if
not
isinstance
(
outputs
,
dict
):
raise
TypeError
(
'model.train_step() must return a dict'
)
if
'log_vars'
in
outputs
:
self
.
log_buffer
.
update
(
outputs
[
'log_vars'
],
outputs
[
'num_samples'
])
self
.
outputs
=
outputs
self
.
call_hook
(
'after_train_iter'
)
self
.
_inner_iter
+=
1
self
.
_iter
+=
1
def
run
(
self
,
data_loaders
,
workflow
,
max_iters
=
None
,
**
kwargs
):
"""Start running.
Args:
data_loaders (list[:obj:`DataLoader`]): Dataloaders for training
and validation.
workflow (list[tuple]): A list of (phase, iters) to specify the
running order and iterations. E.g, [('train', 10000),
('val', 1000)] means running 10000 iterations for training and
1000 iterations for validation, iteratively.
"""
assert
isinstance
(
data_loaders
,
list
)
assert
mmcv
.
is_list_of
(
workflow
,
tuple
)
assert
len
(
data_loaders
)
==
len
(
workflow
)
if
max_iters
is
not
None
:
warnings
.
warn
(
'setting max_iters in run is deprecated, '
'please set max_iters in runner_config'
,
DeprecationWarning
)
self
.
_max_iters
=
max_iters
assert
self
.
_max_iters
is
not
None
,
(
'max_iters must be specified during instantiation'
)
work_dir
=
self
.
work_dir
if
self
.
work_dir
is
not
None
else
'NONE'
self
.
logger
.
info
(
'Start running, host: %s, work_dir: %s'
,
get_host_info
(),
work_dir
)
self
.
logger
.
info
(
'workflow: %s, max: %d iters'
,
workflow
,
self
.
_max_iters
)
self
.
call_hook
(
'before_run'
)
iter_loaders
=
[
IterLoader
(
x
,
self
)
for
x
in
data_loaders
]
self
.
call_hook
(
'before_epoch'
)
while
self
.
iter
<
self
.
_max_iters
:
for
i
,
flow
in
enumerate
(
workflow
):
self
.
_inner_iter
=
0
mode
,
iters
=
flow
if
not
isinstance
(
mode
,
str
)
or
not
hasattr
(
self
,
mode
):
raise
ValueError
(
'runner has no method named "{}" to run a workflow'
.
format
(
mode
))
iter_runner
=
getattr
(
self
,
mode
)
for
_
in
range
(
iters
):
if
mode
==
'train'
and
self
.
iter
>=
self
.
_max_iters
:
break
iter_runner
(
iter_loaders
[
i
],
**
kwargs
)
time
.
sleep
(
1
)
# wait for some hooks like loggers to finish
self
.
call_hook
(
'after_epoch'
)
self
.
call_hook
(
'after_run'
)
def
resume
(
self
,
checkpoint
,
resume_optimizer
=
True
,
resume_loss_scaler
=
True
,
map_location
=
'default'
):
"""Resume model from checkpoint.
Args:
checkpoint (str): Checkpoint to resume from.
resume_optimizer (bool, optional): Whether resume the optimizer(s)
if the checkpoint file includes optimizer(s). Default to True.
resume_loss_scaler (bool, optional): Whether to resume the loss
scaler (GradScaler) from ``torch.cuda.amp``. Defaults to True.
map_location (str, optional): Same as :func:`torch.load`.
Default to 'default'.
"""
if
map_location
==
'default'
:
device_id
=
torch
.
cuda
.
current_device
()
checkpoint
=
self
.
load_checkpoint
(
checkpoint
,
map_location
=
lambda
storage
,
loc
:
storage
.
cuda
(
device_id
))
else
:
checkpoint
=
self
.
load_checkpoint
(
checkpoint
,
map_location
=
map_location
)
self
.
_epoch
=
checkpoint
[
'meta'
][
'epoch'
]
self
.
_iter
=
checkpoint
[
'meta'
][
'iter'
]
self
.
_inner_iter
=
checkpoint
[
'meta'
][
'iter'
]
if
'optimizer'
in
checkpoint
and
resume_optimizer
:
if
isinstance
(
self
.
optimizer
,
Optimizer
):
self
.
optimizer
.
load_state_dict
(
checkpoint
[
'optimizer'
])
elif
isinstance
(
self
.
optimizer
,
dict
):
for
k
in
self
.
optimizer
.
keys
():
self
.
optimizer
[
k
].
load_state_dict
(
checkpoint
[
'optimizer'
][
k
])
else
:
raise
TypeError
(
'Optimizer should be dict or torch.optim.Optimizer '
f
'but got
{
type
(
self
.
optimizer
)
}
'
)
if
'loss_scaler'
in
checkpoint
and
resume_loss_scaler
:
self
.
loss_scaler
.
load_state_dict
(
checkpoint
[
'loss_scaler'
])
if
self
.
use_apex_amp
:
from
apex
import
amp
amp
.
load_state_dict
(
checkpoint
[
'amp'
])
self
.
logger
.
info
(
f
'resumed from epoch:
{
self
.
epoch
}
, iter
{
self
.
iter
}
'
)
def
save_checkpoint
(
self
,
out_dir
,
filename_tmpl
=
'iter_{}.pth'
,
meta
=
None
,
save_optimizer
=
True
,
create_symlink
=
True
):
"""Save checkpoint to file.
Args:
out_dir (str): Directory to save checkpoint files.
filename_tmpl (str, optional): Checkpoint file template.
Defaults to 'iter_{}.pth'.
meta (dict, optional): Metadata to be saved in checkpoint.
Defaults to None.
save_optimizer (bool, optional): Whether save optimizer.
Defaults to True.
create_symlink (bool, optional): Whether create symlink to the
latest checkpoint file. Defaults to True.
"""
if
meta
is
None
:
meta
=
dict
(
iter
=
self
.
iter
+
1
,
epoch
=
self
.
epoch
+
1
)
elif
isinstance
(
meta
,
dict
):
meta
.
update
(
iter
=
self
.
iter
+
1
,
epoch
=
self
.
epoch
+
1
)
else
:
raise
TypeError
(
f
'meta should be a dict or None, but got
{
type
(
meta
)
}
'
)
if
self
.
meta
is
not
None
:
meta
.
update
(
self
.
meta
)
filename
=
filename_tmpl
.
format
(
self
.
iter
+
1
)
filepath
=
osp
.
join
(
out_dir
,
filename
)
optimizer
=
self
.
optimizer
if
save_optimizer
else
None
_loss_scaler
=
self
.
loss_scaler
if
self
.
with_fp16_grad_scaler
else
None
save_checkpoint
(
self
.
model
,
filepath
,
optimizer
=
optimizer
,
loss_scaler
=
_loss_scaler
,
save_apex_amp
=
self
.
use_apex_amp
,
meta
=
meta
)
# in some environments, `os.symlink` is not supported, you may need to
# set `create_symlink` to False
if
create_symlink
:
dst_file
=
osp
.
join
(
out_dir
,
'latest.pth'
)
if
platform
.
system
()
!=
'Windows'
:
mmcv
.
symlink
(
filename
,
dst_file
)
else
:
shutil
.
copy
(
filepath
,
dst_file
)
def
register_lr_hook
(
self
,
lr_config
):
if
lr_config
is
None
:
return
if
isinstance
(
lr_config
,
dict
):
assert
'policy'
in
lr_config
policy_type
=
lr_config
.
pop
(
'policy'
)
# If the type of policy is all in lower case, e.g., 'cyclic',
# then its first letter will be capitalized, e.g., to be 'Cyclic'.
# This is for the convenient usage of Lr updater.
# Since this is not applicable for `
# CosineAnnealingLrUpdater`,
# the string will not be changed if it contains capital letters.
if
policy_type
==
policy_type
.
lower
():
policy_type
=
policy_type
.
title
()
hook_type
=
policy_type
+
'LrUpdaterHook'
lr_config
[
'type'
]
=
hook_type
hook
=
mmcv
.
build_from_cfg
(
lr_config
,
HOOKS
)
else
:
hook
=
lr_config
self
.
register_hook
(
hook
)
build/lib/mmgen/core/runners/fp16_utils.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
import
functools
from
collections
import
abc
from
inspect
import
getfullargspec
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
mmcv.utils
import
TORCH_VERSION
try
:
# If PyTorch version >= 1.6.0, torch.cuda.amp.autocast would be imported
# and used; otherwise, auto fp16 will adopt mmcv's implementation.
from
torch.cuda.amp
import
autocast
except
ImportError
:
pass
def
nan_to_num
(
x
,
nan
=
0.0
,
posinf
=
None
,
neginf
=
None
,
*
,
out
=
None
):
r
"""Replaces :literal:`NaN`, positive infinity, and negative infinity
values in :attr:`input` with the values specified by :attr:`nan`,
:attr:`posinf`, and :attr:`neginf`, respectively. By default,
:literal:`NaN`s are replaced with zero, positive infinity is replaced with
the greatest finite value representable by :attr:`input`'s dtype, and
negative infinity is replaced with the least finite value representable by
:attr:`input`'s dtype.
.. note::
This function is provided in ``PyTorch>=1.8.0``. Here is a
reimplementation to avoid attribute error in lower PyTorch version.
Args:
x (Tensor): Input tensor.
nan (Number, optional): the value to replace :literal:`NaN`\s with.
Default is zero.
posinf (Number, optional): if a Number, the value to replace positive
infinity values with. If None, positive infinity values are
replaced with the greatest finite value representable by
:attr:`input`'s dtype. Default is None.
neginf (Number, optional): if a Number, the value to replace negative
infinity values with. If None, negative infinity values are
replaced with the lowest finite value representable by
:attr:`input`'s dtype. Default is None.
Returns:
Tensor: Output tensor.
"""
try
:
return
torch
.
nan_to_num
(
x
,
nan
=
nan
,
posinf
=
posinf
,
neginf
=
neginf
,
out
=
out
)
except
AttributeError
:
if
not
isinstance
(
x
,
torch
.
Tensor
):
raise
TypeError
(
f
'argument input (position 1) must be Tensor, not
{
type
(
x
)
}
'
)
if
posinf
is
None
:
posinf
=
torch
.
finfo
(
x
.
dtype
).
max
if
neginf
is
None
:
neginf
=
torch
.
finfo
(
x
.
dtype
).
min
assert
nan
==
0
# a better choice is to use nansum, but this function is not supported
# in PyTorch 1.5
# x.unsqueeze(0).nansum(0)
x
[
torch
.
isnan
(
x
)]
=
0.
return
torch
.
clamp
(
x
,
min
=
neginf
,
max
=
posinf
,
out
=
out
)
def
cast_tensor_type
(
inputs
,
src_type
,
dst_type
):
"""Recursively convert Tensor in inputs from src_type to dst_type.
Args:
inputs: Inputs that to be casted.
src_type (torch.dtype): Source type..
dst_type (torch.dtype): Destination type.
Returns:
The same type with inputs, but all contained Tensors have been cast.
"""
if
isinstance
(
inputs
,
torch
.
Tensor
):
return
inputs
.
to
(
dst_type
)
if
isinstance
(
inputs
,
nn
.
Module
):
return
inputs
elif
isinstance
(
inputs
,
str
):
return
inputs
elif
isinstance
(
inputs
,
np
.
ndarray
):
return
inputs
elif
isinstance
(
inputs
,
abc
.
Mapping
):
return
type
(
inputs
)({
k
:
cast_tensor_type
(
v
,
src_type
,
dst_type
)
for
k
,
v
in
inputs
.
items
()
})
elif
isinstance
(
inputs
,
abc
.
Iterable
):
return
type
(
inputs
)(
cast_tensor_type
(
item
,
src_type
,
dst_type
)
for
item
in
inputs
)
else
:
return
inputs
def
auto_fp16
(
apply_to
=
None
,
out_fp32
=
False
):
"""Decorator to enable fp16 training automatically.
This decorator is useful when you write custom modules and want to support
mixed precision training. If inputs arguments are fp32 tensors, they will
be converted to fp16 automatically. Arguments other than fp32 tensors are
ignored. If you are using PyTorch >= 1.6, torch.cuda.amp is used as the
backend, otherwise, original mmcv implementation will be adopted.
Args:
apply_to (Iterable, optional): The argument names to be converted.
`None` indicates all arguments.
out_fp32 (bool): Whether to convert the output back to fp32.
Example:
>>> import torch.nn as nn
>>> class MyModule1(nn.Module):
>>>
>>> # Convert x and y to fp16
>>> @auto_fp16()
>>> def forward(self, x, y):
>>> pass
>>> import torch.nn as nn
>>> class MyModule2(nn.Module):
>>>
>>> # convert pred to fp16
>>> @auto_fp16(apply_to=('pred', ))
>>> def do_something(self, pred, others):
>>> pass
"""
def
auto_fp16_wrapper
(
old_func
):
@
functools
.
wraps
(
old_func
)
def
new_func
(
*
args
,
**
kwargs
):
# check if the module has set the attribute `fp16_enabled`, if not,
# just fallback to the original method.
if
not
isinstance
(
args
[
0
],
torch
.
nn
.
Module
):
raise
TypeError
(
'@auto_fp16 can only be used to decorate the '
'method of nn.Module'
)
if
not
(
hasattr
(
args
[
0
],
'fp16_enabled'
)
and
args
[
0
].
fp16_enabled
):
return
old_func
(
*
args
,
**
kwargs
)
# define output type by class itself
if
hasattr
(
args
[
0
],
'out_fp32'
)
and
args
[
0
].
out_fp32
:
_out_fp32
=
True
else
:
_out_fp32
=
False
# get the arg spec of the decorated method
args_info
=
getfullargspec
(
old_func
)
# get the argument names to be casted
# Here, we change the default behaviour with Yu Xiong's
# implementation
args_to_cast
=
[]
if
apply_to
is
None
else
apply_to
# convert the args that need to be processed
new_args
=
[]
# NOTE: default args are not taken into consideration
if
args
:
arg_names
=
args_info
.
args
[:
len
(
args
)]
for
i
,
arg_name
in
enumerate
(
arg_names
):
if
arg_name
in
args_to_cast
:
new_args
.
append
(
cast_tensor_type
(
args
[
i
],
torch
.
float
,
torch
.
half
))
else
:
new_args
.
append
(
args
[
i
])
# convert the kwargs that need to be processed
new_kwargs
=
{}
if
kwargs
:
for
arg_name
,
arg_value
in
kwargs
.
items
():
if
arg_name
in
args_to_cast
:
new_kwargs
[
arg_name
]
=
cast_tensor_type
(
arg_value
,
torch
.
float
,
torch
.
half
)
else
:
new_kwargs
[
arg_name
]
=
arg_value
# apply converted arguments to the decorated method
if
TORCH_VERSION
!=
'parrots'
and
TORCH_VERSION
>=
'1.6.0'
:
output
=
autocast
(
enabled
=
True
)(
old_func
)(
*
new_args
,
**
new_kwargs
)
else
:
# output = old_func(*new_args, **new_kwargs)
raise
RuntimeError
(
'Please use PyTorch >= 1.6.0'
)
# cast the results back to fp32 if necessary
if
out_fp32
or
_out_fp32
:
output
=
cast_tensor_type
(
output
,
torch
.
half
,
torch
.
float
)
return
output
return
new_func
return
auto_fp16_wrapper
build/lib/mmgen/core/scheduler/__init__.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
from
.lr_updater
import
LinearLrUpdaterHook
__all__
=
[
'LinearLrUpdaterHook'
]
build/lib/mmgen/core/scheduler/lr_updater.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
from
mmcv.runner
import
HOOKS
,
LrUpdaterHook
@
HOOKS
.
register_module
()
class
LinearLrUpdaterHook
(
LrUpdaterHook
):
"""Linear learning rate scheduler for image generation.
In the beginning, the learning rate is 'base_lr' defined in mmcv.
We give a target learning rate 'target_lr' and a start point 'start'
(iteration / epoch). Before 'start', we fix learning rate as 'base_lr';
After 'start', we linearly update learning rate to 'target_lr'.
Args:
target_lr (float): The target learning rate. Default: 0.
start (int): The start point (iteration / epoch, specified by args
'by_epoch' in its parent class in mmcv) to update learning rate.
Default: 0.
interval (int): The interval to update the learning rate. Default: 1.
"""
def
__init__
(
self
,
target_lr
=
0
,
start
=
0
,
interval
=
1
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
target_lr
=
target_lr
self
.
start
=
start
self
.
interval
=
interval
def
get_lr
(
self
,
runner
,
base_lr
):
"""Calculates the learning rate.
Args:
runner (object): The passed runner.
base_lr (float): Base learning rate.
Returns:
float: Current learning rate.
"""
if
self
.
by_epoch
:
progress
=
runner
.
epoch
max_progress
=
runner
.
max_epochs
else
:
progress
=
runner
.
iter
max_progress
=
runner
.
max_iters
assert
max_progress
>=
self
.
start
if
max_progress
==
self
.
start
:
return
base_lr
# Before 'start', fix lr; After 'start', linearly update lr.
factor
=
(
max
(
0
,
progress
-
self
.
start
)
//
self
.
interval
)
/
(
(
max_progress
-
self
.
start
)
//
self
.
interval
)
return
base_lr
+
(
self
.
target_lr
-
base_lr
)
*
factor
Prev
1
…
7
8
9
10
11
12
13
14
15
…
24
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment