Commit b88f8dac authored by DeepMind's avatar DeepMind Committed by Copybara-Service
Browse files

Add a comment on the TriangleMultiplication notation.

PiperOrigin-RevId: 387766802
Change-Id: Ic838537513fe1d5bf41facffffd44046e91c3fa3
parent c9ffb0bc
......@@ -1305,6 +1305,12 @@ class TriangleMultiplication(hk.Module):
left_proj_act *= left_gate_values
right_proj_act *= right_gate_values
# "Outgoing" edges equation: 'ikc,jkc->ijc'
# "Incoming" edges equation: 'kjc,kic->ijc'
# Note on the Suppl. Alg. 11 & 12 notation:
# For the "outgoing" edges, a = left_proj_act and b = right_proj_act
# For the "incoming" edges, it's swapped:
# b = left_proj_act and a = right_proj_act
act = jnp.einsum(c.equation, left_proj_act, right_proj_act)
act = hk.LayerNorm(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment