Commit 57e0e891 authored by limm's avatar limm
Browse files

add part mmgeneration code

parent 04e07f48
接口参考手册
=================
mmgen.apis
--------------
.. automodule:: mmgen.apis
:members:
mmgen.core
--------------
evaluation
^^^^^^^^^^
.. automodule:: mmgen.core.evaluation
:members:
hooks
^^^^^^^^^^
.. automodule:: mmgen.core.hooks
:members:
optimizer
^^^^^^^^^^
.. automodule:: mmgen.core.optimizer
:members:
runners
^^^^^^^^^^
.. automodule:: mmgen.core.runners
:members:
scheduler
^^^^^^^^^^
.. automodule:: mmgen.core.scheduler
:members:
mmgen.datasets
--------------
datasets
^^^^^^^^^^
.. automodule:: mmgen.datasets
:members:
pipelines
^^^^^^^^^^
.. automodule:: mmgen.datasets.pipelines
:members:
mmgen.models
--------------
architectures
^^^^^^^^^^
.. automodule:: mmgen.models.architectures
:members:
common
^^^^^^^^^^
.. automodule:: mmgen.models.common
:members:
gans
^^^^^^^^^^^^
.. automodule:: mmgen.models.gans
:members:
losses
^^^^^^^^^^^^
.. automodule:: mmgen.models.losses
:members:
# 版本更新日志
# Configuration file for the Sphinx documentation builder.
#
# This file only contains a selection of the most common options. For a full
# list see the documentation:
# https://www.sphinx-doc.org/en/master/usage/configuration.html
# -- Path setup --------------------------------------------------------------
# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
#
import os
import subprocess
import sys
sys.path.insert(0, os.path.abspath('../../'))
# -- Project information -----------------------------------------------------
project = 'MMGeneration'
copyright = '2018-2020, OpenMMLab'
author = 'MMGeneration Authors'
version_file = '../../mmgen/version.py'
def get_version():
with open(version_file, 'r') as f:
exec(compile(f.read(), version_file, 'exec'))
return locals()['__version__']
# The full version, including alpha/beta/rc tags
release = get_version()
# -- General configuration ---------------------------------------------------
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
'sphinx.ext.autodoc',
'sphinx.ext.napoleon',
'sphinx.ext.viewcode',
'recommonmark',
'sphinx_markdown_tables',
]
autodoc_mock_imports = [
'matplotlib', 'pycocotools', 'terminaltables', 'mmgen.version', 'mmcv.ops'
]
# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']
# The suffix(es) of source filenames.
# You can specify multiple suffix as a list of string:
#
source_suffix = {
'.rst': 'restructuredtext',
'.md': 'markdown',
}
# The master toctree document.
master_doc = 'index'
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path.
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
# -- Options for HTML output -------------------------------------------------
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
#
html_theme = 'sphinx_rtd_theme'
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['_static']
def builder_inited_handler(app):
subprocess.run(['./stat.py'])
def setup(app):
app.connect('builder-inited', builder_inited_handler)
# 常见问题解答
我们在这里罗列了许多用户遇到的一些常见的问题和相应的解决方案。如果您发现任何常见问题并有办法帮助他人解决该问题,欢迎丰富该列表。如果此处的内容未涵盖您的问题,请使用[提供的模版](https://github.com/open-mmlab/mmgeneration/blob/master/.github/ISSUE_TEMPLATE/error-report.md)来创建新问题,并确保将模版中所有必需的信息填写完整。
## 依赖项
- Linux
- Python 3.6+
- PyTorch 1.5+
- CUDA 9.2+ (如果您从源码编译PyTorch, CUDA 9.0也是兼容的)
- GCC 5.4+
- [MMCV (MMCV-FULL)](https://mmcv.readthedocs.io/en/latest/#installation)
下面是MMGeneration与MMCV版本兼容信息。为防止出错请安装正确的MMCV版本。
| MMGeneration version | MMCV version |
| :------------------: | :--------------: |
| master | mmcv-full>=1.3.0 |
注:如果您已安装mmcv,需要先卸载 `pip uninstall mmcv`。 如果同时安装了mmcv和mmcv-full,将会报错 `ModuleNotFoundError`
## 安装
1. 创建conda虚拟环境并激活。 (这里假设新环境叫 `open-mmlab`)
```shell
conda create -n open-mmlab python=3.7 -y
conda activate open-mmlab
```
2. 安装 PyTorch 和 torchvision,参考[官方安装指令](https://pytorch.org/),比如,
```shell
conda install pytorch torchvision -c pytorch
```
注:确保您编译的CUDA版本和运行时CUDA版本相匹配。您可以在[PyTorch官网](https://pytorch.org/)检查预编译库支持的CUDA版本。
`示例1` 如果您在`/usr/local/cuda`下安装了 CUDA 10.1 并想要安装
PyTorch 1.5,您需要安装支持CUDA 10.1的PyTorch预编译版本。
```shell
conda install pytorch cudatoolkit=10.1 torchvision -c pytorch
```
`示例2`如果您在`/usr/local/cuda`下安装了 CUDA 9.2 并想要安装
PyTorch 1.5.1,您需要安装支持CUDA 9.2的PyTorch预编译版本。
```shell
conda install pytorch=1.5.1 cudatoolkit=9.2 torchvision=0.6.1 -c pytorch
```
如果您从源码编译PyTorch 而非安装预编译库, 您可以使用更多CUDA版本如9.0。
3. 安装 mmcv-full, 我们建议您按照下述方法安装预编译库。
```shell
pip install mmcv-full={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/{cu_version}/{torch_version}/index.html
```
请替换链接中的 `{cu_version}``{torch_version}` 为您想要的版本。 比如, 要安装支持 `CUDA 11``PyTorch 1.7.0``mmcv-full`, 使用下面命令:
```shell
pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu110/torch1.7.0/index.html
```
可在[这里](https://github.com/open-mmlab/mmcv#install-with-pip)查看兼容了不同PyTorch和CUDA的MMCV版本信息。
您也可以选择按照下方命令从源码编译mmcv
```shell
git clone https://github.com/open-mmlab/mmcv.git
cd mmcv
MMCV_WITH_OPS=1 pip install -e . # package mmcv-full will be installed after this step
cd ..
```
或者直接运行
```shell
pip install mmcv-full
```
4. 克隆MMGeneration仓库。
```shell
git clone https://github.com/open-mmlab/mmgeneration.git
cd mmgeneration
```
5. 安装构建依赖项并安装MMGeneration。
```shell
pip install -r requirements.txt
pip install -v -e . # or "python setup.py develop"
```
注:
a. 依照上面的说明, MMGeneration 会以 `dev` 形式安装,
对代码进行的任何本地修改都将生效,而不需要重新安装。
b. 如果您想要使用 `opencv-python-headless` 而非 `opencv-python`
您可以在安装 MMCV 之前安装它。
### 安装CPU版本
本代码可在仅使用CPU的环境下编译 (当 CUDA 不可用时)。
### 一个从头开始的配置脚本
假设您已经安装了CUDA 10.1,下面是使用conda配置MMGeneration的完整脚本。
```shell
conda create -n open-mmlab python=3.7 -y
conda activate open-mmlab
conda install pytorch==1.7.0 torchvision==0.8.0 cudatoolkit=10.1 -c pytorch -y
# install the latest mmcv
# pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/{cu_version}/{torch_version}/index.html
pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu101/torch1.7.0/index.html
# install mmgeneration
git clone https://github.com/open-mmlab/mmgeneration.git
cd mmgeneration
pip install -r requirements.txt
pip install -v -e .
```
需要注意的是,mmcv-full 只在 PyTorch 1.x.0 上编译, 因为1.x.0与1.x.1通常是保持兼容性的。 如果您的 PyTorch 版本是1.x.1, 您可以安装兼容PyTorch 1.x.0的mmcv-full,通常运行良好。
```shell
# We can ignore the micro version of PyTorch
pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu102/torch1.10/index.html
```
### 使用多个 MMGeneration 版本进行开发
训练和测试脚本已经修改了 `PYTHONPATH`, 以确保脚本使用当前目录中的`MMGeneration`
要使用安装在环境中的默认MMGeneration而不是您正在使用的,您可以删除脚本中的以下代码行
```shell
PYTHONPATH="$(dirname $0)/..":$PYTHONPATH
```
## 验证
为了验证MMGeneration和所需的环境是否正确安装,我们可以运行示例Python代码来初始化一个非条件模型,并使用它来生成随机样本:
```python
from mmgen.apis import init_model sample_unconditional_model
config_file = 'configs/styleganv2/stylegan2_c2_lsun-church_256_b4x8_800k.py'
# you can download this checkpoint in advance and use a local file path.
checkpoint_file = 'https://download.openmmlab.com/mmgen/stylegan2/official_weights/stylegan2-church-config-f-official_20210327_172657-1d42b7d1.pth'
device = 'cuda:0'
# init a generatvie
model = init_model(config_file checkpoint_file device=device)
# sample images
fake_imgs = sample_unconditional_model(model 4)
```
当安装完成后,上面的代码可以成功运行。
欢迎来到 MMGeneration 的用户手册!
=======================================
.. toctree::
:maxdepth: 2
:caption: Get Started
get_started.md
modelzoo_statistics.md
.. toctree::
:maxdepth: 2
:caption: Quick Run
quick_run.md
.. toctree::
:maxdepth: 2
:caption: Tutorials
tutorials/index.rst
.. toctree::
:maxdepth: 2
:caption: Notes
changelog.md
faq.md
.. toctree::
:caption: Switch Language
switch_language.md
.. toctree::
:caption: API Reference
api.rst
Indices and tables
==================
* :ref:`genindex`
* :ref:`search`
@ECHO OFF
pushd %~dp0
REM Command file for Sphinx documentation
if "%SPHINXBUILD%" == "" (
set SPHINXBUILD=sphinx-build
)
set SOURCEDIR=.
set BUILDDIR=_build
if "%1" == "" goto help
%SPHINXBUILD% >NUL 2>NUL
if errorlevel 9009 (
echo.
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
echo.installed, then set the SPHINXBUILD environment variable to point
echo.to the full path of the 'sphinx-build' executable. Alternatively you
echo.may add the Sphinx directory to PATH.
echo.
echo.If you don't have Sphinx installed, grab it from
echo.http://sphinx-doc.org/
exit /b 1
)
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
goto end
:help
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
:end
popd
# Model Zoo Statistics
- Number of papers: 11
- Number of checkpoints: 62
- [CycleGAN: Unpaired Image-to-Image Translation Using Cycle-Consistent Adversarial Networks](https://github.com/open-mmlab/mmgeneration/blob/master/configs/cyclegan) (6 ckpts)
- [Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks](https://github.com/open-mmlab/mmgeneration/blob/master/configs/dcgan) (3 ckpts)
- [Geometric GAN](https://github.com/open-mmlab/mmgeneration/blob/master/configs/ggan) (3 ckpts)
- [Least Squares Generative Adversarial Networks](https://github.com/open-mmlab/mmgeneration/blob/master/configs/lsgan) (4 ckpts)
- [Progressive Growing of GANs for Improved Quality, Stability, and Variation](https://github.com/open-mmlab/mmgeneration/blob/master/configs/pggan) (3 ckpts)
- [Pix2Pix: Image-to-Image Translation with Conditional Adversarial Networks](https://github.com/open-mmlab/mmgeneration/blob/master/configs/pix2pix) (4 ckpts)
- [Positional Encoding as Spatial Inductive Bias in GANs (CVPR'2021)](https://github.com/open-mmlab/mmgeneration/blob/master/configs/positional_encoding_in_gans) (21 ckpts)
- [Singan: Learning a Generative Model from a Single Natural Image (ICCV'2019)](https://github.com/open-mmlab/mmgeneration/blob/master/configs/singan) (3 ckpts)
- [A Style-Based Generator Architecture for Generative Adversarial Networks (CVPR'2019)](https://github.com/open-mmlab/mmgeneration/blob/master/configs/styleganv1) (2 ckpts)
- [Analyzing and Improving the Image Quality of Stylegan (CVPR'2020)](https://github.com/open-mmlab/mmgeneration/blob/master/configs/styleganv2) (11 ckpts)
- [Improved Training of Wasserstein GANs](https://github.com/open-mmlab/mmgeneration/blob/master/configs/wgan-gp) (2 ckpts)
# 1: 在标准的数据集上训练和推理现有的模型
## 用现有的生成模型来生成图像
#!/usr/bin/env python
import glob
import os.path as osp
import re
url_prefix = 'https://github.com/open-mmlab/mmgeneration/blob/master/'
files = sorted(glob.glob('../../configs/*/README.md'))
stats = []
titles = []
num_ckpts = 0
for f in files:
url = osp.dirname(f.replace('../', url_prefix))
with open(f, 'r') as content_file:
content = content_file.read()
title = content.split('\n')[0].replace('# ', '')
titles.append(title)
ckpts = set(x.lower().strip()
for x in re.findall(r'https?://download.(.*?)\.pth', content)
if 'mmgen' in x)
num_ckpts += len(ckpts)
statsmsg = f"""
\t* [{title}]({url}) ({len(ckpts)} ckpts)
"""
stats.append((title, ckpts, statsmsg))
msglist = '\n'.join(x for _, _, x in stats)
modelzoo = f"""
# Model Zoo Statistics
* Number of papers: {len(titles)}
* Number of checkpoints: {num_ckpts}
{msglist}
"""
with open('modelzoo_statistics.md', 'w') as f:
f.write(modelzoo)
# Tutorial 8: 生成模型的应用
## 插值
以GAN为架构的生成模型学习将潜码空间中的点映射到生成的图像上。生成模型赋予了潜码空间的具体意义。一般来说,我们想探索潜码空间的结构,我们可以做的一件事是在潜码空间的两个端点之间插入一系列点,观察这些点生成的结果。(例如,我们认为,如果任何一个端点都不存在的特征出现在线性插值路径的中间点,则说明潜码空间是纠缠在一起的,动态属性没有得到适当的分离。)
我们为用户提供了一个应用脚本。你可以使用[apps/interpolate_sample.py](https://github.com/open-mmlab/mmgeneration/tree/master/apps/interpolate_sample.py)的以下命令进行无条件模型的插值。
```bash
python apps/interpolate_sample.py \
${CONFIG_FILE} \
${CHECKPOINT} \
[--show-mode ${SHOW_MODE}] \
[--endpoint ${ENDPOINT}] \
[--interval ${INTERVAL}] \
[--space ${SPACE}] \
[--samples-path ${SAMPLES_PATH}] \
[--batch-size ${BATCH_SIZE}] \
```
在这里,我们提供两种显示模式(SHOW_MODE),即序列(sequence)和组(group)。在序列模式下,我们首先对一连串的端点进行采样,然后按顺序对两个端点之间的点进行插值,生成的图像将被单独保存。在组模式下,我们先采样几对端点,然后在每对端点之间进行插值,生成的图像将被保存在一张图片中。此外,`space` 指的是潜码空间,你可以选择'z'或'w'(指StyleGAN系列中的风格空间),`endpoint` 表示你要采样的端点数量(在 `group` 模式中应设置为偶数),`interval`表示你在两个端点之间插值的点的数量(包括端点)。
注意,我们还提供了更多的自定义参数来定制你的插值程序。
请使用`python apps/interpolate_sample.py --help`来查看更多细节。
如同上面的方法,你可以使用[apps/conditional_interpolate.py](https://github.com/open-mmlab/mmgeneration/tree/master/apps/conditional_interpolate.py)和下列命令进行条件模型的插值。
```bash
python apps/conditional_interpolate.py \
${CONFIG_FILE} \
${CHECKPOINT} \
[--show-mode ${SHOW_MODE}] \
[--endpoint ${ENDPOINT}] \
[--interval ${INTERVAL}] \
[--embedding-name ${EMBEDDING_NAME}]
[--fix-z] \
[--fix-y] \
[--samples-path ${SAMPLES_PATH}] \
[--batch-size ${BATCH_SIZE}] \
```
在这里,与无条件模型不同,如果标签嵌入在 `conv_blocks` 之间共享,你需要提供嵌入层的名称。否则,你应该将 `embedding-name` 设置为 `NULL`。考虑到条件模型有噪声和标签作为输入,我们提供 `fix-z` 来固定噪声,`fix-y` 来固定标签。
## 投影
求生成网络 `g` 的逆是一个有趣的问题,有很多应用。例如,在潜码空间中操作一个给定的图像需要先为它找到一个匹配的潜码。一般来说,你可以通过对潜码进行优化来重建目标图像,使用 `lpips` 和像素级损失作为目标函数。
事实上,我们已经向用户提供了一个应用脚本,为给定的图像找到 `StyleGAN` 系列生成网络的匹配潜码向量 `w`。你可以使用[apps/stylegan_projector.py](https://github.com/open-mmlab/mmgeneration/tree/master/apps/stylegan_projector.py)的以下命令来执行投影。
```bash
python apps/stylegan_projector.py \
${CONFIG_FILE} \
${CHECKPOINT} \
${FILES}
[--results-path ${RESULTS_PATH}]
```
这里,`FILES` 指的是图像的路径,而投影的潜码和重建的图像将被保存在 `results-path` 中。
注意,我们还提供了更多的自定义参数来定制你的投影程序。请使用`python apps/stylegan_projector.py --help`来查看更多细节。
## 编辑
基于 StyleGAN 模型的一个常见应用是操纵潜码空间来控制合成图像的属性。在这里,我们向用户提供了一个基于[SeFa](https://arxiv.org/pdf/2007.06600.pdf)的简单而流行的算法。这里,我们在计算特征向量时对原始版本进行了修改,并提供了一个更灵活的接口。
为了操纵你的生成器,你可以用以下命令运行脚本[apps/modified_sefa.py](https://github.com/open-mmlab/mmgeneration/tree/master/apps/modified_sefa.py)
```shell
python apps/modified_sefa.py --cfg ${CONFIG} --ckpt ${CKPT} \
-i ${INDEX} -d ${DEGREE} --degree-step ${D_STEP} \
-l ${LAYER_NO} \
[--eigen-vector ${PATH_EIGEN_VEC}]
```
在这个脚本中,如果 `eigen-vector``None`,程序将计算生成器参数的特征向量。同时,我们将把该向量保存在 `ckpt` 文件的同一目录下,这样用户就可以应用这个预先计算的向量。`Positional Encoding as Spatial Inductive Bias in GANs` 的演示就来自这个脚本。下面是一个例子,供用户获得与我们的演示类似的结果。
`${INDEX}`表示我们将应用哪个特征向量来操作图像。在一般情况下,每个索引控制一个独立的属性,这是由 `StyleGAN` 中的解耦表示保证的。我们建议用户可以尝试不同的索引来找到你想要的那个属性。`--degree` 设定了乘法因子的范围。在我们的实验中,我们观察到像 `[-3, 8]` 这样的非对称范围是非常有帮助的。因此,我们允许在这个参数中设置下限和上限。`--layer``--l` 定义了我们将应用哪一层的特征向量。有些属性,比如光照,只与生成器中的 1-2 层有关。
以光照属性为例,我们在 MS-PIE-StyleGAN2-256 模型上运行以下命令。
```shell
python apps/modified_sefa.py \
--config configs/positional_encoding_in_gans/mspie-stylegan2_c2_config-f_ffhq_256-512_b3x8_1100k.py \
--ckpt https://download.openmmlab.com/mmgen/pe_in_gans/mspie-stylegan2_c2_config-f_ffhq_256-512_b3x8_1100k_20210406_144927-4f4d5391.pth \
-i 15 -d 8. --degree-step 0.5 -l 8 9 --sample-path ./work_dirs/sefa-exp/ \
--sample-cfg chosen_scale=4 randomize_noise=False
```
注意到,在设置 `chosen_scale=4` 之后,我们可以用一个简单的分辨率为256的生成器来操作512x512的图像。
# Tutorial 1: 配置系统 (config files)
# Tutorial 2: 自定义数据集
# Tutorial 4: 损失函数模块的设计思路
# Tutorial 3: 自定义模型
# Tutorial 6: 自定义配置
# Tutorial 5: MMGeneration 中的分布式训练
在本节中,我们将讨论生成模型的 `DDP`(分布式数据并行)训练,特别是 GANs 的训练。
## 分布式数据并行的训练方式总结
| DDP Model | find_unused_parameters | Static GANs | Dynamic GANs |
| :--------------------------------: | :--------------------: | :---------: | :----------: |
| MMDDP/PyTorch DDP | False | Error | Error |
| MMDDP/PyTorch DDP | True | Error | Error |
| DDP Wrapper | False | **No Bugs** | Error |
| DDP Wrapper | True | **No Bugs** | **No Bugs** |
| MMDDP/PyTorch DDP + Dynamic Runner | True | **No Bugs** | **No Bugs** |
在这个表格中,我们总结了生成对抗网络(GANs)的 DDP 训练方式。[`MMDDP/PyTorch DDP`](https://github.com/open-mmlab/mmcv/blob/master/mmcv/parallel/distributed.py)表示用 `MMDistributedDataPrarallel` 直接封装 GAN 模型(包含生成器、判别器和损失模块)。然而,在这种方式下,我们无法对 GAN 模型应用对抗训练。主要原因是我们总是需要在 `train_step` 函数中对部分模型(只对判别器或生成器)的损失进行反向传播。
另一种使用 DDP 的方式是采用 [DDP Wrapper](https://github.com/open-mmlab/mmgeneration/tree/master/mmgen/core/ddp_wrapper.py),用 `MMDDP` 封装 GAN 模型中的每个模块,这在目前的实践中被广泛使用,例如,`MMEditing`[StyleGAN2-ADA-PyTorch](https://github.com/NVlabs/stylegan2-ada-pytorch)。这样一来,就有了一个重要的参数,`find_unused_parameters`。如表所示,对于训练动态架构的模型,如 PGGAN 和 StyleGANv1,用户必须设置这个参数为 `True`。 然而,一旦 `find_unused_parameters` 设置为 `True`,模型将在每个前向传播后重建 `bucket` 以同步梯度和信息,从而在反向传播过程中追踪计算图所需的张量。
`MMGeneration` 中,我们为用户设计了另一种采用 `DDP` 训练的方式,即 `MMDDP/PyTorch DDP + Dynamic Runner`。在具体说明这个新设计的细节之前,我们首先解释为什么用户应该使用它。尽管通过 `DDP Wrapper` 实现了动态 GAN 的训练,我们仍然发现了一些不便和缺点。
- `DDP Wrapper` 使用户无法调用或获得 GANs 中模块的函数或属性,例如,生成器和判别器。采用 `DDP Wrapper` 后,如果我们想调用 `generator` 中的函数,我们必须使用 `generator.module.xxx()`
- `DDP Wrapper` 将导致多余的桶重建。通过采用 `DDP Wrapper` 来避免 ddp 错误的真正原因是,GAN 模型中的每个模块在调用它们的 `forward` 函数后,会立即为反向传播重建桶。然而,正如 GAN 实践中所知道的,有很多情况下我们不需要为反向传播建立一个桶,例如,在更新判别器时为生成器建桶。
为了解决这些问题,我们试图找到一种方法来直接采用 `MMDDP` 并支持动态的 GAN 训练。在 `MMGeneration` 中,`DynamicIterBasedRunner` 帮助我们实现这一目标。重要的是,只需要少于十行的修改就能解决这个问题。
## MMDDP/PyTorch DDP + Dynamic Runner
在静态/动态GAN训练中采用 DDP 的关键点是在反向传播(判别器和生成器)之前构建(或检查)桶。因为这两个反向中需要梯度的参数来自 GAN 模型的不同部分。因此,我们的解决方案只是在每个反向传播之前显示地重建桶。
[mmgen/core/runners/dynamic_iterbased_runner.py](https://github.com/open-mmlab/mmgeneration/tree/master/mmgen/core/runners/dynamic_iterbased_runner.py)中,我们通过使用 **PyTorch private API** 获得 `reducer`
```python
if self.is_dynamic_ddp:
kwargs.update(dict(ddp_reducer=self.model.reducer))
outputs = self.model.train_step(data_batch, self.optimizer, **kwargs)
```
通过如下对 train_step 的修改,reducer 可以帮助我们在当前反传中重建桶:
```python
if ddp_reducer is not None:
ddp_reducer.prepare_for_backward(_find_tensors(loss_disc))
```
一个完整用例如下:
```python
loss_disc, log_vars_disc = self._get_disc_loss(data_dict_)
# prepare for backward in ddp. If you do not call this function before
# back propagation, the ddp will not dynamically find the used params
# in current computation.
if ddp_reducer is not None:
ddp_reducer.prepare_for_backward(_find_tensors(loss_disc))
loss_disc.backward()
```
也就是说,用户应该在损失计算和损失反传之间准备 reducer。
在我们的 `MMGeneration` 中,这个功能被作为训练 DDP 模型的默认方式。在配置文件中,用户只需要添加以下配置来使用动态 ddp runner。
```python
# use dynamic runner
runner = dict(
type='DynamicIterBasedRunner',
is_dynamic_ddp=True,
pass_training_status=True)
```
*这个实现将使用 PyTorch 中的私有接口,我们将继续维护这一实现。*
## DDP Wrapper
当然,我们仍然支持使用 `DDP Wrapper` 来训练你的 GANs。如果你想切换到使用 DDP Wrapper,你应该这样修改配置文件。
```python
# use ddp wrapper for faster training
use_ddp_wrapper = True
find_unused_parameters = True # True for dynamic model, False for static model
runner = dict(
type='DynamicIterBasedRunner',
is_dynamic_ddp=False, # Note that this flag should be False.
pass_training_status=True)
```
[`dcgan config file`](https://github.com/open-mmlab/mmgeneration/tree/master/configs/dcgan/dcgan_celeba-cropped_64_b128x1_300k.py)中,我们已经提供了一个在 MMGeneration 中使用 `DDP Wrapper` 的例子。
.. toctree::
:maxdepth: 2
config.md
customize_dataset.md
customize_models.md
customize_losses.md
ddp_train_gans.md
customize_runtime.md
applications.md
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment