Commit 3ad85a19 authored by Sam DeLuca's avatar Sam DeLuca
Browse files

Merge remote-tracking branch 'cyrus/main' into run-multiple-models

parents 43b8c6f9 6da2cdaf
![header ](imgs/OpenFold_viz_banner.jpg)
![header ](imgs/of_banner.png)
# OpenFold
A faithful PyTorch reproduction of DeepMind's
A faithful but trainable PyTorch reproduction of DeepMind's
[AlphaFold 2](https://github.com/deepmind/alphafold).
## Features
......@@ -14,20 +14,27 @@ DeepMind experiments. It is omitted here for the sake of reducing clutter. In
cases where the *Nature* paper differs from the source, we always defer to the
latter.
OpenFold is built to support inference with AlphaFold's official parameters. Try it out for yourself with
our [Colab notebook](https://colab.research.google.com/github/aqlaboratory/openfold/blob/main/notebooks/OpenFold.ipynb).
OpenFold is trainable in full precision or `bfloat16` with or without DeepSpeed,
and we've trained it from scratch, matching the performance of the original.
We've publicly released model weights and our training data — some 400,000
MSAs and PDB70 template hit files — under a permissive license. Model weights
are available via scripts in this repository while the MSAs are hosted by the
[Registry of Open Data on AWS (RODA)](https://registry.opendata.aws/openfold).
Try out running inference for yourself with our [Colab notebook](https://colab.research.google.com/github/aqlaboratory/openfold/blob/main/notebooks/OpenFold.ipynb).
Additionally, OpenFold has the following advantages over the reference implementation:
OpenFold also supports inference using AlphaFold's official parameters.
- Openfold is **trainable** in full precision or `bfloat16` half-precision, with or without [DeepSpeed](https://github.com/microsoft/deepspeed).
- **Faster inference** on GPU.
OpenFold has the following advantages over the reference implementation:
- **Faster inference** on GPU for chains with < 1500 residues.
- **Inference on extremely long chains**, made possible by our implementation of low-memory attention
([Rabe & Staats 2021](https://arxiv.org/pdf/2112.05682.pdf)).
([Rabe & Staats 2021](https://arxiv.org/pdf/2112.05682.pdf)). OpenFold can predict the structures of
sequences with more than 4000 residues on a single A100, and even longer ones with CPU offloading.
- **Custom CUDA attention kernels** modified from [FastFold](https://github.com/hpcaitech/FastFold)'s
kernels support in-place attention during inference and training. They use
4x and 5x less GPU memory than equivalent FastFold and stock PyTorch
implementations, respectively.
- **Efficient alignment scripts** using the original AlphaFold HHblits/JackHMMER pipeline or [ColabFold](https://github.com/sokrypton/ColabFold)'s, which uses the faster MMseqs2 instead. We've used them to generate millions of alignments that will be released alongside original OpenFold weights, trained from scratch using our code (more on that soon).
- **Efficient alignment scripts** using the original AlphaFold HHblits/JackHMMER pipeline or [ColabFold](https://github.com/sokrypton/ColabFold)'s, which uses the faster MMseqs2 instead. We've used them to generate millions of alignments.
## Installation (Linux)
......@@ -70,7 +77,7 @@ To install the HH-suite to `/usr/bin`, run
## Usage
To download DeepMind's pretrained parameters and common ground truth data, run:
To download the databases used to train OpenFold and AlphaFold run:
```bash
bash scripts/download_data.sh data/
......@@ -96,12 +103,13 @@ Make sure to run the latter command on the machine that will be used for MSA
generation (the script estimates how the precomputed database index used by
MMseqs2 should be split according to the memory available on the system).
Alternatively, you can use raw MSAs from
Alternatively, you can use raw MSAs from our aforementioned MSA database or
[ProteinNet](https://github.com/aqlaboratory/proteinnet). After downloading
the database, use `scripts/prep_proteinnet_msas.py` to convert the data into
a format recognized by the OpenFold parser. The resulting directory becomes the
`alignment_dir` used in subsequent steps. Use `scripts/unpack_proteinnet.py` to
extract `.core` files from ProteinNet text files.
the latter database, use `scripts/prep_proteinnet_msas.py` to convert the data
into a format recognized by the OpenFold parser. The resulting directory
becomes the `alignment_dir` used in subsequent steps. Use
`scripts/unpack_proteinnet.py` to extract `.core` files from ProteinNet text
files.
For both inference and training, the model's hyperparameters can be tuned from
`openfold/config.py`. Of course, if you plan to perform inference using
......@@ -124,30 +132,41 @@ python3 run_pretrained_openfold.py \
--uniclust30_database_path data/uniclust30/uniclust30_2018_08/uniclust30_2018_08 \
--output_dir ./ \
--bfd_database_path data/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt \
--model_device cuda:1 \
--model_device "cuda:0" \
--jackhmmer_binary_path lib/conda/envs/openfold_venv/bin/jackhmmer \
--hhblits_binary_path lib/conda/envs/openfold_venv/bin/hhblits \
--hhsearch_binary_path lib/conda/envs/openfold_venv/bin/hhsearch \
--kalign_binary_path lib/conda/envs/openfold_venv/bin/kalign
--config_preset "model_1_ptm"
--openfold_checkpoint_path openfold/resources/openfold_params/finetuning_2_ptm.pt
```
where `data` is the same directory as in the previous step. If `jackhmmer`,
`hhblits`, `hhsearch` and `kalign` are available at the default path of
`/usr/bin`, their `binary_path` command-line arguments can be dropped.
If you've already computed alignments for the query, you have the option to
skip the expensive alignment computation here.
skip the expensive alignment computation here with
`--use_precomputed_alignments`.
Exactly one of `--openfold_checkpoint_path` or `--jax_param_path` must be specified
to run the inference script. These accept .pt/DeepSpeed OpenFold checkpoints
and AlphaFold's .npz JAX parameter files, respectively. For a breakdown of the
differences between the different parameter files, see the README downloaded to
`openfold/resources/openfold_params/`. Since OpenFold was trained under a
newer training schedule than the one from which the `model_n` config
presets are derived, there is no clean correspondence between `config_preset`
settings and OpenFold checkpoints; the only restraint is that `*_ptm`
checkpoints must be run with `*_ptm` config presets.
Note that chunking (as defined in section 1.11.8 of the AlphaFold 2 supplement)
is enabled by default in inference mode. To disable it, set `globals.chunk_size`
to `None` in the config.
Inference-time low-memory attention (LMA) can be enabled in the model config.
This setting trades off speed for vastly improved memory usage. By default,
LMA is run with query and key chunk sizes of 1024 and 4096, respectively.
These represent a favorable tradeoff in most memory-constrained cases.
Powerusers can choose to tweak these settings in
`openfold/model/primitives.py`. For more information on the LMA algorithm,
see the aforementioned Staats & Rabe preprint.
to `None` in the config. If a value is specified, OpenFold will attempt to
dynamically tune it, considering the chunk size specified in the config as a
minimum. This tuning process automatically ensures consistently fast runtimes
regardless of input sequence length, but it also introduces some runtime
variability, which may be undesirable for certain users. It is also recommended
to disable this feature for very long chains (see below). To do so, set the
`tune_chunk_size` option in the config to `False`.
Input FASTA files containing multiple sequences are treated as complexes. In
this case, the inference script runs AlphaFold-Gap, a hack proposed
......@@ -156,15 +175,10 @@ the specified stock AlphaFold/OpenFold parameters (NOT AlphaFold-Multimer). To
run inference with AlphaFold-Multimer, use the (experimental) `multimer` branch
instead.
By default, OpenFold will attempt to automatically tune the inference-time
`chunk_size` hyperparameter controlling a memory/runtime tradeoff in certain
modules during inference. The chunk size specified in the config is only
considered a minimum. This feature ensures consistently fast runtimes
regardless of input sequence length, but it also introduces some runtime
variability, which may be undesirable for certain users. To disable this
feature, set the `tune_chunk_size` option in the config to `False`.
To minimize memory usage during inference on long sequences, consider the
following changes:
As noted in the AlphaFold-Multimer paper, the AlphaFold/OpenFold template
- As noted in the AlphaFold-Multimer paper, the AlphaFold/OpenFold template
stack is a major memory bottleneck for inference on long sequences. OpenFold
supports two mutually exclusive inference modes to address this issue. One,
`average_templates` in the `template` section of the config, is similar to the
......@@ -178,6 +192,22 @@ approximation while the latter is slightly slower; both are memory-efficient
and allow the model to utilize arbitrarily many templates across sequence
lengths. Both are disabled by default, and it is up to the user to determine
which best suits their needs, if either.
- Inference-time low-memory attention (LMA) can be enabled in the model config.
This setting trades off speed for vastly improved memory usage. By default,
LMA is run with query and key chunk sizes of 1024 and 4096, respectively.
These represent a favorable tradeoff in most memory-constrained cases.
Powerusers can choose to tweak these settings in
`openfold/model/primitives.py`. For more information on the LMA algorithm,
see the aforementioned Staats & Rabe preprint.
- Disable `tune_chunk_size` for long sequences. Past a certain point, it only
wastes time.
- As a last resort, consider enabling `offload_inference`. This enables more
extensive CPU offloading at various bottlenecks throughout the model.
Using the most conservative settings, we were able to run inference on a
4600-residue complex with a single A100. Compared to AlphaFold's own memory
offloading mode, ours is considerably faster; the same complex takes the more
efficent AlphaFold-Multimer more than double the time.
### Training
......@@ -344,7 +374,7 @@ python3 /opt/openfold/run_pretrained_openfold.py \
--hhblits_binary_path /opt/conda/bin/hhblits \
--hhsearch_binary_path /opt/conda/bin/hhsearch \
--kalign_binary_path /opt/conda/bin/kalign \
--param_path /database/params/params_model_1.npz
--openfold_checkpoint_path /database/openfold_params/finetuning_2_ptm.pt
```
## Copyright notice
......
......@@ -4,7 +4,7 @@
"metadata": {
"accelerator": "GPU",
"colab": {
"name": "OpenFold.ipynb",
"name": "Copy of OpenFold.ipynb",
"provenance": [],
"collapsed_sections": []
},
......@@ -111,7 +111,7 @@
" %shell wget -q -P /content \\\n",
" https://git.scicore.unibas.ch/schwede/openstructure/-/raw/7102c63615b64735c4941278d92b554ec94415f8/modules/mol/alg/src/stereo_chemical_props.txt\n",
" pbar.update(1)\n",
"except subprocess.CalledProcessError:\n",
"except subprocess.CalledProcessError as captured:\n",
" print(captured)\n",
" raise"
],
......@@ -132,9 +132,16 @@
"\n",
"GIT_REPO = 'https://github.com/aqlaboratory/openfold'\n",
"\n",
"SOURCE_URL = 'https://storage.googleapis.com/alphafold/alphafold_params_2022-01-19.tar'\n",
"PARAMS_DIR = './openfold/openfold/resources/params'\n",
"PARAMS_PATH = os.path.join(PARAMS_DIR, os.path.basename(SOURCE_URL))\n",
"OPENFOLD_PARAM_FILE_ID = \"1OpeMrfWEUSD_KqffbPqd5p7WsJjlC3ZE\"\n",
"ALPHAFOLD_PARAM_SOURCE_URL = 'https://storage.googleapis.com/alphafold/alphafold_params_2022-01-19.tar'\n",
"OPENFOLD_PARAMS_DIR = './openfold/openfold/resources/'\n",
"ALPHAFOLD_PARAMS_DIR = './openfold/openfold/resources/params'\n",
"OPENFOLD_PARAMS_PATH = os.path.join(\n",
" OPENFOLD_PARAMS_DIR, \"openfold_params.tar.gz\"\n",
")\n",
"ALPHAFOLD_PARAMS_PATH = os.path.join(\n",
" ALPHAFOLD_PARAMS_DIR, os.path.basename(ALPHAFOLD_PARAM_SOURCE_URL)\n",
")\n",
"\n",
"try:\n",
" with tqdm.notebook.tqdm(total=100, bar_format=TQDM_BAR_FORMAT) as pbar:\n",
......@@ -144,6 +151,10 @@
" pbar.update(8)\n",
" # Install the required versions of all dependencies.\n",
" %shell conda env update -n base --file openfold/environment.yml\n",
" \n",
" %shell mkdir -p /content/openfold/openfold/resources\n",
" %shell cp -f /content/stereo_chemical_props.txt /content/openfold/openfold/resources\n",
" \n",
" # Run setup.py to install only Openfold.\n",
" %shell pip3 install --no-dependencies ./openfold\n",
" pbar.update(10)\n",
......@@ -152,17 +163,20 @@
" %shell pushd /opt/conda/lib/python3.7/site-packages/ && \\\n",
" patch -p0 < /content/openfold/lib/openmm.patch && \\\n",
" popd\n",
" \n",
" %shell mkdir -p /content/openfold/resources\n",
" %shell cp -f /content/stereo_chemical_props.txt /content/openfold/resources\n",
"\n",
" %shell mkdir --parents \"{PARAMS_DIR}\"\n",
" %shell wget -O \"{PARAMS_PATH}\" \"{SOURCE_URL}\"\n",
" %shell mkdir --parents \"{ALPHAFOLD_PARAMS_DIR}\"\n",
" %shell wget -O {ALPHAFOLD_PARAMS_PATH} {ALPHAFOLD_PARAM_SOURCE_URL}\n",
" pbar.update(27)\n",
"\n",
" %shell tar --extract --verbose --file=\"{PARAMS_PATH}\" \\\n",
" --directory=\"{PARAMS_DIR}\" --preserve-permissions\n",
" %shell rm \"{PARAMS_PATH}\"\n",
" %shell tar --extract --verbose --file=\"{ALPHAFOLD_PARAMS_PATH}\" \\\n",
" --directory=\"{ALPHAFOLD_PARAMS_DIR}\" --preserve-permissions\n",
" %shell rm \"{ALPHAFOLD_PARAMS_PATH}\"\n",
"\n",
" %shell mkdir --parents \"{OPENFOLD_PARAMS_DIR}\"\n",
" %shell gdown --id \"{OPENFOLD_PARAM_FILE_ID}\" -O \"{OPENFOLD_PARAMS_PATH}\"\n",
" %shell tar --extract --verbose --file=\"{OPENFOLD_PARAMS_PATH}\" \\\n",
" --directory=\"{OPENFOLD_PARAMS_DIR}\" --preserve-permissions\n",
" %shell rm \"{OPENFOLD_PARAMS_PATH}\"\n",
" pbar.update(55)\n",
"except subprocess.CalledProcessError:\n",
" print(captured)\n",
......@@ -171,6 +185,62 @@
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"#@title Import Python packages\n",
"#@markdown Please execute this cell by pressing the *Play* button on \n",
"#@markdown the left.\n",
"\n",
"import sys\n",
"sys.path.append('/opt/conda/lib/python3.7/site-packages')\n",
"import os\n",
"\n",
"from urllib import request\n",
"from concurrent import futures\n",
"from google.colab import files\n",
"import json\n",
"from matplotlib import gridspec\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import py3Dmol\n",
"import torch\n",
"\n",
"# A filthy hack to avoid slow Linear layer initialization\n",
"import openfold.model.primitives\n",
"\n",
"def __default_linear_init__(self, *args, **kwargs):\n",
" return torch.nn.Linear.__init__(\n",
" self, \n",
" *args[:2], \n",
" **{k:v for k,v in kwargs.items() if k == \"bias\"}\n",
" )\n",
"\n",
"openfold.model.primitives.Linear.__init__ = __default_linear_init__\n",
"\n",
"from openfold import config\n",
"from openfold.data import feature_pipeline\n",
"from openfold.data import parsers\n",
"from openfold.data import data_pipeline\n",
"from openfold.data.tools import jackhmmer\n",
"from openfold.model import model\n",
"from openfold.np import protein\n",
"from openfold.np.relax import relax\n",
"from openfold.np.relax import utils\n",
"from openfold.utils.import_weights import import_jax_weights_\n",
"from openfold.utils.tensor_utils import tensor_tree_map\n",
"\n",
"from IPython import display\n",
"from ipywidgets import GridspecLayout\n",
"from ipywidgets import Output"
],
"metadata": {
"cellView": "form",
"id": "_FpxxMo-mvcP"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
......@@ -191,12 +261,16 @@
"cellView": "form"
},
"source": [
"#@title Enter the amino acid sequence to fold ⬇️\n",
"#@markdown ### Enter the amino acid sequence to fold ⬇️\n",
"sequence = 'MAAHKGAEHHHKAAEHHEQAAKHHHAAAEHHEKGEHEQAAHHADTAYAHHKHAEEHAAQAAKHDAEHHAPKPH' #@param {type:\"string\"}\n",
"\n",
"MIN_SEQUENCE_LENGTH = 16\n",
"MAX_SEQUENCE_LENGTH = 2500\n",
"\n",
"#@markdown ### Choose between OpenFold and AlphaFold model parameters ⬇️\n",
"\n",
"weight_set = 'OpenFold' #@param [\"OpenFold\", \"AlphaFold\"]\n",
"\n",
"# Remove all whitespaces, tabs and end lines; upper-case\n",
"sequence = sequence.translate(str.maketrans('', '', ' \\n\\t')).upper()\n",
"aatypes = set('ACDEFGHIKLMNPQRSTVWY') # 20 standard aatypes\n",
......@@ -225,39 +299,6 @@
"#@markdown you’ll see how well each residue is covered by similar \n",
"#@markdown sequences in the MSA.\n",
"\n",
"# --- Python imports ---\n",
"import sys\n",
"sys.path.append('/opt/conda/lib/python3.7/site-packages')\n",
"import os\n",
"os.environ['TF_FORCE_UNIFIED_MEMORY'] = '1'\n",
"os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '2.0'\n",
"\n",
"from urllib import request\n",
"from concurrent import futures\n",
"from google.colab import files\n",
"import json\n",
"from matplotlib import gridspec\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import py3Dmol\n",
"import torch\n",
"\n",
"from openfold import config\n",
"from openfold.data import feature_pipeline\n",
"from openfold.data import parsers\n",
"from openfold.data import data_pipeline\n",
"from openfold.data.tools import jackhmmer\n",
"from openfold.model import model\n",
"from openfold.np import protein\n",
"from openfold.np.relax import relax\n",
"from openfold.np.relax import utils\n",
"from openfold.utils.import_weights import import_jax_weights_\n",
"from openfold.utils.tensor_utils import tensor_tree_map\n",
"\n",
"from IPython import display\n",
"from ipywidgets import GridspecLayout\n",
"from ipywidgets import Output\n",
"\n",
"# Color bands for visualizing plddt\n",
"PLDDT_BANDS = [(0, 50, '#FF7D45'),\n",
" (50, 70, '#FFDB13'),\n",
......@@ -420,8 +461,25 @@
" cfg = config.model_config(model_name)\n",
" openfold_model = model.AlphaFold(cfg)\n",
" openfold_model = openfold_model.eval()\n",
" params_name = os.path.join(PARAMS_DIR, f\"params_{model_name}.npz\")\n",
" if(weight_set == \"AlphaFold\"):\n",
" params_name = os.path.join(ALPHAFOLD_PARAMS_DIR, f\"params_{model_name}.npz\")\n",
" import_jax_weights_(openfold_model, params_name, version=model_name)\n",
" elif(weight_set == \"OpenFold\"):\n",
" model_name_spl = model_name.split(\"_\")\n",
" if(model_name_spl[-1] == \"ptm\"):\n",
" of_model_name = \"finetuning_ptm_2.pt\"\n",
" else:\n",
" of_model_name = f\"finetuning_{model_name_spl[-1]}.pt\"\n",
" params_name = os.path.join(\n",
" OPENFOLD_PARAMS_DIR,\n",
" \"openfold_params\",\n",
" of_model_name\n",
" )\n",
" d = torch.load(params_name)\n",
" openfold_model.load_state_dict(d)\n",
" else:\n",
" raise ValueError(f\"Invalid weight set: {weight_set}\")\n",
"\n",
" openfold_model = openfold_model.cuda()\n",
"\n",
" pipeline = feature_pipeline.FeaturePipeline(cfg.data)\n",
......
......@@ -119,7 +119,9 @@ def model_config(name, train=False, low_prec=False):
c.globals.blocks_per_ckpt = 1
c.globals.chunk_size = None
c.globals.use_lma = False
c.globals.offload_inference = False
c.model.template.average_templates = False
c.model.template.offload_templates = False
if low_prec:
c.globals.eps = 1e-4
# If we want exact numerical parity with the original, inf can't be
......@@ -314,6 +316,7 @@ config = mlc.ConfigDict(
"blocks_per_ckpt": blocks_per_ckpt,
"chunk_size": chunk_size,
"use_lma": False,
"offload_inference": False,
"c_z": c_z,
"c_m": c_m,
"c_t": c_t,
......@@ -364,6 +367,7 @@ config = mlc.ConfigDict(
"pair_transition_n": 2,
"dropout_rate": 0.25,
"blocks_per_ckpt": blocks_per_ckpt,
"tune_chunk_size": tune_chunk_size,
"inf": 1e9,
},
"template_pointwise_attention": {
......@@ -409,7 +413,7 @@ config = mlc.ConfigDict(
"transition_n": 4,
"msa_dropout": 0.15,
"pair_dropout": 0.25,
"clear_cache_between_blocks": True,
"clear_cache_between_blocks": False,
"tune_chunk_size": tune_chunk_size,
"inf": 1e9,
"eps": eps, # 1e-10,
......
......@@ -432,7 +432,7 @@ def _is_set(data: str) -> bool:
def get_atom_coords(
mmcif_object: MmcifObject,
chain_id: str,
_zero_center_positions: bool = True
_zero_center_positions: bool = False
) -> Tuple[np.ndarray, np.ndarray]:
# Locate the right chain
chains = list(mmcif_object.structure.get_chains())
......
......@@ -503,7 +503,7 @@ def _get_atom_positions(
mmcif_object: mmcif_parsing.MmcifObject,
auth_chain_id: str,
max_ca_ca_distance: float,
_zero_center_positions: bool = True,
_zero_center_positions: bool = False,
) -> Tuple[np.ndarray, np.ndarray]:
"""Gets atom positions and mask from a list of Biopython Residues."""
coords_with_mask = mmcif_parsing.get_atom_coords(
......@@ -1045,6 +1045,7 @@ class TemplateHitFeaturizer:
filtered = list(
sorted(filtered, key=lambda x: x.sum_probs, reverse=True)
)
idx = list(range(len(filtered)))
if(self._shuffle_top_k_prefiltered):
stk = self._shuffle_top_k_prefiltered
......
......@@ -82,14 +82,20 @@ class InputEmbedder(nn.Module):
boundaries = torch.arange(
start=-self.relpos_k, end=self.relpos_k + 1, device=d.device
)
oh = one_hot(d, boundaries).type(ri.dtype)
return self.linear_relpos(oh)
reshaped_bins = boundaries.view(((1,) * len(d.shape)) + (len(boundaries),))
d = d[..., None] - reshaped_bins
d = torch.abs(d)
d = torch.argmin(d, dim=-1)
d = nn.functional.one_hot(d, num_classes=len(boundaries)).float()
d = d.to(ri.dtype)
return self.linear_relpos(d)
def forward(
self,
tf: torch.Tensor,
ri: torch.Tensor,
msa: torch.Tensor,
inplace_safe: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
......@@ -111,8 +117,15 @@ class InputEmbedder(nn.Module):
tf_emb_j = self.linear_tf_z_j(tf)
# [*, N_res, N_res, c_z]
pair_emb = tf_emb_i[..., None, :] + tf_emb_j[..., None, :, :]
pair_emb = pair_emb + self.relpos(ri.type(pair_emb.dtype))
pair_emb = self.relpos(ri.type(tf_emb_i.dtype))
pair_emb = add(pair_emb,
tf_emb_i[..., None, :],
inplace=inplace_safe
)
pair_emb = add(pair_emb,
tf_emb_j[..., None, :, :],
inplace=inplace_safe
)
# [*, N_clust, N_res, c_m]
n_clust = msa.shape[-3]
......@@ -173,7 +186,7 @@ class RecyclingEmbedder(nn.Module):
m: torch.Tensor,
z: torch.Tensor,
x: torch.Tensor,
_inplace: bool = False,
inplace_safe: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
......@@ -191,13 +204,13 @@ class RecyclingEmbedder(nn.Module):
"""
# [*, N, C_m]
m_update = self.layer_norm_m(m)
if(_inplace):
if(inplace_safe):
m.copy_(m_update)
m_update = m
# [*, N, N, C_z]
z_update = self.layer_norm_z(z)
if(_inplace):
if(inplace_safe):
z.copy_(z_update)
z_update = z
......@@ -223,7 +236,7 @@ class RecyclingEmbedder(nn.Module):
# [*, N, N, C_z]
d = self.linear(d)
z_update = add(z_update, d, _inplace)
z_update = add(z_update, d, inplace_safe)
return m_update, z_update
......
......@@ -16,7 +16,7 @@
import math
import torch
import torch.nn as nn
from typing import Tuple, Optional
from typing import Tuple, Sequence, Optional
from functools import partial
from openfold.model.primitives import Linear, LayerNorm
......@@ -29,6 +29,7 @@ from openfold.model.msa import (
from openfold.model.outer_product_mean import OuterProductMean
from openfold.model.pair_transition import PairTransition
from openfold.model.triangular_attention import (
TriangleAttention,
TriangleAttentionStartingNode,
TriangleAttentionEndingNode,
)
......@@ -37,7 +38,8 @@ from openfold.model.triangular_multiplicative_update import (
TriangleMultiplicationIncoming,
)
from openfold.utils.checkpointing import checkpoint_blocks, get_checkpoint_fn
from openfold.utils.tensor_utils import add, chunk_layer, ChunkSizeTuner
from openfold.utils.chunk_utils import chunk_layer, ChunkSizeTuner
from openfold.utils.tensor_utils import add
class MSATransition(nn.Module):
......@@ -66,6 +68,7 @@ class MSATransition(nn.Module):
self.linear_2 = Linear(self.n * self.c_m, self.c_m, init="final")
def _transition(self, m, mask):
m = self.layer_norm(m)
m = self.linear_1(m)
m = self.relu(m)
m = self.linear_2(m) * mask
......@@ -107,8 +110,6 @@ class MSATransition(nn.Module):
mask = mask.unsqueeze(-1)
m = self.layer_norm(m)
if chunk_size is not None:
m = self._chunk(m, mask, chunk_size)
else:
......@@ -155,13 +156,13 @@ class EvoformerBlockCore(nn.Module):
c_hidden_mul,
)
self.tri_att_start = TriangleAttentionStartingNode(
self.tri_att_start = TriangleAttention(
c_z,
c_hidden_pair_att,
no_heads_pair,
inf=inf,
)
self.tri_att_end = TriangleAttentionEndingNode(
self.tri_att_end = TriangleAttention(
c_z,
c_hidden_pair_att,
no_heads_pair,
......@@ -174,17 +175,17 @@ class EvoformerBlockCore(nn.Module):
)
self.ps_dropout_row_layer = DropoutRowwise(pair_dropout)
self.ps_dropout_col_layer = DropoutColumnwise(pair_dropout)
def forward(
self,
m: torch.Tensor,
z: torch.Tensor,
def forward(self,
input_tensors: Sequence[torch.Tensor],
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: Optional[int] = None,
use_lma: bool = False,
inplace_safe: bool = False,
_mask_trans: bool = True,
_attn_chunk_size: Optional[int] = None,
_offload_inference: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
# DeepMind doesn't mask these transitions in the source, so _mask_trans
# should be disabled to better approximate the exact activations of
......@@ -192,8 +193,10 @@ class EvoformerBlockCore(nn.Module):
msa_trans_mask = msa_mask if _mask_trans else None
pair_trans_mask = pair_mask if _mask_trans else None
# Need to dodge activation checkpoints
inplace_safe = not (self.training or torch.is_grad_enabled())
if(_attn_chunk_size is None):
_attn_chunk_size = chunk_size
m, z = input_tensors
m = add(
m,
......@@ -202,17 +205,30 @@ class EvoformerBlockCore(nn.Module):
),
inplace=inplace_safe,
)
z = add(z,
self.outer_product_mean(
m, mask=msa_mask, chunk_size=chunk_size, _inplace=inplace_safe
),
inplace=inplace_safe,
if(_offload_inference and inplace_safe):
del m, z
input_tensors[1] = input_tensors[1].cpu()
torch.cuda.empty_cache()
m, z = input_tensors
opm = self.outer_product_mean(
m, mask=msa_mask, chunk_size=chunk_size, inplace_safe=inplace_safe
)
if(_offload_inference and inplace_safe):
del m, z
input_tensors[0] = input_tensors[0].cpu()
input_tensors[1] = input_tensors[1].to(opm.device)
m, z = input_tensors
z = add(z, opm, inplace=inplace_safe)
del opm
tmu_update = self.tri_mul_out(
z,
mask=pair_mask,
_inplace=inplace_safe,
inplace_safe=inplace_safe,
_add_with_inplace=True,
)
if(not inplace_safe):
......@@ -225,7 +241,7 @@ class EvoformerBlockCore(nn.Module):
tmu_update = self.tri_mul_in(
z,
mask=pair_mask,
_inplace=inplace_safe,
inplace_safe=inplace_safe,
_add_with_inplace=True,
)
if(not inplace_safe):
......@@ -240,23 +256,38 @@ class EvoformerBlockCore(nn.Module):
self.tri_att_start(
z,
mask=pair_mask,
chunk_size=chunk_size,
use_lma=use_lma
chunk_size=_attn_chunk_size,
use_lma=use_lma,
inplace_safe=inplace_safe,
)
),
inplace=inplace_safe,
)
z = z.transpose(-2, -3)
if(inplace_safe):
input_tensors[1] = z.contiguous()
z = input_tensors[1]
z = add(z,
self.ps_dropout_col_layer(
self.ps_dropout_row_layer(
self.tri_att_end(
z,
mask=pair_mask,
chunk_size=chunk_size,
mask=pair_mask.transpose(-1, -2),
chunk_size=_attn_chunk_size,
use_lma=use_lma,
inplace_safe=inplace_safe,
)
),
inplace=inplace_safe,
)
z = z.transpose(-2, -3)
if(inplace_safe):
input_tensors[1] = z.contiguous()
z = input_tensors[1]
z = add(z,
self.pair_transition(
z, mask=pair_trans_mask, chunk_size=chunk_size,
......@@ -264,6 +295,13 @@ class EvoformerBlockCore(nn.Module):
inplace=inplace_safe,
)
if(_offload_inference and inplace_safe):
device = z.device
del m, z
input_tensors[0] = input_tensors[0].to(device)
input_tensors[1] = input_tensors[1].to(device)
m, z = input_tensors
return m, z
......@@ -317,37 +355,66 @@ class EvoformerBlock(nn.Module):
)
def forward(self,
m: torch.Tensor,
z: torch.Tensor,
m: Optional[torch.Tensor],
z: Optional[torch.Tensor],
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: Optional[int] = None,
use_lma: bool = False,
inplace_safe: bool = False,
_mask_trans: bool = True,
_attn_chunk_size: Optional[int] = None,
_offload_inference: bool = False,
_offloadable_inputs: Optional[Sequence[torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
m = m + self.msa_dropout_layer(
if(_attn_chunk_size is None):
_attn_chunk_size = chunk_size
if(_offload_inference and inplace_safe):
input_tensors = _offloadable_inputs
del _offloadable_inputs
else:
input_tensors = [m, z]
m, z = input_tensors
m = add(m,
self.msa_dropout_layer(
self.msa_att_row(
m,
z=z,
mask=msa_mask,
chunk_size=chunk_size,
chunk_size=_attn_chunk_size,
use_lma=use_lma,
)
),
inplace=inplace_safe,
)
m = m + self.msa_att_col(
m = add(m,
self.msa_att_col(
m,
mask=msa_mask,
chunk_size=chunk_size,
use_lma=use_lma,
),
inplace=inplace_safe,
)
if(not inplace_safe):
input_tensors = [m, input_tensors[1]]
del m, z
m, z = self.core(
m,
z,
input_tensors,
msa_mask=msa_mask,
pair_mask=pair_mask,
chunk_size=chunk_size,
use_lma=use_lma,
inplace_safe=inplace_safe,
_mask_trans=_mask_trans,
_attn_chunk_size=_attn_chunk_size,
_offload_inference=_offload_inference,
)
return m, z
......@@ -413,61 +480,85 @@ class ExtraMSABlock(nn.Module):
)
def forward(self,
m: torch.Tensor,
z: torch.Tensor,
m: Optional[torch.Tensor],
z: Optional[torch.Tensor],
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: Optional[int] = None,
use_lma: bool = False,
_chunk_logits: Optional[int] = 1024,
inplace_safe: bool = False,
_mask_trans: bool = True,
_attn_chunk_size: Optional[int] = None,
_offload_inference: bool = False,
_offloadable_inputs: Optional[Sequence[torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# If function calls could speak...
if(_attn_chunk_size is None):
_attn_chunk_size = chunk_size
if(_offload_inference and inplace_safe):
input_tensors = _offloadable_inputs
del _offloadable_inputs
else:
input_tensors = [m, z]
m, z = input_tensors
m = add(m,
self.msa_dropout_layer(
self.msa_att_row(
m.clone() if torch.is_grad_enabled() else m,
z=z.clone() if torch.is_grad_enabled() else z,
mask=msa_mask,
chunk_size=chunk_size,
chunk_size=_attn_chunk_size,
use_lma=use_lma,
use_memory_efficient_kernel=not _chunk_logits and not use_lma,
_chunk_logits=
_chunk_logits if torch.is_grad_enabled() else None,
use_memory_efficient_kernel=not use_lma,
_checkpoint_chunks=
self.ckpt if torch.is_grad_enabled() else False,
)
),
inplace=not (self.training or torch.is_grad_enabled()),
inplace=inplace_safe,
)
def fn(m, z):
m = add(m,
if(not inplace_safe):
input_tensors = [m, z]
del m, z
def fn(input_tensors):
m = add(input_tensors[0],
self.msa_att_col(
m,
input_tensors[0],
mask=msa_mask,
chunk_size=chunk_size,
use_lma=use_lma,
),
inplace=not (self.training or torch.is_grad_enabled()),
inplace=inplace_safe,
)
if(not inplace_safe):
input_tensors = [m, input_tensors[1]]
del m
m, z = self.core(
m,
z,
input_tensors,
msa_mask=msa_mask,
pair_mask=pair_mask,
chunk_size=chunk_size,
use_lma=use_lma,
inplace_safe=inplace_safe,
_mask_trans=_mask_trans,
_attn_chunk_size=_attn_chunk_size,
_offload_inference=_offload_inference,
)
return m, z
if(torch.is_grad_enabled() and self.ckpt):
checkpoint_fn = get_checkpoint_fn()
m, z = checkpoint_fn(fn, m, z)
m, z = checkpoint_fn(fn, input_tensors)
else:
m, z = fn(m, z)
m, z = fn(input_tensors)
return m, z
......@@ -570,6 +661,94 @@ class EvoformerStack(nn.Module):
if(tune_chunk_size):
self.chunk_size_tuner = ChunkSizeTuner()
def _prep_blocks(self,
m: torch.Tensor,
z: torch.Tensor,
chunk_size: int,
use_lma: bool,
msa_mask: Optional[torch.Tensor],
pair_mask: Optional[torch.Tensor],
inplace_safe: bool,
_mask_trans: bool,
):
blocks = [
partial(
b,
msa_mask=msa_mask,
pair_mask=pair_mask,
chunk_size=chunk_size,
use_lma=use_lma,
inplace_safe=inplace_safe,
_mask_trans=_mask_trans,
)
for b in self.blocks
]
if(self.clear_cache_between_blocks):
def block_with_cache_clear(block, *args, **kwargs):
torch.cuda.empty_cache()
return block(*args, **kwargs)
blocks = [partial(block_with_cache_clear, b) for b in blocks]
if(chunk_size is not None and self.chunk_size_tuner is not None):
assert(not self.training)
tuned_chunk_size = self.chunk_size_tuner.tune_chunk_size(
representative_fn=blocks[0],
# We don't want to write in-place during chunk tuning runs
args=(m.clone(), z.clone(),),
min_chunk_size=chunk_size,
)
blocks = [
partial(b,
chunk_size=tuned_chunk_size,
# A temporary measure to address torch's occasional
# inability to allocate large tensors
_attn_chunk_size=max(chunk_size, tuned_chunk_size // 4),
) for b in blocks
]
return blocks
def _forward_offload(self,
input_tensors: Sequence[torch.Tensor],
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: int,
use_lma: bool = False,
_mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
assert(not (self.training or torch.is_grad_enabled()))
blocks = self._prep_blocks(
# We are very careful not to create references to these tensors in
# this function
m=input_tensors[0],
z=input_tensors[1],
chunk_size=chunk_size,
use_lma=use_lma,
msa_mask=msa_mask,
pair_mask=pair_mask,
inplace_safe=True,
_mask_trans=_mask_trans,
)
for b in blocks:
m, z = b(
None,
None,
_offload_inference=True,
_offloadable_inputs=input_tensors,
)
input_tensors[0] = m
input_tensors[1] = z
del m, z
m, z = input_tensors
s = self.linear(m[..., 0, :, :])
return m, z, s
def forward(self,
m: torch.Tensor,
z: torch.Tensor,
......@@ -577,8 +756,9 @@ class EvoformerStack(nn.Module):
pair_mask: torch.Tensor,
chunk_size: int,
use_lma: bool = False,
inplace_safe: bool = False,
_mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
m:
......@@ -601,32 +781,16 @@ class EvoformerStack(nn.Module):
s:
[*, N_res, C_s] single embedding (or None if extra MSA stack)
"""
blocks = [
partial(
b,
msa_mask=msa_mask,
pair_mask=pair_mask,
blocks = self._prep_blocks(
m=m,
z=z,
chunk_size=chunk_size,
use_lma=use_lma,
msa_mask=msa_mask,
pair_mask=pair_mask,
inplace_safe=inplace_safe,
_mask_trans=_mask_trans,
)
for b in self.blocks
]
if(self.clear_cache_between_blocks):
def block_with_cache_clear(block, *args, **kwargs):
torch.cuda.empty_cache()
return block(*args, **kwargs)
blocks = [partial(block_with_cache_clear, b) for b in blocks]
if(chunk_size is not None and self.chunk_size_tuner is not None):
chunk_size = self.chunk_size_tuner.tune_chunk_size(
representative_fn=blocks[0],
args=(m,z),
min_chunk_size=chunk_size,
)
blocks = [partial(b, chunk_size=chunk_size) for b in blocks]
blocks_per_ckpt = self.blocks_per_ckpt
if(not torch.is_grad_enabled()):
......@@ -664,7 +828,6 @@ class ExtraMSAStack(nn.Module):
eps: float,
ckpt: bool,
clear_cache_between_blocks: bool = False,
chunk_msa_attn: bool = False,
tune_chunk_size: bool = False,
**kwargs,
):
......@@ -672,7 +835,6 @@ class ExtraMSAStack(nn.Module):
self.ckpt = ckpt
self.clear_cache_between_blocks = clear_cache_between_blocks
self.chunk_msa_attn = chunk_msa_attn
self.blocks = nn.ModuleList()
for _ in range(no_blocks):
block = ExtraMSABlock(
......@@ -689,7 +851,7 @@ class ExtraMSAStack(nn.Module):
pair_dropout=pair_dropout,
inf=inf,
eps=eps,
ckpt=ckpt if chunk_msa_attn else False,
ckpt=False,
)
self.blocks.append(block)
......@@ -698,6 +860,90 @@ class ExtraMSAStack(nn.Module):
if(tune_chunk_size):
self.chunk_size_tuner = ChunkSizeTuner()
def _prep_blocks(self,
m: torch.Tensor,
z: torch.Tensor,
chunk_size: int,
use_lma: bool,
msa_mask: Optional[torch.Tensor],
pair_mask: Optional[torch.Tensor],
inplace_safe: bool,
_mask_trans: bool,
):
blocks = [
partial(
b,
msa_mask=msa_mask,
pair_mask=pair_mask,
chunk_size=chunk_size,
use_lma=use_lma,
inplace_safe=inplace_safe,
_mask_trans=_mask_trans,
) for b in self.blocks
]
def clear_cache(b, *args, **kwargs):
torch.cuda.empty_cache()
return b(*args, **kwargs)
if(self.clear_cache_between_blocks):
blocks = [partial(clear_cache, b) for b in blocks]
if(chunk_size is not None and self.chunk_size_tuner is not None):
tuned_chunk_size = self.chunk_size_tuner.tune_chunk_size(
representative_fn=blocks[0],
# Tensors cloned to avoid getting written to in-place
# A corollary is that chunk size tuning should be disabled for
# large N, when z gets really big
args=(m.clone(), z.clone(),),
min_chunk_size=chunk_size,
)
blocks = [
partial(b,
chunk_size=tuned_chunk_size,
# A temporary measure to address torch's occasional
# inability to allocate large tensors
_attn_chunk_size=max(chunk_size, tuned_chunk_size // 4),
) for b in blocks
]
return blocks
def _forward_offload(self,
input_tensors: Sequence[torch.Tensor],
chunk_size: int,
use_lma: bool = False,
msa_mask: Optional[torch.Tensor] = None,
pair_mask: Optional[torch.Tensor] = None,
_mask_trans: bool = True,
) -> torch.Tensor:
assert(not (self.training or torch.is_grad_enabled()))
blocks = self._prep_blocks(
# We are very careful not to create references to these tensors in
# this function
m=input_tensors[0],
z=input_tensors[1],
chunk_size=chunk_size,
use_lma=use_lma,
msa_mask=msa_mask,
pair_mask=pair_mask,
inplace_safe=True,
_mask_trans=_mask_trans,
)
for b in blocks:
m, z = b(
None,
None,
_offload_inference=True,
_offloadable_inputs=input_tensors,
)
input_tensors[0] = m
input_tensors[1] = z
del m, z
return input_tensors[1]
def forward(self,
m: torch.Tensor,
z: torch.Tensor,
......@@ -705,6 +951,7 @@ class ExtraMSAStack(nn.Module):
use_lma: bool = False,
msa_mask: Optional[torch.Tensor] = None,
pair_mask: Optional[torch.Tensor] = None,
inplace_safe: bool = False,
_mask_trans: bool = True,
) -> torch.Tensor:
"""
......@@ -722,53 +969,22 @@ class ExtraMSAStack(nn.Module):
Returns:
[*, N_res, N_res, C_z] pair update
"""
if(not self.chunk_msa_attn):
checkpoint_fn = get_checkpoint_fn()
blocks = [
partial(
b,
msa_mask=msa_mask,
pair_mask=pair_mask,
blocks = self._prep_blocks(
m=m,
z=z,
chunk_size=chunk_size,
use_lma=use_lma,
_chunk_logits=None,
msa_mask=msa_mask,
pair_mask=pair_mask,
inplace_safe=inplace_safe,
_mask_trans=_mask_trans,
) for b in self.blocks
]
def clear_cache(b, *args, **kwargs):
torch.cuda.empty_cache()
return b(*args, **kwargs)
if(self.clear_cache_between_blocks):
blocks = [partial(clear_cache, b) for b in blocks]
if(chunk_size is not None and self.chunk_size_tuner is not None):
chunk_size = self.chunk_size_tuner.tune_chunk_size(
representative_fn=blocks[0],
args=(m,z),
min_chunk_size=chunk_size,
)
blocks = [partial(b, chunk_size=chunk_size) for b in blocks]
for b in blocks:
if(self.ckpt and torch.is_grad_enabled()):
m, z = checkpoint_fn(b, *(m, z))
m, z = checkpoint_fn(b, m, z)
else:
m, z = b(m, z)
else:
for b in self.blocks:
m, z = b(
m,
z,
msa_mask,
pair_mask,
chunk_size=chunk_size,
use_lma=use_lma,
_mask_trans=_mask_trans
)
if(self.clear_cache_between_blocks):
torch.cuda.empty_cache()
return z
......@@ -107,19 +107,16 @@ class AlphaFold(nn.Module):
self.config["heads"],
)
def embed_templates(self, batch, z, pair_mask, templ_dim):
def embed_templates(self, batch, z, pair_mask, templ_dim, inplace_safe):
if(self.template_config.offload_templates):
return embed_templates_offload(
self, batch, z, pair_mask, templ_dim,
return embed_templates_offload(self,
batch, z, pair_mask, templ_dim, inplace_safe=inplace_safe,
)
elif(self.template_config.average_templates):
return embed_templates_average(
self, batch, z, pair_mask, templ_dim
return embed_templates_average(self,
batch, z, pair_mask, templ_dim, inplace_safe=inplace_safe,
)
inplace_safe = not (self.training or torch.is_grad_enabled())
# Embed the templates one at a time (with a poor man's vmap)
pair_embeds = []
n = z.shape[-2]
......@@ -168,6 +165,7 @@ class AlphaFold(nn.Module):
pair_mask.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=self.globals.chunk_size,
use_lma=self.globals.use_lma,
inplace_safe=inplace_safe,
_mask_trans=self.config._mask_trans,
)
del t_pair
......@@ -186,6 +184,11 @@ class AlphaFold(nn.Module):
t = t * (torch.sum(batch["template_mask"], dim=-1) > 0)
ret = {}
ret.update({"template_pair_embedding": t})
del t
if self.config.template.embed_angles:
template_angle_feat = build_template_angle_feat(
batch
......@@ -196,10 +199,6 @@ class AlphaFold(nn.Module):
ret["template_angle_embedding"] = a
ret.update({"template_pair_embedding": t})
del t
return ret
def iteration(self, feats, prevs, _recycle=True):
......@@ -218,6 +217,9 @@ class AlphaFold(nn.Module):
n = feats["target_feat"].shape[-2]
n_seq = feats["msa_feat"].shape[-3]
device = feats["target_feat"].device
# Controls whether the model uses in-place operations throughout
# The dual condition accounts for activation checkpoints
inplace_safe = not (self.training or torch.is_grad_enabled())
# Prep some features
......@@ -233,10 +235,11 @@ class AlphaFold(nn.Module):
feats["target_feat"],
feats["residue_index"],
feats["msa_feat"],
inplace_safe=inplace_safe,
)
# Unpack the recycling embeddings. Removing them from the list allows
# them to be freed further down in this function.
# them to be freed further down in this function, saving memory
m_1_prev, z_prev, x_prev = reversed([prevs.pop() for _ in range(3)])
# Initialize the recycling embeddings, if needs be
......@@ -263,22 +266,33 @@ class AlphaFold(nn.Module):
feats["aatype"], x_prev, None
).to(dtype=z.dtype)
# The recycling embedder is memory-intensive, so we offload first
if(self.globals.offload_inference and inplace_safe):
m = m.cpu()
z = z.cpu()
# m_1_prev_emb: [*, N, C_m]
# z_prev_emb: [*, N, N, C_z]
m_1_prev_emb, z_prev_emb = self.recycling_embedder(
m_1_prev,
z_prev,
x_prev,
_inplace=not (self.training or torch.is_grad_enabled()),
inplace_safe=inplace_safe,
)
if(self.globals.offload_inference and inplace_safe):
m = m.to(m_1_prev_emb.device)
z = z.to(z_prev.device)
# [*, S_c, N, C_m]
m[..., 0, :, :] += m_1_prev_emb
# [*, N, N, C_z]
z += z_prev_emb
z = add(z, z_prev_emb, inplace=inplace_safe)
# This matters during inference with large N
# Deletions like these become significant for inference with large N,
# where they free unused tensors and remove references to others such
# that they can be offloaded later
del m_1_prev, z_prev, x_prev, m_1_prev_emb, z_prev_emb
# Embed the templates + merge with MSA/pair embeddings
......@@ -291,6 +305,7 @@ class AlphaFold(nn.Module):
z,
pair_mask.to(dtype=z.dtype),
no_batch_dims,
inplace_safe=inplace_safe,
)
# [*, N, N, C_z]
......@@ -299,7 +314,7 @@ class AlphaFold(nn.Module):
inplace_safe,
)
if self.config.template.embed_angles:
if "template_angle_embedding" in template_embeds:
# [*, S = S_c + S_t, N, C_m]
m = torch.cat(
[m, template_embeds["template_angle_embedding"]],
......@@ -318,23 +333,53 @@ class AlphaFold(nn.Module):
# [*, S_e, N, C_e]
a = self.extra_msa_embedder(build_extra_msa_feat(feats))
if(self.globals.offload_inference):
# To allow the extra MSA stack (and later the evoformer) to
# offload its inputs, we remove all references to them here
input_tensors = [a, z]
del a, z
# [*, N, N, C_z]
z = self.extra_msa_stack(
a,
z,
msa_mask=feats["extra_msa_mask"].to(dtype=a.dtype),
z = self.extra_msa_stack._forward_offload(
input_tensors,
msa_mask=feats["extra_msa_mask"].to(dtype=m.dtype),
chunk_size=self.globals.chunk_size,
use_lma=self.globals.use_lma,
pair_mask=pair_mask.to(dtype=z.dtype),
pair_mask=pair_mask.to(dtype=m.dtype),
_mask_trans=self.config._mask_trans,
)
del a
del input_tensors
else:
# [*, N, N, C_z]
z = self.extra_msa_stack(
a, z,
msa_mask=feats["extra_msa_mask"].to(dtype=m.dtype),
chunk_size=self.globals.chunk_size,
use_lma=self.globals.use_lma,
pair_mask=pair_mask.to(dtype=m.dtype),
inplace_safe=inplace_safe,
_mask_trans=self.config._mask_trans,
)
# Run MSA + pair embeddings through the trunk of the network
# m: [*, S, N, C_m]
# z: [*, N, N, C_z]
# s: [*, N, C_s]
if(self.globals.offload_inference):
input_tensors = [m, z]
del m, z
m, z, s = self.evoformer._forward_offload(
input_tensors,
msa_mask=msa_mask.to(dtype=input_tensors[0].dtype),
pair_mask=pair_mask.to(dtype=input_tensors[1].dtype),
chunk_size=self.globals.chunk_size,
use_lma=self.globals.use_lma,
_mask_trans=self.config._mask_trans,
)
del input_tensors
else:
m, z, s = self.evoformer(
m,
z,
......@@ -342,6 +387,7 @@ class AlphaFold(nn.Module):
pair_mask=pair_mask.to(dtype=z.dtype),
chunk_size=self.globals.chunk_size,
use_lma=self.globals.use_lma,
inplace_safe=inplace_safe,
_mask_trans=self.config._mask_trans,
)
......@@ -349,12 +395,15 @@ class AlphaFold(nn.Module):
outputs["pair"] = z
outputs["single"] = s
del z
# Predict 3D structure
outputs["sm"] = self.structure_module(
s,
z,
outputs,
feats["aatype"],
mask=feats["seq_mask"].to(dtype=s.dtype),
inplace_safe=inplace_safe,
_offload_inference=self.globals.offload_inference,
)
outputs["final_atom_positions"] = atom14_to_atom37(
outputs["sm"]["positions"][-1], feats
......@@ -368,7 +417,7 @@ class AlphaFold(nn.Module):
m_1_prev = m[..., 0, :, :]
# [*, N, N, C_z]
z_prev = z
z_prev = outputs["pair"]
# [*, N, 3]
x_prev = outputs["final_atom_positions"]
......
......@@ -26,8 +26,8 @@ from openfold.model.primitives import (
_attention_chunked_trainable,
)
from openfold.utils.checkpointing import get_checkpoint_fn
from openfold.utils.chunk_utils import chunk_layer
from openfold.utils.tensor_utils import (
chunk_layer,
permute_final_dims,
flatten_final_dims,
)
......@@ -94,16 +94,20 @@ class MSAAttention(nn.Module):
use_memory_efficient_kernel: bool,
use_lma: bool,
) -> torch.Tensor:
mha = partial(
self.mha,
def fn(m, biases):
m = self.layer_norm_m(m)
return self.mha(
q_x=m,
kv_x=m,
biases=biases,
use_memory_efficient_kernel=use_memory_efficient_kernel,
use_lma=use_lma,
)
return chunk_layer(
mha,
fn,
{
"q_x": m,
"kv_x": m,
"m": m,
"biases": biases,
},
chunk_size=chunk_size,
......@@ -113,11 +117,9 @@ class MSAAttention(nn.Module):
def _prep_inputs(self,
m: torch.Tensor,
z: Optional[torch.Tensor],
mask: Optional[torch.Tensor]
mask: Optional[torch.Tensor],
inplace_safe: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# [*, N_seq, N_res, C_m]
m = self.layer_norm_m(m)
n_seq, n_res = m.shape[-3:-1]
if mask is None:
# [*, N_seq, N_res]
......@@ -133,11 +135,20 @@ class MSAAttention(nn.Module):
self.layer_norm_z is not None and # benefit of
self.linear_z is not None # TorchScript
):
chunks = []
for i in range(0, z.shape[-3], 256):
z_chunk = z[..., i: i + 256, :, :]
# [*, N_res, N_res, C_z]
z = self.layer_norm_z(z)
z_chunk = self.layer_norm_z(z_chunk)
# [*, N_res, N_res, no_heads]
z = self.linear_z(z)
z_chunk = self.linear_z(z_chunk)
chunks.append(z_chunk)
z = torch.cat(chunks, dim=-3)
# [*, 1, no_heads, N_res, N_res]
z = permute_final_dims(z, (2, 0, 1)).unsqueeze(-4)
......@@ -151,6 +162,7 @@ class MSAAttention(nn.Module):
mask: Optional[torch.Tensor],
chunk_logits: int,
checkpoint: bool,
inplace_safe: bool = False
) -> torch.Tensor:
"""
MSA attention with training-time chunking of the softmax computation.
......@@ -160,7 +172,9 @@ class MSAAttention(nn.Module):
MSA_DIM = -4
def _get_qkv(m, z):
m, mask_bias, z = self._prep_inputs(m, z, mask)
m, mask_bias, z = self._prep_inputs(
m, z, mask, inplace_safe=inplace_safe
)
q, k, v = self.mha._prep_qkv(m, m)
return m, q, k, v, mask_bias, z
......@@ -196,6 +210,7 @@ class MSAAttention(nn.Module):
chunk_size: Optional[int] = None,
use_memory_efficient_kernel: bool = False,
use_lma: bool = False,
inplace_safe: bool = False,
_chunk_logits: Optional[int] = None,
_checkpoint_chunks: Optional[bool] = None,
) -> torch.Tensor:
......@@ -217,10 +232,14 @@ class MSAAttention(nn.Module):
if(_chunk_logits is not None):
return self._chunked_msa_attn(
m=m, z=z, mask=mask,
chunk_logits=_chunk_logits, checkpoint=_checkpoint_chunks
chunk_logits=_chunk_logits,
checkpoint=_checkpoint_chunks,
inplace_safe=inplace_safe,
)
m, mask_bias, z = self._prep_inputs(m, z, mask)
m, mask_bias, z = self._prep_inputs(
m, z, mask, inplace_safe=inplace_safe
)
biases = [mask_bias]
if(z is not None):
......@@ -376,8 +395,13 @@ class MSAColumnGlobalAttention(nn.Module):
"m": m,
"mask": mask,
}
def fn(m, mask):
m = self.layer_norm_m(m)
return self.global_attention(m, mask, use_lma=use_lma)
return chunk_layer(
partial(self.global_attention, use_lma=use_lma),
fn,
mha_input,
chunk_size=chunk_size,
no_batch_dims=len(m.shape[:-2]),
......@@ -405,11 +429,12 @@ class MSAColumnGlobalAttention(nn.Module):
mask = mask.transpose(-1, -2)
# [*, N_res, N_seq, C_in]
m = self.layer_norm_m(m)
#m = self.layer_norm_m(m)
if chunk_size is not None:
m = self._chunk(m, mask, chunk_size, use_lma=use_lma)
else:
m = self.layer_norm_m(m)
m = self.global_attention(m=m, mask=mask, use_lma=use_lma)
# [*, N_seq, N_res, C_in]
......
......@@ -20,7 +20,7 @@ import torch
import torch.nn as nn
from openfold.model.primitives import Linear
from openfold.utils.tensor_utils import chunk_layer
from openfold.utils.chunk_utils import chunk_layer
class OuterProductMean(nn.Module):
......@@ -97,7 +97,7 @@ class OuterProductMean(nn.Module):
m: torch.Tensor,
mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None,
_inplace: bool = False,
inplace_safe: bool = False,
) -> torch.Tensor:
"""
Args:
......@@ -137,7 +137,7 @@ class OuterProductMean(nn.Module):
norm = norm + self.eps
# [*, N_res, N_res, C_z]
if(_inplace):
if(inplace_safe):
outer /= norm
else:
outer = outer / norm
......
......@@ -18,7 +18,7 @@ import torch
import torch.nn as nn
from openfold.model.primitives import Linear, LayerNorm
from openfold.utils.tensor_utils import chunk_layer
from openfold.utils.chunk_utils import chunk_layer
class PairTransition(nn.Module):
......@@ -46,6 +46,9 @@ class PairTransition(nn.Module):
self.linear_2 = Linear(self.n * self.c_z, c_z, init="final")
def _transition(self, z, mask):
# [*, N_res, N_res, C_z]
z = self.layer_norm(z)
# [*, N_res, N_res, C_hidden]
z = self.linear_1(z)
z = self.relu(z)
......@@ -88,9 +91,6 @@ class PairTransition(nn.Module):
# [*, N_res, N_res, 1]
mask = mask.unsqueeze(-1)
# [*, N_res, N_res, C_z]
z = self.layer_norm(z)
if chunk_size is not None:
z = self._chunk(z, mask, chunk_size)
else:
......
......@@ -23,11 +23,11 @@ import torch.nn as nn
from scipy.stats import truncnorm
from openfold.utils.checkpointing import get_checkpoint_fn
from openfold.utils.chunk_utils import _chunk_slice
from openfold.utils.kernel.attention_core import attention_core
from openfold.utils.tensor_utils import (
permute_final_dims,
flatten_final_dims,
_chunk_slice,
)
......@@ -149,6 +149,7 @@ class Linear(nn.Linear):
with torch.no_grad():
self.bias.fill_(0)
with torch.no_grad():
if init_fn is not None:
init_fn(self.weight, self.bias)
else:
......@@ -161,7 +162,6 @@ class Linear(nn.Linear):
elif init == "gating":
gating_init_(self.weight)
if bias:
with torch.no_grad():
self.bias.fill_(1.0)
elif init == "normal":
normal_init_(self.weight)
......
......@@ -19,7 +19,7 @@ from operator import mul
import torch
import torch.nn as nn
from typing import Optional, Tuple
from typing import Optional, Tuple, Sequence
from openfold.model.primitives import Linear, LayerNorm, ipa_point_weights_init_
from openfold.np.residue_constants import (
......@@ -229,9 +229,12 @@ class InvariantPointAttention(nn.Module):
def forward(
self,
s: torch.Tensor,
z: torch.Tensor,
z: Optional[torch.Tensor],
r: Rigid,
mask: torch.Tensor,
inplace_safe: bool = False,
_offload_inference: bool = False,
_z_reference_list: Optional[Sequence[torch.Tensor]] = None,
) -> torch.Tensor:
"""
Args:
......@@ -246,7 +249,10 @@ class InvariantPointAttention(nn.Module):
Returns:
[*, N_res, C_s] single representation update
"""
inplace_safe = not (self.training or torch.is_grad_enabled())
if(_offload_inference and inplace_safe):
z = _z_reference_list
else:
z = [z]
#######################################
# Generate scalar and point activations
......@@ -298,7 +304,10 @@ class InvariantPointAttention(nn.Module):
# Compute attention scores
##########################
# [*, N_res, N_res, H]
b = self.linear_b(z)
b = self.linear_b(z[0])
if(_offload_inference):
z[0] = z[0].cpu()
# [*, H, N_res, N_res]
a = torch.matmul(
......@@ -392,8 +401,11 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, H * P_v, 3]
o_pt = o_pt.reshape(*o_pt.shape[:-3], -1, 3)
if(_offload_inference):
z[0] = z[0].to(o_pt.device)
# [*, N_res, H, C_z]
o_pair = torch.matmul(a.transpose(-2, -3), z.to(dtype=a.dtype))
o_pair = torch.matmul(a.transpose(-2, -3), z[0].to(dtype=a.dtype))
# [*, N_res, H * C_z]
o_pair = flatten_final_dims(o_pair, 2)
......@@ -402,7 +414,7 @@ class InvariantPointAttention(nn.Module):
s = self.linear_out(
torch.cat(
(o, *torch.unbind(o_pt, dim=-1), o_pt_norm, o_pair), dim=-1
).to(dtype=z.dtype)
).to(dtype=z[0].dtype)
)
return s
......@@ -604,16 +616,19 @@ class StructureModule(nn.Module):
def forward(
self,
s,
z,
evoformer_output_dict,
aatype,
mask=None,
inplace_safe=False,
_offload_inference=False,
):
"""
Args:
s:
evoformer_output_dict:
Dictionary containing:
"single":
[*, N_res, C_s] single representation
z:
"pair":
[*, N_res, N_res, C_z] pair representation
aatype:
[*, N_res] amino acid indices
......@@ -622,6 +637,8 @@ class StructureModule(nn.Module):
Returns:
A dictionary of outputs
"""
s = evoformer_output_dict["single"]
if mask is None:
# [*, N]
mask = s.new_ones(s.shape[:-1])
......@@ -630,7 +647,13 @@ class StructureModule(nn.Module):
s = self.layer_norm_s(s)
# [*, N, N, C_z]
z = self.layer_norm_z(z)
z = self.layer_norm_z(evoformer_output_dict["pair"])
z_reference_list = None
if(_offload_inference):
evoformer_output_dict["pair"] = evoformer_output_dict["pair"].cpu()
z_reference_list = [z]
z = None
# [*, N, C_s]
s_initial = s
......@@ -647,7 +670,15 @@ class StructureModule(nn.Module):
outputs = []
for i in range(self.no_blocks):
# [*, N, C_s]
s = s + self.ipa(s, z, rigids, mask)
s = s + self.ipa(
s,
z,
rigids,
mask,
inplace_safe=inplace_safe,
_offload_inference=_offload_inference,
_z_reference_list=z_reference_list
)
s = self.ipa_dropout(s)
s = self.layer_norm_ipa(s)
s = self.transition(s)
......@@ -698,6 +729,13 @@ class StructureModule(nn.Module):
rigids = rigids.stop_rot_gradient()
del z, z_reference_list
if(_offload_inference):
evoformer_output_dict["pair"] = (
evoformer_output_dict["pair"].to(s.device)
)
outputs = dict_multimap(torch.stack, outputs)
outputs["single"] = s
......
......@@ -34,13 +34,16 @@ from openfold.model.triangular_multiplicative_update import (
TriangleMultiplicationIncoming,
)
from openfold.utils.checkpointing import checkpoint_blocks
from openfold.utils.chunk_utils import (
chunk_layer,
ChunkSizeTuner,
)
from openfold.utils.feats import (
build_template_angle_feat,
build_template_pair_feat,
)
from openfold.utils.tensor_utils import (
add,
chunk_layer,
permute_final_dims,
flatten_final_dims,
tensor_tree_map,
......@@ -198,15 +201,20 @@ class TemplatePairStackBlock(nn.Module):
mask: torch.Tensor,
chunk_size: Optional[int] = None,
use_lma: bool = False,
inplace_safe: bool = False,
_mask_trans: bool = True,
_inplace: bool = False,
_attn_chunk_size: Optional[int] = None,
):
if(_attn_chunk_size is None):
_attn_chunk_size = chunk_size
single_templates = [
t.unsqueeze(-4) for t in torch.unbind(z, dim=-4)
]
single_templates_masks = [
m.unsqueeze(-3) for m in torch.unbind(mask, dim=-3)
]
for i in range(len(single_templates)):
single = single_templates[i]
single_mask = single_templates_masks[i]
......@@ -215,33 +223,35 @@ class TemplatePairStackBlock(nn.Module):
self.dropout_row(
self.tri_att_start(
single,
chunk_size=chunk_size,
chunk_size=_attn_chunk_size,
mask=single_mask,
use_lma=use_lma,
inplace_safe=inplace_safe,
)
),
_inplace,
inplace_safe,
)
single = add(single,
self.dropout_col(
self.tri_att_end(
single,
chunk_size=chunk_size,
chunk_size=_attn_chunk_size,
mask=single_mask,
use_lma=use_lma,
inplace_safe=inplace_safe,
)
),
_inplace,
inplace_safe,
)
tmu_update = self.tri_mul_out(
single,
mask=single_mask,
_inplace=_inplace,
inplace_safe=inplace_safe,
_add_with_inplace=True,
)
if(not _inplace):
if(not inplace_safe):
single = single + self.dropout_row(tmu_update)
else:
single = tmu_update
......@@ -251,10 +261,10 @@ class TemplatePairStackBlock(nn.Module):
tmu_update = self.tri_mul_in(
single,
mask=single_mask,
_inplace=_inplace,
inplace_safe=inplace_safe,
_add_with_inplace=True,
)
if(not _inplace):
if(not inplace_safe):
single = single + self.dropout_row(tmu_update)
else:
single = tmu_update
......@@ -267,13 +277,13 @@ class TemplatePairStackBlock(nn.Module):
mask=single_mask if _mask_trans else None,
chunk_size=chunk_size,
),
_inplace,
inplace_safe,
)
if(not _inplace):
if(not inplace_safe):
single_templates[i] = single
if(not _inplace):
if(not inplace_safe):
z = torch.cat(single_templates, dim=-4)
return z
......@@ -293,6 +303,7 @@ class TemplatePairStack(nn.Module):
pair_transition_n,
dropout_rate,
blocks_per_ckpt,
tune_chunk_size: bool = False,
inf=1e9,
**kwargs,
):
......@@ -333,12 +344,18 @@ class TemplatePairStack(nn.Module):
self.layer_norm = LayerNorm(c_t)
self.tune_chunk_size = tune_chunk_size
self.chunk_size_tuner = None
if(tune_chunk_size):
self.chunk_size_tuner = ChunkSizeTuner()
def forward(
self,
t: torch.tensor,
mask: torch.tensor,
chunk_size: int,
use_lma: bool = False,
inplace_safe: bool = False,
_mask_trans: bool = True,
):
"""
......@@ -355,18 +372,34 @@ class TemplatePairStack(nn.Module):
expand_idx[-3] = t.shape[-4]
mask = mask.expand(*expand_idx)
t, = checkpoint_blocks(
blocks=[
blocks = [
partial(
b,
mask=mask,
chunk_size=chunk_size,
use_lma=use_lma,
inplace_safe=inplace_safe,
_mask_trans=_mask_trans,
_inplace=not (self.training or torch.is_grad_enabled()),
)
for b in self.blocks
],
]
if(chunk_size is not None and self.chunk_size_tuner is not None):
assert(not self.training)
tuned_chunk_size = self.chunk_size_tuner.tune_chunk_size(
representative_fn=blocks[0],
args=(t.clone(),),
min_chunk_size=chunk_size,
)
blocks = [
partial(b,
chunk_size=chunk_size,
_attn_chunk_size=max(chunk_size, tuned_chunk_size // 4),
) for b in blocks
]
t, = checkpoint_blocks(
blocks=blocks,
args=(t,),
blocks_per_ckpt=self.blocks_per_ckpt if self.training else None,
)
......@@ -383,6 +416,7 @@ def embed_templates_offload(
pair_mask,
templ_dim,
template_chunk_size=256,
inplace_safe=False,
):
"""
Args:
......@@ -407,8 +441,6 @@ def embed_templates_offload(
offloads the large template pair tensor to CPU. Slower but more frugal
with GPU memory than the original. Useful for long-sequence inference.
"""
inplace_safe = not (model.training or torch.is_grad_enabled())
# Embed the templates one at a time (with a poor man's vmap)
pair_embeds_cpu = []
n = z.shape[-2]
......@@ -491,6 +523,7 @@ def embed_templates_average(
pair_mask,
templ_dim,
templ_group_size=2,
inplace_safe=False,
):
"""
Args:
......@@ -519,8 +552,6 @@ def embed_templates_average(
embedding, while its low memory footprint allows the number of templates
to scale almost indefinitely.
"""
inplace_safe = not (model.training or torch.is_grad_enabled())
# Embed the templates one at a time (with a poor man's vmap)
n = z.shape[-2]
n_templ = batch["template_aatype"].shape[templ_dim]
......
......@@ -21,8 +21,8 @@ import torch
import torch.nn as nn
from openfold.model.primitives import Linear, LayerNorm, Attention
from openfold.utils.chunk_utils import chunk_layer
from openfold.utils.tensor_utils import (
chunk_layer,
permute_final_dims,
flatten_final_dims,
)
......@@ -30,7 +30,7 @@ from openfold.utils.tensor_utils import (
class TriangleAttention(nn.Module):
def __init__(
self, c_in, c_hidden, no_heads, starting, inf=1e9
self, c_in, c_hidden, no_heads, starting=True, inf=1e9
):
"""
Args:
......@@ -62,25 +62,36 @@ class TriangleAttention(nn.Module):
x: torch.Tensor,
biases: List[torch.Tensor],
chunk_size: int,
use_memory_efficient_kernel: bool = False,
use_lma: bool = False,
inplace_safe: bool = False,
) -> torch.Tensor:
"triangle! triangle!"
mha_inputs = {
"q_x": x,
"kv_x": x,
"biases": biases,
}
return chunk_layer(
partial(self.mha, use_lma=use_lma),
partial(
self.mha,
use_memory_efficient_kernel=use_memory_efficient_kernel,
use_lma=use_lma
),
mha_inputs,
chunk_size=chunk_size,
no_batch_dims=len(x.shape[:-2]),
_out=x if inplace_safe else None,
)
def forward(self,
x: torch.Tensor,
mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None,
use_memory_efficient_kernel: bool = False,
use_lma: bool = False,
inplace_safe: bool = False,
) -> torch.Tensor:
"""
Args:
......@@ -95,8 +106,7 @@ class TriangleAttention(nn.Module):
x.shape[:-1],
)
# Shape annotations assume self.starting. Else, I and J are flipped
if not self.starting:
if(not self.starting):
x = x.transpose(-2, -3)
mask = mask.transpose(-1, -2)
......@@ -115,27 +125,35 @@ class TriangleAttention(nn.Module):
biases = [mask_bias, triangle_bias]
if chunk_size is not None:
x = self._chunk(x, biases, chunk_size, use_lma=use_lma)
x = self._chunk(
x,
biases,
chunk_size,
use_memory_efficient_kernel=use_memory_efficient_kernel,
use_lma=use_lma,
inplace_safe=inplace_safe,
)
else:
x = self.mha(q_x=x, kv_x=x, biases=biases, use_lma=use_lma)
x = self.mha(
q_x=x,
kv_x=x,
biases=biases,
use_memory_efficient_kernel=use_memory_efficient_kernel,
use_lma=use_lma
)
if not self.starting:
if(not self.starting):
x = x.transpose(-2, -3)
return x
class TriangleAttentionStartingNode(TriangleAttention):
"""
Implements Algorithm 13.
"""
__init__ = partialmethod(TriangleAttention.__init__, starting=True)
# Implements Algorithm 13
TriangleAttentionStartingNode = TriangleAttention
class TriangleAttentionEndingNode(TriangleAttention):
"""
Implements Algorithm 14.
"""
__init__ = partialmethod(TriangleAttention.__init__, starting=False)
......@@ -20,7 +20,8 @@ import torch
import torch.nn as nn
from openfold.model.primitives import Linear, LayerNorm
from openfold.utils.tensor_utils import add, chunk_layer, permute_final_dims
from openfold.utils.chunk_utils import chunk_layer
from openfold.utils.tensor_utils import add, permute_final_dims
class TriangleMultiplicativeUpdate(nn.Module):
......@@ -356,7 +357,7 @@ class TriangleMultiplicativeUpdate(nn.Module):
def forward(self,
z: torch.Tensor,
mask: Optional[torch.Tensor] = None,
_inplace: bool = False,
inplace_safe: bool = False,
_add_with_inplace: bool = False,
_inplace_chunk_size: Optional[int] = 256,
) -> torch.Tensor:
......@@ -369,7 +370,7 @@ class TriangleMultiplicativeUpdate(nn.Module):
Returns:
[*, N_res, N_res, C_z] output tensor
"""
if(_inplace):
if(inplace_safe):
x = self._inference_forward(
z,
mask,
......
......@@ -92,24 +92,16 @@ def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
)
model = models[0]
if chain_id is not None:
chain = model[chain_id]
else:
chains = list(model.get_chains())
if len(chains) != 1:
raise ValueError(
"Only single chain PDBs are supported when chain_id not specified. "
f"Found {len(chains)} chains."
)
else:
chain = chains[0]
atom_positions = []
aatype = []
atom_mask = []
residue_index = []
chain_ids = []
b_factors = []
for chain in model:
if(chain_id is not None and chain.id != chain_id):
continue
for res in chain:
if res.id[2] != " ":
raise ValueError(
......@@ -138,22 +130,38 @@ def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
atom_positions.append(pos)
atom_mask.append(mask)
residue_index.append(res.id[1])
chain_ids.append(chain.id)
b_factors.append(res_b_factors)
parents = None
parents_chain_index = None
if("PARENT" in pdb_str):
parents = []
parents_chain_index = []
chain_id = 0
for l in pdb_str.split("\n"):
if("PARENT" in l and not "N/A" in l):
parents = l.split()[1:]
break
if("PARENT" in l):
if(not "N/A" in l):
parent_names = l.split()[1:]
parents.extend(parent_names)
parents_chain_index.extend([
chain_id for _ in parent_names
])
chain_id += 1
unique_chain_ids = np.unique(chain_ids)
chain_id_mapping = {cid: n for n, cid in enumerate(string.ascii_uppercase)}
chain_index = np.array([chain_id_mapping[cid] for cid in chain_ids])
return Protein(
atom_positions=np.array(atom_positions),
atom_mask=np.array(atom_mask),
aatype=np.array(aatype),
residue_index=np.array(residue_index),
chain_index=chain_index,
b_factors=np.array(b_factors),
parents=parents,
parents_chain_index=parents_chain_index,
)
......@@ -232,6 +240,56 @@ def get_pdb_headers(prot: Protein, chain_id: int = 0) -> Sequence[str]:
return pdb_headers
def add_pdb_headers(prot: Protein, pdb_str: str) -> str:
""" Add pdb headers to an existing PDB string. Useful during multi-chain
recycling
"""
out_pdb_lines = []
lines = pdb_str.split('\n')
remark = prot.remark
if(remark is not None):
out_pdb_lines.append(f"REMARK {remark}")
parents_per_chain = None
if(prot.parents is not None and len(prot.parents) > 0):
parents_per_chain = []
if(prot.parents_chain_index is not None):
cur_chain = prot.parents_chain_index[0]
parent_dict = {}
for p, i in zip(prot.parents, prot.parents_chain_index):
parent_dict.setdefault(str(i), [])
parent_dict[str(i)].append(p)
max_idx = max([int(chain_idx) for chain_idx in parent_dict])
for i in range(max_idx + 1):
chain_parents = parent_dict.get(str(i), ["N/A"])
parents_per_chain.append(chain_parents)
else:
parents_per_chain.append(prot.parents)
else:
parents_per_chain = [["N/A"]]
make_parent_line = lambda p: f"PARENT {' '.join(p)}"
out_pdb_lines.append(make_parent_line(parents_per_chain[0]))
chain_counter = 0
for i, l in enumerate(lines):
if("PARENT" not in l and "REMARK" not in l):
out_pdb_lines.append(l)
if("TER" in l and not "END" in lines[i + 1]):
chain_counter += 1
if(not chain_counter >= len(parents_per_chain)):
chain_parents = parents_per_chain[chain_counter]
else:
chain_parents = ["N/A"]
out_pdb_lines.append(make_parent_line(chain_parents))
return '\n'.join(out_pdb_lines)
def to_pdb(prot: Protein) -> str:
"""Converts a `Protein` instance to a PDB string.
......
......@@ -88,8 +88,6 @@ class AmberRelaxation(object):
"total_per_residue_violations_mask"
]
headers = protein.get_pdb_headers(prot)
if(len(headers) > 0):
min_pdb = '\n'.join(['\n'.join(headers), min_pdb])
min_pdb = protein.add_pdb_headers(prot, min_pdb)
return min_pdb, debug_data, violations
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment