Commit 99361481 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Greatly speed up MSA processing code

parent fd56fb0a
......@@ -45,6 +45,7 @@ def cast_to_64bit_ints(protein):
for k, v in protein.items():
if v.dtype == torch.int32:
protein[k] = v.type(torch.int64)
return protein
......@@ -97,6 +98,7 @@ def fix_templates_aatype(protein):
protein["template_aatype"] = torch.gather(
new_order, 1, index=protein["template_aatype"]
)
return protein
......@@ -120,6 +122,7 @@ def correct_msa_restypes(protein):
22,
], "num_dim for %s out of expected range: %s" % (k, num_dim)
protein[k] = torch.dot(protein[k], perm_matrix[:num_dim, :num_dim])
return protein
......@@ -147,6 +150,7 @@ def squeeze_features(protein):
for k in ["seq_length", "num_alignments"]:
if k in protein:
protein[k] = protein[k][0]
return protein
......@@ -169,6 +173,7 @@ def randomly_replace_msa_with_unknown(protein, replace_proportion):
)
return protein
@curry1
def sample_msa(protein, max_seq, keep_extra, seed=None):
"""Sample MSA randomly, remaining sequences are stored are stored as `extra_*`."""
......@@ -190,6 +195,7 @@ def sample_msa(protein, max_seq, keep_extra, seed=None):
protein[k], 0, not_sel_seq
)
protein[k] = torch.index_select(protein[k], 0, sel_seq)
return protein
......@@ -210,6 +216,7 @@ def crop_extra_msa(protein, max_extra_msa):
protein["extra_" + k] = torch.index_select(
protein["extra_" + k], 0, select_indices
)
return protein
......@@ -284,34 +291,30 @@ def nearest_neighbor_clusters(protein, gap_agreement_weight=0.0):
protein["extra_cluster_assignment"] = torch.argmax(agreement, dim=1).to(
torch.int64
)
return protein
def unsorted_segment_sum(data, segment_ids, num_segments):
"""
Computes the sum along segments of a tensor. Analogous to tf.unsorted_segment_sum.
Computes the sum along segments of a tensor. Similar to
tf.unsorted_segment_sum, but only supports 1-D indices.
:param data: A tensor whose segments are to be summed.
:param segment_ids: The segment indices tensor.
:param segment_ids: The 1-D segment indices tensor.
:param num_segments: The number of segments.
:return: A tensor of same data type as the data argument.
"""
# segment_ids.shape should be a prefix of data.shape
assert all([i in data.shape for i in segment_ids.shape])
# segment_ids is a 1-D tensor repeat it to have the same shape as data
if len(segment_ids.shape) == 1:
s = torch.prod(torch.tensor(data.shape[1:])).long()
segment_ids = segment_ids.repeat_interleave(s).view(
segment_ids.shape[0], *data.shape[1:]
)
# data.shape and segment_ids.shape should be equal
assert data.shape == segment_ids.shape
assert (
len(segment_ids.shape) == 1 and
segment_ids.shape[0] == data.shape[0]
)
segment_ids = segment_ids.view(
segment_ids.shape[0], *((1,) * len(data.shape[1:]))
)
segment_ids = segment_ids.expand(data.shape)
shape = [num_segments] + list(data.shape[1:])
tensor = torch.zeros(*shape).scatter_add(0, segment_ids, data.float())
tensor = torch.zeros(*shape).scatter_add_(0, segment_ids, data.float())
tensor = tensor.type(data.dtype)
return tensor
......@@ -332,14 +335,13 @@ def summarize_clusters(protein):
msa_sum = csum(mask[:, :, None] * make_one_hot(protein["extra_msa"], 23))
msa_sum += make_one_hot(protein["msa"], 23) # Original sequence
protein["cluster_profile"] = msa_sum / mask_counts[:, :, None]
del msa_sum
del_sum = csum(mask * protein["extra_deletion_matrix"])
del_sum += protein["deletion_matrix"] # Original sequence
protein["cluster_deletion_mean"] = del_sum / mask_counts
del del_sum
return protein
......@@ -464,7 +466,6 @@ def make_fixed_size(
num_templates=0,
):
"""Guess at the MSA and sequence dimension to make fixed size."""
pad_size_map = {
NUM_RES: num_res,
NUM_MSA_SEQ: msa_cluster_size,
......@@ -490,7 +491,7 @@ def make_fixed_size(
if padding:
protein[k] = torch.nn.functional.pad(v, padding)
protein[k] = torch.reshape(protein[k], pad_size)
return protein
......@@ -1169,4 +1170,5 @@ def random_crop_to_size(
protein[k] = v[slices]
protein["seq_length"] = protein["seq_length"].new_tensor(num_res_crop_size)
return protein
......@@ -175,6 +175,7 @@ def process_tensors_from_config(tensors, common_cfg, mode_cfg):
common_cfg,
mode_cfg,
)
tensors = compose(nonensembled)(tensors)
if("no_recycling_iters" in tensors):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment