Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
62e820fc
Commit
62e820fc
authored
Oct 08, 2021
by
Sachin Kadyan
Browse files
Add feature transformations related to MSAs: MSA sampling, handling extra_msa, and block_delete_msa
parent
d54d5c55
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
84 additions
and
0 deletions
+84
-0
openfold/features/data_transforms.py
openfold/features/data_transforms.py
+84
-0
No files found.
openfold/features/data_transforms.py
View file @
62e820fc
...
...
@@ -3,6 +3,9 @@ import torch
from
np
import
residue_constants
MSA_FEATURE_NAMES
=
[
'msa'
,
'deletion_matrix'
,
'msa_mask'
,
'msa_row_mask'
,
'bert_mask'
,
'true_msa'
]
def
cast_to_64bit_ints
(
protein
):
# We keep all ints as int64
...
...
@@ -83,3 +86,84 @@ def squeeze_features(protein):
def
make_protein_crop_to_size_seed
(
protein
):
protein
[
'random_crop_to_size_seed'
]
=
torch
.
distributions
.
Uniform
(
low
=
torch
.
int32
,
high
=
torch
.
int32
).
sample
((
2
))
return
protein
@
curry1
def
randomly_replace_msa_with_unknown
(
protein
,
replace_proportion
):
"""Replace a portion of the MSA with 'X'."""
msa_mask
=
(
torch
.
rand
(
protein
[
'msa'
].
shape
)
<
replace_proportion
)
x_idx
=
20
gap_idx
=
21
msa_mask
=
torch
.
logical_and
(
msa_mask
,
protein
[
'msa'
]
!=
gap_idx
)
protein
[
'msa'
]
=
torch
.
where
(
msa_mask
,
torch
.
ones_like
(
protein
[
'msa'
])
*
x_idx
,
protein
[
'msa'
])
aatype_mask
=
(
torch
.
rand
(
protein
[
'aatype'
].
shape
)
<
replace_proportion
)
protein
[
'aatype'
]
=
torch
.
where
(
aatype_mask
,
torch
.
ones_like
(
protein
[
'aatype'
])
*
x_idx
,
protein
[
'aatype'
])
return
protein
@
curry1
def
sample_msa
(
protein
,
max_seq
,
keep_extra
):
"""Sample MSA randomly, remaining sequences are stored are stored as `extra_*`.
"""
num_seq
=
protein
[
'msa'
].
shape
[
0
]
shuffled
=
torch
.
randperm
(
num_seq
-
1
)
+
1
index_order
=
torch
.
cat
((
torch
.
tensor
([
0
]),
shuffled
),
dim
=
0
)
num_sel
=
min
(
max_seq
,
num_seq
)
print
(
'sample_msa num_sel'
,
num_sel
,
' num_seq'
,
num_seq
)
sel_seq
,
not_sel_seq
=
torch
.
split
(
index_order
,
[
num_sel
,
num_seq
-
num_sel
])
for
k
in
MSA_FEATURE_NAMES
:
if
k
in
protein
:
if
keep_extra
:
protein
[
'extra_'
+
k
]
=
torch
.
index_select
(
protein
[
k
],
0
,
not_sel_seq
)
protein
[
k
]
=
torch
.
index_select
(
protein
[
k
],
0
,
sel_seq
)
return
protein
@
curry1
def
crop_extra_msa
(
protein
,
max_extra_msa
):
num_seq
=
protein
[
'extra_msa'
].
shape
[
0
]
num_sel
=
min
(
max_extra_msa
,
num_seq
)
select_indices
=
torch
.
randperm
(
num_seq
)[:
num_sel
]
print
(
'select_indices'
,
select_indices
)
for
k
in
MSA_FEATURE_NAMES
:
if
'extra_'
+
k
in
protein
:
protein
[
'extra_'
+
k
]
=
torch
.
index_select
(
protein
[
'extra_'
+
k
],
0
,
select_indices
)
return
protein
def
delete_extra_msa
(
protein
):
for
k
in
MSA_FEATURE_NAMES
:
if
'extra_'
+
k
in
protein
:
del
protein
[
'extra_'
+
k
]
return
protein
# Not used in inference
@
curry1
def
block_delete_msa
(
protein
,
config
):
num_seq
=
protein
[
'msa'
].
shape
[
0
]
block_num_seq
=
torch
.
floor
(
torch
.
tensor
(
num_seq
,
dtype
=
torch
.
float32
)
*
config
.
msa_fraction_per_block
).
to
(
torch
.
int32
)
if
config
.
randomize_num_blocks
:
nb
=
torch
.
distributions
.
uniform
.
Uniform
(
0
,
config
.
num_blocks
+
1
).
sample
()
else
:
nb
=
config
.
num_blocks
del_block_starts
=
torch
.
distributions
.
Uniform
(
0
,
num_seq
).
sample
(
nb
)
del_blocks
=
del_block_starts
[:,
None
]
+
torch
.
range
(
block_num_seq
)
del_blocks
=
torch
.
clip
(
del_blocks
,
0
,
num_seq
-
1
)
del_indices
=
torch
.
unique
(
torch
.
sort
(
torch
.
reshape
(
del_blocks
,
[
-
1
])))[
0
]
# Make sure we keep the original sequence
combined
=
torch
.
cat
((
torch
.
range
(
1
,
num_seq
)[
None
],
del_indices
[
None
]))
uniques
,
counts
=
combined
.
unique
(
return_counts
=
True
)
difference
=
uniques
[
counts
==
1
]
intersection
=
uniques
[
counts
>
1
]
keep_indices
=
torch
.
squeeze
(
difference
,
0
)
for
k
in
MSA_FEATURE_NAMES
:
if
k
in
protein
:
protein
[
k
]
=
torch
.
gather
(
protein
[
k
],
keep_indices
)
return
protein
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment