"...git@developer.sourcefind.cn:OpenDAS/openpcdet.git" did not exist on "9fdb4435ac5375adba530b8813ef608dffb681c0"
Commit 8923e536 authored by Sachin Kadyan's avatar Sachin Kadyan
Browse files

Added tests for squeeze_features.

parent 9e4fb16f
......@@ -145,7 +145,10 @@ def squeeze_features(protein):
if k in protein:
final_dim = protein[k].shape[-1]
if isinstance(final_dim, int) and final_dim == 1:
protein[k] = torch.squeeze(protein[k], dim=-1)
if torch.is_tensor(protein[k]):
protein[k] = torch.squeeze(protein[k], dim=-1)
else:
protein[k] = np.squeeze(protein[k], axis=-1)
for k in ["seq_length", "num_alignments"]:
if k in protein:
......
......@@ -5,12 +5,12 @@ import os
import pickle
import numpy
import numpy as np
import torch
import unittest
from data.data_transforms import make_seq_mask, add_distillation_flag, make_all_atom_aatype, fix_templates_aatype, \
correct_msa_restypes
correct_msa_restypes, squeeze_features
from openfold.config import model_config
......@@ -65,6 +65,38 @@ class TestDataTransforms(unittest.TestCase):
print(protein)
assert torch.all(torch.eq(torch.tensor(features['msa'].shape), torch.tensor(protein['msa'].shape)))
def test_squeeze_features(self):
with open("../test_data/features.pkl", "rb") as file:
features = pickle.load(file)
print(os.path.realpath(file.name), 'Keys: ', features.keys())
features_list = [
'domain_name', 'msa', 'num_alignments', 'seq_length', 'sequence',
'superfamily', 'deletion_matrix', 'resolution',
'between_segment_residues', 'residue_index', 'template_all_atom_mask']
protein = {'aatype': torch.tensor(features['aatype'])}
for k in features_list:
if k in features:
print(k, features[k].dtype)
if k in ['domain_name', 'sequence']:
protein[k] = np.expand_dims(features[k], -1)
else:
protein[k] = torch.tensor(features[k]).unsqueeze(-1)
for k in ['seq_length', 'num_alignments']:
if k in protein:
protein[k] = torch.tensor(protein[k]).unsqueeze(0)
protein_squeezed = squeeze_features(protein)
print(protein)
for k in features_list:
if k in protein:
print(k, protein_squeezed[k].shape, features[k].shape)
assert protein_squeezed[k].shape == features[k].shape
if __name__ == '__main__':
unittest.main()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment