Commit 1401de15 authored by dongchy920's avatar dongchy920
Browse files

stylegan2_mmcv

parents
Pipeline #1274 canceled with stages
# Copyright (c) OpenMMLab. All rights reserved.
import os
import shutil
import sys
from copy import deepcopy
import mmcv
import torch
import torch.distributed as dist
from mmcv.runner import get_dist_info
from prettytable import PrettyTable
from torchvision.utils import save_image
from mmgen.datasets import build_dataloader, build_dataset
def make_metrics_table(train_cfg, ckpt, eval_info, metrics):
"""Arrange evaluation results into a table.
Args:
train_cfg (str): Name of the training configuration.
ckpt (str): Path of the evaluated model's weights.
metrics (Metric): Metric objects.
Returns:
str: String of the eval table.
"""
table = PrettyTable()
table.set_style(14)
table.add_column('Training configuration', [train_cfg])
table.add_column('Checkpoint', [ckpt])
table.add_column('Eval', [eval_info])
for metric in metrics:
table.add_column(metric.name, [metric.result_str])
return table.get_string()
def make_vanilla_dataloader(img_path, batch_size, dist=False):
pipeline = [
dict(type='LoadImageFromFile', key='real_img', io_backend='disk'),
dict(
type='Normalize',
keys=['real_img'],
mean=[127.5] * 3,
std=[127.5] * 3,
to_rgb=False),
dict(type='ImageToTensor', keys=['real_img']),
dict(type='Collect', keys=['real_img'], meta_keys=['real_img_path'])
]
dataset = build_dataset(
dict(
type='UnconditionalImageDataset',
imgs_root=img_path,
pipeline=pipeline,
))
dataloader = build_dataloader(
dataset,
samples_per_gpu=batch_size,
workers_per_gpu=4,
dist=dist,
shuffle=True)
return dataloader
@torch.no_grad()
def offline_evaluation(model,
data_loader,
metrics,
logger,
basic_table_info,
batch_size,
samples_path=None,
**kwargs):
"""Evaluate model in offline mode.
This method first save generated images at local and then load them by
dataloader.
Args:
model (nn.Module): Model to be tested.
data_loader (nn.Dataloader): PyTorch data loader.
metrics (list): List of metric objects.
logger (Logger): logger used to record results of evaluation.
batch_size (int): Batch size of images fed into metrics.
basic_table_info (dict): Dictionary containing the basic information \
of the metric table include training configuration and ckpt.
samples_path (str): Used to save generated images. If it's none, we'll
give it a default directory and delete it after finishing the
evaluation. Default to None.
kwargs (dict): Other arguments.
"""
# eval special and recon metric online only
online_metric_name = ['PPL', 'GaussianKLD']
for metric in metrics:
assert metric.name not in online_metric_name, 'Please eval '\
f'{metric.name} online'
rank, ws = get_dist_info()
delete_samples_path = False
if samples_path:
mmcv.mkdir_or_exist(samples_path)
else:
temp_path = './work_dirs/temp_samples'
# if temp_path exists, add suffix
suffix = 1
samples_path = temp_path
while os.path.exists(samples_path):
samples_path = temp_path + '_' + str(suffix)
suffix += 1
os.makedirs(samples_path)
delete_samples_path = True
# sample images
num_exist = len(
list(
mmcv.scandir(
samples_path, suffix=('.jpg', '.png', '.jpeg', '.JPEG'))))
if basic_table_info['num_samples'] > 0:
max_num_images = basic_table_info['num_samples']
else:
max_num_images = max(metric.num_images for metric in metrics)
num_needed = max(max_num_images - num_exist, 0)
if num_needed > 0 and rank == 0:
mmcv.print_log(f'Sample {num_needed} fake images for evaluation',
'mmgen')
# define mmcv progress bar
pbar = mmcv.ProgressBar(num_needed)
# if no images, `num_needed` should be zero
total_batch_size = batch_size * ws
for begin in range(0, num_needed, total_batch_size):
end = min(begin + batch_size, max_num_images)
fakes = model(
None,
num_batches=end - begin,
return_loss=False,
sample_model=basic_table_info['sample_model'],
**kwargs)
global_end = min(begin + total_batch_size, max_num_images)
if rank == 0:
pbar.update(global_end - begin)
# gather generated images
if ws > 1:
placeholder = [torch.zeros_like(fakes) for _ in range(ws)]
dist.all_gather(placeholder, fakes)
fakes = torch.cat(placeholder, dim=0)
# save as three-channel
if fakes.size(1) == 3:
fakes = fakes[:, [2, 1, 0], ...]
elif fakes.size(1) == 1:
fakes = torch.cat([fakes] * 3, dim=1)
else:
raise RuntimeError('Generated images must have one or three '
'channels in the first dimension, '
'not %d' % fakes.size(1))
if rank == 0:
for i in range(global_end - begin):
images = fakes[i:i + 1]
images = ((images + 1) / 2)
images = images.clamp_(0, 1)
image_name = str(num_exist + begin + i) + '.png'
save_image(images, os.path.join(samples_path, image_name))
if num_needed > 0 and rank == 0:
sys.stdout.write('\n')
# return if only save sampled images
if len(metrics) == 0:
return
# empty cache to release GPU memory
torch.cuda.empty_cache()
fake_dataloader = make_vanilla_dataloader(
samples_path, batch_size, dist=ws > 1)
for metric in metrics:
mmcv.print_log(f'Evaluate with {metric.name} metric.', 'mmgen')
metric.prepare()
if rank == 0:
# prepare for pbar
total_need = (
metric.num_real_need + metric.num_fake_need -
metric.num_real_feeded - metric.num_fake_feeded)
pbar = mmcv.ProgressBar(total_need)
# feed in real images
for data in data_loader:
# key for unconditional GAN
if 'real_img' in data:
reals = data['real_img']
# key for conditional GAN
elif 'img' in data:
reals = data['img']
else:
raise KeyError('Cannot found key for images in data_dict. '
'Only support `real_img` for unconditional '
'datasets and `img` for conditional '
'datasets.')
if reals.shape[1] == 1:
reals = torch.cat([reals] * 3, dim=1)
num_left = metric.feed(reals, 'reals')
if num_left <= 0:
break
if rank == 0:
pbar.update(reals.shape[0] * ws)
# feed in fake images
for data in fake_dataloader:
fakes = data['real_img']
if fakes.shape[1] == 1:
fakes = torch.cat([fakes] * 3, dim=1)
num_left = metric.feed(fakes, 'fakes')
if num_left <= 0:
break
if rank == 0:
pbar.update(fakes.shape[0] * ws)
if rank == 0:
# only call summary at main device
metric.summary()
sys.stdout.write('\n')
if rank == 0:
table_str = make_metrics_table(basic_table_info['train_cfg'],
basic_table_info['ckpt'],
basic_table_info['sample_model'],
metrics)
logger.info('\n' + table_str)
if delete_samples_path:
shutil.rmtree(samples_path)
@torch.no_grad()
def online_evaluation(model, data_loader, metrics, logger, basic_table_info,
batch_size, **kwargs):
"""Evaluate model in online mode.
This method evaluate model and displays eval progress bar.
Different form `offline_evaluation`, this function will not save
the images or read images from disks. Namely, there do not exist any IO
operations in this function. Thus, in general, `online` mode will achieve a
faster evaluation. However, this mode will take much more memory cost.
To be noted that, we only support distributed evaluation for FID and IS
currently.
Args:
model (nn.Module): Model to be tested.
data_loader (nn.Dataloader): PyTorch data loader.
metrics (list): List of metric objects.
logger (Logger): logger used to record results of evaluation.
batch_size (int): Batch size of images fed into metrics.
basic_table_info (dict): Dictionary containing the basic information \
of the metric table include training configuration and ckpt.
kwargs (dict): Other arguments.
"""
# separate metrics into special metrics, probabilistic metrics and vanilla
# metrics.
# For vanilla metrics, images are generated in a random way, and are
# shared by these metrics. For special metrics like 'PPL', images are
# generated in a metric-special way and not shared between different
# metrics.
# For reconstruction metrics like 'GaussianKLD', they do not
# receive images but receive a dict with corresponding probabilistic
# parameter.
rank, ws = get_dist_info()
special_metrics = []
recon_metrics = []
vanilla_metrics = []
special_metric_name = ['PPL']
recon_metric_name = ['GaussianKLD']
for metric in metrics:
if ws > 1:
assert metric.name in [
'FID', 'IS'
], ('We only support FID and IS for distributed evaluation '
f'currently, but receive {metric.name}')
if metric.name in special_metric_name:
special_metrics.append(metric)
elif metric.name in recon_metric_name:
recon_metrics.append(metric)
else:
vanilla_metrics.append(metric)
# define mmcv progress bar
max_num_images = 0
for metric in vanilla_metrics + recon_metrics:
metric.prepare()
max_num_images = max(max_num_images,
metric.num_real_need - metric.num_real_feeded)
if rank == 0:
mmcv.print_log(f'Sample {max_num_images} real images for evaluation',
'mmgen')
pbar = mmcv.ProgressBar(max_num_images)
# avoid `data_loader` is None
data_loader = [] if data_loader is None else data_loader
for data in data_loader:
if 'real_img' in data:
reals = data['real_img']
# key for conditional GAN
elif 'img' in data:
reals = data['img']
else:
raise KeyError('Cannot found key for images in data_dict. '
'Only support `real_img` for unconditional '
'datasets and `img` for conditional '
'datasets.')
if reals.shape[1] not in [1, 3]:
raise RuntimeError('real images should have one or three '
'channels in the first, '
'not % d' % reals.shape[1])
if reals.shape[1] == 1:
reals = reals.repeat(1, 3, 1, 1)
num_feed = 0
for metric in vanilla_metrics:
num_feed_ = metric.feed(reals, 'reals')
num_feed = max(num_feed_, num_feed)
for metric in recon_metrics:
kwargs_ = deepcopy(kwargs)
kwargs_['mode'] = 'reconstruction'
prob_dict = model(reals, return_loss=False, **kwargs_)
num_feed_ = metric.feed(prob_dict, 'reals')
num_feed = max(num_feed_, num_feed)
if num_feed <= 0:
break
if rank == 0:
pbar.update(num_feed)
if rank == 0:
# finish the pbar stdout
sys.stdout.write('\n')
# define mmcv progress bar
max_num_images = 0 if len(vanilla_metrics) == 0 else max(
metric.num_fake_need for metric in vanilla_metrics)
if rank == 0:
mmcv.print_log(f'Sample {max_num_images} fake images for evaluation',
'mmgen')
pbar = mmcv.ProgressBar(max_num_images)
# sampling fake images and directly send them to metrics
total_batch_size = batch_size * ws
for _ in range(0, max_num_images, total_batch_size):
fakes = model(
None,
num_batches=batch_size,
return_loss=False,
sample_model=basic_table_info['sample_model'],
**kwargs)
if fakes.shape[1] not in [1, 3]:
raise RuntimeError('fakes images should have one or three '
'channels in the first, '
'not % d' % fakes.shape[1])
if fakes.shape[1] == 1:
fakes = torch.cat([fakes] * 3, dim=1)
for metric in vanilla_metrics:
# feed in fake images
metric.feed(fakes, 'fakes')
if rank == 0:
pbar.update(total_batch_size)
if rank == 0:
# finish the pbar stdout
sys.stdout.write('\n')
# feed special metric, we do not consider distributed eval here
for metric in special_metrics:
metric.prepare()
fakedata_iterator = iter(
metric.get_sampler(model.module, batch_size,
basic_table_info['sample_model']))
mmcv.print_log(
f'Sample {metric.num_images} samples for evaluating {metric.name}',
'mmgen')
pbar = mmcv.ProgressBar(metric.num_images)
for fakes in fakedata_iterator:
num_left = metric.feed(fakes, 'fakes')
pbar.update(fakes.shape[0])
if num_left <= 0:
break
# finish the pbar stdout
sys.stdout.write('\n')
if rank == 0:
for metric in metrics:
metric.summary()
table_str = make_metrics_table(basic_table_info['train_cfg'],
basic_table_info['ckpt'],
basic_table_info['sample_model'],
metrics)
logger.info('\n' + table_str)
# Copyright (c) OpenMMLab. All rights reserved.
import sys
import mmcv
import numpy as np
import torch
import torch.nn.functional as F
from mmcv.parallel import is_module_wrapper
from mmgen.models.architectures.common import get_module_device
@torch.no_grad()
def extract_inception_features(dataloader,
inception,
num_samples,
inception_style='pytorch'):
"""Extract inception features for FID metric.
Args:
dataloader (:obj:`DataLoader`): Dataloader for images.
inception (nn.Module): Inception network.
num_samples (int): The number of samples to be extracted.
inception_style (str): The style of Inception network, "pytorch" or
"stylegan". Defaults to "pytorch".
Returns:
torch.Tensor: Inception features.
"""
batch_size = dataloader.batch_size
num_iters = num_samples // batch_size
if num_iters * batch_size < num_samples:
num_iters += 1
# define mmcv progress bar
pbar = mmcv.ProgressBar(num_iters)
feature_list = []
curr_iter = 1
for data in dataloader:
# a dirty walkround to support multiple datasets (mainly for the
# unconditional dataset and conditional dataset). In our
# implementation, unconditioanl dataset will return real images with
# the key "real_img". However, the conditional dataset contains a key
# "img" denoting the real images.
if 'real_img' in data:
# Mainly for the unconditional dataset in our MMGeneration
img = data['real_img']
else:
# Mainly for conditional dataset in MMClassification
img = data['img']
pbar.update()
# the inception network is not wrapped with module wrapper.
if not is_module_wrapper(inception):
# put the img to the module device
img = img.to(get_module_device(inception))
if inception_style == 'stylegan':
img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8)
feature = inception(img, return_features=True)
else:
feature = inception(img)[0].view(img.shape[0], -1)
feature_list.append(feature.to('cpu'))
if curr_iter >= num_iters:
break
curr_iter += 1
# Attention: the number of features may be different as you want.
features = torch.cat(feature_list, 0)
assert features.shape[0] >= num_samples
features = features[:num_samples]
# to change the line after pbar
sys.stdout.write('\n')
return features
def _hox_downsample(img):
r"""Downsample images with factor equal to 0.5.
Ref: https://github.com/tkarras/progressive_growing_of_gans/blob/master/metrics/ms_ssim.py # noqa
Args:
img (ndarray): Images with order "NHWC".
Returns:
ndarray: Downsampled images with order "NHWC".
"""
return (img[:, 0::2, 0::2, :] + img[:, 1::2, 0::2, :] +
img[:, 0::2, 1::2, :] + img[:, 1::2, 1::2, :]) * 0.25
def _f_special_gauss(size, sigma):
r"""Return a circular symmetric gaussian kernel.
Ref: https://github.com/tkarras/progressive_growing_of_gans/blob/master/metrics/ms_ssim.py # noqa
Args:
size (int): Size of Gaussian kernel.
sigma (float): Standard deviation for Gaussian blur kernel.
Returns:
ndarray: Gaussian kernel.
"""
radius = size // 2
offset = 0.0
start, stop = -radius, radius + 1
if size % 2 == 0:
offset = 0.5
stop -= 1
x, y = np.mgrid[offset + start:stop, offset + start:stop]
assert len(x) == size
g = np.exp(-((x**2 + y**2) / (2.0 * sigma**2)))
return g / g.sum()
# Gaussian blur kernel
def get_gaussian_kernel():
kernel = np.array([[1, 4, 6, 4, 1], [4, 16, 24, 16, 4], [6, 24, 36, 24, 6],
[4, 16, 24, 16, 4], [1, 4, 6, 4, 1]],
np.float32) / 256.0
gaussian_k = torch.as_tensor(kernel.reshape(1, 1, 5, 5))
return gaussian_k
def get_pyramid_layer(image, gaussian_k, direction='down'):
gaussian_k = gaussian_k.to(image.device)
if direction == 'up':
image = F.interpolate(image, scale_factor=2)
multiband = [
F.conv2d(
image[:, i:i + 1, :, :],
gaussian_k,
padding=2,
stride=1 if direction == 'up' else 2) for i in range(3)
]
image = torch.cat(multiband, dim=1)
return image
def gaussian_pyramid(original, n_pyramids, gaussian_k):
x = original
# pyramid down
pyramids = [original]
for _ in range(n_pyramids):
x = get_pyramid_layer(x, gaussian_k)
pyramids.append(x)
return pyramids
def laplacian_pyramid(original, n_pyramids, gaussian_k):
"""Calculate Laplacian pyramid.
Ref: https://github.com/koshian2/swd-pytorch/blob/master/swd.py
Args:
original (Tensor): Batch of Images with range [0, 1] and order "NCHW"
n_pyramids (int): Levels of pyramids minus one.
gaussian_k (Tensor): Gaussian kernel with shape (1, 1, 5, 5).
Return:
list[Tensor]. Laplacian pyramids of original.
"""
# create gaussian pyramid
pyramids = gaussian_pyramid(original, n_pyramids, gaussian_k)
# pyramid up - diff
laplacian = []
for i in range(len(pyramids) - 1):
diff = pyramids[i] - get_pyramid_layer(pyramids[i + 1], gaussian_k,
'up')
laplacian.append(diff)
# Add last gaussian pyramid
laplacian.append(pyramids[len(pyramids) - 1])
return laplacian
def get_descriptors_for_minibatch(minibatch, nhood_size, nhoods_per_image):
r"""Get descriptors of one level of pyramids.
Ref: https://github.com/tkarras/progressive_growing_of_gans/blob/master/metrics/sliced_wasserstein.py # noqa
Args:
minibatch (Tensor): Pyramids of one level with order "NCHW".
nhood_size (int): Pixel neighborhood size.
nhoods_per_image (int): The number of descriptors per image.
Return:
Tensor: Descriptors of images from one level batch.
"""
S = minibatch.shape # (minibatch, channel, height, width)
assert len(S) == 4 and S[1] == 3
N = nhoods_per_image * S[0]
H = nhood_size // 2
nhood, chan, x, y = np.ogrid[0:N, 0:3, -H:H + 1, -H:H + 1]
img = nhood // nhoods_per_image
x = x + np.random.randint(H, S[3] - H, size=(N, 1, 1, 1))
y = y + np.random.randint(H, S[2] - H, size=(N, 1, 1, 1))
idx = ((img * S[1] + chan) * S[2] + y) * S[3] + x
return minibatch.view(-1)[idx]
def finalize_descriptors(desc):
r"""Normalize and reshape descriptors.
Ref: https://github.com/tkarras/progressive_growing_of_gans/blob/master/metrics/sliced_wasserstein.py # noqa
Args:
desc (list or Tensor): List of descriptors of one level.
Return:
Tensor: Descriptors after normalized along channel and flattened.
"""
if isinstance(desc, list):
desc = torch.cat(desc, dim=0)
assert desc.ndim == 4 # (neighborhood, channel, height, width)
desc -= torch.mean(desc, dim=(0, 2, 3), keepdim=True)
desc /= torch.std(desc, dim=(0, 2, 3), keepdim=True)
desc = desc.reshape(desc.shape[0], -1)
return desc
def compute_pr_distances(row_features,
col_features,
num_gpus,
rank,
col_batch_size=10000):
r"""Compute distances between real images and fake images.
This function is used for calculate Precision and Recall metric.
Refer to:https://github.com/NVlabs/stylegan2-ada-pytorch/blob/main/metrics/precision_recall.py # noqa
"""
assert 0 <= rank < num_gpus
num_cols = col_features.shape[0]
num_batches = ((num_cols - 1) // col_batch_size // num_gpus + 1) * num_gpus
col_batches = torch.nn.functional.pad(
col_features, [0, 0, 0, -num_cols % num_batches]).chunk(num_batches)
dist_batches = []
for col_batch in col_batches[rank::num_gpus]:
dist_batch = torch.cdist(
row_features.unsqueeze(0), col_batch.unsqueeze(0))[0]
for src in range(num_gpus):
dist_broadcast = dist_batch.clone()
if num_gpus > 1:
torch.distributed.broadcast(dist_broadcast, src=src)
dist_batches.append(dist_broadcast.cpu() if rank == 0 else None)
return torch.cat(dist_batches, dim=1)[:, :num_cols] if rank == 0 else None
def normalize(a):
"""L2 normalization.
Args:
a (Tensor): Tensor with shape [N, C].
Returns:
Tensor: Tensor after L2 normalization per-instance.
"""
return a / torch.norm(a, dim=1, keepdim=True)
def slerp(a, b, percent):
"""Spherical linear interpolation between two unnormalized vectors.
Args:
a (Tensor): Tensor with shape [N, C].
b (Tensor): Tensor with shape [N, C].
percent (float|Tensor): A float or tensor with shape broadcastable to
the shape of input Tensors.
Returns:
Tensor: Spherical linear interpolation result with shape [N, C].
"""
a = normalize(a)
b = normalize(b)
d = (a * b).sum(-1, keepdim=True)
p = percent * torch.acos(d)
c = normalize(b - d * a)
d = a * torch.cos(p) + c * torch.sin(p)
return normalize(d)
# Copyright (c) OpenMMLab. All rights reserved.
import os
import pickle
from abc import ABC, abstractmethod
from copy import deepcopy
from functools import partial
import mmcv
import numpy as np
import torch
import torch.distributed as dist
import torch.nn.functional as F
from mmcv.runner import get_dist_info
from scipy import linalg, signal
from scipy.stats import entropy
from torchvision import models
from torchvision.models.inception import inception_v3
from mmgen.models.architectures import InceptionV3
from mmgen.models.architectures.common import get_module_device
from mmgen.models.architectures.lpips import PerceptualLoss
from mmgen.models.losses import gaussian_kld
from mmgen.utils import MMGEN_CACHE_DIR
from mmgen.utils.io_utils import download_from_url
from ..registry import METRICS
from .metric_utils import (_f_special_gauss, _hox_downsample,
compute_pr_distances, finalize_descriptors,
get_descriptors_for_minibatch, get_gaussian_kernel,
laplacian_pyramid, slerp)
TERO_INCEPTION_URL = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt' # noqa
def load_inception(inception_args, metric):
"""Load Inception Model from given ``inception_args`` and ``metric``. This
function would try to load Inception under the guidance of 'type' given in
`inception_args`, if not given, we would try best to load Tero's ones. In
detailly, we would first try to load the model from disk with the given
'inception_path', and then try to download the checkpoint from
'inception_url'. If both method are failed, pytorch version of Inception
would be loaded.
Args:
inception_args (dict): Keyword args for inception net.
metric (string): Metric to use the Inception. This argument would
influence the pytorch's Inception loading.
Returns:
model (torch.nn.Module): Loaded Inception model.
style (string): The version of the loaded Inception.
"""
if not isinstance(inception_args, dict):
raise TypeError('Receive invalid \'inception_args\': '
f'\'{inception_args}\'')
_inception_args = deepcopy(inception_args)
inceptoin_type = _inception_args.pop('type', None)
if torch.__version__ < '1.6.0':
mmcv.print_log(
'Current Pytorch Version not support script module, load '
'Inception Model from torch model zoo. If you want to use '
'Tero\' script model, please update your Pytorch higher '
f'than \'1.6\' (now is {torch.__version__})', 'mmgen')
return _load_inception_torch(_inception_args, metric), 'pytorch'
# load pytorch version is specific
if inceptoin_type != 'StyleGAN':
return _load_inception_torch(_inception_args, metric), 'pytorch'
# try to load Tero's version
path = _inception_args.get('inception_path', TERO_INCEPTION_URL)
# try to parse `path` as web url and download
if 'http' not in path:
model = _load_inception_from_path(path)
if isinstance(model, torch.nn.Module):
return model, 'StyleGAN'
# try to parse `path` as path on disk
model = _load_inception_from_url(path)
if isinstance(model, torch.nn.Module):
return model, 'StyleGAN'
raise RuntimeError('Cannot Load Inception Model, please check the input '
f'`inception_args`: {inception_args}')
def _load_inception_from_path(inception_path):
mmcv.print_log(
'Try to load Tero\'s Inception Model from '
f'\'{inception_path}\'.', 'mmgen')
try:
model = torch.jit.load(inception_path)
mmcv.print_log('Load Tero\'s Inception Model successfully.', 'mmgen')
except Exception as e:
model = None
mmcv.print_log(
'Load Tero\'s Inception Model failed. '
f'\'{e}\' occurs.', 'mmgen')
return model
def _load_inception_from_url(inception_url):
inception_url = inception_url if inception_url else TERO_INCEPTION_URL
mmcv.print_log(f'Try to download Inception Model from {inception_url}...',
'mmgen')
try:
path = download_from_url(inception_url, dest_dir=MMGEN_CACHE_DIR)
mmcv.print_log('Download Finished.')
return _load_inception_from_path(path)
except Exception as e:
mmcv.print_log(f'Download Failed. {e} occurs.')
return None
def _load_inception_torch(inception_args, metric):
assert metric in ['FID', 'IS']
if metric == 'FID':
inception_model = InceptionV3([3], **inception_args)
elif metric == 'IS':
inception_model = inception_v3(pretrained=True, transform_input=False)
mmcv.print_log(
'Load Inception V3 Network from Pytorch Model Zoo '
'for IS calculation. The results can only used '
'for monitoring purposes. To get more accuracy IS, '
'please use Tero\'s Inception V3 checkpoints '
'and use Bicubic Interpolation with Pillow backend '
'for image resizing. More details may refer to '
'https://github.com/open-mmlab/mmgeneration/blob/master/docs/en/quick_run.md#is.', # noqa
'mmgen')
return inception_model
def _ssim_for_multi_scale(img1,
img2,
max_val=255,
filter_size=11,
filter_sigma=1.5,
k1=0.01,
k2=0.03):
"""Calculate SSIM (structural similarity) and contrast sensitivity.
Ref:
Image quality assessment: From error visibility to structural similarity.
The results are the same as that of the official released MATLAB code in
https://ece.uwaterloo.ca/~z70wang/research/ssim/.
For three-channel images, SSIM is calculated for each channel and then
averaged.
This function attempts to match the functionality of ssim_index_new.m by
Zhou Wang: http://www.cns.nyu.edu/~lcv/ssim/msssim.zip
Args:
img1 (ndarray): Images with range [0, 255] and order "NHWC".
img2 (ndarray): Images with range [0, 255] and order "NHWC".
max_val (int): the dynamic range of the images (i.e., the difference
between the maximum the and minimum allowed values).
Default to 255.
filter_size (int): Size of blur kernel to use (will be reduced for
small images). Default to 11.
filter_sigma (float): Standard deviation for Gaussian blur kernel (will
be reduced for small images). Default to 1.5.
k1 (float): Constant used to maintain stability in the SSIM calculation
(0.01 in the original paper). Default to 0.01.
k2 (float): Constant used to maintain stability in the SSIM calculation
(0.03 in the original paper). Default to 0.03.
Returns:
tuple: Pair containing the mean SSIM and contrast sensitivity between
`img1` and `img2`.
"""
if img1.shape != img2.shape:
raise RuntimeError(
'Input images must have the same shape (%s vs. %s).' %
(img1.shape, img2.shape))
if img1.ndim != 4:
raise RuntimeError('Input images must have four dimensions, not %d' %
img1.ndim)
img1 = img1.astype(np.float32)
img2 = img2.astype(np.float32)
_, height, width, _ = img1.shape
# Filter size can't be larger than height or width of images.
size = min(filter_size, height, width)
# Scale down sigma if a smaller filter size is used.
sigma = size * filter_sigma / filter_size if filter_size else 0
if filter_size:
window = np.reshape(_f_special_gauss(size, sigma), (1, size, size, 1))
mu1 = signal.fftconvolve(img1, window, mode='valid')
mu2 = signal.fftconvolve(img2, window, mode='valid')
sigma11 = signal.fftconvolve(img1 * img1, window, mode='valid')
sigma22 = signal.fftconvolve(img2 * img2, window, mode='valid')
sigma12 = signal.fftconvolve(img1 * img2, window, mode='valid')
else:
# Empty blur kernel so no need to convolve.
mu1, mu2 = img1, img2
sigma11 = img1 * img1
sigma22 = img2 * img2
sigma12 = img1 * img2
mu11 = mu1 * mu1
mu22 = mu2 * mu2
mu12 = mu1 * mu2
sigma11 -= mu11
sigma22 -= mu22
sigma12 -= mu12
# Calculate intermediate values used by both ssim and cs_map.
c1 = (k1 * max_val)**2
c2 = (k2 * max_val)**2
v1 = 2.0 * sigma12 + c2
v2 = sigma11 + sigma22 + c2
ssim = np.mean((((2.0 * mu12 + c1) * v1) / ((mu11 + mu22 + c1) * v2)),
axis=(1, 2, 3)) # Return for each image individually.
cs = np.mean(v1 / v2, axis=(1, 2, 3))
return ssim, cs
def ms_ssim(img1,
img2,
max_val=255,
filter_size=11,
filter_sigma=1.5,
k1=0.01,
k2=0.03,
weights=None):
"""Calculate MS-SSIM (multi-scale structural similarity).
Ref:
This function implements Multi-Scale Structural Similarity (MS-SSIM) Image
Quality Assessment according to Zhou Wang's paper, "Multi-scale structural
similarity for image quality assessment" (2003).
Link: https://ece.uwaterloo.ca/~z70wang/publications/msssim.pdf
Author's MATLAB implementation:
http://www.cns.nyu.edu/~lcv/ssim/msssim.zip
PGGAN's implementation:
https://github.com/tkarras/progressive_growing_of_gans/blob/master/metrics/ms_ssim.py
Args:
img1 (ndarray): Images with range [0, 255] and order "NHWC".
img2 (ndarray): Images with range [0, 255] and order "NHWC".
max_val (int): the dynamic range of the images (i.e., the difference
between the maximum the and minimum allowed values).
Default to 255.
filter_size (int): Size of blur kernel to use (will be reduced for
small images). Default to 11.
filter_sigma (float): Standard deviation for Gaussian blur kernel (will
be reduced for small images). Default to 1.5.
k1 (float): Constant used to maintain stability in the SSIM calculation
(0.01 in the original paper). Default to 0.01.
k2 (float): Constant used to maintain stability in the SSIM calculation
(0.03 in the original paper). Default to 0.03.
weights (list): List of weights for each level; if none, use five
levels and the weights from the original paper. Default to None.
Returns:
float: MS-SSIM score between `img1` and `img2`.
"""
if img1.shape != img2.shape:
raise RuntimeError(
'Input images must have the same shape (%s vs. %s).' %
(img1.shape, img2.shape))
if img1.ndim != 4:
raise RuntimeError('Input images must have four dimensions, not %d' %
img1.ndim)
# Note: default weights don't sum to 1.0 but do match the paper / matlab
# code.
weights = np.array(
weights if weights else [0.0448, 0.2856, 0.3001, 0.2363, 0.1333])
levels = weights.size
im1, im2 = [x.astype(np.float32) for x in [img1, img2]]
mssim = []
mcs = []
for _ in range(levels):
ssim, cs = _ssim_for_multi_scale(
im1,
im2,
max_val=max_val,
filter_size=filter_size,
filter_sigma=filter_sigma,
k1=k1,
k2=k2)
mssim.append(ssim)
mcs.append(cs)
im1, im2 = [_hox_downsample(x) for x in [im1, im2]]
# Clip to zero. Otherwise we get NaNs.
mssim = np.clip(np.asarray(mssim), 0.0, np.inf)
mcs = np.clip(np.asarray(mcs), 0.0, np.inf)
# Average over images only at the end.
return np.mean(
np.prod(mcs[:-1, :]**weights[:-1, np.newaxis], axis=0) *
(mssim[-1, :]**weights[-1]))
def sliced_wasserstein(distribution_a,
distribution_b,
dir_repeats=4,
dirs_per_repeat=128):
r"""sliced Wasserstein distance of two sets of patches.
Ref: https://github.com/tkarras/progressive_growing_of_gans/blob/master/metrics/ms_ssim.py # noqa
Args:
distribution_a (Tensor): Descriptors of first distribution.
distribution_b (Tensor): Descriptors of second distribution.
dir_repeats (int): The number of projection times. Default to 4.
dirs_per_repeat (int): The number of directions per projection.
Default to 128.
Returns:
float: sliced Wasserstein distance.
"""
if torch.cuda.is_available():
distribution_b = distribution_b.cuda()
assert distribution_a.ndim == 2
assert distribution_a.shape == distribution_b.shape
assert dir_repeats > 0 and dirs_per_repeat > 0
distribution_a = distribution_a.to(distribution_b.device)
results = []
for _ in range(dir_repeats):
dirs = torch.randn(distribution_a.shape[1], dirs_per_repeat)
dirs /= torch.sqrt(torch.sum((dirs**2), dim=0, keepdim=True))
dirs = dirs.to(distribution_b.device)
proj_a = torch.matmul(distribution_a, dirs)
proj_b = torch.matmul(distribution_b, dirs)
# To save cuda memory, we perform sort in cpu
proj_a, _ = torch.sort(proj_a.cpu(), dim=0)
proj_b, _ = torch.sort(proj_b.cpu(), dim=0)
dists = torch.abs(proj_a - proj_b)
results.append(torch.mean(dists).item())
torch.cuda.empty_cache()
return sum(results) / dir_repeats
class Metric(ABC):
"""The abstract base class of metrics. Basically, we split calculation into
three steps. First, we initialize the metric object and do some
preparation. Second, we will feed the real and fake images into metric
object batch by batch, and we calculate intermediate results of these
batches. Finally, We use these intermediate results to summarize the final
result. And the result as a string can be obtained by property
'result_str'.
Args:
num_images (int): The number of real/fake images needed to calculate
metric.
image_shape (tuple): Shape of the real/fake images with order "CHW".
"""
def __init__(self, num_images, image_shape=None):
self.num_images = num_images
self.image_shape = image_shape
self.num_real_need = num_images
self.num_fake_need = num_images
self.num_real_feeded = 0 # record of the fed real images
self.num_fake_feeded = 0 # record of the fed fake images
self._result_str = None # string of metric result
@property
def result_str(self):
"""Get results in string format.
Returns:
str: results in string format
"""
if not self._result_str:
self.summary()
return self._result_str
return self._result_str
def feed(self, batch, mode):
"""Feed a image batch into metric calculator and perform intermediate
operation in 'feed_op' function.
Args:
batch (Tensor | dict): Images or dict to be fed into
metric object. If ``Tensor`` is passed, the order of ``Tensor``
should be "NCHW". If ``dict`` is passed, each term in the
``dict`` are ``Tensor`` with order "NCHW".
mode (str): Mark the batch as real or fake images. Value can be
'reals' or 'fakes',
"""
_, ws = get_dist_info()
if mode == 'reals':
if self.num_real_feeded == self.num_real_need:
return 0
if isinstance(batch, dict):
batch_size = [v for v in batch.values()][0].shape[0]
end = min(batch_size,
self.num_real_need - self.num_real_feeded)
batch_to_feed = {k: v[:end, ...] for k, v in batch.items()}
else:
batch_size = batch.shape[0]
end = min(batch_size,
self.num_real_need - self.num_real_feeded)
batch_to_feed = batch[:end, ...]
global_end = min(batch_size * ws,
self.num_real_need - self.num_real_feeded)
self.feed_op(batch_to_feed, mode)
self.num_real_feeded += global_end
return end
elif mode == 'fakes':
if self.num_fake_feeded == self.num_fake_need:
return 0
batch_size = batch.shape[0]
end = min(batch_size, self.num_fake_need - self.num_fake_feeded)
if isinstance(batch, dict):
batch_to_feed = {k: v[:end, ...] for k, v in batch.items()}
else:
batch_to_feed = batch[:end, ...]
global_end = min(batch_size * ws,
self.num_fake_need - self.num_fake_feeded)
self.feed_op(batch_to_feed, mode)
self.num_fake_feeded += global_end
return end
else:
raise ValueError(
'The expected mode should be set to \'reals\' or \'fakes\','
f'but got \'{mode}\'')
def check(self):
"""Check the numbers of image."""
assert self.num_real_feeded == self.num_fake_feeded == self.num_images
@abstractmethod
def prepare(self, *args, **kwargs):
"""please implement in subclass."""
@abstractmethod
def feed_op(self, batch, mode):
"""please implement in subclass."""
@abstractmethod
def summary(self):
"""please implement in subclass."""
@METRICS.register_module()
class FID(Metric):
"""FID metric.
In this metric, we calculate the distance between real distributions and
fake distributions. The distributions are modeled by the real samples and
fake samples, respectively.
`Inception_v3` is adopted as the feature extractor, which is widely used in
StyleGAN and BigGAN.
Args:
num_images (int): The number of images to be tested.
image_shape (tuple[int], optional): Image shape. Defaults to None.
inception_pkl (str, optional): Path to reference inception pickle file.
If `None`, the statistical value of real distribution will be
calculated at running time. Defaults to None.
bgr2rgb (bool, optional): If True, reformat the BGR image to RGB
format. Defaults to True.
inception_args (dict, optional): Keyword args for inception net.
Defaults to `dict(normalize_input=False)`.
"""
name = 'FID'
def __init__(self,
num_images,
image_shape=None,
inception_pkl=None,
bgr2rgb=True,
inception_args=dict(normalize_input=False)):
super().__init__(num_images, image_shape=image_shape)
self.inception_pkl = inception_pkl
self.real_feats = []
self.fake_feats = []
self.real_mean = None
self.real_cov = None
self.bgr2rgb = bgr2rgb
self.device = 'cpu'
self.inception_net, self.inception_style = load_inception(
inception_args, 'FID')
if torch.cuda.is_available():
self.inception_net = self.inception_net.cuda()
self.device = 'cuda'
self.inception_net.eval()
mmcv.print_log(f'FID: Adopt Inception in {self.inception_style} style',
'mmgen')
def prepare(self):
"""Prepare for evaluating models with this metric."""
# if `inception_pkl` is provided, read mean and cov stat
if self.inception_pkl is not None and mmcv.is_filepath(
self.inception_pkl):
with open(self.inception_pkl, 'rb') as f:
reference = pickle.load(f)
self.real_mean = reference['mean']
self.real_cov = reference['cov']
mmcv.print_log(
f'Load reference inception pkl from {self.inception_pkl}',
'mmgen')
self.num_real_feeded = self.num_images
@torch.no_grad()
def feed_op(self, batch, mode):
"""Feed data to the metric.
Args:
batch (Tensor): Input tensor.
mode (str): The mode of current data batch. 'reals' or 'fakes'.
"""
if self.bgr2rgb:
batch = batch[:, [2, 1, 0]]
batch = batch.to(self.device)
if self.inception_style == 'StyleGAN':
batch = (batch * 127.5 + 128).clamp(0, 255).to(torch.uint8)
feat = self.inception_net(batch, return_features=True)
else:
feat = self.inception_net(batch)[0].view(batch.shape[0], -1)
# gather all of images if using distributed training
if dist.is_initialized():
ws = dist.get_world_size()
placeholder = [torch.zeros_like(feat) for _ in range(ws)]
dist.all_gather(placeholder, feat)
feat = torch.cat(placeholder, dim=0)
# in distributed training, we only collect features at rank-0.
if (dist.is_initialized()
and dist.get_rank() == 0) or not dist.is_initialized():
if mode == 'reals':
self.real_feats.append(feat.cpu())
elif mode == 'fakes':
self.fake_feats.append(feat.cpu())
else:
raise ValueError(
f"The expected mode should be set to 'reals' or 'fakes,\
but got '{mode}'")
@staticmethod
def _calc_fid(sample_mean, sample_cov, real_mean, real_cov, eps=1e-6):
"""Refer to the implementation from:
https://github.com/rosinality/stylegan2-pytorch/blob/master/fid.py#L34
"""
cov_sqrt, _ = linalg.sqrtm(sample_cov @ real_cov, disp=False)
if not np.isfinite(cov_sqrt).all():
print('product of cov matrices is singular')
offset = np.eye(sample_cov.shape[0]) * eps
cov_sqrt = linalg.sqrtm(
(sample_cov + offset) @ (real_cov + offset))
if np.iscomplexobj(cov_sqrt):
if not np.allclose(np.diagonal(cov_sqrt).imag, 0, atol=1e-3):
m = np.max(np.abs(cov_sqrt.imag))
raise ValueError(f'Imaginary component {m}')
cov_sqrt = cov_sqrt.real
mean_diff = sample_mean - real_mean
mean_norm = mean_diff @ mean_diff
trace = np.trace(sample_cov) + np.trace(
real_cov) - 2 * np.trace(cov_sqrt)
fid = mean_norm + trace
return fid, mean_norm, trace
@torch.no_grad()
def summary(self):
"""Summarize the results.
Returns:
dict | list: Summarized results.
"""
# calculate reference inception stat
if self.real_mean is None:
feats = torch.cat(self.real_feats, dim=0)
assert feats.shape[0] >= self.num_images
feats = feats[:self.num_images]
feats_np = feats.numpy()
self.real_mean = np.mean(feats_np, 0)
self.real_cov = np.cov(feats_np, rowvar=False)
# calculate fake inception stat
fake_feats = torch.cat(self.fake_feats, dim=0)
assert fake_feats.shape[0] >= self.num_images
fake_feats = fake_feats[:self.num_images]
fake_feats_np = fake_feats.numpy()
fake_mean = np.mean(fake_feats_np, 0)
fake_cov = np.cov(fake_feats_np, rowvar=False)
# calculate distance between real and fake statistics
fid, mean, cov = self._calc_fid(fake_mean, fake_cov, self.real_mean,
self.real_cov)
# results for print/table
self._result_str = (f'{fid:.4f} ({mean:.5f}/{cov:.5f})')
# results for log_buffer
self._result_dict = dict(fid=fid, fid_mean=mean, fid_cov=cov)
return fid, mean, cov
def clear_fake_data(self):
"""Clear fake data."""
self.fake_feats = []
self.num_fake_feeded = 0
def clear(self, clear_reals=False):
"""Clear data buffers.
Args:
clear_reals (bool, optional): Whether to clear real data.
Defaults to False.
"""
self.clear_fake_data()
if clear_reals:
self.real_feats = []
self.num_real_feeded = 0
@METRICS.register_module()
class MS_SSIM(Metric):
"""MS-SSIM (Multi-Scale Structure Similarity) metric.
Ref: https://github.com/tkarras/progressive_growing_of_gans/blob/master/metrics/ms_ssim.py # noqa
Args:
num_images (int): The number of evaluated generated samples.
image_shape (tuple, optional): Image shape in order "CHW". Defaults to
None.
"""
name = 'MS-SSIM'
def __init__(self, num_images, image_shape=None):
super().__init__(num_images, image_shape)
assert num_images % 2 == 0
self.num_pairs = num_images // 2
def prepare(self):
"""Prepare for evaluating models with this metric."""
self.sum = 0.0
@torch.no_grad()
def feed_op(self, minibatch, mode):
"""Feed data to the metric.
Args:
batch (Tensor): Input tensor.
mode (str): The mode of current data batch. 'reals' or 'fakes'.
"""
if mode == 'reals':
return
minibatch = ((minibatch + 1) / 2)
minibatch = minibatch.clamp_(0, 1)
half1 = minibatch[0::2].cpu().data.numpy().transpose((0, 2, 3, 1))
half1 = (half1 * 255).astype('uint8')
half2 = minibatch[1::2].cpu().data.numpy().transpose((0, 2, 3, 1))
half2 = (half2 * 255).astype('uint8')
score = ms_ssim(half1, half2)
self.sum += score * (minibatch.shape[0] // 2)
@torch.no_grad()
def summary(self):
"""Summarize the results.
Returns:
dict | list: Summarized results.
"""
self.check()
avg = self.sum / self.num_pairs
self._result_str = str(round(avg.item(), 4))
return avg
@METRICS.register_module()
class SWD(Metric):
"""SWD (Sliced Wasserstein distance) metric. We calculate the SWD of two
sets of images in the following way. In every 'feed', we obtain the
Laplacian pyramids of every images and extract patches from the Laplacian
pyramids as descriptors. In 'summary', we normalize these descriptors along
channel, and reshape them so that we can use these descriptors to represent
the distribution of real/fake images. And we can calculate the sliced
Wasserstein distance of the real and fake descriptors as the SWD of the
real and fake images.
Ref: https://github.com/tkarras/progressive_growing_of_gans/blob/master/metrics/sliced_wasserstein.py # noqa
Args:
num_images (int): The number of evaluated generated samples.
image_shape (tuple): Image shape in order "CHW".
"""
name = 'SWD'
def __init__(self, num_images, image_shape):
super().__init__(num_images, image_shape)
self.nhood_size = 7 # height and width of the extracted patches
self.nhoods_per_image = 128 # number of extracted patches per image
self.dir_repeats = 4 # times of sampling directions
self.dirs_per_repeat = 128 # number of directions per sampling
self.resolutions = []
res = image_shape[1]
while res >= 16 and len(self.resolutions) < 4:
self.resolutions.append(res)
res //= 2
self.n_pyramids = len(self.resolutions)
def prepare(self):
"""Prepare for evaluating models with this metric."""
self.real_descs = [[] for res in self.resolutions]
self.fake_descs = [[] for res in self.resolutions]
self.gaussian_k = get_gaussian_kernel()
@torch.no_grad()
def feed_op(self, minibatch, mode):
"""Feed data to the metric.
Args:
batch (Tensor): Input tensor.
mode (str): The mode of current data batch. 'reals' or 'fakes'.
"""
assert minibatch.shape[1:] == self.image_shape
if mode == 'reals':
real_pyramid = laplacian_pyramid(minibatch, self.n_pyramids - 1,
self.gaussian_k)
# lod: layer_of_descriptors
for lod, level in enumerate(real_pyramid):
desc = get_descriptors_for_minibatch(level, self.nhood_size,
self.nhoods_per_image)
self.real_descs[lod].append(desc)
elif mode == 'fakes':
fake_pyramid = laplacian_pyramid(minibatch, self.n_pyramids - 1,
self.gaussian_k)
for lod, level in enumerate(fake_pyramid):
desc = get_descriptors_for_minibatch(level, self.nhood_size,
self.nhoods_per_image)
self.fake_descs[lod].append(desc)
else:
raise ValueError(f'{mode} is not a implemented feed mode.')
@torch.no_grad()
def summary(self):
"""Summarize the results.
Returns:
dict | list: Summarized results.
"""
self.check()
real_descs = [finalize_descriptors(d) for d in self.real_descs]
fake_descs = [finalize_descriptors(d) for d in self.fake_descs]
del self.real_descs
del self.fake_descs
distance = [
sliced_wasserstein(dreal, dfake, self.dir_repeats,
self.dirs_per_repeat)
for dreal, dfake in zip(real_descs, fake_descs)
]
del real_descs
del fake_descs
distance = [d * 1e3 for d in distance] # multiply by 10^3
result = distance + [np.mean(distance)]
self._result_str = ', '.join([str(round(d, 2)) for d in result])
return result
@METRICS.register_module()
class PR(Metric):
r"""Improved Precision and recall metric.
In this metric, we draw real and generated samples respectively, and
embed them into a high-dimensional feature space using a pre-trained
classifier network. We use these features to estimate the corresponding
manifold. We obtain the estimation by calculating pairwise Euclidean
distances between all feature vectors in the set and, for each feature
vector, construct a hypersphere with radius equal to the distance to its
kth nearest neighbor. Together, these hyperspheres define a volume in
the feature space that serves as an estimate of the true manifold.
Precision is quantified by querying for each generated image whether
the image is within the estimated manifold of real images.
Symmetrically, recall is calculated by querying for each real image
whether the image is within estimated manifold of generated image.
Ref: https://github.com/NVlabs/stylegan2-ada-pytorch/blob/main/metrics/precision_recall.py # noqa
Note that we highly recommend that users should download the vgg16
script module from the following address. Then, the `vgg16_script` can
be set with user's local path. If not given, we will use the vgg16 from
pytorch model zoo. However, this may bring significant different in the
final results.
Tero's vgg16: https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt
Args:
num_images (int): The number of evaluated generated samples.
image_shape (tuple): Image shape in order "CHW". Defaults to None.
num_real_need (int | None, optional): The number of real images.
Defaults to None.
full_dataset (bool, optional): Whether to use full dataset for
evaluation. Defaults to False.
k (int, optional): Kth nearest parameter. Defaults to 3.
bgr2rgb (bool, optional): Whether to change the order of image
channel. Defaults to True.
vgg16_script (str, optional): Path for the Tero's vgg16 module.
Defaults to 'work_dirs/cache/vgg16.pt'.
row_batch_size (int, optional): The batch size of row data.
Defaults to 10000.
col_batch_size (int, optional): The batch size of col data.
Defaults to 10000.
"""
name = 'PR'
def __init__(self,
num_images,
image_shape=None,
num_real_need=None,
full_dataset=False,
k=3,
bgr2rgb=True,
vgg16_script='work_dirs/cache/vgg16.pt',
row_batch_size=10000,
col_batch_size=10000):
super().__init__(num_images, image_shape)
mmcv.print_log('loading vgg16 for improved precision and recall...',
'mmgen')
if os.path.isfile(vgg16_script):
self.vgg16 = torch.jit.load('work_dirs/cache/vgg16.pt').eval()
self.use_tero_scirpt = True
else:
mmcv.print_log(
'Cannot load Tero\'s script module. Use official '
'vgg16 instead', 'mmgen')
self.vgg16 = models.vgg16(pretrained=True).eval()
self.use_tero_scirpt = False
self.device = 'cpu'
if torch.cuda.is_available():
self.vgg16 = self.vgg16.cuda()
self.device = 'cuda'
self.k = k
self.bgr2rgb = bgr2rgb
self.full_dataset = full_dataset
self.row_batch_size = row_batch_size
self.col_batch_size = col_batch_size
if num_real_need:
self.num_real_need = num_real_need
if self.full_dataset:
self.num_real_need = 10000000
def prepare(self, *args, **kwargs):
"""Prepare for evaluating models with this metric."""
self.features_of_reals = []
self.features_of_fakes = []
@torch.no_grad()
def feed_op(self, batch, mode):
"""Feed data to the metric.
Args:
batch (Tensor): Input tensor.
mode (str): The mode of current data batch. 'reals' or 'fakes'.
"""
batch = batch.to(self.device)
if self.bgr2rgb:
batch = batch[:, [2, 1, 0], ...]
if self.use_tero_scirpt:
batch = (batch * 127.5 + 128).clamp(0, 255).to(torch.uint8)
if mode == 'reals':
self.features_of_reals.append(self.extract_features(batch))
elif mode == 'fakes':
self.features_of_fakes.append(self.extract_features(batch))
else:
raise ValueError(f'{mode} is not a implemented feed mode.')
def check(self):
if not self.full_dataset:
assert (self.num_real_feeded == self.num_real_need
and self.num_fake_feeded == self.num_fake_need)
else:
assert self.num_fake_feeded == self.num_fake_need
mmcv.print_log(
f'Test for the full dataset with {self.num_real_feeded}'
' real images', 'mmgen')
@torch.no_grad()
def summary(self):
"""Summarize the results.
Returns:
dict | list: Summarized results.
"""
self.check()
real_features = torch.cat(self.features_of_reals)
gen_features = torch.cat(self.features_of_fakes)
self._result_dict = {}
rank, ws = get_dist_info()
for name, manifold, probes in [
('precision', real_features, gen_features),
('recall', gen_features, real_features)
]:
kth = []
for manifold_batch in manifold.split(self.row_batch_size):
distance = compute_pr_distances(
row_features=manifold_batch,
col_features=manifold,
num_gpus=ws,
rank=rank,
col_batch_size=self.col_batch_size)
kth.append(
distance.to(torch.float32).kthvalue(self.k + 1).values.
to(torch.float16) if rank == 0 else None)
kth = torch.cat(kth) if rank == 0 else None
pred = []
for probes_batch in probes.split(self.row_batch_size):
distance = compute_pr_distances(
row_features=probes_batch,
col_features=manifold,
num_gpus=ws,
rank=rank,
col_batch_size=self.col_batch_size)
pred.append((distance <= kth).any(
dim=1) if rank == 0 else None)
self._result_dict[name] = float(
torch.cat(pred).to(torch.float32).mean() if rank ==
0 else 'nan')
precision = self._result_dict['precision']
recall = self._result_dict['recall']
self._result_str = f'precision: {precision}, recall:{recall}'
return self._result_dict
def extract_features(self, images):
"""Extracting image features.
Args:
images (torch.Tensor): Images tensor.
Returns:
torch.Tensor: Vgg16 features of input images.
"""
if self.use_tero_scirpt:
feature = self.vgg16(images, return_features=True)
else:
batch = F.interpolate(images, size=(224, 224))
before_fc = self.vgg16.features(batch)
before_fc = before_fc.view(-1, 7 * 7 * 512)
feature = self.vgg16.classifier[:4](before_fc)
return feature
@METRICS.register_module()
class IS(Metric):
"""IS (Inception Score) metric.
The images are split into groups, and the inception score is calculated
on each group of images, then the mean and standard deviation of the score
is reported. The calculation of the inception score on a group of images
involves first using the inception v3 model to calculate the conditional
probability for each image (p(y|x)). The marginal probability is then
calculated as the average of the conditional probabilities for the images
in the group (p(y)). The KL divergence is then calculated for each image as
the conditional probability multiplied by the log of the conditional
probability minus the log of the marginal probability. The KL divergence is
then summed over all images and averaged over all classes and the exponent
of the result is calculated to give the final score.
Ref: https://github.com/sbarratt/inception-score-pytorch/blob/master/inception_score.py # noqa
Note that we highly recommend that users should download the Inception V3
script module from the following address. Then, the `inception_pkl` can
be set with user's local path. If not given, we will use the Inception V3
from pytorch model zoo. However, this may bring significant different in
the final results.
Tero's Inception V3: https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt # noqa
Args:
num_images (int): The number of evaluated generated samples.
image_shape (tuple, optional): Image shape in order "CHW". Defaults to
None.
bgr2rgb (bool, optional): If True, reformat the BGR image to RGB
format. In default, our model generate images in the BGR order.
Thus, we use `True` as the default behavior. Please switch to
`False`, if the input is in the `RGB` order. Defaults to True.
resize (bool, optional): Whether resize image to 299x299. Defaults to
True.
splits (int, optional): The number of groups. Defaults to 10.
use_pil_resize (bool, optional): Whether use Bicubic interpolation with
Pillow's backend. If set as True, the evaluation process may be a
little bit slow, but achieve a more accurate IS result. Defaults
to False.
inception_args (dict, optional): Keyword args for inception net.
Defaults to ``dict(type='StyleGAN', inception_path=INCEPTION_URL)``.
"""
name = 'IS'
def __init__(self,
num_images,
image_shape=None,
bgr2rgb=True,
resize=True,
splits=10,
use_pil_resize=True,
inception_args=dict(
type='StyleGAN', inception_path=TERO_INCEPTION_URL)):
super().__init__(num_images, image_shape)
mmcv.print_log('Loading Inception V3 for IS...', 'mmgen')
model, style = load_inception(inception_args, 'IS')
self.inception_model = model
self.use_tero_script = style == 'StyleGAN'
self.num_real_feeded = self.num_images
self.resize = resize
self.splits = splits
self.bgr2rgb = bgr2rgb
self.use_pil_resize = use_pil_resize
self._pil_resize_warned = False
self.device = 'cpu'
if torch.cuda.is_available():
self.inception_model = self.inception_model.cuda()
self.device = 'cuda'
self.inception_model.eval()
def pil_resize(self, x):
"""Apply Bicubic interpolation with Pillow backend. Before and after
interpolation operation, we have to perform a type conversion between
torch.tensor and PIL.Image, and these operations make resize process a
bit slow.
Args:
x (Tensor): Input tensor, should have four dimension and
range in [-1, 1].
Returns:
torch.FloatTensor: Resized tensor.
"""
if not self._pil_resize_warned:
mmcv.print_log(
'`use_pil_resize` is set as True, apply Bicubic '
'interpolation with Pillow backend. We perform '
'type conversion between torch.tensor and '
'PIL.Image in this function and make this process '
'a little bit slow.', 'mmgen')
self._pil_resize_warned = True
from PIL import Image
if x.ndim != 4:
raise ValueError('Input images should have 4 dimensions, '
'here receive input with {} '
'dimensions.'.format(x.ndim))
x = (x.clone() * 127.5 + 128).clamp(0, 255).to(torch.uint8)
x_np = [x_.permute(1, 2, 0).detach().cpu().numpy() for x_ in x]
x_pil = [Image.fromarray(x_).resize((299, 299)) for x_ in x_np]
x_ten = torch.cat(
[torch.FloatTensor(np.array(x_)[None, ...]) for x_ in x_pil])
x_ten = (x_ten / 127.5 - 1).to(torch.float)
return x_ten.permute(0, 3, 1, 2)
def get_pred(self, x):
"""Get prediction from inception model.
Args:
x (Tensor): Input tensor.
Returns:
np.array: Inception score.
"""
if self.use_tero_script:
x = self.inception_model(x, no_output_bias=True)
else:
# specify the dimension to avoid warning
x = F.softmax(self.inception_model(x), dim=1)
return x
def prepare(self):
"""Prepare for evaluating models with this metric."""
self.preds = []
@torch.no_grad()
def feed_op(self, batch, mode):
"""Feed data to the metric.
Args:
batch (Tensor): Input tensor.
mode (str): The mode of current data batch. 'reals' or 'fakes'.
"""
if mode == 'reals':
pass
elif mode == 'fakes':
if self.bgr2rgb:
batch = batch[:, [2, 1, 0], ...]
if self.resize:
if self.use_pil_resize:
batch = self.pil_resize(batch)
else:
batch = F.interpolate(
batch, size=(299, 299), mode='bilinear')
if self.use_tero_script:
batch = (batch * 127.5 + 128).clamp(0, 255).to(torch.uint8)
batch = batch.to(self.device)
# get prediction
pred = self.get_pred(batch)
if dist.is_initialized():
ws = dist.get_world_size()
placeholder = [torch.zeros_like(pred) for _ in range(ws)]
dist.all_gather(placeholder, pred)
pred = torch.cat(placeholder, dim=0)
# in distributed training, we only collect features at rank-0.
if (dist.is_initialized()
and dist.get_rank() == 0) or not dist.is_initialized():
self.preds.append(pred.cpu().numpy())
else:
raise ValueError(f'{mode} is not a implemented feed mode.')
@torch.no_grad()
def summary(self):
"""Summarize the results.
TODO: support `master_only`
Returns:
dict | list: Summarized results.
"""
split_scores = []
self.preds = np.concatenate(self.preds, axis=0)
# check for the size
assert self.preds.shape[0] >= self.num_images
self.preds = self.preds[:self.num_images]
for k in range(self.splits):
part = self.preds[k * (self.num_images // self.splits):(k + 1) *
(self.num_images // self.splits), :]
py = np.mean(part, axis=0)
scores = []
for i in range(part.shape[0]):
pyx = part[i, :]
scores.append(entropy(pyx, py))
split_scores.append(np.exp(np.mean(scores)))
mean, std = np.mean(split_scores), np.std(split_scores)
# results for print/table
self._result_str = f'mean: {mean:.3f}, std: {std:.3f}'
# results for log_buffer
self._result_dict = {'is': mean, 'is_std': std}
return mean, std
def clear_fake_data(self):
"""Clear fake data."""
self.preds = []
self.num_fake_feeded = 0
def clear(self, clear_reals=False):
"""Clear data buffers.
Args:
clear_reals (bool, optional): Whether to clear real data.
Defaults to False.
"""
self.clear_fake_data()
@METRICS.register_module()
class PPL(Metric):
r"""Perceptual path length.
Measure the difference between consecutive images (their VGG16
embeddings) when interpolating between two random inputs. Drastic
changes mean that multiple features have changed together and that
they might be entangled.
Ref: https://github.com/rosinality/stylegan2-pytorch/blob/master/ppl.py # noqa
Args:
num_images (int): The number of evaluated generated samples.
image_shape (tuple, optional): Image shape in order "CHW". Defaults
to None.
crop (bool, optional): Whether crop images. Defaults to True.
epsilon (float, optional): Epsilon parameter for path sampling.
Defaults to 1e-4.
space (str, optional): Latent space. Defaults to 'W'.
sampling (str, optional): Sampling mode, whether sampling in full
path or endpoints. Defaults to 'end'.
latent_dim (int, optional): Latent dimension of input noise.
Defaults to 512.
"""
name = 'PPL'
def __init__(self,
num_images,
image_shape=None,
crop=True,
epsilon=1e-4,
space='W',
sampling='end',
latent_dim=512):
super().__init__(num_images, image_shape=image_shape)
self.crop = crop
self.epsilon = epsilon
self.space = space
self.sampling = sampling
self.latent_dim = latent_dim
self.num_images = num_images * 2
self.num_real_feeded = self.num_images
def prepare(self):
"""Prepare for evaluating models with this metric."""
self.dist_list = []
@torch.no_grad()
def feed_op(self, minibatch, mode):
"""Feed data to the metric.
Args:
batch (Tensor): Input tensor.
mode (str): The mode of current data batch. 'reals' or 'fakes'.
"""
if mode == 'reals':
return
# use minibatch's device type to initialize a lpips calculator
if not hasattr(self, 'percept'):
self.percept = PerceptualLoss(
use_gpu=minibatch.device.type.startswith('cuda'))
# crop and resize images
if self.crop:
c = minibatch.shape[2] // 8
minibatch = minibatch[:, :, c * 3:c * 7, c * 2:c * 6]
factor = minibatch.shape[2] // 256
if factor > 1:
minibatch = F.interpolate(
minibatch,
size=(256, 256),
mode='bilinear',
align_corners=False)
# calculator and store lpips score
distance = self.percept(minibatch[::2], minibatch[1::2]).view(
minibatch.shape[0] // 2) / (
self.epsilon**2)
self.dist_list.append(distance.to('cpu').numpy())
@torch.no_grad()
def summary(self):
"""Summarize the results.
Returns:
dict | list: Summarized results.
"""
distances = np.concatenate(self.dist_list, 0)
lo = np.percentile(distances, 1, interpolation='lower')
hi = np.percentile(distances, 99, interpolation='higher')
filtered_dist = np.extract(
np.logical_and(lo <= distances, distances <= hi), distances)
ppl_score = filtered_dist.mean()
self._result_str = f'{ppl_score:.1f}'
return ppl_score
def get_sampler(self, model, batch_size, sample_model):
"""Get sampler for sampling along the path.
Args:
model (nn.Module): Generative model.
batch_size (int): Sampling batch size.
sample_model (str): Which model you want to use. ['ema',
'orig']. Defaults to 'ema'.
Returns:
Object: A sampler for calculating path length regularization.
"""
if sample_model == 'ema':
generator = model.generator_ema
else:
generator = model.generator
ppl_sampler = PPLSampler(generator, self.num_images, batch_size,
self.space, self.sampling, self.epsilon,
self.latent_dim)
return ppl_sampler
class PPLSampler:
"""StyleGAN series generator's sampling iterator for PPL metric.
Args:
generator (nn.Module): StyleGAN series' generator.
num_images (int): The number of evaluated generated samples.
batch_size (int): Batch size of generated images.
space (str, optional): Latent space. Defaults to 'W'.
sampling (str, optional): Sampling mode, whether sampling in full
path or endpoints. Defaults to 'end'.
epsilon (float, optional): Epsilon parameter for path sampling.
Defaults to 1e-4.
latent_dim (int, optional): Latent dimension of input noise.
Defaults to 512.
"""
def __init__(self,
generator,
num_images,
batch_size,
space='W',
sampling='end',
epsilon=1e-4,
latent_dim=512):
assert space in ['Z', 'W']
assert sampling in ['full', 'end']
n_batch = num_images // batch_size
resid = num_images - (n_batch * batch_size)
self.batch_sizes = [batch_size] * n_batch + ([resid]
if resid > 0 else [])
self.device = get_module_device(generator)
self.generator = generator
self.latent_dim = latent_dim
self.space = space
self.sampling = sampling
self.epsilon = epsilon
def __iter__(self):
self.idx = 0
return self
@torch.no_grad()
def __next__(self):
if self.idx >= len(self.batch_sizes):
raise StopIteration
batch = self.batch_sizes[self.idx]
injected_noise = self.generator.make_injected_noise()
inputs = torch.randn([batch * 2, self.latent_dim], device=self.device)
if self.sampling == 'full':
lerp_t = torch.rand(batch, device=self.device)
else:
lerp_t = torch.zeros(batch, device=self.device)
if self.space == 'W':
assert hasattr(self.generator, 'style_mapping')
latent = self.generator.style_mapping(inputs)
latent_t0, latent_t1 = latent[::2], latent[1::2]
latent_e0 = torch.lerp(latent_t0, latent_t1, lerp_t[:, None])
latent_e1 = torch.lerp(latent_t0, latent_t1,
lerp_t[:, None] + self.epsilon)
latent_e = torch.stack([latent_e0, latent_e1],
1).view(*latent.shape)
image = self.generator([latent_e],
input_is_latent=True,
injected_noise=injected_noise)
else:
latent_t0, latent_t1 = inputs[::2], inputs[1::2]
latent_e0 = slerp(latent_t0, latent_t1, lerp_t[:, None])
latent_e1 = slerp(latent_t0, latent_t1,
lerp_t[:, None] + self.epsilon)
latent_e = torch.stack([latent_e0, latent_e1],
1).view(*inputs.shape)
image = self.generator([latent_e],
input_is_latent=False,
injected_noise=injected_noise)
self.idx += 1
return image
@METRICS.register_module()
class GaussianKLD(Metric):
r"""Gaussian KLD (Kullback-Leibler divergence) metric. We calculate the
KLD between two gaussian distribution via `mean` and `log_variance`.
The passed batch should be a dict instance and contain ``mean_pred``,
``mean_target``, ``logvar_pred``, ``logvar_target``.
When call ``feed`` operation, only ``reals`` mode is needed,
The calculation of KLD can be formulated as:
.. math::
:nowrap:
\begin{align}
KLD(p||q) &= -\int{p(x)\log{q(x)} dx} + \int{p(x)\log{p(x)} dx} \\
&= \frac{1}{2}\log{(2\pi \sigma_2^2)} +
\frac{\sigma_1^2 + (\mu_1 - \mu_2)^2}{2\sigma_2^2} -
\frac{1}{2}(1 + \log{2\pi \sigma_1^2}) \\
&= \log{\frac{\sigma_2}{\sigma_1}} +
\frac{\sigma_1^2 + (\mu_1 - \mu_2)^2}{2\sigma_2^2} - \frac{1}{2}
\end{align}
where `p` and `q` denote target and predicted distribution respectively.
Args:
num_images (int): The number of samples to be tested.
base (str, optional): The log base of calculated KLD. Support
``'e'`` and ``'2'``. Defaults to ``'e'``.
reduction (string, optional): Specifies the reduction to apply to the
output. Support ``'batchmean'``, ``'sum'`` and ``'mean'``. If
``reduction == 'batchmean'``, the sum of the output will be divided
by batchsize. If ``reduction == 'sum'``, the output will be summed.
If ``reduction == 'mean'``, the output will be divided by the
number of elements in the output. Defaults to ``'batchmean'``.
"""
name = 'GaussianKLD'
def __init__(self, num_images, base='e', reduction='batchmean'):
super().__init__(num_images, image_shape=None)
assert reduction in [
'sum', 'batchmean', 'mean'
], ('We only support reduction for \'batchmean\', \'sum\' '
'and \'mean\'')
assert base in ['e',
'2'], ('We only support log_base for \'e\' and \'2\'')
self.reduction = reduction
self.num_fake_feeded = self.num_images
self.cal_kld = partial(
gaussian_kld, weight=1, reduction='none', base=base)
def prepare(self):
"""Prepare for evaluating models with this metric."""
self.kld = []
self.num_real_feeded = 0
@torch.no_grad()
def feed_op(self, batch, mode):
"""Feed data to the metric.
Args:
batch (Tensor): Input tensor.
mode (str): The mode of current data batch. 'reals' or 'fakes'.
"""
if mode == 'fakes':
return
assert isinstance(batch, dict), ('To calculate GaussianKLD loss, a '
'dict contains probabilistic '
'parameters is required.')
# check required keys
require_keys = [
'mean_pred', 'mean_target', 'logvar_pred', 'logvar_target'
]
if any([k not in batch for k in require_keys]):
raise KeyError(f'The input dict must require {require_keys} at '
'the same time. But keys in the given dict are '
f'{batch.keys()}. Some of the requirements are '
'missing.')
kld = self.cal_kld(batch['mean_target'], batch['mean_pred'],
batch['logvar_target'], batch['logvar_pred'])
if dist.is_initialized():
ws = dist.get_world_size()
placeholder = [torch.zeros_like(kld) for _ in range(ws)]
dist.all_gather(placeholder, kld)
kld = torch.cat(placeholder, dim=0)
# in distributed training, we only collect features at rank-0.
if (dist.is_initialized()
and dist.get_rank() == 0) or not dist.is_initialized():
self.kld.append(kld.cpu())
@torch.no_grad()
def summary(self):
"""Summarize the results.
Returns:
dict | list: Summarized results.
"""
kld = torch.cat(self.kld, dim=0)
assert kld.shape[0] >= self.num_images
kld_np = kld.numpy()
if self.reduction == 'sum':
kld_result = np.sum(kld_np)
elif self.reduction == 'mean':
kld_result = np.mean(kld_np)
else:
kld_result = np.sum(kld_np) / kld_np.shape[0]
self._result_str = (f'{kld_result:.4f}')
return kld_result
# Copyright (c) OpenMMLab. All rights reserved.
from .ceph_hooks import PetrelUploadHook
from .ema_hook import ExponentialMovingAverageHook
from .pggan_fetch_data_hook import PGGANFetchDataHook
from .pickle_data_hook import PickleDataHook
from .visualization import VisualizationHook
from .visualize_training_samples import VisualizeUnconditionalSamples
__all__ = [
'VisualizeUnconditionalSamples', 'PGGANFetchDataHook',
'ExponentialMovingAverageHook', 'VisualizationHook', 'PickleDataHook',
'PetrelUploadHook'
]
# Copyright (c) OpenMMLab. All rights reserved.
import os
import mmcv
from mmcv.runner import HOOKS, Hook, master_only
@HOOKS.register_module()
class PetrelUploadHook(Hook):
"""Upload Data with Petrel.
With this hook, users can easily upload data to the cloud server for
saving local spaces. Please read the notes below for using this hook,
especially for the declaration of ``petrel``.
One of the major functions is to transfer the checkpoint files from the
local directory to the cloud server.
.. note::
``petrel`` is a private package containing several commonly used
``AWS`` python API. Currently, this package is only for internal usage
and will not be released to the public. We will support ``boto3`` in
the future. We think this hook is an easy template for you to transfer
to ``boto3``.
Args:
data_path (str, optional): Relative path of the data according to
current working directory. Defaults to 'ckpt'.
suffix (str, optional): Suffix for the data files. Defaults to '.pth'.
ceph_path (str | None, optional): Path in the cloud server.
Defaults to None.
interval (int, optional): Uploading interval (by iterations).
Default: -1.
upload_after_run (bool, optional): Whether to upload after running.
Defaults to True.
rm_orig (bool, optional): Whether to removing the local files after
uploading. Defaults to True.
"""
cfg_path = '~/petreloss.conf'
def __init__(self,
data_path='ckpt',
suffix='.pth',
ceph_path=None,
interval=-1,
upload_after_run=True,
rm_orig=True):
super().__init__()
self.interval = interval
self.upload_after_run = upload_after_run
self.data_path = data_path
self.suffix = suffix
self.ceph_path = ceph_path
self.rm_orig = rm_orig
# setup petrel client
try:
from petrel_client.client import Client
except ImportError:
raise ImportError('Please install petrel in advance.')
self.client = Client(self.cfg_path)
@staticmethod
def upload_dir(client,
local_dir,
remote_dir,
exp_name=None,
suffix=None,
remove_local_file=True):
"""Upload a directory to the cloud server.
Args:
client (obj): AWS client.
local_dir (str): Path for the local data.
remote_dir (str): Path for the remote server.
exp_name (str, optional): The experiment name. Defaults to None.
suffix (str, optional): Suffix for the data files.
Defaults to None.
remove_local_file (bool, optional): Whether to removing the local
files after uploading. Defaults to True.
"""
files = mmcv.scandir(local_dir, suffix=suffix, recursive=False)
files = [os.path.join(local_dir, x) for x in files]
# remove the rebundant symlinks in the data directory
files = [x for x in files if not os.path.islink(x)]
# get the actual exp_name in work_dir
if exp_name is None:
exp_name = local_dir.split('/')[-1]
mmcv.print_log(f'Uploading {len(files)} files to ceph.', 'mmgen')
for file in files:
with open(file, 'rb') as f:
data = f.read()
_path_splits = file.split('/')
idx = _path_splits.index(exp_name)
_rel_path = '/'.join(_path_splits[idx:])
_ceph_path = os.path.join(remote_dir, _rel_path)
client.put(_ceph_path, data)
# remove the local file to save space
if remove_local_file:
os.remove(file)
@master_only
def after_run(self, runner):
"""The behavior after the whole running.
Args:
runner (object): The runner.
"""
if not self.upload_after_run:
return
_data_path = os.path.join(runner.work_dir, self.data_path)
# get the actual exp_name in work_dir
exp_name = runner.work_dir.split('/')[-1]
self.upload_dir(
self.client,
_data_path,
self.ceph_path,
exp_name=exp_name,
suffix=self.suffix,
remove_local_file=self.rm_orig)
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from copy import deepcopy
import mmcv
import torch
from mmcv.parallel import is_module_wrapper
from mmcv.runner import HOOKS, Hook
@HOOKS.register_module()
class ExponentialMovingAverageHook(Hook):
"""Exponential Moving Average Hook.
Exponential moving average is a trick that widely used in current GAN
literature, e.g., PGGAN, StyleGAN, and BigGAN. This general idea of it is
maintaining a model with the same architecture, but its parameters are
updated as a moving average of the trained weights in the original model.
In general, the model with moving averaged weights achieves better
performance.
Args:
module_keys (str | tuple[str]): The name of the ema model. Note that we
require these keys are followed by '_ema' so that we can easily
find the original model by discarding the last four characters.
interp_mode (str, optional): Mode of the interpolation method.
Defaults to 'lerp'.
interp_cfg (dict | None, optional): Set arguments of the interpolation
function. Defaults to None.
interval (int, optional): Evaluation interval (by iterations).
Default: -1.
start_iter (int, optional): Start iteration for ema. If the start
iteration is not reached, the weights of ema model will maintain
the same as the original one. Otherwise, its parameters are updated
as a moving average of the trained weights in the original model.
Default: 0.
momentum_policy (str, optional): Policy of the momentum updating
method. Defaults to 'fixed'.
momentum_cfg (dict | None, optional): Set arguments of the momentum
updater function. Defaults to None.
"""
_registered_interp_funcs = ['lerp']
_registered_momentum_updaters = ['rampup', 'fixed']
def __init__(self,
module_keys,
interp_mode='lerp',
interp_cfg=None,
interval=-1,
start_iter=0,
momentum_policy='fixed',
momentum_cfg=None):
super().__init__()
# check args
assert interp_mode in self._registered_interp_funcs, (
'Supported '
f'interpolation functions are {self._registered_interp_funcs}, '
f'but got {interp_mode}')
assert momentum_policy in self._registered_momentum_updaters, (
'Supported momentum policy are'
f'{self._registered_momentum_updaters},'
f' but got {momentum_policy}')
assert isinstance(module_keys, str) or mmcv.is_tuple_of(
module_keys, str)
self.module_keys = (module_keys, ) if isinstance(module_keys,
str) else module_keys
# sanity check for the format of module keys
for k in self.module_keys:
assert k.endswith(
'_ema'), 'You should give keys that end with "_ema".'
self.interp_mode = interp_mode
self.interp_cfg = dict() if interp_cfg is None else deepcopy(
interp_cfg)
self.interval = interval
self.start_iter = start_iter
assert hasattr(
self, interp_mode
), f'Currently, we do not support {self.interp_mode} for EMA.'
self.interp_func = getattr(self, interp_mode)
self.momentum_cfg = dict() if momentum_cfg is None else deepcopy(
momentum_cfg)
self.momentum_policy = momentum_policy
if momentum_policy != 'fixed':
assert hasattr(
self, momentum_policy
), f'Currently, we do not support {self.momentum_policy} for EMA.'
self.momentum_updater = getattr(self, momentum_policy)
@staticmethod
def lerp(a, b, momentum=0.999, momentum_nontrainable=0., trainable=True):
"""Does a linear interpolation of two parameters/ buffers.
Args:
a (torch.Tensor): Interpolation start point, refer to orig state.
b (torch.Tensor): Interpolation end point, refer to ema state.
momentum (float, optional): The weight for the interpolation
formula. Defaults to 0.999.
momentum_nontrainable (float, optional): The weight for the
interpolation formula used for nontrainable parameters.
Defaults to 0..
trainable (bool, optional): Whether input parameters is trainable.
If set to False, momentum_nontrainable will be used.
Defaults to True.
Returns:
torch.Tensor: Interpolation result.
"""
m = momentum if trainable else momentum_nontrainable
return a + (b - a) * m
@staticmethod
def rampup(runner, ema_kimg=10, ema_rampup=0.05, batch_size=4, eps=1e-8):
"""Ramp up ema momentum.
Ref: https://github.com/NVlabs/stylegan3/blob/a5a69f58294509598714d1e88c9646c3d7c6ec94/training/training_loop.py#L300-L308 # noqa
Args:
runner (_type_): _description_
ema_kimg (int, optional): Half-life of the exponential moving
average of generator weights. Defaults to 10.
ema_rampup (float, optional): EMA ramp-up coefficient.If set to
None, then rampup will be disabled. Defaults to 0.05.
batch_size (int, optional): Total batch size for one training
iteration. Defaults to 4.
eps (float, optional): Epsiolon to avoid ``batch_size`` divided by
zero. Defaults to 1e-8.
Returns:
dict: Updated momentum.
"""
cur_nimg = (runner.iter + 1) * batch_size
ema_nimg = ema_kimg * 1000
if ema_rampup is not None:
ema_nimg = min(ema_nimg, cur_nimg * ema_rampup)
ema_beta = 0.5**(batch_size / max(ema_nimg, eps))
return dict(momentum=ema_beta)
def every_n_iters(self, runner, n):
if runner.iter < self.start_iter:
return True
return (runner.iter + 1 - self.start_iter) % n == 0 if n > 0 else False
@torch.no_grad()
def after_train_iter(self, runner):
if not self.every_n_iters(runner, self.interval):
return
model = runner.model.module if is_module_wrapper(
runner.model) else runner.model
# update momentum
_interp_cfg = deepcopy(self.interp_cfg)
if self.momentum_policy != 'fixed':
_updated_args = self.momentum_updater(runner, **self.momentum_cfg)
_interp_cfg.update(_updated_args)
for key in self.module_keys:
# get current ema states
ema_net = getattr(model, key)
states_ema = ema_net.state_dict(keep_vars=False)
# get currently original states
net = getattr(model, key[:-4])
states_orig = net.state_dict(keep_vars=True)
for k, v in states_orig.items():
if runner.iter < self.start_iter:
states_ema[k].data.copy_(v.data)
else:
states_ema[k] = self.interp_func(
v,
states_ema[k],
trainable=v.requires_grad,
**_interp_cfg).detach()
ema_net.load_state_dict(states_ema, strict=True)
def before_run(self, runner):
model = runner.model.module if is_module_wrapper(
runner.model) else runner.model
# sanity check for ema model
for k in self.module_keys:
if not hasattr(model, k) and not hasattr(model, k[:-4]):
raise RuntimeError(
f'Cannot find both {k[:-4]} and {k} network for EMA hook.')
if not hasattr(model, k) and hasattr(model, k[:-4]):
setattr(model, k, deepcopy(getattr(model, k[:-4])))
warnings.warn(
f'We do not suggest construct and initialize EMA model {k}'
' in hook. You may explicitly define it by yourself.')
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcv.parallel import is_module_wrapper
from mmcv.runner import HOOKS, Hook
@HOOKS.register_module()
class PGGANFetchDataHook(Hook):
"""PGGAN Fetch Data Hook.
Args:
interval (int, optional): The interval of calling this hook. If set
to -1, the visualization hook will not be called. Defaults to 1.
"""
def __init__(self, interval=1):
super().__init__()
self.interval = interval
def before_fetch_train_data(self, runner):
"""The behavior before fetch train data.
Args:
runner (object): The runner.
"""
if not self.every_n_iters(runner, self.interval):
return
_module = runner.model.module if is_module_wrapper(
runner.model) else runner.model
_next_scale_int = _module._next_scale_int
if isinstance(_next_scale_int, torch.Tensor):
_next_scale_int = _next_scale_int.item()
runner.data_loader.update_dataloader(_next_scale_int)
# Copyright (c) OpenMMLab. All rights reserved.
import logging
import os
import pickle
import mmcv
import torch
from mmcv.runner import HOOKS, Hook
from mmcv.runner.dist_utils import master_only
@HOOKS.register_module()
class PickleDataHook(Hook):
"""Pickle Useful Data Hook.
This hook will be used in SinGAN training for saving some important data
that will be used in testing or inference.
Args:
output_dir (str): The output path for saving pickled data.
data_name_list (list[str]): The list contains the name of results in
outputs dict.
interval (int): The interval of calling this hook. If set to -1,
the visualization hook will not be called. Default: -1.
before_run (bool, optional): Whether to save before running.
Defaults to False.
after_run (bool, optional): Whether to save after running.
Defaults to False.
filename_tmpl (str, optional): Format string used to save images. The
output file name will be formatted as this args.
Defaults to 'iter_{}.pkl'.
"""
def __init__(self,
output_dir,
data_name_list,
interval=-1,
before_run=False,
after_run=False,
filename_tmpl='iter_{}.pkl'):
assert mmcv.is_list_of(data_name_list, str)
self.output_dir = output_dir
self.data_name_list = data_name_list
self.interval = interval
self.filename_tmpl = filename_tmpl
self._before_run = before_run
self._after_run = after_run
@master_only
def after_run(self, runner):
"""The behavior after each train iteration.
Args:
runner (object): The runner.
"""
if self._after_run:
self._pickle_data(runner)
@master_only
def before_run(self, runner):
"""The behavior after each train iteration.
Args:
runner (object): The runner.
"""
if self._before_run:
self._pickle_data(runner)
@master_only
def after_train_iter(self, runner):
"""The behavior after each train iteration.
Args:
runner (object): The runner.
"""
if not self.every_n_iters(runner, self.interval):
return
self._pickle_data(runner)
def _pickle_data(self, runner):
filename = self.filename_tmpl.format(runner.iter + 1)
if not hasattr(self, '_out_dir'):
self._out_dir = os.path.join(runner.work_dir, self.output_dir)
mmcv.mkdir_or_exist(self._out_dir)
file_path = os.path.join(self._out_dir, filename)
with open(file_path, 'wb') as f:
data = runner.outputs['results']
not_find_keys = []
data_dict = {}
for k in self.data_name_list:
if k in data.keys():
data_dict[k] = self._get_numpy_data(data[k])
else:
not_find_keys.append(k)
pickle.dump(data_dict, f)
mmcv.print_log(f'Pickle data in {filename}', 'mmgen')
if len(not_find_keys) > 0:
mmcv.print_log(
f'Cannot find keys for pickling: {not_find_keys}',
'mmgen',
level=logging.WARN)
f.flush()
def _get_numpy_data(self, data):
if isinstance(data, list):
return [self._get_numpy_data(x) for x in data]
if isinstance(data, torch.Tensor):
return data.cpu().numpy()
return data
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import mmcv
import torch
from mmcv.runner import HOOKS, Hook
from mmcv.runner.dist_utils import master_only
from torchvision.utils import save_image
@HOOKS.register_module('MMGenVisualizationHook')
class VisualizationHook(Hook):
"""Visualization hook.
In this hook, we use the official api `save_image` in torchvision to save
the visualization results.
Args:
output_dir (str): The file path to store visualizations.
res_name_list (str): The list contains the name of results in outputs
dict. The results in outputs dict must be a torch.Tensor with shape
(n, c, h, w).
interval (int): The interval of calling this hook. If set to -1,
the visualization hook will not be called. Default: -1.
filename_tmpl (str): Format string used to save images. The output file
name will be formatted as this args. Default: 'iter_{}.png'.
rerange (bool): Whether to rerange the output value from [-1, 1] to
[0, 1]. We highly recommend users should preprocess the
visualization results on their own. Here, we just provide a simple
interface. Default: True.
bgr2rgb (bool): Whether to reformat the channel dimension from BGR to
RGB. The final image we will save is following RGB style.
Default: True.
nrow (int): The number of samples in a row. Default: 1.
padding (int): The number of padding pixels between each samples.
Default: 4.
"""
def __init__(self,
output_dir,
res_name_list,
interval=-1,
filename_tmpl='iter_{}.png',
rerange=True,
bgr2rgb=True,
nrow=1,
padding=4):
assert mmcv.is_list_of(res_name_list, str)
self.output_dir = output_dir
self.res_name_list = res_name_list
self.interval = interval
self.filename_tmpl = filename_tmpl
self.bgr2rgb = bgr2rgb
self.rerange = rerange
self.nrow = nrow
self.padding = padding
@master_only
def after_train_iter(self, runner):
"""The behavior after each train iteration.
Args:
runner (object): The runner.
"""
if not self.every_n_iters(runner, self.interval):
return
results = runner.outputs['results']
filename = self.filename_tmpl.format(runner.iter + 1)
# img_list = [x for k, x in results.items() if k in self.res_name_list]
img_list = [results[k] for k in self.res_name_list if k in results]
img_cat = torch.cat(img_list, dim=3).detach()
if self.rerange:
img_cat = ((img_cat + 1) / 2)
if self.bgr2rgb:
img_cat = img_cat[:, [2, 1, 0], ...]
img_cat = img_cat.clamp_(0, 1)
if not hasattr(self, '_out_dir'):
self._out_dir = osp.join(runner.work_dir, self.output_dir)
mmcv.mkdir_or_exist(self._out_dir)
save_image(
img_cat,
osp.join(self._out_dir, filename),
nrow=self.nrow,
padding=self.padding)
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import mmcv
import torch
from mmcv.runner import HOOKS, Hook
from mmcv.runner.dist_utils import master_only
from torchvision.utils import save_image
@HOOKS.register_module()
class VisualizeUnconditionalSamples(Hook):
"""Visualization hook for unconditional GANs.
In this hook, we use the official api `save_image` in torchvision to save
the visualization results.
Args:
output_dir (str): The file path to store visualizations.
fixed_noise (bool, optional): Whether to use fixed noises in sampling.
Defaults to True.
num_samples (int, optional): The number of samples to show in
visualization. Defaults to 16.
interval (int): The interval of calling this hook. If set to -1,
the visualization hook will not be called. Default: -1.
filename_tmpl (str): Format string used to save images. The output file
name will be formatted as this args. Default: 'iter_{}.png'.
rerange (bool): Whether to rerange the output value from [-1, 1] to
[0, 1]. We highly recommend users should preprocess the
visualization results on their own. Here, we just provide a simple
interface. Default: True.
bgr2rgb (bool): Whether to reformat the channel dimension from BGR to
RGB. The final image we will save is following RGB style.
Default: True.
nrow (int): The number of samples in a row. Default: 1.
padding (int): The number of padding pixels between each samples.
Default: 4.
kwargs (dict | None, optional): Key-word arguments for sampling
function. Defaults to None.
"""
def __init__(self,
output_dir,
fixed_noise=True,
num_samples=16,
interval=-1,
filename_tmpl='iter_{}.png',
rerange=True,
bgr2rgb=True,
nrow=4,
padding=0,
kwargs=None):
self.output_dir = output_dir
self.fixed_noise = fixed_noise
self.num_samples = num_samples
self.interval = interval
self.filename_tmpl = filename_tmpl
self.bgr2rgb = bgr2rgb
self.rerange = rerange
self.nrow = nrow
self.padding = padding
# the sampling noise will be initialized by the first sampling.
self.sampling_noise = None
self.kwargs = kwargs if kwargs is not None else dict()
@master_only
def after_train_iter(self, runner):
"""The behavior after each train iteration.
Args:
runner (object): The runner.
"""
if not self.every_n_iters(runner, self.interval):
return
# eval mode
runner.model.eval()
# no grad in sampling
with torch.no_grad():
outputs_dict = runner.model(
self.sampling_noise,
return_loss=False,
num_batches=self.num_samples,
return_noise=True,
**self.kwargs)
imgs = outputs_dict['fake_img']
noise_ = outputs_dict['noise_batch']
# initialize samling noise with the first returned noise
if self.sampling_noise is None and self.fixed_noise:
self.sampling_noise = noise_
# train mode
runner.model.train()
filename = self.filename_tmpl.format(runner.iter + 1)
if self.rerange:
imgs = ((imgs + 1) / 2)
if self.bgr2rgb and imgs.size(1) == 3:
imgs = imgs[:, [2, 1, 0], ...]
if imgs.size(1) == 1:
imgs = torch.cat([imgs, imgs, imgs], dim=1)
imgs = imgs.clamp_(0, 1)
mmcv.mkdir_or_exist(osp.join(runner.work_dir, self.output_dir))
save_image(
imgs,
osp.join(runner.work_dir, self.output_dir, filename),
nrow=self.nrow,
padding=self.padding)
# Copyright (c) OpenMMLab. All rights reserved.
from .builder import build_optimizers
__all__ = ['build_optimizers']
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.runner import build_optimizer
def build_optimizers(model, cfgs):
"""Build multiple optimizers from configs.
If `cfgs` contains several dicts for optimizers, then a dict for each
constructed optimizers will be returned.
If `cfgs` only contains one optimizer config, the constructed optimizer
itself will be returned.
For example,
1) Multiple optimizer configs:
.. code-block:: python
optimizer_cfg = dict(
model1=dict(type='SGD', lr=lr),
model2=dict(type='SGD', lr=lr))
The return dict is
``dict('model1': torch.optim.Optimizer, 'model2': torch.optim.Optimizer)``
2) Single optimizer config:
.. code-block:: python
optimizer_cfg = dict(type='SGD', lr=lr)
The return is ``torch.optim.Optimizer``.
Args:
model (:obj:`nn.Module`): The model with parameters to be optimized.
cfgs (dict): The config dict of the optimizer.
Returns:
dict[:obj:`torch.optim.Optimizer`] | :obj:`torch.optim.Optimizer`:
The initialized optimizers.
"""
optimizers = {}
if hasattr(model, 'module'):
model = model.module
# determine whether 'cfgs' has several dicts for optimizers
is_dict_of_dict = True
for key, cfg in cfgs.items():
if not isinstance(cfg, dict):
is_dict_of_dict = False
if is_dict_of_dict:
for key, cfg in cfgs.items():
cfg_ = cfg.copy()
module = getattr(model, key)
optimizers[key] = build_optimizer(module, cfg_)
return optimizers
return build_optimizer(model, cfgs)
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.utils import Registry, build_from_cfg
METRICS = Registry('metric')
def build(cfg, registry, default_args=None):
"""Build a module.
Args:
cfg (dict, list[dict]): The config of modules, is is either a dict
or a list of configs.
registry (:obj:`Registry`): A registry the module belongs to.
default_args (dict, optional): Default arguments to build the module.
Defaults to None.
Returns:
nn.Module: A built nn module.
"""
if isinstance(cfg, list):
modules = [
build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
]
return modules
return build_from_cfg(cfg, registry, default_args)
def build_metric(cfg):
"""Build a metric calculator."""
return build(cfg, METRICS)
# Copyright (c) OpenMMLab. All rights reserved.
from .dynamic_iterbased_runner import DynamicIterBasedRunner
__all__ = ['DynamicIterBasedRunner']
# Copyright (c) OpenMMLab. All rights reserved.
try:
from apex import amp
except ImportError:
amp = None
def apex_amp_initialize(models, optimizers, init_args=None, mode='gan'):
"""Initialize apex.amp for mixed-precision training.
Args:
models (nn.Module | list[Module]): Modules to be wrapped with apex.amp.
optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
init_args (dict | None, optional): Config for amp initialization.
Defaults to None.
mode (str, optional): The moded used to initialize the apex.map.
Different modes lead to different wrapping mode for models and
optimizers. Defaults to 'gan'.
Returns:
Module, :obj:`Optimizer`: Wrapped module and optimizer.
"""
init_args = init_args or dict()
if mode == 'gan':
_optmizers = [optimizers['generator'], optimizers['discriminator']]
models, _optmizers = amp.initialize(models, _optmizers, **init_args)
optimizers['generator'] = _optmizers[0]
optimizers['discriminator'] = _optmizers[1]
return models, optimizers
else:
raise NotImplementedError(
f'Cannot initialize apex.amp with mode {mode}')
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import time
from tempfile import TemporaryDirectory
import mmcv
import torch
from mmcv.parallel import is_module_wrapper
from mmcv.runner.checkpoint import get_state_dict, weights_to_cpu
from torch.optim import Optimizer
def save_checkpoint(model,
filename,
optimizer=None,
loss_scaler=None,
save_apex_amp=False,
meta=None):
"""Save checkpoint to file.
The checkpoint will have 3 or more fields: ``meta``, ``state_dict`` and
``optimizer``. By default ``meta`` will contain version and time info.
In mixed-precision training, ``loss_scaler`` or ``amp.state_dict`` will be
saved in checkpoint.
Args:
model (Module): Module whose params are to be saved.
filename (str): Checkpoint filename.
optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
loss_scaler (Object, optional): Loss scaler used for FP16 training.
save_apex_amp (bool, optional): Whether to save apex.amp
``state_dict``.
meta (dict, optional): Metadata to be saved in checkpoint.
"""
if meta is None:
meta = {}
elif not isinstance(meta, dict):
raise TypeError(f'meta must be a dict or None, but got {type(meta)}')
meta.update(mmcv_version=mmcv.__version__, time=time.asctime())
if is_module_wrapper(model):
model = model.module
if hasattr(model, 'CLASSES') and model.CLASSES is not None:
# save class name to the meta
meta.update(CLASSES=model.CLASSES)
checkpoint = {
'meta': meta,
'state_dict': weights_to_cpu(get_state_dict(model))
}
# save optimizer state dict in the checkpoint
if isinstance(optimizer, Optimizer):
checkpoint['optimizer'] = optimizer.state_dict()
elif isinstance(optimizer, dict):
checkpoint['optimizer'] = {}
for name, optim in optimizer.items():
checkpoint['optimizer'][name] = optim.state_dict()
# save loss scaler for mixed-precision (FP16) training
if loss_scaler is not None:
checkpoint['loss_scaler'] = loss_scaler.state_dict()
# save state_dict from apex.amp
if save_apex_amp:
from apex import amp
checkpoint['amp'] = amp.state_dict()
if filename.startswith('pavi://'):
try:
from pavi import modelcloud
from pavi.exception import NodeNotFoundError
except ImportError:
raise ImportError(
'Please install pavi to load checkpoint from modelcloud.')
model_path = filename[7:]
root = modelcloud.Folder()
model_dir, model_name = osp.split(model_path)
try:
model = modelcloud.get(model_dir)
except NodeNotFoundError:
model = root.create_training_model(model_dir)
with TemporaryDirectory() as tmp_dir:
checkpoint_file = osp.join(tmp_dir, model_name)
with open(checkpoint_file, 'wb') as f:
torch.save(checkpoint, f)
f.flush()
model.create_file(checkpoint_file, name=model_name)
else:
mmcv.mkdir_or_exist(osp.dirname(filename))
# immediately flush buffer
with open(filename, 'wb') as f:
torch.save(checkpoint, f)
f.flush()
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import platform
import shutil
import time
import warnings
from functools import partial
import mmcv
import torch
import torch.distributed as dist
from mmcv.parallel import collate, is_module_wrapper
from mmcv.runner import HOOKS, RUNNERS, IterBasedRunner, get_host_info
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from .checkpoint import save_checkpoint
try:
# If PyTorch version >= 1.6.0, torch.cuda.amp.GradScaler would be imported
# and used; otherwise, auto fp16 will adopt mmcv's implementation.
from torch.cuda.amp import GradScaler
except ImportError:
pass
class IterLoader:
"""Iteration based dataloader.
This wrapper for dataloader is to matching the iter-based training
proceduer.
Args:
dataloader (object): Dataloader in PyTorch.
runner (object): ``mmcv.Runner``
"""
def __init__(self, dataloader, runner):
self._dataloader = dataloader
self.runner = runner
self.iter_loader = iter(self._dataloader)
self._epoch = 0
@property
def epoch(self):
"""The number of current epoch.
Returns:
int: Epoch number.
"""
return self._epoch
def update_dataloader(self, curr_scale):
"""Update dataloader.
Update the dataloader according to the `curr_scale`. This functionality
is very helpful in training progressive growing GANs in which the
dataloader should be updated according to the scale of the models in
training.
Args:
curr_scale (int): The scale in current stage.
"""
# update dataset, sampler, and samples per gpu in dataloader
if hasattr(self._dataloader.dataset, 'update_annotations'):
update_flag = self._dataloader.dataset.update_annotations(
curr_scale)
else:
update_flag = False
if update_flag:
# the sampler should be updated with the modified dataset
assert hasattr(self._dataloader.sampler, 'update_sampler')
samples_per_gpu = None if not hasattr(
self._dataloader.dataset, 'samples_per_gpu'
) else self._dataloader.dataset.samples_per_gpu
self._dataloader.sampler.update_sampler(self._dataloader.dataset,
samples_per_gpu)
# update samples per gpu
if samples_per_gpu is not None:
if dist.is_initialized():
# samples = samples_per_gpu
# self._dataloader.collate_fn = partial(
# collate, samples_per_gpu=samples)
self._dataloader = DataLoader(
self._dataloader.dataset,
batch_size=samples_per_gpu,
sampler=self._dataloader.sampler,
num_workers=self._dataloader.num_workers,
collate_fn=partial(
collate, samples_per_gpu=samples_per_gpu),
shuffle=False,
worker_init_fn=self._dataloader.worker_init_fn)
self.iter_loader = iter(self._dataloader)
else:
raise NotImplementedError(
'Currently, we only support dynamic batch size in'
' ddp, because the number of gpus in DataParallel '
'cannot be obtained easily.')
def __next__(self):
try:
data = next(self.iter_loader)
except StopIteration:
self._epoch += 1
if hasattr(self._dataloader.sampler, 'set_epoch'):
self._dataloader.sampler.set_epoch(self._epoch)
self.iter_loader = iter(self._dataloader)
data = next(self.iter_loader)
return data
def __len__(self):
return len(self._dataloader)
@RUNNERS.register_module()
class DynamicIterBasedRunner(IterBasedRunner):
"""Dynamic Iterbased Runner.
In this Dynamic Iterbased Runner, we will pass the ``reducer`` to the
``train_step`` so that the models can be trained with dynamic architecture.
More details and clarification can be found in this [tutorial](docs/en/tutorials/ddp_train_gans.md). # noqa
Args:
is_dynamic_ddp (bool, optional): Whether to adopt the dynamic ddp.
Defaults to False.
pass_training_status (bool, optional): Whether to pass the training
status. Defaults to False.
fp16_loss_scaler (dict | None, optional): Config for fp16 GradScaler
from ``torch.cuda.amp``. Defaults to None.
use_apex_amp (bool, optional): Whether to use apex.amp to start mixed
precision training. Defaults to False.
"""
def __init__(self,
*args,
is_dynamic_ddp=False,
pass_training_status=False,
fp16_loss_scaler=None,
use_apex_amp=False,
**kwargs):
super().__init__(*args, **kwargs)
if is_module_wrapper(self.model):
_model = self.model.module
else:
_model = self.model
self.is_dynamic_ddp = is_dynamic_ddp
self.pass_training_status = pass_training_status
# add a flag for checking if `self.optimizer` comes from `_model`
self.optimizer_from_model = False
# add support for optimizer is None.
# sanity check for whether `_model` contains self-defined optimizer
if hasattr(_model, 'optimizer'):
assert self.optimizer is None, (
'Runner and model cannot contain optimizer at the same time.')
self.optimizer_from_model = True
self.optimizer = _model.optimizer
# add fp16 grad scaler, using pytorch official GradScaler
self.with_fp16_grad_scaler = False
if fp16_loss_scaler is not None:
self.loss_scaler = GradScaler(**fp16_loss_scaler)
self.with_fp16_grad_scaler = True
mmcv.print_log('Use FP16 grad scaler in Training', 'mmgen')
# flag to use amp in apex (NVIDIA)
self.use_apex_amp = use_apex_amp
def call_hook(self, fn_name):
"""Call all hooks.
Args:
fn_name (str): The function name in each hook to be called, such as
"before_train_epoch".
"""
for hook in self._hooks:
if hasattr(hook, fn_name):
getattr(hook, fn_name)(self)
def train(self, data_loader, **kwargs):
if is_module_wrapper(self.model):
_model = self.model.module
else:
_model = self.model
self.model.train()
self.mode = 'train'
# check if self.optimizer from model and track it
if self.optimizer_from_model:
self.optimizer = _model.optimizer
self.data_loader = data_loader
self._epoch = data_loader.epoch
self.call_hook('before_fetch_train_data')
data_batch = next(self.data_loader)
self.call_hook('before_train_iter')
# prepare input args for train_step
# running status
if self.pass_training_status:
running_status = dict(iteration=self.iter, epoch=self.epoch)
kwargs['running_status'] = running_status
# ddp reducer for tracking dynamic computational graph
if self.is_dynamic_ddp:
kwargs.update(dict(ddp_reducer=self.model.reducer))
if self.with_fp16_grad_scaler:
kwargs.update(dict(loss_scaler=self.loss_scaler))
if self.use_apex_amp:
kwargs.update(dict(use_apex_amp=True))
outputs = self.model.train_step(data_batch, self.optimizer, **kwargs)
# the loss scaler should be updated after ``train_step``
if self.with_fp16_grad_scaler:
self.loss_scaler.update()
# further check for the cases where the optimizer is built in
# `train_step`.
if self.optimizer is None:
if hasattr(_model, 'optimizer'):
self.optimizer_from_model = True
self.optimizer = _model.optimizer
# check if self.optimizer from model and track it
if self.optimizer_from_model:
self.optimizer = _model.optimizer
if not isinstance(outputs, dict):
raise TypeError('model.train_step() must return a dict')
if 'log_vars' in outputs:
self.log_buffer.update(outputs['log_vars'], outputs['num_samples'])
self.outputs = outputs
self.call_hook('after_train_iter')
self._inner_iter += 1
self._iter += 1
def run(self, data_loaders, workflow, max_iters=None, **kwargs):
"""Start running.
Args:
data_loaders (list[:obj:`DataLoader`]): Dataloaders for training
and validation.
workflow (list[tuple]): A list of (phase, iters) to specify the
running order and iterations. E.g, [('train', 10000),
('val', 1000)] means running 10000 iterations for training and
1000 iterations for validation, iteratively.
"""
assert isinstance(data_loaders, list)
assert mmcv.is_list_of(workflow, tuple)
assert len(data_loaders) == len(workflow)
if max_iters is not None:
warnings.warn(
'setting max_iters in run is deprecated, '
'please set max_iters in runner_config', DeprecationWarning)
self._max_iters = max_iters
assert self._max_iters is not None, (
'max_iters must be specified during instantiation')
work_dir = self.work_dir if self.work_dir is not None else 'NONE'
self.logger.info('Start running, host: %s, work_dir: %s',
get_host_info(), work_dir)
self.logger.info('workflow: %s, max: %d iters', workflow,
self._max_iters)
self.call_hook('before_run')
iter_loaders = [IterLoader(x, self) for x in data_loaders]
self.call_hook('before_epoch')
while self.iter < self._max_iters:
for i, flow in enumerate(workflow):
self._inner_iter = 0
mode, iters = flow
if not isinstance(mode, str) or not hasattr(self, mode):
raise ValueError(
'runner has no method named "{}" to run a workflow'.
format(mode))
iter_runner = getattr(self, mode)
for _ in range(iters):
if mode == 'train' and self.iter >= self._max_iters:
break
iter_runner(iter_loaders[i], **kwargs)
time.sleep(1) # wait for some hooks like loggers to finish
self.call_hook('after_epoch')
self.call_hook('after_run')
def resume(self,
checkpoint,
resume_optimizer=True,
resume_loss_scaler=True,
map_location='default'):
"""Resume model from checkpoint.
Args:
checkpoint (str): Checkpoint to resume from.
resume_optimizer (bool, optional): Whether resume the optimizer(s)
if the checkpoint file includes optimizer(s). Default to True.
resume_loss_scaler (bool, optional): Whether to resume the loss
scaler (GradScaler) from ``torch.cuda.amp``. Defaults to True.
map_location (str, optional): Same as :func:`torch.load`.
Default to 'default'.
"""
if map_location == 'default':
device_id = torch.cuda.current_device()
checkpoint = self.load_checkpoint(
checkpoint,
map_location=lambda storage, loc: storage.cuda(device_id))
else:
checkpoint = self.load_checkpoint(
checkpoint, map_location=map_location)
self._epoch = checkpoint['meta']['epoch']
self._iter = checkpoint['meta']['iter']
self._inner_iter = checkpoint['meta']['iter']
if 'optimizer' in checkpoint and resume_optimizer:
if isinstance(self.optimizer, Optimizer):
self.optimizer.load_state_dict(checkpoint['optimizer'])
elif isinstance(self.optimizer, dict):
for k in self.optimizer.keys():
self.optimizer[k].load_state_dict(
checkpoint['optimizer'][k])
else:
raise TypeError(
'Optimizer should be dict or torch.optim.Optimizer '
f'but got {type(self.optimizer)}')
if 'loss_scaler' in checkpoint and resume_loss_scaler:
self.loss_scaler.load_state_dict(checkpoint['loss_scaler'])
if self.use_apex_amp:
from apex import amp
amp.load_state_dict(checkpoint['amp'])
self.logger.info(f'resumed from epoch: {self.epoch}, iter {self.iter}')
def save_checkpoint(self,
out_dir,
filename_tmpl='iter_{}.pth',
meta=None,
save_optimizer=True,
create_symlink=True):
"""Save checkpoint to file.
Args:
out_dir (str): Directory to save checkpoint files.
filename_tmpl (str, optional): Checkpoint file template.
Defaults to 'iter_{}.pth'.
meta (dict, optional): Metadata to be saved in checkpoint.
Defaults to None.
save_optimizer (bool, optional): Whether save optimizer.
Defaults to True.
create_symlink (bool, optional): Whether create symlink to the
latest checkpoint file. Defaults to True.
"""
if meta is None:
meta = dict(iter=self.iter + 1, epoch=self.epoch + 1)
elif isinstance(meta, dict):
meta.update(iter=self.iter + 1, epoch=self.epoch + 1)
else:
raise TypeError(
f'meta should be a dict or None, but got {type(meta)}')
if self.meta is not None:
meta.update(self.meta)
filename = filename_tmpl.format(self.iter + 1)
filepath = osp.join(out_dir, filename)
optimizer = self.optimizer if save_optimizer else None
_loss_scaler = self.loss_scaler if self.with_fp16_grad_scaler else None
save_checkpoint(
self.model,
filepath,
optimizer=optimizer,
loss_scaler=_loss_scaler,
save_apex_amp=self.use_apex_amp,
meta=meta)
# in some environments, `os.symlink` is not supported, you may need to
# set `create_symlink` to False
if create_symlink:
dst_file = osp.join(out_dir, 'latest.pth')
if platform.system() != 'Windows':
mmcv.symlink(filename, dst_file)
else:
shutil.copy(filepath, dst_file)
def register_lr_hook(self, lr_config):
if lr_config is None:
return
if isinstance(lr_config, dict):
assert 'policy' in lr_config
policy_type = lr_config.pop('policy')
# If the type of policy is all in lower case, e.g., 'cyclic',
# then its first letter will be capitalized, e.g., to be 'Cyclic'.
# This is for the convenient usage of Lr updater.
# Since this is not applicable for `
# CosineAnnealingLrUpdater`,
# the string will not be changed if it contains capital letters.
if policy_type == policy_type.lower():
policy_type = policy_type.title()
hook_type = policy_type + 'LrUpdaterHook'
lr_config['type'] = hook_type
hook = mmcv.build_from_cfg(lr_config, HOOKS)
else:
hook = lr_config
self.register_hook(hook)
# Copyright (c) OpenMMLab. All rights reserved.
import functools
from collections import abc
from inspect import getfullargspec
import numpy as np
import torch
import torch.nn as nn
from mmcv.utils import TORCH_VERSION
try:
# If PyTorch version >= 1.6.0, torch.cuda.amp.autocast would be imported
# and used; otherwise, auto fp16 will adopt mmcv's implementation.
from torch.cuda.amp import autocast
except ImportError:
pass
def nan_to_num(x, nan=0.0, posinf=None, neginf=None, *, out=None):
r"""Replaces :literal:`NaN`, positive infinity, and negative infinity
values in :attr:`input` with the values specified by :attr:`nan`,
:attr:`posinf`, and :attr:`neginf`, respectively. By default,
:literal:`NaN`s are replaced with zero, positive infinity is replaced with
the greatest finite value representable by :attr:`input`'s dtype, and
negative infinity is replaced with the least finite value representable by
:attr:`input`'s dtype.
.. note::
This function is provided in ``PyTorch>=1.8.0``. Here is a
reimplementation to avoid attribute error in lower PyTorch version.
Args:
x (Tensor): Input tensor.
nan (Number, optional): the value to replace :literal:`NaN`\s with.
Default is zero.
posinf (Number, optional): if a Number, the value to replace positive
infinity values with. If None, positive infinity values are
replaced with the greatest finite value representable by
:attr:`input`'s dtype. Default is None.
neginf (Number, optional): if a Number, the value to replace negative
infinity values with. If None, negative infinity values are
replaced with the lowest finite value representable by
:attr:`input`'s dtype. Default is None.
Returns:
Tensor: Output tensor.
"""
try:
return torch.nan_to_num(
x, nan=nan, posinf=posinf, neginf=neginf, out=out)
except AttributeError:
if not isinstance(x, torch.Tensor):
raise TypeError(
f'argument input (position 1) must be Tensor, not {type(x)}')
if posinf is None:
posinf = torch.finfo(x.dtype).max
if neginf is None:
neginf = torch.finfo(x.dtype).min
assert nan == 0
# a better choice is to use nansum, but this function is not supported
# in PyTorch 1.5
# x.unsqueeze(0).nansum(0)
x[torch.isnan(x)] = 0.
return torch.clamp(x, min=neginf, max=posinf, out=out)
def cast_tensor_type(inputs, src_type, dst_type):
"""Recursively convert Tensor in inputs from src_type to dst_type.
Args:
inputs: Inputs that to be casted.
src_type (torch.dtype): Source type..
dst_type (torch.dtype): Destination type.
Returns:
The same type with inputs, but all contained Tensors have been cast.
"""
if isinstance(inputs, torch.Tensor):
return inputs.to(dst_type)
if isinstance(inputs, nn.Module):
return inputs
elif isinstance(inputs, str):
return inputs
elif isinstance(inputs, np.ndarray):
return inputs
elif isinstance(inputs, abc.Mapping):
return type(inputs)({
k: cast_tensor_type(v, src_type, dst_type)
for k, v in inputs.items()
})
elif isinstance(inputs, abc.Iterable):
return type(inputs)(
cast_tensor_type(item, src_type, dst_type) for item in inputs)
else:
return inputs
def auto_fp16(apply_to=None, out_fp32=False):
"""Decorator to enable fp16 training automatically.
This decorator is useful when you write custom modules and want to support
mixed precision training. If inputs arguments are fp32 tensors, they will
be converted to fp16 automatically. Arguments other than fp32 tensors are
ignored. If you are using PyTorch >= 1.6, torch.cuda.amp is used as the
backend, otherwise, original mmcv implementation will be adopted.
Args:
apply_to (Iterable, optional): The argument names to be converted.
`None` indicates all arguments.
out_fp32 (bool): Whether to convert the output back to fp32.
Example:
>>> import torch.nn as nn
>>> class MyModule1(nn.Module):
>>>
>>> # Convert x and y to fp16
>>> @auto_fp16()
>>> def forward(self, x, y):
>>> pass
>>> import torch.nn as nn
>>> class MyModule2(nn.Module):
>>>
>>> # convert pred to fp16
>>> @auto_fp16(apply_to=('pred', ))
>>> def do_something(self, pred, others):
>>> pass
"""
def auto_fp16_wrapper(old_func):
@functools.wraps(old_func)
def new_func(*args, **kwargs):
# check if the module has set the attribute `fp16_enabled`, if not,
# just fallback to the original method.
if not isinstance(args[0], torch.nn.Module):
raise TypeError('@auto_fp16 can only be used to decorate the '
'method of nn.Module')
if not (hasattr(args[0], 'fp16_enabled') and args[0].fp16_enabled):
return old_func(*args, **kwargs)
# define output type by class itself
if hasattr(args[0], 'out_fp32') and args[0].out_fp32:
_out_fp32 = True
else:
_out_fp32 = False
# get the arg spec of the decorated method
args_info = getfullargspec(old_func)
# get the argument names to be casted
# Here, we change the default behaviour with Yu Xiong's
# implementation
args_to_cast = [] if apply_to is None else apply_to
# convert the args that need to be processed
new_args = []
# NOTE: default args are not taken into consideration
if args:
arg_names = args_info.args[:len(args)]
for i, arg_name in enumerate(arg_names):
if arg_name in args_to_cast:
new_args.append(
cast_tensor_type(args[i], torch.float, torch.half))
else:
new_args.append(args[i])
# convert the kwargs that need to be processed
new_kwargs = {}
if kwargs:
for arg_name, arg_value in kwargs.items():
if arg_name in args_to_cast:
new_kwargs[arg_name] = cast_tensor_type(
arg_value, torch.float, torch.half)
else:
new_kwargs[arg_name] = arg_value
# apply converted arguments to the decorated method
if TORCH_VERSION != 'parrots' and TORCH_VERSION >= '1.6.0':
output = autocast(enabled=True)(old_func)(*new_args,
**new_kwargs)
else:
# output = old_func(*new_args, **new_kwargs)
raise RuntimeError('Please use PyTorch >= 1.6.0')
# cast the results back to fp32 if necessary
if out_fp32 or _out_fp32:
output = cast_tensor_type(output, torch.half, torch.float)
return output
return new_func
return auto_fp16_wrapper
# Copyright (c) OpenMMLab. All rights reserved.
from .lr_updater import LinearLrUpdaterHook
__all__ = ['LinearLrUpdaterHook']
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.runner import HOOKS, LrUpdaterHook
@HOOKS.register_module()
class LinearLrUpdaterHook(LrUpdaterHook):
"""Linear learning rate scheduler for image generation.
In the beginning, the learning rate is 'base_lr' defined in mmcv.
We give a target learning rate 'target_lr' and a start point 'start'
(iteration / epoch). Before 'start', we fix learning rate as 'base_lr';
After 'start', we linearly update learning rate to 'target_lr'.
Args:
target_lr (float): The target learning rate. Default: 0.
start (int): The start point (iteration / epoch, specified by args
'by_epoch' in its parent class in mmcv) to update learning rate.
Default: 0.
interval (int): The interval to update the learning rate. Default: 1.
"""
def __init__(self, target_lr=0, start=0, interval=1, **kwargs):
super().__init__(**kwargs)
self.target_lr = target_lr
self.start = start
self.interval = interval
def get_lr(self, runner, base_lr):
"""Calculates the learning rate.
Args:
runner (object): The passed runner.
base_lr (float): Base learning rate.
Returns:
float: Current learning rate.
"""
if self.by_epoch:
progress = runner.epoch
max_progress = runner.max_epochs
else:
progress = runner.iter
max_progress = runner.max_iters
assert max_progress >= self.start
if max_progress == self.start:
return base_lr
# Before 'start', fix lr; After 'start', linearly update lr.
factor = (max(0, progress - self.start) // self.interval) / (
(max_progress - self.start) // self.interval)
return base_lr + (self.target_lr - base_lr) * factor
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment