"ppocr/vscode:/vscode.git/clone" did not exist on "49895d097a242f3cb61e01ae7914f5711c09e9cb"
Commit bfa3fb86 authored by dongchy920's avatar dongchy920
Browse files

dalle2_pytorch

parents
Pipeline #1495 canceled with stages
# DALL-E 2
## 论文
- https://arxiv.org/pdf/2204.06125
## 模型结构
OpenAI的首篇从CLIP的image embedding生成图像的方法,实验证明这种方法生成的图像能够保留丰富的语义与风格分布。
<div align=center>
<img src="./images/dalle2.png"/>
</div>
## 算法原理
算法主要包括CLIP、Prior和Decoder三个部分,对三个部分进行分开训练:
- CLIP训练:
使用图文配对数据,基于对比损失训练CLIP的text encoder和img encoder编码器,目的是想在潜在空间中对文本和图象进行统一。也可以直接使用OpenAI预训练的CLIP模型;
- Prior训练:
Prior结构是论文的一个创新点,输入是文本通过CLIP的text encoder得到的文本特征,输出是预测的对应图像特征,训练时的Ground Truth是文本对应图像通过CLIP的image encoder得到的图像特征,论文中prior结构尝试使用了自回归和扩散模型两种结构,最后扩散模型的效果较好。
- Decoder训练:
Decoder将Prior生成的图像特征解码为高分辨率的图像,和Prior结构一样采用了扩散模型。Decoder由多个unet组成,从低分辨率生成高分辨率图像。在训练Prior和Decoder时,CLIP模型的参数是冻结的。
## 环境配置
### Docker(方法一)
[光源](https://www.sourcefind.cn/#/service-list)中拉取docker镜像:
```
docker pull image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.1.0-ubuntu20.04-dtk24.04.1-py3.10
```
创建容器并挂载目录进行开发:
```
docker run -it --name {name} --shm-size=1024G --device=/dev/kfd --device=/dev/dri/ --privileged --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --ulimit memlock=-1:-1 --ipc=host --network host --group-add video -v /opt/hyhal:/opt/hyhal:ro -v {}:{} {docker_image} /bin/bash
# 修改1 {name} 需要改为自定义名称,建议命名{框架_dtk版本_使用者姓名},如果有特殊用途可在命名框架前添加命名
# 修改2 {docker_image} 需要需要创建容器的对应镜像名称,如: pytorch:1.10.0-centos7.6-dtk-23.04-py37-latest【镜像名称:tag名称】
# 修改3 -v 挂载路径到容器指定路径
pip install -r requirements.txt
```
### Dockerfile(方法二)
```
cd docker
docker build --no-cache -t dalle2_pytorch:1.0 .
docker run -it --name {name} --shm-size=1024G --device=/dev/kfd --device=/dev/dri/ --privileged --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --ulimit memlock=-1:-1 --ipc=host --network host --group-add video -v /opt/hyhal:/opt/hyhal:ro -v {}:{} {docker_image} /bin/bash
pip install -r requirements.txt
```
### Anaconda(方法三)
线上节点推荐使用conda进行环境配置。
创建python=3.10的conda环境并激活
```
conda create -n dalle2 python=3.10
conda activate dalle2
```
关于本项目DCU显卡所需的特殊深度学习库可从[光合](https://developer.hpccube.com/tool/)开发者社区下载安装。
```
DTK驱动:dtk24.04
python:python3.10
pytorch:2.1.0
torchvision:0.16.0
```
安装其他依赖包
```
pip install -r requirements.txt
```
## 数据集
原项目中并未提供训练数据集,我们这里使用laion2B的中文数据集进行训练,数据集的准备包括以下步骤:
- 1、从huggingface下载laion2B中文数据集,下载parquet文件,里面是图片url+caption
huggingface数据地址:[https://huggingface.co/datasets/IDEA-CCNL/laion2B-multi-chinese-subset/tree/main](https://huggingface.co/datasets/IDEA-CCNL/laion2B-multi-chinese-subset/tree/main)
可以通过huggingface镜像进行下载:
```
# 安装配置huggingface镜像
pip install -U huggingface_hub
export HF_ENDPOINT=https://hf-mirror.com
# 下载数据集保存在laion2B-multi-chinese文件夹中
huggingface-cli download --repo-type dataset --resume-download IDEA-CCNL/laion2B-multi-chinese-subset --local-dir ./laion2B-multi-chinese
```
- 2、使用img2dataset项目将parquet文件转换为image+caption格式:
img2dataset项目地址:[https://github.com/rom1504/img2dataset](https://github.com/rom1504/img2dataset)
使用方法:
```
# 安装img2dataset
pip install img2dataset
# 数据集转换
img2dataset --url_list laion2B-multi-chinese --input_format "parquet"\
--url_col "URL" --caption_col "TEXT" --output_format webdataset\
--output_folder laion2B-multi-chinese-data --processes_count 16 --thread_count 128 --image_siz 256\
--save_additional_columns '["NSFW","similarity","LICENSE"]' --enable_wandb True
```
- 3、生成img_path和prompt配对的json文件
```
python create_json.py
```
整个数据集转换下来需要三天的时间,数据集有10个T,本项目提供小数据集用于快速实验:
[test-data](https://pan.baidu.com/s/1IlSb_J88cgTNkRmnG0wm_Q?pwd=1234)
[data.json](https://pan.baidu.com/s/1kpBIWOwxE8HWPXB-a4kWCA?pwd=1234)
## 训练
dalle2的三个组件CLIP、Prior和Decoder是单独训练的,CLIP可以使用OpenAI的预训练模型,这里先训练Prior,然后训练Decoder:
### Prior组件训练
```
python train_prior.py
```
### Decoder组件训练
```
python train_decoder.py
```
## 推理
下载预训练权重文件并解压:
[model.zip](https://pan.baidu.com/s/1jVr4mlnANQU0F1H-y1LiZw?pwd=1234)
[model.z01](https://pan.baidu.com/s/10hyZ1EeWx00OYMJ1vKnkNg?pwd=1234)
[model.z02](https://pan.baidu.com/s/16BOWpeR5qMbcc-5gFJQyXw?pwd=1234)
[model.z03](https://pan.baidu.com/s/1l2Ga_3_QHU3vAoo5AtiHiQ?pwd=1234)
[model.z04](https://pan.baidu.com/s/1koEYvyX2bmaOLVnaGBER5g?pwd=1234)
```
# 文本生成图片
python example_inference.py dream
```
## result
输入提示词为:
```
A field of flowers
5
```
模型生成图片:
<div align=center>
<img src="images/a field of flowers_0.png"/>
<img src="images/a field of flowers_1.png"/>
<img src="images/a field of flowers_2.png"/>
<img src="images/a field of flowers_3.png"/>
<img src="images/a field of flowers_4.png"/>
</div>
## 应用场景
### 算法类别
多模态
### 热点应用行业
AIGC,设计,教育
## 源码仓库及问题反馈
[https://developer.hpccube.com/codes/modelzoo/dalle2_pytorch](https://developer.hpccube.com/codes/modelzoo/dalle2_pytorch)
## 参考资料
[https://github.com/LAION-AI/dalle2-laion](https://github.com/LAION-AI/dalle2-laion)
[https://github.com/rom1504/img2dataset](https://github.com/rom1504/img2dataset)
# Configuration
The root configuration has defines the global properties of how models will be loaded.
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
\ No newline at end of file
{
"decoder": {
"unet_sources": [
{
"unet_numbers": [1],
"default_cond_scale": [1.7],
"load_model_from": {
"load_type": "url",
"path": "https://huggingface.co/laion/DALLE2-PyTorch/resolve/main/decoder/v1.0.2/latest.pth",
"cache_dir": "./models",
"filename_override": "new_decoder.pth"
}
},
{
"unet_numbers": [2],
"load_model_from": {
"load_type": "url",
"path": "https://huggingface.co/Veldrovive/upsamplers/resolve/main/working/latest.pth",
"cache_dir": "./models",
"filename_override": "second_decoder.pth"
},
"load_config_from": {
"load_type": "url",
"path": "https://huggingface.co/Veldrovive/upsamplers/raw/main/working/decoder_config.json",
"checksum_file_path": "https://huggingface.co/Veldrovive/upsamplers/raw/main/working/decoder_config.json",
"cache_dir": "./models",
"filename_override": "second_decoder_config.json"
}
}
]
},
"prior": {
"load_model_from": {
"load_type": "url",
"path": "https://huggingface.co/laion/DALLE2-PyTorch/resolve/main/prior/latest.pth",
"cache_dir": "./models",
"filename_override": "prior.pth"
},
"load_config_from": {
"load_type": "url",
"path": "https://huggingface.co/laion/DALLE2-PyTorch/raw/main/prior/prior_config.json",
"checksum_file_path": "https://huggingface.co/laion/DALLE2-PyTorch/raw/main/prior/prior_config.json",
"cache_dir": "./models"
}
},
"clip": {
"make": "openai",
"model": "ViT-L/14"
},
"devices": "cuda:0",
"strict_loading": false
}
\ No newline at end of file
{
"decoder": {
"unet_sources": [
{
"unet_numbers": [1],
"default_cond_scale": [1.7],
"load_model_from": {
"load_type": "local",
"path": "./model/new_decoder.pth",
"cache_dir": "./models",
"filename_override": "new_decoder.pth"
}
},
{
"unet_numbers": [2],
"load_model_from": {
"load_type": "local",
"path": "./model/second_decoder.pth",
"cache_dir": "./models",
"filename_override": "second_decoder.pth"
},
"load_config_from": {
"load_type": "local",
"path": "./model/second_decoder_config.json",
"checksum_file_path": "https://huggingface.co/Veldrovive/upsamplers/raw/main/working/decoder_config.json",
"cache_dir": "./models",
"filename_override": "second_decoder_config.json"
}
}
]
},
"prior": {
"load_model_from": {
"load_type": "local",
"path": "./model/prior.pth",
"cache_dir": "./models",
"filename_override": "prior.pth"
},
"load_config_from": {
"load_type": "local",
"path": "./model/prior_config.json",
"checksum_file_path": "https://huggingface.co/laion/DALLE2-PyTorch/raw/main/prior/prior_config.json",
"cache_dir": "./models"
}
},
"clip": {
"make": "openai",
"model": "./model/ViT-L-14.pt"
},
"devices": "cuda:0",
"strict_loading": false
}
\ No newline at end of file
{
"decoder": {
"unet_sources": [
{
"unet_numbers": [1],
"default_cond_scale": [1.7],
"load_model_from": {
"load_type": "local",
"path": "/public/home/dongchy920/dalle2-laion-main/models/new_decoder.pth",
"cache_dir": "./models",
"filename_override": "new_decoder.pth"
}
},
{
"unet_numbers": [2],
"load_model_from": {
"load_type": "local",
"path": "/public/home/dongchy920/dalle2-laion-main/models/second_decoder.pth",
"cache_dir": "./models",
"filename_override": "second_decoder.pth"
},
"load_config_from": {
"load_type": "local",
"path": "/public/home/dongchy920/dalle2-laion-main/models/second_decoder_config.json",
"checksum_file_path": "https://huggingface.co/Veldrovive/upsamplers/raw/main/working/decoder_config.json",
"cache_dir": "./models",
"filename_override": "second_decoder_config.json"
}
}
]
},
"prior": {
"load_model_from": {
"load_type": "local",
"path": "/public/home/dongchy920/dalle2-laion-main/models/prior.pth",
"cache_dir": "./models",
"filename_override": "prior.pth"
},
"load_config_from": {
"load_type": "local",
"path": "/public/home/dongchy920/dalle2-laion-main/models/prior_config.json",
"checksum_file_path": "https://huggingface.co/laion/DALLE2-PyTorch/raw/main/prior/prior_config.json",
"cache_dir": "./models"
}
},
"clip": {
"make": "openai",
"model": "ViT-L/14"
},
"devices": "cuda:0",
"strict_loading": false
}
\ No newline at end of file
{
"decoder": {
"unet_sources": [
{
"unet_numbers": [1],
"default_cond_scale": [1.7],
"load_model_from": {
"load_type": "local",
"path": "/mnt/dalle2-laion-main/model/new_decoder.pth",
"cache_dir": "./models",
"filename_override": "new_decoder.pth"
}
},
{
"unet_numbers": [2],
"load_model_from": {
"load_type": "local",
"path": "/mnt/dalle2-laion-main/model/second_decoder.pth",
"cache_dir": "./models",
"filename_override": "second_decoder.pth"
},
"load_config_from": {
"load_type": "local",
"path": "/mnt/dalle2-laion-main/model/second_decoder_config.json",
"checksum_file_path": "https://huggingface.co/Veldrovive/upsamplers/raw/main/working/decoder_config.json",
"cache_dir": "./models",
"filename_override": "second_decoder_config.json"
}
}
]
},
"clip": {
"make": "openai",
"model": "/mnt/dalle2-laion-main/model/ViT-L-14.pt"
},
"devices": "cuda:0",
"strict_loading": false
}
\ No newline at end of file
from pathlib import Path
import glob
import os
import time
import pdb
import shutil
import tarfile
import json
def find_images(path):
# 定义你想查找的图片格式
# image_formats = ['*.jpg','*.json','txt']
image_formats = ['*.jpg']
# 初始化一个空列表来存储找到的图片文件
images = []
# 遍历每一种图片格式
for format in image_formats:
# 使用glob查找指定格式的图片文件
for filename in glob.glob(os.path.join(path, format)):
images.append(filename)
return images
# 使用函数查找图片文件
path_to_search = './laion2B-multi-chinese-data/image-txt-all' # 替换为你的目录路径
# images = find_images(path_to_search)
# 创建或覆盖 data.jsonl 文件
num=0
with open('data.json', 'w', encoding='utf-8') as jsonl_file:
# 读取文本描述文件
for img in os.listdir(path_to_search):
if img.endswith('jpg'):
num+=1
if num%1000==0:
print(f'Processing {num}')
text_path = img[:-3]+'txt'
with open(os.path.join(path_to_search, text_path), 'r', encoding='utf-8') as text_file:
description = text_file.readlines()
if len(description)==0:
print('++'*20, text_path)
continue
else:
description = description[0].strip()
# 构造 JSON 对象
data = {
'image_path': os.path.join(path_to_search, img),
'text': description
}
# 将 JSON 对象转换为字符串并写入 JSONL 文件
jsonl_file.write(json.dumps(data) + '\n')
print("data.json 文件已生成。")
\ No newline at end of file
# DALLE2 LAION Inferencing
In order to simplify running generalized inferences against a dalle2 model, we have created a three stage process to make any inference possible.
## Simple inference
If you are not interested in the details, there are two scripts that just run text to image. In either case, you will need a powerful graphics card (with at least 16gb VRAM).
The easiest method is to use the gradio interface with you can start by navigating to the root folder and running `python gradio_inference.py`.
For a lower level platform and a place to develop your own scripts, it is easier to use the cli with `python example_inference.py`.
## Configuring the Model
Dalle2 is a multistage model constructed out of an encoder, a prior, and some number of decoders. In order to run inference, we must join together all these separately trained components into one model. To do that, we must point the model manager to the model files and tell it how to load them. This is done through a configuration `.json` file.
Generally, an end user should not attempt to write their own config unless they have trained the models themselves. This is because all models must have been trained to be compatible and trying to stitch together models where the components are not compatible will result in nonsense results.
The general structure of the config is a dictionary with the following keys:
`clip`, `prior`, and `decoder`. Each of these contains a dictionary with more specific information.
One repeating pattern in the config is the `File` type. Many of the config options are of type `File` and any using this pattern will be referenced as such.
A file has the following configuration:
| Key | Description |
| --- | --- |
| `load_type` | Either `local` or `url`. |
| `path` | The path to the file or the url of the file. |
| `checksum_file_path` | **Optional**: The path or url of the checksum of the file. This is generated automatically for files stored using huggingface repositories and does not need to be specified. |
| `cache_dir` | **Optional**: The directory to cache the file in. Should be used for `url` load type files. If not provided the file will be re-downloaded every time it is used. |
| `filename_override` | **Optional**: The name of the file to use instead of the default filename. |
### CLIP
Dalle2 uses CLIP as the encoder to turn an image or text into the encoded representation. CLIP produces embeddings for images and text in a shared representation space, but these embeddings are not equal for images and text that match.
Under the `clip` configuration, there are the following options:
| Key | Description |
| --- | --- |
| `make` | The make of the CLIP model to use. Options are `openai`, `x-clip`, or `coca`. |
| `model` | The specific model to use. These are defined by which option you choose for `make`. This should be the same model as the one used during training. |
### Diffusion Prior
The decoders will take an image embedding and turn it into an image, but CLIP has only produced a text embedding. The purpose of the prior is to take the text embeddings and convert them into image embeddings.
Under the `prior` configuration, there are the following options:
| Key | Description |
| --- | --- |
| `load_model_from` | A `File` configuration that points to the model to load. |
| `load_config_from` | **Optional**: If this is an old model, the config must be loaded separately with this `File` configuration. For the vast majority of cases this is not necessary to specify. |
| `default_sample_timesteps` | **Optional**: The number of sampling timesteps to use by default. If not specified this uses the number the prior was trained with. |
| `default_cond_scale` | **Optional**: The default conditioning scale. If not specified 1 is used. |
### Decoders
The initial decoder's purpose is to take the image embedding and turn it into a low resolution image (generally 64x64). Further decoders act as upsamplers and take the low resolution image and turn it into a higher resolution image (generally from 64x64 to 256x256 and then 256x256 to 1024x1024).
This is the most complex configuration since it involves loading multiple models. Each individual decoder has the following configuration which we will call a `SingleDecoderConfig`:
| Key | Description |
| --- | --- |
| `load_model_from` | A `File` configuration that points to the model to load. |
| `load_config_from` | **Optional**: If this is an old model, the config must be loaded separately with this `File` configuration. |
| `unet_numbers` | An array of integers that specify which unet numbers this decoder should be used for. Together, all `SingleDecoderConfig`s must include all numbers in the range [1, max unet number]. |
| `default_sample_timesteps` | **Optional**: An array of numbers that specify the sample timesteps to use for each unet being loaded from this model. |
| `default_cond_scale` | **Optional**: An array of numbers that specify the conditioning scale to use for each unet being loaded from this model. |
Under the `decoder` configuration, there are the following options:
| Key | Description |
| --- | --- |
| `unet_sources` | An array of `SingleDecoderConfig` configurations that point to the models to use for each unet. |
## Using the Configuration
The configuration is used to load the models with the `ModelManager`.
```python
from dalle2_laion import ModelLoadConfig, DalleModelManager
model_config = ModelLoadConfig.from_json_path("path/to/config.json")
model_manager = DalleModelManager(model_config)
```
This will download the requested models, check for updates using the checksums if provided, and load the model into RAM. For larger models, the ram requirements may be too large for most consumer machines to run.
## Inference Scripts
Inference scripts are convenient wrappers that make it easy to run a specific task. In the [scripts](dalle2_laion/scripts) folder there are a few basic scripts ready to run inference, but you can also make your own by implementing the `InferenceScript` abstract class.
In general, an inference script will take a model manager as the first argument to the constructor and then any other arguments that are specific to the task.
When inheriting from `InferenceScript`, the most important methods are `_sample_prior` and `_sample_decoder`.
`_sample_prior` runs the prior sampling loop and returns image embeddings.
It takes the following arguments:
| Argument | Description |
| --- | --- |
| `text_or_tokens` | A list of strings or tokenized strings to use as the conditioning. Encoding are automatically generated using the specified CLIP. |
| `cond_scale` | **Optional**: A conditioning scale to use. If not specified the default from the config is used. |
| `sample_count` | **Optional**: The number of samples to take for each text input. If not specified the default is 1. |
| `batch_size` | **Optional**: The batch size to use when sampling. If not specified the default is 100. |
| `num_samples_per_batch` | **Optional**: The number of samples to rerank when generating an image embedding. You should usually not touch this and the default is 2. |
`_sample_decoder` runs the decoder sampling loop and returns an image.
It takes the following arguments:
| Argument | Description |
| --- | --- |
| `images` or `image_embed` | Exactly one of these must be passed. `images` is an array of PIL images. If it is passed, the image embeddings will be generated using these images so variations of them will be generated. `image_embed` is an array of tensors representing precomputed image embeddings generated by the prior or by CLIP. |
| `text` or `text_encoding` | Exactly one of these must be passed if the decoder has been conditioned on text. |
| `inpaint_images` | **Optional**: If the inpainting feature is being used, this is an array of PIL images that will be masked and inpainted. |
| `inpaint_image_masks` | A list of 2D boolean tensors that indicate which pixels in the inpaint images should be inpainted. |
| `cond_scale` | **Optional**: A conditioning scale to use. If not specified the default from the config is used. |
| `sample_count` | **Optional**: The number of samples to take for each text input. If not specified the default is 1. |
| `batch_size` | **Optional**: The batch size to use when sampling. If not specified the default is 10. |
A simple implementation of a inference script is:
```python
from dalle2_laion import ModelLoadConfig, DalleModelManager
from dalle2_laion.scripts import InferenceScript
class ExampleInference(InferenceScript):
def run(self, text: str) -> PILImage.Image:
"""
Takes a string and returns a single image.
"""
text = [text]
image_embedding_map = self._sample_prior(text)
image_embedding = image_embedding_map[0][0]
image_map = self._sample_decoder(text=text, image_embed=image_embedding)
return image_map[0][0]
model_config = ModelLoadConfig.from_json_path("path/to/config.json")
model_manager = DalleModelManager(model_config)
inference = ExampleInference(model_manager)
image = inference.run("Hello World")
```
from dalle2_laion.dalle2_laion import DalleModelManager
from dalle2_laion.config import ModelLoadConfig
import dalle2_laion.scripts
\ No newline at end of file
from pathlib import Path
from dalle2_pytorch.train_configs import AdapterConfig as ClipConfig
from typing import List, Optional, Union
from enum import Enum
from pydantic import BaseModel, root_validator, ValidationError
from contextlib import contextmanager
import tempfile
import urllib.request
import json
import pdb
class LoadLocation(str, Enum):
"""
Enum for the possible locations of the data.
"""
local = "local"
url = "url"
class File(BaseModel):
load_type: LoadLocation
path: str
checksum_file_path: Optional[str] = None
cache_dir: Optional[Path] = None
filename_override: Optional[str] = None
@root_validator(pre=True)
def add_default_checksum(cls, values):
"""
When loading from url, the checksum is the best way to see if there is an update to the model.
If we are loading from specific places, we know it is already storing a checksum and we can read and compare those to check for updates.
Sources we can do this with:
1. Huggingface: If model is at https://huggingface.co/[ORG?]/[REPO]/resolve/main/[PATH_TO_MODEL.pth] we know the checksum is at https://huggingface.co/[ORG?]/[REPO]/raw/main/[PATH_TO_MODEL.pth]
"""
if values["load_type"] == LoadLocation.url:
filepath = values["path"]
existing_checksum = values["checksum_file_path"] if "checksum_file_path" in values else None
if filepath.startswith("https://huggingface.co/") and "resolve" in filepath and existing_checksum is None:
values["checksum_file_path"] = filepath.replace("resolve/main/", "raw/main/")
return values
def download_to(self, path: Path):
"""
Downloads the file to the given path
"""
assert self.load_type == LoadLocation.url
urllib.request.urlretrieve(self.path, path)
if self.checksum_file_path is not None:
urllib.request.urlretrieve(self.checksum_file_path, str(path) + ".checksum")
def download_checksum_to(self, path: Path):
"""
Downloads the checksum to the given path
"""
assert self.load_type == LoadLocation.url
assert self.checksum_file_path is not None, "No checksum file path specified"
urllib.request.urlretrieve(self.checksum_file_path, path)
def get_remote_checksum(self):
"""
Downloads the remote checksum as a tempfile and returns its content
"""
with tempfile.TemporaryDirectory() as tmpdir:
self.download_checksum_to(tmpdir + "/checksum")
with open(tmpdir + "/checksum", "r") as f:
checksum = f.read()
return checksum
@property
def filename(self):
if self.filename_override is not None:
return self.filename_override
# The filename is everything after the last '/' but before the '?' if it exists
filename = self.path.split('/')[-1]
if '?' in filename:
filename = filename.split('?')[0]
return filename
@contextmanager
def as_local_file(self, check_update: bool = True):
"""
Loads the file as a local file.
If check_update is True, it will download a new version if the checksum is different.
"""
if self.load_type == LoadLocation.local:
yield self.path
elif self.cache_dir is not None:
# Then we are caching the data in a local directory
self.cache_dir.mkdir(parents=True, exist_ok=True)
file_path = self.cache_dir / self.filename
cached_checksum_path = self.cache_dir / (self.filename + ".checksum")
if not file_path.exists():
print(f"Downloading {self.path} to {file_path}")
self.download_to(file_path)
else:
# Then we should download and compare the checksums
if self.checksum_file_path is None:
print(f'{file_path} already exists. Skipping download. No checksum found so if you think this file should be re-downloaded, delete it and try again.')
elif not cached_checksum_path.exists():
# Then we don't know if the file is up to date so we should download it
if check_update:
print(f"Checksum not found for {file_path}. Downloading it again.")
self.download_to(file_path)
else:
print(f"Checksum not found for {file_path}, but updates are disabled. Skipping download.")
else:
new_checksum = self.get_remote_checksum()
with open(cached_checksum_path, "r") as f:
old_checksum = f.read()
should_update = new_checksum != old_checksum
if should_update:
if check_update:
print(f"Checksum mismatch. Deleting {file_path} and downloading again.")
file_path.unlink()
self.download_to(file_path) # This automatically overwrites the checksum file
else:
print(f"Checksums mismatched, but updates are disabled. Skipping download.")
yield file_path
else:
# Then we are not caching and the file should be stored in a temporary directory
with tempfile.TemporaryDirectory() as tmpdir:
tmpfile = tmpdir + "/" + self.filename
self.download_to(tmpfile)
yield tmpfile
class SingleDecoderLoadConfig(BaseModel):
"""
Configuration for the single decoder load.
"""
unet_numbers: List[int]
default_sample_timesteps: Optional[List[int]] = None
default_cond_scale: Optional[List[float]] = None
load_model_from: File
# load_config_from: Optional[File] # The config may be defined within the model file if the version is high enough
load_config_from: Optional[File] = None # The config may be defined within the model file if the version is high enough
class DecoderLoadConfig(BaseModel):
"""
Configuration for the decoder load.
"""
unet_sources: List[SingleDecoderLoadConfig]
final_unet_number: int
@root_validator(pre=True)
def compute_num_unets(cls, values):
"""
Gets the final unet number
"""
unet_numbers = []
assert "unet_sources" in values, "No unet sources defined. Make sure `unet_sources` is defined in the decoder config."
for value in values["unet_sources"]:
unet_numbers.extend(value["unet_numbers"])
final_unet_number = max(unet_numbers)
values["final_unet_number"] = final_unet_number
return values
@root_validator(pre=True)
def verify_unet_numbers_valid(cls, values):
"""
The unets must go from 1 to some positive number not skipping any and not repeating any.
"""
unet_numbers = []
for value in values["unet_sources"]:
unet_numbers.extend(value["unet_numbers"])
unet_numbers.sort()
if len(unet_numbers) != len(set(unet_numbers)):
raise ValidationError("The decoder unet numbers must not repeat.")
if unet_numbers[0] != 1:
raise ValidationError("The decoder unet numbers must start from 1.")
differences = [unet_numbers[i] - unet_numbers[i - 1] for i in range(1, len(unet_numbers))]
if any(diff != 1 for diff in differences):
raise ValidationError("The decoder unet numbers must not skip any.")
return values
class PriorLoadConfig(BaseModel):
"""
Configuration for the prior load.
"""
default_sample_timesteps: Optional[int] = None
default_cond_scale: Optional[float] = None
load_model_from: File
load_config_from: Optional[File] # The config may be defined within the model file if the version is high enough
class ModelLoadConfig(BaseModel):
"""
Configuration for the model load.
"""
decoder: Optional[DecoderLoadConfig] = None
prior: Optional[PriorLoadConfig] = None
clip: Optional[ClipConfig] = None
devices: Union[List[str], str] = 'cuda:0' # The device(s) to use for model inference. If a list, the first device is used for loading.
load_on_cpu: bool = True # Whether to load the state_dict on the first device or on the cpu
strict_loading: bool = True # Whether to error on loading if the model is not compatible with the current version of the code
@classmethod
def from_json_path(cls, json_path):
with open(json_path) as f:
config = json.load(f)
return cls(**config)
\ No newline at end of file
from dataclasses import dataclass
from typing import Any, Tuple, Optional, TypeVar, Generic, List
from dalle2_laion.config import DecoderLoadConfig, SingleDecoderLoadConfig, PriorLoadConfig, ModelLoadConfig
from dalle2_pytorch import __version__ as Dalle2Version, Decoder, DiffusionPrior, Unet
from dalle2_pytorch.train_configs import TrainDecoderConfig, TrainDiffusionPriorConfig, DecoderConfig, UnetConfig, DiffusionPriorConfig
import torch
import torch.nn as nn
from packaging import version
def exists(obj: Any) -> bool:
return obj is not None
@dataclass
class DataRequirements:
image_embedding: bool
text_encoding: bool
image: bool
text: bool
can_generate_embedding: bool
image_size: int
def has_clip(self):
self.can_generate_embedding = True
def is_valid(
self,
has_image_emb: bool = False, has_text_encoding: bool = False,
has_image: bool = False, has_text: bool = False,
image_size: Optional[int] = None
):
# The image size must be equal to or greater than the required size
# Verify that the text input is valid
errors = []
is_valid = True
if self.text_encoding:
# Then we need to some way to get the text encoding
if not (has_text_encoding or (self.can_generate_embedding and has_text)):
errors.append('Text encoding is required, but no text encoding or text was provided')
is_valid = False
if self.text:
# Then this requires text be passed in explicitly
if not has_text:
errors.append('Text is required, but no text was provided')
is_valid = False
# Verify that the image input is valid
image_size_greater = exists(image_size) and image_size >= self.image_size
if self.image_embedding:
# Then we need to some way to get the image embedding
# In this case, we also need to make sure that the image size is big enough to generate the embedding
if not (has_image_emb or (self.can_generate_embedding and has_image and image_size_greater)):
errors.append('Image embedding is required, but no image embedding or image was provided or the image was too small')
is_valid = False
if self.image:
# Then this requires an image be passed in explicitly
# In this case we also need to make sure the image is big enough to be used
if not (has_image and image_size_greater):
errors.append('Image is required, but no image was provided or the image was too small')
is_valid = False
return is_valid, errors
def __add__(self, other: 'DataRequirements') -> 'DataRequirements':
return DataRequirements(
image_embedding=self.image_embedding or other.image_embedding, # If either needs an image embedding, the combination needs one
text_embedding=self.text_embedding or other.text_embedding, # If either needs a text embedding, the combination needs one
image=self.image or other.image, # If either needs an image, the combination needs it
text=self.text or other.text, # If either needs a text, the combination needs it
can_generate_embedding=self.can_generate_embedding and other.can_generate_embedding, # If either cannot generate an embedding, we know that trying to replace an embedding with raw data will not work
image_size=max(self.image_size, other.image_size) # We can downsample without loss of information, so we use the larger image size
)
ModelType = TypeVar('ModelType', Decoder, DiffusionPrior)
@dataclass
class ModelInfo(Generic[ModelType]):
model: ModelType
model_version: Optional[version.Version]
requires_clip: bool
data_requirements: DataRequirements
class DalleModelManager:
"""
Used to load priors and decoders and to provide a simple interface to run general scripts against
"""
def __init__(self, model_load_config: ModelLoadConfig, check_updates: bool = True):
"""
Downloads the models and loads them into memory.
If check_updates is True, then the models will be re-downloaded if checksums do not match.
"""
self.check_updates = check_updates
self.model_config = model_load_config
self.current_version = version.parse(Dalle2Version)
self.single_device = isinstance(model_load_config.devices, str)
self.devices = [torch.device(model_load_config.devices)] if self.single_device else [torch.device(d) for d in model_load_config.devices]
self.load_device = torch.device('cpu') if model_load_config.load_on_cpu else self.devices[0]
self.strict_loading = model_load_config.strict_loading
if model_load_config.decoder is not None:
self.decoder_info = self.load_decoder(model_load_config.decoder)
else:
self.decoder_info = None
if model_load_config.prior is not None:
self.prior_info = self.load_prior(model_load_config.prior)
else:
self.prior_info = None
if (exists(self.decoder_info) and self.decoder_info.requires_clip) or (exists(self.prior_info) and self.prior_info.requires_clip):
assert model_load_config.clip is not None, 'Your model requires clip to be loaded. Please provide a clip config.'
self.clip = model_load_config.clip.create()
# Update the data requirements to include the clip model
if exists(self.decoder_info):
self.decoder_info.data_requirements.has_clip()
if exists(self.prior_info):
self.prior_info.data_requirements.has_clip()
else:
if model_load_config.clip is not None:
print(f'WARNING: Your model does not require clip, but you provided a clip config. This will be ignored.')
def _get_decoder_data_requirements(self, decoder_config: DecoderConfig, min_unet_number: int = 1) -> DataRequirements:
"""
Returns the data requirements for a decoder
"""
return DataRequirements(
image_embedding=True,
text_encoding=any(unet_config.cond_on_text_encodings for unet_config in decoder_config.unets[min_unet_number - 1:]),
image=min_unet_number > 1, # If this is an upsampler we need an image
text=False, # Text is never required for anything
can_generate_embedding=False, # This might be added later if clip is being used
image_size=decoder_config.image_sizes[min_unet_number - 1] # The input image size is the input to the first unet we are using
)
def _load_single_decoder(self, load_config: SingleDecoderLoadConfig) -> Tuple[Decoder, DecoderConfig, Optional[version.Version], bool]:
"""
Loads a single decoder from a model and a config file
"""
unet_sample_timesteps = load_config.default_sample_timesteps
def apply_default_config(config: DecoderConfig):
if unet_sample_timesteps is not None:
base_sample_timesteps = [None] * len(config.unets)
for unet_number, timesteps in zip(load_config.unet_numbers, unet_sample_timesteps):
base_sample_timesteps[unet_number - 1] = timesteps
config.sample_timesteps = base_sample_timesteps
with load_config.load_model_from.as_local_file(check_update=self.check_updates) as model_file:
model_state_dict = torch.load(model_file, map_location=self.load_device)
if 'version' in model_state_dict:
model_version = model_state_dict['version']
if model_version != self.current_version:
print(f'WARNING: This decoder was trained on version {model_version} but the current version is {self.current_version}. This may result in the model failing to load.')
print(f'FIX: Switch to this version with `pip install DALLE2-pytorch=={model_version}`. If different models suggest different versions, you may just need to choose one.')
else:
print(f'WARNING: This decoder was trained on an old version of Dalle2. This may result in the model failing to load or it may lead to producing garbage results.')
model_version = None # No version info in the model
requires_clip = False
if 'config' in model_state_dict:
# Then we define the decoder config from this object
decoder_config = TrainDecoderConfig(**model_state_dict['config']).decoder
apply_default_config(decoder_config)
if decoder_config.clip is not None:
# We don't want to load clip with the model
requires_clip = True
decoder_config.clip = None
decoder = decoder_config.create().eval()
decoder.load_state_dict(model_state_dict['model'], strict=self.strict_loading) # If the model has a config included, then we know the model_state_dict['model'] is the actual model
else:
# In this case, the state_dict is the model itself. This means we also must load the config from an external file
assert load_config.load_config_from is not None
with load_config.load_config_from.as_local_file(check_update=self.check_updates) as config_file:
decoder_config = TrainDecoderConfig.from_json_path(config_file).decoder
apply_default_config(decoder_config)
if decoder_config.clip is not None:
# We don't want to load clip with the model
requires_clip = True
decoder_config.clip = None
decoder = decoder_config.create().eval()
decoder.load_state_dict(model_state_dict, strict=self.strict_loading)
return decoder, decoder_config, model_version, requires_clip
def load_decoder(self, load_config: DecoderLoadConfig) -> 'ModelInfo[Decoder]':
"""
Loads a decoder from a model and a config file
"""
if len(load_config.unet_sources) == 1:
# Then we are loading only one model
decoder, decoder_config, decoder_version, requires_clip = self._load_single_decoder(load_config.unet_sources[0])
decoder_data_requirements = self._get_decoder_data_requirements(decoder_config)
decoder.to(torch.float32)
return ModelInfo(decoder, decoder_version, requires_clip, decoder_data_requirements)
else:
true_unets: List[Unet] = [None] * load_config.final_unet_number # Stores the unets that will replace the ones in the true decoder
true_unet_configs: List[UnetConfig] = [None] * load_config.final_unet_number # Stores the unet configs that will replace the ones in the true decoder config
true_upsampling_sizes: List[Tuple[int, int]] = [None] * load_config.final_unet_number # Stores the progression of upsampling sizes for each unet so that we can validate these unets actually work together
true_train_timesteps: List[int] = [None] * load_config.final_unet_number # Stores the number of timesteps that each unet trained with
true_beta_schedules: List[str] = [None] * load_config.final_unet_number # Stores the beta scheduler that each unet used
true_uses_learned_variance: List[bool] = [None] * load_config.final_unet_number # Stores whether each unet uses learned variance
true_sample_timesteps: List[int] = [None] * load_config.final_unet_number # Stores the number of timesteps that each unet used to sample
requires_clip = False
for source in load_config.unet_sources:
decoder, decoder_config, decoder_version, unets_requires_clip = self._load_single_decoder(source)
if unets_requires_clip:
requires_clip = True
if source.default_sample_timesteps is not None:
assert len(source.default_sample_timesteps) == len(source.unet_numbers)
for i, unet_number in enumerate(source.unet_numbers):
unet_index = unet_number - 1
# Now we need to insert the unet into the true unets and the unet config into the true config
if source.default_sample_timesteps is not None:
true_sample_timesteps[unet_index] = source.default_sample_timesteps[i]
true_unets[unet_index] = decoder.unets[unet_index]
true_unet_configs[unet_index] = decoder_config.unets[unet_index]
true_upsampling_sizes[unet_index] = None if unet_index == 0 else decoder_config.image_sizes[unet_index - 1], decoder_config.image_sizes[unet_index]
true_train_timesteps[unet_index] = decoder_config.timesteps
true_beta_schedules[unet_index] = decoder_config.beta_schedule[unet_index]
true_uses_learned_variance[unet_index] = decoder_config.learned_variance if isinstance(decoder_config.learned_variance, bool) else decoder_config.learned_variance[unet_index]
true_decoder_config_obj = {}
# Insert the true configs into the true decoder config
true_decoder_config_obj['unets'] = true_unet_configs
true_image_sizes = []
for i in range(load_config.final_unet_number):
if i == 0:
true_image_sizes.append(true_upsampling_sizes[i][1])
else:
assert true_upsampling_sizes[i - 1][1] == true_upsampling_sizes[i][0], f"The upsampling sizes for unet {i} are not compatible with unet {i - 1}."
true_image_sizes.append(true_upsampling_sizes[i][1])
true_decoder_config_obj['image_sizes'] = true_image_sizes
# All unets must have been trained with the same number of sampling timesteps in order to be compatible
assert all(true_train_timesteps[0] == t for t in true_train_timesteps), f"All unets must have been trained with the same number of sampling timesteps in order to be compatible."
true_decoder_config_obj['timesteps'] = true_train_timesteps[0]
true_decoder_config_obj['beta_schedule'] = true_beta_schedules
true_decoder_config_obj['learned_variance'] = true_uses_learned_variance
# If any of the sample_timesteps are not None, then we need to insert them into the true decoder config
if any(true_sample_timesteps):
true_decoder_config_obj['sample_timesteps'] = true_sample_timesteps
# Now we can create the decoder and substitute the unets
true_decoder_config = DecoderConfig(**true_decoder_config_obj)
decoder_data_requirements = self._get_decoder_data_requirements(true_decoder_config)
decoder = true_decoder_config.create().eval()
decoder.unets = nn.ModuleList(true_unets)
decoder.to(torch.float32)
return ModelInfo(decoder, decoder_version, requires_clip, decoder_data_requirements)
def _get_prior_data_requirements(self, config: DiffusionPriorConfig) -> DataRequirements:
"""
Returns the data requirements for a diffusion prior
"""
return DataRequirements(
image_embedding=False, # This is kinda the whole point
text_encoding=True, # This is also kinda the whole point
image=False, # The prior is never conditioned on the image
text=False, # Text is never required for anything
can_generate_embedding=False, # This might be added later if clip is being used
image_size=[-1, -1] # This is not used
)
def load_prior(self, load_config: PriorLoadConfig) -> 'ModelInfo[DiffusionPrior]':
"""
Loads a prior from a model and a config file
"""
sample_timesteps = load_config.default_sample_timesteps
def apply_default_config(config: DiffusionPriorConfig) -> DiffusionPriorConfig:
"""
Applies the default config to the given config
"""
if sample_timesteps is not None:
config.sample_timesteps = sample_timesteps
with load_config.load_model_from.as_local_file(check_update=self.check_updates) as model_file:
model_state_dict = torch.load(model_file, map_location=self.load_device)
if 'version' in model_state_dict:
model_version = model_state_dict['version']
if model_version != self.current_version:
print(f'WARNING: This prior was trained on version {model_version} but the current version is {self.current_version}. This may result in the model failing to load.')
print(f'FIX: Switch to this version with `pip install DALLE2-pytorch=={model_version}`. If different models suggest different versions, you may just need to choose one.')
else:
print('WARNING: This prior was trained on an old version of Dalle2. This may result in the model failing to load or it may produce garbage results.')
model_version = None
requires_clip = False
if 'config' in model_state_dict:
# Then we define the prior config from this object
prior_config = TrainDiffusionPriorConfig(**model_state_dict['config']).prior
apply_default_config(prior_config)
if prior_config.clip is not None:
# We don't want to load clip with the model
prior_config.clip = None
requires_clip = True
prior = prior_config.create().eval()
prior.load_state_dict(model_state_dict['model'], strict=self.strict_loading)
else:
# In this case, the state_dict is the model itself. This means we also must load the config from an external file
assert load_config.load_config_from is not None
with load_config.load_config_from.as_local_file(check_update=self.check_updates) as config_file:
prior_config = TrainDiffusionPriorConfig.from_json_path(config_file).prior
apply_default_config(prior_config)
if prior_config.clip is not None:
# We don't want to load clip with the model
prior_config.clip = None
requires_clip = True
prior = prior_config.create().eval()
prior.load_state_dict(model_state_dict, strict=self.strict_loading)
data_requirements = self._get_prior_data_requirements(prior_config)
prior.to(torch.float32)
return ModelInfo(prior, model_version, requires_clip, data_requirements)
"""
This inference script is used to do basic inference without any bells and whistles.
Pass in text, get out image.
"""
from dalle2_laion.scripts import InferenceScript
from typing import Dict, List, Union
from PIL import Image as PILImage
import torch
class BasicInference(InferenceScript):
def run(
self,
text: Union[str, List[str]],
prior_cond_scale: float = None, decoder_cond_scale: float = None, # Use defaults from config by default
prior_sample_count: int = 1, decoder_sample_count: int = 1,
prior_batch_size: int = 100, decoder_batch_size: int = 10,
prior_num_samples_per_batch: int = 2
) -> Dict[str, Dict[int, List[PILImage.Image]]]:
"""
Takes text and generates images.
Returns a map from the text index to the image embedding index to a list of images generated for that embedding.
"""
if isinstance(text, str):
text = [text]
self.print(f"Generating images for texts: {text}")
self.print("Generating prior embeddings...")
image_embedding_map = self._sample_prior(text, cond_scale=prior_cond_scale, sample_count=prior_sample_count, batch_size=prior_batch_size, num_samples_per_batch=prior_num_samples_per_batch)
self.print("Finished generating prior embeddings.")
# image_embedding_map is a map between the text index and the generated image embeddings
image_embeddings: List[torch.Tensor] = []
decoder_text = [] # The decoder also needs the text, but since we have repeated the text embeddings, we also need to repeat the text
for i, original_text in enumerate(text):
decoder_text.extend([original_text] * len(image_embedding_map[i]))
image_embeddings.extend(image_embedding_map[i])
# In order to get the original text from the image embeddings, we need to reverse the map
image_embedding_index_reverse_map = {i: [] for i in range(len(text))}
current_count = 0
for i in range(len(text)):
for _ in range(len(image_embedding_map[i])):
image_embedding_index_reverse_map[i].append(current_count)
current_count += 1
# Now we can use the image embeddings to generate the images
self.print(f"Grouped {len(text)} texts into {len(image_embeddings)} embeddings.")
self.print("Sampling from decoder...")
image_map = self._sample_decoder(text=decoder_text, image_embed=image_embeddings, cond_scale=decoder_cond_scale, sample_count=decoder_sample_count, batch_size=decoder_batch_size)
self.print("Finished sampling from decoder.")
# Now we will reconstruct a map from text to a map of img_embedding indices to list of images
output_map: Dict[str, Dict[int, List[PILImage.Image]]] = {}
for i, prompt in enumerate(text):
output_map[prompt] = {}
embedding_indices = image_embedding_index_reverse_map[i]
for embedding_index in embedding_indices:
output_map[prompt][embedding_index] = image_map[embedding_index]
return output_map
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment