Commit 310493b2 authored by mashun1's avatar mashun1
Browse files

stylegan3

parents
Pipeline #695 canceled with stages
*pyc*
*idea*
*.mp4
*.pkl
out
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
FROM nvcr.io/nvidia/pytorch:21.08-py3
ENV PYTHONDONTWRITEBYTECODE 1
ENV PYTHONUNBUFFERED 1
RUN pip install imageio imageio-ffmpeg==0.4.4 pyspng==0.1.0
WORKDIR /workspace
RUN (printf '#!/bin/bash\nexec \"$@\"\n' >> /entry.sh) && chmod a+x /entry.sh
ENTRYPOINT ["/entry.sh"]
Copyright (c) 2021, NVIDIA Corporation & affiliates. All rights reserved.
NVIDIA Source Code License for StyleGAN3
=======================================================================
1. Definitions
"Licensor" means any person or entity that distributes its Work.
"Software" means the original work of authorship made available under
this License.
"Work" means the Software and any additions to or derivative works of
the Software that are made available under this License.
The terms "reproduce," "reproduction," "derivative works," and
"distribution" have the meaning as provided under U.S. copyright law;
provided, however, that for the purposes of this License, derivative
works shall not include works that remain separable from, or merely
link (or bind by name) to the interfaces of, the Work.
Works, including the Software, are "made available" under this License
by including in or with the Work either (a) a copyright notice
referencing the applicability of this License to the Work, or (b) a
copy of this License.
2. License Grants
2.1 Copyright Grant. Subject to the terms and conditions of this
License, each Licensor grants to you a perpetual, worldwide,
non-exclusive, royalty-free, copyright license to reproduce,
prepare derivative works of, publicly display, publicly perform,
sublicense and distribute its Work and any resulting derivative
works in any form.
3. Limitations
3.1 Redistribution. You may reproduce or distribute the Work only
if (a) you do so under this License, (b) you include a complete
copy of this License with your distribution, and (c) you retain
without modification any copyright, patent, trademark, or
attribution notices that are present in the Work.
3.2 Derivative Works. You may specify that additional or different
terms apply to the use, reproduction, and distribution of your
derivative works of the Work ("Your Terms") only if (a) Your Terms
provide that the use limitation in Section 3.3 applies to your
derivative works, and (b) you identify the specific derivative
works that are subject to Your Terms. Notwithstanding Your Terms,
this License (including the redistribution requirements in Section
3.1) will continue to apply to the Work itself.
3.3 Use Limitation. The Work and any derivative works thereof only
may be used or intended for use non-commercially. Notwithstanding
the foregoing, NVIDIA and its affiliates may use the Work and any
derivative works commercially. As used herein, "non-commercially"
means for research or evaluation purposes only.
3.4 Patent Claims. If you bring or threaten to bring a patent claim
against any Licensor (including any claim, cross-claim or
counterclaim in a lawsuit) to enforce any patents that you allege
are infringed by any Work, then your rights under this License from
such Licensor (including the grant in Section 2.1) will terminate
immediately.
3.5 Trademarks. This License does not grant any rights to use any
Licensor’s or its affiliates’ names, logos, or trademarks, except
as necessary to reproduce the notices described in this License.
3.6 Termination. If you violate any term of this License, then your
rights under this License (including the grant in Section 2.1) will
terminate immediately.
4. Disclaimer of Warranty.
THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR
NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER
THIS LICENSE.
5. Limitation of Liability.
EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL
THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE
SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT,
INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF
OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK
(INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION,
LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER
COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF
THE POSSIBILITY OF SUCH DAMAGES.
=======================================================================
# stylegan3
## 论文
**Alias-Free Generative Adversarial Networks**
* https://nvlabs-fi-cdn.nvidia.com/stylegan3/stylegan3-paper.pdf
## 模型结构
下图为生成器的结构,`Mapping network`(将Latent code z转换为w),`Fourier feat.`(傅里叶特征),`ToRGB`(将图像转换为RGB-3通道格式)。
![Alt text](readme_imgs/image-1.png)
## 算法原理
用途:该算法可以生成高质量图像。
原理:
![Alt text](readme_imgs/image-2.png)
以信号处理方法分析现有问题,并通过修改网络结构,在保证FID的前提下增加网络的等变性,具体操作如下,
1.使用傅里叶特征+删除噪声输入+删除跳跃链接+减少网络层数+禁用混合正则化和路径长度正则化
保持FID与stylegan2相似,同时略微提升平移等变性。
2.增加边距+sinc滤波器
增加平移等变性,但FID变差。
3.在上/下采样之间添加非线性函数
增加平移等变性。
4.降低截止频率
增加平移等变性,FID低于stylegan2。
5.Transformed Fourier features
改善FID。
6.Flexible layer specifications
提高等变性的质量。
7.旋转等变性
提高旋转等变性,降低可训练参数量(通过将3x3卷积替换为1x1卷积)。
## 环境配置
### Docker(方法一)
docker pull image.sourcefind.cn:5000/dcu/admin/base/pytorch:1.13.1-centos7.6-dtk-23.04.1-py39-latest
docker run --shm-size 10g --network=host --name=stylegan3 --privileged --device=/dev/kfd --device=/dev/dri --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v 项目地址(绝对路径):/home/ -it <your IMAGE ID> bash
pip install -r requirements.txt
### Docker(方法二)
# 需要在对应的目录下
docker build -t <IMAGE_NAME>:<TAG> .
# <your IMAGE ID>用以上拉取的docker的镜像ID替换
docker run -it --shm-size 10g --network=host --name=stylegan3 --privileged --device=/dev/kfd --device=/dev/dri --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined <your IMAGE ID> bash
pip install -r requirements.txt
### Anaconda (方法三)
1、关于本项目DCU显卡所需的特殊深度学习库可从光合开发者社区下载安装:
https://developer.hpccube.com/tool/
DTK驱动:dtk23.04.1
python:python3.9
torch:1.13.1
torchvision:0.14.1
torchaudio:0.13.1
deepspeed:0.9.2
apex:0.1
Tips:以上dtk驱动、python、torch等DCU相关工具版本需要严格一一对应
2、其它非特殊库参照requirements.txt安装
pip install -r requirements.txt
## 数据集
https://github.com/NVlabs/ffhq-dataset
https://github.com/NVlabs/metfaces-dataset
https://github.com/clovaai/stargan-v2/blob/master/README.md#animal-faces-hq-dataset-afhq
注意:训练的时候需要将一整个文件夹的数据转换成tfrecords的格式,可以通过以下命令生成对应的zip包
#例:处理AFHQv2数据集
python dataset_tool.py --source=downloads/afhqv2 --dest=datasets/afhqv2-512x512.zip
## 推理
### 模型下载
https://catalog.ngc.nvidia.com/orgs/nvidia/teams/research/models/stylegan3
pretrained_models/
|── xxx.pkl
└── stylegan3-r-ffhq-1024x1024.pkl
注意:上述地址中包含多个模型,可以有选择地下载。
### 命令
# 生成图像
python gen_images.py --outdir=out --trunc=1 --seeds=2 --network=pretrained_models/stylegan3-r-ffhq-1024x1024.pkl
# 生成视频
python gen_video.py --output=lerp.mp4 --trunc=1 --seeds=0-31 --grid=4x2 --network=pretrained_models/stylegan3-r-ffhq-1024x1024.pkl
注意:`network`既可以是本地模型也可以是url(如,`https://catalog.ngc.nvidia.com/orgs/nvidia/teams/research/models/stylegan3/files/xxx.pkl`)。下载的网络pickle缓存在`$HOME/.cache/dnnlib`,可以通过设置`DNNLIB_CACHE_DIR`环境变量来覆盖。默认的PyTorch扩展构建目录是`$HOME/.cache/torch_extensions`,可以通过设置`TORCH_EXTENSIONS_DIR`来覆盖。
## 训练
# 训练新的网络
# 例:训练 StyleGAN3-T 数据集AFHQv2, 4张DCU.
python train.py --outdir=training-runs --cfg=stylegan3-t --data=datasets/afhqv2-512x512.zip \
--gpus=4 --batch=16 --gamma=8.2 --mirror=1
# 微调
#例:预训练的FFHQ-U pickle开始,使用1 DCU微调MetFaces-U的StyleGAN3-R
python train.py --outdir=training-runs --cfg=stylegan3-r --data=datasets/metfacesu-1024x1024.zip \
--gpus=4 --batch=16 --gamma=6.6 --mirror=1 --kimg=5000 --snap=5 \
--resume=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-ffhqu-1024x1024.pkl
注意:必须显式指定最重要的选项(`--gpu``--batch``--gamma`),并且应该谨慎设置,请参阅`python train.py --help`,以获得完整的选项列表和一般指南和建议的训练配置,以及不同场景下预期的训练速度和内存使用情况。
每次训练运行的结果都保存到新创建的目录中,例如`training-runs/00000-stylegan3-t-afhqv2-512x512-gpus8-batch32-gamma8.2`。训练循环定期(由`--snap`控制)导出网络pickles (`network-snapshot-<king>.pkl`)和随机图像网格(`fakes<king>.png`)。对于每个导出的pickle,评估FID(由`--metrics`控制)并在`metric-fid50k_full.jsonl`中记录结果,同时还在`training_stats.jsonl`中记录各种统计数据。
## result
![image-20230710170324231](./pngs/show.gif)
### 精度
## 应用场景
### 算法类别
`AIGC`
### 热点应用行业
`媒体,科研,教育`
## 源码仓库及问题反馈
* https://developer.hpccube.com/codes/modelzoo/stylegan3_pytorch
## 参考资料
* https://github.com/NVlabs/stylegan3
This diff is collapsed.
## [Alias-Free Generative Adversarial Networks (StyleGAN3)](https://nvlabs.github.io/stylegan3/)
### 模型介绍
StyleGAN3尝试解决生成模型(GAN)的一个普遍问题:生成的过程并不是一个自然的层次化生成,粗糙特征(GAN的浅层网络的输出特征)主要控制了精细特征(GAN的深层网络的输出特征)的存在与否,没有严格控制他们出现的精确位置。产生这个现象的根本原因是目前的生成器网络架构是卷积+非线性+上采样等结构,而这样的结构没有做到很好的 Equivariance(等变性),本文作者们就是要改进现有的生成器网络结构(主要针对StyleGAN2),使得其具有高质量的Equivariance。
StyleGAN3从信号处理的理论角度重新审视了生成模型框架,并提供了一种全新的生成模型设计方案来赋予网络Equivariance。
### 模型结构
主要针对生成器(StyleGAN2)结构进行改进,整个生成器网络的改进操作汇总为如下表格:
![image-20230629154853155](./pngs/stru.png)
### 数据集
下载[**FFHQ**](https://github.com/NVlabs/ffhq-dataset)**[MetFaces](https://github.com/NVlabs/metfaces-dataset)**[**AFHQv2**](https://github.com/clovaai/stargan-v2/blob/master/README.md#animal-faces-hq-dataset-afhq)或者准备自己的数据集
训练的时候需要将一整个文件夹的数据转换成tfrecords的格式,可以通过以下命令生成对应的zip包
```shell
#例:处理AFHQv2数据集
python dataset_tool.py --source=~/downloads/afhqv2 --dest=~/datasets/afhqv2-512x512.zip
```
### 训练及推理
#### 环境配置
- 支持Linux和Windows,但出于性能和兼容性的原因,我们推荐Linux
- 1–4张 DCU卡
- 64-bit Python 3.8、dtk23.04,并下载对应版本的torch
pytorch whl包下载目录:[https://cancon.hpccube.com:65024/4/main/pytorch/dtk23.04](https://cancon.hpccube.com:65024/4/main/pytorch/dtk23.04)
```shell
pip install torch* (下载的torch的whl包)
```
- 其他包依赖:
```shell
pip install -r requirements.txt
```
- GCC 7或更高版本(Linux)或Visual Studio (Windows)编译器、
#### 推理
预训练的网络存储为`*pkl`,可以使用本地或url的`pkl`文件
预训练模型地址:[StyleGAN3 pretrained models](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/research/models/stylegan3)
> 也可通过以下方式访问个别网络
>
> https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions`/1/files/<MODEL>`, `<MODEL>` 可以选:
>
> `stylegan3-t-ffhq-1024x1024.pkl`, `stylegan3-t-ffhqu-1024x1024.pkl`,`stylegan3-t-ffhqu-256x256.pkl`
>
> `stylegan3-r-ffhq-1024x1024.pkl`, `stylegan3-r-ffhqu-1024x1024.pkl`,`stylegan3-r-ffhqu-256x256.pkl`
>
> `stylegan3-t-metfaces-1024x1024.pkl`,`stylegan3-t-metfacesu-1024x1024.pkl`,`stylegan3-r-metfaces-1024x1024.pkl`, `stylegan3-r-metfacesu-1024x1024.pkl`,`stylegan3-t-afhqv2-512x512.pkl`,
>
> `stylegan3-r-afhqv2-512x512.pkl`
```shell
#使用预训练的AFHQv2模型生成图像
python gen_images.py --outdir=out --trunc=1 --seeds=2 \
--network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-afhqv2-512x512.pkl
#使用预训练的AFHQv2模型生成视频
python gen_video.py --output=lerp.mp4 --trunc=1 --seeds=0-31 --grid=4x2 \
--network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-afhqv2-512x512.pkl
```
上述命令的输出放在`out/*.png`下,由`--outdir`控制。`--network`可以指定url路径(如上例代码)或者本地`pkl`路径。
下载的网络pickle缓存在`$HOME/.cache/dnnlib`,可以通过设置`DNNLIB_CACHE_DIR`环境变量来覆盖。默认的PyTorch扩展构建目录是`$HOME/.cache/torch_extensions`,可以通过设置`TORCH_EXTENSIONS_DIR`来覆盖。
#### 训练与预训练
可以使用`train.py`训练新的网络,例如:
```shell
# 例:训练 StyleGAN3-T 数据集AFHQv2, 4张DCU.
python train.py --outdir=~/training-runs --cfg=stylegan3-t --data=~/datasets/afhqv2-512x512.zip \
--gpus=4 --batch=16 --gamma=8.2 --mirror=1
#例:预训练的FFHQ-U pickle开始,使用1 DCU微调MetFaces-U的StyleGAN3-R
python train.py --outdir=~/training-runs --cfg=stylegan3-r --data=~/datasets/metfacesu-1024x1024.zip \
--gpus=4 --batch=16 --gamma=6.6 --mirror=1 --kimg=5000 --snap=5 \
--resume=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-ffhqu-1024x1024.pkl
```
必须显式指定最重要的选项(`--gpu``--batch``--gamma`),并且应该谨慎设置,请参阅`python train.py --help`,以获得完整的选项列表和一般指南和建议的训练配置,以及不同场景下预期的训练速度和内存使用情况。
每次训练运行的结果都保存到新创建的目录中,例如`~/training-runs/00000-stylegan3-t-afhqv2-512x512-gpus8-batch32-gamma8.2`。训练循环定期(由`--snap`控制)导出网络pickles (`network-snapshot-<king>.pkl`)和随机图像网格(`fakes<king>.png`)。对于每个导出的pickle,评估FID(由`--metrics`控制)并在`metric-fid50k_full.jsonl`中记录结果,同时还在`training_stats.jsonl`中记录各种统计数据。
### 性能和准确率数据
##### DCU训练损失
![image-20230710170324239](./pngs/loss.png)
##### 生成效果(Metfaces)
![image-20230710170324231](./pngs/show.gif)
##### 频谱分析
![image-20230710182732783](./pngs/acc.png)
### Note
`--gamma`: per-dataset basis R1 regularization weight,最需要调整的参数,通常和训练集的分辨率呈 *倍率平方* 关系,如分辨率从 256x256 → 512x512 那么对应的 gamma 则从 2→8。`-r``-t`是一致的,略低于 stylegan2。
`–-metrics`:用于在训练过程中评估生成的图像相较于数据集的质量,如果不是为了写paper做研究性数据就设置为none,否则**非常耗时**
`--aug`:`noaug` 会禁用 ADA。在至少 100k 训练 images 时(包括翻转后的),这个效果不大,但是训练数据少于 30k 的时候,很有用。
`--cbase=16384`会加速训练,但是效果会变差(低分辨率几乎没有影响,如256*256)。
### 源码仓库及问题反馈
[https://github.com/NVlabs/stylegan3](https://github.com/NVlabs/stylegan3)
### 其他参考
[README_ORIGIN.md](./README_ORIGIN.md)
[Training configurations](https://github.com/NVlabs/stylegan3/blob/main/docs/configs.md)
\ No newline at end of file
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
"""Compare average power spectra between real and generated images,
or between multiple generators."""
import os
import numpy as np
import torch
import torch.fft
import scipy.ndimage
import matplotlib.pyplot as plt
import click
import tqdm
import dnnlib
import legacy
from training import dataset
#----------------------------------------------------------------------------
# Setup an iterator for streaming images, in uint8 NCHW format, based on the
# respective command line options.
def stream_source_images(source, num, seed, device, data_loader_kwargs=None): # => num_images, image_size, image_iter
ext = source.split('.')[-1].lower()
if data_loader_kwargs is None:
data_loader_kwargs = dict(pin_memory=True, num_workers=3, prefetch_factor=2)
if ext == 'pkl':
if num is None:
raise click.ClickException('--num is required when --source points to network pickle')
with dnnlib.util.open_url(source) as f:
G = legacy.load_network_pkl(f)['G_ema'].to(device)
def generate_image(seed):
rnd = np.random.RandomState(seed)
z = torch.from_numpy(rnd.randn(1, G.z_dim)).to(device)
c = torch.zeros([1, G.c_dim], device=device)
if G.c_dim > 0:
c[:, rnd.randint(G.c_dim)] = 1
return (G(z=z, c=c) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
_ = generate_image(seed) # warm up
image_iter = (generate_image(seed + idx) for idx in range(num))
return num, G.img_resolution, image_iter
elif ext == 'zip' or os.path.isdir(source):
dataset_obj = dataset.ImageFolderDataset(path=source, max_size=num, random_seed=seed)
if num is not None and num != len(dataset_obj):
raise click.ClickException(f'--source contains fewer than {num} images')
data_loader = torch.utils.data.DataLoader(dataset_obj, batch_size=1, **data_loader_kwargs)
image_iter = (image.to(device) for image, _label in data_loader)
return len(dataset_obj), dataset_obj.resolution, image_iter
else:
raise click.ClickException('--source must point to network pickle, dataset zip, or directory')
#----------------------------------------------------------------------------
# Load average power spectrum from the specified .npz file and construct
# the corresponding heatmap for visualization.
def construct_heatmap(npz_file, smooth):
npz_data = np.load(npz_file)
spectrum = npz_data['spectrum']
image_size = npz_data['image_size']
hmap = np.log10(spectrum) * 10 # dB
hmap = np.fft.fftshift(hmap)
hmap = np.concatenate([hmap, hmap[:1, :]], axis=0)
hmap = np.concatenate([hmap, hmap[:, :1]], axis=1)
if smooth > 0:
sigma = spectrum.shape[0] / image_size * smooth
hmap = scipy.ndimage.gaussian_filter(hmap, sigma=sigma, mode='nearest')
return hmap, image_size
#----------------------------------------------------------------------------
@click.group()
def main():
"""Compare average power spectra between real and generated images,
or between multiple generators.
Example:
\b
# Calculate dataset mean and std, needed in subsequent steps.
python avg_spectra.py stats --source=~/datasets/ffhq-1024x1024.zip
\b
# Calculate average spectrum for the training data.
python avg_spectra.py calc --source=~/datasets/ffhq-1024x1024.zip \\
--dest=tmp/training-data.npz --mean=112.684 --std=69.509
\b
# Calculate average spectrum for a pre-trained generator.
python avg_spectra.py calc \\
--source=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-ffhq-1024x1024.pkl \\
--dest=tmp/stylegan3-r.npz --mean=112.684 --std=69.509 --num=70000
\b
# Display results.
python avg_spectra.py heatmap tmp/training-data.npz
python avg_spectra.py heatmap tmp/stylegan3-r.npz
python avg_spectra.py slices tmp/training-data.npz tmp/stylegan3-r.npz
\b
# Save as PNG.
python avg_spectra.py heatmap tmp/training-data.npz --save=tmp/training-data.png --dpi=300
python avg_spectra.py heatmap tmp/stylegan3-r.npz --save=tmp/stylegan3-r.png --dpi=300
python avg_spectra.py slices tmp/training-data.npz tmp/stylegan3-r.npz --save=tmp/slices.png --dpi=300
"""
#----------------------------------------------------------------------------
@main.command()
@click.option('--source', help='Network pkl, dataset zip, or directory', metavar='[PKL|ZIP|DIR]', required=True)
@click.option('--num', help='Number of images to process [default: all]', metavar='INT', type=click.IntRange(min=1))
@click.option('--seed', help='Random seed for selecting the images', metavar='INT', type=click.IntRange(min=0), default=0, show_default=True)
def stats(source, num, seed, device=torch.device('cuda')):
"""Calculate dataset mean and standard deviation needed by 'calc'."""
torch.multiprocessing.set_start_method('spawn')
num_images, _image_size, image_iter = stream_source_images(source=source, num=num, seed=seed, device=device)
# Accumulate moments.
moments = torch.zeros([3], dtype=torch.float64, device=device)
for image in tqdm.tqdm(image_iter, total=num_images):
image = image.to(torch.float64)
moments += torch.stack([torch.ones_like(image).sum(), image.sum(), image.square().sum()])
moments = moments / moments[0]
# Compute mean and standard deviation.
mean = moments[1]
std = (moments[2] - moments[1].square()).sqrt()
print(f'--mean={mean:g} --std={std:g}')
#----------------------------------------------------------------------------
@main.command()
@click.option('--source', help='Network pkl, dataset zip, or directory', metavar='[PKL|ZIP|DIR]', required=True)
@click.option('--dest', help='Where to store the result', metavar='NPZ', required=True)
@click.option('--mean', help='Dataset mean for whitening', metavar='FLOAT', type=float, required=True)
@click.option('--std', help='Dataset standard deviation for whitening', metavar='FLOAT', type=click.FloatRange(min=0), required=True)
@click.option('--num', help='Number of images to process [default: all]', metavar='INT', type=click.IntRange(min=1))
@click.option('--seed', help='Random seed for selecting the images', metavar='INT', type=click.IntRange(min=0), default=0, show_default=True)
@click.option('--beta', help='Shape parameter for the Kaiser window', metavar='FLOAT', type=click.FloatRange(min=0), default=8, show_default=True)
@click.option('--interp', help='Frequency-domain interpolation factor', metavar='INT', type=click.IntRange(min=1), default=4, show_default=True)
def calc(source, dest, mean, std, num, seed, beta, interp, device=torch.device('cuda')):
"""Calculate average power spectrum and store it in .npz file."""
torch.multiprocessing.set_start_method('spawn')
num_images, image_size, image_iter = stream_source_images(source=source, num=num, seed=seed, device=device)
spectrum_size = image_size * interp
padding = spectrum_size - image_size
# Setup window function.
window = torch.kaiser_window(image_size, periodic=False, beta=beta, device=device)
window *= window.square().sum().rsqrt()
window = window.ger(window).unsqueeze(0).unsqueeze(1)
# Accumulate power spectrum.
spectrum = torch.zeros([spectrum_size, spectrum_size], dtype=torch.float64, device=device)
for image in tqdm.tqdm(image_iter, total=num_images):
image = (image.to(torch.float64) - mean) / std
image = torch.nn.functional.pad(image * window, [0, padding, 0, padding])
spectrum += torch.fft.fftn(image, dim=[2,3]).abs().square().mean(dim=[0,1])
spectrum /= num_images
# Save result.
if os.path.dirname(dest):
os.makedirs(os.path.dirname(dest), exist_ok=True)
np.savez(dest, spectrum=spectrum.cpu().numpy(), image_size=image_size)
#----------------------------------------------------------------------------
@main.command()
@click.argument('npz-file', nargs=1)
@click.option('--save', help='Save the plot and exit', metavar='[PNG|PDF|...]')
@click.option('--dpi', help='Figure resolution', metavar='FLOAT', type=click.FloatRange(min=1), default=100, show_default=True)
@click.option('--smooth', help='Amount of smoothing', metavar='FLOAT', type=click.FloatRange(min=0), default=1.25, show_default=True)
def heatmap(npz_file, save, smooth, dpi):
"""Visualize 2D heatmap based on the given .npz file."""
hmap, image_size = construct_heatmap(npz_file=npz_file, smooth=smooth)
# Setup plot.
plt.figure(figsize=[6, 4.8], dpi=dpi, tight_layout=True)
freqs = np.linspace(-0.5, 0.5, num=hmap.shape[0], endpoint=True) * image_size
ticks = np.linspace(freqs[0], freqs[-1], num=5, endpoint=True)
levels = np.linspace(-40, 20, num=13, endpoint=True)
# Draw heatmap.
plt.xlim(ticks[0], ticks[-1])
plt.ylim(ticks[0], ticks[-1])
plt.xticks(ticks)
plt.yticks(ticks)
plt.contourf(freqs, freqs, hmap, levels=levels, extend='both', cmap='Blues')
plt.gca().set_aspect('equal')
plt.colorbar(ticks=levels)
plt.contour(freqs, freqs, hmap, levels=levels, extend='both', linestyles='solid', linewidths=1, colors='midnightblue', alpha=0.2)
# Display or save.
if save is None:
plt.show()
else:
if os.path.dirname(save):
os.makedirs(os.path.dirname(save), exist_ok=True)
plt.savefig(save)
#----------------------------------------------------------------------------
@main.command()
@click.argument('npz-files', nargs=-1, required=True)
@click.option('--save', help='Save the plot and exit', metavar='[PNG|PDF|...]')
@click.option('--dpi', help='Figure resolution', metavar='FLOAT', type=click.FloatRange(min=1), default=100, show_default=True)
@click.option('--smooth', help='Amount of smoothing', metavar='FLOAT', type=click.FloatRange(min=0), default=0, show_default=True)
def slices(npz_files, save, dpi, smooth):
"""Visualize 1D slices based on the given .npz files."""
cases = [dnnlib.EasyDict(npz_file=npz_file) for npz_file in npz_files]
for c in cases:
c.hmap, c.image_size = construct_heatmap(npz_file=c.npz_file, smooth=smooth)
c.label = os.path.splitext(os.path.basename(c.npz_file))[0]
# Check consistency.
image_size = cases[0].image_size
hmap_size = cases[0].hmap.shape[0]
if any(c.image_size != image_size or c.hmap.shape[0] != hmap_size for c in cases):
raise click.ClickException('All .npz must have the same resolution')
# Setup plot.
plt.figure(figsize=[12, 4.6], dpi=dpi, tight_layout=True)
hmap_center = hmap_size // 2
hmap_range = np.arange(hmap_center, hmap_size)
freqs0 = np.linspace(0, image_size / 2, num=(hmap_size // 2 + 1), endpoint=True)
freqs45 = np.linspace(0, image_size / np.sqrt(2), num=(hmap_size // 2 + 1), endpoint=True)
xticks0 = np.linspace(freqs0[0], freqs0[-1], num=9, endpoint=True)
xticks45 = np.round(np.linspace(freqs45[0], freqs45[-1], num=9, endpoint=True))
yticks = np.linspace(-50, 30, num=9, endpoint=True)
# Draw 0 degree slice.
plt.subplot(1, 2, 1)
plt.title('0\u00b0 slice')
plt.xlim(xticks0[0], xticks0[-1])
plt.ylim(yticks[0], yticks[-1])
plt.xticks(xticks0)
plt.yticks(yticks)
for c in cases:
plt.plot(freqs0, c.hmap[hmap_center, hmap_range], label=c.label)
plt.grid()
plt.legend(loc='upper right')
# Draw 45 degree slice.
plt.subplot(1, 2, 2)
plt.title('45\u00b0 slice')
plt.xlim(xticks45[0], xticks45[-1])
plt.ylim(yticks[0], yticks[-1])
plt.xticks(xticks45)
plt.yticks(yticks)
for c in cases:
plt.plot(freqs45, c.hmap[hmap_range, hmap_range], label=c.label)
plt.grid()
plt.legend(loc='upper right')
# Display or save.
if save is None:
plt.show()
else:
if os.path.dirname(save):
os.makedirs(os.path.dirname(save), exist_ok=True)
plt.savefig(save)
#----------------------------------------------------------------------------
if __name__ == "__main__":
main() # pylint: disable=no-value-for-parameter
#----------------------------------------------------------------------------
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
"""Calculate quality metrics for previous training run or pretrained network pickle."""
import os
import click
import json
import tempfile
import copy
import torch
import dnnlib
import legacy
from metrics import metric_main
from metrics import metric_utils
from torch_utils import training_stats
from torch_utils import custom_ops
from torch_utils import misc
from torch_utils.ops import conv2d_gradfix
#----------------------------------------------------------------------------
def subprocess_fn(rank, args, temp_dir):
dnnlib.util.Logger(should_flush=True)
# Init torch.distributed.
if args.num_gpus > 1:
init_file = os.path.abspath(os.path.join(temp_dir, '.torch_distributed_init'))
if os.name == 'nt':
init_method = 'file:///' + init_file.replace('\\', '/')
torch.distributed.init_process_group(backend='gloo', init_method=init_method, rank=rank, world_size=args.num_gpus)
else:
init_method = f'file://{init_file}'
torch.distributed.init_process_group(backend='nccl', init_method=init_method, rank=rank, world_size=args.num_gpus)
#torch.distributed.init_process_group(backend='gloo', init_method=init_method, rank=rank, world_size=args.num_gpus)
# Init torch_utils.
sync_device = torch.device('cuda', rank) if args.num_gpus > 1 else None
training_stats.init_multiprocessing(rank=rank, sync_device=sync_device)
if rank != 0 or not args.verbose:
custom_ops.verbosity = 'none'
# Configure torch.
device = torch.device('cuda', rank)
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
conv2d_gradfix.enabled = True
# Print network summary.
G = copy.deepcopy(args.G).eval().requires_grad_(False).to(device)
if rank == 0 and args.verbose:
z = torch.empty([1, G.z_dim], device=device)
c = torch.empty([1, G.c_dim], device=device)
misc.print_module_summary(G, [z, c])
# Calculate each metric.
for metric in args.metrics:
if rank == 0 and args.verbose:
print(f'Calculating {metric}...')
progress = metric_utils.ProgressMonitor(verbose=args.verbose)
result_dict = metric_main.calc_metric(metric=metric, G=G, dataset_kwargs=args.dataset_kwargs,
num_gpus=args.num_gpus, rank=rank, device=device, progress=progress)
if rank == 0:
metric_main.report_metric(result_dict, run_dir=args.run_dir, snapshot_pkl=args.network_pkl)
if rank == 0 and args.verbose:
print()
# Done.
if rank == 0 and args.verbose:
print('Exiting...')
#----------------------------------------------------------------------------
def parse_comma_separated_list(s):
if isinstance(s, list):
return s
if s is None or s.lower() == 'none' or s == '':
return []
return s.split(',')
#----------------------------------------------------------------------------
@click.command()
@click.pass_context
@click.option('network_pkl', '--network', help='Network pickle filename or URL', metavar='PATH', required=True)
@click.option('--metrics', help='Quality metrics', metavar='[NAME|A,B,C|none]', type=parse_comma_separated_list, default='fid50k_full', show_default=True)
@click.option('--data', help='Dataset to evaluate against [default: look up]', metavar='[ZIP|DIR]')
@click.option('--mirror', help='Enable dataset x-flips [default: look up]', type=bool, metavar='BOOL')
@click.option('--gpus', help='Number of GPUs to use', type=int, default=1, metavar='INT', show_default=True)
@click.option('--verbose', help='Print optional information', type=bool, default=True, metavar='BOOL', show_default=True)
def calc_metrics(ctx, network_pkl, metrics, data, mirror, gpus, verbose):
"""Calculate quality metrics for previous training run or pretrained network pickle.
Examples:
\b
# Previous training run: look up options automatically, save result to JSONL file.
python calc_metrics.py --metrics=eqt50k_int,eqr50k \\
--network=~/training-runs/00000-stylegan3-r-mydataset/network-snapshot-000000.pkl
\b
# Pre-trained network pickle: specify dataset explicitly, print result to stdout.
python calc_metrics.py --metrics=fid50k_full --data=~/datasets/ffhq-1024x1024.zip --mirror=1 \\
--network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-ffhq-1024x1024.pkl
\b
Recommended metrics:
fid50k_full Frechet inception distance against the full dataset.
kid50k_full Kernel inception distance against the full dataset.
pr50k3_full Precision and recall againt the full dataset.
ppl2_wend Perceptual path length in W, endpoints, full image.
eqt50k_int Equivariance w.r.t. integer translation (EQ-T).
eqt50k_frac Equivariance w.r.t. fractional translation (EQ-T_frac).
eqr50k Equivariance w.r.t. rotation (EQ-R).
\b
Legacy metrics:
fid50k Frechet inception distance against 50k real images.
kid50k Kernel inception distance against 50k real images.
pr50k3 Precision and recall against 50k real images.
is50k Inception score for CIFAR-10.
"""
dnnlib.util.Logger(should_flush=True)
# Validate arguments.
args = dnnlib.EasyDict(metrics=metrics, num_gpus=gpus, network_pkl=network_pkl, verbose=verbose)
if not all(metric_main.is_valid_metric(metric) for metric in args.metrics):
ctx.fail('\n'.join(['--metrics can only contain the following values:'] + metric_main.list_valid_metrics()))
if not args.num_gpus >= 1:
ctx.fail('--gpus must be at least 1')
# Load network.
if not dnnlib.util.is_url(network_pkl, allow_file_urls=True) and not os.path.isfile(network_pkl):
ctx.fail('--network must point to a file or URL')
if args.verbose:
print(f'Loading network from "{network_pkl}"...')
with dnnlib.util.open_url(network_pkl, verbose=args.verbose) as f:
network_dict = legacy.load_network_pkl(f)
args.G = network_dict['G_ema'] # subclass of torch.nn.Module
# Initialize dataset options.
if data is not None:
args.dataset_kwargs = dnnlib.EasyDict(class_name='training.dataset.ImageFolderDataset', path=data)
elif network_dict['training_set_kwargs'] is not None:
args.dataset_kwargs = dnnlib.EasyDict(network_dict['training_set_kwargs'])
else:
ctx.fail('Could not look up dataset options; please specify --data')
# Finalize dataset options.
args.dataset_kwargs.resolution = args.G.img_resolution
args.dataset_kwargs.use_labels = (args.G.c_dim != 0)
if mirror is not None:
args.dataset_kwargs.xflip = mirror
# Print dataset options.
if args.verbose:
print('Dataset options:')
print(json.dumps(args.dataset_kwargs, indent=2))
# Locate run dir.
args.run_dir = None
if os.path.isfile(network_pkl):
pkl_dir = os.path.dirname(network_pkl)
if os.path.isfile(os.path.join(pkl_dir, 'training_options.json')):
args.run_dir = pkl_dir
# Launch processes.
if args.verbose:
print('Launching processes...')
torch.multiprocessing.set_start_method('spawn')
with tempfile.TemporaryDirectory() as temp_dir:
if args.num_gpus == 1:
subprocess_fn(rank=0, args=args, temp_dir=temp_dir)
else:
torch.multiprocessing.spawn(fn=subprocess_fn, args=(args, temp_dir), nprocs=args.num_gpus)
#----------------------------------------------------------------------------
if __name__ == "__main__":
calc_metrics() # pylint: disable=no-value-for-parameter
#----------------------------------------------------------------------------
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
"""Tool for creating ZIP/PNG based datasets."""
import functools
import gzip
import io
import json
import os
import pickle
import re
import sys
import tarfile
import zipfile
from pathlib import Path
from typing import Callable, Optional, Tuple, Union
import click
import numpy as np
import PIL.Image
from tqdm import tqdm
#----------------------------------------------------------------------------
def error(msg):
print('Error: ' + msg)
sys.exit(1)
#----------------------------------------------------------------------------
def parse_tuple(s: str) -> Tuple[int, int]:
'''Parse a 'M,N' or 'MxN' integer tuple.
Example:
'4x2' returns (4,2)
'0,1' returns (0,1)
'''
m = re.match(r'^(\d+)[x,](\d+)$', s)
if m:
return (int(m.group(1)), int(m.group(2)))
raise ValueError(f'cannot parse tuple {s}')
#----------------------------------------------------------------------------
def maybe_min(a: int, b: Optional[int]) -> int:
if b is not None:
return min(a, b)
return a
#----------------------------------------------------------------------------
def file_ext(name: Union[str, Path]) -> str:
return str(name).split('.')[-1]
#----------------------------------------------------------------------------
def is_image_ext(fname: Union[str, Path]) -> bool:
ext = file_ext(fname).lower()
return f'.{ext}' in PIL.Image.EXTENSION # type: ignore
#----------------------------------------------------------------------------
def open_image_folder(source_dir, *, max_images: Optional[int]):
input_images = [str(f) for f in sorted(Path(source_dir).rglob('*')) if is_image_ext(f) and os.path.isfile(f)]
# Load labels.
labels = {}
meta_fname = os.path.join(source_dir, 'dataset.json')
if os.path.isfile(meta_fname):
with open(meta_fname, 'r') as file:
labels = json.load(file)['labels']
if labels is not None:
labels = { x[0]: x[1] for x in labels }
else:
labels = {}
max_idx = maybe_min(len(input_images), max_images)
def iterate_images():
for idx, fname in enumerate(input_images):
arch_fname = os.path.relpath(fname, source_dir)
arch_fname = arch_fname.replace('\\', '/')
img = np.array(PIL.Image.open(fname))
yield dict(img=img, label=labels.get(arch_fname))
if idx >= max_idx-1:
break
return max_idx, iterate_images()
#----------------------------------------------------------------------------
def open_image_zip(source, *, max_images: Optional[int]):
with zipfile.ZipFile(source, mode='r') as z:
input_images = [str(f) for f in sorted(z.namelist()) if is_image_ext(f)]
# Load labels.
labels = {}
if 'dataset.json' in z.namelist():
with z.open('dataset.json', 'r') as file:
labels = json.load(file)['labels']
if labels is not None:
labels = { x[0]: x[1] for x in labels }
else:
labels = {}
max_idx = maybe_min(len(input_images), max_images)
def iterate_images():
with zipfile.ZipFile(source, mode='r') as z:
for idx, fname in enumerate(input_images):
with z.open(fname, 'r') as file:
img = PIL.Image.open(file) # type: ignore
img = np.array(img)
yield dict(img=img, label=labels.get(fname))
if idx >= max_idx-1:
break
return max_idx, iterate_images()
#----------------------------------------------------------------------------
def open_lmdb(lmdb_dir: str, *, max_images: Optional[int]):
import cv2 # pip install opencv-python # pylint: disable=import-error
import lmdb # pip install lmdb # pylint: disable=import-error
with lmdb.open(lmdb_dir, readonly=True, lock=False).begin(write=False) as txn:
max_idx = maybe_min(txn.stat()['entries'], max_images)
def iterate_images():
with lmdb.open(lmdb_dir, readonly=True, lock=False).begin(write=False) as txn:
for idx, (_key, value) in enumerate(txn.cursor()):
try:
try:
img = cv2.imdecode(np.frombuffer(value, dtype=np.uint8), 1)
if img is None:
raise IOError('cv2.imdecode failed')
img = img[:, :, ::-1] # BGR => RGB
except IOError:
img = np.array(PIL.Image.open(io.BytesIO(value)))
yield dict(img=img, label=None)
if idx >= max_idx-1:
break
except:
print(sys.exc_info()[1])
return max_idx, iterate_images()
#----------------------------------------------------------------------------
def open_cifar10(tarball: str, *, max_images: Optional[int]):
images = []
labels = []
with tarfile.open(tarball, 'r:gz') as tar:
for batch in range(1, 6):
member = tar.getmember(f'cifar-10-batches-py/data_batch_{batch}')
with tar.extractfile(member) as file:
data = pickle.load(file, encoding='latin1')
images.append(data['data'].reshape(-1, 3, 32, 32))
labels.append(data['labels'])
images = np.concatenate(images)
labels = np.concatenate(labels)
images = images.transpose([0, 2, 3, 1]) # NCHW -> NHWC
assert images.shape == (50000, 32, 32, 3) and images.dtype == np.uint8
assert labels.shape == (50000,) and labels.dtype in [np.int32, np.int64]
assert np.min(images) == 0 and np.max(images) == 255
assert np.min(labels) == 0 and np.max(labels) == 9
max_idx = maybe_min(len(images), max_images)
def iterate_images():
for idx, img in enumerate(images):
yield dict(img=img, label=int(labels[idx]))
if idx >= max_idx-1:
break
return max_idx, iterate_images()
#----------------------------------------------------------------------------
def open_mnist(images_gz: str, *, max_images: Optional[int]):
labels_gz = images_gz.replace('-images-idx3-ubyte.gz', '-labels-idx1-ubyte.gz')
assert labels_gz != images_gz
images = []
labels = []
with gzip.open(images_gz, 'rb') as f:
images = np.frombuffer(f.read(), np.uint8, offset=16)
with gzip.open(labels_gz, 'rb') as f:
labels = np.frombuffer(f.read(), np.uint8, offset=8)
images = images.reshape(-1, 28, 28)
images = np.pad(images, [(0,0), (2,2), (2,2)], 'constant', constant_values=0)
assert images.shape == (60000, 32, 32) and images.dtype == np.uint8
assert labels.shape == (60000,) and labels.dtype == np.uint8
assert np.min(images) == 0 and np.max(images) == 255
assert np.min(labels) == 0 and np.max(labels) == 9
max_idx = maybe_min(len(images), max_images)
def iterate_images():
for idx, img in enumerate(images):
yield dict(img=img, label=int(labels[idx]))
if idx >= max_idx-1:
break
return max_idx, iterate_images()
#----------------------------------------------------------------------------
def make_transform(
transform: Optional[str],
output_width: Optional[int],
output_height: Optional[int]
) -> Callable[[np.ndarray], Optional[np.ndarray]]:
def scale(width, height, img):
w = img.shape[1]
h = img.shape[0]
if width == w and height == h:
return img
img = PIL.Image.fromarray(img)
ww = width if width is not None else w
hh = height if height is not None else h
img = img.resize((ww, hh), PIL.Image.LANCZOS)
return np.array(img)
def center_crop(width, height, img):
crop = np.min(img.shape[:2])
img = img[(img.shape[0] - crop) // 2 : (img.shape[0] + crop) // 2, (img.shape[1] - crop) // 2 : (img.shape[1] + crop) // 2]
img = PIL.Image.fromarray(img, 'RGB')
img = img.resize((width, height), PIL.Image.LANCZOS)
return np.array(img)
def center_crop_wide(width, height, img):
ch = int(np.round(width * img.shape[0] / img.shape[1]))
if img.shape[1] < width or ch < height:
return None
img = img[(img.shape[0] - ch) // 2 : (img.shape[0] + ch) // 2]
img = PIL.Image.fromarray(img, 'RGB')
img = img.resize((width, height), PIL.Image.LANCZOS)
img = np.array(img)
canvas = np.zeros([width, width, 3], dtype=np.uint8)
canvas[(width - height) // 2 : (width + height) // 2, :] = img
return canvas
if transform is None:
return functools.partial(scale, output_width, output_height)
if transform == 'center-crop':
if (output_width is None) or (output_height is None):
error ('must specify --resolution=WxH when using ' + transform + 'transform')
return functools.partial(center_crop, output_width, output_height)
if transform == 'center-crop-wide':
if (output_width is None) or (output_height is None):
error ('must specify --resolution=WxH when using ' + transform + ' transform')
return functools.partial(center_crop_wide, output_width, output_height)
assert False, 'unknown transform'
#----------------------------------------------------------------------------
def open_dataset(source, *, max_images: Optional[int]):
if os.path.isdir(source):
if source.rstrip('/').endswith('_lmdb'):
return open_lmdb(source, max_images=max_images)
else:
return open_image_folder(source, max_images=max_images)
elif os.path.isfile(source):
if os.path.basename(source) == 'cifar-10-python.tar.gz':
return open_cifar10(source, max_images=max_images)
elif os.path.basename(source) == 'train-images-idx3-ubyte.gz':
return open_mnist(source, max_images=max_images)
elif file_ext(source) == 'zip':
return open_image_zip(source, max_images=max_images)
else:
assert False, 'unknown archive type'
else:
error(f'Missing input file or directory: {source}')
#----------------------------------------------------------------------------
def open_dest(dest: str) -> Tuple[str, Callable[[str, Union[bytes, str]], None], Callable[[], None]]:
dest_ext = file_ext(dest)
if dest_ext == 'zip':
if os.path.dirname(dest) != '':
os.makedirs(os.path.dirname(dest), exist_ok=True)
zf = zipfile.ZipFile(file=dest, mode='w', compression=zipfile.ZIP_STORED)
def zip_write_bytes(fname: str, data: Union[bytes, str]):
zf.writestr(fname, data)
return '', zip_write_bytes, zf.close
else:
# If the output folder already exists, check that is is
# empty.
#
# Note: creating the output directory is not strictly
# necessary as folder_write_bytes() also mkdirs, but it's better
# to give an error message earlier in case the dest folder
# somehow cannot be created.
if os.path.isdir(dest) and len(os.listdir(dest)) != 0:
error('--dest folder must be empty')
os.makedirs(dest, exist_ok=True)
def folder_write_bytes(fname: str, data: Union[bytes, str]):
os.makedirs(os.path.dirname(fname), exist_ok=True)
with open(fname, 'wb') as fout:
if isinstance(data, str):
data = data.encode('utf8')
fout.write(data)
return dest, folder_write_bytes, lambda: None
#----------------------------------------------------------------------------
@click.command()
@click.pass_context
@click.option('--source', help='Directory or archive name for input dataset', required=True, metavar='PATH')
@click.option('--dest', help='Output directory or archive name for output dataset', required=True, metavar='PATH')
@click.option('--max-images', help='Output only up to `max-images` images', type=int, default=None)
@click.option('--transform', help='Input crop/resize mode', type=click.Choice(['center-crop', 'center-crop-wide']))
@click.option('--resolution', help='Output resolution (e.g., \'512x512\')', metavar='WxH', type=parse_tuple)
def convert_dataset(
ctx: click.Context,
source: str,
dest: str,
max_images: Optional[int],
transform: Optional[str],
resolution: Optional[Tuple[int, int]]
):
"""Convert an image dataset into a dataset archive usable with StyleGAN2 ADA PyTorch.
The input dataset format is guessed from the --source argument:
\b
--source *_lmdb/ Load LSUN dataset
--source cifar-10-python.tar.gz Load CIFAR-10 dataset
--source train-images-idx3-ubyte.gz Load MNIST dataset
--source path/ Recursively load all images from path/
--source dataset.zip Recursively load all images from dataset.zip
Specifying the output format and path:
\b
--dest /path/to/dir Save output files under /path/to/dir
--dest /path/to/dataset.zip Save output files into /path/to/dataset.zip
The output dataset format can be either an image folder or an uncompressed zip archive.
Zip archives makes it easier to move datasets around file servers and clusters, and may
offer better training performance on network file systems.
Images within the dataset archive will be stored as uncompressed PNG.
Uncompresed PNGs can be efficiently decoded in the training loop.
Class labels are stored in a file called 'dataset.json' that is stored at the
dataset root folder. This file has the following structure:
\b
{
"labels": [
["00000/img00000000.png",6],
["00000/img00000001.png",9],
... repeated for every image in the datase
["00049/img00049999.png",1]
]
}
If the 'dataset.json' file cannot be found, the dataset is interpreted as
not containing class labels.
Image scale/crop and resolution requirements:
Output images must be square-shaped and they must all have the same power-of-two
dimensions.
To scale arbitrary input image size to a specific width and height, use the
--resolution option. Output resolution will be either the original
input resolution (if resolution was not specified) or the one specified with
--resolution option.
Use the --transform=center-crop or --transform=center-crop-wide options to apply a
center crop transform on the input image. These options should be used with the
--resolution option. For example:
\b
python dataset_tool.py --source LSUN/raw/cat_lmdb --dest /tmp/lsun_cat \\
--transform=center-crop-wide --resolution=512x384
"""
PIL.Image.init() # type: ignore
if dest == '':
ctx.fail('--dest output filename or directory must not be an empty string')
num_files, input_iter = open_dataset(source, max_images=max_images)
archive_root_dir, save_bytes, close_dest = open_dest(dest)
if resolution is None: resolution = (None, None)
transform_image = make_transform(transform, *resolution)
dataset_attrs = None
labels = []
for idx, image in tqdm(enumerate(input_iter), total=num_files):
idx_str = f'{idx:08d}'
archive_fname = f'{idx_str[:5]}/img{idx_str}.png'
# Apply crop and resize.
img = transform_image(image['img'])
# Transform may drop images.
if img is None:
continue
# Error check to require uniform image attributes across
# the whole dataset.
channels = img.shape[2] if img.ndim == 3 else 1
cur_image_attrs = {
'width': img.shape[1],
'height': img.shape[0],
'channels': channels
}
if dataset_attrs is None:
dataset_attrs = cur_image_attrs
width = dataset_attrs['width']
height = dataset_attrs['height']
if width != height:
error(f'Image dimensions after scale and crop are required to be square. Got {width}x{height}')
if dataset_attrs['channels'] not in [1, 3]:
error('Input images must be stored as RGB or grayscale')
if width != 2 ** int(np.floor(np.log2(width))):
error('Image width/height after scale and crop are required to be power-of-two')
elif dataset_attrs != cur_image_attrs:
err = [f' dataset {k}/cur image {k}: {dataset_attrs[k]}/{cur_image_attrs[k]}' for k in dataset_attrs.keys()] # pylint: disable=unsubscriptable-object
error(f'Image {archive_fname} attributes must be equal across all images of the dataset. Got:\n' + '\n'.join(err))
# Save the image as an uncompressed PNG.
img = PIL.Image.fromarray(img, { 1: 'L', 3: 'RGB' }[channels])
image_bits = io.BytesIO()
img.save(image_bits, format='png', compress_level=0, optimize=False)
save_bytes(os.path.join(archive_root_dir, archive_fname), image_bits.getbuffer())
labels.append([archive_fname, image['label']] if image['label'] is not None else None)
metadata = {
'labels': labels if all(x is not None for x in labels) else None
}
save_bytes(os.path.join(archive_root_dir, 'dataset.json'), json.dumps(metadata))
close_dest()
#----------------------------------------------------------------------------
if __name__ == "__main__":
convert_dataset() # pylint: disable=no-value-for-parameter
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
from .util import EasyDict, make_cache_dir_path
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
"""Miscellaneous utility classes and functions."""
import ctypes
import fnmatch
import importlib
import inspect
import numpy as np
import os
import shutil
import sys
import types
import io
import pickle
import re
import requests
import html
import hashlib
import glob
import tempfile
import urllib
import urllib.request
import uuid
from distutils.util import strtobool
from typing import Any, List, Tuple, Union
# Util classes
# ------------------------------------------------------------------------------------------
class EasyDict(dict):
"""Convenience class that behaves like a dict but allows access with the attribute syntax."""
def __getattr__(self, name: str) -> Any:
try:
return self[name]
except KeyError:
raise AttributeError(name)
def __setattr__(self, name: str, value: Any) -> None:
self[name] = value
def __delattr__(self, name: str) -> None:
del self[name]
class Logger(object):
"""Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file."""
def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True):
self.file = None
if file_name is not None:
self.file = open(file_name, file_mode)
self.should_flush = should_flush
self.stdout = sys.stdout
self.stderr = sys.stderr
sys.stdout = self
sys.stderr = self
def __enter__(self) -> "Logger":
return self
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
self.close()
def write(self, text: Union[str, bytes]) -> None:
"""Write text to stdout (and a file) and optionally flush."""
if isinstance(text, bytes):
text = text.decode()
if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
return
if self.file is not None:
self.file.write(text)
self.stdout.write(text)
if self.should_flush:
self.flush()
def flush(self) -> None:
"""Flush written text to both stdout and a file, if open."""
if self.file is not None:
self.file.flush()
self.stdout.flush()
def close(self) -> None:
"""Flush, close possible files, and remove stdout/stderr mirroring."""
self.flush()
# if using multiple loggers, prevent closing in wrong order
if sys.stdout is self:
sys.stdout = self.stdout
if sys.stderr is self:
sys.stderr = self.stderr
if self.file is not None:
self.file.close()
self.file = None
# Cache directories
# ------------------------------------------------------------------------------------------
_dnnlib_cache_dir = None
def set_cache_dir(path: str) -> None:
global _dnnlib_cache_dir
_dnnlib_cache_dir = path
def make_cache_dir_path(*paths: str) -> str:
if _dnnlib_cache_dir is not None:
return os.path.join(_dnnlib_cache_dir, *paths)
if 'DNNLIB_CACHE_DIR' in os.environ:
return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths)
if 'HOME' in os.environ:
return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths)
if 'USERPROFILE' in os.environ:
return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths)
return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths)
# Small util functions
# ------------------------------------------------------------------------------------------
def format_time(seconds: Union[int, float]) -> str:
"""Convert the seconds to human readable string with days, hours, minutes and seconds."""
s = int(np.rint(seconds))
if s < 60:
return "{0}s".format(s)
elif s < 60 * 60:
return "{0}m {1:02}s".format(s // 60, s % 60)
elif s < 24 * 60 * 60:
return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60)
else:
return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60)
def format_time_brief(seconds: Union[int, float]) -> str:
"""Convert the seconds to human readable string with days, hours, minutes and seconds."""
s = int(np.rint(seconds))
if s < 60:
return "{0}s".format(s)
elif s < 60 * 60:
return "{0}m {1:02}s".format(s // 60, s % 60)
elif s < 24 * 60 * 60:
return "{0}h {1:02}m".format(s // (60 * 60), (s // 60) % 60)
else:
return "{0}d {1:02}h".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24)
def ask_yes_no(question: str) -> bool:
"""Ask the user the question until the user inputs a valid answer."""
while True:
try:
print("{0} [y/n]".format(question))
return strtobool(input().lower())
except ValueError:
pass
def tuple_product(t: Tuple) -> Any:
"""Calculate the product of the tuple elements."""
result = 1
for v in t:
result *= v
return result
_str_to_ctype = {
"uint8": ctypes.c_ubyte,
"uint16": ctypes.c_uint16,
"uint32": ctypes.c_uint32,
"uint64": ctypes.c_uint64,
"int8": ctypes.c_byte,
"int16": ctypes.c_int16,
"int32": ctypes.c_int32,
"int64": ctypes.c_int64,
"float32": ctypes.c_float,
"float64": ctypes.c_double
}
def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:
"""Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes."""
type_str = None
if isinstance(type_obj, str):
type_str = type_obj
elif hasattr(type_obj, "__name__"):
type_str = type_obj.__name__
elif hasattr(type_obj, "name"):
type_str = type_obj.name
else:
raise RuntimeError("Cannot infer type name from input")
assert type_str in _str_to_ctype.keys()
my_dtype = np.dtype(type_str)
my_ctype = _str_to_ctype[type_str]
assert my_dtype.itemsize == ctypes.sizeof(my_ctype)
return my_dtype, my_ctype
def is_pickleable(obj: Any) -> bool:
try:
with io.BytesIO() as stream:
pickle.dump(obj, stream)
return True
except:
return False
# Functionality to import modules/objects by name, and call functions by name
# ------------------------------------------------------------------------------------------
def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]:
"""Searches for the underlying module behind the name to some python object.
Returns the module and the object name (original name with module part removed)."""
# allow convenience shorthands, substitute them by full names
obj_name = re.sub("^np.", "numpy.", obj_name)
obj_name = re.sub("^tf.", "tensorflow.", obj_name)
# list alternatives for (module_name, local_obj_name)
parts = obj_name.split(".")
name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)]
# try each alternative in turn
for module_name, local_obj_name in name_pairs:
try:
module = importlib.import_module(module_name) # may raise ImportError
get_obj_from_module(module, local_obj_name) # may raise AttributeError
return module, local_obj_name
except:
pass
# maybe some of the modules themselves contain errors?
for module_name, _local_obj_name in name_pairs:
try:
importlib.import_module(module_name) # may raise ImportError
except ImportError:
if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"):
raise
# maybe the requested attribute is missing?
for module_name, local_obj_name in name_pairs:
try:
module = importlib.import_module(module_name) # may raise ImportError
get_obj_from_module(module, local_obj_name) # may raise AttributeError
except ImportError:
pass
# we are out of luck, but we have no idea why
raise ImportError(obj_name)
def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:
"""Traverses the object name and returns the last (rightmost) python object."""
if obj_name == '':
return module
obj = module
for part in obj_name.split("."):
obj = getattr(obj, part)
return obj
def get_obj_by_name(name: str) -> Any:
"""Finds the python object with the given name."""
module, obj_name = get_module_from_obj_name(name)
return get_obj_from_module(module, obj_name)
def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any:
"""Finds the python object with the given name and calls it as a function."""
assert func_name is not None
func_obj = get_obj_by_name(func_name)
assert callable(func_obj)
return func_obj(*args, **kwargs)
def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any:
"""Finds the python class with the given name and constructs it with the given arguments."""
return call_func_by_name(*args, func_name=class_name, **kwargs)
def get_module_dir_by_obj_name(obj_name: str) -> str:
"""Get the directory path of the module containing the given object name."""
module, _ = get_module_from_obj_name(obj_name)
return os.path.dirname(inspect.getfile(module))
def is_top_level_function(obj: Any) -> bool:
"""Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'."""
return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__
def get_top_level_function_name(obj: Any) -> str:
"""Return the fully-qualified name of a top-level function."""
assert is_top_level_function(obj)
module = obj.__module__
if module == '__main__':
module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0]
return module + "." + obj.__name__
# File system helpers
# ------------------------------------------------------------------------------------------
def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]:
"""List all files recursively in a given directory while ignoring given file and directory names.
Returns list of tuples containing both absolute and relative paths."""
assert os.path.isdir(dir_path)
base_name = os.path.basename(os.path.normpath(dir_path))
if ignores is None:
ignores = []
result = []
for root, dirs, files in os.walk(dir_path, topdown=True):
for ignore_ in ignores:
dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]
# dirs need to be edited in-place
for d in dirs_to_remove:
dirs.remove(d)
files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]
absolute_paths = [os.path.join(root, f) for f in files]
relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]
if add_base_to_relative:
relative_paths = [os.path.join(base_name, p) for p in relative_paths]
assert len(absolute_paths) == len(relative_paths)
result += zip(absolute_paths, relative_paths)
return result
def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:
"""Takes in a list of tuples of (src, dst) paths and copies files.
Will create all necessary directories."""
for file in files:
target_dir_name = os.path.dirname(file[1])
# will create all intermediate-level directories
if not os.path.exists(target_dir_name):
os.makedirs(target_dir_name)
shutil.copyfile(file[0], file[1])
# URL helpers
# ------------------------------------------------------------------------------------------
def is_url(obj: Any, allow_file_urls: bool = False) -> bool:
"""Determine whether the given object is a valid URL string."""
if not isinstance(obj, str) or not "://" in obj:
return False
if allow_file_urls and obj.startswith('file://'):
return True
try:
res = requests.compat.urlparse(obj)
if not res.scheme or not res.netloc or not "." in res.netloc:
return False
res = requests.compat.urlparse(requests.compat.urljoin(obj, "/"))
if not res.scheme or not res.netloc or not "." in res.netloc:
return False
except:
return False
return True
def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any:
"""Download the given URL and return a binary-mode file object to access the data."""
assert num_attempts >= 1
assert not (return_filename and (not cache))
# Doesn't look like an URL scheme so interpret it as a local filename.
if not re.match('^[a-z]+://', url):
return url if return_filename else open(url, "rb")
# Handle file URLs. This code handles unusual file:// patterns that
# arise on Windows:
#
# file:///c:/foo.txt
#
# which would translate to a local '/c:/foo.txt' filename that's
# invalid. Drop the forward slash for such pathnames.
#
# If you touch this code path, you should test it on both Linux and
# Windows.
#
# Some internet resources suggest using urllib.request.url2pathname() but
# but that converts forward slashes to backslashes and this causes
# its own set of problems.
if url.startswith('file://'):
filename = urllib.parse.urlparse(url).path
if re.match(r'^/[a-zA-Z]:', filename):
filename = filename[1:]
return filename if return_filename else open(filename, "rb")
assert is_url(url)
# Lookup from cache.
if cache_dir is None:
cache_dir = make_cache_dir_path('downloads')
url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
if cache:
cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
if len(cache_files) == 1:
filename = cache_files[0]
return filename if return_filename else open(filename, "rb")
# Download.
url_name = None
url_data = None
with requests.Session() as session:
if verbose:
print("Downloading %s ..." % url, end="", flush=True)
for attempts_left in reversed(range(num_attempts)):
try:
with session.get(url) as res:
res.raise_for_status()
if len(res.content) == 0:
raise IOError("No data received")
if len(res.content) < 8192:
content_str = res.content.decode("utf-8")
if "download_warning" in res.headers.get("Set-Cookie", ""):
links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
if len(links) == 1:
url = requests.compat.urljoin(url, links[0])
raise IOError("Google Drive virus checker nag")
if "Google Drive - Quota exceeded" in content_str:
raise IOError("Google Drive download quota exceeded -- please try again later")
match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
url_name = match[1] if match else url
url_data = res.content
if verbose:
print(" done")
break
except KeyboardInterrupt:
raise
except:
if not attempts_left:
if verbose:
print(" failed")
raise
if verbose:
print(".", end="", flush=True)
# Save to cache.
if cache:
safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
os.makedirs(cache_dir, exist_ok=True)
with open(temp_file, "wb") as f:
f.write(url_data)
os.replace(temp_file, cache_file) # atomic
if return_filename:
return cache_file
# Return data as file object.
assert not return_filename
return io.BytesIO(url_data)
This diff is collapsed.
Usage: dataset_tool.py [OPTIONS]
Convert an image dataset into a dataset archive usable with StyleGAN2 ADA
PyTorch.
The input dataset format is guessed from the --source argument:
--source *_lmdb/ Load LSUN dataset
--source cifar-10-python.tar.gz Load CIFAR-10 dataset
--source train-images-idx3-ubyte.gz Load MNIST dataset
--source path/ Recursively load all images from path/
--source dataset.zip Recursively load all images from dataset.zip
Specifying the output format and path:
--dest /path/to/dir Save output files under /path/to/dir
--dest /path/to/dataset.zip Save output files into /path/to/dataset.zip
The output dataset format can be either an image folder or an uncompressed
zip archive. Zip archives makes it easier to move datasets around file
servers and clusters, and may offer better training performance on network
file systems.
Images within the dataset archive will be stored as uncompressed PNG.
Uncompresed PNGs can be efficiently decoded in the training loop.
Class labels are stored in a file called 'dataset.json' that is stored at
the dataset root folder. This file has the following structure:
{
"labels": [
["00000/img00000000.png",6],
["00000/img00000001.png",9],
... repeated for every image in the datase
["00049/img00049999.png",1]
]
}
If the 'dataset.json' file cannot be found, the dataset is interpreted as
not containing class labels.
Image scale/crop and resolution requirements:
Output images must be square-shaped and they must all have the same power-
of-two dimensions.
To scale arbitrary input image size to a specific width and height, use
the --resolution option. Output resolution will be either the original
input resolution (if resolution was not specified) or the one specified
with --resolution option.
Use the --transform=center-crop or --transform=center-crop-wide options to
apply a center crop transform on the input image. These options should be
used with the --resolution option. For example:
python dataset_tool.py --source LSUN/raw/cat_lmdb --dest /tmp/lsun_cat \
--transform=center-crop-wide --resolution=512x384
Options:
--source PATH Directory or archive name for input dataset
[required]
--dest PATH Output directory or archive name for output
dataset [required]
--max-images INTEGER Output only up to `max-images` images
--transform [center-crop|center-crop-wide]
Input crop/resize mode
--resolution WxH Output resolution (e.g., '512x512')
--help Show this message and exit.
Usage: train.py [OPTIONS]
Train a GAN using the techniques described in the paper "Alias-Free
Generative Adversarial Networks".
Examples:
# Train StyleGAN3-T for AFHQv2 using 8 GPUs.
python train.py --outdir=~/training-runs --cfg=stylegan3-t --data=~/datasets/afhqv2-512x512.zip \
--gpus=8 --batch=32 --gamma=8.2 --mirror=1
# Fine-tune StyleGAN3-R for MetFaces-U using 1 GPU, starting from the pre-trained FFHQ-U pickle.
python train.py --outdir=~/training-runs --cfg=stylegan3-r --data=~/datasets/metfacesu-1024x1024.zip \
--gpus=8 --batch=32 --gamma=6.6 --mirror=1 --kimg=5000 --snap=5 \
--resume=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-ffhqu-1024x1024.pkl
# Train StyleGAN2 for FFHQ at 1024x1024 resolution using 8 GPUs.
python train.py --outdir=~/training-runs --cfg=stylegan2 --data=~/datasets/ffhq-1024x1024.zip \
--gpus=8 --batch=32 --gamma=10 --mirror=1 --aug=noaug
Options:
--outdir DIR Where to save the results [required]
--cfg [stylegan3-t|stylegan3-r|stylegan2]
Base configuration [required]
--data [ZIP|DIR] Training data [required]
--gpus INT Number of GPUs to use [required]
--batch INT Total batch size [required]
--gamma FLOAT R1 regularization weight [required]
--cond BOOL Train conditional model [default: False]
--mirror BOOL Enable dataset x-flips [default: False]
--aug [noaug|ada|fixed] Augmentation mode [default: ada]
--resume [PATH|URL] Resume from given network pickle
--freezed INT Freeze first layers of D [default: 0]
--p FLOAT Probability for --aug=fixed [default: 0.2]
--target FLOAT Target value for --aug=ada [default: 0.6]
--batch-gpu INT Limit batch size per GPU
--cbase INT Capacity multiplier [default: 32768]
--cmax INT Max. feature maps [default: 512]
--glr FLOAT G learning rate [default: varies]
--dlr FLOAT D learning rate [default: 0.002]
--map-depth INT Mapping network depth [default: varies]
--mbstd-group INT Minibatch std group size [default: 4]
--desc STR String to include in result dir name
--metrics [NAME|A,B,C|none] Quality metrics [default: fid50k_full]
--kimg KIMG Total training duration [default: 25000]
--tick KIMG How often to print progress [default: 4]
--snap TICKS How often to save snapshots [default: 50]
--seed INT Random seed [default: 0]
--fp32 BOOL Disable mixed-precision [default: False]
--nobench BOOL Disable cuDNN benchmarking [default: False]
--workers INT DataLoader worker processes [default: 3]
-n, --dry-run Print training options and exit
--help Show this message and exit.
# Troubleshooting
Our PyTorch code uses custom [CUDA extensions](https://pytorch.org/tutorials/advanced/cpp_extension.html) to speed up some of the network layers. Getting these to run can sometimes be a hassle.
This page aims to give guidance on how to diagnose and fix run-time problems related to these extensions.
## Before you start
1. Try Docker first! Ensure you can successfully run our models using the recommended Docker image. Follow the instructions in [README.md](/README.md) to get it running.
2. Can't use Docker? Read on..
## Installing dependencies
Make sure you've installed everything listed on the requirements section in the [README.md](/README.md). The key components w.r.t. custom extensions are:
- **[CUDA toolkit 11.1](https://developer.nvidia.com/cuda-toolkit)** or later (this is not the same as `cudatoolkit` from Conda).
- PyTorch invokes `nvcc` to compile our CUDA kernels.
- **ninja**
- PyTorch uses [Ninja](https://ninja-build.org/) as its build system.
- **GCC** (Linux) or **Visual Studio** (Windows)
- GCC 7.x or later is required. Earlier versions such as GCC 6.3 [are known not to work](https://github.com/NVlabs/stylegan3/issues/2).
#### Why is CUDA toolkit installation necessary?
The PyTorch package contains the required CUDA toolkit libraries needed to run PyTorch, so why is a separate CUDA toolkit installation required? Our models use custom CUDA kernels to implement operations such as efficient resampling of 2D images. PyTorch code invokes the CUDA compiler at run-time to compile these kernels on first-use. The tools and libraries required for this compilation are not bundled in PyTorch and thus a host CUDA toolkit installation is required.
## Things to try
- Completely remove: `$HOME/.cache/torch_extensions` (Linux) or `C:\Users\<username>\AppData\Local\torch_extensions\torch_extensions\Cache` (Windows) and re-run StyleGAN3 python code.
- Run ninja in `$HOME/.cache/torch_extensions` to see that it builds.
- Inspect the `build.ninja` in the build directories under `$HOME/.cache/torch_extensions` and check CUDA tools and versions are consistent with what you intended to use.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment