Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
f7bb6999
Commit
f7bb6999
authored
Oct 08, 2021
by
Sachin Kadyan
Browse files
Added feature transformations related to MSA clustering and summarization.
parent
62e820fc
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
87 additions
and
0 deletions
+87
-0
openfold/features/data_transforms.py
openfold/features/data_transforms.py
+87
-0
No files found.
openfold/features/data_transforms.py
View file @
f7bb6999
...
...
@@ -14,6 +14,11 @@ def cast_to_64bit_ints(protein):
protein
[
k
]
=
v
.
type
(
torch
.
int64
)
return
protein
def
make_one_hot
(
x
,
num_classes
):
x_one_hot
=
torch
.
zeros
(
*
x
.
shape
,
num_classes
)
x_one_hot
.
scatter_
(
-
1
,
x
.
unsqueeze
(
-
1
),
1
)
return
x_one_hot
def
make_seq_mask
(
protein
):
protein
[
'seq_mask'
]
=
torch
.
ones
(
protein
[
'aatype'
].
shape
,
dtype
=
torch
.
float32
)
return
protein
...
...
@@ -167,3 +172,85 @@ def block_delete_msa(protein, config):
protein
[
k
]
=
torch
.
gather
(
protein
[
k
],
keep_indices
)
return
protein
@
curry1
def
nearest_neighbor_clusters
(
protein
,
gap_agreement_weight
=
0.
):
weights
=
torch
.
cat
([
torch
.
ones
(
21
),
gap_agreement_weight
*
torch
.
ones
(
1
),
torch
.
zeros
(
1
)
],
0
)
# Make agreement score as weighted Hamming distance
msa_one_hot
=
make_one_hot
(
protein
[
'msa'
],
23
)
print
(
'msa_one_hot shape'
,
msa_one_hot
.
shape
)
sample_one_hot
=
(
protein
[
'msa_mask'
][:,:,
None
]
*
msa_one_hot
)
extra_msa_one_hot
=
make_one_hot
(
protein
[
'extra_msa'
],
23
)
print
(
'extra_msa_one_hot shape'
,
extra_msa_one_hot
.
shape
)
extra_one_hot
=
(
protein
[
'extra_msa_mask'
][:,:,
None
]
*
extra_msa_one_hot
)
num_seq
,
num_res
,
_
=
sample_one_hot
.
shape
extra_num_seq
,
_
,
_
=
extra_one_hot
.
shape
# Compute tf.einsum('mrc,nrc,c->mn', sample_one_hot, extra_one_hot, weights)
# in an optimized fashion to avoid possible memory or computation blowup.
agreement
=
torch
.
matmul
(
torch
.
reshape
(
extra_one_hot
,
[
extra_num_seq
,
num_res
*
23
]),
torch
.
reshape
(
sample_one_hot
*
weights
,
[
num_seq
,
num_res
*
23
]).
transpose
(
0
,
1
),
)
# Assign each sequence in the extra sequences to the closest MSA sample
protein
[
'extra_cluster_assignment'
]
=
torch
.
argmax
(
agreement
,
dim
=
1
).
to
(
torch
.
int64
)
return
protein
def
unsorted_segment_sum
(
data
,
segment_ids
,
num_segments
):
"""
Computes the sum along segments of a tensor. Analogous to tf.unsorted_segment_sum.
:param data: A tensor whose segments are to be summed.
:param segment_ids: The segment indices tensor.
:param num_segments: The number of segments.
:return: A tensor of same data type as the data argument.
"""
assert
all
([
i
in
data
.
shape
for
i
in
segment_ids
.
shape
]),
"segment_ids.shape should be a prefix of data.shape"
# segment_ids is a 1-D tensor repeat it to have the same shape as data
if
len
(
segment_ids
.
shape
)
==
1
:
s
=
torch
.
prod
(
torch
.
tensor
(
data
.
shape
[
1
:])).
long
()
segment_ids
=
segment_ids
.
repeat_interleave
(
s
).
view
(
segment_ids
.
shape
[
0
],
*
data
.
shape
[
1
:])
assert
data
.
shape
==
segment_ids
.
shape
,
"data.shape and segment_ids.shape should be equal"
shape
=
[
num_segments
]
+
list
(
data
.
shape
[
1
:])
tensor
=
torch
.
zeros
(
*
shape
).
scatter_add
(
0
,
segment_ids
,
data
.
float
())
tensor
=
tensor
.
type
(
data
.
dtype
)
return
tensor
@
curry1
def
summarize_clusters
(
protein
):
"""Produce profile and deletion_matrix_mean within each cluster."""
num_seq
=
protein
[
'msa'
].
shape
[
0
]
def
csum
(
x
):
return
unsorted_segment_sum
(
x
,
protein
[
'extra_cluster_assignment'
],
num_seq
)
mask
=
protein
[
'extra_msa_mask'
]
mask_counts
=
1e-6
+
protein
[
'msa_mask'
]
+
csum
(
mask
)
# Include center
msa_sum
=
csum
(
mask
[:,
:,
None
]
*
make_one_hot
(
protein
[
'extra_msa'
],
23
))
msa_sum
+=
make_one_hot
(
protein
[
'msa'
],
23
)
# Original sequence
protein
[
'cluster_profile'
]
=
msa_sum
/
mask_counts
[:,
:,
None
]
del
msa_sum
del_sum
=
csum
(
mask
*
protein
[
'extra_deletion_matrix'
])
del_sum
+=
protein
[
'deletion_matrix'
]
# Original sequence
protein
[
'cluster_deletion_mean'
]
=
del_sum
/
mask_counts
del
del_sum
return
protein
def
make_msa_mask
(
protein
):
"""Mask features are all ones, but will later be zero-padded."""
protein
[
'msa_mask'
]
=
torch
.
ones
(
protein
[
'msa'
].
shape
,
dtype
=
torch
.
float32
)
protein
[
'msa_row_mask'
]
=
torch
.
ones
(
protein
[
'msa'
].
shape
[
0
],
dtype
=
torch
.
float32
)
return
protein
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment