Commit e482a269 authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

updated data_module

parent 7f2a3267
......@@ -24,8 +24,7 @@ import tempfile
from openfold.utils.tensor_utils import (
tensor_tree_map,
)
import logging
logger = logging.getLogger(__name__)
@contextlib.contextmanager
def temp_fasta_file(sequence_str):
......@@ -471,8 +470,8 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
def __getitem__(self, idx):
mmcif_id = self.idx_to_mmcif_id(idx)
print(f"mmcif_id is :{mmcif_id}")
chains = self.mmcif_data_cache[mmcif_id]['chain_ids']
print(f"mmcif_id is :{mmcif_id} idx:{idx} and has {len(chains)}chains")
seqs = self.mmcif_data_cache[mmcif_id]['seqs']
fasta_str = ""
for c,s in zip(chains,seqs):
......@@ -779,7 +778,7 @@ class OpenFoldMultimerDataset(torch.utils.data.Dataset):
selected_idx = self.filter_samples(dataset_idx)
if len(selected_idx)<self.epoch_len:
self.epoch_len = len(selected_idx)
print(f"self.epoch_len is {self.epoch_len}")
logging.info(f"self.epoch_len is {self.epoch_len}")
self.datapoints += [(dataset_idx, selected_idx[i]) for i in range(self.epoch_len) ]
class OpenFoldBatchCollator:
......@@ -874,51 +873,25 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader):
return _batch_prop_gen(it)
class OpenFoldMultimerDataLoader(OpenFoldDataLoader):
class OpenFoldMultimerDataLoader(torch.utils.data.DataLoader):
def __init__(self, *args, config, stage="train", generator=None, **kwargs):
super(OpenFoldMultimerDataLoader,self).__init__(*args, config=config, stage=stage, generator=generator, **kwargs)
super(OpenFoldMultimerDataLoader,self).__init__(*args, **kwargs)
self.config = config
self.stage = stage
def _add_batch_properties(self, batch):
samples = torch.multinomial(
self.prop_probs_tensor,
num_samples=1, # 1 per row
replacement=True,
generator=self.generator)
def process_samples(batch,samples):
aatype = batch["aatype"]
batch_dims = aatype.shape[:-2]
recycling_dim = aatype.shape[-1]
no_recycling = recycling_dim
for i, key in enumerate(self.prop_keys):
sample = int(samples[i][0])
sample_tensor = torch.tensor(
sample,
device=aatype.device,
requires_grad=False
)
orig_shape = sample_tensor.shape
sample_tensor = sample_tensor.view(
(1,) * len(batch_dims) + sample_tensor.shape + (1,)
)
sample_tensor = sample_tensor.expand(
batch_dims + orig_shape + (recycling_dim,)
)
batch[key] = sample_tensor
if(generator is None):
generator = torch.Generator()
self.generator = generator
print('initialised a multimer dataloader')
def __iter__(self):
it = super().__iter__()
if(key == "no_recycling_iters"):
no_recycling = sample
resample_recycling = lambda t: t[..., :no_recycling + 1]
batch = tensor_tree_map(resample_recycling, batch)
def _batch_prop_gen(iterator):
for batch in iterator:
yield batch
return batch
all_chain_features,ground_truth =batch
all_chain_features = process_samples(all_chain_features,samples)
ground_truth = [process_samples(i,samples) for i in ground_truth]
return (all_chain_features,ground_truth)
return _batch_prop_gen(it)
class OpenFoldDataModule(pl.LightningDataModule):
......@@ -1259,15 +1232,12 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule):
raise ValueError("Invalid stage")
dl = OpenFoldMultimerDataLoader(
dl = torch.utils.data.DataLoader(
dataset,
config=self.config,
stage=stage,
generator=generator,
batch_size=self.config.data_module.data_loaders.batch_size,
batch_size=1,
num_workers=self.config.data_module.data_loaders.num_workers,
)
print(f"generated training dataloader")
return dl
class DummyDataset(torch.utils.data.Dataset):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment