Unverified Commit c48f850a authored by Jennifer Wei's avatar Jennifer Wei Committed by GitHub
Browse files

Merge pull request #408 from rostro36/main

Add use_custom_templates option
parents f37d0d96 7d227395
...@@ -138,6 +138,7 @@ Some commonly used command line flags are here. A full list of flags can be view ...@@ -138,6 +138,7 @@ Some commonly used command line flags are here. A full list of flags can be view
- `--data_random_seed`: Specifies a random seed to use. - `--data_random_seed`: Specifies a random seed to use.
- `--save_outputs`: Saves a copy of all outputs from the model, e.g. the output of the msa track, ptm heads. - `--save_outputs`: Saves a copy of all outputs from the model, e.g. the output of the msa track, ptm heads.
- `--experiment_config_json`: Specify configuration settings using a json file. For example, passing a json with `{globals.relax.max_iterations = 10}` specifies 10 as the maximum number of relaxation iterations. See for [`openfold/config.py`](https://github.com/aqlaboratory/openfold/blob/main/openfold/config.py#L283) the full dictionary of configuration settings. Any parameters that are not manually set in these configuration settings will refer to the defaults specified by your `config_preset`. - `--experiment_config_json`: Specify configuration settings using a json file. For example, passing a json with `{globals.relax.max_iterations = 10}` specifies 10 as the maximum number of relaxation iterations. See for [`openfold/config.py`](https://github.com/aqlaboratory/openfold/blob/main/openfold/config.py#L283) the full dictionary of configuration settings. Any parameters that are not manually set in these configuration settings will refer to the defaults specified by your `config_preset`.
- `--use_custom_template`: Uses all .cif files in `template_mmcif_dir` as template input. Make sure the chains of interest have the identifier _A_ and have the same length as the input sequence. The same templates will be read for all sequences that are passed for inference.
### Advanced Options for Increasing Efficiency ### Advanced Options for Increasing Efficiency
...@@ -159,12 +160,12 @@ Note that chunking (as defined in section 1.11.8 of the AlphaFold 2 supplement) ...@@ -159,12 +160,12 @@ Note that chunking (as defined in section 1.11.8 of the AlphaFold 2 supplement)
#### Long sequence inference #### Long sequence inference
To minimize memory usage during inference on long sequences, consider the following changes: To minimize memory usage during inference on long sequences, consider the following changes:
- As noted in the AlphaFold-Multimer paper, the AlphaFold/OpenFold template stack is a major memory bottleneck for inference on long sequences. OpenFold supports two mutually exclusive inference modes to address this issue. One, `average_templates` in the `template` section of the config, is similar to the solution offered by AlphaFold-Multimer, which is simply to average individual template representations. Our version is modified slightly to accommodate weights trained using the standard template algorithm. Using said weights, we notice no significant difference in performance between our averaged template embeddings and the standard ones. The second, `offload_templates`, temporarily offloads individual template embeddings into CPU memory. The former is an approximation while the latter is slightly slower; both are memory-efficient and allow the model to utilize arbitrarily many templates across sequence lengths. Both are disabled by default, and it is up to the user to determine which best suits their needs, if either. - As noted in the AlphaFold-Multimer paper, the AlphaFold/OpenFold template stack is a major memory bottleneck for inference on long sequences. OpenFold supports two mutually exclusive inference modes to address this issue. One, `average_templates` in the `template` section of the config, is similar to the solution offered by AlphaFold-Multimer, which is simply to average individual template representations. Our version is modified slightly to accommodate weights trained using the standard template algorithm. Using said weights, we notice no significant difference in performance between our averaged template embeddings and the standard ones. The second, `offload_templates`, temporarily offloads individual template embeddings into CPU memory. The former is an approximation while the latter is slightly slower; both are memory-efficient and allow the model to utilize arbitrarily many templates across sequence lengths. Both are disabled by default, and it is up to the user to determine which best suits their needs, if either.
- Inference-time low-memory attention (LMA) can be enabled in the model config. This setting trades off speed for vastly improved memory usage. By default, LMA is run with query and key chunk sizes of 1024 and 4096, respectively. These represent a favorable tradeoff in most memory-constrained cases. Powerusers can choose to tweak these settings in `openfold/model/primitives.py`. For more information on the LMA algorithm, see the aforementioned Staats & Rabe preprint. - Inference-time low-memory attention (LMA) can be enabled in the model config. This setting trades off speed for vastly improved memory usage. By default, LMA is run with query and key chunk sizes of 1024 and 4096, respectively. These represent a favorable tradeoff in most memory-constrained cases. Powerusers can choose to tweak these settings in `openfold/model/primitives.py`. For more information on the LMA algorithm, see the aforementioned Staats & Rabe preprint.
- Disable `tune_chunk_size` for long sequences. Past a certain point, it only wastes time. - Disable `tune_chunk_size` for long sequences. Past a certain point, it only wastes time.
- As a last resort, consider enabling `offload_inference`. This enables more extensive CPU offloading at various bottlenecks throughout the model. - As a last resort, consider enabling `offload_inference`. This enables more extensive CPU offloading at various bottlenecks throughout the model.
- Disable FlashAttention, which seems unstable on long sequences. - Disable FlashAttention, which seems unstable on long sequences.
Using the most conservative settings, we were able to run inference on a 4600-residue complex with a single A100. Compared to AlphaFold's own memory offloading mode, ours is considerably faster; the same complex takes the more efficent AlphaFold-Multimer more than double the time. Use the `long_sequence_inference` config option to enable all of these interventions at once. The `run_pretrained_openfold.py` script can enable this config option with the `--long_sequence_inference` command line option Using the most conservative settings, we were able to run inference on a 4600-residue complex with a single A100. Compared to AlphaFold's own memory offloading mode, ours is considerably faster; the same complex takes the more efficent AlphaFold-Multimer more than double the time. Use the `long_sequence_inference` config option to enable all of these interventions at once. The `run_pretrained_openfold.py` script can enable this config option with the `--long_sequence_inference` command line option
Input FASTA files containing multiple sequences are treated as complexes. In this case, the inference script runs AlphaFold-Gap, a hack proposed [here](https://twitter.com/minkbaek/status/1417538291709071362?lang=en), using the specified stock AlphaFold/OpenFold parameters (NOT AlphaFold-Multimer). Input FASTA files containing multiple sequences are treated as complexes. In this case, the inference script runs AlphaFold-Gap, a hack proposed [here](https://twitter.com/minkbaek/status/1417538291709071362?lang=en), using the specified stock AlphaFold/OpenFold parameters (NOT AlphaFold-Multimer).
\ No newline at end of file \ No newline at end of file
...@@ -23,8 +23,19 @@ import tempfile ...@@ -23,8 +23,19 @@ import tempfile
from typing import Mapping, Optional, Sequence, Any, MutableMapping, Union from typing import Mapping, Optional, Sequence, Any, MutableMapping, Union
import numpy as np import numpy as np
import torch import torch
from openfold.data import templates, parsers, mmcif_parsing, msa_identifiers, msa_pairing, feature_processing_multimer from openfold.data import (
from openfold.data.templates import get_custom_template_features, empty_template_feats templates,
parsers,
mmcif_parsing,
msa_identifiers,
msa_pairing,
feature_processing_multimer,
)
from openfold.data.templates import (
get_custom_template_features,
empty_template_feats,
CustomHitFeaturizer,
)
from openfold.data.tools import jackhmmer, hhblits, hhsearch, hmmsearch from openfold.data.tools import jackhmmer, hhblits, hhsearch, hmmsearch
from openfold.np import residue_constants, protein from openfold.np import residue_constants, protein
...@@ -38,7 +49,9 @@ def make_template_features( ...@@ -38,7 +49,9 @@ def make_template_features(
template_featurizer: Any, template_featurizer: Any,
) -> FeatureDict: ) -> FeatureDict:
hits_cat = sum(hits.values(), []) hits_cat = sum(hits.values(), [])
if(len(hits_cat) == 0 or template_featurizer is None): if template_featurizer is None or (
len(hits_cat) == 0 and not isinstance(template_featurizer, CustomHitFeaturizer)
):
template_features = empty_template_feats(len(input_sequence)) template_features = empty_template_feats(len(input_sequence))
else: else:
templates_result = template_featurizer.get_templates( templates_result = template_featurizer.get_templates(
......
...@@ -22,6 +22,7 @@ import glob ...@@ -22,6 +22,7 @@ import glob
import json import json
import logging import logging
import os import os
from pathlib import Path
import re import re
from typing import Any, Dict, Mapping, Optional, Sequence, Tuple from typing import Any, Dict, Mapping, Optional, Sequence, Tuple
...@@ -947,55 +948,71 @@ def _process_single_hit( ...@@ -947,55 +948,71 @@ def _process_single_hit(
def get_custom_template_features( def get_custom_template_features(
mmcif_path: str, mmcif_path: str,
query_sequence: str, query_sequence: str,
pdb_id: str, pdb_id: str,
chain_id: str, chain_id: Optional[str] = "A",
kalign_binary_path: str): kalign_binary_path: Optional[str] = None,
):
with open(mmcif_path, "r") as mmcif_path: if os.path.isfile(mmcif_path):
cif_string = mmcif_path.read() template_paths = [Path(mmcif_path)]
mmcif_parse_result = mmcif_parsing.parse(
file_id=pdb_id, mmcif_string=cif_string
)
template_sequence = mmcif_parse_result.mmcif_object.chain_to_seqres[chain_id]
mapping = {x:x for x, _ in enumerate(query_sequence)}
features, warnings = _extract_template_features(
mmcif_object=mmcif_parse_result.mmcif_object,
pdb_id=pdb_id,
mapping=mapping,
template_sequence=template_sequence,
query_sequence=query_sequence,
template_chain_id=chain_id,
kalign_binary_path=kalign_binary_path,
_zero_center_positions=True
)
features["template_sum_probs"] = [1.0]
# TODO: clean up this logic
template_features = {}
for template_feature_name in TEMPLATE_FEATURES:
template_features[template_feature_name] = []
for k in template_features:
template_features[k].append(features[k])
for name in template_features:
template_features[name] = np.stack(
template_features[name], axis=0
).astype(TEMPLATE_FEATURES[name])
elif os.path.isdir(mmcif_path):
template_paths = list(Path(mmcif_path).glob("*.cif"))
else:
logging.error("Custom template path %s does not exist", mmcif_path)
raise ValueError(f"Custom template path {mmcif_path} does not exist")
warnings = []
template_features = dict()
for template_path in template_paths:
logging.info("Featurizing template: %s", template_path)
# pdb_id only for error reporting, take file name
pdb_id = Path(template_path).stem
with open(template_path, "r") as mmcif_path:
cif_string = mmcif_path.read()
mmcif_parse_result = mmcif_parsing.parse(
file_id=pdb_id, mmcif_string=cif_string
)
# mapping skipping "-"
mapping = {
x: x for x, curr_char in enumerate(query_sequence) if curr_char.isalnum()
}
realigned_sequence, realigned_mapping = _realign_pdb_template_to_query(
old_template_sequence=query_sequence,
template_chain_id=chain_id,
mmcif_object=mmcif_parse_result.mmcif_object,
old_mapping=mapping,
kalign_binary_path=kalign_binary_path,
)
curr_features, curr_warnings = _extract_template_features(
mmcif_object=mmcif_parse_result.mmcif_object,
pdb_id=pdb_id,
mapping=realigned_mapping,
template_sequence=realigned_sequence,
query_sequence=query_sequence,
template_chain_id=chain_id,
kalign_binary_path=kalign_binary_path,
_zero_center_positions=True,
)
curr_features["template_sum_probs"] = [
1.0
] # template given by user, 100% confident
template_features = {
curr_name: template_features.get(curr_name, []) + [curr_item]
for curr_name, curr_item in curr_features.items()
}
warnings.append(curr_warnings)
template_features = {
template_feature_name: np.stack(
template_features[template_feature_name], axis=0
).astype(template_feature_type)
for template_feature_name, template_feature_type in TEMPLATE_FEATURES.items()
}
return TemplateSearchResult( return TemplateSearchResult(
features=template_features, errors=None, warnings=warnings features=template_features, errors=None, warnings=warnings
) )
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class TemplateSearchResult: class TemplateSearchResult:
features: Mapping[str, Any] features: Mapping[str, Any]
...@@ -1188,6 +1205,23 @@ class HhsearchHitFeaturizer(TemplateHitFeaturizer): ...@@ -1188,6 +1205,23 @@ class HhsearchHitFeaturizer(TemplateHitFeaturizer):
) )
class CustomHitFeaturizer(TemplateHitFeaturizer):
"""Featurizer for templates given in folder.
Chain of interest has to be chain A and of same sequence length as input sequence."""
def get_templates(
self,
query_sequence: str,
hits: Sequence[parsers.TemplateHit],
) -> TemplateSearchResult:
"""Computes the templates for given query sequence (more details above)."""
logging.info("Featurizing mmcif_dir: %s", self._mmcif_dir)
return get_custom_template_features(
self._mmcif_dir,
query_sequence=query_sequence,
pdb_id="test",
chain_id="A",
kalign_binary_path=self._kalign_binary_path,
)
class HmmsearchHitFeaturizer(TemplateHitFeaturizer): class HmmsearchHitFeaturizer(TemplateHitFeaturizer):
def get_templates( def get_templates(
self, self,
......
...@@ -202,8 +202,15 @@ def main(args): ...@@ -202,8 +202,15 @@ def main(args):
) )
is_multimer = "multimer" in args.config_preset is_multimer = "multimer" in args.config_preset
is_custom_template = "use_custom_template" in args and args.use_custom_template
if is_multimer: if is_custom_template:
template_featurizer = templates.CustomHitFeaturizer(
mmcif_dir=args.template_mmcif_dir,
max_template_date="9999-12-31", # just dummy, not used
max_hits=-1, # just dummy, not used
kalign_binary_path=args.kalign_binary_path
)
elif is_multimer:
template_featurizer = templates.HmmsearchHitFeaturizer( template_featurizer = templates.HmmsearchHitFeaturizer(
mmcif_dir=args.template_mmcif_dir, mmcif_dir=args.template_mmcif_dir,
max_template_date=args.max_template_date, max_template_date=args.max_template_date,
...@@ -221,11 +228,9 @@ def main(args): ...@@ -221,11 +228,9 @@ def main(args):
release_dates_path=args.release_dates_path, release_dates_path=args.release_dates_path,
obsolete_pdbs_path=args.obsolete_pdbs_path obsolete_pdbs_path=args.obsolete_pdbs_path
) )
data_processor = data_pipeline.DataPipeline( data_processor = data_pipeline.DataPipeline(
template_featurizer=template_featurizer, template_featurizer=template_featurizer,
) )
if is_multimer: if is_multimer:
data_processor = data_pipeline.DataPipelineMultimer( data_processor = data_pipeline.DataPipelineMultimer(
monomer_data_pipeline=data_processor, monomer_data_pipeline=data_processor,
...@@ -238,7 +243,6 @@ def main(args): ...@@ -238,7 +243,6 @@ def main(args):
np.random.seed(random_seed) np.random.seed(random_seed)
torch.manual_seed(random_seed + 1) torch.manual_seed(random_seed + 1)
feature_processor = feature_pipeline.FeaturePipeline(config.data) feature_processor = feature_pipeline.FeaturePipeline(config.data)
if not os.path.exists(output_dir_base): if not os.path.exists(output_dir_base):
os.makedirs(output_dir_base) os.makedirs(output_dir_base)
...@@ -313,7 +317,6 @@ def main(args): ...@@ -313,7 +317,6 @@ def main(args):
) )
feature_dicts[tag] = feature_dict feature_dicts[tag] = feature_dict
processed_feature_dict = feature_processor.process_features( processed_feature_dict = feature_processor.process_features(
feature_dict, mode='predict', is_multimer=is_multimer feature_dict, mode='predict', is_multimer=is_multimer
) )
...@@ -400,6 +403,10 @@ if __name__ == "__main__": ...@@ -400,6 +403,10 @@ if __name__ == "__main__":
help="""Path to alignment directory. If provided, alignment computation help="""Path to alignment directory. If provided, alignment computation
is skipped and database path arguments are ignored.""" is skipped and database path arguments are ignored."""
) )
parser.add_argument(
"--use_custom_template", action="store_true", default=False,
help="""Use mmcif given with "template_mmcif_dir" argument as template input."""
)
parser.add_argument( parser.add_argument(
"--use_single_seq_mode", action="store_true", default=False, "--use_single_seq_mode", action="store_true", default=False,
help="""Use single sequence embeddings instead of MSAs.""" help="""Use single sequence embeddings instead of MSAs."""
...@@ -494,5 +501,4 @@ if __name__ == "__main__": ...@@ -494,5 +501,4 @@ if __name__ == "__main__":
"""The model is being run on CPU. Consider specifying """The model is being run on CPU. Consider specifying
--model_device for better performance""" --model_device for better performance"""
) )
main(args) main(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment