README.md 8.25 KB
Newer Older
Patrick von Platen's avatar
Patrick von Platen committed
1
2
# Diffusers

Patrick von Platen's avatar
Patrick von Platen committed
3
## Definitions
Patrick von Platen's avatar
Patrick von Platen committed
4

Patrick von Platen's avatar
Patrick von Platen committed
5
6
7
8
9
**Models**: Single neural network that models p_θ(x_t-1|x_t) and is trained to “denoise” to image
*Examples: UNet, Conditioned UNet, 3D UNet, Transformer UNet*

![model_diff_1_50](https://user-images.githubusercontent.com/23423619/171610307-dab0cd8b-75da-4d4e-9f5a-5922072e2bb5.png)

Patrick von Platen's avatar
Patrick von Platen committed
10
**Schedulers**: Algorithm to compute previous image according to alpha, beta schedule and to sample noise. Should be used for both *training* and *inference*.
Patrick von Platen's avatar
Patrick von Platen committed
11
*Example: Gaussian DDPM, DDIM, PMLS, DEIN*
Patrick von Platen's avatar
Patrick von Platen committed
12
13
14
15
16
17
18
19

![sampling](https://user-images.githubusercontent.com/23423619/171608981-3ad05953-a684-4c82-89f8-62a459147a07.png)
![training](https://user-images.githubusercontent.com/23423619/171608964-b3260cce-e6b4-4841-959d-7d8ba4b8d1b2.png)

**Diffusion Pipeline**: End-to-end pipeline that includes multiple diffusion models, possible text encoders, CLIP
*Example: GLIDE,CompVis/Latent-Diffusion, Imagen, DALL-E*

![imagen](https://user-images.githubusercontent.com/23423619/171609001-c3f2c1c9-f597-4a16-9843-749bf3f9431c.png)
Patrick von Platen's avatar
Patrick von Platen committed
20

Patrick von Platen's avatar
Patrick von Platen committed
21
22
## Quickstart

Patrick von Platen's avatar
Patrick von Platen committed
23
24
25
```
git clone https://github.com/huggingface/diffusers.git
cd diffusers && pip install -e .
Patrick von Platen's avatar
Patrick von Platen committed
26
```
Patrick von Platen's avatar
Patrick von Platen committed
27

Patrick von Platen's avatar
Patrick von Platen committed
28
### 1. `diffusers` as a central modular diffusion and sampler library
Patrick von Platen's avatar
Patrick von Platen committed
29

Patrick von Platen's avatar
Patrick von Platen committed
30
31
`diffusers` is more modularized than `transformers`. The idea is that researchers and engineers can use only parts of the library easily for the own use cases.
It could become a central place for all kinds of models, schedulers, training utils and processors that one can mix and match for one's own use case.
Patrick von Platen's avatar
Patrick von Platen committed
32
Both models and schedulers should be load- and saveable from the Hub.
Patrick von Platen's avatar
Patrick von Platen committed
33

Patrick von Platen's avatar
Patrick von Platen committed
34
#### **Example for [DDPM](https://arxiv.org/abs/2006.11239):**
Patrick von Platen's avatar
Patrick von Platen committed
35
36
37

```python
import torch
Patrick von Platen's avatar
Patrick von Platen committed
38
39
40
from diffusers import UNetModel, GaussianDDPMScheduler
import PIL
import numpy as np
Patrick von Platen's avatar
Patrick von Platen committed
41
import tqdm
Patrick von Platen's avatar
Patrick von Platen committed
42

Patrick von Platen's avatar
Patrick von Platen committed
43
generator = torch.manual_seed(0)
Patrick von Platen's avatar
Patrick von Platen committed
44
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
Patrick von Platen's avatar
Patrick von Platen committed
45
46

# 1. Load models
Patrick von Platen's avatar
Patrick von Platen committed
47
noise_scheduler = GaussianDDPMScheduler.from_config("fusing/ddpm-lsun-church")
Patrick von Platen's avatar
Patrick von Platen committed
48
unet = UNetModel.from_pretrained("fusing/ddpm-lsun-church").to(torch_device)
Patrick von Platen's avatar
Patrick von Platen committed
49
50

# 2. Sample gaussian noise
Patrick von Platen's avatar
Patrick von Platen committed
51
image = noise_scheduler.sample_noise((1, unet.in_channels, unet.resolution, unet.resolution), device=torch_device, generator=generator)
Patrick von Platen's avatar
Patrick von Platen committed
52
53

# 3. Denoise                                                                                                                                           
Patrick von Platen's avatar
Patrick von Platen committed
54
55
num_prediction_steps = len(noise_scheduler)
for t in tqdm.tqdm(reversed(range(num_prediction_steps)), total=num_prediction_steps):
Patrick von Platen's avatar
Patrick von Platen committed
56
57
	# predict noise residual
	with torch.no_grad():
Patrick von Platen's avatar
Patrick von Platen committed
58
		residual = unet(image, t)
Patrick von Platen's avatar
Patrick von Platen committed
59

Patrick von Platen's avatar
Patrick von Platen committed
60
	# predict previous mean of image x_t-1
Patrick von Platen's avatar
Patrick von Platen committed
61
	pred_prev_image = noise_scheduler.step(residual, image, t)
Patrick von Platen's avatar
Patrick von Platen committed
62

Patrick von Platen's avatar
Patrick von Platen committed
63
64
65
	# optionally sample variance
	variance = 0
	if t > 0:
Patrick von Platen's avatar
Patrick von Platen committed
66
67
		noise = noise_scheduler.sample_noise(image.shape, device=image.device, generator=generator)
		variance = noise_scheduler.get_variance(t).sqrt() * noise
Patrick von Platen's avatar
Patrick von Platen committed
68

Patrick von Platen's avatar
Patrick von Platen committed
69
70
	# set current image to prev_image: x_t -> x_t-1
	image = pred_prev_image + variance
Patrick von Platen's avatar
Patrick von Platen committed
71
72
73
74
75
76
77
78
79
80
81

# 5. process image to PIL
image_processed = image.cpu().permute(0, 2, 3, 1)
image_processed = (image_processed + 1.0) * 127.5
image_processed = image_processed.numpy().astype(np.uint8)
image_pil = PIL.Image.fromarray(image_processed[0])

# 6. save image
image_pil.save("test.png")
```

Patrick von Platen's avatar
Patrick von Platen committed
82
#### **Example for [DDIM](https://arxiv.org/abs/2010.02502):**
Patrick von Platen's avatar
Patrick von Platen committed
83
84
85
86
87
88

```python
import torch
from diffusers import UNetModel, DDIMScheduler
import PIL
import numpy as np
Patrick von Platen's avatar
Patrick von Platen committed
89
import tqdm
Patrick von Platen's avatar
Patrick von Platen committed
90
91
92
93
94
95

generator = torch.manual_seed(0)
torch_device = "cuda" if torch.cuda.is_available() else "cpu"

# 1. Load models
noise_scheduler = DDIMScheduler.from_config("fusing/ddpm-celeba-hq")
Patrick von Platen's avatar
Patrick von Platen committed
96
unet = UNetModel.from_pretrained("fusing/ddpm-celeba-hq").to(torch_device)
Patrick von Platen's avatar
Patrick von Platen committed
97
98

# 2. Sample gaussian noise
Patrick von Platen's avatar
Patrick von Platen committed
99
image = noise_scheduler.sample_noise((1, unet.in_channels, unet.resolution, unet.resolution), device=torch_device, generator=generator)
Patrick von Platen's avatar
Patrick von Platen committed
100
101
102
103
104
105

# 3. Denoise                                                                                                                                           
num_inference_steps = 50
eta = 0.0  # <- deterministic sampling

for t in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps):
Patrick von Platen's avatar
Patrick von Platen committed
106
107
108
109
110
111
	# 1. predict noise residual
	orig_t = noise_scheduler.get_orig_t(t, num_inference_steps)
	with torch.no_grad():
	    residual = unet(image, orig_t)

	# 2. predict previous mean of image x_t-1
Patrick von Platen's avatar
Patrick von Platen committed
112
	pred_prev_image = noise_scheduler.step(residual, image, t, num_inference_steps, eta)
Patrick von Platen's avatar
Patrick von Platen committed
113
114
115
116
117
118
119
120
121

	# 3. optionally sample variance
	variance = 0
	if eta > 0:
		noise = noise_scheduler.sample_noise(image.shape, device=image.device, generator=generator)
		variance = noise_scheduler.get_variance(t).sqrt() * eta * noise

	# 4. set current image to prev_image: x_t -> x_t-1
	image = pred_prev_image + variance
Patrick von Platen's avatar
Patrick von Platen committed
122
123

# 5. process image to PIL
Patrick von Platen's avatar
Patrick von Platen committed
124
125
126
127
image_processed = image.cpu().permute(0, 2, 3, 1)
image_processed = (image_processed + 1.0) * 127.5
image_processed = image_processed.numpy().astype(np.uint8)
image_pil = PIL.Image.fromarray(image_processed[0])
Patrick von Platen's avatar
Patrick von Platen committed
128

Patrick von Platen's avatar
Patrick von Platen committed
129
# 6. save image
Patrick von Platen's avatar
Patrick von Platen committed
130
image_pil.save("test.png")
Patrick von Platen's avatar
Patrick von Platen committed
131
132
```

Patrick von Platen's avatar
Patrick von Platen committed
133
### 2. `diffusers` as a collection of most important Diffusion systems (GLIDE, Dalle, ...)
Patrick von Platen's avatar
Patrick von Platen committed
134
`models` directory in repository hosts the complete code necessary for running a diffusion system as well as to train it. A `DiffusionPipeline` class allows to easily run the diffusion model in inference:
Patrick von Platen's avatar
Patrick von Platen committed
135

Patrick von Platen's avatar
Patrick von Platen committed
136
#### **Example image generation with DDPM**
Patrick von Platen's avatar
Patrick von Platen committed
137
138

```python
Suraj Patil's avatar
Suraj Patil committed
139
from diffusers import DiffusionPipeline
Patrick von Platen's avatar
Patrick von Platen committed
140
141
import PIL.Image
import numpy as np
Patrick von Platen's avatar
Patrick von Platen committed
142

Patrick von Platen's avatar
Patrick von Platen committed
143
# load model and scheduler
Suraj Patil's avatar
Suraj Patil committed
144
ddpm = DiffusionPipeline.from_pretrained("fusing/ddpm-lsun-bedroom")
Patrick von Platen's avatar
Patrick von Platen committed
145
146

# run pipeline in inference (sample random noise and denoise)
Patrick von Platen's avatar
Patrick von Platen committed
147
148
image = ddpm()

Patrick von Platen's avatar
Patrick von Platen committed
149
# process image to PIL
Patrick von Platen's avatar
Patrick von Platen committed
150
151
152
153
image_processed = image.cpu().permute(0, 2, 3, 1)
image_processed = (image_processed + 1.0) * 127.5
image_processed = image_processed.numpy().astype(np.uint8)
image_pil = PIL.Image.fromarray(image_processed[0])
Patrick von Platen's avatar
Patrick von Platen committed
154
155

# save image
Patrick von Platen's avatar
Patrick von Platen committed
156
image_pil.save("test.png")
Patrick von Platen's avatar
Patrick von Platen committed
157
158
```

159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
**Text to Image generation with Latent Diffusion**

```python
from diffusers import DiffusionPipeline

ldm = DiffusionPipeline.from_pretrained("fusing/latent-diffusion-text2im-large")

generator = torch.Generator()
generator = generator.manual_seed(6694729458485568)

prompt = "A painting of a squirrel eating a burger"
image = ldm([prompt], generator=generator, eta=0.3, guidance_scale=6.0, num_inference_steps=50)

image_processed = image.cpu().permute(0, 2, 3, 1)
image_processed = image_processed  * 255.
image_processed = image_processed.numpy().astype(np.uint8)
image_pil = PIL.Image.fromarray(image_processed[0])

# save image
image_pil.save("test.png")
```

Patrick von Platen's avatar
Patrick von Platen committed
181
182
183
184
## Library structure:

```
├── models
Patrick von Platen's avatar
Patrick von Platen committed
185
186
187
188
189
│   ├── audio
│   │   └── fastdiff
│   │       ├── modeling_fastdiff.py
│   │       ├── README.md
│   │       └── run_fastdiff.py
Patrick von Platen's avatar
Patrick von Platen committed
190
│   ├── __init__.py
Patrick von Platen's avatar
Patrick von Platen committed
191
192
193
194
195
196
│   └── vision
│       ├── dalle2
│       │   ├── modeling_dalle2.py
│       │   ├── README.md
│       │   └── run_dalle2.py
│       ├── ddpm
Patrick von Platen's avatar
Patrick von Platen committed
197
│       │   ├── example.py
Patrick von Platen's avatar
Patrick von Platen committed
198
199
200
201
202
│       │   ├── modeling_ddpm.py
│       │   ├── README.md
│       │   └── run_ddpm.py
│       ├── glide
│       │   ├── modeling_glide.py
Patrick von Platen's avatar
Patrick von Platen committed
203
│       │   ├── modeling_vqvae.py.py
Patrick von Platen's avatar
Patrick von Platen committed
204
│       │   ├── README.md
Patrick von Platen's avatar
Patrick von Platen committed
205
│       │   └── run_glide.py
Patrick von Platen's avatar
Patrick von Platen committed
206
207
208
209
│       ├── imagen
│       │   ├── modeling_dalle2.py
│       │   ├── README.md
│       │   └── run_dalle2.py
Patrick von Platen's avatar
Patrick von Platen committed
210
│       ├── __init__.py
Patrick von Platen's avatar
Patrick von Platen committed
211
212
213
214
│       └── latent_diffusion
│           ├── modeling_latent_diffusion.py
│           ├── README.md
│           └── run_latent_diffusion.py
Patrick von Platen's avatar
Patrick von Platen committed
215
216
217
218
├── pyproject.toml
├── README.md
├── setup.cfg
├── setup.py
Patrick von Platen's avatar
Patrick von Platen committed
219
220
221
222
223
224
├── src
│   └── diffusers
│       ├── configuration_utils.py
│       ├── __init__.py
│       ├── modeling_utils.py
│       ├── models
Patrick von Platen's avatar
Patrick von Platen committed
225
226
│       │   ├── __init__.py
│       │   ├── unet_glide.py
Patrick von Platen's avatar
Patrick von Platen committed
227
│       │   └── unet.py
Patrick von Platen's avatar
Patrick von Platen committed
228
│       ├── pipeline_utils.py
Patrick von Platen's avatar
Patrick von Platen committed
229
230
│       └── schedulers
│           ├── gaussian_ddpm.py
Patrick von Platen's avatar
Patrick von Platen committed
231
│           ├── __init__.py
Patrick von Platen's avatar
Patrick von Platen committed
232
233
234
├── tests
│   └── test_modeling_utils.py
```