Unverified Commit 84616b5d authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix `CrossAttention._sliced_attention` (#563)



* Fix CrossAttention._sliced_attention
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 8d36d5ad
...@@ -249,13 +249,15 @@ class CrossAttention(nn.Module): ...@@ -249,13 +249,15 @@ class CrossAttention(nn.Module):
return tensor return tensor
def forward(self, hidden_states, context=None, mask=None): def forward(self, hidden_states, context=None, mask=None):
batch_size, sequence_length, dim = hidden_states.shape batch_size, sequence_length, _ = hidden_states.shape
query = self.to_q(hidden_states) query = self.to_q(hidden_states)
context = context if context is not None else hidden_states context = context if context is not None else hidden_states
key = self.to_k(context) key = self.to_k(context)
value = self.to_v(context) value = self.to_v(context)
dim = query.shape[-1]
query = self.reshape_heads_to_batch_dim(query) query = self.reshape_heads_to_batch_dim(query)
key = self.reshape_heads_to_batch_dim(key) key = self.reshape_heads_to_batch_dim(key)
value = self.reshape_heads_to_batch_dim(value) value = self.reshape_heads_to_batch_dim(value)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment