Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
56d5e39c
"deploy/dynemo/operator/pkg/compoundai/reqcli/http.go" did not exist on "5ddc7f7df5ab77c4efae9fd6ca299c3040c91533"
Commit
56d5e39c
authored
Jun 17, 2023
by
Geoffrey Yu
Browse files
Merge remote-tracking branch 'upstream/multimer' into multimer
parents
56b86074
51556d52
Changes
80
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
742 additions
and
168 deletions
+742
-168
setup.py
setup.py
+2
-0
tests/compare_utils.py
tests/compare_utils.py
+4
-4
tests/config.py
tests/config.py
+5
-1
tests/data_utils.py
tests/data_utils.py
+32
-0
tests/test_data_pipeline.py
tests/test_data_pipeline.py
+25
-15
tests/test_data_transforms.py
tests/test_data_transforms.py
+1
-5
tests/test_embedders.py
tests/test_embedders.py
+25
-5
tests/test_evoformer.py
tests/test_evoformer.py
+10
-1
tests/test_feats.py
tests/test_feats.py
+99
-28
tests/test_import_weights.py
tests/test_import_weights.py
+1
-1
tests/test_loss.py
tests/test_loss.py
+284
-28
tests/test_model.py
tests/test_model.py
+32
-6
tests/test_outer_product_mean.py
tests/test_outer_product_mean.py
+1
-1
tests/test_pair_transition.py
tests/test_pair_transition.py
+1
-1
tests/test_primitives.py
tests/test_primitives.py
+7
-13
tests/test_structure_module.py
tests/test_structure_module.py
+77
-23
tests/test_template.py
tests/test_template.py
+116
-23
tests/test_triangular_attention.py
tests/test_triangular_attention.py
+2
-2
tests/test_triangular_multiplicative_update.py
tests/test_triangular_multiplicative_update.py
+17
-10
tests/test_utils.py
tests/test_utils.py
+1
-1
No files found.
setup.py
View file @
56d5e39c
...
@@ -75,6 +75,8 @@ for major, minor in list(compute_capabilities):
...
@@ -75,6 +75,8 @@ for major, minor in list(compute_capabilities):
extra_cuda_flags
+=
cc_flag
extra_cuda_flags
+=
cc_flag
cc_flag
=
[
'-gencode'
,
'arch=compute_70,code=sm_70'
]
if
bare_metal_major
!=
-
1
:
if
bare_metal_major
!=
-
1
:
modules
=
[
CUDAExtension
(
modules
=
[
CUDAExtension
(
name
=
"attn_core_inplace_cuda"
,
name
=
"attn_core_inplace_cuda"
,
...
...
tests/compare_utils.py
View file @
56d5e39c
...
@@ -46,26 +46,26 @@ def import_alphafold():
...
@@ -46,26 +46,26 @@ def import_alphafold():
def
get_alphafold_config
():
def
get_alphafold_config
():
config
=
alphafold
.
model
.
config
.
model_config
(
"model_1_ptm"
)
# noqa
config
=
alphafold
.
model
.
config
.
model_config
(
consts
.
model
)
# noqa
config
.
model
.
global_config
.
deterministic
=
True
config
.
model
.
global_config
.
deterministic
=
True
return
config
return
config
_param_path
=
"openfold/resources/params/params_
model_1_ptm
.npz"
_param_path
=
f
"openfold/resources/params/params_
{
consts
.
model
}
.npz"
_model
=
None
_model
=
None
def
get_global_pretrained_openfold
():
def
get_global_pretrained_openfold
():
global
_model
global
_model
if
_model
is
None
:
if
_model
is
None
:
_model
=
AlphaFold
(
model_config
(
"model_1_ptm"
))
_model
=
AlphaFold
(
model_config
(
consts
.
model
))
_model
=
_model
.
eval
()
_model
=
_model
.
eval
()
if
not
os
.
path
.
exists
(
_param_path
):
if
not
os
.
path
.
exists
(
_param_path
):
raise
FileNotFoundError
(
raise
FileNotFoundError
(
"""Cannot load pretrained parameters. Make sure to run the
"""Cannot load pretrained parameters. Make sure to run the
installation script before running tests."""
installation script before running tests."""
)
)
import_jax_weights_
(
_model
,
_param_path
,
version
=
"model_1_ptm"
)
import_jax_weights_
(
_model
,
_param_path
,
version
=
consts
.
model
)
_model
=
_model
.
cuda
()
_model
=
_model
.
cuda
()
return
_model
return
_model
...
...
tests/config.py
View file @
56d5e39c
...
@@ -2,8 +2,11 @@ import ml_collections as mlc
...
@@ -2,8 +2,11 @@ import ml_collections as mlc
consts
=
mlc
.
ConfigDict
(
consts
=
mlc
.
ConfigDict
(
{
{
"model"
:
"model_1_multimer_v3"
,
# monomer:model_1_ptm, multimer: model_1_multimer_v3
"is_multimer"
:
True
,
# monomer: False, multimer: True
"chunk_size"
:
4
,
"batch_size"
:
2
,
"batch_size"
:
2
,
"n_res"
:
11
,
"n_res"
:
22
,
"n_seq"
:
13
,
"n_seq"
:
13
,
"n_templ"
:
3
,
"n_templ"
:
3
,
"n_extra"
:
17
,
"n_extra"
:
17
,
...
@@ -16,6 +19,7 @@ consts = mlc.ConfigDict(
...
@@ -16,6 +19,7 @@ consts = mlc.ConfigDict(
"c_s"
:
384
,
"c_s"
:
384
,
"c_t"
:
64
,
"c_t"
:
64
,
"c_e"
:
64
,
"c_e"
:
64
,
"msa_logits"
:
22
# monomer: 23, multimer: 22
}
}
)
)
...
...
tests/data_utils.py
View file @
56d5e39c
...
@@ -12,9 +12,36 @@
...
@@ -12,9 +12,36 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
random
import
randint
import
numpy
as
np
import
numpy
as
np
from
scipy.spatial.transform
import
Rotation
from
scipy.spatial.transform
import
Rotation
from
tests.config
import
consts
def
random_asym_ids
(
n_res
,
split_chains
=
True
,
min_chain_len
=
4
):
n_chain
=
randint
(
1
,
n_res
//
min_chain_len
)
if
consts
.
is_multimer
else
1
if
not
split_chains
:
return
[
0
]
*
n_res
assert
n_res
>=
n_chain
pieces
=
[]
asym_ids
=
[]
final_idx
=
n_chain
-
1
for
idx
in
range
(
n_chain
-
1
):
n_stop
=
(
n_res
-
sum
(
pieces
)
-
n_chain
+
idx
-
min_chain_len
)
if
n_stop
<=
min_chain_len
:
final_idx
=
idx
break
piece
=
randint
(
min_chain_len
,
n_stop
)
pieces
.
append
(
piece
)
asym_ids
.
extend
(
piece
*
[
idx
])
asym_ids
.
extend
((
n_res
-
sum
(
pieces
))
*
[
final_idx
])
return
np
.
array
(
asym_ids
).
astype
(
np
.
int64
)
def
random_template_feats
(
n_templ
,
n
,
batch_size
=
None
):
def
random_template_feats
(
n_templ
,
n
,
batch_size
=
None
):
b
=
[]
b
=
[]
...
@@ -39,6 +66,11 @@ def random_template_feats(n_templ, n, batch_size=None):
...
@@ -39,6 +66,11 @@ def random_template_feats(n_templ, n, batch_size=None):
}
}
batch
=
{
k
:
v
.
astype
(
np
.
float32
)
for
k
,
v
in
batch
.
items
()}
batch
=
{
k
:
v
.
astype
(
np
.
float32
)
for
k
,
v
in
batch
.
items
()}
batch
[
"template_aatype"
]
=
batch
[
"template_aatype"
].
astype
(
np
.
int64
)
batch
[
"template_aatype"
]
=
batch
[
"template_aatype"
].
astype
(
np
.
int64
)
if
consts
.
is_multimer
:
asym_ids
=
np
.
array
(
random_asym_ids
(
n
))
batch
[
"asym_id"
]
=
np
.
tile
(
asym_ids
[
np
.
newaxis
,
:],
(
*
b
,
n_templ
,
1
))
return
batch
return
batch
...
...
tests/test_data_pipeline.py
View file @
56d5e39c
...
@@ -15,19 +15,13 @@
...
@@ -15,19 +15,13 @@
import
pickle
import
pickle
import
shutil
import
shutil
import
torch
import
numpy
as
np
import
numpy
as
np
import
unittest
import
unittest
from
openfold.data.data_pipeline
import
DataPipeline
from
openfold.data.data_pipeline
import
DataPipeline
from
openfold.data.templates
import
TemplateHitFeaturizer
from
openfold.data.templates
import
HhsearchHitFeaturizer
,
HmmsearchHitFeaturizer
from
openfold.model.embedders
import
(
InputEmbedder
,
RecyclingEmbedder
,
TemplateAngleEmbedder
,
TemplatePairEmbedder
,
)
import
tests.compare_utils
as
compare_utils
import
tests.compare_utils
as
compare_utils
from
tests.config
import
consts
if
compare_utils
.
alphafold_is_installed
():
if
compare_utils
.
alphafold_is_installed
():
alphafold
=
compare_utils
.
import_alphafold
()
alphafold
=
compare_utils
.
import_alphafold
()
...
@@ -45,13 +39,29 @@ class TestDataPipeline(unittest.TestCase):
...
@@ -45,13 +39,29 @@ class TestDataPipeline(unittest.TestCase):
with
open
(
"tests/test_data/alphafold_feature_dict.pickle"
,
"rb"
)
as
fp
:
with
open
(
"tests/test_data/alphafold_feature_dict.pickle"
,
"rb"
)
as
fp
:
alphafold_feature_dict
=
pickle
.
load
(
fp
)
alphafold_feature_dict
=
pickle
.
load
(
fp
)
template_featurizer
=
TemplateHitFeaturizer
(
if
consts
.
is_multimer
:
mmcif_dir
=
"tests/test_data/mmcifs"
,
# template_featurizer = HmmsearchHitFeaturizer(
max_template_date
=
"2021-12-20"
,
# mmcif_dir="tests/test_data/mmcifs",
max_hits
=
20
,
# max_template_date="2021-12-20",
kalign_binary_path
=
shutil
.
which
(
"kalign"
),
# max_hits=20,
_zero_center_positions
=
False
,
# kalign_binary_path=shutil.which("kalign"),
)
# _zero_center_positions=False,
# )
template_featurizer
=
HhsearchHitFeaturizer
(
mmcif_dir
=
"tests/test_data/mmcifs"
,
max_template_date
=
"2021-12-20"
,
max_hits
=
20
,
kalign_binary_path
=
shutil
.
which
(
"kalign"
),
_zero_center_positions
=
False
,
)
else
:
template_featurizer
=
HhsearchHitFeaturizer
(
mmcif_dir
=
"tests/test_data/mmcifs"
,
max_template_date
=
"2021-12-20"
,
max_hits
=
20
,
kalign_binary_path
=
shutil
.
which
(
"kalign"
),
_zero_center_positions
=
False
,
)
data_pipeline
=
DataPipeline
(
data_pipeline
=
DataPipeline
(
template_featurizer
=
template_featurizer
,
template_featurizer
=
template_featurizer
,
...
...
tests/test_data_transforms.py
View file @
56d5e39c
import
copy
import
gzip
import
gzip
import
os
import
pickle
import
pickle
import
numpy
as
np
import
numpy
as
np
...
@@ -181,7 +177,7 @@ class TestDataTransforms(unittest.TestCase):
...
@@ -181,7 +177,7 @@ class TestDataTransforms(unittest.TestCase):
}
}
protein
=
make_hhblits_profile
(
protein
)
protein
=
make_hhblits_profile
(
protein
)
masked_msa_config
=
config
.
data
.
common
.
masked_msa
masked_msa_config
=
config
.
data
.
common
.
masked_msa
protein
=
make_masked_msa
.
__wrapped__
(
protein
,
masked_msa_config
,
replace_fraction
=
0.15
)
protein
=
make_masked_msa
.
__wrapped__
(
protein
,
masked_msa_config
,
replace_fraction
=
0.15
,
seed
=
42
)
assert
'bert_mask'
in
protein
assert
'bert_mask'
in
protein
assert
'true_msa'
in
protein
assert
'true_msa'
in
protein
assert
'msa'
in
protein
assert
'msa'
in
protein
...
...
tests/test_embedders.py
View file @
56d5e39c
...
@@ -12,14 +12,17 @@
...
@@ -12,14 +12,17 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
random
import
torch
import
torch
import
numpy
as
np
import
unittest
import
unittest
from
tests.config
import
consts
from
tests.data_utils
import
random_asym_ids
from
openfold.model.embedders
import
(
from
openfold.model.embedders
import
(
InputEmbedder
,
InputEmbedder
,
InputEmbedderMultimer
,
RecyclingEmbedder
,
RecyclingEmbedder
,
TemplateAngleEmbedder
,
TemplateAngleEmbedder
,
TemplatePairEmbedder
,
TemplatePairEmbedder
)
)
...
@@ -35,13 +38,30 @@ class TestInputEmbedder(unittest.TestCase):
...
@@ -35,13 +38,30 @@ class TestInputEmbedder(unittest.TestCase):
n_res
=
17
n_res
=
17
n_clust
=
19
n_clust
=
19
max_relative_chain
=
2
max_relative_idx
=
32
use_chain_relative
=
True
tf
=
torch
.
rand
((
b
,
n_res
,
tf_dim
))
tf
=
torch
.
rand
((
b
,
n_res
,
tf_dim
))
ri
=
torch
.
rand
((
b
,
n_res
))
ri
=
torch
.
rand
((
b
,
n_res
))
msa
=
torch
.
rand
((
b
,
n_clust
,
n_res
,
msa_dim
))
msa
=
torch
.
rand
((
b
,
n_clust
,
n_res
,
msa_dim
))
asym_ids_flat
=
torch
.
Tensor
(
random_asym_ids
(
n_res
))
asym_id
=
torch
.
tile
(
asym_ids_flat
.
unsqueeze
(
0
),
(
b
,
1
))
entity_id
=
asym_id
sym_id
=
torch
.
zeros_like
(
entity_id
)
if
consts
.
is_multimer
:
ie
=
InputEmbedderMultimer
(
tf_dim
,
msa_dim
,
c_z
,
c_m
,
max_relative_idx
=
max_relative_idx
,
use_chain_relative
=
use_chain_relative
,
max_relative_chain
=
max_relative_chain
)
batch
=
{
"target_feat"
:
tf
,
"residue_index"
:
ri
,
"msa_feat"
:
msa
,
"asym_id"
:
asym_id
,
"entity_id"
:
entity_id
,
"sym_id"
:
sym_id
}
msa_emb
,
pair_emb
=
ie
(
batch
)
else
:
ie
=
InputEmbedder
(
tf_dim
,
msa_dim
,
c_z
,
c_m
,
relpos_k
)
msa_emb
,
pair_emb
=
ie
(
tf
=
tf
,
ri
=
ri
,
msa
=
msa
,
inplace_safe
=
False
)
ie
=
InputEmbedder
(
tf_dim
,
msa_dim
,
c_z
,
c_m
,
relpos_k
)
msa_emb
,
pair_emb
=
ie
(
tf
,
ri
,
msa
)
self
.
assertTrue
(
msa_emb
.
shape
==
(
b
,
n_clust
,
n_res
,
c_m
))
self
.
assertTrue
(
msa_emb
.
shape
==
(
b
,
n_clust
,
n_res
,
c_m
))
self
.
assertTrue
(
pair_emb
.
shape
==
(
b
,
n_res
,
n_res
,
c_z
))
self
.
assertTrue
(
pair_emb
.
shape
==
(
b
,
n_res
,
n_res
,
c_z
))
...
...
tests/test_evoformer.py
View file @
56d5e39c
...
@@ -12,6 +12,7 @@
...
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
re
import
torch
import
torch
import
numpy
as
np
import
numpy
as
np
import
unittest
import
unittest
...
@@ -48,6 +49,8 @@ class TestEvoformerStack(unittest.TestCase):
...
@@ -48,6 +49,8 @@ class TestEvoformerStack(unittest.TestCase):
transition_n
=
2
transition_n
=
2
msa_dropout
=
0.15
msa_dropout
=
0.15
pair_stack_dropout
=
0.25
pair_stack_dropout
=
0.25
opm_first
=
consts
.
is_multimer
fuse_projection_weights
=
True
if
re
.
fullmatch
(
"^model_[1-5]_multimer_v3$"
,
consts
.
model
)
else
False
inf
=
1e9
inf
=
1e9
eps
=
1e-10
eps
=
1e-10
...
@@ -65,6 +68,8 @@ class TestEvoformerStack(unittest.TestCase):
...
@@ -65,6 +68,8 @@ class TestEvoformerStack(unittest.TestCase):
transition_n
,
transition_n
,
msa_dropout
,
msa_dropout
,
pair_stack_dropout
,
pair_stack_dropout
,
opm_first
,
fuse_projection_weights
,
blocks_per_ckpt
=
None
,
blocks_per_ckpt
=
None
,
inf
=
inf
,
inf
=
inf
,
eps
=
eps
,
eps
=
eps
,
...
@@ -174,6 +179,8 @@ class TestExtraMSAStack(unittest.TestCase):
...
@@ -174,6 +179,8 @@ class TestExtraMSAStack(unittest.TestCase):
transition_n
=
5
transition_n
=
5
msa_dropout
=
0.15
msa_dropout
=
0.15
pair_stack_dropout
=
0.25
pair_stack_dropout
=
0.25
opm_first
=
consts
.
is_multimer
fuse_projection_weights
=
True
if
re
.
fullmatch
(
"^model_[1-5]_multimer_v3$"
,
consts
.
model
)
else
False
inf
=
1e9
inf
=
1e9
eps
=
1e-10
eps
=
1e-10
...
@@ -190,6 +197,8 @@ class TestExtraMSAStack(unittest.TestCase):
...
@@ -190,6 +197,8 @@ class TestExtraMSAStack(unittest.TestCase):
transition_n
,
transition_n
,
msa_dropout
,
msa_dropout
,
pair_stack_dropout
,
pair_stack_dropout
,
opm_first
,
fuse_projection_weights
,
ckpt
=
False
,
ckpt
=
False
,
inf
=
inf
,
inf
=
inf
,
eps
=
eps
,
eps
=
eps
,
...
@@ -277,7 +286,7 @@ class TestMSATransition(unittest.TestCase):
...
@@ -277,7 +286,7 @@ class TestMSATransition(unittest.TestCase):
model
=
compare_utils
.
get_global_pretrained_openfold
()
model
=
compare_utils
.
get_global_pretrained_openfold
()
out_repro
=
(
out_repro
=
(
model
.
evoformer
.
blocks
[
0
].
core
.
msa_transition
(
model
.
evoformer
.
blocks
[
0
].
msa_transition
(
torch
.
as_tensor
(
msa_act
,
dtype
=
torch
.
float32
).
cuda
(),
torch
.
as_tensor
(
msa_act
,
dtype
=
torch
.
float32
).
cuda
(),
mask
=
torch
.
as_tensor
(
msa_mask
,
dtype
=
torch
.
float32
).
cuda
(),
mask
=
torch
.
as_tensor
(
msa_mask
,
dtype
=
torch
.
float32
).
cuda
(),
)
)
...
...
tests/test_feats.py
View file @
56d5e39c
...
@@ -25,13 +25,16 @@ from openfold.np.residue_constants import (
...
@@ -25,13 +25,16 @@ from openfold.np.residue_constants import (
)
)
import
openfold.utils.feats
as
feats
import
openfold.utils.feats
as
feats
from
openfold.utils.rigid_utils
import
Rotation
,
Rigid
from
openfold.utils.rigid_utils
import
Rotation
,
Rigid
from
openfold.utils.geometry.rigid_matrix_vector
import
Rigid3Array
from
openfold.utils.geometry.rotation_matrix
import
Rot3Array
from
openfold.utils.geometry.vector
import
Vec3Array
from
openfold.utils.tensor_utils
import
(
from
openfold.utils.tensor_utils
import
(
tree_map
,
tree_map
,
tensor_tree_map
,
tensor_tree_map
,
)
)
import
tests.compare_utils
as
compare_utils
import
tests.compare_utils
as
compare_utils
from
tests.config
import
consts
from
tests.config
import
consts
from
tests.data_utils
import
random_affines_4x4
from
tests.data_utils
import
random_affines_4x4
,
random_asym_ids
if
compare_utils
.
alphafold_is_installed
():
if
compare_utils
.
alphafold_is_installed
():
alphafold
=
compare_utils
.
import_alphafold
()
alphafold
=
compare_utils
.
import_alphafold
()
...
@@ -40,6 +43,19 @@ if compare_utils.alphafold_is_installed():
...
@@ -40,6 +43,19 @@ if compare_utils.alphafold_is_installed():
class
TestFeats
(
unittest
.
TestCase
):
class
TestFeats
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
if
consts
.
is_multimer
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom_multimer
cls
.
am_fold
=
alphafold
.
model
.
folding_multimer
cls
.
am_modules
=
alphafold
.
model
.
modules_multimer
cls
.
am_rigid
=
alphafold
.
model
.
geometry
else
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom
cls
.
am_fold
=
alphafold
.
model
.
folding
cls
.
am_modules
=
alphafold
.
model
.
modules
cls
.
am_rigid
=
alphafold
.
model
.
r3
@
compare_utils
.
skip_unless_alphafold_installed
()
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_pseudo_beta_fn_compare
(
self
):
def
test_pseudo_beta_fn_compare
(
self
):
def
test_pbf
(
aatype
,
all_atom_pos
,
all_atom_mask
):
def
test_pbf
(
aatype
,
all_atom_pos
,
all_atom_mask
):
...
@@ -131,7 +147,9 @@ class TestFeats(unittest.TestCase):
...
@@ -131,7 +147,9 @@ class TestFeats(unittest.TestCase):
@
compare_utils
.
skip_unless_alphafold_installed
()
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_atom37_to_frames_compare
(
self
):
def
test_atom37_to_frames_compare
(
self
):
def
run_atom37_to_frames
(
aatype
,
all_atom_positions
,
all_atom_mask
):
def
run_atom37_to_frames
(
aatype
,
all_atom_positions
,
all_atom_mask
):
return
alphafold
.
model
.
all_atom
.
atom37_to_frames
(
if
consts
.
is_multimer
:
all_atom_positions
=
self
.
am_rigid
.
Vec3Array
.
from_array
(
all_atom_positions
)
return
self
.
am_atom
.
atom37_to_frames
(
aatype
,
all_atom_positions
,
all_atom_mask
aatype
,
all_atom_positions
,
all_atom_mask
)
)
...
@@ -150,9 +168,23 @@ class TestFeats(unittest.TestCase):
...
@@ -150,9 +168,23 @@ class TestFeats(unittest.TestCase):
}
}
out_gt
=
f
.
apply
({},
None
,
**
batch
)
out_gt
=
f
.
apply
({},
None
,
**
batch
)
to_tensor
=
lambda
t
:
torch
.
tensor
(
np
.
array
(
t
))
if
consts
.
is_multimer
:
batch
[
"asym_id"
]
=
random_asym_ids
(
n_res
)
to_tensor
=
(
lambda
t
:
torch
.
tensor
(
np
.
array
(
t
))
if
not
isinstance
(
t
,
self
.
am_rigid
.
Rigid3Array
)
else
torch
.
tensor
(
np
.
array
(
t
.
to_array
())))
else
:
to_tensor
=
lambda
t
:
torch
.
tensor
(
np
.
array
(
t
))
out_gt
=
{
k
:
to_tensor
(
v
)
for
k
,
v
in
out_gt
.
items
()}
out_gt
=
{
k
:
to_tensor
(
v
)
for
k
,
v
in
out_gt
.
items
()}
def
rigid3x4_to_4x4
(
rigid3arr
):
four_by_four
=
torch
.
zeros
(
*
rigid3arr
.
shape
[:
-
2
],
4
,
4
)
four_by_four
[...,
:
3
,
:
4
]
=
rigid3arr
four_by_four
[...,
3
,
3
]
=
1
return
four_by_four
def
flat12_to_4x4
(
flat12
):
def
flat12_to_4x4
(
flat12
):
rot
=
flat12
[...,
:
9
].
view
(
*
flat12
.
shape
[:
-
1
],
3
,
3
)
rot
=
flat12
[...,
:
9
].
view
(
*
flat12
.
shape
[:
-
1
],
3
,
3
)
trans
=
flat12
[...,
9
:]
trans
=
flat12
[...,
9
:]
...
@@ -164,10 +196,12 @@ class TestFeats(unittest.TestCase):
...
@@ -164,10 +196,12 @@ class TestFeats(unittest.TestCase):
return
four_by_four
return
four_by_four
out_gt
[
"rigidgroups_gt_frames"
]
=
flat12_to_4x4
(
convert_func
=
rigid3x4_to_4x4
if
consts
.
is_multimer
else
flat12_to_4x4
out_gt
[
"rigidgroups_gt_frames"
]
=
convert_func
(
out_gt
[
"rigidgroups_gt_frames"
]
out_gt
[
"rigidgroups_gt_frames"
]
)
)
out_gt
[
"rigidgroups_alt_gt_frames"
]
=
flat12_to_4x4
(
out_gt
[
"rigidgroups_alt_gt_frames"
]
=
convert_func
(
out_gt
[
"rigidgroups_alt_gt_frames"
]
out_gt
[
"rigidgroups_alt_gt_frames"
]
)
)
...
@@ -187,7 +221,13 @@ class TestFeats(unittest.TestCase):
...
@@ -187,7 +221,13 @@ class TestFeats(unittest.TestCase):
n
=
5
n
=
5
rots
=
torch
.
rand
((
batch_size
,
n
,
3
,
3
))
rots
=
torch
.
rand
((
batch_size
,
n
,
3
,
3
))
trans
=
torch
.
rand
((
batch_size
,
n
,
3
))
trans
=
torch
.
rand
((
batch_size
,
n
,
3
))
ts
=
Rigid
(
Rotation
(
rot_mats
=
rots
),
trans
)
if
consts
.
is_multimer
:
rotation
=
Rot3Array
.
from_array
(
rots
)
translation
=
Vec3Array
.
from_array
(
trans
)
ts
=
Rigid3Array
(
rotation
,
translation
)
else
:
ts
=
Rigid
(
Rotation
(
rot_mats
=
rots
),
trans
)
angles
=
torch
.
rand
((
batch_size
,
n
,
7
,
2
))
angles
=
torch
.
rand
((
batch_size
,
n
,
7
,
2
))
...
@@ -208,7 +248,7 @@ class TestFeats(unittest.TestCase):
...
@@ -208,7 +248,7 @@ class TestFeats(unittest.TestCase):
def
run_torsion_angles_to_frames
(
def
run_torsion_angles_to_frames
(
aatype
,
backb_to_global
,
torsion_angles_sin_cos
aatype
,
backb_to_global
,
torsion_angles_sin_cos
):
):
return
alphafold
.
mod
el
.
a
ll
_atom
.
torsion_angles_to_frames
(
return
s
el
f
.
a
m
_atom
.
torsion_angles_to_frames
(
aatype
,
aatype
,
backb_to_global
,
backb_to_global
,
torsion_angles_sin_cos
,
torsion_angles_sin_cos
,
...
@@ -221,10 +261,17 @@ class TestFeats(unittest.TestCase):
...
@@ -221,10 +261,17 @@ class TestFeats(unittest.TestCase):
aatype
=
np
.
random
.
randint
(
0
,
21
,
size
=
(
n_res
,))
aatype
=
np
.
random
.
randint
(
0
,
21
,
size
=
(
n_res
,))
affines
=
random_affines_4x4
((
n_res
,))
affines
=
random_affines_4x4
((
n_res
,))
rigids
=
alphafold
.
model
.
r3
.
rigids_from_tensor4x4
(
affines
)
transformations
=
Rigid
.
from_tensor_4x4
(
if
consts
.
is_multimer
:
torch
.
as_tensor
(
affines
).
float
()
rigids
=
self
.
am_rigid
.
Rigid3Array
.
from_array4x4
(
affines
)
)
transformations
=
Rigid3Array
.
from_tensor_4x4
(
torch
.
as_tensor
(
affines
).
float
()
)
else
:
rigids
=
self
.
am_rigid
.
rigids_from_tensor4x4
(
affines
)
transformations
=
Rigid
.
from_tensor_4x4
(
torch
.
as_tensor
(
affines
).
float
()
)
torsion_angles_sin_cos
=
np
.
random
.
rand
(
n_res
,
7
,
2
)
torsion_angles_sin_cos
=
np
.
random
.
rand
(
n_res
,
7
,
2
)
...
@@ -240,13 +287,21 @@ class TestFeats(unittest.TestCase):
...
@@ -240,13 +287,21 @@ class TestFeats(unittest.TestCase):
)
)
# Convert the Rigids to 4x4 transformation tensors
# Convert the Rigids to 4x4 transformation tensors
rots_gt
=
list
(
map
(
lambda
x
:
torch
.
as_tensor
(
np
.
array
(
x
)),
out_gt
.
rot
))
out_gt_rot
=
out_gt
.
rot
if
not
consts
.
is_multimer
else
out_gt
.
rotation
.
to_array
()
trans_gt
=
list
(
out_gt_trans
=
out_gt
.
trans
if
not
consts
.
is_multimer
else
out_gt
.
translation
.
to_array
()
map
(
lambda
x
:
torch
.
as_tensor
(
np
.
array
(
x
)),
out_gt
.
trans
)
)
if
consts
.
is_multimer
:
rots_gt
=
torch
.
cat
([
x
.
unsqueeze
(
-
1
)
for
x
in
rots_gt
],
dim
=-
1
)
rots_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt_rot
))
rots_gt
=
rots_gt
.
view
(
*
rots_gt
.
shape
[:
-
1
],
3
,
3
)
trans_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt_trans
))
trans_gt
=
torch
.
cat
([
x
.
unsqueeze
(
-
1
)
for
x
in
trans_gt
],
dim
=-
1
)
else
:
rots_gt
=
list
(
map
(
lambda
x
:
torch
.
as_tensor
(
np
.
array
(
x
)),
out_gt_rot
))
trans_gt
=
list
(
map
(
lambda
x
:
torch
.
as_tensor
(
np
.
array
(
x
)),
out_gt_trans
)
)
rots_gt
=
torch
.
cat
([
x
.
unsqueeze
(
-
1
)
for
x
in
rots_gt
],
dim
=-
1
)
rots_gt
=
rots_gt
.
view
(
*
rots_gt
.
shape
[:
-
1
],
3
,
3
)
trans_gt
=
torch
.
cat
([
x
.
unsqueeze
(
-
1
)
for
x
in
trans_gt
],
dim
=-
1
)
transforms_gt
=
torch
.
cat
([
rots_gt
,
trans_gt
.
unsqueeze
(
-
1
)],
dim
=-
1
)
transforms_gt
=
torch
.
cat
([
rots_gt
,
trans_gt
.
unsqueeze
(
-
1
)],
dim
=-
1
)
bottom_row
=
torch
.
zeros
((
*
rots_gt
.
shape
[:
-
2
],
1
,
4
))
bottom_row
=
torch
.
zeros
((
*
rots_gt
.
shape
[:
-
2
],
1
,
4
))
bottom_row
[...,
3
]
=
1
bottom_row
[...,
3
]
=
1
...
@@ -264,7 +319,13 @@ class TestFeats(unittest.TestCase):
...
@@ -264,7 +319,13 @@ class TestFeats(unittest.TestCase):
rots
=
torch
.
rand
((
batch_size
,
n_res
,
8
,
3
,
3
))
rots
=
torch
.
rand
((
batch_size
,
n_res
,
8
,
3
,
3
))
trans
=
torch
.
rand
((
batch_size
,
n_res
,
8
,
3
))
trans
=
torch
.
rand
((
batch_size
,
n_res
,
8
,
3
))
ts
=
Rigid
(
Rotation
(
rot_mats
=
rots
),
trans
)
if
consts
.
is_multimer
:
rotation
=
Rot3Array
.
from_array
(
rots
)
translation
=
Vec3Array
.
from_array
(
trans
)
ts
=
Rigid3Array
(
rotation
,
translation
)
else
:
ts
=
Rigid
(
Rotation
(
rot_mats
=
rots
),
trans
)
f
=
torch
.
randint
(
low
=
0
,
high
=
21
,
size
=
(
batch_size
,
n_res
)).
long
()
f
=
torch
.
randint
(
low
=
0
,
high
=
21
,
size
=
(
batch_size
,
n_res
)).
long
()
...
@@ -282,8 +343,7 @@ class TestFeats(unittest.TestCase):
...
@@ -282,8 +343,7 @@ class TestFeats(unittest.TestCase):
@
compare_utils
.
skip_unless_alphafold_installed
()
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_frames_and_literature_positions_to_atom14_pos_compare
(
self
):
def
test_frames_and_literature_positions_to_atom14_pos_compare
(
self
):
def
run_f
(
aatype
,
affines
):
def
run_f
(
aatype
,
affines
):
am
=
alphafold
.
model
return
self
.
am_atom
.
frames_and_literature_positions_to_atom14_pos
(
return
am
.
all_atom
.
frames_and_literature_positions_to_atom14_pos
(
aatype
,
affines
aatype
,
affines
)
)
...
@@ -294,16 +354,27 @@ class TestFeats(unittest.TestCase):
...
@@ -294,16 +354,27 @@ class TestFeats(unittest.TestCase):
aatype
=
np
.
random
.
randint
(
0
,
21
,
size
=
(
n_res
,))
aatype
=
np
.
random
.
randint
(
0
,
21
,
size
=
(
n_res
,))
affines
=
random_affines_4x4
((
n_res
,
8
))
affines
=
random_affines_4x4
((
n_res
,
8
))
rigids
=
alphafold
.
model
.
r3
.
rigids_from_tensor4x4
(
affines
)
transformations
=
Rigid
.
from_tensor_4x4
(
if
consts
.
is_multimer
:
torch
.
as_tensor
(
affines
).
float
()
rigids
=
self
.
am_rigid
.
Rigid3Array
.
from_array4x4
(
affines
)
)
transformations
=
Rigid3Array
.
from_tensor_4x4
(
torch
.
as_tensor
(
affines
).
float
()
)
else
:
rigids
=
self
.
am_rigid
.
rigids_from_tensor4x4
(
affines
)
transformations
=
Rigid
.
from_tensor_4x4
(
torch
.
as_tensor
(
affines
).
float
()
)
out_gt
=
f
.
apply
({},
None
,
aatype
,
rigids
)
out_gt
=
f
.
apply
({},
None
,
aatype
,
rigids
)
jax
.
tree_map
(
lambda
x
:
x
.
block_until_ready
(),
out_gt
)
jax
.
tree_map
(
lambda
x
:
x
.
block_until_ready
(),
out_gt
)
out_gt
=
torch
.
stack
(
[
torch
.
as_tensor
(
np
.
array
(
x
))
for
x
in
out_gt
],
dim
=-
1
if
consts
.
is_multimer
:
)
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
.
to_array
()))
else
:
out_gt
=
torch
.
stack
(
[
torch
.
as_tensor
(
np
.
array
(
x
))
for
x
in
out_gt
],
dim
=-
1
)
out_repro
=
feats
.
frames_and_literature_positions_to_atom14_pos
(
out_repro
=
feats
.
frames_and_literature_positions_to_atom14_pos
(
transformations
.
cuda
(),
transformations
.
cuda
(),
...
...
tests/test_import_weights.py
View file @
56d5e39c
...
@@ -65,7 +65,7 @@ class TestImportWeights(unittest.TestCase):
...
@@ -65,7 +65,7 @@ class TestImportWeights(unittest.TestCase):
)
)
][
1
].
transpose
(
-
1
,
-
2
)
][
1
].
transpose
(
-
1
,
-
2
)
),
),
model
.
evoformer
.
blocks
[
1
].
core
.
outer_product_mean
.
linear_1
.
weight
,
model
.
evoformer
.
blocks
[
1
].
outer_product_mean
.
linear_1
.
weight
,
),
),
]
]
...
...
tests/test_loss.py
View file @
56d5e39c
...
@@ -13,18 +13,18 @@
...
@@ -13,18 +13,18 @@
# limitations under the License.
# limitations under the License.
import
os
import
os
import
math
import
torch
import
torch
import
numpy
as
np
import
numpy
as
np
from
pathlib
import
Path
import
unittest
import
unittest
import
ml_collections
as
mlc
import
ml_collections
as
mlc
from
openfold.data
import
data_transforms
from
openfold.data
import
data_transforms
from
openfold.np
import
residue_constants
from
openfold.utils.rigid_utils
import
(
from
openfold.utils.rigid_utils
import
(
Rotation
,
Rotation
,
Rigid
,
Rigid
,
)
)
import
openfold.utils.feats
as
feats
from
openfold.utils.loss
import
(
from
openfold.utils.loss
import
(
torsion_angle_loss
,
torsion_angle_loss
,
compute_fape
,
compute_fape
,
...
@@ -43,6 +43,8 @@ from openfold.utils.loss import (
...
@@ -43,6 +43,8 @@ from openfold.utils.loss import (
sidechain_loss
,
sidechain_loss
,
tm_loss
,
tm_loss
,
compute_plddt
,
compute_plddt
,
compute_tm
,
chain_center_of_mass_loss
)
)
from
openfold.utils.tensor_utils
import
(
from
openfold.utils.tensor_utils
import
(
tree_map
,
tree_map
,
...
@@ -51,7 +53,7 @@ from openfold.utils.tensor_utils import (
...
@@ -51,7 +53,7 @@ from openfold.utils.tensor_utils import (
)
)
import
tests.compare_utils
as
compare_utils
import
tests.compare_utils
as
compare_utils
from
tests.config
import
consts
from
tests.config
import
consts
from
tests.data_utils
import
random_affines_vector
,
random_affines_4x4
from
tests.data_utils
import
random_affines_vector
,
random_affines_4x4
,
random_asym_ids
if
compare_utils
.
alphafold_is_installed
():
if
compare_utils
.
alphafold_is_installed
():
alphafold
=
compare_utils
.
import_alphafold
()
alphafold
=
compare_utils
.
import_alphafold
()
...
@@ -64,7 +66,30 @@ def affine_vector_to_4x4(affine):
...
@@ -64,7 +66,30 @@ def affine_vector_to_4x4(affine):
return
r
.
to_tensor_4x4
()
return
r
.
to_tensor_4x4
()
def
affine_vector_to_rigid
(
am_rigid
,
affine
):
rigid_flat
=
np
.
split
(
affine
,
7
,
axis
=-
1
)
rigid_flat
=
[
r
.
squeeze
(
-
1
)
for
r
in
rigid_flat
]
qw
,
qx
,
qy
,
qz
=
rigid_flat
[:
4
]
trans
=
rigid_flat
[
4
:]
rotation
=
am_rigid
.
Rot3Array
.
from_quaternion
(
qw
,
qx
,
qy
,
qz
,
normalize
=
True
)
translation
=
am_rigid
.
Vec3Array
(
*
trans
)
return
am_rigid
.
Rigid3Array
(
rotation
,
translation
)
class
TestLoss
(
unittest
.
TestCase
):
class
TestLoss
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
if
consts
.
is_multimer
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom_multimer
cls
.
am_fold
=
alphafold
.
model
.
folding_multimer
cls
.
am_modules
=
alphafold
.
model
.
modules_multimer
cls
.
am_rigid
=
alphafold
.
model
.
geometry
else
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom
cls
.
am_fold
=
alphafold
.
model
.
folding
cls
.
am_modules
=
alphafold
.
model
.
modules
cls
.
am_rigid
=
alphafold
.
model
.
r3
def
test_run_torsion_angle_loss
(
self
):
def
test_run_torsion_angle_loss
(
self
):
batch_size
=
consts
.
batch_size
batch_size
=
consts
.
batch_size
n_res
=
consts
.
n_res
n_res
=
consts
.
n_res
...
@@ -127,7 +152,10 @@ class TestLoss(unittest.TestCase):
...
@@ -127,7 +152,10 @@ class TestLoss(unittest.TestCase):
@
compare_utils
.
skip_unless_alphafold_installed
()
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_between_residue_bond_loss_compare
(
self
):
def
test_between_residue_bond_loss_compare
(
self
):
def
run_brbl
(
pred_pos
,
pred_atom_mask
,
residue_index
,
aatype
):
def
run_brbl
(
pred_pos
,
pred_atom_mask
,
residue_index
,
aatype
):
return
alphafold
.
model
.
all_atom
.
between_residue_bond_loss
(
if
consts
.
is_multimer
:
pred_pos
=
self
.
am_rigid
.
Vec3Array
.
from_array
(
pred_pos
)
return
self
.
am_atom
.
between_residue_bond_loss
(
pred_pos
,
pred_pos
,
pred_atom_mask
,
pred_atom_mask
,
residue_index
,
residue_index
,
...
@@ -184,12 +212,22 @@ class TestLoss(unittest.TestCase):
...
@@ -184,12 +212,22 @@ class TestLoss(unittest.TestCase):
@
compare_utils
.
skip_unless_alphafold_installed
()
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_between_residue_clash_loss_compare
(
self
):
def
test_between_residue_clash_loss_compare
(
self
):
def
run_brcl
(
pred_pos
,
atom_exists
,
atom_radius
,
res_ind
):
def
run_brcl
(
pred_pos
,
atom_exists
,
atom_radius
,
res_ind
,
asym_id
):
return
alphafold
.
model
.
all_atom
.
between_residue_clash_loss
(
if
consts
.
is_multimer
:
pred_pos
=
self
.
am_rigid
.
Vec3Array
.
from_array
(
pred_pos
)
return
self
.
am_atom
.
between_residue_clash_loss
(
pred_pos
,
atom_exists
,
atom_radius
,
res_ind
,
asym_id
)
return
self
.
am_atom
.
between_residue_clash_loss
(
pred_pos
,
pred_pos
,
atom_exists
,
atom_exists
,
atom_radius
,
atom_radius
,
res_ind
,
res_ind
)
)
f
=
hk
.
transform
(
run_brcl
)
f
=
hk
.
transform
(
run_brcl
)
...
@@ -198,10 +236,24 @@ class TestLoss(unittest.TestCase):
...
@@ -198,10 +236,24 @@ class TestLoss(unittest.TestCase):
pred_pos
=
np
.
random
.
rand
(
n_res
,
14
,
3
).
astype
(
np
.
float32
)
pred_pos
=
np
.
random
.
rand
(
n_res
,
14
,
3
).
astype
(
np
.
float32
)
atom_exists
=
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
14
)).
astype
(
np
.
float32
)
atom_exists
=
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
14
)).
astype
(
np
.
float32
)
atom_radius
=
np
.
random
.
rand
(
n_res
,
14
).
astype
(
np
.
float32
)
res_ind
=
np
.
arange
(
res_ind
=
np
.
arange
(
n_res
,
n_res
,
)
)
residx_atom14_to_atom37
=
np
.
random
.
randint
(
0
,
37
,
(
n_res
,
14
)).
astype
(
np
.
int64
)
atomtype_radius
=
[
residue_constants
.
van_der_waals_radius
[
name
[
0
]]
for
name
in
residue_constants
.
atom_types
]
atomtype_radius
=
np
.
array
(
atomtype_radius
).
astype
(
np
.
float32
)
atom_radius
=
(
atom_exists
*
atomtype_radius
[
residx_atom14_to_atom37
]
)
asym_id
=
None
if
consts
.
is_multimer
:
asym_id
=
random_asym_ids
(
n_res
)
out_gt
=
f
.
apply
(
out_gt
=
f
.
apply
(
{},
{},
...
@@ -210,6 +262,7 @@ class TestLoss(unittest.TestCase):
...
@@ -210,6 +262,7 @@ class TestLoss(unittest.TestCase):
atom_exists
,
atom_exists
,
atom_radius
,
atom_radius
,
res_ind
,
res_ind
,
asym_id
)
)
out_gt
=
jax
.
tree_map
(
lambda
x
:
x
.
block_until_ready
(),
out_gt
)
out_gt
=
jax
.
tree_map
(
lambda
x
:
x
.
block_until_ready
(),
out_gt
)
out_gt
=
jax
.
tree_map
(
lambda
x
:
torch
.
tensor
(
np
.
copy
(
x
)),
out_gt
)
out_gt
=
jax
.
tree_map
(
lambda
x
:
torch
.
tensor
(
np
.
copy
(
x
)),
out_gt
)
...
@@ -219,6 +272,7 @@ class TestLoss(unittest.TestCase):
...
@@ -219,6 +272,7 @@ class TestLoss(unittest.TestCase):
torch
.
tensor
(
atom_exists
).
cuda
(),
torch
.
tensor
(
atom_exists
).
cuda
(),
torch
.
tensor
(
atom_radius
).
cuda
(),
torch
.
tensor
(
atom_radius
).
cuda
(),
torch
.
tensor
(
res_ind
).
cuda
(),
torch
.
tensor
(
res_ind
).
cuda
(),
torch
.
tensor
(
asym_id
).
cuda
()
if
asym_id
is
not
None
else
None
,
)
)
out_repro
=
tensor_tree_map
(
lambda
x
:
x
.
cpu
(),
out_repro
)
out_repro
=
tensor_tree_map
(
lambda
x
:
x
.
cpu
(),
out_repro
)
...
@@ -242,6 +296,36 @@ class TestLoss(unittest.TestCase):
...
@@ -242,6 +296,36 @@ class TestLoss(unittest.TestCase):
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
)
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_compute_ptm_compare
(
self
):
n_res
=
consts
.
n_res
max_bin
=
31
no_bins
=
64
logits
=
np
.
random
.
rand
(
n_res
,
n_res
,
no_bins
)
boundaries
=
np
.
linspace
(
0
,
max_bin
,
num
=
(
no_bins
-
1
))
ptm_gt
=
alphafold
.
common
.
confidence
.
predicted_tm_score
(
logits
,
boundaries
)
ptm_gt
=
torch
.
tensor
(
ptm_gt
)
logits_t
=
torch
.
tensor
(
logits
)
ptm_repro
=
compute_tm
(
logits_t
,
no_bins
=
no_bins
,
max_bin
=
max_bin
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
ptm_gt
-
ptm_repro
))
<
consts
.
eps
)
if
consts
.
is_multimer
:
asym_id
=
random_asym_ids
(
n_res
)
iptm_gt
=
alphafold
.
common
.
confidence
.
predicted_tm_score
(
logits
,
boundaries
,
asym_id
=
asym_id
,
interface
=
True
)
iptm_gt
=
torch
.
tensor
(
iptm_gt
)
iptm_repro
=
compute_tm
(
logits_t
,
no_bins
=
no_bins
,
max_bin
=
max_bin
,
asym_id
=
torch
.
tensor
(
asym_id
),
interface
=
True
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
iptm_gt
-
iptm_repro
))
<
consts
.
eps
)
def
test_find_structural_violations
(
self
):
def
test_find_structural_violations
(
self
):
n
=
consts
.
n_res
n
=
consts
.
n_res
...
@@ -265,8 +349,21 @@ class TestLoss(unittest.TestCase):
...
@@ -265,8 +349,21 @@ class TestLoss(unittest.TestCase):
def
test_find_structural_violations_compare
(
self
):
def
test_find_structural_violations_compare
(
self
):
def
run_fsv
(
batch
,
pos
,
config
):
def
run_fsv
(
batch
,
pos
,
config
):
cwd
=
os
.
getcwd
()
cwd
=
os
.
getcwd
()
os
.
chdir
(
"tests/test_data"
)
fpath
=
Path
(
__file__
).
parent
.
resolve
()
/
"test_data"
loss
=
alphafold
.
model
.
folding
.
find_structural_violations
(
os
.
chdir
(
str
(
fpath
))
if
consts
.
is_multimer
:
atom14_pred_pos
=
self
.
am_rigid
.
Vec3Array
.
from_array
(
pos
)
return
self
.
am_fold
.
find_structural_violations
(
batch
[
'aatype'
],
batch
[
'residue_index'
],
batch
[
'atom14_atom_exists'
],
atom14_pred_pos
,
config
,
batch
[
'asym_id'
]
)
loss
=
self
.
am_fold
.
find_structural_violations
(
batch
,
batch
,
pos
,
pos
,
config
,
config
,
...
@@ -287,6 +384,9 @@ class TestLoss(unittest.TestCase):
...
@@ -287,6 +384,9 @@ class TestLoss(unittest.TestCase):
).
astype
(
np
.
int64
),
).
astype
(
np
.
int64
),
}
}
if
consts
.
is_multimer
:
batch
[
"asym_id"
]
=
random_asym_ids
(
n_res
)
pred_pos
=
np
.
random
.
rand
(
n_res
,
14
,
3
)
pred_pos
=
np
.
random
.
rand
(
n_res
,
14
,
3
)
config
=
mlc
.
ConfigDict
(
config
=
mlc
.
ConfigDict
(
...
@@ -380,14 +480,14 @@ class TestLoss(unittest.TestCase):
...
@@ -380,14 +480,14 @@ class TestLoss(unittest.TestCase):
n_seq
=
consts
.
n_seq
n_seq
=
consts
.
n_seq
value
=
{
value
=
{
"logits"
:
np
.
random
.
rand
(
n_res
,
n_seq
,
23
).
astype
(
np
.
float32
),
"logits"
:
np
.
random
.
rand
(
n_res
,
n_seq
,
consts
.
msa_logits
).
astype
(
np
.
float32
),
}
}
batch
=
{
batch
=
{
"true_msa"
:
np
.
random
.
randint
(
0
,
21
,
(
n_res
,
n_seq
)),
"true_msa"
:
np
.
random
.
randint
(
0
,
21
,
(
n_res
,
n_seq
)),
"bert_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
n_seq
)).
astype
(
"bert_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
n_seq
)).
astype
(
np
.
float32
np
.
float32
)
,
)
}
}
out_gt
=
f
.
apply
({},
None
,
value
,
batch
)[
"loss"
]
out_gt
=
f
.
apply
({},
None
,
value
,
batch
)[
"loss"
]
...
@@ -399,7 +499,9 @@ class TestLoss(unittest.TestCase):
...
@@ -399,7 +499,9 @@ class TestLoss(unittest.TestCase):
with
torch
.
no_grad
():
with
torch
.
no_grad
():
out_repro
=
masked_msa_loss
(
out_repro
=
masked_msa_loss
(
value
[
"logits"
],
value
[
"logits"
],
**
batch
,
batch
[
"true_msa"
],
batch
[
"bert_mask"
],
consts
.
msa_logits
)
)
out_repro
=
tensor_tree_map
(
lambda
t
:
t
.
cpu
(),
out_repro
)
out_repro
=
tensor_tree_map
(
lambda
t
:
t
.
cpu
(),
out_repro
)
...
@@ -506,10 +608,28 @@ class TestLoss(unittest.TestCase):
...
@@ -506,10 +608,28 @@ class TestLoss(unittest.TestCase):
c_chi_loss
=
config
.
model
.
heads
.
structure_module
c_chi_loss
=
config
.
model
.
heads
.
structure_module
def
run_supervised_chi_loss
(
value
,
batch
):
def
run_supervised_chi_loss
(
value
,
batch
):
if
consts
.
is_multimer
:
pred_angles
=
np
.
reshape
(
value
[
'sidechains'
][
'angles_sin_cos'
],
[
-
1
,
consts
.
n_res
,
7
,
2
])
unnormed_angles
=
np
.
reshape
(
value
[
'sidechains'
][
'unnormalized_angles_sin_cos'
],
[
-
1
,
consts
.
n_res
,
7
,
2
])
chi_loss
,
_
,
_
=
self
.
am_fold
.
supervised_chi_loss
(
batch
[
'seq_mask'
],
batch
[
'chi_mask'
],
batch
[
'aatype'
],
batch
[
'chi_angles'
],
pred_angles
,
unnormed_angles
,
c_chi_loss
)
return
chi_loss
ret
=
{
ret
=
{
"loss"
:
jax
.
numpy
.
array
(
0.0
),
"loss"
:
jax
.
numpy
.
array
(
0.0
),
}
}
alphafold
.
model
.
fold
ing
.
supervised_chi_loss
(
self
.
am_
fold
.
supervised_chi_loss
(
ret
,
batch
,
value
,
c_chi_loss
ret
,
batch
,
value
,
c_chi_loss
)
)
return
ret
[
"loss"
]
return
ret
[
"loss"
]
...
@@ -561,6 +681,40 @@ class TestLoss(unittest.TestCase):
...
@@ -561,6 +681,40 @@ class TestLoss(unittest.TestCase):
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_violation_loss
(
self
):
config
=
compare_utils
.
get_alphafold_config
()
c_viol
=
config
.
model
.
heads
.
structure_module
n_res
=
consts
.
n_res
batch
=
{
"seq_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,)).
astype
(
np
.
float32
),
"residue_index"
:
np
.
arange
(
n_res
),
"aatype"
:
np
.
random
.
randint
(
0
,
21
,
(
n_res
,)),
}
if
consts
.
is_multimer
:
batch
[
"asym_id"
]
=
random_asym_ids
(
n_res
)
batch
=
tree_map
(
lambda
n
:
torch
.
tensor
(
n
).
cuda
(),
batch
,
np
.
ndarray
)
atom14_pred_pos
=
np
.
random
.
rand
(
n_res
,
14
,
3
).
astype
(
np
.
float32
)
atom14_pred_pos
=
torch
.
tensor
(
atom14_pred_pos
).
cuda
()
batch
=
data_transforms
.
make_atom14_masks
(
batch
)
loss_sum_clash
=
violation_loss
(
find_structural_violations
(
batch
,
atom14_pred_pos
,
**
c_viol
),
average_clashes
=
False
,
**
batch
)
loss_sum_clash
=
loss_sum_clash
.
cpu
()
loss_avg_clash
=
violation_loss
(
find_structural_violations
(
batch
,
atom14_pred_pos
,
**
c_viol
),
average_clashes
=
True
,
**
batch
)
loss_avg_clash
=
loss_avg_clash
.
cpu
()
@
compare_utils
.
skip_unless_alphafold_installed
()
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_violation_loss_compare
(
self
):
def
test_violation_loss_compare
(
self
):
config
=
compare_utils
.
get_alphafold_config
()
config
=
compare_utils
.
get_alphafold_config
()
...
@@ -570,15 +724,31 @@ class TestLoss(unittest.TestCase):
...
@@ -570,15 +724,31 @@ class TestLoss(unittest.TestCase):
ret
=
{
ret
=
{
"loss"
:
np
.
array
(
0.0
).
astype
(
np
.
float32
),
"loss"
:
np
.
array
(
0.0
).
astype
(
np
.
float32
),
}
}
if
consts
.
is_multimer
:
atom14_pred_pos
=
self
.
am_rigid
.
Vec3Array
.
from_array
(
atom14_pred_pos
)
viol
=
self
.
am_fold
.
find_structural_violations
(
batch
[
'aatype'
],
batch
[
'residue_index'
],
batch
[
'atom14_atom_exists'
],
atom14_pred_pos
,
c_viol
,
batch
[
'asym_id'
]
)
return
self
.
am_fold
.
structural_violation_loss
(
mask
=
batch
[
'atom14_atom_exists'
],
violations
=
viol
,
config
=
c_viol
)
value
=
{}
value
=
{}
value
[
value
[
"violations"
"violations"
]
=
alphafold
.
model
.
fold
ing
.
find_structural_violations
(
]
=
self
.
am_
fold
.
find_structural_violations
(
batch
,
batch
,
atom14_pred_pos
,
atom14_pred_pos
,
c_viol
,
c_viol
,
)
)
alphafold
.
model
.
folding
.
structural_violation_loss
(
self
.
am_fold
.
structural_violation_loss
(
ret
,
ret
,
batch
,
batch
,
value
,
value
,
...
@@ -593,13 +763,17 @@ class TestLoss(unittest.TestCase):
...
@@ -593,13 +763,17 @@ class TestLoss(unittest.TestCase):
batch
=
{
batch
=
{
"seq_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,)).
astype
(
np
.
float32
),
"seq_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,)).
astype
(
np
.
float32
),
"residue_index"
:
np
.
arange
(
n_res
),
"residue_index"
:
np
.
arange
(
n_res
),
"aatype"
:
np
.
random
.
randint
(
0
,
21
,
(
n_res
,))
,
"aatype"
:
np
.
random
.
randint
(
0
,
21
,
(
n_res
,))
}
}
alphafold
.
model
.
tf
.
data_transforms
.
make_atom14_masks
(
batch
)
batch
=
{
k
:
np
.
array
(
v
)
for
k
,
v
in
batch
.
items
()}
if
consts
.
is_multimer
:
batch
[
"asym_id"
]
=
random_asym_ids
(
n_res
)
atom14_pred_pos
=
np
.
random
.
rand
(
n_res
,
14
,
3
).
astype
(
np
.
float32
)
atom14_pred_pos
=
np
.
random
.
rand
(
n_res
,
14
,
3
).
astype
(
np
.
float32
)
alphafold
.
model
.
tf
.
data_transforms
.
make_atom14_masks
(
batch
)
batch
=
{
k
:
np
.
array
(
v
)
for
k
,
v
in
batch
.
items
()}
out_gt
=
f
.
apply
({},
None
,
batch
,
atom14_pred_pos
)
out_gt
=
f
.
apply
({},
None
,
batch
,
atom14_pred_pos
)
out_gt
=
torch
.
tensor
(
np
.
array
(
out_gt
.
block_until_ready
()))
out_gt
=
torch
.
tensor
(
np
.
array
(
out_gt
.
block_until_ready
()))
...
@@ -676,10 +850,31 @@ class TestLoss(unittest.TestCase):
...
@@ -676,10 +850,31 @@ class TestLoss(unittest.TestCase):
c_sm
=
config
.
model
.
heads
.
structure_module
c_sm
=
config
.
model
.
heads
.
structure_module
def
run_bb_loss
(
batch
,
value
):
def
run_bb_loss
(
batch
,
value
):
if
consts
.
is_multimer
:
intra_chain_mask
=
(
batch
[
"asym_id"
][...,
None
]
==
batch
[
"asym_id"
][...,
None
,
:]).
astype
(
np
.
float32
)
gt_rigid
=
affine_vector_to_rigid
(
self
.
am_rigid
,
batch
[
"backbone_affine_tensor"
])
target_rigid
=
affine_vector_to_rigid
(
self
.
am_rigid
,
value
[
'traj'
])
intra_chain_bb_loss
,
intra_chain_fape
=
self
.
am_fold
.
backbone_loss
(
gt_rigid
=
gt_rigid
,
gt_frames_mask
=
batch
[
"backbone_affine_mask"
],
gt_positions_mask
=
batch
[
"backbone_affine_mask"
],
target_rigid
=
target_rigid
,
config
=
c_sm
.
intra_chain_fape
,
pair_mask
=
intra_chain_mask
)
interface_bb_loss
,
interface_fape
=
self
.
am_fold
.
backbone_loss
(
gt_rigid
=
gt_rigid
,
gt_frames_mask
=
batch
[
"backbone_affine_mask"
],
gt_positions_mask
=
batch
[
"backbone_affine_mask"
],
target_rigid
=
target_rigid
,
config
=
c_sm
.
interface_fape
,
pair_mask
=
1.
-
intra_chain_mask
)
return
intra_chain_bb_loss
+
interface_bb_loss
ret
=
{
ret
=
{
"loss"
:
np
.
array
(
0.0
),
"loss"
:
np
.
array
(
0.0
),
}
}
alphafold
.
model
.
fold
ing
.
backbone_loss
(
ret
,
batch
,
value
,
c_sm
)
self
.
am_
fold
.
backbone_loss
(
ret
,
batch
,
value
,
c_sm
)
return
ret
[
"loss"
]
return
ret
[
"loss"
]
f
=
hk
.
transform
(
run_bb_loss
)
f
=
hk
.
transform
(
run_bb_loss
)
...
@@ -691,7 +886,7 @@ class TestLoss(unittest.TestCase):
...
@@ -691,7 +886,7 @@ class TestLoss(unittest.TestCase):
"backbone_affine_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,)).
astype
(
"backbone_affine_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,)).
astype
(
np
.
float32
np
.
float32
),
),
"use_clamped_fape"
:
np
.
array
(
0.0
)
,
"use_clamped_fape"
:
np
.
array
(
0.0
)
}
}
value
=
{
value
=
{
...
@@ -703,6 +898,9 @@ class TestLoss(unittest.TestCase):
...
@@ -703,6 +898,9 @@ class TestLoss(unittest.TestCase):
),
),
}
}
if
consts
.
is_multimer
:
batch
[
"asym_id"
]
=
random_asym_ids
(
n_res
)
out_gt
=
f
.
apply
({},
None
,
batch
,
value
)
out_gt
=
f
.
apply
({},
None
,
batch
,
value
)
out_gt
=
torch
.
tensor
(
np
.
array
(
out_gt
.
block_until_ready
()))
out_gt
=
torch
.
tensor
(
np
.
array
(
out_gt
.
block_until_ready
()))
...
@@ -715,8 +913,18 @@ class TestLoss(unittest.TestCase):
...
@@ -715,8 +913,18 @@ class TestLoss(unittest.TestCase):
)
)
batch
[
"backbone_rigid_mask"
]
=
batch
[
"backbone_affine_mask"
]
batch
[
"backbone_rigid_mask"
]
=
batch
[
"backbone_affine_mask"
]
out_repro
=
backbone_loss
(
traj
=
value
[
"traj"
],
**
{
**
batch
,
**
c_sm
})
if
consts
.
is_multimer
:
out_repro
=
out_repro
.
cpu
()
intra_chain_mask
=
(
batch
[
"asym_id"
][...,
None
]
==
batch
[
"asym_id"
][...,
None
,
:]).
to
(
dtype
=
value
[
"traj"
].
dtype
)
intra_chain_out
=
backbone_loss
(
traj
=
value
[
"traj"
],
pair_mask
=
intra_chain_mask
,
**
{
**
batch
,
**
c_sm
.
intra_chain_fape
})
interface_out
=
backbone_loss
(
traj
=
value
[
"traj"
],
pair_mask
=
1.
-
intra_chain_mask
,
**
{
**
batch
,
**
c_sm
.
interface_fape
})
out_repro
=
intra_chain_out
+
interface_out
out_repro
=
out_repro
.
cpu
()
else
:
out_repro
=
backbone_loss
(
traj
=
value
[
"traj"
],
**
{
**
batch
,
**
c_sm
})
out_repro
=
out_repro
.
cpu
()
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
...
@@ -726,9 +934,29 @@ class TestLoss(unittest.TestCase):
...
@@ -726,9 +934,29 @@ class TestLoss(unittest.TestCase):
c_sm
=
config
.
model
.
heads
.
structure_module
c_sm
=
config
.
model
.
heads
.
structure_module
def
run_sidechain_loss
(
batch
,
value
,
atom14_pred_positions
):
def
run_sidechain_loss
(
batch
,
value
,
atom14_pred_positions
):
if
consts
.
is_multimer
:
atom14_pred_positions
=
self
.
am_rigid
.
Vec3Array
.
from_array
(
atom14_pred_positions
)
all_atom_positions
=
self
.
am_rigid
.
Vec3Array
.
from_array
(
batch
[
"all_atom_positions"
])
gt_positions
,
gt_mask
,
alt_naming_is_better
=
self
.
am_fold
.
compute_atom14_gt
(
aatype
=
batch
[
"aatype"
],
all_atom_positions
=
all_atom_positions
,
all_atom_mask
=
batch
[
"all_atom_mask"
],
pred_pos
=
atom14_pred_positions
)
pred_frames
=
self
.
am_rigid
.
Rigid3Array
.
from_array4x4
(
value
[
"sidechains"
][
"frames"
])
pred_positions
=
self
.
am_rigid
.
Vec3Array
.
from_array
(
value
[
"sidechains"
][
"atom_pos"
])
gt_sc_frames
,
gt_sc_frames_mask
=
self
.
am_fold
.
compute_frames
(
aatype
=
batch
[
"aatype"
],
all_atom_positions
=
all_atom_positions
,
all_atom_mask
=
batch
[
"all_atom_mask"
],
use_alt
=
alt_naming_is_better
)
return
self
.
am_fold
.
sidechain_loss
(
gt_sc_frames
,
gt_sc_frames_mask
,
gt_positions
,
gt_mask
,
pred_frames
,
pred_positions
,
c_sm
)[
'loss'
]
batch
=
{
batch
=
{
**
batch
,
**
batch
,
**
alphafold
.
mod
el
.
a
ll
_atom
.
atom37_to_frames
(
**
s
el
f
.
a
m
_atom
.
atom37_to_frames
(
batch
[
"aatype"
],
batch
[
"aatype"
],
batch
[
"all_atom_positions"
],
batch
[
"all_atom_positions"
],
batch
[
"all_atom_mask"
],
batch
[
"all_atom_mask"
],
...
@@ -738,21 +966,21 @@ class TestLoss(unittest.TestCase):
...
@@ -738,21 +966,21 @@ class TestLoss(unittest.TestCase):
v
[
"sidechains"
]
=
{}
v
[
"sidechains"
]
=
{}
v
[
"sidechains"
][
v
[
"sidechains"
][
"frames"
"frames"
]
=
alphafold
.
model
.
r3
.
rigids_from_tensor4x4
(
]
=
self
.
am_rigid
.
rigids_from_tensor4x4
(
value
[
"sidechains"
][
"frames"
]
value
[
"sidechains"
][
"frames"
]
)
)
v
[
"sidechains"
][
"atom_pos"
]
=
alphafold
.
model
.
r3
.
vecs_from_tensor
(
v
[
"sidechains"
][
"atom_pos"
]
=
self
.
am_rigid
.
vecs_from_tensor
(
value
[
"sidechains"
][
"atom_pos"
]
value
[
"sidechains"
][
"atom_pos"
]
)
)
v
.
update
(
v
.
update
(
alphafold
.
model
.
fold
ing
.
compute_renamed_ground_truth
(
self
.
am_
fold
.
compute_renamed_ground_truth
(
batch
,
batch
,
atom14_pred_positions
,
atom14_pred_positions
,
)
)
)
)
value
=
v
value
=
v
ret
=
alphafold
.
model
.
fold
ing
.
sidechain_loss
(
batch
,
value
,
c_sm
)
ret
=
self
.
am_
fold
.
sidechain_loss
(
batch
,
value
,
c_sm
)
return
ret
[
"loss"
]
return
ret
[
"loss"
]
f
=
hk
.
transform
(
run_sidechain_loss
)
f
=
hk
.
transform
(
run_sidechain_loss
)
...
@@ -816,6 +1044,7 @@ class TestLoss(unittest.TestCase):
...
@@ -816,6 +1044,7 @@ class TestLoss(unittest.TestCase):
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
@
compare_utils
.
skip_unless_alphafold_installed
()
@
compare_utils
.
skip_unless_alphafold_installed
()
@
unittest
.
skipIf
(
not
consts
.
is_multimer
and
"ptm"
not
in
consts
.
model
,
"Not enabled for non-ptm models."
)
def
test_tm_loss_compare
(
self
):
def
test_tm_loss_compare
(
self
):
config
=
compare_utils
.
get_alphafold_config
()
config
=
compare_utils
.
get_alphafold_config
()
c_tm
=
config
.
model
.
heads
.
predicted_aligned_error
c_tm
=
config
.
model
.
heads
.
predicted_aligned_error
...
@@ -882,6 +1111,33 @@ class TestLoss(unittest.TestCase):
...
@@ -882,6 +1111,33 @@ class TestLoss(unittest.TestCase):
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_chain_center_of_mass_loss
(
self
):
batch_size
=
consts
.
batch_size
n_res
=
consts
.
n_res
batch
=
{
"all_atom_positions"
:
np
.
random
.
rand
(
batch_size
,
n_res
,
37
,
3
).
astype
(
np
.
float32
)
*
10.0
,
"all_atom_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
batch_size
,
n_res
,
37
)).
astype
(
np
.
float32
),
"asym_id"
:
np
.
stack
([
random_asym_ids
(
n_res
)
for
_
in
range
(
batch_size
)])
}
config
=
{
"weight"
:
0.05
,
"clamp_distance"
:
-
4.0
,
}
final_atom_positions
=
torch
.
rand
(
batch_size
,
n_res
,
37
,
3
).
cuda
()
to_tensor
=
lambda
t
:
torch
.
tensor
(
t
).
cuda
()
batch
=
tree_map
(
to_tensor
,
batch
,
np
.
ndarray
)
out_repro
=
chain_center_of_mass_loss
(
all_atom_pred_pos
=
final_atom_positions
,
**
{
**
batch
,
**
config
},
)
out_repro
=
out_repro
.
cpu
()
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
tests/test_model.py
View file @
56d5e39c
...
@@ -12,6 +12,7 @@
...
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
pathlib
import
Path
import
pickle
import
pickle
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -20,8 +21,7 @@ import unittest
...
@@ -20,8 +21,7 @@ import unittest
from
openfold.config
import
model_config
from
openfold.config
import
model_config
from
openfold.data
import
data_transforms
from
openfold.data
import
data_transforms
from
openfold.model.model
import
AlphaFold
from
openfold.model.model
import
AlphaFold
import
openfold.utils.feats
as
feats
from
openfold.utils.tensor_utils
import
tensor_tree_map
from
openfold.utils.tensor_utils
import
tree_map
,
tensor_tree_map
import
tests.compare_utils
as
compare_utils
import
tests.compare_utils
as
compare_utils
from
tests.config
import
consts
from
tests.config
import
consts
from
tests.data_utils
import
(
from
tests.data_utils
import
(
...
@@ -36,13 +36,26 @@ if compare_utils.alphafold_is_installed():
...
@@ -36,13 +36,26 @@ if compare_utils.alphafold_is_installed():
class
TestModel
(
unittest
.
TestCase
):
class
TestModel
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
if
consts
.
is_multimer
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom_multimer
cls
.
am_fold
=
alphafold
.
model
.
folding_multimer
cls
.
am_modules
=
alphafold
.
model
.
modules_multimer
cls
.
am_rigid
=
alphafold
.
model
.
geometry
else
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom
cls
.
am_fold
=
alphafold
.
model
.
folding
cls
.
am_modules
=
alphafold
.
model
.
modules
cls
.
am_rigid
=
alphafold
.
model
.
r3
def
test_dry_run
(
self
):
def
test_dry_run
(
self
):
n_seq
=
consts
.
n_seq
n_seq
=
consts
.
n_seq
n_templ
=
consts
.
n_templ
n_templ
=
consts
.
n_templ
n_res
=
consts
.
n_res
n_res
=
consts
.
n_res
n_extra_seq
=
consts
.
n_extra
n_extra_seq
=
consts
.
n_extra
c
=
model_config
(
"model_1"
)
c
=
model_config
(
consts
.
model
,
train
=
True
)
c
.
model
.
evoformer_stack
.
no_blocks
=
4
# no need to go overboard here
c
.
model
.
evoformer_stack
.
no_blocks
=
4
# no need to go overboard here
c
.
model
.
evoformer_stack
.
blocks_per_ckpt
=
None
# don't want to set up
c
.
model
.
evoformer_stack
.
blocks_per_ckpt
=
None
# don't want to set up
# deepspeed for this test
# deepspeed for this test
...
@@ -56,6 +69,7 @@ class TestModel(unittest.TestCase):
...
@@ -56,6 +69,7 @@ class TestModel(unittest.TestCase):
).
float
()
).
float
()
batch
[
"aatype"
]
=
torch
.
argmax
(
batch
[
"target_feat"
],
dim
=-
1
)
batch
[
"aatype"
]
=
torch
.
argmax
(
batch
[
"target_feat"
],
dim
=-
1
)
batch
[
"residue_index"
]
=
torch
.
arange
(
n_res
)
batch
[
"residue_index"
]
=
torch
.
arange
(
n_res
)
batch
[
"msa_feat"
]
=
torch
.
rand
((
n_seq
,
n_res
,
c
.
model
.
input_embedder
.
msa_dim
))
batch
[
"msa_feat"
]
=
torch
.
rand
((
n_seq
,
n_res
,
c
.
model
.
input_embedder
.
msa_dim
))
t_feats
=
random_template_feats
(
n_templ
,
n_res
)
t_feats
=
random_template_feats
(
n_templ
,
n_res
)
batch
.
update
({
k
:
torch
.
tensor
(
v
)
for
k
,
v
in
t_feats
.
items
()})
batch
.
update
({
k
:
torch
.
tensor
(
v
)
for
k
,
v
in
t_feats
.
items
()})
...
@@ -68,6 +82,12 @@ class TestModel(unittest.TestCase):
...
@@ -68,6 +82,12 @@ class TestModel(unittest.TestCase):
batch
.
update
(
data_transforms
.
make_atom14_masks
(
batch
))
batch
.
update
(
data_transforms
.
make_atom14_masks
(
batch
))
batch
[
"no_recycling_iters"
]
=
torch
.
tensor
(
2.
)
batch
[
"no_recycling_iters"
]
=
torch
.
tensor
(
2.
)
if
consts
.
is_multimer
:
batch
[
"asym_id"
]
=
torch
.
randint
(
0
,
1
,
size
=
(
n_res
,))
batch
[
"entity_id"
]
=
torch
.
randint
(
0
,
1
,
size
=
(
n_res
,))
batch
[
"sym_id"
]
=
torch
.
randint
(
0
,
1
,
size
=
(
n_res
,))
batch
[
"extra_deletion_matrix"
]
=
torch
.
randint
(
0
,
2
,
size
=
(
n_extra_seq
,
n_res
))
add_recycling_dims
=
lambda
t
:
(
add_recycling_dims
=
lambda
t
:
(
t
.
unsqueeze
(
-
1
).
expand
(
*
t
.
shape
,
c
.
data
.
common
.
max_recycling_iters
)
t
.
unsqueeze
(
-
1
).
expand
(
*
t
.
shape
,
c
.
data
.
common
.
max_recycling_iters
)
)
)
...
@@ -77,10 +97,14 @@ class TestModel(unittest.TestCase):
...
@@ -77,10 +97,14 @@ class TestModel(unittest.TestCase):
out
=
model
(
batch
)
out
=
model
(
batch
)
@
compare_utils
.
skip_unless_alphafold_installed
()
@
compare_utils
.
skip_unless_alphafold_installed
()
@
unittest
.
skipIf
(
consts
.
is_multimer
,
"Additional changes required for multimer."
)
def
test_compare
(
self
):
def
test_compare
(
self
):
#TODO: Fix test data for multimer MSA features
def
run_alphafold
(
batch
):
def
run_alphafold
(
batch
):
config
=
compare_utils
.
get_alphafold_config
()
config
=
compare_utils
.
get_alphafold_config
()
model
=
alphafold
.
model
.
modules
.
AlphaFold
(
config
.
model
)
model
=
self
.
am_modules
.
AlphaFold
(
config
.
model
)
return
model
(
return
model
(
batch
=
batch
,
batch
=
batch
,
is_training
=
False
,
is_training
=
False
,
...
@@ -91,7 +115,8 @@ class TestModel(unittest.TestCase):
...
@@ -91,7 +115,8 @@ class TestModel(unittest.TestCase):
params
=
compare_utils
.
fetch_alphafold_module_weights
(
""
)
params
=
compare_utils
.
fetch_alphafold_module_weights
(
""
)
with
open
(
"tests/test_data/sample_feats.pickle"
,
"rb"
)
as
fp
:
fpath
=
Path
(
__file__
).
parent
.
resolve
()
/
"test_data/sample_feats.pickle"
with
open
(
str
(
fpath
),
"rb"
)
as
fp
:
batch
=
pickle
.
load
(
fp
)
batch
=
pickle
.
load
(
fp
)
out_gt
=
f
.
apply
(
params
,
jax
.
random
.
PRNGKey
(
42
),
batch
)
out_gt
=
f
.
apply
(
params
,
jax
.
random
.
PRNGKey
(
42
),
batch
)
...
@@ -100,7 +125,8 @@ class TestModel(unittest.TestCase):
...
@@ -100,7 +125,8 @@ class TestModel(unittest.TestCase):
# atom37_to_atom14 doesn't like batches
# atom37_to_atom14 doesn't like batches
batch
[
"residx_atom14_to_atom37"
]
=
batch
[
"residx_atom14_to_atom37"
][
0
]
batch
[
"residx_atom14_to_atom37"
]
=
batch
[
"residx_atom14_to_atom37"
][
0
]
batch
[
"atom14_atom_exists"
]
=
batch
[
"atom14_atom_exists"
][
0
]
batch
[
"atom14_atom_exists"
]
=
batch
[
"atom14_atom_exists"
][
0
]
out_gt
=
alphafold
.
model
.
all_atom
.
atom37_to_atom14
(
out_gt
,
batch
)
out_gt
=
self
.
am_atom
.
atom37_to_atom14
(
out_gt
,
batch
)
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
.
block_until_ready
()))
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
.
block_until_ready
()))
batch
[
"no_recycling_iters"
]
=
np
.
array
([
3.
,
3.
,
3.
,
3.
,])
batch
[
"no_recycling_iters"
]
=
np
.
array
([
3.
,
3.
,
3.
,
3.
,])
...
...
tests/test_outer_product_mean.py
View file @
56d5e39c
...
@@ -81,7 +81,7 @@ class TestOuterProductMean(unittest.TestCase):
...
@@ -81,7 +81,7 @@ class TestOuterProductMean(unittest.TestCase):
model
=
compare_utils
.
get_global_pretrained_openfold
()
model
=
compare_utils
.
get_global_pretrained_openfold
()
out_repro
=
(
out_repro
=
(
model
.
evoformer
.
blocks
[
0
]
.
core
model
.
evoformer
.
blocks
[
0
]
.
outer_product_mean
(
.
outer_product_mean
(
torch
.
as_tensor
(
msa_act
).
cuda
(),
torch
.
as_tensor
(
msa_act
).
cuda
(),
chunk_size
=
4
,
chunk_size
=
4
,
...
...
tests/test_pair_transition.py
View file @
56d5e39c
...
@@ -76,7 +76,7 @@ class TestPairTransition(unittest.TestCase):
...
@@ -76,7 +76,7 @@ class TestPairTransition(unittest.TestCase):
model
=
compare_utils
.
get_global_pretrained_openfold
()
model
=
compare_utils
.
get_global_pretrained_openfold
()
out_repro
=
(
out_repro
=
(
model
.
evoformer
.
blocks
[
0
].
core
model
.
evoformer
.
blocks
[
0
].
pair_stack
.
pair_transition
(
.
pair_transition
(
torch
.
as_tensor
(
pair_act
,
dtype
=
torch
.
float32
).
cuda
(),
torch
.
as_tensor
(
pair_act
,
dtype
=
torch
.
float32
).
cuda
(),
chunk_size
=
4
,
chunk_size
=
4
,
...
...
tests/test_primitives.py
View file @
56d5e39c
...
@@ -13,20 +13,17 @@
...
@@ -13,20 +13,17 @@
# limitations under the License.
# limitations under the License.
import
torch
import
torch
import
numpy
as
np
import
unittest
import
unittest
from
openfold.model.primitives
import
(
from
openfold.model.primitives
import
Attention
Attention
,
)
from
tests.config
import
consts
from
tests.config
import
consts
class
TestLMA
(
unittest
.
TestCase
):
class
TestLMA
(
unittest
.
TestCase
):
def
test_lma_vs_attention
(
self
):
def
test_lma_vs_attention
(
self
):
batch_size
=
consts
.
batch_size
batch_size
=
consts
.
batch_size
c_hidden
=
32
c_hidden
=
32
n
=
2
**
12
n
=
2
**
12
no_heads
=
4
no_heads
=
4
q
=
torch
.
rand
(
batch_size
,
n
,
c_hidden
).
cuda
()
q
=
torch
.
rand
(
batch_size
,
n
,
c_hidden
).
cuda
()
...
@@ -34,20 +31,17 @@ class TestLMA(unittest.TestCase):
...
@@ -34,20 +31,17 @@ class TestLMA(unittest.TestCase):
bias
=
[
torch
.
rand
(
no_heads
,
1
,
n
)]
bias
=
[
torch
.
rand
(
no_heads
,
1
,
n
)]
bias
=
[
b
.
cuda
()
for
b
in
bias
]
bias
=
[
b
.
cuda
()
for
b
in
bias
]
gating_fill
=
torch
.
rand
(
c_hidden
*
no_heads
,
c_hidden
)
o_fill
=
torch
.
rand
(
c_hidden
,
c_hidden
*
no_heads
)
a
=
Attention
(
a
=
Attention
(
c_hidden
,
c_hidden
,
c_hidden
,
c_hidden
,
no_heads
c_hidden
,
c_hidden
,
c_hidden
,
c_hidden
,
no_heads
).
cuda
()
).
cuda
()
with
torch
.
no_grad
():
with
torch
.
no_grad
():
l
=
a
(
q
,
kv
,
biases
=
bias
,
use_lma
=
True
)
l
=
a
(
q
,
kv
,
biases
=
bias
,
use_lma
=
True
)
real
=
a
(
q
,
kv
,
biases
=
bias
)
real
=
a
(
q
,
kv
,
biases
=
bias
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
l
-
real
))
<
consts
.
eps
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
l
-
real
))
<
consts
.
eps
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
tests/test_structure_module.py
View file @
56d5e39c
...
@@ -18,21 +18,19 @@ import unittest
...
@@ -18,21 +18,19 @@ import unittest
from
openfold.data.data_transforms
import
make_atom14_masks_np
from
openfold.data.data_transforms
import
make_atom14_masks_np
from
openfold.np.residue_constants
import
(
from
openfold.np.residue_constants
import
(
restype_rigid_group_default_frame
,
restype_atom14_to_rigid_group
,
restype_atom14_mask
,
restype_atom14_mask
,
restype_atom14_rigid_group_positions
,
restype_atom37_mask
,
restype_atom37_mask
,
)
)
from
openfold.model.structure_module
import
(
from
openfold.model.structure_module
import
(
StructureModule
,
StructureModule
,
StructureModuleTransition
,
StructureModuleTransition
,
BackboneUpdate
,
AngleResnet
,
AngleResnet
,
InvariantPointAttention
,
InvariantPointAttention
,
)
)
import
openfold.utils.feats
as
feats
from
openfold.utils.rigid_utils
import
Rotation
,
Rigid
from
openfold.utils.rigid_utils
import
Rotation
,
Rigid
from
openfold.utils.geometry.rigid_matrix_vector
import
Rigid3Array
from
openfold.utils.geometry.rotation_matrix
import
Rot3Array
from
openfold.utils.geometry.vector
import
Vec3Array
import
tests.compare_utils
as
compare_utils
import
tests.compare_utils
as
compare_utils
from
tests.config
import
consts
from
tests.config
import
consts
from
tests.data_utils
import
(
from
tests.data_utils
import
(
...
@@ -46,6 +44,19 @@ if compare_utils.alphafold_is_installed():
...
@@ -46,6 +44,19 @@ if compare_utils.alphafold_is_installed():
class
TestStructureModule
(
unittest
.
TestCase
):
class
TestStructureModule
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
if
consts
.
is_multimer
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom_multimer
cls
.
am_fold
=
alphafold
.
model
.
folding_multimer
cls
.
am_modules
=
alphafold
.
model
.
modules_multimer
cls
.
am_rigid
=
alphafold
.
model
.
geometry
else
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom
cls
.
am_fold
=
alphafold
.
model
.
folding
cls
.
am_modules
=
alphafold
.
model
.
modules
cls
.
am_rigid
=
alphafold
.
model
.
r3
def
test_structure_module_shape
(
self
):
def
test_structure_module_shape
(
self
):
batch_size
=
consts
.
batch_size
batch_size
=
consts
.
batch_size
n
=
consts
.
n_res
n
=
consts
.
n_res
...
@@ -81,6 +92,7 @@ class TestStructureModule(unittest.TestCase):
...
@@ -81,6 +92,7 @@ class TestStructureModule(unittest.TestCase):
trans_scale_factor
,
trans_scale_factor
,
ar_epsilon
,
ar_epsilon
,
inf
,
inf
,
is_multimer
=
consts
.
is_multimer
)
)
s
=
torch
.
rand
((
batch_size
,
n
,
c_s
))
s
=
torch
.
rand
((
batch_size
,
n
,
c_s
))
...
@@ -89,7 +101,11 @@ class TestStructureModule(unittest.TestCase):
...
@@ -89,7 +101,11 @@ class TestStructureModule(unittest.TestCase):
out
=
sm
({
"single"
:
s
,
"pair"
:
z
},
f
)
out
=
sm
({
"single"
:
s
,
"pair"
:
z
},
f
)
self
.
assertTrue
(
out
[
"frames"
].
shape
==
(
no_layers
,
batch_size
,
n
,
7
))
if
consts
.
is_multimer
:
self
.
assertTrue
(
out
[
"frames"
].
shape
==
(
no_layers
,
batch_size
,
n
,
4
,
4
))
else
:
self
.
assertTrue
(
out
[
"frames"
].
shape
==
(
no_layers
,
batch_size
,
n
,
7
))
self
.
assertTrue
(
self
.
assertTrue
(
out
[
"angles"
].
shape
==
(
no_layers
,
batch_size
,
n
,
no_angles
,
2
)
out
[
"angles"
].
shape
==
(
no_layers
,
batch_size
,
n
,
no_angles
,
2
)
)
)
...
@@ -121,11 +137,14 @@ class TestStructureModule(unittest.TestCase):
...
@@ -121,11 +137,14 @@ class TestStructureModule(unittest.TestCase):
c_global
=
config
.
model
.
global_config
c_global
=
config
.
model
.
global_config
def
run_sm
(
representations
,
batch
):
def
run_sm
(
representations
,
batch
):
sm
=
alphafold
.
model
.
fold
ing
.
StructureModule
(
c_sm
,
c_global
)
sm
=
self
.
am_
fold
.
StructureModule
(
c_sm
,
c_global
)
representations
=
{
representations
=
{
k
:
jax
.
lax
.
stop_gradient
(
v
)
for
k
,
v
in
representations
.
items
()
k
:
jax
.
lax
.
stop_gradient
(
v
)
for
k
,
v
in
representations
.
items
()
}
}
batch
=
{
k
:
jax
.
lax
.
stop_gradient
(
v
)
for
k
,
v
in
batch
.
items
()}
batch
=
{
k
:
jax
.
lax
.
stop_gradient
(
v
)
for
k
,
v
in
batch
.
items
()}
if
consts
.
is_multimer
:
return
sm
(
representations
,
batch
,
is_training
=
False
,
compute_loss
=
True
)
return
sm
(
representations
,
batch
,
is_training
=
False
)
return
sm
(
representations
,
batch
,
is_training
=
False
)
f
=
hk
.
transform
(
run_sm
)
f
=
hk
.
transform
(
run_sm
)
...
@@ -181,6 +200,19 @@ class TestStructureModule(unittest.TestCase):
...
@@ -181,6 +200,19 @@ class TestStructureModule(unittest.TestCase):
class
TestInvariantPointAttention
(
unittest
.
TestCase
):
class
TestInvariantPointAttention
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
if
consts
.
is_multimer
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom_multimer
cls
.
am_fold
=
alphafold
.
model
.
folding_multimer
cls
.
am_modules
=
alphafold
.
model
.
modules_multimer
cls
.
am_rigid
=
alphafold
.
model
.
geometry
else
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom
cls
.
am_fold
=
alphafold
.
model
.
folding
cls
.
am_modules
=
alphafold
.
model
.
modules
cls
.
am_rigid
=
alphafold
.
model
.
r3
def
test_shape
(
self
):
def
test_shape
(
self
):
c_m
=
13
c_m
=
13
c_z
=
17
c_z
=
17
...
@@ -197,13 +229,18 @@ class TestInvariantPointAttention(unittest.TestCase):
...
@@ -197,13 +229,18 @@ class TestInvariantPointAttention(unittest.TestCase):
mask
=
torch
.
ones
((
batch_size
,
n_res
))
mask
=
torch
.
ones
((
batch_size
,
n_res
))
rot_mats
=
torch
.
rand
((
batch_size
,
n_res
,
3
,
3
))
rot_mats
=
torch
.
rand
((
batch_size
,
n_res
,
3
,
3
))
rots
=
Rotation
(
rot_mats
=
rot_mats
,
quats
=
None
)
trans
=
torch
.
rand
((
batch_size
,
n_res
,
3
))
trans
=
torch
.
rand
((
batch_size
,
n_res
,
3
))
r
=
Rigid
(
rots
,
trans
)
if
consts
.
is_multimer
:
rotation
=
Rot3Array
.
from_array
(
rot_mats
)
translation
=
Vec3Array
.
from_array
(
trans
)
r
=
Rigid3Array
(
rotation
,
translation
)
else
:
rots
=
Rotation
(
rot_mats
=
rot_mats
,
quats
=
None
)
r
=
Rigid
(
rots
,
trans
)
ipa
=
InvariantPointAttention
(
ipa
=
InvariantPointAttention
(
c_m
,
c_z
,
c_hidden
,
no_heads
,
no_qp
,
no_vp
c_m
,
c_z
,
c_hidden
,
no_heads
,
no_qp
,
no_vp
,
is_multimer
=
consts
.
is_multimer
)
)
shape_before
=
s
.
shape
shape_before
=
s
.
shape
...
@@ -215,16 +252,26 @@ class TestInvariantPointAttention(unittest.TestCase):
...
@@ -215,16 +252,26 @@ class TestInvariantPointAttention(unittest.TestCase):
def
test_ipa_compare
(
self
):
def
test_ipa_compare
(
self
):
def
run_ipa
(
act
,
static_feat_2d
,
mask
,
affine
):
def
run_ipa
(
act
,
static_feat_2d
,
mask
,
affine
):
config
=
compare_utils
.
get_alphafold_config
()
config
=
compare_utils
.
get_alphafold_config
()
ipa
=
alphafold
.
model
.
fold
ing
.
InvariantPointAttention
(
ipa
=
self
.
am_
fold
.
InvariantPointAttention
(
config
.
model
.
heads
.
structure_module
,
config
.
model
.
heads
.
structure_module
,
config
.
model
.
global_config
,
config
.
model
.
global_config
,
)
)
attn
=
ipa
(
inputs_1d
=
act
,
if
consts
.
is_multimer
:
inputs_2d
=
static_feat_2d
,
attn
=
ipa
(
mask
=
mask
,
inputs_1d
=
act
,
affine
=
affine
,
inputs_2d
=
static_feat_2d
,
)
mask
=
mask
,
rigid
=
affine
)
else
:
attn
=
ipa
(
inputs_1d
=
act
,
inputs_2d
=
static_feat_2d
,
mask
=
mask
,
affine
=
affine
)
return
attn
return
attn
f
=
hk
.
transform
(
run_ipa
)
f
=
hk
.
transform
(
run_ipa
)
...
@@ -238,13 +285,20 @@ class TestInvariantPointAttention(unittest.TestCase):
...
@@ -238,13 +285,20 @@ class TestInvariantPointAttention(unittest.TestCase):
sample_mask
=
np
.
ones
((
n_res
,
1
))
sample_mask
=
np
.
ones
((
n_res
,
1
))
affines
=
random_affines_4x4
((
n_res
,))
affines
=
random_affines_4x4
((
n_res
,))
rigids
=
alphafold
.
model
.
r3
.
rigids_from_tensor4x4
(
affines
)
quats
=
alphafold
.
model
.
r3
.
rigids_to_quataffine
(
rigids
)
transformations
=
Rigid
.
from_tensor_4x4
(
torch
.
as_tensor
(
affines
).
float
().
cuda
()
)
sample_affine
=
quats
if
consts
.
is_multimer
:
rigids
=
self
.
am_rigid
.
Rigid3Array
.
from_array4x4
(
affines
)
transformations
=
Rigid3Array
.
from_tensor_4x4
(
torch
.
as_tensor
(
affines
).
float
().
cuda
()
)
sample_affine
=
rigids
else
:
rigids
=
self
.
am_rigid
.
rigids_from_tensor4x4
(
affines
)
quats
=
self
.
am_rigid
.
rigids_to_quataffine
(
rigids
)
transformations
=
Rigid
.
from_tensor_4x4
(
torch
.
as_tensor
(
affines
).
float
().
cuda
()
)
sample_affine
=
quats
ipa_params
=
compare_utils
.
fetch_alphafold_module_weights
(
ipa_params
=
compare_utils
.
fetch_alphafold_module_weights
(
"alphafold/alphafold_iteration/structure_module/"
"alphafold/alphafold_iteration/structure_module/"
...
...
tests/test_template.py
View file @
56d5e39c
...
@@ -12,6 +12,7 @@
...
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
re
import
torch
import
torch
import
numpy
as
np
import
numpy
as
np
import
unittest
import
unittest
...
@@ -19,7 +20,6 @@ from openfold.model.template import (
...
@@ -19,7 +20,6 @@ from openfold.model.template import (
TemplatePointwiseAttention
,
TemplatePointwiseAttention
,
TemplatePairStack
,
TemplatePairStack
,
)
)
from
openfold.utils.tensor_utils
import
tree_map
import
tests.compare_utils
as
compare_utils
import
tests.compare_utils
as
compare_utils
from
tests.config
import
consts
from
tests.config
import
consts
from
tests.data_utils
import
random_template_feats
from
tests.data_utils
import
random_template_feats
...
@@ -54,6 +54,19 @@ class TestTemplatePointwiseAttention(unittest.TestCase):
...
@@ -54,6 +54,19 @@ class TestTemplatePointwiseAttention(unittest.TestCase):
class
TestTemplatePairStack
(
unittest
.
TestCase
):
class
TestTemplatePairStack
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
if
consts
.
is_multimer
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom_multimer
cls
.
am_fold
=
alphafold
.
model
.
folding_multimer
cls
.
am_modules
=
alphafold
.
model
.
modules_multimer
cls
.
am_rigid
=
alphafold
.
model
.
geometry
else
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom
cls
.
am_fold
=
alphafold
.
model
.
folding
cls
.
am_modules
=
alphafold
.
model
.
modules
cls
.
am_rigid
=
alphafold
.
model
.
r3
def
test_shape
(
self
):
def
test_shape
(
self
):
batch_size
=
consts
.
batch_size
batch_size
=
consts
.
batch_size
c_t
=
consts
.
c_t
c_t
=
consts
.
c_t
...
@@ -65,6 +78,8 @@ class TestTemplatePairStack(unittest.TestCase):
...
@@ -65,6 +78,8 @@ class TestTemplatePairStack(unittest.TestCase):
dropout
=
0.25
dropout
=
0.25
n_templ
=
consts
.
n_templ
n_templ
=
consts
.
n_templ
n_res
=
consts
.
n_res
n_res
=
consts
.
n_res
tri_mul_first
=
consts
.
is_multimer
fuse_projection_weights
=
True
if
re
.
fullmatch
(
"^model_[1-5]_multimer_v3$"
,
consts
.
model
)
else
False
blocks_per_ckpt
=
None
blocks_per_ckpt
=
None
chunk_size
=
4
chunk_size
=
4
inf
=
1e7
inf
=
1e7
...
@@ -78,6 +93,8 @@ class TestTemplatePairStack(unittest.TestCase):
...
@@ -78,6 +93,8 @@ class TestTemplatePairStack(unittest.TestCase):
no_heads
=
no_heads
,
no_heads
=
no_heads
,
pair_transition_n
=
pt_inner_dim
,
pair_transition_n
=
pt_inner_dim
,
dropout_rate
=
dropout
,
dropout_rate
=
dropout
,
tri_mul_first
=
tri_mul_first
,
fuse_projection_weights
=
fuse_projection_weights
,
blocks_per_ckpt
=
None
,
blocks_per_ckpt
=
None
,
inf
=
inf
,
inf
=
inf
,
eps
=
eps
,
eps
=
eps
,
...
@@ -96,12 +113,40 @@ class TestTemplatePairStack(unittest.TestCase):
...
@@ -96,12 +113,40 @@ class TestTemplatePairStack(unittest.TestCase):
def
run_template_pair_stack
(
pair_act
,
pair_mask
):
def
run_template_pair_stack
(
pair_act
,
pair_mask
):
config
=
compare_utils
.
get_alphafold_config
()
config
=
compare_utils
.
get_alphafold_config
()
c_ee
=
config
.
model
.
embeddings_and_evoformer
c_ee
=
config
.
model
.
embeddings_and_evoformer
tps
=
alphafold
.
model
.
modules
.
TemplatePairStack
(
c_ee
.
template
.
template_pair_stack
,
if
consts
.
is_multimer
:
config
.
model
.
global_config
,
safe_key
=
alphafold
.
model
.
prng
.
SafeKey
(
hk
.
next_rng_key
())
name
=
"template_pair_stack"
,
template_iteration
=
self
.
am_modules
.
TemplateEmbeddingIteration
(
)
c_ee
.
template
.
template_pair_stack
,
act
=
tps
(
pair_act
,
pair_mask
,
is_training
=
False
)
config
.
model
.
global_config
,
name
=
'template_embedding_iteration'
)
def
template_iteration_fn
(
x
):
act
,
safe_key
=
x
safe_key
,
safe_subkey
=
safe_key
.
split
()
act
=
template_iteration
(
act
=
act
,
pair_mask
=
pair_mask
,
is_training
=
False
,
safe_key
=
safe_subkey
)
return
(
act
,
safe_key
)
if
config
.
model
.
global_config
.
use_remat
:
template_iteration_fn
=
hk
.
remat
(
template_iteration_fn
)
safe_key
,
safe_subkey
=
safe_key
.
split
()
template_stack
=
alphafold
.
model
.
layer_stack
.
layer_stack
(
c_ee
.
template
.
template_pair_stack
.
num_block
)(
template_iteration_fn
)
act
,
_
=
template_stack
((
pair_act
,
safe_subkey
))
else
:
tps
=
self
.
am_modules
.
TemplatePairStack
(
c_ee
.
template
.
template_pair_stack
,
config
.
model
.
global_config
,
name
=
"template_pair_stack"
,
)
act
=
tps
(
pair_act
,
pair_mask
,
is_training
=
False
)
ln
=
hk
.
LayerNorm
([
-
1
],
True
,
True
,
name
=
"output_layer_norm"
)
ln
=
hk
.
LayerNorm
([
-
1
],
True
,
True
,
name
=
"output_layer_norm"
)
act
=
ln
(
act
)
act
=
ln
(
act
)
return
act
return
act
...
@@ -115,10 +160,16 @@ class TestTemplatePairStack(unittest.TestCase):
...
@@ -115,10 +160,16 @@ class TestTemplatePairStack(unittest.TestCase):
low
=
0
,
high
=
2
,
size
=
(
n_res
,
n_res
)
low
=
0
,
high
=
2
,
size
=
(
n_res
,
n_res
)
).
astype
(
np
.
float32
)
).
astype
(
np
.
float32
)
params
=
compare_utils
.
fetch_alphafold_module_weights
(
if
consts
.
is_multimer
:
"alphafold/alphafold_iteration/evoformer/template_embedding/"
params
=
compare_utils
.
fetch_alphafold_module_weights
(
+
"single_template_embedding/template_pair_stack"
"alphafold/alphafold_iteration/evoformer/template_embedding/"
)
+
"single_template_embedding/template_embedding_iteration"
)
else
:
params
=
compare_utils
.
fetch_alphafold_module_weights
(
"alphafold/alphafold_iteration/evoformer/template_embedding/"
+
"single_template_embedding/template_pair_stack"
)
params
.
update
(
params
.
update
(
compare_utils
.
fetch_alphafold_module_weights
(
compare_utils
.
fetch_alphafold_module_weights
(
"alphafold/alphafold_iteration/evoformer/template_embedding/"
"alphafold/alphafold_iteration/evoformer/template_embedding/"
...
@@ -132,7 +183,7 @@ class TestTemplatePairStack(unittest.TestCase):
...
@@ -132,7 +183,7 @@ class TestTemplatePairStack(unittest.TestCase):
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
))
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
))
model
=
compare_utils
.
get_global_pretrained_openfold
()
model
=
compare_utils
.
get_global_pretrained_openfold
()
out_repro
=
model
.
template_pair_stack
(
out_repro
=
model
.
template_
embedder
.
template_
pair_stack
(
torch
.
as_tensor
(
pair_act
).
unsqueeze
(
-
4
).
cuda
(),
torch
.
as_tensor
(
pair_act
).
unsqueeze
(
-
4
).
cuda
(),
torch
.
as_tensor
(
pair_mask
).
unsqueeze
(
-
3
).
cuda
(),
torch
.
as_tensor
(
pair_mask
).
unsqueeze
(
-
3
).
cuda
(),
chunk_size
=
None
,
chunk_size
=
None
,
...
@@ -143,15 +194,32 @@ class TestTemplatePairStack(unittest.TestCase):
...
@@ -143,15 +194,32 @@ class TestTemplatePairStack(unittest.TestCase):
class
Template
(
unittest
.
TestCase
):
class
Template
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
if
consts
.
is_multimer
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom_multimer
cls
.
am_fold
=
alphafold
.
model
.
folding_multimer
cls
.
am_modules
=
alphafold
.
model
.
modules_multimer
cls
.
am_rigid
=
alphafold
.
model
.
geometry
else
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom
cls
.
am_fold
=
alphafold
.
model
.
folding
cls
.
am_modules
=
alphafold
.
model
.
modules
cls
.
am_rigid
=
alphafold
.
model
.
r3
@
compare_utils
.
skip_unless_alphafold_installed
()
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_compare
(
self
):
def
test_compare
(
self
):
def
test_template_embedding
(
pair
,
batch
,
mask_2d
):
def
test_template_embedding
(
pair
,
batch
,
mask_2d
,
mc_
mask_2d
):
config
=
compare_utils
.
get_alphafold_config
()
config
=
compare_utils
.
get_alphafold_config
()
te
=
alphafold
.
model
.
modules
.
TemplateEmbedding
(
te
=
self
.
am_
modules
.
TemplateEmbedding
(
config
.
model
.
embeddings_and_evoformer
.
template
,
config
.
model
.
embeddings_and_evoformer
.
template
,
config
.
model
.
global_config
,
config
.
model
.
global_config
,
)
)
act
=
te
(
pair
,
batch
,
mask_2d
,
is_training
=
False
)
if
consts
.
is_multimer
:
act
=
te
(
pair
,
batch
,
mask_2d
,
multichain_mask_2d
=
mc_mask_2d
,
is_training
=
False
)
else
:
act
=
te
(
pair
,
batch
,
mask_2d
,
is_training
=
False
)
return
act
return
act
f
=
hk
.
transform
(
test_template_embedding
)
f
=
hk
.
transform
(
test_template_embedding
)
...
@@ -162,6 +230,14 @@ class Template(unittest.TestCase):
...
@@ -162,6 +230,14 @@ class Template(unittest.TestCase):
pair_act
=
np
.
random
.
rand
(
n_res
,
n_res
,
consts
.
c_z
).
astype
(
np
.
float32
)
pair_act
=
np
.
random
.
rand
(
n_res
,
n_res
,
consts
.
c_z
).
astype
(
np
.
float32
)
batch
=
random_template_feats
(
n_templ
,
n_res
)
batch
=
random_template_feats
(
n_templ
,
n_res
)
batch
[
"template_all_atom_masks"
]
=
batch
[
"template_all_atom_mask"
]
batch
[
"template_all_atom_masks"
]
=
batch
[
"template_all_atom_mask"
]
multichain_mask_2d
=
None
if
consts
.
is_multimer
:
asym_id
=
batch
[
'asym_id'
][
0
]
multichain_mask_2d
=
(
asym_id
[...,
None
]
==
asym_id
[...,
None
,
:]
).
astype
(
np
.
float32
)
pair_mask
=
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
n_res
)).
astype
(
np
.
float32
)
pair_mask
=
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
n_res
)).
astype
(
np
.
float32
)
# Fetch pretrained parameters (but only from one block)]
# Fetch pretrained parameters (but only from one block)]
params
=
compare_utils
.
fetch_alphafold_module_weights
(
params
=
compare_utils
.
fetch_alphafold_module_weights
(
...
@@ -169,7 +245,7 @@ class Template(unittest.TestCase):
...
@@ -169,7 +245,7 @@ class Template(unittest.TestCase):
)
)
out_gt
=
f
.
apply
(
out_gt
=
f
.
apply
(
params
,
jax
.
random
.
PRNGKey
(
42
),
pair_act
,
batch
,
pair_mask
params
,
jax
.
random
.
PRNGKey
(
42
),
pair_act
,
batch
,
pair_mask
,
multichain_mask_2d
).
block_until_ready
()
).
block_until_ready
()
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
))
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
))
...
@@ -177,13 +253,30 @@ class Template(unittest.TestCase):
...
@@ -177,13 +253,30 @@ class Template(unittest.TestCase):
batch
[
"target_feat"
]
=
np
.
eye
(
22
)[
inds
]
batch
[
"target_feat"
]
=
np
.
eye
(
22
)[
inds
]
model
=
compare_utils
.
get_global_pretrained_openfold
()
model
=
compare_utils
.
get_global_pretrained_openfold
()
out_repro
=
model
.
embed_templates
(
{
k
:
torch
.
as_tensor
(
v
).
cuda
()
for
k
,
v
in
batch
.
items
()},
template_feats
=
{
k
:
torch
.
as_tensor
(
v
).
cuda
()
for
k
,
v
in
batch
.
items
()}
torch
.
as_tensor
(
pair_act
).
cuda
(),
if
consts
.
is_multimer
:
torch
.
as_tensor
(
pair_mask
).
cuda
(),
out_repro
=
model
.
template_embedder
(
templ_dim
=
0
,
template_feats
,
inplace_safe
=
False
torch
.
as_tensor
(
pair_act
).
cuda
(),
)
torch
.
as_tensor
(
pair_mask
).
cuda
(),
templ_dim
=
0
,
chunk_size
=
consts
.
chunk_size
,
multichain_mask_2d
=
torch
.
as_tensor
(
multichain_mask_2d
).
cuda
(),
use_lma
=
False
,
inplace_safe
=
False
)
else
:
out_repro
=
model
.
template_embedder
(
template_feats
,
torch
.
as_tensor
(
pair_act
).
cuda
(),
torch
.
as_tensor
(
pair_mask
).
cuda
(),
templ_dim
=
0
,
chunk_size
=
consts
.
chunk_size
,
use_lma
=
False
,
inplace_safe
=
False
)
out_repro
=
out_repro
[
"template_pair_embedding"
]
out_repro
=
out_repro
[
"template_pair_embedding"
]
out_repro
=
out_repro
.
cpu
()
out_repro
=
out_repro
.
cpu
()
...
...
tests/test_triangular_attention.py
View file @
56d5e39c
...
@@ -86,9 +86,9 @@ class TestTriangularAttention(unittest.TestCase):
...
@@ -86,9 +86,9 @@ class TestTriangularAttention(unittest.TestCase):
model
=
compare_utils
.
get_global_pretrained_openfold
()
model
=
compare_utils
.
get_global_pretrained_openfold
()
module
=
(
module
=
(
model
.
evoformer
.
blocks
[
0
].
core
.
tri_att_start
model
.
evoformer
.
blocks
[
0
].
pair_stack
.
tri_att_start
if
starting
if
starting
else
model
.
evoformer
.
blocks
[
0
].
core
.
tri_att_end
else
model
.
evoformer
.
blocks
[
0
].
pair_stack
.
tri_att_end
)
)
# To save memory, the full model transposes inputs outside of the
# To save memory, the full model transposes inputs outside of the
...
...
tests/test_triangular_multiplicative_update.py
View file @
56d5e39c
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
# limitations under the License.
# limitations under the License.
import
torch
import
torch
import
re
import
numpy
as
np
import
numpy
as
np
import
unittest
import
unittest
from
openfold.model.triangular_multiplicative_update
import
*
from
openfold.model.triangular_multiplicative_update
import
*
...
@@ -31,10 +32,16 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
...
@@ -31,10 +32,16 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
c_z
=
consts
.
c_z
c_z
=
consts
.
c_z
c
=
11
c
=
11
tm
=
TriangleMultiplicationOutgoing
(
if
re
.
fullmatch
(
"^model_[1-5]_multimer_v3$"
,
consts
.
model
):
c_z
,
tm
=
FusedTriangleMultiplicationOutgoing
(
c
,
c_z
,
)
c
,
)
else
:
tm
=
TriangleMultiplicationOutgoing
(
c_z
,
c
,
)
n_res
=
consts
.
c_z
n_res
=
consts
.
c_z
batch_size
=
consts
.
batch_size
batch_size
=
consts
.
batch_size
...
@@ -62,7 +69,7 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
...
@@ -62,7 +69,7 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
config
.
model
.
global_config
,
config
.
model
.
global_config
,
name
=
name
,
name
=
name
,
)
)
act
=
tri_mul
(
act
=
pair_act
,
mask
=
pair_mask
)
act
=
tri_mul
(
pair_act
,
pair_mask
)
return
act
return
act
f
=
hk
.
transform
(
run_tri_mul
)
f
=
hk
.
transform
(
run_tri_mul
)
...
@@ -85,10 +92,11 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
...
@@ -85,10 +92,11 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
model
=
compare_utils
.
get_global_pretrained_openfold
()
model
=
compare_utils
.
get_global_pretrained_openfold
()
module
=
(
module
=
(
model
.
evoformer
.
blocks
[
0
].
core
.
tri_mul_in
model
.
evoformer
.
blocks
[
0
].
pair_stack
.
tri_mul_in
if
incoming
if
incoming
else
model
.
evoformer
.
blocks
[
0
].
core
.
tri_mul_out
else
model
.
evoformer
.
blocks
[
0
].
pair_stack
.
tri_mul_out
)
)
out_repro
=
module
(
out_repro
=
module
(
torch
.
as_tensor
(
pair_act
,
dtype
=
torch
.
float32
).
cuda
(),
torch
.
as_tensor
(
pair_act
,
dtype
=
torch
.
float32
).
cuda
(),
mask
=
torch
.
as_tensor
(
pair_mask
,
dtype
=
torch
.
float32
).
cuda
(),
mask
=
torch
.
as_tensor
(
pair_mask
,
dtype
=
torch
.
float32
).
cuda
(),
...
@@ -112,12 +120,11 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
...
@@ -112,12 +120,11 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
pair_mask
=
np
.
random
.
randint
(
low
=
0
,
high
=
2
,
size
=
(
n_res
,
n_res
))
pair_mask
=
np
.
random
.
randint
(
low
=
0
,
high
=
2
,
size
=
(
n_res
,
n_res
))
pair_mask
=
pair_mask
.
astype
(
np
.
float32
)
pair_mask
=
pair_mask
.
astype
(
np
.
float32
)
model
=
compare_utils
.
get_global_pretrained_openfold
()
model
=
compare_utils
.
get_global_pretrained_openfold
()
module
=
(
module
=
(
model
.
evoformer
.
blocks
[
0
].
core
.
tri_mul_in
model
.
evoformer
.
blocks
[
0
].
pair_stack
.
tri_mul_in
if
incoming
if
incoming
else
model
.
evoformer
.
blocks
[
0
].
core
.
tri_mul_out
else
model
.
evoformer
.
blocks
[
0
].
pair_stack
.
tri_mul_out
)
)
out_stock
=
module
(
out_stock
=
module
(
torch
.
as_tensor
(
pair_act
,
dtype
=
torch
.
float32
).
cuda
(),
torch
.
as_tensor
(
pair_act
,
dtype
=
torch
.
float32
).
cuda
(),
...
...
tests/test_utils.py
View file @
56d5e39c
...
@@ -23,7 +23,7 @@ from openfold.utils.rigid_utils import (
...
@@ -23,7 +23,7 @@ from openfold.utils.rigid_utils import (
quat_to_rot
,
quat_to_rot
,
rot_to_quat
,
rot_to_quat
,
)
)
from
openfold.utils.
tensor
_utils
import
chunk_layer
,
_chunk_slice
from
openfold.utils.
chunk
_utils
import
chunk_layer
,
_chunk_slice
import
tests.compare_utils
as
compare_utils
import
tests.compare_utils
as
compare_utils
from
tests.config
import
consts
from
tests.config
import
consts
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment