Commit 60e9bd54 authored by Tim O'Donnell's avatar Tim O'Donnell
Browse files

Drop alignments that are missing structure data in training

parent 12caaa89
...@@ -24,6 +24,7 @@ from openfold.utils.tensor_utils import tensor_tree_map, dict_multimap ...@@ -24,6 +24,7 @@ from openfold.utils.tensor_utils import tensor_tree_map, dict_multimap
class OpenFoldSingleDataset(torch.utils.data.Dataset): class OpenFoldSingleDataset(torch.utils.data.Dataset):
def __init__(self, def __init__(self,
data_dir: str, data_dir: str,
chain_data_cache_path: str,
alignment_dir: str, alignment_dir: str,
template_mmcif_dir: str, template_mmcif_dir: str,
max_template_date: str, max_template_date: str,
...@@ -80,6 +81,11 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -80,6 +81,11 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
""" """
super(OpenFoldSingleDataset, self).__init__() super(OpenFoldSingleDataset, self).__init__()
self.data_dir = data_dir self.data_dir = data_dir
with open(chain_data_cache_path, "r") as fp:
self.chain_data_cache = json.load(fp)
assert isinstance(self.chain_data_cache, dict)
self.alignment_dir = alignment_dir self.alignment_dir = alignment_dir
self.config = config self.config = config
self.treat_pdb_as_distillation = treat_pdb_as_distillation self.treat_pdb_as_distillation = treat_pdb_as_distillation
...@@ -109,7 +115,30 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -109,7 +115,30 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
with open(filter_path, "r") as f: with open(filter_path, "r") as f:
chains_to_include = set([l.strip() for l in f.readlines()]) chains_to_include = set([l.strip() for l in f.readlines()])
self._chain_ids = [c for c in self._chain_ids if c in chains_to_include] self._chain_ids = [
c for c in self._chain_ids if c in chains_to_include
]
# Filter to include only chains where we have structure data
# (i.e. entries in chain_data_cache)
original_chain_ids = self._chain_ids
self._chain_ids = [
c for c in self._chain_ids if c in self.chain_data_cache
]
if len(self._chain_ids) < len(original_chain_ids):
missing = [
c for c in original_chain_ids if c not in self.chain_data_cache
]
max_to_print = 10
missing_examples = ", ".join(missing[:max_to_print])
if len(missing) > max_to_print:
missing_examples += ", ..."
logging.warning(
"Ignoring %d alignment entries (%s) that have no corresponding "
"entries in chain_data_cache (%s).",
len(missing),
missing_examples,
chain_data_cache_path)
self._chain_id_to_idx_dict = { self._chain_id_to_idx_dict = {
chain: i for i, chain in enumerate(self._chain_ids) chain: i for i, chain in enumerate(self._chain_ids)
...@@ -234,7 +263,10 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -234,7 +263,10 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
data, self.mode data, self.mode
) )
feats["batch_idx"] = torch.tensor([idx for _ in range(feats["aatype"].shape[-1])], dtype=torch.int64, device=feats["aatype"].device) feats["batch_idx"] = torch.tensor(
[idx for _ in range(feats["aatype"].shape[-1])],
dtype=torch.int64,
device=feats["aatype"].device)
return feats return feats
...@@ -297,7 +329,6 @@ class OpenFoldDataset(torch.utils.data.Dataset): ...@@ -297,7 +329,6 @@ class OpenFoldDataset(torch.utils.data.Dataset):
datasets: Sequence[OpenFoldSingleDataset], datasets: Sequence[OpenFoldSingleDataset],
probabilities: Sequence[int], probabilities: Sequence[int],
epoch_len: int, epoch_len: int,
chain_data_cache_paths: List[str],
generator: torch.Generator = None, generator: torch.Generator = None,
_roll_at_init: bool = True, _roll_at_init: bool = True,
): ):
...@@ -306,11 +337,6 @@ class OpenFoldDataset(torch.utils.data.Dataset): ...@@ -306,11 +337,6 @@ class OpenFoldDataset(torch.utils.data.Dataset):
self.epoch_len = epoch_len self.epoch_len = epoch_len
self.generator = generator self.generator = generator
self.chain_data_caches = []
for path in chain_data_cache_paths:
with open(path, "r") as fp:
self.chain_data_caches.append(json.load(fp))
def looped_shuffled_dataset_idx(dataset_len): def looped_shuffled_dataset_idx(dataset_len):
while True: while True:
# Uniformly shuffle each dataset's indices # Uniformly shuffle each dataset's indices
...@@ -328,7 +354,7 @@ class OpenFoldDataset(torch.utils.data.Dataset): ...@@ -328,7 +354,7 @@ class OpenFoldDataset(torch.utils.data.Dataset):
max_cache_len = int(epoch_len * probabilities[dataset_idx]) max_cache_len = int(epoch_len * probabilities[dataset_idx])
dataset = self.datasets[dataset_idx] dataset = self.datasets[dataset_idx]
idx_iter = looped_shuffled_dataset_idx(len(dataset)) idx_iter = looped_shuffled_dataset_idx(len(dataset))
chain_data_cache = self.chain_data_caches[dataset_idx] chain_data_cache = dataset.chain_data_cache
while True: while True:
weights = [] weights = []
idx = [] idx = []
...@@ -591,6 +617,7 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -591,6 +617,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
if(self.training_mode): if(self.training_mode):
train_dataset = dataset_gen( train_dataset = dataset_gen(
data_dir=self.train_data_dir, data_dir=self.train_data_dir,
chain_data_cache_path=self.train_data_cache_path,
alignment_dir=self.train_alignment_dir, alignment_dir=self.train_alignment_dir,
filter_path=self.train_filter_path, filter_path=self.train_filter_path,
max_template_hits=self.config.train.max_template_hits, max_template_hits=self.config.train.max_template_hits,
...@@ -605,6 +632,7 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -605,6 +632,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
if(self.distillation_data_dir is not None): if(self.distillation_data_dir is not None):
distillation_dataset = dataset_gen( distillation_dataset = dataset_gen(
data_dir=self.distillation_data_dir, data_dir=self.distillation_data_dir,
chain_data_cache_path=self.distillation_chain_data_cache_path,
alignment_dir=self.distillation_alignment_dir, alignment_dir=self.distillation_alignment_dir,
filter_path=self.distillation_filter_path, filter_path=self.distillation_filter_path,
max_template_hits=self.config.train.max_template_hits, max_template_hits=self.config.train.max_template_hits,
...@@ -620,16 +648,9 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -620,16 +648,9 @@ class OpenFoldDataModule(pl.LightningDataModule):
datasets = [train_dataset, distillation_dataset] datasets = [train_dataset, distillation_dataset]
d_prob = self.config.train.distillation_prob d_prob = self.config.train.distillation_prob
probabilities = [1. - d_prob, d_prob] probabilities = [1. - d_prob, d_prob]
chain_data_cache_paths = [
self.train_chain_data_cache_path,
self.distillation_chain_data_cache_path,
]
else: else:
datasets = [train_dataset] datasets = [train_dataset]
probabilities = [1.] probabilities = [1.]
chain_data_cache_paths = [
self.train_chain_data_cache_path,
]
generator = None generator = None
if(self.batch_seed is not None): if(self.batch_seed is not None):
...@@ -640,7 +661,6 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -640,7 +661,6 @@ class OpenFoldDataModule(pl.LightningDataModule):
datasets=datasets, datasets=datasets,
probabilities=probabilities, probabilities=probabilities,
epoch_len=self.train_epoch_len, epoch_len=self.train_epoch_len,
chain_data_cache_paths=chain_data_cache_paths,
generator=generator, generator=generator,
_roll_at_init=False, _roll_at_init=False,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment