Commit a18f98cf authored by Christina Floristean's avatar Christina Floristean
Browse files

Seed fixes for multimer data pipeline

parent 44b0bf76
...@@ -155,8 +155,6 @@ def model_config( ...@@ -155,8 +155,6 @@ def model_config(
c.loss.tm.weight = 0.1 c.loss.tm.weight = 0.1
elif "multimer" in name: elif "multimer" in name:
c.update(multimer_config_update.copy_and_resolve_references()) c.update(multimer_config_update.copy_and_resolve_references())
del c.model.template.template_pointwise_attention
del c.loss.fape.backbone
# TODO: Change max_msa_clusters and max_extra_msa to multimer feats within model # TODO: Change max_msa_clusters and max_extra_msa to multimer feats within model
if re.fullmatch("^model_[1-5]_multimer(_v2)?$", name): if re.fullmatch("^model_[1-5]_multimer(_v2)?$", name):
...@@ -354,6 +352,8 @@ config = mlc.ConfigDict( ...@@ -354,6 +352,8 @@ config = mlc.ConfigDict(
"max_templates": 4, "max_templates": 4,
"crop": False, "crop": False,
"crop_size": None, "crop_size": None,
"spatial_crop_prob": None,
"interface_threshold": None,
"supervised": False, "supervised": False,
"uniform_recycling": False, "uniform_recycling": False,
}, },
...@@ -367,6 +367,8 @@ config = mlc.ConfigDict( ...@@ -367,6 +367,8 @@ config = mlc.ConfigDict(
"max_templates": 4, "max_templates": 4,
"crop": False, "crop": False,
"crop_size": None, "crop_size": None,
"spatial_crop_prob": None,
"interface_threshold": None,
"supervised": True, "supervised": True,
"uniform_recycling": False, "uniform_recycling": False,
}, },
...@@ -381,6 +383,8 @@ config = mlc.ConfigDict( ...@@ -381,6 +383,8 @@ config = mlc.ConfigDict(
"shuffle_top_k_prefiltered": 20, "shuffle_top_k_prefiltered": 20,
"crop": True, "crop": True,
"crop_size": 256, "crop_size": 256,
"spatial_crop_prob": 0.,
"interface_threshold": None,
"supervised": True, "supervised": True,
"clamp_prob": 0.9, "clamp_prob": 0.9,
"max_distillation_msa_clusters": 1000, "max_distillation_msa_clusters": 1000,
...@@ -709,7 +713,9 @@ multimer_config_update = mlc.ConfigDict({ ...@@ -709,7 +713,9 @@ multimer_config_update = mlc.ConfigDict({
"train": { "train": {
"max_msa_clusters": 508, "max_msa_clusters": 508,
"max_extra_msa": 2048, "max_extra_msa": 2048,
"crop_size": 640 "crop_size": 640,
"spatial_crop_prob": 0.5,
"interface_threshold": 10.
}, },
}, },
"model": { "model": {
...@@ -735,6 +741,7 @@ multimer_config_update = mlc.ConfigDict({ ...@@ -735,6 +741,7 @@ multimer_config_update = mlc.ConfigDict({
"tri_mul_first": True, "tri_mul_first": True,
"fuse_projection_weights": True "fuse_projection_weights": True
}, },
"template_pointwise_attention": None, # Not used in Multimer
"c_t": c_t, "c_t": c_t,
"c_z": c_z, "c_z": c_z,
"use_unit_vector": True "use_unit_vector": True
...@@ -778,7 +785,8 @@ multimer_config_update = mlc.ConfigDict({ ...@@ -778,7 +785,8 @@ multimer_config_update = mlc.ConfigDict({
"clamp_distance": 30.0, "clamp_distance": 30.0,
"loss_unit_distance": 20.0, "loss_unit_distance": 20.0,
"weight": 0.5 "weight": 0.5
} },
"backbone": None # Not used in Multimer
}, },
"masked_msa": { "masked_msa": {
"num_classes": 22 "num_classes": 22
......
...@@ -4,7 +4,7 @@ import json ...@@ -4,7 +4,7 @@ import json
import logging import logging
import os import os
import pickle import pickle
from typing import Optional, Sequence, List, Any from typing import Optional, Sequence, Any
import ml_collections as mlc import ml_collections as mlc
import pytorch_lightning as pl import pytorch_lightning as pl
...@@ -880,10 +880,7 @@ class OpenFoldMultimerDataLoader(torch.utils.data.DataLoader): ...@@ -880,10 +880,7 @@ class OpenFoldMultimerDataLoader(torch.utils.data.DataLoader):
def __init__(self, *args, config, stage="train", generator=None, **kwargs): def __init__(self, *args, config, stage="train", generator=None, **kwargs):
super(OpenFoldMultimerDataLoader,self).__init__(*args, **kwargs) super(OpenFoldMultimerDataLoader,self).__init__(*args, **kwargs)
self.config = config self.config = config
self.stage = stage self.stage = stage
if(generator is None):
generator = torch.Generator()
self.generator = generator self.generator = generator
print('initialised a multimer dataloader') print('initialised a multimer dataloader')
...@@ -1220,8 +1217,9 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule): ...@@ -1220,8 +1217,9 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule):
) )
def _gen_dataloader(self, stage): def _gen_dataloader(self, stage):
generator = torch.Generator() generator = None
if(self.batch_seed is not None): if(self.batch_seed is not None):
generator = torch.Generator()
generator = generator.manual_seed(self.batch_seed) generator = generator.manual_seed(self.batch_seed)
dataset = None dataset = None
......
...@@ -477,8 +477,9 @@ def make_masked_msa(protein, config, replace_fraction, seed): ...@@ -477,8 +477,9 @@ def make_masked_msa(protein, config, replace_fraction, seed):
sh = protein["msa"].shape sh = protein["msa"].shape
g = torch.Generator(device=protein["msa"].device) g = None
if seed is not None: if seed is not None:
g = torch.Generator(device=protein["msa"].device)
g.manual_seed(seed) g.manual_seed(seed)
sample = torch.rand(sh, device=device, generator=g) sample = torch.rand(sh, device=device, generator=g)
......
...@@ -100,8 +100,9 @@ def make_masked_msa(batch, config, replace_fraction, seed, eps=1e-6): ...@@ -100,8 +100,9 @@ def make_masked_msa(batch, config, replace_fraction, seed, eps=1e-6):
logits = torch.log(categorical_probs + eps) logits = torch.log(categorical_probs + eps)
g = torch.Generator(device=batch["msa"].device) g = None
if seed is not None: if seed is not None:
g = torch.Generator(device=batch["msa"].device)
g.manual_seed(seed) g.manual_seed(seed)
bert_msa = gumbel_max_sample(logits, generator=g) bert_msa = gumbel_max_sample(logits, generator=g)
...@@ -262,8 +263,9 @@ def sample_msa(batch, max_seq, max_extra_msa_seq, seed, inf=1e6): ...@@ -262,8 +263,9 @@ def sample_msa(batch, max_seq, max_extra_msa_seq, seed, inf=1e6):
Returns: Returns:
Protein with sampled msa. Protein with sampled msa.
""" """
g = torch.Generator(device=batch["msa"].device) g = None
if seed is not None: if seed is not None:
g = torch.Generator(device=batch["msa"].device)
g.manual_seed(seed) g.manual_seed(seed)
# Sample uniformly among sequences with at least one non-masked position. # Sample uniformly among sequences with at least one non-masked position.
...@@ -417,8 +419,9 @@ def random_crop_to_size( ...@@ -417,8 +419,9 @@ def random_crop_to_size(
): ):
"""Crop randomly to `crop_size`, or keep as is if shorter than that.""" """Crop randomly to `crop_size`, or keep as is if shorter than that."""
# We want each ensemble to be cropped the same way # We want each ensemble to be cropped the same way
g = torch.Generator(device=protein["seq_length"].device) g = None
if seed is not None: if seed is not None:
g = torch.Generator(device=protein["seq_length"].device)
g.manual_seed(seed) g.manual_seed(seed)
use_spatial_crop = torch.rand((1,), use_spatial_crop = torch.rand((1,),
......
...@@ -13,8 +13,7 @@ ...@@ -13,8 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from functools import partial import random
import torch import torch
from openfold.data import ( from openfold.data import (
...@@ -75,13 +74,44 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed): ...@@ -75,13 +74,44 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed):
transforms.append(data_transforms_multimer.nearest_neighbor_clusters()) transforms.append(data_transforms_multimer.nearest_neighbor_clusters())
transforms.append(data_transforms_multimer.create_msa_feat) transforms.append(data_transforms_multimer.create_msa_feat)
crop_feats = dict(common_cfg.feat)
if mode_cfg.fixed_size:
transforms.append(data_transforms.select_feat(list(crop_feats)))
if mode_cfg.crop:
transforms.append(
data_transforms_multimer.random_crop_to_size(
crop_size=mode_cfg.crop_size,
max_templates=mode_cfg.max_templates,
shape_schema=crop_feats,
spatial_crop_prob=mode_cfg.spatial_crop_prob,
interface_threshold=mode_cfg.interface_threshold,
subsample_templates=mode_cfg.subsample_templates,
seed=ensemble_seed + 1,
)
)
transforms.append(
data_transforms.make_fixed_size(
shape_schema=crop_feats,
msa_cluster_size=pad_msa_clusters,
extra_msa_size=mode_cfg.max_extra_msa,
num_res=mode_cfg.crop_size,
num_templates=mode_cfg.max_templates,
)
)
else:
transforms.append(
data_transforms.crop_templates(mode_cfg.max_templates)
)
return transforms return transforms
def process_tensors_from_config(tensors, common_cfg, mode_cfg): def process_tensors_from_config(tensors, common_cfg, mode_cfg):
"""Based on the config, apply filters and transformations to the data.""" """Based on the config, apply filters and transformations to the data."""
ensemble_seed = torch.Generator().seed() ensemble_seed = random.randint(0, torch.iinfo(torch.int32).max)
def wrap_ensemble_fn(data, i): def wrap_ensemble_fn(data, i):
"""Function to be mapped over the ensemble dimension.""" """Function to be mapped over the ensemble dimension."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment