Commit 44b0bf76 authored by Christina Floristean's avatar Christina Floristean
Browse files

Merge branch 'main' into multimer

parents 959b3f25 1dc2748c
......@@ -796,11 +796,7 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader):
def __init__(self, *args, config, stage="train", generator=None, **kwargs):
super().__init__(*args, **kwargs)
self.config = config
self.stage = stage
if(generator is None):
generator = torch.Generator()
self.stage = stage
self.generator = generator
self._prep_batch_properties_probs()
......@@ -1077,8 +1073,9 @@ class OpenFoldDataModule(pl.LightningDataModule):
)
def _gen_dataloader(self, stage):
generator = torch.Generator()
generator = None
if(self.batch_seed is not None):
generator = torch.Generator()
generator = generator.manual_seed(self.batch_seed)
dataset = None
......
......@@ -184,11 +184,14 @@ def randomly_replace_msa_with_unknown(protein, replace_proportion):
@curry1
def sample_msa(protein, max_seq, keep_extra, seed=None):
"""Sample MSA randomly, remaining sequences are stored are stored as `extra_*`."""
"""Sample MSA randomly, remaining sequences are stored are stored as `extra_*`."""
num_seq = protein["msa"].shape[0]
g = torch.Generator(device=protein["msa"].device)
g = None
if seed is not None:
g = torch.Generator(device=protein["msa"].device)
g.manual_seed(seed)
shuffled = torch.randperm(num_seq - 1, generator=g) + 1
index_order = torch.cat(
(torch.tensor([0], device=shuffled.device), shuffled),
......@@ -1181,8 +1184,10 @@ def random_crop_to_size(
):
"""Crop randomly to `crop_size`, or keep as is if shorter than that."""
# We want each ensemble to be cropped the same way
g = torch.Generator(device=protein["seq_length"].device)
g = None
if seed is not None:
g = torch.Generator(device=protein["seq_length"].device)
g.manual_seed(seed)
seq_length = protein["seq_length"]
......
......@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import random
import torch
......@@ -154,7 +154,7 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed):
def process_tensors_from_config(tensors, common_cfg, mode_cfg):
"""Based on the config, apply filters and transformations to the data."""
ensemble_seed = torch.Generator().seed()
ensemble_seed = random.randint(0, torch.iinfo(torch.int32).max)
def wrap_ensemble_fn(data, i):
"""Function to be mapped over the ensemble dimension."""
......
......@@ -202,10 +202,10 @@ class TestExtraMSAStack(unittest.TestCase):
ckpt=False,
inf=inf,
eps=eps,
).eval()
).eval().cuda()
m = torch.rand((batch_size, s_t, n_res, c_m))
z = torch.rand((batch_size, n_res, n_res, c_z))
m = torch.rand((batch_size, s_t, n_res, c_m), device="cuda")
z = torch.rand((batch_size, n_res, n_res, c_z), device="cuda")
msa_mask = torch.randint(
0,
2,
......@@ -214,6 +214,7 @@ class TestExtraMSAStack(unittest.TestCase):
s_t,
n_res,
),
device="cuda",
)
pair_mask = torch.randint(
0,
......@@ -223,6 +224,7 @@ class TestExtraMSAStack(unittest.TestCase):
n_res,
n_res,
),
device="cuda",
)
shape_z_before = z.shape
......
......@@ -61,6 +61,7 @@ class TestModel(unittest.TestCase):
# deepspeed for this test
model = AlphaFold(c)
model.eval()
batch = {}
tf = torch.randint(c.model.input_embedder.tf_dim - 1, size=(n_res,))
......
......@@ -88,7 +88,6 @@ class OpenFoldWrapper(pl.LightningModule):
)
for k,v in other_metrics.items():
assert(len(v.shape) == 1)
self.log(
f"{phase}/{k}",
torch.mean(v),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment