Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
d8ee9c5f
Commit
d8ee9c5f
authored
Feb 17, 2023
by
Christina Floristean
Browse files
All non-cuda tests passing for monomer/multimer. Tri mul/attn and OPM order switched.
parent
260db67f
Changes
33
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
470 additions
and
102 deletions
+470
-102
tests/test_embedders.py
tests/test_embedders.py
+26
-6
tests/test_evoformer.py
tests/test_evoformer.py
+5
-1
tests/test_feats.py
tests/test_feats.py
+75
-18
tests/test_import_weights.py
tests/test_import_weights.py
+1
-1
tests/test_loss.py
tests/test_loss.py
+145
-17
tests/test_model.py
tests/test_model.py
+25
-5
tests/test_outer_product_mean.py
tests/test_outer_product_mean.py
+1
-1
tests/test_pair_transition.py
tests/test_pair_transition.py
+1
-1
tests/test_primitives.py
tests/test_primitives.py
+3
-5
tests/test_structure_module.py
tests/test_structure_module.py
+77
-23
tests/test_template.py
tests/test_template.py
+107
-20
tests/test_triangular_attention.py
tests/test_triangular_attention.py
+2
-2
tests/test_triangular_multiplicative_update.py
tests/test_triangular_multiplicative_update.py
+2
-2
No files found.
tests/test_embedders.py
View file @
d8ee9c5f
...
...
@@ -12,14 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
random
import
torch
import
numpy
as
np
import
unittest
from
tests.config
import
consts
from
tests.data_utils
import
random_asym_ids
from
openfold.model.embedders
import
(
InputEmbedder
,
InputEmbedderMultimer
,
RecyclingEmbedder
,
TemplateAngleEmbedder
,
TemplatePairEmbedder
,
TemplatePairEmbedder
)
...
...
@@ -35,13 +38,30 @@ class TestInputEmbedder(unittest.TestCase):
n_res
=
17
n_clust
=
19
max_relative_chain
=
2
max_relative_idx
=
32
use_chain_relative
=
True
tf
=
torch
.
rand
((
b
,
n_res
,
tf_dim
))
ri
=
torch
.
rand
((
b
,
n_res
))
msa
=
torch
.
rand
((
b
,
n_clust
,
n_res
,
msa_dim
))
ie
=
InputEmbedder
(
tf_dim
,
msa_dim
,
c_z
,
c_m
,
relpos_k
)
msa_emb
,
pair_emb
=
ie
(
tf
,
ri
,
msa
)
asym_ids_flat
=
torch
.
Tensor
(
random_asym_ids
(
n_res
))
asym_id
=
torch
.
tile
(
asym_ids_flat
.
unsqueeze
(
0
),
(
b
,
1
))
entity_id
=
asym_id
sym_id
=
torch
.
zeros_like
(
entity_id
)
batch
=
{
"target_feat"
:
tf
,
"residue_index"
:
ri
,
"msa_feat"
:
msa
}
if
consts
.
is_multimer
:
ie
=
InputEmbedderMultimer
(
tf_dim
,
msa_dim
,
c_z
,
c_m
,
max_relative_idx
=
max_relative_idx
,
use_chain_relative
=
use_chain_relative
,
max_relative_chain
=
max_relative_chain
)
batch
.
update
({
"asym_id"
:
asym_id
,
"entity_id"
:
entity_id
,
"sym_id"
:
sym_id
})
else
:
ie
=
InputEmbedder
(
tf_dim
,
msa_dim
,
c_z
,
c_m
,
relpos_k
)
msa_emb
,
pair_emb
=
ie
(
batch
)
self
.
assertTrue
(
msa_emb
.
shape
==
(
b
,
n_clust
,
n_res
,
c_m
))
self
.
assertTrue
(
pair_emb
.
shape
==
(
b
,
n_res
,
n_res
,
c_z
))
...
...
tests/test_evoformer.py
View file @
d8ee9c5f
...
...
@@ -48,6 +48,7 @@ class TestEvoformerStack(unittest.TestCase):
transition_n
=
2
msa_dropout
=
0.15
pair_stack_dropout
=
0.25
opm_first
=
consts
.
is_multimer
inf
=
1e9
eps
=
1e-10
...
...
@@ -65,6 +66,7 @@ class TestEvoformerStack(unittest.TestCase):
transition_n
,
msa_dropout
,
pair_stack_dropout
,
opm_first
,
blocks_per_ckpt
=
None
,
inf
=
inf
,
eps
=
eps
,
...
...
@@ -156,6 +158,7 @@ class TestExtraMSAStack(unittest.TestCase):
transition_n
=
5
msa_dropout
=
0.15
pair_stack_dropout
=
0.25
opm_first
=
consts
.
is_multimer
inf
=
1e9
eps
=
1e-10
...
...
@@ -172,6 +175,7 @@ class TestExtraMSAStack(unittest.TestCase):
transition_n
,
msa_dropout
,
pair_stack_dropout
,
opm_first
,
ckpt
=
False
,
inf
=
inf
,
eps
=
eps
,
...
...
@@ -259,7 +263,7 @@ class TestMSATransition(unittest.TestCase):
model
=
compare_utils
.
get_global_pretrained_openfold
()
out_repro
=
(
model
.
evoformer
.
blocks
[
0
].
core
.
msa_transition
(
model
.
evoformer
.
blocks
[
0
].
msa_transition
(
torch
.
as_tensor
(
msa_act
,
dtype
=
torch
.
float32
).
cuda
(),
mask
=
torch
.
as_tensor
(
msa_mask
,
dtype
=
torch
.
float32
).
cuda
(),
)
...
...
tests/test_feats.py
View file @
d8ee9c5f
...
...
@@ -25,6 +25,9 @@ from openfold.np.residue_constants import (
)
import
openfold.utils.feats
as
feats
from
openfold.utils.rigid_utils
import
Rotation
,
Rigid
from
openfold.utils.geometry.rigid_matrix_vector
import
Rigid3Array
from
openfold.utils.geometry.rotation_matrix
import
Rot3Array
from
openfold.utils.geometry.vector
import
Vec3Array
from
openfold.utils.tensor_utils
import
(
tree_map
,
tensor_tree_map
,
...
...
@@ -40,6 +43,19 @@ if compare_utils.alphafold_is_installed():
class
TestFeats
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
if
consts
.
is_multimer
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom_multimer
cls
.
am_fold
=
alphafold
.
model
.
folding_multimer
cls
.
am_modules
=
alphafold
.
model
.
modules_multimer
cls
.
am_rigid
=
alphafold
.
model
.
geometry
else
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom
cls
.
am_fold
=
alphafold
.
model
.
folding
cls
.
am_modules
=
alphafold
.
model
.
modules
cls
.
am_rigid
=
alphafold
.
model
.
r3
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_pseudo_beta_fn_compare
(
self
):
def
test_pbf
(
aatype
,
all_atom_pos
,
all_atom_mask
):
...
...
@@ -131,7 +147,9 @@ class TestFeats(unittest.TestCase):
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_atom37_to_frames_compare
(
self
):
def
run_atom37_to_frames
(
aatype
,
all_atom_positions
,
all_atom_mask
):
return
alphafold
.
model
.
all_atom
.
atom37_to_frames
(
if
consts
.
is_multimer
:
all_atom_positions
=
self
.
am_rigid
.
Vec3Array
.
from_array
(
all_atom_positions
)
return
self
.
am_atom
.
atom37_to_frames
(
aatype
,
all_atom_positions
,
all_atom_mask
)
...
...
@@ -150,7 +168,14 @@ class TestFeats(unittest.TestCase):
}
out_gt
=
f
.
apply
({},
None
,
**
batch
)
to_tensor
=
lambda
t
:
torch
.
tensor
(
np
.
array
(
t
))
if
consts
.
is_multimer
:
to_tensor
=
(
lambda
t
:
torch
.
tensor
(
np
.
array
(
t
))
if
not
isinstance
(
t
,
self
.
am_rigid
.
Rigid3Array
)
else
torch
.
tensor
(
np
.
array
(
t
.
to_array
())).
view
(
*
t
.
shape
[:
2
],
12
))
else
:
to_tensor
=
lambda
t
:
torch
.
tensor
(
np
.
array
(
t
))
out_gt
=
{
k
:
to_tensor
(
v
)
for
k
,
v
in
out_gt
.
items
()}
def
flat12_to_4x4
(
flat12
):
...
...
@@ -187,7 +212,13 @@ class TestFeats(unittest.TestCase):
n
=
5
rots
=
torch
.
rand
((
batch_size
,
n
,
3
,
3
))
trans
=
torch
.
rand
((
batch_size
,
n
,
3
))
ts
=
Rigid
(
Rotation
(
rot_mats
=
rots
),
trans
)
if
consts
.
is_multimer
:
rotation
=
Rot3Array
.
from_array
(
rots
)
translation
=
Vec3Array
.
from_array
(
trans
)
ts
=
Rigid3Array
(
rotation
,
translation
)
else
:
ts
=
Rigid
(
Rotation
(
rot_mats
=
rots
),
trans
)
angles
=
torch
.
rand
((
batch_size
,
n
,
7
,
2
))
...
...
@@ -208,7 +239,7 @@ class TestFeats(unittest.TestCase):
def
run_torsion_angles_to_frames
(
aatype
,
backb_to_global
,
torsion_angles_sin_cos
):
return
alphafold
.
mod
el
.
a
ll
_atom
.
torsion_angles_to_frames
(
return
s
el
f
.
a
m
_atom
.
torsion_angles_to_frames
(
aatype
,
backb_to_global
,
torsion_angles_sin_cos
,
...
...
@@ -221,10 +252,17 @@ class TestFeats(unittest.TestCase):
aatype
=
np
.
random
.
randint
(
0
,
21
,
size
=
(
n_res
,))
affines
=
random_affines_4x4
((
n_res
,))
rigids
=
alphafold
.
model
.
r3
.
rigids_from_tensor4x4
(
affines
)
transformations
=
Rigid
.
from_tensor_4x4
(
torch
.
as_tensor
(
affines
).
float
()
)
if
consts
.
is_multimer
:
rigids
=
self
.
am_rigid
.
Rigid3Array
.
from_array4x4
(
affines
)
transformations
=
Rigid3Array
.
from_tensor_4x4
(
torch
.
as_tensor
(
affines
).
float
()
)
else
:
rigids
=
self
.
am_rigid
.
rigids_from_tensor4x4
(
affines
)
transformations
=
Rigid
.
from_tensor_4x4
(
torch
.
as_tensor
(
affines
).
float
()
)
torsion_angles_sin_cos
=
np
.
random
.
rand
(
n_res
,
7
,
2
)
...
...
@@ -264,7 +302,13 @@ class TestFeats(unittest.TestCase):
rots
=
torch
.
rand
((
batch_size
,
n_res
,
8
,
3
,
3
))
trans
=
torch
.
rand
((
batch_size
,
n_res
,
8
,
3
))
ts
=
Rigid
(
Rotation
(
rot_mats
=
rots
),
trans
)
if
consts
.
is_multimer
:
rotation
=
Rot3Array
.
from_array
(
rots
)
translation
=
Vec3Array
.
from_array
(
trans
)
ts
=
Rigid3Array
(
rotation
,
translation
)
else
:
ts
=
Rigid
(
Rotation
(
rot_mats
=
rots
),
trans
)
f
=
torch
.
randint
(
low
=
0
,
high
=
21
,
size
=
(
batch_size
,
n_res
)).
long
()
...
...
@@ -277,13 +321,15 @@ class TestFeats(unittest.TestCase):
torch
.
tensor
(
restype_atom14_rigid_group_positions
),
)
if
consts
.
is_multimer
:
xyz
=
xyz
.
to_tensor
()
self
.
assertTrue
(
xyz
.
shape
==
(
batch_size
,
n_res
,
14
,
3
))
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_frames_and_literature_positions_to_atom14_pos_compare
(
self
):
def
run_f
(
aatype
,
affines
):
am
=
alphafold
.
model
return
am
.
all_atom
.
frames_and_literature_positions_to_atom14_pos
(
return
self
.
am_atom
.
frames_and_literature_positions_to_atom14_pos
(
aatype
,
affines
)
...
...
@@ -294,16 +340,27 @@ class TestFeats(unittest.TestCase):
aatype
=
np
.
random
.
randint
(
0
,
21
,
size
=
(
n_res
,))
affines
=
random_affines_4x4
((
n_res
,
8
))
rigids
=
alphafold
.
model
.
r3
.
rigids_from_tensor4x4
(
affines
)
transformations
=
Rigid
.
from_tensor_4x4
(
torch
.
as_tensor
(
affines
).
float
()
)
if
consts
.
is_multimer
:
rigids
=
self
.
am_rigid
.
Rigid3Array
.
from_array4x4
(
affines
)
transformations
=
Rigid3Array
.
from_tensor_4x4
(
torch
.
as_tensor
(
affines
).
float
()
)
else
:
rigids
=
self
.
am_rigid
.
rigids_from_tensor4x4
(
affines
)
transformations
=
Rigid
.
from_tensor_4x4
(
torch
.
as_tensor
(
affines
).
float
()
)
out_gt
=
f
.
apply
({},
None
,
aatype
,
rigids
)
jax
.
tree_map
(
lambda
x
:
x
.
block_until_ready
(),
out_gt
)
out_gt
=
torch
.
stack
(
[
torch
.
as_tensor
(
np
.
array
(
x
))
for
x
in
out_gt
],
dim
=-
1
)
if
consts
.
is_multimer
:
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
.
to_array
()))
else
:
out_gt
=
torch
.
stack
(
[
torch
.
as_tensor
(
np
.
array
(
x
))
for
x
in
out_gt
],
dim
=-
1
)
out_repro
=
feats
.
frames_and_literature_positions_to_atom14_pos
(
transformations
.
cuda
(),
...
...
tests/test_import_weights.py
View file @
d8ee9c5f
...
...
@@ -65,7 +65,7 @@ class TestImportWeights(unittest.TestCase):
)
][
1
].
transpose
(
-
1
,
-
2
)
),
model
.
evoformer
.
blocks
[
1
].
core
.
outer_product_mean
.
linear_1
.
weight
,
model
.
evoformer
.
blocks
[
1
].
outer_product_mean
.
linear_1
.
weight
,
),
]
...
...
tests/test_loss.py
View file @
d8ee9c5f
...
...
@@ -13,7 +13,6 @@
# limitations under the License.
import
os
import
math
import
torch
import
numpy
as
np
import
unittest
...
...
@@ -24,7 +23,6 @@ from openfold.utils.rigid_utils import (
Rotation
,
Rigid
,
)
import
openfold.utils.feats
as
feats
from
openfold.utils.loss
import
(
torsion_angle_loss
,
compute_fape
,
...
...
@@ -51,7 +49,7 @@ from openfold.utils.tensor_utils import (
)
import
tests.compare_utils
as
compare_utils
from
tests.config
import
consts
from
tests.data_utils
import
random_affines_vector
,
random_affines_4x4
from
tests.data_utils
import
random_affines_vector
,
random_affines_4x4
,
random_asym_ids
if
compare_utils
.
alphafold_is_installed
():
alphafold
=
compare_utils
.
import_alphafold
()
...
...
@@ -64,7 +62,30 @@ def affine_vector_to_4x4(affine):
return
r
.
to_tensor_4x4
()
def
affine_vector_to_rigid
(
am_rigid
,
affine
):
rigid_flat
=
np
.
split
(
affine
,
7
,
axis
=-
1
)
rigid_flat
=
[
r
.
squeeze
(
-
1
)
for
r
in
rigid_flat
]
qw
,
qx
,
qy
,
qz
=
rigid_flat
[:
4
]
trans
=
rigid_flat
[
4
:]
rotation
=
am_rigid
.
Rot3Array
.
from_quaternion
(
qw
,
qx
,
qy
,
qz
,
normalize
=
True
)
translation
=
am_rigid
.
Vec3Array
(
*
trans
)
return
am_rigid
.
Rigid3Array
(
rotation
,
translation
)
class
TestLoss
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
if
consts
.
is_multimer
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom_multimer
cls
.
am_fold
=
alphafold
.
model
.
folding_multimer
cls
.
am_modules
=
alphafold
.
model
.
modules_multimer
cls
.
am_rigid
=
alphafold
.
model
.
geometry
else
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom
cls
.
am_fold
=
alphafold
.
model
.
folding
cls
.
am_modules
=
alphafold
.
model
.
modules
cls
.
am_rigid
=
alphafold
.
model
.
r3
def
test_run_torsion_angle_loss
(
self
):
batch_size
=
consts
.
batch_size
n_res
=
consts
.
n_res
...
...
@@ -127,7 +148,10 @@ class TestLoss(unittest.TestCase):
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_between_residue_bond_loss_compare
(
self
):
def
run_brbl
(
pred_pos
,
pred_atom_mask
,
residue_index
,
aatype
):
return
alphafold
.
model
.
all_atom
.
between_residue_bond_loss
(
if
consts
.
is_multimer
:
pred_pos
=
self
.
am_rigid
.
Vec3Array
.
from_array
(
pred_pos
)
return
self
.
am_atom
.
between_residue_bond_loss
(
pred_pos
,
pred_atom_mask
,
residue_index
,
...
...
@@ -184,12 +208,22 @@ class TestLoss(unittest.TestCase):
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_between_residue_clash_loss_compare
(
self
):
def
run_brcl
(
pred_pos
,
atom_exists
,
atom_radius
,
res_ind
):
return
alphafold
.
model
.
all_atom
.
between_residue_clash_loss
(
def
run_brcl
(
pred_pos
,
atom_exists
,
atom_radius
,
res_ind
,
asym_id
):
if
consts
.
is_multimer
:
pred_pos
=
self
.
am_rigid
.
Vec3Array
.
from_array
(
pred_pos
)
return
self
.
am_atom
.
between_residue_clash_loss
(
pred_pos
,
atom_exists
,
atom_radius
,
res_ind
,
asym_id
)
return
self
.
am_atom
.
between_residue_clash_loss
(
pred_pos
,
atom_exists
,
atom_radius
,
res_ind
,
res_ind
)
f
=
hk
.
transform
(
run_brcl
)
...
...
@@ -202,6 +236,7 @@ class TestLoss(unittest.TestCase):
res_ind
=
np
.
arange
(
n_res
,
)
asym_id
=
random_asym_ids
(
n_res
)
out_gt
=
f
.
apply
(
{},
...
...
@@ -210,6 +245,7 @@ class TestLoss(unittest.TestCase):
atom_exists
,
atom_radius
,
res_ind
,
asym_id
)
out_gt
=
jax
.
tree_map
(
lambda
x
:
x
.
block_until_ready
(),
out_gt
)
out_gt
=
jax
.
tree_map
(
lambda
x
:
torch
.
tensor
(
np
.
copy
(
x
)),
out_gt
)
...
...
@@ -266,7 +302,19 @@ class TestLoss(unittest.TestCase):
def
run_fsv
(
batch
,
pos
,
config
):
cwd
=
os
.
getcwd
()
os
.
chdir
(
"tests/test_data"
)
loss
=
alphafold
.
model
.
folding
.
find_structural_violations
(
if
consts
.
is_multimer
:
atom14_pred_pos
=
self
.
am_rigid
.
Vec3Array
.
from_array
(
pos
)
return
self
.
am_fold
.
find_structural_violations
(
batch
[
'aatype'
],
batch
[
'residue_index'
],
batch
[
'atom14_atom_exists'
],
atom14_pred_pos
,
config
,
batch
[
'asym_id'
]
)
loss
=
self
.
am_fold
.
find_structural_violations
(
batch
,
pos
,
config
,
...
...
@@ -285,6 +333,7 @@ class TestLoss(unittest.TestCase):
"residx_atom14_to_atom37"
:
np
.
random
.
randint
(
0
,
37
,
(
n_res
,
14
)
).
astype
(
np
.
int64
),
"asym_id"
:
random_asym_ids
(
n_res
)
}
pred_pos
=
np
.
random
.
rand
(
n_res
,
14
,
3
)
...
...
@@ -380,7 +429,7 @@ class TestLoss(unittest.TestCase):
n_seq
=
consts
.
n_seq
value
=
{
"logits"
:
np
.
random
.
rand
(
n_res
,
n_seq
,
23
).
astype
(
np
.
float32
),
"logits"
:
np
.
random
.
rand
(
n_res
,
n_seq
,
consts
.
msa_logits
).
astype
(
np
.
float32
),
}
batch
=
{
...
...
@@ -506,10 +555,28 @@ class TestLoss(unittest.TestCase):
c_chi_loss
=
config
.
model
.
heads
.
structure_module
def
run_supervised_chi_loss
(
value
,
batch
):
if
consts
.
is_multimer
:
pred_angles
=
np
.
reshape
(
value
[
'sidechains'
][
'angles_sin_cos'
],
[
-
1
,
consts
.
n_res
,
7
,
2
])
unnormed_angles
=
np
.
reshape
(
value
[
'sidechains'
][
'unnormalized_angles_sin_cos'
],
[
-
1
,
consts
.
n_res
,
7
,
2
])
chi_loss
,
_
,
_
=
self
.
am_fold
.
supervised_chi_loss
(
batch
[
'seq_mask'
],
batch
[
'chi_mask'
],
batch
[
'aatype'
],
batch
[
'chi_angles'
],
pred_angles
,
unnormed_angles
,
c_chi_loss
)
return
chi_loss
ret
=
{
"loss"
:
jax
.
numpy
.
array
(
0.0
),
}
alphafold
.
model
.
fold
ing
.
supervised_chi_loss
(
self
.
am_
fold
.
supervised_chi_loss
(
ret
,
batch
,
value
,
c_chi_loss
)
return
ret
[
"loss"
]
...
...
@@ -570,15 +637,31 @@ class TestLoss(unittest.TestCase):
ret
=
{
"loss"
:
np
.
array
(
0.0
).
astype
(
np
.
float32
),
}
if
consts
.
is_multimer
:
atom14_pred_pos
=
self
.
am_rigid
.
Vec3Array
.
from_array
(
atom14_pred_pos
)
viol
=
self
.
am_fold
.
find_structural_violations
(
batch
[
'aatype'
],
batch
[
'residue_index'
],
batch
[
'atom14_atom_exists'
],
atom14_pred_pos
,
c_viol
,
batch
[
'asym_id'
]
)
return
self
.
am_fold
.
structural_violation_loss
(
mask
=
batch
[
'atom14_atom_exists'
],
violations
=
viol
,
config
=
c_viol
)
value
=
{}
value
[
"violations"
]
=
alphafold
.
model
.
fold
ing
.
find_structural_violations
(
]
=
self
.
am_
fold
.
find_structural_violations
(
batch
,
atom14_pred_pos
,
c_viol
,
)
alphafold
.
model
.
folding
.
structural_violation_loss
(
self
.
am_fold
.
structural_violation_loss
(
ret
,
batch
,
value
,
...
...
@@ -594,12 +677,14 @@ class TestLoss(unittest.TestCase):
"seq_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,)).
astype
(
np
.
float32
),
"residue_index"
:
np
.
arange
(
n_res
),
"aatype"
:
np
.
random
.
randint
(
0
,
21
,
(
n_res
,)),
"asym_id"
:
random_asym_ids
(
n_res
)
}
alphafold
.
model
.
tf
.
data_transforms
.
make_atom14_masks
(
batch
)
batch
=
{
k
:
np
.
array
(
v
)
for
k
,
v
in
batch
.
items
()}
atom14_pred_pos
=
np
.
random
.
rand
(
n_res
,
14
,
3
).
astype
(
np
.
float32
)
alphafold
.
model
.
tf
.
data_transforms
.
make_atom14_masks
(
batch
)
batch
=
{
k
:
np
.
array
(
v
)
for
k
,
v
in
batch
.
items
()}
out_gt
=
f
.
apply
({},
None
,
batch
,
atom14_pred_pos
)
out_gt
=
torch
.
tensor
(
np
.
array
(
out_gt
.
block_until_ready
()))
...
...
@@ -676,10 +761,31 @@ class TestLoss(unittest.TestCase):
c_sm
=
config
.
model
.
heads
.
structure_module
def
run_bb_loss
(
batch
,
value
):
if
consts
.
is_multimer
:
intra_chain_mask
=
(
batch
[
"asym_id"
][...,
None
]
==
batch
[
"asym_id"
][...,
None
,
:]).
astype
(
np
.
float32
)
gt_rigid
=
affine_vector_to_rigid
(
self
.
am_rigid
,
batch
[
"backbone_affine_tensor"
])
target_rigid
=
affine_vector_to_rigid
(
self
.
am_rigid
,
value
[
'traj'
])
intra_chain_bb_loss
,
intra_chain_fape
=
self
.
am_fold
.
backbone_loss
(
gt_rigid
=
gt_rigid
,
gt_frames_mask
=
batch
[
"backbone_affine_mask"
],
gt_positions_mask
=
batch
[
"backbone_affine_mask"
],
target_rigid
=
target_rigid
,
config
=
c_sm
.
intra_chain_fape
,
pair_mask
=
intra_chain_mask
)
interface_bb_loss
,
interface_fape
=
self
.
am_fold
.
backbone_loss
(
gt_rigid
=
gt_rigid
,
gt_frames_mask
=
batch
[
"backbone_affine_mask"
],
gt_positions_mask
=
batch
[
"backbone_affine_mask"
],
target_rigid
=
target_rigid
,
config
=
c_sm
.
interface_fape
,
pair_mask
=
1.
-
intra_chain_mask
)
return
intra_chain_bb_loss
+
interface_bb_loss
ret
=
{
"loss"
:
np
.
array
(
0.0
),
}
alphafold
.
model
.
fold
ing
.
backbone_loss
(
ret
,
batch
,
value
,
c_sm
)
self
.
am_
fold
.
backbone_loss
(
ret
,
batch
,
value
,
c_sm
)
return
ret
[
"loss"
]
f
=
hk
.
transform
(
run_bb_loss
)
...
...
@@ -692,6 +798,7 @@ class TestLoss(unittest.TestCase):
np
.
float32
),
"use_clamped_fape"
:
np
.
array
(
0.0
),
"asym_id"
:
random_asym_ids
(
n_res
)
}
value
=
{
...
...
@@ -726,9 +833,29 @@ class TestLoss(unittest.TestCase):
c_sm
=
config
.
model
.
heads
.
structure_module
def
run_sidechain_loss
(
batch
,
value
,
atom14_pred_positions
):
if
consts
.
is_multimer
:
atom14_pred_positions
=
self
.
am_rigid
.
Vec3Array
.
from_array
(
atom14_pred_positions
)
all_atom_positions
=
self
.
am_rigid
.
Vec3Array
.
from_array
(
batch
[
"all_atom_positions"
])
gt_positions
,
gt_mask
,
alt_naming_is_better
=
self
.
am_fold
.
compute_atom14_gt
(
aatype
=
batch
[
"aatype"
],
all_atom_positions
=
all_atom_positions
,
all_atom_mask
=
batch
[
"all_atom_mask"
],
pred_pos
=
atom14_pred_positions
)
pred_frames
=
self
.
am_rigid
.
Rigid3Array
.
from_array4x4
(
value
[
"sidechains"
][
"frames"
])
pred_positions
=
self
.
am_rigid
.
Vec3Array
.
from_array
(
value
[
"sidechains"
][
"atom_pos"
])
gt_sc_frames
,
gt_sc_frames_mask
=
self
.
am_fold
.
compute_frames
(
aatype
=
batch
[
"aatype"
],
all_atom_positions
=
all_atom_positions
,
all_atom_mask
=
batch
[
"all_atom_mask"
],
use_alt
=
alt_naming_is_better
)
return
self
.
am_fold
.
sidechain_loss
(
gt_sc_frames
,
gt_sc_frames_mask
,
gt_positions
,
gt_mask
,
pred_frames
,
pred_positions
,
c_sm
)[
'loss'
]
batch
=
{
**
batch
,
**
alphafold
.
mod
el
.
a
ll
_atom
.
atom37_to_frames
(
**
s
el
f
.
a
m
_atom
.
atom37_to_frames
(
batch
[
"aatype"
],
batch
[
"all_atom_positions"
],
batch
[
"all_atom_mask"
],
...
...
@@ -752,7 +879,7 @@ class TestLoss(unittest.TestCase):
)
value
=
v
ret
=
alphafold
.
model
.
fold
ing
.
sidechain_loss
(
batch
,
value
,
c_sm
)
ret
=
self
.
am_
fold
.
sidechain_loss
(
batch
,
value
,
c_sm
)
return
ret
[
"loss"
]
f
=
hk
.
transform
(
run_sidechain_loss
)
...
...
@@ -816,6 +943,7 @@ class TestLoss(unittest.TestCase):
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
@
compare_utils
.
skip_unless_alphafold_installed
()
@
unittest
.
skipIf
(
not
consts
.
is_multimer
and
"ptm"
not
in
consts
.
model
,
"Not enabled for non-ptm models."
)
def
test_tm_loss_compare
(
self
):
config
=
compare_utils
.
get_alphafold_config
()
c_tm
=
config
.
model
.
heads
.
predicted_aligned_error
...
...
tests/test_model.py
View file @
d8ee9c5f
...
...
@@ -20,8 +20,7 @@ import unittest
from
openfold.config
import
model_config
from
openfold.data
import
data_transforms
from
openfold.model.model
import
AlphaFold
import
openfold.utils.feats
as
feats
from
openfold.utils.tensor_utils
import
tree_map
,
tensor_tree_map
from
openfold.utils.tensor_utils
import
tensor_tree_map
import
tests.compare_utils
as
compare_utils
from
tests.config
import
consts
from
tests.data_utils
import
(
...
...
@@ -36,13 +35,26 @@ if compare_utils.alphafold_is_installed():
class
TestModel
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
if
consts
.
is_multimer
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom_multimer
cls
.
am_fold
=
alphafold
.
model
.
folding_multimer
cls
.
am_modules
=
alphafold
.
model
.
modules_multimer
cls
.
am_rigid
=
alphafold
.
model
.
geometry
else
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom
cls
.
am_fold
=
alphafold
.
model
.
folding
cls
.
am_modules
=
alphafold
.
model
.
modules
cls
.
am_rigid
=
alphafold
.
model
.
r3
def
test_dry_run
(
self
):
n_seq
=
consts
.
n_seq
n_templ
=
consts
.
n_templ
n_res
=
consts
.
n_res
n_extra_seq
=
consts
.
n_extra
c
=
model_config
(
"
model
_1"
)
c
=
model_config
(
consts
.
model
)
c
.
model
.
evoformer_stack
.
no_blocks
=
4
# no need to go overboard here
c
.
model
.
evoformer_stack
.
blocks_per_ckpt
=
None
# don't want to set up
# deepspeed for this test
...
...
@@ -68,6 +80,12 @@ class TestModel(unittest.TestCase):
batch
.
update
(
data_transforms
.
make_atom14_masks
(
batch
))
batch
[
"no_recycling_iters"
]
=
torch
.
tensor
(
2.
)
if
consts
.
is_multimer
:
batch
[
"asym_id"
]
=
torch
.
randint
(
0
,
1
,
size
=
(
n_res
,))
batch
[
"entity_id"
]
=
torch
.
randint
(
0
,
1
,
size
=
(
n_res
,))
batch
[
"sym_id"
]
=
torch
.
randint
(
0
,
1
,
size
=
(
n_res
,))
batch
[
"extra_deletion_matrix"
]
=
torch
.
randint
(
0
,
2
,
size
=
(
n_extra_seq
,
n_res
))
add_recycling_dims
=
lambda
t
:
(
t
.
unsqueeze
(
-
1
).
expand
(
*
t
.
shape
,
c
.
data
.
common
.
max_recycling_iters
)
)
...
...
@@ -80,7 +98,8 @@ class TestModel(unittest.TestCase):
def
test_compare
(
self
):
def
run_alphafold
(
batch
):
config
=
compare_utils
.
get_alphafold_config
()
model
=
alphafold
.
model
.
modules
.
AlphaFold
(
config
.
model
)
model
=
self
.
am_modules
.
AlphaFold
(
config
.
model
)
return
model
(
batch
=
batch
,
is_training
=
False
,
...
...
@@ -100,7 +119,8 @@ class TestModel(unittest.TestCase):
# atom37_to_atom14 doesn't like batches
batch
[
"residx_atom14_to_atom37"
]
=
batch
[
"residx_atom14_to_atom37"
][
0
]
batch
[
"atom14_atom_exists"
]
=
batch
[
"atom14_atom_exists"
][
0
]
out_gt
=
alphafold
.
model
.
all_atom
.
atom37_to_atom14
(
out_gt
,
batch
)
out_gt
=
self
.
am_atom
.
atom37_to_atom14
(
out_gt
,
batch
)
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
.
block_until_ready
()))
batch
[
"no_recycling_iters"
]
=
np
.
array
([
3.
,
3.
,
3.
,
3.
,])
...
...
tests/test_outer_product_mean.py
View file @
d8ee9c5f
...
...
@@ -81,7 +81,7 @@ class TestOuterProductMean(unittest.TestCase):
model
=
compare_utils
.
get_global_pretrained_openfold
()
out_repro
=
(
model
.
evoformer
.
blocks
[
0
]
.
core
model
.
evoformer
.
blocks
[
0
]
.
outer_product_mean
(
torch
.
as_tensor
(
msa_act
).
cuda
(),
chunk_size
=
4
,
...
...
tests/test_pair_transition.py
View file @
d8ee9c5f
...
...
@@ -76,7 +76,7 @@ class TestPairTransition(unittest.TestCase):
model
=
compare_utils
.
get_global_pretrained_openfold
()
out_repro
=
(
model
.
evoformer
.
blocks
[
0
].
core
model
.
evoformer
.
blocks
[
0
].
pair_stack
.
pair_transition
(
torch
.
as_tensor
(
pair_act
,
dtype
=
torch
.
float32
).
cuda
(),
chunk_size
=
4
,
...
...
tests/test_primitives.py
View file @
d8ee9c5f
...
...
@@ -13,12 +13,10 @@
# limitations under the License.
import
torch
import
numpy
as
np
import
unittest
from
openfold.model.primitives
import
(
Attention
,
LowMemoryAttention
,
Attention
)
from
tests.config
import
consts
...
...
@@ -40,7 +38,7 @@ class TestLMA(unittest.TestCase):
gating_fill
=
torch
.
rand
(
c_hidden
*
no_heads
,
c_hidden
)
o_fill
=
torch
.
rand
(
c_hidden
,
c_hidden
*
no_heads
)
lma
=
LowMemory
Attention
(
lma
=
Attention
(
c_hidden
,
c_hidden
,
c_hidden
,
c_hidden
,
no_heads
).
cuda
()
a
=
Attention
(
...
...
@@ -60,7 +58,7 @@ class TestLMA(unittest.TestCase):
m
.
linear_o
.
weight
.
copy_
(
o_fill
)
with
torch
.
no_grad
():
l
=
lma
(
q
,
k
,
v
,
1024
,
4096
,
biases
=
bias
)
l
=
lma
(
q
,
k
,
v
,
biases
=
bias
,
use_lma
=
True
,
q_chunk_size
=
1024
,
kv_chunk_size
=
4096
)
real
=
a
(
q
,
k
,
v
,
biases
=
bias
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
l
-
real
))
<
consts
.
eps
)
...
...
tests/test_structure_module.py
View file @
d8ee9c5f
...
...
@@ -18,21 +18,19 @@ import unittest
from
openfold.data.data_transforms
import
make_atom14_masks_np
from
openfold.np.residue_constants
import
(
restype_rigid_group_default_frame
,
restype_atom14_to_rigid_group
,
restype_atom14_mask
,
restype_atom14_rigid_group_positions
,
restype_atom37_mask
,
)
from
openfold.model.structure_module
import
(
StructureModule
,
StructureModuleTransition
,
BackboneUpdate
,
AngleResnet
,
InvariantPointAttention
,
)
import
openfold.utils.feats
as
feats
from
openfold.utils.rigid_utils
import
Rotation
,
Rigid
from
openfold.utils.geometry.rigid_matrix_vector
import
Rigid3Array
from
openfold.utils.geometry.rotation_matrix
import
Rot3Array
from
openfold.utils.geometry.vector
import
Vec3Array
import
tests.compare_utils
as
compare_utils
from
tests.config
import
consts
from
tests.data_utils
import
(
...
...
@@ -46,6 +44,19 @@ if compare_utils.alphafold_is_installed():
class
TestStructureModule
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
if
consts
.
is_multimer
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom_multimer
cls
.
am_fold
=
alphafold
.
model
.
folding_multimer
cls
.
am_modules
=
alphafold
.
model
.
modules_multimer
cls
.
am_rigid
=
alphafold
.
model
.
geometry
else
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom
cls
.
am_fold
=
alphafold
.
model
.
folding
cls
.
am_modules
=
alphafold
.
model
.
modules
cls
.
am_rigid
=
alphafold
.
model
.
r3
def
test_structure_module_shape
(
self
):
batch_size
=
consts
.
batch_size
n
=
consts
.
n_res
...
...
@@ -81,6 +92,7 @@ class TestStructureModule(unittest.TestCase):
trans_scale_factor
,
ar_epsilon
,
inf
,
is_multimer
=
consts
.
is_multimer
)
s
=
torch
.
rand
((
batch_size
,
n
,
c_s
))
...
...
@@ -89,7 +101,11 @@ class TestStructureModule(unittest.TestCase):
out
=
sm
(
s
,
z
,
f
)
self
.
assertTrue
(
out
[
"frames"
].
shape
==
(
no_layers
,
batch_size
,
n
,
7
))
if
consts
.
is_multimer
:
self
.
assertTrue
(
out
[
"frames"
].
shape
==
(
no_layers
,
batch_size
,
n
,
4
,
4
))
else
:
self
.
assertTrue
(
out
[
"frames"
].
shape
==
(
no_layers
,
batch_size
,
n
,
7
))
self
.
assertTrue
(
out
[
"angles"
].
shape
==
(
no_layers
,
batch_size
,
n
,
no_angles
,
2
)
)
...
...
@@ -121,11 +137,14 @@ class TestStructureModule(unittest.TestCase):
c_global
=
config
.
model
.
global_config
def
run_sm
(
representations
,
batch
):
sm
=
alphafold
.
model
.
fold
ing
.
StructureModule
(
c_sm
,
c_global
)
sm
=
self
.
am_
fold
.
StructureModule
(
c_sm
,
c_global
)
representations
=
{
k
:
jax
.
lax
.
stop_gradient
(
v
)
for
k
,
v
in
representations
.
items
()
}
batch
=
{
k
:
jax
.
lax
.
stop_gradient
(
v
)
for
k
,
v
in
batch
.
items
()}
if
consts
.
is_multimer
:
return
sm
(
representations
,
batch
,
is_training
=
False
,
compute_loss
=
True
)
return
sm
(
representations
,
batch
,
is_training
=
False
)
f
=
hk
.
transform
(
run_sm
)
...
...
@@ -178,6 +197,19 @@ class TestStructureModule(unittest.TestCase):
class
TestInvariantPointAttention
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
if
consts
.
is_multimer
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom_multimer
cls
.
am_fold
=
alphafold
.
model
.
folding_multimer
cls
.
am_modules
=
alphafold
.
model
.
modules_multimer
cls
.
am_rigid
=
alphafold
.
model
.
geometry
else
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom
cls
.
am_fold
=
alphafold
.
model
.
folding
cls
.
am_modules
=
alphafold
.
model
.
modules
cls
.
am_rigid
=
alphafold
.
model
.
r3
def
test_shape
(
self
):
c_m
=
13
c_z
=
17
...
...
@@ -194,13 +226,18 @@ class TestInvariantPointAttention(unittest.TestCase):
mask
=
torch
.
ones
((
batch_size
,
n_res
))
rot_mats
=
torch
.
rand
((
batch_size
,
n_res
,
3
,
3
))
rots
=
Rotation
(
rot_mats
=
rot_mats
,
quats
=
None
)
trans
=
torch
.
rand
((
batch_size
,
n_res
,
3
))
r
=
Rigid
(
rots
,
trans
)
if
consts
.
is_multimer
:
rotation
=
Rot3Array
.
from_array
(
rot_mats
)
translation
=
Vec3Array
.
from_array
(
trans
)
r
=
Rigid3Array
(
rotation
,
translation
)
else
:
rots
=
Rotation
(
rot_mats
=
rot_mats
,
quats
=
None
)
r
=
Rigid
(
rots
,
trans
)
ipa
=
InvariantPointAttention
(
c_m
,
c_z
,
c_hidden
,
no_heads
,
no_qp
,
no_vp
c_m
,
c_z
,
c_hidden
,
no_heads
,
no_qp
,
no_vp
,
is_multimer
=
consts
.
is_multimer
)
shape_before
=
s
.
shape
...
...
@@ -212,16 +249,26 @@ class TestInvariantPointAttention(unittest.TestCase):
def
test_ipa_compare
(
self
):
def
run_ipa
(
act
,
static_feat_2d
,
mask
,
affine
):
config
=
compare_utils
.
get_alphafold_config
()
ipa
=
alphafold
.
model
.
fold
ing
.
InvariantPointAttention
(
ipa
=
self
.
am_
fold
.
InvariantPointAttention
(
config
.
model
.
heads
.
structure_module
,
config
.
model
.
global_config
,
)
attn
=
ipa
(
inputs_1d
=
act
,
inputs_2d
=
static_feat_2d
,
mask
=
mask
,
affine
=
affine
,
)
if
consts
.
is_multimer
:
attn
=
ipa
(
inputs_1d
=
act
,
inputs_2d
=
static_feat_2d
,
mask
=
mask
,
rigid
=
affine
)
else
:
attn
=
ipa
(
inputs_1d
=
act
,
inputs_2d
=
static_feat_2d
,
mask
=
mask
,
affine
=
affine
)
return
attn
f
=
hk
.
transform
(
run_ipa
)
...
...
@@ -235,13 +282,20 @@ class TestInvariantPointAttention(unittest.TestCase):
sample_mask
=
np
.
ones
((
n_res
,
1
))
affines
=
random_affines_4x4
((
n_res
,))
rigids
=
alphafold
.
model
.
r3
.
rigids_from_tensor4x4
(
affines
)
quats
=
alphafold
.
model
.
r3
.
rigids_to_quataffine
(
rigids
)
transformations
=
Rigid
.
from_tensor_4x4
(
torch
.
as_tensor
(
affines
).
float
().
cuda
()
)
sample_affine
=
quats
if
consts
.
is_multimer
:
rigids
=
self
.
am_rigid
.
Rigid3Array
.
from_array4x4
(
affines
)
transformations
=
Rigid3Array
.
from_tensor_4x4
(
torch
.
as_tensor
(
affines
).
float
()
)
sample_affine
=
rigids
else
:
rigids
=
self
.
am_rigid
.
rigids_from_tensor4x4
(
affines
)
quats
=
self
.
am_rigid
.
rigids_to_quataffine
(
rigids
)
transformations
=
Rigid
.
from_tensor_4x4
(
torch
.
as_tensor
(
affines
).
float
().
cuda
()
)
sample_affine
=
quats
ipa_params
=
compare_utils
.
fetch_alphafold_module_weights
(
"alphafold/alphafold_iteration/structure_module/"
...
...
tests/test_template.py
View file @
d8ee9c5f
...
...
@@ -19,7 +19,6 @@ from openfold.model.template import (
TemplatePointwiseAttention
,
TemplatePairStack
,
)
from
openfold.utils.tensor_utils
import
tree_map
import
tests.compare_utils
as
compare_utils
from
tests.config
import
consts
from
tests.data_utils
import
random_template_feats
...
...
@@ -54,6 +53,19 @@ class TestTemplatePointwiseAttention(unittest.TestCase):
class
TestTemplatePairStack
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
if
consts
.
is_multimer
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom_multimer
cls
.
am_fold
=
alphafold
.
model
.
folding_multimer
cls
.
am_modules
=
alphafold
.
model
.
modules_multimer
cls
.
am_rigid
=
alphafold
.
model
.
geometry
else
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom
cls
.
am_fold
=
alphafold
.
model
.
folding
cls
.
am_modules
=
alphafold
.
model
.
modules
cls
.
am_rigid
=
alphafold
.
model
.
r3
def
test_shape
(
self
):
batch_size
=
consts
.
batch_size
c_t
=
consts
.
c_t
...
...
@@ -65,6 +77,7 @@ class TestTemplatePairStack(unittest.TestCase):
dropout
=
0.25
n_templ
=
consts
.
n_templ
n_res
=
consts
.
n_res
tri_mul_first
=
consts
.
is_multimer
blocks_per_ckpt
=
None
chunk_size
=
4
inf
=
1e7
...
...
@@ -78,6 +91,7 @@ class TestTemplatePairStack(unittest.TestCase):
no_heads
=
no_heads
,
pair_transition_n
=
pt_inner_dim
,
dropout_rate
=
dropout
,
tri_mul_first
=
tri_mul_first
,
blocks_per_ckpt
=
None
,
inf
=
inf
,
eps
=
eps
,
...
...
@@ -96,12 +110,40 @@ class TestTemplatePairStack(unittest.TestCase):
def
run_template_pair_stack
(
pair_act
,
pair_mask
):
config
=
compare_utils
.
get_alphafold_config
()
c_ee
=
config
.
model
.
embeddings_and_evoformer
tps
=
alphafold
.
model
.
modules
.
TemplatePairStack
(
c_ee
.
template
.
template_pair_stack
,
config
.
model
.
global_config
,
name
=
"template_pair_stack"
,
)
act
=
tps
(
pair_act
,
pair_mask
,
is_training
=
False
)
if
consts
.
is_multimer
:
safe_key
=
alphafold
.
model
.
prng
.
SafeKey
(
hk
.
next_rng_key
())
template_iteration
=
self
.
am_modules
.
TemplateEmbeddingIteration
(
c_ee
.
template
.
template_pair_stack
,
config
.
model
.
global_config
,
name
=
'template_embedding_iteration'
)
def
template_iteration_fn
(
x
):
act
,
safe_key
=
x
safe_key
,
safe_subkey
=
safe_key
.
split
()
act
=
template_iteration
(
act
=
act
,
pair_mask
=
pair_mask
,
is_training
=
False
,
safe_key
=
safe_subkey
)
return
(
act
,
safe_key
)
if
config
.
model
.
global_config
.
use_remat
:
template_iteration_fn
=
hk
.
remat
(
template_iteration_fn
)
safe_key
,
safe_subkey
=
safe_key
.
split
()
template_stack
=
alphafold
.
model
.
layer_stack
.
layer_stack
(
c_ee
.
template
.
template_pair_stack
.
num_block
)(
template_iteration_fn
)
act
,
_
=
template_stack
((
pair_act
,
safe_subkey
))
else
:
tps
=
self
.
am_modules
.
TemplatePairStack
(
c_ee
.
template
.
template_pair_stack
,
config
.
model
.
global_config
,
name
=
"template_pair_stack"
,
)
act
=
tps
(
pair_act
,
pair_mask
,
is_training
=
False
)
ln
=
hk
.
LayerNorm
([
-
1
],
True
,
True
,
name
=
"output_layer_norm"
)
act
=
ln
(
act
)
return
act
...
...
@@ -115,10 +157,16 @@ class TestTemplatePairStack(unittest.TestCase):
low
=
0
,
high
=
2
,
size
=
(
n_res
,
n_res
)
).
astype
(
np
.
float32
)
params
=
compare_utils
.
fetch_alphafold_module_weights
(
"alphafold/alphafold_iteration/evoformer/template_embedding/"
+
"single_template_embedding/template_pair_stack"
)
if
consts
.
is_multimer
:
params
=
compare_utils
.
fetch_alphafold_module_weights
(
"alphafold/alphafold_iteration/evoformer/template_embedding/"
+
"single_template_embedding/template_embedding_iteration"
)
else
:
params
=
compare_utils
.
fetch_alphafold_module_weights
(
"alphafold/alphafold_iteration/evoformer/template_embedding/"
+
"single_template_embedding/template_pair_stack"
)
params
.
update
(
compare_utils
.
fetch_alphafold_module_weights
(
"alphafold/alphafold_iteration/evoformer/template_embedding/"
...
...
@@ -132,7 +180,7 @@ class TestTemplatePairStack(unittest.TestCase):
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
))
model
=
compare_utils
.
get_global_pretrained_openfold
()
out_repro
=
model
.
template_pair_stack
(
out_repro
=
model
.
template_
embedder
.
template_
pair_stack
(
torch
.
as_tensor
(
pair_act
).
unsqueeze
(
-
4
).
cuda
(),
torch
.
as_tensor
(
pair_mask
).
unsqueeze
(
-
3
).
cuda
(),
chunk_size
=
None
,
...
...
@@ -143,15 +191,32 @@ class TestTemplatePairStack(unittest.TestCase):
class
Template
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
if
consts
.
is_multimer
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom_multimer
cls
.
am_fold
=
alphafold
.
model
.
folding_multimer
cls
.
am_modules
=
alphafold
.
model
.
modules_multimer
cls
.
am_rigid
=
alphafold
.
model
.
geometry
else
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom
cls
.
am_fold
=
alphafold
.
model
.
folding
cls
.
am_modules
=
alphafold
.
model
.
modules
cls
.
am_rigid
=
alphafold
.
model
.
r3
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_compare
(
self
):
def
test_template_embedding
(
pair
,
batch
,
mask_2d
):
config
=
compare_utils
.
get_alphafold_config
()
te
=
alphafold
.
model
.
modules
.
TemplateEmbedding
(
te
=
self
.
am_
modules
.
TemplateEmbedding
(
config
.
model
.
embeddings_and_evoformer
.
template
,
config
.
model
.
global_config
,
)
act
=
te
(
pair
,
batch
,
mask_2d
,
is_training
=
False
)
if
consts
.
is_multimer
:
act
=
te
(
pair
,
batch
,
mask_2d
,
multichain_mask_2d
=
multichain_mask_2d
,
is_training
=
False
)
else
:
act
=
te
(
pair
,
batch
,
mask_2d
,
is_training
=
False
)
return
act
f
=
hk
.
transform
(
test_template_embedding
)
...
...
@@ -162,6 +227,14 @@ class Template(unittest.TestCase):
pair_act
=
np
.
random
.
rand
(
n_res
,
n_res
,
consts
.
c_z
).
astype
(
np
.
float32
)
batch
=
random_template_feats
(
n_templ
,
n_res
)
batch
[
"template_all_atom_masks"
]
=
batch
[
"template_all_atom_mask"
]
if
consts
.
is_multimer
:
asym_id
=
batch
[
'asym_id'
][
0
]
multichain_mask_2d
=
(
asym_id
[...,
None
]
==
asym_id
[...,
None
,
:]
).
astype
(
np
.
float32
)
batch
[
"multichain_mask_2d"
]
=
multichain_mask_2d
pair_mask
=
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
n_res
)).
astype
(
np
.
float32
)
# Fetch pretrained parameters (but only from one block)]
params
=
compare_utils
.
fetch_alphafold_module_weights
(
...
...
@@ -177,12 +250,26 @@ class Template(unittest.TestCase):
batch
[
"target_feat"
]
=
np
.
eye
(
22
)[
inds
]
model
=
compare_utils
.
get_global_pretrained_openfold
()
out_repro
=
model
.
embed_templates
(
{
k
:
torch
.
as_tensor
(
v
).
cuda
()
for
k
,
v
in
batch
.
items
()},
torch
.
as_tensor
(
pair_act
).
cuda
(),
torch
.
as_tensor
(
pair_mask
).
cuda
(),
templ_dim
=
0
,
)
template_feats
=
{
k
:
torch
.
as_tensor
(
v
).
cuda
()
for
k
,
v
in
batch
.
items
()}
if
consts
.
is_multimer
:
out_repro
=
model
.
template_embedder
(
template_feats
,
torch
.
as_tensor
(
pair_act
).
cuda
(),
torch
.
as_tensor
(
pair_mask
).
cuda
(),
templ_dim
=
0
,
chunk_size
=
consts
.
chunk_size
,
multichain_mask_2d
=
multichain_mask_2d
,
)
else
:
out_repro
=
model
.
template_embedder
(
template_feats
,
torch
.
as_tensor
(
pair_act
).
cuda
(),
torch
.
as_tensor
(
pair_mask
).
cuda
(),
templ_dim
=
0
,
chunk_size
=
consts
.
chunk_size
)
out_repro
=
out_repro
[
"template_pair_embedding"
]
out_repro
=
out_repro
.
cpu
()
...
...
tests/test_triangular_attention.py
View file @
d8ee9c5f
...
...
@@ -85,9 +85,9 @@ class TestTriangularAttention(unittest.TestCase):
model
=
compare_utils
.
get_global_pretrained_openfold
()
module
=
(
model
.
evoformer
.
blocks
[
0
].
core
.
tri_att_start
model
.
evoformer
.
blocks
[
0
].
pair_stack
.
tri_att_start
if
starting
else
model
.
evoformer
.
blocks
[
0
].
core
.
tri_att_end
else
model
.
evoformer
.
blocks
[
0
].
pair_stack
.
tri_att_end
)
out_repro
=
module
(
torch
.
as_tensor
(
pair_act
,
dtype
=
torch
.
float32
).
cuda
(),
...
...
tests/test_triangular_multiplicative_update.py
View file @
d8ee9c5f
...
...
@@ -87,9 +87,9 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
model
=
compare_utils
.
get_global_pretrained_openfold
()
module
=
(
model
.
evoformer
.
blocks
[
0
].
core
.
tri_mul_in
model
.
evoformer
.
blocks
[
0
].
pair_stack
.
tri_mul_in
if
incoming
else
model
.
evoformer
.
blocks
[
0
].
core
.
tri_mul_out
else
model
.
evoformer
.
blocks
[
0
].
pair_stack
.
tri_mul_out
)
out_repro
=
module
(
torch
.
as_tensor
(
pair_act
,
dtype
=
torch
.
float32
).
cuda
(),
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment