Commit 8470b803 authored by Christina Floristean's avatar Christina Floristean
Browse files

Fix multimer sampling

parent 14853379
......@@ -694,14 +694,14 @@ class OpenFoldMultimerDataset(OpenFoldDataset):
@staticmethod
def get_stochastic_train_filter_prob(
cache_entry: Any,
) -> float:
) -> list:
# Stochastic filters
cluster_sizes = cache_entry.get("cluster_sizes", [])
chain_probs = [1 / c for c in cluster_sizes if c > 0]
if chain_probs:
return sum(chain_probs)
cluster_sizes = cache_entry.get("cluster_sizes")
if cluster_sizes is not None:
return [1 / c if c > 0 else 1 for c in cluster_sizes]
return 1.
num_chains = len(cache_entry["chain_ids"])
return [1.] * num_chains
def looped_samples(self, dataset_idx):
max_cache_len = int(self.epoch_len * self.probabilities[dataset_idx])
......@@ -718,11 +718,11 @@ class OpenFoldMultimerDataset(OpenFoldDataset):
if not self.deterministic_train_filter(mmcif_data_cache_entry):
continue
p = self.get_stochastic_train_filter_prob(
chain_probs = self.get_stochastic_train_filter_prob(
mmcif_data_cache_entry,
)
weights.append([1. - p, p])
idx.append(candidate_idx)
weights.extend([[1. - p, p] for p in chain_probs])
idx.extend([candidate_idx] * len(chain_probs))
samples = torch.multinomial(
torch.tensor(weights),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment