"tools/cfgs/vscode:/vscode.git/clone" did not exist on "a481fba1b92c753ec9a8d8e3c917fcf3d8a1c201"
Commit a18f98cf authored by Christina Floristean's avatar Christina Floristean
Browse files

Seed fixes for multimer data pipeline

parent 44b0bf76
......@@ -155,8 +155,6 @@ def model_config(
c.loss.tm.weight = 0.1
elif "multimer" in name:
c.update(multimer_config_update.copy_and_resolve_references())
del c.model.template.template_pointwise_attention
del c.loss.fape.backbone
# TODO: Change max_msa_clusters and max_extra_msa to multimer feats within model
if re.fullmatch("^model_[1-5]_multimer(_v2)?$", name):
......@@ -354,6 +352,8 @@ config = mlc.ConfigDict(
"max_templates": 4,
"crop": False,
"crop_size": None,
"spatial_crop_prob": None,
"interface_threshold": None,
"supervised": False,
"uniform_recycling": False,
},
......@@ -367,6 +367,8 @@ config = mlc.ConfigDict(
"max_templates": 4,
"crop": False,
"crop_size": None,
"spatial_crop_prob": None,
"interface_threshold": None,
"supervised": True,
"uniform_recycling": False,
},
......@@ -381,6 +383,8 @@ config = mlc.ConfigDict(
"shuffle_top_k_prefiltered": 20,
"crop": True,
"crop_size": 256,
"spatial_crop_prob": 0.,
"interface_threshold": None,
"supervised": True,
"clamp_prob": 0.9,
"max_distillation_msa_clusters": 1000,
......@@ -709,7 +713,9 @@ multimer_config_update = mlc.ConfigDict({
"train": {
"max_msa_clusters": 508,
"max_extra_msa": 2048,
"crop_size": 640
"crop_size": 640,
"spatial_crop_prob": 0.5,
"interface_threshold": 10.
},
},
"model": {
......@@ -735,6 +741,7 @@ multimer_config_update = mlc.ConfigDict({
"tri_mul_first": True,
"fuse_projection_weights": True
},
"template_pointwise_attention": None, # Not used in Multimer
"c_t": c_t,
"c_z": c_z,
"use_unit_vector": True
......@@ -778,7 +785,8 @@ multimer_config_update = mlc.ConfigDict({
"clamp_distance": 30.0,
"loss_unit_distance": 20.0,
"weight": 0.5
}
},
"backbone": None # Not used in Multimer
},
"masked_msa": {
"num_classes": 22
......
......@@ -4,7 +4,7 @@ import json
import logging
import os
import pickle
from typing import Optional, Sequence, List, Any
from typing import Optional, Sequence, Any
import ml_collections as mlc
import pytorch_lightning as pl
......@@ -882,9 +882,6 @@ class OpenFoldMultimerDataLoader(torch.utils.data.DataLoader):
self.config = config
self.stage = stage
if(generator is None):
generator = torch.Generator()
self.generator = generator
print('initialised a multimer dataloader')
def __iter__(self):
......@@ -1220,8 +1217,9 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule):
)
def _gen_dataloader(self, stage):
generator = torch.Generator()
generator = None
if(self.batch_seed is not None):
generator = torch.Generator()
generator = generator.manual_seed(self.batch_seed)
dataset = None
......
......@@ -477,8 +477,9 @@ def make_masked_msa(protein, config, replace_fraction, seed):
sh = protein["msa"].shape
g = torch.Generator(device=protein["msa"].device)
g = None
if seed is not None:
g = torch.Generator(device=protein["msa"].device)
g.manual_seed(seed)
sample = torch.rand(sh, device=device, generator=g)
......
......@@ -100,8 +100,9 @@ def make_masked_msa(batch, config, replace_fraction, seed, eps=1e-6):
logits = torch.log(categorical_probs + eps)
g = torch.Generator(device=batch["msa"].device)
g = None
if seed is not None:
g = torch.Generator(device=batch["msa"].device)
g.manual_seed(seed)
bert_msa = gumbel_max_sample(logits, generator=g)
......@@ -262,8 +263,9 @@ def sample_msa(batch, max_seq, max_extra_msa_seq, seed, inf=1e6):
Returns:
Protein with sampled msa.
"""
g = torch.Generator(device=batch["msa"].device)
g = None
if seed is not None:
g = torch.Generator(device=batch["msa"].device)
g.manual_seed(seed)
# Sample uniformly among sequences with at least one non-masked position.
......@@ -417,8 +419,9 @@ def random_crop_to_size(
):
"""Crop randomly to `crop_size`, or keep as is if shorter than that."""
# We want each ensemble to be cropped the same way
g = torch.Generator(device=protein["seq_length"].device)
g = None
if seed is not None:
g = torch.Generator(device=protein["seq_length"].device)
g.manual_seed(seed)
use_spatial_crop = torch.rand((1,),
......
......@@ -13,8 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import random
import torch
from openfold.data import (
......@@ -75,13 +74,44 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed):
transforms.append(data_transforms_multimer.nearest_neighbor_clusters())
transforms.append(data_transforms_multimer.create_msa_feat)
crop_feats = dict(common_cfg.feat)
if mode_cfg.fixed_size:
transforms.append(data_transforms.select_feat(list(crop_feats)))
if mode_cfg.crop:
transforms.append(
data_transforms_multimer.random_crop_to_size(
crop_size=mode_cfg.crop_size,
max_templates=mode_cfg.max_templates,
shape_schema=crop_feats,
spatial_crop_prob=mode_cfg.spatial_crop_prob,
interface_threshold=mode_cfg.interface_threshold,
subsample_templates=mode_cfg.subsample_templates,
seed=ensemble_seed + 1,
)
)
transforms.append(
data_transforms.make_fixed_size(
shape_schema=crop_feats,
msa_cluster_size=pad_msa_clusters,
extra_msa_size=mode_cfg.max_extra_msa,
num_res=mode_cfg.crop_size,
num_templates=mode_cfg.max_templates,
)
)
else:
transforms.append(
data_transforms.crop_templates(mode_cfg.max_templates)
)
return transforms
def process_tensors_from_config(tensors, common_cfg, mode_cfg):
"""Based on the config, apply filters and transformations to the data."""
ensemble_seed = torch.Generator().seed()
ensemble_seed = random.randint(0, torch.iinfo(torch.int32).max)
def wrap_ensemble_fn(data, i):
"""Function to be mapped over the ensemble dimension."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment