Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
304b5ff7
Commit
304b5ff7
authored
Oct 02, 2021
by
Gustaf Ahdritz
Browse files
Begin to spruce up unit tests, fix config
parent
abd78418
Changes
18
Hide whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
1045 additions
and
125 deletions
+1045
-125
config.py
config.py
+27
-15
openfold/model/structure_module.py
openfold/model/structure_module.py
+0
-2
openfold/utils/feats.py
openfold/utils/feats.py
+9
-0
setup.py
setup.py
+1
-1
tests/compare_utils.py
tests/compare_utils.py
+102
-0
tests/config.py
tests/config.py
+17
-0
tests/data_utils.py
tests/data_utils.py
+2
-2
tests/test_evoformer.py
tests/test_evoformer.py
+115
-9
tests/test_loss.py
tests/test_loss.py
+87
-19
tests/test_model.py
tests/test_model.py
+68
-13
tests/test_msa.py
tests/test_msa.py
+164
-19
tests/test_outer_product_mean.py
tests/test_outer_product_mean.py
+65
-11
tests/test_pair_transition.py
tests/test_pair_transition.py
+53
-4
tests/test_structure_module.py
tests/test_structure_module.py
+71
-5
tests/test_template.py
tests/test_template.py
+125
-13
tests/test_triangular_attention.py
tests/test_triangular_attention.py
+68
-7
tests/test_triangular_multiplicative_update.py
tests/test_triangular_multiplicative_update.py
+69
-4
tests/test_utils.py
tests/test_utils.py
+2
-1
No files found.
config.py
View file @
304b5ff7
...
...
@@ -2,7 +2,15 @@ import copy
import
ml_collections
as
mlc
def
model_config
(
name
,
train
=
False
):
def
set_inf
(
c
,
inf
):
for
k
,
v
in
c
.
items
():
if
(
isinstance
(
v
,
mlc
.
ConfigDict
)):
set_inf
(
v
,
inf
)
elif
(
k
==
"inf"
):
c
[
k
]
=
inf
def
model_config
(
name
,
train
=
False
,
low_prec
=
False
):
c
=
copy
.
deepcopy
(
config
)
if
(
name
==
"model_1"
):
pass
...
...
@@ -16,28 +24,34 @@ def model_config(name, train=False):
c
.
model
.
template
.
enabled
=
False
elif
(
name
==
"model_1_ptm"
):
c
.
model
.
heads
.
tm
.
enabled
=
True
c
.
model
.
loss
.
tm
.
weight
=
0.1
c
.
loss
.
tm
.
weight
=
0.1
elif
(
name
==
"model_2_ptm"
):
c
.
model
.
heads
.
tm
.
enabled
=
True
c
.
model
.
loss
.
tm
.
weight
=
0.1
c
.
loss
.
tm
.
weight
=
0.1
elif
(
name
==
"model_3_ptm"
):
c
.
model
.
template
.
enabled
=
False
c
.
model
.
heads
.
tm
.
enabled
=
True
c
.
model
.
loss
.
tm
.
weight
=
0.1
c
.
loss
.
tm
.
weight
=
0.1
elif
(
name
==
"model_4_ptm"
):
c
.
model
.
template
.
enabled
=
False
c
.
model
.
heads
.
tm
.
enabled
=
True
c
.
model
.
loss
.
tm
.
weight
=
0.1
c
.
loss
.
tm
.
weight
=
0.1
elif
(
name
==
"model_5_ptm"
):
c
.
model
.
template
.
enabled
=
False
c
.
model
.
heads
.
tm
.
enabled
=
True
c
.
model
.
loss
.
tm
.
weight
=
0.1
c
.
loss
.
tm
.
weight
=
0.1
else
:
raise
ValueError
(
"Invalid model name"
)
if
(
train
):
c
.
globals
.
model
.
blocks_per_ckpt
=
1
c
.
globals
.
chunk_size
=
None
if
(
low_prec
):
c
.
globals
.
eps
=
1e-4
# If we want exact numerical parity with the original, inf can't be
# a global constant
set_inf
(
c
,
1e4
)
return
c
...
...
@@ -51,7 +65,6 @@ blocks_per_ckpt = mlc.FieldReference(None, field_type=int)
chunk_size
=
mlc
.
FieldReference
(
4
,
field_type
=
int
)
aux_distogram_bins
=
mlc
.
FieldReference
(
64
,
field_type
=
int
)
eps
=
mlc
.
FieldReference
(
1e-8
,
field_type
=
float
)
inf
=
mlc
.
FieldReference
(
1e8
,
field_type
=
float
)
config
=
mlc
.
ConfigDict
({
# Recurring FieldReferences that can be changed globally here
...
...
@@ -64,7 +77,6 @@ config = mlc.ConfigDict({
"c_e"
:
c_e
,
"c_s"
:
c_s
,
"eps"
:
eps
,
"inf"
:
inf
,
},
"model"
:
{
"no_cycles"
:
4
,
...
...
@@ -82,7 +94,7 @@ config = mlc.ConfigDict({
"min_bin"
:
3.25
,
"max_bin"
:
20.75
,
"no_bins"
:
15
,
"inf"
:
inf
,
#
1e8,
"inf"
:
1e8
,
},
"template"
:
{
"distogram"
:
{
...
...
@@ -111,7 +123,7 @@ config = mlc.ConfigDict({
"dropout_rate"
:
0.25
,
"blocks_per_ckpt"
:
blocks_per_ckpt
,
"chunk_size"
:
chunk_size
,
"inf"
:
inf
,
"inf"
:
1e9
,
},
"template_pointwise_attention"
:
{
"c_t"
:
c_t
,
...
...
@@ -121,9 +133,9 @@ config = mlc.ConfigDict({
"c_hidden"
:
16
,
"no_heads"
:
4
,
"chunk_size"
:
chunk_size
,
"inf"
:
inf
,
#
1e
-
9,
"inf"
:
1e9
,
},
"inf"
:
inf
,
"inf"
:
1e9
,
"eps"
:
eps
,
#1e-6,
"enabled"
:
True
,
"embed_angles"
:
True
,
...
...
@@ -148,7 +160,7 @@ config = mlc.ConfigDict({
"pair_dropout"
:
0.25
,
"blocks_per_ckpt"
:
blocks_per_ckpt
,
"chunk_size"
:
chunk_size
,
"inf"
:
inf
,
#
1e9,
"inf"
:
1e9
,
"eps"
:
eps
,
#1e-10,
},
"enabled"
:
True
,
...
...
@@ -169,7 +181,7 @@ config = mlc.ConfigDict({
"pair_dropout"
:
0.25
,
"blocks_per_ckpt"
:
blocks_per_ckpt
,
"chunk_size"
:
chunk_size
,
"inf"
:
inf
,
#
1e9,
"inf"
:
1e9
,
"eps"
:
eps
,
#1e-10,
},
"structure_module"
:
{
...
...
@@ -187,7 +199,7 @@ config = mlc.ConfigDict({
"no_angles"
:
7
,
"trans_scale_factor"
:
10
,
"epsilon"
:
eps
,
#1e-12,
"inf"
:
inf
,
#
1e5,
"inf"
:
1e5
,
},
"heads"
:
{
"lddt"
:
{
...
...
openfold/model/structure_module.py
View file @
304b5ff7
...
...
@@ -316,7 +316,6 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, N_res, H]
pt_att
=
torch
.
sum
(
pt_att
,
dim
=-
1
)
*
(
-
0.5
)
# [*, N_res, N_res]
square_mask
=
mask
.
unsqueeze
(
-
1
)
*
mask
.
unsqueeze
(
-
2
)
square_mask
=
self
.
inf
*
(
square_mask
-
1
)
...
...
@@ -721,7 +720,6 @@ class StructureModule(nn.Module):
# [*, N]
t
=
T
.
identity
(
s
.
shape
[:
-
1
],
s
.
dtype
,
s
.
device
,
self
.
training
)
outputs
=
[]
for
i
in
range
(
self
.
no_blocks
):
# [*, N, C_s]
...
...
openfold/utils/feats.py
View file @
304b5ff7
...
...
@@ -23,6 +23,8 @@ from openfold.utils.affine_utils import T
from
openfold.utils.tensor_utils
import
(
batched_gather
,
one_hot
,
tree_map
,
tensor_tree_map
,
)
...
...
@@ -143,6 +145,13 @@ def compute_residx(batch):
return
out
def
compute_residx_np
(
batch
):
batch
=
tree_map
(
lambda
n
:
torch
.
tensor
(
n
),
batch
,
np
.
ndarray
)
out
=
compute_residx
(
batch
)
out
=
tensor_tree_map
(
lambda
t
:
np
.
array
(
t
),
out
)
return
out
def
atom14_to_atom37
(
atom14
,
batch
):
atom37_data
=
batched_gather
(
atom14
,
...
...
setup.py
View file @
304b5ff7
...
...
@@ -19,7 +19,7 @@ setup(
name
=
'openfold'
,
version
=
'1.0.0'
,
description
=
'A PyTorch reimplementation of DeepMind
\'
s AlphaFold 2'
,
author
=
'Gustaf Ahdritz'
,
author
=
'Gustaf Ahdritz
& DeepMind
'
,
author_email
=
'gahdritz@gmail.com'
,
license
=
'Apache License, Version 2.0'
,
url
=
'https://github.com/aqlaboratory/openfold'
,
...
...
tests/compare_utils.py
0 → 100644
View file @
304b5ff7
import
os
import
importlib
import
pkgutil
import
sys
import
unittest
import
numpy
as
np
from
config
import
model_config
from
openfold.model.model
import
AlphaFold
from
openfold.utils.import_weights
import
import_jax_weights_
from
tests.config
import
consts
# Give JAX some GPU memory discipline
# (by default it hogs 90% of GPU memory. This disables that behavior and also
# forces it to proactively free memory that it allocates)
os
.
environ
[
"XLA_PYTHON_CLIENT_ALLOCATOR"
]
=
"platform"
os
.
environ
[
"JAX_PLATFORM_NAME"
]
=
"gpu"
def
alphafold_is_installed
():
return
importlib
.
util
.
find_spec
(
"alphafold"
)
is
not
None
def
skip_unless_alphafold_installed
():
return
unittest
.
skipUnless
(
alphafold_is_installed
(),
"Requires AlphaFold"
)
def
import_alphafold
():
"""
If AlphaFold is installed using the provided setuptools script, this
is necessary to expose all of AlphaFold's precious insides
"""
if
(
"alphafold"
in
sys
.
modules
):
return
sys
.
modules
[
"alphafold"
]
module
=
importlib
.
import_module
(
"alphafold"
)
# Forcefully import alphafold's submodules
submodules
=
pkgutil
.
walk_packages
(
module
.
__path__
,
prefix
=
(
"alphafold."
)
)
for
submodule_info
in
submodules
:
importlib
.
import_module
(
submodule_info
.
name
)
sys
.
modules
[
"alphafold"
]
=
module
globals
()[
"alphafold"
]
=
module
return
module
def
get_alphafold_config
():
config
=
alphafold
.
model
.
config
.
model_config
(
"model_1_ptm"
)
config
.
model
.
global_config
.
deterministic
=
True
return
config
_param_path
=
"openfold/resources/params/params_model_1_ptm.npz"
_model
=
None
def
get_global_pretrained_openfold
():
global
_model
if
(
_model
is
None
):
_model
=
AlphaFold
(
model_config
(
"model_1_ptm"
).
model
)
_model
=
_model
.
eval
()
if
(
not
os
.
path
.
exists
(
_param_path
)):
raise
FileNotFoundError
(
"""Cannot load pretrained parameters. Make sure to run the
installation script before running tests."""
)
import_jax_weights_
(
_model
,
_param_path
)
_model
=
_model
.
cuda
()
return
_model
_orig_weights
=
None
def
_get_orig_weights
():
global
_orig_weights
if
(
_orig_weights
is
None
):
_orig_weights
=
np
.
load
(
_param_path
)
return
_orig_weights
def
_remove_key_prefix
(
d
,
prefix
):
for
k
,
v
in
list
(
d
.
items
()):
if
(
k
.
startswith
(
prefix
)):
d
.
pop
(
k
)
d
[
k
[
len
(
prefix
):]]
=
v
def
fetch_alphafold_module_weights
(
weight_path
):
orig_weights
=
_get_orig_weights
()
params
=
{
k
:
v
for
k
,
v
in
orig_weights
.
items
()
if
weight_path
in
k
}
if
(
'/'
in
weight_path
):
spl
=
weight_path
.
split
(
'/'
)
spl
=
spl
if
len
(
spl
[
-
1
])
!=
0
else
spl
[:
-
1
]
module_name
=
spl
[
-
1
]
prefix
=
'/'
.
join
(
spl
[:
-
1
])
+
'/'
_remove_key_prefix
(
params
,
prefix
)
params
=
alphafold
.
model
.
utils
.
flat_params_to_haiku
(
params
)
return
params
tests/config.py
0 → 100644
View file @
304b5ff7
import
ml_collections
as
mlc
consts
=
mlc
.
ConfigDict
({
"batch_size"
:
2
,
"n_res"
:
11
,
"n_seq"
:
13
,
"n_templ"
:
3
,
"n_extra"
:
17
,
"eps"
:
5e-4
,
# For compatibility with DeepMind's pretrained weights, it's easiest for
# everyone if these take their real values.
"c_m"
:
256
,
"c_z"
:
128
,
"c_s"
:
384
,
"c_t"
:
64
,
"c_e"
:
64
,
})
tests/utils.py
→
tests/
data_
utils.py
View file @
304b5ff7
...
...
@@ -54,7 +54,7 @@ def random_extra_msa_feats(n_extra, n, batch_size=None):
return
batch
def
random_affine_vector
s
(
dim
):
def
random_affine
s
_vector
(
dim
):
prod_dim
=
1
for
d
in
dim
:
prod_dim
*=
d
...
...
@@ -68,7 +68,7 @@ def random_affine_vectors(dim):
return
affines
.
reshape
(
*
dim
,
7
)
def
random_affine_4x4
s
(
dim
):
def
random_affine
s
_4x4
(
dim
):
prod_dim
=
1
for
d
in
dim
:
prod_dim
*=
d
...
...
tests/test_evoformer.py
View file @
304b5ff7
...
...
@@ -15,21 +15,33 @@
import
torch
import
numpy
as
np
import
unittest
from
alphafold.model.evoformer
import
*
from
openfold.model.evoformer
import
(
MSATransition
,
EvoformerStack
,
ExtraMSAStack
,
)
from
openfold.utils.tensor_utils
import
tree_map
import
tests.compare_utils
as
compare_utils
from
tests.config
import
consts
if
(
compare_utils
.
alphafold_is_installed
()):
alphafold
=
compare_utils
.
import_alphafold
()
import
jax
import
haiku
as
hk
class
TestEvoformerStack
(
unittest
.
TestCase
):
def
test_shape
(
self
):
batch_size
=
5
s_t
=
27
n_res
=
29
c_m
=
7
c_z
=
11
batch_size
=
consts
.
batch_size
n_seq
=
consts
.
n_seq
n_res
=
consts
.
n_res
c_m
=
consts
.
c_m
c_z
=
consts
.
c_z
c_hidden_msa_att
=
12
c_hidden_opm
=
17
c_hidden_mul
=
19
c_hidden_pair_att
=
14
c_s
=
23
c_s
=
consts
.
c_s
no_heads_msa
=
3
no_heads_pair
=
7
no_blocks
=
2
...
...
@@ -59,9 +71,9 @@ class TestEvoformerStack(unittest.TestCase):
eps
=
eps
,
).
eval
()
m
=
torch
.
rand
((
batch_size
,
s_t
,
n_res
,
c_m
))
m
=
torch
.
rand
((
batch_size
,
n_seq
,
n_res
,
c_m
))
z
=
torch
.
rand
((
batch_size
,
n_res
,
n_res
,
c_z
))
msa_mask
=
torch
.
randint
(
0
,
2
,
size
=
(
batch_size
,
s_t
,
n_res
))
msa_mask
=
torch
.
randint
(
0
,
2
,
size
=
(
batch_size
,
n_seq
,
n_res
))
pair_mask
=
torch
.
randint
(
0
,
2
,
size
=
(
batch_size
,
n_res
,
n_res
))
shape_m_before
=
m
.
shape
...
...
@@ -73,6 +85,59 @@ class TestEvoformerStack(unittest.TestCase):
self
.
assertTrue
(
z
.
shape
==
shape_z_before
)
self
.
assertTrue
(
s
.
shape
==
(
batch_size
,
n_res
,
c_s
))
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_compare
(
self
):
def
run_ei
(
activations
,
masks
):
config
=
compare_utils
.
get_alphafold_config
()
c_e
=
config
.
model
.
embeddings_and_evoformer
.
evoformer
ei
=
alphafold
.
model
.
modules
.
EvoformerIteration
(
c_e
,
config
.
model
.
global_config
,
is_extra_msa
=
False
)
return
ei
(
activations
,
masks
,
is_training
=
False
)
f
=
hk
.
transform
(
run_ei
)
n_res
=
consts
.
n_res
n_seq
=
consts
.
n_seq
activations
=
{
'msa'
:
np
.
random
.
rand
(
n_seq
,
n_res
,
consts
.
c_m
).
astype
(
np
.
float32
),
'pair'
:
np
.
random
.
rand
(
n_res
,
n_res
,
consts
.
c_z
).
astype
(
np
.
float32
),
}
masks
=
{
'msa'
:
np
.
random
.
randint
(
0
,
2
,
(
n_seq
,
n_res
)).
astype
(
np
.
float32
),
'pair'
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
n_res
)).
astype
(
np
.
float32
),
}
params
=
compare_utils
.
fetch_alphafold_module_weights
(
"alphafold/alphafold_iteration/evoformer/evoformer_iteration"
)
params
=
tree_map
(
lambda
n
:
n
[
0
],
params
,
jax
.
numpy
.
DeviceArray
)
key
=
jax
.
random
.
PRNGKey
(
42
)
out_gt
=
f
.
apply
(
params
,
key
,
activations
,
masks
)
jax
.
tree_map
(
lambda
x
:
x
.
block_until_ready
(),
out_gt
)
out_gt_msa
=
torch
.
as_tensor
(
np
.
array
(
out_gt
[
"msa"
]))
out_gt_pair
=
torch
.
as_tensor
(
np
.
array
(
out_gt
[
"pair"
]))
model
=
compare_utils
.
get_global_pretrained_openfold
()
out_repro_msa
,
out_repro_pair
=
model
.
evoformer
.
blocks
[
0
](
torch
.
as_tensor
(
activations
[
"msa"
]).
cuda
(),
torch
.
as_tensor
(
activations
[
"pair"
]).
cuda
(),
torch
.
as_tensor
(
masks
[
"msa"
]).
cuda
(),
torch
.
as_tensor
(
masks
[
"pair"
]).
cuda
(),
_mask_trans
=
False
,
)
out_repro_msa
=
out_repro_msa
.
cpu
()
out_repro_pair
=
out_repro_pair
.
cpu
()
assert
(
torch
.
max
(
torch
.
abs
(
out_repro_msa
-
out_gt_msa
)
<
consts
.
eps
))
assert
(
torch
.
max
(
torch
.
abs
(
out_repro_pair
-
out_gt_pair
)
<
consts
.
eps
))
class
TestExtraMSAStack
(
unittest
.
TestCase
):
def
test_shape
(
self
):
...
...
@@ -143,6 +208,47 @@ class TestMSATransition(unittest.TestCase):
self
.
assertTrue
(
shape_before
==
shape_after
)
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_compare
(
self
):
def
run_msa_transition
(
msa_act
,
msa_mask
):
config
=
compare_utils
.
get_alphafold_config
()
c_e
=
config
.
model
.
embeddings_and_evoformer
.
evoformer
msa_trans
=
alphafold
.
model
.
modules
.
Transition
(
c_e
.
msa_transition
,
config
.
model
.
global_config
,
name
=
"msa_transition"
)
act
=
msa_trans
(
act
=
msa_act
,
mask
=
msa_mask
)
return
act
f
=
hk
.
transform
(
run_msa_transition
)
n_res
=
consts
.
n_res
n_seq
=
consts
.
n_seq
msa_act
=
np
.
random
.
rand
(
n_seq
,
n_res
,
consts
.
c_m
).
astype
(
np
.
float32
)
msa_mask
=
np
.
ones
((
n_seq
,
n_res
)).
astype
(
np
.
float32
)
# no mask here either
# Fetch pretrained parameters (but only from one block)]
params
=
compare_utils
.
fetch_alphafold_module_weights
(
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
+
"msa_transition"
)
params
=
tree_map
(
lambda
n
:
n
[
0
],
params
,
jax
.
numpy
.
DeviceArray
)
out_gt
=
f
.
apply
(
params
,
None
,
msa_act
,
msa_mask
).
block_until_ready
()
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
))
model
=
compare_utils
.
get_global_pretrained_openfold
()
out_repro
=
model
.
evoformer
.
blocks
[
0
].
msa_transition
(
torch
.
as_tensor
(
msa_act
,
dtype
=
torch
.
float32
).
cuda
(),
mask
=
torch
.
as_tensor
(
msa_mask
,
dtype
=
torch
.
float32
).
cuda
(),
).
cpu
()
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
)
<
consts
.
eps
))
if
__name__
==
"__main__"
:
unittest
.
main
()
tests/test_loss.py
View file @
304b5ff7
...
...
@@ -17,23 +17,37 @@ import torch
import
numpy
as
np
import
unittest
from
alphafold.utils.loss
import
*
from
alphafold.utils.utils
import
T
from
openfold.utils.loss
import
(
torsion_angle_loss
,
compute_fape
,
between_residue_bond_loss
,
between_residue_clash_loss
,
find_structural_violations
,
)
from
openfold.utils.affine_utils
import
T
from
openfold.utils.tensor_utils
import
tensor_tree_map
import
tests.compare_utils
as
compare_utils
from
tests.config
import
consts
if
(
compare_utils
.
alphafold_is_installed
()):
alphafold
=
compare_utils
.
import_alphafold
()
import
jax
import
haiku
as
hk
class
TestLoss
(
unittest
.
TestCase
):
def
test_run_torsion_angle_loss
(
self
):
batch_size
=
2
n
=
5
batch_size
=
consts
.
batch_size
n
_res
=
consts
.
n_res
a
=
torch
.
rand
((
batch_size
,
n
,
7
,
2
))
a_gt
=
torch
.
rand
((
batch_size
,
n
,
7
,
2
))
a_alt_gt
=
torch
.
rand
((
batch_size
,
n
,
7
,
2
))
a
=
torch
.
rand
((
batch_size
,
n
_res
,
7
,
2
))
a_gt
=
torch
.
rand
((
batch_size
,
n
_res
,
7
,
2
))
a_alt_gt
=
torch
.
rand
((
batch_size
,
n
_res
,
7
,
2
))
loss
=
torsion_angle_loss
(
a
,
a_gt
,
a_alt_gt
)
def
test_run_fape
(
self
):
batch_size
=
2
batch_size
=
consts
.
batch_size
n_frames
=
7
n_atoms
=
5
...
...
@@ -45,12 +59,23 @@ class TestLoss(unittest.TestCase):
trans_gt
=
torch
.
rand
((
batch_size
,
n_frames
,
3
))
t
=
T
(
rots
,
trans
)
t_gt
=
T
(
rots_gt
,
trans_gt
)
frames_mask
=
torch
.
randint
(
0
,
2
,
(
batch_size
,
n_frames
)).
float
()
positions_mask
=
torch
.
randint
(
0
,
2
,
(
batch_size
,
n_atoms
)).
float
()
length_scale
=
10
loss
=
compute_fape
(
pred_frames
=
t
,
target_frames
=
t_gt
,
frames_mask
=
frames_mask
,
pred_positions
=
x
,
target_positions
=
x_gt
,
positions_mask
=
positions_mask
,
length_scale
=
length_scale
,
)
loss
=
compute_fape
(
t
,
x
,
t_gt
,
x_gt
)
def
test_between_residue_bond_loss
(
self
):
bs
=
2
n
=
10
def
test_run_between_residue_bond_loss
(
self
):
bs
=
consts
.
batch_size
n
=
consts
.
n_res
pred_pos
=
torch
.
rand
(
bs
,
n
,
14
,
3
)
pred_atom_mask
=
torch
.
randint
(
0
,
2
,
(
bs
,
n
,
14
))
residue_index
=
torch
.
arange
(
n
).
unsqueeze
(
0
)
...
...
@@ -63,9 +88,52 @@ class TestLoss(unittest.TestCase):
aatype
,
)
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_between_residue_bond_loss_compare
(
self
):
def
run_brbl
(
pred_pos
,
pred_atom_mask
,
residue_index
,
aatype
):
return
alphafold
.
model
.
all_atom
.
between_residue_bond_loss
(
pred_pos
,
pred_atom_mask
,
residue_index
,
aatype
,
)
f
=
hk
.
transform
(
run_brbl
)
n_res
=
consts
.
n_res
pred_pos
=
np
.
random
.
rand
(
n_res
,
14
,
3
).
astype
(
np
.
float32
)
pred_atom_mask
=
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
14
)).
astype
(
np
.
float32
)
residue_index
=
np
.
arange
(
n_res
)
aatype
=
np
.
random
.
randint
(
0
,
22
,
(
n_res
,))
out_gt
=
f
.
apply
(
{},
None
,
pred_pos
,
pred_atom_mask
,
residue_index
,
aatype
,
)
out_gt
=
jax
.
tree_map
(
lambda
x
:
x
.
block_until_ready
(),
out_gt
)
out_gt
=
jax
.
tree_map
(
lambda
x
:
torch
.
tensor
(
np
.
copy
(
x
)),
out_gt
)
out_repro
=
between_residue_bond_loss
(
torch
.
tensor
(
pred_pos
).
cuda
(),
torch
.
tensor
(
pred_atom_mask
).
cuda
(),
torch
.
tensor
(
residue_index
).
cuda
(),
torch
.
tensor
(
aatype
).
cuda
(),
)
out_repro
=
tensor_tree_map
(
lambda
x
:
x
.
cpu
(),
out_repro
)
for
k
in
out_gt
.
keys
():
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
[
k
]
-
out_repro
[
k
]))
<
consts
.
eps
)
def
test_between_residue_clash_loss
(
self
):
bs
=
2
n
=
10
bs
=
consts
.
batch_size
n
=
consts
.
n_res
pred_pos
=
torch
.
rand
(
bs
,
n
,
14
,
3
)
pred_atom_mask
=
torch
.
randint
(
0
,
2
,
(
bs
,
n
,
14
))
atom14_atom_radius
=
torch
.
rand
(
bs
,
n
,
14
)
...
...
@@ -79,7 +147,7 @@ class TestLoss(unittest.TestCase):
)
def
test_find_structural_violations
(
self
):
n
=
10
n
=
consts
.
n_res
batch
=
{
"atom14_atom_exists"
:
torch
.
randint
(
0
,
2
,
(
n
,
14
)),
...
...
@@ -90,12 +158,12 @@ class TestLoss(unittest.TestCase):
pred_pos
=
torch
.
rand
(
n
,
14
,
3
)
config
=
ml_collections
.
ConfigDict
(
{
config
=
{
"clash_overlap_tolerance"
:
1.5
,
"violation_tolerance_factor"
:
12.0
,
}
)
}
find_structural_violations
(
batch
,
pred_pos
,
config
)
find_structural_violations
(
batch
,
pred_pos
,
**
config
)
if
__name__
==
"__main__"
:
...
...
tests/test_model.py
View file @
304b5ff7
...
...
@@ -12,25 +12,35 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
pickle
import
torch
import
torch.nn
as
nn
import
numpy
as
np
import
unittest
from
config
import
*
from
alphafold.model.model
import
*
from
alphafold.utils.utils
import
my_tree_map
from
tests.alphafold.utils.utils
import
(
from
openfold.model.model
import
AlphaFold
import
openfold.utils.feats
as
feats
from
openfold.utils.tensor_utils
import
tree_map
,
tensor_tree_map
import
tests.compare_utils
as
compare_utils
from
tests.config
import
consts
from
tests.data_utils
import
(
random_template_feats
,
random_extra_msa_feats
,
)
if
(
compare_utils
.
alphafold_is_installed
()):
alphafold
=
compare_utils
.
import_alphafold
()
import
jax
import
haiku
as
hk
class
TestModel
(
unittest
.
TestCase
):
def
test_dry_run
(
self
):
batch_size
=
2
n_seq
=
5
n_templ
=
7
n_res
=
11
n_extra_seq
=
13
batch_size
=
consts
.
batch_size
n_seq
=
consts
.
n_seq
n_templ
=
consts
.
n_templ
n_res
=
consts
.
n_res
n_extra_seq
=
consts
.
n_extra
c
=
model_config
(
"model_1"
).
model
c
.
no_cycles
=
2
...
...
@@ -59,20 +69,65 @@ class TestModel(unittest.TestCase):
batch
.
update
({
k
:
torch
.
tensor
(
v
)
for
k
,
v
in
extra_feats
.
items
()})
batch
[
"msa_mask"
]
=
torch
.
randint
(
low
=
0
,
high
=
2
,
size
=
(
batch_size
,
n_seq
,
n_res
)
)
)
.
float
()
batch
[
"seq_mask"
]
=
torch
.
randint
(
low
=
0
,
high
=
2
,
size
=
(
batch_size
,
n_res
)
)
).
float
()
batch
.
update
(
feats
.
compute_residx
(
batch
))
add_recycling_dims
=
lambda
t
:
(
t
.
unsqueeze
(
-
1
).
expand
(
*
t
.
shape
,
c
.
no_cycles
)
)
batch
=
my
_tree_map
(
add_recycling_dims
,
batch
,
torch
.
Tensor
)
batch
=
tensor
_tree_map
(
add_recycling_dims
,
batch
)
with
torch
.
no_grad
():
out
=
model
(
batch
)
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_compare
(
self
):
def
run_alphafold
(
batch
):
config
=
compare_utils
.
get_alphafold_config
()
model
=
alphafold
.
model
.
modules
.
AlphaFold
(
config
.
model
)
return
model
(
batch
=
batch
,
is_training
=
False
,
return_representations
=
True
,
)
f
=
hk
.
transform
(
run_alphafold
)
params
=
compare_utils
.
fetch_alphafold_module_weights
(
''
)
with
open
(
"tests/test_data/sample_feats.pickle"
,
"rb"
)
as
fp
:
batch
=
pickle
.
load
(
fp
)
out_gt
=
jax
.
jit
(
f
.
apply
)(
params
,
jax
.
random
.
PRNGKey
(
42
),
batch
)
out_gt
=
out_gt
[
"structure_module"
][
"final_atom_positions"
]
# atom37_to_atom14 doesn't like batches
batch
[
"residx_atom14_to_atom37"
]
=
batch
[
"residx_atom14_to_atom37"
][
0
]
batch
[
"atom14_atom_exists"
]
=
batch
[
"atom14_atom_exists"
][
0
]
out_gt
=
alphafold
.
model
.
all_atom
.
atom37_to_atom14
(
out_gt
,
batch
)
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
.
block_until_ready
()))
batch
=
{
k
:
torch
.
as_tensor
(
v
).
cuda
()
for
k
,
v
in
batch
.
items
()
}
batch
[
"aatype"
]
=
batch
[
"aatype"
].
long
()
batch
[
"template_aatype"
]
=
batch
[
"template_aatype"
].
long
()
batch
[
"extra_msa"
]
=
batch
[
"extra_msa"
].
long
()
batch
[
"residx_atom37_to_atom14"
]
=
batch
[
"residx_atom37_to_atom14"
].
long
()
# Move the recycling dimension to the end
move_dim
=
lambda
t
:
t
.
permute
(
*
range
(
len
(
t
.
shape
))[
1
:],
0
)
batch
=
tensor_tree_map
(
move_dim
,
batch
)
with
torch
.
no_grad
():
model
=
compare_utils
.
get_global_pretrained_openfold
()
out_repro
=
model
(
batch
)
out_repro
=
tensor_tree_map
(
lambda
t
:
t
.
cpu
(),
out_repro
)
out_repro
=
out_repro
[
"sm"
][
"positions"
][
-
1
]
out_repro
=
out_repro
.
squeeze
(
0
)
if
__name__
==
"__main__"
:
unittest
.
main
()
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
)
<
1e-3
))
tests/test_msa.py
View file @
304b5ff7
...
...
@@ -15,23 +15,36 @@
import
torch
import
numpy
as
np
import
unittest
from
alphafold.model.msa
import
*
from
openfold.model.msa
import
(
MSARowAttentionWithPairBias
,
MSAColumnAttention
,
MSAColumnGlobalAttention
,
)
from
openfold.utils.tensor_utils
import
tree_map
import
tests.compare_utils
as
compare_utils
from
tests.config
import
consts
if
(
compare_utils
.
alphafold_is_installed
()):
alphafold
=
compare_utils
.
import_alphafold
()
import
jax
import
haiku
as
hk
class
TestMSARowAttentionWithPairBias
(
unittest
.
TestCase
):
def
test_shape
(
self
):
batch_size
=
2
s_t
=
3
n
=
5
c_m
=
7
c_z
=
11
batch_size
=
consts
.
batch_size
n_seq
=
consts
.
n_seq
n
_res
=
consts
.
n_res
c_m
=
consts
.
c_m
c_z
=
consts
.
c_z
c
=
52
no_heads
=
4
chunk_size
=
None
mrapb
=
MSARowAttentionWithPairBias
(
c_m
,
c_z
,
c
,
no_heads
)
mrapb
=
MSARowAttentionWithPairBias
(
c_m
,
c_z
,
c
,
no_heads
,
chunk_size
)
m
=
torch
.
rand
((
batch_size
,
s_t
,
n
,
c_m
))
z
=
torch
.
rand
((
batch_size
,
n
,
n
,
c_z
))
m
=
torch
.
rand
((
batch_size
,
n_seq
,
n_res
,
c_m
))
z
=
torch
.
rand
((
batch_size
,
n
_res
,
n_res
,
c_z
))
shape_before
=
m
.
shape
m
=
mrapb
(
m
,
z
)
...
...
@@ -39,19 +52,65 @@ class TestMSARowAttentionWithPairBias(unittest.TestCase):
self
.
assertTrue
(
shape_before
==
shape_after
)
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_compare
(
self
):
def
run_msa_row_att
(
msa_act
,
msa_mask
,
pair_act
):
config
=
compare_utils
.
get_alphafold_config
()
c_e
=
config
.
model
.
embeddings_and_evoformer
.
evoformer
msa_row
=
alphafold
.
model
.
modules
.
MSARowAttentionWithPairBias
(
c_e
.
msa_row_attention_with_pair_bias
,
config
.
model
.
global_config
)
act
=
msa_row
(
msa_act
=
msa_act
,
msa_mask
=
msa_mask
,
pair_act
=
pair_act
)
return
act
f
=
hk
.
transform
(
run_msa_row_att
)
n_res
=
consts
.
n_res
n_seq
=
consts
.
n_seq
msa_act
=
np
.
random
.
rand
(
n_seq
,
n_res
,
consts
.
c_m
).
astype
(
np
.
float32
)
msa_mask
=
np
.
random
.
randint
(
low
=
0
,
high
=
2
,
size
=
(
n_seq
,
n_res
)
).
astype
(
np
.
float32
)
pair_act
=
np
.
random
.
rand
(
n_res
,
n_res
,
consts
.
c_z
).
astype
(
np
.
float32
)
# Fetch pretrained parameters (but only from one block)]
params
=
compare_utils
.
fetch_alphafold_module_weights
(
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
+
"msa_row_attention"
)
params
=
tree_map
(
lambda
n
:
n
[
0
],
params
,
jax
.
numpy
.
DeviceArray
)
out_gt
=
f
.
apply
(
params
,
None
,
msa_act
,
msa_mask
,
pair_act
).
block_until_ready
()
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
))
model
=
compare_utils
.
get_global_pretrained_openfold
()
out_repro
=
model
.
evoformer
.
blocks
[
0
].
msa_att_row
(
torch
.
as_tensor
(
msa_act
).
cuda
(),
torch
.
as_tensor
(
pair_act
).
cuda
(),
torch
.
as_tensor
(
msa_mask
).
cuda
(),
).
cpu
()
self
.
assertTrue
(
torch
.
all
(
torch
.
abs
(
out_gt
-
out_repro
)
<
consts
.
eps
))
class
TestMSAColumnAttention
(
unittest
.
TestCase
):
def
test_shape
(
self
):
batch_size
=
2
s_t
=
3
n
=
5
c_m
=
7
batch_size
=
consts
.
batch_size
n_seq
=
consts
.
n_seq
n
_res
=
consts
.
n_res
c_m
=
consts
.
c_m
c
=
44
no_heads
=
4
msaca
=
MSAColumnAttention
(
c_m
,
c
,
no_heads
)
x
=
torch
.
rand
((
batch_size
,
s_t
,
n
,
c_m
))
x
=
torch
.
rand
((
batch_size
,
n_seq
,
n_res
,
c_m
))
shape_before
=
x
.
shape
x
=
msaca
(
x
)
...
...
@@ -59,19 +118,63 @@ class TestMSAColumnAttention(unittest.TestCase):
self
.
assertTrue
(
shape_before
==
shape_after
)
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_compare
(
self
):
def
run_msa_col_att
(
msa_act
,
msa_mask
):
config
=
compare_utils
.
get_alphafold_config
()
c_e
=
config
.
model
.
embeddings_and_evoformer
.
evoformer
msa_col
=
alphafold
.
model
.
modules
.
MSAColumnAttention
(
c_e
.
msa_column_attention
,
config
.
model
.
global_config
)
act
=
msa_col
(
msa_act
=
msa_act
,
msa_mask
=
msa_mask
)
return
act
f
=
hk
.
transform
(
run_msa_col_att
)
n_res
=
consts
.
n_res
n_seq
=
consts
.
n_seq
msa_act
=
np
.
random
.
rand
(
n_seq
,
n_res
,
consts
.
c_m
).
astype
(
np
.
float32
)
msa_mask
=
np
.
random
.
randint
(
low
=
0
,
high
=
2
,
size
=
(
n_seq
,
n_res
)
).
astype
(
np
.
float32
)
# Fetch pretrained parameters (but only from one block)]
params
=
compare_utils
.
fetch_alphafold_module_weights
(
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
+
"msa_column_attention"
)
params
=
tree_map
(
lambda
n
:
n
[
0
],
params
,
jax
.
numpy
.
DeviceArray
)
out_gt
=
f
.
apply
(
params
,
None
,
msa_act
,
msa_mask
).
block_until_ready
()
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
))
model
=
compare_utils
.
get_global_pretrained_openfold
()
out_repro
=
model
.
evoformer
.
blocks
[
0
].
msa_att_col
(
torch
.
as_tensor
(
msa_act
).
cuda
(),
torch
.
as_tensor
(
msa_mask
).
cuda
(),
).
cpu
()
self
.
assertTrue
(
torch
.
all
(
torch
.
abs
(
out_gt
-
out_repro
)
<
consts
.
eps
))
class
TestMSAColumnGlobalAttention
(
unittest
.
TestCase
):
def
test_shape
(
self
):
batch_size
=
2
s_t
=
3
n
=
5
c_m
=
7
batch_size
=
consts
.
batch_size
n_seq
=
consts
.
n_seq
n
_res
=
consts
.
n_res
c_m
=
consts
.
c_m
c
=
44
no_heads
=
4
msagca
=
MSAColumnGlobalAttention
(
c_m
,
c
,
no_heads
)
x
=
torch
.
rand
((
batch_size
,
s_t
,
n
,
c_m
))
x
=
torch
.
rand
((
batch_size
,
n_seq
,
n_res
,
c_m
))
shape_before
=
x
.
shape
x
=
msagca
(
x
)
...
...
@@ -79,6 +182,48 @@ class TestMSAColumnGlobalAttention(unittest.TestCase):
self
.
assertTrue
(
shape_before
==
shape_after
)
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_compare
(
self
):
def
run_msa_col_global_att
(
msa_act
,
msa_mask
):
config
=
compare_utils
.
get_alphafold_config
()
c_e
=
config
.
model
.
embeddings_and_evoformer
.
evoformer
msa_col
=
alphafold
.
model
.
modules
.
MSAColumnGlobalAttention
(
c_e
.
msa_column_attention
,
config
.
model
.
global_config
,
name
=
"msa_column_global_attention"
)
act
=
msa_col
(
msa_act
=
msa_act
,
msa_mask
=
msa_mask
)
return
act
f
=
hk
.
transform
(
run_msa_col_global_att
)
n_res
=
consts
.
n_res
n_seq
=
consts
.
n_seq
c_e
=
consts
.
c_e
msa_act
=
np
.
random
.
rand
(
n_seq
,
n_res
,
c_e
)
msa_mask
=
np
.
random
.
randint
(
low
=
0
,
high
=
2
,
size
=
(
n_seq
,
n_res
))
# Fetch pretrained parameters (but only from one block)]
params
=
compare_utils
.
fetch_alphafold_module_weights
(
"alphafold/alphafold_iteration/evoformer/extra_msa_stack/"
+
"msa_column_global_attention"
)
params
=
tree_map
(
lambda
n
:
n
[
0
],
params
,
jax
.
numpy
.
DeviceArray
)
out_gt
=
f
.
apply
(
params
,
None
,
msa_act
,
msa_mask
).
block_until_ready
()
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
.
block_until_ready
()))
model
=
compare_utils
.
get_global_pretrained_openfold
()
out_repro
=
model
.
extra_msa_stack
.
stack
.
blocks
[
0
].
msa_att_col
(
torch
.
as_tensor
(
msa_act
,
dtype
=
torch
.
float32
).
cuda
(),
mask
=
torch
.
as_tensor
(
msa_mask
,
dtype
=
torch
.
float32
).
cuda
(),
).
cpu
()
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
)
<
consts
.
eps
))
if
__name__
==
"__main__"
:
unittest
.
main
()
tests/test_outer_product_mean.py
View file @
304b5ff7
...
...
@@ -15,25 +15,79 @@
import
torch
import
numpy
as
np
import
unittest
from
alphafold.model.outer_product_mean
import
*
from
openfold.model.outer_product_mean
import
OuterProductMean
from
openfold.utils.tensor_utils
import
tree_map
import
tests.compare_utils
as
compare_utils
from
tests.config
import
consts
if
(
compare_utils
.
alphafold_is_installed
()):
alphafold
=
compare_utils
.
import_alphafold
()
import
jax
import
haiku
as
hk
class
TestOuterProductMean
(
unittest
.
TestCase
):
def
test_shape
(
self
):
batch_size
=
2
s
=
5
n_res
=
7
c_m
=
11
c
=
13
c_z
=
17
c
=
31
opm
=
OuterProductMean
(
c
_m
,
c_z
,
c
)
opm
=
OuterProductMean
(
c
onsts
.
c_m
,
consts
.
c_z
,
c
)
m
=
torch
.
rand
((
batch_size
,
s
,
n_res
,
c_m
))
mask
=
torch
.
randint
(
0
,
2
,
size
=
(
batch_size
,
s
,
n_res
))
m
=
torch
.
rand
(
(
consts
.
batch_size
,
consts
.
n_seq
,
consts
.
n_res
,
consts
.
c_m
)
)
mask
=
torch
.
randint
(
0
,
2
,
size
=
(
consts
.
batch_size
,
consts
.
n_seq
,
consts
.
n_res
)
)
m
=
opm
(
m
,
mask
)
self
.
assertTrue
(
m
.
shape
==
(
batch_size
,
n_res
,
n_res
,
c_z
))
self
.
assertTrue
(
m
.
shape
==
(
consts
.
batch_size
,
consts
.
n_res
,
consts
.
n_res
,
consts
.
c_z
)
)
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_opm_compare
(
self
):
def
run_opm
(
msa_act
,
msa_mask
):
config
=
compare_utils
.
get_alphafold_config
()
c_evo
=
config
.
model
.
embeddings_and_evoformer
.
evoformer
opm
=
alphafold
.
model
.
modules
.
OuterProductMean
(
c_evo
.
outer_product_mean
,
config
.
model
.
global_config
,
consts
.
c_z
,
)
act
=
opm
(
act
=
msa_act
,
mask
=
msa_mask
)
return
act
f
=
hk
.
transform
(
run_opm
)
n_res
=
consts
.
n_res
n_seq
=
consts
.
n_seq
c_m
=
consts
.
c_m
msa_act
=
np
.
random
.
rand
(
n_seq
,
n_res
,
c_m
).
astype
(
np
.
float32
)
*
100
msa_mask
=
np
.
random
.
randint
(
low
=
0
,
high
=
2
,
size
=
(
n_seq
,
n_res
)
).
astype
(
np
.
float32
)
# Fetch pretrained parameters (but only from one block)]
params
=
compare_utils
.
fetch_alphafold_module_weights
(
"alphafold/alphafold_iteration/evoformer/"
+
"evoformer_iteration/outer_product_mean"
)
params
=
tree_map
(
lambda
n
:
n
[
0
],
params
,
jax
.
numpy
.
DeviceArray
)
out_gt
=
f
.
apply
(
params
,
None
,
msa_act
,
msa_mask
).
block_until_ready
()
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
))
model
=
compare_utils
.
get_global_pretrained_openfold
()
out_repro
=
model
.
evoformer
.
blocks
[
0
].
outer_product_mean
(
torch
.
as_tensor
(
msa_act
).
cuda
(),
mask
=
torch
.
as_tensor
(
msa_mask
).
cuda
(),
).
cpu
()
# Even when correct, OPM has large, precision-related errors. It gets
# a special pass from consts.eps.
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
)
<
5e-4
))
if
__name__
==
"__main__"
:
...
...
tests/test_pair_transition.py
View file @
304b5ff7
...
...
@@ -15,18 +15,26 @@
import
torch
import
numpy
as
np
import
unittest
from
alphafold.model.pair_transition
import
*
from
openfold.model.pair_transition
import
PairTransition
from
openfold.utils.tensor_utils
import
tree_map
import
tests.compare_utils
as
compare_utils
from
tests.config
import
consts
if
(
compare_utils
.
alphafold_is_installed
()):
alphafold
=
compare_utils
.
import_alphafold
()
import
jax
import
haiku
as
hk
class
TestPairTransition
(
unittest
.
TestCase
):
def
test_shape
(
self
):
c_z
=
5
c_z
=
consts
.
c_z
n
=
4
pt
=
PairTransition
(
c_z
,
n
)
batch_size
=
4
n_res
=
256
batch_size
=
consts
.
batch_size
n_res
=
consts
.
n_res
z
=
torch
.
rand
((
batch_size
,
n_res
,
n_res
,
c_z
))
mask
=
torch
.
randint
(
0
,
2
,
size
=
(
batch_size
,
n_res
,
n_res
))
...
...
@@ -36,6 +44,47 @@ class TestPairTransition(unittest.TestCase):
self
.
assertTrue
(
shape_before
==
shape_after
)
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_compare
(
self
):
def
run_pair_transition
(
pair_act
,
pair_mask
):
config
=
compare_utils
.
get_alphafold_config
()
c_e
=
config
.
model
.
embeddings_and_evoformer
.
evoformer
pt
=
alphafold
.
model
.
modules
.
Transition
(
c_e
.
pair_transition
,
config
.
model
.
global_config
,
name
=
"pair_transition"
)
act
=
pt
(
act
=
pair_act
,
mask
=
pair_mask
)
return
act
f
=
hk
.
transform
(
run_pair_transition
)
n_res
=
consts
.
n_res
pair_act
=
np
.
random
.
rand
(
n_res
,
n_res
,
consts
.
c_z
).
astype
(
np
.
float32
)
pair_mask
=
np
.
ones
((
n_res
,
n_res
)).
astype
(
np
.
float32
)
# no mask
# Fetch pretrained parameters (but only from one block)]
params
=
compare_utils
.
fetch_alphafold_module_weights
(
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
+
"pair_transition"
)
params
=
tree_map
(
lambda
n
:
n
[
0
],
params
,
jax
.
numpy
.
DeviceArray
)
out_gt
=
f
.
apply
(
params
,
None
,
pair_act
,
pair_mask
).
block_until_ready
()
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
.
block_until_ready
()))
model
=
compare_utils
.
get_global_pretrained_openfold
()
out_repro
=
model
.
evoformer
.
blocks
[
0
].
pair_transition
(
torch
.
as_tensor
(
pair_act
,
dtype
=
torch
.
float32
).
cuda
(),
mask
=
torch
.
as_tensor
(
pair_mask
,
dtype
=
torch
.
float32
).
cuda
(),
).
cpu
()
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
)
<
consts
.
eps
))
if
__name__
==
"__main__"
:
unittest
.
main
()
...
...
tests/test_structure_module.py
View file @
304b5ff7
...
...
@@ -16,18 +16,28 @@ import torch
import
numpy
as
np
import
unittest
from
alpha
fold.np.residue_constants
import
(
from
open
fold.np.residue_constants
import
(
restype_rigid_group_default_frame
,
restype_atom14_to_rigid_group
,
restype_atom14_mask
,
restype_atom14_rigid_group_positions
,
)
from
alpha
fold.model.structure_module
import
*
from
alpha
fold.model.structure_module
import
(
from
open
fold.model.structure_module
import
*
from
open
fold.model.structure_module
import
(
_torsion_angles_to_frames
,
_frames_and_literature_positions_to_atom14_pos
,
)
from
alphafold.utils.utils
import
T
from
openfold.utils.affine_utils
import
T
import
tests.compare_utils
as
compare_utils
from
tests.config
import
consts
from
tests.data_utils
import
(
random_affines_4x4
,
)
if
(
compare_utils
.
alphafold_is_installed
()):
alphafold
=
compare_utils
.
import_alphafold
()
import
jax
import
haiku
as
hk
class
TestStructureModule
(
unittest
.
TestCase
):
...
...
@@ -75,7 +85,7 @@ class TestStructureModule(unittest.TestCase):
out
=
sm
(
s
,
z
,
f
)
self
.
assertTrue
(
out
[
"
t
ra
nsformation
s"
].
shape
==
(
no_layers
,
batch_size
,
n
,
4
,
4
)
out
[
"
f
ra
me
s"
].
shape
==
(
no_layers
,
batch_size
,
n
,
4
,
4
)
)
self
.
assertTrue
(
out
[
"angles"
].
shape
==
(
no_layers
,
batch_size
,
n
,
no_angles
,
2
)
...
...
@@ -190,6 +200,62 @@ class TestInvariantPointAttention(unittest.TestCase):
self
.
assertTrue
(
s
.
shape
==
shape_before
)
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_ipa_compare
(
self
):
def
run_ipa
(
act
,
static_feat_2d
,
mask
,
affine
):
config
=
compare_utils
.
get_alphafold_config
()
ipa
=
alphafold
.
model
.
folding
.
InvariantPointAttention
(
config
.
model
.
heads
.
structure_module
,
config
.
model
.
global_config
,
)
attn
=
ipa
(
inputs_1d
=
act
,
inputs_2d
=
static_feat_2d
,
mask
=
mask
,
affine
=
affine
)
return
attn
f
=
hk
.
transform
(
run_ipa
)
n_res
=
consts
.
n_res
c_s
=
consts
.
c_s
c_z
=
consts
.
c_z
sample_act
=
np
.
random
.
rand
(
n_res
,
c_s
)
sample_2d
=
np
.
random
.
rand
(
n_res
,
n_res
,
c_z
)
sample_mask
=
np
.
ones
((
n_res
,
1
))
affines
=
random_affines_4x4
((
n_res
,))
rigids
=
alphafold
.
model
.
r3
.
rigids_from_tensor4x4
(
affines
)
quats
=
alphafold
.
model
.
r3
.
rigids_to_quataffine
(
rigids
)
transformations
=
T
.
from_4x4
(
torch
.
as_tensor
(
affines
).
float
().
cuda
()
)
sample_affine
=
quats
ipa_params
=
compare_utils
.
fetch_alphafold_module_weights
(
"alphafold/alphafold_iteration/structure_module/"
+
"fold_iteration/invariant_point_attention"
)
out_gt
=
f
.
apply
(
ipa_params
,
None
,
sample_act
,
sample_2d
,
sample_mask
,
sample_affine
).
block_until_ready
()
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
))
with
torch
.
no_grad
():
model
=
compare_utils
.
get_global_pretrained_openfold
()
out_repro
=
model
.
structure_module
.
ipa
(
torch
.
as_tensor
(
sample_act
).
float
().
cuda
(),
torch
.
as_tensor
(
sample_2d
).
float
().
cuda
(),
transformations
,
torch
.
as_tensor
(
sample_mask
.
squeeze
(
-
1
)).
float
().
cuda
(),
).
cpu
()
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
class
TestAngleResnet
(
unittest
.
TestCase
):
def
test_shape
(
self
):
...
...
tests/test_template.py
View file @
304b5ff7
...
...
@@ -15,23 +15,38 @@
import
torch
import
numpy
as
np
import
unittest
from
alphafold.model.template
import
*
from
openfold.model.template
import
(
TemplatePointwiseAttention
,
TemplatePairStack
,
)
from
openfold.utils.tensor_utils
import
tree_map
import
tests.compare_utils
as
compare_utils
from
tests.config
import
consts
from
tests.data_utils
import
random_template_feats
if
(
compare_utils
.
alphafold_is_installed
()):
alphafold
=
compare_utils
.
import_alphafold
()
import
jax
import
haiku
as
hk
class
TestTemplatePointwiseAttention
(
unittest
.
TestCase
):
def
test_shape
(
self
):
batch_size
=
2
s_t
=
3
c_t
=
5
c_z
=
7
batch_size
=
consts
.
batch_size
n_seq
=
consts
.
n_seq
c_t
=
consts
.
c_t
c_z
=
consts
.
c_z
c
=
26
no_heads
=
13
n
=
17
n_res
=
consts
.
n_res
inf
=
1e7
tpa
=
TemplatePointwiseAttention
(
c_t
,
c_z
,
c
,
no_heads
,
chunk_size
=
4
)
tpa
=
TemplatePointwiseAttention
(
c_t
,
c_z
,
c
,
no_heads
,
chunk_size
=
4
,
inf
=
inf
)
t
=
torch
.
rand
((
batch_size
,
s_t
,
n
,
n
,
c_t
))
z
=
torch
.
rand
((
batch_size
,
n
,
n
,
c_z
))
t
=
torch
.
rand
((
batch_size
,
n_seq
,
n_res
,
n_res
,
c_t
))
z
=
torch
.
rand
((
batch_size
,
n
_res
,
n_res
,
c_z
))
z_update
=
tpa
(
t
,
z
)
...
...
@@ -40,17 +55,20 @@ class TestTemplatePointwiseAttention(unittest.TestCase):
class
TestTemplatePairStack
(
unittest
.
TestCase
):
def
test_shape
(
self
):
batch_size
=
2
c_t
=
5
batch_size
=
consts
.
batch_size
c_t
=
consts
.
c_t
c_hidden_tri_att
=
7
c_hidden_tri_mul
=
7
no_blocks
=
2
no_heads
=
4
pt_inner_dim
=
15
dropout
=
0.25
n_templ
=
3
n_res
=
5
n_templ
=
consts
.
n_templ
n_res
=
consts
.
n_res
blocks_per_ckpt
=
None
chunk_size
=
4
inf
=
1e7
eps
=
1e-7
tpe
=
TemplatePairStack
(
c_t
,
...
...
@@ -60,7 +78,10 @@ class TestTemplatePairStack(unittest.TestCase):
no_heads
=
no_heads
,
pair_transition_n
=
pt_inner_dim
,
dropout_rate
=
dropout
,
blocks_per_ckpt
=
None
,
chunk_size
=
chunk_size
,
inf
=
inf
,
eps
=
eps
,
)
t
=
torch
.
rand
((
batch_size
,
n_templ
,
n_res
,
n_res
,
c_t
))
...
...
@@ -71,7 +92,98 @@ class TestTemplatePairStack(unittest.TestCase):
self
.
assertTrue
(
shape_before
==
shape_after
)
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_compare
(
self
):
def
run_template_pair_stack
(
pair_act
,
pair_mask
):
config
=
compare_utils
.
get_alphafold_config
()
c_ee
=
config
.
model
.
embeddings_and_evoformer
tps
=
alphafold
.
model
.
modules
.
TemplatePairStack
(
c_ee
.
template
.
template_pair_stack
,
config
.
model
.
global_config
,
name
=
"template_pair_stack"
)
act
=
tps
(
pair_act
,
pair_mask
,
is_training
=
False
)
ln
=
hk
.
LayerNorm
([
-
1
],
True
,
True
,
name
=
"output_layer_norm"
)
act
=
ln
(
act
)
return
act
f
=
hk
.
transform
(
run_template_pair_stack
)
n_res
=
consts
.
n_res
pair_act
=
np
.
random
.
rand
(
n_res
,
n_res
,
consts
.
c_t
).
astype
(
np
.
float32
)
pair_mask
=
np
.
random
.
randint
(
low
=
0
,
high
=
2
,
size
=
(
n_res
,
n_res
)
).
astype
(
np
.
float32
)
params
=
compare_utils
.
fetch_alphafold_module_weights
(
"alphafold/alphafold_iteration/evoformer/template_embedding/"
+
"single_template_embedding/template_pair_stack"
)
params
.
update
(
compare_utils
.
fetch_alphafold_module_weights
(
"alphafold/alphafold_iteration/evoformer/template_embedding/"
+
"single_template_embedding/output_layer_norm"
))
out_gt
=
f
.
apply
(
params
,
jax
.
random
.
PRNGKey
(
42
),
pair_act
,
pair_mask
).
block_until_ready
()
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
))
model
=
compare_utils
.
get_global_pretrained_openfold
()
out_repro
=
model
.
template_pair_stack
(
torch
.
as_tensor
(
pair_act
).
cuda
(),
torch
.
as_tensor
(
pair_mask
).
cuda
(),
_mask_trans
=
False
,
).
cpu
()
self
.
assertTrue
(
torch
.
all
(
torch
.
abs
(
out_gt
-
out_repro
)
<
consts
.
eps
))
class
Template
(
unittest
.
TestCase
):
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_compare
(
self
):
def
test_template_embedding
(
pair
,
batch
,
mask_2d
):
config
=
compare_utils
.
get_alphafold_config
()
te
=
alphafold
.
model
.
modules
.
TemplateEmbedding
(
config
.
model
.
embeddings_and_evoformer
.
template
,
config
.
model
.
global_config
)
act
=
te
(
pair
,
batch
,
mask_2d
,
is_training
=
False
)
return
act
f
=
hk
.
transform
(
test_template_embedding
)
n_res
=
consts
.
n_res
n_templ
=
consts
.
n_templ
pair_act
=
np
.
random
.
rand
(
n_res
,
n_res
,
consts
.
c_z
).
astype
(
np
.
float32
)
batch
=
random_template_feats
(
n_templ
,
n_res
)
pair_mask
=
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
n_res
)).
astype
(
np
.
float32
)
# Fetch pretrained parameters (but only from one block)]
params
=
compare_utils
.
fetch_alphafold_module_weights
(
"alphafold/alphafold_iteration/evoformer/template_embedding"
)
out_gt
=
f
.
apply
(
params
,
jax
.
random
.
PRNGKey
(
42
),
pair_act
,
batch
,
pair_mask
).
block_until_ready
()
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
))
inds
=
np
.
random
.
randint
(
0
,
21
,
(
n_res
,))
batch
[
"target_feat"
]
=
np
.
eye
(
22
)[
inds
]
model
=
compare_utils
.
get_global_pretrained_openfold
()
out_repro
=
model
.
embed_templates
(
{
k
:
torch
.
as_tensor
(
v
).
cuda
()
for
k
,
v
in
batch
.
items
()},
torch
.
as_tensor
(
pair_act
).
cuda
(),
torch
.
as_tensor
(
pair_mask
).
cuda
(),
templ_dim
=
0
,
)
out_repro
=
out_repro
[
"template_pair_embedding"
]
out_repro
=
out_repro
.
cpu
()
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
)
<
consts
.
eps
))
if
__name__
==
"__main__"
:
...
...
tests/test_triangular_attention.py
View file @
304b5ff7
...
...
@@ -15,12 +15,21 @@
import
torch
import
numpy
as
np
import
unittest
from
alphafold.model.triangular_attention
import
*
from
openfold.model.triangular_attention
import
TriangleAttention
from
openfold.utils.tensor_utils
import
tree_map
import
tests.compare_utils
as
compare_utils
from
tests.config
import
consts
if
(
compare_utils
.
alphafold_is_installed
()):
alphafold
=
compare_utils
.
import_alphafold
()
import
jax
import
haiku
as
hk
class
TestTriangularAttention
(
unittest
.
TestCase
):
def
test_shape
(
self
):
c_z
=
2
c_z
=
consts
.
c_z
c
=
12
no_heads
=
4
starting
=
True
...
...
@@ -32,8 +41,8 @@ class TestTriangularAttention(unittest.TestCase):
starting
)
batch_size
=
4
n_res
=
7
batch_size
=
consts
.
batch_size
n_res
=
consts
.
n_res
x
=
torch
.
rand
((
batch_size
,
n_res
,
n_res
,
c_z
))
shape_before
=
x
.
shape
...
...
@@ -42,9 +51,61 @@ class TestTriangularAttention(unittest.TestCase):
self
.
assertTrue
(
shape_before
==
shape_after
)
if
__name__
==
"__main__"
:
unittest
.
main
()
def
_tri_att_compare
(
self
,
starting
=
False
):
name
=
(
"triangle_attention_"
+
(
"starting"
if
starting
else
"ending"
)
+
"_node"
)
def
run_tri_att
(
pair_act
,
pair_mask
):
config
=
compare_utils
.
get_alphafold_config
()
c_e
=
config
.
model
.
embeddings_and_evoformer
.
evoformer
tri_att
=
alphafold
.
model
.
modules
.
TriangleAttention
(
c_e
.
triangle_attention_starting_node
if
starting
else
c_e
.
triangle_attention_ending_node
,
config
.
model
.
global_config
,
name
=
name
,
)
act
=
tri_att
(
pair_act
=
pair_act
,
pair_mask
=
pair_mask
)
return
act
f
=
hk
.
transform
(
run_tri_att
)
n_res
=
consts
.
n_res
pair_act
=
np
.
random
.
rand
(
n_res
,
n_res
,
consts
.
c_z
)
pair_mask
=
np
.
random
.
randint
(
low
=
0
,
high
=
2
,
size
=
(
n_res
,
n_res
))
# Fetch pretrained parameters (but only from one block)]
params
=
compare_utils
.
fetch_alphafold_module_weights
(
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
+
name
)
params
=
tree_map
(
lambda
n
:
n
[
0
],
params
,
jax
.
numpy
.
DeviceArray
)
out_gt
=
f
.
apply
(
params
,
None
,
pair_act
,
pair_mask
).
block_until_ready
()
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
))
model
=
compare_utils
.
get_global_pretrained_openfold
()
module
=
(
model
.
evoformer
.
blocks
[
0
].
tri_att_start
if
starting
else
model
.
evoformer
.
blocks
[
0
].
tri_att_end
)
out_repro
=
module
(
torch
.
as_tensor
(
pair_act
,
dtype
=
torch
.
float32
).
cuda
(),
mask
=
torch
.
as_tensor
(
pair_mask
,
dtype
=
torch
.
float32
).
cuda
(),
).
cpu
()
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
)
<
consts
.
eps
))
def
test_tri_att_end_compare
(
self
):
self
.
_tri_att_compare
()
def
test_tri_att_start_compare
(
self
):
self
.
_tri_att_compare
(
starting
=
True
)
if
__name__
==
"__main__"
:
unittest
.
main
()
tests/test_triangular_multiplicative_update.py
View file @
304b5ff7
...
...
@@ -15,12 +15,20 @@
import
torch
import
numpy
as
np
import
unittest
from
alphafold.model.triangular_multiplicative_update
import
*
from
openfold.model.triangular_multiplicative_update
import
*
from
openfold.utils.tensor_utils
import
tree_map
import
tests.compare_utils
as
compare_utils
from
tests.config
import
consts
if
(
compare_utils
.
alphafold_is_installed
()):
alphafold
=
compare_utils
.
import_alphafold
()
import
jax
import
haiku
as
hk
class
TestTriangularMultiplicativeUpdate
(
unittest
.
TestCase
):
def
test_shape
(
self
):
c_z
=
7
c_z
=
consts
.
c_z
c
=
11
outgoing
=
True
...
...
@@ -30,8 +38,8 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
outgoing
,
)
n_res
=
5
batch_size
=
2
n_res
=
consts
.
c_z
batch_size
=
consts
.
batch_size
x
=
torch
.
rand
((
batch_size
,
n_res
,
n_res
,
c_z
))
mask
=
torch
.
randint
(
0
,
2
,
size
=
(
batch_size
,
n_res
,
n_res
))
...
...
@@ -41,6 +49,63 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
self
.
assertTrue
(
shape_before
==
shape_after
)
def
_tri_mul_compare
(
self
,
incoming
=
False
):
name
=
(
"triangle_multiplication_"
+
(
"incoming"
if
incoming
else
"outgoing"
)
)
def
run_tri_mul
(
pair_act
,
pair_mask
):
config
=
compare_utils
.
get_alphafold_config
()
c_e
=
config
.
model
.
embeddings_and_evoformer
.
evoformer
tri_mul
=
alphafold
.
model
.
modules
.
TriangleMultiplication
(
c_e
.
triangle_multiplication_incoming
if
incoming
else
c_e
.
triangle_multiplication_outgoing
,
config
.
model
.
global_config
,
name
=
name
,
)
act
=
tri_mul
(
act
=
pair_act
,
mask
=
pair_mask
)
return
act
f
=
hk
.
transform
(
run_tri_mul
)
n_res
=
consts
.
n_res
pair_act
=
np
.
random
.
rand
(
n_res
,
n_res
,
consts
.
c_z
).
astype
(
np
.
float32
)
pair_mask
=
np
.
random
.
randint
(
low
=
0
,
high
=
2
,
size
=
(
n_res
,
n_res
))
pair_mask
=
pair_mask
.
astype
(
np
.
float32
)
# Fetch pretrained parameters (but only from one block)]
params
=
compare_utils
.
fetch_alphafold_module_weights
(
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
+
name
)
params
=
tree_map
(
lambda
n
:
n
[
0
],
params
,
jax
.
numpy
.
DeviceArray
)
out_gt
=
f
.
apply
(
params
,
None
,
pair_act
,
pair_mask
).
block_until_ready
()
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
))
model
=
compare_utils
.
get_global_pretrained_openfold
()
module
=
(
model
.
evoformer
.
blocks
[
0
].
tri_mul_in
if
incoming
else
model
.
evoformer
.
blocks
[
0
].
tri_mul_out
)
out_repro
=
module
(
torch
.
as_tensor
(
pair_act
,
dtype
=
torch
.
float32
).
cuda
(),
mask
=
torch
.
as_tensor
(
pair_mask
,
dtype
=
torch
.
float32
).
cuda
(),
).
cpu
()
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
)
<
consts
.
eps
))
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_tri_mul_out_compare
(
self
):
self
.
_tri_mul_compare
()
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_tri_mul_in_compare
(
self
):
self
.
_tri_mul_compare
(
incoming
=
True
)
if
__name__
==
"__main__"
:
unittest
.
main
()
...
...
tests/test_utils.py
View file @
304b5ff7
...
...
@@ -16,7 +16,8 @@ import math
import
torch
import
unittest
from
alphafold.utils.utils
import
*
from
openfold.utils.affine_utils
import
*
from
openfold.utils.tensor_utils
import
*
X_90_ROT
=
torch
.
tensor
([
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment