torch2.0.mdx 11 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
<!--Copyright 2023 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->

13
# Accelerated PyTorch 2.0 support in Diffusers
14
15

Starting from version `0.13.0`, Diffusers supports the latest optimization from the upcoming [PyTorch 2.0](https://pytorch.org/get-started/pytorch-2.0/) release. These include:
16
17
1. Support for accelerated transformers implementation with memory-efficient attention – no extra dependencies required.
2. [torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) support for extra performance boost when individual models are compiled.
18
19
20


## Installation
21
To benefit from the accelerated attention implementation and `torch.compile`, you just need to install the latest versions of PyTorch 2.0 from `pip`, and make sure you are on diffusers 0.13.0 or later. As explained below, `diffusers` automatically uses the attention optimizations (but not `torch.compile`) when available.
22
23

```bash
24
pip install --upgrade torch torchvision diffusers
25
26
```

27
## Using accelerated transformers and torch.compile.
28
29


30
1. **Accelerated Transformers implementation**
31

32
   PyTorch 2.0 includes an optimized and memory-efficient attention implementation through the [`torch.nn.functional.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention) function, which automatically enables several optimizations depending on the inputs and the GPU type. This is similar to the `memory_efficient_attention` from [xFormers](https://github.com/facebookresearch/xformers), but built natively into PyTorch. 
33

34
   These optimizations will be enabled by default in Diffusers if PyTorch 2.0 is installed and if `torch.nn.functional.scaled_dot_product_attention` is available. To use it, just install `torch 2.0` as suggested above and simply use the pipeline. For example:
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51

    ```Python
    import torch
    from diffusers import StableDiffusionPipeline

    pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
    pipe = pipe.to("cuda")

    prompt = "a photo of an astronaut riding a horse on mars"
    image = pipe(prompt).images[0]
    ```

    If you want to enable it explicitly (which is not required), you can do so as shown below.

    ```Python
    import torch
    from diffusers import StableDiffusionPipeline
Patrick von Platen's avatar
Patrick von Platen committed
52
    from diffusers.models.attention_processor import AttnProcessor2_0
53
54

    pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to("cuda")
55
    pipe.unet.set_attn_processor(AttnProcessor2_0())
56
57
58
59
60

    prompt = "a photo of an astronaut riding a horse on mars"
    image = pipe(prompt).images[0]
    ```

61
    This should be as fast and memory efficient as `xFormers`. More details [in our benchmark](#benchmark).
62
63
64
65


2. **torch.compile**

66
    To get an additional speedup, we can use the new `torch.compile` feature. To do so, we simply wrap our `unet` with `torch.compile`. For more information and different options, refer to the 
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
    [torch compile docs](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html).

    ```python
    import torch
    from diffusers import StableDiffusionPipeline

    pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to(
        "cuda"
    )
    pipe.unet = torch.compile(pipe.unet)

    batch_size = 10
    prompt = "A photo of an astronaut riding a horse on marse."
    images = pipe(prompt, num_inference_steps=steps, num_images_per_prompt=batch_size).images
    ```

83
    Depending on the type of GPU, `compile()` can yield between 2-9% of _additional speed-up_ over the accelerated transformer optimizations. Note, however, that compilation is able to squeeze more performance improvements in more recent GPU architectures such as Ampere (A100, 3090), Ada (4090) and Hopper (H100).
84
    
85
    Compilation takes some time to complete, so it is best suited for situations where you need to prepare your pipeline once and then perform the same type of inference operations multiple times.
86
87
88
89
90


## Benchmark

We conducted a simple benchmark on different GPUs to compare vanilla attention, xFormers, `torch.nn.functional.scaled_dot_product_attention` and `torch.compile+torch.nn.functional.scaled_dot_product_attention`.
M. Tolga Cangöz's avatar
M. Tolga Cangöz committed
91
For the benchmark we used the [stable-diffusion-v1-4](https://huggingface.co/CompVis/stable-diffusion-v1-4) model with 50 steps. The `xFormers` benchmark is done using the `torch==1.13.1` version, while the accelerated transformers optimizations are tested using nightly versions of PyTorch 2.0. The tables below summarize the results we got.
92

93
Please refer to [our featured blog post in the PyTorch site](https://pytorch.org/blog/accelerated-diffusers-pt-20/) for more details.
94
95
96
97

### FP16 benchmark

The table below shows the benchmark results for inference using `fp16`. As we can see, `torch.nn.functional.scaled_dot_product_attention` is as fast as `xFormers` (sometimes slightly faster/slower) on all the GPUs we tested.
98
And using `torch.compile` gives further speed-up of up of 10% over `xFormers`, but it's mostly noticeable on the A100 GPU.
99
100
101
102
103

___The time reported is in seconds.___

| GPU | Batch Size | Vanilla Attention | xFormers | PyTorch2.0 SDPA | SDPA + torch.compile | Speed over xformers (%) |
| --- | --- | --- | --- | --- | --- | --- |
104
105
106
107
108
109
110
111
| A100 | 1 | 2.69 | 2.7 | 1.98 | 2.47 | 8.52 |
| A100 | 2 | 3.21 | 3.04 | 2.38 | 2.78 | 8.55 |
| A100 | 4 | 5.27 | 3.91 | 3.89 | 3.53 | 9.72 |
| A100 | 8 | 9.74 | 7.03 | 7.04 | 6.62 | 5.83 |
| A100 | 10 | 12.02 | 8.7 | 8.67 | 8.45 | 2.87 |
| A100 | 16 | 18.95 | 13.57 | 13.55 | 13.20 | 2.73 |
| A100 | 32 (1) | OOM | 26.56 | 26.68 | 25.85 | 2.67 |
| A100 | 64 | | 52.51 | 53.03 | 50.93 | 3.01 |
112
| | | | | | | |
patil-suraj's avatar
patil-suraj committed
113
114
115
116
117
118
119
| A10 | 4 | 13.94 | 9.81 | 10.01 | 9.35 | 4.69 |
| A10 | 8 | 27.09 | 19 | 19.53 | 18.33 | 3.53 |
| A10 | 10 | 33.69 | 23.53 | 24.19 | 22.52 | 4.29 |
| A10 | 16 | OOM | 37.55 | 38.31 | 36.81 | 1.97 |
| A10 | 32 (1) | | 77.19 | 78.43 | 76.64 | 0.71 |
| A10 | 64 (1) | | 173.59 | 158.99 | 155.14 | 10.63 |
| | | | | | | |
120
121
122
123
124
125
126
127
128
129
| T4 | 4 | 38.81 | 30.09 | 29.74 | 27.55 | 8.44 |
| T4 | 8 | OOM | 55.71 | 55.99 | 53.85 | 3.34 |
| T4 | 10 | OOM | 68.96 | 69.86 | 65.35 | 5.23 |
| T4 | 16 | OOM | 111.47 | 113.26  | 106.93  | 4.07 |
| | | | | | | |
| V100 | 4 | 9.84 | 8.16 | 8.09 | 7.65 | 6.25 |
| V100 | 8 | OOM | 15.62 | 15.44 | 14.59 | 6.59 |
| V100 | 10 | OOM | 19.52 | 19.28 | 18.18 | 6.86 |
| V100 | 16 | OOM | 30.29 | 29.84 | 28.22 | 6.83 |
| | | | | | | |
130
131
132
133
134
135
136
| 3090 | 1 | 2.94 | 2.5 | 2.42 | 2.33 | 6.80 |
| 3090 | 4 | 10.04 | 7.82 | 7.72 | 7.38 | 5.63 |
| 3090 | 8 | 19.27 | 14.97 | 14.88 | 14.15 | 5.48 |
| 3090 | 10| 24.08 | 18.7 | 18.62 | 18.12 | 3.10 |
| 3090 | 16 | OOM | 29.06 | 28.88 | 28.2 | 2.96 |
| 3090 | 32 (1) | | 58.05 | 57.42 | 56.28 | 3.05 |
| 3090 | 64 (1) | | 126.54 | 114.27 | 112.21 | 11.32 |
137
| | | | | | | |
138
139
140
141
142
143
144
| 3090 Ti | 1 | 2.7 | 2.26 | 2.19 | 2.12 | 6.19 |
| 3090 Ti | 4 | 9.07 | 7.14 | 7.00 | 6.71 | 6.02 |
| 3090 Ti | 8 | 17.51 | 13.65 | 13.53 | 12.94 | 5.20 |
| 3090 Ti | 10 (2) | 21.79 | 16.85 | 16.77 | 16.44 | 2.43 |
| 3090 Ti | 16 | OOM | 26.1 | 26.04 | 25.53 | 2.18 |
| 3090 Ti | 32 (1) | | 51.78 | 51.71 | 50.91 | 1.68 |
| 3090 Ti | 64 (1) | | 112.02 | 102.78 | 100.89 | 9.94 |
145
| | | | | | | |
146
147
148
149
150
151
| 4090 | 1 | 4.47 | 3.98 | 1.28 | 1.21 | 69.60 |
| 4090 | 4 | 10.48 | 8.37 | 3.76 | 3.56 | 57.47 |
| 4090 | 8 | 14.33 | 10.22 | 7.43 | 6.99 | 31.60 |
| 4090 | 16 | | 17.07 | 14.98 | 14.58 | 14.59 |
| 4090 | 32 (1) | | 39.03 | 30.18 | 29.49 | 24.44 |
| 4090 | 64 (1) | | 77.29 | 61.34 | 59.96 | 22.42 |
152
153
154
155
156


				
### FP32 benchmark

157
158
159
The table below shows the benchmark results for inference using `fp32`. In this case, `torch.nn.functional.scaled_dot_product_attention` is faster than `xFormers` on all the GPUs we tested.

Using `torch.compile` in addition to the accelerated transformers implementation can yield up to 19% performance improvement over `xFormers` in Ampere and Ada cards, and up to 20% (Ampere) or 28% (Ada) over vanilla attention.
160
161
162

| GPU | Batch Size | Vanilla Attention | xFormers | PyTorch2.0 SDPA | SDPA + torch.compile | Speed over xformers (%) | Speed over vanilla (%) |
| --- | --- | --- | --- | --- | --- | --- | --- |
163
164
165
166
167
168
169
| A100 | 1 | 4.97 | 3.86 | 2.6 | 2.86 | 25.91 | 42.45 |
| A100 | 2 | 9.03 | 6.76 | 4.41 | 4.21 | 37.72 | 53.38 |
| A100 | 4 | 16.70 | 12.42 | 7.94 | 7.54 | 39.29 | 54.85 |
| A100 | 10 | OOM | 29.93 | 18.70 | 18.46 | 38.32 | |
| A100 | 16 | | 47.08 | 29.41 | 29.04 | 38.32 | |
| A100 | 32 | | 92.89 | 57.55 | 56.67 | 38.99 | |
| A100 | 64 | | 185.3 | 114.8 | 112.98 | 39.03 | |
170
| | | | | | | |
patil-suraj's avatar
patil-suraj committed
171
172
173
174
175
176
177
| A10 | 1 | 10.59 | 8.81 | 7.51 | 7.35 | 16.57 | 30.59 |
| A10 | 4 | 34.77 | 27.63 | 22.77 | 22.07 | 20.12 | 36.53 |
| A10 | 8 | | 56.19 | 43.53 | 43.86 | 21.94 | |
| A10 | 16 | | 116.49 | 88.56 | 86.64 | 25.62 | |
| A10 | 32 | | 221.95 | 175.74 | 168.18 | 24.23 | |
| A10 | 48 | | 333.23 | 264.84 | | 20.52 | |
| | | | | | | |
178
179
180
181
182
183
184
185
186
187
188
| T4 | 1 | 28.2 | 24.49 | 23.93 | 23.56 | 3.80 | 16.45 |
| T4 | 2 | 52.77 | 45.7 | 45.88 | 45.06 | 1.40 | 14.61 |
| T4 | 4 | OOM | 85.72 | 85.78 | 84.48 | 1.45 | |
| T4 | 8 | | 149.64 | 150.75 | 148.4 | 0.83 | |
| | | | | | | |
| V100 | 1 | 7.4 | 6.84 | 6.8 | 6.66 | 2.63 | 10.00 |
| V100 | 2 | 13.85 | 12.81 | 12.66 | 12.35 | 3.59 | 10.83 |
| V100 | 4 | OOM | 25.73 | 25.31 | 24.78 | 3.69 | |
| V100 | 8 | | 43.95 | 43.37 | 42.25 | 3.87 | |
| V100 | 16 | | 84.99 | 84.73 | 82.55 | 2.87 | |
| | | | | | | |
189
190
191
192
193
| 3090 | 1 | 7.09 | 6.78 | 5.34 | 5.35 | 21.09 | 24.54 |
| 3090 | 4 | 22.69 | 21.45 | 18.56 | 18.18 | 15.24 | 19.88 |
| 3090 | 8 | | 42.59 | 36.68 | 35.61 | 16.39 | |
| 3090 | 16 | | 85.35 | 72.93 | 70.18 | 17.77 | |
| 3090 | 32 (1) | | 162.05 | 143.46 | 138.67 | 14.43 | |
194
| | | | | | | |
195
196
197
198
199
| 3090 Ti | 1 | 6.45 | 6.19 | 4.99 | 4.89 | 21.00 | 24.19 |
| 3090 Ti | 4 | 20.32 | 19.31 | 17.02 | 16.48 | 14.66 | 18.90 |
| 3090 Ti | 8 | | 37.93 | 33.21 | 32.24 | 15.00 | |
| 3090 Ti | 16 | | 75.37 | 66.63 | 64.5 | 14.42 | |
| 3090 Ti | 32 (1) | | 142.55 | 128.89 | 124.92 | 12.37 | |
200
| | | | | | | |
201
202
203
204
205
206
| 4090 | 1 | 5.54 | 4.99 | 2.66 | 2.58 | 48.30 | 53.43 |
| 4090 | 4 | 13.67 | 11.4 | 8.81 | 8.46 | 25.79 | 38.11 |
| 4090 | 8 |  | 19.79 | 17.55 | 16.62 | 16.02 |  |
| 4090 | 16 |  | 38.62 | 35.65 | 34.07 | 11.78 |  |
| 4090 | 32 (1) |  | 76.57 | 69.48 | 65.35 | 14.65 |  |
| 4090 | 48 |  | 114.44 | 106.3 |  | 7.11 |  |
207
208


209
210
(1) Batch Size >= 32 requires enable_vae_slicing() because of https://github.com/pytorch/pytorch/issues/81665.
This is required for PyTorch 1.13.1, and also for PyTorch 2.0 and large batch sizes.
211

212
For more details about how this benchmark was run, please refer to [this PR](https://github.com/huggingface/diffusers/pull/2303) and to [the blog post](https://pytorch.org/blog/accelerated-diffusers-pt-20/).