Commit 2726892a authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Remove more undefined references

parent 9285d94d
......@@ -552,7 +552,7 @@ def _lma(
]
a = torch.einsum(
"...qhd,...khd->...hqk", query, key
"...qhd,...khd->...hqk", q_chunk, k_chunk,
)
for b in small_bias_chunks:
......@@ -562,7 +562,7 @@ def _lma(
max_a = torch.max(a, dim=-1, keepdim=True)[0]
exp_a = torch.exp(a - max_a)
exp_v = torch.einsum("...vhf,...qhv->...qhf", value, exp_a)
exp_v = torch.einsum("...vhf,...qhv->...qhf", v_chunk, exp_a)
maxes.append(max_a.detach().squeeze(-1))
weights.append(torch.sum(exp_a, dim=-1))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment