Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
427a6ee7
Commit
427a6ee7
authored
Jan 23, 2024
by
Jennifer
Browse files
update deprecated jax.numpy.DeviceArray to jax.Array
parent
91776cdf
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
12 additions
and
10 deletions
+12
-10
tests/test_evoformer.py
tests/test_evoformer.py
+2
-2
tests/test_msa.py
tests/test_msa.py
+3
-3
tests/test_outer_product_mean.py
tests/test_outer_product_mean.py
+1
-1
tests/test_pair_transition.py
tests/test_pair_transition.py
+1
-1
tests/test_template.py
tests/test_template.py
+3
-1
tests/test_triangular_attention.py
tests/test_triangular_attention.py
+1
-1
tests/test_triangular_multiplicative_update.py
tests/test_triangular_multiplicative_update.py
+1
-1
No files found.
tests/test_evoformer.py
View file @
427a6ee7
...
@@ -178,7 +178,7 @@ class TestEvoformerStack(unittest.TestCase):
...
@@ -178,7 +178,7 @@ class TestEvoformerStack(unittest.TestCase):
params
=
compare_utils
.
fetch_alphafold_module_weights
(
params
=
compare_utils
.
fetch_alphafold_module_weights
(
"alphafold/alphafold_iteration/evoformer/evoformer_iteration"
"alphafold/alphafold_iteration/evoformer/evoformer_iteration"
)
)
params
=
tree_map
(
lambda
n
:
n
[
0
],
params
,
jax
.
numpy
.
Device
Array
)
params
=
tree_map
(
lambda
n
:
n
[
0
],
params
,
jax
.
Array
)
key
=
jax
.
random
.
PRNGKey
(
42
)
key
=
jax
.
random
.
PRNGKey
(
42
)
out_gt
=
f
.
apply
(
params
,
key
,
activations
,
masks
)
out_gt
=
f
.
apply
(
params
,
key
,
activations
,
masks
)
...
@@ -339,7 +339,7 @@ class TestMSATransition(unittest.TestCase):
...
@@ -339,7 +339,7 @@ class TestMSATransition(unittest.TestCase):
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
+
"msa_transition"
+
"msa_transition"
)
)
params
=
tree_map
(
lambda
n
:
n
[
0
],
params
,
jax
.
numpy
.
Device
Array
)
params
=
tree_map
(
lambda
n
:
n
[
0
],
params
,
jax
.
Array
)
out_gt
=
f
.
apply
(
params
,
None
,
msa_act
,
msa_mask
).
block_until_ready
()
out_gt
=
f
.
apply
(
params
,
None
,
msa_act
,
msa_mask
).
block_until_ready
()
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
))
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
))
...
...
tests/test_msa.py
View file @
427a6ee7
...
@@ -79,7 +79,7 @@ class TestMSARowAttentionWithPairBias(unittest.TestCase):
...
@@ -79,7 +79,7 @@ class TestMSARowAttentionWithPairBias(unittest.TestCase):
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
+
"msa_row_attention"
+
"msa_row_attention"
)
)
params
=
tree_map
(
lambda
n
:
n
[
0
],
params
,
jax
.
numpy
.
Device
Array
)
params
=
tree_map
(
lambda
n
:
n
[
0
],
params
,
jax
.
Array
)
out_gt
=
f
.
apply
(
out_gt
=
f
.
apply
(
params
,
None
,
msa_act
,
msa_mask
,
pair_act
params
,
None
,
msa_act
,
msa_mask
,
pair_act
...
@@ -144,7 +144,7 @@ class TestMSAColumnAttention(unittest.TestCase):
...
@@ -144,7 +144,7 @@ class TestMSAColumnAttention(unittest.TestCase):
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
+
"msa_column_attention"
+
"msa_column_attention"
)
)
params
=
tree_map
(
lambda
n
:
n
[
0
],
params
,
jax
.
numpy
.
Device
Array
)
params
=
tree_map
(
lambda
n
:
n
[
0
],
params
,
jax
.
Array
)
out_gt
=
f
.
apply
(
params
,
None
,
msa_act
,
msa_mask
).
block_until_ready
()
out_gt
=
f
.
apply
(
params
,
None
,
msa_act
,
msa_mask
).
block_until_ready
()
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
))
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
))
...
@@ -207,7 +207,7 @@ class TestMSAColumnGlobalAttention(unittest.TestCase):
...
@@ -207,7 +207,7 @@ class TestMSAColumnGlobalAttention(unittest.TestCase):
"alphafold/alphafold_iteration/evoformer/extra_msa_stack/"
"alphafold/alphafold_iteration/evoformer/extra_msa_stack/"
+
"msa_column_global_attention"
+
"msa_column_global_attention"
)
)
params
=
tree_map
(
lambda
n
:
n
[
0
],
params
,
jax
.
numpy
.
Device
Array
)
params
=
tree_map
(
lambda
n
:
n
[
0
],
params
,
jax
.
Array
)
out_gt
=
f
.
apply
(
params
,
None
,
msa_act
,
msa_mask
).
block_until_ready
()
out_gt
=
f
.
apply
(
params
,
None
,
msa_act
,
msa_mask
).
block_until_ready
()
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
.
block_until_ready
()))
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
.
block_until_ready
()))
...
...
tests/test_outer_product_mean.py
View file @
427a6ee7
...
@@ -74,7 +74,7 @@ class TestOuterProductMean(unittest.TestCase):
...
@@ -74,7 +74,7 @@ class TestOuterProductMean(unittest.TestCase):
"alphafold/alphafold_iteration/evoformer/"
"alphafold/alphafold_iteration/evoformer/"
+
"evoformer_iteration/outer_product_mean"
+
"evoformer_iteration/outer_product_mean"
)
)
params
=
tree_map
(
lambda
n
:
n
[
0
],
params
,
jax
.
numpy
.
Device
Array
)
params
=
tree_map
(
lambda
n
:
n
[
0
],
params
,
jax
.
Array
)
out_gt
=
f
.
apply
(
params
,
None
,
msa_act
,
msa_mask
).
block_until_ready
()
out_gt
=
f
.
apply
(
params
,
None
,
msa_act
,
msa_mask
).
block_until_ready
()
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
))
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
))
...
...
tests/test_pair_transition.py
View file @
427a6ee7
...
@@ -69,7 +69,7 @@ class TestPairTransition(unittest.TestCase):
...
@@ -69,7 +69,7 @@ class TestPairTransition(unittest.TestCase):
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
+
"pair_transition"
+
"pair_transition"
)
)
params
=
tree_map
(
lambda
n
:
n
[
0
],
params
,
jax
.
numpy
.
Device
Array
)
params
=
tree_map
(
lambda
n
:
n
[
0
],
params
,
jax
.
Array
)
out_gt
=
f
.
apply
(
params
,
None
,
pair_act
,
pair_mask
).
block_until_ready
()
out_gt
=
f
.
apply
(
params
,
None
,
pair_act
,
pair_mask
).
block_until_ready
()
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
.
block_until_ready
()))
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
.
block_until_ready
()))
...
...
tests/test_template.py
View file @
427a6ee7
...
@@ -191,7 +191,9 @@ class TestTemplatePairStack(unittest.TestCase):
...
@@ -191,7 +191,9 @@ class TestTemplatePairStack(unittest.TestCase):
_mask_trans
=
False
,
_mask_trans
=
False
,
).
cpu
()
).
cpu
()
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
diff
=
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
self
.
assertTrue
(
diff
<
consts
.
eps
,
msg
=
f
"Found difference between ground truth and reproduction of
{
diff
}
"
)
class
Template
(
unittest
.
TestCase
):
class
Template
(
unittest
.
TestCase
):
...
...
tests/test_triangular_attention.py
View file @
427a6ee7
...
@@ -79,7 +79,7 @@ class TestTriangularAttention(unittest.TestCase):
...
@@ -79,7 +79,7 @@ class TestTriangularAttention(unittest.TestCase):
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
+
name
+
name
)
)
params
=
tree_map
(
lambda
n
:
n
[
0
],
params
,
jax
.
numpy
.
Device
Array
)
params
=
tree_map
(
lambda
n
:
n
[
0
],
params
,
jax
.
Array
)
out_gt
=
f
.
apply
(
params
,
None
,
pair_act
,
pair_mask
).
block_until_ready
()
out_gt
=
f
.
apply
(
params
,
None
,
pair_act
,
pair_mask
).
block_until_ready
()
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
))
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
))
...
...
tests/test_triangular_multiplicative_update.py
View file @
427a6ee7
...
@@ -85,7 +85,7 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
...
@@ -85,7 +85,7 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
+
name
+
name
)
)
params
=
tree_map
(
lambda
n
:
n
[
0
],
params
,
jax
.
numpy
.
Device
Array
)
params
=
tree_map
(
lambda
n
:
n
[
0
],
params
,
jax
.
Array
)
out_gt
=
f
.
apply
(
params
,
None
,
pair_act
,
pair_mask
).
block_until_ready
()
out_gt
=
f
.
apply
(
params
,
None
,
pair_act
,
pair_mask
).
block_until_ready
()
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
))
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment