Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
d8ee9c5f
You need to sign in or sign up before continuing.
Commit
d8ee9c5f
authored
Feb 17, 2023
by
Christina Floristean
Browse files
All non-cuda tests passing for monomer/multimer. Tri mul/attn and OPM order switched.
parent
260db67f
Changes
33
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
470 additions
and
102 deletions
+470
-102
tests/test_embedders.py
tests/test_embedders.py
+26
-6
tests/test_evoformer.py
tests/test_evoformer.py
+5
-1
tests/test_feats.py
tests/test_feats.py
+75
-18
tests/test_import_weights.py
tests/test_import_weights.py
+1
-1
tests/test_loss.py
tests/test_loss.py
+145
-17
tests/test_model.py
tests/test_model.py
+25
-5
tests/test_outer_product_mean.py
tests/test_outer_product_mean.py
+1
-1
tests/test_pair_transition.py
tests/test_pair_transition.py
+1
-1
tests/test_primitives.py
tests/test_primitives.py
+3
-5
tests/test_structure_module.py
tests/test_structure_module.py
+77
-23
tests/test_template.py
tests/test_template.py
+107
-20
tests/test_triangular_attention.py
tests/test_triangular_attention.py
+2
-2
tests/test_triangular_multiplicative_update.py
tests/test_triangular_multiplicative_update.py
+2
-2
No files found.
tests/test_embedders.py
View file @
d8ee9c5f
...
@@ -12,14 +12,17 @@
...
@@ -12,14 +12,17 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
random
import
torch
import
torch
import
numpy
as
np
import
unittest
import
unittest
from
tests.config
import
consts
from
tests.data_utils
import
random_asym_ids
from
openfold.model.embedders
import
(
from
openfold.model.embedders
import
(
InputEmbedder
,
InputEmbedder
,
InputEmbedderMultimer
,
RecyclingEmbedder
,
RecyclingEmbedder
,
TemplateAngleEmbedder
,
TemplateAngleEmbedder
,
TemplatePairEmbedder
,
TemplatePairEmbedder
)
)
...
@@ -35,13 +38,30 @@ class TestInputEmbedder(unittest.TestCase):
...
@@ -35,13 +38,30 @@ class TestInputEmbedder(unittest.TestCase):
n_res
=
17
n_res
=
17
n_clust
=
19
n_clust
=
19
max_relative_chain
=
2
max_relative_idx
=
32
use_chain_relative
=
True
tf
=
torch
.
rand
((
b
,
n_res
,
tf_dim
))
tf
=
torch
.
rand
((
b
,
n_res
,
tf_dim
))
ri
=
torch
.
rand
((
b
,
n_res
))
ri
=
torch
.
rand
((
b
,
n_res
))
msa
=
torch
.
rand
((
b
,
n_clust
,
n_res
,
msa_dim
))
msa
=
torch
.
rand
((
b
,
n_clust
,
n_res
,
msa_dim
))
asym_ids_flat
=
torch
.
Tensor
(
random_asym_ids
(
n_res
))
ie
=
InputEmbedder
(
tf_dim
,
msa_dim
,
c_z
,
c_m
,
relpos_k
)
asym_id
=
torch
.
tile
(
asym_ids_flat
.
unsqueeze
(
0
),
(
b
,
1
))
entity_id
=
asym_id
msa_emb
,
pair_emb
=
ie
(
tf
,
ri
,
msa
)
sym_id
=
torch
.
zeros_like
(
entity_id
)
batch
=
{
"target_feat"
:
tf
,
"residue_index"
:
ri
,
"msa_feat"
:
msa
}
if
consts
.
is_multimer
:
ie
=
InputEmbedderMultimer
(
tf_dim
,
msa_dim
,
c_z
,
c_m
,
max_relative_idx
=
max_relative_idx
,
use_chain_relative
=
use_chain_relative
,
max_relative_chain
=
max_relative_chain
)
batch
.
update
({
"asym_id"
:
asym_id
,
"entity_id"
:
entity_id
,
"sym_id"
:
sym_id
})
else
:
ie
=
InputEmbedder
(
tf_dim
,
msa_dim
,
c_z
,
c_m
,
relpos_k
)
msa_emb
,
pair_emb
=
ie
(
batch
)
self
.
assertTrue
(
msa_emb
.
shape
==
(
b
,
n_clust
,
n_res
,
c_m
))
self
.
assertTrue
(
msa_emb
.
shape
==
(
b
,
n_clust
,
n_res
,
c_m
))
self
.
assertTrue
(
pair_emb
.
shape
==
(
b
,
n_res
,
n_res
,
c_z
))
self
.
assertTrue
(
pair_emb
.
shape
==
(
b
,
n_res
,
n_res
,
c_z
))
...
...
tests/test_evoformer.py
View file @
d8ee9c5f
...
@@ -48,6 +48,7 @@ class TestEvoformerStack(unittest.TestCase):
...
@@ -48,6 +48,7 @@ class TestEvoformerStack(unittest.TestCase):
transition_n
=
2
transition_n
=
2
msa_dropout
=
0.15
msa_dropout
=
0.15
pair_stack_dropout
=
0.25
pair_stack_dropout
=
0.25
opm_first
=
consts
.
is_multimer
inf
=
1e9
inf
=
1e9
eps
=
1e-10
eps
=
1e-10
...
@@ -65,6 +66,7 @@ class TestEvoformerStack(unittest.TestCase):
...
@@ -65,6 +66,7 @@ class TestEvoformerStack(unittest.TestCase):
transition_n
,
transition_n
,
msa_dropout
,
msa_dropout
,
pair_stack_dropout
,
pair_stack_dropout
,
opm_first
,
blocks_per_ckpt
=
None
,
blocks_per_ckpt
=
None
,
inf
=
inf
,
inf
=
inf
,
eps
=
eps
,
eps
=
eps
,
...
@@ -156,6 +158,7 @@ class TestExtraMSAStack(unittest.TestCase):
...
@@ -156,6 +158,7 @@ class TestExtraMSAStack(unittest.TestCase):
transition_n
=
5
transition_n
=
5
msa_dropout
=
0.15
msa_dropout
=
0.15
pair_stack_dropout
=
0.25
pair_stack_dropout
=
0.25
opm_first
=
consts
.
is_multimer
inf
=
1e9
inf
=
1e9
eps
=
1e-10
eps
=
1e-10
...
@@ -172,6 +175,7 @@ class TestExtraMSAStack(unittest.TestCase):
...
@@ -172,6 +175,7 @@ class TestExtraMSAStack(unittest.TestCase):
transition_n
,
transition_n
,
msa_dropout
,
msa_dropout
,
pair_stack_dropout
,
pair_stack_dropout
,
opm_first
,
ckpt
=
False
,
ckpt
=
False
,
inf
=
inf
,
inf
=
inf
,
eps
=
eps
,
eps
=
eps
,
...
@@ -259,7 +263,7 @@ class TestMSATransition(unittest.TestCase):
...
@@ -259,7 +263,7 @@ class TestMSATransition(unittest.TestCase):
model
=
compare_utils
.
get_global_pretrained_openfold
()
model
=
compare_utils
.
get_global_pretrained_openfold
()
out_repro
=
(
out_repro
=
(
model
.
evoformer
.
blocks
[
0
].
core
.
msa_transition
(
model
.
evoformer
.
blocks
[
0
].
msa_transition
(
torch
.
as_tensor
(
msa_act
,
dtype
=
torch
.
float32
).
cuda
(),
torch
.
as_tensor
(
msa_act
,
dtype
=
torch
.
float32
).
cuda
(),
mask
=
torch
.
as_tensor
(
msa_mask
,
dtype
=
torch
.
float32
).
cuda
(),
mask
=
torch
.
as_tensor
(
msa_mask
,
dtype
=
torch
.
float32
).
cuda
(),
)
)
...
...
tests/test_feats.py
View file @
d8ee9c5f
...
@@ -25,6 +25,9 @@ from openfold.np.residue_constants import (
...
@@ -25,6 +25,9 @@ from openfold.np.residue_constants import (
)
)
import
openfold.utils.feats
as
feats
import
openfold.utils.feats
as
feats
from
openfold.utils.rigid_utils
import
Rotation
,
Rigid
from
openfold.utils.rigid_utils
import
Rotation
,
Rigid
from
openfold.utils.geometry.rigid_matrix_vector
import
Rigid3Array
from
openfold.utils.geometry.rotation_matrix
import
Rot3Array
from
openfold.utils.geometry.vector
import
Vec3Array
from
openfold.utils.tensor_utils
import
(
from
openfold.utils.tensor_utils
import
(
tree_map
,
tree_map
,
tensor_tree_map
,
tensor_tree_map
,
...
@@ -40,6 +43,19 @@ if compare_utils.alphafold_is_installed():
...
@@ -40,6 +43,19 @@ if compare_utils.alphafold_is_installed():
class
TestFeats
(
unittest
.
TestCase
):
class
TestFeats
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
if
consts
.
is_multimer
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom_multimer
cls
.
am_fold
=
alphafold
.
model
.
folding_multimer
cls
.
am_modules
=
alphafold
.
model
.
modules_multimer
cls
.
am_rigid
=
alphafold
.
model
.
geometry
else
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom
cls
.
am_fold
=
alphafold
.
model
.
folding
cls
.
am_modules
=
alphafold
.
model
.
modules
cls
.
am_rigid
=
alphafold
.
model
.
r3
@
compare_utils
.
skip_unless_alphafold_installed
()
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_pseudo_beta_fn_compare
(
self
):
def
test_pseudo_beta_fn_compare
(
self
):
def
test_pbf
(
aatype
,
all_atom_pos
,
all_atom_mask
):
def
test_pbf
(
aatype
,
all_atom_pos
,
all_atom_mask
):
...
@@ -131,7 +147,9 @@ class TestFeats(unittest.TestCase):
...
@@ -131,7 +147,9 @@ class TestFeats(unittest.TestCase):
@
compare_utils
.
skip_unless_alphafold_installed
()
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_atom37_to_frames_compare
(
self
):
def
test_atom37_to_frames_compare
(
self
):
def
run_atom37_to_frames
(
aatype
,
all_atom_positions
,
all_atom_mask
):
def
run_atom37_to_frames
(
aatype
,
all_atom_positions
,
all_atom_mask
):
return
alphafold
.
model
.
all_atom
.
atom37_to_frames
(
if
consts
.
is_multimer
:
all_atom_positions
=
self
.
am_rigid
.
Vec3Array
.
from_array
(
all_atom_positions
)
return
self
.
am_atom
.
atom37_to_frames
(
aatype
,
all_atom_positions
,
all_atom_mask
aatype
,
all_atom_positions
,
all_atom_mask
)
)
...
@@ -150,7 +168,14 @@ class TestFeats(unittest.TestCase):
...
@@ -150,7 +168,14 @@ class TestFeats(unittest.TestCase):
}
}
out_gt
=
f
.
apply
({},
None
,
**
batch
)
out_gt
=
f
.
apply
({},
None
,
**
batch
)
to_tensor
=
lambda
t
:
torch
.
tensor
(
np
.
array
(
t
))
if
consts
.
is_multimer
:
to_tensor
=
(
lambda
t
:
torch
.
tensor
(
np
.
array
(
t
))
if
not
isinstance
(
t
,
self
.
am_rigid
.
Rigid3Array
)
else
torch
.
tensor
(
np
.
array
(
t
.
to_array
())).
view
(
*
t
.
shape
[:
2
],
12
))
else
:
to_tensor
=
lambda
t
:
torch
.
tensor
(
np
.
array
(
t
))
out_gt
=
{
k
:
to_tensor
(
v
)
for
k
,
v
in
out_gt
.
items
()}
out_gt
=
{
k
:
to_tensor
(
v
)
for
k
,
v
in
out_gt
.
items
()}
def
flat12_to_4x4
(
flat12
):
def
flat12_to_4x4
(
flat12
):
...
@@ -187,7 +212,13 @@ class TestFeats(unittest.TestCase):
...
@@ -187,7 +212,13 @@ class TestFeats(unittest.TestCase):
n
=
5
n
=
5
rots
=
torch
.
rand
((
batch_size
,
n
,
3
,
3
))
rots
=
torch
.
rand
((
batch_size
,
n
,
3
,
3
))
trans
=
torch
.
rand
((
batch_size
,
n
,
3
))
trans
=
torch
.
rand
((
batch_size
,
n
,
3
))
ts
=
Rigid
(
Rotation
(
rot_mats
=
rots
),
trans
)
if
consts
.
is_multimer
:
rotation
=
Rot3Array
.
from_array
(
rots
)
translation
=
Vec3Array
.
from_array
(
trans
)
ts
=
Rigid3Array
(
rotation
,
translation
)
else
:
ts
=
Rigid
(
Rotation
(
rot_mats
=
rots
),
trans
)
angles
=
torch
.
rand
((
batch_size
,
n
,
7
,
2
))
angles
=
torch
.
rand
((
batch_size
,
n
,
7
,
2
))
...
@@ -208,7 +239,7 @@ class TestFeats(unittest.TestCase):
...
@@ -208,7 +239,7 @@ class TestFeats(unittest.TestCase):
def
run_torsion_angles_to_frames
(
def
run_torsion_angles_to_frames
(
aatype
,
backb_to_global
,
torsion_angles_sin_cos
aatype
,
backb_to_global
,
torsion_angles_sin_cos
):
):
return
alphafold
.
mod
el
.
a
ll
_atom
.
torsion_angles_to_frames
(
return
s
el
f
.
a
m
_atom
.
torsion_angles_to_frames
(
aatype
,
aatype
,
backb_to_global
,
backb_to_global
,
torsion_angles_sin_cos
,
torsion_angles_sin_cos
,
...
@@ -221,10 +252,17 @@ class TestFeats(unittest.TestCase):
...
@@ -221,10 +252,17 @@ class TestFeats(unittest.TestCase):
aatype
=
np
.
random
.
randint
(
0
,
21
,
size
=
(
n_res
,))
aatype
=
np
.
random
.
randint
(
0
,
21
,
size
=
(
n_res
,))
affines
=
random_affines_4x4
((
n_res
,))
affines
=
random_affines_4x4
((
n_res
,))
rigids
=
alphafold
.
model
.
r3
.
rigids_from_tensor4x4
(
affines
)
transformations
=
Rigid
.
from_tensor_4x4
(
if
consts
.
is_multimer
:
torch
.
as_tensor
(
affines
).
float
()
rigids
=
self
.
am_rigid
.
Rigid3Array
.
from_array4x4
(
affines
)
)
transformations
=
Rigid3Array
.
from_tensor_4x4
(
torch
.
as_tensor
(
affines
).
float
()
)
else
:
rigids
=
self
.
am_rigid
.
rigids_from_tensor4x4
(
affines
)
transformations
=
Rigid
.
from_tensor_4x4
(
torch
.
as_tensor
(
affines
).
float
()
)
torsion_angles_sin_cos
=
np
.
random
.
rand
(
n_res
,
7
,
2
)
torsion_angles_sin_cos
=
np
.
random
.
rand
(
n_res
,
7
,
2
)
...
@@ -264,7 +302,13 @@ class TestFeats(unittest.TestCase):
...
@@ -264,7 +302,13 @@ class TestFeats(unittest.TestCase):
rots
=
torch
.
rand
((
batch_size
,
n_res
,
8
,
3
,
3
))
rots
=
torch
.
rand
((
batch_size
,
n_res
,
8
,
3
,
3
))
trans
=
torch
.
rand
((
batch_size
,
n_res
,
8
,
3
))
trans
=
torch
.
rand
((
batch_size
,
n_res
,
8
,
3
))
ts
=
Rigid
(
Rotation
(
rot_mats
=
rots
),
trans
)
if
consts
.
is_multimer
:
rotation
=
Rot3Array
.
from_array
(
rots
)
translation
=
Vec3Array
.
from_array
(
trans
)
ts
=
Rigid3Array
(
rotation
,
translation
)
else
:
ts
=
Rigid
(
Rotation
(
rot_mats
=
rots
),
trans
)
f
=
torch
.
randint
(
low
=
0
,
high
=
21
,
size
=
(
batch_size
,
n_res
)).
long
()
f
=
torch
.
randint
(
low
=
0
,
high
=
21
,
size
=
(
batch_size
,
n_res
)).
long
()
...
@@ -277,13 +321,15 @@ class TestFeats(unittest.TestCase):
...
@@ -277,13 +321,15 @@ class TestFeats(unittest.TestCase):
torch
.
tensor
(
restype_atom14_rigid_group_positions
),
torch
.
tensor
(
restype_atom14_rigid_group_positions
),
)
)
if
consts
.
is_multimer
:
xyz
=
xyz
.
to_tensor
()
self
.
assertTrue
(
xyz
.
shape
==
(
batch_size
,
n_res
,
14
,
3
))
self
.
assertTrue
(
xyz
.
shape
==
(
batch_size
,
n_res
,
14
,
3
))
@
compare_utils
.
skip_unless_alphafold_installed
()
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_frames_and_literature_positions_to_atom14_pos_compare
(
self
):
def
test_frames_and_literature_positions_to_atom14_pos_compare
(
self
):
def
run_f
(
aatype
,
affines
):
def
run_f
(
aatype
,
affines
):
am
=
alphafold
.
model
return
self
.
am_atom
.
frames_and_literature_positions_to_atom14_pos
(
return
am
.
all_atom
.
frames_and_literature_positions_to_atom14_pos
(
aatype
,
affines
aatype
,
affines
)
)
...
@@ -294,16 +340,27 @@ class TestFeats(unittest.TestCase):
...
@@ -294,16 +340,27 @@ class TestFeats(unittest.TestCase):
aatype
=
np
.
random
.
randint
(
0
,
21
,
size
=
(
n_res
,))
aatype
=
np
.
random
.
randint
(
0
,
21
,
size
=
(
n_res
,))
affines
=
random_affines_4x4
((
n_res
,
8
))
affines
=
random_affines_4x4
((
n_res
,
8
))
rigids
=
alphafold
.
model
.
r3
.
rigids_from_tensor4x4
(
affines
)
transformations
=
Rigid
.
from_tensor_4x4
(
if
consts
.
is_multimer
:
torch
.
as_tensor
(
affines
).
float
()
rigids
=
self
.
am_rigid
.
Rigid3Array
.
from_array4x4
(
affines
)
)
transformations
=
Rigid3Array
.
from_tensor_4x4
(
torch
.
as_tensor
(
affines
).
float
()
)
else
:
rigids
=
self
.
am_rigid
.
rigids_from_tensor4x4
(
affines
)
transformations
=
Rigid
.
from_tensor_4x4
(
torch
.
as_tensor
(
affines
).
float
()
)
out_gt
=
f
.
apply
({},
None
,
aatype
,
rigids
)
out_gt
=
f
.
apply
({},
None
,
aatype
,
rigids
)
jax
.
tree_map
(
lambda
x
:
x
.
block_until_ready
(),
out_gt
)
jax
.
tree_map
(
lambda
x
:
x
.
block_until_ready
(),
out_gt
)
out_gt
=
torch
.
stack
(
[
torch
.
as_tensor
(
np
.
array
(
x
))
for
x
in
out_gt
],
dim
=-
1
if
consts
.
is_multimer
:
)
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
.
to_array
()))
else
:
out_gt
=
torch
.
stack
(
[
torch
.
as_tensor
(
np
.
array
(
x
))
for
x
in
out_gt
],
dim
=-
1
)
out_repro
=
feats
.
frames_and_literature_positions_to_atom14_pos
(
out_repro
=
feats
.
frames_and_literature_positions_to_atom14_pos
(
transformations
.
cuda
(),
transformations
.
cuda
(),
...
...
tests/test_import_weights.py
View file @
d8ee9c5f
...
@@ -65,7 +65,7 @@ class TestImportWeights(unittest.TestCase):
...
@@ -65,7 +65,7 @@ class TestImportWeights(unittest.TestCase):
)
)
][
1
].
transpose
(
-
1
,
-
2
)
][
1
].
transpose
(
-
1
,
-
2
)
),
),
model
.
evoformer
.
blocks
[
1
].
core
.
outer_product_mean
.
linear_1
.
weight
,
model
.
evoformer
.
blocks
[
1
].
outer_product_mean
.
linear_1
.
weight
,
),
),
]
]
...
...
tests/test_loss.py
View file @
d8ee9c5f
...
@@ -13,7 +13,6 @@
...
@@ -13,7 +13,6 @@
# limitations under the License.
# limitations under the License.
import
os
import
os
import
math
import
torch
import
torch
import
numpy
as
np
import
numpy
as
np
import
unittest
import
unittest
...
@@ -24,7 +23,6 @@ from openfold.utils.rigid_utils import (
...
@@ -24,7 +23,6 @@ from openfold.utils.rigid_utils import (
Rotation
,
Rotation
,
Rigid
,
Rigid
,
)
)
import
openfold.utils.feats
as
feats
from
openfold.utils.loss
import
(
from
openfold.utils.loss
import
(
torsion_angle_loss
,
torsion_angle_loss
,
compute_fape
,
compute_fape
,
...
@@ -51,7 +49,7 @@ from openfold.utils.tensor_utils import (
...
@@ -51,7 +49,7 @@ from openfold.utils.tensor_utils import (
)
)
import
tests.compare_utils
as
compare_utils
import
tests.compare_utils
as
compare_utils
from
tests.config
import
consts
from
tests.config
import
consts
from
tests.data_utils
import
random_affines_vector
,
random_affines_4x4
from
tests.data_utils
import
random_affines_vector
,
random_affines_4x4
,
random_asym_ids
if
compare_utils
.
alphafold_is_installed
():
if
compare_utils
.
alphafold_is_installed
():
alphafold
=
compare_utils
.
import_alphafold
()
alphafold
=
compare_utils
.
import_alphafold
()
...
@@ -64,7 +62,30 @@ def affine_vector_to_4x4(affine):
...
@@ -64,7 +62,30 @@ def affine_vector_to_4x4(affine):
return
r
.
to_tensor_4x4
()
return
r
.
to_tensor_4x4
()
def
affine_vector_to_rigid
(
am_rigid
,
affine
):
rigid_flat
=
np
.
split
(
affine
,
7
,
axis
=-
1
)
rigid_flat
=
[
r
.
squeeze
(
-
1
)
for
r
in
rigid_flat
]
qw
,
qx
,
qy
,
qz
=
rigid_flat
[:
4
]
trans
=
rigid_flat
[
4
:]
rotation
=
am_rigid
.
Rot3Array
.
from_quaternion
(
qw
,
qx
,
qy
,
qz
,
normalize
=
True
)
translation
=
am_rigid
.
Vec3Array
(
*
trans
)
return
am_rigid
.
Rigid3Array
(
rotation
,
translation
)
class
TestLoss
(
unittest
.
TestCase
):
class
TestLoss
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
if
consts
.
is_multimer
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom_multimer
cls
.
am_fold
=
alphafold
.
model
.
folding_multimer
cls
.
am_modules
=
alphafold
.
model
.
modules_multimer
cls
.
am_rigid
=
alphafold
.
model
.
geometry
else
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom
cls
.
am_fold
=
alphafold
.
model
.
folding
cls
.
am_modules
=
alphafold
.
model
.
modules
cls
.
am_rigid
=
alphafold
.
model
.
r3
def
test_run_torsion_angle_loss
(
self
):
def
test_run_torsion_angle_loss
(
self
):
batch_size
=
consts
.
batch_size
batch_size
=
consts
.
batch_size
n_res
=
consts
.
n_res
n_res
=
consts
.
n_res
...
@@ -127,7 +148,10 @@ class TestLoss(unittest.TestCase):
...
@@ -127,7 +148,10 @@ class TestLoss(unittest.TestCase):
@
compare_utils
.
skip_unless_alphafold_installed
()
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_between_residue_bond_loss_compare
(
self
):
def
test_between_residue_bond_loss_compare
(
self
):
def
run_brbl
(
pred_pos
,
pred_atom_mask
,
residue_index
,
aatype
):
def
run_brbl
(
pred_pos
,
pred_atom_mask
,
residue_index
,
aatype
):
return
alphafold
.
model
.
all_atom
.
between_residue_bond_loss
(
if
consts
.
is_multimer
:
pred_pos
=
self
.
am_rigid
.
Vec3Array
.
from_array
(
pred_pos
)
return
self
.
am_atom
.
between_residue_bond_loss
(
pred_pos
,
pred_pos
,
pred_atom_mask
,
pred_atom_mask
,
residue_index
,
residue_index
,
...
@@ -184,12 +208,22 @@ class TestLoss(unittest.TestCase):
...
@@ -184,12 +208,22 @@ class TestLoss(unittest.TestCase):
@
compare_utils
.
skip_unless_alphafold_installed
()
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_between_residue_clash_loss_compare
(
self
):
def
test_between_residue_clash_loss_compare
(
self
):
def
run_brcl
(
pred_pos
,
atom_exists
,
atom_radius
,
res_ind
):
def
run_brcl
(
pred_pos
,
atom_exists
,
atom_radius
,
res_ind
,
asym_id
):
return
alphafold
.
model
.
all_atom
.
between_residue_clash_loss
(
if
consts
.
is_multimer
:
pred_pos
=
self
.
am_rigid
.
Vec3Array
.
from_array
(
pred_pos
)
return
self
.
am_atom
.
between_residue_clash_loss
(
pred_pos
,
atom_exists
,
atom_radius
,
res_ind
,
asym_id
)
return
self
.
am_atom
.
between_residue_clash_loss
(
pred_pos
,
pred_pos
,
atom_exists
,
atom_exists
,
atom_radius
,
atom_radius
,
res_ind
,
res_ind
)
)
f
=
hk
.
transform
(
run_brcl
)
f
=
hk
.
transform
(
run_brcl
)
...
@@ -202,6 +236,7 @@ class TestLoss(unittest.TestCase):
...
@@ -202,6 +236,7 @@ class TestLoss(unittest.TestCase):
res_ind
=
np
.
arange
(
res_ind
=
np
.
arange
(
n_res
,
n_res
,
)
)
asym_id
=
random_asym_ids
(
n_res
)
out_gt
=
f
.
apply
(
out_gt
=
f
.
apply
(
{},
{},
...
@@ -210,6 +245,7 @@ class TestLoss(unittest.TestCase):
...
@@ -210,6 +245,7 @@ class TestLoss(unittest.TestCase):
atom_exists
,
atom_exists
,
atom_radius
,
atom_radius
,
res_ind
,
res_ind
,
asym_id
)
)
out_gt
=
jax
.
tree_map
(
lambda
x
:
x
.
block_until_ready
(),
out_gt
)
out_gt
=
jax
.
tree_map
(
lambda
x
:
x
.
block_until_ready
(),
out_gt
)
out_gt
=
jax
.
tree_map
(
lambda
x
:
torch
.
tensor
(
np
.
copy
(
x
)),
out_gt
)
out_gt
=
jax
.
tree_map
(
lambda
x
:
torch
.
tensor
(
np
.
copy
(
x
)),
out_gt
)
...
@@ -266,7 +302,19 @@ class TestLoss(unittest.TestCase):
...
@@ -266,7 +302,19 @@ class TestLoss(unittest.TestCase):
def
run_fsv
(
batch
,
pos
,
config
):
def
run_fsv
(
batch
,
pos
,
config
):
cwd
=
os
.
getcwd
()
cwd
=
os
.
getcwd
()
os
.
chdir
(
"tests/test_data"
)
os
.
chdir
(
"tests/test_data"
)
loss
=
alphafold
.
model
.
folding
.
find_structural_violations
(
if
consts
.
is_multimer
:
atom14_pred_pos
=
self
.
am_rigid
.
Vec3Array
.
from_array
(
pos
)
return
self
.
am_fold
.
find_structural_violations
(
batch
[
'aatype'
],
batch
[
'residue_index'
],
batch
[
'atom14_atom_exists'
],
atom14_pred_pos
,
config
,
batch
[
'asym_id'
]
)
loss
=
self
.
am_fold
.
find_structural_violations
(
batch
,
batch
,
pos
,
pos
,
config
,
config
,
...
@@ -285,6 +333,7 @@ class TestLoss(unittest.TestCase):
...
@@ -285,6 +333,7 @@ class TestLoss(unittest.TestCase):
"residx_atom14_to_atom37"
:
np
.
random
.
randint
(
"residx_atom14_to_atom37"
:
np
.
random
.
randint
(
0
,
37
,
(
n_res
,
14
)
0
,
37
,
(
n_res
,
14
)
).
astype
(
np
.
int64
),
).
astype
(
np
.
int64
),
"asym_id"
:
random_asym_ids
(
n_res
)
}
}
pred_pos
=
np
.
random
.
rand
(
n_res
,
14
,
3
)
pred_pos
=
np
.
random
.
rand
(
n_res
,
14
,
3
)
...
@@ -380,7 +429,7 @@ class TestLoss(unittest.TestCase):
...
@@ -380,7 +429,7 @@ class TestLoss(unittest.TestCase):
n_seq
=
consts
.
n_seq
n_seq
=
consts
.
n_seq
value
=
{
value
=
{
"logits"
:
np
.
random
.
rand
(
n_res
,
n_seq
,
23
).
astype
(
np
.
float32
),
"logits"
:
np
.
random
.
rand
(
n_res
,
n_seq
,
consts
.
msa_logits
).
astype
(
np
.
float32
),
}
}
batch
=
{
batch
=
{
...
@@ -506,10 +555,28 @@ class TestLoss(unittest.TestCase):
...
@@ -506,10 +555,28 @@ class TestLoss(unittest.TestCase):
c_chi_loss
=
config
.
model
.
heads
.
structure_module
c_chi_loss
=
config
.
model
.
heads
.
structure_module
def
run_supervised_chi_loss
(
value
,
batch
):
def
run_supervised_chi_loss
(
value
,
batch
):
if
consts
.
is_multimer
:
pred_angles
=
np
.
reshape
(
value
[
'sidechains'
][
'angles_sin_cos'
],
[
-
1
,
consts
.
n_res
,
7
,
2
])
unnormed_angles
=
np
.
reshape
(
value
[
'sidechains'
][
'unnormalized_angles_sin_cos'
],
[
-
1
,
consts
.
n_res
,
7
,
2
])
chi_loss
,
_
,
_
=
self
.
am_fold
.
supervised_chi_loss
(
batch
[
'seq_mask'
],
batch
[
'chi_mask'
],
batch
[
'aatype'
],
batch
[
'chi_angles'
],
pred_angles
,
unnormed_angles
,
c_chi_loss
)
return
chi_loss
ret
=
{
ret
=
{
"loss"
:
jax
.
numpy
.
array
(
0.0
),
"loss"
:
jax
.
numpy
.
array
(
0.0
),
}
}
alphafold
.
model
.
fold
ing
.
supervised_chi_loss
(
self
.
am_
fold
.
supervised_chi_loss
(
ret
,
batch
,
value
,
c_chi_loss
ret
,
batch
,
value
,
c_chi_loss
)
)
return
ret
[
"loss"
]
return
ret
[
"loss"
]
...
@@ -570,15 +637,31 @@ class TestLoss(unittest.TestCase):
...
@@ -570,15 +637,31 @@ class TestLoss(unittest.TestCase):
ret
=
{
ret
=
{
"loss"
:
np
.
array
(
0.0
).
astype
(
np
.
float32
),
"loss"
:
np
.
array
(
0.0
).
astype
(
np
.
float32
),
}
}
if
consts
.
is_multimer
:
atom14_pred_pos
=
self
.
am_rigid
.
Vec3Array
.
from_array
(
atom14_pred_pos
)
viol
=
self
.
am_fold
.
find_structural_violations
(
batch
[
'aatype'
],
batch
[
'residue_index'
],
batch
[
'atom14_atom_exists'
],
atom14_pred_pos
,
c_viol
,
batch
[
'asym_id'
]
)
return
self
.
am_fold
.
structural_violation_loss
(
mask
=
batch
[
'atom14_atom_exists'
],
violations
=
viol
,
config
=
c_viol
)
value
=
{}
value
=
{}
value
[
value
[
"violations"
"violations"
]
=
alphafold
.
model
.
fold
ing
.
find_structural_violations
(
]
=
self
.
am_
fold
.
find_structural_violations
(
batch
,
batch
,
atom14_pred_pos
,
atom14_pred_pos
,
c_viol
,
c_viol
,
)
)
alphafold
.
model
.
folding
.
structural_violation_loss
(
self
.
am_fold
.
structural_violation_loss
(
ret
,
ret
,
batch
,
batch
,
value
,
value
,
...
@@ -594,12 +677,14 @@ class TestLoss(unittest.TestCase):
...
@@ -594,12 +677,14 @@ class TestLoss(unittest.TestCase):
"seq_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,)).
astype
(
np
.
float32
),
"seq_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,)).
astype
(
np
.
float32
),
"residue_index"
:
np
.
arange
(
n_res
),
"residue_index"
:
np
.
arange
(
n_res
),
"aatype"
:
np
.
random
.
randint
(
0
,
21
,
(
n_res
,)),
"aatype"
:
np
.
random
.
randint
(
0
,
21
,
(
n_res
,)),
"asym_id"
:
random_asym_ids
(
n_res
)
}
}
alphafold
.
model
.
tf
.
data_transforms
.
make_atom14_masks
(
batch
)
batch
=
{
k
:
np
.
array
(
v
)
for
k
,
v
in
batch
.
items
()}
atom14_pred_pos
=
np
.
random
.
rand
(
n_res
,
14
,
3
).
astype
(
np
.
float32
)
atom14_pred_pos
=
np
.
random
.
rand
(
n_res
,
14
,
3
).
astype
(
np
.
float32
)
alphafold
.
model
.
tf
.
data_transforms
.
make_atom14_masks
(
batch
)
batch
=
{
k
:
np
.
array
(
v
)
for
k
,
v
in
batch
.
items
()}
out_gt
=
f
.
apply
({},
None
,
batch
,
atom14_pred_pos
)
out_gt
=
f
.
apply
({},
None
,
batch
,
atom14_pred_pos
)
out_gt
=
torch
.
tensor
(
np
.
array
(
out_gt
.
block_until_ready
()))
out_gt
=
torch
.
tensor
(
np
.
array
(
out_gt
.
block_until_ready
()))
...
@@ -676,10 +761,31 @@ class TestLoss(unittest.TestCase):
...
@@ -676,10 +761,31 @@ class TestLoss(unittest.TestCase):
c_sm
=
config
.
model
.
heads
.
structure_module
c_sm
=
config
.
model
.
heads
.
structure_module
def
run_bb_loss
(
batch
,
value
):
def
run_bb_loss
(
batch
,
value
):
if
consts
.
is_multimer
:
intra_chain_mask
=
(
batch
[
"asym_id"
][...,
None
]
==
batch
[
"asym_id"
][...,
None
,
:]).
astype
(
np
.
float32
)
gt_rigid
=
affine_vector_to_rigid
(
self
.
am_rigid
,
batch
[
"backbone_affine_tensor"
])
target_rigid
=
affine_vector_to_rigid
(
self
.
am_rigid
,
value
[
'traj'
])
intra_chain_bb_loss
,
intra_chain_fape
=
self
.
am_fold
.
backbone_loss
(
gt_rigid
=
gt_rigid
,
gt_frames_mask
=
batch
[
"backbone_affine_mask"
],
gt_positions_mask
=
batch
[
"backbone_affine_mask"
],
target_rigid
=
target_rigid
,
config
=
c_sm
.
intra_chain_fape
,
pair_mask
=
intra_chain_mask
)
interface_bb_loss
,
interface_fape
=
self
.
am_fold
.
backbone_loss
(
gt_rigid
=
gt_rigid
,
gt_frames_mask
=
batch
[
"backbone_affine_mask"
],
gt_positions_mask
=
batch
[
"backbone_affine_mask"
],
target_rigid
=
target_rigid
,
config
=
c_sm
.
interface_fape
,
pair_mask
=
1.
-
intra_chain_mask
)
return
intra_chain_bb_loss
+
interface_bb_loss
ret
=
{
ret
=
{
"loss"
:
np
.
array
(
0.0
),
"loss"
:
np
.
array
(
0.0
),
}
}
alphafold
.
model
.
fold
ing
.
backbone_loss
(
ret
,
batch
,
value
,
c_sm
)
self
.
am_
fold
.
backbone_loss
(
ret
,
batch
,
value
,
c_sm
)
return
ret
[
"loss"
]
return
ret
[
"loss"
]
f
=
hk
.
transform
(
run_bb_loss
)
f
=
hk
.
transform
(
run_bb_loss
)
...
@@ -692,6 +798,7 @@ class TestLoss(unittest.TestCase):
...
@@ -692,6 +798,7 @@ class TestLoss(unittest.TestCase):
np
.
float32
np
.
float32
),
),
"use_clamped_fape"
:
np
.
array
(
0.0
),
"use_clamped_fape"
:
np
.
array
(
0.0
),
"asym_id"
:
random_asym_ids
(
n_res
)
}
}
value
=
{
value
=
{
...
@@ -726,9 +833,29 @@ class TestLoss(unittest.TestCase):
...
@@ -726,9 +833,29 @@ class TestLoss(unittest.TestCase):
c_sm
=
config
.
model
.
heads
.
structure_module
c_sm
=
config
.
model
.
heads
.
structure_module
def
run_sidechain_loss
(
batch
,
value
,
atom14_pred_positions
):
def
run_sidechain_loss
(
batch
,
value
,
atom14_pred_positions
):
if
consts
.
is_multimer
:
atom14_pred_positions
=
self
.
am_rigid
.
Vec3Array
.
from_array
(
atom14_pred_positions
)
all_atom_positions
=
self
.
am_rigid
.
Vec3Array
.
from_array
(
batch
[
"all_atom_positions"
])
gt_positions
,
gt_mask
,
alt_naming_is_better
=
self
.
am_fold
.
compute_atom14_gt
(
aatype
=
batch
[
"aatype"
],
all_atom_positions
=
all_atom_positions
,
all_atom_mask
=
batch
[
"all_atom_mask"
],
pred_pos
=
atom14_pred_positions
)
pred_frames
=
self
.
am_rigid
.
Rigid3Array
.
from_array4x4
(
value
[
"sidechains"
][
"frames"
])
pred_positions
=
self
.
am_rigid
.
Vec3Array
.
from_array
(
value
[
"sidechains"
][
"atom_pos"
])
gt_sc_frames
,
gt_sc_frames_mask
=
self
.
am_fold
.
compute_frames
(
aatype
=
batch
[
"aatype"
],
all_atom_positions
=
all_atom_positions
,
all_atom_mask
=
batch
[
"all_atom_mask"
],
use_alt
=
alt_naming_is_better
)
return
self
.
am_fold
.
sidechain_loss
(
gt_sc_frames
,
gt_sc_frames_mask
,
gt_positions
,
gt_mask
,
pred_frames
,
pred_positions
,
c_sm
)[
'loss'
]
batch
=
{
batch
=
{
**
batch
,
**
batch
,
**
alphafold
.
mod
el
.
a
ll
_atom
.
atom37_to_frames
(
**
s
el
f
.
a
m
_atom
.
atom37_to_frames
(
batch
[
"aatype"
],
batch
[
"aatype"
],
batch
[
"all_atom_positions"
],
batch
[
"all_atom_positions"
],
batch
[
"all_atom_mask"
],
batch
[
"all_atom_mask"
],
...
@@ -752,7 +879,7 @@ class TestLoss(unittest.TestCase):
...
@@ -752,7 +879,7 @@ class TestLoss(unittest.TestCase):
)
)
value
=
v
value
=
v
ret
=
alphafold
.
model
.
fold
ing
.
sidechain_loss
(
batch
,
value
,
c_sm
)
ret
=
self
.
am_
fold
.
sidechain_loss
(
batch
,
value
,
c_sm
)
return
ret
[
"loss"
]
return
ret
[
"loss"
]
f
=
hk
.
transform
(
run_sidechain_loss
)
f
=
hk
.
transform
(
run_sidechain_loss
)
...
@@ -816,6 +943,7 @@ class TestLoss(unittest.TestCase):
...
@@ -816,6 +943,7 @@ class TestLoss(unittest.TestCase):
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
@
compare_utils
.
skip_unless_alphafold_installed
()
@
compare_utils
.
skip_unless_alphafold_installed
()
@
unittest
.
skipIf
(
not
consts
.
is_multimer
and
"ptm"
not
in
consts
.
model
,
"Not enabled for non-ptm models."
)
def
test_tm_loss_compare
(
self
):
def
test_tm_loss_compare
(
self
):
config
=
compare_utils
.
get_alphafold_config
()
config
=
compare_utils
.
get_alphafold_config
()
c_tm
=
config
.
model
.
heads
.
predicted_aligned_error
c_tm
=
config
.
model
.
heads
.
predicted_aligned_error
...
...
tests/test_model.py
View file @
d8ee9c5f
...
@@ -20,8 +20,7 @@ import unittest
...
@@ -20,8 +20,7 @@ import unittest
from
openfold.config
import
model_config
from
openfold.config
import
model_config
from
openfold.data
import
data_transforms
from
openfold.data
import
data_transforms
from
openfold.model.model
import
AlphaFold
from
openfold.model.model
import
AlphaFold
import
openfold.utils.feats
as
feats
from
openfold.utils.tensor_utils
import
tensor_tree_map
from
openfold.utils.tensor_utils
import
tree_map
,
tensor_tree_map
import
tests.compare_utils
as
compare_utils
import
tests.compare_utils
as
compare_utils
from
tests.config
import
consts
from
tests.config
import
consts
from
tests.data_utils
import
(
from
tests.data_utils
import
(
...
@@ -36,13 +35,26 @@ if compare_utils.alphafold_is_installed():
...
@@ -36,13 +35,26 @@ if compare_utils.alphafold_is_installed():
class
TestModel
(
unittest
.
TestCase
):
class
TestModel
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
if
consts
.
is_multimer
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom_multimer
cls
.
am_fold
=
alphafold
.
model
.
folding_multimer
cls
.
am_modules
=
alphafold
.
model
.
modules_multimer
cls
.
am_rigid
=
alphafold
.
model
.
geometry
else
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom
cls
.
am_fold
=
alphafold
.
model
.
folding
cls
.
am_modules
=
alphafold
.
model
.
modules
cls
.
am_rigid
=
alphafold
.
model
.
r3
def
test_dry_run
(
self
):
def
test_dry_run
(
self
):
n_seq
=
consts
.
n_seq
n_seq
=
consts
.
n_seq
n_templ
=
consts
.
n_templ
n_templ
=
consts
.
n_templ
n_res
=
consts
.
n_res
n_res
=
consts
.
n_res
n_extra_seq
=
consts
.
n_extra
n_extra_seq
=
consts
.
n_extra
c
=
model_config
(
"
model
_1"
)
c
=
model_config
(
consts
.
model
)
c
.
model
.
evoformer_stack
.
no_blocks
=
4
# no need to go overboard here
c
.
model
.
evoformer_stack
.
no_blocks
=
4
# no need to go overboard here
c
.
model
.
evoformer_stack
.
blocks_per_ckpt
=
None
# don't want to set up
c
.
model
.
evoformer_stack
.
blocks_per_ckpt
=
None
# don't want to set up
# deepspeed for this test
# deepspeed for this test
...
@@ -68,6 +80,12 @@ class TestModel(unittest.TestCase):
...
@@ -68,6 +80,12 @@ class TestModel(unittest.TestCase):
batch
.
update
(
data_transforms
.
make_atom14_masks
(
batch
))
batch
.
update
(
data_transforms
.
make_atom14_masks
(
batch
))
batch
[
"no_recycling_iters"
]
=
torch
.
tensor
(
2.
)
batch
[
"no_recycling_iters"
]
=
torch
.
tensor
(
2.
)
if
consts
.
is_multimer
:
batch
[
"asym_id"
]
=
torch
.
randint
(
0
,
1
,
size
=
(
n_res
,))
batch
[
"entity_id"
]
=
torch
.
randint
(
0
,
1
,
size
=
(
n_res
,))
batch
[
"sym_id"
]
=
torch
.
randint
(
0
,
1
,
size
=
(
n_res
,))
batch
[
"extra_deletion_matrix"
]
=
torch
.
randint
(
0
,
2
,
size
=
(
n_extra_seq
,
n_res
))
add_recycling_dims
=
lambda
t
:
(
add_recycling_dims
=
lambda
t
:
(
t
.
unsqueeze
(
-
1
).
expand
(
*
t
.
shape
,
c
.
data
.
common
.
max_recycling_iters
)
t
.
unsqueeze
(
-
1
).
expand
(
*
t
.
shape
,
c
.
data
.
common
.
max_recycling_iters
)
)
)
...
@@ -80,7 +98,8 @@ class TestModel(unittest.TestCase):
...
@@ -80,7 +98,8 @@ class TestModel(unittest.TestCase):
def
test_compare
(
self
):
def
test_compare
(
self
):
def
run_alphafold
(
batch
):
def
run_alphafold
(
batch
):
config
=
compare_utils
.
get_alphafold_config
()
config
=
compare_utils
.
get_alphafold_config
()
model
=
alphafold
.
model
.
modules
.
AlphaFold
(
config
.
model
)
model
=
self
.
am_modules
.
AlphaFold
(
config
.
model
)
return
model
(
return
model
(
batch
=
batch
,
batch
=
batch
,
is_training
=
False
,
is_training
=
False
,
...
@@ -100,7 +119,8 @@ class TestModel(unittest.TestCase):
...
@@ -100,7 +119,8 @@ class TestModel(unittest.TestCase):
# atom37_to_atom14 doesn't like batches
# atom37_to_atom14 doesn't like batches
batch
[
"residx_atom14_to_atom37"
]
=
batch
[
"residx_atom14_to_atom37"
][
0
]
batch
[
"residx_atom14_to_atom37"
]
=
batch
[
"residx_atom14_to_atom37"
][
0
]
batch
[
"atom14_atom_exists"
]
=
batch
[
"atom14_atom_exists"
][
0
]
batch
[
"atom14_atom_exists"
]
=
batch
[
"atom14_atom_exists"
][
0
]
out_gt
=
alphafold
.
model
.
all_atom
.
atom37_to_atom14
(
out_gt
,
batch
)
out_gt
=
self
.
am_atom
.
atom37_to_atom14
(
out_gt
,
batch
)
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
.
block_until_ready
()))
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
.
block_until_ready
()))
batch
[
"no_recycling_iters"
]
=
np
.
array
([
3.
,
3.
,
3.
,
3.
,])
batch
[
"no_recycling_iters"
]
=
np
.
array
([
3.
,
3.
,
3.
,
3.
,])
...
...
tests/test_outer_product_mean.py
View file @
d8ee9c5f
...
@@ -81,7 +81,7 @@ class TestOuterProductMean(unittest.TestCase):
...
@@ -81,7 +81,7 @@ class TestOuterProductMean(unittest.TestCase):
model
=
compare_utils
.
get_global_pretrained_openfold
()
model
=
compare_utils
.
get_global_pretrained_openfold
()
out_repro
=
(
out_repro
=
(
model
.
evoformer
.
blocks
[
0
]
.
core
model
.
evoformer
.
blocks
[
0
]
.
outer_product_mean
(
.
outer_product_mean
(
torch
.
as_tensor
(
msa_act
).
cuda
(),
torch
.
as_tensor
(
msa_act
).
cuda
(),
chunk_size
=
4
,
chunk_size
=
4
,
...
...
tests/test_pair_transition.py
View file @
d8ee9c5f
...
@@ -76,7 +76,7 @@ class TestPairTransition(unittest.TestCase):
...
@@ -76,7 +76,7 @@ class TestPairTransition(unittest.TestCase):
model
=
compare_utils
.
get_global_pretrained_openfold
()
model
=
compare_utils
.
get_global_pretrained_openfold
()
out_repro
=
(
out_repro
=
(
model
.
evoformer
.
blocks
[
0
].
core
model
.
evoformer
.
blocks
[
0
].
pair_stack
.
pair_transition
(
.
pair_transition
(
torch
.
as_tensor
(
pair_act
,
dtype
=
torch
.
float32
).
cuda
(),
torch
.
as_tensor
(
pair_act
,
dtype
=
torch
.
float32
).
cuda
(),
chunk_size
=
4
,
chunk_size
=
4
,
...
...
tests/test_primitives.py
View file @
d8ee9c5f
...
@@ -13,12 +13,10 @@
...
@@ -13,12 +13,10 @@
# limitations under the License.
# limitations under the License.
import
torch
import
torch
import
numpy
as
np
import
unittest
import
unittest
from
openfold.model.primitives
import
(
from
openfold.model.primitives
import
(
Attention
,
Attention
LowMemoryAttention
,
)
)
from
tests.config
import
consts
from
tests.config
import
consts
...
@@ -40,7 +38,7 @@ class TestLMA(unittest.TestCase):
...
@@ -40,7 +38,7 @@ class TestLMA(unittest.TestCase):
gating_fill
=
torch
.
rand
(
c_hidden
*
no_heads
,
c_hidden
)
gating_fill
=
torch
.
rand
(
c_hidden
*
no_heads
,
c_hidden
)
o_fill
=
torch
.
rand
(
c_hidden
,
c_hidden
*
no_heads
)
o_fill
=
torch
.
rand
(
c_hidden
,
c_hidden
*
no_heads
)
lma
=
LowMemory
Attention
(
lma
=
Attention
(
c_hidden
,
c_hidden
,
c_hidden
,
c_hidden
,
no_heads
c_hidden
,
c_hidden
,
c_hidden
,
c_hidden
,
no_heads
).
cuda
()
).
cuda
()
a
=
Attention
(
a
=
Attention
(
...
@@ -60,7 +58,7 @@ class TestLMA(unittest.TestCase):
...
@@ -60,7 +58,7 @@ class TestLMA(unittest.TestCase):
m
.
linear_o
.
weight
.
copy_
(
o_fill
)
m
.
linear_o
.
weight
.
copy_
(
o_fill
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
l
=
lma
(
q
,
k
,
v
,
1024
,
4096
,
biases
=
bias
)
l
=
lma
(
q
,
k
,
v
,
biases
=
bias
,
use_lma
=
True
,
q_chunk_size
=
1024
,
kv_chunk_size
=
4096
)
real
=
a
(
q
,
k
,
v
,
biases
=
bias
)
real
=
a
(
q
,
k
,
v
,
biases
=
bias
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
l
-
real
))
<
consts
.
eps
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
l
-
real
))
<
consts
.
eps
)
...
...
tests/test_structure_module.py
View file @
d8ee9c5f
...
@@ -18,21 +18,19 @@ import unittest
...
@@ -18,21 +18,19 @@ import unittest
from
openfold.data.data_transforms
import
make_atom14_masks_np
from
openfold.data.data_transforms
import
make_atom14_masks_np
from
openfold.np.residue_constants
import
(
from
openfold.np.residue_constants
import
(
restype_rigid_group_default_frame
,
restype_atom14_to_rigid_group
,
restype_atom14_mask
,
restype_atom14_mask
,
restype_atom14_rigid_group_positions
,
restype_atom37_mask
,
restype_atom37_mask
,
)
)
from
openfold.model.structure_module
import
(
from
openfold.model.structure_module
import
(
StructureModule
,
StructureModule
,
StructureModuleTransition
,
StructureModuleTransition
,
BackboneUpdate
,
AngleResnet
,
AngleResnet
,
InvariantPointAttention
,
InvariantPointAttention
,
)
)
import
openfold.utils.feats
as
feats
from
openfold.utils.rigid_utils
import
Rotation
,
Rigid
from
openfold.utils.rigid_utils
import
Rotation
,
Rigid
from
openfold.utils.geometry.rigid_matrix_vector
import
Rigid3Array
from
openfold.utils.geometry.rotation_matrix
import
Rot3Array
from
openfold.utils.geometry.vector
import
Vec3Array
import
tests.compare_utils
as
compare_utils
import
tests.compare_utils
as
compare_utils
from
tests.config
import
consts
from
tests.config
import
consts
from
tests.data_utils
import
(
from
tests.data_utils
import
(
...
@@ -46,6 +44,19 @@ if compare_utils.alphafold_is_installed():
...
@@ -46,6 +44,19 @@ if compare_utils.alphafold_is_installed():
class
TestStructureModule
(
unittest
.
TestCase
):
class
TestStructureModule
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
if
consts
.
is_multimer
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom_multimer
cls
.
am_fold
=
alphafold
.
model
.
folding_multimer
cls
.
am_modules
=
alphafold
.
model
.
modules_multimer
cls
.
am_rigid
=
alphafold
.
model
.
geometry
else
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom
cls
.
am_fold
=
alphafold
.
model
.
folding
cls
.
am_modules
=
alphafold
.
model
.
modules
cls
.
am_rigid
=
alphafold
.
model
.
r3
def
test_structure_module_shape
(
self
):
def
test_structure_module_shape
(
self
):
batch_size
=
consts
.
batch_size
batch_size
=
consts
.
batch_size
n
=
consts
.
n_res
n
=
consts
.
n_res
...
@@ -81,6 +92,7 @@ class TestStructureModule(unittest.TestCase):
...
@@ -81,6 +92,7 @@ class TestStructureModule(unittest.TestCase):
trans_scale_factor
,
trans_scale_factor
,
ar_epsilon
,
ar_epsilon
,
inf
,
inf
,
is_multimer
=
consts
.
is_multimer
)
)
s
=
torch
.
rand
((
batch_size
,
n
,
c_s
))
s
=
torch
.
rand
((
batch_size
,
n
,
c_s
))
...
@@ -89,7 +101,11 @@ class TestStructureModule(unittest.TestCase):
...
@@ -89,7 +101,11 @@ class TestStructureModule(unittest.TestCase):
out
=
sm
(
s
,
z
,
f
)
out
=
sm
(
s
,
z
,
f
)
self
.
assertTrue
(
out
[
"frames"
].
shape
==
(
no_layers
,
batch_size
,
n
,
7
))
if
consts
.
is_multimer
:
self
.
assertTrue
(
out
[
"frames"
].
shape
==
(
no_layers
,
batch_size
,
n
,
4
,
4
))
else
:
self
.
assertTrue
(
out
[
"frames"
].
shape
==
(
no_layers
,
batch_size
,
n
,
7
))
self
.
assertTrue
(
self
.
assertTrue
(
out
[
"angles"
].
shape
==
(
no_layers
,
batch_size
,
n
,
no_angles
,
2
)
out
[
"angles"
].
shape
==
(
no_layers
,
batch_size
,
n
,
no_angles
,
2
)
)
)
...
@@ -121,11 +137,14 @@ class TestStructureModule(unittest.TestCase):
...
@@ -121,11 +137,14 @@ class TestStructureModule(unittest.TestCase):
c_global
=
config
.
model
.
global_config
c_global
=
config
.
model
.
global_config
def
run_sm
(
representations
,
batch
):
def
run_sm
(
representations
,
batch
):
sm
=
alphafold
.
model
.
fold
ing
.
StructureModule
(
c_sm
,
c_global
)
sm
=
self
.
am_
fold
.
StructureModule
(
c_sm
,
c_global
)
representations
=
{
representations
=
{
k
:
jax
.
lax
.
stop_gradient
(
v
)
for
k
,
v
in
representations
.
items
()
k
:
jax
.
lax
.
stop_gradient
(
v
)
for
k
,
v
in
representations
.
items
()
}
}
batch
=
{
k
:
jax
.
lax
.
stop_gradient
(
v
)
for
k
,
v
in
batch
.
items
()}
batch
=
{
k
:
jax
.
lax
.
stop_gradient
(
v
)
for
k
,
v
in
batch
.
items
()}
if
consts
.
is_multimer
:
return
sm
(
representations
,
batch
,
is_training
=
False
,
compute_loss
=
True
)
return
sm
(
representations
,
batch
,
is_training
=
False
)
return
sm
(
representations
,
batch
,
is_training
=
False
)
f
=
hk
.
transform
(
run_sm
)
f
=
hk
.
transform
(
run_sm
)
...
@@ -178,6 +197,19 @@ class TestStructureModule(unittest.TestCase):
...
@@ -178,6 +197,19 @@ class TestStructureModule(unittest.TestCase):
class
TestInvariantPointAttention
(
unittest
.
TestCase
):
class
TestInvariantPointAttention
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
if
consts
.
is_multimer
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom_multimer
cls
.
am_fold
=
alphafold
.
model
.
folding_multimer
cls
.
am_modules
=
alphafold
.
model
.
modules_multimer
cls
.
am_rigid
=
alphafold
.
model
.
geometry
else
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom
cls
.
am_fold
=
alphafold
.
model
.
folding
cls
.
am_modules
=
alphafold
.
model
.
modules
cls
.
am_rigid
=
alphafold
.
model
.
r3
def
test_shape
(
self
):
def
test_shape
(
self
):
c_m
=
13
c_m
=
13
c_z
=
17
c_z
=
17
...
@@ -194,13 +226,18 @@ class TestInvariantPointAttention(unittest.TestCase):
...
@@ -194,13 +226,18 @@ class TestInvariantPointAttention(unittest.TestCase):
mask
=
torch
.
ones
((
batch_size
,
n_res
))
mask
=
torch
.
ones
((
batch_size
,
n_res
))
rot_mats
=
torch
.
rand
((
batch_size
,
n_res
,
3
,
3
))
rot_mats
=
torch
.
rand
((
batch_size
,
n_res
,
3
,
3
))
rots
=
Rotation
(
rot_mats
=
rot_mats
,
quats
=
None
)
trans
=
torch
.
rand
((
batch_size
,
n_res
,
3
))
trans
=
torch
.
rand
((
batch_size
,
n_res
,
3
))
r
=
Rigid
(
rots
,
trans
)
if
consts
.
is_multimer
:
rotation
=
Rot3Array
.
from_array
(
rot_mats
)
translation
=
Vec3Array
.
from_array
(
trans
)
r
=
Rigid3Array
(
rotation
,
translation
)
else
:
rots
=
Rotation
(
rot_mats
=
rot_mats
,
quats
=
None
)
r
=
Rigid
(
rots
,
trans
)
ipa
=
InvariantPointAttention
(
ipa
=
InvariantPointAttention
(
c_m
,
c_z
,
c_hidden
,
no_heads
,
no_qp
,
no_vp
c_m
,
c_z
,
c_hidden
,
no_heads
,
no_qp
,
no_vp
,
is_multimer
=
consts
.
is_multimer
)
)
shape_before
=
s
.
shape
shape_before
=
s
.
shape
...
@@ -212,16 +249,26 @@ class TestInvariantPointAttention(unittest.TestCase):
...
@@ -212,16 +249,26 @@ class TestInvariantPointAttention(unittest.TestCase):
def
test_ipa_compare
(
self
):
def
test_ipa_compare
(
self
):
def
run_ipa
(
act
,
static_feat_2d
,
mask
,
affine
):
def
run_ipa
(
act
,
static_feat_2d
,
mask
,
affine
):
config
=
compare_utils
.
get_alphafold_config
()
config
=
compare_utils
.
get_alphafold_config
()
ipa
=
alphafold
.
model
.
fold
ing
.
InvariantPointAttention
(
ipa
=
self
.
am_
fold
.
InvariantPointAttention
(
config
.
model
.
heads
.
structure_module
,
config
.
model
.
heads
.
structure_module
,
config
.
model
.
global_config
,
config
.
model
.
global_config
,
)
)
attn
=
ipa
(
inputs_1d
=
act
,
if
consts
.
is_multimer
:
inputs_2d
=
static_feat_2d
,
attn
=
ipa
(
mask
=
mask
,
inputs_1d
=
act
,
affine
=
affine
,
inputs_2d
=
static_feat_2d
,
)
mask
=
mask
,
rigid
=
affine
)
else
:
attn
=
ipa
(
inputs_1d
=
act
,
inputs_2d
=
static_feat_2d
,
mask
=
mask
,
affine
=
affine
)
return
attn
return
attn
f
=
hk
.
transform
(
run_ipa
)
f
=
hk
.
transform
(
run_ipa
)
...
@@ -235,13 +282,20 @@ class TestInvariantPointAttention(unittest.TestCase):
...
@@ -235,13 +282,20 @@ class TestInvariantPointAttention(unittest.TestCase):
sample_mask
=
np
.
ones
((
n_res
,
1
))
sample_mask
=
np
.
ones
((
n_res
,
1
))
affines
=
random_affines_4x4
((
n_res
,))
affines
=
random_affines_4x4
((
n_res
,))
rigids
=
alphafold
.
model
.
r3
.
rigids_from_tensor4x4
(
affines
)
quats
=
alphafold
.
model
.
r3
.
rigids_to_quataffine
(
rigids
)
transformations
=
Rigid
.
from_tensor_4x4
(
torch
.
as_tensor
(
affines
).
float
().
cuda
()
)
sample_affine
=
quats
if
consts
.
is_multimer
:
rigids
=
self
.
am_rigid
.
Rigid3Array
.
from_array4x4
(
affines
)
transformations
=
Rigid3Array
.
from_tensor_4x4
(
torch
.
as_tensor
(
affines
).
float
()
)
sample_affine
=
rigids
else
:
rigids
=
self
.
am_rigid
.
rigids_from_tensor4x4
(
affines
)
quats
=
self
.
am_rigid
.
rigids_to_quataffine
(
rigids
)
transformations
=
Rigid
.
from_tensor_4x4
(
torch
.
as_tensor
(
affines
).
float
().
cuda
()
)
sample_affine
=
quats
ipa_params
=
compare_utils
.
fetch_alphafold_module_weights
(
ipa_params
=
compare_utils
.
fetch_alphafold_module_weights
(
"alphafold/alphafold_iteration/structure_module/"
"alphafold/alphafold_iteration/structure_module/"
...
...
tests/test_template.py
View file @
d8ee9c5f
...
@@ -19,7 +19,6 @@ from openfold.model.template import (
...
@@ -19,7 +19,6 @@ from openfold.model.template import (
TemplatePointwiseAttention
,
TemplatePointwiseAttention
,
TemplatePairStack
,
TemplatePairStack
,
)
)
from
openfold.utils.tensor_utils
import
tree_map
import
tests.compare_utils
as
compare_utils
import
tests.compare_utils
as
compare_utils
from
tests.config
import
consts
from
tests.config
import
consts
from
tests.data_utils
import
random_template_feats
from
tests.data_utils
import
random_template_feats
...
@@ -54,6 +53,19 @@ class TestTemplatePointwiseAttention(unittest.TestCase):
...
@@ -54,6 +53,19 @@ class TestTemplatePointwiseAttention(unittest.TestCase):
class
TestTemplatePairStack
(
unittest
.
TestCase
):
class
TestTemplatePairStack
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
if
consts
.
is_multimer
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom_multimer
cls
.
am_fold
=
alphafold
.
model
.
folding_multimer
cls
.
am_modules
=
alphafold
.
model
.
modules_multimer
cls
.
am_rigid
=
alphafold
.
model
.
geometry
else
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom
cls
.
am_fold
=
alphafold
.
model
.
folding
cls
.
am_modules
=
alphafold
.
model
.
modules
cls
.
am_rigid
=
alphafold
.
model
.
r3
def
test_shape
(
self
):
def
test_shape
(
self
):
batch_size
=
consts
.
batch_size
batch_size
=
consts
.
batch_size
c_t
=
consts
.
c_t
c_t
=
consts
.
c_t
...
@@ -65,6 +77,7 @@ class TestTemplatePairStack(unittest.TestCase):
...
@@ -65,6 +77,7 @@ class TestTemplatePairStack(unittest.TestCase):
dropout
=
0.25
dropout
=
0.25
n_templ
=
consts
.
n_templ
n_templ
=
consts
.
n_templ
n_res
=
consts
.
n_res
n_res
=
consts
.
n_res
tri_mul_first
=
consts
.
is_multimer
blocks_per_ckpt
=
None
blocks_per_ckpt
=
None
chunk_size
=
4
chunk_size
=
4
inf
=
1e7
inf
=
1e7
...
@@ -78,6 +91,7 @@ class TestTemplatePairStack(unittest.TestCase):
...
@@ -78,6 +91,7 @@ class TestTemplatePairStack(unittest.TestCase):
no_heads
=
no_heads
,
no_heads
=
no_heads
,
pair_transition_n
=
pt_inner_dim
,
pair_transition_n
=
pt_inner_dim
,
dropout_rate
=
dropout
,
dropout_rate
=
dropout
,
tri_mul_first
=
tri_mul_first
,
blocks_per_ckpt
=
None
,
blocks_per_ckpt
=
None
,
inf
=
inf
,
inf
=
inf
,
eps
=
eps
,
eps
=
eps
,
...
@@ -96,12 +110,40 @@ class TestTemplatePairStack(unittest.TestCase):
...
@@ -96,12 +110,40 @@ class TestTemplatePairStack(unittest.TestCase):
def
run_template_pair_stack
(
pair_act
,
pair_mask
):
def
run_template_pair_stack
(
pair_act
,
pair_mask
):
config
=
compare_utils
.
get_alphafold_config
()
config
=
compare_utils
.
get_alphafold_config
()
c_ee
=
config
.
model
.
embeddings_and_evoformer
c_ee
=
config
.
model
.
embeddings_and_evoformer
tps
=
alphafold
.
model
.
modules
.
TemplatePairStack
(
c_ee
.
template
.
template_pair_stack
,
if
consts
.
is_multimer
:
config
.
model
.
global_config
,
safe_key
=
alphafold
.
model
.
prng
.
SafeKey
(
hk
.
next_rng_key
())
name
=
"template_pair_stack"
,
template_iteration
=
self
.
am_modules
.
TemplateEmbeddingIteration
(
)
c_ee
.
template
.
template_pair_stack
,
act
=
tps
(
pair_act
,
pair_mask
,
is_training
=
False
)
config
.
model
.
global_config
,
name
=
'template_embedding_iteration'
)
def
template_iteration_fn
(
x
):
act
,
safe_key
=
x
safe_key
,
safe_subkey
=
safe_key
.
split
()
act
=
template_iteration
(
act
=
act
,
pair_mask
=
pair_mask
,
is_training
=
False
,
safe_key
=
safe_subkey
)
return
(
act
,
safe_key
)
if
config
.
model
.
global_config
.
use_remat
:
template_iteration_fn
=
hk
.
remat
(
template_iteration_fn
)
safe_key
,
safe_subkey
=
safe_key
.
split
()
template_stack
=
alphafold
.
model
.
layer_stack
.
layer_stack
(
c_ee
.
template
.
template_pair_stack
.
num_block
)(
template_iteration_fn
)
act
,
_
=
template_stack
((
pair_act
,
safe_subkey
))
else
:
tps
=
self
.
am_modules
.
TemplatePairStack
(
c_ee
.
template
.
template_pair_stack
,
config
.
model
.
global_config
,
name
=
"template_pair_stack"
,
)
act
=
tps
(
pair_act
,
pair_mask
,
is_training
=
False
)
ln
=
hk
.
LayerNorm
([
-
1
],
True
,
True
,
name
=
"output_layer_norm"
)
ln
=
hk
.
LayerNorm
([
-
1
],
True
,
True
,
name
=
"output_layer_norm"
)
act
=
ln
(
act
)
act
=
ln
(
act
)
return
act
return
act
...
@@ -115,10 +157,16 @@ class TestTemplatePairStack(unittest.TestCase):
...
@@ -115,10 +157,16 @@ class TestTemplatePairStack(unittest.TestCase):
low
=
0
,
high
=
2
,
size
=
(
n_res
,
n_res
)
low
=
0
,
high
=
2
,
size
=
(
n_res
,
n_res
)
).
astype
(
np
.
float32
)
).
astype
(
np
.
float32
)
params
=
compare_utils
.
fetch_alphafold_module_weights
(
if
consts
.
is_multimer
:
"alphafold/alphafold_iteration/evoformer/template_embedding/"
params
=
compare_utils
.
fetch_alphafold_module_weights
(
+
"single_template_embedding/template_pair_stack"
"alphafold/alphafold_iteration/evoformer/template_embedding/"
)
+
"single_template_embedding/template_embedding_iteration"
)
else
:
params
=
compare_utils
.
fetch_alphafold_module_weights
(
"alphafold/alphafold_iteration/evoformer/template_embedding/"
+
"single_template_embedding/template_pair_stack"
)
params
.
update
(
params
.
update
(
compare_utils
.
fetch_alphafold_module_weights
(
compare_utils
.
fetch_alphafold_module_weights
(
"alphafold/alphafold_iteration/evoformer/template_embedding/"
"alphafold/alphafold_iteration/evoformer/template_embedding/"
...
@@ -132,7 +180,7 @@ class TestTemplatePairStack(unittest.TestCase):
...
@@ -132,7 +180,7 @@ class TestTemplatePairStack(unittest.TestCase):
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
))
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
))
model
=
compare_utils
.
get_global_pretrained_openfold
()
model
=
compare_utils
.
get_global_pretrained_openfold
()
out_repro
=
model
.
template_pair_stack
(
out_repro
=
model
.
template_
embedder
.
template_
pair_stack
(
torch
.
as_tensor
(
pair_act
).
unsqueeze
(
-
4
).
cuda
(),
torch
.
as_tensor
(
pair_act
).
unsqueeze
(
-
4
).
cuda
(),
torch
.
as_tensor
(
pair_mask
).
unsqueeze
(
-
3
).
cuda
(),
torch
.
as_tensor
(
pair_mask
).
unsqueeze
(
-
3
).
cuda
(),
chunk_size
=
None
,
chunk_size
=
None
,
...
@@ -143,15 +191,32 @@ class TestTemplatePairStack(unittest.TestCase):
...
@@ -143,15 +191,32 @@ class TestTemplatePairStack(unittest.TestCase):
class
Template
(
unittest
.
TestCase
):
class
Template
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
if
consts
.
is_multimer
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom_multimer
cls
.
am_fold
=
alphafold
.
model
.
folding_multimer
cls
.
am_modules
=
alphafold
.
model
.
modules_multimer
cls
.
am_rigid
=
alphafold
.
model
.
geometry
else
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom
cls
.
am_fold
=
alphafold
.
model
.
folding
cls
.
am_modules
=
alphafold
.
model
.
modules
cls
.
am_rigid
=
alphafold
.
model
.
r3
@
compare_utils
.
skip_unless_alphafold_installed
()
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_compare
(
self
):
def
test_compare
(
self
):
def
test_template_embedding
(
pair
,
batch
,
mask_2d
):
def
test_template_embedding
(
pair
,
batch
,
mask_2d
):
config
=
compare_utils
.
get_alphafold_config
()
config
=
compare_utils
.
get_alphafold_config
()
te
=
alphafold
.
model
.
modules
.
TemplateEmbedding
(
te
=
self
.
am_
modules
.
TemplateEmbedding
(
config
.
model
.
embeddings_and_evoformer
.
template
,
config
.
model
.
embeddings_and_evoformer
.
template
,
config
.
model
.
global_config
,
config
.
model
.
global_config
,
)
)
act
=
te
(
pair
,
batch
,
mask_2d
,
is_training
=
False
)
if
consts
.
is_multimer
:
act
=
te
(
pair
,
batch
,
mask_2d
,
multichain_mask_2d
=
multichain_mask_2d
,
is_training
=
False
)
else
:
act
=
te
(
pair
,
batch
,
mask_2d
,
is_training
=
False
)
return
act
return
act
f
=
hk
.
transform
(
test_template_embedding
)
f
=
hk
.
transform
(
test_template_embedding
)
...
@@ -162,6 +227,14 @@ class Template(unittest.TestCase):
...
@@ -162,6 +227,14 @@ class Template(unittest.TestCase):
pair_act
=
np
.
random
.
rand
(
n_res
,
n_res
,
consts
.
c_z
).
astype
(
np
.
float32
)
pair_act
=
np
.
random
.
rand
(
n_res
,
n_res
,
consts
.
c_z
).
astype
(
np
.
float32
)
batch
=
random_template_feats
(
n_templ
,
n_res
)
batch
=
random_template_feats
(
n_templ
,
n_res
)
batch
[
"template_all_atom_masks"
]
=
batch
[
"template_all_atom_mask"
]
batch
[
"template_all_atom_masks"
]
=
batch
[
"template_all_atom_mask"
]
if
consts
.
is_multimer
:
asym_id
=
batch
[
'asym_id'
][
0
]
multichain_mask_2d
=
(
asym_id
[...,
None
]
==
asym_id
[...,
None
,
:]
).
astype
(
np
.
float32
)
batch
[
"multichain_mask_2d"
]
=
multichain_mask_2d
pair_mask
=
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
n_res
)).
astype
(
np
.
float32
)
pair_mask
=
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
n_res
)).
astype
(
np
.
float32
)
# Fetch pretrained parameters (but only from one block)]
# Fetch pretrained parameters (but only from one block)]
params
=
compare_utils
.
fetch_alphafold_module_weights
(
params
=
compare_utils
.
fetch_alphafold_module_weights
(
...
@@ -177,12 +250,26 @@ class Template(unittest.TestCase):
...
@@ -177,12 +250,26 @@ class Template(unittest.TestCase):
batch
[
"target_feat"
]
=
np
.
eye
(
22
)[
inds
]
batch
[
"target_feat"
]
=
np
.
eye
(
22
)[
inds
]
model
=
compare_utils
.
get_global_pretrained_openfold
()
model
=
compare_utils
.
get_global_pretrained_openfold
()
out_repro
=
model
.
embed_templates
(
{
k
:
torch
.
as_tensor
(
v
).
cuda
()
for
k
,
v
in
batch
.
items
()},
template_feats
=
{
k
:
torch
.
as_tensor
(
v
).
cuda
()
for
k
,
v
in
batch
.
items
()}
torch
.
as_tensor
(
pair_act
).
cuda
(),
if
consts
.
is_multimer
:
torch
.
as_tensor
(
pair_mask
).
cuda
(),
out_repro
=
model
.
template_embedder
(
templ_dim
=
0
,
template_feats
,
)
torch
.
as_tensor
(
pair_act
).
cuda
(),
torch
.
as_tensor
(
pair_mask
).
cuda
(),
templ_dim
=
0
,
chunk_size
=
consts
.
chunk_size
,
multichain_mask_2d
=
multichain_mask_2d
,
)
else
:
out_repro
=
model
.
template_embedder
(
template_feats
,
torch
.
as_tensor
(
pair_act
).
cuda
(),
torch
.
as_tensor
(
pair_mask
).
cuda
(),
templ_dim
=
0
,
chunk_size
=
consts
.
chunk_size
)
out_repro
=
out_repro
[
"template_pair_embedding"
]
out_repro
=
out_repro
[
"template_pair_embedding"
]
out_repro
=
out_repro
.
cpu
()
out_repro
=
out_repro
.
cpu
()
...
...
tests/test_triangular_attention.py
View file @
d8ee9c5f
...
@@ -85,9 +85,9 @@ class TestTriangularAttention(unittest.TestCase):
...
@@ -85,9 +85,9 @@ class TestTriangularAttention(unittest.TestCase):
model
=
compare_utils
.
get_global_pretrained_openfold
()
model
=
compare_utils
.
get_global_pretrained_openfold
()
module
=
(
module
=
(
model
.
evoformer
.
blocks
[
0
].
core
.
tri_att_start
model
.
evoformer
.
blocks
[
0
].
pair_stack
.
tri_att_start
if
starting
if
starting
else
model
.
evoformer
.
blocks
[
0
].
core
.
tri_att_end
else
model
.
evoformer
.
blocks
[
0
].
pair_stack
.
tri_att_end
)
)
out_repro
=
module
(
out_repro
=
module
(
torch
.
as_tensor
(
pair_act
,
dtype
=
torch
.
float32
).
cuda
(),
torch
.
as_tensor
(
pair_act
,
dtype
=
torch
.
float32
).
cuda
(),
...
...
tests/test_triangular_multiplicative_update.py
View file @
d8ee9c5f
...
@@ -87,9 +87,9 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
...
@@ -87,9 +87,9 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
model
=
compare_utils
.
get_global_pretrained_openfold
()
model
=
compare_utils
.
get_global_pretrained_openfold
()
module
=
(
module
=
(
model
.
evoformer
.
blocks
[
0
].
core
.
tri_mul_in
model
.
evoformer
.
blocks
[
0
].
pair_stack
.
tri_mul_in
if
incoming
if
incoming
else
model
.
evoformer
.
blocks
[
0
].
core
.
tri_mul_out
else
model
.
evoformer
.
blocks
[
0
].
pair_stack
.
tri_mul_out
)
)
out_repro
=
module
(
out_repro
=
module
(
torch
.
as_tensor
(
pair_act
,
dtype
=
torch
.
float32
).
cuda
(),
torch
.
as_tensor
(
pair_act
,
dtype
=
torch
.
float32
).
cuda
(),
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment