Unverified Commit a39d42b9 authored by Pedro Cuenca's avatar Pedro Cuenca Committed by GitHub
Browse files

[docs] update torch 2 benchmark (#2764)

* Update benchmark for A100, 3090, 3090 Ti, 4090.

* Link to PyTorch blog.

* Update install instructions.
parent ca1e4072
...@@ -18,11 +18,10 @@ Starting from version `0.13.0`, Diffusers supports the latest optimization from ...@@ -18,11 +18,10 @@ Starting from version `0.13.0`, Diffusers supports the latest optimization from
## Installation ## Installation
To benefit from the accelerated transformers implementation and `torch.compile`, we will need to install the nightly version of PyTorch, as the stable version is yet to be released. The first step is to install CUDA 11.7 or CUDA 11.8, To benefit from the accelerated attention implementation and `torch.compile`, you just need to install the latest versions of PyTorch 2.0 from `pip`, and make sure you are on diffusers 0.13.0 or later. As explained below, `diffusers` automatically uses the attention optimizations (but not `torch.compile`) when available.
as PyTorch 2.0 does not support the previous versions. Once CUDA is installed, torch nightly can be installed using:
```bash ```bash
pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu117 pip install --upgrade torch torchvision diffusers
``` ```
## Using accelerated transformers and torch.compile. ## Using accelerated transformers and torch.compile.
...@@ -91,8 +90,7 @@ pip install --pre torch torchvision --index-url https://download.pytorch.org/whl ...@@ -91,8 +90,7 @@ pip install --pre torch torchvision --index-url https://download.pytorch.org/whl
We conducted a simple benchmark on different GPUs to compare vanilla attention, xFormers, `torch.nn.functional.scaled_dot_product_attention` and `torch.compile+torch.nn.functional.scaled_dot_product_attention`. We conducted a simple benchmark on different GPUs to compare vanilla attention, xFormers, `torch.nn.functional.scaled_dot_product_attention` and `torch.compile+torch.nn.functional.scaled_dot_product_attention`.
For the benchmark we used the [stable-diffusion-v1-4](https://huggingface.co/CompVis/stable-diffusion-v1-4) model with 50 steps. The `xFormers` benchmark is done using the `torch==1.13.1` version, while the accelerated transformers optimizations are tested using nightly versions of PyTorch 2.0. The tables below summarize the results we got. For the benchmark we used the [stable-diffusion-v1-4](https://huggingface.co/CompVis/stable-diffusion-v1-4) model with 50 steps. The `xFormers` benchmark is done using the `torch==1.13.1` version, while the accelerated transformers optimizations are tested using nightly versions of PyTorch 2.0. The tables below summarize the results we got.
The `Speed over xformers` columns denote the speed-up gained over `xFormers` using the `torch.compile+torch.nn.functional.scaled_dot_product_attention`. Please refer to [our featured blog post in the PyTorch site](https://pytorch.org/blog/accelerated-diffusers-pt-20/) for more details.
### FP16 benchmark ### FP16 benchmark
...@@ -103,10 +101,14 @@ ___The time reported is in seconds.___ ...@@ -103,10 +101,14 @@ ___The time reported is in seconds.___
| GPU | Batch Size | Vanilla Attention | xFormers | PyTorch2.0 SDPA | SDPA + torch.compile | Speed over xformers (%) | | GPU | Batch Size | Vanilla Attention | xFormers | PyTorch2.0 SDPA | SDPA + torch.compile | Speed over xformers (%) |
| --- | --- | --- | --- | --- | --- | --- | | --- | --- | --- | --- | --- | --- | --- |
| A100 | 10 | 12.02 | 8.7 | 8.79 | 7.89 | 9.31 | | A100 | 1 | 2.69 | 2.7 | 1.98 | 2.47 | 8.52 |
| A100 | 16 | 18.95 | 13.57 | 13.67 | 12.25 | 9.73 | | A100 | 2 | 3.21 | 3.04 | 2.38 | 2.78 | 8.55 |
| A100 | 32 (1) | OOM | 26.56 | 26.68 | 24.08 | 9.34 | | A100 | 4 | 5.27 | 3.91 | 3.89 | 3.53 | 9.72 |
| A100 | 64 | | 52.51 | 53.03 | 47.81 | 8.95 | | A100 | 8 | 9.74 | 7.03 | 7.04 | 6.62 | 5.83 |
| A100 | 10 | 12.02 | 8.7 | 8.67 | 8.45 | 2.87 |
| A100 | 16 | 18.95 | 13.57 | 13.55 | 13.20 | 2.73 |
| A100 | 32 (1) | OOM | 26.56 | 26.68 | 25.85 | 2.67 |
| A100 | 64 | | 52.51 | 53.03 | 50.93 | 3.01 |
| | | | | | | | | | | | | | | |
| A10 | 4 | 13.94 | 9.81 | 10.01 | 9.35 | 4.69 | | A10 | 4 | 13.94 | 9.81 | 10.01 | 9.35 | 4.69 |
| A10 | 8 | 27.09 | 19 | 19.53 | 18.33 | 3.53 | | A10 | 8 | 27.09 | 19 | 19.53 | 18.33 | 3.53 |
...@@ -125,25 +127,28 @@ ___The time reported is in seconds.___ ...@@ -125,25 +127,28 @@ ___The time reported is in seconds.___
| V100 | 10 | OOM | 19.52 | 19.28 | 18.18 | 6.86 | | V100 | 10 | OOM | 19.52 | 19.28 | 18.18 | 6.86 |
| V100 | 16 | OOM | 30.29 | 29.84 | 28.22 | 6.83 | | V100 | 16 | OOM | 30.29 | 29.84 | 28.22 | 6.83 |
| | | | | | | | | | | | | | | |
| 3090 | 4 | 10.04 | 7.82 | 7.89 | 7.47 | 4.48 | | 3090 | 1 | 2.94 | 2.5 | 2.42 | 2.33 | 6.80 |
| 3090 | 8 | 19.27 | 14.97 | 15.04 | 14.22 | 5.01 | | 3090 | 4 | 10.04 | 7.82 | 7.72 | 7.38 | 5.63 |
| 3090 | 10| 24.08 | 18.7 | 18.7 | 17.69 | 5.40 | | 3090 | 8 | 19.27 | 14.97 | 14.88 | 14.15 | 5.48 |
| 3090 | 16 | OOM | 29.06 | 29.06 | 28.2 | 2.96 | | 3090 | 10| 24.08 | 18.7 | 18.62 | 18.12 | 3.10 |
| 3090 | 32 (1) | | 58.05 | 58 | 54.88 | 5.46 | | 3090 | 16 | OOM | 29.06 | 28.88 | 28.2 | 2.96 |
| 3090 | 64 (1) | | 126.54 | 126.03 | 117.33 | 7.28 | | 3090 | 32 (1) | | 58.05 | 57.42 | 56.28 | 3.05 |
| 3090 | 64 (1) | | 126.54 | 114.27 | 112.21 | 11.32 |
| | | | | | | | | | | | | | | |
| 3090 Ti | 4 | 9.07 | 7.14 | 7.15 | 6.81 | 4.62 | | 3090 Ti | 1 | 2.7 | 2.26 | 2.19 | 2.12 | 6.19 |
| 3090 Ti | 8 | 17.51 | 13.65 | 13.72 | 12.99 | 4.84 | | 3090 Ti | 4 | 9.07 | 7.14 | 7.00 | 6.71 | 6.02 |
| 3090 Ti | 10 (2) | 21.79 | 16.85 | 16.93 | 16.02 | 4.93 | | 3090 Ti | 8 | 17.51 | 13.65 | 13.53 | 12.94 | 5.20 |
| 3090 Ti | 16 | OOM | 26.1 | 26.28 | 25.46 | 2.45 | | 3090 Ti | 10 (2) | 21.79 | 16.85 | 16.77 | 16.44 | 2.43 |
| 3090 Ti | 32 (1) | | 51.78 | 52.04 | 49.15 | 5.08 | | 3090 Ti | 16 | OOM | 26.1 | 26.04 | 25.53 | 2.18 |
| 3090 Ti | 64 (1) | | 112.02 | 112.33 | 103.91 | 7.24 | | 3090 Ti | 32 (1) | | 51.78 | 51.71 | 50.91 | 1.68 |
| 3090 Ti | 64 (1) | | 112.02 | 102.78 | 100.89 | 9.94 |
| | | | | | | | | | | | | | | |
| 4090 | 4 | 10.48 | 8.37 | 8.32 | 8.01 | 4.30 | | 4090 | 1 | 4.47 | 3.98 | 1.28 | 1.21 | 69.60 |
| 4090 | 8 | 14.33 | 10.22 | 10.42 | 9.78 | 4.31 | | 4090 | 4 | 10.48 | 8.37 | 3.76 | 3.56 | 57.47 |
| 4090 | 16 | | 17.07 | 17.46 | 17.15 | -0.47 | | 4090 | 8 | 14.33 | 10.22 | 7.43 | 6.99 | 31.60 |
| 4090 | 32 (1) | | 39.03 | 39.86 | 37.97 | 2.72 | | 4090 | 16 | | 17.07 | 14.98 | 14.58 | 14.59 |
| 4090 | 64 (1) | | 77.29 | 79.44 | 77.67 | -0.49 | | 4090 | 32 (1) | | 39.03 | 30.18 | 29.49 | 24.44 |
| 4090 | 64 (1) | | 77.29 | 61.34 | 59.96 | 22.42 |
...@@ -155,11 +160,13 @@ Using `torch.compile` in addition to the accelerated transformers implementation ...@@ -155,11 +160,13 @@ Using `torch.compile` in addition to the accelerated transformers implementation
| GPU | Batch Size | Vanilla Attention | xFormers | PyTorch2.0 SDPA | SDPA + torch.compile | Speed over xformers (%) | Speed over vanilla (%) | | GPU | Batch Size | Vanilla Attention | xFormers | PyTorch2.0 SDPA | SDPA + torch.compile | Speed over xformers (%) | Speed over vanilla (%) |
| --- | --- | --- | --- | --- | --- | --- | --- | | --- | --- | --- | --- | --- | --- | --- | --- |
| A100 | 4 | 16.56 | 12.42 | 12.2 | 11.84 | 4.67 | 28.50 | | A100 | 1 | 4.97 | 3.86 | 2.6 | 2.86 | 25.91 | 42.45 |
| A100 | 10 | OOM | 29.93 | 29.44 | 28.5 | 4.78 | | | A100 | 2 | 9.03 | 6.76 | 4.41 | 4.21 | 37.72 | 53.38 |
| A100 | 16 | | 47.08 | 46.27 | 44.8 | 4.84 | | | A100 | 4 | 16.70 | 12.42 | 7.94 | 7.54 | 39.29 | 54.85 |
| A100 | 32 | | 92.89 | 91.34 | 88.35 | 4.89 | | | A100 | 10 | OOM | 29.93 | 18.70 | 18.46 | 38.32 | |
| A100 | 64 | | 185.3 | 182.71 | 176.48 | 4.76 | | | A100 | 16 | | 47.08 | 29.41 | 29.04 | 38.32 | |
| A100 | 32 | | 92.89 | 57.55 | 56.67 | 38.99 | |
| A100 | 64 | | 185.3 | 114.8 | 112.98 | 39.03 | |
| | | | | | | | | | | | | | | |
| A10 | 1 | 10.59 | 8.81 | 7.51 | 7.35 | 16.57 | 30.59 | | A10 | 1 | 10.59 | 8.81 | 7.51 | 7.35 | 16.57 | 30.59 |
| A10 | 4 | 34.77 | 27.63 | 22.77 | 22.07 | 20.12 | 36.53 | | A10 | 4 | 34.77 | 27.63 | 22.77 | 22.07 | 20.12 | 36.53 |
...@@ -179,30 +186,27 @@ Using `torch.compile` in addition to the accelerated transformers implementation ...@@ -179,30 +186,27 @@ Using `torch.compile` in addition to the accelerated transformers implementation
| V100 | 8 | | 43.95 | 43.37 | 42.25 | 3.87 | | | V100 | 8 | | 43.95 | 43.37 | 42.25 | 3.87 | |
| V100 | 16 | | 84.99 | 84.73 | 82.55 | 2.87 | | | V100 | 16 | | 84.99 | 84.73 | 82.55 | 2.87 | |
| | | | | | | | | | | | | | | |
| 3090 | 1 | 7.09 | 6.78 | 6.11 | 6.03 | 11.06 | 14.95 | | 3090 | 1 | 7.09 | 6.78 | 5.34 | 5.35 | 21.09 | 24.54 |
| 3090 | 4 | 22.69 | 21.45 | 18.67 | 18.09 | 15.66 | 20.27 | | 3090 | 4 | 22.69 | 21.45 | 18.56 | 18.18 | 15.24 | 19.88 |
| 3090 | 8 | | 42.59 | 36.75 | 35.59 | 16.44 | | | 3090 | 8 | | 42.59 | 36.68 | 35.61 | 16.39 | |
| 3090 | 16 | | 85.35 | 72.37 | 70.25 | 17.69 | | | 3090 | 16 | | 85.35 | 72.93 | 70.18 | 17.77 | |
| 3090 | 32 (1) | | 162.05 | 138.99 | 134.53 | 16.98 | | | 3090 | 32 (1) | | 162.05 | 143.46 | 138.67 | 14.43 | |
| 3090 | 48 | | 241.91 | 207.75 | | 14.12 | |
| | | | | | | | | | | | | | | |
| 3090 Ti | 1 | 6.45 | 6.19 | 5.64 | 5.49 | 11.31 | 14.88 | | 3090 Ti | 1 | 6.45 | 6.19 | 4.99 | 4.89 | 21.00 | 24.19 |
| 3090 Ti | 4 | 20.32 | 19.31 | 16.9 | 16.37 | 15.23 | 19.44 | | 3090 Ti | 4 | 20.32 | 19.31 | 17.02 | 16.48 | 14.66 | 18.90 |
| 3090 Ti | 8 (2) | | 37.93 | 33.05 | 31.99 | 15.66 | | | 3090 Ti | 8 | | 37.93 | 33.21 | 32.24 | 15.00 | |
| 3090 Ti | 16 | | 75.37 | 65.25 | 64.32 | 14.66 | | | 3090 Ti | 16 | | 75.37 | 66.63 | 64.5 | 14.42 | |
| 3090 Ti | 32 (1) | | 142.55 | 124.44 | 120.74 | 15.30 | | | 3090 Ti | 32 (1) | | 142.55 | 128.89 | 124.92 | 12.37 | |
| 3090 Ti | 48 | | 213.19 | 186.55 | | 12.50 | |
| | | | | | | | | | | | | | | |
| 4090 | 1 | 5.54 | 4.99 | 4.51 | 4.44 | 11.02 | 19.86 | | 4090 | 1 | 5.54 | 4.99 | 2.66 | 2.58 | 48.30 | 53.43 |
| 4090 | 4 | 13.67 | 11.4 | 10.3 | 9.84 | 13.68 | 28.02 | | 4090 | 4 | 13.67 | 11.4 | 8.81 | 8.46 | 25.79 | 38.11 |
| 4090 | 8 | | 19.79 | 17.13 | 16.19 | 18.19 | | | 4090 | 8 | | 19.79 | 17.55 | 16.62 | 16.02 | |
| 4090 | 16 | | 38.62 | 33.14 | 32.31 | 16.34 | | | 4090 | 16 | | 38.62 | 35.65 | 34.07 | 11.78 | |
| 4090 | 32 (1) | | 76.57 | 65.96 | 62.05 | 18.96 | | | 4090 | 32 (1) | | 76.57 | 69.48 | 65.35 | 14.65 | |
| 4090 | 48 | | 114.44 | 98.78 | | 13.68 | | | 4090 | 48 | | 114.44 | 106.3 | | 7.11 | |
(1) Batch Size >= 32 requires enable_vae_slicing() because of https://github.com/pytorch/pytorch/issues/81665. (1) Batch Size >= 32 requires enable_vae_slicing() because of https://github.com/pytorch/pytorch/issues/81665.
This is required for PyTorch 1.13.1, and also for PyTorch 2.0 and batch size of 64. This is required for PyTorch 1.13.1, and also for PyTorch 2.0 and large batch sizes.
For more details about how this benchmark was run, please refer to [this PR](https://github.com/huggingface/diffusers/pull/2303). For more details about how this benchmark was run, please refer to [this PR](https://github.com/huggingface/diffusers/pull/2303) and to [the blog post](https://pytorch.org/blog/accelerated-diffusers-pt-20/).
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment