Unverified Commit 98c42134 authored by MatthieuTPHR's avatar MatthieuTPHR Committed by GitHub
Browse files

Up to 2x speedup on GPUs using memory efficient attention (#532)



* 2x speedup using memory efficient attention

* remove einops dependency

* Swap K, M in op instantiation

* Simplify code, remove unnecessary maybe_init call and function, remove unused self.scale parameter

* make xformers a soft dependency

* remove one-liner functions

* change one letter variable to appropriate names

* Remove Env variable dependency, remove MemoryEfficientCrossAttention class and use enable_xformers_memory_efficient_attention method

* Add memory efficient attention toggle to img2img and inpaint pipelines

* Clearer management of xformers' availability

* update optimizations markdown to add info about memory efficient attention

* add benchmarks for TITAN RTX

* More detailed explanation of how the mem eff benchmark were ran

* Removing autocast from optimization markdown

* import_utils: import torch only if is available
Co-authored-by: default avatarNouamane Tazi <nouamane98@gmail.com>
parent a793b1fe
...@@ -22,6 +22,7 @@ We present some techniques and ideas to optimize 🤗 Diffusers _inference_ for ...@@ -22,6 +22,7 @@ We present some techniques and ideas to optimize 🤗 Diffusers _inference_ for
| fp16 | 3.61s | x2.63 | | fp16 | 3.61s | x2.63 |
| channels last | 3.30s | x2.88 | | channels last | 3.30s | x2.88 |
| traced UNet | 3.21s | x2.96 | | traced UNet | 3.21s | x2.96 |
| memory efficient attention | 2.63s | x3.61 |
<em> <em>
obtained on NVIDIA TITAN RTX by generating a single image of size 512x512 from obtained on NVIDIA TITAN RTX by generating a single image of size 512x512 from
...@@ -290,3 +291,41 @@ pipe.unet = TracedUNet() ...@@ -290,3 +291,41 @@ pipe.unet = TracedUNet()
with torch.inference_mode(): with torch.inference_mode():
image = pipe([prompt] * 1, num_inference_steps=50).images[0] image = pipe([prompt] * 1, num_inference_steps=50).images[0]
``` ```
## Memory Efficient Attention
Recent work on optimizing the bandwitdh in the attention block have generated huge speed ups and gains in GPU memory usage. The most recent being Flash Attention (from @tridao, [code](https://github.com/HazyResearch/flash-attention), [paper](https://arxiv.org/pdf/2205.14135.pdf)) .
Here are the speedups we obtain on a few Nvidia GPUs when running the inference at 512x512 with a batch size of 1 (one prompt):
| GPU | Base Attention FP16 | Memory Efficient Attention FP16 |
|------------------ |--------------------- |--------------------------------- |
| NVIDIA Tesla T4 | 3.5it/s | 5.5it/s |
| NVIDIA 3060 RTX | 4.6it/s | 7.8it/s |
| NVIDIA A10G | 8.88it/s | 15.6it/s |
| NVIDIA RTX A6000 | 11.7it/s | 21.09it/s |
| NVIDIA TITAN RTX | 12.51it/s | 18.22it/s |
| A100-SXM4-40GB | 18.6it/s | 29.it/s |
| A100-SXM-80GB | 18.7it/s | 29.5it/s |
To leverage it just make sure you have:
- PyTorch > 1.12
- Cuda available
- Installed the [xformers](https://github.com/facebookresearch/xformers) library
```python
from diffusers import StableDiffusionPipeline
import torch
pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
revision="fp16",
torch_dtype=torch.float16,
).to("cuda")
pipe.enable_xformers_memory_efficient_attention()
with torch.inference_mode():
sample = pipe("a small cat")
# optional: You can disable it via
# pipe.disable_xformers_memory_efficient_attention()
```
\ No newline at end of file
...@@ -18,6 +18,15 @@ import torch ...@@ -18,6 +18,15 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
from diffusers.utils.import_utils import is_xformers_available
if is_xformers_available():
import xformers
import xformers.ops
else:
xformers = None
class AttentionBlock(nn.Module): class AttentionBlock(nn.Module):
""" """
...@@ -150,6 +159,10 @@ class SpatialTransformer(nn.Module): ...@@ -150,6 +159,10 @@ class SpatialTransformer(nn.Module):
for block in self.transformer_blocks: for block in self.transformer_blocks:
block._set_attention_slice(slice_size) block._set_attention_slice(slice_size)
def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
for block in self.transformer_blocks:
block._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
def forward(self, hidden_states, context=None): def forward(self, hidden_states, context=None):
# note: if no context is given, cross-attention defaults to self-attention # note: if no context is given, cross-attention defaults to self-attention
batch, channel, height, weight = hidden_states.shape batch, channel, height, weight = hidden_states.shape
...@@ -206,6 +219,32 @@ class BasicTransformerBlock(nn.Module): ...@@ -206,6 +219,32 @@ class BasicTransformerBlock(nn.Module):
self.attn1._slice_size = slice_size self.attn1._slice_size = slice_size
self.attn2._slice_size = slice_size self.attn2._slice_size = slice_size
def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
if not is_xformers_available():
print("Here is how to install it")
raise ModuleNotFoundError(
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
" xformers",
name="xformers",
)
elif not torch.cuda.is_available():
raise ValueError(
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
" available for GPU "
)
else:
try:
# Make sure we can run the memory efficient attention
_ = xformers.ops.memory_efficient_attention(
torch.randn((1, 2, 40), device="cuda"),
torch.randn((1, 2, 40), device="cuda"),
torch.randn((1, 2, 40), device="cuda"),
)
except Exception as e:
raise e
self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
def forward(self, hidden_states, context=None): def forward(self, hidden_states, context=None):
hidden_states = self.attn1(self.norm1(hidden_states)) + hidden_states hidden_states = self.attn1(self.norm1(hidden_states)) + hidden_states
hidden_states = self.attn2(self.norm2(hidden_states), context=context) + hidden_states hidden_states = self.attn2(self.norm2(hidden_states), context=context) + hidden_states
...@@ -239,6 +278,7 @@ class CrossAttention(nn.Module): ...@@ -239,6 +278,7 @@ class CrossAttention(nn.Module):
# is split across the batch axis to save memory # is split across the batch axis to save memory
# You can set slice_size with `set_attention_slice` # You can set slice_size with `set_attention_slice`
self._slice_size = None self._slice_size = None
self._use_memory_efficient_attention_xformers = False
self.to_q = nn.Linear(query_dim, inner_dim, bias=False) self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False) self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
...@@ -279,11 +319,13 @@ class CrossAttention(nn.Module): ...@@ -279,11 +319,13 @@ class CrossAttention(nn.Module):
# TODO(PVP) - mask is currently never used. Remember to re-implement when used # TODO(PVP) - mask is currently never used. Remember to re-implement when used
# attention, what we cannot get enough of # attention, what we cannot get enough of
if self._use_memory_efficient_attention_xformers:
if self._slice_size is None or query.shape[0] // self._slice_size == 1: hidden_states = self._memory_efficient_attention_xformers(query, key, value)
hidden_states = self._attention(query, key, value)
else: else:
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim) if self._slice_size is None or query.shape[0] // self._slice_size == 1:
hidden_states = self._attention(query, key, value)
else:
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim)
# linear proj # linear proj
hidden_states = self.to_out[0](hidden_states) hidden_states = self.to_out[0](hidden_states)
...@@ -341,6 +383,11 @@ class CrossAttention(nn.Module): ...@@ -341,6 +383,11 @@ class CrossAttention(nn.Module):
hidden_states = self.reshape_batch_dim_to_heads(hidden_states) hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
return hidden_states return hidden_states
def _memory_efficient_attention_xformers(self, query, key, value):
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=None)
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
return hidden_states
class FeedForward(nn.Module): class FeedForward(nn.Module):
r""" r"""
......
...@@ -367,6 +367,10 @@ class UNetMidBlock2DCrossAttn(nn.Module): ...@@ -367,6 +367,10 @@ class UNetMidBlock2DCrossAttn(nn.Module):
for attn in self.attentions: for attn in self.attentions:
attn._set_attention_slice(slice_size) attn._set_attention_slice(slice_size)
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
for attn in self.attentions:
attn._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
def forward(self, hidden_states, temb=None, encoder_hidden_states=None): def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
hidden_states = self.resnets[0](hidden_states, temb) hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]): for attn, resnet in zip(self.attentions, self.resnets[1:]):
...@@ -542,6 +546,10 @@ class CrossAttnDownBlock2D(nn.Module): ...@@ -542,6 +546,10 @@ class CrossAttnDownBlock2D(nn.Module):
for attn in self.attentions: for attn in self.attentions:
attn._set_attention_slice(slice_size) attn._set_attention_slice(slice_size)
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
for attn in self.attentions:
attn._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
def forward(self, hidden_states, temb=None, encoder_hidden_states=None): def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
output_states = () output_states = ()
...@@ -1117,6 +1125,10 @@ class CrossAttnUpBlock2D(nn.Module): ...@@ -1117,6 +1125,10 @@ class CrossAttnUpBlock2D(nn.Module):
self.gradient_checkpointing = False self.gradient_checkpointing = False
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
for attn in self.attentions:
attn._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
def forward( def forward(
self, self,
hidden_states, hidden_states,
......
...@@ -225,6 +225,17 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): ...@@ -225,6 +225,17 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
if hasattr(block, "attentions") and block.attentions is not None: if hasattr(block, "attentions") and block.attentions is not None:
block.set_attention_slice(slice_size) block.set_attention_slice(slice_size)
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
for block in self.down_blocks:
if hasattr(block, "attentions") and block.attentions is not None:
block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
self.mid_block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
for block in self.up_blocks:
if hasattr(block, "attentions") and block.attentions is not None:
block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)): if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)):
module.gradient_checkpointing = value module.gradient_checkpointing = value
......
...@@ -113,6 +113,24 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -113,6 +113,24 @@ class StableDiffusionPipeline(DiffusionPipeline):
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
) )
def enable_xformers_memory_efficient_attention(self):
r"""
Enable memory efficient attention as implemented in xformers.
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
time. Speed up at training time is not guaranteed.
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
is used.
"""
self.unet.set_use_memory_efficient_attention_xformers(True)
def disable_xformers_memory_efficient_attention(self):
r"""
Disable memory efficient attention as implemented in xformers.
"""
self.unet.set_use_memory_efficient_attention_xformers(False)
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r""" r"""
Enable sliced attention computation. Enable sliced attention computation.
......
...@@ -151,6 +151,24 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -151,6 +151,24 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
# set slice_size = `None` to disable `set_attention_slice` # set slice_size = `None` to disable `set_attention_slice`
self.enable_attention_slicing(None) self.enable_attention_slicing(None)
def enable_xformers_memory_efficient_attention(self):
r"""
Enable memory efficient attention as implemented in xformers.
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
time. Speed up at training time is not guaranteed.
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
is used.
"""
self.unet.set_use_memory_efficient_attention_xformers(True)
def disable_xformers_memory_efficient_attention(self):
r"""
Disable memory efficient attention as implemented in xformers.
"""
self.unet.set_use_memory_efficient_attention_xformers(False)
@torch.no_grad() @torch.no_grad()
def __call__( def __call__(
self, self,
......
...@@ -151,6 +151,24 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -151,6 +151,24 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
# set slice_size = `None` to disable `attention slicing` # set slice_size = `None` to disable `attention slicing`
self.enable_attention_slicing(None) self.enable_attention_slicing(None)
def enable_xformers_memory_efficient_attention(self):
r"""
Enable memory efficient attention as implemented in xformers.
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
time. Speed up at training time is not guaranteed.
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
is used.
"""
self.unet.set_use_memory_efficient_attention_xformers(True)
def disable_xformers_memory_efficient_attention(self):
r"""
Disable memory efficient attention as implemented in xformers.
"""
self.unet.set_use_memory_efficient_attention_xformers(False)
@torch.no_grad() @torch.no_grad()
def __call__( def __call__(
self, self,
......
...@@ -168,6 +168,18 @@ try: ...@@ -168,6 +168,18 @@ try:
except importlib_metadata.PackageNotFoundError: except importlib_metadata.PackageNotFoundError:
_accelerate_available = False _accelerate_available = False
_xformers_available = importlib.util.find_spec("xformers") is not None
try:
_xformers_version = importlib_metadata.version("xformers")
if _torch_available:
import torch
if torch.__version__ < version.Version("1.12"):
raise ValueError("PyTorch should be >= 1.12")
logger.debug(f"Successfully imported xformers version {_xformers_version}")
except importlib_metadata.PackageNotFoundError:
_xformers_available = False
def is_torch_available(): def is_torch_available():
return _torch_available return _torch_available
...@@ -205,6 +217,10 @@ def is_scipy_available(): ...@@ -205,6 +217,10 @@ def is_scipy_available():
return _scipy_available return _scipy_available
def is_xformers_available():
return _xformers_available
def is_accelerate_available(): def is_accelerate_available():
return _accelerate_available return _accelerate_available
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment