Commit 1401de15 authored by dongchy920's avatar dongchy920
Browse files

stylegan2_mmcv

parents
Pipeline #1274 canceled with stages
# Licenses for special operations
In this file, we list the operations with other licenses instead of Apache 2.0. Users should be careful about adopting these operations in any commercial matters.
| Operation | Files | License |
| :------------------: | :-----------------------------------------------------------------------------------------------------------------------------------: | :------------: |
| conv2d_gradfix | [mmgen/ops/conv2d_gradfix.py](https://github.com/open-mmlab/mmgeneration/blob/master/mmgen/ops/conv2d_gradfix.py) | NVIDIA License |
| compute_pr_distances | [mmgen/core/evaluation/metric_utils.py](https://github.com/open-mmlab/mmgeneration/blob/master/mmgen/core/evaluation/metric_utils.py) | NVIDIA License |
include mmgen/model-index.yml
recursive-include mmgen/configs *.py *.yml
recursive-include mmgen/tools *.sh *.py
include requirements/*.txt
include mmgen/VERSION
include mmgen/.mim/model-index.yml
include mmgen/.mim/demo/*/*
recursive-include mmgen/.mim/configs *.py *.yml
recursive-include mmgen/.mim/tools *.sh *.py
# ArcFace
## 论文
- https://arxiv.org/pdf/1912.04958
## 模型结构
针对StyleGAN中,大多数的生成的图像容易产生一个类似水滴状的伪影问题,StyleGAN2对AdaIN的归一化操作进行改进,作者把AdaIN层里面的归一化去掉,将噪声B和偏置项b移动出style模块之外,取得了更好的生成效果。
<div align=center>
<img src="./images/AdaIN.png"/>
</div>
## 算法原理
StyleGAN2为StyleGAN的改进版本,StyleGAN是在PGGAN的基础上进行改进的模型
GAN(生成对抗网络)包含了一个生成模型G和一个判别模型D,模型通过生成器G对从正态分布P(z)随机采样的z生成伪数据x',和从真实图像分布P(data)中采样的样本x作为判别器D的输入,判别器要让x的概率越大越好,让x'的概率越小越好,同时生成器希望生成的样本让判别器判别为真的概率越大越好。通过这种对抗的方式使模型生成越来越逼真的图片。
<div align=center>
<img src="./images/GAN2.png"/>
</div>
PGGAN(渐进式生长生成对抗网络)通过先从低分辨率开始训练,然后再逐层提高分辨率进行训练的方式解决了传统GAN存在的模式崩溃(生成数据只是原始数据的子集(生成器偏向于生成判别器难以判别的样本))和难以训练高分辨率图片(生成器刚开始直接生成高分辨率图片很容易被判别器识别,在反向传播出现梯度大范围更新导致生成器崩溃)的问题,PGGAN先训练低分辨率,然后通过平滑接入的方式逐步提高分辨率:
<div align=center>
<img src="./images/PGGAN.png"/>
</div>
StyleGAN
PGGAN虽然能生成高清伪图,但是不能对图象的风格和细节进行修改,StyleGAN通过对特征进行解耦,让特征之间相互独立,互不影响,从而达到单独修改图象的某一部分的目的。具体来说StyleGAN是通过修改生成器,下图中左边为传统的生成器,右部分为StyleGAN的生成器,由两部分构成——Mapping network和Synthesis network ,其中Mapping network就是用来控制图像的风格信息,Synthesis network用来生成图像
<div align=center>
<img src="./images/styleGAN.png"/>
</div>
StyleGAN2就是在StyleGAN的基础上改进了归一化操作,对损失函数和训练方法进行改进。
## 环境配置
### Docker(方法一)
[光源](https://www.sourcefind.cn/#/service-list)中拉取docker镜像:
```
docker pull image.sourcefind.cn:5000/dcu/admin/base/pytorch:1.10.0-centos7.6-dtk-23.04-py39-latest
```
创建容器并挂载目录进行开发:
```
docker run -it --name {name} --shm-size=1024G --device=/dev/kfd --device=/dev/dri/ --privileged --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --ulimit memlock=-1:-1 --ipc=host --network host --group-add video -v /opt/hyhal:/opt/hyhal:ro -v {}:{} {docker_image} /bin/bash
# 修改1 {name} 需要改为自定义名称,建议命名{框架_dtk版本_使用者姓名},如果有特殊用途可在命名框架前添加命名
# 修改2 {docker_image} 需要需要创建容器的对应镜像名称,如: pytorch:1.10.0-centos7.6-dtk-23.04-py37-latest【镜像名称:tag名称】
# 修改3 -v 挂载路径到容器指定路径
pip install -r requirements.txt
```
### Dockerfile(方法二)
```
cd docker
docker build --no-cache -t gan2_pytorch:1.0 .
docker run -it --name {name} --shm-size=1024G --device=/dev/kfd --device=/dev/dri/ --privileged --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --ulimit memlock=-1:-1 --ipc=host --network host --group-add video -v /opt/hyhal:/opt/hyhal:ro -v {}:{} {docker_image} /bin/bash
pip install -r requirements.txt
```
### Anaconda(方法三)
线上节点推荐使用conda进行环境配置。
创建python=3.9的conda环境并激活
```
conda create -n styleGan2 python=3.9
conda activate styleGan2
```
关于本项目DCU显卡所需的特殊深度学习库可从[光合](https://developer.hpccube.com/tool/)开发者社区下载安装。
```
DTK驱动:dtk23.04
python:python3.9
pytorch:1.10.0
torchvision:0.11.0
```
安装其他依赖包
```
pip install -r requirements.txt
```
## 数据集
- unconditional models训练集[FFHQ](https://drive.google.com/drive/folders/1u2xu7bSrWxrbUxk-dT-UvEJq8IjdmNTP)
下载后解压到当前目录
数据目录结构如下:
```
./data/ffhq/
   └── 00000.png
└── 00001.png
```
- 使用只有100个样本的小数据集[data](https://pan.baidu.com/s/1ep4k_9w1YR9Cg8Q5NQmNYQ?pwd=1234)进行训练和测试
## 训练
### 单机单卡
```
sh tools/dist_train.sh configs/styleganv2/stylegan2_c2_ffhq_1024_b4x8.py 1 --work-dir work_dirs/experiments/stylegan2_c2_ffhq_1024_b4x8/
```
### 单机多卡
```
sh tools/dist_train.sh configs/styleganv2/stylegan2_c2_ffhq_1024_b4x8.py 4 --work-dir work_dirs/experiments/stylegan2_c2_ffhq_1024_b4x8/
```
## 精度
下载权重文件[stylegan2_c2_ffhq_1024_b4x8](https://download.openmmlab.com/mmgen/stylegan2/official_weights/stylegan2-ffhq-config-f-official_20210327_171224-bce9310c.pth),测试模型精度:
```
tools/eval.sh configs/styleganv2/stylegan2_c2_ffhq_1024_b4x8.py stylegan2_c2_ffhq_1024_b4x8_20210407_150045-618c9024.pth --batch-size 10 --online
```
模型在ffhq数据集的测试指标:
| 模型 | 数据类型 | FID | P/R |
| :------: | :------: | :------: | :------: |
| [stylegan2_c2_ffhq_1024_b4x8](https://download.openmmlab.com/mmgen/stylegan2/official_weights/stylegan2-ffhq-config-f-official_20210327_171224-bce9310c.pth) | fp16 | 92.0757 | 0.65/0.48 |
## 推理
### unconditional GANs推理
```
python demo/unconditional_demo.py configs/styleganv2/stylegan2_c2_ffhq_1024_b4x8.py stylegan2_c2_ffhq_1024_b4x8_20210407_150045-618c9024.pth
```
<div align=center>
<img src="./images/unconditional_samples.png"/>
</div>
### conditional GANs推理
```
python demo/conditional_demo.py configs/styleganv2/stylegan2_c2_ffhq_1024_b4x8.py stylegan2_c2_ffhq_1024_b4x8_20210407_150045-618c9024.pth --label 1
```
<div align=center>
<img src="./images/conditional_samples_2.png"/>
</div>
### 通过插值调整人脸生成的细节
```
python apps/interpolate_sample.py configs/styleganv2/stylegan2_c2_ffhq_1024_b4x8.py stylegan2_c2_ffhq_1024_b4x8_20210407_150045-618c9024.pth --show-mode group --samples-path result
```
<div align=center>
<img src="./images/group.png"/>
</div>
## 应用场景
### 算法类别
人脸生成
### 热点应用行业
安防,交通,教育
## 源码仓库及问题反馈
[https://developer.hpccube.com/codes/modelzoo/stylegan2_mmcv](https://developer.hpccube.com/codes/modelzoo/stylegan2_mmcv)
## 参考资料
[https://github.com/open-mmlab/mmgeneration/tree/master](https://github.com/deepinsight/insightface/tree/master/recognition/arcface_torch)
<div align="center">
<img src="https://user-images.githubusercontent.com/12726765/114528756-de55af80-9c7b-11eb-94d7-d3224ada1585.png" width="400"/>
<div>&nbsp;</div>
<div align="center">
<b><font size="5">OpenMMLab 官网</font></b>
<sup>
<a href="https://openmmlab.com">
<i><font size="4">HOT</font></i>
</a>
</sup>
&nbsp;&nbsp;&nbsp;&nbsp;
<b><font size="5">OpenMMLab 开放平台</font></b>
<sup>
<a href="https://platform.openmmlab.com">
<i><font size="4">TRY IT OUT</font></i>
</a>
</sup>
</div>
<div>&nbsp;</div>
</div>
[![PyPI](https://img.shields.io/pypi/v/mmgen)](https://pypi.org/project/mmgen)
[![docs](https://img.shields.io/badge/docs-latest-blue)](https://mmgeneration.readthedocs.io/en/latest/)
[![badge](https://github.com/open-mmlab/mmgeneration/workflows/build/badge.svg)](https://github.com/open-mmlab/mmgeneration/actions)
[![codecov](https://codecov.io/gh/open-mmlab/mmgeneration/branch/master/graph/badge.svg)](https://codecov.io/gh/open-mmlab/mmgeneration)
[![license](https://img.shields.io/github/license/open-mmlab/mmgeneration.svg)](https://github.com/open-mmlab/mmgeneration/blob/master/LICENSE)
[![open issues](https://isitmaintained.com/badge/open/open-mmlab/mmgeneration.svg)](https://github.com/open-mmlab/mmgeneration/issues)
[![issue resolution](https://isitmaintained.com/badge/resolution/open-mmlab/mmgeneration.svg)](https://github.com/open-mmlab/mmgeneration/issues)
[📘使用文档](https://mmgeneration.readthedocs.io/en/latest/) |
[🛠️安装教程](https://mmgeneration.readthedocs.io/en/latest/get_started.html#installation) |
[👀模型库](https://mmgeneration.readthedocs.io/en/latest/modelzoo_statistics.html) |
[🆕更新记录](https://github.com/open-mmlab/mmgeneration/blob/master/docs/en/changelog.md) |
[🚀进行中的项目](https://github.com/open-mmlab/mmgeneration/projects) |
[🤔提出问题](https://github.com/open-mmlab/mmgeneration/issues)
[English](README.md) | 简体中文
## 最新进展
我们将MMGeneration合入了[MMEditing](https://github.com/open-mmlab/mmediting/tree/1.x),并支持了新的生成任务和算法。请关注以下新特性:
- 🌟 图文生成任务
-[GLIDE](https://github.com/open-mmlab/mmediting/tree/1.x/projects/glide/configs/README.md)
-[Disco-Diffusion](https://github.com/open-mmlab/mmediting/tree/1.x/configs/disco_diffusion/README.md)
-[Stable-Diffusion](https://github.com/open-mmlab/mmediting/tree/1.x/configs/stable_diffusion/README.md)
- 🌟 3D生成任务
-[EG3D](https://github.com/open-mmlab/mmediting/tree/1.x/configs/eg3d/README.md)
## 简介
MMGeneration 是一个基于 PyTorch 和[MMCV](https://github.com/open-mmlab/mmcv)的强有力的生成模型工具箱,尤其专注于 GAN 模型。
主分支目前支持 **PyTorch 1.5** 以上的版本。
<div align="center">
<img src="https://user-images.githubusercontent.com/12726765/114534478-9a65a900-9c81-11eb-8087-de8b6816eed8.png" width="800"/>
</div>
## 主要特性
- **高质量高性能的训练:** 我们目前支持 Unconditional GANs, Internal GANs, 以及 Image Translation Models 的训练。很快将会支持 conditional models 的训练。
- **强有力的应用工具箱:** 为用户提供了丰富的工具箱,包含 GANs 中的多种应用。我们的框架集成了 GANs 的插值,投影和编辑。请用你的 GANs 尽情尝试!([应用教程](docs/tutorials/applications.md))
- **生成模型的高效分布式训练:** 对于生成模型中的高度动态训练,我们采用 `MMDDP` 的新方法来训练动态模型。([DDP教程](docs/tutorials/ddp_train_gans.md))
- **灵活组合的新型模块化设计:** 针对复杂的损失模块,我们提出了一种新的设计,可以自定义模块之间的链接,实现不同模块之间的灵活组合。 ([新模块化设计教程](docs/tutorials/customize_losses.md))
<table>
<thead>
<tr>
<td>
<div align="center">
<b> 训练可视化</b>
<br/>
<img src="https://user-images.githubusercontent.com/12726765/114509105-b6f4e780-9c67-11eb-8644-110b3cb01314.gif" width="200"/>
</div></td>
<td>
<div align="center">
<b> GAN 插值</b>
<br/>
<img src="https://user-images.githubusercontent.com/12726765/114679300-9fd4f900-9d3e-11eb-8f37-c36a018c02f7.gif" width="200"/>
</div></td>
<td>
<div align="center">
<b> GAN 投影</b>
<br/>
<img src="https://user-images.githubusercontent.com/12726765/114524392-c11ee200-9c77-11eb-8b6d-37bc637f5626.gif" width="200"/>
</div></td>
<td>
<div align="center">
<b> GAN 编辑</b>
<br/>
<img src="https://user-images.githubusercontent.com/12726765/114523716-20302700-9c77-11eb-804e-327ae1ca0c5b.gif" width="200"/>
</div></td>
</tr>
</thead>
</table>
## 亮点
- **Positional Encoding as Spatial Inductive Bias in GANs (CVPR2021)** 已在 `MMGeneration` 中发布. [\[配置文件\]](configs/positional_encoding_in_gans/README.md), [\[项目主页\]](https://nbei.github.io/gan-pos-encoding.html)
- 我们已经支持训练目前主流的 Conditional GANs 模型,更多的方法和预训练权重马上就会发布,敬请期待。
- 混合精度训练已经在 `StyleGAN2` 中进行了初步支持,请到[这里](configs/styleganv2/README.md)查看各种实现方式的详细比较。
## 更新日志
v0.7.3 在 14/04/2023 发布。 关于细节和发布历史,请参考 [changelog.md](docs/zh_cn/changelog.md)
## 安装
MMGeneration 依赖 [PyTorch](https://pytorch.org/)[MMCV](https://github.com/open-mmlab/mmcv),以下是安装的简要步骤。
**步骤 1.**
依照[官方教程](https://pytorch.org/get-started/locally/)安装PyTorch,例如
```python
pip3 install torch torchvision
```
**步骤 2.**
使用 [MIM](https://github.com/open-mmlab/mim) 安装 MMCV
```
pip3 install openmim
mim install mmcv-full
```
**步骤 3.**
从源码安装 MMGeneration
```
git clone https://github.com/open-mmlab/mmgeneration.git
cd mmgeneration
pip3 install -e .
```
更详细的安装指南请参考 [get_started.md](docs/zh/get_started.md) .
## 快速入门
对于 `MMGeneration` 的基本使用请参考 [快速入门](docs/zh_cn/get_started.md)。其他细节和教程,请参考我们的[文档](https://mmgeneration.readthedocs.io/)
## 模型库
这些算法在我们的框架中得到了认真研究和支持。
<details open>
<summary>Unconditional GANs (点击折叠)</summary>
-[DCGAN](configs/dcgan/README.md) (ICLR'2016)
-[WGAN-GP](configs/wgan-gp/README.md) (NIPS'2017)
-[LSGAN](configs/lsgan/README.md) (ICCV'2017)
-[GGAN](configs/ggan/README.md) (arXiv'2017)
-[PGGAN](configs/pggan/README.md) (ICLR'2018)
-[StyleGANV1](configs/styleganv1/README.md) (CVPR'2019)
-[StyleGANV2](configs/styleganv2/README.md) (CVPR'2020)
-[StyleGANV3](configs/styleganv3/README.md) (NeurIPS'2021)
-[Positional Encoding in GANs](configs/positional_encoding_in_gans/README.md) (CVPR'2021)
</details>
<details open>
<summary>Conditional GANs (点击折叠)</summary>
-[SNGAN](configs/sngan_proj/README.md) (ICLR'2018)
-[Projection GAN](configs/sngan_proj/README.md) (ICLR'2018)
-[SAGAN](configs/sagan/README.md) (ICML'2019)
-[BIGGAN/BIGGAN-DEEP](configs/biggan/README.md) (ICLR'2019)
</details>
<details open>
<summary>Tricks for GANs (点击折叠)</summary>
-[ADA](configs/ada/README.md) (NeurIPS'2020)
</details>
<details open>
<summary>Image2Image Translation (点击折叠)</summary>
-[Pix2Pix](configs/pix2pix/README.md) (CVPR'2017)
-[CycleGAN](configs/cyclegan/README.md) (ICCV'2017)
</details>
<details open>
<summary>Internal Learning (点击折叠)</summary>
-[SinGAN](configs/dcgan/README.md) (ICCV'2019)
</details>
<details open>
<summary>Denoising Diffusion Probabilistic Models (点击折叠)</summary>
-[Improved DDPM](configs/improved_ddpm/README.md) (arXiv'2021)
</details>
## 相关应用
-[MMGEN-FaceStylor](https://github.com/open-mmlab/MMGEN-FaceStylor)
## 贡献指南
我们感谢所有的贡献者为改进和提升 MMGeneration 所作出的努力。请参考[贡献指南](https://github.com/open-mmlab/mmcv/blob/master/CONTRIBUTING.md)来了解参与项目贡献的相关指引。
## 引用
如果您发现此项目对您的研究有用,请考虑引用:
```BibTeX
@misc{2021mmgeneration,
title={{MMGeneration}: OpenMMLab Generative Model Toolbox and Benchmark},
author={MMGeneration Contributors},
howpublished = {\url{https://github.com/open-mmlab/mmgeneration}},
year={2020}
}
```
## 开源许可证
该项目采用 [Apache 2.0 license](LICENSE) 开源许可证。`MMGeneration` 中的一些操作使用了其他许可证。如果您使用我们的代码进行商业事务,请参考 [许可证](LICENSES.md) 并仔细检查。
## OpenMMLab 的其他项目
- [MMCV](https://github.com/open-mmlab/mmcv): OpenMMLab 计算机视觉基础库
- [MIM](https://github.com/open-mmlab/mim): MIM 是 OpenMMlab 项目、算法、模型的统一入口
- [MMClassification](https://github.com/open-mmlab/mmclassification): OpenMMLab 图像分类工具箱
- [MMDetection](https://github.com/open-mmlab/mmdetection): OpenMMLab 目标检测工具箱
- [MMDetection3D](https://github.com/open-mmlab/mmdetection3d): OpenMMLab 新一代通用 3D 目标检测平台
- [MMRotate](https://github.com/open-mmlab/mmrotate): OpenMMLab 旋转框检测工具箱与测试基准
- [MMSegmentation](https://github.com/open-mmlab/mmsegmentation): OpenMMLab 语义分割工具箱
- [MMOCR](https://github.com/open-mmlab/mmocr): OpenMMLab 全流程文字检测识别理解工具箱
- [MMPose](https://github.com/open-mmlab/mmpose): OpenMMLab 姿态估计工具箱
- [MMHuman3D](https://github.com/open-mmlab/mmhuman3d): OpenMMLab 人体参数化模型工具箱与测试基准
- [MMSelfSup](https://github.com/open-mmlab/mmselfsup): OpenMMLab 自监督学习工具箱与测试基准
- [MMRazor](https://github.com/open-mmlab/mmrazor): OpenMMLab 模型压缩工具箱与测试基准
- [MMFewShot](https://github.com/open-mmlab/mmfewshot): OpenMMLab 少样本学习工具箱与测试基准
- [MMAction2](https://github.com/open-mmlab/mmaction2): OpenMMLab 新一代视频理解工具箱
- [MMTracking](https://github.com/open-mmlab/mmtracking): OpenMMLab 一体化视频目标感知平台
- [MMFlow](https://github.com/open-mmlab/mmflow): OpenMMLab 光流估计工具箱与测试基准
- [MMEditing](https://github.com/open-mmlab/mmediting): OpenMMLab 图像视频编辑工具箱
- [MMGeneration](https://github.com/open-mmlab/mmgeneration): OpenMMLab 图片视频生成模型工具箱
- [MMDeploy](https://github.com/open-mmlab/mmdeploy): OpenMMLab 模型部署框架
## 欢迎加入 OpenMMLab 社区
扫描下方的二维码可关注 OpenMMLab 团队的 [知乎官方账号](https://www.zhihu.com/people/openmmlab),加入 OpenMMLab 团队的 [官方交流 QQ 群](https://jq.qq.com/?_wv=1027&k=K0QI8ByU)
<div align="center">
<img src="https://user-images.githubusercontent.com/22982797/115827101-66874200-a43e-11eb-9abf-831094c27ef4.JPG" height="400" /> <img src="https://user-images.githubusercontent.com/25839884/203927852-e15def4d-a0eb-4dfc-9bfb-7cf09ea945d0.png" height="400" />
</div>
我们会在 OpenMMLab 社区为大家
- 📢 分享 AI 框架的前沿核心技术
- 💻 解读 PyTorch 常用模块源码
- 📰 发布 OpenMMLab 的相关新闻
- 🚀 介绍 OpenMMLab 开发的前沿算法
- 🏃 获取更高效的问题答疑和意见反馈
- 🔥 提供与各行各业开发者充分交流的平台
干货满满 📘,等你来撩 💗,OpenMMLab 社区期待您的加入 👬
import argparse
import os
import sys
import mmcv
import torch
import torch.nn as nn
from mmcv import Config, DictAction
from mmcv.runner import load_checkpoint
from torchvision.utils import save_image
# yapf: disable
sys.path.append(os.path.abspath(os.path.join(__file__, '../..'))) # isort:skip # noqa
from mmgen.apis import set_random_seed # isort:skip # noqa
from mmgen.core.evaluation import slerp # isort:skip # noqa
from mmgen.models import build_model # isort:skip # noqa
from mmgen.models.architectures import BigGANDeepGenerator, BigGANGenerator # isort:skip # noqa
from mmgen.models.architectures.common import get_module_device # isort:skip # noqa
# yapf: enable
_default_embedding_name = dict(
BigGANGenerator='shared_embedding',
BigGANDeepGenerator='shared_embedding',
SNGANGenerator='NULL',
SAGANGenerator='NULL')
def parse_args():
parser = argparse.ArgumentParser(
description='Sampling from latents\' interpolation')
parser.add_argument('config', help='evaluation config file path')
parser.add_argument('checkpoint', help='checkpoint file')
parser.add_argument(
'--use-cpu',
action='store_true',
help='whether to use cpu device for sampling')
parser.add_argument(
'--embedding-name',
type=str,
default=None,
help='name of conditional model\'s embedding layer')
parser.add_argument(
'--fix-z',
action='store_true',
help='whether to fix the noise for conditional model')
parser.add_argument(
'--fix-y',
action='store_true',
help='whether to fix the label for conditional model')
parser.add_argument('--seed', type=int, default=2021, help='random seed')
parser.add_argument(
'--deterministic',
action='store_true',
help='whether to set deterministic options for CUDNN backend.')
parser.add_argument(
'--samples-path', type=str, help='path to store images.')
parser.add_argument(
'--sample-model',
type=str,
default='ema',
help='use which mode (ema/orig) in sampling.')
parser.add_argument(
'--show-mode',
choices=['group', 'sequence'],
default='sequence',
help='mode to show interpolation result.')
parser.add_argument(
'--interp-mode',
choices=['lerp', 'slerp'],
default='lerp',
help='mode to sample from endpoints\'s interpolation.')
parser.add_argument(
'--endpoint', type=int, default=2, help='The number of endpoints.')
parser.add_argument(
'--batch-size',
type=int,
default=2,
help='batch size used in generator sampling.')
parser.add_argument(
'--interval',
type=int,
default=10,
help='The number of intervals between two endpoints.')
parser.add_argument(
'--sample-cfg',
nargs='+',
action=DictAction,
help='Other customized kwargs for sampling function')
args = parser.parse_args()
return args
@torch.no_grad()
def batch_inference(generator,
noise,
embedding=None,
num_batches=-1,
max_batch_size=16,
dict_key=None,
**kwargs):
"""Inference function to get a batch of desired data from output dictionary
of generator.
Args:
generator (nn.Module): Generator of a conditional model.
noise (Tensor | list[torch.tensor] | None): A batch of noise
Tensor.
embedding (Tensor, optional): Embedding tensor of label for
conditional models. Defaults to None.
num_batches (int, optional): The number of batchs for
inference. Defaults to -1.
max_batch_size (int, optional): The number of batch size for
inference. Defaults to 16.
dict_key (str, optional): key used to get results from output
dictionary of generator. Defaults to None.
Returns:
torch.Tensor: Tensor of output image, noise batch or label
batch.
"""
# split noise into groups
if noise is not None:
if isinstance(noise, torch.Tensor):
num_batches = noise.shape[0]
noise_group = torch.split(noise, max_batch_size, 0)
else:
num_batches = noise[0].shape[0]
noise_group = torch.split(noise[0], max_batch_size, 0)
noise_group = [[noise_tensor] for noise_tensor in noise_group]
else:
noise_group = [None] * (
num_batches // max_batch_size +
(1 if num_batches % max_batch_size > 0 else 0))
# split embedding into groups
if embedding is not None:
assert isinstance(embedding, torch.Tensor)
num_batches = embedding.shape[0]
embedding_group = torch.split(embedding, max_batch_size, 0)
else:
embedding_group = [None] * (
num_batches // max_batch_size +
(1 if num_batches % max_batch_size > 0 else 0))
# split batchsize into groups
batchsize_group = [max_batch_size] * (num_batches // max_batch_size)
if num_batches % max_batch_size > 0:
batchsize_group += [num_batches % max_batch_size]
device = get_module_device(generator)
outputs = []
for _noise, _embedding, _num_batches in zip(noise_group, embedding_group,
batchsize_group):
if isinstance(_noise, torch.Tensor):
_noise = _noise.to(device)
if isinstance(_noise, list):
_noise = [ele.to(device) for ele in _noise]
if _embedding is not None:
_embedding = _embedding.to(device)
output = generator(
_noise, label=_embedding, num_batches=_num_batches, **kwargs)
output = output[dict_key] if dict_key else output
if isinstance(output, list):
output = output[0]
# once obtaining sampled results, we immediately put them into cpu
# to save cuda memory
outputs.append(output.to('cpu'))
outputs = torch.cat(outputs, dim=0)
return outputs
@torch.no_grad()
def sample_from_path(generator,
latent_a,
latent_b,
label_a,
label_b,
intervals,
embedding_name=None,
interp_mode='lerp',
**kwargs):
interp_alphas = torch.linspace(0, 1, intervals)
interp_samples = []
device = get_module_device(generator)
if embedding_name is None:
generator_name = generator.__class__.__name__
assert generator_name in _default_embedding_name
embedding_name = _default_embedding_name[generator_name]
embedding_fn = getattr(generator, embedding_name, nn.Identity())
embedding_a = embedding_fn(label_a.to(device))
embedding_b = embedding_fn(label_b.to(device))
for alpha in interp_alphas:
# calculate latent interpolation
if interp_mode == 'lerp':
latent_interp = torch.lerp(latent_a, latent_b, alpha)
else:
assert latent_a.ndim == latent_b.ndim == 2
latent_interp = slerp(latent_a, latent_b, alpha)
# calculate embedding interpolation
embedding_interp = embedding_a + (
embedding_b - embedding_a) * alpha.to(embedding_a.dtype)
if isinstance(generator, (BigGANDeepGenerator, BigGANGenerator)):
kwargs.update(dict(use_outside_embedding=True))
sample = batch_inference(generator, latent_interp, embedding_interp,
**kwargs)
interp_samples.append(sample)
return interp_samples
def main():
args = parse_args()
cfg = Config.fromfile(args.config)
# set cudnn_benchmark
if cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True
# set random seeds
if args.seed is not None:
print('set random seed to', args.seed)
set_random_seed(args.seed, deterministic=args.deterministic)
# build the model and load checkpoint
model = build_model(
cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
_ = load_checkpoint(model, args.checkpoint, map_location='cpu')
# sanity check for models without ema
if not model.use_ema:
args.sample_model = 'orig'
if args.sample_model == 'ema':
generator = model.generator_ema
else:
generator = model.generator
mmcv.print_log(f'Sampling model: {args.sample_model}', 'mmgen')
mmcv.print_log(f'Show mode: {args.show_mode}', 'mmgen')
mmcv.print_log(f'Samples path: {args.samples_path}', 'mmgen')
generator.eval()
if not args.use_cpu:
generator = generator.cuda()
if args.show_mode == 'sequence':
assert args.endpoint >= 2
else:
assert args.endpoint >= 2 and args.endpoint % 2 == 0
kwargs = dict(max_batch_size=args.batch_size)
if args.sample_cfg is None:
args.sample_cfg = dict()
kwargs.update(args.sample_cfg)
# get noises corresponding to each endpoint
noise_batch = batch_inference(
generator,
None,
num_batches=args.endpoint,
dict_key='noise_batch',
return_noise=True,
**kwargs)
# get labels corresponding to each endpoint
label_batch = batch_inference(
generator,
None,
num_batches=args.endpoint,
dict_key='label',
return_noise=True,
**kwargs)
# set label fixed
if args.fix_y:
label_batch = label_batch[0] * torch.ones_like(label_batch)
# set noise fixed
if args.fix_z:
noise_batch = torch.cat(
[noise_batch[0:1, ]] * noise_batch.shape[0], dim=0)
if args.show_mode == 'sequence':
results = sample_from_path(generator, noise_batch[:-1, ],
noise_batch[1:, ], label_batch[:-1, ],
label_batch[1:, ], args.interval,
args.embedding_name, args.interp_mode,
**kwargs)
else:
results = sample_from_path(generator, noise_batch[::2, ],
noise_batch[1::2, ], label_batch[:-1, ],
label_batch[1:, ], args.interval,
args.embedding_name, args.interp_mode,
**kwargs)
# reorder results
results = torch.stack(results).permute(1, 0, 2, 3, 4)
_, _, ch, h, w = results.shape
results = results.reshape(-1, ch, h, w)
# rescale value range to [0, 1]
results = ((results + 1) / 2)
results = results[:, [2, 1, 0], ...]
results = results.clamp_(0, 1)
# save image
mmcv.mkdir_or_exist(args.samples_path)
if args.show_mode == 'sequence':
for i in range(results.shape[0]):
image = results[i:i + 1]
save_image(
image,
os.path.join(args.samples_path, '{:0>5d}'.format(i) + '.png'))
else:
save_image(
results,
os.path.join(args.samples_path, 'group.png'),
nrow=args.interval)
if __name__ == '__main__':
main()
import argparse
import os
import sys
import imageio
import mmcv
import numpy as np
import torch
from mmcv import Config, DictAction
from mmcv.runner import load_checkpoint
from torchvision.utils import save_image
# yapf: disable
sys.path.append(os.path.abspath(os.path.join(__file__, '../..'))) # isort:skip # noqa
from mmgen.apis import set_random_seed # isort:skip # noqa
from mmgen.core.evaluation import slerp # isort:skip # noqa
from mmgen.models import build_model # isort:skip # noqa
# yapf: enable
def parse_args():
parser = argparse.ArgumentParser(
description='Sampling from latents\' interpolation')
parser.add_argument('config', help='evaluation config file path')
parser.add_argument('checkpoint', help='checkpoint file')
parser.add_argument(
'--use-cpu',
action='store_true',
help='whether to use cpu device for sampling')
parser.add_argument(
'--export-video',
action='store_true',
help='If true, export video rather than images')
parser.add_argument('--seed', type=int, default=2021, help='random seed')
parser.add_argument(
'--deterministic',
action='store_true',
help='whether to set deterministic options for CUDNN backend.')
parser.add_argument(
'--samples-path', type=str, help='path to store images.')
parser.add_argument(
'--sample-model',
type=str,
default='ema',
help='use which mode (ema/orig) in sampling.')
parser.add_argument(
'--show-mode',
choices=['group', 'sequence'],
default='sequence',
help='mode to show interpolation result.')
parser.add_argument(
'--interp-mode',
choices=['lerp', 'slerp'],
default='lerp',
help='mode to sample from endpoints\'s interpolation.')
parser.add_argument(
'--proj-latent',
type=str,
default=None,
help='Projection image files produced by stylegan_projector.py. If this \
argument is given, then the projected latent will be used as the input\
noise.')
parser.add_argument(
'--endpoint', type=int, default=2, help='The number of endpoints.')
parser.add_argument(
'--batch-size',
type=int,
default=2,
help='batch size used in generator sampling.')
parser.add_argument(
'--interval',
type=int,
default=10,
help='The number of intervals between two endpoints.')
parser.add_argument(
'--space',
choices=['z', 'w'],
default='w',
help='Interpolation space.')
parser.add_argument(
'--sample-cfg',
nargs='+',
action=DictAction,
help='Other customized kwargs for sampling function')
args = parser.parse_args()
return args
@torch.no_grad()
def batch_inference(generator,
noise,
num_batches=-1,
max_batch_size=16,
dict_key=None,
**kwargs):
"""Inference function to get a batch of desired data from output dictionary
of generator.
Args:
generator (nn.Module): Generator of a conditional model.
noise (Tensor | list[torch.tensor] | None): A batch of noise
Tensor.
num_batches (int, optional): The number of batchs for
inference. Defaults to -1.
max_batch_size (int, optional): The number of batch size for
inference. Defaults to 16.
dict_key (str, optional): key used to get results from output
dictionary of generator. Defaults to None.
Returns:
torch.Tensor: Tensor of output image, noise batch or label
batch.
"""
# split noise into groups
if noise is not None:
if isinstance(noise, torch.Tensor):
num_batches = noise.shape[0]
noise_group = torch.split(noise, max_batch_size, 0)
else:
num_batches = noise[0].shape[0]
noise_group = torch.split(noise[0], max_batch_size, 0)
noise_group = [[noise_tensor] for noise_tensor in noise_group]
else:
noise_group = [None] * (
num_batches // max_batch_size +
(1 if num_batches % max_batch_size > 0 else 0))
# split batchsize into groups
batchsize_group = [max_batch_size] * (num_batches // max_batch_size)
if num_batches % max_batch_size > 0:
batchsize_group += [num_batches % max_batch_size]
outputs = []
for _noise, _num_batches in zip(noise_group, batchsize_group):
if isinstance(_noise, torch.Tensor):
_noise = _noise.cuda()
if isinstance(_noise, list):
_noise = [ele.cuda() for ele in _noise]
output = generator(_noise, num_batches=_num_batches, **kwargs)
output = output[dict_key] if dict_key else output
if isinstance(output, list):
output = output[0]
# once obtaining sampled results, we immediately put them into cpu
# to save cuda memory
outputs.append(output.to('cpu'))
outputs = torch.cat(outputs, dim=0)
return outputs
def layout_grid(video_out,
all_img,
grid_w=1,
grid_h=1,
float_to_uint8=True,
chw_to_hwc=True,
to_numpy=True):
r"""Arrange images into video frames.
Ref: https://github.com/NVlabs/stylegan3/blob/a5a69f58294509598714d1e88c9646c3d7c6ec94/gen_video.py#L28 # noqa
Args:
video_out (Writer): Video writer.
all_img (torch.Tensor): All images to be displayed in video.
grid_w (int, optional): Column number in a frame. Defaults to 1.
grid_h (int, optional): Row number in a frame. Defaults to 1.
float_to_uint8 (bool, optional): Change torch value from `float` to `uint8`. Defaults to True.
chw_to_hwc (bool, optional): Change channel order from `chw` to `hwc`. Defaults to True.
to_numpy (bool, optional): Change image format from `torch.Tensor` to `np.array`. Defaults to True.
Returns:
Writer: Video writer.
"""
batch_size, channels, img_h, img_w = all_img.shape
assert batch_size % (grid_w * grid_h) == 0
images_per_frame = grid_w * grid_h
n_frames = batch_size // images_per_frame
all_img = all_img.reshape(images_per_frame, n_frames, channels, img_h,
img_w).permute(1, 0, 2, 3, 4).reshape(
n_frames, images_per_frame, channels, img_h,
img_w)
for i in range(0, n_frames):
img = all_img[i]
if float_to_uint8:
img = (img * 255.).clamp(0, 255).to(torch.uint8)
img = img.reshape(grid_h, grid_w, channels, img_h, img_w)
img = img.permute(2, 0, 3, 1, 4)
img = img.reshape(channels, grid_h * img_h, grid_w * img_w)
if chw_to_hwc:
img = img.permute(1, 2, 0)
if to_numpy:
img = img.cpu().numpy()
video_out.append_data(img)
return video_out
def crack_integer(integer):
"""Cracking an integer into the product of two nearest integers.
Args:
integer (int): An positive integer.
Returns:
tuple: Two integers.
"""
start = int(np.sqrt(integer))
factor = integer / start
while int(factor) != factor:
start += 1
factor = integer / start
return int(factor), start
@torch.no_grad()
def sample_from_path(generator,
latent_a,
latent_b,
intervals,
interp_mode='lerp',
space='z',
**kwargs):
interp_alphas = np.linspace(0, 1, intervals)
interp_samples = []
for alpha in interp_alphas:
if interp_mode == 'lerp':
latent_interp = torch.lerp(latent_a, latent_b, alpha)
else:
assert latent_a.ndim == latent_b.ndim == 2
latent_interp = slerp(latent_a, latent_b, alpha)
if space == 'w':
latent_interp = [latent_interp]
sample = batch_inference(generator, latent_interp, **kwargs)
interp_samples.append(sample)
return interp_samples
def main():
args = parse_args()
cfg = Config.fromfile(args.config)
# set cudnn_benchmark
if cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True
# set random seeds
if args.seed is not None:
print('set random seed to', args.seed)
set_random_seed(args.seed, deterministic=args.deterministic)
# build the model and load checkpoint
model = build_model(
cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
_ = load_checkpoint(model, args.checkpoint, map_location='cpu')
# sanity check for models without ema
if not model.use_ema:
args.sample_model = 'orig'
if args.sample_model == 'ema':
generator = model.generator_ema
else:
generator = model.generator
mmcv.print_log(f'Sampling model: {args.sample_model}', 'mmgen')
mmcv.print_log(f'Show mode: {args.show_mode}', 'mmgen')
mmcv.print_log(f'Samples path: {args.samples_path}', 'mmgen')
generator.eval()
if not args.use_cpu:
generator = generator.cuda()
# if given proj_latent, reset args.endpoint
if args.proj_latent is not None:
mmcv.print_log(f'Load projected latent: {args.proj_latent}', 'mmgen')
proj_file = torch.load(args.proj_latent)
proj_n = len(proj_file)
setattr(args, 'endpoint', proj_n)
assert args.space == 'w', 'Projected latent are w or w-plus latent.'
noise_batch = []
for img_path in proj_file:
noise_batch.append(proj_file[img_path]['latent'].unsqueeze(0))
noise_batch = torch.cat(noise_batch, dim=0).cuda()
if args.use_cpu:
noise_batch = noise_batch.to('cpu')
if args.show_mode == 'sequence':
assert args.endpoint >= 2
else:
assert args.endpoint >= 2 and args.endpoint % 2 == 0,\
'''We need paired images in group mode,
so keep endpoint an even number'''
kwargs = dict(max_batch_size=args.batch_size)
if args.sample_cfg is None:
args.sample_cfg = dict()
kwargs.update(args.sample_cfg)
# remind users to fixed injected noise
if kwargs.get('randomize_noise', 'True'):
mmcv.print_log(
'''Hint: For Style-Based GAN, you can add
`--sample-cfg randomize_noise=False` to fix injected noises''',
'mmgen')
# get noises corresponding to each endpoint
if not args.proj_latent:
noise_batch = batch_inference(
generator,
None,
num_batches=args.endpoint,
dict_key='noise_batch' if args.space == 'z' else 'latent',
return_noise=True,
**kwargs)
if args.space == 'w':
kwargs['truncation_latent'] = generator.get_mean_latent()
kwargs['input_is_latent'] = True
if args.show_mode == 'sequence':
results = sample_from_path(generator, noise_batch[:-1, ],
noise_batch[1:, ], args.interval,
args.interp_mode, args.space, **kwargs)
else:
results = sample_from_path(generator, noise_batch[::2, ],
noise_batch[1::2, ], args.interval,
args.interp_mode, args.space, **kwargs)
# reorder results
results = torch.stack(results).permute(1, 0, 2, 3, 4)
_, _, ch, h, w = results.shape
results = results.reshape(-1, ch, h, w)
# rescale value range to [0, 1]
results = ((results + 1) / 2)
results = results[:, [2, 1, 0], ...]
results = results.clamp_(0, 1)
# save image
mmcv.mkdir_or_exist(args.samples_path)
if args.show_mode == 'sequence':
if args.export_video:
# render video.
video_out = imageio.get_writer(
os.path.join(args.samples_path, 'lerp.mp4'),
mode='I',
fps=60,
codec='libx264',
bitrate='12M')
video_out = layout_grid(video_out, results)
video_out.close()
else:
for i in range(results.shape[0]):
image = results[i:i + 1]
save_image(
image,
os.path.join(args.samples_path,
'{:0>5d}'.format(i) + '.png'))
else:
if args.export_video:
# render video.
video_out = imageio.get_writer(
os.path.join(args.samples_path, 'lerp.mp4'),
mode='I',
fps=60,
codec='libx264',
bitrate='12M')
n_pair = args.endpoint // 2
grid_w, grid_h = crack_integer(n_pair)
video_out = layout_grid(
video_out, results, grid_h=grid_h, grid_w=grid_w)
video_out.close()
else:
save_image(
results,
os.path.join(args.samples_path, 'group.png'),
nrow=args.interval)
if __name__ == '__main__':
main()
"""Modified SeFa (closed-form factorization)
This gan editing method is modified according to Sefa. More details can be
found in Positional Encoding as Spatial Inductive Bias in GANs, CVPR2021.
The major modifications are:
- Calculate eigen vectors on the matrix with all style modulation weights in
styleconvs;
- Allow to adopt unsymetric degree to be more robust to different samples.
"""
import argparse
import os
import sys
import mmcv
import numpy as np
import torch
from mmcv import DictAction
from mmcv.runner import load_checkpoint
from torchvision import utils
# yapf: disable
sys.path.append(os.path.abspath(os.path.join(__file__, '../..'))) # isort:skip # noqa
from mmgen.apis import set_random_seed # isort:skip # noqa
from mmgen.models import build_model # isort:skip # noqa
# yapf: enable
def calc_eigens(args, state_dict):
# get all of the style modulation weights except for weights in `to_rgb`
modulated = {
k: v
for k, v in state_dict.items()
if 'style_modulation' in k and 'to_rgb' not in k and 'weight' in k
}
weight_mat = []
for _, v in modulated.items():
weight_mat.append(v)
W = torch.cat(weight_mat, dim=0)
eigen_vector = torch.svd(W).V
# save eigen vector
output_path = os.path.splitext(args.ckpt)[0] + '_eigen-vec-mod.pth'
torch.save({'ckpt': args.ckpt, 'eigen_vector': eigen_vector}, output_path)
return eigen_vector
if __name__ == '__main__':
# set device
if torch.cuda.is_available():
device = 'cuda'
else:
device = 'cpu'
# set grad enabled = False
torch.set_grad_enabled(False)
parser = argparse.ArgumentParser(
description='Apply modified closed form factorization')
# sefa args
parser.add_argument(
'-i', '--index', type=int, default=0, help='index of eigenvector')
parser.add_argument(
'-d',
'--degree',
type=float,
nargs='+',
default=[2.],
help='scalar factors for moving latent vectors along eigenvector',
)
parser.add_argument(
'--degree-step',
type=float,
default=0.25,
help='The step of changing degrees')
parser.add_argument('-l', '--layer-num', nargs='+', type=int, default=None)
parser.add_argument(
'--eigen-vector',
type=str,
default=None,
help='Path to the eigen vectors')
# gan args
parser.add_argument(
'--randomize-noise',
action='store_true',
help='whether to use random noise in the middle layers')
parser.add_argument('--ckpt', type=str, help='Path to the checkpoint')
parser.add_argument('--config', type=str, help='Path to model config')
parser.add_argument('--truncation', type=float, default=1)
parser.add_argument('--truncation-mean', type=int, default=4096)
parser.add_argument('--noise-channels', type=int, default=512)
parser.add_argument('--input-scale', type=int, default=4)
parser.add_argument(
'--sample-cfg',
nargs='+',
action=DictAction,
help='Other customized kwargs for sampling function')
# system args
parser.add_argument('--num-samples', type=int, default=2)
parser.add_argument('--sample-path', type=str, default=None)
parser.add_argument('--random-seed', type=int, default=2020)
args = parser.parse_args()
set_random_seed(args.random_seed)
cfg = mmcv.Config.fromfile(args.config)
if cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True
mmcv.print_log('Building models and loading checkpoints', 'mmgen')
# build model
model = build_model(
cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
model.eval()
load_checkpoint(model, args.ckpt, map_location='cpu')
# get generator
if model.use_ema:
generator = model.generator_ema
else:
generator = model.generator
generator = generator.to(device)
generator.eval()
mmcv.print_log('Calculating or loading eigen vectors', 'mmgen')
# load/calculate eigen vector for current weights
if args.eigen_vector is None:
eigen_vector = calc_eigens(args, generator.state_dict())
else:
eigen_vector = torch.load(args.eigen_vector)['eigen_vector']
eigen_vector = eigen_vector.to(device)
if args.truncation < 1:
# TODO: get mean latent
mean_latent = generator.get_mean_latent(args.truncation_mean)
else:
mean_latent = None
noise = torch.randn((args.num_samples, args.noise_channels), device=device)
latent = generator.style_mapping(noise)
# kwargs for different gan models
kwargs = dict()
# mspie-stylegan2
if args.input_scale > 0:
kwargs['chosen_scale'] = args.input_scale
if args.sample_cfg is None:
args.sample_cfg = dict()
mmcv.print_log('Sampling images with modified SeFa', 'mmgen')
sample = generator([latent], input_is_latent=True, **args.sample_cfg)
# the first line is the original samples
img_list = [sample]
if len(args.degree) == 1:
factor_list = np.arange(-args.degree[0], args.degree[0] + 0.001,
args.degree_step)
else:
factor_list = np.arange(args.degree[0], args.degree[1] + 0.001,
args.degree_step)
for fac in factor_list:
direction = fac * eigen_vector[:, args.index].unsqueeze(0)
if args.layer_num is None:
latent_input = [latent + direction]
else:
latent_all = latent.unsqueeze(1).repeat(1, generator.num_latents,
1)
for l_num in args.layer_num:
latent_all[:, l_num] = latent + direction
latent_input = [latent_all]
sample = generator(
latent_input, input_is_latent=True, **args.sample_cfg)
img_list.append(sample)
mmcv.mkdir_or_exist(args.sample_path)
if args.layer_num is None:
filename = (
f'{args.sample_path}/entangle-i{args.index}-d{args.degree}'
f'-t{args.degree_step}_{str(args.random_seed).zfill(6)}.png')
else:
filename = (f'{args.sample_path}/entangle-i{args.index}-d{args.degree}'
f'-t{args.degree_step}-l{args.layer_num}'
f'_{str(args.random_seed).zfill(6)}.png')
img = torch.cat(img_list, dim=0)[:, [2, 1, 0]]
utils.save_image(
img,
filename,
nrow=args.num_samples,
padding=0,
normalize=True,
range=(-1, 1))
mmcv.print_log(f'Save images to {filename}', 'mmgen')
import argparse
import math
import os
try:
import clip
except ImportError:
raise 'To use styleclip, openai clip need to be installed first'
import mmcv
import torch
import torchvision
from mmcv import Config, DictAction
from torch import optim
from tqdm import tqdm
from mmgen.apis import init_model
from mmgen.models.losses import CLIPLoss, FaceIdLoss
from mmgen.apis import set_random_seed # isort:skip # noqa
def get_lr(t, initial_lr, rampdown=0.25, rampup=0.05):
lr_ramp = min(1, (1 - t) / rampdown)
lr_ramp = 0.5 - 0.5 * math.cos(lr_ramp * math.pi)
lr_ramp = lr_ramp * min(1, t / rampup)
return initial_lr * lr_ramp
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('config', help='model config file path')
parser.add_argument('checkpoint', help='checkpoint file')
parser.add_argument('--seed', type=int, default=2021, help='random seed')
parser.add_argument(
'--deterministic',
action='store_true',
help='whether to set deterministic options for CUDNN backend.')
parser.add_argument(
'--use-cpu',
action='store_true',
help='whether to use cpu device for sampling')
parser.add_argument(
'--description',
type=str,
default='a person with purple hair',
help='the text that guides the editing/generation')
parser.add_argument('--lr', type=float, default=0.1)
parser.add_argument(
'--mode',
type=str,
default='generate',
choices=['edit', 'generate'],
help='choose between edit an image an generate a free one')
parser.add_argument(
'--l2-lambda',
type=float,
default=0.008,
help='weight of the latent distance, used for editing only')
parser.add_argument(
'--id-lambda',
type=float,
default=0.000,
help='weight of id loss, used for editing only')
parser.add_argument(
'--proj-latent',
type=str,
default=None,
help='Projection image files produced by stylegan_projector.py. If this \
argument is given, then the projected latent will be used as the init\
latent.')
parser.add_argument(
'--truncation',
type=float,
default=0.7,
help='used only for the initial latent vector, and only when a latent '
'code path is not provided')
parser.add_argument(
'--step', type=int, default=2000, help='Optimization iterations')
parser.add_argument(
'--save-interval',
type=int,
default=20,
help='if > 0 then saves intermidate results during the optimization')
parser.add_argument(
'--results-dir', type=str, default='work_dirs/styleclip/')
parser.add_argument(
'--sample-cfg',
nargs='+',
action=DictAction,
help='Other customized kwargs for sampling function')
args = parser.parse_args()
return args
def main():
args = parse_args()
# set cudnn_benchmark
cfg = Config.fromfile(args.config)
if cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True
# set random seeds
if args.seed is not None:
print('set random seed to', args.seed)
set_random_seed(args.seed, deterministic=args.deterministic)
os.makedirs(args.results_dir, exist_ok=True)
text_inputs = torch.cat([clip.tokenize(args.description)]).cuda()
model = init_model(args.config, args.checkpoint, device='cpu')
g_ema = model.generator_ema
g_ema.eval()
if not args.use_cpu:
g_ema = g_ema.cuda()
mean_latent = g_ema.get_mean_latent()
# if given proj_latent
if args.proj_latent is not None:
mmcv.print_log(f'Load projected latent: {args.proj_latent}', 'mmgen')
proj_file = torch.load(args.proj_latent)
proj_n = len(proj_file)
assert proj_n == 1
noise_batch = []
for img_path in proj_file:
noise_batch.append(proj_file[img_path]['latent'].unsqueeze(0))
latent_code_init = torch.cat(noise_batch, dim=0).cuda()
elif args.mode == 'edit':
latent_code_init_not_trunc = torch.randn(1, 512).cuda()
with torch.no_grad():
results = g_ema([latent_code_init_not_trunc],
return_latents=True,
truncation=args.truncation,
truncation_latent=mean_latent)
latent_code_init = results['latent']
else:
latent_code_init = mean_latent.detach().clone().repeat(1, 18, 1)
with torch.no_grad():
img_orig = g_ema([latent_code_init],
input_is_latent=True,
randomize_noise=False)
latent = latent_code_init.detach().clone()
latent.requires_grad = True
clip_loss = CLIPLoss(clip_model=dict(in_size=g_ema.out_size))
id_loss = FaceIdLoss(
facenet=dict(type='ArcFace', ir_se50_weights=None, device='cuda'))
optimizer = optim.Adam([latent], lr=args.lr)
pbar = tqdm(range(args.step))
mmcv.print_log(f'Description: {args.description}')
for i in pbar:
t = i / args.step
lr = get_lr(t, args.lr)
optimizer.param_groups[0]['lr'] = lr
img_gen = g_ema([latent], input_is_latent=True, randomize_noise=False)
img_gen = img_gen[:, [2, 1, 0], ...]
# clip loss
c_loss = clip_loss(image=img_gen, text=text_inputs)
if args.id_lambda > 0:
i_loss = id_loss(pred=img_gen, gt=img_orig)[0]
else:
i_loss = 0
if args.mode == 'edit':
l2_loss = ((latent_code_init - latent)**2).sum()
loss = c_loss + args.l2_lambda * l2_loss + args.id_lambda * i_loss
else:
loss = c_loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
pbar.set_description((f'loss: {loss.item():.4f};'))
if args.save_interval > 0 and (i % args.save_interval == 0):
with torch.no_grad():
img_gen = g_ema([latent],
input_is_latent=True,
randomize_noise=False)
img_gen = img_gen[:, [2, 1, 0], ...]
torchvision.utils.save_image(
img_gen,
os.path.join(args.results_dir, f'{str(i).zfill(5)}.png'),
normalize=True,
range=(-1, 1))
if args.mode == 'edit':
img_orig = img_orig[:, [2, 1, 0], ...]
final_result = torch.cat([img_orig, img_gen])
else:
final_result = img_gen
torchvision.utils.save_image(
final_result.detach().cpu(),
os.path.join(args.results_dir, 'final_result.png'),
normalize=True,
scale_each=True,
range=(-1, 1))
if __name__ == '__main__':
main()
r"""
This app is used to invert the styleGAN series synthesis network. We find
the matching latent vector w for given images so that we can manipulate
images in the latent feature space.
Ref: https://github.com/rosinality/stylegan2-pytorch/blob/master/projector.py # noqa
"""
import argparse
import os
import sys
from collections import OrderedDict
import mmcv
import numpy as np
import torch
import torch.nn.functional as F
from mmcv import Config
from mmcv.runner import load_checkpoint
from PIL import Image
from torch import optim
from torchvision import transforms
from torchvision.utils import save_image
from tqdm import tqdm
# yapf: disable
sys.path.append(os.path.abspath(os.path.join(__file__, '../..'))) # isort:skip # noqa
from mmgen.apis import set_random_seed # isort:skip # noqa
from mmgen.models import build_model # isort:skip # noqa
from mmgen.models.architectures.lpips import PerceptualLoss # isort:skip # noqa
# yapf: enable
def parse_args():
parser = argparse.ArgumentParser(
description='Image projector to the StyleGAN-based generator latent \
spaces')
parser.add_argument('config', help='evaluation config file path')
parser.add_argument('checkpoint', help='checkpoint file')
parser.add_argument(
'files',
metavar='FILES',
nargs='+',
help='path to image files to be projected')
parser.add_argument(
'--results-path', type=str, help='path to store projection results.')
parser.add_argument(
'--use-cpu',
action='store_true',
help='whether to use cpu device for sampling')
parser.add_argument('--seed', type=int, default=2021, help='random seed')
parser.add_argument(
'--deterministic',
action='store_true',
help='whether to set deterministic options for CUDNN backend.')
parser.add_argument(
'--sample-model',
type=str,
default='ema',
help='use which mode (ema/orig) in sampling.')
parser.add_argument(
'--lr-rampup',
type=float,
default=0.05,
help='proportion of the learning rate warmup iters in the total iters')
parser.add_argument(
'--lr-rampdown',
type=float,
default=0.25,
help='proportion of the learning rate decay iters in the total iters')
parser.add_argument(
'--lr', type=float, default=0.1, help='maximum learning rate')
parser.add_argument(
'--noise',
type=float,
default=0.05,
help='strength of the noise level')
parser.add_argument(
'--noise-ramp',
type=float,
default=0.75,
help='proportion of the noise level decay iters in the total iters',
)
parser.add_argument(
'--total-iters', type=int, default=1000, help='optimize iterations')
parser.add_argument(
'--noise-regularize',
type=float,
default=1e5,
help='weight of the noise regularization',
)
parser.add_argument(
'--mse', type=float, default=0, help='weight of the mse loss')
parser.add_argument(
'--n-mean-latent',
type=int,
default=10000,
help='sampling times to obtain the mean latent')
parser.add_argument(
'--w-plus',
action='store_true',
help='allow to use distinct latent codes to each layers',
)
args = parser.parse_args()
return args
def noise_regularize(noises):
loss = 0
for noise in noises:
size = noise.shape[2]
while True:
loss = (
loss +
(noise * torch.roll(noise, shifts=1, dims=3)).mean().pow(2) +
(noise * torch.roll(noise, shifts=1, dims=2)).mean().pow(2))
if size <= 8:
break
noise = noise.reshape([-1, 1, size // 2, 2, size // 2, 2])
noise = noise.mean([3, 5])
size //= 2
return loss
def noise_normalize_(noises):
for noise in noises:
mean = noise.mean()
std = noise.std()
noise.data.add_(-mean).div_(std)
def get_lr(t, initial_lr, rampdown=0.25, rampup=0.05):
lr_ramp = min(1, (1 - t) / rampdown)
lr_ramp = 0.5 - 0.5 * np.cos(lr_ramp * np.pi)
lr_ramp = lr_ramp * min(1, t / rampup)
return initial_lr * lr_ramp
def latent_noise(latent, strength):
noise = torch.randn_like(latent) * strength
return latent + noise
def main():
args = parse_args()
cfg = Config.fromfile(args.config)
# set cudnn_benchmark
if cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True
# set random seeds
if args.seed is not None:
print('set random seed to', args.seed)
set_random_seed(args.seed, deterministic=args.deterministic)
# build the model and load checkpoint
model = build_model(
cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
_ = load_checkpoint(model, args.checkpoint, map_location='cpu')
# sanity check for models without ema
if not model.use_ema:
args.sample_model = 'orig'
if args.sample_model == 'ema':
generator = model.generator_ema
else:
generator = model.generator
mmcv.print_log(f'Sampling model: {args.sample_model}', 'mmgen')
generator.eval()
device = 'cpu'
if not args.use_cpu:
generator = generator.cuda()
device = 'cuda'
img_size = min(generator.out_size, 256)
transform = transforms.Compose([
transforms.Resize(img_size),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
])
# read images
print('----'*20)
print('Reading images!')
imgs = []
for imgfile in args.files:
img = Image.open(imgfile).convert('RGB')
img = transform(img)
img = img[[2, 1, 0], ...]
imgs.append(img)
imgs = torch.stack(imgs, 0).to(device)
# get mean and standard deviation of style latents
with torch.no_grad():
noise_sample = torch.randn(
args.n_mean_latent, generator.style_channels, device=device)
latent_out = generator.style_mapping(noise_sample)
latent_mean = latent_out.mean(0)
latent_std = ((latent_out - latent_mean).pow(2).sum() /
args.n_mean_latent)**0.5
latent_in = latent_mean.detach().clone().unsqueeze(0).repeat(
imgs.shape[0], 1)
if args.w_plus:
latent_in = latent_in.unsqueeze(1).repeat(1, generator.num_latents, 1)
latent_in.requires_grad = True
# define lpips loss
percept = PerceptualLoss(use_gpu=device.startswith('cuda'))
# initialize layer noises
noises_single = generator.make_injected_noise()
noises = []
for noise in noises_single:
noises.append(noise.repeat(imgs.shape[0], 1, 1, 1).normal_())
for noise in noises:
noise.requires_grad = True
optimizer = optim.Adam([latent_in] + noises, lr=args.lr)
pbar = tqdm(range(args.total_iters))
# run optimization
for i in pbar:
t = i / args.total_iters
lr = get_lr(t, args.lr, args.lr_rampdown, args.lr_rampup)
optimizer.param_groups[0]['lr'] = lr
noise_strength = latent_std * args.noise * max(
0, 1 - t / args.noise_ramp)**2
latent_n = latent_noise(latent_in, noise_strength.item())
img_gen = generator([latent_n],
input_is_latent=True,
injected_noise=noises)
batch, channel, height, width = img_gen.shape
if height > 256:
factor = height // 256
img_gen = img_gen.reshape(batch, channel, height // factor, factor,
width // factor, factor)
img_gen = img_gen.mean([3, 5])
p_loss = percept(img_gen, imgs).sum()
n_loss = noise_regularize(noises)
mse_loss = F.mse_loss(img_gen, imgs)
loss = p_loss + args.noise_regularize * n_loss + args.mse * mse_loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
noise_normalize_(noises)
pbar.set_description(
f' perceptual: {p_loss.item():.4f}, noise regularize:'
f'{n_loss.item():.4f}, mse: {mse_loss.item():.4f}, lr: {lr:.4f}')
results = generator([latent_in.detach().clone()],
input_is_latent=True,
injected_noise=noises)
# rescale value range to [0, 1]
results = ((results + 1) / 2)
results = results[:, [2, 1, 0], ...]
results = results.clamp_(0, 1)
mmcv.mkdir_or_exist(args.results_path)
# save projection results
result_file = OrderedDict()
for i, input_name in enumerate(args.files):
noise_single = []
for noise in noises:
noise_single.append(noise[i:i + 1])
result_file[input_name] = {
'img': img_gen[i],
'latent': latent_in[i],
'injected_noise': noise_single,
}
img_name = os.path.splitext(
os.path.basename(input_name))[0] + '-project.png'
save_image(results[i], os.path.join(args.results_path, img_name))
torch.save(result_file, os.path.join(args.results_path,
'project_result.pt'))
if __name__ == '__main__':
main()
dataset_type = 'UnconditionalImageDataset'
# To be noted that, `Resize` operation with `pillow` backend and
# `bicubic` interpolation is the must for correct IS evaluation
val_pipeline = [
dict(
type='LoadImageFromFile',
key='real_img',
io_backend='disk',
),
dict(
type='Resize',
keys=['real_img'],
scale=(299, 299),
backend='pillow',
interpolation='bicubic'),
dict(
type='Normalize',
keys=['real_img'],
mean=[127.5] * 3,
std=[127.5] * 3,
to_rgb=True),
dict(type='ImageToTensor', keys=['real_img']),
dict(type='Collect', keys=['real_img'], meta_keys=['real_img_path'])
]
data = dict(
samples_per_gpu=None,
workers_per_gpu=4,
val=dict(type=dataset_type, imgs_root=None, pipeline=val_pipeline))
dataset_type = 'mmcls.CIFAR10'
# different from mmcls, we adopt the setting used in BigGAN
# Note that the pipelines below are from MMClassification. Importantly, the
# `to_rgb` is set to `True` to convert image to BGR orders. The default order
# in Cifar10 is RGB. Thus, we have to convert it to BGR.
img_norm_cfg = dict(
mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=True)
train_pipeline = [
dict(type='RandomCrop', size=32, padding=4),
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='ToTensor', keys=['gt_label']),
dict(type='Collect', keys=['img', 'gt_label'])
]
test_pipeline = [
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
]
# Different from the classification task, the val/test split also use the
# training part, which is the same to StyleGAN-ADA.
data = dict(
samples_per_gpu=None,
workers_per_gpu=4,
train=dict(
type=dataset_type, data_prefix='data/cifar10',
pipeline=train_pipeline),
val=dict(
type=dataset_type, data_prefix='data/cifar10', pipeline=test_pipeline),
test=dict(
type=dataset_type, data_prefix='data/cifar10', pipeline=test_pipeline))
dataset_type = 'mmcls.CIFAR10'
# This config is set for extract inception state of CIFAR dataset.
# Different from mmcls, we adopt the setting used in BigGAN.
# Note that the pipelines below are from MMClassification.
# The default order in Cifar10 is RGB. Thus, we set `to_rgb` as `False`.
img_norm_cfg = dict(
mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=False)
train_pipeline = [
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='ToTensor', keys=['gt_label']),
dict(type='Collect', keys=['img', 'gt_label'])
]
test_pipeline = [
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
]
# Different from the classification task, the val/test split also use the
# training part, which is the same to StyleGAN-ADA.
data = dict(
samples_per_gpu=None,
workers_per_gpu=4,
train=dict(
type=dataset_type, data_prefix='data/cifar10',
pipeline=train_pipeline),
val=dict(
type=dataset_type, data_prefix='data/cifar10', pipeline=test_pipeline),
test=dict(
type=dataset_type, data_prefix='data/cifar10', pipeline=test_pipeline))
dataset_type = 'mmcls.CIFAR10'
# different from mmcls, we adopt the setting used in BigGAN
# Note that the pipelines below are from MMClassification. Importantly, the
# `to_rgb` is set to `True` to convert image to BGR orders. The default order
# in Cifar10 is RGB. Thus, we have to convert it to BGR.
# Cifar dataset w/o augmentations. Remove `RandomFlip` and `RandomCrop`
# augmentations.
img_norm_cfg = dict(
mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=True)
train_pipeline = [
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='ToTensor', keys=['gt_label']),
dict(type='Collect', keys=['img', 'gt_label'])
]
test_pipeline = [
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
]
# Different from the classification task, the val/test split also use the
# training part, which is the same to StyleGAN-ADA.
data = dict(
samples_per_gpu=None,
workers_per_gpu=4,
train=dict(
type='RepeatDataset',
times=500,
dataset=dict(
type=dataset_type,
data_prefix='data/cifar10',
pipeline=train_pipeline)),
val=dict(
type=dataset_type, data_prefix='data/cifar10', pipeline=test_pipeline),
test=dict(
type=dataset_type, data_prefix='data/cifar10', pipeline=test_pipeline))
dataset_type = 'mmcls.CIFAR10'
# different from mmcls, we adopt the setting used in BigGAN
# Note that the pipelines below are from MMClassification. Importantly, the
# `to_rgb` is set to `True` to convert image to BGR orders. The default order
# in Cifar10 is RGB. Thus, we have to convert it to BGR.
img_norm_cfg = dict(
mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=True)
train_pipeline = [
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='ToTensor', keys=['gt_label']),
dict(type='Collect', keys=['img', 'gt_label'])
]
test_pipeline = [
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
]
# Different from the classification task, the val/test split also use the
# training part, which is the same to StyleGAN-ADA.
data = dict(
samples_per_gpu=None,
workers_per_gpu=4,
train=dict(
type=dataset_type, data_prefix='data/cifar10',
pipeline=train_pipeline),
val=dict(
type=dataset_type, data_prefix='data/cifar10', pipeline=test_pipeline),
test=dict(
type=dataset_type, data_prefix='data/cifar10', pipeline=test_pipeline))
dataset_type = 'mmcls.CIFAR10'
# cifar dataset without augmentation
# different from mmcls, we adopt the setting used in BigGAN
# Note that the pipelines below are from MMClassification. Importantly, the
# `to_rgb` is set to `True` to convert image to BGR orders. The default order
# in Cifar10 is RGB. Thus, we have to convert it to BGR.
# Follow the pipeline in
# https://github.com/pfnet-research/sngan_projection/blob/master/datasets/cifar10.py
# Only `RandomImageNoise` augmentation is adopted.
img_norm_cfg = dict(
mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=True)
train_pipeline = [
dict(type='Normalize', **img_norm_cfg),
dict(type='RandomImgNoise', keys=['img']),
dict(type='ImageToTensor', keys=['img']),
dict(type='ToTensor', keys=['gt_label']),
dict(type='Collect', keys=['img', 'gt_label'])
]
test_pipeline = [
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
]
# Different from the classification task, the val/test split also use the
# training part, which is the same to StyleGAN-ADA.
data = dict(
samples_per_gpu=None,
workers_per_gpu=4,
train=dict(
type=dataset_type, data_prefix='data/cifar10',
pipeline=train_pipeline),
val=dict(
type=dataset_type, data_prefix='data/cifar10', pipeline=test_pipeline),
test=dict(
type=dataset_type, data_prefix='data/cifar10', pipeline=test_pipeline))
dataset_type = 'mmcls.CIFAR10'
# different from mmcls, we adopt the setting used in BigGAN
# Note that the pipelines below are from MMClassification
img_norm_cfg = dict(
mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=False)
train_pipeline = [
dict(type='RandomCrop', size=32, padding=4),
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='ToTensor', keys=['gt_label']),
dict(type='Collect', keys=['img', 'gt_label'])
]
test_pipeline = [
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
]
# Different from the classification task, the val/test split also use the
# training part, which is the same to StyleGAN-ADA.
data = dict(
samples_per_gpu=None,
workers_per_gpu=4,
train=dict(
type=dataset_type, data_prefix='data/cifar10',
pipeline=train_pipeline),
val=dict(
type=dataset_type, data_prefix='data/cifar10', pipeline=test_pipeline),
test=dict(
type=dataset_type, data_prefix='data/cifar10', pipeline=test_pipeline))
dataset_type = 'UnconditionalImageDataset'
train_pipeline = [
dict(
type='LoadImageFromFile',
key='real_img',
io_backend='disk',
),
dict(type='Flip', keys=['real_img'], direction='horizontal'),
dict(
type='Normalize',
keys=['real_img'],
mean=[127.5] * 3,
std=[127.5] * 3,
to_rgb=False),
dict(type='ImageToTensor', keys=['real_img']),
dict(type='Collect', keys=['real_img'], meta_keys=['real_img_path'])
]
val_pipeline = [
dict(
type='LoadImageFromFile',
key='real_img',
io_backend='disk',
),
dict(
type='Normalize',
keys=['real_img'],
mean=[127.5] * 3,
std=[127.5] * 3,
to_rgb=True),
dict(type='ImageToTensor', keys=['real_img']),
dict(type='Collect', keys=['real_img'], meta_keys=['real_img_path'])
]
# `samples_per_gpu` and `imgs_root` need to be set.
data = dict(
samples_per_gpu=None,
workers_per_gpu=4,
train=dict(
type='RepeatDataset',
times=100,
dataset=dict(
type=dataset_type, imgs_root=None, pipeline=train_pipeline)),
val=dict(type=dataset_type, imgs_root=None, pipeline=val_pipeline))
dataset_type = 'GrowScaleImgDataset'
train_pipeline = [
dict(
type='LoadImageFromFile',
key='real_img',
io_backend='disk',
),
dict(type='Resize', keys=['real_img'], scale=(128, 128)),
dict(type='Flip', keys=['real_img'], direction='horizontal'),
dict(
type='Normalize',
keys=['real_img'],
mean=[127.5] * 3,
std=[127.5] * 3,
to_rgb=False),
dict(type='ImageToTensor', keys=['real_img']),
dict(type='Collect', keys=['real_img'], meta_keys=['real_img_path'])
]
# `samples_per_gpu` and `imgs_root` need to be set.
data = dict(
# samples per gpu should be the same as the first scale, e.g. '4': 64
# in this case
samples_per_gpu=None,
workers_per_gpu=4,
train=dict(
type=dataset_type,
# just an example
imgs_roots={'128': './data/lsun/bedroom_train'},
pipeline=train_pipeline,
gpu_samples_base=4,
# note that this should be changed with total gpu number
gpu_samples_per_scale={
'4': 64,
'8': 32,
'16': 16,
'32': 8,
'64': 4
},
len_per_stage=-1))
dataset_type = 'GrowScaleImgDataset'
train_pipeline = [
dict(
type='LoadImageFromFile',
key='real_img',
io_backend='disk',
),
dict(type='Flip', keys=['real_img'], direction='horizontal'),
dict(
type='Normalize',
keys=['real_img'],
mean=[127.5] * 3,
std=[127.5] * 3,
to_rgb=False),
dict(type='ImageToTensor', keys=['real_img']),
dict(type='Collect', keys=['real_img'], meta_keys=['real_img_path'])
]
# `samples_per_gpu` and `imgs_root` need to be set.
data = dict(
# samples per gpu should be the same as the first scale, e.g. '4': 64
# in this case
samples_per_gpu=None,
workers_per_gpu=4,
train=dict(
type=dataset_type,
# just an example
imgs_roots={
'64': './data/celebahq/imgs_64',
'256': './data/celebahq/imgs_256',
'512': './data/celebahq/imgs_512',
'1024': './data/celebahq/imgs_1024'
},
pipeline=train_pipeline,
gpu_samples_base=4,
# note that this should be changed with total gpu number
gpu_samples_per_scale={
'4': 64,
'8': 32,
'16': 16,
'32': 8,
'64': 4
},
len_per_stage=300000))
dataset_type = 'GrowScaleImgDataset'
train_pipeline = [
dict(
type='LoadImageFromFile',
key='real_img',
io_backend='disk',
),
dict(type='Flip', keys=['real_img'], direction='horizontal'),
dict(
type='Normalize',
keys=['real_img'],
mean=[127.5, 127.5, 127.5],
std=[127.5, 127.5, 127.5],
to_rgb=False),
dict(type='ImageToTensor', keys=['real_img']),
dict(type='Collect', keys=['real_img'], meta_keys=['real_img_path'])
]
data = dict(
samples_per_gpu=64,
workers_per_gpu=4,
train=dict(
type='GrowScaleImgDataset',
imgs_roots=dict({
'1024': './data/ffhq/images',
'256': './data/ffhq/ffhq_imgs/ffhq_256',
'64': './data/ffhq/ffhq_imgs/ffhq_64'
}),
pipeline=train_pipeline,
gpu_samples_base=4,
gpu_samples_per_scale={
'4': 64,
'8': 32,
'16': 16,
'32': 8,
'64': 4,
'128': 4,
'256': 4,
'512': 4,
'1024': 4
},
len_per_stage=300000))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment