Unverified Commit 9dd9cea4 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz Committed by GitHub
Browse files

Merge pull request #210 from timodonnell/remove-chains-missing-data

Drop chains that are missing (structure) data in training
parents 12caaa89 f6d02cd9
......@@ -28,6 +28,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
template_mmcif_dir: str,
max_template_date: str,
config: mlc.ConfigDict,
chain_data_cache_path: Optional[str] = None,
kalign_binary_path: str = '/usr/bin/kalign',
max_template_hits: int = 4,
obsolete_pdbs_file_path: Optional[str] = None,
......@@ -56,6 +57,9 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
Path to a directory containing template mmCIF files.
config:
A dataset config object. See openfold.config
chain_data_cache_path:
Path to cache of data_dir generated by
scripts/generate_chain_data_cache.py
kalign_binary_path:
Path to kalign binary.
max_template_hits:
......@@ -80,6 +84,13 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
"""
super(OpenFoldSingleDataset, self).__init__()
self.data_dir = data_dir
self.chain_data_cache = None
if chain_data_cache_path is not None:
with open(chain_data_cache_path, "r") as fp:
self.chain_data_cache = json.load(fp)
assert isinstance(self.chain_data_cache, dict)
self.alignment_dir = alignment_dir
self.config = config
self.treat_pdb_as_distillation = treat_pdb_as_distillation
......@@ -109,7 +120,32 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
with open(filter_path, "r") as f:
chains_to_include = set([l.strip() for l in f.readlines()])
self._chain_ids = [c for c in self._chain_ids if c in chains_to_include]
self._chain_ids = [
c for c in self._chain_ids if c in chains_to_include
]
if self.chain_data_cache is not None:
# Filter to include only chains where we have structure data
# (entries in chain_data_cache)
original_chain_ids = self._chain_ids
self._chain_ids = [
c for c in self._chain_ids if c in self.chain_data_cache
]
if len(self._chain_ids) < len(original_chain_ids):
missing = [
c for c in original_chain_ids
if c not in self.chain_data_cache
]
max_to_print = 10
missing_examples = ", ".join(missing[:max_to_print])
if len(missing) > max_to_print:
missing_examples += ", ..."
logging.warning(
"Removing %d alignment entries (%s) with no corresponding "
"entries in chain_data_cache (%s).",
len(missing),
missing_examples,
chain_data_cache_path)
self._chain_id_to_idx_dict = {
chain: i for i, chain in enumerate(self._chain_ids)
......@@ -234,7 +270,10 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
data, self.mode
)
feats["batch_idx"] = torch.tensor([idx for _ in range(feats["aatype"].shape[-1])], dtype=torch.int64, device=feats["aatype"].device)
feats["batch_idx"] = torch.tensor(
[idx for _ in range(feats["aatype"].shape[-1])],
dtype=torch.int64,
device=feats["aatype"].device)
return feats
......@@ -297,7 +336,6 @@ class OpenFoldDataset(torch.utils.data.Dataset):
datasets: Sequence[OpenFoldSingleDataset],
probabilities: Sequence[int],
epoch_len: int,
chain_data_cache_paths: List[str],
generator: torch.Generator = None,
_roll_at_init: bool = True,
):
......@@ -306,11 +344,6 @@ class OpenFoldDataset(torch.utils.data.Dataset):
self.epoch_len = epoch_len
self.generator = generator
self.chain_data_caches = []
for path in chain_data_cache_paths:
with open(path, "r") as fp:
self.chain_data_caches.append(json.load(fp))
def looped_shuffled_dataset_idx(dataset_len):
while True:
# Uniformly shuffle each dataset's indices
......@@ -328,7 +361,7 @@ class OpenFoldDataset(torch.utils.data.Dataset):
max_cache_len = int(epoch_len * probabilities[dataset_idx])
dataset = self.datasets[dataset_idx]
idx_iter = looped_shuffled_dataset_idx(len(dataset))
chain_data_cache = self.chain_data_caches[dataset_idx]
chain_data_cache = dataset.chain_data_cache
while True:
weights = []
idx = []
......@@ -591,6 +624,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
if(self.training_mode):
train_dataset = dataset_gen(
data_dir=self.train_data_dir,
chain_data_cache_path=self.train_chain_data_cache_path,
alignment_dir=self.train_alignment_dir,
filter_path=self.train_filter_path,
max_template_hits=self.config.train.max_template_hits,
......@@ -605,6 +639,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
if(self.distillation_data_dir is not None):
distillation_dataset = dataset_gen(
data_dir=self.distillation_data_dir,
chain_data_cache_path=self.distillation_chain_data_cache_path,
alignment_dir=self.distillation_alignment_dir,
filter_path=self.distillation_filter_path,
max_template_hits=self.config.train.max_template_hits,
......@@ -620,16 +655,9 @@ class OpenFoldDataModule(pl.LightningDataModule):
datasets = [train_dataset, distillation_dataset]
d_prob = self.config.train.distillation_prob
probabilities = [1. - d_prob, d_prob]
chain_data_cache_paths = [
self.train_chain_data_cache_path,
self.distillation_chain_data_cache_path,
]
else:
datasets = [train_dataset]
probabilities = [1.]
chain_data_cache_paths = [
self.train_chain_data_cache_path,
]
generator = None
if(self.batch_seed is not None):
......@@ -640,7 +668,6 @@ class OpenFoldDataModule(pl.LightningDataModule):
datasets=datasets,
probabilities=probabilities,
epoch_len=self.train_epoch_len,
chain_data_cache_paths=chain_data_cache_paths,
generator=generator,
_roll_at_init=False,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment