"src/vscode:/vscode.git/clone" did not exist on "b7af94613816e590e09eb536897de37dd2a07e10"
Unverified Commit 1c47d1fc authored by Fabio Rigano's avatar Fabio Rigano Committed by GitHub
Browse files

Fix head_to_batch_dim for IPAdapterAttnProcessor (#7077)

* Fix IPAdapterAttnProcessor

* Fix batch_to_head_dim and revert reshape
parent bbf70c87
......@@ -559,12 +559,16 @@ class Attention(nn.Module):
`torch.Tensor`: The reshaped tensor.
"""
head_size = self.heads
if tensor.ndim == 3:
batch_size, seq_len, dim = tensor.shape
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
extra_dim = 1
else:
batch_size, extra_dim, seq_len, dim = tensor.shape
tensor = tensor.reshape(batch_size, seq_len * extra_dim, head_size, dim // head_size)
tensor = tensor.permute(0, 2, 1, 3)
if out_dim == 3:
tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
tensor = tensor.reshape(batch_size * head_size, seq_len * extra_dim, dim // head_size)
return tensor
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment