Commit 9b18d6a9 authored by Augustin Zidek's avatar Augustin Zidek
Browse files

Release code for v2.3.0

PiperOrigin-RevId: 494507694
parent 4494af84
...@@ -7,23 +7,30 @@ v2.0. This is a completely new model that was entered in CASP14 and published in ...@@ -7,23 +7,30 @@ v2.0. This is a completely new model that was entered in CASP14 and published in
Nature. For simplicity, we refer to this model as AlphaFold throughout the rest Nature. For simplicity, we refer to this model as AlphaFold throughout the rest
of this document. of this document.
We also provide an implementation of AlphaFold-Multimer. This represents a work We also provide:
in progress and AlphaFold-Multimer isn't expected to be as stable as our monomer
AlphaFold system. 1. An implementation of AlphaFold-Multimer. This represents a work in progress
[Read the guide](#updating-existing-alphafold-installation-to-include-alphafold-multimers) and AlphaFold-Multimer isn't expected to be as stable as our monomer
for how to upgrade and update code. AlphaFold system.
[Read the guide](#updating-existing-installation)
Any publication that discloses findings arising from using this source code or the model parameters should [cite](#citing-this-work) the for how to upgrade and update code.
2. The [technical note](docs/technical_note_v2.3.0.md) containing the models
and inference procedure for an updated AlphaFold v2.3.0.
3. A [CASP15 baseline](docs/casp15_predictions.zip) set of predictions along
with documentation of any manual interventions performed.
Any publication that discloses findings arising from using this source code or
the model parameters should [cite](#citing-this-work) the
[AlphaFold paper](https://doi.org/10.1038/s41586-021-03819-2) and, if [AlphaFold paper](https://doi.org/10.1038/s41586-021-03819-2) and, if
applicable, the [AlphaFold-Multimer paper](https://www.biorxiv.org/content/10.1101/2021.10.04.463034v1). applicable, the
[AlphaFold-Multimer paper](https://www.biorxiv.org/content/10.1101/2021.10.04.463034v1).
Please also refer to the Please also refer to the
[Supplementary Information](https://static-content.springer.com/esm/art%3A10.1038%2Fs41586-021-03819-2/MediaObjects/41586_2021_3819_MOESM1_ESM.pdf) [Supplementary Information](https://static-content.springer.com/esm/art%3A10.1038%2Fs41586-021-03819-2/MediaObjects/41586_2021_3819_MOESM1_ESM.pdf)
for a detailed description of the method. for a detailed description of the method.
**You can use a slightly simplified version of AlphaFold with **You can use a slightly simplified version of AlphaFold with
[this Colab [this Colab notebook](https://colab.research.google.com/github/deepmind/alphafold/blob/main/notebooks/AlphaFold.ipynb)**
notebook](https://colab.research.google.com/github/deepmind/alphafold/blob/main/notebooks/AlphaFold.ipynb)**
or community-supported versions (see below). or community-supported versions (see below).
If you have any questions, please contact the AlphaFold team at If you have any questions, please contact the AlphaFold team at
...@@ -58,9 +65,9 @@ The following steps are required in order to run AlphaFold: ...@@ -58,9 +65,9 @@ The following steps are required in order to run AlphaFold:
or take a look at the following or take a look at the following
[NVIDIA Docker issue](https://github.com/NVIDIA/nvidia-docker/issues/1447#issuecomment-801479573). [NVIDIA Docker issue](https://github.com/NVIDIA/nvidia-docker/issues/1447#issuecomment-801479573).
If you wish to run AlphaFold using Singularity (a common containerization platform on HPC systems) we recommend using some of the If you wish to run AlphaFold using Singularity (a common containerization
third party Singularity setups as linked in platform on HPC systems) we recommend using some of the third party Singularity
https://github.com/deepmind/alphafold/issues/10 or setups as linked in https://github.com/deepmind/alphafold/issues/10 or
https://github.com/deepmind/alphafold/issues/24. https://github.com/deepmind/alphafold/issues/24.
### Genetic databases ### Genetic databases
...@@ -74,7 +81,7 @@ AlphaFold needs multiple genetic (sequence) databases to run: ...@@ -74,7 +81,7 @@ AlphaFold needs multiple genetic (sequence) databases to run:
* [PDB70](http://wwwuser.gwdg.de/~compbiol/data/hhsuite/databases/hhsuite_dbs/), * [PDB70](http://wwwuser.gwdg.de/~compbiol/data/hhsuite/databases/hhsuite_dbs/),
* [PDB](https://www.rcsb.org/) (structures in the mmCIF format), * [PDB](https://www.rcsb.org/) (structures in the mmCIF format),
* [PDB seqres](https://www.rcsb.org/) – only for AlphaFold-Multimer, * [PDB seqres](https://www.rcsb.org/) – only for AlphaFold-Multimer,
* [Uniclust30](https://uniclust.mmseqs.com/), * [UniRef30 (FKA UniClust30)](https://uniclust.mmseqs.com/),
* [UniProt](https://www.uniprot.org/uniprot/) – only for AlphaFold-Multimer, * [UniProt](https://www.uniprot.org/uniprot/) – only for AlphaFold-Multimer,
* [UniRef90](https://www.uniprot.org/help/uniref). * [UniRef90](https://www.uniprot.org/help/uniref).
...@@ -98,29 +105,34 @@ and set up all of these databases: ...@@ -98,29 +105,34 @@ and set up all of these databases:
will download a reduced version of the databases to be used with the will download a reduced version of the databases to be used with the
`reduced_dbs` database preset. `reduced_dbs` database preset.
:ledger: **Note: The download directory `<DOWNLOAD_DIR>` should _not_ be a :ledger: **Note: The download directory `<DOWNLOAD_DIR>` should *not* be a
subdirectory in the AlphaFold repository directory.** If it is, the Docker build subdirectory in the AlphaFold repository directory.** If it is, the Docker build
will be slow as the large databases will be copied during the image creation. will be slow as the large databases will be copied during the image creation.
We don't provide exactly the database versions used in CASP14 – see the [note on We don't provide exactly the database versions used in CASP14 – see the
reproducibility](#note-on-reproducibility). Some of the databases are mirrored [note on reproducibility](#note-on-casp14-reproducibility). Some of the databases are
for speed, see [mirrored databases](#mirrored-databases). mirrored for speed, see [mirrored databases](#mirrored-databases).
:ledger: **Note: The total download size for the full databases is around 415 GB :ledger: **Note: The total download size for the full databases is around 415 GB
and the total size when unzipped is 2.2 TB. Please make sure you have a large and the total size when unzipped is 2.62 TB. Please make sure you have a large
enough hard drive space, bandwidth and time to download. We recommend using an enough hard drive space, bandwidth and time to download. We recommend using an
SSD for better genetic search performance.** SSD for better genetic search performance.**
:ledger: **Note: If the download directory and datasets don't have full read and
write permissions, it can cause errors with the MSA tools, with opaque
(external) error messages. Please ensure the required permissions are applied,
e.g. with the `sudo chmod 755 --recursive "$DOWNLOAD_DIR"` command.**
The `download_all_data.sh` script will also download the model parameter files. The `download_all_data.sh` script will also download the model parameter files.
Once the script has finished, you should have the following directory structure: Once the script has finished, you should have the following directory structure:
``` ```
$DOWNLOAD_DIR/ # Total: ~ 2.2 TB (download: 438 GB) $DOWNLOAD_DIR/ # Total: ~ 2.62 TB (download: 556 GB)
bfd/ # ~ 1.7 TB (download: 271.6 GB) bfd/ # ~ 1.8 TB (download: 271.6 GB)
# 6 files. # 6 files.
mgnify/ # ~ 64 GB (download: 32.9 GB) mgnify/ # ~ 120 GB (download: 67 GB)
mgy_clusters_2018_12.fa mgy_clusters_2022_05.fa
params/ # ~ 3.5 GB (download: 3.5 GB) params/ # ~ 5.3 GB (download: 5.3 GB)
# 5 CASP14 models, # 5 CASP14 models,
# 5 pTM models, # 5 pTM models,
# 5 AlphaFold-Multimer models, # 5 AlphaFold-Multimer models,
...@@ -128,20 +140,19 @@ $DOWNLOAD_DIR/ # Total: ~ 2.2 TB (download: 438 GB) ...@@ -128,20 +140,19 @@ $DOWNLOAD_DIR/ # Total: ~ 2.2 TB (download: 438 GB)
# = 16 files. # = 16 files.
pdb70/ # ~ 56 GB (download: 19.5 GB) pdb70/ # ~ 56 GB (download: 19.5 GB)
# 9 files. # 9 files.
pdb_mmcif/ # ~ 206 GB (download: 46 GB) pdb_mmcif/ # ~ 238 GB (download: 43 GB)
mmcif_files/ mmcif_files/
# About 180,000 .cif files. # About 199,000 .cif files.
obsolete.dat obsolete.dat
pdb_seqres/ # ~ 0.2 GB (download: 0.2 GB) pdb_seqres/ # ~ 0.2 GB (download: 0.2 GB)
pdb_seqres.txt pdb_seqres.txt
small_bfd/ # ~ 17 GB (download: 9.6 GB) small_bfd/ # ~ 17 GB (download: 9.6 GB)
bfd-first_non_consensus_sequences.fasta bfd-first_non_consensus_sequences.fasta
uniclust30/ # ~ 86 GB (download: 24.9 GB) uniref30/ # ~ 206 GB (download: 52.5 GB)
uniclust30_2018_08/ # 7 files.
# 13 files. uniprot/ # ~ 105 GB (download: 53 GB)
uniprot/ # ~ 98.3 GB (download: 49 GB)
uniprot.fasta uniprot.fasta
uniref90/ # ~ 58 GB (download: 29.7 GB) uniref90/ # ~ 67 GB (download: 34 GB)
uniref90.fasta uniref90.fasta
``` ```
...@@ -151,11 +162,12 @@ is only downloaded if you download the reduced databases. ...@@ -151,11 +162,12 @@ is only downloaded if you download the reduced databases.
### Model parameters ### Model parameters
While the AlphaFold code is licensed under the Apache 2.0 License, the AlphaFold While the AlphaFold code is licensed under the Apache 2.0 License, the AlphaFold
parameters are made available under the terms of the CC BY 4.0 license. Please parameters and CASP15 prediction data are made available under the terms of the
see the [Disclaimer](#license-and-disclaimer) below for more detail. CC BY 4.0 license. Please see the [Disclaimer](#license-and-disclaimer) below
for more detail.
The AlphaFold parameters are available from The AlphaFold parameters are available from
https://storage.googleapis.com/alphafold/alphafold_params_2022-03-02.tar, and https://storage.googleapis.com/alphafold/alphafold_params_2022-12-06.tar, and
are downloaded as part of the `scripts/download_all_data.sh` script. This script are downloaded as part of the `scripts/download_all_data.sh` script. This script
will download parameters for: will download parameters for:
...@@ -168,18 +180,25 @@ will download parameters for: ...@@ -168,18 +180,25 @@ will download parameters for:
* 5 AlphaFold-Multimer models that produce pTM and PAE values alongside their * 5 AlphaFold-Multimer models that produce pTM and PAE values alongside their
structure predictions. structure predictions.
### Updating existing AlphaFold installation to include AlphaFold-Multimers ### Updating existing installation
If you have AlphaFold v2.0.0 or v2.0.1 you can either reinstall AlphaFold fully If you have a previous version you can either reinstall fully from scratch
from scratch (remove everything and run the setup from scratch) or you can do an (remove everything and run the setup from scratch) or you can do an incremental
incremental update that will be significantly faster but will require a bit more update that will be significantly faster but will require a bit more work. Make
work. Make sure you follow these steps in the exact order they are listed below: sure you follow these steps in the exact order they are listed below:
1. **Update the code.** 1. **Update the code.**
* Go to the directory with the cloned AlphaFold repository and run * Go to the directory with the cloned AlphaFold repository and run `git
`git fetch origin main` to get all code updates. fetch origin main` to get all code updates.
1. **Download the UniProt and PDB seqres databases.** 1. **Update the UniProt, UniRef, MGnify and PDB seqres databases.**
* Remove `<DOWNLOAD_DIR>/uniprot`.
* Run `scripts/download_uniprot.sh <DOWNLOAD_DIR>`. * Run `scripts/download_uniprot.sh <DOWNLOAD_DIR>`.
* Remove `<DOWNLOAD_DIR>/uniclust30`.
* Run `scripts/download_uniref30.sh <DOWNLOAD_DIR>`.
* Remove `<DOWNLOAD_DIR>/uniref90`.
* Run `scripts/download_uniref90.sh <DOWNLOAD_DIR>`.
* Remove `<DOWNLOAD_DIR>/mgnify`.
* Run `scripts/download_mgnify.sh <DOWNLOAD_DIR>`.
* Remove `<DOWNLOAD_DIR>/pdb_mmcif`. It is needed to have PDB SeqRes and * Remove `<DOWNLOAD_DIR>/pdb_mmcif`. It is needed to have PDB SeqRes and
PDB from exactly the same date. Failure to do this step will result in PDB from exactly the same date. Failure to do this step will result in
potential errors when searching for templates when running potential errors when searching for templates when running
...@@ -192,40 +211,21 @@ work. Make sure you follow these steps in the exact order they are listed below: ...@@ -192,40 +211,21 @@ work. Make sure you follow these steps in the exact order they are listed below:
`scripts/download_alphafold_params.sh <DOWNLOAD_DIR>`. `scripts/download_alphafold_params.sh <DOWNLOAD_DIR>`.
1. **Follow [Running AlphaFold](#running-alphafold).** 1. **Follow [Running AlphaFold](#running-alphafold).**
#### API changes between v2.0.0 and v2.1.0 #### Using deprecated model weights
We tried to keep the API as much backwards compatible as possible, but we had to
change the following:
* The `RunModel.predict()` now needs a `random_seed` argument as MSA sampling
happens inside the Multimer model.
* The `preset` flag in `run_alphafold.py` and `run_docker.py` was split into
`db_preset` and `model_preset`.
* The models to use are not specified using `model_names` but rather using the
`model_preset` flag. If you want to customize which models are used for each
preset, you will have to modify the the `MODEL_PRESETS` dictionary in
`alphafold/model/config.py`.
* Setting the `data_dir` flag is now needed when using `run_docker.py`.
#### API changes between v2.1.0 and v2.2.0
The AlphaFold-Multimer model weights have been updated, these new models have
greatly reduced numbers of clashes on average and are slightly more accurate.
A flag `--num_multimer_predictions_per_model` has been added that controls how To use the deprecated v2.2.0 AlphaFold-Multimer model weights:
many predictions will be made per model, by default the offline system will run
each model 5 times for a total of 25 predictions.
The `--is_prokaryote_list` flag has been removed along with the `is_prokaryote` 1. Change `SOURCE_URL` in `scripts/download_alphafold_params.sh` to
argument in `run_alphafold.predict_structure()`, eukaryotes and prokaryotes are `https://storage.googleapis.com/alphafold/alphafold_params_2022-03-02.tar`,
now paired in the same way. and download the old parameters.
2. Change the `_v3` to `_v2` in the multimer `MODEL_PRESETS` in `config.py`.
To use the deprecated v2.1.0 AlphaFold-Multimer model weights: To use the deprecated v2.1.0 AlphaFold-Multimer model weights:
1. Change `SOURCE_URL` in `scripts/download_alphafold_params.sh` to 1. Change `SOURCE_URL` in `scripts/download_alphafold_params.sh` to
`https://storage.googleapis.com/alphafold/alphafold_params_2022-01-19.tar`, `https://storage.googleapis.com/alphafold/alphafold_params_2022-01-19.tar`,
and download the old parameters. and download the old parameters.
2. Remove the `_v2` in the multimer `MODEL_PRESETS` in `config.py`. 2. Remove the `_v3` in the multimer `MODEL_PRESETS` in `config.py`.
## Running AlphaFold ## Running AlphaFold
...@@ -266,9 +266,7 @@ with 12 vCPUs, 85 GB of RAM, a 100 GB boot disk, the databases on an additional ...@@ -266,9 +266,7 @@ with 12 vCPUs, 85 GB of RAM, a 100 GB boot disk, the databases on an additional
``` ```
1. Make sure that the output directory exists (the default is `/tmp/alphafold`) 1. Make sure that the output directory exists (the default is `/tmp/alphafold`)
and that you have sufficient permissions to write into it. You can make sure and that you have sufficient permissions to write into it.
that is the case by manually running `mkdir /tmp/alphafold` and
`chmod 770 /tmp/alphafold`.
1. Run `run_docker.py` pointing to a FASTA file containing the protein 1. Run `run_docker.py` pointing to a FASTA file containing the protein
sequence(s) for which you wish to predict the structure. If you are sequence(s) for which you wish to predict the structure. If you are
...@@ -291,16 +289,16 @@ with 12 vCPUs, 85 GB of RAM, a 100 GB boot disk, the databases on an additional ...@@ -291,16 +289,16 @@ with 12 vCPUs, 85 GB of RAM, a 100 GB boot disk, the databases on an additional
[GPU enumeration](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/user-guide.html#gpu-enumeration) [GPU enumeration](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/user-guide.html#gpu-enumeration)
for more details. for more details.
1. You can control which AlphaFold model to run by adding the 1. You can control which AlphaFold model to run by adding the `--model_preset=`
`--model_preset=` flag. We provide the following models: flag. We provide the following models:
* **monomer**: This is the original model used at CASP14 with no ensembling. * **monomer**: This is the original model used at CASP14 with no
ensembling.
* **monomer\_casp14**: This is the original model used at CASP14 with * **monomer\_casp14**: This is the original model used at CASP14 with
`num_ensemble=8`, matching our CASP14 configuration. This is largely `num_ensemble=8`, matching our CASP14 configuration. This is largely
provided for reproducibility as it is 8x more computationally provided for reproducibility as it is 8x more computationally expensive
expensive for limited accuracy gain (+0.1 average GDT gain on CASP14 for limited accuracy gain (+0.1 average GDT gain on CASP14 domains).
domains).
* **monomer\_ptm**: This is the original CASP14 model fine tuned with the * **monomer\_ptm**: This is the original CASP14 model fine tuned with the
pTM head, providing a pairwise confidence measure. It is slightly less pTM head, providing a pairwise confidence measure. It is slightly less
...@@ -315,8 +313,8 @@ with 12 vCPUs, 85 GB of RAM, a 100 GB boot disk, the databases on an additional ...@@ -315,8 +313,8 @@ with 12 vCPUs, 85 GB of RAM, a 100 GB boot disk, the databases on an additional
provide the following presets: provide the following presets:
* **reduced\_dbs**: This preset is optimized for speed and lower hardware * **reduced\_dbs**: This preset is optimized for speed and lower hardware
requirements. It runs with a reduced version of the BFD database. requirements. It runs with a reduced version of the BFD database. It
It requires 8 CPU cores (vCPUs), 8 GB of RAM, and 600 GB of disk space. requires 8 CPU cores (vCPUs), 8 GB of RAM, and 600 GB of disk space.
* **full\_dbs**: This runs with all genetic databases used at CASP14. * **full\_dbs**: This runs with all genetic databases used at CASP14.
...@@ -379,8 +377,8 @@ python3 docker/run_docker.py \ ...@@ -379,8 +377,8 @@ python3 docker/run_docker.py \
#### Folding a homomer #### Folding a homomer
Say we have a homomer with 3 copies of the same sequence Say we have a homomer with 3 copies of the same sequence `<SEQUENCE>`. The input
`<SEQUENCE>`. The input fasta should be: fasta should be:
```fasta ```fasta
>sequence_1 >sequence_1
...@@ -403,8 +401,8 @@ python3 docker/run_docker.py \ ...@@ -403,8 +401,8 @@ python3 docker/run_docker.py \
#### Folding a heteromer #### Folding a heteromer
Say we have an A2B3 heteromer, i.e. with 2 copies of Say we have an A2B3 heteromer, i.e. with 2 copies of `<SEQUENCE A>` and 3 copies
`<SEQUENCE A>` and 3 copies of `<SEQUENCE B>`. The input fasta should be: of `<SEQUENCE B>`. The input fasta should be:
```fasta ```fasta
>sequence_1 >sequence_1
...@@ -470,12 +468,13 @@ The `--output_dir` directory will have the following structure: ...@@ -470,12 +468,13 @@ The `--output_dir` directory will have the following structure:
features.pkl features.pkl
ranked_{0,1,2,3,4}.pdb ranked_{0,1,2,3,4}.pdb
ranking_debug.json ranking_debug.json
relax_metrics.json
relaxed_model_{1,2,3,4,5}.pdb relaxed_model_{1,2,3,4,5}.pdb
result_model_{1,2,3,4,5}.pkl result_model_{1,2,3,4,5}.pkl
timings.json timings.json
unrelaxed_model_{1,2,3,4,5}.pdb unrelaxed_model_{1,2,3,4,5}.pdb
msas/ msas/
bfd_uniclust_hits.a3m bfd_uniref_hits.a3m
mgnify_hits.sto mgnify_hits.sto
uniref90_hits.sto uniref90_hits.sto
``` ```
...@@ -499,6 +498,8 @@ The contents of each output file are as follows: ...@@ -499,6 +498,8 @@ The contents of each output file are as follows:
* `ranking_debug.json` – A JSON format text file containing the pLDDT values * `ranking_debug.json` – A JSON format text file containing the pLDDT values
used to perform the model ranking, and a mapping back to the original model used to perform the model ranking, and a mapping back to the original model
names. names.
* `relax_metrics.json` – A JSON format text file containing relax metrics, for
instance remaining violations.
* `timings.json` – A JSON format text file containing the times taken to run * `timings.json` – A JSON format text file containing the times taken to run
each section of the AlphaFold pipeline. each section of the AlphaFold pipeline.
* `msas/` - A directory containing the files describing the various genetic * `msas/` - A directory containing the files describing the various genetic
...@@ -576,7 +577,8 @@ For genetics: ...@@ -576,7 +577,8 @@ For genetics:
For templates: For templates:
* PDB: (downloaded 2020-05-14) * PDB: (downloaded 2020-05-14)
* PDB70: [2020-05-13](http://wwwuser.gwdg.de/~compbiol/data/hhsuite/databases/hhsuite_dbs/old-releases/pdb70_from_mmcif_200513.tar.gz) * PDB70:
[2020-05-13](http://wwwuser.gwdg.de/~compbiol/data/hhsuite/databases/hhsuite_dbs/old-releases/pdb70_from_mmcif_200513.tar.gz)
An alternative for templates is to use the latest PDB and PDB70, but pass the An alternative for templates is to use the latest PDB and PDB70, but pass the
flag `--max_template_date=2020-05-14`, which restricts templates only to flag `--max_template_date=2020-05-14`, which restricts templates only to
...@@ -601,6 +603,7 @@ If you use the code or data in this package, please cite: ...@@ -601,6 +603,7 @@ If you use the code or data in this package, please cite:
In addition, if you use the AlphaFold-Multimer mode, please cite: In addition, if you use the AlphaFold-Multimer mode, please cite:
```bibtex ```bibtex
@article {AlphaFold-Multimer2021, @article {AlphaFold-Multimer2021,
author = {Evans, Richard and O{\textquoteright}Neill, Michael and Pritzel, Alexander and Antropova, Natasha and Senior, Andrew and Green, Tim and {\v{Z}}{\'\i}dek, Augustin and Bates, Russ and Blackwell, Sam and Yim, Jason and Ronneberger, Olaf and Bodenstein, Sebastian and Zielinski, Michal and Bridgland, Alex and Potapenko, Anna and Cowie, Andrew and Tunyasuvunakool, Kathryn and Jain, Rishub and Clancy, Ellen and Kohli, Pushmeet and Jumper, John and Hassabis, Demis}, author = {Evans, Richard and O{\textquoteright}Neill, Michael and Pritzel, Alexander and Antropova, Natasha and Senior, Andrew and Green, Tim and {\v{Z}}{\'\i}dek, Augustin and Bates, Russ and Blackwell, Sam and Yim, Jason and Ronneberger, Olaf and Bodenstein, Sebastian and Zielinski, Michal and Bridgland, Alex and Potapenko, Anna and Cowie, Andrew and Tunyasuvunakool, Kathryn and Jain, Rishub and Clancy, Ellen and Kohli, Pushmeet and Jumper, John and Hassabis, Demis},
...@@ -619,10 +622,11 @@ In addition, if you use the AlphaFold-Multimer mode, please cite: ...@@ -619,10 +622,11 @@ In addition, if you use the AlphaFold-Multimer mode, please cite:
Colab notebooks provided by the community (please note that these notebooks may Colab notebooks provided by the community (please note that these notebooks may
vary from our full AlphaFold system and we did not validate their accuracy): vary from our full AlphaFold system and we did not validate their accuracy):
* The [ColabFold AlphaFold2 notebook](https://colab.research.google.com/github/sokrypton/ColabFold/blob/main/AlphaFold2.ipynb) * The
[ColabFold AlphaFold2 notebook](https://colab.research.google.com/github/sokrypton/ColabFold/blob/main/AlphaFold2.ipynb)
by Martin Steinegger, Sergey Ovchinnikov and Milot Mirdita, which uses an by Martin Steinegger, Sergey Ovchinnikov and Milot Mirdita, which uses an
API hosted at the Södinglab based on the MMseqs2 server [(Mirdita et al. API hosted at the Södinglab based on the MMseqs2 server
2019, Bioinformatics)](https://academic.oup.com/bioinformatics/article/35/16/2856/5280135) [(Mirdita et al. 2019, Bioinformatics)](https://academic.oup.com/bioinformatics/article/35/16/2856/5280135)
for the multiple sequence alignment creation. for the multiple sequence alignment creation.
## Acknowledgements ## Acknowledgements
...@@ -661,15 +665,15 @@ We thank all their contributors and maintainers! ...@@ -661,15 +665,15 @@ We thank all their contributors and maintainers!
If you have any questions not covered in this overview, please contact the If you have any questions not covered in this overview, please contact the
AlphaFold team at [alphafold@deepmind.com](mailto:alphafold@deepmind.com). AlphaFold team at [alphafold@deepmind.com](mailto:alphafold@deepmind.com).
We would love to hear your feedback and understand how AlphaFold has been We would love to hear your feedback and understand how AlphaFold has been useful
useful in your research. Share your stories with us at in your research. Share your stories with us at
[alphafold@deepmind.com](mailto:alphafold@deepmind.com). [alphafold@deepmind.com](mailto:alphafold@deepmind.com).
## License and Disclaimer ## License and Disclaimer
This is not an officially supported Google product. This is not an officially supported Google product.
Copyright 2021 DeepMind Technologies Limited. Copyright 2022 DeepMind Technologies Limited.
### AlphaFold Code License ### AlphaFold Code License
...@@ -699,12 +703,26 @@ before use. ...@@ -699,12 +703,26 @@ before use.
### Mirrored Databases ### Mirrored Databases
The following databases have been mirrored by DeepMind, and are available with reference to the following: The following databases have been mirrored by DeepMind, and are available with
reference to the following:
* [BFD](https://bfd.mmseqs.com/) (unmodified), by Steinegger M. and Söding J., available under a [Creative Commons Attribution-ShareAlike 4.0 International License](http://creativecommons.org/licenses/by-sa/4.0/).
* [BFD](https://bfd.mmseqs.com/) (unmodified), by Steinegger M. and Söding J.,
* [BFD](https://bfd.mmseqs.com/) (modified), by Steinegger M. and Söding J., modified by DeepMind, available under a [Creative Commons Attribution-ShareAlike 4.0 International License](http://creativecommons.org/licenses/by-sa/4.0/). See the Methods section of the [AlphaFold proteome paper](https://www.nature.com/articles/s41586-021-03828-1) for details. available under a
[Creative Commons Attribution-ShareAlike 4.0 International License](http://creativecommons.org/licenses/by-sa/4.0/).
* [Uniclust30: v2018_08](http://wwwuser.gwdg.de/~compbiol/uniclust/2018_08/) (unmodified), by Mirdita M. et al., available under a [Creative Commons Attribution-ShareAlike 4.0 International License](http://creativecommons.org/licenses/by-sa/4.0/).
* [BFD](https://bfd.mmseqs.com/) (modified), by Steinegger M. and Söding J.,
* [MGnify: v2018_12](http://ftp.ebi.ac.uk/pub/databases/metagenomics/peptide_database/current_release/README.txt) (unmodified), by Mitchell AL et al., available free of all copyright restrictions and made fully and freely available for both non-commercial and commercial use under [CC0 1.0 Universal (CC0 1.0) Public Domain Dedication](https://creativecommons.org/publicdomain/zero/1.0/). modified by DeepMind, available under a
[Creative Commons Attribution-ShareAlike 4.0 International License](http://creativecommons.org/licenses/by-sa/4.0/).
See the Methods section of the
[AlphaFold proteome paper](https://www.nature.com/articles/s41586-021-03828-1)
for details.
* [Uniref30: v2021_03](http://wwwuser.gwdg.de/~compbiol/uniclust/2021_03/)
(unmodified), by Mirdita M. et al., available under a
[Creative Commons Attribution-ShareAlike 4.0 International License](http://creativecommons.org/licenses/by-sa/4.0/).
* [MGnify: v2022_05](http://ftp.ebi.ac.uk/pub/databases/metagenomics/peptide_database/2022_05/README.txt)
(unmodified), by Mitchell AL et al., available free of all copyright
restrictions and made fully and freely available for both non-commercial and
commercial use under
[CC0 1.0 Universal (CC0 1.0) Public Domain Dedication](https://creativecommons.org/publicdomain/zero/1.0/).
...@@ -304,9 +304,9 @@ fractionPlddtVeryHigh | `FLOAT64` | Fraction of the residues in the predi ...@@ -304,9 +304,9 @@ fractionPlddtVeryHigh | `FLOAT64` | Fraction of the residues in the predi
fractionPlddtVeryLow | `FLOAT64` | Fraction of the residues in the prediction with pLDDT less than 50 fractionPlddtVeryLow | `FLOAT64` | Fraction of the residues in the prediction with pLDDT less than 50
gene | `STRING` | The name of the gene if known, e.g. "COII" gene | `STRING` | The name of the gene if known, e.g. "COII"
geneSynonyms | `ARRAY<STRING>` | Additional synonyms for the gene geneSynonyms | `ARRAY<STRING>` | Additional synonyms for the gene
globalMetricValue | `FLOAT64` | The mean pLDDT of this prediction
isReferenceProteome | `BOOL` | Is this protein part of the reference proteome? isReferenceProteome | `BOOL` | Is this protein part of the reference proteome?
isReviewed | `BOOL` | Has this protein been reviewed, i.e. is it part of SwissProt? isReviewed | `BOOL` | Has this protein been reviewed, i.e. is it part of SwissProt?
globalMetricValue | `FLOAT64` | The mean pLDDT of this prediction
latestVersion | `INT64` | The latest AFDB version for this prediction latestVersion | `INT64` | The latest AFDB version for this prediction
modelCreatedDate | `DATE` | The date of creation for this entry, e.g. "2022-06-01" modelCreatedDate | `DATE` | The date of creation for this entry, e.g. "2022-06-01"
organismCommonNames | `ARRAY<STRING>` | List of common organism names organismCommonNames | `ARRAY<STRING>` | List of common organism names
......
...@@ -117,7 +117,7 @@ class DataPipeline: ...@@ -117,7 +117,7 @@ class DataPipeline:
uniref90_database_path: str, uniref90_database_path: str,
mgnify_database_path: str, mgnify_database_path: str,
bfd_database_path: Optional[str], bfd_database_path: Optional[str],
uniclust30_database_path: Optional[str], uniref30_database_path: Optional[str],
small_bfd_database_path: Optional[str], small_bfd_database_path: Optional[str],
template_searcher: TemplateSearcher, template_searcher: TemplateSearcher,
template_featurizer: templates.TemplateHitFeaturizer, template_featurizer: templates.TemplateHitFeaturizer,
...@@ -135,9 +135,9 @@ class DataPipeline: ...@@ -135,9 +135,9 @@ class DataPipeline:
binary_path=jackhmmer_binary_path, binary_path=jackhmmer_binary_path,
database_path=small_bfd_database_path) database_path=small_bfd_database_path)
else: else:
self.hhblits_bfd_uniclust_runner = hhblits.HHBlits( self.hhblits_bfd_uniref_runner = hhblits.HHBlits(
binary_path=hhblits_binary_path, binary_path=hhblits_binary_path,
databases=[bfd_database_path, uniclust30_database_path]) databases=[bfd_database_path, uniref30_database_path])
self.jackhmmer_mgnify_runner = jackhmmer.Jackhmmer( self.jackhmmer_mgnify_runner = jackhmmer.Jackhmmer(
binary_path=jackhmmer_binary_path, binary_path=jackhmmer_binary_path,
database_path=mgnify_database_path) database_path=mgnify_database_path)
...@@ -211,14 +211,14 @@ class DataPipeline: ...@@ -211,14 +211,14 @@ class DataPipeline:
use_precomputed_msas=self.use_precomputed_msas) use_precomputed_msas=self.use_precomputed_msas)
bfd_msa = parsers.parse_stockholm(jackhmmer_small_bfd_result['sto']) bfd_msa = parsers.parse_stockholm(jackhmmer_small_bfd_result['sto'])
else: else:
bfd_out_path = os.path.join(msa_output_dir, 'bfd_uniclust_hits.a3m') bfd_out_path = os.path.join(msa_output_dir, 'bfd_uniref_hits.a3m')
hhblits_bfd_uniclust_result = run_msa_tool( hhblits_bfd_uniref_result = run_msa_tool(
msa_runner=self.hhblits_bfd_uniclust_runner, msa_runner=self.hhblits_bfd_uniref_runner,
input_fasta_path=input_fasta_path, input_fasta_path=input_fasta_path,
msa_out_path=bfd_out_path, msa_out_path=bfd_out_path,
msa_format='a3m', msa_format='a3m',
use_precomputed_msas=self.use_precomputed_msas) use_precomputed_msas=self.use_precomputed_msas)
bfd_msa = parsers.parse_a3m(hhblits_bfd_uniclust_result['a3m']) bfd_msa = parsers.parse_a3m(hhblits_bfd_uniref_result['a3m'])
templates_result = self.template_featurizer.get_templates( templates_result = self.template_featurizer.get_templates(
query_sequence=input_sequence, query_sequence=input_sequence,
......
...@@ -128,3 +128,64 @@ class Linear(hk.Module): ...@@ -128,3 +128,64 @@ class Linear(hk.Module):
return output return output
class LayerNorm(hk.LayerNorm):
"""LayerNorm module.
Equivalent to hk.LayerNorm but with different parameter shapes: they are
always vectors rather than possibly higher-rank tensors. This makes it easier
to change the layout whilst keep the model weight-compatible.
"""
def __init__(self,
axis,
create_scale: bool,
create_offset: bool,
eps: float = 1e-5,
scale_init=None,
offset_init=None,
use_fast_variance: bool = False,
name=None,
param_axis=None):
super().__init__(
axis=axis,
create_scale=False,
create_offset=False,
eps=eps,
scale_init=None,
offset_init=None,
use_fast_variance=use_fast_variance,
name=name,
param_axis=param_axis)
self._temp_create_scale = create_scale
self._temp_create_offset = create_offset
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
is_bf16 = (x.dtype == jnp.bfloat16)
if is_bf16:
x = x.astype(jnp.float32)
param_axis = self.param_axis[0] if self.param_axis else -1
param_shape = (x.shape[param_axis],)
param_broadcast_shape = [1] * x.ndim
param_broadcast_shape[param_axis] = x.shape[param_axis]
scale = None
offset = None
if self._temp_create_scale:
scale = hk.get_parameter(
'scale', param_shape, x.dtype, init=self.scale_init)
scale = scale.reshape(param_broadcast_shape)
if self._temp_create_offset:
offset = hk.get_parameter(
'offset', param_shape, x.dtype, init=self.offset_init)
offset = offset.reshape(param_broadcast_shape)
out = super().__call__(x, scale=scale, offset=offset)
if is_bf16:
out = out.astype(jnp.bfloat16)
return out
\ No newline at end of file
...@@ -26,11 +26,11 @@ NUM_TEMPLATES = shape_placeholders.NUM_TEMPLATES ...@@ -26,11 +26,11 @@ NUM_TEMPLATES = shape_placeholders.NUM_TEMPLATES
def model_config(name: str) -> ml_collections.ConfigDict: def model_config(name: str) -> ml_collections.ConfigDict:
"""Get the ConfigDict of a CASP14 model.""" """Get the ConfigDict of a CASP14 model."""
if 'multimer' in name:
return CONFIG_MULTIMER
if name not in CONFIG_DIFFS: if name not in CONFIG_DIFFS:
raise ValueError(f'Invalid model name {name}.') raise ValueError(f'Invalid model name {name}.')
if 'multimer' in name:
cfg = copy.deepcopy(CONFIG_MULTIMER)
else:
cfg = copy.deepcopy(CONFIG) cfg = copy.deepcopy(CONFIG)
cfg.update_from_flattened_dict(CONFIG_DIFFS[name]) cfg.update_from_flattened_dict(CONFIG_DIFFS[name])
return cfg return cfg
...@@ -52,11 +52,11 @@ MODEL_PRESETS = { ...@@ -52,11 +52,11 @@ MODEL_PRESETS = {
'model_5_ptm', 'model_5_ptm',
), ),
'multimer': ( 'multimer': (
'model_1_multimer_v2', 'model_1_multimer_v3',
'model_2_multimer_v2', 'model_2_multimer_v3',
'model_3_multimer_v2', 'model_3_multimer_v3',
'model_4_multimer_v2', 'model_4_multimer_v3',
'model_5_multimer_v2', 'model_5_multimer_v3',
), ),
} }
MODEL_PRESETS['monomer_casp14'] = MODEL_PRESETS['monomer'] MODEL_PRESETS['monomer_casp14'] = MODEL_PRESETS['monomer']
...@@ -118,8 +118,32 @@ CONFIG_DIFFS = { ...@@ -118,8 +118,32 @@ CONFIG_DIFFS = {
}, },
'model_5_ptm': { 'model_5_ptm': {
'model.heads.predicted_aligned_error.weight': 0.1 'model.heads.predicted_aligned_error.weight': 0.1
} },
'model_1_multimer_v3': {},
'model_2_multimer_v3': {},
'model_3_multimer_v3': {},
'model_4_multimer_v3': {
'model.embeddings_and_evoformer.num_extra_msa': 1152
},
'model_5_multimer_v3': {
'model.embeddings_and_evoformer.num_extra_msa': 1152
},
} }
# Key differences between multimer v1/v2 and v3, mostly due to numerical
# optimisations in the TriangleMultiplication module.
common_updates = {
'model.embeddings_and_evoformer.num_msa': 252,
'model.embeddings_and_evoformer.num_extra_msa': 1152,
'model.embeddings_and_evoformer.evoformer.triangle_multiplication_incoming.fuse_projection_weights': False,
'model.embeddings_and_evoformer.evoformer.triangle_multiplication_outgoing.fuse_projection_weights': False,
'model.embeddings_and_evoformer.template.template_pair_stack.triangle_multiplication_incoming.fuse_projection_weights': False,
'model.embeddings_and_evoformer.template.template_pair_stack.triangle_multiplication_outgoing.fuse_projection_weights': False,
}
CONFIG_DIFFS.update(
{f'model_{i}_multimer': common_updates for i in range(1, 6)})
CONFIG_DIFFS.update(
{f'model_{i}_multimer_v2': common_updates for i in range(1, 6)})
CONFIG = ml_collections.ConfigDict({ CONFIG = ml_collections.ConfigDict({
'data': { 'data': {
...@@ -260,14 +284,16 @@ CONFIG = ml_collections.ConfigDict({ ...@@ -260,14 +284,16 @@ CONFIG = ml_collections.ConfigDict({
'equation': 'ikc,jkc->ijc', 'equation': 'ikc,jkc->ijc',
'num_intermediate_channel': 128, 'num_intermediate_channel': 128,
'orientation': 'per_row', 'orientation': 'per_row',
'shared_dropout': True 'shared_dropout': True,
'fuse_projection_weights': False,
}, },
'triangle_multiplication_incoming': { 'triangle_multiplication_incoming': {
'dropout_rate': 0.25, 'dropout_rate': 0.25,
'equation': 'kjc,kic->ijc', 'equation': 'kjc,kic->ijc',
'num_intermediate_channel': 128, 'num_intermediate_channel': 128,
'orientation': 'per_row', 'orientation': 'per_row',
'shared_dropout': True 'shared_dropout': True,
'fuse_projection_weights': False,
}, },
'pair_transition': { 'pair_transition': {
'dropout_rate': 0.0, 'dropout_rate': 0.0,
...@@ -328,14 +354,16 @@ CONFIG = ml_collections.ConfigDict({ ...@@ -328,14 +354,16 @@ CONFIG = ml_collections.ConfigDict({
'equation': 'ikc,jkc->ijc', 'equation': 'ikc,jkc->ijc',
'num_intermediate_channel': 64, 'num_intermediate_channel': 64,
'orientation': 'per_row', 'orientation': 'per_row',
'shared_dropout': True 'shared_dropout': True,
'fuse_projection_weights': False,
}, },
'triangle_multiplication_incoming': { 'triangle_multiplication_incoming': {
'dropout_rate': 0.25, 'dropout_rate': 0.25,
'equation': 'kjc,kic->ijc', 'equation': 'kjc,kic->ijc',
'num_intermediate_channel': 64, 'num_intermediate_channel': 64,
'orientation': 'per_row', 'orientation': 'per_row',
'shared_dropout': True 'shared_dropout': True,
'fuse_projection_weights': False,
}, },
'pair_transition': { 'pair_transition': {
'dropout_rate': 0.0, 'dropout_rate': 0.0,
...@@ -354,7 +382,7 @@ CONFIG = ml_collections.ConfigDict({ ...@@ -354,7 +382,7 @@ CONFIG = ml_collections.ConfigDict({
'multimer_mode': False, 'multimer_mode': False,
'subbatch_size': 4, 'subbatch_size': 4,
'use_remat': False, 'use_remat': False,
'zero_init': True 'zero_init': True,
}, },
'heads': { 'heads': {
'distogram': { 'distogram': {
...@@ -483,27 +511,29 @@ CONFIG_MULTIMER = ml_collections.ConfigDict({ ...@@ -483,27 +511,29 @@ CONFIG_MULTIMER = ml_collections.ConfigDict({
'gating': True, 'gating': True,
'num_head': 4, 'num_head': 4,
'orientation': 'per_row', 'orientation': 'per_row',
'shared_dropout': True 'shared_dropout': True,
}, },
'triangle_multiplication_incoming': { 'triangle_multiplication_incoming': {
'dropout_rate': 0.25, 'dropout_rate': 0.25,
'equation': 'kjc,kic->ijc', 'equation': 'kjc,kic->ijc',
'num_intermediate_channel': 128, 'num_intermediate_channel': 128,
'orientation': 'per_row', 'orientation': 'per_row',
'shared_dropout': True 'shared_dropout': True,
'fuse_projection_weights': True,
}, },
'triangle_multiplication_outgoing': { 'triangle_multiplication_outgoing': {
'dropout_rate': 0.25, 'dropout_rate': 0.25,
'equation': 'ikc,jkc->ijc', 'equation': 'ikc,jkc->ijc',
'num_intermediate_channel': 128, 'num_intermediate_channel': 128,
'orientation': 'per_row', 'orientation': 'per_row',
'shared_dropout': True 'shared_dropout': True,
'fuse_projection_weights': True,
} }
}, },
'extra_msa_channel': 64, 'extra_msa_channel': 64,
'extra_msa_stack_num_block': 4, 'extra_msa_stack_num_block': 4,
'num_msa': 252, 'num_msa': 508,
'num_extra_msa': 1152, 'num_extra_msa': 2048,
'masked_msa': { 'masked_msa': {
'profile_prob': 0.1, 'profile_prob': 0.1,
'replace_fraction': 0.15, 'replace_fraction': 0.15,
...@@ -564,24 +594,28 @@ CONFIG_MULTIMER = ml_collections.ConfigDict({ ...@@ -564,24 +594,28 @@ CONFIG_MULTIMER = ml_collections.ConfigDict({
'equation': 'kjc,kic->ijc', 'equation': 'kjc,kic->ijc',
'num_intermediate_channel': 64, 'num_intermediate_channel': 64,
'orientation': 'per_row', 'orientation': 'per_row',
'shared_dropout': True 'shared_dropout': True,
'fuse_projection_weights': True,
}, },
'triangle_multiplication_outgoing': { 'triangle_multiplication_outgoing': {
'dropout_rate': 0.25, 'dropout_rate': 0.25,
'equation': 'ikc,jkc->ijc', 'equation': 'ikc,jkc->ijc',
'num_intermediate_channel': 64, 'num_intermediate_channel': 64,
'orientation': 'per_row', 'orientation': 'per_row',
'shared_dropout': True 'shared_dropout': True,
'fuse_projection_weights': True,
} }
} }
}, },
}, },
'global_config': { 'global_config': {
'bfloat16': True,
'bfloat16_output': False,
'deterministic': False, 'deterministic': False,
'multimer_mode': True, 'multimer_mode': True,
'subbatch_size': 4, 'subbatch_size': 4,
'use_remat': False, 'use_remat': False,
'zero_init': True 'zero_init': True,
}, },
'heads': { 'heads': {
'distogram': { 'distogram': {
...@@ -651,7 +685,13 @@ CONFIG_MULTIMER = ml_collections.ConfigDict({ ...@@ -651,7 +685,13 @@ CONFIG_MULTIMER = ml_collections.ConfigDict({
} }
}, },
'num_ensemble_eval': 1, 'num_ensemble_eval': 1,
'num_recycle': 3, 'num_recycle': 20,
# A negative value indicates that no early stopping will occur, i.e.
# the model will always run `num_recycle` number of recycling
# iterations. A positive value will enable early stopping if the
# difference in pairwise distances is less than the tolerance between
# recycling steps.
'recycle_early_stop_tolerance': 0.5,
'resample_msa_in_recycling': True 'resample_msa_in_recycling': True
} }
}) })
...@@ -331,7 +331,7 @@ class FoldIteration(hk.Module): ...@@ -331,7 +331,7 @@ class FoldIteration(hk.Module):
safe_key, *sub_keys = safe_key.split(3) safe_key, *sub_keys = safe_key.split(3)
sub_keys = iter(sub_keys) sub_keys = iter(sub_keys)
act = safe_dropout_fn(act, next(sub_keys)) act = safe_dropout_fn(act, next(sub_keys))
act = hk.LayerNorm( act = common_modules.LayerNorm(
axis=[-1], axis=[-1],
create_scale=True, create_scale=True,
create_offset=True, create_offset=True,
...@@ -353,7 +353,7 @@ class FoldIteration(hk.Module): ...@@ -353,7 +353,7 @@ class FoldIteration(hk.Module):
act = jax.nn.relu(act) act = jax.nn.relu(act)
act += input_act act += input_act
act = safe_dropout_fn(act, next(sub_keys)) act = safe_dropout_fn(act, next(sub_keys))
act = hk.LayerNorm( act = common_modules.LayerNorm(
axis=[-1], axis=[-1],
create_scale=True, create_scale=True,
create_offset=True, create_offset=True,
...@@ -410,7 +410,7 @@ def generate_affines(representations, batch, config, global_config, ...@@ -410,7 +410,7 @@ def generate_affines(representations, batch, config, global_config,
c = config c = config
sequence_mask = batch['seq_mask'][:, None] sequence_mask = batch['seq_mask'][:, None]
act = hk.LayerNorm( act = common_modules.LayerNorm(
axis=[-1], axis=[-1],
create_scale=True, create_scale=True,
create_offset=True, create_offset=True,
...@@ -433,7 +433,7 @@ def generate_affines(representations, batch, config, global_config, ...@@ -433,7 +433,7 @@ def generate_affines(representations, batch, config, global_config,
'affine': affine.to_tensor(), 'affine': affine.to_tensor(),
} }
act_2d = hk.LayerNorm( act_2d = common_modules.LayerNorm(
axis=[-1], axis=[-1],
create_scale=True, create_scale=True,
create_offset=True, create_offset=True,
......
...@@ -427,7 +427,7 @@ class FoldIteration(hk.Module): ...@@ -427,7 +427,7 @@ class FoldIteration(hk.Module):
safe_key, *sub_keys = safe_key.split(3) safe_key, *sub_keys = safe_key.split(3)
sub_keys = iter(sub_keys) sub_keys = iter(sub_keys)
act = safe_dropout_fn(act, next(sub_keys)) act = safe_dropout_fn(act, next(sub_keys))
act = hk.LayerNorm( act = common_modules.LayerNorm(
axis=-1, axis=-1,
create_scale=True, create_scale=True,
create_offset=True, create_offset=True,
...@@ -448,7 +448,7 @@ class FoldIteration(hk.Module): ...@@ -448,7 +448,7 @@ class FoldIteration(hk.Module):
act = jax.nn.relu(act) act = jax.nn.relu(act)
act += input_act act += input_act
act = safe_dropout_fn(act, next(sub_keys)) act = safe_dropout_fn(act, next(sub_keys))
act = hk.LayerNorm( act = common_modules.LayerNorm(
axis=-1, axis=-1,
create_scale=True, create_scale=True,
create_offset=True, create_offset=True,
...@@ -500,7 +500,7 @@ def generate_monomer_rigids(representations: Mapping[str, jnp.ndarray], ...@@ -500,7 +500,7 @@ def generate_monomer_rigids(representations: Mapping[str, jnp.ndarray],
""" """
c = config c = config
sequence_mask = batch['seq_mask'][:, None] sequence_mask = batch['seq_mask'][:, None]
act = hk.LayerNorm( act = common_modules.LayerNorm(
axis=-1, create_scale=True, create_offset=True, name='single_layer_norm')( axis=-1, create_scale=True, create_offset=True, name='single_layer_norm')(
representations['single']) representations['single'])
...@@ -523,7 +523,7 @@ def generate_monomer_rigids(representations: Mapping[str, jnp.ndarray], ...@@ -523,7 +523,7 @@ def generate_monomer_rigids(representations: Mapping[str, jnp.ndarray],
rigid rigid
} }
act_2d = hk.LayerNorm( act_2d = common_modules.LayerNorm(
axis=-1, axis=-1,
create_scale=True, create_scale=True,
create_offset=True, create_offset=True,
......
...@@ -133,7 +133,7 @@ def flatten(instance): ...@@ -133,7 +133,7 @@ def flatten(instance):
inner_treedefs = [] inner_treedefs = []
num_arrays = [] num_arrays = []
for array_like in array_likes: for array_like in array_likes:
flat_array_like, inner_treedef = jax.tree_flatten(array_like) flat_array_like, inner_treedef = jax.tree_util.tree_flatten(array_like)
inner_treedefs.append(inner_treedef) inner_treedefs.append(inner_treedef)
flat_array_likes += flat_array_like flat_array_likes += flat_array_like
num_arrays.append(len(flat_array_like)) num_arrays.append(len(flat_array_like))
...@@ -206,7 +206,7 @@ class StructOfArray: ...@@ -206,7 +206,7 @@ class StructOfArray:
for num_array, inner_treedef, array_field in zip(num_arrays, for num_array, inner_treedef, array_field in zip(num_arrays,
inner_treedefs, inner_treedefs,
array_fields): array_fields):
value_dict[array_field] = jax.tree_unflatten( value_dict[array_field] = jax.tree_util.tree_unflatten(
inner_treedef, data[array_start:array_start + num_array]) inner_treedef, data[array_start:array_start + num_array])
array_start += num_array array_start += num_array
metadata_fields = get_metadata_fields(new_cls) metadata_fields = get_metadata_fields(new_cls)
......
...@@ -47,11 +47,11 @@ def _maybe_get_size(array, axis): ...@@ -47,11 +47,11 @@ def _maybe_get_size(array, axis):
def _expand_axes(axes, values, name='sharded_apply'): def _expand_axes(axes, values, name='sharded_apply'):
values_tree_def = jax.tree_flatten(values)[1] values_tree_def = jax.tree_util.tree_flatten(values)[1]
flat_axes = jax.api_util.flatten_axes(name, values_tree_def, axes) flat_axes = jax.api_util.flatten_axes(name, values_tree_def, axes)
# Replace None's with PROXY # Replace None's with PROXY
flat_axes = [PROXY if x is None else x for x in flat_axes] flat_axes = [PROXY if x is None else x for x in flat_axes]
return jax.tree_unflatten(values_tree_def, flat_axes) return jax.tree_util.tree_unflatten(values_tree_def, flat_axes)
def sharded_map( def sharded_map(
...@@ -126,7 +126,7 @@ def sharded_apply( ...@@ -126,7 +126,7 @@ def sharded_apply(
in_axes_ = _expand_axes(in_axes, args) in_axes_ = _expand_axes(in_axes, args)
in_sizes = jax.tree_map(_maybe_get_size, args, in_axes_) in_sizes = jax.tree_map(_maybe_get_size, args, in_axes_)
flat_sizes = jax.tree_flatten(in_sizes)[0] flat_sizes = jax.tree_util.tree_flatten(in_sizes)[0]
in_size = max(flat_sizes) in_size = max(flat_sizes)
assert all(i in {in_size, -1} for i in flat_sizes) assert all(i in {in_size, -1} for i in flat_sizes)
......
...@@ -501,7 +501,7 @@ class Transition(hk.Module): ...@@ -501,7 +501,7 @@ class Transition(hk.Module):
num_intermediate = int(nc * self.config.num_intermediate_factor) num_intermediate = int(nc * self.config.num_intermediate_factor)
mask = jnp.expand_dims(mask, axis=-1) mask = jnp.expand_dims(mask, axis=-1)
act = hk.LayerNorm( act = common_modules.LayerNorm(
axis=[-1], axis=[-1],
create_scale=True, create_scale=True,
create_offset=True, create_offset=True,
...@@ -569,12 +569,15 @@ class Attention(hk.Module): ...@@ -569,12 +569,15 @@ class Attention(hk.Module):
q_weights = hk.get_parameter( q_weights = hk.get_parameter(
'query_w', shape=(q_data.shape[-1], num_head, key_dim), 'query_w', shape=(q_data.shape[-1], num_head, key_dim),
dtype=q_data.dtype,
init=glorot_uniform()) init=glorot_uniform())
k_weights = hk.get_parameter( k_weights = hk.get_parameter(
'key_w', shape=(m_data.shape[-1], num_head, key_dim), 'key_w', shape=(m_data.shape[-1], num_head, key_dim),
dtype=q_data.dtype,
init=glorot_uniform()) init=glorot_uniform())
v_weights = hk.get_parameter( v_weights = hk.get_parameter(
'value_w', shape=(m_data.shape[-1], num_head, value_dim), 'value_w', shape=(m_data.shape[-1], num_head, value_dim),
dtype=q_data.dtype,
init=glorot_uniform()) init=glorot_uniform())
q = jnp.einsum('bqa,ahc->bqhc', q_data, q_weights) * key_dim**(-0.5) q = jnp.einsum('bqa,ahc->bqhc', q_data, q_weights) * key_dim**(-0.5)
...@@ -595,10 +598,12 @@ class Attention(hk.Module): ...@@ -595,10 +598,12 @@ class Attention(hk.Module):
gating_weights = hk.get_parameter( gating_weights = hk.get_parameter(
'gating_w', 'gating_w',
shape=(q_data.shape[-1], num_head, value_dim), shape=(q_data.shape[-1], num_head, value_dim),
dtype=q_data.dtype,
init=hk.initializers.Constant(0.0)) init=hk.initializers.Constant(0.0))
gating_bias = hk.get_parameter( gating_bias = hk.get_parameter(
'gating_b', 'gating_b',
shape=(num_head, value_dim), shape=(num_head, value_dim),
dtype=q_data.dtype,
init=hk.initializers.Constant(1.0)) init=hk.initializers.Constant(1.0))
gate_values = jnp.einsum('bqc, chv->bqhv', q_data, gate_values = jnp.einsum('bqc, chv->bqhv', q_data,
...@@ -610,8 +615,11 @@ class Attention(hk.Module): ...@@ -610,8 +615,11 @@ class Attention(hk.Module):
o_weights = hk.get_parameter( o_weights = hk.get_parameter(
'output_w', shape=(num_head, value_dim, self.output_dim), 'output_w', shape=(num_head, value_dim, self.output_dim),
dtype=q_data.dtype,
init=init) init=init)
o_bias = hk.get_parameter('output_b', shape=(self.output_dim,), o_bias = hk.get_parameter(
'output_b', shape=(self.output_dim,),
dtype=q_data.dtype,
init=hk.initializers.Constant(0.0)) init=hk.initializers.Constant(0.0))
output = jnp.einsum('bqhc,hco->bqo', weighted_avg, o_weights) + o_bias output = jnp.einsum('bqhc,hco->bqo', weighted_avg, o_weights) + o_bias
...@@ -658,12 +666,15 @@ class GlobalAttention(hk.Module): ...@@ -658,12 +666,15 @@ class GlobalAttention(hk.Module):
q_weights = hk.get_parameter( q_weights = hk.get_parameter(
'query_w', shape=(q_data.shape[-1], num_head, key_dim), 'query_w', shape=(q_data.shape[-1], num_head, key_dim),
dtype=q_data.dtype,
init=glorot_uniform()) init=glorot_uniform())
k_weights = hk.get_parameter( k_weights = hk.get_parameter(
'key_w', shape=(m_data.shape[-1], key_dim), 'key_w', shape=(m_data.shape[-1], key_dim),
dtype=q_data.dtype,
init=glorot_uniform()) init=glorot_uniform())
v_weights = hk.get_parameter( v_weights = hk.get_parameter(
'value_w', shape=(m_data.shape[-1], value_dim), 'value_w', shape=(m_data.shape[-1], value_dim),
dtype=q_data.dtype,
init=glorot_uniform()) init=glorot_uniform())
v = jnp.einsum('bka,ac->bkc', m_data, v_weights) v = jnp.einsum('bka,ac->bkc', m_data, v_weights)
...@@ -684,18 +695,23 @@ class GlobalAttention(hk.Module): ...@@ -684,18 +695,23 @@ class GlobalAttention(hk.Module):
o_weights = hk.get_parameter( o_weights = hk.get_parameter(
'output_w', shape=(num_head, value_dim, self.output_dim), 'output_w', shape=(num_head, value_dim, self.output_dim),
dtype=q_data.dtype,
init=init) init=init)
o_bias = hk.get_parameter('output_b', shape=(self.output_dim,), o_bias = hk.get_parameter(
'output_b', shape=(self.output_dim,),
dtype=q_data.dtype,
init=hk.initializers.Constant(0.0)) init=hk.initializers.Constant(0.0))
if self.config.gating: if self.config.gating:
gating_weights = hk.get_parameter( gating_weights = hk.get_parameter(
'gating_w', 'gating_w',
shape=(q_data.shape[-1], num_head, value_dim), shape=(q_data.shape[-1], num_head, value_dim),
dtype=q_data.dtype,
init=hk.initializers.Constant(0.0)) init=hk.initializers.Constant(0.0))
gating_bias = hk.get_parameter( gating_bias = hk.get_parameter(
'gating_b', 'gating_b',
shape=(num_head, value_dim), shape=(num_head, value_dim),
dtype=q_data.dtype,
init=hk.initializers.Constant(1.0)) init=hk.initializers.Constant(1.0))
gate_values = jnp.einsum('bqc, chv->bqhv', q_data, gating_weights) gate_values = jnp.einsum('bqc, chv->bqhv', q_data, gating_weights)
...@@ -745,11 +761,11 @@ class MSARowAttentionWithPairBias(hk.Module): ...@@ -745,11 +761,11 @@ class MSARowAttentionWithPairBias(hk.Module):
bias = (1e9 * (msa_mask - 1.))[:, None, None, :] bias = (1e9 * (msa_mask - 1.))[:, None, None, :]
assert len(bias.shape) == 4 assert len(bias.shape) == 4
msa_act = hk.LayerNorm( msa_act = common_modules.LayerNorm(
axis=[-1], create_scale=True, create_offset=True, name='query_norm')( axis=[-1], create_scale=True, create_offset=True, name='query_norm')(
msa_act) msa_act)
pair_act = hk.LayerNorm( pair_act = common_modules.LayerNorm(
axis=[-1], axis=[-1],
create_scale=True, create_scale=True,
create_offset=True, create_offset=True,
...@@ -760,6 +776,7 @@ class MSARowAttentionWithPairBias(hk.Module): ...@@ -760,6 +776,7 @@ class MSARowAttentionWithPairBias(hk.Module):
weights = hk.get_parameter( weights = hk.get_parameter(
'feat_2d_weights', 'feat_2d_weights',
shape=(pair_act.shape[-1], c.num_head), shape=(pair_act.shape[-1], c.num_head),
dtype=msa_act.dtype,
init=hk.initializers.RandomNormal(stddev=init_factor)) init=hk.initializers.RandomNormal(stddev=init_factor))
nonbatched_bias = jnp.einsum('qkc,ch->hqk', pair_act, weights) nonbatched_bias = jnp.einsum('qkc,ch->hqk', pair_act, weights)
...@@ -812,7 +829,7 @@ class MSAColumnAttention(hk.Module): ...@@ -812,7 +829,7 @@ class MSAColumnAttention(hk.Module):
bias = (1e9 * (msa_mask - 1.))[:, None, None, :] bias = (1e9 * (msa_mask - 1.))[:, None, None, :]
assert len(bias.shape) == 4 assert len(bias.shape) == 4
msa_act = hk.LayerNorm( msa_act = common_modules.LayerNorm(
axis=[-1], create_scale=True, create_offset=True, name='query_norm')( axis=[-1], create_scale=True, create_offset=True, name='query_norm')(
msa_act) msa_act)
...@@ -867,7 +884,7 @@ class MSAColumnGlobalAttention(hk.Module): ...@@ -867,7 +884,7 @@ class MSAColumnGlobalAttention(hk.Module):
bias = (1e9 * (msa_mask - 1.))[:, None, None, :] bias = (1e9 * (msa_mask - 1.))[:, None, None, :]
assert len(bias.shape) == 4 assert len(bias.shape) == 4
msa_act = hk.LayerNorm( msa_act = common_modules.LayerNorm(
axis=[-1], create_scale=True, create_offset=True, name='query_norm')( axis=[-1], create_scale=True, create_offset=True, name='query_norm')(
msa_act) msa_act)
...@@ -924,7 +941,7 @@ class TriangleAttention(hk.Module): ...@@ -924,7 +941,7 @@ class TriangleAttention(hk.Module):
bias = (1e9 * (pair_mask - 1.))[:, None, None, :] bias = (1e9 * (pair_mask - 1.))[:, None, None, :]
assert len(bias.shape) == 4 assert len(bias.shape) == 4
pair_act = hk.LayerNorm( pair_act = common_modules.LayerNorm(
axis=[-1], create_scale=True, create_offset=True, name='query_norm')( axis=[-1], create_scale=True, create_offset=True, name='query_norm')(
pair_act) pair_act)
...@@ -932,6 +949,7 @@ class TriangleAttention(hk.Module): ...@@ -932,6 +949,7 @@ class TriangleAttention(hk.Module):
weights = hk.get_parameter( weights = hk.get_parameter(
'feat_2d_weights', 'feat_2d_weights',
shape=(pair_act.shape[-1], c.num_head), shape=(pair_act.shape[-1], c.num_head),
dtype=pair_act.dtype,
init=hk.initializers.RandomNormal(stddev=init_factor)) init=hk.initializers.RandomNormal(stddev=init_factor))
nonbatched_bias = jnp.einsum('qkc,ch->hqk', pair_act, weights) nonbatched_bias = jnp.einsum('qkc,ch->hqk', pair_act, weights)
...@@ -1029,7 +1047,7 @@ class PredictedLDDTHead(hk.Module): ...@@ -1029,7 +1047,7 @@ class PredictedLDDTHead(hk.Module):
""" """
act = representations['structure_module'] act = representations['structure_module']
act = hk.LayerNorm( act = common_modules.LayerNorm(
axis=[-1], axis=[-1],
create_scale=True, create_scale=True,
create_offset=True, create_offset=True,
...@@ -1251,6 +1269,19 @@ class ExperimentallyResolvedHead(hk.Module): ...@@ -1251,6 +1269,19 @@ class ExperimentallyResolvedHead(hk.Module):
return output return output
def _layer_norm(axis=-1, name='layer_norm'):
return common_modules.LayerNorm(
axis=axis,
create_scale=True,
create_offset=True,
eps=1e-5,
use_fast_variance=True,
scale_init=hk.initializers.Constant(1.),
offset_init=hk.initializers.Constant(0.),
param_axis=axis,
name=name)
class TriangleMultiplication(hk.Module): class TriangleMultiplication(hk.Module):
"""Triangle multiplication layer ("outgoing" or "incoming"). """Triangle multiplication layer ("outgoing" or "incoming").
...@@ -1263,25 +1294,34 @@ class TriangleMultiplication(hk.Module): ...@@ -1263,25 +1294,34 @@ class TriangleMultiplication(hk.Module):
self.config = config self.config = config
self.global_config = global_config self.global_config = global_config
def __call__(self, act, mask, is_training=True): def __call__(self, left_act, left_mask, is_training=True):
"""Builds TriangleMultiplication module. """Builds TriangleMultiplication module.
Arguments: Arguments:
act: Pair activations, shape [N_res, N_res, c_z] left_act: Pair activations, shape [N_res, N_res, c_z]
mask: Pair mask, shape [N_res, N_res]. left_mask: Pair mask, shape [N_res, N_res].
is_training: Whether the module is in training mode. is_training: Whether the module is in training mode.
Returns: Returns:
Outputs, same shape/type as act. Outputs, same shape/type as left_act.
""" """
del is_training del is_training
if self.config.fuse_projection_weights:
return self._fused_triangle_multiplication(left_act, left_mask)
else:
return self._triangle_multiplication(left_act, left_mask)
@hk.transparent
def _triangle_multiplication(self, left_act, left_mask):
"""Implementation of TriangleMultiplication used in AF2 and AF-M<2.3."""
c = self.config c = self.config
gc = self.global_config gc = self.global_config
mask = mask[..., None] mask = left_mask[..., None]
act = hk.LayerNorm(axis=[-1], create_scale=True, create_offset=True, act = common_modules.LayerNorm(axis=[-1], create_scale=True, create_offset=True,
name='layer_norm_input')(act) name='layer_norm_input')(left_act)
input_act = act input_act = act
left_projection = common_modules.Linear( left_projection = common_modules.Linear(
...@@ -1317,7 +1357,7 @@ class TriangleMultiplication(hk.Module): ...@@ -1317,7 +1357,7 @@ class TriangleMultiplication(hk.Module):
# b = left_proj_act and a = right_proj_act # b = left_proj_act and a = right_proj_act
act = jnp.einsum(c.equation, left_proj_act, right_proj_act) act = jnp.einsum(c.equation, left_proj_act, right_proj_act)
act = hk.LayerNorm( act = common_modules.LayerNorm(
axis=[-1], axis=[-1],
create_scale=True, create_scale=True,
create_offset=True, create_offset=True,
...@@ -1340,6 +1380,50 @@ class TriangleMultiplication(hk.Module): ...@@ -1340,6 +1380,50 @@ class TriangleMultiplication(hk.Module):
return act return act
@hk.transparent
def _fused_triangle_multiplication(self, left_act, left_mask):
"""TriangleMultiplication with fused projection weights."""
mask = left_mask[..., None]
c = self.config
gc = self.global_config
left_act = _layer_norm(axis=-1, name='left_norm_input')(left_act)
# Both left and right projections are fused into projection.
projection = common_modules.Linear(
2*c.num_intermediate_channel, name='projection')
proj_act = mask * projection(left_act)
# Both left + right gate are fused into gate_values.
gate_values = common_modules.Linear(
2 * c.num_intermediate_channel,
name='gate',
bias_init=1.,
initializer=utils.final_init(gc))(left_act)
proj_act *= jax.nn.sigmoid(gate_values)
left_proj_act = proj_act[:, :, :c.num_intermediate_channel]
right_proj_act = proj_act[:, :, c.num_intermediate_channel:]
act = jnp.einsum(c.equation, left_proj_act, right_proj_act)
act = _layer_norm(axis=-1, name='center_norm')(act)
output_channel = int(left_act.shape[-1])
act = common_modules.Linear(
output_channel,
initializer=utils.final_init(gc),
name='output_projection')(act)
gate_values = common_modules.Linear(
output_channel,
bias_init=1.,
initializer=utils.final_init(gc),
name='gating_linear')(left_act)
act *= jax.nn.sigmoid(gate_values)
return act
class DistogramHead(hk.Module): class DistogramHead(hk.Module):
"""Head to predict a distogram. """Head to predict a distogram.
...@@ -1446,7 +1530,7 @@ class OuterProductMean(hk.Module): ...@@ -1446,7 +1530,7 @@ class OuterProductMean(hk.Module):
c = self.config c = self.config
mask = mask[..., None] mask = mask[..., None]
act = hk.LayerNorm([-1], True, True, name='layer_norm_input')(act) act = common_modules.LayerNorm([-1], True, True, name='layer_norm_input')(act)
left_act = mask * common_modules.Linear( left_act = mask * common_modules.Linear(
c.num_outer_channel, c.num_outer_channel,
...@@ -1469,9 +1553,11 @@ class OuterProductMean(hk.Module): ...@@ -1469,9 +1553,11 @@ class OuterProductMean(hk.Module):
'output_w', 'output_w',
shape=(c.num_outer_channel, c.num_outer_channel, shape=(c.num_outer_channel, c.num_outer_channel,
self.num_output_channel), self.num_output_channel),
dtype=act.dtype,
init=init_w) init=init_w)
output_b = hk.get_parameter( output_b = hk.get_parameter(
'output_b', shape=(self.num_output_channel,), 'output_b', shape=(self.num_output_channel,),
dtype=act.dtype,
init=hk.initializers.Constant(0.0)) init=hk.initializers.Constant(0.0))
def compute_chunk(left_act): def compute_chunk(left_act):
...@@ -1738,7 +1824,7 @@ class EmbeddingsAndEvoformer(hk.Module): ...@@ -1738,7 +1824,7 @@ class EmbeddingsAndEvoformer(hk.Module):
dgram) dgram)
if c.recycle_features: if c.recycle_features:
prev_msa_first_row = hk.LayerNorm( prev_msa_first_row = common_modules.LayerNorm(
axis=[-1], axis=[-1],
create_scale=True, create_scale=True,
create_offset=True, create_offset=True,
...@@ -1746,7 +1832,7 @@ class EmbeddingsAndEvoformer(hk.Module): ...@@ -1746,7 +1832,7 @@ class EmbeddingsAndEvoformer(hk.Module):
batch['prev_msa_first_row']) batch['prev_msa_first_row'])
msa_activations = msa_activations.at[0].add(prev_msa_first_row) msa_activations = msa_activations.at[0].add(prev_msa_first_row)
pair_activations += hk.LayerNorm( pair_activations += common_modules.LayerNorm(
axis=[-1], axis=[-1],
create_scale=True, create_scale=True,
create_offset=True, create_offset=True,
...@@ -2020,7 +2106,7 @@ class SingleTemplateEmbedding(hk.Module): ...@@ -2020,7 +2106,7 @@ class SingleTemplateEmbedding(hk.Module):
self.config.template_pair_stack, self.global_config)( self.config.template_pair_stack, self.global_config)(
act, mask_2d, is_training) act, mask_2d, is_training)
act = hk.LayerNorm([-1], True, True, name='output_layer_norm')(act) act = common_modules.LayerNorm([-1], True, True, name='output_layer_norm')(act)
return act return act
......
...@@ -475,20 +475,51 @@ class AlphaFold(hk.Module): ...@@ -475,20 +475,51 @@ class AlphaFold(hk.Module):
# Eval mode or tests: use the maximum number of iterations. # Eval mode or tests: use the maximum number of iterations.
num_iter = c.num_recycle num_iter = c.num_recycle
def recycle_body(i, x): def distances(points):
del i """Compute all pairwise distances for a set of points."""
prev, safe_key = x return jnp.sqrt(jnp.sum((points[:, None] - points[None, :])**2,
axis=-1))
def recycle_body(x):
i, _, prev, safe_key = x
safe_key1, safe_key2 = safe_key.split() if c.resample_msa_in_recycling else safe_key.duplicate() # pylint: disable=line-too-long safe_key1, safe_key2 = safe_key.split() if c.resample_msa_in_recycling else safe_key.duplicate() # pylint: disable=line-too-long
ret = apply_network(prev=prev, safe_key=safe_key2) ret = apply_network(prev=prev, safe_key=safe_key2)
return get_prev(ret), safe_key1 return i+1, prev, get_prev(ret), safe_key1
prev, safe_key = hk.fori_loop(0, num_iter, recycle_body, (prev, safe_key)) def recycle_cond(x):
i, prev, next_in, _ = x
ca_idx = residue_constants.atom_order['CA']
sq_diff = jnp.square(distances(prev['prev_pos'][:, ca_idx, :]) -
distances(next_in['prev_pos'][:, ca_idx, :]))
mask = batch['seq_mask'][:, None] * batch['seq_mask'][None, :]
sq_diff = utils.mask_mean(mask, sq_diff)
# Early stopping criteria based on criteria used in
# AF2Complex: https://www.nature.com/articles/s41467-022-29394-2
diff = jnp.sqrt(sq_diff + 1e-8) # avoid bad numerics giving negatives
less_than_max_recycles = (i < num_iter)
has_exceeded_tolerance = (
(i == 0) | (diff > c.recycle_early_stop_tolerance))
return less_than_max_recycles & has_exceeded_tolerance
if hk.running_init():
num_recycles, _, prev, safe_key = recycle_body(
(0, prev, prev, safe_key))
else:
num_recycles, _, prev, safe_key = hk.while_loop(
recycle_cond,
recycle_body,
(0, prev, prev, safe_key))
else:
# No recycling.
num_recycles = 0
# Run extra iteration. # Run extra iteration.
ret = apply_network(prev=prev, safe_key=safe_key) ret = apply_network(prev=prev, safe_key=safe_key)
if not return_representations: if not return_representations:
del ret['representations'] del ret['representations']
ret['num_recycles'] = num_recycles
return ret return ret
...@@ -524,11 +555,13 @@ class EmbeddingsAndEvoformer(hk.Module): ...@@ -524,11 +555,13 @@ class EmbeddingsAndEvoformer(hk.Module):
Feature embedding using the features as described before. Feature embedding using the features as described before.
""" """
c = self.config c = self.config
gc = self.global_config
rel_feats = [] rel_feats = []
pos = batch['residue_index'] pos = batch['residue_index']
asym_id = batch['asym_id'] asym_id = batch['asym_id']
asym_id_same = jnp.equal(asym_id[:, None], asym_id[None, :]) asym_id_same = jnp.equal(asym_id[:, None], asym_id[None, :])
offset = pos[:, None] - pos[None, :] offset = pos[:, None] - pos[None, :]
dtype = jnp.bfloat16 if gc.bfloat16 else jnp.float32
clipped_offset = jnp.clip( clipped_offset = jnp.clip(
offset + c.max_relative_idx, a_min=0, a_max=2 * c.max_relative_idx) offset + c.max_relative_idx, a_min=0, a_max=2 * c.max_relative_idx)
...@@ -568,6 +601,7 @@ class EmbeddingsAndEvoformer(hk.Module): ...@@ -568,6 +601,7 @@ class EmbeddingsAndEvoformer(hk.Module):
rel_feat = jnp.concatenate(rel_feats, axis=-1) rel_feat = jnp.concatenate(rel_feats, axis=-1)
rel_feat = rel_feat.astype(dtype)
return common_modules.Linear( return common_modules.Linear(
c.pair_channel, c.pair_channel,
name='position_activations')( name='position_activations')(
...@@ -579,6 +613,7 @@ class EmbeddingsAndEvoformer(hk.Module): ...@@ -579,6 +613,7 @@ class EmbeddingsAndEvoformer(hk.Module):
gc = self.global_config gc = self.global_config
batch = dict(batch) batch = dict(batch)
dtype = jnp.bfloat16 if gc.bfloat16 else jnp.float32
if safe_key is None: if safe_key is None:
safe_key = prng.SafeKey(hk.next_rng_key()) safe_key = prng.SafeKey(hk.next_rng_key())
...@@ -587,7 +622,8 @@ class EmbeddingsAndEvoformer(hk.Module): ...@@ -587,7 +622,8 @@ class EmbeddingsAndEvoformer(hk.Module):
batch['msa_profile'] = make_msa_profile(batch) batch['msa_profile'] = make_msa_profile(batch)
target_feat = jax.nn.one_hot(batch['aatype'], 21) with utils.bfloat16_context():
target_feat = jax.nn.one_hot(batch['aatype'], 21).astype(dtype)
preprocess_1d = common_modules.Linear( preprocess_1d = common_modules.Linear(
c.msa_channel, name='preprocess_1d')( c.msa_channel, name='preprocess_1d')(
...@@ -600,12 +636,11 @@ class EmbeddingsAndEvoformer(hk.Module): ...@@ -600,12 +636,11 @@ class EmbeddingsAndEvoformer(hk.Module):
(batch['cluster_profile'], (batch['cluster_profile'],
batch['cluster_deletion_mean']) = nearest_neighbor_clusters(batch) batch['cluster_deletion_mean']) = nearest_neighbor_clusters(batch)
msa_feat = create_msa_feat(batch) msa_feat = create_msa_feat(batch).astype(dtype)
preprocess_msa = common_modules.Linear( preprocess_msa = common_modules.Linear(
c.msa_channel, name='preprocess_msa')( c.msa_channel, name='preprocess_msa')(
msa_feat) msa_feat)
msa_activations = jnp.expand_dims(preprocess_1d, axis=0) + preprocess_msa msa_activations = jnp.expand_dims(preprocess_1d, axis=0) + preprocess_msa
left_single = common_modules.Linear( left_single = common_modules.Linear(
...@@ -616,7 +651,7 @@ class EmbeddingsAndEvoformer(hk.Module): ...@@ -616,7 +651,7 @@ class EmbeddingsAndEvoformer(hk.Module):
target_feat) target_feat)
pair_activations = left_single[:, None] + right_single[None] pair_activations = left_single[:, None] + right_single[None]
mask_2d = batch['seq_mask'][:, None] * batch['seq_mask'][None, :] mask_2d = batch['seq_mask'][:, None] * batch['seq_mask'][None, :]
mask_2d = mask_2d.astype(jnp.float32) mask_2d = mask_2d.astype(dtype)
if c.recycle_pos: if c.recycle_pos:
prev_pseudo_beta = modules.pseudo_beta_fn( prev_pseudo_beta = modules.pseudo_beta_fn(
...@@ -624,25 +659,25 @@ class EmbeddingsAndEvoformer(hk.Module): ...@@ -624,25 +659,25 @@ class EmbeddingsAndEvoformer(hk.Module):
dgram = modules.dgram_from_positions( dgram = modules.dgram_from_positions(
prev_pseudo_beta, **self.config.prev_pos) prev_pseudo_beta, **self.config.prev_pos)
dgram = dgram.astype(dtype)
pair_activations += common_modules.Linear( pair_activations += common_modules.Linear(
c.pair_channel, name='prev_pos_linear')( c.pair_channel, name='prev_pos_linear')(
dgram) dgram)
if c.recycle_features: if c.recycle_features:
prev_msa_first_row = hk.LayerNorm( prev_msa_first_row = common_modules.LayerNorm(
axis=[-1], axis=[-1],
create_scale=True, create_scale=True,
create_offset=True, create_offset=True,
name='prev_msa_first_row_norm')( name='prev_msa_first_row_norm')(
batch['prev_msa_first_row']) batch['prev_msa_first_row']).astype(dtype)
msa_activations = msa_activations.at[0].add(prev_msa_first_row) msa_activations = msa_activations.at[0].add(prev_msa_first_row)
pair_activations += hk.LayerNorm( pair_activations += common_modules.LayerNorm(
axis=[-1], axis=[-1],
create_scale=True, create_scale=True,
create_offset=True, create_offset=True,
name='prev_pair_norm')( name='prev_pair_norm')(
batch['prev_pair']) batch['prev_pair']).astype(dtype)
if c.max_relative_idx: if c.max_relative_idx:
pair_activations += self._relative_encoding(batch) pair_activations += self._relative_encoding(batch)
...@@ -673,8 +708,8 @@ class EmbeddingsAndEvoformer(hk.Module): ...@@ -673,8 +708,8 @@ class EmbeddingsAndEvoformer(hk.Module):
extra_msa_activations = common_modules.Linear( extra_msa_activations = common_modules.Linear(
c.extra_msa_channel, c.extra_msa_channel,
name='extra_msa_activations')( name='extra_msa_activations')(
extra_msa_feat) extra_msa_feat).astype(dtype)
extra_msa_mask = extra_msa_mask.astype(jnp.float32) extra_msa_mask = extra_msa_mask.astype(dtype)
extra_evoformer_input = { extra_evoformer_input = {
'msa': extra_msa_activations, 'msa': extra_msa_activations,
...@@ -714,18 +749,19 @@ class EmbeddingsAndEvoformer(hk.Module): ...@@ -714,18 +749,19 @@ class EmbeddingsAndEvoformer(hk.Module):
'msa': msa_activations, 'msa': msa_activations,
'pair': pair_activations, 'pair': pair_activations,
} }
evoformer_masks = {'msa': batch['msa_mask'].astype(jnp.float32), evoformer_masks = {
'pair': mask_2d} 'msa': batch['msa_mask'].astype(dtype),
'pair': mask_2d
}
if c.template.enabled: if c.template.enabled:
template_features, template_masks = ( template_features, template_masks = (
template_embedding_1d(batch=batch, num_channel=c.msa_channel)) template_embedding_1d(
batch=batch, num_channel=c.msa_channel, global_config=gc))
evoformer_input['msa'] = jnp.concatenate( evoformer_input['msa'] = jnp.concatenate(
[evoformer_input['msa'], template_features], axis=0) [evoformer_input['msa'], template_features], axis=0)
evoformer_masks['msa'] = jnp.concatenate( evoformer_masks['msa'] = jnp.concatenate(
[evoformer_masks['msa'], template_masks], axis=0) [evoformer_masks['msa'], template_masks], axis=0)
evoformer_iteration = modules.EvoformerIteration( evoformer_iteration = modules.EvoformerIteration(
c.evoformer, gc, is_extra_msa=False, name='evoformer_iteration') c.evoformer, gc, is_extra_msa=False, name='evoformer_iteration')
...@@ -771,6 +807,12 @@ class EmbeddingsAndEvoformer(hk.Module): ...@@ -771,6 +807,12 @@ class EmbeddingsAndEvoformer(hk.Module):
msa_activations[0], msa_activations[0],
}) })
# Convert back to float32 if we're not saving memory.
if not gc.bfloat16_output:
for k, v in output.items():
if v.dtype == jnp.bfloat16:
output[k] = v.astype(jnp.float32)
return output return output
...@@ -917,6 +959,9 @@ class SingleTemplateEmbedding(hk.Module): ...@@ -917,6 +959,9 @@ class SingleTemplateEmbedding(hk.Module):
# backbone affine - i.e. in each residues local frame, what direction are # backbone affine - i.e. in each residues local frame, what direction are
# each of the other residues. # each of the other residues.
raw_atom_pos = template_all_atom_positions raw_atom_pos = template_all_atom_positions
if gc.bfloat16:
# Vec3Arrays are required to be float32
raw_atom_pos = raw_atom_pos.astype(jnp.float32)
atom_pos = geometry.Vec3Array.from_array(raw_atom_pos) atom_pos = geometry.Vec3Array.from_array(raw_atom_pos)
rigid, backbone_mask = folding_multimer.make_backbone_affine( rigid, backbone_mask = folding_multimer.make_backbone_affine(
...@@ -928,6 +973,10 @@ class SingleTemplateEmbedding(hk.Module): ...@@ -928,6 +973,10 @@ class SingleTemplateEmbedding(hk.Module):
unit_vector = rigid_vec.normalized() unit_vector = rigid_vec.normalized()
unit_vector = [unit_vector.x, unit_vector.y, unit_vector.z] unit_vector = [unit_vector.x, unit_vector.y, unit_vector.z]
if gc.bfloat16:
unit_vector = [x.astype(jnp.bfloat16) for x in unit_vector]
backbone_mask = backbone_mask.astype(jnp.bfloat16)
backbone_mask_2d = backbone_mask[:, None] * backbone_mask[None, :] backbone_mask_2d = backbone_mask[:, None] * backbone_mask[None, :]
backbone_mask_2d *= multichain_mask_2d backbone_mask_2d *= multichain_mask_2d
unit_vector = [x*backbone_mask_2d for x in unit_vector] unit_vector = [x*backbone_mask_2d for x in unit_vector]
...@@ -937,7 +986,7 @@ class SingleTemplateEmbedding(hk.Module): ...@@ -937,7 +986,7 @@ class SingleTemplateEmbedding(hk.Module):
to_concat.extend([(x, 0) for x in unit_vector]) to_concat.extend([(x, 0) for x in unit_vector])
to_concat.append((backbone_mask_2d, 0)) to_concat.append((backbone_mask_2d, 0))
query_embedding = hk.LayerNorm( query_embedding = common_modules.LayerNorm(
axis=[-1], axis=[-1],
create_scale=True, create_scale=True,
create_offset=True, create_offset=True,
...@@ -986,12 +1035,13 @@ class SingleTemplateEmbedding(hk.Module): ...@@ -986,12 +1035,13 @@ class SingleTemplateEmbedding(hk.Module):
template_iteration_fn) template_iteration_fn)
act, safe_key = template_stack((act, safe_subkey)) act, safe_key = template_stack((act, safe_subkey))
act = hk.LayerNorm( act = common_modules.LayerNorm(
axis=[-1], axis=[-1],
create_scale=True, create_scale=True,
create_offset=True, create_offset=True,
name='output_layer_norm')( name='output_layer_norm')(
act) act)
return act return act
...@@ -1044,21 +1094,18 @@ class TemplateEmbeddingIteration(hk.Module): ...@@ -1044,21 +1094,18 @@ class TemplateEmbeddingIteration(hk.Module):
act, act,
pair_mask, pair_mask,
safe_key=next(sub_keys)) safe_key=next(sub_keys))
act = dropout_wrapper_fn( act = dropout_wrapper_fn(
modules.TriangleAttention(c.triangle_attention_starting_node, gc, modules.TriangleAttention(c.triangle_attention_starting_node, gc,
name='triangle_attention_starting_node'), name='triangle_attention_starting_node'),
act, act,
pair_mask, pair_mask,
safe_key=next(sub_keys)) safe_key=next(sub_keys))
act = dropout_wrapper_fn( act = dropout_wrapper_fn(
modules.TriangleAttention(c.triangle_attention_ending_node, gc, modules.TriangleAttention(c.triangle_attention_ending_node, gc,
name='triangle_attention_ending_node'), name='triangle_attention_ending_node'),
act, act,
pair_mask, pair_mask,
safe_key=next(sub_keys)) safe_key=next(sub_keys))
act = dropout_wrapper_fn( act = dropout_wrapper_fn(
modules.Transition(c.pair_transition, gc, modules.Transition(c.pair_transition, gc,
name='pair_transition'), name='pair_transition'),
...@@ -1069,7 +1116,7 @@ class TemplateEmbeddingIteration(hk.Module): ...@@ -1069,7 +1116,7 @@ class TemplateEmbeddingIteration(hk.Module):
return act return act
def template_embedding_1d(batch, num_channel): def template_embedding_1d(batch, num_channel, global_config):
"""Embed templates into an (num_res, num_templates, num_channels) embedding. """Embed templates into an (num_res, num_templates, num_channels) embedding.
Args: Args:
...@@ -1080,6 +1127,7 @@ def template_embedding_1d(batch, num_channel): ...@@ -1080,6 +1127,7 @@ def template_embedding_1d(batch, num_channel):
template_all_atom_mask, (num_templates, num_residues, 37) atom mask for template_all_atom_mask, (num_templates, num_residues, 37) atom mask for
each template. each template.
num_channel: The number of channels in the output. num_channel: The number of channels in the output.
global_config: The global_config.
Returns: Returns:
An embedding of shape (num_templates, num_res, num_channels) and a mask of An embedding of shape (num_templates, num_res, num_channels) and a mask of
...@@ -1112,6 +1160,10 @@ def template_embedding_1d(batch, num_channel): ...@@ -1112,6 +1160,10 @@ def template_embedding_1d(batch, num_channel):
template_mask = chi_mask[:, :, 0] template_mask = chi_mask[:, :, 0]
if global_config.bfloat16:
template_features = template_features.astype(jnp.bfloat16)
template_mask = template_mask.astype(jnp.bfloat16)
template_activations = common_modules.Linear( template_activations = common_modules.Linear(
num_channel, num_channel,
initializer='relu', initializer='relu',
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
"""A collection of JAX utility functions for use in protein folding.""" """A collection of JAX utility functions for use in protein folding."""
import collections import collections
import contextlib
import functools import functools
import numbers import numbers
from typing import Mapping from typing import Mapping
...@@ -25,6 +26,27 @@ import jax.numpy as jnp ...@@ -25,6 +26,27 @@ import jax.numpy as jnp
import numpy as np import numpy as np
def bfloat16_creator(next_creator, shape, dtype, init, context):
"""Creates float32 variables when bfloat16 is requested."""
if context.original_dtype == jnp.bfloat16:
dtype = jnp.float32
return next_creator(shape, dtype, init)
def bfloat16_getter(next_getter, value, context):
"""Casts float32 to bfloat16 when bfloat16 was originally requested."""
if context.original_dtype == jnp.bfloat16:
assert value.dtype == jnp.float32
value = value.astype(jnp.bfloat16)
return next_getter(value)
@contextlib.contextmanager
def bfloat16_context():
with hk.custom_creator(bfloat16_creator), hk.custom_getter(bfloat16_getter):
yield
def final_init(config): def final_init(config):
if config.zero_init: if config.zero_init:
return 'zeros' return 'zeros'
......
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
"""Helper methods for the AlphaFold Colab notebook.""" """Helper methods for the AlphaFold Colab notebook."""
import enum
import json import json
from typing import Any, Mapping, Optional, Sequence, Tuple from typing import Any, Mapping, Optional, Sequence, Tuple
...@@ -23,13 +22,7 @@ from matplotlib import pyplot as plt ...@@ -23,13 +22,7 @@ from matplotlib import pyplot as plt
import numpy as np import numpy as np
@enum.unique def clean_and_validate_single_sequence(
class ModelType(enum.Enum):
MONOMER = 0
MULTIMER = 1
def clean_and_validate_sequence(
input_sequence: str, min_length: int, max_length: int) -> str: input_sequence: str, min_length: int, max_length: int) -> str:
"""Checks that the input sequence is ok and returns a clean version of it.""" """Checks that the input sequence is ok and returns a clean version of it."""
# Remove all whitespaces, tabs and end lines; upper-case. # Remove all whitespaces, tabs and end lines; upper-case.
...@@ -54,41 +47,23 @@ def clean_and_validate_sequence( ...@@ -54,41 +47,23 @@ def clean_and_validate_sequence(
return clean_sequence return clean_sequence
def validate_input( def clean_and_validate_input_sequences(
input_sequences: Sequence[str], input_sequences: Sequence[str],
min_length: int, min_sequence_length: int,
max_length: int, max_sequence_length: int) -> Sequence[str]:
max_multimer_length: int) -> Tuple[Sequence[str], ModelType]: """Validates and cleans input sequences."""
"""Validates and cleans input sequences and determines which model to use."""
sequences = [] sequences = []
for input_sequence in input_sequences: for input_sequence in input_sequences:
if input_sequence.strip(): if input_sequence.strip():
input_sequence = clean_and_validate_sequence( input_sequence = clean_and_validate_single_sequence(
input_sequence=input_sequence, input_sequence=input_sequence,
min_length=min_length, min_length=min_sequence_length,
max_length=max_length) max_length=max_sequence_length)
sequences.append(input_sequence) sequences.append(input_sequence)
if len(sequences) == 1: if sequences:
print('Using the single-chain model.') return sequences
return sequences, ModelType.MONOMER
elif len(sequences) > 1:
total_multimer_length = sum([len(seq) for seq in sequences])
if total_multimer_length > max_multimer_length:
raise ValueError(f'The total length of multimer sequences is too long: '
f'{total_multimer_length}, while the maximum is '
f'{max_multimer_length}. Please use the full AlphaFold '
f'system for long multimers.')
elif total_multimer_length > 1536:
print('WARNING: The accuracy of the system has not been fully validated '
'above 1536 residues, and you may experience long running times or '
f'run out of memory for your complex with {total_multimer_length} '
'residues.')
print(f'Using the multimer model with {len(sequences)} sequences.')
return sequences, ModelType.MULTIMER
else: else:
raise ValueError('No input amino acid sequence provided, please provide at ' raise ValueError('No input amino acid sequence provided, please provide at '
'least one sequence.') 'least one sequence.')
......
...@@ -85,7 +85,7 @@ class NotebookUtilsTest(parameterized.TestCase): ...@@ -85,7 +85,7 @@ class NotebookUtilsTest(parameterized.TestCase):
('DeepMind', 'DEEPMIND'), ('A ', 'A'), ('\tA', 'A'), (' A\t\n', 'A'), ('DeepMind', 'DEEPMIND'), ('A ', 'A'), ('\tA', 'A'), (' A\t\n', 'A'),
('ACDEFGHIKLMNPQRSTVWY', 'ACDEFGHIKLMNPQRSTVWY')) ('ACDEFGHIKLMNPQRSTVWY', 'ACDEFGHIKLMNPQRSTVWY'))
def test_clean_and_validate_sequence_ok(self, sequence, exp_clean): def test_clean_and_validate_sequence_ok(self, sequence, exp_clean):
clean = notebook_utils.clean_and_validate_sequence( clean = notebook_utils.clean_and_validate_single_sequence(
sequence, min_length=1, max_length=100) sequence, min_length=1, max_length=100)
self.assertEqual(clean, exp_clean) self.assertEqual(clean, exp_clean)
...@@ -100,35 +100,29 @@ class NotebookUtilsTest(parameterized.TestCase): ...@@ -100,35 +100,29 @@ class NotebookUtilsTest(parameterized.TestCase):
('bad_amino_acids_Z', 'ZZZZ', 'non-amino acid')) ('bad_amino_acids_Z', 'ZZZZ', 'non-amino acid'))
def test_clean_and_validate_sequence_bad(self, sequence, exp_error): def test_clean_and_validate_sequence_bad(self, sequence, exp_error):
with self.assertRaisesRegex(ValueError, f'.*{exp_error}.*'): with self.assertRaisesRegex(ValueError, f'.*{exp_error}.*'):
notebook_utils.clean_and_validate_sequence( notebook_utils.clean_and_validate_single_sequence(
sequence, min_length=4, max_length=8) sequence, min_length=4, max_length=8)
@parameterized.parameters( @parameterized.parameters(
(['A', '', '', ' ', '\t', ' \t\n', '', ''], ['A'], (['A', '', '', ' ', '\t', ' \t\n', '', ''], ['A']),
notebook_utils.ModelType.MONOMER), (['', 'A'], ['A']),
(['', 'A'], ['A'], (['A', 'C ', ''], ['A', 'C']),
notebook_utils.ModelType.MONOMER), (['', 'A', '', 'C '], ['A', 'C']))
(['A', 'C ', ''], ['A', 'C'], def test_validate_input_ok(self, input_sequences, exp_sequences):
notebook_utils.ModelType.MULTIMER), sequences = notebook_utils.clean_and_validate_input_sequences(
(['', 'A', '', 'C '], ['A', 'C'],
notebook_utils.ModelType.MULTIMER))
def test_validate_input_ok(
self, input_sequences, exp_sequences, exp_model_type):
sequences, model_type = notebook_utils.validate_input(
input_sequences=input_sequences, input_sequences=input_sequences,
min_length=1, max_length=100, max_multimer_length=100) min_sequence_length=1, max_sequence_length=100)
self.assertSequenceEqual(sequences, exp_sequences) self.assertSequenceEqual(sequences, exp_sequences)
self.assertEqual(model_type, exp_model_type)
@parameterized.named_parameters( @parameterized.named_parameters(
('no_input_sequence', ['', '\t', '\n'], 'No input amino acid sequence'), ('no_input_sequence', ['', '\t', '\n'], 'No input amino acid sequence'),
('too_long_single', ['AAAAAAAAA', 'AAAA'], 'Input sequence is too long'), ('too_long_single', ['AAAAAAAAA', 'AAAA'], 'Input sequence is too long'),
('too_long_multimer', ['AAAA', 'AAAAA'], 'The total length of multimer')) ('too_short_single', ['AAA', 'AAAA'], 'Input sequence is too short'))
def test_validate_input_bad(self, input_sequences, exp_error): def test_validate_input_bad(self, input_sequences, exp_error):
with self.assertRaisesRegex(ValueError, f'.*{exp_error}.*'): with self.assertRaisesRegex(ValueError, f'.*{exp_error}.*'):
notebook_utils.validate_input( notebook_utils.clean_and_validate_input_sequences(
input_sequences=input_sequences, input_sequences=input_sequences, min_sequence_length=4,
min_length=4, max_length=8, max_multimer_length=6) max_sequence_length=8)
def test_merge_chunked_msa_no_hits(self): def test_merge_chunked_msa_no_hits(self):
results = [ONLY_QUERY_HIT, ONLY_QUERY_HIT] results = [ONLY_QUERY_HIT, ONLY_QUERY_HIT]
......
...@@ -56,7 +56,8 @@ class AmberRelaxation(object): ...@@ -56,7 +56,8 @@ class AmberRelaxation(object):
self._use_gpu = use_gpu self._use_gpu = use_gpu
def process(self, *, def process(self, *,
prot: protein.Protein) -> Tuple[str, Dict[str, Any], np.ndarray]: prot: protein.Protein
) -> Tuple[str, Dict[str, Any], Sequence[float]]:
"""Runs Amber relax on a prediction, adds hydrogens, returns PDB string.""" """Runs Amber relax on a prediction, adds hydrogens, returns PDB string."""
out = amber_minimize.run_pipeline( out = amber_minimize.run_pipeline(
prot=prot, max_iterations=self._max_iterations, prot=prot, max_iterations=self._max_iterations,
...@@ -73,12 +74,11 @@ class AmberRelaxation(object): ...@@ -73,12 +74,11 @@ class AmberRelaxation(object):
'attempts': out['min_attempts'], 'attempts': out['min_attempts'],
'rmsd': rmsd 'rmsd': rmsd
} }
pdb_str = amber_minimize.clean_protein(prot) min_pdb = out['min_pdb']
min_pdb = utils.overwrite_pdb_coordinates(pdb_str, min_pos)
min_pdb = utils.overwrite_b_factors(min_pdb, prot.b_factors) min_pdb = utils.overwrite_b_factors(min_pdb, prot.b_factors)
utils.assert_equal_nonterminal_atom_types( utils.assert_equal_nonterminal_atom_types(
protein.from_pdb_string(min_pdb).atom_mask, protein.from_pdb_string(min_pdb).atom_mask,
prot.atom_mask) prot.atom_mask)
violations = out['structural_violations'][ violations = out['structural_violations'][
'total_per_residue_violations_mask'] 'total_per_residue_violations_mask'].tolist()
return min_pdb, debug_data, violations return min_pdb, debug_data, violations
...@@ -82,7 +82,7 @@ class RunAmberRelaxTest(absltest.TestCase): ...@@ -82,7 +82,7 @@ class RunAmberRelaxTest(absltest.TestCase):
0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0,
0, 0, 0, 0]) 0, 0, 0, 0])
# Check no violations were added. Can't check exactly due to stochasticity. # Check no violations were added. Can't check exactly due to stochasticity.
self.assertTrue(np.all(num_violations <= exp_num_violations)) self.assertTrue(np.all(np.array(num_violations) <= exp_num_violations))
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -17,17 +17,6 @@ import io ...@@ -17,17 +17,6 @@ import io
from alphafold.common import residue_constants from alphafold.common import residue_constants
from Bio import PDB from Bio import PDB
import numpy as np import numpy as np
from simtk.openmm import app as openmm_app
from simtk.openmm.app.internal.pdbstructure import PdbStructure
def overwrite_pdb_coordinates(pdb_str: str, pos) -> str:
pdb_file = io.StringIO(pdb_str)
structure = PdbStructure(pdb_file)
topology = openmm_app.PDBFile(structure).getTopology()
with io.StringIO() as f:
openmm_app.PDBFile.writeFile(topology, pos, f)
return f.getvalue()
def overwrite_b_factors(pdb_str: str, bfactors: np.ndarray) -> str: def overwrite_b_factors(pdb_str: str, bfactors: np.ndarray) -> str:
......
...@@ -59,7 +59,7 @@ RUN conda install -qy conda==4.13.0 \ ...@@ -59,7 +59,7 @@ RUN conda install -qy conda==4.13.0 \
cudatoolkit==${CUDA_VERSION} \ cudatoolkit==${CUDA_VERSION} \
pdbfixer \ pdbfixer \
pip \ pip \
python=3.7 \ python=3.8 \
&& conda clean --all --force-pkgs-dirs --yes && conda clean --all --force-pkgs-dirs --yes
COPY . /app/alphafold COPY . /app/alphafold
...@@ -75,7 +75,7 @@ RUN pip3 install --upgrade pip --no-cache-dir \ ...@@ -75,7 +75,7 @@ RUN pip3 install --upgrade pip --no-cache-dir \
-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# Apply OpenMM patch. # Apply OpenMM patch.
WORKDIR /opt/conda/lib/python3.7/site-packages WORKDIR /opt/conda/lib/python3.8/site-packages
RUN patch -p0 < /app/alphafold/docker/openmm.patch RUN patch -p0 < /app/alphafold/docker/openmm.patch
# Add SETUID bit to the ldconfig binary so that non-root users can run it. # Add SETUID bit to the ldconfig binary so that non-root users can run it.
......
...@@ -133,7 +133,7 @@ def main(argv): ...@@ -133,7 +133,7 @@ def main(argv):
# Path to the MGnify database for use by JackHMMER. # Path to the MGnify database for use by JackHMMER.
mgnify_database_path = os.path.join( mgnify_database_path = os.path.join(
FLAGS.data_dir, 'mgnify', 'mgy_clusters_2018_12.fa') FLAGS.data_dir, 'mgnify', 'mgy_clusters_2022_05.fa')
# Path to the BFD database for use by HHblits. # Path to the BFD database for use by HHblits.
bfd_database_path = os.path.join( bfd_database_path = os.path.join(
...@@ -144,9 +144,9 @@ def main(argv): ...@@ -144,9 +144,9 @@ def main(argv):
small_bfd_database_path = os.path.join( small_bfd_database_path = os.path.join(
FLAGS.data_dir, 'small_bfd', 'bfd-first_non_consensus_sequences.fasta') FLAGS.data_dir, 'small_bfd', 'bfd-first_non_consensus_sequences.fasta')
# Path to the Uniclust30 database for use by HHblits. # Path to the Uniref30 database for use by HHblits.
uniclust30_database_path = os.path.join( uniref30_database_path = os.path.join(
FLAGS.data_dir, 'uniclust30', 'uniclust30_2018_08', 'uniclust30_2018_08') FLAGS.data_dir, 'uniref30', 'UniRef30_2021_03')
# Path to the PDB70 database for use by HHsearch. # Path to the PDB70 database for use by HHsearch.
pdb70_database_path = os.path.join(FLAGS.data_dir, 'pdb70', 'pdb70') pdb70_database_path = os.path.join(FLAGS.data_dir, 'pdb70', 'pdb70')
...@@ -199,7 +199,7 @@ def main(argv): ...@@ -199,7 +199,7 @@ def main(argv):
database_paths.append(('small_bfd_database_path', small_bfd_database_path)) database_paths.append(('small_bfd_database_path', small_bfd_database_path))
else: else:
database_paths.extend([ database_paths.extend([
('uniclust30_database_path', uniclust30_database_path), ('uniref30_database_path', uniref30_database_path),
('bfd_database_path', bfd_database_path), ('bfd_database_path', bfd_database_path),
]) ])
for name, path in database_paths: for name, path in database_paths:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment