Unverified Commit 1c47d1fc authored by Fabio Rigano's avatar Fabio Rigano Committed by GitHub
Browse files

Fix head_to_batch_dim for IPAdapterAttnProcessor (#7077)

* Fix IPAdapterAttnProcessor

* Fix batch_to_head_dim and revert reshape
parent bbf70c87
...@@ -559,12 +559,16 @@ class Attention(nn.Module): ...@@ -559,12 +559,16 @@ class Attention(nn.Module):
`torch.Tensor`: The reshaped tensor. `torch.Tensor`: The reshaped tensor.
""" """
head_size = self.heads head_size = self.heads
batch_size, seq_len, dim = tensor.shape if tensor.ndim == 3:
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) batch_size, seq_len, dim = tensor.shape
extra_dim = 1
else:
batch_size, extra_dim, seq_len, dim = tensor.shape
tensor = tensor.reshape(batch_size, seq_len * extra_dim, head_size, dim // head_size)
tensor = tensor.permute(0, 2, 1, 3) tensor = tensor.permute(0, 2, 1, 3)
if out_dim == 3: if out_dim == 3:
tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size) tensor = tensor.reshape(batch_size * head_size, seq_len * extra_dim, dim // head_size)
return tensor return tensor
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment