Unverified Commit 39d0ef43 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz Committed by GitHub
Browse files

Make seeding more consistent

parent 410e1829
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
import itertools import itertools
from functools import reduce, wraps from functools import reduce, wraps
from operator import add from operator import add
import random
import numpy as np import numpy as np
import torch import torch
...@@ -183,12 +184,11 @@ def randomly_replace_msa_with_unknown(protein, replace_proportion): ...@@ -183,12 +184,11 @@ def randomly_replace_msa_with_unknown(protein, replace_proportion):
@curry1 @curry1
def sample_msa(protein, max_seq, keep_extra, seed=None): def sample_msa(protein, max_seq, keep_extra, seed=None):
"""Sample MSA randomly, remaining sequences are stored are stored as `extra_*`.""" """Sample MSA randomly, remaining sequences are stored are stored as `extra_*`."""
if(seed is None):
seed = random.randint(0, 2147483647)
num_seq = protein["msa"].shape[0] num_seq = protein["msa"].shape[0]
g = torch.Generator(device=protein["msa"].device) g = torch.Generator(device=protein["msa"].device)
if seed is not None: g.manual_seed(seed)
g.manual_seed(seed)
else:
g.seed()
shuffled = torch.randperm(num_seq - 1, generator=g) + 1 shuffled = torch.randperm(num_seq - 1, generator=g) + 1
index_order = torch.cat( index_order = torch.cat(
(torch.tensor([0], device=shuffled.device), shuffled), (torch.tensor([0], device=shuffled.device), shuffled),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment