Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
9ce96fb5
Commit
9ce96fb5
authored
Dec 27, 2021
by
Gustaf Ahdritz
Browse files
Merge branch 'main' of
ssh://github.com/aqlaboratory/openfold
into main
parents
3d5e8740
e1c7c9e7
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
257 additions
and
14 deletions
+257
-14
README.md
README.md
+1
-1
openfold/data/data_transforms.py
openfold/data/data_transforms.py
+20
-9
scripts/prep_proteinnet_msas.py
scripts/prep_proteinnet_msas.py
+0
-0
setup.py
setup.py
+1
-1
tests/config.py
tests/config.py
+14
-0
tests/test_data/features.pkl
tests/test_data/features.pkl
+0
-0
tests/test_data_transforms.py
tests/test_data_transforms.py
+221
-3
No files found.
README.md
View file @
9ce96fb5
...
...
@@ -86,7 +86,7 @@ MMseqs2 should be split according to the memory available on the system).
Alternatively, you can use raw MSAs from
[
ProteinNet
](
https://github.com/aqlaboratory/proteinnet
)
. After downloading
the database, use
`scripts/prep
are
_proteinnet_msas.py`
to convert the data into
the database, use
`scripts/prep_proteinnet_msas.py`
to convert the data into
a format recognized by the OpenFold parser. The resulting directory becomes the
`alignment_dir`
used in subsequent steps. Use
`scripts/unpack_proteinnet.py`
to
extract
`.core`
files from ProteinNet text files.
...
...
openfold/data/data_transforms.py
View file @
9ce96fb5
...
...
@@ -14,7 +14,7 @@
# limitations under the License.
import
itertools
from
functools
import
reduce
from
functools
import
reduce
,
wraps
from
operator
import
add
import
numpy
as
np
...
...
@@ -71,7 +71,7 @@ def make_template_mask(protein):
def
curry1
(
f
):
"""Supply all arguments but the first."""
@
wraps
(
f
)
def
fc
(
*
args
,
**
kwargs
):
return
lambda
x
:
f
(
x
,
*
args
,
**
kwargs
)
...
...
@@ -145,7 +145,10 @@ def squeeze_features(protein):
if
k
in
protein
:
final_dim
=
protein
[
k
].
shape
[
-
1
]
if
isinstance
(
final_dim
,
int
)
and
final_dim
==
1
:
protein
[
k
]
=
torch
.
squeeze
(
protein
[
k
],
dim
=-
1
)
if
torch
.
is_tensor
(
protein
[
k
]):
protein
[
k
]
=
torch
.
squeeze
(
protein
[
k
],
dim
=-
1
)
else
:
protein
[
k
]
=
np
.
squeeze
(
protein
[
k
],
axis
=-
1
)
for
k
in
[
"seq_length"
,
"num_alignments"
]:
if
k
in
protein
:
...
...
@@ -162,7 +165,9 @@ def randomly_replace_msa_with_unknown(protein, replace_proportion):
gap_idx
=
21
msa_mask
=
torch
.
logical_and
(
msa_mask
,
protein
[
"msa"
]
!=
gap_idx
)
protein
[
"msa"
]
=
torch
.
where
(
msa_mask
,
torch
.
ones_like
(
protein
[
"msa"
])
*
x_idx
,
protein
[
"msa"
]
msa_mask
,
torch
.
ones_like
(
protein
[
"msa"
])
*
x_idx
,
protein
[
"msa"
]
)
aatype_mask
=
torch
.
rand
(
protein
[
"aatype"
].
shape
)
<
replace_proportion
...
...
@@ -199,6 +204,11 @@ def sample_msa(protein, max_seq, keep_extra, seed=None):
return
protein
@
curry1
def
add_distillation_flag
(
protein
,
distillation
):
protein
[
'is_distillation'
]
=
distillation
return
protein
@
curry1
def
sample_msa_distillation
(
protein
,
max_seq
):
if
(
protein
[
"is_distillation"
]
==
1
):
...
...
@@ -349,7 +359,7 @@ def make_msa_mask(protein):
"""Mask features are all ones, but will later be zero-padded."""
protein
[
"msa_mask"
]
=
torch
.
ones
(
protein
[
"msa"
].
shape
,
dtype
=
torch
.
float32
)
protein
[
"msa_row_mask"
]
=
torch
.
ones
(
protein
[
"msa"
].
shape
[
0
],
dtype
=
torch
.
float32
(
protein
[
"msa"
].
shape
[
0
]
)
,
dtype
=
torch
.
float32
)
return
protein
...
...
@@ -602,17 +612,18 @@ def make_atom14_masks(protein):
dtype
=
torch
.
float32
,
device
=
protein
[
"aatype"
].
device
,
)
protein_aatype
=
protein
[
'aatype'
].
to
(
torch
.
long
)
# create the mapping for (residx, atom14) --> atom37, i.e. an array
# with shape (num_res, 14) containing the atom37 indices for this protein
residx_atom14_to_atom37
=
restype_atom14_to_atom37
[
protein
[
"
aatype
"
]
]
residx_atom14_mask
=
restype_atom14_mask
[
protein
[
"
aatype
"
]
]
residx_atom14_to_atom37
=
restype_atom14_to_atom37
[
protein
_
aatype
]
residx_atom14_mask
=
restype_atom14_mask
[
protein
_
aatype
]
protein
[
"atom14_atom_exists"
]
=
residx_atom14_mask
protein
[
"residx_atom14_to_atom37"
]
=
residx_atom14_to_atom37
.
long
()
# create the gather indices for mapping back
residx_atom37_to_atom14
=
restype_atom37_to_atom14
[
protein
[
"
aatype
"
]
]
residx_atom37_to_atom14
=
restype_atom37_to_atom14
[
protein
_
aatype
]
protein
[
"residx_atom37_to_atom14"
]
=
residx_atom37_to_atom14
.
long
()
# create the corresponding mask
...
...
@@ -626,7 +637,7 @@ def make_atom14_masks(protein):
atom_type
=
rc
.
atom_order
[
atom_name
]
restype_atom37_mask
[
restype
,
atom_type
]
=
1
residx_atom37_mask
=
restype_atom37_mask
[
protein
[
"
aatype
"
]
]
residx_atom37_mask
=
restype_atom37_mask
[
protein
_
aatype
]
protein
[
"atom37_atom_exists"
]
=
residx_atom37_mask
return
protein
...
...
scripts/prep
are
_proteinnet_msas.py
→
scripts/prep_proteinnet_msas.py
View file @
9ce96fb5
File moved
setup.py
View file @
9ce96fb5
...
...
@@ -17,7 +17,7 @@ from setuptools import setup
setup
(
name
=
'openfold'
,
version
=
'
1.0
.0'
,
version
=
'
0.1
.0'
,
description
=
'A PyTorch reimplementation of DeepMind
\'
s AlphaFold 2'
,
author
=
'Gustaf Ahdritz & DeepMind'
,
author_email
=
'gahdritz@gmail.com'
,
...
...
tests/config.py
View file @
9ce96fb5
...
...
@@ -17,3 +17,17 @@ consts = mlc.ConfigDict(
"c_e"
:
64
,
}
)
config
=
mlc
.
ConfigDict
(
{
"data"
:
{
"common"
:
{
"masked_msa"
:
{
"profile_prob"
:
0.1
,
"same_prob"
:
0.1
,
"uniform_prob"
:
0.1
,
},
}
}
}
)
tests/test_data/features.pkl
0 → 100644
View file @
9ce96fb5
File added
tests/test_data_transforms.py
View file @
9ce96fb5
...
...
@@ -5,12 +5,15 @@ import os
import
pickle
import
numpy
import
numpy
as
np
import
torch
import
unittest
from
data.data_transforms
import
make_seq_mask
from
openfold.config
import
model_config
from
data.data_transforms
import
make_seq_mask
,
add_distillation_flag
,
make_all_atom_aatype
,
fix_templates_aatype
,
\
correct_msa_restypes
,
squeeze_features
,
randomly_replace_msa_with_unknown
,
MSA_FEATURE_NAMES
,
sample_msa
,
\
crop_extra_msa
,
delete_extra_msa
,
nearest_neighbor_clusters
,
make_msa_mask
,
make_hhblits_profile
,
make_masked_msa
,
\
make_msa_feat
,
crop_templates
,
make_atom14_masks
from
tests.config
import
config
class
TestDataTransforms
(
unittest
.
TestCase
):
...
...
@@ -25,6 +28,221 @@ class TestDataTransforms(unittest.TestCase):
assert
'seq_mask'
in
protein
assert
protein
[
'seq_mask'
].
shape
==
torch
.
Size
((
seq
.
shape
[
0
],
20
))
def
test_add_distillation_flag
(
self
):
protein
=
{}
protein
=
add_distillation_flag
.
__wrapped__
(
protein
,
True
)
print
(
protein
)
assert
'is_distillation'
in
protein
assert
protein
[
'is_distillation'
]
is
True
def
test_make_all_atom_aatype
(
self
):
seq
=
torch
.
tensor
([
range
(
20
)],
dtype
=
torch
.
int64
).
transpose
(
0
,
1
)
seq_one_hot
=
torch
.
FloatTensor
(
seq
.
shape
[
0
],
20
).
zero_
()
seq_one_hot
.
scatter_
(
1
,
seq
,
1
)
protein_aatype
=
torch
.
tensor
(
seq_one_hot
)
protein
=
{
'aatype'
:
protein_aatype
}
protein
=
make_all_atom_aatype
(
protein
)
print
(
protein
)
assert
'all_atom_aatype'
in
protein
assert
protein
[
'all_atom_aatype'
].
shape
==
protein
[
'aatype'
].
shape
def
test_fix_templates_aatype
(
self
):
template_seq
=
torch
.
tensor
(
list
(
range
(
20
))
*
2
,
dtype
=
torch
.
int64
)
template_seq
=
template_seq
.
unsqueeze
(
0
).
transpose
(
0
,
1
)
template_seq_one_hot
=
torch
.
FloatTensor
(
template_seq
.
shape
[
0
],
20
).
zero_
()
template_seq_one_hot
.
scatter_
(
1
,
template_seq
,
1
)
template_aatype
=
torch
.
tensor
(
template_seq_one_hot
).
unsqueeze
(
0
)
protein
=
{
'template_aatype'
:
template_aatype
}
protein
=
fix_templates_aatype
(
protein
)
print
(
protein
)
template_seq_ours
=
torch
.
tensor
([[
0
,
4
,
3
,
6
,
13
,
7
,
8
,
9
,
11
,
10
,
12
,
2
,
14
,
5
,
1
,
15
,
16
,
19
,
17
,
18
]
*
2
])
assert
torch
.
all
(
torch
.
eq
(
protein
[
'template_aatype'
],
template_seq_ours
))
def
test_correct_msa_restypes
(
self
):
with
open
(
"../test_data/features.pkl"
,
'rb'
)
as
file
:
features
=
pickle
.
load
(
file
)
protein
=
{
'msa'
:
torch
.
tensor
(
features
[
'msa'
],
dtype
=
torch
.
int64
)}
protein
=
correct_msa_restypes
(
protein
)
print
(
protein
)
assert
torch
.
all
(
torch
.
eq
(
torch
.
tensor
(
features
[
'msa'
].
shape
),
torch
.
tensor
(
protein
[
'msa'
].
shape
)))
def
test_squeeze_features
(
self
):
with
open
(
"../test_data/features.pkl"
,
"rb"
)
as
file
:
features
=
pickle
.
load
(
file
)
print
(
os
.
path
.
realpath
(
file
.
name
),
'Keys: '
,
features
.
keys
())
features_list
=
[
'domain_name'
,
'msa'
,
'num_alignments'
,
'seq_length'
,
'sequence'
,
'superfamily'
,
'deletion_matrix'
,
'resolution'
,
'between_segment_residues'
,
'residue_index'
,
'template_all_atom_mask'
]
protein
=
{
'aatype'
:
torch
.
tensor
(
features
[
'aatype'
])}
for
k
in
features_list
:
if
k
in
features
:
print
(
k
,
features
[
k
].
dtype
)
if
k
in
[
'domain_name'
,
'sequence'
]:
protein
[
k
]
=
np
.
expand_dims
(
features
[
k
],
-
1
)
else
:
protein
[
k
]
=
torch
.
tensor
(
features
[
k
]).
unsqueeze
(
-
1
)
for
k
in
[
'seq_length'
,
'num_alignments'
]:
if
k
in
protein
:
protein
[
k
]
=
torch
.
tensor
(
protein
[
k
]).
unsqueeze
(
0
)
protein_squeezed
=
squeeze_features
(
protein
)
print
(
protein
)
for
k
in
features_list
:
if
k
in
protein
:
print
(
k
,
protein_squeezed
[
k
].
shape
,
features
[
k
].
shape
)
assert
protein_squeezed
[
k
].
shape
==
features
[
k
].
shape
def
test_randomly_replace_msa_with_unknown
(
self
):
with
open
(
'../test_data/features.pkl'
,
'rb'
)
as
file
:
features
=
pickle
.
load
(
file
)
protein
=
{
'msa'
:
torch
.
tensor
(
features
[
'msa'
]),
'aatype'
:
torch
.
argmax
(
torch
.
tensor
(
features
[
'aatype'
]),
dim
=
1
)}
replace_proportion
=
0.15
x_idx
=
20
protein
=
randomly_replace_msa_with_unknown
.
__wrapped__
(
protein
,
replace_proportion
)
unknown_proportion_in_msa
=
torch
.
bincount
(
protein
[
'msa'
].
flatten
())
/
torch
.
numel
(
protein
[
'msa'
])
unknown_proportion_in_seq
=
torch
.
bincount
(
protein
[
'aatype'
].
flatten
())
/
torch
.
numel
(
protein
[
'aatype'
])
print
(
protein
)
print
(
'Proportion of X in MSA: '
,
unknown_proportion_in_msa
[
x_idx
])
print
(
'Proportion of X in sequence: '
,
unknown_proportion_in_seq
[
x_idx
])
def
test_sample_msa
(
self
):
with
open
(
'../test_data/features.pkl'
,
'rb'
)
as
file
:
features
=
pickle
.
load
(
file
)
max_seq
=
1000
keep_extra
=
True
protein
=
{}
for
k
in
MSA_FEATURE_NAMES
:
if
k
in
features
:
protein
[
k
]
=
torch
.
tensor
(
features
[
k
])
protein_processed
=
sample_msa
.
__wrapped__
(
protein
.
copy
(),
max_seq
,
keep_extra
)
print
(
protein
)
for
k
in
MSA_FEATURE_NAMES
:
if
k
in
protein
and
keep_extra
:
assert
protein_processed
[
k
].
shape
[
0
]
==
min
(
protein
[
k
].
shape
[
0
],
max_seq
)
assert
'extra_'
+
k
in
protein_processed
print
(
'extra_'
+
str
(
k
),
protein_processed
[
'extra_'
+
k
].
shape
)
print
(
'msa'
,
protein
[
k
].
shape
[
0
]
-
min
(
protein
[
k
].
shape
[
0
],
max_seq
))
assert
protein_processed
[
'extra_'
+
k
].
shape
[
0
]
==
protein
[
k
].
shape
[
0
]
-
min
(
protein
[
k
].
shape
[
0
],
max_seq
)
def
test_crop_extra_msa
(
self
):
with
open
(
'../test_data/features.pkl'
,
'rb'
)
as
file
:
features
=
pickle
.
load
(
file
)
max_extra_msa
=
10
protein
=
{
'extra_msa'
:
torch
.
tensor
(
features
[
'msa'
])}
num_seq
=
protein
[
"extra_msa"
].
shape
[
0
]
protein
=
crop_extra_msa
.
__wrapped__
(
protein
,
max_extra_msa
)
print
(
protein
)
for
k
in
MSA_FEATURE_NAMES
:
if
"extra_"
+
k
in
protein
:
assert
protein
[
"extra_"
+
k
].
shape
[
0
]
==
min
(
max_extra_msa
,
num_seq
)
def
test_delete_extra_msa
(
self
):
protein
=
{
'extra_msa'
:
torch
.
rand
((
512
,
100
,
23
))}
extra_msa_has_deletion_shape
=
list
(
protein
[
'extra_msa'
].
shape
)
extra_msa_has_deletion_shape
[
2
]
=
1
protein
[
'extra_deletion_matrix'
]
=
torch
.
rand
(
extra_msa_has_deletion_shape
)
protein
=
delete_extra_msa
(
protein
)
print
(
protein
)
for
k
in
MSA_FEATURE_NAMES
:
assert
'extra_'
+
k
not
in
protein
assert
'extra_msa'
not
in
protein
def
test_nearest_neighbor_clusters
(
self
):
with
gzip
.
open
(
'../test_data/sample_feats.pickle.gz'
,
'rb'
)
as
f
:
features
=
pickle
.
load
(
f
)
protein
=
{
'msa'
:
torch
.
tensor
(
features
[
'true_msa'
][
0
],
dtype
=
torch
.
int64
),
'msa_mask'
:
torch
.
tensor
(
features
[
'msa_mask'
][
0
],
dtype
=
torch
.
int64
),
'extra_msa'
:
torch
.
tensor
(
features
[
'extra_msa'
][
0
],
dtype
=
torch
.
int64
),
'extra_msa_mask'
:
torch
.
tensor
(
features
[
'extra_msa_mask'
][
0
],
dtype
=
torch
.
int64
)}
protein
=
nearest_neighbor_clusters
.
__wrapped__
(
protein
,
0
)
print
(
protein
)
assert
'extra_cluster_assignment'
in
protein
def
test_make_msa_mask
(
self
):
with
open
(
'../test_data/features.pkl'
,
'rb'
)
as
file
:
features
=
pickle
.
load
(
file
)
msa_mat
=
torch
.
tensor
(
features
[
'msa'
])
protein
=
{
'msa'
:
msa_mat
}
protein
=
make_msa_mask
(
protein
)
print
(
protein
)
assert
'msa_row_mask'
in
protein
assert
protein
[
'msa_row_mask'
].
shape
[
0
]
==
msa_mat
.
shape
[
0
]
def
test_make_hhblits_profile
(
self
):
with
open
(
'../test_data/features.pkl'
,
'rb'
)
as
file
:
features
=
pickle
.
load
(
file
)
protein
=
{
'msa'
:
torch
.
tensor
(
features
[
'msa'
],
dtype
=
torch
.
int64
)}
protein
=
make_hhblits_profile
(
protein
)
assert
'hhblits_profile'
in
protein
assert
protein
[
'hhblits_profile'
].
shape
==
torch
.
Size
((
protein
[
'msa'
].
shape
[
1
],
22
))
def
test_make_masked_msa
(
self
):
with
open
(
'../test_data/features.pkl'
,
'rb'
)
as
file
:
features
=
pickle
.
load
(
file
)
protein
=
{
'msa'
:
torch
.
tensor
(
features
[
'msa'
],
dtype
=
torch
.
int64
)}
protein
=
make_hhblits_profile
(
protein
)
masked_msa_config
=
config
.
data
.
common
.
masked_msa
protein
=
make_masked_msa
.
__wrapped__
(
protein
,
masked_msa_config
,
replace_fraction
=
0.15
)
print
(
protein
)
assert
'bert_mask'
in
protein
assert
'true_msa'
in
protein
assert
'msa'
in
protein
assert
protein
[
'bert_mask'
].
sum
()
>=
0
assert
torch
.
all
(
torch
.
eq
(
protein
[
'true_msa'
]
*
(
1
-
protein
[
'bert_mask'
]),
protein
[
'msa'
]
*
(
1
-
protein
[
'bert_mask'
])))
def
test_make_msa_feat
(
self
):
with
open
(
'../test_data/features.pkl'
,
'rb'
)
as
file
:
features
=
pickle
.
load
(
file
)
protein
=
{
'between_segment_residues'
:
torch
.
tensor
(
features
[
'between_segment_residues'
]),
'msa'
:
torch
.
tensor
(
features
[
'msa'
],
dtype
=
torch
.
int64
),
'deletion_matrix'
:
torch
.
tensor
(
features
[
'deletion_matrix_int'
]),
'aatype'
:
torch
.
argmax
(
torch
.
tensor
(
features
[
'aatype'
]),
dim
=
1
)}
protein
=
make_msa_feat
.
__wrapped__
(
protein
)
assert
'msa_feat'
in
protein
assert
'target_feat'
in
protein
assert
protein
[
'target_feat'
].
shape
==
torch
.
Size
((
protein
[
'msa'
].
shape
[
1
],
22
))
assert
protein
[
'msa_feat'
].
shape
==
torch
.
Size
((
*
protein
[
'msa'
].
shape
,
25
))
def
test_crop_templates
(
self
):
with
gzip
.
open
(
'../test_data/sample_feats.pickle.gz'
,
'rb'
)
as
f
:
features
=
pickle
.
load
(
f
)
protein
=
{
'template_aatype'
:
torch
.
tensor
(
features
[
'true_msa'
][
0
]),
'template_all_atom_masks'
:
torch
.
tensor
(
features
[
'msa_mask'
][
0
])}
max_templates
=
2
protein
=
crop_templates
.
__wrapped__
(
protein
,
max_templates
)
assert
protein
[
'template_aatype'
].
shape
[
0
]
==
max_templates
assert
protein
[
'template_all_atom_masks'
].
shape
[
0
]
==
max_templates
def
test_make_atom14_masks
(
self
):
with
gzip
.
open
(
'../test_data/sample_feats.pickle.gz'
,
'rb'
)
as
file
:
features
=
pickle
.
load
(
file
)
protein
=
{
'aatype'
:
torch
.
tensor
(
features
[
'aatype'
][
0
])}
protein
=
make_atom14_masks
(
protein
)
print
(
protein
)
assert
'atom14_atom_exists'
in
protein
assert
'residx_atom14_to_atom37'
in
protein
assert
'residx_atom37_to_atom14'
in
protein
assert
'atom37_atom_exists'
in
protein
if
__name__
==
'__main__'
:
unittest
.
main
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment