Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
56d5e39c
Commit
56d5e39c
authored
Jun 17, 2023
by
Geoffrey Yu
Browse files
Merge remote-tracking branch 'upstream/multimer' into multimer
parents
56b86074
51556d52
Changes
80
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
742 additions
and
168 deletions
+742
-168
setup.py
setup.py
+2
-0
tests/compare_utils.py
tests/compare_utils.py
+4
-4
tests/config.py
tests/config.py
+5
-1
tests/data_utils.py
tests/data_utils.py
+32
-0
tests/test_data_pipeline.py
tests/test_data_pipeline.py
+25
-15
tests/test_data_transforms.py
tests/test_data_transforms.py
+1
-5
tests/test_embedders.py
tests/test_embedders.py
+25
-5
tests/test_evoformer.py
tests/test_evoformer.py
+10
-1
tests/test_feats.py
tests/test_feats.py
+99
-28
tests/test_import_weights.py
tests/test_import_weights.py
+1
-1
tests/test_loss.py
tests/test_loss.py
+284
-28
tests/test_model.py
tests/test_model.py
+32
-6
tests/test_outer_product_mean.py
tests/test_outer_product_mean.py
+1
-1
tests/test_pair_transition.py
tests/test_pair_transition.py
+1
-1
tests/test_primitives.py
tests/test_primitives.py
+7
-13
tests/test_structure_module.py
tests/test_structure_module.py
+77
-23
tests/test_template.py
tests/test_template.py
+116
-23
tests/test_triangular_attention.py
tests/test_triangular_attention.py
+2
-2
tests/test_triangular_multiplicative_update.py
tests/test_triangular_multiplicative_update.py
+17
-10
tests/test_utils.py
tests/test_utils.py
+1
-1
No files found.
setup.py
View file @
56d5e39c
...
...
@@ -75,6 +75,8 @@ for major, minor in list(compute_capabilities):
extra_cuda_flags
+=
cc_flag
cc_flag
=
[
'-gencode'
,
'arch=compute_70,code=sm_70'
]
if
bare_metal_major
!=
-
1
:
modules
=
[
CUDAExtension
(
name
=
"attn_core_inplace_cuda"
,
...
...
tests/compare_utils.py
View file @
56d5e39c
...
...
@@ -46,26 +46,26 @@ def import_alphafold():
def
get_alphafold_config
():
config
=
alphafold
.
model
.
config
.
model_config
(
"model_1_ptm"
)
# noqa
config
=
alphafold
.
model
.
config
.
model_config
(
consts
.
model
)
# noqa
config
.
model
.
global_config
.
deterministic
=
True
return
config
_param_path
=
"openfold/resources/params/params_
model_1_ptm
.npz"
_param_path
=
f
"openfold/resources/params/params_
{
consts
.
model
}
.npz"
_model
=
None
def
get_global_pretrained_openfold
():
global
_model
if
_model
is
None
:
_model
=
AlphaFold
(
model_config
(
"model_1_ptm"
))
_model
=
AlphaFold
(
model_config
(
consts
.
model
))
_model
=
_model
.
eval
()
if
not
os
.
path
.
exists
(
_param_path
):
raise
FileNotFoundError
(
"""Cannot load pretrained parameters. Make sure to run the
installation script before running tests."""
)
import_jax_weights_
(
_model
,
_param_path
,
version
=
"model_1_ptm"
)
import_jax_weights_
(
_model
,
_param_path
,
version
=
consts
.
model
)
_model
=
_model
.
cuda
()
return
_model
...
...
tests/config.py
View file @
56d5e39c
...
...
@@ -2,8 +2,11 @@ import ml_collections as mlc
consts
=
mlc
.
ConfigDict
(
{
"model"
:
"model_1_multimer_v3"
,
# monomer:model_1_ptm, multimer: model_1_multimer_v3
"is_multimer"
:
True
,
# monomer: False, multimer: True
"chunk_size"
:
4
,
"batch_size"
:
2
,
"n_res"
:
11
,
"n_res"
:
22
,
"n_seq"
:
13
,
"n_templ"
:
3
,
"n_extra"
:
17
,
...
...
@@ -16,6 +19,7 @@ consts = mlc.ConfigDict(
"c_s"
:
384
,
"c_t"
:
64
,
"c_e"
:
64
,
"msa_logits"
:
22
# monomer: 23, multimer: 22
}
)
...
...
tests/data_utils.py
View file @
56d5e39c
...
...
@@ -12,9 +12,36 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
random
import
randint
import
numpy
as
np
from
scipy.spatial.transform
import
Rotation
from
tests.config
import
consts
def
random_asym_ids
(
n_res
,
split_chains
=
True
,
min_chain_len
=
4
):
n_chain
=
randint
(
1
,
n_res
//
min_chain_len
)
if
consts
.
is_multimer
else
1
if
not
split_chains
:
return
[
0
]
*
n_res
assert
n_res
>=
n_chain
pieces
=
[]
asym_ids
=
[]
final_idx
=
n_chain
-
1
for
idx
in
range
(
n_chain
-
1
):
n_stop
=
(
n_res
-
sum
(
pieces
)
-
n_chain
+
idx
-
min_chain_len
)
if
n_stop
<=
min_chain_len
:
final_idx
=
idx
break
piece
=
randint
(
min_chain_len
,
n_stop
)
pieces
.
append
(
piece
)
asym_ids
.
extend
(
piece
*
[
idx
])
asym_ids
.
extend
((
n_res
-
sum
(
pieces
))
*
[
final_idx
])
return
np
.
array
(
asym_ids
).
astype
(
np
.
int64
)
def
random_template_feats
(
n_templ
,
n
,
batch_size
=
None
):
b
=
[]
...
...
@@ -39,6 +66,11 @@ def random_template_feats(n_templ, n, batch_size=None):
}
batch
=
{
k
:
v
.
astype
(
np
.
float32
)
for
k
,
v
in
batch
.
items
()}
batch
[
"template_aatype"
]
=
batch
[
"template_aatype"
].
astype
(
np
.
int64
)
if
consts
.
is_multimer
:
asym_ids
=
np
.
array
(
random_asym_ids
(
n
))
batch
[
"asym_id"
]
=
np
.
tile
(
asym_ids
[
np
.
newaxis
,
:],
(
*
b
,
n_templ
,
1
))
return
batch
...
...
tests/test_data_pipeline.py
View file @
56d5e39c
...
...
@@ -15,19 +15,13 @@
import
pickle
import
shutil
import
torch
import
numpy
as
np
import
unittest
from
openfold.data.data_pipeline
import
DataPipeline
from
openfold.data.templates
import
TemplateHitFeaturizer
from
openfold.model.embedders
import
(
InputEmbedder
,
RecyclingEmbedder
,
TemplateAngleEmbedder
,
TemplatePairEmbedder
,
)
from
openfold.data.templates
import
HhsearchHitFeaturizer
,
HmmsearchHitFeaturizer
import
tests.compare_utils
as
compare_utils
from
tests.config
import
consts
if
compare_utils
.
alphafold_is_installed
():
alphafold
=
compare_utils
.
import_alphafold
()
...
...
@@ -45,13 +39,29 @@ class TestDataPipeline(unittest.TestCase):
with
open
(
"tests/test_data/alphafold_feature_dict.pickle"
,
"rb"
)
as
fp
:
alphafold_feature_dict
=
pickle
.
load
(
fp
)
template_featurizer
=
TemplateHitFeaturizer
(
mmcif_dir
=
"tests/test_data/mmcifs"
,
max_template_date
=
"2021-12-20"
,
max_hits
=
20
,
kalign_binary_path
=
shutil
.
which
(
"kalign"
),
_zero_center_positions
=
False
,
)
if
consts
.
is_multimer
:
# template_featurizer = HmmsearchHitFeaturizer(
# mmcif_dir="tests/test_data/mmcifs",
# max_template_date="2021-12-20",
# max_hits=20,
# kalign_binary_path=shutil.which("kalign"),
# _zero_center_positions=False,
# )
template_featurizer
=
HhsearchHitFeaturizer
(
mmcif_dir
=
"tests/test_data/mmcifs"
,
max_template_date
=
"2021-12-20"
,
max_hits
=
20
,
kalign_binary_path
=
shutil
.
which
(
"kalign"
),
_zero_center_positions
=
False
,
)
else
:
template_featurizer
=
HhsearchHitFeaturizer
(
mmcif_dir
=
"tests/test_data/mmcifs"
,
max_template_date
=
"2021-12-20"
,
max_hits
=
20
,
kalign_binary_path
=
shutil
.
which
(
"kalign"
),
_zero_center_positions
=
False
,
)
data_pipeline
=
DataPipeline
(
template_featurizer
=
template_featurizer
,
...
...
tests/test_data_transforms.py
View file @
56d5e39c
import
copy
import
gzip
import
os
import
pickle
import
numpy
as
np
...
...
@@ -181,7 +177,7 @@ class TestDataTransforms(unittest.TestCase):
}
protein
=
make_hhblits_profile
(
protein
)
masked_msa_config
=
config
.
data
.
common
.
masked_msa
protein
=
make_masked_msa
.
__wrapped__
(
protein
,
masked_msa_config
,
replace_fraction
=
0.15
)
protein
=
make_masked_msa
.
__wrapped__
(
protein
,
masked_msa_config
,
replace_fraction
=
0.15
,
seed
=
42
)
assert
'bert_mask'
in
protein
assert
'true_msa'
in
protein
assert
'msa'
in
protein
...
...
tests/test_embedders.py
View file @
56d5e39c
...
...
@@ -12,14 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
random
import
torch
import
numpy
as
np
import
unittest
from
tests.config
import
consts
from
tests.data_utils
import
random_asym_ids
from
openfold.model.embedders
import
(
InputEmbedder
,
InputEmbedderMultimer
,
RecyclingEmbedder
,
TemplateAngleEmbedder
,
TemplatePairEmbedder
,
TemplatePairEmbedder
)
...
...
@@ -35,13 +38,30 @@ class TestInputEmbedder(unittest.TestCase):
n_res
=
17
n_clust
=
19
max_relative_chain
=
2
max_relative_idx
=
32
use_chain_relative
=
True
tf
=
torch
.
rand
((
b
,
n_res
,
tf_dim
))
ri
=
torch
.
rand
((
b
,
n_res
))
msa
=
torch
.
rand
((
b
,
n_clust
,
n_res
,
msa_dim
))
asym_ids_flat
=
torch
.
Tensor
(
random_asym_ids
(
n_res
))
asym_id
=
torch
.
tile
(
asym_ids_flat
.
unsqueeze
(
0
),
(
b
,
1
))
entity_id
=
asym_id
sym_id
=
torch
.
zeros_like
(
entity_id
)
if
consts
.
is_multimer
:
ie
=
InputEmbedderMultimer
(
tf_dim
,
msa_dim
,
c_z
,
c_m
,
max_relative_idx
=
max_relative_idx
,
use_chain_relative
=
use_chain_relative
,
max_relative_chain
=
max_relative_chain
)
batch
=
{
"target_feat"
:
tf
,
"residue_index"
:
ri
,
"msa_feat"
:
msa
,
"asym_id"
:
asym_id
,
"entity_id"
:
entity_id
,
"sym_id"
:
sym_id
}
msa_emb
,
pair_emb
=
ie
(
batch
)
else
:
ie
=
InputEmbedder
(
tf_dim
,
msa_dim
,
c_z
,
c_m
,
relpos_k
)
msa_emb
,
pair_emb
=
ie
(
tf
=
tf
,
ri
=
ri
,
msa
=
msa
,
inplace_safe
=
False
)
ie
=
InputEmbedder
(
tf_dim
,
msa_dim
,
c_z
,
c_m
,
relpos_k
)
msa_emb
,
pair_emb
=
ie
(
tf
,
ri
,
msa
)
self
.
assertTrue
(
msa_emb
.
shape
==
(
b
,
n_clust
,
n_res
,
c_m
))
self
.
assertTrue
(
pair_emb
.
shape
==
(
b
,
n_res
,
n_res
,
c_z
))
...
...
tests/test_evoformer.py
View file @
56d5e39c
...
...
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
re
import
torch
import
numpy
as
np
import
unittest
...
...
@@ -48,6 +49,8 @@ class TestEvoformerStack(unittest.TestCase):
transition_n
=
2
msa_dropout
=
0.15
pair_stack_dropout
=
0.25
opm_first
=
consts
.
is_multimer
fuse_projection_weights
=
True
if
re
.
fullmatch
(
"^model_[1-5]_multimer_v3$"
,
consts
.
model
)
else
False
inf
=
1e9
eps
=
1e-10
...
...
@@ -65,6 +68,8 @@ class TestEvoformerStack(unittest.TestCase):
transition_n
,
msa_dropout
,
pair_stack_dropout
,
opm_first
,
fuse_projection_weights
,
blocks_per_ckpt
=
None
,
inf
=
inf
,
eps
=
eps
,
...
...
@@ -174,6 +179,8 @@ class TestExtraMSAStack(unittest.TestCase):
transition_n
=
5
msa_dropout
=
0.15
pair_stack_dropout
=
0.25
opm_first
=
consts
.
is_multimer
fuse_projection_weights
=
True
if
re
.
fullmatch
(
"^model_[1-5]_multimer_v3$"
,
consts
.
model
)
else
False
inf
=
1e9
eps
=
1e-10
...
...
@@ -190,6 +197,8 @@ class TestExtraMSAStack(unittest.TestCase):
transition_n
,
msa_dropout
,
pair_stack_dropout
,
opm_first
,
fuse_projection_weights
,
ckpt
=
False
,
inf
=
inf
,
eps
=
eps
,
...
...
@@ -277,7 +286,7 @@ class TestMSATransition(unittest.TestCase):
model
=
compare_utils
.
get_global_pretrained_openfold
()
out_repro
=
(
model
.
evoformer
.
blocks
[
0
].
core
.
msa_transition
(
model
.
evoformer
.
blocks
[
0
].
msa_transition
(
torch
.
as_tensor
(
msa_act
,
dtype
=
torch
.
float32
).
cuda
(),
mask
=
torch
.
as_tensor
(
msa_mask
,
dtype
=
torch
.
float32
).
cuda
(),
)
...
...
tests/test_feats.py
View file @
56d5e39c
...
...
@@ -25,13 +25,16 @@ from openfold.np.residue_constants import (
)
import
openfold.utils.feats
as
feats
from
openfold.utils.rigid_utils
import
Rotation
,
Rigid
from
openfold.utils.geometry.rigid_matrix_vector
import
Rigid3Array
from
openfold.utils.geometry.rotation_matrix
import
Rot3Array
from
openfold.utils.geometry.vector
import
Vec3Array
from
openfold.utils.tensor_utils
import
(
tree_map
,
tensor_tree_map
,
)
import
tests.compare_utils
as
compare_utils
from
tests.config
import
consts
from
tests.data_utils
import
random_affines_4x4
from
tests.data_utils
import
random_affines_4x4
,
random_asym_ids
if
compare_utils
.
alphafold_is_installed
():
alphafold
=
compare_utils
.
import_alphafold
()
...
...
@@ -40,6 +43,19 @@ if compare_utils.alphafold_is_installed():
class
TestFeats
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
if
consts
.
is_multimer
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom_multimer
cls
.
am_fold
=
alphafold
.
model
.
folding_multimer
cls
.
am_modules
=
alphafold
.
model
.
modules_multimer
cls
.
am_rigid
=
alphafold
.
model
.
geometry
else
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom
cls
.
am_fold
=
alphafold
.
model
.
folding
cls
.
am_modules
=
alphafold
.
model
.
modules
cls
.
am_rigid
=
alphafold
.
model
.
r3
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_pseudo_beta_fn_compare
(
self
):
def
test_pbf
(
aatype
,
all_atom_pos
,
all_atom_mask
):
...
...
@@ -131,7 +147,9 @@ class TestFeats(unittest.TestCase):
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_atom37_to_frames_compare
(
self
):
def
run_atom37_to_frames
(
aatype
,
all_atom_positions
,
all_atom_mask
):
return
alphafold
.
model
.
all_atom
.
atom37_to_frames
(
if
consts
.
is_multimer
:
all_atom_positions
=
self
.
am_rigid
.
Vec3Array
.
from_array
(
all_atom_positions
)
return
self
.
am_atom
.
atom37_to_frames
(
aatype
,
all_atom_positions
,
all_atom_mask
)
...
...
@@ -150,9 +168,23 @@ class TestFeats(unittest.TestCase):
}
out_gt
=
f
.
apply
({},
None
,
**
batch
)
to_tensor
=
lambda
t
:
torch
.
tensor
(
np
.
array
(
t
))
if
consts
.
is_multimer
:
batch
[
"asym_id"
]
=
random_asym_ids
(
n_res
)
to_tensor
=
(
lambda
t
:
torch
.
tensor
(
np
.
array
(
t
))
if
not
isinstance
(
t
,
self
.
am_rigid
.
Rigid3Array
)
else
torch
.
tensor
(
np
.
array
(
t
.
to_array
())))
else
:
to_tensor
=
lambda
t
:
torch
.
tensor
(
np
.
array
(
t
))
out_gt
=
{
k
:
to_tensor
(
v
)
for
k
,
v
in
out_gt
.
items
()}
def
rigid3x4_to_4x4
(
rigid3arr
):
four_by_four
=
torch
.
zeros
(
*
rigid3arr
.
shape
[:
-
2
],
4
,
4
)
four_by_four
[...,
:
3
,
:
4
]
=
rigid3arr
four_by_four
[...,
3
,
3
]
=
1
return
four_by_four
def
flat12_to_4x4
(
flat12
):
rot
=
flat12
[...,
:
9
].
view
(
*
flat12
.
shape
[:
-
1
],
3
,
3
)
trans
=
flat12
[...,
9
:]
...
...
@@ -164,10 +196,12 @@ class TestFeats(unittest.TestCase):
return
four_by_four
out_gt
[
"rigidgroups_gt_frames"
]
=
flat12_to_4x4
(
convert_func
=
rigid3x4_to_4x4
if
consts
.
is_multimer
else
flat12_to_4x4
out_gt
[
"rigidgroups_gt_frames"
]
=
convert_func
(
out_gt
[
"rigidgroups_gt_frames"
]
)
out_gt
[
"rigidgroups_alt_gt_frames"
]
=
flat12_to_4x4
(
out_gt
[
"rigidgroups_alt_gt_frames"
]
=
convert_func
(
out_gt
[
"rigidgroups_alt_gt_frames"
]
)
...
...
@@ -187,7 +221,13 @@ class TestFeats(unittest.TestCase):
n
=
5
rots
=
torch
.
rand
((
batch_size
,
n
,
3
,
3
))
trans
=
torch
.
rand
((
batch_size
,
n
,
3
))
ts
=
Rigid
(
Rotation
(
rot_mats
=
rots
),
trans
)
if
consts
.
is_multimer
:
rotation
=
Rot3Array
.
from_array
(
rots
)
translation
=
Vec3Array
.
from_array
(
trans
)
ts
=
Rigid3Array
(
rotation
,
translation
)
else
:
ts
=
Rigid
(
Rotation
(
rot_mats
=
rots
),
trans
)
angles
=
torch
.
rand
((
batch_size
,
n
,
7
,
2
))
...
...
@@ -208,7 +248,7 @@ class TestFeats(unittest.TestCase):
def
run_torsion_angles_to_frames
(
aatype
,
backb_to_global
,
torsion_angles_sin_cos
):
return
alphafold
.
mod
el
.
a
ll
_atom
.
torsion_angles_to_frames
(
return
s
el
f
.
a
m
_atom
.
torsion_angles_to_frames
(
aatype
,
backb_to_global
,
torsion_angles_sin_cos
,
...
...
@@ -221,10 +261,17 @@ class TestFeats(unittest.TestCase):
aatype
=
np
.
random
.
randint
(
0
,
21
,
size
=
(
n_res
,))
affines
=
random_affines_4x4
((
n_res
,))
rigids
=
alphafold
.
model
.
r3
.
rigids_from_tensor4x4
(
affines
)
transformations
=
Rigid
.
from_tensor_4x4
(
torch
.
as_tensor
(
affines
).
float
()
)
if
consts
.
is_multimer
:
rigids
=
self
.
am_rigid
.
Rigid3Array
.
from_array4x4
(
affines
)
transformations
=
Rigid3Array
.
from_tensor_4x4
(
torch
.
as_tensor
(
affines
).
float
()
)
else
:
rigids
=
self
.
am_rigid
.
rigids_from_tensor4x4
(
affines
)
transformations
=
Rigid
.
from_tensor_4x4
(
torch
.
as_tensor
(
affines
).
float
()
)
torsion_angles_sin_cos
=
np
.
random
.
rand
(
n_res
,
7
,
2
)
...
...
@@ -240,13 +287,21 @@ class TestFeats(unittest.TestCase):
)
# Convert the Rigids to 4x4 transformation tensors
rots_gt
=
list
(
map
(
lambda
x
:
torch
.
as_tensor
(
np
.
array
(
x
)),
out_gt
.
rot
))
trans_gt
=
list
(
map
(
lambda
x
:
torch
.
as_tensor
(
np
.
array
(
x
)),
out_gt
.
trans
)
)
rots_gt
=
torch
.
cat
([
x
.
unsqueeze
(
-
1
)
for
x
in
rots_gt
],
dim
=-
1
)
rots_gt
=
rots_gt
.
view
(
*
rots_gt
.
shape
[:
-
1
],
3
,
3
)
trans_gt
=
torch
.
cat
([
x
.
unsqueeze
(
-
1
)
for
x
in
trans_gt
],
dim
=-
1
)
out_gt_rot
=
out_gt
.
rot
if
not
consts
.
is_multimer
else
out_gt
.
rotation
.
to_array
()
out_gt_trans
=
out_gt
.
trans
if
not
consts
.
is_multimer
else
out_gt
.
translation
.
to_array
()
if
consts
.
is_multimer
:
rots_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt_rot
))
trans_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt_trans
))
else
:
rots_gt
=
list
(
map
(
lambda
x
:
torch
.
as_tensor
(
np
.
array
(
x
)),
out_gt_rot
))
trans_gt
=
list
(
map
(
lambda
x
:
torch
.
as_tensor
(
np
.
array
(
x
)),
out_gt_trans
)
)
rots_gt
=
torch
.
cat
([
x
.
unsqueeze
(
-
1
)
for
x
in
rots_gt
],
dim
=-
1
)
rots_gt
=
rots_gt
.
view
(
*
rots_gt
.
shape
[:
-
1
],
3
,
3
)
trans_gt
=
torch
.
cat
([
x
.
unsqueeze
(
-
1
)
for
x
in
trans_gt
],
dim
=-
1
)
transforms_gt
=
torch
.
cat
([
rots_gt
,
trans_gt
.
unsqueeze
(
-
1
)],
dim
=-
1
)
bottom_row
=
torch
.
zeros
((
*
rots_gt
.
shape
[:
-
2
],
1
,
4
))
bottom_row
[...,
3
]
=
1
...
...
@@ -264,7 +319,13 @@ class TestFeats(unittest.TestCase):
rots
=
torch
.
rand
((
batch_size
,
n_res
,
8
,
3
,
3
))
trans
=
torch
.
rand
((
batch_size
,
n_res
,
8
,
3
))
ts
=
Rigid
(
Rotation
(
rot_mats
=
rots
),
trans
)
if
consts
.
is_multimer
:
rotation
=
Rot3Array
.
from_array
(
rots
)
translation
=
Vec3Array
.
from_array
(
trans
)
ts
=
Rigid3Array
(
rotation
,
translation
)
else
:
ts
=
Rigid
(
Rotation
(
rot_mats
=
rots
),
trans
)
f
=
torch
.
randint
(
low
=
0
,
high
=
21
,
size
=
(
batch_size
,
n_res
)).
long
()
...
...
@@ -282,8 +343,7 @@ class TestFeats(unittest.TestCase):
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_frames_and_literature_positions_to_atom14_pos_compare
(
self
):
def
run_f
(
aatype
,
affines
):
am
=
alphafold
.
model
return
am
.
all_atom
.
frames_and_literature_positions_to_atom14_pos
(
return
self
.
am_atom
.
frames_and_literature_positions_to_atom14_pos
(
aatype
,
affines
)
...
...
@@ -294,16 +354,27 @@ class TestFeats(unittest.TestCase):
aatype
=
np
.
random
.
randint
(
0
,
21
,
size
=
(
n_res
,))
affines
=
random_affines_4x4
((
n_res
,
8
))
rigids
=
alphafold
.
model
.
r3
.
rigids_from_tensor4x4
(
affines
)
transformations
=
Rigid
.
from_tensor_4x4
(
torch
.
as_tensor
(
affines
).
float
()
)
if
consts
.
is_multimer
:
rigids
=
self
.
am_rigid
.
Rigid3Array
.
from_array4x4
(
affines
)
transformations
=
Rigid3Array
.
from_tensor_4x4
(
torch
.
as_tensor
(
affines
).
float
()
)
else
:
rigids
=
self
.
am_rigid
.
rigids_from_tensor4x4
(
affines
)
transformations
=
Rigid
.
from_tensor_4x4
(
torch
.
as_tensor
(
affines
).
float
()
)
out_gt
=
f
.
apply
({},
None
,
aatype
,
rigids
)
jax
.
tree_map
(
lambda
x
:
x
.
block_until_ready
(),
out_gt
)
out_gt
=
torch
.
stack
(
[
torch
.
as_tensor
(
np
.
array
(
x
))
for
x
in
out_gt
],
dim
=-
1
)
if
consts
.
is_multimer
:
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
.
to_array
()))
else
:
out_gt
=
torch
.
stack
(
[
torch
.
as_tensor
(
np
.
array
(
x
))
for
x
in
out_gt
],
dim
=-
1
)
out_repro
=
feats
.
frames_and_literature_positions_to_atom14_pos
(
transformations
.
cuda
(),
...
...
tests/test_import_weights.py
View file @
56d5e39c
...
...
@@ -65,7 +65,7 @@ class TestImportWeights(unittest.TestCase):
)
][
1
].
transpose
(
-
1
,
-
2
)
),
model
.
evoformer
.
blocks
[
1
].
core
.
outer_product_mean
.
linear_1
.
weight
,
model
.
evoformer
.
blocks
[
1
].
outer_product_mean
.
linear_1
.
weight
,
),
]
...
...
tests/test_loss.py
View file @
56d5e39c
...
...
@@ -13,18 +13,18 @@
# limitations under the License.
import
os
import
math
import
torch
import
numpy
as
np
from
pathlib
import
Path
import
unittest
import
ml_collections
as
mlc
from
openfold.data
import
data_transforms
from
openfold.np
import
residue_constants
from
openfold.utils.rigid_utils
import
(
Rotation
,
Rigid
,
)
import
openfold.utils.feats
as
feats
from
openfold.utils.loss
import
(
torsion_angle_loss
,
compute_fape
,
...
...
@@ -43,6 +43,8 @@ from openfold.utils.loss import (
sidechain_loss
,
tm_loss
,
compute_plddt
,
compute_tm
,
chain_center_of_mass_loss
)
from
openfold.utils.tensor_utils
import
(
tree_map
,
...
...
@@ -51,7 +53,7 @@ from openfold.utils.tensor_utils import (
)
import
tests.compare_utils
as
compare_utils
from
tests.config
import
consts
from
tests.data_utils
import
random_affines_vector
,
random_affines_4x4
from
tests.data_utils
import
random_affines_vector
,
random_affines_4x4
,
random_asym_ids
if
compare_utils
.
alphafold_is_installed
():
alphafold
=
compare_utils
.
import_alphafold
()
...
...
@@ -64,7 +66,30 @@ def affine_vector_to_4x4(affine):
return
r
.
to_tensor_4x4
()
def
affine_vector_to_rigid
(
am_rigid
,
affine
):
rigid_flat
=
np
.
split
(
affine
,
7
,
axis
=-
1
)
rigid_flat
=
[
r
.
squeeze
(
-
1
)
for
r
in
rigid_flat
]
qw
,
qx
,
qy
,
qz
=
rigid_flat
[:
4
]
trans
=
rigid_flat
[
4
:]
rotation
=
am_rigid
.
Rot3Array
.
from_quaternion
(
qw
,
qx
,
qy
,
qz
,
normalize
=
True
)
translation
=
am_rigid
.
Vec3Array
(
*
trans
)
return
am_rigid
.
Rigid3Array
(
rotation
,
translation
)
class
TestLoss
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
if
consts
.
is_multimer
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom_multimer
cls
.
am_fold
=
alphafold
.
model
.
folding_multimer
cls
.
am_modules
=
alphafold
.
model
.
modules_multimer
cls
.
am_rigid
=
alphafold
.
model
.
geometry
else
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom
cls
.
am_fold
=
alphafold
.
model
.
folding
cls
.
am_modules
=
alphafold
.
model
.
modules
cls
.
am_rigid
=
alphafold
.
model
.
r3
def
test_run_torsion_angle_loss
(
self
):
batch_size
=
consts
.
batch_size
n_res
=
consts
.
n_res
...
...
@@ -127,7 +152,10 @@ class TestLoss(unittest.TestCase):
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_between_residue_bond_loss_compare
(
self
):
def
run_brbl
(
pred_pos
,
pred_atom_mask
,
residue_index
,
aatype
):
return
alphafold
.
model
.
all_atom
.
between_residue_bond_loss
(
if
consts
.
is_multimer
:
pred_pos
=
self
.
am_rigid
.
Vec3Array
.
from_array
(
pred_pos
)
return
self
.
am_atom
.
between_residue_bond_loss
(
pred_pos
,
pred_atom_mask
,
residue_index
,
...
...
@@ -184,12 +212,22 @@ class TestLoss(unittest.TestCase):
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_between_residue_clash_loss_compare
(
self
):
def
run_brcl
(
pred_pos
,
atom_exists
,
atom_radius
,
res_ind
):
return
alphafold
.
model
.
all_atom
.
between_residue_clash_loss
(
def
run_brcl
(
pred_pos
,
atom_exists
,
atom_radius
,
res_ind
,
asym_id
):
if
consts
.
is_multimer
:
pred_pos
=
self
.
am_rigid
.
Vec3Array
.
from_array
(
pred_pos
)
return
self
.
am_atom
.
between_residue_clash_loss
(
pred_pos
,
atom_exists
,
atom_radius
,
res_ind
,
asym_id
)
return
self
.
am_atom
.
between_residue_clash_loss
(
pred_pos
,
atom_exists
,
atom_radius
,
res_ind
,
res_ind
)
f
=
hk
.
transform
(
run_brcl
)
...
...
@@ -198,10 +236,24 @@ class TestLoss(unittest.TestCase):
pred_pos
=
np
.
random
.
rand
(
n_res
,
14
,
3
).
astype
(
np
.
float32
)
atom_exists
=
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
14
)).
astype
(
np
.
float32
)
atom_radius
=
np
.
random
.
rand
(
n_res
,
14
).
astype
(
np
.
float32
)
res_ind
=
np
.
arange
(
n_res
,
)
residx_atom14_to_atom37
=
np
.
random
.
randint
(
0
,
37
,
(
n_res
,
14
)).
astype
(
np
.
int64
)
atomtype_radius
=
[
residue_constants
.
van_der_waals_radius
[
name
[
0
]]
for
name
in
residue_constants
.
atom_types
]
atomtype_radius
=
np
.
array
(
atomtype_radius
).
astype
(
np
.
float32
)
atom_radius
=
(
atom_exists
*
atomtype_radius
[
residx_atom14_to_atom37
]
)
asym_id
=
None
if
consts
.
is_multimer
:
asym_id
=
random_asym_ids
(
n_res
)
out_gt
=
f
.
apply
(
{},
...
...
@@ -210,6 +262,7 @@ class TestLoss(unittest.TestCase):
atom_exists
,
atom_radius
,
res_ind
,
asym_id
)
out_gt
=
jax
.
tree_map
(
lambda
x
:
x
.
block_until_ready
(),
out_gt
)
out_gt
=
jax
.
tree_map
(
lambda
x
:
torch
.
tensor
(
np
.
copy
(
x
)),
out_gt
)
...
...
@@ -219,6 +272,7 @@ class TestLoss(unittest.TestCase):
torch
.
tensor
(
atom_exists
).
cuda
(),
torch
.
tensor
(
atom_radius
).
cuda
(),
torch
.
tensor
(
res_ind
).
cuda
(),
torch
.
tensor
(
asym_id
).
cuda
()
if
asym_id
is
not
None
else
None
,
)
out_repro
=
tensor_tree_map
(
lambda
x
:
x
.
cpu
(),
out_repro
)
...
...
@@ -242,6 +296,36 @@ class TestLoss(unittest.TestCase):
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_compute_ptm_compare
(
self
):
n_res
=
consts
.
n_res
max_bin
=
31
no_bins
=
64
logits
=
np
.
random
.
rand
(
n_res
,
n_res
,
no_bins
)
boundaries
=
np
.
linspace
(
0
,
max_bin
,
num
=
(
no_bins
-
1
))
ptm_gt
=
alphafold
.
common
.
confidence
.
predicted_tm_score
(
logits
,
boundaries
)
ptm_gt
=
torch
.
tensor
(
ptm_gt
)
logits_t
=
torch
.
tensor
(
logits
)
ptm_repro
=
compute_tm
(
logits_t
,
no_bins
=
no_bins
,
max_bin
=
max_bin
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
ptm_gt
-
ptm_repro
))
<
consts
.
eps
)
if
consts
.
is_multimer
:
asym_id
=
random_asym_ids
(
n_res
)
iptm_gt
=
alphafold
.
common
.
confidence
.
predicted_tm_score
(
logits
,
boundaries
,
asym_id
=
asym_id
,
interface
=
True
)
iptm_gt
=
torch
.
tensor
(
iptm_gt
)
iptm_repro
=
compute_tm
(
logits_t
,
no_bins
=
no_bins
,
max_bin
=
max_bin
,
asym_id
=
torch
.
tensor
(
asym_id
),
interface
=
True
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
iptm_gt
-
iptm_repro
))
<
consts
.
eps
)
def
test_find_structural_violations
(
self
):
n
=
consts
.
n_res
...
...
@@ -265,8 +349,21 @@ class TestLoss(unittest.TestCase):
def
test_find_structural_violations_compare
(
self
):
def
run_fsv
(
batch
,
pos
,
config
):
cwd
=
os
.
getcwd
()
os
.
chdir
(
"tests/test_data"
)
loss
=
alphafold
.
model
.
folding
.
find_structural_violations
(
fpath
=
Path
(
__file__
).
parent
.
resolve
()
/
"test_data"
os
.
chdir
(
str
(
fpath
))
if
consts
.
is_multimer
:
atom14_pred_pos
=
self
.
am_rigid
.
Vec3Array
.
from_array
(
pos
)
return
self
.
am_fold
.
find_structural_violations
(
batch
[
'aatype'
],
batch
[
'residue_index'
],
batch
[
'atom14_atom_exists'
],
atom14_pred_pos
,
config
,
batch
[
'asym_id'
]
)
loss
=
self
.
am_fold
.
find_structural_violations
(
batch
,
pos
,
config
,
...
...
@@ -287,6 +384,9 @@ class TestLoss(unittest.TestCase):
).
astype
(
np
.
int64
),
}
if
consts
.
is_multimer
:
batch
[
"asym_id"
]
=
random_asym_ids
(
n_res
)
pred_pos
=
np
.
random
.
rand
(
n_res
,
14
,
3
)
config
=
mlc
.
ConfigDict
(
...
...
@@ -380,14 +480,14 @@ class TestLoss(unittest.TestCase):
n_seq
=
consts
.
n_seq
value
=
{
"logits"
:
np
.
random
.
rand
(
n_res
,
n_seq
,
23
).
astype
(
np
.
float32
),
"logits"
:
np
.
random
.
rand
(
n_res
,
n_seq
,
consts
.
msa_logits
).
astype
(
np
.
float32
),
}
batch
=
{
"true_msa"
:
np
.
random
.
randint
(
0
,
21
,
(
n_res
,
n_seq
)),
"bert_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
n_seq
)).
astype
(
np
.
float32
)
,
)
}
out_gt
=
f
.
apply
({},
None
,
value
,
batch
)[
"loss"
]
...
...
@@ -399,7 +499,9 @@ class TestLoss(unittest.TestCase):
with
torch
.
no_grad
():
out_repro
=
masked_msa_loss
(
value
[
"logits"
],
**
batch
,
batch
[
"true_msa"
],
batch
[
"bert_mask"
],
consts
.
msa_logits
)
out_repro
=
tensor_tree_map
(
lambda
t
:
t
.
cpu
(),
out_repro
)
...
...
@@ -506,10 +608,28 @@ class TestLoss(unittest.TestCase):
c_chi_loss
=
config
.
model
.
heads
.
structure_module
def
run_supervised_chi_loss
(
value
,
batch
):
if
consts
.
is_multimer
:
pred_angles
=
np
.
reshape
(
value
[
'sidechains'
][
'angles_sin_cos'
],
[
-
1
,
consts
.
n_res
,
7
,
2
])
unnormed_angles
=
np
.
reshape
(
value
[
'sidechains'
][
'unnormalized_angles_sin_cos'
],
[
-
1
,
consts
.
n_res
,
7
,
2
])
chi_loss
,
_
,
_
=
self
.
am_fold
.
supervised_chi_loss
(
batch
[
'seq_mask'
],
batch
[
'chi_mask'
],
batch
[
'aatype'
],
batch
[
'chi_angles'
],
pred_angles
,
unnormed_angles
,
c_chi_loss
)
return
chi_loss
ret
=
{
"loss"
:
jax
.
numpy
.
array
(
0.0
),
}
alphafold
.
model
.
fold
ing
.
supervised_chi_loss
(
self
.
am_
fold
.
supervised_chi_loss
(
ret
,
batch
,
value
,
c_chi_loss
)
return
ret
[
"loss"
]
...
...
@@ -561,6 +681,40 @@ class TestLoss(unittest.TestCase):
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_violation_loss
(
self
):
config
=
compare_utils
.
get_alphafold_config
()
c_viol
=
config
.
model
.
heads
.
structure_module
n_res
=
consts
.
n_res
batch
=
{
"seq_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,)).
astype
(
np
.
float32
),
"residue_index"
:
np
.
arange
(
n_res
),
"aatype"
:
np
.
random
.
randint
(
0
,
21
,
(
n_res
,)),
}
if
consts
.
is_multimer
:
batch
[
"asym_id"
]
=
random_asym_ids
(
n_res
)
batch
=
tree_map
(
lambda
n
:
torch
.
tensor
(
n
).
cuda
(),
batch
,
np
.
ndarray
)
atom14_pred_pos
=
np
.
random
.
rand
(
n_res
,
14
,
3
).
astype
(
np
.
float32
)
atom14_pred_pos
=
torch
.
tensor
(
atom14_pred_pos
).
cuda
()
batch
=
data_transforms
.
make_atom14_masks
(
batch
)
loss_sum_clash
=
violation_loss
(
find_structural_violations
(
batch
,
atom14_pred_pos
,
**
c_viol
),
average_clashes
=
False
,
**
batch
)
loss_sum_clash
=
loss_sum_clash
.
cpu
()
loss_avg_clash
=
violation_loss
(
find_structural_violations
(
batch
,
atom14_pred_pos
,
**
c_viol
),
average_clashes
=
True
,
**
batch
)
loss_avg_clash
=
loss_avg_clash
.
cpu
()
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_violation_loss_compare
(
self
):
config
=
compare_utils
.
get_alphafold_config
()
...
...
@@ -570,15 +724,31 @@ class TestLoss(unittest.TestCase):
ret
=
{
"loss"
:
np
.
array
(
0.0
).
astype
(
np
.
float32
),
}
if
consts
.
is_multimer
:
atom14_pred_pos
=
self
.
am_rigid
.
Vec3Array
.
from_array
(
atom14_pred_pos
)
viol
=
self
.
am_fold
.
find_structural_violations
(
batch
[
'aatype'
],
batch
[
'residue_index'
],
batch
[
'atom14_atom_exists'
],
atom14_pred_pos
,
c_viol
,
batch
[
'asym_id'
]
)
return
self
.
am_fold
.
structural_violation_loss
(
mask
=
batch
[
'atom14_atom_exists'
],
violations
=
viol
,
config
=
c_viol
)
value
=
{}
value
[
"violations"
]
=
alphafold
.
model
.
fold
ing
.
find_structural_violations
(
]
=
self
.
am_
fold
.
find_structural_violations
(
batch
,
atom14_pred_pos
,
c_viol
,
)
alphafold
.
model
.
folding
.
structural_violation_loss
(
self
.
am_fold
.
structural_violation_loss
(
ret
,
batch
,
value
,
...
...
@@ -593,13 +763,17 @@ class TestLoss(unittest.TestCase):
batch
=
{
"seq_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,)).
astype
(
np
.
float32
),
"residue_index"
:
np
.
arange
(
n_res
),
"aatype"
:
np
.
random
.
randint
(
0
,
21
,
(
n_res
,))
,
"aatype"
:
np
.
random
.
randint
(
0
,
21
,
(
n_res
,))
}
alphafold
.
model
.
tf
.
data_transforms
.
make_atom14_masks
(
batch
)
batch
=
{
k
:
np
.
array
(
v
)
for
k
,
v
in
batch
.
items
()}
if
consts
.
is_multimer
:
batch
[
"asym_id"
]
=
random_asym_ids
(
n_res
)
atom14_pred_pos
=
np
.
random
.
rand
(
n_res
,
14
,
3
).
astype
(
np
.
float32
)
alphafold
.
model
.
tf
.
data_transforms
.
make_atom14_masks
(
batch
)
batch
=
{
k
:
np
.
array
(
v
)
for
k
,
v
in
batch
.
items
()}
out_gt
=
f
.
apply
({},
None
,
batch
,
atom14_pred_pos
)
out_gt
=
torch
.
tensor
(
np
.
array
(
out_gt
.
block_until_ready
()))
...
...
@@ -676,10 +850,31 @@ class TestLoss(unittest.TestCase):
c_sm
=
config
.
model
.
heads
.
structure_module
def
run_bb_loss
(
batch
,
value
):
if
consts
.
is_multimer
:
intra_chain_mask
=
(
batch
[
"asym_id"
][...,
None
]
==
batch
[
"asym_id"
][...,
None
,
:]).
astype
(
np
.
float32
)
gt_rigid
=
affine_vector_to_rigid
(
self
.
am_rigid
,
batch
[
"backbone_affine_tensor"
])
target_rigid
=
affine_vector_to_rigid
(
self
.
am_rigid
,
value
[
'traj'
])
intra_chain_bb_loss
,
intra_chain_fape
=
self
.
am_fold
.
backbone_loss
(
gt_rigid
=
gt_rigid
,
gt_frames_mask
=
batch
[
"backbone_affine_mask"
],
gt_positions_mask
=
batch
[
"backbone_affine_mask"
],
target_rigid
=
target_rigid
,
config
=
c_sm
.
intra_chain_fape
,
pair_mask
=
intra_chain_mask
)
interface_bb_loss
,
interface_fape
=
self
.
am_fold
.
backbone_loss
(
gt_rigid
=
gt_rigid
,
gt_frames_mask
=
batch
[
"backbone_affine_mask"
],
gt_positions_mask
=
batch
[
"backbone_affine_mask"
],
target_rigid
=
target_rigid
,
config
=
c_sm
.
interface_fape
,
pair_mask
=
1.
-
intra_chain_mask
)
return
intra_chain_bb_loss
+
interface_bb_loss
ret
=
{
"loss"
:
np
.
array
(
0.0
),
}
alphafold
.
model
.
fold
ing
.
backbone_loss
(
ret
,
batch
,
value
,
c_sm
)
self
.
am_
fold
.
backbone_loss
(
ret
,
batch
,
value
,
c_sm
)
return
ret
[
"loss"
]
f
=
hk
.
transform
(
run_bb_loss
)
...
...
@@ -691,7 +886,7 @@ class TestLoss(unittest.TestCase):
"backbone_affine_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,)).
astype
(
np
.
float32
),
"use_clamped_fape"
:
np
.
array
(
0.0
)
,
"use_clamped_fape"
:
np
.
array
(
0.0
)
}
value
=
{
...
...
@@ -703,6 +898,9 @@ class TestLoss(unittest.TestCase):
),
}
if
consts
.
is_multimer
:
batch
[
"asym_id"
]
=
random_asym_ids
(
n_res
)
out_gt
=
f
.
apply
({},
None
,
batch
,
value
)
out_gt
=
torch
.
tensor
(
np
.
array
(
out_gt
.
block_until_ready
()))
...
...
@@ -715,8 +913,18 @@ class TestLoss(unittest.TestCase):
)
batch
[
"backbone_rigid_mask"
]
=
batch
[
"backbone_affine_mask"
]
out_repro
=
backbone_loss
(
traj
=
value
[
"traj"
],
**
{
**
batch
,
**
c_sm
})
out_repro
=
out_repro
.
cpu
()
if
consts
.
is_multimer
:
intra_chain_mask
=
(
batch
[
"asym_id"
][...,
None
]
==
batch
[
"asym_id"
][...,
None
,
:]).
to
(
dtype
=
value
[
"traj"
].
dtype
)
intra_chain_out
=
backbone_loss
(
traj
=
value
[
"traj"
],
pair_mask
=
intra_chain_mask
,
**
{
**
batch
,
**
c_sm
.
intra_chain_fape
})
interface_out
=
backbone_loss
(
traj
=
value
[
"traj"
],
pair_mask
=
1.
-
intra_chain_mask
,
**
{
**
batch
,
**
c_sm
.
interface_fape
})
out_repro
=
intra_chain_out
+
interface_out
out_repro
=
out_repro
.
cpu
()
else
:
out_repro
=
backbone_loss
(
traj
=
value
[
"traj"
],
**
{
**
batch
,
**
c_sm
})
out_repro
=
out_repro
.
cpu
()
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
...
...
@@ -726,9 +934,29 @@ class TestLoss(unittest.TestCase):
c_sm
=
config
.
model
.
heads
.
structure_module
def
run_sidechain_loss
(
batch
,
value
,
atom14_pred_positions
):
if
consts
.
is_multimer
:
atom14_pred_positions
=
self
.
am_rigid
.
Vec3Array
.
from_array
(
atom14_pred_positions
)
all_atom_positions
=
self
.
am_rigid
.
Vec3Array
.
from_array
(
batch
[
"all_atom_positions"
])
gt_positions
,
gt_mask
,
alt_naming_is_better
=
self
.
am_fold
.
compute_atom14_gt
(
aatype
=
batch
[
"aatype"
],
all_atom_positions
=
all_atom_positions
,
all_atom_mask
=
batch
[
"all_atom_mask"
],
pred_pos
=
atom14_pred_positions
)
pred_frames
=
self
.
am_rigid
.
Rigid3Array
.
from_array4x4
(
value
[
"sidechains"
][
"frames"
])
pred_positions
=
self
.
am_rigid
.
Vec3Array
.
from_array
(
value
[
"sidechains"
][
"atom_pos"
])
gt_sc_frames
,
gt_sc_frames_mask
=
self
.
am_fold
.
compute_frames
(
aatype
=
batch
[
"aatype"
],
all_atom_positions
=
all_atom_positions
,
all_atom_mask
=
batch
[
"all_atom_mask"
],
use_alt
=
alt_naming_is_better
)
return
self
.
am_fold
.
sidechain_loss
(
gt_sc_frames
,
gt_sc_frames_mask
,
gt_positions
,
gt_mask
,
pred_frames
,
pred_positions
,
c_sm
)[
'loss'
]
batch
=
{
**
batch
,
**
alphafold
.
mod
el
.
a
ll
_atom
.
atom37_to_frames
(
**
s
el
f
.
a
m
_atom
.
atom37_to_frames
(
batch
[
"aatype"
],
batch
[
"all_atom_positions"
],
batch
[
"all_atom_mask"
],
...
...
@@ -738,21 +966,21 @@ class TestLoss(unittest.TestCase):
v
[
"sidechains"
]
=
{}
v
[
"sidechains"
][
"frames"
]
=
alphafold
.
model
.
r3
.
rigids_from_tensor4x4
(
]
=
self
.
am_rigid
.
rigids_from_tensor4x4
(
value
[
"sidechains"
][
"frames"
]
)
v
[
"sidechains"
][
"atom_pos"
]
=
alphafold
.
model
.
r3
.
vecs_from_tensor
(
v
[
"sidechains"
][
"atom_pos"
]
=
self
.
am_rigid
.
vecs_from_tensor
(
value
[
"sidechains"
][
"atom_pos"
]
)
v
.
update
(
alphafold
.
model
.
fold
ing
.
compute_renamed_ground_truth
(
self
.
am_
fold
.
compute_renamed_ground_truth
(
batch
,
atom14_pred_positions
,
)
)
value
=
v
ret
=
alphafold
.
model
.
fold
ing
.
sidechain_loss
(
batch
,
value
,
c_sm
)
ret
=
self
.
am_
fold
.
sidechain_loss
(
batch
,
value
,
c_sm
)
return
ret
[
"loss"
]
f
=
hk
.
transform
(
run_sidechain_loss
)
...
...
@@ -816,6 +1044,7 @@ class TestLoss(unittest.TestCase):
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
@
compare_utils
.
skip_unless_alphafold_installed
()
@
unittest
.
skipIf
(
not
consts
.
is_multimer
and
"ptm"
not
in
consts
.
model
,
"Not enabled for non-ptm models."
)
def
test_tm_loss_compare
(
self
):
config
=
compare_utils
.
get_alphafold_config
()
c_tm
=
config
.
model
.
heads
.
predicted_aligned_error
...
...
@@ -882,6 +1111,33 @@ class TestLoss(unittest.TestCase):
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_chain_center_of_mass_loss
(
self
):
batch_size
=
consts
.
batch_size
n_res
=
consts
.
n_res
batch
=
{
"all_atom_positions"
:
np
.
random
.
rand
(
batch_size
,
n_res
,
37
,
3
).
astype
(
np
.
float32
)
*
10.0
,
"all_atom_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
batch_size
,
n_res
,
37
)).
astype
(
np
.
float32
),
"asym_id"
:
np
.
stack
([
random_asym_ids
(
n_res
)
for
_
in
range
(
batch_size
)])
}
config
=
{
"weight"
:
0.05
,
"clamp_distance"
:
-
4.0
,
}
final_atom_positions
=
torch
.
rand
(
batch_size
,
n_res
,
37
,
3
).
cuda
()
to_tensor
=
lambda
t
:
torch
.
tensor
(
t
).
cuda
()
batch
=
tree_map
(
to_tensor
,
batch
,
np
.
ndarray
)
out_repro
=
chain_center_of_mass_loss
(
all_atom_pred_pos
=
final_atom_positions
,
**
{
**
batch
,
**
config
},
)
out_repro
=
out_repro
.
cpu
()
if
__name__
==
"__main__"
:
unittest
.
main
()
tests/test_model.py
View file @
56d5e39c
...
...
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
pathlib
import
Path
import
pickle
import
torch
import
torch.nn
as
nn
...
...
@@ -20,8 +21,7 @@ import unittest
from
openfold.config
import
model_config
from
openfold.data
import
data_transforms
from
openfold.model.model
import
AlphaFold
import
openfold.utils.feats
as
feats
from
openfold.utils.tensor_utils
import
tree_map
,
tensor_tree_map
from
openfold.utils.tensor_utils
import
tensor_tree_map
import
tests.compare_utils
as
compare_utils
from
tests.config
import
consts
from
tests.data_utils
import
(
...
...
@@ -36,13 +36,26 @@ if compare_utils.alphafold_is_installed():
class
TestModel
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
if
consts
.
is_multimer
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom_multimer
cls
.
am_fold
=
alphafold
.
model
.
folding_multimer
cls
.
am_modules
=
alphafold
.
model
.
modules_multimer
cls
.
am_rigid
=
alphafold
.
model
.
geometry
else
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom
cls
.
am_fold
=
alphafold
.
model
.
folding
cls
.
am_modules
=
alphafold
.
model
.
modules
cls
.
am_rigid
=
alphafold
.
model
.
r3
def
test_dry_run
(
self
):
n_seq
=
consts
.
n_seq
n_templ
=
consts
.
n_templ
n_res
=
consts
.
n_res
n_extra_seq
=
consts
.
n_extra
c
=
model_config
(
"model_1"
)
c
=
model_config
(
consts
.
model
,
train
=
True
)
c
.
model
.
evoformer_stack
.
no_blocks
=
4
# no need to go overboard here
c
.
model
.
evoformer_stack
.
blocks_per_ckpt
=
None
# don't want to set up
# deepspeed for this test
...
...
@@ -56,6 +69,7 @@ class TestModel(unittest.TestCase):
).
float
()
batch
[
"aatype"
]
=
torch
.
argmax
(
batch
[
"target_feat"
],
dim
=-
1
)
batch
[
"residue_index"
]
=
torch
.
arange
(
n_res
)
batch
[
"msa_feat"
]
=
torch
.
rand
((
n_seq
,
n_res
,
c
.
model
.
input_embedder
.
msa_dim
))
t_feats
=
random_template_feats
(
n_templ
,
n_res
)
batch
.
update
({
k
:
torch
.
tensor
(
v
)
for
k
,
v
in
t_feats
.
items
()})
...
...
@@ -68,6 +82,12 @@ class TestModel(unittest.TestCase):
batch
.
update
(
data_transforms
.
make_atom14_masks
(
batch
))
batch
[
"no_recycling_iters"
]
=
torch
.
tensor
(
2.
)
if
consts
.
is_multimer
:
batch
[
"asym_id"
]
=
torch
.
randint
(
0
,
1
,
size
=
(
n_res
,))
batch
[
"entity_id"
]
=
torch
.
randint
(
0
,
1
,
size
=
(
n_res
,))
batch
[
"sym_id"
]
=
torch
.
randint
(
0
,
1
,
size
=
(
n_res
,))
batch
[
"extra_deletion_matrix"
]
=
torch
.
randint
(
0
,
2
,
size
=
(
n_extra_seq
,
n_res
))
add_recycling_dims
=
lambda
t
:
(
t
.
unsqueeze
(
-
1
).
expand
(
*
t
.
shape
,
c
.
data
.
common
.
max_recycling_iters
)
)
...
...
@@ -77,10 +97,14 @@ class TestModel(unittest.TestCase):
out
=
model
(
batch
)
@
compare_utils
.
skip_unless_alphafold_installed
()
@
unittest
.
skipIf
(
consts
.
is_multimer
,
"Additional changes required for multimer."
)
def
test_compare
(
self
):
#TODO: Fix test data for multimer MSA features
def
run_alphafold
(
batch
):
config
=
compare_utils
.
get_alphafold_config
()
model
=
alphafold
.
model
.
modules
.
AlphaFold
(
config
.
model
)
model
=
self
.
am_modules
.
AlphaFold
(
config
.
model
)
return
model
(
batch
=
batch
,
is_training
=
False
,
...
...
@@ -91,7 +115,8 @@ class TestModel(unittest.TestCase):
params
=
compare_utils
.
fetch_alphafold_module_weights
(
""
)
with
open
(
"tests/test_data/sample_feats.pickle"
,
"rb"
)
as
fp
:
fpath
=
Path
(
__file__
).
parent
.
resolve
()
/
"test_data/sample_feats.pickle"
with
open
(
str
(
fpath
),
"rb"
)
as
fp
:
batch
=
pickle
.
load
(
fp
)
out_gt
=
f
.
apply
(
params
,
jax
.
random
.
PRNGKey
(
42
),
batch
)
...
...
@@ -100,7 +125,8 @@ class TestModel(unittest.TestCase):
# atom37_to_atom14 doesn't like batches
batch
[
"residx_atom14_to_atom37"
]
=
batch
[
"residx_atom14_to_atom37"
][
0
]
batch
[
"atom14_atom_exists"
]
=
batch
[
"atom14_atom_exists"
][
0
]
out_gt
=
alphafold
.
model
.
all_atom
.
atom37_to_atom14
(
out_gt
,
batch
)
out_gt
=
self
.
am_atom
.
atom37_to_atom14
(
out_gt
,
batch
)
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
.
block_until_ready
()))
batch
[
"no_recycling_iters"
]
=
np
.
array
([
3.
,
3.
,
3.
,
3.
,])
...
...
tests/test_outer_product_mean.py
View file @
56d5e39c
...
...
@@ -81,7 +81,7 @@ class TestOuterProductMean(unittest.TestCase):
model
=
compare_utils
.
get_global_pretrained_openfold
()
out_repro
=
(
model
.
evoformer
.
blocks
[
0
]
.
core
model
.
evoformer
.
blocks
[
0
]
.
outer_product_mean
(
torch
.
as_tensor
(
msa_act
).
cuda
(),
chunk_size
=
4
,
...
...
tests/test_pair_transition.py
View file @
56d5e39c
...
...
@@ -76,7 +76,7 @@ class TestPairTransition(unittest.TestCase):
model
=
compare_utils
.
get_global_pretrained_openfold
()
out_repro
=
(
model
.
evoformer
.
blocks
[
0
].
core
model
.
evoformer
.
blocks
[
0
].
pair_stack
.
pair_transition
(
torch
.
as_tensor
(
pair_act
,
dtype
=
torch
.
float32
).
cuda
(),
chunk_size
=
4
,
...
...
tests/test_primitives.py
View file @
56d5e39c
...
...
@@ -13,20 +13,17 @@
# limitations under the License.
import
torch
import
numpy
as
np
import
unittest
from
openfold.model.primitives
import
(
Attention
,
)
from
openfold.model.primitives
import
Attention
from
tests.config
import
consts
class
TestLMA
(
unittest
.
TestCase
):
def
test_lma_vs_attention
(
self
):
batch_size
=
consts
.
batch_size
c_hidden
=
32
n
=
2
**
12
c_hidden
=
32
n
=
2
**
12
no_heads
=
4
q
=
torch
.
rand
(
batch_size
,
n
,
c_hidden
).
cuda
()
...
...
@@ -34,20 +31,17 @@ class TestLMA(unittest.TestCase):
bias
=
[
torch
.
rand
(
no_heads
,
1
,
n
)]
bias
=
[
b
.
cuda
()
for
b
in
bias
]
gating_fill
=
torch
.
rand
(
c_hidden
*
no_heads
,
c_hidden
)
o_fill
=
torch
.
rand
(
c_hidden
,
c_hidden
*
no_heads
)
a
=
Attention
(
c_hidden
,
c_hidden
,
c_hidden
,
c_hidden
,
no_heads
).
cuda
()
with
torch
.
no_grad
():
l
=
a
(
q
,
kv
,
biases
=
bias
,
use_lma
=
True
)
real
=
a
(
q
,
kv
,
biases
=
bias
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
l
-
real
))
<
consts
.
eps
)
if
__name__
==
"__main__"
:
unittest
.
main
()
tests/test_structure_module.py
View file @
56d5e39c
...
...
@@ -18,21 +18,19 @@ import unittest
from
openfold.data.data_transforms
import
make_atom14_masks_np
from
openfold.np.residue_constants
import
(
restype_rigid_group_default_frame
,
restype_atom14_to_rigid_group
,
restype_atom14_mask
,
restype_atom14_rigid_group_positions
,
restype_atom37_mask
,
)
from
openfold.model.structure_module
import
(
StructureModule
,
StructureModuleTransition
,
BackboneUpdate
,
AngleResnet
,
InvariantPointAttention
,
)
import
openfold.utils.feats
as
feats
from
openfold.utils.rigid_utils
import
Rotation
,
Rigid
from
openfold.utils.geometry.rigid_matrix_vector
import
Rigid3Array
from
openfold.utils.geometry.rotation_matrix
import
Rot3Array
from
openfold.utils.geometry.vector
import
Vec3Array
import
tests.compare_utils
as
compare_utils
from
tests.config
import
consts
from
tests.data_utils
import
(
...
...
@@ -46,6 +44,19 @@ if compare_utils.alphafold_is_installed():
class
TestStructureModule
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
if
consts
.
is_multimer
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom_multimer
cls
.
am_fold
=
alphafold
.
model
.
folding_multimer
cls
.
am_modules
=
alphafold
.
model
.
modules_multimer
cls
.
am_rigid
=
alphafold
.
model
.
geometry
else
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom
cls
.
am_fold
=
alphafold
.
model
.
folding
cls
.
am_modules
=
alphafold
.
model
.
modules
cls
.
am_rigid
=
alphafold
.
model
.
r3
def
test_structure_module_shape
(
self
):
batch_size
=
consts
.
batch_size
n
=
consts
.
n_res
...
...
@@ -81,6 +92,7 @@ class TestStructureModule(unittest.TestCase):
trans_scale_factor
,
ar_epsilon
,
inf
,
is_multimer
=
consts
.
is_multimer
)
s
=
torch
.
rand
((
batch_size
,
n
,
c_s
))
...
...
@@ -89,7 +101,11 @@ class TestStructureModule(unittest.TestCase):
out
=
sm
({
"single"
:
s
,
"pair"
:
z
},
f
)
self
.
assertTrue
(
out
[
"frames"
].
shape
==
(
no_layers
,
batch_size
,
n
,
7
))
if
consts
.
is_multimer
:
self
.
assertTrue
(
out
[
"frames"
].
shape
==
(
no_layers
,
batch_size
,
n
,
4
,
4
))
else
:
self
.
assertTrue
(
out
[
"frames"
].
shape
==
(
no_layers
,
batch_size
,
n
,
7
))
self
.
assertTrue
(
out
[
"angles"
].
shape
==
(
no_layers
,
batch_size
,
n
,
no_angles
,
2
)
)
...
...
@@ -121,11 +137,14 @@ class TestStructureModule(unittest.TestCase):
c_global
=
config
.
model
.
global_config
def
run_sm
(
representations
,
batch
):
sm
=
alphafold
.
model
.
fold
ing
.
StructureModule
(
c_sm
,
c_global
)
sm
=
self
.
am_
fold
.
StructureModule
(
c_sm
,
c_global
)
representations
=
{
k
:
jax
.
lax
.
stop_gradient
(
v
)
for
k
,
v
in
representations
.
items
()
}
batch
=
{
k
:
jax
.
lax
.
stop_gradient
(
v
)
for
k
,
v
in
batch
.
items
()}
if
consts
.
is_multimer
:
return
sm
(
representations
,
batch
,
is_training
=
False
,
compute_loss
=
True
)
return
sm
(
representations
,
batch
,
is_training
=
False
)
f
=
hk
.
transform
(
run_sm
)
...
...
@@ -181,6 +200,19 @@ class TestStructureModule(unittest.TestCase):
class
TestInvariantPointAttention
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
if
consts
.
is_multimer
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom_multimer
cls
.
am_fold
=
alphafold
.
model
.
folding_multimer
cls
.
am_modules
=
alphafold
.
model
.
modules_multimer
cls
.
am_rigid
=
alphafold
.
model
.
geometry
else
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom
cls
.
am_fold
=
alphafold
.
model
.
folding
cls
.
am_modules
=
alphafold
.
model
.
modules
cls
.
am_rigid
=
alphafold
.
model
.
r3
def
test_shape
(
self
):
c_m
=
13
c_z
=
17
...
...
@@ -197,13 +229,18 @@ class TestInvariantPointAttention(unittest.TestCase):
mask
=
torch
.
ones
((
batch_size
,
n_res
))
rot_mats
=
torch
.
rand
((
batch_size
,
n_res
,
3
,
3
))
rots
=
Rotation
(
rot_mats
=
rot_mats
,
quats
=
None
)
trans
=
torch
.
rand
((
batch_size
,
n_res
,
3
))
r
=
Rigid
(
rots
,
trans
)
if
consts
.
is_multimer
:
rotation
=
Rot3Array
.
from_array
(
rot_mats
)
translation
=
Vec3Array
.
from_array
(
trans
)
r
=
Rigid3Array
(
rotation
,
translation
)
else
:
rots
=
Rotation
(
rot_mats
=
rot_mats
,
quats
=
None
)
r
=
Rigid
(
rots
,
trans
)
ipa
=
InvariantPointAttention
(
c_m
,
c_z
,
c_hidden
,
no_heads
,
no_qp
,
no_vp
c_m
,
c_z
,
c_hidden
,
no_heads
,
no_qp
,
no_vp
,
is_multimer
=
consts
.
is_multimer
)
shape_before
=
s
.
shape
...
...
@@ -215,16 +252,26 @@ class TestInvariantPointAttention(unittest.TestCase):
def
test_ipa_compare
(
self
):
def
run_ipa
(
act
,
static_feat_2d
,
mask
,
affine
):
config
=
compare_utils
.
get_alphafold_config
()
ipa
=
alphafold
.
model
.
fold
ing
.
InvariantPointAttention
(
ipa
=
self
.
am_
fold
.
InvariantPointAttention
(
config
.
model
.
heads
.
structure_module
,
config
.
model
.
global_config
,
)
attn
=
ipa
(
inputs_1d
=
act
,
inputs_2d
=
static_feat_2d
,
mask
=
mask
,
affine
=
affine
,
)
if
consts
.
is_multimer
:
attn
=
ipa
(
inputs_1d
=
act
,
inputs_2d
=
static_feat_2d
,
mask
=
mask
,
rigid
=
affine
)
else
:
attn
=
ipa
(
inputs_1d
=
act
,
inputs_2d
=
static_feat_2d
,
mask
=
mask
,
affine
=
affine
)
return
attn
f
=
hk
.
transform
(
run_ipa
)
...
...
@@ -238,13 +285,20 @@ class TestInvariantPointAttention(unittest.TestCase):
sample_mask
=
np
.
ones
((
n_res
,
1
))
affines
=
random_affines_4x4
((
n_res
,))
rigids
=
alphafold
.
model
.
r3
.
rigids_from_tensor4x4
(
affines
)
quats
=
alphafold
.
model
.
r3
.
rigids_to_quataffine
(
rigids
)
transformations
=
Rigid
.
from_tensor_4x4
(
torch
.
as_tensor
(
affines
).
float
().
cuda
()
)
sample_affine
=
quats
if
consts
.
is_multimer
:
rigids
=
self
.
am_rigid
.
Rigid3Array
.
from_array4x4
(
affines
)
transformations
=
Rigid3Array
.
from_tensor_4x4
(
torch
.
as_tensor
(
affines
).
float
().
cuda
()
)
sample_affine
=
rigids
else
:
rigids
=
self
.
am_rigid
.
rigids_from_tensor4x4
(
affines
)
quats
=
self
.
am_rigid
.
rigids_to_quataffine
(
rigids
)
transformations
=
Rigid
.
from_tensor_4x4
(
torch
.
as_tensor
(
affines
).
float
().
cuda
()
)
sample_affine
=
quats
ipa_params
=
compare_utils
.
fetch_alphafold_module_weights
(
"alphafold/alphafold_iteration/structure_module/"
...
...
tests/test_template.py
View file @
56d5e39c
...
...
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
re
import
torch
import
numpy
as
np
import
unittest
...
...
@@ -19,7 +20,6 @@ from openfold.model.template import (
TemplatePointwiseAttention
,
TemplatePairStack
,
)
from
openfold.utils.tensor_utils
import
tree_map
import
tests.compare_utils
as
compare_utils
from
tests.config
import
consts
from
tests.data_utils
import
random_template_feats
...
...
@@ -54,6 +54,19 @@ class TestTemplatePointwiseAttention(unittest.TestCase):
class
TestTemplatePairStack
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
if
consts
.
is_multimer
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom_multimer
cls
.
am_fold
=
alphafold
.
model
.
folding_multimer
cls
.
am_modules
=
alphafold
.
model
.
modules_multimer
cls
.
am_rigid
=
alphafold
.
model
.
geometry
else
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom
cls
.
am_fold
=
alphafold
.
model
.
folding
cls
.
am_modules
=
alphafold
.
model
.
modules
cls
.
am_rigid
=
alphafold
.
model
.
r3
def
test_shape
(
self
):
batch_size
=
consts
.
batch_size
c_t
=
consts
.
c_t
...
...
@@ -65,6 +78,8 @@ class TestTemplatePairStack(unittest.TestCase):
dropout
=
0.25
n_templ
=
consts
.
n_templ
n_res
=
consts
.
n_res
tri_mul_first
=
consts
.
is_multimer
fuse_projection_weights
=
True
if
re
.
fullmatch
(
"^model_[1-5]_multimer_v3$"
,
consts
.
model
)
else
False
blocks_per_ckpt
=
None
chunk_size
=
4
inf
=
1e7
...
...
@@ -78,6 +93,8 @@ class TestTemplatePairStack(unittest.TestCase):
no_heads
=
no_heads
,
pair_transition_n
=
pt_inner_dim
,
dropout_rate
=
dropout
,
tri_mul_first
=
tri_mul_first
,
fuse_projection_weights
=
fuse_projection_weights
,
blocks_per_ckpt
=
None
,
inf
=
inf
,
eps
=
eps
,
...
...
@@ -96,12 +113,40 @@ class TestTemplatePairStack(unittest.TestCase):
def
run_template_pair_stack
(
pair_act
,
pair_mask
):
config
=
compare_utils
.
get_alphafold_config
()
c_ee
=
config
.
model
.
embeddings_and_evoformer
tps
=
alphafold
.
model
.
modules
.
TemplatePairStack
(
c_ee
.
template
.
template_pair_stack
,
config
.
model
.
global_config
,
name
=
"template_pair_stack"
,
)
act
=
tps
(
pair_act
,
pair_mask
,
is_training
=
False
)
if
consts
.
is_multimer
:
safe_key
=
alphafold
.
model
.
prng
.
SafeKey
(
hk
.
next_rng_key
())
template_iteration
=
self
.
am_modules
.
TemplateEmbeddingIteration
(
c_ee
.
template
.
template_pair_stack
,
config
.
model
.
global_config
,
name
=
'template_embedding_iteration'
)
def
template_iteration_fn
(
x
):
act
,
safe_key
=
x
safe_key
,
safe_subkey
=
safe_key
.
split
()
act
=
template_iteration
(
act
=
act
,
pair_mask
=
pair_mask
,
is_training
=
False
,
safe_key
=
safe_subkey
)
return
(
act
,
safe_key
)
if
config
.
model
.
global_config
.
use_remat
:
template_iteration_fn
=
hk
.
remat
(
template_iteration_fn
)
safe_key
,
safe_subkey
=
safe_key
.
split
()
template_stack
=
alphafold
.
model
.
layer_stack
.
layer_stack
(
c_ee
.
template
.
template_pair_stack
.
num_block
)(
template_iteration_fn
)
act
,
_
=
template_stack
((
pair_act
,
safe_subkey
))
else
:
tps
=
self
.
am_modules
.
TemplatePairStack
(
c_ee
.
template
.
template_pair_stack
,
config
.
model
.
global_config
,
name
=
"template_pair_stack"
,
)
act
=
tps
(
pair_act
,
pair_mask
,
is_training
=
False
)
ln
=
hk
.
LayerNorm
([
-
1
],
True
,
True
,
name
=
"output_layer_norm"
)
act
=
ln
(
act
)
return
act
...
...
@@ -115,10 +160,16 @@ class TestTemplatePairStack(unittest.TestCase):
low
=
0
,
high
=
2
,
size
=
(
n_res
,
n_res
)
).
astype
(
np
.
float32
)
params
=
compare_utils
.
fetch_alphafold_module_weights
(
"alphafold/alphafold_iteration/evoformer/template_embedding/"
+
"single_template_embedding/template_pair_stack"
)
if
consts
.
is_multimer
:
params
=
compare_utils
.
fetch_alphafold_module_weights
(
"alphafold/alphafold_iteration/evoformer/template_embedding/"
+
"single_template_embedding/template_embedding_iteration"
)
else
:
params
=
compare_utils
.
fetch_alphafold_module_weights
(
"alphafold/alphafold_iteration/evoformer/template_embedding/"
+
"single_template_embedding/template_pair_stack"
)
params
.
update
(
compare_utils
.
fetch_alphafold_module_weights
(
"alphafold/alphafold_iteration/evoformer/template_embedding/"
...
...
@@ -132,7 +183,7 @@ class TestTemplatePairStack(unittest.TestCase):
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
))
model
=
compare_utils
.
get_global_pretrained_openfold
()
out_repro
=
model
.
template_pair_stack
(
out_repro
=
model
.
template_
embedder
.
template_
pair_stack
(
torch
.
as_tensor
(
pair_act
).
unsqueeze
(
-
4
).
cuda
(),
torch
.
as_tensor
(
pair_mask
).
unsqueeze
(
-
3
).
cuda
(),
chunk_size
=
None
,
...
...
@@ -143,15 +194,32 @@ class TestTemplatePairStack(unittest.TestCase):
class
Template
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
if
consts
.
is_multimer
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom_multimer
cls
.
am_fold
=
alphafold
.
model
.
folding_multimer
cls
.
am_modules
=
alphafold
.
model
.
modules_multimer
cls
.
am_rigid
=
alphafold
.
model
.
geometry
else
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom
cls
.
am_fold
=
alphafold
.
model
.
folding
cls
.
am_modules
=
alphafold
.
model
.
modules
cls
.
am_rigid
=
alphafold
.
model
.
r3
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_compare
(
self
):
def
test_template_embedding
(
pair
,
batch
,
mask_2d
):
def
test_template_embedding
(
pair
,
batch
,
mask_2d
,
mc_
mask_2d
):
config
=
compare_utils
.
get_alphafold_config
()
te
=
alphafold
.
model
.
modules
.
TemplateEmbedding
(
te
=
self
.
am_
modules
.
TemplateEmbedding
(
config
.
model
.
embeddings_and_evoformer
.
template
,
config
.
model
.
global_config
,
)
act
=
te
(
pair
,
batch
,
mask_2d
,
is_training
=
False
)
if
consts
.
is_multimer
:
act
=
te
(
pair
,
batch
,
mask_2d
,
multichain_mask_2d
=
mc_mask_2d
,
is_training
=
False
)
else
:
act
=
te
(
pair
,
batch
,
mask_2d
,
is_training
=
False
)
return
act
f
=
hk
.
transform
(
test_template_embedding
)
...
...
@@ -162,6 +230,14 @@ class Template(unittest.TestCase):
pair_act
=
np
.
random
.
rand
(
n_res
,
n_res
,
consts
.
c_z
).
astype
(
np
.
float32
)
batch
=
random_template_feats
(
n_templ
,
n_res
)
batch
[
"template_all_atom_masks"
]
=
batch
[
"template_all_atom_mask"
]
multichain_mask_2d
=
None
if
consts
.
is_multimer
:
asym_id
=
batch
[
'asym_id'
][
0
]
multichain_mask_2d
=
(
asym_id
[...,
None
]
==
asym_id
[...,
None
,
:]
).
astype
(
np
.
float32
)
pair_mask
=
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
n_res
)).
astype
(
np
.
float32
)
# Fetch pretrained parameters (but only from one block)]
params
=
compare_utils
.
fetch_alphafold_module_weights
(
...
...
@@ -169,7 +245,7 @@ class Template(unittest.TestCase):
)
out_gt
=
f
.
apply
(
params
,
jax
.
random
.
PRNGKey
(
42
),
pair_act
,
batch
,
pair_mask
params
,
jax
.
random
.
PRNGKey
(
42
),
pair_act
,
batch
,
pair_mask
,
multichain_mask_2d
).
block_until_ready
()
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
))
...
...
@@ -177,13 +253,30 @@ class Template(unittest.TestCase):
batch
[
"target_feat"
]
=
np
.
eye
(
22
)[
inds
]
model
=
compare_utils
.
get_global_pretrained_openfold
()
out_repro
=
model
.
embed_templates
(
{
k
:
torch
.
as_tensor
(
v
).
cuda
()
for
k
,
v
in
batch
.
items
()},
torch
.
as_tensor
(
pair_act
).
cuda
(),
torch
.
as_tensor
(
pair_mask
).
cuda
(),
templ_dim
=
0
,
inplace_safe
=
False
)
template_feats
=
{
k
:
torch
.
as_tensor
(
v
).
cuda
()
for
k
,
v
in
batch
.
items
()}
if
consts
.
is_multimer
:
out_repro
=
model
.
template_embedder
(
template_feats
,
torch
.
as_tensor
(
pair_act
).
cuda
(),
torch
.
as_tensor
(
pair_mask
).
cuda
(),
templ_dim
=
0
,
chunk_size
=
consts
.
chunk_size
,
multichain_mask_2d
=
torch
.
as_tensor
(
multichain_mask_2d
).
cuda
(),
use_lma
=
False
,
inplace_safe
=
False
)
else
:
out_repro
=
model
.
template_embedder
(
template_feats
,
torch
.
as_tensor
(
pair_act
).
cuda
(),
torch
.
as_tensor
(
pair_mask
).
cuda
(),
templ_dim
=
0
,
chunk_size
=
consts
.
chunk_size
,
use_lma
=
False
,
inplace_safe
=
False
)
out_repro
=
out_repro
[
"template_pair_embedding"
]
out_repro
=
out_repro
.
cpu
()
...
...
tests/test_triangular_attention.py
View file @
56d5e39c
...
...
@@ -86,9 +86,9 @@ class TestTriangularAttention(unittest.TestCase):
model
=
compare_utils
.
get_global_pretrained_openfold
()
module
=
(
model
.
evoformer
.
blocks
[
0
].
core
.
tri_att_start
model
.
evoformer
.
blocks
[
0
].
pair_stack
.
tri_att_start
if
starting
else
model
.
evoformer
.
blocks
[
0
].
core
.
tri_att_end
else
model
.
evoformer
.
blocks
[
0
].
pair_stack
.
tri_att_end
)
# To save memory, the full model transposes inputs outside of the
...
...
tests/test_triangular_multiplicative_update.py
View file @
56d5e39c
...
...
@@ -13,6 +13,7 @@
# limitations under the License.
import
torch
import
re
import
numpy
as
np
import
unittest
from
openfold.model.triangular_multiplicative_update
import
*
...
...
@@ -31,10 +32,16 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
c_z
=
consts
.
c_z
c
=
11
tm
=
TriangleMultiplicationOutgoing
(
c_z
,
c
,
)
if
re
.
fullmatch
(
"^model_[1-5]_multimer_v3$"
,
consts
.
model
):
tm
=
FusedTriangleMultiplicationOutgoing
(
c_z
,
c
,
)
else
:
tm
=
TriangleMultiplicationOutgoing
(
c_z
,
c
,
)
n_res
=
consts
.
c_z
batch_size
=
consts
.
batch_size
...
...
@@ -62,7 +69,7 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
config
.
model
.
global_config
,
name
=
name
,
)
act
=
tri_mul
(
act
=
pair_act
,
mask
=
pair_mask
)
act
=
tri_mul
(
pair_act
,
pair_mask
)
return
act
f
=
hk
.
transform
(
run_tri_mul
)
...
...
@@ -85,10 +92,11 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
model
=
compare_utils
.
get_global_pretrained_openfold
()
module
=
(
model
.
evoformer
.
blocks
[
0
].
core
.
tri_mul_in
model
.
evoformer
.
blocks
[
0
].
pair_stack
.
tri_mul_in
if
incoming
else
model
.
evoformer
.
blocks
[
0
].
core
.
tri_mul_out
else
model
.
evoformer
.
blocks
[
0
].
pair_stack
.
tri_mul_out
)
out_repro
=
module
(
torch
.
as_tensor
(
pair_act
,
dtype
=
torch
.
float32
).
cuda
(),
mask
=
torch
.
as_tensor
(
pair_mask
,
dtype
=
torch
.
float32
).
cuda
(),
...
...
@@ -112,12 +120,11 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
pair_mask
=
np
.
random
.
randint
(
low
=
0
,
high
=
2
,
size
=
(
n_res
,
n_res
))
pair_mask
=
pair_mask
.
astype
(
np
.
float32
)
model
=
compare_utils
.
get_global_pretrained_openfold
()
module
=
(
model
.
evoformer
.
blocks
[
0
].
core
.
tri_mul_in
model
.
evoformer
.
blocks
[
0
].
pair_stack
.
tri_mul_in
if
incoming
else
model
.
evoformer
.
blocks
[
0
].
core
.
tri_mul_out
else
model
.
evoformer
.
blocks
[
0
].
pair_stack
.
tri_mul_out
)
out_stock
=
module
(
torch
.
as_tensor
(
pair_act
,
dtype
=
torch
.
float32
).
cuda
(),
...
...
tests/test_utils.py
View file @
56d5e39c
...
...
@@ -23,7 +23,7 @@ from openfold.utils.rigid_utils import (
quat_to_rot
,
rot_to_quat
,
)
from
openfold.utils.
tensor
_utils
import
chunk_layer
,
_chunk_slice
from
openfold.utils.
chunk
_utils
import
chunk_layer
,
_chunk_slice
import
tests.compare_utils
as
compare_utils
from
tests.config
import
consts
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment