Unverified Commit 16bb5058 authored by Takuma Mori's avatar Takuma Mori Committed by GitHub
Browse files

xFormers attention op arg (#2049)



* allow passing op to xFormers attention

original code by @patil-suraj
huggingface/diffusers@ae0cc0b71f28c0f2c5c27026b18f1bea98b505f1

* correct style by `make style`

* add attention_op arg documents

* add usage example to docstring
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* add usage example to docstring
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* code style correction by `make style`

* Update docstring code to a valid python example
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>

* Update docstring code to a valid python example
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>

* style correction by `make style`

* Update code exmaple to fully functional
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
parent 7533e3d7
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional
from typing import Callable, Optional
import torch
import torch.nn.functional as F
......@@ -72,6 +72,7 @@ class AttentionBlock(nn.Module):
self.proj_attn = nn.Linear(channels, channels, 1)
self._use_memory_efficient_attention_xformers = False
self._attention_op = None
def reshape_heads_to_batch_dim(self, tensor):
batch_size, seq_len, dim = tensor.shape
......@@ -87,7 +88,9 @@ class AttentionBlock(nn.Module):
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
return tensor
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
def set_use_memory_efficient_attention_xformers(
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
):
if use_memory_efficient_attention_xformers:
if not is_xformers_available():
raise ModuleNotFoundError(
......@@ -113,6 +116,7 @@ class AttentionBlock(nn.Module):
except Exception as e:
raise e
self._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
self._attention_op = attention_op
def forward(self, hidden_states):
residual = hidden_states
......@@ -136,7 +140,9 @@ class AttentionBlock(nn.Module):
if self._use_memory_efficient_attention_xformers:
# Memory efficient attention
hidden_states = xformers.ops.memory_efficient_attention(query_proj, key_proj, value_proj, attn_bias=None)
hidden_states = xformers.ops.memory_efficient_attention(
query_proj, key_proj, value_proj, attn_bias=None, op=self._attention_op
)
hidden_states = hidden_states.to(query_proj.dtype)
else:
attention_scores = torch.baddbmm(
......
......@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Union
from typing import Callable, Optional, Union
import torch
import torch.nn.functional as F
......@@ -93,7 +93,9 @@ class CrossAttention(nn.Module):
processor = processor if processor is not None else CrossAttnProcessor()
self.set_processor(processor)
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
def set_use_memory_efficient_attention_xformers(
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
):
if use_memory_efficient_attention_xformers:
if self.added_kv_proj_dim is not None:
# TODO(Anton, Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP
......@@ -127,7 +129,7 @@ class CrossAttention(nn.Module):
except Exception as e:
raise e
processor = XFormersCrossAttnProcessor()
processor = XFormersCrossAttnProcessor(attention_op=attention_op)
else:
processor = CrossAttnProcessor()
......@@ -351,6 +353,9 @@ class CrossAttnAddedKVProcessor:
class XFormersCrossAttnProcessor:
def __init__(self, attention_op: Optional[Callable] = None):
self.attention_op = attention_op
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
batch_size, sequence_length, _ = hidden_states.shape
......@@ -366,7 +371,9 @@ class XFormersCrossAttnProcessor:
key = attn.head_to_batch_dim(key).contiguous()
value = attn.head_to_batch_dim(value).contiguous()
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
hidden_states = xformers.ops.memory_efficient_attention(
query, key, value, attn_bias=attention_mask, op=self.attention_op
)
hidden_states = hidden_states.to(query.dtype)
hidden_states = attn.batch_to_head_dim(hidden_states)
......
......@@ -190,13 +190,15 @@ class ModelMixin(torch.nn.Module):
if self._supports_gradient_checkpointing:
self.apply(partial(self._set_gradient_checkpointing, value=False))
def set_use_memory_efficient_attention_xformers(self, valid: bool) -> None:
def set_use_memory_efficient_attention_xformers(
self, valid: bool, attention_op: Optional[Callable] = None
) -> None:
# Recursively walk through all the children.
# Any children which exposes the set_use_memory_efficient_attention_xformers method
# gets the message
def fn_recursive_set_mem_eff(module: torch.nn.Module):
if hasattr(module, "set_use_memory_efficient_attention_xformers"):
module.set_use_memory_efficient_attention_xformers(valid)
module.set_use_memory_efficient_attention_xformers(valid, attention_op)
for child in module.children():
fn_recursive_set_mem_eff(child)
......@@ -205,7 +207,7 @@ class ModelMixin(torch.nn.Module):
if isinstance(module, torch.nn.Module):
fn_recursive_set_mem_eff(module)
def enable_xformers_memory_efficient_attention(self):
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
r"""
Enable memory efficient attention as implemented in xformers.
......@@ -214,8 +216,28 @@ class ModelMixin(torch.nn.Module):
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
is used.
Parameters:
attention_op (`Callable`, *optional*):
Override the default `None` operator for use as `op` argument to the
[`memory_efficient_attention()`](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.memory_efficient_attention)
function of xFormers.
Examples:
```py
>>> import torch
>>> from diffusers import UNet2DConditionModel
>>> from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
>>> model = UNet2DConditionModel.from_pretrained(
... "stabilityai/stable-diffusion-2-1", subfolder="unet", torch_dtype=torch.float16
... )
>>> model = model.to("cuda")
>>> model.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp)
```
"""
self.set_use_memory_efficient_attention_xformers(True)
self.set_use_memory_efficient_attention_xformers(True, attention_op)
def disable_xformers_memory_efficient_attention(self):
r"""
......
......@@ -19,7 +19,7 @@ import inspect
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import torch
......@@ -842,7 +842,7 @@ class DiffusionPipeline(ConfigMixin):
def set_progress_bar_config(self, **kwargs):
self._progress_bar_config = kwargs
def enable_xformers_memory_efficient_attention(self):
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
r"""
Enable memory efficient attention as implemented in xformers.
......@@ -851,8 +851,28 @@ class DiffusionPipeline(ConfigMixin):
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
is used.
Parameters:
attention_op (`Callable`, *optional*):
Override the default `None` operator for use as `op` argument to the
[`memory_efficient_attention()`](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.memory_efficient_attention)
function of xFormers.
Examples:
```py
>>> import torch
>>> from diffusers import DiffusionPipeline
>>> from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
>>> pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16)
>>> pipe = pipe.to("cuda")
>>> pipe.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp)
>>> # Workaround for not accepting attention shape using VAE for Flash Attention
>>> pipe.vae.enable_xformers_memory_efficient_attention(attention_op=None)
```
"""
self.set_use_memory_efficient_attention_xformers(True)
self.set_use_memory_efficient_attention_xformers(True, attention_op)
def disable_xformers_memory_efficient_attention(self):
r"""
......@@ -860,13 +880,15 @@ class DiffusionPipeline(ConfigMixin):
"""
self.set_use_memory_efficient_attention_xformers(False)
def set_use_memory_efficient_attention_xformers(self, valid: bool) -> None:
def set_use_memory_efficient_attention_xformers(
self, valid: bool, attention_op: Optional[Callable] = None
) -> None:
# Recursively walk through all the children.
# Any children which exposes the set_use_memory_efficient_attention_xformers method
# gets the message
def fn_recursive_set_mem_eff(module: torch.nn.Module):
if hasattr(module, "set_use_memory_efficient_attention_xformers"):
module.set_use_memory_efficient_attention_xformers(valid)
module.set_use_memory_efficient_attention_xformers(valid, attention_op)
for child in module.children():
fn_recursive_set_mem_eff(child)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment