Commit 011a6526 authored by Sachin Kadyan's avatar Sachin Kadyan
Browse files

Added tests for randomly_replace_msa_with_unknown

parent cad8de7e
......@@ -165,7 +165,9 @@ def randomly_replace_msa_with_unknown(protein, replace_proportion):
gap_idx = 21
msa_mask = torch.logical_and(msa_mask, protein["msa"] != gap_idx)
protein["msa"] = torch.where(
msa_mask, torch.ones_like(protein["msa"]) * x_idx, protein["msa"]
msa_mask,
torch.ones_like(protein["msa"]) * x_idx,
protein["msa"]
)
aatype_mask = torch.rand(protein["aatype"].shape) < replace_proportion
......
......@@ -10,7 +10,7 @@ import torch
import unittest
from data.data_transforms import make_seq_mask, add_distillation_flag, make_all_atom_aatype, fix_templates_aatype, \
correct_msa_restypes, squeeze_features
correct_msa_restypes, squeeze_features, randomly_replace_msa_with_unknown
from openfold.config import model_config
......@@ -95,6 +95,20 @@ class TestDataTransforms(unittest.TestCase):
print(k, protein_squeezed[k].shape, features[k].shape)
assert protein_squeezed[k].shape == features[k].shape
def test_randomly_replace_msa_with_unknown(self):
with open('../test_data/features.pkl', 'rb') as file:
features = pickle.load(file)
protein = {'msa': torch.tensor(features['msa']),
'aatype': torch.argmax(torch.tensor(features['aatype']), dim=1)}
replace_proportion = 0.15
x_idx = 20
protein = randomly_replace_msa_with_unknown.__wrapped__(protein, replace_proportion)
unknown_proportion_in_msa = torch.bincount(protein['msa'].flatten()) / torch.numel(protein['msa'])
unknown_proportion_in_seq = torch.bincount(protein['aatype'].flatten()) / torch.numel(protein['aatype'])
print(protein)
print('Proportion of X in MSA: ', unknown_proportion_in_msa[x_idx])
print('Proportion of X in sequence: ', unknown_proportion_in_seq[x_idx])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment