"examples/python_rs/vscode:/vscode.git/clone" did not exist on "3f84cdadfa3fb50f0764023219ab7cad11275f74"
Commit 1197e8b1 authored by Sachin Kadyan's avatar Sachin Kadyan
Browse files

Added feature transformations for creating masked MSAs, and some other miscellaneous ones.

parent e28c2828
from functools import reduce
import numpy as np
import torch
from operator import add
from np import residue_constants
......@@ -117,7 +120,6 @@ def sample_msa(protein, max_seq, keep_extra):
shuffled = torch.randperm(num_seq-1)+1
index_order = torch.cat((torch.tensor([0]), shuffled), dim=0)
num_sel = min(max_seq, num_seq)
print('sample_msa num_sel', num_sel, ' num_seq', num_seq)
sel_seq, not_sel_seq = torch.split(index_order, [num_sel, num_seq-num_sel])
for k in MSA_FEATURE_NAMES:
......@@ -132,7 +134,6 @@ def crop_extra_msa(protein, max_extra_msa):
num_seq = protein['extra_msa'].shape[0]
num_sel = min(max_extra_msa, num_seq)
select_indices = torch.randperm(num_seq)[:num_sel]
print('select_indices', select_indices)
for k in MSA_FEATURE_NAMES:
if 'extra_' + k in protein:
protein['extra_'+k] = torch.index_select(protein['extra_'+k], 0, select_indices)
......@@ -183,10 +184,8 @@ def nearest_neighbor_clusters(protein, gap_agreement_weight=0.):
# Make agreement score as weighted Hamming distance
msa_one_hot = make_one_hot(protein['msa'], 23)
print('msa_one_hot shape', msa_one_hot.shape)
sample_one_hot = (protein['msa_mask'][:,:,None] * msa_one_hot)
extra_msa_one_hot = make_one_hot(protein['extra_msa'], 23)
print('extra_msa_one_hot shape', extra_msa_one_hot.shape)
extra_one_hot = (protein['extra_msa_mask'][:,:,None] * extra_msa_one_hot)
num_seq, num_res, _ = sample_one_hot.shape
......@@ -282,3 +281,57 @@ def make_pseudo_beta(protein, prefix=''):
protein[prefix + 'all_atom_positions'],
protein['template_all_atom_masks' if prefix else 'all_atom_mask']))
return protein
@curry1
def add_constant_field(protein, key, value):
protein[key] = torch.tensor(value)
return protein
def shaped_categorical(probs, epsilon=1e-10):
ds = probs.shape
num_classes = ds[-1]
distribution = torch.distributions.categorical.Categorical(torch.reshape(probs+epsilon,[-1, num_classes]))
counts = distribution.sample()
return torch.reshape(counts, ds[:-1])
def make_hhblits_profile(protein):
"""Compute the HHblits MSA profile if not already present."""
if 'hhblits_profile' in protein:
return protein
# Compute the profile for every residue (over all MSA sequences).
msa_one_hot = make_one_hot(protein['msa'], 22)
protein['hhblits_profile'] = torch.mean(msa_one_hot, dim=0)
return protein
@curry1
def make_masked_msa(protein, config, replace_fraction):
"""Create data for BERT on raw MSA."""
# Add a random amino acid uniformly.
random_aa = torch.tensor([0.05] * 20 + [0., 0.], dtype=torch.float32)
categorical_probs = (
config.uniform_prob * random_aa +
config.profile_prob * protein['hhblits_profile'] +
config.same_prob * make_one_hot(protein['msa'], 22))
# Put all remaining probability on [MASK] which is a new column
pad_shapes = list(reduce(add, [(0, 0) for _ in range(len(categorical_probs.shape))]))
pad_shapes[1] = 1
mask_prob = 1. - config.profile_prob - config.same_prob - config.uniform_prob
assert mask_prob >= 0.
categorical_probs = torch.nn.functional.pad(categorical_probs, pad_shapes, value=mask_prob)
sh = protein['msa'].shape
mask_position = torch.rand(sh) < replace_fraction
bert_msa = shaped_categorical(categorical_probs)
bert_msa = torch.where(mask_position, bert_msa, protein['msa'])
# Mix real and masked MSA
protein['bert_mask'] = mask_position.to(torch.float32)
protein['true_msa'] = protein['msa']
protein['msa'] = bert_msa
return protein
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment