Commit 60e9bd54 authored by Tim O'Donnell's avatar Tim O'Donnell
Browse files

Drop alignments that are missing structure data in training

parent 12caaa89
......@@ -24,6 +24,7 @@ from openfold.utils.tensor_utils import tensor_tree_map, dict_multimap
class OpenFoldSingleDataset(torch.utils.data.Dataset):
def __init__(self,
data_dir: str,
chain_data_cache_path: str,
alignment_dir: str,
template_mmcif_dir: str,
max_template_date: str,
......@@ -80,6 +81,11 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
"""
super(OpenFoldSingleDataset, self).__init__()
self.data_dir = data_dir
with open(chain_data_cache_path, "r") as fp:
self.chain_data_cache = json.load(fp)
assert isinstance(self.chain_data_cache, dict)
self.alignment_dir = alignment_dir
self.config = config
self.treat_pdb_as_distillation = treat_pdb_as_distillation
......@@ -104,12 +110,35 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
self._chain_ids = list(alignment_index.keys())
else:
self._chain_ids = list(os.listdir(alignment_dir))
if(filter_path is not None):
with open(filter_path, "r") as f:
chains_to_include = set([l.strip() for l in f.readlines()])
self._chain_ids = [c for c in self._chain_ids if c in chains_to_include]
self._chain_ids = [
c for c in self._chain_ids if c in chains_to_include
]
# Filter to include only chains where we have structure data
# (i.e. entries in chain_data_cache)
original_chain_ids = self._chain_ids
self._chain_ids = [
c for c in self._chain_ids if c in self.chain_data_cache
]
if len(self._chain_ids) < len(original_chain_ids):
missing = [
c for c in original_chain_ids if c not in self.chain_data_cache
]
max_to_print = 10
missing_examples = ", ".join(missing[:max_to_print])
if len(missing) > max_to_print:
missing_examples += ", ..."
logging.warning(
"Ignoring %d alignment entries (%s) that have no corresponding "
"entries in chain_data_cache (%s).",
len(missing),
missing_examples,
chain_data_cache_path)
self._chain_id_to_idx_dict = {
chain: i for i, chain in enumerate(self._chain_ids)
......@@ -234,7 +263,10 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
data, self.mode
)
feats["batch_idx"] = torch.tensor([idx for _ in range(feats["aatype"].shape[-1])], dtype=torch.int64, device=feats["aatype"].device)
feats["batch_idx"] = torch.tensor(
[idx for _ in range(feats["aatype"].shape[-1])],
dtype=torch.int64,
device=feats["aatype"].device)
return feats
......@@ -297,7 +329,6 @@ class OpenFoldDataset(torch.utils.data.Dataset):
datasets: Sequence[OpenFoldSingleDataset],
probabilities: Sequence[int],
epoch_len: int,
chain_data_cache_paths: List[str],
generator: torch.Generator = None,
_roll_at_init: bool = True,
):
......@@ -305,11 +336,6 @@ class OpenFoldDataset(torch.utils.data.Dataset):
self.probabilities = probabilities
self.epoch_len = epoch_len
self.generator = generator
self.chain_data_caches = []
for path in chain_data_cache_paths:
with open(path, "r") as fp:
self.chain_data_caches.append(json.load(fp))
def looped_shuffled_dataset_idx(dataset_len):
while True:
......@@ -328,7 +354,7 @@ class OpenFoldDataset(torch.utils.data.Dataset):
max_cache_len = int(epoch_len * probabilities[dataset_idx])
dataset = self.datasets[dataset_idx]
idx_iter = looped_shuffled_dataset_idx(len(dataset))
chain_data_cache = self.chain_data_caches[dataset_idx]
chain_data_cache = dataset.chain_data_cache
while True:
weights = []
idx = []
......@@ -591,6 +617,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
if(self.training_mode):
train_dataset = dataset_gen(
data_dir=self.train_data_dir,
chain_data_cache_path=self.train_data_cache_path,
alignment_dir=self.train_alignment_dir,
filter_path=self.train_filter_path,
max_template_hits=self.config.train.max_template_hits,
......@@ -605,6 +632,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
if(self.distillation_data_dir is not None):
distillation_dataset = dataset_gen(
data_dir=self.distillation_data_dir,
chain_data_cache_path=self.distillation_chain_data_cache_path,
alignment_dir=self.distillation_alignment_dir,
filter_path=self.distillation_filter_path,
max_template_hits=self.config.train.max_template_hits,
......@@ -620,16 +648,9 @@ class OpenFoldDataModule(pl.LightningDataModule):
datasets = [train_dataset, distillation_dataset]
d_prob = self.config.train.distillation_prob
probabilities = [1. - d_prob, d_prob]
chain_data_cache_paths = [
self.train_chain_data_cache_path,
self.distillation_chain_data_cache_path,
]
else:
datasets = [train_dataset]
probabilities = [1.]
chain_data_cache_paths = [
self.train_chain_data_cache_path,
]
probabilities = [1.]
generator = None
if(self.batch_seed is not None):
......@@ -640,7 +661,6 @@ class OpenFoldDataModule(pl.LightningDataModule):
datasets=datasets,
probabilities=probabilities,
epoch_len=self.train_epoch_len,
chain_data_cache_paths=chain_data_cache_paths,
generator=generator,
_roll_at_init=False,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment