README.md 11.6 KB
Newer Older
mashun1's avatar
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
# taming-transformers

VQGAN

## 论文

**Taming Transformers for High-Resolution Image Synthesis**

* https://arxiv.org/abs/2012.09841

## 模型结构

该模型的主体结构为加入了离散编码本的`AutoEncoder`结构,具体如下图所示,其中`E`(CNN Encoder)用于压缩提取图像中的信息,`Z`(编码本Codebook)记录图像特征,`G`(Decoder)用于生成图像,`D`(CNN Discriminator)判断生成图像的真假,`Transformer`根据现有特征预测接下来的特征(编码本索引)。

![Alt text](readme_imgs/image-1.png)

## 算法原理

该算法结合了CNN与Transformer,可用于高清图像的生成,具体如下,

1、CNN + Transformer

使用卷积神经网络(CNN)架构对成分(纹理、形状、物体以及其他视觉特征)进行建模,并使用Transformer架构对它们的组合进行建模,充分发挥了它们的互补优势。

![Alt text](readme_imgs/R-C.gif)

![Alt text](readme_imgs/image-2.png)

## 环境配置

### Docker(方法一)
    docker pull image.sourcefind.cn:5000/dcu/admin/base/pytorch:1.10.0-centos7.6-dtk-23.04-py38-latest
    docker run --shm-size 10g --network=host --name=taming_transformer --privileged --device=/dev/kfd --device=/dev/dri --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v 项目地址(绝对路径):/home/ -it <your IMAGE ID> bash
mashun1's avatar
mashun1 committed
34

mashun1's avatar
mashun1 committed
35
36
    pip install -r requirements.txt

mashun1's avatar
mashun1 committed
37
38
    pip install torchvision==0.11.1 --no-deps

mashun1's avatar
mashun1 committed
39
40
41
42
43
44
### Dockerfile(方法二)

    # 需要在对应的目录下
    docker build -t <IMAGE_NAME>:<TAG> .
    # <your IMAGE ID>用以上拉取的docker的镜像ID替换
    docker run -it --shm-size 10g --network=host --name=taming_transformer --privileged --device=/dev/kfd --device=/dev/dri --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined <your IMAGE ID> bash
mashun1's avatar
mashun1 committed
45

mashun1's avatar
mashun1 committed
46
47
    pip install -r requirements.txt

mashun1's avatar
mashun1 committed
48
49
    pip install torchvision==0.11.1 --no-deps

mashun1's avatar
mashun1 committed
50
51
52
53
54
55
56
57
58
59
60
61
62
### Anaconda (方法三)
1、关于本项目DCU显卡所需的特殊深度学习库可从光合开发者社区下载安装:
https://developer.hpccube.com/tool/

    DTK驱动:dtk23.04
    python:python3.8
    torch:1.10.0

Tips:以上dtk驱动、python、torch等DCU相关工具版本需要严格一一对应

2、其它非特殊库参照requirements.txt安装

    pip install -r requirements.txt
mashun1's avatar
mashun1 committed
63
    pip install torchvision==0.11.1 --no-deps
mashun1's avatar
mashun1 committed
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246

## 数据集

|名称|URL|存放位置|
|:----|:----|:----|
|ImageNet|https://openxlab.org.cn/datasets/OpenDataLab/ImageNet-1K/tree/main|
|CelebA-HQ|https://aistudio.baidu.com/datasetdetail/49050/0|data/celebahq|
|FFHQ|https://github.com/NVlabs/ffhq-dataset|data/ffhq|
|COCO|https://openxlab.org.cn/datasets/OpenDataLab/COCO_2017/tree/main|data/coco|
|COCO-Stuff|http://calvin.inf.ed.ac.uk/wp-content/uploads/data/cocostuffdataset/stuffthingmaps_trainval2017.zip|data/cocostuffthings|
|ADE20k|https://openxlab.org.cn/datasets/OpenDataLab/ADE20K_2016|data/ade20k_root|
|OpenImages-V7|https://github.com/cvdfoundation/open-images-dataset#download-full-dataset-with-google-storage-transfer|

我们也提供了用于测试的tiny数据集,链接:https://pan.baidu.com/s/1UFK-CsMBrnOEsGCNN_P9sA 
提取码:kwai

上述数据集按需下载并放入`data`中。

### 数据处理

#### ImageNet

首先将数据以该结构进行存放,其中${XDG_CACHE}默认为`~/.cache``split`表示`train``validation`

    ${XDG_CACHE}/autoencoders/data/ILSVRC2012_{split}/data/
    ├── n01440764
    │   ├── n01440764_10026.JPEG
    │   ├── n01440764_10027.JPEG
    │   ├── ...
    ├── n01443537
    │   ├── n01443537_10007.JPEG
    │   ├── n01443537_10014.JPEG
    │   ├── ...
    ├── ...

注意:您可以将`ILSVRC2012_img_train.tar或ILSVRC2012_img_val.tar`放入`${XDG_CACHE}/autoencoders/data/ILSVRC2012_train/ 或 ${XDG_CACHE}/autoencoders/data/ILSVRC2012_validation/`,仅当文件夹 `${XDG_CACHE}/autoencoders/data/ILSVRC2012_{split}/data/` 和文件 `${XDG_CACHE}/autoencoders/data/ILSVRC2012_{split}/.ready` 都不存在时,数据才会自动处理为上述结构。

您需要使用MiDaS准备深度数据。创建一个符号链接data/imagenet_depth,指向一个包含两个子文件夹train和val的文件夹,每个子文件夹的结构都与上面描述的相应的ImageNet文件夹相同,并包含一个为ImageNet的每个JPEG文件编码的png文件。该png文件作为RGBA图像,编码从MiDaS获得的float32深度值。我们提供了生成此数据的脚本scripts/extract_depth.py。

    imagenet_depth/
    ├── train
    │   ├── n01440764
    │   │   └── n01440764_10043.png
    │   ├── n01443537
    │   │   └── n01443537_10482.png
    │   ├── n01484850
    │   │   └── n01484850_10160.png
    │   ├── n01491361
    │   │   └── n01491361_10353.png

    ├── val
        ├── n01440764
        │   └── ILSVRC2012_val_00000293.png
        ├── n01443537
        │   └── ILSVRC2012_val_00002848.png
        ├── n01484850
        │   └── ILSVRC2012_val_00002338.png
        ├── n01491361
        │   └── ILSVRC2012_val_00002922.png
        ├── n01494475
        │   └── ILSVRC2012_val_00004417.png
        ├── n01496331
        │   └── ILSVRC2012_val_00004698.png
        ├── n01498041
        │   └── ILSVRC2012_val_00002284.png
        ├── n01514668
        │   └── ILSVRC2012_val_00000329.png

注意:若无法正常运行,您还需要下载`https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1` 并放入 `${XDG_CACHE}/autoencoders/data/ILSVRC2012_{split}/synset_human.txt`,下载`https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1`并放入`${XDG_CACHE}/autoencoders/data/index_synset.yaml`,同时还需要创建`filelist.txt`其中包含`${XDG_CACHE}/autoencoders/data/ILSVRC2012_{split}/data`中的相对文件路径(`n01234/n01234_xxxx.JPEG`)。

#### FacesHQ

该数据集包含`CelebaHQ``FFHQ`,并无特定文件结构,仅需生成相应的文件列表文件即可。具体参考`data`目录下的`celebahqtrain.txt``celebahqvalidation.txt``ffhqtrain.txt`以及`ffhqvalidation.txt`,如需修改数据加载方式,可在`taming/data/faceshq.py`中进行修改。

## 训练

### FacesHQ

1.训练VQGAN

    python main.py --base configs/faceshq_vqgan.yaml -t True --gpus 0,

2.训练transformer

    python main.py --base configs/faceshq_transformer.yaml -t True --gpus 0,

### D-RIN

1.训练VQGAN

    python main.py --base configs/imagenet_vqgan.yaml -t True --gpus 0,

    python main.py --base configs/imagenetdepth_vqgan.yaml -t True --gpus 0,

2.训练transformer

    python main.py --base configs/drin_transformer.yaml -t True --gpus 0,

注意:在训练transformer前需要修改`configs``model.params.first_stage_config.params.ckpt_path`以及`model.params.cond_stage_config.params.ckpt_path`(如果存在)的值(vqgan路径)。

## 推理

### 模型下载

|名称|URL|
|:-------|:----|
|S-FLCKR | https://heibox.uni-heidelberg.de/d/73487ab6e5314cb5adba/ |
|ImageNet | https://k00.fr/s511rwcv|
|FFHQ |https://k00.fr/yndvfu95|
|CelebA-HQ |https://k00.fr/2xkmielf|
|FacesHQ |https://k00.fr/qqfl2do8|
|D-RIN |https://k00.fr/39jcugc5|
|COCO |https://k00.fr/2zz6i2ce|
|ADE20k |https://k00.fr/ot46cksa|
|Scene Image Synthesis| https://drive.google.com/file/d/1FEK-Z7hyWJBvFWQF50pzSK9y1W_CJEig/view?usp=sharing  <br> https://heibox.uni-heidelberg.de/f/0d0b2594e9074c7e9a33/  <br> https://drive.google.com/file/d/1bInd49g2YulTJBjU32Awyt5qnzxxG5U9/


文件结构:

    logs/
    └── 2020-11-09T13-31-51_sflckr
        ├── checkpoints
        │   └── last.ckpt
        └── configs
            ├── 2020-11-09T13-31-51-lightning.yaml
            └── 2020-11-09T13-31-51-project.yaml

注意:上述模型按需下载。

### 命令

    # S-FLCKR
    streamlit run scripts/sample_conditional.py -- -r logs/2020-11-09T13-31-51_sflckr/

    # ImageNet
    # 为ImageNet的每个1000个类别生成50个样本,使用top-k采样中的k=600,nucleus采样中的p=0.92,温度t=1.0。
    python scripts/sample_fast.py -r logs/2021-04-03T19-39-50_cin_transformer/ -n 50 -k 600 -t 1.0 -p 0.92 --batch_size 25   

    # classes用于指定生成类型(如猫,狗,鸟等)
    python scripts/sample_fast.py -r logs/2021-04-03T19-39-50_cin_transformer/ -n 50 -k 600 -t 1.0 -p 0.92 --batch_size 25 --classes 9,232,901   

    # FFHQ
    # 为了生成50000个样本,使用top-k采样中的k=250,nucleus采样中的p=1.0,温度t=1.0。
    python scripts/sample_fast.py -r logs/2021-04-23T18-19-01_ffhq_transformer/   

    # CelebA-HQ
    python scripts/sample_fast.py -r logs/2021-04-23T18-11-19_celebahq_transformer/   

    # FacesHQ
    streamlit run scripts/sample_conditional.py -- -r logs/2020-11-13T21-41-45_faceshq_transformer/

    # D-RIN
    # demo数据
    streamlit run scripts/sample_conditional.py -- -r logs/2020-11-20T12-54-32_drin_transformer/ --ignore_base_data data="{target: main.DataModuleFromConfig, params: {batch_size: 1, validation: {target: taming.data.imagenet.DRINExamples}}}"

    # 所有validation数据(需要准备相应的ImageNet数据集,参考`训练/数据准备`)
    streamlit run scripts/sample_conditional.py -- -r logs/2020-11-20T12-54-32_drin_transformer/
    
    # COCO
    streamlit run scripts/sample_conditional.py -- -r logs/2021-01-20T16-04-20_coco_transformer/ --ignore_base_data data="{target: main.DataModuleFromConfig, params: {batch_size: 1, validation: {target: taming.data.coco.Examples}}}"

    # ADE20k
    streamlit run scripts/sample_conditional.py -- -r logs/2020-11-20T21-45-44_ade20k_transformer/ --ignore_base_data data="{target: main.DataModuleFromConfig, params: {batch_size: 1, validation: {target: taming.data.ade20k.Examples}}}"

### Scene Image Synthesis

#### 模型下载

|名称|URL|
|:----|:-----|
|COCO-8k-VQGAN|https://heibox.uni-heidelberg.de/f/78dea9589974474c97c1/|
| COCO/Open-Images-8k-VQGAN | https://heibox.uni-heidelberg.de/f/461d9a9f4fcf48ab84f4/|
|Open Images 1 billion parameter model| https://drive.google.com/file/d/1FEK-Z7hyWJBvFWQF50pzSK9y1W_CJEig/view?usp=sharing|
|Open Images distilled version|https://drive.google.com/file/d/1xf89g0mc78J3d8Bx5YhbK4tNRNlOoYaO|
|COCO 30 epochs| https://heibox.uni-heidelberg.de/f/0d0b2594e9074c7e9a33/|
|COCO 60 epochs|https://drive.google.com/file/d/1bInd49g2YulTJBjU32Awyt5qnzxxG5U9/|

注意:上述模型中`COCO-8k-VQGAN`以及`COCO/Open-Images-8k-VQGAN`为第一阶段预训练模型,需要进一步训练后使用,这需要下载`COCO/OpenImage`的所有数据集,并修改相应配置文件`configs/xxx.yaml`中的模型和数据路径。

#### 命令

    # 继续训练
    # coco
mashun1's avatar
mashun1 committed
247
    python main.py --base configs/coco_scene_images_transformer.yaml -t True --gpus 0,
mashun1's avatar
mashun1 committed
248
249

    # openimage
mashun1's avatar
mashun1 committed
250
    python main.py --base configs/open_images_scene_images_transformer.yaml -t True --gpus 0,
mashun1's avatar
mashun1 committed
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289

    # 推理
    python scripts/make_scene_samples.py --outdir=/some/outdir -r /path/to/pretrained/model --resolution=512,512

注意:需要下载`arialuni.ttf`字体文件,并放入`taming/data/conditional_builder/font/`目录下。

## result

S-FLCKR

![Alt text](readme_imgs/image-3.png)

### 精度

指标:FID(越小越好)

数据集:FacesHQ

||nopix(随机采样)|half(补全)|
|:---|:---:|:---:|
|DCU|48.98|24.39|
|GPU|49.57|28.03|

## 应用场景

### 算法类别

`AIGC`

### 热点应用行业

`教育,科研,媒体`

## 源码仓库及问题反馈

* https://developer.hpccube.com/codes/modelzoo/taming-transformers_pytorch

## 参考资料

mashun1's avatar
mashun1 committed
290
* https://github.com/CompVis/taming-transformers