"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "bdeff4d64a57e556c2b62f887da03a2c37c54d54"
Unverified Commit a5f35ee4 authored by Juan Acevedo's avatar Juan Acevedo Committed by GitHub
Browse files

add reshape to fix use_memory_efficient_attention in flax (#7918)


Co-authored-by: default avatarJuan Acevedo <jfacevedo@google.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarAryan <aryan@huggingface.co>
parent 63243406
...@@ -216,8 +216,8 @@ class FlaxAttention(nn.Module): ...@@ -216,8 +216,8 @@ class FlaxAttention(nn.Module):
hidden_states = jax_memory_efficient_attention( hidden_states = jax_memory_efficient_attention(
query_states, key_states, value_states, query_chunk_size=query_chunk_size, key_chunk_size=4096 * 4 query_states, key_states, value_states, query_chunk_size=query_chunk_size, key_chunk_size=4096 * 4
) )
hidden_states = hidden_states.transpose(1, 0, 2) hidden_states = hidden_states.transpose(1, 0, 2)
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
else: else:
# compute attentions # compute attentions
if self.split_head_dim: if self.split_head_dim:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment