Unverified Commit 16d56c4b authored by Juan Acevedo's avatar Juan Acevedo Committed by GitHub
Browse files

F/flax split head dim (#5181)



* split_head_dim flax attn

* Make split_head_dim non default

* make style and make quality

* add description for split_head_dim flag

* Update src/diffusers/models/attention_flax.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

---------
Co-authored-by: default avatarJuan Acevedo <jfacevedo@google.com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent c82f7baf
......@@ -131,6 +131,8 @@ class FlaxAttention(nn.Module):
Dropout rate
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
enable memory efficient attention https://arxiv.org/abs/2112.05682
split_head_dim (`bool`, *optional*, defaults to `False`):
Whether to split the head dimension into a new axis for the self-attention computation. In most cases, enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
......@@ -140,6 +142,7 @@ class FlaxAttention(nn.Module):
dim_head: int = 64
dropout: float = 0.0
use_memory_efficient_attention: bool = False
split_head_dim: bool = False
dtype: jnp.dtype = jnp.float32
def setup(self):
......@@ -177,6 +180,12 @@ class FlaxAttention(nn.Module):
key_proj = self.key(context)
value_proj = self.value(context)
if self.split_head_dim:
b = hidden_states.shape[0]
query_states = jnp.reshape(query_proj, (b, -1, self.heads, self.dim_head))
key_states = jnp.reshape(key_proj, (b, -1, self.heads, self.dim_head))
value_states = jnp.reshape(value_proj, (b, -1, self.heads, self.dim_head))
else:
query_states = self.reshape_heads_to_batch_dim(query_proj)
key_states = self.reshape_heads_to_batch_dim(key_proj)
value_states = self.reshape_heads_to_batch_dim(value_proj)
......@@ -206,14 +215,23 @@ class FlaxAttention(nn.Module):
hidden_states = hidden_states.transpose(1, 0, 2)
else:
# compute attentions
if self.split_head_dim:
attention_scores = jnp.einsum("b t n h, b f n h -> b n f t", key_states, query_states)
else:
attention_scores = jnp.einsum("b i d, b j d->b i j", query_states, key_states)
attention_scores = attention_scores * self.scale
attention_probs = nn.softmax(attention_scores, axis=2)
attention_probs = nn.softmax(attention_scores, axis=-1 if self.split_head_dim else 2)
# attend to values
if self.split_head_dim:
hidden_states = jnp.einsum("b n f t, b t n h -> b f n h", attention_probs, value_states)
b = hidden_states.shape[0]
hidden_states = jnp.reshape(hidden_states, (b, -1, self.heads * self.dim_head))
else:
hidden_states = jnp.einsum("b i j, b j d -> b i d", attention_probs, value_states)
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
hidden_states = self.proj_attn(hidden_states)
return self.dropout_layer(hidden_states, deterministic=deterministic)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment