Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
49767099
Commit
49767099
authored
Oct 19, 2021
by
Gustaf Ahdritz
Browse files
Bring tests up to speed
parent
a6f56d16
Changes
21
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
167 additions
and
108 deletions
+167
-108
openfold/config.py
openfold/config.py
+10
-5
openfold/data/data_modules.py
openfold/data/data_modules.py
+61
-21
openfold/data/data_transforms.py
openfold/data/data_transforms.py
+2
-5
openfold/data/feature_pipeline.py
openfold/data/feature_pipeline.py
+2
-11
openfold/data/input_pipeline.py
openfold/data/input_pipeline.py
+5
-7
openfold/model/model.py
openfold/model/model.py
+7
-5
openfold/model/primitives.py
openfold/model/primitives.py
+11
-8
scripts/run_unit_tests.sh
scripts/run_unit_tests.sh
+2
-0
tests/compare_utils.py
tests/compare_utils.py
+1
-1
tests/data_utils.py
tests/data_utils.py
+9
-3
tests/test_evoformer.py
tests/test_evoformer.py
+7
-6
tests/test_feats.py
tests/test_feats.py
+7
-5
tests/test_import_weights.py
tests/test_import_weights.py
+1
-1
tests/test_loss.py
tests/test_loss.py
+1
-1
tests/test_model.py
tests/test_model.py
+17
-11
tests/test_msa.py
tests/test_msa.py
+10
-7
tests/test_outer_product_mean.py
tests/test_outer_product_mean.py
+4
-3
tests/test_pair_transition.py
tests/test_pair_transition.py
+2
-1
tests/test_structure_module.py
tests/test_structure_module.py
+2
-2
tests/test_template.py
tests/test_template.py
+6
-5
No files found.
openfold/config.py
View file @
49767099
...
...
@@ -64,7 +64,6 @@ blocks_per_ckpt = mlc.FieldReference(None, field_type=int)
chunk_size
=
mlc
.
FieldReference
(
4
,
field_type
=
int
)
aux_distogram_bins
=
mlc
.
FieldReference
(
64
,
field_type
=
int
)
eps
=
mlc
.
FieldReference
(
1e-8
,
field_type
=
float
)
num_recycle
=
mlc
.
FieldReference
(
3
,
field_type
=
int
)
templates_enabled
=
mlc
.
FieldReference
(
True
,
field_type
=
bool
)
embed_template_torsion_angles
=
mlc
.
FieldReference
(
True
,
field_type
=
bool
)
...
...
@@ -77,7 +76,6 @@ config = mlc.ConfigDict(
{
"data"
:
{
"common"
:
{
"batch_modes"
:
[(
"clamped"
,
0.9
),
(
"unclamped"
,
0.1
)],
"feat"
:
{
"aatype"
:
[
NUM_RES
],
"all_atom_mask"
:
[
NUM_RES
,
None
],
...
...
@@ -93,7 +91,7 @@ config = mlc.ConfigDict(
"backbone_affine_mask"
:
[
NUM_RES
],
"backbone_affine_tensor"
:
[
NUM_RES
,
None
,
None
],
"bert_mask"
:
[
NUM_MSA_SEQ
,
NUM_RES
],
"chi_angles_sin_cos"
:
[
NUM_RES
,
None
],
"chi_angles_sin_cos"
:
[
NUM_RES
,
None
,
None
],
"chi_mask"
:
[
NUM_RES
,
None
],
"extra_deletion_value"
:
[
NUM_EXTRA_SEQ
,
NUM_RES
],
"extra_has_deletion"
:
[
NUM_EXTRA_SEQ
,
NUM_RES
],
...
...
@@ -104,6 +102,7 @@ config = mlc.ConfigDict(
"msa_feat"
:
[
NUM_MSA_SEQ
,
NUM_RES
,
None
],
"msa_mask"
:
[
NUM_MSA_SEQ
,
NUM_RES
],
"msa_row_mask"
:
[
NUM_MSA_SEQ
],
"no_recycling_iters"
:
[],
"pseudo_beta"
:
[
NUM_RES
,
None
],
"pseudo_beta_mask"
:
[
NUM_RES
],
"residue_index"
:
[
NUM_RES
],
...
...
@@ -149,8 +148,8 @@ config = mlc.ConfigDict(
"uniform_prob"
:
0.1
,
},
"max_extra_msa"
:
1024
,
"max_recycling_iters"
:
3
,
"msa_cluster_features"
:
True
,
"num_recycle"
:
num_recycle
,
"reduce_msa_clusters_by_max_templates"
:
False
,
"resample_msa_in_recycling"
:
True
,
"template_features"
:
[
...
...
@@ -167,9 +166,14 @@ config = mlc.ConfigDict(
"seq_length"
,
"between_segment_residues"
,
"deletion_matrix"
,
"no_recycling_iters"
,
],
"use_templates"
:
templates_enabled
,
"use_template_torsion_angles"
:
embed_template_torsion_angles
,
},
"supervised"
:
{
"clamp_prob"
:
0.9
,
"uniform_recycling"
:
True
,
"supervised_features"
:
[
"all_atom_mask"
,
"all_atom_positions"
,
...
...
@@ -212,6 +216,8 @@ config = mlc.ConfigDict(
"crop"
:
True
,
"crop_size"
:
256
,
"supervised"
:
True
,
"clamp_prob"
:
0.9
,
"subsample_recycling"
:
True
,
},
"data_module"
:
{
"use_small_bfd"
:
False
,
...
...
@@ -234,7 +240,6 @@ config = mlc.ConfigDict(
"eps"
:
eps
,
},
"model"
:
{
"num_recycle"
:
num_recycle
,
"_mask_trans"
:
False
,
"input_embedder"
:
{
"tf_dim"
:
22
,
...
...
openfold/data/data_modules.py
View file @
49767099
...
...
@@ -5,6 +5,7 @@ import os
from
typing
import
Optional
,
Sequence
import
ml_collections
as
mlc
import
numpy
as
np
import
pytorch_lightning
as
pl
import
torch
from
torch.utils.data
import
RandomSampler
...
...
@@ -216,31 +217,66 @@ class OpenFoldDataset(torch.utils.data.IterableDataset):
class
OpenFoldBatchCollator
:
def
__init__
(
self
,
config
,
generator
,
stage
=
"train"
):
self
.
config
=
config
batch_modes
=
config
.
common
.
batch_modes
batch_mode_names
,
batch_mode_probs
=
list
(
zip
(
*
batch_modes
))
self
.
batch_mode_names
=
batch_mode_names
self
.
batch_mode_probs
=
batch_mode_probs
self
.
generator
=
generator
self
.
stage
=
stage
self
.
feature_pipeline
=
feature_pipeline
.
FeaturePipeline
(
config
)
self
.
_prep_batch_properties_probs
()
self
.
batch_mode_probs_tensor
=
torch
.
tensor
(
self
.
batch_mode_probs
)
def
_prep_batch_properties_probs
(
self
):
keyed_probs
=
[]
stage_cfg
=
self
.
config
[
self
.
stage
]
self
.
feature_pipeline
=
feature_pipeline
.
FeaturePipeline
(
self
.
config
)
max_iters
=
self
.
config
.
common
.
max_recycling_iters
if
(
stage_cfg
.
supervised
):
clamp_prob
=
self
.
config
.
supervised
.
clamp_prob
keyed_probs
.
append
(
(
"use_clamped_fape"
,
[
1
-
clamp_prob
,
clamp_prob
])
)
if
(
self
.
config
.
supervised
.
uniform_recycling
):
recycling_probs
=
[
1.
/
(
max_iters
+
1
)
for
_
in
range
(
max_iters
+
1
)
]
keyed_probs
.
append
(
(
"no_recycling_iters"
,
recycling_probs
)
)
else
:
recycling_probs
=
[
0.
for
_
in
range
(
max_iters
+
1
)
]
recycling_probs
[
-
1
]
=
1.
keyed_probs
.
append
(
(
"no_recycling_iters"
,
recycling_probs
)
)
def
__call__
(
self
,
raw_prots
):
# We use torch.multinomial here rather than Categorical because the
# latter doesn't accept a generator for some reason
batch_mode_idx
=
torch
.
multinomial
(
self
.
batch_mode_probs_tensor
,
1
,
keys
,
probs
=
zip
(
*
keyed_probs
)
max_len
=
max
([
len
(
p
)
for
p
in
probs
])
padding
=
[[
0.
]
*
(
max_len
-
len
(
p
))
for
p
in
probs
]
self
.
prop_keys
=
keys
self
.
prop_probs_tensor
=
torch
.
tensor
(
[
p
+
pad
for
p
,
pad
in
zip
(
probs
,
padding
)],
dtype
=
torch
.
float32
,
)
def
_add_batch_properties
(
self
,
raw_prots
):
samples
=
torch
.
multinomial
(
self
.
prop_probs_tensor
,
num_samples
=
1
,
# 1 per row
replacement
=
True
,
generator
=
self
.
generator
).
item
()
batch_mode_name
=
self
.
batch_mode_names
[
batch_mode_idx
]
)
for
i
,
key
in
enumerate
(
self
.
prop_keys
):
sample
=
samples
[
i
][
0
]
for
prot
in
raw_prots
:
prot
[
key
]
=
np
.
array
(
sample
,
dtype
=
np
.
float32
)
def
__call__
(
self
,
raw_prots
):
self
.
_add_batch_properties
(
raw_prots
)
processed_prots
=
[]
for
prot
in
raw_prots
:
features
=
self
.
feature_pipeline
.
process_features
(
prot
,
self
.
stage
,
batch_mode_name
prot
,
self
.
stage
)
processed_prots
.
append
(
features
)
...
...
@@ -265,6 +301,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
train_mapping_path
:
Optional
[
str
]
=
None
,
distillation_mapping_path
:
Optional
[
str
]
=
None
,
template_release_dates_cache_path
:
Optional
[
str
]
=
None
,
batch_seed
:
Optional
[
int
]
=
None
,
**
kwargs
):
super
(
OpenFoldDataModule
,
self
).
__init__
()
...
...
@@ -286,6 +323,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
self
.
template_release_dates_cache_path
=
(
template_release_dates_cache_path
)
self
.
batch_seed
=
batch_seed
if
(
self
.
train_data_dir
is
None
and
self
.
predict_data_dir
is
None
):
raise
ValueError
(
...
...
@@ -309,7 +347,10 @@ class OpenFoldDataModule(pl.LightningDataModule):
'be specified as well'
)
def
setup
(
self
,
stage
):
def
setup
(
self
,
stage
:
Optional
[
str
]
=
None
):
if
(
stage
is
None
):
stage
=
"train"
# Most of the arguments are the same for the three datasets
dataset_gen
=
partial
(
OpenFoldSingleDataset
,
template_mmcif_dir
=
self
.
template_mmcif_dir
,
...
...
@@ -369,12 +410,11 @@ class OpenFoldDataModule(pl.LightningDataModule):
mode
=
"predict"
,
)
self
.
batch_collation_seed
=
torch
.
Generator
().
seed
()
def
_gen_batch_collator
(
self
,
stage
):
""" We want each process to use the same batch collation seed """
generator
=
torch
.
Generator
()
generator
=
generator
.
manual_seed
(
self
.
batch_collation_seed
)
if
(
self
.
batch_seed
is
not
None
):
generator
=
generator
.
manual_seed
(
self
.
batch_seed
)
collate_fn
=
OpenFoldBatchCollator
(
self
.
config
,
generator
,
stage
)
...
...
@@ -404,5 +444,5 @@ class OpenFoldDataModule(pl.LightningDataModule):
self
.
predict_dataset
,
batch_size
=
self
.
config
.
data_module
.
data_loaders
.
batch_size
,
num_workers
=
self
.
config
.
data_module
.
data_loaders
.
num_workers
,
collate_fn
=
self
.
_gen_batch_collator
(
"
eval
"
)
collate_fn
=
self
.
_gen_batch_collator
(
"
predict
"
)
)
openfold/data/data_transforms.py
View file @
49767099
...
...
@@ -1095,7 +1095,6 @@ def random_crop_to_size(
shape_schema
,
subsample_templates
=
False
,
seed
=
None
,
batch_mode
=
"clamped"
,
):
"""Crop randomly to `crop_size`, or keep as is if shorter than that."""
seq_length
=
protein
[
"seq_length"
]
...
...
@@ -1133,13 +1132,11 @@ def random_crop_to_size(
num_templates_crop_size
=
num_templates
n
=
seq_length
-
num_res_crop_size
if
batch_mode
==
"clamped"
:
if
protein
[
"use_clamped_fape"
]
==
1.
:
right_anchor
=
n
el
if
batch_mode
==
"unclamped"
:
el
se
:
x
=
_randint
(
0
,
n
)
right_anchor
=
n
-
x
else
:
raise
ValueError
(
"Invalid batch mode"
)
num_res_crop_start
=
_randint
(
0
,
right_anchor
)
...
...
openfold/data/feature_pipeline.py
View file @
49767099
...
...
@@ -64,7 +64,7 @@ def make_data_config(
feature_names
+=
cfg
.
common
.
template_features
if
cfg
[
mode
].
supervised
:
feature_names
+=
cfg
.
common
.
supervised_features
feature_names
+=
cfg
.
supervised
.
supervised_features
return
cfg
,
feature_names
...
...
@@ -73,7 +73,6 @@ def np_example_to_features(
np_example
:
FeatureDict
,
config
:
ml_collections
.
ConfigDict
,
mode
:
str
,
batch_mode
:
str
,
):
np_example
=
dict
(
np_example
)
num_res
=
int
(
np_example
[
"seq_length"
][
0
])
...
...
@@ -84,11 +83,6 @@ def np_example_to_features(
"deletion_matrix_int"
).
astype
(
np
.
float32
)
if
batch_mode
==
"clamped"
:
np_example
[
"use_clamped_fape"
]
=
np
.
array
(
1.0
).
astype
(
np
.
float32
)
elif
batch_mode
==
"unclamped"
:
np_example
[
"use_clamped_fape"
]
=
np
.
array
(
0.0
).
astype
(
np
.
float32
)
tensor_dict
=
np_to_tensor_dict
(
np_example
=
np_example
,
features
=
feature_names
)
...
...
@@ -97,7 +91,6 @@ def np_example_to_features(
tensor_dict
,
cfg
.
common
,
cfg
[
mode
],
batch_mode
=
batch_mode
,
)
return
{
k
:
v
for
k
,
v
in
features
.
items
()}
...
...
@@ -116,11 +109,9 @@ class FeaturePipeline:
self
,
raw_features
:
FeatureDict
,
mode
:
str
=
"train"
,
batch_mode
:
str
=
"clamped"
,
)
->
FeatureDict
:
return
np_example_to_features
(
np_example
=
raw_features
,
config
=
self
.
config
,
mode
=
mode
,
batch_mode
=
batch_mode
,
)
openfold/data/input_pipeline.py
View file @
49767099
...
...
@@ -68,7 +68,7 @@ def nonensembled_transform_fns(common_cfg, mode_cfg):
return
transforms
def
ensembled_transform_fns
(
common_cfg
,
mode_cfg
,
batch_mode
,
ensemble_seed
):
def
ensembled_transform_fns
(
common_cfg
,
mode_cfg
,
ensemble_seed
):
"""Input pipeline data transformers that can be ensembled and averaged."""
transforms
=
[]
...
...
@@ -116,7 +116,6 @@ def ensembled_transform_fns(common_cfg, mode_cfg, batch_mode, ensemble_seed):
mode_cfg
.
max_templates
,
crop_feats
,
mode_cfg
.
subsample_templates
,
batch_mode
=
batch_mode
,
seed
=
ensemble_seed
,
)
)
...
...
@@ -137,9 +136,7 @@ def ensembled_transform_fns(common_cfg, mode_cfg, batch_mode, ensemble_seed):
return
transforms
def
process_tensors_from_config
(
tensors
,
common_cfg
,
mode_cfg
,
batch_mode
=
"clamped"
):
def
process_tensors_from_config
(
tensors
,
common_cfg
,
mode_cfg
):
"""Based on the config, apply filters and transformations to the data."""
ensemble_seed
=
torch
.
Generator
().
seed
()
...
...
@@ -150,7 +147,6 @@ def process_tensors_from_config(
fns
=
ensembled_transform_fns
(
common_cfg
,
mode_cfg
,
batch_mode
,
ensemble_seed
,
)
fn
=
compose
(
fns
)
...
...
@@ -160,9 +156,11 @@ def process_tensors_from_config(
tensors
=
compose
(
nonensembled_transform_fns
(
common_cfg
,
mode_cfg
))(
tensors
)
num_ensemble
=
mode_cfg
.
num_ensemble
num_recycling
=
tensors
[
"no_recycling_iters"
].
item
()
if
common_cfg
.
resample_msa_in_recycling
:
# Separate batch per ensembling & recycling step.
num_ensemble
*=
common_cfg
.
num_recycl
e
+
1
num_ensemble
*=
num_recycl
ing
+
1
if
isinstance
(
num_ensemble
,
torch
.
Tensor
)
or
num_ensemble
>
1
:
tensors
=
map_fn
(
...
...
openfold/model/model.py
View file @
49767099
...
...
@@ -202,7 +202,7 @@ class AlphaFold(nn.Module):
)
# Inject information from previous recycling iterations
if
self
.
config
.
num_recycle
>
0
:
if
feats
[
"no_recycling_iters"
]
>
0
:
# Initialize the recycling embeddings, if needs be
if
None
in
[
m_1_prev
,
z_prev
,
x_prev
]:
# [*, N, C_m]
...
...
@@ -236,7 +236,7 @@ class AlphaFold(nn.Module):
# [*, N, N, C_z]
z
=
z
+
z_prev_emb
#
This can matter during inference when N_res is very large
#
Possibly prevents memory fragmentation
del
m_1_prev_emb
,
z_prev_emb
# Embed the templates + merge with MSA/pair embeddings
...
...
@@ -395,19 +395,21 @@ class AlphaFold(nn.Module):
# Initialize recycling embeddings
m_1_prev
,
z_prev
,
x_prev
=
None
,
None
,
None
# Disable activation checkpointing for the first few recycling iters
is_grad_enabled
=
torch
.
is_grad_enabled
()
self
.
_disable_activation_checkpointing
()
# Main recycling loop
for
cycle_no
in
range
(
self
.
config
.
num_recycle
+
1
):
num_iters
=
batch
[
"aatype"
].
shape
[
-
1
]
for
cycle_no
in
range
(
num_iters
):
# Select the features for the current recycling cycle
fetch_cur_batch
=
lambda
t
:
t
[...,
cycle_no
]
feats
=
tensor_tree_map
(
fetch_cur_batch
,
batch
)
# Enable grad iff we're training and it's the final recycling layer
is_final_iter
=
cycle_no
==
self
.
config
.
num_recycle
is_final_iter
=
cycle_no
==
(
num_iters
-
1
)
with
torch
.
set_grad_enabled
(
is_grad_enabled
and
is_final_iter
):
# Sidestep AMP bug
discussed in pyt
orch issue #65766
# Sidestep AMP bug
(PyT
orch issue #65766
)
if
is_final_iter
:
self
.
_enable_activation_checkpointing
()
if
torch
.
is_autocast_enabled
():
...
...
openfold/model/primitives.py
View file @
49767099
...
...
@@ -258,11 +258,14 @@ class Attention(nn.Module):
k
=
k
.
view
(
k
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
v
=
v
.
view
(
v
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
# [*, H, Q, C_hidden]
q
=
permute_final_dims
(
q
,
(
1
,
0
,
2
))
# [*, H, C_hidden, K]
k
=
permute_final_dims
(
k
,
(
1
,
2
,
0
))
# [*, H, Q, K]
a
=
torch
.
matmul
(
permute_final_dims
(
q
,
(
1
,
0
,
2
)),
# [*, H, Q, C_hidden]
permute_final_dims
(
k
,
(
1
,
2
,
0
)),
# [*, H, C_hidden, K]
)
a
=
torch
.
matmul
(
q
,
k
)
del
q
,
k
...
...
@@ -273,11 +276,11 @@ class Attention(nn.Module):
a
=
a
+
b
a
=
self
.
softmax
(
a
)
# [*, H, V, C_hidden]
v
=
permute_final_dims
(
v
,
(
1
,
0
,
2
))
# [*, H, Q, C_hidden]
o
=
torch
.
matmul
(
a
,
permute_final_dims
(
v
,
(
1
,
0
,
2
)),
# [*, H, V, C_hidden]
)
o
=
torch
.
matmul
(
a
,
v
)
# [*, Q, H, C_hidden]
o
=
o
.
transpose
(
-
2
,
-
3
)
...
...
scripts/run_unit_tests.sh
View file @
49767099
#!/bin/bash
#CUDA_VISIBLE_DEVICES="5"
python3
-m
unittest
"
$@
"
||
\
echo
-e
"
\n
Test(s) failed. Make sure you've installed all Python dependencies."
tests/compare_utils.py
View file @
49767099
...
...
@@ -60,7 +60,7 @@ _model = None
def
get_global_pretrained_openfold
():
global
_model
if
_model
is
None
:
_model
=
AlphaFold
(
model_config
(
"model_1_ptm"
)
.
model
)
_model
=
AlphaFold
(
model_config
(
"model_1_ptm"
))
_model
=
_model
.
eval
()
if
not
os
.
path
.
exists
(
_param_path
):
raise
FileNotFoundError
(
...
...
tests/data_utils.py
View file @
49767099
...
...
@@ -25,11 +25,17 @@ def random_template_feats(n_templ, n, batch_size=None):
"template_pseudo_beta_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
*
b
,
n_templ
,
n
)),
"template_pseudo_beta"
:
np
.
random
.
rand
(
*
b
,
n_templ
,
n
,
3
),
"template_aatype"
:
np
.
random
.
randint
(
0
,
22
,
(
*
b
,
n_templ
,
n
)),
"template_all_atom_mask
s
"
:
np
.
random
.
randint
(
"template_all_atom_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
*
b
,
n_templ
,
n
,
37
)
),
"template_all_atom_positions"
:
np
.
random
.
rand
(
*
b
,
n_templ
,
n
,
37
,
3
)
*
10
,
"template_all_atom_positions"
:
np
.
random
.
rand
(
*
b
,
n_templ
,
n
,
37
,
3
)
*
10
,
"template_torsion_angles_sin_cos"
:
np
.
random
.
rand
(
*
b
,
n_templ
,
n
,
7
,
2
),
"template_alt_torsion_angles_sin_cos"
:
np
.
random
.
rand
(
*
b
,
n_templ
,
n
,
7
,
2
),
"template_torsion_angles_mask"
:
np
.
random
.
rand
(
*
b
,
n_templ
,
n
,
7
),
}
batch
=
{
k
:
v
.
astype
(
np
.
float32
)
for
k
,
v
in
batch
.
items
()}
batch
[
"template_aatype"
]
=
batch
[
"template_aatype"
].
astype
(
np
.
int64
)
...
...
tests/test_evoformer.py
View file @
49767099
...
...
@@ -66,7 +66,6 @@ class TestEvoformerStack(unittest.TestCase):
msa_dropout
,
pair_stack_dropout
,
blocks_per_ckpt
=
None
,
chunk_size
=
4
,
inf
=
inf
,
eps
=
eps
,
).
eval
()
...
...
@@ -79,7 +78,9 @@ class TestEvoformerStack(unittest.TestCase):
shape_m_before
=
m
.
shape
shape_z_before
=
z
.
shape
m
,
z
,
s
=
es
(
m
,
z
,
msa_mask
,
pair_mask
)
m
,
z
,
s
=
es
(
m
,
z
,
chunk_size
=
4
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
)
self
.
assertTrue
(
m
.
shape
==
shape_m_before
)
self
.
assertTrue
(
z
.
shape
==
shape_z_before
)
...
...
@@ -127,6 +128,7 @@ class TestEvoformerStack(unittest.TestCase):
torch
.
as_tensor
(
activations
[
"pair"
]).
cuda
(),
torch
.
as_tensor
(
masks
[
"msa"
]).
cuda
(),
torch
.
as_tensor
(
masks
[
"pair"
]).
cuda
(),
chunk_size
=
4
,
_mask_trans
=
False
,
)
...
...
@@ -171,7 +173,6 @@ class TestExtraMSAStack(unittest.TestCase):
msa_dropout
,
pair_stack_dropout
,
blocks_per_ckpt
=
None
,
chunk_size
=
4
,
inf
=
inf
,
eps
=
eps
,
).
eval
()
...
...
@@ -199,7 +200,7 @@ class TestExtraMSAStack(unittest.TestCase):
shape_z_before
=
z
.
shape
z
=
es
(
m
,
z
,
msa_mask
,
pair_mask
)
z
=
es
(
m
,
z
,
chunk_size
=
4
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
)
self
.
assertTrue
(
z
.
shape
==
shape_z_before
)
...
...
@@ -212,12 +213,12 @@ class TestMSATransition(unittest.TestCase):
c_m
=
7
n
=
11
mt
=
MSATransition
(
c_m
,
n
,
chunk_size
=
4
)
mt
=
MSATransition
(
c_m
,
n
)
m
=
torch
.
rand
((
batch_size
,
s_t
,
n_r
,
c_m
))
shape_before
=
m
.
shape
m
=
mt
(
m
)
m
=
mt
(
m
,
chunk_size
=
4
)
shape_after
=
m
.
shape
self
.
assertTrue
(
shape_before
==
shape_after
)
...
...
tests/test_feats.py
View file @
49767099
...
...
@@ -16,7 +16,7 @@ import torch
import
numpy
as
np
import
unittest
import
openfold.
features
.data_transforms
as
data_transforms
import
openfold.
data
.data_transforms
as
data_transforms
from
openfold.np.residue_constants
import
(
restype_rigid_group_default_frame
,
restype_atom14_to_rigid_group
,
...
...
@@ -102,10 +102,12 @@ class TestFeats(unittest.TestCase):
out_gt
=
f
.
apply
({},
None
,
aatype
,
all_atom_pos
,
all_atom_mask
)
out_gt
=
jax
.
tree_map
(
lambda
x
:
torch
.
as_tensor
(
np
.
array
(
x
)),
out_gt
)
out_repro
=
feats
.
atom37_to_torsion_angles
(
torch
.
as_tensor
(
aatype
).
cuda
(),
torch
.
as_tensor
(
all_atom_pos
).
cuda
(),
torch
.
as_tensor
(
all_atom_mask
).
cuda
(),
out_repro
=
data_transforms
.
atom37_to_torsion_angles
()(
{
"aatype"
:
torch
.
as_tensor
(
aatype
).
cuda
(),
"all_atom_positions"
:
torch
.
as_tensor
(
all_atom_pos
).
cuda
(),
"all_atom_mask"
:
torch
.
as_tensor
(
all_atom_mask
).
cuda
(),
},
)
tasc
=
out_repro
[
"torsion_angles_sin_cos"
].
cpu
()
atasc
=
out_repro
[
"alt_torsion_angles_sin_cos"
].
cpu
()
...
...
tests/test_import_weights.py
View file @
49767099
...
...
@@ -27,7 +27,7 @@ class TestImportWeights(unittest.TestCase):
c
=
model_config
(
"model_1_ptm"
)
c
.
globals
.
blocks_per_ckpt
=
None
model
=
AlphaFold
(
c
.
model
)
model
=
AlphaFold
(
c
)
import_jax_weights_
(
model
,
...
...
tests/test_loss.py
View file @
49767099
...
...
@@ -19,7 +19,7 @@ import numpy as np
import
unittest
import
ml_collections
as
mlc
from
openfold.
features
import
data_transforms
from
openfold.
data
import
data_transforms
from
openfold.utils.affine_utils
import
T
,
affine_vector_to_4x4
import
openfold.utils.feats
as
feats
from
openfold.utils.loss
import
(
...
...
tests/test_model.py
View file @
49767099
...
...
@@ -18,7 +18,7 @@ import torch.nn as nn
import
numpy
as
np
import
unittest
from
openfold.config
import
model_config
from
openfold.
features.data_transforms
import
make_atom14_mask
s
from
openfold.
data
import
data_transform
s
from
openfold.model.model
import
AlphaFold
import
openfold.utils.feats
as
feats
from
openfold.utils.tensor_utils
import
tree_map
,
tensor_tree_map
...
...
@@ -42,22 +42,21 @@ class TestModel(unittest.TestCase):
n_res
=
consts
.
n_res
n_extra_seq
=
consts
.
n_extra
c
=
model_config
(
"model_1"
).
model
c
.
no_cycles
=
2
c
.
evoformer_stack
.
no_blocks
=
4
# no need to go overboard here
c
.
evoformer_stack
.
blocks_per_ckpt
=
None
# don't want to set up
c
=
model_config
(
"model_1"
)
c
.
model
.
evoformer_stack
.
no_blocks
=
4
# no need to go overboard here
c
.
model
.
evoformer_stack
.
blocks_per_ckpt
=
None
# don't want to set up
# deepspeed for this test
model
=
AlphaFold
(
c
)
batch
=
{}
tf
=
torch
.
randint
(
c
.
input_embedder
.
tf_dim
-
1
,
size
=
(
n_res
,))
tf
=
torch
.
randint
(
c
.
model
.
input_embedder
.
tf_dim
-
1
,
size
=
(
n_res
,))
batch
[
"target_feat"
]
=
nn
.
functional
.
one_hot
(
tf
,
c
.
input_embedder
.
tf_dim
tf
,
c
.
model
.
input_embedder
.
tf_dim
).
float
()
batch
[
"aatype"
]
=
torch
.
argmax
(
batch
[
"target_feat"
],
dim
=-
1
)
batch
[
"residue_index"
]
=
torch
.
arange
(
n_res
)
batch
[
"msa_feat"
]
=
torch
.
rand
((
n_seq
,
n_res
,
c
.
input_embedder
.
msa_dim
))
batch
[
"msa_feat"
]
=
torch
.
rand
((
n_seq
,
n_res
,
c
.
model
.
input_embedder
.
msa_dim
))
t_feats
=
random_template_feats
(
n_templ
,
n_res
)
batch
.
update
({
k
:
torch
.
tensor
(
v
)
for
k
,
v
in
t_feats
.
items
()})
extra_feats
=
random_extra_msa_feats
(
n_extra_seq
,
n_res
)
...
...
@@ -66,10 +65,11 @@ class TestModel(unittest.TestCase):
low
=
0
,
high
=
2
,
size
=
(
n_seq
,
n_res
)
).
float
()
batch
[
"seq_mask"
]
=
torch
.
randint
(
low
=
0
,
high
=
2
,
size
=
(
n_res
,)).
float
()
batch
.
update
(
make_atom14_masks
(
batch
))
batch
.
update
(
data_transforms
.
make_atom14_masks
(
batch
))
batch
[
"no_recycling_iters"
]
=
torch
.
tensor
(
2.
)
add_recycling_dims
=
lambda
t
:
(
t
.
unsqueeze
(
-
1
).
expand
(
*
t
.
shape
,
c
.
no_cycle
s
)
t
.
unsqueeze
(
-
1
).
expand
(
*
t
.
shape
,
c
.
data
.
common
.
max_recycling_iter
s
)
)
batch
=
tensor_tree_map
(
add_recycling_dims
,
batch
)
...
...
@@ -94,7 +94,7 @@ class TestModel(unittest.TestCase):
with
open
(
"tests/test_data/sample_feats.pickle"
,
"rb"
)
as
fp
:
batch
=
pickle
.
load
(
fp
)
out_gt
=
jax
.
jit
(
f
.
apply
)
(
params
,
jax
.
random
.
PRNGKey
(
42
),
batch
)
out_gt
=
f
.
apply
(
params
,
jax
.
random
.
PRNGKey
(
42
),
batch
)
out_gt
=
out_gt
[
"structure_module"
][
"final_atom_positions"
]
# atom37_to_atom14 doesn't like batches
...
...
@@ -103,13 +103,19 @@ class TestModel(unittest.TestCase):
out_gt
=
alphafold
.
model
.
all_atom
.
atom37_to_atom14
(
out_gt
,
batch
)
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
.
block_until_ready
()))
batch
[
"no_recycling_iters"
]
=
np
.
array
([
3.
,
3.
,
3.
,
3.
,])
batch
=
{
k
:
torch
.
as_tensor
(
v
).
cuda
()
for
k
,
v
in
batch
.
items
()}
batch
[
"aatype"
]
=
batch
[
"aatype"
].
long
()
batch
[
"template_aatype"
]
=
batch
[
"template_aatype"
].
long
()
batch
[
"extra_msa"
]
=
batch
[
"extra_msa"
].
long
()
batch
[
"residx_atom37_to_atom14"
]
=
batch
[
"residx_atom37_to_atom14"
].
long
()
batch
[
"template_all_atom_mask"
]
=
batch
[
"template_all_atom_masks"
]
batch
.
update
(
data_transforms
.
atom37_to_torsion_angles
(
"template_"
)(
batch
)
)
# Move the recycling dimension to the end
move_dim
=
lambda
t
:
t
.
permute
(
*
range
(
len
(
t
.
shape
))[
1
:],
0
)
...
...
tests/test_msa.py
View file @
49767099
...
...
@@ -41,13 +41,13 @@ class TestMSARowAttentionWithPairBias(unittest.TestCase):
no_heads
=
4
chunk_size
=
None
mrapb
=
MSARowAttentionWithPairBias
(
c_m
,
c_z
,
c
,
no_heads
,
chunk_size
)
mrapb
=
MSARowAttentionWithPairBias
(
c_m
,
c_z
,
c
,
no_heads
)
m
=
torch
.
rand
((
batch_size
,
n_seq
,
n_res
,
c_m
))
z
=
torch
.
rand
((
batch_size
,
n_res
,
n_res
,
c_z
))
shape_before
=
m
.
shape
m
=
mrapb
(
m
,
z
)
m
=
mrapb
(
m
,
z
=
z
,
chunk_size
=
chunk_size
)
shape_after
=
m
.
shape
self
.
assertTrue
(
shape_before
==
shape_after
)
...
...
@@ -91,8 +91,9 @@ class TestMSARowAttentionWithPairBias(unittest.TestCase):
model
.
evoformer
.
blocks
[
0
]
.
msa_att_row
(
torch
.
as_tensor
(
msa_act
).
cuda
(),
torch
.
as_tensor
(
pair_act
).
cuda
(),
torch
.
as_tensor
(
msa_mask
).
cuda
(),
z
=
torch
.
as_tensor
(
pair_act
).
cuda
(),
chunk_size
=
4
,
mask
=
torch
.
as_tensor
(
msa_mask
).
cuda
(),
)
.
cpu
()
)
...
...
@@ -114,7 +115,7 @@ class TestMSAColumnAttention(unittest.TestCase):
x
=
torch
.
rand
((
batch_size
,
n_seq
,
n_res
,
c_m
))
shape_before
=
x
.
shape
x
=
msaca
(
x
)
x
=
msaca
(
x
,
chunk_size
=
None
)
shape_after
=
x
.
shape
self
.
assertTrue
(
shape_before
==
shape_after
)
...
...
@@ -155,7 +156,8 @@ class TestMSAColumnAttention(unittest.TestCase):
model
.
evoformer
.
blocks
[
0
]
.
msa_att_col
(
torch
.
as_tensor
(
msa_act
).
cuda
(),
torch
.
as_tensor
(
msa_mask
).
cuda
(),
chunk_size
=
4
,
mask
=
torch
.
as_tensor
(
msa_mask
).
cuda
(),
)
.
cpu
()
)
...
...
@@ -177,7 +179,7 @@ class TestMSAColumnGlobalAttention(unittest.TestCase):
x
=
torch
.
rand
((
batch_size
,
n_seq
,
n_res
,
c_m
))
shape_before
=
x
.
shape
x
=
msagca
(
x
)
x
=
msagca
(
x
,
chunk_size
=
None
)
shape_after
=
x
.
shape
self
.
assertTrue
(
shape_before
==
shape_after
)
...
...
@@ -219,6 +221,7 @@ class TestMSAColumnGlobalAttention(unittest.TestCase):
model
.
extra_msa_stack
.
stack
.
blocks
[
0
]
.
msa_att_col
(
torch
.
as_tensor
(
msa_act
,
dtype
=
torch
.
float32
).
cuda
(),
chunk_size
=
4
,
mask
=
torch
.
as_tensor
(
msa_mask
,
dtype
=
torch
.
float32
).
cuda
(),
)
.
cpu
()
...
...
tests/test_outer_product_mean.py
View file @
49767099
...
...
@@ -38,11 +38,11 @@ class TestOuterProductMean(unittest.TestCase):
mask
=
torch
.
randint
(
0
,
2
,
size
=
(
consts
.
batch_size
,
consts
.
n_seq
,
consts
.
n_res
)
)
m
=
opm
(
m
,
mask
)
m
=
opm
(
m
,
mask
=
mask
,
chunk_size
=
None
)
self
.
assertTrue
(
m
.
shape
==
(
consts
.
batch_size
,
consts
.
n_res
,
consts
.
n_res
,
consts
.
c_z
)
m
.
shape
==
(
consts
.
batch_size
,
consts
.
n_res
,
consts
.
n_res
,
consts
.
c_z
)
)
@
compare_utils
.
skip_unless_alphafold_installed
()
...
...
@@ -84,6 +84,7 @@ class TestOuterProductMean(unittest.TestCase):
model
.
evoformer
.
blocks
[
0
]
.
outer_product_mean
(
torch
.
as_tensor
(
msa_act
).
cuda
(),
chunk_size
=
4
,
mask
=
torch
.
as_tensor
(
msa_mask
).
cuda
(),
)
.
cpu
()
...
...
tests/test_pair_transition.py
View file @
49767099
...
...
@@ -39,7 +39,7 @@ class TestPairTransition(unittest.TestCase):
z
=
torch
.
rand
((
batch_size
,
n_res
,
n_res
,
c_z
))
mask
=
torch
.
randint
(
0
,
2
,
size
=
(
batch_size
,
n_res
,
n_res
))
shape_before
=
z
.
shape
z
=
pt
(
z
,
mask
)
z
=
pt
(
z
,
mask
=
mask
,
chunk_size
=
None
)
shape_after
=
z
.
shape
self
.
assertTrue
(
shape_before
==
shape_after
)
...
...
@@ -79,6 +79,7 @@ class TestPairTransition(unittest.TestCase):
model
.
evoformer
.
blocks
[
0
]
.
pair_transition
(
torch
.
as_tensor
(
pair_act
,
dtype
=
torch
.
float32
).
cuda
(),
chunk_size
=
4
,
mask
=
torch
.
as_tensor
(
pair_mask
,
dtype
=
torch
.
float32
).
cuda
(),
)
.
cpu
()
...
...
tests/test_structure_module.py
View file @
49767099
...
...
@@ -16,7 +16,7 @@ import torch
import
numpy
as
np
import
unittest
from
openfold.
features
.data_transforms
import
make_atom14_masks_np
from
openfold.
data
.data_transforms
import
make_atom14_masks_np
from
openfold.np.residue_constants
import
(
restype_rigid_group_default_frame
,
restype_atom14_to_rigid_group
,
...
...
@@ -174,7 +174,7 @@ class TestStructureModule(unittest.TestCase):
# The structure module, thanks to angle normalization, is very volatile
# We only assess the mean here. Heuristically speaking, it seems to
# have lower error in general on real rather than synthetic data.
self
.
assertTrue
(
torch
.
mean
(
torch
.
abs
(
out_gt
-
out_repro
))
<
0.0
1
)
self
.
assertTrue
(
torch
.
mean
(
torch
.
abs
(
out_gt
-
out_repro
))
<
0.0
5
)
class
TestBackboneUpdate
(
unittest
.
TestCase
):
...
...
tests/test_template.py
View file @
49767099
...
...
@@ -42,13 +42,13 @@ class TestTemplatePointwiseAttention(unittest.TestCase):
inf
=
1e7
tpa
=
TemplatePointwiseAttention
(
c_t
,
c_z
,
c
,
no_heads
,
chunk_size
=
4
,
inf
=
inf
c_t
,
c_z
,
c
,
no_heads
,
inf
=
inf
)
t
=
torch
.
rand
((
batch_size
,
n_seq
,
n_res
,
n_res
,
c_t
))
z
=
torch
.
rand
((
batch_size
,
n_res
,
n_res
,
c_z
))
z_update
=
tpa
(
t
,
z
)
z_update
=
tpa
(
t
,
z
,
chunk_size
=
None
)
self
.
assertTrue
(
z_update
.
shape
==
z
.
shape
)
...
...
@@ -79,7 +79,6 @@ class TestTemplatePairStack(unittest.TestCase):
pair_transition_n
=
pt_inner_dim
,
dropout_rate
=
dropout
,
blocks_per_ckpt
=
None
,
chunk_size
=
chunk_size
,
inf
=
inf
,
eps
=
eps
,
)
...
...
@@ -87,7 +86,7 @@ class TestTemplatePairStack(unittest.TestCase):
t
=
torch
.
rand
((
batch_size
,
n_templ
,
n_res
,
n_res
,
c_t
))
mask
=
torch
.
randint
(
0
,
2
,
(
batch_size
,
n_templ
,
n_res
,
n_res
))
shape_before
=
t
.
shape
t
=
tpe
(
t
,
mask
)
t
=
tpe
(
t
,
mask
,
chunk_size
=
chunk_size
)
shape_after
=
t
.
shape
self
.
assertTrue
(
shape_before
==
shape_after
)
...
...
@@ -136,6 +135,7 @@ class TestTemplatePairStack(unittest.TestCase):
out_repro
=
model
.
template_pair_stack
(
torch
.
as_tensor
(
pair_act
).
cuda
(),
torch
.
as_tensor
(
pair_mask
).
cuda
(),
chunk_size
=
None
,
_mask_trans
=
False
,
).
cpu
()
...
...
@@ -161,8 +161,8 @@ class Template(unittest.TestCase):
pair_act
=
np
.
random
.
rand
(
n_res
,
n_res
,
consts
.
c_z
).
astype
(
np
.
float32
)
batch
=
random_template_feats
(
n_templ
,
n_res
)
batch
[
"template_all_atom_masks"
]
=
batch
[
"template_all_atom_mask"
]
pair_mask
=
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
n_res
)).
astype
(
np
.
float32
)
# Fetch pretrained parameters (but only from one block)]
params
=
compare_utils
.
fetch_alphafold_module_weights
(
"alphafold/alphafold_iteration/evoformer/template_embedding"
...
...
@@ -182,6 +182,7 @@ class Template(unittest.TestCase):
torch
.
as_tensor
(
pair_act
).
cuda
(),
torch
.
as_tensor
(
pair_mask
).
cuda
(),
templ_dim
=
0
,
chunk_size
=
None
,
)
out_repro
=
out_repro
[
"template_pair_embedding"
]
out_repro
=
out_repro
.
cpu
()
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment