Commit 3d5e8740 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Fix RNG bug

parent 43116de0
...@@ -244,10 +244,31 @@ class OpenFoldDataset(torch.utils.data.IterableDataset): ...@@ -244,10 +244,31 @@ class OpenFoldDataset(torch.utils.data.IterableDataset):
class OpenFoldBatchCollator: class OpenFoldBatchCollator:
def __init__(self, config, generator, stage="train"): def __init__(self, config, generator, stage="train"):
self.config = config
self.generator = generator
self.stage = stage self.stage = stage
self.feature_pipeline = feature_pipeline.FeaturePipeline(config) self.feature_pipeline = feature_pipeline.FeaturePipeline(config)
def __call__(self, raw_prots):
processed_prots = []
for prot in raw_prots:
features = self.feature_pipeline.process_features(
prot, self.stage
)
processed_prots.append(features)
stack_fn = partial(torch.stack, dim=0)
return dict_multimap(stack_fn, processed_prots)
class OpenFoldDataLoader(torch.utils.data.DataLoader):
def __init__(self, *args, config, stage="train", generator=None, **kwargs):
super().__init__(*args, **kwargs)
self.config = config
self.stage = stage
if(generator is None):
generator = torch.Generator()
self.generator = generator
self._prep_batch_properties_probs() self._prep_batch_properties_probs()
def _prep_batch_properties_probs(self): def _prep_batch_properties_probs(self):
...@@ -286,7 +307,7 @@ class OpenFoldBatchCollator: ...@@ -286,7 +307,7 @@ class OpenFoldBatchCollator:
dtype=torch.float32, dtype=torch.float32,
) )
def _add_batch_properties(self, raw_prots): def _add_batch_properties(self, batch):
samples = torch.multinomial( samples = torch.multinomial(
self.prop_probs_tensor, self.prop_probs_tensor,
num_samples=1, # 1 per row num_samples=1, # 1 per row
...@@ -294,22 +315,42 @@ class OpenFoldBatchCollator: ...@@ -294,22 +315,42 @@ class OpenFoldBatchCollator:
generator=self.generator generator=self.generator
) )
aatype = batch["aatype"]
batch_dims = aatype.shape[:-2]
recycling_dim = aatype.shape[-1]
no_recycling = recycling_dim
for i, key in enumerate(self.prop_keys): for i, key in enumerate(self.prop_keys):
sample = samples[i][0] sample = int(samples[i][0])
for prot in raw_prots: sample_tensor = torch.tensor(
prot[key] = np.array(sample, dtype=np.float32) sample,
device=aatype.device,
def __call__(self, raw_prots): requires_grad=False
self._add_batch_properties(raw_prots)
processed_prots = []
for prot in raw_prots:
features = self.feature_pipeline.process_features(
prot, self.stage
) )
processed_prots.append(features) orig_shape = sample_tensor.shape
sample_tensor = sample_tensor.view(
(1,) * len(batch_dims) + sample_tensor.shape + (1,)
)
sample_tensor = sample_tensor.expand(
batch_dims + orig_shape + (recycling_dim,)
)
batch[key] = sample_tensor
stack_fn = partial(torch.stack, dim=0) if(key == "no_recycling_iters"):
return dict_multimap(stack_fn, processed_prots) no_recycling = sample
resample_recycling = lambda t: t[..., :no_recycling + 1]
batch = tensor_tree_map(resample_recycling, batch)
return batch
def __iter__(self):
it = super().__iter__()
def _batch_prop_gen(iterator):
for batch in iterator:
yield self._add_batch_properties(batch)
return _batch_prop_gen(it)
class OpenFoldDataModule(pl.LightningDataModule): class OpenFoldDataModule(pl.LightningDataModule):
...@@ -427,7 +468,7 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -427,7 +468,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
) )
if(self.val_data_dir is not None): if(self.val_data_dir is not None):
self.val_dataset = dataset_gen( self.eval_dataset = dataset_gen(
data_dir=self.val_data_dir, data_dir=self.val_data_dir,
alignment_dir=self.val_alignment_dir, alignment_dir=self.val_alignment_dir,
mapping_path=None, mapping_path=None,
...@@ -436,7 +477,7 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -436,7 +477,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
_output_raw=True, _output_raw=True,
) )
else: else:
self.val_dataset = None self.eval_dataset = None
else: else:
self.predict_dataset = dataset_gen( self.predict_dataset = dataset_gen(
data_dir=self.predict_data_dir, data_dir=self.predict_data_dir,
...@@ -446,42 +487,45 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -446,42 +487,45 @@ class OpenFoldDataModule(pl.LightningDataModule):
mode="predict", mode="predict",
) )
def _gen_batch_collator(self, stage): def _gen_dataloader(self, stage):
""" We want each process to use the same batch collation seed """
generator = torch.Generator() generator = torch.Generator()
if(self.batch_seed is not None): if(self.batch_seed is not None):
generator = generator.manual_seed(self.batch_seed) generator = generator.manual_seed(self.batch_seed)
collate_fn = OpenFoldBatchCollator(
self.config, generator, stage
)
return collate_fn
def train_dataloader(self): dataset = None
return torch.utils.data.DataLoader( if(stage == "train"):
self.train_dataset, dataset = self.train_dataset
batch_size=self.config.data_module.data_loaders.batch_size, elif(stage == "eval"):
num_workers=self.config.data_module.data_loaders.num_workers, dataset = self.eval_dataset
collate_fn=self._gen_batch_collator("train"), elif(stage == "predict"):
) dataset = self.predict_dataset
else:
raise ValueError("Invalid stage")
def val_dataloader(self): batch_collator = OpenFoldBatchCollator(self.config, stage)
if(self.val_dataset is not None):
return torch.utils.data.DataLoader( dl = OpenFoldDataLoader(
self.val_dataset, dataset,
config=self.config,
stage=stage,
generator=generator,
batch_size=self.config.data_module.data_loaders.batch_size, batch_size=self.config.data_module.data_loaders.batch_size,
num_workers=self.config.data_module.data_loaders.num_workers, num_workers=self.config.data_module.data_loaders.num_workers,
collate_fn=self._gen_batch_collator("eval") collate_fn=batch_collator,
) )
return dl
def train_dataloader(self):
return self._gen_dataloader("train")
def val_dataloader(self):
if(self.eval_dataset is not None):
return self._gen_dataloader("eval")
return None return None
def predict_dataloader(self): def predict_dataloader(self):
return torch.utils.data.DataLoader( return self._gen_dataloader("predict")
self.predict_dataset,
batch_size=self.config.data_module.data_loaders.batch_size,
num_workers=self.config.data_module.data_loaders.num_workers,
collate_fn=self._gen_batch_collator("predict")
)
class DummyDataset(torch.utils.data.Dataset): class DummyDataset(torch.utils.data.Dataset):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment