Unverified Commit 4d35d7fe authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Allow disabling torch 2_0 attention (#3273)

* Allow disabling torch 2_0 attention

* make style

* Update src/diffusers/models/attention.py
parent a7b0671c
...@@ -71,6 +71,7 @@ class AttentionBlock(nn.Module): ...@@ -71,6 +71,7 @@ class AttentionBlock(nn.Module):
self.proj_attn = nn.Linear(channels, channels, bias=True) self.proj_attn = nn.Linear(channels, channels, bias=True)
self._use_memory_efficient_attention_xformers = False self._use_memory_efficient_attention_xformers = False
self._use_2_0_attn = True
self._attention_op = None self._attention_op = None
def reshape_heads_to_batch_dim(self, tensor, merge_head_and_batch=True): def reshape_heads_to_batch_dim(self, tensor, merge_head_and_batch=True):
...@@ -142,9 +143,8 @@ class AttentionBlock(nn.Module): ...@@ -142,9 +143,8 @@ class AttentionBlock(nn.Module):
scale = 1 / math.sqrt(self.channels / self.num_heads) scale = 1 / math.sqrt(self.channels / self.num_heads)
use_torch_2_0_attn = ( _use_2_0_attn = self._use_2_0_attn and not self._use_memory_efficient_attention_xformers
hasattr(F, "scaled_dot_product_attention") and not self._use_memory_efficient_attention_xformers use_torch_2_0_attn = hasattr(F, "scaled_dot_product_attention") and _use_2_0_attn
)
query_proj = self.reshape_heads_to_batch_dim(query_proj, merge_head_and_batch=not use_torch_2_0_attn) query_proj = self.reshape_heads_to_batch_dim(query_proj, merge_head_and_batch=not use_torch_2_0_attn)
key_proj = self.reshape_heads_to_batch_dim(key_proj, merge_head_and_batch=not use_torch_2_0_attn) key_proj = self.reshape_heads_to_batch_dim(key_proj, merge_head_and_batch=not use_torch_2_0_attn)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment