Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
8923e536
Commit
8923e536
authored
Dec 21, 2021
by
Sachin Kadyan
Browse files
Added tests for squeeze_features.
parent
9e4fb16f
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
38 additions
and
3 deletions
+38
-3
openfold/data/data_transforms.py
openfold/data/data_transforms.py
+4
-1
tests/test_data/features.pkl
tests/test_data/features.pkl
+0
-0
tests/test_data_transforms.py
tests/test_data_transforms.py
+34
-2
No files found.
openfold/data/data_transforms.py
View file @
8923e536
...
@@ -145,7 +145,10 @@ def squeeze_features(protein):
...
@@ -145,7 +145,10 @@ def squeeze_features(protein):
if
k
in
protein
:
if
k
in
protein
:
final_dim
=
protein
[
k
].
shape
[
-
1
]
final_dim
=
protein
[
k
].
shape
[
-
1
]
if
isinstance
(
final_dim
,
int
)
and
final_dim
==
1
:
if
isinstance
(
final_dim
,
int
)
and
final_dim
==
1
:
if
torch
.
is_tensor
(
protein
[
k
]):
protein
[
k
]
=
torch
.
squeeze
(
protein
[
k
],
dim
=-
1
)
protein
[
k
]
=
torch
.
squeeze
(
protein
[
k
],
dim
=-
1
)
else
:
protein
[
k
]
=
np
.
squeeze
(
protein
[
k
],
axis
=-
1
)
for
k
in
[
"seq_length"
,
"num_alignments"
]:
for
k
in
[
"seq_length"
,
"num_alignments"
]:
if
k
in
protein
:
if
k
in
protein
:
...
...
tests/test_data/features.pkl
View file @
8923e536
No preview for this file type
tests/test_data_transforms.py
View file @
8923e536
...
@@ -5,12 +5,12 @@ import os
...
@@ -5,12 +5,12 @@ import os
import
pickle
import
pickle
import
numpy
import
numpy
as
np
import
torch
import
torch
import
unittest
import
unittest
from
data.data_transforms
import
make_seq_mask
,
add_distillation_flag
,
make_all_atom_aatype
,
fix_templates_aatype
,
\
from
data.data_transforms
import
make_seq_mask
,
add_distillation_flag
,
make_all_atom_aatype
,
fix_templates_aatype
,
\
correct_msa_restypes
correct_msa_restypes
,
squeeze_features
from
openfold.config
import
model_config
from
openfold.config
import
model_config
...
@@ -65,6 +65,38 @@ class TestDataTransforms(unittest.TestCase):
...
@@ -65,6 +65,38 @@ class TestDataTransforms(unittest.TestCase):
print
(
protein
)
print
(
protein
)
assert
torch
.
all
(
torch
.
eq
(
torch
.
tensor
(
features
[
'msa'
].
shape
),
torch
.
tensor
(
protein
[
'msa'
].
shape
)))
assert
torch
.
all
(
torch
.
eq
(
torch
.
tensor
(
features
[
'msa'
].
shape
),
torch
.
tensor
(
protein
[
'msa'
].
shape
)))
def
test_squeeze_features
(
self
):
with
open
(
"../test_data/features.pkl"
,
"rb"
)
as
file
:
features
=
pickle
.
load
(
file
)
print
(
os
.
path
.
realpath
(
file
.
name
),
'Keys: '
,
features
.
keys
())
features_list
=
[
'domain_name'
,
'msa'
,
'num_alignments'
,
'seq_length'
,
'sequence'
,
'superfamily'
,
'deletion_matrix'
,
'resolution'
,
'between_segment_residues'
,
'residue_index'
,
'template_all_atom_mask'
]
protein
=
{
'aatype'
:
torch
.
tensor
(
features
[
'aatype'
])}
for
k
in
features_list
:
if
k
in
features
:
print
(
k
,
features
[
k
].
dtype
)
if
k
in
[
'domain_name'
,
'sequence'
]:
protein
[
k
]
=
np
.
expand_dims
(
features
[
k
],
-
1
)
else
:
protein
[
k
]
=
torch
.
tensor
(
features
[
k
]).
unsqueeze
(
-
1
)
for
k
in
[
'seq_length'
,
'num_alignments'
]:
if
k
in
protein
:
protein
[
k
]
=
torch
.
tensor
(
protein
[
k
]).
unsqueeze
(
0
)
protein_squeezed
=
squeeze_features
(
protein
)
print
(
protein
)
for
k
in
features_list
:
if
k
in
protein
:
print
(
k
,
protein_squeezed
[
k
].
shape
,
features
[
k
].
shape
)
assert
protein_squeezed
[
k
].
shape
==
features
[
k
].
shape
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment