Commit 33d8de81 authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

added multimer training filter criteria described in the multimer paper

parent d35816e3
......@@ -11,7 +11,7 @@ import numpy as np
import pytorch_lightning as pl
import torch
from torch.utils.data import RandomSampler
from openfold.np.residue_constants import restypes
from openfold.data import (
data_pipeline,
feature_pipeline,
......@@ -579,6 +579,40 @@ def deterministic_train_filter(
return True
def deterministic_multimer_train_filter(
mmcif_data_cache_entry,
max_resolution: 9.,
max_single_aa_prop:float=0.8,
minimum_number_of_residues:int=200,
) -> bool:
"""
Implement multimer training filtering criteria described in
https://www.biorxiv.org/content/10.1101/2021.10.04.463034v2.full.pdf Supplementary section 7.1
"""
# First check resolution
resolution = mmcif_data_cache_entry.get("resolution", None)
if(resolution is not None and resolution > max_resolution):
return False
# Then check if any single amino acid accounts for more than 80% of the complex sequences
seqs = mmcif_data_cache_entry["seqs"]
counts = {}
for aa in restypes:
counts[aa] = 0
total_len = sum([len(i) for i in seqs])
if total_len<minimum_number_of_residues: # check if the complex has less than 200 residues
return False
for seq in seqs:
for aa in seq:
counts[aa] += 1
largest_aa_count = max(counts.values())
largest_single_aa_prop = largest_aa_count / total_len
if(largest_single_aa_prop > max_single_aa_prop):
return False
return True
def get_stochastic_train_filter_prob(
chain_data_cache_entry: Any,
......@@ -694,12 +728,83 @@ class OpenFoldDataset(torch.utils.data.Dataset):
self.datapoints.append((dataset_idx, datapoint_idx))
print(f"datapoints is {self.datapoints}")
class OpenFoldMultimerDataset(torch.utils.data.Dataset):
"""
Create a torch Dataset object for multimer training and
add filtering steps described in AlphaFold Multimer's paper:
https://www.biorxiv.org/content/10.1101/2021.10.04.463034v2.full.pdf Supplementary section 7.1
"""
def __init__(self,
datasets: Sequence[OpenFoldSingleDataset],
probabilities: Sequence[float],
epoch_len: int,
generator: torch.Generator = None,
_roll_at_init: bool = True,
):
self.datasets = datasets
self.probabilities = probabilities
self.epoch_len = epoch_len
self.generator = generator
self._samples = [self.looped_samples(i) for i in range(len(self.datasets))]
def looped_shuffled_dataset_idx(self,dataset_len):
while True:
# Uniformly shuffle each dataset's indices
weights = [1. for _ in range(dataset_len)]
shuf = torch.multinomial(
torch.tensor(weights),
num_samples=dataset_len,
replacement=False,
generator=self.generator,
)
for idx in shuf:
yield idx
def looped_samples(self,dataset_idx):
max_cache_len = int(self.epoch_len * self.probabilities[dataset_idx])
dataset = self.datasets[dataset_idx]
idx_iter = self.looped_shuffled_dataset_idx(len(dataset))
chain_data_cache = dataset.chain_data_cache
mmcif_data_cache = dataset.mmcif_data_cache
while True:
weights = []
idx = []
for _ in range(max_cache_len):
candidate_idx = next(idx_iter)
## TO DO: add filtering cretieria for multimer
mmcif_id = dataset.idx_to_mmcif_id(candidate_idx)
chains = chain_data_cache[mmcif_id]['chain_ids']
mmcif_data_cache_entry = mmcif_data_cache[mmcif_id]
if(not deterministic_multimer_train_filter(mmcif_data_cache_entry)):
continue
p = get_stochastic_train_filter_prob(
chain_data_cache_entry,
)
weights.append([1. - p, p])
idx.append(candidate_idx)
samples = torch.multinomial(
torch.tensor(weights),
num_samples=1,
generator=self.generator,
)
samples = samples.squeeze()
cache = [i for i, s in zip(idx, samples) if s]
for datapoint_idx in cache:
yield datapoint_idx
class OpenFoldBatchCollator:
def __call__(self, prots):
stack_fn = partial(torch.stack, dim=0)
return dict_multimap(stack_fn, prots)
class OpenFoldDataLoader(torch.utils.data.DataLoader):
def __init__(self, *args, config, stage="train", generator=None, **kwargs):
super().__init__(*args, **kwargs)
......@@ -796,7 +901,6 @@ class OpenFoldDataModule(pl.LightningDataModule):
train_data_dir: Optional[str] = None,
train_alignment_dir: Optional[str] = None,
train_chain_data_cache_path: Optional[str] = None,
train_mmcif_data_cache_path:Optional[str] = None,
distillation_data_dir: Optional[str] = None,
distillation_alignment_dir: Optional[str] = None,
distillation_chain_data_cache_path: Optional[str] = None,
......@@ -824,7 +928,6 @@ class OpenFoldDataModule(pl.LightningDataModule):
self.train_data_dir = train_data_dir
self.train_alignment_dir = train_alignment_dir
self.train_chain_data_cache_path = train_chain_data_cache_path
self.train_mmcif_data_cache_path=train_mmcif_data_cache_path
self.distillation_data_dir = distillation_data_dir
self.distillation_alignment_dir = distillation_alignment_dir
self.distillation_chain_data_cache_path = (
......@@ -1045,6 +1148,7 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule):
train_dataset = dataset_gen(
data_dir=self.train_data_dir,
mmcif_data_cache_path=self.train_mmcif_data_cache_path,
chain_data_cache_path=self.train_chain_data_cache_path,
alignment_dir=self.train_alignment_dir,
filter_path=self.train_filter_path,
max_template_hits=self.config.train.max_template_hits,
......@@ -1084,7 +1188,7 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule):
generator = torch.Generator()
generator = generator.manual_seed(self.batch_seed + 1)
self.train_dataset = OpenFoldDataset(
self.train_dataset = OpenFoldMultimerDataset(
datasets=datasets,
probabilities=probabilities,
epoch_len=self.train_epoch_len,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment