Commit 925d56f8 authored by Sachin Kadyan's avatar Sachin Kadyan
Browse files

Added test for sample_msa

parent 011a6526
......@@ -10,7 +10,7 @@ import torch
import unittest
from data.data_transforms import make_seq_mask, add_distillation_flag, make_all_atom_aatype, fix_templates_aatype, \
correct_msa_restypes, squeeze_features, randomly_replace_msa_with_unknown
correct_msa_restypes, squeeze_features, randomly_replace_msa_with_unknown, MSA_FEATURE_NAMES, sample_msa
from openfold.config import model_config
......@@ -110,6 +110,26 @@ class TestDataTransforms(unittest.TestCase):
print('Proportion of X in MSA: ', unknown_proportion_in_msa[x_idx])
print('Proportion of X in sequence: ', unknown_proportion_in_seq[x_idx])
def test_sample_msa(self):
with open('../test_data/features.pkl', 'rb') as file:
features = pickle.load(file)
max_seq = 1000
keep_extra = True
protein = {}
for k in MSA_FEATURE_NAMES:
if k in features:
protein[k] = torch.tensor(features[k])
protein_processed = sample_msa.__wrapped__(protein.copy(), max_seq, keep_extra)
print(protein)
for k in MSA_FEATURE_NAMES:
if k in protein and keep_extra:
assert protein_processed[k].shape[0] == min(protein[k].shape[0], max_seq)
assert 'extra_'+k in protein_processed
print('extra_'+str(k), protein_processed['extra_'+k].shape)
print('msa', protein[k].shape[0] - min(protein[k].shape[0], max_seq))
assert protein_processed['extra_'+k].shape[0] == protein[k].shape[0] - min(protein[k].shape[0], max_seq)
if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment