Commit cdc28a07 authored by Sachin Kadyan's avatar Sachin Kadyan
Browse files

Added test for make_masked_msa

parent e683dce3
......@@ -17,3 +17,17 @@ consts = mlc.ConfigDict(
"c_e": 64,
}
)
config = mlc.ConfigDict(
{
"data": {
"common": {
"masked_msa": {
"profile_prob": 0.1,
"same_prob": 0.1,
"uniform_prob": 0.1,
},
}
}
}
)
......@@ -11,8 +11,8 @@ import unittest
from data.data_transforms import make_seq_mask, add_distillation_flag, make_all_atom_aatype, fix_templates_aatype, \
correct_msa_restypes, squeeze_features, randomly_replace_msa_with_unknown, MSA_FEATURE_NAMES, sample_msa, \
crop_extra_msa, delete_extra_msa, nearest_neighbor_clusters, make_msa_mask, make_hhblits_profile
from openfold.config import model_config
crop_extra_msa, delete_extra_msa, nearest_neighbor_clusters, make_msa_mask, make_hhblits_profile, make_masked_msa
from tests.config import config
class TestDataTransforms(unittest.TestCase):
......@@ -189,6 +189,22 @@ class TestDataTransforms(unittest.TestCase):
assert 'hhblits_profile' in protein
assert protein['hhblits_profile'].shape == torch.Size((protein['msa'].shape[1], 22))
def test_make_masked_msa(self):
with open('../test_data/features.pkl', 'rb') as file:
features = pickle.load(file)
protein = {'msa': torch.tensor(features['msa'], dtype=torch.int64)}
protein = make_hhblits_profile(protein)
masked_msa_config = config.data.common.masked_msa
protein = make_masked_msa.__wrapped__(protein, masked_msa_config, replace_fraction=0.15)
print(protein)
assert 'bert_mask' in protein
assert 'true_msa' in protein
assert 'msa' in protein
assert protein['bert_mask'].sum() >= 0
assert torch.all(torch.eq(
protein['true_msa'] * (1-protein['bert_mask']), protein['msa'] * (1-protein['bert_mask'])))
if __name__ == '__main__':
unittest.main()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment