"deploy/examples/vllm/utils/chat_processor.py" did not exist on "8435b993b3311c2725893cc0a010da757e2540a1"
Unverified Commit bc075004 authored by Jennifer Wei's avatar Jennifer Wei Committed by GitHub
Browse files

Merge pull request #404 from jnwei/multimer-small-edits

Type fixes and README changes for multimer branch
parents 5e0616b6 a2adb147
......@@ -197,7 +197,7 @@ class TestStructureModule(unittest.TestCase):
# The structure module, thanks to angle normalization, is very volatile
# We only assess the mean here. Heuristically speaking, it seems to
# have lower error in general on real rather than synthetic data.
self.assertTrue(torch.mean(torch.abs(out_gt - out_repro)) < 0.05)
compare_utils.assert_mean_abs_diff_small(out_gt, out_repro, 0.05)
class TestInvariantPointAttention(unittest.TestCase):
......@@ -321,7 +321,7 @@ class TestInvariantPointAttention(unittest.TestCase):
torch.as_tensor(sample_mask.squeeze(-1)).float().cuda(),
).cpu()
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)
compare_utils.assert_max_abs_diff_small(out_gt, out_repro, consts.eps)
class TestAngleResnet(unittest.TestCase):
......
......@@ -191,7 +191,7 @@ class TestTemplatePairStack(unittest.TestCase):
_mask_trans=False,
).cpu()
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)
compare_utils.assert_max_abs_diff_small(out_gt, out_repro, consts.eps)
class Template(unittest.TestCase):
......@@ -284,7 +284,7 @@ class Template(unittest.TestCase):
out_repro = out_repro_all["template_pair_embedding"]
out_repro = out_repro.cpu()
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)
compare_utils.assert_mean_abs_diff_small(out_gt, out_repro, consts.eps)
if __name__ == "__main__":
......
......@@ -79,7 +79,7 @@ class TestTriangularAttention(unittest.TestCase):
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
+ name
)
params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray)
params = tree_map(lambda n: n[0], params, jax.Array)
out_gt = f.apply(params, None, pair_act, pair_mask).block_until_ready()
out_gt = torch.as_tensor(np.array(out_gt))
......@@ -102,7 +102,7 @@ class TestTriangularAttention(unittest.TestCase):
chunk_size=None,
).cpu()
self.assertTrue(torch.mean(torch.abs(out_gt - out_repro)) < consts.eps)
compare_utils.assert_mean_abs_diff_small(out_gt, out_repro, consts.eps)
@compare_utils.skip_unless_alphafold_installed()
def test_tri_att_end_compare(self):
......
......@@ -85,7 +85,7 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
+ name
)
params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray)
params = tree_map(lambda n: n[0], params, jax.Array)
out_gt = f.apply(params, None, pair_act, pair_mask).block_until_ready()
out_gt = torch.as_tensor(np.array(out_gt))
......@@ -103,7 +103,7 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
inplace_safe=True, _inplace_chunk_size=4,
).cpu()
self.assertTrue(torch.mean(torch.abs(out_gt - out_repro)) < consts.eps)
compare_utils.assert_mean_abs_diff_small(out_gt, out_repro, consts.eps)
@compare_utils.skip_unless_alphafold_installed()
def test_tri_mul_out_compare(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment