Commit 44b0bf76 authored by Christina Floristean's avatar Christina Floristean
Browse files

Merge branch 'main' into multimer

parents 959b3f25 1dc2748c
...@@ -796,11 +796,7 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader): ...@@ -796,11 +796,7 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader):
def __init__(self, *args, config, stage="train", generator=None, **kwargs): def __init__(self, *args, config, stage="train", generator=None, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.config = config self.config = config
self.stage = stage self.stage = stage
if(generator is None):
generator = torch.Generator()
self.generator = generator self.generator = generator
self._prep_batch_properties_probs() self._prep_batch_properties_probs()
...@@ -1077,8 +1073,9 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -1077,8 +1073,9 @@ class OpenFoldDataModule(pl.LightningDataModule):
) )
def _gen_dataloader(self, stage): def _gen_dataloader(self, stage):
generator = torch.Generator() generator = None
if(self.batch_seed is not None): if(self.batch_seed is not None):
generator = torch.Generator()
generator = generator.manual_seed(self.batch_seed) generator = generator.manual_seed(self.batch_seed)
dataset = None dataset = None
......
...@@ -184,11 +184,14 @@ def randomly_replace_msa_with_unknown(protein, replace_proportion): ...@@ -184,11 +184,14 @@ def randomly_replace_msa_with_unknown(protein, replace_proportion):
@curry1 @curry1
def sample_msa(protein, max_seq, keep_extra, seed=None): def sample_msa(protein, max_seq, keep_extra, seed=None):
"""Sample MSA randomly, remaining sequences are stored are stored as `extra_*`.""" """Sample MSA randomly, remaining sequences are stored are stored as `extra_*`."""
num_seq = protein["msa"].shape[0] num_seq = protein["msa"].shape[0]
g = torch.Generator(device=protein["msa"].device)
g = None
if seed is not None: if seed is not None:
g = torch.Generator(device=protein["msa"].device)
g.manual_seed(seed) g.manual_seed(seed)
shuffled = torch.randperm(num_seq - 1, generator=g) + 1 shuffled = torch.randperm(num_seq - 1, generator=g) + 1
index_order = torch.cat( index_order = torch.cat(
(torch.tensor([0], device=shuffled.device), shuffled), (torch.tensor([0], device=shuffled.device), shuffled),
...@@ -1181,8 +1184,10 @@ def random_crop_to_size( ...@@ -1181,8 +1184,10 @@ def random_crop_to_size(
): ):
"""Crop randomly to `crop_size`, or keep as is if shorter than that.""" """Crop randomly to `crop_size`, or keep as is if shorter than that."""
# We want each ensemble to be cropped the same way # We want each ensemble to be cropped the same way
g = torch.Generator(device=protein["seq_length"].device)
g = None
if seed is not None: if seed is not None:
g = torch.Generator(device=protein["seq_length"].device)
g.manual_seed(seed) g.manual_seed(seed)
seq_length = protein["seq_length"] seq_length = protein["seq_length"]
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from functools import partial import random
import torch import torch
...@@ -154,7 +154,7 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed): ...@@ -154,7 +154,7 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed):
def process_tensors_from_config(tensors, common_cfg, mode_cfg): def process_tensors_from_config(tensors, common_cfg, mode_cfg):
"""Based on the config, apply filters and transformations to the data.""" """Based on the config, apply filters and transformations to the data."""
ensemble_seed = torch.Generator().seed() ensemble_seed = random.randint(0, torch.iinfo(torch.int32).max)
def wrap_ensemble_fn(data, i): def wrap_ensemble_fn(data, i):
"""Function to be mapped over the ensemble dimension.""" """Function to be mapped over the ensemble dimension."""
......
...@@ -202,10 +202,10 @@ class TestExtraMSAStack(unittest.TestCase): ...@@ -202,10 +202,10 @@ class TestExtraMSAStack(unittest.TestCase):
ckpt=False, ckpt=False,
inf=inf, inf=inf,
eps=eps, eps=eps,
).eval() ).eval().cuda()
m = torch.rand((batch_size, s_t, n_res, c_m)) m = torch.rand((batch_size, s_t, n_res, c_m), device="cuda")
z = torch.rand((batch_size, n_res, n_res, c_z)) z = torch.rand((batch_size, n_res, n_res, c_z), device="cuda")
msa_mask = torch.randint( msa_mask = torch.randint(
0, 0,
2, 2,
...@@ -214,6 +214,7 @@ class TestExtraMSAStack(unittest.TestCase): ...@@ -214,6 +214,7 @@ class TestExtraMSAStack(unittest.TestCase):
s_t, s_t,
n_res, n_res,
), ),
device="cuda",
) )
pair_mask = torch.randint( pair_mask = torch.randint(
0, 0,
...@@ -223,6 +224,7 @@ class TestExtraMSAStack(unittest.TestCase): ...@@ -223,6 +224,7 @@ class TestExtraMSAStack(unittest.TestCase):
n_res, n_res,
n_res, n_res,
), ),
device="cuda",
) )
shape_z_before = z.shape shape_z_before = z.shape
......
...@@ -61,6 +61,7 @@ class TestModel(unittest.TestCase): ...@@ -61,6 +61,7 @@ class TestModel(unittest.TestCase):
# deepspeed for this test # deepspeed for this test
model = AlphaFold(c) model = AlphaFold(c)
model.eval()
batch = {} batch = {}
tf = torch.randint(c.model.input_embedder.tf_dim - 1, size=(n_res,)) tf = torch.randint(c.model.input_embedder.tf_dim - 1, size=(n_res,))
......
...@@ -88,7 +88,6 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -88,7 +88,6 @@ class OpenFoldWrapper(pl.LightningModule):
) )
for k,v in other_metrics.items(): for k,v in other_metrics.items():
assert(len(v.shape) == 1)
self.log( self.log(
f"{phase}/{k}", f"{phase}/{k}",
torch.mean(v), torch.mean(v),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment