Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
49767099
"git@developer.sourcefind.cn:OpenDAS/openfold.git" did not exist on "204ed191059a5fadf993a7cab8ca4bd33b744c16"
Commit
49767099
authored
Oct 19, 2021
by
Gustaf Ahdritz
Browse files
Bring tests up to speed
parent
a6f56d16
Changes
21
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
167 additions
and
108 deletions
+167
-108
openfold/config.py
openfold/config.py
+10
-5
openfold/data/data_modules.py
openfold/data/data_modules.py
+61
-21
openfold/data/data_transforms.py
openfold/data/data_transforms.py
+2
-5
openfold/data/feature_pipeline.py
openfold/data/feature_pipeline.py
+2
-11
openfold/data/input_pipeline.py
openfold/data/input_pipeline.py
+5
-7
openfold/model/model.py
openfold/model/model.py
+7
-5
openfold/model/primitives.py
openfold/model/primitives.py
+11
-8
scripts/run_unit_tests.sh
scripts/run_unit_tests.sh
+2
-0
tests/compare_utils.py
tests/compare_utils.py
+1
-1
tests/data_utils.py
tests/data_utils.py
+9
-3
tests/test_evoformer.py
tests/test_evoformer.py
+7
-6
tests/test_feats.py
tests/test_feats.py
+7
-5
tests/test_import_weights.py
tests/test_import_weights.py
+1
-1
tests/test_loss.py
tests/test_loss.py
+1
-1
tests/test_model.py
tests/test_model.py
+17
-11
tests/test_msa.py
tests/test_msa.py
+10
-7
tests/test_outer_product_mean.py
tests/test_outer_product_mean.py
+4
-3
tests/test_pair_transition.py
tests/test_pair_transition.py
+2
-1
tests/test_structure_module.py
tests/test_structure_module.py
+2
-2
tests/test_template.py
tests/test_template.py
+6
-5
No files found.
openfold/config.py
View file @
49767099
...
...
@@ -64,7 +64,6 @@ blocks_per_ckpt = mlc.FieldReference(None, field_type=int)
chunk_size
=
mlc
.
FieldReference
(
4
,
field_type
=
int
)
aux_distogram_bins
=
mlc
.
FieldReference
(
64
,
field_type
=
int
)
eps
=
mlc
.
FieldReference
(
1e-8
,
field_type
=
float
)
num_recycle
=
mlc
.
FieldReference
(
3
,
field_type
=
int
)
templates_enabled
=
mlc
.
FieldReference
(
True
,
field_type
=
bool
)
embed_template_torsion_angles
=
mlc
.
FieldReference
(
True
,
field_type
=
bool
)
...
...
@@ -77,7 +76,6 @@ config = mlc.ConfigDict(
{
"data"
:
{
"common"
:
{
"batch_modes"
:
[(
"clamped"
,
0.9
),
(
"unclamped"
,
0.1
)],
"feat"
:
{
"aatype"
:
[
NUM_RES
],
"all_atom_mask"
:
[
NUM_RES
,
None
],
...
...
@@ -93,7 +91,7 @@ config = mlc.ConfigDict(
"backbone_affine_mask"
:
[
NUM_RES
],
"backbone_affine_tensor"
:
[
NUM_RES
,
None
,
None
],
"bert_mask"
:
[
NUM_MSA_SEQ
,
NUM_RES
],
"chi_angles_sin_cos"
:
[
NUM_RES
,
None
],
"chi_angles_sin_cos"
:
[
NUM_RES
,
None
,
None
],
"chi_mask"
:
[
NUM_RES
,
None
],
"extra_deletion_value"
:
[
NUM_EXTRA_SEQ
,
NUM_RES
],
"extra_has_deletion"
:
[
NUM_EXTRA_SEQ
,
NUM_RES
],
...
...
@@ -104,6 +102,7 @@ config = mlc.ConfigDict(
"msa_feat"
:
[
NUM_MSA_SEQ
,
NUM_RES
,
None
],
"msa_mask"
:
[
NUM_MSA_SEQ
,
NUM_RES
],
"msa_row_mask"
:
[
NUM_MSA_SEQ
],
"no_recycling_iters"
:
[],
"pseudo_beta"
:
[
NUM_RES
,
None
],
"pseudo_beta_mask"
:
[
NUM_RES
],
"residue_index"
:
[
NUM_RES
],
...
...
@@ -149,8 +148,8 @@ config = mlc.ConfigDict(
"uniform_prob"
:
0.1
,
},
"max_extra_msa"
:
1024
,
"max_recycling_iters"
:
3
,
"msa_cluster_features"
:
True
,
"num_recycle"
:
num_recycle
,
"reduce_msa_clusters_by_max_templates"
:
False
,
"resample_msa_in_recycling"
:
True
,
"template_features"
:
[
...
...
@@ -167,9 +166,14 @@ config = mlc.ConfigDict(
"seq_length"
,
"between_segment_residues"
,
"deletion_matrix"
,
"no_recycling_iters"
,
],
"use_templates"
:
templates_enabled
,
"use_template_torsion_angles"
:
embed_template_torsion_angles
,
},
"supervised"
:
{
"clamp_prob"
:
0.9
,
"uniform_recycling"
:
True
,
"supervised_features"
:
[
"all_atom_mask"
,
"all_atom_positions"
,
...
...
@@ -212,6 +216,8 @@ config = mlc.ConfigDict(
"crop"
:
True
,
"crop_size"
:
256
,
"supervised"
:
True
,
"clamp_prob"
:
0.9
,
"subsample_recycling"
:
True
,
},
"data_module"
:
{
"use_small_bfd"
:
False
,
...
...
@@ -234,7 +240,6 @@ config = mlc.ConfigDict(
"eps"
:
eps
,
},
"model"
:
{
"num_recycle"
:
num_recycle
,
"_mask_trans"
:
False
,
"input_embedder"
:
{
"tf_dim"
:
22
,
...
...
openfold/data/data_modules.py
View file @
49767099
...
...
@@ -5,6 +5,7 @@ import os
from
typing
import
Optional
,
Sequence
import
ml_collections
as
mlc
import
numpy
as
np
import
pytorch_lightning
as
pl
import
torch
from
torch.utils.data
import
RandomSampler
...
...
@@ -216,31 +217,66 @@ class OpenFoldDataset(torch.utils.data.IterableDataset):
class
OpenFoldBatchCollator
:
def
__init__
(
self
,
config
,
generator
,
stage
=
"train"
):
self
.
config
=
config
batch_modes
=
config
.
common
.
batch_modes
batch_mode_names
,
batch_mode_probs
=
list
(
zip
(
*
batch_modes
))
self
.
batch_mode_names
=
batch_mode_names
self
.
batch_mode_probs
=
batch_mode_probs
self
.
generator
=
generator
self
.
stage
=
stage
self
.
feature_pipeline
=
feature_pipeline
.
FeaturePipeline
(
config
)
self
.
_prep_batch_properties_probs
()
self
.
batch_mode_probs_tensor
=
torch
.
tensor
(
self
.
batch_mode_probs
)
def
_prep_batch_properties_probs
(
self
):
keyed_probs
=
[]
stage_cfg
=
self
.
config
[
self
.
stage
]
max_iters
=
self
.
config
.
common
.
max_recycling_iters
if
(
stage_cfg
.
supervised
):
clamp_prob
=
self
.
config
.
supervised
.
clamp_prob
keyed_probs
.
append
(
(
"use_clamped_fape"
,
[
1
-
clamp_prob
,
clamp_prob
])
)
if
(
self
.
config
.
supervised
.
uniform_recycling
):
recycling_probs
=
[
1.
/
(
max_iters
+
1
)
for
_
in
range
(
max_iters
+
1
)
]
keyed_probs
.
append
(
(
"no_recycling_iters"
,
recycling_probs
)
)
else
:
recycling_probs
=
[
0.
for
_
in
range
(
max_iters
+
1
)
]
recycling_probs
[
-
1
]
=
1.
keyed_probs
.
append
(
(
"no_recycling_iters"
,
recycling_probs
)
)
self
.
feature_pipeline
=
feature_pipeline
.
FeaturePipeline
(
self
.
config
)
keys
,
probs
=
zip
(
*
keyed_probs
)
max_len
=
max
([
len
(
p
)
for
p
in
probs
])
padding
=
[[
0.
]
*
(
max_len
-
len
(
p
))
for
p
in
probs
]
self
.
prop_keys
=
keys
self
.
prop_probs_tensor
=
torch
.
tensor
(
[
p
+
pad
for
p
,
pad
in
zip
(
probs
,
padding
)],
dtype
=
torch
.
float32
,
)
def
__call__
(
self
,
raw_prots
):
# We use torch.multinomial here rather than Categorical because the
# latter doesn't accept a generator for some reason
batch_mode_idx
=
torch
.
multinomial
(
self
.
batch_mode_probs_tensor
,
1
,
def
_add_batch_properties
(
self
,
raw_prots
):
samples
=
torch
.
multinomial
(
self
.
prop_probs_tensor
,
num_samples
=
1
,
# 1 per row
replacement
=
True
,
generator
=
self
.
generator
).
item
()
batch_mode_name
=
self
.
batch_mode_names
[
batch_mode_idx
]
)
for
i
,
key
in
enumerate
(
self
.
prop_keys
):
sample
=
samples
[
i
][
0
]
for
prot
in
raw_prots
:
prot
[
key
]
=
np
.
array
(
sample
,
dtype
=
np
.
float32
)
def
__call__
(
self
,
raw_prots
):
self
.
_add_batch_properties
(
raw_prots
)
processed_prots
=
[]
for
prot
in
raw_prots
:
features
=
self
.
feature_pipeline
.
process_features
(
prot
,
self
.
stage
,
batch_mode_name
prot
,
self
.
stage
)
processed_prots
.
append
(
features
)
...
...
@@ -264,7 +300,8 @@ class OpenFoldDataModule(pl.LightningDataModule):
kalign_binary_path
:
str
=
'/usr/bin/kalign'
,
train_mapping_path
:
Optional
[
str
]
=
None
,
distillation_mapping_path
:
Optional
[
str
]
=
None
,
template_release_dates_cache_path
:
Optional
[
str
]
=
None
,
template_release_dates_cache_path
:
Optional
[
str
]
=
None
,
batch_seed
:
Optional
[
int
]
=
None
,
**
kwargs
):
super
(
OpenFoldDataModule
,
self
).
__init__
()
...
...
@@ -286,6 +323,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
self
.
template_release_dates_cache_path
=
(
template_release_dates_cache_path
)
self
.
batch_seed
=
batch_seed
if
(
self
.
train_data_dir
is
None
and
self
.
predict_data_dir
is
None
):
raise
ValueError
(
...
...
@@ -309,7 +347,10 @@ class OpenFoldDataModule(pl.LightningDataModule):
'be specified as well'
)
def
setup
(
self
,
stage
):
def
setup
(
self
,
stage
:
Optional
[
str
]
=
None
):
if
(
stage
is
None
):
stage
=
"train"
# Most of the arguments are the same for the three datasets
dataset_gen
=
partial
(
OpenFoldSingleDataset
,
template_mmcif_dir
=
self
.
template_mmcif_dir
,
...
...
@@ -369,12 +410,11 @@ class OpenFoldDataModule(pl.LightningDataModule):
mode
=
"predict"
,
)
self
.
batch_collation_seed
=
torch
.
Generator
().
seed
()
def
_gen_batch_collator
(
self
,
stage
):
""" We want each process to use the same batch collation seed """
generator
=
torch
.
Generator
()
generator
=
generator
.
manual_seed
(
self
.
batch_collation_seed
)
if
(
self
.
batch_seed
is
not
None
):
generator
=
generator
.
manual_seed
(
self
.
batch_seed
)
collate_fn
=
OpenFoldBatchCollator
(
self
.
config
,
generator
,
stage
)
...
...
@@ -404,5 +444,5 @@ class OpenFoldDataModule(pl.LightningDataModule):
self
.
predict_dataset
,
batch_size
=
self
.
config
.
data_module
.
data_loaders
.
batch_size
,
num_workers
=
self
.
config
.
data_module
.
data_loaders
.
num_workers
,
collate_fn
=
self
.
_gen_batch_collator
(
"
eval
"
)
collate_fn
=
self
.
_gen_batch_collator
(
"
predict
"
)
)
openfold/data/data_transforms.py
View file @
49767099
...
...
@@ -1095,7 +1095,6 @@ def random_crop_to_size(
shape_schema
,
subsample_templates
=
False
,
seed
=
None
,
batch_mode
=
"clamped"
,
):
"""Crop randomly to `crop_size`, or keep as is if shorter than that."""
seq_length
=
protein
[
"seq_length"
]
...
...
@@ -1133,13 +1132,11 @@ def random_crop_to_size(
num_templates_crop_size
=
num_templates
n
=
seq_length
-
num_res_crop_size
if
batch_mode
==
"clamped"
:
if
protein
[
"use_clamped_fape"
]
==
1.
:
right_anchor
=
n
el
if
batch_mode
==
"unclamped"
:
el
se
:
x
=
_randint
(
0
,
n
)
right_anchor
=
n
-
x
else
:
raise
ValueError
(
"Invalid batch mode"
)
num_res_crop_start
=
_randint
(
0
,
right_anchor
)
...
...
openfold/data/feature_pipeline.py
View file @
49767099
...
...
@@ -64,7 +64,7 @@ def make_data_config(
feature_names
+=
cfg
.
common
.
template_features
if
cfg
[
mode
].
supervised
:
feature_names
+=
cfg
.
common
.
supervised_features
feature_names
+=
cfg
.
supervised
.
supervised_features
return
cfg
,
feature_names
...
...
@@ -73,7 +73,6 @@ def np_example_to_features(
np_example
:
FeatureDict
,
config
:
ml_collections
.
ConfigDict
,
mode
:
str
,
batch_mode
:
str
,
):
np_example
=
dict
(
np_example
)
num_res
=
int
(
np_example
[
"seq_length"
][
0
])
...
...
@@ -84,11 +83,6 @@ def np_example_to_features(
"deletion_matrix_int"
).
astype
(
np
.
float32
)
if
batch_mode
==
"clamped"
:
np_example
[
"use_clamped_fape"
]
=
np
.
array
(
1.0
).
astype
(
np
.
float32
)
elif
batch_mode
==
"unclamped"
:
np_example
[
"use_clamped_fape"
]
=
np
.
array
(
0.0
).
astype
(
np
.
float32
)
tensor_dict
=
np_to_tensor_dict
(
np_example
=
np_example
,
features
=
feature_names
)
...
...
@@ -97,7 +91,6 @@ def np_example_to_features(
tensor_dict
,
cfg
.
common
,
cfg
[
mode
],
batch_mode
=
batch_mode
,
)
return
{
k
:
v
for
k
,
v
in
features
.
items
()}
...
...
@@ -115,12 +108,10 @@ class FeaturePipeline:
def
process_features
(
self
,
raw_features
:
FeatureDict
,
mode
:
str
=
"train"
,
batch_mode
:
str
=
"clamped"
,
mode
:
str
=
"train"
,
)
->
FeatureDict
:
return
np_example_to_features
(
np_example
=
raw_features
,
config
=
self
.
config
,
mode
=
mode
,
batch_mode
=
batch_mode
,
)
openfold/data/input_pipeline.py
View file @
49767099
...
...
@@ -68,7 +68,7 @@ def nonensembled_transform_fns(common_cfg, mode_cfg):
return
transforms
def
ensembled_transform_fns
(
common_cfg
,
mode_cfg
,
batch_mode
,
ensemble_seed
):
def
ensembled_transform_fns
(
common_cfg
,
mode_cfg
,
ensemble_seed
):
"""Input pipeline data transformers that can be ensembled and averaged."""
transforms
=
[]
...
...
@@ -116,7 +116,6 @@ def ensembled_transform_fns(common_cfg, mode_cfg, batch_mode, ensemble_seed):
mode_cfg
.
max_templates
,
crop_feats
,
mode_cfg
.
subsample_templates
,
batch_mode
=
batch_mode
,
seed
=
ensemble_seed
,
)
)
...
...
@@ -137,9 +136,7 @@ def ensembled_transform_fns(common_cfg, mode_cfg, batch_mode, ensemble_seed):
return
transforms
def
process_tensors_from_config
(
tensors
,
common_cfg
,
mode_cfg
,
batch_mode
=
"clamped"
):
def
process_tensors_from_config
(
tensors
,
common_cfg
,
mode_cfg
):
"""Based on the config, apply filters and transformations to the data."""
ensemble_seed
=
torch
.
Generator
().
seed
()
...
...
@@ -150,7 +147,6 @@ def process_tensors_from_config(
fns
=
ensembled_transform_fns
(
common_cfg
,
mode_cfg
,
batch_mode
,
ensemble_seed
,
)
fn
=
compose
(
fns
)
...
...
@@ -160,9 +156,11 @@ def process_tensors_from_config(
tensors
=
compose
(
nonensembled_transform_fns
(
common_cfg
,
mode_cfg
))(
tensors
)
num_ensemble
=
mode_cfg
.
num_ensemble
num_recycling
=
tensors
[
"no_recycling_iters"
].
item
()
if
common_cfg
.
resample_msa_in_recycling
:
# Separate batch per ensembling & recycling step.
num_ensemble
*=
common_cfg
.
num_recycl
e
+
1
num_ensemble
*=
num_recycl
ing
+
1
if
isinstance
(
num_ensemble
,
torch
.
Tensor
)
or
num_ensemble
>
1
:
tensors
=
map_fn
(
...
...
openfold/model/model.py
View file @
49767099
...
...
@@ -202,7 +202,7 @@ class AlphaFold(nn.Module):
)
# Inject information from previous recycling iterations
if
self
.
config
.
num_recycle
>
0
:
if
feats
[
"no_recycling_iters"
]
>
0
:
# Initialize the recycling embeddings, if needs be
if
None
in
[
m_1_prev
,
z_prev
,
x_prev
]:
# [*, N, C_m]
...
...
@@ -236,7 +236,7 @@ class AlphaFold(nn.Module):
# [*, N, N, C_z]
z
=
z
+
z_prev_emb
#
This can matter during inference when N_res is very large
#
Possibly prevents memory fragmentation
del
m_1_prev_emb
,
z_prev_emb
# Embed the templates + merge with MSA/pair embeddings
...
...
@@ -395,19 +395,21 @@ class AlphaFold(nn.Module):
# Initialize recycling embeddings
m_1_prev
,
z_prev
,
x_prev
=
None
,
None
,
None
# Disable activation checkpointing for the first few recycling iters
is_grad_enabled
=
torch
.
is_grad_enabled
()
self
.
_disable_activation_checkpointing
()
# Main recycling loop
for
cycle_no
in
range
(
self
.
config
.
num_recycle
+
1
):
num_iters
=
batch
[
"aatype"
].
shape
[
-
1
]
for
cycle_no
in
range
(
num_iters
):
# Select the features for the current recycling cycle
fetch_cur_batch
=
lambda
t
:
t
[...,
cycle_no
]
feats
=
tensor_tree_map
(
fetch_cur_batch
,
batch
)
# Enable grad iff we're training and it's the final recycling layer
is_final_iter
=
cycle_no
==
self
.
config
.
num_recycle
is_final_iter
=
cycle_no
==
(
num_iters
-
1
)
with
torch
.
set_grad_enabled
(
is_grad_enabled
and
is_final_iter
):
# Sidestep AMP bug
discussed in pyt
orch issue #65766
# Sidestep AMP bug
(PyT
orch issue #65766
)
if
is_final_iter
:
self
.
_enable_activation_checkpointing
()
if
torch
.
is_autocast_enabled
():
...
...
openfold/model/primitives.py
View file @
49767099
...
...
@@ -258,11 +258,14 @@ class Attention(nn.Module):
k
=
k
.
view
(
k
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
v
=
v
.
view
(
v
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
# [*, H, Q, C_hidden]
q
=
permute_final_dims
(
q
,
(
1
,
0
,
2
))
# [*, H, C_hidden, K]
k
=
permute_final_dims
(
k
,
(
1
,
2
,
0
))
# [*, H, Q, K]
a
=
torch
.
matmul
(
permute_final_dims
(
q
,
(
1
,
0
,
2
)),
# [*, H, Q, C_hidden]
permute_final_dims
(
k
,
(
1
,
2
,
0
)),
# [*, H, C_hidden, K]
)
a
=
torch
.
matmul
(
q
,
k
)
del
q
,
k
...
...
@@ -273,11 +276,11 @@ class Attention(nn.Module):
a
=
a
+
b
a
=
self
.
softmax
(
a
)
# [*, H, V, C_hidden]
v
=
permute_final_dims
(
v
,
(
1
,
0
,
2
))
# [*, H, Q, C_hidden]
o
=
torch
.
matmul
(
a
,
permute_final_dims
(
v
,
(
1
,
0
,
2
)),
# [*, H, V, C_hidden]
)
o
=
torch
.
matmul
(
a
,
v
)
# [*, Q, H, C_hidden]
o
=
o
.
transpose
(
-
2
,
-
3
)
...
...
scripts/run_unit_tests.sh
View file @
49767099
#!/bin/bash
#CUDA_VISIBLE_DEVICES="5"
python3
-m
unittest
"
$@
"
||
\
echo
-e
"
\n
Test(s) failed. Make sure you've installed all Python dependencies."
tests/compare_utils.py
View file @
49767099
...
...
@@ -60,7 +60,7 @@ _model = None
def
get_global_pretrained_openfold
():
global
_model
if
_model
is
None
:
_model
=
AlphaFold
(
model_config
(
"model_1_ptm"
)
.
model
)
_model
=
AlphaFold
(
model_config
(
"model_1_ptm"
))
_model
=
_model
.
eval
()
if
not
os
.
path
.
exists
(
_param_path
):
raise
FileNotFoundError
(
...
...
tests/data_utils.py
View file @
49767099
...
...
@@ -25,11 +25,17 @@ def random_template_feats(n_templ, n, batch_size=None):
"template_pseudo_beta_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
*
b
,
n_templ
,
n
)),
"template_pseudo_beta"
:
np
.
random
.
rand
(
*
b
,
n_templ
,
n
,
3
),
"template_aatype"
:
np
.
random
.
randint
(
0
,
22
,
(
*
b
,
n_templ
,
n
)),
"template_all_atom_mask
s
"
:
np
.
random
.
randint
(
"template_all_atom_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
*
b
,
n_templ
,
n
,
37
)
),
"template_all_atom_positions"
:
np
.
random
.
rand
(
*
b
,
n_templ
,
n
,
37
,
3
)
*
10
,
"template_all_atom_positions"
:
np
.
random
.
rand
(
*
b
,
n_templ
,
n
,
37
,
3
)
*
10
,
"template_torsion_angles_sin_cos"
:
np
.
random
.
rand
(
*
b
,
n_templ
,
n
,
7
,
2
),
"template_alt_torsion_angles_sin_cos"
:
np
.
random
.
rand
(
*
b
,
n_templ
,
n
,
7
,
2
),
"template_torsion_angles_mask"
:
np
.
random
.
rand
(
*
b
,
n_templ
,
n
,
7
),
}
batch
=
{
k
:
v
.
astype
(
np
.
float32
)
for
k
,
v
in
batch
.
items
()}
batch
[
"template_aatype"
]
=
batch
[
"template_aatype"
].
astype
(
np
.
int64
)
...
...
tests/test_evoformer.py
View file @
49767099
...
...
@@ -66,7 +66,6 @@ class TestEvoformerStack(unittest.TestCase):
msa_dropout
,
pair_stack_dropout
,
blocks_per_ckpt
=
None
,
chunk_size
=
4
,
inf
=
inf
,
eps
=
eps
,
).
eval
()
...
...
@@ -79,7 +78,9 @@ class TestEvoformerStack(unittest.TestCase):
shape_m_before
=
m
.
shape
shape_z_before
=
z
.
shape
m
,
z
,
s
=
es
(
m
,
z
,
msa_mask
,
pair_mask
)
m
,
z
,
s
=
es
(
m
,
z
,
chunk_size
=
4
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
)
self
.
assertTrue
(
m
.
shape
==
shape_m_before
)
self
.
assertTrue
(
z
.
shape
==
shape_z_before
)
...
...
@@ -127,6 +128,7 @@ class TestEvoformerStack(unittest.TestCase):
torch
.
as_tensor
(
activations
[
"pair"
]).
cuda
(),
torch
.
as_tensor
(
masks
[
"msa"
]).
cuda
(),
torch
.
as_tensor
(
masks
[
"pair"
]).
cuda
(),
chunk_size
=
4
,
_mask_trans
=
False
,
)
...
...
@@ -171,7 +173,6 @@ class TestExtraMSAStack(unittest.TestCase):
msa_dropout
,
pair_stack_dropout
,
blocks_per_ckpt
=
None
,
chunk_size
=
4
,
inf
=
inf
,
eps
=
eps
,
).
eval
()
...
...
@@ -199,7 +200,7 @@ class TestExtraMSAStack(unittest.TestCase):
shape_z_before
=
z
.
shape
z
=
es
(
m
,
z
,
msa_mask
,
pair_mask
)
z
=
es
(
m
,
z
,
chunk_size
=
4
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
)
self
.
assertTrue
(
z
.
shape
==
shape_z_before
)
...
...
@@ -212,12 +213,12 @@ class TestMSATransition(unittest.TestCase):
c_m
=
7
n
=
11
mt
=
MSATransition
(
c_m
,
n
,
chunk_size
=
4
)
mt
=
MSATransition
(
c_m
,
n
)
m
=
torch
.
rand
((
batch_size
,
s_t
,
n_r
,
c_m
))
shape_before
=
m
.
shape
m
=
mt
(
m
)
m
=
mt
(
m
,
chunk_size
=
4
)
shape_after
=
m
.
shape
self
.
assertTrue
(
shape_before
==
shape_after
)
...
...
tests/test_feats.py
View file @
49767099
...
...
@@ -16,7 +16,7 @@ import torch
import
numpy
as
np
import
unittest
import
openfold.
features
.data_transforms
as
data_transforms
import
openfold.
data
.data_transforms
as
data_transforms
from
openfold.np.residue_constants
import
(
restype_rigid_group_default_frame
,
restype_atom14_to_rigid_group
,
...
...
@@ -102,10 +102,12 @@ class TestFeats(unittest.TestCase):
out_gt
=
f
.
apply
({},
None
,
aatype
,
all_atom_pos
,
all_atom_mask
)
out_gt
=
jax
.
tree_map
(
lambda
x
:
torch
.
as_tensor
(
np
.
array
(
x
)),
out_gt
)
out_repro
=
feats
.
atom37_to_torsion_angles
(
torch
.
as_tensor
(
aatype
).
cuda
(),
torch
.
as_tensor
(
all_atom_pos
).
cuda
(),
torch
.
as_tensor
(
all_atom_mask
).
cuda
(),
out_repro
=
data_transforms
.
atom37_to_torsion_angles
()(
{
"aatype"
:
torch
.
as_tensor
(
aatype
).
cuda
(),
"all_atom_positions"
:
torch
.
as_tensor
(
all_atom_pos
).
cuda
(),
"all_atom_mask"
:
torch
.
as_tensor
(
all_atom_mask
).
cuda
(),
},
)
tasc
=
out_repro
[
"torsion_angles_sin_cos"
].
cpu
()
atasc
=
out_repro
[
"alt_torsion_angles_sin_cos"
].
cpu
()
...
...
tests/test_import_weights.py
View file @
49767099
...
...
@@ -27,7 +27,7 @@ class TestImportWeights(unittest.TestCase):
c
=
model_config
(
"model_1_ptm"
)
c
.
globals
.
blocks_per_ckpt
=
None
model
=
AlphaFold
(
c
.
model
)
model
=
AlphaFold
(
c
)
import_jax_weights_
(
model
,
...
...
tests/test_loss.py
View file @
49767099
...
...
@@ -19,7 +19,7 @@ import numpy as np
import
unittest
import
ml_collections
as
mlc
from
openfold.
features
import
data_transforms
from
openfold.
data
import
data_transforms
from
openfold.utils.affine_utils
import
T
,
affine_vector_to_4x4
import
openfold.utils.feats
as
feats
from
openfold.utils.loss
import
(
...
...
tests/test_model.py
View file @
49767099
...
...
@@ -18,7 +18,7 @@ import torch.nn as nn
import
numpy
as
np
import
unittest
from
openfold.config
import
model_config
from
openfold.
features.data_transforms
import
make_atom14_mask
s
from
openfold.
data
import
data_transform
s
from
openfold.model.model
import
AlphaFold
import
openfold.utils.feats
as
feats
from
openfold.utils.tensor_utils
import
tree_map
,
tensor_tree_map
...
...
@@ -42,22 +42,21 @@ class TestModel(unittest.TestCase):
n_res
=
consts
.
n_res
n_extra_seq
=
consts
.
n_extra
c
=
model_config
(
"model_1"
).
model
c
.
no_cycles
=
2
c
.
evoformer_stack
.
no_blocks
=
4
# no need to go overboard here
c
.
evoformer_stack
.
blocks_per_ckpt
=
None
# don't want to set up
c
=
model_config
(
"model_1"
)
c
.
model
.
evoformer_stack
.
no_blocks
=
4
# no need to go overboard here
c
.
model
.
evoformer_stack
.
blocks_per_ckpt
=
None
# don't want to set up
# deepspeed for this test
model
=
AlphaFold
(
c
)
batch
=
{}
tf
=
torch
.
randint
(
c
.
input_embedder
.
tf_dim
-
1
,
size
=
(
n_res
,))
tf
=
torch
.
randint
(
c
.
model
.
input_embedder
.
tf_dim
-
1
,
size
=
(
n_res
,))
batch
[
"target_feat"
]
=
nn
.
functional
.
one_hot
(
tf
,
c
.
input_embedder
.
tf_dim
tf
,
c
.
model
.
input_embedder
.
tf_dim
).
float
()
batch
[
"aatype"
]
=
torch
.
argmax
(
batch
[
"target_feat"
],
dim
=-
1
)
batch
[
"residue_index"
]
=
torch
.
arange
(
n_res
)
batch
[
"msa_feat"
]
=
torch
.
rand
((
n_seq
,
n_res
,
c
.
input_embedder
.
msa_dim
))
batch
[
"msa_feat"
]
=
torch
.
rand
((
n_seq
,
n_res
,
c
.
model
.
input_embedder
.
msa_dim
))
t_feats
=
random_template_feats
(
n_templ
,
n_res
)
batch
.
update
({
k
:
torch
.
tensor
(
v
)
for
k
,
v
in
t_feats
.
items
()})
extra_feats
=
random_extra_msa_feats
(
n_extra_seq
,
n_res
)
...
...
@@ -66,10 +65,11 @@ class TestModel(unittest.TestCase):
low
=
0
,
high
=
2
,
size
=
(
n_seq
,
n_res
)
).
float
()
batch
[
"seq_mask"
]
=
torch
.
randint
(
low
=
0
,
high
=
2
,
size
=
(
n_res
,)).
float
()
batch
.
update
(
make_atom14_masks
(
batch
))
batch
.
update
(
data_transforms
.
make_atom14_masks
(
batch
))
batch
[
"no_recycling_iters"
]
=
torch
.
tensor
(
2.
)
add_recycling_dims
=
lambda
t
:
(
t
.
unsqueeze
(
-
1
).
expand
(
*
t
.
shape
,
c
.
no_cycle
s
)
t
.
unsqueeze
(
-
1
).
expand
(
*
t
.
shape
,
c
.
data
.
common
.
max_recycling_iter
s
)
)
batch
=
tensor_tree_map
(
add_recycling_dims
,
batch
)
...
...
@@ -94,7 +94,7 @@ class TestModel(unittest.TestCase):
with
open
(
"tests/test_data/sample_feats.pickle"
,
"rb"
)
as
fp
:
batch
=
pickle
.
load
(
fp
)
out_gt
=
jax
.
jit
(
f
.
apply
)
(
params
,
jax
.
random
.
PRNGKey
(
42
),
batch
)
out_gt
=
f
.
apply
(
params
,
jax
.
random
.
PRNGKey
(
42
),
batch
)
out_gt
=
out_gt
[
"structure_module"
][
"final_atom_positions"
]
# atom37_to_atom14 doesn't like batches
...
...
@@ -103,13 +103,19 @@ class TestModel(unittest.TestCase):
out_gt
=
alphafold
.
model
.
all_atom
.
atom37_to_atom14
(
out_gt
,
batch
)
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
.
block_until_ready
()))
batch
[
"no_recycling_iters"
]
=
np
.
array
([
3.
,
3.
,
3.
,
3.
,])
batch
=
{
k
:
torch
.
as_tensor
(
v
).
cuda
()
for
k
,
v
in
batch
.
items
()}
batch
[
"aatype"
]
=
batch
[
"aatype"
].
long
()
batch
[
"template_aatype"
]
=
batch
[
"template_aatype"
].
long
()
batch
[
"extra_msa"
]
=
batch
[
"extra_msa"
].
long
()
batch
[
"residx_atom37_to_atom14"
]
=
batch
[
"residx_atom37_to_atom14"
].
long
()
batch
[
"template_all_atom_mask"
]
=
batch
[
"template_all_atom_masks"
]
batch
.
update
(
data_transforms
.
atom37_to_torsion_angles
(
"template_"
)(
batch
)
)
# Move the recycling dimension to the end
move_dim
=
lambda
t
:
t
.
permute
(
*
range
(
len
(
t
.
shape
))[
1
:],
0
)
...
...
tests/test_msa.py
View file @
49767099
...
...
@@ -41,13 +41,13 @@ class TestMSARowAttentionWithPairBias(unittest.TestCase):
no_heads
=
4
chunk_size
=
None
mrapb
=
MSARowAttentionWithPairBias
(
c_m
,
c_z
,
c
,
no_heads
,
chunk_size
)
mrapb
=
MSARowAttentionWithPairBias
(
c_m
,
c_z
,
c
,
no_heads
)
m
=
torch
.
rand
((
batch_size
,
n_seq
,
n_res
,
c_m
))
z
=
torch
.
rand
((
batch_size
,
n_res
,
n_res
,
c_z
))
shape_before
=
m
.
shape
m
=
mrapb
(
m
,
z
)
m
=
mrapb
(
m
,
z
=
z
,
chunk_size
=
chunk_size
)
shape_after
=
m
.
shape
self
.
assertTrue
(
shape_before
==
shape_after
)
...
...
@@ -91,8 +91,9 @@ class TestMSARowAttentionWithPairBias(unittest.TestCase):
model
.
evoformer
.
blocks
[
0
]
.
msa_att_row
(
torch
.
as_tensor
(
msa_act
).
cuda
(),
torch
.
as_tensor
(
pair_act
).
cuda
(),
torch
.
as_tensor
(
msa_mask
).
cuda
(),
z
=
torch
.
as_tensor
(
pair_act
).
cuda
(),
chunk_size
=
4
,
mask
=
torch
.
as_tensor
(
msa_mask
).
cuda
(),
)
.
cpu
()
)
...
...
@@ -114,7 +115,7 @@ class TestMSAColumnAttention(unittest.TestCase):
x
=
torch
.
rand
((
batch_size
,
n_seq
,
n_res
,
c_m
))
shape_before
=
x
.
shape
x
=
msaca
(
x
)
x
=
msaca
(
x
,
chunk_size
=
None
)
shape_after
=
x
.
shape
self
.
assertTrue
(
shape_before
==
shape_after
)
...
...
@@ -155,7 +156,8 @@ class TestMSAColumnAttention(unittest.TestCase):
model
.
evoformer
.
blocks
[
0
]
.
msa_att_col
(
torch
.
as_tensor
(
msa_act
).
cuda
(),
torch
.
as_tensor
(
msa_mask
).
cuda
(),
chunk_size
=
4
,
mask
=
torch
.
as_tensor
(
msa_mask
).
cuda
(),
)
.
cpu
()
)
...
...
@@ -177,7 +179,7 @@ class TestMSAColumnGlobalAttention(unittest.TestCase):
x
=
torch
.
rand
((
batch_size
,
n_seq
,
n_res
,
c_m
))
shape_before
=
x
.
shape
x
=
msagca
(
x
)
x
=
msagca
(
x
,
chunk_size
=
None
)
shape_after
=
x
.
shape
self
.
assertTrue
(
shape_before
==
shape_after
)
...
...
@@ -219,6 +221,7 @@ class TestMSAColumnGlobalAttention(unittest.TestCase):
model
.
extra_msa_stack
.
stack
.
blocks
[
0
]
.
msa_att_col
(
torch
.
as_tensor
(
msa_act
,
dtype
=
torch
.
float32
).
cuda
(),
chunk_size
=
4
,
mask
=
torch
.
as_tensor
(
msa_mask
,
dtype
=
torch
.
float32
).
cuda
(),
)
.
cpu
()
...
...
tests/test_outer_product_mean.py
View file @
49767099
...
...
@@ -38,11 +38,11 @@ class TestOuterProductMean(unittest.TestCase):
mask
=
torch
.
randint
(
0
,
2
,
size
=
(
consts
.
batch_size
,
consts
.
n_seq
,
consts
.
n_res
)
)
m
=
opm
(
m
,
mask
)
m
=
opm
(
m
,
mask
=
mask
,
chunk_size
=
None
)
self
.
assertTrue
(
m
.
shape
==
(
consts
.
batch_size
,
consts
.
n_res
,
consts
.
n_res
,
consts
.
c_z
)
m
.
shape
==
(
consts
.
batch_size
,
consts
.
n_res
,
consts
.
n_res
,
consts
.
c_z
)
)
@
compare_utils
.
skip_unless_alphafold_installed
()
...
...
@@ -84,6 +84,7 @@ class TestOuterProductMean(unittest.TestCase):
model
.
evoformer
.
blocks
[
0
]
.
outer_product_mean
(
torch
.
as_tensor
(
msa_act
).
cuda
(),
chunk_size
=
4
,
mask
=
torch
.
as_tensor
(
msa_mask
).
cuda
(),
)
.
cpu
()
...
...
tests/test_pair_transition.py
View file @
49767099
...
...
@@ -39,7 +39,7 @@ class TestPairTransition(unittest.TestCase):
z
=
torch
.
rand
((
batch_size
,
n_res
,
n_res
,
c_z
))
mask
=
torch
.
randint
(
0
,
2
,
size
=
(
batch_size
,
n_res
,
n_res
))
shape_before
=
z
.
shape
z
=
pt
(
z
,
mask
)
z
=
pt
(
z
,
mask
=
mask
,
chunk_size
=
None
)
shape_after
=
z
.
shape
self
.
assertTrue
(
shape_before
==
shape_after
)
...
...
@@ -79,6 +79,7 @@ class TestPairTransition(unittest.TestCase):
model
.
evoformer
.
blocks
[
0
]
.
pair_transition
(
torch
.
as_tensor
(
pair_act
,
dtype
=
torch
.
float32
).
cuda
(),
chunk_size
=
4
,
mask
=
torch
.
as_tensor
(
pair_mask
,
dtype
=
torch
.
float32
).
cuda
(),
)
.
cpu
()
...
...
tests/test_structure_module.py
View file @
49767099
...
...
@@ -16,7 +16,7 @@ import torch
import
numpy
as
np
import
unittest
from
openfold.
features
.data_transforms
import
make_atom14_masks_np
from
openfold.
data
.data_transforms
import
make_atom14_masks_np
from
openfold.np.residue_constants
import
(
restype_rigid_group_default_frame
,
restype_atom14_to_rigid_group
,
...
...
@@ -174,7 +174,7 @@ class TestStructureModule(unittest.TestCase):
# The structure module, thanks to angle normalization, is very volatile
# We only assess the mean here. Heuristically speaking, it seems to
# have lower error in general on real rather than synthetic data.
self
.
assertTrue
(
torch
.
mean
(
torch
.
abs
(
out_gt
-
out_repro
))
<
0.0
1
)
self
.
assertTrue
(
torch
.
mean
(
torch
.
abs
(
out_gt
-
out_repro
))
<
0.0
5
)
class
TestBackboneUpdate
(
unittest
.
TestCase
):
...
...
tests/test_template.py
View file @
49767099
...
...
@@ -42,13 +42,13 @@ class TestTemplatePointwiseAttention(unittest.TestCase):
inf
=
1e7
tpa
=
TemplatePointwiseAttention
(
c_t
,
c_z
,
c
,
no_heads
,
chunk_size
=
4
,
inf
=
inf
c_t
,
c_z
,
c
,
no_heads
,
inf
=
inf
)
t
=
torch
.
rand
((
batch_size
,
n_seq
,
n_res
,
n_res
,
c_t
))
z
=
torch
.
rand
((
batch_size
,
n_res
,
n_res
,
c_z
))
z_update
=
tpa
(
t
,
z
)
z_update
=
tpa
(
t
,
z
,
chunk_size
=
None
)
self
.
assertTrue
(
z_update
.
shape
==
z
.
shape
)
...
...
@@ -79,7 +79,6 @@ class TestTemplatePairStack(unittest.TestCase):
pair_transition_n
=
pt_inner_dim
,
dropout_rate
=
dropout
,
blocks_per_ckpt
=
None
,
chunk_size
=
chunk_size
,
inf
=
inf
,
eps
=
eps
,
)
...
...
@@ -87,7 +86,7 @@ class TestTemplatePairStack(unittest.TestCase):
t
=
torch
.
rand
((
batch_size
,
n_templ
,
n_res
,
n_res
,
c_t
))
mask
=
torch
.
randint
(
0
,
2
,
(
batch_size
,
n_templ
,
n_res
,
n_res
))
shape_before
=
t
.
shape
t
=
tpe
(
t
,
mask
)
t
=
tpe
(
t
,
mask
,
chunk_size
=
chunk_size
)
shape_after
=
t
.
shape
self
.
assertTrue
(
shape_before
==
shape_after
)
...
...
@@ -136,6 +135,7 @@ class TestTemplatePairStack(unittest.TestCase):
out_repro
=
model
.
template_pair_stack
(
torch
.
as_tensor
(
pair_act
).
cuda
(),
torch
.
as_tensor
(
pair_mask
).
cuda
(),
chunk_size
=
None
,
_mask_trans
=
False
,
).
cpu
()
...
...
@@ -161,8 +161,8 @@ class Template(unittest.TestCase):
pair_act
=
np
.
random
.
rand
(
n_res
,
n_res
,
consts
.
c_z
).
astype
(
np
.
float32
)
batch
=
random_template_feats
(
n_templ
,
n_res
)
batch
[
"template_all_atom_masks"
]
=
batch
[
"template_all_atom_mask"
]
pair_mask
=
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
n_res
)).
astype
(
np
.
float32
)
# Fetch pretrained parameters (but only from one block)]
params
=
compare_utils
.
fetch_alphafold_module_weights
(
"alphafold/alphafold_iteration/evoformer/template_embedding"
...
...
@@ -182,6 +182,7 @@ class Template(unittest.TestCase):
torch
.
as_tensor
(
pair_act
).
cuda
(),
torch
.
as_tensor
(
pair_mask
).
cuda
(),
templ_dim
=
0
,
chunk_size
=
None
,
)
out_repro
=
out_repro
[
"template_pair_embedding"
]
out_repro
=
out_repro
.
cpu
()
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment