"...examples/git@developer.sourcefind.cn:OpenDAS/dynamo.git" did not exist on "08fcd7e93ba5df3093a8b54fe79e0895fe7a5f15"
Commit 62e820fc authored by Sachin Kadyan's avatar Sachin Kadyan
Browse files

Add feature transformations related to MSAs: MSA sampling, handling extra_msa, and block_delete_msa

parent d54d5c55
...@@ -3,6 +3,9 @@ import torch ...@@ -3,6 +3,9 @@ import torch
from np import residue_constants from np import residue_constants
MSA_FEATURE_NAMES = [
'msa', 'deletion_matrix', 'msa_mask', 'msa_row_mask', 'bert_mask', 'true_msa'
]
def cast_to_64bit_ints(protein): def cast_to_64bit_ints(protein):
# We keep all ints as int64 # We keep all ints as int64
...@@ -83,3 +86,84 @@ def squeeze_features(protein): ...@@ -83,3 +86,84 @@ def squeeze_features(protein):
def make_protein_crop_to_size_seed(protein): def make_protein_crop_to_size_seed(protein):
protein['random_crop_to_size_seed'] = torch.distributions.Uniform(low=torch.int32, high=torch.int32).sample((2)) protein['random_crop_to_size_seed'] = torch.distributions.Uniform(low=torch.int32, high=torch.int32).sample((2))
return protein return protein
@curry1
def randomly_replace_msa_with_unknown(protein, replace_proportion):
"""Replace a portion of the MSA with 'X'."""
msa_mask = (torch.rand(protein['msa'].shape) < replace_proportion)
x_idx = 20
gap_idx = 21
msa_mask = torch.logical_and(msa_mask, protein['msa'] != gap_idx)
protein['msa'] = torch.where(msa_mask, torch.ones_like(protein['msa'])*x_idx,
protein['msa'])
aatype_mask = (
torch.rand(protein['aatype'].shape) < replace_proportion
)
protein['aatype'] = torch.where(aatype_mask, torch.ones_like(protein['aatype']) * x_idx,
protein['aatype'])
return protein
@curry1
def sample_msa(protein, max_seq, keep_extra):
"""Sample MSA randomly, remaining sequences are stored are stored as `extra_*`.
"""
num_seq = protein['msa'].shape[0]
shuffled = torch.randperm(num_seq-1)+1
index_order = torch.cat((torch.tensor([0]), shuffled), dim=0)
num_sel = min(max_seq, num_seq)
print('sample_msa num_sel', num_sel, ' num_seq', num_seq)
sel_seq, not_sel_seq = torch.split(index_order, [num_sel, num_seq-num_sel])
for k in MSA_FEATURE_NAMES:
if k in protein:
if keep_extra:
protein['extra_'+k] = torch.index_select(protein[k], 0, not_sel_seq)
protein[k] = torch.index_select(protein[k], 0, sel_seq)
return protein
@curry1
def crop_extra_msa(protein, max_extra_msa):
num_seq = protein['extra_msa'].shape[0]
num_sel = min(max_extra_msa, num_seq)
select_indices = torch.randperm(num_seq)[:num_sel]
print('select_indices', select_indices)
for k in MSA_FEATURE_NAMES:
if 'extra_' + k in protein:
protein['extra_'+k] = torch.index_select(protein['extra_'+k], 0, select_indices)
return protein
def delete_extra_msa(protein):
for k in MSA_FEATURE_NAMES:
if 'extra_' + k in protein:
del protein['extra_' + k]
return protein
# Not used in inference
@curry1
def block_delete_msa(protein, config):
num_seq = protein['msa'].shape[0]
block_num_seq = torch.floor(torch.tensor(num_seq, dtype=torch.float32) * config.msa_fraction_per_block).to(torch.int32)
if config.randomize_num_blocks:
nb = torch.distributions.uniform.Uniform(0, config.num_blocks+1).sample()
else:
nb = config.num_blocks
del_block_starts = torch.distributions.Uniform(0, num_seq).sample(nb)
del_blocks = del_block_starts[:, None] + torch.range(block_num_seq)
del_blocks = torch.clip(del_blocks, 0, num_seq-1)
del_indices = torch.unique(torch.sort(torch.reshape(del_blocks, [-1])))[0]
# Make sure we keep the original sequence
combined = torch.cat((torch.range(1, num_seq)[None], del_indices[None]))
uniques, counts = combined.unique(return_counts=True)
difference = uniques[counts == 1]
intersection = uniques[counts > 1]
keep_indices = torch.squeeze(difference, 0)
for k in MSA_FEATURE_NAMES:
if k in protein:
protein[k] = torch.gather(protein[k], keep_indices)
return protein
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment