"vscode:/vscode.git/clone" did not exist on "c3b847901099bf5c3dd174a3c8ec994b73426833"
Commit 566ca1a3 authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

added openfold multimer dataloader class and overwrite batch processing

parent dbc0b085
......@@ -24,6 +24,9 @@ import tempfile
from openfold.utils.tensor_utils import (
tensor_tree_map,
)
import logging
logger = logging.getLogger(__name__)
@contextlib.contextmanager
def temp_fasta_file(sequence_str):
"""function that create temparory fasta file used in multimer datapipeline"""
......@@ -468,6 +471,7 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
def __getitem__(self, idx):
mmcif_id = self.idx_to_mmcif_id(idx)
print(f"mmcif_id is :{mmcif_id}")
chains = self.mmcif_data_cache[mmcif_id]['chain_ids']
seqs = self.mmcif_data_cache[mmcif_id]['seqs']
fasta_str = ""
......@@ -599,7 +603,10 @@ def deterministic_multimer_train_filter(
for seq in seqs:
for aa in seq:
counts[aa] += 1
if aa not in restypes:
return False
else:
counts[aa] += 1
largest_aa_count = max(counts.values())
largest_single_aa_prop = largest_aa_count / total_len
if(largest_single_aa_prop > max_single_aa_prop):
......@@ -867,6 +874,52 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader):
return _batch_prop_gen(it)
class OpenFoldMultimerDataLoader(OpenFoldDataLoader):
def __init__(self, *args, config, stage="train", generator=None, **kwargs):
super(OpenFoldMultimerDataLoader,self).__init__(*args, config=config, stage=stage, generator=generator, **kwargs)
def _add_batch_properties(self, batch):
samples = torch.multinomial(
self.prop_probs_tensor,
num_samples=1, # 1 per row
replacement=True,
generator=self.generator)
def process_samples(batch,samples):
aatype = batch["aatype"]
batch_dims = aatype.shape[:-2]
recycling_dim = aatype.shape[-1]
no_recycling = recycling_dim
for i, key in enumerate(self.prop_keys):
sample = int(samples[i][0])
sample_tensor = torch.tensor(
sample,
device=aatype.device,
requires_grad=False
)
orig_shape = sample_tensor.shape
sample_tensor = sample_tensor.view(
(1,) * len(batch_dims) + sample_tensor.shape + (1,)
)
sample_tensor = sample_tensor.expand(
batch_dims + orig_shape + (recycling_dim,)
)
batch[key] = sample_tensor
if(key == "no_recycling_iters"):
no_recycling = sample
resample_recycling = lambda t: t[..., :no_recycling + 1]
batch = tensor_tree_map(resample_recycling, batch)
return batch
all_chain_features,ground_truth =batch
all_chain_features = process_samples(all_chain_features,samples)
ground_truth = [process_samples(i,samples) for i in ground_truth]
return (all_chain_features,ground_truth)
class OpenFoldDataModule(pl.LightningDataModule):
def __init__(self,
......@@ -1123,7 +1176,6 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule):
train_dataset = dataset_gen(
data_dir=self.train_data_dir,
mmcif_data_cache_path=self.train_mmcif_data_cache_path,
chain_data_cache_path=self.train_chain_data_cache_path,
alignment_dir=self.train_alignment_dir,
filter_path=self.train_filter_path,
max_template_hits=self.config.train.max_template_hits,
......@@ -1138,7 +1190,6 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule):
if(self.distillation_data_dir is not None):
distillation_dataset = dataset_gen(
data_dir=self.distillation_data_dir,
chain_data_cache_path=self.distillation_chain_data_cache_path,
alignment_dir=self.distillation_alignment_dir,
filter_path=self.distillation_filter_path,
max_template_hits=self.config.train.max_template_hits,
......@@ -1189,6 +1240,35 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule):
max_template_hits=self.config.predict.max_template_hits,
mode="predict",
)
def _gen_dataloader(self, stage):
generator = torch.Generator()
if(self.batch_seed is not None):
generator = generator.manual_seed(self.batch_seed)
dataset = None
if(stage == "train"):
dataset = self.train_dataset
# Filter the dataset, if necessary
dataset.reroll()
elif(stage == "eval"):
dataset = self.eval_dataset
elif(stage == "predict"):
dataset = self.predict_dataset
else:
raise ValueError("Invalid stage")
dl = OpenFoldMultimerDataLoader(
dataset,
config=self.config,
stage=stage,
generator=generator,
batch_size=self.config.data_module.data_loaders.batch_size,
num_workers=self.config.data_module.data_loaders.num_workers,
)
return dl
class DummyDataset(torch.utils.data.Dataset):
def __init__(self, batch_path):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment