"...git@developer.sourcefind.cn:OpenDAS/mmdeploy.git" did not exist on "e4fb2aa4ea2a6347df67473bb2f02169510a1cab"
Commit 862dab6f authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Fix bugs from previous commit

parent 34e9363c
...@@ -315,7 +315,7 @@ class TemplatePairStack(nn.Module): ...@@ -315,7 +315,7 @@ class TemplatePairStack(nn.Module):
expand_idx[-3] = t.shape[-4] expand_idx[-3] = t.shape[-4]
mask = mask.expand(*expand_idx) mask = mask.expand(*expand_idx)
(t,) = checkpoint_blocks( t, = checkpoint_blocks(
blocks=[ blocks=[
partial( partial(
b, b,
......
...@@ -138,7 +138,6 @@ def _trace_module(module, batch_dims=None): ...@@ -138,7 +138,6 @@ def _trace_module(module, batch_dims=None):
) )
) )
} }
module = OPM(module)
else: else:
raise TypeError( raise TypeError(
f"tracing is not supported for modules of type {type(module)}" f"tracing is not supported for modules of type {type(module)}"
......
...@@ -52,7 +52,7 @@ class TriangleMultiplicativeUpdate(nn.Module): ...@@ -52,7 +52,7 @@ class TriangleMultiplicativeUpdate(nn.Module):
self.sigmoid = nn.Sigmoid() self.sigmoid = nn.Sigmoid()
def _combine_projections( def _combine_projections(self,
a: torch.Tensor, a: torch.Tensor,
b: torch.Tensor, b: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -94,8 +94,7 @@ class TriangleMultiplicationOutgoing(TriangleMultiplicativeUpdate): ...@@ -94,8 +94,7 @@ class TriangleMultiplicationOutgoing(TriangleMultiplicativeUpdate):
""" """
Implements Algorithm 11. Implements Algorithm 11.
""" """
def _combine_projections( def _combine_projections(self,
self,
a: torch.Tensor, # [*, N_i, N_k, C] a: torch.Tensor, # [*, N_i, N_k, C]
b: torch.Tensor, # [*, N_j, N_k, C] b: torch.Tensor, # [*, N_j, N_k, C]
): ):
...@@ -113,8 +112,7 @@ class TriangleMultiplicationIncoming(TriangleMultiplicativeUpdate): ...@@ -113,8 +112,7 @@ class TriangleMultiplicationIncoming(TriangleMultiplicativeUpdate):
""" """
Implements Algorithm 12. Implements Algorithm 12.
""" """
def _combine_projections( def _combine_projections(self,
self,
a: torch.Tensor, # [*, N_k, N_i, C] a: torch.Tensor, # [*, N_k, N_i, C]
b: torch.Tensor, # [*, N_k, N_j, C] b: torch.Tensor, # [*, N_k, N_j, C]
): ):
......
...@@ -32,7 +32,7 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase): ...@@ -32,7 +32,7 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
c = 11 c = 11
outgoing = True outgoing = True
tm = TriangleMultiplicativeUpdate( tm = TriangleMultiplicationOutgoing(
c_z, c_z,
c, c,
outgoing, outgoing,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment