Commit f7bb6999 authored by Sachin Kadyan's avatar Sachin Kadyan
Browse files

Added feature transformations related to MSA clustering and summarization.

parent 62e820fc
...@@ -14,6 +14,11 @@ def cast_to_64bit_ints(protein): ...@@ -14,6 +14,11 @@ def cast_to_64bit_ints(protein):
protein[k] = v.type(torch.int64) protein[k] = v.type(torch.int64)
return protein return protein
def make_one_hot(x, num_classes):
x_one_hot = torch.zeros(*x.shape, num_classes)
x_one_hot.scatter_(-1, x.unsqueeze(-1), 1)
return x_one_hot
def make_seq_mask(protein): def make_seq_mask(protein):
protein['seq_mask'] = torch.ones(protein['aatype'].shape, dtype=torch.float32) protein['seq_mask'] = torch.ones(protein['aatype'].shape, dtype=torch.float32)
return protein return protein
...@@ -167,3 +172,85 @@ def block_delete_msa(protein, config): ...@@ -167,3 +172,85 @@ def block_delete_msa(protein, config):
protein[k] = torch.gather(protein[k], keep_indices) protein[k] = torch.gather(protein[k], keep_indices)
return protein return protein
@curry1
def nearest_neighbor_clusters(protein, gap_agreement_weight=0.):
weights = torch.cat([
torch.ones(21),
gap_agreement_weight * torch.ones(1),
torch.zeros(1)
], 0)
# Make agreement score as weighted Hamming distance
msa_one_hot = make_one_hot(protein['msa'], 23)
print('msa_one_hot shape', msa_one_hot.shape)
sample_one_hot = (protein['msa_mask'][:,:,None] * msa_one_hot)
extra_msa_one_hot = make_one_hot(protein['extra_msa'], 23)
print('extra_msa_one_hot shape', extra_msa_one_hot.shape)
extra_one_hot = (protein['extra_msa_mask'][:,:,None] * extra_msa_one_hot)
num_seq, num_res, _ = sample_one_hot.shape
extra_num_seq, _, _ = extra_one_hot.shape
# Compute tf.einsum('mrc,nrc,c->mn', sample_one_hot, extra_one_hot, weights)
# in an optimized fashion to avoid possible memory or computation blowup.
agreement = torch.matmul(torch.reshape(extra_one_hot, [extra_num_seq, num_res*23]),
torch.reshape(sample_one_hot * weights, [num_seq, num_res * 23]).transpose(0, 1),
)
# Assign each sequence in the extra sequences to the closest MSA sample
protein['extra_cluster_assignment'] = torch.argmax(agreement, dim=1).to(torch.int64)
return protein
def unsorted_segment_sum(data, segment_ids, num_segments):
"""
Computes the sum along segments of a tensor. Analogous to tf.unsorted_segment_sum.
:param data: A tensor whose segments are to be summed.
:param segment_ids: The segment indices tensor.
:param num_segments: The number of segments.
:return: A tensor of same data type as the data argument.
"""
assert all([i in data.shape for i in segment_ids.shape]), "segment_ids.shape should be a prefix of data.shape"
# segment_ids is a 1-D tensor repeat it to have the same shape as data
if len(segment_ids.shape) == 1:
s = torch.prod(torch.tensor(data.shape[1:])).long()
segment_ids = segment_ids.repeat_interleave(s).view(segment_ids.shape[0], *data.shape[1:])
assert data.shape == segment_ids.shape, "data.shape and segment_ids.shape should be equal"
shape = [num_segments] + list(data.shape[1:])
tensor = torch.zeros(*shape).scatter_add(0, segment_ids, data.float())
tensor = tensor.type(data.dtype)
return tensor
@curry1
def summarize_clusters(protein):
"""Produce profile and deletion_matrix_mean within each cluster."""
num_seq = protein['msa'].shape[0]
def csum(x):
return unsorted_segment_sum(x, protein['extra_cluster_assignment'], num_seq)
mask = protein['extra_msa_mask']
mask_counts = 1e-6 + protein['msa_mask'] + csum(mask) # Include center
msa_sum = csum(mask[:, :, None] * make_one_hot(protein['extra_msa'], 23))
msa_sum += make_one_hot(protein['msa'], 23) # Original sequence
protein['cluster_profile'] = msa_sum / mask_counts[:, :, None]
del msa_sum
del_sum = csum(mask * protein['extra_deletion_matrix'])
del_sum += protein['deletion_matrix'] # Original sequence
protein['cluster_deletion_mean'] = del_sum / mask_counts
del del_sum
return protein
def make_msa_mask(protein):
"""Mask features are all ones, but will later be zero-padded."""
protein['msa_mask'] = torch.ones(protein['msa'].shape, dtype=torch.float32)
protein['msa_row_mask'] = torch.ones(protein['msa'].shape[0], dtype=torch.float32)
return protein
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment