README.md 6.24 KB
Newer Older
dongchy920's avatar
dongchy920 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
# DALL-E 2
## 论文
- https://arxiv.org/pdf/2204.06125
## 模型结构
OpenAI的首篇从CLIP的image embedding生成图像的方法,实验证明这种方法生成的图像能够保留丰富的语义与风格分布。
<div align=center>
    <img src="./images/dalle2.png"/>
</div>

## 算法原理
算法主要包括CLIP、Prior和Decoder三个部分,对三个部分进行分开训练:

- CLIP训练:  
使用图文配对数据,基于对比损失训练CLIP的text encoder和img encoder编码器,目的是想在潜在空间中对文本和图象进行统一。也可以直接使用OpenAI预训练的CLIP模型;

- Prior训练:  
Prior结构是论文的一个创新点,输入是文本通过CLIP的text encoder得到的文本特征,输出是预测的对应图像特征,训练时的Ground Truth是文本对应图像通过CLIP的image encoder得到的图像特征,论文中prior结构尝试使用了自回归和扩散模型两种结构,最后扩散模型的效果较好。

- Decoder训练:  
Decoder将Prior生成的图像特征解码为高分辨率的图像,和Prior结构一样采用了扩散模型。Decoder由多个unet组成,从低分辨率生成高分辨率图像。在训练Prior和Decoder时,CLIP模型的参数是冻结的。

## 环境配置
### Docker(方法一)
[光源](https://www.sourcefind.cn/#/service-list)中拉取docker镜像:
```
docker pull image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.1.0-ubuntu20.04-dtk24.04.1-py3.10
```
创建容器并挂载目录进行开发:
```
docker run -it --name {name} --shm-size=1024G  --device=/dev/kfd --device=/dev/dri/ --privileged --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --ulimit memlock=-1:-1 --ipc=host --network host --group-add video -v /opt/hyhal:/opt/hyhal:ro -v {}:{} {docker_image} /bin/bash
# 修改1 {name} 需要改为自定义名称,建议命名{框架_dtk版本_使用者姓名},如果有特殊用途可在命名框架前添加命名
# 修改2 {docker_image} 需要需要创建容器的对应镜像名称,如: pytorch:1.10.0-centos7.6-dtk-23.04-py37-latest【镜像名称:tag名称】
# 修改3 -v 挂载路径到容器指定路径
pip install -r requirements.txt
```
### Dockerfile(方法二)
```
cd docker
docker build --no-cache -t dalle2_pytorch:1.0 .
docker run -it --name {name} --shm-size=1024G  --device=/dev/kfd --device=/dev/dri/ --privileged --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --ulimit memlock=-1:-1 --ipc=host --network host --group-add video -v /opt/hyhal:/opt/hyhal:ro -v {}:{} {docker_image} /bin/bash 
pip install -r requirements.txt
```
### Anaconda(方法三)
线上节点推荐使用conda进行环境配置。
创建python=3.10的conda环境并激活
```
conda create -n dalle2 python=3.10
conda activate dalle2
```

关于本项目DCU显卡所需的特殊深度学习库可从[光合](https://developer.hpccube.com/tool/)开发者社区下载安装。
```
DTK驱动:dtk24.04
python:python3.10
pytorch:2.1.0
torchvision:0.16.0
```
安装其他依赖包
```
pip install -r requirements.txt
```
## 数据集 
原项目中并未提供训练数据集,我们这里使用laion2B的中文数据集进行训练,数据集的准备包括以下步骤:  
- 1、从huggingface下载laion2B中文数据集,下载parquet文件,里面是图片url+caption  
huggingface数据地址:[https://huggingface.co/datasets/IDEA-CCNL/laion2B-multi-chinese-subset/tree/main](https://huggingface.co/datasets/IDEA-CCNL/laion2B-multi-chinese-subset/tree/main)  
可以通过huggingface镜像进行下载: 
```
# 安装配置huggingface镜像
pip install -U huggingface_hub
export HF_ENDPOINT=https://hf-mirror.com
# 下载数据集保存在laion2B-multi-chinese文件夹中
huggingface-cli download --repo-type dataset --resume-download IDEA-CCNL/laion2B-multi-chinese-subset --local-dir ./laion2B-multi-chinese
```

- 2、使用img2dataset项目将parquet文件转换为image+caption格式:  
img2dataset项目地址:[https://github.com/rom1504/img2dataset](https://github.com/rom1504/img2dataset)
使用方法:  
```
# 安装img2dataset
pip install img2dataset
# 数据集转换
img2dataset --url_list laion2B-multi-chinese --input_format "parquet"\
         --url_col "URL" --caption_col "TEXT" --output_format webdataset\
           --output_folder laion2B-multi-chinese-data --processes_count 16 --thread_count 128 --image_siz 256\
             --save_additional_columns '["NSFW","similarity","LICENSE"]' --enable_wandb True
```

- 3、生成img_path和prompt配对的json文件  
```
python create_json.py
```
整个数据集转换下来需要三天的时间,数据集有10个T,本项目提供小数据集用于快速实验:  
[test-data](https://pan.baidu.com/s/1IlSb_J88cgTNkRmnG0wm_Q?pwd=1234)  
[data.json](https://pan.baidu.com/s/1kpBIWOwxE8HWPXB-a4kWCA?pwd=1234)

## 训练
dalle2的三个组件CLIP、Prior和Decoder是单独训练的,CLIP可以使用OpenAI的预训练模型,这里先训练Prior,然后训练Decoder:  
### Prior组件训练
```
python train_prior.py
```

### Decoder组件训练
```
python train_decoder.py
```

## 推理
下载预训练权重文件并解压:  
dongchy920's avatar
dongchy920 committed
110
111
112
113
114
[model.zip](https://pan.baidu.com/s/1GdDN8zt8mrqvbJELtcF3ng?pwd=1234)  
[model.z01](https://pan.baidu.com/s/1hRLiDZE28jigEriFcQe0BQ?pwd=1234)  
[model.z02](https://pan.baidu.com/s/1B9VnzzXBP549EIO6aAP_sw?pwd=1234)  
[model.z03](https://pan.baidu.com/s/1RoTFTIkRHJw34sKpsHrT1w?pwd=1234)  
[model.z04](https://pan.baidu.com/s/1UCnXLKreoNqR7lXFw291LA?pwd=1234)  
dongchy920's avatar
dongchy920 committed
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
```
# 文本生成图片
python example_inference.py dream
```

## result
输入提示词为:
```
A field of flowers
5
```

模型生成图片:
<div align=center>
    <img src="images/a field of flowers_0.png"/>
    <img src="images/a field of flowers_1.png"/>
    <img src="images/a field of flowers_2.png"/>
    <img src="images/a field of flowers_3.png"/>
    <img src="images/a field of flowers_4.png"/>
</div>

## 应用场景
### 算法类别
多模态

### 热点应用行业
AIGC,设计,教育


## 源码仓库及问题反馈
[https://developer.hpccube.com/codes/modelzoo/dalle2_pytorch](https://developer.hpccube.com/codes/modelzoo/dalle2_pytorch)
## 参考资料
[https://github.com/LAION-AI/dalle2-laion](https://github.com/LAION-AI/dalle2-laion)
[https://github.com/rom1504/img2dataset](https://github.com/rom1504/img2dataset)