Unverified Commit bc075004 authored by Jennifer Wei's avatar Jennifer Wei Committed by GitHub
Browse files

Merge pull request #404 from jnwei/multimer-small-edits

Type fixes and README changes for multimer branch
parents 5e0616b6 a2adb147
...@@ -197,7 +197,7 @@ class TestStructureModule(unittest.TestCase): ...@@ -197,7 +197,7 @@ class TestStructureModule(unittest.TestCase):
# The structure module, thanks to angle normalization, is very volatile # The structure module, thanks to angle normalization, is very volatile
# We only assess the mean here. Heuristically speaking, it seems to # We only assess the mean here. Heuristically speaking, it seems to
# have lower error in general on real rather than synthetic data. # have lower error in general on real rather than synthetic data.
self.assertTrue(torch.mean(torch.abs(out_gt - out_repro)) < 0.05) compare_utils.assert_mean_abs_diff_small(out_gt, out_repro, 0.05)
class TestInvariantPointAttention(unittest.TestCase): class TestInvariantPointAttention(unittest.TestCase):
...@@ -321,7 +321,7 @@ class TestInvariantPointAttention(unittest.TestCase): ...@@ -321,7 +321,7 @@ class TestInvariantPointAttention(unittest.TestCase):
torch.as_tensor(sample_mask.squeeze(-1)).float().cuda(), torch.as_tensor(sample_mask.squeeze(-1)).float().cuda(),
).cpu() ).cpu()
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps) compare_utils.assert_max_abs_diff_small(out_gt, out_repro, consts.eps)
class TestAngleResnet(unittest.TestCase): class TestAngleResnet(unittest.TestCase):
......
...@@ -191,7 +191,7 @@ class TestTemplatePairStack(unittest.TestCase): ...@@ -191,7 +191,7 @@ class TestTemplatePairStack(unittest.TestCase):
_mask_trans=False, _mask_trans=False,
).cpu() ).cpu()
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps) compare_utils.assert_max_abs_diff_small(out_gt, out_repro, consts.eps)
class Template(unittest.TestCase): class Template(unittest.TestCase):
...@@ -284,7 +284,7 @@ class Template(unittest.TestCase): ...@@ -284,7 +284,7 @@ class Template(unittest.TestCase):
out_repro = out_repro_all["template_pair_embedding"] out_repro = out_repro_all["template_pair_embedding"]
out_repro = out_repro.cpu() out_repro = out_repro.cpu()
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps) compare_utils.assert_mean_abs_diff_small(out_gt, out_repro, consts.eps)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -79,7 +79,7 @@ class TestTriangularAttention(unittest.TestCase): ...@@ -79,7 +79,7 @@ class TestTriangularAttention(unittest.TestCase):
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/" "alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
+ name + name
) )
params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray) params = tree_map(lambda n: n[0], params, jax.Array)
out_gt = f.apply(params, None, pair_act, pair_mask).block_until_ready() out_gt = f.apply(params, None, pair_act, pair_mask).block_until_ready()
out_gt = torch.as_tensor(np.array(out_gt)) out_gt = torch.as_tensor(np.array(out_gt))
...@@ -102,7 +102,7 @@ class TestTriangularAttention(unittest.TestCase): ...@@ -102,7 +102,7 @@ class TestTriangularAttention(unittest.TestCase):
chunk_size=None, chunk_size=None,
).cpu() ).cpu()
self.assertTrue(torch.mean(torch.abs(out_gt - out_repro)) < consts.eps) compare_utils.assert_mean_abs_diff_small(out_gt, out_repro, consts.eps)
@compare_utils.skip_unless_alphafold_installed() @compare_utils.skip_unless_alphafold_installed()
def test_tri_att_end_compare(self): def test_tri_att_end_compare(self):
......
...@@ -85,7 +85,7 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase): ...@@ -85,7 +85,7 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/" "alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
+ name + name
) )
params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray) params = tree_map(lambda n: n[0], params, jax.Array)
out_gt = f.apply(params, None, pair_act, pair_mask).block_until_ready() out_gt = f.apply(params, None, pair_act, pair_mask).block_until_ready()
out_gt = torch.as_tensor(np.array(out_gt)) out_gt = torch.as_tensor(np.array(out_gt))
...@@ -103,7 +103,7 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase): ...@@ -103,7 +103,7 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
inplace_safe=True, _inplace_chunk_size=4, inplace_safe=True, _inplace_chunk_size=4,
).cpu() ).cpu()
self.assertTrue(torch.mean(torch.abs(out_gt - out_repro)) < consts.eps) compare_utils.assert_mean_abs_diff_small(out_gt, out_repro, consts.eps)
@compare_utils.skip_unless_alphafold_installed() @compare_utils.skip_unless_alphafold_installed()
def test_tri_mul_out_compare(self): def test_tri_mul_out_compare(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment