Commit 2726892a authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Remove more undefined references

parent 9285d94d
...@@ -552,7 +552,7 @@ def _lma( ...@@ -552,7 +552,7 @@ def _lma(
] ]
a = torch.einsum( a = torch.einsum(
"...qhd,...khd->...hqk", query, key "...qhd,...khd->...hqk", q_chunk, k_chunk,
) )
for b in small_bias_chunks: for b in small_bias_chunks:
...@@ -562,7 +562,7 @@ def _lma( ...@@ -562,7 +562,7 @@ def _lma(
max_a = torch.max(a, dim=-1, keepdim=True)[0] max_a = torch.max(a, dim=-1, keepdim=True)[0]
exp_a = torch.exp(a - max_a) exp_a = torch.exp(a - max_a)
exp_v = torch.einsum("...vhf,...qhv->...qhf", value, exp_a) exp_v = torch.einsum("...vhf,...qhv->...qhf", v_chunk, exp_a)
maxes.append(max_a.detach().squeeze(-1)) maxes.append(max_a.detach().squeeze(-1))
weights.append(torch.sum(exp_a, dim=-1)) weights.append(torch.sum(exp_a, dim=-1))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment