Commit 1dc2748c authored by Christina Floristean's avatar Christina Floristean
Browse files

Additional changes to seeding

parent 39d0ef43
...@@ -7,7 +7,6 @@ import pickle ...@@ -7,7 +7,6 @@ import pickle
from typing import Optional, Sequence, List, Any from typing import Optional, Sequence, List, Any
import ml_collections as mlc import ml_collections as mlc
import numpy as np
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from torch.utils.data import RandomSampler from torch.utils.data import RandomSampler
...@@ -427,11 +426,7 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader): ...@@ -427,11 +426,7 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader):
def __init__(self, *args, config, stage="train", generator=None, **kwargs): def __init__(self, *args, config, stage="train", generator=None, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.config = config self.config = config
self.stage = stage self.stage = stage
if(generator is None):
generator = torch.Generator()
self.generator = generator self.generator = generator
self._prep_batch_properties_probs() self._prep_batch_properties_probs()
...@@ -687,8 +682,9 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -687,8 +682,9 @@ class OpenFoldDataModule(pl.LightningDataModule):
) )
def _gen_dataloader(self, stage): def _gen_dataloader(self, stage):
generator = torch.Generator() generator = None
if(self.batch_seed is not None): if(self.batch_seed is not None):
generator = torch.Generator()
generator = generator.manual_seed(self.batch_seed) generator = generator.manual_seed(self.batch_seed)
dataset = None dataset = None
......
...@@ -16,7 +16,6 @@ ...@@ -16,7 +16,6 @@
import itertools import itertools
from functools import reduce, wraps from functools import reduce, wraps
from operator import add from operator import add
import random
import numpy as np import numpy as np
import torch import torch
...@@ -183,12 +182,14 @@ def randomly_replace_msa_with_unknown(protein, replace_proportion): ...@@ -183,12 +182,14 @@ def randomly_replace_msa_with_unknown(protein, replace_proportion):
@curry1 @curry1
def sample_msa(protein, max_seq, keep_extra, seed=None): def sample_msa(protein, max_seq, keep_extra, seed=None):
"""Sample MSA randomly, remaining sequences are stored are stored as `extra_*`.""" """Sample MSA randomly, remaining sequences are stored are stored as `extra_*`."""
if(seed is None):
seed = random.randint(0, 2147483647)
num_seq = protein["msa"].shape[0] num_seq = protein["msa"].shape[0]
g = torch.Generator(device=protein["msa"].device)
g.manual_seed(seed) g = None
if seed is not None:
g = torch.Generator(device=protein["msa"].device)
g.manual_seed(seed)
shuffled = torch.randperm(num_seq - 1, generator=g) + 1 shuffled = torch.randperm(num_seq - 1, generator=g) + 1
index_order = torch.cat( index_order = torch.cat(
(torch.tensor([0], device=shuffled.device), shuffled), (torch.tensor([0], device=shuffled.device), shuffled),
...@@ -1141,8 +1142,10 @@ def random_crop_to_size( ...@@ -1141,8 +1142,10 @@ def random_crop_to_size(
): ):
"""Crop randomly to `crop_size`, or keep as is if shorter than that.""" """Crop randomly to `crop_size`, or keep as is if shorter than that."""
# We want each ensemble to be cropped the same way # We want each ensemble to be cropped the same way
g = torch.Generator(device=protein["seq_length"].device)
g = None
if seed is not None: if seed is not None:
g = torch.Generator(device=protein["seq_length"].device)
g.manual_seed(seed) g.manual_seed(seed)
seq_length = protein["seq_length"] seq_length = protein["seq_length"]
......
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from functools import partial
import random import random
import torch import torch
...@@ -154,7 +153,7 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed): ...@@ -154,7 +153,7 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed):
def process_tensors_from_config(tensors, common_cfg, mode_cfg): def process_tensors_from_config(tensors, common_cfg, mode_cfg):
"""Based on the config, apply filters and transformations to the data.""" """Based on the config, apply filters and transformations to the data."""
ensemble_seed = random.randint(0, 2147483647) ensemble_seed = random.randint(0, torch.iinfo(torch.int32).max)
def wrap_ensemble_fn(data, i): def wrap_ensemble_fn(data, i):
"""Function to be mapped over the ensemble dimension.""" """Function to be mapped over the ensemble dimension."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment