Commit 13f8f163 authored by zhuwenwen's avatar zhuwenwen
Browse files
parents a509a4c5 b5fa2ba3
Pipeline #235 failed with stages
in 0 seconds
name: Docker Image CI
on:
push:
branches: [ main ]
pull_request:
branches: [ main ]
jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Build the Docker image
run: docker build . --file Dockerfile --tag openfold:$(date +%s)
\ No newline at end of file
name: undefined_names
on: [pull_request, push]
jobs:
undefined_names:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions/setup-python@v2
- run: pip install --upgrade pip
- run: pip install flake8
- run: flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
.vscode/
__pycache__/
*.egg-info
build
dist
# files from script downloads
data
openfold/resources/
tests/test_data/
cff-version: 1.2.0
preferred-citation:
authors:
- family-names: "Ahdritz"
given-names: "Gustaf"
orcid: https://orcid.org/0000-0001-8283-5324
- family-names: "Bouatta"
given-names: "Nazim"
orcid: https://orcid.org/0000-0002-6524-874X
- family-names: "Kadyan"
given-names: "Sachin"
orcid: https://orcid.org/0000-0002-6079-7627
- family-names: "Xia"
given-names: "Qinghui"
- family-names: "Gerecke"
given-names: "William"
orcid: https://orcid.org/0000-0002-9777-6192
- family-names: "O'Donnell"
given-names: "Timothy J"
orcid: https://orcid.org/0000-0002-9949-069X
- family-names: "Berenberg"
given-names: "Daniel"
orcid: https://orcid.org/0000-0003-4631-0947
- family-names: "Fisk"
given-names: "Ian"
- family-names: "Zanichelli"
given-names: "Niccolò"
orcid: https://orcid.org/0000-0002-3093-3587
- family-names: "Zhang"
given-names: "Bo"
orcid: https://orcid.org/0000-0002-9714-2827
- family-names: "Nowaczynski"
given-names: "Arkadiusz"
orcid: https://orcid.org/0000-0002-3351-9584
- family-names: "Wang"
given-names: "Bei"
orcid: https://orcid.org/0000-0003-4942-9652
- family-names: "Stepniewska-Dziubinska"
given-names: "Marta M"
orcid: https://orcid.org/0000-0003-4942-9652
- family-names: "Zhang"
given-names: "Shang"
orcid: https://orcid.org/0000-0003-0759-2080
- family-names: "Ojewole"
given-names: "Adegoke"
orcid: https://orcid.org/0000-0003-2661-4388
- family-names: "Guney"
given-names: "Murat Efe"
- family-names: "Biderman"
given-names: "Stella"
orcid: https://orcid.org/0000-0001-8228-1042
- family-names: "Watkins"
given-names: "Andrew M"
orcid: https://orcid.org/0000-0003-1617-1720
- family-names: "Ra"
given-names: "Stephen"
orcid: https://orcid.org/0000-0002-2820-0050
- family-names: "Lorenzo"
given-names: "Pablo Ribalta"
orcid: https://orcid.org/0000-0002-3657-8053
- family-names: "Nivon"
given-names: "Lucas"
- family-names: "Weitzner"
given-names: "Brian"
orcid: https://orcid.org/0000-0002-1909-0961
- family-names: "Ban"
given-names: "Yih-En"
orcid: https://orcid.org/0000-0003-3698-3574
- family-names: "Ban"
given-names: "Yih-En Andrew"
orcid: https://orcid.org/0000-0003-3698-3574
- family-names: "Sorger"
given-names: "Peter K"
orcid: https://orcid.org/0000-0002-3364-1838
- family-names: "Mostaque"
given-names: "Emad"
- family-names: "Zhang"
given-names: "Zhao"
orcid: https://orcid.org/0000-0001-5921-0035
- family-names: "Bonneau"
given-names: "Richard"
orcid: https://orcid.org/0000-0003-4354-7906
- family-names: "AlQuraishi"
given-names: "Mohammed"
orcid: https://orcid.org/0000-0001-6817-1322
title: "OpenFold: Retraining AlphaFold2 yields new insights into its learning mechanisms and capacity for generalization"
type: article
doi: 10.1101/2022.11.20.517210
doi: 10.1101/2022.11.20.517210
date-released: 2021-11-12
url: "https://doi.org/10.1101/2022.11.20.517210"
FROM nvidia/cuda:11.3.1-cudnn8-runtime-ubuntu18.04
# metainformation
LABEL org.opencontainers.image.version = "1.0.0"
LABEL org.opencontainers.image.authors = "Gustaf Ahdritz"
LABEL org.opencontainers.image.source = "https://github.com/aqlaboratory/openfold"
LABEL org.opencontainers.image.licenses = "Apache License 2.0"
LABEL org.opencontainers.image.base.name="docker.io/nvidia/cuda:10.2-cudnn8-runtime-ubuntu18.04"
RUN apt-key del 7fa2af80
RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/7fa2af80.pub
RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub
RUN apt-get update && apt-get install -y wget libxml2 cuda-minimal-build-11-3 libcusparse-dev-11-3 libcublas-dev-11-3 libcusolver-dev-11-3 git
RUN wget -P /tmp \
"https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh" \
&& bash /tmp/Miniconda3-latest-Linux-x86_64.sh -b -p /opt/conda \
&& rm /tmp/Miniconda3-latest-Linux-x86_64.sh
ENV PATH /opt/conda/bin:$PATH
COPY environment.yml /opt/openfold/environment.yml
# installing into the base environment since the docker container wont do anything other than run openfold
RUN conda env update -n base --file /opt/openfold/environment.yml && conda clean --all
COPY openfold /opt/openfold/openfold
COPY scripts /opt/openfold/scripts
COPY run_pretrained_openfold.py /opt/openfold/run_pretrained_openfold.py
COPY train_openfold.py /opt/openfold/train_openfold.py
COPY setup.py /opt/openfold/setup.py
COPY lib/openmm.patch /opt/openfold/lib/openmm.patch
RUN wget -q -P /opt/openfold/openfold/resources \
https://git.scicore.unibas.ch/schwede/openstructure/-/raw/7102c63615b64735c4941278d92b554ec94415f8/modules/mol/alg/src/stereo_chemical_props.txt
RUN patch -p0 -d /opt/conda/lib/python3.7/site-packages/ < /opt/openfold/lib/openmm.patch
WORKDIR /opt/openfold
RUN python3 setup.py install
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
![header ](imgs/of_banner.png)
_Figure: Comparison of OpenFold and AlphaFold2 predictions to the experimental structure of PDB 7KDX, chain B._
# OpenFold
A faithful but trainable PyTorch reproduction of DeepMind's
[AlphaFold 2](https://github.com/deepmind/alphafold).
## Features
OpenFold carefully reproduces (almost) all of the features of the original open
source inference code (v2.0.1). The sole exception is model ensembling, which
fared poorly in DeepMind's own ablation testing and is being phased out in future
DeepMind experiments. It is omitted here for the sake of reducing clutter. In
cases where the *Nature* paper differs from the source, we always defer to the
latter.
OpenFold is trainable in full precision, half precision, or `bfloat16` with or without DeepSpeed,
and we've trained it from scratch, matching the performance of the original.
We've publicly released model weights and our training data &mdash; some 400,000
MSAs and PDB70 template hit files &mdash; under a permissive license. Model weights
are available via scripts in this repository while the MSAs are hosted by the
[Registry of Open Data on AWS (RODA)](https://registry.opendata.aws/openfold).
Try out running inference for yourself with our [Colab notebook](https://colab.research.google.com/github/aqlaboratory/openfold/blob/main/notebooks/OpenFold.ipynb).
OpenFold also supports inference using AlphaFold's official parameters, and
vice versa (see `scripts/convert_of_weights_to_jax.py`).
OpenFold has the following advantages over the reference implementation:
- **Faster inference** on GPU, sometimes by as much as 2x. The greatest speedups are achieved on (>= Ampere) GPUs.
- **Inference on extremely long chains**, made possible by our implementation of low-memory attention
([Rabe & Staats 2021](https://arxiv.org/pdf/2112.05682.pdf)). OpenFold can predict the structures of
sequences with more than 4000 residues on a single A100, and even longer ones with CPU offloading.
- **Custom CUDA attention kernels** modified from [FastFold](https://github.com/hpcaitech/FastFold)'s
kernels support in-place attention during inference and training. They use
4x and 5x less GPU memory than equivalent FastFold and stock PyTorch
implementations, respectively.
- **Efficient alignment scripts** using the original AlphaFold HHblits/JackHMMER pipeline or [ColabFold](https://github.com/sokrypton/ColabFold)'s, which uses the faster MMseqs2 instead. We've used them to generate millions of alignments.
- **FlashAttention** support greatly speeds up MSA attention.
## Installation (Linux)
All Python dependencies are specified in `environment.yml`. For producing sequence
alignments, you'll also need `kalign`, the [HH-suite](https://github.com/soedinglab/hh-suite),
and one of {`jackhmmer`, [MMseqs2](https://github.com/soedinglab/mmseqs2) (nightly build)}
installed on on your system. You'll need `git-lfs` to download OpenFold parameters.
Finally, some download scripts require `aria2c` and `aws`.
For convenience, we provide a script that installs Miniconda locally, creates a
`conda` virtual environment, installs all Python dependencies, and downloads
useful resources, including both sets of model parameters. Run:
```bash
scripts/install_third_party_dependencies.sh
```
To activate the environment, run:
```bash
source scripts/activate_conda_env.sh
```
To deactivate it, run:
```bash
source scripts/deactivate_conda_env.sh
```
With the environment active, compile OpenFold's CUDA kernels with
```bash
python3 setup.py install
```
To install the HH-suite to `/usr/bin`, run
```bash
# scripts/install_hh_suite.sh
```
## Usage
If you intend to generate your own alignments, e.g. for inference, you have two
choices for downloading protein databases, depending on whether you want to use
DeepMind's MSA generation pipeline (w/ HMMR & HHblits) or
[ColabFold](https://github.com/sokrypton/ColabFold)'s, which uses the faster
MMseqs2 instead. For the former, run:
```bash
bash scripts/download_alphafold_dbs.sh data/
```
For the latter, run:
```bash
bash scripts/download_mmseqs_dbs.sh data/ # downloads .tar files
bash scripts/prep_mmseqs_dbs.sh data/ # unpacks and preps the databases
```
Make sure to run the latter command on the machine that will be used for MSA
generation (the script estimates how the precomputed database index used by
MMseqs2 should be split according to the memory available on the system).
If you're using your own precomputed MSAs or MSAs from the RODA repository,
there's no need to download these alignment databases. Simply make sure that
the `alignment_dir` contains one directory per chain and that each of these
contains alignments (.sto, .a3m, and .hhr) corresponding to that chain. You
can use `scripts/flatten_roda.sh` to reformat RODA downloads in this way.
Note that the RODA alignments are NOT compatible with the recent .cif ground
truth files downloaded by `scripts/download_alphafold_dbs.sh`. To fetch .cif
files that match the RODA MSAs, once the alignments are flattened, use
`scripts/download_roda_pdbs.sh`. That script outputs a list of alignment dirs
for which matching .cif files could not be found. These should be removed from
the alignment directory.
Alternatively, you can use raw MSAs from
[ProteinNet](https://github.com/aqlaboratory/proteinnet). After downloading
that database, use `scripts/prep_proteinnet_msas.py` to convert the data
into a format recognized by the OpenFold parser. The resulting directory
becomes the `alignment_dir` used in subsequent steps. Use
`scripts/unpack_proteinnet.py` to extract `.core` files from ProteinNet text
files.
For both inference and training, the model's hyperparameters can be tuned from
`openfold/config.py`. Of course, if you plan to perform inference using
DeepMind's pretrained parameters, you will only be able to make changes that
do not affect the shapes of model parameters. For an example of initializing
the model, consult `run_pretrained_openfold.py`.
### Inference
To run inference on a sequence or multiple sequences using a set of DeepMind's
pretrained parameters, run e.g.:
```bash
python3 run_pretrained_openfold.py \
fasta_dir \
data/pdb_mmcif/mmcif_files/ \
--uniref90_database_path data/uniref90/uniref90.fasta \
--mgnify_database_path data/mgnify/mgy_clusters_2018_12.fa \
--pdb70_database_path data/pdb70/pdb70 \
--uniclust30_database_path data/uniclust30/uniclust30_2018_08/uniclust30_2018_08 \
--output_dir ./ \
--bfd_database_path data/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt \
--model_device "cuda:0" \
--jackhmmer_binary_path lib/conda/envs/openfold_venv/bin/jackhmmer \
--hhblits_binary_path lib/conda/envs/openfold_venv/bin/hhblits \
--hhsearch_binary_path lib/conda/envs/openfold_venv/bin/hhsearch \
--kalign_binary_path lib/conda/envs/openfold_venv/bin/kalign \
--config_preset "model_1_ptm" \
--openfold_checkpoint_path openfold/resources/openfold_params/finetuning_ptm_2.pt
```
where `data` is the same directory as in the previous step. If `jackhmmer`,
`hhblits`, `hhsearch` and `kalign` are available at the default path of
`/usr/bin`, their `binary_path` command-line arguments can be dropped.
If you've already computed alignments for the query, you have the option to
skip the expensive alignment computation here with
`--use_precomputed_alignments`.
`--openfold_checkpoint_path` or `--jax_param_path` accept comma-delineated lists
of .pt/DeepSpeed OpenFold checkpoints and AlphaFold's .npz JAX parameter files,
respectively. For a breakdown of the differences between the different parameter
files, see the README downloaded to `openfold/resources/openfold_params/`. Since
OpenFold was trained under a newer training schedule than the one from which the
`model_n` config presets are derived, there is no clean correspondence between
`config_preset` settings and OpenFold checkpoints; the only restraints are that
`*_ptm` checkpoints must be run with `*_ptm` config presets and that `_no_templ_`
checkpoints are only compatible with template-less presets (`model_3` and above).
Note that chunking (as defined in section 1.11.8 of the AlphaFold 2 supplement)
is enabled by default in inference mode. To disable it, set `globals.chunk_size`
to `None` in the config. If a value is specified, OpenFold will attempt to
dynamically tune it, considering the chunk size specified in the config as a
minimum. This tuning process automatically ensures consistently fast runtimes
regardless of input sequence length, but it also introduces some runtime
variability, which may be undesirable for certain users. It is also recommended
to disable this feature for very long chains (see below). To do so, set the
`tune_chunk_size` option in the config to `False`.
For large-scale batch inference, we offer an optional tracing mode, which
massively improves runtimes at the cost of a lengthy model compilation process.
To enable it, add `--trace_model` to the inference command.
To get a speedup during inference, enable [FlashAttention](https://github.com/HazyResearch/flash-attention)
in the config. Note that it appears to work best for sequences with < 1000 residues.
Input FASTA files containing multiple sequences are treated as complexes. In
this case, the inference script runs AlphaFold-Gap, a hack proposed
[here](https://twitter.com/minkbaek/status/1417538291709071362?lang=en), using
the specified stock AlphaFold/OpenFold parameters (NOT AlphaFold-Multimer). To
run inference with AlphaFold-Multimer, use the (experimental) `multimer` branch
instead.
To minimize memory usage during inference on long sequences, consider the
following changes:
- As noted in the AlphaFold-Multimer paper, the AlphaFold/OpenFold template
stack is a major memory bottleneck for inference on long sequences. OpenFold
supports two mutually exclusive inference modes to address this issue. One,
`average_templates` in the `template` section of the config, is similar to the
solution offered by AlphaFold-Multimer, which is simply to average individual
template representations. Our version is modified slightly to accommodate
weights trained using the standard template algorithm. Using said weights, we
notice no significant difference in performance between our averaged template
embeddings and the standard ones. The second, `offload_templates`, temporarily
offloads individual template embeddings into CPU memory. The former is an
approximation while the latter is slightly slower; both are memory-efficient
and allow the model to utilize arbitrarily many templates across sequence
lengths. Both are disabled by default, and it is up to the user to determine
which best suits their needs, if either.
- Inference-time low-memory attention (LMA) can be enabled in the model config.
This setting trades off speed for vastly improved memory usage. By default,
LMA is run with query and key chunk sizes of 1024 and 4096, respectively.
These represent a favorable tradeoff in most memory-constrained cases.
Powerusers can choose to tweak these settings in
`openfold/model/primitives.py`. For more information on the LMA algorithm,
see the aforementioned Staats & Rabe preprint.
- Disable `tune_chunk_size` for long sequences. Past a certain point, it only
wastes time.
- As a last resort, consider enabling `offload_inference`. This enables more
extensive CPU offloading at various bottlenecks throughout the model.
- Disable FlashAttention, which seems unstable on long sequences.
Using the most conservative settings, we were able to run inference on a
4600-residue complex with a single A100. Compared to AlphaFold's own memory
offloading mode, ours is considerably faster; the same complex takes the more
efficent AlphaFold-Multimer more than double the time. Use the
`long_sequence_inference` config option to enable all of these interventions
at once. The `run_pretrained_openfold.py` script can enable this config option with the
`--long_sequence_inference` command line option
### Training
To train the model, you will first need to precompute protein alignments.
You have two options. You can use the same procedure DeepMind used by running
the following:
```bash
python3 scripts/precompute_alignments.py mmcif_dir/ alignment_dir/ \
--uniref90_database_path data/uniref90/uniref90.fasta \
--mgnify_database_path data/mgnify/mgy_clusters_2018_12.fa \
--pdb70_database_path data/pdb70/pdb70 \
--uniclust30_database_path data/uniclust30/uniclust30_2018_08/uniclust30_2018_08 \
--bfd_database_path data/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt \
--cpus_per_task 16 \
--jackhmmer_binary_path lib/conda/envs/openfold_venv/bin/jackhmmer \
--hhblits_binary_path lib/conda/envs/openfold_venv/bin/hhblits \
--hhsearch_binary_path lib/conda/envs/openfold_venv/bin/hhsearch \
--kalign_binary_path lib/conda/envs/openfold_venv/bin/kalign
```
As noted before, you can skip the `binary_path` arguments if these binaries are
at `/usr/bin`. Expect this step to take a very long time, even for small
numbers of proteins.
Alternatively, you can generate MSAs with the ColabFold pipeline (and templates
with HHsearch) with:
```bash
python3 scripts/precompute_alignments_mmseqs.py input.fasta \
data/mmseqs_dbs \
uniref30_2103_db \
alignment_dir \
~/MMseqs2/build/bin/mmseqs \
/usr/bin/hhsearch \
--env_db colabfold_envdb_202108_db
--pdb70 data/pdb70/pdb70
```
where `input.fasta` is a FASTA file containing one or more query sequences. To
generate an input FASTA from a directory of mmCIF and/or ProteinNet .core
files, we provide `scripts/data_dir_to_fasta.py`.
Next, generate a cache of certain datapoints in the template mmCIF files:
```bash
python3 scripts/generate_mmcif_cache.py \
mmcif_dir/ \
mmcif_cache.json \
--no_workers 16
```
This cache is used to pre-filter templates.
Next, generate a separate chain-level cache with data used for training-time
data filtering:
```bash
python3 scripts/generate_chain_data_cache.py \
mmcif_dir/ \
chain_data_cache.json \
--cluster_file clusters-by-entity-40.txt \
--no_workers 16
```
where the `cluster_file` argument is a file of chain clusters, one cluster
per line.
Optionally, download an AlphaFold-style validation set from
[CAMEO](https://cameo3d.org) using `scripts/download_cameo.py`. Use the
resulting FASTA files to generate validation alignments and then specify
the validation set's location using the `--val_...` family of training script
flags.
Finally, call the training script:
```bash
python3 train_openfold.py mmcif_dir/ alignment_dir/ template_mmcif_dir/ output_dir/ \
2021-10-10 \
--template_release_dates_cache_path mmcif_cache.json \
--precision bf16 \
--gpus 8 --replace_sampler_ddp=True \
--seed 4242022 \ # in multi-gpu settings, the seed must be specified
--deepspeed_config_path deepspeed_config.json \
--checkpoint_every_epoch \
--resume_from_ckpt ckpt_dir/ \
--train_chain_data_cache_path chain_data_cache.json \
--obsolete_pdbs_file_path obsolete.dat
```
where `--template_release_dates_cache_path` is a path to the mmCIF cache.
Note that `template_mmcif_dir` can be the same as `mmcif_dir` which contains
training targets. A suitable DeepSpeed configuration file can be generated with
`scripts/build_deepspeed_config.py`. The training script is
written with [PyTorch Lightning](https://github.com/PyTorchLightning/pytorch-lightning)
and supports the full range of training options that entails, including
multi-node distributed training, validation, and so on. For more information,
consult PyTorch Lightning documentation and the `--help` flag of the training
script.
Note that, despite its variable name, `mmcif_dir` can also contain PDB files
or even ProteinNet .core files.
To emulate the AlphaFold training procedure, which uses a self-distillation set
subject to special preprocessing steps, use the family of `--distillation` flags.
In cases where it may be burdensome to create separate files for each chain's
alignments, alignment directories can be consolidated using the scripts in
`scripts/alignment_db_scripts/`. First, run `create_alignment_db.py` to
consolidate an alignment directory into a pair of database and index files.
Once all alignment directories (or shards of a single alignment directory)
have been compiled, unify the indices with `unify_alignment_db_indices.py`. The
resulting index, `super.index`, can be passed to the training script flags
containing the phrase `alignment_index`. In this scenario, the `alignment_dir`
flags instead represent the directory containing the compiled alignment
databases. Both the training and distillation datasets can be compiled in this
way. Anecdotally, this can speed up training in I/O-bottlenecked environments.
## Testing
To run unit tests, use
```bash
scripts/run_unit_tests.sh
```
The script is a thin wrapper around Python's `unittest` suite, and recognizes
`unittest` arguments. E.g., to run a specific test verbosely:
```bash
scripts/run_unit_tests.sh -v tests.test_model
```
Certain tests require that AlphaFold (v2.0.1) be installed in the same Python
environment. These run components of AlphaFold and OpenFold side by side and
ensure that output activations are adequately similar. For most modules, we
target a maximum pointwise difference of `1e-4`.
## Building and using the docker container
### Building the docker image
Openfold can be built as a docker container using the included dockerfile. To build it, run the following command from the root of this repository:
```bash
docker build -t openfold .
```
### Running the docker container
The built container contains both `run_pretrained_openfold.py` and `train_openfold.py` as well as all necessary software dependencies. It does not contain the model parameters, sequence, or structural databases. These should be downloaded to the host machine following the instructions in the Usage section above.
The docker container installs all conda components to the base conda environment in `/opt/conda`, and installs openfold itself in `/opt/openfold`,
Before running the docker container, you can verify that your docker installation is able to properly communicate with your GPU by running the following command:
```bash
docker run --rm --gpus all nvidia/cuda:11.0-base nvidia-smi
```
Note the `--gpus all` option passed to `docker run`. This option is necessary in order for the container to use the GPUs on the host machine.
To run the inference code under docker, you can use a command like the one below. In this example, parameters and sequences from the alphafold dataset are being used and are located at `/mnt/alphafold_database` on the host machine, and the input files are located in the current working directory. You can adjust the volume mount locations as needed to reflect the locations of your data.
```bash
docker run \
--gpus all \
-v $PWD/:/data \
-v /mnt/alphafold_database/:/database \
-ti openfold:latest \
python3 /opt/openfold/run_pretrained_openfold.py \
/data/fasta_dir \
/database/pdb_mmcif/mmcif_files/ \
--uniref90_database_path /database/uniref90/uniref90.fasta \
--mgnify_database_path /database/mgnify/mgy_clusters_2018_12.fa \
--pdb70_database_path /database/pdb70/pdb70 \
--uniclust30_database_path /database/uniclust30/uniclust30_2018_08/uniclust30_2018_08 \
--output_dir /data \
--bfd_database_path /database/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt \
--model_device cuda:0 \
--jackhmmer_binary_path /opt/conda/bin/jackhmmer \
--hhblits_binary_path /opt/conda/bin/hhblits \
--hhsearch_binary_path /opt/conda/bin/hhsearch \
--kalign_binary_path /opt/conda/bin/kalign \
--openfold_checkpoint_path /database/openfold_params/finetuning_ptm_2.pt
```
## Copyright notice
While AlphaFold's and, by extension, OpenFold's source code is licensed under
the permissive Apache Licence, Version 2.0, DeepMind's pretrained parameters
fall under the CC BY 4.0 license, a copy of which is downloaded to
`openfold/resources/params` by the installation script. Note that the latter
replaces the original, more restrictive CC BY-NC 4.0 license as of January 2022.
## Contributing
If you encounter problems using OpenFold, feel free to create an issue! We also
welcome pull requests from the community.
## Citing this work
Please cite our paper:
```bibtex
@article {Ahdritz2022.11.20.517210,
author = {Ahdritz, Gustaf and Bouatta, Nazim and Kadyan, Sachin and Xia, Qinghui and Gerecke, William and O{\textquoteright}Donnell, Timothy J and Berenberg, Daniel and Fisk, Ian and Zanichelli, Niccolò and Zhang, Bo and Nowaczynski, Arkadiusz and Wang, Bei and Stepniewska-Dziubinska, Marta M and Zhang, Shang and Ojewole, Adegoke and Guney, Murat Efe and Biderman, Stella and Watkins, Andrew M and Ra, Stephen and Lorenzo, Pablo Ribalta and Nivon, Lucas and Weitzner, Brian and Ban, Yih-En Andrew and Sorger, Peter K and Mostaque, Emad and Zhang, Zhao and Bonneau, Richard and AlQuraishi, Mohammed},
title = {OpenFold: Retraining AlphaFold2 yields new insights into its learning mechanisms and capacity for generalization},
elocation-id = {2022.11.20.517210},
year = {2022},
doi = {10.1101/2022.11.20.517210},
publisher = {Cold Spring Harbor Laboratory},
abstract = {AlphaFold2 revolutionized structural biology with the ability to predict protein structures with exceptionally high accuracy. Its implementation, however, lacks the code and data required to train new models. These are necessary to (i) tackle new tasks, like protein-ligand complex structure prediction, (ii) investigate the process by which the model learns, which remains poorly understood, and (iii) assess the model{\textquoteright}s generalization capacity to unseen regions of fold space. Here we report OpenFold, a fast, memory-efficient, and trainable implementation of AlphaFold2, and OpenProteinSet, the largest public database of protein multiple sequence alignments. We use OpenProteinSet to train OpenFold from scratch, fully matching the accuracy of AlphaFold2. Having established parity, we assess OpenFold{\textquoteright}s capacity to generalize across fold space by retraining it using carefully designed datasets. We find that OpenFold is remarkably robust at generalizing despite extreme reductions in training set size and diversity, including near-complete elisions of classes of secondary structure elements. By analyzing intermediate structures produced by OpenFold during training, we also gain surprising insights into the manner in which the model learns to fold proteins, discovering that spatial dimensions are learned sequentially. Taken together, our studies demonstrate the power and utility of OpenFold, which we believe will prove to be a crucial new resource for the protein modeling community.},
URL = {https://www.biorxiv.org/content/10.1101/2022.11.20.517210},
eprint = {https://www.biorxiv.org/content/early/2022/11/22/2022.11.20.517210.full.pdf},
journal = {bioRxiv}
}
```
Any work that cites OpenFold should also cite AlphaFold.
{
"fp16": {
"enabled": false,
"min_loss_scale": 1
},
"amp": {
"enabled": false,
"opt_level": "O2"
},
"bfloat16": {
"enabled": true
},
"zero_optimization": {
"stage": 2,
"cpu_offload": true,
"contiguous_gradients": true
},
"activation_checkpointing": {
"partition_activations": true,
"cpu_checkpointing": false,
"profile": false
},
"gradient_clipping": 0.1
}
name: openfold_venv
channels:
- conda-forge
- bioconda
- pytorch
dependencies:
- conda-forge::python=3.7
- conda-forge::setuptools=59.5.0
- conda-forge::pip
- conda-forge::openmm=7.5.1
- conda-forge::pdbfixer
- conda-forge::cudatoolkit==11.3.*
- bioconda::hmmer==3.3.2
- bioconda::hhsuite==3.3.0
- bioconda::kalign2==2.04
- pytorch::pytorch=1.12.*
- pip:
- biopython==1.79
- deepspeed==0.5.10
- dm-tree==0.1.6
- ml-collections==0.1.0
- numpy==1.21.2
- PyYAML==5.4.1
- requests==2.26.0
- scipy==1.7.1
- tqdm==4.62.2
- typing-extensions==3.10.0.2
- pytorch_lightning==1.5.10
- wandb==0.12.21
- modelcif==0.7
- git+https://github.com/NVIDIA/dllogger.git
Index: simtk/openmm/app/topology.py
===================================================================
--- simtk.orig/openmm/app/topology.py
+++ simtk/openmm/app/topology.py
@@ -356,19 +356,35 @@
def isCyx(res):
names = [atom.name for atom in res._atoms]
return 'SG' in names and 'HG' not in names
+ # This function is used to prevent multiple di-sulfide bonds from being
+ # assigned to a given atom. This is a DeepMind modification.
+ def isDisulfideBonded(atom):
+ for b in self._bonds:
+ if (atom in b and b[0].name == 'SG' and
+ b[1].name == 'SG'):
+ return True
+
+ return False
cyx = [res for res in self.residues() if res.name == 'CYS' and isCyx(res)]
atomNames = [[atom.name for atom in res._atoms] for res in cyx]
for i in range(len(cyx)):
sg1 = cyx[i]._atoms[atomNames[i].index('SG')]
pos1 = positions[sg1.index]
+ candidate_distance, candidate_atom = 0.3*nanometers, None
for j in range(i):
sg2 = cyx[j]._atoms[atomNames[j].index('SG')]
pos2 = positions[sg2.index]
delta = [x-y for (x,y) in zip(pos1, pos2)]
distance = sqrt(delta[0]*delta[0] + delta[1]*delta[1] + delta[2]*delta[2])
- if distance < 0.3*nanometers:
- self.addBond(sg1, sg2)
+ if distance < candidate_distance and not isDisulfideBonded(sg2):
+ candidate_distance = distance
+ candidate_atom = sg2
+ # Assign bond to closest pair.
+ if candidate_atom:
+ self.addBond(sg1, candidate_atom)
+
+
class Chain(object):
"""A Chain object represents a chain within a Topology."""
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"accelerator": "GPU",
"colab": {
"name": "OpenFold.ipynb",
"provenance": [],
"collapsed_sections": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "pc5-mbsX9PZC"
},
"source": [
"# OpenFold Colab\n",
"\n",
"Runs a simplified version of [OpenFold](https://github.com/aqlaboratory/openfold) on a target sequence. Adapted from DeepMind's [official AlphaFold Colab](https://colab.research.google.com/github/deepmind/alphafold/blob/main/notebooks/AlphaFold.ipynb).\n",
"\n",
"**Differences to AlphaFold v2.0**\n",
"\n",
"OpenFold is a trainable PyTorch reimplementation of AlphaFold 2. For the purposes of inference, it is practically identical to the original (\"practically\" because ensembling is excluded from OpenFold (recycling is enabled, however)).\n",
"\n",
"In this notebook, OpenFold is run with your choice of our original OpenFold parameters or DeepMind's publicly released parameters for AlphaFold 2.\n",
"\n",
"**Note**\n",
"\n",
"Like DeepMind's official Colab, this notebook uses **no templates (homologous structures)** and a selected portion of the full [BFD database](https://bfd.mmseqs.com/).\n",
"\n",
"**Citing this work**\n",
"\n",
"Any publication that discloses findings arising from using this notebook should [cite](https://github.com/deepmind/alphafold/#citing-this-work) DeepMind's [AlphaFold paper](https://doi.org/10.1038/s41586-021-03819-2).\n",
"\n",
"**Licenses**\n",
"\n",
"This Colab supports inference with the [AlphaFold model parameters](https://github.com/deepmind/alphafold/#model-parameters-license), made available under the Creative Commons Attribution 4.0 International ([CC BY 4.0](https://creativecommons.org/licenses/by/4.0/legalcode)) license. The Colab itself is provided under the [Apache 2.0 license](https://www.apache.org/licenses/LICENSE-2.0). See the full license statement below.\n",
"\n",
"**More information**\n",
"\n",
"You can find more information about how AlphaFold/OpenFold works in DeepMind's two Nature papers:\n",
"\n",
"* [AlphaFold methods paper](https://www.nature.com/articles/s41586-021-03819-2)\n",
"* [AlphaFold predictions of the human proteome paper](https://www.nature.com/articles/s41586-021-03828-1)\n",
"\n",
"FAQ on how to interpret AlphaFold/OpenFold predictions are [here](https://alphafold.ebi.ac.uk/faq)."
]
},
{
"cell_type": "code",
"metadata": {
"id": "rowN0bVYLe9n",
"cellView": "form"
},
"source": [
"#@markdown ### Enter the amino acid sequence to fold ⬇️\n",
"sequence = 'MAAHKGAEHHHKAAEHHEQAAKHHHAAAEHHEKGEHEQAAHHADTAYAHHKHAEEHAAQAAKHDAEHHAPKPH' #@param {type:\"string\"}\n",
"\n",
"#@markdown ### Configure the model ⬇️\n",
"\n",
"weight_set = 'OpenFold' #@param [\"OpenFold\", \"AlphaFold\"]\n",
"relax_prediction = True #@param {type:\"boolean\"}\n",
"\n",
"# Remove all whitespaces, tabs and end lines; upper-case\n",
"sequence = sequence.translate(str.maketrans('', '', ' \\n\\t')).upper()\n",
"aatypes = set('ACDEFGHIKLMNPQRSTVWY') # 20 standard aatypes\n",
"if not set(sequence).issubset(aatypes):\n",
" raise Exception(f'Input sequence contains non-amino acid letters: {set(sequence) - aatypes}. OpenFold only supports 20 standard amino acids as inputs.')\n",
"\n",
"#@markdown After making your selections, execute this cell by pressing the\n",
"#@markdown *Play* button on the left."
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "woIxeCPygt7K",
"cellView": "form"
},
"source": [
"#@title Install third-party software\n",
"#@markdown Please execute this cell by pressing the *Play* button on \n",
"#@markdown the left.\n",
"\n",
"\n",
"#@markdown **Note**: This installs the software on the Colab \n",
"#@markdown notebook in the cloud and not on your computer.\n",
"\n",
"from IPython.utils import io\n",
"import os\n",
"import subprocess\n",
"import tqdm.notebook\n",
"\n",
"TQDM_BAR_FORMAT = '{l_bar}{bar}| {n_fmt}/{total_fmt} [elapsed: {elapsed} remaining: {remaining}]'\n",
"\n",
"try:\n",
" with io.capture_output() as captured:\n",
" %shell sudo apt install --quiet --yes hmmer\n",
"\n",
" # Install py3dmol.\n",
" %shell pip install py3dmol\n",
"\n",
" %shell rm -rf /opt/conda\n",
" %shell wget -q -P /tmp \\\n",
" https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \\\n",
" && bash /tmp/Miniconda3-latest-Linux-x86_64.sh -b -p /opt/conda \\\n",
" && rm /tmp/Miniconda3-latest-Linux-x86_64.sh\n",
"\n",
" PATH=%env PATH\n",
" %env PATH=/opt/conda/bin:{PATH}\n",
"\n",
" # Install the required versions of all dependencies.\n",
" %shell conda install -y -q conda==4.13.0\n",
" %shell conda install -y -q -c conda-forge -c bioconda \\\n",
" kalign2=2.04 \\\n",
" hhsuite=3.3.0 \\\n",
" python=3.8 \\\n",
" 2>&1 1>/dev/null\n",
" %shell pip install -q \\\n",
" ml-collections==0.1.0 \\\n",
" PyYAML==5.4.1 \\\n",
" biopython==1.79\n",
"\n",
" # Create a ramdisk to store a database chunk to make Jackhmmer run fast.\n",
" %shell sudo mkdir -m 777 --parents /tmp/ramdisk\n",
" %shell sudo mount -t tmpfs -o size=9G ramdisk /tmp/ramdisk\n",
"\n",
" %shell wget -q -P /content \\\n",
" https://git.scicore.unibas.ch/schwede/openstructure/-/raw/7102c63615b64735c4941278d92b554ec94415f8/modules/mol/alg/src/stereo_chemical_props.txt\n",
"\n",
" # Install AWS CLI\n",
" %shell curl \"https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip\" -o \"awscliv2.zip\"\n",
" %shell unzip -qq awscliv2.zip\n",
" %shell sudo ./aws/install\n",
" %shell rm awscliv2.zip\n",
" %shell rm -rf ./aws\n",
"except subprocess.CalledProcessError as captured:\n",
" print(captured)\n",
" raise"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "VzJ5iMjTtoZw",
"cellView": "form"
},
"source": [
"#@title Install OpenFold\n",
"#@markdown Please execute this cell by pressing the *Play* button on \n",
"#@markdown the left.\n",
"\n",
"# Define constants\n",
"GIT_REPO='https://github.com/aqlaboratory/openfold'\n",
"ALPHAFOLD_PARAM_SOURCE_URL = 'https://storage.googleapis.com/alphafold/alphafold_params_2022-01-19.tar'\n",
"OPENFOLD_PARAMS_DIR = './openfold/openfold/resources/openfold_params'\n",
"ALPHAFOLD_PARAMS_DIR = './openfold/openfold/resources/params'\n",
"ALPHAFOLD_PARAMS_PATH = os.path.join(\n",
" ALPHAFOLD_PARAMS_DIR, os.path.basename(ALPHAFOLD_PARAM_SOURCE_URL)\n",
")\n",
"\n",
"try:\n",
" with io.capture_output() as captured:\n",
" # Run setup.py to install only Openfold.\n",
" %shell rm -rf openfold\n",
" %shell git clone \"{GIT_REPO}\" openfold 2>&1 1> /dev/null\n",
" %shell mkdir -p /content/openfold/openfold/resources\n",
" %shell cp -f /content/stereo_chemical_props.txt /content/openfold/openfold/resources\n",
" %shell /usr/bin/python3 -m pip install -q ./openfold\n",
"\n",
" %shell conda install -y -q -c conda-forge openmm=7.5.1\n",
" # Apply OpenMM patch.\n",
" %shell pushd /opt/conda/lib/python3.8/site-packages/ && \\\n",
" patch -p0 < /content/openfold/lib/openmm.patch && \\\n",
" popd\n",
" %shell conda install -y -q -c conda-forge pdbfixer=1.7\n",
"\n",
" if(weight_set == 'AlphaFold'):\n",
" %shell mkdir --parents \"{ALPHAFOLD_PARAMS_DIR}\"\n",
" %shell wget -O {ALPHAFOLD_PARAMS_PATH} {ALPHAFOLD_PARAM_SOURCE_URL}\n",
" %shell tar --extract --verbose --file=\"{ALPHAFOLD_PARAMS_PATH}\" \\\n",
" --directory=\"{ALPHAFOLD_PARAMS_DIR}\" --preserve-permissions\n",
" %shell rm \"{ALPHAFOLD_PARAMS_PATH}\"\n",
" elif(weight_set == 'OpenFold'):\n",
" %shell mkdir --parents \"{OPENFOLD_PARAMS_DIR}\"\n",
" %shell aws s3 cp \\\n",
" --no-sign-request \\\n",
" --region us-east-1 \\\n",
" s3://openfold/openfold_params \"{OPENFOLD_PARAMS_DIR}\" \\\n",
" --recursive\n",
" else:\n",
" raise ValueError(\"Invalid weight set\")\n",
"except subprocess.CalledProcessError as captured:\n",
" print(captured)\n",
" raise"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"#@title Import Python packages\n",
"#@markdown Please execute this cell by pressing the *Play* button on \n",
"#@markdown the left.\n",
"\n",
"import unittest.mock\n",
"import sys\n",
"\n",
"sys.path.insert(0, '/usr/local/lib/python3.8/site-packages/')\n",
"sys.path.append('/opt/conda/lib/python3.8/site-packages')\n",
"\n",
"# Allows us to skip installing these packages\n",
"unnecessary_modules = [\n",
" \"dllogger\",\n",
" \"pytorch_lightning\",\n",
" \"pytorch_lightning.utilities\",\n",
" \"pytorch_lightning.callbacks.early_stopping\",\n",
" \"pytorch_lightning.utilities.seed\",\n",
"]\n",
"for unnecessary_module in unnecessary_modules:\n",
" sys.modules[unnecessary_module] = unittest.mock.MagicMock()\n",
"\n",
"import os\n",
"\n",
"from urllib import request\n",
"from concurrent import futures\n",
"from google.colab import files\n",
"import json\n",
"from matplotlib import gridspec\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import py3Dmol\n",
"import torch\n",
"import shutil\n",
"\n",
"# Prevent shell magic being broken by openmm, prevent this cryptic error:\n",
"# \"NotImplementedError: A UTF-8 locale is required. Got ANSI_X3.4-1968\"\n",
"import locale\n",
"def getpreferredencoding(do_setlocale = True):\n",
" return \"UTF-8\"\n",
"locale.getpreferredencoding = getpreferredencoding\n",
"\n",
"# A filthy hack to avoid slow Linear layer initialization\n",
"import openfold.model.primitives\n",
"\n",
"def __default_linear_init__(self, *args, **kwargs):\n",
" return torch.nn.Linear.__init__(\n",
" self, \n",
" *args[:2], \n",
" **{k:v for k,v in kwargs.items() if k == \"bias\"}\n",
" )\n",
"\n",
"openfold.model.primitives.Linear.__init__ = __default_linear_init__\n",
"\n",
"from openfold import config\n",
"from openfold.data import feature_pipeline\n",
"from openfold.data import parsers\n",
"from openfold.data import data_pipeline\n",
"from openfold.data.tools import jackhmmer\n",
"from openfold.model import model\n",
"from openfold.np import protein\n",
"from openfold.np.relax import relax\n",
"from openfold.np.relax.utils import overwrite_b_factors\n",
"from openfold.utils.import_weights import import_jax_weights_\n",
"from openfold.utils.tensor_utils import tensor_tree_map\n",
"\n",
"from IPython import display\n",
"from ipywidgets import GridspecLayout\n",
"from ipywidgets import Output"
],
"metadata": {
"id": "_FpxxMo-mvcP",
"cellView": "form"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "W4JpOs6oA-QS"
},
"source": [
"## Making a prediction\n",
"\n",
"Note that the search against databases and the actual prediction can take some time, from minutes to hours, depending on the length of the protein and what type of GPU you are allocated by Colab (see FAQ below)."
]
},
{
"cell_type": "code",
"metadata": {
"id": "2tTeTTsLKPjB",
"cellView": "form"
},
"source": [
"#@title Search against genetic databases\n",
"\n",
"#@markdown Once this cell has been executed, you will see\n",
"#@markdown statistics about the multiple sequence alignment \n",
"#@markdown (MSA) that will be used by OpenFold. In particular, \n",
"#@markdown you’ll see how well each residue is covered by similar \n",
"#@markdown sequences in the MSA.\n",
"\n",
"# --- Find the closest source ---\n",
"test_url_pattern = 'https://storage.googleapis.com/alphafold-colab{:s}/latest/uniref90_2021_03.fasta.1'\n",
"ex = futures.ThreadPoolExecutor(3)\n",
"def fetch(source):\n",
" request.urlretrieve(test_url_pattern.format(source))\n",
" return source\n",
"fs = [ex.submit(fetch, source) for source in ['', '-europe', '-asia']]\n",
"source = None\n",
"for f in futures.as_completed(fs):\n",
" source = f.result()\n",
" ex.shutdown()\n",
" break\n",
"\n",
"# --- Search against genetic databases ---\n",
"with open('target.fasta', 'wt') as f:\n",
" f.write(f'>query\\n{sequence}')\n",
"\n",
"# Run the search against chunks of genetic databases (since the genetic\n",
"# databases don't fit in Colab ramdisk).\n",
"\n",
"jackhmmer_binary_path = '/usr/bin/jackhmmer'\n",
"dbs = []\n",
"\n",
"num_jackhmmer_chunks = {'uniref90': 59, 'smallbfd': 17, 'mgnify': 71}\n",
"total_jackhmmer_chunks = sum(num_jackhmmer_chunks.values())\n",
"with tqdm.notebook.tqdm(total=total_jackhmmer_chunks, bar_format=TQDM_BAR_FORMAT) as pbar:\n",
" def jackhmmer_chunk_callback(i):\n",
" pbar.update(n=1)\n",
"\n",
" pbar.set_description('Searching uniref90')\n",
" jackhmmer_uniref90_runner = jackhmmer.Jackhmmer(\n",
" binary_path=jackhmmer_binary_path,\n",
" database_path=f'https://storage.googleapis.com/alphafold-colab{source}/latest/uniref90_2021_03.fasta',\n",
" get_tblout=True,\n",
" num_streamed_chunks=num_jackhmmer_chunks['uniref90'],\n",
" streaming_callback=jackhmmer_chunk_callback,\n",
" z_value=135301051)\n",
" dbs.append(('uniref90', jackhmmer_uniref90_runner.query('target.fasta')))\n",
"\n",
" pbar.set_description('Searching smallbfd')\n",
" jackhmmer_smallbfd_runner = jackhmmer.Jackhmmer(\n",
" binary_path=jackhmmer_binary_path,\n",
" database_path=f'https://storage.googleapis.com/alphafold-colab{source}/latest/bfd-first_non_consensus_sequences.fasta',\n",
" get_tblout=True,\n",
" num_streamed_chunks=num_jackhmmer_chunks['smallbfd'],\n",
" streaming_callback=jackhmmer_chunk_callback,\n",
" z_value=65984053)\n",
" dbs.append(('smallbfd', jackhmmer_smallbfd_runner.query('target.fasta')))\n",
"\n",
" pbar.set_description('Searching mgnify')\n",
" jackhmmer_mgnify_runner = jackhmmer.Jackhmmer(\n",
" binary_path=jackhmmer_binary_path,\n",
" database_path=f'https://storage.googleapis.com/alphafold-colab{source}/latest/mgy_clusters_2019_05.fasta',\n",
" get_tblout=True,\n",
" num_streamed_chunks=num_jackhmmer_chunks['mgnify'],\n",
" streaming_callback=jackhmmer_chunk_callback,\n",
" z_value=304820129)\n",
" dbs.append(('mgnify', jackhmmer_mgnify_runner.query('target.fasta')))\n",
"\n",
"\n",
"# --- Extract the MSAs and visualize ---\n",
"# Extract the MSAs from the Stockholm files.\n",
"# NB: deduplication happens later in data_pipeline.make_msa_features.\n",
"\n",
"mgnify_max_hits = 501\n",
"\n",
"msas = []\n",
"deletion_matrices = []\n",
"full_msa = []\n",
"for db_name, db_results in dbs:\n",
" unsorted_results = []\n",
" for i, result in enumerate(db_results):\n",
" msa, deletion_matrix, target_names = parsers.parse_stockholm(result['sto'])\n",
" e_values_dict = parsers.parse_e_values_from_tblout(result['tbl'])\n",
" e_values = [e_values_dict[t.split('/')[0]] for t in target_names]\n",
" zipped_results = zip(msa, deletion_matrix, target_names, e_values)\n",
" if i != 0:\n",
" # Only take query from the first chunk\n",
" zipped_results = [x for x in zipped_results if x[2] != 'query']\n",
" unsorted_results.extend(zipped_results)\n",
" sorted_by_evalue = sorted(unsorted_results, key=lambda x: x[3])\n",
" db_msas, db_deletion_matrices, _, _ = zip(*sorted_by_evalue)\n",
" if db_msas:\n",
" if db_name == 'mgnify':\n",
" db_msas = db_msas[:mgnify_max_hits]\n",
" db_deletion_matrices = db_deletion_matrices[:mgnify_max_hits]\n",
" full_msa.extend(db_msas)\n",
" msas.append(db_msas)\n",
" deletion_matrices.append(db_deletion_matrices)\n",
" msa_size = len(set(db_msas))\n",
" print(f'{msa_size} Sequences Found in {db_name}')\n",
"\n",
"deduped_full_msa = list(dict.fromkeys(full_msa))\n",
"total_msa_size = len(deduped_full_msa)\n",
"print(f'\\n{total_msa_size} Sequences Found in Total\\n')\n",
"\n",
"aa_map = {restype: i for i, restype in enumerate('ABCDEFGHIJKLMNOPQRSTUVWXYZ-')}\n",
"msa_arr = np.array([[aa_map[aa] for aa in seq] for seq in deduped_full_msa])\n",
"num_alignments, num_res = msa_arr.shape\n",
"\n",
"fig = plt.figure(figsize=(12, 3))\n",
"plt.title('Per-Residue Count of Non-Gap Amino Acids in the MSA')\n",
"plt.plot(np.sum(msa_arr != aa_map['-'], axis=0), color='black')\n",
"plt.ylabel('Non-Gap Count')\n",
"plt.yticks(range(0, num_alignments + 1, max(1, int(num_alignments / 3))))\n",
"plt.show()"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "XUo6foMQxwS2",
"cellView": "form"
},
"source": [
"#@title Run OpenFold and download prediction\n",
"\n",
"#@markdown Once this cell has been executed, a zip-archive with \n",
"#@markdown the obtained prediction will be automatically downloaded \n",
"#@markdown to your computer.\n",
"\n",
"# Color bands for visualizing plddt\n",
"PLDDT_BANDS = [\n",
" (0, 50, '#FF7D45'),\n",
" (50, 70, '#FFDB13'),\n",
" (70, 90, '#65CBF3'),\n",
" (90, 100, '#0053D6')\n",
"]\n",
"\n",
"# --- Run the model ---\n",
"model_names = [ \n",
" 'finetuning_3.pt', \n",
" 'finetuning_4.pt', \n",
" 'finetuning_5.pt', \n",
" 'finetuning_ptm_2.pt',\n",
" 'finetuning_no_templ_ptm_1.pt'\n",
"]\n",
"\n",
"def _placeholder_template_feats(num_templates_, num_res_):\n",
" return {\n",
" 'template_aatype': np.zeros((num_templates_, num_res_, 22), dtype=np.int64),\n",
" 'template_all_atom_positions': np.zeros((num_templates_, num_res_, 37, 3), dtype=np.float32),\n",
" 'template_all_atom_mask': np.zeros((num_templates_, num_res_, 37), dtype=np.float32),\n",
" 'template_domain_names': np.zeros((num_templates_,), dtype=np.float32),\n",
" 'template_sum_probs': np.zeros((num_templates_, 1), dtype=np.float32),\n",
" }\n",
"\n",
"output_dir = 'prediction'\n",
"os.makedirs(output_dir, exist_ok=True)\n",
"\n",
"plddts = {}\n",
"pae_outputs = {}\n",
"unrelaxed_proteins = {}\n",
"\n",
"with tqdm.notebook.tqdm(total=len(model_names) + 1, bar_format=TQDM_BAR_FORMAT) as pbar:\n",
" for i, model_name in list(enumerate(model_names)):\n",
" pbar.set_description(f'Running {model_name}')\n",
" num_templates = 1 # dummy number --- is ignored\n",
" num_res = len(sequence)\n",
" \n",
" feature_dict = {}\n",
" feature_dict.update(data_pipeline.make_sequence_features(sequence, 'test', num_res))\n",
" feature_dict.update(data_pipeline.make_msa_features(msas, deletion_matrices=deletion_matrices))\n",
" feature_dict.update(_placeholder_template_feats(num_templates, num_res))\n",
"\n",
" if(weight_set == \"AlphaFold\"):\n",
" config_preset = f\"model_{i}\"\n",
" else:\n",
" if(\"_no_templ_\" in model_name):\n",
" config_preset = \"model_3\"\n",
" else:\n",
" config_preset = \"model_1\"\n",
" if(\"_ptm_\" in model_name):\n",
" config_preset += \"_ptm\"\n",
"\n",
" cfg = config.model_config(config_preset)\n",
" openfold_model = model.AlphaFold(cfg)\n",
" openfold_model = openfold_model.eval()\n",
" if(weight_set == \"AlphaFold\"):\n",
" params_name = os.path.join(\n",
" ALPHAFOLD_PARAMS_DIR, f\"params_{config_preset}.npz\"\n",
" )\n",
" import_jax_weights_(openfold_model, params_name, version=config_preset)\n",
" elif(weight_set == \"OpenFold\"):\n",
" params_name = os.path.join(\n",
" OPENFOLD_PARAMS_DIR,\n",
" model_name,\n",
" )\n",
" d = torch.load(params_name)\n",
" openfold_model.load_state_dict(d)\n",
" else:\n",
" raise ValueError(f\"Invalid weight set: {weight_set}\")\n",
"\n",
" openfold_model = openfold_model.cuda()\n",
"\n",
" pipeline = feature_pipeline.FeaturePipeline(cfg.data)\n",
" processed_feature_dict = pipeline.process_features(\n",
" feature_dict, mode='predict'\n",
" )\n",
"\n",
" processed_feature_dict = tensor_tree_map(\n",
" lambda t: t.cuda(), processed_feature_dict\n",
" )\n",
"\n",
" with torch.no_grad():\n",
" prediction_result = openfold_model(processed_feature_dict)\n",
"\n",
" # Move the batch and output to np for further processing\n",
" processed_feature_dict = tensor_tree_map(\n",
" lambda t: np.array(t[..., -1].cpu()), processed_feature_dict\n",
" )\n",
" prediction_result = tensor_tree_map(\n",
" lambda t: np.array(t.cpu()), prediction_result\n",
" )\n",
"\n",
" mean_plddt = prediction_result['plddt'].mean()\n",
"\n",
" if 'predicted_aligned_error' in prediction_result:\n",
" pae_outputs[model_name] = (\n",
" prediction_result['predicted_aligned_error'],\n",
" prediction_result['max_predicted_aligned_error']\n",
" )\n",
" else:\n",
" # Get the pLDDT confidence metrics. Do not put pTM models here as they\n",
" # should never get selected.\n",
" plddts[model_name] = prediction_result['plddt']\n",
"\n",
" # Set the b-factors to the per-residue plddt.\n",
" final_atom_mask = prediction_result['final_atom_mask']\n",
" b_factors = prediction_result['plddt'][:, None] * final_atom_mask\n",
" unrelaxed_protein = protein.from_prediction(\n",
" processed_feature_dict, prediction_result, b_factors=b_factors\n",
" )\n",
" unrelaxed_proteins[model_name] = unrelaxed_protein\n",
"\n",
" # Delete unused outputs to save memory.\n",
" del openfold_model\n",
" del processed_feature_dict\n",
" del prediction_result\n",
" pbar.update(n=1)\n",
"\n",
" # Find the best model according to the mean pLDDT.\n",
" best_model_name = max(plddts.keys(), key=lambda x: plddts[x].mean())\n",
" best_pdb = protein.to_pdb(unrelaxed_proteins[best_model_name])\n",
"\n",
" # --- AMBER relax the best model ---\n",
" if(relax_prediction):\n",
" pbar.set_description(f'AMBER relaxation')\n",
" amber_relaxer = relax.AmberRelaxation(\n",
" max_iterations=0,\n",
" tolerance=2.39,\n",
" stiffness=10.0,\n",
" exclude_residues=[],\n",
" max_outer_iterations=20,\n",
" use_gpu=False,\n",
" )\n",
" relaxed_pdb, _, _ = amber_relaxer.process(\n",
" prot=unrelaxed_proteins[best_model_name]\n",
" )\n",
" best_pdb = relaxed_pdb\n",
"\n",
" # Write out the prediction\n",
" pred_output_path = os.path.join(output_dir, 'selected_prediction.pdb')\n",
" with open(pred_output_path, 'w') as f:\n",
" f.write(best_pdb)\n",
"\n",
" pbar.update(n=1) # Finished AMBER relax.\n",
"\n",
"# Construct multiclass b-factors to indicate confidence bands\n",
"# 0=very low, 1=low, 2=confident, 3=very high\n",
"banded_b_factors = []\n",
"for plddt in plddts[best_model_name]:\n",
" for idx, (min_val, max_val, _) in enumerate(PLDDT_BANDS):\n",
" if plddt >= min_val and plddt <= max_val:\n",
" banded_b_factors.append(idx)\n",
" break\n",
"banded_b_factors = np.array(banded_b_factors)[:, None] * final_atom_mask\n",
"to_visualize_pdb = overwrite_b_factors(best_pdb, banded_b_factors)\n",
"\n",
"# --- Visualise the prediction & confidence ---\n",
"show_sidechains = True\n",
"def plot_plddt_legend():\n",
" \"\"\"Plots the legend for pLDDT.\"\"\"\n",
" thresh = [\n",
" 'Very low (pLDDT < 50)',\n",
" 'Low (70 > pLDDT > 50)',\n",
" 'Confident (90 > pLDDT > 70)',\n",
" 'Very high (pLDDT > 90)']\n",
"\n",
" colors = [x[2] for x in PLDDT_BANDS]\n",
"\n",
" plt.figure(figsize=(2, 2))\n",
" for c in colors:\n",
" plt.bar(0, 0, color=c)\n",
" plt.legend(thresh, frameon=False, loc='center', fontsize=20)\n",
" plt.xticks([])\n",
" plt.yticks([])\n",
" ax = plt.gca()\n",
" ax.spines['right'].set_visible(False)\n",
" ax.spines['top'].set_visible(False)\n",
" ax.spines['left'].set_visible(False)\n",
" ax.spines['bottom'].set_visible(False)\n",
" plt.title('Model Confidence', fontsize=20, pad=20)\n",
" return plt\n",
"\n",
"# Color the structure by per-residue pLDDT\n",
"color_map = {i: bands[2] for i, bands in enumerate(PLDDT_BANDS)}\n",
"view = py3Dmol.view(width=800, height=600)\n",
"view.addModelsAsFrames(to_visualize_pdb)\n",
"style = {'cartoon': {\n",
" 'colorscheme': {\n",
" 'prop': 'b',\n",
" 'map': color_map}\n",
" }}\n",
"if show_sidechains:\n",
" style['stick'] = {}\n",
"view.setStyle({'model': -1}, style)\n",
"view.zoomTo()\n",
"\n",
"grid = GridspecLayout(1, 2)\n",
"out = Output()\n",
"with out:\n",
" view.show()\n",
"grid[0, 0] = out\n",
"\n",
"out = Output()\n",
"with out:\n",
" plot_plddt_legend().show()\n",
"grid[0, 1] = out\n",
"\n",
"display.display(grid)\n",
"\n",
"# Display pLDDT and predicted aligned error (if output by the model).\n",
"if pae_outputs:\n",
" num_plots = 2\n",
"else:\n",
" num_plots = 1\n",
"\n",
"plt.figure(figsize=[8 * num_plots, 6])\n",
"plt.subplot(1, num_plots, 1)\n",
"plt.plot(plddts[best_model_name])\n",
"plt.title('Predicted LDDT')\n",
"plt.xlabel('Residue')\n",
"plt.ylabel('pLDDT')\n",
"\n",
"if num_plots == 2:\n",
" plt.subplot(1, 2, 2)\n",
" pae, max_pae = list(pae_outputs.values())[0]\n",
" plt.imshow(pae, vmin=0., vmax=max_pae, cmap='Greens_r')\n",
" plt.colorbar(fraction=0.046, pad=0.04)\n",
" plt.title('Predicted Aligned Error')\n",
" plt.xlabel('Scored residue')\n",
" plt.ylabel('Aligned residue')\n",
"\n",
"# Save pLDDT and predicted aligned error (if it exists)\n",
"pae_output_path = os.path.join(output_dir, 'predicted_aligned_error.json')\n",
"if pae_outputs:\n",
" # Save predicted aligned error in the same format as the AF EMBL DB\n",
" rounded_errors = np.round(pae.astype(np.float64), decimals=1)\n",
" indices = np.indices((len(rounded_errors), len(rounded_errors))) + 1\n",
" indices_1 = indices[0].flatten().tolist()\n",
" indices_2 = indices[1].flatten().tolist()\n",
" pae_data = json.dumps([{\n",
" 'residue1': indices_1,\n",
" 'residue2': indices_2,\n",
" 'distance': rounded_errors.flatten().tolist(),\n",
" 'max_predicted_aligned_error': max_pae.item()\n",
" }],\n",
" indent=None,\n",
" separators=(',', ':'))\n",
" with open(pae_output_path, 'w') as f:\n",
" f.write(pae_data)\n",
"\n",
"\n",
"# --- Download the predictions ---\n",
"shutil.make_archive(base_name='prediction', format='zip', root_dir=output_dir)\n",
"files.download(f'{output_dir}.zip')"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "lUQAn5LYC5n4"
},
"source": [
"### Interpreting the prediction\n",
"\n",
"Please see the [AlphaFold methods paper](https://www.nature.com/articles/s41586-021-03819-2) and the [AlphaFold predictions of the human proteome paper](https://www.nature.com/articles/s41586-021-03828-1), as well as [DeepMind's FAQ](https://alphafold.ebi.ac.uk/faq) on how to interpret AlphaFold/OpenFold predictions."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jeb2z8DIA4om"
},
"source": [
"## FAQ & Troubleshooting\n",
"\n",
"\n",
"* How do I get a predicted protein structure for my protein?\n",
" * Click on the _Connect_ button on the top right to get started.\n",
" * Paste the amino acid sequence of your protein (without any headers) into the “Enter the amino acid sequence to fold”.\n",
" * Run all cells in the Colab, either by running them individually (with the play button on the left side) or via _Runtime_ > _Run all._\n",
" * The predicted protein structure will be downloaded once all cells have been executed. Note: This can take minutes to hours - see below.\n",
"* How long will this take?\n",
" * Downloading the OpenFold source code can take up to a few minutes.\n",
" * Downloading and installing the third-party software can take up to a few minutes.\n",
" * The search against genetic databases can take minutes to hours.\n",
" * Running OpenFold and generating the prediction can take minutes to hours, depending on the length of your protein and on which GPU-type Colab has assigned you.\n",
"* My Colab no longer seems to be doing anything, what should I do?\n",
" * Some steps may take minutes to hours to complete.\n",
" * Sometimes, running the \"installation\" cells more than once can corrupt the OpenFold installation.\n",
" * If nothing happens or if you receive an error message, try restarting your Colab runtime via _Runtime_ > _Restart runtime_.\n",
" * If this doesn’t help, reset your Colab runtime via _Runtime_ > _Factory reset runtime_.\n",
"* How does what's run in this notebook compare to the full versions of Alphafold/Openfold?\n",
" * This Colab version of OpenFold searches a selected portion of the BFD dataset and currently doesn’t use templates, so its accuracy is reduced in comparison to the full version, which is analogous to what's described in the [AlphaFold paper](https://doi.org/10.1038/s41586-021-03819-2) and [Github repo](https://github.com/deepmind/alphafold/). The full version of OpenFold can be run from our own [GitHub repo](https://github.com/aqlaboratory/openfold).\n",
"* What is a Colab?\n",
" * See the [Colab FAQ](https://research.google.com/colaboratory/faq.html).\n",
"* I received a warning “Notebook requires high RAM”, what do I do?\n",
" * The resources allocated to your Colab vary. See the [Colab FAQ](https://research.google.com/colaboratory/faq.html) for more details.\n",
" * You can execute the Colab nonetheless.\n",
"* I received an error “Colab CPU runtime not supported” or “No GPU/TPU found”, what do I do?\n",
" * Colab CPU runtime is not supported. Try changing your runtime via _Runtime_ > _Change runtime type_ > _Hardware accelerator_ > _GPU_.\n",
" * The type of GPU allocated to your Colab varies. See the [Colab FAQ](https://research.google.com/colaboratory/faq.html) for more details.\n",
" * If you receive “Cannot connect to GPU backend”, you can try again later to see if Colab allocates you a GPU.\n",
" * [Colab Pro](https://colab.research.google.com/signup) offers priority access to GPUs. \n",
"* Does this tool install anything on my computer?\n",
" * No, everything happens in the cloud on Google Colab.\n",
" * At the end of the Colab execution a zip-archive with the obtained prediction will be automatically downloaded to your computer.\n",
"* How should I share feedback and bug reports?\n",
" * Please share any feedback and bug reports as an [issue](https://github.com/aqlaboratory/openfold/issues) on Github.\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "YfPhvYgKC81B"
},
"source": [
"# License and Disclaimer\n",
"\n",
"This Colab notebook and other information provided is for theoretical modelling only, caution should be exercised in its use. It is provided ‘as-is’ without any warranty of any kind, whether expressed or implied. Information is not intended to be a substitute for professional medical advice, diagnosis, or treatment, and does not constitute medical or other professional advice.\n",
"\n",
"## AlphaFold/OpenFold Code License\n",
"\n",
"Copyright 2021 AlQuraishi Laboratory\n",
"\n",
"Copyright 2021 DeepMind Technologies Limited.\n",
"\n",
"Licensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with the License. You may obtain a copy of the License at https://www.apache.org/licenses/LICENSE-2.0.\n",
"\n",
"Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.\n",
"\n",
"## Model Parameters License\n",
"\n",
"DeepMind's AlphaFold parameters are made available under the terms of the Creative Commons Attribution 4.0 International (CC BY 4.0) license. You can find details at: https://creativecommons.org/licenses/by/4.0/legalcode\n",
"\n",
"\n",
"## Third-party software\n",
"\n",
"Use of the third-party software, libraries or code referred to in this notebook may be governed by separate terms and conditions or license provisions. Your use of the third-party software, libraries or code is subject to any such terms and you should check that you can comply with any applicable restrictions or terms and conditions before use.\n",
"\n",
"\n",
"## Mirrored Databases\n",
"\n",
"The following databases have been mirrored by DeepMind, and are available with reference to the following:\n",
"* UniRef90: v2021\\_03 (unmodified), by The UniProt Consortium, available under a [Creative Commons Attribution-NoDerivatives 4.0 International License](http://creativecommons.org/licenses/by-nd/4.0/).\n",
"* MGnify: v2019\\_05 (unmodified), by Mitchell AL et al., available free of all copyright restrictions and made fully and freely available for both non-commercial and commercial use under [CC0 1.0 Universal (CC0 1.0) Public Domain Dedication](https://creativecommons.org/publicdomain/zero/1.0/).\n",
"* BFD: (modified), by Steinegger M. and Söding J., modified by DeepMind, available under a [Creative Commons Attribution-ShareAlike 4.0 International License](https://creativecommons.org/licenses/by/4.0/). See the Methods section of the [AlphaFold proteome paper](https://www.nature.com/articles/s41586-021-03828-1) for details."
]
}
]
}
\ No newline at end of file
name: openfold_venv
channels:
- conda-forge
- bioconda
dependencies:
- conda-forge::openmm=7.5.1
- conda-forge::pdbfixer
- bioconda::hmmer==3.3.2
- bioconda::hhsuite==3.3.0
- bioconda::kalign2==2.04
- pip:
- biopython==1.79
- dm-tree==0.1.6
- ml-collections==0.1.0
- PyYAML==5.4.1
- requests==2.26.0
- typing-extensions==3.10.0.2
from . import model
from . import utils
from . import np
from . import resources
__all__ = ["model", "utils", "np", "data", "resources"]
import copy
import importlib
import ml_collections as mlc
def set_inf(c, inf):
for k, v in c.items():
if isinstance(v, mlc.ConfigDict):
set_inf(v, inf)
elif k == "inf":
c[k] = inf
def enforce_config_constraints(config):
def string_to_setting(s):
path = s.split('.')
setting = config
for p in path:
setting = setting[p]
return setting
mutually_exclusive_bools = [
(
"model.template.average_templates",
"model.template.offload_templates"
),
(
"globals.use_lma",
"globals.use_flash",
),
]
for s1, s2 in mutually_exclusive_bools:
s1_setting = string_to_setting(s1)
s2_setting = string_to_setting(s2)
if(s1_setting and s2_setting):
raise ValueError(f"Only one of {s1} and {s2} may be set at a time")
fa_is_installed = importlib.util.find_spec("flash_attn") is not None
if(config.globals.use_flash and not fa_is_installed):
raise ValueError("use_flash requires that FlashAttention is installed")
if(
config.globals.offload_inference and
not config.model.template.average_templates
):
config.model.template.offload_templates = True
def model_config(
name,
train=False,
low_prec=False,
long_sequence_inference=False
):
c = copy.deepcopy(config)
# TRAINING PRESETS
if name == "initial_training":
# AF2 Suppl. Table 4, "initial training" setting
pass
elif name == "finetuning":
# AF2 Suppl. Table 4, "finetuning" setting
c.data.train.crop_size = 384
c.data.train.max_extra_msa = 5120
c.data.train.max_msa_clusters = 512
c.loss.violation.weight = 1.
c.loss.experimentally_resolved.weight = 0.01
elif name == "finetuning_ptm":
c.data.train.max_extra_msa = 5120
c.data.train.crop_size = 384
c.data.train.max_msa_clusters = 512
c.loss.violation.weight = 1.
c.loss.experimentally_resolved.weight = 0.01
c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1
elif name == "finetuning_no_templ":
# AF2 Suppl. Table 4, "finetuning" setting
c.data.train.crop_size = 384
c.data.train.max_extra_msa = 5120
c.data.train.max_msa_clusters = 512
c.model.template.enabled = False
c.loss.violation.weight = 1.
c.loss.experimentally_resolved.weight = 0.01
elif name == "finetuning_no_templ_ptm":
# AF2 Suppl. Table 4, "finetuning" setting
c.data.train.crop_size = 384
c.data.train.max_extra_msa = 5120
c.data.train.max_msa_clusters = 512
c.model.template.enabled = False
c.loss.violation.weight = 1.
c.loss.experimentally_resolved.weight = 0.01
c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1
# INFERENCE PRESETS
elif name == "model_1":
# AF2 Suppl. Table 5, Model 1.1.1
c.data.train.max_extra_msa = 5120
c.data.predict.max_extra_msa = 5120
c.data.common.reduce_max_clusters_by_max_templates = True
c.data.common.use_templates = True
c.data.common.use_template_torsion_angles = True
c.model.template.enabled = True
elif name == "model_2":
# AF2 Suppl. Table 5, Model 1.1.2
c.data.common.reduce_max_clusters_by_max_templates = True
c.data.common.use_templates = True
c.data.common.use_template_torsion_angles = True
c.model.template.enabled = True
elif name == "model_3":
# AF2 Suppl. Table 5, Model 1.2.1
c.data.train.max_extra_msa = 5120
c.data.predict.max_extra_msa = 5120
c.model.template.enabled = False
elif name == "model_4":
# AF2 Suppl. Table 5, Model 1.2.2
c.data.train.max_extra_msa = 5120
c.data.predict.max_extra_msa = 5120
c.model.template.enabled = False
elif name == "model_5":
# AF2 Suppl. Table 5, Model 1.2.3
c.model.template.enabled = False
elif name == "model_1_ptm":
c.data.train.max_extra_msa = 5120
c.data.predict.max_extra_msa = 5120
c.data.common.reduce_max_clusters_by_max_templates = True
c.data.common.use_templates = True
c.data.common.use_template_torsion_angles = True
c.model.template.enabled = True
c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1
elif name == "model_2_ptm":
c.data.common.reduce_max_clusters_by_max_templates = True
c.data.common.use_templates = True
c.data.common.use_template_torsion_angles = True
c.model.template.enabled = True
c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1
elif name == "model_3_ptm":
c.data.train.max_extra_msa = 5120
c.data.predict.max_extra_msa = 5120
c.model.template.enabled = False
c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1
elif name == "model_4_ptm":
c.data.train.max_extra_msa = 5120
c.data.predict.max_extra_msa = 5120
c.model.template.enabled = False
c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1
elif name == "model_5_ptm":
c.model.template.enabled = False
c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1
else:
raise ValueError("Invalid model name")
if long_sequence_inference:
assert(not train)
c.globals.offload_inference = True
c.globals.use_lma = True
c.globals.use_flash = False
c.model.template.offload_inference = True
c.model.template.template_pair_stack.tune_chunk_size = False
c.model.extra_msa.extra_msa_stack.tune_chunk_size = False
c.model.evoformer_stack.tune_chunk_size = False
if train:
c.globals.blocks_per_ckpt = 1
c.globals.chunk_size = None
c.globals.use_lma = False
c.globals.offload_inference = False
c.model.template.average_templates = False
c.model.template.offload_templates = False
if low_prec:
c.globals.eps = 1e-4
# If we want exact numerical parity with the original, inf can't be
# a global constant
set_inf(c, 1e4)
enforce_config_constraints(c)
return c
c_z = mlc.FieldReference(128, field_type=int)
c_m = mlc.FieldReference(256, field_type=int)
c_t = mlc.FieldReference(64, field_type=int)
c_e = mlc.FieldReference(64, field_type=int)
c_s = mlc.FieldReference(384, field_type=int)
blocks_per_ckpt = mlc.FieldReference(None, field_type=int)
chunk_size = mlc.FieldReference(4, field_type=int)
aux_distogram_bins = mlc.FieldReference(64, field_type=int)
tm_enabled = mlc.FieldReference(False, field_type=bool)
eps = mlc.FieldReference(1e-8, field_type=float)
templates_enabled = mlc.FieldReference(True, field_type=bool)
embed_template_torsion_angles = mlc.FieldReference(True, field_type=bool)
tune_chunk_size = mlc.FieldReference(True, field_type=bool)
NUM_RES = "num residues placeholder"
NUM_MSA_SEQ = "msa placeholder"
NUM_EXTRA_SEQ = "extra msa placeholder"
NUM_TEMPLATES = "num templates placeholder"
config = mlc.ConfigDict(
{
"data": {
"common": {
"feat": {
"aatype": [NUM_RES],
"all_atom_mask": [NUM_RES, None],
"all_atom_positions": [NUM_RES, None, None],
"alt_chi_angles": [NUM_RES, None],
"atom14_alt_gt_exists": [NUM_RES, None],
"atom14_alt_gt_positions": [NUM_RES, None, None],
"atom14_atom_exists": [NUM_RES, None],
"atom14_atom_is_ambiguous": [NUM_RES, None],
"atom14_gt_exists": [NUM_RES, None],
"atom14_gt_positions": [NUM_RES, None, None],
"atom37_atom_exists": [NUM_RES, None],
"backbone_rigid_mask": [NUM_RES],
"backbone_rigid_tensor": [NUM_RES, None, None],
"bert_mask": [NUM_MSA_SEQ, NUM_RES],
"chi_angles_sin_cos": [NUM_RES, None, None],
"chi_mask": [NUM_RES, None],
"extra_deletion_value": [NUM_EXTRA_SEQ, NUM_RES],
"extra_has_deletion": [NUM_EXTRA_SEQ, NUM_RES],
"extra_msa": [NUM_EXTRA_SEQ, NUM_RES],
"extra_msa_mask": [NUM_EXTRA_SEQ, NUM_RES],
"extra_msa_row_mask": [NUM_EXTRA_SEQ],
"is_distillation": [],
"msa_feat": [NUM_MSA_SEQ, NUM_RES, None],
"msa_mask": [NUM_MSA_SEQ, NUM_RES],
"msa_row_mask": [NUM_MSA_SEQ],
"no_recycling_iters": [],
"pseudo_beta": [NUM_RES, None],
"pseudo_beta_mask": [NUM_RES],
"residue_index": [NUM_RES],
"residx_atom14_to_atom37": [NUM_RES, None],
"residx_atom37_to_atom14": [NUM_RES, None],
"resolution": [],
"rigidgroups_alt_gt_frames": [NUM_RES, None, None, None],
"rigidgroups_group_exists": [NUM_RES, None],
"rigidgroups_group_is_ambiguous": [NUM_RES, None],
"rigidgroups_gt_exists": [NUM_RES, None],
"rigidgroups_gt_frames": [NUM_RES, None, None, None],
"seq_length": [],
"seq_mask": [NUM_RES],
"target_feat": [NUM_RES, None],
"template_aatype": [NUM_TEMPLATES, NUM_RES],
"template_all_atom_mask": [NUM_TEMPLATES, NUM_RES, None],
"template_all_atom_positions": [
NUM_TEMPLATES, NUM_RES, None, None,
],
"template_alt_torsion_angles_sin_cos": [
NUM_TEMPLATES, NUM_RES, None, None,
],
"template_backbone_rigid_mask": [NUM_TEMPLATES, NUM_RES],
"template_backbone_rigid_tensor": [
NUM_TEMPLATES, NUM_RES, None, None,
],
"template_mask": [NUM_TEMPLATES],
"template_pseudo_beta": [NUM_TEMPLATES, NUM_RES, None],
"template_pseudo_beta_mask": [NUM_TEMPLATES, NUM_RES],
"template_sum_probs": [NUM_TEMPLATES, None],
"template_torsion_angles_mask": [
NUM_TEMPLATES, NUM_RES, None,
],
"template_torsion_angles_sin_cos": [
NUM_TEMPLATES, NUM_RES, None, None,
],
"true_msa": [NUM_MSA_SEQ, NUM_RES],
"use_clamped_fape": [],
},
"masked_msa": {
"profile_prob": 0.1,
"same_prob": 0.1,
"uniform_prob": 0.1,
},
"max_recycling_iters": 3,
"msa_cluster_features": True,
"reduce_msa_clusters_by_max_templates": False,
"resample_msa_in_recycling": True,
"template_features": [
"template_all_atom_positions",
"template_sum_probs",
"template_aatype",
"template_all_atom_mask",
],
"unsupervised_features": [
"aatype",
"residue_index",
"msa",
"num_alignments",
"seq_length",
"between_segment_residues",
"deletion_matrix",
"no_recycling_iters",
],
"use_templates": templates_enabled,
"use_template_torsion_angles": embed_template_torsion_angles,
},
"supervised": {
"clamp_prob": 0.9,
"supervised_features": [
"all_atom_mask",
"all_atom_positions",
"resolution",
"use_clamped_fape",
"is_distillation",
],
},
"predict": {
"fixed_size": True,
"subsample_templates": False, # We want top templates.
"masked_msa_replace_fraction": 0.15,
"max_msa_clusters": 512,
"max_extra_msa": 1024,
"max_template_hits": 4,
"max_templates": 4,
"crop": False,
"crop_size": None,
"supervised": False,
"uniform_recycling": False,
},
"eval": {
"fixed_size": True,
"subsample_templates": False, # We want top templates.
"masked_msa_replace_fraction": 0.15,
"max_msa_clusters": 128,
"max_extra_msa": 1024,
"max_template_hits": 4,
"max_templates": 4,
"crop": False,
"crop_size": None,
"supervised": True,
"uniform_recycling": False,
},
"train": {
"fixed_size": True,
"subsample_templates": True,
"masked_msa_replace_fraction": 0.15,
"max_msa_clusters": 128,
"max_extra_msa": 1024,
"max_template_hits": 4,
"max_templates": 4,
"shuffle_top_k_prefiltered": 20,
"crop": True,
"crop_size": 256,
"supervised": True,
"clamp_prob": 0.9,
"max_distillation_msa_clusters": 1000,
"uniform_recycling": True,
"distillation_prob": 0.75,
},
"data_module": {
"use_small_bfd": False,
"data_loaders": {
"batch_size": 1,
"num_workers": 16,
"pin_memory": True,
},
},
},
# Recurring FieldReferences that can be changed globally here
"globals": {
"blocks_per_ckpt": blocks_per_ckpt,
"chunk_size": chunk_size,
# Use Staats & Rabe's low-memory attention algorithm. Mutually
# exclusive with use_flash.
"use_lma": False,
# Use FlashAttention in selected modules. Mutually exclusive with
# use_lma. Doesn't work that well on long sequences (>1000 residues).
"use_flash": False,
"offload_inference": False,
"c_z": c_z,
"c_m": c_m,
"c_t": c_t,
"c_e": c_e,
"c_s": c_s,
"eps": eps,
},
"model": {
"_mask_trans": False,
"input_embedder": {
"tf_dim": 22,
"msa_dim": 49,
"c_z": c_z,
"c_m": c_m,
"relpos_k": 32,
},
"recycling_embedder": {
"c_z": c_z,
"c_m": c_m,
"min_bin": 3.25,
"max_bin": 20.75,
"no_bins": 15,
"inf": 1e8,
},
"template": {
"distogram": {
"min_bin": 3.25,
"max_bin": 50.75,
"no_bins": 39,
},
"template_angle_embedder": {
# DISCREPANCY: c_in is supposed to be 51.
"c_in": 57,
"c_out": c_m,
},
"template_pair_embedder": {
"c_in": 88,
"c_out": c_t,
},
"template_pair_stack": {
"c_t": c_t,
# DISCREPANCY: c_hidden_tri_att here is given in the supplement
# as 64. In the code, it's 16.
"c_hidden_tri_att": 16,
"c_hidden_tri_mul": 64,
"no_blocks": 2,
"no_heads": 4,
"pair_transition_n": 2,
"dropout_rate": 0.25,
"blocks_per_ckpt": blocks_per_ckpt,
"tune_chunk_size": tune_chunk_size,
"inf": 1e9,
},
"template_pointwise_attention": {
"c_t": c_t,
"c_z": c_z,
# DISCREPANCY: c_hidden here is given in the supplement as 64.
# It's actually 16.
"c_hidden": 16,
"no_heads": 4,
"inf": 1e5, # 1e9,
},
"inf": 1e5, # 1e9,
"eps": eps, # 1e-6,
"enabled": templates_enabled,
"embed_angles": embed_template_torsion_angles,
"use_unit_vector": False,
# Approximate template computation, saving memory.
# In our experiments, results are equivalent to or better than
# the stock implementation. Should be enabled for all new
# training runs.
"average_templates": False,
# Offload template embeddings to CPU memory. Vastly reduced
# memory consumption at the cost of a modest increase in
# runtime. Useful for inference on very long sequences.
# Mutually exclusive with average_templates. Automatically
# enabled if offload_inference is set.
"offload_templates": False,
},
"extra_msa": {
"extra_msa_embedder": {
"c_in": 25,
"c_out": c_e,
},
"extra_msa_stack": {
"c_m": c_e,
"c_z": c_z,
"c_hidden_msa_att": 8,
"c_hidden_opm": 32,
"c_hidden_mul": 128,
"c_hidden_pair_att": 32,
"no_heads_msa": 8,
"no_heads_pair": 4,
"no_blocks": 4,
"transition_n": 4,
"msa_dropout": 0.15,
"pair_dropout": 0.25,
"clear_cache_between_blocks": False,
"tune_chunk_size": tune_chunk_size,
"inf": 1e9,
"eps": eps, # 1e-10,
"ckpt": blocks_per_ckpt is not None,
},
"enabled": True,
},
"evoformer_stack": {
"c_m": c_m,
"c_z": c_z,
"c_hidden_msa_att": 32,
"c_hidden_opm": 32,
"c_hidden_mul": 128,
"c_hidden_pair_att": 32,
"c_s": c_s,
"no_heads_msa": 8,
"no_heads_pair": 4,
"no_blocks": 48,
"transition_n": 4,
"msa_dropout": 0.15,
"pair_dropout": 0.25,
"blocks_per_ckpt": blocks_per_ckpt,
"clear_cache_between_blocks": False,
"tune_chunk_size": tune_chunk_size,
"inf": 1e9,
"eps": eps, # 1e-10,
},
"structure_module": {
"c_s": c_s,
"c_z": c_z,
"c_ipa": 16,
"c_resnet": 128,
"no_heads_ipa": 12,
"no_qk_points": 4,
"no_v_points": 8,
"dropout_rate": 0.1,
"no_blocks": 8,
"no_transition_layers": 1,
"no_resnet_blocks": 2,
"no_angles": 7,
"trans_scale_factor": 10,
"epsilon": eps, # 1e-12,
"inf": 1e5,
},
"heads": {
"lddt": {
"no_bins": 50,
"c_in": c_s,
"c_hidden": 128,
},
"distogram": {
"c_z": c_z,
"no_bins": aux_distogram_bins,
},
"tm": {
"c_z": c_z,
"no_bins": aux_distogram_bins,
"enabled": tm_enabled,
},
"masked_msa": {
"c_m": c_m,
"c_out": 23,
},
"experimentally_resolved": {
"c_s": c_s,
"c_out": 37,
},
},
},
"relax": {
"max_iterations": 0, # no max
"tolerance": 2.39,
"stiffness": 10.0,
"max_outer_iterations": 20,
"exclude_residues": [],
},
"loss": {
"distogram": {
"min_bin": 2.3125,
"max_bin": 21.6875,
"no_bins": 64,
"eps": eps, # 1e-6,
"weight": 0.3,
},
"experimentally_resolved": {
"eps": eps, # 1e-8,
"min_resolution": 0.1,
"max_resolution": 3.0,
"weight": 0.0,
},
"fape": {
"backbone": {
"clamp_distance": 10.0,
"loss_unit_distance": 10.0,
"weight": 0.5,
},
"sidechain": {
"clamp_distance": 10.0,
"length_scale": 10.0,
"weight": 0.5,
},
"eps": 1e-4,
"weight": 1.0,
},
"plddt_loss": {
"min_resolution": 0.1,
"max_resolution": 3.0,
"cutoff": 15.0,
"no_bins": 50,
"eps": eps, # 1e-10,
"weight": 0.01,
},
"masked_msa": {
"eps": eps, # 1e-8,
"weight": 2.0,
},
"supervised_chi": {
"chi_weight": 0.5,
"angle_norm_weight": 0.01,
"eps": eps, # 1e-6,
"weight": 1.0,
},
"violation": {
"violation_tolerance_factor": 12.0,
"clash_overlap_tolerance": 1.5,
"eps": eps, # 1e-6,
"weight": 0.0,
},
"tm": {
"max_bin": 31,
"no_bins": 64,
"min_resolution": 0.1,
"max_resolution": 3.0,
"eps": eps, # 1e-8,
"weight": 0.,
"enabled": tm_enabled,
},
"eps": eps,
},
"ema": {"decay": 0.999},
}
)
import copy
from functools import partial
import json
import logging
import os
import pickle
from typing import Optional, Sequence, List, Any
import ml_collections as mlc
import numpy as np
import pytorch_lightning as pl
import torch
from torch.utils.data import RandomSampler
from openfold.data import (
data_pipeline,
feature_pipeline,
mmcif_parsing,
templates,
)
from openfold.utils.tensor_utils import tensor_tree_map, dict_multimap
class OpenFoldSingleDataset(torch.utils.data.Dataset):
def __init__(self,
data_dir: str,
alignment_dir: str,
template_mmcif_dir: str,
max_template_date: str,
config: mlc.ConfigDict,
chain_data_cache_path: Optional[str] = None,
kalign_binary_path: str = '/usr/bin/kalign',
max_template_hits: int = 4,
obsolete_pdbs_file_path: Optional[str] = None,
template_release_dates_cache_path: Optional[str] = None,
shuffle_top_k_prefiltered: Optional[int] = None,
treat_pdb_as_distillation: bool = True,
filter_path: Optional[str] = None,
mode: str = "train",
alignment_index: Optional[Any] = None,
_output_raw: bool = False,
_structure_index: Optional[Any] = None,
):
"""
Args:
data_dir:
A path to a directory containing mmCIF files (in train
mode) or FASTA files (in inference mode).
alignment_dir:
A path to a directory containing only data in the format
output by an AlignmentRunner
(defined in openfold.features.alignment_runner).
I.e. a directory of directories named {PDB_ID}_{CHAIN_ID}
or simply {PDB_ID}, each containing .a3m, .sto, and .hhr
files.
template_mmcif_dir:
Path to a directory containing template mmCIF files.
config:
A dataset config object. See openfold.config
chain_data_cache_path:
Path to cache of data_dir generated by
scripts/generate_chain_data_cache.py
kalign_binary_path:
Path to kalign binary.
max_template_hits:
An upper bound on how many templates are considered. During
training, the templates ultimately used are subsampled
from this total quantity.
template_release_dates_cache_path:
Path to the output of scripts/generate_mmcif_cache.
obsolete_pdbs_file_path:
Path to the file containing replacements for obsolete PDBs.
shuffle_top_k_prefiltered:
Whether to uniformly shuffle the top k template hits before
parsing max_template_hits of them. Can be used to
approximate DeepMind's training-time template subsampling
scheme much more performantly.
treat_pdb_as_distillation:
Whether to assume that .pdb files in the data_dir are from
the self-distillation set (and should be subjected to
special distillation set preprocessing steps).
mode:
"train", "val", or "predict"
"""
super(OpenFoldSingleDataset, self).__init__()
self.data_dir = data_dir
self.chain_data_cache = None
if chain_data_cache_path is not None:
with open(chain_data_cache_path, "r") as fp:
self.chain_data_cache = json.load(fp)
assert isinstance(self.chain_data_cache, dict)
self.alignment_dir = alignment_dir
self.config = config
self.treat_pdb_as_distillation = treat_pdb_as_distillation
self.mode = mode
self.alignment_index = alignment_index
self._output_raw = _output_raw
self._structure_index = _structure_index
self.supported_exts = [".cif", ".core", ".pdb"]
valid_modes = ["train", "eval", "predict"]
if(mode not in valid_modes):
raise ValueError(f'mode must be one of {valid_modes}')
if(template_release_dates_cache_path is None):
logging.warning(
"Template release dates cache does not exist. Remember to run "
"scripts/generate_mmcif_cache.py before running OpenFold"
)
if(alignment_index is not None):
self._chain_ids = list(alignment_index.keys())
else:
self._chain_ids = list(os.listdir(alignment_dir))
if(filter_path is not None):
with open(filter_path, "r") as f:
chains_to_include = set([l.strip() for l in f.readlines()])
self._chain_ids = [
c for c in self._chain_ids if c in chains_to_include
]
if self.chain_data_cache is not None:
# Filter to include only chains where we have structure data
# (entries in chain_data_cache)
original_chain_ids = self._chain_ids
self._chain_ids = [
c for c in self._chain_ids if c in self.chain_data_cache
]
if len(self._chain_ids) < len(original_chain_ids):
missing = [
c for c in original_chain_ids
if c not in self.chain_data_cache
]
max_to_print = 10
missing_examples = ", ".join(missing[:max_to_print])
if len(missing) > max_to_print:
missing_examples += ", ..."
logging.warning(
"Removing %d alignment entries (%s) with no corresponding "
"entries in chain_data_cache (%s).",
len(missing),
missing_examples,
chain_data_cache_path)
self._chain_id_to_idx_dict = {
chain: i for i, chain in enumerate(self._chain_ids)
}
template_featurizer = templates.TemplateHitFeaturizer(
mmcif_dir=template_mmcif_dir,
max_template_date=max_template_date,
max_hits=max_template_hits,
kalign_binary_path=kalign_binary_path,
release_dates_path=template_release_dates_cache_path,
obsolete_pdbs_path=obsolete_pdbs_file_path,
_shuffle_top_k_prefiltered=shuffle_top_k_prefiltered,
)
self.data_pipeline = data_pipeline.DataPipeline(
template_featurizer=template_featurizer,
)
if(not self._output_raw):
self.feature_pipeline = feature_pipeline.FeaturePipeline(config)
def _parse_mmcif(self, path, file_id, chain_id, alignment_dir, alignment_index):
with open(path, 'r') as f:
mmcif_string = f.read()
mmcif_object = mmcif_parsing.parse(
file_id=file_id, mmcif_string=mmcif_string
)
# Crash if an error is encountered. Any parsing errors should have
# been dealt with at the alignment stage.
if(mmcif_object.mmcif_object is None):
raise list(mmcif_object.errors.values())[0]
mmcif_object = mmcif_object.mmcif_object
data = self.data_pipeline.process_mmcif(
mmcif=mmcif_object,
alignment_dir=alignment_dir,
chain_id=chain_id,
alignment_index=alignment_index
)
return data
def chain_id_to_idx(self, chain_id):
return self._chain_id_to_idx_dict[chain_id]
def idx_to_chain_id(self, idx):
return self._chain_ids[idx]
def __getitem__(self, idx):
name = self.idx_to_chain_id(idx)
alignment_dir = os.path.join(self.alignment_dir, name)
alignment_index = None
if(self.alignment_index is not None):
alignment_dir = self.alignment_dir
alignment_index = self.alignment_index[name]
if(self.mode == 'train' or self.mode == 'eval'):
spl = name.rsplit('_', 1)
if(len(spl) == 2):
file_id, chain_id = spl
else:
file_id, = spl
chain_id = None
path = os.path.join(self.data_dir, file_id)
structure_index_entry = None
if(self._structure_index is not None):
structure_index_entry = self._structure_index[name]
assert(len(structure_index_entry["files"]) == 1)
filename, _, _ = structure_index_entry["files"][0]
ext = os.path.splitext(filename)[1]
else:
ext = None
for e in self.supported_exts:
if(os.path.exists(path + e)):
ext = e
break
if(ext is None):
raise ValueError("Invalid file type")
path += ext
if(ext == ".cif"):
data = self._parse_mmcif(
path, file_id, chain_id, alignment_dir, alignment_index,
)
elif(ext == ".core"):
data = self.data_pipeline.process_core(
path, alignment_dir, alignment_index,
)
elif(ext == ".pdb"):
structure_index = None
if(self._structure_index is not None):
structure_index = self._structure_index[name]
data = self.data_pipeline.process_pdb(
pdb_path=path,
alignment_dir=alignment_dir,
is_distillation=self.treat_pdb_as_distillation,
chain_id=chain_id,
alignment_index=alignment_index,
_structure_index=structure_index,
)
else:
raise ValueError("Extension branch missing")
else:
path = os.path.join(name, name + ".fasta")
data = self.data_pipeline.process_fasta(
fasta_path=path,
alignment_dir=alignment_dir,
alignment_index=alignment_index,
)
if(self._output_raw):
return data
feats = self.feature_pipeline.process_features(
data, self.mode
)
feats["batch_idx"] = torch.tensor(
[idx for _ in range(feats["aatype"].shape[-1])],
dtype=torch.int64,
device=feats["aatype"].device)
return feats
def __len__(self):
return len(self._chain_ids)
def deterministic_train_filter(
chain_data_cache_entry: Any,
max_resolution: float = 9.,
max_single_aa_prop: float = 0.8,
) -> bool:
# Hard filters
resolution = chain_data_cache_entry.get("resolution", None)
if(resolution is not None and resolution > max_resolution):
return False
seq = chain_data_cache_entry["seq"]
counts = {}
for aa in seq:
counts.setdefault(aa, 0)
counts[aa] += 1
largest_aa_count = max(counts.values())
largest_single_aa_prop = largest_aa_count / len(seq)
if(largest_single_aa_prop > max_single_aa_prop):
return False
return True
def get_stochastic_train_filter_prob(
chain_data_cache_entry: Any,
) -> List[float]:
# Stochastic filters
probabilities = []
cluster_size = chain_data_cache_entry.get("cluster_size", None)
if(cluster_size is not None and cluster_size > 0):
probabilities.append(1 / cluster_size)
chain_length = len(chain_data_cache_entry["seq"])
probabilities.append((1 / 512) * (max(min(chain_length, 512), 256)))
# Risk of underflow here?
out = 1
for p in probabilities:
out *= p
return out
class OpenFoldDataset(torch.utils.data.Dataset):
"""
Implements the stochastic filters applied during AlphaFold's training.
Because samples are selected from constituent datasets randomly, the
length of an OpenFoldFilteredDataset is arbitrary. Samples are selected
and filtered once at initialization.
"""
def __init__(self,
datasets: Sequence[OpenFoldSingleDataset],
probabilities: Sequence[float],
epoch_len: int,
generator: torch.Generator = None,
_roll_at_init: bool = True,
):
self.datasets = datasets
self.probabilities = probabilities
self.epoch_len = epoch_len
self.generator = generator
def looped_shuffled_dataset_idx(dataset_len):
while True:
# Uniformly shuffle each dataset's indices
weights = [1. for _ in range(dataset_len)]
shuf = torch.multinomial(
torch.tensor(weights),
num_samples=dataset_len,
replacement=False,
generator=self.generator,
)
for idx in shuf:
yield idx
def looped_samples(dataset_idx):
max_cache_len = int(epoch_len * probabilities[dataset_idx])
dataset = self.datasets[dataset_idx]
idx_iter = looped_shuffled_dataset_idx(len(dataset))
chain_data_cache = dataset.chain_data_cache
while True:
weights = []
idx = []
for _ in range(max_cache_len):
candidate_idx = next(idx_iter)
chain_id = dataset.idx_to_chain_id(candidate_idx)
chain_data_cache_entry = chain_data_cache[chain_id]
if(not deterministic_train_filter(chain_data_cache_entry)):
continue
p = get_stochastic_train_filter_prob(
chain_data_cache_entry,
)
weights.append([1. - p, p])
idx.append(candidate_idx)
samples = torch.multinomial(
torch.tensor(weights),
num_samples=1,
generator=self.generator,
)
samples = samples.squeeze()
cache = [i for i, s in zip(idx, samples) if s]
for datapoint_idx in cache:
yield datapoint_idx
self._samples = [looped_samples(i) for i in range(len(self.datasets))]
if(_roll_at_init):
self.reroll()
def __getitem__(self, idx):
dataset_idx, datapoint_idx = self.datapoints[idx]
return self.datasets[dataset_idx][datapoint_idx]
def __len__(self):
return self.epoch_len
def reroll(self):
dataset_choices = torch.multinomial(
torch.tensor(self.probabilities),
num_samples=self.epoch_len,
replacement=True,
generator=self.generator,
)
self.datapoints = []
for dataset_idx in dataset_choices:
samples = self._samples[dataset_idx]
datapoint_idx = next(samples)
self.datapoints.append((dataset_idx, datapoint_idx))
class OpenFoldBatchCollator:
def __call__(self, prots):
stack_fn = partial(torch.stack, dim=0)
return dict_multimap(stack_fn, prots)
class OpenFoldDataLoader(torch.utils.data.DataLoader):
def __init__(self, *args, config, stage="train", generator=None, **kwargs):
super().__init__(*args, **kwargs)
self.config = config
self.stage = stage
if(generator is None):
generator = torch.Generator()
self.generator = generator
self._prep_batch_properties_probs()
def _prep_batch_properties_probs(self):
keyed_probs = []
stage_cfg = self.config[self.stage]
max_iters = self.config.common.max_recycling_iters
if(stage_cfg.uniform_recycling):
recycling_probs = [
1. / (max_iters + 1) for _ in range(max_iters + 1)
]
else:
recycling_probs = [
0. for _ in range(max_iters + 1)
]
recycling_probs[-1] = 1.
keyed_probs.append(
("no_recycling_iters", recycling_probs)
)
keys, probs = zip(*keyed_probs)
max_len = max([len(p) for p in probs])
padding = [[0.] * (max_len - len(p)) for p in probs]
self.prop_keys = keys
self.prop_probs_tensor = torch.tensor(
[p + pad for p, pad in zip(probs, padding)],
dtype=torch.float32,
)
def _add_batch_properties(self, batch):
samples = torch.multinomial(
self.prop_probs_tensor,
num_samples=1, # 1 per row
replacement=True,
generator=self.generator
)
aatype = batch["aatype"]
batch_dims = aatype.shape[:-2]
recycling_dim = aatype.shape[-1]
no_recycling = recycling_dim
for i, key in enumerate(self.prop_keys):
sample = int(samples[i][0])
sample_tensor = torch.tensor(
sample,
device=aatype.device,
requires_grad=False
)
orig_shape = sample_tensor.shape
sample_tensor = sample_tensor.view(
(1,) * len(batch_dims) + sample_tensor.shape + (1,)
)
sample_tensor = sample_tensor.expand(
batch_dims + orig_shape + (recycling_dim,)
)
batch[key] = sample_tensor
if(key == "no_recycling_iters"):
no_recycling = sample
resample_recycling = lambda t: t[..., :no_recycling + 1]
batch = tensor_tree_map(resample_recycling, batch)
return batch
def __iter__(self):
it = super().__iter__()
def _batch_prop_gen(iterator):
for batch in iterator:
yield self._add_batch_properties(batch)
return _batch_prop_gen(it)
class OpenFoldDataModule(pl.LightningDataModule):
def __init__(self,
config: mlc.ConfigDict,
template_mmcif_dir: str,
max_template_date: str,
train_data_dir: Optional[str] = None,
train_alignment_dir: Optional[str] = None,
train_chain_data_cache_path: Optional[str] = None,
distillation_data_dir: Optional[str] = None,
distillation_alignment_dir: Optional[str] = None,
distillation_chain_data_cache_path: Optional[str] = None,
val_data_dir: Optional[str] = None,
val_alignment_dir: Optional[str] = None,
predict_data_dir: Optional[str] = None,
predict_alignment_dir: Optional[str] = None,
kalign_binary_path: str = '/usr/bin/kalign',
train_filter_path: Optional[str] = None,
distillation_filter_path: Optional[str] = None,
obsolete_pdbs_file_path: Optional[str] = None,
template_release_dates_cache_path: Optional[str] = None,
batch_seed: Optional[int] = None,
train_epoch_len: int = 50000,
_distillation_structure_index_path: Optional[str] = None,
alignment_index_path: Optional[str] = None,
distillation_alignment_index_path: Optional[str] = None,
**kwargs
):
super(OpenFoldDataModule, self).__init__()
self.config = config
self.template_mmcif_dir = template_mmcif_dir
self.max_template_date = max_template_date
self.train_data_dir = train_data_dir
self.train_alignment_dir = train_alignment_dir
self.train_chain_data_cache_path = train_chain_data_cache_path
self.distillation_data_dir = distillation_data_dir
self.distillation_alignment_dir = distillation_alignment_dir
self.distillation_chain_data_cache_path = (
distillation_chain_data_cache_path
)
self.val_data_dir = val_data_dir
self.val_alignment_dir = val_alignment_dir
self.predict_data_dir = predict_data_dir
self.predict_alignment_dir = predict_alignment_dir
self.kalign_binary_path = kalign_binary_path
self.train_filter_path = train_filter_path
self.distillation_filter_path = distillation_filter_path
self.template_release_dates_cache_path = (
template_release_dates_cache_path
)
self.obsolete_pdbs_file_path = obsolete_pdbs_file_path
self.batch_seed = batch_seed
self.train_epoch_len = train_epoch_len
if(self.train_data_dir is None and self.predict_data_dir is None):
raise ValueError(
'At least one of train_data_dir or predict_data_dir must be '
'specified'
)
self.training_mode = self.train_data_dir is not None
if(self.training_mode and train_alignment_dir is None):
raise ValueError(
'In training mode, train_alignment_dir must be specified'
)
elif(not self.training_mode and predict_alignment_dir is None):
raise ValueError(
'In inference mode, predict_alignment_dir must be specified'
)
elif(val_data_dir is not None and val_alignment_dir is None):
raise ValueError(
'If val_data_dir is specified, val_alignment_dir must '
'be specified as well'
)
# An ad-hoc measure for our particular filesystem restrictions
self._distillation_structure_index = None
if(_distillation_structure_index_path is not None):
with open(_distillation_structure_index_path, "r") as fp:
self._distillation_structure_index = json.load(fp)
self.alignment_index = None
if(alignment_index_path is not None):
with open(alignment_index_path, "r") as fp:
self.alignment_index = json.load(fp)
self.distillation_alignment_index = None
if(distillation_alignment_index_path is not None):
with open(distillation_alignment_index_path, "r") as fp:
self.distillation_alignment_index = json.load(fp)
def setup(self):
# Most of the arguments are the same for the three datasets
dataset_gen = partial(OpenFoldSingleDataset,
template_mmcif_dir=self.template_mmcif_dir,
max_template_date=self.max_template_date,
config=self.config,
kalign_binary_path=self.kalign_binary_path,
template_release_dates_cache_path=
self.template_release_dates_cache_path,
obsolete_pdbs_file_path=
self.obsolete_pdbs_file_path,
)
if(self.training_mode):
train_dataset = dataset_gen(
data_dir=self.train_data_dir,
chain_data_cache_path=self.train_chain_data_cache_path,
alignment_dir=self.train_alignment_dir,
filter_path=self.train_filter_path,
max_template_hits=self.config.train.max_template_hits,
shuffle_top_k_prefiltered=
self.config.train.shuffle_top_k_prefiltered,
treat_pdb_as_distillation=False,
mode="train",
alignment_index=self.alignment_index,
)
distillation_dataset = None
if(self.distillation_data_dir is not None):
distillation_dataset = dataset_gen(
data_dir=self.distillation_data_dir,
chain_data_cache_path=self.distillation_chain_data_cache_path,
alignment_dir=self.distillation_alignment_dir,
filter_path=self.distillation_filter_path,
max_template_hits=self.config.train.max_template_hits,
treat_pdb_as_distillation=True,
mode="train",
alignment_index=self.distillation_alignment_index,
_structure_index=self._distillation_structure_index,
)
d_prob = self.config.train.distillation_prob
if(distillation_dataset is not None):
datasets = [train_dataset, distillation_dataset]
d_prob = self.config.train.distillation_prob
probabilities = [1. - d_prob, d_prob]
else:
datasets = [train_dataset]
probabilities = [1.]
generator = None
if(self.batch_seed is not None):
generator = torch.Generator()
generator = generator.manual_seed(self.batch_seed + 1)
self.train_dataset = OpenFoldDataset(
datasets=datasets,
probabilities=probabilities,
epoch_len=self.train_epoch_len,
generator=generator,
_roll_at_init=False,
)
if(self.val_data_dir is not None):
self.eval_dataset = dataset_gen(
data_dir=self.val_data_dir,
alignment_dir=self.val_alignment_dir,
filter_path=None,
max_template_hits=self.config.eval.max_template_hits,
mode="eval",
)
else:
self.eval_dataset = None
else:
self.predict_dataset = dataset_gen(
data_dir=self.predict_data_dir,
alignment_dir=self.predict_alignment_dir,
filter_path=None,
max_template_hits=self.config.predict.max_template_hits,
mode="predict",
)
def _gen_dataloader(self, stage):
generator = torch.Generator()
if(self.batch_seed is not None):
generator = generator.manual_seed(self.batch_seed)
dataset = None
if(stage == "train"):
dataset = self.train_dataset
# Filter the dataset, if necessary
dataset.reroll()
elif(stage == "eval"):
dataset = self.eval_dataset
elif(stage == "predict"):
dataset = self.predict_dataset
else:
raise ValueError("Invalid stage")
batch_collator = OpenFoldBatchCollator()
dl = OpenFoldDataLoader(
dataset,
config=self.config,
stage=stage,
generator=generator,
batch_size=self.config.data_module.data_loaders.batch_size,
num_workers=self.config.data_module.data_loaders.num_workers,
collate_fn=batch_collator,
)
return dl
def train_dataloader(self):
return self._gen_dataloader("train")
def val_dataloader(self):
if(self.eval_dataset is not None):
return self._gen_dataloader("eval")
return None
def predict_dataloader(self):
return self._gen_dataloader("predict")
class DummyDataset(torch.utils.data.Dataset):
def __init__(self, batch_path):
with open(batch_path, "rb") as f:
self.batch = pickle.load(f)
def __getitem__(self, idx):
return copy.deepcopy(self.batch)
def __len__(self):
return 1000
class DummyDataLoader(pl.LightningDataModule):
def __init__(self, batch_path):
super().__init__()
self.dataset = DummyDataset(batch_path)
def train_dataloader(self):
return torch.utils.data.DataLoader(self.dataset)
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import datetime
from multiprocessing import cpu_count
from typing import Mapping, Optional, Sequence, Any
import numpy as np
from openfold.data import templates, parsers, mmcif_parsing
from openfold.data.templates import get_custom_template_features
from openfold.data.tools import jackhmmer, hhblits, hhsearch
from openfold.data.tools.utils import to_date
from openfold.np import residue_constants, protein
FeatureDict = Mapping[str, np.ndarray]
def empty_template_feats(n_res) -> FeatureDict:
return {
"template_aatype": np.zeros((0, n_res)).astype(np.int64),
"template_all_atom_positions":
np.zeros((0, n_res, 37, 3)).astype(np.float32),
"template_sum_probs": np.zeros((0, 1)).astype(np.float32),
"template_all_atom_mask": np.zeros((0, n_res, 37)).astype(np.float32),
}
def make_template_features(
input_sequence: str,
hits: Sequence[Any],
template_featurizer: Any,
query_pdb_code: Optional[str] = None,
query_release_date: Optional[str] = None,
) -> FeatureDict:
hits_cat = sum(hits.values(), [])
if(len(hits_cat) == 0 or template_featurizer is None):
template_features = empty_template_feats(len(input_sequence))
else:
templates_result = template_featurizer.get_templates(
query_sequence=input_sequence,
query_pdb_code=query_pdb_code,
query_release_date=query_release_date,
hits=hits_cat,
)
template_features = templates_result.features
# The template featurizer doesn't format empty template features
# properly. This is a quick fix.
if(template_features["template_aatype"].shape[0] == 0):
template_features = empty_template_feats(len(input_sequence))
return template_features
def unify_template_features(
template_feature_list: Sequence[FeatureDict]
) -> FeatureDict:
out_dicts = []
seq_lens = [fd["template_aatype"].shape[1] for fd in template_feature_list]
for i, fd in enumerate(template_feature_list):
out_dict = {}
n_templates, n_res = fd["template_aatype"].shape[:2]
for k,v in fd.items():
seq_keys = [
"template_aatype",
"template_all_atom_positions",
"template_all_atom_mask",
]
if(k in seq_keys):
new_shape = list(v.shape)
assert(new_shape[1] == n_res)
new_shape[1] = sum(seq_lens)
new_array = np.zeros(new_shape, dtype=v.dtype)
if(k == "template_aatype"):
new_array[..., residue_constants.HHBLITS_AA_TO_ID['-']] = 1
offset = sum(seq_lens[:i])
new_array[:, offset:offset + seq_lens[i]] = v
out_dict[k] = new_array
else:
out_dict[k] = v
chain_indices = np.array(n_templates * [i])
out_dict["template_chain_index"] = chain_indices
if(n_templates != 0):
out_dicts.append(out_dict)
if(len(out_dicts) > 0):
out_dict = {
k: np.concatenate([od[k] for od in out_dicts]) for k in out_dicts[0]
}
else:
out_dict = empty_template_feats(sum(seq_lens))
return out_dict
def make_sequence_features(
sequence: str, description: str, num_res: int
) -> FeatureDict:
"""Construct a feature dict of sequence features."""
features = {}
features["aatype"] = residue_constants.sequence_to_onehot(
sequence=sequence,
mapping=residue_constants.restype_order_with_x,
map_unknown_to_x=True,
)
features["between_segment_residues"] = np.zeros((num_res,), dtype=np.int32)
features["domain_name"] = np.array(
[description.encode("utf-8")], dtype=np.object_
)
features["residue_index"] = np.array(range(num_res), dtype=np.int32)
features["seq_length"] = np.array([num_res] * num_res, dtype=np.int32)
features["sequence"] = np.array(
[sequence.encode("utf-8")], dtype=np.object_
)
return features
def make_mmcif_features(
mmcif_object: mmcif_parsing.MmcifObject, chain_id: str
) -> FeatureDict:
input_sequence = mmcif_object.chain_to_seqres[chain_id]
description = "_".join([mmcif_object.file_id, chain_id])
num_res = len(input_sequence)
mmcif_feats = {}
mmcif_feats.update(
make_sequence_features(
sequence=input_sequence,
description=description,
num_res=num_res,
)
)
all_atom_positions, all_atom_mask = mmcif_parsing.get_atom_coords(
mmcif_object=mmcif_object, chain_id=chain_id
)
mmcif_feats["all_atom_positions"] = all_atom_positions
mmcif_feats["all_atom_mask"] = all_atom_mask
mmcif_feats["resolution"] = np.array(
[mmcif_object.header["resolution"]], dtype=np.float32
)
mmcif_feats["release_date"] = np.array(
[mmcif_object.header["release_date"].encode("utf-8")], dtype=np.object_
)
mmcif_feats["is_distillation"] = np.array(0., dtype=np.float32)
return mmcif_feats
def _aatype_to_str_sequence(aatype):
return ''.join([
residue_constants.restypes_with_x[aatype[i]]
for i in range(len(aatype))
])
def make_protein_features(
protein_object: protein.Protein,
description: str,
_is_distillation: bool = False,
) -> FeatureDict:
pdb_feats = {}
aatype = protein_object.aatype
sequence = _aatype_to_str_sequence(aatype)
pdb_feats.update(
make_sequence_features(
sequence=sequence,
description=description,
num_res=len(protein_object.aatype),
)
)
all_atom_positions = protein_object.atom_positions
all_atom_mask = protein_object.atom_mask
pdb_feats["all_atom_positions"] = all_atom_positions.astype(np.float32)
pdb_feats["all_atom_mask"] = all_atom_mask.astype(np.float32)
pdb_feats["resolution"] = np.array([0.]).astype(np.float32)
pdb_feats["is_distillation"] = np.array(
1. if _is_distillation else 0.
).astype(np.float32)
return pdb_feats
def make_pdb_features(
protein_object: protein.Protein,
description: str,
is_distillation: bool = True,
confidence_threshold: float = 50.,
) -> FeatureDict:
pdb_feats = make_protein_features(
protein_object, description, _is_distillation=True
)
if(is_distillation):
high_confidence = protein_object.b_factors > confidence_threshold
high_confidence = np.any(high_confidence, axis=-1)
pdb_feats["all_atom_mask"] *= high_confidence[..., None]
return pdb_feats
def make_msa_features(
msas: Sequence[Sequence[str]],
deletion_matrices: Sequence[parsers.DeletionMatrix],
) -> FeatureDict:
"""Constructs a feature dict of MSA features."""
if not msas:
raise ValueError("At least one MSA must be provided.")
int_msa = []
deletion_matrix = []
seen_sequences = set()
for msa_index, msa in enumerate(msas):
if not msa:
raise ValueError(
f"MSA {msa_index} must contain at least one sequence."
)
for sequence_index, sequence in enumerate(msa):
if sequence in seen_sequences:
continue
seen_sequences.add(sequence)
int_msa.append(
[residue_constants.HHBLITS_AA_TO_ID[res] for res in sequence]
)
deletion_matrix.append(deletion_matrices[msa_index][sequence_index])
num_res = len(msas[0][0])
num_alignments = len(int_msa)
features = {}
features["deletion_matrix_int"] = np.array(deletion_matrix, dtype=np.int32)
features["msa"] = np.array(int_msa, dtype=np.int32)
features["num_alignments"] = np.array(
[num_alignments] * num_res, dtype=np.int32
)
return features
def make_sequence_features_with_custom_template(
sequence: str,
mmcif_path: str,
pdb_id: str,
chain_id: str,
kalign_binary_path: str) -> FeatureDict:
"""
process a single fasta file using features derived from a single template rather than an alignment
"""
num_res = len(sequence)
sequence_features = make_sequence_features(
sequence=sequence,
description=pdb_id,
num_res=num_res,
)
msa_data = [[sequence]]
deletion_matrix = [[[0 for _ in sequence]]]
msa_features = make_msa_features(msa_data, deletion_matrix)
template_features = get_custom_template_features(
mmcif_path=mmcif_path,
query_sequence=sequence,
pdb_id=pdb_id,
chain_id=chain_id,
kalign_binary_path=kalign_binary_path
)
return {
**sequence_features,
**msa_features,
**template_features.features
}
class AlignmentRunner:
"""Runs alignment tools and saves the results"""
def __init__(
self,
jackhmmer_binary_path: Optional[str] = None,
hhblits_binary_path: Optional[str] = None,
hhsearch_binary_path: Optional[str] = None,
uniref90_database_path: Optional[str] = None,
mgnify_database_path: Optional[str] = None,
bfd_database_path: Optional[str] = None,
uniclust30_database_path: Optional[str] = None,
pdb70_database_path: Optional[str] = None,
use_small_bfd: Optional[bool] = None,
no_cpus: Optional[int] = None,
uniref_max_hits: int = 10000,
mgnify_max_hits: int = 5000,
):
"""
Args:
jackhmmer_binary_path:
Path to jackhmmer binary
hhblits_binary_path:
Path to hhblits binary
hhsearch_binary_path:
Path to hhsearch binary
uniref90_database_path:
Path to uniref90 database. If provided, jackhmmer_binary_path
must also be provided
mgnify_database_path:
Path to mgnify database. If provided, jackhmmer_binary_path
must also be provided
bfd_database_path:
Path to BFD database. Depending on the value of use_small_bfd,
one of hhblits_binary_path or jackhmmer_binary_path must be
provided.
uniclust30_database_path:
Path to uniclust30. Searched alongside BFD if use_small_bfd is
false.
pdb70_database_path:
Path to pdb70 database.
use_small_bfd:
Whether to search the BFD database alone with jackhmmer or
in conjunction with uniclust30 with hhblits.
no_cpus:
The number of CPUs available for alignment. By default, all
CPUs are used.
uniref_max_hits:
Max number of uniref hits
mgnify_max_hits:
Max number of mgnify hits
"""
db_map = {
"jackhmmer": {
"binary": jackhmmer_binary_path,
"dbs": [
uniref90_database_path,
mgnify_database_path,
bfd_database_path if use_small_bfd else None,
],
},
"hhblits": {
"binary": hhblits_binary_path,
"dbs": [
bfd_database_path if not use_small_bfd else None,
],
},
"hhsearch": {
"binary": hhsearch_binary_path,
"dbs": [
pdb70_database_path,
],
},
}
for name, dic in db_map.items():
binary, dbs = dic["binary"], dic["dbs"]
if(binary is None and not all([x is None for x in dbs])):
raise ValueError(
f"{name} DBs provided but {name} binary is None"
)
if(not all([x is None for x in db_map["hhsearch"]["dbs"]])
and uniref90_database_path is None):
raise ValueError(
"""uniref90_database_path must be specified in order to perform
template search"""
)
self.uniref_max_hits = uniref_max_hits
self.mgnify_max_hits = mgnify_max_hits
self.use_small_bfd = use_small_bfd
if(no_cpus is None):
no_cpus = cpu_count()
self.jackhmmer_uniref90_runner = None
if(jackhmmer_binary_path is not None and
uniref90_database_path is not None
):
self.jackhmmer_uniref90_runner = jackhmmer.Jackhmmer(
binary_path=jackhmmer_binary_path,
database_path=uniref90_database_path,
n_cpu=no_cpus,
)
self.jackhmmer_small_bfd_runner = None
self.hhblits_bfd_uniclust_runner = None
if(bfd_database_path is not None):
if use_small_bfd:
self.jackhmmer_small_bfd_runner = jackhmmer.Jackhmmer(
binary_path=jackhmmer_binary_path,
database_path=bfd_database_path,
n_cpu=no_cpus,
)
else:
dbs = [bfd_database_path]
if(uniclust30_database_path is not None):
dbs.append(uniclust30_database_path)
self.hhblits_bfd_uniclust_runner = hhblits.HHBlits(
binary_path=hhblits_binary_path,
databases=dbs,
n_cpu=no_cpus,
)
self.jackhmmer_mgnify_runner = None
if(mgnify_database_path is not None):
self.jackhmmer_mgnify_runner = jackhmmer.Jackhmmer(
binary_path=jackhmmer_binary_path,
database_path=mgnify_database_path,
n_cpu=no_cpus,
)
self.hhsearch_pdb70_runner = None
if(pdb70_database_path is not None):
self.hhsearch_pdb70_runner = hhsearch.HHSearch(
binary_path=hhsearch_binary_path,
databases=[pdb70_database_path],
n_cpu=no_cpus,
)
def run(
self,
fasta_path: str,
output_dir: str,
):
"""Runs alignment tools on a sequence"""
if(self.jackhmmer_uniref90_runner is not None):
jackhmmer_uniref90_result = self.jackhmmer_uniref90_runner.query(
fasta_path
)[0]
uniref90_msa_as_a3m = parsers.convert_stockholm_to_a3m(
jackhmmer_uniref90_result["sto"],
max_sequences=self.uniref_max_hits
)
uniref90_out_path = os.path.join(output_dir, "uniref90_hits.a3m")
with open(uniref90_out_path, "w") as f:
f.write(uniref90_msa_as_a3m)
if(self.hhsearch_pdb70_runner is not None):
hhsearch_result = self.hhsearch_pdb70_runner.query(
uniref90_msa_as_a3m
)
pdb70_out_path = os.path.join(output_dir, "pdb70_hits.hhr")
with open(pdb70_out_path, "w") as f:
f.write(hhsearch_result)
if(self.jackhmmer_mgnify_runner is not None):
jackhmmer_mgnify_result = self.jackhmmer_mgnify_runner.query(
fasta_path
)[0]
mgnify_msa_as_a3m = parsers.convert_stockholm_to_a3m(
jackhmmer_mgnify_result["sto"],
max_sequences=self.mgnify_max_hits
)
mgnify_out_path = os.path.join(output_dir, "mgnify_hits.a3m")
with open(mgnify_out_path, "w") as f:
f.write(mgnify_msa_as_a3m)
if(self.use_small_bfd and self.jackhmmer_small_bfd_runner is not None):
jackhmmer_small_bfd_result = self.jackhmmer_small_bfd_runner.query(
fasta_path
)[0]
bfd_out_path = os.path.join(output_dir, "small_bfd_hits.sto")
with open(bfd_out_path, "w") as f:
f.write(jackhmmer_small_bfd_result["sto"])
elif(self.hhblits_bfd_uniclust_runner is not None):
hhblits_bfd_uniclust_result = (
self.hhblits_bfd_uniclust_runner.query(fasta_path)
)
if output_dir is not None:
bfd_out_path = os.path.join(output_dir, "bfd_uniclust_hits.a3m")
with open(bfd_out_path, "w") as f:
f.write(hhblits_bfd_uniclust_result["a3m"])
class DataPipeline:
"""Assembles input features."""
def __init__(
self,
template_featurizer: Optional[templates.TemplateHitFeaturizer],
):
self.template_featurizer = template_featurizer
def _parse_msa_data(
self,
alignment_dir: str,
alignment_index: Optional[Any] = None,
) -> Mapping[str, Any]:
msa_data = {}
if(alignment_index is not None):
fp = open(os.path.join(alignment_dir, alignment_index["db"]), "rb")
def read_msa(start, size):
fp.seek(start)
msa = fp.read(size).decode("utf-8")
return msa
for (name, start, size) in alignment_index["files"]:
ext = os.path.splitext(name)[-1]
if(ext == ".a3m"):
msa, deletion_matrix = parsers.parse_a3m(
read_msa(start, size)
)
data = {"msa": msa, "deletion_matrix": deletion_matrix}
elif(ext == ".sto"):
msa, deletion_matrix, _ = parsers.parse_stockholm(
read_msa(start, size)
)
data = {"msa": msa, "deletion_matrix": deletion_matrix}
else:
continue
msa_data[name] = data
fp.close()
else:
for f in os.listdir(alignment_dir):
path = os.path.join(alignment_dir, f)
ext = os.path.splitext(f)[-1]
if(ext == ".a3m"):
with open(path, "r") as fp:
msa, deletion_matrix = parsers.parse_a3m(fp.read())
data = {"msa": msa, "deletion_matrix": deletion_matrix}
elif(ext == ".sto"):
with open(path, "r") as fp:
msa, deletion_matrix, _ = parsers.parse_stockholm(
fp.read()
)
data = {"msa": msa, "deletion_matrix": deletion_matrix}
else:
continue
msa_data[f] = data
return msa_data
def _parse_template_hits(
self,
alignment_dir: str,
alignment_index: Optional[Any] = None
) -> Mapping[str, Any]:
all_hits = {}
if(alignment_index is not None):
fp = open(os.path.join(alignment_dir, alignment_index["db"]), 'rb')
def read_template(start, size):
fp.seek(start)
return fp.read(size).decode("utf-8")
for (name, start, size) in alignment_index["files"]:
ext = os.path.splitext(name)[-1]
if(ext == ".hhr"):
hits = parsers.parse_hhr(read_template(start, size))
all_hits[name] = hits
fp.close()
else:
for f in os.listdir(alignment_dir):
path = os.path.join(alignment_dir, f)
ext = os.path.splitext(f)[-1]
if(ext == ".hhr"):
with open(path, "r") as fp:
hits = parsers.parse_hhr(fp.read())
all_hits[f] = hits
return all_hits
def _get_msas(self,
alignment_dir: str,
input_sequence: Optional[str] = None,
alignment_index: Optional[str] = None,
):
msa_data = self._parse_msa_data(alignment_dir, alignment_index)
if(len(msa_data) == 0):
if(input_sequence is None):
raise ValueError(
"""
If the alignment dir contains no MSAs, an input sequence
must be provided.
"""
)
msa_data["dummy"] = {
"msa": [input_sequence],
"deletion_matrix": [[0 for _ in input_sequence]],
}
msas, deletion_matrices = zip(*[
(v["msa"], v["deletion_matrix"]) for v in msa_data.values()
])
return msas, deletion_matrices
def _process_msa_feats(
self,
alignment_dir: str,
input_sequence: Optional[str] = None,
alignment_index: Optional[str] = None
) -> Mapping[str, Any]:
msas, deletion_matrices = self._get_msas(
alignment_dir, input_sequence, alignment_index
)
msa_features = make_msa_features(
msas=msas,
deletion_matrices=deletion_matrices,
)
return msa_features
def process_fasta(
self,
fasta_path: str,
alignment_dir: str,
alignment_index: Optional[str] = None,
) -> FeatureDict:
"""Assembles features for a single sequence in a FASTA file"""
with open(fasta_path) as f:
fasta_str = f.read()
input_seqs, input_descs = parsers.parse_fasta(fasta_str)
if len(input_seqs) != 1:
raise ValueError(
f"More than one input sequence found in {fasta_path}."
)
input_sequence = input_seqs[0]
input_description = input_descs[0]
num_res = len(input_sequence)
hits = self._parse_template_hits(alignment_dir, alignment_index)
template_features = make_template_features(
input_sequence,
hits,
self.template_featurizer,
)
sequence_features = make_sequence_features(
sequence=input_sequence,
description=input_description,
num_res=num_res,
)
msa_features = self._process_msa_feats(alignment_dir, input_sequence, alignment_index)
return {
**sequence_features,
**msa_features,
**template_features
}
def process_mmcif(
self,
mmcif: mmcif_parsing.MmcifObject, # parsing is expensive, so no path
alignment_dir: str,
chain_id: Optional[str] = None,
alignment_index: Optional[str] = None,
) -> FeatureDict:
"""
Assembles features for a specific chain in an mmCIF object.
If chain_id is None, it is assumed that there is only one chain
in the object. Otherwise, a ValueError is thrown.
"""
if chain_id is None:
chains = mmcif.structure.get_chains()
chain = next(chains, None)
if chain is None:
raise ValueError("No chains in mmCIF file")
chain_id = chain.id
mmcif_feats = make_mmcif_features(mmcif, chain_id)
input_sequence = mmcif.chain_to_seqres[chain_id]
hits = self._parse_template_hits(alignment_dir, alignment_index)
template_features = make_template_features(
input_sequence,
hits,
self.template_featurizer,
query_release_date=to_date(mmcif.header["release_date"])
)
msa_features = self._process_msa_feats(alignment_dir, input_sequence, alignment_index)
return {**mmcif_feats, **template_features, **msa_features}
def process_pdb(
self,
pdb_path: str,
alignment_dir: str,
is_distillation: bool = True,
chain_id: Optional[str] = None,
_structure_index: Optional[str] = None,
alignment_index: Optional[str] = None,
) -> FeatureDict:
"""
Assembles features for a protein in a PDB file.
"""
if(_structure_index is not None):
db_dir = os.path.dirname(pdb_path)
db = _structure_index["db"]
db_path = os.path.join(db_dir, db)
fp = open(db_path, "rb")
_, offset, length = _structure_index["files"][0]
fp.seek(offset)
pdb_str = fp.read(length).decode("utf-8")
fp.close()
else:
with open(pdb_path, 'r') as f:
pdb_str = f.read()
protein_object = protein.from_pdb_string(pdb_str, chain_id)
input_sequence = _aatype_to_str_sequence(protein_object.aatype)
description = os.path.splitext(os.path.basename(pdb_path))[0].upper()
pdb_feats = make_pdb_features(
protein_object,
description,
is_distillation=is_distillation
)
hits = self._parse_template_hits(alignment_dir, alignment_index)
template_features = make_template_features(
input_sequence,
hits,
self.template_featurizer,
)
msa_features = self._process_msa_feats(alignment_dir, input_sequence, alignment_index)
return {**pdb_feats, **template_features, **msa_features}
def process_core(
self,
core_path: str,
alignment_dir: str,
alignment_index: Optional[str] = None,
) -> FeatureDict:
"""
Assembles features for a protein in a ProteinNet .core file.
"""
with open(core_path, 'r') as f:
core_str = f.read()
protein_object = protein.from_proteinnet_string(core_str)
input_sequence = _aatype_to_str_sequence(protein_object.aatype)
description = os.path.splitext(os.path.basename(core_path))[0].upper()
core_feats = make_protein_features(protein_object, description)
hits = self._parse_template_hits(alignment_dir, alignment_index)
template_features = make_template_features(
input_sequence,
hits,
self.template_featurizer,
)
msa_features = self._process_msa_feats(alignment_dir, input_sequence)
return {**core_feats, **template_features, **msa_features}
def process_multiseq_fasta(self,
fasta_path: str,
super_alignment_dir: str,
ri_gap: int = 200,
) -> FeatureDict:
"""
Assembles features for a multi-sequence FASTA. Uses Minkyung Baek's
hack from Twitter (a.k.a. AlphaFold-Gap).
"""
with open(fasta_path, 'r') as f:
fasta_str = f.read()
input_seqs, input_descs = parsers.parse_fasta(fasta_str)
# No whitespace allowed
input_descs = [i.split()[0] for i in input_descs]
# Stitch all of the sequences together
input_sequence = ''.join(input_seqs)
input_description = '-'.join(input_descs)
num_res = len(input_sequence)
sequence_features = make_sequence_features(
sequence=input_sequence,
description=input_description,
num_res=num_res,
)
seq_lens = [len(s) for s in input_seqs]
total_offset = 0
for sl in seq_lens:
total_offset += sl
sequence_features["residue_index"][total_offset:] += ri_gap
msa_list = []
deletion_mat_list = []
for seq, desc in zip(input_seqs, input_descs):
alignment_dir = os.path.join(
super_alignment_dir, desc
)
msas, deletion_mats = self._get_msas(
alignment_dir, seq, None
)
msa_list.append(msas)
deletion_mat_list.append(deletion_mats)
final_msa = []
final_deletion_mat = []
msa_it = enumerate(zip(msa_list, deletion_mat_list))
for i, (msas, deletion_mats) in msa_it:
prec, post = sum(seq_lens[:i]), sum(seq_lens[i + 1:])
msas = [
[prec * '-' + seq + post * '-' for seq in msa] for msa in msas
]
deletion_mats = [
[prec * [0] + dml + post * [0] for dml in deletion_mat]
for deletion_mat in deletion_mats
]
assert(len(msas[0][-1]) == len(input_sequence))
final_msa.extend(msas)
final_deletion_mat.extend(deletion_mats)
msa_features = make_msa_features(
msas=final_msa,
deletion_matrices=final_deletion_mat,
)
template_feature_list = []
for seq, desc in zip(input_seqs, input_descs):
alignment_dir = os.path.join(
super_alignment_dir, desc
)
hits = self._parse_template_hits(alignment_dir, alignment_index=None)
template_features = make_template_features(
seq,
hits,
self.template_featurizer,
)
template_feature_list.append(template_features)
template_features = unify_template_features(template_feature_list)
return {
**sequence_features,
**msa_features,
**template_features,
}
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
from functools import reduce, wraps
from operator import add
import numpy as np
import torch
from openfold.config import NUM_RES, NUM_EXTRA_SEQ, NUM_TEMPLATES, NUM_MSA_SEQ
from openfold.np import residue_constants as rc
from openfold.utils.rigid_utils import Rotation, Rigid
from openfold.utils.tensor_utils import (
tree_map,
tensor_tree_map,
batched_gather,
)
MSA_FEATURE_NAMES = [
"msa",
"deletion_matrix",
"msa_mask",
"msa_row_mask",
"bert_mask",
"true_msa",
]
def cast_to_64bit_ints(protein):
# We keep all ints as int64
for k, v in protein.items():
if v.dtype == torch.int32:
protein[k] = v.type(torch.int64)
return protein
def make_one_hot(x, num_classes):
x_one_hot = torch.zeros(*x.shape, num_classes, device=x.device)
x_one_hot.scatter_(-1, x.unsqueeze(-1), 1)
return x_one_hot
def make_seq_mask(protein):
protein["seq_mask"] = torch.ones(
protein["aatype"].shape, dtype=torch.float32
)
return protein
def make_template_mask(protein):
protein["template_mask"] = torch.ones(
protein["template_aatype"].shape[0], dtype=torch.float32
)
return protein
def curry1(f):
"""Supply all arguments but the first."""
@wraps(f)
def fc(*args, **kwargs):
return lambda x: f(x, *args, **kwargs)
return fc
def make_all_atom_aatype(protein):
protein["all_atom_aatype"] = protein["aatype"]
return protein
def fix_templates_aatype(protein):
# Map one-hot to indices
num_templates = protein["template_aatype"].shape[0]
if(num_templates > 0):
protein["template_aatype"] = torch.argmax(
protein["template_aatype"], dim=-1
)
# Map hhsearch-aatype to our aatype.
new_order_list = rc.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
new_order = torch.tensor(
new_order_list, dtype=torch.int64, device=protein["aatype"].device,
).expand(num_templates, -1)
protein["template_aatype"] = torch.gather(
new_order, 1, index=protein["template_aatype"]
)
return protein
def correct_msa_restypes(protein):
"""Correct MSA restype to have the same order as rc."""
new_order_list = rc.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
new_order = torch.tensor(
[new_order_list] * protein["msa"].shape[1],
device=protein["msa"].device,
).transpose(0, 1)
protein["msa"] = torch.gather(new_order, 0, protein["msa"])
perm_matrix = np.zeros((22, 22), dtype=np.float32)
perm_matrix[range(len(new_order_list)), new_order_list] = 1.0
for k in protein:
if "profile" in k:
num_dim = protein[k].shape.as_list()[-1]
assert num_dim in [
20,
21,
22,
], "num_dim for %s out of expected range: %s" % (k, num_dim)
protein[k] = torch.dot(protein[k], perm_matrix[:num_dim, :num_dim])
return protein
def squeeze_features(protein):
"""Remove singleton and repeated dimensions in protein features."""
protein["aatype"] = torch.argmax(protein["aatype"], dim=-1)
for k in [
"domain_name",
"msa",
"num_alignments",
"seq_length",
"sequence",
"superfamily",
"deletion_matrix",
"resolution",
"between_segment_residues",
"residue_index",
"template_all_atom_mask",
]:
if k in protein:
final_dim = protein[k].shape[-1]
if isinstance(final_dim, int) and final_dim == 1:
if torch.is_tensor(protein[k]):
protein[k] = torch.squeeze(protein[k], dim=-1)
else:
protein[k] = np.squeeze(protein[k], axis=-1)
for k in ["seq_length", "num_alignments"]:
if k in protein:
protein[k] = protein[k][0]
return protein
@curry1
def randomly_replace_msa_with_unknown(protein, replace_proportion):
"""Replace a portion of the MSA with 'X'."""
msa_mask = torch.rand(protein["msa"].shape) < replace_proportion
x_idx = 20
gap_idx = 21
msa_mask = torch.logical_and(msa_mask, protein["msa"] != gap_idx)
protein["msa"] = torch.where(
msa_mask,
torch.ones_like(protein["msa"]) * x_idx,
protein["msa"]
)
aatype_mask = torch.rand(protein["aatype"].shape) < replace_proportion
protein["aatype"] = torch.where(
aatype_mask,
torch.ones_like(protein["aatype"]) * x_idx,
protein["aatype"],
)
return protein
@curry1
def sample_msa(protein, max_seq, keep_extra, seed=None):
"""Sample MSA randomly, remaining sequences are stored are stored as `extra_*`."""
num_seq = protein["msa"].shape[0]
g = torch.Generator(device=protein["msa"].device)
if seed is not None:
g.manual_seed(seed)
shuffled = torch.randperm(num_seq - 1, generator=g) + 1
index_order = torch.cat(
(torch.tensor([0], device=shuffled.device), shuffled),
dim=0
)
num_sel = min(max_seq, num_seq)
sel_seq, not_sel_seq = torch.split(
index_order, [num_sel, num_seq - num_sel]
)
for k in MSA_FEATURE_NAMES:
if k in protein:
if keep_extra:
protein["extra_" + k] = torch.index_select(
protein[k], 0, not_sel_seq
)
protein[k] = torch.index_select(protein[k], 0, sel_seq)
return protein
@curry1
def add_distillation_flag(protein, distillation):
protein['is_distillation'] = distillation
return protein
@curry1
def sample_msa_distillation(protein, max_seq):
if(protein["is_distillation"] == 1):
protein = sample_msa(max_seq, keep_extra=False)(protein)
return protein
@curry1
def crop_extra_msa(protein, max_extra_msa):
num_seq = protein["extra_msa"].shape[0]
num_sel = min(max_extra_msa, num_seq)
select_indices = torch.randperm(num_seq)[:num_sel]
for k in MSA_FEATURE_NAMES:
if "extra_" + k in protein:
protein["extra_" + k] = torch.index_select(
protein["extra_" + k], 0, select_indices
)
return protein
def delete_extra_msa(protein):
for k in MSA_FEATURE_NAMES:
if "extra_" + k in protein:
del protein["extra_" + k]
return protein
# Not used in inference
@curry1
def block_delete_msa(protein, config):
num_seq = protein["msa"].shape[0]
block_num_seq = torch.floor(
torch.tensor(num_seq, dtype=torch.float32, device=protein["msa"].device)
* config.msa_fraction_per_block
).to(torch.int32)
if config.randomize_num_blocks:
nb = torch.distributions.uniform.Uniform(
0, config.num_blocks + 1
).sample()
else:
nb = config.num_blocks
del_block_starts = torch.distributions.Uniform(0, num_seq).sample(nb)
del_blocks = del_block_starts[:, None] + torch.range(block_num_seq)
del_blocks = torch.clip(del_blocks, 0, num_seq - 1)
del_indices = torch.unique(torch.sort(torch.reshape(del_blocks, [-1])))[0]
# Make sure we keep the original sequence
combined = torch.cat((torch.range(1, num_seq)[None], del_indices[None]))
uniques, counts = combined.unique(return_counts=True)
difference = uniques[counts == 1]
intersection = uniques[counts > 1]
keep_indices = torch.squeeze(difference, 0)
for k in MSA_FEATURE_NAMES:
if k in protein:
protein[k] = torch.gather(protein[k], keep_indices)
return protein
@curry1
def nearest_neighbor_clusters(protein, gap_agreement_weight=0.0):
weights = torch.cat(
[
torch.ones(21, device=protein["msa"].device),
gap_agreement_weight * torch.ones(1, device=protein["msa"].device),
torch.zeros(1, device=protein["msa"].device)
],
0,
)
# Make agreement score as weighted Hamming distance
msa_one_hot = make_one_hot(protein["msa"], 23)
sample_one_hot = protein["msa_mask"][:, :, None] * msa_one_hot
extra_msa_one_hot = make_one_hot(protein["extra_msa"], 23)
extra_one_hot = protein["extra_msa_mask"][:, :, None] * extra_msa_one_hot
num_seq, num_res, _ = sample_one_hot.shape
extra_num_seq, _, _ = extra_one_hot.shape
# Compute tf.einsum('mrc,nrc,c->mn', sample_one_hot, extra_one_hot, weights)
# in an optimized fashion to avoid possible memory or computation blowup.
agreement = torch.matmul(
torch.reshape(extra_one_hot, [extra_num_seq, num_res * 23]),
torch.reshape(
sample_one_hot * weights, [num_seq, num_res * 23]
).transpose(0, 1),
)
# Assign each sequence in the extra sequences to the closest MSA sample
protein["extra_cluster_assignment"] = torch.argmax(agreement, dim=1).to(
torch.int64
)
return protein
def unsorted_segment_sum(data, segment_ids, num_segments):
"""
Computes the sum along segments of a tensor. Similar to
tf.unsorted_segment_sum, but only supports 1-D indices.
:param data: A tensor whose segments are to be summed.
:param segment_ids: The 1-D segment indices tensor.
:param num_segments: The number of segments.
:return: A tensor of same data type as the data argument.
"""
assert (
len(segment_ids.shape) == 1 and
segment_ids.shape[0] == data.shape[0]
)
segment_ids = segment_ids.view(
segment_ids.shape[0], *((1,) * len(data.shape[1:]))
)
segment_ids = segment_ids.expand(data.shape)
shape = [num_segments] + list(data.shape[1:])
tensor = (
torch.zeros(*shape, device=segment_ids.device)
.scatter_add_(0, segment_ids, data.float())
)
tensor = tensor.type(data.dtype)
return tensor
@curry1
def summarize_clusters(protein):
"""Produce profile and deletion_matrix_mean within each cluster."""
num_seq = protein["msa"].shape[0]
def csum(x):
return unsorted_segment_sum(
x, protein["extra_cluster_assignment"], num_seq
)
mask = protein["extra_msa_mask"]
mask_counts = 1e-6 + protein["msa_mask"] + csum(mask) # Include center
msa_sum = csum(mask[:, :, None] * make_one_hot(protein["extra_msa"], 23))
msa_sum += make_one_hot(protein["msa"], 23) # Original sequence
protein["cluster_profile"] = msa_sum / mask_counts[:, :, None]
del msa_sum
del_sum = csum(mask * protein["extra_deletion_matrix"])
del_sum += protein["deletion_matrix"] # Original sequence
protein["cluster_deletion_mean"] = del_sum / mask_counts
del del_sum
return protein
def make_msa_mask(protein):
"""Mask features are all ones, but will later be zero-padded."""
protein["msa_mask"] = torch.ones(protein["msa"].shape, dtype=torch.float32)
protein["msa_row_mask"] = torch.ones(
(protein["msa"].shape[0]), dtype=torch.float32
)
return protein
def pseudo_beta_fn(aatype, all_atom_positions, all_atom_mask):
"""Create pseudo beta features."""
is_gly = torch.eq(aatype, rc.restype_order["G"])
ca_idx = rc.atom_order["CA"]
cb_idx = rc.atom_order["CB"]
pseudo_beta = torch.where(
torch.tile(is_gly[..., None], [1] * len(is_gly.shape) + [3]),
all_atom_positions[..., ca_idx, :],
all_atom_positions[..., cb_idx, :],
)
if all_atom_mask is not None:
pseudo_beta_mask = torch.where(
is_gly, all_atom_mask[..., ca_idx], all_atom_mask[..., cb_idx]
)
return pseudo_beta, pseudo_beta_mask
else:
return pseudo_beta
@curry1
def make_pseudo_beta(protein, prefix=""):
"""Create pseudo-beta (alpha for glycine) position and mask."""
assert prefix in ["", "template_"]
(
protein[prefix + "pseudo_beta"],
protein[prefix + "pseudo_beta_mask"],
) = pseudo_beta_fn(
protein["template_aatype" if prefix else "aatype"],
protein[prefix + "all_atom_positions"],
protein["template_all_atom_mask" if prefix else "all_atom_mask"],
)
return protein
@curry1
def add_constant_field(protein, key, value):
protein[key] = torch.tensor(value, device=protein["msa"].device)
return protein
def shaped_categorical(probs, epsilon=1e-10):
ds = probs.shape
num_classes = ds[-1]
distribution = torch.distributions.categorical.Categorical(
torch.reshape(probs + epsilon, [-1, num_classes])
)
counts = distribution.sample()
return torch.reshape(counts, ds[:-1])
def make_hhblits_profile(protein):
"""Compute the HHblits MSA profile if not already present."""
if "hhblits_profile" in protein:
return protein
# Compute the profile for every residue (over all MSA sequences).
msa_one_hot = make_one_hot(protein["msa"], 22)
protein["hhblits_profile"] = torch.mean(msa_one_hot, dim=0)
return protein
@curry1
def make_masked_msa(protein, config, replace_fraction):
"""Create data for BERT on raw MSA."""
# Add a random amino acid uniformly.
random_aa = torch.tensor(
[0.05] * 20 + [0.0, 0.0],
dtype=torch.float32,
device=protein["aatype"].device
)
categorical_probs = (
config.uniform_prob * random_aa
+ config.profile_prob * protein["hhblits_profile"]
+ config.same_prob * make_one_hot(protein["msa"], 22)
)
# Put all remaining probability on [MASK] which is a new column
pad_shapes = list(
reduce(add, [(0, 0) for _ in range(len(categorical_probs.shape))])
)
pad_shapes[1] = 1
mask_prob = (
1.0 - config.profile_prob - config.same_prob - config.uniform_prob
)
assert mask_prob >= 0.0
categorical_probs = torch.nn.functional.pad(
categorical_probs, pad_shapes, value=mask_prob
)
sh = protein["msa"].shape
mask_position = torch.rand(sh) < replace_fraction
bert_msa = shaped_categorical(categorical_probs)
bert_msa = torch.where(mask_position, bert_msa, protein["msa"])
# Mix real and masked MSA
protein["bert_mask"] = mask_position.to(torch.float32)
protein["true_msa"] = protein["msa"]
protein["msa"] = bert_msa
return protein
@curry1
def make_fixed_size(
protein,
shape_schema,
msa_cluster_size,
extra_msa_size,
num_res=0,
num_templates=0,
):
"""Guess at the MSA and sequence dimension to make fixed size."""
pad_size_map = {
NUM_RES: num_res,
NUM_MSA_SEQ: msa_cluster_size,
NUM_EXTRA_SEQ: extra_msa_size,
NUM_TEMPLATES: num_templates,
}
for k, v in protein.items():
# Don't transfer this to the accelerator.
if k == "extra_cluster_assignment":
continue
shape = list(v.shape)
schema = shape_schema[k]
msg = "Rank mismatch between shape and shape schema for"
assert len(shape) == len(schema), f"{msg} {k}: {shape} vs {schema}"
pad_size = [
pad_size_map.get(s2, None) or s1 for (s1, s2) in zip(shape, schema)
]
padding = [(0, p - v.shape[i]) for i, p in enumerate(pad_size)]
padding.reverse()
padding = list(itertools.chain(*padding))
if padding:
protein[k] = torch.nn.functional.pad(v, padding)
protein[k] = torch.reshape(protein[k], pad_size)
return protein
@curry1
def make_msa_feat(protein):
"""Create and concatenate MSA features."""
# Whether there is a domain break. Always zero for chains, but keeping for
# compatibility with domain datasets.
has_break = torch.clip(
protein["between_segment_residues"].to(torch.float32), 0, 1
)
aatype_1hot = make_one_hot(protein["aatype"], 21)
target_feat = [
torch.unsqueeze(has_break, dim=-1),
aatype_1hot, # Everyone gets the original sequence.
]
msa_1hot = make_one_hot(protein["msa"], 23)
has_deletion = torch.clip(protein["deletion_matrix"], 0.0, 1.0)
deletion_value = torch.atan(protein["deletion_matrix"] / 3.0) * (
2.0 / np.pi
)
msa_feat = [
msa_1hot,
torch.unsqueeze(has_deletion, dim=-1),
torch.unsqueeze(deletion_value, dim=-1),
]
if "cluster_profile" in protein:
deletion_mean_value = torch.atan(
protein["cluster_deletion_mean"] / 3.0
) * (2.0 / np.pi)
msa_feat.extend(
[
protein["cluster_profile"],
torch.unsqueeze(deletion_mean_value, dim=-1),
]
)
if "extra_deletion_matrix" in protein:
protein["extra_has_deletion"] = torch.clip(
protein["extra_deletion_matrix"], 0.0, 1.0
)
protein["extra_deletion_value"] = torch.atan(
protein["extra_deletion_matrix"] / 3.0
) * (2.0 / np.pi)
protein["msa_feat"] = torch.cat(msa_feat, dim=-1)
protein["target_feat"] = torch.cat(target_feat, dim=-1)
return protein
@curry1
def select_feat(protein, feature_list):
return {k: v for k, v in protein.items() if k in feature_list}
@curry1
def crop_templates(protein, max_templates):
for k, v in protein.items():
if k.startswith("template_"):
protein[k] = v[:max_templates]
return protein
def make_atom14_masks(protein):
"""Construct denser atom positions (14 dimensions instead of 37)."""
restype_atom14_to_atom37 = []
restype_atom37_to_atom14 = []
restype_atom14_mask = []
for rt in rc.restypes:
atom_names = rc.restype_name_to_atom14_names[rc.restype_1to3[rt]]
restype_atom14_to_atom37.append(
[(rc.atom_order[name] if name else 0) for name in atom_names]
)
atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)}
restype_atom37_to_atom14.append(
[
(atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0)
for name in rc.atom_types
]
)
restype_atom14_mask.append(
[(1.0 if name else 0.0) for name in atom_names]
)
# Add dummy mapping for restype 'UNK'
restype_atom14_to_atom37.append([0] * 14)
restype_atom37_to_atom14.append([0] * 37)
restype_atom14_mask.append([0.0] * 14)
restype_atom14_to_atom37 = torch.tensor(
restype_atom14_to_atom37,
dtype=torch.int32,
device=protein["aatype"].device,
)
restype_atom37_to_atom14 = torch.tensor(
restype_atom37_to_atom14,
dtype=torch.int32,
device=protein["aatype"].device,
)
restype_atom14_mask = torch.tensor(
restype_atom14_mask,
dtype=torch.float32,
device=protein["aatype"].device,
)
protein_aatype = protein['aatype'].to(torch.long)
# create the mapping for (residx, atom14) --> atom37, i.e. an array
# with shape (num_res, 14) containing the atom37 indices for this protein
residx_atom14_to_atom37 = restype_atom14_to_atom37[protein_aatype]
residx_atom14_mask = restype_atom14_mask[protein_aatype]
protein["atom14_atom_exists"] = residx_atom14_mask
protein["residx_atom14_to_atom37"] = residx_atom14_to_atom37.long()
# create the gather indices for mapping back
residx_atom37_to_atom14 = restype_atom37_to_atom14[protein_aatype]
protein["residx_atom37_to_atom14"] = residx_atom37_to_atom14.long()
# create the corresponding mask
restype_atom37_mask = torch.zeros(
[21, 37], dtype=torch.float32, device=protein["aatype"].device
)
for restype, restype_letter in enumerate(rc.restypes):
restype_name = rc.restype_1to3[restype_letter]
atom_names = rc.residue_atoms[restype_name]
for atom_name in atom_names:
atom_type = rc.atom_order[atom_name]
restype_atom37_mask[restype, atom_type] = 1
residx_atom37_mask = restype_atom37_mask[protein_aatype]
protein["atom37_atom_exists"] = residx_atom37_mask
return protein
def make_atom14_masks_np(batch):
batch = tree_map(
lambda n: torch.tensor(n, device="cpu"),
batch,
np.ndarray
)
out = make_atom14_masks(batch)
out = tensor_tree_map(lambda t: np.array(t), out)
return out
def make_atom14_positions(protein):
"""Constructs denser atom positions (14 dimensions instead of 37)."""
residx_atom14_mask = protein["atom14_atom_exists"]
residx_atom14_to_atom37 = protein["residx_atom14_to_atom37"]
# Create a mask for known ground truth positions.
residx_atom14_gt_mask = residx_atom14_mask * batched_gather(
protein["all_atom_mask"],
residx_atom14_to_atom37,
dim=-1,
no_batch_dims=len(protein["all_atom_mask"].shape[:-1]),
)
# Gather the ground truth positions.
residx_atom14_gt_positions = residx_atom14_gt_mask[..., None] * (
batched_gather(
protein["all_atom_positions"],
residx_atom14_to_atom37,
dim=-2,
no_batch_dims=len(protein["all_atom_positions"].shape[:-2]),
)
)
protein["atom14_atom_exists"] = residx_atom14_mask
protein["atom14_gt_exists"] = residx_atom14_gt_mask
protein["atom14_gt_positions"] = residx_atom14_gt_positions
# As the atom naming is ambiguous for 7 of the 20 amino acids, provide
# alternative ground truth coordinates where the naming is swapped
restype_3 = [rc.restype_1to3[res] for res in rc.restypes]
restype_3 += ["UNK"]
# Matrices for renaming ambiguous atoms.
all_matrices = {
res: torch.eye(
14,
dtype=protein["all_atom_mask"].dtype,
device=protein["all_atom_mask"].device,
)
for res in restype_3
}
for resname, swap in rc.residue_atom_renaming_swaps.items():
correspondences = torch.arange(
14, device=protein["all_atom_mask"].device
)
for source_atom_swap, target_atom_swap in swap.items():
source_index = rc.restype_name_to_atom14_names[resname].index(
source_atom_swap
)
target_index = rc.restype_name_to_atom14_names[resname].index(
target_atom_swap
)
correspondences[source_index] = target_index
correspondences[target_index] = source_index
renaming_matrix = protein["all_atom_mask"].new_zeros((14, 14))
for index, correspondence in enumerate(correspondences):
renaming_matrix[index, correspondence] = 1.0
all_matrices[resname] = renaming_matrix
renaming_matrices = torch.stack(
[all_matrices[restype] for restype in restype_3]
)
# Pick the transformation matrices for the given residue sequence
# shape (num_res, 14, 14).
renaming_transform = renaming_matrices[protein["aatype"]]
# Apply it to the ground truth positions. shape (num_res, 14, 3).
alternative_gt_positions = torch.einsum(
"...rac,...rab->...rbc", residx_atom14_gt_positions, renaming_transform
)
protein["atom14_alt_gt_positions"] = alternative_gt_positions
# Create the mask for the alternative ground truth (differs from the
# ground truth mask, if only one of the atoms in an ambiguous pair has a
# ground truth position).
alternative_gt_mask = torch.einsum(
"...ra,...rab->...rb", residx_atom14_gt_mask, renaming_transform
)
protein["atom14_alt_gt_exists"] = alternative_gt_mask
# Create an ambiguous atoms mask. shape: (21, 14).
restype_atom14_is_ambiguous = protein["all_atom_mask"].new_zeros((21, 14))
for resname, swap in rc.residue_atom_renaming_swaps.items():
for atom_name1, atom_name2 in swap.items():
restype = rc.restype_order[rc.restype_3to1[resname]]
atom_idx1 = rc.restype_name_to_atom14_names[resname].index(
atom_name1
)
atom_idx2 = rc.restype_name_to_atom14_names[resname].index(
atom_name2
)
restype_atom14_is_ambiguous[restype, atom_idx1] = 1
restype_atom14_is_ambiguous[restype, atom_idx2] = 1
# From this create an ambiguous_mask for the given sequence.
protein["atom14_atom_is_ambiguous"] = restype_atom14_is_ambiguous[
protein["aatype"]
]
return protein
def atom37_to_frames(protein, eps=1e-8):
aatype = protein["aatype"]
all_atom_positions = protein["all_atom_positions"]
all_atom_mask = protein["all_atom_mask"]
batch_dims = len(aatype.shape[:-1])
restype_rigidgroup_base_atom_names = np.full([21, 8, 3], "", dtype=object)
restype_rigidgroup_base_atom_names[:, 0, :] = ["C", "CA", "N"]
restype_rigidgroup_base_atom_names[:, 3, :] = ["CA", "C", "O"]
for restype, restype_letter in enumerate(rc.restypes):
resname = rc.restype_1to3[restype_letter]
for chi_idx in range(4):
if rc.chi_angles_mask[restype][chi_idx]:
names = rc.chi_angles_atoms[resname][chi_idx]
restype_rigidgroup_base_atom_names[
restype, chi_idx + 4, :
] = names[1:]
restype_rigidgroup_mask = all_atom_mask.new_zeros(
(*aatype.shape[:-1], 21, 8),
)
restype_rigidgroup_mask[..., 0] = 1
restype_rigidgroup_mask[..., 3] = 1
restype_rigidgroup_mask[..., :20, 4:] = all_atom_mask.new_tensor(
rc.chi_angles_mask
)
lookuptable = rc.atom_order.copy()
lookuptable[""] = 0
lookup = np.vectorize(lambda x: lookuptable[x])
restype_rigidgroup_base_atom37_idx = lookup(
restype_rigidgroup_base_atom_names,
)
restype_rigidgroup_base_atom37_idx = aatype.new_tensor(
restype_rigidgroup_base_atom37_idx,
)
restype_rigidgroup_base_atom37_idx = (
restype_rigidgroup_base_atom37_idx.view(
*((1,) * batch_dims), *restype_rigidgroup_base_atom37_idx.shape
)
)
residx_rigidgroup_base_atom37_idx = batched_gather(
restype_rigidgroup_base_atom37_idx,
aatype,
dim=-3,
no_batch_dims=batch_dims,
)
base_atom_pos = batched_gather(
all_atom_positions,
residx_rigidgroup_base_atom37_idx,
dim=-2,
no_batch_dims=len(all_atom_positions.shape[:-2]),
)
gt_frames = Rigid.from_3_points(
p_neg_x_axis=base_atom_pos[..., 0, :],
origin=base_atom_pos[..., 1, :],
p_xy_plane=base_atom_pos[..., 2, :],
eps=eps,
)
group_exists = batched_gather(
restype_rigidgroup_mask,
aatype,
dim=-2,
no_batch_dims=batch_dims,
)
gt_atoms_exist = batched_gather(
all_atom_mask,
residx_rigidgroup_base_atom37_idx,
dim=-1,
no_batch_dims=len(all_atom_mask.shape[:-1]),
)
gt_exists = torch.min(gt_atoms_exist, dim=-1)[0] * group_exists
rots = torch.eye(3, dtype=all_atom_mask.dtype, device=aatype.device)
rots = torch.tile(rots, (*((1,) * batch_dims), 8, 1, 1))
rots[..., 0, 0, 0] = -1
rots[..., 0, 2, 2] = -1
rots = Rotation(rot_mats=rots)
gt_frames = gt_frames.compose(Rigid(rots, None))
restype_rigidgroup_is_ambiguous = all_atom_mask.new_zeros(
*((1,) * batch_dims), 21, 8
)
restype_rigidgroup_rots = torch.eye(
3, dtype=all_atom_mask.dtype, device=aatype.device
)
restype_rigidgroup_rots = torch.tile(
restype_rigidgroup_rots,
(*((1,) * batch_dims), 21, 8, 1, 1),
)
for resname, _ in rc.residue_atom_renaming_swaps.items():
restype = rc.restype_order[rc.restype_3to1[resname]]
chi_idx = int(sum(rc.chi_angles_mask[restype]) - 1)
restype_rigidgroup_is_ambiguous[..., restype, chi_idx + 4] = 1
restype_rigidgroup_rots[..., restype, chi_idx + 4, 1, 1] = -1
restype_rigidgroup_rots[..., restype, chi_idx + 4, 2, 2] = -1
residx_rigidgroup_is_ambiguous = batched_gather(
restype_rigidgroup_is_ambiguous,
aatype,
dim=-2,
no_batch_dims=batch_dims,
)
residx_rigidgroup_ambiguity_rot = batched_gather(
restype_rigidgroup_rots,
aatype,
dim=-4,
no_batch_dims=batch_dims,
)
residx_rigidgroup_ambiguity_rot = Rotation(
rot_mats=residx_rigidgroup_ambiguity_rot
)
alt_gt_frames = gt_frames.compose(
Rigid(residx_rigidgroup_ambiguity_rot, None)
)
gt_frames_tensor = gt_frames.to_tensor_4x4()
alt_gt_frames_tensor = alt_gt_frames.to_tensor_4x4()
protein["rigidgroups_gt_frames"] = gt_frames_tensor
protein["rigidgroups_gt_exists"] = gt_exists
protein["rigidgroups_group_exists"] = group_exists
protein["rigidgroups_group_is_ambiguous"] = residx_rigidgroup_is_ambiguous
protein["rigidgroups_alt_gt_frames"] = alt_gt_frames_tensor
return protein
def get_chi_atom_indices():
"""Returns atom indices needed to compute chi angles for all residue types.
Returns:
A tensor of shape [residue_types=21, chis=4, atoms=4]. The residue types are
in the order specified in rc.restypes + unknown residue type
at the end. For chi angles which are not defined on the residue, the
positions indices are by default set to 0.
"""
chi_atom_indices = []
for residue_name in rc.restypes:
residue_name = rc.restype_1to3[residue_name]
residue_chi_angles = rc.chi_angles_atoms[residue_name]
atom_indices = []
for chi_angle in residue_chi_angles:
atom_indices.append([rc.atom_order[atom] for atom in chi_angle])
for _ in range(4 - len(atom_indices)):
atom_indices.append(
[0, 0, 0, 0]
) # For chi angles not defined on the AA.
chi_atom_indices.append(atom_indices)
chi_atom_indices.append([[0, 0, 0, 0]] * 4) # For UNKNOWN residue.
return chi_atom_indices
@curry1
def atom37_to_torsion_angles(
protein,
prefix="",
):
"""
Convert coordinates to torsion angles.
This function is extremely sensitive to floating point imprecisions
and should be run with double precision whenever possible.
Args:
Dict containing:
* (prefix)aatype:
[*, N_res] residue indices
* (prefix)all_atom_positions:
[*, N_res, 37, 3] atom positions (in atom37
format)
* (prefix)all_atom_mask:
[*, N_res, 37] atom position mask
Returns:
The same dictionary updated with the following features:
"(prefix)torsion_angles_sin_cos" ([*, N_res, 7, 2])
Torsion angles
"(prefix)alt_torsion_angles_sin_cos" ([*, N_res, 7, 2])
Alternate torsion angles (accounting for 180-degree symmetry)
"(prefix)torsion_angles_mask" ([*, N_res, 7])
Torsion angles mask
"""
aatype = protein[prefix + "aatype"]
all_atom_positions = protein[prefix + "all_atom_positions"]
all_atom_mask = protein[prefix + "all_atom_mask"]
aatype = torch.clamp(aatype, max=20)
pad = all_atom_positions.new_zeros(
[*all_atom_positions.shape[:-3], 1, 37, 3]
)
prev_all_atom_positions = torch.cat(
[pad, all_atom_positions[..., :-1, :, :]], dim=-3
)
pad = all_atom_mask.new_zeros([*all_atom_mask.shape[:-2], 1, 37])
prev_all_atom_mask = torch.cat([pad, all_atom_mask[..., :-1, :]], dim=-2)
pre_omega_atom_pos = torch.cat(
[prev_all_atom_positions[..., 1:3, :], all_atom_positions[..., :2, :]],
dim=-2,
)
phi_atom_pos = torch.cat(
[prev_all_atom_positions[..., 2:3, :], all_atom_positions[..., :3, :]],
dim=-2,
)
psi_atom_pos = torch.cat(
[all_atom_positions[..., :3, :], all_atom_positions[..., 4:5, :]],
dim=-2,
)
pre_omega_mask = torch.prod(
prev_all_atom_mask[..., 1:3], dim=-1
) * torch.prod(all_atom_mask[..., :2], dim=-1)
phi_mask = prev_all_atom_mask[..., 2] * torch.prod(
all_atom_mask[..., :3], dim=-1, dtype=all_atom_mask.dtype
)
psi_mask = (
torch.prod(all_atom_mask[..., :3], dim=-1, dtype=all_atom_mask.dtype)
* all_atom_mask[..., 4]
)
chi_atom_indices = torch.as_tensor(
get_chi_atom_indices(), device=aatype.device
)
atom_indices = chi_atom_indices[..., aatype, :, :]
chis_atom_pos = batched_gather(
all_atom_positions, atom_indices, -2, len(atom_indices.shape[:-2])
)
chi_angles_mask = list(rc.chi_angles_mask)
chi_angles_mask.append([0.0, 0.0, 0.0, 0.0])
chi_angles_mask = all_atom_mask.new_tensor(chi_angles_mask)
chis_mask = chi_angles_mask[aatype, :]
chi_angle_atoms_mask = batched_gather(
all_atom_mask,
atom_indices,
dim=-1,
no_batch_dims=len(atom_indices.shape[:-2]),
)
chi_angle_atoms_mask = torch.prod(
chi_angle_atoms_mask, dim=-1, dtype=chi_angle_atoms_mask.dtype
)
chis_mask = chis_mask * chi_angle_atoms_mask
torsions_atom_pos = torch.cat(
[
pre_omega_atom_pos[..., None, :, :],
phi_atom_pos[..., None, :, :],
psi_atom_pos[..., None, :, :],
chis_atom_pos,
],
dim=-3,
)
torsion_angles_mask = torch.cat(
[
pre_omega_mask[..., None],
phi_mask[..., None],
psi_mask[..., None],
chis_mask,
],
dim=-1,
)
torsion_frames = Rigid.from_3_points(
torsions_atom_pos[..., 1, :],
torsions_atom_pos[..., 2, :],
torsions_atom_pos[..., 0, :],
eps=1e-8,
)
fourth_atom_rel_pos = torsion_frames.invert().apply(
torsions_atom_pos[..., 3, :]
)
torsion_angles_sin_cos = torch.stack(
[fourth_atom_rel_pos[..., 2], fourth_atom_rel_pos[..., 1]], dim=-1
)
denom = torch.sqrt(
torch.sum(
torch.square(torsion_angles_sin_cos),
dim=-1,
dtype=torsion_angles_sin_cos.dtype,
keepdims=True,
)
+ 1e-8
)
torsion_angles_sin_cos = torsion_angles_sin_cos / denom
torsion_angles_sin_cos = torsion_angles_sin_cos * all_atom_mask.new_tensor(
[1.0, 1.0, -1.0, 1.0, 1.0, 1.0, 1.0],
)[((None,) * len(torsion_angles_sin_cos.shape[:-2])) + (slice(None), None)]
chi_is_ambiguous = torsion_angles_sin_cos.new_tensor(
rc.chi_pi_periodic,
)[aatype, ...]
mirror_torsion_angles = torch.cat(
[
all_atom_mask.new_ones(*aatype.shape, 3),
1.0 - 2.0 * chi_is_ambiguous,
],
dim=-1,
)
alt_torsion_angles_sin_cos = (
torsion_angles_sin_cos * mirror_torsion_angles[..., None]
)
protein[prefix + "torsion_angles_sin_cos"] = torsion_angles_sin_cos
protein[prefix + "alt_torsion_angles_sin_cos"] = alt_torsion_angles_sin_cos
protein[prefix + "torsion_angles_mask"] = torsion_angles_mask
return protein
def get_backbone_frames(protein):
# DISCREPANCY: AlphaFold uses tensor_7s here. I don't know why.
protein["backbone_rigid_tensor"] = protein["rigidgroups_gt_frames"][
..., 0, :, :
]
protein["backbone_rigid_mask"] = protein["rigidgroups_gt_exists"][..., 0]
return protein
def get_chi_angles(protein):
dtype = protein["all_atom_mask"].dtype
protein["chi_angles_sin_cos"] = (
protein["torsion_angles_sin_cos"][..., 3:, :]
).to(dtype)
protein["chi_mask"] = protein["torsion_angles_mask"][..., 3:].to(dtype)
return protein
@curry1
def random_crop_to_size(
protein,
crop_size,
max_templates,
shape_schema,
subsample_templates=False,
seed=None,
):
"""Crop randomly to `crop_size`, or keep as is if shorter than that."""
# We want each ensemble to be cropped the same way
g = torch.Generator(device=protein["seq_length"].device)
if seed is not None:
g.manual_seed(seed)
seq_length = protein["seq_length"]
if "template_mask" in protein:
num_templates = protein["template_mask"].shape[-1]
else:
num_templates = 0
# No need to subsample templates if there aren't any
subsample_templates = subsample_templates and num_templates
num_res_crop_size = min(int(seq_length), crop_size)
def _randint(lower, upper):
return int(torch.randint(
lower,
upper + 1,
(1,),
device=protein["seq_length"].device,
generator=g,
)[0])
if subsample_templates:
templates_crop_start = _randint(0, num_templates)
templates_select_indices = torch.randperm(
num_templates, device=protein["seq_length"].device, generator=g
)
else:
templates_crop_start = 0
num_templates_crop_size = min(
num_templates - templates_crop_start, max_templates
)
n = seq_length - num_res_crop_size
if "use_clamped_fape" in protein and protein["use_clamped_fape"] == 1.:
right_anchor = n
else:
x = _randint(0, n)
right_anchor = n - x
num_res_crop_start = _randint(0, right_anchor)
for k, v in protein.items():
if k not in shape_schema or (
"template" not in k and NUM_RES not in shape_schema[k]
):
continue
# randomly permute the templates before cropping them.
if k.startswith("template") and subsample_templates:
v = v[templates_select_indices]
slices = []
for i, (dim_size, dim) in enumerate(zip(shape_schema[k], v.shape)):
is_num_res = dim_size == NUM_RES
if i == 0 and k.startswith("template"):
crop_size = num_templates_crop_size
crop_start = templates_crop_start
else:
crop_start = num_res_crop_start if is_num_res else 0
crop_size = num_res_crop_size if is_num_res else dim
slices.append(slice(crop_start, crop_start + crop_size))
protein[k] = v[slices]
protein["seq_length"] = protein["seq_length"].new_tensor(num_res_crop_size)
return protein
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""General-purpose errors used throughout the data pipeline"""
class Error(Exception):
"""Base class for exceptions."""
class MultipleChainsError(Error):
"""An error indicating that multiple chains were found for a given ID."""
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment