Commit 8187865a authored by ydshieh's avatar ydshieh
Browse files

Fix CrossAttention._sliced_attention

parent 0c0c2224
...@@ -267,7 +267,7 @@ class CrossAttention(nn.Module): ...@@ -267,7 +267,7 @@ class CrossAttention(nn.Module):
if self._slice_size is None or query.shape[0] // self._slice_size == 1: if self._slice_size is None or query.shape[0] // self._slice_size == 1:
hidden_states = self._attention(query, key, value) hidden_states = self._attention(query, key, value)
else: else:
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim) hidden_states = self._sliced_attention(query, key, value, sequence_length, dim=query.shape[-1] * self.heads)
return self.to_out(hidden_states) return self.to_out(hidden_states)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment