Commit b7536f78 authored by limm's avatar limm
Browse files

add a to another part of mmgeneration code

parent 57e0e891
Pipeline #2777 canceled with stages
include mmgen/model-index.yml
recursive-include mmgen/configs *.py *.yml
recursive-include mmgen/tools *.sh *.py
include requirements/*.txt
include mmgen/VERSION
include mmgen/.mim/model-index.yml
include mmgen/.mim/demo/*/*
recursive-include mmgen/.mim/configs *.py *.yml
recursive-include mmgen/.mim/tools *.sh *.py
<div align="center">
<img src="https://user-images.githubusercontent.com/12726765/114528756-de55af80-9c7b-11eb-94d7-d3224ada1585.png" width="400"/>
<div>&nbsp;</div>
<div align="center">
<b><font size="5">OpenMMLab 官网</font></b>
<sup>
<a href="https://openmmlab.com">
<i><font size="4">HOT</font></i>
</a>
</sup>
&nbsp;&nbsp;&nbsp;&nbsp;
<b><font size="5">OpenMMLab 开放平台</font></b>
<sup>
<a href="https://platform.openmmlab.com">
<i><font size="4">TRY IT OUT</font></i>
</a>
</sup>
</div>
<div>&nbsp;</div>
</div>
[![PyPI](https://img.shields.io/pypi/v/mmgen)](https://pypi.org/project/mmgen)
[![docs](https://img.shields.io/badge/docs-latest-blue)](https://mmgeneration.readthedocs.io/en/latest/)
[![badge](https://github.com/open-mmlab/mmgeneration/workflows/build/badge.svg)](https://github.com/open-mmlab/mmgeneration/actions)
[![codecov](https://codecov.io/gh/open-mmlab/mmgeneration/branch/master/graph/badge.svg)](https://codecov.io/gh/open-mmlab/mmgeneration)
[![license](https://img.shields.io/github/license/open-mmlab/mmgeneration.svg)](https://github.com/open-mmlab/mmgeneration/blob/master/LICENSE)
[![open issues](https://isitmaintained.com/badge/open/open-mmlab/mmgeneration.svg)](https://github.com/open-mmlab/mmgeneration/issues)
[![issue resolution](https://isitmaintained.com/badge/resolution/open-mmlab/mmgeneration.svg)](https://github.com/open-mmlab/mmgeneration/issues)
[📘使用文档](https://mmgeneration.readthedocs.io/en/latest/) |
[🛠️安装教程](https://mmgeneration.readthedocs.io/en/latest/get_started.html#installation) |
[👀模型库](https://mmgeneration.readthedocs.io/en/latest/modelzoo_statistics.html) |
[🆕更新记录](https://github.com/open-mmlab/mmgeneration/blob/master/docs/en/changelog.md) |
[🚀进行中的项目](https://github.com/open-mmlab/mmgeneration/projects) |
[🤔提出问题](https://github.com/open-mmlab/mmgeneration/issues)
[English](README.md) | 简体中文
## 最新进展
我们将MMGeneration合入了[MMEditing](https://github.com/open-mmlab/mmediting/tree/1.x),并支持了新的生成任务和算法。请关注以下新特性:
- 🌟 图文生成任务
-[GLIDE](https://github.com/open-mmlab/mmediting/tree/1.x/projects/glide/configs/README.md)
-[Disco-Diffusion](https://github.com/open-mmlab/mmediting/tree/1.x/configs/disco_diffusion/README.md)
-[Stable-Diffusion](https://github.com/open-mmlab/mmediting/tree/1.x/configs/stable_diffusion/README.md)
- 🌟 3D生成任务
-[EG3D](https://github.com/open-mmlab/mmediting/tree/1.x/configs/eg3d/README.md)
## 简介
MMGeneration 是一个基于 PyTorch 和[MMCV](https://github.com/open-mmlab/mmcv)的强有力的生成模型工具箱,尤其专注于 GAN 模型。
主分支目前支持 **PyTorch 1.5** 以上的版本。
<div align="center">
<img src="https://user-images.githubusercontent.com/12726765/114534478-9a65a900-9c81-11eb-8087-de8b6816eed8.png" width="800"/>
</div>
## 主要特性
- **高质量高性能的训练:** 我们目前支持 Unconditional GANs, Internal GANs, 以及 Image Translation Models 的训练。很快将会支持 conditional models 的训练。
- **强有力的应用工具箱:** 为用户提供了丰富的工具箱,包含 GANs 中的多种应用。我们的框架集成了 GANs 的插值,投影和编辑。请用你的 GANs 尽情尝试!([应用教程](docs/tutorials/applications.md))
- **生成模型的高效分布式训练:** 对于生成模型中的高度动态训练,我们采用 `MMDDP` 的新方法来训练动态模型。([DDP教程](docs/tutorials/ddp_train_gans.md))
- **灵活组合的新型模块化设计:** 针对复杂的损失模块,我们提出了一种新的设计,可以自定义模块之间的链接,实现不同模块之间的灵活组合。 ([新模块化设计教程](docs/tutorials/customize_losses.md))
<table>
<thead>
<tr>
<td>
<div align="center">
<b> 训练可视化</b>
<br/>
<img src="https://user-images.githubusercontent.com/12726765/114509105-b6f4e780-9c67-11eb-8644-110b3cb01314.gif" width="200"/>
</div></td>
<td>
<div align="center">
<b> GAN 插值</b>
<br/>
<img src="https://user-images.githubusercontent.com/12726765/114679300-9fd4f900-9d3e-11eb-8f37-c36a018c02f7.gif" width="200"/>
</div></td>
<td>
<div align="center">
<b> GAN 投影</b>
<br/>
<img src="https://user-images.githubusercontent.com/12726765/114524392-c11ee200-9c77-11eb-8b6d-37bc637f5626.gif" width="200"/>
</div></td>
<td>
<div align="center">
<b> GAN 编辑</b>
<br/>
<img src="https://user-images.githubusercontent.com/12726765/114523716-20302700-9c77-11eb-804e-327ae1ca0c5b.gif" width="200"/>
</div></td>
</tr>
</thead>
</table>
## 亮点
- **Positional Encoding as Spatial Inductive Bias in GANs (CVPR2021)** 已在 `MMGeneration` 中发布. [\[配置文件\]](configs/positional_encoding_in_gans/README.md), [\[项目主页\]](https://nbei.github.io/gan-pos-encoding.html)
- 我们已经支持训练目前主流的 Conditional GANs 模型,更多的方法和预训练权重马上就会发布,敬请期待。
- 混合精度训练已经在 `StyleGAN2` 中进行了初步支持,请到[这里](configs/styleganv2/README.md)查看各种实现方式的详细比较。
## 更新日志
v0.7.3 在 14/04/2023 发布。 关于细节和发布历史,请参考 [changelog.md](docs/zh_cn/changelog.md)
## 安装
MMGeneration 依赖 [PyTorch](https://pytorch.org/)[MMCV](https://github.com/open-mmlab/mmcv),以下是安装的简要步骤。
**步骤 1.**
依照[官方教程](https://pytorch.org/get-started/locally/)安装PyTorch,例如
```python
pip3 install torch torchvision
```
**步骤 2.**
使用 [MIM](https://github.com/open-mmlab/mim) 安装 MMCV
```
pip3 install openmim
mim install mmcv-full
```
**步骤 3.**
从源码安装 MMGeneration
```
git clone https://github.com/open-mmlab/mmgeneration.git
cd mmgeneration
pip3 install -e .
```
更详细的安装指南请参考 [get_started.md](docs/zh/get_started.md) .
## 快速入门
对于 `MMGeneration` 的基本使用请参考 [快速入门](docs/zh_cn/get_started.md)。其他细节和教程,请参考我们的[文档](https://mmgeneration.readthedocs.io/)
## 模型库
这些算法在我们的框架中得到了认真研究和支持。
<details open>
<summary>Unconditional GANs (点击折叠)</summary>
-[DCGAN](configs/dcgan/README.md) (ICLR'2016)
-[WGAN-GP](configs/wgan-gp/README.md) (NIPS'2017)
-[LSGAN](configs/lsgan/README.md) (ICCV'2017)
-[GGAN](configs/ggan/README.md) (arXiv'2017)
-[PGGAN](configs/pggan/README.md) (ICLR'2018)
-[StyleGANV1](configs/styleganv1/README.md) (CVPR'2019)
-[StyleGANV2](configs/styleganv2/README.md) (CVPR'2020)
-[StyleGANV3](configs/styleganv3/README.md) (NeurIPS'2021)
-[Positional Encoding in GANs](configs/positional_encoding_in_gans/README.md) (CVPR'2021)
</details>
<details open>
<summary>Conditional GANs (点击折叠)</summary>
-[SNGAN](configs/sngan_proj/README.md) (ICLR'2018)
-[Projection GAN](configs/sngan_proj/README.md) (ICLR'2018)
-[SAGAN](configs/sagan/README.md) (ICML'2019)
-[BIGGAN/BIGGAN-DEEP](configs/biggan/README.md) (ICLR'2019)
</details>
<details open>
<summary>Tricks for GANs (点击折叠)</summary>
-[ADA](configs/ada/README.md) (NeurIPS'2020)
</details>
<details open>
<summary>Image2Image Translation (点击折叠)</summary>
-[Pix2Pix](configs/pix2pix/README.md) (CVPR'2017)
-[CycleGAN](configs/cyclegan/README.md) (ICCV'2017)
</details>
<details open>
<summary>Internal Learning (点击折叠)</summary>
-[SinGAN](configs/dcgan/README.md) (ICCV'2019)
</details>
<details open>
<summary>Denoising Diffusion Probabilistic Models (点击折叠)</summary>
-[Improved DDPM](configs/improved_ddpm/README.md) (arXiv'2021)
</details>
## 相关应用
-[MMGEN-FaceStylor](https://github.com/open-mmlab/MMGEN-FaceStylor)
## 贡献指南
我们感谢所有的贡献者为改进和提升 MMGeneration 所作出的努力。请参考[贡献指南](https://github.com/open-mmlab/mmcv/blob/master/CONTRIBUTING.md)来了解参与项目贡献的相关指引。
## 引用
如果您发现此项目对您的研究有用,请考虑引用:
```BibTeX
@misc{2021mmgeneration,
title={{MMGeneration}: OpenMMLab Generative Model Toolbox and Benchmark},
author={MMGeneration Contributors},
howpublished = {\url{https://github.com/open-mmlab/mmgeneration}},
year={2020}
}
```
## 开源许可证
该项目采用 [Apache 2.0 license](LICENSE) 开源许可证。`MMGeneration` 中的一些操作使用了其他许可证。如果您使用我们的代码进行商业事务,请参考 [许可证](LICENSES.md) 并仔细检查。
## OpenMMLab 的其他项目
- [MMCV](https://github.com/open-mmlab/mmcv): OpenMMLab 计算机视觉基础库
- [MIM](https://github.com/open-mmlab/mim): MIM 是 OpenMMlab 项目、算法、模型的统一入口
- [MMClassification](https://github.com/open-mmlab/mmclassification): OpenMMLab 图像分类工具箱
- [MMDetection](https://github.com/open-mmlab/mmdetection): OpenMMLab 目标检测工具箱
- [MMDetection3D](https://github.com/open-mmlab/mmdetection3d): OpenMMLab 新一代通用 3D 目标检测平台
- [MMRotate](https://github.com/open-mmlab/mmrotate): OpenMMLab 旋转框检测工具箱与测试基准
- [MMSegmentation](https://github.com/open-mmlab/mmsegmentation): OpenMMLab 语义分割工具箱
- [MMOCR](https://github.com/open-mmlab/mmocr): OpenMMLab 全流程文字检测识别理解工具箱
- [MMPose](https://github.com/open-mmlab/mmpose): OpenMMLab 姿态估计工具箱
- [MMHuman3D](https://github.com/open-mmlab/mmhuman3d): OpenMMLab 人体参数化模型工具箱与测试基准
- [MMSelfSup](https://github.com/open-mmlab/mmselfsup): OpenMMLab 自监督学习工具箱与测试基准
- [MMRazor](https://github.com/open-mmlab/mmrazor): OpenMMLab 模型压缩工具箱与测试基准
- [MMFewShot](https://github.com/open-mmlab/mmfewshot): OpenMMLab 少样本学习工具箱与测试基准
- [MMAction2](https://github.com/open-mmlab/mmaction2): OpenMMLab 新一代视频理解工具箱
- [MMTracking](https://github.com/open-mmlab/mmtracking): OpenMMLab 一体化视频目标感知平台
- [MMFlow](https://github.com/open-mmlab/mmflow): OpenMMLab 光流估计工具箱与测试基准
- [MMEditing](https://github.com/open-mmlab/mmediting): OpenMMLab 图像视频编辑工具箱
- [MMGeneration](https://github.com/open-mmlab/mmgeneration): OpenMMLab 图片视频生成模型工具箱
- [MMDeploy](https://github.com/open-mmlab/mmdeploy): OpenMMLab 模型部署框架
## 欢迎加入 OpenMMLab 社区
扫描下方的二维码可关注 OpenMMLab 团队的 [知乎官方账号](https://www.zhihu.com/people/openmmlab),加入 OpenMMLab 团队的 [官方交流 QQ 群](https://jq.qq.com/?_wv=1027&k=K0QI8ByU)
<div align="center">
<img src="https://user-images.githubusercontent.com/22982797/115827101-66874200-a43e-11eb-9abf-831094c27ef4.JPG" height="400" /> <img src="https://user-images.githubusercontent.com/25839884/203927852-e15def4d-a0eb-4dfc-9bfb-7cf09ea945d0.png" height="400" />
</div>
我们会在 OpenMMLab 社区为大家
- 📢 分享 AI 框架的前沿核心技术
- 💻 解读 PyTorch 常用模块源码
- 📰 发布 OpenMMLab 的相关新闻
- 🚀 介绍 OpenMMLab 开发的前沿算法
- 🏃 获取更高效的问题答疑和意见反馈
- 🔥 提供与各行各业开发者充分交流的平台
干货满满 📘,等你来撩 💗,OpenMMLab 社区期待您的加入 👬
# Copyright (c) OpenMMLab. All rights reserved.
import mmcv
from .version import __version__, parse_version_info, version_info
def digit_version(version_str):
digit_version = []
for x in version_str.split('.'):
if x.isdigit():
digit_version.append(int(x))
elif x.find('rc') != -1:
patch_version = x.split('rc')
digit_version.append(int(patch_version[0]) - 1)
digit_version.append(int(patch_version[1]))
return digit_version
mmcv_minimum_version = '1.3.0'
mmcv_maximum_version = '1.8.0'
mmcv_version = digit_version(mmcv.__version__)
assert (mmcv_version >= digit_version(mmcv_minimum_version)
and mmcv_version <= digit_version(mmcv_maximum_version)), \
f'MMCV=={mmcv.__version__} is used but incompatible. ' \
f'Please install mmcv>={mmcv_minimum_version}, <={mmcv_maximum_version}.'
__all__ = ['__version__', 'version_info', 'parse_version_info']
# Copyright (c) OpenMMLab. All rights reserved.
from .inference import (init_model, sample_conditional_model,
sample_ddpm_model, sample_img2img_model,
sample_unconditional_model)
from .train import set_random_seed, train_model
__all__ = [
'set_random_seed', 'train_model', 'init_model', 'sample_img2img_model',
'sample_unconditional_model', 'sample_conditional_model',
'sample_ddpm_model'
]
# Copyright (c) OpenMMLab. All rights reserved.
import mmcv
import torch
from mmcv.parallel import collate, scatter
from mmcv.runner import load_checkpoint
from mmcv.utils import is_list_of
from mmgen.datasets.pipelines import Compose
from mmgen.models import BaseTranslationModel, build_model
def init_model(config, checkpoint=None, device='cuda:0', cfg_options=None):
"""Initialize a detector from config file.
Args:
config (str or :obj:`mmcv.Config`): Config file path or the config
object.
checkpoint (str, optional): Checkpoint path. If left as None, the model
will not load any weights.
cfg_options (dict): Options to override some settings in the used
config.
Returns:
nn.Module: The constructed unconditional model.
"""
if isinstance(config, str):
config = mmcv.Config.fromfile(config)
elif not isinstance(config, mmcv.Config):
raise TypeError('config must be a filename or Config object, '
f'but got {type(config)}')
if cfg_options is not None:
config.merge_from_dict(cfg_options)
model = build_model(
config.model, train_cfg=config.train_cfg, test_cfg=config.test_cfg)
if checkpoint is not None:
load_checkpoint(model, checkpoint, map_location='cpu')
model._cfg = config # save the config in the model for convenience
model.to(device)
model.eval()
return model
@torch.no_grad()
def sample_unconditional_model(model,
num_samples=16,
num_batches=4,
sample_model='ema',
**kwargs):
"""Sampling from unconditional models.
Args:
model (nn.Module): Unconditional models in MMGeneration.
num_samples (int, optional): The total number of samples.
Defaults to 16.
num_batches (int, optional): The number of batch size for inference.
Defaults to 4.
sample_model (str, optional): Which model you want to use. ['ema',
'orig']. Defaults to 'ema'.
Returns:
Tensor: Generated image tensor.
"""
# set eval mode
model.eval()
# construct sampling list for batches
n_repeat = num_samples // num_batches
batches_list = [num_batches] * n_repeat
if num_samples % num_batches > 0:
batches_list.append(num_samples % num_batches)
res_list = []
# inference
for batches in batches_list:
res = model.sample_from_noise(
None, num_batches=batches, sample_model=sample_model, **kwargs)
res_list.append(res.cpu())
results = torch.cat(res_list, dim=0)
return results
@torch.no_grad()
def sample_conditional_model(model,
num_samples=16,
num_batches=4,
sample_model='ema',
label=None,
**kwargs):
"""Sampling from conditional models.
Args:
model (nn.Module): Conditional models in MMGeneration.
num_samples (int, optional): The total number of samples.
Defaults to 16.
num_batches (int, optional): The number of batch size for inference.
Defaults to 4.
sample_model (str, optional): Which model you want to use. ['ema',
'orig']. Defaults to 'ema'.
label (int | torch.Tensor | list[int], optional): Labels used to
generate images. Default to None.,
Returns:
Tensor: Generated image tensor.
"""
# set eval mode
model.eval()
# construct sampling list for batches
n_repeat = num_samples // num_batches
batches_list = [num_batches] * n_repeat
# check and convert the input labels
if isinstance(label, int):
label = torch.LongTensor([label] * num_samples)
elif isinstance(label, torch.Tensor):
label = label.type(torch.int64)
if label.numel() == 1:
# repeat single tensor
# call view(-1) to avoid nested tensor like [[[1]]]
label = label.view(-1).repeat(num_samples)
else:
# flatten multi tensors
label = label.view(-1)
elif isinstance(label, list):
if is_list_of(label, int):
label = torch.LongTensor(label)
# `nargs='+'` parse single integer as list
if label.numel() == 1:
# repeat single tensor
label = label.repeat(num_samples)
else:
raise TypeError('Only support `int` for label list elements, '
f'but receive {type(label[0])}')
elif label is None:
pass
else:
raise TypeError('Only support `int`, `torch.Tensor`, `list[int]` or '
f'None as label, but receive {type(label)}.')
# check the length of the (converted) label
if label is not None and label.size(0) != num_samples:
raise ValueError('Number of elements in the label list should be ONE '
'or the length of `num_samples`. Requires '
f'{num_samples}, but receive {label.size(0)}.')
# make label list
label_list = []
for n in range(n_repeat):
if label is None:
label_list.append(None)
else:
label_list.append(label[n * num_batches:(n + 1) * num_batches])
if num_samples % num_batches > 0:
batches_list.append(num_samples % num_batches)
if label is None:
label_list.append(None)
else:
label_list.append(label[(n + 1) * num_batches:])
res_list = []
# inference
for batches, labels in zip(batches_list, label_list):
res = model.sample_from_noise(
None,
num_batches=batches,
label=labels,
sample_model=sample_model,
**kwargs)
res_list.append(res.cpu())
results = torch.cat(res_list, dim=0)
return results
def sample_img2img_model(model, image_path, target_domain=None, **kwargs):
"""Sampling from translation models.
Args:
model (nn.Module): The loaded model.
image_path (str): File path of input image.
style (str): Target style of output image.
Returns:
Tensor: Translated image tensor.
"""
assert isinstance(model, BaseTranslationModel)
# get source domain and target domain
if target_domain is None:
target_domain = model._default_domain
source_domain = model.get_other_domains(target_domain)[0]
cfg = model._cfg
device = next(model.parameters()).device # model device
# build the data pipeline
test_pipeline = Compose(cfg.test_pipeline)
# prepare data
data = dict()
# dirty code to deal with test data pipeline
data['pair_path'] = image_path
data[f'img_{source_domain}_path'] = image_path
data[f'img_{target_domain}_path'] = image_path
data = test_pipeline(data)
if device.type == 'cpu':
data = collate([data], samples_per_gpu=1)
data['meta'] = []
else:
data = scatter(collate([data], samples_per_gpu=1), [device])[0]
source_image = data[f'img_{source_domain}']
# forward the model
with torch.no_grad():
results = model(
source_image,
test_mode=True,
target_domain=target_domain,
**kwargs)
output = results['target']
return output
@torch.no_grad()
def sample_ddpm_model(model,
num_samples=16,
num_batches=4,
sample_model='ema',
same_noise=False,
**kwargs):
"""Sampling from ddpm models.
Args:
model (nn.Module): DDPM models in MMGeneration.
num_samples (int, optional): The total number of samples.
Defaults to 16.
num_batches (int, optional): The number of batch size for inference.
Defaults to 4.
sample_model (str, optional): Which model you want to use. ['ema',
'orig']. Defaults to 'ema'.
noise_batch (torch.Tensor): Noise batch used as denoising starting up.
Defaults to None.
Returns:
list[Tensor | dict]: Generated image tensor.
"""
model.eval()
n_repeat = num_samples // num_batches
batches_list = [num_batches] * n_repeat
if num_samples % num_batches > 0:
batches_list.append(num_samples % num_batches)
noise_batch = torch.randn(model.image_shape) if same_noise else None
res_list = []
# inference
for idx, batches in enumerate(batches_list):
mmcv.print_log(
f'Start to sample batch [{idx+1} / '
f'{len(batches_list)}]', 'mmgen')
noise_batch_ = noise_batch[None, ...].expand(batches, -1, -1, -1) \
if same_noise else None
res = model.sample_from_noise(
noise_batch_,
num_batches=batches,
sample_model=sample_model,
show_pbar=True,
**kwargs)
if isinstance(res, dict):
res = {k: v.cpu() for k, v in res.items()}
elif isinstance(res, torch.Tensor):
res = res.cpu()
else:
raise ValueError('Sample results should be \'dict\' or '
f'\'torch.Tensor\', but receive \'{type(res)}\'')
res_list.append(res)
# gather the res_list
if isinstance(res_list[0], dict):
res_dict = dict()
for t in res_list[0].keys():
# num_samples x 3 x H x W
res_dict[t] = torch.cat([res[t] for res in res_list], dim=0)
return res_dict
else:
return torch.cat(res_list, dim=0)
# Copyright (c) OpenMMLab. All rights reserved.
import os
from copy import deepcopy
import mmcv
import torch
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import HOOKS, IterBasedRunner, OptimizerHook, build_runner
from mmcv.runner import set_random_seed as set_random_seed_mmcv
from mmcv.utils import build_from_cfg
from mmgen.core.ddp_wrapper import DistributedDataParallelWrapper
from mmgen.core.optimizer import build_optimizers
from mmgen.core.runners.apex_amp_utils import apex_amp_initialize
from mmgen.datasets import build_dataloader, build_dataset
from mmgen.utils import get_root_logger
def set_random_seed(seed, deterministic=False, use_rank_shift=True):
"""Set random seed.
In this function, we just modify the default behavior of the similar
function defined in MMCV.
Args:
seed (int): Seed to be used.
deterministic (bool): Whether to set the deterministic option for
CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
to True and `torch.backends.cudnn.benchmark` to False.
Default: False.
rank_shift (bool): Whether to add rank number to the random seed to
have different random seed in different threads. Default: True.
"""
set_random_seed_mmcv(
seed, deterministic=deterministic, use_rank_shift=use_rank_shift)
def train_model(model,
dataset,
cfg,
distributed=False,
validate=False,
timestamp=None,
meta=None):
logger = get_root_logger(cfg.log_level)
# prepare data loaders
dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
# default loader config
loader_cfg = dict(
samples_per_gpu=cfg.data.samples_per_gpu,
workers_per_gpu=cfg.data.workers_per_gpu,
# cfg.gpus will be ignored if distributed
num_gpus=len(cfg.gpu_ids),
dist=distributed,
persistent_workers=cfg.data.get('persistent_workers', False),
seed=cfg.seed)
# The overall dataloader settings
loader_cfg.update({
k: v
for k, v in cfg.data.items() if k not in [
'train', 'val', 'test', 'train_dataloader', 'val_dataloader',
'test_dataloader'
]
})
# The specific datalaoder settings
train_loader_cfg = {**loader_cfg, **cfg.data.get('train_dataloader', {})}
data_loaders = [build_dataloader(ds, **train_loader_cfg) for ds in dataset]
# dirty code for use apex amp
# apex.amp request that models should be in cuda device before
# initialization.
if cfg.get('apex_amp', None):
assert distributed, (
'Currently, apex.amp is only supported with DDP training.')
model = model.cuda()
# build optimizer
if cfg.optimizer:
optimizer = build_optimizers(model, cfg.optimizer)
# In GANs, we allow building optimizer in GAN model.
else:
optimizer = None
_use_apex_amp = False
if cfg.get('apex_amp', None):
model, optimizer = apex_amp_initialize(model, optimizer,
**cfg.apex_amp)
_use_apex_amp = True
# put model on gpus
if distributed:
find_unused_parameters = cfg.get('find_unused_parameters', False)
use_ddp_wrapper = cfg.get('use_ddp_wrapper', False)
# Sets the `find_unused_parameters` parameter in
# torch.nn.parallel.DistributedDataParallel
if use_ddp_wrapper:
mmcv.print_log('Use DDP Wrapper.', 'mmgen')
model = DistributedDataParallelWrapper(
model.cuda(),
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False,
find_unused_parameters=find_unused_parameters)
else:
model = MMDistributedDataParallel(
model.cuda(),
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False,
find_unused_parameters=find_unused_parameters)
else:
model = MMDataParallel(model, device_ids=cfg.gpu_ids)
# allow users to define the runner
if cfg.get('runner', None):
runner = build_runner(
cfg.runner,
dict(
model=model,
optimizer=optimizer,
work_dir=cfg.work_dir,
logger=logger,
use_apex_amp=_use_apex_amp,
meta=meta))
else:
runner = IterBasedRunner(
model,
optimizer=optimizer,
work_dir=cfg.work_dir,
logger=logger,
meta=meta)
# set if use dynamic ddp in training
# is_dynamic_ddp=cfg.get('is_dynamic_ddp', False))
# an ugly walkaround to make the .log and .log.json filenames the same
runner.timestamp = timestamp
# fp16 setting
fp16_cfg = cfg.get('fp16', None)
# In GANs, we can directly optimize parameter in `train_step` function.
if cfg.get('optimizer_cfg', None) is None:
optimizer_config = None
elif fp16_cfg is not None:
raise NotImplementedError('Fp16 has not been supported.')
# optimizer_config = Fp16OptimizerHook(
# **cfg.optimizer_config, **fp16_cfg, distributed=distributed)
# default to use OptimizerHook
elif distributed and 'type' not in cfg.optimizer_config:
optimizer_config = OptimizerHook(**cfg.optimizer_config)
else:
optimizer_config = cfg.optimizer_config
# update `out_dir` in ckpt hook
if cfg.checkpoint_config is not None:
cfg.checkpoint_config['out_dir'] = os.path.join(
cfg.work_dir, cfg.checkpoint_config.get('out_dir', 'ckpt'))
# register hooks
runner.register_training_hooks(cfg.lr_config, optimizer_config,
cfg.checkpoint_config, cfg.log_config,
cfg.get('momentum_config', None))
# # DistSamplerSeedHook should be used with EpochBasedRunner
# if distributed:
# runner.register_hook(DistSamplerSeedHook())
# In general, we do NOT adopt standard evaluation hook in GAN training.
# Thus, if you want a eval hook, you need further define the key of
# 'evaluation' in the config.
# register eval hooks
if validate and cfg.get('evaluation', None) is not None:
val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))
# Support batch_size > 1 in validation
val_loader_cfg = {
**loader_cfg, 'shuffle': False,
**cfg.data.get('val_data_loader', {})
}
val_dataloader = build_dataloader(val_dataset, **val_loader_cfg)
eval_cfg = deepcopy(cfg.get('evaluation'))
priority = eval_cfg.pop('priority', 'LOW')
eval_cfg.update(dict(dist=distributed, dataloader=val_dataloader))
eval_hook = build_from_cfg(eval_cfg, HOOKS)
runner.register_hook(eval_hook, priority=priority)
# user-defined hooks
if cfg.get('custom_hooks', None):
custom_hooks = cfg.custom_hooks
assert isinstance(custom_hooks, list), \
f'custom_hooks expect list type, but got {type(custom_hooks)}'
for hook_cfg in cfg.custom_hooks:
assert isinstance(hook_cfg, dict), \
'Each item in custom_hooks expects dict type, but got ' \
f'{type(hook_cfg)}'
hook_cfg = hook_cfg.copy()
priority = hook_cfg.pop('priority', 'NORMAL')
hook = build_from_cfg(hook_cfg, HOOKS)
runner.register_hook(hook, priority=priority)
if cfg.resume_from:
runner.resume(cfg.resume_from)
elif cfg.load_from:
runner.load_checkpoint(cfg.load_from)
runner.run(data_loaders, cfg.workflow, cfg.total_iters)
# Copyright (c) OpenMMLab. All rights reserved.
from .evaluation import * # noqa: F401, F403
from .hooks import * # noqa: F401, F403
from .optimizer import * # noqa: F401, F403
from .registry import * # noqa: F401, F403
from .runners import * # noqa: F401, F403
from .scheduler import * # noqa: F401, F403
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmcv.parallel import MODULE_WRAPPERS, MMDistributedDataParallel
from mmcv.parallel.scatter_gather import scatter_kwargs
from torch.cuda._utils import _get_device_index
@MODULE_WRAPPERS.register_module('mmgen.DDPWrapper')
class DistributedDataParallelWrapper(nn.Module):
"""A DistributedDataParallel wrapper for models in MMGeneration.
In MMedting, there is a need to wrap different modules in the models
with separate DistributedDataParallel. Otherwise, it will cause
errors for GAN training.
More specific, the GAN model, usually has two sub-modules:
generator and discriminator. If we wrap both of them in one
standard DistributedDataParallel, it will cause errors during training,
because when we update the parameters of the generator (or discriminator),
the parameters of the discriminator (or generator) is not updated, which is
not allowed for DistributedDataParallel.
So we design this wrapper to separately wrap DistributedDataParallel
for generator and discriminator.
In this wrapper, we perform two operations:
1. Wrap the modules in the models with separate MMDistributedDataParallel.
Note that only modules with parameters will be wrapped.
2. Do scatter operation for 'forward', 'train_step' and 'val_step'.
Note that the arguments of this wrapper is the same as those in
`torch.nn.parallel.distributed.DistributedDataParallel`.
Args:
module (nn.Module): Module that needs to be wrapped.
device_ids (list[int | `torch.device`]): Same as that in
`torch.nn.parallel.distributed.DistributedDataParallel`.
dim (int, optional): Same as that in the official scatter function in
pytorch. Defaults to 0.
broadcast_buffers (bool): Same as that in
`torch.nn.parallel.distributed.DistributedDataParallel`.
Defaults to False.
find_unused_parameters (bool, optional): Same as that in
`torch.nn.parallel.distributed.DistributedDataParallel`.
Traverse the autograd graph of all tensors contained in returned
value of the wrapped module’s forward function. Defaults to False.
kwargs (dict): Other arguments used in
`torch.nn.parallel.distributed.DistributedDataParallel`.
"""
def __init__(self,
module,
device_ids,
dim=0,
broadcast_buffers=False,
find_unused_parameters=False,
**kwargs):
super().__init__()
assert len(device_ids) == 1, (
'Currently, DistributedDataParallelWrapper only supports one'
'single CUDA device for each process.'
f'The length of device_ids must be 1, but got {len(device_ids)}.')
self.module = module
self.dim = dim
self.to_ddp(
device_ids=device_ids,
dim=dim,
broadcast_buffers=broadcast_buffers,
find_unused_parameters=find_unused_parameters,
**kwargs)
self.output_device = _get_device_index(device_ids[0], True)
def to_ddp(self, device_ids, dim, broadcast_buffers,
find_unused_parameters, **kwargs):
"""Wrap models with separate MMDistributedDataParallel.
It only wraps the modules with parameters.
"""
for name, module in self.module._modules.items():
if next(module.parameters(), None) is None:
module = module.cuda()
elif all(not p.requires_grad for p in module.parameters()):
module = module.cuda()
else:
module = MMDistributedDataParallel(
module.cuda(),
device_ids=device_ids,
dim=dim,
broadcast_buffers=broadcast_buffers,
find_unused_parameters=find_unused_parameters,
**kwargs)
self.module._modules[name] = module
def scatter(self, inputs, kwargs, device_ids):
"""Scatter function.
Args:
inputs (Tensor): Input Tensor.
kwargs (dict): Args for
``mmcv.parallel.scatter_gather.scatter_kwargs``.
device_ids (int): Device id.
"""
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
def forward(self, *inputs, **kwargs):
"""Forward function.
Args:
inputs (tuple): Input data.
kwargs (dict): Args for
``mmcv.parallel.scatter_gather.scatter_kwargs``.
"""
inputs, kwargs = self.scatter(inputs, kwargs,
[torch.cuda.current_device()])
return self.module(*inputs[0], **kwargs[0])
def train_step(self, *inputs, **kwargs):
"""Train step function.
Args:
inputs (Tensor): Input Tensor.
kwargs (dict): Args for
``mmcv.parallel.scatter_gather.scatter_kwargs``.
"""
inputs, kwargs = self.scatter(inputs, kwargs,
[torch.cuda.current_device()])
output = self.module.train_step(*inputs[0], **kwargs[0])
return output
def val_step(self, *inputs, **kwargs):
"""Validation step function.
Args:
inputs (tuple): Input data.
kwargs (dict): Args for ``scatter_kwargs``.
"""
inputs, kwargs = self.scatter(inputs, kwargs,
[torch.cuda.current_device()])
output = self.module.val_step(*inputs[0], **kwargs[0])
return output
# Copyright (c) OpenMMLab. All rights reserved.
from .eval_hooks import GenerativeEvalHook, TranslationEvalHook
from .evaluation import (make_metrics_table, make_vanilla_dataloader,
offline_evaluation, online_evaluation)
from .metric_utils import slerp
from .metrics import (IS, MS_SSIM, PR, SWD, GaussianKLD, ms_ssim,
sliced_wasserstein)
__all__ = [
'MS_SSIM', 'SWD', 'ms_ssim', 'sliced_wasserstein', 'offline_evaluation',
'online_evaluation', 'PR', 'IS', 'slerp', 'GenerativeEvalHook',
'make_metrics_table', 'make_vanilla_dataloader', 'GaussianKLD',
'TranslationEvalHook'
]
# Copyright (c) OpenMMLab. All rights reserved.
import math
import os
import os.path as osp
import sys
import warnings
from bisect import bisect_right
import mmcv
import torch
from mmcv.runner import HOOKS, Hook, get_dist_info
from ..registry import build_metric
@HOOKS.register_module()
class GenerativeEvalHook(Hook):
"""Evaluation Hook for Generative Models.
This evaluation hook can be used to evaluate unconditional and conditional
models. Note that only ``FID`` and ``IS`` metric are supported for the
distributed training now. In the future, we will support more metrics for
the evaluation during the training procedure.
In our config system, you only need to add `evaluation` with the detailed
configureations. Below is several usage cases for different situations.
What you need to do is to add these lines at the end of your config file.
Then, you can use this evaluation hook in the training procedure.
To be noted that, this evaluation hook support evaluation with dynamic
intervals for FID or other metrics may fluctuate frequently at the end of
the training process.
# TODO: fix the online doc
#. Only use FID for evaluation
.. code-block:: python
:linenos:
evaluation = dict(
type='GenerativeEvalHook',
interval=10000,
metrics=dict(
type='FID',
num_images=50000,
inception_pkl='work_dirs/inception_pkl/ffhq-256-50k-rgb.pkl',
bgr2rgb=True),
sample_kwargs=dict(sample_model='ema'))
#. Use FID and IS simultaneously and save the best checkpoints respectively
.. code-block:: python
:linenos:
evaluation = dict(
type='GenerativeEvalHook',
interval=10000,
metrics=[dict(
type='FID',
num_images=50000,
inception_pkl='work_dirs/inception_pkl/ffhq-256-50k-rgb.pkl',
bgr2rgb=True),
dict(type='IS',
num_images=50000)],
best_metric=['fid', 'is'],
sample_kwargs=dict(sample_model='ema'))
#. Use dynamic evaluation intervals
.. code-block:: python
:linenos:
# interval = 10000 if iter < 50000,
# interval = 4000, if 50000 <= iter < 750000,
# interval = 2000, if iter >= 750000
evaluation = dict(
type='GenerativeEvalHook',
interval=dict(milestones=[500000, 750000],
interval=[10000, 4000, 2000])
metrics=[dict(
type='FID',
num_images=50000,
inception_pkl='work_dirs/inception_pkl/ffhq-256-50k-rgb.pkl',
bgr2rgb=True),
dict(type='IS',
num_images=50000)],
best_metric=['fid', 'is'],
sample_kwargs=dict(sample_model='ema'))
Args:
dataloader (DataLoader): A PyTorch dataloader.
interval (int | dict): Evaluation interval. If int is passed,
``eval_hook`` would run under given interval. If a dict is passed,
The key and value would be interpret as 'milestones' and 'interval'
of the evaluation. Default: 1.
dist (bool, optional): Whether to use distributed evaluation.
Defaults to True.
metrics (dict | list[dict], optional): Configs for metrics that will be
used in evaluation hook. Defaults to None.
sample_kwargs (dict | None, optional): Additional keyword arguments for
sampling images. Defaults to None.
save_best_ckpt (bool, optional): Whether to save the best checkpoint
according to ``best_metric``. Defaults to ``True``.
best_metric (str | list, optional): Which metric to be used in saving
the best checkpoint. Multiple metrics have been supported by
inputing a list of metric names, e.g., ``['fid', 'is']``.
Defaults to ``'fid'``.
"""
rule_map = {'greater': lambda x, y: x > y, 'less': lambda x, y: x < y}
init_value_map = {'greater': -math.inf, 'less': math.inf}
greater_keys = ['acc', 'top', 'AR@', 'auc', 'precision', 'mAP', 'is']
less_keys = ['loss', 'fid']
_supported_best_metrics = ['fid', 'is']
def __init__(self,
dataloader,
interval=1,
dist=True,
metrics=None,
sample_kwargs=None,
save_best_ckpt=True,
best_metric='fid'):
assert metrics is not None
self.dataloader = dataloader
self.dist = dist
self.sample_kwargs = sample_kwargs if sample_kwargs else dict()
self.save_best_ckpt = save_best_ckpt
self.best_metric = best_metric
if isinstance(interval, int):
self.interval = interval
elif isinstance(interval, dict):
if 'milestones' not in interval or 'interval' not in interval:
raise KeyError(
'`milestones` and `interval` must exist in interval dict '
'if you want to use the dynamic interval evaluation '
f'strategy. But receive [{[k for k in interval.keys()]}] '
'in the interval dict.')
self.milestones = interval['milestones']
self.interval = interval['interval']
# check if length of interval match with the milestones
if len(self.interval) != len(self.milestones) + 1:
raise ValueError(
f'Length of `interval`(={len(self.interval)}) cannot '
f'match length of `milestones`(={len(self.milestones)}).')
# check if milestones is in order
for idx in range(len(self.milestones) - 1):
former, latter = self.milestones[idx], self.milestones[idx + 1]
if former >= latter:
raise ValueError(
'Elements in `milestones` should in ascending order.')
else:
raise TypeError('`interval` only support `int` or `dict`,'
f'recieve {type(self.interval)} instead.')
if isinstance(best_metric, str):
self.best_metric = [self.best_metric]
if self.save_best_ckpt:
not_supported = set(self.best_metric) - set(
self._supported_best_metrics)
assert len(not_supported) == 0, (
f'{not_supported} is not supported for saving best ckpt')
self.metrics = build_metric(metrics)
if isinstance(metrics, dict):
self.metrics = [self.metrics]
for metric in self.metrics:
metric.prepare()
# add support for saving best ckpt
if self.save_best_ckpt:
self.rule = {}
self.compare_func = {}
self._curr_best_score = {}
self._curr_best_ckpt_path = {}
for name in self.best_metric:
if name in self.greater_keys:
self.rule[name] = 'greater'
else:
self.rule[name] = 'less'
self.compare_func[name] = self.rule_map[self.rule[name]]
self._curr_best_score[name] = self.init_value_map[
self.rule[name]]
self._curr_best_ckpt_path[name] = None
def get_current_interval(self, runner):
"""Get current evaluation interval.
Args:
runner (``mmcv.runner.BaseRunner``): The runner.
"""
if isinstance(self.interval, int):
return self.interval
else:
curr_iter = runner.iter + 1
index = bisect_right(self.milestones, curr_iter)
return self.interval[index]
def before_run(self, runner):
"""The behavior before running.
Args:
runner (``mmcv.runner.BaseRunner``): The runner.
"""
if self.save_best_ckpt is not None:
if runner.meta is None:
warnings.warn('runner.meta is None. Creating an empty one.')
runner.meta = dict()
runner.meta.setdefault('hook_msgs', dict())
def after_train_iter(self, runner):
"""The behavior after each train iteration.
Args:
runner (``mmcv.runner.BaseRunner``): The runner.
"""
interval = self.get_current_interval(runner)
if not self.every_n_iters(runner, interval):
return
runner.model.eval()
batch_size = self.dataloader.batch_size
rank, ws = get_dist_info()
total_batch_size = batch_size * ws
# sample real images
max_real_num_images = max(metric.num_images - metric.num_real_feeded
for metric in self.metrics)
# define mmcv progress bar
if rank == 0 and max_real_num_images > 0:
mmcv.print_log(
f'Sample {max_real_num_images} real images for evaluation',
'mmgen')
pbar = mmcv.ProgressBar(max_real_num_images)
if max_real_num_images > 0:
for data in self.dataloader:
if 'real_img' in data:
reals = data['real_img']
# key for conditional GAN
elif 'img' in data:
reals = data['img']
else:
raise KeyError('Cannot found key for images in data_dict. '
'Only support `real_img` for unconditional '
'datasets and `img` for conditional '
'datasets.')
if reals.shape[1] not in [1, 3]:
raise RuntimeError('real images should have one or three '
'channels in the first, '
'not % d' % reals.shape[1])
if reals.shape[1] == 1:
reals = reals.repeat(1, 3, 1, 1)
num_feed = 0
for metric in self.metrics:
num_feed_ = metric.feed(reals, 'reals')
num_feed = max(num_feed_, num_feed)
if num_feed <= 0:
break
if rank == 0:
pbar.update(num_feed)
max_num_images = max(metric.num_images for metric in self.metrics)
if rank == 0:
mmcv.print_log(
f'Sample {max_num_images} fake images for evaluation', 'mmgen')
# define mmcv progress bar
if rank == 0:
pbar = mmcv.ProgressBar(max_num_images)
# sampling fake images and directly send them to metrics
for _ in range(0, max_num_images, total_batch_size):
with torch.no_grad():
fakes = runner.model(
None,
num_batches=batch_size,
return_loss=False,
**self.sample_kwargs)
for metric in self.metrics:
# feed in fake images
metric.feed(fakes, 'fakes')
if rank == 0:
pbar.update(total_batch_size)
runner.log_buffer.clear()
# a dirty walkround to change the line at the end of pbar
if rank == 0:
sys.stdout.write('\n')
for metric in self.metrics:
with torch.no_grad():
metric.summary()
for name, val in metric._result_dict.items():
runner.log_buffer.output[name] = val
# record best metric and save the best ckpt
if self.save_best_ckpt and name in self.best_metric:
self._save_best_ckpt(runner, val, name)
runner.log_buffer.ready = True
runner.model.train()
# clear all current states for next evaluation
for metric in self.metrics:
metric.clear()
def _save_best_ckpt(self, runner, new_score, metric_name):
"""Save checkpoint with best metric score.
Args:
runner (``mmcv.runner.BaseRunner``): The runner.
new_score (float): New metric score.
metric_name (str): Name of metric.
"""
curr_iter = f'iter_{runner.iter + 1}'
if self.compare_func[metric_name](new_score,
self._curr_best_score[metric_name]):
best_ckpt_name = f'best_{metric_name}_{curr_iter}.pth'
runner.meta['hook_msgs'][f'best_score_{metric_name}'] = new_score
if self._curr_best_ckpt_path[metric_name] and osp.isfile(
self._curr_best_ckpt_path[metric_name]):
os.remove(self._curr_best_ckpt_path[metric_name])
self._curr_best_ckpt_path[metric_name] = osp.join(
runner.work_dir, best_ckpt_name)
runner.save_checkpoint(
runner.work_dir, best_ckpt_name, create_symlink=False)
runner.meta['hook_msgs'][
f'best_ckpt_{metric_name}'] = self._curr_best_ckpt_path[
metric_name]
self._curr_best_score[metric_name] = new_score
runner.logger.info(
f'Now best checkpoint is saved as {best_ckpt_name}.')
runner.logger.info(f'Best {metric_name} is {new_score:0.4f} '
f'at {curr_iter}.')
@HOOKS.register_module()
class TranslationEvalHook(GenerativeEvalHook):
"""Evaluation Hook for Translation Models.
This evaluation hook can be used to evaluate translation models. Note
that only ``FID`` and ``IS`` metric are supported for the distributed
training now. In the future, we will support more metrics for the
evaluation during the training procedure.
In our config system, you only need to add `evaluation` with the detailed
configureations. Below is several usage cases for different situations.
What you need to do is to add these lines at the end of your config file.
Then, you can use this evaluation hook in the training procedure.
To be noted that, this evaluation hook support evaluation with dynamic
intervals for FID or other metrics may fluctuate frequently at the end of
the training process.
# TODO: fix the online doc
#. Only use FID for evaluation
.. code-blcok:: python
:linenos
evaluation = dict(
type='TranslationEvalHook',
target_domain='photo',
interval=10000,
metrics=dict(type='FID', num_images=106, bgr2rgb=True))
#. Use FID and IS simultaneously and save the best checkpoints respectively
.. code-block:: python
:linenos
evaluation = dict(
type='TranslationEvalHook',
target_domain='photo',
interval=10000,
metrics=[
dict(type='FID', num_images=106, bgr2rgb=True),
dict(
type='IS',
num_images=106,
inception_args=dict(type='pytorch'))
],
best_metric=['fid', 'is'])
#. Use dynamic evaluation intervals
.. code-block:: python
:linenos
# interval = 10000 if iter < 100000,
# interval = 4000, if 100000 <= iter < 200000,
# interval = 2000, if iter >= 200000
evaluation = dict(
type='TranslationEvalHook',
interval=dict(milestones=[100000, 200000],
interval=[10000, 4000, 2000]),
target_domain='zebra',
metrics=[
dict(type='FID', num_images=140, bgr2rgb=True),
dict(type='IS', num_images=140)
],
best_metric=['fid', 'is'])
Args:
target_domain (str): Target domain of output image.
"""
def __init__(self, *args, target_domain, **kwargs):
super().__init__(*args, **kwargs)
self.target_domain = target_domain
def after_train_iter(self, runner):
"""The behavior after each train iteration.
Args:
runner (``mmcv.runner.BaseRunner``): The runner.
"""
interval = self.get_current_interval(runner)
if not self.every_n_iters(runner, interval):
return
runner.model.eval()
source_domain = runner.model.module.get_other_domains(
self.target_domain)[0]
# feed real images
max_num_images = max(metric.num_images for metric in self.metrics)
for metric in self.metrics:
if metric.num_real_feeded >= metric.num_real_need:
continue
mmcv.print_log(f'Feed reals to {metric.name} metric.', 'mmgen')
# feed in real images
for data in self.dataloader:
# key for translation model
if f'img_{self.target_domain}' in data:
reals = data[f'img_{self.target_domain}']
# key for conditional GAN
else:
raise KeyError(
'Cannot found key for images in data_dict. ')
num_feed = metric.feed(reals, 'reals')
if num_feed <= 0:
break
mmcv.print_log(f'Sample {max_num_images} fake images for evaluation',
'mmgen')
rank, ws = get_dist_info()
# define mmcv progress bar
if rank == 0:
pbar = mmcv.ProgressBar(max_num_images)
# feed in fake images
for data in self.dataloader:
# key for translation model
if f'img_{source_domain}' in data:
with torch.no_grad():
output_dict = runner.model(
data[f'img_{source_domain}'],
test_mode=True,
target_domain=self.target_domain,
**self.sample_kwargs)
fakes = output_dict['target']
# key Error
else:
raise KeyError('Cannot found key for images in data_dict. ')
# sampling fake images and directly send them to metrics
# pbar update number for one proc
num_update = 0
for metric in self.metrics:
if metric.num_fake_feeded >= metric.num_fake_need:
continue
num_feed = metric.feed(fakes, 'fakes')
num_update = max(num_update, num_feed)
if num_feed <= 0:
break
if rank == 0:
if num_update > 0:
pbar.update(num_update * ws)
runner.log_buffer.clear()
# a dirty walkround to change the line at the end of pbar
if rank == 0:
sys.stdout.write('\n')
for metric in self.metrics:
with torch.no_grad():
metric.summary()
for name, val in metric._result_dict.items():
runner.log_buffer.output[name] = val
# record best metric and save the best ckpt
if self.save_best_ckpt and name in self.best_metric:
self._save_best_ckpt(runner, val, name)
runner.log_buffer.ready = True
runner.model.train()
# clear all current states for next evaluation
for metric in self.metrics:
metric.clear()
# Copyright (c) OpenMMLab. All rights reserved.
import os
import shutil
import sys
from copy import deepcopy
import mmcv
import torch
import torch.distributed as dist
from mmcv.runner import get_dist_info
from prettytable import PrettyTable
from torchvision.utils import save_image
from mmgen.datasets import build_dataloader, build_dataset
def make_metrics_table(train_cfg, ckpt, eval_info, metrics):
"""Arrange evaluation results into a table.
Args:
train_cfg (str): Name of the training configuration.
ckpt (str): Path of the evaluated model's weights.
metrics (Metric): Metric objects.
Returns:
str: String of the eval table.
"""
table = PrettyTable()
table.set_style(14)
table.add_column('Training configuration', [train_cfg])
table.add_column('Checkpoint', [ckpt])
table.add_column('Eval', [eval_info])
for metric in metrics:
table.add_column(metric.name, [metric.result_str])
return table.get_string()
def make_vanilla_dataloader(img_path, batch_size, dist=False):
pipeline = [
dict(type='LoadImageFromFile', key='real_img', io_backend='disk'),
dict(
type='Normalize',
keys=['real_img'],
mean=[127.5] * 3,
std=[127.5] * 3,
to_rgb=False),
dict(type='ImageToTensor', keys=['real_img']),
dict(type='Collect', keys=['real_img'], meta_keys=['real_img_path'])
]
dataset = build_dataset(
dict(
type='UnconditionalImageDataset',
imgs_root=img_path,
pipeline=pipeline,
))
dataloader = build_dataloader(
dataset,
samples_per_gpu=batch_size,
workers_per_gpu=4,
dist=dist,
shuffle=True)
return dataloader
@torch.no_grad()
def offline_evaluation(model,
data_loader,
metrics,
logger,
basic_table_info,
batch_size,
samples_path=None,
**kwargs):
"""Evaluate model in offline mode.
This method first save generated images at local and then load them by
dataloader.
Args:
model (nn.Module): Model to be tested.
data_loader (nn.Dataloader): PyTorch data loader.
metrics (list): List of metric objects.
logger (Logger): logger used to record results of evaluation.
batch_size (int): Batch size of images fed into metrics.
basic_table_info (dict): Dictionary containing the basic information \
of the metric table include training configuration and ckpt.
samples_path (str): Used to save generated images. If it's none, we'll
give it a default directory and delete it after finishing the
evaluation. Default to None.
kwargs (dict): Other arguments.
"""
# eval special and recon metric online only
online_metric_name = ['PPL', 'GaussianKLD']
for metric in metrics:
assert metric.name not in online_metric_name, 'Please eval '\
f'{metric.name} online'
rank, ws = get_dist_info()
delete_samples_path = False
if samples_path:
mmcv.mkdir_or_exist(samples_path)
else:
temp_path = './work_dirs/temp_samples'
# if temp_path exists, add suffix
suffix = 1
samples_path = temp_path
while os.path.exists(samples_path):
samples_path = temp_path + '_' + str(suffix)
suffix += 1
os.makedirs(samples_path)
delete_samples_path = True
# sample images
num_exist = len(
list(
mmcv.scandir(
samples_path, suffix=('.jpg', '.png', '.jpeg', '.JPEG'))))
if basic_table_info['num_samples'] > 0:
max_num_images = basic_table_info['num_samples']
else:
max_num_images = max(metric.num_images for metric in metrics)
num_needed = max(max_num_images - num_exist, 0)
if num_needed > 0 and rank == 0:
mmcv.print_log(f'Sample {num_needed} fake images for evaluation',
'mmgen')
# define mmcv progress bar
pbar = mmcv.ProgressBar(num_needed)
# if no images, `num_needed` should be zero
total_batch_size = batch_size * ws
for begin in range(0, num_needed, total_batch_size):
end = min(begin + batch_size, max_num_images)
fakes = model(
None,
num_batches=end - begin,
return_loss=False,
sample_model=basic_table_info['sample_model'],
**kwargs)
global_end = min(begin + total_batch_size, max_num_images)
if rank == 0:
pbar.update(global_end - begin)
# gather generated images
if ws > 1:
placeholder = [torch.zeros_like(fakes) for _ in range(ws)]
dist.all_gather(placeholder, fakes)
fakes = torch.cat(placeholder, dim=0)
# save as three-channel
if fakes.size(1) == 3:
fakes = fakes[:, [2, 1, 0], ...]
elif fakes.size(1) == 1:
fakes = torch.cat([fakes] * 3, dim=1)
else:
raise RuntimeError('Generated images must have one or three '
'channels in the first dimension, '
'not %d' % fakes.size(1))
if rank == 0:
for i in range(global_end - begin):
images = fakes[i:i + 1]
images = ((images + 1) / 2)
images = images.clamp_(0, 1)
image_name = str(num_exist + begin + i) + '.png'
save_image(images, os.path.join(samples_path, image_name))
if num_needed > 0 and rank == 0:
sys.stdout.write('\n')
# return if only save sampled images
if len(metrics) == 0:
return
# empty cache to release GPU memory
torch.cuda.empty_cache()
fake_dataloader = make_vanilla_dataloader(
samples_path, batch_size, dist=ws > 1)
for metric in metrics:
mmcv.print_log(f'Evaluate with {metric.name} metric.', 'mmgen')
metric.prepare()
if rank == 0:
# prepare for pbar
total_need = (
metric.num_real_need + metric.num_fake_need -
metric.num_real_feeded - metric.num_fake_feeded)
pbar = mmcv.ProgressBar(total_need)
# feed in real images
for data in data_loader:
# key for unconditional GAN
if 'real_img' in data:
reals = data['real_img']
# key for conditional GAN
elif 'img' in data:
reals = data['img']
else:
raise KeyError('Cannot found key for images in data_dict. '
'Only support `real_img` for unconditional '
'datasets and `img` for conditional '
'datasets.')
if reals.shape[1] == 1:
reals = torch.cat([reals] * 3, dim=1)
num_left = metric.feed(reals, 'reals')
if num_left <= 0:
break
if rank == 0:
pbar.update(reals.shape[0] * ws)
# feed in fake images
for data in fake_dataloader:
fakes = data['real_img']
if fakes.shape[1] == 1:
fakes = torch.cat([fakes] * 3, dim=1)
num_left = metric.feed(fakes, 'fakes')
if num_left <= 0:
break
if rank == 0:
pbar.update(fakes.shape[0] * ws)
if rank == 0:
# only call summary at main device
metric.summary()
sys.stdout.write('\n')
if rank == 0:
table_str = make_metrics_table(basic_table_info['train_cfg'],
basic_table_info['ckpt'],
basic_table_info['sample_model'],
metrics)
logger.info('\n' + table_str)
if delete_samples_path:
shutil.rmtree(samples_path)
@torch.no_grad()
def online_evaluation(model, data_loader, metrics, logger, basic_table_info,
batch_size, **kwargs):
"""Evaluate model in online mode.
This method evaluate model and displays eval progress bar.
Different form `offline_evaluation`, this function will not save
the images or read images from disks. Namely, there do not exist any IO
operations in this function. Thus, in general, `online` mode will achieve a
faster evaluation. However, this mode will take much more memory cost.
To be noted that, we only support distributed evaluation for FID and IS
currently.
Args:
model (nn.Module): Model to be tested.
data_loader (nn.Dataloader): PyTorch data loader.
metrics (list): List of metric objects.
logger (Logger): logger used to record results of evaluation.
batch_size (int): Batch size of images fed into metrics.
basic_table_info (dict): Dictionary containing the basic information \
of the metric table include training configuration and ckpt.
kwargs (dict): Other arguments.
"""
# separate metrics into special metrics, probabilistic metrics and vanilla
# metrics.
# For vanilla metrics, images are generated in a random way, and are
# shared by these metrics. For special metrics like 'PPL', images are
# generated in a metric-special way and not shared between different
# metrics.
# For reconstruction metrics like 'GaussianKLD', they do not
# receive images but receive a dict with corresponding probabilistic
# parameter.
rank, ws = get_dist_info()
special_metrics = []
recon_metrics = []
vanilla_metrics = []
special_metric_name = ['PPL']
recon_metric_name = ['GaussianKLD']
for metric in metrics:
if ws > 1:
assert metric.name in [
'FID', 'IS'
], ('We only support FID and IS for distributed evaluation '
f'currently, but receive {metric.name}')
if metric.name in special_metric_name:
special_metrics.append(metric)
elif metric.name in recon_metric_name:
recon_metrics.append(metric)
else:
vanilla_metrics.append(metric)
# define mmcv progress bar
max_num_images = 0
for metric in vanilla_metrics + recon_metrics:
metric.prepare()
max_num_images = max(max_num_images,
metric.num_real_need - metric.num_real_feeded)
if rank == 0:
mmcv.print_log(f'Sample {max_num_images} real images for evaluation',
'mmgen')
pbar = mmcv.ProgressBar(max_num_images)
# avoid `data_loader` is None
data_loader = [] if data_loader is None else data_loader
for data in data_loader:
if 'real_img' in data:
reals = data['real_img']
# key for conditional GAN
elif 'img' in data:
reals = data['img']
else:
raise KeyError('Cannot found key for images in data_dict. '
'Only support `real_img` for unconditional '
'datasets and `img` for conditional '
'datasets.')
if reals.shape[1] not in [1, 3]:
raise RuntimeError('real images should have one or three '
'channels in the first, '
'not % d' % reals.shape[1])
if reals.shape[1] == 1:
reals = reals.repeat(1, 3, 1, 1)
num_feed = 0
for metric in vanilla_metrics:
num_feed_ = metric.feed(reals, 'reals')
num_feed = max(num_feed_, num_feed)
for metric in recon_metrics:
kwargs_ = deepcopy(kwargs)
kwargs_['mode'] = 'reconstruction'
prob_dict = model(reals, return_loss=False, **kwargs_)
num_feed_ = metric.feed(prob_dict, 'reals')
num_feed = max(num_feed_, num_feed)
if num_feed <= 0:
break
if rank == 0:
pbar.update(num_feed)
if rank == 0:
# finish the pbar stdout
sys.stdout.write('\n')
# define mmcv progress bar
max_num_images = 0 if len(vanilla_metrics) == 0 else max(
metric.num_fake_need for metric in vanilla_metrics)
if rank == 0:
mmcv.print_log(f'Sample {max_num_images} fake images for evaluation',
'mmgen')
pbar = mmcv.ProgressBar(max_num_images)
# sampling fake images and directly send them to metrics
total_batch_size = batch_size * ws
for _ in range(0, max_num_images, total_batch_size):
fakes = model(
None,
num_batches=batch_size,
return_loss=False,
sample_model=basic_table_info['sample_model'],
**kwargs)
if fakes.shape[1] not in [1, 3]:
raise RuntimeError('fakes images should have one or three '
'channels in the first, '
'not % d' % fakes.shape[1])
if fakes.shape[1] == 1:
fakes = torch.cat([fakes] * 3, dim=1)
for metric in vanilla_metrics:
# feed in fake images
metric.feed(fakes, 'fakes')
if rank == 0:
pbar.update(total_batch_size)
if rank == 0:
# finish the pbar stdout
sys.stdout.write('\n')
# feed special metric, we do not consider distributed eval here
for metric in special_metrics:
metric.prepare()
fakedata_iterator = iter(
metric.get_sampler(model.module, batch_size,
basic_table_info['sample_model']))
mmcv.print_log(
f'Sample {metric.num_images} samples for evaluating {metric.name}',
'mmgen')
pbar = mmcv.ProgressBar(metric.num_images)
for fakes in fakedata_iterator:
num_left = metric.feed(fakes, 'fakes')
pbar.update(fakes.shape[0])
if num_left <= 0:
break
# finish the pbar stdout
sys.stdout.write('\n')
if rank == 0:
for metric in metrics:
metric.summary()
table_str = make_metrics_table(basic_table_info['train_cfg'],
basic_table_info['ckpt'],
basic_table_info['sample_model'],
metrics)
logger.info('\n' + table_str)
# Copyright (c) OpenMMLab. All rights reserved.
import sys
import mmcv
import numpy as np
import torch
import torch.nn.functional as F
from mmcv.parallel import is_module_wrapper
from mmgen.models.architectures.common import get_module_device
@torch.no_grad()
def extract_inception_features(dataloader,
inception,
num_samples,
inception_style='pytorch'):
"""Extract inception features for FID metric.
Args:
dataloader (:obj:`DataLoader`): Dataloader for images.
inception (nn.Module): Inception network.
num_samples (int): The number of samples to be extracted.
inception_style (str): The style of Inception network, "pytorch" or
"stylegan". Defaults to "pytorch".
Returns:
torch.Tensor: Inception features.
"""
batch_size = dataloader.batch_size
num_iters = num_samples // batch_size
if num_iters * batch_size < num_samples:
num_iters += 1
# define mmcv progress bar
pbar = mmcv.ProgressBar(num_iters)
feature_list = []
curr_iter = 1
for data in dataloader:
# a dirty walkround to support multiple datasets (mainly for the
# unconditional dataset and conditional dataset). In our
# implementation, unconditioanl dataset will return real images with
# the key "real_img". However, the conditional dataset contains a key
# "img" denoting the real images.
if 'real_img' in data:
# Mainly for the unconditional dataset in our MMGeneration
img = data['real_img']
else:
# Mainly for conditional dataset in MMClassification
img = data['img']
pbar.update()
# the inception network is not wrapped with module wrapper.
if not is_module_wrapper(inception):
# put the img to the module device
img = img.to(get_module_device(inception))
if inception_style == 'stylegan':
img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8)
feature = inception(img, return_features=True)
else:
feature = inception(img)[0].view(img.shape[0], -1)
feature_list.append(feature.to('cpu'))
if curr_iter >= num_iters:
break
curr_iter += 1
# Attention: the number of features may be different as you want.
features = torch.cat(feature_list, 0)
assert features.shape[0] >= num_samples
features = features[:num_samples]
# to change the line after pbar
sys.stdout.write('\n')
return features
def _hox_downsample(img):
r"""Downsample images with factor equal to 0.5.
Ref: https://github.com/tkarras/progressive_growing_of_gans/blob/master/metrics/ms_ssim.py # noqa
Args:
img (ndarray): Images with order "NHWC".
Returns:
ndarray: Downsampled images with order "NHWC".
"""
return (img[:, 0::2, 0::2, :] + img[:, 1::2, 0::2, :] +
img[:, 0::2, 1::2, :] + img[:, 1::2, 1::2, :]) * 0.25
def _f_special_gauss(size, sigma):
r"""Return a circular symmetric gaussian kernel.
Ref: https://github.com/tkarras/progressive_growing_of_gans/blob/master/metrics/ms_ssim.py # noqa
Args:
size (int): Size of Gaussian kernel.
sigma (float): Standard deviation for Gaussian blur kernel.
Returns:
ndarray: Gaussian kernel.
"""
radius = size // 2
offset = 0.0
start, stop = -radius, radius + 1
if size % 2 == 0:
offset = 0.5
stop -= 1
x, y = np.mgrid[offset + start:stop, offset + start:stop]
assert len(x) == size
g = np.exp(-((x**2 + y**2) / (2.0 * sigma**2)))
return g / g.sum()
# Gaussian blur kernel
def get_gaussian_kernel():
kernel = np.array([[1, 4, 6, 4, 1], [4, 16, 24, 16, 4], [6, 24, 36, 24, 6],
[4, 16, 24, 16, 4], [1, 4, 6, 4, 1]],
np.float32) / 256.0
gaussian_k = torch.as_tensor(kernel.reshape(1, 1, 5, 5))
return gaussian_k
def get_pyramid_layer(image, gaussian_k, direction='down'):
gaussian_k = gaussian_k.to(image.device)
if direction == 'up':
image = F.interpolate(image, scale_factor=2)
multiband = [
F.conv2d(
image[:, i:i + 1, :, :],
gaussian_k,
padding=2,
stride=1 if direction == 'up' else 2) for i in range(3)
]
image = torch.cat(multiband, dim=1)
return image
def gaussian_pyramid(original, n_pyramids, gaussian_k):
x = original
# pyramid down
pyramids = [original]
for _ in range(n_pyramids):
x = get_pyramid_layer(x, gaussian_k)
pyramids.append(x)
return pyramids
def laplacian_pyramid(original, n_pyramids, gaussian_k):
"""Calculate Laplacian pyramid.
Ref: https://github.com/koshian2/swd-pytorch/blob/master/swd.py
Args:
original (Tensor): Batch of Images with range [0, 1] and order "NCHW"
n_pyramids (int): Levels of pyramids minus one.
gaussian_k (Tensor): Gaussian kernel with shape (1, 1, 5, 5).
Return:
list[Tensor]. Laplacian pyramids of original.
"""
# create gaussian pyramid
pyramids = gaussian_pyramid(original, n_pyramids, gaussian_k)
# pyramid up - diff
laplacian = []
for i in range(len(pyramids) - 1):
diff = pyramids[i] - get_pyramid_layer(pyramids[i + 1], gaussian_k,
'up')
laplacian.append(diff)
# Add last gaussian pyramid
laplacian.append(pyramids[len(pyramids) - 1])
return laplacian
def get_descriptors_for_minibatch(minibatch, nhood_size, nhoods_per_image):
r"""Get descriptors of one level of pyramids.
Ref: https://github.com/tkarras/progressive_growing_of_gans/blob/master/metrics/sliced_wasserstein.py # noqa
Args:
minibatch (Tensor): Pyramids of one level with order "NCHW".
nhood_size (int): Pixel neighborhood size.
nhoods_per_image (int): The number of descriptors per image.
Return:
Tensor: Descriptors of images from one level batch.
"""
S = minibatch.shape # (minibatch, channel, height, width)
assert len(S) == 4 and S[1] == 3
N = nhoods_per_image * S[0]
H = nhood_size // 2
nhood, chan, x, y = np.ogrid[0:N, 0:3, -H:H + 1, -H:H + 1]
img = nhood // nhoods_per_image
x = x + np.random.randint(H, S[3] - H, size=(N, 1, 1, 1))
y = y + np.random.randint(H, S[2] - H, size=(N, 1, 1, 1))
idx = ((img * S[1] + chan) * S[2] + y) * S[3] + x
return minibatch.view(-1)[idx]
def finalize_descriptors(desc):
r"""Normalize and reshape descriptors.
Ref: https://github.com/tkarras/progressive_growing_of_gans/blob/master/metrics/sliced_wasserstein.py # noqa
Args:
desc (list or Tensor): List of descriptors of one level.
Return:
Tensor: Descriptors after normalized along channel and flattened.
"""
if isinstance(desc, list):
desc = torch.cat(desc, dim=0)
assert desc.ndim == 4 # (neighborhood, channel, height, width)
desc -= torch.mean(desc, dim=(0, 2, 3), keepdim=True)
desc /= torch.std(desc, dim=(0, 2, 3), keepdim=True)
desc = desc.reshape(desc.shape[0], -1)
return desc
def compute_pr_distances(row_features,
col_features,
num_gpus,
rank,
col_batch_size=10000):
r"""Compute distances between real images and fake images.
This function is used for calculate Precision and Recall metric.
Refer to:https://github.com/NVlabs/stylegan2-ada-pytorch/blob/main/metrics/precision_recall.py # noqa
"""
assert 0 <= rank < num_gpus
num_cols = col_features.shape[0]
num_batches = ((num_cols - 1) // col_batch_size // num_gpus + 1) * num_gpus
col_batches = torch.nn.functional.pad(
col_features, [0, 0, 0, -num_cols % num_batches]).chunk(num_batches)
dist_batches = []
for col_batch in col_batches[rank::num_gpus]:
dist_batch = torch.cdist(
row_features.unsqueeze(0), col_batch.unsqueeze(0))[0]
for src in range(num_gpus):
dist_broadcast = dist_batch.clone()
if num_gpus > 1:
torch.distributed.broadcast(dist_broadcast, src=src)
dist_batches.append(dist_broadcast.cpu() if rank == 0 else None)
return torch.cat(dist_batches, dim=1)[:, :num_cols] if rank == 0 else None
def normalize(a):
"""L2 normalization.
Args:
a (Tensor): Tensor with shape [N, C].
Returns:
Tensor: Tensor after L2 normalization per-instance.
"""
return a / torch.norm(a, dim=1, keepdim=True)
def slerp(a, b, percent):
"""Spherical linear interpolation between two unnormalized vectors.
Args:
a (Tensor): Tensor with shape [N, C].
b (Tensor): Tensor with shape [N, C].
percent (float|Tensor): A float or tensor with shape broadcastable to
the shape of input Tensors.
Returns:
Tensor: Spherical linear interpolation result with shape [N, C].
"""
a = normalize(a)
b = normalize(b)
d = (a * b).sum(-1, keepdim=True)
p = percent * torch.acos(d)
c = normalize(b - d * a)
d = a * torch.cos(p) + c * torch.sin(p)
return normalize(d)
This diff is collapsed.
# Copyright (c) OpenMMLab. All rights reserved.
from .ceph_hooks import PetrelUploadHook
from .ema_hook import ExponentialMovingAverageHook
from .pggan_fetch_data_hook import PGGANFetchDataHook
from .pickle_data_hook import PickleDataHook
from .visualization import VisualizationHook
from .visualize_training_samples import VisualizeUnconditionalSamples
__all__ = [
'VisualizeUnconditionalSamples', 'PGGANFetchDataHook',
'ExponentialMovingAverageHook', 'VisualizationHook', 'PickleDataHook',
'PetrelUploadHook'
]
# Copyright (c) OpenMMLab. All rights reserved.
import os
import mmcv
from mmcv.runner import HOOKS, Hook, master_only
@HOOKS.register_module()
class PetrelUploadHook(Hook):
"""Upload Data with Petrel.
With this hook, users can easily upload data to the cloud server for
saving local spaces. Please read the notes below for using this hook,
especially for the declaration of ``petrel``.
One of the major functions is to transfer the checkpoint files from the
local directory to the cloud server.
.. note::
``petrel`` is a private package containing several commonly used
``AWS`` python API. Currently, this package is only for internal usage
and will not be released to the public. We will support ``boto3`` in
the future. We think this hook is an easy template for you to transfer
to ``boto3``.
Args:
data_path (str, optional): Relative path of the data according to
current working directory. Defaults to 'ckpt'.
suffix (str, optional): Suffix for the data files. Defaults to '.pth'.
ceph_path (str | None, optional): Path in the cloud server.
Defaults to None.
interval (int, optional): Uploading interval (by iterations).
Default: -1.
upload_after_run (bool, optional): Whether to upload after running.
Defaults to True.
rm_orig (bool, optional): Whether to removing the local files after
uploading. Defaults to True.
"""
cfg_path = '~/petreloss.conf'
def __init__(self,
data_path='ckpt',
suffix='.pth',
ceph_path=None,
interval=-1,
upload_after_run=True,
rm_orig=True):
super().__init__()
self.interval = interval
self.upload_after_run = upload_after_run
self.data_path = data_path
self.suffix = suffix
self.ceph_path = ceph_path
self.rm_orig = rm_orig
# setup petrel client
try:
from petrel_client.client import Client
except ImportError:
raise ImportError('Please install petrel in advance.')
self.client = Client(self.cfg_path)
@staticmethod
def upload_dir(client,
local_dir,
remote_dir,
exp_name=None,
suffix=None,
remove_local_file=True):
"""Upload a directory to the cloud server.
Args:
client (obj): AWS client.
local_dir (str): Path for the local data.
remote_dir (str): Path for the remote server.
exp_name (str, optional): The experiment name. Defaults to None.
suffix (str, optional): Suffix for the data files.
Defaults to None.
remove_local_file (bool, optional): Whether to removing the local
files after uploading. Defaults to True.
"""
files = mmcv.scandir(local_dir, suffix=suffix, recursive=False)
files = [os.path.join(local_dir, x) for x in files]
# remove the rebundant symlinks in the data directory
files = [x for x in files if not os.path.islink(x)]
# get the actual exp_name in work_dir
if exp_name is None:
exp_name = local_dir.split('/')[-1]
mmcv.print_log(f'Uploading {len(files)} files to ceph.', 'mmgen')
for file in files:
with open(file, 'rb') as f:
data = f.read()
_path_splits = file.split('/')
idx = _path_splits.index(exp_name)
_rel_path = '/'.join(_path_splits[idx:])
_ceph_path = os.path.join(remote_dir, _rel_path)
client.put(_ceph_path, data)
# remove the local file to save space
if remove_local_file:
os.remove(file)
@master_only
def after_run(self, runner):
"""The behavior after the whole running.
Args:
runner (object): The runner.
"""
if not self.upload_after_run:
return
_data_path = os.path.join(runner.work_dir, self.data_path)
# get the actual exp_name in work_dir
exp_name = runner.work_dir.split('/')[-1]
self.upload_dir(
self.client,
_data_path,
self.ceph_path,
exp_name=exp_name,
suffix=self.suffix,
remove_local_file=self.rm_orig)
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from copy import deepcopy
import mmcv
import torch
from mmcv.parallel import is_module_wrapper
from mmcv.runner import HOOKS, Hook
@HOOKS.register_module()
class ExponentialMovingAverageHook(Hook):
"""Exponential Moving Average Hook.
Exponential moving average is a trick that widely used in current GAN
literature, e.g., PGGAN, StyleGAN, and BigGAN. This general idea of it is
maintaining a model with the same architecture, but its parameters are
updated as a moving average of the trained weights in the original model.
In general, the model with moving averaged weights achieves better
performance.
Args:
module_keys (str | tuple[str]): The name of the ema model. Note that we
require these keys are followed by '_ema' so that we can easily
find the original model by discarding the last four characters.
interp_mode (str, optional): Mode of the interpolation method.
Defaults to 'lerp'.
interp_cfg (dict | None, optional): Set arguments of the interpolation
function. Defaults to None.
interval (int, optional): Evaluation interval (by iterations).
Default: -1.
start_iter (int, optional): Start iteration for ema. If the start
iteration is not reached, the weights of ema model will maintain
the same as the original one. Otherwise, its parameters are updated
as a moving average of the trained weights in the original model.
Default: 0.
momentum_policy (str, optional): Policy of the momentum updating
method. Defaults to 'fixed'.
momentum_cfg (dict | None, optional): Set arguments of the momentum
updater function. Defaults to None.
"""
_registered_interp_funcs = ['lerp']
_registered_momentum_updaters = ['rampup', 'fixed']
def __init__(self,
module_keys,
interp_mode='lerp',
interp_cfg=None,
interval=-1,
start_iter=0,
momentum_policy='fixed',
momentum_cfg=None):
super().__init__()
# check args
assert interp_mode in self._registered_interp_funcs, (
'Supported '
f'interpolation functions are {self._registered_interp_funcs}, '
f'but got {interp_mode}')
assert momentum_policy in self._registered_momentum_updaters, (
'Supported momentum policy are'
f'{self._registered_momentum_updaters},'
f' but got {momentum_policy}')
assert isinstance(module_keys, str) or mmcv.is_tuple_of(
module_keys, str)
self.module_keys = (module_keys, ) if isinstance(module_keys,
str) else module_keys
# sanity check for the format of module keys
for k in self.module_keys:
assert k.endswith(
'_ema'), 'You should give keys that end with "_ema".'
self.interp_mode = interp_mode
self.interp_cfg = dict() if interp_cfg is None else deepcopy(
interp_cfg)
self.interval = interval
self.start_iter = start_iter
assert hasattr(
self, interp_mode
), f'Currently, we do not support {self.interp_mode} for EMA.'
self.interp_func = getattr(self, interp_mode)
self.momentum_cfg = dict() if momentum_cfg is None else deepcopy(
momentum_cfg)
self.momentum_policy = momentum_policy
if momentum_policy != 'fixed':
assert hasattr(
self, momentum_policy
), f'Currently, we do not support {self.momentum_policy} for EMA.'
self.momentum_updater = getattr(self, momentum_policy)
@staticmethod
def lerp(a, b, momentum=0.999, momentum_nontrainable=0., trainable=True):
"""Does a linear interpolation of two parameters/ buffers.
Args:
a (torch.Tensor): Interpolation start point, refer to orig state.
b (torch.Tensor): Interpolation end point, refer to ema state.
momentum (float, optional): The weight for the interpolation
formula. Defaults to 0.999.
momentum_nontrainable (float, optional): The weight for the
interpolation formula used for nontrainable parameters.
Defaults to 0..
trainable (bool, optional): Whether input parameters is trainable.
If set to False, momentum_nontrainable will be used.
Defaults to True.
Returns:
torch.Tensor: Interpolation result.
"""
m = momentum if trainable else momentum_nontrainable
return a + (b - a) * m
@staticmethod
def rampup(runner, ema_kimg=10, ema_rampup=0.05, batch_size=4, eps=1e-8):
"""Ramp up ema momentum.
Ref: https://github.com/NVlabs/stylegan3/blob/a5a69f58294509598714d1e88c9646c3d7c6ec94/training/training_loop.py#L300-L308 # noqa
Args:
runner (_type_): _description_
ema_kimg (int, optional): Half-life of the exponential moving
average of generator weights. Defaults to 10.
ema_rampup (float, optional): EMA ramp-up coefficient.If set to
None, then rampup will be disabled. Defaults to 0.05.
batch_size (int, optional): Total batch size for one training
iteration. Defaults to 4.
eps (float, optional): Epsiolon to avoid ``batch_size`` divided by
zero. Defaults to 1e-8.
Returns:
dict: Updated momentum.
"""
cur_nimg = (runner.iter + 1) * batch_size
ema_nimg = ema_kimg * 1000
if ema_rampup is not None:
ema_nimg = min(ema_nimg, cur_nimg * ema_rampup)
ema_beta = 0.5**(batch_size / max(ema_nimg, eps))
return dict(momentum=ema_beta)
def every_n_iters(self, runner, n):
if runner.iter < self.start_iter:
return True
return (runner.iter + 1 - self.start_iter) % n == 0 if n > 0 else False
@torch.no_grad()
def after_train_iter(self, runner):
if not self.every_n_iters(runner, self.interval):
return
model = runner.model.module if is_module_wrapper(
runner.model) else runner.model
# update momentum
_interp_cfg = deepcopy(self.interp_cfg)
if self.momentum_policy != 'fixed':
_updated_args = self.momentum_updater(runner, **self.momentum_cfg)
_interp_cfg.update(_updated_args)
for key in self.module_keys:
# get current ema states
ema_net = getattr(model, key)
states_ema = ema_net.state_dict(keep_vars=False)
# get currently original states
net = getattr(model, key[:-4])
states_orig = net.state_dict(keep_vars=True)
for k, v in states_orig.items():
if runner.iter < self.start_iter:
states_ema[k].data.copy_(v.data)
else:
states_ema[k] = self.interp_func(
v,
states_ema[k],
trainable=v.requires_grad,
**_interp_cfg).detach()
ema_net.load_state_dict(states_ema, strict=True)
def before_run(self, runner):
model = runner.model.module if is_module_wrapper(
runner.model) else runner.model
# sanity check for ema model
for k in self.module_keys:
if not hasattr(model, k) and not hasattr(model, k[:-4]):
raise RuntimeError(
f'Cannot find both {k[:-4]} and {k} network for EMA hook.')
if not hasattr(model, k) and hasattr(model, k[:-4]):
setattr(model, k, deepcopy(getattr(model, k[:-4])))
warnings.warn(
f'We do not suggest construct and initialize EMA model {k}'
' in hook. You may explicitly define it by yourself.')
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcv.parallel import is_module_wrapper
from mmcv.runner import HOOKS, Hook
@HOOKS.register_module()
class PGGANFetchDataHook(Hook):
"""PGGAN Fetch Data Hook.
Args:
interval (int, optional): The interval of calling this hook. If set
to -1, the visualization hook will not be called. Defaults to 1.
"""
def __init__(self, interval=1):
super().__init__()
self.interval = interval
def before_fetch_train_data(self, runner):
"""The behavior before fetch train data.
Args:
runner (object): The runner.
"""
if not self.every_n_iters(runner, self.interval):
return
_module = runner.model.module if is_module_wrapper(
runner.model) else runner.model
_next_scale_int = _module._next_scale_int
if isinstance(_next_scale_int, torch.Tensor):
_next_scale_int = _next_scale_int.item()
runner.data_loader.update_dataloader(_next_scale_int)
# Copyright (c) OpenMMLab. All rights reserved.
import logging
import os
import pickle
import mmcv
import torch
from mmcv.runner import HOOKS, Hook
from mmcv.runner.dist_utils import master_only
@HOOKS.register_module()
class PickleDataHook(Hook):
"""Pickle Useful Data Hook.
This hook will be used in SinGAN training for saving some important data
that will be used in testing or inference.
Args:
output_dir (str): The output path for saving pickled data.
data_name_list (list[str]): The list contains the name of results in
outputs dict.
interval (int): The interval of calling this hook. If set to -1,
the visualization hook will not be called. Default: -1.
before_run (bool, optional): Whether to save before running.
Defaults to False.
after_run (bool, optional): Whether to save after running.
Defaults to False.
filename_tmpl (str, optional): Format string used to save images. The
output file name will be formatted as this args.
Defaults to 'iter_{}.pkl'.
"""
def __init__(self,
output_dir,
data_name_list,
interval=-1,
before_run=False,
after_run=False,
filename_tmpl='iter_{}.pkl'):
assert mmcv.is_list_of(data_name_list, str)
self.output_dir = output_dir
self.data_name_list = data_name_list
self.interval = interval
self.filename_tmpl = filename_tmpl
self._before_run = before_run
self._after_run = after_run
@master_only
def after_run(self, runner):
"""The behavior after each train iteration.
Args:
runner (object): The runner.
"""
if self._after_run:
self._pickle_data(runner)
@master_only
def before_run(self, runner):
"""The behavior after each train iteration.
Args:
runner (object): The runner.
"""
if self._before_run:
self._pickle_data(runner)
@master_only
def after_train_iter(self, runner):
"""The behavior after each train iteration.
Args:
runner (object): The runner.
"""
if not self.every_n_iters(runner, self.interval):
return
self._pickle_data(runner)
def _pickle_data(self, runner):
filename = self.filename_tmpl.format(runner.iter + 1)
if not hasattr(self, '_out_dir'):
self._out_dir = os.path.join(runner.work_dir, self.output_dir)
mmcv.mkdir_or_exist(self._out_dir)
file_path = os.path.join(self._out_dir, filename)
with open(file_path, 'wb') as f:
data = runner.outputs['results']
not_find_keys = []
data_dict = {}
for k in self.data_name_list:
if k in data.keys():
data_dict[k] = self._get_numpy_data(data[k])
else:
not_find_keys.append(k)
pickle.dump(data_dict, f)
mmcv.print_log(f'Pickle data in {filename}', 'mmgen')
if len(not_find_keys) > 0:
mmcv.print_log(
f'Cannot find keys for pickling: {not_find_keys}',
'mmgen',
level=logging.WARN)
f.flush()
def _get_numpy_data(self, data):
if isinstance(data, list):
return [self._get_numpy_data(x) for x in data]
if isinstance(data, torch.Tensor):
return data.cpu().numpy()
return data
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import mmcv
import torch
from mmcv.runner import HOOKS, Hook
from mmcv.runner.dist_utils import master_only
from torchvision.utils import save_image
@HOOKS.register_module('MMGenVisualizationHook')
class VisualizationHook(Hook):
"""Visualization hook.
In this hook, we use the official api `save_image` in torchvision to save
the visualization results.
Args:
output_dir (str): The file path to store visualizations.
res_name_list (str): The list contains the name of results in outputs
dict. The results in outputs dict must be a torch.Tensor with shape
(n, c, h, w).
interval (int): The interval of calling this hook. If set to -1,
the visualization hook will not be called. Default: -1.
filename_tmpl (str): Format string used to save images. The output file
name will be formatted as this args. Default: 'iter_{}.png'.
rerange (bool): Whether to rerange the output value from [-1, 1] to
[0, 1]. We highly recommend users should preprocess the
visualization results on their own. Here, we just provide a simple
interface. Default: True.
bgr2rgb (bool): Whether to reformat the channel dimension from BGR to
RGB. The final image we will save is following RGB style.
Default: True.
nrow (int): The number of samples in a row. Default: 1.
padding (int): The number of padding pixels between each samples.
Default: 4.
"""
def __init__(self,
output_dir,
res_name_list,
interval=-1,
filename_tmpl='iter_{}.png',
rerange=True,
bgr2rgb=True,
nrow=1,
padding=4):
assert mmcv.is_list_of(res_name_list, str)
self.output_dir = output_dir
self.res_name_list = res_name_list
self.interval = interval
self.filename_tmpl = filename_tmpl
self.bgr2rgb = bgr2rgb
self.rerange = rerange
self.nrow = nrow
self.padding = padding
@master_only
def after_train_iter(self, runner):
"""The behavior after each train iteration.
Args:
runner (object): The runner.
"""
if not self.every_n_iters(runner, self.interval):
return
results = runner.outputs['results']
filename = self.filename_tmpl.format(runner.iter + 1)
# img_list = [x for k, x in results.items() if k in self.res_name_list]
img_list = [results[k] for k in self.res_name_list if k in results]
img_cat = torch.cat(img_list, dim=3).detach()
if self.rerange:
img_cat = ((img_cat + 1) / 2)
if self.bgr2rgb:
img_cat = img_cat[:, [2, 1, 0], ...]
img_cat = img_cat.clamp_(0, 1)
if not hasattr(self, '_out_dir'):
self._out_dir = osp.join(runner.work_dir, self.output_dir)
mmcv.mkdir_or_exist(self._out_dir)
save_image(
img_cat,
osp.join(self._out_dir, filename),
nrow=self.nrow,
padding=self.padding)
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import mmcv
import torch
from mmcv.runner import HOOKS, Hook
from mmcv.runner.dist_utils import master_only
from torchvision.utils import save_image
@HOOKS.register_module()
class VisualizeUnconditionalSamples(Hook):
"""Visualization hook for unconditional GANs.
In this hook, we use the official api `save_image` in torchvision to save
the visualization results.
Args:
output_dir (str): The file path to store visualizations.
fixed_noise (bool, optional): Whether to use fixed noises in sampling.
Defaults to True.
num_samples (int, optional): The number of samples to show in
visualization. Defaults to 16.
interval (int): The interval of calling this hook. If set to -1,
the visualization hook will not be called. Default: -1.
filename_tmpl (str): Format string used to save images. The output file
name will be formatted as this args. Default: 'iter_{}.png'.
rerange (bool): Whether to rerange the output value from [-1, 1] to
[0, 1]. We highly recommend users should preprocess the
visualization results on their own. Here, we just provide a simple
interface. Default: True.
bgr2rgb (bool): Whether to reformat the channel dimension from BGR to
RGB. The final image we will save is following RGB style.
Default: True.
nrow (int): The number of samples in a row. Default: 1.
padding (int): The number of padding pixels between each samples.
Default: 4.
kwargs (dict | None, optional): Key-word arguments for sampling
function. Defaults to None.
"""
def __init__(self,
output_dir,
fixed_noise=True,
num_samples=16,
interval=-1,
filename_tmpl='iter_{}.png',
rerange=True,
bgr2rgb=True,
nrow=4,
padding=0,
kwargs=None):
self.output_dir = output_dir
self.fixed_noise = fixed_noise
self.num_samples = num_samples
self.interval = interval
self.filename_tmpl = filename_tmpl
self.bgr2rgb = bgr2rgb
self.rerange = rerange
self.nrow = nrow
self.padding = padding
# the sampling noise will be initialized by the first sampling.
self.sampling_noise = None
self.kwargs = kwargs if kwargs is not None else dict()
@master_only
def after_train_iter(self, runner):
"""The behavior after each train iteration.
Args:
runner (object): The runner.
"""
if not self.every_n_iters(runner, self.interval):
return
# eval mode
runner.model.eval()
# no grad in sampling
with torch.no_grad():
outputs_dict = runner.model(
self.sampling_noise,
return_loss=False,
num_batches=self.num_samples,
return_noise=True,
**self.kwargs)
imgs = outputs_dict['fake_img']
noise_ = outputs_dict['noise_batch']
# initialize samling noise with the first returned noise
if self.sampling_noise is None and self.fixed_noise:
self.sampling_noise = noise_
# train mode
runner.model.train()
filename = self.filename_tmpl.format(runner.iter + 1)
if self.rerange:
imgs = ((imgs + 1) / 2)
if self.bgr2rgb and imgs.size(1) == 3:
imgs = imgs[:, [2, 1, 0], ...]
if imgs.size(1) == 1:
imgs = torch.cat([imgs, imgs, imgs], dim=1)
imgs = imgs.clamp_(0, 1)
mmcv.mkdir_or_exist(osp.join(runner.work_dir, self.output_dir))
save_image(
imgs,
osp.join(runner.work_dir, self.output_dir, filename),
nrow=self.nrow,
padding=self.padding)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment