Commit ab4a2459 authored by Jennifer Wei's avatar Jennifer Wei
Browse files

Merge branch 'main' into pl_upgrades

parents 100a309e 815a042c
...@@ -11,7 +11,7 @@ jobs: ...@@ -11,7 +11,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- name: Cleanup - name: Cleanup # https://github.com/actions/virtual-environments/issues/2840
run: sudo rm -rf /usr/share/dotnet && sudo rm -rf /opt/ghc && sudo rm -rf "/usr/local/share/boost" && sudo rm -rf "$AGENT_TOOLSDIRECTORY" run: sudo rm -rf /usr/share/dotnet && sudo rm -rf /opt/ghc && sudo rm -rf "/usr/local/share/boost" && sudo rm -rf "$AGENT_TOOLSDIRECTORY"
- name: Build the Docker image - name: Build the Docker image
run: docker build . --file Dockerfile --tag openfold:$(date +%s) run: docker build . --file Dockerfile --tag openfold:$(date +%s)
...@@ -68,9 +68,9 @@ All together, the file directory would look like: ...@@ -68,9 +68,9 @@ All together, the file directory would look like:
└── 6kwc.cif └── 6kwc.cif
└── alignment_db └── alignment_db
├── alignment_db_0.db ├── alignment_db_0.db
├── alignment_db_1.db ├── alignment_db_1.db
... ...
├── alignment_db_9.db ├── alignment_db_9.db
└── alignment_db.index └── alignment_db.index
``` ```
......
...@@ -42,7 +42,7 @@ $ bash scripts/download_openfold_params.sh $PARAMS_DIR ...@@ -42,7 +42,7 @@ $ bash scripts/download_openfold_params.sh $PARAMS_DIR
We recommend selecting `openfold/resources` as the params directory as this is the default directory used by the `run_pretrained_openfold.py` to locate parameters. We recommend selecting `openfold/resources` as the params directory as this is the default directory used by the `run_pretrained_openfold.py` to locate parameters.
If you choose to use a different directory, you may make a symlink to the `openfold/resources` directory, or specify an alternate parameter path with the command line argument `--jax_path` for AlphaFold parameters or `--openfold_checkpoint_path` for OpenFold parameters. If you choose to use a different directory, you may make a symlink to the `openfold/resources` directory, or specify an alternate parameter path with the command line argument `--jax_param_path` for AlphaFold parameters or `--openfold_checkpoint_path` for OpenFold parameters.
### Model Inference ### Model Inference
...@@ -138,6 +138,7 @@ Some commonly used command line flags are here. A full list of flags can be view ...@@ -138,6 +138,7 @@ Some commonly used command line flags are here. A full list of flags can be view
- `--data_random_seed`: Specifies a random seed to use. - `--data_random_seed`: Specifies a random seed to use.
- `--save_outputs`: Saves a copy of all outputs from the model, e.g. the output of the msa track, ptm heads. - `--save_outputs`: Saves a copy of all outputs from the model, e.g. the output of the msa track, ptm heads.
- `--experiment_config_json`: Specify configuration settings using a json file. For example, passing a json with `{globals.relax.max_iterations = 10}` specifies 10 as the maximum number of relaxation iterations. See for [`openfold/config.py`](https://github.com/aqlaboratory/openfold/blob/main/openfold/config.py#L283) the full dictionary of configuration settings. Any parameters that are not manually set in these configuration settings will refer to the defaults specified by your `config_preset`. - `--experiment_config_json`: Specify configuration settings using a json file. For example, passing a json with `{globals.relax.max_iterations = 10}` specifies 10 as the maximum number of relaxation iterations. See for [`openfold/config.py`](https://github.com/aqlaboratory/openfold/blob/main/openfold/config.py#L283) the full dictionary of configuration settings. Any parameters that are not manually set in these configuration settings will refer to the defaults specified by your `config_preset`.
- `--use_custom_template`: Uses all .cif files in `template_mmcif_dir` as template input. Make sure the chains of interest have the identifier _A_ and have the same length as the input sequence. The same templates will be read for all sequences that are passed for inference.
### Advanced Options for Increasing Efficiency ### Advanced Options for Increasing Efficiency
...@@ -159,12 +160,12 @@ Note that chunking (as defined in section 1.11.8 of the AlphaFold 2 supplement) ...@@ -159,12 +160,12 @@ Note that chunking (as defined in section 1.11.8 of the AlphaFold 2 supplement)
#### Long sequence inference #### Long sequence inference
To minimize memory usage during inference on long sequences, consider the following changes: To minimize memory usage during inference on long sequences, consider the following changes:
- As noted in the AlphaFold-Multimer paper, the AlphaFold/OpenFold template stack is a major memory bottleneck for inference on long sequences. OpenFold supports two mutually exclusive inference modes to address this issue. One, `average_templates` in the `template` section of the config, is similar to the solution offered by AlphaFold-Multimer, which is simply to average individual template representations. Our version is modified slightly to accommodate weights trained using the standard template algorithm. Using said weights, we notice no significant difference in performance between our averaged template embeddings and the standard ones. The second, `offload_templates`, temporarily offloads individual template embeddings into CPU memory. The former is an approximation while the latter is slightly slower; both are memory-efficient and allow the model to utilize arbitrarily many templates across sequence lengths. Both are disabled by default, and it is up to the user to determine which best suits their needs, if either. - As noted in the AlphaFold-Multimer paper, the AlphaFold/OpenFold template stack is a major memory bottleneck for inference on long sequences. OpenFold supports two mutually exclusive inference modes to address this issue. One, `average_templates` in the `template` section of the config, is similar to the solution offered by AlphaFold-Multimer, which is simply to average individual template representations. Our version is modified slightly to accommodate weights trained using the standard template algorithm. Using said weights, we notice no significant difference in performance between our averaged template embeddings and the standard ones. The second, `offload_templates`, temporarily offloads individual template embeddings into CPU memory. The former is an approximation while the latter is slightly slower; both are memory-efficient and allow the model to utilize arbitrarily many templates across sequence lengths. Both are disabled by default, and it is up to the user to determine which best suits their needs, if either.
- Inference-time low-memory attention (LMA) can be enabled in the model config. This setting trades off speed for vastly improved memory usage. By default, LMA is run with query and key chunk sizes of 1024 and 4096, respectively. These represent a favorable tradeoff in most memory-constrained cases. Powerusers can choose to tweak these settings in `openfold/model/primitives.py`. For more information on the LMA algorithm, see the aforementioned Staats & Rabe preprint. - Inference-time low-memory attention (LMA) can be enabled in the model config. This setting trades off speed for vastly improved memory usage. By default, LMA is run with query and key chunk sizes of 1024 and 4096, respectively. These represent a favorable tradeoff in most memory-constrained cases. Powerusers can choose to tweak these settings in `openfold/model/primitives.py`. For more information on the LMA algorithm, see the aforementioned Staats & Rabe preprint.
- Disable `tune_chunk_size` for long sequences. Past a certain point, it only wastes time. - Disable `tune_chunk_size` for long sequences. Past a certain point, it only wastes time.
- As a last resort, consider enabling `offload_inference`. This enables more extensive CPU offloading at various bottlenecks throughout the model. - As a last resort, consider enabling `offload_inference`. This enables more extensive CPU offloading at various bottlenecks throughout the model.
- Disable FlashAttention, which seems unstable on long sequences. - Disable FlashAttention, which seems unstable on long sequences.
Using the most conservative settings, we were able to run inference on a 4600-residue complex with a single A100. Compared to AlphaFold's own memory offloading mode, ours is considerably faster; the same complex takes the more efficent AlphaFold-Multimer more than double the time. Use the `long_sequence_inference` config option to enable all of these interventions at once. The `run_pretrained_openfold.py` script can enable this config option with the `--long_sequence_inference` command line option Using the most conservative settings, we were able to run inference on a 4600-residue complex with a single A100. Compared to AlphaFold's own memory offloading mode, ours is considerably faster; the same complex takes the more efficent AlphaFold-Multimer more than double the time. Use the `long_sequence_inference` config option to enable all of these interventions at once. The `run_pretrained_openfold.py` script can enable this config option with the `--long_sequence_inference` command line option
Input FASTA files containing multiple sequences are treated as complexes. In this case, the inference script runs AlphaFold-Gap, a hack proposed [here](https://twitter.com/minkbaek/status/1417538291709071362?lang=en), using the specified stock AlphaFold/OpenFold parameters (NOT AlphaFold-Multimer). Input FASTA files containing multiple sequences are treated as complexes. In this case, the inference script runs AlphaFold-Gap, a hack proposed [here](https://twitter.com/minkbaek/status/1417538291709071362?lang=en), using the specified stock AlphaFold/OpenFold parameters (NOT AlphaFold-Multimer).
\ No newline at end of file \ No newline at end of file
# Setting Up OpenFold
In this guide, we will OpenFold and its dependencies.
**Pre-requisites**
This package is currently supported for CUDA 11 and Pytorch 1.12. All dependencies are listed in the [`environment.yml`](https://github.com/aqlaboratory/openfold/blob/main/environment.yml). To install OpenFold for CUDA 12, please refer to the [Environment specific modifications](#Environment-specific-modifications) section.
At this time, only Linux systems are supported.
## Instructions
:::
### Installation:
1. Clone the repository, e.g. `git clone https://github.com/aqlaboratory/openfold.git`
1. From the `openfold` repo:
- Create a [Mamba]("https://github.com/conda-forge/miniforge/releases/latest/download/) environment, e.g.
`mamba env create -n openfold_env -f environment.yml`
Mamba is recommended as the dependencies required by OpenFold are quite large and mamba can speed up the process.
- Activate the environment, e.g `conda activate openfold_env`
1. Run the setup script to configure kernels and folding resources.
> scripts/install_third_party_dependencies.sh
1. Prepend the conda environment to the `$LD_LIBRARY_PATH` and `$LIBRARY_PATH`., e.g.
```
export LIBRARY_PATH=$CONDA_PREFIX/lib:$LIBRARY_PATH
export LD_LIBRARY_PATH=$CONDA_PREFIX/lib:$LD_LIBRARY_PATH
```
You may optionally set this as a conda environment variable according to the [conda docs](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#saving-environment-variables) to activate each time the environment is used.
1. Download parameters. We recommend using a destination as `openfold/resources` as our unittests will look for the weights there.
- For AlphaFold2 weights, use
> ./scripts/download_alphafold_params.sh <dest>
- For OpenFold weights, use :
> ./scripts/download_openfold_params.sh <dest>
- For OpenFold SoloSeq weights, use:
> ./scripts/download_openfold_soloseq_params.sh <dest>
### Checking your build with unit tests:
To test your installation, you can run OpenFold unit tests. Make sure that the OpenFold and AlphaFold parameters have been downloaded, and that they are located (or symlinked) in the directory `openfold/resources`
Run with the following script:
> scripts/run_unit_tests.sh
The script is a thin wrapper around Python's `unittest` suite, and recognizes `unittest` arguments. E.g., to run a specific test verbosely:
> scripts/run_unit_tests.sh -v tests.test_model
**Alphafold Comparison tests:**
Certain tests perform equivalence comparisons with the AlphaFold implementation. Instructions to run this level of tests requires an environment with both AlphaFold 2.0.1 and OpenFold installed, and is not covered in this guide. These tests are skipped by default if no installation of AlphaFold is found.
## Environment specific modifications
### CUDA 12
To use OpenFold on CUDA 12 environment rather than a CUDA 11 environment.
In step 1, use the branch [`pl_upgrades`](https://github.com/aqlaboratory/openfold/tree/pl_upgrades) rather than the main branch, i.e. replace the command in step 1 with `git clone -b pl_upgrades https://github.com/aqlaboratory/openfold.git`
and follow the rest of the steps of [Installation Guide](#Installation)
### MPI
To use OpenFold with MPI support, you will need to add the package [`mpi4py`](https://pypi.org/project/mpi4py/). This can be done with pip in your OpenFold environment, e.g. `$ pip install mpi4py`.
### Install OpenFold parameters without aws
If you don't have access to `aws` on your system, you can use a different download source:
- HuggingFace (requires `git-lts`): `scripts/download_openfold_params_huggingface.sh`
- Google Drive: `scripts/download_openfold_params_gdrive.sh`
### Docker setup
A [`Dockerfile`] is provided to build an OpenFold Docker image. Additional notes for setting up a docker container for OpenFold and running inference can be found [here](original_readme.md#building-and-using-the-docker-container).
...@@ -72,8 +72,7 @@ python3 run_pretrained_openfold.py \ ...@@ -72,8 +72,7 @@ python3 run_pretrained_openfold.py \
--output_dir ./ --output_dir ./
``` ```
Note that template searching in the multimer pipeline **Notes:**
uses HMMSearch with the PDB SeqRes database, replacing HHSearch and PDB70 used in the monomer pipeline. - Template searching in the multimer pipeline uses HMMSearch with the PDB SeqRes database, replacing HHSearch and PDB70 used in the monomer pipeline.
- As with monomer inference, if you've already computed alignments for the query, you can use the `--use_precomputed_alignments` option.
As with monomer inference, if you've already computed alignments for the query, you can use - At this time, only AlphaFold parameter weights are available for multimer mode.
the `--use_precomputed_alignments` option. \ No newline at end of file
\ No newline at end of file
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
:align: center :align: center
:alt: Comparison of OpenFold and AlphaFold2 predictions to the experimental structure of PDB 7KDX, chain B._ :alt: Comparison of OpenFold and AlphaFold2 predictions to the experimental structure of PDB 7KDX, chain B._
``` ```
Welcome to the Documentation for OpenFold, the fully open source, trainable, PyTorch-based reproduction of DeepMind's Welcome to the Documentation for [OpenFold](https://github.com/aqlaboratory/openfold), the fully open source, trainable, PyTorch-based reproduction of DeepMind's
[AlphaFold 2](https://github.com/deepmind/alphafold). [AlphaFold 2](https://github.com/deepmind/alphafold).
Here, you will find guides and documentation for: Here, you will find guides and documentation for:
...@@ -115,4 +115,4 @@ Aux_seq_files.md ...@@ -115,4 +115,4 @@ Aux_seq_files.md
OpenFold_Parameters.md OpenFold_Parameters.md
FAQ.md FAQ.md
original_readme.md original_readme.md
``` ```
\ No newline at end of file
...@@ -25,8 +25,8 @@ dependencies: ...@@ -25,8 +25,8 @@ dependencies:
- modelcif==0.7 - modelcif==0.7
- awscli - awscli
- ml-collections - ml-collections
- mkl=2022.1
- aria2 - aria2
- mkl=2024.0
- git - git
- bioconda::hmmer - bioconda::hmmer
- bioconda::hhsuite - bioconda::hhsuite
......
...@@ -111,7 +111,7 @@ ...@@ -111,7 +111,7 @@
"os.system(\"wget -qnc https://github.com/conda-forge/miniforge/releases/latest/download/Mambaforge-Linux-x86_64.sh\")\n", "os.system(\"wget -qnc https://github.com/conda-forge/miniforge/releases/latest/download/Mambaforge-Linux-x86_64.sh\")\n",
"os.system(\"bash Mambaforge-Linux-x86_64.sh -bfp /usr/local\")\n", "os.system(\"bash Mambaforge-Linux-x86_64.sh -bfp /usr/local\")\n",
"os.system(\"mamba config --set auto_update_conda false\")\n", "os.system(\"mamba config --set auto_update_conda false\")\n",
"os.system(f\"mamba install -y -c conda-forge -c bioconda kalign2=2.04 hhsuite=3.3.0 openmm=7.7.0 python={python_version} pdbfixer biopython=1.79\")\n", "os.system(f\"mamba install -y -c conda-forge -c bioconda kalign2=2.04 hhsuite=3.3.0 openmm=7.7.0 python={python_version} pdbfixer biopython=1.83\")\n",
"os.system(\"pip install -q torch ml_collections py3Dmol modelcif\")\n", "os.system(\"pip install -q torch ml_collections py3Dmol modelcif\")\n",
"\n", "\n",
"try:\n", "try:\n",
...@@ -127,7 +127,7 @@ ...@@ -127,7 +127,7 @@
"\n", "\n",
" %shell mkdir -p /content/openfold/openfold/resources\n", " %shell mkdir -p /content/openfold/openfold/resources\n",
"\n", "\n",
" commit = \"a96ffd67f8c96f8c4decc3abdd2cffbb57fc5764\"\n", " commit = \"3bec3e9b2d1e8bdb83887899102eff7d42dc2ba9\"\n",
" os.system(f\"pip install -q git+https://github.com/aqlaboratory/openfold.git@{commit}\")\n", " os.system(f\"pip install -q git+https://github.com/aqlaboratory/openfold.git@{commit}\")\n",
"\n", "\n",
" os.system(f\"cp -f -p /content/stereo_chemical_props.txt /usr/local/lib/python{python_version}/site-packages/openfold/resources/\")\n", " os.system(f\"cp -f -p /content/stereo_chemical_props.txt /usr/local/lib/python{python_version}/site-packages/openfold/resources/\")\n",
...@@ -907,4 +907,4 @@ ...@@ -907,4 +907,4 @@
}, },
"nbformat": 4, "nbformat": 4,
"nbformat_minor": 0 "nbformat_minor": 0
} }
\ No newline at end of file
...@@ -23,8 +23,19 @@ import tempfile ...@@ -23,8 +23,19 @@ import tempfile
from typing import Mapping, Optional, Sequence, Any, MutableMapping, Union from typing import Mapping, Optional, Sequence, Any, MutableMapping, Union
import numpy as np import numpy as np
import torch import torch
from openfold.data import templates, parsers, mmcif_parsing, msa_identifiers, msa_pairing, feature_processing_multimer from openfold.data import (
from openfold.data.templates import get_custom_template_features, empty_template_feats templates,
parsers,
mmcif_parsing,
msa_identifiers,
msa_pairing,
feature_processing_multimer,
)
from openfold.data.templates import (
get_custom_template_features,
empty_template_feats,
CustomHitFeaturizer,
)
from openfold.data.tools import jackhmmer, hhblits, hhsearch, hmmsearch from openfold.data.tools import jackhmmer, hhblits, hhsearch, hmmsearch
from openfold.np import residue_constants, protein from openfold.np import residue_constants, protein
...@@ -38,7 +49,9 @@ def make_template_features( ...@@ -38,7 +49,9 @@ def make_template_features(
template_featurizer: Any, template_featurizer: Any,
) -> FeatureDict: ) -> FeatureDict:
hits_cat = sum(hits.values(), []) hits_cat = sum(hits.values(), [])
if(len(hits_cat) == 0 or template_featurizer is None): if template_featurizer is None or (
len(hits_cat) == 0 and not isinstance(template_featurizer, CustomHitFeaturizer)
):
template_features = empty_template_feats(len(input_sequence)) template_features = empty_template_feats(len(input_sequence))
else: else:
templates_result = template_featurizer.get_templates( templates_result = template_featurizer.get_templates(
......
...@@ -283,7 +283,7 @@ def parse( ...@@ -283,7 +283,7 @@ def parse(
author_chain = mmcif_to_author_chain_id[chain_id] author_chain = mmcif_to_author_chain_id[chain_id]
seq = [] seq = []
for monomer in seq_info: for monomer in seq_info:
code = PDBData.protein_letters_3to1.get(monomer.id, "X") code = PDBData.protein_letters_3to1_extended.get(monomer.id, "X")
seq.append(code if len(code) == 1 else "X") seq.append(code if len(code) == 1 else "X")
seq = "".join(seq) seq = "".join(seq)
author_chain_to_sequence[author_chain] = seq author_chain_to_sequence[author_chain] = seq
......
...@@ -22,6 +22,7 @@ import glob ...@@ -22,6 +22,7 @@ import glob
import json import json
import logging import logging
import os import os
from pathlib import Path
import re import re
from typing import Any, Dict, Mapping, Optional, Sequence, Tuple from typing import Any, Dict, Mapping, Optional, Sequence, Tuple
...@@ -947,55 +948,71 @@ def _process_single_hit( ...@@ -947,55 +948,71 @@ def _process_single_hit(
def get_custom_template_features( def get_custom_template_features(
mmcif_path: str, mmcif_path: str,
query_sequence: str, query_sequence: str,
pdb_id: str, pdb_id: str,
chain_id: str, chain_id: Optional[str] = "A",
kalign_binary_path: str): kalign_binary_path: Optional[str] = None,
):
with open(mmcif_path, "r") as mmcif_path: if os.path.isfile(mmcif_path):
cif_string = mmcif_path.read() template_paths = [Path(mmcif_path)]
mmcif_parse_result = mmcif_parsing.parse(
file_id=pdb_id, mmcif_string=cif_string
)
template_sequence = mmcif_parse_result.mmcif_object.chain_to_seqres[chain_id]
mapping = {x:x for x, _ in enumerate(query_sequence)}
features, warnings = _extract_template_features(
mmcif_object=mmcif_parse_result.mmcif_object,
pdb_id=pdb_id,
mapping=mapping,
template_sequence=template_sequence,
query_sequence=query_sequence,
template_chain_id=chain_id,
kalign_binary_path=kalign_binary_path,
_zero_center_positions=True
)
features["template_sum_probs"] = [1.0]
# TODO: clean up this logic
template_features = {}
for template_feature_name in TEMPLATE_FEATURES:
template_features[template_feature_name] = []
for k in template_features:
template_features[k].append(features[k])
for name in template_features:
template_features[name] = np.stack(
template_features[name], axis=0
).astype(TEMPLATE_FEATURES[name])
elif os.path.isdir(mmcif_path):
template_paths = list(Path(mmcif_path).glob("*.cif"))
else:
logging.error("Custom template path %s does not exist", mmcif_path)
raise ValueError(f"Custom template path {mmcif_path} does not exist")
warnings = []
template_features = dict()
for template_path in template_paths:
logging.info("Featurizing template: %s", template_path)
# pdb_id only for error reporting, take file name
pdb_id = Path(template_path).stem
with open(template_path, "r") as mmcif_path:
cif_string = mmcif_path.read()
mmcif_parse_result = mmcif_parsing.parse(
file_id=pdb_id, mmcif_string=cif_string
)
# mapping skipping "-"
mapping = {
x: x for x, curr_char in enumerate(query_sequence) if curr_char.isalnum()
}
realigned_sequence, realigned_mapping = _realign_pdb_template_to_query(
old_template_sequence=query_sequence,
template_chain_id=chain_id,
mmcif_object=mmcif_parse_result.mmcif_object,
old_mapping=mapping,
kalign_binary_path=kalign_binary_path,
)
curr_features, curr_warnings = _extract_template_features(
mmcif_object=mmcif_parse_result.mmcif_object,
pdb_id=pdb_id,
mapping=realigned_mapping,
template_sequence=realigned_sequence,
query_sequence=query_sequence,
template_chain_id=chain_id,
kalign_binary_path=kalign_binary_path,
_zero_center_positions=True,
)
curr_features["template_sum_probs"] = [
1.0
] # template given by user, 100% confident
template_features = {
curr_name: template_features.get(curr_name, []) + [curr_item]
for curr_name, curr_item in curr_features.items()
}
warnings.append(curr_warnings)
template_features = {
template_feature_name: np.stack(
template_features[template_feature_name], axis=0
).astype(template_feature_type)
for template_feature_name, template_feature_type in TEMPLATE_FEATURES.items()
}
return TemplateSearchResult( return TemplateSearchResult(
features=template_features, errors=None, warnings=warnings features=template_features, errors=None, warnings=warnings
) )
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class TemplateSearchResult: class TemplateSearchResult:
features: Mapping[str, Any] features: Mapping[str, Any]
...@@ -1188,6 +1205,23 @@ class HhsearchHitFeaturizer(TemplateHitFeaturizer): ...@@ -1188,6 +1205,23 @@ class HhsearchHitFeaturizer(TemplateHitFeaturizer):
) )
class CustomHitFeaturizer(TemplateHitFeaturizer):
"""Featurizer for templates given in folder.
Chain of interest has to be chain A and of same sequence length as input sequence."""
def get_templates(
self,
query_sequence: str,
hits: Sequence[parsers.TemplateHit],
) -> TemplateSearchResult:
"""Computes the templates for given query sequence (more details above)."""
logging.info("Featurizing mmcif_dir: %s", self._mmcif_dir)
return get_custom_template_features(
self._mmcif_dir,
query_sequence=query_sequence,
pdb_id="test",
chain_id="A",
kalign_binary_path=self._kalign_binary_path,
)
class HmmsearchHitFeaturizer(TemplateHitFeaturizer): class HmmsearchHitFeaturizer(TemplateHitFeaturizer):
def get_templates( def get_templates(
self, self,
......
...@@ -185,12 +185,7 @@ def main(args): ...@@ -185,12 +185,7 @@ def main(args):
use_deepspeed_evoformer_attention=args.use_deepspeed_evoformer_attention, use_deepspeed_evoformer_attention=args.use_deepspeed_evoformer_attention,
) )
if args.experiment_config_json: if args.experiment_config_json:
with open(args.experiment_config_json, 'r') as f:
custom_config_dict = json.load(f)
config.update_from_flattened_dict(custom_config_dict)
if args.experiment_config_json:
with open(args.experiment_config_json, 'r') as f: with open(args.experiment_config_json, 'r') as f:
custom_config_dict = json.load(f) custom_config_dict = json.load(f)
config.update_from_flattened_dict(custom_config_dict) config.update_from_flattened_dict(custom_config_dict)
...@@ -202,8 +197,15 @@ def main(args): ...@@ -202,8 +197,15 @@ def main(args):
) )
is_multimer = "multimer" in args.config_preset is_multimer = "multimer" in args.config_preset
is_custom_template = "use_custom_template" in args and args.use_custom_template
if is_multimer: if is_custom_template:
template_featurizer = templates.CustomHitFeaturizer(
mmcif_dir=args.template_mmcif_dir,
max_template_date="9999-12-31", # just dummy, not used
max_hits=-1, # just dummy, not used
kalign_binary_path=args.kalign_binary_path
)
elif is_multimer:
template_featurizer = templates.HmmsearchHitFeaturizer( template_featurizer = templates.HmmsearchHitFeaturizer(
mmcif_dir=args.template_mmcif_dir, mmcif_dir=args.template_mmcif_dir,
max_template_date=args.max_template_date, max_template_date=args.max_template_date,
...@@ -221,11 +223,9 @@ def main(args): ...@@ -221,11 +223,9 @@ def main(args):
release_dates_path=args.release_dates_path, release_dates_path=args.release_dates_path,
obsolete_pdbs_path=args.obsolete_pdbs_path obsolete_pdbs_path=args.obsolete_pdbs_path
) )
data_processor = data_pipeline.DataPipeline( data_processor = data_pipeline.DataPipeline(
template_featurizer=template_featurizer, template_featurizer=template_featurizer,
) )
if is_multimer: if is_multimer:
data_processor = data_pipeline.DataPipelineMultimer( data_processor = data_pipeline.DataPipelineMultimer(
monomer_data_pipeline=data_processor, monomer_data_pipeline=data_processor,
...@@ -238,7 +238,6 @@ def main(args): ...@@ -238,7 +238,6 @@ def main(args):
np.random.seed(random_seed) np.random.seed(random_seed)
torch.manual_seed(random_seed + 1) torch.manual_seed(random_seed + 1)
feature_processor = feature_pipeline.FeaturePipeline(config.data) feature_processor = feature_pipeline.FeaturePipeline(config.data)
if not os.path.exists(output_dir_base): if not os.path.exists(output_dir_base):
os.makedirs(output_dir_base) os.makedirs(output_dir_base)
...@@ -273,6 +272,11 @@ def main(args): ...@@ -273,6 +272,11 @@ def main(args):
seq_sort_fn = lambda target: sum([len(s) for s in target[1]]) seq_sort_fn = lambda target: sum([len(s) for s in target[1]])
sorted_targets = sorted(zip(tag_list, seq_list), key=seq_sort_fn) sorted_targets = sorted(zip(tag_list, seq_list), key=seq_sort_fn)
feature_dicts = {} feature_dicts = {}
if is_multimer and args.openfold_checkpoint_path:
raise ValueError(
'`openfold_checkpoint_path` was specified, but no OpenFold checkpoints are available for multimer mode')
model_generator = load_models_from_command_line( model_generator = load_models_from_command_line(
config, config,
args.model_device, args.model_device,
...@@ -308,7 +312,6 @@ def main(args): ...@@ -308,7 +312,6 @@ def main(args):
) )
feature_dicts[tag] = feature_dict feature_dicts[tag] = feature_dict
processed_feature_dict = feature_processor.process_features( processed_feature_dict = feature_processor.process_features(
feature_dict, mode='predict', is_multimer=is_multimer feature_dict, mode='predict', is_multimer=is_multimer
) )
...@@ -395,6 +398,10 @@ if __name__ == "__main__": ...@@ -395,6 +398,10 @@ if __name__ == "__main__":
help="""Path to alignment directory. If provided, alignment computation help="""Path to alignment directory. If provided, alignment computation
is skipped and database path arguments are ignored.""" is skipped and database path arguments are ignored."""
) )
parser.add_argument(
"--use_custom_template", action="store_true", default=False,
help="""Use mmcif given with "template_mmcif_dir" argument as template input."""
)
parser.add_argument( parser.add_argument(
"--use_single_seq_mode", action="store_true", default=False, "--use_single_seq_mode", action="store_true", default=False,
help="""Use single sequence embeddings instead of MSAs.""" help="""Use single sequence embeddings instead of MSAs."""
...@@ -489,5 +496,4 @@ if __name__ == "__main__": ...@@ -489,5 +496,4 @@ if __name__ == "__main__":
"""The model is being run on CPU. Consider specifying """The model is being run on CPU. Consider specifying
--model_device for better performance""" --model_device for better performance"""
) )
main(args) main(args)
"""
This script generates a FASTA file for all chains in an alignment directory or
alignment DB.
"""
import json
from argparse import ArgumentParser
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from typing import Optional
from tqdm import tqdm
def chain_dir_to_fasta(dir: Path) -> str:
"""
Generates a FASTA string from a chain directory.
"""
# take some alignment file
for alignment_file_type in [
"mgnify_hits.a3m",
"uniref90_hits.a3m",
"bfd_uniclust_hits.a3m",
]:
alignment_file = dir / alignment_file_type
if alignment_file.exists():
break
with open(alignment_file, "r") as f:
next(f) # skip the first line
seq = next(f).strip()
try:
next_line = next(f)
except StopIteration:
pass
else:
assert next_line.startswith(">") # ensure that sequence ended
chain_id = dir.name
return f">{chain_id}\n{seq}\n"
def index_entry_to_fasta(index_entry: dict, db_dir: Path, chain_id: str) -> str:
"""
Generates a FASTA string from an alignment-db index entry.
"""
db_file = db_dir / index_entry["db"]
# look for an alignment file
for alignment_file_type in [
"mgnify_hits.a3m",
"uniref90_hits.a3m",
"bfd_uniclust_hits.a3m",
]:
for file_info in index_entry["files"]:
if file_info[0] == alignment_file_type:
start, size = file_info[1], file_info[2]
break
with open(db_file, "rb") as f:
f.seek(start)
msa_lines = f.read(size).decode("utf-8").splitlines()
seq = msa_lines[1]
try:
next_line = msa_lines[2]
except IndexError:
pass
else:
assert next_line.startswith(">") # ensure that sequence ended
return f">{chain_id}\n{seq}\n"
def main(
output_path: Path, alignment_db_index: Optional[Path], alignment_dir: Optional[Path]
) -> None:
"""
Generate a FASTA file from either an alignment-db index or a chain directory using multi-threading.
"""
fasta = []
if alignment_dir and alignment_db_index:
raise ValueError(
"Only one of alignment_db_index and alignment_dir can be provided."
)
if alignment_dir:
print("Creating FASTA from alignment directory...")
chain_dirs = list(alignment_dir.iterdir())
with ThreadPoolExecutor() as executor:
futures = [
executor.submit(chain_dir_to_fasta, chain_dir)
for chain_dir in chain_dirs
]
for future in tqdm(as_completed(futures), total=len(chain_dirs)):
fasta.append(future.result())
elif alignment_db_index:
print("Creating FASTA from alignment dbs...")
with open(alignment_db_index, "r") as f:
index = json.load(f)
db_dir = alignment_db_index.parent
with ThreadPoolExecutor() as executor:
futures = [
executor.submit(index_entry_to_fasta, index_entry, db_dir, chain_id)
for chain_id, index_entry in index.items()
]
for future in tqdm(as_completed(futures), total=len(index)):
fasta.append(future.result())
else:
raise ValueError("Either alignment_db_index or alignment_dir must be provided.")
with open(output_path, "w") as f:
f.write("".join(fasta))
print(f"FASTA file written to {output_path}.")
if __name__ == "__main__":
parser = ArgumentParser(description=__doc__)
parser.add_argument(
"output_path",
type=Path,
help="Path to output FASTA file.",
)
parser.add_argument(
"--alignment_db_index",
type=Path,
help="Path to alignment-db index file.",
)
parser.add_argument(
"--alignment_dir",
type=Path,
help="Path to alignment directory.",
)
args = parser.parse_args()
main(args.output_path, args.alignment_db_index, args.alignment_dir)
...@@ -5,17 +5,19 @@ super index, meaning that "unify_alignment_db_indices.py" does not need to be ...@@ -5,17 +5,19 @@ super index, meaning that "unify_alignment_db_indices.py" does not need to be
run on the output index. Additionally this script uses threading and run on the output index. Additionally this script uses threading and
multiprocessing and is much faster than the old version. multiprocessing and is much faster than the old version.
""" """
import argparse import argparse
import json
from collections import defaultdict from collections import defaultdict
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
import json from math import ceil
from multiprocessing import cpu_count
from pathlib import Path from pathlib import Path
from typing import List
from tqdm import tqdm from tqdm import tqdm
from math import ceil
def split_file_list(file_list, n_shards): def split_file_list(file_list: list[Path], n_shards: int):
""" """
Split up the total file list into n_shards sublists. Split up the total file list into n_shards sublists.
""" """
...@@ -29,26 +31,25 @@ def split_file_list(file_list, n_shards): ...@@ -29,26 +31,25 @@ def split_file_list(file_list, n_shards):
return split_list return split_list
def chunked_iterator(lst, chunk_size): def chunked_iterator(lst: list, chunk_size: int):
"""Iterate over a list in chunks of size chunk_size.""" """Iterate over a list in chunks of size chunk_size."""
for i in range(0, len(lst), chunk_size): for i in range(0, len(lst), chunk_size):
yield lst[i : i + chunk_size] yield lst[i : i + chunk_size]
def read_chain_dir(chain_dir) -> dict: def read_chain_dir(chain_dir: Path) -> dict:
""" """
Read all alignment files in a single chain directory and return a dict Read all alignment files in a single chain directory and return a dict
mapping chain name to file names and bytes. mapping chain name to file names and bytes.
""" """
if not chain_dir.is_dir(): if not chain_dir.is_dir():
raise ValueError(f"chain_dir must be a directory, but is {chain_dir}") raise ValueError(f"chain_dir must be a directory, but is {chain_dir}")
# ensure that PDB IDs are all lowercase # ensure that PDB IDs are all lowercase
pdb_id, chain = chain_dir.name.split("_") pdb_id, chain = chain_dir.name.split("_")
pdb_id = pdb_id.lower() pdb_id = pdb_id.lower()
chain_name = f"{pdb_id}_{chain}" chain_name = f"{pdb_id}_{chain}"
file_data = [] file_data = []
for file_path in sorted(chain_dir.iterdir()): for file_path in sorted(chain_dir.iterdir()):
...@@ -62,7 +63,7 @@ def read_chain_dir(chain_dir) -> dict: ...@@ -62,7 +63,7 @@ def read_chain_dir(chain_dir) -> dict:
return {chain_name: file_data} return {chain_name: file_data}
def process_chunk(chain_files: List[Path]) -> dict: def process_chunk(chain_files: list[Path]) -> dict:
""" """
Returns the file names and bytes for all chains in a chunk of files. Returns the file names and bytes for all chains in a chunk of files.
""" """
...@@ -83,7 +84,7 @@ def create_index_default_dict() -> dict: ...@@ -83,7 +84,7 @@ def create_index_default_dict() -> dict:
def create_shard( def create_shard(
shard_files: List[Path], output_dir: Path, output_name: str, shard_num: int shard_files: list[Path], output_dir: Path, output_name: str, shard_num: int
) -> dict: ) -> dict:
""" """
Creates a single shard of the alignment database, and returns the Creates a single shard of the alignment database, and returns the
...@@ -92,7 +93,7 @@ def create_shard( ...@@ -92,7 +93,7 @@ def create_shard(
CHUNK_SIZE = 200 CHUNK_SIZE = 200
shard_index = defaultdict( shard_index = defaultdict(
create_index_default_dict create_index_default_dict
) # {chain_name: {db: str, files: [(file_name, db_offset, file_length)]}, ...} ) # e.g. {chain_name: {db: str, files: [(file_name, db_offset, file_length)]}, ...}
chunk_iter = chunked_iterator(shard_files, CHUNK_SIZE) chunk_iter = chunked_iterator(shard_files, CHUNK_SIZE)
pbar_desc = f"Shard {shard_num}" pbar_desc = f"Shard {shard_num}"
...@@ -101,7 +102,11 @@ def create_shard( ...@@ -101,7 +102,11 @@ def create_shard(
db_offset = 0 db_offset = 0
db_file = open(output_path, "wb") db_file = open(output_path, "wb")
for files_chunk in tqdm( for files_chunk in tqdm(
chunk_iter, total=ceil(len(shard_files) / CHUNK_SIZE), desc=pbar_desc, position=shard_num, leave=False chunk_iter,
total=ceil(len(shard_files) / CHUNK_SIZE),
desc=pbar_desc,
position=shard_num,
leave=False,
): ):
# get processed files for one chunk # get processed files for one chunk
chunk_data = process_chunk(files_chunk) chunk_data = process_chunk(files_chunk)
...@@ -125,9 +130,17 @@ def create_shard( ...@@ -125,9 +130,17 @@ def create_shard(
def main(args): def main(args):
alignment_dir = args.alignment_dir alignment_dir = args.alignment_dir
output_dir = args.output_db_path output_dir = args.output_db_path
output_dir.mkdir(exist_ok=True, parents=True)
output_db_name = args.output_db_name output_db_name = args.output_db_name
n_shards = args.n_shards n_shards = args.n_shards
n_cpus = cpu_count()
if n_shards > n_cpus:
print(
f"Warning: Your number of shards ({n_shards}) is greater than the number of cores on your machine ({n_cpus}). "
"This may result in slower performance. Consider using a smaller number of shards."
)
# get all chain dirs in alignment_dir # get all chain dirs in alignment_dir
print("Getting chain directories...") print("Getting chain directories...")
all_chain_dirs = sorted([f for f in tqdm(alignment_dir.iterdir())]) all_chain_dirs = sorted([f for f in tqdm(alignment_dir.iterdir())])
...@@ -153,12 +166,36 @@ def main(args): ...@@ -153,12 +166,36 @@ def main(args):
super_index.update(shard_index) super_index.update(shard_index)
print("\nCreated all shards.") print("\nCreated all shards.")
if args.duplicate_chains_file:
print("Extending super index with duplicate chains...")
duplicates_added = 0
with open(args.duplicate_chains_file, "r") as fp:
duplicate_chains = [line.strip().split() for line in fp]
for chains in duplicate_chains:
# find representative with alignment
for chain in chains:
if chain in super_index:
representative_chain = chain
break
else:
print(f"No representative chain found for {chains}, skipping...")
continue
# add duplicates to index
for chain in chains:
if chain != representative_chain:
super_index[chain] = super_index[representative_chain]
duplicates_added += 1
print(f"Added {duplicates_added} duplicate chains to index.")
# write super index to file # write super index to file
print("\nWriting super index...") print("\nWriting super index...")
index_path = output_dir / f"{output_db_name}.index" index_path = output_dir / f"{output_db_name}.index"
with open(index_path, "w") as fp: with open(index_path, "w") as fp:
json.dump(super_index, fp, indent=4) json.dump(super_index, fp, indent=4)
print("Done.") print("Done.")
...@@ -179,13 +216,27 @@ if __name__ == "__main__": ...@@ -179,13 +216,27 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"alignment_dir", "alignment_dir",
type=Path, type=Path,
help="""Path to precomputed alignment directory, with one subdirectory help="""Path to precomputed flattened alignment directory, with one
per chain.""", subdirectory per chain.""",
) )
parser.add_argument("output_db_path", type=Path) parser.add_argument("output_db_path", type=Path)
parser.add_argument("output_db_name", type=str) parser.add_argument("output_db_name", type=str)
parser.add_argument( parser.add_argument(
"n_shards", type=int, help="Number of shards to split the database into" "--n_shards",
type=int,
help="Number of shards to split the database into",
default=10,
)
parser.add_argument(
"--duplicate_chains_file",
type=Path,
help="""
Optional path to file containing duplicate chain information, where each
line contains chains that are 100% sequence identical. If provided,
duplicate chains will be added to the index and point to the same
underlying database entry as their representatives in the alignment dir.
""",
default=None,
) )
args = parser.parse_args() args = parser.parse_args()
......
"""
The OpenProteinSet alignment database is non-redundant, meaning that it only
stores one explicit representative alignment directory for all PDB chains in a
100% sequence identity cluster. In order to add explicit alignments for all PDB
chains, this script will add the missing chain directories and symlink them to
their representative alignment directories. This is required in order to train
OpenFold on the full PDB, not just one representative chain per cluster.
"""
from argparse import ArgumentParser
from pathlib import Path
from tqdm import tqdm
def create_duplicate_dirs(duplicate_chains: list[list[str]], alignment_dir: Path):
"""
Create duplicate directory symlinks for all chains in the given duplicate lists.
Args:
duplicate_lists (list[list[str]]): A list of lists, where each inner list
contains chains that are 100% sequence identical.
alignment_dir (Path): Path to flattened alignment directory, with one
subdirectory per chain.
"""
print("Creating duplicate directory symlinks...")
dirs_created = 0
for chains in tqdm(duplicate_chains):
# find the chain that has an alignment
for chain in chains:
if (alignment_dir / chain).exists():
representative_chain = chain
break
else:
print(f"No representative chain found for {chains}, skipping...")
continue
# create symlinks for all other chains
for chain in chains:
if chain != representative_chain:
target_path = alignment_dir / chain
if target_path.exists():
print(f"Chain {chain} already exists, skipping...")
else:
(target_path).symlink_to(alignment_dir / representative_chain)
dirs_created += 1
print(f"Created directories for {dirs_created} duplicate chains.")
def main(alignment_dir: Path, duplicate_chains_file: Path):
# read duplicate chains file
with open(duplicate_chains_file, "r") as fp:
duplicate_chains = [list(line.strip().split()) for line in fp]
# convert to absolute path for symlink creation
alignment_dir = alignment_dir.resolve()
create_duplicate_dirs(duplicate_chains, alignment_dir)
if __name__ == "__main__":
parser = ArgumentParser(description=__doc__)
parser.add_argument(
"alignment_dir",
type=Path,
help="""Path to flattened alignment directory, with one subdirectory
per chain.""",
)
parser.add_argument(
"duplicate_chains_file",
type=Path,
help="""Path to file containing duplicate chains, where each line
contains a space-separated list of chains that are 100%%
sequence identical.
""",
)
args = parser.parse_args()
main(args.alignment_dir, args.duplicate_chains_file)
...@@ -85,7 +85,7 @@ def main(args): ...@@ -85,7 +85,7 @@ def main(args):
if __name__ == "__main__": if __name__ == "__main__":
parser = ArgumentParser( parser = ArgumentParser(
description="Creates a sequence cluster file from a .fasta file using mmseqs2 with PDB settings." description=__doc__
) )
parser.add_argument( parser.add_argument(
"input_fasta", "input_fasta",
......
...@@ -9,8 +9,8 @@ ...@@ -9,8 +9,8 @@
# output_dir: # output_dir:
# The directory in which to construct the reformatted data # The directory in which to construct the reformatted data
if [[ $# != 2 ]]; then if [ "$#" -ne 2 ]; then
echo "usage: ./flatten_roda.sh <roda_dir> <output_dir>" echo "Usage: ./flatten_roda.sh <roda_dir> <output_dir>"
exit 1 exit 1
fi fi
...@@ -23,25 +23,36 @@ ALIGNMENT_DIR="${OUTPUT_DIR}/alignments" ...@@ -23,25 +23,36 @@ ALIGNMENT_DIR="${OUTPUT_DIR}/alignments"
mkdir -p "${DATA_DIR}" mkdir -p "${DATA_DIR}"
mkdir -p "${ALIGNMENT_DIR}" mkdir -p "${ALIGNMENT_DIR}"
for chain_dir in $(ls "${RODA_DIR}"); do for chain_dir in "${RODA_DIR}"/*; do
CHAIN_DIR_PATH="${RODA_DIR}/${chain_dir}" if [ ! -d "$chain_dir" ]; then
for subdir in $(ls "${CHAIN_DIR_PATH}"); do continue
if [[ ! -d "$subdir" ]]; then fi
echo "$subdir is not directory"
chain_name=$(basename "$chain_dir")
for subdir in "$chain_dir"/*; do
if [ ! -d "$subdir" ]; then
echo "$subdir is not a directory"
continue continue
elif [[ -z $(ls "${subdir}")]]; then fi
if [ -z "$(ls -A "$subdir")" ]; then
continue continue
elif [[ $subdir = "pdb" ]] || [[ $subdir = "cif" ]]; then fi
mv "${CHAIN_DIR_PATH}/${subdir}"/* "${DATA_DIR}"
subdir_name=$(basename "$subdir")
if [ "$subdir_name" = "pdb" ] || [ "$subdir_name" = "cif" ]; then
mv "$subdir"/* "${DATA_DIR}/"
else else
CHAIN_ALIGNMENT_DIR="${ALIGNMENT_DIR}/${chain_dir}" CHAIN_ALIGNMENT_DIR="${ALIGNMENT_DIR}/${chain_name}"
mkdir -p "${CHAIN_ALIGNMENT_DIR}" mkdir -p "${CHAIN_ALIGNMENT_DIR}"
mv "${CHAIN_DIR_PATH}/${subdir}"/* "${CHAIN_ALIGNMENT_DIR}" mv "$subdir"/* "${CHAIN_ALIGNMENT_DIR}/"
fi fi
done done
done done
NO_DATA_FILES=$(find "${DATA_DIR}" -type f | wc -l) NO_DATA_FILES=$(find "${DATA_DIR}" -type f | wc -l)
if [[ $NO_DATA_FILES = 0 ]]; then if [ "$NO_DATA_FILES" -eq 0 ]; then
rm -rf ${DATA_DIR} rm -rf "${DATA_DIR}"
fi fi
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment