Commit 0cf1541c authored by Christina Floristean's avatar Christina Floristean
Browse files

Refactoring multimer data pipeline and permutation alignment.

parent 377f854c
......@@ -19,6 +19,8 @@ dependencies:
- deepspeed==0.5.10
- dm-tree==0.1.6
- ml-collections==0.1.0
- jax==0.3.25
- pandas==2.0.2
- numpy==1.21.2
- PyYAML==5.4.1
- requests==2.26.0
......
......@@ -749,6 +749,9 @@ multimer_config_update = mlc.ConfigDict({
"sym_id",
]
},
"supervised": {
"clamp_prob": 1.
},
# TODO: Change max_msa_clusters and max_extra_msa to multimer feats within model:
# c.model.input_embedder.num_msa = 508
# c.model.extra_msa.extra_msa_embedder.num_extra_msa = 2048
......@@ -765,7 +768,8 @@ multimer_config_update = mlc.ConfigDict({
"max_extra_msa": 2048,
"crop_size": 640,
"spatial_crop_prob": 0.5,
"interface_threshold": 10.
"interface_threshold": 10.,
"clamp_prob": 1.,
},
},
"model": {
......
......@@ -4,7 +4,7 @@ import json
import logging
import os
import pickle
from typing import Optional, Sequence, Any
from typing import Optional, Sequence, Any, Union
import ml_collections as mlc
import pytorch_lightning as pl
......@@ -18,43 +18,31 @@ from openfold.data import (
templates,
)
from openfold.utils.tensor_utils import dict_multimap
import contextlib
import tempfile
from openfold.utils.tensor_utils import (
tensor_tree_map,
)
import random
logging.basicConfig(level=logging.INFO)
@contextlib.contextmanager
def temp_fasta_file(sequence_str):
"""function that create temparory fasta file used in multimer datapipeline"""
with tempfile.NamedTemporaryFile("w", suffix=".fasta") as fasta_file:
fasta_file.write(sequence_str)
fasta_file.seek(0)
yield fasta_file.name
class OpenFoldSingleDataset(torch.utils.data.Dataset):
def __init__(self,
data_dir: str,
alignment_dir: str,
template_mmcif_dir: str,
max_template_date: str,
config: mlc.ConfigDict,
chain_data_cache_path: Optional[str] = None,
kalign_binary_path: str = '/usr/bin/kalign',
max_template_hits: int = 4,
obsolete_pdbs_file_path: Optional[str] = None,
template_release_dates_cache_path: Optional[str] = None,
shuffle_top_k_prefiltered: Optional[int] = None,
treat_pdb_as_distillation: bool = True,
filter_path: Optional[str] = None,
mode: str = "train",
alignment_index: Optional[Any] = None,
_output_raw: bool = False,
_structure_index: Optional[Any] = None,
):
data_dir: str,
alignment_dir: str,
template_mmcif_dir: str,
max_template_date: str,
config: mlc.ConfigDict,
chain_data_cache_path: Optional[str] = None,
kalign_binary_path: str = '/usr/bin/kalign',
max_template_hits: int = 4,
obsolete_pdbs_file_path: Optional[str] = None,
template_release_dates_cache_path: Optional[str] = None,
shuffle_top_k_prefiltered: Optional[int] = None,
treat_pdb_as_distillation: bool = True,
filter_path: Optional[str] = None,
mode: str = "train",
alignment_index: Optional[Any] = None,
_output_raw: bool = False,
_structure_index: Optional[Any] = None,
):
"""
Args:
data_dir:
......@@ -116,21 +104,21 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
self.supported_exts = [".cif", ".core", ".pdb"]
valid_modes = ["train", "eval", "predict"]
if(mode not in valid_modes):
if mode not in valid_modes:
raise ValueError(f'mode must be one of {valid_modes}')
if(template_release_dates_cache_path is None):
if template_release_dates_cache_path is None:
logging.warning(
"Template release dates cache does not exist. Remember to run "
"scripts/generate_mmcif_cache.py before running OpenFold"
)
if(alignment_index is not None):
if alignment_index is not None:
self._chain_ids = list(alignment_index.keys())
else:
self._chain_ids = list(os.listdir(alignment_dir))
if(filter_path is not None):
if filter_path is not None:
with open(filter_path, "r") as f:
chains_to_include = set([l.strip() for l in f.readlines()])
......@@ -160,7 +148,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
len(missing),
missing_examples,
chain_data_cache_path)
self._chain_id_to_idx_dict = {
chain: i for i, chain in enumerate(self._chain_ids)
}
......@@ -182,7 +170,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
template_featurizer=template_featurizer,
)
if(not self._output_raw):
if not self._output_raw:
self.feature_pipeline = feature_pipeline.FeaturePipeline(config)
def _parse_mmcif(self, path, file_id, chain_id, alignment_dir, alignment_index):
......@@ -195,7 +183,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
# Crash if an error is encountered. Any parsing errors should have
# been dealt with at the alignment stage.
if(mmcif_object.mmcif_object is None):
if mmcif_object.mmcif_object is None:
raise list(mmcif_object.errors.values())[0]
mmcif_object = mmcif_object.mmcif_object
......@@ -220,47 +208,46 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
alignment_dir = os.path.join(self.alignment_dir, name)
alignment_index = None
if(self.alignment_index is not None):
if self.alignment_index is not None:
alignment_dir = self.alignment_dir
alignment_index = self.alignment_index[name]
if(self.mode == 'train' or self.mode == 'eval'):
if self.mode == 'train' or self.mode == 'eval':
spl = name.rsplit('_', 1)
if(len(spl) == 2):
if len(spl) == 2:
file_id, chain_id = spl
else:
file_id, = spl
chain_id = None
path = os.path.join(self.data_dir, file_id)
structure_index_entry = None
if(self._structure_index is not None):
if self._structure_index is not None:
structure_index_entry = self._structure_index[name]
assert(len(structure_index_entry["files"]) == 1)
assert (len(structure_index_entry["files"]) == 1)
filename, _, _ = structure_index_entry["files"][0]
ext = os.path.splitext(filename)[1]
else:
ext = None
for e in self.supported_exts:
if(os.path.exists(path + e)):
if os.path.exists(path + e):
ext = e
break
if(ext is None):
if ext is None:
raise ValueError("Invalid file type")
path += ext
if(ext == ".cif"):
if ext == ".cif":
data = self._parse_mmcif(
path, file_id, chain_id, alignment_dir, alignment_index,
)
elif(ext == ".core"):
elif ext == ".core":
data = self.data_pipeline.process_core(
path, alignment_dir, alignment_index,
)
elif(ext == ".pdb"):
elif ext == ".pdb":
structure_index = None
if(self._structure_index is not None):
if self._structure_index is not None:
structure_index = self._structure_index[name]
data = self.data_pipeline.process_pdb(
pdb_path=path,
......@@ -271,7 +258,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
_structure_index=structure_index,
)
else:
raise ValueError("Extension branch missing")
raise ValueError("Extension branch missing")
else:
path = os.path.join(name, name + ".fasta")
data = self.data_pipeline.process_fasta(
......@@ -280,11 +267,11 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
alignment_index=alignment_index,
)
if(self._output_raw):
if self._output_raw:
return data
feats = self.feature_pipeline.process_features(
data, self.mode
data, self.mode
)
feats["batch_idx"] = torch.tensor(
......@@ -295,30 +282,29 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
return feats
def __len__(self):
return len(self._chain_ids)
return len(self._chain_ids)
class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
def __init__(self,
data_dir: str,
alignment_dir: str,
template_mmcif_dir: str,
max_template_date: str,
config: mlc.ConfigDict,
chain_data_cache_path: Optional[str] = None,
mmcif_data_cache_path: Optional[str] = None,
kalign_binary_path: str = '/usr/bin/kalign',
max_template_hits: int = 4,
obsolete_pdbs_file_path: Optional[str] = None,
template_release_dates_cache_path: Optional[str] = None,
shuffle_top_k_prefiltered: Optional[int] = None,
treat_pdb_as_distillation: bool = True,
filter_path: Optional[str] = None,
mode: str = "train",
alignment_index: Optional[Any] = None,
_output_raw: bool = False,
_structure_index: Optional[Any] = None,
):
data_dir: str,
alignment_dir: str,
template_mmcif_dir: str,
max_template_date: str,
config: mlc.ConfigDict,
mmcif_data_cache_path: Optional[str] = None,
kalign_binary_path: str = '/usr/bin/kalign',
max_template_hits: int = 4,
obsolete_pdbs_file_path: Optional[str] = None,
template_release_dates_cache_path: Optional[str] = None,
shuffle_top_k_prefiltered: Optional[int] = None,
treat_pdb_as_distillation: bool = True,
filter_path: Optional[str] = None,
mode: str = "train",
alignment_index: Optional[Any] = None,
_output_raw: bool = False,
_structure_index: Optional[Any] = None,
):
"""
This class check each individual PDB ID and return its chain(s) features/ground truth
Args:
......@@ -336,15 +322,10 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
Path to a directory containing template mmCIF files.
config:
A dataset config object. See openfold.config
chain_data_cache_path:
Path to cache of data_dir generated by
scripts/generate_chain_data_cache.py
mmcif_data_cache_path:
Path to cache of all mmcifs files generated by
scripts/generate_mmcif_cache.py It should be a json file which records
what PDB ID contains which chain(s)
kalign_binary_path:
Path to kalign binary.
max_template_hits:
......@@ -369,17 +350,12 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
"""
super(OpenFoldSingleMultimerDataset, self).__init__()
self.data_dir = data_dir
self.mmcif_data_cache_path=mmcif_data_cache_path
self.chain_data_cache = None
if chain_data_cache_path is not None:
with open(chain_data_cache_path, "r") as fp:
self.chain_data_cache = json.load(fp)
assert isinstance(self.chain_data_cache, dict)
self.mmcif_data_cache_path = mmcif_data_cache_path
if self.mmcif_data_cache_path is not None:
with open(self.mmcif_data_cache_path,"r") as infile:
with open(self.mmcif_data_cache_path, "r") as infile:
self.mmcif_data_cache = json.load(infile)
assert isinstance(self.mmcif_data_cache,dict)
assert isinstance(self.mmcif_data_cache, dict)
self.alignment_dir = alignment_dir
self.config = config
......@@ -392,39 +368,36 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
self.supported_exts = [".cif", ".core", ".pdb"]
valid_modes = ["train", "eval", "predict"]
if(mode not in valid_modes):
if mode not in valid_modes:
raise ValueError(f'mode must be one of {valid_modes}')
if(template_release_dates_cache_path is None):
if template_release_dates_cache_path is None:
logging.warning(
"Template release dates cache does not exist. Remember to run "
"scripts/generate_mmcif_cache.py before running OpenFold"
)
if(alignment_index is not None):
self._chain_ids = list(alignment_index.keys())
if self.mmcif_data_cache_path is not None:
self._mmcifs = list(self.mmcif_data_cache.keys())
elif self.alignment_index is not None:
self._mmcifs = [i.split("_")[0] for i in list(alignment_index.keys())]
elif self.alignment_dir is not None:
self._mmcifs = [i.split("_")[0] for i in os.listdir(self.alignment_dir)]
else:
self._chain_ids = list(os.listdir(alignment_dir))
raise ValueError("You must provide at least one of the mmcif_data_cache or alignment_dir")
if(filter_path is not None):
if filter_path is not None:
with open(filter_path, "r") as f:
chains_to_include = set([l.strip() for l in f.readlines()])
mmcifs_to_include = set([l.strip() for l in f.readlines()])
self._chain_ids = [
c for c in self._chain_ids if c in chains_to_include
self._mmcifs = [
m for m in self._mmcifs if m in mmcifs_to_include
]
if self.mmcif_data_cache_path is not None:
self._mmcifs = list(self.mmcif_data_cache.keys())
elif self.mmcif_data_cache_path is None and self.alignment_dir is not None:
self._mmcifs = [i.split("_")[0] for i in os.listdir(self.alignment_dir)]
else:
raise ValueError("You must provide at least one of the mmcif_data_cache or alignment_dir")
self._mmcif_id_to_idx_dict = {
mmcif: i for i, mmcif in enumerate(self._mmcifs)
}
# changed template_featurizer to hmmsearch for now just to run the test
mmcif: i for i, mmcif in enumerate(self._mmcifs)
}
template_featurizer = templates.HmmsearchHitFeaturizer(
mmcif_dir=template_mmcif_dir,
max_template_date=max_template_date,
......@@ -443,7 +416,7 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
)
self.feature_pipeline = feature_pipeline.FeaturePipeline(config)
def _parse_mmcif(self, path, file_id,alignment_dir, alignment_index):
def _parse_mmcif(self, path, file_id, alignment_dir, alignment_index):
with open(path, 'r') as f:
mmcif_string = f.read()
......@@ -453,7 +426,7 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
# Crash if an error is encountered. Any parsing errors should have
# been dealt with at the alignment stage.
if(mmcif_object.mmcif_object is None):
if mmcif_object.mmcif_object is None:
raise list(mmcif_object.errors.values())[0]
mmcif_object = mmcif_object.mmcif_object
......@@ -462,34 +435,34 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
mmcif=mmcif_object,
alignment_dir=alignment_dir,
alignment_index=alignment_index
)
)
return data
def mmcif_id_to_idx(self, chain_id):
return self._mmcif_id_to_idx_dict[chain_id]
def mmcif_id_to_idx(self, mmcif_id):
return self._mmcif_id_to_idx_dict[mmcif_id]
def idx_to_mmcif_id(self, idx):
return self._mmcifs[idx]
def __getitem__(self, idx):
mmcif_id = self.idx_to_mmcif_id(idx)
alignment_index = None
if(self.mode == 'train' or self.mode == 'eval'):
if self.mode == 'train' or self.mode == 'eval':
path = os.path.join(self.data_dir, f"{mmcif_id}")
ext = None
for e in self.supported_exts:
if(os.path.exists(path + e)):
if os.path.exists(path + e):
ext = e
break
if(ext is None):
if ext is None:
raise ValueError("Invalid file type")
#TODO: Add pdb and core exts to data_pipeline for multimer
# TODO: Add pdb and core exts to data_pipeline for multimer
path += ext
if(ext == ".cif"):
if ext == ".cif":
data = self._parse_mmcif(
path, mmcif_id, self.alignment_dir, alignment_index,
)
......@@ -502,107 +475,52 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
alignment_dir=self.alignment_dir
)
if (self._output_raw):
if self._output_raw:
return data
# process all_chain_features
data,ground_truth = self.feature_pipeline.process_features(data,
data = self.feature_pipeline.process_features(data,
mode=self.mode,
is_multimer=True)
# if it's inference mode, only need all_chain_features
data["batch_idx"] = torch.tensor(
[idx for _ in range(data["aatype"].shape[-1])],
dtype=torch.int64,
device=data["aatype"].device)
return data, ground_truth
return data
def __len__(self):
return len(self._chain_ids)
return len(self._mmcifs)
def deterministic_train_filter(
chain_data_cache_entry: Any,
max_resolution: float = 9.,
max_single_aa_prop: float = 0.8,
) -> bool:
# Hard filters
resolution = chain_data_cache_entry.get("resolution", None)
if(resolution is not None and resolution > max_resolution):
return False
def resolution_filter(resolution: int, max_resolution: float) -> bool:
"""Check that the resolution is <= max_resolution permitted"""
return resolution is not None and resolution <= max_resolution
seq = chain_data_cache_entry["seq"]
counts = {}
for aa in seq:
counts.setdefault(aa, 0)
counts[aa] += 1
largest_aa_count = max(counts.values())
largest_single_aa_prop = largest_aa_count / len(seq)
if(largest_single_aa_prop > max_single_aa_prop):
return False
return True
def deterministic_multimer_train_filter(
mmcif_data_cache_entry,
max_resolution:float= 9.,
max_single_aa_prop:float=0.8,
minimum_number_of_residues:int=200,
) -> bool:
"""
Implement multimer training filtering criteria described in
https://www.biorxiv.org/content/10.1101/2021.10.04.463034v2.full.pdf Supplementary section 7.1
"""
# First check resolution
resolution = mmcif_data_cache_entry.get("resolution", None)
if(resolution is not None and resolution > max_resolution) or (resolution is None):
return False
# Then check if any single amino acid accounts for more than 80% of the complex sequences
seqs = mmcif_data_cache_entry["seqs"]
def aa_count_filter(seqs: list, max_single_aa_prop: float) -> bool:
"""Check if any single amino acid accounts for more than max_single_aa_prop percent of the sequence(s)"""
counts = {}
for aa in restypes:
counts[aa] = 0
total_len = sum([len(i) for i in seqs])
if total_len<minimum_number_of_residues: # check if the complex has less than 200 residues
return False
for seq in seqs:
for aa in seq:
counts.setdefault(aa, 0)
if aa not in restypes:
return False
else:
counts[aa] += 1
total_len = sum([len(i) for i in seqs])
largest_aa_count = max(counts.values())
largest_single_aa_prop = largest_aa_count / total_len
if(largest_single_aa_prop > max_single_aa_prop):
return False
return True
def get_stochastic_train_filter_prob(
chain_data_cache_entry: Any,
) -> float:
# Stochastic filters
probabilities = []
cluster_size = chain_data_cache_entry.get("cluster_size", None)
if(cluster_size is not None and cluster_size > 0):
probabilities.append(1 / cluster_size)
chain_length = len(chain_data_cache_entry["seq"])
probabilities.append((1 / 512) * (max(min(chain_length, 512), 256)))
return largest_single_aa_prop <= max_single_aa_prop
# Risk of underflow here?
out = 1
for p in probabilities:
out *= p
return out
def all_seq_len_filter(seqs: list, minimum_number_of_residues: int) -> bool:
"""Check if the total combined sequence lengths are >= minimum_numer_of_residues"""
total_len = sum([len(i) for i in seqs])
return total_len >= minimum_number_of_residues
class OpenFoldDataset(torch.utils.data.Dataset):
......@@ -612,67 +530,104 @@ class OpenFoldDataset(torch.utils.data.Dataset):
length of an OpenFoldFilteredDataset is arbitrary. Samples are selected
and filtered once at initialization.
"""
def __init__(self,
datasets: Sequence[OpenFoldSingleDataset],
probabilities: Sequence[float],
epoch_len: int,
generator: torch.Generator = None,
_roll_at_init: bool = True,
):
datasets: Union[Sequence[OpenFoldSingleDataset], Sequence[OpenFoldSingleMultimerDataset]],
probabilities: Sequence[float],
epoch_len: int,
generator: torch.Generator = None,
_roll_at_init: bool = True,
):
self.datasets = datasets
self.probabilities = probabilities
self.epoch_len = epoch_len
self.generator = generator
def looped_shuffled_dataset_idx(dataset_len):
while True:
# Uniformly shuffle each dataset's indices
weights = [1. for _ in range(dataset_len)]
shuf = torch.multinomial(
torch.tensor(weights),
num_samples=dataset_len,
replacement=False,
generator=self.generator,
)
for idx in shuf:
yield idx
def looped_samples(dataset_idx):
max_cache_len = int(epoch_len * probabilities[dataset_idx])
dataset = self.datasets[dataset_idx]
idx_iter = looped_shuffled_dataset_idx(len(dataset))
chain_data_cache = dataset.chain_data_cache
while True:
weights = []
idx = []
for _ in range(max_cache_len):
candidate_idx = next(idx_iter)
chain_id = dataset.idx_to_chain_id(candidate_idx)
chain_data_cache_entry = chain_data_cache[chain_id]
if(not deterministic_train_filter(chain_data_cache_entry)):
continue
p = get_stochastic_train_filter_prob(
chain_data_cache_entry,
)
weights.append([1. - p, p])
idx.append(candidate_idx)
samples = torch.multinomial(
torch.tensor(weights),
num_samples=1,
generator=self.generator,
self._samples = [self.looped_samples(i) for i in range(len(self.datasets))]
if _roll_at_init:
self.reroll()
@staticmethod
def deterministic_train_filter(
cache_entry: Any,
max_resolution: float = 9.,
max_single_aa_prop: float = 0.8,
) -> bool:
# Hard filters
resolution = cache_entry.get("resolution", None)
seqs = [cache_entry["seq"]]
return all([resolution_filter(resolution=resolution,
max_resolution=max_resolution),
aa_count_filter(seqs=seqs,
max_single_aa_prop=max_single_aa_prop)])
@staticmethod
def get_stochastic_train_filter_prob(
cache_entry: Any,
) -> float:
# Stochastic filters
probabilities = []
cluster_size = cache_entry.get("cluster_size", None)
if cluster_size is not None and cluster_size > 0:
probabilities.append(1 / cluster_size)
chain_length = len(cache_entry["seq"])
probabilities.append((1 / 512) * (max(min(chain_length, 512), 256)))
# Risk of underflow here?
out = 1
for p in probabilities:
out *= p
return out
def looped_shuffled_dataset_idx(self, dataset_len):
while True:
# Uniformly shuffle each dataset's indices
weights = [1. for _ in range(dataset_len)]
shuf = torch.multinomial(
torch.tensor(weights),
num_samples=dataset_len,
replacement=False,
generator=self.generator,
)
for idx in shuf:
yield idx
def looped_samples(self, dataset_idx):
max_cache_len = int(self.epoch_len * self.probabilities[dataset_idx])
dataset = self.datasets[dataset_idx]
idx_iter = self.looped_shuffled_dataset_idx(len(dataset))
chain_data_cache = dataset.chain_data_cache
while True:
weights = []
idx = []
for _ in range(max_cache_len):
candidate_idx = next(idx_iter)
chain_id = dataset.idx_to_chain_id(candidate_idx)
chain_data_cache_entry = chain_data_cache[chain_id]
if not self.deterministic_train_filter(chain_data_cache_entry):
continue
p = self.get_stochastic_train_filter_prob(
chain_data_cache_entry,
)
samples = samples.squeeze()
weights.append([1. - p, p])
idx.append(candidate_idx)
cache = [i for i, s in zip(idx, samples) if s]
samples = torch.multinomial(
torch.tensor(weights),
num_samples=1,
generator=self.generator,
)
samples = samples.squeeze()
for datapoint_idx in cache:
yield datapoint_idx
cache = [i for i, s in zip(idx, samples) if s]
self._samples = [looped_samples(i) for i in range(len(self.datasets))]
if(_roll_at_init):
self.reroll()
for datapoint_idx in cache:
yield datapoint_idx
def __getitem__(self, idx):
dataset_idx, datapoint_idx = self.datapoints[idx]
......@@ -695,71 +650,97 @@ class OpenFoldDataset(torch.utils.data.Dataset):
self.datapoints.append((dataset_idx, datapoint_idx))
class OpenFoldMultimerDataset(torch.utils.data.Dataset):
class OpenFoldMultimerDataset(OpenFoldDataset):
"""
Create a torch Dataset object for multimer training and
add filtering steps described in AlphaFold Multimer's paper:
https://www.biorxiv.org/content/10.1101/2021.10.04.463034v2.full.pdf Supplementary section 7.1
"""
def __init__(self,
datasets: Sequence[OpenFoldSingleDataset],
probabilities: Sequence[float],
epoch_len: int,
generator: torch.Generator = None,
_roll_at_init: bool = True,
):
self.datasets = datasets
self.probabilities = probabilities
self.epoch_len = epoch_len
self.generator = generator
if _roll_at_init:
self.reroll()
def filter_samples(self,dataset_idx):
dataset = self.datasets[dataset_idx]
mmcif_data_cache = dataset.mmcif_data_cache if hasattr(dataset,"mmcif_data_cache") else None
selected_idx = []
if mmcif_data_cache is not None:
for i in range(len(mmcif_data_cache)):
mmcif_id = dataset.idx_to_mmcif_id(i)
mmcif_data_cache_entry = mmcif_data_cache[mmcif_id]
if deterministic_multimer_train_filter(mmcif_data_cache_entry,
max_resolution=9):
selected_idx.append(i)
logging.info(f"Originally {len(mmcif_data_cache)} mmcifs. After filtering: {len(selected_idx)}")
else:
selected_idx = list(range(len(dataset._mmcif_id_to_idx_dict)))
return selected_idx
def __getitem__(self, idx):
dataset_idx, datapoint_idx = self.datapoints[idx]
return self.datasets[dataset_idx][datapoint_idx]
def __init__(self,
datasets: Sequence[OpenFoldSingleMultimerDataset],
probabilities: Sequence[float],
epoch_len: int,
generator: torch.Generator = None,
_roll_at_init: bool = True
):
super(OpenFoldMultimerDataset).__init__(datasets=datasets,
probabilities=probabilities,
epoch_len=epoch_len,
generator=generator,
_roll_at_init=_roll_at_init)
@staticmethod
def deterministic_train_filter(
cache_entry: Any,
max_resolution: float = 9.,
max_single_aa_prop: float = 0.8,
minimum_number_of_residues: int = 200,
) -> bool:
"""
Implement multimer training filtering criteria described in
https://www.biorxiv.org/content/10.1101/2021.10.04.463034v2.full.pdf Supplementary section 7.1
"""
resolution = cache_entry.get("resolution", None)
seqs = cache_entry["seqs"]
return all([resolution_filter(resolution=resolution,
max_resolution=max_resolution),
all_seq_len_filter(seqs=seqs,
minimum_number_of_residues=minimum_number_of_residues),
aa_count_filter(seqs=seqs,
max_single_aa_prop=max_single_aa_prop)])
@staticmethod
def get_stochastic_train_filter_prob(
cache_entry: Any,
) -> float:
# Stochastic filters
cluster_sizes = cache_entry.get("cluster_sizes", [])
chain_probs = [1 / c for c in cluster_sizes if c > 0]
if chain_probs:
return sum(chain_probs)
return 1.
def looped_samples(self, dataset_idx):
max_cache_len = int(self.epoch_len * self.probabilities[dataset_idx])
dataset = self.datasets[dataset_idx]
idx_iter = self.looped_shuffled_dataset_idx(len(dataset))
mmcif_data_cache = dataset.mmcif_data_cache
while True:
weights = []
idx = []
for _ in range(max_cache_len):
candidate_idx = next(idx_iter)
mmcif_id = dataset.idx_to_mmcif_id(candidate_idx)
mmcif_data_cache_entry = mmcif_data_cache[mmcif_id]
if not self.deterministic_train_filter(mmcif_data_cache_entry):
continue
p = self.get_stochastic_train_filter_prob(
mmcif_data_cache_entry,
)
weights.append([1. - p, p])
idx.append(candidate_idx)
def __len__(self):
return self.epoch_len
samples = torch.multinomial(
torch.tensor(weights),
num_samples=1,
generator=self.generator,
)
samples = samples.squeeze()
def reroll(self):
dataset_choices = torch.multinomial(
torch.tensor(self.probabilities),
num_samples=len(self.probabilities),
replacement=True,
generator=self.generator,
)
cache = [i for i, s in zip(idx, samples) if s]
self.datapoints = []
for dataset_idx in dataset_choices:
selected_idx = self.filter_samples(dataset_idx)
random.shuffle(selected_idx)
if len(selected_idx)<self.epoch_len:
self.epoch_len = len(selected_idx)
logging.info(f"self.epoch_len is {self.epoch_len}")
self.datapoints += [(dataset_idx, selected_idx[i]) for i in range(self.epoch_len) ]
for datapoint_idx in cache:
yield datapoint_idx
class OpenFoldBatchCollator:
def __call__(self, prots):
stack_fn = partial(torch.stack, dim=0)
return dict_multimap(stack_fn, prots)
return dict_multimap(stack_fn, prots)
class OpenFoldDataLoader(torch.utils.data.DataLoader):
......@@ -775,8 +756,8 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader):
stage_cfg = self.config[self.stage]
max_iters = self.config.common.max_recycling_iters
if(stage_cfg.uniform_recycling):
if stage_cfg.uniform_recycling:
recycling_probs = [
1. / (max_iters + 1) for _ in range(max_iters + 1)
]
......@@ -785,15 +766,15 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader):
0. for _ in range(max_iters + 1)
]
recycling_probs[-1] = 1.
keyed_probs.append(
("no_recycling_iters", recycling_probs)
)
keys, probs = zip(*keyed_probs)
max_len = max([len(p) for p in probs])
padding = [[0.] * (max_len - len(p)) for p in probs]
padding = [[0.] * (max_len - len(p)) for p in probs]
self.prop_keys = keys
self.prop_probs_tensor = torch.tensor(
[p + pad for p, pad in zip(probs, padding)],
......@@ -803,7 +784,7 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader):
def _add_batch_properties(self, batch):
samples = torch.multinomial(
self.prop_probs_tensor,
num_samples=1, # 1 per row
num_samples=1, # 1 per row
replacement=True,
generator=self.generator
)
......@@ -815,8 +796,8 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader):
for i, key in enumerate(self.prop_keys):
sample = int(samples[i][0])
sample_tensor = torch.tensor(
sample,
device=aatype.device,
sample,
device=aatype.device,
requires_grad=False
)
orig_shape = sample_tensor.shape
......@@ -828,9 +809,9 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader):
)
batch[key] = sample_tensor
if(key == "no_recycling_iters"):
no_recycling = sample
if key == "no_recycling_iters":
no_recycling = sample
resample_recycling = lambda t: t[..., :no_recycling + 1]
batch = tensor_tree_map(resample_recycling, batch)
......@@ -846,50 +827,33 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader):
return _batch_prop_gen(it)
class OpenFoldMultimerDataLoader(torch.utils.data.DataLoader):
def __init__(self, *args, config, stage="train", generator=None, **kwargs):
super(OpenFoldMultimerDataLoader,self).__init__(*args, **kwargs)
self.config = config
self.stage = stage
self.generator = generator
def __iter__(self):
it = super().__iter__()
def _batch_prop_gen(iterator):
for batch in iterator:
yield batch
return _batch_prop_gen(it)
class OpenFoldDataModule(pl.LightningDataModule):
def __init__(self,
config: mlc.ConfigDict,
template_mmcif_dir: str,
max_template_date: str,
train_data_dir: Optional[str] = None,
train_alignment_dir: Optional[str] = None,
train_chain_data_cache_path: Optional[str] = None,
distillation_data_dir: Optional[str] = None,
distillation_alignment_dir: Optional[str] = None,
distillation_chain_data_cache_path: Optional[str] = None,
val_data_dir: Optional[str] = None,
val_alignment_dir: Optional[str] = None,
predict_data_dir: Optional[str] = None,
predict_alignment_dir: Optional[str] = None,
kalign_binary_path: str = '/usr/bin/kalign',
train_filter_path: Optional[str] = None,
distillation_filter_path: Optional[str] = None,
obsolete_pdbs_file_path: Optional[str] = None,
template_release_dates_cache_path: Optional[str] = None,
batch_seed: Optional[int] = None,
train_epoch_len: int = 50000,
_distillation_structure_index_path: Optional[str] = None,
alignment_index_path: Optional[str] = None,
distillation_alignment_index_path: Optional[str] = None,
**kwargs
):
config: mlc.ConfigDict,
template_mmcif_dir: str,
max_template_date: str,
train_data_dir: Optional[str] = None,
train_alignment_dir: Optional[str] = None,
train_chain_data_cache_path: Optional[str] = None,
distillation_data_dir: Optional[str] = None,
distillation_alignment_dir: Optional[str] = None,
distillation_chain_data_cache_path: Optional[str] = None,
val_data_dir: Optional[str] = None,
val_alignment_dir: Optional[str] = None,
predict_data_dir: Optional[str] = None,
predict_alignment_dir: Optional[str] = None,
kalign_binary_path: str = '/usr/bin/kalign',
train_filter_path: Optional[str] = None,
distillation_filter_path: Optional[str] = None,
obsolete_pdbs_file_path: Optional[str] = None,
template_release_dates_cache_path: Optional[str] = None,
batch_seed: Optional[int] = None,
train_epoch_len: int = 50000,
_distillation_structure_index_path: Optional[str] = None,
alignment_index_path: Optional[str] = None,
distillation_alignment_index_path: Optional[str] = None,
**kwargs
):
super(OpenFoldDataModule, self).__init__()
self.config = config
......@@ -917,7 +881,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
self.batch_seed = batch_seed
self.train_epoch_len = train_epoch_len
if(self.train_data_dir is None and self.predict_data_dir is None):
if self.train_data_dir is None and self.predict_data_dir is None:
raise ValueError(
'At least one of train_data_dir or predict_data_dir must be '
'specified'
......@@ -925,65 +889,61 @@ class OpenFoldDataModule(pl.LightningDataModule):
self.training_mode = self.train_data_dir is not None
if(self.training_mode and train_alignment_dir is None):
if self.training_mode and train_alignment_dir is None:
raise ValueError(
'In training mode, train_alignment_dir must be specified'
)
elif(not self.training_mode and predict_alignment_dir is None):
elif not self.training_mode and predict_alignment_dir is None:
raise ValueError(
'In inference mode, predict_alignment_dir must be specified'
)
elif(val_data_dir is not None and val_alignment_dir is None):
)
elif val_data_dir is not None and val_alignment_dir is None:
raise ValueError(
'If val_data_dir is specified, val_alignment_dir must '
'be specified as well'
)
)
# An ad-hoc measure for our particular filesystem restrictions
self._distillation_structure_index = None
if(_distillation_structure_index_path is not None):
if _distillation_structure_index_path is not None:
with open(_distillation_structure_index_path, "r") as fp:
self._distillation_structure_index = json.load(fp)
self.alignment_index = None
if(alignment_index_path is not None):
if alignment_index_path is not None:
with open(alignment_index_path, "r") as fp:
self.alignment_index = json.load(fp)
self.distillation_alignment_index = None
if(distillation_alignment_index_path is not None):
if distillation_alignment_index_path is not None:
with open(distillation_alignment_index_path, "r") as fp:
self.distillation_alignment_index = json.load(fp)
def setup(self):
# Most of the arguments are the same for the three datasets
dataset_gen = partial(OpenFoldSingleDataset,
template_mmcif_dir=self.template_mmcif_dir,
max_template_date=self.max_template_date,
config=self.config,
kalign_binary_path=self.kalign_binary_path,
template_release_dates_cache_path=
self.template_release_dates_cache_path,
obsolete_pdbs_file_path=
self.obsolete_pdbs_file_path,
)
if(self.training_mode):
template_mmcif_dir=self.template_mmcif_dir,
max_template_date=self.max_template_date,
config=self.config,
kalign_binary_path=self.kalign_binary_path,
template_release_dates_cache_path=self.template_release_dates_cache_path,
obsolete_pdbs_file_path=self.obsolete_pdbs_file_path)
if self.training_mode:
train_dataset = dataset_gen(
data_dir=self.train_data_dir,
chain_data_cache_path=self.train_chain_data_cache_path,
alignment_dir=self.train_alignment_dir,
filter_path=self.train_filter_path,
max_template_hits=self.config.train.max_template_hits,
shuffle_top_k_prefiltered=
self.config.train.shuffle_top_k_prefiltered,
shuffle_top_k_prefiltered=self.config.train.shuffle_top_k_prefiltered,
treat_pdb_as_distillation=False,
mode="train",
alignment_index=self.alignment_index,
)
distillation_dataset = None
if(self.distillation_data_dir is not None):
if self.distillation_data_dir is not None:
distillation_dataset = dataset_gen(
data_dir=self.distillation_data_dir,
chain_data_cache_path=self.distillation_chain_data_cache_path,
......@@ -997,8 +957,8 @@ class OpenFoldDataModule(pl.LightningDataModule):
)
d_prob = self.config.train.distillation_prob
if(distillation_dataset is not None):
if distillation_dataset is not None:
datasets = [train_dataset, distillation_dataset]
d_prob = self.config.train.distillation_prob
probabilities = [1. - d_prob, d_prob]
......@@ -1007,10 +967,10 @@ class OpenFoldDataModule(pl.LightningDataModule):
probabilities = [1.]
generator = None
if(self.batch_seed is not None):
if self.batch_seed is not None:
generator = torch.Generator()
generator = generator.manual_seed(self.batch_seed + 1)
self.train_dataset = OpenFoldDataset(
datasets=datasets,
probabilities=probabilities,
......@@ -1018,8 +978,8 @@ class OpenFoldDataModule(pl.LightningDataModule):
generator=generator,
_roll_at_init=False,
)
if(self.val_data_dir is not None):
if self.val_data_dir is not None:
self.eval_dataset = dataset_gen(
data_dir=self.val_data_dir,
alignment_dir=self.val_alignment_dir,
......@@ -1029,7 +989,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
)
else:
self.eval_dataset = None
else:
else:
self.predict_dataset = dataset_gen(
data_dir=self.predict_data_dir,
alignment_dir=self.predict_alignment_dir,
......@@ -1040,18 +1000,17 @@ class OpenFoldDataModule(pl.LightningDataModule):
def _gen_dataloader(self, stage):
generator = None
if(self.batch_seed is not None):
if self.batch_seed is not None:
generator = torch.Generator()
generator = generator.manual_seed(self.batch_seed)
dataset = None
if(stage == "train"):
if stage == "train":
dataset = self.train_dataset
# Filter the dataset, if necessary
dataset.reroll()
elif(stage == "eval"):
elif stage == "eval":
dataset = self.eval_dataset
elif(stage == "predict"):
elif stage == "predict":
dataset = self.predict_dataset
else:
raise ValueError("Invalid stage")
......@@ -1071,15 +1030,15 @@ class OpenFoldDataModule(pl.LightningDataModule):
return dl
def train_dataloader(self):
return self._gen_dataloader("train")
return self._gen_dataloader("train")
def val_dataloader(self):
if(self.eval_dataset is not None):
if self.eval_dataset is not None:
return self._gen_dataloader("eval")
return None
def predict_dataloader(self):
return self._gen_dataloader("predict")
return self._gen_dataloader("predict")
class OpenFoldMultimerDataModule(OpenFoldDataModule):
......@@ -1091,16 +1050,19 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule):
scripts/generate_mmcif_cache.py mmcif_data_cache_path should be
a file that record what chain(s) each mmcif file has
"""
def __init__(self, config: mlc.ConfigDict,
template_mmcif_dir: str, max_template_date: str,
def __init__(self, config: mlc.ConfigDict,
template_mmcif_dir: str, max_template_date: str,
train_data_dir: Optional[str] = None,
train_mmcif_data_cache_path:Optional[str] = None,
val_mmcif_data_cache_path:Optional[str] = None,
train_mmcif_data_cache_path: Optional[str] = None,
val_mmcif_data_cache_path: Optional[str] = None,
**kwargs):
super(OpenFoldMultimerDataModule,self).__init__(config,
template_mmcif_dir,
max_template_date,
train_data_dir,**kwargs)
super(OpenFoldMultimerDataModule, self).__init__(config,
template_mmcif_dir,
max_template_date,
train_data_dir,
**kwargs)
self.train_mmcif_data_cache_path = train_mmcif_data_cache_path
self.training_mode = self.train_data_dir is not None
self.val_mmcif_data_cache_path = val_mmcif_data_cache_path
......@@ -1108,32 +1070,28 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule):
def setup(self):
# Most of the arguments are the same for the three datasets
dataset_gen = partial(OpenFoldSingleMultimerDataset,
template_mmcif_dir=self.template_mmcif_dir,
max_template_date=self.max_template_date,
config=self.config,
kalign_binary_path=self.kalign_binary_path,
template_release_dates_cache_path=
self.template_release_dates_cache_path,
obsolete_pdbs_file_path=
self.obsolete_pdbs_file_path,
)
if(self.training_mode):
template_mmcif_dir=self.template_mmcif_dir,
max_template_date=self.max_template_date,
config=self.config,
kalign_binary_path=self.kalign_binary_path,
template_release_dates_cache_path=self.template_release_dates_cache_path,
obsolete_pdbs_file_path=self.obsolete_pdbs_file_path)
if self.training_mode:
train_dataset = dataset_gen(
data_dir=self.train_data_dir,
mmcif_data_cache_path=self.train_mmcif_data_cache_path,
alignment_dir=self.train_alignment_dir,
filter_path=self.train_filter_path,
max_template_hits=self.config.train.max_template_hits,
shuffle_top_k_prefiltered=
self.config.train.shuffle_top_k_prefiltered,
shuffle_top_k_prefiltered=self.config.train.shuffle_top_k_prefiltered,
treat_pdb_as_distillation=False,
mode="train",
alignment_index=self.alignment_index,
)
distillation_dataset = None
if(self.distillation_data_dir is not None):
if self.distillation_data_dir is not None:
distillation_dataset = dataset_gen(
data_dir=self.distillation_data_dir,
alignment_dir=self.distillation_alignment_dir,
......@@ -1146,8 +1104,8 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule):
)
d_prob = self.config.train.distillation_prob
if(distillation_dataset is not None):
if distillation_dataset is not None:
datasets = [train_dataset, distillation_dataset]
d_prob = self.config.train.distillation_prob
probabilities = [1. - d_prob, d_prob]
......@@ -1156,10 +1114,10 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule):
probabilities = [1.]
generator = None
if(self.batch_seed is not None):
if self.batch_seed is not None:
generator = torch.Generator()
generator = generator.manual_seed(self.batch_seed + 1)
self.train_dataset = OpenFoldMultimerDataset(
datasets=datasets,
probabilities=probabilities,
......@@ -1167,8 +1125,8 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule):
generator=generator,
_roll_at_init=True,
)
if(self.val_data_dir is not None):
if self.val_data_dir is not None:
self.eval_dataset = dataset_gen(
data_dir=self.val_data_dir,
alignment_dir=self.val_alignment_dir,
......@@ -1179,7 +1137,7 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule):
)
else:
self.eval_dataset = None
else:
else:
self.predict_dataset = dataset_gen(
data_dir=self.predict_data_dir,
alignment_dir=self.predict_alignment_dir,
......@@ -1187,32 +1145,7 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule):
max_template_hits=self.config.predict.max_template_hits,
mode="predict",
)
def _gen_dataloader(self, stage):
generator = None
if(self.batch_seed is not None):
generator = torch.Generator()
generator = generator.manual_seed(self.batch_seed)
dataset = None
if(stage == "train"):
dataset = self.train_dataset
# Filter the dataset, if necessary
dataset.reroll()
elif(stage == "eval"):
dataset = self.eval_dataset
elif(stage == "predict"):
dataset = self.predict_dataset
else:
raise ValueError("Invalid stage")
dl = torch.utils.data.DataLoader(
dataset,
batch_size=1,
num_workers=self.config.data_module.data_loaders.num_workers,
)
return dl
class DummyDataset(torch.utils.data.Dataset):
def __init__(self, batch_path):
......
......@@ -93,24 +93,11 @@ def np_example_to_features(
with torch.no_grad():
if is_multimer:
if mode == 'train':
features,gt_features = input_pipeline_multimer.process_tensors_from_config(
tensor_dict,
cfg.common,
cfg[mode],
is_training=True
)
return {k: v for k, v in features.items()}, gt_features
else:
features = input_pipeline_multimer.process_tensors_from_config(
tensor_dict,
cfg.common,
cfg[mode],
is_training=False
)
return {k: v for k, v in features.items()}
features = input_pipeline_multimer.process_tensors_from_config(
tensor_dict,
cfg.common,
cfg[mode],
)
else:
features = input_pipeline.process_tensors_from_config(
tensor_dict,
......
......@@ -21,16 +21,17 @@ from openfold.data import (
data_transforms_multimer,
)
def grountruth_transforms_fns():
transforms = [data_transforms.make_atom14_masks,
data_transforms.make_atom14_positions,
data_transforms.atom37_to_frames,
data_transforms.atom37_to_torsion_angles(""),
data_transforms.make_pseudo_beta(""),
data_transforms.get_backbone_frames,
data_transforms.get_chi_angles,
]
return transforms
def groundtruth_transforms_fns():
transforms = [data_transforms.make_atom14_masks,
data_transforms.make_atom14_positions,
data_transforms.atom37_to_frames,
data_transforms.atom37_to_torsion_angles(""),
data_transforms.make_pseudo_beta(""),
data_transforms.get_backbone_frames,
data_transforms.get_chi_angles]
return transforms
def nonensembled_transform_fns():
"""Input pipeline data transformers that are not ensembled."""
......@@ -112,20 +113,24 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed):
return transforms
def prepare_ground_truth_features(tensors):
"""Prepare ground truth features that are only needed for loss calculation during training"""
GROUNDTRUTH_FEATURES=['all_atom_mask', 'all_atom_positions','asym_id','sym_id','entity_id']
gt_tensors = {k:v for k,v in tensors.items() if k in GROUNDTRUTH_FEATURES}
gt_features = ['all_atom_mask', 'all_atom_positions', 'asym_id', 'sym_id', 'entity_id']
gt_tensors = {k: v for k, v in tensors.items() if k in gt_features}
gt_tensors['aatype'] = tensors['aatype'].to(torch.long)
gt_tensors = compose(grountruth_transforms_fns())(gt_tensors)
gt_tensors = compose(groundtruth_transforms_fns())(gt_tensors)
return gt_tensors
def process_tensors_from_config(tensors, common_cfg, mode_cfg,is_training=False):
def process_tensors_from_config(tensors, common_cfg, mode_cfg):
"""Based on the config, apply filters and transformations to the data."""
if is_training:
gt_tensors= prepare_ground_truth_features(tensors)
process_gt_feats = mode_cfg.supervised
gt_tensors = {}
if process_gt_feats:
gt_tensors = prepare_ground_truth_features(tensors)
ensemble_seed = random.randint(0, torch.iinfo(torch.int32).max)
tensors['aatype'] = tensors['aatype'].to(torch.long)
......@@ -152,10 +157,10 @@ def process_tensors_from_config(tensors, common_cfg, mode_cfg,is_training=False)
lambda x: wrap_ensemble_fn(tensors, x), torch.arange(num_recycling + 1)
)
if is_training:
return tensors,gt_tensors
else:
return tensors
if process_gt_feats:
tensors['gt_features'] = gt_tensors
return tensors
@data_transforms.curry1
def compose(x, fs):
......
......@@ -13,35 +13,30 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import logging
import ml_collections
import numpy as np
import torch
import torch.nn as nn
from torch.distributions.bernoulli import Bernoulli
from typing import Dict, Optional, Tuple
from openfold.np import residue_constants
from openfold.utils import feats
from openfold.utils.rigid_utils import Rotation, Rigid
from openfold.utils.geometry.vector import Vec3Array, euclidean_distance
from openfold.utils.all_atom_multimer import get_rc_tensor
from openfold.utils.tensor_utils import (
tree_map,
tensor_tree_map,
masked_mean,
permute_final_dims,
batched_gather,
)
import random
from openfold.np import residue_constants as rc
import logging
import procrustes
import logging
import procrustes
from openfold.utils.tensor_utils import tensor_tree_map
import gc
logger = logging.getLogger(__name__)
def softmax_cross_entropy(logits, labels):
loss = -1 * torch.sum(
labels * torch.nn.functional.log_softmax(logits, dim=-1),
......@@ -185,11 +180,10 @@ def backbone_loss(
eps: float = 1e-4,
**kwargs,
) -> torch.Tensor:
### need to check if the traj belongs to 4*4 matrix or a tensor_7
if traj.shape[-1]==7:
if traj.shape[-1] == 7:
pred_aff = Rigid.from_tensor_7(traj)
elif traj.shape[-1]==4:
elif traj.shape[-1] == 4:
pred_aff = Rigid.from_tensor_4x4(traj)
pred_aff = Rigid(
......@@ -256,10 +250,10 @@ def sidechain_loss(
**kwargs,
) -> torch.Tensor:
renamed_gt_frames = (
1.0 - alt_naming_is_better[..., None, None, None]
) * rigidgroups_gt_frames + alt_naming_is_better[
..., None, None, None
] * rigidgroups_alt_gt_frames
1.0 - alt_naming_is_better[..., None, None, None]
) * rigidgroups_gt_frames + alt_naming_is_better[
..., None, None, None
] * rigidgroups_alt_gt_frames
# Steamroll the inputs
sidechain_frames = sidechain_frames[-1]
......@@ -297,7 +291,6 @@ def fape_loss(
batch: Dict[str, torch.Tensor],
config: ml_collections.ConfigDict,
) -> torch.Tensor:
traj = out["sm"]["frames"]
asym_id = batch.get("asym_id")
if asym_id is not None:
......@@ -328,7 +321,7 @@ def fape_loss(
)
loss = weighted_bb_loss + config.sidechain.weight * sc_loss
# Average over the batch dimension
loss = torch.mean(loss)
......@@ -390,7 +383,7 @@ def supervised_chi_loss(
(true_chi_shifted - pred_angles) ** 2, dim=-1
)
sq_chi_error = torch.minimum(sq_chi_error, sq_chi_error_shifted)
# The ol' switcheroo
sq_chi_error = sq_chi_error.permute(
*range(len(sq_chi_error.shape))[1:-2], 0, -2, -1
......@@ -502,7 +495,7 @@ def lddt_ca(
ca_pos = residue_constants.atom_order["CA"]
all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :]
all_atom_positions = all_atom_positions[..., ca_pos, :]
all_atom_mask = all_atom_mask[..., ca_pos : (ca_pos + 1)] # keep dim
all_atom_mask = all_atom_mask[..., ca_pos: (ca_pos + 1)] # keep dim
return lddt(
all_atom_pred_pos,
......@@ -532,19 +525,19 @@ def lddt_loss(
ca_pos = residue_constants.atom_order["CA"]
all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :]
all_atom_positions = all_atom_positions[..., ca_pos, :]
all_atom_mask = all_atom_mask[..., ca_pos : (ca_pos + 1)] # keep dim
all_atom_mask = all_atom_mask[..., ca_pos: (ca_pos + 1)] # keep dim
score = lddt(
all_atom_pred_pos,
all_atom_positions,
all_atom_mask,
cutoff=cutoff,
all_atom_pred_pos,
all_atom_positions,
all_atom_mask,
cutoff=cutoff,
eps=eps
)
# TODO: Remove after initial pipeline testing
score = torch.nan_to_num(score, nan=torch.nanmean(score))
score[score<0] = 0
score[score < 0] = 0
score = score.detach()
bin_index = torch.floor(score * no_bins).long()
......@@ -586,7 +579,7 @@ def distogram_loss(
device=logits.device,
)
boundaries = boundaries ** 2
dists = torch.sum(
(pseudo_beta[..., None, :] - pseudo_beta[..., None, :, :]) ** 2,
dim=-1,
......@@ -707,12 +700,12 @@ def compute_tm(
n = residue_weights.shape[-1]
pair_mask = residue_weights.new_ones((n, n), dtype=torch.int32)
if interface and (asym_id is not None):
if len(asym_id.shape)>1:
assert len(asym_id.shape)<=2
if len(asym_id.shape) > 1:
assert len(asym_id.shape) <= 2
batch_size = asym_id.shape[0]
pair_mask = residue_weights.new_ones((batch_size,n, n), dtype=torch.int32)
pair_mask = residue_weights.new_ones((batch_size, n, n), dtype=torch.int32)
pair_mask *= (asym_id[..., None] != asym_id[..., None, :]).to(dtype=pair_mask.dtype)
predicted_tm_term *= pair_mask
pair_residue_weights = pair_mask * (
......@@ -727,6 +720,7 @@ def compute_tm(
argmax = (weighted == torch.max(weighted)).nonzero()[0]
return per_alignment[tuple(argmax)]
def tm_loss(
logits,
final_affine_tensor,
......@@ -741,9 +735,9 @@ def tm_loss(
**kwargs,
):
# first check whether this is a tensor_7 or tensor_4*4
if final_affine_tensor.shape[-1]==7:
if final_affine_tensor.shape[-1] == 7:
pred_affine = Rigid.from_tensor_7(final_affine_tensor)
elif final_affine_tensor.shape[-1]==4:
elif final_affine_tensor.shape[-1] == 4:
pred_affine = Rigid.from_tensor_4x4(final_affine_tensor)
backbone_rigid = Rigid.from_tensor_4x4(backbone_rigid_tensor)
......@@ -844,19 +838,19 @@ def between_residue_bond_loss(
# The C-N bond to proline has slightly different length because of the ring.
next_is_proline = aatype[..., 1:] == residue_constants.resname_to_idx["PRO"]
gt_length = (
~next_is_proline
) * residue_constants.between_res_bond_length_c_n[
0
] + next_is_proline * residue_constants.between_res_bond_length_c_n[
1
]
~next_is_proline
) * residue_constants.between_res_bond_length_c_n[
0
] + next_is_proline * residue_constants.between_res_bond_length_c_n[
1
]
gt_stddev = (
~next_is_proline
) * residue_constants.between_res_bond_length_stddev_c_n[
0
] + next_is_proline * residue_constants.between_res_bond_length_stddev_c_n[
1
]
~next_is_proline
) * residue_constants.between_res_bond_length_stddev_c_n[
0
] + next_is_proline * residue_constants.between_res_bond_length_stddev_c_n[
1
]
c_n_bond_length_error = torch.sqrt(eps + (c_n_bond_length - gt_length) ** 2)
c_n_loss_per_residue = torch.nn.functional.relu(
......@@ -1082,7 +1076,7 @@ def between_residue_clash_loss(
# Compute the hard clash mask.
# shape (N, N, 14, 14)
clash_mask = dists_mask * (
dists < (dists_lower_bound - overlap_tolerance_hard)
dists < (dists_lower_bound - overlap_tolerance_hard)
)
per_atom_num_clash = torch.sum(clash_mask, dim=(-4, -2)) + torch.sum(clash_mask, dim=(-3, -1))
......@@ -1098,7 +1092,7 @@ def between_residue_clash_loss(
"mean_loss": mean_loss, # shape ()
"per_atom_loss_sum": per_atom_loss_sum, # shape (N, 14)
"per_atom_clash_mask": per_atom_clash_mask, # shape (N, 14)
"per_atom_num_clash": per_atom_num_clash # shape (N, 14)
"per_atom_num_clash": per_atom_num_clash # shape (N, 14)
}
......@@ -1221,7 +1215,7 @@ def find_structural_violations(
atomtype_radius = atom14_pred_positions.new_tensor(atomtype_radius)
#TODO: Consolidate monomer/multimer modes
# TODO: Consolidate monomer/multimer modes
asym_id = batch.get("asym_id")
if asym_id is not None:
residx_atom14_to_atom37 = get_rc_tensor(
......@@ -1372,8 +1366,8 @@ def extreme_ca_ca_distance_violations(
eps + torch.sum((this_ca_pos - next_ca_pos) ** 2, dim=-1)
)
violations = (
ca_ca_distance - residue_constants.ca_ca
) > max_angstrom_tolerance
ca_ca_distance - residue_constants.ca_ca
) > max_angstrom_tolerance
mask = this_ca_mask * next_ca_mask * has_no_gap_mask
mean = masked_mean(mask, violations, -1)
return mean
......@@ -1559,16 +1553,16 @@ def compute_renamed_ground_truth(
alt_naming_is_better = (alt_per_res_lddt < per_res_lddt).type(fp_type)
renamed_atom14_gt_positions = (
1.0 - alt_naming_is_better[..., None, None]
) * atom14_gt_positions + alt_naming_is_better[
..., None, None
] * atom14_alt_gt_positions
1.0 - alt_naming_is_better[..., None, None]
) * atom14_gt_positions + alt_naming_is_better[
..., None, None
] * atom14_alt_gt_positions
renamed_atom14_gt_mask = (
1.0 - alt_naming_is_better[..., None]
) * atom14_gt_exists + alt_naming_is_better[..., None] * batch[
"atom14_alt_gt_exists"
]
1.0 - alt_naming_is_better[..., None]
) * atom14_gt_exists + alt_naming_is_better[..., None] * batch[
"atom14_alt_gt_exists"
]
return {
"alt_naming_is_better": alt_naming_is_better,
......@@ -1591,13 +1585,13 @@ def experimentally_resolved_loss(
loss = torch.sum(errors * atom37_atom_exists, dim=-1)
loss = loss / (eps + torch.sum(atom37_atom_exists, dim=(-1, -2)).unsqueeze(-1))
loss = torch.sum(loss, dim=-1)
loss = loss * (
(resolution >= min_resolution) & (resolution <= max_resolution)
)
loss = torch.mean(loss)
return loss
......@@ -1701,20 +1695,17 @@ def compute_rmsd(
eps: float = 1e-6,
) -> torch.Tensor:
sq_diff = torch.square(true_atom_pos - pred_atom_pos).sum(dim=-1, keepdim=False)
del true_atom_pos
del pred_atom_pos
gc.collect()
if atom_mask is not None:
sq_diff = torch.masked_select(sq_diff, atom_mask.to(sq_diff.device))
msd = torch.mean(sq_diff)
msd = torch.nan_to_num(msd, nan=1e8)
return torch.sqrt(msd + eps) # prevent sqrt 0
return torch.sqrt(msd + eps) # prevent sqrt 0
def kabsch_rotation(P, Q):
"""
Use procrustes package to calculate best rotation that minimises
the RMSD betwee P and Q
Use procrustes package to calculate the best rotation that minimises
the RMSD between P and Q
The optimal rotation matrix was calculated using
the rotational() function from procrustes package. Details can be found here:
......@@ -1728,12 +1719,12 @@ def kabsch_rotation(P, Q):
A 3*3 rotation matrix
"""
assert P.shape == torch.Size([Q.shape[0],Q.shape[1]])
assert P.shape == torch.Size([Q.shape[0], Q.shape[1]])
rotation = procrustes.rotational(P.detach().cpu().float().numpy(),
Q.detach().cpu().float().numpy(),translate=False,scale=False)
Q.detach().cpu().float().numpy(), translate=False, scale=False)
# Rotation.t doesn't mean transpose, t only means get the matrix out of the procruste object
rotation = torch.tensor(rotation.t,dtype=torch.float)
assert rotation.shape == torch.Size([3,3])
rotation = torch.tensor(rotation.t, dtype=torch.float)
assert rotation.shape == torch.Size([3, 3])
return rotation.to(device=P.device, dtype=P.dtype)
......@@ -1756,7 +1747,7 @@ def get_optimal_transform(
# sometimes using fake test inputs generates NaN in the predicted atom positions
# #
logging.warning(f"src_atom has nan or inf")
src_atoms = torch.nan_to_num(src_atoms,nan=0.0,posinf=1.0,neginf=1.0)
src_atoms = torch.nan_to_num(src_atoms, nan=0.0, posinf=1.0, neginf=1.0)
if mask is not None:
assert mask.dtype == torch.bool
......@@ -1767,21 +1758,15 @@ def get_optimal_transform(
else:
src_atoms = src_atoms[mask, :]
tgt_atoms = tgt_atoms[mask, :]
src_center = src_atoms.mean(-2, keepdim=True,dtype=src_atoms.dtype)
tgt_center = tgt_atoms.mean(-2, keepdim=True,dtype=src_atoms.dtype)
r = kabsch_rotation(src_atoms,tgt_atoms)
del src_atoms,tgt_atoms,
gc.collect()
src_center = src_atoms.mean(-2, keepdim=True, dtype=src_atoms.dtype)
tgt_center = tgt_atoms.mean(-2, keepdim=True, dtype=src_atoms.dtype)
r = kabsch_rotation(src_atoms, tgt_atoms)
x = tgt_center - src_center @ r
del tgt_center,src_center,mask
gc.collect()
return r, x
def get_least_asym_entity_or_longest_length(batch,input_asym_id):
def get_least_asym_entity_or_longest_length(batch, input_asym_id):
"""
First check how many subunit(s) one sequence has, if there is no tie, e.g. AABBB then select
one of the A as anchor
......@@ -1805,7 +1790,7 @@ def get_least_asym_entity_or_longest_length(batch,input_asym_id):
for entity_id in unique_entity_ids:
asym_ids = torch.unique(batch["asym_id"][batch["entity_id"] == entity_id])
entity_asym_count[int(entity_id)] = len(asym_ids)
# Calculate entity length
entity_mask = (batch["entity_id"] == entity_id)
entity_length[int(entity_id)] = entity_mask.sum().item()
......@@ -1821,13 +1806,14 @@ def get_least_asym_entity_or_longest_length(batch,input_asym_id):
# If still multiple entities, return a random one
if len(least_asym_entities) > 1:
least_asym_entities = random.choice(least_asym_entities)
assert len(least_asym_entities)==1
assert len(least_asym_entities) == 1
least_asym_entities = least_asym_entities[0]
anchor_gt_asym_id = random.choice(entity_2_asym_list[least_asym_entities])
anchor_pred_asym_ids = [id for id in entity_2_asym_list[least_asym_entities] if id in input_asym_id]
return anchor_gt_asym_id, anchor_pred_asym_ids
def greedy_align(
batch,
per_asym_residue_index,
......@@ -1843,7 +1829,7 @@ def greedy_align(
"""
used = [False for _ in range(len(true_ca_poses))]
align = []
unique_asym_ids = [i for i in torch.unique(batch["asym_id"]) if i!=0]
unique_asym_ids = [i for i in torch.unique(batch["asym_id"]) if i != 0]
for cur_asym_id in unique_asym_ids:
i = int(cur_asym_id - 1)
asym_mask = batch["asym_id"] == cur_asym_id
......@@ -1857,13 +1843,13 @@ def greedy_align(
for next_asym_id in cur_asym_list:
j = int(next_asym_id - 1)
if not used[j]: # possible candidate
cropped_pos = torch.index_select(true_ca_poses[j],1,cur_residue_index)
mask = torch.index_select(true_ca_masks[j],1,cur_residue_index)
cropped_pos = torch.index_select(true_ca_poses[j], 1, cur_residue_index)
mask = torch.index_select(true_ca_masks[j], 1, cur_residue_index)
rmsd = compute_rmsd(
torch.squeeze(cropped_pos,0), torch.squeeze(cur_pred_pos,0),
torch.squeeze(cropped_pos, 0), torch.squeeze(cur_pred_pos, 0),
(cur_pred_mask * mask).bool()
)
if (rmsd is not None) and (rmsd < best_rmsd):
if rmsd is not None and rmsd < best_rmsd:
best_rmsd = rmsd
best_idx = j
......@@ -1873,15 +1859,15 @@ def greedy_align(
return align
def pad_features(feature_tensor,nres_pad,pad_dim):
def pad_features(feature_tensor, nres_pad, pad_dim):
"""Pad input feature tensor"""
pad_shape = list(feature_tensor.shape)
pad_shape[pad_dim] = nres_pad
padding_tensor = feature_tensor.new_zeros(pad_shape,device=feature_tensor.device)
return torch.concat((feature_tensor,padding_tensor),dim=pad_dim)
padding_tensor = feature_tensor.new_zeros(pad_shape, device=feature_tensor.device)
return torch.concat((feature_tensor, padding_tensor), dim=pad_dim)
def merge_labels(per_asym_residue_index,labels, align,original_nres):
def merge_labels(per_asym_residue_index, labels, align, original_nres):
"""
Merge ground truth labels according to the permutation results
......@@ -1898,24 +1884,25 @@ def merge_labels(per_asym_residue_index,labels, align,original_nres):
label = labels[j][k]
# to 1-based
cur_residue_index = per_asym_residue_index[i + 1]
if len(v.shape)<=1 or "template" in k or "row_mask" in k :
if len(v.shape) <= 1 or "template" in k or "row_mask" in k:
continue
else:
dimension_to_merge = 1
cur_out[i] = label.index_select(dimension_to_merge,cur_residue_index)
cur_out[i] = label.index_select(dimension_to_merge, cur_residue_index)
cur_out = [x[1] for x in sorted(cur_out.items())]
if len(cur_out)>0:
if len(cur_out) > 0:
new_v = torch.concat(cur_out, dim=dimension_to_merge)
# below check whether padding is needed
if new_v.shape[dimension_to_merge]!=original_nres:
if new_v.shape[dimension_to_merge] != original_nres:
nres_pad = original_nres - new_v.shape[dimension_to_merge]
new_v = pad_features(new_v,nres_pad,pad_dim=dimension_to_merge)
new_v = pad_features(new_v, nres_pad, pad_dim=dimension_to_merge)
outs[k] = new_v
return outs
class AlphaFoldLoss(nn.Module):
"""Aggregation of the various losses described in the supplement"""
def __init__(self, config):
super(AlphaFoldLoss, self).__init__()
self.config = config
......@@ -1974,13 +1961,13 @@ class AlphaFoldLoss(nn.Module):
),
}
if(self.config.tm.enabled):
if self.config.tm.enabled:
loss_fns["tm"] = lambda: tm_loss(
logits=out["tm_logits"],
**{**batch, **out, **self.config.tm},
)
if (self.config.chain_center_of_mass.enabled):
if self.config.chain_center_of_mass.enabled:
loss_fns["chain_center_of_mass"] = lambda: chain_center_of_mass_loss(
all_atom_pred_pos=out["final_atom_positions"],
**{**batch, **self.config.chain_center_of_mass},
......@@ -1991,11 +1978,11 @@ class AlphaFoldLoss(nn.Module):
for loss_name, loss_fn in loss_fns.items():
weight = self.config[loss_name].weight
loss = loss_fn()
if(torch.isnan(loss) or torch.isinf(loss)):
#for k,v in batch.items():
# if(torch.any(torch.isnan(v)) or torch.any(torch.isinf(v))):
if torch.isnan(loss) or torch.isinf(loss):
# for k,v in batch.items():
# if torch.any(torch.isnan(v)) or torch.any(torch.isinf(v)):
# logging.warning(f"{k}: is nan")
#logging.warning(f"{loss_name}: {loss}")
# logging.warning(f"{loss_name}: {loss}")
logging.warning(f"{loss_name} loss is NaN. Skipping...")
loss = loss.new_tensor(0., requires_grad=True)
cum_loss = cum_loss + weight * loss
......@@ -2010,18 +1997,18 @@ class AlphaFoldLoss(nn.Module):
losses["loss"] = cum_loss.detach().clone()
if(not _return_breakdown):
if not _return_breakdown:
return cum_loss
return cum_loss, losses
def forward(self, out, batch, _return_breakdown=False):
if(not _return_breakdown):
cum_loss = self.loss(out,batch,_return_breakdown)
return cum_loss
else:
cum_loss,losses = self.loss(out,batch,_return_breakdown)
return cum_loss, losses
if not _return_breakdown:
cum_loss = self.loss(out, batch, _return_breakdown)
return cum_loss
else:
cum_loss, losses = self.loss(out, batch, _return_breakdown)
return cum_loss, losses
class AlphaFoldMultimerLoss(AlphaFoldLoss):
......@@ -2029,12 +2016,13 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
Add multi-chain permutation on top of
AlphaFoldLoss
"""
def __init__(self, config):
super(AlphaFoldMultimerLoss, self).__init__(config)
self.config = config
@staticmethod
def split_ground_truth_labels(batch,REQUIRED_FEATURES,split_dim=1):
def split_ground_truth_labels(gt_features):
"""
Splits ground truth features according to chains
......@@ -2042,26 +2030,26 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
a list of feature dictionaries with only necessary ground truth features
required to finish multi-chain permutation
"""
unique_asym_ids, asym_id_counts = torch.unique(batch["asym_id"], sorted=False,return_counts=True)
unique_asym_ids, asym_id_counts= unique_asym_ids.tolist(),asym_id_counts.tolist()
if 0 in unique_asym_ids:
pop_idx = unique_asym_ids.index(0)
padding_asym_id = unique_asym_ids.pop(pop_idx)
padding_asym_counts = asym_id_counts.pop(pop_idx)
unique_asym_ids.append(padding_asym_id)
asym_id_counts.append(padding_asym_counts)
labels = list(map(dict, zip(*[[(k, v) for v in torch.split(value, asym_id_counts, dim=split_dim)] for k, value in batch.items() if k in REQUIRED_FEATURES])))
unique_asym_ids, asym_id_counts = torch.unique(gt_features["asym_id"], sorted=True, return_counts=True)
n_res = gt_features["asym_id"].shape[-1]
def split_dim(shape):
return next(iter(i for i, size in enumerate(shape) if size == n_res), None)
labels = list(map(dict, zip(*[[(k, v) for v in torch.split(v_all, asym_id_counts.tolist(),
dim=split_dim(v_all.shape))]
for k, v_all in gt_features.items()
if n_res in v_all.shape])))
return labels
@staticmethod
def get_per_asym_residue_index(features):
unique_asym_ids = [i for i in torch.unique(features["asym_id"]) if i!=0]
unique_asym_ids = [i for i in torch.unique(features["asym_id"]) if i != 0]
per_asym_residue_index = {}
for cur_asym_id in unique_asym_ids:
asym_mask = (features["asym_id"] == cur_asym_id).bool()
per_asym_residue_index[int(cur_asym_id)] = torch.masked_select(features["residue_index"], asym_mask)
return per_asym_residue_index
@staticmethod
......@@ -2083,10 +2071,10 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
cur_asym_id = torch.unique(batch["asym_id"][ent_mask])
entity_2_asym_list[int(cur_ent_id)] = cur_asym_id
return entity_2_asym_list
@staticmethod
def calculate_input_mask(true_ca_masks,anchor_gt_idx,anchor_gt_residue,
asym_mask,pred_ca_mask):
def calculate_input_mask(true_ca_masks, anchor_gt_idx, anchor_gt_residue,
asym_mask, pred_ca_mask):
"""
Calculate an input mask for downstream optimal transformation computation
......@@ -2099,37 +2087,37 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
Returns:
input_mask (Tensor): A boolean mask
"""
pred_ca_mask = torch.squeeze(pred_ca_mask,0)
asym_mask = torch.squeeze(asym_mask,0)
pred_ca_mask = torch.squeeze(pred_ca_mask, 0)
asym_mask = torch.squeeze(asym_mask, 0)
anchor_pred_mask = pred_ca_mask[asym_mask]
anchor_true_mask = torch.index_select(true_ca_masks[anchor_gt_idx],1,anchor_gt_residue)
anchor_true_mask = torch.index_select(true_ca_masks[anchor_gt_idx], 1, anchor_gt_residue)
input_mask = (anchor_true_mask * anchor_pred_mask).bool()
return input_mask
@staticmethod
def calculate_optimal_transform(true_ca_poses,
anchor_gt_idx,anchor_gt_residue,
true_ca_masks,pred_ca_mask,
anchor_gt_idx, anchor_gt_residue,
true_ca_masks, pred_ca_mask,
asym_mask,
pred_ca_pos):
input_mask = AlphaFoldMultimerLoss.calculate_input_mask(true_ca_masks,
anchor_gt_idx,anchor_gt_residue,
anchor_gt_idx, anchor_gt_residue,
asym_mask,
pred_ca_mask)
input_mask = torch.squeeze(input_mask,0)
pred_ca_pos = torch.squeeze(pred_ca_pos,0)
asym_mask = torch.squeeze(asym_mask,0)
anchor_true_pos = torch.index_select(true_ca_poses[anchor_gt_idx],1,anchor_gt_residue)
input_mask = torch.squeeze(input_mask, 0)
pred_ca_pos = torch.squeeze(pred_ca_pos, 0)
asym_mask = torch.squeeze(asym_mask, 0)
anchor_true_pos = torch.index_select(true_ca_poses[anchor_gt_idx], 1, anchor_gt_residue)
anchor_pred_pos = pred_ca_pos[asym_mask]
r, x = get_optimal_transform(
anchor_pred_pos, torch.squeeze(anchor_true_pos,0),
anchor_pred_pos, torch.squeeze(anchor_true_pos, 0),
mask=input_mask
)
)
return r, x
@staticmethod
def multi_chain_perm_align(out, batch,permutate_chains=False):
def multi_chain_perm_align(out, features, ground_truth):
"""
A class method that first permutate chains in ground truth first
before calculating the loss.
......@@ -2137,71 +2125,68 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
Details are described in Section 7.3 in the Supplementary of AlphaFold-Multimer paper:
https://www.biorxiv.org/content/10.1101/2021.10.04.463034v2
"""
feature, ground_truth = batch
del batch
if permutate_chains:
best_rmsd = float('inf')
best_align = None
# First select anchors from predicted structures and ground truths
anchor_gt_asym, anchor_pred_asym_ids = get_least_asym_entity_or_longest_length(ground_truth,feature['asym_id'])
entity_2_asym_list = AlphaFoldMultimerLoss.get_entity_2_asym_list(ground_truth)
labels = AlphaFoldMultimerLoss.split_ground_truth_labels(ground_truth,
REQUIRED_FEATURES=["all_atom_mask","all_atom_positions"])
assert isinstance(labels, list)
del ground_truth
anchor_gt_idx = int(anchor_gt_asym) - 1
# Then calculate optimal transform by aligning anchors
ca_idx = rc.atom_order["CA"]
pred_ca_pos = out["final_atom_positions"][..., ca_idx, :] # [bsz, nres, 3]
pred_ca_mask = out["final_atom_mask"][..., ca_idx].to(dtype=pred_ca_pos.dtype) # [bsz, nres]
true_ca_poses = [
l["all_atom_positions"][..., ca_idx, :] for l in labels
] # list([nres, 3])
true_ca_masks = [
l["all_atom_mask"][..., ca_idx].long() for l in labels
] # list([nres,])
per_asym_residue_index = AlphaFoldMultimerLoss.get_per_asym_residue_index(feature)
for candidate_pred_anchor in anchor_pred_asym_ids:
asym_mask = (feature["asym_id"] == candidate_pred_anchor).bool()
anchor_gt_residue = per_asym_residue_index[int(candidate_pred_anchor)]
r, x = AlphaFoldMultimerLoss.calculate_optimal_transform(true_ca_poses,
anchor_gt_idx,anchor_gt_residue,
true_ca_masks,pred_ca_mask,
asym_mask,
pred_ca_pos
)
aligned_true_ca_poses = [ca.to(r.dtype) @ r + x for ca in true_ca_poses] # apply transforms
align = greedy_align(
feature,
per_asym_residue_index,
entity_2_asym_list,
pred_ca_pos,
pred_ca_mask,
aligned_true_ca_poses,
true_ca_masks,
)
merged_labels = merge_labels(per_asym_residue_index,labels,align,
original_nres=feature['aatype'].shape[-1])
rmsd = compute_rmsd(true_atom_pos = merged_labels['all_atom_positions'][..., ca_idx, :].to(r.dtype) @ r + x,
pred_atom_pos = pred_ca_pos,
atom_mask = (pred_ca_mask * merged_labels['all_atom_mask'][..., ca_idx].long()).bool())
if rmsd < best_rmsd:
best_rmsd = rmsd
best_align = align
del r,x
del true_ca_masks,aligned_true_ca_poses
del pred_ca_pos, pred_ca_mask
gc.collect()
else:
per_asym_residue_index = AlphaFoldMultimerLoss.get_per_asym_residue_index(feature)
unique_asym_ids = set(torch.unique(features['asym_id']).tolist())
unique_asym_ids.discard(0) # Remove padding asym_id
is_monomer = len(unique_asym_ids) == 1
per_asym_residue_index = AlphaFoldMultimerLoss.get_per_asym_residue_index(features)
if is_monomer:
best_align = list(enumerate(range(len(per_asym_residue_index))))
return best_align, per_asym_residue_index
best_rmsd = float('inf')
best_align = None
# First select anchors from predicted structures and ground truths
anchor_gt_asym, anchor_pred_asym_ids = get_least_asym_entity_or_longest_length(ground_truth,
features['asym_id'])
entity_2_asym_list = AlphaFoldMultimerLoss.get_entity_2_asym_list(ground_truth)
labels = AlphaFoldMultimerLoss.split_ground_truth_labels(ground_truth)
assert isinstance(labels, list)
anchor_gt_idx = int(anchor_gt_asym) - 1
# Then calculate optimal transform by aligning anchors
ca_idx = rc.atom_order["CA"]
pred_ca_pos = out["final_atom_positions"][..., ca_idx, :] # [bsz, nres, 3]
pred_ca_mask = out["final_atom_mask"][..., ca_idx].to(dtype=pred_ca_pos.dtype) # [bsz, nres]
true_ca_poses = [
l["all_atom_positions"][..., ca_idx, :] for l in labels
] # list([nres, 3])
true_ca_masks = [
l["all_atom_mask"][..., ca_idx].long() for l in labels
] # list([nres,])
for candidate_pred_anchor in anchor_pred_asym_ids:
asym_mask = (features["asym_id"] == candidate_pred_anchor).bool()
anchor_gt_residue = per_asym_residue_index[candidate_pred_anchor.item()]
r, x = AlphaFoldMultimerLoss.calculate_optimal_transform(true_ca_poses,
anchor_gt_idx, anchor_gt_residue,
true_ca_masks, pred_ca_mask,
asym_mask,
pred_ca_pos
)
aligned_true_ca_poses = [ca.to(r.dtype) @ r + x for ca in true_ca_poses] # apply transforms
align = greedy_align(
features,
per_asym_residue_index,
entity_2_asym_list,
pred_ca_pos,
pred_ca_mask,
aligned_true_ca_poses,
true_ca_masks,
)
merged_labels = merge_labels(per_asym_residue_index, labels, align,
original_nres=features['aatype'].shape[-1])
rmsd = compute_rmsd(
true_atom_pos=merged_labels['all_atom_positions'][..., ca_idx, :].to(r.dtype) @ r + x,
pred_atom_pos=pred_ca_pos,
atom_mask=(pred_ca_mask * merged_labels['all_atom_mask'][..., ca_idx].long()).bool())
if rmsd < best_rmsd:
best_rmsd = rmsd
best_align = align
return best_align, per_asym_residue_index
def forward(self, out, batch, _return_breakdown=False,permutate_chains=True):
def forward(self, out, batch, _return_breakdown=False):
"""
Overwrite AlphaFoldLoss forward function so that
it first compute multi-chain permutation
......@@ -2210,32 +2195,24 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
out: the output of model.forward()
batch: a pair of input features and its corresponding ground truth structure
"""
# first check if it is a monomer
features, ground_truth = batch
del batch
is_monomer = len(torch.unique(features['asym_id']))==1 or torch.unique(features['asym_id']).tolist()==[0,1]
if not is_monomer:
permutate_chains = True
# Then permutate ground truth chains before calculating the loss
align,per_asym_residue_index = AlphaFoldMultimerLoss.multi_chain_perm_align(out,
(features,ground_truth),
permutate_chains=permutate_chains)
labels = AlphaFoldMultimerLoss.split_ground_truth_labels(ground_truth,
REQUIRED_FEATURES=[i for i in ground_truth.keys()])
ground_truth = batch.pop('gt_features')
labels = AlphaFoldMultimerLoss.split_ground_truth_labels(ground_truth)
# Then permute ground truth chains before calculating the loss
align, per_asym_residue_index = AlphaFoldMultimerLoss.multi_chain_perm_align(out=out,
features=batch,
ground_truth=ground_truth)
# reorder ground truth labels according to permutation results
labels = merge_labels(per_asym_residue_index,labels,align,
original_nres=features['aatype'].shape[-1])
features.update(labels)
labels = merge_labels(per_asym_residue_index, labels, align,
original_nres=batch['aatype'].shape[-1])
batch.update(labels)
if (not _return_breakdown):
cum_loss = self.loss(out, features, _return_breakdown)
if not _return_breakdown:
cum_loss = self.loss(out, batch, _return_breakdown)
print(f"cum_loss: {cum_loss}")
return cum_loss
else:
cum_loss, losses = self.loss(out, features, _return_breakdown)
cum_loss, losses = self.loss(out, batch, _return_breakdown)
print(f"cum_loss: {cum_loss} losses: {losses}")
return cum_loss, losses
\ No newline at end of file
return cum_loss, losses
......@@ -13,7 +13,7 @@ from tqdm import tqdm
from openfold.data.mmcif_parsing import parse
def parse_file(f, args):
def parse_file(f, args, chain_cluster_size_dict=None):
with open(os.path.join(args.mmcif_dir, f), "r") as fp:
mmcif_string = fp.read()
file_id = os.path.splitext(f)[0]
......@@ -28,6 +28,18 @@ def parse_file(f, args):
local_data["release_date"] = mmcif.header["release_date"]
chain_ids, seqs = list(zip(*mmcif.chain_to_seqres.items()))
if chain_cluster_size_dict is not None:
cluster_sizes = []
for chain_id in chain_ids:
full_name = "_".join([file_id, chain_id])
cluster_size = chain_cluster_size_dict.get(
full_name.upper(), -1
)
cluster_sizes.append(cluster_size)
local_data["cluster_sizes"] = cluster_sizes
local_data["chain_ids"] = chain_ids
local_data["seqs"] = seqs
local_data["no_chains"] = len(chain_ids)
......@@ -38,8 +50,21 @@ def parse_file(f, args):
def main(args):
chain_cluster_size_dict = None
if args.cluster_file is not None:
chain_cluster_size_dict = {}
with open(args.cluster_file, "r") as fp:
clusters = [l.strip() for l in fp.readlines()]
for cluster in clusters:
chain_ids = cluster.split()
cluster_len = len(chain_ids)
for chain_id in chain_ids:
chain_id = chain_id.upper()
chain_cluster_size_dict[chain_id] = cluster_len
files = [f for f in os.listdir(args.mmcif_dir) if ".cif" in f]
fn = partial(parse_file, args=args)
fn = partial(parse_file, args=args, chain_cluster_size_dict=chain_cluster_size_dict)
data = {}
with Pool(processes=args.no_workers) as p:
with tqdm(total=len(files)) as pbar:
......@@ -63,6 +88,15 @@ if __name__ == "__main__":
"--no_workers", type=int, default=4,
help="Number of workers to use for parsing"
)
parser.add_argument(
"--cluster_file", type=str, default=None,
help=(
"Path to a cluster file (e.g. PDB40), one cluster "
"({PROT1_ID}_{CHAIN_ID} {PROT2_ID}_{CHAIN_ID} ...) per line. "
"Chains not in this cluster file will NOT be filtered by cluster "
"size."
)
)
parser.add_argument(
"--chunksize", type=int, default=10,
help="How many files should be distributed to each worker at a time"
......
import argparse
import logging
import os
import random
import sys
import time
import numpy as np
import pytorch_lightning as pl
from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.plugins.training_type import DeepSpeedPlugin, DDPPlugin
from pytorch_lightning.plugins.environments import SLURMEnvironment
import torch
from openfold.config import model_config
from openfold.data.data_modules import (
OpenFoldDataModule,OpenFoldMultimerDataModule,
DummyDataLoader,
)
from openfold.data.data_modules import OpenFoldDataModule, OpenFoldMultimerDataModule
from openfold.model.model import AlphaFold
from openfold.model.torchscript import script_preset_
from openfold.np import residue_constants
......@@ -53,7 +46,12 @@ class OpenFoldWrapper(pl.LightningModule):
super(OpenFoldWrapper, self).__init__()
self.config = config
self.model = AlphaFold(config)
self.loss = AlphaFoldLoss(config.loss)
if self.config.globals.is_multimer:
self.loss = AlphaFoldMultimerLoss(config.loss)
else:
self.loss = AlphaFoldLoss(config.loss)
self.ema = ExponentialMovingAverage(
model=self.model, decay=config.ema.decay
)
......@@ -256,72 +254,6 @@ class OpenFoldWrapper(pl.LightningModule):
)
class OpenFoldMultimerWrapper(OpenFoldWrapper):
def __init__(self, config):
super(OpenFoldMultimerWrapper, self).__init__(config)
self.config = config
self.model = AlphaFold(config)
self.loss = AlphaFoldMultimerLoss(config.loss)
self.ema = ExponentialMovingAverage(
model=self.model, decay=config.ema.decay
)
self.cached_weights = None
self.last_lr_step = -1
def forward(self, batch):
return self.model(batch)
def training_step(self, batch, batch_idx):
features,gt_features = batch
# Log it
if(self.ema.device != features["aatype"].device):
self.ema.to(features["aatype"].device)
# Run the model
outputs = self(features)
# Remove the recycling dimension
features = tensor_tree_map(lambda t: t[..., -1], features)
# Compute loss
loss, loss_breakdown = self.loss(
outputs, (features,gt_features), _return_breakdown=True
)
# Log it
self._log(loss_breakdown, features, outputs)
return loss
def validation_step(self, batch, batch_idx):
features,gt_features = batch
# At the start of validation, load the EMA weights
if(self.cached_weights is None):
# model.state_dict() contains references to model weights rather
# than copies. Therefore, we need to clone them before calling
# load_state_dict().
clone_param = lambda t: t.detach().clone()
self.cached_weights = tensor_tree_map(clone_param, self.model.state_dict())
self.model.load_state_dict(self.ema.state_dict()["params"])
# Run the model
outputs = self(features)
# Compute loss and other metrics
features["use_clamped_fape"] = 0.
_, loss_breakdown = self.loss(
outputs, (features,gt_features), _return_breakdown=True
)
self._log(loss_breakdown, features, outputs, train=False)
def validation_epoch_end(self, _):
# Restore the model weights to normal
self.model.load_state_dict(self.cached_weights)
self.cached_weights = None
def main(args):
if(args.seed is not None):
seed_everything(args.seed)
......@@ -331,10 +263,8 @@ def main(args):
train=True,
low_prec=(str(args.precision) == "16")
)
if "multimer" in args.config_preset:
model_module = OpenFoldMultimerWrapper(config)
else:
model_module = OpenFoldWrapper(config)
model_module = OpenFoldWrapper(config)
if(args.resume_from_ckpt):
if(os.path.isdir(args.resume_from_ckpt)):
last_global_step = get_global_step_from_zero_checkpoint(args.resume_from_ckpt)
......@@ -359,7 +289,6 @@ def main(args):
if(args.script_modules):
script_preset_(model_module)
#data_module = DummyDataLoader("new_batch.pickle")
if "multimer" in args.config_preset:
data_module = OpenFoldMultimerDataModule(
config=config.data,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment