Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
1197e8b1
Commit
1197e8b1
authored
Oct 08, 2021
by
Sachin Kadyan
Browse files
Added feature transformations for creating masked MSAs, and some other miscellaneous ones.
parent
e28c2828
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
57 additions
and
4 deletions
+57
-4
openfold/features/data_transforms.py
openfold/features/data_transforms.py
+57
-4
No files found.
openfold/features/data_transforms.py
View file @
1197e8b1
from
functools
import
reduce
import
numpy
as
np
import
torch
from
operator
import
add
from
np
import
residue_constants
...
...
@@ -117,7 +120,6 @@ def sample_msa(protein, max_seq, keep_extra):
shuffled
=
torch
.
randperm
(
num_seq
-
1
)
+
1
index_order
=
torch
.
cat
((
torch
.
tensor
([
0
]),
shuffled
),
dim
=
0
)
num_sel
=
min
(
max_seq
,
num_seq
)
print
(
'sample_msa num_sel'
,
num_sel
,
' num_seq'
,
num_seq
)
sel_seq
,
not_sel_seq
=
torch
.
split
(
index_order
,
[
num_sel
,
num_seq
-
num_sel
])
for
k
in
MSA_FEATURE_NAMES
:
...
...
@@ -132,7 +134,6 @@ def crop_extra_msa(protein, max_extra_msa):
num_seq
=
protein
[
'extra_msa'
].
shape
[
0
]
num_sel
=
min
(
max_extra_msa
,
num_seq
)
select_indices
=
torch
.
randperm
(
num_seq
)[:
num_sel
]
print
(
'select_indices'
,
select_indices
)
for
k
in
MSA_FEATURE_NAMES
:
if
'extra_'
+
k
in
protein
:
protein
[
'extra_'
+
k
]
=
torch
.
index_select
(
protein
[
'extra_'
+
k
],
0
,
select_indices
)
...
...
@@ -183,10 +184,8 @@ def nearest_neighbor_clusters(protein, gap_agreement_weight=0.):
# Make agreement score as weighted Hamming distance
msa_one_hot
=
make_one_hot
(
protein
[
'msa'
],
23
)
print
(
'msa_one_hot shape'
,
msa_one_hot
.
shape
)
sample_one_hot
=
(
protein
[
'msa_mask'
][:,:,
None
]
*
msa_one_hot
)
extra_msa_one_hot
=
make_one_hot
(
protein
[
'extra_msa'
],
23
)
print
(
'extra_msa_one_hot shape'
,
extra_msa_one_hot
.
shape
)
extra_one_hot
=
(
protein
[
'extra_msa_mask'
][:,:,
None
]
*
extra_msa_one_hot
)
num_seq
,
num_res
,
_
=
sample_one_hot
.
shape
...
...
@@ -282,3 +281,57 @@ def make_pseudo_beta(protein, prefix=''):
protein
[
prefix
+
'all_atom_positions'
],
protein
[
'template_all_atom_masks'
if
prefix
else
'all_atom_mask'
]))
return
protein
@
curry1
def
add_constant_field
(
protein
,
key
,
value
):
protein
[
key
]
=
torch
.
tensor
(
value
)
return
protein
def
shaped_categorical
(
probs
,
epsilon
=
1e-10
):
ds
=
probs
.
shape
num_classes
=
ds
[
-
1
]
distribution
=
torch
.
distributions
.
categorical
.
Categorical
(
torch
.
reshape
(
probs
+
epsilon
,[
-
1
,
num_classes
]))
counts
=
distribution
.
sample
()
return
torch
.
reshape
(
counts
,
ds
[:
-
1
])
def
make_hhblits_profile
(
protein
):
"""Compute the HHblits MSA profile if not already present."""
if
'hhblits_profile'
in
protein
:
return
protein
# Compute the profile for every residue (over all MSA sequences).
msa_one_hot
=
make_one_hot
(
protein
[
'msa'
],
22
)
protein
[
'hhblits_profile'
]
=
torch
.
mean
(
msa_one_hot
,
dim
=
0
)
return
protein
@
curry1
def
make_masked_msa
(
protein
,
config
,
replace_fraction
):
"""Create data for BERT on raw MSA."""
# Add a random amino acid uniformly.
random_aa
=
torch
.
tensor
([
0.05
]
*
20
+
[
0.
,
0.
],
dtype
=
torch
.
float32
)
categorical_probs
=
(
config
.
uniform_prob
*
random_aa
+
config
.
profile_prob
*
protein
[
'hhblits_profile'
]
+
config
.
same_prob
*
make_one_hot
(
protein
[
'msa'
],
22
))
# Put all remaining probability on [MASK] which is a new column
pad_shapes
=
list
(
reduce
(
add
,
[(
0
,
0
)
for
_
in
range
(
len
(
categorical_probs
.
shape
))]))
pad_shapes
[
1
]
=
1
mask_prob
=
1.
-
config
.
profile_prob
-
config
.
same_prob
-
config
.
uniform_prob
assert
mask_prob
>=
0.
categorical_probs
=
torch
.
nn
.
functional
.
pad
(
categorical_probs
,
pad_shapes
,
value
=
mask_prob
)
sh
=
protein
[
'msa'
].
shape
mask_position
=
torch
.
rand
(
sh
)
<
replace_fraction
bert_msa
=
shaped_categorical
(
categorical_probs
)
bert_msa
=
torch
.
where
(
mask_position
,
bert_msa
,
protein
[
'msa'
])
# Mix real and masked MSA
protein
[
'bert_mask'
]
=
mask_position
.
to
(
torch
.
float32
)
protein
[
'true_msa'
]
=
protein
[
'msa'
]
protein
[
'msa'
]
=
bert_msa
return
protein
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment