Commit 9b18d6a9 authored by Augustin Zidek's avatar Augustin Zidek
Browse files

Release code for v2.3.0

PiperOrigin-RevId: 494507694
parent 4494af84
......@@ -7,23 +7,30 @@ v2.0. This is a completely new model that was entered in CASP14 and published in
Nature. For simplicity, we refer to this model as AlphaFold throughout the rest
of this document.
We also provide an implementation of AlphaFold-Multimer. This represents a work
in progress and AlphaFold-Multimer isn't expected to be as stable as our monomer
AlphaFold system.
[Read the guide](#updating-existing-alphafold-installation-to-include-alphafold-multimers)
for how to upgrade and update code.
Any publication that discloses findings arising from using this source code or the model parameters should [cite](#citing-this-work) the
[AlphaFold paper](https://doi.org/10.1038/s41586-021-03819-2) and, if
applicable, the [AlphaFold-Multimer paper](https://www.biorxiv.org/content/10.1101/2021.10.04.463034v1).
We also provide:
1. An implementation of AlphaFold-Multimer. This represents a work in progress
and AlphaFold-Multimer isn't expected to be as stable as our monomer
AlphaFold system.
[Read the guide](#updating-existing-installation)
for how to upgrade and update code.
2. The [technical note](docs/technical_note_v2.3.0.md) containing the models
and inference procedure for an updated AlphaFold v2.3.0.
3. A [CASP15 baseline](docs/casp15_predictions.zip) set of predictions along
with documentation of any manual interventions performed.
Any publication that discloses findings arising from using this source code or
the model parameters should [cite](#citing-this-work) the
[AlphaFold paper](https://doi.org/10.1038/s41586-021-03819-2) and, if
applicable, the
[AlphaFold-Multimer paper](https://www.biorxiv.org/content/10.1101/2021.10.04.463034v1).
Please also refer to the
[Supplementary Information](https://static-content.springer.com/esm/art%3A10.1038%2Fs41586-021-03819-2/MediaObjects/41586_2021_3819_MOESM1_ESM.pdf)
for a detailed description of the method.
**You can use a slightly simplified version of AlphaFold with
[this Colab
notebook](https://colab.research.google.com/github/deepmind/alphafold/blob/main/notebooks/AlphaFold.ipynb)**
[this Colab notebook](https://colab.research.google.com/github/deepmind/alphafold/blob/main/notebooks/AlphaFold.ipynb)**
or community-supported versions (see below).
If you have any questions, please contact the AlphaFold team at
......@@ -58,9 +65,9 @@ The following steps are required in order to run AlphaFold:
or take a look at the following
[NVIDIA Docker issue](https://github.com/NVIDIA/nvidia-docker/issues/1447#issuecomment-801479573).
If you wish to run AlphaFold using Singularity (a common containerization platform on HPC systems) we recommend using some of the
third party Singularity setups as linked in
https://github.com/deepmind/alphafold/issues/10 or
If you wish to run AlphaFold using Singularity (a common containerization
platform on HPC systems) we recommend using some of the third party Singularity
setups as linked in https://github.com/deepmind/alphafold/issues/10 or
https://github.com/deepmind/alphafold/issues/24.
### Genetic databases
......@@ -74,7 +81,7 @@ AlphaFold needs multiple genetic (sequence) databases to run:
* [PDB70](http://wwwuser.gwdg.de/~compbiol/data/hhsuite/databases/hhsuite_dbs/),
* [PDB](https://www.rcsb.org/) (structures in the mmCIF format),
* [PDB seqres](https://www.rcsb.org/) – only for AlphaFold-Multimer,
* [Uniclust30](https://uniclust.mmseqs.com/),
* [UniRef30 (FKA UniClust30)](https://uniclust.mmseqs.com/),
* [UniProt](https://www.uniprot.org/uniprot/) – only for AlphaFold-Multimer,
* [UniRef90](https://www.uniprot.org/help/uniref).
......@@ -98,29 +105,34 @@ and set up all of these databases:
will download a reduced version of the databases to be used with the
`reduced_dbs` database preset.
:ledger: **Note: The download directory `<DOWNLOAD_DIR>` should _not_ be a
:ledger: **Note: The download directory `<DOWNLOAD_DIR>` should *not* be a
subdirectory in the AlphaFold repository directory.** If it is, the Docker build
will be slow as the large databases will be copied during the image creation.
We don't provide exactly the database versions used in CASP14 – see the [note on
reproducibility](#note-on-reproducibility). Some of the databases are mirrored
for speed, see [mirrored databases](#mirrored-databases).
We don't provide exactly the database versions used in CASP14 – see the
[note on reproducibility](#note-on-casp14-reproducibility). Some of the databases are
mirrored for speed, see [mirrored databases](#mirrored-databases).
:ledger: **Note: The total download size for the full databases is around 415 GB
and the total size when unzipped is 2.2 TB. Please make sure you have a large
and the total size when unzipped is 2.62 TB. Please make sure you have a large
enough hard drive space, bandwidth and time to download. We recommend using an
SSD for better genetic search performance.**
:ledger: **Note: If the download directory and datasets don't have full read and
write permissions, it can cause errors with the MSA tools, with opaque
(external) error messages. Please ensure the required permissions are applied,
e.g. with the `sudo chmod 755 --recursive "$DOWNLOAD_DIR"` command.**
The `download_all_data.sh` script will also download the model parameter files.
Once the script has finished, you should have the following directory structure:
```
$DOWNLOAD_DIR/ # Total: ~ 2.2 TB (download: 438 GB)
bfd/ # ~ 1.7 TB (download: 271.6 GB)
$DOWNLOAD_DIR/ # Total: ~ 2.62 TB (download: 556 GB)
bfd/ # ~ 1.8 TB (download: 271.6 GB)
# 6 files.
mgnify/ # ~ 64 GB (download: 32.9 GB)
mgy_clusters_2018_12.fa
params/ # ~ 3.5 GB (download: 3.5 GB)
mgnify/ # ~ 120 GB (download: 67 GB)
mgy_clusters_2022_05.fa
params/ # ~ 5.3 GB (download: 5.3 GB)
# 5 CASP14 models,
# 5 pTM models,
# 5 AlphaFold-Multimer models,
......@@ -128,20 +140,19 @@ $DOWNLOAD_DIR/ # Total: ~ 2.2 TB (download: 438 GB)
# = 16 files.
pdb70/ # ~ 56 GB (download: 19.5 GB)
# 9 files.
pdb_mmcif/ # ~ 206 GB (download: 46 GB)
pdb_mmcif/ # ~ 238 GB (download: 43 GB)
mmcif_files/
# About 180,000 .cif files.
# About 199,000 .cif files.
obsolete.dat
pdb_seqres/ # ~ 0.2 GB (download: 0.2 GB)
pdb_seqres.txt
small_bfd/ # ~ 17 GB (download: 9.6 GB)
bfd-first_non_consensus_sequences.fasta
uniclust30/ # ~ 86 GB (download: 24.9 GB)
uniclust30_2018_08/
# 13 files.
uniprot/ # ~ 98.3 GB (download: 49 GB)
uniref30/ # ~ 206 GB (download: 52.5 GB)
# 7 files.
uniprot/ # ~ 105 GB (download: 53 GB)
uniprot.fasta
uniref90/ # ~ 58 GB (download: 29.7 GB)
uniref90/ # ~ 67 GB (download: 34 GB)
uniref90.fasta
```
......@@ -151,11 +162,12 @@ is only downloaded if you download the reduced databases.
### Model parameters
While the AlphaFold code is licensed under the Apache 2.0 License, the AlphaFold
parameters are made available under the terms of the CC BY 4.0 license. Please
see the [Disclaimer](#license-and-disclaimer) below for more detail.
parameters and CASP15 prediction data are made available under the terms of the
CC BY 4.0 license. Please see the [Disclaimer](#license-and-disclaimer) below
for more detail.
The AlphaFold parameters are available from
https://storage.googleapis.com/alphafold/alphafold_params_2022-03-02.tar, and
https://storage.googleapis.com/alphafold/alphafold_params_2022-12-06.tar, and
are downloaded as part of the `scripts/download_all_data.sh` script. This script
will download parameters for:
......@@ -168,18 +180,25 @@ will download parameters for:
* 5 AlphaFold-Multimer models that produce pTM and PAE values alongside their
structure predictions.
### Updating existing AlphaFold installation to include AlphaFold-Multimers
### Updating existing installation
If you have AlphaFold v2.0.0 or v2.0.1 you can either reinstall AlphaFold fully
from scratch (remove everything and run the setup from scratch) or you can do an
incremental update that will be significantly faster but will require a bit more
work. Make sure you follow these steps in the exact order they are listed below:
If you have a previous version you can either reinstall fully from scratch
(remove everything and run the setup from scratch) or you can do an incremental
update that will be significantly faster but will require a bit more work. Make
sure you follow these steps in the exact order they are listed below:
1. **Update the code.**
* Go to the directory with the cloned AlphaFold repository and run
`git fetch origin main` to get all code updates.
1. **Download the UniProt and PDB seqres databases.**
* Go to the directory with the cloned AlphaFold repository and run `git
fetch origin main` to get all code updates.
1. **Update the UniProt, UniRef, MGnify and PDB seqres databases.**
* Remove `<DOWNLOAD_DIR>/uniprot`.
* Run `scripts/download_uniprot.sh <DOWNLOAD_DIR>`.
* Remove `<DOWNLOAD_DIR>/uniclust30`.
* Run `scripts/download_uniref30.sh <DOWNLOAD_DIR>`.
* Remove `<DOWNLOAD_DIR>/uniref90`.
* Run `scripts/download_uniref90.sh <DOWNLOAD_DIR>`.
* Remove `<DOWNLOAD_DIR>/mgnify`.
* Run `scripts/download_mgnify.sh <DOWNLOAD_DIR>`.
* Remove `<DOWNLOAD_DIR>/pdb_mmcif`. It is needed to have PDB SeqRes and
PDB from exactly the same date. Failure to do this step will result in
potential errors when searching for templates when running
......@@ -192,40 +211,21 @@ work. Make sure you follow these steps in the exact order they are listed below:
`scripts/download_alphafold_params.sh <DOWNLOAD_DIR>`.
1. **Follow [Running AlphaFold](#running-alphafold).**
#### API changes between v2.0.0 and v2.1.0
We tried to keep the API as much backwards compatible as possible, but we had to
change the following:
* The `RunModel.predict()` now needs a `random_seed` argument as MSA sampling
happens inside the Multimer model.
* The `preset` flag in `run_alphafold.py` and `run_docker.py` was split into
`db_preset` and `model_preset`.
* The models to use are not specified using `model_names` but rather using the
`model_preset` flag. If you want to customize which models are used for each
preset, you will have to modify the the `MODEL_PRESETS` dictionary in
`alphafold/model/config.py`.
* Setting the `data_dir` flag is now needed when using `run_docker.py`.
#### API changes between v2.1.0 and v2.2.0
The AlphaFold-Multimer model weights have been updated, these new models have
greatly reduced numbers of clashes on average and are slightly more accurate.
#### Using deprecated model weights
A flag `--num_multimer_predictions_per_model` has been added that controls how
many predictions will be made per model, by default the offline system will run
each model 5 times for a total of 25 predictions.
To use the deprecated v2.2.0 AlphaFold-Multimer model weights:
The `--is_prokaryote_list` flag has been removed along with the `is_prokaryote`
argument in `run_alphafold.predict_structure()`, eukaryotes and prokaryotes are
now paired in the same way.
1. Change `SOURCE_URL` in `scripts/download_alphafold_params.sh` to
`https://storage.googleapis.com/alphafold/alphafold_params_2022-03-02.tar`,
and download the old parameters.
2. Change the `_v3` to `_v2` in the multimer `MODEL_PRESETS` in `config.py`.
To use the deprecated v2.1.0 AlphaFold-Multimer model weights:
1. Change `SOURCE_URL` in `scripts/download_alphafold_params.sh` to
`https://storage.googleapis.com/alphafold/alphafold_params_2022-01-19.tar`,
and download the old parameters.
2. Remove the `_v2` in the multimer `MODEL_PRESETS` in `config.py`.
`https://storage.googleapis.com/alphafold/alphafold_params_2022-01-19.tar`,
and download the old parameters.
2. Remove the `_v3` in the multimer `MODEL_PRESETS` in `config.py`.
## Running AlphaFold
......@@ -266,9 +266,7 @@ with 12 vCPUs, 85 GB of RAM, a 100 GB boot disk, the databases on an additional
```
1. Make sure that the output directory exists (the default is `/tmp/alphafold`)
and that you have sufficient permissions to write into it. You can make sure
that is the case by manually running `mkdir /tmp/alphafold` and
`chmod 770 /tmp/alphafold`.
and that you have sufficient permissions to write into it.
1. Run `run_docker.py` pointing to a FASTA file containing the protein
sequence(s) for which you wish to predict the structure. If you are
......@@ -291,32 +289,32 @@ with 12 vCPUs, 85 GB of RAM, a 100 GB boot disk, the databases on an additional
[GPU enumeration](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/user-guide.html#gpu-enumeration)
for more details.
1. You can control which AlphaFold model to run by adding the
`--model_preset=` flag. We provide the following models:
1. You can control which AlphaFold model to run by adding the `--model_preset=`
flag. We provide the following models:
* **monomer**: This is the original model used at CASP14 with no ensembling.
* **monomer**: This is the original model used at CASP14 with no
ensembling.
* **monomer\_casp14**: This is the original model used at CASP14 with
`num_ensemble=8`, matching our CASP14 configuration. This is largely
provided for reproducibility as it is 8x more computationally
expensive for limited accuracy gain (+0.1 average GDT gain on CASP14
domains).
* **monomer\_casp14**: This is the original model used at CASP14 with
`num_ensemble=8`, matching our CASP14 configuration. This is largely
provided for reproducibility as it is 8x more computationally expensive
for limited accuracy gain (+0.1 average GDT gain on CASP14 domains).
* **monomer\_ptm**: This is the original CASP14 model fine tuned with the
pTM head, providing a pairwise confidence measure. It is slightly less
accurate than the normal monomer model.
* **monomer\_ptm**: This is the original CASP14 model fine tuned with the
pTM head, providing a pairwise confidence measure. It is slightly less
accurate than the normal monomer model.
* **multimer**: This is the [AlphaFold-Multimer](#citing-this-work) model.
To use this model, provide a multi-sequence FASTA file. In addition, the
UniProt database should have been downloaded.
* **multimer**: This is the [AlphaFold-Multimer](#citing-this-work) model.
To use this model, provide a multi-sequence FASTA file. In addition, the
UniProt database should have been downloaded.
1. You can control MSA speed/quality tradeoff by adding
`--db_preset=reduced_dbs` or `--db_preset=full_dbs` to the run command. We
provide the following presets:
* **reduced\_dbs**: This preset is optimized for speed and lower hardware
requirements. It runs with a reduced version of the BFD database.
It requires 8 CPU cores (vCPUs), 8 GB of RAM, and 600 GB of disk space.
requirements. It runs with a reduced version of the BFD database. It
requires 8 CPU cores (vCPUs), 8 GB of RAM, and 600 GB of disk space.
* **full\_dbs**: This runs with all genetic databases used at CASP14.
......@@ -350,7 +348,7 @@ python3 docker/run_docker.py \
```
By default the multimer system will run 5 seeds per model (25 total predictions)
for a small drop in accuracy you may wish to run a single seed per model. This
for a small drop in accuracy you may wish to run a single seed per model. This
can be done via the `--num_multimer_predictions_per_model` flag, e.g. set it to
`--num_multimer_predictions_per_model=1` to run a single seed per model.
......@@ -379,8 +377,8 @@ python3 docker/run_docker.py \
#### Folding a homomer
Say we have a homomer with 3 copies of the same sequence
`<SEQUENCE>`. The input fasta should be:
Say we have a homomer with 3 copies of the same sequence `<SEQUENCE>`. The input
fasta should be:
```fasta
>sequence_1
......@@ -403,8 +401,8 @@ python3 docker/run_docker.py \
#### Folding a heteromer
Say we have an A2B3 heteromer, i.e. with 2 copies of
`<SEQUENCE A>` and 3 copies of `<SEQUENCE B>`. The input fasta should be:
Say we have an A2B3 heteromer, i.e. with 2 copies of `<SEQUENCE A>` and 3 copies
of `<SEQUENCE B>`. The input fasta should be:
```fasta
>sequence_1
......@@ -470,12 +468,13 @@ The `--output_dir` directory will have the following structure:
features.pkl
ranked_{0,1,2,3,4}.pdb
ranking_debug.json
relax_metrics.json
relaxed_model_{1,2,3,4,5}.pdb
result_model_{1,2,3,4,5}.pkl
timings.json
unrelaxed_model_{1,2,3,4,5}.pdb
msas/
bfd_uniclust_hits.a3m
bfd_uniref_hits.a3m
mgnify_hits.sto
uniref90_hits.sto
```
......@@ -499,6 +498,8 @@ The contents of each output file are as follows:
* `ranking_debug.json` – A JSON format text file containing the pLDDT values
used to perform the model ranking, and a mapping back to the original model
names.
* `relax_metrics.json` – A JSON format text file containing relax metrics, for
instance remaining violations.
* `timings.json` – A JSON format text file containing the times taken to run
each section of the AlphaFold pipeline.
* `msas/` - A directory containing the files describing the various genetic
......@@ -576,7 +577,8 @@ For genetics:
For templates:
* PDB: (downloaded 2020-05-14)
* PDB70: [2020-05-13](http://wwwuser.gwdg.de/~compbiol/data/hhsuite/databases/hhsuite_dbs/old-releases/pdb70_from_mmcif_200513.tar.gz)
* PDB70:
[2020-05-13](http://wwwuser.gwdg.de/~compbiol/data/hhsuite/databases/hhsuite_dbs/old-releases/pdb70_from_mmcif_200513.tar.gz)
An alternative for templates is to use the latest PDB and PDB70, but pass the
flag `--max_template_date=2020-05-14`, which restricts templates only to
......@@ -601,6 +603,7 @@ If you use the code or data in this package, please cite:
In addition, if you use the AlphaFold-Multimer mode, please cite:
```bibtex
@article {AlphaFold-Multimer2021,
author = {Evans, Richard and O{\textquoteright}Neill, Michael and Pritzel, Alexander and Antropova, Natasha and Senior, Andrew and Green, Tim and {\v{Z}}{\'\i}dek, Augustin and Bates, Russ and Blackwell, Sam and Yim, Jason and Ronneberger, Olaf and Bodenstein, Sebastian and Zielinski, Michal and Bridgland, Alex and Potapenko, Anna and Cowie, Andrew and Tunyasuvunakool, Kathryn and Jain, Rishub and Clancy, Ellen and Kohli, Pushmeet and Jumper, John and Hassabis, Demis},
......@@ -619,10 +622,11 @@ In addition, if you use the AlphaFold-Multimer mode, please cite:
Colab notebooks provided by the community (please note that these notebooks may
vary from our full AlphaFold system and we did not validate their accuracy):
* The [ColabFold AlphaFold2 notebook](https://colab.research.google.com/github/sokrypton/ColabFold/blob/main/AlphaFold2.ipynb)
* The
[ColabFold AlphaFold2 notebook](https://colab.research.google.com/github/sokrypton/ColabFold/blob/main/AlphaFold2.ipynb)
by Martin Steinegger, Sergey Ovchinnikov and Milot Mirdita, which uses an
API hosted at the Södinglab based on the MMseqs2 server [(Mirdita et al.
2019, Bioinformatics)](https://academic.oup.com/bioinformatics/article/35/16/2856/5280135)
API hosted at the Södinglab based on the MMseqs2 server
[(Mirdita et al. 2019, Bioinformatics)](https://academic.oup.com/bioinformatics/article/35/16/2856/5280135)
for the multiple sequence alignment creation.
## Acknowledgements
......@@ -661,15 +665,15 @@ We thank all their contributors and maintainers!
If you have any questions not covered in this overview, please contact the
AlphaFold team at [alphafold@deepmind.com](mailto:alphafold@deepmind.com).
We would love to hear your feedback and understand how AlphaFold has been
useful in your research. Share your stories with us at
We would love to hear your feedback and understand how AlphaFold has been useful
in your research. Share your stories with us at
[alphafold@deepmind.com](mailto:alphafold@deepmind.com).
## License and Disclaimer
This is not an officially supported Google product.
Copyright 2021 DeepMind Technologies Limited.
Copyright 2022 DeepMind Technologies Limited.
### AlphaFold Code License
......@@ -699,12 +703,26 @@ before use.
### Mirrored Databases
The following databases have been mirrored by DeepMind, and are available with reference to the following:
* [BFD](https://bfd.mmseqs.com/) (unmodified), by Steinegger M. and Söding J., available under a [Creative Commons Attribution-ShareAlike 4.0 International License](http://creativecommons.org/licenses/by-sa/4.0/).
* [BFD](https://bfd.mmseqs.com/) (modified), by Steinegger M. and Söding J., modified by DeepMind, available under a [Creative Commons Attribution-ShareAlike 4.0 International License](http://creativecommons.org/licenses/by-sa/4.0/). See the Methods section of the [AlphaFold proteome paper](https://www.nature.com/articles/s41586-021-03828-1) for details.
* [Uniclust30: v2018_08](http://wwwuser.gwdg.de/~compbiol/uniclust/2018_08/) (unmodified), by Mirdita M. et al., available under a [Creative Commons Attribution-ShareAlike 4.0 International License](http://creativecommons.org/licenses/by-sa/4.0/).
* [MGnify: v2018_12](http://ftp.ebi.ac.uk/pub/databases/metagenomics/peptide_database/current_release/README.txt) (unmodified), by Mitchell AL et al., available free of all copyright restrictions and made fully and freely available for both non-commercial and commercial use under [CC0 1.0 Universal (CC0 1.0) Public Domain Dedication](https://creativecommons.org/publicdomain/zero/1.0/).
The following databases have been mirrored by DeepMind, and are available with
reference to the following:
* [BFD](https://bfd.mmseqs.com/) (unmodified), by Steinegger M. and Söding J.,
available under a
[Creative Commons Attribution-ShareAlike 4.0 International License](http://creativecommons.org/licenses/by-sa/4.0/).
* [BFD](https://bfd.mmseqs.com/) (modified), by Steinegger M. and Söding J.,
modified by DeepMind, available under a
[Creative Commons Attribution-ShareAlike 4.0 International License](http://creativecommons.org/licenses/by-sa/4.0/).
See the Methods section of the
[AlphaFold proteome paper](https://www.nature.com/articles/s41586-021-03828-1)
for details.
* [Uniref30: v2021_03](http://wwwuser.gwdg.de/~compbiol/uniclust/2021_03/)
(unmodified), by Mirdita M. et al., available under a
[Creative Commons Attribution-ShareAlike 4.0 International License](http://creativecommons.org/licenses/by-sa/4.0/).
* [MGnify: v2022_05](http://ftp.ebi.ac.uk/pub/databases/metagenomics/peptide_database/2022_05/README.txt)
(unmodified), by Mitchell AL et al., available free of all copyright
restrictions and made fully and freely available for both non-commercial and
commercial use under
[CC0 1.0 Universal (CC0 1.0) Public Domain Dedication](https://creativecommons.org/publicdomain/zero/1.0/).
......@@ -304,9 +304,9 @@ fractionPlddtVeryHigh | `FLOAT64` | Fraction of the residues in the predi
fractionPlddtVeryLow | `FLOAT64` | Fraction of the residues in the prediction with pLDDT less than 50
gene | `STRING` | The name of the gene if known, e.g. "COII"
geneSynonyms | `ARRAY<STRING>` | Additional synonyms for the gene
globalMetricValue | `FLOAT64` | The mean pLDDT of this prediction
isReferenceProteome | `BOOL` | Is this protein part of the reference proteome?
isReviewed | `BOOL` | Has this protein been reviewed, i.e. is it part of SwissProt?
globalMetricValue | `FLOAT64` | The mean pLDDT of this prediction
latestVersion | `INT64` | The latest AFDB version for this prediction
modelCreatedDate | `DATE` | The date of creation for this entry, e.g. "2022-06-01"
organismCommonNames | `ARRAY<STRING>` | List of common organism names
......
......@@ -117,7 +117,7 @@ class DataPipeline:
uniref90_database_path: str,
mgnify_database_path: str,
bfd_database_path: Optional[str],
uniclust30_database_path: Optional[str],
uniref30_database_path: Optional[str],
small_bfd_database_path: Optional[str],
template_searcher: TemplateSearcher,
template_featurizer: templates.TemplateHitFeaturizer,
......@@ -135,9 +135,9 @@ class DataPipeline:
binary_path=jackhmmer_binary_path,
database_path=small_bfd_database_path)
else:
self.hhblits_bfd_uniclust_runner = hhblits.HHBlits(
self.hhblits_bfd_uniref_runner = hhblits.HHBlits(
binary_path=hhblits_binary_path,
databases=[bfd_database_path, uniclust30_database_path])
databases=[bfd_database_path, uniref30_database_path])
self.jackhmmer_mgnify_runner = jackhmmer.Jackhmmer(
binary_path=jackhmmer_binary_path,
database_path=mgnify_database_path)
......@@ -211,14 +211,14 @@ class DataPipeline:
use_precomputed_msas=self.use_precomputed_msas)
bfd_msa = parsers.parse_stockholm(jackhmmer_small_bfd_result['sto'])
else:
bfd_out_path = os.path.join(msa_output_dir, 'bfd_uniclust_hits.a3m')
hhblits_bfd_uniclust_result = run_msa_tool(
msa_runner=self.hhblits_bfd_uniclust_runner,
bfd_out_path = os.path.join(msa_output_dir, 'bfd_uniref_hits.a3m')
hhblits_bfd_uniref_result = run_msa_tool(
msa_runner=self.hhblits_bfd_uniref_runner,
input_fasta_path=input_fasta_path,
msa_out_path=bfd_out_path,
msa_format='a3m',
use_precomputed_msas=self.use_precomputed_msas)
bfd_msa = parsers.parse_a3m(hhblits_bfd_uniclust_result['a3m'])
bfd_msa = parsers.parse_a3m(hhblits_bfd_uniref_result['a3m'])
templates_result = self.template_featurizer.get_templates(
query_sequence=input_sequence,
......
......@@ -128,3 +128,64 @@ class Linear(hk.Module):
return output
class LayerNorm(hk.LayerNorm):
"""LayerNorm module.
Equivalent to hk.LayerNorm but with different parameter shapes: they are
always vectors rather than possibly higher-rank tensors. This makes it easier
to change the layout whilst keep the model weight-compatible.
"""
def __init__(self,
axis,
create_scale: bool,
create_offset: bool,
eps: float = 1e-5,
scale_init=None,
offset_init=None,
use_fast_variance: bool = False,
name=None,
param_axis=None):
super().__init__(
axis=axis,
create_scale=False,
create_offset=False,
eps=eps,
scale_init=None,
offset_init=None,
use_fast_variance=use_fast_variance,
name=name,
param_axis=param_axis)
self._temp_create_scale = create_scale
self._temp_create_offset = create_offset
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
is_bf16 = (x.dtype == jnp.bfloat16)
if is_bf16:
x = x.astype(jnp.float32)
param_axis = self.param_axis[0] if self.param_axis else -1
param_shape = (x.shape[param_axis],)
param_broadcast_shape = [1] * x.ndim
param_broadcast_shape[param_axis] = x.shape[param_axis]
scale = None
offset = None
if self._temp_create_scale:
scale = hk.get_parameter(
'scale', param_shape, x.dtype, init=self.scale_init)
scale = scale.reshape(param_broadcast_shape)
if self._temp_create_offset:
offset = hk.get_parameter(
'offset', param_shape, x.dtype, init=self.offset_init)
offset = offset.reshape(param_broadcast_shape)
out = super().__call__(x, scale=scale, offset=offset)
if is_bf16:
out = out.astype(jnp.bfloat16)
return out
\ No newline at end of file
......@@ -26,12 +26,12 @@ NUM_TEMPLATES = shape_placeholders.NUM_TEMPLATES
def model_config(name: str) -> ml_collections.ConfigDict:
"""Get the ConfigDict of a CASP14 model."""
if 'multimer' in name:
return CONFIG_MULTIMER
if name not in CONFIG_DIFFS:
raise ValueError(f'Invalid model name {name}.')
cfg = copy.deepcopy(CONFIG)
if 'multimer' in name:
cfg = copy.deepcopy(CONFIG_MULTIMER)
else:
cfg = copy.deepcopy(CONFIG)
cfg.update_from_flattened_dict(CONFIG_DIFFS[name])
return cfg
......@@ -52,11 +52,11 @@ MODEL_PRESETS = {
'model_5_ptm',
),
'multimer': (
'model_1_multimer_v2',
'model_2_multimer_v2',
'model_3_multimer_v2',
'model_4_multimer_v2',
'model_5_multimer_v2',
'model_1_multimer_v3',
'model_2_multimer_v3',
'model_3_multimer_v3',
'model_4_multimer_v3',
'model_5_multimer_v3',
),
}
MODEL_PRESETS['monomer_casp14'] = MODEL_PRESETS['monomer']
......@@ -118,8 +118,32 @@ CONFIG_DIFFS = {
},
'model_5_ptm': {
'model.heads.predicted_aligned_error.weight': 0.1
}
},
'model_1_multimer_v3': {},
'model_2_multimer_v3': {},
'model_3_multimer_v3': {},
'model_4_multimer_v3': {
'model.embeddings_and_evoformer.num_extra_msa': 1152
},
'model_5_multimer_v3': {
'model.embeddings_and_evoformer.num_extra_msa': 1152
},
}
# Key differences between multimer v1/v2 and v3, mostly due to numerical
# optimisations in the TriangleMultiplication module.
common_updates = {
'model.embeddings_and_evoformer.num_msa': 252,
'model.embeddings_and_evoformer.num_extra_msa': 1152,
'model.embeddings_and_evoformer.evoformer.triangle_multiplication_incoming.fuse_projection_weights': False,
'model.embeddings_and_evoformer.evoformer.triangle_multiplication_outgoing.fuse_projection_weights': False,
'model.embeddings_and_evoformer.template.template_pair_stack.triangle_multiplication_incoming.fuse_projection_weights': False,
'model.embeddings_and_evoformer.template.template_pair_stack.triangle_multiplication_outgoing.fuse_projection_weights': False,
}
CONFIG_DIFFS.update(
{f'model_{i}_multimer': common_updates for i in range(1, 6)})
CONFIG_DIFFS.update(
{f'model_{i}_multimer_v2': common_updates for i in range(1, 6)})
CONFIG = ml_collections.ConfigDict({
'data': {
......@@ -260,14 +284,16 @@ CONFIG = ml_collections.ConfigDict({
'equation': 'ikc,jkc->ijc',
'num_intermediate_channel': 128,
'orientation': 'per_row',
'shared_dropout': True
'shared_dropout': True,
'fuse_projection_weights': False,
},
'triangle_multiplication_incoming': {
'dropout_rate': 0.25,
'equation': 'kjc,kic->ijc',
'num_intermediate_channel': 128,
'orientation': 'per_row',
'shared_dropout': True
'shared_dropout': True,
'fuse_projection_weights': False,
},
'pair_transition': {
'dropout_rate': 0.0,
......@@ -328,14 +354,16 @@ CONFIG = ml_collections.ConfigDict({
'equation': 'ikc,jkc->ijc',
'num_intermediate_channel': 64,
'orientation': 'per_row',
'shared_dropout': True
'shared_dropout': True,
'fuse_projection_weights': False,
},
'triangle_multiplication_incoming': {
'dropout_rate': 0.25,
'equation': 'kjc,kic->ijc',
'num_intermediate_channel': 64,
'orientation': 'per_row',
'shared_dropout': True
'shared_dropout': True,
'fuse_projection_weights': False,
},
'pair_transition': {
'dropout_rate': 0.0,
......@@ -354,7 +382,7 @@ CONFIG = ml_collections.ConfigDict({
'multimer_mode': False,
'subbatch_size': 4,
'use_remat': False,
'zero_init': True
'zero_init': True,
},
'heads': {
'distogram': {
......@@ -483,27 +511,29 @@ CONFIG_MULTIMER = ml_collections.ConfigDict({
'gating': True,
'num_head': 4,
'orientation': 'per_row',
'shared_dropout': True
'shared_dropout': True,
},
'triangle_multiplication_incoming': {
'dropout_rate': 0.25,
'equation': 'kjc,kic->ijc',
'num_intermediate_channel': 128,
'orientation': 'per_row',
'shared_dropout': True
'shared_dropout': True,
'fuse_projection_weights': True,
},
'triangle_multiplication_outgoing': {
'dropout_rate': 0.25,
'equation': 'ikc,jkc->ijc',
'num_intermediate_channel': 128,
'orientation': 'per_row',
'shared_dropout': True
'shared_dropout': True,
'fuse_projection_weights': True,
}
},
'extra_msa_channel': 64,
'extra_msa_stack_num_block': 4,
'num_msa': 252,
'num_extra_msa': 1152,
'num_msa': 508,
'num_extra_msa': 2048,
'masked_msa': {
'profile_prob': 0.1,
'replace_fraction': 0.15,
......@@ -564,24 +594,28 @@ CONFIG_MULTIMER = ml_collections.ConfigDict({
'equation': 'kjc,kic->ijc',
'num_intermediate_channel': 64,
'orientation': 'per_row',
'shared_dropout': True
'shared_dropout': True,
'fuse_projection_weights': True,
},
'triangle_multiplication_outgoing': {
'dropout_rate': 0.25,
'equation': 'ikc,jkc->ijc',
'num_intermediate_channel': 64,
'orientation': 'per_row',
'shared_dropout': True
'shared_dropout': True,
'fuse_projection_weights': True,
}
}
},
},
'global_config': {
'bfloat16': True,
'bfloat16_output': False,
'deterministic': False,
'multimer_mode': True,
'subbatch_size': 4,
'use_remat': False,
'zero_init': True
'zero_init': True,
},
'heads': {
'distogram': {
......@@ -651,7 +685,13 @@ CONFIG_MULTIMER = ml_collections.ConfigDict({
}
},
'num_ensemble_eval': 1,
'num_recycle': 3,
'num_recycle': 20,
# A negative value indicates that no early stopping will occur, i.e.
# the model will always run `num_recycle` number of recycling
# iterations. A positive value will enable early stopping if the
# difference in pairwise distances is less than the tolerance between
# recycling steps.
'recycle_early_stop_tolerance': 0.5,
'resample_msa_in_recycling': True
}
})
......@@ -331,7 +331,7 @@ class FoldIteration(hk.Module):
safe_key, *sub_keys = safe_key.split(3)
sub_keys = iter(sub_keys)
act = safe_dropout_fn(act, next(sub_keys))
act = hk.LayerNorm(
act = common_modules.LayerNorm(
axis=[-1],
create_scale=True,
create_offset=True,
......@@ -353,7 +353,7 @@ class FoldIteration(hk.Module):
act = jax.nn.relu(act)
act += input_act
act = safe_dropout_fn(act, next(sub_keys))
act = hk.LayerNorm(
act = common_modules.LayerNorm(
axis=[-1],
create_scale=True,
create_offset=True,
......@@ -410,7 +410,7 @@ def generate_affines(representations, batch, config, global_config,
c = config
sequence_mask = batch['seq_mask'][:, None]
act = hk.LayerNorm(
act = common_modules.LayerNorm(
axis=[-1],
create_scale=True,
create_offset=True,
......@@ -433,7 +433,7 @@ def generate_affines(representations, batch, config, global_config,
'affine': affine.to_tensor(),
}
act_2d = hk.LayerNorm(
act_2d = common_modules.LayerNorm(
axis=[-1],
create_scale=True,
create_offset=True,
......
......@@ -427,7 +427,7 @@ class FoldIteration(hk.Module):
safe_key, *sub_keys = safe_key.split(3)
sub_keys = iter(sub_keys)
act = safe_dropout_fn(act, next(sub_keys))
act = hk.LayerNorm(
act = common_modules.LayerNorm(
axis=-1,
create_scale=True,
create_offset=True,
......@@ -448,7 +448,7 @@ class FoldIteration(hk.Module):
act = jax.nn.relu(act)
act += input_act
act = safe_dropout_fn(act, next(sub_keys))
act = hk.LayerNorm(
act = common_modules.LayerNorm(
axis=-1,
create_scale=True,
create_offset=True,
......@@ -500,7 +500,7 @@ def generate_monomer_rigids(representations: Mapping[str, jnp.ndarray],
"""
c = config
sequence_mask = batch['seq_mask'][:, None]
act = hk.LayerNorm(
act = common_modules.LayerNorm(
axis=-1, create_scale=True, create_offset=True, name='single_layer_norm')(
representations['single'])
......@@ -523,7 +523,7 @@ def generate_monomer_rigids(representations: Mapping[str, jnp.ndarray],
rigid
}
act_2d = hk.LayerNorm(
act_2d = common_modules.LayerNorm(
axis=-1,
create_scale=True,
create_offset=True,
......
......@@ -133,7 +133,7 @@ def flatten(instance):
inner_treedefs = []
num_arrays = []
for array_like in array_likes:
flat_array_like, inner_treedef = jax.tree_flatten(array_like)
flat_array_like, inner_treedef = jax.tree_util.tree_flatten(array_like)
inner_treedefs.append(inner_treedef)
flat_array_likes += flat_array_like
num_arrays.append(len(flat_array_like))
......@@ -206,7 +206,7 @@ class StructOfArray:
for num_array, inner_treedef, array_field in zip(num_arrays,
inner_treedefs,
array_fields):
value_dict[array_field] = jax.tree_unflatten(
value_dict[array_field] = jax.tree_util.tree_unflatten(
inner_treedef, data[array_start:array_start + num_array])
array_start += num_array
metadata_fields = get_metadata_fields(new_cls)
......
......@@ -47,11 +47,11 @@ def _maybe_get_size(array, axis):
def _expand_axes(axes, values, name='sharded_apply'):
values_tree_def = jax.tree_flatten(values)[1]
values_tree_def = jax.tree_util.tree_flatten(values)[1]
flat_axes = jax.api_util.flatten_axes(name, values_tree_def, axes)
# Replace None's with PROXY
flat_axes = [PROXY if x is None else x for x in flat_axes]
return jax.tree_unflatten(values_tree_def, flat_axes)
return jax.tree_util.tree_unflatten(values_tree_def, flat_axes)
def sharded_map(
......@@ -126,7 +126,7 @@ def sharded_apply(
in_axes_ = _expand_axes(in_axes, args)
in_sizes = jax.tree_map(_maybe_get_size, args, in_axes_)
flat_sizes = jax.tree_flatten(in_sizes)[0]
flat_sizes = jax.tree_util.tree_flatten(in_sizes)[0]
in_size = max(flat_sizes)
assert all(i in {in_size, -1} for i in flat_sizes)
......
......@@ -501,7 +501,7 @@ class Transition(hk.Module):
num_intermediate = int(nc * self.config.num_intermediate_factor)
mask = jnp.expand_dims(mask, axis=-1)
act = hk.LayerNorm(
act = common_modules.LayerNorm(
axis=[-1],
create_scale=True,
create_offset=True,
......@@ -569,12 +569,15 @@ class Attention(hk.Module):
q_weights = hk.get_parameter(
'query_w', shape=(q_data.shape[-1], num_head, key_dim),
dtype=q_data.dtype,
init=glorot_uniform())
k_weights = hk.get_parameter(
'key_w', shape=(m_data.shape[-1], num_head, key_dim),
dtype=q_data.dtype,
init=glorot_uniform())
v_weights = hk.get_parameter(
'value_w', shape=(m_data.shape[-1], num_head, value_dim),
dtype=q_data.dtype,
init=glorot_uniform())
q = jnp.einsum('bqa,ahc->bqhc', q_data, q_weights) * key_dim**(-0.5)
......@@ -595,10 +598,12 @@ class Attention(hk.Module):
gating_weights = hk.get_parameter(
'gating_w',
shape=(q_data.shape[-1], num_head, value_dim),
dtype=q_data.dtype,
init=hk.initializers.Constant(0.0))
gating_bias = hk.get_parameter(
'gating_b',
shape=(num_head, value_dim),
dtype=q_data.dtype,
init=hk.initializers.Constant(1.0))
gate_values = jnp.einsum('bqc, chv->bqhv', q_data,
......@@ -610,9 +615,12 @@ class Attention(hk.Module):
o_weights = hk.get_parameter(
'output_w', shape=(num_head, value_dim, self.output_dim),
dtype=q_data.dtype,
init=init)
o_bias = hk.get_parameter('output_b', shape=(self.output_dim,),
init=hk.initializers.Constant(0.0))
o_bias = hk.get_parameter(
'output_b', shape=(self.output_dim,),
dtype=q_data.dtype,
init=hk.initializers.Constant(0.0))
output = jnp.einsum('bqhc,hco->bqo', weighted_avg, o_weights) + o_bias
......@@ -658,12 +666,15 @@ class GlobalAttention(hk.Module):
q_weights = hk.get_parameter(
'query_w', shape=(q_data.shape[-1], num_head, key_dim),
dtype=q_data.dtype,
init=glorot_uniform())
k_weights = hk.get_parameter(
'key_w', shape=(m_data.shape[-1], key_dim),
dtype=q_data.dtype,
init=glorot_uniform())
v_weights = hk.get_parameter(
'value_w', shape=(m_data.shape[-1], value_dim),
dtype=q_data.dtype,
init=glorot_uniform())
v = jnp.einsum('bka,ac->bkc', m_data, v_weights)
......@@ -684,18 +695,23 @@ class GlobalAttention(hk.Module):
o_weights = hk.get_parameter(
'output_w', shape=(num_head, value_dim, self.output_dim),
dtype=q_data.dtype,
init=init)
o_bias = hk.get_parameter('output_b', shape=(self.output_dim,),
init=hk.initializers.Constant(0.0))
o_bias = hk.get_parameter(
'output_b', shape=(self.output_dim,),
dtype=q_data.dtype,
init=hk.initializers.Constant(0.0))
if self.config.gating:
gating_weights = hk.get_parameter(
'gating_w',
shape=(q_data.shape[-1], num_head, value_dim),
dtype=q_data.dtype,
init=hk.initializers.Constant(0.0))
gating_bias = hk.get_parameter(
'gating_b',
shape=(num_head, value_dim),
dtype=q_data.dtype,
init=hk.initializers.Constant(1.0))
gate_values = jnp.einsum('bqc, chv->bqhv', q_data, gating_weights)
......@@ -745,11 +761,11 @@ class MSARowAttentionWithPairBias(hk.Module):
bias = (1e9 * (msa_mask - 1.))[:, None, None, :]
assert len(bias.shape) == 4
msa_act = hk.LayerNorm(
msa_act = common_modules.LayerNorm(
axis=[-1], create_scale=True, create_offset=True, name='query_norm')(
msa_act)
pair_act = hk.LayerNorm(
pair_act = common_modules.LayerNorm(
axis=[-1],
create_scale=True,
create_offset=True,
......@@ -760,6 +776,7 @@ class MSARowAttentionWithPairBias(hk.Module):
weights = hk.get_parameter(
'feat_2d_weights',
shape=(pair_act.shape[-1], c.num_head),
dtype=msa_act.dtype,
init=hk.initializers.RandomNormal(stddev=init_factor))
nonbatched_bias = jnp.einsum('qkc,ch->hqk', pair_act, weights)
......@@ -812,7 +829,7 @@ class MSAColumnAttention(hk.Module):
bias = (1e9 * (msa_mask - 1.))[:, None, None, :]
assert len(bias.shape) == 4
msa_act = hk.LayerNorm(
msa_act = common_modules.LayerNorm(
axis=[-1], create_scale=True, create_offset=True, name='query_norm')(
msa_act)
......@@ -867,7 +884,7 @@ class MSAColumnGlobalAttention(hk.Module):
bias = (1e9 * (msa_mask - 1.))[:, None, None, :]
assert len(bias.shape) == 4
msa_act = hk.LayerNorm(
msa_act = common_modules.LayerNorm(
axis=[-1], create_scale=True, create_offset=True, name='query_norm')(
msa_act)
......@@ -924,7 +941,7 @@ class TriangleAttention(hk.Module):
bias = (1e9 * (pair_mask - 1.))[:, None, None, :]
assert len(bias.shape) == 4
pair_act = hk.LayerNorm(
pair_act = common_modules.LayerNorm(
axis=[-1], create_scale=True, create_offset=True, name='query_norm')(
pair_act)
......@@ -932,6 +949,7 @@ class TriangleAttention(hk.Module):
weights = hk.get_parameter(
'feat_2d_weights',
shape=(pair_act.shape[-1], c.num_head),
dtype=pair_act.dtype,
init=hk.initializers.RandomNormal(stddev=init_factor))
nonbatched_bias = jnp.einsum('qkc,ch->hqk', pair_act, weights)
......@@ -1029,7 +1047,7 @@ class PredictedLDDTHead(hk.Module):
"""
act = representations['structure_module']
act = hk.LayerNorm(
act = common_modules.LayerNorm(
axis=[-1],
create_scale=True,
create_offset=True,
......@@ -1251,6 +1269,19 @@ class ExperimentallyResolvedHead(hk.Module):
return output
def _layer_norm(axis=-1, name='layer_norm'):
return common_modules.LayerNorm(
axis=axis,
create_scale=True,
create_offset=True,
eps=1e-5,
use_fast_variance=True,
scale_init=hk.initializers.Constant(1.),
offset_init=hk.initializers.Constant(0.),
param_axis=axis,
name=name)
class TriangleMultiplication(hk.Module):
"""Triangle multiplication layer ("outgoing" or "incoming").
......@@ -1263,25 +1294,34 @@ class TriangleMultiplication(hk.Module):
self.config = config
self.global_config = global_config
def __call__(self, act, mask, is_training=True):
def __call__(self, left_act, left_mask, is_training=True):
"""Builds TriangleMultiplication module.
Arguments:
act: Pair activations, shape [N_res, N_res, c_z]
mask: Pair mask, shape [N_res, N_res].
left_act: Pair activations, shape [N_res, N_res, c_z]
left_mask: Pair mask, shape [N_res, N_res].
is_training: Whether the module is in training mode.
Returns:
Outputs, same shape/type as act.
Outputs, same shape/type as left_act.
"""
del is_training
if self.config.fuse_projection_weights:
return self._fused_triangle_multiplication(left_act, left_mask)
else:
return self._triangle_multiplication(left_act, left_mask)
@hk.transparent
def _triangle_multiplication(self, left_act, left_mask):
"""Implementation of TriangleMultiplication used in AF2 and AF-M<2.3."""
c = self.config
gc = self.global_config
mask = mask[..., None]
mask = left_mask[..., None]
act = hk.LayerNorm(axis=[-1], create_scale=True, create_offset=True,
name='layer_norm_input')(act)
act = common_modules.LayerNorm(axis=[-1], create_scale=True, create_offset=True,
name='layer_norm_input')(left_act)
input_act = act
left_projection = common_modules.Linear(
......@@ -1317,7 +1357,7 @@ class TriangleMultiplication(hk.Module):
# b = left_proj_act and a = right_proj_act
act = jnp.einsum(c.equation, left_proj_act, right_proj_act)
act = hk.LayerNorm(
act = common_modules.LayerNorm(
axis=[-1],
create_scale=True,
create_offset=True,
......@@ -1340,6 +1380,50 @@ class TriangleMultiplication(hk.Module):
return act
@hk.transparent
def _fused_triangle_multiplication(self, left_act, left_mask):
"""TriangleMultiplication with fused projection weights."""
mask = left_mask[..., None]
c = self.config
gc = self.global_config
left_act = _layer_norm(axis=-1, name='left_norm_input')(left_act)
# Both left and right projections are fused into projection.
projection = common_modules.Linear(
2*c.num_intermediate_channel, name='projection')
proj_act = mask * projection(left_act)
# Both left + right gate are fused into gate_values.
gate_values = common_modules.Linear(
2 * c.num_intermediate_channel,
name='gate',
bias_init=1.,
initializer=utils.final_init(gc))(left_act)
proj_act *= jax.nn.sigmoid(gate_values)
left_proj_act = proj_act[:, :, :c.num_intermediate_channel]
right_proj_act = proj_act[:, :, c.num_intermediate_channel:]
act = jnp.einsum(c.equation, left_proj_act, right_proj_act)
act = _layer_norm(axis=-1, name='center_norm')(act)
output_channel = int(left_act.shape[-1])
act = common_modules.Linear(
output_channel,
initializer=utils.final_init(gc),
name='output_projection')(act)
gate_values = common_modules.Linear(
output_channel,
bias_init=1.,
initializer=utils.final_init(gc),
name='gating_linear')(left_act)
act *= jax.nn.sigmoid(gate_values)
return act
class DistogramHead(hk.Module):
"""Head to predict a distogram.
......@@ -1446,7 +1530,7 @@ class OuterProductMean(hk.Module):
c = self.config
mask = mask[..., None]
act = hk.LayerNorm([-1], True, True, name='layer_norm_input')(act)
act = common_modules.LayerNorm([-1], True, True, name='layer_norm_input')(act)
left_act = mask * common_modules.Linear(
c.num_outer_channel,
......@@ -1469,9 +1553,11 @@ class OuterProductMean(hk.Module):
'output_w',
shape=(c.num_outer_channel, c.num_outer_channel,
self.num_output_channel),
dtype=act.dtype,
init=init_w)
output_b = hk.get_parameter(
'output_b', shape=(self.num_output_channel,),
dtype=act.dtype,
init=hk.initializers.Constant(0.0))
def compute_chunk(left_act):
......@@ -1738,7 +1824,7 @@ class EmbeddingsAndEvoformer(hk.Module):
dgram)
if c.recycle_features:
prev_msa_first_row = hk.LayerNorm(
prev_msa_first_row = common_modules.LayerNorm(
axis=[-1],
create_scale=True,
create_offset=True,
......@@ -1746,7 +1832,7 @@ class EmbeddingsAndEvoformer(hk.Module):
batch['prev_msa_first_row'])
msa_activations = msa_activations.at[0].add(prev_msa_first_row)
pair_activations += hk.LayerNorm(
pair_activations += common_modules.LayerNorm(
axis=[-1],
create_scale=True,
create_offset=True,
......@@ -2020,7 +2106,7 @@ class SingleTemplateEmbedding(hk.Module):
self.config.template_pair_stack, self.global_config)(
act, mask_2d, is_training)
act = hk.LayerNorm([-1], True, True, name='output_layer_norm')(act)
act = common_modules.LayerNorm([-1], True, True, name='output_layer_norm')(act)
return act
......
......@@ -475,20 +475,51 @@ class AlphaFold(hk.Module):
# Eval mode or tests: use the maximum number of iterations.
num_iter = c.num_recycle
def recycle_body(i, x):
del i
prev, safe_key = x
def distances(points):
"""Compute all pairwise distances for a set of points."""
return jnp.sqrt(jnp.sum((points[:, None] - points[None, :])**2,
axis=-1))
def recycle_body(x):
i, _, prev, safe_key = x
safe_key1, safe_key2 = safe_key.split() if c.resample_msa_in_recycling else safe_key.duplicate() # pylint: disable=line-too-long
ret = apply_network(prev=prev, safe_key=safe_key2)
return get_prev(ret), safe_key1
prev, safe_key = hk.fori_loop(0, num_iter, recycle_body, (prev, safe_key))
return i+1, prev, get_prev(ret), safe_key1
def recycle_cond(x):
i, prev, next_in, _ = x
ca_idx = residue_constants.atom_order['CA']
sq_diff = jnp.square(distances(prev['prev_pos'][:, ca_idx, :]) -
distances(next_in['prev_pos'][:, ca_idx, :]))
mask = batch['seq_mask'][:, None] * batch['seq_mask'][None, :]
sq_diff = utils.mask_mean(mask, sq_diff)
# Early stopping criteria based on criteria used in
# AF2Complex: https://www.nature.com/articles/s41467-022-29394-2
diff = jnp.sqrt(sq_diff + 1e-8) # avoid bad numerics giving negatives
less_than_max_recycles = (i < num_iter)
has_exceeded_tolerance = (
(i == 0) | (diff > c.recycle_early_stop_tolerance))
return less_than_max_recycles & has_exceeded_tolerance
if hk.running_init():
num_recycles, _, prev, safe_key = recycle_body(
(0, prev, prev, safe_key))
else:
num_recycles, _, prev, safe_key = hk.while_loop(
recycle_cond,
recycle_body,
(0, prev, prev, safe_key))
else:
# No recycling.
num_recycles = 0
# Run extra iteration.
ret = apply_network(prev=prev, safe_key=safe_key)
if not return_representations:
del ret['representations']
ret['num_recycles'] = num_recycles
return ret
......@@ -524,11 +555,13 @@ class EmbeddingsAndEvoformer(hk.Module):
Feature embedding using the features as described before.
"""
c = self.config
gc = self.global_config
rel_feats = []
pos = batch['residue_index']
asym_id = batch['asym_id']
asym_id_same = jnp.equal(asym_id[:, None], asym_id[None, :])
offset = pos[:, None] - pos[None, :]
dtype = jnp.bfloat16 if gc.bfloat16 else jnp.float32
clipped_offset = jnp.clip(
offset + c.max_relative_idx, a_min=0, a_max=2 * c.max_relative_idx)
......@@ -568,6 +601,7 @@ class EmbeddingsAndEvoformer(hk.Module):
rel_feat = jnp.concatenate(rel_feats, axis=-1)
rel_feat = rel_feat.astype(dtype)
return common_modules.Linear(
c.pair_channel,
name='position_activations')(
......@@ -579,6 +613,7 @@ class EmbeddingsAndEvoformer(hk.Module):
gc = self.global_config
batch = dict(batch)
dtype = jnp.bfloat16 if gc.bfloat16 else jnp.float32
if safe_key is None:
safe_key = prng.SafeKey(hk.next_rng_key())
......@@ -587,177 +622,178 @@ class EmbeddingsAndEvoformer(hk.Module):
batch['msa_profile'] = make_msa_profile(batch)
target_feat = jax.nn.one_hot(batch['aatype'], 21)
preprocess_1d = common_modules.Linear(
c.msa_channel, name='preprocess_1d')(
target_feat)
safe_key, sample_key, mask_key = safe_key.split(3)
batch = sample_msa(sample_key, batch, c.num_msa)
batch = make_masked_msa(batch, mask_key, c.masked_msa)
(batch['cluster_profile'],
batch['cluster_deletion_mean']) = nearest_neighbor_clusters(batch)
msa_feat = create_msa_feat(batch)
preprocess_msa = common_modules.Linear(
c.msa_channel, name='preprocess_msa')(
msa_feat)
msa_activations = jnp.expand_dims(preprocess_1d, axis=0) + preprocess_msa
left_single = common_modules.Linear(
c.pair_channel, name='left_single')(
target_feat)
right_single = common_modules.Linear(
c.pair_channel, name='right_single')(
target_feat)
pair_activations = left_single[:, None] + right_single[None]
mask_2d = batch['seq_mask'][:, None] * batch['seq_mask'][None, :]
mask_2d = mask_2d.astype(jnp.float32)
if c.recycle_pos:
prev_pseudo_beta = modules.pseudo_beta_fn(
batch['aatype'], batch['prev_pos'], None)
dgram = modules.dgram_from_positions(
prev_pseudo_beta, **self.config.prev_pos)
pair_activations += common_modules.Linear(
c.pair_channel, name='prev_pos_linear')(
dgram)
if c.recycle_features:
prev_msa_first_row = hk.LayerNorm(
axis=[-1],
create_scale=True,
create_offset=True,
name='prev_msa_first_row_norm')(
batch['prev_msa_first_row'])
msa_activations = msa_activations.at[0].add(prev_msa_first_row)
pair_activations += hk.LayerNorm(
axis=[-1],
create_scale=True,
create_offset=True,
name='prev_pair_norm')(
batch['prev_pair'])
if c.max_relative_idx:
pair_activations += self._relative_encoding(batch)
if c.template.enabled:
template_module = TemplateEmbedding(c.template, gc)
template_batch = {
'template_aatype': batch['template_aatype'],
'template_all_atom_positions': batch['template_all_atom_positions'],
'template_all_atom_mask': batch['template_all_atom_mask']
with utils.bfloat16_context():
target_feat = jax.nn.one_hot(batch['aatype'], 21).astype(dtype)
preprocess_1d = common_modules.Linear(
c.msa_channel, name='preprocess_1d')(
target_feat)
safe_key, sample_key, mask_key = safe_key.split(3)
batch = sample_msa(sample_key, batch, c.num_msa)
batch = make_masked_msa(batch, mask_key, c.masked_msa)
(batch['cluster_profile'],
batch['cluster_deletion_mean']) = nearest_neighbor_clusters(batch)
msa_feat = create_msa_feat(batch).astype(dtype)
preprocess_msa = common_modules.Linear(
c.msa_channel, name='preprocess_msa')(
msa_feat)
msa_activations = jnp.expand_dims(preprocess_1d, axis=0) + preprocess_msa
left_single = common_modules.Linear(
c.pair_channel, name='left_single')(
target_feat)
right_single = common_modules.Linear(
c.pair_channel, name='right_single')(
target_feat)
pair_activations = left_single[:, None] + right_single[None]
mask_2d = batch['seq_mask'][:, None] * batch['seq_mask'][None, :]
mask_2d = mask_2d.astype(dtype)
if c.recycle_pos:
prev_pseudo_beta = modules.pseudo_beta_fn(
batch['aatype'], batch['prev_pos'], None)
dgram = modules.dgram_from_positions(
prev_pseudo_beta, **self.config.prev_pos)
dgram = dgram.astype(dtype)
pair_activations += common_modules.Linear(
c.pair_channel, name='prev_pos_linear')(
dgram)
if c.recycle_features:
prev_msa_first_row = common_modules.LayerNorm(
axis=[-1],
create_scale=True,
create_offset=True,
name='prev_msa_first_row_norm')(
batch['prev_msa_first_row']).astype(dtype)
msa_activations = msa_activations.at[0].add(prev_msa_first_row)
pair_activations += common_modules.LayerNorm(
axis=[-1],
create_scale=True,
create_offset=True,
name='prev_pair_norm')(
batch['prev_pair']).astype(dtype)
if c.max_relative_idx:
pair_activations += self._relative_encoding(batch)
if c.template.enabled:
template_module = TemplateEmbedding(c.template, gc)
template_batch = {
'template_aatype': batch['template_aatype'],
'template_all_atom_positions': batch['template_all_atom_positions'],
'template_all_atom_mask': batch['template_all_atom_mask']
}
# Construct a mask such that only intra-chain template features are
# computed, since all templates are for each chain individually.
multichain_mask = batch['asym_id'][:, None] == batch['asym_id'][None, :]
safe_key, safe_subkey = safe_key.split()
template_act = template_module(
query_embedding=pair_activations,
template_batch=template_batch,
padding_mask_2d=mask_2d,
multichain_mask_2d=multichain_mask,
is_training=is_training,
safe_key=safe_subkey)
pair_activations += template_act
# Extra MSA stack.
(extra_msa_feat,
extra_msa_mask) = create_extra_msa_feature(batch, c.num_extra_msa)
extra_msa_activations = common_modules.Linear(
c.extra_msa_channel,
name='extra_msa_activations')(
extra_msa_feat).astype(dtype)
extra_msa_mask = extra_msa_mask.astype(dtype)
extra_evoformer_input = {
'msa': extra_msa_activations,
'pair': pair_activations,
}
# Construct a mask such that only intra-chain template features are
# computed, since all templates are for each chain individually.
multichain_mask = batch['asym_id'][:, None] == batch['asym_id'][None, :]
safe_key, safe_subkey = safe_key.split()
template_act = template_module(
query_embedding=pair_activations,
template_batch=template_batch,
padding_mask_2d=mask_2d,
multichain_mask_2d=multichain_mask,
is_training=is_training,
safe_key=safe_subkey)
pair_activations += template_act
# Extra MSA stack.
(extra_msa_feat,
extra_msa_mask) = create_extra_msa_feature(batch, c.num_extra_msa)
extra_msa_activations = common_modules.Linear(
c.extra_msa_channel,
name='extra_msa_activations')(
extra_msa_feat)
extra_msa_mask = extra_msa_mask.astype(jnp.float32)
extra_evoformer_input = {
'msa': extra_msa_activations,
'pair': pair_activations,
}
extra_masks = {'msa': extra_msa_mask, 'pair': mask_2d}
extra_masks = {'msa': extra_msa_mask, 'pair': mask_2d}
extra_evoformer_iteration = modules.EvoformerIteration(
c.evoformer, gc, is_extra_msa=True, name='extra_msa_stack')
extra_evoformer_iteration = modules.EvoformerIteration(
c.evoformer, gc, is_extra_msa=True, name='extra_msa_stack')
def extra_evoformer_fn(x):
act, safe_key = x
safe_key, safe_subkey = safe_key.split()
extra_evoformer_output = extra_evoformer_iteration(
activations=act,
masks=extra_masks,
is_training=is_training,
safe_key=safe_subkey)
return (extra_evoformer_output, safe_key)
if gc.use_remat:
extra_evoformer_fn = hk.remat(extra_evoformer_fn)
safe_key, safe_subkey = safe_key.split()
extra_evoformer_stack = layer_stack.layer_stack(
c.extra_msa_stack_num_block)(
extra_evoformer_fn)
extra_evoformer_output, safe_key = extra_evoformer_stack(
(extra_evoformer_input, safe_subkey))
pair_activations = extra_evoformer_output['pair']
# Get the size of the MSA before potentially adding templates, so we
# can crop out the templates later.
num_msa_sequences = msa_activations.shape[0]
evoformer_input = {
'msa': msa_activations,
'pair': pair_activations,
}
evoformer_masks = {'msa': batch['msa_mask'].astype(jnp.float32),
'pair': mask_2d}
if c.template.enabled:
template_features, template_masks = (
template_embedding_1d(batch=batch, num_channel=c.msa_channel))
evoformer_input['msa'] = jnp.concatenate(
[evoformer_input['msa'], template_features], axis=0)
evoformer_masks['msa'] = jnp.concatenate(
[evoformer_masks['msa'], template_masks], axis=0)
def extra_evoformer_fn(x):
act, safe_key = x
safe_key, safe_subkey = safe_key.split()
extra_evoformer_output = extra_evoformer_iteration(
activations=act,
masks=extra_masks,
is_training=is_training,
safe_key=safe_subkey)
return (extra_evoformer_output, safe_key)
evoformer_iteration = modules.EvoformerIteration(
c.evoformer, gc, is_extra_msa=False, name='evoformer_iteration')
if gc.use_remat:
extra_evoformer_fn = hk.remat(extra_evoformer_fn)
def evoformer_fn(x):
act, safe_key = x
safe_key, safe_subkey = safe_key.split()
evoformer_output = evoformer_iteration(
activations=act,
masks=evoformer_masks,
is_training=is_training,
safe_key=safe_subkey)
return (evoformer_output, safe_key)
if gc.use_remat:
evoformer_fn = hk.remat(evoformer_fn)
extra_evoformer_stack = layer_stack.layer_stack(
c.extra_msa_stack_num_block)(
extra_evoformer_fn)
extra_evoformer_output, safe_key = extra_evoformer_stack(
(extra_evoformer_input, safe_subkey))
pair_activations = extra_evoformer_output['pair']
# Get the size of the MSA before potentially adding templates, so we
# can crop out the templates later.
num_msa_sequences = msa_activations.shape[0]
evoformer_input = {
'msa': msa_activations,
'pair': pair_activations,
}
evoformer_masks = {
'msa': batch['msa_mask'].astype(dtype),
'pair': mask_2d
}
if c.template.enabled:
template_features, template_masks = (
template_embedding_1d(
batch=batch, num_channel=c.msa_channel, global_config=gc))
evoformer_input['msa'] = jnp.concatenate(
[evoformer_input['msa'], template_features], axis=0)
evoformer_masks['msa'] = jnp.concatenate(
[evoformer_masks['msa'], template_masks], axis=0)
evoformer_iteration = modules.EvoformerIteration(
c.evoformer, gc, is_extra_msa=False, name='evoformer_iteration')
def evoformer_fn(x):
act, safe_key = x
safe_key, safe_subkey = safe_key.split()
evoformer_output = evoformer_iteration(
activations=act,
masks=evoformer_masks,
is_training=is_training,
safe_key=safe_subkey)
return (evoformer_output, safe_key)
if gc.use_remat:
evoformer_fn = hk.remat(evoformer_fn)
safe_key, safe_subkey = safe_key.split()
evoformer_stack = layer_stack.layer_stack(c.evoformer_num_block)(
evoformer_fn)
safe_key, safe_subkey = safe_key.split()
evoformer_stack = layer_stack.layer_stack(c.evoformer_num_block)(
evoformer_fn)
def run_evoformer(evoformer_input):
evoformer_output, _ = evoformer_stack((evoformer_input, safe_subkey))
return evoformer_output
def run_evoformer(evoformer_input):
evoformer_output, _ = evoformer_stack((evoformer_input, safe_subkey))
return evoformer_output
evoformer_output = run_evoformer(evoformer_input)
evoformer_output = run_evoformer(evoformer_input)
msa_activations = evoformer_output['msa']
pair_activations = evoformer_output['pair']
msa_activations = evoformer_output['msa']
pair_activations = evoformer_output['pair']
single_activations = common_modules.Linear(
c.seq_channel, name='single_activations')(
msa_activations[0])
single_activations = common_modules.Linear(
c.seq_channel, name='single_activations')(
msa_activations[0])
output.update({
'single':
......@@ -771,6 +807,12 @@ class EmbeddingsAndEvoformer(hk.Module):
msa_activations[0],
})
# Convert back to float32 if we're not saving memory.
if not gc.bfloat16_output:
for k, v in output.items():
if v.dtype == jnp.bfloat16:
output[k] = v.astype(jnp.float32)
return output
......@@ -917,6 +959,9 @@ class SingleTemplateEmbedding(hk.Module):
# backbone affine - i.e. in each residues local frame, what direction are
# each of the other residues.
raw_atom_pos = template_all_atom_positions
if gc.bfloat16:
# Vec3Arrays are required to be float32
raw_atom_pos = raw_atom_pos.astype(jnp.float32)
atom_pos = geometry.Vec3Array.from_array(raw_atom_pos)
rigid, backbone_mask = folding_multimer.make_backbone_affine(
......@@ -928,6 +973,10 @@ class SingleTemplateEmbedding(hk.Module):
unit_vector = rigid_vec.normalized()
unit_vector = [unit_vector.x, unit_vector.y, unit_vector.z]
if gc.bfloat16:
unit_vector = [x.astype(jnp.bfloat16) for x in unit_vector]
backbone_mask = backbone_mask.astype(jnp.bfloat16)
backbone_mask_2d = backbone_mask[:, None] * backbone_mask[None, :]
backbone_mask_2d *= multichain_mask_2d
unit_vector = [x*backbone_mask_2d for x in unit_vector]
......@@ -937,7 +986,7 @@ class SingleTemplateEmbedding(hk.Module):
to_concat.extend([(x, 0) for x in unit_vector])
to_concat.append((backbone_mask_2d, 0))
query_embedding = hk.LayerNorm(
query_embedding = common_modules.LayerNorm(
axis=[-1],
create_scale=True,
create_offset=True,
......@@ -986,12 +1035,13 @@ class SingleTemplateEmbedding(hk.Module):
template_iteration_fn)
act, safe_key = template_stack((act, safe_subkey))
act = hk.LayerNorm(
act = common_modules.LayerNorm(
axis=[-1],
create_scale=True,
create_offset=True,
name='output_layer_norm')(
act)
return act
......@@ -1044,21 +1094,18 @@ class TemplateEmbeddingIteration(hk.Module):
act,
pair_mask,
safe_key=next(sub_keys))
act = dropout_wrapper_fn(
modules.TriangleAttention(c.triangle_attention_starting_node, gc,
name='triangle_attention_starting_node'),
act,
pair_mask,
safe_key=next(sub_keys))
act = dropout_wrapper_fn(
modules.TriangleAttention(c.triangle_attention_ending_node, gc,
name='triangle_attention_ending_node'),
act,
pair_mask,
safe_key=next(sub_keys))
act = dropout_wrapper_fn(
modules.Transition(c.pair_transition, gc,
name='pair_transition'),
......@@ -1069,7 +1116,7 @@ class TemplateEmbeddingIteration(hk.Module):
return act
def template_embedding_1d(batch, num_channel):
def template_embedding_1d(batch, num_channel, global_config):
"""Embed templates into an (num_res, num_templates, num_channels) embedding.
Args:
......@@ -1080,6 +1127,7 @@ def template_embedding_1d(batch, num_channel):
template_all_atom_mask, (num_templates, num_residues, 37) atom mask for
each template.
num_channel: The number of channels in the output.
global_config: The global_config.
Returns:
An embedding of shape (num_templates, num_res, num_channels) and a mask of
......@@ -1112,6 +1160,10 @@ def template_embedding_1d(batch, num_channel):
template_mask = chi_mask[:, :, 0]
if global_config.bfloat16:
template_features = template_features.astype(jnp.bfloat16)
template_mask = template_mask.astype(jnp.bfloat16)
template_activations = common_modules.Linear(
num_channel,
initializer='relu',
......
......@@ -15,6 +15,7 @@
"""A collection of JAX utility functions for use in protein folding."""
import collections
import contextlib
import functools
import numbers
from typing import Mapping
......@@ -25,6 +26,27 @@ import jax.numpy as jnp
import numpy as np
def bfloat16_creator(next_creator, shape, dtype, init, context):
"""Creates float32 variables when bfloat16 is requested."""
if context.original_dtype == jnp.bfloat16:
dtype = jnp.float32
return next_creator(shape, dtype, init)
def bfloat16_getter(next_getter, value, context):
"""Casts float32 to bfloat16 when bfloat16 was originally requested."""
if context.original_dtype == jnp.bfloat16:
assert value.dtype == jnp.float32
value = value.astype(jnp.bfloat16)
return next_getter(value)
@contextlib.contextmanager
def bfloat16_context():
with hk.custom_creator(bfloat16_creator), hk.custom_getter(bfloat16_getter):
yield
def final_init(config):
if config.zero_init:
return 'zeros'
......
......@@ -13,7 +13,6 @@
# limitations under the License.
"""Helper methods for the AlphaFold Colab notebook."""
import enum
import json
from typing import Any, Mapping, Optional, Sequence, Tuple
......@@ -23,13 +22,7 @@ from matplotlib import pyplot as plt
import numpy as np
@enum.unique
class ModelType(enum.Enum):
MONOMER = 0
MULTIMER = 1
def clean_and_validate_sequence(
def clean_and_validate_single_sequence(
input_sequence: str, min_length: int, max_length: int) -> str:
"""Checks that the input sequence is ok and returns a clean version of it."""
# Remove all whitespaces, tabs and end lines; upper-case.
......@@ -54,41 +47,23 @@ def clean_and_validate_sequence(
return clean_sequence
def validate_input(
def clean_and_validate_input_sequences(
input_sequences: Sequence[str],
min_length: int,
max_length: int,
max_multimer_length: int) -> Tuple[Sequence[str], ModelType]:
"""Validates and cleans input sequences and determines which model to use."""
min_sequence_length: int,
max_sequence_length: int) -> Sequence[str]:
"""Validates and cleans input sequences."""
sequences = []
for input_sequence in input_sequences:
if input_sequence.strip():
input_sequence = clean_and_validate_sequence(
input_sequence = clean_and_validate_single_sequence(
input_sequence=input_sequence,
min_length=min_length,
max_length=max_length)
min_length=min_sequence_length,
max_length=max_sequence_length)
sequences.append(input_sequence)
if len(sequences) == 1:
print('Using the single-chain model.')
return sequences, ModelType.MONOMER
elif len(sequences) > 1:
total_multimer_length = sum([len(seq) for seq in sequences])
if total_multimer_length > max_multimer_length:
raise ValueError(f'The total length of multimer sequences is too long: '
f'{total_multimer_length}, while the maximum is '
f'{max_multimer_length}. Please use the full AlphaFold '
f'system for long multimers.')
elif total_multimer_length > 1536:
print('WARNING: The accuracy of the system has not been fully validated '
'above 1536 residues, and you may experience long running times or '
f'run out of memory for your complex with {total_multimer_length} '
'residues.')
print(f'Using the multimer model with {len(sequences)} sequences.')
return sequences, ModelType.MULTIMER
if sequences:
return sequences
else:
raise ValueError('No input amino acid sequence provided, please provide at '
'least one sequence.')
......
......@@ -85,7 +85,7 @@ class NotebookUtilsTest(parameterized.TestCase):
('DeepMind', 'DEEPMIND'), ('A ', 'A'), ('\tA', 'A'), (' A\t\n', 'A'),
('ACDEFGHIKLMNPQRSTVWY', 'ACDEFGHIKLMNPQRSTVWY'))
def test_clean_and_validate_sequence_ok(self, sequence, exp_clean):
clean = notebook_utils.clean_and_validate_sequence(
clean = notebook_utils.clean_and_validate_single_sequence(
sequence, min_length=1, max_length=100)
self.assertEqual(clean, exp_clean)
......@@ -100,35 +100,29 @@ class NotebookUtilsTest(parameterized.TestCase):
('bad_amino_acids_Z', 'ZZZZ', 'non-amino acid'))
def test_clean_and_validate_sequence_bad(self, sequence, exp_error):
with self.assertRaisesRegex(ValueError, f'.*{exp_error}.*'):
notebook_utils.clean_and_validate_sequence(
notebook_utils.clean_and_validate_single_sequence(
sequence, min_length=4, max_length=8)
@parameterized.parameters(
(['A', '', '', ' ', '\t', ' \t\n', '', ''], ['A'],
notebook_utils.ModelType.MONOMER),
(['', 'A'], ['A'],
notebook_utils.ModelType.MONOMER),
(['A', 'C ', ''], ['A', 'C'],
notebook_utils.ModelType.MULTIMER),
(['', 'A', '', 'C '], ['A', 'C'],
notebook_utils.ModelType.MULTIMER))
def test_validate_input_ok(
self, input_sequences, exp_sequences, exp_model_type):
sequences, model_type = notebook_utils.validate_input(
(['A', '', '', ' ', '\t', ' \t\n', '', ''], ['A']),
(['', 'A'], ['A']),
(['A', 'C ', ''], ['A', 'C']),
(['', 'A', '', 'C '], ['A', 'C']))
def test_validate_input_ok(self, input_sequences, exp_sequences):
sequences = notebook_utils.clean_and_validate_input_sequences(
input_sequences=input_sequences,
min_length=1, max_length=100, max_multimer_length=100)
min_sequence_length=1, max_sequence_length=100)
self.assertSequenceEqual(sequences, exp_sequences)
self.assertEqual(model_type, exp_model_type)
@parameterized.named_parameters(
('no_input_sequence', ['', '\t', '\n'], 'No input amino acid sequence'),
('too_long_single', ['AAAAAAAAA', 'AAAA'], 'Input sequence is too long'),
('too_long_multimer', ['AAAA', 'AAAAA'], 'The total length of multimer'))
('too_short_single', ['AAA', 'AAAA'], 'Input sequence is too short'))
def test_validate_input_bad(self, input_sequences, exp_error):
with self.assertRaisesRegex(ValueError, f'.*{exp_error}.*'):
notebook_utils.validate_input(
input_sequences=input_sequences,
min_length=4, max_length=8, max_multimer_length=6)
notebook_utils.clean_and_validate_input_sequences(
input_sequences=input_sequences, min_sequence_length=4,
max_sequence_length=8)
def test_merge_chunked_msa_no_hits(self):
results = [ONLY_QUERY_HIT, ONLY_QUERY_HIT]
......
......@@ -56,7 +56,8 @@ class AmberRelaxation(object):
self._use_gpu = use_gpu
def process(self, *,
prot: protein.Protein) -> Tuple[str, Dict[str, Any], np.ndarray]:
prot: protein.Protein
) -> Tuple[str, Dict[str, Any], Sequence[float]]:
"""Runs Amber relax on a prediction, adds hydrogens, returns PDB string."""
out = amber_minimize.run_pipeline(
prot=prot, max_iterations=self._max_iterations,
......@@ -73,12 +74,11 @@ class AmberRelaxation(object):
'attempts': out['min_attempts'],
'rmsd': rmsd
}
pdb_str = amber_minimize.clean_protein(prot)
min_pdb = utils.overwrite_pdb_coordinates(pdb_str, min_pos)
min_pdb = out['min_pdb']
min_pdb = utils.overwrite_b_factors(min_pdb, prot.b_factors)
utils.assert_equal_nonterminal_atom_types(
protein.from_pdb_string(min_pdb).atom_mask,
prot.atom_mask)
violations = out['structural_violations'][
'total_per_residue_violations_mask']
'total_per_residue_violations_mask'].tolist()
return min_pdb, debug_data, violations
......@@ -82,7 +82,7 @@ class RunAmberRelaxTest(absltest.TestCase):
0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0,
0, 0, 0, 0])
# Check no violations were added. Can't check exactly due to stochasticity.
self.assertTrue(np.all(num_violations <= exp_num_violations))
self.assertTrue(np.all(np.array(num_violations) <= exp_num_violations))
if __name__ == '__main__':
......
......@@ -17,17 +17,6 @@ import io
from alphafold.common import residue_constants
from Bio import PDB
import numpy as np
from simtk.openmm import app as openmm_app
from simtk.openmm.app.internal.pdbstructure import PdbStructure
def overwrite_pdb_coordinates(pdb_str: str, pos) -> str:
pdb_file = io.StringIO(pdb_str)
structure = PdbStructure(pdb_file)
topology = openmm_app.PDBFile(structure).getTopology()
with io.StringIO() as f:
openmm_app.PDBFile.writeFile(topology, pos, f)
return f.getvalue()
def overwrite_b_factors(pdb_str: str, bfactors: np.ndarray) -> str:
......
......@@ -21,7 +21,7 @@ ARG CUDA
# Use bash to support string substitution.
SHELL ["/bin/bash", "-o", "pipefail", "-c"]
RUN apt-get update \
RUN apt-get update \
&& DEBIAN_FRONTEND=noninteractive apt-get install --no-install-recommends -y \
build-essential \
cmake \
......@@ -59,7 +59,7 @@ RUN conda install -qy conda==4.13.0 \
cudatoolkit==${CUDA_VERSION} \
pdbfixer \
pip \
python=3.7 \
python=3.8 \
&& conda clean --all --force-pkgs-dirs --yes
COPY . /app/alphafold
......@@ -75,7 +75,7 @@ RUN pip3 install --upgrade pip --no-cache-dir \
-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# Apply OpenMM patch.
WORKDIR /opt/conda/lib/python3.7/site-packages
WORKDIR /opt/conda/lib/python3.8/site-packages
RUN patch -p0 < /app/alphafold/docker/openmm.patch
# Add SETUID bit to the ldconfig binary so that non-root users can run it.
......
......@@ -133,7 +133,7 @@ def main(argv):
# Path to the MGnify database for use by JackHMMER.
mgnify_database_path = os.path.join(
FLAGS.data_dir, 'mgnify', 'mgy_clusters_2018_12.fa')
FLAGS.data_dir, 'mgnify', 'mgy_clusters_2022_05.fa')
# Path to the BFD database for use by HHblits.
bfd_database_path = os.path.join(
......@@ -144,9 +144,9 @@ def main(argv):
small_bfd_database_path = os.path.join(
FLAGS.data_dir, 'small_bfd', 'bfd-first_non_consensus_sequences.fasta')
# Path to the Uniclust30 database for use by HHblits.
uniclust30_database_path = os.path.join(
FLAGS.data_dir, 'uniclust30', 'uniclust30_2018_08', 'uniclust30_2018_08')
# Path to the Uniref30 database for use by HHblits.
uniref30_database_path = os.path.join(
FLAGS.data_dir, 'uniref30', 'UniRef30_2021_03')
# Path to the PDB70 database for use by HHsearch.
pdb70_database_path = os.path.join(FLAGS.data_dir, 'pdb70', 'pdb70')
......@@ -199,7 +199,7 @@ def main(argv):
database_paths.append(('small_bfd_database_path', small_bfd_database_path))
else:
database_paths.extend([
('uniclust30_database_path', uniclust30_database_path),
('uniref30_database_path', uniref30_database_path),
('bfd_database_path', bfd_database_path),
])
for name, path in database_paths:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment