Commit 427a6ee7 authored by Jennifer's avatar Jennifer
Browse files

update deprecated jax.numpy.DeviceArray to jax.Array

parent 91776cdf
...@@ -178,7 +178,7 @@ class TestEvoformerStack(unittest.TestCase): ...@@ -178,7 +178,7 @@ class TestEvoformerStack(unittest.TestCase):
params = compare_utils.fetch_alphafold_module_weights( params = compare_utils.fetch_alphafold_module_weights(
"alphafold/alphafold_iteration/evoformer/evoformer_iteration" "alphafold/alphafold_iteration/evoformer/evoformer_iteration"
) )
params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray) params = tree_map(lambda n: n[0], params, jax.Array)
key = jax.random.PRNGKey(42) key = jax.random.PRNGKey(42)
out_gt = f.apply(params, key, activations, masks) out_gt = f.apply(params, key, activations, masks)
...@@ -339,7 +339,7 @@ class TestMSATransition(unittest.TestCase): ...@@ -339,7 +339,7 @@ class TestMSATransition(unittest.TestCase):
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/" "alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
+ "msa_transition" + "msa_transition"
) )
params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray) params = tree_map(lambda n: n[0], params, jax.Array)
out_gt = f.apply(params, None, msa_act, msa_mask).block_until_ready() out_gt = f.apply(params, None, msa_act, msa_mask).block_until_ready()
out_gt = torch.as_tensor(np.array(out_gt)) out_gt = torch.as_tensor(np.array(out_gt))
......
...@@ -79,7 +79,7 @@ class TestMSARowAttentionWithPairBias(unittest.TestCase): ...@@ -79,7 +79,7 @@ class TestMSARowAttentionWithPairBias(unittest.TestCase):
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/" "alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
+ "msa_row_attention" + "msa_row_attention"
) )
params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray) params = tree_map(lambda n: n[0], params, jax.Array)
out_gt = f.apply( out_gt = f.apply(
params, None, msa_act, msa_mask, pair_act params, None, msa_act, msa_mask, pair_act
...@@ -144,7 +144,7 @@ class TestMSAColumnAttention(unittest.TestCase): ...@@ -144,7 +144,7 @@ class TestMSAColumnAttention(unittest.TestCase):
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/" "alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
+ "msa_column_attention" + "msa_column_attention"
) )
params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray) params = tree_map(lambda n: n[0], params, jax.Array)
out_gt = f.apply(params, None, msa_act, msa_mask).block_until_ready() out_gt = f.apply(params, None, msa_act, msa_mask).block_until_ready()
out_gt = torch.as_tensor(np.array(out_gt)) out_gt = torch.as_tensor(np.array(out_gt))
...@@ -207,7 +207,7 @@ class TestMSAColumnGlobalAttention(unittest.TestCase): ...@@ -207,7 +207,7 @@ class TestMSAColumnGlobalAttention(unittest.TestCase):
"alphafold/alphafold_iteration/evoformer/extra_msa_stack/" "alphafold/alphafold_iteration/evoformer/extra_msa_stack/"
+ "msa_column_global_attention" + "msa_column_global_attention"
) )
params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray) params = tree_map(lambda n: n[0], params, jax.Array)
out_gt = f.apply(params, None, msa_act, msa_mask).block_until_ready() out_gt = f.apply(params, None, msa_act, msa_mask).block_until_ready()
out_gt = torch.as_tensor(np.array(out_gt.block_until_ready())) out_gt = torch.as_tensor(np.array(out_gt.block_until_ready()))
......
...@@ -74,7 +74,7 @@ class TestOuterProductMean(unittest.TestCase): ...@@ -74,7 +74,7 @@ class TestOuterProductMean(unittest.TestCase):
"alphafold/alphafold_iteration/evoformer/" "alphafold/alphafold_iteration/evoformer/"
+ "evoformer_iteration/outer_product_mean" + "evoformer_iteration/outer_product_mean"
) )
params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray) params = tree_map(lambda n: n[0], params, jax.Array)
out_gt = f.apply(params, None, msa_act, msa_mask).block_until_ready() out_gt = f.apply(params, None, msa_act, msa_mask).block_until_ready()
out_gt = torch.as_tensor(np.array(out_gt)) out_gt = torch.as_tensor(np.array(out_gt))
......
...@@ -69,7 +69,7 @@ class TestPairTransition(unittest.TestCase): ...@@ -69,7 +69,7 @@ class TestPairTransition(unittest.TestCase):
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/" "alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
+ "pair_transition" + "pair_transition"
) )
params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray) params = tree_map(lambda n: n[0], params, jax.Array)
out_gt = f.apply(params, None, pair_act, pair_mask).block_until_ready() out_gt = f.apply(params, None, pair_act, pair_mask).block_until_ready()
out_gt = torch.as_tensor(np.array(out_gt.block_until_ready())) out_gt = torch.as_tensor(np.array(out_gt.block_until_ready()))
......
...@@ -191,7 +191,9 @@ class TestTemplatePairStack(unittest.TestCase): ...@@ -191,7 +191,9 @@ class TestTemplatePairStack(unittest.TestCase):
_mask_trans=False, _mask_trans=False,
).cpu() ).cpu()
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps) diff = torch.max(torch.abs(out_gt - out_repro))
self.assertTrue(diff < consts.eps,
msg=f"Found difference between ground truth and reproduction of {diff}")
class Template(unittest.TestCase): class Template(unittest.TestCase):
......
...@@ -79,7 +79,7 @@ class TestTriangularAttention(unittest.TestCase): ...@@ -79,7 +79,7 @@ class TestTriangularAttention(unittest.TestCase):
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/" "alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
+ name + name
) )
params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray) params = tree_map(lambda n: n[0], params, jax.Array)
out_gt = f.apply(params, None, pair_act, pair_mask).block_until_ready() out_gt = f.apply(params, None, pair_act, pair_mask).block_until_ready()
out_gt = torch.as_tensor(np.array(out_gt)) out_gt = torch.as_tensor(np.array(out_gt))
......
...@@ -85,7 +85,7 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase): ...@@ -85,7 +85,7 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/" "alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
+ name + name
) )
params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray) params = tree_map(lambda n: n[0], params, jax.Array)
out_gt = f.apply(params, None, pair_act, pair_mask).block_until_ready() out_gt = f.apply(params, None, pair_act, pair_mask).block_until_ready()
out_gt = torch.as_tensor(np.array(out_gt)) out_gt = torch.as_tensor(np.array(out_gt))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment