Commit 8470b803 authored by Christina Floristean's avatar Christina Floristean
Browse files

Fix multimer sampling

parent 14853379
...@@ -694,14 +694,14 @@ class OpenFoldMultimerDataset(OpenFoldDataset): ...@@ -694,14 +694,14 @@ class OpenFoldMultimerDataset(OpenFoldDataset):
@staticmethod @staticmethod
def get_stochastic_train_filter_prob( def get_stochastic_train_filter_prob(
cache_entry: Any, cache_entry: Any,
) -> float: ) -> list:
# Stochastic filters # Stochastic filters
cluster_sizes = cache_entry.get("cluster_sizes", []) cluster_sizes = cache_entry.get("cluster_sizes")
chain_probs = [1 / c for c in cluster_sizes if c > 0] if cluster_sizes is not None:
if chain_probs: return [1 / c if c > 0 else 1 for c in cluster_sizes]
return sum(chain_probs)
return 1. num_chains = len(cache_entry["chain_ids"])
return [1.] * num_chains
def looped_samples(self, dataset_idx): def looped_samples(self, dataset_idx):
max_cache_len = int(self.epoch_len * self.probabilities[dataset_idx]) max_cache_len = int(self.epoch_len * self.probabilities[dataset_idx])
...@@ -718,11 +718,11 @@ class OpenFoldMultimerDataset(OpenFoldDataset): ...@@ -718,11 +718,11 @@ class OpenFoldMultimerDataset(OpenFoldDataset):
if not self.deterministic_train_filter(mmcif_data_cache_entry): if not self.deterministic_train_filter(mmcif_data_cache_entry):
continue continue
p = self.get_stochastic_train_filter_prob( chain_probs = self.get_stochastic_train_filter_prob(
mmcif_data_cache_entry, mmcif_data_cache_entry,
) )
weights.append([1. - p, p]) weights.extend([[1. - p, p] for p in chain_probs])
idx.append(candidate_idx) idx.extend([candidate_idx] * len(chain_probs))
samples = torch.multinomial( samples = torch.multinomial(
torch.tensor(weights), torch.tensor(weights),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment