Commit b61e99bc authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

update multimer datasets

parent 33d8de81
......@@ -371,6 +371,7 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
assert isinstance(self.chain_data_cache, dict)
if mmcif_data_cache_path is not None:
print(f"mmcif_data_cache_path is {mmcif_data_cache_path}")
with open(mmcif_data_cache_path,"r") as infile:
self.mmcif_data_cache = json.load(infile)
assert isinstance(self.mmcif_data_cache,dict)
......@@ -747,6 +748,8 @@ class OpenFoldMultimerDataset(torch.utils.data.Dataset):
self.epoch_len = epoch_len
self.generator = generator
self._samples = [self.looped_samples(i) for i in range(len(self.datasets))]
if _roll_at_init:
self.reroll()
def looped_shuffled_dataset_idx(self,dataset_len):
while True:
......@@ -774,9 +777,10 @@ class OpenFoldMultimerDataset(torch.utils.data.Dataset):
candidate_idx = next(idx_iter)
## TO DO: add filtering cretieria for multimer
mmcif_id = dataset.idx_to_mmcif_id(candidate_idx)
chains = chain_data_cache[mmcif_id]['chain_ids']
print(f"mmcif_id is {mmcif_id} and candidate_idx: {candidate_idx}")
chains = mmcif_data_cache[mmcif_id]['chain_ids']
mmcif_data_cache_entry = mmcif_data_cache[mmcif_id]
if(not deterministic_multimer_train_filter(mmcif_data_cache_entry)):
if(not deterministic_multimer_train_filter(mmcif_data_cache_entry,max_resolution=9)):
continue
p = get_stochastic_train_filter_prob(
......@@ -797,6 +801,27 @@ class OpenFoldMultimerDataset(torch.utils.data.Dataset):
for datapoint_idx in cache:
yield datapoint_idx
def __getitem__(self, idx):
dataset_idx, datapoint_idx = self.datapoints[idx]
return self.datasets[dataset_idx][datapoint_idx]
def __len__(self):
return self.epoch_len
def reroll(self):
dataset_choices = torch.multinomial(
torch.tensor(self.probabilities),
num_samples=self.epoch_len,
replacement=True,
generator=self.generator,
)
self.datapoints = []
for dataset_idx in dataset_choices:
samples = self._samples[dataset_idx]
datapoint_idx = next(samples)
self.datapoints.append((dataset_idx, datapoint_idx))
print(f"datapoints is {self.datapoints}")
......@@ -1193,7 +1218,7 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule):
probabilities=probabilities,
epoch_len=self.train_epoch_len,
generator=generator,
_roll_at_init=False,
_roll_at_init=True,
)
if(self.val_data_dir is not None):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment