Commit 01db7703 authored by mashun1's avatar mashun1
Browse files

taming-transformer

parents
Pipeline #801 canceled with stages
*pyc*
celebahq
ffhq
sample_for_test
nohup*
logs/
results
*.egg*
sample*
imagenet_depth
\ No newline at end of file
FROM image.sourcefind.cn:5000/dcu/admin/base/pytorch:1.10.0-centos7.6-dtk-23.04-py38-latest
\ No newline at end of file
Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE
OR OTHER DEALINGS IN THE SOFTWARE./
# taming-transformers
VQGAN
## 论文
**Taming Transformers for High-Resolution Image Synthesis**
* https://arxiv.org/abs/2012.09841
## 模型结构
该模型的主体结构为加入了离散编码本的`AutoEncoder`结构,具体如下图所示,其中`E`(CNN Encoder)用于压缩提取图像中的信息,`Z`(编码本Codebook)记录图像特征,`G`(Decoder)用于生成图像,`D`(CNN Discriminator)判断生成图像的真假,`Transformer`根据现有特征预测接下来的特征(编码本索引)。
![Alt text](readme_imgs/image-1.png)
## 算法原理
该算法结合了CNN与Transformer,可用于高清图像的生成,具体如下,
1、CNN + Transformer
使用卷积神经网络(CNN)架构对成分(纹理、形状、物体以及其他视觉特征)进行建模,并使用Transformer架构对它们的组合进行建模,充分发挥了它们的互补优势。
![Alt text](readme_imgs/R-C.gif)
![Alt text](readme_imgs/image-2.png)
## 环境配置
### Docker(方法一)
docker pull image.sourcefind.cn:5000/dcu/admin/base/pytorch:1.10.0-centos7.6-dtk-23.04-py38-latest
docker run --shm-size 10g --network=host --name=taming_transformer --privileged --device=/dev/kfd --device=/dev/dri --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v 项目地址(绝对路径):/home/ -it <your IMAGE ID> bash
pip install -r requirements.txt
### Dockerfile(方法二)
# 需要在对应的目录下
docker build -t <IMAGE_NAME>:<TAG> .
# <your IMAGE ID>用以上拉取的docker的镜像ID替换
docker run -it --shm-size 10g --network=host --name=taming_transformer --privileged --device=/dev/kfd --device=/dev/dri --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined <your IMAGE ID> bash
pip install -r requirements.txt
### Anaconda (方法三)
1、关于本项目DCU显卡所需的特殊深度学习库可从光合开发者社区下载安装:
https://developer.hpccube.com/tool/
DTK驱动:dtk23.04
python:python3.8
torch:1.10.0
torchvision:0.11.1
Tips:以上dtk驱动、python、torch等DCU相关工具版本需要严格一一对应
2、其它非特殊库参照requirements.txt安装
pip install -r requirements.txt
## 数据集
|名称|URL|存放位置|
|:----|:----|:----|
|ImageNet|https://openxlab.org.cn/datasets/OpenDataLab/ImageNet-1K/tree/main|
|CelebA-HQ|https://aistudio.baidu.com/datasetdetail/49050/0|data/celebahq|
|FFHQ|https://github.com/NVlabs/ffhq-dataset|data/ffhq|
|COCO|https://openxlab.org.cn/datasets/OpenDataLab/COCO_2017/tree/main|data/coco|
|COCO-Stuff|http://calvin.inf.ed.ac.uk/wp-content/uploads/data/cocostuffdataset/stuffthingmaps_trainval2017.zip|data/cocostuffthings|
|ADE20k|https://openxlab.org.cn/datasets/OpenDataLab/ADE20K_2016|data/ade20k_root|
|OpenImages-V7|https://github.com/cvdfoundation/open-images-dataset#download-full-dataset-with-google-storage-transfer|
我们也提供了用于测试的tiny数据集,链接:https://pan.baidu.com/s/1UFK-CsMBrnOEsGCNN_P9sA
提取码:kwai
上述数据集按需下载并放入`data`中。
### 数据处理
#### ImageNet
首先将数据以该结构进行存放,其中${XDG_CACHE}默认为`~/.cache``split`表示`train``validation`
${XDG_CACHE}/autoencoders/data/ILSVRC2012_{split}/data/
├── n01440764
│ ├── n01440764_10026.JPEG
│ ├── n01440764_10027.JPEG
│ ├── ...
├── n01443537
│ ├── n01443537_10007.JPEG
│ ├── n01443537_10014.JPEG
│ ├── ...
├── ...
注意:您可以将`ILSVRC2012_img_train.tar或ILSVRC2012_img_val.tar`放入`${XDG_CACHE}/autoencoders/data/ILSVRC2012_train/ 或 ${XDG_CACHE}/autoencoders/data/ILSVRC2012_validation/`,仅当文件夹 `${XDG_CACHE}/autoencoders/data/ILSVRC2012_{split}/data/` 和文件 `${XDG_CACHE}/autoencoders/data/ILSVRC2012_{split}/.ready` 都不存在时,数据才会自动处理为上述结构。
您需要使用MiDaS准备深度数据。创建一个符号链接data/imagenet_depth,指向一个包含两个子文件夹train和val的文件夹,每个子文件夹的结构都与上面描述的相应的ImageNet文件夹相同,并包含一个为ImageNet的每个JPEG文件编码的png文件。该png文件作为RGBA图像,编码从MiDaS获得的float32深度值。我们提供了生成此数据的脚本scripts/extract_depth.py。
imagenet_depth/
├── train
│   ├── n01440764
│   │   └── n01440764_10043.png
│   ├── n01443537
│   │   └── n01443537_10482.png
│   ├── n01484850
│   │   └── n01484850_10160.png
│   ├── n01491361
│   │   └── n01491361_10353.png
├── val
├── n01440764
│   └── ILSVRC2012_val_00000293.png
├── n01443537
│   └── ILSVRC2012_val_00002848.png
├── n01484850
│   └── ILSVRC2012_val_00002338.png
├── n01491361
│   └── ILSVRC2012_val_00002922.png
├── n01494475
│   └── ILSVRC2012_val_00004417.png
├── n01496331
│   └── ILSVRC2012_val_00004698.png
├── n01498041
│   └── ILSVRC2012_val_00002284.png
├── n01514668
│   └── ILSVRC2012_val_00000329.png
注意:若无法正常运行,您还需要下载`https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1` 并放入 `${XDG_CACHE}/autoencoders/data/ILSVRC2012_{split}/synset_human.txt`,下载`https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1`并放入`${XDG_CACHE}/autoencoders/data/index_synset.yaml`,同时还需要创建`filelist.txt`其中包含`${XDG_CACHE}/autoencoders/data/ILSVRC2012_{split}/data`中的相对文件路径(`n01234/n01234_xxxx.JPEG`)。
#### FacesHQ
该数据集包含`CelebaHQ``FFHQ`,并无特定文件结构,仅需生成相应的文件列表文件即可。具体参考`data`目录下的`celebahqtrain.txt``celebahqvalidation.txt``ffhqtrain.txt`以及`ffhqvalidation.txt`,如需修改数据加载方式,可在`taming/data/faceshq.py`中进行修改。
## 训练
### FacesHQ
1.训练VQGAN
python main.py --base configs/faceshq_vqgan.yaml -t True --gpus 0,
2.训练transformer
python main.py --base configs/faceshq_transformer.yaml -t True --gpus 0,
### D-RIN
1.训练VQGAN
python main.py --base configs/imagenet_vqgan.yaml -t True --gpus 0,
python main.py --base configs/imagenetdepth_vqgan.yaml -t True --gpus 0,
2.训练transformer
python main.py --base configs/drin_transformer.yaml -t True --gpus 0,
注意:在训练transformer前需要修改`configs``model.params.first_stage_config.params.ckpt_path`以及`model.params.cond_stage_config.params.ckpt_path`(如果存在)的值(vqgan路径)。
## 推理
### 模型下载
|名称|URL|
|:-------|:----|
|S-FLCKR | https://heibox.uni-heidelberg.de/d/73487ab6e5314cb5adba/ |
|ImageNet | https://k00.fr/s511rwcv|
|FFHQ |https://k00.fr/yndvfu95|
|CelebA-HQ |https://k00.fr/2xkmielf|
|FacesHQ |https://k00.fr/qqfl2do8|
|D-RIN |https://k00.fr/39jcugc5|
|COCO |https://k00.fr/2zz6i2ce|
|ADE20k |https://k00.fr/ot46cksa|
|Scene Image Synthesis| https://drive.google.com/file/d/1FEK-Z7hyWJBvFWQF50pzSK9y1W_CJEig/view?usp=sharing <br> https://heibox.uni-heidelberg.de/f/0d0b2594e9074c7e9a33/ <br> https://drive.google.com/file/d/1bInd49g2YulTJBjU32Awyt5qnzxxG5U9/
文件结构:
logs/
└── 2020-11-09T13-31-51_sflckr
├── checkpoints
│ └── last.ckpt
└── configs
├── 2020-11-09T13-31-51-lightning.yaml
└── 2020-11-09T13-31-51-project.yaml
注意:上述模型按需下载。
### 命令
# S-FLCKR
streamlit run scripts/sample_conditional.py -- -r logs/2020-11-09T13-31-51_sflckr/
# ImageNet
# 为ImageNet的每个1000个类别生成50个样本,使用top-k采样中的k=600,nucleus采样中的p=0.92,温度t=1.0。
python scripts/sample_fast.py -r logs/2021-04-03T19-39-50_cin_transformer/ -n 50 -k 600 -t 1.0 -p 0.92 --batch_size 25
# classes用于指定生成类型(如猫,狗,鸟等)
python scripts/sample_fast.py -r logs/2021-04-03T19-39-50_cin_transformer/ -n 50 -k 600 -t 1.0 -p 0.92 --batch_size 25 --classes 9,232,901
# FFHQ
# 为了生成50000个样本,使用top-k采样中的k=250,nucleus采样中的p=1.0,温度t=1.0。
python scripts/sample_fast.py -r logs/2021-04-23T18-19-01_ffhq_transformer/
# CelebA-HQ
python scripts/sample_fast.py -r logs/2021-04-23T18-11-19_celebahq_transformer/
# FacesHQ
streamlit run scripts/sample_conditional.py -- -r logs/2020-11-13T21-41-45_faceshq_transformer/
# D-RIN
# demo数据
streamlit run scripts/sample_conditional.py -- -r logs/2020-11-20T12-54-32_drin_transformer/ --ignore_base_data data="{target: main.DataModuleFromConfig, params: {batch_size: 1, validation: {target: taming.data.imagenet.DRINExamples}}}"
# 所有validation数据(需要准备相应的ImageNet数据集,参考`训练/数据准备`)
streamlit run scripts/sample_conditional.py -- -r logs/2020-11-20T12-54-32_drin_transformer/
# COCO
streamlit run scripts/sample_conditional.py -- -r logs/2021-01-20T16-04-20_coco_transformer/ --ignore_base_data data="{target: main.DataModuleFromConfig, params: {batch_size: 1, validation: {target: taming.data.coco.Examples}}}"
# ADE20k
streamlit run scripts/sample_conditional.py -- -r logs/2020-11-20T21-45-44_ade20k_transformer/ --ignore_base_data data="{target: main.DataModuleFromConfig, params: {batch_size: 1, validation: {target: taming.data.ade20k.Examples}}}"
### Scene Image Synthesis
#### 模型下载
|名称|URL|
|:----|:-----|
|COCO-8k-VQGAN|https://heibox.uni-heidelberg.de/f/78dea9589974474c97c1/|
| COCO/Open-Images-8k-VQGAN | https://heibox.uni-heidelberg.de/f/461d9a9f4fcf48ab84f4/|
|Open Images 1 billion parameter model| https://drive.google.com/file/d/1FEK-Z7hyWJBvFWQF50pzSK9y1W_CJEig/view?usp=sharing|
|Open Images distilled version|https://drive.google.com/file/d/1xf89g0mc78J3d8Bx5YhbK4tNRNlOoYaO|
|COCO 30 epochs| https://heibox.uni-heidelberg.de/f/0d0b2594e9074c7e9a33/|
|COCO 60 epochs|https://drive.google.com/file/d/1bInd49g2YulTJBjU32Awyt5qnzxxG5U9/|
注意:上述模型中`COCO-8k-VQGAN`以及`COCO/Open-Images-8k-VQGAN`为第一阶段预训练模型,需要进一步训练后使用,这需要下载`COCO/OpenImage`的所有数据集,并修改相应配置文件`configs/xxx.yaml`中的模型和数据路径。
#### 命令
# 继续训练
# coco
python main.py --base configs/coco_scene_images_transformer.yaml -t True --gpus 0
# openimage
python main.py --base configs/open_images_scene_images_transformer.yaml -t True --gpus 0
# 推理
python scripts/make_scene_samples.py --outdir=/some/outdir -r /path/to/pretrained/model --resolution=512,512
注意:需要下载`arialuni.ttf`字体文件,并放入`taming/data/conditional_builder/font/`目录下。
## result
S-FLCKR
![Alt text](readme_imgs/image-3.png)
### 精度
指标:FID(越小越好)
数据集:FacesHQ
||nopix(随机采样)|half(补全)|
|:---|:---:|:---:|
|DCU|48.98|24.39|
|GPU|49.57|28.03|
## 应用场景
### 算法类别
`AIGC`
### 热点应用行业
`教育,科研,媒体`
## 源码仓库及问题反馈
* https://developer.hpccube.com/codes/modelzoo/taming-transformers_pytorch
## 参考资料
* https://github.com/CompVis/taming-transformers
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
model:
base_learning_rate: 4.5e-06
target: taming.models.vqgan.VQSegmentationModel
params:
embed_dim: 256
n_embed: 1024
image_key: "segmentation"
n_labels: 183
ddconfig:
double_z: false
z_channels: 256
resolution: 256
in_channels: 183
out_ch: 183
ch: 128
ch_mult:
- 1
- 1
- 2
- 2
- 4
num_res_blocks: 2
attn_resolutions:
- 16
dropout: 0.0
lossconfig:
target: taming.modules.losses.segmentation.BCELossWithQuant
params:
codebook_weight: 1.0
data:
target: main.DataModuleFromConfig
params:
batch_size: 12
train:
target: taming.data.coco.CocoImagesAndCaptionsTrain
params:
size: 296
crop_size: 256
onehot_segmentation: true
use_stuffthing: true
validation:
target: taming.data.coco.CocoImagesAndCaptionsValidation
params:
size: 256
crop_size: 256
onehot_segmentation: true
use_stuffthing: true
model:
base_learning_rate: 4.5e-06
target: taming.models.cond_transformer.Net2NetTransformer
params:
cond_stage_key: objects_bbox
transformer_config:
target: taming.modules.transformer.mingpt.GPT
params:
vocab_size: 8192
block_size: 348 # = 256 + 92 = dim(vqgan_latent_space,16x16) + dim(conditional_builder.embedding_dim)
n_layer: 40
n_head: 16
n_embd: 1408
embd_pdrop: 0.1
resid_pdrop: 0.1
attn_pdrop: 0.1
first_stage_config:
target: taming.models.vqgan.VQModel
params:
ckpt_path: ./logs/coco_8k_vqgan/coco_epoch117.ckpt # https://heibox.uni-heidelberg.de/f/78dea9589974474c97c1/
embed_dim: 256
n_embed: 8192
ddconfig:
double_z: false
z_channels: 256
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 1
- 2
- 2
- 4
num_res_blocks: 2
attn_resolutions:
- 16
dropout: 0.0
lossconfig:
target: taming.modules.losses.DummyLoss
cond_stage_config:
target: taming.models.dummy_cond_stage.DummyCondStage
params:
conditional_key: objects_bbox
data:
target: main.DataModuleFromConfig
params:
batch_size: 6
train:
target: taming.data.annotated_objects_coco.AnnotatedObjectsCoco
params:
data_path: data/coco_annotations_100 # substitute with path to full dataset
split: train
keys: [image, objects_bbox, file_name, annotations]
no_tokens: 8192
target_image_size: 256
min_object_area: 0.00001
min_objects_per_image: 2
max_objects_per_image: 30
crop_method: random-1d
random_flip: true
use_group_parameter: true
encode_crop: true
validation:
target: taming.data.annotated_objects_coco.AnnotatedObjectsCoco
params:
data_path: data/coco_annotations_100 # substitute with path to full dataset
split: validation
keys: [image, objects_bbox, file_name, annotations]
no_tokens: 8192
target_image_size: 256
min_object_area: 0.00001
min_objects_per_image: 2
max_objects_per_image: 30
crop_method: center
random_flip: false
use_group_parameter: true
encode_crop: true
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment