Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
bc7864b1
Commit
bc7864b1
authored
Dec 21, 2021
by
Sachin Kadyan
Browse files
Added flag in protein to indicate if protein comes from distillation dataset.
parent
0e9aaa63
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
15 additions
and
3 deletions
+15
-3
openfold/data/data_transforms.py
openfold/data/data_transforms.py
+7
-2
tests/test_data_transforms.py
tests/test_data_transforms.py
+8
-1
No files found.
openfold/data/data_transforms.py
View file @
bc7864b1
...
...
@@ -14,7 +14,7 @@
# limitations under the License.
import
itertools
from
functools
import
reduce
from
functools
import
reduce
,
wraps
from
operator
import
add
import
numpy
as
np
...
...
@@ -71,7 +71,7 @@ def make_template_mask(protein):
def
curry1
(
f
):
"""Supply all arguments but the first."""
@
wraps
(
f
)
def
fc
(
*
args
,
**
kwargs
):
return
lambda
x
:
f
(
x
,
*
args
,
**
kwargs
)
...
...
@@ -199,6 +199,11 @@ def sample_msa(protein, max_seq, keep_extra, seed=None):
return
protein
@
curry1
def
add_distillation_flag
(
protein
,
distillation
):
protein
[
'is_distillation'
]
=
distillation
return
protein
@
curry1
def
sample_msa_distillation
(
protein
,
max_seq
):
if
(
protein
[
"is_distillation"
]
==
1
):
...
...
tests/test_data_transforms.py
View file @
bc7864b1
...
...
@@ -9,7 +9,7 @@ import numpy
import
torch
import
unittest
from
data.data_transforms
import
make_seq_mask
from
data.data_transforms
import
make_seq_mask
,
add_distillation_flag
from
openfold.config
import
model_config
...
...
@@ -25,6 +25,13 @@ class TestDataTransforms(unittest.TestCase):
assert
'seq_mask'
in
protein
assert
protein
[
'seq_mask'
].
shape
==
torch
.
Size
((
seq
.
shape
[
0
],
20
))
def
test_add_distillation_flag
(
self
):
protein
=
{}
protein
=
add_distillation_flag
.
__wrapped__
(
protein
,
True
)
print
(
protein
)
assert
'is_distillation'
in
protein
assert
protein
[
'is_distillation'
]
is
True
if
__name__
==
'__main__'
:
unittest
.
main
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment