Commit b55ad675 authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

finished constructing OpenFoldMultimerDataset filtering and sampling steps

parent 71fdc063
......@@ -212,6 +212,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
def __getitem__(self, idx):
name = self.idx_to_chain_id(idx)
print(f"name is {name}")
alignment_dir = os.path.join(self.alignment_dir, name)
alignment_index = None
......@@ -371,7 +372,6 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
assert isinstance(self.chain_data_cache, dict)
if mmcif_data_cache_path is not None:
print(f"mmcif_data_cache_path is {mmcif_data_cache_path}")
with open(mmcif_data_cache_path,"r") as infile:
self.mmcif_data_cache = json.load(infile)
assert isinstance(self.mmcif_data_cache,dict)
......@@ -678,7 +678,6 @@ class OpenFoldDataset(torch.utils.data.Dataset):
idx = []
for _ in range(max_cache_len):
candidate_idx = next(idx_iter)
## TO DO: add filtering cretieria for multimer
chain_id = dataset.idx_to_chain_id(candidate_idx)
chain_data_cache_entry = chain_data_cache[chain_id]
if(not deterministic_train_filter(chain_data_cache_entry)):
......@@ -703,7 +702,6 @@ class OpenFoldDataset(torch.utils.data.Dataset):
yield datapoint_idx
self._samples = [looped_samples(i) for i in range(len(self.datasets))]
if(_roll_at_init):
self.reroll()
......@@ -721,13 +719,11 @@ class OpenFoldDataset(torch.utils.data.Dataset):
replacement=True,
generator=self.generator,
)
self.datapoints = []
for dataset_idx in dataset_choices:
samples = self._samples[dataset_idx]
datapoint_idx = next(samples)
self.datapoints.append((dataset_idx, datapoint_idx))
print(f"datapoints is {self.datapoints}")
class OpenFoldMultimerDataset(torch.utils.data.Dataset):
......@@ -747,59 +743,23 @@ class OpenFoldMultimerDataset(torch.utils.data.Dataset):
self.probabilities = probabilities
self.epoch_len = epoch_len
self.generator = generator
self._samples = [self.looped_samples(i) for i in range(len(self.datasets))]
if _roll_at_init:
self.reroll()
def looped_shuffled_dataset_idx(self,dataset_len):
while True:
# Uniformly shuffle each dataset's indices
weights = [1. for _ in range(dataset_len)]
shuf = torch.multinomial(
torch.tensor(weights),
num_samples=dataset_len,
replacement=False,
generator=self.generator,
)
for idx in shuf:
yield idx
def looped_samples(self,dataset_idx):
max_cache_len = int(self.epoch_len * self.probabilities[dataset_idx])
def filter_samples(self,dataset_idx):
dataset = self.datasets[dataset_idx]
idx_iter = self.looped_shuffled_dataset_idx(len(dataset))
chain_data_cache = dataset.chain_data_cache
mmcif_data_cache = dataset.mmcif_data_cache
while True:
weights = []
idx = []
for _ in range(max_cache_len):
candidate_idx = next(idx_iter)
## TO DO: add filtering cretieria for multimer
mmcif_id = dataset.idx_to_mmcif_id(candidate_idx)
print(f"mmcif_id is {mmcif_id} and candidate_idx: {candidate_idx}")
selected_idx = []
for i in range(len(mmcif_data_cache)):
mmcif_id = dataset.idx_to_mmcif_id(i)
print(f"mmcif_id is {mmcif_id} and candidate_idx: {i}")
chains = mmcif_data_cache[mmcif_id]['chain_ids']
mmcif_data_cache_entry = mmcif_data_cache[mmcif_id]
if(not deterministic_multimer_train_filter(mmcif_data_cache_entry,max_resolution=9)):
continue
p = get_stochastic_train_filter_prob(
chain_data_cache_entry,
)
weights.append([1. - p, p])
idx.append(candidate_idx)
samples = torch.multinomial(
torch.tensor(weights),
num_samples=1,
generator=self.generator,
)
samples = samples.squeeze()
cache = [i for i, s in zip(idx, samples) if s]
for datapoint_idx in cache:
yield datapoint_idx
if(len(chains)>1) and (not deterministic_multimer_train_filter(mmcif_data_cache_entry,
max_resolution=9)):
selected_idx.append(i)
return selected_idx
def __getitem__(self, idx):
dataset_idx, datapoint_idx = self.datapoints[idx]
......@@ -811,20 +771,21 @@ class OpenFoldMultimerDataset(torch.utils.data.Dataset):
def reroll(self):
dataset_choices = torch.multinomial(
torch.tensor(self.probabilities),
num_samples=self.epoch_len,
num_samples=len(self.probabilities),
replacement=True,
generator=self.generator,
)
self.datapoints = []
for dataset_idx in dataset_choices:
samples = self._samples[dataset_idx]
datapoint_idx = next(samples)
self.datapoints.append((dataset_idx, datapoint_idx))
selected_idx = self.filter_samples(dataset_idx)
if len(selected_idx)<self.epoch_len:
self.epoch_len = len(selected_idx)
self.datapoints = [(dataset_idx, datapoint_idx) for datapoint_idx in range(self.epoch_len) ]
print(f"datapoints is {self.datapoints}")
class OpenFoldBatchCollator:
def __call__(self, prots):
stack_fn = partial(torch.stack, dim=0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment