README.md 10.5 KB
Newer Older
Patrick von Platen's avatar
Patrick von Platen committed
1
2
<p align="center">
    <br>
Anton Lozhkov's avatar
Anton Lozhkov committed
3
    <img src="docs/source/imgs/diffusers_library.jpg" width="400"/>
Patrick von Platen's avatar
Patrick von Platen committed
4
5
6
    <br>
<p>
<p align="center">
Anton Lozhkov's avatar
Anton Lozhkov committed
7
    <a href="https://github.com/huggingface/diffusers/blob/main/LICENSE">
Patrick von Platen's avatar
Patrick von Platen committed
8
9
10
        <img alt="GitHub" src="https://img.shields.io/github/license/huggingface/datasets.svg?color=blue">
    </a>
    <a href="https://github.com/huggingface/diffusers/releases">
Anton Lozhkov's avatar
Anton Lozhkov committed
11
        <img alt="GitHub release" src="https://img.shields.io/github/release/huggingface/diffusers.svg">
Patrick von Platen's avatar
Patrick von Platen committed
12
13
14
15
16
17
18
19
20
21
22
23
24
    </a>
    <a href="CODE_OF_CONDUCT.md">
        <img alt="Contributor Covenant" src="https://img.shields.io/badge/Contributor%20Covenant-2.0-4baaaa.svg">
    </a>
</p>

🤗 Diffusers provides pretrained diffusion models across multiple modalities, such as vision and audio, and serves
as a modular toolbox for inference and training of diffusion models.

More precisely, 🤗 Diffusers offers:

- State-of-the-art diffusion pipelines that can be run in inference with just a couple of lines of code (see [src/diffusers/pipelines](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines)).
- Various noise schedulers that can be used interchangeably for the prefered speed vs. quality trade-off in inference (see [src/diffusers/schedulers](https://github.com/huggingface/diffusers/tree/main/src/diffusers/schedulers)).
Suraj Patil's avatar
Suraj Patil committed
25
- Multiple types of models, such as UNet, that can be used as building blocks in an end-to-end diffusion system (see [src/diffusers/models](https://github.com/huggingface/diffusers/tree/main/src/diffusers/models)).
Patrick von Platen's avatar
up  
Patrick von Platen committed
26
- Training examples to show how to train the most popular diffusion models (see [examples](https://github.com/huggingface/diffusers/tree/main/examples)).
Patrick von Platen's avatar
Patrick von Platen committed
27

Patrick von Platen's avatar
Patrick von Platen committed
28
## Definitions
Patrick von Platen's avatar
Patrick von Platen committed
29

Patrick von Platen's avatar
Patrick von Platen committed
30
31
**Models**: Neural network that models **p_θ(x_t-1|x_t)** (see image below) and is trained end-to-end to *denoise* a noisy input to an image.
*Examples*: UNet, Conditioned UNet, 3D UNet, Transformer UNet
Patrick von Platen's avatar
Patrick von Platen committed
32
33
34

![model_diff_1_50](https://user-images.githubusercontent.com/23423619/171610307-dab0cd8b-75da-4d4e-9f5a-5922072e2bb5.png)

Patrick von Platen's avatar
Patrick von Platen committed
35
36
37
**Schedulers**: Algorithm class for both **inference** and **training**.
The class provides functionality to compute previous image according to alpha, beta schedule as well as predict noise for training.
*Examples*: [DDPM](https://arxiv.org/abs/2006.11239), [DDIM](https://arxiv.org/abs/2010.02502), [PNDM](https://arxiv.org/abs/2202.09778), [DEIS](https://arxiv.org/abs/2204.13902)
Patrick von Platen's avatar
Patrick von Platen committed
38
39
40
41

![sampling](https://user-images.githubusercontent.com/23423619/171608981-3ad05953-a684-4c82-89f8-62a459147a07.png)
![training](https://user-images.githubusercontent.com/23423619/171608964-b3260cce-e6b4-4841-959d-7d8ba4b8d1b2.png)

Patrick von Platen's avatar
Patrick von Platen committed
42
43
**Diffusion Pipeline**: End-to-end pipeline that includes multiple diffusion models, possible text encoders, ...
*Examples*: GLIDE, Latent-Diffusion, Imagen, DALL-E 2
Patrick von Platen's avatar
Patrick von Platen committed
44
45

![imagen](https://user-images.githubusercontent.com/23423619/171609001-c3f2c1c9-f597-4a16-9843-749bf3f9431c.png)
Patrick von Platen's avatar
Patrick von Platen committed
46

Patrick von Platen's avatar
Patrick von Platen committed
47
48
49
50
51
52
53

## Philosophy

- Readability and clarity is prefered over highly optimized code. A strong importance is put on providing readable, intuitive and elementary code desgin. *E.g.*, the provided [schedulers](https://github.com/huggingface/diffusers/tree/main/src/diffusers/schedulers) are separated from the provided [models](https://github.com/huggingface/diffusers/tree/main/src/diffusers/models) and provide well-commented code that can be read alongside the original paper.
- Diffusers is **modality independent** and focusses on providing pretrained models and tools to build systems that generate **continous outputs**, *e.g.* vision and audio.
- Diffusion models and schedulers are provided as consise, elementary building blocks whereas diffusion pipelines are a collection of end-to-end diffusion systems that can be used out-of-the-box, should stay as close as possible to their original implementation and can include components of other library, such as text-encoders. Examples for diffusion pipelines are [Glide](https://github.com/openai/glide-text2im) and [Latent Diffusion](https://github.com/CompVis/latent-diffusion).

Patrick von Platen's avatar
Patrick von Platen committed
54
55
## Quickstart

Patrick von Platen's avatar
Patrick von Platen committed
56
57
58
```
git clone https://github.com/huggingface/diffusers.git
cd diffusers && pip install -e .
Patrick von Platen's avatar
Patrick von Platen committed
59
```
Patrick von Platen's avatar
Patrick von Platen committed
60

Patrick von Platen's avatar
Patrick von Platen committed
61
### 1. `diffusers` as a central modular diffusion and sampler library
Patrick von Platen's avatar
Patrick von Platen committed
62

Patrick von Platen's avatar
Patrick von Platen committed
63
64
`diffusers` is more modularized than `transformers`. The idea is that researchers and engineers can use only parts of the library easily for the own use cases.
It could become a central place for all kinds of models, schedulers, training utils and processors that one can mix and match for one's own use case.
Patrick von Platen's avatar
Patrick von Platen committed
65
Both models and schedulers should be load- and saveable from the Hub.
Patrick von Platen's avatar
Patrick von Platen committed
66

Patrick von Platen's avatar
Patrick von Platen committed
67
#### **Example for [DDPM](https://arxiv.org/abs/2006.11239):**
Patrick von Platen's avatar
Patrick von Platen committed
68
69
70

```python
import torch
Patrick von Platen's avatar
Patrick von Platen committed
71
from diffusers import UNetModel, DDPMScheduler
Patrick von Platen's avatar
Patrick von Platen committed
72
73
import PIL
import numpy as np
Patrick von Platen's avatar
Patrick von Platen committed
74
import tqdm
Patrick von Platen's avatar
Patrick von Platen committed
75

Patrick von Platen's avatar
Patrick von Platen committed
76
generator = torch.manual_seed(0)
Patrick von Platen's avatar
Patrick von Platen committed
77
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
Patrick von Platen's avatar
Patrick von Platen committed
78
79

# 1. Load models
Patrick von Platen's avatar
Patrick von Platen committed
80
noise_scheduler = DDPMScheduler.from_config("fusing/ddpm-lsun-church", tensor_format="pt")
Patrick von Platen's avatar
Patrick von Platen committed
81
unet = UNetModel.from_pretrained("fusing/ddpm-lsun-church").to(torch_device)
Patrick von Platen's avatar
Patrick von Platen committed
82
83

# 2. Sample gaussian noise
Patrick von Platen's avatar
Patrick von Platen committed
84
image = torch.randn(
Patrick von Platen's avatar
Patrick von Platen committed
85
86
	(1, unet.in_channels, unet.resolution, unet.resolution),
	generator=generator,
Patrick von Platen's avatar
Patrick von Platen committed
87
88
)
image = image.to(torch_device)
Patrick von Platen's avatar
Patrick von Platen committed
89

Patrick von Platen's avatar
Patrick von Platen committed
90
# 3. Denoise
Patrick von Platen's avatar
Patrick von Platen committed
91
92
num_prediction_steps = len(noise_scheduler)
for t in tqdm.tqdm(reversed(range(num_prediction_steps)), total=num_prediction_steps):
Patrick von Platen's avatar
Patrick von Platen committed
93
94
	# predict noise residual
	with torch.no_grad():
Patrick von Platen's avatar
Patrick von Platen committed
95
		residual = unet(image, t)
Patrick von Platen's avatar
Patrick von Platen committed
96

Patrick von Platen's avatar
Patrick von Platen committed
97
	# predict previous mean of image x_t-1
Patrick von Platen's avatar
Patrick von Platen committed
98
	pred_prev_image = noise_scheduler.step(residual, image, t)
Patrick von Platen's avatar
Patrick von Platen committed
99

Patrick von Platen's avatar
Patrick von Platen committed
100
101
102
	# optionally sample variance
	variance = 0
	if t > 0:
Patrick von Platen's avatar
Patrick von Platen committed
103
		noise = torch.randn(image.shape, generator=generator).to(image.device)
Patrick von Platen's avatar
Patrick von Platen committed
104
		variance = noise_scheduler.get_variance(t).sqrt() * noise
Patrick von Platen's avatar
Patrick von Platen committed
105

Patrick von Platen's avatar
Patrick von Platen committed
106
107
	# set current image to prev_image: x_t -> x_t-1
	image = pred_prev_image + variance
Patrick von Platen's avatar
Patrick von Platen committed
108
109
110
111
112
113
114
115
116
117
118

# 5. process image to PIL
image_processed = image.cpu().permute(0, 2, 3, 1)
image_processed = (image_processed + 1.0) * 127.5
image_processed = image_processed.numpy().astype(np.uint8)
image_pil = PIL.Image.fromarray(image_processed[0])

# 6. save image
image_pil.save("test.png")
```

Patrick von Platen's avatar
Patrick von Platen committed
119
#### **Example for [DDIM](https://arxiv.org/abs/2010.02502):**
Patrick von Platen's avatar
Patrick von Platen committed
120
121
122
123
124
125

```python
import torch
from diffusers import UNetModel, DDIMScheduler
import PIL
import numpy as np
Patrick von Platen's avatar
Patrick von Platen committed
126
import tqdm
Patrick von Platen's avatar
Patrick von Platen committed
127
128
129
130
131

generator = torch.manual_seed(0)
torch_device = "cuda" if torch.cuda.is_available() else "cpu"

# 1. Load models
Patrick von Platen's avatar
Patrick von Platen committed
132
noise_scheduler = DDIMScheduler.from_config("fusing/ddpm-celeba-hq", tensor_format="pt")
Patrick von Platen's avatar
Patrick von Platen committed
133
unet = UNetModel.from_pretrained("fusing/ddpm-celeba-hq").to(torch_device)
Patrick von Platen's avatar
Patrick von Platen committed
134
135

# 2. Sample gaussian noise
Patrick von Platen's avatar
Patrick von Platen committed
136
image = torch.randn(
Patrick von Platen's avatar
Patrick von Platen committed
137
138
	(1, unet.in_channels, unet.resolution, unet.resolution),
	generator=generator,
Patrick von Platen's avatar
Patrick von Platen committed
139
140
)
image = image.to(torch_device)
Patrick von Platen's avatar
Patrick von Platen committed
141
142
143
144
145
146

# 3. Denoise                                                                                                                                           
num_inference_steps = 50
eta = 0.0  # <- deterministic sampling

for t in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps):
Patrick von Platen's avatar
Patrick von Platen committed
147
148
149
150
151
152
	# 1. predict noise residual
	orig_t = noise_scheduler.get_orig_t(t, num_inference_steps)
	with torch.no_grad():
	    residual = unet(image, orig_t)

	# 2. predict previous mean of image x_t-1
Patrick von Platen's avatar
Patrick von Platen committed
153
	pred_prev_image = noise_scheduler.step(residual, image, t, num_inference_steps, eta)
Patrick von Platen's avatar
Patrick von Platen committed
154
155
156
157

	# 3. optionally sample variance
	variance = 0
	if eta > 0:
Patrick von Platen's avatar
Patrick von Platen committed
158
		noise = torch.randn(image.shape, generator=generator).to(image.device)
Patrick von Platen's avatar
Patrick von Platen committed
159
160
161
162
		variance = noise_scheduler.get_variance(t).sqrt() * eta * noise

	# 4. set current image to prev_image: x_t -> x_t-1
	image = pred_prev_image + variance
Patrick von Platen's avatar
Patrick von Platen committed
163
164

# 5. process image to PIL
Patrick von Platen's avatar
Patrick von Platen committed
165
166
167
168
image_processed = image.cpu().permute(0, 2, 3, 1)
image_processed = (image_processed + 1.0) * 127.5
image_processed = image_processed.numpy().astype(np.uint8)
image_pil = PIL.Image.fromarray(image_processed[0])
Patrick von Platen's avatar
Patrick von Platen committed
169

Patrick von Platen's avatar
Patrick von Platen committed
170
# 6. save image
Patrick von Platen's avatar
Patrick von Platen committed
171
image_pil.save("test.png")
Patrick von Platen's avatar
Patrick von Platen committed
172
173
```

Patrick von Platen's avatar
Patrick von Platen committed
174
### 2. `diffusers` as a collection of most important Diffusion systems (GLIDE, Dalle, ...)
Patrick von Platen's avatar
Patrick von Platen committed
175
`models` directory in repository hosts the complete code necessary for running a diffusion system as well as to train it. A `DiffusionPipeline` class allows to easily run the diffusion model in inference:
Patrick von Platen's avatar
Patrick von Platen committed
176

Patrick von Platen's avatar
Patrick von Platen committed
177
#### **Example image generation with DDPM**
Patrick von Platen's avatar
Patrick von Platen committed
178
179

```python
Suraj Patil's avatar
Suraj Patil committed
180
from diffusers import DiffusionPipeline
Patrick von Platen's avatar
Patrick von Platen committed
181
182
import PIL.Image
import numpy as np
Patrick von Platen's avatar
Patrick von Platen committed
183

Patrick von Platen's avatar
Patrick von Platen committed
184
# load model and scheduler
Suraj Patil's avatar
Suraj Patil committed
185
ddpm = DiffusionPipeline.from_pretrained("fusing/ddpm-lsun-bedroom")
Patrick von Platen's avatar
Patrick von Platen committed
186
187

# run pipeline in inference (sample random noise and denoise)
Patrick von Platen's avatar
Patrick von Platen committed
188
189
image = ddpm()

Patrick von Platen's avatar
Patrick von Platen committed
190
# process image to PIL
Patrick von Platen's avatar
Patrick von Platen committed
191
192
193
194
image_processed = image.cpu().permute(0, 2, 3, 1)
image_processed = (image_processed + 1.0) * 127.5
image_processed = image_processed.numpy().astype(np.uint8)
image_pil = PIL.Image.fromarray(image_processed[0])
Patrick von Platen's avatar
Patrick von Platen committed
195
196

# save image
Patrick von Platen's avatar
Patrick von Platen committed
197
image_pil.save("test.png")
Patrick von Platen's avatar
Patrick von Platen committed
198
199
```

Suraj Patil's avatar
Suraj Patil committed
200
#### **Text to Image generation with Latent Diffusion**
201

patil-suraj's avatar
patil-suraj committed
202
203
_Note: To use latent diffusion install transformers from [this branch](https://github.com/patil-suraj/transformers/tree/ldm-bert)._

204
205
206
207
208
```python
from diffusers import DiffusionPipeline

ldm = DiffusionPipeline.from_pretrained("fusing/latent-diffusion-text2im-large")

patil-suraj's avatar
patil-suraj committed
209
generator = torch.manual_seed(42)
210
211
212
213
214
215
216
217
218
219
220
221
222

prompt = "A painting of a squirrel eating a burger"
image = ldm([prompt], generator=generator, eta=0.3, guidance_scale=6.0, num_inference_steps=50)

image_processed = image.cpu().permute(0, 2, 3, 1)
image_processed = image_processed  * 255.
image_processed = image_processed.numpy().astype(np.uint8)
image_pil = PIL.Image.fromarray(image_processed[0])

# save image
image_pil.save("test.png")
```

Suraj Patil's avatar
Suraj Patil committed
223
 #### **Text to speech with BDDM**
Suraj Patil's avatar
Suraj Patil committed
224
225
226
227
228
229
230
231
232
233

_Follow the isnstructions [here](https://pytorch.org/hub/nvidia_deeplearningexamples_tacotron2/) to load tacotron2 model._

```python
import torch
from diffusers import BDDM, DiffusionPipeline

torch_device = "cuda"

# load the BDDM pipeline
patil-suraj's avatar
patil-suraj committed
234
bddm = DiffusionPipeline.from_pretrained("fusing/diffwave-vocoder-ljspeech")
Suraj Patil's avatar
Suraj Patil committed
235
236
237
238
239
240
241
242
243
244

# load tacotron2 to get the mel spectograms
tacotron2 = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_tacotron2', model_math='fp16')
tacotron2 = tacotron2.to(torch_device).eval()

text = "Hello world, I missed you so much."

utils = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_tts_utils')
sequences, lengths = utils.prepare_input_sequence([text])

Suraj Patil's avatar
Suraj Patil committed
245
# generate mel spectograms using text
Suraj Patil's avatar
Suraj Patil committed
246
with torch.no_grad():
Suraj Patil's avatar
Suraj Patil committed
247
    mel_spec, _, _ = tacotron2.infer(sequences, lengths)
Suraj Patil's avatar
Suraj Patil committed
248

Suraj Patil's avatar
Suraj Patil committed
249
# generate the speech by passing mel spectograms to BDDM pipeline
Suraj Patil's avatar
Suraj Patil committed
250
generator = torch.manual_seed(0)
Suraj Patil's avatar
Suraj Patil committed
251
audio = bddm(mel_spec, generator, torch_device)
Suraj Patil's avatar
Suraj Patil committed
252

Suraj Patil's avatar
Suraj Patil committed
253
# save generated audio
Suraj Patil's avatar
Suraj Patil committed
254
255
256
257
from scipy.io.wavfile import write as wavwrite
sampling_rate = 22050
wavwrite("generated_audio.wav", sampling_rate, audio.squeeze().cpu().numpy())
```