Commit 1606ac08 authored by Christina Floristean's avatar Christina Floristean
Browse files

Merge branch 'main' into multimer

parents 58d65692 67a00a6c
......@@ -39,13 +39,14 @@ kernels support in-place attention during inference and training. They use
implementations, respectively.
- **Efficient alignment scripts** using the original AlphaFold HHblits/JackHMMER pipeline or [ColabFold](https://github.com/sokrypton/ColabFold)'s, which uses the faster MMseqs2 instead. We've used them to generate millions of alignments.
- **FlashAttention** support greatly speeds up MSA attention.
- **DeepSpeed DS4Sci_EvoformerAttention kernel** is a memory-efficient attention kernel developed as part of a collaboration between OpenFold and the DeepSpeed4Science initiative. The kernel provides substantial speedups for training and inference, and significantly reduces the model's peak device memory requirement by 13X. The model is 15% faster during the initial training and finetuning stages, and up to 4x faster during inference. To use this feature, simply set the `use_deepspeed_evo_attention` option in `openfold/config.py`.
## Installation (Linux)
All Python dependencies are specified in `environment.yml`. For producing sequence
alignments, you'll also need `kalign`, the [HH-suite](https://github.com/soedinglab/hh-suite),
and one of {`jackhmmer`, [MMseqs2](https://github.com/soedinglab/mmseqs2) (nightly build)}
installed on on your system. You'll need `git-lfs` to download OpenFold parameters.
installed on your system. You'll need `git-lfs` to download OpenFold parameters.
Finally, some download scripts require `aria2c` and `aws`.
This package is currently supported for CUDA 11 and Pytorch 1.12
......@@ -114,7 +115,13 @@ the model, consult `run_pretrained_openfold.py`.
### Inference
To run inference on a sequence or multiple sequences using a set of DeepMind's
pretrained parameters, run e.g.:
pretrained parameters, first download the OpenFold weights e.g.:
```bash
bash scripts/download_openfold_params.sh openfold/resources
```
then run e.g.:
```bash
python3 run_pretrained_openfold.py \
......@@ -286,6 +293,14 @@ python scripts/precompute_embeddings.py fasta_dir/ embeddings_output_dir/
In the same per-label subdirectories inside `embeddings_output_dir`, you can also place `*.hhr` files (outputs from HHSearch), which can contain the details about the structures that you want to use as templates. If you do not place any such file, templates will not be used and only the ESM-1b embeddings will be used to predict the structure. If you want to use templates, you need to pass the PDB MMCIF dataset to the command.
Then download the SoloSeq model weights, e.g.:
```bash
bash scripts/download_openfold_soloseq_params.sh openfold/resources
```
Now, you are ready to run inference:
```bash
python run_pretrained_openfold.py \
......@@ -295,7 +310,7 @@ python run_pretrained_openfold.py \
--output_dir ./ \
--model_device "cuda:0" \
--config_preset "seq_model_esm1b_ptm" \
--openfold_checkpoint_path openfold/resources/openfold_params/seq_model_esm1b_ptm.pt
--openfold_checkpoint_path openfold/resources/openfold_soloseq_params/seq_model_esm1b_ptm.pt
```
For generating the embeddings during inference, skip the `--use_precomputed_alignments` argument. The `*.hhr` files will be generated as well if you pass the paths to the relevant databases and tools, as specified in the command below. If you skip the database and tool arguments, HHSearch will not be used to find templates and only generated ESM-1b embeddings will be used to predict the structure.
......
......@@ -12,7 +12,9 @@
},
"zero_optimization": {
"stage": 2,
"cpu_offload": true,
"offload_optimizer": {
"device": "cpu"
},
"contiguous_gradients": true
},
"activation_checkpointing": {
......@@ -20,5 +22,6 @@
"cpu_checkpointing": false,
"profile": false
},
"gradient_clipping": 0.1
"gradient_clipping": 0.1,
"zero_force_ds_cpu_optimizer": false
}
......@@ -31,7 +31,7 @@ dependencies:
- bioconda::kalign2==2.04
- pytorch::pytorch=1.12.*
- pip:
- deepspeed==0.5.10
- deepspeed==0.12.4
- dm-tree==0.1.6
- git+https://github.com/NVIDIA/dllogger.git
- git+https://github.com/Dao-AILab/flash-attention.git@5b838a8
{
"nbformat": 4,
"nbformat_minor": 0,
"cells": [
{
"cell_type": "markdown",
"metadata": {
"accelerator": "GPU",
"colab": {
"name": "OpenFold.ipynb",
"provenance": [],
"collapsed_sections": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
"id": "view-in-github",
"colab_type": "text"
},
"language_info": {
"name": "python"
}
"source": [
"<a href=\"https://colab.research.google.com/github/aqlaboratory/openfold/blob/main/notebooks/OpenFold.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
......@@ -57,10 +50,12 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "rowN0bVYLe9n",
"cellView": "form"
"cellView": "form",
"id": "rowN0bVYLe9n"
},
"outputs": [],
"source": [
"#@markdown ### Enter the amino acid sequence to fold ⬇️\n",
"sequence = 'MAAHKGAEHHHKAAEHHEQAAKHHHAAAEHHEKGEHEQAAHHADTAYAHHKHAEEHAAQAAKHDAEHHAPKPH' #@param {type:\"string\"}\n",
......@@ -78,16 +73,16 @@
"\n",
"#@markdown After making your selections, execute this cell by pressing the\n",
"#@markdown *Play* button on the left."
],
"execution_count": null,
"outputs": []
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "woIxeCPygt7K",
"cellView": "form"
"cellView": "form",
"id": "woIxeCPygt7K"
},
"outputs": [],
"source": [
"#@title Install third-party software\n",
"#@markdown Please execute this cell by pressing the *Play* button on \n",
......@@ -97,75 +92,54 @@
"#@markdown **Note**: This installs the software on the Colab \n",
"#@markdown notebook in the cloud and not on your computer.\n",
"\n",
"import sys\n",
"import os, time\n",
"from IPython.utils import io\n",
"import os\n",
"from sys import version_info\n",
"import subprocess\n",
"import tqdm.notebook\n",
"\n",
"TQDM_BAR_FORMAT = '{l_bar}{bar}| {n_fmt}/{total_fmt} [elapsed: {elapsed} remaining: {remaining}]'\n",
"python_version = f\"{version_info.major}.{version_info.minor}\"\n",
"\n",
"\n",
"os.system(\"wget -qnc https://github.com/conda-forge/miniforge/releases/latest/download/Mambaforge-Linux-x86_64.sh\")\n",
"os.system(\"bash Mambaforge-Linux-x86_64.sh -bfp /usr/local\")\n",
"os.system(\"mamba config --set auto_update_conda false\")\n",
"os.system(f\"mamba install -y -c conda-forge -c bioconda kalign2=2.04 hhsuite=3.3.0 openmm=7.7.0 python={python_version} pdbfixer\")\n",
"\n",
"\n",
"python_version = '.'.join(sys.version.split('.')[:2]) #get string like \"3.9\"\n",
"os.system(\"pip install -q \\\"torch<2\\\" biopython ml_collections py3Dmol modelcif\")\n",
"\n",
"try:\n",
" with io.capture_output() as captured:\n",
" %shell sudo apt install --quiet --yes hmmer\n",
"\n",
" # Install py3dmol.\n",
" %shell pip install py3dmol\n",
"\n",
" %shell rm -rf /opt/conda\n",
" %shell wget -q -P /tmp \\\n",
" https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \\\n",
" && bash /tmp/Miniconda3-latest-Linux-x86_64.sh -b -p /opt/conda \\\n",
" && rm /tmp/Miniconda3-latest-Linux-x86_64.sh\n",
"\n",
" PATH=%env PATH\n",
" %env PATH=/opt/conda/bin:{PATH}\n",
"\n",
" # Install the required versions of all dependencies.\n",
" %shell conda install -y -q conda==4.13.0\n",
" %shell conda install -y -q -c conda-forge -c bioconda \\\n",
" kalign2=2.04 \\\n",
" hhsuite=3.3.0 \\\n",
" python={python_version} \\\n",
" openmm=7.7.0 \\\n",
" pdbfixer \\\n",
" 2>&1 1>/dev/null\n",
" %shell pip install -q \\\n",
" ml-collections==0.1.0 \\\n",
" PyYAML==5.4.1 \\\n",
" biopython==1.79 \\\n",
" modelcif==0.7\n",
"\n",
" # Create a ramdisk to store a database chunk to make Jackhmmer run fast.\n",
" %shell sudo apt install --quiet --yes hmmer\n",
" %shell sudo mkdir -m 777 --parents /tmp/ramdisk\n",
" %shell sudo mount -t tmpfs -o size=9G ramdisk /tmp/ramdisk\n",
"\n",
" %shell wget -q -P /content \\\n",
" https://git.scicore.unibas.ch/schwede/openstructure/-/raw/7102c63615b64735c4941278d92b554ec94415f8/modules/mol/alg/src/stereo_chemical_props.txt\n",
"\n",
" # Install AWS CLI\n",
" %shell curl \"https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip\" -o \"awscliv2.zip\"\n",
" %shell unzip -qq awscliv2.zip\n",
" %shell sudo ./aws/install\n",
" %shell rm awscliv2.zip\n",
" %shell rm -rf ./aws\n",
" %shell mkdir -p /content/openfold/openfold/resourcees\n",
" \n",
" commit = \"099769d2ecfd01a8baa8d950030df454a042c910\"\n",
" os.system(f\"pip install -q git+https://github.com/aqlaboratory/openfold.git@{commit}\")\n",
" \n",
" %shell cp -f /content/stereo_chemical_props.txt /usr/local/lib/python3.10/site-packages/openfold/resources/\n",
"\n",
"except subprocess.CalledProcessError as captured:\n",
" print(captured)\n",
" raise"
],
"execution_count": null,
"outputs": []
" print(captured)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "VzJ5iMjTtoZw",
"cellView": "form"
"cellView": "form",
"id": "VzJ5iMjTtoZw"
},
"outputs": [],
"source": [
"#@title Install OpenFold\n",
"#@title Download model weights \n",
"#@markdown Please execute this cell by pressing the *Play* button on \n",
"#@markdown the left.\n",
"\n",
......@@ -180,13 +154,6 @@
"\n",
"try:\n",
" with io.capture_output() as captured:\n",
" # Run setup.py to install only Openfold.\n",
" %shell rm -rf openfold\n",
" %shell git clone \"{GIT_REPO}\" openfold 2>&1 1> /dev/null\n",
" %shell mkdir -p /content/openfold/openfold/resources\n",
" %shell cp -f /content/stereo_chemical_props.txt /content/openfold/openfold/resources\n",
" %shell /usr/bin/python3 -m pip install -q ./openfold\n",
"\n",
" if(weight_set == 'AlphaFold'):\n",
" %shell mkdir --parents \"{ALPHAFOLD_PARAMS_DIR}\"\n",
" %shell wget -O {ALPHAFOLD_PARAMS_PATH} {ALPHAFOLD_PARAM_SOURCE_URL}\n",
......@@ -194,7 +161,14 @@
" --directory=\"{ALPHAFOLD_PARAMS_DIR}\" --preserve-permissions\n",
" %shell rm \"{ALPHAFOLD_PARAMS_PATH}\"\n",
" elif(weight_set == 'OpenFold'):\n",
" # Install AWS CLI\n",
" %shell curl \"https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip\" -o \"awscliv2.zip\"\n",
" %shell unzip -qq awscliv2.zip\n",
" %shell sudo ./aws/install\n",
" %shell rm awscliv2.zip\n",
" %shell rm -rf ./aws\n",
" %shell mkdir --parents \"{OPENFOLD_PARAMS_DIR}\"\n",
"\n",
" %shell aws s3 cp \\\n",
" --no-sign-request \\\n",
" --region us-east-1 \\\n",
......@@ -203,14 +177,17 @@
" else:\n",
" raise ValueError(\"Invalid weight set\")\n",
"except subprocess.CalledProcessError as captured:\n",
" print(captured)\n",
" raise"
],
"execution_count": null,
"outputs": []
" print(captured)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "_FpxxMo-mvcP"
},
"outputs": [],
"source": [
"#@title Import Python packages\n",
"#@markdown Please execute this cell by pressing the *Play* button on \n",
......@@ -219,8 +196,8 @@
"import unittest.mock\n",
"import sys\n",
"\n",
"sys.path.insert(0, f'/usr/local/lib/python{python_version}/dist-packages/')\n",
"sys.path.insert(0, f'/usr/local/lib/python{python_version}/site-packages/')\n",
"sys.path.append(f'/opt/conda/lib/python{python_version}/site-packages')\n",
"\n",
"# Allows us to skip installing these packages\n",
"unnecessary_modules = [\n",
......@@ -245,6 +222,10 @@
"import py3Dmol\n",
"import torch\n",
"import shutil\n",
"import tqdm\n",
"import tqdm.notebook\n",
"\n",
"TQDM_BAR_FORMAT = '{l_bar}{bar}| {n_fmt}/{total_fmt} [elapsed: {elapsed} remaining: {remaining}]'\n",
"\n",
"# Prevent shell magic being broken by openmm, prevent this cryptic error:\n",
"# \"NotImplementedError: A UTF-8 locale is required. Got ANSI_X3.4-1968\"\n",
......@@ -280,13 +261,7 @@
"from IPython import display\n",
"from ipywidgets import GridspecLayout\n",
"from ipywidgets import Output"
],
"metadata": {
"id": "_FpxxMo-mvcP",
"cellView": "form"
},
"execution_count": null,
"outputs": []
]
},
{
"cell_type": "markdown",
......@@ -301,10 +276,12 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "2tTeTTsLKPjB",
"cellView": "form"
"cellView": "form",
"id": "2tTeTTsLKPjB"
},
"outputs": [],
"source": [
"#@title Search against genetic databases\n",
"\n",
......@@ -420,16 +397,16 @@
"plt.ylabel('Non-Gap Count')\n",
"plt.yticks(range(0, num_alignments + 1, max(1, int(num_alignments / 3))))\n",
"plt.show()"
],
"execution_count": null,
"outputs": []
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "XUo6foMQxwS2",
"cellView": "form"
"cellView": "form",
"id": "XUo6foMQxwS2"
},
"outputs": [],
"source": [
"#@title Run OpenFold and download prediction\n",
"\n",
......@@ -693,9 +670,7 @@
"# --- Download the predictions ---\n",
"shutil.make_archive(base_name='prediction', format='zip', root_dir=output_dir)\n",
"files.download(f'{output_dir}.zip')"
],
"execution_count": null,
"outputs": []
]
},
{
"cell_type": "markdown",
......@@ -789,5 +764,24 @@
"* BFD: (modified), by Steinegger M. and Söding J., modified by DeepMind, available under a [Creative Commons Attribution-ShareAlike 4.0 International License](https://creativecommons.org/licenses/by/4.0/). See the Methods section of the [AlphaFold proteome paper](https://www.nature.com/articles/s41586-021-03828-1) for details."
]
}
]
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"name": "OpenFold.ipynb",
"provenance": [],
"gpuType": "T4",
"include_colab_link": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
......@@ -29,19 +29,28 @@ def enforce_config_constraints(config):
(
"globals.use_lma",
"globals.use_flash",
"globals.use_deepspeed_evo_attention"
),
]
for s1, s2 in mutually_exclusive_bools:
s1_setting = string_to_setting(s1)
s2_setting = string_to_setting(s2)
if(s1_setting and s2_setting):
raise ValueError(f"Only one of {s1} and {s2} may be set at a time")
for options in mutually_exclusive_bools:
option_settings = [string_to_setting(o) for o in options]
if sum(option_settings) > 1:
raise ValueError(f"Only one of {', '.join(options)} may be set at a time")
fa_is_installed = importlib.util.find_spec("flash_attn") is not None
if(config.globals.use_flash and not fa_is_installed):
if config.globals.use_flash and not fa_is_installed:
raise ValueError("use_flash requires that FlashAttention is installed")
deepspeed_is_installed = importlib.util.find_spec("deepspeed") is not None
ds4s_is_installed = deepspeed_is_installed and importlib.util.find_spec(
"deepspeed.ops.deepspeed4science") is not None
if config.globals.use_deepspeed_evo_attention and not ds4s_is_installed:
raise ValueError(
"use_deepspeed_evo_attention requires that DeepSpeed be installed "
"and that the deepspeed.ops.deepspeed4science package exists"
)
if(
config.globals.offload_inference and
not config.model.template.average_templates
......@@ -158,10 +167,12 @@ def model_config(
if name == "seqemb_initial_training":
c.data.train.max_msa_clusters = 1
c.data.eval.max_msa_clusters = 1
c.data.train.block_delete_msa = False
c.data.train.max_distillation_msa_clusters = 1
elif name == "seqemb_finetuning":
c.data.train.max_msa_clusters = 1
c.data.eval.max_msa_clusters = 1
c.data.train.block_delete_msa = False
c.data.train.max_distillation_msa_clusters = 1
c.data.train.crop_size = 384
c.loss.violation.weight = 1.
......@@ -218,7 +229,8 @@ def model_config(
if long_sequence_inference:
assert(not train)
c.globals.offload_inference = True
c.globals.use_lma = True
# Default to DeepSpeed memory-efficient attention kernel unless use_lma is explicitly set
c.globals.use_deepspeed_evo_attention = True if not c.globals.use_lma else False
c.globals.use_flash = False
c.model.template.offload_inference = True
c.model.template.template_pair_stack.tune_chunk_size = False
......@@ -338,6 +350,11 @@ config = mlc.ConfigDict(
"true_msa": [NUM_MSA_SEQ, NUM_RES],
"use_clamped_fape": [],
},
"block_delete_msa": {
"msa_fraction_per_block": 0.3,
"randomize_num_blocks": False,
"num_blocks": 5,
},
"masked_msa": {
"profile_prob": 0.1,
"same_prob": 0.1,
......@@ -382,6 +399,7 @@ config = mlc.ConfigDict(
"predict": {
"fixed_size": True,
"subsample_templates": False, # We want top templates.
"block_delete_msa": False,
"masked_msa_replace_fraction": 0.15,
"max_msa_clusters": 512,
"max_extra_msa": 1024,
......@@ -397,6 +415,7 @@ config = mlc.ConfigDict(
"eval": {
"fixed_size": True,
"subsample_templates": False, # We want top templates.
"block_delete_msa": False,
"masked_msa_replace_fraction": 0.15,
"max_msa_clusters": 128,
"max_extra_msa": 1024,
......@@ -412,6 +431,7 @@ config = mlc.ConfigDict(
"train": {
"fixed_size": True,
"subsample_templates": True,
"block_delete_msa": True,
"masked_msa_replace_fraction": 0.15,
"max_msa_clusters": 128,
"max_extra_msa": 1024,
......@@ -441,11 +461,15 @@ config = mlc.ConfigDict(
"globals": {
"blocks_per_ckpt": blocks_per_ckpt,
"chunk_size": chunk_size,
# Use DeepSpeed memory-efficient attention kernel. Mutually
# exclusive with use_lma and use_flash.
"use_deepspeed_evo_attention": False,
# Use Staats & Rabe's low-memory attention algorithm. Mutually
# exclusive with use_flash.
# exclusive with use_deepspeed_evo_attention and use_flash.
"use_lma": False,
# Use FlashAttention in selected modules. Mutually exclusive with
# use_lma. Doesn't work that well on long sequences (>1000 residues).
# use_deepspeed_evo_attention and use_lma. Doesn't work that well
# on long sequences (>1000 residues).
"use_flash": False,
"offload_inference": False,
"c_z": c_z,
......@@ -799,6 +823,7 @@ multimer_config_update = mlc.ConfigDict({
"train": {
"max_msa_clusters": 508,
"max_extra_msa": 2048,
"block_delete_msa" : False,
"crop_size": 640,
"spatial_crop_prob": 0.5,
"interface_threshold": 10.,
......
......@@ -255,28 +255,33 @@ def block_delete_msa(protein, config):
* config.msa_fraction_per_block
).to(torch.int32)
if int(block_num_seq) == 0:
return protein
if config.randomize_num_blocks:
nb = torch.distributions.uniform.Uniform(
0, config.num_blocks + 1
).sample()
nb = int(torch.randint(
low=0,
high=config.num_blocks + 1,
size=(1,),
device=protein["msa"].device,
)[0])
else:
nb = config.num_blocks
del_block_starts = torch.distributions.Uniform(0, num_seq).sample(nb)
del_blocks = del_block_starts[:, None] + torch.range(block_num_seq)
del_blocks = torch.clip(del_blocks, 0, num_seq - 1)
del_indices = torch.unique(torch.sort(torch.reshape(del_blocks, [-1])))[0]
del_block_starts = torch.randint(low=1, high=num_seq, size=(nb,), device=protein["msa"].device)
del_blocks = del_block_starts[:, None] + torch.arange(start=0, end=block_num_seq)
del_blocks = torch.clip(del_blocks, 1, num_seq - 1)
del_indices = torch.unique(torch.reshape(del_blocks, [-1]))
# Make sure we keep the original sequence
combined = torch.cat((torch.range(1, num_seq)[None], del_indices[None]))
combined = torch.cat((torch.arange(start=0, end=num_seq), del_indices)).long()
uniques, counts = combined.unique(return_counts=True)
difference = uniques[counts == 1]
intersection = uniques[counts > 1]
keep_indices = torch.squeeze(difference, 0)
keep_indices = uniques[counts == 1]
assert int(keep_indices[0]) == 0
for k in MSA_FEATURE_NAMES:
if k in protein:
protein[k] = torch.gather(protein[k], keep_indices)
protein[k] = torch.index_select(protein[k], 0, keep_indices)
return protein
......
......@@ -71,6 +71,9 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed):
"""Input pipeline data transformers that can be ensembled and averaged."""
transforms = []
if mode_cfg.block_delete_msa:
transforms.append(data_transforms.block_delete_msa(common_cfg.block_delete_msa))
if "max_distillation_msa_clusters" in mode_cfg:
transforms.append(
data_transforms.sample_msa_distillation(
......
......@@ -657,6 +657,7 @@ class TemplateEmbedder(nn.Module):
templ_dim,
chunk_size,
_mask_trans=True,
use_deepspeed_evo_attention=False,
use_lma=False,
inplace_safe=False
):
......@@ -707,6 +708,7 @@ class TemplateEmbedder(nn.Module):
t_pair,
pair_mask.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=chunk_size,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma,
inplace_safe=inplace_safe,
_mask_trans=_mask_trans,
......@@ -893,6 +895,7 @@ class TemplateEmbedderMultimer(nn.Module):
chunk_size,
multichain_mask_2d,
_mask_trans=True,
use_deepspeed_evo_attention=False,
use_lma=False,
inplace_safe=False
):
......@@ -967,6 +970,7 @@ class TemplateEmbedderMultimer(nn.Module):
template_embeds["template_pair_embedding"],
padding_mask_2d.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=chunk_size,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma,
inplace_safe=inplace_safe,
_mask_trans=_mask_trans,
......
......@@ -90,7 +90,6 @@ class MSATransition(nn.Module):
no_batch_dims=len(m.shape[:-2]),
)
def forward(
self,
m: torch.Tensor,
......@@ -179,6 +178,7 @@ class PairStack(nn.Module):
z: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: Optional[int] = None,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False,
inplace_safe: bool = False,
_mask_trans: bool = True,
......@@ -225,6 +225,7 @@ class PairStack(nn.Module):
mask=pair_mask,
chunk_size=_attn_chunk_size,
use_memory_efficient_kernel=False,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma,
inplace_safe=inplace_safe,
)
......@@ -243,6 +244,7 @@ class PairStack(nn.Module):
mask=pair_mask.transpose(-1, -2),
chunk_size=_attn_chunk_size,
use_memory_efficient_kernel=False,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma,
inplace_safe=inplace_safe,
)
......@@ -360,6 +362,7 @@ class MSABlock(nn.Module, ABC):
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: Optional[int] = None,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False,
use_flash: bool = False,
inplace_safe: bool = False,
......@@ -423,6 +426,7 @@ class EvoformerBlock(MSABlock):
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: Optional[int] = None,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False,
use_flash: bool = False,
inplace_safe: bool = False,
......@@ -462,6 +466,7 @@ class EvoformerBlock(MSABlock):
mask=msa_mask,
chunk_size=_attn_chunk_size,
use_memory_efficient_kernel=False,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma,
)
),
......@@ -483,6 +488,7 @@ class EvoformerBlock(MSABlock):
m,
mask=msa_mask,
chunk_size=chunk_size,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma,
use_flash=use_flash,
),
......@@ -527,6 +533,7 @@ class EvoformerBlock(MSABlock):
z=input_tensors[1],
pair_mask=pair_mask,
chunk_size=chunk_size,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma,
inplace_safe=inplace_safe,
_mask_trans=_mask_trans,
......@@ -602,6 +609,7 @@ class ExtraMSABlock(MSABlock):
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: Optional[int] = None,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False,
inplace_safe: bool = False,
_mask_trans: bool = True,
......@@ -637,7 +645,8 @@ class ExtraMSABlock(MSABlock):
mask=msa_mask,
chunk_size=_attn_chunk_size,
use_lma=use_lma,
use_memory_efficient_kernel=not use_lma and m.is_cuda,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_memory_efficient_kernel=not (use_lma or use_deepspeed_evo_attention),
_checkpoint_chunks=
self.ckpt if torch.is_grad_enabled() else False,
)
......@@ -709,6 +718,7 @@ class ExtraMSABlock(MSABlock):
input_tensors[1],
pair_mask=pair_mask,
chunk_size=chunk_size,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma,
inplace_safe=inplace_safe,
_mask_trans=_mask_trans,
......@@ -853,6 +863,7 @@ class EvoformerStack(nn.Module):
m: torch.Tensor,
z: torch.Tensor,
chunk_size: int,
use_deepspeed_evo_attention: bool,
use_lma: bool,
use_flash: bool,
msa_mask: Optional[torch.Tensor],
......@@ -866,6 +877,7 @@ class EvoformerStack(nn.Module):
msa_mask=msa_mask,
pair_mask=pair_mask,
chunk_size=chunk_size,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma,
use_flash=use_flash,
inplace_safe=inplace_safe,
......@@ -905,6 +917,7 @@ class EvoformerStack(nn.Module):
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: int,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False,
use_flash: bool = False,
_mask_trans: bool = True,
......@@ -916,6 +929,7 @@ class EvoformerStack(nn.Module):
m=input_tensors[0],
z=input_tensors[1],
chunk_size=chunk_size,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma,
use_flash=use_flash,
msa_mask=msa_mask,
......@@ -947,6 +961,7 @@ class EvoformerStack(nn.Module):
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: int,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False,
use_flash: bool = False,
inplace_safe: bool = False,
......@@ -965,10 +980,15 @@ class EvoformerStack(nn.Module):
chunk_size:
Inference-time subbatch size. Acts as a minimum if
self.tune_chunk_size is True
use_lma: Whether to use low-memory attention during inference
use_deepspeed_evo_attention:
Whether to use DeepSpeed memory efficient kernel.
Mutually exclusive with use_lma and use_flash.
use_lma:
Whether to use low-memory attention during inference.
Mutually exclusive with use_flash and use_deepspeed_evo_attention.
use_flash:
Whether to use FlashAttention where possible. Mutually
exclusive with use_lma.
exclusive with use_lma and use_deepspeed_evo_attention.
Returns:
m:
[*, N_seq, N_res, C_m] MSA embedding
......@@ -981,6 +1001,7 @@ class EvoformerStack(nn.Module):
m=m,
z=z,
chunk_size=chunk_size,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma,
use_flash=use_flash,
msa_mask=msa_mask,
......@@ -1065,6 +1086,7 @@ class ExtraMSAStack(nn.Module):
m: torch.Tensor,
z: torch.Tensor,
chunk_size: int,
use_deepspeed_evo_attention: bool,
use_lma: bool,
msa_mask: Optional[torch.Tensor],
pair_mask: Optional[torch.Tensor],
......@@ -1077,6 +1099,7 @@ class ExtraMSAStack(nn.Module):
msa_mask=msa_mask,
pair_mask=pair_mask,
chunk_size=chunk_size,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma,
inplace_safe=inplace_safe,
_mask_trans=_mask_trans,
......@@ -1113,6 +1136,7 @@ class ExtraMSAStack(nn.Module):
def _forward_offload(self,
input_tensors: Sequence[torch.Tensor],
chunk_size: int,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False,
msa_mask: Optional[torch.Tensor] = None,
pair_mask: Optional[torch.Tensor] = None,
......@@ -1125,6 +1149,7 @@ class ExtraMSAStack(nn.Module):
m=input_tensors[0],
z=input_tensors[1],
chunk_size=chunk_size,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma,
msa_mask=msa_mask,
pair_mask=pair_mask,
......@@ -1151,6 +1176,7 @@ class ExtraMSAStack(nn.Module):
msa_mask: Optional[torch.Tensor],
pair_mask: Optional[torch.Tensor],
chunk_size: int,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False,
inplace_safe: bool = False,
_mask_trans: bool = True,
......@@ -1162,6 +1188,7 @@ class ExtraMSAStack(nn.Module):
z:
[*, N_res, N_res, C_z] pair embedding
chunk_size: Inference-time subbatch size for Evoformer modules
use_deepspeed_evo_attention: Whether to use DeepSpeed memory-efficient kernel
use_lma: Whether to use low-memory attention during inference
msa_mask:
Optional [*, N_extra, N_res] MSA mask
......@@ -1175,6 +1202,7 @@ class ExtraMSAStack(nn.Module):
m=m,
z=z,
chunk_size=chunk_size,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma,
msa_mask=msa_mask,
pair_mask=pair_mask,
......
......@@ -103,8 +103,8 @@ class AlphaFold(nn.Module):
**self.config["recycling_embedder"],
)
if (self.template_config.enabled):
if(self.globals.is_multimer):
if self.template_config.enabled:
if self.globals.is_multimer:
self.template_embedder = TemplateEmbedderMultimer(
self.template_config,
)
......@@ -113,7 +113,7 @@ class AlphaFold(nn.Module):
self.template_config,
)
if (self.extra_msa_config.enabled):
if self.extra_msa_config.enabled:
self.extra_msa_embedder = ExtraMSAEmbedder(
**self.extra_msa_config["extra_msa_embedder"],
)
......@@ -134,7 +134,7 @@ class AlphaFold(nn.Module):
)
def embed_templates(self, batch, feats, z, pair_mask, templ_dim, inplace_safe):
if (self.globals.is_multimer):
if self.globals.is_multimer:
asym_id = feats["asym_id"]
multichain_mask_2d = (
asym_id[..., None] == asym_id[..., None, :]
......@@ -146,19 +146,20 @@ class AlphaFold(nn.Module):
templ_dim,
chunk_size=self.globals.chunk_size,
multichain_mask_2d=multichain_mask_2d,
use_deepspeed_evo_attention=self.globals.use_deepspeed_evo_attention,
use_lma=self.globals.use_lma,
inplace_safe=inplace_safe,
_mask_trans = self.config._mask_trans
_mask_trans=self.config._mask_trans
)
feats["template_torsion_angles_mask"] = (
template_embeds["template_mask"]
)
else:
if (self.template_config.offload_templates):
if self.template_config.offload_templates:
return embed_templates_offload(self,
batch, z, pair_mask, templ_dim, inplace_safe=inplace_safe,
)
elif (self.template_config.average_templates):
elif self.template_config.average_templates:
return embed_templates_average(self,
batch, z, pair_mask, templ_dim, inplace_safe=inplace_safe,
)
......@@ -169,6 +170,7 @@ class AlphaFold(nn.Module):
pair_mask.to(dtype=z.dtype),
templ_dim,
chunk_size=self.globals.chunk_size,
use_deepspeed_evo_attention=self.globals.use_deepspeed_evo_attention,
use_lma=self.globals.use_lma,
inplace_safe=inplace_safe,
_mask_trans=self.config._mask_trans
......@@ -176,7 +178,7 @@ class AlphaFold(nn.Module):
return template_embeds
def tolerance_reached(self, prev_pos, next_pos, mask, no_batch_dims, eps=1e-8) -> bool:
def tolerance_reached(self, prev_pos, next_pos, mask, eps=1e-8) -> bool:
"""
Early stopping criteria based on criteria used in
AF2Complex: https://www.nature.com/articles/s41467-022-29394-2
......@@ -188,6 +190,7 @@ class AlphaFold(nn.Module):
Returns:
Whether to stop recycling early based on the desired tolerance.
"""
def distances(points):
"""Compute all pairwise distances for a set of points."""
d = points[..., None, :] - points[..., None, :, :]
......@@ -210,7 +213,7 @@ class AlphaFold(nn.Module):
# This needs to be done manually for DeepSpeed's sake
dtype = next(self.parameters()).dtype
for k in feats:
if(feats[k].dtype == torch.float32):
if feats[k].dtype == torch.float32:
feats[k] = feats[k].to(dtype=dtype)
# Grab some data about the input
......@@ -283,7 +286,7 @@ class AlphaFold(nn.Module):
).to(dtype=z.dtype)
# The recycling embedder is memory-intensive, so we offload first
if(self.globals.offload_inference and inplace_safe):
if self.globals.offload_inference and inplace_safe:
m = m.cpu()
z = z.cpu()
......@@ -298,7 +301,7 @@ class AlphaFold(nn.Module):
del pseudo_beta_x_prev
if(self.globals.offload_inference and inplace_safe):
if self.globals.offload_inference and inplace_safe:
m = m.to(m_1_prev_emb.device)
z = z.to(z_prev.device)
......@@ -334,7 +337,7 @@ class AlphaFold(nn.Module):
inplace_safe,
)
if(
if (
"template_single_embedding" in template_embeds
):
# [*, S = S_c + S_t, N, C_m]
......@@ -344,7 +347,7 @@ class AlphaFold(nn.Module):
)
# [*, S, N]
if(not self.globals.is_multimer):
if not self.globals.is_multimer:
torsion_angles_mask = feats["template_torsion_angles_mask"]
msa_mask = torch.cat(
[feats["msa_mask"], torsion_angles_mask[..., 2]],
......@@ -358,7 +361,7 @@ class AlphaFold(nn.Module):
# Embed extra MSA features + merge with pairwise embeddings
if self.config.extra_msa.enabled:
if(self.globals.is_multimer):
if self.globals.is_multimer:
extra_msa_fn = data_transforms_multimer.build_extra_msa_feat
else:
extra_msa_fn = build_extra_msa_feat
......@@ -367,7 +370,7 @@ class AlphaFold(nn.Module):
extra_msa_feat = extra_msa_fn(feats).to(dtype=z.dtype)
a = self.extra_msa_embedder(extra_msa_feat)
if(self.globals.offload_inference):
if self.globals.offload_inference:
# To allow the extra MSA stack (and later the evoformer) to
# offload its inputs, we remove all references to them here
input_tensors = [a, z]
......@@ -378,6 +381,7 @@ class AlphaFold(nn.Module):
input_tensors,
msa_mask=feats["extra_msa_mask"].to(dtype=m.dtype),
chunk_size=self.globals.chunk_size,
use_deepspeed_evo_attention=self.globals.use_deepspeed_evo_attention,
use_lma=self.globals.use_lma,
pair_mask=pair_mask.to(dtype=m.dtype),
_mask_trans=self.config._mask_trans,
......@@ -390,6 +394,7 @@ class AlphaFold(nn.Module):
a, z,
msa_mask=feats["extra_msa_mask"].to(dtype=m.dtype),
chunk_size=self.globals.chunk_size,
use_deepspeed_evo_attention=self.globals.use_deepspeed_evo_attention,
use_lma=self.globals.use_lma,
pair_mask=pair_mask.to(dtype=m.dtype),
inplace_safe=inplace_safe,
......@@ -400,7 +405,7 @@ class AlphaFold(nn.Module):
# m: [*, S, N, C_m]
# z: [*, N, N, C_z]
# s: [*, N, C_s]
if(self.globals.offload_inference):
if self.globals.offload_inference:
input_tensors = [m, z]
del m, z
m, z, s = self.evoformer._forward_offload(
......@@ -408,6 +413,7 @@ class AlphaFold(nn.Module):
msa_mask=msa_mask.to(dtype=input_tensors[0].dtype),
pair_mask=pair_mask.to(dtype=input_tensors[1].dtype),
chunk_size=self.globals.chunk_size,
use_deepspeed_evo_attention=self.globals.use_deepspeed_evo_attention,
use_lma=self.globals.use_lma,
_mask_trans=self.config._mask_trans,
)
......@@ -420,6 +426,7 @@ class AlphaFold(nn.Module):
msa_mask=msa_mask.to(dtype=m.dtype),
pair_mask=pair_mask.to(dtype=z.dtype),
chunk_size=self.globals.chunk_size,
use_deepspeed_evo_attention=self.globals.use_deepspeed_evo_attention,
use_lma=self.globals.use_lma,
use_flash=self.globals.use_flash,
inplace_safe=inplace_safe,
......@@ -456,7 +463,7 @@ class AlphaFold(nn.Module):
early_stop = False
if self.globals.is_multimer:
early_stop = self.tolerance_reached(x_prev, outputs["final_atom_positions"], seq_mask, no_batch_dims)
early_stop = self.tolerance_reached(x_prev, outputs["final_atom_positions"], seq_mask)
del x_prev
......
......@@ -92,6 +92,7 @@ class MSAAttention(nn.Module):
biases: Optional[List[torch.Tensor]],
chunk_size: int,
use_memory_efficient_kernel: bool,
use_deepspeed_evo_attention: bool,
use_lma: bool,
use_flash: bool,
flash_mask: Optional[torch.Tensor],
......@@ -103,6 +104,7 @@ class MSAAttention(nn.Module):
kv_x=m,
biases=biases,
use_memory_efficient_kernel=use_memory_efficient_kernel,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma,
use_flash=use_flash,
flash_mask=flash_mask,
......@@ -221,6 +223,7 @@ class MSAAttention(nn.Module):
mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None,
use_memory_efficient_kernel: bool = False,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False,
use_flash: bool = False,
inplace_safe: bool = False,
......@@ -268,6 +271,7 @@ class MSAAttention(nn.Module):
biases,
chunk_size,
use_memory_efficient_kernel=use_memory_efficient_kernel,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma,
use_flash=use_flash,
flash_mask=mask,
......@@ -279,6 +283,7 @@ class MSAAttention(nn.Module):
kv_x=m,
biases=biases,
use_memory_efficient_kernel=use_memory_efficient_kernel,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma,
use_flash=use_flash,
flash_mask=mask,
......@@ -356,6 +361,7 @@ class MSAColumnAttention(nn.Module):
m: torch.Tensor,
mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False,
use_flash: bool = False,
) -> torch.Tensor:
......@@ -379,6 +385,7 @@ class MSAColumnAttention(nn.Module):
m,
mask=mask,
chunk_size=chunk_size,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma,
use_flash=use_flash,
)
......
......@@ -12,20 +12,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import importlib
import math
from typing import Optional, Callable, List, Tuple, Sequence
from typing import Optional, Callable, List, Tuple
import numpy as np
deepspeed_is_installed = importlib.util.find_spec("deepspeed") is not None
if(deepspeed_is_installed):
ds4s_is_installed = deepspeed_is_installed and importlib.util.find_spec("deepspeed.ops.deepspeed4science") is not None
if deepspeed_is_installed:
import deepspeed
if ds4s_is_installed:
from deepspeed.ops.deepspeed4science import DS4Sci_EvoformerAttention
fa_is_installed = importlib.util.find_spec("flash_attn") is not None
if(fa_is_installed):
from flash_attn.bert_padding import unpad_input, pad_input
from flash_attn.flash_attention import FlashAttention
if fa_is_installed:
from flash_attn.bert_padding import unpad_input
from flash_attn.flash_attn_interface import flash_attn_unpadded_kvpacked_func
import torch
......@@ -33,7 +35,6 @@ import torch.nn as nn
from scipy.stats import truncnorm
from openfold.utils.checkpointing import get_checkpoint_fn
from openfold.utils.chunk_utils import _chunk_slice
from openfold.utils.kernel.attention_core import attention_core
from openfold.utils.precision_utils import is_fp16_enabled
from openfold.utils.tensor_utils import (
......@@ -42,8 +43,8 @@ from openfold.utils.tensor_utils import (
)
DEFAULT_LMA_Q_CHUNK_SIZE=1024
DEFAULT_LMA_KV_CHUNK_SIZE=4096
DEFAULT_LMA_Q_CHUNK_SIZE = 1024
DEFAULT_LMA_KV_CHUNK_SIZE = 4096
def _prod(nums):
......@@ -217,9 +218,9 @@ class LayerNorm(nn.Module):
d = x.dtype
deepspeed_is_initialized = (
deepspeed_is_installed and
deepspeed.utils.is_initialized()
deepspeed.comm.comm.is_initialized()
)
if(d is torch.bfloat16 and not deepspeed_is_initialized):
if d is torch.bfloat16 and not deepspeed_is_initialized:
with torch.cuda.amp.autocast(enabled=False):
out = nn.functional.layer_norm(
x,
......@@ -249,9 +250,9 @@ def softmax_no_cast(t: torch.Tensor, dim: int = -1) -> torch.Tensor:
d = t.dtype
deepspeed_is_initialized = (
deepspeed_is_installed and
deepspeed.utils.is_initialized()
deepspeed.comm.comm.is_initialized()
)
if(d is torch.bfloat16 and not deepspeed_is_initialized):
if d is torch.bfloat16 and not deepspeed_is_initialized:
with torch.cuda.amp.autocast(enabled=False):
s = torch.nn.functional.softmax(t, dim=dim)
else:
......@@ -283,7 +284,7 @@ def _attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, bias
def _attention_chunked_trainable(
query, key, value, biases, chunk_size, chunk_dim, checkpoint,
):
if(checkpoint and len(biases) > 2):
if checkpoint and len(biases) > 2:
raise ValueError(
"Checkpointed version permits only permits two bias terms"
)
......@@ -311,7 +312,7 @@ def _attention_chunked_trainable(
)
return b[tuple(idx)]
if(checkpoint):
if checkpoint:
bias_1_chunk, bias_2_chunk = [
_slice_bias(b) if b is not None else None
for b in (biases + [None, None])[:2]
......@@ -398,7 +399,8 @@ class Attention(nn.Module):
def _prep_qkv(self,
q_x: torch.Tensor,
kv_x: torch.Tensor
kv_x: torch.Tensor,
apply_scale: bool = True
) -> Tuple[
torch.Tensor, torch.Tensor, torch.Tensor
]:
......@@ -417,6 +419,7 @@ class Attention(nn.Module):
k = k.transpose(-2, -3)
v = v.transpose(-2, -3)
if apply_scale:
q /= math.sqrt(self.c_hidden)
return q, k, v
......@@ -425,7 +428,7 @@ class Attention(nn.Module):
o: torch.Tensor,
q_x: torch.Tensor
) -> torch.Tensor:
if(self.linear_g is not None):
if self.linear_g is not None:
g = self.sigmoid(self.linear_g(q_x))
# [*, Q, H, C_hidden]
......@@ -446,11 +449,12 @@ class Attention(nn.Module):
kv_x: torch.Tensor,
biases: Optional[List[torch.Tensor]] = None,
use_memory_efficient_kernel: bool = False,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False,
lma_q_chunk_size: int = DEFAULT_LMA_Q_CHUNK_SIZE,
lma_kv_chunk_size: int = DEFAULT_LMA_KV_CHUNK_SIZE,
use_flash: bool = False,
flash_mask: Optional[torch.Tensor] = None,
flash_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Args:
......@@ -465,6 +469,10 @@ class Attention(nn.Module):
This should be the default choice for most. If none of the
"use_<...>" flags are True, a stock PyTorch implementation
is used instead
use_deepspeed_evo_attention:
Whether to use DeepSpeed memory-efficient attention kernel.
If none of the "use_<...>" flags are True, a stock PyTorch
implementation is used instead
use_lma:
Whether to use low-memory attention (Staats & Rabe 2021). If
none of the "use_<...>" flags are True, a stock PyTorch
......@@ -476,50 +484,57 @@ class Attention(nn.Module):
Returns
[*, Q, C_q] attention update
"""
if(use_lma and (lma_q_chunk_size is None or lma_kv_chunk_size is None)):
if use_lma and (lma_q_chunk_size is None or lma_kv_chunk_size is None):
raise ValueError(
"If use_lma is specified, lma_q_chunk_size and "
"lma_kv_chunk_size must be provided"
)
if(use_flash and biases is not None):
if use_flash and biases is not None:
raise ValueError(
"use_flash is incompatible with the bias option. For masking, "
"use flash_mask instead"
)
attn_options = [use_memory_efficient_kernel, use_lma, use_flash]
if(sum(attn_options) > 1):
attn_options = [use_memory_efficient_kernel, use_deepspeed_evo_attention, use_lma, use_flash]
if sum(attn_options) > 1:
raise ValueError(
"Choose at most one alternative attention algorithm"
)
if(biases is None):
if biases is None:
biases = []
# [*, H, Q/K, C_hidden]
q, k, v = self._prep_qkv(q_x, kv_x)
# DeepSpeed attention kernel applies scaling internally
q, k, v = self._prep_qkv(q_x, kv_x,
apply_scale=not use_deepspeed_evo_attention)
# [*, Q, H, C_hidden]
if is_fp16_enabled():
use_memory_efficient_kernel = False
if(use_memory_efficient_kernel):
if(len(biases) > 2):
if use_memory_efficient_kernel:
if len(biases) > 2:
raise ValueError(
"If use_memory_efficient_kernel is True, you may only "
"provide up to two bias terms"
)
o = attention_core(q, k, v, *((biases + [None] * 2)[:2]))
o = o.transpose(-2, -3)
elif(use_lma):
elif use_deepspeed_evo_attention:
if len(biases) > 2:
raise ValueError(
"If use_deepspeed_evo_attention is True, you may only "
"provide up to two bias terms"
)
o = _deepspeed_evo_attn(q, k, v, biases)
elif use_lma:
biases = [
b.expand(b.shape[:-2] + (q_x.shape[-2],) + (kv_x.shape[-2],))
for b in biases
]
o = _lma(q, k, v, biases, lma_q_chunk_size, lma_kv_chunk_size)
o = o.transpose(-2, -3)
elif(use_flash):
elif use_flash:
o = _flash_attn(q, k, v, flash_mask)
else:
o = _attention(q, k, v, biases)
......@@ -577,7 +592,7 @@ class GlobalAttention(nn.Module):
v = self.linear_v(m)
bias = (self.inf * (mask - 1))[..., :, None, :]
if(not use_lma):
if not use_lma:
# [*, N_res, H, N_seq]
a = torch.matmul(
q,
......@@ -619,6 +634,72 @@ class GlobalAttention(nn.Module):
return m
@torch.jit.ignore
def _deepspeed_evo_attn(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
biases: List[torch.Tensor],
):
"""""
Compute attention using the DeepSpeed DS4Sci_EvoformerAttention kernel.
Args:
q:
[*, H, Q, C_hidden] query data
k:
[*, H, K, C_hidden] key data
v:
[*, H, V, C_hidden] value data
biases:
List of biases that broadcast to [*, H, Q, K]
"""
if not ds4s_is_installed:
raise ValueError(
"_deepspeed_evo_attn requires that DeepSpeed be installed "
"and that the deepspeed.ops.deepspeed4science package exists"
)
def reshape_dims(x):
no_batch_dims = len(x.shape[:-3])
if no_batch_dims < 2:
return x.reshape(*((1,) * (2 - no_batch_dims) + x.shape))
if no_batch_dims > 2:
return x.reshape(*((x.shape[0], -1) + x.shape[-3:]))
return x
# [*, Q/K, H, C_hidden]
q = q.transpose(-2, -3)
k = k.transpose(-2, -3)
v = v.transpose(-2, -3)
# Reshape tensors to match expected input shape [B, N, Q/K, H, C_hidden]
# for DS4Sci_EvoformerAttention() by adding or flattening batch dims as needed.
orig_shape = q.shape
if len(orig_shape[:-3]) != 2:
q = reshape_dims(q)
k = reshape_dims(k)
v = reshape_dims(v)
biases = [reshape_dims(b) for b in biases]
# DeepSpeed attn. kernel requires inputs to be type bf16 or fp16
# Cast to bf16 so kernel can be used during inference
orig_dtype = q.dtype
if orig_dtype not in [torch.bfloat16, torch.float16]:
o = DS4Sci_EvoformerAttention(q.to(dtype=torch.bfloat16),
k.to(dtype=torch.bfloat16),
v.to(dtype=torch.bfloat16),
[b.to(dtype=torch.bfloat16) for b in biases])
o = o.to(dtype=orig_dtype)
else:
o = DS4Sci_EvoformerAttention(q, k, v, biases)
o = o.reshape(orig_shape)
return o
def _lma(
q: torch.Tensor,
k: torch.Tensor,
......@@ -683,7 +764,7 @@ def _lma(
@torch.jit.ignore
def _flash_attn(q, k, v, kv_mask):
if(not fa_is_installed):
if not fa_is_installed:
raise ValueError(
"_flash_attn requires that FlashAttention be installed"
)
......@@ -735,8 +816,8 @@ def _flash_attn(q, k, v, kv_mask):
kv_cu_seqlens,
q_max_s,
kv_max_s,
dropout_p = 0.,
softmax_scale = 1., # q has been scaled already
dropout_p=0.,
softmax_scale=1., # q has been scaled already
)
# [*, B, N, H, C]
......
......@@ -20,7 +20,7 @@ from typing import Optional, List
import torch
import torch.nn as nn
from openfold.model.primitives import Linear, LayerNorm, Attention
from openfold.model.primitives import LayerNorm, Attention
from openfold.model.dropout import (
DropoutRowwise,
DropoutColumnwise,
......@@ -48,7 +48,6 @@ from openfold.utils.feats import (
from openfold.utils.tensor_utils import (
add,
permute_final_dims,
flatten_final_dims,
tensor_tree_map,
)
......@@ -57,6 +56,7 @@ class TemplatePointwiseAttention(nn.Module):
"""
Implements Algorithm 17.
"""
def __init__(self, c_t, c_z, c_hidden, no_heads, inf, **kwargs):
"""
Args:
......@@ -103,7 +103,6 @@ class TemplatePointwiseAttention(nn.Module):
no_batch_dims=len(z.shape[:-2]),
)
def forward(self,
t: torch.Tensor,
z: torch.Tensor,
......@@ -212,13 +211,20 @@ class TemplatePairStackBlock(nn.Module):
self.pair_transition_n,
)
def tri_att_start_end(self, single, _attn_chunk_size, single_mask, use_lma, inplace_safe):
def tri_att_start_end(self,
single: torch.Tensor,
_attn_chunk_size: Optional[int],
single_mask: torch.Tensor,
use_deepspeed_evo_attention: bool,
use_lma: bool,
inplace_safe: bool):
single = add(single,
self.dropout_row(
self.tri_att_start(
single,
chunk_size=_attn_chunk_size,
mask=single_mask,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma,
inplace_safe=inplace_safe,
)
......@@ -232,6 +238,7 @@ class TemplatePairStackBlock(nn.Module):
single,
chunk_size=_attn_chunk_size,
mask=single_mask,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma,
inplace_safe=inplace_safe,
)
......@@ -241,14 +248,17 @@ class TemplatePairStackBlock(nn.Module):
return single
def tri_mul_out_in(self, single, single_mask, inplace_safe):
def tri_mul_out_in(self,
single: torch.Tensor,
single_mask: torch.Tensor,
inplace_safe: bool):
tmu_update = self.tri_mul_out(
single,
mask=single_mask,
inplace_safe=inplace_safe,
_add_with_inplace=True,
)
if (not inplace_safe):
if not inplace_safe:
single = single + self.dropout_row(tmu_update)
else:
single = tmu_update
......@@ -261,7 +271,7 @@ class TemplatePairStackBlock(nn.Module):
inplace_safe=inplace_safe,
_add_with_inplace=True,
)
if (not inplace_safe):
if not inplace_safe:
single = single + self.dropout_row(tmu_update)
else:
single = tmu_update
......@@ -274,12 +284,13 @@ class TemplatePairStackBlock(nn.Module):
z: torch.Tensor,
mask: torch.Tensor,
chunk_size: Optional[int] = None,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False,
inplace_safe: bool = False,
_mask_trans: bool = True,
_attn_chunk_size: Optional[int] = None,
):
if(_attn_chunk_size is None):
if _attn_chunk_size is None:
_attn_chunk_size = chunk_size
single_templates = [
......@@ -299,12 +310,15 @@ class TemplatePairStackBlock(nn.Module):
inplace_safe=inplace_safe),
_attn_chunk_size=_attn_chunk_size,
single_mask=single_mask,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma,
inplace_safe=inplace_safe)
else:
single = self.tri_mul_out_in(single=self.tri_att_start_end(single=single,
single = self.tri_mul_out_in(
single=self.tri_att_start_end(single=single,
_attn_chunk_size=_attn_chunk_size,
single_mask=single_mask,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma,
inplace_safe=inplace_safe),
single_mask=single_mask,
......@@ -319,10 +333,10 @@ class TemplatePairStackBlock(nn.Module):
inplace_safe,
)
if (not inplace_safe):
if not inplace_safe:
single_templates[i] = single
if (not inplace_safe):
if not inplace_safe:
z = torch.cat(single_templates, dim=-4)
return z
......@@ -332,6 +346,7 @@ class TemplatePairStack(nn.Module):
"""
Implements Algorithm 16.
"""
def __init__(
self,
c_t,
......@@ -389,7 +404,7 @@ class TemplatePairStack(nn.Module):
self.tune_chunk_size = tune_chunk_size
self.chunk_size_tuner = None
if(tune_chunk_size):
if tune_chunk_size:
self.chunk_size_tuner = ChunkSizeTuner()
def forward(
......@@ -397,6 +412,7 @@ class TemplatePairStack(nn.Module):
t: torch.tensor,
mask: torch.tensor,
chunk_size: int,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False,
inplace_safe: bool = False,
_mask_trans: bool = True,
......@@ -410,7 +426,7 @@ class TemplatePairStack(nn.Module):
Returns:
[*, N_templ, N_res, N_res, C_t] template embedding update
"""
if(mask.shape[-3] == 1):
if mask.shape[-3] == 1:
expand_idx = list(mask.shape)
expand_idx[-3] = t.shape[-4]
mask = mask.expand(*expand_idx)
......@@ -420,6 +436,7 @@ class TemplatePairStack(nn.Module):
b,
mask=mask,
chunk_size=chunk_size,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma,
inplace_safe=inplace_safe,
_mask_trans=_mask_trans,
......@@ -427,8 +444,8 @@ class TemplatePairStack(nn.Module):
for b in self.blocks
]
if(chunk_size is not None and self.chunk_size_tuner is not None):
assert(not self.training)
if chunk_size is not None and self.chunk_size_tuner is not None:
assert (not self.training)
tuned_chunk_size = self.chunk_size_tuner.tune_chunk_size(
representative_fn=blocks[0],
args=(t.clone(),),
......@@ -510,11 +527,13 @@ def embed_templates_offload(
t.unsqueeze(templ_dim),
pair_mask.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=model.globals.chunk_size,
use_deepspeed_evo_attention=model.globals.use_deepspeed_evo_attention,
use_lma=model.globals.use_lma,
inplace_safe=inplace_safe,
_mask_trans=model.config._mask_trans,
)
assert(sys.getrefcount(t) == 2)
assert (sys.getrefcount(t) == 2)
pair_embeds_cpu.append(t.cpu())
......@@ -540,7 +559,7 @@ def embed_templates_offload(
del pair_chunks
if(inplace_safe):
if inplace_safe:
t = t * (torch.sum(batch["template_mask"], dim=-1) > 0)
else:
t *= (torch.sum(batch["template_mask"], dim=-1) > 0)
......@@ -627,7 +646,9 @@ def embed_templates_average(
t,
pair_mask.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=model.globals.chunk_size,
use_deepspeed_evo_attention=model.globals.use_deepspeed_evo_attention,
use_lma=model.globals.use_lma,
inplace_safe=inplace_safe,
_mask_trans=model.config._mask_trans,
)
......@@ -639,19 +660,19 @@ def embed_templates_average(
)
denom = math.ceil(n_templ / templ_group_size)
if(inplace_safe):
if inplace_safe:
t /= denom
else:
t = t / denom
if(inplace_safe):
if inplace_safe:
out_tensor += t
else:
out_tensor = out_tensor + t
del t
if(inplace_safe):
if inplace_safe:
out_tensor *= (torch.sum(batch["template_mask"], dim=-1) > 0)
else:
out_tensor = out_tensor * (torch.sum(batch["template_mask"], dim=-1) > 0)
......
......@@ -63,6 +63,7 @@ class TriangleAttention(nn.Module):
biases: List[torch.Tensor],
chunk_size: int,
use_memory_efficient_kernel: bool = False,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False,
inplace_safe: bool = False,
) -> torch.Tensor:
......@@ -77,6 +78,7 @@ class TriangleAttention(nn.Module):
partial(
self.mha,
use_memory_efficient_kernel=use_memory_efficient_kernel,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma
),
mha_inputs,
......@@ -90,6 +92,7 @@ class TriangleAttention(nn.Module):
mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None,
use_memory_efficient_kernel: bool = False,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False,
inplace_safe: bool = False,
) -> torch.Tensor:
......@@ -130,6 +133,7 @@ class TriangleAttention(nn.Module):
biases,
chunk_size,
use_memory_efficient_kernel=use_memory_efficient_kernel,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma,
inplace_safe=inplace_safe,
)
......@@ -139,6 +143,7 @@ class TriangleAttention(nn.Module):
kv_x=x,
biases=biases,
use_memory_efficient_kernel=use_memory_efficient_kernel,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma
)
......
......@@ -181,6 +181,7 @@ def trace_model_(model, sample_input):
("mask", msa_mask),
("chunk_size", torch.tensor(evoformer_attn_chunk_size)),
("use_memory_efficient_kernel", torch.tensor(False)),
("use_deepspeed_evo_attention", torch.tensor(model.globals.use_deepspeed_evo_attention)),
("use_lma", torch.tensor(model.globals.use_lma)),
]
verify_arg_order(
......@@ -201,6 +202,7 @@ def trace_model_(model, sample_input):
("m", m),
("mask", msa_mask),
("chunk_size", torch.tensor(evoformer_chunk_size)),
("use_deepspeed_evo_attention", torch.tensor(model.globals.use_deepspeed_evo_attention)),
("use_lma", torch.tensor(model.globals.use_lma)),
("use_flash", torch.tensor(model.globals.use_flash)),
]
......@@ -283,6 +285,7 @@ def trace_model_(model, sample_input):
("mask", pair_mask.float()),
("chunk_size", torch.tensor(evoformer_attn_chunk_size)),
("use_memory_efficient_kernel", torch.tensor(False)),
("use_deepspeed_evo_attention", torch.tensor(model.globals.use_deepspeed_evo_attention)),
("use_lma", torch.tensor(model.globals.use_lma)),
("inplace_safe", torch.tensor(True)),
]
......@@ -305,6 +308,7 @@ def trace_model_(model, sample_input):
("mask", pair_mask.transpose(-1, -2).float()),
("chunk_size", torch.tensor(evoformer_attn_chunk_size)),
("use_memory_efficient_kernel", torch.tensor(False)),
("use_deepspeed_evo_attention", torch.tensor(model.globals.use_deepspeed_evo_attention)),
("use_lma", torch.tensor(model.globals.use_lma)),
("inplace_safe", torch.tensor(True)),
]
......
#!/bin/bash
#
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Downloads OpenFold parameters.
#
# Usage: bash download_openfold_params_huggingface.sh /path/to/download/directory
set -e
if [[ $# -eq 0 ]]; then
echo "Error: download directory must be provided as an input argument."
exit 1
fi
if ! command -v aws &> /dev/null ; then
echo "Error: aws could not be found. Please install aws."
exit 1
fi
DOWNLOAD_DIR="${1}/openfold_soloseq_params"
mkdir -p "${DOWNLOAD_DIR}"
aws s3 cp --no-sign-request --region us-east-1 s3://openfold/openfold_soloseq_params/ "${DOWNLOAD_DIR}" --recursive
......@@ -13,6 +13,12 @@ gunzip -c tests/test_data/sample_feats.pickle.gz > tests/test_data/sample_feats.
python setup.py install
export LD_LIBRARY_PATH=$CONDA_PREFIX/lib:$LD_LIBRARY_PATH
echo "Download CUTLASS, required for Deepspeed Evoformer attention kernel"
git clone https://github.com/NVIDIA/cutlass --depth 1
conda env config vars set CUTLASS_PATH=$PWD/cutlass
# This setting is used to fix a worker assignment issue during data loading
conda env config vars set KMP_AFFINITY=none
export LIBRARY_PATH=$CONDA_PREFIX/lib:$LIBRARY_PATH
export LD_LIBRARY_PATH=$CONDA_PREFIX/lib:$LD_LIBRARY_PATH
......@@ -10,7 +10,6 @@ import numpy as np
from openfold.config import model_config
from openfold.model.model import AlphaFold
from openfold.utils.import_weights import import_jax_weights_
from tests.config import consts
# Give JAX some GPU memory discipline
# (by default it hogs 90% of GPU memory. This disables that behavior and also
......@@ -19,6 +18,18 @@ os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
os.environ["JAX_PLATFORM_NAME"] = "gpu"
def skip_unless_ds4s_installed():
deepspeed_is_installed = importlib.util.find_spec("deepspeed") is not None
ds4s_is_installed = deepspeed_is_installed and importlib.util.find_spec(
"deepspeed.ops.deepspeed4science") is not None
return unittest.skipUnless(ds4s_is_installed, "Requires DeepSpeed with version ≥ 0.10.4")
def skip_unless_flash_attn_installed():
fa_is_installed = importlib.util.find_spec("flash_attn") is not None
return unittest.skipUnless(fa_is_installed, "Requires Flash Attention")
def alphafold_is_installed():
return importlib.util.find_spec("alphafold") is not None
......
......@@ -13,6 +13,7 @@
# limitations under the License.
from random import randint
import torch
import numpy as np
from scipy.spatial.transform import Rotation
......@@ -127,3 +128,17 @@ def random_affines_4x4(dim):
affines[:, 3, 3] = 1
return affines.reshape(*dim, 4, 4)
def random_attention_inputs(batch_size, n_seq, n, no_heads, c_hidden, inf=1e9,
dtype=torch.float32, requires_grad=False):
q = torch.rand(batch_size, n_seq, n, c_hidden, dtype=dtype, requires_grad=requires_grad).cuda()
kv = torch.rand(batch_size, n_seq, n, c_hidden, dtype=dtype, requires_grad=requires_grad).cuda()
mask = torch.randint(0, 2, (batch_size, n_seq, 1, 1, n), dtype=dtype, requires_grad=False).cuda()
z_bias = torch.rand(batch_size, 1, no_heads, n, n, dtype=dtype, requires_grad=requires_grad).cuda()
mask_bias = inf * (mask - 1)
biases = [mask_bias, z_bias]
return q, kv, mask, biases
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Unit tests to compare components of OpenFold run with the DeepSpeed memory-efficient
attention kernel, DS4Sci_EvoformerAttention vs. a stock PyTorch attention implementation.
"""
import unittest
import numpy as np
import pickle
import torch
from torch.nn import functional as F
from openfold.data import data_transforms
from openfold.model.primitives import (
lecun_normal_init_,
Attention
)
from openfold.utils.tensor_utils import tensor_tree_map
from tests.config import consts
import tests.compare_utils as compare_utils
from tests.data_utils import random_template_feats, random_attention_inputs
@compare_utils.skip_unless_ds4s_installed()
class TestDeepSpeedKernel(unittest.TestCase):
def compare_attention_types(self, use_flash=False):
"""Compare attention with and without using DeepSpeed Evoformer kernel."""
batch_size = consts.batch_size
n_seq = 18
n_res = 20
c_hidden = 32
no_heads = 4
eps = 2e-2
q, kv, mask, biases = random_attention_inputs(batch_size=batch_size,
n_seq=n_seq,
n=n_res,
no_heads=no_heads,
c_hidden=c_hidden)
a = Attention(
c_hidden, c_hidden, c_hidden, c_hidden, no_heads
).cuda()
# Change output params init for testing since they are initialized with 'final' init (zeros)
# Otherwise both will just return zero.
with torch.no_grad():
lecun_normal_init_(a.linear_g.weight)
lecun_normal_init_(a.linear_o.weight)
if use_flash:
biases = [biases[0]]
flash_mask = mask.reshape(batch_size * n_seq, n_res)
real_out = a(q, kv, use_flash=True, flash_mask=flash_mask).cpu()
else:
real_out = a(q, kv, biases=biases).cpu()
ds_out = a(q, kv, biases=biases, use_deepspeed_evo_attention=True).cpu()
err = torch.max(torch.abs(ds_out - real_out))
self.assertTrue(err < eps, f'Error: {err}')
def test_ds_kernel_vs_attention_forward(self):
"""Compare regular attention vs. DeepSpeed Evoformer kernel."""
self.compare_attention_types(use_flash=False)
@compare_utils.skip_unless_flash_attn_installed()
def test_ds_kernel_vs_flash_attn_forward(self):
"""Compare Flash Attention vs. DeepSpeed Evoformer kernel."""
self.compare_attention_types(use_flash=True)
def test_ds_kernel_vs_attention_backward(self):
"""Compare backward pass for regular attention vs. DeepSpeed Evoformer kernel."""
batch_size = consts.batch_size
n_seq = 18
n_res = 20
c_hidden = 32
no_heads = 4
eps = consts.eps
q, kv, mask, biases = random_attention_inputs(batch_size=batch_size,
n_seq=n_seq,
n=n_res,
no_heads=no_heads,
c_hidden=c_hidden,
requires_grad=True)
attn = Attention(
c_hidden, c_hidden, c_hidden, c_hidden, no_heads
).cuda()
with torch.no_grad():
lecun_normal_init_(attn.linear_g.weight)
lecun_normal_init_(attn.linear_o.weight)
def clone(t):
# Create new params, clone values
t = t.clone()
if t.requires_grad:
t.retain_grad()
return t
def init_attn():
# Create new attention object with same initial weights
a_clone = Attention(
c_hidden, c_hidden, c_hidden, c_hidden, no_heads
).cuda()
a_clone.load_state_dict(attn.state_dict())
return a_clone
# Clone param values and run attention with DS kernel
q_repro = clone(q)
kv_repro = clone(kv)
biases_repro = [clone(b) for b in biases]
a_repro = init_attn()
out_repro = a_repro(q_repro, kv_repro, biases=biases_repro, use_deepspeed_evo_attention=True)
loss_repro = torch.mean(out_repro)
loss_repro.backward()
q_gt = clone(q)
kv_gt = clone(kv)
biases_gt = [clone(b) for b in biases]
# Clone param values and run attention without DS kernel
a_gt = init_attn()
out_gt = a_gt(q_gt, kv_gt, biases=biases_gt)
loss_gt = torch.mean(out_gt)
loss_gt.backward()
# Compare the grads of attention inputs
pairs = zip([q_repro, kv_repro, biases_repro[1]],
[q_gt, kv_gt, biases_gt[1]])
for i, item in enumerate(pairs):
t_repro, t_gt = item
err = torch.max(torch.abs(t_repro.grad.cpu() - t_gt.grad.cpu()))
self.assertTrue(err < eps, f'Error item #{i}: {err}')
# Compare the grads of model weights
a_repro_params = dict(a_repro.named_parameters())
a_gt_params = dict(a_gt.named_parameters())
for name in a_gt_params.keys():
t_repro = a_repro_params[name]
t_gt = a_gt_params[name]
err = torch.max(torch.abs(t_repro.grad.cpu() - t_gt.grad.cpu()))
self.assertTrue(err < eps, f'Error item {name}: {err}')
def compare_evoformer(self, dtype, eps):
"""
Compare Evoformer output with and without using DeepSpeed Evoformer attention kernel.
Set dtype to confirm the kernel can be used during both training (BF16) and inference (FP32),
since the kernel itself can run with either BF16 or FP16 precision.
"""
n_res = 20
n_seq = 18
c_m_shape = (consts.c_m,)
c_z_shape = (consts.c_z,)
activations = {
"msa": torch.rand(n_seq, n_res, consts.c_m, device='cuda', dtype=dtype),
"pair": torch.rand(n_res, n_res, consts.c_z, device='cuda', dtype=dtype)
}
masks = {
"msa": torch.randint(0, 2, (n_seq, n_res), device='cuda', dtype=dtype),
"pair": torch.randint(0, 2, (n_res, n_res), device='cuda', dtype=dtype),
}
with torch.cuda.amp.autocast(dtype=dtype):
model = compare_utils.get_global_pretrained_openfold()
out_repro_msa, out_repro_pair = model.evoformer.blocks[0](
activations["msa"],
activations["pair"],
masks["msa"],
masks["pair"],
use_deepspeed_evo_attention=False,
chunk_size=4,
_mask_trans=False,
inplace_safe=False,
)
# In practice, layer norms applied later in the network make any
# kernel rounding errors negligible
out_repro_msa = F.layer_norm(out_repro_msa, c_m_shape).cpu()
out_repro_pair = F.layer_norm(out_repro_pair, c_z_shape).cpu()
out_repro_msa_ds, out_repro_pair_ds = model.evoformer.blocks[0](
activations["msa"],
activations["pair"],
masks["msa"],
masks["pair"],
use_deepspeed_evo_attention=True,
chunk_size=4,
_mask_trans=False,
inplace_safe=False,
)
out_repro_msa_ds = F.layer_norm(out_repro_msa_ds, c_m_shape).cpu()
out_repro_pair_ds = F.layer_norm(out_repro_pair_ds, c_z_shape).cpu()
err = torch.mean(torch.abs(out_repro_msa - out_repro_msa_ds))
self.assertTrue(err < eps, f'MSA Error: {err}')
err = torch.mean(torch.abs(out_repro_pair - out_repro_pair_ds))
self.assertTrue(err < eps, f'Pair Error {err}')
def test_compare_evoformer_bf16(self):
"""Run evoformer comparison test with BF16 precision."""
self.compare_evoformer(dtype=torch.bfloat16, eps=4e-2)
def test_compare_evoformer_fp32(self):
"""Run evoformer comparison test with FP32 precision."""
self.compare_evoformer(dtype=torch.float32, eps=2e-2)
def test_compare_template_stack(self):
"""
Compare Template Stack output with and without using DeepSpeed Evoformer attention kernel.
Kernel can be used for Triangle Attention in the Template Pair Stack.
"""
n_templ = consts.n_templ
n_res = 20
eps = 2e-2
pair_act = np.random.rand(n_res, n_res, consts.c_z).astype(np.float32)
batch = random_template_feats(n_templ, n_res)
batch["template_all_atom_masks"] = batch["template_all_atom_mask"]
pair_mask = np.random.randint(0, 2, (n_res, n_res)).astype(np.float32)
inds = np.random.randint(0, 21, (n_res,))
batch["target_feat"] = np.eye(22)[inds]
with torch.no_grad():
model = compare_utils.get_global_pretrained_openfold()
model.globals.use_deepspeed_evo_attention = False
out_repro = model.embed_templates(
{k: torch.as_tensor(v).cuda() for k, v in batch.items()},
torch.as_tensor(pair_act).cuda(),
torch.as_tensor(pair_mask).cuda(),
templ_dim=0,
inplace_safe=False
)
out_repro = out_repro["template_pair_embedding"].cpu()
model.globals.use_deepspeed_evo_attention = True
out_repro_ds = model.embed_templates(
{k: torch.as_tensor(v).cuda() for k, v in batch.items()},
torch.as_tensor(pair_act).cuda(),
torch.as_tensor(pair_mask).cuda(),
templ_dim=0,
inplace_safe=False
)
out_repro_ds = out_repro_ds["template_pair_embedding"].cpu()
err = torch.max(torch.abs(out_repro - out_repro_ds))
self.assertTrue(err < eps, f'Error {err}')
def test_compare_model(self):
"""
Run full model with and without using DeepSpeed Evoformer attention kernel
and compare output coordinates.
"""
eps = 0.5
with open("tests/test_data/sample_feats.pickle", "rb") as fp:
batch = pickle.load(fp)
# atom37_to_atom14 doesn't like batches
batch["residx_atom14_to_atom37"] = batch["residx_atom14_to_atom37"][0]
batch["atom14_atom_exists"] = batch["atom14_atom_exists"][0]
batch["no_recycling_iters"] = np.array([3., 3., 3., 3., ])
batch = {k: torch.as_tensor(v).cuda() for k, v in batch.items()}
batch["aatype"] = batch["aatype"].long()
batch["template_aatype"] = batch["template_aatype"].long()
batch["extra_msa"] = batch["extra_msa"].long()
batch["residx_atom37_to_atom14"] = batch[
"residx_atom37_to_atom14"
].long()
batch["template_all_atom_mask"] = batch["template_all_atom_masks"]
batch.update(
data_transforms.atom37_to_torsion_angles("template_")(batch)
)
# Move the recycling dimension to the end
move_dim = lambda t: t.permute(*range(len(t.shape))[1:], 0)
batch = tensor_tree_map(move_dim, batch)
with torch.no_grad():
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
model = compare_utils.get_global_pretrained_openfold()
model.globals.use_deepspeed_evo_attention = False
out_repro = model(batch)
# Enable kernel
model.globals.use_deepspeed_evo_attention = True
out_repro_ds = model(batch)
out_repro = tensor_tree_map(lambda t: t.cpu(), out_repro)
out_repro_ds = tensor_tree_map(lambda t: t.cpu(), out_repro_ds)
out_repro = out_repro["sm"]["positions"][-1].squeeze(0)
out_repro_ds = out_repro_ds["sm"]["positions"][-1].squeeze(0)
err = torch.max(torch.abs(out_repro - out_repro_ds))
self.assertTrue(err < eps, f'Error: {err}')
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment