Unverified Commit 16d56c4b authored by Juan Acevedo's avatar Juan Acevedo Committed by GitHub
Browse files

F/flax split head dim (#5181)



* split_head_dim flax attn

* Make split_head_dim non default

* make style and make quality

* add description for split_head_dim flag

* Update src/diffusers/models/attention_flax.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

---------
Co-authored-by: default avatarJuan Acevedo <jfacevedo@google.com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent c82f7baf
...@@ -131,6 +131,8 @@ class FlaxAttention(nn.Module): ...@@ -131,6 +131,8 @@ class FlaxAttention(nn.Module):
Dropout rate Dropout rate
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
enable memory efficient attention https://arxiv.org/abs/2112.05682 enable memory efficient attention https://arxiv.org/abs/2112.05682
split_head_dim (`bool`, *optional*, defaults to `False`):
Whether to split the head dimension into a new axis for the self-attention computation. In most cases, enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype` Parameters `dtype`
...@@ -140,6 +142,7 @@ class FlaxAttention(nn.Module): ...@@ -140,6 +142,7 @@ class FlaxAttention(nn.Module):
dim_head: int = 64 dim_head: int = 64
dropout: float = 0.0 dropout: float = 0.0
use_memory_efficient_attention: bool = False use_memory_efficient_attention: bool = False
split_head_dim: bool = False
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
def setup(self): def setup(self):
...@@ -177,9 +180,15 @@ class FlaxAttention(nn.Module): ...@@ -177,9 +180,15 @@ class FlaxAttention(nn.Module):
key_proj = self.key(context) key_proj = self.key(context)
value_proj = self.value(context) value_proj = self.value(context)
query_states = self.reshape_heads_to_batch_dim(query_proj) if self.split_head_dim:
key_states = self.reshape_heads_to_batch_dim(key_proj) b = hidden_states.shape[0]
value_states = self.reshape_heads_to_batch_dim(value_proj) query_states = jnp.reshape(query_proj, (b, -1, self.heads, self.dim_head))
key_states = jnp.reshape(key_proj, (b, -1, self.heads, self.dim_head))
value_states = jnp.reshape(value_proj, (b, -1, self.heads, self.dim_head))
else:
query_states = self.reshape_heads_to_batch_dim(query_proj)
key_states = self.reshape_heads_to_batch_dim(key_proj)
value_states = self.reshape_heads_to_batch_dim(value_proj)
if self.use_memory_efficient_attention: if self.use_memory_efficient_attention:
query_states = query_states.transpose(1, 0, 2) query_states = query_states.transpose(1, 0, 2)
...@@ -206,14 +215,23 @@ class FlaxAttention(nn.Module): ...@@ -206,14 +215,23 @@ class FlaxAttention(nn.Module):
hidden_states = hidden_states.transpose(1, 0, 2) hidden_states = hidden_states.transpose(1, 0, 2)
else: else:
# compute attentions # compute attentions
attention_scores = jnp.einsum("b i d, b j d->b i j", query_states, key_states) if self.split_head_dim:
attention_scores = jnp.einsum("b t n h, b f n h -> b n f t", key_states, query_states)
else:
attention_scores = jnp.einsum("b i d, b j d->b i j", query_states, key_states)
attention_scores = attention_scores * self.scale attention_scores = attention_scores * self.scale
attention_probs = nn.softmax(attention_scores, axis=2) attention_probs = nn.softmax(attention_scores, axis=-1 if self.split_head_dim else 2)
# attend to values # attend to values
hidden_states = jnp.einsum("b i j, b j d -> b i d", attention_probs, value_states) if self.split_head_dim:
hidden_states = jnp.einsum("b n f t, b t n h -> b f n h", attention_probs, value_states)
b = hidden_states.shape[0]
hidden_states = jnp.reshape(hidden_states, (b, -1, self.heads * self.dim_head))
else:
hidden_states = jnp.einsum("b i j, b j d -> b i d", attention_probs, value_states)
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
hidden_states = self.proj_attn(hidden_states) hidden_states = self.proj_attn(hidden_states)
return self.dropout_layer(hidden_states, deterministic=deterministic) return self.dropout_layer(hidden_states, deterministic=deterministic)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment