Unverified Commit bb3f51e5 authored by Christina Floristean's avatar Christina Floristean Committed by GitHub
Browse files

Merge pull request #405 from aqlaboratory/multimer

Full multimer merge
parents ce211367 c33a0bd6
.vscode/ .vscode/
.idea/
__pycache__/ __pycache__/
*.egg-info *.egg-info
build build
...@@ -8,3 +9,4 @@ dist ...@@ -8,3 +9,4 @@ dist
data data
openfold/resources/ openfold/resources/
tests/test_data/ tests/test_data/
cutlass/
...@@ -7,13 +7,31 @@ _Figure: Comparison of OpenFold and AlphaFold2 predictions to the experimental s ...@@ -7,13 +7,31 @@ _Figure: Comparison of OpenFold and AlphaFold2 predictions to the experimental s
A faithful but trainable PyTorch reproduction of DeepMind's A faithful but trainable PyTorch reproduction of DeepMind's
[AlphaFold 2](https://github.com/deepmind/alphafold). [AlphaFold 2](https://github.com/deepmind/alphafold).
## Contents
- [OpenFold](#openfold)
- [Contents](#contents)
- [Features](#features)
- [Installation (Linux)](#installation-linux)
- [Download Alignment Databases](#download-alignment-databases)
- [Inference](#inference)
- [Monomer inference](#monomer-inference)
- [Multimer Inference](#multimer-inference)
- [Soloseq Inference](#soloseq-inference)
- [Training](#training)
- [Testing](#testing)
- [Building and Using the Docker Container](#building-and-using-the-docker-container)
- [Copyright Notice](#copyright-notice)
- [Contributing](#contributing)
- [Citing this Work](#citing-this-work)
## Features ## Features
OpenFold carefully reproduces (almost) all of the features of the original open OpenFold carefully reproduces (almost) all of the features of the original open
source inference code (v2.0.1). The sole exception is model ensembling, which source monomer (v2.0.1) and multimer (v2.3.2) inference code. The sole exception is
fared poorly in DeepMind's own ablation testing and is being phased out in future model ensembling, which fared poorly in DeepMind's own ablation testing and is being
DeepMind experiments. It is omitted here for the sake of reducing clutter. In phased out in future DeepMind experiments. It is omitted here for the sake of reducing
cases where the *Nature* paper differs from the source, we always defer to the clutter. In cases where the *Nature* paper differs from the source, we always defer to the
latter. latter.
OpenFold is trainable in full precision, half precision, or `bfloat16` with or without DeepSpeed, OpenFold is trainable in full precision, half precision, or `bfloat16` with or without DeepSpeed,
...@@ -63,7 +81,7 @@ To install: ...@@ -63,7 +81,7 @@ To install:
For some systems, it may help to append the Conda environment library path to `$LD_LIBRARY_PATH`. The `install_third_party_dependencies.sh` script does this once, but you may need this for each bash instance. For some systems, it may help to append the Conda environment library path to `$LD_LIBRARY_PATH`. The `install_third_party_dependencies.sh` script does this once, but you may need this for each bash instance.
## Usage ## Download Alignment Databases
If you intend to generate your own alignments, e.g. for inference, you have two If you intend to generate your own alignments, e.g. for inference, you have two
choices for downloading protein databases, depending on whether you want to use choices for downloading protein databases, depending on whether you want to use
...@@ -112,7 +130,16 @@ DeepMind's pretrained parameters, you will only be able to make changes that ...@@ -112,7 +130,16 @@ DeepMind's pretrained parameters, you will only be able to make changes that
do not affect the shapes of model parameters. For an example of initializing do not affect the shapes of model parameters. For an example of initializing
the model, consult `run_pretrained_openfold.py`. the model, consult `run_pretrained_openfold.py`.
### Inference ## Inference
OpenFold now supports three inference modes:
- [Monomer Inference](#monomer-inference): OpenFold reproduction of AlphaFold2. Inference available with either DeepMind's pretrained parameters or OpenFold trained parameters.
- [Multimer Inference](#multimer-inference): OpenFold reproduction of AlphaFold-Multimer. Inference available with DeepMind's pre-trained parameters.
- [Single Sequence Inference (SoloSeq)](#soloseq-inference): Language Model based structure prediction, using [ESM-1b](https://github.com/facebookresearch/esm) embeddings.
More instructions for each inference mode are provided below:
### Monomer inference
To run inference on a sequence or multiple sequences using a set of DeepMind's To run inference on a sequence or multiple sequences using a set of DeepMind's
pretrained parameters, first download the OpenFold weights e.g.: pretrained parameters, first download the OpenFold weights e.g.:
...@@ -131,14 +158,14 @@ python3 run_pretrained_openfold.py \ ...@@ -131,14 +158,14 @@ python3 run_pretrained_openfold.py \
--mgnify_database_path data/mgnify/mgy_clusters_2018_12.fa \ --mgnify_database_path data/mgnify/mgy_clusters_2018_12.fa \
--pdb70_database_path data/pdb70/pdb70 \ --pdb70_database_path data/pdb70/pdb70 \
--uniclust30_database_path data/uniclust30/uniclust30_2018_08/uniclust30_2018_08 \ --uniclust30_database_path data/uniclust30/uniclust30_2018_08/uniclust30_2018_08 \
--output_dir ./ \
--bfd_database_path data/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt \ --bfd_database_path data/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt \
--model_device "cuda:0" \
--jackhmmer_binary_path lib/conda/envs/openfold_venv/bin/jackhmmer \ --jackhmmer_binary_path lib/conda/envs/openfold_venv/bin/jackhmmer \
--hhblits_binary_path lib/conda/envs/openfold_venv/bin/hhblits \ --hhblits_binary_path lib/conda/envs/openfold_venv/bin/hhblits \
--hhsearch_binary_path lib/conda/envs/openfold_venv/bin/hhsearch \ --hhsearch_binary_path lib/conda/envs/openfold_venv/bin/hhsearch \
--kalign_binary_path lib/conda/envs/openfold_venv/bin/kalign \ --kalign_binary_path lib/conda/envs/openfold_venv/bin/kalign \
--config_preset "model_1_ptm" \ --config_preset "model_1_ptm" \
--model_device "cuda:0" \
--output_dir ./ \
--openfold_checkpoint_path openfold/resources/openfold_params/finetuning_ptm_2.pt --openfold_checkpoint_path openfold/resources/openfold_params/finetuning_ptm_2.pt
``` ```
...@@ -176,13 +203,6 @@ To enable it, add `--trace_model` to the inference command. ...@@ -176,13 +203,6 @@ To enable it, add `--trace_model` to the inference command.
To get a speedup during inference, enable [FlashAttention](https://github.com/HazyResearch/flash-attention) To get a speedup during inference, enable [FlashAttention](https://github.com/HazyResearch/flash-attention)
in the config. Note that it appears to work best for sequences with < 1000 residues. in the config. Note that it appears to work best for sequences with < 1000 residues.
Input FASTA files containing multiple sequences are treated as complexes. In
this case, the inference script runs AlphaFold-Gap, a hack proposed
[here](https://twitter.com/minkbaek/status/1417538291709071362?lang=en), using
the specified stock AlphaFold/OpenFold parameters (NOT AlphaFold-Multimer). To
run inference with AlphaFold-Multimer, use the (experimental) `multimer` branch
instead.
To minimize memory usage during inference on long sequences, consider the To minimize memory usage during inference on long sequences, consider the
following changes: following changes:
...@@ -221,7 +241,78 @@ efficent AlphaFold-Multimer more than double the time. Use the ...@@ -221,7 +241,78 @@ efficent AlphaFold-Multimer more than double the time. Use the
at once. The `run_pretrained_openfold.py` script can enable this config option with the at once. The `run_pretrained_openfold.py` script can enable this config option with the
`--long_sequence_inference` command line option `--long_sequence_inference` command line option
#### SoloSeq Inference Input FASTA files containing multiple sequences are treated as complexes. In
this case, the inference script runs AlphaFold-Gap, a hack proposed
[here](https://twitter.com/minkbaek/status/1417538291709071362?lang=en), using
the specified stock AlphaFold/OpenFold parameters (NOT AlphaFold-Multimer).
### Multimer Inference
To run inference on a complex or multiple complexes using a set of DeepMind's pretrained parameters, run e.g.:
```bash
python3 run_pretrained_openfold.py \
fasta_dir \
data/pdb_mmcif/mmcif_files/ \
--uniref90_database_path data/uniref90/uniref90.fasta \
--mgnify_database_path data/mgnify/mgy_clusters_2022_05.fa \
--pdb_seqres_database_path data/pdb_seqres/pdb_seqres.txt \
--uniref30_database_path data/uniref30/UniRef30_2021_03 \
--uniprot_database_path data/uniprot/uniprot.fasta \
--bfd_database_path data/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt \
--jackhmmer_binary_path lib/conda/envs/openfold_venv/bin/jackhmmer \
--hhblits_binary_path lib/conda/envs/openfold_venv/bin/hhblits \
--hmmsearch_binary_path lib/conda/envs/openfold_venv/bin/hmmsearch \
--hmmbuild_binary_path lib/conda/envs/openfold_venv/bin/hmmbuild \
--kalign_binary_path lib/conda/envs/openfold_venv/bin/kalign \
--config_preset "model_1_multimer_v3" \
--model_device "cuda:0" \
--output_dir ./
```
As with monomer inference, if you've already computed alignments for the query, you can use
the `--use_precomputed_alignments` option. Note that template searching in the multimer pipeline
uses HMMSearch with the PDB SeqRes database, replacing HHSearch and PDB70 used in the monomer pipeline.
**Upgrade from an existing OpenFold installation**
The above command requires several upgrades to existing openfold installations.
1. Re-download the alphafold parameters to get the latest
AlphaFold-Multimer v3 weights:
```bash
bash scripts/download_alphafold_params.sh openfold/resources
```
2. Download the [UniProt](https://www.uniprot.org/uniprotkb/)
and [PDB SeqRes](https://www.rcsb.org/) databases:
```bash
bash scripts/download_uniprot.sh data/
```
The PDB SeqRes and PDB databases must be from the same date to avoid potential
errors during template searching. Remove the existing `data/pdb_mmcif` directory
and download both databases:
```bash
bash scripts/download_pdb_mmcif.sh data/
bash scripts/download_pdb_seqres.sh data/
```
3. Additionally, AlphaFold-Multimer uses upgraded versions of the [MGnify](https://www.ebi.ac.uk/metagenomics)
and [UniRef30](https://uniclust.mmseqs.com/) (previously UniClust30) databases. To download the upgraded databases, run:
```bash
bash scripts/download_uniref30.sh data/
bash scripts/download_mgnify.sh data/
```
Multimer inference can also run with the older database versions if desired.
### Soloseq Inference
To run inference for a sequence using the SoloSeq single-sequence model, you can either precompute ESM-1b embeddings in bulk, or you can generate them during inference. To run inference for a sequence using the SoloSeq single-sequence model, you can either precompute ESM-1b embeddings in bulk, or you can generate them during inference.
For generating ESM-1b embeddings in bulk, use the provided script: `scripts/precompute_embeddings.py`. The script takes a directory of FASTA files (one sequence per file) and generates ESM-1b embeddings in the same format and directory structure as required by SoloSeq. Following is an example command to use the script: For generating ESM-1b embeddings in bulk, use the provided script: `scripts/precompute_embeddings.py`. The script takes a directory of FASTA files (one sequence per file) and generates ESM-1b embeddings in the same format and directory structure as required by SoloSeq. Following is an example command to use the script:
...@@ -260,7 +351,7 @@ python3 run_pretrained_openfold.py \ ...@@ -260,7 +351,7 @@ python3 run_pretrained_openfold.py \
--output_dir ./ \ --output_dir ./ \
--model_device "cuda:0" \ --model_device "cuda:0" \
--config_preset "seq_model_esm1b_ptm" \ --config_preset "seq_model_esm1b_ptm" \
--openfold_checkpoint_path openfold/resources/openfold_params/seq_model_esm1b_ptm.pt \ --openfold_checkpoint_path openfold/resources/openfold_soloseq_params/seq_model_esm1b_ptm.pt \
--uniref90_database_path data/uniref90/uniref90.fasta \ --uniref90_database_path data/uniref90/uniref90.fasta \
--pdb70_database_path data/pdb70/pdb70 \ --pdb70_database_path data/pdb70/pdb70 \
--jackhmmer_binary_path lib/conda/envs/openfold_venv/bin/jackhmmer \ --jackhmmer_binary_path lib/conda/envs/openfold_venv/bin/jackhmmer \
...@@ -274,7 +365,7 @@ SoloSeq allows you to use the same flags and optimizations as the MSA-based Open ...@@ -274,7 +365,7 @@ SoloSeq allows you to use the same flags and optimizations as the MSA-based Open
**NOTE:** Due to the nature of the ESM-1b embeddings, the sequence length for inference using the SoloSeq model is limited to 1022 residues. Sequences longer than that will be truncated. **NOTE:** Due to the nature of the ESM-1b embeddings, the sequence length for inference using the SoloSeq model is limited to 1022 residues. Sequences longer than that will be truncated.
### Training ## Training
To train the model, you will first need to precompute protein alignments. To train the model, you will first need to precompute protein alignments.
...@@ -412,9 +503,9 @@ environment. These run components of AlphaFold and OpenFold side by side and ...@@ -412,9 +503,9 @@ environment. These run components of AlphaFold and OpenFold side by side and
ensure that output activations are adequately similar. For most modules, we ensure that output activations are adequately similar. For most modules, we
target a maximum pointwise difference of `1e-4`. target a maximum pointwise difference of `1e-4`.
## Building and using the docker container ## Building and Using the Docker Container
### Building the docker image **Building the Docker Image**
Openfold can be built as a docker container using the included dockerfile. To build it, run the following command from the root of this repository: Openfold can be built as a docker container using the included dockerfile. To build it, run the following command from the root of this repository:
...@@ -422,7 +513,7 @@ Openfold can be built as a docker container using the included dockerfile. To bu ...@@ -422,7 +513,7 @@ Openfold can be built as a docker container using the included dockerfile. To bu
docker build -t openfold . docker build -t openfold .
``` ```
### Running the docker container **Running the Docker Container**
The built container contains both `run_pretrained_openfold.py` and `train_openfold.py` as well as all necessary software dependencies. It does not contain the model parameters, sequence, or structural databases. These should be downloaded to the host machine following the instructions in the Usage section above. The built container contains both `run_pretrained_openfold.py` and `train_openfold.py` as well as all necessary software dependencies. It does not contain the model parameters, sequence, or structural databases. These should be downloaded to the host machine following the instructions in the Usage section above.
...@@ -462,7 +553,7 @@ python3 /opt/openfold/run_pretrained_openfold.py \ ...@@ -462,7 +553,7 @@ python3 /opt/openfold/run_pretrained_openfold.py \
--openfold_checkpoint_path /database/openfold_params/finetuning_ptm_2.pt --openfold_checkpoint_path /database/openfold_params/finetuning_ptm_2.pt
``` ```
## Copyright notice ## Copyright Notice
While AlphaFold's and, by extension, OpenFold's source code is licensed under While AlphaFold's and, by extension, OpenFold's source code is licensed under
the permissive Apache Licence, Version 2.0, DeepMind's pretrained parameters the permissive Apache Licence, Version 2.0, DeepMind's pretrained parameters
...@@ -475,7 +566,7 @@ replaces the original, more restrictive CC BY-NC 4.0 license as of January 2022. ...@@ -475,7 +566,7 @@ replaces the original, more restrictive CC BY-NC 4.0 license as of January 2022.
If you encounter problems using OpenFold, feel free to create an issue! We also If you encounter problems using OpenFold, feel free to create an issue! We also
welcome pull requests from the community. welcome pull requests from the community.
## Citing this work ## Citing this Work
Please cite our paper: Please cite our paper:
...@@ -504,4 +595,4 @@ If you use OpenProteinSet, please also cite: ...@@ -504,4 +595,4 @@ If you use OpenProteinSet, please also cite:
primaryClass={q-bio.BM} primaryClass={q-bio.BM}
} }
``` ```
Any work that cites OpenFold should also cite AlphaFold. Any work that cites OpenFold should also cite [AlphaFold](https://www.nature.com/articles/s41586-021-03819-2) and [AlphaFold-Multimer](https://www.biorxiv.org/content/10.1101/2021.10.04.463034v1) if applicable.
...@@ -14,6 +14,7 @@ dependencies: ...@@ -14,6 +14,7 @@ dependencies:
- pytorch-lightning==1.5.10 - pytorch-lightning==1.5.10
- biopython==1.79 - biopython==1.79
- numpy==1.21 - numpy==1.21
- pandas==2.0
- PyYAML==5.4.1 - PyYAML==5.4.1
- requests - requests
- scipy==1.7 - scipy==1.7
......
...@@ -3,8 +3,7 @@ ...@@ -3,8 +3,7 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"id": "view-in-github", "id": "view-in-github"
"colab_type": "text"
}, },
"source": [ "source": [
"<a href=\"https://colab.research.google.com/github/aqlaboratory/openfold/blob/main/notebooks/OpenFold.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" "<a href=\"https://colab.research.google.com/github/aqlaboratory/openfold/blob/main/notebooks/OpenFold.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
...@@ -52,25 +51,44 @@ ...@@ -52,25 +51,44 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"cellView": "form",
"id": "rowN0bVYLe9n" "id": "rowN0bVYLe9n"
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"#@markdown ### Enter the amino acid sequence to fold ⬇️\n", "#@markdown ### Enter the amino acid sequence to fold ⬇️\n",
"sequence = 'MAAHKGAEHHHKAAEHHEQAAKHHHAAAEHHEKGEHEQAAHHADTAYAHHKHAEEHAAQAAKHDAEHHAPKPH' #@param {type:\"string\"}\n", "#@markdown For multiple sequences, separate sequences with a colon `:`\n",
"input_sequence = 'MKLKQVADKLEEVASKLYHNANELARVAKLLGER:MKLKQVADKLEEVASKLYHNANELARVAKLLGER: MKLKQVADKLEEVASKLYHNANELARVAKLLGER:MKLKQVADKLEEVASKLYHNANELARVAKLLGER' #@param {type:\"string\"}\n",
"\n", "\n",
"#@markdown ### Configure the model ⬇️\n", "#@markdown ### Configure the model ⬇️\n",
"\n", "\n",
"weight_set = 'OpenFold' #@param [\"OpenFold\", \"AlphaFold\"]\n", "weight_set = 'AlphaFold' #@param [\"OpenFold\", \"AlphaFold\"]\n",
"model_mode = 'multimer' #@param [\"monomer\", \"multimer\"]\n",
"relax_prediction = True #@param {type:\"boolean\"}\n", "relax_prediction = True #@param {type:\"boolean\"}\n",
"\n", "\n",
"\n",
"# Remove all whitespaces, tabs and end lines; upper-case\n", "# Remove all whitespaces, tabs and end lines; upper-case\n",
"sequence = sequence.translate(str.maketrans('', '', ' \\n\\t')).upper()\n", "input_sequence = input_sequence.translate(str.maketrans('', '', ' \\n\\t')).upper()\n",
"aatypes = set('ACDEFGHIKLMNPQRSTVWY') # 20 standard aatypes\n", "aatypes = set('ACDEFGHIKLMNPQRSTVWY') # 20 standard aatypes\n",
"if not set(sequence).issubset(aatypes):\n", "allowed_chars = aatypes.union({':'})\n",
" raise Exception(f'Input sequence contains non-amino acid letters: {set(sequence) - aatypes}. OpenFold only supports 20 standard amino acids as inputs.')\n", "if not set(input_sequence).issubset(allowed_chars):\n",
"\n", " raise Exception(f'Input sequence contains non-amino acid letters: {set(input_sequence) - allowed_chars}. OpenFold only supports 20 standard amino acids as inputs.')\n",
"\n",
"if ':' in input_sequence and weight_set != 'AlphaFold':\n",
" raise ValueError('Input sequence is a multimer, must select Alphafold weight set')\n",
"\n",
"import enum\n",
"@enum.unique\n",
"class ModelType(enum.Enum):\n",
" MONOMER = 0\n",
" MULTIMER = 1\n",
"\n",
"model_type_dict = {\n",
" 'monomer': ModelType.MONOMER,\n",
" 'multimer': ModelType.MULTIMER,\n",
"}\n",
"\n",
"model_type = model_type_dict[model_mode]\n",
"print(f'Length of input sequence : {len(input_sequence.replace(\":\", \"\"))}')\n",
"#@markdown After making your selections, execute this cell by pressing the\n", "#@markdown After making your selections, execute this cell by pressing the\n",
"#@markdown *Play* button on the left." "#@markdown *Play* button on the left."
] ]
...@@ -79,17 +97,16 @@ ...@@ -79,17 +97,16 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"cellView": "form",
"id": "woIxeCPygt7K" "id": "woIxeCPygt7K"
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"#@title Install third-party software\n", "#@title Install third-party software\n",
"#@markdown Please execute this cell by pressing the *Play* button on \n", "#@markdown Please execute this cell by pressing the *Play* button on\n",
"#@markdown the left.\n", "#@markdown the left.\n",
"\n", "\n",
"\n", "\n",
"#@markdown **Note**: This installs the software on the Colab \n", "#@markdown **Note**: This installs the software on the Colab\n",
"#@markdown notebook in the cloud and not on your computer.\n", "#@markdown notebook in the cloud and not on your computer.\n",
"\n", "\n",
"import os, time\n", "import os, time\n",
...@@ -103,10 +120,8 @@ ...@@ -103,10 +120,8 @@
"os.system(\"wget -qnc https://github.com/conda-forge/miniforge/releases/latest/download/Mambaforge-Linux-x86_64.sh\")\n", "os.system(\"wget -qnc https://github.com/conda-forge/miniforge/releases/latest/download/Mambaforge-Linux-x86_64.sh\")\n",
"os.system(\"bash Mambaforge-Linux-x86_64.sh -bfp /usr/local\")\n", "os.system(\"bash Mambaforge-Linux-x86_64.sh -bfp /usr/local\")\n",
"os.system(\"mamba config --set auto_update_conda false\")\n", "os.system(\"mamba config --set auto_update_conda false\")\n",
"os.system(f\"mamba install -y -c conda-forge -c bioconda kalign2=2.04 hhsuite=3.3.0 openmm=7.7.0 python={python_version} pdbfixer\")\n", "os.system(f\"mamba install -y -c conda-forge -c bioconda kalign2=2.04 hhsuite=3.3.0 openmm=7.7.0 python={python_version} pdbfixer biopython=1.79\")\n",
"\n", "os.system(\"pip install -q torch ml_collections py3Dmol modelcif\")\n",
"\n",
"os.system(\"pip install -q \\\"torch<2\\\" biopython ml_collections py3Dmol modelcif\")\n",
"\n", "\n",
"try:\n", "try:\n",
" with io.capture_output() as captured:\n", " with io.capture_output() as captured:\n",
...@@ -119,12 +134,12 @@ ...@@ -119,12 +134,12 @@
" %shell wget -q -P /content \\\n", " %shell wget -q -P /content \\\n",
" https://git.scicore.unibas.ch/schwede/openstructure/-/raw/7102c63615b64735c4941278d92b554ec94415f8/modules/mol/alg/src/stereo_chemical_props.txt\n", " https://git.scicore.unibas.ch/schwede/openstructure/-/raw/7102c63615b64735c4941278d92b554ec94415f8/modules/mol/alg/src/stereo_chemical_props.txt\n",
"\n", "\n",
" %shell mkdir -p /content/openfold/openfold/resourcees\n", " %shell mkdir -p /content/openfold/openfold/resources\n",
" \n", "\n",
" commit = \"099769d2ecfd01a8baa8d950030df454a042c910\"\n", " commit = \"e2e19f16676b1a409f9ba3a6f69b11ee7f5887c2\"\n",
" os.system(f\"pip install -q git+https://github.com/aqlaboratory/openfold.git@{commit}\")\n", " os.system(f\"pip install -q git+https://github.com/aqlaboratory/openfold.git@{commit}\")\n",
" \n", "\n",
" %shell cp -f /content/stereo_chemical_props.txt /usr/local/lib/python3.10/site-packages/openfold/resources/\n", " os.system(f\"cp -f -p /content/stereo_chemical_props.txt /usr/local/lib/python{python_version}/site-packages/openfold/resources/\")\n",
"\n", "\n",
"except subprocess.CalledProcessError as captured:\n", "except subprocess.CalledProcessError as captured:\n",
" print(captured)" " print(captured)"
...@@ -134,18 +149,17 @@ ...@@ -134,18 +149,17 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"cellView": "form",
"id": "VzJ5iMjTtoZw" "id": "VzJ5iMjTtoZw"
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"#@title Download model weights \n", "#@title Download model weights\n",
"#@markdown Please execute this cell by pressing the *Play* button on \n", "#@markdown Please execute this cell by pressing the *Play* button on\n",
"#@markdown the left.\n", "#@markdown the left.\n",
"\n", "\n",
"# Define constants\n", "# Define constants\n",
"GIT_REPO='https://github.com/aqlaboratory/openfold'\n", "GIT_REPO='https://github.com/aqlaboratory/openfold'\n",
"ALPHAFOLD_PARAM_SOURCE_URL = 'https://storage.googleapis.com/alphafold/alphafold_params_2022-01-19.tar'\n", "ALPHAFOLD_PARAM_SOURCE_URL = 'https://storage.googleapis.com/alphafold/alphafold_params_2022-12-06.tar'\n",
"OPENFOLD_PARAMS_DIR = './openfold/openfold/resources/openfold_params'\n", "OPENFOLD_PARAMS_DIR = './openfold/openfold/resources/openfold_params'\n",
"ALPHAFOLD_PARAMS_DIR = './openfold/openfold/resources/params'\n", "ALPHAFOLD_PARAMS_DIR = './openfold/openfold/resources/params'\n",
"ALPHAFOLD_PARAMS_PATH = os.path.join(\n", "ALPHAFOLD_PARAMS_PATH = os.path.join(\n",
...@@ -184,17 +198,17 @@ ...@@ -184,17 +198,17 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"cellView": "form",
"id": "_FpxxMo-mvcP" "id": "_FpxxMo-mvcP"
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"#@title Import Python packages\n", "#@title Import Python packages\n",
"#@markdown Please execute this cell by pressing the *Play* button on \n", "#@markdown Please execute this cell by pressing the *Play* button on\n",
"#@markdown the left.\n", "#@markdown the left.\n",
"\n", "\n",
"import unittest.mock\n", "import unittest.mock\n",
"import sys\n", "import sys\n",
"from typing import Dict, Sequence\n",
"\n", "\n",
"sys.path.insert(0, f'/usr/local/lib/python{python_version}/dist-packages/')\n", "sys.path.insert(0, f'/usr/local/lib/python{python_version}/dist-packages/')\n",
"sys.path.insert(0, f'/usr/local/lib/python{python_version}/site-packages/')\n", "sys.path.insert(0, f'/usr/local/lib/python{python_version}/site-packages/')\n",
...@@ -234,22 +248,12 @@ ...@@ -234,22 +248,12 @@
" return \"UTF-8\"\n", " return \"UTF-8\"\n",
"locale.getpreferredencoding = getpreferredencoding\n", "locale.getpreferredencoding = getpreferredencoding\n",
"\n", "\n",
"# A filthy hack to avoid slow Linear layer initialization\n",
"import openfold.model.primitives\n",
"\n",
"def __default_linear_init__(self, *args, **kwargs):\n",
" return torch.nn.Linear.__init__(\n",
" self, \n",
" *args[:2], \n",
" **{k:v for k,v in kwargs.items() if k == \"bias\"}\n",
" )\n",
"\n",
"openfold.model.primitives.Linear.__init__ = __default_linear_init__\n",
"\n",
"from openfold import config\n", "from openfold import config\n",
"from openfold.data import feature_pipeline\n", "from openfold.data import feature_pipeline\n",
"from openfold.data import parsers\n", "from openfold.data import parsers\n",
"from openfold.data import data_pipeline\n", "from openfold.data import data_pipeline\n",
"from openfold.data import msa_pairing\n",
"from openfold.data import feature_processing_multimer\n",
"from openfold.data.tools import jackhmmer\n", "from openfold.data.tools import jackhmmer\n",
"from openfold.model import model\n", "from openfold.model import model\n",
"from openfold.np import protein\n", "from openfold.np import protein\n",
...@@ -276,22 +280,16 @@ ...@@ -276,22 +280,16 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "2tTeTTsLKPjB"
},
"outputs": [],
"source": [ "source": [
"#@title Search against genetic databases\n", "#@title Search against genetic databases\n",
"\n", "\n",
"#@markdown Once this cell has been executed, you will see\n", "#@markdown Once this cell has been executed, you will see\n",
"#@markdown statistics about the multiple sequence alignment \n", "#@markdown statistics about the multiple sequence alignment\n",
"#@markdown (MSA) that will be used by OpenFold. In particular, \n", "#@markdown (MSA) that will be used by OpenFold. In particular,\n",
"#@markdown you’ll see how well each residue is covered by similar \n", "#@markdown you’ll see how well each residue is covered by similar\n",
"#@markdown sequences in the MSA.\n", "#@markdown sequences in the MSA.\n",
"\n", "\n",
"# --- Find the closest source ---\n", "# --- Find the closest source --\n",
"test_url_pattern = 'https://storage.googleapis.com/alphafold-colab{:s}/latest/uniref90_2021_03.fasta.1'\n", "test_url_pattern = 'https://storage.googleapis.com/alphafold-colab{:s}/latest/uniref90_2021_03.fasta.1'\n",
"ex = futures.ThreadPoolExecutor(3)\n", "ex = futures.ThreadPoolExecutor(3)\n",
"def fetch(source):\n", "def fetch(source):\n",
...@@ -304,114 +302,156 @@ ...@@ -304,114 +302,156 @@
" ex.shutdown()\n", " ex.shutdown()\n",
" break\n", " break\n",
"\n", "\n",
"# --- Search against genetic databases ---\n",
"with open('target.fasta', 'wt') as f:\n",
" f.write(f'>query\\n{sequence}')\n",
"\n",
"# Run the search against chunks of genetic databases (since the genetic\n", "# Run the search against chunks of genetic databases (since the genetic\n",
"# databases don't fit in Colab ramdisk).\n", "# databases don't fit in Colab ramdisk).\n",
"\n", "\n",
"jackhmmer_binary_path = '/usr/bin/jackhmmer'\n", "jackhmmer_binary_path = '/usr/bin/jackhmmer'\n",
"dbs = []\n",
"\n", "\n",
"num_jackhmmer_chunks = {'uniref90': 59, 'smallbfd': 17, 'mgnify': 71}\n", "# --- Parse multiple sequences, if there are any ---\n",
"total_jackhmmer_chunks = sum(num_jackhmmer_chunks.values())\n", "def split_multiple_sequences(sequence):\n",
" seqs = sequence.split(':')\n",
" sorted_seqs = sorted(seqs, key=lambda s: len(s))\n",
"\n",
" # TODO: Handle the homomer case when writing fasta sequences\n",
" fasta_path_tuples = []\n",
" for idx, seq in enumerate(set(sorted_seqs)):\n",
" fasta_path = f'target_{idx+1}.fasta'\n",
" with open(fasta_path, 'wt') as f:\n",
" f.write(f'>query\\n{seq}\\n')\n",
" fasta_path_tuples.append((seq, fasta_path))\n",
" fasta_path_by_seq = dict(fasta_path_tuples)\n",
"\n",
" return sorted_seqs, fasta_path_by_seq\n",
"\n",
"sequences, fasta_path_by_sequence = split_multiple_sequences(input_sequence)\n",
"db_results_by_sequence = {seq: {} for seq in fasta_path_by_sequence.keys()}\n",
"\n",
"DB_ROOT_PATH = f'https://storage.googleapis.com/alphafold-colab{source}/latest/'\n",
"db_configs = {}\n",
"db_configs['smallbfd'] = {\n",
" 'database_path': f'{DB_ROOT_PATH}uniref90_2021_03.fasta',\n",
" 'z_value': 65984053,\n",
" 'num_jackhmmer_chunks': 17,\n",
"}\n",
"db_configs['mgnify'] = {\n",
" 'database_path': f'{DB_ROOT_PATH}mgy_clusters_2022_05.fasta',\n",
" 'z_value': 304820129,\n",
" 'num_jackhmmer_chunks': 120,\n",
"}\n",
"db_configs['uniref90'] = {\n",
" 'database_path': f'{DB_ROOT_PATH}uniref90_2022_01.fasta',\n",
" 'z_value': 144113457,\n",
" 'num_jackhmmer_chunks': 62,\n",
"}\n",
"\n",
"# Search UniProt and construct the all_seq features only for heteromers, not homomers.\n",
"if model_type == ModelType.MULTIMER and len(set(sequences)) > 1:\n",
" db_configs['uniprot'] = {\n",
" 'database_path': f'{DB_ROOT_PATH}uniprot_2021_04.fasta',\n",
" 'z_value': 225013025 + 565928,\n",
" 'num_jackhmmer_chunks': 101,\n",
" }\n",
"\n",
"total_jackhmmer_chunks = sum([d['num_jackhmmer_chunks'] for d in db_configs.values()])\n",
"with tqdm.notebook.tqdm(total=total_jackhmmer_chunks, bar_format=TQDM_BAR_FORMAT) as pbar:\n", "with tqdm.notebook.tqdm(total=total_jackhmmer_chunks, bar_format=TQDM_BAR_FORMAT) as pbar:\n",
" def jackhmmer_chunk_callback(i):\n", " def jackhmmer_chunk_callback(i):\n",
" pbar.update(n=1)\n", " pbar.update(n=1)\n",
"\n", "\n",
" pbar.set_description('Searching uniref90')\n", " for db_name, db_config in db_configs.items():\n",
" jackhmmer_uniref90_runner = jackhmmer.Jackhmmer(\n", " pbar.set_description(f'Searching {db_name}')\n",
" binary_path=jackhmmer_binary_path,\n", " jackhmmer_runner = jackhmmer.Jackhmmer(\n",
" database_path=f'https://storage.googleapis.com/alphafold-colab{source}/latest/uniref90_2021_03.fasta',\n", " binary_path=jackhmmer_binary_path,\n",
" get_tblout=True,\n", " database_path=db_config['database_path'],\n",
" num_streamed_chunks=num_jackhmmer_chunks['uniref90'],\n", " get_tblout=True,\n",
" streaming_callback=jackhmmer_chunk_callback,\n", " num_streamed_chunks=db_config['num_jackhmmer_chunks'],\n",
" z_value=135301051)\n", " streaming_callback=jackhmmer_chunk_callback,\n",
" dbs.append(('uniref90', jackhmmer_uniref90_runner.query('target.fasta')))\n", " z_value=db_config['z_value'])\n",
"\n", "\n",
" pbar.set_description('Searching smallbfd')\n", " db_results = jackhmmer_runner.query_multiple(fasta_path_by_sequence.values())\n",
" jackhmmer_smallbfd_runner = jackhmmer.Jackhmmer(\n", " for seq, result in zip(fasta_path_by_sequence.keys(), db_results):\n",
" binary_path=jackhmmer_binary_path,\n", " db_results_by_sequence[seq][db_name] = result\n",
" database_path=f'https://storage.googleapis.com/alphafold-colab{source}/latest/bfd-first_non_consensus_sequences.fasta',\n",
" get_tblout=True,\n",
" num_streamed_chunks=num_jackhmmer_chunks['smallbfd'],\n",
" streaming_callback=jackhmmer_chunk_callback,\n",
" z_value=65984053)\n",
" dbs.append(('smallbfd', jackhmmer_smallbfd_runner.query('target.fasta')))\n",
"\n",
" pbar.set_description('Searching mgnify')\n",
" jackhmmer_mgnify_runner = jackhmmer.Jackhmmer(\n",
" binary_path=jackhmmer_binary_path,\n",
" database_path=f'https://storage.googleapis.com/alphafold-colab{source}/latest/mgy_clusters_2019_05.fasta',\n",
" get_tblout=True,\n",
" num_streamed_chunks=num_jackhmmer_chunks['mgnify'],\n",
" streaming_callback=jackhmmer_chunk_callback,\n",
" z_value=304820129)\n",
" dbs.append(('mgnify', jackhmmer_mgnify_runner.query('target.fasta')))\n",
"\n", "\n",
"\n", "\n",
"# --- Extract the MSAs and visualize ---\n", "# --- Extract the MSAs and visualize ---\n",
"# Extract the MSAs from the Stockholm files.\n", "# Extract the MSAs from the Stockholm files.\n",
"# NB: deduplication happens later in data_pipeline.make_msa_features.\n", "# NB: deduplication happens later in data_pipeline.make_msa_features.\n",
"\n", "\n",
"mgnify_max_hits = 501\n", "MAX_HITS_BY_DB = {\n",
"\n", " 'uniref90': 10000,\n",
"msas = []\n", " 'smallbfd': 5000,\n",
"deletion_matrices = []\n", " 'mgnify': 501,\n",
"full_msa = []\n", " 'uniprot': 50000,\n",
"for db_name, db_results in dbs:\n", "}\n",
" unsorted_results = []\n", "\n",
" for i, result in enumerate(db_results):\n", "msas_by_seq_by_db = {seq: {} for seq in sequences}\n",
" msa, deletion_matrix, target_names = parsers.parse_stockholm(result['sto'])\n", "full_msa_by_seq = {seq: [] for seq in sequences}\n",
" e_values_dict = parsers.parse_e_values_from_tblout(result['tbl'])\n", "\n",
" e_values = [e_values_dict[t.split('/')[0]] for t in target_names]\n", "for seq, sequence_result in db_results_by_sequence.items():\n",
" zipped_results = zip(msa, deletion_matrix, target_names, e_values)\n", " print(f'parsing_results_for_sequence {seq}')\n",
" if i != 0:\n", " for db_name, db_results in sequence_result.items():\n",
" # Only take query from the first chunk\n", " unsorted_results = []\n",
" zipped_results = [x for x in zipped_results if x[2] != 'query']\n", " for i, result in enumerate(db_results):\n",
" unsorted_results.extend(zipped_results)\n", " msa_obj = parsers.parse_stockholm(result['sto'])\n",
" sorted_by_evalue = sorted(unsorted_results, key=lambda x: x[3])\n", " e_values_dict = parsers.parse_e_values_from_tblout(result['tbl'])\n",
" db_msas, db_deletion_matrices, _, _ = zip(*sorted_by_evalue)\n", " target_names = msa_obj.descriptions\n",
" if db_msas:\n", " e_values = [e_values_dict[t.split('/')[0]] for t in target_names]\n",
" if db_name == 'mgnify':\n", " zipped_results = zip(msa_obj.sequences, msa_obj.deletion_matrix, target_names, e_values)\n",
" db_msas = db_msas[:mgnify_max_hits]\n", " if i != 0:\n",
" db_deletion_matrices = db_deletion_matrices[:mgnify_max_hits]\n", " # Only take query from the first chunk\n",
" full_msa.extend(db_msas)\n", " zipped_results = [x for x in zipped_results if x[2] != 'query']\n",
" msas.append(db_msas)\n", " unsorted_results.extend(zipped_results)\n",
" deletion_matrices.append(db_deletion_matrices)\n", " sorted_by_evalue = sorted(unsorted_results, key=lambda x: x[3])\n",
" msa_size = len(set(db_msas))\n", " msas, del_matrix, targets, _ = zip(*sorted_by_evalue)\n",
" print(f'{msa_size} Sequences Found in {db_name}')\n", " db_msas = parsers.Msa(msas, del_matrix, targets)\n",
"\n", " if db_msas:\n",
"deduped_full_msa = list(dict.fromkeys(full_msa))\n", " if db_name in MAX_HITS_BY_DB:\n",
"total_msa_size = len(deduped_full_msa)\n", " db_msas.truncate(MAX_HITS_BY_DB[db_name])\n",
"print(f'\\n{total_msa_size} Sequences Found in Total\\n')\n", " msas_by_seq_by_db[seq][db_name] = db_msas\n",
"\n", " full_msa_by_seq[seq].extend(db_msas.sequences)\n",
"aa_map = {restype: i for i, restype in enumerate('ABCDEFGHIJKLMNOPQRSTUVWXYZ-')}\n", " msa_size = len(set(db_msas.sequences))\n",
"msa_arr = np.array([[aa_map[aa] for aa in seq] for seq in deduped_full_msa])\n", " print(f'{msa_size} Sequences Found in {db_name}')\n",
"num_alignments, num_res = msa_arr.shape\n", "\n",
"\n", "\n",
"fig = plt.figure(figsize=(12, 3))\n", "fig = plt.figure(figsize=(12, 3))\n",
"max_num_alignments = 0\n",
"\n",
"for seq_idx, seq in enumerate(set(sequences)):\n",
" full_msas = full_msa_by_seq[seq]\n",
" deduped_full_msa = list(dict.fromkeys(full_msas))\n",
" total_msa_size = len(deduped_full_msa)\n",
" print(f'\\n{total_msa_size} Sequences Found in Total\\n')\n",
"\n",
" aa_map = {restype: i for i, restype in enumerate('ABCDEFGHIJKLMNOPQRSTUVWXYZ-')}\n",
" msa_arr = np.array([[aa_map[aa] for aa in seq] for seq in deduped_full_msa])\n",
" num_alignments, num_res = msa_arr.shape\n",
" plt.plot(np.sum(msa_arr != aa_map['-'], axis=0), label=f'Chain {seq_idx}')\n",
" max_num_alignments = max(num_alignments, max_num_alignments)\n",
"\n",
"\n",
"plt.title('Per-Residue Count of Non-Gap Amino Acids in the MSA')\n", "plt.title('Per-Residue Count of Non-Gap Amino Acids in the MSA')\n",
"plt.plot(np.sum(msa_arr != aa_map['-'], axis=0), color='black')\n",
"plt.ylabel('Non-Gap Count')\n", "plt.ylabel('Non-Gap Count')\n",
"plt.yticks(range(0, num_alignments + 1, max(1, int(num_alignments / 3))))\n", "plt.yticks(range(0, max_num_alignments + 1, max(1, int(max_num_alignments / 3))))\n",
"plt.legend()\n",
"plt.show()" "plt.show()"
] ],
"metadata": {
"id": "o7BqQN_gfYtq"
},
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"cellView": "form",
"id": "XUo6foMQxwS2" "id": "XUo6foMQxwS2"
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"#@title Run OpenFold and download prediction\n", "#@title Run OpenFold and download prediction\n",
"\n", "\n",
"#@markdown Once this cell has been executed, a zip-archive with \n", "#@markdown Once this cell has been executed, a zip-archive with\n",
"#@markdown the obtained prediction will be automatically downloaded \n", "#@markdown the obtained prediction will be automatically downloaded\n",
"#@markdown to your computer.\n", "#@markdown to your computer.\n",
"\n", "\n",
"# Color bands for visualizing plddt\n", "# Color bands for visualizing plddt\n",
...@@ -423,13 +463,22 @@ ...@@ -423,13 +463,22 @@
"]\n", "]\n",
"\n", "\n",
"# --- Run the model ---\n", "# --- Run the model ---\n",
"model_names = [ \n", "if model_type == ModelType.MONOMER:\n",
" 'finetuning_3.pt', \n", " model_names = [\n",
" 'finetuning_4.pt', \n", " 'finetuning_3.pt',\n",
" 'finetuning_5.pt', \n", " 'finetuning_4.pt',\n",
" 'finetuning_ptm_2.pt',\n", " 'finetuning_5.pt',\n",
" 'finetuning_no_templ_ptm_1.pt'\n", " 'finetuning_ptm_2.pt',\n",
"]\n", " 'finetuning_no_templ_ptm_1.pt'\n",
" ]\n",
"elif model_type == ModelType.MULTIMER:\n",
" model_names = [\n",
" 'model_1_multimer_v3',\n",
" 'model_2_multimer_v3',\n",
" 'model_3_multimer_v3',\n",
" 'model_4_multimer_v3',\n",
" 'model_5_multimer_v3',\n",
" ]\n",
"\n", "\n",
"def _placeholder_template_feats(num_templates_, num_res_):\n", "def _placeholder_template_feats(num_templates_, num_res_):\n",
" return {\n", " return {\n",
...@@ -440,26 +489,72 @@ ...@@ -440,26 +489,72 @@
" 'template_sum_probs': np.zeros((num_templates_, 1), dtype=np.float32),\n", " 'template_sum_probs': np.zeros((num_templates_, 1), dtype=np.float32),\n",
" }\n", " }\n",
"\n", "\n",
"\n",
"def make_features(\n",
" sequences: Sequence[str],\n",
" msas_by_seq_by_db: Dict[str, Dict[str, parsers.Msa]],\n",
" model_type: ModelType):\n",
" num_templates = 1 # Placeholder for generating fake template features\n",
" feature_dict = {}\n",
"\n",
" for idx, seq in enumerate(sequences, start=1):\n",
" _chain_id = f'chain_{idx}'\n",
" num_res = len(seq)\n",
"\n",
" feats = data_pipeline.make_sequence_features(seq, _chain_id, num_res)\n",
" msas_without_uniprot = [msas_by_seq_by_db[seq][db] for db in db_configs.keys() if db != 'uniprot']\n",
" msa_feats = data_pipeline.make_msa_features(msas_without_uniprot)\n",
" feats.update(msa_feats)\n",
" feats.update(_placeholder_template_feats(num_templates, num_res))\n",
"\n",
" if model_type == ModelType.MONOMER:\n",
" feature_dict[seq] = feats\n",
" if model_type == ModelType.MULTIMER:\n",
" # Perform extra pair processing steps for heteromers\n",
" if len(set(sequences)) > 1:\n",
" uniprot_msa = msas_by_seq_by_db[seq]['uniprot']\n",
" uniprot_msa_features = data_pipeline.make_msa_features([uniprot_msa])\n",
" valid_feat_names = msa_pairing.MSA_FEATURES + (\n",
" 'msa_species_identifiers',\n",
" )\n",
" pair_feats = {\n",
" f'{k}_all_seq': v for k, v in uniprot_msa_features.items()\n",
" if k in valid_feat_names\n",
" }\n",
" feats.update(pair_feats)\n",
"\n",
" feats = data_pipeline.convert_monomer_features(feats, _chain_id)\n",
" feature_dict[_chain_id] = feats\n",
"\n",
" if model_type == ModelType.MONOMER:\n",
" np_example = feature_dict[sequences[0]]\n",
" elif model_type == ModelType.MULTIMER:\n",
" all_chain_feats = data_pipeline.add_assembly_features(feature_dict)\n",
" features = feature_processing_multimer.pair_and_merge(all_chain_features=all_chain_feats)\n",
" np_example = data_pipeline.pad_msa(features, 512)\n",
"\n",
" return np_example\n",
"\n",
"\n",
"output_dir = 'prediction'\n", "output_dir = 'prediction'\n",
"os.makedirs(output_dir, exist_ok=True)\n", "os.makedirs(output_dir, exist_ok=True)\n",
"\n", "\n",
"plddts = {}\n", "plddts = {}\n",
"pae_outputs = {}\n", "pae_outputs = {}\n",
"weighted_ptms = {}\n",
"unrelaxed_proteins = {}\n", "unrelaxed_proteins = {}\n",
"\n", "\n",
"with tqdm.notebook.tqdm(total=len(model_names) + 1, bar_format=TQDM_BAR_FORMAT) as pbar:\n", "with tqdm.notebook.tqdm(total=len(model_names), bar_format=TQDM_BAR_FORMAT) as pbar:\n",
" for i, model_name in list(enumerate(model_names)):\n", " for i, model_name in enumerate(model_names, start = 1):\n",
" pbar.set_description(f'Running {model_name}')\n", " pbar.set_description(f'Running {model_name}')\n",
" num_templates = 1 # dummy number --- is ignored\n", "\n",
" num_res = len(sequence)\n", " feature_dict = make_features(sequences, msas_by_seq_by_db, model_type)\n",
" \n",
" feature_dict = {}\n",
" feature_dict.update(data_pipeline.make_sequence_features(sequence, 'test', num_res))\n",
" feature_dict.update(data_pipeline.make_msa_features(msas, deletion_matrices=deletion_matrices))\n",
" feature_dict.update(_placeholder_template_feats(num_templates, num_res))\n",
"\n", "\n",
" if(weight_set == \"AlphaFold\"):\n", " if(weight_set == \"AlphaFold\"):\n",
" config_preset = f\"model_{i}\"\n", " if model_type == ModelType.MONOMER:\n",
" config_preset = f\"model_{i}\"\n",
" elif model_type == ModelType.MULTIMER:\n",
" config_preset = f'model_{i}_multimer_v3'\n",
" else:\n", " else:\n",
" if(\"_no_templ_\" in model_name):\n", " if(\"_no_templ_\" in model_name):\n",
" config_preset = \"model_3\"\n", " config_preset = \"model_3\"\n",
...@@ -469,6 +564,11 @@ ...@@ -469,6 +564,11 @@
" config_preset += \"_ptm\"\n", " config_preset += \"_ptm\"\n",
"\n", "\n",
" cfg = config.model_config(config_preset)\n", " cfg = config.model_config(config_preset)\n",
"\n",
" # Force the model to only use 3 recycling updates\n",
" cfg.data.common.max_recycling_iters = 3\n",
" cfg.model.recycle_early_stop_tolerance = -1\n",
"\n",
" openfold_model = model.AlphaFold(cfg)\n", " openfold_model = model.AlphaFold(cfg)\n",
" openfold_model = openfold_model.eval()\n", " openfold_model = openfold_model.eval()\n",
" if(weight_set == \"AlphaFold\"):\n", " if(weight_set == \"AlphaFold\"):\n",
...@@ -490,7 +590,9 @@ ...@@ -490,7 +590,9 @@
"\n", "\n",
" pipeline = feature_pipeline.FeaturePipeline(cfg.data)\n", " pipeline = feature_pipeline.FeaturePipeline(cfg.data)\n",
" processed_feature_dict = pipeline.process_features(\n", " processed_feature_dict = pipeline.process_features(\n",
" feature_dict, mode='predict'\n", " feature_dict,\n",
" mode='predict',\n",
" is_multimer = (model_type == ModelType.MULTIMER),\n",
" )\n", " )\n",
"\n", "\n",
" processed_feature_dict = tensor_tree_map(\n", " processed_feature_dict = tensor_tree_map(\n",
...@@ -510,21 +612,32 @@ ...@@ -510,21 +612,32 @@
"\n", "\n",
" mean_plddt = prediction_result['plddt'].mean()\n", " mean_plddt = prediction_result['plddt'].mean()\n",
"\n", "\n",
" if 'predicted_aligned_error' in prediction_result:\n", " if model_type == ModelType.MONOMER:\n",
" pae_outputs[model_name] = (\n", " if 'predicted_aligned_error' in prediction_result:\n",
" prediction_result['predicted_aligned_error'],\n", " pae_outputs[model_name] = (\n",
" prediction_result['max_predicted_aligned_error']\n", " prediction_result['predicted_aligned_error'],\n",
" )\n", " prediction_result['max_predicted_aligned_error']\n",
" else:\n", " )\n",
" # Get the pLDDT confidence metrics. Do not put pTM models here as they\n", " else:\n",
" # should never get selected.\n", " # Get the pLDDT confidence metrics. Do not put pTM models here as they\n",
" # should never get selected.\n",
" plddts[model_name] = prediction_result['plddt']\n",
" elif model_type == ModelType.MULTIMER:\n",
" # Multimer models are sorted by pTM+ipTM.\n",
" plddts[model_name] = prediction_result['plddt']\n", " plddts[model_name] = prediction_result['plddt']\n",
" pae_outputs[model_name] = (prediction_result['predicted_aligned_error'],\n",
" prediction_result['max_predicted_aligned_error'])\n",
"\n",
" weighted_ptms[model_name] = prediction_result['weighted_ptm_score']\n",
"\n", "\n",
" # Set the b-factors to the per-residue plddt.\n", " # Set the b-factors to the per-residue plddt.\n",
" final_atom_mask = prediction_result['final_atom_mask']\n", " final_atom_mask = prediction_result['final_atom_mask']\n",
" b_factors = prediction_result['plddt'][:, None] * final_atom_mask\n", " b_factors = prediction_result['plddt'][:, None] * final_atom_mask\n",
" unrelaxed_protein = protein.from_prediction(\n", " unrelaxed_protein = protein.from_prediction(\n",
" processed_feature_dict, prediction_result, b_factors=b_factors\n", " processed_feature_dict,\n",
" prediction_result,\n",
" remove_leading_feature_dimension=False,\n",
" b_factors=b_factors,\n",
" )\n", " )\n",
" unrelaxed_proteins[model_name] = unrelaxed_protein\n", " unrelaxed_proteins[model_name] = unrelaxed_protein\n",
"\n", "\n",
...@@ -535,7 +648,10 @@ ...@@ -535,7 +648,10 @@
" pbar.update(n=1)\n", " pbar.update(n=1)\n",
"\n", "\n",
" # Find the best model according to the mean pLDDT.\n", " # Find the best model according to the mean pLDDT.\n",
" best_model_name = max(plddts.keys(), key=lambda x: plddts[x].mean())\n", " if model_type == ModelType.MONOMER:\n",
" best_model_name = max(plddts.keys(), key=lambda x: plddts[x].mean())\n",
" elif model_type == ModelType.MULTIMER:\n",
" best_model_name = max(weighted_ptms.keys(), key=lambda x: weighted_ptms[x])\n",
" best_pdb = protein.to_pdb(unrelaxed_proteins[best_model_name])\n", " best_pdb = protein.to_pdb(unrelaxed_proteins[best_model_name])\n",
"\n", "\n",
" # --- AMBER relax the best model ---\n", " # --- AMBER relax the best model ---\n",
...@@ -547,7 +663,7 @@ ...@@ -547,7 +663,7 @@
" stiffness=10.0,\n", " stiffness=10.0,\n",
" exclude_residues=[],\n", " exclude_residues=[],\n",
" max_outer_iterations=20,\n", " max_outer_iterations=20,\n",
" use_gpu=False,\n", " use_gpu=True,\n",
" )\n", " )\n",
" relaxed_pdb, _, _ = amber_relaxer.process(\n", " relaxed_pdb, _, _ = amber_relaxer.process(\n",
" prot=unrelaxed_proteins[best_model_name]\n", " prot=unrelaxed_proteins[best_model_name]\n",
...@@ -598,6 +714,15 @@ ...@@ -598,6 +714,15 @@
" plt.title('Model Confidence', fontsize=20, pad=20)\n", " plt.title('Model Confidence', fontsize=20, pad=20)\n",
" return plt\n", " return plt\n",
"\n", "\n",
"# Show the structure coloured by chain if the multimer model has been used.\n",
"if model_type == ModelType.MULTIMER:\n",
" multichain_view = py3Dmol.view(width=800, height=600)\n",
" multichain_view.addModelsAsFrames(to_visualize_pdb)\n",
" multichain_style = {'cartoon': {'colorscheme': 'chain'}}\n",
" multichain_view.setStyle({'model': -1}, multichain_style)\n",
" multichain_view.zoomTo()\n",
" multichain_view.show()\n",
"\n",
"# Color the structure by per-residue pLDDT\n", "# Color the structure by per-residue pLDDT\n",
"color_map = {i: bands[2] for i, bands in enumerate(PLDDT_BANDS)}\n", "color_map = {i: bands[2] for i, bands in enumerate(PLDDT_BANDS)}\n",
"view = py3Dmol.view(width=800, height=600)\n", "view = py3Dmol.view(width=800, height=600)\n",
...@@ -643,6 +768,15 @@ ...@@ -643,6 +768,15 @@
" pae, max_pae = list(pae_outputs.values())[0]\n", " pae, max_pae = list(pae_outputs.values())[0]\n",
" plt.imshow(pae, vmin=0., vmax=max_pae, cmap='Greens_r')\n", " plt.imshow(pae, vmin=0., vmax=max_pae, cmap='Greens_r')\n",
" plt.colorbar(fraction=0.046, pad=0.04)\n", " plt.colorbar(fraction=0.046, pad=0.04)\n",
"\n",
" # Display lines at chain boundaries.\n",
" best_unrelaxed_prot = unrelaxed_proteins[best_model_name]\n",
" total_num_res = best_unrelaxed_prot.residue_index.shape[-1]\n",
" chain_ids = best_unrelaxed_prot.chain_index\n",
" for chain_boundary in np.nonzero(chain_ids[:-1] - chain_ids[1:]):\n",
" if chain_boundary.size:\n",
" plt.plot([0, total_num_res], [chain_boundary, chain_boundary], color='red')\n",
" plt.plot([chain_boundary, chain_boundary], [0, total_num_res], color='red')\n",
" plt.title('Predicted Aligned Error')\n", " plt.title('Predicted Aligned Error')\n",
" plt.xlabel('Scored residue')\n", " plt.xlabel('Scored residue')\n",
" plt.ylabel('Aligned residue')\n", " plt.ylabel('Aligned residue')\n",
...@@ -680,7 +814,7 @@ ...@@ -680,7 +814,7 @@
"source": [ "source": [
"### Interpreting the prediction\n", "### Interpreting the prediction\n",
"\n", "\n",
"Please see the [AlphaFold methods paper](https://www.nature.com/articles/s41586-021-03819-2) and the [AlphaFold predictions of the human proteome paper](https://www.nature.com/articles/s41586-021-03828-1), as well as [DeepMind's FAQ](https://alphafold.ebi.ac.uk/faq) on how to interpret AlphaFold/OpenFold predictions." "Please see the [AlphaFold methods paper](https://www.nature.com/articles/s41586-021-03819-2) and the [AlphaFold predictions of the human proteome paper](https://www.nature.com/articles/s41586-021-03828-1), as well as [DeepMind's FAQ](https://alphafold.ebi.ac.uk/faq) on how to interpret AlphaFold/OpenFold predictions. More information about the predictions of the AlphaFold Multimer model can be found in the [Alphafold Multimer paper](https://www.biorxiv.org/content/10.1101/2022.03.11.484043v3.full.pdf)."
] ]
}, },
{ {
...@@ -718,7 +852,7 @@ ...@@ -718,7 +852,7 @@
" * Colab CPU runtime is not supported. Try changing your runtime via _Runtime_ > _Change runtime type_ > _Hardware accelerator_ > _GPU_.\n", " * Colab CPU runtime is not supported. Try changing your runtime via _Runtime_ > _Change runtime type_ > _Hardware accelerator_ > _GPU_.\n",
" * The type of GPU allocated to your Colab varies. See the [Colab FAQ](https://research.google.com/colaboratory/faq.html) for more details.\n", " * The type of GPU allocated to your Colab varies. See the [Colab FAQ](https://research.google.com/colaboratory/faq.html) for more details.\n",
" * If you receive “Cannot connect to GPU backend”, you can try again later to see if Colab allocates you a GPU.\n", " * If you receive “Cannot connect to GPU backend”, you can try again later to see if Colab allocates you a GPU.\n",
" * [Colab Pro](https://colab.research.google.com/signup) offers priority access to GPUs. \n", " * [Colab Pro](https://colab.research.google.com/signup) offers priority access to GPUs.\n",
"* Does this tool install anything on my computer?\n", "* Does this tool install anything on my computer?\n",
" * No, everything happens in the cloud on Google Colab.\n", " * No, everything happens in the cloud on Google Colab.\n",
" * At the end of the Colab execution a zip-archive with the obtained prediction will be automatically downloaded to your computer.\n", " * At the end of the Colab execution a zip-archive with the obtained prediction will be automatically downloaded to your computer.\n",
...@@ -766,13 +900,10 @@ ...@@ -766,13 +900,10 @@
} }
], ],
"metadata": { "metadata": {
"accelerator": "GPU",
"colab": { "colab": {
"collapsed_sections": [],
"name": "OpenFold.ipynb",
"provenance": [], "provenance": [],
"gpuType": "T4", "gpuType": "T4",
"include_colab_link": true "toc_visible": true
}, },
"kernelspec": { "kernelspec": {
"display_name": "Python 3", "display_name": "Python 3",
...@@ -780,8 +911,9 @@ ...@@ -780,8 +911,9 @@
}, },
"language_info": { "language_info": {
"name": "python" "name": "python"
} },
"accelerator": "GPU"
}, },
"nbformat": 4, "nbformat": 4,
"nbformat_minor": 0 "nbformat_minor": 0
} }
\ No newline at end of file
...@@ -3,15 +3,15 @@ channels: ...@@ -3,15 +3,15 @@ channels:
- conda-forge - conda-forge
- bioconda - bioconda
dependencies: dependencies:
- conda-forge::openmm=7.5.1 - openmm=7.7
- conda-forge::pdbfixer - pdbfixer
- ml-collections
- PyYAML==5.4.1
- requests
- typing-extensions
- bioconda::hmmer==3.3.2 - bioconda::hmmer==3.3.2
- bioconda::hhsuite==3.3.0 - bioconda::hhsuite==3.3.0
- bioconda::kalign2==2.04 - bioconda::kalign2==2.04
- pip: - pip:
- biopython==1.79 - biopython==1.79
- dm-tree==0.1.6 - dm-tree==0.1.6
- ml-collections==0.1.0
- PyYAML==5.4.1
- requests==2.26.0
- typing-extensions==3.10.0.2
from . import model from . import model
from . import utils from . import utils
from . import data
from . import np from . import np
from . import resources from . import resources
......
import re
import copy import copy
import importlib import importlib
import ml_collections as mlc import ml_collections as mlc
...@@ -16,7 +17,7 @@ def enforce_config_constraints(config): ...@@ -16,7 +17,7 @@ def enforce_config_constraints(config):
path = s.split('.') path = s.split('.')
setting = config setting = config
for p in path: for p in path:
setting = setting[p] setting = setting.get(p)
return setting return setting
...@@ -161,44 +162,70 @@ def model_config( ...@@ -161,44 +162,70 @@ def model_config(
c.model.template.enabled = False c.model.template.enabled = False
c.model.heads.tm.enabled = True c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1 c.loss.tm.weight = 0.1
# SINGLE SEQUENCE EMBEDDING PRESETS elif name.startswith("seq"): # SINGLE SEQUENCE EMBEDDING PRESETS
elif name == "seqemb_initial_training": c.update(seq_mode_config.copy_and_resolve_references())
c.data.train.max_msa_clusters = 1 if name == "seqemb_initial_training":
c.data.eval.max_msa_clusters = 1 c.data.train.max_msa_clusters = 1
c.data.train.block_delete_msa = False c.data.eval.max_msa_clusters = 1
c.data.train.max_distillation_msa_clusters = 1 c.data.train.block_delete_msa = False
elif name == "seqemb_finetuning": c.data.train.max_distillation_msa_clusters = 1
c.data.train.max_msa_clusters = 1 elif name == "seqemb_finetuning":
c.data.eval.max_msa_clusters = 1 c.data.train.max_msa_clusters = 1
c.data.train.block_delete_msa = False c.data.eval.max_msa_clusters = 1
c.data.train.max_distillation_msa_clusters = 1 c.data.train.block_delete_msa = False
c.data.train.crop_size = 384 c.data.train.max_distillation_msa_clusters = 1
c.loss.violation.weight = 1. c.data.train.crop_size = 384
c.loss.experimentally_resolved.weight = 0.01 c.loss.violation.weight = 1.
elif name == "seq_model_esm1b": c.loss.experimentally_resolved.weight = 0.01
c.data.common.use_templates = True elif name == "seq_model_esm1b":
c.data.common.use_template_torsion_angles = True c.data.common.use_templates = True
c.model.template.enabled = True c.data.common.use_template_torsion_angles = True
c.data.predict.max_msa_clusters = 1 c.model.template.enabled = True
elif name == "seq_model_esm1b_ptm": c.data.predict.max_msa_clusters = 1
c.data.common.use_templates = True elif name == "seq_model_esm1b_ptm":
c.data.common.use_template_torsion_angles = True c.data.common.use_templates = True
c.model.template.enabled = True c.data.common.use_template_torsion_angles = True
c.data.predict.max_msa_clusters = 1 c.model.template.enabled = True
c.model.heads.tm.enabled = True c.data.predict.max_msa_clusters = 1
c.loss.tm.weight = 0.1 c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1
elif "multimer" in name: # MULTIMER PRESETS
c.update(multimer_config_update.copy_and_resolve_references())
# Not used in multimer
del c.model.template.template_pointwise_attention
del c.loss.fape.backbone
# TODO: Change max_msa_clusters and max_extra_msa to multimer feats within model
if re.fullmatch("^model_[1-5]_multimer(_v2)?$", name):
#c.model.input_embedder.num_msa = 252
#c.model.extra_msa.extra_msa_embedder.num_extra_msa = 1152
c.data.train.crop_size = 384
c.data.train.max_msa_clusters = 252
c.data.eval.max_msa_clusters = 252
c.data.predict.max_msa_clusters = 252
c.data.train.max_extra_msa = 1152
c.data.eval.max_extra_msa = 1152
c.data.predict.max_extra_msa = 1152
c.model.evoformer_stack.fuse_projection_weights = False
c.model.extra_msa.extra_msa_stack.fuse_projection_weights = False
c.model.template.template_pair_stack.fuse_projection_weights = False
elif name == 'model_4_multimer_v3':
#c.model.extra_msa.extra_msa_embedder.num_extra_msa = 1152
c.data.train.max_extra_msa = 1152
c.data.eval.max_extra_msa = 1152
c.data.predict.max_extra_msa = 1152
elif name == 'model_5_multimer_v3':
#c.model.extra_msa.extra_msa_embedder.num_extra_msa = 1152
c.data.train.max_extra_msa = 1152
c.data.eval.max_extra_msa = 1152
c.data.predict.max_extra_msa = 1152
else: else:
raise ValueError("Invalid model name") raise ValueError("Invalid model name")
if name.startswith("seq"):
# Tell the data pipeline that we will use sequence embeddings instead of MSAs.
c.data.seqemb_mode.enabled = True
c.globals.seqemb_mode_enabled = True
# In seqemb mode, we turn off the ExtraMSAStack and Evoformer's column attention.
c.model.extra_msa.enabled = False
c.model.evoformer_stack.no_column_attention = True
c.update(seq_mode_config.copy_and_resolve_references())
if long_sequence_inference: if long_sequence_inference:
assert(not train) assert(not train)
c.globals.offload_inference = True c.globals.offload_inference = True
...@@ -380,6 +407,8 @@ config = mlc.ConfigDict( ...@@ -380,6 +407,8 @@ config = mlc.ConfigDict(
"max_templates": 4, "max_templates": 4,
"crop": False, "crop": False,
"crop_size": None, "crop_size": None,
"spatial_crop_prob": None,
"interface_threshold": None,
"supervised": False, "supervised": False,
"uniform_recycling": False, "uniform_recycling": False,
}, },
...@@ -394,6 +423,8 @@ config = mlc.ConfigDict( ...@@ -394,6 +423,8 @@ config = mlc.ConfigDict(
"max_templates": 4, "max_templates": 4,
"crop": False, "crop": False,
"crop_size": None, "crop_size": None,
"spatial_crop_prob": None,
"interface_threshold": None,
"supervised": True, "supervised": True,
"uniform_recycling": False, "uniform_recycling": False,
}, },
...@@ -409,6 +440,8 @@ config = mlc.ConfigDict( ...@@ -409,6 +440,8 @@ config = mlc.ConfigDict(
"shuffle_top_k_prefiltered": 20, "shuffle_top_k_prefiltered": 20,
"crop": True, "crop": True,
"crop_size": 256, "crop_size": 256,
"spatial_crop_prob": 0.,
"interface_threshold": None,
"supervised": True, "supervised": True,
"clamp_prob": 0.9, "clamp_prob": 0.9,
"max_distillation_msa_clusters": 1000, "max_distillation_msa_clusters": 1000,
...@@ -426,7 +459,6 @@ config = mlc.ConfigDict( ...@@ -426,7 +459,6 @@ config = mlc.ConfigDict(
}, },
# Recurring FieldReferences that can be changed globally here # Recurring FieldReferences that can be changed globally here
"globals": { "globals": {
"seqemb_mode_enabled": False, # Global flag for enabling seq emb mode
"blocks_per_ckpt": blocks_per_ckpt, "blocks_per_ckpt": blocks_per_ckpt,
"chunk_size": chunk_size, "chunk_size": chunk_size,
# Use DeepSpeed memory-efficient attention kernel. Mutually # Use DeepSpeed memory-efficient attention kernel. Mutually
...@@ -446,6 +478,8 @@ config = mlc.ConfigDict( ...@@ -446,6 +478,8 @@ config = mlc.ConfigDict(
"c_e": c_e, "c_e": c_e,
"c_s": c_s, "c_s": c_s,
"eps": eps, "eps": eps,
"is_multimer": False,
"seqemb_mode_enabled": False, # Global flag for enabling seq emb mode
}, },
"model": { "model": {
"_mask_trans": False, "_mask_trans": False,
...@@ -470,7 +504,7 @@ config = mlc.ConfigDict( ...@@ -470,7 +504,7 @@ config = mlc.ConfigDict(
"max_bin": 50.75, "max_bin": 50.75,
"no_bins": 39, "no_bins": 39,
}, },
"template_angle_embedder": { "template_single_embedder": {
# DISCREPANCY: c_in is supposed to be 51. # DISCREPANCY: c_in is supposed to be 51.
"c_in": 57, "c_in": 57,
"c_out": c_m, "c_out": c_m,
...@@ -489,6 +523,8 @@ config = mlc.ConfigDict( ...@@ -489,6 +523,8 @@ config = mlc.ConfigDict(
"no_heads": 4, "no_heads": 4,
"pair_transition_n": 2, "pair_transition_n": 2,
"dropout_rate": 0.25, "dropout_rate": 0.25,
"tri_mul_first": False,
"fuse_projection_weights": False,
"blocks_per_ckpt": blocks_per_ckpt, "blocks_per_ckpt": blocks_per_ckpt,
"tune_chunk_size": tune_chunk_size, "tune_chunk_size": tune_chunk_size,
"inf": 1e9, "inf": 1e9,
...@@ -537,6 +573,8 @@ config = mlc.ConfigDict( ...@@ -537,6 +573,8 @@ config = mlc.ConfigDict(
"transition_n": 4, "transition_n": 4,
"msa_dropout": 0.15, "msa_dropout": 0.15,
"pair_dropout": 0.25, "pair_dropout": 0.25,
"opm_first": False,
"fuse_projection_weights": False,
"clear_cache_between_blocks": False, "clear_cache_between_blocks": False,
"tune_chunk_size": tune_chunk_size, "tune_chunk_size": tune_chunk_size,
"inf": 1e9, "inf": 1e9,
...@@ -560,6 +598,8 @@ config = mlc.ConfigDict( ...@@ -560,6 +598,8 @@ config = mlc.ConfigDict(
"msa_dropout": 0.15, "msa_dropout": 0.15,
"pair_dropout": 0.25, "pair_dropout": 0.25,
"no_column_attention": False, "no_column_attention": False,
"opm_first": False,
"fuse_projection_weights": False,
"blocks_per_ckpt": blocks_per_ckpt, "blocks_per_ckpt": blocks_per_ckpt,
"clear_cache_between_blocks": False, "clear_cache_between_blocks": False,
"tune_chunk_size": tune_chunk_size, "tune_chunk_size": tune_chunk_size,
...@@ -607,6 +647,12 @@ config = mlc.ConfigDict( ...@@ -607,6 +647,12 @@ config = mlc.ConfigDict(
"c_out": 37, "c_out": 37,
}, },
}, },
# A negative value indicates that no early stopping will occur, i.e.
# the model will always run `max_recycling_iters` number of recycling
# iterations. A positive value will enable early stopping if the
# difference in pairwise distances is less than the tolerance between
# recycling steps.
"recycle_early_stop_tolerance": -1.
}, },
"relax": { "relax": {
"max_iterations": 0, # no max "max_iterations": 0, # no max
...@@ -652,6 +698,7 @@ config = mlc.ConfigDict( ...@@ -652,6 +698,7 @@ config = mlc.ConfigDict(
"weight": 0.01, "weight": 0.01,
}, },
"masked_msa": { "masked_msa": {
"num_classes": 23,
"eps": eps, # 1e-8, "eps": eps, # 1e-8,
"weight": 2.0, "weight": 2.0,
}, },
...@@ -664,6 +711,7 @@ config = mlc.ConfigDict( ...@@ -664,6 +711,7 @@ config = mlc.ConfigDict(
"violation": { "violation": {
"violation_tolerance_factor": 12.0, "violation_tolerance_factor": 12.0,
"clash_overlap_tolerance": 1.5, "clash_overlap_tolerance": 1.5,
"average_clashes": False,
"eps": eps, # 1e-6, "eps": eps, # 1e-6,
"weight": 0.0, "weight": 0.0,
}, },
...@@ -676,12 +724,199 @@ config = mlc.ConfigDict( ...@@ -676,12 +724,199 @@ config = mlc.ConfigDict(
"weight": 0., "weight": 0.,
"enabled": tm_enabled, "enabled": tm_enabled,
}, },
"chain_center_of_mass": {
"clamp_distance": -4.0,
"weight": 0.,
"eps": eps,
"enabled": False,
},
"eps": eps, "eps": eps,
}, },
"ema": {"decay": 0.999}, "ema": {"decay": 0.999},
} }
) )
multimer_config_update = mlc.ConfigDict({
"globals": {
"is_multimer": True
},
"data": {
"common": {
"feat": {
"aatype": [NUM_RES],
"all_atom_mask": [NUM_RES, None],
"all_atom_positions": [NUM_RES, None, None],
# "all_chains_entity_ids": [], # TODO: Resolve missing features, remove processed msa feats
# "all_crops_all_chains_mask": [],
# "all_crops_all_chains_positions": [],
# "all_crops_all_chains_residue_ids": [],
"assembly_num_chains": [],
"asym_id": [NUM_RES],
"atom14_atom_exists": [NUM_RES, None],
"atom37_atom_exists": [NUM_RES, None],
"bert_mask": [NUM_MSA_SEQ, NUM_RES],
"cluster_bias_mask": [NUM_MSA_SEQ],
"cluster_profile": [NUM_MSA_SEQ, NUM_RES, None],
"cluster_deletion_mean": [NUM_MSA_SEQ, NUM_RES],
"deletion_matrix": [NUM_MSA_SEQ, NUM_RES],
"deletion_mean": [NUM_RES],
"entity_id": [NUM_RES],
"entity_mask": [NUM_RES],
"extra_deletion_matrix": [NUM_EXTRA_SEQ, NUM_RES],
"extra_msa": [NUM_EXTRA_SEQ, NUM_RES],
"extra_msa_mask": [NUM_EXTRA_SEQ, NUM_RES],
# "mem_peak": [],
"msa": [NUM_MSA_SEQ, NUM_RES],
"msa_feat": [NUM_MSA_SEQ, NUM_RES, None],
"msa_mask": [NUM_MSA_SEQ, NUM_RES],
"msa_profile": [NUM_RES, None],
"num_alignments": [],
"num_templates": [],
# "queue_size": [],
"residue_index": [NUM_RES],
"residx_atom14_to_atom37": [NUM_RES, None],
"residx_atom37_to_atom14": [NUM_RES, None],
"resolution": [],
"seq_length": [],
"seq_mask": [NUM_RES],
"sym_id": [NUM_RES],
"target_feat": [NUM_RES, None],
"template_aatype": [NUM_TEMPLATES, NUM_RES],
"template_all_atom_mask": [NUM_TEMPLATES, NUM_RES, None],
"template_all_atom_positions": [
NUM_TEMPLATES, NUM_RES, None, None,
],
"true_msa": [NUM_MSA_SEQ, NUM_RES]
},
"max_recycling_iters": 20, # For training, value is 3
"unsupervised_features": [
"aatype",
"residue_index",
"msa",
"num_alignments",
"seq_length",
"between_segment_residues",
"deletion_matrix",
"no_recycling_iters",
# Additional multimer features
"msa_mask",
"seq_mask",
"asym_id",
"entity_id",
"sym_id",
]
},
"supervised": {
"clamp_prob": 1.
},
# TODO: Change max_msa_clusters and max_extra_msa to multimer feats within model:
# c.model.input_embedder.num_msa = 508
# c.model.extra_msa.extra_msa_embedder.num_extra_msa = 2048
"predict": {
"max_msa_clusters": 508,
"max_extra_msa": 2048
},
"eval": {
"max_msa_clusters": 508,
"max_extra_msa": 2048
},
"train": {
"max_msa_clusters": 508,
"max_extra_msa": 2048,
"block_delete_msa" : False,
"crop_size": 640,
"spatial_crop_prob": 0.5,
"interface_threshold": 10.,
"clamp_prob": 1.,
},
},
"model": {
"input_embedder": {
"tf_dim": 21,
#"num_msa": 508,
"max_relative_chain": 2,
"max_relative_idx": 32,
"use_chain_relative": True
},
"template": {
"template_single_embedder": {
"c_in": 34,
"c_out": c_m
},
"template_pair_embedder": {
"c_in": c_z,
"c_out": c_t,
"c_dgram": 39,
"c_aatype": 22
},
"template_pair_stack": {
"tri_mul_first": True,
"fuse_projection_weights": True
},
"c_t": c_t,
"c_z": c_z,
"use_unit_vector": True
},
"extra_msa": {
# "extra_msa_embedder": {
# "num_extra_msa": 2048
# },
"extra_msa_stack": {
"opm_first": True,
"fuse_projection_weights": True
}
},
"evoformer_stack": {
"opm_first": True,
"fuse_projection_weights": True
},
"structure_module": {
"trans_scale_factor": 20
},
"heads": {
"tm": {
"ptm_weight": 0.2,
"iptm_weight": 0.8,
"enabled": True
},
"masked_msa": {
"c_out": 22
},
},
"recycle_early_stop_tolerance": 0.5 # For training, value is -1.
},
"loss": {
"fape": {
"intra_chain_backbone": {
"clamp_distance": 10.0,
"loss_unit_distance": 10.0,
"weight": 0.5
},
"interface_backbone": {
"clamp_distance": 30.0,
"loss_unit_distance": 20.0,
"weight": 0.5
}
},
"masked_msa": {
"num_classes": 22
},
"violation": {
"average_clashes": True,
"weight": 0.03 # Not finetuning
},
"tm": {
"weight": 0.1,
"enabled": True
},
"chain_center_of_mass": {
"weight": 0.05,
"enabled": True
}
}
})
seq_mode_config = mlc.ConfigDict({ seq_mode_config = mlc.ConfigDict({
"data": { "data": {
"common": { "common": {
...@@ -700,12 +935,18 @@ seq_mode_config = mlc.ConfigDict({ ...@@ -700,12 +935,18 @@ seq_mode_config = mlc.ConfigDict({
"seqemb_mode_enabled": True, "seqemb_mode_enabled": True,
}, },
"model": { "model": {
"preembedding_embedder": { # Used in sequence embedding mode "preembedding_embedder": { # Used in sequence embedding mode
"tf_dim": 22, "tf_dim": 22,
"preembedding_dim": preemb_dim_size, "preembedding_dim": preemb_dim_size,
"c_z": c_z, "c_z": c_z,
"c_m": c_m, "c_m": c_m,
"relpos_k": 32, "relpos_k": 32,
}, },
"extra_msa": {
"enabled": False # Disable Extra MSA Stack
},
"evoformer_stack": {
"no_column_attention": True # Turn off Evoformer's column attention
},
} }
}) })
\ No newline at end of file
...@@ -4,42 +4,45 @@ import json ...@@ -4,42 +4,45 @@ import json
import logging import logging
import os import os
import pickle import pickle
from typing import Optional, Sequence, List, Any from typing import Optional, Sequence, Any, Union
import ml_collections as mlc import ml_collections as mlc
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from torch.utils.data import RandomSampler from torch.utils.data import RandomSampler
from openfold.np.residue_constants import restypes
from openfold.data import ( from openfold.data import (
data_pipeline, data_pipeline,
feature_pipeline, feature_pipeline,
mmcif_parsing, mmcif_parsing,
templates, templates,
) )
from openfold.utils.tensor_utils import tensor_tree_map, dict_multimap from openfold.utils.tensor_utils import dict_multimap
from openfold.utils.tensor_utils import (
tensor_tree_map,
)
class OpenFoldSingleDataset(torch.utils.data.Dataset): class OpenFoldSingleDataset(torch.utils.data.Dataset):
def __init__(self, def __init__(self,
data_dir: str, data_dir: str,
alignment_dir: str, alignment_dir: str,
template_mmcif_dir: str, template_mmcif_dir: str,
max_template_date: str, max_template_date: str,
config: mlc.ConfigDict, config: mlc.ConfigDict,
chain_data_cache_path: Optional[str] = None, chain_data_cache_path: Optional[str] = None,
kalign_binary_path: str = '/usr/bin/kalign', kalign_binary_path: str = '/usr/bin/kalign',
max_template_hits: int = 4, max_template_hits: int = 4,
obsolete_pdbs_file_path: Optional[str] = None, obsolete_pdbs_file_path: Optional[str] = None,
template_release_dates_cache_path: Optional[str] = None, template_release_dates_cache_path: Optional[str] = None,
shuffle_top_k_prefiltered: Optional[int] = None, shuffle_top_k_prefiltered: Optional[int] = None,
treat_pdb_as_distillation: bool = True, treat_pdb_as_distillation: bool = True,
filter_path: Optional[str] = None, filter_path: Optional[str] = None,
mode: str = "train", mode: str = "train",
alignment_index: Optional[Any] = None, alignment_index: Optional[Any] = None,
_output_raw: bool = False, _output_raw: bool = False,
_structure_index: Optional[Any] = None, _structure_index: Optional[Any] = None,
): ):
""" """
Args: Args:
data_dir: data_dir:
...@@ -101,21 +104,21 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -101,21 +104,21 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
self.supported_exts = [".cif", ".core", ".pdb"] self.supported_exts = [".cif", ".core", ".pdb"]
valid_modes = ["train", "eval", "predict"] valid_modes = ["train", "eval", "predict"]
if(mode not in valid_modes): if mode not in valid_modes:
raise ValueError(f'mode must be one of {valid_modes}') raise ValueError(f'mode must be one of {valid_modes}')
if(template_release_dates_cache_path is None): if template_release_dates_cache_path is None:
logging.warning( logging.warning(
"Template release dates cache does not exist. Remember to run " "Template release dates cache does not exist. Remember to run "
"scripts/generate_mmcif_cache.py before running OpenFold" "scripts/generate_mmcif_cache.py before running OpenFold"
) )
if(alignment_index is not None): if alignment_index is not None:
self._chain_ids = list(alignment_index.keys()) self._chain_ids = list(alignment_index.keys())
else: else:
self._chain_ids = list(os.listdir(alignment_dir)) self._chain_ids = list(os.listdir(alignment_dir))
if(filter_path is not None): if filter_path is not None:
with open(filter_path, "r") as f: with open(filter_path, "r") as f:
chains_to_include = set([l.strip() for l in f.readlines()]) chains_to_include = set([l.strip() for l in f.readlines()])
...@@ -145,12 +148,15 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -145,12 +148,15 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
len(missing), len(missing),
missing_examples, missing_examples,
chain_data_cache_path) chain_data_cache_path)
self._chain_id_to_idx_dict = { self._chain_id_to_idx_dict = {
chain: i for i, chain in enumerate(self._chain_ids) chain: i for i, chain in enumerate(self._chain_ids)
} }
template_featurizer = templates.TemplateHitFeaturizer( # If it's running template search for a monomer, then use hhsearch
# as demonstrated in AlphaFold's run_alphafold.py code
# https://github.com/deepmind/alphafold/blob/6c4d833fbd1c6b8e7c9a21dae5d4ada2ce777e10/run_alphafold.py#L462C1-L477
template_featurizer = templates.HhsearchHitFeaturizer(
mmcif_dir=template_mmcif_dir, mmcif_dir=template_mmcif_dir,
max_template_date=max_template_date, max_template_date=max_template_date,
max_hits=max_template_hits, max_hits=max_template_hits,
...@@ -164,8 +170,8 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -164,8 +170,8 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
template_featurizer=template_featurizer, template_featurizer=template_featurizer,
) )
if(not self._output_raw): if not self._output_raw:
self.feature_pipeline = feature_pipeline.FeaturePipeline(config) self.feature_pipeline = feature_pipeline.FeaturePipeline(config)
def _parse_mmcif(self, path, file_id, chain_id, alignment_dir, alignment_index): def _parse_mmcif(self, path, file_id, chain_id, alignment_dir, alignment_index):
with open(path, 'r') as f: with open(path, 'r') as f:
...@@ -177,7 +183,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -177,7 +183,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
# Crash if an error is encountered. Any parsing errors should have # Crash if an error is encountered. Any parsing errors should have
# been dealt with at the alignment stage. # been dealt with at the alignment stage.
if(mmcif_object.mmcif_object is None): if mmcif_object.mmcif_object is None:
raise list(mmcif_object.errors.values())[0] raise list(mmcif_object.errors.values())[0]
mmcif_object = mmcif_object.mmcif_object mmcif_object = mmcif_object.mmcif_object
...@@ -203,48 +209,47 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -203,48 +209,47 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
alignment_dir = os.path.join(self.alignment_dir, name) alignment_dir = os.path.join(self.alignment_dir, name)
alignment_index = None alignment_index = None
if(self.alignment_index is not None): if self.alignment_index is not None:
alignment_dir = self.alignment_dir alignment_dir = self.alignment_dir
alignment_index = self.alignment_index[name] alignment_index = self.alignment_index[name]
if(self.mode == 'train' or self.mode == 'eval'): if self.mode == 'train' or self.mode == 'eval':
spl = name.rsplit('_', 1) spl = name.rsplit('_', 1)
if(len(spl) == 2): if len(spl) == 2:
file_id, chain_id = spl file_id, chain_id = spl
else: else:
file_id, = spl file_id, = spl
chain_id = None chain_id = None
path = os.path.join(self.data_dir, file_id) path = os.path.join(self.data_dir, file_id)
structure_index_entry = None if self._structure_index is not None:
if(self._structure_index is not None):
structure_index_entry = self._structure_index[name] structure_index_entry = self._structure_index[name]
assert(len(structure_index_entry["files"]) == 1) assert (len(structure_index_entry["files"]) == 1)
filename, _, _ = structure_index_entry["files"][0] filename, _, _ = structure_index_entry["files"][0]
ext = os.path.splitext(filename)[1] ext = os.path.splitext(filename)[1]
else: else:
ext = None ext = None
for e in self.supported_exts: for e in self.supported_exts:
if(os.path.exists(path + e)): if os.path.exists(path + e):
ext = e ext = e
break break
if(ext is None): if ext is None:
raise ValueError("Invalid file type") raise ValueError("Invalid file type")
path += ext path += ext
if(ext == ".cif"): if ext == ".cif":
data = self._parse_mmcif( data = self._parse_mmcif(
path, file_id, chain_id, alignment_dir, alignment_index, path, file_id, chain_id, alignment_dir, alignment_index,
) )
elif(ext == ".core"): elif ext == ".core":
data = self.data_pipeline.process_core( data = self.data_pipeline.process_core(
path, alignment_dir, alignment_index, path, alignment_dir, alignment_index,
seqemb_mode=self.config.seqemb_mode.enabled, seqemb_mode=self.config.seqemb_mode.enabled,
) )
elif(ext == ".pdb"): elif ext == ".pdb":
structure_index = None structure_index = None
if(self._structure_index is not None): if self._structure_index is not None:
structure_index = self._structure_index[name] structure_index = self._structure_index[name]
data = self.data_pipeline.process_pdb( data = self.data_pipeline.process_pdb(
pdb_path=path, pdb_path=path,
...@@ -256,7 +261,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -256,7 +261,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
seqemb_mode=self.config.seqemb_mode.enabled, seqemb_mode=self.config.seqemb_mode.enabled,
) )
else: else:
raise ValueError("Extension branch missing") raise ValueError("Extension branch missing")
else: else:
path = os.path.join(name, name + ".fasta") path = os.path.join(name, name + ".fasta")
data = self.data_pipeline.process_fasta( data = self.data_pipeline.process_fasta(
...@@ -266,11 +271,11 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -266,11 +271,11 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
seqemb_mode=self.config.seqemb_mode.enabled, seqemb_mode=self.config.seqemb_mode.enabled,
) )
if(self._output_raw): if self._output_raw:
return data return data
feats = self.feature_pipeline.process_features( feats = self.feature_pipeline.process_features(
data, self.mode data, self.mode
) )
feats["batch_idx"] = torch.tensor( feats["batch_idx"] = torch.tensor(
...@@ -281,51 +286,250 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -281,51 +286,250 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
return feats return feats
def __len__(self): def __len__(self):
return len(self._chain_ids) return len(self._chain_ids)
def deterministic_train_filter( class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
chain_data_cache_entry: Any, def __init__(self,
max_resolution: float = 9., data_dir: str,
max_single_aa_prop: float = 0.8, alignment_dir: str,
) -> bool: template_mmcif_dir: str,
# Hard filters max_template_date: str,
resolution = chain_data_cache_entry.get("resolution", None) config: mlc.ConfigDict,
if(resolution is not None and resolution > max_resolution): mmcif_data_cache_path: Optional[str] = None,
return False kalign_binary_path: str = '/usr/bin/kalign',
max_template_hits: int = 4,
obsolete_pdbs_file_path: Optional[str] = None,
template_release_dates_cache_path: Optional[str] = None,
shuffle_top_k_prefiltered: Optional[int] = None,
treat_pdb_as_distillation: bool = True,
filter_path: Optional[str] = None,
mode: str = "train",
alignment_index: Optional[Any] = None,
_output_raw: bool = False,
_structure_index: Optional[Any] = None,
):
"""
This class check each individual PDB ID and return its chain(s) features/ground truth
Args:
data_dir:
A path to a directory containing mmCIF files (in train
mode) or FASTA files (in inference mode).
alignment_dir:
A path to a directory containing only data in the format
output by an AlignmentRunner
(defined in openfold.features.alignment_runner).
I.e. a directory of directories named {PDB_ID}_{CHAIN_ID}
or simply {PDB_ID}, each containing .a3m, .sto, and .hhr
files.
template_mmcif_dir:
Path to a directory containing template mmCIF files.
config:
A dataset config object. See openfold.config
mmcif_data_cache_path:
Path to cache of all mmcifs files generated by
scripts/generate_mmcif_cache.py It should be a json file which records
what PDB ID contains which chain(s)
kalign_binary_path:
Path to kalign binary.
max_template_hits:
An upper bound on how many templates are considered. During
training, the templates ultimately used are subsampled
from this total quantity.
template_release_dates_cache_path:
Path to the output of scripts/generate_mmcif_cache.
obsolete_pdbs_file_path:
Path to the file containing replacements for obsolete PDBs.
shuffle_top_k_prefiltered:
Whether to uniformly shuffle the top k template hits before
parsing max_template_hits of them. Can be used to
approximate DeepMind's training-time template subsampling
scheme much more performantly.
treat_pdb_as_distillation:
Whether to assume that .pdb files in the data_dir are from
the self-distillation set (and should be subjected to
special distillation set preprocessing steps).
mode:
"train", "val", or "predict"
"""
super(OpenFoldSingleMultimerDataset, self).__init__()
self.data_dir = data_dir
self.mmcif_data_cache_path = mmcif_data_cache_path
seq = chain_data_cache_entry["seq"] if self.mmcif_data_cache_path is not None:
counts = {} with open(self.mmcif_data_cache_path, "r") as infile:
for aa in seq: self.mmcif_data_cache = json.load(infile)
counts.setdefault(aa, 0) assert isinstance(self.mmcif_data_cache, dict)
counts[aa] += 1
largest_aa_count = max(counts.values())
largest_single_aa_prop = largest_aa_count / len(seq)
if(largest_single_aa_prop > max_single_aa_prop):
return False
return True self.alignment_dir = alignment_dir
self.config = config
self.treat_pdb_as_distillation = treat_pdb_as_distillation
self.mode = mode
self.alignment_index = alignment_index
self._output_raw = _output_raw
self._structure_index = _structure_index
self.supported_exts = [".cif", ".core", ".pdb"]
def get_stochastic_train_filter_prob( valid_modes = ["train", "eval", "predict"]
chain_data_cache_entry: Any, if mode not in valid_modes:
) -> List[float]: raise ValueError(f'mode must be one of {valid_modes}')
# Stochastic filters
probabilities = []
cluster_size = chain_data_cache_entry.get("cluster_size", None)
if(cluster_size is not None and cluster_size > 0):
probabilities.append(1 / cluster_size)
chain_length = len(chain_data_cache_entry["seq"])
probabilities.append((1 / 512) * (max(min(chain_length, 512), 256)))
# Risk of underflow here? if template_release_dates_cache_path is None:
out = 1 logging.warning(
for p in probabilities: "Template release dates cache does not exist. Remember to run "
out *= p "scripts/generate_mmcif_cache.py before running OpenFold"
)
return out if self.mmcif_data_cache_path is not None:
self._mmcifs = list(self.mmcif_data_cache.keys())
elif self.alignment_index is not None:
self._mmcifs = [i.split("_")[0] for i in list(alignment_index.keys())]
elif self.alignment_dir is not None:
self._mmcifs = [i.split("_")[0] for i in os.listdir(self.alignment_dir)]
else:
raise ValueError("You must provide at least one of the mmcif_data_cache or alignment_dir")
if filter_path is not None:
with open(filter_path, "r") as f:
mmcifs_to_include = set([l.strip() for l in f.readlines()])
self._mmcifs = [
m for m in self._mmcifs if m in mmcifs_to_include
]
self._mmcif_id_to_idx_dict = {
mmcif: i for i, mmcif in enumerate(self._mmcifs)
}
template_featurizer = templates.HmmsearchHitFeaturizer(
mmcif_dir=template_mmcif_dir,
max_template_date=max_template_date,
max_hits=max_template_hits,
kalign_binary_path=kalign_binary_path,
release_dates_path=template_release_dates_cache_path,
obsolete_pdbs_path=obsolete_pdbs_file_path,
_shuffle_top_k_prefiltered=shuffle_top_k_prefiltered,
)
data_processor = data_pipeline.DataPipeline(
template_featurizer=template_featurizer,
)
self.data_pipeline = data_pipeline.DataPipelineMultimer(
monomer_data_pipeline=data_processor
)
self.feature_pipeline = feature_pipeline.FeaturePipeline(config)
def _parse_mmcif(self, path, file_id, alignment_dir, alignment_index):
with open(path, 'r') as f:
mmcif_string = f.read()
mmcif_object = mmcif_parsing.parse(
file_id=file_id, mmcif_string=mmcif_string
)
# Crash if an error is encountered. Any parsing errors should have
# been dealt with at the alignment stage.
if mmcif_object.mmcif_object is None:
raise list(mmcif_object.errors.values())[0]
mmcif_object = mmcif_object.mmcif_object
data = self.data_pipeline.process_mmcif(
mmcif=mmcif_object,
alignment_dir=alignment_dir,
alignment_index=alignment_index
)
return data
def mmcif_id_to_idx(self, mmcif_id):
return self._mmcif_id_to_idx_dict[mmcif_id]
def idx_to_mmcif_id(self, idx):
return self._mmcifs[idx]
def __getitem__(self, idx):
mmcif_id = self.idx_to_mmcif_id(idx)
alignment_index = None
if self.alignment_index is not None:
alignment_index = {k: v for k, v in self.alignment_index.items()
if f'{mmcif_id}_' in k}
if self.mode == 'train' or self.mode == 'eval':
path = os.path.join(self.data_dir, f"{mmcif_id}")
ext = None
for e in self.supported_exts:
if os.path.exists(path + e):
ext = e
break
if ext is None:
raise ValueError("Invalid file type")
# TODO: Add pdb and core exts to data_pipeline for multimer
path += ext
if ext == ".cif":
data = self._parse_mmcif(
path, mmcif_id, self.alignment_dir, alignment_index,
)
else:
raise ValueError("Extension branch missing")
else:
path = os.path.join(self.data_dir, f"{mmcif_id}.fasta")
data = self.data_pipeline.process_fasta(
fasta_path=path,
alignment_dir=self.alignment_dir,
alignment_index=alignment_index
)
if self._output_raw:
return data
# process all_chain_features
data = self.feature_pipeline.process_features(data,
mode=self.mode,
is_multimer=True)
# if it's inference mode, only need all_chain_features
data["batch_idx"] = torch.tensor(
[idx for _ in range(data["aatype"].shape[-1])],
dtype=torch.int64,
device=data["aatype"].device)
return data
def __len__(self):
return len(self._mmcifs)
def resolution_filter(resolution: int, max_resolution: float) -> bool:
"""Check that the resolution is <= max_resolution permitted"""
return resolution is not None and resolution <= max_resolution
def aa_count_filter(seqs: list, max_single_aa_prop: float) -> bool:
"""Check if any single amino acid accounts for more than max_single_aa_prop percent of the sequence(s)"""
counts = {}
for seq in seqs:
for aa in seq:
counts.setdefault(aa, 0)
if aa not in restypes:
return False
else:
counts[aa] += 1
total_len = sum([len(i) for i in seqs])
largest_aa_count = max(counts.values())
largest_single_aa_prop = largest_aa_count / total_len
return largest_single_aa_prop <= max_single_aa_prop
def all_seq_len_filter(seqs: list, minimum_number_of_residues: int) -> bool:
"""Check if the total combined sequence lengths are >= minimum_numer_of_residues"""
total_len = sum([len(i) for i in seqs])
return total_len >= minimum_number_of_residues
class OpenFoldDataset(torch.utils.data.Dataset): class OpenFoldDataset(torch.utils.data.Dataset):
...@@ -335,68 +539,106 @@ class OpenFoldDataset(torch.utils.data.Dataset): ...@@ -335,68 +539,106 @@ class OpenFoldDataset(torch.utils.data.Dataset):
length of an OpenFoldFilteredDataset is arbitrary. Samples are selected length of an OpenFoldFilteredDataset is arbitrary. Samples are selected
and filtered once at initialization. and filtered once at initialization.
""" """
def __init__(self, def __init__(self,
datasets: Sequence[OpenFoldSingleDataset], datasets: Union[Sequence[OpenFoldSingleDataset], Sequence[OpenFoldSingleMultimerDataset]],
probabilities: Sequence[float], probabilities: Sequence[float],
epoch_len: int, epoch_len: int,
generator: torch.Generator = None, generator: torch.Generator = None,
_roll_at_init: bool = True, _roll_at_init: bool = True,
): ):
self.datasets = datasets self.datasets = datasets
self.probabilities = probabilities self.probabilities = probabilities
self.epoch_len = epoch_len self.epoch_len = epoch_len
self.generator = generator self.generator = generator
def looped_shuffled_dataset_idx(dataset_len): self._samples = [self.looped_samples(i) for i in range(len(self.datasets))]
while True: if _roll_at_init:
# Uniformly shuffle each dataset's indices self.reroll()
weights = [1. for _ in range(dataset_len)]
shuf = torch.multinomial(
torch.tensor(weights),
num_samples=dataset_len,
replacement=False,
generator=self.generator,
)
for idx in shuf:
yield idx
def looped_samples(dataset_idx):
max_cache_len = int(epoch_len * probabilities[dataset_idx])
dataset = self.datasets[dataset_idx]
idx_iter = looped_shuffled_dataset_idx(len(dataset))
chain_data_cache = dataset.chain_data_cache
while True:
weights = []
idx = []
for _ in range(max_cache_len):
candidate_idx = next(idx_iter)
chain_id = dataset.idx_to_chain_id(candidate_idx)
chain_data_cache_entry = chain_data_cache[chain_id]
if(not deterministic_train_filter(chain_data_cache_entry)):
continue
p = get_stochastic_train_filter_prob(
chain_data_cache_entry,
)
weights.append([1. - p, p])
idx.append(candidate_idx)
samples = torch.multinomial(
torch.tensor(weights),
num_samples=1,
generator=self.generator,
)
samples = samples.squeeze()
cache = [i for i, s in zip(idx, samples) if s] @staticmethod
def deterministic_train_filter(
cache_entry: Any,
max_resolution: float = 9.,
max_single_aa_prop: float = 0.8,
*args, **kwargs
) -> bool:
# Hard filters
resolution = cache_entry.get("resolution", None)
seqs = [cache_entry["seq"]]
return all([resolution_filter(resolution=resolution,
max_resolution=max_resolution),
aa_count_filter(seqs=seqs,
max_single_aa_prop=max_single_aa_prop)])
@staticmethod
def get_stochastic_train_filter_prob(
cache_entry: Any,
*args, **kwargs
) -> float:
# Stochastic filters
probabilities = []
cluster_size = cache_entry.get("cluster_size", None)
if cluster_size is not None and cluster_size > 0:
probabilities.append(1 / cluster_size)
chain_length = len(cache_entry["seq"])
probabilities.append((1 / 512) * (max(min(chain_length, 512), 256)))
# Risk of underflow here?
out = 1
for p in probabilities:
out *= p
return out
def looped_shuffled_dataset_idx(self, dataset_len):
while True:
# Uniformly shuffle each dataset's indices
weights = [1. for _ in range(dataset_len)]
shuf = torch.multinomial(
torch.tensor(weights),
num_samples=dataset_len,
replacement=False,
generator=self.generator,
)
for idx in shuf:
yield idx
def looped_samples(self, dataset_idx):
max_cache_len = int(self.epoch_len * self.probabilities[dataset_idx])
dataset = self.datasets[dataset_idx]
idx_iter = self.looped_shuffled_dataset_idx(len(dataset))
chain_data_cache = dataset.chain_data_cache
while True:
weights = []
idx = []
for _ in range(max_cache_len):
candidate_idx = next(idx_iter)
chain_id = dataset.idx_to_chain_id(candidate_idx)
chain_data_cache_entry = chain_data_cache[chain_id]
if not self.deterministic_train_filter(chain_data_cache_entry):
continue
p = self.get_stochastic_train_filter_prob(
chain_data_cache_entry,
)
weights.append([1. - p, p])
idx.append(candidate_idx)
for datapoint_idx in cache: samples = torch.multinomial(
yield datapoint_idx torch.tensor(weights),
num_samples=1,
generator=self.generator,
)
samples = samples.squeeze()
self._samples = [looped_samples(i) for i in range(len(self.datasets))] cache = [i for i, s in zip(idx, samples) if s]
if(_roll_at_init): for datapoint_idx in cache:
self.reroll() yield datapoint_idx
def __getitem__(self, idx): def __getitem__(self, idx):
dataset_idx, datapoint_idx = self.datapoints[idx] dataset_idx, datapoint_idx = self.datapoints[idx]
...@@ -412,7 +654,6 @@ class OpenFoldDataset(torch.utils.data.Dataset): ...@@ -412,7 +654,6 @@ class OpenFoldDataset(torch.utils.data.Dataset):
replacement=True, replacement=True,
generator=self.generator, generator=self.generator,
) )
self.datapoints = [] self.datapoints = []
for dataset_idx in dataset_choices: for dataset_idx in dataset_choices:
samples = self._samples[dataset_idx] samples = self._samples[dataset_idx]
...@@ -420,10 +661,102 @@ class OpenFoldDataset(torch.utils.data.Dataset): ...@@ -420,10 +661,102 @@ class OpenFoldDataset(torch.utils.data.Dataset):
self.datapoints.append((dataset_idx, datapoint_idx)) self.datapoints.append((dataset_idx, datapoint_idx))
class OpenFoldMultimerDataset(OpenFoldDataset):
"""
Create a torch Dataset object for multimer training and
add filtering steps described in AlphaFold Multimer's paper:
https://www.biorxiv.org/content/10.1101/2021.10.04.463034v2.full.pdf Supplementary section 7.1
"""
def __init__(self,
datasets: Sequence[OpenFoldSingleMultimerDataset],
probabilities: Sequence[float],
epoch_len: int,
generator: torch.Generator = None,
_roll_at_init: bool = True
):
super(OpenFoldMultimerDataset, self).__init__(datasets=datasets,
probabilities=probabilities,
epoch_len=epoch_len,
generator=generator,
_roll_at_init=_roll_at_init)
@staticmethod
def deterministic_train_filter(
cache_entry: Any,
is_distillation: bool,
max_resolution: float = 9.,
max_single_aa_prop: float = 0.8,
minimum_number_of_residues: int = 200,
*args, **kwargs
) -> bool:
"""
Implement multimer training filtering criteria described in
https://www.biorxiv.org/content/10.1101/2021.10.04.463034v2.full.pdf Supplementary section 7.1
"""
resolution = cache_entry.get("resolution", None)
seqs = cache_entry["seqs"]
return all([resolution_filter(resolution=resolution,
max_resolution=max_resolution),
aa_count_filter(seqs=seqs,
max_single_aa_prop=max_single_aa_prop),
(not is_distillation or all_seq_len_filter(seqs=seqs,
minimum_number_of_residues=minimum_number_of_residues))])
@staticmethod
def get_stochastic_train_filter_prob(
cache_entry: Any,
*args, **kwargs
) -> list:
# Stochastic filters
cluster_sizes = cache_entry.get("cluster_sizes")
if cluster_sizes is not None:
return [1 / c if c > 0 else 1 for c in cluster_sizes]
num_chains = len(cache_entry["chain_ids"])
return [1.] * num_chains
def looped_samples(self, dataset_idx):
max_cache_len = int(self.epoch_len * self.probabilities[dataset_idx])
dataset = self.datasets[dataset_idx]
is_distillation = dataset.treat_pdb_as_distillation
idx_iter = self.looped_shuffled_dataset_idx(len(dataset))
mmcif_data_cache = dataset.mmcif_data_cache
while True:
weights = []
idx = []
for _ in range(max_cache_len):
candidate_idx = next(idx_iter)
mmcif_id = dataset.idx_to_mmcif_id(candidate_idx)
mmcif_data_cache_entry = mmcif_data_cache[mmcif_id]
if not self.deterministic_train_filter(cache_entry=mmcif_data_cache_entry,
is_distillation=is_distillation):
continue
chain_probs = self.get_stochastic_train_filter_prob(
mmcif_data_cache_entry,
)
weights.extend([[1. - p, p] for p in chain_probs])
idx.extend([candidate_idx] * len(chain_probs))
samples = torch.multinomial(
torch.tensor(weights),
num_samples=1,
generator=self.generator,
)
samples = samples.squeeze()
cache = [i for i, s in zip(idx, samples) if s]
for datapoint_idx in cache:
yield datapoint_idx
class OpenFoldBatchCollator: class OpenFoldBatchCollator:
def __call__(self, prots): def __call__(self, prots):
stack_fn = partial(torch.stack, dim=0) stack_fn = partial(torch.stack, dim=0)
return dict_multimap(stack_fn, prots) return dict_multimap(stack_fn, prots)
class OpenFoldDataLoader(torch.utils.data.DataLoader): class OpenFoldDataLoader(torch.utils.data.DataLoader):
...@@ -439,8 +772,8 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader): ...@@ -439,8 +772,8 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader):
stage_cfg = self.config[self.stage] stage_cfg = self.config[self.stage]
max_iters = self.config.common.max_recycling_iters max_iters = self.config.common.max_recycling_iters
if(stage_cfg.uniform_recycling): if stage_cfg.uniform_recycling:
recycling_probs = [ recycling_probs = [
1. / (max_iters + 1) for _ in range(max_iters + 1) 1. / (max_iters + 1) for _ in range(max_iters + 1)
] ]
...@@ -449,15 +782,15 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader): ...@@ -449,15 +782,15 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader):
0. for _ in range(max_iters + 1) 0. for _ in range(max_iters + 1)
] ]
recycling_probs[-1] = 1. recycling_probs[-1] = 1.
keyed_probs.append( keyed_probs.append(
("no_recycling_iters", recycling_probs) ("no_recycling_iters", recycling_probs)
) )
keys, probs = zip(*keyed_probs) keys, probs = zip(*keyed_probs)
max_len = max([len(p) for p in probs]) max_len = max([len(p) for p in probs])
padding = [[0.] * (max_len - len(p)) for p in probs] padding = [[0.] * (max_len - len(p)) for p in probs]
self.prop_keys = keys self.prop_keys = keys
self.prop_probs_tensor = torch.tensor( self.prop_probs_tensor = torch.tensor(
[p + pad for p, pad in zip(probs, padding)], [p + pad for p, pad in zip(probs, padding)],
...@@ -465,9 +798,10 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader): ...@@ -465,9 +798,10 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader):
) )
def _add_batch_properties(self, batch): def _add_batch_properties(self, batch):
gt_features = batch.pop('gt_features', None)
samples = torch.multinomial( samples = torch.multinomial(
self.prop_probs_tensor, self.prop_probs_tensor,
num_samples=1, # 1 per row num_samples=1, # 1 per row
replacement=True, replacement=True,
generator=self.generator generator=self.generator
) )
...@@ -479,8 +813,8 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader): ...@@ -479,8 +813,8 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader):
for i, key in enumerate(self.prop_keys): for i, key in enumerate(self.prop_keys):
sample = int(samples[i][0]) sample = int(samples[i][0])
sample_tensor = torch.tensor( sample_tensor = torch.tensor(
sample, sample,
device=aatype.device, device=aatype.device,
requires_grad=False requires_grad=False
) )
orig_shape = sample_tensor.shape orig_shape = sample_tensor.shape
...@@ -492,11 +826,12 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader): ...@@ -492,11 +826,12 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader):
) )
batch[key] = sample_tensor batch[key] = sample_tensor
if(key == "no_recycling_iters"): if key == "no_recycling_iters":
no_recycling = sample no_recycling = sample
resample_recycling = lambda t: t[..., :no_recycling + 1] resample_recycling = lambda t: t[..., :no_recycling + 1]
batch = tensor_tree_map(resample_recycling, batch) batch = tensor_tree_map(resample_recycling, batch)
batch['gt_features'] = gt_features
return batch return batch
...@@ -512,31 +847,31 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader): ...@@ -512,31 +847,31 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader):
class OpenFoldDataModule(pl.LightningDataModule): class OpenFoldDataModule(pl.LightningDataModule):
def __init__(self, def __init__(self,
config: mlc.ConfigDict, config: mlc.ConfigDict,
template_mmcif_dir: str, template_mmcif_dir: str,
max_template_date: str, max_template_date: str,
train_data_dir: Optional[str] = None, train_data_dir: Optional[str] = None,
train_alignment_dir: Optional[str] = None, train_alignment_dir: Optional[str] = None,
train_chain_data_cache_path: Optional[str] = None, train_chain_data_cache_path: Optional[str] = None,
distillation_data_dir: Optional[str] = None, distillation_data_dir: Optional[str] = None,
distillation_alignment_dir: Optional[str] = None, distillation_alignment_dir: Optional[str] = None,
distillation_chain_data_cache_path: Optional[str] = None, distillation_chain_data_cache_path: Optional[str] = None,
val_data_dir: Optional[str] = None, val_data_dir: Optional[str] = None,
val_alignment_dir: Optional[str] = None, val_alignment_dir: Optional[str] = None,
predict_data_dir: Optional[str] = None, predict_data_dir: Optional[str] = None,
predict_alignment_dir: Optional[str] = None, predict_alignment_dir: Optional[str] = None,
kalign_binary_path: str = '/usr/bin/kalign', kalign_binary_path: str = '/usr/bin/kalign',
train_filter_path: Optional[str] = None, train_filter_path: Optional[str] = None,
distillation_filter_path: Optional[str] = None, distillation_filter_path: Optional[str] = None,
obsolete_pdbs_file_path: Optional[str] = None, obsolete_pdbs_file_path: Optional[str] = None,
template_release_dates_cache_path: Optional[str] = None, template_release_dates_cache_path: Optional[str] = None,
batch_seed: Optional[int] = None, batch_seed: Optional[int] = None,
train_epoch_len: int = 50000, train_epoch_len: int = 50000,
_distillation_structure_index_path: Optional[str] = None, _distillation_structure_index_path: Optional[str] = None,
alignment_index_path: Optional[str] = None, alignment_index_path: Optional[str] = None,
distillation_alignment_index_path: Optional[str] = None, distillation_alignment_index_path: Optional[str] = None,
**kwargs **kwargs
): ):
super(OpenFoldDataModule, self).__init__() super(OpenFoldDataModule, self).__init__()
self.config = config self.config = config
...@@ -564,7 +899,7 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -564,7 +899,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
self.batch_seed = batch_seed self.batch_seed = batch_seed
self.train_epoch_len = train_epoch_len self.train_epoch_len = train_epoch_len
if(self.train_data_dir is None and self.predict_data_dir is None): if self.train_data_dir is None and self.predict_data_dir is None:
raise ValueError( raise ValueError(
'At least one of train_data_dir or predict_data_dir must be ' 'At least one of train_data_dir or predict_data_dir must be '
'specified' 'specified'
...@@ -572,65 +907,61 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -572,65 +907,61 @@ class OpenFoldDataModule(pl.LightningDataModule):
self.training_mode = self.train_data_dir is not None self.training_mode = self.train_data_dir is not None
if(self.training_mode and train_alignment_dir is None): if self.training_mode and train_alignment_dir is None:
raise ValueError( raise ValueError(
'In training mode, train_alignment_dir must be specified' 'In training mode, train_alignment_dir must be specified'
) )
elif(not self.training_mode and predict_alignment_dir is None): elif not self.training_mode and predict_alignment_dir is None:
raise ValueError( raise ValueError(
'In inference mode, predict_alignment_dir must be specified' 'In inference mode, predict_alignment_dir must be specified'
) )
elif(val_data_dir is not None and val_alignment_dir is None): elif val_data_dir is not None and val_alignment_dir is None:
raise ValueError( raise ValueError(
'If val_data_dir is specified, val_alignment_dir must ' 'If val_data_dir is specified, val_alignment_dir must '
'be specified as well' 'be specified as well'
) )
# An ad-hoc measure for our particular filesystem restrictions # An ad-hoc measure for our particular filesystem restrictions
self._distillation_structure_index = None self._distillation_structure_index = None
if(_distillation_structure_index_path is not None): if _distillation_structure_index_path is not None:
with open(_distillation_structure_index_path, "r") as fp: with open(_distillation_structure_index_path, "r") as fp:
self._distillation_structure_index = json.load(fp) self._distillation_structure_index = json.load(fp)
self.alignment_index = None self.alignment_index = None
if(alignment_index_path is not None): if alignment_index_path is not None:
with open(alignment_index_path, "r") as fp: with open(alignment_index_path, "r") as fp:
self.alignment_index = json.load(fp) self.alignment_index = json.load(fp)
self.distillation_alignment_index = None self.distillation_alignment_index = None
if(distillation_alignment_index_path is not None): if distillation_alignment_index_path is not None:
with open(distillation_alignment_index_path, "r") as fp: with open(distillation_alignment_index_path, "r") as fp:
self.distillation_alignment_index = json.load(fp) self.distillation_alignment_index = json.load(fp)
def setup(self): def setup(self):
# Most of the arguments are the same for the three datasets # Most of the arguments are the same for the three datasets
dataset_gen = partial(OpenFoldSingleDataset, dataset_gen = partial(OpenFoldSingleDataset,
template_mmcif_dir=self.template_mmcif_dir, template_mmcif_dir=self.template_mmcif_dir,
max_template_date=self.max_template_date, max_template_date=self.max_template_date,
config=self.config, config=self.config,
kalign_binary_path=self.kalign_binary_path, kalign_binary_path=self.kalign_binary_path,
template_release_dates_cache_path= template_release_dates_cache_path=self.template_release_dates_cache_path,
self.template_release_dates_cache_path, obsolete_pdbs_file_path=self.obsolete_pdbs_file_path)
obsolete_pdbs_file_path=
self.obsolete_pdbs_file_path, if self.training_mode:
)
if(self.training_mode):
train_dataset = dataset_gen( train_dataset = dataset_gen(
data_dir=self.train_data_dir, data_dir=self.train_data_dir,
chain_data_cache_path=self.train_chain_data_cache_path, chain_data_cache_path=self.train_chain_data_cache_path,
alignment_dir=self.train_alignment_dir, alignment_dir=self.train_alignment_dir,
filter_path=self.train_filter_path, filter_path=self.train_filter_path,
max_template_hits=self.config.train.max_template_hits, max_template_hits=self.config.train.max_template_hits,
shuffle_top_k_prefiltered= shuffle_top_k_prefiltered=self.config.train.shuffle_top_k_prefiltered,
self.config.train.shuffle_top_k_prefiltered,
treat_pdb_as_distillation=False, treat_pdb_as_distillation=False,
mode="train", mode="train",
alignment_index=self.alignment_index, alignment_index=self.alignment_index,
) )
distillation_dataset = None distillation_dataset = None
if(self.distillation_data_dir is not None): if self.distillation_data_dir is not None:
distillation_dataset = dataset_gen( distillation_dataset = dataset_gen(
data_dir=self.distillation_data_dir, data_dir=self.distillation_data_dir,
chain_data_cache_path=self.distillation_chain_data_cache_path, chain_data_cache_path=self.distillation_chain_data_cache_path,
...@@ -644,8 +975,8 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -644,8 +975,8 @@ class OpenFoldDataModule(pl.LightningDataModule):
) )
d_prob = self.config.train.distillation_prob d_prob = self.config.train.distillation_prob
if(distillation_dataset is not None): if distillation_dataset is not None:
datasets = [train_dataset, distillation_dataset] datasets = [train_dataset, distillation_dataset]
d_prob = self.config.train.distillation_prob d_prob = self.config.train.distillation_prob
probabilities = [1. - d_prob, d_prob] probabilities = [1. - d_prob, d_prob]
...@@ -654,10 +985,10 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -654,10 +985,10 @@ class OpenFoldDataModule(pl.LightningDataModule):
probabilities = [1.] probabilities = [1.]
generator = None generator = None
if(self.batch_seed is not None): if self.batch_seed is not None:
generator = torch.Generator() generator = torch.Generator()
generator = generator.manual_seed(self.batch_seed + 1) generator = generator.manual_seed(self.batch_seed + 1)
self.train_dataset = OpenFoldDataset( self.train_dataset = OpenFoldDataset(
datasets=datasets, datasets=datasets,
probabilities=probabilities, probabilities=probabilities,
...@@ -665,8 +996,8 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -665,8 +996,8 @@ class OpenFoldDataModule(pl.LightningDataModule):
generator=generator, generator=generator,
_roll_at_init=False, _roll_at_init=False,
) )
if(self.val_data_dir is not None): if self.val_data_dir is not None:
self.eval_dataset = dataset_gen( self.eval_dataset = dataset_gen(
data_dir=self.val_data_dir, data_dir=self.val_data_dir,
alignment_dir=self.val_alignment_dir, alignment_dir=self.val_alignment_dir,
...@@ -676,7 +1007,7 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -676,7 +1007,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
) )
else: else:
self.eval_dataset = None self.eval_dataset = None
else: else:
self.predict_dataset = dataset_gen( self.predict_dataset = dataset_gen(
data_dir=self.predict_data_dir, data_dir=self.predict_data_dir,
alignment_dir=self.predict_alignment_dir, alignment_dir=self.predict_alignment_dir,
...@@ -687,18 +1018,17 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -687,18 +1018,17 @@ class OpenFoldDataModule(pl.LightningDataModule):
def _gen_dataloader(self, stage): def _gen_dataloader(self, stage):
generator = None generator = None
if(self.batch_seed is not None): if self.batch_seed is not None:
generator = torch.Generator() generator = torch.Generator()
generator = generator.manual_seed(self.batch_seed) generator = generator.manual_seed(self.batch_seed)
dataset = None if stage == "train":
if(stage == "train"):
dataset = self.train_dataset dataset = self.train_dataset
# Filter the dataset, if necessary # Filter the dataset, if necessary
dataset.reroll() dataset.reroll()
elif(stage == "eval"): elif stage == "eval":
dataset = self.eval_dataset dataset = self.eval_dataset
elif(stage == "predict"): elif stage == "predict":
dataset = self.predict_dataset dataset = self.predict_dataset
else: else:
raise ValueError("Invalid stage") raise ValueError("Invalid stage")
...@@ -718,15 +1048,121 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -718,15 +1048,121 @@ class OpenFoldDataModule(pl.LightningDataModule):
return dl return dl
def train_dataloader(self): def train_dataloader(self):
return self._gen_dataloader("train") return self._gen_dataloader("train")
def val_dataloader(self): def val_dataloader(self):
if(self.eval_dataset is not None): if self.eval_dataset is not None:
return self._gen_dataloader("eval") return self._gen_dataloader("eval")
return None return None
def predict_dataloader(self): def predict_dataloader(self):
return self._gen_dataloader("predict") return self._gen_dataloader("predict")
class OpenFoldMultimerDataModule(OpenFoldDataModule):
"""
Create a datamodule specifically for multimer training
Compared to OpenFoldDataModule, OpenFoldMultimerDataModule
requires mmcif_data_cache_path which is the product of
scripts/generate_mmcif_cache.py mmcif_data_cache_path should be
a file that record what chain(s) each mmcif file has
"""
def __init__(self, config: mlc.ConfigDict,
template_mmcif_dir: str, max_template_date: str,
train_data_dir: Optional[str] = None,
train_mmcif_data_cache_path: Optional[str] = None,
val_mmcif_data_cache_path: Optional[str] = None,
**kwargs):
super(OpenFoldMultimerDataModule, self).__init__(config,
template_mmcif_dir,
max_template_date,
train_data_dir,
**kwargs)
self.train_mmcif_data_cache_path = train_mmcif_data_cache_path
self.training_mode = self.train_data_dir is not None
self.val_mmcif_data_cache_path = val_mmcif_data_cache_path
def setup(self):
# Most of the arguments are the same for the three datasets
dataset_gen = partial(OpenFoldSingleMultimerDataset,
template_mmcif_dir=self.template_mmcif_dir,
max_template_date=self.max_template_date,
config=self.config,
kalign_binary_path=self.kalign_binary_path,
template_release_dates_cache_path=self.template_release_dates_cache_path,
obsolete_pdbs_file_path=self.obsolete_pdbs_file_path)
if self.training_mode:
train_dataset = dataset_gen(
data_dir=self.train_data_dir,
mmcif_data_cache_path=self.train_mmcif_data_cache_path,
alignment_dir=self.train_alignment_dir,
filter_path=self.train_filter_path,
max_template_hits=self.config.train.max_template_hits,
shuffle_top_k_prefiltered=self.config.train.shuffle_top_k_prefiltered,
treat_pdb_as_distillation=False,
mode="train",
alignment_index=self.alignment_index,
)
distillation_dataset = None
if self.distillation_data_dir is not None:
distillation_dataset = dataset_gen(
data_dir=self.distillation_data_dir,
alignment_dir=self.distillation_alignment_dir,
filter_path=self.distillation_filter_path,
max_template_hits=self.config.train.max_template_hits,
treat_pdb_as_distillation=True,
mode="train",
alignment_index=self.distillation_alignment_index,
_structure_index=self._distillation_structure_index,
)
d_prob = self.config.train.distillation_prob
if distillation_dataset is not None:
datasets = [train_dataset, distillation_dataset]
d_prob = self.config.train.distillation_prob
probabilities = [1. - d_prob, d_prob]
else:
datasets = [train_dataset]
probabilities = [1.]
generator = None
if self.batch_seed is not None:
generator = torch.Generator()
generator = generator.manual_seed(self.batch_seed + 1)
self.train_dataset = OpenFoldMultimerDataset(
datasets=datasets,
probabilities=probabilities,
epoch_len=self.train_epoch_len,
generator=generator,
_roll_at_init=True,
)
if self.val_data_dir is not None:
self.eval_dataset = dataset_gen(
data_dir=self.val_data_dir,
alignment_dir=self.val_alignment_dir,
mmcif_data_cache_path=self.val_mmcif_data_cache_path,
filter_path=None,
max_template_hits=self.config.eval.max_template_hits,
mode="eval",
)
else:
self.eval_dataset = None
else:
self.predict_dataset = dataset_gen(
data_dir=self.predict_data_dir,
alignment_dir=self.predict_alignment_dir,
filter_path=None,
max_template_hits=self.config.predict.max_template_hits,
mode="predict",
)
class DummyDataset(torch.utils.data.Dataset): class DummyDataset(torch.utils.data.Dataset):
......
...@@ -14,38 +14,28 @@ ...@@ -14,38 +14,28 @@
# limitations under the License. # limitations under the License.
import os import os
import datetime import copy
import collections
import contextlib
import dataclasses
from multiprocessing import cpu_count from multiprocessing import cpu_count
from typing import Mapping, Optional, Sequence, Any import tempfile
from typing import Mapping, Optional, Sequence, Any, MutableMapping, Union
import numpy as np import numpy as np
import torch import torch
from openfold.data import templates, parsers, mmcif_parsing, msa_identifiers, msa_pairing, feature_processing_multimer
from openfold.data import templates, parsers, mmcif_parsing from openfold.data.templates import get_custom_template_features, empty_template_feats
from openfold.data.templates import get_custom_template_features from openfold.data.tools import jackhmmer, hhblits, hhsearch, hmmsearch
from openfold.data.tools import jackhmmer, hhblits, hhsearch
from openfold.data.tools.utils import to_date
from openfold.np import residue_constants, protein from openfold.np import residue_constants, protein
FeatureDict = MutableMapping[str, np.ndarray]
FeatureDict = Mapping[str, np.ndarray] TemplateSearcher = Union[hhsearch.HHSearch, hmmsearch.Hmmsearch]
def empty_template_feats(n_res) -> FeatureDict:
return {
"template_aatype": np.zeros((0, n_res)).astype(np.int64),
"template_all_atom_positions":
np.zeros((0, n_res, 37, 3)).astype(np.float32),
"template_sum_probs": np.zeros((0, 1)).astype(np.float32),
"template_all_atom_mask": np.zeros((0, n_res, 37)).astype(np.float32),
}
def make_template_features( def make_template_features(
input_sequence: str, input_sequence: str,
hits: Sequence[Any], hits: Sequence[Any],
template_featurizer: Any, template_featurizer: Any,
query_pdb_code: Optional[str] = None,
query_release_date: Optional[str] = None,
) -> FeatureDict: ) -> FeatureDict:
hits_cat = sum(hits.values(), []) hits_cat = sum(hits.values(), [])
if(len(hits_cat) == 0 or template_featurizer is None): if(len(hits_cat) == 0 or template_featurizer is None):
...@@ -53,17 +43,10 @@ def make_template_features( ...@@ -53,17 +43,10 @@ def make_template_features(
else: else:
templates_result = template_featurizer.get_templates( templates_result = template_featurizer.get_templates(
query_sequence=input_sequence, query_sequence=input_sequence,
query_pdb_code=query_pdb_code,
query_release_date=query_release_date,
hits=hits_cat, hits=hits_cat,
) )
template_features = templates_result.features template_features = templates_result.features
# The template featurizer doesn't format empty template features
# properly. This is a quick fix.
if(template_features["template_aatype"].shape[0] == 0):
template_features = empty_template_feats(len(input_sequence))
return template_features return template_features
...@@ -86,7 +69,7 @@ def unify_template_features( ...@@ -86,7 +69,7 @@ def unify_template_features(
assert(new_shape[1] == n_res) assert(new_shape[1] == n_res)
new_shape[1] = sum(seq_lens) new_shape[1] = sum(seq_lens)
new_array = np.zeros(new_shape, dtype=v.dtype) new_array = np.zeros(new_shape, dtype=v.dtype)
if(k == "template_aatype"): if(k == "template_aatype"):
new_array[..., residue_constants.HHBLITS_AA_TO_ID['-']] = 1 new_array[..., residue_constants.HHBLITS_AA_TO_ID['-']] = 1
...@@ -172,13 +155,13 @@ def make_mmcif_features( ...@@ -172,13 +155,13 @@ def make_mmcif_features(
def _aatype_to_str_sequence(aatype): def _aatype_to_str_sequence(aatype):
return ''.join([ return ''.join([
residue_constants.restypes_with_x[aatype[i]] residue_constants.restypes_with_x[aatype[i]]
for i in range(len(aatype)) for i in range(len(aatype))
]) ])
def make_protein_features( def make_protein_features(
protein_object: protein.Protein, protein_object: protein.Protein,
description: str, description: str,
_is_distillation: bool = False, _is_distillation: bool = False,
) -> FeatureDict: ) -> FeatureDict:
...@@ -225,32 +208,35 @@ def make_pdb_features( ...@@ -225,32 +208,35 @@ def make_pdb_features(
return pdb_feats return pdb_feats
def make_msa_features( def make_msa_features(msas: Sequence[parsers.Msa]) -> FeatureDict:
msas: Sequence[Sequence[str]],
deletion_matrices: Sequence[parsers.DeletionMatrix],
) -> FeatureDict:
"""Constructs a feature dict of MSA features.""" """Constructs a feature dict of MSA features."""
if not msas: if not msas:
raise ValueError("At least one MSA must be provided.") raise ValueError("At least one MSA must be provided.")
int_msa = [] int_msa = []
deletion_matrix = [] deletion_matrix = []
species_ids = []
seen_sequences = set() seen_sequences = set()
for msa_index, msa in enumerate(msas): for msa_index, msa in enumerate(msas):
if not msa: if not msa:
raise ValueError( raise ValueError(
f"MSA {msa_index} must contain at least one sequence." f"MSA {msa_index} must contain at least one sequence."
) )
for sequence_index, sequence in enumerate(msa): for sequence_index, sequence in enumerate(msa.sequences):
if sequence in seen_sequences: if sequence in seen_sequences:
continue continue
seen_sequences.add(sequence) seen_sequences.add(sequence)
int_msa.append( int_msa.append(
[residue_constants.HHBLITS_AA_TO_ID[res] for res in sequence] [residue_constants.HHBLITS_AA_TO_ID[res] for res in sequence]
) )
deletion_matrix.append(deletion_matrices[msa_index][sequence_index])
num_res = len(msas[0][0]) deletion_matrix.append(msa.deletion_matrix[sequence_index])
identifiers = msa_identifiers.get_identifiers(
msa.descriptions[sequence_index]
)
species_ids.append(identifiers.species_id.encode('utf-8'))
num_res = len(msas[0].sequences[0])
num_alignments = len(int_msa) num_alignments = len(int_msa)
features = {} features = {}
features["deletion_matrix_int"] = np.array(deletion_matrix, dtype=np.int32) features["deletion_matrix_int"] = np.array(deletion_matrix, dtype=np.int32)
...@@ -258,19 +244,41 @@ def make_msa_features( ...@@ -258,19 +244,41 @@ def make_msa_features(
features["num_alignments"] = np.array( features["num_alignments"] = np.array(
[num_alignments] * num_res, dtype=np.int32 [num_alignments] * num_res, dtype=np.int32
) )
features["msa_species_identifiers"] = np.array(species_ids, dtype=np.object_)
return features return features
# Generate 1-sequence MSA features having only the input sequence def run_msa_tool(
def make_dummy_msa_feats(input_sequence): msa_runner,
msas = [[input_sequence]] fasta_path: str,
deletion_matrices = [[[0 for _ in input_sequence]]] msa_out_path: str,
msa_features = make_msa_features( msa_format: str,
msas=msas, max_sto_sequences: Optional[int] = None,
deletion_matrices=deletion_matrices, ) -> Mapping[str, Any]:
) """Runs an MSA tool, checking if output already exists first."""
if(msa_format == "sto" and max_sto_sequences is not None):
result = msa_runner.query(fasta_path, max_sto_sequences)[0]
else:
result = msa_runner.query(fasta_path)[0]
assert msa_out_path.split('.')[-1] == msa_format
with open(msa_out_path, "w") as f:
f.write(result[msa_format])
return msa_features return result
def make_dummy_msa_obj(input_sequence) -> parsers.Msa:
deletion_matrix = [[0 for _ in input_sequence]]
return parsers.Msa(sequences=[input_sequence],
deletion_matrix=deletion_matrix,
descriptions=['dummy'])
# Generate 1-sequence MSA features having only the input sequence
def make_dummy_msa_feats(input_sequence) -> FeatureDict:
msa_data_obj = make_dummy_msa_obj(input_sequence)
return make_msa_features([msa_data_obj])
def make_sequence_features_with_custom_template( def make_sequence_features_with_custom_template(
...@@ -290,10 +298,11 @@ def make_sequence_features_with_custom_template( ...@@ -290,10 +298,11 @@ def make_sequence_features_with_custom_template(
num_res=num_res, num_res=num_res,
) )
msa_data = [[sequence]] msa_data = [sequence]
deletion_matrix = [[[0 for _ in sequence]]] deletion_matrix = [[0 for _ in sequence]]
msa_data_obj = parsers.Msa(sequences=msa_data, deletion_matrix=deletion_matrix, descriptions=None)
msa_features = make_msa_features(msa_data, deletion_matrix) msa_features = make_msa_features([msa_data_obj])
template_features = get_custom_template_features( template_features = get_custom_template_features(
mmcif_path=mmcif_path, mmcif_path=mmcif_path,
query_sequence=sequence, query_sequence=sequence,
...@@ -308,22 +317,25 @@ def make_sequence_features_with_custom_template( ...@@ -308,22 +317,25 @@ def make_sequence_features_with_custom_template(
**template_features.features **template_features.features
} }
class AlignmentRunner: class AlignmentRunner:
"""Runs alignment tools and saves the results""" """Runs alignment tools and saves the results"""
def __init__( def __init__(
self, self,
jackhmmer_binary_path: Optional[str] = None, jackhmmer_binary_path: Optional[str] = None,
hhblits_binary_path: Optional[str] = None, hhblits_binary_path: Optional[str] = None,
hhsearch_binary_path: Optional[str] = None,
uniref90_database_path: Optional[str] = None, uniref90_database_path: Optional[str] = None,
mgnify_database_path: Optional[str] = None, mgnify_database_path: Optional[str] = None,
bfd_database_path: Optional[str] = None, bfd_database_path: Optional[str] = None,
uniref30_database_path: Optional[str] = None,
uniclust30_database_path: Optional[str] = None, uniclust30_database_path: Optional[str] = None,
pdb70_database_path: Optional[str] = None, uniprot_database_path: Optional[str] = None,
template_searcher: Optional[TemplateSearcher] = None,
use_small_bfd: Optional[bool] = None, use_small_bfd: Optional[bool] = None,
no_cpus: Optional[int] = None, no_cpus: Optional[int] = None,
uniref_max_hits: int = 10000, uniref_max_hits: int = 10000,
mgnify_max_hits: int = 5000, mgnify_max_hits: int = 5000,
uniprot_max_hits: int = 50000,
): ):
""" """
Args: Args:
...@@ -331,8 +343,6 @@ class AlignmentRunner: ...@@ -331,8 +343,6 @@ class AlignmentRunner:
Path to jackhmmer binary Path to jackhmmer binary
hhblits_binary_path: hhblits_binary_path:
Path to hhblits binary Path to hhblits binary
hhsearch_binary_path:
Path to hhsearch binary
uniref90_database_path: uniref90_database_path:
Path to uniref90 database. If provided, jackhmmer_binary_path Path to uniref90 database. If provided, jackhmmer_binary_path
must also be provided must also be provided
...@@ -341,16 +351,17 @@ class AlignmentRunner: ...@@ -341,16 +351,17 @@ class AlignmentRunner:
must also be provided must also be provided
bfd_database_path: bfd_database_path:
Path to BFD database. Depending on the value of use_small_bfd, Path to BFD database. Depending on the value of use_small_bfd,
one of hhblits_binary_path or jackhmmer_binary_path must be one of hhblits_binary_path or jackhmmer_binary_path must be
provided. provided.
uniref30_database_path:
Path to uniref30. Searched alongside BFD if use_small_bfd is
false.
uniclust30_database_path: uniclust30_database_path:
Path to uniclust30. Searched alongside BFD if use_small_bfd is Path to uniclust30. Searched alongside BFD if use_small_bfd is
false. false.
pdb70_database_path:
Path to pdb70 database.
use_small_bfd: use_small_bfd:
Whether to search the BFD database alone with jackhmmer or Whether to search the BFD database alone with jackhmmer or
in conjunction with uniclust30 with hhblits. in conjunction with uniref30/uniclust30 with hhblits.
no_cpus: no_cpus:
The number of CPUs available for alignment. By default, all The number of CPUs available for alignment. By default, all
CPUs are used. CPUs are used.
...@@ -358,6 +369,8 @@ class AlignmentRunner: ...@@ -358,6 +369,8 @@ class AlignmentRunner:
Max number of uniref hits Max number of uniref hits
mgnify_max_hits: mgnify_max_hits:
Max number of mgnify hits Max number of mgnify hits
uniprot_max_hits:
Max number of uniprot hits
""" """
db_map = { db_map = {
"jackhmmer": { "jackhmmer": {
...@@ -366,6 +379,7 @@ class AlignmentRunner: ...@@ -366,6 +379,7 @@ class AlignmentRunner:
uniref90_database_path, uniref90_database_path,
mgnify_database_path, mgnify_database_path,
bfd_database_path if use_small_bfd else None, bfd_database_path if use_small_bfd else None,
uniprot_database_path,
], ],
}, },
"hhblits": { "hhblits": {
...@@ -374,12 +388,6 @@ class AlignmentRunner: ...@@ -374,12 +388,6 @@ class AlignmentRunner:
bfd_database_path if not use_small_bfd else None, bfd_database_path if not use_small_bfd else None,
], ],
}, },
"hhsearch": {
"binary": hhsearch_binary_path,
"dbs": [
pdb70_database_path,
],
},
} }
for name, dic in db_map.items(): for name, dic in db_map.items():
...@@ -389,22 +397,16 @@ class AlignmentRunner: ...@@ -389,22 +397,16 @@ class AlignmentRunner:
f"{name} DBs provided but {name} binary is None" f"{name} DBs provided but {name} binary is None"
) )
if(not all([x is None for x in db_map["hhsearch"]["dbs"]])
and uniref90_database_path is None):
raise ValueError(
"""uniref90_database_path must be specified in order to perform
template search"""
)
self.uniref_max_hits = uniref_max_hits self.uniref_max_hits = uniref_max_hits
self.mgnify_max_hits = mgnify_max_hits self.mgnify_max_hits = mgnify_max_hits
self.uniprot_max_hits = uniprot_max_hits
self.use_small_bfd = use_small_bfd self.use_small_bfd = use_small_bfd
if(no_cpus is None): if(no_cpus is None):
no_cpus = cpu_count() no_cpus = cpu_count()
self.jackhmmer_uniref90_runner = None self.jackhmmer_uniref90_runner = None
if(jackhmmer_binary_path is not None and if(jackhmmer_binary_path is not None and
uniref90_database_path is not None uniref90_database_path is not None
): ):
self.jackhmmer_uniref90_runner = jackhmmer.Jackhmmer( self.jackhmmer_uniref90_runner = jackhmmer.Jackhmmer(
...@@ -412,9 +414,9 @@ class AlignmentRunner: ...@@ -412,9 +414,9 @@ class AlignmentRunner:
database_path=uniref90_database_path, database_path=uniref90_database_path,
n_cpu=no_cpus, n_cpu=no_cpus,
) )
self.jackhmmer_small_bfd_runner = None self.jackhmmer_small_bfd_runner = None
self.hhblits_bfd_uniclust_runner = None self.hhblits_bfd_unirefclust_runner = None
if(bfd_database_path is not None): if(bfd_database_path is not None):
if use_small_bfd: if use_small_bfd:
self.jackhmmer_small_bfd_runner = jackhmmer.Jackhmmer( self.jackhmmer_small_bfd_runner = jackhmmer.Jackhmmer(
...@@ -424,9 +426,11 @@ class AlignmentRunner: ...@@ -424,9 +426,11 @@ class AlignmentRunner:
) )
else: else:
dbs = [bfd_database_path] dbs = [bfd_database_path]
if(uniclust30_database_path is not None): if(uniref30_database_path is not None):
dbs.append(uniref30_database_path)
if (uniclust30_database_path is not None):
dbs.append(uniclust30_database_path) dbs.append(uniclust30_database_path)
self.hhblits_bfd_uniclust_runner = hhblits.HHBlits( self.hhblits_bfd_unirefclust_runner = hhblits.HHBlits(
binary_path=hhblits_binary_path, binary_path=hhblits_binary_path,
databases=dbs, databases=dbs,
n_cpu=no_cpus, n_cpu=no_cpus,
...@@ -440,14 +444,23 @@ class AlignmentRunner: ...@@ -440,14 +444,23 @@ class AlignmentRunner:
n_cpu=no_cpus, n_cpu=no_cpus,
) )
self.hhsearch_pdb70_runner = None self.jackhmmer_uniprot_runner = None
if(pdb70_database_path is not None): if(uniprot_database_path is not None):
self.hhsearch_pdb70_runner = hhsearch.HHSearch( self.jackhmmer_uniprot_runner = jackhmmer.Jackhmmer(
binary_path=hhsearch_binary_path, binary_path=jackhmmer_binary_path,
databases=[pdb70_database_path], database_path=uniprot_database_path,
n_cpu=no_cpus, n_cpu=no_cpus
) )
if(template_searcher is not None and
self.jackhmmer_uniref90_runner is None
):
raise ValueError(
"Uniref90 runner must be specified to run template search"
)
self.template_searcher = template_searcher
def run( def run(
self, self,
fasta_path: str, fasta_path: str,
...@@ -455,52 +468,226 @@ class AlignmentRunner: ...@@ -455,52 +468,226 @@ class AlignmentRunner:
): ):
"""Runs alignment tools on a sequence""" """Runs alignment tools on a sequence"""
if(self.jackhmmer_uniref90_runner is not None): if(self.jackhmmer_uniref90_runner is not None):
jackhmmer_uniref90_result = self.jackhmmer_uniref90_runner.query( uniref90_out_path = os.path.join(output_dir, "uniref90_hits.sto")
fasta_path
)[0] jackhmmer_uniref90_result = run_msa_tool(
uniref90_msa_as_a3m = parsers.convert_stockholm_to_a3m( msa_runner=self.jackhmmer_uniref90_runner,
jackhmmer_uniref90_result["sto"], fasta_path=fasta_path,
max_sequences=self.uniref_max_hits msa_out_path=uniref90_out_path,
msa_format='sto',
max_sto_sequences=self.uniref_max_hits,
) )
uniref90_out_path = os.path.join(output_dir, "uniref90_hits.a3m")
with open(uniref90_out_path, "w") as f:
f.write(uniref90_msa_as_a3m)
if(self.hhsearch_pdb70_runner is not None): template_msa = jackhmmer_uniref90_result["sto"]
hhsearch_result = self.hhsearch_pdb70_runner.query( template_msa = parsers.deduplicate_stockholm_msa(template_msa)
uniref90_msa_as_a3m template_msa = parsers.remove_empty_columns_from_stockholm_msa(
) template_msa
pdb70_out_path = os.path.join(output_dir, "pdb70_hits.hhr") )
with open(pdb70_out_path, "w") as f:
f.write(hhsearch_result) if(self.template_searcher is not None):
if(self.template_searcher.input_format == "sto"):
pdb_templates_result = self.template_searcher.query(
template_msa,
output_dir=output_dir
)
elif(self.template_searcher.input_format == "a3m"):
uniref90_msa_as_a3m = parsers.convert_stockholm_to_a3m(
template_msa
)
pdb_templates_result = self.template_searcher.query(
uniref90_msa_as_a3m,
output_dir=output_dir
)
else:
fmt = self.template_searcher.input_format
raise ValueError(
f"Unrecognized template input format: {fmt}"
)
if(self.jackhmmer_mgnify_runner is not None): if(self.jackhmmer_mgnify_runner is not None):
jackhmmer_mgnify_result = self.jackhmmer_mgnify_runner.query( mgnify_out_path = os.path.join(output_dir, "mgnify_hits.sto")
fasta_path jackhmmer_mgnify_result = run_msa_tool(
)[0] msa_runner=self.jackhmmer_mgnify_runner,
mgnify_msa_as_a3m = parsers.convert_stockholm_to_a3m( fasta_path=fasta_path,
jackhmmer_mgnify_result["sto"], msa_out_path=mgnify_out_path,
max_sequences=self.mgnify_max_hits msa_format='sto',
max_sto_sequences=self.mgnify_max_hits
) )
mgnify_out_path = os.path.join(output_dir, "mgnify_hits.a3m")
with open(mgnify_out_path, "w") as f:
f.write(mgnify_msa_as_a3m)
if(self.use_small_bfd and self.jackhmmer_small_bfd_runner is not None): if(self.use_small_bfd and self.jackhmmer_small_bfd_runner is not None):
jackhmmer_small_bfd_result = self.jackhmmer_small_bfd_runner.query(
fasta_path
)[0]
bfd_out_path = os.path.join(output_dir, "small_bfd_hits.sto") bfd_out_path = os.path.join(output_dir, "small_bfd_hits.sto")
with open(bfd_out_path, "w") as f: jackhmmer_small_bfd_result = run_msa_tool(
f.write(jackhmmer_small_bfd_result["sto"]) msa_runner=self.jackhmmer_small_bfd_runner,
elif(self.hhblits_bfd_uniclust_runner is not None): fasta_path=fasta_path,
hhblits_bfd_uniclust_result = ( msa_out_path=bfd_out_path,
self.hhblits_bfd_uniclust_runner.query(fasta_path) msa_format="sto",
)
elif(self.hhblits_bfd_unirefclust_runner is not None):
uni_name = "uni"
for db_name in self.hhblits_bfd_unirefclust_runner.databases:
if "uniref" in db_name.lower():
uni_name = f"{uni_name}ref"
elif "uniclust" in db_name.lower():
uni_name = f"{uni_name}clust"
bfd_out_path = os.path.join(output_dir, f"bfd_{uni_name}_hits.a3m")
hhblits_bfd_unirefclust_result = run_msa_tool(
msa_runner=self.hhblits_bfd_unirefclust_runner,
fasta_path=fasta_path,
msa_out_path=bfd_out_path,
msa_format="a3m",
)
if(self.jackhmmer_uniprot_runner is not None):
uniprot_out_path = os.path.join(output_dir, 'uniprot_hits.sto')
result = run_msa_tool(
self.jackhmmer_uniprot_runner,
fasta_path=fasta_path,
msa_out_path=uniprot_out_path,
msa_format='sto',
max_sto_sequences=self.uniprot_max_hits,
) )
if output_dir is not None:
bfd_out_path = os.path.join(output_dir, "bfd_uniclust_hits.a3m")
with open(bfd_out_path, "w") as f: @dataclasses.dataclass(frozen=True)
f.write(hhblits_bfd_uniclust_result["a3m"]) class _FastaChain:
sequence: str
description: str
def _make_chain_id_map(
sequences: Sequence[str],
descriptions: Sequence[str],
) -> Mapping[str, _FastaChain]:
"""Makes a mapping from PDB-format chain ID to sequence and description."""
if len(sequences) != len(descriptions):
raise ValueError('sequences and descriptions must have equal length. '
f'Got {len(sequences)} != {len(descriptions)}.')
if len(sequences) > protein.PDB_MAX_CHAINS:
raise ValueError('Cannot process more chains than the PDB format supports. '
f'Got {len(sequences)} chains.')
chain_id_map = {}
for chain_id, sequence, description in zip(
protein.PDB_CHAIN_IDS, sequences, descriptions
):
chain_id_map[chain_id] = _FastaChain(
sequence=sequence, description=description
)
return chain_id_map
@contextlib.contextmanager
def temp_fasta_file(fasta_str: str):
with tempfile.NamedTemporaryFile('w', suffix='.fasta') as fasta_file:
fasta_file.write(fasta_str)
fasta_file.seek(0)
yield fasta_file.name
def convert_monomer_features(
monomer_features: FeatureDict,
chain_id: str
) -> FeatureDict:
"""Reshapes and modifies monomer features for multimer models."""
converted = {}
converted['auth_chain_id'] = np.asarray(chain_id, dtype=np.object_)
unnecessary_leading_dim_feats = {
'sequence', 'domain_name', 'num_alignments', 'seq_length'
}
for feature_name, feature in monomer_features.items():
if feature_name in unnecessary_leading_dim_feats:
# asarray ensures it's a np.ndarray.
feature = np.asarray(feature[0], dtype=feature.dtype)
elif feature_name == 'aatype':
# The multimer model performs the one-hot operation itself.
feature = np.argmax(feature, axis=-1).astype(np.int32)
elif feature_name == 'template_aatype':
feature = np.argmax(feature, axis=-1).astype(np.int32)
new_order_list = residue_constants.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
feature = np.take(new_order_list, feature.astype(np.int32), axis=0)
elif feature_name == 'template_all_atom_masks':
feature_name = 'template_all_atom_mask'
converted[feature_name] = feature
return converted
def int_id_to_str_id(num: int) -> str:
"""Encodes a number as a string, using reverse spreadsheet style naming.
Args:
num: A positive integer.
Returns:
A string that encodes the positive integer using reverse spreadsheet style,
naming e.g. 1 = A, 2 = B, ..., 27 = AA, 28 = BA, 29 = CA, ... This is the
usual way to encode chain IDs in mmCIF files.
"""
if num <= 0:
raise ValueError(f'Only positive integers allowed, got {num}.')
num = num - 1 # 1-based indexing.
output = []
while num >= 0:
output.append(chr(num % 26 + ord('A')))
num = num // 26 - 1
return ''.join(output)
def add_assembly_features(
all_chain_features: MutableMapping[str, FeatureDict],
) -> MutableMapping[str, FeatureDict]:
"""Add features to distinguish between chains.
Args:
all_chain_features: A dictionary which maps chain_id to a dictionary of
features for each chain.
Returns:
all_chain_features: A dictionary which maps strings of the form
`<seq_id>_<sym_id>` to the corresponding chain features. E.g. two
chains from a homodimer would have keys A_1 and A_2. Two chains from a
heterodimer would have keys A_1 and B_1.
"""
# Group the chains by sequence
seq_to_entity_id = {}
grouped_chains = collections.defaultdict(list)
for chain_id, chain_features in all_chain_features.items():
seq = str(chain_features['sequence'])
if seq not in seq_to_entity_id:
seq_to_entity_id[seq] = len(seq_to_entity_id) + 1
grouped_chains[seq_to_entity_id[seq]].append(chain_features)
new_all_chain_features = {}
chain_id = 1
for entity_id, group_chain_features in grouped_chains.items():
for sym_id, chain_features in enumerate(group_chain_features, start=1):
new_all_chain_features[
f'{int_id_to_str_id(entity_id)}_{sym_id}'] = chain_features
seq_length = chain_features['seq_length']
chain_features['asym_id'] = (
chain_id * np.ones(seq_length)
).astype(np.int64)
chain_features['sym_id'] = (
sym_id * np.ones(seq_length)
).astype(np.int64)
chain_features['entity_id'] = (
entity_id * np.ones(seq_length)
).astype(np.int64)
chain_id += 1
return new_all_chain_features
def pad_msa(np_example, min_num_seq):
np_example = dict(np_example)
num_seq = np_example['msa'].shape[0]
if num_seq < min_num_seq:
for feat in ('msa', 'deletion_matrix', 'bert_mask', 'msa_mask'):
np_example[feat] = np.pad(
np_example[feat], ((0, min_num_seq - num_seq), (0, 0)))
np_example['cluster_bias_mask'] = np.pad(
np_example['cluster_bias_mask'], ((0, min_num_seq - num_seq),))
return np_example
class DataPipeline: class DataPipeline:
...@@ -514,10 +701,10 @@ class DataPipeline: ...@@ -514,10 +701,10 @@ class DataPipeline:
def _parse_msa_data( def _parse_msa_data(
self, self,
alignment_dir: str, alignment_dir: str,
alignment_index: Optional[Any] = None, alignment_index: Optional[Any] = None
) -> Mapping[str, Any]: ) -> Mapping[str, Any]:
msa_data = {} msa_data = {}
if(alignment_index is not None): if alignment_index is not None:
fp = open(os.path.join(alignment_dir, alignment_index["db"]), "rb") fp = open(os.path.join(alignment_dir, alignment_index["db"]), "rb")
def read_msa(start, size): def read_msa(start, size):
...@@ -526,49 +713,47 @@ class DataPipeline: ...@@ -526,49 +713,47 @@ class DataPipeline:
return msa return msa
for (name, start, size) in alignment_index["files"]: for (name, start, size) in alignment_index["files"]:
ext = os.path.splitext(name)[-1] filename, ext = os.path.splitext(name)
if(ext == ".a3m"): if ext == ".a3m":
msa, deletion_matrix = parsers.parse_a3m( msa = parsers.parse_a3m(
read_msa(start, size)
)
data = {"msa": msa, "deletion_matrix": deletion_matrix}
elif(ext == ".sto"):
msa, deletion_matrix, _ = parsers.parse_stockholm(
read_msa(start, size) read_msa(start, size)
) )
data = {"msa": msa, "deletion_matrix": deletion_matrix} # The "hmm_output" exception is a crude way to exclude
# multimer template hits.
# Multimer "uniprot_hits" processed separately.
elif ext == ".sto" and filename not in ["uniprot_hits", "hmm_output"]:
msa = parsers.parse_stockholm(read_msa(start, size))
else: else:
continue continue
msa_data[name] = data msa_data[name] = msa
fp.close() fp.close()
else: else:
for f in os.listdir(alignment_dir): for f in os.listdir(alignment_dir):
path = os.path.join(alignment_dir, f) path = os.path.join(alignment_dir, f)
ext = os.path.splitext(f)[-1] filename, ext = os.path.splitext(f)
if(ext == ".a3m"): if ext == ".a3m":
with open(path, "r") as fp: with open(path, "r") as fp:
msa, deletion_matrix = parsers.parse_a3m(fp.read()) msa = parsers.parse_a3m(fp.read())
data = {"msa": msa, "deletion_matrix": deletion_matrix} elif ext == ".sto" and filename not in ["uniprot_hits", "hmm_output"]:
elif(ext == ".sto"):
with open(path, "r") as fp: with open(path, "r") as fp:
msa, deletion_matrix, _ = parsers.parse_stockholm( msa = parsers.parse_stockholm(
fp.read() fp.read()
) )
data = {"msa": msa, "deletion_matrix": deletion_matrix}
else: else:
continue continue
msa_data[f] = data msa_data[f] = msa
return msa_data return msa_data
def _parse_template_hits( def _parse_template_hit_files(
self, self,
alignment_dir: str, alignment_dir: str,
input_sequence: str,
alignment_index: Optional[Any] = None alignment_index: Optional[Any] = None
) -> Mapping[str, Any]: ) -> Mapping[str, Any]:
all_hits = {} all_hits = {}
...@@ -585,6 +770,12 @@ class DataPipeline: ...@@ -585,6 +770,12 @@ class DataPipeline:
if(ext == ".hhr"): if(ext == ".hhr"):
hits = parsers.parse_hhr(read_template(start, size)) hits = parsers.parse_hhr(read_template(start, size))
all_hits[name] = hits all_hits[name] = hits
elif(name == "hmmsearch_output.sto"):
hits = parsers.parse_hmmsearch_sto(
read_template(start, size),
input_sequence,
)
all_hits[name] = hits
fp.close() fp.close()
else: else:
...@@ -596,13 +787,20 @@ class DataPipeline: ...@@ -596,13 +787,20 @@ class DataPipeline:
with open(path, "r") as fp: with open(path, "r") as fp:
hits = parsers.parse_hhr(fp.read()) hits = parsers.parse_hhr(fp.read())
all_hits[f] = hits all_hits[f] = hits
elif(f == "hmm_output.sto"):
with open(path, "r") as fp:
hits = parsers.parse_hmmsearch_sto(
fp.read(),
input_sequence,
)
all_hits[f] = hits
return all_hits return all_hits
def _get_msas(self, def _get_msas(self,
alignment_dir: str, alignment_dir: str,
input_sequence: Optional[str] = None, input_sequence: Optional[str] = None,
alignment_index: Optional[str] = None, alignment_index: Optional[Any] = None,
): ):
msa_data = self._parse_msa_data(alignment_dir, alignment_index) msa_data = self._parse_msa_data(alignment_dir, alignment_index)
if(len(msa_data) == 0): if(len(msa_data) == 0):
...@@ -613,29 +811,23 @@ class DataPipeline: ...@@ -613,29 +811,23 @@ class DataPipeline:
must be provided. must be provided.
""" """
) )
msa_data["dummy"] = {
"msa": [input_sequence],
"deletion_matrix": [[0 for _ in input_sequence]],
}
msas, deletion_matrices = zip(*[ msa_data["dummy"] = make_dummy_msa_obj(input_sequence)
(v["msa"], v["deletion_matrix"]) for v in msa_data.values()
])
return msas, deletion_matrices return list(msa_data.values())
def _process_msa_feats( def _process_msa_feats(
self, self,
alignment_dir: str, alignment_dir: str,
input_sequence: Optional[str] = None, input_sequence: Optional[str] = None,
alignment_index: Optional[str] = None alignment_index: Optional[Any] = None
) -> Mapping[str, Any]: ) -> Mapping[str, Any]:
msas, deletion_matrices = self._get_msas(
msas = self._get_msas(
alignment_dir, input_sequence, alignment_index alignment_dir, input_sequence, alignment_index
) )
msa_features = make_msa_features( msa_features = make_msa_features(
msas=msas, msas=msas
deletion_matrices=deletion_matrices,
) )
return msa_features return msa_features
...@@ -660,10 +852,10 @@ class DataPipeline: ...@@ -660,10 +852,10 @@ class DataPipeline:
self, self,
fasta_path: str, fasta_path: str,
alignment_dir: str, alignment_dir: str,
alignment_index: Optional[str] = None, alignment_index: Optional[Any] = None,
seqemb_mode: bool = False, seqemb_mode: bool = False,
) -> FeatureDict: ) -> FeatureDict:
"""Assembles features for a single sequence in a FASTA file""" """Assembles features for a single sequence in a FASTA file"""
with open(fasta_path) as f: with open(fasta_path) as f:
fasta_str = f.read() fasta_str = f.read()
input_seqs, input_descs = parsers.parse_fasta(fasta_str) input_seqs, input_descs = parsers.parse_fasta(fasta_str)
...@@ -675,7 +867,12 @@ class DataPipeline: ...@@ -675,7 +867,12 @@ class DataPipeline:
input_description = input_descs[0] input_description = input_descs[0]
num_res = len(input_sequence) num_res = len(input_sequence)
hits = self._parse_template_hits(alignment_dir, alignment_index) hits = self._parse_template_hit_files(
alignment_dir=alignment_dir,
input_sequence=input_sequence,
alignment_index=alignment_index,
)
template_features = make_template_features( template_features = make_template_features(
input_sequence, input_sequence,
hits, hits,
...@@ -700,7 +897,7 @@ class DataPipeline: ...@@ -700,7 +897,7 @@ class DataPipeline:
**sequence_features, **sequence_features,
**msa_features, **msa_features,
**template_features, **template_features,
**sequence_embedding_features, **sequence_embedding_features
} }
def process_mmcif( def process_mmcif(
...@@ -708,7 +905,7 @@ class DataPipeline: ...@@ -708,7 +905,7 @@ class DataPipeline:
mmcif: mmcif_parsing.MmcifObject, # parsing is expensive, so no path mmcif: mmcif_parsing.MmcifObject, # parsing is expensive, so no path
alignment_dir: str, alignment_dir: str,
chain_id: Optional[str] = None, chain_id: Optional[str] = None,
alignment_index: Optional[str] = None, alignment_index: Optional[Any] = None,
seqemb_mode: bool = False, seqemb_mode: bool = False,
) -> FeatureDict: ) -> FeatureDict:
""" """
...@@ -727,12 +924,15 @@ class DataPipeline: ...@@ -727,12 +924,15 @@ class DataPipeline:
mmcif_feats = make_mmcif_features(mmcif, chain_id) mmcif_feats = make_mmcif_features(mmcif, chain_id)
input_sequence = mmcif.chain_to_seqres[chain_id] input_sequence = mmcif.chain_to_seqres[chain_id]
hits = self._parse_template_hits(alignment_dir, alignment_index) hits = self._parse_template_hit_files(
alignment_dir=alignment_dir,
input_sequence=input_sequence,
alignment_index=alignment_index)
template_features = make_template_features( template_features = make_template_features(
input_sequence, input_sequence,
hits, hits,
self.template_featurizer, self.template_featurizer
query_release_date=to_date(mmcif.header["release_date"])
) )
sequence_embedding_features = {} sequence_embedding_features = {}
...@@ -752,7 +952,7 @@ class DataPipeline: ...@@ -752,7 +952,7 @@ class DataPipeline:
is_distillation: bool = True, is_distillation: bool = True,
chain_id: Optional[str] = None, chain_id: Optional[str] = None,
_structure_index: Optional[str] = None, _structure_index: Optional[str] = None,
alignment_index: Optional[str] = None, alignment_index: Optional[Any] = None,
seqemb_mode: bool = False, seqemb_mode: bool = False,
) -> FeatureDict: ) -> FeatureDict:
""" """
...@@ -772,15 +972,20 @@ class DataPipeline: ...@@ -772,15 +972,20 @@ class DataPipeline:
pdb_str = f.read() pdb_str = f.read()
protein_object = protein.from_pdb_string(pdb_str, chain_id) protein_object = protein.from_pdb_string(pdb_str, chain_id)
input_sequence = _aatype_to_str_sequence(protein_object.aatype) input_sequence = _aatype_to_str_sequence(protein_object.aatype)
description = os.path.splitext(os.path.basename(pdb_path))[0].upper() description = os.path.splitext(os.path.basename(pdb_path))[0].upper()
pdb_feats = make_pdb_features( pdb_feats = make_pdb_features(
protein_object, protein_object,
description, description,
is_distillation=is_distillation is_distillation=is_distillation
) )
hits = self._parse_template_hits(alignment_dir, alignment_index) hits = self._parse_template_hit_files(
alignment_dir=alignment_dir,
input_sequence=input_sequence,
alignment_index=alignment_index,
)
template_features = make_template_features( template_features = make_template_features(
input_sequence, input_sequence,
hits, hits,
...@@ -801,7 +1006,7 @@ class DataPipeline: ...@@ -801,7 +1006,7 @@ class DataPipeline:
self, self,
core_path: str, core_path: str,
alignment_dir: str, alignment_dir: str,
alignment_index: Optional[str] = None, alignment_index: Optional[Any] = None,
seqemb_mode: bool = False, seqemb_mode: bool = False,
) -> FeatureDict: ) -> FeatureDict:
""" """
...@@ -811,11 +1016,16 @@ class DataPipeline: ...@@ -811,11 +1016,16 @@ class DataPipeline:
core_str = f.read() core_str = f.read()
protein_object = protein.from_proteinnet_string(core_str) protein_object = protein.from_proteinnet_string(core_str)
input_sequence = _aatype_to_str_sequence(protein_object.aatype) input_sequence = _aatype_to_str_sequence(protein_object.aatype)
description = os.path.splitext(os.path.basename(core_path))[0].upper() description = os.path.splitext(os.path.basename(core_path))[0].upper()
core_feats = make_protein_features(protein_object, description) core_feats = make_protein_features(protein_object, description)
hits = self._parse_template_hits(alignment_dir, alignment_index) hits = self._parse_template_hit_files(
alignment_dir=alignment_dir,
input_sequence=input_sequence,
alignment_index=alignment_index,
)
template_features = make_template_features( template_features = make_template_features(
input_sequence, input_sequence,
hits, hits,
...@@ -833,10 +1043,10 @@ class DataPipeline: ...@@ -833,10 +1043,10 @@ class DataPipeline:
return {**core_feats, **template_features, **msa_features, **sequence_embedding_features} return {**core_feats, **template_features, **msa_features, **sequence_embedding_features}
def process_multiseq_fasta(self, def process_multiseq_fasta(self,
fasta_path: str, fasta_path: str,
super_alignment_dir: str, super_alignment_dir: str,
ri_gap: int = 200, ri_gap: int = 200,
) -> FeatureDict: ) -> FeatureDict:
""" """
Assembles features for a multi-sequence FASTA. Uses Minkyung Baek's Assembles features for a multi-sequence FASTA. Uses Minkyung Baek's
hack from Twitter (a.k.a. AlphaFold-Gap). hack from Twitter (a.k.a. AlphaFold-Gap).
...@@ -845,7 +1055,7 @@ class DataPipeline: ...@@ -845,7 +1055,7 @@ class DataPipeline:
fasta_str = f.read() fasta_str = f.read()
input_seqs, input_descs = parsers.parse_fasta(fasta_str) input_seqs, input_descs = parsers.parse_fasta(fasta_str)
# No whitespace allowed # No whitespace allowed
input_descs = [i.split()[0] for i in input_descs] input_descs = [i.split()[0] for i in input_descs]
...@@ -872,14 +1082,15 @@ class DataPipeline: ...@@ -872,14 +1082,15 @@ class DataPipeline:
alignment_dir = os.path.join( alignment_dir = os.path.join(
super_alignment_dir, desc super_alignment_dir, desc
) )
msas, deletion_mats = self._get_msas( msas = self._get_msas(
alignment_dir, seq, None alignment_dir, seq, None
) )
msa_list.append(msas) msa_list.append([m.sequences for m in msas])
deletion_mat_list.append(deletion_mats) deletion_mat_list.append([m.deletion_matrix for m in msas])
final_msa = [] final_msa = []
final_deletion_mat = [] final_deletion_mat = []
final_msa_obj = []
msa_it = enumerate(zip(msa_list, deletion_mat_list)) msa_it = enumerate(zip(msa_list, deletion_mat_list))
for i, (msas, deletion_mats) in msa_it: for i, (msas, deletion_mats) in msa_it:
prec, post = sum(seq_lens[:i]), sum(seq_lens[i + 1:]) prec, post = sum(seq_lens[:i]), sum(seq_lens[i + 1:])
...@@ -887,18 +1098,19 @@ class DataPipeline: ...@@ -887,18 +1098,19 @@ class DataPipeline:
[prec * '-' + seq + post * '-' for seq in msa] for msa in msas [prec * '-' + seq + post * '-' for seq in msa] for msa in msas
] ]
deletion_mats = [ deletion_mats = [
[prec * [0] + dml + post * [0] for dml in deletion_mat] [prec * [0] + dml + post * [0] for dml in deletion_mat]
for deletion_mat in deletion_mats for deletion_mat in deletion_mats
] ]
assert(len(msas[0][-1]) == len(input_sequence)) assert (len(msas[0][-1]) == len(input_sequence))
final_msa.extend(msas) final_msa.extend(msas)
final_deletion_mat.extend(deletion_mats) final_deletion_mat.extend(deletion_mats)
final_msa_obj.extend([parsers.Msa(sequences=msas[k], deletion_matrix=deletion_mats[k], descriptions=None)
for k in range(len(msas))])
msa_features = make_msa_features( msa_features = make_msa_features(
msas=final_msa, msas=final_msa_obj
deletion_matrices=final_deletion_mat,
) )
template_feature_list = [] template_feature_list = []
...@@ -906,7 +1118,10 @@ class DataPipeline: ...@@ -906,7 +1118,10 @@ class DataPipeline:
alignment_dir = os.path.join( alignment_dir = os.path.join(
super_alignment_dir, desc super_alignment_dir, desc
) )
hits = self._parse_template_hits(alignment_dir, alignment_index=None) hits = self._parse_template_hit_files(alignment_dir=alignment_dir,
input_sequence=seq,
alignment_index=None)
template_features = make_template_features( template_features = make_template_features(
seq, seq,
hits, hits,
...@@ -918,6 +1133,228 @@ class DataPipeline: ...@@ -918,6 +1133,228 @@ class DataPipeline:
return { return {
**sequence_features, **sequence_features,
**msa_features, **msa_features,
**template_features, **template_features,
} }
class DataPipelineMultimer:
"""Runs the alignment tools and assembles the input features."""
def __init__(self,
monomer_data_pipeline: DataPipeline,
):
"""Initializes the data pipeline.
Args:
monomer_data_pipeline: An instance of pipeline.DataPipeline - that runs
the data pipeline for the monomer AlphaFold system.
jackhmmer_binary_path: Location of the jackhmmer binary.
uniprot_database_path: Location of the unclustered uniprot sequences, that
will be searched with jackhmmer and used for MSA pairing.
max_uniprot_hits: The maximum number of hits to return from uniprot.
use_precomputed_msas: Whether to use pre-existing MSAs; see run_alphafold.
"""
self._monomer_data_pipeline = monomer_data_pipeline
def _process_single_chain(
self,
chain_id: str,
sequence: str,
description: str,
chain_alignment_dir: str,
chain_alignment_index: Optional[Any],
is_homomer_or_monomer: bool
) -> FeatureDict:
"""Runs the monomer pipeline on a single chain."""
chain_fasta_str = f'>{chain_id}\n{sequence}\n'
if chain_alignment_index is None and not os.path.exists(chain_alignment_dir):
raise ValueError(f"Alignments for {chain_id} not found...")
with temp_fasta_file(chain_fasta_str) as chain_fasta_path:
chain_features = self._monomer_data_pipeline.process_fasta(
fasta_path=chain_fasta_path,
alignment_dir=chain_alignment_dir,
alignment_index=chain_alignment_index
)
# We only construct the pairing features if there are 2 or more unique
# sequences.
if not is_homomer_or_monomer:
all_seq_msa_features = self._all_seq_msa_features(
chain_alignment_dir,
chain_alignment_index
)
chain_features.update(all_seq_msa_features)
return chain_features
@staticmethod
def _all_seq_msa_features(alignment_dir, alignment_index):
"""Get MSA features for unclustered uniprot, for pairing."""
if alignment_index is not None:
fp = open(os.path.join(alignment_dir, alignment_index["db"]), "rb")
def read_msa(start, size):
fp.seek(start)
msa = fp.read(size).decode("utf-8")
return msa
start, size = next(iter((start, size) for name, start, size in alignment_index["files"]
if name == 'uniprot_hits.sto'))
msa = parsers.parse_stockholm(read_msa(start, size))
fp.close()
else:
uniprot_msa_path = os.path.join(alignment_dir, "uniprot_hits.sto")
if not os.path.exists(uniprot_msa_path):
chain_id = os.path.basename(os.path.normpath(alignment_dir))
raise ValueError(f"Missing 'uniprot_hits.sto' for {chain_id}. "
f"This is required for Multimer MSA pairing.")
with open(uniprot_msa_path, "r") as fp:
uniprot_msa_string = fp.read()
msa = parsers.parse_stockholm(uniprot_msa_string)
all_seq_features = make_msa_features([msa])
valid_feats = msa_pairing.MSA_FEATURES + (
'msa_species_identifiers',
)
feats = {
f'{k}_all_seq': v for k, v in all_seq_features.items()
if k in valid_feats
}
return feats
def process_fasta(self,
fasta_path: str,
alignment_dir: str,
alignment_index: Optional[Any] = None
) -> FeatureDict:
"""Creates features."""
with open(fasta_path) as f:
input_fasta_str = f.read()
input_seqs, input_descs = parsers.parse_fasta(input_fasta_str)
all_chain_features = {}
sequence_features = {}
is_homomer_or_monomer = len(set(input_seqs)) == 1
for desc, seq in zip(input_descs, input_seqs):
if seq in sequence_features:
all_chain_features[desc] = copy.deepcopy(
sequence_features[seq]
)
continue
if alignment_index is not None:
chain_alignment_index = alignment_index.get(desc)
chain_alignment_dir = alignment_dir
else:
chain_alignment_index = None
chain_alignment_dir = os.path.join(alignment_dir, desc)
chain_features = self._process_single_chain(
chain_id=desc,
sequence=seq,
description=desc,
chain_alignment_dir=chain_alignment_dir,
chain_alignment_index=chain_alignment_index,
is_homomer_or_monomer=is_homomer_or_monomer
)
chain_features = convert_monomer_features(
chain_features,
chain_id=desc
)
all_chain_features[desc] = chain_features
sequence_features[seq] = chain_features
all_chain_features = add_assembly_features(all_chain_features)
np_example = feature_processing_multimer.pair_and_merge(
all_chain_features=all_chain_features,
)
# Pad MSA to avoid zero-sized extra_msa.
np_example = pad_msa(np_example, 512)
return np_example
def get_mmcif_features(
self, mmcif_object: mmcif_parsing.MmcifObject, chain_id: str
) -> FeatureDict:
mmcif_feats = {}
all_atom_positions, all_atom_mask = mmcif_parsing.get_atom_coords(
mmcif_object=mmcif_object, chain_id=chain_id
)
mmcif_feats["all_atom_positions"] = all_atom_positions
mmcif_feats["all_atom_mask"] = all_atom_mask
mmcif_feats["resolution"] = np.array(
mmcif_object.header["resolution"], dtype=np.float32
)
mmcif_feats["release_date"] = np.array(
[mmcif_object.header["release_date"].encode("utf-8")], dtype=np.object_
)
mmcif_feats["is_distillation"] = np.array(0., dtype=np.float32)
return mmcif_feats
def process_mmcif(
self,
mmcif: mmcif_parsing.MmcifObject, # parsing is expensive, so no path
alignment_dir: str,
alignment_index: Optional[Any] = None,
) -> FeatureDict:
all_chain_features = {}
sequence_features = {}
is_homomer_or_monomer = len(set(list(mmcif.chain_to_seqres.values()))) == 1
for chain_id, seq in mmcif.chain_to_seqres.items():
desc= "_".join([mmcif.file_id, chain_id])
if seq in sequence_features:
all_chain_features[desc] = copy.deepcopy(
sequence_features[seq]
)
continue
if alignment_index is not None:
chain_alignment_index = alignment_index.get(desc)
chain_alignment_dir = alignment_dir
else:
chain_alignment_index = None
chain_alignment_dir = os.path.join(alignment_dir, desc)
chain_features = self._process_single_chain(
chain_id=desc,
sequence=seq,
description=desc,
chain_alignment_dir=chain_alignment_dir,
chain_alignment_index=chain_alignment_index,
is_homomer_or_monomer=is_homomer_or_monomer
)
chain_features = convert_monomer_features(
chain_features,
chain_id=desc
)
mmcif_feats = self.get_mmcif_features(mmcif, chain_id)
chain_features.update(mmcif_feats)
all_chain_features[desc] = chain_features
sequence_features[seq] = chain_features
all_chain_features = add_assembly_features(all_chain_features)
np_example = feature_processing_multimer.pair_and_merge(
all_chain_features=all_chain_features,
)
# Pad MSA to avoid zero-sized extra_msa.
np_example = pad_msa(np_example, 512)
return np_example
...@@ -23,6 +23,9 @@ import torch ...@@ -23,6 +23,9 @@ import torch
from openfold.config import NUM_RES, NUM_EXTRA_SEQ, NUM_TEMPLATES, NUM_MSA_SEQ from openfold.config import NUM_RES, NUM_EXTRA_SEQ, NUM_TEMPLATES, NUM_MSA_SEQ
from openfold.np import residue_constants as rc from openfold.np import residue_constants as rc
from openfold.utils.rigid_utils import Rotation, Rigid from openfold.utils.rigid_utils import Rotation, Rigid
from openfold.utils.geometry.rigid_matrix_vector import Rigid3Array
from openfold.utils.geometry.rotation_matrix import Rot3Array
from openfold.utils.geometry.vector import Vec3Array
from openfold.utils.tensor_utils import ( from openfold.utils.tensor_utils import (
tree_map, tree_map,
tensor_tree_map, tensor_tree_map,
...@@ -86,18 +89,17 @@ def make_all_atom_aatype(protein): ...@@ -86,18 +89,17 @@ def make_all_atom_aatype(protein):
def fix_templates_aatype(protein): def fix_templates_aatype(protein):
# Map one-hot to indices # Map one-hot to indices
num_templates = protein["template_aatype"].shape[0] num_templates = protein["template_aatype"].shape[0]
if(num_templates > 0): protein["template_aatype"] = torch.argmax(
protein["template_aatype"] = torch.argmax( protein["template_aatype"], dim=-1
protein["template_aatype"], dim=-1 )
) # Map hhsearch-aatype to our aatype.
# Map hhsearch-aatype to our aatype. new_order_list = rc.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
new_order_list = rc.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE new_order = torch.tensor(
new_order = torch.tensor( new_order_list, dtype=torch.int64, device=protein["template_aatype"].device,
new_order_list, dtype=torch.int64, device=protein["aatype"].device, ).expand(num_templates, -1)
).expand(num_templates, -1) protein["template_aatype"] = torch.gather(
protein["template_aatype"] = torch.gather( new_order, 1, index=protein["template_aatype"]
new_order, 1, index=protein["template_aatype"] )
)
return protein return protein
...@@ -447,13 +449,15 @@ def make_hhblits_profile(protein): ...@@ -447,13 +449,15 @@ def make_hhblits_profile(protein):
@curry1 @curry1
def make_masked_msa(protein, config, replace_fraction): def make_masked_msa(protein, config, replace_fraction, seed):
"""Create data for BERT on raw MSA.""" """Create data for BERT on raw MSA."""
device = protein["msa"].device
# Add a random amino acid uniformly. # Add a random amino acid uniformly.
random_aa = torch.tensor( random_aa = torch.tensor(
[0.05] * 20 + [0.0, 0.0], [0.05] * 20 + [0.0, 0.0],
dtype=torch.float32, dtype=torch.float32,
device=protein["aatype"].device device=device
) )
categorical_probs = ( categorical_probs = (
...@@ -473,11 +477,18 @@ def make_masked_msa(protein, config, replace_fraction): ...@@ -473,11 +477,18 @@ def make_masked_msa(protein, config, replace_fraction):
assert mask_prob >= 0.0 assert mask_prob >= 0.0
categorical_probs = torch.nn.functional.pad( categorical_probs = torch.nn.functional.pad(
categorical_probs, pad_shapes, value=mask_prob categorical_probs, pad_shapes, value=mask_prob,
) )
sh = protein["msa"].shape sh = protein["msa"].shape
mask_position = torch.rand(sh) < replace_fraction
g = None
if seed is not None:
g = torch.Generator(device=protein["msa"].device)
g.manual_seed(seed)
sample = torch.rand(sh, device=device, generator=g)
mask_position = sample < replace_fraction
bert_msa = shaped_categorical(categorical_probs) bert_msa = shaped_categorical(categorical_probs)
bert_msa = torch.where(mask_position, bert_msa, protein["msa"]) bert_msa = torch.where(mask_position, bert_msa, protein["msa"])
...@@ -670,7 +681,7 @@ def make_atom14_masks(protein): ...@@ -670,7 +681,7 @@ def make_atom14_masks(protein):
def make_atom14_masks_np(batch): def make_atom14_masks_np(batch):
batch = tree_map( batch = tree_map(
lambda n: torch.tensor(n, device="cpu"), lambda n: torch.tensor(n, device="cpu"),
batch, batch,
np.ndarray np.ndarray
) )
out = make_atom14_masks(batch) out = make_atom14_masks(batch)
...@@ -736,7 +747,7 @@ def make_atom14_positions(protein): ...@@ -736,7 +747,7 @@ def make_atom14_positions(protein):
for index, correspondence in enumerate(correspondences): for index, correspondence in enumerate(correspondences):
renaming_matrix[index, correspondence] = 1.0 renaming_matrix[index, correspondence] = 1.0
all_matrices[resname] = renaming_matrix all_matrices[resname] = renaming_matrix
renaming_matrices = torch.stack( renaming_matrices = torch.stack(
[all_matrices[restype] for restype in restype_3] [all_matrices[restype] for restype in restype_3]
) )
...@@ -782,10 +793,14 @@ def make_atom14_positions(protein): ...@@ -782,10 +793,14 @@ def make_atom14_positions(protein):
def atom37_to_frames(protein, eps=1e-8): def atom37_to_frames(protein, eps=1e-8):
is_multimer = "asym_id" in protein
aatype = protein["aatype"] aatype = protein["aatype"]
all_atom_positions = protein["all_atom_positions"] all_atom_positions = protein["all_atom_positions"]
all_atom_mask = protein["all_atom_mask"] all_atom_mask = protein["all_atom_mask"]
if is_multimer:
all_atom_positions = Vec3Array.from_array(all_atom_positions)
batch_dims = len(aatype.shape[:-1]) batch_dims = len(aatype.shape[:-1])
restype_rigidgroup_base_atom_names = np.full([21, 8, 3], "", dtype=object) restype_rigidgroup_base_atom_names = np.full([21, 8, 3], "", dtype=object)
...@@ -832,19 +847,37 @@ def atom37_to_frames(protein, eps=1e-8): ...@@ -832,19 +847,37 @@ def atom37_to_frames(protein, eps=1e-8):
no_batch_dims=batch_dims, no_batch_dims=batch_dims,
) )
base_atom_pos = batched_gather( if is_multimer:
all_atom_positions, base_atom_pos = [batched_gather(
residx_rigidgroup_base_atom37_idx, pos,
dim=-2, residx_rigidgroup_base_atom37_idx,
no_batch_dims=len(all_atom_positions.shape[:-2]), dim=-1,
) no_batch_dims=len(all_atom_positions.shape[:-1]),
) for pos in all_atom_positions]
base_atom_pos = Vec3Array.from_array(torch.stack(base_atom_pos, dim=-1))
else:
base_atom_pos = batched_gather(
all_atom_positions,
residx_rigidgroup_base_atom37_idx,
dim=-2,
no_batch_dims=len(all_atom_positions.shape[:-2]),
)
gt_frames = Rigid.from_3_points( if is_multimer:
p_neg_x_axis=base_atom_pos[..., 0, :], point_on_neg_x_axis = base_atom_pos[:, :, 0]
origin=base_atom_pos[..., 1, :], origin = base_atom_pos[:, :, 1]
p_xy_plane=base_atom_pos[..., 2, :], point_on_xy_plane = base_atom_pos[:, :, 2]
eps=eps, gt_rotation = Rot3Array.from_two_vectors(
) origin - point_on_neg_x_axis, point_on_xy_plane - origin)
gt_frames = Rigid3Array(gt_rotation, origin)
else:
gt_frames = Rigid.from_3_points(
p_neg_x_axis=base_atom_pos[..., 0, :],
origin=base_atom_pos[..., 1, :],
p_xy_plane=base_atom_pos[..., 2, :],
eps=eps,
)
group_exists = batched_gather( group_exists = batched_gather(
restype_rigidgroup_mask, restype_rigidgroup_mask,
...@@ -865,9 +898,13 @@ def atom37_to_frames(protein, eps=1e-8): ...@@ -865,9 +898,13 @@ def atom37_to_frames(protein, eps=1e-8):
rots = torch.tile(rots, (*((1,) * batch_dims), 8, 1, 1)) rots = torch.tile(rots, (*((1,) * batch_dims), 8, 1, 1))
rots[..., 0, 0, 0] = -1 rots[..., 0, 0, 0] = -1
rots[..., 0, 2, 2] = -1 rots[..., 0, 2, 2] = -1
rots = Rotation(rot_mats=rots)
gt_frames = gt_frames.compose(Rigid(rots, None)) if is_multimer:
gt_frames = gt_frames.compose_rotation(
Rot3Array.from_array(rots))
else:
rots = Rotation(rot_mats=rots)
gt_frames = gt_frames.compose(Rigid(rots, None))
restype_rigidgroup_is_ambiguous = all_atom_mask.new_zeros( restype_rigidgroup_is_ambiguous = all_atom_mask.new_zeros(
*((1,) * batch_dims), 21, 8 *((1,) * batch_dims), 21, 8
...@@ -901,12 +938,18 @@ def atom37_to_frames(protein, eps=1e-8): ...@@ -901,12 +938,18 @@ def atom37_to_frames(protein, eps=1e-8):
no_batch_dims=batch_dims, no_batch_dims=batch_dims,
) )
residx_rigidgroup_ambiguity_rot = Rotation( if is_multimer:
rot_mats=residx_rigidgroup_ambiguity_rot ambiguity_rot = Rot3Array.from_array(residx_rigidgroup_ambiguity_rot)
)
alt_gt_frames = gt_frames.compose( # Create the alternative ground truth frames.
Rigid(residx_rigidgroup_ambiguity_rot, None) alt_gt_frames = gt_frames.compose_rotation(ambiguity_rot)
) else:
residx_rigidgroup_ambiguity_rot = Rotation(
rot_mats=residx_rigidgroup_ambiguity_rot
)
alt_gt_frames = gt_frames.compose(
Rigid(residx_rigidgroup_ambiguity_rot, None)
)
gt_frames_tensor = gt_frames.to_tensor_4x4() gt_frames_tensor = gt_frames.to_tensor_4x4()
alt_gt_frames_tensor = alt_gt_frames.to_tensor_4x4() alt_gt_frames_tensor = alt_gt_frames.to_tensor_4x4()
......
from typing import Sequence
import torch
from openfold.config import NUM_RES
from openfold.data.data_transforms import curry1
from openfold.np import residue_constants as rc
from openfold.utils.tensor_utils import masked_mean
def gumbel_noise(
shape: Sequence[int],
device: torch.device,
eps=1e-6,
generator=None,
) -> torch.Tensor:
"""Generate Gumbel Noise of given Shape.
This generates samples from Gumbel(0, 1).
Args:
shape: Shape of noise to return.
Returns:
Gumbel noise of given shape.
"""
uniform_noise = torch.rand(
shape, dtype=torch.float32, device=device, generator=generator
)
gumbel = -torch.log(-torch.log(uniform_noise + eps) + eps)
return gumbel
def gumbel_max_sample(logits: torch.Tensor, generator=None) -> torch.Tensor:
"""Samples from a probability distribution given by 'logits'.
This uses Gumbel-max trick to implement the sampling in an efficient manner.
Args:
logits: Logarithm of probabilities to sample from, probabilities can be
unnormalized.
Returns:
Sample from logprobs in one-hot form.
"""
z = gumbel_noise(logits.shape, device=logits.device, generator=generator)
return torch.nn.functional.one_hot(
torch.argmax(logits + z, dim=-1),
logits.shape[-1],
)
def gumbel_argsort_sample_idx(
logits: torch.Tensor,
generator=None
) -> torch.Tensor:
"""Samples with replacement from a distribution given by 'logits'.
This uses Gumbel trick to implement the sampling an efficient manner. For a
distribution over k items this samples k times without replacement, so this
is effectively sampling a random permutation with probabilities over the
permutations derived from the logprobs.
Args:
logits: Logarithm of probabilities to sample from, probabilities can be
unnormalized.
Returns:
Sample from logprobs in one-hot form.
"""
z = gumbel_noise(logits.shape, device=logits.device, generator=generator)
return torch.argsort(logits + z, dim=-1, descending=True)
@curry1
def make_masked_msa(batch, config, replace_fraction, seed, eps=1e-6):
"""Create data for BERT on raw MSA."""
# Add a random amino acid uniformly.
random_aa = torch.Tensor(
[0.05] * 20 + [0., 0.],
device=batch['msa'].device
)
categorical_probs = (
config.uniform_prob * random_aa +
config.profile_prob * batch['msa_profile'] +
config.same_prob * torch.nn.functional.one_hot(batch['msa'], 22)
)
# Put all remaining probability on [MASK] which is a new column.
mask_prob = 1. - config.profile_prob - config.same_prob - config.uniform_prob
categorical_probs = torch.nn.functional.pad(
categorical_probs, [0,1], value=mask_prob
)
sh = batch['msa'].shape
mask_position = torch.rand(sh, device=batch['msa'].device) < replace_fraction
mask_position *= batch['msa_mask'].to(mask_position.dtype)
logits = torch.log(categorical_probs + eps)
g = None
if seed is not None:
g = torch.Generator(device=batch["msa"].device)
g.manual_seed(seed)
bert_msa = gumbel_max_sample(logits, generator=g)
bert_msa = torch.where(
mask_position,
torch.argmax(bert_msa, dim=-1),
batch['msa']
)
bert_msa *= batch['msa_mask'].to(bert_msa.dtype)
# Mix real and masked MSA.
if 'bert_mask' in batch:
batch['bert_mask'] *= mask_position.to(torch.float32)
else:
batch['bert_mask'] = mask_position.to(torch.float32)
batch['true_msa'] = batch['msa']
batch['msa'] = bert_msa
return batch
@curry1
def nearest_neighbor_clusters(batch, gap_agreement_weight=0.):
"""Assign each extra MSA sequence to its nearest neighbor in sampled MSA."""
device = batch["msa_mask"].device
# Determine how much weight we assign to each agreement. In theory, we could
# use a full blosum matrix here, but right now let's just down-weight gap
# agreement because it could be spurious.
# Never put weight on agreeing on BERT mask.
weights = torch.Tensor(
[1.] * 21 + [gap_agreement_weight] + [0.],
device=device,
)
msa_mask = batch['msa_mask']
msa_one_hot = torch.nn.functional.one_hot(batch['msa'], 23)
extra_mask = batch['extra_msa_mask']
extra_one_hot = torch.nn.functional.one_hot(batch['extra_msa'], 23)
msa_one_hot_masked = msa_mask[:, :, None] * msa_one_hot
extra_one_hot_masked = extra_mask[:, :, None] * extra_one_hot
agreement = torch.einsum(
'mrc, nrc->nm',
extra_one_hot_masked,
weights * msa_one_hot_masked
)
cluster_assignment = torch.nn.functional.softmax(1e3 * agreement, dim=0)
cluster_assignment *= torch.einsum('mr, nr->mn', msa_mask, extra_mask)
cluster_count = torch.sum(cluster_assignment, dim=-1)
cluster_count += 1. # We always include the sequence itself.
msa_sum = torch.einsum('nm, mrc->nrc', cluster_assignment, extra_one_hot_masked)
msa_sum += msa_one_hot_masked
cluster_profile = msa_sum / cluster_count[:, None, None]
extra_deletion_matrix = batch['extra_deletion_matrix']
deletion_matrix = batch['deletion_matrix']
del_sum = torch.einsum(
'nm, mc->nc',
cluster_assignment,
extra_mask * extra_deletion_matrix
)
del_sum += deletion_matrix # Original sequence.
cluster_deletion_mean = del_sum / cluster_count[:, None]
batch['cluster_profile'] = cluster_profile
batch['cluster_deletion_mean'] = cluster_deletion_mean
return batch
def create_target_feat(batch):
"""Create the target features"""
batch["target_feat"] = torch.nn.functional.one_hot(
batch["aatype"], 21
).to(torch.float32)
return batch
def create_msa_feat(batch):
"""Create and concatenate MSA features."""
device = batch["msa"]
msa_1hot = torch.nn.functional.one_hot(batch['msa'], 23)
deletion_matrix = batch['deletion_matrix']
has_deletion = torch.clamp(deletion_matrix, min=0., max=1.)[..., None]
pi = torch.acos(torch.zeros(1, device=deletion_matrix.device)) * 2
deletion_value = (torch.atan(deletion_matrix / 3.) * (2. / pi))[..., None]
deletion_mean_value = (
torch.atan(
batch['cluster_deletion_mean'] / 3.) *
(2. / pi)
)[..., None]
msa_feat = torch.cat(
[
msa_1hot,
has_deletion,
deletion_value,
batch['cluster_profile'],
deletion_mean_value
],
dim=-1,
)
batch["msa_feat"] = msa_feat
return batch
def build_extra_msa_feat(batch):
"""Expand extra_msa into 1hot and concat with other extra msa features.
We do this as late as possible as the one_hot extra msa can be very large.
Args:
batch: a dictionary with the following keys:
* 'extra_msa': [num_seq, num_res] MSA that wasn't selected as a cluster
centre. Note - This isn't one-hotted.
* 'extra_deletion_matrix': [num_seq, num_res] Number of deletions at given
position.
num_extra_msa: Number of extra msa to use.
Returns:
Concatenated tensor of extra MSA features.
"""
# 23 = 20 amino acids + 'X' for unknown + gap + bert mask
extra_msa = batch['extra_msa']
deletion_matrix = batch['extra_deletion_matrix']
msa_1hot = torch.nn.functional.one_hot(extra_msa, 23)
has_deletion = torch.clamp(deletion_matrix, min=0., max=1.)[..., None]
pi = torch.acos(torch.zeros(1, device=deletion_matrix.device)) * 2
deletion_value = (
(torch.atan(deletion_matrix / 3.) * (2. / pi))[..., None]
)
extra_msa_mask = batch['extra_msa_mask']
catted = torch.cat([msa_1hot, has_deletion, deletion_value], dim=-1)
return catted
@curry1
def sample_msa(batch, max_seq, max_extra_msa_seq, seed, inf=1e6):
"""Sample MSA randomly, remaining sequences are stored as `extra_*`.
Args:
batch: batch to sample msa from.
max_seq: number of sequences to sample.
Returns:
Protein with sampled msa.
"""
g = None
if seed is not None:
g = torch.Generator(device=batch["msa"].device)
g.manual_seed(seed)
# Sample uniformly among sequences with at least one non-masked position.
logits = (torch.clamp(torch.sum(batch['msa_mask'], dim=-1), 0., 1.) - 1.) * inf
# The cluster_bias_mask can be used to preserve the first row (target
# sequence) for each chain, for example.
if 'cluster_bias_mask' not in batch:
cluster_bias_mask = torch.nn.functional.pad(
batch['msa'].new_zeros(batch['msa'].shape[0] - 1),
(1, 0),
value=1.
)
else:
cluster_bias_mask = batch['cluster_bias_mask']
logits += cluster_bias_mask * inf
index_order = gumbel_argsort_sample_idx(logits, generator=g)
sel_idx = index_order[:max_seq]
extra_idx = index_order[max_seq:][:max_extra_msa_seq]
for k in ['msa', 'deletion_matrix', 'msa_mask', 'bert_mask']:
if k in batch:
batch['extra_' + k] = batch[k][extra_idx]
batch[k] = batch[k][sel_idx]
return batch
def make_msa_profile(batch):
"""Compute the MSA profile."""
# Compute the profile for every residue (over all MSA sequences).
batch["msa_profile"] = masked_mean(
batch['msa_mask'][..., None],
torch.nn.functional.one_hot(batch['msa'], 22),
dim=-3,
)
return batch
def randint(lower, upper, generator, device):
return int(torch.randint(
lower,
upper + 1,
(1,),
device=device,
generator=generator,
)[0])
def get_interface_residues(positions, atom_mask, asym_id, interface_threshold):
coord_diff = positions[..., None, :, :] - positions[..., None, :, :, :]
pairwise_dists = torch.sqrt(torch.sum(coord_diff ** 2, dim=-1))
diff_chain_mask = (asym_id[..., None, :] != asym_id[..., :, None]).float()
pair_mask = atom_mask[..., None, :] * atom_mask[..., None, :, :]
mask = (diff_chain_mask[..., None] * pair_mask).bool()
min_dist_per_res, _ = torch.where(mask, pairwise_dists, torch.inf).min(dim=-1)
valid_interfaces = torch.sum((min_dist_per_res < interface_threshold).float(), dim=-1)
interface_residues_idxs = torch.nonzero(valid_interfaces, as_tuple=True)[0]
return interface_residues_idxs
def get_spatial_crop_idx(protein, crop_size, interface_threshold, generator):
positions = protein["all_atom_positions"]
atom_mask = protein["all_atom_mask"]
asym_id = protein["asym_id"]
interface_residues = get_interface_residues(positions=positions,
atom_mask=atom_mask,
asym_id=asym_id,
interface_threshold=interface_threshold)
if not torch.any(interface_residues):
return get_contiguous_crop_idx(protein, crop_size, generator)
target_res_idx = randint(lower=0,
upper=interface_residues.shape[-1] - 1,
generator=generator,
device=positions.device)
target_res = interface_residues[target_res_idx]
ca_idx = rc.atom_order["CA"]
ca_positions = positions[..., ca_idx, :]
ca_mask = atom_mask[..., ca_idx].bool()
coord_diff = ca_positions[..., None, :] - ca_positions[..., None, :, :]
ca_pairwise_dists = torch.sqrt(torch.sum(coord_diff ** 2, dim=-1))
to_target_distances = ca_pairwise_dists[target_res]
break_tie = (
torch.arange(
0, to_target_distances.shape[-1], device=positions.device
).float()
* 1e-3
)
to_target_distances = torch.where(ca_mask, to_target_distances, torch.inf) + break_tie
ret = torch.argsort(to_target_distances)[:crop_size]
return ret.sort().values
def get_contiguous_crop_idx(protein, crop_size, generator):
unique_asym_ids, chain_idxs, chain_lens = protein["asym_id"].unique(dim=-1,
return_inverse=True,
return_counts=True)
shuffle_idx = torch.randperm(chain_lens.shape[-1], device=chain_lens.device, generator=generator)
_, idx_sorted = torch.sort(chain_idxs, stable=True)
cum_sum = chain_lens.cumsum(dim=0)
cum_sum = torch.cat((torch.tensor([0]), cum_sum[:-1]), dim=0)
asym_offsets = idx_sorted[cum_sum]
num_budget = crop_size
num_remaining = int(protein["seq_length"])
crop_idxs = []
for idx in shuffle_idx:
chain_len = int(chain_lens[idx])
num_remaining -= chain_len
crop_size_max = min(num_budget, chain_len)
crop_size_min = min(chain_len, max(0, num_budget - num_remaining))
chain_crop_size = randint(lower=crop_size_min,
upper=crop_size_max,
generator=generator,
device=chain_lens.device)
num_budget -= chain_crop_size
chain_start = randint(lower=0,
upper=chain_len - chain_crop_size,
generator=generator,
device=chain_lens.device)
asym_offset = asym_offsets[idx]
crop_idxs.append(
torch.arange(asym_offset + chain_start, asym_offset + chain_start + chain_crop_size)
)
return torch.concat(crop_idxs).sort().values
@curry1
def random_crop_to_size(
protein,
crop_size,
max_templates,
shape_schema,
spatial_crop_prob,
interface_threshold,
subsample_templates=False,
seed=None,
):
"""Crop randomly to `crop_size`, or keep as is if shorter than that."""
# We want each ensemble to be cropped the same way
g = None
if seed is not None:
g = torch.Generator(device=protein["seq_length"].device)
g.manual_seed(seed)
use_spatial_crop = torch.rand((1,),
device=protein["seq_length"].device,
generator=g) < spatial_crop_prob
num_res = protein["aatype"].shape[0]
if num_res <= crop_size:
crop_idxs = torch.arange(num_res)
elif use_spatial_crop:
crop_idxs = get_spatial_crop_idx(protein, crop_size, interface_threshold, g)
else:
crop_idxs = get_contiguous_crop_idx(protein, crop_size, g)
if "template_mask" in protein:
num_templates = protein["template_mask"].shape[-1]
else:
num_templates = 0
# No need to subsample templates if there aren't any
subsample_templates = subsample_templates and num_templates
if subsample_templates:
templates_crop_start = randint(lower=0,
upper=num_templates,
generator=g,
device=protein["seq_length"].device)
templates_select_indices = torch.randperm(
num_templates, device=protein["seq_length"].device, generator=g
)
else:
templates_crop_start = 0
num_res_crop_size = min(int(protein["seq_length"]), crop_size)
num_templates_crop_size = min(
num_templates - templates_crop_start, max_templates
)
for k, v in protein.items():
if k not in shape_schema or (
"template" not in k and NUM_RES not in shape_schema[k]
):
continue
# randomly permute the templates before cropping them.
if k.startswith("template") and subsample_templates:
v = v[templates_select_indices]
for i, (dim_size, dim) in enumerate(zip(shape_schema[k], v.shape)):
is_num_res = dim_size == NUM_RES
if i == 0 and k.startswith("template"):
v = v[slice(templates_crop_start, templates_crop_start + num_templates_crop_size)]
elif is_num_res:
v = torch.index_select(v, i, crop_idxs)
protein[k] = v
protein["seq_length"] = protein["seq_length"].new_tensor(num_res_crop_size)
return protein
...@@ -20,7 +20,7 @@ import ml_collections ...@@ -20,7 +20,7 @@ import ml_collections
import numpy as np import numpy as np
import torch import torch
from openfold.data import input_pipeline from openfold.data import input_pipeline, input_pipeline_multimer
FeatureDict = Mapping[str, np.ndarray] FeatureDict = Mapping[str, np.ndarray]
...@@ -80,11 +80,14 @@ def np_example_to_features( ...@@ -80,11 +80,14 @@ def np_example_to_features(
np_example: FeatureDict, np_example: FeatureDict,
config: ml_collections.ConfigDict, config: ml_collections.ConfigDict,
mode: str, mode: str,
is_multimer: bool = False
): ):
np_example = dict(np_example) np_example = dict(np_example)
num_res = int(np_example["seq_length"][0])
cfg, feature_names = make_data_config(config, mode=mode, num_res=num_res)
seq_length = np_example["seq_length"]
num_res = int(seq_length[0]) if seq_length.ndim != 0 else int(seq_length)
cfg, feature_names = make_data_config(config, mode=mode, num_res=num_res)
if "deletion_matrix_int" in np_example: if "deletion_matrix_int" in np_example:
np_example["deletion_matrix"] = np_example.pop( np_example["deletion_matrix"] = np_example.pop(
"deletion_matrix_int" "deletion_matrix_int"
...@@ -93,12 +96,20 @@ def np_example_to_features( ...@@ -93,12 +96,20 @@ def np_example_to_features(
tensor_dict = np_to_tensor_dict( tensor_dict = np_to_tensor_dict(
np_example=np_example, features=feature_names np_example=np_example, features=feature_names
) )
with torch.no_grad(): with torch.no_grad():
features = input_pipeline.process_tensors_from_config( if is_multimer:
tensor_dict, features = input_pipeline_multimer.process_tensors_from_config(
cfg.common, tensor_dict,
cfg[mode], cfg.common,
) cfg[mode],
)
else:
features = input_pipeline.process_tensors_from_config(
tensor_dict,
cfg.common,
cfg[mode],
)
if mode == "train": if mode == "train":
p = torch.rand(1).item() p = torch.rand(1).item()
...@@ -128,10 +139,15 @@ class FeaturePipeline: ...@@ -128,10 +139,15 @@ class FeaturePipeline:
def process_features( def process_features(
self, self,
raw_features: FeatureDict, raw_features: FeatureDict,
mode: str = "train", mode: str = "train",
is_multimer: bool = False,
) -> FeatureDict: ) -> FeatureDict:
# if(is_multimer and mode != "predict"):
# raise ValueError("Multimer mode is not currently trainable")
return np_example_to_features( return np_example_to_features(
np_example=raw_features, np_example=raw_features,
config=self.config, config=self.config,
mode=mode, mode=mode,
is_multimer=is_multimer,
) )
# Copyright 2021 DeepMind Technologies Limited
# Copyright 2022 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Feature processing logic for multimer data pipeline."""
from typing import Iterable, MutableMapping, List, Mapping
from openfold.data import msa_pairing
from openfold.np import residue_constants
import numpy as np
# TODO: Move this into the config
REQUIRED_FEATURES = frozenset({
'aatype', 'all_atom_mask', 'all_atom_positions', 'all_chains_entity_ids',
'all_crops_all_chains_mask', 'all_crops_all_chains_positions',
'all_crops_all_chains_residue_ids', 'assembly_num_chains', 'asym_id',
'bert_mask', 'cluster_bias_mask', 'deletion_matrix', 'deletion_mean',
'entity_id', 'entity_mask', 'mem_peak', 'msa', 'msa_mask', 'num_alignments',
'num_templates', 'queue_size', 'residue_index', 'resolution',
'seq_length', 'seq_mask', 'sym_id', 'template_aatype',
'template_all_atom_mask', 'template_all_atom_positions'
})
MAX_TEMPLATES = 4
MSA_CROP_SIZE = 2048
def _is_homomer_or_monomer(chains: Iterable[Mapping[str, np.ndarray]]) -> bool:
"""Checks if a list of chains represents a homomer/monomer example."""
# Note that an entity_id of 0 indicates padding.
num_unique_chains = len(np.unique(np.concatenate(
[np.unique(chain['entity_id'][chain['entity_id'] > 0]) for
chain in chains])))
return num_unique_chains == 1
def pair_and_merge(
all_chain_features: MutableMapping[str, Mapping[str, np.ndarray]],
) -> Mapping[str, np.ndarray]:
"""Runs processing on features to augment, pair and merge.
Args:
all_chain_features: A MutableMap of dictionaries of features for each chain.
Returns:
A dictionary of features.
"""
process_unmerged_features(all_chain_features)
np_chains_list = list(all_chain_features.values())
pair_msa_sequences = not _is_homomer_or_monomer(np_chains_list)
if pair_msa_sequences:
np_chains_list = msa_pairing.create_paired_features(
chains=np_chains_list
)
np_chains_list = msa_pairing.deduplicate_unpaired_sequences(np_chains_list)
np_chains_list = crop_chains(
np_chains_list,
msa_crop_size=MSA_CROP_SIZE,
pair_msa_sequences=pair_msa_sequences,
max_templates=MAX_TEMPLATES
)
np_example = msa_pairing.merge_chain_features(
np_chains_list=np_chains_list, pair_msa_sequences=pair_msa_sequences,
max_templates=MAX_TEMPLATES
)
np_example = process_final(np_example)
return np_example
def crop_chains(
chains_list: List[Mapping[str, np.ndarray]],
msa_crop_size: int,
pair_msa_sequences: bool,
max_templates: int
) -> List[Mapping[str, np.ndarray]]:
"""Crops the MSAs for a set of chains.
Args:
chains_list: A list of chains to be cropped.
msa_crop_size: The total number of sequences to crop from the MSA.
pair_msa_sequences: Whether we are operating in sequence-pairing mode.
max_templates: The maximum templates to use per chain.
Returns:
The chains cropped.
"""
# Apply the cropping.
cropped_chains = []
for chain in chains_list:
cropped_chain = _crop_single_chain(
chain,
msa_crop_size=msa_crop_size,
pair_msa_sequences=pair_msa_sequences,
max_templates=max_templates)
cropped_chains.append(cropped_chain)
return cropped_chains
def _crop_single_chain(chain: Mapping[str, np.ndarray],
msa_crop_size: int,
pair_msa_sequences: bool,
max_templates: int) -> Mapping[str, np.ndarray]:
"""Crops msa sequences to `msa_crop_size`."""
msa_size = chain['num_alignments']
if pair_msa_sequences:
msa_size_all_seq = chain['num_alignments_all_seq']
msa_crop_size_all_seq = np.minimum(msa_size_all_seq, msa_crop_size // 2)
# We reduce the number of un-paired sequences, by the number of times a
# sequence from this chain's MSA is included in the paired MSA. This keeps
# the MSA size for each chain roughly constant.
msa_all_seq = chain['msa_all_seq'][:msa_crop_size_all_seq, :]
num_non_gapped_pairs = np.sum(
np.any(msa_all_seq != msa_pairing.MSA_GAP_IDX, axis=1))
num_non_gapped_pairs = np.minimum(num_non_gapped_pairs,
msa_crop_size_all_seq)
# Restrict the unpaired crop size so that paired+unpaired sequences do not
# exceed msa_seqs_per_chain for each chain.
max_msa_crop_size = np.maximum(msa_crop_size - num_non_gapped_pairs, 0)
msa_crop_size = np.minimum(msa_size, max_msa_crop_size)
else:
msa_crop_size = np.minimum(msa_size, msa_crop_size)
include_templates = 'template_aatype' in chain and max_templates
if include_templates:
num_templates = chain['template_aatype'].shape[0]
templates_crop_size = np.minimum(num_templates, max_templates)
for k in chain:
k_split = k.split('_all_seq')[0]
if k_split in msa_pairing.TEMPLATE_FEATURES:
chain[k] = chain[k][:templates_crop_size, :]
elif k_split in msa_pairing.MSA_FEATURES:
if '_all_seq' in k and pair_msa_sequences:
chain[k] = chain[k][:msa_crop_size_all_seq, :]
else:
chain[k] = chain[k][:msa_crop_size, :]
chain['num_alignments'] = np.asarray(msa_crop_size, dtype=np.int32)
if include_templates:
chain['num_templates'] = np.asarray(templates_crop_size, dtype=np.int32)
if pair_msa_sequences:
chain['num_alignments_all_seq'] = np.asarray(
msa_crop_size_all_seq, dtype=np.int32)
return chain
def process_final(
np_example: Mapping[str, np.ndarray]
) -> Mapping[str, np.ndarray]:
"""Final processing steps in data pipeline, after merging and pairing."""
np_example = _correct_msa_restypes(np_example)
np_example = _make_seq_mask(np_example)
np_example = _make_msa_mask(np_example)
np_example = _filter_features(np_example)
return np_example
def _correct_msa_restypes(np_example):
"""Correct MSA restype to have the same order as residue_constants."""
new_order_list = residue_constants.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
np_example['msa'] = np.take(new_order_list, np_example['msa'], axis=0)
np_example['msa'] = np_example['msa'].astype(np.int32)
return np_example
def _make_seq_mask(np_example):
np_example['seq_mask'] = (np_example['entity_id'] > 0).astype(np.float32)
return np_example
def _make_msa_mask(np_example):
"""Mask features are all ones, but will later be zero-padded."""
np_example['msa_mask'] = np.ones_like(np_example['msa'], dtype=np.float32)
seq_mask = (np_example['entity_id'] > 0).astype(np.float32)
np_example['msa_mask'] *= seq_mask[None]
return np_example
def _filter_features(
np_example: Mapping[str, np.ndarray]
) -> Mapping[str, np.ndarray]:
"""Filters features of example to only those requested."""
return {k: v for (k, v) in np_example.items() if k in REQUIRED_FEATURES}
def process_unmerged_features(
all_chain_features: MutableMapping[str, Mapping[str, np.ndarray]]
):
"""Postprocessing stage for per-chain features before merging."""
num_chains = len(all_chain_features)
for chain_features in all_chain_features.values():
# Convert deletion matrices to float.
chain_features['deletion_matrix'] = np.asarray(
chain_features.pop('deletion_matrix_int'), dtype=np.float32
)
if 'deletion_matrix_int_all_seq' in chain_features:
chain_features['deletion_matrix_all_seq'] = np.asarray(
chain_features.pop('deletion_matrix_int_all_seq'), dtype=np.float32
)
chain_features['deletion_mean'] = np.mean(
chain_features['deletion_matrix'], axis=0
)
if 'all_atom_positions' not in chain_features:
# Add all_atom_mask and dummy all_atom_positions based on aatype.
all_atom_mask = residue_constants.STANDARD_ATOM_MASK[
chain_features['aatype']]
chain_features['all_atom_mask'] = all_atom_mask.astype(dtype=np.float32)
chain_features['all_atom_positions'] = np.zeros(
list(all_atom_mask.shape) + [3])
# Add assembly_num_chains.
chain_features['assembly_num_chains'] = np.asarray(num_chains)
# Add entity_mask.
for chain_features in all_chain_features.values():
chain_features['entity_mask'] = (
chain_features['entity_id'] != 0).astype(np.int32)
...@@ -107,7 +107,8 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed): ...@@ -107,7 +107,8 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed):
# the masked locations and secret corrupted locations. # the masked locations and secret corrupted locations.
transforms.append( transforms.append(
data_transforms.make_masked_msa( data_transforms.make_masked_msa(
common_cfg.masked_msa, mode_cfg.masked_msa_replace_fraction common_cfg.masked_msa, mode_cfg.masked_msa_replace_fraction,
seed=(msa_seed + 1) if msa_seed else None,
) )
) )
......
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import torch
from openfold.data import (
data_transforms,
data_transforms_multimer,
)
def groundtruth_transforms_fns():
transforms = [data_transforms.make_atom14_masks,
data_transforms.make_atom14_positions,
data_transforms.atom37_to_frames,
data_transforms.atom37_to_torsion_angles(""),
data_transforms.make_pseudo_beta(""),
data_transforms.get_backbone_frames,
data_transforms.get_chi_angles]
return transforms
def nonensembled_transform_fns():
"""Input pipeline data transformers that are not ensembled."""
transforms = [
data_transforms.cast_to_64bit_ints,
data_transforms_multimer.make_msa_profile,
data_transforms_multimer.create_target_feat,
data_transforms.make_atom14_masks
]
return transforms
def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed):
"""Input pipeline data transformers that can be ensembled and averaged."""
transforms = []
pad_msa_clusters = mode_cfg.max_msa_clusters
max_msa_clusters = pad_msa_clusters
max_extra_msa = mode_cfg.max_extra_msa
msa_seed = None
if(not common_cfg.resample_msa_in_recycling):
msa_seed = ensemble_seed
transforms.append(
data_transforms_multimer.sample_msa(
max_msa_clusters,
max_extra_msa,
seed=msa_seed,
)
)
if "masked_msa" in common_cfg:
# Masked MSA should come *before* MSA clustering so that
# the clustering and full MSA profile do not leak information about
# the masked locations and secret corrupted locations.
transforms.append(
data_transforms_multimer.make_masked_msa(
common_cfg.masked_msa,
mode_cfg.masked_msa_replace_fraction,
seed=(msa_seed + 1) if msa_seed else None,
)
)
transforms.append(data_transforms_multimer.nearest_neighbor_clusters())
transforms.append(data_transforms_multimer.create_msa_feat)
crop_feats = dict(common_cfg.feat)
if mode_cfg.fixed_size:
transforms.append(data_transforms.select_feat(list(crop_feats)))
if mode_cfg.crop:
transforms.append(
data_transforms_multimer.random_crop_to_size(
crop_size=mode_cfg.crop_size,
max_templates=mode_cfg.max_templates,
shape_schema=crop_feats,
spatial_crop_prob=mode_cfg.spatial_crop_prob,
interface_threshold=mode_cfg.interface_threshold,
subsample_templates=mode_cfg.subsample_templates,
seed=ensemble_seed + 1,
)
)
transforms.append(
data_transforms.make_fixed_size(
shape_schema=crop_feats,
msa_cluster_size=pad_msa_clusters,
extra_msa_size=mode_cfg.max_extra_msa,
num_res=mode_cfg.crop_size,
num_templates=mode_cfg.max_templates,
)
)
else:
transforms.append(
data_transforms.crop_templates(mode_cfg.max_templates)
)
return transforms
def prepare_ground_truth_features(tensors):
"""Prepare ground truth features that are only needed for loss calculation during training"""
gt_features = ['all_atom_mask', 'all_atom_positions', 'asym_id', 'sym_id', 'entity_id']
gt_tensors = {k: v for k, v in tensors.items() if k in gt_features}
gt_tensors['aatype'] = tensors['aatype'].to(torch.long)
gt_tensors = compose(groundtruth_transforms_fns())(gt_tensors)
return gt_tensors
def process_tensors_from_config(tensors, common_cfg, mode_cfg):
"""Based on the config, apply filters and transformations to the data."""
process_gt_feats = mode_cfg.supervised
gt_tensors = {}
if process_gt_feats:
gt_tensors = prepare_ground_truth_features(tensors)
ensemble_seed = random.randint(0, torch.iinfo(torch.int32).max)
tensors['aatype'] = tensors['aatype'].to(torch.long)
nonensembled = nonensembled_transform_fns()
tensors = compose(nonensembled)(tensors)
if("no_recycling_iters" in tensors):
num_recycling = int(tensors["no_recycling_iters"])
else:
num_recycling = common_cfg.max_recycling_iters
def wrap_ensemble_fn(data, i):
"""Function to be mapped over the ensemble dimension."""
d = data.copy()
fns = ensembled_transform_fns(
common_cfg,
mode_cfg,
ensemble_seed,
)
fn = compose(fns)
d["ensemble_index"] = i
return fn(d)
tensors = map_fn(
lambda x: wrap_ensemble_fn(tensors, x), torch.arange(num_recycling + 1)
)
if process_gt_feats:
tensors['gt_features'] = gt_tensors
return tensors
@data_transforms.curry1
def compose(x, fs):
for f in fs:
x = f(x)
return x
def map_fn(fun, x):
ensembles = [fun(elem) for elem in x]
features = ensembles[0].keys()
ensembled_dict = {}
for feat in features:
ensembled_dict[feat] = torch.stack(
[dict_i[feat] for dict_i in ensembles], dim=-1
)
return ensembled_dict
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
"""Parses the mmCIF file format.""" """Parses the mmCIF file format."""
import collections import collections
import dataclasses import dataclasses
import functools
import io import io
import json import json
import logging import logging
...@@ -173,6 +174,7 @@ def mmcif_loop_to_dict( ...@@ -173,6 +174,7 @@ def mmcif_loop_to_dict(
return {entry[index]: entry for entry in entries} return {entry[index]: entry for entry in entries}
@functools.lru_cache(16, typed=False)
def parse( def parse(
*, file_id: str, mmcif_string: str, catch_all_errors: bool = True *, file_id: str, mmcif_string: str, catch_all_errors: bool = True
) -> ParsingResult: ) -> ParsingResult:
...@@ -346,7 +348,7 @@ def _get_header(parsed_info: MmCIFDict) -> PdbHeader: ...@@ -346,7 +348,7 @@ def _get_header(parsed_info: MmCIFDict) -> PdbHeader:
raw_resolution = parsed_info[res_key][0] raw_resolution = parsed_info[res_key][0]
header["resolution"] = float(raw_resolution) header["resolution"] = float(raw_resolution)
except ValueError: except ValueError:
logging.info( logging.debug(
"Invalid resolution format: %s", parsed_info[res_key] "Invalid resolution format: %s", parsed_info[res_key]
) )
...@@ -474,6 +476,20 @@ def get_atom_coords( ...@@ -474,6 +476,20 @@ def get_atom_coords(
pos[residue_constants.atom_order["SD"]] = [x, y, z] pos[residue_constants.atom_order["SD"]] = [x, y, z]
mask[residue_constants.atom_order["SD"]] = 1.0 mask[residue_constants.atom_order["SD"]] = 1.0
# Fix naming errors in arginine residues where NH2 is incorrectly
# assigned to be closer to CD than NH1
cd = residue_constants.atom_order['CD']
nh1 = residue_constants.atom_order['NH1']
nh2 = residue_constants.atom_order['NH2']
if(
res.get_resname() == 'ARG' and
all(mask[atom_index] for atom_index in (cd, nh1, nh2)) and
(np.linalg.norm(pos[nh1] - pos[cd]) >
np.linalg.norm(pos[nh2] - pos[cd]))
):
pos[nh1], pos[nh2] = pos[nh2].copy(), pos[nh1].copy()
mask[nh1], mask[nh2] = mask[nh2].copy(), mask[nh1].copy()
all_atom_positions[res_index] = pos all_atom_positions[res_index] = pos
all_atom_mask[res_index] = mask all_atom_mask[res_index] = mask
......
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for extracting identifiers from MSA sequence descriptions."""
import dataclasses
import re
from typing import Optional
# Sequences coming from UniProtKB database come in the
# `db|UniqueIdentifier|EntryName` format, e.g. `tr|A0A146SKV9|A0A146SKV9_FUNHE`
# or `sp|P0C2L1|A3X1_LOXLA` (for TREMBL/Swiss-Prot respectively).
_UNIPROT_PATTERN = re.compile(
r"""
^
# UniProtKB/TrEMBL or UniProtKB/Swiss-Prot
(?:tr|sp)
\|
# A primary accession number of the UniProtKB entry.
(?P<AccessionIdentifier>[A-Za-z0-9]{6,10})
# Occasionally there is a _0 or _1 isoform suffix, which we ignore.
(?:_\d)?
\|
# TREMBL repeats the accession ID here. Swiss-Prot has a mnemonic
# protein ID code.
(?:[A-Za-z0-9]+)
_
# A mnemonic species identification code.
(?P<SpeciesIdentifier>([A-Za-z0-9]){1,5})
# Small BFD uses a final value after an underscore, which we ignore.
(?:_\d+)?
$
""",
re.VERBOSE)
@dataclasses.dataclass(frozen=True)
class Identifiers:
species_id: str = ''
def _parse_sequence_identifier(msa_sequence_identifier: str) -> Identifiers:
"""Gets accession id and species from an msa sequence identifier.
The sequence identifier has the format specified by
_UNIPROT_TREMBL_ENTRY_NAME_PATTERN or _UNIPROT_SWISSPROT_ENTRY_NAME_PATTERN.
An example of a sequence identifier: `tr|A0A146SKV9|A0A146SKV9_FUNHE`
Args:
msa_sequence_identifier: a sequence identifier.
Returns:
An `Identifiers` instance with a uniprot_accession_id and species_id. These
can be empty in the case where no identifier was found.
"""
matches = re.search(_UNIPROT_PATTERN, msa_sequence_identifier.strip())
if matches:
return Identifiers(
species_id=matches.group('SpeciesIdentifier')
)
return Identifiers()
def _extract_sequence_identifier(description: str) -> Optional[str]:
"""Extracts sequence identifier from description. Returns None if no match."""
split_description = description.split()
if split_description:
return split_description[0].partition('/')[0]
else:
return None
def get_identifiers(description: str) -> Identifiers:
"""Computes extra MSA features from the description."""
sequence_identifier = _extract_sequence_identifier(description)
if sequence_identifier is None:
return Identifiers()
else:
return _parse_sequence_identifier(sequence_identifier)
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pairing logic for multimer data pipeline."""
import collections
import functools
import string
from typing import Any, Dict, Iterable, List, Sequence, Mapping
import numpy as np
import pandas as pd
import scipy.linalg
from openfold.np import residue_constants
# TODO: This stuff should probably also be in a config
MSA_GAP_IDX = residue_constants.restypes_with_x_and_gap.index('-')
SEQUENCE_GAP_CUTOFF = 0.5
SEQUENCE_SIMILARITY_CUTOFF = 0.9
MSA_PAD_VALUES = {'msa_all_seq': MSA_GAP_IDX,
'msa_mask_all_seq': 1,
'deletion_matrix_all_seq': 0,
'deletion_matrix_int_all_seq': 0,
'msa': MSA_GAP_IDX,
'msa_mask': 1,
'deletion_matrix': 0,
'deletion_matrix_int': 0}
MSA_FEATURES = ('msa', 'msa_mask', 'deletion_matrix', 'deletion_matrix_int')
SEQ_FEATURES = ('residue_index', 'aatype', 'all_atom_positions',
'all_atom_mask', 'seq_mask', 'between_segment_residues',
'has_alt_locations', 'has_hetatoms', 'asym_id', 'entity_id',
'sym_id', 'entity_mask', 'deletion_mean',
'prediction_atom_mask',
'literature_positions', 'atom_indices_to_group_indices',
'rigid_group_default_frame')
TEMPLATE_FEATURES = ('template_aatype', 'template_all_atom_positions',
'template_all_atom_mask')
CHAIN_FEATURES = ('num_alignments', 'seq_length')
def create_paired_features(
chains: Iterable[Mapping[str, np.ndarray]],
) -> List[Mapping[str, np.ndarray]]:
"""Returns the original chains with paired NUM_SEQ features.
Args:
chains: A list of feature dictionaries for each chain.
Returns:
A list of feature dictionaries with sequence features including only
rows to be paired.
"""
chains = list(chains)
chain_keys = chains[0].keys()
if len(chains) < 2:
return chains
else:
updated_chains = []
paired_chains_to_paired_row_indices = pair_sequences(chains)
paired_rows = reorder_paired_rows(
paired_chains_to_paired_row_indices)
for chain_num, chain in enumerate(chains):
new_chain = {k: v for k, v in chain.items() if '_all_seq' not in k}
for feature_name in chain_keys:
if feature_name.endswith('_all_seq'):
feats_padded = pad_features(chain[feature_name], feature_name)
new_chain[feature_name] = feats_padded[paired_rows[:, chain_num]]
new_chain['num_alignments_all_seq'] = np.asarray(
len(paired_rows[:, chain_num]))
updated_chains.append(new_chain)
return updated_chains
def pad_features(feature: np.ndarray, feature_name: str) -> np.ndarray:
"""Add a 'padding' row at the end of the features list.
The padding row will be selected as a 'paired' row in the case of partial
alignment - for the chain that doesn't have paired alignment.
Args:
feature: The feature to be padded.
feature_name: The name of the feature to be padded.
Returns:
The feature with an additional padding row.
"""
assert feature.dtype != np.dtype(np.string_)
if feature_name in ('msa_all_seq', 'msa_mask_all_seq',
'deletion_matrix_all_seq', 'deletion_matrix_int_all_seq'):
num_res = feature.shape[1]
padding = MSA_PAD_VALUES[feature_name] * np.ones([1, num_res],
feature.dtype)
elif feature_name == 'msa_species_identifiers_all_seq':
padding = [b'']
else:
return feature
feats_padded = np.concatenate([feature, padding], axis=0)
return feats_padded
def _make_msa_df(chain_features: Mapping[str, np.ndarray]) -> pd.DataFrame:
"""Makes dataframe with msa features needed for msa pairing."""
chain_msa = chain_features['msa_all_seq']
query_seq = chain_msa[0]
per_seq_similarity = np.sum(
query_seq[None] == chain_msa, axis=-1) / float(len(query_seq))
per_seq_gap = np.sum(chain_msa == 21, axis=-1) / float(len(query_seq))
msa_df = pd.DataFrame({
'msa_species_identifiers':
chain_features['msa_species_identifiers_all_seq'],
'msa_row':
np.arange(len(
chain_features['msa_species_identifiers_all_seq'])),
'msa_similarity': per_seq_similarity,
'gap': per_seq_gap
})
return msa_df
def _create_species_dict(msa_df: pd.DataFrame) -> Dict[bytes, pd.DataFrame]:
"""Creates mapping from species to msa dataframe of that species."""
species_lookup = {}
for species, species_df in msa_df.groupby('msa_species_identifiers'):
species_lookup[species] = species_df
return species_lookup
def _match_rows_by_sequence_similarity(this_species_msa_dfs: List[pd.DataFrame]
) -> List[List[int]]:
"""Finds MSA sequence pairings across chains based on sequence similarity.
Each chain's MSA sequences are first sorted by their sequence similarity to
their respective target sequence. The sequences are then paired, starting
from the sequences most similar to their target sequence.
Args:
this_species_msa_dfs: a list of dataframes containing MSA features for
sequences for a specific species.
Returns:
A list of lists, each containing M indices corresponding to paired MSA rows,
where M is the number of chains.
"""
all_paired_msa_rows = []
num_seqs = [len(species_df) for species_df in this_species_msa_dfs
if species_df is not None]
take_num_seqs = np.min(num_seqs)
sort_by_similarity = (
lambda x: x.sort_values('msa_similarity', axis=0, ascending=False))
for species_df in this_species_msa_dfs:
if species_df is not None:
species_df_sorted = sort_by_similarity(species_df)
msa_rows = species_df_sorted.msa_row.iloc[:take_num_seqs].values
else:
msa_rows = [-1] * take_num_seqs # take the last 'padding' row
all_paired_msa_rows.append(msa_rows)
all_paired_msa_rows = list(np.array(all_paired_msa_rows).transpose())
return all_paired_msa_rows
def pair_sequences(
examples: List[Mapping[str, np.ndarray]],
) -> Dict[int, np.ndarray]:
"""Returns indices for paired MSA sequences across chains."""
num_examples = len(examples)
all_chain_species_dict = []
common_species = set()
for chain_features in examples:
msa_df = _make_msa_df(chain_features)
species_dict = _create_species_dict(msa_df)
all_chain_species_dict.append(species_dict)
common_species.update(set(species_dict))
common_species = sorted(common_species)
common_species.remove(b'') # Remove target sequence species.
all_paired_msa_rows = [np.zeros(len(examples), int)]
all_paired_msa_rows_dict = {k: [] for k in range(num_examples)}
all_paired_msa_rows_dict[num_examples] = [np.zeros(len(examples), int)]
for species in common_species:
if not species:
continue
this_species_msa_dfs = []
species_dfs_present = 0
for species_dict in all_chain_species_dict:
if species in species_dict:
this_species_msa_dfs.append(species_dict[species])
species_dfs_present += 1
else:
this_species_msa_dfs.append(None)
# Skip species that are present in only one chain.
if species_dfs_present <= 1:
continue
if np.any(
np.array([len(species_df) for species_df in
this_species_msa_dfs if
isinstance(species_df, pd.DataFrame)]) > 600):
continue
paired_msa_rows = _match_rows_by_sequence_similarity(this_species_msa_dfs)
all_paired_msa_rows.extend(paired_msa_rows)
all_paired_msa_rows_dict[species_dfs_present].extend(paired_msa_rows)
all_paired_msa_rows_dict = {
num_examples: np.array(paired_msa_rows) for
num_examples, paired_msa_rows in all_paired_msa_rows_dict.items()
}
return all_paired_msa_rows_dict
def reorder_paired_rows(all_paired_msa_rows_dict: Dict[int, np.ndarray]
) -> np.ndarray:
"""Creates a list of indices of paired MSA rows across chains.
Args:
all_paired_msa_rows_dict: a mapping from the number of paired chains to the
paired indices.
Returns:
a list of lists, each containing indices of paired MSA rows across chains.
The paired-index lists are ordered by:
1) the number of chains in the paired alignment, i.e, all-chain pairings
will come first.
2) e-values
"""
all_paired_msa_rows = []
for num_pairings in sorted(all_paired_msa_rows_dict, reverse=True):
paired_rows = all_paired_msa_rows_dict[num_pairings]
paired_rows_product = abs(np.array([np.prod(rows) for rows in paired_rows]))
paired_rows_sort_index = np.argsort(paired_rows_product)
all_paired_msa_rows.extend(paired_rows[paired_rows_sort_index])
return np.array(all_paired_msa_rows)
def block_diag(*arrs: np.ndarray, pad_value: float = 0.0) -> np.ndarray:
"""Like scipy.linalg.block_diag but with an optional padding value."""
ones_arrs = [np.ones_like(x) for x in arrs]
off_diag_mask = 1.0 - scipy.linalg.block_diag(*ones_arrs)
diag = scipy.linalg.block_diag(*arrs)
diag += (off_diag_mask * pad_value).astype(diag.dtype)
return diag
def _correct_post_merged_feats(
np_example: Mapping[str, np.ndarray],
np_chains_list: Sequence[Mapping[str, np.ndarray]],
pair_msa_sequences: bool
) -> Mapping[str, np.ndarray]:
"""Adds features that need to be computed/recomputed post merging."""
np_example['seq_length'] = np.asarray(
np_example['aatype'].shape[0],
dtype=np.int32
)
np_example['num_alignments'] = np.asarray(
np_example['msa'].shape[0],
dtype=np.int32
)
if not pair_msa_sequences:
# Generate a bias that is 1 for the first row of every block in the
# block diagonal MSA - i.e. make sure the cluster stack always includes
# the query sequences for each chain (since the first row is the query
# sequence).
cluster_bias_masks = []
for chain in np_chains_list:
mask = np.zeros(chain['msa'].shape[0])
mask[0] = 1
cluster_bias_masks.append(mask)
np_example['cluster_bias_mask'] = np.concatenate(cluster_bias_masks)
# Initialize Bert mask with masked out off diagonals.
msa_masks = [
np.ones(x['msa'].shape, dtype=np.float32)
for x in np_chains_list
]
np_example['bert_mask'] = block_diag(
*msa_masks, pad_value=0
)
else:
np_example['cluster_bias_mask'] = np.zeros(np_example['msa'].shape[0])
np_example['cluster_bias_mask'][0] = 1
# Initialize Bert mask with masked out off diagonals.
msa_masks = [
np.ones(x['msa'].shape, dtype=np.float32) for
x in np_chains_list
]
msa_masks_all_seq = [
np.ones(x['msa_all_seq'].shape, dtype=np.float32) for
x in np_chains_list
]
msa_mask_block_diag = block_diag(
*msa_masks, pad_value=0
)
msa_mask_all_seq = np.concatenate(msa_masks_all_seq, axis=1)
np_example['bert_mask'] = np.concatenate(
[msa_mask_all_seq, msa_mask_block_diag],
axis=0
)
return np_example
def _pad_templates(chains: Sequence[Mapping[str, np.ndarray]],
max_templates: int) -> Sequence[Mapping[str, np.ndarray]]:
"""For each chain pad the number of templates to a fixed size.
Args:
chains: A list of protein chains.
max_templates: Each chain will be padded to have this many templates.
Returns:
The list of chains, updated to have template features padded to
max_templates.
"""
for chain in chains:
for k, v in chain.items():
if k in TEMPLATE_FEATURES:
padding = np.zeros_like(v.shape)
padding[0] = max_templates - v.shape[0]
padding = [(0, p) for p in padding]
chain[k] = np.pad(v, padding, mode='constant')
return chains
def _merge_features_from_multiple_chains(
chains: Sequence[Mapping[str, np.ndarray]],
pair_msa_sequences: bool) -> Mapping[str, np.ndarray]:
"""Merge features from multiple chains.
Args:
chains: A list of feature dictionaries that we want to merge.
pair_msa_sequences: Whether to concatenate MSA features along the
num_res dimension (if True), or to block diagonalize them (if False).
Returns:
A feature dictionary for the merged example.
"""
merged_example = {}
for feature_name in chains[0]:
feats = [x[feature_name] for x in chains]
feature_name_split = feature_name.split('_all_seq')[0]
if feature_name_split in MSA_FEATURES:
if pair_msa_sequences or '_all_seq' in feature_name:
merged_example[feature_name] = np.concatenate(feats, axis=1)
else:
merged_example[feature_name] = block_diag(
*feats, pad_value=MSA_PAD_VALUES[feature_name])
elif feature_name_split in SEQ_FEATURES:
merged_example[feature_name] = np.concatenate(feats, axis=0)
elif feature_name_split in TEMPLATE_FEATURES:
merged_example[feature_name] = np.concatenate(feats, axis=1)
elif feature_name_split in CHAIN_FEATURES:
merged_example[feature_name] = np.sum(x for x in feats).astype(np.int32)
else:
merged_example[feature_name] = feats[0]
return merged_example
def _merge_homomers_dense_msa(
chains: Iterable[Mapping[str, np.ndarray]]) -> Sequence[Mapping[str, np.ndarray]]:
"""Merge all identical chains, making the resulting MSA dense.
Args:
chains: An iterable of features for each chain.
Returns:
A list of feature dictionaries. All features with the same entity_id
will be merged - MSA features will be concatenated along the num_res
dimension - making them dense.
"""
entity_chains = collections.defaultdict(list)
for chain in chains:
entity_id = chain['entity_id'][0]
entity_chains[entity_id].append(chain)
grouped_chains = []
for entity_id in sorted(entity_chains):
chains = entity_chains[entity_id]
grouped_chains.append(chains)
chains = [
_merge_features_from_multiple_chains(chains, pair_msa_sequences=True)
for chains in grouped_chains]
return chains
def _concatenate_paired_and_unpaired_features(
example: Mapping[str, np.ndarray]) -> Mapping[str, np.ndarray]:
"""Merges paired and block-diagonalised features."""
features = MSA_FEATURES
for feature_name in features:
if feature_name in example:
feat = example[feature_name]
feat_all_seq = example[feature_name + '_all_seq']
merged_feat = np.concatenate([feat_all_seq, feat], axis=0)
example[feature_name] = merged_feat
example['num_alignments'] = np.array(example['msa'].shape[0],
dtype=np.int32)
return example
def merge_chain_features(np_chains_list: List[Mapping[str, np.ndarray]],
pair_msa_sequences: bool,
max_templates: int) -> Mapping[str, np.ndarray]:
"""Merges features for multiple chains to single FeatureDict.
Args:
np_chains_list: List of FeatureDicts for each chain.
pair_msa_sequences: Whether to merge paired MSAs.
max_templates: The maximum number of templates to include.
Returns:
Single FeatureDict for entire complex.
"""
np_chains_list = _pad_templates(
np_chains_list, max_templates=max_templates)
np_chains_list = _merge_homomers_dense_msa(np_chains_list)
# Unpaired MSA features will be always block-diagonalised; paired MSA
# features will be concatenated.
np_example = _merge_features_from_multiple_chains(
np_chains_list, pair_msa_sequences=False)
if pair_msa_sequences:
np_example = _concatenate_paired_and_unpaired_features(np_example)
np_example = _correct_post_merged_feats(
np_example=np_example,
np_chains_list=np_chains_list,
pair_msa_sequences=pair_msa_sequences)
return np_example
def deduplicate_unpaired_sequences(
np_chains: List[Mapping[str, np.ndarray]]) -> List[Mapping[str, np.ndarray]]:
"""Removes unpaired sequences which duplicate a paired sequence."""
feature_names = np_chains[0].keys()
msa_features = MSA_FEATURES
for chain in np_chains:
# Convert the msa_all_seq numpy array to a tuple for hashing.
sequence_set = set(tuple(s) for s in chain['msa_all_seq'])
keep_rows = []
# Go through unpaired MSA seqs and remove any rows that correspond to the
# sequences that are already present in the paired MSA.
for row_num, seq in enumerate(chain['msa']):
if tuple(seq) not in sequence_set:
keep_rows.append(row_num)
for feature_name in feature_names:
if feature_name in msa_features:
chain[feature_name] = chain[feature_name][keep_rows]
chain['num_alignments'] = np.array(chain['msa'].shape[0], dtype=np.int32)
return np_chains
...@@ -16,14 +16,43 @@ ...@@ -16,14 +16,43 @@
"""Functions for parsing various file formats.""" """Functions for parsing various file formats."""
import collections import collections
import dataclasses import dataclasses
import itertools
import re import re
import string import string
from typing import Dict, Iterable, List, Optional, Sequence, Tuple from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Set
DeletionMatrix = Sequence[Sequence[int]] DeletionMatrix = Sequence[Sequence[int]]
@dataclasses.dataclass(frozen=True)
class Msa:
"""Class representing a parsed MSA file"""
sequences: Sequence[str]
deletion_matrix: DeletionMatrix
descriptions: Optional[Sequence[str]]
def __post_init__(self):
if(not (
len(self.sequences) ==
len(self.deletion_matrix) ==
len(self.descriptions)
)):
raise ValueError(
"All fields for an MSA must have the same length"
)
def __len__(self):
return len(self.sequences)
def truncate(self, max_seqs: int):
return Msa(
sequences=self.sequences[:max_seqs],
deletion_matrix=self.deletion_matrix[:max_seqs],
descriptions=self.descriptions[:max_seqs],
)
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class TemplateHit: class TemplateHit:
"""Class representing a template hit.""" """Class representing a template hit."""
...@@ -31,7 +60,7 @@ class TemplateHit: ...@@ -31,7 +60,7 @@ class TemplateHit:
index: int index: int
name: str name: str
aligned_cols: int aligned_cols: int
sum_probs: float sum_probs: Optional[float]
query: str query: str
hit_sequence: str hit_sequence: str
indices_query: List[int] indices_query: List[int]
...@@ -69,9 +98,7 @@ def parse_fasta(fasta_string: str) -> Tuple[Sequence[str], Sequence[str]]: ...@@ -69,9 +98,7 @@ def parse_fasta(fasta_string: str) -> Tuple[Sequence[str], Sequence[str]]:
return sequences, descriptions return sequences, descriptions
def parse_stockholm( def parse_stockholm(stockholm_string: str) -> Msa:
stockholm_string: str,
) -> Tuple[Sequence[str], DeletionMatrix, Sequence[str]]:
"""Parses sequences and deletion matrix from stockholm format alignment. """Parses sequences and deletion matrix from stockholm format alignment.
Args: Args:
...@@ -126,10 +153,14 @@ def parse_stockholm( ...@@ -126,10 +153,14 @@ def parse_stockholm(
deletion_count = 0 deletion_count = 0
deletion_matrix.append(deletion_vec) deletion_matrix.append(deletion_vec)
return msa, deletion_matrix, list(name_to_sequence.keys()) return Msa(
sequences=msa,
deletion_matrix=deletion_matrix,
descriptions=list(name_to_sequence.keys())
)
def parse_a3m(a3m_string: str) -> Tuple[Sequence[str], DeletionMatrix]: def parse_a3m(a3m_string: str) -> Msa:
"""Parses sequences and deletion matrix from a3m format alignment. """Parses sequences and deletion matrix from a3m format alignment.
Args: Args:
...@@ -144,7 +175,7 @@ def parse_a3m(a3m_string: str) -> Tuple[Sequence[str], DeletionMatrix]: ...@@ -144,7 +175,7 @@ def parse_a3m(a3m_string: str) -> Tuple[Sequence[str], DeletionMatrix]:
at `deletion_matrix[i][j]` is the number of residues deleted from at `deletion_matrix[i][j]` is the number of residues deleted from
the aligned sequence i at residue position j. the aligned sequence i at residue position j.
""" """
sequences, _ = parse_fasta(a3m_string) sequences, descriptions = parse_fasta(a3m_string)
deletion_matrix = [] deletion_matrix = []
for msa_sequence in sequences: for msa_sequence in sequences:
deletion_vec = [] deletion_vec = []
...@@ -160,7 +191,11 @@ def parse_a3m(a3m_string: str) -> Tuple[Sequence[str], DeletionMatrix]: ...@@ -160,7 +191,11 @@ def parse_a3m(a3m_string: str) -> Tuple[Sequence[str], DeletionMatrix]:
# Make the MSA matrix out of aligned (deletion-free) sequences. # Make the MSA matrix out of aligned (deletion-free) sequences.
deletion_table = str.maketrans("", "", string.ascii_lowercase) deletion_table = str.maketrans("", "", string.ascii_lowercase)
aligned_sequences = [s.translate(deletion_table) for s in sequences] aligned_sequences = [s.translate(deletion_table) for s in sequences]
return aligned_sequences, deletion_matrix return Msa(
sequences=aligned_sequences,
deletion_matrix=deletion_matrix,
descriptions=descriptions
)
def _convert_sto_seq_to_a3m( def _convert_sto_seq_to_a3m(
...@@ -174,7 +209,9 @@ def _convert_sto_seq_to_a3m( ...@@ -174,7 +209,9 @@ def _convert_sto_seq_to_a3m(
def convert_stockholm_to_a3m( def convert_stockholm_to_a3m(
stockholm_format: str, max_sequences: Optional[int] = None stockholm_format: str,
max_sequences: Optional[int] = None,
remove_first_row_gaps: bool = True,
) -> str: ) -> str:
"""Converts MSA in Stockholm format to the A3M format.""" """Converts MSA in Stockholm format to the A3M format."""
descriptions = {} descriptions = {}
...@@ -212,13 +249,19 @@ def convert_stockholm_to_a3m( ...@@ -212,13 +249,19 @@ def convert_stockholm_to_a3m(
# Convert sto format to a3m line by line # Convert sto format to a3m line by line
a3m_sequences = {} a3m_sequences = {}
# query_sequence is assumed to be the first sequence if(remove_first_row_gaps):
query_sequence = next(iter(sequences.values())) # query_sequence is assumed to be the first sequence
query_non_gaps = [res != "-" for res in query_sequence] query_sequence = next(iter(sequences.values()))
query_non_gaps = [res != "-" for res in query_sequence]
for seqname, sto_sequence in sequences.items(): for seqname, sto_sequence in sequences.items():
a3m_sequences[seqname] = "".join( # Dots are optional in a3m format and are commonly removed.
_convert_sto_seq_to_a3m(query_non_gaps, sto_sequence) out_sequence = sto_sequence.replace('.', '')
) if(remove_first_row_gaps):
out_sequence = ''.join(
_convert_sto_seq_to_a3m(query_non_gaps, out_sequence)
)
a3m_sequences[seqname] = out_sequence
fasta_chunks = ( fasta_chunks = (
f">{k} {descriptions.get(k, '')}\n{a3m_sequences[k]}" f">{k} {descriptions.get(k, '')}\n{a3m_sequences[k]}"
...@@ -227,6 +270,124 @@ def convert_stockholm_to_a3m( ...@@ -227,6 +270,124 @@ def convert_stockholm_to_a3m(
return "\n".join(fasta_chunks) + "\n" # Include terminating newline. return "\n".join(fasta_chunks) + "\n" # Include terminating newline.
def _keep_line(line: str, seqnames: Set[str]) -> bool:
"""Function to decide which lines to keep."""
if not line.strip():
return True
if line.strip() == '//': # End tag
return True
if line.startswith('# STOCKHOLM'): # Start tag
return True
if line.startswith('#=GC RF'): # Reference Annotation Line
return True
if line[:4] == '#=GS': # Description lines - keep if sequence in list.
_, seqname, _ = line.split(maxsplit=2)
return seqname in seqnames
elif line.startswith('#'): # Other markup - filter out
return False
else: # Alignment data - keep if sequence in list.
seqname = line.partition(' ')[0]
return seqname in seqnames
def truncate_stockholm_msa(stockholm_msa_path: str, max_sequences: int) -> str:
"""Reads + truncates a Stockholm file while preventing excessive RAM usage."""
seqnames = set()
filtered_lines = []
with open(stockholm_msa_path) as f:
for line in f:
if line.strip() and not line.startswith(('#', '//')):
# Ignore blank lines, markup and end symbols - remainder are alignment
# sequence parts.
seqname = line.partition(' ')[0]
seqnames.add(seqname)
if len(seqnames) >= max_sequences:
break
f.seek(0)
for line in f:
if _keep_line(line, seqnames):
filtered_lines.append(line)
return ''.join(filtered_lines)
def remove_empty_columns_from_stockholm_msa(stockholm_msa: str) -> str:
"""Removes empty columns (dashes-only) from a Stockholm MSA."""
processed_lines = {}
unprocessed_lines = {}
for i, line in enumerate(stockholm_msa.splitlines()):
if line.startswith('#=GC RF'):
reference_annotation_i = i
reference_annotation_line = line
# Reached the end of this chunk of the alignment. Process chunk.
_, _, first_alignment = line.rpartition(' ')
mask = []
for j in range(len(first_alignment)):
for _, unprocessed_line in unprocessed_lines.items():
prefix, _, alignment = unprocessed_line.rpartition(' ')
if alignment[j] != '-':
mask.append(True)
break
else: # Every row contained a hyphen - empty column.
mask.append(False)
# Add reference annotation for processing with mask.
unprocessed_lines[reference_annotation_i] = reference_annotation_line
if not any(mask): # All columns were empty. Output empty lines for chunk.
for line_index in unprocessed_lines:
processed_lines[line_index] = ''
else:
for line_index, unprocessed_line in unprocessed_lines.items():
prefix, _, alignment = unprocessed_line.rpartition(' ')
masked_alignment = ''.join(itertools.compress(alignment, mask))
processed_lines[line_index] = f'{prefix} {masked_alignment}'
# Clear raw_alignments.
unprocessed_lines = {}
elif line.strip() and not line.startswith(('#', '//')):
unprocessed_lines[i] = line
else:
processed_lines[i] = line
return '\n'.join((processed_lines[i] for i in range(len(processed_lines))))
def deduplicate_stockholm_msa(stockholm_msa: str) -> str:
"""Remove duplicate sequences (ignoring insertions wrt query)."""
sequence_dict = collections.defaultdict(str)
# First we must extract all sequences from the MSA.
for line in stockholm_msa.splitlines():
# Only consider the alignments - ignore reference annotation, empty lines,
# descriptions or markup.
if line.strip() and not line.startswith(('#', '//')):
line = line.strip()
seqname, alignment = line.split()
sequence_dict[seqname] += alignment
seen_sequences = set()
seqnames = set()
# First alignment is the query.
query_align = next(iter(sequence_dict.values()))
mask = [c != '-' for c in query_align] # Mask is False for insertions.
for seqname, alignment in sequence_dict.items():
# Apply mask to remove all insertions from the string.
masked_alignment = ''.join(itertools.compress(alignment, mask))
if masked_alignment in seen_sequences:
continue
else:
seen_sequences.add(masked_alignment)
seqnames.add(seqname)
filtered_lines = []
for line in stockholm_msa.splitlines():
if _keep_line(line, seqnames):
filtered_lines.append(line)
return '\n'.join(filtered_lines) + '\n'
def _get_hhr_line_regex_groups( def _get_hhr_line_regex_groups(
regex_pattern: str, line: str regex_pattern: str, line: str
) -> Sequence[Optional[str]]: ) -> Sequence[Optional[str]]:
...@@ -280,7 +441,7 @@ def _parse_hhr_hit(detailed_lines: Sequence[str]) -> TemplateHit: ...@@ -280,7 +441,7 @@ def _parse_hhr_hit(detailed_lines: Sequence[str]) -> TemplateHit:
"Could not parse section: %s. Expected this: \n%s to contain summary." "Could not parse section: %s. Expected this: \n%s to contain summary."
% (detailed_lines, detailed_lines[2]) % (detailed_lines, detailed_lines[2])
) )
(prob_true, e_value, _, aligned_cols, _, _, sum_probs, neff) = [ (_, _, _, aligned_cols, _, _, sum_probs, _) = [
float(x) for x in match.groups() float(x) for x in match.groups()
] ]
...@@ -388,3 +549,115 @@ def parse_e_values_from_tblout(tblout: str) -> Dict[str, float]: ...@@ -388,3 +549,115 @@ def parse_e_values_from_tblout(tblout: str) -> Dict[str, float]:
target_name = fields[0] target_name = fields[0]
e_values[target_name] = float(e_value) e_values[target_name] = float(e_value)
return e_values return e_values
def _get_indices(sequence: str, start: int) -> List[int]:
"""Returns indices for non-gap/insert residues starting at the given index."""
indices = []
counter = start
for symbol in sequence:
# Skip gaps but add a placeholder so that the alignment is preserved.
if symbol == '-':
indices.append(-1)
# Skip deleted residues, but increase the counter.
elif symbol.islower():
counter += 1
# Normal aligned residue. Increase the counter and append to indices.
else:
indices.append(counter)
counter += 1
return indices
@dataclasses.dataclass(frozen=True)
class HitMetadata:
pdb_id: str
chain: str
start: int
end: int
length: int
text: str
def _parse_hmmsearch_description(description: str) -> HitMetadata:
"""Parses the hmmsearch A3M sequence description line."""
# Example 1: >4pqx_A/2-217 [subseq from] mol:protein length:217 Free text
# Example 2: >5g3r_A/1-55 [subseq from] mol:protein length:352
match = re.match(
r'^>?([a-z0-9]+)_(\w+)/([0-9]+)-([0-9]+).*protein length:([0-9]+) *(.*)$',
description.strip())
if not match:
raise ValueError(f'Could not parse description: "{description}".')
return HitMetadata(
pdb_id=match[1],
chain=match[2],
start=int(match[3]),
end=int(match[4]),
length=int(match[5]),
text=match[6]
)
def parse_hmmsearch_a3m(
query_sequence: str,
a3m_string: str,
skip_first: bool = True
) -> Sequence[TemplateHit]:
"""Parses an a3m string produced by hmmsearch.
Args:
query_sequence: The query sequence.
a3m_string: The a3m string produced by hmmsearch.
skip_first: Whether to skip the first sequence in the a3m string.
Returns:
A sequence of `TemplateHit` results.
"""
# Zip the descriptions and MSAs together, skip the first query sequence.
parsed_a3m = list(zip(*parse_fasta(a3m_string)))
if skip_first:
parsed_a3m = parsed_a3m[1:]
indices_query = _get_indices(query_sequence, start=0)
hits = []
for i, (hit_sequence, hit_description) in enumerate(parsed_a3m, start=1):
if 'mol:protein' not in hit_description:
continue # Skip non-protein chains.
metadata = _parse_hmmsearch_description(hit_description)
# Aligned columns are only the match states.
aligned_cols = sum([r.isupper() and r != '-' for r in hit_sequence])
indices_hit = _get_indices(hit_sequence, start=metadata.start - 1)
hit = TemplateHit(
index=i,
name=f'{metadata.pdb_id}_{metadata.chain}',
aligned_cols=aligned_cols,
sum_probs=None,
query=query_sequence,
hit_sequence=hit_sequence.upper(),
indices_query=indices_query,
indices_hit=indices_hit,
)
hits.append(hit)
return hits
def parse_hmmsearch_sto(
output_string: str,
input_sequence: str
) -> Sequence[TemplateHit]:
"""Gets parsed template hits from the raw string output by the tool."""
a3m_string = convert_stockholm_to_a3m(
output_string,
remove_first_row_gaps=False
)
template_hits = parse_hmmsearch_a3m(
query_sequence=input_sequence,
a3m_string=a3m_string,
skip_first=False
)
return template_hits
...@@ -14,8 +14,10 @@ ...@@ -14,8 +14,10 @@
# limitations under the License. # limitations under the License.
"""Functions for getting templates and calculating template features.""" """Functions for getting templates and calculating template features."""
import abc
import dataclasses import dataclasses
import datetime import datetime
import functools
import glob import glob
import json import json
import logging import logging
...@@ -65,10 +67,6 @@ class DateError(PrefilterError): ...@@ -65,10 +67,6 @@ class DateError(PrefilterError):
"""An error indicating that the hit date was after the max allowed date.""" """An error indicating that the hit date was after the max allowed date."""
class PdbIdError(PrefilterError):
"""An error indicating that the hit PDB ID was identical to the query."""
class AlignRatioError(PrefilterError): class AlignRatioError(PrefilterError):
"""An error indicating that the hit align ratio to the query was too small.""" """An error indicating that the hit align ratio to the query was too small."""
...@@ -91,6 +89,24 @@ TEMPLATE_FEATURES = { ...@@ -91,6 +89,24 @@ TEMPLATE_FEATURES = {
} }
def empty_template_feats(n_res):
return {
"template_aatype": np.zeros(
(0, n_res, len(residue_constants.restypes_with_x_and_gap)),
np.float32
),
"template_all_atom_mask": np.zeros(
(0, n_res, residue_constants.atom_type_num), np.float32
),
"template_all_atom_positions": np.zeros(
(0, n_res, residue_constants.atom_type_num, 3), np.float32
),
"template_domain_names": np.array([''.encode()], dtype=object),
"template_sequence": np.array([''.encode()], dtype=object),
"template_sum_probs": np.zeros((0, 1), dtype=np.float32),
}
def _get_pdb_id_and_chain(hit: parsers.TemplateHit) -> Tuple[str, str]: def _get_pdb_id_and_chain(hit: parsers.TemplateHit) -> Tuple[str, str]:
"""Returns PDB id and chain id for an HHSearch Hit.""" """Returns PDB id and chain id for an HHSearch Hit."""
# PDB ID: 4 letters. Chain ID: 1+ alphanumeric letters or "." if unknown. # PDB ID: 4 letters. Chain ID: 1+ alphanumeric letters or "." if unknown.
...@@ -204,7 +220,6 @@ def _assess_hhsearch_hit( ...@@ -204,7 +220,6 @@ def _assess_hhsearch_hit(
hit: parsers.TemplateHit, hit: parsers.TemplateHit,
hit_pdb_code: str, hit_pdb_code: str,
query_sequence: str, query_sequence: str,
query_pdb_code: Optional[str],
release_dates: Mapping[str, datetime.datetime], release_dates: Mapping[str, datetime.datetime],
release_date_cutoff: datetime.datetime, release_date_cutoff: datetime.datetime,
max_subsequence_ratio: float = 0.95, max_subsequence_ratio: float = 0.95,
...@@ -218,7 +233,6 @@ def _assess_hhsearch_hit( ...@@ -218,7 +233,6 @@ def _assess_hhsearch_hit(
different from the value in the actual hit since the original pdb might different from the value in the actual hit since the original pdb might
have become obsolete. have become obsolete.
query_sequence: Amino acid sequence of the query. query_sequence: Amino acid sequence of the query.
query_pdb_code: 4 letter pdb code of the query.
release_dates: Dictionary mapping pdb codes to their structure release release_dates: Dictionary mapping pdb codes to their structure release
dates. dates.
release_date_cutoff: Max release date that is valid for this query. release_date_cutoff: Max release date that is valid for this query.
...@@ -230,7 +244,6 @@ def _assess_hhsearch_hit( ...@@ -230,7 +244,6 @@ def _assess_hhsearch_hit(
Raises: Raises:
DateError: If the hit date was after the max allowed date. DateError: If the hit date was after the max allowed date.
PdbIdError: If the hit PDB ID was identical to the query.
AlignRatioError: If the hit align ratio to the query was too small. AlignRatioError: If the hit align ratio to the query was too small.
DuplicateError: If the hit was an exact subsequence of the query. DuplicateError: If the hit was an exact subsequence of the query.
LengthError: If the hit was too short. LengthError: If the hit was too short.
...@@ -241,13 +254,6 @@ def _assess_hhsearch_hit( ...@@ -241,13 +254,6 @@ def _assess_hhsearch_hit(
template_sequence = hit.hit_sequence.replace("-", "") template_sequence = hit.hit_sequence.replace("-", "")
length_ratio = float(len(template_sequence)) / len(query_sequence) length_ratio = float(len(template_sequence)) / len(query_sequence)
# Check whether the template is a large subsequence or duplicate of original
# query. This can happen due to duplicate entries in the PDB database.
duplicate = (
template_sequence in query_sequence
and length_ratio > max_subsequence_ratio
)
if _is_after_cutoff(hit_pdb_code, release_dates, release_date_cutoff): if _is_after_cutoff(hit_pdb_code, release_dates, release_date_cutoff):
date = release_dates[hit_pdb_code.upper()] date = release_dates[hit_pdb_code.upper()]
raise DateError( raise DateError(
...@@ -255,16 +261,19 @@ def _assess_hhsearch_hit( ...@@ -255,16 +261,19 @@ def _assess_hhsearch_hit(
f"({release_date_cutoff})." f"({release_date_cutoff})."
) )
if query_pdb_code is not None:
if query_pdb_code.lower() == hit_pdb_code.lower():
raise PdbIdError("PDB code identical to Query PDB code.")
if align_ratio <= min_align_ratio: if align_ratio <= min_align_ratio:
raise AlignRatioError( raise AlignRatioError(
"Proportion of residues aligned to query too small. " "Proportion of residues aligned to query too small. "
f"Align ratio: {align_ratio}." f"Align ratio: {align_ratio}."
) )
# Check whether the template is a large subsequence or duplicate of original
# query. This can happen due to duplicate entries in the PDB database.
duplicate = (
template_sequence in query_sequence
and length_ratio > max_subsequence_ratio
)
if duplicate: if duplicate:
raise DuplicateError( raise DuplicateError(
"Template is an exact subsequence of query with large " "Template is an exact subsequence of query with large "
...@@ -424,9 +433,10 @@ def _realign_pdb_template_to_query( ...@@ -424,9 +433,10 @@ def _realign_pdb_template_to_query(
) )
try: try:
(old_aligned_template, new_aligned_template), _ = parsers.parse_a3m( parsed_a3m = parsers.parse_a3m(
aligner.align([old_template_sequence, new_template_sequence]) aligner.align([old_template_sequence, new_template_sequence])
) )
old_aligned_template, new_aligned_template = parsed_a3m.sequences
except Exception as e: except Exception as e:
raise QueryToTemplateAlignError( raise QueryToTemplateAlignError(
"Could not align old template %s to template %s (%s_%s). Error: %s" "Could not align old template %s to template %s (%s_%s). Error: %s"
...@@ -768,7 +778,6 @@ class SingleHitResult: ...@@ -768,7 +778,6 @@ class SingleHitResult:
def _prefilter_hit( def _prefilter_hit(
query_sequence: str, query_sequence: str,
query_pdb_code: Optional[str],
hit: parsers.TemplateHit, hit: parsers.TemplateHit,
max_template_date: datetime.datetime, max_template_date: datetime.datetime,
release_dates: Mapping[str, datetime.datetime], release_dates: Mapping[str, datetime.datetime],
...@@ -789,17 +798,14 @@ def _prefilter_hit( ...@@ -789,17 +798,14 @@ def _prefilter_hit(
hit=hit, hit=hit,
hit_pdb_code=hit_pdb_code, hit_pdb_code=hit_pdb_code,
query_sequence=query_sequence, query_sequence=query_sequence,
query_pdb_code=query_pdb_code,
release_dates=release_dates, release_dates=release_dates,
release_date_cutoff=max_template_date, release_date_cutoff=max_template_date,
) )
except PrefilterError as e: except PrefilterError as e:
hit_name = f"{hit_pdb_code}_{hit_chain_id}" hit_name = f"{hit_pdb_code}_{hit_chain_id}"
msg = f"hit {hit_name} did not pass prefilter: {str(e)}" msg = f"hit {hit_name} did not pass prefilter: {str(e)}"
logging.info("%s: %s", query_pdb_code, msg) logging.info(msg)
if strict_error_check and isinstance( if strict_error_check and isinstance(e, (DateError, DuplicateError)):
e, (DateError, PdbIdError, DuplicateError)
):
# In strict mode we treat some prefilter cases as errors. # In strict mode we treat some prefilter cases as errors.
return PrefilterResult(valid=False, error=msg, warning=None) return PrefilterResult(valid=False, error=msg, warning=None)
...@@ -808,9 +814,16 @@ def _prefilter_hit( ...@@ -808,9 +814,16 @@ def _prefilter_hit(
return PrefilterResult(valid=True, error=None, warning=None) return PrefilterResult(valid=True, error=None, warning=None)
@functools.lru_cache(16, typed=False)
def _read_file(path):
with open(path, 'r') as f:
file_data = f.read()
return file_data
def _process_single_hit( def _process_single_hit(
query_sequence: str, query_sequence: str,
query_pdb_code: Optional[str],
hit: parsers.TemplateHit, hit: parsers.TemplateHit,
mmcif_dir: str, mmcif_dir: str,
max_template_date: datetime.datetime, max_template_date: datetime.datetime,
...@@ -847,9 +860,9 @@ def _process_single_hit( ...@@ -847,9 +860,9 @@ def _process_single_hit(
query_sequence, query_sequence,
template_sequence, template_sequence,
) )
# Fail if we can't find the mmCIF file. # Fail if we can't find the mmCIF file.
with open(cif_path, "r") as cif_file: cif_string = _read_file(cif_path)
cif_string = cif_file.read()
parsing_result = mmcif_parsing.parse( parsing_result = mmcif_parsing.parse(
file_id=hit_pdb_code, mmcif_string=cif_string file_id=hit_pdb_code, mmcif_string=cif_string
...@@ -882,7 +895,11 @@ def _process_single_hit( ...@@ -882,7 +895,11 @@ def _process_single_hit(
kalign_binary_path=kalign_binary_path, kalign_binary_path=kalign_binary_path,
_zero_center_positions=_zero_center_positions, _zero_center_positions=_zero_center_positions,
) )
features["template_sum_probs"] = [hit.sum_probs]
if hit.sum_probs is None:
features["template_sum_probs"] = [0]
else:
features["template_sum_probs"] = [hit.sum_probs]
# It is possible there were some errors when parsing the other chains in the # It is possible there were some errors when parsing the other chains in the
# mmCIF file, but the template features for the chain we want were still # mmCIF file, but the template features for the chain we want were still
...@@ -903,7 +920,7 @@ def _process_single_hit( ...@@ -903,7 +920,7 @@ def _process_single_hit(
% ( % (
hit_pdb_code, hit_pdb_code,
hit_chain_id, hit_chain_id,
hit.sum_probs, hit.sum_probs if hit.sum_probs else 0.,
hit.index, hit.index,
str(e), str(e),
parsing_result.errors, parsing_result.errors,
...@@ -920,7 +937,7 @@ def _process_single_hit( ...@@ -920,7 +937,7 @@ def _process_single_hit(
% ( % (
hit_pdb_code, hit_pdb_code,
hit_chain_id, hit_chain_id,
hit.sum_probs, hit.sum_probs if hit.sum_probs else 0.,
hit.index, hit.index,
str(e), str(e),
parsing_result.errors, parsing_result.errors,
...@@ -986,8 +1003,8 @@ class TemplateSearchResult: ...@@ -986,8 +1003,8 @@ class TemplateSearchResult:
warnings: Sequence[str] warnings: Sequence[str]
class TemplateHitFeaturizer: class TemplateHitFeaturizer(abc.ABC):
"""A class for turning hhr hits to template features.""" """An abstract base class for turning template hits to features."""
def __init__( def __init__(
self, self,
mmcif_dir: str, mmcif_dir: str,
...@@ -1036,7 +1053,7 @@ class TemplateHitFeaturizer: ...@@ -1036,7 +1053,7 @@ class TemplateHitFeaturizer:
raise ValueError( raise ValueError(
"max_template_date must be set and have format YYYY-MM-DD." "max_template_date must be set and have format YYYY-MM-DD."
) )
self.max_hits = max_hits self._max_hits = max_hits
self._kalign_binary_path = kalign_binary_path self._kalign_binary_path = kalign_binary_path
self._strict_error_check = strict_error_check self._strict_error_check = strict_error_check
...@@ -1059,31 +1076,29 @@ class TemplateHitFeaturizer: ...@@ -1059,31 +1076,29 @@ class TemplateHitFeaturizer:
self._shuffle_top_k_prefiltered = _shuffle_top_k_prefiltered self._shuffle_top_k_prefiltered = _shuffle_top_k_prefiltered
self._zero_center_positions = _zero_center_positions self._zero_center_positions = _zero_center_positions
@abc.abstractmethod
def get_templates(
self,
query_sequence: str,
hits: Sequence[parsers.TemplateHit]
) -> TemplateSearchResult:
""" Computes the templates for a given query sequence """
class HhsearchHitFeaturizer(TemplateHitFeaturizer):
def get_templates( def get_templates(
self, self,
query_sequence: str, query_sequence: str,
query_pdb_code: Optional[str],
query_release_date: Optional[datetime.datetime],
hits: Sequence[parsers.TemplateHit], hits: Sequence[parsers.TemplateHit],
) -> TemplateSearchResult: ) -> TemplateSearchResult:
"""Computes the templates for given query sequence (more details above).""" """Computes the templates for given query sequence (more details above)."""
logging.info("Searching for template for: %s", query_pdb_code) logging.info("Searching for template for: %s", query_sequence)
template_features = {} template_features = {}
for template_feature_name in TEMPLATE_FEATURES: for template_feature_name in TEMPLATE_FEATURES:
template_features[template_feature_name] = [] template_features[template_feature_name] = []
# Always use a max_template_date. Set to query_release_date minus 60 days already_seen = set()
# if that's earlier.
template_cutoff_date = self._max_template_date
if query_release_date:
delta = datetime.timedelta(days=60)
if query_release_date - delta < template_cutoff_date:
template_cutoff_date = query_release_date - delta
assert template_cutoff_date < query_release_date
assert template_cutoff_date <= self._max_template_date
num_hits = 0
errors = [] errors = []
warnings = [] warnings = []
...@@ -1091,9 +1106,8 @@ class TemplateHitFeaturizer: ...@@ -1091,9 +1106,8 @@ class TemplateHitFeaturizer:
for hit in hits: for hit in hits:
prefilter_result = _prefilter_hit( prefilter_result = _prefilter_hit(
query_sequence=query_sequence, query_sequence=query_sequence,
query_pdb_code=query_pdb_code,
hit=hit, hit=hit,
max_template_date=template_cutoff_date, max_template_date=self._max_template_date,
release_dates=self._release_dates, release_dates=self._release_dates,
obsolete_pdbs=self._obsolete_pdbs, obsolete_pdbs=self._obsolete_pdbs,
strict_error_check=self._strict_error_check, strict_error_check=self._strict_error_check,
...@@ -1119,17 +1133,16 @@ class TemplateHitFeaturizer: ...@@ -1119,17 +1133,16 @@ class TemplateHitFeaturizer:
for i in idx: for i in idx:
# We got all the templates we wanted, stop processing hits. # We got all the templates we wanted, stop processing hits.
if num_hits >= self.max_hits: if len(already_seen) >= self._max_hits:
break break
hit = filtered[i] hit = filtered[i]
result = _process_single_hit( result = _process_single_hit(
query_sequence=query_sequence, query_sequence=query_sequence,
query_pdb_code=query_pdb_code,
hit=hit, hit=hit,
mmcif_dir=self._mmcif_dir, mmcif_dir=self._mmcif_dir,
max_template_date=template_cutoff_date, max_template_date=self._max_template_date,
release_dates=self._release_dates, release_dates=self._release_dates,
obsolete_pdbs=self._obsolete_pdbs, obsolete_pdbs=self._obsolete_pdbs,
strict_error_check=self._strict_error_check, strict_error_check=self._strict_error_check,
...@@ -1153,22 +1166,124 @@ class TemplateHitFeaturizer: ...@@ -1153,22 +1166,124 @@ class TemplateHitFeaturizer:
result.warning, result.warning,
) )
else: else:
# Increment the hit counter, since we got features out of this hit. already_seen_key = result.features["template_sequence"]
num_hits += 1 if(already_seen_key in already_seen):
continue
already_seen.add(already_seen_key)
for k in template_features: for k in template_features:
template_features[k].append(result.features[k]) template_features[k].append(result.features[k])
for name in template_features: if already_seen:
if num_hits > 0: for name in template_features:
template_features[name] = np.stack( template_features[name] = np.stack(
template_features[name], axis=0 template_features[name], axis=0
).astype(TEMPLATE_FEATURES[name]) ).astype(TEMPLATE_FEATURES[name])
else: else:
# Make sure the feature has correct dtype even if empty. num_res = len(query_sequence)
template_features[name] = np.array( # Construct a default template with all zeros.
[], dtype=TEMPLATE_FEATURES[name] template_features = empty_template_feats(num_res)
)
return TemplateSearchResult( return TemplateSearchResult(
features=template_features, errors=errors, warnings=warnings features=template_features, errors=errors, warnings=warnings
) )
class HmmsearchHitFeaturizer(TemplateHitFeaturizer):
def get_templates(
self,
query_sequence: str,
hits: Sequence[parsers.TemplateHit]
) -> TemplateSearchResult:
logging.info("Searching for template for: %s", query_sequence)
template_features = {}
for template_feature_name in TEMPLATE_FEATURES:
template_features[template_feature_name] = []
already_seen = set()
errors = []
warnings = []
# DISCREPANCY: This filtering scheme that saves time
filtered = []
for hit in hits:
prefilter_result = _prefilter_hit(
query_sequence=query_sequence,
hit=hit,
max_template_date=self._max_template_date,
release_dates=self._release_dates,
obsolete_pdbs=self._obsolete_pdbs,
strict_error_check=self._strict_error_check,
)
if prefilter_result.error:
errors.append(prefilter_result.error)
if prefilter_result.warning:
warnings.append(prefilter_result.warning)
if prefilter_result.valid:
filtered.append(hit)
filtered = list(
sorted(
filtered, key=lambda x: x.sum_probs if x.sum_probs else 0., reverse=True
)
)
idx = list(range(len(filtered)))
if(self._shuffle_top_k_prefiltered):
stk = self._shuffle_top_k_prefiltered
idx[:stk] = np.random.permutation(idx[:stk])
for i in idx:
if(len(already_seen) >= self._max_hits):
break
hit = filtered[i]
result = _process_single_hit(
query_sequence=query_sequence,
hit=hit,
mmcif_dir=self._mmcif_dir,
max_template_date = self._max_template_date,
release_dates = self._release_dates,
obsolete_pdbs = self._obsolete_pdbs,
strict_error_check = self._strict_error_check,
kalign_binary_path = self._kalign_binary_path
)
if result.error:
errors.append(result.error)
if result.warning:
warnings.append(result.warning)
if result.features is None:
logging.debug(
"Skipped invalid hit %s, error: %s, warning: %s",
hit.name, result.error, result.warning,
)
else:
already_seen_key = result.features["template_sequence"]
if(already_seen_key in already_seen):
continue
# Increment the hit counter, since we got features out of this hit.
already_seen.add(already_seen_key)
for k in template_features:
template_features[k].append(result.features[k])
if already_seen:
for name in template_features:
template_features[name] = np.stack(
template_features[name], axis=0
).astype(TEMPLATE_FEATURES[name])
else:
num_res = len(query_sequence)
# Construct a default template with all zeros.
template_features = empty_template_feats(num_res)
return TemplateSearchResult(
features=template_features,
errors=errors,
warnings=warnings,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment