memory.md 25.5 KB
Newer Older
Aryan's avatar
Aryan committed
1
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->

13
14
# Reduce memory usage

Steven Liu's avatar
Steven Liu committed
15
Modern diffusion models like [Flux](../api/pipelines/flux) and [Wan](../api/pipelines/wan) have billions of parameters that take up a lot of memory on your hardware for inference. This is challenging because common GPUs often don't have sufficient memory. To overcome the memory limitations, you can use more than one GPU (if available), offload some of the pipeline components to the CPU, and more.
16

Steven Liu's avatar
Steven Liu committed
17
This guide will show you how to reduce your memory usage. 
18

Steven Liu's avatar
Steven Liu committed
19
20
> [!TIP]
> Keep in mind these techniques may need to be adjusted depending on the model! For example, a transformer-based diffusion model may not benefit equally from these inference speed optimizations as a UNet-based model.
21

Steven Liu's avatar
Steven Liu committed
22
## Multiple GPUs
23

Steven Liu's avatar
Steven Liu committed
24
If you have access to more than one GPU, there a few options for efficiently loading and distributing a large model across your hardware. These features are supported by the [Accelerate](https://huggingface.co/docs/accelerate/index) library, so make sure it is installed first.
25

Steven Liu's avatar
Steven Liu committed
26
27
28
```bash
pip install -U accelerate
```
29

Steven Liu's avatar
Steven Liu committed
30
### Sharded checkpoints
31

Steven Liu's avatar
Steven Liu committed
32
Loading large checkpoints in several shards in useful because the shards are loaded one at a time. This keeps memory usage low, only requiring enough memory for the model size and the largest shard size. We recommend sharding when the fp32 checkpoint is greater than 5GB. The default shard size is 5GB.
33

Steven Liu's avatar
Steven Liu committed
34
Shard a checkpoint in [`~DiffusionPipeline.save_pretrained`] with the `max_shard_size` parameter.
35

Steven Liu's avatar
Steven Liu committed
36
37
```py
from diffusers import AutoModel
38

Steven Liu's avatar
Steven Liu committed
39
40
unet = AutoModel.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet"
41
)
Steven Liu's avatar
Steven Liu committed
42
unet.save_pretrained("sdxl-unet-sharded", max_shard_size="5GB")
43
44
```

Steven Liu's avatar
Steven Liu committed
45
Now you can use the sharded checkpoint, instead of the regular checkpoint, to save memory.
46

Steven Liu's avatar
Steven Liu committed
47
48
49
```py
import torch
from diffusers import AutoModel, StableDiffusionXLPipeline
50

Steven Liu's avatar
Steven Liu committed
51
52
53
54
55
56
57
58
59
unet = AutoModel.from_pretrained(
    "username/sdxl-unet-sharded", torch_dtype=torch.float16
)
pipeline = StableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    unet=unet,
    torch_dtype=torch.float16
).to("cuda")
```
60

Steven Liu's avatar
Steven Liu committed
61
### Device placement
62

Steven Liu's avatar
Steven Liu committed
63
64
65
66
67
68
> [!WARNING]
> Device placement is an experimental feature and the API may change. Only the `balanced` strategy is supported at the moment. We plan to support additional mapping strategies in the future.

The `device_map` parameter controls how the model components in a pipeline are distributed across devices. The `balanced` device placement strategy evenly splits the pipeline across all available devices.

```py
69
import torch
Steven Liu's avatar
Steven Liu committed
70
from diffusers import AutoModel, StableDiffusionXLPipeline
71

Steven Liu's avatar
Steven Liu committed
72
73
pipeline = StableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
74
    torch_dtype=torch.float16,
Steven Liu's avatar
Steven Liu committed
75
    device_map="balanced"
76
)
Steven Liu's avatar
Steven Liu committed
77
```
78

Steven Liu's avatar
Steven Liu committed
79
80
81
82
83
You can inspect a pipeline's device map with `hf_device_map`.

```py
print(pipeline.hf_device_map)
{'unet': 1, 'vae': 1, 'safety_checker': 0, 'text_encoder': 0}
84
85
```

Steven Liu's avatar
Steven Liu committed
86
The `device_map` parameter also works on the model-level. This is useful for loading large models, such as the Flux diffusion transformer which has 12.5B parameters. Instead of `balanced`, set it to `"auto"` to automatically distribute a model across the fastest device first before moving to slower devices. Refer to the [Model sharding](../training/distributed_inference#model-sharding) docs for more details.
87

Steven Liu's avatar
Steven Liu committed
88
89
90
```py
import torch
from diffusers import AutoModel
91

Steven Liu's avatar
Steven Liu committed
92
93
94
95
96
97
98
transformer = AutoModel.from_pretrained(
    "black-forest-labs/FLUX.1-dev", 
    subfolder="transformer",
    device_map="auto",
    torch_dtype=torch.bfloat16
)
```
99

Steven Liu's avatar
Steven Liu committed
100
For more fine-grained control, pass a dictionary to enforce the maximum GPU memory to use on each device. If a device is not in `max_memory`, it is ignored and pipeline components won't be distributed to it.
101

Steven Liu's avatar
Steven Liu committed
102
```py
103
import torch
Steven Liu's avatar
Steven Liu committed
104
from diffusers import AutoModel, StableDiffusionXLPipeline
105

Steven Liu's avatar
Steven Liu committed
106
107
108
max_memory = {0:"1GB", 1:"1GB"}
pipeline = StableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
109
    torch_dtype=torch.float16,
Steven Liu's avatar
Steven Liu committed
110
111
    device_map="balanced",
    max_memory=max_memory
112
)
Steven Liu's avatar
Steven Liu committed
113
```
114

Steven Liu's avatar
Steven Liu committed
115
116
117
118
119
120
121
122
123
Diffusers uses the maxmium memory of all devices by default, but if they don't fit on the GPUs, then you'll need to use a single GPU and offload to the CPU with the methods below.

- [`~DiffusionPipeline.enable_model_cpu_offload`] only works on a single GPU but a very large model may not fit on it
- [`~DiffusionPipeline.enable_sequential_cpu_offload`] may work but it is extremely slow and also limited to a single GPU

Use the [`~DiffusionPipeline.reset_device_map`] method to reset the `device_map`. This is necessary if you want to use methods like `.to()`, [`~DiffusionPipeline.enable_sequential_cpu_offload`], and [`~DiffusionPipeline.enable_model_cpu_offload`] on a pipeline that was device-mapped.

```py
pipeline.reset_device_map()
124
125
```

Steven Liu's avatar
Steven Liu committed
126
## VAE slicing
127

Steven Liu's avatar
Steven Liu committed
128
VAE slicing saves memory by splitting large batches of inputs into a single batch of data and separately processing them. This method works best when generating more than one image at a time.
129

Steven Liu's avatar
Steven Liu committed
130
For example, if you're generating 4 images at once, decoding would increase peak activation memory by 4x. VAE slicing reduces this by only decoding 1 image at a time instead of all 4 images at once.
131

Steven Liu's avatar
Steven Liu committed
132
Call [`~StableDiffusionPipeline.enable_vae_slicing`] to enable sliced VAE. You can expect a small increase in performance when decoding multi-image batches and no performance impact for single-image batches.
133

Steven Liu's avatar
Steven Liu committed
134
135
136
```py
import torch
from diffusers import AutoModel, StableDiffusionXLPipeline
137

Steven Liu's avatar
Steven Liu committed
138
139
140
141
142
143
144
145
pipeline = StableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=torch.float16,
).to("cuda")
pipeline.enable_vae_slicing()
pipeline(["An astronaut riding a horse on Mars"]*32).images[0]
print(f"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB")
```
146

Steven Liu's avatar
Steven Liu committed
147
148
> [!WARNING]
> [`AutoencoderKLWan`] and [`AsymmetricAutoencoderKL`] don't support slicing.
149

Steven Liu's avatar
Steven Liu committed
150
## VAE tiling
151

Steven Liu's avatar
Steven Liu committed
152
153
154
155
156
157
158
159
160
161
162
163
164
VAE tiling saves memory by dividing an image into smaller overlapping tiles instead of processing the entire image at once. This also reduces peak memory usage because the GPU is only processing a tile at a time.

Call [`~StableDiffusionPipeline.enable_vae_tiling`] to enable VAE tiling. The generated image may have some tone variation from tile-to-tile because they're decoded separately, but there shouldn't be any obvious seams between the tiles. Tiling is disabled for resolutions lower than a pre-specified (but configurable) limit. For example, this limit is 512x512 for the VAE in [`StableDiffusionPipeline`].

```py
import torch
from diffusers import AutoPipelineForImage2Image
from diffusers.utils import load_image

pipeline = AutoPipelineForImage2Image.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
).to("cuda")
pipeline.enable_vae_tiling()
165

Steven Liu's avatar
Steven Liu committed
166
167
168
169
170
init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-sdxl-init.png")
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
pipeline(prompt, image=init_image, strength=0.5).images[0]
print(f"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB")
```
171

Steven Liu's avatar
Steven Liu committed
172
173
> [!WARNING]
> [`AutoencoderKLWan`] and [`AsymmetricAutoencoderKL`] don't support tiling.
174

Steven Liu's avatar
Steven Liu committed
175
## CPU offloading
176

Steven Liu's avatar
Steven Liu committed
177
CPU offloading selectively moves weights from the GPU to the CPU. When a component is required, it is transferred to the GPU and when it isn't required, it is moved to the CPU. This method works on submodules rather than whole models. It saves memory by avoiding storing the entire model on the GPU.
178

Steven Liu's avatar
Steven Liu committed
179
CPU offloading dramatically reduces memory usage, but it is also **extremely slow** because submodules are passed back and forth multiple times between devices. It can often be impractical due to how slow it is.
180

Steven Liu's avatar
Steven Liu committed
181
182
> [!WARNING]
> Don't move the pipeline to CUDA before calling [`~DiffusionPipeline.enable_sequential_cpu_offload`], otherwise the amount of memory saved is only minimal (refer to this [issue](https://github.com/huggingface/diffusers/issues/1934) for more details). This is a stateful operation that installs hooks on the model.
183

Steven Liu's avatar
Steven Liu committed
184
Call [`~DiffusionPipeline.enable_sequential_cpu_offload`] to enable it on a pipeline.
185

Steven Liu's avatar
Steven Liu committed
186
```py
187
import torch
Steven Liu's avatar
Steven Liu committed
188
from diffusers import DiffusionPipeline
189

Steven Liu's avatar
Steven Liu committed
190
191
pipeline = DiffusionPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16
192
)
Steven Liu's avatar
Steven Liu committed
193
194
195
196
197
198
199
200
201
202
203
pipeline.enable_sequential_cpu_offload()

pipeline(
    prompt="An astronaut riding a horse on Mars",
    guidance_scale=0.,
    height=768,
    width=1360,
    num_inference_steps=4,
    max_sequence_length=256,
).images[0]
print(f"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB")
204
205
```

Steven Liu's avatar
Steven Liu committed
206
## Model offloading
207

Steven Liu's avatar
Steven Liu committed
208
Model offloading moves entire models to the GPU instead of selectively moving *some* layers or model components. One of the main pipeline models, usually the text encoder, UNet, and VAE, is placed on the GPU while the other components are held on the CPU. Components like the UNet that run multiple times stays on the GPU until its completely finished and no longer needed. This eliminates the communication overhead of [CPU offloading](#cpu-offloading) and makes model offloading a faster alternative. The tradeoff is memory savings won't be as large.
209

Steven Liu's avatar
Steven Liu committed
210
211
> [!WARNING]
> Keep in mind that if models are reused outside the pipeline after hookes have been installed (see [Removing Hooks](https://huggingface.co/docs/accelerate/en/package_reference/big_modeling#accelerate.hooks.remove_hook_from_module) for more details), you need to run the entire pipeline and models in the expected order to properly offload them. This is a stateful operation that installs hooks on the model.
212

Steven Liu's avatar
Steven Liu committed
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
Call [`~DiffusionPipeline.enable_model_cpu_offload`] to enable it on a pipeline.

```py
import torch
from diffusers import DiffusionPipeline

pipeline = DiffusionPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16
)
pipline.enable_model_cpu_offload()

pipeline(
    prompt="An astronaut riding a horse on Mars",
    guidance_scale=0.,
    height=768,
    width=1360,
    num_inference_steps=4,
    max_sequence_length=256,
).images[0]
print(f"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB")
```

[`~DiffusionPipeline.enable_model_cpu_offload`] also helps when you're using the [`~StableDiffusionXLPipeline.encode_prompt`] method on its own to generate the text encoders hidden state.
236

Aryan's avatar
Aryan committed
237
238
## Group offloading

Steven Liu's avatar
Steven Liu committed
239
Group offloading moves groups of internal layers ([torch.nn.ModuleList](https://pytorch.org/docs/stable/generated/torch.nn.ModuleList.html) or [torch.nn.Sequential](https://pytorch.org/docs/stable/generated/torch.nn.Sequential.html)) to the CPU. It uses less memory than [model offloading](#model-offloading) and it is faster than [CPU offloading](#cpu-offloading) because it reduces communication overhead.
Aryan's avatar
Aryan committed
240

Steven Liu's avatar
Steven Liu committed
241
242
> [!WARNING]
> Group offloading may not work with all models if the forward implementation contains weight-dependent device casting of inputs because it may clash with group offloading's device casting mechanism.
Aryan's avatar
Aryan committed
243

Steven Liu's avatar
Steven Liu committed
244
245
246
247
248
249
250
251
Call [`~ModelMixin.enable_group_offload`] to enable it for standard Diffusers model components that inherit from [`ModelMixin`]. For other model components that don't inherit from [`ModelMixin`], such as a generic [torch.nn.Module](https://pytorch.org/docs/stable/generated/torch.nn.Module.html), use [`~hooks.apply_group_offloading`] instead.

The `offload_type` parameter can be set to `block_level` or `leaf_level`.

- `block_level` offloads groups of layers based on the `num_blocks_per_group` parameter. For example, if `num_blocks_per_group=2` on a model with 40 layers, 2 layers are onloaded and offloaded at a time (20 total onloads/offloads). This drastically reduces memory requirements.
- `leaf_level` offloads individual layers at the lowest level and is equivalent to [CPU offloading](#cpu-offloading). But it can be made faster if you use streams without giving up inference speed.

```py
Aryan's avatar
Aryan committed
252
253
254
255
256
257
258
import torch
from diffusers import CogVideoXPipeline
from diffusers.hooks import apply_group_offloading
from diffusers.utils import export_to_video

onload_device = torch.device("cuda")
offload_device = torch.device("cpu")
Steven Liu's avatar
Steven Liu committed
259
pipeline = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
Aryan's avatar
Aryan committed
260

Steven Liu's avatar
Steven Liu committed
261
262
263
# Use the enable_group_offload method for Diffusers model implementations
pipeline.transformer.enable_group_offload(onload_device=onload_device, offload_device=offload_device, offload_type="leaf_level")
pipeline.vae.enable_group_offload(onload_device=onload_device, offload_type="leaf_level")
264

Steven Liu's avatar
Steven Liu committed
265
266
# Use the apply_group_offloading method for other model components
apply_group_offloading(pipeline.text_encoder, onload_device=onload_device, offload_type="block_level", num_blocks_per_group=2)
Aryan's avatar
Aryan committed
267
268
269
270
271
272
273
274
275

prompt = (
    "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
    "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
    "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
    "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
    "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
    "atmosphere of this unique musical performance."
)
Steven Liu's avatar
Steven Liu committed
276
video = pipeline(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
Aryan's avatar
Aryan committed
277
278
279
280
print(f"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB")
export_to_video(video, "output.mp4", fps=8)
```

Steven Liu's avatar
Steven Liu committed
281
282
283
### CUDA stream

The `use_stream` parameter can be activated for CUDA devices that support asynchronous data transfer streams to reduce overall execution time compared to [CPU offloading](#cpu-offloading). It overlaps data transfer and computation by using layer prefetching. The next layer to be executed is loaded onto the GPU while the current layer is still being executed. It can increase CPU memory significantly so ensure you have 2x the amount of memory as the model size.
Aryan's avatar
Aryan committed
284

Steven Liu's avatar
Steven Liu committed
285
Set `record_stream=True` for more of a speedup at the cost of slightly increased memory usage. Refer to the [torch.Tensor.record_stream](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) docs to learn more.
286

Steven Liu's avatar
Steven Liu committed
287
288
> [!TIP]
> When `use_stream=True` on VAEs with tiling enabled, make sure to do a dummy forward pass (possible with dummy inputs as well) before inference to avoid device mismatch errors. This may not work on all implementations, so feel free to open an issue if you encounter any problems.
289

Steven Liu's avatar
Steven Liu committed
290
If you're using `block_level` group offloading with `use_stream` enabled, the `num_blocks_per_group` parameter should be set to `1`, otherwise a warning will be raised.
291

Steven Liu's avatar
Steven Liu committed
292
293
294
```py
pipeline.transformer.enable_group_offload(onload_device=onload_device, offload_device=offload_device, offload_type="leaf_level", use_stream=True, record_stream=True)
```
295

Steven Liu's avatar
Steven Liu committed
296
The `low_cpu_mem_usage` parameter can be set to `True` to reduce CPU memory usage when using streams during group offloading. It is best for `leaf_level` offloading and when CPU memory is bottlenecked. Memory is saved by creating pinned tensors on the fly instead of pre-pinning them. However, this may increase overall execution time.
Aryan's avatar
Aryan committed
297

298
299
300
301
302
303
304
<Tip>

The offloading strategies can be combined with [quantization](../quantization/overview.md) to enable further memory savings. For image generation, combining [quantization and model offloading](#model-offloading) can often give the best trade-off between quality, speed, and memory. However, for video generation, as the models are more
compute-bound, [group-offloading](#group-offloading) tends to be better. Group offloading provides considerable benefits when weight transfers can be overlapped with computation (must use streams). When applying group offloading with quantization on image generation models at typical resolutions (1024x1024, for example), it is usually not possible to *fully* overlap weight transfers if the compute kernel finishes faster, making it communication bound between CPU/GPU (due to device synchronizations).

</Tip>

Steven Liu's avatar
Steven Liu committed
305
## Layerwise casting
Aryan's avatar
Aryan committed
306

Steven Liu's avatar
Steven Liu committed
307
Layerwise casting stores weights in a smaller data format (for example, `torch.float8_e4m3fn` and `torch.float8_e5m2`) to use less memory and upcasts those weights to a higher precision like `torch.float16` or `torch.bfloat16` for computation. Certain layers (normalization and modulation related weights) are skipped because storing them in fp8 can degrade generation quality.
Aryan's avatar
Aryan committed
308

Steven Liu's avatar
Steven Liu committed
309
310
311
312
313
314
315
316
> [!WARNING]
> Layerwise casting may not work with all models if the forward implementation contains internal typecasting of weights. The current implementation of layerwise casting assumes the forward pass is independent of the weight precision and the input datatypes are always specified in `compute_dtype` (see [here](https://github.com/huggingface/transformers/blob/7f5077e53682ca855afc826162b204ebf809f1f9/src/transformers/models/t5/modeling_t5.py#L294-L299) for an incompatible implementation).
>
> Layerwise casting may also fail on custom modeling implementations with [PEFT](https://huggingface.co/docs/peft/index) layers. There are some checks available but they are not extensively tested or guaranteed to work in all cases.

Call [`~ModelMixin.enable_layerwise_casting`] to set the storage and computation datatypes.

```py
Aryan's avatar
Aryan committed
317
318
319
320
import torch
from diffusers import CogVideoXPipeline, CogVideoXTransformer3DModel
from diffusers.utils import export_to_video

Steven Liu's avatar
Steven Liu committed
321
322
323
324
325
transformer = CogVideoXTransformer3DModel.from_pretrained(
    "THUDM/CogVideoX-5b",
    subfolder="transformer",
    torch_dtype=torch.bfloat16
)
Aryan's avatar
Aryan committed
326
327
transformer.enable_layerwise_casting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16)

Steven Liu's avatar
Steven Liu committed
328
329
330
331
pipeline = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b",
    transformer=transformer,
    torch_dtype=torch.bfloat16
).to("cuda")
Aryan's avatar
Aryan committed
332
333
334
335
336
337
338
339
prompt = (
    "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
    "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
    "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
    "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
    "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
    "atmosphere of this unique musical performance."
)
Steven Liu's avatar
Steven Liu committed
340
341
video = pipeline(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
print(f"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB")
Aryan's avatar
Aryan committed
342
343
344
export_to_video(video, "output.mp4", fps=8)
```

Steven Liu's avatar
Steven Liu committed
345
The [`~hooks.apply_layerwise_casting`] method can also be used if you need more control and flexibility. It can be partially applied to model layers by calling it on specific internal modules. Use the `skip_modules_pattern` or `skip_modules_classes` parameters to specify modules to avoid, such as the normalization and modulation layers.
Aryan's avatar
Aryan committed
346

Steven Liu's avatar
Steven Liu committed
347
348
349
350
```python
import torch
from diffusers import CogVideoXTransformer3DModel
from diffusers.hooks import apply_layerwise_casting
351

Steven Liu's avatar
Steven Liu committed
352
353
354
355
356
transformer = CogVideoXTransformer3DModel.from_pretrained(
    "THUDM/CogVideoX-5b",
    subfolder="transformer",
    torch_dtype=torch.bfloat16
)
357

Steven Liu's avatar
Steven Liu committed
358
359
360
361
362
363
364
365
366
# skip the normalization layer
apply_layerwise_casting(
    transformer,
    storage_dtype=torch.float8_e4m3fn,
    compute_dtype=torch.bfloat16,
    skip_modules_classes=["norm"],
    non_blocking=True,
)
```
367

Steven Liu's avatar
Steven Liu committed
368
## torch.channels_last
369

Steven Liu's avatar
Steven Liu committed
370
[torch.channels_last](https://pytorch.org/tutorials/intermediate/memory_format_tutorial.html) flips how tensors are stored from `(batch size, channels, height, width)` to `(batch size, heigh, width, channels)`. This aligns the tensors with how the hardware sequentially accesses the tensors stored in memory and avoids skipping around in memory to access the pixel values.
371

Steven Liu's avatar
Steven Liu committed
372
Not all operators currently support the channels-last format and may result in worst performance, but it is still worth trying.
373

Steven Liu's avatar
Steven Liu committed
374
375
376
```py
print(pipeline.unet.conv_out.state_dict()["weight"].stride())  # (2880, 9, 3, 1)
pipeline.unet.to(memory_format=torch.channels_last)  # in-place operation
377
print(
Steven Liu's avatar
Steven Liu committed
378
    pipeline.unet.conv_out.state_dict()["weight"].stride()
379
380
381
)  # (2880, 1, 960, 320) having a stride of 1 for the 2nd dimension proves that it works
```

Steven Liu's avatar
Steven Liu committed
382
## torch.jit.trace
383

Steven Liu's avatar
Steven Liu committed
384
[torch.jit.trace](https://pytorch.org/docs/stable/generated/torch.jit.trace.html) records the operations a model performs on a sample input and creates a new, optimized representation of the model based on the recorded execution path. During tracing, the model is optimized to reduce overhead from Python and dynamic control flows and operations are fused together for more efficiency. The returned executable or [ScriptFunction](https://pytorch.org/docs/stable/generated/torch.jit.ScriptFunction.html) can be compiled.
385

Steven Liu's avatar
Steven Liu committed
386
```py
387
388
389
390
391
392
393
394
395
396
397
398
import time
import torch
from diffusers import StableDiffusionPipeline
import functools

# torch disable grad
torch.set_grad_enabled(False)

# set variables
n_experiments = 2
unet_runs_per_experiment = 50

Steven Liu's avatar
Steven Liu committed
399
# load sample inputs
400
def generate_inputs():
401
402
403
    sample = torch.randn((2, 4, 64, 64), device="cuda", dtype=torch.float16)
    timestep = torch.rand(1, device="cuda", dtype=torch.float16) * 999
    encoder_hidden_states = torch.randn((2, 77, 768), device="cuda", dtype=torch.float16)
404
405
406
    return sample, timestep, encoder_hidden_states


Steven Liu's avatar
Steven Liu committed
407
pipeline = StableDiffusionPipeline.from_pretrained(
408
    "stable-diffusion-v1-5/stable-diffusion-v1-5",
409
410
411
    torch_dtype=torch.float16,
    use_safetensors=True,
).to("cuda")
Steven Liu's avatar
Steven Liu committed
412
unet = pipeline.unet
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
unet.eval()
unet.to(memory_format=torch.channels_last)  # use channels_last memory format
unet.forward = functools.partial(unet.forward, return_dict=False)  # set return_dict=False as default

# warmup
for _ in range(3):
    with torch.inference_mode():
        inputs = generate_inputs()
        orig_output = unet(*inputs)

# trace
print("tracing..")
unet_traced = torch.jit.trace(unet, inputs)
unet_traced.eval()
print("done tracing")

# warmup and optimize graph
for _ in range(5):
    with torch.inference_mode():
        inputs = generate_inputs()
        orig_output = unet_traced(*inputs)

# benchmarking
with torch.inference_mode():
    for _ in range(n_experiments):
        torch.cuda.synchronize()
        start_time = time.time()
        for _ in range(unet_runs_per_experiment):
            orig_output = unet_traced(*inputs)
        torch.cuda.synchronize()
        print(f"unet traced inference took {time.time() - start_time:.2f} seconds")
    for _ in range(n_experiments):
        torch.cuda.synchronize()
        start_time = time.time()
        for _ in range(unet_runs_per_experiment):
            orig_output = unet(*inputs)
        torch.cuda.synchronize()
        print(f"unet inference took {time.time() - start_time:.2f} seconds")

# save the model
unet_traced.save("unet_traced.pt")
```

Steven Liu's avatar
Steven Liu committed
456
Replace the pipeline's UNet with the traced version.
457

Steven Liu's avatar
Steven Liu committed
458
```py
459
import torch
Steven Liu's avatar
Steven Liu committed
460
from diffusers import StableDiffusionPipeline
461
462
463
464
from dataclasses import dataclass

@dataclass
class UNet2DConditionOutput:
465
    sample: torch.Tensor
466

Steven Liu's avatar
Steven Liu committed
467
pipeline = StableDiffusionPipeline.from_pretrained(
468
    "stable-diffusion-v1-5/stable-diffusion-v1-5",
469
470
471
472
473
474
475
    torch_dtype=torch.float16,
    use_safetensors=True,
).to("cuda")

# use jitted unet
unet_traced = torch.jit.load("unet_traced.pt")

Steven Liu's avatar
Steven Liu committed
476
# del pipeline.unet
477
478
479
class TracedUNet(torch.nn.Module):
    def __init__(self):
        super().__init__()
480
        self.in_channels = pipe.unet.config.in_channels
481
482
483
484
485
486
        self.device = pipe.unet.device

    def forward(self, latent_model_input, t, encoder_hidden_states):
        sample = unet_traced(latent_model_input, t, encoder_hidden_states)[0]
        return UNet2DConditionOutput(sample=sample)

Steven Liu's avatar
Steven Liu committed
487
pipeline.unet = TracedUNet()
488
489
490
491
492
493
494

with torch.inference_mode():
    image = pipe([prompt] * 1, num_inference_steps=50).images[0]
```

## Memory-efficient attention

Steven Liu's avatar
Steven Liu committed
495
496
> [!TIP]
> Memory-efficient attention optimizes for memory usage *and* [inference speed](./fp16#scaled-dot-product-attention!
497

Steven Liu's avatar
Steven Liu committed
498
The Transformers attention mechanism is memory-intensive, especially for long sequences, so you can try using different and more memory-efficient attention types.
499

Steven Liu's avatar
Steven Liu committed
500
By default, if PyTorch >= 2.0 is installed, [scaled dot-product attention (SDPA)](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) is used. You don't need to make any additional changes to your code.
501

Steven Liu's avatar
Steven Liu committed
502
SDPA supports [FlashAttention](https://github.com/Dao-AILab/flash-attention) and [xFormers](https://github.com/facebookresearch/xformers) as well as a native C++ PyTorch implementation. It automatically selects the most optimal implementation based on your input.
503

Steven Liu's avatar
Steven Liu committed
504
You can explicitly use xFormers with the [`~ModelMixin.enable_xformers_memory_efficient_attention`] method.
505

Steven Liu's avatar
Steven Liu committed
506
507
```py
# pip install xformers
508
import torch
Steven Liu's avatar
Steven Liu committed
509
from diffusers import StableDiffusionXLPipeline
510

Steven Liu's avatar
Steven Liu committed
511
512
pipeline = StableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
513
514
    torch_dtype=torch.float16,
).to("cuda")
Steven Liu's avatar
Steven Liu committed
515
pipeline.enable_xformers_memory_efficient_attention()
516
```
517

Steven Liu's avatar
Steven Liu committed
518
519
520
521
522
Call [`~ModelMixin.disable_xformers_memory_efficient_attention`] to disable it.

```py
pipeline.disable_xformers_memory_efficient_attention()
```