Unverified Commit 930c8fdc authored by Allen Zhang's avatar Allen Zhang Committed by GitHub
Browse files

fix incorrect attention head dimension in AttnProcessor2_0 (#4154)

fix inner_dim
parent 6b1abba1
...@@ -1096,7 +1096,6 @@ class AttnProcessor2_0: ...@@ -1096,7 +1096,6 @@ class AttnProcessor2_0:
batch_size, sequence_length, _ = ( batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
) )
inner_dim = hidden_states.shape[-1]
if attention_mask is not None: if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
...@@ -1117,6 +1116,7 @@ class AttnProcessor2_0: ...@@ -1117,6 +1116,7 @@ class AttnProcessor2_0:
key = attn.to_k(encoder_hidden_states) key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states) value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment