Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
09564595
Commit
09564595
authored
Dec 21, 2021
by
Sachin Kadyan
Browse files
Added test for crop_extra_msa
parent
925d56f8
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
16 additions
and
1 deletion
+16
-1
tests/test_data_transforms.py
tests/test_data_transforms.py
+16
-1
No files found.
tests/test_data_transforms.py
View file @
09564595
...
@@ -10,7 +10,8 @@ import torch
...
@@ -10,7 +10,8 @@ import torch
import
unittest
import
unittest
from
data.data_transforms
import
make_seq_mask
,
add_distillation_flag
,
make_all_atom_aatype
,
fix_templates_aatype
,
\
from
data.data_transforms
import
make_seq_mask
,
add_distillation_flag
,
make_all_atom_aatype
,
fix_templates_aatype
,
\
correct_msa_restypes
,
squeeze_features
,
randomly_replace_msa_with_unknown
,
MSA_FEATURE_NAMES
,
sample_msa
correct_msa_restypes
,
squeeze_features
,
randomly_replace_msa_with_unknown
,
MSA_FEATURE_NAMES
,
sample_msa
,
\
crop_extra_msa
from
openfold.config
import
model_config
from
openfold.config
import
model_config
...
@@ -131,6 +132,20 @@ class TestDataTransforms(unittest.TestCase):
...
@@ -131,6 +132,20 @@ class TestDataTransforms(unittest.TestCase):
print
(
'msa'
,
protein
[
k
].
shape
[
0
]
-
min
(
protein
[
k
].
shape
[
0
],
max_seq
))
print
(
'msa'
,
protein
[
k
].
shape
[
0
]
-
min
(
protein
[
k
].
shape
[
0
],
max_seq
))
assert
protein_processed
[
'extra_'
+
k
].
shape
[
0
]
==
protein
[
k
].
shape
[
0
]
-
min
(
protein
[
k
].
shape
[
0
],
max_seq
)
assert
protein_processed
[
'extra_'
+
k
].
shape
[
0
]
==
protein
[
k
].
shape
[
0
]
-
min
(
protein
[
k
].
shape
[
0
],
max_seq
)
def
test_crop_extra_msa
(
self
):
with
open
(
'../test_data/features.pkl'
,
'rb'
)
as
file
:
features
=
pickle
.
load
(
file
)
max_extra_msa
=
10
protein
=
{
'extra_msa'
:
torch
.
tensor
(
features
[
'msa'
])}
num_seq
=
protein
[
"extra_msa"
].
shape
[
0
]
protein
=
crop_extra_msa
.
__wrapped__
(
protein
,
max_extra_msa
)
print
(
protein
)
for
k
in
MSA_FEATURE_NAMES
:
if
"extra_"
+
k
in
protein
:
assert
protein
[
"extra_"
+
k
].
shape
[
0
]
==
min
(
max_extra_msa
,
num_seq
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment