Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
e1c7c9e7
Commit
e1c7c9e7
authored
Dec 21, 2021
by
Sachin Kadyan
Browse files
Added test for make_atom14_masks
parent
f33faf84
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
18 additions
and
5 deletions
+18
-5
openfold/data/data_transforms.py
openfold/data/data_transforms.py
+5
-4
tests/test_data_transforms.py
tests/test_data_transforms.py
+13
-1
No files found.
openfold/data/data_transforms.py
View file @
e1c7c9e7
...
@@ -612,17 +612,18 @@ def make_atom14_masks(protein):
...
@@ -612,17 +612,18 @@ def make_atom14_masks(protein):
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
device
=
protein
[
"aatype"
].
device
,
device
=
protein
[
"aatype"
].
device
,
)
)
protein_aatype
=
protein
[
'aatype'
].
to
(
torch
.
long
)
# create the mapping for (residx, atom14) --> atom37, i.e. an array
# create the mapping for (residx, atom14) --> atom37, i.e. an array
# with shape (num_res, 14) containing the atom37 indices for this protein
# with shape (num_res, 14) containing the atom37 indices for this protein
residx_atom14_to_atom37
=
restype_atom14_to_atom37
[
protein
[
"
aatype
"
]
]
residx_atom14_to_atom37
=
restype_atom14_to_atom37
[
protein
_
aatype
]
residx_atom14_mask
=
restype_atom14_mask
[
protein
[
"
aatype
"
]
]
residx_atom14_mask
=
restype_atom14_mask
[
protein
_
aatype
]
protein
[
"atom14_atom_exists"
]
=
residx_atom14_mask
protein
[
"atom14_atom_exists"
]
=
residx_atom14_mask
protein
[
"residx_atom14_to_atom37"
]
=
residx_atom14_to_atom37
.
long
()
protein
[
"residx_atom14_to_atom37"
]
=
residx_atom14_to_atom37
.
long
()
# create the gather indices for mapping back
# create the gather indices for mapping back
residx_atom37_to_atom14
=
restype_atom37_to_atom14
[
protein
[
"
aatype
"
]
]
residx_atom37_to_atom14
=
restype_atom37_to_atom14
[
protein
_
aatype
]
protein
[
"residx_atom37_to_atom14"
]
=
residx_atom37_to_atom14
.
long
()
protein
[
"residx_atom37_to_atom14"
]
=
residx_atom37_to_atom14
.
long
()
# create the corresponding mask
# create the corresponding mask
...
@@ -636,7 +637,7 @@ def make_atom14_masks(protein):
...
@@ -636,7 +637,7 @@ def make_atom14_masks(protein):
atom_type
=
rc
.
atom_order
[
atom_name
]
atom_type
=
rc
.
atom_order
[
atom_name
]
restype_atom37_mask
[
restype
,
atom_type
]
=
1
restype_atom37_mask
[
restype
,
atom_type
]
=
1
residx_atom37_mask
=
restype_atom37_mask
[
protein
[
"
aatype
"
]
]
residx_atom37_mask
=
restype_atom37_mask
[
protein
_
aatype
]
protein
[
"atom37_atom_exists"
]
=
residx_atom37_mask
protein
[
"atom37_atom_exists"
]
=
residx_atom37_mask
return
protein
return
protein
...
...
tests/test_data_transforms.py
View file @
e1c7c9e7
...
@@ -12,7 +12,7 @@ import unittest
...
@@ -12,7 +12,7 @@ import unittest
from
data.data_transforms
import
make_seq_mask
,
add_distillation_flag
,
make_all_atom_aatype
,
fix_templates_aatype
,
\
from
data.data_transforms
import
make_seq_mask
,
add_distillation_flag
,
make_all_atom_aatype
,
fix_templates_aatype
,
\
correct_msa_restypes
,
squeeze_features
,
randomly_replace_msa_with_unknown
,
MSA_FEATURE_NAMES
,
sample_msa
,
\
correct_msa_restypes
,
squeeze_features
,
randomly_replace_msa_with_unknown
,
MSA_FEATURE_NAMES
,
sample_msa
,
\
crop_extra_msa
,
delete_extra_msa
,
nearest_neighbor_clusters
,
make_msa_mask
,
make_hhblits_profile
,
make_masked_msa
,
\
crop_extra_msa
,
delete_extra_msa
,
nearest_neighbor_clusters
,
make_msa_mask
,
make_hhblits_profile
,
make_masked_msa
,
\
make_msa_feat
,
crop_templates
make_msa_feat
,
crop_templates
,
make_atom14_masks
from
tests.config
import
config
from
tests.config
import
config
...
@@ -231,6 +231,18 @@ class TestDataTransforms(unittest.TestCase):
...
@@ -231,6 +231,18 @@ class TestDataTransforms(unittest.TestCase):
assert
protein
[
'template_aatype'
].
shape
[
0
]
==
max_templates
assert
protein
[
'template_aatype'
].
shape
[
0
]
==
max_templates
assert
protein
[
'template_all_atom_masks'
].
shape
[
0
]
==
max_templates
assert
protein
[
'template_all_atom_masks'
].
shape
[
0
]
==
max_templates
def
test_make_atom14_masks
(
self
):
with
gzip
.
open
(
'../test_data/sample_feats.pickle.gz'
,
'rb'
)
as
file
:
features
=
pickle
.
load
(
file
)
protein
=
{
'aatype'
:
torch
.
tensor
(
features
[
'aatype'
][
0
])}
protein
=
make_atom14_masks
(
protein
)
print
(
protein
)
assert
'atom14_atom_exists'
in
protein
assert
'residx_atom14_to_atom37'
in
protein
assert
'residx_atom37_to_atom14'
in
protein
assert
'atom37_atom_exists'
in
protein
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment