# DALL-E 2 ## 论文 - https://arxiv.org/pdf/2204.06125 ## 模型结构 OpenAI的首篇从CLIP的image embedding生成图像的方法,实验证明这种方法生成的图像能够保留丰富的语义与风格分布。
## 算法原理 算法主要包括CLIP、Prior和Decoder三个部分,对三个部分进行分开训练: - CLIP训练: 使用图文配对数据,基于对比损失训练CLIP的text encoder和img encoder编码器,目的是想在潜在空间中对文本和图象进行统一。也可以直接使用OpenAI预训练的CLIP模型; - Prior训练: Prior结构是论文的一个创新点,输入是文本通过CLIP的text encoder得到的文本特征,输出是预测的对应图像特征,训练时的Ground Truth是文本对应图像通过CLIP的image encoder得到的图像特征,论文中prior结构尝试使用了自回归和扩散模型两种结构,最后扩散模型的效果较好。 - Decoder训练: Decoder将Prior生成的图像特征解码为高分辨率的图像,和Prior结构一样采用了扩散模型。Decoder由多个unet组成,从低分辨率生成高分辨率图像。在训练Prior和Decoder时,CLIP模型的参数是冻结的。 ## 环境配置 ### Docker(方法一) 从[光源](https://www.sourcefind.cn/#/service-list)中拉取docker镜像: ``` docker pull image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.1.0-ubuntu20.04-dtk24.04.1-py3.10 ``` 创建容器并挂载目录进行开发: ``` docker run -it --name {name} --shm-size=1024G --device=/dev/kfd --device=/dev/dri/ --privileged --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --ulimit memlock=-1:-1 --ipc=host --network host --group-add video -v /opt/hyhal:/opt/hyhal:ro -v {}:{} {docker_image} /bin/bash # 修改1 {name} 需要改为自定义名称,建议命名{框架_dtk版本_使用者姓名},如果有特殊用途可在命名框架前添加命名 # 修改2 {docker_image} 需要需要创建容器的对应镜像名称,如: pytorch:1.10.0-centos7.6-dtk-23.04-py37-latest【镜像名称:tag名称】 # 修改3 -v 挂载路径到容器指定路径 pip install -r requirements.txt ``` ### Dockerfile(方法二) ``` cd docker docker build --no-cache -t dalle2_pytorch:1.0 . docker run -it --name {name} --shm-size=1024G --device=/dev/kfd --device=/dev/dri/ --privileged --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --ulimit memlock=-1:-1 --ipc=host --network host --group-add video -v /opt/hyhal:/opt/hyhal:ro -v {}:{} {docker_image} /bin/bash pip install -r requirements.txt ``` ### Anaconda(方法三) 线上节点推荐使用conda进行环境配置。 创建python=3.10的conda环境并激活 ``` conda create -n dalle2 python=3.10 conda activate dalle2 ``` 关于本项目DCU显卡所需的特殊深度学习库可从[光合](https://developer.hpccube.com/tool/)开发者社区下载安装。 ``` DTK驱动:dtk24.04.1 python:python3.10 pytorch:2.1.0 torchvision:0.16.0 ``` 安装其他依赖包 ``` pip install -r requirements.txt ``` ## 数据集 原项目中并未提供训练数据集,我们这里使用laion2B的中文数据集进行训练,数据集的准备包括以下步骤: - 1、从huggingface下载laion2B中文数据集,下载parquet文件,里面是图片url+caption huggingface数据地址:[https://huggingface.co/datasets/IDEA-CCNL/laion2B-multi-chinese-subset/tree/main](https://huggingface.co/datasets/IDEA-CCNL/laion2B-multi-chinese-subset/tree/main) 可以通过huggingface镜像进行下载: ``` # 安装配置huggingface镜像 pip install -U huggingface_hub export HF_ENDPOINT=https://hf-mirror.com # 下载数据集保存在laion2B-multi-chinese文件夹中 huggingface-cli download --repo-type dataset --resume-download IDEA-CCNL/laion2B-multi-chinese-subset --local-dir ./laion2B-multi-chinese ``` - 2、使用img2dataset项目将parquet文件转换为image+caption格式: img2dataset项目地址:[https://github.com/rom1504/img2dataset](https://github.com/rom1504/img2dataset) 使用方法: ``` # 安装img2dataset pip install img2dataset # 数据集转换 img2dataset --url_list laion2B-multi-chinese --input_format "parquet"\ --url_col "URL" --caption_col "TEXT" --output_format webdataset\ --output_folder laion2B-multi-chinese-data --processes_count 16 --thread_count 128 --image_siz 256\ --save_additional_columns '["NSFW","similarity","LICENSE"]' --enable_wandb True ``` - 3、生成img_path和prompt配对的json文件 ``` python create_json.py ``` 整个数据集转换下来需要三天的时间,数据集有10个T,本项目提供小数据集用于快速实验: [test-data](https://pan.baidu.com/s/1IlSb_J88cgTNkRmnG0wm_Q?pwd=1234) [data.json](https://pan.baidu.com/s/1kpBIWOwxE8HWPXB-a4kWCA?pwd=1234) ## 训练 dalle2的三个组件CLIP、Prior和Decoder是单独训练的,CLIP可以使用OpenAI的预训练模型,这里先训练Prior,然后训练Decoder: ### Prior组件训练 ``` python train_prior.py ``` ### Decoder组件训练 ``` python train_decoder.py ``` ## 推理 下载预训练权重文件并解压: [model.zip](https://pan.baidu.com/s/1GdDN8zt8mrqvbJELtcF3ng?pwd=1234) [model.z01](https://pan.baidu.com/s/1hRLiDZE28jigEriFcQe0BQ?pwd=1234) [model.z02](https://pan.baidu.com/s/1B9VnzzXBP549EIO6aAP_sw?pwd=1234) [model.z03](https://pan.baidu.com/s/1RoTFTIkRHJw34sKpsHrT1w?pwd=1234) [model.z04](https://pan.baidu.com/s/1UCnXLKreoNqR7lXFw291LA?pwd=1234) 可通过SCNet快速下载链接[http://113.200.138.88:18080/aimodels/dalle2_pytorch/-/tree/main](http://113.200.138.88:18080/aimodels/dalle2_pytorch/-/tree/main)进行下载 ``` # 文本生成图片 python example_inference.py dream ``` ## result 输入提示词为: ``` A field of flowers 5 ``` 模型生成图片:
## 应用场景 ### 算法类别 多模态 ### 热点应用行业 AIGC,设计,教育 ## 源码仓库及问题反馈 [https://developer.hpccube.com/codes/modelzoo/dalle2_pytorch](https://developer.hpccube.com/codes/modelzoo/dalle2_pytorch) ## 参考资料 [https://github.com/LAION-AI/dalle2-laion](https://github.com/LAION-AI/dalle2-laion) [https://github.com/rom1504/img2dataset](https://github.com/rom1504/img2dataset)