Unverified Commit 39d0ef43 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz Committed by GitHub
Browse files

Make seeding more consistent

parent 410e1829
......@@ -16,6 +16,7 @@
import itertools
from functools import reduce, wraps
from operator import add
import random
import numpy as np
import torch
......@@ -183,12 +184,11 @@ def randomly_replace_msa_with_unknown(protein, replace_proportion):
@curry1
def sample_msa(protein, max_seq, keep_extra, seed=None):
"""Sample MSA randomly, remaining sequences are stored are stored as `extra_*`."""
if(seed is None):
seed = random.randint(0, 2147483647)
num_seq = protein["msa"].shape[0]
g = torch.Generator(device=protein["msa"].device)
if seed is not None:
g.manual_seed(seed)
else:
g.seed()
g.manual_seed(seed)
shuffled = torch.randperm(num_seq - 1, generator=g) + 1
index_order = torch.cat(
(torch.tensor([0], device=shuffled.device), shuffled),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment