fp16.md 15.7 KB
Newer Older
Patrick von Platen's avatar
Patrick von Platen committed
1
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
Nathan Lambert's avatar
Nathan Lambert committed
2
3
4
5
6
7
8
9
10
11
12

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->

Pedro Cuenca's avatar
Pedro Cuenca committed
13
# Memory and speed
Patrick von Platen's avatar
Patrick von Platen committed
14

15
16
17
We present some techniques and ideas to optimize 🤗 Diffusers _inference_ for memory or speed. As a general rule, we recommend the use of [xFormers](https://github.com/facebookresearch/xformers) for memory efficient attention, please see the recommended [installation instructions](xformers).

We'll discuss how the following settings impact performance and memory.
Patrick von Platen's avatar
Patrick von Platen committed
18

19
|                  | Latency | Speedup |
20
| ---------------- | ------- | ------- |
21
| original         | 9.50s   | x1      |
22
23
| fp16             | 3.61s   | x2.63   |
| channels last    | 3.30s   | x2.88   |
24
| traced UNet      | 3.21s   | x2.96   |
25
| memory efficient attention  | 2.63s  | x3.61   |
26

27
28
29
30
31
<em>
  obtained on NVIDIA TITAN RTX by generating a single image of size 512x512 from
  the prompt "a photo of an astronaut riding a horse on mars" with 50 DDIM
  steps.
</em>
32
33
34
35
36
37
38
39
40
41
42

### Use tf32 instead of fp32 (on Ampere and later CUDA devices)

On Ampere and later CUDA devices matrix multiplications and convolutions can use the TensorFloat32 (TF32) mode for faster but slightly less accurate computations. By default PyTorch enables TF32 mode for convolutions but not matrix multiplications, and unless a network requires full float32 precision we recommend enabling this setting for matrix multiplications, too. It can significantly speed up computations with typically negligible loss of numerical accuracy. You can read more about it [here](https://huggingface.co/docs/transformers/v4.18.0/en/performance#tf32). All you need to do is to add this before your inference:

```python
import torch

torch.backends.cuda.matmul.allow_tf32 = True
```

Pedro Cuenca's avatar
Pedro Cuenca committed
43
44
## Half precision weights

45
To save more GPU memory and get more speed, you can load and run the model weights directly in half precision. This involves loading the float16 version of the weights, which was saved to a branch named `fp16`, and telling PyTorch to use the `float16` type when loading them:
Pedro Cuenca's avatar
Pedro Cuenca committed
46
47

```Python
48
49
50
51
import torch
from diffusers import DiffusionPipeline

pipe = DiffusionPipeline.from_pretrained(
apolinario's avatar
apolinario committed
52
    "runwayml/stable-diffusion-v1-5",
Pedro Cuenca's avatar
Pedro Cuenca committed
53
    torch_dtype=torch.float16,
54
    use_safetensors=True,
Pedro Cuenca's avatar
Pedro Cuenca committed
55
)
56
57
58
pipe = pipe.to("cuda")

prompt = "a photo of an astronaut riding a horse on mars"
59
image = pipe(prompt).images[0]
Patrick von Platen's avatar
Patrick von Platen committed
60
61
```

62
<Tip warning={true}>
Pedro Cuenca's avatar
Pedro Cuenca committed
63

64
65
  It is strongly discouraged to make use of [`torch.autocast`](https://pytorch.org/docs/stable/amp.html#torch.autocast) in any of the pipelines as it can lead to black images and is always slower than using pure 
  float16 precision.
Pedro Cuenca's avatar
Pedro Cuenca committed
66
  
67
68
</Tip>

69
70
71
72
## Sliced VAE decode for larger batches

To decode large batches of images with limited VRAM, or to enable batches with 32 images or more, you can use sliced VAE decode that decodes the batch latents one image at a time.

73
You likely want to couple this with [`~StableDiffusionPipeline.enable_xformers_memory_efficient_attention`] to further minimize memory use.
74
75
76
77
78
79
80
81
82
83

To perform the VAE decode one image at a time, invoke [`~StableDiffusionPipeline.enable_vae_slicing`] in your pipeline before inference. For example:

```Python
import torch
from diffusers import StableDiffusionPipeline

pipe = StableDiffusionPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    torch_dtype=torch.float16,
84
    use_safetensors=True,
85
86
87
88
89
90
91
92
93
94
95
)
pipe = pipe.to("cuda")

prompt = "a photo of an astronaut riding a horse on mars"
pipe.enable_vae_slicing()
images = pipe([prompt] * 32).images
```

You may see a small performance boost in VAE decode on multi-image batches. There should be no performance impact on single-image batches.


96
97
98
99
## Tiled VAE decode and encode for large images

Tiled VAE processing makes it possible to work with large images on limited VRAM. For example, generating 4k images in 8GB of VRAM. Tiled VAE decoder splits the image into overlapping tiles, decodes the tiles, and blends the outputs to make the final image.

100
You want to couple this with [`~StableDiffusionPipeline.enable_xformers_memory_efficient_attention`] to further minimize memory use.
101
102
103
104
105
106
107
108
109
110

To use tiled VAE processing, invoke [`~StableDiffusionPipeline.enable_vae_tiling`] in your pipeline before inference. For example:

```python
import torch
from diffusers import StableDiffusionPipeline, UniPCMultistepScheduler

pipe = StableDiffusionPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    torch_dtype=torch.float16,
111
    use_safetensors=True,
112
113
114
115
116
117
118
119
120
121
122
123
124
)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to("cuda")
prompt = "a beautiful landscape photograph"
pipe.enable_vae_tiling()
pipe.enable_xformers_memory_efficient_attention()

image = pipe([prompt], width=3840, height=2224, num_inference_steps=20).images[0]
```

The output image will have some tile-to-tile tone variation from the tiles having separate decoders, but you shouldn't see sharp seams between the tiles. The tiling is turned off for images that are 512x512 or smaller.


125
<a name="sequential_offloading"></a>
126
127
## Offloading to CPU with accelerate for memory savings

128
For additional memory savings, you can offload the weights to CPU and only load them to GPU when performing the forward pass.
129
130
131
132
133
134
135
136
137
138

To perform CPU offloading, all you have to do is invoke [`~StableDiffusionPipeline.enable_sequential_cpu_offload`]:

```Python
import torch
from diffusers import StableDiffusionPipeline

pipe = StableDiffusionPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    torch_dtype=torch.float16,
139
    use_safetensors=True,
140
141
142
143
144
145
146
)

prompt = "a photo of an astronaut riding a horse on mars"
pipe.enable_sequential_cpu_offload()
image = pipe(prompt).images[0]
```

147
And you can get the memory consumption to < 3GB.
148

149
150
151
152
153
154
155
Note that this method works at the submodule level, not on whole models. This is the best way to minimize memory consumption, but inference is much slower due to the iterative nature of the process. The UNet component of the pipeline runs several times (as many as `num_inference_steps`); each time, the different submodules of the UNet are sequentially onloaded and then offloaded as they are needed, so the number of memory transfers is large.

<Tip>
Consider using <a href="#model_offloading">model offloading</a> as another point in the optimization space: it will be much faster, but memory savings won't be as large.
</Tip>

It is also possible to chain offloading with attention slicing for minimal memory consumption (< 2GB).
156
157
158
159
160
161
162
163

```Python
import torch
from diffusers import StableDiffusionPipeline

pipe = StableDiffusionPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    torch_dtype=torch.float16,
164
    use_safetensors=True,
165
166
167
168
169
170
171
172
)

prompt = "a photo of an astronaut riding a horse on mars"
pipe.enable_sequential_cpu_offload()

image = pipe(prompt).images[0]
```

173
174
**Note**: When using `enable_sequential_cpu_offload()`, it is important to **not** move the pipeline to CUDA beforehand or else the gain in memory consumption will only be minimal. See [this issue](https://github.com/huggingface/diffusers/issues/1934) for more information.

175
176
**Note**: `enable_sequential_cpu_offload()` is a stateful operation that installs hooks on the models.

177
178
179
180
181
182
183
184
185

<a name="model_offloading"></a>
## Model offloading for fast inference and memory savings

[Sequential CPU offloading](#sequential_offloading), as discussed in the previous section, preserves a lot of memory but makes inference slower, because submodules are moved to GPU as needed, and immediately returned to CPU when a new module runs.

Full-model offloading is an alternative that moves whole models to the GPU, instead of handling each model's constituent _modules_. This results in a negligible impact on inference time (compared with moving the pipeline to `cuda`), while still providing some memory savings.

In this scenario, only one of the main components of the pipeline (typically: text encoder, unet and vae)
M. Tolga Cangöz's avatar
M. Tolga Cangöz committed
186
will be in the GPU while the others wait in the CPU. Components like the UNet that run for multiple iterations will stay on GPU until they are no longer needed.
187
188
189
190
191
192
193
194
195
196

This feature can be enabled by invoking `enable_model_cpu_offload()` on the pipeline, as shown below.

```Python
import torch
from diffusers import StableDiffusionPipeline

pipe = StableDiffusionPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5",  
    torch_dtype=torch.float16,
197
    use_safetensors=True,
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
)

prompt = "a photo of an astronaut riding a horse on mars"
pipe.enable_model_cpu_offload()
image = pipe(prompt).images[0]
```

This is also compatible with attention slicing for additional memory savings.

```Python
import torch
from diffusers import StableDiffusionPipeline

pipe = StableDiffusionPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    torch_dtype=torch.float16,
214
    use_safetensors=True,
215
216
217
218
219
220
221
222
223
224
225
226
)

prompt = "a photo of an astronaut riding a horse on mars"
pipe.enable_model_cpu_offload()

image = pipe(prompt).images[0]
```

<Tip>
This feature requires `accelerate` version 0.17.0 or larger.
</Tip>

227
228
229
230
231
**Note**: `enable_model_cpu_offload()` is a stateful operation that installs hooks on the models and state on the pipeline. In order to properly offload
models after they are called, it is required that the entire pipeline is run and models are called in the order the pipeline expects them to be. Exercise caution
if models are re-used outside the context of the pipeline after hooks have been installed. See [accelerate](https://huggingface.co/docs/accelerate/v0.18.0/en/package_reference/big_modeling#accelerate.hooks.remove_hook_from_module)
for further docs on removing hooks.

232
233
234
235
236
237
238
239
240
241
242
## Using Channels Last memory format

Channels last memory format is an alternative way of ordering NCHW tensors in memory preserving dimensions ordering. Channels last tensors ordered in such a way that channels become the densest dimension (aka storing images pixel-per-pixel). Since not all operators currently support channels last format it may result in a worst performance, so it's better to try it and see if it works for your model.

For example, in order to set the UNet model in our pipeline to use channels last format, we can use the following:

```python
print(pipe.unet.conv_out.state_dict()["weight"].stride())  # (2880, 9, 3, 1)
pipe.unet.to(memory_format=torch.channels_last)  # in-place operation
print(
    pipe.unet.conv_out.state_dict()["weight"].stride()
Yuta Hayashibe's avatar
Yuta Hayashibe committed
243
)  # (2880, 1, 960, 320) having a stride of 1 for the 2nd dimension proves that it works
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
```

## Tracing

Tracing runs an example input tensor through your model, and captures the operations that are invoked as that input makes its way through the model's layers so that an executable or `ScriptFunction` is returned that will be optimized using just-in-time compilation.

To trace our UNet model, we can use the following:

```python
import time
import torch
from diffusers import StableDiffusionPipeline
import functools

# torch disable grad
torch.set_grad_enabled(False)

# set variables
n_experiments = 2
unet_runs_per_experiment = 50

265

266
267
268
269
270
271
272
273
274
# load inputs
def generate_inputs():
    sample = torch.randn(2, 4, 64, 64).half().cuda()
    timestep = torch.rand(1).half().cuda() * 999
    encoder_hidden_states = torch.randn(2, 77, 768).half().cuda()
    return sample, timestep, encoder_hidden_states


pipe = StableDiffusionPipeline.from_pretrained(
apolinario's avatar
apolinario committed
275
    "runwayml/stable-diffusion-v1-5",
276
    torch_dtype=torch.float16,
277
    use_safetensors=True,
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
).to("cuda")
unet = pipe.unet
unet.eval()
unet.to(memory_format=torch.channels_last)  # use channels_last memory format
unet.forward = functools.partial(unet.forward, return_dict=False)  # set return_dict=False as default

# warmup
for _ in range(3):
    with torch.inference_mode():
        inputs = generate_inputs()
        orig_output = unet(*inputs)

# trace
print("tracing..")
unet_traced = torch.jit.trace(unet, inputs)
unet_traced.eval()
print("done tracing")


# warmup and optimize graph
for _ in range(5):
    with torch.inference_mode():
        inputs = generate_inputs()
        orig_output = unet_traced(*inputs)


# benchmarking
with torch.inference_mode():
    for _ in range(n_experiments):
        torch.cuda.synchronize()
        start_time = time.time()
        for _ in range(unet_runs_per_experiment):
            orig_output = unet_traced(*inputs)
        torch.cuda.synchronize()
        print(f"unet traced inference took {time.time() - start_time:.2f} seconds")
    for _ in range(n_experiments):
        torch.cuda.synchronize()
        start_time = time.time()
        for _ in range(unet_runs_per_experiment):
            orig_output = unet(*inputs)
        torch.cuda.synchronize()
        print(f"unet inference took {time.time() - start_time:.2f} seconds")

# save the model
unet_traced.save("unet_traced.pt")
```

Then we can replace the `unet` attribute of the pipeline with the traced model like the following

```python
from diffusers import StableDiffusionPipeline
import torch
from dataclasses import dataclass


@dataclass
class UNet2DConditionOutput:
    sample: torch.FloatTensor


pipe = StableDiffusionPipeline.from_pretrained(
apolinario's avatar
apolinario committed
339
    "runwayml/stable-diffusion-v1-5",
340
    torch_dtype=torch.float16,
341
    use_safetensors=True,
342
343
344
345
).to("cuda")

# use jitted unet
unet_traced = torch.jit.load("unet_traced.pt")
346
347


348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
# del pipe.unet
class TracedUNet(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.in_channels = pipe.unet.in_channels
        self.device = pipe.unet.device

    def forward(self, latent_model_input, t, encoder_hidden_states):
        sample = unet_traced(latent_model_input, t, encoder_hidden_states)[0]
        return UNet2DConditionOutput(sample=sample)


pipe.unet = TracedUNet()

with torch.inference_mode():
    image = pipe([prompt] * 1, num_inference_steps=50).images[0]
```
365
366
367


## Memory Efficient Attention
368
369
370

Recent work on optimizing the bandwitdh in the attention block has generated huge speed ups and gains in GPU memory usage. The most recent being Flash Attention from @tridao: [code](https://github.com/HazyResearch/flash-attention), [paper](https://arxiv.org/pdf/2205.14135.pdf).

371
372
373
374
375
376
377
378
379
380
381
382
Here are the speedups we obtain on a few Nvidia GPUs when running the inference at 512x512 with a batch size of 1 (one prompt):

| GPU              	| Base Attention FP16 	| Memory Efficient Attention FP16 	|
|------------------	|---------------------	|---------------------------------	|
| NVIDIA Tesla T4  	| 3.5it/s             	| 5.5it/s                         	|
| NVIDIA 3060 RTX  	| 4.6it/s             	| 7.8it/s                         	|
| NVIDIA A10G      	| 8.88it/s            	| 15.6it/s                        	|
| NVIDIA RTX A6000 	| 11.7it/s            	| 21.09it/s                       	|
| NVIDIA TITAN RTX  | 12.51it/s         	| 18.22it/s                       	|
| A100-SXM4-40GB    	| 18.6it/s            	| 29.it/s                        	|
| A100-SXM-80GB    	| 18.7it/s            	| 29.5it/s                        	|

Steven Liu's avatar
Steven Liu committed
383
384
385
386
387
388
389
390
To leverage it just make sure you have:

<Tip warning={true}>

If you have PyTorch 2.0 installed, you shouldn't use xFormers!

</Tip>

391
392
 - PyTorch > 1.12
 - Cuda available
393
 - [Installed the xformers library](xformers).
394
```python
395
from diffusers import DiffusionPipeline
396
397
import torch

398
pipe = DiffusionPipeline.from_pretrained(
399
400
    "runwayml/stable-diffusion-v1-5",
    torch_dtype=torch.float16,
401
    use_safetensors=True,
402
403
404
405
406
407
408
409
410
).to("cuda")

pipe.enable_xformers_memory_efficient_attention()

with torch.inference_mode():
    sample = pipe("a small cat")

# optional: You can disable it via
# pipe.disable_xformers_memory_efficient_attention()
411
```