"git@developer.sourcefind.cn:OpenDAS/openfold.git" did not exist on "099769d2ecfd01a8baa8d950030df454a042c910"
Commit b55ad675 authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

finished constructing OpenFoldMultimerDataset filtering and sampling steps

parent 71fdc063
...@@ -212,6 +212,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -212,6 +212,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
def __getitem__(self, idx): def __getitem__(self, idx):
name = self.idx_to_chain_id(idx) name = self.idx_to_chain_id(idx)
print(f"name is {name}")
alignment_dir = os.path.join(self.alignment_dir, name) alignment_dir = os.path.join(self.alignment_dir, name)
alignment_index = None alignment_index = None
...@@ -371,7 +372,6 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset): ...@@ -371,7 +372,6 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
assert isinstance(self.chain_data_cache, dict) assert isinstance(self.chain_data_cache, dict)
if mmcif_data_cache_path is not None: if mmcif_data_cache_path is not None:
print(f"mmcif_data_cache_path is {mmcif_data_cache_path}")
with open(mmcif_data_cache_path,"r") as infile: with open(mmcif_data_cache_path,"r") as infile:
self.mmcif_data_cache = json.load(infile) self.mmcif_data_cache = json.load(infile)
assert isinstance(self.mmcif_data_cache,dict) assert isinstance(self.mmcif_data_cache,dict)
...@@ -678,7 +678,6 @@ class OpenFoldDataset(torch.utils.data.Dataset): ...@@ -678,7 +678,6 @@ class OpenFoldDataset(torch.utils.data.Dataset):
idx = [] idx = []
for _ in range(max_cache_len): for _ in range(max_cache_len):
candidate_idx = next(idx_iter) candidate_idx = next(idx_iter)
## TO DO: add filtering cretieria for multimer
chain_id = dataset.idx_to_chain_id(candidate_idx) chain_id = dataset.idx_to_chain_id(candidate_idx)
chain_data_cache_entry = chain_data_cache[chain_id] chain_data_cache_entry = chain_data_cache[chain_id]
if(not deterministic_train_filter(chain_data_cache_entry)): if(not deterministic_train_filter(chain_data_cache_entry)):
...@@ -703,7 +702,6 @@ class OpenFoldDataset(torch.utils.data.Dataset): ...@@ -703,7 +702,6 @@ class OpenFoldDataset(torch.utils.data.Dataset):
yield datapoint_idx yield datapoint_idx
self._samples = [looped_samples(i) for i in range(len(self.datasets))] self._samples = [looped_samples(i) for i in range(len(self.datasets))]
if(_roll_at_init): if(_roll_at_init):
self.reroll() self.reroll()
...@@ -721,13 +719,11 @@ class OpenFoldDataset(torch.utils.data.Dataset): ...@@ -721,13 +719,11 @@ class OpenFoldDataset(torch.utils.data.Dataset):
replacement=True, replacement=True,
generator=self.generator, generator=self.generator,
) )
self.datapoints = [] self.datapoints = []
for dataset_idx in dataset_choices: for dataset_idx in dataset_choices:
samples = self._samples[dataset_idx] samples = self._samples[dataset_idx]
datapoint_idx = next(samples) datapoint_idx = next(samples)
self.datapoints.append((dataset_idx, datapoint_idx)) self.datapoints.append((dataset_idx, datapoint_idx))
print(f"datapoints is {self.datapoints}")
class OpenFoldMultimerDataset(torch.utils.data.Dataset): class OpenFoldMultimerDataset(torch.utils.data.Dataset):
...@@ -747,59 +743,23 @@ class OpenFoldMultimerDataset(torch.utils.data.Dataset): ...@@ -747,59 +743,23 @@ class OpenFoldMultimerDataset(torch.utils.data.Dataset):
self.probabilities = probabilities self.probabilities = probabilities
self.epoch_len = epoch_len self.epoch_len = epoch_len
self.generator = generator self.generator = generator
self._samples = [self.looped_samples(i) for i in range(len(self.datasets))]
if _roll_at_init: if _roll_at_init:
self.reroll() self.reroll()
def looped_shuffled_dataset_idx(self,dataset_len): def filter_samples(self,dataset_idx):
while True:
# Uniformly shuffle each dataset's indices
weights = [1. for _ in range(dataset_len)]
shuf = torch.multinomial(
torch.tensor(weights),
num_samples=dataset_len,
replacement=False,
generator=self.generator,
)
for idx in shuf:
yield idx
def looped_samples(self,dataset_idx):
max_cache_len = int(self.epoch_len * self.probabilities[dataset_idx])
dataset = self.datasets[dataset_idx] dataset = self.datasets[dataset_idx]
idx_iter = self.looped_shuffled_dataset_idx(len(dataset))
chain_data_cache = dataset.chain_data_cache
mmcif_data_cache = dataset.mmcif_data_cache mmcif_data_cache = dataset.mmcif_data_cache
while True: selected_idx = []
weights = [] for i in range(len(mmcif_data_cache)):
idx = [] mmcif_id = dataset.idx_to_mmcif_id(i)
for _ in range(max_cache_len): print(f"mmcif_id is {mmcif_id} and candidate_idx: {i}")
candidate_idx = next(idx_iter)
## TO DO: add filtering cretieria for multimer
mmcif_id = dataset.idx_to_mmcif_id(candidate_idx)
print(f"mmcif_id is {mmcif_id} and candidate_idx: {candidate_idx}")
chains = mmcif_data_cache[mmcif_id]['chain_ids'] chains = mmcif_data_cache[mmcif_id]['chain_ids']
mmcif_data_cache_entry = mmcif_data_cache[mmcif_id] mmcif_data_cache_entry = mmcif_data_cache[mmcif_id]
if(not deterministic_multimer_train_filter(mmcif_data_cache_entry,max_resolution=9)): if(len(chains)>1) and (not deterministic_multimer_train_filter(mmcif_data_cache_entry,
continue max_resolution=9)):
selected_idx.append(i)
p = get_stochastic_train_filter_prob(
chain_data_cache_entry, return selected_idx
)
weights.append([1. - p, p])
idx.append(candidate_idx)
samples = torch.multinomial(
torch.tensor(weights),
num_samples=1,
generator=self.generator,
)
samples = samples.squeeze()
cache = [i for i, s in zip(idx, samples) if s]
for datapoint_idx in cache:
yield datapoint_idx
def __getitem__(self, idx): def __getitem__(self, idx):
dataset_idx, datapoint_idx = self.datapoints[idx] dataset_idx, datapoint_idx = self.datapoints[idx]
...@@ -811,20 +771,21 @@ class OpenFoldMultimerDataset(torch.utils.data.Dataset): ...@@ -811,20 +771,21 @@ class OpenFoldMultimerDataset(torch.utils.data.Dataset):
def reroll(self): def reroll(self):
dataset_choices = torch.multinomial( dataset_choices = torch.multinomial(
torch.tensor(self.probabilities), torch.tensor(self.probabilities),
num_samples=self.epoch_len, num_samples=len(self.probabilities),
replacement=True, replacement=True,
generator=self.generator, generator=self.generator,
) )
self.datapoints = [] self.datapoints = []
for dataset_idx in dataset_choices: for dataset_idx in dataset_choices:
samples = self._samples[dataset_idx] selected_idx = self.filter_samples(dataset_idx)
datapoint_idx = next(samples) if len(selected_idx)<self.epoch_len:
self.datapoints.append((dataset_idx, datapoint_idx)) self.epoch_len = len(selected_idx)
self.datapoints = [(dataset_idx, datapoint_idx) for datapoint_idx in range(self.epoch_len) ]
print(f"datapoints is {self.datapoints}") print(f"datapoints is {self.datapoints}")
class OpenFoldBatchCollator: class OpenFoldBatchCollator:
def __call__(self, prots): def __call__(self, prots):
stack_fn = partial(torch.stack, dim=0) stack_fn = partial(torch.stack, dim=0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment