Unverified Commit bb3f51e5 authored by Christina Floristean's avatar Christina Floristean Committed by GitHub
Browse files

Merge pull request #405 from aqlaboratory/multimer

Full multimer merge
parents ce211367 c33a0bd6
.vscode/ .vscode/
.idea/
__pycache__/ __pycache__/
*.egg-info *.egg-info
build build
...@@ -8,3 +9,4 @@ dist ...@@ -8,3 +9,4 @@ dist
data data
openfold/resources/ openfold/resources/
tests/test_data/ tests/test_data/
cutlass/
...@@ -7,13 +7,31 @@ _Figure: Comparison of OpenFold and AlphaFold2 predictions to the experimental s ...@@ -7,13 +7,31 @@ _Figure: Comparison of OpenFold and AlphaFold2 predictions to the experimental s
A faithful but trainable PyTorch reproduction of DeepMind's A faithful but trainable PyTorch reproduction of DeepMind's
[AlphaFold 2](https://github.com/deepmind/alphafold). [AlphaFold 2](https://github.com/deepmind/alphafold).
## Contents
- [OpenFold](#openfold)
- [Contents](#contents)
- [Features](#features)
- [Installation (Linux)](#installation-linux)
- [Download Alignment Databases](#download-alignment-databases)
- [Inference](#inference)
- [Monomer inference](#monomer-inference)
- [Multimer Inference](#multimer-inference)
- [Soloseq Inference](#soloseq-inference)
- [Training](#training)
- [Testing](#testing)
- [Building and Using the Docker Container](#building-and-using-the-docker-container)
- [Copyright Notice](#copyright-notice)
- [Contributing](#contributing)
- [Citing this Work](#citing-this-work)
## Features ## Features
OpenFold carefully reproduces (almost) all of the features of the original open OpenFold carefully reproduces (almost) all of the features of the original open
source inference code (v2.0.1). The sole exception is model ensembling, which source monomer (v2.0.1) and multimer (v2.3.2) inference code. The sole exception is
fared poorly in DeepMind's own ablation testing and is being phased out in future model ensembling, which fared poorly in DeepMind's own ablation testing and is being
DeepMind experiments. It is omitted here for the sake of reducing clutter. In phased out in future DeepMind experiments. It is omitted here for the sake of reducing
cases where the *Nature* paper differs from the source, we always defer to the clutter. In cases where the *Nature* paper differs from the source, we always defer to the
latter. latter.
OpenFold is trainable in full precision, half precision, or `bfloat16` with or without DeepSpeed, OpenFold is trainable in full precision, half precision, or `bfloat16` with or without DeepSpeed,
...@@ -63,7 +81,7 @@ To install: ...@@ -63,7 +81,7 @@ To install:
For some systems, it may help to append the Conda environment library path to `$LD_LIBRARY_PATH`. The `install_third_party_dependencies.sh` script does this once, but you may need this for each bash instance. For some systems, it may help to append the Conda environment library path to `$LD_LIBRARY_PATH`. The `install_third_party_dependencies.sh` script does this once, but you may need this for each bash instance.
## Usage ## Download Alignment Databases
If you intend to generate your own alignments, e.g. for inference, you have two If you intend to generate your own alignments, e.g. for inference, you have two
choices for downloading protein databases, depending on whether you want to use choices for downloading protein databases, depending on whether you want to use
...@@ -112,7 +130,16 @@ DeepMind's pretrained parameters, you will only be able to make changes that ...@@ -112,7 +130,16 @@ DeepMind's pretrained parameters, you will only be able to make changes that
do not affect the shapes of model parameters. For an example of initializing do not affect the shapes of model parameters. For an example of initializing
the model, consult `run_pretrained_openfold.py`. the model, consult `run_pretrained_openfold.py`.
### Inference ## Inference
OpenFold now supports three inference modes:
- [Monomer Inference](#monomer-inference): OpenFold reproduction of AlphaFold2. Inference available with either DeepMind's pretrained parameters or OpenFold trained parameters.
- [Multimer Inference](#multimer-inference): OpenFold reproduction of AlphaFold-Multimer. Inference available with DeepMind's pre-trained parameters.
- [Single Sequence Inference (SoloSeq)](#soloseq-inference): Language Model based structure prediction, using [ESM-1b](https://github.com/facebookresearch/esm) embeddings.
More instructions for each inference mode are provided below:
### Monomer inference
To run inference on a sequence or multiple sequences using a set of DeepMind's To run inference on a sequence or multiple sequences using a set of DeepMind's
pretrained parameters, first download the OpenFold weights e.g.: pretrained parameters, first download the OpenFold weights e.g.:
...@@ -131,14 +158,14 @@ python3 run_pretrained_openfold.py \ ...@@ -131,14 +158,14 @@ python3 run_pretrained_openfold.py \
--mgnify_database_path data/mgnify/mgy_clusters_2018_12.fa \ --mgnify_database_path data/mgnify/mgy_clusters_2018_12.fa \
--pdb70_database_path data/pdb70/pdb70 \ --pdb70_database_path data/pdb70/pdb70 \
--uniclust30_database_path data/uniclust30/uniclust30_2018_08/uniclust30_2018_08 \ --uniclust30_database_path data/uniclust30/uniclust30_2018_08/uniclust30_2018_08 \
--output_dir ./ \
--bfd_database_path data/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt \ --bfd_database_path data/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt \
--model_device "cuda:0" \
--jackhmmer_binary_path lib/conda/envs/openfold_venv/bin/jackhmmer \ --jackhmmer_binary_path lib/conda/envs/openfold_venv/bin/jackhmmer \
--hhblits_binary_path lib/conda/envs/openfold_venv/bin/hhblits \ --hhblits_binary_path lib/conda/envs/openfold_venv/bin/hhblits \
--hhsearch_binary_path lib/conda/envs/openfold_venv/bin/hhsearch \ --hhsearch_binary_path lib/conda/envs/openfold_venv/bin/hhsearch \
--kalign_binary_path lib/conda/envs/openfold_venv/bin/kalign \ --kalign_binary_path lib/conda/envs/openfold_venv/bin/kalign \
--config_preset "model_1_ptm" \ --config_preset "model_1_ptm" \
--model_device "cuda:0" \
--output_dir ./ \
--openfold_checkpoint_path openfold/resources/openfold_params/finetuning_ptm_2.pt --openfold_checkpoint_path openfold/resources/openfold_params/finetuning_ptm_2.pt
``` ```
...@@ -176,13 +203,6 @@ To enable it, add `--trace_model` to the inference command. ...@@ -176,13 +203,6 @@ To enable it, add `--trace_model` to the inference command.
To get a speedup during inference, enable [FlashAttention](https://github.com/HazyResearch/flash-attention) To get a speedup during inference, enable [FlashAttention](https://github.com/HazyResearch/flash-attention)
in the config. Note that it appears to work best for sequences with < 1000 residues. in the config. Note that it appears to work best for sequences with < 1000 residues.
Input FASTA files containing multiple sequences are treated as complexes. In
this case, the inference script runs AlphaFold-Gap, a hack proposed
[here](https://twitter.com/minkbaek/status/1417538291709071362?lang=en), using
the specified stock AlphaFold/OpenFold parameters (NOT AlphaFold-Multimer). To
run inference with AlphaFold-Multimer, use the (experimental) `multimer` branch
instead.
To minimize memory usage during inference on long sequences, consider the To minimize memory usage during inference on long sequences, consider the
following changes: following changes:
...@@ -221,7 +241,78 @@ efficent AlphaFold-Multimer more than double the time. Use the ...@@ -221,7 +241,78 @@ efficent AlphaFold-Multimer more than double the time. Use the
at once. The `run_pretrained_openfold.py` script can enable this config option with the at once. The `run_pretrained_openfold.py` script can enable this config option with the
`--long_sequence_inference` command line option `--long_sequence_inference` command line option
#### SoloSeq Inference Input FASTA files containing multiple sequences are treated as complexes. In
this case, the inference script runs AlphaFold-Gap, a hack proposed
[here](https://twitter.com/minkbaek/status/1417538291709071362?lang=en), using
the specified stock AlphaFold/OpenFold parameters (NOT AlphaFold-Multimer).
### Multimer Inference
To run inference on a complex or multiple complexes using a set of DeepMind's pretrained parameters, run e.g.:
```bash
python3 run_pretrained_openfold.py \
fasta_dir \
data/pdb_mmcif/mmcif_files/ \
--uniref90_database_path data/uniref90/uniref90.fasta \
--mgnify_database_path data/mgnify/mgy_clusters_2022_05.fa \
--pdb_seqres_database_path data/pdb_seqres/pdb_seqres.txt \
--uniref30_database_path data/uniref30/UniRef30_2021_03 \
--uniprot_database_path data/uniprot/uniprot.fasta \
--bfd_database_path data/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt \
--jackhmmer_binary_path lib/conda/envs/openfold_venv/bin/jackhmmer \
--hhblits_binary_path lib/conda/envs/openfold_venv/bin/hhblits \
--hmmsearch_binary_path lib/conda/envs/openfold_venv/bin/hmmsearch \
--hmmbuild_binary_path lib/conda/envs/openfold_venv/bin/hmmbuild \
--kalign_binary_path lib/conda/envs/openfold_venv/bin/kalign \
--config_preset "model_1_multimer_v3" \
--model_device "cuda:0" \
--output_dir ./
```
As with monomer inference, if you've already computed alignments for the query, you can use
the `--use_precomputed_alignments` option. Note that template searching in the multimer pipeline
uses HMMSearch with the PDB SeqRes database, replacing HHSearch and PDB70 used in the monomer pipeline.
**Upgrade from an existing OpenFold installation**
The above command requires several upgrades to existing openfold installations.
1. Re-download the alphafold parameters to get the latest
AlphaFold-Multimer v3 weights:
```bash
bash scripts/download_alphafold_params.sh openfold/resources
```
2. Download the [UniProt](https://www.uniprot.org/uniprotkb/)
and [PDB SeqRes](https://www.rcsb.org/) databases:
```bash
bash scripts/download_uniprot.sh data/
```
The PDB SeqRes and PDB databases must be from the same date to avoid potential
errors during template searching. Remove the existing `data/pdb_mmcif` directory
and download both databases:
```bash
bash scripts/download_pdb_mmcif.sh data/
bash scripts/download_pdb_seqres.sh data/
```
3. Additionally, AlphaFold-Multimer uses upgraded versions of the [MGnify](https://www.ebi.ac.uk/metagenomics)
and [UniRef30](https://uniclust.mmseqs.com/) (previously UniClust30) databases. To download the upgraded databases, run:
```bash
bash scripts/download_uniref30.sh data/
bash scripts/download_mgnify.sh data/
```
Multimer inference can also run with the older database versions if desired.
### Soloseq Inference
To run inference for a sequence using the SoloSeq single-sequence model, you can either precompute ESM-1b embeddings in bulk, or you can generate them during inference. To run inference for a sequence using the SoloSeq single-sequence model, you can either precompute ESM-1b embeddings in bulk, or you can generate them during inference.
For generating ESM-1b embeddings in bulk, use the provided script: `scripts/precompute_embeddings.py`. The script takes a directory of FASTA files (one sequence per file) and generates ESM-1b embeddings in the same format and directory structure as required by SoloSeq. Following is an example command to use the script: For generating ESM-1b embeddings in bulk, use the provided script: `scripts/precompute_embeddings.py`. The script takes a directory of FASTA files (one sequence per file) and generates ESM-1b embeddings in the same format and directory structure as required by SoloSeq. Following is an example command to use the script:
...@@ -260,7 +351,7 @@ python3 run_pretrained_openfold.py \ ...@@ -260,7 +351,7 @@ python3 run_pretrained_openfold.py \
--output_dir ./ \ --output_dir ./ \
--model_device "cuda:0" \ --model_device "cuda:0" \
--config_preset "seq_model_esm1b_ptm" \ --config_preset "seq_model_esm1b_ptm" \
--openfold_checkpoint_path openfold/resources/openfold_params/seq_model_esm1b_ptm.pt \ --openfold_checkpoint_path openfold/resources/openfold_soloseq_params/seq_model_esm1b_ptm.pt \
--uniref90_database_path data/uniref90/uniref90.fasta \ --uniref90_database_path data/uniref90/uniref90.fasta \
--pdb70_database_path data/pdb70/pdb70 \ --pdb70_database_path data/pdb70/pdb70 \
--jackhmmer_binary_path lib/conda/envs/openfold_venv/bin/jackhmmer \ --jackhmmer_binary_path lib/conda/envs/openfold_venv/bin/jackhmmer \
...@@ -274,7 +365,7 @@ SoloSeq allows you to use the same flags and optimizations as the MSA-based Open ...@@ -274,7 +365,7 @@ SoloSeq allows you to use the same flags and optimizations as the MSA-based Open
**NOTE:** Due to the nature of the ESM-1b embeddings, the sequence length for inference using the SoloSeq model is limited to 1022 residues. Sequences longer than that will be truncated. **NOTE:** Due to the nature of the ESM-1b embeddings, the sequence length for inference using the SoloSeq model is limited to 1022 residues. Sequences longer than that will be truncated.
### Training ## Training
To train the model, you will first need to precompute protein alignments. To train the model, you will first need to precompute protein alignments.
...@@ -412,9 +503,9 @@ environment. These run components of AlphaFold and OpenFold side by side and ...@@ -412,9 +503,9 @@ environment. These run components of AlphaFold and OpenFold side by side and
ensure that output activations are adequately similar. For most modules, we ensure that output activations are adequately similar. For most modules, we
target a maximum pointwise difference of `1e-4`. target a maximum pointwise difference of `1e-4`.
## Building and using the docker container ## Building and Using the Docker Container
### Building the docker image **Building the Docker Image**
Openfold can be built as a docker container using the included dockerfile. To build it, run the following command from the root of this repository: Openfold can be built as a docker container using the included dockerfile. To build it, run the following command from the root of this repository:
...@@ -422,7 +513,7 @@ Openfold can be built as a docker container using the included dockerfile. To bu ...@@ -422,7 +513,7 @@ Openfold can be built as a docker container using the included dockerfile. To bu
docker build -t openfold . docker build -t openfold .
``` ```
### Running the docker container **Running the Docker Container**
The built container contains both `run_pretrained_openfold.py` and `train_openfold.py` as well as all necessary software dependencies. It does not contain the model parameters, sequence, or structural databases. These should be downloaded to the host machine following the instructions in the Usage section above. The built container contains both `run_pretrained_openfold.py` and `train_openfold.py` as well as all necessary software dependencies. It does not contain the model parameters, sequence, or structural databases. These should be downloaded to the host machine following the instructions in the Usage section above.
...@@ -462,7 +553,7 @@ python3 /opt/openfold/run_pretrained_openfold.py \ ...@@ -462,7 +553,7 @@ python3 /opt/openfold/run_pretrained_openfold.py \
--openfold_checkpoint_path /database/openfold_params/finetuning_ptm_2.pt --openfold_checkpoint_path /database/openfold_params/finetuning_ptm_2.pt
``` ```
## Copyright notice ## Copyright Notice
While AlphaFold's and, by extension, OpenFold's source code is licensed under While AlphaFold's and, by extension, OpenFold's source code is licensed under
the permissive Apache Licence, Version 2.0, DeepMind's pretrained parameters the permissive Apache Licence, Version 2.0, DeepMind's pretrained parameters
...@@ -475,7 +566,7 @@ replaces the original, more restrictive CC BY-NC 4.0 license as of January 2022. ...@@ -475,7 +566,7 @@ replaces the original, more restrictive CC BY-NC 4.0 license as of January 2022.
If you encounter problems using OpenFold, feel free to create an issue! We also If you encounter problems using OpenFold, feel free to create an issue! We also
welcome pull requests from the community. welcome pull requests from the community.
## Citing this work ## Citing this Work
Please cite our paper: Please cite our paper:
...@@ -504,4 +595,4 @@ If you use OpenProteinSet, please also cite: ...@@ -504,4 +595,4 @@ If you use OpenProteinSet, please also cite:
primaryClass={q-bio.BM} primaryClass={q-bio.BM}
} }
``` ```
Any work that cites OpenFold should also cite AlphaFold. Any work that cites OpenFold should also cite [AlphaFold](https://www.nature.com/articles/s41586-021-03819-2) and [AlphaFold-Multimer](https://www.biorxiv.org/content/10.1101/2021.10.04.463034v1) if applicable.
...@@ -14,6 +14,7 @@ dependencies: ...@@ -14,6 +14,7 @@ dependencies:
- pytorch-lightning==1.5.10 - pytorch-lightning==1.5.10
- biopython==1.79 - biopython==1.79
- numpy==1.21 - numpy==1.21
- pandas==2.0
- PyYAML==5.4.1 - PyYAML==5.4.1
- requests - requests
- scipy==1.7 - scipy==1.7
......
This diff is collapsed.
...@@ -3,15 +3,15 @@ channels: ...@@ -3,15 +3,15 @@ channels:
- conda-forge - conda-forge
- bioconda - bioconda
dependencies: dependencies:
- conda-forge::openmm=7.5.1 - openmm=7.7
- conda-forge::pdbfixer - pdbfixer
- ml-collections
- PyYAML==5.4.1
- requests
- typing-extensions
- bioconda::hmmer==3.3.2 - bioconda::hmmer==3.3.2
- bioconda::hhsuite==3.3.0 - bioconda::hhsuite==3.3.0
- bioconda::kalign2==2.04 - bioconda::kalign2==2.04
- pip: - pip:
- biopython==1.79 - biopython==1.79
- dm-tree==0.1.6 - dm-tree==0.1.6
- ml-collections==0.1.0
- PyYAML==5.4.1
- requests==2.26.0
- typing-extensions==3.10.0.2
from . import model from . import model
from . import utils from . import utils
from . import data
from . import np from . import np
from . import resources from . import resources
......
import re
import copy import copy
import importlib import importlib
import ml_collections as mlc import ml_collections as mlc
...@@ -16,7 +17,7 @@ def enforce_config_constraints(config): ...@@ -16,7 +17,7 @@ def enforce_config_constraints(config):
path = s.split('.') path = s.split('.')
setting = config setting = config
for p in path: for p in path:
setting = setting[p] setting = setting.get(p)
return setting return setting
...@@ -161,8 +162,9 @@ def model_config( ...@@ -161,8 +162,9 @@ def model_config(
c.model.template.enabled = False c.model.template.enabled = False
c.model.heads.tm.enabled = True c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1 c.loss.tm.weight = 0.1
# SINGLE SEQUENCE EMBEDDING PRESETS elif name.startswith("seq"): # SINGLE SEQUENCE EMBEDDING PRESETS
elif name == "seqemb_initial_training": c.update(seq_mode_config.copy_and_resolve_references())
if name == "seqemb_initial_training":
c.data.train.max_msa_clusters = 1 c.data.train.max_msa_clusters = 1
c.data.eval.max_msa_clusters = 1 c.data.eval.max_msa_clusters = 1
c.data.train.block_delete_msa = False c.data.train.block_delete_msa = False
...@@ -187,18 +189,43 @@ def model_config( ...@@ -187,18 +189,43 @@ def model_config(
c.data.predict.max_msa_clusters = 1 c.data.predict.max_msa_clusters = 1
c.model.heads.tm.enabled = True c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1 c.loss.tm.weight = 0.1
elif "multimer" in name: # MULTIMER PRESETS
c.update(multimer_config_update.copy_and_resolve_references())
# Not used in multimer
del c.model.template.template_pointwise_attention
del c.loss.fape.backbone
# TODO: Change max_msa_clusters and max_extra_msa to multimer feats within model
if re.fullmatch("^model_[1-5]_multimer(_v2)?$", name):
#c.model.input_embedder.num_msa = 252
#c.model.extra_msa.extra_msa_embedder.num_extra_msa = 1152
c.data.train.crop_size = 384
c.data.train.max_msa_clusters = 252
c.data.eval.max_msa_clusters = 252
c.data.predict.max_msa_clusters = 252
c.data.train.max_extra_msa = 1152
c.data.eval.max_extra_msa = 1152
c.data.predict.max_extra_msa = 1152
c.model.evoformer_stack.fuse_projection_weights = False
c.model.extra_msa.extra_msa_stack.fuse_projection_weights = False
c.model.template.template_pair_stack.fuse_projection_weights = False
elif name == 'model_4_multimer_v3':
#c.model.extra_msa.extra_msa_embedder.num_extra_msa = 1152
c.data.train.max_extra_msa = 1152
c.data.eval.max_extra_msa = 1152
c.data.predict.max_extra_msa = 1152
elif name == 'model_5_multimer_v3':
#c.model.extra_msa.extra_msa_embedder.num_extra_msa = 1152
c.data.train.max_extra_msa = 1152
c.data.eval.max_extra_msa = 1152
c.data.predict.max_extra_msa = 1152
else: else:
raise ValueError("Invalid model name") raise ValueError("Invalid model name")
if name.startswith("seq"):
# Tell the data pipeline that we will use sequence embeddings instead of MSAs.
c.data.seqemb_mode.enabled = True
c.globals.seqemb_mode_enabled = True
# In seqemb mode, we turn off the ExtraMSAStack and Evoformer's column attention.
c.model.extra_msa.enabled = False
c.model.evoformer_stack.no_column_attention = True
c.update(seq_mode_config.copy_and_resolve_references())
if long_sequence_inference: if long_sequence_inference:
assert(not train) assert(not train)
c.globals.offload_inference = True c.globals.offload_inference = True
...@@ -380,6 +407,8 @@ config = mlc.ConfigDict( ...@@ -380,6 +407,8 @@ config = mlc.ConfigDict(
"max_templates": 4, "max_templates": 4,
"crop": False, "crop": False,
"crop_size": None, "crop_size": None,
"spatial_crop_prob": None,
"interface_threshold": None,
"supervised": False, "supervised": False,
"uniform_recycling": False, "uniform_recycling": False,
}, },
...@@ -394,6 +423,8 @@ config = mlc.ConfigDict( ...@@ -394,6 +423,8 @@ config = mlc.ConfigDict(
"max_templates": 4, "max_templates": 4,
"crop": False, "crop": False,
"crop_size": None, "crop_size": None,
"spatial_crop_prob": None,
"interface_threshold": None,
"supervised": True, "supervised": True,
"uniform_recycling": False, "uniform_recycling": False,
}, },
...@@ -409,6 +440,8 @@ config = mlc.ConfigDict( ...@@ -409,6 +440,8 @@ config = mlc.ConfigDict(
"shuffle_top_k_prefiltered": 20, "shuffle_top_k_prefiltered": 20,
"crop": True, "crop": True,
"crop_size": 256, "crop_size": 256,
"spatial_crop_prob": 0.,
"interface_threshold": None,
"supervised": True, "supervised": True,
"clamp_prob": 0.9, "clamp_prob": 0.9,
"max_distillation_msa_clusters": 1000, "max_distillation_msa_clusters": 1000,
...@@ -426,7 +459,6 @@ config = mlc.ConfigDict( ...@@ -426,7 +459,6 @@ config = mlc.ConfigDict(
}, },
# Recurring FieldReferences that can be changed globally here # Recurring FieldReferences that can be changed globally here
"globals": { "globals": {
"seqemb_mode_enabled": False, # Global flag for enabling seq emb mode
"blocks_per_ckpt": blocks_per_ckpt, "blocks_per_ckpt": blocks_per_ckpt,
"chunk_size": chunk_size, "chunk_size": chunk_size,
# Use DeepSpeed memory-efficient attention kernel. Mutually # Use DeepSpeed memory-efficient attention kernel. Mutually
...@@ -446,6 +478,8 @@ config = mlc.ConfigDict( ...@@ -446,6 +478,8 @@ config = mlc.ConfigDict(
"c_e": c_e, "c_e": c_e,
"c_s": c_s, "c_s": c_s,
"eps": eps, "eps": eps,
"is_multimer": False,
"seqemb_mode_enabled": False, # Global flag for enabling seq emb mode
}, },
"model": { "model": {
"_mask_trans": False, "_mask_trans": False,
...@@ -470,7 +504,7 @@ config = mlc.ConfigDict( ...@@ -470,7 +504,7 @@ config = mlc.ConfigDict(
"max_bin": 50.75, "max_bin": 50.75,
"no_bins": 39, "no_bins": 39,
}, },
"template_angle_embedder": { "template_single_embedder": {
# DISCREPANCY: c_in is supposed to be 51. # DISCREPANCY: c_in is supposed to be 51.
"c_in": 57, "c_in": 57,
"c_out": c_m, "c_out": c_m,
...@@ -489,6 +523,8 @@ config = mlc.ConfigDict( ...@@ -489,6 +523,8 @@ config = mlc.ConfigDict(
"no_heads": 4, "no_heads": 4,
"pair_transition_n": 2, "pair_transition_n": 2,
"dropout_rate": 0.25, "dropout_rate": 0.25,
"tri_mul_first": False,
"fuse_projection_weights": False,
"blocks_per_ckpt": blocks_per_ckpt, "blocks_per_ckpt": blocks_per_ckpt,
"tune_chunk_size": tune_chunk_size, "tune_chunk_size": tune_chunk_size,
"inf": 1e9, "inf": 1e9,
...@@ -537,6 +573,8 @@ config = mlc.ConfigDict( ...@@ -537,6 +573,8 @@ config = mlc.ConfigDict(
"transition_n": 4, "transition_n": 4,
"msa_dropout": 0.15, "msa_dropout": 0.15,
"pair_dropout": 0.25, "pair_dropout": 0.25,
"opm_first": False,
"fuse_projection_weights": False,
"clear_cache_between_blocks": False, "clear_cache_between_blocks": False,
"tune_chunk_size": tune_chunk_size, "tune_chunk_size": tune_chunk_size,
"inf": 1e9, "inf": 1e9,
...@@ -560,6 +598,8 @@ config = mlc.ConfigDict( ...@@ -560,6 +598,8 @@ config = mlc.ConfigDict(
"msa_dropout": 0.15, "msa_dropout": 0.15,
"pair_dropout": 0.25, "pair_dropout": 0.25,
"no_column_attention": False, "no_column_attention": False,
"opm_first": False,
"fuse_projection_weights": False,
"blocks_per_ckpt": blocks_per_ckpt, "blocks_per_ckpt": blocks_per_ckpt,
"clear_cache_between_blocks": False, "clear_cache_between_blocks": False,
"tune_chunk_size": tune_chunk_size, "tune_chunk_size": tune_chunk_size,
...@@ -607,6 +647,12 @@ config = mlc.ConfigDict( ...@@ -607,6 +647,12 @@ config = mlc.ConfigDict(
"c_out": 37, "c_out": 37,
}, },
}, },
# A negative value indicates that no early stopping will occur, i.e.
# the model will always run `max_recycling_iters` number of recycling
# iterations. A positive value will enable early stopping if the
# difference in pairwise distances is less than the tolerance between
# recycling steps.
"recycle_early_stop_tolerance": -1.
}, },
"relax": { "relax": {
"max_iterations": 0, # no max "max_iterations": 0, # no max
...@@ -652,6 +698,7 @@ config = mlc.ConfigDict( ...@@ -652,6 +698,7 @@ config = mlc.ConfigDict(
"weight": 0.01, "weight": 0.01,
}, },
"masked_msa": { "masked_msa": {
"num_classes": 23,
"eps": eps, # 1e-8, "eps": eps, # 1e-8,
"weight": 2.0, "weight": 2.0,
}, },
...@@ -664,6 +711,7 @@ config = mlc.ConfigDict( ...@@ -664,6 +711,7 @@ config = mlc.ConfigDict(
"violation": { "violation": {
"violation_tolerance_factor": 12.0, "violation_tolerance_factor": 12.0,
"clash_overlap_tolerance": 1.5, "clash_overlap_tolerance": 1.5,
"average_clashes": False,
"eps": eps, # 1e-6, "eps": eps, # 1e-6,
"weight": 0.0, "weight": 0.0,
}, },
...@@ -676,12 +724,199 @@ config = mlc.ConfigDict( ...@@ -676,12 +724,199 @@ config = mlc.ConfigDict(
"weight": 0., "weight": 0.,
"enabled": tm_enabled, "enabled": tm_enabled,
}, },
"chain_center_of_mass": {
"clamp_distance": -4.0,
"weight": 0.,
"eps": eps,
"enabled": False,
},
"eps": eps, "eps": eps,
}, },
"ema": {"decay": 0.999}, "ema": {"decay": 0.999},
} }
) )
multimer_config_update = mlc.ConfigDict({
"globals": {
"is_multimer": True
},
"data": {
"common": {
"feat": {
"aatype": [NUM_RES],
"all_atom_mask": [NUM_RES, None],
"all_atom_positions": [NUM_RES, None, None],
# "all_chains_entity_ids": [], # TODO: Resolve missing features, remove processed msa feats
# "all_crops_all_chains_mask": [],
# "all_crops_all_chains_positions": [],
# "all_crops_all_chains_residue_ids": [],
"assembly_num_chains": [],
"asym_id": [NUM_RES],
"atom14_atom_exists": [NUM_RES, None],
"atom37_atom_exists": [NUM_RES, None],
"bert_mask": [NUM_MSA_SEQ, NUM_RES],
"cluster_bias_mask": [NUM_MSA_SEQ],
"cluster_profile": [NUM_MSA_SEQ, NUM_RES, None],
"cluster_deletion_mean": [NUM_MSA_SEQ, NUM_RES],
"deletion_matrix": [NUM_MSA_SEQ, NUM_RES],
"deletion_mean": [NUM_RES],
"entity_id": [NUM_RES],
"entity_mask": [NUM_RES],
"extra_deletion_matrix": [NUM_EXTRA_SEQ, NUM_RES],
"extra_msa": [NUM_EXTRA_SEQ, NUM_RES],
"extra_msa_mask": [NUM_EXTRA_SEQ, NUM_RES],
# "mem_peak": [],
"msa": [NUM_MSA_SEQ, NUM_RES],
"msa_feat": [NUM_MSA_SEQ, NUM_RES, None],
"msa_mask": [NUM_MSA_SEQ, NUM_RES],
"msa_profile": [NUM_RES, None],
"num_alignments": [],
"num_templates": [],
# "queue_size": [],
"residue_index": [NUM_RES],
"residx_atom14_to_atom37": [NUM_RES, None],
"residx_atom37_to_atom14": [NUM_RES, None],
"resolution": [],
"seq_length": [],
"seq_mask": [NUM_RES],
"sym_id": [NUM_RES],
"target_feat": [NUM_RES, None],
"template_aatype": [NUM_TEMPLATES, NUM_RES],
"template_all_atom_mask": [NUM_TEMPLATES, NUM_RES, None],
"template_all_atom_positions": [
NUM_TEMPLATES, NUM_RES, None, None,
],
"true_msa": [NUM_MSA_SEQ, NUM_RES]
},
"max_recycling_iters": 20, # For training, value is 3
"unsupervised_features": [
"aatype",
"residue_index",
"msa",
"num_alignments",
"seq_length",
"between_segment_residues",
"deletion_matrix",
"no_recycling_iters",
# Additional multimer features
"msa_mask",
"seq_mask",
"asym_id",
"entity_id",
"sym_id",
]
},
"supervised": {
"clamp_prob": 1.
},
# TODO: Change max_msa_clusters and max_extra_msa to multimer feats within model:
# c.model.input_embedder.num_msa = 508
# c.model.extra_msa.extra_msa_embedder.num_extra_msa = 2048
"predict": {
"max_msa_clusters": 508,
"max_extra_msa": 2048
},
"eval": {
"max_msa_clusters": 508,
"max_extra_msa": 2048
},
"train": {
"max_msa_clusters": 508,
"max_extra_msa": 2048,
"block_delete_msa" : False,
"crop_size": 640,
"spatial_crop_prob": 0.5,
"interface_threshold": 10.,
"clamp_prob": 1.,
},
},
"model": {
"input_embedder": {
"tf_dim": 21,
#"num_msa": 508,
"max_relative_chain": 2,
"max_relative_idx": 32,
"use_chain_relative": True
},
"template": {
"template_single_embedder": {
"c_in": 34,
"c_out": c_m
},
"template_pair_embedder": {
"c_in": c_z,
"c_out": c_t,
"c_dgram": 39,
"c_aatype": 22
},
"template_pair_stack": {
"tri_mul_first": True,
"fuse_projection_weights": True
},
"c_t": c_t,
"c_z": c_z,
"use_unit_vector": True
},
"extra_msa": {
# "extra_msa_embedder": {
# "num_extra_msa": 2048
# },
"extra_msa_stack": {
"opm_first": True,
"fuse_projection_weights": True
}
},
"evoformer_stack": {
"opm_first": True,
"fuse_projection_weights": True
},
"structure_module": {
"trans_scale_factor": 20
},
"heads": {
"tm": {
"ptm_weight": 0.2,
"iptm_weight": 0.8,
"enabled": True
},
"masked_msa": {
"c_out": 22
},
},
"recycle_early_stop_tolerance": 0.5 # For training, value is -1.
},
"loss": {
"fape": {
"intra_chain_backbone": {
"clamp_distance": 10.0,
"loss_unit_distance": 10.0,
"weight": 0.5
},
"interface_backbone": {
"clamp_distance": 30.0,
"loss_unit_distance": 20.0,
"weight": 0.5
}
},
"masked_msa": {
"num_classes": 22
},
"violation": {
"average_clashes": True,
"weight": 0.03 # Not finetuning
},
"tm": {
"weight": 0.1,
"enabled": True
},
"chain_center_of_mass": {
"weight": 0.05,
"enabled": True
}
}
})
seq_mode_config = mlc.ConfigDict({ seq_mode_config = mlc.ConfigDict({
"data": { "data": {
"common": { "common": {
...@@ -707,5 +942,11 @@ seq_mode_config = mlc.ConfigDict({ ...@@ -707,5 +942,11 @@ seq_mode_config = mlc.ConfigDict({
"c_m": c_m, "c_m": c_m,
"relpos_k": 32, "relpos_k": 32,
}, },
"extra_msa": {
"enabled": False # Disable Extra MSA Stack
},
"evoformer_stack": {
"no_column_attention": True # Turn off Evoformer's column attention
},
} }
}) })
This diff is collapsed.
This diff is collapsed.
...@@ -23,6 +23,9 @@ import torch ...@@ -23,6 +23,9 @@ import torch
from openfold.config import NUM_RES, NUM_EXTRA_SEQ, NUM_TEMPLATES, NUM_MSA_SEQ from openfold.config import NUM_RES, NUM_EXTRA_SEQ, NUM_TEMPLATES, NUM_MSA_SEQ
from openfold.np import residue_constants as rc from openfold.np import residue_constants as rc
from openfold.utils.rigid_utils import Rotation, Rigid from openfold.utils.rigid_utils import Rotation, Rigid
from openfold.utils.geometry.rigid_matrix_vector import Rigid3Array
from openfold.utils.geometry.rotation_matrix import Rot3Array
from openfold.utils.geometry.vector import Vec3Array
from openfold.utils.tensor_utils import ( from openfold.utils.tensor_utils import (
tree_map, tree_map,
tensor_tree_map, tensor_tree_map,
...@@ -86,14 +89,13 @@ def make_all_atom_aatype(protein): ...@@ -86,14 +89,13 @@ def make_all_atom_aatype(protein):
def fix_templates_aatype(protein): def fix_templates_aatype(protein):
# Map one-hot to indices # Map one-hot to indices
num_templates = protein["template_aatype"].shape[0] num_templates = protein["template_aatype"].shape[0]
if(num_templates > 0):
protein["template_aatype"] = torch.argmax( protein["template_aatype"] = torch.argmax(
protein["template_aatype"], dim=-1 protein["template_aatype"], dim=-1
) )
# Map hhsearch-aatype to our aatype. # Map hhsearch-aatype to our aatype.
new_order_list = rc.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE new_order_list = rc.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
new_order = torch.tensor( new_order = torch.tensor(
new_order_list, dtype=torch.int64, device=protein["aatype"].device, new_order_list, dtype=torch.int64, device=protein["template_aatype"].device,
).expand(num_templates, -1) ).expand(num_templates, -1)
protein["template_aatype"] = torch.gather( protein["template_aatype"] = torch.gather(
new_order, 1, index=protein["template_aatype"] new_order, 1, index=protein["template_aatype"]
...@@ -447,13 +449,15 @@ def make_hhblits_profile(protein): ...@@ -447,13 +449,15 @@ def make_hhblits_profile(protein):
@curry1 @curry1
def make_masked_msa(protein, config, replace_fraction): def make_masked_msa(protein, config, replace_fraction, seed):
"""Create data for BERT on raw MSA.""" """Create data for BERT on raw MSA."""
device = protein["msa"].device
# Add a random amino acid uniformly. # Add a random amino acid uniformly.
random_aa = torch.tensor( random_aa = torch.tensor(
[0.05] * 20 + [0.0, 0.0], [0.05] * 20 + [0.0, 0.0],
dtype=torch.float32, dtype=torch.float32,
device=protein["aatype"].device device=device
) )
categorical_probs = ( categorical_probs = (
...@@ -473,11 +477,18 @@ def make_masked_msa(protein, config, replace_fraction): ...@@ -473,11 +477,18 @@ def make_masked_msa(protein, config, replace_fraction):
assert mask_prob >= 0.0 assert mask_prob >= 0.0
categorical_probs = torch.nn.functional.pad( categorical_probs = torch.nn.functional.pad(
categorical_probs, pad_shapes, value=mask_prob categorical_probs, pad_shapes, value=mask_prob,
) )
sh = protein["msa"].shape sh = protein["msa"].shape
mask_position = torch.rand(sh) < replace_fraction
g = None
if seed is not None:
g = torch.Generator(device=protein["msa"].device)
g.manual_seed(seed)
sample = torch.rand(sh, device=device, generator=g)
mask_position = sample < replace_fraction
bert_msa = shaped_categorical(categorical_probs) bert_msa = shaped_categorical(categorical_probs)
bert_msa = torch.where(mask_position, bert_msa, protein["msa"]) bert_msa = torch.where(mask_position, bert_msa, protein["msa"])
...@@ -782,10 +793,14 @@ def make_atom14_positions(protein): ...@@ -782,10 +793,14 @@ def make_atom14_positions(protein):
def atom37_to_frames(protein, eps=1e-8): def atom37_to_frames(protein, eps=1e-8):
is_multimer = "asym_id" in protein
aatype = protein["aatype"] aatype = protein["aatype"]
all_atom_positions = protein["all_atom_positions"] all_atom_positions = protein["all_atom_positions"]
all_atom_mask = protein["all_atom_mask"] all_atom_mask = protein["all_atom_mask"]
if is_multimer:
all_atom_positions = Vec3Array.from_array(all_atom_positions)
batch_dims = len(aatype.shape[:-1]) batch_dims = len(aatype.shape[:-1])
restype_rigidgroup_base_atom_names = np.full([21, 8, 3], "", dtype=object) restype_rigidgroup_base_atom_names = np.full([21, 8, 3], "", dtype=object)
...@@ -832,6 +847,15 @@ def atom37_to_frames(protein, eps=1e-8): ...@@ -832,6 +847,15 @@ def atom37_to_frames(protein, eps=1e-8):
no_batch_dims=batch_dims, no_batch_dims=batch_dims,
) )
if is_multimer:
base_atom_pos = [batched_gather(
pos,
residx_rigidgroup_base_atom37_idx,
dim=-1,
no_batch_dims=len(all_atom_positions.shape[:-1]),
) for pos in all_atom_positions]
base_atom_pos = Vec3Array.from_array(torch.stack(base_atom_pos, dim=-1))
else:
base_atom_pos = batched_gather( base_atom_pos = batched_gather(
all_atom_positions, all_atom_positions,
residx_rigidgroup_base_atom37_idx, residx_rigidgroup_base_atom37_idx,
...@@ -839,6 +863,15 @@ def atom37_to_frames(protein, eps=1e-8): ...@@ -839,6 +863,15 @@ def atom37_to_frames(protein, eps=1e-8):
no_batch_dims=len(all_atom_positions.shape[:-2]), no_batch_dims=len(all_atom_positions.shape[:-2]),
) )
if is_multimer:
point_on_neg_x_axis = base_atom_pos[:, :, 0]
origin = base_atom_pos[:, :, 1]
point_on_xy_plane = base_atom_pos[:, :, 2]
gt_rotation = Rot3Array.from_two_vectors(
origin - point_on_neg_x_axis, point_on_xy_plane - origin)
gt_frames = Rigid3Array(gt_rotation, origin)
else:
gt_frames = Rigid.from_3_points( gt_frames = Rigid.from_3_points(
p_neg_x_axis=base_atom_pos[..., 0, :], p_neg_x_axis=base_atom_pos[..., 0, :],
origin=base_atom_pos[..., 1, :], origin=base_atom_pos[..., 1, :],
...@@ -865,8 +898,12 @@ def atom37_to_frames(protein, eps=1e-8): ...@@ -865,8 +898,12 @@ def atom37_to_frames(protein, eps=1e-8):
rots = torch.tile(rots, (*((1,) * batch_dims), 8, 1, 1)) rots = torch.tile(rots, (*((1,) * batch_dims), 8, 1, 1))
rots[..., 0, 0, 0] = -1 rots[..., 0, 0, 0] = -1
rots[..., 0, 2, 2] = -1 rots[..., 0, 2, 2] = -1
rots = Rotation(rot_mats=rots)
if is_multimer:
gt_frames = gt_frames.compose_rotation(
Rot3Array.from_array(rots))
else:
rots = Rotation(rot_mats=rots)
gt_frames = gt_frames.compose(Rigid(rots, None)) gt_frames = gt_frames.compose(Rigid(rots, None))
restype_rigidgroup_is_ambiguous = all_atom_mask.new_zeros( restype_rigidgroup_is_ambiguous = all_atom_mask.new_zeros(
...@@ -901,6 +938,12 @@ def atom37_to_frames(protein, eps=1e-8): ...@@ -901,6 +938,12 @@ def atom37_to_frames(protein, eps=1e-8):
no_batch_dims=batch_dims, no_batch_dims=batch_dims,
) )
if is_multimer:
ambiguity_rot = Rot3Array.from_array(residx_rigidgroup_ambiguity_rot)
# Create the alternative ground truth frames.
alt_gt_frames = gt_frames.compose_rotation(ambiguity_rot)
else:
residx_rigidgroup_ambiguity_rot = Rotation( residx_rigidgroup_ambiguity_rot = Rotation(
rot_mats=residx_rigidgroup_ambiguity_rot rot_mats=residx_rigidgroup_ambiguity_rot
) )
......
from typing import Sequence
import torch
from openfold.config import NUM_RES
from openfold.data.data_transforms import curry1
from openfold.np import residue_constants as rc
from openfold.utils.tensor_utils import masked_mean
def gumbel_noise(
shape: Sequence[int],
device: torch.device,
eps=1e-6,
generator=None,
) -> torch.Tensor:
"""Generate Gumbel Noise of given Shape.
This generates samples from Gumbel(0, 1).
Args:
shape: Shape of noise to return.
Returns:
Gumbel noise of given shape.
"""
uniform_noise = torch.rand(
shape, dtype=torch.float32, device=device, generator=generator
)
gumbel = -torch.log(-torch.log(uniform_noise + eps) + eps)
return gumbel
def gumbel_max_sample(logits: torch.Tensor, generator=None) -> torch.Tensor:
"""Samples from a probability distribution given by 'logits'.
This uses Gumbel-max trick to implement the sampling in an efficient manner.
Args:
logits: Logarithm of probabilities to sample from, probabilities can be
unnormalized.
Returns:
Sample from logprobs in one-hot form.
"""
z = gumbel_noise(logits.shape, device=logits.device, generator=generator)
return torch.nn.functional.one_hot(
torch.argmax(logits + z, dim=-1),
logits.shape[-1],
)
def gumbel_argsort_sample_idx(
logits: torch.Tensor,
generator=None
) -> torch.Tensor:
"""Samples with replacement from a distribution given by 'logits'.
This uses Gumbel trick to implement the sampling an efficient manner. For a
distribution over k items this samples k times without replacement, so this
is effectively sampling a random permutation with probabilities over the
permutations derived from the logprobs.
Args:
logits: Logarithm of probabilities to sample from, probabilities can be
unnormalized.
Returns:
Sample from logprobs in one-hot form.
"""
z = gumbel_noise(logits.shape, device=logits.device, generator=generator)
return torch.argsort(logits + z, dim=-1, descending=True)
@curry1
def make_masked_msa(batch, config, replace_fraction, seed, eps=1e-6):
"""Create data for BERT on raw MSA."""
# Add a random amino acid uniformly.
random_aa = torch.Tensor(
[0.05] * 20 + [0., 0.],
device=batch['msa'].device
)
categorical_probs = (
config.uniform_prob * random_aa +
config.profile_prob * batch['msa_profile'] +
config.same_prob * torch.nn.functional.one_hot(batch['msa'], 22)
)
# Put all remaining probability on [MASK] which is a new column.
mask_prob = 1. - config.profile_prob - config.same_prob - config.uniform_prob
categorical_probs = torch.nn.functional.pad(
categorical_probs, [0,1], value=mask_prob
)
sh = batch['msa'].shape
mask_position = torch.rand(sh, device=batch['msa'].device) < replace_fraction
mask_position *= batch['msa_mask'].to(mask_position.dtype)
logits = torch.log(categorical_probs + eps)
g = None
if seed is not None:
g = torch.Generator(device=batch["msa"].device)
g.manual_seed(seed)
bert_msa = gumbel_max_sample(logits, generator=g)
bert_msa = torch.where(
mask_position,
torch.argmax(bert_msa, dim=-1),
batch['msa']
)
bert_msa *= batch['msa_mask'].to(bert_msa.dtype)
# Mix real and masked MSA.
if 'bert_mask' in batch:
batch['bert_mask'] *= mask_position.to(torch.float32)
else:
batch['bert_mask'] = mask_position.to(torch.float32)
batch['true_msa'] = batch['msa']
batch['msa'] = bert_msa
return batch
@curry1
def nearest_neighbor_clusters(batch, gap_agreement_weight=0.):
"""Assign each extra MSA sequence to its nearest neighbor in sampled MSA."""
device = batch["msa_mask"].device
# Determine how much weight we assign to each agreement. In theory, we could
# use a full blosum matrix here, but right now let's just down-weight gap
# agreement because it could be spurious.
# Never put weight on agreeing on BERT mask.
weights = torch.Tensor(
[1.] * 21 + [gap_agreement_weight] + [0.],
device=device,
)
msa_mask = batch['msa_mask']
msa_one_hot = torch.nn.functional.one_hot(batch['msa'], 23)
extra_mask = batch['extra_msa_mask']
extra_one_hot = torch.nn.functional.one_hot(batch['extra_msa'], 23)
msa_one_hot_masked = msa_mask[:, :, None] * msa_one_hot
extra_one_hot_masked = extra_mask[:, :, None] * extra_one_hot
agreement = torch.einsum(
'mrc, nrc->nm',
extra_one_hot_masked,
weights * msa_one_hot_masked
)
cluster_assignment = torch.nn.functional.softmax(1e3 * agreement, dim=0)
cluster_assignment *= torch.einsum('mr, nr->mn', msa_mask, extra_mask)
cluster_count = torch.sum(cluster_assignment, dim=-1)
cluster_count += 1. # We always include the sequence itself.
msa_sum = torch.einsum('nm, mrc->nrc', cluster_assignment, extra_one_hot_masked)
msa_sum += msa_one_hot_masked
cluster_profile = msa_sum / cluster_count[:, None, None]
extra_deletion_matrix = batch['extra_deletion_matrix']
deletion_matrix = batch['deletion_matrix']
del_sum = torch.einsum(
'nm, mc->nc',
cluster_assignment,
extra_mask * extra_deletion_matrix
)
del_sum += deletion_matrix # Original sequence.
cluster_deletion_mean = del_sum / cluster_count[:, None]
batch['cluster_profile'] = cluster_profile
batch['cluster_deletion_mean'] = cluster_deletion_mean
return batch
def create_target_feat(batch):
"""Create the target features"""
batch["target_feat"] = torch.nn.functional.one_hot(
batch["aatype"], 21
).to(torch.float32)
return batch
def create_msa_feat(batch):
"""Create and concatenate MSA features."""
device = batch["msa"]
msa_1hot = torch.nn.functional.one_hot(batch['msa'], 23)
deletion_matrix = batch['deletion_matrix']
has_deletion = torch.clamp(deletion_matrix, min=0., max=1.)[..., None]
pi = torch.acos(torch.zeros(1, device=deletion_matrix.device)) * 2
deletion_value = (torch.atan(deletion_matrix / 3.) * (2. / pi))[..., None]
deletion_mean_value = (
torch.atan(
batch['cluster_deletion_mean'] / 3.) *
(2. / pi)
)[..., None]
msa_feat = torch.cat(
[
msa_1hot,
has_deletion,
deletion_value,
batch['cluster_profile'],
deletion_mean_value
],
dim=-1,
)
batch["msa_feat"] = msa_feat
return batch
def build_extra_msa_feat(batch):
"""Expand extra_msa into 1hot and concat with other extra msa features.
We do this as late as possible as the one_hot extra msa can be very large.
Args:
batch: a dictionary with the following keys:
* 'extra_msa': [num_seq, num_res] MSA that wasn't selected as a cluster
centre. Note - This isn't one-hotted.
* 'extra_deletion_matrix': [num_seq, num_res] Number of deletions at given
position.
num_extra_msa: Number of extra msa to use.
Returns:
Concatenated tensor of extra MSA features.
"""
# 23 = 20 amino acids + 'X' for unknown + gap + bert mask
extra_msa = batch['extra_msa']
deletion_matrix = batch['extra_deletion_matrix']
msa_1hot = torch.nn.functional.one_hot(extra_msa, 23)
has_deletion = torch.clamp(deletion_matrix, min=0., max=1.)[..., None]
pi = torch.acos(torch.zeros(1, device=deletion_matrix.device)) * 2
deletion_value = (
(torch.atan(deletion_matrix / 3.) * (2. / pi))[..., None]
)
extra_msa_mask = batch['extra_msa_mask']
catted = torch.cat([msa_1hot, has_deletion, deletion_value], dim=-1)
return catted
@curry1
def sample_msa(batch, max_seq, max_extra_msa_seq, seed, inf=1e6):
"""Sample MSA randomly, remaining sequences are stored as `extra_*`.
Args:
batch: batch to sample msa from.
max_seq: number of sequences to sample.
Returns:
Protein with sampled msa.
"""
g = None
if seed is not None:
g = torch.Generator(device=batch["msa"].device)
g.manual_seed(seed)
# Sample uniformly among sequences with at least one non-masked position.
logits = (torch.clamp(torch.sum(batch['msa_mask'], dim=-1), 0., 1.) - 1.) * inf
# The cluster_bias_mask can be used to preserve the first row (target
# sequence) for each chain, for example.
if 'cluster_bias_mask' not in batch:
cluster_bias_mask = torch.nn.functional.pad(
batch['msa'].new_zeros(batch['msa'].shape[0] - 1),
(1, 0),
value=1.
)
else:
cluster_bias_mask = batch['cluster_bias_mask']
logits += cluster_bias_mask * inf
index_order = gumbel_argsort_sample_idx(logits, generator=g)
sel_idx = index_order[:max_seq]
extra_idx = index_order[max_seq:][:max_extra_msa_seq]
for k in ['msa', 'deletion_matrix', 'msa_mask', 'bert_mask']:
if k in batch:
batch['extra_' + k] = batch[k][extra_idx]
batch[k] = batch[k][sel_idx]
return batch
def make_msa_profile(batch):
"""Compute the MSA profile."""
# Compute the profile for every residue (over all MSA sequences).
batch["msa_profile"] = masked_mean(
batch['msa_mask'][..., None],
torch.nn.functional.one_hot(batch['msa'], 22),
dim=-3,
)
return batch
def randint(lower, upper, generator, device):
return int(torch.randint(
lower,
upper + 1,
(1,),
device=device,
generator=generator,
)[0])
def get_interface_residues(positions, atom_mask, asym_id, interface_threshold):
coord_diff = positions[..., None, :, :] - positions[..., None, :, :, :]
pairwise_dists = torch.sqrt(torch.sum(coord_diff ** 2, dim=-1))
diff_chain_mask = (asym_id[..., None, :] != asym_id[..., :, None]).float()
pair_mask = atom_mask[..., None, :] * atom_mask[..., None, :, :]
mask = (diff_chain_mask[..., None] * pair_mask).bool()
min_dist_per_res, _ = torch.where(mask, pairwise_dists, torch.inf).min(dim=-1)
valid_interfaces = torch.sum((min_dist_per_res < interface_threshold).float(), dim=-1)
interface_residues_idxs = torch.nonzero(valid_interfaces, as_tuple=True)[0]
return interface_residues_idxs
def get_spatial_crop_idx(protein, crop_size, interface_threshold, generator):
positions = protein["all_atom_positions"]
atom_mask = protein["all_atom_mask"]
asym_id = protein["asym_id"]
interface_residues = get_interface_residues(positions=positions,
atom_mask=atom_mask,
asym_id=asym_id,
interface_threshold=interface_threshold)
if not torch.any(interface_residues):
return get_contiguous_crop_idx(protein, crop_size, generator)
target_res_idx = randint(lower=0,
upper=interface_residues.shape[-1] - 1,
generator=generator,
device=positions.device)
target_res = interface_residues[target_res_idx]
ca_idx = rc.atom_order["CA"]
ca_positions = positions[..., ca_idx, :]
ca_mask = atom_mask[..., ca_idx].bool()
coord_diff = ca_positions[..., None, :] - ca_positions[..., None, :, :]
ca_pairwise_dists = torch.sqrt(torch.sum(coord_diff ** 2, dim=-1))
to_target_distances = ca_pairwise_dists[target_res]
break_tie = (
torch.arange(
0, to_target_distances.shape[-1], device=positions.device
).float()
* 1e-3
)
to_target_distances = torch.where(ca_mask, to_target_distances, torch.inf) + break_tie
ret = torch.argsort(to_target_distances)[:crop_size]
return ret.sort().values
def get_contiguous_crop_idx(protein, crop_size, generator):
unique_asym_ids, chain_idxs, chain_lens = protein["asym_id"].unique(dim=-1,
return_inverse=True,
return_counts=True)
shuffle_idx = torch.randperm(chain_lens.shape[-1], device=chain_lens.device, generator=generator)
_, idx_sorted = torch.sort(chain_idxs, stable=True)
cum_sum = chain_lens.cumsum(dim=0)
cum_sum = torch.cat((torch.tensor([0]), cum_sum[:-1]), dim=0)
asym_offsets = idx_sorted[cum_sum]
num_budget = crop_size
num_remaining = int(protein["seq_length"])
crop_idxs = []
for idx in shuffle_idx:
chain_len = int(chain_lens[idx])
num_remaining -= chain_len
crop_size_max = min(num_budget, chain_len)
crop_size_min = min(chain_len, max(0, num_budget - num_remaining))
chain_crop_size = randint(lower=crop_size_min,
upper=crop_size_max,
generator=generator,
device=chain_lens.device)
num_budget -= chain_crop_size
chain_start = randint(lower=0,
upper=chain_len - chain_crop_size,
generator=generator,
device=chain_lens.device)
asym_offset = asym_offsets[idx]
crop_idxs.append(
torch.arange(asym_offset + chain_start, asym_offset + chain_start + chain_crop_size)
)
return torch.concat(crop_idxs).sort().values
@curry1
def random_crop_to_size(
protein,
crop_size,
max_templates,
shape_schema,
spatial_crop_prob,
interface_threshold,
subsample_templates=False,
seed=None,
):
"""Crop randomly to `crop_size`, or keep as is if shorter than that."""
# We want each ensemble to be cropped the same way
g = None
if seed is not None:
g = torch.Generator(device=protein["seq_length"].device)
g.manual_seed(seed)
use_spatial_crop = torch.rand((1,),
device=protein["seq_length"].device,
generator=g) < spatial_crop_prob
num_res = protein["aatype"].shape[0]
if num_res <= crop_size:
crop_idxs = torch.arange(num_res)
elif use_spatial_crop:
crop_idxs = get_spatial_crop_idx(protein, crop_size, interface_threshold, g)
else:
crop_idxs = get_contiguous_crop_idx(protein, crop_size, g)
if "template_mask" in protein:
num_templates = protein["template_mask"].shape[-1]
else:
num_templates = 0
# No need to subsample templates if there aren't any
subsample_templates = subsample_templates and num_templates
if subsample_templates:
templates_crop_start = randint(lower=0,
upper=num_templates,
generator=g,
device=protein["seq_length"].device)
templates_select_indices = torch.randperm(
num_templates, device=protein["seq_length"].device, generator=g
)
else:
templates_crop_start = 0
num_res_crop_size = min(int(protein["seq_length"]), crop_size)
num_templates_crop_size = min(
num_templates - templates_crop_start, max_templates
)
for k, v in protein.items():
if k not in shape_schema or (
"template" not in k and NUM_RES not in shape_schema[k]
):
continue
# randomly permute the templates before cropping them.
if k.startswith("template") and subsample_templates:
v = v[templates_select_indices]
for i, (dim_size, dim) in enumerate(zip(shape_schema[k], v.shape)):
is_num_res = dim_size == NUM_RES
if i == 0 and k.startswith("template"):
v = v[slice(templates_crop_start, templates_crop_start + num_templates_crop_size)]
elif is_num_res:
v = torch.index_select(v, i, crop_idxs)
protein[k] = v
protein["seq_length"] = protein["seq_length"].new_tensor(num_res_crop_size)
return protein
...@@ -20,7 +20,7 @@ import ml_collections ...@@ -20,7 +20,7 @@ import ml_collections
import numpy as np import numpy as np
import torch import torch
from openfold.data import input_pipeline from openfold.data import input_pipeline, input_pipeline_multimer
FeatureDict = Mapping[str, np.ndarray] FeatureDict = Mapping[str, np.ndarray]
...@@ -80,9 +80,12 @@ def np_example_to_features( ...@@ -80,9 +80,12 @@ def np_example_to_features(
np_example: FeatureDict, np_example: FeatureDict,
config: ml_collections.ConfigDict, config: ml_collections.ConfigDict,
mode: str, mode: str,
is_multimer: bool = False
): ):
np_example = dict(np_example) np_example = dict(np_example)
num_res = int(np_example["seq_length"][0])
seq_length = np_example["seq_length"]
num_res = int(seq_length[0]) if seq_length.ndim != 0 else int(seq_length)
cfg, feature_names = make_data_config(config, mode=mode, num_res=num_res) cfg, feature_names = make_data_config(config, mode=mode, num_res=num_res)
if "deletion_matrix_int" in np_example: if "deletion_matrix_int" in np_example:
...@@ -93,7 +96,15 @@ def np_example_to_features( ...@@ -93,7 +96,15 @@ def np_example_to_features(
tensor_dict = np_to_tensor_dict( tensor_dict = np_to_tensor_dict(
np_example=np_example, features=feature_names np_example=np_example, features=feature_names
) )
with torch.no_grad(): with torch.no_grad():
if is_multimer:
features = input_pipeline_multimer.process_tensors_from_config(
tensor_dict,
cfg.common,
cfg[mode],
)
else:
features = input_pipeline.process_tensors_from_config( features = input_pipeline.process_tensors_from_config(
tensor_dict, tensor_dict,
cfg.common, cfg.common,
...@@ -129,9 +140,14 @@ class FeaturePipeline: ...@@ -129,9 +140,14 @@ class FeaturePipeline:
self, self,
raw_features: FeatureDict, raw_features: FeatureDict,
mode: str = "train", mode: str = "train",
is_multimer: bool = False,
) -> FeatureDict: ) -> FeatureDict:
# if(is_multimer and mode != "predict"):
# raise ValueError("Multimer mode is not currently trainable")
return np_example_to_features( return np_example_to_features(
np_example=raw_features, np_example=raw_features,
config=self.config, config=self.config,
mode=mode, mode=mode,
is_multimer=is_multimer,
) )
# Copyright 2021 DeepMind Technologies Limited
# Copyright 2022 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Feature processing logic for multimer data pipeline."""
from typing import Iterable, MutableMapping, List, Mapping
from openfold.data import msa_pairing
from openfold.np import residue_constants
import numpy as np
# TODO: Move this into the config
REQUIRED_FEATURES = frozenset({
'aatype', 'all_atom_mask', 'all_atom_positions', 'all_chains_entity_ids',
'all_crops_all_chains_mask', 'all_crops_all_chains_positions',
'all_crops_all_chains_residue_ids', 'assembly_num_chains', 'asym_id',
'bert_mask', 'cluster_bias_mask', 'deletion_matrix', 'deletion_mean',
'entity_id', 'entity_mask', 'mem_peak', 'msa', 'msa_mask', 'num_alignments',
'num_templates', 'queue_size', 'residue_index', 'resolution',
'seq_length', 'seq_mask', 'sym_id', 'template_aatype',
'template_all_atom_mask', 'template_all_atom_positions'
})
MAX_TEMPLATES = 4
MSA_CROP_SIZE = 2048
def _is_homomer_or_monomer(chains: Iterable[Mapping[str, np.ndarray]]) -> bool:
"""Checks if a list of chains represents a homomer/monomer example."""
# Note that an entity_id of 0 indicates padding.
num_unique_chains = len(np.unique(np.concatenate(
[np.unique(chain['entity_id'][chain['entity_id'] > 0]) for
chain in chains])))
return num_unique_chains == 1
def pair_and_merge(
all_chain_features: MutableMapping[str, Mapping[str, np.ndarray]],
) -> Mapping[str, np.ndarray]:
"""Runs processing on features to augment, pair and merge.
Args:
all_chain_features: A MutableMap of dictionaries of features for each chain.
Returns:
A dictionary of features.
"""
process_unmerged_features(all_chain_features)
np_chains_list = list(all_chain_features.values())
pair_msa_sequences = not _is_homomer_or_monomer(np_chains_list)
if pair_msa_sequences:
np_chains_list = msa_pairing.create_paired_features(
chains=np_chains_list
)
np_chains_list = msa_pairing.deduplicate_unpaired_sequences(np_chains_list)
np_chains_list = crop_chains(
np_chains_list,
msa_crop_size=MSA_CROP_SIZE,
pair_msa_sequences=pair_msa_sequences,
max_templates=MAX_TEMPLATES
)
np_example = msa_pairing.merge_chain_features(
np_chains_list=np_chains_list, pair_msa_sequences=pair_msa_sequences,
max_templates=MAX_TEMPLATES
)
np_example = process_final(np_example)
return np_example
def crop_chains(
chains_list: List[Mapping[str, np.ndarray]],
msa_crop_size: int,
pair_msa_sequences: bool,
max_templates: int
) -> List[Mapping[str, np.ndarray]]:
"""Crops the MSAs for a set of chains.
Args:
chains_list: A list of chains to be cropped.
msa_crop_size: The total number of sequences to crop from the MSA.
pair_msa_sequences: Whether we are operating in sequence-pairing mode.
max_templates: The maximum templates to use per chain.
Returns:
The chains cropped.
"""
# Apply the cropping.
cropped_chains = []
for chain in chains_list:
cropped_chain = _crop_single_chain(
chain,
msa_crop_size=msa_crop_size,
pair_msa_sequences=pair_msa_sequences,
max_templates=max_templates)
cropped_chains.append(cropped_chain)
return cropped_chains
def _crop_single_chain(chain: Mapping[str, np.ndarray],
msa_crop_size: int,
pair_msa_sequences: bool,
max_templates: int) -> Mapping[str, np.ndarray]:
"""Crops msa sequences to `msa_crop_size`."""
msa_size = chain['num_alignments']
if pair_msa_sequences:
msa_size_all_seq = chain['num_alignments_all_seq']
msa_crop_size_all_seq = np.minimum(msa_size_all_seq, msa_crop_size // 2)
# We reduce the number of un-paired sequences, by the number of times a
# sequence from this chain's MSA is included in the paired MSA. This keeps
# the MSA size for each chain roughly constant.
msa_all_seq = chain['msa_all_seq'][:msa_crop_size_all_seq, :]
num_non_gapped_pairs = np.sum(
np.any(msa_all_seq != msa_pairing.MSA_GAP_IDX, axis=1))
num_non_gapped_pairs = np.minimum(num_non_gapped_pairs,
msa_crop_size_all_seq)
# Restrict the unpaired crop size so that paired+unpaired sequences do not
# exceed msa_seqs_per_chain for each chain.
max_msa_crop_size = np.maximum(msa_crop_size - num_non_gapped_pairs, 0)
msa_crop_size = np.minimum(msa_size, max_msa_crop_size)
else:
msa_crop_size = np.minimum(msa_size, msa_crop_size)
include_templates = 'template_aatype' in chain and max_templates
if include_templates:
num_templates = chain['template_aatype'].shape[0]
templates_crop_size = np.minimum(num_templates, max_templates)
for k in chain:
k_split = k.split('_all_seq')[0]
if k_split in msa_pairing.TEMPLATE_FEATURES:
chain[k] = chain[k][:templates_crop_size, :]
elif k_split in msa_pairing.MSA_FEATURES:
if '_all_seq' in k and pair_msa_sequences:
chain[k] = chain[k][:msa_crop_size_all_seq, :]
else:
chain[k] = chain[k][:msa_crop_size, :]
chain['num_alignments'] = np.asarray(msa_crop_size, dtype=np.int32)
if include_templates:
chain['num_templates'] = np.asarray(templates_crop_size, dtype=np.int32)
if pair_msa_sequences:
chain['num_alignments_all_seq'] = np.asarray(
msa_crop_size_all_seq, dtype=np.int32)
return chain
def process_final(
np_example: Mapping[str, np.ndarray]
) -> Mapping[str, np.ndarray]:
"""Final processing steps in data pipeline, after merging and pairing."""
np_example = _correct_msa_restypes(np_example)
np_example = _make_seq_mask(np_example)
np_example = _make_msa_mask(np_example)
np_example = _filter_features(np_example)
return np_example
def _correct_msa_restypes(np_example):
"""Correct MSA restype to have the same order as residue_constants."""
new_order_list = residue_constants.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
np_example['msa'] = np.take(new_order_list, np_example['msa'], axis=0)
np_example['msa'] = np_example['msa'].astype(np.int32)
return np_example
def _make_seq_mask(np_example):
np_example['seq_mask'] = (np_example['entity_id'] > 0).astype(np.float32)
return np_example
def _make_msa_mask(np_example):
"""Mask features are all ones, but will later be zero-padded."""
np_example['msa_mask'] = np.ones_like(np_example['msa'], dtype=np.float32)
seq_mask = (np_example['entity_id'] > 0).astype(np.float32)
np_example['msa_mask'] *= seq_mask[None]
return np_example
def _filter_features(
np_example: Mapping[str, np.ndarray]
) -> Mapping[str, np.ndarray]:
"""Filters features of example to only those requested."""
return {k: v for (k, v) in np_example.items() if k in REQUIRED_FEATURES}
def process_unmerged_features(
all_chain_features: MutableMapping[str, Mapping[str, np.ndarray]]
):
"""Postprocessing stage for per-chain features before merging."""
num_chains = len(all_chain_features)
for chain_features in all_chain_features.values():
# Convert deletion matrices to float.
chain_features['deletion_matrix'] = np.asarray(
chain_features.pop('deletion_matrix_int'), dtype=np.float32
)
if 'deletion_matrix_int_all_seq' in chain_features:
chain_features['deletion_matrix_all_seq'] = np.asarray(
chain_features.pop('deletion_matrix_int_all_seq'), dtype=np.float32
)
chain_features['deletion_mean'] = np.mean(
chain_features['deletion_matrix'], axis=0
)
if 'all_atom_positions' not in chain_features:
# Add all_atom_mask and dummy all_atom_positions based on aatype.
all_atom_mask = residue_constants.STANDARD_ATOM_MASK[
chain_features['aatype']]
chain_features['all_atom_mask'] = all_atom_mask.astype(dtype=np.float32)
chain_features['all_atom_positions'] = np.zeros(
list(all_atom_mask.shape) + [3])
# Add assembly_num_chains.
chain_features['assembly_num_chains'] = np.asarray(num_chains)
# Add entity_mask.
for chain_features in all_chain_features.values():
chain_features['entity_mask'] = (
chain_features['entity_id'] != 0).astype(np.int32)
...@@ -107,7 +107,8 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed): ...@@ -107,7 +107,8 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed):
# the masked locations and secret corrupted locations. # the masked locations and secret corrupted locations.
transforms.append( transforms.append(
data_transforms.make_masked_msa( data_transforms.make_masked_msa(
common_cfg.masked_msa, mode_cfg.masked_msa_replace_fraction common_cfg.masked_msa, mode_cfg.masked_msa_replace_fraction,
seed=(msa_seed + 1) if msa_seed else None,
) )
) )
......
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import torch
from openfold.data import (
data_transforms,
data_transforms_multimer,
)
def groundtruth_transforms_fns():
transforms = [data_transforms.make_atom14_masks,
data_transforms.make_atom14_positions,
data_transforms.atom37_to_frames,
data_transforms.atom37_to_torsion_angles(""),
data_transforms.make_pseudo_beta(""),
data_transforms.get_backbone_frames,
data_transforms.get_chi_angles]
return transforms
def nonensembled_transform_fns():
"""Input pipeline data transformers that are not ensembled."""
transforms = [
data_transforms.cast_to_64bit_ints,
data_transforms_multimer.make_msa_profile,
data_transforms_multimer.create_target_feat,
data_transforms.make_atom14_masks
]
return transforms
def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed):
"""Input pipeline data transformers that can be ensembled and averaged."""
transforms = []
pad_msa_clusters = mode_cfg.max_msa_clusters
max_msa_clusters = pad_msa_clusters
max_extra_msa = mode_cfg.max_extra_msa
msa_seed = None
if(not common_cfg.resample_msa_in_recycling):
msa_seed = ensemble_seed
transforms.append(
data_transforms_multimer.sample_msa(
max_msa_clusters,
max_extra_msa,
seed=msa_seed,
)
)
if "masked_msa" in common_cfg:
# Masked MSA should come *before* MSA clustering so that
# the clustering and full MSA profile do not leak information about
# the masked locations and secret corrupted locations.
transforms.append(
data_transforms_multimer.make_masked_msa(
common_cfg.masked_msa,
mode_cfg.masked_msa_replace_fraction,
seed=(msa_seed + 1) if msa_seed else None,
)
)
transforms.append(data_transforms_multimer.nearest_neighbor_clusters())
transforms.append(data_transforms_multimer.create_msa_feat)
crop_feats = dict(common_cfg.feat)
if mode_cfg.fixed_size:
transforms.append(data_transforms.select_feat(list(crop_feats)))
if mode_cfg.crop:
transforms.append(
data_transforms_multimer.random_crop_to_size(
crop_size=mode_cfg.crop_size,
max_templates=mode_cfg.max_templates,
shape_schema=crop_feats,
spatial_crop_prob=mode_cfg.spatial_crop_prob,
interface_threshold=mode_cfg.interface_threshold,
subsample_templates=mode_cfg.subsample_templates,
seed=ensemble_seed + 1,
)
)
transforms.append(
data_transforms.make_fixed_size(
shape_schema=crop_feats,
msa_cluster_size=pad_msa_clusters,
extra_msa_size=mode_cfg.max_extra_msa,
num_res=mode_cfg.crop_size,
num_templates=mode_cfg.max_templates,
)
)
else:
transforms.append(
data_transforms.crop_templates(mode_cfg.max_templates)
)
return transforms
def prepare_ground_truth_features(tensors):
"""Prepare ground truth features that are only needed for loss calculation during training"""
gt_features = ['all_atom_mask', 'all_atom_positions', 'asym_id', 'sym_id', 'entity_id']
gt_tensors = {k: v for k, v in tensors.items() if k in gt_features}
gt_tensors['aatype'] = tensors['aatype'].to(torch.long)
gt_tensors = compose(groundtruth_transforms_fns())(gt_tensors)
return gt_tensors
def process_tensors_from_config(tensors, common_cfg, mode_cfg):
"""Based on the config, apply filters and transformations to the data."""
process_gt_feats = mode_cfg.supervised
gt_tensors = {}
if process_gt_feats:
gt_tensors = prepare_ground_truth_features(tensors)
ensemble_seed = random.randint(0, torch.iinfo(torch.int32).max)
tensors['aatype'] = tensors['aatype'].to(torch.long)
nonensembled = nonensembled_transform_fns()
tensors = compose(nonensembled)(tensors)
if("no_recycling_iters" in tensors):
num_recycling = int(tensors["no_recycling_iters"])
else:
num_recycling = common_cfg.max_recycling_iters
def wrap_ensemble_fn(data, i):
"""Function to be mapped over the ensemble dimension."""
d = data.copy()
fns = ensembled_transform_fns(
common_cfg,
mode_cfg,
ensemble_seed,
)
fn = compose(fns)
d["ensemble_index"] = i
return fn(d)
tensors = map_fn(
lambda x: wrap_ensemble_fn(tensors, x), torch.arange(num_recycling + 1)
)
if process_gt_feats:
tensors['gt_features'] = gt_tensors
return tensors
@data_transforms.curry1
def compose(x, fs):
for f in fs:
x = f(x)
return x
def map_fn(fun, x):
ensembles = [fun(elem) for elem in x]
features = ensembles[0].keys()
ensembled_dict = {}
for feat in features:
ensembled_dict[feat] = torch.stack(
[dict_i[feat] for dict_i in ensembles], dim=-1
)
return ensembled_dict
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
"""Parses the mmCIF file format.""" """Parses the mmCIF file format."""
import collections import collections
import dataclasses import dataclasses
import functools
import io import io
import json import json
import logging import logging
...@@ -173,6 +174,7 @@ def mmcif_loop_to_dict( ...@@ -173,6 +174,7 @@ def mmcif_loop_to_dict(
return {entry[index]: entry for entry in entries} return {entry[index]: entry for entry in entries}
@functools.lru_cache(16, typed=False)
def parse( def parse(
*, file_id: str, mmcif_string: str, catch_all_errors: bool = True *, file_id: str, mmcif_string: str, catch_all_errors: bool = True
) -> ParsingResult: ) -> ParsingResult:
...@@ -346,7 +348,7 @@ def _get_header(parsed_info: MmCIFDict) -> PdbHeader: ...@@ -346,7 +348,7 @@ def _get_header(parsed_info: MmCIFDict) -> PdbHeader:
raw_resolution = parsed_info[res_key][0] raw_resolution = parsed_info[res_key][0]
header["resolution"] = float(raw_resolution) header["resolution"] = float(raw_resolution)
except ValueError: except ValueError:
logging.info( logging.debug(
"Invalid resolution format: %s", parsed_info[res_key] "Invalid resolution format: %s", parsed_info[res_key]
) )
...@@ -474,6 +476,20 @@ def get_atom_coords( ...@@ -474,6 +476,20 @@ def get_atom_coords(
pos[residue_constants.atom_order["SD"]] = [x, y, z] pos[residue_constants.atom_order["SD"]] = [x, y, z]
mask[residue_constants.atom_order["SD"]] = 1.0 mask[residue_constants.atom_order["SD"]] = 1.0
# Fix naming errors in arginine residues where NH2 is incorrectly
# assigned to be closer to CD than NH1
cd = residue_constants.atom_order['CD']
nh1 = residue_constants.atom_order['NH1']
nh2 = residue_constants.atom_order['NH2']
if(
res.get_resname() == 'ARG' and
all(mask[atom_index] for atom_index in (cd, nh1, nh2)) and
(np.linalg.norm(pos[nh1] - pos[cd]) >
np.linalg.norm(pos[nh2] - pos[cd]))
):
pos[nh1], pos[nh2] = pos[nh2].copy(), pos[nh1].copy()
mask[nh1], mask[nh2] = mask[nh2].copy(), mask[nh1].copy()
all_atom_positions[res_index] = pos all_atom_positions[res_index] = pos
all_atom_mask[res_index] = mask all_atom_mask[res_index] = mask
......
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for extracting identifiers from MSA sequence descriptions."""
import dataclasses
import re
from typing import Optional
# Sequences coming from UniProtKB database come in the
# `db|UniqueIdentifier|EntryName` format, e.g. `tr|A0A146SKV9|A0A146SKV9_FUNHE`
# or `sp|P0C2L1|A3X1_LOXLA` (for TREMBL/Swiss-Prot respectively).
_UNIPROT_PATTERN = re.compile(
r"""
^
# UniProtKB/TrEMBL or UniProtKB/Swiss-Prot
(?:tr|sp)
\|
# A primary accession number of the UniProtKB entry.
(?P<AccessionIdentifier>[A-Za-z0-9]{6,10})
# Occasionally there is a _0 or _1 isoform suffix, which we ignore.
(?:_\d)?
\|
# TREMBL repeats the accession ID here. Swiss-Prot has a mnemonic
# protein ID code.
(?:[A-Za-z0-9]+)
_
# A mnemonic species identification code.
(?P<SpeciesIdentifier>([A-Za-z0-9]){1,5})
# Small BFD uses a final value after an underscore, which we ignore.
(?:_\d+)?
$
""",
re.VERBOSE)
@dataclasses.dataclass(frozen=True)
class Identifiers:
species_id: str = ''
def _parse_sequence_identifier(msa_sequence_identifier: str) -> Identifiers:
"""Gets accession id and species from an msa sequence identifier.
The sequence identifier has the format specified by
_UNIPROT_TREMBL_ENTRY_NAME_PATTERN or _UNIPROT_SWISSPROT_ENTRY_NAME_PATTERN.
An example of a sequence identifier: `tr|A0A146SKV9|A0A146SKV9_FUNHE`
Args:
msa_sequence_identifier: a sequence identifier.
Returns:
An `Identifiers` instance with a uniprot_accession_id and species_id. These
can be empty in the case where no identifier was found.
"""
matches = re.search(_UNIPROT_PATTERN, msa_sequence_identifier.strip())
if matches:
return Identifiers(
species_id=matches.group('SpeciesIdentifier')
)
return Identifiers()
def _extract_sequence_identifier(description: str) -> Optional[str]:
"""Extracts sequence identifier from description. Returns None if no match."""
split_description = description.split()
if split_description:
return split_description[0].partition('/')[0]
else:
return None
def get_identifiers(description: str) -> Identifiers:
"""Computes extra MSA features from the description."""
sequence_identifier = _extract_sequence_identifier(description)
if sequence_identifier is None:
return Identifiers()
else:
return _parse_sequence_identifier(sequence_identifier)
This diff is collapsed.
...@@ -16,14 +16,43 @@ ...@@ -16,14 +16,43 @@
"""Functions for parsing various file formats.""" """Functions for parsing various file formats."""
import collections import collections
import dataclasses import dataclasses
import itertools
import re import re
import string import string
from typing import Dict, Iterable, List, Optional, Sequence, Tuple from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Set
DeletionMatrix = Sequence[Sequence[int]] DeletionMatrix = Sequence[Sequence[int]]
@dataclasses.dataclass(frozen=True)
class Msa:
"""Class representing a parsed MSA file"""
sequences: Sequence[str]
deletion_matrix: DeletionMatrix
descriptions: Optional[Sequence[str]]
def __post_init__(self):
if(not (
len(self.sequences) ==
len(self.deletion_matrix) ==
len(self.descriptions)
)):
raise ValueError(
"All fields for an MSA must have the same length"
)
def __len__(self):
return len(self.sequences)
def truncate(self, max_seqs: int):
return Msa(
sequences=self.sequences[:max_seqs],
deletion_matrix=self.deletion_matrix[:max_seqs],
descriptions=self.descriptions[:max_seqs],
)
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class TemplateHit: class TemplateHit:
"""Class representing a template hit.""" """Class representing a template hit."""
...@@ -31,7 +60,7 @@ class TemplateHit: ...@@ -31,7 +60,7 @@ class TemplateHit:
index: int index: int
name: str name: str
aligned_cols: int aligned_cols: int
sum_probs: float sum_probs: Optional[float]
query: str query: str
hit_sequence: str hit_sequence: str
indices_query: List[int] indices_query: List[int]
...@@ -69,9 +98,7 @@ def parse_fasta(fasta_string: str) -> Tuple[Sequence[str], Sequence[str]]: ...@@ -69,9 +98,7 @@ def parse_fasta(fasta_string: str) -> Tuple[Sequence[str], Sequence[str]]:
return sequences, descriptions return sequences, descriptions
def parse_stockholm( def parse_stockholm(stockholm_string: str) -> Msa:
stockholm_string: str,
) -> Tuple[Sequence[str], DeletionMatrix, Sequence[str]]:
"""Parses sequences and deletion matrix from stockholm format alignment. """Parses sequences and deletion matrix from stockholm format alignment.
Args: Args:
...@@ -126,10 +153,14 @@ def parse_stockholm( ...@@ -126,10 +153,14 @@ def parse_stockholm(
deletion_count = 0 deletion_count = 0
deletion_matrix.append(deletion_vec) deletion_matrix.append(deletion_vec)
return msa, deletion_matrix, list(name_to_sequence.keys()) return Msa(
sequences=msa,
deletion_matrix=deletion_matrix,
descriptions=list(name_to_sequence.keys())
)
def parse_a3m(a3m_string: str) -> Tuple[Sequence[str], DeletionMatrix]: def parse_a3m(a3m_string: str) -> Msa:
"""Parses sequences and deletion matrix from a3m format alignment. """Parses sequences and deletion matrix from a3m format alignment.
Args: Args:
...@@ -144,7 +175,7 @@ def parse_a3m(a3m_string: str) -> Tuple[Sequence[str], DeletionMatrix]: ...@@ -144,7 +175,7 @@ def parse_a3m(a3m_string: str) -> Tuple[Sequence[str], DeletionMatrix]:
at `deletion_matrix[i][j]` is the number of residues deleted from at `deletion_matrix[i][j]` is the number of residues deleted from
the aligned sequence i at residue position j. the aligned sequence i at residue position j.
""" """
sequences, _ = parse_fasta(a3m_string) sequences, descriptions = parse_fasta(a3m_string)
deletion_matrix = [] deletion_matrix = []
for msa_sequence in sequences: for msa_sequence in sequences:
deletion_vec = [] deletion_vec = []
...@@ -160,7 +191,11 @@ def parse_a3m(a3m_string: str) -> Tuple[Sequence[str], DeletionMatrix]: ...@@ -160,7 +191,11 @@ def parse_a3m(a3m_string: str) -> Tuple[Sequence[str], DeletionMatrix]:
# Make the MSA matrix out of aligned (deletion-free) sequences. # Make the MSA matrix out of aligned (deletion-free) sequences.
deletion_table = str.maketrans("", "", string.ascii_lowercase) deletion_table = str.maketrans("", "", string.ascii_lowercase)
aligned_sequences = [s.translate(deletion_table) for s in sequences] aligned_sequences = [s.translate(deletion_table) for s in sequences]
return aligned_sequences, deletion_matrix return Msa(
sequences=aligned_sequences,
deletion_matrix=deletion_matrix,
descriptions=descriptions
)
def _convert_sto_seq_to_a3m( def _convert_sto_seq_to_a3m(
...@@ -174,7 +209,9 @@ def _convert_sto_seq_to_a3m( ...@@ -174,7 +209,9 @@ def _convert_sto_seq_to_a3m(
def convert_stockholm_to_a3m( def convert_stockholm_to_a3m(
stockholm_format: str, max_sequences: Optional[int] = None stockholm_format: str,
max_sequences: Optional[int] = None,
remove_first_row_gaps: bool = True,
) -> str: ) -> str:
"""Converts MSA in Stockholm format to the A3M format.""" """Converts MSA in Stockholm format to the A3M format."""
descriptions = {} descriptions = {}
...@@ -212,13 +249,19 @@ def convert_stockholm_to_a3m( ...@@ -212,13 +249,19 @@ def convert_stockholm_to_a3m(
# Convert sto format to a3m line by line # Convert sto format to a3m line by line
a3m_sequences = {} a3m_sequences = {}
if(remove_first_row_gaps):
# query_sequence is assumed to be the first sequence # query_sequence is assumed to be the first sequence
query_sequence = next(iter(sequences.values())) query_sequence = next(iter(sequences.values()))
query_non_gaps = [res != "-" for res in query_sequence] query_non_gaps = [res != "-" for res in query_sequence]
for seqname, sto_sequence in sequences.items(): for seqname, sto_sequence in sequences.items():
a3m_sequences[seqname] = "".join( # Dots are optional in a3m format and are commonly removed.
_convert_sto_seq_to_a3m(query_non_gaps, sto_sequence) out_sequence = sto_sequence.replace('.', '')
if(remove_first_row_gaps):
out_sequence = ''.join(
_convert_sto_seq_to_a3m(query_non_gaps, out_sequence)
) )
a3m_sequences[seqname] = out_sequence
fasta_chunks = ( fasta_chunks = (
f">{k} {descriptions.get(k, '')}\n{a3m_sequences[k]}" f">{k} {descriptions.get(k, '')}\n{a3m_sequences[k]}"
...@@ -227,6 +270,124 @@ def convert_stockholm_to_a3m( ...@@ -227,6 +270,124 @@ def convert_stockholm_to_a3m(
return "\n".join(fasta_chunks) + "\n" # Include terminating newline. return "\n".join(fasta_chunks) + "\n" # Include terminating newline.
def _keep_line(line: str, seqnames: Set[str]) -> bool:
"""Function to decide which lines to keep."""
if not line.strip():
return True
if line.strip() == '//': # End tag
return True
if line.startswith('# STOCKHOLM'): # Start tag
return True
if line.startswith('#=GC RF'): # Reference Annotation Line
return True
if line[:4] == '#=GS': # Description lines - keep if sequence in list.
_, seqname, _ = line.split(maxsplit=2)
return seqname in seqnames
elif line.startswith('#'): # Other markup - filter out
return False
else: # Alignment data - keep if sequence in list.
seqname = line.partition(' ')[0]
return seqname in seqnames
def truncate_stockholm_msa(stockholm_msa_path: str, max_sequences: int) -> str:
"""Reads + truncates a Stockholm file while preventing excessive RAM usage."""
seqnames = set()
filtered_lines = []
with open(stockholm_msa_path) as f:
for line in f:
if line.strip() and not line.startswith(('#', '//')):
# Ignore blank lines, markup and end symbols - remainder are alignment
# sequence parts.
seqname = line.partition(' ')[0]
seqnames.add(seqname)
if len(seqnames) >= max_sequences:
break
f.seek(0)
for line in f:
if _keep_line(line, seqnames):
filtered_lines.append(line)
return ''.join(filtered_lines)
def remove_empty_columns_from_stockholm_msa(stockholm_msa: str) -> str:
"""Removes empty columns (dashes-only) from a Stockholm MSA."""
processed_lines = {}
unprocessed_lines = {}
for i, line in enumerate(stockholm_msa.splitlines()):
if line.startswith('#=GC RF'):
reference_annotation_i = i
reference_annotation_line = line
# Reached the end of this chunk of the alignment. Process chunk.
_, _, first_alignment = line.rpartition(' ')
mask = []
for j in range(len(first_alignment)):
for _, unprocessed_line in unprocessed_lines.items():
prefix, _, alignment = unprocessed_line.rpartition(' ')
if alignment[j] != '-':
mask.append(True)
break
else: # Every row contained a hyphen - empty column.
mask.append(False)
# Add reference annotation for processing with mask.
unprocessed_lines[reference_annotation_i] = reference_annotation_line
if not any(mask): # All columns were empty. Output empty lines for chunk.
for line_index in unprocessed_lines:
processed_lines[line_index] = ''
else:
for line_index, unprocessed_line in unprocessed_lines.items():
prefix, _, alignment = unprocessed_line.rpartition(' ')
masked_alignment = ''.join(itertools.compress(alignment, mask))
processed_lines[line_index] = f'{prefix} {masked_alignment}'
# Clear raw_alignments.
unprocessed_lines = {}
elif line.strip() and not line.startswith(('#', '//')):
unprocessed_lines[i] = line
else:
processed_lines[i] = line
return '\n'.join((processed_lines[i] for i in range(len(processed_lines))))
def deduplicate_stockholm_msa(stockholm_msa: str) -> str:
"""Remove duplicate sequences (ignoring insertions wrt query)."""
sequence_dict = collections.defaultdict(str)
# First we must extract all sequences from the MSA.
for line in stockholm_msa.splitlines():
# Only consider the alignments - ignore reference annotation, empty lines,
# descriptions or markup.
if line.strip() and not line.startswith(('#', '//')):
line = line.strip()
seqname, alignment = line.split()
sequence_dict[seqname] += alignment
seen_sequences = set()
seqnames = set()
# First alignment is the query.
query_align = next(iter(sequence_dict.values()))
mask = [c != '-' for c in query_align] # Mask is False for insertions.
for seqname, alignment in sequence_dict.items():
# Apply mask to remove all insertions from the string.
masked_alignment = ''.join(itertools.compress(alignment, mask))
if masked_alignment in seen_sequences:
continue
else:
seen_sequences.add(masked_alignment)
seqnames.add(seqname)
filtered_lines = []
for line in stockholm_msa.splitlines():
if _keep_line(line, seqnames):
filtered_lines.append(line)
return '\n'.join(filtered_lines) + '\n'
def _get_hhr_line_regex_groups( def _get_hhr_line_regex_groups(
regex_pattern: str, line: str regex_pattern: str, line: str
) -> Sequence[Optional[str]]: ) -> Sequence[Optional[str]]:
...@@ -280,7 +441,7 @@ def _parse_hhr_hit(detailed_lines: Sequence[str]) -> TemplateHit: ...@@ -280,7 +441,7 @@ def _parse_hhr_hit(detailed_lines: Sequence[str]) -> TemplateHit:
"Could not parse section: %s. Expected this: \n%s to contain summary." "Could not parse section: %s. Expected this: \n%s to contain summary."
% (detailed_lines, detailed_lines[2]) % (detailed_lines, detailed_lines[2])
) )
(prob_true, e_value, _, aligned_cols, _, _, sum_probs, neff) = [ (_, _, _, aligned_cols, _, _, sum_probs, _) = [
float(x) for x in match.groups() float(x) for x in match.groups()
] ]
...@@ -388,3 +549,115 @@ def parse_e_values_from_tblout(tblout: str) -> Dict[str, float]: ...@@ -388,3 +549,115 @@ def parse_e_values_from_tblout(tblout: str) -> Dict[str, float]:
target_name = fields[0] target_name = fields[0]
e_values[target_name] = float(e_value) e_values[target_name] = float(e_value)
return e_values return e_values
def _get_indices(sequence: str, start: int) -> List[int]:
"""Returns indices for non-gap/insert residues starting at the given index."""
indices = []
counter = start
for symbol in sequence:
# Skip gaps but add a placeholder so that the alignment is preserved.
if symbol == '-':
indices.append(-1)
# Skip deleted residues, but increase the counter.
elif symbol.islower():
counter += 1
# Normal aligned residue. Increase the counter and append to indices.
else:
indices.append(counter)
counter += 1
return indices
@dataclasses.dataclass(frozen=True)
class HitMetadata:
pdb_id: str
chain: str
start: int
end: int
length: int
text: str
def _parse_hmmsearch_description(description: str) -> HitMetadata:
"""Parses the hmmsearch A3M sequence description line."""
# Example 1: >4pqx_A/2-217 [subseq from] mol:protein length:217 Free text
# Example 2: >5g3r_A/1-55 [subseq from] mol:protein length:352
match = re.match(
r'^>?([a-z0-9]+)_(\w+)/([0-9]+)-([0-9]+).*protein length:([0-9]+) *(.*)$',
description.strip())
if not match:
raise ValueError(f'Could not parse description: "{description}".')
return HitMetadata(
pdb_id=match[1],
chain=match[2],
start=int(match[3]),
end=int(match[4]),
length=int(match[5]),
text=match[6]
)
def parse_hmmsearch_a3m(
query_sequence: str,
a3m_string: str,
skip_first: bool = True
) -> Sequence[TemplateHit]:
"""Parses an a3m string produced by hmmsearch.
Args:
query_sequence: The query sequence.
a3m_string: The a3m string produced by hmmsearch.
skip_first: Whether to skip the first sequence in the a3m string.
Returns:
A sequence of `TemplateHit` results.
"""
# Zip the descriptions and MSAs together, skip the first query sequence.
parsed_a3m = list(zip(*parse_fasta(a3m_string)))
if skip_first:
parsed_a3m = parsed_a3m[1:]
indices_query = _get_indices(query_sequence, start=0)
hits = []
for i, (hit_sequence, hit_description) in enumerate(parsed_a3m, start=1):
if 'mol:protein' not in hit_description:
continue # Skip non-protein chains.
metadata = _parse_hmmsearch_description(hit_description)
# Aligned columns are only the match states.
aligned_cols = sum([r.isupper() and r != '-' for r in hit_sequence])
indices_hit = _get_indices(hit_sequence, start=metadata.start - 1)
hit = TemplateHit(
index=i,
name=f'{metadata.pdb_id}_{metadata.chain}',
aligned_cols=aligned_cols,
sum_probs=None,
query=query_sequence,
hit_sequence=hit_sequence.upper(),
indices_query=indices_query,
indices_hit=indices_hit,
)
hits.append(hit)
return hits
def parse_hmmsearch_sto(
output_string: str,
input_sequence: str
) -> Sequence[TemplateHit]:
"""Gets parsed template hits from the raw string output by the tool."""
a3m_string = convert_stockholm_to_a3m(
output_string,
remove_first_row_gaps=False
)
template_hits = parse_hmmsearch_a3m(
query_sequence=input_sequence,
a3m_string=a3m_string,
skip_first=False
)
return template_hits
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment