Commit f6d02cd9 authored by Tim O'Donnell's avatar Tim O'Donnell
Browse files

fix

parent 1abe6160
...@@ -57,6 +57,9 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -57,6 +57,9 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
Path to a directory containing template mmCIF files. Path to a directory containing template mmCIF files.
config: config:
A dataset config object. See openfold.config A dataset config object. See openfold.config
chain_data_cache_path:
Path to cache of data_dir generated by
scripts/generate_chain_data_cache.py
kalign_binary_path: kalign_binary_path:
Path to kalign binary. Path to kalign binary.
max_template_hits: max_template_hits:
...@@ -121,26 +124,28 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -121,26 +124,28 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
c for c in self._chain_ids if c in chains_to_include c for c in self._chain_ids if c in chains_to_include
] ]
# Filter to include only chains where we have structure data if self.chain_data_cache is not None:
# (i.e. entries in chain_data_cache) # Filter to include only chains where we have structure data
original_chain_ids = self._chain_ids # (entries in chain_data_cache)
self._chain_ids = [ original_chain_ids = self._chain_ids
c for c in self._chain_ids if c in self.chain_data_cache self._chain_ids = [
] c for c in self._chain_ids if c in self.chain_data_cache
if len(self._chain_ids) < len(original_chain_ids):
missing = [
c for c in original_chain_ids if c not in self.chain_data_cache
] ]
max_to_print = 10 if len(self._chain_ids) < len(original_chain_ids):
missing_examples = ", ".join(missing[:max_to_print]) missing = [
if len(missing) > max_to_print: c for c in original_chain_ids
missing_examples += ", ..." if c not in self.chain_data_cache
logging.warning( ]
"Ignoring %d alignment entries (%s) that have no corresponding " max_to_print = 10
"entries in chain_data_cache (%s).", missing_examples = ", ".join(missing[:max_to_print])
len(missing), if len(missing) > max_to_print:
missing_examples, missing_examples += ", ..."
chain_data_cache_path) logging.warning(
"Removing %d alignment entries (%s) with no corresponding "
"entries in chain_data_cache (%s).",
len(missing),
missing_examples,
chain_data_cache_path)
self._chain_id_to_idx_dict = { self._chain_id_to_idx_dict = {
chain: i for i, chain in enumerate(self._chain_ids) chain: i for i, chain in enumerate(self._chain_ids)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment