Commit 0cf1541c authored by Christina Floristean's avatar Christina Floristean
Browse files

Refactoring multimer data pipeline and permutation alignment.

parent 377f854c
...@@ -19,6 +19,8 @@ dependencies: ...@@ -19,6 +19,8 @@ dependencies:
- deepspeed==0.5.10 - deepspeed==0.5.10
- dm-tree==0.1.6 - dm-tree==0.1.6
- ml-collections==0.1.0 - ml-collections==0.1.0
- jax==0.3.25
- pandas==2.0.2
- numpy==1.21.2 - numpy==1.21.2
- PyYAML==5.4.1 - PyYAML==5.4.1
- requests==2.26.0 - requests==2.26.0
......
...@@ -749,6 +749,9 @@ multimer_config_update = mlc.ConfigDict({ ...@@ -749,6 +749,9 @@ multimer_config_update = mlc.ConfigDict({
"sym_id", "sym_id",
] ]
}, },
"supervised": {
"clamp_prob": 1.
},
# TODO: Change max_msa_clusters and max_extra_msa to multimer feats within model: # TODO: Change max_msa_clusters and max_extra_msa to multimer feats within model:
# c.model.input_embedder.num_msa = 508 # c.model.input_embedder.num_msa = 508
# c.model.extra_msa.extra_msa_embedder.num_extra_msa = 2048 # c.model.extra_msa.extra_msa_embedder.num_extra_msa = 2048
...@@ -765,7 +768,8 @@ multimer_config_update = mlc.ConfigDict({ ...@@ -765,7 +768,8 @@ multimer_config_update = mlc.ConfigDict({
"max_extra_msa": 2048, "max_extra_msa": 2048,
"crop_size": 640, "crop_size": 640,
"spatial_crop_prob": 0.5, "spatial_crop_prob": 0.5,
"interface_threshold": 10. "interface_threshold": 10.,
"clamp_prob": 1.,
}, },
}, },
"model": { "model": {
......
...@@ -4,7 +4,7 @@ import json ...@@ -4,7 +4,7 @@ import json
import logging import logging
import os import os
import pickle import pickle
from typing import Optional, Sequence, Any from typing import Optional, Sequence, Any, Union
import ml_collections as mlc import ml_collections as mlc
import pytorch_lightning as pl import pytorch_lightning as pl
...@@ -18,21 +18,9 @@ from openfold.data import ( ...@@ -18,21 +18,9 @@ from openfold.data import (
templates, templates,
) )
from openfold.utils.tensor_utils import dict_multimap from openfold.utils.tensor_utils import dict_multimap
import contextlib
import tempfile
from openfold.utils.tensor_utils import ( from openfold.utils.tensor_utils import (
tensor_tree_map, tensor_tree_map,
) )
import random
logging.basicConfig(level=logging.INFO)
@contextlib.contextmanager
def temp_fasta_file(sequence_str):
"""function that create temparory fasta file used in multimer datapipeline"""
with tempfile.NamedTemporaryFile("w", suffix=".fasta") as fasta_file:
fasta_file.write(sequence_str)
fasta_file.seek(0)
yield fasta_file.name
class OpenFoldSingleDataset(torch.utils.data.Dataset): class OpenFoldSingleDataset(torch.utils.data.Dataset):
...@@ -116,21 +104,21 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -116,21 +104,21 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
self.supported_exts = [".cif", ".core", ".pdb"] self.supported_exts = [".cif", ".core", ".pdb"]
valid_modes = ["train", "eval", "predict"] valid_modes = ["train", "eval", "predict"]
if(mode not in valid_modes): if mode not in valid_modes:
raise ValueError(f'mode must be one of {valid_modes}') raise ValueError(f'mode must be one of {valid_modes}')
if(template_release_dates_cache_path is None): if template_release_dates_cache_path is None:
logging.warning( logging.warning(
"Template release dates cache does not exist. Remember to run " "Template release dates cache does not exist. Remember to run "
"scripts/generate_mmcif_cache.py before running OpenFold" "scripts/generate_mmcif_cache.py before running OpenFold"
) )
if(alignment_index is not None): if alignment_index is not None:
self._chain_ids = list(alignment_index.keys()) self._chain_ids = list(alignment_index.keys())
else: else:
self._chain_ids = list(os.listdir(alignment_dir)) self._chain_ids = list(os.listdir(alignment_dir))
if(filter_path is not None): if filter_path is not None:
with open(filter_path, "r") as f: with open(filter_path, "r") as f:
chains_to_include = set([l.strip() for l in f.readlines()]) chains_to_include = set([l.strip() for l in f.readlines()])
...@@ -182,7 +170,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -182,7 +170,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
template_featurizer=template_featurizer, template_featurizer=template_featurizer,
) )
if(not self._output_raw): if not self._output_raw:
self.feature_pipeline = feature_pipeline.FeaturePipeline(config) self.feature_pipeline = feature_pipeline.FeaturePipeline(config)
def _parse_mmcif(self, path, file_id, chain_id, alignment_dir, alignment_index): def _parse_mmcif(self, path, file_id, chain_id, alignment_dir, alignment_index):
...@@ -195,7 +183,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -195,7 +183,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
# Crash if an error is encountered. Any parsing errors should have # Crash if an error is encountered. Any parsing errors should have
# been dealt with at the alignment stage. # been dealt with at the alignment stage.
if(mmcif_object.mmcif_object is None): if mmcif_object.mmcif_object is None:
raise list(mmcif_object.errors.values())[0] raise list(mmcif_object.errors.values())[0]
mmcif_object = mmcif_object.mmcif_object mmcif_object = mmcif_object.mmcif_object
...@@ -220,47 +208,46 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -220,47 +208,46 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
alignment_dir = os.path.join(self.alignment_dir, name) alignment_dir = os.path.join(self.alignment_dir, name)
alignment_index = None alignment_index = None
if(self.alignment_index is not None): if self.alignment_index is not None:
alignment_dir = self.alignment_dir alignment_dir = self.alignment_dir
alignment_index = self.alignment_index[name] alignment_index = self.alignment_index[name]
if(self.mode == 'train' or self.mode == 'eval'): if self.mode == 'train' or self.mode == 'eval':
spl = name.rsplit('_', 1) spl = name.rsplit('_', 1)
if(len(spl) == 2): if len(spl) == 2:
file_id, chain_id = spl file_id, chain_id = spl
else: else:
file_id, = spl file_id, = spl
chain_id = None chain_id = None
path = os.path.join(self.data_dir, file_id) path = os.path.join(self.data_dir, file_id)
structure_index_entry = None if self._structure_index is not None:
if(self._structure_index is not None):
structure_index_entry = self._structure_index[name] structure_index_entry = self._structure_index[name]
assert(len(structure_index_entry["files"]) == 1) assert (len(structure_index_entry["files"]) == 1)
filename, _, _ = structure_index_entry["files"][0] filename, _, _ = structure_index_entry["files"][0]
ext = os.path.splitext(filename)[1] ext = os.path.splitext(filename)[1]
else: else:
ext = None ext = None
for e in self.supported_exts: for e in self.supported_exts:
if(os.path.exists(path + e)): if os.path.exists(path + e):
ext = e ext = e
break break
if(ext is None): if ext is None:
raise ValueError("Invalid file type") raise ValueError("Invalid file type")
path += ext path += ext
if(ext == ".cif"): if ext == ".cif":
data = self._parse_mmcif( data = self._parse_mmcif(
path, file_id, chain_id, alignment_dir, alignment_index, path, file_id, chain_id, alignment_dir, alignment_index,
) )
elif(ext == ".core"): elif ext == ".core":
data = self.data_pipeline.process_core( data = self.data_pipeline.process_core(
path, alignment_dir, alignment_index, path, alignment_dir, alignment_index,
) )
elif(ext == ".pdb"): elif ext == ".pdb":
structure_index = None structure_index = None
if(self._structure_index is not None): if self._structure_index is not None:
structure_index = self._structure_index[name] structure_index = self._structure_index[name]
data = self.data_pipeline.process_pdb( data = self.data_pipeline.process_pdb(
pdb_path=path, pdb_path=path,
...@@ -280,7 +267,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -280,7 +267,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
alignment_index=alignment_index, alignment_index=alignment_index,
) )
if(self._output_raw): if self._output_raw:
return data return data
feats = self.feature_pipeline.process_features( feats = self.feature_pipeline.process_features(
...@@ -305,7 +292,6 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset): ...@@ -305,7 +292,6 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
template_mmcif_dir: str, template_mmcif_dir: str,
max_template_date: str, max_template_date: str,
config: mlc.ConfigDict, config: mlc.ConfigDict,
chain_data_cache_path: Optional[str] = None,
mmcif_data_cache_path: Optional[str] = None, mmcif_data_cache_path: Optional[str] = None,
kalign_binary_path: str = '/usr/bin/kalign', kalign_binary_path: str = '/usr/bin/kalign',
max_template_hits: int = 4, max_template_hits: int = 4,
...@@ -336,15 +322,10 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset): ...@@ -336,15 +322,10 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
Path to a directory containing template mmCIF files. Path to a directory containing template mmCIF files.
config: config:
A dataset config object. See openfold.config A dataset config object. See openfold.config
chain_data_cache_path:
Path to cache of data_dir generated by
scripts/generate_chain_data_cache.py
mmcif_data_cache_path: mmcif_data_cache_path:
Path to cache of all mmcifs files generated by Path to cache of all mmcifs files generated by
scripts/generate_mmcif_cache.py It should be a json file which records scripts/generate_mmcif_cache.py It should be a json file which records
what PDB ID contains which chain(s) what PDB ID contains which chain(s)
kalign_binary_path: kalign_binary_path:
Path to kalign binary. Path to kalign binary.
max_template_hits: max_template_hits:
...@@ -369,17 +350,12 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset): ...@@ -369,17 +350,12 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
""" """
super(OpenFoldSingleMultimerDataset, self).__init__() super(OpenFoldSingleMultimerDataset, self).__init__()
self.data_dir = data_dir self.data_dir = data_dir
self.mmcif_data_cache_path=mmcif_data_cache_path self.mmcif_data_cache_path = mmcif_data_cache_path
self.chain_data_cache = None
if chain_data_cache_path is not None:
with open(chain_data_cache_path, "r") as fp:
self.chain_data_cache = json.load(fp)
assert isinstance(self.chain_data_cache, dict)
if self.mmcif_data_cache_path is not None: if self.mmcif_data_cache_path is not None:
with open(self.mmcif_data_cache_path,"r") as infile: with open(self.mmcif_data_cache_path, "r") as infile:
self.mmcif_data_cache = json.load(infile) self.mmcif_data_cache = json.load(infile)
assert isinstance(self.mmcif_data_cache,dict) assert isinstance(self.mmcif_data_cache, dict)
self.alignment_dir = alignment_dir self.alignment_dir = alignment_dir
self.config = config self.config = config
...@@ -392,39 +368,36 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset): ...@@ -392,39 +368,36 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
self.supported_exts = [".cif", ".core", ".pdb"] self.supported_exts = [".cif", ".core", ".pdb"]
valid_modes = ["train", "eval", "predict"] valid_modes = ["train", "eval", "predict"]
if(mode not in valid_modes): if mode not in valid_modes:
raise ValueError(f'mode must be one of {valid_modes}') raise ValueError(f'mode must be one of {valid_modes}')
if(template_release_dates_cache_path is None): if template_release_dates_cache_path is None:
logging.warning( logging.warning(
"Template release dates cache does not exist. Remember to run " "Template release dates cache does not exist. Remember to run "
"scripts/generate_mmcif_cache.py before running OpenFold" "scripts/generate_mmcif_cache.py before running OpenFold"
) )
if(alignment_index is not None): if self.mmcif_data_cache_path is not None:
self._chain_ids = list(alignment_index.keys()) self._mmcifs = list(self.mmcif_data_cache.keys())
elif self.alignment_index is not None:
self._mmcifs = [i.split("_")[0] for i in list(alignment_index.keys())]
elif self.alignment_dir is not None:
self._mmcifs = [i.split("_")[0] for i in os.listdir(self.alignment_dir)]
else: else:
self._chain_ids = list(os.listdir(alignment_dir)) raise ValueError("You must provide at least one of the mmcif_data_cache or alignment_dir")
if(filter_path is not None): if filter_path is not None:
with open(filter_path, "r") as f: with open(filter_path, "r") as f:
chains_to_include = set([l.strip() for l in f.readlines()]) mmcifs_to_include = set([l.strip() for l in f.readlines()])
self._chain_ids = [ self._mmcifs = [
c for c in self._chain_ids if c in chains_to_include m for m in self._mmcifs if m in mmcifs_to_include
] ]
if self.mmcif_data_cache_path is not None:
self._mmcifs = list(self.mmcif_data_cache.keys())
elif self.mmcif_data_cache_path is None and self.alignment_dir is not None:
self._mmcifs = [i.split("_")[0] for i in os.listdir(self.alignment_dir)]
else:
raise ValueError("You must provide at least one of the mmcif_data_cache or alignment_dir")
self._mmcif_id_to_idx_dict = { self._mmcif_id_to_idx_dict = {
mmcif: i for i, mmcif in enumerate(self._mmcifs) mmcif: i for i, mmcif in enumerate(self._mmcifs)
} }
# changed template_featurizer to hmmsearch for now just to run the test
template_featurizer = templates.HmmsearchHitFeaturizer( template_featurizer = templates.HmmsearchHitFeaturizer(
mmcif_dir=template_mmcif_dir, mmcif_dir=template_mmcif_dir,
max_template_date=max_template_date, max_template_date=max_template_date,
...@@ -443,7 +416,7 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset): ...@@ -443,7 +416,7 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
) )
self.feature_pipeline = feature_pipeline.FeaturePipeline(config) self.feature_pipeline = feature_pipeline.FeaturePipeline(config)
def _parse_mmcif(self, path, file_id,alignment_dir, alignment_index): def _parse_mmcif(self, path, file_id, alignment_dir, alignment_index):
with open(path, 'r') as f: with open(path, 'r') as f:
mmcif_string = f.read() mmcif_string = f.read()
...@@ -453,7 +426,7 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset): ...@@ -453,7 +426,7 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
# Crash if an error is encountered. Any parsing errors should have # Crash if an error is encountered. Any parsing errors should have
# been dealt with at the alignment stage. # been dealt with at the alignment stage.
if(mmcif_object.mmcif_object is None): if mmcif_object.mmcif_object is None:
raise list(mmcif_object.errors.values())[0] raise list(mmcif_object.errors.values())[0]
mmcif_object = mmcif_object.mmcif_object mmcif_object = mmcif_object.mmcif_object
...@@ -466,8 +439,8 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset): ...@@ -466,8 +439,8 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
return data return data
def mmcif_id_to_idx(self, chain_id): def mmcif_id_to_idx(self, mmcif_id):
return self._mmcif_id_to_idx_dict[chain_id] return self._mmcif_id_to_idx_dict[mmcif_id]
def idx_to_mmcif_id(self, idx): def idx_to_mmcif_id(self, idx):
return self._mmcifs[idx] return self._mmcifs[idx]
...@@ -476,20 +449,20 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset): ...@@ -476,20 +449,20 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
mmcif_id = self.idx_to_mmcif_id(idx) mmcif_id = self.idx_to_mmcif_id(idx)
alignment_index = None alignment_index = None
if(self.mode == 'train' or self.mode == 'eval'): if self.mode == 'train' or self.mode == 'eval':
path = os.path.join(self.data_dir, f"{mmcif_id}") path = os.path.join(self.data_dir, f"{mmcif_id}")
ext = None ext = None
for e in self.supported_exts: for e in self.supported_exts:
if(os.path.exists(path + e)): if os.path.exists(path + e):
ext = e ext = e
break break
if(ext is None): if ext is None:
raise ValueError("Invalid file type") raise ValueError("Invalid file type")
#TODO: Add pdb and core exts to data_pipeline for multimer # TODO: Add pdb and core exts to data_pipeline for multimer
path += ext path += ext
if(ext == ".cif"): if ext == ".cif":
data = self._parse_mmcif( data = self._parse_mmcif(
path, mmcif_id, self.alignment_dir, alignment_index, path, mmcif_id, self.alignment_dir, alignment_index,
) )
...@@ -502,11 +475,11 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset): ...@@ -502,11 +475,11 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
alignment_dir=self.alignment_dir alignment_dir=self.alignment_dir
) )
if (self._output_raw): if self._output_raw:
return data return data
# process all_chain_features # process all_chain_features
data,ground_truth = self.feature_pipeline.process_features(data, data = self.feature_pipeline.process_features(data,
mode=self.mode, mode=self.mode,
is_multimer=True) is_multimer=True)
...@@ -516,93 +489,38 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset): ...@@ -516,93 +489,38 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
dtype=torch.int64, dtype=torch.int64,
device=data["aatype"].device) device=data["aatype"].device)
return data, ground_truth return data
def __len__(self): def __len__(self):
return len(self._chain_ids) return len(self._mmcifs)
def deterministic_train_filter(
chain_data_cache_entry: Any,
max_resolution: float = 9.,
max_single_aa_prop: float = 0.8,
) -> bool:
# Hard filters
resolution = chain_data_cache_entry.get("resolution", None)
if(resolution is not None and resolution > max_resolution):
return False
seq = chain_data_cache_entry["seq"]
counts = {}
for aa in seq:
counts.setdefault(aa, 0)
counts[aa] += 1
largest_aa_count = max(counts.values())
largest_single_aa_prop = largest_aa_count / len(seq)
if(largest_single_aa_prop > max_single_aa_prop):
return False
return True def resolution_filter(resolution: int, max_resolution: float) -> bool:
"""Check that the resolution is <= max_resolution permitted"""
return resolution is not None and resolution <= max_resolution
def deterministic_multimer_train_filter( def aa_count_filter(seqs: list, max_single_aa_prop: float) -> bool:
mmcif_data_cache_entry, """Check if any single amino acid accounts for more than max_single_aa_prop percent of the sequence(s)"""
max_resolution:float= 9.,
max_single_aa_prop:float=0.8,
minimum_number_of_residues:int=200,
) -> bool:
"""
Implement multimer training filtering criteria described in
https://www.biorxiv.org/content/10.1101/2021.10.04.463034v2.full.pdf Supplementary section 7.1
"""
# First check resolution
resolution = mmcif_data_cache_entry.get("resolution", None)
if(resolution is not None and resolution > max_resolution) or (resolution is None):
return False
# Then check if any single amino acid accounts for more than 80% of the complex sequences
seqs = mmcif_data_cache_entry["seqs"]
counts = {} counts = {}
for aa in restypes:
counts[aa] = 0
total_len = sum([len(i) for i in seqs])
if total_len<minimum_number_of_residues: # check if the complex has less than 200 residues
return False
for seq in seqs: for seq in seqs:
for aa in seq: for aa in seq:
counts.setdefault(aa, 0)
if aa not in restypes: if aa not in restypes:
return False return False
else: else:
counts[aa] += 1 counts[aa] += 1
total_len = sum([len(i) for i in seqs])
largest_aa_count = max(counts.values()) largest_aa_count = max(counts.values())
largest_single_aa_prop = largest_aa_count / total_len largest_single_aa_prop = largest_aa_count / total_len
if(largest_single_aa_prop > max_single_aa_prop): return largest_single_aa_prop <= max_single_aa_prop
return False
return True
def get_stochastic_train_filter_prob(
chain_data_cache_entry: Any,
) -> float:
# Stochastic filters
probabilities = []
cluster_size = chain_data_cache_entry.get("cluster_size", None)
if(cluster_size is not None and cluster_size > 0):
probabilities.append(1 / cluster_size)
chain_length = len(chain_data_cache_entry["seq"])
probabilities.append((1 / 512) * (max(min(chain_length, 512), 256)))
# Risk of underflow here?
out = 1
for p in probabilities:
out *= p
return out def all_seq_len_filter(seqs: list, minimum_number_of_residues: int) -> bool:
"""Check if the total combined sequence lengths are >= minimum_numer_of_residues"""
total_len = sum([len(i) for i in seqs])
return total_len >= minimum_number_of_residues
class OpenFoldDataset(torch.utils.data.Dataset): class OpenFoldDataset(torch.utils.data.Dataset):
...@@ -612,8 +530,9 @@ class OpenFoldDataset(torch.utils.data.Dataset): ...@@ -612,8 +530,9 @@ class OpenFoldDataset(torch.utils.data.Dataset):
length of an OpenFoldFilteredDataset is arbitrary. Samples are selected length of an OpenFoldFilteredDataset is arbitrary. Samples are selected
and filtered once at initialization. and filtered once at initialization.
""" """
def __init__(self, def __init__(self,
datasets: Sequence[OpenFoldSingleDataset], datasets: Union[Sequence[OpenFoldSingleDataset], Sequence[OpenFoldSingleMultimerDataset]],
probabilities: Sequence[float], probabilities: Sequence[float],
epoch_len: int, epoch_len: int,
generator: torch.Generator = None, generator: torch.Generator = None,
...@@ -624,7 +543,47 @@ class OpenFoldDataset(torch.utils.data.Dataset): ...@@ -624,7 +543,47 @@ class OpenFoldDataset(torch.utils.data.Dataset):
self.epoch_len = epoch_len self.epoch_len = epoch_len
self.generator = generator self.generator = generator
def looped_shuffled_dataset_idx(dataset_len): self._samples = [self.looped_samples(i) for i in range(len(self.datasets))]
if _roll_at_init:
self.reroll()
@staticmethod
def deterministic_train_filter(
cache_entry: Any,
max_resolution: float = 9.,
max_single_aa_prop: float = 0.8,
) -> bool:
# Hard filters
resolution = cache_entry.get("resolution", None)
seqs = [cache_entry["seq"]]
return all([resolution_filter(resolution=resolution,
max_resolution=max_resolution),
aa_count_filter(seqs=seqs,
max_single_aa_prop=max_single_aa_prop)])
@staticmethod
def get_stochastic_train_filter_prob(
cache_entry: Any,
) -> float:
# Stochastic filters
probabilities = []
cluster_size = cache_entry.get("cluster_size", None)
if cluster_size is not None and cluster_size > 0:
probabilities.append(1 / cluster_size)
chain_length = len(cache_entry["seq"])
probabilities.append((1 / 512) * (max(min(chain_length, 512), 256)))
# Risk of underflow here?
out = 1
for p in probabilities:
out *= p
return out
def looped_shuffled_dataset_idx(self, dataset_len):
while True: while True:
# Uniformly shuffle each dataset's indices # Uniformly shuffle each dataset's indices
weights = [1. for _ in range(dataset_len)] weights = [1. for _ in range(dataset_len)]
...@@ -637,10 +596,10 @@ class OpenFoldDataset(torch.utils.data.Dataset): ...@@ -637,10 +596,10 @@ class OpenFoldDataset(torch.utils.data.Dataset):
for idx in shuf: for idx in shuf:
yield idx yield idx
def looped_samples(dataset_idx): def looped_samples(self, dataset_idx):
max_cache_len = int(epoch_len * probabilities[dataset_idx]) max_cache_len = int(self.epoch_len * self.probabilities[dataset_idx])
dataset = self.datasets[dataset_idx] dataset = self.datasets[dataset_idx]
idx_iter = looped_shuffled_dataset_idx(len(dataset)) idx_iter = self.looped_shuffled_dataset_idx(len(dataset))
chain_data_cache = dataset.chain_data_cache chain_data_cache = dataset.chain_data_cache
while True: while True:
weights = [] weights = []
...@@ -649,10 +608,10 @@ class OpenFoldDataset(torch.utils.data.Dataset): ...@@ -649,10 +608,10 @@ class OpenFoldDataset(torch.utils.data.Dataset):
candidate_idx = next(idx_iter) candidate_idx = next(idx_iter)
chain_id = dataset.idx_to_chain_id(candidate_idx) chain_id = dataset.idx_to_chain_id(candidate_idx)
chain_data_cache_entry = chain_data_cache[chain_id] chain_data_cache_entry = chain_data_cache[chain_id]
if(not deterministic_train_filter(chain_data_cache_entry)): if not self.deterministic_train_filter(chain_data_cache_entry):
continue continue
p = get_stochastic_train_filter_prob( p = self.get_stochastic_train_filter_prob(
chain_data_cache_entry, chain_data_cache_entry,
) )
weights.append([1. - p, p]) weights.append([1. - p, p])
...@@ -670,10 +629,6 @@ class OpenFoldDataset(torch.utils.data.Dataset): ...@@ -670,10 +629,6 @@ class OpenFoldDataset(torch.utils.data.Dataset):
for datapoint_idx in cache: for datapoint_idx in cache:
yield datapoint_idx yield datapoint_idx
self._samples = [looped_samples(i) for i in range(len(self.datasets))]
if(_roll_at_init):
self.reroll()
def __getitem__(self, idx): def __getitem__(self, idx):
dataset_idx, datapoint_idx = self.datapoints[idx] dataset_idx, datapoint_idx = self.datapoints[idx]
return self.datasets[dataset_idx][datapoint_idx] return self.datasets[dataset_idx][datapoint_idx]
...@@ -695,65 +650,91 @@ class OpenFoldDataset(torch.utils.data.Dataset): ...@@ -695,65 +650,91 @@ class OpenFoldDataset(torch.utils.data.Dataset):
self.datapoints.append((dataset_idx, datapoint_idx)) self.datapoints.append((dataset_idx, datapoint_idx))
class OpenFoldMultimerDataset(torch.utils.data.Dataset): class OpenFoldMultimerDataset(OpenFoldDataset):
""" """
Create a torch Dataset object for multimer training and Create a torch Dataset object for multimer training and
add filtering steps described in AlphaFold Multimer's paper: add filtering steps described in AlphaFold Multimer's paper:
https://www.biorxiv.org/content/10.1101/2021.10.04.463034v2.full.pdf Supplementary section 7.1 https://www.biorxiv.org/content/10.1101/2021.10.04.463034v2.full.pdf Supplementary section 7.1
""" """
def __init__(self, def __init__(self,
datasets: Sequence[OpenFoldSingleDataset], datasets: Sequence[OpenFoldSingleMultimerDataset],
probabilities: Sequence[float], probabilities: Sequence[float],
epoch_len: int, epoch_len: int,
generator: torch.Generator = None, generator: torch.Generator = None,
_roll_at_init: bool = True, _roll_at_init: bool = True
): ):
self.datasets = datasets super(OpenFoldMultimerDataset).__init__(datasets=datasets,
self.probabilities = probabilities probabilities=probabilities,
self.epoch_len = epoch_len epoch_len=epoch_len,
self.generator = generator generator=generator,
if _roll_at_init: _roll_at_init=_roll_at_init)
self.reroll()
def filter_samples(self,dataset_idx): @staticmethod
def deterministic_train_filter(
cache_entry: Any,
max_resolution: float = 9.,
max_single_aa_prop: float = 0.8,
minimum_number_of_residues: int = 200,
) -> bool:
"""
Implement multimer training filtering criteria described in
https://www.biorxiv.org/content/10.1101/2021.10.04.463034v2.full.pdf Supplementary section 7.1
"""
resolution = cache_entry.get("resolution", None)
seqs = cache_entry["seqs"]
return all([resolution_filter(resolution=resolution,
max_resolution=max_resolution),
all_seq_len_filter(seqs=seqs,
minimum_number_of_residues=minimum_number_of_residues),
aa_count_filter(seqs=seqs,
max_single_aa_prop=max_single_aa_prop)])
@staticmethod
def get_stochastic_train_filter_prob(
cache_entry: Any,
) -> float:
# Stochastic filters
cluster_sizes = cache_entry.get("cluster_sizes", [])
chain_probs = [1 / c for c in cluster_sizes if c > 0]
if chain_probs:
return sum(chain_probs)
return 1.
def looped_samples(self, dataset_idx):
max_cache_len = int(self.epoch_len * self.probabilities[dataset_idx])
dataset = self.datasets[dataset_idx] dataset = self.datasets[dataset_idx]
mmcif_data_cache = dataset.mmcif_data_cache if hasattr(dataset,"mmcif_data_cache") else None idx_iter = self.looped_shuffled_dataset_idx(len(dataset))
selected_idx = [] mmcif_data_cache = dataset.mmcif_data_cache
if mmcif_data_cache is not None: while True:
for i in range(len(mmcif_data_cache)): weights = []
mmcif_id = dataset.idx_to_mmcif_id(i) idx = []
for _ in range(max_cache_len):
candidate_idx = next(idx_iter)
mmcif_id = dataset.idx_to_mmcif_id(candidate_idx)
mmcif_data_cache_entry = mmcif_data_cache[mmcif_id] mmcif_data_cache_entry = mmcif_data_cache[mmcif_id]
if deterministic_multimer_train_filter(mmcif_data_cache_entry, if not self.deterministic_train_filter(mmcif_data_cache_entry):
max_resolution=9): continue
selected_idx.append(i)
logging.info(f"Originally {len(mmcif_data_cache)} mmcifs. After filtering: {len(selected_idx)}")
else:
selected_idx = list(range(len(dataset._mmcif_id_to_idx_dict)))
return selected_idx
def __getitem__(self, idx):
dataset_idx, datapoint_idx = self.datapoints[idx]
return self.datasets[dataset_idx][datapoint_idx]
def __len__(self): p = self.get_stochastic_train_filter_prob(
return self.epoch_len mmcif_data_cache_entry,
)
weights.append([1. - p, p])
idx.append(candidate_idx)
def reroll(self): samples = torch.multinomial(
dataset_choices = torch.multinomial( torch.tensor(weights),
torch.tensor(self.probabilities), num_samples=1,
num_samples=len(self.probabilities),
replacement=True,
generator=self.generator, generator=self.generator,
) )
samples = samples.squeeze()
self.datapoints = [] cache = [i for i, s in zip(idx, samples) if s]
for dataset_idx in dataset_choices:
selected_idx = self.filter_samples(dataset_idx) for datapoint_idx in cache:
random.shuffle(selected_idx) yield datapoint_idx
if len(selected_idx)<self.epoch_len:
self.epoch_len = len(selected_idx)
logging.info(f"self.epoch_len is {self.epoch_len}")
self.datapoints += [(dataset_idx, selected_idx[i]) for i in range(self.epoch_len) ]
class OpenFoldBatchCollator: class OpenFoldBatchCollator:
...@@ -776,7 +757,7 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader): ...@@ -776,7 +757,7 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader):
max_iters = self.config.common.max_recycling_iters max_iters = self.config.common.max_recycling_iters
if(stage_cfg.uniform_recycling): if stage_cfg.uniform_recycling:
recycling_probs = [ recycling_probs = [
1. / (max_iters + 1) for _ in range(max_iters + 1) 1. / (max_iters + 1) for _ in range(max_iters + 1)
] ]
...@@ -828,7 +809,7 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader): ...@@ -828,7 +809,7 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader):
) )
batch[key] = sample_tensor batch[key] = sample_tensor
if(key == "no_recycling_iters"): if key == "no_recycling_iters":
no_recycling = sample no_recycling = sample
resample_recycling = lambda t: t[..., :no_recycling + 1] resample_recycling = lambda t: t[..., :no_recycling + 1]
...@@ -846,23 +827,6 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader): ...@@ -846,23 +827,6 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader):
return _batch_prop_gen(it) return _batch_prop_gen(it)
class OpenFoldMultimerDataLoader(torch.utils.data.DataLoader):
def __init__(self, *args, config, stage="train", generator=None, **kwargs):
super(OpenFoldMultimerDataLoader,self).__init__(*args, **kwargs)
self.config = config
self.stage = stage
self.generator = generator
def __iter__(self):
it = super().__iter__()
def _batch_prop_gen(iterator):
for batch in iterator:
yield batch
return _batch_prop_gen(it)
class OpenFoldDataModule(pl.LightningDataModule): class OpenFoldDataModule(pl.LightningDataModule):
def __init__(self, def __init__(self,
config: mlc.ConfigDict, config: mlc.ConfigDict,
...@@ -917,7 +881,7 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -917,7 +881,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
self.batch_seed = batch_seed self.batch_seed = batch_seed
self.train_epoch_len = train_epoch_len self.train_epoch_len = train_epoch_len
if(self.train_data_dir is None and self.predict_data_dir is None): if self.train_data_dir is None and self.predict_data_dir is None:
raise ValueError( raise ValueError(
'At least one of train_data_dir or predict_data_dir must be ' 'At least one of train_data_dir or predict_data_dir must be '
'specified' 'specified'
...@@ -925,15 +889,15 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -925,15 +889,15 @@ class OpenFoldDataModule(pl.LightningDataModule):
self.training_mode = self.train_data_dir is not None self.training_mode = self.train_data_dir is not None
if(self.training_mode and train_alignment_dir is None): if self.training_mode and train_alignment_dir is None:
raise ValueError( raise ValueError(
'In training mode, train_alignment_dir must be specified' 'In training mode, train_alignment_dir must be specified'
) )
elif(not self.training_mode and predict_alignment_dir is None): elif not self.training_mode and predict_alignment_dir is None:
raise ValueError( raise ValueError(
'In inference mode, predict_alignment_dir must be specified' 'In inference mode, predict_alignment_dir must be specified'
) )
elif(val_data_dir is not None and val_alignment_dir is None): elif val_data_dir is not None and val_alignment_dir is None:
raise ValueError( raise ValueError(
'If val_data_dir is specified, val_alignment_dir must ' 'If val_data_dir is specified, val_alignment_dir must '
'be specified as well' 'be specified as well'
...@@ -941,17 +905,17 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -941,17 +905,17 @@ class OpenFoldDataModule(pl.LightningDataModule):
# An ad-hoc measure for our particular filesystem restrictions # An ad-hoc measure for our particular filesystem restrictions
self._distillation_structure_index = None self._distillation_structure_index = None
if(_distillation_structure_index_path is not None): if _distillation_structure_index_path is not None:
with open(_distillation_structure_index_path, "r") as fp: with open(_distillation_structure_index_path, "r") as fp:
self._distillation_structure_index = json.load(fp) self._distillation_structure_index = json.load(fp)
self.alignment_index = None self.alignment_index = None
if(alignment_index_path is not None): if alignment_index_path is not None:
with open(alignment_index_path, "r") as fp: with open(alignment_index_path, "r") as fp:
self.alignment_index = json.load(fp) self.alignment_index = json.load(fp)
self.distillation_alignment_index = None self.distillation_alignment_index = None
if(distillation_alignment_index_path is not None): if distillation_alignment_index_path is not None:
with open(distillation_alignment_index_path, "r") as fp: with open(distillation_alignment_index_path, "r") as fp:
self.distillation_alignment_index = json.load(fp) self.distillation_alignment_index = json.load(fp)
...@@ -962,28 +926,24 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -962,28 +926,24 @@ class OpenFoldDataModule(pl.LightningDataModule):
max_template_date=self.max_template_date, max_template_date=self.max_template_date,
config=self.config, config=self.config,
kalign_binary_path=self.kalign_binary_path, kalign_binary_path=self.kalign_binary_path,
template_release_dates_cache_path= template_release_dates_cache_path=self.template_release_dates_cache_path,
self.template_release_dates_cache_path, obsolete_pdbs_file_path=self.obsolete_pdbs_file_path)
obsolete_pdbs_file_path=
self.obsolete_pdbs_file_path,
)
if(self.training_mode): if self.training_mode:
train_dataset = dataset_gen( train_dataset = dataset_gen(
data_dir=self.train_data_dir, data_dir=self.train_data_dir,
chain_data_cache_path=self.train_chain_data_cache_path, chain_data_cache_path=self.train_chain_data_cache_path,
alignment_dir=self.train_alignment_dir, alignment_dir=self.train_alignment_dir,
filter_path=self.train_filter_path, filter_path=self.train_filter_path,
max_template_hits=self.config.train.max_template_hits, max_template_hits=self.config.train.max_template_hits,
shuffle_top_k_prefiltered= shuffle_top_k_prefiltered=self.config.train.shuffle_top_k_prefiltered,
self.config.train.shuffle_top_k_prefiltered,
treat_pdb_as_distillation=False, treat_pdb_as_distillation=False,
mode="train", mode="train",
alignment_index=self.alignment_index, alignment_index=self.alignment_index,
) )
distillation_dataset = None distillation_dataset = None
if(self.distillation_data_dir is not None): if self.distillation_data_dir is not None:
distillation_dataset = dataset_gen( distillation_dataset = dataset_gen(
data_dir=self.distillation_data_dir, data_dir=self.distillation_data_dir,
chain_data_cache_path=self.distillation_chain_data_cache_path, chain_data_cache_path=self.distillation_chain_data_cache_path,
...@@ -998,7 +958,7 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -998,7 +958,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
d_prob = self.config.train.distillation_prob d_prob = self.config.train.distillation_prob
if(distillation_dataset is not None): if distillation_dataset is not None:
datasets = [train_dataset, distillation_dataset] datasets = [train_dataset, distillation_dataset]
d_prob = self.config.train.distillation_prob d_prob = self.config.train.distillation_prob
probabilities = [1. - d_prob, d_prob] probabilities = [1. - d_prob, d_prob]
...@@ -1007,7 +967,7 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -1007,7 +967,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
probabilities = [1.] probabilities = [1.]
generator = None generator = None
if(self.batch_seed is not None): if self.batch_seed is not None:
generator = torch.Generator() generator = torch.Generator()
generator = generator.manual_seed(self.batch_seed + 1) generator = generator.manual_seed(self.batch_seed + 1)
...@@ -1019,7 +979,7 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -1019,7 +979,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
_roll_at_init=False, _roll_at_init=False,
) )
if(self.val_data_dir is not None): if self.val_data_dir is not None:
self.eval_dataset = dataset_gen( self.eval_dataset = dataset_gen(
data_dir=self.val_data_dir, data_dir=self.val_data_dir,
alignment_dir=self.val_alignment_dir, alignment_dir=self.val_alignment_dir,
...@@ -1040,18 +1000,17 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -1040,18 +1000,17 @@ class OpenFoldDataModule(pl.LightningDataModule):
def _gen_dataloader(self, stage): def _gen_dataloader(self, stage):
generator = None generator = None
if(self.batch_seed is not None): if self.batch_seed is not None:
generator = torch.Generator() generator = torch.Generator()
generator = generator.manual_seed(self.batch_seed) generator = generator.manual_seed(self.batch_seed)
dataset = None if stage == "train":
if(stage == "train"):
dataset = self.train_dataset dataset = self.train_dataset
# Filter the dataset, if necessary # Filter the dataset, if necessary
dataset.reroll() dataset.reroll()
elif(stage == "eval"): elif stage == "eval":
dataset = self.eval_dataset dataset = self.eval_dataset
elif(stage == "predict"): elif stage == "predict":
dataset = self.predict_dataset dataset = self.predict_dataset
else: else:
raise ValueError("Invalid stage") raise ValueError("Invalid stage")
...@@ -1074,7 +1033,7 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -1074,7 +1033,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
return self._gen_dataloader("train") return self._gen_dataloader("train")
def val_dataloader(self): def val_dataloader(self):
if(self.eval_dataset is not None): if self.eval_dataset is not None:
return self._gen_dataloader("eval") return self._gen_dataloader("eval")
return None return None
...@@ -1091,16 +1050,19 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule): ...@@ -1091,16 +1050,19 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule):
scripts/generate_mmcif_cache.py mmcif_data_cache_path should be scripts/generate_mmcif_cache.py mmcif_data_cache_path should be
a file that record what chain(s) each mmcif file has a file that record what chain(s) each mmcif file has
""" """
def __init__(self, config: mlc.ConfigDict, def __init__(self, config: mlc.ConfigDict,
template_mmcif_dir: str, max_template_date: str, template_mmcif_dir: str, max_template_date: str,
train_data_dir: Optional[str] = None, train_data_dir: Optional[str] = None,
train_mmcif_data_cache_path:Optional[str] = None, train_mmcif_data_cache_path: Optional[str] = None,
val_mmcif_data_cache_path:Optional[str] = None, val_mmcif_data_cache_path: Optional[str] = None,
**kwargs): **kwargs):
super(OpenFoldMultimerDataModule,self).__init__(config, super(OpenFoldMultimerDataModule, self).__init__(config,
template_mmcif_dir, template_mmcif_dir,
max_template_date, max_template_date,
train_data_dir,**kwargs) train_data_dir,
**kwargs)
self.train_mmcif_data_cache_path = train_mmcif_data_cache_path self.train_mmcif_data_cache_path = train_mmcif_data_cache_path
self.training_mode = self.train_data_dir is not None self.training_mode = self.train_data_dir is not None
self.val_mmcif_data_cache_path = val_mmcif_data_cache_path self.val_mmcif_data_cache_path = val_mmcif_data_cache_path
...@@ -1112,28 +1074,24 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule): ...@@ -1112,28 +1074,24 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule):
max_template_date=self.max_template_date, max_template_date=self.max_template_date,
config=self.config, config=self.config,
kalign_binary_path=self.kalign_binary_path, kalign_binary_path=self.kalign_binary_path,
template_release_dates_cache_path= template_release_dates_cache_path=self.template_release_dates_cache_path,
self.template_release_dates_cache_path, obsolete_pdbs_file_path=self.obsolete_pdbs_file_path)
obsolete_pdbs_file_path=
self.obsolete_pdbs_file_path,
)
if(self.training_mode): if self.training_mode:
train_dataset = dataset_gen( train_dataset = dataset_gen(
data_dir=self.train_data_dir, data_dir=self.train_data_dir,
mmcif_data_cache_path=self.train_mmcif_data_cache_path, mmcif_data_cache_path=self.train_mmcif_data_cache_path,
alignment_dir=self.train_alignment_dir, alignment_dir=self.train_alignment_dir,
filter_path=self.train_filter_path, filter_path=self.train_filter_path,
max_template_hits=self.config.train.max_template_hits, max_template_hits=self.config.train.max_template_hits,
shuffle_top_k_prefiltered= shuffle_top_k_prefiltered=self.config.train.shuffle_top_k_prefiltered,
self.config.train.shuffle_top_k_prefiltered,
treat_pdb_as_distillation=False, treat_pdb_as_distillation=False,
mode="train", mode="train",
alignment_index=self.alignment_index, alignment_index=self.alignment_index,
) )
distillation_dataset = None distillation_dataset = None
if(self.distillation_data_dir is not None): if self.distillation_data_dir is not None:
distillation_dataset = dataset_gen( distillation_dataset = dataset_gen(
data_dir=self.distillation_data_dir, data_dir=self.distillation_data_dir,
alignment_dir=self.distillation_alignment_dir, alignment_dir=self.distillation_alignment_dir,
...@@ -1147,7 +1105,7 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule): ...@@ -1147,7 +1105,7 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule):
d_prob = self.config.train.distillation_prob d_prob = self.config.train.distillation_prob
if(distillation_dataset is not None): if distillation_dataset is not None:
datasets = [train_dataset, distillation_dataset] datasets = [train_dataset, distillation_dataset]
d_prob = self.config.train.distillation_prob d_prob = self.config.train.distillation_prob
probabilities = [1. - d_prob, d_prob] probabilities = [1. - d_prob, d_prob]
...@@ -1156,7 +1114,7 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule): ...@@ -1156,7 +1114,7 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule):
probabilities = [1.] probabilities = [1.]
generator = None generator = None
if(self.batch_seed is not None): if self.batch_seed is not None:
generator = torch.Generator() generator = torch.Generator()
generator = generator.manual_seed(self.batch_seed + 1) generator = generator.manual_seed(self.batch_seed + 1)
...@@ -1168,7 +1126,7 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule): ...@@ -1168,7 +1126,7 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule):
_roll_at_init=True, _roll_at_init=True,
) )
if(self.val_data_dir is not None): if self.val_data_dir is not None:
self.eval_dataset = dataset_gen( self.eval_dataset = dataset_gen(
data_dir=self.val_data_dir, data_dir=self.val_data_dir,
alignment_dir=self.val_alignment_dir, alignment_dir=self.val_alignment_dir,
...@@ -1188,31 +1146,6 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule): ...@@ -1188,31 +1146,6 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule):
mode="predict", mode="predict",
) )
def _gen_dataloader(self, stage):
generator = None
if(self.batch_seed is not None):
generator = torch.Generator()
generator = generator.manual_seed(self.batch_seed)
dataset = None
if(stage == "train"):
dataset = self.train_dataset
# Filter the dataset, if necessary
dataset.reroll()
elif(stage == "eval"):
dataset = self.eval_dataset
elif(stage == "predict"):
dataset = self.predict_dataset
else:
raise ValueError("Invalid stage")
dl = torch.utils.data.DataLoader(
dataset,
batch_size=1,
num_workers=self.config.data_module.data_loaders.num_workers,
)
return dl
class DummyDataset(torch.utils.data.Dataset): class DummyDataset(torch.utils.data.Dataset):
def __init__(self, batch_path): def __init__(self, batch_path):
......
...@@ -93,24 +93,11 @@ def np_example_to_features( ...@@ -93,24 +93,11 @@ def np_example_to_features(
with torch.no_grad(): with torch.no_grad():
if is_multimer: if is_multimer:
if mode == 'train':
features,gt_features = input_pipeline_multimer.process_tensors_from_config(
tensor_dict,
cfg.common,
cfg[mode],
is_training=True
)
return {k: v for k, v in features.items()}, gt_features
else:
features = input_pipeline_multimer.process_tensors_from_config( features = input_pipeline_multimer.process_tensors_from_config(
tensor_dict, tensor_dict,
cfg.common, cfg.common,
cfg[mode], cfg[mode],
is_training=False
) )
return {k: v for k, v in features.items()}
else: else:
features = input_pipeline.process_tensors_from_config( features = input_pipeline.process_tensors_from_config(
tensor_dict, tensor_dict,
......
...@@ -21,17 +21,18 @@ from openfold.data import ( ...@@ -21,17 +21,18 @@ from openfold.data import (
data_transforms_multimer, data_transforms_multimer,
) )
def grountruth_transforms_fns():
def groundtruth_transforms_fns():
transforms = [data_transforms.make_atom14_masks, transforms = [data_transforms.make_atom14_masks,
data_transforms.make_atom14_positions, data_transforms.make_atom14_positions,
data_transforms.atom37_to_frames, data_transforms.atom37_to_frames,
data_transforms.atom37_to_torsion_angles(""), data_transforms.atom37_to_torsion_angles(""),
data_transforms.make_pseudo_beta(""), data_transforms.make_pseudo_beta(""),
data_transforms.get_backbone_frames, data_transforms.get_backbone_frames,
data_transforms.get_chi_angles, data_transforms.get_chi_angles]
]
return transforms return transforms
def nonensembled_transform_fns(): def nonensembled_transform_fns():
"""Input pipeline data transformers that are not ensembled.""" """Input pipeline data transformers that are not ensembled."""
transforms = [ transforms = [
...@@ -112,20 +113,24 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed): ...@@ -112,20 +113,24 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed):
return transforms return transforms
def prepare_ground_truth_features(tensors): def prepare_ground_truth_features(tensors):
"""Prepare ground truth features that are only needed for loss calculation during training""" """Prepare ground truth features that are only needed for loss calculation during training"""
GROUNDTRUTH_FEATURES=['all_atom_mask', 'all_atom_positions','asym_id','sym_id','entity_id'] gt_features = ['all_atom_mask', 'all_atom_positions', 'asym_id', 'sym_id', 'entity_id']
gt_tensors = {k:v for k,v in tensors.items() if k in GROUNDTRUTH_FEATURES} gt_tensors = {k: v for k, v in tensors.items() if k in gt_features}
gt_tensors['aatype'] = tensors['aatype'].to(torch.long) gt_tensors['aatype'] = tensors['aatype'].to(torch.long)
gt_tensors = compose(grountruth_transforms_fns())(gt_tensors) gt_tensors = compose(groundtruth_transforms_fns())(gt_tensors)
return gt_tensors return gt_tensors
def process_tensors_from_config(tensors, common_cfg, mode_cfg,is_training=False):
def process_tensors_from_config(tensors, common_cfg, mode_cfg):
"""Based on the config, apply filters and transformations to the data.""" """Based on the config, apply filters and transformations to the data."""
if is_training: process_gt_feats = mode_cfg.supervised
gt_tensors= prepare_ground_truth_features(tensors) gt_tensors = {}
if process_gt_feats:
gt_tensors = prepare_ground_truth_features(tensors)
ensemble_seed = random.randint(0, torch.iinfo(torch.int32).max) ensemble_seed = random.randint(0, torch.iinfo(torch.int32).max)
tensors['aatype'] = tensors['aatype'].to(torch.long) tensors['aatype'] = tensors['aatype'].to(torch.long)
...@@ -152,9 +157,9 @@ def process_tensors_from_config(tensors, common_cfg, mode_cfg,is_training=False) ...@@ -152,9 +157,9 @@ def process_tensors_from_config(tensors, common_cfg, mode_cfg,is_training=False)
lambda x: wrap_ensemble_fn(tensors, x), torch.arange(num_recycling + 1) lambda x: wrap_ensemble_fn(tensors, x), torch.arange(num_recycling + 1)
) )
if is_training: if process_gt_feats:
return tensors,gt_tensors tensors['gt_features'] = gt_tensors
else:
return tensors return tensors
@data_transforms.curry1 @data_transforms.curry1
......
...@@ -13,35 +13,30 @@ ...@@ -13,35 +13,30 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from functools import partial
import logging
import ml_collections import ml_collections
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.distributions.bernoulli import Bernoulli
from typing import Dict, Optional, Tuple from typing import Dict, Optional, Tuple
from openfold.np import residue_constants from openfold.np import residue_constants
from openfold.utils import feats
from openfold.utils.rigid_utils import Rotation, Rigid from openfold.utils.rigid_utils import Rotation, Rigid
from openfold.utils.geometry.vector import Vec3Array, euclidean_distance from openfold.utils.geometry.vector import Vec3Array, euclidean_distance
from openfold.utils.all_atom_multimer import get_rc_tensor from openfold.utils.all_atom_multimer import get_rc_tensor
from openfold.utils.tensor_utils import ( from openfold.utils.tensor_utils import (
tree_map, tree_map,
tensor_tree_map,
masked_mean, masked_mean,
permute_final_dims, permute_final_dims,
batched_gather,
) )
import random import random
from openfold.np import residue_constants as rc from openfold.np import residue_constants as rc
import logging import logging
import procrustes import procrustes
from openfold.utils.tensor_utils import tensor_tree_map from openfold.utils.tensor_utils import tensor_tree_map
import gc
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def softmax_cross_entropy(logits, labels): def softmax_cross_entropy(logits, labels):
loss = -1 * torch.sum( loss = -1 * torch.sum(
labels * torch.nn.functional.log_softmax(logits, dim=-1), labels * torch.nn.functional.log_softmax(logits, dim=-1),
...@@ -185,11 +180,10 @@ def backbone_loss( ...@@ -185,11 +180,10 @@ def backbone_loss(
eps: float = 1e-4, eps: float = 1e-4,
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
### need to check if the traj belongs to 4*4 matrix or a tensor_7 ### need to check if the traj belongs to 4*4 matrix or a tensor_7
if traj.shape[-1]==7: if traj.shape[-1] == 7:
pred_aff = Rigid.from_tensor_7(traj) pred_aff = Rigid.from_tensor_7(traj)
elif traj.shape[-1]==4: elif traj.shape[-1] == 4:
pred_aff = Rigid.from_tensor_4x4(traj) pred_aff = Rigid.from_tensor_4x4(traj)
pred_aff = Rigid( pred_aff = Rigid(
...@@ -297,7 +291,6 @@ def fape_loss( ...@@ -297,7 +291,6 @@ def fape_loss(
batch: Dict[str, torch.Tensor], batch: Dict[str, torch.Tensor],
config: ml_collections.ConfigDict, config: ml_collections.ConfigDict,
) -> torch.Tensor: ) -> torch.Tensor:
traj = out["sm"]["frames"] traj = out["sm"]["frames"]
asym_id = batch.get("asym_id") asym_id = batch.get("asym_id")
if asym_id is not None: if asym_id is not None:
...@@ -502,7 +495,7 @@ def lddt_ca( ...@@ -502,7 +495,7 @@ def lddt_ca(
ca_pos = residue_constants.atom_order["CA"] ca_pos = residue_constants.atom_order["CA"]
all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :] all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :]
all_atom_positions = all_atom_positions[..., ca_pos, :] all_atom_positions = all_atom_positions[..., ca_pos, :]
all_atom_mask = all_atom_mask[..., ca_pos : (ca_pos + 1)] # keep dim all_atom_mask = all_atom_mask[..., ca_pos: (ca_pos + 1)] # keep dim
return lddt( return lddt(
all_atom_pred_pos, all_atom_pred_pos,
...@@ -532,7 +525,7 @@ def lddt_loss( ...@@ -532,7 +525,7 @@ def lddt_loss(
ca_pos = residue_constants.atom_order["CA"] ca_pos = residue_constants.atom_order["CA"]
all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :] all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :]
all_atom_positions = all_atom_positions[..., ca_pos, :] all_atom_positions = all_atom_positions[..., ca_pos, :]
all_atom_mask = all_atom_mask[..., ca_pos : (ca_pos + 1)] # keep dim all_atom_mask = all_atom_mask[..., ca_pos: (ca_pos + 1)] # keep dim
score = lddt( score = lddt(
all_atom_pred_pos, all_atom_pred_pos,
...@@ -544,7 +537,7 @@ def lddt_loss( ...@@ -544,7 +537,7 @@ def lddt_loss(
# TODO: Remove after initial pipeline testing # TODO: Remove after initial pipeline testing
score = torch.nan_to_num(score, nan=torch.nanmean(score)) score = torch.nan_to_num(score, nan=torch.nanmean(score))
score[score<0] = 0 score[score < 0] = 0
score = score.detach() score = score.detach()
bin_index = torch.floor(score * no_bins).long() bin_index = torch.floor(score * no_bins).long()
...@@ -707,10 +700,10 @@ def compute_tm( ...@@ -707,10 +700,10 @@ def compute_tm(
n = residue_weights.shape[-1] n = residue_weights.shape[-1]
pair_mask = residue_weights.new_ones((n, n), dtype=torch.int32) pair_mask = residue_weights.new_ones((n, n), dtype=torch.int32)
if interface and (asym_id is not None): if interface and (asym_id is not None):
if len(asym_id.shape)>1: if len(asym_id.shape) > 1:
assert len(asym_id.shape)<=2 assert len(asym_id.shape) <= 2
batch_size = asym_id.shape[0] batch_size = asym_id.shape[0]
pair_mask = residue_weights.new_ones((batch_size,n, n), dtype=torch.int32) pair_mask = residue_weights.new_ones((batch_size, n, n), dtype=torch.int32)
pair_mask *= (asym_id[..., None] != asym_id[..., None, :]).to(dtype=pair_mask.dtype) pair_mask *= (asym_id[..., None] != asym_id[..., None, :]).to(dtype=pair_mask.dtype)
predicted_tm_term *= pair_mask predicted_tm_term *= pair_mask
...@@ -727,6 +720,7 @@ def compute_tm( ...@@ -727,6 +720,7 @@ def compute_tm(
argmax = (weighted == torch.max(weighted)).nonzero()[0] argmax = (weighted == torch.max(weighted)).nonzero()[0]
return per_alignment[tuple(argmax)] return per_alignment[tuple(argmax)]
def tm_loss( def tm_loss(
logits, logits,
final_affine_tensor, final_affine_tensor,
...@@ -741,9 +735,9 @@ def tm_loss( ...@@ -741,9 +735,9 @@ def tm_loss(
**kwargs, **kwargs,
): ):
# first check whether this is a tensor_7 or tensor_4*4 # first check whether this is a tensor_7 or tensor_4*4
if final_affine_tensor.shape[-1]==7: if final_affine_tensor.shape[-1] == 7:
pred_affine = Rigid.from_tensor_7(final_affine_tensor) pred_affine = Rigid.from_tensor_7(final_affine_tensor)
elif final_affine_tensor.shape[-1]==4: elif final_affine_tensor.shape[-1] == 4:
pred_affine = Rigid.from_tensor_4x4(final_affine_tensor) pred_affine = Rigid.from_tensor_4x4(final_affine_tensor)
backbone_rigid = Rigid.from_tensor_4x4(backbone_rigid_tensor) backbone_rigid = Rigid.from_tensor_4x4(backbone_rigid_tensor)
...@@ -1221,7 +1215,7 @@ def find_structural_violations( ...@@ -1221,7 +1215,7 @@ def find_structural_violations(
atomtype_radius = atom14_pred_positions.new_tensor(atomtype_radius) atomtype_radius = atom14_pred_positions.new_tensor(atomtype_radius)
#TODO: Consolidate monomer/multimer modes # TODO: Consolidate monomer/multimer modes
asym_id = batch.get("asym_id") asym_id = batch.get("asym_id")
if asym_id is not None: if asym_id is not None:
residx_atom14_to_atom37 = get_rc_tensor( residx_atom14_to_atom37 = get_rc_tensor(
...@@ -1701,9 +1695,6 @@ def compute_rmsd( ...@@ -1701,9 +1695,6 @@ def compute_rmsd(
eps: float = 1e-6, eps: float = 1e-6,
) -> torch.Tensor: ) -> torch.Tensor:
sq_diff = torch.square(true_atom_pos - pred_atom_pos).sum(dim=-1, keepdim=False) sq_diff = torch.square(true_atom_pos - pred_atom_pos).sum(dim=-1, keepdim=False)
del true_atom_pos
del pred_atom_pos
gc.collect()
if atom_mask is not None: if atom_mask is not None:
sq_diff = torch.masked_select(sq_diff, atom_mask.to(sq_diff.device)) sq_diff = torch.masked_select(sq_diff, atom_mask.to(sq_diff.device))
msd = torch.mean(sq_diff) msd = torch.mean(sq_diff)
...@@ -1713,8 +1704,8 @@ def compute_rmsd( ...@@ -1713,8 +1704,8 @@ def compute_rmsd(
def kabsch_rotation(P, Q): def kabsch_rotation(P, Q):
""" """
Use procrustes package to calculate best rotation that minimises Use procrustes package to calculate the best rotation that minimises
the RMSD betwee P and Q the RMSD between P and Q
The optimal rotation matrix was calculated using The optimal rotation matrix was calculated using
the rotational() function from procrustes package. Details can be found here: the rotational() function from procrustes package. Details can be found here:
...@@ -1728,12 +1719,12 @@ def kabsch_rotation(P, Q): ...@@ -1728,12 +1719,12 @@ def kabsch_rotation(P, Q):
A 3*3 rotation matrix A 3*3 rotation matrix
""" """
assert P.shape == torch.Size([Q.shape[0],Q.shape[1]]) assert P.shape == torch.Size([Q.shape[0], Q.shape[1]])
rotation = procrustes.rotational(P.detach().cpu().float().numpy(), rotation = procrustes.rotational(P.detach().cpu().float().numpy(),
Q.detach().cpu().float().numpy(),translate=False,scale=False) Q.detach().cpu().float().numpy(), translate=False, scale=False)
# Rotation.t doesn't mean transpose, t only means get the matrix out of the procruste object # Rotation.t doesn't mean transpose, t only means get the matrix out of the procruste object
rotation = torch.tensor(rotation.t,dtype=torch.float) rotation = torch.tensor(rotation.t, dtype=torch.float)
assert rotation.shape == torch.Size([3,3]) assert rotation.shape == torch.Size([3, 3])
return rotation.to(device=P.device, dtype=P.dtype) return rotation.to(device=P.device, dtype=P.dtype)
...@@ -1756,7 +1747,7 @@ def get_optimal_transform( ...@@ -1756,7 +1747,7 @@ def get_optimal_transform(
# sometimes using fake test inputs generates NaN in the predicted atom positions # sometimes using fake test inputs generates NaN in the predicted atom positions
# # # #
logging.warning(f"src_atom has nan or inf") logging.warning(f"src_atom has nan or inf")
src_atoms = torch.nan_to_num(src_atoms,nan=0.0,posinf=1.0,neginf=1.0) src_atoms = torch.nan_to_num(src_atoms, nan=0.0, posinf=1.0, neginf=1.0)
if mask is not None: if mask is not None:
assert mask.dtype == torch.bool assert mask.dtype == torch.bool
...@@ -1767,21 +1758,15 @@ def get_optimal_transform( ...@@ -1767,21 +1758,15 @@ def get_optimal_transform(
else: else:
src_atoms = src_atoms[mask, :] src_atoms = src_atoms[mask, :]
tgt_atoms = tgt_atoms[mask, :] tgt_atoms = tgt_atoms[mask, :]
src_center = src_atoms.mean(-2, keepdim=True,dtype=src_atoms.dtype) src_center = src_atoms.mean(-2, keepdim=True, dtype=src_atoms.dtype)
tgt_center = tgt_atoms.mean(-2, keepdim=True,dtype=src_atoms.dtype) tgt_center = tgt_atoms.mean(-2, keepdim=True, dtype=src_atoms.dtype)
r = kabsch_rotation(src_atoms,tgt_atoms) r = kabsch_rotation(src_atoms, tgt_atoms)
del src_atoms,tgt_atoms,
gc.collect()
x = tgt_center - src_center @ r x = tgt_center - src_center @ r
del tgt_center,src_center,mask
gc.collect()
return r, x return r, x
def get_least_asym_entity_or_longest_length(batch,input_asym_id): def get_least_asym_entity_or_longest_length(batch, input_asym_id):
""" """
First check how many subunit(s) one sequence has, if there is no tie, e.g. AABBB then select First check how many subunit(s) one sequence has, if there is no tie, e.g. AABBB then select
one of the A as anchor one of the A as anchor
...@@ -1821,13 +1806,14 @@ def get_least_asym_entity_or_longest_length(batch,input_asym_id): ...@@ -1821,13 +1806,14 @@ def get_least_asym_entity_or_longest_length(batch,input_asym_id):
# If still multiple entities, return a random one # If still multiple entities, return a random one
if len(least_asym_entities) > 1: if len(least_asym_entities) > 1:
least_asym_entities = random.choice(least_asym_entities) least_asym_entities = random.choice(least_asym_entities)
assert len(least_asym_entities)==1 assert len(least_asym_entities) == 1
least_asym_entities = least_asym_entities[0] least_asym_entities = least_asym_entities[0]
anchor_gt_asym_id = random.choice(entity_2_asym_list[least_asym_entities]) anchor_gt_asym_id = random.choice(entity_2_asym_list[least_asym_entities])
anchor_pred_asym_ids = [id for id in entity_2_asym_list[least_asym_entities] if id in input_asym_id] anchor_pred_asym_ids = [id for id in entity_2_asym_list[least_asym_entities] if id in input_asym_id]
return anchor_gt_asym_id, anchor_pred_asym_ids return anchor_gt_asym_id, anchor_pred_asym_ids
def greedy_align( def greedy_align(
batch, batch,
per_asym_residue_index, per_asym_residue_index,
...@@ -1843,7 +1829,7 @@ def greedy_align( ...@@ -1843,7 +1829,7 @@ def greedy_align(
""" """
used = [False for _ in range(len(true_ca_poses))] used = [False for _ in range(len(true_ca_poses))]
align = [] align = []
unique_asym_ids = [i for i in torch.unique(batch["asym_id"]) if i!=0] unique_asym_ids = [i for i in torch.unique(batch["asym_id"]) if i != 0]
for cur_asym_id in unique_asym_ids: for cur_asym_id in unique_asym_ids:
i = int(cur_asym_id - 1) i = int(cur_asym_id - 1)
asym_mask = batch["asym_id"] == cur_asym_id asym_mask = batch["asym_id"] == cur_asym_id
...@@ -1857,13 +1843,13 @@ def greedy_align( ...@@ -1857,13 +1843,13 @@ def greedy_align(
for next_asym_id in cur_asym_list: for next_asym_id in cur_asym_list:
j = int(next_asym_id - 1) j = int(next_asym_id - 1)
if not used[j]: # possible candidate if not used[j]: # possible candidate
cropped_pos = torch.index_select(true_ca_poses[j],1,cur_residue_index) cropped_pos = torch.index_select(true_ca_poses[j], 1, cur_residue_index)
mask = torch.index_select(true_ca_masks[j],1,cur_residue_index) mask = torch.index_select(true_ca_masks[j], 1, cur_residue_index)
rmsd = compute_rmsd( rmsd = compute_rmsd(
torch.squeeze(cropped_pos,0), torch.squeeze(cur_pred_pos,0), torch.squeeze(cropped_pos, 0), torch.squeeze(cur_pred_pos, 0),
(cur_pred_mask * mask).bool() (cur_pred_mask * mask).bool()
) )
if (rmsd is not None) and (rmsd < best_rmsd): if rmsd is not None and rmsd < best_rmsd:
best_rmsd = rmsd best_rmsd = rmsd
best_idx = j best_idx = j
...@@ -1873,15 +1859,15 @@ def greedy_align( ...@@ -1873,15 +1859,15 @@ def greedy_align(
return align return align
def pad_features(feature_tensor,nres_pad,pad_dim): def pad_features(feature_tensor, nres_pad, pad_dim):
"""Pad input feature tensor""" """Pad input feature tensor"""
pad_shape = list(feature_tensor.shape) pad_shape = list(feature_tensor.shape)
pad_shape[pad_dim] = nres_pad pad_shape[pad_dim] = nres_pad
padding_tensor = feature_tensor.new_zeros(pad_shape,device=feature_tensor.device) padding_tensor = feature_tensor.new_zeros(pad_shape, device=feature_tensor.device)
return torch.concat((feature_tensor,padding_tensor),dim=pad_dim) return torch.concat((feature_tensor, padding_tensor), dim=pad_dim)
def merge_labels(per_asym_residue_index,labels, align,original_nres): def merge_labels(per_asym_residue_index, labels, align, original_nres):
""" """
Merge ground truth labels according to the permutation results Merge ground truth labels according to the permutation results
...@@ -1898,24 +1884,25 @@ def merge_labels(per_asym_residue_index,labels, align,original_nres): ...@@ -1898,24 +1884,25 @@ def merge_labels(per_asym_residue_index,labels, align,original_nres):
label = labels[j][k] label = labels[j][k]
# to 1-based # to 1-based
cur_residue_index = per_asym_residue_index[i + 1] cur_residue_index = per_asym_residue_index[i + 1]
if len(v.shape)<=1 or "template" in k or "row_mask" in k : if len(v.shape) <= 1 or "template" in k or "row_mask" in k:
continue continue
else: else:
dimension_to_merge = 1 dimension_to_merge = 1
cur_out[i] = label.index_select(dimension_to_merge,cur_residue_index) cur_out[i] = label.index_select(dimension_to_merge, cur_residue_index)
cur_out = [x[1] for x in sorted(cur_out.items())] cur_out = [x[1] for x in sorted(cur_out.items())]
if len(cur_out)>0: if len(cur_out) > 0:
new_v = torch.concat(cur_out, dim=dimension_to_merge) new_v = torch.concat(cur_out, dim=dimension_to_merge)
# below check whether padding is needed # below check whether padding is needed
if new_v.shape[dimension_to_merge]!=original_nres: if new_v.shape[dimension_to_merge] != original_nres:
nres_pad = original_nres - new_v.shape[dimension_to_merge] nres_pad = original_nres - new_v.shape[dimension_to_merge]
new_v = pad_features(new_v,nres_pad,pad_dim=dimension_to_merge) new_v = pad_features(new_v, nres_pad, pad_dim=dimension_to_merge)
outs[k] = new_v outs[k] = new_v
return outs return outs
class AlphaFoldLoss(nn.Module): class AlphaFoldLoss(nn.Module):
"""Aggregation of the various losses described in the supplement""" """Aggregation of the various losses described in the supplement"""
def __init__(self, config): def __init__(self, config):
super(AlphaFoldLoss, self).__init__() super(AlphaFoldLoss, self).__init__()
self.config = config self.config = config
...@@ -1974,13 +1961,13 @@ class AlphaFoldLoss(nn.Module): ...@@ -1974,13 +1961,13 @@ class AlphaFoldLoss(nn.Module):
), ),
} }
if(self.config.tm.enabled): if self.config.tm.enabled:
loss_fns["tm"] = lambda: tm_loss( loss_fns["tm"] = lambda: tm_loss(
logits=out["tm_logits"], logits=out["tm_logits"],
**{**batch, **out, **self.config.tm}, **{**batch, **out, **self.config.tm},
) )
if (self.config.chain_center_of_mass.enabled): if self.config.chain_center_of_mass.enabled:
loss_fns["chain_center_of_mass"] = lambda: chain_center_of_mass_loss( loss_fns["chain_center_of_mass"] = lambda: chain_center_of_mass_loss(
all_atom_pred_pos=out["final_atom_positions"], all_atom_pred_pos=out["final_atom_positions"],
**{**batch, **self.config.chain_center_of_mass}, **{**batch, **self.config.chain_center_of_mass},
...@@ -1991,11 +1978,11 @@ class AlphaFoldLoss(nn.Module): ...@@ -1991,11 +1978,11 @@ class AlphaFoldLoss(nn.Module):
for loss_name, loss_fn in loss_fns.items(): for loss_name, loss_fn in loss_fns.items():
weight = self.config[loss_name].weight weight = self.config[loss_name].weight
loss = loss_fn() loss = loss_fn()
if(torch.isnan(loss) or torch.isinf(loss)): if torch.isnan(loss) or torch.isinf(loss):
#for k,v in batch.items(): # for k,v in batch.items():
# if(torch.any(torch.isnan(v)) or torch.any(torch.isinf(v))): # if torch.any(torch.isnan(v)) or torch.any(torch.isinf(v)):
# logging.warning(f"{k}: is nan") # logging.warning(f"{k}: is nan")
#logging.warning(f"{loss_name}: {loss}") # logging.warning(f"{loss_name}: {loss}")
logging.warning(f"{loss_name} loss is NaN. Skipping...") logging.warning(f"{loss_name} loss is NaN. Skipping...")
loss = loss.new_tensor(0., requires_grad=True) loss = loss.new_tensor(0., requires_grad=True)
cum_loss = cum_loss + weight * loss cum_loss = cum_loss + weight * loss
...@@ -2010,17 +1997,17 @@ class AlphaFoldLoss(nn.Module): ...@@ -2010,17 +1997,17 @@ class AlphaFoldLoss(nn.Module):
losses["loss"] = cum_loss.detach().clone() losses["loss"] = cum_loss.detach().clone()
if(not _return_breakdown): if not _return_breakdown:
return cum_loss return cum_loss
return cum_loss, losses return cum_loss, losses
def forward(self, out, batch, _return_breakdown=False): def forward(self, out, batch, _return_breakdown=False):
if(not _return_breakdown): if not _return_breakdown:
cum_loss = self.loss(out,batch,_return_breakdown) cum_loss = self.loss(out, batch, _return_breakdown)
return cum_loss return cum_loss
else: else:
cum_loss,losses = self.loss(out,batch,_return_breakdown) cum_loss, losses = self.loss(out, batch, _return_breakdown)
return cum_loss, losses return cum_loss, losses
...@@ -2029,12 +2016,13 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss): ...@@ -2029,12 +2016,13 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
Add multi-chain permutation on top of Add multi-chain permutation on top of
AlphaFoldLoss AlphaFoldLoss
""" """
def __init__(self, config): def __init__(self, config):
super(AlphaFoldMultimerLoss, self).__init__(config) super(AlphaFoldMultimerLoss, self).__init__(config)
self.config = config self.config = config
@staticmethod @staticmethod
def split_ground_truth_labels(batch,REQUIRED_FEATURES,split_dim=1): def split_ground_truth_labels(gt_features):
""" """
Splits ground truth features according to chains Splits ground truth features according to chains
...@@ -2042,21 +2030,21 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss): ...@@ -2042,21 +2030,21 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
a list of feature dictionaries with only necessary ground truth features a list of feature dictionaries with only necessary ground truth features
required to finish multi-chain permutation required to finish multi-chain permutation
""" """
unique_asym_ids, asym_id_counts = torch.unique(batch["asym_id"], sorted=False,return_counts=True) unique_asym_ids, asym_id_counts = torch.unique(gt_features["asym_id"], sorted=True, return_counts=True)
unique_asym_ids, asym_id_counts= unique_asym_ids.tolist(),asym_id_counts.tolist() n_res = gt_features["asym_id"].shape[-1]
if 0 in unique_asym_ids:
pop_idx = unique_asym_ids.index(0) def split_dim(shape):
padding_asym_id = unique_asym_ids.pop(pop_idx) return next(iter(i for i, size in enumerate(shape) if size == n_res), None)
padding_asym_counts = asym_id_counts.pop(pop_idx)
unique_asym_ids.append(padding_asym_id) labels = list(map(dict, zip(*[[(k, v) for v in torch.split(v_all, asym_id_counts.tolist(),
asym_id_counts.append(padding_asym_counts) dim=split_dim(v_all.shape))]
for k, v_all in gt_features.items()
labels = list(map(dict, zip(*[[(k, v) for v in torch.split(value, asym_id_counts, dim=split_dim)] for k, value in batch.items() if k in REQUIRED_FEATURES]))) if n_res in v_all.shape])))
return labels return labels
@staticmethod @staticmethod
def get_per_asym_residue_index(features): def get_per_asym_residue_index(features):
unique_asym_ids = [i for i in torch.unique(features["asym_id"]) if i!=0] unique_asym_ids = [i for i in torch.unique(features["asym_id"]) if i != 0]
per_asym_residue_index = {} per_asym_residue_index = {}
for cur_asym_id in unique_asym_ids: for cur_asym_id in unique_asym_ids:
asym_mask = (features["asym_id"] == cur_asym_id).bool() asym_mask = (features["asym_id"] == cur_asym_id).bool()
...@@ -2085,8 +2073,8 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss): ...@@ -2085,8 +2073,8 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
return entity_2_asym_list return entity_2_asym_list
@staticmethod @staticmethod
def calculate_input_mask(true_ca_masks,anchor_gt_idx,anchor_gt_residue, def calculate_input_mask(true_ca_masks, anchor_gt_idx, anchor_gt_residue,
asym_mask,pred_ca_mask): asym_mask, pred_ca_mask):
""" """
Calculate an input mask for downstream optimal transformation computation Calculate an input mask for downstream optimal transformation computation
...@@ -2099,37 +2087,37 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss): ...@@ -2099,37 +2087,37 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
Returns: Returns:
input_mask (Tensor): A boolean mask input_mask (Tensor): A boolean mask
""" """
pred_ca_mask = torch.squeeze(pred_ca_mask,0) pred_ca_mask = torch.squeeze(pred_ca_mask, 0)
asym_mask = torch.squeeze(asym_mask,0) asym_mask = torch.squeeze(asym_mask, 0)
anchor_pred_mask = pred_ca_mask[asym_mask] anchor_pred_mask = pred_ca_mask[asym_mask]
anchor_true_mask = torch.index_select(true_ca_masks[anchor_gt_idx],1,anchor_gt_residue) anchor_true_mask = torch.index_select(true_ca_masks[anchor_gt_idx], 1, anchor_gt_residue)
input_mask = (anchor_true_mask * anchor_pred_mask).bool() input_mask = (anchor_true_mask * anchor_pred_mask).bool()
return input_mask return input_mask
@staticmethod @staticmethod
def calculate_optimal_transform(true_ca_poses, def calculate_optimal_transform(true_ca_poses,
anchor_gt_idx,anchor_gt_residue, anchor_gt_idx, anchor_gt_residue,
true_ca_masks,pred_ca_mask, true_ca_masks, pred_ca_mask,
asym_mask, asym_mask,
pred_ca_pos): pred_ca_pos):
input_mask = AlphaFoldMultimerLoss.calculate_input_mask(true_ca_masks, input_mask = AlphaFoldMultimerLoss.calculate_input_mask(true_ca_masks,
anchor_gt_idx,anchor_gt_residue, anchor_gt_idx, anchor_gt_residue,
asym_mask, asym_mask,
pred_ca_mask) pred_ca_mask)
input_mask = torch.squeeze(input_mask,0) input_mask = torch.squeeze(input_mask, 0)
pred_ca_pos = torch.squeeze(pred_ca_pos,0) pred_ca_pos = torch.squeeze(pred_ca_pos, 0)
asym_mask = torch.squeeze(asym_mask,0) asym_mask = torch.squeeze(asym_mask, 0)
anchor_true_pos = torch.index_select(true_ca_poses[anchor_gt_idx],1,anchor_gt_residue) anchor_true_pos = torch.index_select(true_ca_poses[anchor_gt_idx], 1, anchor_gt_residue)
anchor_pred_pos = pred_ca_pos[asym_mask] anchor_pred_pos = pred_ca_pos[asym_mask]
r, x = get_optimal_transform( r, x = get_optimal_transform(
anchor_pred_pos, torch.squeeze(anchor_true_pos,0), anchor_pred_pos, torch.squeeze(anchor_true_pos, 0),
mask=input_mask mask=input_mask
) )
return r, x return r, x
@staticmethod @staticmethod
def multi_chain_perm_align(out, batch,permutate_chains=False): def multi_chain_perm_align(out, features, ground_truth):
""" """
A class method that first permutate chains in ground truth first A class method that first permutate chains in ground truth first
before calculating the loss. before calculating the loss.
...@@ -2137,18 +2125,24 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss): ...@@ -2137,18 +2125,24 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
Details are described in Section 7.3 in the Supplementary of AlphaFold-Multimer paper: Details are described in Section 7.3 in the Supplementary of AlphaFold-Multimer paper:
https://www.biorxiv.org/content/10.1101/2021.10.04.463034v2 https://www.biorxiv.org/content/10.1101/2021.10.04.463034v2
""" """
feature, ground_truth = batch unique_asym_ids = set(torch.unique(features['asym_id']).tolist())
del batch unique_asym_ids.discard(0) # Remove padding asym_id
if permutate_chains: is_monomer = len(unique_asym_ids) == 1
per_asym_residue_index = AlphaFoldMultimerLoss.get_per_asym_residue_index(features)
if is_monomer:
best_align = list(enumerate(range(len(per_asym_residue_index))))
return best_align, per_asym_residue_index
best_rmsd = float('inf') best_rmsd = float('inf')
best_align = None best_align = None
# First select anchors from predicted structures and ground truths # First select anchors from predicted structures and ground truths
anchor_gt_asym, anchor_pred_asym_ids = get_least_asym_entity_or_longest_length(ground_truth,feature['asym_id']) anchor_gt_asym, anchor_pred_asym_ids = get_least_asym_entity_or_longest_length(ground_truth,
features['asym_id'])
entity_2_asym_list = AlphaFoldMultimerLoss.get_entity_2_asym_list(ground_truth) entity_2_asym_list = AlphaFoldMultimerLoss.get_entity_2_asym_list(ground_truth)
labels = AlphaFoldMultimerLoss.split_ground_truth_labels(ground_truth, labels = AlphaFoldMultimerLoss.split_ground_truth_labels(ground_truth)
REQUIRED_FEATURES=["all_atom_mask","all_atom_positions"])
assert isinstance(labels, list) assert isinstance(labels, list)
del ground_truth
anchor_gt_idx = int(anchor_gt_asym) - 1 anchor_gt_idx = int(anchor_gt_asym) - 1
# Then calculate optimal transform by aligning anchors # Then calculate optimal transform by aligning anchors
ca_idx = rc.atom_order["CA"] ca_idx = rc.atom_order["CA"]
...@@ -2161,19 +2155,18 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss): ...@@ -2161,19 +2155,18 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
true_ca_masks = [ true_ca_masks = [
l["all_atom_mask"][..., ca_idx].long() for l in labels l["all_atom_mask"][..., ca_idx].long() for l in labels
] # list([nres,]) ] # list([nres,])
per_asym_residue_index = AlphaFoldMultimerLoss.get_per_asym_residue_index(feature)
for candidate_pred_anchor in anchor_pred_asym_ids: for candidate_pred_anchor in anchor_pred_asym_ids:
asym_mask = (feature["asym_id"] == candidate_pred_anchor).bool() asym_mask = (features["asym_id"] == candidate_pred_anchor).bool()
anchor_gt_residue = per_asym_residue_index[int(candidate_pred_anchor)] anchor_gt_residue = per_asym_residue_index[candidate_pred_anchor.item()]
r, x = AlphaFoldMultimerLoss.calculate_optimal_transform(true_ca_poses, r, x = AlphaFoldMultimerLoss.calculate_optimal_transform(true_ca_poses,
anchor_gt_idx,anchor_gt_residue, anchor_gt_idx, anchor_gt_residue,
true_ca_masks,pred_ca_mask, true_ca_masks, pred_ca_mask,
asym_mask, asym_mask,
pred_ca_pos pred_ca_pos
) )
aligned_true_ca_poses = [ca.to(r.dtype) @ r + x for ca in true_ca_poses] # apply transforms aligned_true_ca_poses = [ca.to(r.dtype) @ r + x for ca in true_ca_poses] # apply transforms
align = greedy_align( align = greedy_align(
feature, features,
per_asym_residue_index, per_asym_residue_index,
entity_2_asym_list, entity_2_asym_list,
pred_ca_pos, pred_ca_pos,
...@@ -2181,27 +2174,19 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss): ...@@ -2181,27 +2174,19 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
aligned_true_ca_poses, aligned_true_ca_poses,
true_ca_masks, true_ca_masks,
) )
merged_labels = merge_labels(per_asym_residue_index,labels,align, merged_labels = merge_labels(per_asym_residue_index, labels, align,
original_nres=feature['aatype'].shape[-1]) original_nres=features['aatype'].shape[-1])
rmsd = compute_rmsd(true_atom_pos = merged_labels['all_atom_positions'][..., ca_idx, :].to(r.dtype) @ r + x, rmsd = compute_rmsd(
pred_atom_pos = pred_ca_pos, true_atom_pos=merged_labels['all_atom_positions'][..., ca_idx, :].to(r.dtype) @ r + x,
atom_mask = (pred_ca_mask * merged_labels['all_atom_mask'][..., ca_idx].long()).bool()) pred_atom_pos=pred_ca_pos,
atom_mask=(pred_ca_mask * merged_labels['all_atom_mask'][..., ca_idx].long()).bool())
if rmsd < best_rmsd: if rmsd < best_rmsd:
best_rmsd = rmsd best_rmsd = rmsd
best_align = align best_align = align
del r,x
del true_ca_masks,aligned_true_ca_poses
del pred_ca_pos, pred_ca_mask
gc.collect()
else:
per_asym_residue_index = AlphaFoldMultimerLoss.get_per_asym_residue_index(feature)
best_align = list(enumerate(range(len(per_asym_residue_index))))
return best_align, per_asym_residue_index return best_align, per_asym_residue_index
def forward(self, out, batch, _return_breakdown=False):
def forward(self, out, batch, _return_breakdown=False,permutate_chains=True):
""" """
Overwrite AlphaFoldLoss forward function so that Overwrite AlphaFoldLoss forward function so that
it first compute multi-chain permutation it first compute multi-chain permutation
...@@ -2210,32 +2195,24 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss): ...@@ -2210,32 +2195,24 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
out: the output of model.forward() out: the output of model.forward()
batch: a pair of input features and its corresponding ground truth structure batch: a pair of input features and its corresponding ground truth structure
""" """
# first check if it is a monomer ground_truth = batch.pop('gt_features')
features, ground_truth = batch labels = AlphaFoldMultimerLoss.split_ground_truth_labels(ground_truth)
del batch
is_monomer = len(torch.unique(features['asym_id']))==1 or torch.unique(features['asym_id']).tolist()==[0,1]
if not is_monomer: # Then permute ground truth chains before calculating the loss
permutate_chains = True align, per_asym_residue_index = AlphaFoldMultimerLoss.multi_chain_perm_align(out=out,
features=batch,
# Then permutate ground truth chains before calculating the loss ground_truth=ground_truth)
align,per_asym_residue_index = AlphaFoldMultimerLoss.multi_chain_perm_align(out,
(features,ground_truth),
permutate_chains=permutate_chains)
labels = AlphaFoldMultimerLoss.split_ground_truth_labels(ground_truth,
REQUIRED_FEATURES=[i for i in ground_truth.keys()])
# reorder ground truth labels according to permutation results # reorder ground truth labels according to permutation results
labels = merge_labels(per_asym_residue_index,labels,align, labels = merge_labels(per_asym_residue_index, labels, align,
original_nres=features['aatype'].shape[-1]) original_nres=batch['aatype'].shape[-1])
features.update(labels) batch.update(labels)
if (not _return_breakdown): if not _return_breakdown:
cum_loss = self.loss(out, features, _return_breakdown) cum_loss = self.loss(out, batch, _return_breakdown)
print(f"cum_loss: {cum_loss}") print(f"cum_loss: {cum_loss}")
return cum_loss return cum_loss
else: else:
cum_loss, losses = self.loss(out, features, _return_breakdown) cum_loss, losses = self.loss(out, batch, _return_breakdown)
print(f"cum_loss: {cum_loss} losses: {losses}") print(f"cum_loss: {cum_loss} losses: {losses}")
return cum_loss, losses return cum_loss, losses
...@@ -13,7 +13,7 @@ from tqdm import tqdm ...@@ -13,7 +13,7 @@ from tqdm import tqdm
from openfold.data.mmcif_parsing import parse from openfold.data.mmcif_parsing import parse
def parse_file(f, args): def parse_file(f, args, chain_cluster_size_dict=None):
with open(os.path.join(args.mmcif_dir, f), "r") as fp: with open(os.path.join(args.mmcif_dir, f), "r") as fp:
mmcif_string = fp.read() mmcif_string = fp.read()
file_id = os.path.splitext(f)[0] file_id = os.path.splitext(f)[0]
...@@ -28,6 +28,18 @@ def parse_file(f, args): ...@@ -28,6 +28,18 @@ def parse_file(f, args):
local_data["release_date"] = mmcif.header["release_date"] local_data["release_date"] = mmcif.header["release_date"]
chain_ids, seqs = list(zip(*mmcif.chain_to_seqres.items())) chain_ids, seqs = list(zip(*mmcif.chain_to_seqres.items()))
if chain_cluster_size_dict is not None:
cluster_sizes = []
for chain_id in chain_ids:
full_name = "_".join([file_id, chain_id])
cluster_size = chain_cluster_size_dict.get(
full_name.upper(), -1
)
cluster_sizes.append(cluster_size)
local_data["cluster_sizes"] = cluster_sizes
local_data["chain_ids"] = chain_ids local_data["chain_ids"] = chain_ids
local_data["seqs"] = seqs local_data["seqs"] = seqs
local_data["no_chains"] = len(chain_ids) local_data["no_chains"] = len(chain_ids)
...@@ -38,8 +50,21 @@ def parse_file(f, args): ...@@ -38,8 +50,21 @@ def parse_file(f, args):
def main(args): def main(args):
chain_cluster_size_dict = None
if args.cluster_file is not None:
chain_cluster_size_dict = {}
with open(args.cluster_file, "r") as fp:
clusters = [l.strip() for l in fp.readlines()]
for cluster in clusters:
chain_ids = cluster.split()
cluster_len = len(chain_ids)
for chain_id in chain_ids:
chain_id = chain_id.upper()
chain_cluster_size_dict[chain_id] = cluster_len
files = [f for f in os.listdir(args.mmcif_dir) if ".cif" in f] files = [f for f in os.listdir(args.mmcif_dir) if ".cif" in f]
fn = partial(parse_file, args=args) fn = partial(parse_file, args=args, chain_cluster_size_dict=chain_cluster_size_dict)
data = {} data = {}
with Pool(processes=args.no_workers) as p: with Pool(processes=args.no_workers) as p:
with tqdm(total=len(files)) as pbar: with tqdm(total=len(files)) as pbar:
...@@ -63,6 +88,15 @@ if __name__ == "__main__": ...@@ -63,6 +88,15 @@ if __name__ == "__main__":
"--no_workers", type=int, default=4, "--no_workers", type=int, default=4,
help="Number of workers to use for parsing" help="Number of workers to use for parsing"
) )
parser.add_argument(
"--cluster_file", type=str, default=None,
help=(
"Path to a cluster file (e.g. PDB40), one cluster "
"({PROT1_ID}_{CHAIN_ID} {PROT2_ID}_{CHAIN_ID} ...) per line. "
"Chains not in this cluster file will NOT be filtered by cluster "
"size."
)
)
parser.add_argument( parser.add_argument(
"--chunksize", type=int, default=10, "--chunksize", type=int, default=10,
help="How many files should be distributed to each worker at a time" help="How many files should be distributed to each worker at a time"
......
import argparse import argparse
import logging import logging
import os import os
import random
import sys import sys
import time
import numpy as np
import pytorch_lightning as pl import pytorch_lightning as pl
from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.plugins.training_type import DeepSpeedPlugin, DDPPlugin from pytorch_lightning.plugins.training_type import DeepSpeedPlugin, DDPPlugin
from pytorch_lightning.plugins.environments import SLURMEnvironment
import torch import torch
from openfold.config import model_config from openfold.config import model_config
from openfold.data.data_modules import ( from openfold.data.data_modules import OpenFoldDataModule, OpenFoldMultimerDataModule
OpenFoldDataModule,OpenFoldMultimerDataModule,
DummyDataLoader,
)
from openfold.model.model import AlphaFold from openfold.model.model import AlphaFold
from openfold.model.torchscript import script_preset_ from openfold.model.torchscript import script_preset_
from openfold.np import residue_constants from openfold.np import residue_constants
...@@ -53,7 +46,12 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -53,7 +46,12 @@ class OpenFoldWrapper(pl.LightningModule):
super(OpenFoldWrapper, self).__init__() super(OpenFoldWrapper, self).__init__()
self.config = config self.config = config
self.model = AlphaFold(config) self.model = AlphaFold(config)
if self.config.globals.is_multimer:
self.loss = AlphaFoldMultimerLoss(config.loss)
else:
self.loss = AlphaFoldLoss(config.loss) self.loss = AlphaFoldLoss(config.loss)
self.ema = ExponentialMovingAverage( self.ema = ExponentialMovingAverage(
model=self.model, decay=config.ema.decay model=self.model, decay=config.ema.decay
) )
...@@ -256,72 +254,6 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -256,72 +254,6 @@ class OpenFoldWrapper(pl.LightningModule):
) )
class OpenFoldMultimerWrapper(OpenFoldWrapper):
def __init__(self, config):
super(OpenFoldMultimerWrapper, self).__init__(config)
self.config = config
self.model = AlphaFold(config)
self.loss = AlphaFoldMultimerLoss(config.loss)
self.ema = ExponentialMovingAverage(
model=self.model, decay=config.ema.decay
)
self.cached_weights = None
self.last_lr_step = -1
def forward(self, batch):
return self.model(batch)
def training_step(self, batch, batch_idx):
features,gt_features = batch
# Log it
if(self.ema.device != features["aatype"].device):
self.ema.to(features["aatype"].device)
# Run the model
outputs = self(features)
# Remove the recycling dimension
features = tensor_tree_map(lambda t: t[..., -1], features)
# Compute loss
loss, loss_breakdown = self.loss(
outputs, (features,gt_features), _return_breakdown=True
)
# Log it
self._log(loss_breakdown, features, outputs)
return loss
def validation_step(self, batch, batch_idx):
features,gt_features = batch
# At the start of validation, load the EMA weights
if(self.cached_weights is None):
# model.state_dict() contains references to model weights rather
# than copies. Therefore, we need to clone them before calling
# load_state_dict().
clone_param = lambda t: t.detach().clone()
self.cached_weights = tensor_tree_map(clone_param, self.model.state_dict())
self.model.load_state_dict(self.ema.state_dict()["params"])
# Run the model
outputs = self(features)
# Compute loss and other metrics
features["use_clamped_fape"] = 0.
_, loss_breakdown = self.loss(
outputs, (features,gt_features), _return_breakdown=True
)
self._log(loss_breakdown, features, outputs, train=False)
def validation_epoch_end(self, _):
# Restore the model weights to normal
self.model.load_state_dict(self.cached_weights)
self.cached_weights = None
def main(args): def main(args):
if(args.seed is not None): if(args.seed is not None):
seed_everything(args.seed) seed_everything(args.seed)
...@@ -331,10 +263,8 @@ def main(args): ...@@ -331,10 +263,8 @@ def main(args):
train=True, train=True,
low_prec=(str(args.precision) == "16") low_prec=(str(args.precision) == "16")
) )
if "multimer" in args.config_preset:
model_module = OpenFoldMultimerWrapper(config)
else:
model_module = OpenFoldWrapper(config) model_module = OpenFoldWrapper(config)
if(args.resume_from_ckpt): if(args.resume_from_ckpt):
if(os.path.isdir(args.resume_from_ckpt)): if(os.path.isdir(args.resume_from_ckpt)):
last_global_step = get_global_step_from_zero_checkpoint(args.resume_from_ckpt) last_global_step = get_global_step_from_zero_checkpoint(args.resume_from_ckpt)
...@@ -359,7 +289,6 @@ def main(args): ...@@ -359,7 +289,6 @@ def main(args):
if(args.script_modules): if(args.script_modules):
script_preset_(model_module) script_preset_(model_module)
#data_module = DummyDataLoader("new_batch.pickle")
if "multimer" in args.config_preset: if "multimer" in args.config_preset:
data_module = OpenFoldMultimerDataModule( data_module = OpenFoldMultimerDataModule(
config=config.data, config=config.data,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment