Commit 3d9d977a authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Remove vestigial unit tests

parent de0fa7b1
...@@ -349,32 +349,3 @@ class ExtraMSAEmbedder(nn.Module): ...@@ -349,32 +349,3 @@ class ExtraMSAEmbedder(nn.Module):
x = self.linear(x) x = self.linear(x)
return x return x
if __name__ == "__main__":
tf_dim = 21
msa_dim = 49
c_z = 128
c_m = 256
relpos_k = 32
b = 16
n_res = 200
n_clust = 10
tf = torch.rand((b, n_res, tf_dim))
ri = torch.rand((b, n_res))
msa = torch.rand((b, n_clust, n_res, msa_dim))
batch = {}
batch["target_feat"] = tf
batch["residue_index"] = ri
batch["msa_feat"] = msa
ie = InputEmbedder(tf_dim, msa_dim, c_z, c_m, relpos_k)
msa_emb, pair_emb = ie(batch)
assert(msa_emb.shape == (b, n_clust, n_res, c_m))
assert(pair_emb.shape == (b, n_res, n_res, c_z))
...@@ -439,62 +439,3 @@ class ExtraMSAStack(nn.Module): ...@@ -439,62 +439,3 @@ class ExtraMSAStack(nn.Module):
_mask_trans=_mask_trans _mask_trans=_mask_trans
) )
return z return z
if __name__ == "__main__":
batch_size = 2
s_t = 3
n_res = 100
c_m = 128
c_z = 64
c_hidden_att = 32
c_hidden_opm = 31
c_hidden_mul = 30
c_s = 29
no_heads_msa = 4
no_heads_pair = 8
no_blocks = 2
transition_n = 5
msa_dropout = 0.15
pair_dropout = 0.25
es = EvoformerStack(
c_m,
c_z,
c_hidden_att,
c_hidden_opm,
c_hidden_mul,
c_s,
no_heads_msa,
no_heads_pair,
no_blocks,
transition_n,
msa_dropout,
pair_dropout,
)
m = torch.rand((batch_size, s_t, n_res, c_m))
z = torch.rand((batch_size, n_res, n_res, c_z))
shape_m_before = m.shape
shape_z_before = z.shape
m, z, s = es(m, z)
assert(m.shape == shape_m_before)
assert(z.shape == shape_z_before)
assert(s.shape == (batch_size, n_res, c_s))
batch_size = 2
s = 5
n_res = 100
c_m = 256
c = 32
c_z = 128
opm = OuterProductMean(c_m, c_z, c)
m = torch.rand((batch_size, s, n_res, c_m))
m = opm(m)
assert(m.shape == (batch_size, n_res, n_res, c_z))
...@@ -285,22 +285,3 @@ class MSAColumnGlobalAttention(nn.Module): ...@@ -285,22 +285,3 @@ class MSAColumnGlobalAttention(nn.Module):
m = m.transpose(-2, -3) m = m.transpose(-2, -3)
return m return m
if __name__ == "__main__":
batch_size = 2
s_t = 3
n = 100
c_in = 128
c = 32
no_heads = 4
msaca = MSAColumnAttention(c_in, c, no_heads)
x = torch.rand((batch_size, s_t, n, c_in))
shape_before = x.shape
x = msaca(x)
shape_after = x.shape
assert(shape_before == shape_after)
...@@ -110,19 +110,3 @@ class OuterProductMean(nn.Module): ...@@ -110,19 +110,3 @@ class OuterProductMean(nn.Module):
outer = outer / (self.eps + norm) outer = outer / (self.eps + norm)
return outer return outer
if __name__ == "__main__":
batch_size = 2
s = 5
n_res = 100
c_m = 256
c = 32
c_z = 128
opm = OuterProductMean(c_m, c_z, c)
m = torch.rand((batch_size, s, n_res, c_m))
m = opm(m)
assert(m.shape == (batch_size, n_res, n_res, c_z))
...@@ -84,20 +84,3 @@ class PairTransition(nn.Module): ...@@ -84,20 +84,3 @@ class PairTransition(nn.Module):
z = self._transition(**inp) z = self._transition(**inp)
return z return z
if __name__ == "__main__":
n = 4
c_in = 128
pt = PairTransition(n, c_in)
batch_size = 4
n_res = 256
z = torch.rand((batch_size, n_res, n_res, c_in))
shape_before = z.shape
z = pt(z)
shape_after = z.shape
assert(shape_before == shape_after)
...@@ -810,29 +810,3 @@ class StructureModule(nn.Module): ...@@ -810,29 +810,3 @@ class StructureModule(nn.Module):
self.atom_mask, self.atom_mask,
self.lit_positions, self.lit_positions,
) )
if __name__ == "__main__":
c_m = 11
c_z = 13
c_hidden = 17
no_heads = 3
no_qp = 5
no_vp = 7
batch_size = 2
s = torch.rand((batch_size, n_res, c_m))
z = torch.rand((batch_size, n_res, n_res, c_z))
rots = torch.rand((batch_size, n_res, 3, 3))
trans = torch.rand((batch_size, n_res, 3))
t = (rots, trans)
ipa = InvariantPointAttention(c_m, c_z, c_hidden, no_heads, no_qp, no_vp)
shape_before = s.shape
s = ipa(s, z, t)
assert(s.shape == shape_before)
...@@ -271,40 +271,3 @@ class TemplatePairStack(nn.Module): ...@@ -271,40 +271,3 @@ class TemplatePairStack(nn.Module):
t = self.layer_norm(t) t = self.layer_norm(t)
return t return t
if __name__ == "__main__":
template_angle_dim = 51
c_m = 256
batch_size = 4
n_templ = 4
n_res = 256
tae = TemplateAngleEmbedder(
template_angle_dim,
c_m,
)
x = torch.rand((batch_size, n_templ, n_res, template_angle_dim))
x.shape_before = x.shape
x = tae(x)
x.shape_after = x.shape
assert(shape_before == shape_after)
batch_size = 2
s_t = 4
c_t = 64
c_z = 128
c = 32
no_heads = 3
n = 100
tpa = TemplatePointwiseAttention(c_t, c_z, c)
t = torch.rand((batch_size, s_t, n, n, c_t))
z = torch.rand((batch_size, n, n, c_z))
z_update = tpa(t, z)
assert(z_update.shape == z.shape)
...@@ -130,27 +130,3 @@ class TriangleAttentionEndingNode(TriangleAttention): ...@@ -130,27 +130,3 @@ class TriangleAttentionEndingNode(TriangleAttention):
Implements Algorithm 14. Implements Algorithm 14.
""" """
__init__ = partialmethod(TriangleAttention.__init__, starting=False) __init__ = partialmethod(TriangleAttention.__init__, starting=False)
if __name__ == "__main__":
c_in = 256
c = 32
no_heads = 4
starting = True
tan = TriangleAttention(
c_in,
c,
no_heads,
starting
)
batch_size = 16
n_res = 256
x = torch.rand((batch_size, n_res, n_res, c_in))
shape_before = x.shape
x = tan(x)
shape_after = x.shape
assert(shape_before == shape_after)
...@@ -125,24 +125,3 @@ class TriangleMultiplicationIncoming(TriangleMultiplicativeUpdate): ...@@ -125,24 +125,3 @@ class TriangleMultiplicationIncoming(TriangleMultiplicativeUpdate):
__init__ = partialmethod( __init__ = partialmethod(
TriangleMultiplicativeUpdate.__init__, _outgoing=False, TriangleMultiplicativeUpdate.__init__, _outgoing=False,
) )
if __name__ == "__main__":
c_in = 256 # doubled to make shape changes more apparent
c = 128
outgoing = True
tm = TriangleMultiplication(
c_in,
c,
outgoing,
)
n_res = 300
batch_size = 16
x = torch.rand((batch_size, n_res, n_res, c_in))
shape_before = x.shape
x = tm(x)
shape_after = x.shape
assert(shape_before == shape_after)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment