Commit bc7864b1 authored by Sachin Kadyan's avatar Sachin Kadyan
Browse files

Added flag in protein to indicate if protein comes from distillation dataset.

parent 0e9aaa63
......@@ -14,7 +14,7 @@
# limitations under the License.
import itertools
from functools import reduce
from functools import reduce, wraps
from operator import add
import numpy as np
......@@ -71,7 +71,7 @@ def make_template_mask(protein):
def curry1(f):
"""Supply all arguments but the first."""
@wraps(f)
def fc(*args, **kwargs):
return lambda x: f(x, *args, **kwargs)
......@@ -199,6 +199,11 @@ def sample_msa(protein, max_seq, keep_extra, seed=None):
return protein
@curry1
def add_distillation_flag(protein, distillation):
protein['is_distillation'] = distillation
return protein
@curry1
def sample_msa_distillation(protein, max_seq):
if(protein["is_distillation"] == 1):
......
......@@ -9,7 +9,7 @@ import numpy
import torch
import unittest
from data.data_transforms import make_seq_mask
from data.data_transforms import make_seq_mask, add_distillation_flag
from openfold.config import model_config
......@@ -25,6 +25,13 @@ class TestDataTransforms(unittest.TestCase):
assert 'seq_mask' in protein
assert protein['seq_mask'].shape == torch.Size((seq.shape[0], 20))
def test_add_distillation_flag(self):
protein = {}
protein = add_distillation_flag.__wrapped__(protein, True)
print(protein)
assert 'is_distillation' in protein
assert protein['is_distillation'] is True
if __name__ == '__main__':
unittest.main()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment