Commit bc7864b1 authored by Sachin Kadyan's avatar Sachin Kadyan
Browse files

Added flag in protein to indicate if protein comes from distillation dataset.

parent 0e9aaa63
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
import itertools import itertools
from functools import reduce from functools import reduce, wraps
from operator import add from operator import add
import numpy as np import numpy as np
...@@ -71,7 +71,7 @@ def make_template_mask(protein): ...@@ -71,7 +71,7 @@ def make_template_mask(protein):
def curry1(f): def curry1(f):
"""Supply all arguments but the first.""" """Supply all arguments but the first."""
@wraps(f)
def fc(*args, **kwargs): def fc(*args, **kwargs):
return lambda x: f(x, *args, **kwargs) return lambda x: f(x, *args, **kwargs)
...@@ -199,6 +199,11 @@ def sample_msa(protein, max_seq, keep_extra, seed=None): ...@@ -199,6 +199,11 @@ def sample_msa(protein, max_seq, keep_extra, seed=None):
return protein return protein
@curry1
def add_distillation_flag(protein, distillation):
protein['is_distillation'] = distillation
return protein
@curry1 @curry1
def sample_msa_distillation(protein, max_seq): def sample_msa_distillation(protein, max_seq):
if(protein["is_distillation"] == 1): if(protein["is_distillation"] == 1):
......
...@@ -9,7 +9,7 @@ import numpy ...@@ -9,7 +9,7 @@ import numpy
import torch import torch
import unittest import unittest
from data.data_transforms import make_seq_mask from data.data_transforms import make_seq_mask, add_distillation_flag
from openfold.config import model_config from openfold.config import model_config
...@@ -25,6 +25,13 @@ class TestDataTransforms(unittest.TestCase): ...@@ -25,6 +25,13 @@ class TestDataTransforms(unittest.TestCase):
assert 'seq_mask' in protein assert 'seq_mask' in protein
assert protein['seq_mask'].shape == torch.Size((seq.shape[0], 20)) assert protein['seq_mask'].shape == torch.Size((seq.shape[0], 20))
def test_add_distillation_flag(self):
protein = {}
protein = add_distillation_flag.__wrapped__(protein, True)
print(protein)
assert 'is_distillation' in protein
assert protein['is_distillation'] is True
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment