"segmentation/git@developer.sourcefind.cn:OpenDAS/dcnv3.git" did not exist on "4182170a9b4843d153b58197701c5352232313d9"
Commit a3b5c162 authored by Sachin Kadyan's avatar Sachin Kadyan
Browse files

Added test for make_msa_mask

parent 974fe5a9
......@@ -359,7 +359,7 @@ def make_msa_mask(protein):
"""Mask features are all ones, but will later be zero-padded."""
protein["msa_mask"] = torch.ones(protein["msa"].shape, dtype=torch.float32)
protein["msa_row_mask"] = torch.ones(
protein["msa"].shape[0], dtype=torch.float32
(protein["msa"].shape[0]), dtype=torch.float32
)
return protein
......
......@@ -11,7 +11,7 @@ import unittest
from data.data_transforms import make_seq_mask, add_distillation_flag, make_all_atom_aatype, fix_templates_aatype, \
correct_msa_restypes, squeeze_features, randomly_replace_msa_with_unknown, MSA_FEATURE_NAMES, sample_msa, \
crop_extra_msa, delete_extra_msa, nearest_neighbor_clusters
crop_extra_msa, delete_extra_msa, nearest_neighbor_clusters, make_msa_mask
from openfold.config import model_config
......@@ -169,6 +169,17 @@ class TestDataTransforms(unittest.TestCase):
print(protein)
assert 'extra_cluster_assignment' in protein
def test_make_msa_mask(self):
with open('../test_data/features.pkl', 'rb') as file:
features = pickle.load(file)
msa_mat = torch.tensor(features['msa'])
protein = {'msa': msa_mat}
protein = make_msa_mask(protein)
print(protein)
assert 'msa_row_mask' in protein
assert protein['msa_row_mask'].shape[0] == msa_mat.shape[0]
if __name__ == '__main__':
unittest.main()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment