Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
9f6b67f3
Commit
9f6b67f3
authored
Oct 06, 2021
by
Gustaf Ahdritz
Browse files
Fix another FP16 overflow, finish tests, update util scripts
parent
893fe372
Changes
18
Show whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
1130 additions
and
98 deletions
+1130
-98
config.py
config.py
+6
-6
openfold/model/template.py
openfold/model/template.py
+1
-2
openfold/utils/feats.py
openfold/utils/feats.py
+6
-2
openfold/utils/import_weights.py
openfold/utils/import_weights.py
+9
-1
openfold/utils/loss.py
openfold/utils/loss.py
+3
-3
scripts/install_third_party_dependencies.sh
scripts/install_third_party_dependencies.sh
+4
-0
scripts/run_unit_tests.sh
scripts/run_unit_tests.sh
+1
-16
tests/compare_utils.py
tests/compare_utils.py
+1
-1
tests/sample_feats.py
tests/sample_feats.py
+0
-0
tests/test_embedders.py
tests/test_embedders.py
+6
-1
tests/test_feats.py
tests/test_feats.py
+319
-0
tests/test_import_weights.py
tests/test_import_weights.py
+6
-8
tests/test_loss.py
tests/test_loss.py
+681
-6
tests/test_outer_product_mean.py
tests/test_outer_product_mean.py
+0
-1
tests/test_structure_module.py
tests/test_structure_module.py
+81
-48
tests/test_template.py
tests/test_template.py
+2
-1
tests/test_triangular_attention.py
tests/test_triangular_attention.py
+2
-0
tests/test_utils.py
tests/test_utils.py
+2
-2
No files found.
config.py
View file @
9f6b67f3
...
...
@@ -123,7 +123,7 @@ config = mlc.ConfigDict({
"dropout_rate"
:
0.25
,
"blocks_per_ckpt"
:
blocks_per_ckpt
,
"chunk_size"
:
chunk_size
,
"inf"
:
1e9
,
"inf"
:
1e5
,
#
1e9,
},
"template_pointwise_attention"
:
{
"c_t"
:
c_t
,
...
...
@@ -133,11 +133,11 @@ config = mlc.ConfigDict({
"c_hidden"
:
16
,
"no_heads"
:
4
,
"chunk_size"
:
chunk_size
,
"inf"
:
1e9
,
"inf"
:
1e5
,
#
1e9,
},
"inf"
:
1e9
,
"inf"
:
1e5
,
#
1e9,
"eps"
:
eps
,
#1e-6,
"enabled"
:
False
,
#
True,
"enabled"
:
True
,
"embed_angles"
:
True
,
},
"extra_msa"
:
{
...
...
@@ -160,7 +160,7 @@ config = mlc.ConfigDict({
"pair_dropout"
:
0.25
,
"blocks_per_ckpt"
:
blocks_per_ckpt
,
"chunk_size"
:
chunk_size
,
"inf"
:
1e9
,
"inf"
:
1e5
,
#
1e9,
"eps"
:
eps
,
#1e-10,
},
"enabled"
:
True
,
...
...
@@ -181,7 +181,7 @@ config = mlc.ConfigDict({
"pair_dropout"
:
0.25
,
"blocks_per_ckpt"
:
blocks_per_ckpt
,
"chunk_size"
:
chunk_size
,
"inf"
:
1e9
,
"inf"
:
1e5
,
#
1e9,
"eps"
:
eps
,
#1e-10,
},
"structure_module"
:
{
...
...
openfold/model/template.py
View file @
9f6b67f3
...
...
@@ -192,7 +192,6 @@ class TemplatePairStackBlock(nn.Module):
return
z
class
TemplatePairStack
(
nn
.
Module
):
"""
Implements Algorithm 16.
...
...
@@ -273,7 +272,7 @@ class TemplatePairStack(nn.Module):
_mask_trans
=
_mask_trans
,
)
for
b
in
self
.
blocks
],
args
=
(
t
),
args
=
(
t
,
),
blocks_per_ckpt
=
self
.
blocks_per_ckpt
if
self
.
training
else
None
,
)
...
...
openfold/utils/feats.py
View file @
9f6b67f3
...
...
@@ -115,8 +115,8 @@ def compute_residx(batch):
restype_atom37_to_atom14
=
aatype
.
new_tensor
(
restype_atom37_to_atom14
)
restype_atom14_mask
=
a
at
ype
.
new_tensor
(
restype_atom14_mask
,
dtype
=
float_type
restype_atom14_mask
=
b
at
ch
[
"seq_mask"
]
.
new_tensor
(
restype_atom14_mask
)
residx_atom14_to_atom37
=
restype_atom14_to_atom37
[
aatype
]
...
...
@@ -527,13 +527,17 @@ def build_template_pair_feat(batch, min_bin, max_bin, no_bins, eps=1e-20, inf=1e
)
n
,
ca
,
c
=
[
rc
.
atom_order
[
a
]
for
a
in
[
'N'
,
'CA'
,
'C'
]]
# TODO: Consider running this in double precision
affines
=
T
.
make_transform_from_reference
(
n_xyz
=
batch
[
"template_all_atom_positions"
][...,
n
,
:],
ca_xyz
=
batch
[
"template_all_atom_positions"
][...,
ca
,
:],
c_xyz
=
batch
[
"template_all_atom_positions"
][...,
c
,
:],
eps
=
eps
,
)
points
=
affines
.
get_trans
()[...,
None
,
:,
:]
affine_vec
=
affines
[...,
None
].
invert_apply
(
points
)
inv_distance_scalar
=
torch
.
rsqrt
(
eps
+
torch
.
sum
(
affine_vec
**
2
,
dim
=-
1
)
)
...
...
openfold/utils/import_weights.py
View file @
9f6b67f3
...
...
@@ -407,7 +407,15 @@ def import_jax_weights_(model, npz_path, version="model_1"):
},
}
if
(
version
not
in
[
"model_1"
,
"model_2"
]):
no_templ
=
[
"model_3"
,
"model_4"
,
"model_5"
,
"model_3_ptm"
,
"model_4_ptm"
,
"model_5_ptm"
,
]
if
(
version
in
no_templ
):
evo_dict
=
translations
[
"evoformer"
]
keys
=
list
(
evo_dict
.
keys
())
for
k
in
keys
:
...
...
openfold/utils/loss.py
View file @
9f6b67f3
...
...
@@ -1428,10 +1428,10 @@ class AlphaFoldLoss(nn.Module):
for
k
,
loss_fn
in
loss_fns
.
items
():
weight
=
self
.
config
[
k
].
weight
if
(
weight
):
print
(
k
)
#
print(k)
loss
=
loss_fn
()
print
(
weight
*
loss
)
#
print(weight * loss)
cum_loss
=
cum_loss
+
weight
*
loss
print
(
cum_loss
)
#
print(cum_loss)
return
cum_loss
scripts/install_third_party_dependencies.sh
View file @
9f6b67f3
...
...
@@ -31,6 +31,10 @@ pushd lib/conda/envs/$ENV_NAME/lib/python3.9/site-packages/ \
wget
-q
-P
openfold/resources
\
https://git.scicore.unibas.ch/schwede/openstructure/-/raw/7102c63615b64735c4941278d92b554ec94415f8/modules/mol/alg/src/stereo_chemical_props.txt
# Certain tests need access to this file
mkdir
-p
tests/test_data/alphafold/common
ln
-s
openfold/resources/stereo_chemical_props.txt tests/test_data/alphafold/common
# Download pretrained openfold weights
scripts/download_alphafold_params.sh openfold/resources
...
...
scripts/run_unit_tests.sh
View file @
9f6b67f3
#!/bin/bash
FLAGS
=
""
while
getopts
":v"
option
;
do
case
$option
in
v
)
FLAGS
=
$(
echo
"-v
$FLAGS
"
| xargs
)
# strip whitespace
;;
*
)
echo
"Invalid option:
${
option
}
"
;;
esac
done
python3
-m
unittest
$FLAGS
"
$@
"
||
\
python3
-m
unittest
"
$@
"
||
\
echo
-e
"
\n
Test(s) failed. Make sure you've installed all Python dependencies."
tests/compare_utils.py
View file @
9f6b67f3
...
...
@@ -64,7 +64,7 @@ def get_global_pretrained_openfold():
"""Cannot load pretrained parameters. Make sure to run the
installation script before running tests."""
)
import_jax_weights_
(
_model
,
_param_path
)
import_jax_weights_
(
_model
,
_param_path
,
version
=
"model_1_ptm"
)
_model
=
_model
.
cuda
()
return
_model
...
...
tests/sample_feats.py
deleted
100644 → 0
View file @
893fe372
tests/test_embedders.py
View file @
9f6b67f3
...
...
@@ -15,7 +15,12 @@
import
torch
import
numpy
as
np
import
unittest
from
alphafold.model.embedders
import
*
from
openfold.model.embedders
import
(
InputEmbedder
,
RecyclingEmbedder
,
TemplateAngleEmbedder
,
TemplatePairEmbedder
,
)
class
TestInputEmbedder
(
unittest
.
TestCase
):
...
...
tests/test_feats.py
0 → 100644
View file @
9f6b67f3
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
import
numpy
as
np
import
unittest
from
openfold.np.residue_constants
import
(
restype_rigid_group_default_frame
,
restype_atom14_to_rigid_group
,
restype_atom14_mask
,
restype_atom14_rigid_group_positions
,
)
from
openfold.utils.affine_utils
import
T
import
openfold.utils.feats
as
feats
from
openfold.utils.tensor_utils
import
(
tree_map
,
tensor_tree_map
,
)
import
tests.compare_utils
as
compare_utils
from
tests.config
import
consts
from
tests.data_utils
import
random_affines_4x4
if
(
compare_utils
.
alphafold_is_installed
()):
alphafold
=
compare_utils
.
import_alphafold
()
import
jax
import
haiku
as
hk
class
TestFeats
(
unittest
.
TestCase
):
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_pseudo_beta_fn_compare
(
self
):
def
test_pbf
(
aatype
,
all_atom_pos
,
all_atom_mask
):
return
alphafold
.
model
.
modules
.
pseudo_beta_fn
(
aatype
,
all_atom_pos
,
all_atom_mask
,
)
f
=
hk
.
transform
(
test_pbf
)
n_res
=
consts
.
n_res
aatype
=
np
.
random
.
randint
(
0
,
22
,
(
n_res
,))
all_atom_pos
=
np
.
random
.
rand
(
n_res
,
37
,
3
).
astype
(
np
.
float32
)
all_atom_mask
=
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
37
))
out_gt_pos
,
out_gt_mask
=
f
.
apply
(
{},
None
,
aatype
,
all_atom_pos
,
all_atom_mask
)
out_gt_pos
=
torch
.
tensor
(
np
.
array
(
out_gt_pos
.
block_until_ready
()))
out_gt_mask
=
torch
.
tensor
(
np
.
array
(
out_gt_mask
.
block_until_ready
()))
out_repro_pos
,
out_repro_mask
=
feats
.
pseudo_beta_fn
(
torch
.
tensor
(
aatype
).
cuda
(),
torch
.
tensor
(
all_atom_pos
).
cuda
(),
torch
.
tensor
(
all_atom_mask
).
cuda
(),
)
out_repro_pos
=
out_repro_pos
.
cpu
()
out_repro_mask
=
out_repro_mask
.
cpu
()
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt_pos
-
out_repro_pos
))
<
consts
.
eps
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt_mask
-
out_repro_mask
))
<
consts
.
eps
)
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_atom37_to_torsion_angles_compare
(
self
):
def
run_test
(
aatype
,
all_atom_pos
,
all_atom_mask
):
return
alphafold
.
model
.
all_atom
.
atom37_to_torsion_angles
(
aatype
,
all_atom_pos
,
all_atom_mask
,
placeholder_for_undefined
=
False
,
)
f
=
hk
.
transform
(
run_test
)
n_templ
=
7
n_res
=
13
aatype
=
np
.
random
.
randint
(
0
,
22
,
(
n_templ
,
n_res
)).
astype
(
np
.
int64
)
all_atom_pos
=
np
.
random
.
rand
(
n_templ
,
n_res
,
37
,
3
).
astype
(
np
.
float32
)
all_atom_mask
=
np
.
random
.
randint
(
0
,
2
,
(
n_templ
,
n_res
,
37
)
).
astype
(
np
.
float32
)
out_gt
=
f
.
apply
({},
None
,
aatype
,
all_atom_pos
,
all_atom_mask
)
out_gt
=
jax
.
tree_map
(
lambda
x
:
torch
.
as_tensor
(
np
.
array
(
x
)),
out_gt
)
out_repro
=
feats
.
atom37_to_torsion_angles
(
torch
.
as_tensor
(
aatype
).
cuda
(),
torch
.
as_tensor
(
all_atom_pos
).
cuda
(),
torch
.
as_tensor
(
all_atom_mask
).
cuda
(),
)
tasc
=
out_repro
[
"torsion_angles_sin_cos"
].
cpu
()
atasc
=
out_repro
[
"alt_torsion_angles_sin_cos"
].
cpu
()
tam
=
out_repro
[
"torsion_angles_mask"
].
cpu
()
# This function is extremely sensitive to floating point imprecisions,
# so it is given much greater latitude in comparison tests.
self
.
assertTrue
(
torch
.
mean
(
torch
.
abs
(
out_gt
[
"torsion_angles_sin_cos"
]
-
tasc
)
)
<
0.01
)
self
.
assertTrue
(
torch
.
mean
(
torch
.
abs
(
out_gt
[
"alt_torsion_angles_sin_cos"
]
-
atasc
)
)
<
0.01
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
[
"torsion_angles_mask"
]
-
tam
))
<
consts
.
eps
)
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_atom37_to_frames_compare
(
self
):
def
run_atom37_to_frames
(
aatype
,
all_atom_positions
,
all_atom_mask
):
return
alphafold
.
model
.
all_atom
.
atom37_to_frames
(
aatype
,
all_atom_positions
,
all_atom_mask
)
f
=
hk
.
transform
(
run_atom37_to_frames
)
n_res
=
consts
.
n_res
batch
=
{
"aatype"
:
np
.
random
.
randint
(
0
,
21
,
(
n_res
,)),
"all_atom_positions"
:
np
.
random
.
rand
(
n_res
,
37
,
3
).
astype
(
np
.
float32
),
"all_atom_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
37
)).
astype
(
np
.
float32
),
}
out_gt
=
f
.
apply
({},
None
,
**
batch
)
to_tensor
=
lambda
t
:
torch
.
tensor
(
np
.
array
(
t
))
out_gt
=
{
k
:
to_tensor
(
v
)
for
k
,
v
in
out_gt
.
items
()}
def
flat12_to_4x4
(
flat12
):
rot
=
flat12
[...,
:
9
].
view
(
*
flat12
.
shape
[:
-
1
],
3
,
3
)
trans
=
flat12
[...,
9
:]
four_by_four
=
torch
.
zeros
(
*
flat12
.
shape
[:
-
1
],
4
,
4
)
four_by_four
[...,
:
3
,
:
3
]
=
rot
four_by_four
[...,
:
3
,
3
]
=
trans
four_by_four
[...,
3
,
3
]
=
1
return
four_by_four
out_gt
[
"rigidgroups_gt_frames"
]
=
flat12_to_4x4
(
out_gt
[
"rigidgroups_gt_frames"
]
)
out_gt
[
"rigidgroups_alt_gt_frames"
]
=
flat12_to_4x4
(
out_gt
[
"rigidgroups_alt_gt_frames"
]
)
to_tensor
=
lambda
t
:
torch
.
tensor
(
np
.
array
(
t
)).
cuda
()
batch
=
tree_map
(
to_tensor
,
batch
,
np
.
ndarray
)
out_repro
=
feats
.
atom37_to_frames
(
eps
=
1e-8
,
**
batch
)
out_repro
=
tensor_tree_map
(
lambda
t
:
t
.
cpu
(),
out_repro
)
for
k
,
v
in
out_gt
.
items
():
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
[
k
]
-
out_repro
[
k
]))
<
consts
.
eps
)
def
test_torsion_angles_to_frames_shape
(
self
):
batch_size
=
2
n
=
5
rots
=
torch
.
rand
((
batch_size
,
n
,
3
,
3
))
trans
=
torch
.
rand
((
batch_size
,
n
,
3
))
ts
=
T
(
rots
,
trans
)
angles
=
torch
.
rand
((
batch_size
,
n
,
7
,
2
))
aas
=
torch
.
tensor
([
i
%
2
for
i
in
range
(
n
)])
aas
=
torch
.
stack
([
aas
for
_
in
range
(
batch_size
)])
frames
=
feats
.
torsion_angles_to_frames
(
ts
,
angles
,
aas
,
torch
.
tensor
(
restype_rigid_group_default_frame
),
)
self
.
assertTrue
(
frames
.
shape
==
(
batch_size
,
n
,
8
))
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_torsion_angles_to_frames_compare
(
self
):
def
run_torsion_angles_to_frames
(
aatype
,
backb_to_global
,
torsion_angles_sin_cos
):
return
alphafold
.
model
.
all_atom
.
torsion_angles_to_frames
(
aatype
,
backb_to_global
,
torsion_angles_sin_cos
,
)
f
=
hk
.
transform
(
run_torsion_angles_to_frames
)
n_res
=
consts
.
n_res
aatype
=
np
.
random
.
randint
(
0
,
21
,
size
=
(
n_res
,))
affines
=
random_affines_4x4
((
n_res
,))
rigids
=
alphafold
.
model
.
r3
.
rigids_from_tensor4x4
(
affines
)
transformations
=
T
.
from_4x4
(
torch
.
as_tensor
(
affines
).
float
())
torsion_angles_sin_cos
=
np
.
random
.
rand
(
n_res
,
7
,
2
)
out_gt
=
f
.
apply
(
{},
None
,
aatype
,
rigids
,
torsion_angles_sin_cos
)
jax
.
tree_map
(
lambda
x
:
x
.
block_until_ready
(),
out_gt
)
out
=
feats
.
torsion_angles_to_frames
(
transformations
.
cuda
(),
torch
.
as_tensor
(
torsion_angles_sin_cos
).
cuda
(),
torch
.
as_tensor
(
aatype
).
cuda
(),
torch
.
tensor
(
restype_rigid_group_default_frame
).
cuda
(),
)
# Convert the Rigids to 4x4 transformation tensors
rots_gt
=
list
(
map
(
lambda
x
:
torch
.
as_tensor
(
np
.
array
(
x
)),
out_gt
.
rot
)
)
trans_gt
=
list
(
map
(
lambda
x
:
torch
.
as_tensor
(
np
.
array
(
x
)),
out_gt
.
trans
)
)
rots_gt
=
torch
.
cat
([
x
.
unsqueeze
(
-
1
)
for
x
in
rots_gt
],
dim
=-
1
)
rots_gt
=
rots_gt
.
view
(
*
rots_gt
.
shape
[:
-
1
],
3
,
3
)
trans_gt
=
torch
.
cat
([
x
.
unsqueeze
(
-
1
)
for
x
in
trans_gt
],
dim
=-
1
)
transforms_gt
=
torch
.
cat
([
rots_gt
,
trans_gt
.
unsqueeze
(
-
1
)],
dim
=-
1
)
bottom_row
=
torch
.
zeros
((
*
rots_gt
.
shape
[:
-
2
],
1
,
4
))
bottom_row
[...,
3
]
=
1
transforms_gt
=
torch
.
cat
([
transforms_gt
,
bottom_row
],
dim
=-
2
)
transforms_repro
=
out
.
to_4x4
().
cpu
()
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
transforms_gt
-
transforms_repro
)
<
consts
.
eps
)
)
def
test_frames_and_literature_positions_to_atom14_pos_shape
(
self
):
batch_size
=
consts
.
batch_size
n_res
=
consts
.
n_res
rots
=
torch
.
rand
((
batch_size
,
n_res
,
8
,
3
,
3
))
trans
=
torch
.
rand
((
batch_size
,
n_res
,
8
,
3
))
ts
=
T
(
rots
,
trans
)
f
=
torch
.
randint
(
low
=
0
,
high
=
21
,
size
=
(
batch_size
,
n_res
)).
long
()
xyz
=
feats
.
frames_and_literature_positions_to_atom14_pos
(
ts
,
f
,
torch
.
tensor
(
restype_rigid_group_default_frame
),
torch
.
tensor
(
restype_atom14_to_rigid_group
),
torch
.
tensor
(
restype_atom14_mask
),
torch
.
tensor
(
restype_atom14_rigid_group_positions
),
)
self
.
assertTrue
(
xyz
.
shape
==
(
batch_size
,
n_res
,
14
,
3
))
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_frames_and_literature_positions_to_atom14_pos_compare
(
self
):
def
run_f
(
aatype
,
affines
):
am
=
alphafold
.
model
return
am
.
all_atom
.
frames_and_literature_positions_to_atom14_pos
(
aatype
,
affines
)
f
=
hk
.
transform
(
run_f
)
n_res
=
consts
.
n_res
aatype
=
np
.
random
.
randint
(
0
,
21
,
size
=
(
n_res
,))
affines
=
random_affines_4x4
((
n_res
,
8
))
rigids
=
alphafold
.
model
.
r3
.
rigids_from_tensor4x4
(
affines
)
transformations
=
T
.
from_4x4
(
torch
.
as_tensor
(
affines
).
float
())
out_gt
=
f
.
apply
(
{},
None
,
aatype
,
rigids
)
jax
.
tree_map
(
lambda
x
:
x
.
block_until_ready
(),
out_gt
)
out_gt
=
torch
.
stack
(
[
torch
.
as_tensor
(
np
.
array
(
x
))
for
x
in
out_gt
],
dim
=-
1
)
out_repro
=
feats
.
frames_and_literature_positions_to_atom14_pos
(
transformations
.
cuda
(),
torch
.
as_tensor
(
aatype
).
cuda
(),
torch
.
tensor
(
restype_rigid_group_default_frame
).
cuda
(),
torch
.
tensor
(
restype_atom14_to_rigid_group
).
cuda
(),
torch
.
tensor
(
restype_atom14_mask
).
cuda
(),
torch
.
tensor
(
restype_atom14_rigid_group_positions
).
cuda
(),
).
cpu
()
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
)
<
consts
.
eps
))
if
__name__
==
"__main__"
:
unittest
.
main
()
tests/test_import_weights.py
View file @
9f6b67f3
...
...
@@ -17,19 +17,17 @@ import numpy as np
import
unittest
from
config
import
model_config
from
alpha
fold.model.model
import
AlphaFold
from
alphafold.model
.import_weights
import
*
from
open
fold.model.model
import
AlphaFold
from
openfold.utils
.import_weights
import
import_jax_weights_
class
TestImportWeights
(
unittest
.
TestCase
):
def
test_import_jax_weights_
(
self
):
npz_path
=
"
tests/model/alphafold_2
/params_model_1.npz"
npz_path
=
"
openfold/resources/params
/params_model_1
_ptm
.npz"
c
=
model_config
(
"model_1"
).
model
c
.
evoformer_stack
.
blocks_per_ckpt
=
None
# don't want to set up
# deepspeed for this test
model
=
AlphaFold
(
c
)
c
=
model_config
(
"model_1_ptm"
)
c
.
globals
.
blocks_per_ckpt
=
None
model
=
AlphaFold
(
c
.
model
)
import_jax_weights_
(
model
,
npz_path
,
...
...
tests/test_loss.py
View file @
9f6b67f3
...
...
@@ -12,22 +12,41 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
math
import
torch
import
numpy
as
np
import
unittest
import
ml_collections
as
mlc
from
openfold.utils.affine_utils
import
T
,
affine_vector_to_4x4
import
openfold.utils.feats
as
feats
from
openfold.utils.loss
import
(
torsion_angle_loss
,
compute_fape
,
between_residue_bond_loss
,
between_residue_clash_loss
,
find_structural_violations
,
compute_renamed_ground_truth
,
masked_msa_loss
,
distogram_loss
,
experimentally_resolved_loss
,
violation_loss
,
fape_loss
,
lddt_loss
,
supervised_chi_loss
,
backbone_loss
,
sidechain_loss
,
tm_loss
,
)
from
openfold.utils.tensor_utils
import
(
tree_map
,
tensor_tree_map
,
dict_multimap
,
)
from
openfold.utils.affine_utils
import
T
from
openfold.utils.tensor_utils
import
tensor_tree_map
import
tests.compare_utils
as
compare_utils
from
tests.config
import
consts
from
tests.data_utils
import
random_affines_vector
,
random_affines_4x4
if
(
compare_utils
.
alphafold_is_installed
()):
alphafold
=
compare_utils
.
import_alphafold
()
...
...
@@ -102,7 +121,9 @@ class TestLoss(unittest.TestCase):
n_res
=
consts
.
n_res
pred_pos
=
np
.
random
.
rand
(
n_res
,
14
,
3
).
astype
(
np
.
float32
)
pred_atom_mask
=
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
14
)).
astype
(
np
.
float32
)
pred_atom_mask
=
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
14
)
).
astype
(
np
.
float32
)
residue_index
=
np
.
arange
(
n_res
)
aatype
=
np
.
random
.
randint
(
0
,
22
,
(
n_res
,))
...
...
@@ -130,12 +151,12 @@ class TestLoss(unittest.TestCase):
)
def
test_between_residue_clash_loss
(
self
):
def
test_
run_
between_residue_clash_loss
(
self
):
bs
=
consts
.
batch_size
n
=
consts
.
n_res
pred_pos
=
torch
.
rand
(
bs
,
n
,
14
,
3
)
pred_atom_mask
=
torch
.
randint
(
0
,
2
,
(
bs
,
n
,
14
))
pred_atom_mask
=
torch
.
randint
(
0
,
2
,
(
bs
,
n
,
14
))
.
float
()
atom14_atom_radius
=
torch
.
rand
(
bs
,
n
,
14
)
residue_index
=
torch
.
arange
(
n
).
unsqueeze
(
0
)
...
...
@@ -146,6 +167,48 @@ class TestLoss(unittest.TestCase):
residue_index
,
)
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_between_residue_clash_loss_compare
(
self
):
def
run_brcl
(
pred_pos
,
atom_exists
,
atom_radius
,
res_ind
):
return
alphafold
.
model
.
all_atom
.
between_residue_clash_loss
(
pred_pos
,
atom_exists
,
atom_radius
,
res_ind
,
)
f
=
hk
.
transform
(
run_brcl
)
n_res
=
consts
.
n_res
pred_pos
=
np
.
random
.
rand
(
n_res
,
14
,
3
).
astype
(
np
.
float32
)
atom_exists
=
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
14
)).
astype
(
np
.
float32
)
atom_radius
=
np
.
random
.
rand
(
n_res
,
14
).
astype
(
np
.
float32
)
res_ind
=
np
.
arange
(
n_res
,)
out_gt
=
f
.
apply
(
{},
None
,
pred_pos
,
atom_exists
,
atom_radius
,
res_ind
,
)
out_gt
=
jax
.
tree_map
(
lambda
x
:
x
.
block_until_ready
(),
out_gt
)
out_gt
=
jax
.
tree_map
(
lambda
x
:
torch
.
tensor
(
np
.
copy
(
x
)),
out_gt
)
out_repro
=
between_residue_clash_loss
(
torch
.
tensor
(
pred_pos
).
cuda
(),
torch
.
tensor
(
atom_exists
).
cuda
(),
torch
.
tensor
(
atom_radius
).
cuda
(),
torch
.
tensor
(
res_ind
).
cuda
(),
)
out_repro
=
tensor_tree_map
(
lambda
x
:
x
.
cpu
(),
out_repro
)
for
k
in
out_gt
.
keys
():
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
[
k
]
-
out_repro
[
k
]))
<
consts
.
eps
)
def
test_find_structural_violations
(
self
):
n
=
consts
.
n_res
...
...
@@ -165,7 +228,619 @@ class TestLoss(unittest.TestCase):
find_structural_violations
(
batch
,
pred_pos
,
**
config
)
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_find_structural_violations_compare
(
self
):
def
run_fsv
(
batch
,
pos
,
config
):
cwd
=
os
.
getcwd
()
os
.
chdir
(
"tests/test_data"
)
loss
=
alphafold
.
model
.
folding
.
find_structural_violations
(
batch
,
pos
,
config
,
)
os
.
chdir
(
cwd
)
return
loss
f
=
hk
.
transform
(
run_fsv
)
n_res
=
consts
.
n_res
batch
=
{
"atom14_atom_exists"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
14
)),
"residue_index"
:
np
.
arange
(
n_res
),
"aatype"
:
np
.
random
.
randint
(
0
,
21
,
(
n_res
,)),
"residx_atom14_to_atom37"
:
np
.
random
.
randint
(
0
,
37
,
(
n_res
,
14
)).
astype
(
np
.
int64
),
}
pred_pos
=
np
.
random
.
rand
(
n_res
,
14
,
3
)
config
=
mlc
.
ConfigDict
({
"clash_overlap_tolerance"
:
1.5
,
"violation_tolerance_factor"
:
12.0
,
})
out_gt
=
f
.
apply
(
{},
None
,
batch
,
pred_pos
,
config
)
out_gt
=
jax
.
tree_map
(
lambda
x
:
x
.
block_until_ready
(),
out_gt
)
out_gt
=
jax
.
tree_map
(
lambda
x
:
torch
.
tensor
(
np
.
copy
(
x
)),
out_gt
)
batch
=
tree_map
(
lambda
x
:
torch
.
tensor
(
x
).
cuda
(),
batch
,
np
.
ndarray
)
out_repro
=
find_structural_violations
(
batch
,
torch
.
tensor
(
pred_pos
).
cuda
(),
**
config
,
)
out_repro
=
tensor_tree_map
(
lambda
x
:
x
.
cpu
(),
out_repro
)
def
compare
(
out
):
gt
,
repro
=
out
assert
(
torch
.
max
(
torch
.
abs
(
gt
-
repro
))
<
consts
.
eps
)
dict_multimap
(
compare
,
[
out_gt
,
out_repro
])
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_compute_renamed_ground_truth_compare
(
self
):
def
run_crgt
(
batch
,
atom14_pred_pos
):
return
alphafold
.
model
.
folding
.
compute_renamed_ground_truth
(
batch
,
atom14_pred_pos
,
)
f
=
hk
.
transform
(
run_crgt
)
n_res
=
consts
.
n_res
batch
=
{
"seq_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,)).
astype
(
np
.
float32
),
"aatype"
:
np
.
random
.
randint
(
0
,
21
,
(
n_res
,)),
"atom14_gt_positions"
:
np
.
random
.
rand
(
n_res
,
14
,
3
),
"atom14_gt_exists"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
14
)).
astype
(
np
.
float32
),
}
def
_build_extra_feats_np
():
b
=
tree_map
(
lambda
n
:
torch
.
tensor
(
n
),
batch
,
np
.
ndarray
)
b
.
update
(
feats
.
build_ambiguity_feats
(
b
))
b
.
update
(
feats
.
compute_residx
(
b
))
return
tensor_tree_map
(
lambda
t
:
np
.
array
(
t
),
b
)
batch
=
_build_extra_feats_np
()
atom14_pred_pos
=
np
.
random
.
rand
(
n_res
,
14
,
3
)
out_gt
=
f
.
apply
({},
None
,
batch
,
atom14_pred_pos
)
out_gt
=
jax
.
tree_map
(
lambda
x
:
torch
.
tensor
(
np
.
array
(
x
)),
out_gt
)
batch
=
tree_map
(
lambda
x
:
torch
.
tensor
(
x
).
cuda
(),
batch
,
np
.
ndarray
)
atom14_pred_pos
=
torch
.
tensor
(
atom14_pred_pos
).
cuda
()
out_repro
=
compute_renamed_ground_truth
(
batch
,
atom14_pred_pos
)
out_repro
=
tensor_tree_map
(
lambda
t
:
t
.
cpu
(),
out_repro
)
for
k
in
out_repro
:
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
[
k
]
-
out_repro
[
k
]))
<
consts
.
eps
)
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_msa_loss_compare
(
self
):
def
run_msa_loss
(
value
,
batch
):
config
=
compare_utils
.
get_alphafold_config
()
msa_head
=
alphafold
.
model
.
modules
.
MaskedMsaHead
(
config
.
model
.
heads
.
masked_msa
,
config
.
model
.
global_config
)
return
msa_head
.
loss
(
value
,
batch
)
f
=
hk
.
transform
(
run_msa_loss
)
n_res
=
consts
.
n_res
n_seq
=
consts
.
n_seq
value
=
{
"logits"
:
np
.
random
.
rand
(
n_res
,
n_seq
,
23
).
astype
(
np
.
float32
),
}
batch
=
{
"true_msa"
:
np
.
random
.
randint
(
0
,
21
,
(
n_res
,
n_seq
)),
"bert_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
n_seq
)).
astype
(
np
.
float32
),
}
out_gt
=
f
.
apply
({},
None
,
value
,
batch
)[
"loss"
]
out_gt
=
torch
.
tensor
(
np
.
array
(
out_gt
))
value
=
tree_map
(
lambda
x
:
torch
.
tensor
(
x
).
cuda
(),
value
,
np
.
ndarray
)
batch
=
tree_map
(
lambda
x
:
torch
.
tensor
(
x
).
cuda
(),
batch
,
np
.
ndarray
)
with
torch
.
no_grad
():
out_repro
=
masked_msa_loss
(
value
[
"logits"
],
**
batch
,
)
out_repro
=
tensor_tree_map
(
lambda
t
:
t
.
cpu
(),
out_repro
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_distogram_loss_compare
(
self
):
config
=
compare_utils
.
get_alphafold_config
()
c_distogram
=
config
.
model
.
heads
.
distogram
def
run_distogram_loss
(
value
,
batch
):
dist_head
=
alphafold
.
model
.
modules
.
DistogramHead
(
c_distogram
,
config
.
model
.
global_config
)
return
dist_head
.
loss
(
value
,
batch
)
f
=
hk
.
transform
(
run_distogram_loss
)
n_res
=
consts
.
n_res
value
=
{
"logits"
:
np
.
random
.
rand
(
n_res
,
n_res
,
c_distogram
.
num_bins
).
astype
(
np
.
float32
),
"bin_edges"
:
np
.
linspace
(
c_distogram
.
first_break
,
c_distogram
.
last_break
,
c_distogram
.
num_bins
,
)
}
batch
=
{
"pseudo_beta"
:
np
.
random
.
rand
(
n_res
,
3
).
astype
(
np
.
float32
),
"pseudo_beta_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,))
}
out_gt
=
f
.
apply
({},
None
,
value
,
batch
)[
"loss"
]
out_gt
=
torch
.
tensor
(
np
.
array
(
out_gt
))
value
=
tree_map
(
lambda
x
:
torch
.
tensor
(
x
).
cuda
(),
value
,
np
.
ndarray
)
batch
=
tree_map
(
lambda
x
:
torch
.
tensor
(
x
).
cuda
(),
batch
,
np
.
ndarray
)
with
torch
.
no_grad
():
out_repro
=
distogram_loss
(
logits
=
value
[
"logits"
],
min_bin
=
c_distogram
.
first_break
,
max_bin
=
c_distogram
.
last_break
,
no_bins
=
c_distogram
.
num_bins
,
**
batch
,
)
out_repro
=
tensor_tree_map
(
lambda
t
:
t
.
cpu
(),
out_repro
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_experimentally_resolved_loss_compare
(
self
):
config
=
compare_utils
.
get_alphafold_config
()
c_experimentally_resolved
=
config
.
model
.
heads
.
experimentally_resolved
def
run_experimentally_resolved_loss
(
value
,
batch
):
er_head
=
alphafold
.
model
.
modules
.
ExperimentallyResolvedHead
(
c_experimentally_resolved
,
config
.
model
.
global_config
)
return
er_head
.
loss
(
value
,
batch
)
f
=
hk
.
transform
(
run_experimentally_resolved_loss
)
n_res
=
consts
.
n_res
value
=
{
"logits"
:
np
.
random
.
rand
(
n_res
,
37
).
astype
(
np
.
float32
),
}
batch
=
{
"all_atom_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
37
)),
"atom37_atom_exists"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
37
)),
"resolution"
:
np
.
array
(
1.0
)
}
out_gt
=
f
.
apply
({},
None
,
value
,
batch
)[
"loss"
]
out_gt
=
torch
.
tensor
(
np
.
array
(
out_gt
))
value
=
tree_map
(
lambda
x
:
torch
.
tensor
(
x
).
cuda
(),
value
,
np
.
ndarray
)
batch
=
tree_map
(
lambda
x
:
torch
.
tensor
(
x
).
cuda
(),
batch
,
np
.
ndarray
)
with
torch
.
no_grad
():
out_repro
=
experimentally_resolved_loss
(
logits
=
value
[
"logits"
],
min_resolution
=
c_experimentally_resolved
.
min_resolution
,
max_resolution
=
c_experimentally_resolved
.
max_resolution
,
**
batch
,
)
out_repro
=
tensor_tree_map
(
lambda
t
:
t
.
cpu
(),
out_repro
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_supervised_chi_loss_compare
(
self
):
config
=
compare_utils
.
get_alphafold_config
()
c_chi_loss
=
config
.
model
.
heads
.
structure_module
def
run_supervised_chi_loss
(
value
,
batch
):
ret
=
{
"loss"
:
jax
.
numpy
.
array
(
0.
),
}
alphafold
.
model
.
folding
.
supervised_chi_loss
(
ret
,
batch
,
value
,
c_chi_loss
)
return
ret
[
"loss"
]
f
=
hk
.
transform
(
run_supervised_chi_loss
)
n_res
=
consts
.
n_res
value
=
{
"sidechains"
:
{
"angles_sin_cos"
:
np
.
random
.
rand
(
8
,
n_res
,
7
,
2
).
astype
(
np
.
float32
),
"unnormalized_angles_sin_cos"
:
np
.
random
.
rand
(
8
,
n_res
,
7
,
2
).
astype
(
np
.
float32
),
}
}
batch
=
{
"aatype"
:
np
.
random
.
randint
(
0
,
21
,
(
n_res
,)),
"seq_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,)),
"chi_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
4
)),
"chi_angles"
:
np
.
random
.
rand
(
n_res
,
4
).
astype
(
np
.
float32
),
}
out_gt
=
f
.
apply
({},
None
,
value
,
batch
)
out_gt
=
torch
.
tensor
(
np
.
array
(
out_gt
.
block_until_ready
()))
value
=
tree_map
(
lambda
x
:
torch
.
tensor
(
x
).
cuda
(),
value
,
np
.
ndarray
)
batch
=
tree_map
(
lambda
x
:
torch
.
tensor
(
x
).
cuda
(),
batch
,
np
.
ndarray
)
batch
[
"chi_angles_sin_cos"
]
=
torch
.
stack
(
[
torch
.
sin
(
batch
[
"chi_angles"
]),
torch
.
cos
(
batch
[
"chi_angles"
]),
],
dim
=-
1
,
)
with
torch
.
no_grad
():
out_repro
=
supervised_chi_loss
(
chi_weight
=
c_chi_loss
.
chi_weight
,
angle_norm_weight
=
c_chi_loss
.
angle_norm_weight
,
**
{
**
batch
,
**
value
[
"sidechains"
]}
)
out_repro
=
tensor_tree_map
(
lambda
t
:
t
.
cpu
(),
out_repro
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_violation_loss_compare
(
self
):
config
=
compare_utils
.
get_alphafold_config
()
c_viol
=
config
.
model
.
heads
.
structure_module
def
run_viol_loss
(
batch
,
atom14_pred_pos
):
ret
=
{
"loss"
:
np
.
array
(
0.
).
astype
(
np
.
float32
),
}
value
=
{}
value
[
"violations"
]
=
(
alphafold
.
model
.
folding
.
find_structural_violations
(
batch
,
atom14_pred_pos
,
c_viol
,
)
)
alphafold
.
model
.
folding
.
structural_violation_loss
(
ret
,
batch
,
value
,
c_viol
,
)
return
ret
[
"loss"
]
f
=
hk
.
transform
(
run_viol_loss
)
n_res
=
consts
.
n_res
batch
=
{
"seq_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,)).
astype
(
np
.
float32
),
"residue_index"
:
np
.
arange
(
n_res
),
"aatype"
:
np
.
random
.
randint
(
0
,
21
,
(
n_res
,)),
}
alphafold
.
model
.
tf
.
data_transforms
.
make_atom14_masks
(
batch
)
batch
=
{
k
:
np
.
array
(
v
)
for
k
,
v
in
batch
.
items
()}
atom14_pred_pos
=
np
.
random
.
rand
(
n_res
,
14
,
3
).
astype
(
np
.
float32
)
out_gt
=
f
.
apply
({},
None
,
batch
,
atom14_pred_pos
)
out_gt
=
torch
.
tensor
(
np
.
array
(
out_gt
.
block_until_ready
()))
batch
=
tree_map
(
lambda
n
:
torch
.
tensor
(
n
).
cuda
(),
batch
,
np
.
ndarray
)
atom14_pred_pos
=
torch
.
tensor
(
atom14_pred_pos
).
cuda
()
batch
.
update
(
feats
.
compute_residx
(
batch
))
out_repro
=
violation_loss
(
find_structural_violations
(
batch
,
atom14_pred_pos
,
**
c_viol
),
**
batch
,
)
out_repro
=
out_repro
.
cpu
()
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_lddt_loss_compare
(
self
):
config
=
compare_utils
.
get_alphafold_config
()
c_plddt
=
config
.
model
.
heads
.
predicted_lddt
def
run_plddt_loss
(
value
,
batch
):
head
=
alphafold
.
model
.
modules
.
PredictedLDDTHead
(
c_plddt
,
config
.
model
.
global_config
)
return
head
.
loss
(
value
,
batch
)
f
=
hk
.
transform
(
run_plddt_loss
)
n_res
=
consts
.
n_res
value
=
{
"predicted_lddt"
:
{
"logits"
:
np
.
random
.
rand
(
n_res
,
c_plddt
.
num_bins
).
astype
(
np
.
float32
),
},
"structure_module"
:
{
"final_atom_positions"
:
np
.
random
.
rand
(
n_res
,
37
,
3
).
astype
(
np
.
float32
),
}
}
batch
=
{
"all_atom_positions"
:
np
.
random
.
rand
(
n_res
,
37
,
3
).
astype
(
np
.
float32
),
"all_atom_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
37
)).
astype
(
np
.
float32
),
"resolution"
:
np
.
array
(
1.
).
astype
(
np
.
float32
),
}
out_gt
=
f
.
apply
({},
None
,
value
,
batch
)
out_gt
=
torch
.
tensor
(
np
.
array
(
out_gt
[
"loss"
]))
to_tensor
=
lambda
t
:
torch
.
tensor
(
t
).
cuda
()
value
=
tree_map
(
to_tensor
,
value
,
np
.
ndarray
)
batch
=
tree_map
(
to_tensor
,
batch
,
np
.
ndarray
)
out_repro
=
lddt_loss
(
logits
=
value
[
"predicted_lddt"
][
"logits"
],
all_atom_pred_pos
=
value
[
"structure_module"
][
"final_atom_positions"
],
**
{
**
batch
,
**
c_plddt
},
)
out_repro
=
out_repro
.
cpu
()
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_backbone_loss
(
self
):
config
=
compare_utils
.
get_alphafold_config
()
c_sm
=
config
.
model
.
heads
.
structure_module
def
run_bb_loss
(
batch
,
value
):
ret
=
{
"loss"
:
np
.
array
(
0.
),
}
alphafold
.
model
.
folding
.
backbone_loss
(
ret
,
batch
,
value
,
c_sm
)
return
ret
[
"loss"
]
f
=
hk
.
transform
(
run_bb_loss
)
n_res
=
consts
.
n_res
batch
=
{
"backbone_affine_tensor"
:
random_affines_vector
((
n_res
,)),
"backbone_affine_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,)).
astype
(
np
.
float32
),
"use_clamped_fape"
:
np
.
array
(
0.
),
}
value
=
{
"traj"
:
random_affines_vector
((
c_sm
.
num_layer
,
n_res
,)),
}
out_gt
=
f
.
apply
({},
None
,
batch
,
value
)
out_gt
=
torch
.
tensor
(
np
.
array
(
out_gt
.
block_until_ready
()))
to_tensor
=
lambda
t
:
torch
.
tensor
(
t
).
cuda
()
batch
=
tree_map
(
to_tensor
,
batch
,
np
.
ndarray
)
value
=
tree_map
(
to_tensor
,
value
,
np
.
ndarray
)
batch
[
"backbone_affine_tensor"
]
=
affine_vector_to_4x4
(
batch
[
"backbone_affine_tensor"
]
)
value
[
"traj"
]
=
affine_vector_to_4x4
(
value
[
"traj"
])
out_repro
=
backbone_loss
(
traj
=
value
[
"traj"
],
**
{
**
batch
,
**
c_sm
})
out_repro
=
out_repro
.
cpu
()
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_sidechain_loss_compare
(
self
):
config
=
compare_utils
.
get_alphafold_config
()
c_sm
=
config
.
model
.
heads
.
structure_module
def
run_sidechain_loss
(
batch
,
value
,
atom14_pred_positions
):
batch
=
{
**
batch
,
**
alphafold
.
model
.
all_atom
.
atom37_to_frames
(
batch
[
"aatype"
],
batch
[
"all_atom_positions"
],
batch
[
"all_atom_mask"
],
)
}
v
=
{}
v
[
"sidechains"
]
=
{}
v
[
"sidechains"
][
"frames"
]
=
(
alphafold
.
model
.
r3
.
rigids_from_tensor4x4
(
value
[
"sidechains"
][
"frames"
]
)
)
v
[
"sidechains"
][
"atom_pos"
]
=
alphafold
.
model
.
r3
.
vecs_from_tensor
(
value
[
"sidechains"
][
"atom_pos"
]
)
v
.
update
(
alphafold
.
model
.
folding
.
compute_renamed_ground_truth
(
batch
,
atom14_pred_positions
,
))
value
=
v
ret
=
alphafold
.
model
.
folding
.
sidechain_loss
(
batch
,
value
,
c_sm
)
return
ret
[
"loss"
]
f
=
hk
.
transform
(
run_sidechain_loss
)
n_res
=
consts
.
n_res
batch
=
{
"seq_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,)).
astype
(
np
.
float32
),
"aatype"
:
np
.
random
.
randint
(
0
,
21
,
(
n_res
,)),
"atom14_gt_positions"
:
np
.
random
.
rand
(
n_res
,
14
,
3
).
astype
(
np
.
float32
),
"atom14_gt_exists"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
14
)).
astype
(
np
.
float32
),
"all_atom_positions"
:
np
.
random
.
rand
(
n_res
,
37
,
3
).
astype
(
np
.
float32
),
"all_atom_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
37
)).
astype
(
np
.
float32
),
}
def
_build_extra_feats_np
():
b
=
tree_map
(
lambda
n
:
torch
.
tensor
(
n
),
batch
,
np
.
ndarray
)
b
.
update
(
feats
.
build_ambiguity_feats
(
b
))
b
.
update
(
feats
.
compute_residx
(
b
))
return
tensor_tree_map
(
lambda
t
:
np
.
array
(
t
),
b
)
batch
=
_build_extra_feats_np
()
value
=
{
"sidechains"
:
{
"frames"
:
random_affines_4x4
((
c_sm
.
num_layer
,
n_res
,
8
)),
"atom_pos"
:
np
.
random
.
rand
(
c_sm
.
num_layer
,
n_res
,
14
,
3
).
astype
(
np
.
float32
),
}
}
atom14_pred_pos
=
np
.
random
.
rand
(
n_res
,
14
,
3
).
astype
(
np
.
float32
)
out_gt
=
f
.
apply
({},
None
,
batch
,
value
,
atom14_pred_pos
)
out_gt
=
torch
.
tensor
(
np
.
array
(
out_gt
.
block_until_ready
()))
to_tensor
=
lambda
t
:
torch
.
tensor
(
t
).
cuda
()
batch
=
tree_map
(
to_tensor
,
batch
,
np
.
ndarray
)
value
=
tree_map
(
to_tensor
,
value
,
np
.
ndarray
)
atom14_pred_pos
=
to_tensor
(
atom14_pred_pos
)
batch
.
update
(
feats
.
atom37_to_frames
(
eps
=
1e-8
,
**
batch
))
batch
.
update
(
compute_renamed_ground_truth
(
batch
,
atom14_pred_pos
))
out_repro
=
sidechain_loss
(
sidechain_frames
=
value
[
"sidechains"
][
"frames"
],
sidechain_atom_pos
=
value
[
"sidechains"
][
"atom_pos"
],
**
{
**
batch
,
**
c_sm
},
)
out_repro
=
out_repro
.
cpu
()
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_tm_loss_compare
(
self
):
config
=
compare_utils
.
get_alphafold_config
()
c_tm
=
config
.
model
.
heads
.
predicted_aligned_error
def
run_tm_loss
(
representations
,
batch
,
value
):
head
=
alphafold
.
model
.
modules
.
PredictedAlignedErrorHead
(
c_tm
,
config
.
model
.
global_config
)
v
=
{}
v
.
update
(
value
)
v
[
"predicted_aligned_error"
]
=
head
(
representations
,
batch
,
False
)
return
head
.
loss
(
v
,
batch
)[
"loss"
]
f
=
hk
.
transform
(
run_tm_loss
)
n_res
=
consts
.
n_res
representations
=
{
"pair"
:
np
.
random
.
rand
(
n_res
,
n_res
,
consts
.
c_z
).
astype
(
np
.
float32
),
}
batch
=
{
"backbone_affine_tensor"
:
random_affines_vector
((
n_res
,)),
"backbone_affine_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,)).
astype
(
np
.
float32
),
"resolution"
:
np
.
array
(
1.
).
astype
(
np
.
float32
),
}
value
=
{
"structure_module"
:
{
"final_affines"
:
random_affines_vector
((
n_res
,)),
}
}
params
=
compare_utils
.
fetch_alphafold_module_weights
(
"alphafold/alphafold_iteration/predicted_aligned_error_head"
)
out_gt
=
f
.
apply
(
params
,
None
,
representations
,
batch
,
value
)
out_gt
=
torch
.
tensor
(
np
.
array
(
out_gt
.
block_until_ready
()))
to_tensor
=
lambda
n
:
torch
.
tensor
(
n
).
cuda
()
representations
=
tree_map
(
to_tensor
,
representations
,
np
.
ndarray
)
batch
=
tree_map
(
to_tensor
,
batch
,
np
.
ndarray
)
value
=
tree_map
(
to_tensor
,
value
,
np
.
ndarray
)
batch
[
"backbone_affine_tensor"
]
=
(
affine_vector_to_4x4
(
batch
[
"backbone_affine_tensor"
])
)
value
[
"structure_module"
][
"final_affines"
]
=
(
affine_vector_to_4x4
(
value
[
"structure_module"
][
"final_affines"
])
)
model
=
compare_utils
.
get_global_pretrained_openfold
()
logits
=
model
.
aux_heads
.
tm
(
representations
[
"pair"
])
out_repro
=
tm_loss
(
logits
=
logits
,
final_affine_tensor
=
value
[
"structure_module"
][
"final_affines"
],
**
{
**
batch
,
**
c_tm
},
)
out_repro
=
out_repro
.
cpu
()
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
if
__name__
==
"__main__"
:
unittest
.
main
()
tests/test_outer_product_mean.py
View file @
9f6b67f3
...
...
@@ -92,4 +92,3 @@ class TestOuterProductMean(unittest.TestCase):
if
__name__
==
"__main__"
:
unittest
.
main
()
tests/test_structure_module.py
View file @
9f6b67f3
...
...
@@ -21,13 +21,17 @@ from openfold.np.residue_constants import (
restype_atom14_to_rigid_group
,
restype_atom14_mask
,
restype_atom14_rigid_group_positions
,
restype_atom37_mask
,
)
from
openfold.model.structure_module
import
*
from
openfold.model.structure_module
import
(
_torsion_angles_to_frames
,
_frames_and_literature_positions_to_atom14_pos
,
StructureModule
,
StructureModuleTransition
,
BackboneUpdate
,
AngleResnet
,
InvariantPointAttention
,
)
from
openfold.utils.affine_utils
import
T
import
openfold.utils.feats
as
feats
import
tests.compare_utils
as
compare_utils
from
tests.config
import
consts
from
tests.data_utils
import
(
...
...
@@ -42,10 +46,10 @@ if(compare_utils.alphafold_is_installed()):
class
TestStructureModule
(
unittest
.
TestCase
):
def
test_structure_module_shape
(
self
):
batch_size
=
2
n
=
5
c_s
=
7
c_z
=
11
batch_size
=
consts
.
batch_size
n
=
consts
.
n_res
c_s
=
consts
.
c_s
c_z
=
consts
.
c_z
c_ipa
=
13
c_resnet
=
17
no_heads_ipa
=
6
...
...
@@ -94,47 +98,6 @@ class TestStructureModule(unittest.TestCase):
out
[
"positions"
].
shape
==
(
no_layers
,
batch_size
,
n
,
14
,
3
)
)
def
test_torsion_angles_to_frames_shape
(
self
):
batch_size
=
2
n
=
5
rots
=
torch
.
rand
((
batch_size
,
n
,
3
,
3
))
trans
=
torch
.
rand
((
batch_size
,
n
,
3
))
ts
=
T
(
rots
,
trans
)
angles
=
torch
.
rand
((
batch_size
,
n
,
7
,
2
))
aas
=
torch
.
tensor
([
i
%
2
for
i
in
range
(
n
)])
aas
=
torch
.
stack
([
aas
for
_
in
range
(
batch_size
)])
frames
=
_torsion_angles_to_frames
(
ts
,
angles
,
aas
,
torch
.
tensor
(
restype_rigid_group_default_frame
),
)
self
.
assertTrue
(
frames
.
shape
==
(
batch_size
,
n
,
8
))
def
test_frames_and_literature_positions_to_atom14_pos_shape
(
self
):
batch_size
=
2
n
=
5
rots
=
torch
.
rand
((
batch_size
,
n
,
8
,
3
,
3
))
trans
=
torch
.
rand
((
batch_size
,
n
,
8
,
3
))
ts
=
T
(
rots
,
trans
)
f
=
torch
.
randint
(
low
=
0
,
high
=
21
,
size
=
(
batch_size
,
n
)).
long
()
xyz
=
_frames_and_literature_positions_to_atom14_pos
(
ts
,
f
,
torch
.
tensor
(
restype_rigid_group_default_frame
),
torch
.
tensor
(
restype_atom14_to_rigid_group
),
torch
.
tensor
(
restype_atom14_mask
),
torch
.
tensor
(
restype_atom14_rigid_group_positions
),
)
self
.
assertTrue
(
xyz
.
shape
==
(
batch_size
,
n
,
14
,
3
))
def
test_structure_module_transition_shape
(
self
):
batch_size
=
2
n
=
5
...
...
@@ -152,6 +115,76 @@ class TestStructureModule(unittest.TestCase):
self
.
assertTrue
(
shape_before
==
shape_after
)
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_structure_module_compare
(
self
):
config
=
compare_utils
.
get_alphafold_config
()
c_sm
=
config
.
model
.
heads
.
structure_module
c_global
=
config
.
model
.
global_config
def
run_sm
(
representations
,
batch
):
sm
=
alphafold
.
model
.
folding
.
StructureModule
(
c_sm
,
c_global
)
representations
=
{
k
:
jax
.
lax
.
stop_gradient
(
v
)
for
k
,
v
in
representations
.
items
()
}
batch
=
{
k
:
jax
.
lax
.
stop_gradient
(
v
)
for
k
,
v
in
batch
.
items
()
}
return
sm
(
representations
,
batch
,
is_training
=
False
)
f
=
hk
.
transform
(
run_sm
)
n_res
=
200
representations
=
{
'single'
:
np
.
random
.
rand
(
n_res
,
consts
.
c_s
).
astype
(
np
.
float32
),
'pair'
:
np
.
random
.
rand
(
n_res
,
n_res
,
consts
.
c_z
).
astype
(
np
.
float32
),
}
batch
=
{
'seq_mask'
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,)),
'aatype'
:
np
.
random
.
randint
(
0
,
21
,
(
n_res
,)),
}
batch
[
'atom14_atom_exists'
]
=
np
.
take
(
restype_atom14_mask
,
batch
[
'aatype'
],
axis
=
0
)
batch
[
'atom37_atom_exists'
]
=
np
.
take
(
restype_atom37_mask
,
batch
[
'aatype'
],
axis
=
0
)
batch
.
update
(
feats
.
compute_residx_np
(
batch
))
params
=
compare_utils
.
fetch_alphafold_module_weights
(
"alphafold/alphafold_iteration/structure_module"
)
key
=
jax
.
random
.
PRNGKey
(
42
)
out_gt
=
f
.
apply
(
params
,
key
,
representations
,
batch
)
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
[
"final_atom14_positions"
].
block_until_ready
())
)
model
=
compare_utils
.
get_global_pretrained_openfold
()
out_repro
=
model
.
structure_module
(
torch
.
as_tensor
(
representations
[
"single"
]).
cuda
(),
torch
.
as_tensor
(
representations
[
"pair"
]).
cuda
(),
torch
.
as_tensor
(
batch
[
"aatype"
]).
cuda
(),
mask
=
torch
.
as_tensor
(
batch
[
"seq_mask"
]).
cuda
(),
)
out_repro
=
out_repro
[
"positions"
][
-
1
].
cpu
()
# The structure module, thanks to angle normalization, is very volatile
# We only assess the mean here. Heuristically speaking, it seems to
# have lower error in general on real rather than synthetic data.
self
.
assertTrue
(
torch
.
mean
(
torch
.
abs
(
out_gt
-
out_repro
))
<
0.01
)
class
TestBackboneUpdate
(
unittest
.
TestCase
):
def
test_shape
(
self
):
...
...
tests/test_template.py
View file @
9f6b67f3
...
...
@@ -137,7 +137,8 @@ class TestTemplatePairStack(unittest.TestCase):
_mask_trans
=
False
,
).
cpu
()
self
.
assertTrue
(
torch
.
all
(
torch
.
abs
(
out_gt
-
out_repro
)
<
consts
.
eps
))
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
class
Template
(
unittest
.
TestCase
):
@
compare_utils
.
skip_unless_alphafold_installed
()
...
...
tests/test_triangular_attention.py
View file @
9f6b67f3
...
...
@@ -100,9 +100,11 @@ class TestTriangularAttention(unittest.TestCase):
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
)
<
consts
.
eps
))
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_tri_att_end_compare
(
self
):
self
.
_tri_att_compare
()
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_tri_att_start_compare
(
self
):
self
.
_tri_att_compare
(
starting
=
True
)
...
...
tests/test_utils.py
View file @
9f6b67f3
...
...
@@ -16,8 +16,8 @@ import math
import
torch
import
unittest
from
openfold.utils.affine_utils
import
*
from
openfold.utils.tensor_utils
import
*
from
openfold.utils.affine_utils
import
T
,
quat_to_rot
from
openfold.utils.tensor_utils
import
chunk_layer
X_90_ROT
=
torch
.
tensor
([
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment