Commit 1dc2748c authored by Christina Floristean's avatar Christina Floristean
Browse files

Additional changes to seeding

parent 39d0ef43
......@@ -7,7 +7,6 @@ import pickle
from typing import Optional, Sequence, List, Any
import ml_collections as mlc
import numpy as np
import pytorch_lightning as pl
import torch
from torch.utils.data import RandomSampler
......@@ -428,10 +427,6 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader):
super().__init__(*args, **kwargs)
self.config = config
self.stage = stage
if(generator is None):
generator = torch.Generator()
self.generator = generator
self._prep_batch_properties_probs()
......@@ -687,8 +682,9 @@ class OpenFoldDataModule(pl.LightningDataModule):
)
def _gen_dataloader(self, stage):
generator = torch.Generator()
generator = None
if(self.batch_seed is not None):
generator = torch.Generator()
generator = generator.manual_seed(self.batch_seed)
dataset = None
......
......@@ -16,7 +16,6 @@
import itertools
from functools import reduce, wraps
from operator import add
import random
import numpy as np
import torch
......@@ -184,11 +183,13 @@ def randomly_replace_msa_with_unknown(protein, replace_proportion):
@curry1
def sample_msa(protein, max_seq, keep_extra, seed=None):
"""Sample MSA randomly, remaining sequences are stored are stored as `extra_*`."""
if(seed is None):
seed = random.randint(0, 2147483647)
num_seq = protein["msa"].shape[0]
g = None
if seed is not None:
g = torch.Generator(device=protein["msa"].device)
g.manual_seed(seed)
shuffled = torch.randperm(num_seq - 1, generator=g) + 1
index_order = torch.cat(
(torch.tensor([0], device=shuffled.device), shuffled),
......@@ -1141,8 +1142,10 @@ def random_crop_to_size(
):
"""Crop randomly to `crop_size`, or keep as is if shorter than that."""
# We want each ensemble to be cropped the same way
g = torch.Generator(device=protein["seq_length"].device)
g = None
if seed is not None:
g = torch.Generator(device=protein["seq_length"].device)
g.manual_seed(seed)
seq_length = protein["seq_length"]
......
......@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import random
import torch
......@@ -154,7 +153,7 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed):
def process_tensors_from_config(tensors, common_cfg, mode_cfg):
"""Based on the config, apply filters and transformations to the data."""
ensemble_seed = random.randint(0, 2147483647)
ensemble_seed = random.randint(0, torch.iinfo(torch.int32).max)
def wrap_ensemble_fn(data, i):
"""Function to be mapped over the ensemble dimension."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment