Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
011a6526
Commit
011a6526
authored
Dec 21, 2021
by
Sachin Kadyan
Browse files
Added tests for randomly_replace_msa_with_unknown
parent
cad8de7e
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
18 additions
and
2 deletions
+18
-2
openfold/data/data_transforms.py
openfold/data/data_transforms.py
+3
-1
tests/test_data_transforms.py
tests/test_data_transforms.py
+15
-1
No files found.
openfold/data/data_transforms.py
View file @
011a6526
...
@@ -165,7 +165,9 @@ def randomly_replace_msa_with_unknown(protein, replace_proportion):
...
@@ -165,7 +165,9 @@ def randomly_replace_msa_with_unknown(protein, replace_proportion):
gap_idx
=
21
gap_idx
=
21
msa_mask
=
torch
.
logical_and
(
msa_mask
,
protein
[
"msa"
]
!=
gap_idx
)
msa_mask
=
torch
.
logical_and
(
msa_mask
,
protein
[
"msa"
]
!=
gap_idx
)
protein
[
"msa"
]
=
torch
.
where
(
protein
[
"msa"
]
=
torch
.
where
(
msa_mask
,
torch
.
ones_like
(
protein
[
"msa"
])
*
x_idx
,
protein
[
"msa"
]
msa_mask
,
torch
.
ones_like
(
protein
[
"msa"
])
*
x_idx
,
protein
[
"msa"
]
)
)
aatype_mask
=
torch
.
rand
(
protein
[
"aatype"
].
shape
)
<
replace_proportion
aatype_mask
=
torch
.
rand
(
protein
[
"aatype"
].
shape
)
<
replace_proportion
...
...
tests/test_data_transforms.py
View file @
011a6526
...
@@ -10,7 +10,7 @@ import torch
...
@@ -10,7 +10,7 @@ import torch
import
unittest
import
unittest
from
data.data_transforms
import
make_seq_mask
,
add_distillation_flag
,
make_all_atom_aatype
,
fix_templates_aatype
,
\
from
data.data_transforms
import
make_seq_mask
,
add_distillation_flag
,
make_all_atom_aatype
,
fix_templates_aatype
,
\
correct_msa_restypes
,
squeeze_features
correct_msa_restypes
,
squeeze_features
,
randomly_replace_msa_with_unknown
from
openfold.config
import
model_config
from
openfold.config
import
model_config
...
@@ -95,6 +95,20 @@ class TestDataTransforms(unittest.TestCase):
...
@@ -95,6 +95,20 @@ class TestDataTransforms(unittest.TestCase):
print
(
k
,
protein_squeezed
[
k
].
shape
,
features
[
k
].
shape
)
print
(
k
,
protein_squeezed
[
k
].
shape
,
features
[
k
].
shape
)
assert
protein_squeezed
[
k
].
shape
==
features
[
k
].
shape
assert
protein_squeezed
[
k
].
shape
==
features
[
k
].
shape
def
test_randomly_replace_msa_with_unknown
(
self
):
with
open
(
'../test_data/features.pkl'
,
'rb'
)
as
file
:
features
=
pickle
.
load
(
file
)
protein
=
{
'msa'
:
torch
.
tensor
(
features
[
'msa'
]),
'aatype'
:
torch
.
argmax
(
torch
.
tensor
(
features
[
'aatype'
]),
dim
=
1
)}
replace_proportion
=
0.15
x_idx
=
20
protein
=
randomly_replace_msa_with_unknown
.
__wrapped__
(
protein
,
replace_proportion
)
unknown_proportion_in_msa
=
torch
.
bincount
(
protein
[
'msa'
].
flatten
())
/
torch
.
numel
(
protein
[
'msa'
])
unknown_proportion_in_seq
=
torch
.
bincount
(
protein
[
'aatype'
].
flatten
())
/
torch
.
numel
(
protein
[
'aatype'
])
print
(
protein
)
print
(
'Proportion of X in MSA: '
,
unknown_proportion_in_msa
[
x_idx
])
print
(
'Proportion of X in sequence: '
,
unknown_proportion_in_seq
[
x_idx
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment