"tools/git@developer.sourcefind.cn:OpenDAS/openpcdet.git" did not exist on "76279ff212dfbf16fd06a6e4a51e9bca02304b02"
Unverified Commit 9dd9cea4 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz Committed by GitHub
Browse files

Merge pull request #210 from timodonnell/remove-chains-missing-data

Drop chains that are missing (structure) data in training
parents 12caaa89 f6d02cd9
...@@ -28,6 +28,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -28,6 +28,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
template_mmcif_dir: str, template_mmcif_dir: str,
max_template_date: str, max_template_date: str,
config: mlc.ConfigDict, config: mlc.ConfigDict,
chain_data_cache_path: Optional[str] = None,
kalign_binary_path: str = '/usr/bin/kalign', kalign_binary_path: str = '/usr/bin/kalign',
max_template_hits: int = 4, max_template_hits: int = 4,
obsolete_pdbs_file_path: Optional[str] = None, obsolete_pdbs_file_path: Optional[str] = None,
...@@ -56,6 +57,9 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -56,6 +57,9 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
Path to a directory containing template mmCIF files. Path to a directory containing template mmCIF files.
config: config:
A dataset config object. See openfold.config A dataset config object. See openfold.config
chain_data_cache_path:
Path to cache of data_dir generated by
scripts/generate_chain_data_cache.py
kalign_binary_path: kalign_binary_path:
Path to kalign binary. Path to kalign binary.
max_template_hits: max_template_hits:
...@@ -80,6 +84,13 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -80,6 +84,13 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
""" """
super(OpenFoldSingleDataset, self).__init__() super(OpenFoldSingleDataset, self).__init__()
self.data_dir = data_dir self.data_dir = data_dir
self.chain_data_cache = None
if chain_data_cache_path is not None:
with open(chain_data_cache_path, "r") as fp:
self.chain_data_cache = json.load(fp)
assert isinstance(self.chain_data_cache, dict)
self.alignment_dir = alignment_dir self.alignment_dir = alignment_dir
self.config = config self.config = config
self.treat_pdb_as_distillation = treat_pdb_as_distillation self.treat_pdb_as_distillation = treat_pdb_as_distillation
...@@ -104,12 +115,37 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -104,12 +115,37 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
self._chain_ids = list(alignment_index.keys()) self._chain_ids = list(alignment_index.keys())
else: else:
self._chain_ids = list(os.listdir(alignment_dir)) self._chain_ids = list(os.listdir(alignment_dir))
if(filter_path is not None): if(filter_path is not None):
with open(filter_path, "r") as f: with open(filter_path, "r") as f:
chains_to_include = set([l.strip() for l in f.readlines()]) chains_to_include = set([l.strip() for l in f.readlines()])
self._chain_ids = [c for c in self._chain_ids if c in chains_to_include] self._chain_ids = [
c for c in self._chain_ids if c in chains_to_include
]
if self.chain_data_cache is not None:
# Filter to include only chains where we have structure data
# (entries in chain_data_cache)
original_chain_ids = self._chain_ids
self._chain_ids = [
c for c in self._chain_ids if c in self.chain_data_cache
]
if len(self._chain_ids) < len(original_chain_ids):
missing = [
c for c in original_chain_ids
if c not in self.chain_data_cache
]
max_to_print = 10
missing_examples = ", ".join(missing[:max_to_print])
if len(missing) > max_to_print:
missing_examples += ", ..."
logging.warning(
"Removing %d alignment entries (%s) with no corresponding "
"entries in chain_data_cache (%s).",
len(missing),
missing_examples,
chain_data_cache_path)
self._chain_id_to_idx_dict = { self._chain_id_to_idx_dict = {
chain: i for i, chain in enumerate(self._chain_ids) chain: i for i, chain in enumerate(self._chain_ids)
...@@ -234,7 +270,10 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -234,7 +270,10 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
data, self.mode data, self.mode
) )
feats["batch_idx"] = torch.tensor([idx for _ in range(feats["aatype"].shape[-1])], dtype=torch.int64, device=feats["aatype"].device) feats["batch_idx"] = torch.tensor(
[idx for _ in range(feats["aatype"].shape[-1])],
dtype=torch.int64,
device=feats["aatype"].device)
return feats return feats
...@@ -297,7 +336,6 @@ class OpenFoldDataset(torch.utils.data.Dataset): ...@@ -297,7 +336,6 @@ class OpenFoldDataset(torch.utils.data.Dataset):
datasets: Sequence[OpenFoldSingleDataset], datasets: Sequence[OpenFoldSingleDataset],
probabilities: Sequence[int], probabilities: Sequence[int],
epoch_len: int, epoch_len: int,
chain_data_cache_paths: List[str],
generator: torch.Generator = None, generator: torch.Generator = None,
_roll_at_init: bool = True, _roll_at_init: bool = True,
): ):
...@@ -305,11 +343,6 @@ class OpenFoldDataset(torch.utils.data.Dataset): ...@@ -305,11 +343,6 @@ class OpenFoldDataset(torch.utils.data.Dataset):
self.probabilities = probabilities self.probabilities = probabilities
self.epoch_len = epoch_len self.epoch_len = epoch_len
self.generator = generator self.generator = generator
self.chain_data_caches = []
for path in chain_data_cache_paths:
with open(path, "r") as fp:
self.chain_data_caches.append(json.load(fp))
def looped_shuffled_dataset_idx(dataset_len): def looped_shuffled_dataset_idx(dataset_len):
while True: while True:
...@@ -328,7 +361,7 @@ class OpenFoldDataset(torch.utils.data.Dataset): ...@@ -328,7 +361,7 @@ class OpenFoldDataset(torch.utils.data.Dataset):
max_cache_len = int(epoch_len * probabilities[dataset_idx]) max_cache_len = int(epoch_len * probabilities[dataset_idx])
dataset = self.datasets[dataset_idx] dataset = self.datasets[dataset_idx]
idx_iter = looped_shuffled_dataset_idx(len(dataset)) idx_iter = looped_shuffled_dataset_idx(len(dataset))
chain_data_cache = self.chain_data_caches[dataset_idx] chain_data_cache = dataset.chain_data_cache
while True: while True:
weights = [] weights = []
idx = [] idx = []
...@@ -591,6 +624,7 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -591,6 +624,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
if(self.training_mode): if(self.training_mode):
train_dataset = dataset_gen( train_dataset = dataset_gen(
data_dir=self.train_data_dir, data_dir=self.train_data_dir,
chain_data_cache_path=self.train_chain_data_cache_path,
alignment_dir=self.train_alignment_dir, alignment_dir=self.train_alignment_dir,
filter_path=self.train_filter_path, filter_path=self.train_filter_path,
max_template_hits=self.config.train.max_template_hits, max_template_hits=self.config.train.max_template_hits,
...@@ -605,6 +639,7 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -605,6 +639,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
if(self.distillation_data_dir is not None): if(self.distillation_data_dir is not None):
distillation_dataset = dataset_gen( distillation_dataset = dataset_gen(
data_dir=self.distillation_data_dir, data_dir=self.distillation_data_dir,
chain_data_cache_path=self.distillation_chain_data_cache_path,
alignment_dir=self.distillation_alignment_dir, alignment_dir=self.distillation_alignment_dir,
filter_path=self.distillation_filter_path, filter_path=self.distillation_filter_path,
max_template_hits=self.config.train.max_template_hits, max_template_hits=self.config.train.max_template_hits,
...@@ -620,16 +655,9 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -620,16 +655,9 @@ class OpenFoldDataModule(pl.LightningDataModule):
datasets = [train_dataset, distillation_dataset] datasets = [train_dataset, distillation_dataset]
d_prob = self.config.train.distillation_prob d_prob = self.config.train.distillation_prob
probabilities = [1. - d_prob, d_prob] probabilities = [1. - d_prob, d_prob]
chain_data_cache_paths = [
self.train_chain_data_cache_path,
self.distillation_chain_data_cache_path,
]
else: else:
datasets = [train_dataset] datasets = [train_dataset]
probabilities = [1.] probabilities = [1.]
chain_data_cache_paths = [
self.train_chain_data_cache_path,
]
generator = None generator = None
if(self.batch_seed is not None): if(self.batch_seed is not None):
...@@ -640,7 +668,6 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -640,7 +668,6 @@ class OpenFoldDataModule(pl.LightningDataModule):
datasets=datasets, datasets=datasets,
probabilities=probabilities, probabilities=probabilities,
epoch_len=self.train_epoch_len, epoch_len=self.train_epoch_len,
chain_data_cache_paths=chain_data_cache_paths,
generator=generator, generator=generator,
_roll_at_init=False, _roll_at_init=False,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment