Commit 99361481 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Greatly speed up MSA processing code

parent fd56fb0a
...@@ -45,6 +45,7 @@ def cast_to_64bit_ints(protein): ...@@ -45,6 +45,7 @@ def cast_to_64bit_ints(protein):
for k, v in protein.items(): for k, v in protein.items():
if v.dtype == torch.int32: if v.dtype == torch.int32:
protein[k] = v.type(torch.int64) protein[k] = v.type(torch.int64)
return protein return protein
...@@ -97,6 +98,7 @@ def fix_templates_aatype(protein): ...@@ -97,6 +98,7 @@ def fix_templates_aatype(protein):
protein["template_aatype"] = torch.gather( protein["template_aatype"] = torch.gather(
new_order, 1, index=protein["template_aatype"] new_order, 1, index=protein["template_aatype"]
) )
return protein return protein
...@@ -120,6 +122,7 @@ def correct_msa_restypes(protein): ...@@ -120,6 +122,7 @@ def correct_msa_restypes(protein):
22, 22,
], "num_dim for %s out of expected range: %s" % (k, num_dim) ], "num_dim for %s out of expected range: %s" % (k, num_dim)
protein[k] = torch.dot(protein[k], perm_matrix[:num_dim, :num_dim]) protein[k] = torch.dot(protein[k], perm_matrix[:num_dim, :num_dim])
return protein return protein
...@@ -147,6 +150,7 @@ def squeeze_features(protein): ...@@ -147,6 +150,7 @@ def squeeze_features(protein):
for k in ["seq_length", "num_alignments"]: for k in ["seq_length", "num_alignments"]:
if k in protein: if k in protein:
protein[k] = protein[k][0] protein[k] = protein[k][0]
return protein return protein
...@@ -169,6 +173,7 @@ def randomly_replace_msa_with_unknown(protein, replace_proportion): ...@@ -169,6 +173,7 @@ def randomly_replace_msa_with_unknown(protein, replace_proportion):
) )
return protein return protein
@curry1 @curry1
def sample_msa(protein, max_seq, keep_extra, seed=None): def sample_msa(protein, max_seq, keep_extra, seed=None):
"""Sample MSA randomly, remaining sequences are stored are stored as `extra_*`.""" """Sample MSA randomly, remaining sequences are stored are stored as `extra_*`."""
...@@ -190,6 +195,7 @@ def sample_msa(protein, max_seq, keep_extra, seed=None): ...@@ -190,6 +195,7 @@ def sample_msa(protein, max_seq, keep_extra, seed=None):
protein[k], 0, not_sel_seq protein[k], 0, not_sel_seq
) )
protein[k] = torch.index_select(protein[k], 0, sel_seq) protein[k] = torch.index_select(protein[k], 0, sel_seq)
return protein return protein
...@@ -210,6 +216,7 @@ def crop_extra_msa(protein, max_extra_msa): ...@@ -210,6 +216,7 @@ def crop_extra_msa(protein, max_extra_msa):
protein["extra_" + k] = torch.index_select( protein["extra_" + k] = torch.index_select(
protein["extra_" + k], 0, select_indices protein["extra_" + k], 0, select_indices
) )
return protein return protein
...@@ -290,28 +297,24 @@ def nearest_neighbor_clusters(protein, gap_agreement_weight=0.0): ...@@ -290,28 +297,24 @@ def nearest_neighbor_clusters(protein, gap_agreement_weight=0.0):
def unsorted_segment_sum(data, segment_ids, num_segments): def unsorted_segment_sum(data, segment_ids, num_segments):
""" """
Computes the sum along segments of a tensor. Analogous to tf.unsorted_segment_sum. Computes the sum along segments of a tensor. Similar to
tf.unsorted_segment_sum, but only supports 1-D indices.
:param data: A tensor whose segments are to be summed. :param data: A tensor whose segments are to be summed.
:param segment_ids: The segment indices tensor. :param segment_ids: The 1-D segment indices tensor.
:param num_segments: The number of segments. :param num_segments: The number of segments.
:return: A tensor of same data type as the data argument. :return: A tensor of same data type as the data argument.
""" """
# segment_ids.shape should be a prefix of data.shape assert (
assert all([i in data.shape for i in segment_ids.shape]) len(segment_ids.shape) == 1 and
segment_ids.shape[0] == data.shape[0]
# segment_ids is a 1-D tensor repeat it to have the same shape as data
if len(segment_ids.shape) == 1:
s = torch.prod(torch.tensor(data.shape[1:])).long()
segment_ids = segment_ids.repeat_interleave(s).view(
segment_ids.shape[0], *data.shape[1:]
) )
segment_ids = segment_ids.view(
# data.shape and segment_ids.shape should be equal segment_ids.shape[0], *((1,) * len(data.shape[1:]))
assert data.shape == segment_ids.shape )
segment_ids = segment_ids.expand(data.shape)
shape = [num_segments] + list(data.shape[1:]) shape = [num_segments] + list(data.shape[1:])
tensor = torch.zeros(*shape).scatter_add(0, segment_ids, data.float()) tensor = torch.zeros(*shape).scatter_add_(0, segment_ids, data.float())
tensor = tensor.type(data.dtype) tensor = tensor.type(data.dtype)
return tensor return tensor
...@@ -332,7 +335,6 @@ def summarize_clusters(protein): ...@@ -332,7 +335,6 @@ def summarize_clusters(protein):
msa_sum = csum(mask[:, :, None] * make_one_hot(protein["extra_msa"], 23)) msa_sum = csum(mask[:, :, None] * make_one_hot(protein["extra_msa"], 23))
msa_sum += make_one_hot(protein["msa"], 23) # Original sequence msa_sum += make_one_hot(protein["msa"], 23) # Original sequence
protein["cluster_profile"] = msa_sum / mask_counts[:, :, None] protein["cluster_profile"] = msa_sum / mask_counts[:, :, None]
del msa_sum del msa_sum
del_sum = csum(mask * protein["extra_deletion_matrix"]) del_sum = csum(mask * protein["extra_deletion_matrix"])
...@@ -464,7 +466,6 @@ def make_fixed_size( ...@@ -464,7 +466,6 @@ def make_fixed_size(
num_templates=0, num_templates=0,
): ):
"""Guess at the MSA and sequence dimension to make fixed size.""" """Guess at the MSA and sequence dimension to make fixed size."""
pad_size_map = { pad_size_map = {
NUM_RES: num_res, NUM_RES: num_res,
NUM_MSA_SEQ: msa_cluster_size, NUM_MSA_SEQ: msa_cluster_size,
...@@ -1169,4 +1170,5 @@ def random_crop_to_size( ...@@ -1169,4 +1170,5 @@ def random_crop_to_size(
protein[k] = v[slices] protein[k] = v[slices]
protein["seq_length"] = protein["seq_length"].new_tensor(num_res_crop_size) protein["seq_length"] = protein["seq_length"].new_tensor(num_res_crop_size)
return protein return protein
...@@ -175,6 +175,7 @@ def process_tensors_from_config(tensors, common_cfg, mode_cfg): ...@@ -175,6 +175,7 @@ def process_tensors_from_config(tensors, common_cfg, mode_cfg):
common_cfg, common_cfg,
mode_cfg, mode_cfg,
) )
tensors = compose(nonensembled)(tensors) tensors = compose(nonensembled)(tensors)
if("no_recycling_iters" in tensors): if("no_recycling_iters" in tensors):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment