"lib/engines/python/src/lib.rs" did not exist on "c0e008b4fe459a60f2bb3e8784b3cf5c16b71afd"
Commit 09564595 authored by Sachin Kadyan's avatar Sachin Kadyan
Browse files

Added test for crop_extra_msa

parent 925d56f8
......@@ -10,7 +10,8 @@ import torch
import unittest
from data.data_transforms import make_seq_mask, add_distillation_flag, make_all_atom_aatype, fix_templates_aatype, \
correct_msa_restypes, squeeze_features, randomly_replace_msa_with_unknown, MSA_FEATURE_NAMES, sample_msa
correct_msa_restypes, squeeze_features, randomly_replace_msa_with_unknown, MSA_FEATURE_NAMES, sample_msa, \
crop_extra_msa
from openfold.config import model_config
......@@ -131,6 +132,20 @@ class TestDataTransforms(unittest.TestCase):
print('msa', protein[k].shape[0] - min(protein[k].shape[0], max_seq))
assert protein_processed['extra_'+k].shape[0] == protein[k].shape[0] - min(protein[k].shape[0], max_seq)
def test_crop_extra_msa(self):
with open('../test_data/features.pkl', 'rb') as file:
features = pickle.load(file)
max_extra_msa = 10
protein = {'extra_msa': torch.tensor(features['msa'])}
num_seq = protein["extra_msa"].shape[0]
protein = crop_extra_msa.__wrapped__(protein, max_extra_msa)
print(protein)
for k in MSA_FEATURE_NAMES:
if "extra_" + k in protein:
assert protein["extra_" + k].shape[0] == min(max_extra_msa, num_seq)
if __name__ == '__main__':
unittest.main()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment