Commit 3ad85a19 authored by Sam DeLuca's avatar Sam DeLuca
Browse files

Merge remote-tracking branch 'cyrus/main' into run-multiple-models

parents 43b8c6f9 6da2cdaf
![header ](imgs/OpenFold_viz_banner.jpg) ![header ](imgs/of_banner.png)
# OpenFold # OpenFold
A faithful PyTorch reproduction of DeepMind's A faithful but trainable PyTorch reproduction of DeepMind's
[AlphaFold 2](https://github.com/deepmind/alphafold). [AlphaFold 2](https://github.com/deepmind/alphafold).
## Features ## Features
...@@ -14,20 +14,27 @@ DeepMind experiments. It is omitted here for the sake of reducing clutter. In ...@@ -14,20 +14,27 @@ DeepMind experiments. It is omitted here for the sake of reducing clutter. In
cases where the *Nature* paper differs from the source, we always defer to the cases where the *Nature* paper differs from the source, we always defer to the
latter. latter.
OpenFold is built to support inference with AlphaFold's official parameters. Try it out for yourself with OpenFold is trainable in full precision or `bfloat16` with or without DeepSpeed,
our [Colab notebook](https://colab.research.google.com/github/aqlaboratory/openfold/blob/main/notebooks/OpenFold.ipynb). and we've trained it from scratch, matching the performance of the original.
We've publicly released model weights and our training data — some 400,000
MSAs and PDB70 template hit files — under a permissive license. Model weights
are available via scripts in this repository while the MSAs are hosted by the
[Registry of Open Data on AWS (RODA)](https://registry.opendata.aws/openfold).
Try out running inference for yourself with our [Colab notebook](https://colab.research.google.com/github/aqlaboratory/openfold/blob/main/notebooks/OpenFold.ipynb).
Additionally, OpenFold has the following advantages over the reference implementation: OpenFold also supports inference using AlphaFold's official parameters.
- Openfold is **trainable** in full precision or `bfloat16` half-precision, with or without [DeepSpeed](https://github.com/microsoft/deepspeed). OpenFold has the following advantages over the reference implementation:
- **Faster inference** on GPU.
- **Faster inference** on GPU for chains with < 1500 residues.
- **Inference on extremely long chains**, made possible by our implementation of low-memory attention - **Inference on extremely long chains**, made possible by our implementation of low-memory attention
([Rabe & Staats 2021](https://arxiv.org/pdf/2112.05682.pdf)). ([Rabe & Staats 2021](https://arxiv.org/pdf/2112.05682.pdf)). OpenFold can predict the structures of
sequences with more than 4000 residues on a single A100, and even longer ones with CPU offloading.
- **Custom CUDA attention kernels** modified from [FastFold](https://github.com/hpcaitech/FastFold)'s - **Custom CUDA attention kernels** modified from [FastFold](https://github.com/hpcaitech/FastFold)'s
kernels support in-place attention during inference and training. They use kernels support in-place attention during inference and training. They use
4x and 5x less GPU memory than equivalent FastFold and stock PyTorch 4x and 5x less GPU memory than equivalent FastFold and stock PyTorch
implementations, respectively. implementations, respectively.
- **Efficient alignment scripts** using the original AlphaFold HHblits/JackHMMER pipeline or [ColabFold](https://github.com/sokrypton/ColabFold)'s, which uses the faster MMseqs2 instead. We've used them to generate millions of alignments that will be released alongside original OpenFold weights, trained from scratch using our code (more on that soon). - **Efficient alignment scripts** using the original AlphaFold HHblits/JackHMMER pipeline or [ColabFold](https://github.com/sokrypton/ColabFold)'s, which uses the faster MMseqs2 instead. We've used them to generate millions of alignments.
## Installation (Linux) ## Installation (Linux)
...@@ -70,7 +77,7 @@ To install the HH-suite to `/usr/bin`, run ...@@ -70,7 +77,7 @@ To install the HH-suite to `/usr/bin`, run
## Usage ## Usage
To download DeepMind's pretrained parameters and common ground truth data, run: To download the databases used to train OpenFold and AlphaFold run:
```bash ```bash
bash scripts/download_data.sh data/ bash scripts/download_data.sh data/
...@@ -96,12 +103,13 @@ Make sure to run the latter command on the machine that will be used for MSA ...@@ -96,12 +103,13 @@ Make sure to run the latter command on the machine that will be used for MSA
generation (the script estimates how the precomputed database index used by generation (the script estimates how the precomputed database index used by
MMseqs2 should be split according to the memory available on the system). MMseqs2 should be split according to the memory available on the system).
Alternatively, you can use raw MSAs from Alternatively, you can use raw MSAs from our aforementioned MSA database or
[ProteinNet](https://github.com/aqlaboratory/proteinnet). After downloading [ProteinNet](https://github.com/aqlaboratory/proteinnet). After downloading
the database, use `scripts/prep_proteinnet_msas.py` to convert the data into the latter database, use `scripts/prep_proteinnet_msas.py` to convert the data
a format recognized by the OpenFold parser. The resulting directory becomes the into a format recognized by the OpenFold parser. The resulting directory
`alignment_dir` used in subsequent steps. Use `scripts/unpack_proteinnet.py` to becomes the `alignment_dir` used in subsequent steps. Use
extract `.core` files from ProteinNet text files. `scripts/unpack_proteinnet.py` to extract `.core` files from ProteinNet text
files.
For both inference and training, the model's hyperparameters can be tuned from For both inference and training, the model's hyperparameters can be tuned from
`openfold/config.py`. Of course, if you plan to perform inference using `openfold/config.py`. Of course, if you plan to perform inference using
...@@ -124,30 +132,41 @@ python3 run_pretrained_openfold.py \ ...@@ -124,30 +132,41 @@ python3 run_pretrained_openfold.py \
--uniclust30_database_path data/uniclust30/uniclust30_2018_08/uniclust30_2018_08 \ --uniclust30_database_path data/uniclust30/uniclust30_2018_08/uniclust30_2018_08 \
--output_dir ./ \ --output_dir ./ \
--bfd_database_path data/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt \ --bfd_database_path data/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt \
--model_device cuda:1 \ --model_device "cuda:0" \
--jackhmmer_binary_path lib/conda/envs/openfold_venv/bin/jackhmmer \ --jackhmmer_binary_path lib/conda/envs/openfold_venv/bin/jackhmmer \
--hhblits_binary_path lib/conda/envs/openfold_venv/bin/hhblits \ --hhblits_binary_path lib/conda/envs/openfold_venv/bin/hhblits \
--hhsearch_binary_path lib/conda/envs/openfold_venv/bin/hhsearch \ --hhsearch_binary_path lib/conda/envs/openfold_venv/bin/hhsearch \
--kalign_binary_path lib/conda/envs/openfold_venv/bin/kalign --kalign_binary_path lib/conda/envs/openfold_venv/bin/kalign
--config_preset "model_1_ptm"
--openfold_checkpoint_path openfold/resources/openfold_params/finetuning_2_ptm.pt
``` ```
where `data` is the same directory as in the previous step. If `jackhmmer`, where `data` is the same directory as in the previous step. If `jackhmmer`,
`hhblits`, `hhsearch` and `kalign` are available at the default path of `hhblits`, `hhsearch` and `kalign` are available at the default path of
`/usr/bin`, their `binary_path` command-line arguments can be dropped. `/usr/bin`, their `binary_path` command-line arguments can be dropped.
If you've already computed alignments for the query, you have the option to If you've already computed alignments for the query, you have the option to
skip the expensive alignment computation here. skip the expensive alignment computation here with
`--use_precomputed_alignments`.
Exactly one of `--openfold_checkpoint_path` or `--jax_param_path` must be specified
to run the inference script. These accept .pt/DeepSpeed OpenFold checkpoints
and AlphaFold's .npz JAX parameter files, respectively. For a breakdown of the
differences between the different parameter files, see the README downloaded to
`openfold/resources/openfold_params/`. Since OpenFold was trained under a
newer training schedule than the one from which the `model_n` config
presets are derived, there is no clean correspondence between `config_preset`
settings and OpenFold checkpoints; the only restraint is that `*_ptm`
checkpoints must be run with `*_ptm` config presets.
Note that chunking (as defined in section 1.11.8 of the AlphaFold 2 supplement) Note that chunking (as defined in section 1.11.8 of the AlphaFold 2 supplement)
is enabled by default in inference mode. To disable it, set `globals.chunk_size` is enabled by default in inference mode. To disable it, set `globals.chunk_size`
to `None` in the config. to `None` in the config. If a value is specified, OpenFold will attempt to
dynamically tune it, considering the chunk size specified in the config as a
Inference-time low-memory attention (LMA) can be enabled in the model config. minimum. This tuning process automatically ensures consistently fast runtimes
This setting trades off speed for vastly improved memory usage. By default, regardless of input sequence length, but it also introduces some runtime
LMA is run with query and key chunk sizes of 1024 and 4096, respectively. variability, which may be undesirable for certain users. It is also recommended
These represent a favorable tradeoff in most memory-constrained cases. to disable this feature for very long chains (see below). To do so, set the
Powerusers can choose to tweak these settings in `tune_chunk_size` option in the config to `False`.
`openfold/model/primitives.py`. For more information on the LMA algorithm,
see the aforementioned Staats & Rabe preprint.
Input FASTA files containing multiple sequences are treated as complexes. In Input FASTA files containing multiple sequences are treated as complexes. In
this case, the inference script runs AlphaFold-Gap, a hack proposed this case, the inference script runs AlphaFold-Gap, a hack proposed
...@@ -156,15 +175,10 @@ the specified stock AlphaFold/OpenFold parameters (NOT AlphaFold-Multimer). To ...@@ -156,15 +175,10 @@ the specified stock AlphaFold/OpenFold parameters (NOT AlphaFold-Multimer). To
run inference with AlphaFold-Multimer, use the (experimental) `multimer` branch run inference with AlphaFold-Multimer, use the (experimental) `multimer` branch
instead. instead.
By default, OpenFold will attempt to automatically tune the inference-time To minimize memory usage during inference on long sequences, consider the
`chunk_size` hyperparameter controlling a memory/runtime tradeoff in certain following changes:
modules during inference. The chunk size specified in the config is only
considered a minimum. This feature ensures consistently fast runtimes
regardless of input sequence length, but it also introduces some runtime
variability, which may be undesirable for certain users. To disable this
feature, set the `tune_chunk_size` option in the config to `False`.
As noted in the AlphaFold-Multimer paper, the AlphaFold/OpenFold template - As noted in the AlphaFold-Multimer paper, the AlphaFold/OpenFold template
stack is a major memory bottleneck for inference on long sequences. OpenFold stack is a major memory bottleneck for inference on long sequences. OpenFold
supports two mutually exclusive inference modes to address this issue. One, supports two mutually exclusive inference modes to address this issue. One,
`average_templates` in the `template` section of the config, is similar to the `average_templates` in the `template` section of the config, is similar to the
...@@ -178,6 +192,22 @@ approximation while the latter is slightly slower; both are memory-efficient ...@@ -178,6 +192,22 @@ approximation while the latter is slightly slower; both are memory-efficient
and allow the model to utilize arbitrarily many templates across sequence and allow the model to utilize arbitrarily many templates across sequence
lengths. Both are disabled by default, and it is up to the user to determine lengths. Both are disabled by default, and it is up to the user to determine
which best suits their needs, if either. which best suits their needs, if either.
- Inference-time low-memory attention (LMA) can be enabled in the model config.
This setting trades off speed for vastly improved memory usage. By default,
LMA is run with query and key chunk sizes of 1024 and 4096, respectively.
These represent a favorable tradeoff in most memory-constrained cases.
Powerusers can choose to tweak these settings in
`openfold/model/primitives.py`. For more information on the LMA algorithm,
see the aforementioned Staats & Rabe preprint.
- Disable `tune_chunk_size` for long sequences. Past a certain point, it only
wastes time.
- As a last resort, consider enabling `offload_inference`. This enables more
extensive CPU offloading at various bottlenecks throughout the model.
Using the most conservative settings, we were able to run inference on a
4600-residue complex with a single A100. Compared to AlphaFold's own memory
offloading mode, ours is considerably faster; the same complex takes the more
efficent AlphaFold-Multimer more than double the time.
### Training ### Training
...@@ -344,7 +374,7 @@ python3 /opt/openfold/run_pretrained_openfold.py \ ...@@ -344,7 +374,7 @@ python3 /opt/openfold/run_pretrained_openfold.py \
--hhblits_binary_path /opt/conda/bin/hhblits \ --hhblits_binary_path /opt/conda/bin/hhblits \
--hhsearch_binary_path /opt/conda/bin/hhsearch \ --hhsearch_binary_path /opt/conda/bin/hhsearch \
--kalign_binary_path /opt/conda/bin/kalign \ --kalign_binary_path /opt/conda/bin/kalign \
--param_path /database/params/params_model_1.npz --openfold_checkpoint_path /database/openfold_params/finetuning_2_ptm.pt
``` ```
## Copyright notice ## Copyright notice
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
"metadata": { "metadata": {
"accelerator": "GPU", "accelerator": "GPU",
"colab": { "colab": {
"name": "OpenFold.ipynb", "name": "Copy of OpenFold.ipynb",
"provenance": [], "provenance": [],
"collapsed_sections": [] "collapsed_sections": []
}, },
...@@ -111,7 +111,7 @@ ...@@ -111,7 +111,7 @@
" %shell wget -q -P /content \\\n", " %shell wget -q -P /content \\\n",
" https://git.scicore.unibas.ch/schwede/openstructure/-/raw/7102c63615b64735c4941278d92b554ec94415f8/modules/mol/alg/src/stereo_chemical_props.txt\n", " https://git.scicore.unibas.ch/schwede/openstructure/-/raw/7102c63615b64735c4941278d92b554ec94415f8/modules/mol/alg/src/stereo_chemical_props.txt\n",
" pbar.update(1)\n", " pbar.update(1)\n",
"except subprocess.CalledProcessError:\n", "except subprocess.CalledProcessError as captured:\n",
" print(captured)\n", " print(captured)\n",
" raise" " raise"
], ],
...@@ -132,9 +132,16 @@ ...@@ -132,9 +132,16 @@
"\n", "\n",
"GIT_REPO = 'https://github.com/aqlaboratory/openfold'\n", "GIT_REPO = 'https://github.com/aqlaboratory/openfold'\n",
"\n", "\n",
"SOURCE_URL = 'https://storage.googleapis.com/alphafold/alphafold_params_2022-01-19.tar'\n", "OPENFOLD_PARAM_FILE_ID = \"1OpeMrfWEUSD_KqffbPqd5p7WsJjlC3ZE\"\n",
"PARAMS_DIR = './openfold/openfold/resources/params'\n", "ALPHAFOLD_PARAM_SOURCE_URL = 'https://storage.googleapis.com/alphafold/alphafold_params_2022-01-19.tar'\n",
"PARAMS_PATH = os.path.join(PARAMS_DIR, os.path.basename(SOURCE_URL))\n", "OPENFOLD_PARAMS_DIR = './openfold/openfold/resources/'\n",
"ALPHAFOLD_PARAMS_DIR = './openfold/openfold/resources/params'\n",
"OPENFOLD_PARAMS_PATH = os.path.join(\n",
" OPENFOLD_PARAMS_DIR, \"openfold_params.tar.gz\"\n",
")\n",
"ALPHAFOLD_PARAMS_PATH = os.path.join(\n",
" ALPHAFOLD_PARAMS_DIR, os.path.basename(ALPHAFOLD_PARAM_SOURCE_URL)\n",
")\n",
"\n", "\n",
"try:\n", "try:\n",
" with tqdm.notebook.tqdm(total=100, bar_format=TQDM_BAR_FORMAT) as pbar:\n", " with tqdm.notebook.tqdm(total=100, bar_format=TQDM_BAR_FORMAT) as pbar:\n",
...@@ -144,6 +151,10 @@ ...@@ -144,6 +151,10 @@
" pbar.update(8)\n", " pbar.update(8)\n",
" # Install the required versions of all dependencies.\n", " # Install the required versions of all dependencies.\n",
" %shell conda env update -n base --file openfold/environment.yml\n", " %shell conda env update -n base --file openfold/environment.yml\n",
" \n",
" %shell mkdir -p /content/openfold/openfold/resources\n",
" %shell cp -f /content/stereo_chemical_props.txt /content/openfold/openfold/resources\n",
" \n",
" # Run setup.py to install only Openfold.\n", " # Run setup.py to install only Openfold.\n",
" %shell pip3 install --no-dependencies ./openfold\n", " %shell pip3 install --no-dependencies ./openfold\n",
" pbar.update(10)\n", " pbar.update(10)\n",
...@@ -152,17 +163,20 @@ ...@@ -152,17 +163,20 @@
" %shell pushd /opt/conda/lib/python3.7/site-packages/ && \\\n", " %shell pushd /opt/conda/lib/python3.7/site-packages/ && \\\n",
" patch -p0 < /content/openfold/lib/openmm.patch && \\\n", " patch -p0 < /content/openfold/lib/openmm.patch && \\\n",
" popd\n", " popd\n",
" \n",
" %shell mkdir -p /content/openfold/resources\n",
" %shell cp -f /content/stereo_chemical_props.txt /content/openfold/resources\n",
"\n", "\n",
" %shell mkdir --parents \"{PARAMS_DIR}\"\n", " %shell mkdir --parents \"{ALPHAFOLD_PARAMS_DIR}\"\n",
" %shell wget -O \"{PARAMS_PATH}\" \"{SOURCE_URL}\"\n", " %shell wget -O {ALPHAFOLD_PARAMS_PATH} {ALPHAFOLD_PARAM_SOURCE_URL}\n",
" pbar.update(27)\n", " pbar.update(27)\n",
"\n", "\n",
" %shell tar --extract --verbose --file=\"{PARAMS_PATH}\" \\\n", " %shell tar --extract --verbose --file=\"{ALPHAFOLD_PARAMS_PATH}\" \\\n",
" --directory=\"{PARAMS_DIR}\" --preserve-permissions\n", " --directory=\"{ALPHAFOLD_PARAMS_DIR}\" --preserve-permissions\n",
" %shell rm \"{PARAMS_PATH}\"\n", " %shell rm \"{ALPHAFOLD_PARAMS_PATH}\"\n",
"\n",
" %shell mkdir --parents \"{OPENFOLD_PARAMS_DIR}\"\n",
" %shell gdown --id \"{OPENFOLD_PARAM_FILE_ID}\" -O \"{OPENFOLD_PARAMS_PATH}\"\n",
" %shell tar --extract --verbose --file=\"{OPENFOLD_PARAMS_PATH}\" \\\n",
" --directory=\"{OPENFOLD_PARAMS_DIR}\" --preserve-permissions\n",
" %shell rm \"{OPENFOLD_PARAMS_PATH}\"\n",
" pbar.update(55)\n", " pbar.update(55)\n",
"except subprocess.CalledProcessError:\n", "except subprocess.CalledProcessError:\n",
" print(captured)\n", " print(captured)\n",
...@@ -171,6 +185,62 @@ ...@@ -171,6 +185,62 @@
"execution_count": null, "execution_count": null,
"outputs": [] "outputs": []
}, },
{
"cell_type": "code",
"source": [
"#@title Import Python packages\n",
"#@markdown Please execute this cell by pressing the *Play* button on \n",
"#@markdown the left.\n",
"\n",
"import sys\n",
"sys.path.append('/opt/conda/lib/python3.7/site-packages')\n",
"import os\n",
"\n",
"from urllib import request\n",
"from concurrent import futures\n",
"from google.colab import files\n",
"import json\n",
"from matplotlib import gridspec\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import py3Dmol\n",
"import torch\n",
"\n",
"# A filthy hack to avoid slow Linear layer initialization\n",
"import openfold.model.primitives\n",
"\n",
"def __default_linear_init__(self, *args, **kwargs):\n",
" return torch.nn.Linear.__init__(\n",
" self, \n",
" *args[:2], \n",
" **{k:v for k,v in kwargs.items() if k == \"bias\"}\n",
" )\n",
"\n",
"openfold.model.primitives.Linear.__init__ = __default_linear_init__\n",
"\n",
"from openfold import config\n",
"from openfold.data import feature_pipeline\n",
"from openfold.data import parsers\n",
"from openfold.data import data_pipeline\n",
"from openfold.data.tools import jackhmmer\n",
"from openfold.model import model\n",
"from openfold.np import protein\n",
"from openfold.np.relax import relax\n",
"from openfold.np.relax import utils\n",
"from openfold.utils.import_weights import import_jax_weights_\n",
"from openfold.utils.tensor_utils import tensor_tree_map\n",
"\n",
"from IPython import display\n",
"from ipywidgets import GridspecLayout\n",
"from ipywidgets import Output"
],
"metadata": {
"cellView": "form",
"id": "_FpxxMo-mvcP"
},
"execution_count": null,
"outputs": []
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
...@@ -191,12 +261,16 @@ ...@@ -191,12 +261,16 @@
"cellView": "form" "cellView": "form"
}, },
"source": [ "source": [
"#@title Enter the amino acid sequence to fold ⬇️\n", "#@markdown ### Enter the amino acid sequence to fold ⬇️\n",
"sequence = 'MAAHKGAEHHHKAAEHHEQAAKHHHAAAEHHEKGEHEQAAHHADTAYAHHKHAEEHAAQAAKHDAEHHAPKPH' #@param {type:\"string\"}\n", "sequence = 'MAAHKGAEHHHKAAEHHEQAAKHHHAAAEHHEKGEHEQAAHHADTAYAHHKHAEEHAAQAAKHDAEHHAPKPH' #@param {type:\"string\"}\n",
"\n", "\n",
"MIN_SEQUENCE_LENGTH = 16\n", "MIN_SEQUENCE_LENGTH = 16\n",
"MAX_SEQUENCE_LENGTH = 2500\n", "MAX_SEQUENCE_LENGTH = 2500\n",
"\n", "\n",
"#@markdown ### Choose between OpenFold and AlphaFold model parameters ⬇️\n",
"\n",
"weight_set = 'OpenFold' #@param [\"OpenFold\", \"AlphaFold\"]\n",
"\n",
"# Remove all whitespaces, tabs and end lines; upper-case\n", "# Remove all whitespaces, tabs and end lines; upper-case\n",
"sequence = sequence.translate(str.maketrans('', '', ' \\n\\t')).upper()\n", "sequence = sequence.translate(str.maketrans('', '', ' \\n\\t')).upper()\n",
"aatypes = set('ACDEFGHIKLMNPQRSTVWY') # 20 standard aatypes\n", "aatypes = set('ACDEFGHIKLMNPQRSTVWY') # 20 standard aatypes\n",
...@@ -225,39 +299,6 @@ ...@@ -225,39 +299,6 @@
"#@markdown you’ll see how well each residue is covered by similar \n", "#@markdown you’ll see how well each residue is covered by similar \n",
"#@markdown sequences in the MSA.\n", "#@markdown sequences in the MSA.\n",
"\n", "\n",
"# --- Python imports ---\n",
"import sys\n",
"sys.path.append('/opt/conda/lib/python3.7/site-packages')\n",
"import os\n",
"os.environ['TF_FORCE_UNIFIED_MEMORY'] = '1'\n",
"os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '2.0'\n",
"\n",
"from urllib import request\n",
"from concurrent import futures\n",
"from google.colab import files\n",
"import json\n",
"from matplotlib import gridspec\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import py3Dmol\n",
"import torch\n",
"\n",
"from openfold import config\n",
"from openfold.data import feature_pipeline\n",
"from openfold.data import parsers\n",
"from openfold.data import data_pipeline\n",
"from openfold.data.tools import jackhmmer\n",
"from openfold.model import model\n",
"from openfold.np import protein\n",
"from openfold.np.relax import relax\n",
"from openfold.np.relax import utils\n",
"from openfold.utils.import_weights import import_jax_weights_\n",
"from openfold.utils.tensor_utils import tensor_tree_map\n",
"\n",
"from IPython import display\n",
"from ipywidgets import GridspecLayout\n",
"from ipywidgets import Output\n",
"\n",
"# Color bands for visualizing plddt\n", "# Color bands for visualizing plddt\n",
"PLDDT_BANDS = [(0, 50, '#FF7D45'),\n", "PLDDT_BANDS = [(0, 50, '#FF7D45'),\n",
" (50, 70, '#FFDB13'),\n", " (50, 70, '#FFDB13'),\n",
...@@ -420,8 +461,25 @@ ...@@ -420,8 +461,25 @@
" cfg = config.model_config(model_name)\n", " cfg = config.model_config(model_name)\n",
" openfold_model = model.AlphaFold(cfg)\n", " openfold_model = model.AlphaFold(cfg)\n",
" openfold_model = openfold_model.eval()\n", " openfold_model = openfold_model.eval()\n",
" params_name = os.path.join(PARAMS_DIR, f\"params_{model_name}.npz\")\n", " if(weight_set == \"AlphaFold\"):\n",
" import_jax_weights_(openfold_model, params_name, version=model_name)\n", " params_name = os.path.join(ALPHAFOLD_PARAMS_DIR, f\"params_{model_name}.npz\")\n",
" import_jax_weights_(openfold_model, params_name, version=model_name)\n",
" elif(weight_set == \"OpenFold\"):\n",
" model_name_spl = model_name.split(\"_\")\n",
" if(model_name_spl[-1] == \"ptm\"):\n",
" of_model_name = \"finetuning_ptm_2.pt\"\n",
" else:\n",
" of_model_name = f\"finetuning_{model_name_spl[-1]}.pt\"\n",
" params_name = os.path.join(\n",
" OPENFOLD_PARAMS_DIR,\n",
" \"openfold_params\",\n",
" of_model_name\n",
" )\n",
" d = torch.load(params_name)\n",
" openfold_model.load_state_dict(d)\n",
" else:\n",
" raise ValueError(f\"Invalid weight set: {weight_set}\")\n",
"\n",
" openfold_model = openfold_model.cuda()\n", " openfold_model = openfold_model.cuda()\n",
"\n", "\n",
" pipeline = feature_pipeline.FeaturePipeline(cfg.data)\n", " pipeline = feature_pipeline.FeaturePipeline(cfg.data)\n",
...@@ -699,4 +757,4 @@ ...@@ -699,4 +757,4 @@
] ]
} }
] ]
} }
\ No newline at end of file
...@@ -119,7 +119,9 @@ def model_config(name, train=False, low_prec=False): ...@@ -119,7 +119,9 @@ def model_config(name, train=False, low_prec=False):
c.globals.blocks_per_ckpt = 1 c.globals.blocks_per_ckpt = 1
c.globals.chunk_size = None c.globals.chunk_size = None
c.globals.use_lma = False c.globals.use_lma = False
c.globals.offload_inference = False
c.model.template.average_templates = False
c.model.template.offload_templates = False
if low_prec: if low_prec:
c.globals.eps = 1e-4 c.globals.eps = 1e-4
# If we want exact numerical parity with the original, inf can't be # If we want exact numerical parity with the original, inf can't be
...@@ -314,6 +316,7 @@ config = mlc.ConfigDict( ...@@ -314,6 +316,7 @@ config = mlc.ConfigDict(
"blocks_per_ckpt": blocks_per_ckpt, "blocks_per_ckpt": blocks_per_ckpt,
"chunk_size": chunk_size, "chunk_size": chunk_size,
"use_lma": False, "use_lma": False,
"offload_inference": False,
"c_z": c_z, "c_z": c_z,
"c_m": c_m, "c_m": c_m,
"c_t": c_t, "c_t": c_t,
...@@ -364,6 +367,7 @@ config = mlc.ConfigDict( ...@@ -364,6 +367,7 @@ config = mlc.ConfigDict(
"pair_transition_n": 2, "pair_transition_n": 2,
"dropout_rate": 0.25, "dropout_rate": 0.25,
"blocks_per_ckpt": blocks_per_ckpt, "blocks_per_ckpt": blocks_per_ckpt,
"tune_chunk_size": tune_chunk_size,
"inf": 1e9, "inf": 1e9,
}, },
"template_pointwise_attention": { "template_pointwise_attention": {
...@@ -409,7 +413,7 @@ config = mlc.ConfigDict( ...@@ -409,7 +413,7 @@ config = mlc.ConfigDict(
"transition_n": 4, "transition_n": 4,
"msa_dropout": 0.15, "msa_dropout": 0.15,
"pair_dropout": 0.25, "pair_dropout": 0.25,
"clear_cache_between_blocks": True, "clear_cache_between_blocks": False,
"tune_chunk_size": tune_chunk_size, "tune_chunk_size": tune_chunk_size,
"inf": 1e9, "inf": 1e9,
"eps": eps, # 1e-10, "eps": eps, # 1e-10,
......
...@@ -432,7 +432,7 @@ def _is_set(data: str) -> bool: ...@@ -432,7 +432,7 @@ def _is_set(data: str) -> bool:
def get_atom_coords( def get_atom_coords(
mmcif_object: MmcifObject, mmcif_object: MmcifObject,
chain_id: str, chain_id: str,
_zero_center_positions: bool = True _zero_center_positions: bool = False
) -> Tuple[np.ndarray, np.ndarray]: ) -> Tuple[np.ndarray, np.ndarray]:
# Locate the right chain # Locate the right chain
chains = list(mmcif_object.structure.get_chains()) chains = list(mmcif_object.structure.get_chains())
......
...@@ -503,7 +503,7 @@ def _get_atom_positions( ...@@ -503,7 +503,7 @@ def _get_atom_positions(
mmcif_object: mmcif_parsing.MmcifObject, mmcif_object: mmcif_parsing.MmcifObject,
auth_chain_id: str, auth_chain_id: str,
max_ca_ca_distance: float, max_ca_ca_distance: float,
_zero_center_positions: bool = True, _zero_center_positions: bool = False,
) -> Tuple[np.ndarray, np.ndarray]: ) -> Tuple[np.ndarray, np.ndarray]:
"""Gets atom positions and mask from a list of Biopython Residues.""" """Gets atom positions and mask from a list of Biopython Residues."""
coords_with_mask = mmcif_parsing.get_atom_coords( coords_with_mask = mmcif_parsing.get_atom_coords(
...@@ -1045,6 +1045,7 @@ class TemplateHitFeaturizer: ...@@ -1045,6 +1045,7 @@ class TemplateHitFeaturizer:
filtered = list( filtered = list(
sorted(filtered, key=lambda x: x.sum_probs, reverse=True) sorted(filtered, key=lambda x: x.sum_probs, reverse=True)
) )
idx = list(range(len(filtered))) idx = list(range(len(filtered)))
if(self._shuffle_top_k_prefiltered): if(self._shuffle_top_k_prefiltered):
stk = self._shuffle_top_k_prefiltered stk = self._shuffle_top_k_prefiltered
......
...@@ -81,15 +81,21 @@ class InputEmbedder(nn.Module): ...@@ -81,15 +81,21 @@ class InputEmbedder(nn.Module):
d = ri[..., None] - ri[..., None, :] d = ri[..., None] - ri[..., None, :]
boundaries = torch.arange( boundaries = torch.arange(
start=-self.relpos_k, end=self.relpos_k + 1, device=d.device start=-self.relpos_k, end=self.relpos_k + 1, device=d.device
) )
oh = one_hot(d, boundaries).type(ri.dtype) reshaped_bins = boundaries.view(((1,) * len(d.shape)) + (len(boundaries),))
return self.linear_relpos(oh) d = d[..., None] - reshaped_bins
d = torch.abs(d)
d = torch.argmin(d, dim=-1)
d = nn.functional.one_hot(d, num_classes=len(boundaries)).float()
d = d.to(ri.dtype)
return self.linear_relpos(d)
def forward( def forward(
self, self,
tf: torch.Tensor, tf: torch.Tensor,
ri: torch.Tensor, ri: torch.Tensor,
msa: torch.Tensor, msa: torch.Tensor,
inplace_safe: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Args: Args:
...@@ -111,8 +117,15 @@ class InputEmbedder(nn.Module): ...@@ -111,8 +117,15 @@ class InputEmbedder(nn.Module):
tf_emb_j = self.linear_tf_z_j(tf) tf_emb_j = self.linear_tf_z_j(tf)
# [*, N_res, N_res, c_z] # [*, N_res, N_res, c_z]
pair_emb = tf_emb_i[..., None, :] + tf_emb_j[..., None, :, :] pair_emb = self.relpos(ri.type(tf_emb_i.dtype))
pair_emb = pair_emb + self.relpos(ri.type(pair_emb.dtype)) pair_emb = add(pair_emb,
tf_emb_i[..., None, :],
inplace=inplace_safe
)
pair_emb = add(pair_emb,
tf_emb_j[..., None, :, :],
inplace=inplace_safe
)
# [*, N_clust, N_res, c_m] # [*, N_clust, N_res, c_m]
n_clust = msa.shape[-3] n_clust = msa.shape[-3]
...@@ -173,7 +186,7 @@ class RecyclingEmbedder(nn.Module): ...@@ -173,7 +186,7 @@ class RecyclingEmbedder(nn.Module):
m: torch.Tensor, m: torch.Tensor,
z: torch.Tensor, z: torch.Tensor,
x: torch.Tensor, x: torch.Tensor,
_inplace: bool = False, inplace_safe: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Args: Args:
...@@ -191,13 +204,13 @@ class RecyclingEmbedder(nn.Module): ...@@ -191,13 +204,13 @@ class RecyclingEmbedder(nn.Module):
""" """
# [*, N, C_m] # [*, N, C_m]
m_update = self.layer_norm_m(m) m_update = self.layer_norm_m(m)
if(_inplace): if(inplace_safe):
m.copy_(m_update) m.copy_(m_update)
m_update = m m_update = m
# [*, N, N, C_z] # [*, N, N, C_z]
z_update = self.layer_norm_z(z) z_update = self.layer_norm_z(z)
if(_inplace): if(inplace_safe):
z.copy_(z_update) z.copy_(z_update)
z_update = z z_update = z
...@@ -223,7 +236,7 @@ class RecyclingEmbedder(nn.Module): ...@@ -223,7 +236,7 @@ class RecyclingEmbedder(nn.Module):
# [*, N, N, C_z] # [*, N, N, C_z]
d = self.linear(d) d = self.linear(d)
z_update = add(z_update, d, _inplace) z_update = add(z_update, d, inplace_safe)
return m_update, z_update return m_update, z_update
......
This diff is collapsed.
...@@ -107,19 +107,16 @@ class AlphaFold(nn.Module): ...@@ -107,19 +107,16 @@ class AlphaFold(nn.Module):
self.config["heads"], self.config["heads"],
) )
def embed_templates(self, batch, z, pair_mask, templ_dim, inplace_safe):
def embed_templates(self, batch, z, pair_mask, templ_dim):
if(self.template_config.offload_templates): if(self.template_config.offload_templates):
return embed_templates_offload( return embed_templates_offload(self,
self, batch, z, pair_mask, templ_dim, batch, z, pair_mask, templ_dim, inplace_safe=inplace_safe,
) )
elif(self.template_config.average_templates): elif(self.template_config.average_templates):
return embed_templates_average( return embed_templates_average(self,
self, batch, z, pair_mask, templ_dim batch, z, pair_mask, templ_dim, inplace_safe=inplace_safe,
) )
inplace_safe = not (self.training or torch.is_grad_enabled())
# Embed the templates one at a time (with a poor man's vmap) # Embed the templates one at a time (with a poor man's vmap)
pair_embeds = [] pair_embeds = []
n = z.shape[-2] n = z.shape[-2]
...@@ -168,6 +165,7 @@ class AlphaFold(nn.Module): ...@@ -168,6 +165,7 @@ class AlphaFold(nn.Module):
pair_mask.unsqueeze(-3).to(dtype=z.dtype), pair_mask.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=self.globals.chunk_size, chunk_size=self.globals.chunk_size,
use_lma=self.globals.use_lma, use_lma=self.globals.use_lma,
inplace_safe=inplace_safe,
_mask_trans=self.config._mask_trans, _mask_trans=self.config._mask_trans,
) )
del t_pair del t_pair
...@@ -186,6 +184,11 @@ class AlphaFold(nn.Module): ...@@ -186,6 +184,11 @@ class AlphaFold(nn.Module):
t = t * (torch.sum(batch["template_mask"], dim=-1) > 0) t = t * (torch.sum(batch["template_mask"], dim=-1) > 0)
ret = {} ret = {}
ret.update({"template_pair_embedding": t})
del t
if self.config.template.embed_angles: if self.config.template.embed_angles:
template_angle_feat = build_template_angle_feat( template_angle_feat = build_template_angle_feat(
batch batch
...@@ -196,10 +199,6 @@ class AlphaFold(nn.Module): ...@@ -196,10 +199,6 @@ class AlphaFold(nn.Module):
ret["template_angle_embedding"] = a ret["template_angle_embedding"] = a
ret.update({"template_pair_embedding": t})
del t
return ret return ret
def iteration(self, feats, prevs, _recycle=True): def iteration(self, feats, prevs, _recycle=True):
...@@ -218,6 +217,9 @@ class AlphaFold(nn.Module): ...@@ -218,6 +217,9 @@ class AlphaFold(nn.Module):
n = feats["target_feat"].shape[-2] n = feats["target_feat"].shape[-2]
n_seq = feats["msa_feat"].shape[-3] n_seq = feats["msa_feat"].shape[-3]
device = feats["target_feat"].device device = feats["target_feat"].device
# Controls whether the model uses in-place operations throughout
# The dual condition accounts for activation checkpoints
inplace_safe = not (self.training or torch.is_grad_enabled()) inplace_safe = not (self.training or torch.is_grad_enabled())
# Prep some features # Prep some features
...@@ -233,10 +235,11 @@ class AlphaFold(nn.Module): ...@@ -233,10 +235,11 @@ class AlphaFold(nn.Module):
feats["target_feat"], feats["target_feat"],
feats["residue_index"], feats["residue_index"],
feats["msa_feat"], feats["msa_feat"],
inplace_safe=inplace_safe,
) )
# Unpack the recycling embeddings. Removing them from the list allows # Unpack the recycling embeddings. Removing them from the list allows
# them to be freed further down in this function. # them to be freed further down in this function, saving memory
m_1_prev, z_prev, x_prev = reversed([prevs.pop() for _ in range(3)]) m_1_prev, z_prev, x_prev = reversed([prevs.pop() for _ in range(3)])
# Initialize the recycling embeddings, if needs be # Initialize the recycling embeddings, if needs be
...@@ -263,26 +266,37 @@ class AlphaFold(nn.Module): ...@@ -263,26 +266,37 @@ class AlphaFold(nn.Module):
feats["aatype"], x_prev, None feats["aatype"], x_prev, None
).to(dtype=z.dtype) ).to(dtype=z.dtype)
# The recycling embedder is memory-intensive, so we offload first
if(self.globals.offload_inference and inplace_safe):
m = m.cpu()
z = z.cpu()
# m_1_prev_emb: [*, N, C_m] # m_1_prev_emb: [*, N, C_m]
# z_prev_emb: [*, N, N, C_z] # z_prev_emb: [*, N, N, C_z]
m_1_prev_emb, z_prev_emb = self.recycling_embedder( m_1_prev_emb, z_prev_emb = self.recycling_embedder(
m_1_prev, m_1_prev,
z_prev, z_prev,
x_prev, x_prev,
_inplace=not (self.training or torch.is_grad_enabled()), inplace_safe=inplace_safe,
) )
if(self.globals.offload_inference and inplace_safe):
m = m.to(m_1_prev_emb.device)
z = z.to(z_prev.device)
# [*, S_c, N, C_m] # [*, S_c, N, C_m]
m[..., 0, :, :] += m_1_prev_emb m[..., 0, :, :] += m_1_prev_emb
# [*, N, N, C_z] # [*, N, N, C_z]
z += z_prev_emb z = add(z, z_prev_emb, inplace=inplace_safe)
# This matters during inference with large N # Deletions like these become significant for inference with large N,
# where they free unused tensors and remove references to others such
# that they can be offloaded later
del m_1_prev, z_prev, x_prev, m_1_prev_emb, z_prev_emb del m_1_prev, z_prev, x_prev, m_1_prev_emb, z_prev_emb
# Embed the templates + merge with MSA/pair embeddings # Embed the templates + merge with MSA/pair embeddings
if self.config.template.enabled: if self.config.template.enabled:
template_feats = { template_feats = {
k: v for k, v in feats.items() if k.startswith("template_") k: v for k, v in feats.items() if k.startswith("template_")
} }
...@@ -291,6 +305,7 @@ class AlphaFold(nn.Module): ...@@ -291,6 +305,7 @@ class AlphaFold(nn.Module):
z, z,
pair_mask.to(dtype=z.dtype), pair_mask.to(dtype=z.dtype),
no_batch_dims, no_batch_dims,
inplace_safe=inplace_safe,
) )
# [*, N, N, C_z] # [*, N, N, C_z]
...@@ -299,7 +314,7 @@ class AlphaFold(nn.Module): ...@@ -299,7 +314,7 @@ class AlphaFold(nn.Module):
inplace_safe, inplace_safe,
) )
if self.config.template.embed_angles: if "template_angle_embedding" in template_embeds:
# [*, S = S_c + S_t, N, C_m] # [*, S = S_c + S_t, N, C_m]
m = torch.cat( m = torch.cat(
[m, template_embeds["template_angle_embedding"]], [m, template_embeds["template_angle_embedding"]],
...@@ -317,44 +332,78 @@ class AlphaFold(nn.Module): ...@@ -317,44 +332,78 @@ class AlphaFold(nn.Module):
if self.config.extra_msa.enabled: if self.config.extra_msa.enabled:
# [*, S_e, N, C_e] # [*, S_e, N, C_e]
a = self.extra_msa_embedder(build_extra_msa_feat(feats)) a = self.extra_msa_embedder(build_extra_msa_feat(feats))
# [*, N, N, C_z]
z = self.extra_msa_stack(
a,
z,
msa_mask=feats["extra_msa_mask"].to(dtype=a.dtype),
chunk_size=self.globals.chunk_size,
use_lma=self.globals.use_lma,
pair_mask=pair_mask.to(dtype=z.dtype),
_mask_trans=self.config._mask_trans,
)
del a if(self.globals.offload_inference):
# To allow the extra MSA stack (and later the evoformer) to
# offload its inputs, we remove all references to them here
input_tensors = [a, z]
del a, z
# [*, N, N, C_z]
z = self.extra_msa_stack._forward_offload(
input_tensors,
msa_mask=feats["extra_msa_mask"].to(dtype=m.dtype),
chunk_size=self.globals.chunk_size,
use_lma=self.globals.use_lma,
pair_mask=pair_mask.to(dtype=m.dtype),
_mask_trans=self.config._mask_trans,
)
del input_tensors
else:
# [*, N, N, C_z]
z = self.extra_msa_stack(
a, z,
msa_mask=feats["extra_msa_mask"].to(dtype=m.dtype),
chunk_size=self.globals.chunk_size,
use_lma=self.globals.use_lma,
pair_mask=pair_mask.to(dtype=m.dtype),
inplace_safe=inplace_safe,
_mask_trans=self.config._mask_trans,
)
# Run MSA + pair embeddings through the trunk of the network # Run MSA + pair embeddings through the trunk of the network
# m: [*, S, N, C_m] # m: [*, S, N, C_m]
# z: [*, N, N, C_z] # z: [*, N, N, C_z]
# s: [*, N, C_s] # s: [*, N, C_s]
m, z, s = self.evoformer( if(self.globals.offload_inference):
m, input_tensors = [m, z]
z, del m, z
msa_mask=msa_mask.to(dtype=m.dtype), m, z, s = self.evoformer._forward_offload(
pair_mask=pair_mask.to(dtype=z.dtype), input_tensors,
chunk_size=self.globals.chunk_size, msa_mask=msa_mask.to(dtype=input_tensors[0].dtype),
use_lma=self.globals.use_lma, pair_mask=pair_mask.to(dtype=input_tensors[1].dtype),
_mask_trans=self.config._mask_trans, chunk_size=self.globals.chunk_size,
) use_lma=self.globals.use_lma,
_mask_trans=self.config._mask_trans,
)
del input_tensors
else:
m, z, s = self.evoformer(
m,
z,
msa_mask=msa_mask.to(dtype=m.dtype),
pair_mask=pair_mask.to(dtype=z.dtype),
chunk_size=self.globals.chunk_size,
use_lma=self.globals.use_lma,
inplace_safe=inplace_safe,
_mask_trans=self.config._mask_trans,
)
outputs["msa"] = m[..., :n_seq, :, :] outputs["msa"] = m[..., :n_seq, :, :]
outputs["pair"] = z outputs["pair"] = z
outputs["single"] = s outputs["single"] = s
del z
# Predict 3D structure # Predict 3D structure
outputs["sm"] = self.structure_module( outputs["sm"] = self.structure_module(
s, outputs,
z,
feats["aatype"], feats["aatype"],
mask=feats["seq_mask"].to(dtype=s.dtype), mask=feats["seq_mask"].to(dtype=s.dtype),
inplace_safe=inplace_safe,
_offload_inference=self.globals.offload_inference,
) )
outputs["final_atom_positions"] = atom14_to_atom37( outputs["final_atom_positions"] = atom14_to_atom37(
outputs["sm"]["positions"][-1], feats outputs["sm"]["positions"][-1], feats
...@@ -368,7 +417,7 @@ class AlphaFold(nn.Module): ...@@ -368,7 +417,7 @@ class AlphaFold(nn.Module):
m_1_prev = m[..., 0, :, :] m_1_prev = m[..., 0, :, :]
# [*, N, N, C_z] # [*, N, N, C_z]
z_prev = z z_prev = outputs["pair"]
# [*, N, 3] # [*, N, 3]
x_prev = outputs["final_atom_positions"] x_prev = outputs["final_atom_positions"]
......
...@@ -26,8 +26,8 @@ from openfold.model.primitives import ( ...@@ -26,8 +26,8 @@ from openfold.model.primitives import (
_attention_chunked_trainable, _attention_chunked_trainable,
) )
from openfold.utils.checkpointing import get_checkpoint_fn from openfold.utils.checkpointing import get_checkpoint_fn
from openfold.utils.chunk_utils import chunk_layer
from openfold.utils.tensor_utils import ( from openfold.utils.tensor_utils import (
chunk_layer,
permute_final_dims, permute_final_dims,
flatten_final_dims, flatten_final_dims,
) )
...@@ -94,16 +94,20 @@ class MSAAttention(nn.Module): ...@@ -94,16 +94,20 @@ class MSAAttention(nn.Module):
use_memory_efficient_kernel: bool, use_memory_efficient_kernel: bool,
use_lma: bool, use_lma: bool,
) -> torch.Tensor: ) -> torch.Tensor:
mha = partial( def fn(m, biases):
self.mha, m = self.layer_norm_m(m)
use_memory_efficient_kernel=use_memory_efficient_kernel, return self.mha(
use_lma=use_lma, q_x=m,
) kv_x=m,
biases=biases,
use_memory_efficient_kernel=use_memory_efficient_kernel,
use_lma=use_lma,
)
return chunk_layer( return chunk_layer(
mha, fn,
{ {
"q_x": m, "m": m,
"kv_x": m,
"biases": biases, "biases": biases,
}, },
chunk_size=chunk_size, chunk_size=chunk_size,
...@@ -113,11 +117,9 @@ class MSAAttention(nn.Module): ...@@ -113,11 +117,9 @@ class MSAAttention(nn.Module):
def _prep_inputs(self, def _prep_inputs(self,
m: torch.Tensor, m: torch.Tensor,
z: Optional[torch.Tensor], z: Optional[torch.Tensor],
mask: Optional[torch.Tensor] mask: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: inplace_safe: bool = False,
# [*, N_seq, N_res, C_m] ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
m = self.layer_norm_m(m)
n_seq, n_res = m.shape[-3:-1] n_seq, n_res = m.shape[-3:-1]
if mask is None: if mask is None:
# [*, N_seq, N_res] # [*, N_seq, N_res]
...@@ -133,11 +135,20 @@ class MSAAttention(nn.Module): ...@@ -133,11 +135,20 @@ class MSAAttention(nn.Module):
self.layer_norm_z is not None and # benefit of self.layer_norm_z is not None and # benefit of
self.linear_z is not None # TorchScript self.linear_z is not None # TorchScript
): ):
# [*, N_res, N_res, C_z] chunks = []
z = self.layer_norm_z(z)
for i in range(0, z.shape[-3], 256):
z_chunk = z[..., i: i + 256, :, :]
# [*, N_res, N_res, C_z]
z_chunk = self.layer_norm_z(z_chunk)
# [*, N_res, N_res, no_heads]
z_chunk = self.linear_z(z_chunk)
chunks.append(z_chunk)
# [*, N_res, N_res, no_heads] z = torch.cat(chunks, dim=-3)
z = self.linear_z(z)
# [*, 1, no_heads, N_res, N_res] # [*, 1, no_heads, N_res, N_res]
z = permute_final_dims(z, (2, 0, 1)).unsqueeze(-4) z = permute_final_dims(z, (2, 0, 1)).unsqueeze(-4)
...@@ -151,6 +162,7 @@ class MSAAttention(nn.Module): ...@@ -151,6 +162,7 @@ class MSAAttention(nn.Module):
mask: Optional[torch.Tensor], mask: Optional[torch.Tensor],
chunk_logits: int, chunk_logits: int,
checkpoint: bool, checkpoint: bool,
inplace_safe: bool = False
) -> torch.Tensor: ) -> torch.Tensor:
""" """
MSA attention with training-time chunking of the softmax computation. MSA attention with training-time chunking of the softmax computation.
...@@ -160,7 +172,9 @@ class MSAAttention(nn.Module): ...@@ -160,7 +172,9 @@ class MSAAttention(nn.Module):
MSA_DIM = -4 MSA_DIM = -4
def _get_qkv(m, z): def _get_qkv(m, z):
m, mask_bias, z = self._prep_inputs(m, z, mask) m, mask_bias, z = self._prep_inputs(
m, z, mask, inplace_safe=inplace_safe
)
q, k, v = self.mha._prep_qkv(m, m) q, k, v = self.mha._prep_qkv(m, m)
return m, q, k, v, mask_bias, z return m, q, k, v, mask_bias, z
...@@ -196,6 +210,7 @@ class MSAAttention(nn.Module): ...@@ -196,6 +210,7 @@ class MSAAttention(nn.Module):
chunk_size: Optional[int] = None, chunk_size: Optional[int] = None,
use_memory_efficient_kernel: bool = False, use_memory_efficient_kernel: bool = False,
use_lma: bool = False, use_lma: bool = False,
inplace_safe: bool = False,
_chunk_logits: Optional[int] = None, _chunk_logits: Optional[int] = None,
_checkpoint_chunks: Optional[bool] = None, _checkpoint_chunks: Optional[bool] = None,
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -217,10 +232,14 @@ class MSAAttention(nn.Module): ...@@ -217,10 +232,14 @@ class MSAAttention(nn.Module):
if(_chunk_logits is not None): if(_chunk_logits is not None):
return self._chunked_msa_attn( return self._chunked_msa_attn(
m=m, z=z, mask=mask, m=m, z=z, mask=mask,
chunk_logits=_chunk_logits, checkpoint=_checkpoint_chunks chunk_logits=_chunk_logits,
checkpoint=_checkpoint_chunks,
inplace_safe=inplace_safe,
) )
m, mask_bias, z = self._prep_inputs(m, z, mask) m, mask_bias, z = self._prep_inputs(
m, z, mask, inplace_safe=inplace_safe
)
biases = [mask_bias] biases = [mask_bias]
if(z is not None): if(z is not None):
...@@ -376,8 +395,13 @@ class MSAColumnGlobalAttention(nn.Module): ...@@ -376,8 +395,13 @@ class MSAColumnGlobalAttention(nn.Module):
"m": m, "m": m,
"mask": mask, "mask": mask,
} }
def fn(m, mask):
m = self.layer_norm_m(m)
return self.global_attention(m, mask, use_lma=use_lma)
return chunk_layer( return chunk_layer(
partial(self.global_attention, use_lma=use_lma), fn,
mha_input, mha_input,
chunk_size=chunk_size, chunk_size=chunk_size,
no_batch_dims=len(m.shape[:-2]), no_batch_dims=len(m.shape[:-2]),
...@@ -405,11 +429,12 @@ class MSAColumnGlobalAttention(nn.Module): ...@@ -405,11 +429,12 @@ class MSAColumnGlobalAttention(nn.Module):
mask = mask.transpose(-1, -2) mask = mask.transpose(-1, -2)
# [*, N_res, N_seq, C_in] # [*, N_res, N_seq, C_in]
m = self.layer_norm_m(m) #m = self.layer_norm_m(m)
if chunk_size is not None: if chunk_size is not None:
m = self._chunk(m, mask, chunk_size, use_lma=use_lma) m = self._chunk(m, mask, chunk_size, use_lma=use_lma)
else: else:
m = self.layer_norm_m(m)
m = self.global_attention(m=m, mask=mask, use_lma=use_lma) m = self.global_attention(m=m, mask=mask, use_lma=use_lma)
# [*, N_seq, N_res, C_in] # [*, N_seq, N_res, C_in]
......
...@@ -20,7 +20,7 @@ import torch ...@@ -20,7 +20,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from openfold.model.primitives import Linear from openfold.model.primitives import Linear
from openfold.utils.tensor_utils import chunk_layer from openfold.utils.chunk_utils import chunk_layer
class OuterProductMean(nn.Module): class OuterProductMean(nn.Module):
...@@ -97,7 +97,7 @@ class OuterProductMean(nn.Module): ...@@ -97,7 +97,7 @@ class OuterProductMean(nn.Module):
m: torch.Tensor, m: torch.Tensor,
mask: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None, chunk_size: Optional[int] = None,
_inplace: bool = False, inplace_safe: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Args: Args:
...@@ -137,7 +137,7 @@ class OuterProductMean(nn.Module): ...@@ -137,7 +137,7 @@ class OuterProductMean(nn.Module):
norm = norm + self.eps norm = norm + self.eps
# [*, N_res, N_res, C_z] # [*, N_res, N_res, C_z]
if(_inplace): if(inplace_safe):
outer /= norm outer /= norm
else: else:
outer = outer / norm outer = outer / norm
......
...@@ -18,7 +18,7 @@ import torch ...@@ -18,7 +18,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from openfold.model.primitives import Linear, LayerNorm from openfold.model.primitives import Linear, LayerNorm
from openfold.utils.tensor_utils import chunk_layer from openfold.utils.chunk_utils import chunk_layer
class PairTransition(nn.Module): class PairTransition(nn.Module):
...@@ -46,6 +46,9 @@ class PairTransition(nn.Module): ...@@ -46,6 +46,9 @@ class PairTransition(nn.Module):
self.linear_2 = Linear(self.n * self.c_z, c_z, init="final") self.linear_2 = Linear(self.n * self.c_z, c_z, init="final")
def _transition(self, z, mask): def _transition(self, z, mask):
# [*, N_res, N_res, C_z]
z = self.layer_norm(z)
# [*, N_res, N_res, C_hidden] # [*, N_res, N_res, C_hidden]
z = self.linear_1(z) z = self.linear_1(z)
z = self.relu(z) z = self.relu(z)
...@@ -88,9 +91,6 @@ class PairTransition(nn.Module): ...@@ -88,9 +91,6 @@ class PairTransition(nn.Module):
# [*, N_res, N_res, 1] # [*, N_res, N_res, 1]
mask = mask.unsqueeze(-1) mask = mask.unsqueeze(-1)
# [*, N_res, N_res, C_z]
z = self.layer_norm(z)
if chunk_size is not None: if chunk_size is not None:
z = self._chunk(z, mask, chunk_size) z = self._chunk(z, mask, chunk_size)
else: else:
......
...@@ -23,11 +23,11 @@ import torch.nn as nn ...@@ -23,11 +23,11 @@ import torch.nn as nn
from scipy.stats import truncnorm from scipy.stats import truncnorm
from openfold.utils.checkpointing import get_checkpoint_fn from openfold.utils.checkpointing import get_checkpoint_fn
from openfold.utils.chunk_utils import _chunk_slice
from openfold.utils.kernel.attention_core import attention_core from openfold.utils.kernel.attention_core import attention_core
from openfold.utils.tensor_utils import ( from openfold.utils.tensor_utils import (
permute_final_dims, permute_final_dims,
flatten_final_dims, flatten_final_dims,
_chunk_slice,
) )
...@@ -149,26 +149,26 @@ class Linear(nn.Linear): ...@@ -149,26 +149,26 @@ class Linear(nn.Linear):
with torch.no_grad(): with torch.no_grad():
self.bias.fill_(0) self.bias.fill_(0)
if init_fn is not None: with torch.no_grad():
init_fn(self.weight, self.bias) if init_fn is not None:
else: init_fn(self.weight, self.bias)
if init == "default":
lecun_normal_init_(self.weight)
elif init == "relu":
he_normal_init_(self.weight)
elif init == "glorot":
glorot_uniform_init_(self.weight)
elif init == "gating":
gating_init_(self.weight)
if bias:
with torch.no_grad():
self.bias.fill_(1.0)
elif init == "normal":
normal_init_(self.weight)
elif init == "final":
final_init_(self.weight)
else: else:
raise ValueError("Invalid init string.") if init == "default":
lecun_normal_init_(self.weight)
elif init == "relu":
he_normal_init_(self.weight)
elif init == "glorot":
glorot_uniform_init_(self.weight)
elif init == "gating":
gating_init_(self.weight)
if bias:
self.bias.fill_(1.0)
elif init == "normal":
normal_init_(self.weight)
elif init == "final":
final_init_(self.weight)
else:
raise ValueError("Invalid init string.")
class LayerNorm(nn.Module): class LayerNorm(nn.Module):
......
...@@ -19,7 +19,7 @@ from operator import mul ...@@ -19,7 +19,7 @@ from operator import mul
import torch import torch
import torch.nn as nn import torch.nn as nn
from typing import Optional, Tuple from typing import Optional, Tuple, Sequence
from openfold.model.primitives import Linear, LayerNorm, ipa_point_weights_init_ from openfold.model.primitives import Linear, LayerNorm, ipa_point_weights_init_
from openfold.np.residue_constants import ( from openfold.np.residue_constants import (
...@@ -229,9 +229,12 @@ class InvariantPointAttention(nn.Module): ...@@ -229,9 +229,12 @@ class InvariantPointAttention(nn.Module):
def forward( def forward(
self, self,
s: torch.Tensor, s: torch.Tensor,
z: torch.Tensor, z: Optional[torch.Tensor],
r: Rigid, r: Rigid,
mask: torch.Tensor, mask: torch.Tensor,
inplace_safe: bool = False,
_offload_inference: bool = False,
_z_reference_list: Optional[Sequence[torch.Tensor]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Args: Args:
...@@ -246,8 +249,11 @@ class InvariantPointAttention(nn.Module): ...@@ -246,8 +249,11 @@ class InvariantPointAttention(nn.Module):
Returns: Returns:
[*, N_res, C_s] single representation update [*, N_res, C_s] single representation update
""" """
inplace_safe = not (self.training or torch.is_grad_enabled()) if(_offload_inference and inplace_safe):
z = _z_reference_list
else:
z = [z]
####################################### #######################################
# Generate scalar and point activations # Generate scalar and point activations
####################################### #######################################
...@@ -298,7 +304,10 @@ class InvariantPointAttention(nn.Module): ...@@ -298,7 +304,10 @@ class InvariantPointAttention(nn.Module):
# Compute attention scores # Compute attention scores
########################## ##########################
# [*, N_res, N_res, H] # [*, N_res, N_res, H]
b = self.linear_b(z) b = self.linear_b(z[0])
if(_offload_inference):
z[0] = z[0].cpu()
# [*, H, N_res, N_res] # [*, H, N_res, N_res]
a = torch.matmul( a = torch.matmul(
...@@ -392,8 +401,11 @@ class InvariantPointAttention(nn.Module): ...@@ -392,8 +401,11 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, H * P_v, 3] # [*, N_res, H * P_v, 3]
o_pt = o_pt.reshape(*o_pt.shape[:-3], -1, 3) o_pt = o_pt.reshape(*o_pt.shape[:-3], -1, 3)
if(_offload_inference):
z[0] = z[0].to(o_pt.device)
# [*, N_res, H, C_z] # [*, N_res, H, C_z]
o_pair = torch.matmul(a.transpose(-2, -3), z.to(dtype=a.dtype)) o_pair = torch.matmul(a.transpose(-2, -3), z[0].to(dtype=a.dtype))
# [*, N_res, H * C_z] # [*, N_res, H * C_z]
o_pair = flatten_final_dims(o_pair, 2) o_pair = flatten_final_dims(o_pair, 2)
...@@ -402,9 +414,9 @@ class InvariantPointAttention(nn.Module): ...@@ -402,9 +414,9 @@ class InvariantPointAttention(nn.Module):
s = self.linear_out( s = self.linear_out(
torch.cat( torch.cat(
(o, *torch.unbind(o_pt, dim=-1), o_pt_norm, o_pair), dim=-1 (o, *torch.unbind(o_pt, dim=-1), o_pt_norm, o_pair), dim=-1
).to(dtype=z.dtype) ).to(dtype=z[0].dtype)
) )
return s return s
...@@ -604,17 +616,20 @@ class StructureModule(nn.Module): ...@@ -604,17 +616,20 @@ class StructureModule(nn.Module):
def forward( def forward(
self, self,
s, evoformer_output_dict,
z,
aatype, aatype,
mask=None, mask=None,
inplace_safe=False,
_offload_inference=False,
): ):
""" """
Args: Args:
s: evoformer_output_dict:
[*, N_res, C_s] single representation Dictionary containing:
z: "single":
[*, N_res, N_res, C_z] pair representation [*, N_res, C_s] single representation
"pair":
[*, N_res, N_res, C_z] pair representation
aatype: aatype:
[*, N_res] amino acid indices [*, N_res] amino acid indices
mask: mask:
...@@ -622,6 +637,8 @@ class StructureModule(nn.Module): ...@@ -622,6 +637,8 @@ class StructureModule(nn.Module):
Returns: Returns:
A dictionary of outputs A dictionary of outputs
""" """
s = evoformer_output_dict["single"]
if mask is None: if mask is None:
# [*, N] # [*, N]
mask = s.new_ones(s.shape[:-1]) mask = s.new_ones(s.shape[:-1])
...@@ -630,7 +647,13 @@ class StructureModule(nn.Module): ...@@ -630,7 +647,13 @@ class StructureModule(nn.Module):
s = self.layer_norm_s(s) s = self.layer_norm_s(s)
# [*, N, N, C_z] # [*, N, N, C_z]
z = self.layer_norm_z(z) z = self.layer_norm_z(evoformer_output_dict["pair"])
z_reference_list = None
if(_offload_inference):
evoformer_output_dict["pair"] = evoformer_output_dict["pair"].cpu()
z_reference_list = [z]
z = None
# [*, N, C_s] # [*, N, C_s]
s_initial = s s_initial = s
...@@ -647,11 +670,19 @@ class StructureModule(nn.Module): ...@@ -647,11 +670,19 @@ class StructureModule(nn.Module):
outputs = [] outputs = []
for i in range(self.no_blocks): for i in range(self.no_blocks):
# [*, N, C_s] # [*, N, C_s]
s = s + self.ipa(s, z, rigids, mask) s = s + self.ipa(
s,
z,
rigids,
mask,
inplace_safe=inplace_safe,
_offload_inference=_offload_inference,
_z_reference_list=z_reference_list
)
s = self.ipa_dropout(s) s = self.ipa_dropout(s)
s = self.layer_norm_ipa(s) s = self.layer_norm_ipa(s)
s = self.transition(s) s = self.transition(s)
# [*, N] # [*, N]
rigids = rigids.compose_q_update_vec(self.bb_update(s)) rigids = rigids.compose_q_update_vec(self.bb_update(s))
...@@ -698,6 +729,13 @@ class StructureModule(nn.Module): ...@@ -698,6 +729,13 @@ class StructureModule(nn.Module):
rigids = rigids.stop_rot_gradient() rigids = rigids.stop_rot_gradient()
del z, z_reference_list
if(_offload_inference):
evoformer_output_dict["pair"] = (
evoformer_output_dict["pair"].to(s.device)
)
outputs = dict_multimap(torch.stack, outputs) outputs = dict_multimap(torch.stack, outputs)
outputs["single"] = s outputs["single"] = s
......
...@@ -34,13 +34,16 @@ from openfold.model.triangular_multiplicative_update import ( ...@@ -34,13 +34,16 @@ from openfold.model.triangular_multiplicative_update import (
TriangleMultiplicationIncoming, TriangleMultiplicationIncoming,
) )
from openfold.utils.checkpointing import checkpoint_blocks from openfold.utils.checkpointing import checkpoint_blocks
from openfold.utils.chunk_utils import (
chunk_layer,
ChunkSizeTuner,
)
from openfold.utils.feats import ( from openfold.utils.feats import (
build_template_angle_feat, build_template_angle_feat,
build_template_pair_feat, build_template_pair_feat,
) )
from openfold.utils.tensor_utils import ( from openfold.utils.tensor_utils import (
add, add,
chunk_layer,
permute_final_dims, permute_final_dims,
flatten_final_dims, flatten_final_dims,
tensor_tree_map, tensor_tree_map,
...@@ -198,15 +201,20 @@ class TemplatePairStackBlock(nn.Module): ...@@ -198,15 +201,20 @@ class TemplatePairStackBlock(nn.Module):
mask: torch.Tensor, mask: torch.Tensor,
chunk_size: Optional[int] = None, chunk_size: Optional[int] = None,
use_lma: bool = False, use_lma: bool = False,
inplace_safe: bool = False,
_mask_trans: bool = True, _mask_trans: bool = True,
_inplace: bool = False, _attn_chunk_size: Optional[int] = None,
): ):
if(_attn_chunk_size is None):
_attn_chunk_size = chunk_size
single_templates = [ single_templates = [
t.unsqueeze(-4) for t in torch.unbind(z, dim=-4) t.unsqueeze(-4) for t in torch.unbind(z, dim=-4)
] ]
single_templates_masks = [ single_templates_masks = [
m.unsqueeze(-3) for m in torch.unbind(mask, dim=-3) m.unsqueeze(-3) for m in torch.unbind(mask, dim=-3)
] ]
for i in range(len(single_templates)): for i in range(len(single_templates)):
single = single_templates[i] single = single_templates[i]
single_mask = single_templates_masks[i] single_mask = single_templates_masks[i]
...@@ -215,33 +223,35 @@ class TemplatePairStackBlock(nn.Module): ...@@ -215,33 +223,35 @@ class TemplatePairStackBlock(nn.Module):
self.dropout_row( self.dropout_row(
self.tri_att_start( self.tri_att_start(
single, single,
chunk_size=chunk_size, chunk_size=_attn_chunk_size,
mask=single_mask, mask=single_mask,
use_lma=use_lma, use_lma=use_lma,
inplace_safe=inplace_safe,
) )
), ),
_inplace, inplace_safe,
) )
single = add(single, single = add(single,
self.dropout_col( self.dropout_col(
self.tri_att_end( self.tri_att_end(
single, single,
chunk_size=chunk_size, chunk_size=_attn_chunk_size,
mask=single_mask, mask=single_mask,
use_lma=use_lma, use_lma=use_lma,
inplace_safe=inplace_safe,
) )
), ),
_inplace, inplace_safe,
) )
tmu_update = self.tri_mul_out( tmu_update = self.tri_mul_out(
single, single,
mask=single_mask, mask=single_mask,
_inplace=_inplace, inplace_safe=inplace_safe,
_add_with_inplace=True, _add_with_inplace=True,
) )
if(not _inplace): if(not inplace_safe):
single = single + self.dropout_row(tmu_update) single = single + self.dropout_row(tmu_update)
else: else:
single = tmu_update single = tmu_update
...@@ -251,10 +261,10 @@ class TemplatePairStackBlock(nn.Module): ...@@ -251,10 +261,10 @@ class TemplatePairStackBlock(nn.Module):
tmu_update = self.tri_mul_in( tmu_update = self.tri_mul_in(
single, single,
mask=single_mask, mask=single_mask,
_inplace=_inplace, inplace_safe=inplace_safe,
_add_with_inplace=True, _add_with_inplace=True,
) )
if(not _inplace): if(not inplace_safe):
single = single + self.dropout_row(tmu_update) single = single + self.dropout_row(tmu_update)
else: else:
single = tmu_update single = tmu_update
...@@ -267,13 +277,13 @@ class TemplatePairStackBlock(nn.Module): ...@@ -267,13 +277,13 @@ class TemplatePairStackBlock(nn.Module):
mask=single_mask if _mask_trans else None, mask=single_mask if _mask_trans else None,
chunk_size=chunk_size, chunk_size=chunk_size,
), ),
_inplace, inplace_safe,
) )
if(not _inplace): if(not inplace_safe):
single_templates[i] = single single_templates[i] = single
if(not _inplace): if(not inplace_safe):
z = torch.cat(single_templates, dim=-4) z = torch.cat(single_templates, dim=-4)
return z return z
...@@ -293,6 +303,7 @@ class TemplatePairStack(nn.Module): ...@@ -293,6 +303,7 @@ class TemplatePairStack(nn.Module):
pair_transition_n, pair_transition_n,
dropout_rate, dropout_rate,
blocks_per_ckpt, blocks_per_ckpt,
tune_chunk_size: bool = False,
inf=1e9, inf=1e9,
**kwargs, **kwargs,
): ):
...@@ -333,12 +344,18 @@ class TemplatePairStack(nn.Module): ...@@ -333,12 +344,18 @@ class TemplatePairStack(nn.Module):
self.layer_norm = LayerNorm(c_t) self.layer_norm = LayerNorm(c_t)
self.tune_chunk_size = tune_chunk_size
self.chunk_size_tuner = None
if(tune_chunk_size):
self.chunk_size_tuner = ChunkSizeTuner()
def forward( def forward(
self, self,
t: torch.tensor, t: torch.tensor,
mask: torch.tensor, mask: torch.tensor,
chunk_size: int, chunk_size: int,
use_lma: bool = False, use_lma: bool = False,
inplace_safe: bool = False,
_mask_trans: bool = True, _mask_trans: bool = True,
): ):
""" """
...@@ -355,18 +372,34 @@ class TemplatePairStack(nn.Module): ...@@ -355,18 +372,34 @@ class TemplatePairStack(nn.Module):
expand_idx[-3] = t.shape[-4] expand_idx[-3] = t.shape[-4]
mask = mask.expand(*expand_idx) mask = mask.expand(*expand_idx)
t, = checkpoint_blocks( blocks = [
blocks=[ partial(
partial( b,
b, mask=mask,
mask=mask, chunk_size=chunk_size,
use_lma=use_lma,
inplace_safe=inplace_safe,
_mask_trans=_mask_trans,
)
for b in self.blocks
]
if(chunk_size is not None and self.chunk_size_tuner is not None):
assert(not self.training)
tuned_chunk_size = self.chunk_size_tuner.tune_chunk_size(
representative_fn=blocks[0],
args=(t.clone(),),
min_chunk_size=chunk_size,
)
blocks = [
partial(b,
chunk_size=chunk_size, chunk_size=chunk_size,
use_lma=use_lma, _attn_chunk_size=max(chunk_size, tuned_chunk_size // 4),
_mask_trans=_mask_trans, ) for b in blocks
_inplace=not (self.training or torch.is_grad_enabled()), ]
)
for b in self.blocks t, = checkpoint_blocks(
], blocks=blocks,
args=(t,), args=(t,),
blocks_per_ckpt=self.blocks_per_ckpt if self.training else None, blocks_per_ckpt=self.blocks_per_ckpt if self.training else None,
) )
...@@ -383,6 +416,7 @@ def embed_templates_offload( ...@@ -383,6 +416,7 @@ def embed_templates_offload(
pair_mask, pair_mask,
templ_dim, templ_dim,
template_chunk_size=256, template_chunk_size=256,
inplace_safe=False,
): ):
""" """
Args: Args:
...@@ -407,8 +441,6 @@ def embed_templates_offload( ...@@ -407,8 +441,6 @@ def embed_templates_offload(
offloads the large template pair tensor to CPU. Slower but more frugal offloads the large template pair tensor to CPU. Slower but more frugal
with GPU memory than the original. Useful for long-sequence inference. with GPU memory than the original. Useful for long-sequence inference.
""" """
inplace_safe = not (model.training or torch.is_grad_enabled())
# Embed the templates one at a time (with a poor man's vmap) # Embed the templates one at a time (with a poor man's vmap)
pair_embeds_cpu = [] pair_embeds_cpu = []
n = z.shape[-2] n = z.shape[-2]
...@@ -491,6 +523,7 @@ def embed_templates_average( ...@@ -491,6 +523,7 @@ def embed_templates_average(
pair_mask, pair_mask,
templ_dim, templ_dim,
templ_group_size=2, templ_group_size=2,
inplace_safe=False,
): ):
""" """
Args: Args:
...@@ -519,8 +552,6 @@ def embed_templates_average( ...@@ -519,8 +552,6 @@ def embed_templates_average(
embedding, while its low memory footprint allows the number of templates embedding, while its low memory footprint allows the number of templates
to scale almost indefinitely. to scale almost indefinitely.
""" """
inplace_safe = not (model.training or torch.is_grad_enabled())
# Embed the templates one at a time (with a poor man's vmap) # Embed the templates one at a time (with a poor man's vmap)
n = z.shape[-2] n = z.shape[-2]
n_templ = batch["template_aatype"].shape[templ_dim] n_templ = batch["template_aatype"].shape[templ_dim]
......
...@@ -21,8 +21,8 @@ import torch ...@@ -21,8 +21,8 @@ import torch
import torch.nn as nn import torch.nn as nn
from openfold.model.primitives import Linear, LayerNorm, Attention from openfold.model.primitives import Linear, LayerNorm, Attention
from openfold.utils.chunk_utils import chunk_layer
from openfold.utils.tensor_utils import ( from openfold.utils.tensor_utils import (
chunk_layer,
permute_final_dims, permute_final_dims,
flatten_final_dims, flatten_final_dims,
) )
...@@ -30,7 +30,7 @@ from openfold.utils.tensor_utils import ( ...@@ -30,7 +30,7 @@ from openfold.utils.tensor_utils import (
class TriangleAttention(nn.Module): class TriangleAttention(nn.Module):
def __init__( def __init__(
self, c_in, c_hidden, no_heads, starting, inf=1e9 self, c_in, c_hidden, no_heads, starting=True, inf=1e9
): ):
""" """
Args: Args:
...@@ -62,25 +62,36 @@ class TriangleAttention(nn.Module): ...@@ -62,25 +62,36 @@ class TriangleAttention(nn.Module):
x: torch.Tensor, x: torch.Tensor,
biases: List[torch.Tensor], biases: List[torch.Tensor],
chunk_size: int, chunk_size: int,
use_memory_efficient_kernel: bool = False,
use_lma: bool = False, use_lma: bool = False,
inplace_safe: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
"triangle! triangle!"
mha_inputs = { mha_inputs = {
"q_x": x, "q_x": x,
"kv_x": x, "kv_x": x,
"biases": biases, "biases": biases,
} }
return chunk_layer( return chunk_layer(
partial(self.mha, use_lma=use_lma), partial(
self.mha,
use_memory_efficient_kernel=use_memory_efficient_kernel,
use_lma=use_lma
),
mha_inputs, mha_inputs,
chunk_size=chunk_size, chunk_size=chunk_size,
no_batch_dims=len(x.shape[:-2]), no_batch_dims=len(x.shape[:-2]),
_out=x if inplace_safe else None,
) )
def forward(self, def forward(self,
x: torch.Tensor, x: torch.Tensor,
mask: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None, chunk_size: Optional[int] = None,
use_memory_efficient_kernel: bool = False,
use_lma: bool = False, use_lma: bool = False,
inplace_safe: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Args: Args:
...@@ -88,15 +99,14 @@ class TriangleAttention(nn.Module): ...@@ -88,15 +99,14 @@ class TriangleAttention(nn.Module):
[*, I, J, C_in] input tensor (e.g. the pair representation) [*, I, J, C_in] input tensor (e.g. the pair representation)
Returns: Returns:
[*, I, J, C_in] output tensor [*, I, J, C_in] output tensor
""" """
if mask is None: if mask is None:
# [*, I, J] # [*, I, J]
mask = x.new_ones( mask = x.new_ones(
x.shape[:-1], x.shape[:-1],
) )
# Shape annotations assume self.starting. Else, I and J are flipped if(not self.starting):
if not self.starting:
x = x.transpose(-2, -3) x = x.transpose(-2, -3)
mask = mask.transpose(-1, -2) mask = mask.transpose(-1, -2)
...@@ -115,27 +125,35 @@ class TriangleAttention(nn.Module): ...@@ -115,27 +125,35 @@ class TriangleAttention(nn.Module):
biases = [mask_bias, triangle_bias] biases = [mask_bias, triangle_bias]
if chunk_size is not None: if chunk_size is not None:
x = self._chunk(x, biases, chunk_size, use_lma=use_lma) x = self._chunk(
x,
biases,
chunk_size,
use_memory_efficient_kernel=use_memory_efficient_kernel,
use_lma=use_lma,
inplace_safe=inplace_safe,
)
else: else:
x = self.mha(q_x=x, kv_x=x, biases=biases, use_lma=use_lma) x = self.mha(
q_x=x,
kv_x=x,
biases=biases,
use_memory_efficient_kernel=use_memory_efficient_kernel,
use_lma=use_lma
)
if not self.starting: if(not self.starting):
x = x.transpose(-2, -3) x = x.transpose(-2, -3)
return x return x
class TriangleAttentionStartingNode(TriangleAttention): # Implements Algorithm 13
""" TriangleAttentionStartingNode = TriangleAttention
Implements Algorithm 13.
"""
__init__ = partialmethod(TriangleAttention.__init__, starting=True)
class TriangleAttentionEndingNode(TriangleAttention): class TriangleAttentionEndingNode(TriangleAttention):
""" """
Implements Algorithm 14. Implements Algorithm 14.
""" """
__init__ = partialmethod(TriangleAttention.__init__, starting=False) __init__ = partialmethod(TriangleAttention.__init__, starting=False)
...@@ -20,7 +20,8 @@ import torch ...@@ -20,7 +20,8 @@ import torch
import torch.nn as nn import torch.nn as nn
from openfold.model.primitives import Linear, LayerNorm from openfold.model.primitives import Linear, LayerNorm
from openfold.utils.tensor_utils import add, chunk_layer, permute_final_dims from openfold.utils.chunk_utils import chunk_layer
from openfold.utils.tensor_utils import add, permute_final_dims
class TriangleMultiplicativeUpdate(nn.Module): class TriangleMultiplicativeUpdate(nn.Module):
...@@ -356,7 +357,7 @@ class TriangleMultiplicativeUpdate(nn.Module): ...@@ -356,7 +357,7 @@ class TriangleMultiplicativeUpdate(nn.Module):
def forward(self, def forward(self,
z: torch.Tensor, z: torch.Tensor,
mask: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None,
_inplace: bool = False, inplace_safe: bool = False,
_add_with_inplace: bool = False, _add_with_inplace: bool = False,
_inplace_chunk_size: Optional[int] = 256, _inplace_chunk_size: Optional[int] = 256,
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -369,7 +370,7 @@ class TriangleMultiplicativeUpdate(nn.Module): ...@@ -369,7 +370,7 @@ class TriangleMultiplicativeUpdate(nn.Module):
Returns: Returns:
[*, N_res, N_res, C_z] output tensor [*, N_res, N_res, C_z] output tensor
""" """
if(_inplace): if(inplace_safe):
x = self._inference_forward( x = self._inference_forward(
z, z,
mask, mask,
......
...@@ -92,68 +92,76 @@ def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein: ...@@ -92,68 +92,76 @@ def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
) )
model = models[0] model = models[0]
if chain_id is not None:
chain = model[chain_id]
else:
chains = list(model.get_chains())
if len(chains) != 1:
raise ValueError(
"Only single chain PDBs are supported when chain_id not specified. "
f"Found {len(chains)} chains."
)
else:
chain = chains[0]
atom_positions = [] atom_positions = []
aatype = [] aatype = []
atom_mask = [] atom_mask = []
residue_index = [] residue_index = []
chain_ids = []
b_factors = [] b_factors = []
for res in chain: for chain in model:
if res.id[2] != " ": if(chain_id is not None and chain.id != chain_id):
raise ValueError( continue
f"PDB contains an insertion code at chain {chain.id} and residue " for res in chain:
f"index {res.id[1]}. These are not supported." if res.id[2] != " ":
raise ValueError(
f"PDB contains an insertion code at chain {chain.id} and residue "
f"index {res.id[1]}. These are not supported."
)
res_shortname = residue_constants.restype_3to1.get(res.resname, "X")
restype_idx = residue_constants.restype_order.get(
res_shortname, residue_constants.restype_num
) )
res_shortname = residue_constants.restype_3to1.get(res.resname, "X") pos = np.zeros((residue_constants.atom_type_num, 3))
restype_idx = residue_constants.restype_order.get( mask = np.zeros((residue_constants.atom_type_num,))
res_shortname, residue_constants.restype_num res_b_factors = np.zeros((residue_constants.atom_type_num,))
) for atom in res:
pos = np.zeros((residue_constants.atom_type_num, 3)) if atom.name not in residue_constants.atom_types:
mask = np.zeros((residue_constants.atom_type_num,)) continue
res_b_factors = np.zeros((residue_constants.atom_type_num,)) pos[residue_constants.atom_order[atom.name]] = atom.coord
for atom in res: mask[residue_constants.atom_order[atom.name]] = 1.0
if atom.name not in residue_constants.atom_types: res_b_factors[
residue_constants.atom_order[atom.name]
] = atom.bfactor
if np.sum(mask) < 0.5:
# If no known atom positions are reported for the residue then skip it.
continue continue
pos[residue_constants.atom_order[atom.name]] = atom.coord aatype.append(restype_idx)
mask[residue_constants.atom_order[atom.name]] = 1.0 atom_positions.append(pos)
res_b_factors[ atom_mask.append(mask)
residue_constants.atom_order[atom.name] residue_index.append(res.id[1])
] = atom.bfactor chain_ids.append(chain.id)
if np.sum(mask) < 0.5: b_factors.append(res_b_factors)
# If no known atom positions are reported for the residue then skip it.
continue
aatype.append(restype_idx)
atom_positions.append(pos)
atom_mask.append(mask)
residue_index.append(res.id[1])
b_factors.append(res_b_factors)
parents = None parents = None
parents_chain_index = None
if("PARENT" in pdb_str): if("PARENT" in pdb_str):
parents = []
parents_chain_index = []
chain_id = 0
for l in pdb_str.split("\n"): for l in pdb_str.split("\n"):
if("PARENT" in l and not "N/A" in l): if("PARENT" in l):
parents = l.split()[1:] if(not "N/A" in l):
break parent_names = l.split()[1:]
parents.extend(parent_names)
parents_chain_index.extend([
chain_id for _ in parent_names
])
chain_id += 1
unique_chain_ids = np.unique(chain_ids)
chain_id_mapping = {cid: n for n, cid in enumerate(string.ascii_uppercase)}
chain_index = np.array([chain_id_mapping[cid] for cid in chain_ids])
return Protein( return Protein(
atom_positions=np.array(atom_positions), atom_positions=np.array(atom_positions),
atom_mask=np.array(atom_mask), atom_mask=np.array(atom_mask),
aatype=np.array(aatype), aatype=np.array(aatype),
residue_index=np.array(residue_index), residue_index=np.array(residue_index),
chain_index=chain_index,
b_factors=np.array(b_factors), b_factors=np.array(b_factors),
parents=parents, parents=parents,
parents_chain_index=parents_chain_index,
) )
...@@ -232,6 +240,56 @@ def get_pdb_headers(prot: Protein, chain_id: int = 0) -> Sequence[str]: ...@@ -232,6 +240,56 @@ def get_pdb_headers(prot: Protein, chain_id: int = 0) -> Sequence[str]:
return pdb_headers return pdb_headers
def add_pdb_headers(prot: Protein, pdb_str: str) -> str:
""" Add pdb headers to an existing PDB string. Useful during multi-chain
recycling
"""
out_pdb_lines = []
lines = pdb_str.split('\n')
remark = prot.remark
if(remark is not None):
out_pdb_lines.append(f"REMARK {remark}")
parents_per_chain = None
if(prot.parents is not None and len(prot.parents) > 0):
parents_per_chain = []
if(prot.parents_chain_index is not None):
cur_chain = prot.parents_chain_index[0]
parent_dict = {}
for p, i in zip(prot.parents, prot.parents_chain_index):
parent_dict.setdefault(str(i), [])
parent_dict[str(i)].append(p)
max_idx = max([int(chain_idx) for chain_idx in parent_dict])
for i in range(max_idx + 1):
chain_parents = parent_dict.get(str(i), ["N/A"])
parents_per_chain.append(chain_parents)
else:
parents_per_chain.append(prot.parents)
else:
parents_per_chain = [["N/A"]]
make_parent_line = lambda p: f"PARENT {' '.join(p)}"
out_pdb_lines.append(make_parent_line(parents_per_chain[0]))
chain_counter = 0
for i, l in enumerate(lines):
if("PARENT" not in l and "REMARK" not in l):
out_pdb_lines.append(l)
if("TER" in l and not "END" in lines[i + 1]):
chain_counter += 1
if(not chain_counter >= len(parents_per_chain)):
chain_parents = parents_per_chain[chain_counter]
else:
chain_parents = ["N/A"]
out_pdb_lines.append(make_parent_line(chain_parents))
return '\n'.join(out_pdb_lines)
def to_pdb(prot: Protein) -> str: def to_pdb(prot: Protein) -> str:
"""Converts a `Protein` instance to a PDB string. """Converts a `Protein` instance to a PDB string.
......
...@@ -88,8 +88,6 @@ class AmberRelaxation(object): ...@@ -88,8 +88,6 @@ class AmberRelaxation(object):
"total_per_residue_violations_mask" "total_per_residue_violations_mask"
] ]
headers = protein.get_pdb_headers(prot) min_pdb = protein.add_pdb_headers(prot, min_pdb)
if(len(headers) > 0):
min_pdb = '\n'.join(['\n'.join(headers), min_pdb])
return min_pdb, debug_data, violations return min_pdb, debug_data, violations
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment