Commit 3d5e8740 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Fix RNG bug

parent 43116de0
......@@ -244,10 +244,31 @@ class OpenFoldDataset(torch.utils.data.IterableDataset):
class OpenFoldBatchCollator:
def __init__(self, config, generator, stage="train"):
self.config = config
self.generator = generator
self.stage = stage
self.feature_pipeline = feature_pipeline.FeaturePipeline(config)
def __call__(self, raw_prots):
processed_prots = []
for prot in raw_prots:
features = self.feature_pipeline.process_features(
prot, self.stage
)
processed_prots.append(features)
stack_fn = partial(torch.stack, dim=0)
return dict_multimap(stack_fn, processed_prots)
class OpenFoldDataLoader(torch.utils.data.DataLoader):
def __init__(self, *args, config, stage="train", generator=None, **kwargs):
super().__init__(*args, **kwargs)
self.config = config
self.stage = stage
if(generator is None):
generator = torch.Generator()
self.generator = generator
self._prep_batch_properties_probs()
def _prep_batch_properties_probs(self):
......@@ -286,7 +307,7 @@ class OpenFoldBatchCollator:
dtype=torch.float32,
)
def _add_batch_properties(self, raw_prots):
def _add_batch_properties(self, batch):
samples = torch.multinomial(
self.prop_probs_tensor,
num_samples=1, # 1 per row
......@@ -294,22 +315,42 @@ class OpenFoldBatchCollator:
generator=self.generator
)
aatype = batch["aatype"]
batch_dims = aatype.shape[:-2]
recycling_dim = aatype.shape[-1]
no_recycling = recycling_dim
for i, key in enumerate(self.prop_keys):
sample = samples[i][0]
for prot in raw_prots:
prot[key] = np.array(sample, dtype=np.float32)
def __call__(self, raw_prots):
self._add_batch_properties(raw_prots)
processed_prots = []
for prot in raw_prots:
features = self.feature_pipeline.process_features(
prot, self.stage
sample = int(samples[i][0])
sample_tensor = torch.tensor(
sample,
device=aatype.device,
requires_grad=False
)
processed_prots.append(features)
orig_shape = sample_tensor.shape
sample_tensor = sample_tensor.view(
(1,) * len(batch_dims) + sample_tensor.shape + (1,)
)
sample_tensor = sample_tensor.expand(
batch_dims + orig_shape + (recycling_dim,)
)
batch[key] = sample_tensor
stack_fn = partial(torch.stack, dim=0)
return dict_multimap(stack_fn, processed_prots)
if(key == "no_recycling_iters"):
no_recycling = sample
resample_recycling = lambda t: t[..., :no_recycling + 1]
batch = tensor_tree_map(resample_recycling, batch)
return batch
def __iter__(self):
it = super().__iter__()
def _batch_prop_gen(iterator):
for batch in iterator:
yield self._add_batch_properties(batch)
return _batch_prop_gen(it)
class OpenFoldDataModule(pl.LightningDataModule):
......@@ -427,7 +468,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
)
if(self.val_data_dir is not None):
self.val_dataset = dataset_gen(
self.eval_dataset = dataset_gen(
data_dir=self.val_data_dir,
alignment_dir=self.val_alignment_dir,
mapping_path=None,
......@@ -436,7 +477,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
_output_raw=True,
)
else:
self.val_dataset = None
self.eval_dataset = None
else:
self.predict_dataset = dataset_gen(
data_dir=self.predict_data_dir,
......@@ -446,42 +487,45 @@ class OpenFoldDataModule(pl.LightningDataModule):
mode="predict",
)
def _gen_batch_collator(self, stage):
""" We want each process to use the same batch collation seed """
def _gen_dataloader(self, stage):
generator = torch.Generator()
if(self.batch_seed is not None):
generator = generator.manual_seed(self.batch_seed)
collate_fn = OpenFoldBatchCollator(
self.config, generator, stage
)
return collate_fn
def train_dataloader(self):
return torch.utils.data.DataLoader(
self.train_dataset,
batch_size=self.config.data_module.data_loaders.batch_size,
num_workers=self.config.data_module.data_loaders.num_workers,
collate_fn=self._gen_batch_collator("train"),
)
dataset = None
if(stage == "train"):
dataset = self.train_dataset
elif(stage == "eval"):
dataset = self.eval_dataset
elif(stage == "predict"):
dataset = self.predict_dataset
else:
raise ValueError("Invalid stage")
def val_dataloader(self):
if(self.val_dataset is not None):
return torch.utils.data.DataLoader(
self.val_dataset,
batch_collator = OpenFoldBatchCollator(self.config, stage)
dl = OpenFoldDataLoader(
dataset,
config=self.config,
stage=stage,
generator=generator,
batch_size=self.config.data_module.data_loaders.batch_size,
num_workers=self.config.data_module.data_loaders.num_workers,
collate_fn=self._gen_batch_collator("eval")
collate_fn=batch_collator,
)
return dl
def train_dataloader(self):
return self._gen_dataloader("train")
def val_dataloader(self):
if(self.eval_dataset is not None):
return self._gen_dataloader("eval")
return None
def predict_dataloader(self):
return torch.utils.data.DataLoader(
self.predict_dataset,
batch_size=self.config.data_module.data_loaders.batch_size,
num_workers=self.config.data_module.data_loaders.num_workers,
collate_fn=self._gen_batch_collator("predict")
)
return self._gen_dataloader("predict")
class DummyDataset(torch.utils.data.Dataset):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment