Unverified Commit 5adb0a7b authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

use torch.matmul instead of einsum in attnetion. (#445)

* use torch.matmul instead of einsum

* fix softmax
parent b2b3b1a8
...@@ -275,11 +275,9 @@ class CrossAttention(nn.Module): ...@@ -275,11 +275,9 @@ class CrossAttention(nn.Module):
for i in range(hidden_states.shape[0] // slice_size): for i in range(hidden_states.shape[0] // slice_size):
start_idx = i * slice_size start_idx = i * slice_size
end_idx = (i + 1) * slice_size end_idx = (i + 1) * slice_size
attn_slice = ( attn_slice = torch.matmul(query[start_idx:end_idx], key[start_idx:end_idx].transpose(1, 2)) * self.scale
torch.einsum("b i d, b j d -> b i j", query[start_idx:end_idx], key[start_idx:end_idx]) * self.scale
)
attn_slice = attn_slice.softmax(dim=-1) attn_slice = attn_slice.softmax(dim=-1)
attn_slice = torch.einsum("b i j, b j d -> b i d", attn_slice, value[start_idx:end_idx]) attn_slice = torch.matmul(attn_slice, value[start_idx:end_idx])
hidden_states[start_idx:end_idx] = attn_slice hidden_states[start_idx:end_idx] = attn_slice
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment