Commit 862dab6f authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Fix bugs from previous commit

parent 34e9363c
......@@ -315,7 +315,7 @@ class TemplatePairStack(nn.Module):
expand_idx[-3] = t.shape[-4]
mask = mask.expand(*expand_idx)
(t,) = checkpoint_blocks(
t, = checkpoint_blocks(
blocks=[
partial(
b,
......
......@@ -138,7 +138,6 @@ def _trace_module(module, batch_dims=None):
)
)
}
module = OPM(module)
else:
raise TypeError(
f"tracing is not supported for modules of type {type(module)}"
......
......@@ -52,7 +52,7 @@ class TriangleMultiplicativeUpdate(nn.Module):
self.sigmoid = nn.Sigmoid()
def _combine_projections(
def _combine_projections(self,
a: torch.Tensor,
b: torch.Tensor,
) -> torch.Tensor:
......@@ -94,8 +94,7 @@ class TriangleMultiplicationOutgoing(TriangleMultiplicativeUpdate):
"""
Implements Algorithm 11.
"""
def _combine_projections(
self,
def _combine_projections(self,
a: torch.Tensor, # [*, N_i, N_k, C]
b: torch.Tensor, # [*, N_j, N_k, C]
):
......@@ -113,8 +112,7 @@ class TriangleMultiplicationIncoming(TriangleMultiplicativeUpdate):
"""
Implements Algorithm 12.
"""
def _combine_projections(
self,
def _combine_projections(self,
a: torch.Tensor, # [*, N_k, N_i, C]
b: torch.Tensor, # [*, N_k, N_j, C]
):
......
......@@ -32,7 +32,7 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
c = 11
outgoing = True
tm = TriangleMultiplicativeUpdate(
tm = TriangleMultiplicationOutgoing(
c_z,
c,
outgoing,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment