Commit 17f24bd7 authored by rostro36's avatar rostro36
Browse files

Added custom template folder

parent bb3f51e5
...@@ -174,7 +174,10 @@ where `data` is the same directory as in the previous step. If `jackhmmer`, ...@@ -174,7 +174,10 @@ where `data` is the same directory as in the previous step. If `jackhmmer`,
`/usr/bin`, their `binary_path` command-line arguments can be dropped. `/usr/bin`, their `binary_path` command-line arguments can be dropped.
If you've already computed alignments for the query, you have the option to If you've already computed alignments for the query, you have the option to
skip the expensive alignment computation here with skip the expensive alignment computation here with
`--use_precomputed_alignments`. `--use_precomputed_alignments`. If you wish to use a specific template as input,
you can use the argument `--use_custom_template`, which then will read all .cif
files in `template_mmcif_dir`. Make sure the chains of interest have the identifier _A_
and have the same length as the input sequence.
`--openfold_checkpoint_path` or `--jax_param_path` accept comma-delineated lists `--openfold_checkpoint_path` or `--jax_param_path` accept comma-delineated lists
of .pt/DeepSpeed OpenFold checkpoints and AlphaFold's .npz JAX parameter files, of .pt/DeepSpeed OpenFold checkpoints and AlphaFold's .npz JAX parameter files,
......
...@@ -23,8 +23,19 @@ import tempfile ...@@ -23,8 +23,19 @@ import tempfile
from typing import Mapping, Optional, Sequence, Any, MutableMapping, Union from typing import Mapping, Optional, Sequence, Any, MutableMapping, Union
import numpy as np import numpy as np
import torch import torch
from openfold.data import templates, parsers, mmcif_parsing, msa_identifiers, msa_pairing, feature_processing_multimer from openfold.data import (
from openfold.data.templates import get_custom_template_features, empty_template_feats templates,
parsers,
mmcif_parsing,
msa_identifiers,
msa_pairing,
feature_processing_multimer,
)
from openfold.data.templates import (
get_custom_template_features,
empty_template_feats,
CustomHitFeaturizer,
)
from openfold.data.tools import jackhmmer, hhblits, hhsearch, hmmsearch from openfold.data.tools import jackhmmer, hhblits, hhsearch, hmmsearch
from openfold.np import residue_constants, protein from openfold.np import residue_constants, protein
...@@ -38,7 +49,9 @@ def make_template_features( ...@@ -38,7 +49,9 @@ def make_template_features(
template_featurizer: Any, template_featurizer: Any,
) -> FeatureDict: ) -> FeatureDict:
hits_cat = sum(hits.values(), []) hits_cat = sum(hits.values(), [])
if(len(hits_cat) == 0 or template_featurizer is None): if template_featurizer is None or (
len(hits_cat) == 0 and not isinstance(template_featurizer, CustomHitFeaturizer)
):
template_features = empty_template_feats(len(input_sequence)) template_features = empty_template_feats(len(input_sequence))
else: else:
templates_result = template_featurizer.get_templates( templates_result = template_featurizer.get_templates(
......
...@@ -22,6 +22,7 @@ import glob ...@@ -22,6 +22,7 @@ import glob
import json import json
import logging import logging
import os import os
from pathlib import Path
import re import re
from typing import Any, Dict, Mapping, Optional, Sequence, Tuple from typing import Any, Dict, Mapping, Optional, Sequence, Tuple
...@@ -950,22 +951,33 @@ def get_custom_template_features( ...@@ -950,22 +951,33 @@ def get_custom_template_features(
mmcif_path: str, mmcif_path: str,
query_sequence: str, query_sequence: str,
pdb_id: str, pdb_id: str,
chain_id: str, chain_id: Optional[str] = "A",
kalign_binary_path: str): kalign_binary_path: Optional[str] = None,
):
if os.path.isfile(mmcif_path):
template_paths = [Path(mmcif_path)]
with open(mmcif_path, "r") as mmcif_path: elif os.path.isdir(mmcif_path):
template_paths = list(Path(mmcif_path).glob("*.cif"))
else:
logging.error("Custom template path %s does not exist", mmcif_path)
raise ValueError(f"Custom template path {mmcif_path} does not exist")
warnings = []
template_features = dict()
for template_path in template_paths:
logging.info("Featurizing template: %s", template_path)
# pdb_id only for error reporting, take file name
pdb_id = Path(template_path).stem
with open(template_path, "r") as mmcif_path:
cif_string = mmcif_path.read() cif_string = mmcif_path.read()
mmcif_parse_result = mmcif_parsing.parse( mmcif_parse_result = mmcif_parsing.parse(
file_id=pdb_id, mmcif_string=cif_string file_id=pdb_id, mmcif_string=cif_string
) )
# chain_id defaults to A, should be changed?
template_sequence = mmcif_parse_result.mmcif_object.chain_to_seqres[chain_id] template_sequence = mmcif_parse_result.mmcif_object.chain_to_seqres[chain_id]
mapping = {x: x for x, _ in enumerate(query_sequence)}
curr_features, curr_warnings = _extract_template_features(
mapping = {x:x for x, _ in enumerate(query_sequence)}
features, warnings = _extract_template_features(
mmcif_object=mmcif_parse_result.mmcif_object, mmcif_object=mmcif_parse_result.mmcif_object,
pdb_id=pdb_id, pdb_id=pdb_id,
mapping=mapping, mapping=mapping,
...@@ -973,23 +985,21 @@ def get_custom_template_features( ...@@ -973,23 +985,21 @@ def get_custom_template_features(
query_sequence=query_sequence, query_sequence=query_sequence,
template_chain_id=chain_id, template_chain_id=chain_id,
kalign_binary_path=kalign_binary_path, kalign_binary_path=kalign_binary_path,
_zero_center_positions=True _zero_center_positions=True,
) )
features["template_sum_probs"] = [1.0] curr_features["template_sum_probs"] = [1.0]
template_features = {
# TODO: clean up this logic curr_name: template_features.get(curr_name, []) + [curr_item]
template_features = {} for curr_name, curr_item in curr_features.items()
for template_feature_name in TEMPLATE_FEATURES: }
template_features[template_feature_name] = [] warnings = warnings.append(curr_warnings)
for k in template_features:
template_features[k].append(features[k])
for name in template_features:
template_features[name] = np.stack(
template_features[name], axis=0
).astype(TEMPLATE_FEATURES[name])
template_features = {
template_feature_name: np.stack(
template_features[template_feature_name], axis=0
).astype(template_feature_type)
for template_feature_name, template_feature_type in TEMPLATE_FEATURES.items()
}
return TemplateSearchResult( return TemplateSearchResult(
features=template_features, errors=None, warnings=warnings features=template_features, errors=None, warnings=warnings
) )
...@@ -1188,6 +1198,23 @@ class HhsearchHitFeaturizer(TemplateHitFeaturizer): ...@@ -1188,6 +1198,23 @@ class HhsearchHitFeaturizer(TemplateHitFeaturizer):
) )
class CustomHitFeaturizer(TemplateHitFeaturizer):
"""Featurizer for templates given in folder.
Chain of interest has to be chain A and of same residue size as input sequence."""
def get_templates(
self,
query_sequence: str,
hits: Sequence[parsers.TemplateHit],
) -> TemplateSearchResult:
"""Computes the templates for given query sequence (more details above)."""
logging.info("Featurizing mmcif_dir: %s", self._mmcif_dir)
return get_custom_template_features(
self._mmcif_dir,
query_sequence=query_sequence,
pdb_id="test",
chain_id="A",
kalign_binary_path=self._kalign_binary_path,
)
class HmmsearchHitFeaturizer(TemplateHitFeaturizer): class HmmsearchHitFeaturizer(TemplateHitFeaturizer):
def get_templates( def get_templates(
self, self,
......
...@@ -186,8 +186,15 @@ def main(args): ...@@ -186,8 +186,15 @@ def main(args):
) )
is_multimer = "multimer" in args.config_preset is_multimer = "multimer" in args.config_preset
is_custom_template = "use_custom_template" in args
if is_multimer: if is_custom_template:
template_featurizer = templates.CustomHitFeaturizer(
mmcif_dir=args.template_mmcif_dir,
max_template_date="9999-12-31", # just dummy, not used
max_hits=-1, # just dummy, not used
kalign_binary_path=args.kalign_binary_path
)
elif is_multimer:
template_featurizer = templates.HmmsearchHitFeaturizer( template_featurizer = templates.HmmsearchHitFeaturizer(
mmcif_dir=args.template_mmcif_dir, mmcif_dir=args.template_mmcif_dir,
max_template_date=args.max_template_date, max_template_date=args.max_template_date,
...@@ -205,11 +212,9 @@ def main(args): ...@@ -205,11 +212,9 @@ def main(args):
release_dates_path=args.release_dates_path, release_dates_path=args.release_dates_path,
obsolete_pdbs_path=args.obsolete_pdbs_path obsolete_pdbs_path=args.obsolete_pdbs_path
) )
data_processor = data_pipeline.DataPipeline( data_processor = data_pipeline.DataPipeline(
template_featurizer=template_featurizer, template_featurizer=template_featurizer,
) )
if is_multimer: if is_multimer:
data_processor = data_pipeline.DataPipelineMultimer( data_processor = data_pipeline.DataPipelineMultimer(
monomer_data_pipeline=data_processor, monomer_data_pipeline=data_processor,
...@@ -222,7 +227,6 @@ def main(args): ...@@ -222,7 +227,6 @@ def main(args):
np.random.seed(random_seed) np.random.seed(random_seed)
torch.manual_seed(random_seed + 1) torch.manual_seed(random_seed + 1)
feature_processor = feature_pipeline.FeaturePipeline(config.data) feature_processor = feature_pipeline.FeaturePipeline(config.data)
if not os.path.exists(output_dir_base): if not os.path.exists(output_dir_base):
os.makedirs(output_dir_base) os.makedirs(output_dir_base)
...@@ -292,7 +296,6 @@ def main(args): ...@@ -292,7 +296,6 @@ def main(args):
) )
feature_dicts[tag] = feature_dict feature_dicts[tag] = feature_dict
processed_feature_dict = feature_processor.process_features( processed_feature_dict = feature_processor.process_features(
feature_dict, mode='predict', is_multimer=is_multimer feature_dict, mode='predict', is_multimer=is_multimer
) )
...@@ -379,6 +382,10 @@ if __name__ == "__main__": ...@@ -379,6 +382,10 @@ if __name__ == "__main__":
help="""Path to alignment directory. If provided, alignment computation help="""Path to alignment directory. If provided, alignment computation
is skipped and database path arguments are ignored.""" is skipped and database path arguments are ignored."""
) )
parser.add_argument(
"--use_custom_template", action="store_true", default=False,
help="""Use mmcif given with "template_mmcif_dir" argument as template input."""
)
parser.add_argument( parser.add_argument(
"--use_single_seq_mode", action="store_true", default=False, "--use_single_seq_mode", action="store_true", default=False,
help="""Use single sequence embeddings instead of MSAs.""" help="""Use single sequence embeddings instead of MSAs."""
...@@ -466,5 +473,4 @@ if __name__ == "__main__": ...@@ -466,5 +473,4 @@ if __name__ == "__main__":
"""The model is being run on CPU. Consider specifying """The model is being run on CPU. Consider specifying
--model_device for better performance""" --model_device for better performance"""
) )
main(args) main(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment