Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
99361481
Commit
99361481
authored
Nov 04, 2021
by
Gustaf Ahdritz
Browse files
Greatly speed up MSA processing code
parent
fd56fb0a
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
24 additions
and
21 deletions
+24
-21
openfold/data/data_transforms.py
openfold/data/data_transforms.py
+23
-21
openfold/data/input_pipeline.py
openfold/data/input_pipeline.py
+1
-0
No files found.
openfold/data/data_transforms.py
View file @
99361481
...
...
@@ -45,6 +45,7 @@ def cast_to_64bit_ints(protein):
for
k
,
v
in
protein
.
items
():
if
v
.
dtype
==
torch
.
int32
:
protein
[
k
]
=
v
.
type
(
torch
.
int64
)
return
protein
...
...
@@ -97,6 +98,7 @@ def fix_templates_aatype(protein):
protein
[
"template_aatype"
]
=
torch
.
gather
(
new_order
,
1
,
index
=
protein
[
"template_aatype"
]
)
return
protein
...
...
@@ -120,6 +122,7 @@ def correct_msa_restypes(protein):
22
,
],
"num_dim for %s out of expected range: %s"
%
(
k
,
num_dim
)
protein
[
k
]
=
torch
.
dot
(
protein
[
k
],
perm_matrix
[:
num_dim
,
:
num_dim
])
return
protein
...
...
@@ -147,6 +150,7 @@ def squeeze_features(protein):
for
k
in
[
"seq_length"
,
"num_alignments"
]:
if
k
in
protein
:
protein
[
k
]
=
protein
[
k
][
0
]
return
protein
...
...
@@ -169,6 +173,7 @@ def randomly_replace_msa_with_unknown(protein, replace_proportion):
)
return
protein
@
curry1
def
sample_msa
(
protein
,
max_seq
,
keep_extra
,
seed
=
None
):
"""Sample MSA randomly, remaining sequences are stored are stored as `extra_*`."""
...
...
@@ -190,6 +195,7 @@ def sample_msa(protein, max_seq, keep_extra, seed=None):
protein
[
k
],
0
,
not_sel_seq
)
protein
[
k
]
=
torch
.
index_select
(
protein
[
k
],
0
,
sel_seq
)
return
protein
...
...
@@ -210,6 +216,7 @@ def crop_extra_msa(protein, max_extra_msa):
protein
[
"extra_"
+
k
]
=
torch
.
index_select
(
protein
[
"extra_"
+
k
],
0
,
select_indices
)
return
protein
...
...
@@ -284,34 +291,30 @@ def nearest_neighbor_clusters(protein, gap_agreement_weight=0.0):
protein
[
"extra_cluster_assignment"
]
=
torch
.
argmax
(
agreement
,
dim
=
1
).
to
(
torch
.
int64
)
return
protein
def
unsorted_segment_sum
(
data
,
segment_ids
,
num_segments
):
"""
Computes the sum along segments of a tensor. Analogous to tf.unsorted_segment_sum.
Computes the sum along segments of a tensor. Similar to
tf.unsorted_segment_sum, but only supports 1-D indices.
:param data: A tensor whose segments are to be summed.
:param segment_ids: The segment indices tensor.
:param segment_ids: The
1-D
segment indices tensor.
:param num_segments: The number of segments.
:return: A tensor of same data type as the data argument.
"""
# segment_ids.shape should be a prefix of data.shape
assert
all
([
i
in
data
.
shape
for
i
in
segment_ids
.
shape
])
# segment_ids is a 1-D tensor repeat it to have the same shape as data
if
len
(
segment_ids
.
shape
)
==
1
:
s
=
torch
.
prod
(
torch
.
tensor
(
data
.
shape
[
1
:])).
long
()
segment_ids
=
segment_ids
.
repeat_interleave
(
s
).
view
(
segment_ids
.
shape
[
0
],
*
data
.
shape
[
1
:]
)
# data.shape and segment_ids.shape should be equal
assert
data
.
shape
==
segment_ids
.
shape
assert
(
len
(
segment_ids
.
shape
)
==
1
and
segment_ids
.
shape
[
0
]
==
data
.
shape
[
0
]
)
segment_ids
=
segment_ids
.
view
(
segment_ids
.
shape
[
0
],
*
((
1
,)
*
len
(
data
.
shape
[
1
:]))
)
segment_ids
=
segment_ids
.
expand
(
data
.
shape
)
shape
=
[
num_segments
]
+
list
(
data
.
shape
[
1
:])
tensor
=
torch
.
zeros
(
*
shape
).
scatter_add
(
0
,
segment_ids
,
data
.
float
())
tensor
=
torch
.
zeros
(
*
shape
).
scatter_add
_
(
0
,
segment_ids
,
data
.
float
())
tensor
=
tensor
.
type
(
data
.
dtype
)
return
tensor
...
...
@@ -332,14 +335,13 @@ def summarize_clusters(protein):
msa_sum
=
csum
(
mask
[:,
:,
None
]
*
make_one_hot
(
protein
[
"extra_msa"
],
23
))
msa_sum
+=
make_one_hot
(
protein
[
"msa"
],
23
)
# Original sequence
protein
[
"cluster_profile"
]
=
msa_sum
/
mask_counts
[:,
:,
None
]
del
msa_sum
del_sum
=
csum
(
mask
*
protein
[
"extra_deletion_matrix"
])
del_sum
+=
protein
[
"deletion_matrix"
]
# Original sequence
protein
[
"cluster_deletion_mean"
]
=
del_sum
/
mask_counts
del
del_sum
return
protein
...
...
@@ -464,7 +466,6 @@ def make_fixed_size(
num_templates
=
0
,
):
"""Guess at the MSA and sequence dimension to make fixed size."""
pad_size_map
=
{
NUM_RES
:
num_res
,
NUM_MSA_SEQ
:
msa_cluster_size
,
...
...
@@ -490,7 +491,7 @@ def make_fixed_size(
if
padding
:
protein
[
k
]
=
torch
.
nn
.
functional
.
pad
(
v
,
padding
)
protein
[
k
]
=
torch
.
reshape
(
protein
[
k
],
pad_size
)
return
protein
...
...
@@ -1169,4 +1170,5 @@ def random_crop_to_size(
protein
[
k
]
=
v
[
slices
]
protein
[
"seq_length"
]
=
protein
[
"seq_length"
].
new_tensor
(
num_res_crop_size
)
return
protein
openfold/data/input_pipeline.py
View file @
99361481
...
...
@@ -175,6 +175,7 @@ def process_tensors_from_config(tensors, common_cfg, mode_cfg):
common_cfg
,
mode_cfg
,
)
tensors
=
compose
(
nonensembled
)(
tensors
)
if
(
"no_recycling_iters"
in
tensors
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment