Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
bc7864b1
Commit
bc7864b1
authored
Dec 21, 2021
by
Sachin Kadyan
Browse files
Added flag in protein to indicate if protein comes from distillation dataset.
parent
0e9aaa63
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
15 additions
and
3 deletions
+15
-3
openfold/data/data_transforms.py
openfold/data/data_transforms.py
+7
-2
tests/test_data_transforms.py
tests/test_data_transforms.py
+8
-1
No files found.
openfold/data/data_transforms.py
View file @
bc7864b1
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
# limitations under the License.
# limitations under the License.
import
itertools
import
itertools
from
functools
import
reduce
from
functools
import
reduce
,
wraps
from
operator
import
add
from
operator
import
add
import
numpy
as
np
import
numpy
as
np
...
@@ -71,7 +71,7 @@ def make_template_mask(protein):
...
@@ -71,7 +71,7 @@ def make_template_mask(protein):
def
curry1
(
f
):
def
curry1
(
f
):
"""Supply all arguments but the first."""
"""Supply all arguments but the first."""
@
wraps
(
f
)
def
fc
(
*
args
,
**
kwargs
):
def
fc
(
*
args
,
**
kwargs
):
return
lambda
x
:
f
(
x
,
*
args
,
**
kwargs
)
return
lambda
x
:
f
(
x
,
*
args
,
**
kwargs
)
...
@@ -199,6 +199,11 @@ def sample_msa(protein, max_seq, keep_extra, seed=None):
...
@@ -199,6 +199,11 @@ def sample_msa(protein, max_seq, keep_extra, seed=None):
return
protein
return
protein
@
curry1
def
add_distillation_flag
(
protein
,
distillation
):
protein
[
'is_distillation'
]
=
distillation
return
protein
@
curry1
@
curry1
def
sample_msa_distillation
(
protein
,
max_seq
):
def
sample_msa_distillation
(
protein
,
max_seq
):
if
(
protein
[
"is_distillation"
]
==
1
):
if
(
protein
[
"is_distillation"
]
==
1
):
...
...
tests/test_data_transforms.py
View file @
bc7864b1
...
@@ -9,7 +9,7 @@ import numpy
...
@@ -9,7 +9,7 @@ import numpy
import
torch
import
torch
import
unittest
import
unittest
from
data.data_transforms
import
make_seq_mask
from
data.data_transforms
import
make_seq_mask
,
add_distillation_flag
from
openfold.config
import
model_config
from
openfold.config
import
model_config
...
@@ -25,6 +25,13 @@ class TestDataTransforms(unittest.TestCase):
...
@@ -25,6 +25,13 @@ class TestDataTransforms(unittest.TestCase):
assert
'seq_mask'
in
protein
assert
'seq_mask'
in
protein
assert
protein
[
'seq_mask'
].
shape
==
torch
.
Size
((
seq
.
shape
[
0
],
20
))
assert
protein
[
'seq_mask'
].
shape
==
torch
.
Size
((
seq
.
shape
[
0
],
20
))
def
test_add_distillation_flag
(
self
):
protein
=
{}
protein
=
add_distillation_flag
.
__wrapped__
(
protein
,
True
)
print
(
protein
)
assert
'is_distillation'
in
protein
assert
protein
[
'is_distillation'
]
is
True
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment