Commit 28334db3 authored by Sachin Kadyan's avatar Sachin Kadyan
Browse files

Added test for no column attention Evoformer

parent a7c0d0d1
...@@ -66,6 +66,7 @@ class TestEvoformerStack(unittest.TestCase): ...@@ -66,6 +66,7 @@ class TestEvoformerStack(unittest.TestCase):
msa_dropout, msa_dropout,
pair_stack_dropout, pair_stack_dropout,
blocks_per_ckpt=None, blocks_per_ckpt=None,
no_column_attention=False,
inf=inf, inf=inf,
eps=eps, eps=eps,
).eval() ).eval()
...@@ -86,6 +87,62 @@ class TestEvoformerStack(unittest.TestCase): ...@@ -86,6 +87,62 @@ class TestEvoformerStack(unittest.TestCase):
self.assertTrue(z.shape == shape_z_before) self.assertTrue(z.shape == shape_z_before)
self.assertTrue(s.shape == (batch_size, n_res, c_s)) self.assertTrue(s.shape == (batch_size, n_res, c_s))
def test_shape_without_column_attention(self):
batch_size = consts.batch_size
n_seq = consts.n_seq
n_res = consts.n_res
c_m = consts.c_m
c_z = consts.c_z
c_hidden_msa_att = 12
c_hidden_opm = 17
c_hidden_mul = 19
c_hidden_pair_att = 14
c_s = consts.c_s
no_heads_msa = 3
no_heads_pair = 7
no_blocks = 2
transition_n = 2
msa_dropout = 0.15
pair_stack_dropout = 0.25
inf = 1e9
eps = 1e-10
es = EvoformerStack(
c_m,
c_z,
c_hidden_msa_att,
c_hidden_opm,
c_hidden_mul,
c_hidden_pair_att,
c_s,
no_heads_msa,
no_heads_pair,
no_blocks,
transition_n,
msa_dropout,
pair_stack_dropout,
blocks_per_ckpt=None,
no_column_attention=True,
inf=inf,
eps=eps,
).eval()
m_init = torch.rand((batch_size, n_seq, n_res, c_m))
z_init = torch.rand((batch_size, n_res, n_res, c_z))
msa_mask = torch.randint(0, 2, size=(batch_size, n_seq, n_res))
pair_mask = torch.randint(0, 2, size=(batch_size, n_res, n_res))
shape_m_before = m_init.shape
shape_z_before = z_init.shape
m, z, s = es(
m_init, z_init, chunk_size=4, msa_mask=msa_mask, pair_mask=pair_mask
)
self.assertTrue(m.shape == shape_m_before)
self.assertTrue(z.shape == shape_z_before)
self.assertTrue(s.shape == (batch_size, n_res, c_s))
@compare_utils.skip_unless_alphafold_installed() @compare_utils.skip_unless_alphafold_installed()
def test_compare(self): def test_compare(self):
def run_ei(activations, masks): def run_ei(activations, masks):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment