Commit cbcd81fb authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

modifid _parse_template_hits so that it can read stockholm file and actually return the dictionary

parent 39d4e5c7
...@@ -22,6 +22,268 @@ from openfold.utils.tensor_utils import tensor_tree_map, dict_multimap ...@@ -22,6 +22,268 @@ from openfold.utils.tensor_utils import tensor_tree_map, dict_multimap
class OpenFoldSingleDataset(torch.utils.data.Dataset): class OpenFoldSingleDataset(torch.utils.data.Dataset):
def __init__(self,
data_dir: str,
alignment_dir: str,
template_mmcif_dir: str,
max_template_date: str,
config: mlc.ConfigDict,
chain_data_cache_path: Optional[str] = None,
kalign_binary_path: str = '/usr/bin/kalign',
max_template_hits: int = 4,
obsolete_pdbs_file_path: Optional[str] = None,
template_release_dates_cache_path: Optional[str] = None,
shuffle_top_k_prefiltered: Optional[int] = None,
treat_pdb_as_distillation: bool = True,
filter_path: Optional[str] = None,
mode: str = "train",
alignment_index: Optional[Any] = None,
_output_raw: bool = False,
_structure_index: Optional[Any] = None,
):
"""
Args:
data_dir:
A path to a directory containing mmCIF files (in train
mode) or FASTA files (in inference mode).
alignment_dir:
A path to a directory containing only data in the format
output by an AlignmentRunner
(defined in openfold.features.alignment_runner).
I.e. a directory of directories named {PDB_ID}_{CHAIN_ID}
or simply {PDB_ID}, each containing .a3m, .sto, and .hhr
files.
template_mmcif_dir:
Path to a directory containing template mmCIF files.
config:
A dataset config object. See openfold.config
chain_data_cache_path:
Path to cache of data_dir generated by
scripts/generate_chain_data_cache.py
kalign_binary_path:
Path to kalign binary.
max_template_hits:
An upper bound on how many templates are considered. During
training, the templates ultimately used are subsampled
from this total quantity.
template_release_dates_cache_path:
Path to the output of scripts/generate_mmcif_cache.
obsolete_pdbs_file_path:
Path to the file containing replacements for obsolete PDBs.
shuffle_top_k_prefiltered:
Whether to uniformly shuffle the top k template hits before
parsing max_template_hits of them. Can be used to
approximate DeepMind's training-time template subsampling
scheme much more performantly.
treat_pdb_as_distillation:
Whether to assume that .pdb files in the data_dir are from
the self-distillation set (and should be subjected to
special distillation set preprocessing steps).
mode:
"train", "val", or "predict"
"""
super(OpenFoldSingleDataset, self).__init__()
self.data_dir = data_dir
self.chain_data_cache = None
if chain_data_cache_path is not None:
with open(chain_data_cache_path, "r") as fp:
self.chain_data_cache = json.load(fp)
assert isinstance(self.chain_data_cache, dict)
self.alignment_dir = alignment_dir
self.config = config
self.treat_pdb_as_distillation = treat_pdb_as_distillation
self.mode = mode
self.alignment_index = alignment_index
self._output_raw = _output_raw
self._structure_index = _structure_index
self.supported_exts = [".cif", ".core", ".pdb"]
valid_modes = ["train", "eval", "predict"]
if(mode not in valid_modes):
raise ValueError(f'mode must be one of {valid_modes}')
if(template_release_dates_cache_path is None):
logging.warning(
"Template release dates cache does not exist. Remember to run "
"scripts/generate_mmcif_cache.py before running OpenFold"
)
if(alignment_index is not None):
self._chain_ids = list(alignment_index.keys())
else:
self._chain_ids = list(os.listdir(alignment_dir))
if(filter_path is not None):
with open(filter_path, "r") as f:
chains_to_include = set([l.strip() for l in f.readlines()])
self._chain_ids = [
c for c in self._chain_ids if c in chains_to_include
]
if self.chain_data_cache is not None:
# Filter to include only chains where we have structure data
# (entries in chain_data_cache)
original_chain_ids = self._chain_ids
self._chain_ids = [
c for c in self._chain_ids if c in self.chain_data_cache
]
if len(self._chain_ids) < len(original_chain_ids):
missing = [
c for c in original_chain_ids
if c not in self.chain_data_cache
]
max_to_print = 10
missing_examples = ", ".join(missing[:max_to_print])
if len(missing) > max_to_print:
missing_examples += ", ..."
logging.warning(
"Removing %d alignment entries (%s) with no corresponding "
"entries in chain_data_cache (%s).",
len(missing),
missing_examples,
chain_data_cache_path)
self._chain_id_to_idx_dict = {
chain: i for i, chain in enumerate(self._chain_ids)
}
# If it's running template search for a monomer, then use hhsearch
# as demonstrated in AlphaFold's run_alphafold.py code
# https://github.com/deepmind/alphafold/blob/6c4d833fbd1c6b8e7c9a21dae5d4ada2ce777e10/run_alphafold.py#L462C1-L477
template_featurizer = templates.HhsearchHitFeaturizer(
mmcif_dir=template_mmcif_dir,
max_template_date=max_template_date,
max_hits=max_template_hits,
kalign_binary_path=kalign_binary_path,
release_dates_path=template_release_dates_cache_path,
obsolete_pdbs_path=obsolete_pdbs_file_path,
_shuffle_top_k_prefiltered=shuffle_top_k_prefiltered,
)
self.data_pipeline = data_pipeline.DataPipeline(
template_featurizer=template_featurizer,
)
if(not self._output_raw):
self.feature_pipeline = feature_pipeline.FeaturePipeline(config)
def _parse_mmcif(self, path, file_id, chain_id, alignment_dir, alignment_index):
with open(path, 'r') as f:
mmcif_string = f.read()
mmcif_object = mmcif_parsing.parse(
file_id=file_id, mmcif_string=mmcif_string
)
# Crash if an error is encountered. Any parsing errors should have
# been dealt with at the alignment stage.
if(mmcif_object.mmcif_object is None):
raise list(mmcif_object.errors.values())[0]
mmcif_object = mmcif_object.mmcif_object
data = self.data_pipeline.process_mmcif(
mmcif=mmcif_object,
alignment_dir=alignment_dir,
chain_id=chain_id,
alignment_index=alignment_index
)
return data
def chain_id_to_idx(self, chain_id):
return self._chain_id_to_idx_dict[chain_id]
def idx_to_chain_id(self, idx):
return self._chain_ids[idx]
def __getitem__(self, idx):
name = self.idx_to_chain_id(idx)
alignment_dir = os.path.join(self.alignment_dir, name)
alignment_index = None
if(self.alignment_index is not None):
alignment_dir = self.alignment_dir
alignment_index = self.alignment_index[name]
if(self.mode == 'train' or self.mode == 'eval'):
spl = name.rsplit('_', 1)
if(len(spl) == 2):
file_id, chain_id = spl
else:
file_id, = spl
chain_id = None
path = os.path.join(self.data_dir, file_id)
structure_index_entry = None
if(self._structure_index is not None):
structure_index_entry = self._structure_index[name]
assert(len(structure_index_entry["files"]) == 1)
filename, _, _ = structure_index_entry["files"][0]
ext = os.path.splitext(filename)[1]
else:
ext = None
for e in self.supported_exts:
if(os.path.exists(path + e)):
ext = e
break
if(ext is None):
raise ValueError("Invalid file type")
path += ext
if(ext == ".cif"):
data = self._parse_mmcif(
path, file_id, chain_id, alignment_dir, alignment_index,
)
elif(ext == ".core"):
data = self.data_pipeline.process_core(
path, alignment_dir, alignment_index,
)
elif(ext == ".pdb"):
structure_index = None
if(self._structure_index is not None):
structure_index = self._structure_index[name]
data = self.data_pipeline.process_pdb(
pdb_path=path,
alignment_dir=alignment_dir,
is_distillation=self.treat_pdb_as_distillation,
chain_id=chain_id,
alignment_index=alignment_index,
_structure_index=structure_index,
)
else:
raise ValueError("Extension branch missing")
else:
path = os.path.join(name, name + ".fasta")
data = self.data_pipeline.process_fasta(
fasta_path=path,
alignment_dir=alignment_dir,
alignment_index=alignment_index,
)
if(self._output_raw):
return data
feats = self.feature_pipeline.process_features(
data, self.mode
)
feats["batch_idx"] = torch.tensor(
[idx for _ in range(feats["aatype"].shape[-1])],
dtype=torch.int64,
device=feats["aatype"].device)
return feats,data
def __len__(self):
return len(self._chain_ids)
class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
def __init__(self, def __init__(self,
data_dir: str, data_dir: str,
alignment_dir: str, alignment_dir: str,
...@@ -43,6 +305,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -43,6 +305,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
_structure_index: Optional[Any] = None, _structure_index: Optional[Any] = None,
): ):
""" """
This class check each individual PDB ID and return its chain(s) features/ground truth
Args: Args:
data_dir: data_dir:
A path to a directory containing mmCIF files (in train A path to a directory containing mmCIF files (in train
...@@ -89,7 +352,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -89,7 +352,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
mode: mode:
"train", "val", or "predict" "train", "val", or "predict"
""" """
super(OpenFoldSingleDataset, self).__init__() super(OpenFoldSingleMultimerDataset, self).__init__()
self.data_dir = data_dir self.data_dir = data_dir
self.chain_data_cache = None self.chain_data_cache = None
...@@ -293,7 +556,6 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -293,7 +556,6 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
def __len__(self): def __len__(self):
return len(self._chain_ids) return len(self._chain_ids)
def deterministic_train_filter( def deterministic_train_filter(
chain_data_cache_entry: Any, chain_data_cache_entry: Any,
max_resolution: float = 9., max_resolution: float = 9.,
...@@ -371,7 +633,6 @@ class OpenFoldDataset(torch.utils.data.Dataset): ...@@ -371,7 +633,6 @@ class OpenFoldDataset(torch.utils.data.Dataset):
yield idx yield idx
def looped_samples(dataset_idx): def looped_samples(dataset_idx):
print(f"dataset_idx is {dataset_idx} and start looping samples")
max_cache_len = int(epoch_len * probabilities[dataset_idx]) max_cache_len = int(epoch_len * probabilities[dataset_idx])
dataset = self.datasets[dataset_idx] dataset = self.datasets[dataset_idx]
idx_iter = looped_shuffled_dataset_idx(len(dataset)) idx_iter = looped_shuffled_dataset_idx(len(dataset))
...@@ -382,7 +643,6 @@ class OpenFoldDataset(torch.utils.data.Dataset): ...@@ -382,7 +643,6 @@ class OpenFoldDataset(torch.utils.data.Dataset):
for _ in range(max_cache_len): for _ in range(max_cache_len):
candidate_idx = next(idx_iter) candidate_idx = next(idx_iter)
chain_id = dataset.idx_to_chain_id(candidate_idx) chain_id = dataset.idx_to_chain_id(candidate_idx)
print(f"candidate_idx: {candidate_idx} and chain_id: {chain_id}")
chain_data_cache_entry = chain_data_cache[chain_id] chain_data_cache_entry = chain_data_cache[chain_id]
if(not deterministic_train_filter(chain_data_cache_entry)): if(not deterministic_train_filter(chain_data_cache_entry)):
continue continue
......
...@@ -803,7 +803,8 @@ class DataPipeline: ...@@ -803,7 +803,8 @@ class DataPipeline:
def _parse_template_hits( def _parse_template_hits(
self, self,
alignment_dir: str, alignment_dir: str,
alignment_index: Optional[Any] = None alignment_index: Optional[Any] = None,
input_sequence=None,
) -> Mapping[str, Any]: ) -> Mapping[str, Any]:
all_hits = {} all_hits = {}
if (alignment_index is not None): if (alignment_index is not None):
...@@ -830,6 +831,15 @@ class DataPipeline: ...@@ -830,6 +831,15 @@ class DataPipeline:
with open(path, "r") as fp: with open(path, "r") as fp:
hits = parsers.parse_hhr(fp.read()) hits = parsers.parse_hhr(fp.read())
all_hits[f] = hits all_hits[f] = hits
fp.close()
elif (ext =='.sto') and (f.startswith("pdb")):
with open(path,"r") as fp:
hits = parsers.parse_hmmsearch_sto(fp.read(),input_sequence)
all_hits[f] = hits
fp.close()
return all_hits
def _get_msas(self, def _get_msas(self,
alignment_dir: str, alignment_dir: str,
...@@ -937,7 +947,7 @@ class DataPipeline: ...@@ -937,7 +947,7 @@ class DataPipeline:
input_sequence = mmcif.chain_to_seqres[chain_id] input_sequence = mmcif.chain_to_seqres[chain_id]
hits = self._parse_template_hits( hits = self._parse_template_hits(
alignment_dir, alignment_dir,
alignment_index) alignment_index,input_sequence)
template_features = make_template_features( template_features = make_template_features(
input_sequence, input_sequence,
...@@ -986,7 +996,7 @@ class DataPipeline: ...@@ -986,7 +996,7 @@ class DataPipeline:
hits = self._parse_template_hits( hits = self._parse_template_hits(
alignment_dir, alignment_dir,
alignment_index alignment_index,input_sequence
) )
template_features = make_template_features( template_features = make_template_features(
...@@ -1018,7 +1028,7 @@ class DataPipeline: ...@@ -1018,7 +1028,7 @@ class DataPipeline:
hits = self._parse_template_hits( hits = self._parse_template_hits(
alignment_dir, alignment_dir,
alignment_index alignment_index,input_sequence
) )
template_features = make_template_features( template_features = make_template_features(
...@@ -1107,7 +1117,7 @@ class DataPipeline: ...@@ -1107,7 +1117,7 @@ class DataPipeline:
alignment_dir = os.path.join( alignment_dir = os.path.join(
super_alignment_dir, desc super_alignment_dir, desc
) )
hits = self._parse_template_hits(alignment_dir, alignment_index=None) hits = self._parse_template_hits(alignment_dir, alignment_index=None,input_sequence=input_sequence)
template_features = make_template_features( template_features = make_template_features(
seq, seq,
hits, hits,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment